0x00 前言
微软最近新发了一篇论文,提出了YOCO(You Only Cache Once,和RetNet似乎是相同的作者),这是一个KV Cache层间共享的新思路。同期MIT-IBM Watson AI Lab也发了一篇类似的论文,提出了CLA( Cross-Layer Attention),即KV Cache跨层推理。简直和YOCO不谋而合,因此本文把这两篇论文的阅读笔记放到一起记录了。
由于最近一两年都是在做推理相关的工作(刚毕业的时候做了几年CV算法,偶尔客串一下NLP),对于非推理方向的论文关注已经变少了一些。但是最近在arxiv刷到这篇paper,看完摘要,第一反应就是,“我去,这样做好像还真可以,为啥我没想到(太菜了)”。YOCO的全称是《You Only Cache Once: Decoder-Decoder Architectures for Language Models》,YOCO的名称借鉴了单阶段目标检测始祖YOLO的风格。整篇论文提出的最核心的创新就是,提出了一种新的KV Cache共享方式,即层间共享。并且基于此,将上下文长度扩展到1百万。我们知道,目前最常见的KV Cache共享策略是MQA/GQA,在GQA出来之后,已经很长一段时间没看到类似的研究了。从Layer的视角来看,MQA/GQA可以认为是Intra-Layer KV Cache Shared(层内KV Cache共享),而YOCO提出的想法,则可以认为是Inter-Layer KV Cache Shared(层间KV Cache共享)。层间KV Cache共享,理论上的可以最多把KV Cache的Memory需求降低到1/N(N为Transformer层数),并且,这和层内KV Cache共享的技术,比如MQA和GQA是不冲突的,两者可以一起使用,从而极大地降低KV Cache的显存开销。
本文将结合MQA/GQA,简单记录一下自己对YOCO的理解。如有错误,欢迎大佬指正。本文内容如下:
- 0x01 层内KV Cache共享: MQA简析
- 0x02 层内KV Cache共享: GQA简析
- 0x03 跨层KV Cache共享: YOCO简析
- 0x04 跨层KV Cache共享: CLA简析
- 0x05 总结
0x01 层内KV Cache共享: MQA简析
首先简单介绍一下MQA和GQA。标准的多头注意力就是MHA(Multi Head Attention),在MHA中,KV Heads的数量和Query Heads的数量相同,每个Query Head持有一个独立的KV Head,在Attention中,对单独的KV Head做计算。但是,当模型层数加深和Heads数变多后,QKV Attention的计算和IO都会快速增加。为了缓解这种情况,有学者提出了MQA和GQA。
MQA (Multi Queries Attention): MQA比较极端,只保留一个KV Head,多个Query Heads共享相同的KV Head。这相当于不同Head的Attention差异,全部都放在了Query上,需要模型仅从不同的Query Heads上就能够关注到输入hidden states不同方面的信息。这样做的好处是,极大地降低了KV Cache的需求,但是会导致模型效果有所下降。
0x02 层内KV Cache共享: GQA简析
GQA (Group Queries Attention): GQA与MQA不同,而是采取了折中的做法。GQA把Query Heads进行分组,每组Query Heads对应一个KV Head。比如,把8个Query Heads分成4组,每个Grouped Query Head包含2个Query Heads,一个Grouped Query Head对应一个KV Head,此时总共有4个KV Heads。GQA可以在减少计算量和KV Cache同时确保模型效果不受到大的影响。
在目前大部分主流训推框架或算法,都已经支持MQA/GQA,比如FlashAttention中,也支持MQA和GQA。对于MQA和GQA的情形,FlashAttention采用Indexing的方式,而不是直接复制多份KV Head的内容到显存然后再进行计算。Indexing,即通过传入KV/KV Head索引到Kernel中,然后计算内存地址,直接从内存中读取KV。
歪个楼,FlashAttention V1/V2/V3系列原理&图解,推荐阅读我的另一篇文章:
DefTruth:[Attention优化][2w字] 原理&图解: 从Online-Softmax到FlashAttention V1/V2/V3
- 关于GQA一些数值上的理解
GQA最大的作用是节省显存,同时由于LLM推理的一大瓶颈就是memory bound,大模型推理的性能受限于显存带宽。而GPU算力增长是快于显存以及显存带宽的。KV Cache的减少不单可以节省显存,还可以节省需要加载显存所需要的IO时间。我们可以来看一下KV Cache占用的显存占用量,下图来自PageaAttention论文:
KV Cache显存占用的计算方式如下:
1 token KV Cache = 2[K,V] x hidden_size x layers x 2[bytes per FP16] = 4 x H x N bytes
比如对于LLaMA 13B fp16模型,1个token所需要的KV Cache为:4 x 5120 x 40 = 819200 bytes,即 800KB。那么对于L=seq_len为2048 tokens的请求,需要的KV Cache数量为: 4 x 2048 x 5120 x 40 = 2048 x 800KB = 1.6GB。对于长度为L的请求,需要的KV Cache数量为:
KV Cache = 4 x L x H x N bytes # MHA
上述是在MHA下的KV Cache计算公式,最后,再考虑batch_size,那么公式为:
KV Cache = 4 x B x L x H x N bytes # MHA
如果是GQA,假设Q的组数为G,则GQA下需要的KV Cache为:
KV Cache = 4 x B x L x H x N / G bytes # GQA
最后我们用一个表格来直观感受一下,假设以下为某72B模型的配置:
我们可以看到,在BS=32(模拟的是高并发的情形)时,对于8K的上下文,如果是MHA,则需要640GB的KV Cache,这已经远远超过目前单卡GPU的显存上限了,我们只是少需要一台8卡的服务才装得下。假设我们有一台这样的服务器,单卡显存带宽为800Gb,互连带宽为800*8Gb,那么对于BS=32的MHA情形,每次decode步骤,大约需要100ms加载KV Cache的时间,而当G=8时,只需要12.5ms。在KV Cache加载的IO耗时上,GQA是MHA的1/8。因此,我们可以推算,当我们对服务压测极限吞吐,并且context比较长时(比如4K以上),GQA对比MHA应该有数倍的吞吐提升。需要注意的是,这里隐含假设了每次decode forward只需要加载一次KV Cache,没有冗余的IO消耗,比如使用FlashDecoding算法。FlashDecoding原理,可以参考我写的另一篇文章:
DefTruth:[Decoding优化] 原理&图解FlashDecoding/FlashDecoding++
0x03 跨层KV Cache共享: YOCO简析
- YOCO整体架构分析
YOCO是Decoder-Decoder架构,和Decoder-Only架构非常接近,之所以命名为Decoder-Decoder架构,是因为这两个Decoder的含义不是完全一致的。YOCO整体上包括两部分,一部分是Self-Decoder,这个和常见的Decoder Transformers是一样的;另一部分是Cross-Decoder。Self-Decoder负责产生global KV Cache,这个KV Cache会直接被后续的Cross-Decoder使用。这也是后半部分为啥叫做Cross-Decoder的原因,它使用Self-Decoder产生的KV Cache做交叉注意力机制,Cross-Decoder本身不产生KV Cache。
- YOCO的KV Cache计算
在GQA的计算公式基础上,我们可以很快得出YOCO的KV Cache计算公式,假设YOCO中有前Y层用于产生全局KV Cache,则计算公式为:
KV Cache = 4 x B x L x H x N / G x (Y / N) bytes # GQA + YOCO
其中B表示Batch size,L表示seq len,H表示hidden size,N表示层数,G表示GQA组数。
- YOCO + Efficient Self-Attention
我们可以看到,YOCO在Self-Attention阶段,选取了前L/2层,并且对于Attention具体的实现,选择了Efficient Self-Attention,比如Slide Window Attention和Retention。Slide Window Attention很好理解,就是对于long context选取固定窗口长度的上下文来做attention。
Retention则比较复杂,这里边会涉及chunk-wise retention的推导以及retention的递归形式和并行形式。并行形式可以提高训练效率,递归形式用于推理阶段则可以达到节省显存的效果。更详细的分析,推荐阅读: 白发小Luke船长:深入解析:Retentive Network (RetNet) —— Transformer 的有力继任者。PS: 似乎YOCO和Retention是同一个作者大佬。
我们可以看到,YOCO在推理阶段,可以节省大量Prefill的耗时,基于YOCO的特性,Prefill阶段可以跳过首Token的生成,也就是跳过Cross-Decoder。Prefill阶段的KV Cache显存需求从O(LND)下降到O((N+L)D),其中N表示seq_len,L表示transformer层数,D表示hidden size;Prefill耗时则从O(LN^2D)下降为O(LND),我们可以看到,耗时从N平方复杂度变为线性复杂度。
- TTFT和Prefill的区别?
歪个楼,论文里边的Prefill耗时,其实把生成首Token需要的部分排除掉了。就目前常用的decoder-only架构,比如LLaMA,Prefill耗时实际上就是TTFT。但是YOCO里边,Prefill耗时其实是小于TTFT的,而实际应用中,我们通常更加关注TTFT,而不仅仅是Prefill(生成context prompt的KV Cache耗时)。因此,如果我们从TTFT的角度去考虑,实际的性能收益,应该没有论文中描述的那么大。
- 实验结果
YOCO论文给出了一个在3B规模左右的模型的实验结果,可以看到在大部分任务上,YOCO-3B都达到了SOTA。并且,基于YOCO的KV Cache共享策略和ESA,可以将训练上下文扩展到1M规模。不过,没看到更大参数规模的模型实验,比如72B等,盲猜是分布式训练逻辑不太好写?个人可能更想看到在70B以上的参数规模,YOCO是否依然表现优越。
0x04 跨层KV Cache共享: CLA简析
- CLA整体架构分析
同期MIT-IBM Watson AI Lab也发了一篇类似的论文,提出了CLA( Cross-Layer Attention),即KV Cache跨层推理。简直和YOCO不谋而合,因此本文把这两篇论文的阅读笔记放到一起记录了。
CLA同样是一种KV Cache跨层共享的新方式,但和YOCO不一样的是,CLA并不是选择固定的前几层来产生KV Cache(比如YOCO,使用的是前L/2层),而是将产生KV Cache的层交替分布在模型不同深度的层,然后邻近层复用附近层产生的KV Cache进行Attention计算。CLA可以选择多种跨层共享KV Cache的方式,比较灵活。
- CLA的KV Cache计算
CLA和GQA属于两种不冲突的KV Cache共享方式,GQA为层内 KV Cache共享,而CLA为层间KV Cache共享。我们知道,GQA的KV Cache计算方式为:
KV Cache = 4 x B x L x H x N / G bytes # GQA
其中B表示Batch size,L表示seq len,H表示hidden size,N表示层数,G表示GQA组数。如果再考虑CLA层间的KV Cache共享,则计算方式变成:
KV Cache = 4 x B x L x H x N / G / C bytes # GQA + CLA
其中C表示多少层共享一份KV Cache,比如C=2表示CLA2模式,相邻的2个层共享一份KV Cache,则KV Cache总量可以节省一半。贴一下论文中的一些数值:
- 实验结果
CLA提供的是一种通用的跨层KV Cache共享思路,本质上和MQA/GQA等层内KV Cache共享的方式是不冲突的,因此,CLA可以和MQA/GQA结合一起使用。我们看到使用了CLA后,模型的精度还是会有一定程度的退化,因此,如何在精度和性能上做tradeoff,也是CLA无法避免的问题。
0x05 总结
本文简单解析了KV Cache共享的几种算法,包括MQA、GQA、YOCO和CLA,并且将其归纳为层间共享和层内共享来两种方式,也总结和整理了GQA、YOCO和CLA的KV Cache计算公式。所谓好记性,不如烂笔头。最后,LLM推理部署各方向新进展,推荐我整理的Awesome-LLM-Inference,传送门:https://https://github.com/DefTruth/Awesome-LLM-Inference
老样子,先更后改,后续有更新再继续修改......
- The End -
作者:DefTruth
来源:GiantPandaCV
推荐阅读
- 图解大模型计算加速系列:分离式推理架构1,从DistServe谈起
- 窥探Triton的lower(二)
- Llama也能做图像生成!港大字节推出开源自回归文生图模型,在线体验已开放
- 如何在 PyTorch 中 profile CUDA kernels
欢迎大家点赞留言,更多Arm技术文章动态请关注极术社区嵌入式AI专栏欢迎添加极术小姐姐微信(id:aijishu20)加入技术交流群,请备注研究方向。