AI学习者 · 6月18日 · 广东

TransformerEngine代码走读

在几个月前nv就发表过关于fp8数制训练和推理的白皮书,大概三四个月前公开了他们针对transformer模型的fp8训练的软件库TransformerEngine,由于最近在关注fp8训练,因此想了解一下他们的实现方案,但是由于没有H100的卡,目前还不能跑通te的代码,很多细节没办法验证,只能通过源码的阅读来观察它插入了什么功能实现fp8量化训练,特此记录一下。

“Transformer Engine (TE) 是一个用于在 NVIDIA GPU 上加速 Transformer 模型的库,包括在 Hopper GPU 上使用 8 位浮点 (FP8) 精度,以在训练和推理中以较低的内存利用率提供更好的性能。TE 为流行的 Transformer 架构提供了一系列高度优化的构建块,以及可与您自己的特定于框架的代码无缝使用的类似自动混合精度的 API。TE 还包括一个与框架无关的 C++ API,它可以与其他深度学习库集成,以实现对 Transformers 的 FP8 支持。”抄自te官方文档https://github.com/NVIDIA/Tra...https://github.com/NVIDIA/Tra...)的一段话。总体来说,te可以实现fp8数制在训练中需要的scale的维护,以及一些层融合的策略。从这些层面来说,te的代码阅读难度还好,大部分工作是基于pytorch的一些数据结构,涉及底层的代码也不算多。

1. example

首先看一段代码

import torch
import transformer_engine.pytorch as te
from transformer_engine.common import recipe

# Set dimensions.
in_features = 768
out_features = 3072
hidden_size = 2048

# Initialize model and inputs.
model = te.Linear(in_features, out_features, bias=True)
inp = torch.randn(hidden_size, in_features, device="cuda")

# Create an FP8 recipe. Note: All input args are optional.
fp8_recipe = recipe.DelayedScaling(margin=0, interval=1, fp8_format=recipe.Format.E4M3)

# Enable autocasting for the forward pass
with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
    out = model(inp)

loss = out.sum()
loss.backward()

这段是官方给出的一个使用示例,相比于普通的fp32训练,主要有3个修改或新增的部分,te.Linear,fp8_recipe,te.fp8_autocast,其中fp8_recipe是一个记录了fp8量化策略和超参数的对象,另外两个在后面的部分会解释。

2. pipeline

这里我画了一张简单的流程图,用来表示te中的代码整体结构。Linear类继承自TransformerengineBaseModule,其中主要提供了关于scale和amax_history更新的功能,te的默认策略是保留最近1000次迭代的amax(要量化的tensor的最大值)为一个amax_history的队列,在每次量化时,考虑amax_history的最大值用来计算scale,然后采用类似int8的量化函数:x_q=quantize(x * scale)/scale。TransformerengineBaseModule又继承自nn.module,方便用户直接调用。_Linear函数继承自torch.autograd.Function,可以利用torch的自动求导的能力。在_Linear的前向和反向函数中底层又调用了cublas库的gemm函数,和atomicMax两个硬件相关的函数。(图中没有考虑cast和transpose融合,和后面的代码有些区别)

image.png

3. deep dive

这部分我用伪代码的形式,从最外层的te.autocast,分析了一下te.Linear大体完整的forward流程,将处在不同文件的调用到的函数以串行的方式展现,在理解上更方便。主要流程由两个上下文管理器控制,fp8_autocast和prepare_forward。主要的数据结构fp8_meta用来保存当前层的所有量化相关的信息(包括amax_history),_global_fp8_buffer用来保存一次迭代的各个设备上计算出来的amax。

# module的上下文管理器
te.fp8_autocast: 
    # 1. module prologue
    set global var:
    # 2. module yield
    Linear.forward():
        # function的上下文管理器
        prepare_forward: 
            # 2.1 function prologue
            fp8_init():
                # amax_history的shape, need_len=1000, num_gemms=需要量化的tensor个数
                if not initialized: amax_history = [need_len, num_gemms] 
            set_fp8_weights()
            copy_amax_from_global_buffer():
                fp8_meta_tensor_key: “scale_fwd”
                buffer_position_key: “global_fp8_buffer_pos_fwd”
                amax_buffer_key: “FWD_AMAX_{fp8_meta[‘autocast_id_fwd’]}”
                # 获取reduce后的上一次的amax
                fp8_meta[fp8_meta_tensor_key].amax_history[0] = _global_fp8_buffer[amax_buffer_key][fp8_meta[buffer_position_key]]
            amax_and_scale_update():
                # 根据新的history计算scale, 并去除一个旧amax, 留下第一个位置给最新的amax
                fp8_meta[fp8_meta_tensor_key].amax_history, amax =_default_get_amax(amax_history, amax_compute_algo):
                    amax = torch.max(amax_history)
                    # torch.roll会把amax_history中的第一个元素放到最后,其他元素往前滚动一位
                    amax_history = torch.roll(amax_history, -1, 0)
                    # 将滚动后的第一位置零,移除了一个旧的amax,并用来表示当前amax的初始值
                    amax_history[0].fill_(0.0)
                scale = _default_get_sf_compute = (amax, scale, fp8_max, margin):
                    exp = get_exponent(fp8_max / amax) - margin
                    new_scaling_factor = 2.0 ^ exp
            # 赋值给fp8_meta后上一次迭代的amax就无用了, 准备根据上一次迭代的id在_global_fp8_buffer中删除
            set_amax_buffer_key_deletion():
                if "autocast_id_fwd" in fp8_meta:
                    _buffer_delete_key_fwd = amax_buffer_key
            if is_first_fp8_module:
                fp8_meta[“autocast_id_fwd”]=FP8_AUTOCAST_COUNTER
                _FP8_CURRENT_CONTEXT_ID = fp8_meta[“autocast_id_fwd”]
            else:
                # FP8_AUTOCAST_COUNTER+=1 在module prologue, FP8_CURRENT_CONTEXT_ID在后面会改变
                fp8_meta[“autocast_id_fwd”]=FP8_CURRENT_CONTEXT_ID
                # 反向的时候会pop给autocast_id_bwd
                fp8_meta[“autocast_id_fwd_stack”].append(_FP8_CURRENT_CONTEXT_ID)
            add_amax_to_global_buffer():
                # 将所有设备的当前amax的tensor保存到_global_fp8_buffer,这里的amax_buffer_key已改变,是本次迭代的id
                if key not in dict: 
                    _global_fp8_buffer[amax_buffer_key] = [fp8_meta[fp8_meta_tensor_key].amax_history[0]]
                else: 
                    _global_fp8_buffer[amax_buffer_key].append(fp8_meta[fp8_meta_tensor_key].amax_history[0])
                if key not in dict: 
                    fp8_meta[buffer_position_key] = len(_global_fp8_buffer[amax_buffer_key])-1
            # 2.2 function yield
            _Linear.forward:
                inputmat, inputmat_t = fp8_cast_transpose_fused(inp, ...):
                    # 更新当前amax
                    tex.fused_cast_transpose(input, fp8_meta[fp8_meta_tensor_key].scale[0],  fp8_meta[fp8_meta_tensor_key].amax_history[0][0], ...):
                        fused_cast_transpose():
                            nvte_cast_transpose():
                                cast_transpose():
                                    cast_transpose_kernel():
                                        cast_and_transpose_regs():
                                            # quantize
                                            out = T(scale*in)
                                            max = fmaxf(fabsf(in), max) in vec
                                        # 先统计每个warp的max,因为直接atomicMax太慢了
                                        max = reduce_max(max, warp_id)
                                        # 底层调用atomicMax intrinsic函数
                                        atomicMaxFloat(amax, max)
                # weight量化和上面同样的操作
                weight_fp8, weight_t_fp8 = fp8_cast_transpose_fused(weight, ...)
                out = fp8_gemm(weight_fp8, ... inputmat, ...):
                    torch.ops.tex_ts.te_gemm_ts():
                        te_gemm():
                            nvte_cublas_gemm():
                                cublas_gemm():
                                    # 来自cublas
                                    cublasLtMatmul():
            # 2.3 function epilogue
            # 更新当前迭代id FP8_CURRENT_CONTEXT_ID
            _FP8_CURRENT_CONTEXT_ID = fp8_meta[“autocast_id_fwd”]
            _amax_forward_global_reduce_func = partial(global_amax_reduction, ...)
    # 3. module epilogue
    _amax_forward_global_reduce_func():
        # 对各个设备上的amax做reduce
        reduce_tensor_across_group_op_max(torch.cat(_global_fp8_buffer[amax_buffer_key]), ...):
            torch.distributed.all_reduce(op=MAX)
    # 删除上一次迭代时记录的amax
    del global_fp8_buffer[_buffer_delete_key_fwd]

到这儿Linear层一次完整的forward就结束了,关于backward的过程还没来得及看,大体瞅了一眼差不多。关于Linear和其他层融合的功能也还没看,猜测是在cuda函数里做的融合,外面的流程应该一致。有些细节目前我也还没看懂,中间可能也忽略了一些函数,后面有新的想法再记录。

The End

作者:液态黑洞
来源:GiantPandaCV

推荐阅读

欢迎大家点赞留言,更多Arm技术文章动态请关注极术社区嵌入式客栈专栏欢迎添加极术小姐姐微信(id:aijishu20)加入技术交流群,请备注研究方向。

推荐阅读
关注数
18801
内容数
1347
嵌入式端AI,包括AI算法在推理框架Tengine,MNN,NCNN,PaddlePaddle及相关芯片上的实现。欢迎加入微信交流群,微信号:aijishu20(备注:嵌入式)
目录
极术微信服务号
关注极术微信号
实时接收点赞提醒和评论通知
安谋科技学堂公众号
关注安谋科技学堂
实时获取安谋科技及 Arm 教学资源
安谋科技招聘公众号
关注安谋科技招聘
实时获取安谋科技中国职位信息