博客来源:https://pytorch.org/blog/training-using-float8-fsdp2/ 。by IBM and Meta 。这里主要是汇总一下 FSDP2 和 FP8 训练相关的内容,目前的实践主要集中在 TorchTitan(DTensor,Async Tensor Parallelism,FP8 Allgather 等等)和 torchao 上面,包括 torch.compile 编译器也在做对应的支持,PyTorch 对于这个工作其实还没做到很稳定,和 Meagtron-LM 的 FP8 类似处于半成品阶段,例如 API 接口变动就很大,这里可以先简单了解一下他们的进展。以下是 PyTorch 关于 FP8 训练最新进展的博客翻译。
FSDP2 和 FP8 训练 相关前置内容:
- 【翻译】使用 PyTorch FSDP 最大化训练吞吐量
- 【翻译】使用 PyTorch FSDP 和 Torch.compile 最大化训练吞吐量
- 【翻译】在 FSDP2 中开启 Float8 All-Gather
- [分布式训练与 TorchTitan] PyTorch 中的 Async Tensor Parallelism 介绍
使用 float8 和 FSDP2 加速训练
作者:IBM: Tuan Hoang Trong, Alexei Karve, Yan Koyfman, Linsong Chu, Divya Kumari, Shweta Salaria, Robert Walkup, Praneet Adusumilli, Nirmit Desai, Raghu Ganti, Seetharami Seelam Meta: Less Wright, Wei Feng, Vasiliy Kuznetsov, Driss Guesseous
在本博客中,我们将展示如何在保持损失和评估基准一致性的同时,相比FSDP1 bf16 训练实现高达 50%的吞吐量提升。我们通过利用 FSDP2、DTensor 和 torch.compile 与 torchao 的 float8 线性层更新(计算)以及 float8 all_gathers 进行权重通信来实现这一提升。我们展示了这些改进在 Meta LLaMa 模型架构的不同规模上的效果,从 1.8B 小型模型一直到 405B 大型模型,使训练速度比以往更快。
我们使用 Meta Llama3 架构展示这些改进,并在两个规模上进行模型质量研究:8B 模型规模的 100B tokens 训练和 70B 模型规模的 50B tokens 训练,这提供了 float8 和 bf16 训练损失曲线的精确比较。我们证明了与bf16
相比,这些模型训练运行的损失曲线达到了相同的损失收敛。此外,我们使用 FineWeb-edu 数据集训练了一个 3B 模型到 1T tokens,并运行标准评估基准以确保模型质量完整且与 bf16 运行相当。
在 IBM 研究院,我们计划采用这些功能进行数据消融实验,以提高在给定 GPU 预算内可以执行的实验数量。从长远来看,我们将通过更大规模的模型运行来展示float8
训练的端到端可行性。
什么是 Float8?
float8
训练格式是由 NVIDIA、ARM 和 Intel 在 2022 年的一篇论文(https://arxiv.org/abs/2209.05433)中提出的,该论文证明了使用更低精度float8进行训练的可行性,且不会牺牲模型质量。随着NVIDIA Hopper 系列等新型 GPU 的推出,由于原生 float8 张量核心支持,FP8 训练变得可行,有望实现超过 2 倍的训练吞吐量提升。实现这一承诺面临一些挑战:(i) 在float8
中启用核心模型操作如matmul
和attention
, (ii) 在分布式框架中启用float8
训练, (iii) 在float8
中启用 GPU 之间的权重通信。虽然 NVIDIA 库启用了float8
matmul
,但后两项是在 FSDP2 和 torchao 的最新更新中提供的。
在本博客中,我们使用 torchtitan(https://github.com/pytorch/to...)作为训练入口点,使用IBM的确定性数据加载器,来自torchao的float8
线性层实现,以及最新PyTorch nightly 版本中的float8 all gather
与 FSDP2 结合。对于这次训练,我们使用的是float8
每张量(tensorwise)缩放粒度而不是行级。我们利用torch.compile
确保获得最大性能提升。我们使用 SDPA 在bf16
中计算attention
,目前正在努力将其也迁移到float8
。
实验
我们进行了各种实验来展示 float8 训练的优势。首先是确保不会牺牲模型质量。为了验证这一点,我们训练了一个 8B 模型和 70B 模型几千步,并比较 float8 和 bf16 训练运行之间的损失曲线。我们的实验在三个不同的 H100 集群上进行,分别配置了 128、256 和 512 个 H100 GPU,环境各不相同,以证明可重复性。第一个集群是 Meta 的 Grand Teton(https://engineering.fb.com/20...)上的定制集群,具有400Gbps定制互连;第二个是IBM研究集群,具有3.2Tbps Infiniband 互连;第三个是 IBM Cloud 集群,具有 3.2Tbps RoCE 互连用于 GPU 到 GPU 通信。
首先,我们在下面的图中绘制了这两个模型的损失曲线比较,以展示几千步的损失一致性。
图 1:(a) 8B 模型 2k 步损失一致性,(b) 70B 模型 1k 步损失一致性
我们观察到,在这些不同的模型和不同的环境中,我们在小规模 tokens 训练中获得了损失一致性。接下来,我们对从 1.8B 到 405B 的四种不同模型规模的吞吐量增益进行了表征。我们探索了 float8 和 bf16 训练运行的最佳批量大小和激活检查点方案,以确定每 GPU 每秒的 tokens 数(wps)指标并报告性能增益。对于 405B 模型,我们利用 DTensor 进行张量并行训练与 FSDP2。我们所有的测量都使用 8K 的序列长度。
表 1:相对于 bf16 的性能增益(bf16 和 float8 都使用 torch.compile)
从表 1 中我们观察到,较大模型(70B 和 405B)的增益达到 50%,较小模型的增益在 20%到 30%之间。在进一步的实验中,我们观察到 float8 all_gather 的添加使性能在 float8 计算本身的基础上提升了约 5%,这与这篇博客(https://aws.amazon.com/cn/blo...)中的观察结果一致。
其次,为了展示 FP8 模型的有效性,我们使用来自 Hugging Face 的 FineWeb-edu 数据集训练了一个遵循 Llama3 架构的 3B 模型,训练量达到 1T tokens。我们使用 lm-eval-harness 框架进行评估,并在下表中展示了部分结果。我们观察到 bf16 的性能略优于 float8 分数(约一个百分点)。虽然某些分数在 bf16 下明显更好(例如,MMLU 高出 3 分),但我们预计当选择正确的超参数和进行更大规模的训练运行时,这些差距会消失(例如,bf16 运行的批量大小是一半,众所周知较小的批量大小运行可以提高评估分数)。
表 2:float8 训练模型在 FP16 下进行评估的基准分数(在 FineWeb 预训练的 1T tokens 处)。
最后,我们将实验扩展到 IBM Cloud 集群的 512 个 H100 GPU 上。我们能够在 512 GPU 规模上重现我们观察到的结果和加速。我们在下表中仅总结了大型模型(70B 和 405B)的这些结果。
表 3:512 GPU 规模下相对于 bf16 的性能增益(bf16 和 float8 都使用 torch.compile)
未来工作
我们还在研究其他形式的并行性,如上下文并行性。我们计划评估所有这些特性,以展示可组合性和为大规模模型训练做出选择的能力。
致谢
我们感谢 IBM Research 的 Davis Wertheimer 为 torchtitan 运行启用数据加载器,使我们能够在多次运行中以相同顺序重放数据。我们还感谢 IBM Cloud 为我们提供 H100 集群的早期测试访问权限。
END
作者:GiantPandaCV
来源:GiantPandaCV
推荐阅读
欢迎大家点赞留言,更多 Arm 技术文章动态请关注极术社区嵌入式 AI 专栏欢迎添加极术小姐姐微信(id:aijishu20)加入技术交流群,请备注研究方向。