在序列并行系列中,我们将详细介绍下面四种常用的框架/方法:
- Megatron Sequence Parallelism:本质是想通过降低单卡激活值大小的方式,尽可能多保存激活值,少做重计算,以此提升整体训练速度,一般和它家的 tp 配套使用。
- DeepSpeed Ulysses:我们知道 ds 家的 zero 是模型并行的形式,数据并行的本质。在这个情况下,单张卡是完整地做一条序列的 MHA 过程的,序列长度较长时,就会对单卡显存产生压力。所以 Ulysses 的解决办法是,让单张卡只算全部 seq 的某个/某些 head 的结果,具体实践起来就是先通过按 seq 维度切割卡的输入,再通过 all2all 通讯来做。
- Ring Attention:相当于分布式的 Flash Attention V2(我个人的理解),它最终的效果是让每张卡只算自己所维护的那部分 seq_chunk 的 MHA。
- Megatron Context Parallelism:可以看成是增强版的 sp,引入了类 ring-attention 的技术(在 tp-pp-dp rank 相同的位置做 ring-attention),联合 Megatron 的各种混合并行方式进行训练。
今天,我们来讲最后一部分 Megatron Context Parallelism,把它放在最后的原因是:
- Megatron cp 可以看成是在保持 megatron sp 混合并行框架的基础上,引入 cp 维度的并行。而 cp 并行的本质其实是做 attention 部分的优化。所以你可以把 megatron sp 混合并行理解成整体框架,cp 理解成局部优化。
- Megatron cp 在实践上和朴素的 ring attention 非常相似,但是它做了计算上的负载均衡处理,我们在本文中会详细讲解这一点。
- Megatron cp 也有尝试做 deepspeed ulysses + ring attention 的结合,这一点也写在 cp 的核心逻辑中,但这不是本文讲解的重点。
- 综合来看,本文要讲解的重点是 megatron tp + cp + dp + pp 的混合并行,同时重点关注纯 cp 部分的实践方法。
关于 megatron cp,算是一个比较新的还在持续发展的项目,目前官方没有给出具体的论文,只有一个很简短的官网介绍(https://docs.nvidia.com/megat...),从官网介绍中我们可以大致理解上面说的“在保持megatron sp 混合并行框架的基础上,引入 cp 维度并行”的大致含义。但是这篇文章真得太短了(苦笑),所以 cp 的细节只能从源码层面来解读。(然而,请让我再次吐槽一次 😢😢,cp 的实践横跨了 megatron-lm 和 TranformerEngine 两个仓库,代码真得写得太冗余、太杂、太混乱了...所以这真是一篇暗含泪水的解读)。
虽然是从源码的阅读中推测出了 cp 的核心技术,但是本文不打算写成一篇源码解读的文章。本文将把源码运作流程抽象成一张张具体的图例,来说明 cp 主要做了什么事,在每一节的最后配上相关的代码链接,大家可以配合着图例自行阅读。如此一来,尽量让这篇文章变成纯原理式的文章,不让大家被冗长的代码分心。
最后,小小纪念下,历时一年,终于把并行训练的部分写得差不多了,也算是有头有尾的一个系列了,后续依然可能会往这个系列里增添新的东西,谢谢在这个系列中给出宝贵反馈和肯定的所有朋友们!
图解大模型训练系列:流水线并行
图解大模型训练系列:dp 与 ddp
图解大模型训练系列:deepspeed zero
图解大模型训练系列:megatron-LM,张量并行(原理篇)
图解大模型训练系列:megatron-LM,分布式环境初始化(源码篇)
图解大模型训练系列:megatron-LM,模型并行(源码篇)
图解大模型训练系列:megatron-LM,混合精度训练(源码篇)
图解大模型训练系列:deepspeed 与 megatron 的 moe 并行(原理篇)
图解大模型训练系列:deepspeed 与 megatron 的 moe 并行(源码篇)
以及本次更新的 4 篇序列并行。
一、分布式环境初始化
首先,我们来看在引入 cp 的前提下,megatron 是如何做混合并行的,具体情况如下:
tp = 2, cp = 2, dp = 2, pp = 2
。那么有 num_gpu = tp * cp * dp * pp = 222*2 = 16。也就是我们需要 16 张卡。假设我们的一台机器内有 8 张卡,则我们需要 2 台机器。- 我们不考虑 ep 维度(即 ep=1),因为本质上它不影响 cp 维度的并行(cp 维度是对 attention 做优化,ep 可以理解成是 mlp 层的操作)
- 在考虑如何设置并行 group 时,我们采用的顺序是**
tp-cp-ep-dp-pp
,我们认为越靠前的并行组,通讯量越大,所以尽量安排在一台机器内**。例如对于 tp group,它的每一个 sub-tp group 关联的 2 张卡都位于同一台机器中。tp-cp-ep-dp-pp 是 megatron 代码默认的顺序,我们当然可以根据实际情况做修改,但前提就是要考虑通讯量。 - 由于 dp=2,所以我们假设有 2 个 micro-batch,分别是 batch0 和 batch1。由于 cp=2,每个 batch 都被从 seq 维度上切成两份。
在这些前置条件下,我们绘制出了上面的分布式配置图片,我们以 gpu0 为例:
- 首先,对于一个模型,它沿着 layer 层被横向切成 2 份(pp=2),沿着权重被纵向切成 2 份(tp=2)。对应到我们的图里,就是 4 个不同颜色的色块组成一个完整的模型。
- 对于 gpu0 来说,[0,1,8,9]组成了一个 mp group,拥有一个完整的模型
- 对于 gpu0 来说,[0,1]组成了 tp 组,这意味着 0 和 1 将吃相同的输入 X,然后分别计算 X 的不同 head 的结果
- 对于 gpu0 来说,[0,8]组成了 pp 组,这意味着 0 和 8 之间会做层间激活值的传递
- 对于 gpu0 来说,[0,2]组成了 cp 组,0 和 2 上维护着相同的模型权重,但是分别维护着同一个 batch 的 seq_chunk0,seq_chunk1
- 对于 gpu0 来说,[0,4]组成了 dp 组,0 和 4 上维护着相同的模型权重,但是分别维护着不同 batch 的 seq_chunk0。
总结来看,引入 megatron cp,其实就是:
- 我们先假设不对输入 X 做任何序列维度的切分,这时我们就得到了原始的 megatron tp-dp-pp 组。
- 现在引入 cp,意味着我们要把输入 X 切分成 cp_size 份,所以我们只需要把原始的 tp-dp-pp 组拷贝 cp_size 份,就得到了最终的分布式配置。
- 所以我们在前文中才说,相同的 tp-dp-pp rank 位置就是新的 cp 组。例如图中所展示的 tp-dp-pp group 中,0 和 2 都是各自 group 内的 local rank = 0 的元素,所以他们组成一个的 cp 组;1 和 3 都是各自 group 内 local rank = 1 的元素,所以他们组成一个 cp 组,以此类推。
好,现在我们已经知道了如下内容:
- cp 组的设置方式
- 同一个 cp_group 内的各张卡维护着:【相同的模型权重】、【相同 batch 的不同 seq_chunk】
- 一个 cp 组的最终目标是:通过类 ring attention 的方式,计算出自己所维护的这个 seq_chunk 在自己所负责的这个 head 上的结果。
所以接下来,我们马上来看计算细节。分布式初始化的代码在https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py中,大家可以配合上述讲解自行阅读。
二、负载均衡的 Ring Attention
2.1 朴素 Ring Attention
如上图所示,我们在ring attention 篇中讲过一个朴素 ring attention 的运作流程:
- 每张卡上固定维护着某个 seq_chunk 的 Q
- 每张卡上轮转不同 seq_chunk 的 KV 值
- 每张卡上,Q 和当前轮转到的(K, V)数据做 attention 计算,然后通过类似 Flash Attention V2 的方式更新 output(细节这里不赘述,大家可以去看上面链接中的文章)
- 当所有的 KV 值轮转完毕后,每张卡上就得到了最终的 output。
例如以 Q0 为例,整个计算过程如下:
但是,朴素 ring attention 存在一个较大的问题:计算负载不均衡。
假设我们使用的是 causal mask,也就是在 attention 计算对于某个 token,它只和自己及之前的 tokens 做 attn,而不关心后面的 token。但是在当前 ring attention 的划分下:
- 对于 gpu0,它维护着 Q0,这也意味着后面流转过来的(K1, V1)(K2, V2)(K3, V3)都是位于它之后的 tokens 产出的结果,它根本不需要和它们做 attn,这时 gpu0 的计算就被浪费了。
- 对于其余 gpu 也是同理。只有维护着最后一块 Q 分块的 gpu3 能在每次流转中都做好计算,没有浪费计算资源。
- 这就是我们所说的,causal mask 下朴素 ring attention 的计算负载不均问题。
2.2 负载均衡版 Ring Attention
在之前的ring attention 篇和Flash Attention V1/Flash Attntion V2中我们讲过,分块 attention 的计算其实是和计算顺序无关的,核心是只要每次计算时我们都能拿到当前分块的 output,当前未做 softmax 前 attention score 矩阵的 max 和 sum 相关信息,我们就能正常更新最终的 output。(如果对这句话不太理解,可以看下上面链接里给的文章,这里不再展开了)
在理解了这一点的基础上,我们重新设计 Ring Attention 中每块卡上存放的 seq_chunk:
如上图所示,假设 cp_size = 4,也就是我们打算在 4 块 gpu 上做 ring attention。
- 首先,对于原始输入数据 X,我们将其切分为 2*cp_size = 8 块,也就是上图的 0 ~ 7 chunk
- [0,7],[1, 6],[2, 5], [3, 4]分别组成 4 个 seq_chunk,安放在 gpu0~gpu3 上。
- 则在 ring attention 下,每块 gpu 上计算 cp_size 次后,就能得到最终的 output。例如对于 gpu0,计算 4 次后,就能得到[0, 7]这两个位置最终的 attention 结果。
图中接着展示了在不同的 iteration 中,每块卡上的计算情况,可以发现:
- i = 0 时,每张卡上都是 4 个小方块在做 attn 计算
- i = 1/2/3 时,每张卡上都是 3 个小方块在做 attn 计算
- 总结来看,每个 iteration 中,各卡的计算量是相同的。不存在朴素 ring attention 上某些卡空转的情况。
同时注意到,当 i = 1/2/3 时,总有 Q 或者 KV 块不参与计算,如果我们用 rank 表示这是 cp_group 内的第几块 gpu(例如 rank=0 就是上面 cp_group 中的第 0 块 gpu),则对于某张卡,我们有如下规律:
- i = 0,该卡上所有的 QKV 块参与计算
- i <= rank 时,该卡上第 2 个 KV 块不参与计算
- i > rank 时,该卡上第 1 个 Q 块不参与计算
用于分配哪张卡上应该维护哪些 Q 块的代码在:https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/training/utils.py#L233
用于处理实际 QKV 计算时应该保留哪些数据块,去掉哪些数据块的代码在:https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/attention.py#L1901
大家可以自行阅读
三、计算和通讯的 overlap
在 ring attention 中我们讲解过,对于一张卡,如果我们能让它在计算 attn 的同时,把自己的 KV 发送给下一张卡,同时从上一张卡中获取新的 KV,这样我们就能实现【计算】和【通讯】的并行,以此来掩盖通讯要带来的额外时间开销。
具体到代码的时间上,我们可以创建不同的 cuda 流(torch.cuda.Stream())来实现这一目标。对于 cuda 流的作用你可以简单理解成:一个 cuda 流中可能包含若干串行的操作,而不同的 cuda 流是可以并行执行的。这样,我们就可以定义一个 cuda 流用于计算 attn,再定义一个 cuda 流用于做通讯。
但在 megatron cp 中,其实一共包含 3 个 cuda 流,我们来简单认识一下它们:
- NCCL stream:定义在 cp_group 内,是用于做 KV 发送和接收的 cuda 流
- Stream0 和 Stream1:都是用于做计算的 cuda 流,这两个流的作用是可以并行执行 attn 的计算和 softmax_lse 的更新。
- 也就是说,megatron cp 中,除了对【计算】和【通讯】做了并行,还对计算中的【attn】和【softmax_lse 更新】做了并行。
【计算】和【通讯】的并行好理解,我们现在来快速解释下【attn】和【softmax_lse 更新】的并行是什么意思。
在之前的系列中,我们已经讲过 ring attention 更新 output 的方式非常近似于 Flash Attention V2,所以我们贴出 Flash Attention V2 的 fwd 过程,来看下 output 是如何更新的:
- 图中第 10 行展示了每次 output 的更新过程
- 图中第 12 行是,对于一块 Q,当我们轮转了所有的(K, V)后,我们使用第 12 行的公式对 output 再做一次性的更新,这才得到了这块 Q 最终的 output。而第 12 行的结果,就是我们说的 softmax_lse。
- 但是,另一种做法是,我们可以把第 12 行的结果放进第 10 行做,也就是对于一块 Q,每轮转 1 次(K,V),计算 attn 时,我们就可以这次轮转算出的 attn score 矩阵算出 max 和 sum,进而更新 softmax_lse,然后用于更新本次轮转的 output,这就是 ring attention 采用的做法,目的应该是尽量减少精度损失。至于 FA2 中为什么把第 10 行和第 12 行拆开做,本质是为了减少非矩阵乘法的计算量,以此提升计算速度(之前的文章讲过,这里不再赘述。)
所以,对于 ring attention,总结来看它的每次计算都分成两块:
- 【attn:算出本次轮转的 output】
- 【softmax_lse 更新:基于本次轮转的结果更新 softmax_lse,用于修正 output】
现在我们已经基本了解【attn】和【softmax_lse 更新】的定义了,那么现在我们就直接来看 megatron cp 中这 3 个 cuda 流的运行过程,然后来解释什么叫【attn】和【softmax_lse 更新】的并行:
上图刻画了在 cp_size = 4 的情况下,某张卡上的流转过程,具体而言:
i = 0时
- 切换到 stream0 上开始执行
- 在 stream0 上开启发送/接收 KV 数据的流程,而这个流程则实际由 NCCL stream 开始执行
- 在 stream0 上做 attn 计算。attn 计算结束后,会得到 output 和 softmax_lse0。由于这是 i=0 阶段,所以我们不用做 softmax_lse 的更新
- 我们要注意区分“算出 softmax_lse_i"和“更新 softmax_lse”的区别。
i = 1时
:- 我们需要等待(wait)本次计算需要的 KV 值到位,这里我特意假设计算无法完美覆盖通讯,所以我们多了等待时间。
- 等数据到位后,我们就开启新的发送/接收 KV 数据的流程,这个这个流程则实际由 NCCL stream 执行。
- 接着我们正常做 attn 计算,得到本次的 output 和 softmax_lse1。
- 同时开启 stream0 和 stream1 流程
- 在 stream0 流程中,我们令 softmax_lse = softmax_lse0,这时我们先不做任何 softmax_lse 的更新。
- 在 stream1 流程中:
- 不难发现,此时我们在 stream0 和 stream1 中已经实现了【attn】和【softmax_lse 更新】的并行,只是这里不是严格的 softmax_lse 更新
i= 2时
:- 同时开启 stream0 和 stream1 流程。
- 在 stream1 流程中,我们开始做真正意义上的【softmax_lse 更新】,即 softmax_lse = correction(softmax_lse, softmax_lse1)
- 在 stream0 流程中,我们做【attn】计算,得到新的 output 和 softmax_lse2。同时开启 NCCL stream 做数据发送
i = 3时
:- 同时开启 stream0 和 stream1 流程。
- 在 stream0 流程中,我们做【softmax_lse 更新】,即 softmax_lse = correction(softmax_lse, softmax_lse2)
- 在 stream1 流程中,我们做【attn】计算,得到新的 output 和 softmax_lse3。此时我们已经无需再做数据通讯了,因为这是最后一轮流转。
i = 4时
:- 只需要启动 stream1,做最后一次【softmax_lse 更新】,即 softmax_lse = correction(softmax_lse, softmax_lse3)即可。
这张通讯图我们还做了一些简化,例如当我们每次做【softmax_lse 更新】时,我们都需要保证上一次更新的结果已经计算完毕,所以这里可能也是需要做 wait 的,为了表达简便,这边略去。
本节以及整个 cp 核心代码在 https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/attention.py#L1867中,大家可以配合上面的图例更好阅读代码。
END
来源:GiantPandaCV
推荐阅读
- YOLOv8+PyQT5 打造安全帽检测预警应用
- 轻量级神经网络模型,嵌入式微小设备也能实时检测 !
- YOLOv8 与 YOLO11 自定义数据集迁移学习效果对比
- OrientedFormer: 基于 Transformer 的定向目标检测新框架 !
欢迎大家点赞留言,更多 Arm 技术文章动态请关注极术社区嵌入式 AI 专栏欢迎添加极术小姐姐微信(id:aijishu20)加入技术交流群,请备注研究方向。