MQA/GQA/YOCO/CLA/MLKV笔记: 层内和层间KV Cache共享

0x00.前言

本人更多的技术笔记以及 CUDA 学习笔记,欢迎来 CUDA-Learn-Notes(CUDA Learn Notes with PyTorch)查阅。CUDA-Learn-Notes 包括了本人的LLM/VLM文章整理,以及对于SGEMM/HGEMM/GEMV等常见CUDA Kernel示例实现,目前已经累计  1.5k+ stars,传送门https://github.com/xlite-dev/...

image.png

CUDA Learn Notes with PyTorch

微软最近新发了一篇论文,提出了 YOCO(You Only Cache Once,和 RetNet 似乎是相同的作者),这是一个 KV Cache 层间共享的新思路。同期 MIT-IBM Watson AI Lab 也发了一篇类似的论文,提出了 CLA( Cross-Layer Attention),即 KV Cache 跨层推理。简直和 YOCO 不谋而合,因此本文把这两篇论文的阅读笔记放到一起记录了。

由于最近一两年都是在做推理相关的工作(刚毕业的时候做了几年 CV 算法,偶尔客串一下 NLP),对于非推理方向的论文关注已经变少了一些。但是最近在 arxiv 刷到这篇 paper,看完摘要,第一反应就是,“我去,这样做好像还真可以,为啥我没想到(太菜了)”。YOCO 的全称是《You Only Cache Once: Decoder-Decoder Architectures for Language Models》,YOCO 的名称借鉴了单阶段目标检测始祖 YOLO 的风格。整篇论文提出的最核心的创新就是,提出了一种新的 KV Cache 共享方式,即层间共享。并且基于此,将上下文长度扩展到 1 百万。我们知道,目前最常见的 KV Cache 共享策略是MQA/GQA,在 GQA 出来之后,已经很长一段时间没看到类似的研究了。从 Layer 的视角来看,MQA/GQA 可以认为是Intra-Layer KV Cache Shared(层内 KV Cache 共享),而 YOCO 提出的想法,则可以认为是Inter-Layer KV Cache Shared(层间 KV Cache 共享)。层间 KV Cache 共享,理论上的可以最多把 KV Cache 的 Memory 需求降低到 1/N(N 为 Transformer 层数),并且,这和层内 KV Cache 共享的技术,比如 MQA 和 GQA 是不冲突的,两者可以一起使用,从而极大地降低 KV Cache 的显存开销。

Image

本文将结合 MQA/GQA,简单记录一下自己对 YOCO 的理解。如有错误,欢迎大佬指正。本文大约 6-7k 字,内容如下:

  • 0x01 前置知识: Prefill 和 Decode 阶段
  • 0x02 层内 KV Cache 共享: MQA 简析
  • 0x03 层内 KV Cache 共享: GQA 简析
  • 0x04 跨层 KV Cache 共享: YOCO 简析
  • 0x05 跨层 KV Cache 共享: CLA 简析
  • 0x06 跨层 KV Cache 共享: MLKV 简析
  • 0x07 总结
  • 参考文献

0x01.前置知识: Prefill 和 Decode 阶段

LLM 推理过程分为 Prefill 和 Decode 两个阶段,其中 Prefill 阶段会对 Prompt 中所有的 token 做并行计算,得到 Prompt 中所有 Tokens 的 KV Cache 以及计算得到首 Token。Prompt 阶段 Token 计算得到的 KV Cache 会保存下来,留给 Decode 阶段复用,Decode 阶段是一个自回归过程,每 decode 一个新的 Token,都需要用到所有之前计算得到的 KV Cache 来计算当前 query token 的 Attention。因此,当输出长度越来越大或者 context 很长时,KV Cache 将会占用大量的显存。如何优化 KV Cache 的显存占用,一直都是 LLM 推理的核心主题之一。

Image

Prefill 阶段

Image

Decode 阶段

0x02.层内 KV Cache 共享: MQA 简析

Image

MHA/GQA/MQA

首先简单介绍一下 MQA 和 GQA。标准的多头注意力就是MHA(Multi Head Attention),在 MHA 中,KV Heads 的数量和 Query Heads 的数量相同,每个 Query Head 持有一个独立的 KV Head,在 Attention 中,对单独的 KV Head 做计算。但是,当模型层数加深和 Heads 数变多后,QKV Attention 的计算和 IO 都会快速增加。为了缓解这种情况,有学者提出了 MQA 和 GQA。

MQA (Multi Queries Attention): MQA 比较极端,只保留一个 KV Head,多个 Query Heads 共享相同的 KV Head。这相当于不同 Head 的 Attention 差异,全部都放在了 Query 上,需要模型仅从不同的 Query Heads 上就能够关注到输入 hidden states 不同方面的信息。这样做的好处是,极大地降低了 KV Cache 的需求,但是会导致模型效果有所下降。

Image

MQA

0x03.层内 KV Cache 共享: GQA 简析

GQA (Group Queries Attention): GQA 与 MQA 不同,而是采取了折中的做法。GQA 把 Query Heads 进行分组,每组 Query Heads 对应一个 KV Head。比如,把 8 个 Query Heads 分成 4 组,每个 Grouped Query Head 包含 2 个 Query Heads,一个 Grouped Query Head 对应一个 KV Head,此时总共有 4 个 KV Heads。GQA 可以在减少计算量和 KV Cache 同时确保模型效果不受到大的影响。

Image

GQA

在目前大部分主流训推框架或算法,都已经支持 MQA/GQA,比如 FlashAttention 中,也支持 MQA 和 GQA。对于 MQA 和 GQA 的情形,FlashAttention 采用 Indexing 的方式,而不是直接复制多份 KV Head 的内容到显存然后再进行计算。Indexing,即通过传入 KV/KV Head 索引到 Kernel 中,然后计算内存地址,直接从内存中读取 KV。

Image

GQA/MQA in FlashAttention V2

歪个楼,FlashAttention V1/V2/V3 系列原理&图解,推荐阅读我的另一篇文章:

DefTruth:[Attention 优化][2w 字] 原理&图解: 从 Online-Softmax 到 FlashAttention V1/V2/V3https://zhuanlan.zhihu.com/p/...

  • 关于 GQA 一些数值上的理解

GQA 最大的作用是节省显存,同时由于 LLM 推理的一大瓶颈就是 memory bound,大模型推理的性能受限于显存带宽。而 GPU 算力增长是快于显存以及显存带宽的。KV Cache 的减少不单可以节省显存,还可以节省需要加载显存所需要的 IO 时间。我们可以来看一下 KV Cache 占用的显存占用量,下图来自 PageaAttention 论文:

Image

KV Cache 占用计算方式

KV Cache 显存占用的计算方式如下:

1 token KV Cache = 2[K,V] x hidden_size x layers x 2[bytes per FP16] = 4 x H x N bytes

比如对于 LLaMA 13B fp16 模型,1 个 token 所需要的 KV Cache 为:4 x 5120 x 40 = 819200 bytes,即 800KB。那么对于 L=seq_len 为 2048 tokens 的请求,需要的 KV Cache 数量为: 4 x 2048 x 5120 x 40 = 2048 x 800KB = 1.6GB。对于长度为 L 的请求,需要的 KV Cache 数量为:

KV Cache = 4 x L x H x N bytes # MHA

上述是在 MHA 下的 KV Cache 计算公式,最后,再考虑 batch_size,那么公式为:

KV Cache = 4 x B x L x H x N bytes # MHA

如果是 GQA,假设 Q 的组数为 G,则 GQA 下需要的 KV Cache 为:

KV Cache = 4 x B x L x H x N / G bytes # GQA

最后我们用一个表格来直观感受一下,假设以下为某 72B 模型的配置:

image.png

我们可以看到,在 BS=32(模拟的是高并发的情形)时,对于 8K 的上下文,如果是 MHA,则需要 640GB 的 KV Cache,这已经远远超过目前单卡 GPU 的显存上限了,我们只是少需要一台 8 卡的服务才装得下。假设我们有一台这样的服务器,单卡显存带宽为 800Gb,互连带宽为 800*8Gb,那么对于 BS=32 的 MHA 情形,每次 decode 步骤,大约需要 100ms 加载 KV Cache 的时间,而当 G=8 时,只需要 12.5ms。在 KV Cache 加载的 IO 耗时上,GQA 是 MHA 的 1/8。因此,我们可以推算,当我们对服务压测极限吞吐,并且 context 比较长时(比如 4K 以上),GQA 对比 MHA 应该有数倍的吞吐提升。需要注意的是,这里隐含假设了每次 decode forward 只需要加载一次 KV Cache,没有冗余的 IO 消耗,比如使用 FlashDecoding 算法。FlashDecoding 原理,可以参考我写的另一篇文章:

DefTruth:[Decoding 优化] 原理&图解 FlashDecoding/FlashDecoding++https://zhuanlan.zhihu.com/p/...

0x04.跨层 KV Cache 共享: YOCO 简析

  • YOCO 整体架构分析

Image

YOCO 架构

YOCO 是 Decoder-Decoder 架构,和 Decoder-Only 架构非常接近,之所以命名为 Decoder-Decoder 架构,是因为这两个 Decoder 的含义不是完全一致的。YOCO 整体上包括两部分,一部分是 Self-Decoder,这个和常见的 Decoder Transformers 是一样的;另一部分是 Cross-Decoder。Self-Decoder 负责产生 global KV Cache,这个 KV Cache 会直接被后续的 Cross-Decoder 使用。这也是后半部分为啥叫做 Cross-Decoder 的原因,它使用 Self-Decoder 产生的 KV Cache 做交叉注意力机制,Cross-Decoder 本身不产生 KV Cache。

Image

Self Attention + Cross Attention

  • YOCO 的 KV Cache 计算

在 GQA 的计算公式基础上,我们可以很快得出 YOCO 的 KV Cache 计算公式,假设 YOCO 中有前 Y 层用于产生全局 KV Cache,则计算公式为:

KV Cache = 4 x B x L x H x N / G x (Y / N) bytes # GQA + YOCO

其中 B 表示 Batch size,L 表示 seq len,H 表示 hidden size,N 表示层数,G 表示 GQA 组数。

  • YOCO + Efficient Self-Attention

我们可以看到,YOCO 在 Self-Attention 阶段,选取了前 L/2 层,并且对于 Attention 具体的实现,选择了 Efficient Self-Attention,比如 Slide Window Attention 和 Retention。Slide Window Attention 很好理解,就是对于 long context 选取固定窗口长度的上下文来做 attention。

Image

Slide Window Attention

Retention 则比较复杂,这里边会涉及 chunk-wise retention 的推导以及 retention 的递归形式和并行形式。并行形式可以提高训练效率,递归形式用于推理阶段则可以达到节省显存的效果。更详细的分析,推荐阅读:  白发小 Luke 船长:深入解析:Retentive Network (RetNet) —— Transformer 的有力继任者(https://zhuanlan.zhihu.com/p/...)。PS: 似乎 YOCO 和 Retention 是同一个作者大佬。

)。PS: 似乎 YOCO 和 Retention 是同一个作者大佬。

Image

Retention

我们可以看到,YOCO 在推理阶段,可以节省大量 Prefill 的耗时,基于 YOCO 的特性,Prefill 阶段可以跳过首 Token 的生成,也就是跳过 Cross-Decoder。Prefill 阶段的 KV Cache 显存需求从 O(LND)下降到 O((N+L)D),其中 N 表示 seq_len,L 表示 transformer 层数,D 表示 hidden size;Prefill 耗时则从 O(LN^2D)下降为 O(LND),我们可以看到,耗时从 N 平方复杂度变为线性复杂度。

Image

Prefill in YOCO

  • TTFT 和 Prefill 的区别?

歪个楼,论文里边的 Prefill 耗时,其实把生成首 Token 需要的部分排除掉了。就目前常用的 decoder-only 架构,比如 LLaMA,Prefill 耗时实际上就是 TTFT。但是 YOCO 里边,Prefill 耗时其实是小于 TTFT 的,而实际应用中,我们通常更加关注 TTFT,而不仅仅是 Prefill(生成 context prompt 的 KV Cache 耗时)。因此,如果我们从 TTFT 的角度去考虑,实际的性能收益,应该没有论文中描述的那么大。

  • 实验结果

YOCO 论文给出了一个在 3B 规模左右的模型的实验结果,可以看到在大部分任务上,YOCO-3B 都达到了 SOTA。并且,基于 YOCO 的 KV Cache 共享策略和 ESA,可以将训练上下文扩展到 1M 规模。不过,没看到更大参数规模的模型实验,比如 72B 等,盲猜是分布式训练逻辑不太好写?个人可能更想看到在 70B 以上的参数规模,YOCO 是否依然表现优越。

Image

YOCO 实验结果

0x05.跨层 KV Cache 共享: CLA 简析

  • CLA 整体架构分析

同期 MIT-IBM Watson AI Lab 也发了一篇类似的论文,提出了 CLA( Cross-Layer Attention),即 KV Cache 跨层推理。简直和 YOCO 不谋而合,因此本文把这两篇论文的阅读笔记放到一起记录了。

Image

Cross-Layer Attention

CLA 同样是一种 KV Cache 跨层共享的新方式,但和 YOCO 不一样的是,CLA 并不是选择固定的前几层来产生 KV Cache(比如 YOCO,使用的是前 L/2 层),而是将产生 KV Cache 的层交替分布在模型不同深度的层,然后邻近层复用附近层产生的 KV Cache 进行 Attention 计算。CLA 可以选择多种跨层共享 KV Cache 的方式,比较灵活。

Image

CLA 的多种跨层共享 KV Cache 方式

  • CLA 的 KV Cache 计算

CLA 和 GQA 属于两种不冲突的 KV Cache 共享方式,GQA 为层内 KV Cache 共享,而 CLA 为层间 KV Cache 共享。我们知道,GQA 的 KV Cache 计算方式为:

KV Cache = 4 x B x L x H x N / G bytes # GQA

其中 B 表示 Batch size,L 表示 seq len,H 表示 hidden size,N 表示层数,G 表示 GQA 组数。如果再考虑 CLA 层间的 KV Cache 共享,则计算方式变成:

KV Cache = 4 x B x L x H x N / G / C bytes # GQA + CLA

其中 C 表示多少层共享一份 KV Cache,比如 C=2 表示 CLA2 模式,相邻的 2 个层共享一份 KV Cache,则 KV Cache 总量可以节省一半。贴一下论文中的一些数值:

Image

CLA KV Cache

  • 实验结果

Image

CLA 实验结果

CLA 提供的是一种通用的跨层 KV Cache 共享思路,本质上和 MQA/GQA 等层内 KV Cache 共享的方式是不冲突的,因此,CLA 可以和 MQA/GQA 结合一起使用。我们看到使用了 CLA 后,模型的精度还是会有一定程度的退化,因此,如何在精度和性能上做 tradeoff,也是 CLA 无法避免的问题。

0x06.跨层 KV Cache 共享: MLKV 简析

  • MLKV 整体框架分析

Image

MLKV

在了解了 YOCO 和 CLA 的整体结构和思路后,MLKV 就很好理解了,从创新点上看,MLKV 和 YOCO 都是 CLA 的特例。YOCO 不再重复了,请看前文的分析。MLKV 则是和 CLA 一样,做跨层的 KV Cache 共享,本质上和 CLA 是一个路子,但是主要是对 MQA 做更加极端的扩展,也就是 MQA+跨层 KV Cache 共享。不过,从论文上看,MLKV 的分析和实验设计不太丰富的样子。论文只给了 Pythia-160M 的对比实验(PS: 是不是实验室卡不够用了)

Image

MLKV 实验结果

  • MLKV 的 KV Cache 计算

不过比较贴心的是,论文给出了 KV Cache 的计算公式:

Image

MLKV KV Cache 计算公式

0x07.总结

本文简单解析了 KV Cache 共享的几种算法,包括 MQA、GQA、YOCO、CLA 和 MLKV,并且将其归纳为层间共享和层内共享来两种方式,也总结和整理了 GQA、YOCO、CLA 和 MLKV 的 KV Cache 计算公式。所谓好记性,不如烂笔头。最后,LLM 推理部署各方向新进展,推荐我整理的 Awesome-LLM-Inference,传送门:https://https://github.com/De...

Image

Awesome-LLM-Inference

本人更多的技术笔记以及 CUDA 学习笔记,欢迎来 CUDA-Learn-Notes(CUDA Learn Notes with PyTorch)查阅。CUDA-Learn-Notes 包括了本人的LLM/VLM文章整理,以及对于SGEMM/HGEMM/GEMV等常见CUDA Kernel示例实现,目前已经累计  1.5k+ stars,传送门https://github.com/xlite-dev/...

Image

CUDA Learn Notes with PyTorch

老样子,先更后改,后续有更新再继续修改......

参考文献

END

作者:DefTruth
来源:GiantPandaLLM

推荐阅读

欢迎大家点赞留言,更多 Arm 技术文章动态请关注极术社区嵌入式AI专栏欢迎添加极术小姐姐微信(id:aijishu20)加入技术交流群,请备注研究方向。

推荐阅读
关注数
18920
内容数
1430
嵌入式端AI,包括AI算法在推理框架Tengine,MNN,NCNN,PaddlePaddle及相关芯片上的实现。欢迎加入微信交流群,微信号:aijishu20(备注:嵌入式)
目录
极术微信服务号
关注极术微信号
实时接收点赞提醒和评论通知
安谋科技学堂公众号
关注安谋科技学堂
实时获取安谋科技及 Arm 教学资源
安谋科技招聘公众号
关注安谋科技招聘
实时获取安谋科技中国职位信息