大模型训练加速之FlashAttention系列:爆款工作背后的产品观

来源:https://zhuanlan.zhihu.com/p/664061672

FlashAttention(FA)是一系列针对Transformer模型训练和推理加速方案。自从去年(2022年)五月发布以来,历经了多次迭代,并借着其节省显存、加速计算、使用简单的特性,目前已经是大模型训练必备之良药。FA的论文,一年多就有242个引用,作为一个做底层优化的论文,也是前无古人了。

不过仔细研究背后的历史,你会发现FA没有任何没有独特的创新,甚至某G字大厂还更早公开过一模一样的工作,从某种角度来说FA有抄袭洗稿的嫌疑。FA的爆红出圈的背后,究竟隐藏着什么不为人知的秘密?是人性的扭曲,还是社区的沦丧。本文带你走进FlashAttention的世界,一起解密爆款AI Infra工具背后的故事。

一、产品思维成就爆款

本文先分析一下FA成功的原因,然后介绍一下技术细节。

首先,FA的发展线大概经历如下几个重要里程碑:

  • 2022.6 FA最初版本arXiv文章(https://arxiv.org/abs/2205.14135v1)公开。
  • 2022.7 使用OpenAI Triton实现的FA发布。
  • 2022.11 MeurIPS 2022收录了 FA v1的论文,这个工作也出现在ICML Workshop中。

FLASHATTENTION: Fast and Memory-Efficient Exact Attention with IO-Awareness

  • [2022.12 Megatron-LM支持FA v1 [PR link]](https://github.com/NVIDIA/Meg...
  • 2022.3 PyTorch 2.0集成FA。
  • 2023.7 FA v2版本发布,更高的FLOPs利用率。

Flashattention-2: Faster attention with better parallelism and work partitioning

  • 2023.10 Flash Decoding发布,针对长序列的推理。

Flash-Decoding for long-context inference

去年我和身边同事说过,FlashAttention是我个人评选的2022年Infra类最佳的工作。在电影《死亡诗社》中,传统教科书中评判诗歌好坏的方式是画一个坐标轴,X轴是艺术性,Y是主题重要性,一首诗歌多优秀是它覆盖的面。类似方法,如果评价一个AI场景中”加速“类的工作,X轴是它的使用广泛性,Y轴是使用它带来的加速比,我想FA是一个非常优秀的工作。金杯银杯不如百姓的口碑,FA之所以被大家广泛使用,其本质原因是产品能力的胜利

image.png

从优化技术角度讲,FA采用的方法并不新颖,本质是对矩阵连乘问题进行定制化的的Tiling+重计算。同类的工作其实很多:一方面,这个经典并行问题是各种AI Compiler和JIT技术,比如Halide、TVM、torch.compile等,的射程范围之内的,甚至它们可以做到对矩阵维度参数泛化性更强。另一方面,xformers等精细的Kernel Fusion的技术也早已有之,和它们相比FA v1版本效率也并不是绝对最高的。

FA为何脱颖而出,成为训练场景下Multihead Attention(MHA)实现的事实标准,有以下几个原因。

1、简单易用:微信之父张小龙说过,把用户体验做到极致就是创新。FA使用非常简单,用户不需要了解任何GPU有关的细节,只需要pip install安装,用python替换掉原始的Multihead Attention代码,半个小时就能拿到显存和速度收益。这种好事,何乐而不为呢?假设FA要求用户本地编译CUDA代码,我相信会劝退90%用户,比如我无法理解Megatron-LM安装时候花五分钟本地编译cuda kernels的行为。

2、聚焦头部需求:与其一开始就充分考虑各种长尾需求,不如先把头部需求做到极致。FA把自己定位在面向特定型号GPU大模型训练框架的插件,只做MHA的正反向的实现。在2022年时间点,随着序列变长,MHA训练的内存和计算增长问题首当其冲。尽管算子优化在推理场景做是最顺手的,FA发布却只针对训练,从现在来看,这是很必要的,如果长序列模型训练不出来,哪有推理可言。而且,FA初始版本对于参数范围、GPU型号都有限制,不过作为一个MVP把最主流需求满足,长尾需求慢慢解决都来得及。

3、拥抱开源生态:Online Softmax+Tiling+Recompute几个核心技术都不是FA原创的。Google Research在2021年12月就有一个和FA一模一样的工作叫memory-efficient-attention(论文链接贴在下面)。看完这篇论文,你甚至说FA抄袭都不为过。可为什么后来大家只知FA,而不知MEA?我认为原因在拥抱开源力度的差距,MEA论文放出来时候没有一个官方开源实现,论文也是汇报的Jax实现结果,让大家觉得只是Google内部黑科技而已。而FA的官方PyTorch实现代码和论文一起放出来,和现有生态Megatron-LM完美兼容,易用性配合出色的效果自带传播属性。值得玩味的是,2022年早些时候MAE也被集成到PyTorch xformers项目里了,但大众认知里MHA优化还是FA先做的,MEA论文引用只有区区33个,而FA却有242个,这或许来源于开源那一刻的蝴蝶效应。最后,一个好名字对产品很重要,MEA听起来就很难记忆,而且没有体现它fast的特征,而Flash这个词就更能抓住人的记忆点。值得一提的是,在2023年3月和FA同时作为在2.0版本的功能发布。

Self-attention Does Not Need O(n^2) Memory 

4、保持更新:开源如逆水行舟,不进则退,FA没有发布即巅峰,不信可以看一下Dao-AILab(https://github.com/Dao-AILab)/flash-attention(https://github.com/Dao-AILab/flash-attention)的Contributors变化曲线,作者一年多时间不断在commit新的代码。FA V2里也体现了作者对社区的改进的关注,比如对MQA、MGA、Mask等需求的支持相对比较及时,让用户感知到这个Library一直与时俱进很重要。

成功的产品都是相似的,失败的产品确实各有各的失败。聚焦主流需求、极致使用体验、出色的实用效果、兼容开源生态,让FA在大模型训练Infra这个极卷的红海领域里面成为爆品工作。

最后聊一聊FA的核心开发者Tri Dao,他去年刚博士毕业于Standford,之前一直做核方法之类的ML算法研究,发的都是NeurPS、ICLR会议,并不是Infra背景,FA似乎是他的第一个跨界工作。他博士毕业后作为togethercompute的首席科学家,最近要加入普林斯顿计算机系做助理教授。影响力如此大的AI Infra工作,是算法研究者跨界的结果,其背后原因值得做MLSys从业者深思。

二、FA学习路线

对FA技术细节解读很多,在知乎上有很多解读,比如这个问题下有很多非常好的答案:

FlashAttention 的速度优化原理是怎样的?

用户并不需要了解FA细节也可以很容易把FA当黑盒使用,这再次说明了FA产品设计的很好,把复杂度留给了自己,把简单交给了用户。但对于关注底层同学,还是需要好好掌握FA原理,从而进行二次开发和进一步优化工作。

我建议把FA设计和实现分开学习,FA迭代了三个版本了,而且除了Tri Dao维护的官方版本,还有Triton、PyTorch等版本实现。对于FA的设计思路,2023年我的建议是不要去读FA V1论文了,推荐学习路线:

  1. 学习Online Softmax

Online normalizer calculation for softmax

  1. 看FlashAttention-2论文。因为FlashAttention V1的关键设计在V2中被推翻了,FA V2论文也介绍了FA V1。
  2. 看Flash Decoding(不在本文介绍范围内)

我下面按照这个顺序,介绍FA的技术要点。这里尽量避免过分陷入细节中,因此不会过度出现复杂的伪代码和公式。

三、FA技术细节

3.1 问题定义

image.png
image.png

如果没有Softmax,那TVM、Halide之类的编译优化优化工具解决这个问题是是驾轻就熟的,如何对矩阵进行Tiling可以优化的明明白白。

可是麻烦出现在Softmax!Softmax需要对完整的QK^T结果矩阵(图上虚线部分正方形)沿着Inner Loop维度进行归一化。Softmax需要全局的max和sum结果才能scale每一个元素,因此本地算出一块QK^T的结果还不能立刻和V进行运算,还要等同一行的后面的QK^T都算完才能开始,这就造成依赖关系,影响计算的并行。

3.2 Online Softmax

为了能够并行计算Softmax,文章使用了Online Softmax的技巧。这个技巧simple but effective,值得我们单开一个小节,花5min时间理解其精髓。

Online normalizer calculation for softmax

Naive的Softmax公式

image.png

算法流程如下,需要两个循环,先算分母,再scale每一个元素。

image.png

但是,line 3指数相加容易浮点溢出,所以Safe Softmax需要把每个元素减去所有元素最大值

image.png

算法流程如下,需要三层循环,分别求max,sum,然后scale每一个元素。

image.png

三个循环需要从内存读写4次次向量。因此,下面算法用一个循环同时计算了最大值m和归一化项d,它将Softmax函数评估的内存访问从4减少到每个向量元素的3次。它用前一个迭代的d和当前m,计算当前迭代的m。

image.png

3.3 FA的online softmax应用

Online softmax可以打破之前必须先算完一整行的QK^T结果,再和V相乘的依赖关系。算出local softmax结果立刻和V的分块运算,后面再通过乘系数矫正即可。

image.png

有了online softmax + tiling,我们就可以按照outer loop切分,row-wise地并行地执行图1的流程。

image.png

3.4 反向重计算

回顾之前反向计算的一组公式,第一行计算dV需要P矩阵,而正向计算时FA不会存储P,在反向时重计算P的值来节省内存。这个方法就是MXNet提出的经典activation checkpointing技术。

3.5 FA V2其他优化

今年7月发布的V2版本是在V1基础上融合了社区的一些改进,实现了FLOPs利用率的飞跃。

我觉得V2最重要的提升点是参考Phil Tillet的Tirton版本,更改了Tiling循环的顺序,也就是笔者本文图1的效果。V1版本循环顺序设计是outer和inner loop和图1反过来,在outer loop扫描时做softmax的规约,这导致outer loop必须在一个thread block里才能共享softmax计算中间结果的信息,从而只能对batch head维度上以thread block为粒度并行切分。V2中调换了循环顺序,使outer loop每个迭代计算没有依赖,可以发送给不同的thread block并行执行,也就是可以对batch head * sequence三层循环以thread block为粒度并行切分,从而显著增加GPU的吞吐。反向遵循同样的原理,不要把inner loop放在softmax规约的维度,因此正向反向的循环顺序是不同的。

说实话,我看V1论文中原始图1就觉得循环顺序很不合理,我盯着循环执行的zigzag顺序图看了一下午,百思不得其解。现在FA github readme里还不把这个图改过来,有点说不过去。另外,个人推测V2性能最大的提升就是来源是来自这个循环顺序的调换,Tri Dao大佬把OpenAI Triton的作者Phil Tillet挂在coauthor位置一点也不过分。所以,看到很多报道FA V2的新闻都用“斯坦福博士一己之力让Attention提速9倍!XXX“的标题,细心的读者会觉得比较扎心。

随着循环顺序调换之后,一个thread block内warps粒度的划分也需要改进,作者把V1版本沿着K切分共享Q,改为沿着Q切分共享K。

另外作者做了提取公因式数学变换,减少了一些non-matmul FLOPs,调优了thread block size。最终更达到了最优情况下72% A100 FLOPs利用率的效果。

写到这,文章有点长了,笔者已经写不动了,对于推理的Flash Decoding优化我今后再单开一篇文章解读。

总结

FlashAttention是近一年来影响力最大的训练框架层加速工作了。它爆红出圈,算法工程师几乎人尽皆知,人皆可用。可是,FA的优化方法并没有任何创新性可言,它爆红的背后是产品设计上的胜利。聚焦主流需求、极致使用体验、出色的实用效果、兼容开源生态,让FA在大模型训练Infra这个极卷的红海领域里面成为爆品工作。

The End

作者:方佳瑞
文章来源:GiantPandaCV

推荐阅读

更多嵌入式AI干货请关注嵌入式AI专栏。欢迎添加极术小姐姐微信(id:aijishu20)加入技术交流群,请备注研究方向。

推荐阅读
关注数
16657
内容数
1232
嵌入式端AI,包括AI算法在推理框架Tengine,MNN,NCNN,PaddlePaddle及相关芯片上的实现。欢迎加入微信交流群,微信号:aijishu20(备注:嵌入式)
目录
极术微信服务号
关注极术微信号
实时接收点赞提醒和评论通知
安谋科技学堂公众号
关注安谋科技学堂
实时获取安谋科技及 Arm 教学资源
安谋科技招聘公众号
关注安谋科技招聘
实时获取安谋科技中国职位信息