图解大模型训练系列:序列并行 4,Megatron Context Parallel

在序列并行系列中,我们将详细介绍下面四种常用的框架/方法:

  1. Megatron Sequence Parallelism:本质是想通过降低单卡激活值大小的方式,尽可能多保存激活值,少做重计算,以此提升整体训练速度,一般和它家的 tp 配套使用。
  2. DeepSpeed Ulysses:我们知道 ds 家的 zero 是模型并行的形式,数据并行的本质。在这个情况下,单张卡是完整地做一条序列的 MHA 过程的,序列长度较长时,就会对单卡显存产生压力。所以 Ulysses 的解决办法是,让单张卡只算全部 seq 的某个/某些 head 的结果,具体实践起来就是先通过按 seq 维度切割卡的输入,再通过 all2all 通讯来做。
  3. Ring Attention:相当于分布式的 Flash Attention V2(我个人的理解),它最终的效果是让每张卡只算自己所维护的那部分 seq_chunk 的 MHA。
  4. Megatron Context Parallelism:可以看成是增强版的 sp,引入了类 ring-attention 的技术(在 tp-pp-dp rank 相同的位置做 ring-attention),联合 Megatron 的各种混合并行方式进行训练。

今天,我们来讲最后一部分 Megatron Context Parallelism,把它放在最后的原因是:

  • Megatron cp 可以看成是在保持 megatron sp 混合并行框架的基础上,引入 cp 维度的并行。而 cp 并行的本质其实是做 attention 部分的优化。所以你可以把 megatron sp 混合并行理解成整体框架,cp 理解成局部优化。
  • Megatron cp 在实践上和朴素的 ring attention 非常相似,但是它做了计算上的负载均衡处理,我们在本文中会详细讲解这一点。
  • Megatron cp 也有尝试做 deepspeed ulysses + ring attention 的结合,这一点也写在 cp 的核心逻辑中,但这不是本文讲解的重点。
  • 综合来看,本文要讲解的重点是 megatron tp + cp + dp + pp 的混合并行,同时重点关注纯 cp 部分的实践方法。

关于 megatron cp,算是一个比较新的还在持续发展的项目,目前官方没有给出具体的论文,只有一个很简短的官网介绍(https://docs.nvidia.com/megat...),从官网介绍中我们可以大致理解上面说的“在保持megatron sp 混合并行框架的基础上,引入 cp 维度并行”的大致含义。但是这篇文章真得太短了(苦笑),所以 cp 的细节只能从源码层面来解读。(然而,请让我再次吐槽一次 😢😢,cp 的实践横跨了 megatron-lm 和 TranformerEngine 两个仓库,代码真得写得太冗余、太杂、太混乱了...所以这真是一篇暗含泪水的解读)。

虽然是从源码的阅读中推测出了 cp 的核心技术,但是本文不打算写成一篇源码解读的文章。本文将把源码运作流程抽象成一张张具体的图例,来说明 cp 主要做了什么事,在每一节的最后配上相关的代码链接,大家可以配合着图例自行阅读。如此一来,尽量让这篇文章变成纯原理式的文章,不让大家被冗长的代码分心。

最后,小小纪念下,历时一年,终于把并行训练的部分写得差不多了,也算是有头有尾的一个系列了,后续依然可能会往这个系列里增添新的东西,谢谢在这个系列中给出宝贵反馈和肯定的所有朋友们!

图解大模型训练系列:流水线并行
图解大模型训练系列:dp 与 ddp
图解大模型训练系列:deepspeed zero
图解大模型训练系列:megatron-LM,张量并行(原理篇)
图解大模型训练系列:megatron-LM,分布式环境初始化(源码篇)
图解大模型训练系列:megatron-LM,模型并行(源码篇)
图解大模型训练系列:megatron-LM,混合精度训练(源码篇)
图解大模型训练系列:deepspeed 与 megatron 的 moe 并行(原理篇)
图解大模型训练系列:deepspeed 与 megatron 的 moe 并行(源码篇)

以及本次更新的 4 篇序列并行。

一、分布式环境初始化

image.png

首先,我们来看在引入 cp 的前提下,megatron 是如何做混合并行的,具体情况如下:

  • tp = 2, cp = 2, dp = 2, pp = 2。那么有 num_gpu = tp * cp * dp * pp = 222*2 = 16。也就是我们需要 16 张卡。假设我们的一台机器内有 8 张卡,则我们需要 2 台机器。
  • 我们不考虑 ep 维度(即 ep=1),因为本质上它不影响 cp 维度的并行(cp 维度是对 attention 做优化,ep 可以理解成是 mlp 层的操作)
  • 在考虑如何设置并行 group 时,我们采用的顺序是**tp-cp-ep-dp-pp,我们认为越靠前的并行组,通讯量越大,所以尽量安排在一台机器内**。例如对于 tp group,它的每一个 sub-tp group 关联的 2 张卡都位于同一台机器中。tp-cp-ep-dp-pp 是 megatron 代码默认的顺序,我们当然可以根据实际情况做修改,但前提就是要考虑通讯量。
  • 由于 dp=2,所以我们假设有 2 个 micro-batch,分别是 batch0 和 batch1。由于 cp=2,每个 batch 都被从 seq 维度上切成两份。

在这些前置条件下,我们绘制出了上面的分布式配置图片,我们以 gpu0 为例:

  • 首先,对于一个模型,它沿着 layer 层被横向切成 2 份(pp=2),沿着权重被纵向切成 2 份(tp=2)。对应到我们的图里,就是 4 个不同颜色的色块组成一个完整的模型。
  • 对于 gpu0 来说,[0,1,8,9]组成了一个 mp group,拥有一个完整的模型
  • 对于 gpu0 来说,[0,1]组成了 tp 组,这意味着 0 和 1 将吃相同的输入 X,然后分别计算 X 的不同 head 的结果
  • 对于 gpu0 来说,[0,8]组成了 pp 组,这意味着 0 和 8 之间会做层间激活值的传递
  • 对于 gpu0 来说,[0,2]组成了 cp 组,0 和 2 上维护着相同的模型权重,但是分别维护着同一个 batch 的 seq_chunk0,seq_chunk1
  • 对于 gpu0 来说,[0,4]组成了 dp 组,0 和 4 上维护着相同的模型权重,但是分别维护着不同 batch 的 seq_chunk0。

总结来看,引入 megatron cp,其实就是:

  • 我们先假设不对输入 X 做任何序列维度的切分,这时我们就得到了原始的 megatron tp-dp-pp 组。
  • 现在引入 cp,意味着我们要把输入 X 切分成 cp_size 份,所以我们只需要把原始的 tp-dp-pp 组拷贝 cp_size 份,就得到了最终的分布式配置。
  • 所以我们在前文中才说,相同的 tp-dp-pp rank 位置就是新的 cp 组。例如图中所展示的 tp-dp-pp group 中,0 和 2 都是各自 group 内的 local rank = 0 的元素,所以他们组成一个的 cp 组;1 和 3 都是各自 group 内 local rank = 1 的元素,所以他们组成一个 cp 组,以此类推。

好,现在我们已经知道了如下内容:

  • cp 组的设置方式
  • 同一个 cp_group 内的各张卡维护着:【相同的模型权重】、【相同 batch 的不同 seq_chunk】
  • 一个 cp 组的最终目标是:通过类 ring attention 的方式,计算出自己所维护的这个 seq_chunk 在自己所负责的这个 head 上的结果。

所以接下来,我们马上来看计算细节。分布式初始化的代码在https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py中,大家可以配合上述讲解自行阅读。

二、负载均衡的 Ring Attention

2.1 朴素 Ring Attention

Image

如上图所示,我们在ring attention 篇中讲过一个朴素 ring attention 的运作流程:

  • 每张卡上固定维护着某个 seq_chunk 的 Q
  • 每张卡上轮转不同 seq_chunk 的 KV 值
  • 每张卡上,Q 和当前轮转到的(K, V)数据做 attention 计算,然后通过类似 Flash Attention V2 的方式更新 output(细节这里不赘述,大家可以去看上面链接中的文章)
  • 当所有的 KV 值轮转完毕后,每张卡上就得到了最终的 output。

例如以 Q0 为例,整个计算过程如下:

Image

但是,朴素 ring attention 存在一个较大的问题:计算负载不均衡。

假设我们使用的是 causal mask,也就是在 attention 计算对于某个 token,它只和自己及之前的 tokens 做 attn,而不关心后面的 token。但是在当前 ring attention 的划分下:

  • 对于 gpu0,它维护着 Q0,这也意味着后面流转过来的(K1, V1)(K2, V2)(K3, V3)都是位于它之后的 tokens 产出的结果,它根本不需要和它们做 attn,这时 gpu0 的计算就被浪费了。
  • 对于其余 gpu 也是同理。只有维护着最后一块 Q 分块的 gpu3 能在每次流转中都做好计算,没有浪费计算资源。
  • 这就是我们所说的,causal mask 下朴素 ring attention 的计算负载不均问题。

2.2 负载均衡版 Ring Attention

在之前的ring attention 篇Flash Attention V1/Flash Attntion V2中我们讲过,分块 attention 的计算其实是和计算顺序无关的,核心是只要每次计算时我们都能拿到当前分块的 output,当前未做 softmax 前 attention score 矩阵的 max 和 sum 相关信息,我们就能正常更新最终的 output。(如果对这句话不太理解,可以看下上面链接里给的文章,这里不再展开了)

在理解了这一点的基础上,我们重新设计 Ring Attention 中每块卡上存放的 seq_chunk:

Image

如上图所示,假设 cp_size = 4,也就是我们打算在 4 块 gpu 上做 ring attention。

  • 首先,对于原始输入数据 X,我们将其切分为 2*cp_size = 8 块,也就是上图的 0 ~ 7 chunk
  • [0,7],[1, 6],[2, 5],  [3, 4]分别组成 4 个 seq_chunk,安放在 gpu0~gpu3 上。
  • 则在 ring attention 下,每块 gpu 上计算 cp_size 次后,就能得到最终的 output。例如对于 gpu0,计算 4 次后,就能得到[0, 7]这两个位置最终的 attention 结果。
  • 图中接着展示了在不同的 iteration 中,每块卡上的计算情况,可以发现:

    • i = 0 时,每张卡上都是 4 个小方块在做 attn 计算
    • i = 1/2/3 时,每张卡上都是 3 个小方块在做 attn 计算
    • 总结来看,每个 iteration 中,各卡的计算量是相同的。不存在朴素 ring attention 上某些卡空转的情况。


同时注意到,当 i = 1/2/3 时,总有 Q 或者 KV 块不参与计算,如果我们用 rank 表示这是 cp_group 内的第几块 gpu(例如 rank=0 就是上面 cp_group 中的第 0 块 gpu),则对于某张卡,我们有如下规律:

  • i = 0,该卡上所有的 QKV 块参与计算
  • i <= rank 时,该卡上第 2 个 KV 块不参与计算
  • i > rank 时,该卡上第 1 个 Q 块不参与计算

用于分配哪张卡上应该维护哪些 Q 块的代码在:https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/training/utils.py#L233

用于处理实际 QKV 计算时应该保留哪些数据块,去掉哪些数据块的代码在:https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/attention.py#L1901

大家可以自行阅读

三、计算和通讯的 overlap

在 ring attention 中我们讲解过,对于一张卡,如果我们能让它在计算 attn 的同时,把自己的 KV 发送给下一张卡,同时从上一张卡中获取新的 KV,这样我们就能实现【计算】和【通讯】的并行,以此来掩盖通讯要带来的额外时间开销。

具体到代码的时间上,我们可以创建不同的 cuda 流(torch.cuda.Stream())来实现这一目标。对于 cuda 流的作用你可以简单理解成:一个 cuda 流中可能包含若干串行的操作,而不同的 cuda 流是可以并行执行的。这样,我们就可以定义一个 cuda 流用于计算 attn,再定义一个 cuda 流用于做通讯。

但在 megatron cp 中,其实一共包含 3 个 cuda 流,我们来简单认识一下它们:

  • NCCL stream:定义在 cp_group 内,是用于做 KV 发送和接收的 cuda 流
  • Stream0 和 Stream1:都是用于做计算的 cuda 流,这两个流的作用是可以并行执行 attn 的计算和 softmax_lse 的更新。
  • 也就是说,megatron cp 中,除了对【计算】和【通讯】做了并行,还对计算中的【attn】和【softmax_lse 更新】做了并行。

【计算】和【通讯】的并行好理解,我们现在来快速解释下【attn】和【softmax_lse 更新】的并行是什么意思。

在之前的系列中,我们已经讲过 ring attention 更新 output 的方式非常近似于 Flash Attention V2,所以我们贴出 Flash Attention V2 的 fwd 过程,来看下 output 是如何更新的:

Image

  • 图中第 10 行展示了每次 output 的更新过程
  • 图中第 12 行是,对于一块 Q,当我们轮转了所有的(K, V)后,我们使用第 12 行的公式对 output 再做一次性的更新,这才得到了这块 Q 最终的 output。而第 12 行的结果,就是我们说的 softmax_lse。
  • 但是,另一种做法是,我们可以把第 12 行的结果放进第 10 行做,也就是对于一块 Q,每轮转 1 次(K,V),计算 attn 时,我们就可以这次轮转算出的 attn score 矩阵算出 max 和 sum,进而更新 softmax_lse然后用于更新本次轮转的 output,这就是 ring attention 采用的做法,目的应该是尽量减少精度损失。至于 FA2 中为什么把第 10 行和第 12 行拆开做,本质是为了减少非矩阵乘法的计算量,以此提升计算速度(之前的文章讲过,这里不再赘述。)
  • 所以,对于 ring attention,总结来看它的每次计算都分成两块:

    • 【attn:算出本次轮转的 output】
    • 【softmax_lse 更新:基于本次轮转的结果更新 softmax_lse,用于修正 output】


现在我们已经基本了解【attn】和【softmax_lse 更新】的定义了,那么现在我们就直接来看 megatron cp 中这 3 个 cuda 流的运行过程,然后来解释什么叫【attn】和【softmax_lse 更新】的并行:

Image

上图刻画了在 cp_size = 4 的情况下,某张卡上的流转过程,具体而言:

  • i = 0时

    • 切换到 stream0 上开始执行
    • 在 stream0 上开启发送/接收 KV 数据的流程,而这个流程则实际由 NCCL stream 开始执行
    • 在 stream0 上做 attn 计算。attn 计算结束后,会得到 output 和 softmax_lse0。由于这是 i=0 阶段,所以我们不用做 softmax_lse 的更新
    • 我们要注意区分“算出 softmax_lse_i"和“更新 softmax_lse”的区别。


  • i = 1时

    • 我们需要等待(wait)本次计算需要的 KV 值到位,这里我特意假设计算无法完美覆盖通讯,所以我们多了等待时间。
    • 等数据到位后,我们就开启新的发送/接收 KV 数据的流程,这个这个流程则实际由 NCCL stream 执行。
    • 接着我们正常做 attn 计算,得到本次的 output 和 softmax_lse1。
    • 同时开启 stream0 和 stream1 流程
    • 在 stream0 流程中,我们令 softmax_lse = softmax_lse0,这时我们先不做任何 softmax_lse 的更新。
    • 在 stream1 流程中:
    • 难发现,此时我们在 stream0 和 stream1 中已经实现了【attn】和【softmax_lse 更新】的并行,只是这里不是严格的 softmax_lse 更新


  • i= 2时

    • 同时开启 stream0 和 stream1 流程。
    • 在 stream1 流程中,我们开始做真正意义上的【softmax_lse 更新】,即 softmax_lse = correction(softmax_lse, softmax_lse1)
    • 在 stream0 流程中,我们做【attn】计算,得到新的 output 和 softmax_lse2。同时开启 NCCL stream 做数据发送


  • i = 3时

    • 同时开启 stream0 和 stream1 流程。
    • 在 stream0 流程中,我们做【softmax_lse 更新】,即 softmax_lse = correction(softmax_lse, softmax_lse2)
    • 在 stream1 流程中,我们做【attn】计算,得到新的 output 和 softmax_lse3。此时我们已经无需再做数据通讯了,因为这是最后一轮流转。


  • i = 4时

    • 只需要启动 stream1,做最后一次【softmax_lse 更新】,即 softmax_lse = correction(softmax_lse, softmax_lse3)即可。


这张通讯图我们还做了一些简化,例如当我们每次做【softmax_lse 更新】时,我们都需要保证上一次更新的结果已经计算完毕,所以这里可能也是需要做 wait 的,为了表达简便,这边略去。

本节以及整个 cp 核心代码在 https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/attention.py#L1867中,大家可以配合上面的图例更好阅读代码。

END

来源:GiantPandaCV

推荐阅读

欢迎大家点赞留言,更多 Arm 技术文章动态请关注极术社区嵌入式 AI 专栏欢迎添加极术小姐姐微信(id:aijishu20)加入技术交流群,请备注研究方向。

推荐阅读
关注数
18835
内容数
1369
嵌入式端AI,包括AI算法在推理框架Tengine,MNN,NCNN,PaddlePaddle及相关芯片上的实现。欢迎加入微信交流群,微信号:aijishu20(备注:嵌入式)
目录
极术微信服务号
关注极术微信号
实时接收点赞提醒和评论通知
安谋科技学堂公众号
关注安谋科技学堂
实时获取安谋科技及 Arm 教学资源
安谋科技招聘公众号
关注安谋科技招聘
实时获取安谋科技中国职位信息