图解大模型训练系列:序列并行3,Ring Attention

在序列并行系列中,我们将详细介绍下面四种常用的框架/方法:

  1. Megatron Sequence Parallelism:本质是想通过降低单卡激活值大小的方式,尽可能多保存激活值,少做重计算,以此提升整体训练速度,一般和它家的tp配套使用。
  2. DeepSpeed Ulysses:我们知道ds家的zero是模型并行的形式,数据并行的本质。在这个情况下,单张卡是完整地做一条序列的MHA过程的,序列长度较长时,就会对单卡显存产生压力。所以Ulysses的解决办法是,让单张卡只算全部seq的某个/某些head的结果,具体实践起来就是先通过按seq维度切割卡的输入,再通过all2all通讯来做。
  3. Ring Attention近似于分布式的Flash Attention V2(我个人的理解)它最终的效果是让每张卡只算自己所维护的那部分seq_chunk的MHA。
  4. Megatron Context Parallelism:可以看成是增强版的sp,引入了类ring-attention的技术(在tp-pp-dp rank相同的位置做ring-attention),联合Megatron的各种混合并行方式进行训练。

其中1和2我们已经在前面的系列中讲过,今天我们来看【3. Ring attention】,阅读本文前,最好对Flash Attention V2Flash attention V1的分块计算过程有初步了解。

一、单gpu:safe softmax

image.png

一个朴素attention的计算过程如上图所示。这里我们设Q,K,V的尺寸都为(N, d),其中N=seq_len,d = hidden_size。

为了接下来更好表示ring attention的切块计算法,上图中把Q,K,V都切分成了4块,每块大小为(C, d)。但这依然不妨碍以不切块的视角阅读上面这张图。特别注意,图中尺寸为(N, N)的attention矩阵要做完softmax后才能和V矩阵相乘,为了表达简便图中没有画出softmax过程,但我们一定要记得有这个过程,这非常重要。

正常来说,我们计算softmax的方式为:

image.png

image.png

二、单gpu:分块计算

在分块的情况下,基本思路是:

  • 固定住某个Qi,它的尺寸为(C, d)
  • 给这个Qi分别传入不同的Kj, Vj数据块,计算对应的Oij
  • 通过某种方式,每传入一份(Kj, Vj),就更新一次Oij。这样知道最后一份(Kj, Vj)传入并计算完毕后,我们就得到了最终的Oi。你可以把这个过程理解为,我们维护的始终只有一个尺寸为(C, d)的Oi,我们在每次计算完毕后都更新这种Oi。到这里为止我们先不纠结Oi具体的更新方式,只要知道它是【滚动更新】的即可。
  • 不难理解,在写成代码的情况下,遍历Q分块的过程可以表示成outer loop,遍历KV分块的过程可以表示成inner loop

整个计算过程如下图所示:

image.png

不难得知:

  • 如果我们不对Attention score做softmax,那么O0的更新方式就可以变成简单的累加形式,即本次计算出的O0+ 上一步骤后的O0结果。
  • 但当我们对Attention score做softmax(默认都指safe softmax时),情况就大不相同了,举例来说:

    • 计算safe softmax,我们需要知道分数矩阵每行的max和每行的sum,我们记其为global max和global sum。
    • 当我们使用S00的结果计算O0时,我们用的是这个矩阵的local max和global sum。
    • 所以,在使用softmax的情况下,我们无法对O0做简单的累加。

那么,在分块的情况下,我们到底要采取什么方式更新Oi呢?本质上来说,ring attention采用的是和flash attention V2非常相近的Oi更新方式,具体可以参见我之前对Flash Attention V2的解读:

  • Flash Attention V2:这篇文章的1.2(1)部分详细介绍了Oi的更新方式
  • Flash Attention V1:这篇文章第四部分详细介绍了朴素attention->safe softmax -> 分块safe softmax的整个过程,并用递归法证明了Oi的更新方式,这个证明方法同样可以类推到V2的Oi上。

注意,这里不了解Oi的具体更新方式并不影响下文的阅读。所以,本文不再对Oi的更新细节和数学推导做更多论述。

但这边我们额外再关注一点:Ring Attention和Flash Attention V2的Oi的更新方式非常相近,但不完全相同。为了更好阐述这一点,我们先来看Flash Attention V2中Oi的更新方式:

  • Flash Attention V2:这篇文章的1.2(1)部分详细介绍了Oi的更新方式
  • Flash Attention V1:这篇文章第四部分详细介绍了朴素attention->safe softmax -> 分块safe softmax的整个过程,并用递归法证明了Oi的更新方式,这个证明方法同样可以类推到V2的Oi上。

注意,这里不了解Oi的具体更新方式并不影响下文的阅读。所以,本文不再对Oi的更新细节和数学推导做更多论述。

但这边我们额外再关注一点:Ring Attention和Flash Attention V2的Oi的更新方式非常相近,但不完全相同。为了更好阐述这一点,我们先来看Flash Attention V2中Oi的更新方式:

image.png

上图展示的是Flash Attention V2的fwd算法过程,第10行展示了Oi的更新方式。同时注意到,当我们把outer loop和inner loop全部做完后,在第12行我们又对Oi做了一次更新,且这个更新是一次性的,同时更新公式中的l和global sum相关。Flash Attention V2为什么要这么做呢?因为:

  • 首先,你当然可以把第12行的更新放到第10行中去做。也就是对于某个分块Qi,我们在逐步更新它对应的Oi时,我们要考虑到目前为止得到的global sum信息。什么叫“目前为止得到global sum”信息呢?例如,当你计算出S00时,你会根据它得到一个sum;当你算出S01时,你会根据它和S00再次得到一个sum;当你算完全部的S分块时,你得到的sum就是真正的global sum了。所以尽管在这里我们没有给出详细的数学推导,从直觉上也不难理解,我们可以选择在第10行内用“目前为止得到的global sum”做迭代更新,也可以选择在第12行用最终的global sum做一个一次性的更新。Flash Attention V1和ring Attention选择把第12行放入第10行中做,而Flash Attention V2选择把两者拆开。
  • 而把第12行更新从第10行中拆出来的主要原因,是为了在gpu中尽量减少非矩阵乘法的计算量。这是因为在现代gpu中(比如NV GPU)非矩阵乘法的计算比矩阵乘法慢约16倍。以NV A100来说,fp16/bf16的矩阵乘法计算理论上的最大吞吐是312 TFLOPs/s,但是非矩阵乘法运算仅为19.5TFLOPs/s

好,到目前为止,我们已经知道在分块的情况下,如何在单GPU上进行Attention计算了,接下来,我们就把这个计算过程拆分到多gpu上,来看看ring attention中的ring是如何运作的。

三、多gpu:环状通信

让我们先来重新端详一下Q0的分块计算流程:

image.png

从这个计算流程中,我们不难看出下面这几点:

image.png

受到这两点的启发,Ring Attention就诞生了,它的整体运作如下:

image.png

  • 首先,我们把Q分块放到各卡上,然后固定住。也就是各卡上保存的Q分块始终不变。
  • 接着,每块卡上只放一块(K,V)对。也就是每次计算时,哪个(K, V)对需要被这块卡用到,哪个(K, V)对就在这块卡上放着。初始化的状态如图iter0所示。
  • 接着,在每块卡使用当前(K, V)对做Attention计算时:
  • 它接收来自前一块卡的(K,V)对
  • 它把自己当前维护的(K,V)对发给下一张卡
  • 例如,当gpu0正在使用(K0, V0)进行计算时,它接收来自gpu3的(K3, V3),同时把自己的(K0, V0)发送给gpu1。其余卡也是类推,整体形成一个环状的通信拓扑
  • 由于在传输(K, V)对的同时,每张卡也在进行Attention计算,因此只要我们设计得当,让【传输时间<=计算时间】,那传输数据带来的额外开销就可以被计算时间覆盖住,进而不影响整个系统的计算效率,还能帮助单卡节省显存。至于如何才算“设计得当”,我们将在后文给出分析。
  • 在我们的图中,为了更好表示这一个更新过程,对于某个Qi我们画出了若干个Oi,但正如上文所说,实际运行时我们只维护一个Qi并不断更新它。这里画出多个Oi只是更好帮助大家理解。
  • 所以,当每张卡在做attention的过程中,它的显存占用是
  • 1个q block(2cd bytes)
  • 1个k block + 1个v block用于计算当前的attention(4cd bytes)
  • 1个k block + 1个v block来自环状通讯拓扑中的前一张卡(4cd bytes)
  • 1个o block用于存储和更新最终output(2cd bytes)

现在,我们应该能更好理解为什么在文章开头中说“ring attention其实约等于分布式版本的flash attention2”了:

  • 在Flash Attention V2中,outer loop(Q分块)和inner loop(KV分块)都在单卡上进行。
  • 在Ring Attention中,outer loop(Q分块)首先被分配到若干张卡上,然后inner loop(KV分块)通过环状通讯的方式轮流被发送到各块卡上进行计算。每个Q分块更新对应O分块的方法和Flash Attention V2基本一致。

四、最佳chunk_size

在第三部分中,我们提到一个很重要的点:ring attention是会带来额外的(K, V)对传输时间开销的,因此我们需要让【传输时间 <= 计算时间】,这样才可以让这部分开销被计算覆盖住,进而不影响整个系统的计算效率。而要做到这一点,就需要我们根据所使用的卡的配置,设计好最优的分块大小,也就是上图中所说的C(chunk_size),用于表示一个分块中包含多长的序列。

我们假设:

  • 硬件(例如单张gpu)的算力上限是F,它表示这个硬件倾尽全力每秒所能完成的浮点运算数,单位是FLOPS或者FLOP/s。
  • 硬件的带宽上限是B,它表示这个硬件倾尽全力每秒所能完成的内存交换量,单位是Byte/s。

接下来为了表达简便,我们在做各种指标计算时忽略掉batch_size维度。

我们知道在单卡上当我们对某个QKV chunk计算attention时,有:

image.png

假设我们用bf16/fp16进行训练,则传输的KV数据量大小(单位bytes)为

  • K chunk大小2dcbytes。
  • V chunk大小为2dcbytes。
  • 总结来看,每次传输的KV数据量大小为4dcbytes

那么基于【传输时间<=计算时间】的基本要求,我们有:

image.png

进一步有:

image.png

也即我们可以根据硬件的F和B,来计算最优的切块大小c。

image.png

END

来源:GiantPandaCV

推荐阅读

欢迎大家点赞留言,更多Arm技术文章动态请关注极术社区嵌入式AI专栏欢迎添加极术小姐姐微信(id:aijishu20)加入技术交流群,请备注研究方向。

推荐阅读
关注数
18801
内容数
1347
嵌入式端AI,包括AI算法在推理框架Tengine,MNN,NCNN,PaddlePaddle及相关芯片上的实现。欢迎加入微信交流群,微信号:aijishu20(备注:嵌入式)
目录
极术微信服务号
关注极术微信号
实时接收点赞提醒和评论通知
安谋科技学堂公众号
关注安谋科技学堂
实时获取安谋科技及 Arm 教学资源
安谋科技招聘公众号
关注安谋科技招聘
实时获取安谋科技中国职位信息