KV Cache优化: 层内和层间KV Cache共享

0x00 前言

微软最近新发了一篇论文,提出了YOCO(You Only Cache Once,和RetNet似乎是相同的作者),这是一个KV Cache层间共享的新思路。同期MIT-IBM Watson AI Lab也发了一篇类似的论文,提出了CLA( Cross-Layer Attention),即KV Cache跨层推理。简直和YOCO不谋而合,因此本文把这两篇论文的阅读笔记放到一起记录了。

由于最近一两年都是在做推理相关的工作(刚毕业的时候做了几年CV算法,偶尔客串一下NLP),对于非推理方向的论文关注已经变少了一些。但是最近在arxiv刷到这篇paper,看完摘要,第一反应就是,“我去,这样做好像还真可以,为啥我没想到(太菜了)”。YOCO的全称是《You Only Cache Once: Decoder-Decoder Architectures for Language Models》,YOCO的名称借鉴了单阶段目标检测始祖YOLO的风格。整篇论文提出的最核心的创新就是,提出了一种新的KV Cache共享方式,即层间共享。并且基于此,将上下文长度扩展到1百万。我们知道,目前最常见的KV Cache共享策略是MQA/GQA,在GQA出来之后,已经很长一段时间没看到类似的研究了。从Layer的视角来看,MQA/GQA可以认为是Intra-Layer KV Cache Shared(层内KV Cache共享),而YOCO提出的想法,则可以认为是Inter-Layer KV Cache Shared(层间KV Cache共享)。层间KV Cache共享,理论上的可以最多把KV Cache的Memory需求降低到1/N(N为Transformer层数),并且,这和层内KV Cache共享的技术,比如MQA和GQA是不冲突的,两者可以一起使用,从而极大地降低KV Cache的显存开销。

image.png

本文将结合MQA/GQA,简单记录一下自己对YOCO的理解。如有错误,欢迎大佬指正。本文内容如下:

  • 0x01 层内KV Cache共享: MQA简析
  • 0x02 层内KV Cache共享: GQA简析
  • 0x03 跨层KV Cache共享: YOCO简析
  • 0x04 跨层KV Cache共享: CLA简析
  • 0x05 总结

0x01 层内KV Cache共享: MQA简析

image.png

首先简单介绍一下MQA和GQA。标准的多头注意力就是MHA(Multi Head Attention),在MHA中,KV Heads的数量和Query Heads的数量相同,每个Query Head持有一个独立的KV Head,在Attention中,对单独的KV Head做计算。但是,当模型层数加深和Heads数变多后,QKV Attention的计算和IO都会快速增加。为了缓解这种情况,有学者提出了MQA和GQA。

MQA (Multi Queries Attention): MQA比较极端,只保留一个KV Head,多个Query Heads共享相同的KV Head。这相当于不同Head的Attention差异,全部都放在了Query上,需要模型仅从不同的Query Heads上就能够关注到输入hidden states不同方面的信息。这样做的好处是,极大地降低了KV Cache的需求,但是会导致模型效果有所下降。

image.png

0x02 层内KV Cache共享: GQA简析

GQA (Group Queries Attention): GQA与MQA不同,而是采取了折中的做法。GQA把Query Heads进行分组,每组Query Heads对应一个KV Head。比如,把8个Query Heads分成4组,每个Grouped Query Head包含2个Query Heads,一个Grouped Query Head对应一个KV Head,此时总共有4个KV Heads。GQA可以在减少计算量和KV Cache同时确保模型效果不受到大的影响。

image.png

在目前大部分主流训推框架或算法,都已经支持MQA/GQA,比如FlashAttention中,也支持MQA和GQA。对于MQA和GQA的情形,FlashAttention采用Indexing的方式,而不是直接复制多份KV Head的内容到显存然后再进行计算。Indexing,即通过传入KV/KV Head索引到Kernel中,然后计算内存地址,直接从内存中读取KV。

image.png

歪个楼,FlashAttention V1/V2/V3系列原理&图解,推荐阅读我的另一篇文章:

DefTruth:[Attention优化][2w字] 原理&图解: 从Online-Softmax到FlashAttention V1/V2/V3

  • 关于GQA一些数值上的理解

GQA最大的作用是节省显存,同时由于LLM推理的一大瓶颈就是memory bound,大模型推理的性能受限于显存带宽。而GPU算力增长是快于显存以及显存带宽的。KV Cache的减少不单可以节省显存,还可以节省需要加载显存所需要的IO时间。我们可以来看一下KV Cache占用的显存占用量,下图来自PageaAttention论文:

image.png

KV Cache显存占用的计算方式如下:

1 token KV Cache = 2[K,V] x hidden_size x layers x 2[bytes per FP16] = 4 x H x N bytes

比如对于LLaMA 13B fp16模型,1个token所需要的KV Cache为:4 x 5120 x 40 = 819200 bytes,即 800KB。那么对于L=seq_len为2048 tokens的请求,需要的KV Cache数量为: 4 x 2048 x 5120 x 40 = 2048 x 800KB = 1.6GB。对于长度为L的请求,需要的KV Cache数量为:

KV Cache = 4 x L x H x N bytes # MHA

上述是在MHA下的KV Cache计算公式,最后,再考虑batch_size,那么公式为:

KV Cache = 4 x B x L x H x N bytes # MHA

如果是GQA,假设Q的组数为G,则GQA下需要的KV Cache为:

KV Cache = 4 x B x L x H x N / G bytes # GQA

最后我们用一个表格来直观感受一下,假设以下为某72B模型的配置:

image.png

我们可以看到,在BS=32(模拟的是高并发的情形)时,对于8K的上下文,如果是MHA,则需要640GB的KV Cache,这已经远远超过目前单卡GPU的显存上限了,我们只是少需要一台8卡的服务才装得下。假设我们有一台这样的服务器,单卡显存带宽为800Gb,互连带宽为800*8Gb,那么对于BS=32的MHA情形,每次decode步骤,大约需要100ms加载KV Cache的时间,而当G=8时,只需要12.5ms。在KV Cache加载的IO耗时上,GQA是MHA的1/8。因此,我们可以推算,当我们对服务压测极限吞吐,并且context比较长时(比如4K以上),GQA对比MHA应该有数倍的吞吐提升。需要注意的是,这里隐含假设了每次decode forward只需要加载一次KV Cache,没有冗余的IO消耗,比如使用FlashDecoding算法。FlashDecoding原理,可以参考我写的另一篇文章:

DefTruth:[Decoding优化] 原理&图解FlashDecoding/FlashDecoding++

0x03 跨层KV Cache共享: YOCO简析

  • YOCO整体架构分析

image.png

YOCO是Decoder-Decoder架构,和Decoder-Only架构非常接近,之所以命名为Decoder-Decoder架构,是因为这两个Decoder的含义不是完全一致的。YOCO整体上包括两部分,一部分是Self-Decoder,这个和常见的Decoder Transformers是一样的;另一部分是Cross-Decoder。Self-Decoder负责产生global KV Cache,这个KV Cache会直接被后续的Cross-Decoder使用。这也是后半部分为啥叫做Cross-Decoder的原因,它使用Self-Decoder产生的KV Cache做交叉注意力机制,Cross-Decoder本身不产生KV Cache。

image.png

  • YOCO的KV Cache计算

在GQA的计算公式基础上,我们可以很快得出YOCO的KV Cache计算公式,假设YOCO中有前Y层用于产生全局KV Cache,则计算公式为:

KV Cache = 4 x B x L x H x N / G x (Y / N) bytes # GQA + YOCO

其中B表示Batch size,L表示seq len,H表示hidden size,N表示层数,G表示GQA组数。

  • YOCO + Efficient Self-Attention

我们可以看到,YOCO在Self-Attention阶段,选取了前L/2层,并且对于Attention具体的实现,选择了Efficient Self-Attention,比如Slide Window Attention和Retention。Slide Window Attention很好理解,就是对于long context选取固定窗口长度的上下文来做attention。

image.png

Retention则比较复杂,这里边会涉及chunk-wise retention的推导以及retention的递归形式和并行形式。并行形式可以提高训练效率,递归形式用于推理阶段则可以达到节省显存的效果。更详细的分析,推荐阅读: 白发小Luke船长:深入解析:Retentive Network (RetNet) —— Transformer 的有力继任者。PS: 似乎YOCO和Retention是同一个作者大佬。

image.png

我们可以看到,YOCO在推理阶段,可以节省大量Prefill的耗时,基于YOCO的特性,Prefill阶段可以跳过首Token的生成,也就是跳过Cross-Decoder。Prefill阶段的KV Cache显存需求从O(LND)下降到O((N+L)D),其中N表示seq_len,L表示transformer层数,D表示hidden size;Prefill耗时则从O(LN^2D)下降为O(LND),我们可以看到,耗时从N平方复杂度变为线性复杂度。

image.png

  • TTFT和Prefill的区别?

歪个楼,论文里边的Prefill耗时,其实把生成首Token需要的部分排除掉了。就目前常用的decoder-only架构,比如LLaMA,Prefill耗时实际上就是TTFT。但是YOCO里边,Prefill耗时其实是小于TTFT的,而实际应用中,我们通常更加关注TTFT,而不仅仅是Prefill(生成context prompt的KV Cache耗时)。因此,如果我们从TTFT的角度去考虑,实际的性能收益,应该没有论文中描述的那么大。

  • 实验结果

YOCO论文给出了一个在3B规模左右的模型的实验结果,可以看到在大部分任务上,YOCO-3B都达到了SOTA。并且,基于YOCO的KV Cache共享策略和ESA,可以将训练上下文扩展到1M规模。不过,没看到更大参数规模的模型实验,比如72B等,盲猜是分布式训练逻辑不太好写?个人可能更想看到在70B以上的参数规模,YOCO是否依然表现优越。

image.png

0x04 跨层KV Cache共享: CLA简析

  • CLA整体架构分析

同期MIT-IBM Watson AI Lab也发了一篇类似的论文,提出了CLA( Cross-Layer Attention),即KV Cache跨层推理。简直和YOCO不谋而合,因此本文把这两篇论文的阅读笔记放到一起记录了。

image.png

CLA同样是一种KV Cache跨层共享的新方式,但和YOCO不一样的是,CLA并不是选择固定的前几层来产生KV Cache(比如YOCO,使用的是前L/2层),而是将产生KV Cache的层交替分布在模型不同深度的层,然后邻近层复用附近层产生的KV Cache进行Attention计算。CLA可以选择多种跨层共享KV Cache的方式,比较灵活。

image.png

  • CLA的KV Cache计算

CLA和GQA属于两种不冲突的KV Cache共享方式,GQA为层内 KV Cache共享,而CLA为层间KV Cache共享。我们知道,GQA的KV Cache计算方式为:

KV Cache = 4 x B x L x H x N / G bytes # GQA

其中B表示Batch size,L表示seq len,H表示hidden size,N表示层数,G表示GQA组数。如果再考虑CLA层间的KV Cache共享,则计算方式变成:

KV Cache = 4 x B x L x H x N / G / C bytes # GQA + CLA

其中C表示多少层共享一份KV Cache,比如C=2表示CLA2模式,相邻的2个层共享一份KV Cache,则KV Cache总量可以节省一半。贴一下论文中的一些数值:

image.png

  • 实验结果

image.png

CLA提供的是一种通用的跨层KV Cache共享思路,本质上和MQA/GQA等层内KV Cache共享的方式是不冲突的,因此,CLA可以和MQA/GQA结合一起使用。我们看到使用了CLA后,模型的精度还是会有一定程度的退化,因此,如何在精度和性能上做tradeoff,也是CLA无法避免的问题。

0x05 总结

本文简单解析了KV Cache共享的几种算法,包括MQA、GQA、YOCO和CLA,并且将其归纳为层间共享和层内共享来两种方式,也总结和整理了GQA、YOCO和CLA的KV Cache计算公式。所谓好记性,不如烂笔头。最后,LLM推理部署各方向新进展,推荐我整理的Awesome-LLM-Inference,传送门:https://https://github.com/DefTruth/Awesome-LLM-Inference

image.png

老样子,先更后改,后续有更新再继续修改......

- The End -

作者:DefTruth
来源:GiantPandaCV

推荐阅读

欢迎大家点赞留言,更多Arm技术文章动态请关注极术社区嵌入式AI专栏欢迎添加极术小姐姐微信(id:aijishu20)加入技术交流群,请备注研究方向。

推荐阅读
关注数
18812
内容数
1354
嵌入式端AI,包括AI算法在推理框架Tengine,MNN,NCNN,PaddlePaddle及相关芯片上的实现。欢迎加入微信交流群,微信号:aijishu20(备注:嵌入式)
目录
极术微信服务号
关注极术微信号
实时接收点赞提醒和评论通知
安谋科技学堂公众号
关注安谋科技学堂
实时获取安谋科技及 Arm 教学资源
安谋科技招聘公众号
关注安谋科技招聘
实时获取安谋科技中国职位信息