FP8 低精度训练:Transformer Engine 简析

原文:https://zhuanlan.zhihu.com/p/700874387

一、背景介绍

业界广泛采用 FP16、BF16 混合精度(AMP)进行模型训练。AMP 能在下游任务不掉点的前提下提升训练效率、减少显存等资源占用,如今也常用于大模型预训练、微调等任务。

  • Pytorch 1.6 版本后原生支持 FP16、BF16 精度的 AMP 训练(torch.amp),过往 AMP 功能由 NVIDIA APEX 库实现。

NVIDIA GPU 自 Hopper 架构起支持 FP8 精度的 Tensor Core 计算,相比于 FP16/BF16 精度,FP8 具有如下优势:

  • 更强的计算性能

对比 A100 BF16 精度训练,H100 FP8 训练速度提升 2-3x。

image.png

对比 FP16/BF16,FP8 的计算吞吐提升至 2x,与 A100 相比提升的吞吐更多。
image.png

  • 更低的训练成本:FP8 能提供 2x 的计算速度提升,节省 50%-75% 内存占用,以及节省 50%-75% 的数据通信量。
  • 更好的模型优化:FP8 的使用促使模型在训练和推理过程中进行量化,这有助于模型的优化和压缩,进一步降低部署成本。

我们仍然先从几个问题出发~

什么是 FP8?

相比于 16bit 精度,FP8 使用了更少的指数 bit 位和尾数 bit 位:

image.png

在 NV、Arm、Intel 公布的 FP8 白皮书中[arXiv:2209.05433] 介绍了 FP8 数据的两种精度:E4M3和E5M2。两种数据格式的具体二进制表示如下表:

image.png

FP8 E4M3的表示范围为[-448, 448],E5M2为[-57334, 57334],根据其数据表示范围和精度需求,一般而言,E4M3 格式更适合 weight、activation 数据,E5M2 格式更适合 grad 数据。

image.png

为什么是 FP8,不是其他精度(比如 int8)?

  • 与 int8 的数值表示相比较, FP8 在 LLM 的训练更有优势。因为 int8 在数值空间是均匀分布的,而 FP8 有更宽的动态范围, 更能精准捕获 LLM 中参数的数值分布。
  • 尽管 FP16/BF16 已成为业界常用的训练精度,但在大模型训练场景下精度损失问题相对不敏感,仍然可以通过降低精度提升效率。FP8 训练能够在控制精度误差的情况下(Per-tensor Scaling),具有比 16bit 精度更快的计算速度和更少的资源占用,从而提升吞吐、降低训练通信量。

FP8 精度训练的效果如何?在下游任务的表现如何?

FP8 在绝大多数训练任务下都能有 FP16 相当的精度,在少部分下游任务(如数学运算)存在一定差距。

各种 CV 模型在 FP8 精度下训练的分类精度【NV测试结果】:

image.png

NLP 预训练任务【NV测试结果】:

image.png

LLM Benchmark【NV测试结果】:

image.png

SFT 微调效果:

image.png

FP8 有哪些应用场景/案例?

  • Inflection AI 推出的 Inflection2 模型中,采用了 FP8 技术对其模型进行训练优化。Inflection-2 采用了 FP8 混合精度在 5000 个 NVIDIA Hopper 架构 GPU 上进行了训练,累计浮点运算次数高达约10^{25}FLOPs。与同属训练计算类别的 Google 旗舰模型 PaLM 2 相比,在包括知名的 MMLU、TriviaQA、HellaSwag 以及 GSM8k 等多项标准人工智能性能基准测试中,Inflection-2 展现出了卓越的性能,成功超越了 PaLM 2,彰显了其在模型训练方面的领先性,同时也印证了 FP8 混合精度训练策略能够保证模型正常收敛并取得良好的性能。

image.png

  • 零一万物基于 NVIDIA 软硬结合的技术栈,在功能开发、调试和性能层面,与 NVIDIA 团队合作优化,完成了在大模型的 FP8 训练和验证。其大模型的训练吞吐相对 BF16 得到了 1.3 倍的性能提升。零一万物的训练框架是基于 NVIDIA Megatron-LM 开发的 Y 训练框架, 其 FP8 训练基于 NVIDIA Transformer Engine。

image.png

  • Google 与 NVIDIA 团队合作,将 TensorRT-LLM 应用于 Gemma 模型,并结合 FP8 技术进行了推理加速。使用 Hopper GPU 进行推理时,FP8 对比 FP16 在吞吐量上能够带来 3 倍以上的收益。FP8 能够在相同的时间限制下使用更大的 batch size,从而有更好的 GPU 利用率,达到更高的吞吐量。
  • 目前,NVIDIA 有专门使用 FP8 的 开源库—— Transformer Engine。

    • Transformer Engine 和 FP8 已经集成到 PyTorch/JAX/Paddle Paddle 等基础深度学习框架中。
    • 在专用于LLM的框架,比如 Megatron/NeMo/DeepSpeed/HuggingFace/Colossal-AI中也已经集成了 Transformer Engine 和 FP8,并有相应的 FP8 示例。

二、FP16/BF16 AMP

回顾 Pytorch AMP 的实现原理:

  • 计算流程
  • 显存分布
  • Loss Scaling

1.计算流程

通常的 FP16 AMP 计算流程为:

  • 数据、模型一开始都是 FP32 精度。
  • 进入 torch.autocast 后,模型开始前向计算:

    • 如果遇到 FP16 算子,则权重和数据都会转化为 FP16(权重一般有 FP16 cache,除非设置了autocast(cache_enabled=False)),然后在 FP16 精度的算子上进行前向计算,输出的结果也是 FP16 精度;
    • 如果遇到 FP32 算子,则计算精度为 FP32,输出的精度也为 FP32。
  • 反向计算,不需要在 torch.autocast 区域中,torch 会根据前向计算精度,自动确定反向计算精度。
  • 优化器更新权重,利用 Tensor Core,可以直接完成 FP16 + FP32 的加法计算,以 FP32 精度更新权重,整个过程不需要精度转化。
FP16 支持算子:https://pytorch.org/docs/stable

image.png

Pytorch 使用 AMP 的样例代码如下(注意此处精度为 BF16):

with torch.cuda.amp.autocast(dtype=torch.bfloat16):
    outputs = model(inputs)
    loss = loss_func(outputs, targets)

loss.backward()
optimizer.step()
optimizer.zero_grad()

2.显存分布

FP16 AMP 训练过程中,显存包含如下数据:

  • 用于前向计算的模型权重:FP16
  • 梯度:FP16
  • 优化器:包含 FP32 Master Model Weight + 2*FP32 Adam States,即一阶矩和二阶矩
  • 其余中间计算结果

3.Global Loss Scaling

由于 FP16 能够表示的数值范围更小,因此对于 FP16 精度的 AMP,需要进行 loss scaling。

scaler = torch.cuda.amp.GradScaler()

with torch.cuda.amp.autocast(dtype=torch.float16):
    outputs = model(inputs)
    loss = loss_func(outputs, targets)

scaler.scaled(loss).backward()
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
关于数值范围:FP32: 1-8-23 / BF16: 1-8-7 / FP16: 1-5-10
BF16 的数值范围和 FP32 一致,均有 8 个指数位,不需要 scaler 调整数值范围。
但 FP16 仅有 5 个指数位,当原数值过大或过小时,转换到 FP16 就可能出现 overflow/underflow,对训练造成影响。

image.png

实际上,我们会维护一个全局的 scale 值,并采用 Dynamic Loss Scaling 动态调整这个全局的 scale 值。即每当梯度溢出时候减少损失缩放规模,并且间歇性地尝试增加损失规模,从而实现在不引起溢出的情况下使用最高损失缩放因子,更好地恢复精度。

image.png

三、FP8 技术分析

1.宏观实现框架

FP16 所采用的 Loss Scaling 与量化的思想非常相似,它可以看成是对全局的梯度数据做离线量化(PTQ)。

FP8 的数据范围有更大的限制,单一的全局 Scale 值无法满足众多数据分布的相对精确表示,因此我们可以仿照量化的思路,将量化的基本单位缩小至 tensor(更细致的量化,如 Block-wise quantization,理论上可以用于更低精度的训练上)。

FP8 对每一个 tensor(无论是输入数据、前向计算结果、反向计算结果)都计算一个 Per-tensor Scaling Factor,以此做更加细致的量化,充分利用 FP8 为数不多的格点数。

具体而言,每一次前向的 GEMM 计算需要对 3 个 tensor 记录 scale 值:inputweightoutput;而相对应的反向计算需要记录 2 个 tensor 的 scale 值:grad_outputgrad_input。在 TE 的 Hybrid 模式下,前向 tensor 数据格式为 E4M3,反向 tensor 数据格式为 E5M2。两种 FP8 精度的量化方式基本相同,均采用对称线性量化,我们只需要关心 scale 值。可以参考之前的内容:

image.png

然而,FP8 训练最关键的问题是,如何在训练过程中高效地寻找到这个 scale 值?

NV 在 TE 文档中给出了两种方案:

  • Just-in-time scaling,即先通过计算得到高精度的 output tensor,再在其上计算 amax,然后对 output tensor 做量化,输出 FP8 tensor。

    这些步骤在单个 kernel 层面上是不可能实现的,因为我们需要将完整的 output tensor 搬到 HBM 才能完成量化,因此这会将一个完整的计算过程分散成多次 kernel 调用,增加了数据传输量,拖慢了运行速度。NV 认为这极大削减了 FP8 低精度带来的好处。

image.png

  • Delayed scaling,即假设我们提前知道了 scale 值,那么计算过程就可以在一个 kernel 上完成,而 amax 的计算和 scale 值的更新与其独立,不会中断计算进程,因此这种方式能完全发挥 FP8 的性能。但“提前知道的” scale 值需要额外的空间来记录,同时会引入一定的误差。

image.png

如下图所示,如果我们知道了 scale 值,那么计算的公式和伪代码就比较直接了:

image.png

TE 框架采用 Delayed scaling 方案,即对每个 GEMM 算子用到的 tensor 记录一个 amax history 数组,当我们需要 scale 值时,就从这个数组中取出最近一段时间窗口内 amax 的最大值,以此近似现在这个 tensor 的 amax,并默认用以下方式计算 scale 值:

FP8_MAX = maximum_representable_value(fp8_format)
new_scaling_factor = (FP8_MAX / amax) / (2 ^ margin)

用户可以自定义 Delayed scaling 的策略(Recipe),例如:

  • margin:即上面公式的 margin,用于调整 scaling factor
  • interval:经过多少 steps 更新一次 scaling factor
  • fp8_format:指定前向反向的计算精度,默认为前向 E4M3,反向 E5M2
  • amax_history_len:amax history 窗口长度,超过窗口长度的历史记录会被覆盖
  • ……等等

2.TE 及各类框架集成方法

总的来说,任何结合 FP8 能力的框架都只需要做两件事:

  1. 使用 TE 模块搭建 model,因为计算要用到 TE 提供的 FP8 算子;
  2. 用 fp8_autocast 装饰前向计算过程。

在实际场景下,FP8 训练通常需要结合 BF16 混合精度训练。

1)TE 官方案例:

import torch
import transformer_engine.pytorch as te
from transformer_engine.common import recipe

# Set dimensions.
in_features = 768
out_features = 3072
hidden_size = 2048

# Initialize model and inputs.
model = te.Linear(in_features, out_features, bias=True)
inp = torch.randn(hidden_size, in_features, device="cuda")

# Create an FP8 recipe. Note: All input args are optional.
fp8_recipe = recipe.DelayedScaling(margin=0, interval=1, fp8_format=recipe.Format.E4M3)

# Enable autocasting for the forward pass
with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
    out = model(inp)

loss = out.sum()
loss.backward()

2)Accelerate:支持 DDP 和 FSDP 的 FP8 训练

# We prepare fp8 after, allowing for bf16 autocast to happen first
if getattr(self.fp8_recipe_handler, "backend", None) == "TE":
    if not has_transformer_engine_layers(model):
        with torch.no_grad():
            convert_model(model)
        model._converted_to_transformer_engine = True

    kwargs = self.fp8_recipe_handler.to_kwargs() if self.fp8_recipe_handler is not None else {}
    if "fp8_format" in kwargs:
        kwargs["fp8_format"] = getattr(te_recipe.Format, kwargs["fp8_format"])
    fp8_recipe = te_recipe.DelayedScaling(**kwargs)
    # If we are in DDP or FSDP, we delay `autocast` until after FSDP/DDP has been initialized
    # to make use of the process group
    if not self.delayed_fp8_autocast:
        model.forward = fp8_autocast(enabled=True, fp8_recipe=fp8_recipe)(model.forward)
  • Line5:如果模型不是 TE 结构的,需要转化为 TE 结构的模型
  • Line8-11:获取 fp8_format、fp8_recipe
  • Line15:前向计算(model.forward)在fp8_autocast上下文管理器内完成

3)Megatron Core:支持 Tensor、Sequence、Pipeline 并行与 FP8 训练结合

# define TE model
use_te = args.transformer_impl == "transformer_engine"
if use_te:
    transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec()
model = GPTModel(transformer_layer_spec=transformer_layer_spec)

# set autocast context
class TransformerBlock(MegatronModule):
    def forward():
        if self.config.fp8:
            import transformer_engine  # To keep out TE dependency when not training in fp8

            if self.config.fp8 == "e4m3":
                fp8_format = transformer_engine.common.recipe.Format.E4M3
            elif self.config.fp8 == "hybrid":
                fp8_format = transformer_engine.common.recipe.Format.HYBRID
            else:
                raise ValueError("E4M3 and HYBRID are the only supported FP8 formats.")

            fp8_recipe = TEDelayedScaling(
                config=self.config,
                fp8_format=fp8_format,
                override_linear_precision=(False, False, not self.config.fp8_wgrad),
            )
            fp8_group = None
            if parallel_state.model_parallel_is_initialized():
                fp8_group = parallel_state.get_amax_reduction_group(with_context_parallel=True)
            fp8_context = transformer_engine.pytorch.fp8_autocast(
                enabled=True, fp8_recipe=fp8_recipe, fp8_group=fp8_group
            )
        else:
            fp8_context = nullcontext()

        with fp8_context:
            # Forward pass.

3.FP8 框架 TE 计算流程

首先分析 TE 框架入口fp8_autocast的源代码:

@contextmanager
def fp8_autocast(
    enabled: bool = True,
    calibrating: bool = False,
    fp8_recipe: Optional[DelayedScaling] = None,
    fp8_group: Optional[dist_group_type] = None,
    _graph: bool = False,
) -> None:
    try:
        fp8_state = FP8GlobalStateManager.get_fp8_autocast_state()
        FP8GlobalStateManager.fp8_autocast_enter(enabled=enabled,
                                                 calibrating=calibrating,
                                                 fp8_recipe=fp8_recipe,
                                                 fp8_group=fp8_group,
                                                 _graph=_graph)
        yield
    finally:
        FP8GlobalStateManager.set_fp8_autocast_state(fp8_state)
        FP8GlobalStateManager.fp8_autocast_exit(enabled, _graph=_graph)

FP8GlobalStateManager是一个单例,它保存了全局的 fp8 state 和每个 TE Module 的 fp8 scale/amax 信息。

fp8_autocast的主要工作是:

  • 进入 with 语句后,我们即将开始一次前向计算,此时首先保存当前 fp8 state,以便之后恢复状态。同时更新 fp8 训练的新设置;
  • 退出 with 语句时,我们已经完成了一次前向计算,此时恢复原先的 fp8 state,随后在fp8_autocast_exit 函数中,我们 reduce 每个 fp8_group 进程的 amax,并更新 amax history 和 scale 值。

TE 模块都继承于TransformerEngineBaseModule 这个基类。每个实例均有一个 fp8_meta 字典,这个字典包含了 fp8 的关键信息,里面记录的内容有:

  • fp8_checkpoint:save state 时是否保存 fp8_meta 信息
  • num_gemms:前向计算中 GEMM 的计算次数
  • recipe:即 fp8_recipe
  • fp8_group:fp8 通信组
  • fp8_max_fwd:前向计算最大数值(一般是 E4M3 = 448)
  • fp8_max_bwd:反向计算最大数值(一般是 E5M2 = 57344)
  • scaling_fwd:记录前向 GEMM 的 3 个 scale 值,以及相应的 scale_inv、amax_history
  • scaling_bwd:记录反向 GEMM 的 2 个 scale 值,以及相应的 scale_inv、amax_history

经过代码分析,TE 框架的 FP8 计算流程大致如下:

  • 数据、模型的精度开始是什么不重要,它们都需要先经过 BF16 AMP 的处理。由于 TE 的模块也是 torch.nn.Module 类,因此在调用 forward 方法之前,数据和模型权重都会先转换为 BF16 精度。
  • 然后,在遇到 TE FP8 Module 时,在 forwardbackward 方法内,input 和 weight 都会先转化为 FP8 精度(如果已有 FP8 cache,或计算数据已经是 FP8 精度就跳过),并调用 fp8_gemm 方法进行 FP8 精度计算。这个过程中,我们会传入fp8_meta字典,读取 scale 和 scale_inv 值用于量化计算,并将计算得到的 amax 放入其中的 amax_history。
  • 其他非 FP8 Module 的情况,按照通常 AMP 的计算逻辑进行。
  • 优化器更新权重不属于 FP8 的管辖范围,按照通常 AMP 的计算逻辑进行。

4.Tensor Core 如何进行 FP8 训练

FP8 精度计算仅能运行在 Tensor Core 上。Tensor Core 的基本运算单元为 D = AB + C,其中A、B、C、D 均为矩阵。每个 Tensor Core 能在一个时钟周期内完成 44 的 mma 运算,即一次矩阵乘法和一次矩阵加法。Tensor Core wmma::mmasync API 的最小数据单元是 1616 的矩阵,因此 TE 框架要求输入数据的各维度必须是 16 的倍数。

image.png

在 FP8 计算中,输入的两个矩阵可以是 FP8 两种精度的任意组合,并且 FP8 的 FLOPS 是 16bit 的两倍。两个 FP8 矩阵在完成一次 Tensor Core 运算后会输出高精度结果(FP16/FP32),因此这里存在着 FP8->FP16/FP32 以及 FP16/FP32->FP8 的精度转化过程。

image.png

四、总结与展望

FP8 训练的局限性

  • 在小参数规模训练(小于1B参数量)的场景下,FP8 训练带来的 overhead(量化处理逻辑、精度转化等)要大于计算加速和通信加速带来的性能提升。
  • 同时,如果训练的数据量(batch size)太小,FP8 训练的性能反而会不如 BF16。(在我们的训练环境下,当 batch size 仅为 4 时,训练吞吐会下降约 17%)
  • FP8 微调部分下游任务时表现欠佳,例如数学运算、MMLU 中的困难任务等。
  • FP8 训练过程中出现异常情况(loss spike、NaN 等)的调试更具有挑战性。

FP8 及更低精度训练的前景

FP8 训练在大模型场景下已具有明确的应用前景,目前也具有工业界的应用案例,因此它有望成为大模型高效训练的配置之一。

在硬件端,NV 最新的 BlackWell 架构开始支持 FP6、FP4 等更低精度的 Tensor Core 运算,并可能采用 Block-wise 的量化方案。而 Deepspeed 也推出了不依赖于硬件计算条件的 FP6 运算:参考链接https://github.com/microsoft/DeepSpeed/blob/master/blogs/deepspeed-fp6/03-05-2024/README-Chinese.md

在低精度运算成为常规方案的今天,在保证训练精度不掉点,并采用低精度训练的性能提升幅度,可能还远未到达极限。

image.png

参考资料

FP8 相关论文:

  • 8-BIT NUMERICAL FORMATS FOR DEEP NEURALNETWORKS(Graphcore 2021)
  • Auto-Precision Scaling for Distributed Deep Learning(2021)
  • FP8 Formats For Deep Learning(NV、Intel、Arm 2022)
  • Mixed Precision Training With 8-bit Floating Point(Intel 2019)

TE 代码与文档:

NV 技术博客:

作者: 杨远航
来源: GiantPandaCV

推荐阅读

欢迎大家点赞留言,更多Arm技术文章动态请关注极术社区嵌入式AI专栏欢迎添加极术小姐姐微信(id:aijishu20)加入技术交流群,请备注研究方向。

推荐阅读
关注数
18751
内容数
1308
嵌入式端AI,包括AI算法在推理框架Tengine,MNN,NCNN,PaddlePaddle及相关芯片上的实现。欢迎加入微信交流群,微信号:aijishu20(备注:嵌入式)
目录
极术微信服务号
关注极术微信号
实时接收点赞提醒和评论通知
安谋科技学堂公众号
关注安谋科技学堂
实时获取安谋科技及 Arm 教学资源
安谋科技招聘公众号
关注安谋科技招聘
实时获取安谋科技中国职位信息