FasterTransformer Decoding 源码分析(一)-整体框架介绍

FasterTransformer 是 NVIDIA 推出的一个用于加速 Transformer 模型推理的库。该库主要通过使用 NVIDIA 的深度学习加速库 cuBLAS、cuDNN 和 TensorRT,以及深度学习框架 TensorFlow 和 PyTorch 的扩展,对 Transformer 模型进行优化和加速。本系列文章试图对FasterTransformer中的Decoding Model进行详细的分析,主要探究其代码模块设计、性能加速优化方案和CUDA Kernel实现技巧,通过学习源码掌握其实现精髓。Decoding Model是经典Transformer中的第二部分,也是推理耗时最高的部分,对这个模块的大量优化值得深入学习借鉴。

image.png

一、整体框架

Decoding整体框架如下图所示,最右侧是处理流程,即embedding->decoder->laynorm->logit GEMM->beam search,本质上和经典Transformer的Decoder结构基本一致。其中decoder部分是核心处理流程,也拆解成各个子模块处理,在FasterTransformer Decoding 源码分析(二)-Decoder框架介绍中进行了详解。

image.png
Decoding框架

图中最右模块序列具体拆解到Layer和Kernel粒度的流程如下图所示,蓝色为kernel,红色是由复杂op封装成Layer。针对实现的流程来看主要做了这几点优化:

  1. 将参数initialize和paddingEmbedding操作(即图中前两个kernel)改写成了CUDA kernel进行了并行化加速。
  2. 将embedding查表和Position Encoding流程合并成了一个CUDA kernel(即图中第三个kernel)来实现。
  3. 针对核心也是最复杂的decoder和beam_search等操作封装成独立的Layer(即图中两个红色Layer)进行单独优化。

接下来围绕这张具体流程图分别介绍每个模块的处理逻辑,因为内容比较多本文不会很深入地介绍细节,会分别单独撰文介绍。

image.png
具体执行流程

二、数据处理流程

输入输出

template<typename T>void Decoding<T>::forward(std::vector<Tensor>*       output_tensors,
                          const std::vector<Tensor>* input_tensors,
                          const DecodingWeight<T>*   decoding_weights){
    FT_LOG_DEBUG(__PRETTY_FUNCTION__);
    // input_tensors:    //      encoder_output [batch_size * beam, mem_max_seq_len, memory_hidden_dimension]    //      encoder_sequence_length [batch_size * beam]
    // output_tensors:    //      output_ids [max_seq_len, batch_size, beam]    //      parent_ids [max_seq_len, batch_size, beam]    //      sequence_length [batch_size, beam], record the number of generated token, except the start token
    // Step is from 1 ~ max_seq_len,    // When step = k,  we put output ids and caches at step k, and the sequence_length would be k - 1 before    // complete this step.

先看下Decoding框架的整体输入和输出,再逐步剖析。

输入tensor:

  1. encoder模块输出的feature,作为decoding的输入这个没什么好讲的,这里按照beam_size的大小进行了复制,主要是为了在cross-attention阶段和beam_size化的输出进行计算,所以大小是[batch_size * beam, mem_max_seq_len, memory_hidden_dimension]。
  2. 输入句子的长度,也按照beam_size的大小进行了复制,大小是[batch_size * beam]。

输出tensor:

  1. 结果句子的id列表,这是最重要的输出,当然也是beam_size化的。[max_seq_len, batch_size, beam]
  2. 结果句子的父id列表,为了在beam_search中查找最佳路径,这个目前已经废弃了。[max_seq_len, batch_size, beam]
  3. 所有句子的最终长度。[batch_size, beam]

权重参数decoding_weights,这个是一些解码过程中要用的权重参数,可以先了解个大概:

  • embedding 词表
  • position_encoding表
  • post_decoder_embedding
  • post_decoder_layernorm
  • decoder_layer_weights

decodingInitialize

    invokeDecodingInitialize(finished_buf_,
                             output_tensors->at("sequence_length").getPtr<int>(),
                             output_ids_buf_,
                             cum_log_probs_,
                             start_ids_buf_,
                             batch_size,
                             beam_width_,
                             max_input_length,
                             stream_);
    sync_check_cuda_error();--------------------------------------------------------------------------------------------------------------------template<typename T>__global__ void decodingInitialize(bool*      finished,
                                   int*       sequence_length,
                                   int*       word_ids,
                                   T*         cum_log_probs,
                                   const int* sentence_ids,
                                   const int  batch_size,
                                   const int  beam_width,
                                   const int  max_input_length){
    const bool IS_FP16   = std::is_same<T, half>::value;
    const T    MAX_T_VAL = (IS_FP16) ? (T)HALF_FLT_MAX : (T)1e20f;  // BF16 and FP32 have the same dynamic range    for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < batch_size * beam_width;
         index += blockDim.x * gridDim.x) {
        finished[index]        = false;
        sequence_length[index] = max_input_length;
        if (word_ids != nullptr) {
            word_ids[index] = sentence_ids[index / beam_width];
        }
        cum_log_probs[index] = (index % beam_width == 0) ? (T)0.0f : (T)-MAX_T_VAL;
    }}

这个函数非常简单,主要是用cuda kernel并行地对finished、sequence_length、cum_log_probs这些后续会使用到的变量进行初始化,形如其名initialize。

paddingEmbedding

        invokePaddingEmbedding(padded_embedding_kernel_,
                               padded_embedding_bias_,
                               decoding_weights->post_decoder_embedding.kernel,
                               decoding_weights->post_decoder_embedding.bias,
                               hidden_units_,
                               vocab_size_,
                               vocab_size_padded_,
                               stream_);
        sync_check_cuda_error();---------------------------------------------------------------------------------------------------------------------template<typename T>__global__ void paddingEmbedding(T*            padded_embedding_kernel,
                                 T*            padded_embedding_bias,
                                 const T*      embedding_kernel,
                                 const T*      embedding_bias,
                                 const int64_t hidden_unit,
                                 const int64_t vocab_size,
                                 const int64_t vocab_size_padded){
    for (int64_t id = threadIdx.x + blockIdx.x * blockDim.x; id < hidden_unit * vocab_size_padded;
         id += blockDim.x * gridDim.x) {
        int row_id = id / vocab_size_padded;
        int col_id = id % vocab_size_padded;
        if (col_id < vocab_size) {
            padded_embedding_kernel[id] = embedding_kernel[row_id * vocab_size + col_id];
        }
        else {
            padded_embedding_kernel[id] = (T)(0.0f);
        }
    }

    for (int id = threadIdx.x + blockIdx.x * blockDim.x; id < vocab_size_padded; id += blockDim.x * gridDim.x) {
        if (id < vocab_size) {
            padded_embedding_bias[id] = embedding_bias[id];
        }
        else {
            padded_embedding_bias[id] = (T)(0.0f);
        }
    }}

这个函数也极其简单,用cuda kernel并行地对padded_embedding_kernel和padded_embedding_bias多余的部分padding为0。

embeddingLookupPosEncoding

        invokeEmbeddingLookupPosEncodingPadCount(decoder_input_buf_,
                                                 decoding_weights->pre_decoder_embedding_table,
                                                 decoding_weights->position_encoding_table,
                                                 output_ids_buf_,
                                                 nullptr,
                                                 batch_size * beam_width_,
                                                 hidden_units_,
                                                 (T)sqrtf(float(hidden_units_)),
                                                 step - 1,
                                                 batch_size * beam_width_,
                                                 0,
                                                 stream_);
        sync_check_cuda_error();------------------------------------------------------------------------------------------------------template<typename T>__global__ void embeddingLookupPosEncoding(T*             from_tensor,
                                           const T*       embedding_table,
                                           const T*       position_encoding,
                                           const int*     all_ids,
                                           const int*     padding_count,
                                           const int*     input_lengths,
                                           const int      local_token_num,
                                           const int64_t  hidden_units,
                                           const int      step,
                                           const int      max_input_length,
                                           const int      token_num,
                                           const int      ite,
                                           const T        scale){
    // 1. lookup from embedding table    // 2. multiply scale    // 3. add the position encoding    const int id_offset = step * token_num + ite * local_token_num;

    const bool use_padding_count = padding_count != nullptr;
    const bool use_input_len     = input_lengths != nullptr;

    for (int64_t index = blockIdx.x * blockDim.x + threadIdx.x; index < local_token_num * hidden_units;
         index += blockDim.x * gridDim.x) {
        const int row_index   = index / hidden_units;
        const int col_index   = index % hidden_units;
        int       step_offset = step;
        if (use_padding_count) {
            step_offset -= padding_count[row_index];
        }
        else if (use_input_len) {
            step_offset -= max_input_length - input_lengths[row_index];
        }
        step_offset *= hidden_units;

        T val = embedding_table[all_ids[id_offset + row_index] * hidden_units + col_index] * scale;
        val   = val + position_encoding[step_offset + col_index];

        from_tensor[index] = val;
    }}

从这个函数开始进入逐个step的解码流程,这里将embedding表查找(词向量)和PosEncoding的流程做了一个融合,在一个CUDA kernel实现了,本质上就是2行查表操作,数据表来自输入的参数。对PosEncoding不了解的可以去看这篇文章猛猿:Transformer学习笔记一:Positional Encoding(位置编码)。

decoderLayer

        std::vector<Tensor> decoder_input_tensors{
            Tensor{MEMORY_GPU, data_type, {batch_size * beam_width_, hidden_units_}, decoder_input_buf_},
            input_tensors->at("encoder_output"),
            input_tensors->at("encoder_sequence_length"),
            Tensor{MEMORY_GPU, TYPE_BOOL, {batch_size * beam_width_}, finished_buf_},
            Tensor{MEMORY_CPU, TYPE_INT32, {1}, &step},
            output_tensors->at("sequence_length"),
            Tensor{MEMORY_GPU,
                   TYPE_INT32,
                   {(size_t)local_batch_size, beam_width_, max_seq_len_},
                   beam_width_ > 1 ? cache_indirections_[src_indir_idx] + id_offset * max_seq_len_ : nullptr}};

        std::vector<Tensor> decoder_output_tensors{
            Tensor{MEMORY_GPU, data_type, {batch_size * beam_width_, hidden_units_}, decoder_output_buf_},
            Tensor{MEMORY_GPU, data_type, self_k_cache_size, key_cache_},
            Tensor{MEMORY_GPU, data_type, self_v_cache_size, value_cache_},
            Tensor{MEMORY_GPU,
                   data_type,
                   {num_layer_, batch_size * beam_width_, mem_max_seq_len, hidden_units_},
                   key_mem_cache_},
            Tensor{MEMORY_GPU,
                   data_type,
                   {num_layer_, batch_size * beam_width_, mem_max_seq_len, hidden_units_},
                   value_mem_cache_}};
        decoder_->forward(&decoder_output_tensors, &decoder_input_tensors, &decoding_weights->decoder_layer_weights);

这里进入到decoding的核心流程,封装在decoder中,decoder的流程逻辑较为复杂,单独在FasterTransformer Decoding 源码分析(二)-Decoder框架介绍。 这里再对输入输出进行一轮讲解。

输入Tensor:

  1. batch_size * beam_size个单词的embedding表示或上一个step的解码输出。
  2. encoder层的输出。
  3. encoder层输入序列的实际长度。
  4. batch_size * beam_size中是否已经解码完成。
  5. 当前解码的步长。
  6. 已解码句子的序列长度。
  7. 中间缓存。(这个暂时还未看到)

输出Tensor:

  1. batch_size * beam_size个解码器的词向量输出。
  2. self-attention中前面steps所计算出来的key buffer。
  3. self-attention中前面steps所计算出来的value buffer。
  4. cross-attention中前面steps所计算出来的key buffer。
  5. cross-attention中前面steps所计算出来的value buffer。

layerNorm

        invokeGeneralLayerNorm(normed_decoder_output_buf_,
                               decoder_output_buf_,
                               decoding_weights->post_decoder_layernorm.gamma,
                               decoding_weights->post_decoder_layernorm.beta,
                               layernorm_eps_,
                               batch_size * beam_width_,
                               hidden_units_,
                               (float*)nullptr,
                               0,
                               stream_);
        sync_check_cuda_error();

layerNorm比较简单,是一个单独的kernel实现,在进击的Killua:FasterTransformer Decoding 源码分析(三)-LayerNorm介绍进行了详细介绍。

cuBLAS Gemm

            cublas_wrapper_->Gemm(CUBLAS_OP_N,
                                  CUBLAS_OP_N,
                                  vocab_size_padded_,  // n                                  batch_size * beam_width_,
                                  hidden_units_,  // k                                  padded_embedding_kernel_ptr_,
                                  vocab_size_padded_,  // n                                  normed_decoder_output_buf_,
                                  hidden_units_,  // k                                  logits_buf_,
                                  vocab_size_padded_ /* n */);
        }

使用对cuBLAS封装的wrapper对上步norm生成的tensor和后处理embeding进行矩阵乘运算,即linear层,后面经过softmax即可得到logits向量(索引为word.id,值为该word.id的概率)。

dynamicDecodeLayer

        TensorMap dynamic_decode_input_tensors(
            {{"logits", Tensor{MEMORY_GPU, data_type, {batch_size, beam_width_, vocab_size_padded_}, logits_buf_}},
             {"embedding_bias",
              Tensor{MEMORY_GPU, data_type, {vocab_size_padded_}, is_bf16 ? nullptr : padded_embedding_bias_ptr_}},
             {"end_id", Tensor{MEMORY_GPU, TYPE_INT32, {batch_size}, end_ids_buf_}},
             {"step", Tensor{MEMORY_CPU, TYPE_INT32, {1}, &step}},
             {"max_input_length", Tensor{MEMORY_CPU, TYPE_INT32, {1}, &max_input_length}},
             // {"input_lengths", Tensor{MEMORY_GPU, TYPE_INT32, {batch_size, beam_width_}, nullptr}},
             {"ite", Tensor{MEMORY_CPU, TYPE_UINT32, {1}, &tmp_ite}},
             {"src_cache_indirection",
              Tensor{
                  MEMORY_GPU, TYPE_INT32, {batch_size, beam_width_, max_seq_len_}, cache_indirections_[src_indir_idx]}},
             {"local_batch_size", Tensor{MEMORY_CPU, TYPE_INT32, {1}, &tmp_local_batch_size}},
             {"beam_search_diversity_rate", Tensor{MEMORY_CPU, TYPE_FP32, {1}, &beam_search_diversity_rate_}},
             {"temperature", Tensor{MEMORY_CPU, TYPE_FP32, {1}, &temperature_}},
             {"len_penalty", Tensor{MEMORY_CPU, TYPE_FP32, {1}, &len_penalty_}},
             {"repetition_penalty", Tensor{MEMORY_CPU, TYPE_FP32, {1}, &repetition_penalty_}}});

        // TODO(bhsueh) Need to modify the forward function to use unordered_map
        // for (auto t = input_tensors->begin(); t != input_tensors->end(); ++t) {
        //     dynamic_decode_input_tensors.insert(*t);
        // }

        // common outputs
        TensorMap dynamic_decode_output_tensors(
            {{"output_ids", Tensor{MEMORY_GPU, TYPE_INT32, {max_seq_len_, batch_size, beam_width_}, output_ids_buf_}},
             {"finished", Tensor{MEMORY_GPU, TYPE_BOOL, {batch_size * beam_width_}, finished_buf_}},
             {"cum_log_probs", Tensor{MEMORY_GPU, TYPE_FP32, {batch_size * beam_width_}, cum_log_probs_}},
             {"parent_ids", Tensor{MEMORY_GPU, TYPE_INT32, {max_seq_len_, batch_size, beam_width_}, parent_ids_buf_}},
             {"sequence_length", output_tensors->at("sequence_length")},
             {"tgt_cache_indirection",
              Tensor{MEMORY_GPU,
                     TYPE_INT32,
                     {batch_size, beam_width_, max_seq_len_},
                     cache_indirections_[tgt_indir_idx]}}});

        // TODO(bhsueh) Need to modify the forward function to use unordered_map
        // for (auto t = output_tensors->begin(); t != output_tensors->end(); ++t) {
        //     dynamic_decode_output_tensors.insert(*t);
        // }

        dynamic_decode_layer_->forward(&dynamic_decode_output_tensors, &dynamic_decode_input_tensors);

这个Layer内部实现了BeamSearch和TopKSampling的逻辑,详见 进击的Killua:FasterTransformer Decoding 源码分析(九)-DynamicDecodeLayer。到这个函数就完成了最多max_seq_len_个step的解码流程。

gatherTree

        invokeGatherTree(output_tensors->at("output_ids").getPtr<int>(),
                         output_tensors->at("sequence_length").getPtr<int>(),
                         max_seq_len_ - 1,
                         batch_size,
                         beam_width_,
                         output_ids_buf_ + batch_size * beam_width_,
                         parent_ids_buf_ + batch_size * beam_width_,
                         end_ids_buf_,
                         stream_);

在 Beam Search 的解码过程中,会生成多个假设,每个假设都是一个可能的序列。gatherTree 函数用于追踪生成最优序列的路径,使用了CUDA kernel进行加速,提高了Beam Search 过程的效率。

三、总结

本文试图对FasterTransformer中的Decoding框架源码进行了初步分析,梳理出整体框架图,同时也拆解出了核心的处理模块,针对加速优化主要采用了将串行流程改写并行kernel、紧密OP融合、利用cuBLAS库和核心模块定制优化等方法,接下来还会深入到各个模块细节去分析开发优化手段。

四、参考

https://github.com/NVIDIA/FasterTransformer/blob/main/docs/decoder_guide.md

  • The End -
作者:进击的Killua
文章来源:GiantPandaCV

推荐阅读

更多嵌入式AI干货请关注嵌入式AI专栏。欢迎添加极术小姐姐微信(id:aijishu20)加入技术交流群,请备注研究方向。

推荐阅读
关注数
17800
内容数
1242
嵌入式端AI,包括AI算法在推理框架Tengine,MNN,NCNN,PaddlePaddle及相关芯片上的实现。欢迎加入微信交流群,微信号:aijishu20(备注:嵌入式)
目录
极术微信服务号
关注极术微信号
实时接收点赞提醒和评论通知
安谋科技学堂公众号
关注安谋科技学堂
实时获取安谋科技及 Arm 教学资源
安谋科技招聘公众号
关注安谋科技招聘
实时获取安谋科技中国职位信息