1 前言
Attention 的计算过程中,需要之前的 k 和 v。
但每次计算的时候,把之前的 k,v 重新计算一次成本太高昂,需要找个地方临时存起来,这就是 KV Cache。
llama1 的代码就非常简单
self.cache_k = self.cache_k.to(xq)
self.cache_v = self.cache_v.to(xq)
self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk
self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv
keys = self.cache_k[:bsz, : start_pos + seqlen]
values = self.cache_v[:bsz, : start_pos + seqlen]
把 kv 的值更新到 cache 里,然后再从缓存中读取需要的 cache。
真实业务场景中,用户对于 context 的需求是一直变长,但耗时太久,导致用户体验上不去,真实需求变成了不存在的需求。
举例子,当年我们做智能音箱,到后面发现用户都是听首歌,看电影之类的需求。但用户只有这样的需求么?答案是否定的,是因为之前用户的复杂需求,都是智障回答。
用户试了几次,被“教育”后,就不愿意做进阶尝试了。
用户不请求的需求,不代表真的不存在,可能是你做的太垃圾了。
随着大模型在下游的落地不断深入,未来某些特定 domain,大概率会有每天,32k 往上,百亿-万亿的请求。
我们搞了半年(预计要持续两年)的冯诺依曼大模型,就是要优化这块。(近期一些小分支预计也会开源出来。
KV Cache 就是其中核心模块,很多地方要从零改。
我们跟 SGLang 的朋友,一起从头梳理了一下 SGLang 的 kv cache 部分。他们英文版本的 code walk through 已经开放出来。
https://github.com/zhaochenyang20/Awesome-ML-SYS-Tutorial/tree/main/sglang/kvcache-code-walk-through
这篇文章会从设计思路的角度去探讨,为什么这样设计 KV Cache,给我们后面的优化提供一定的借鉴和参考。
这篇文章要重点感谢,bruce,zenan,zhongtao,qwang。其中有些代码理解的现场源码 battle 还是让我觉得非常爽的。
2 cache 管理
关于 kv cache,直观上来看,有这么两个需求
- a) kv cache 是高频读写,量级不小,如何高效的管理
- b) kv cache 的实际业务有多种, MHA,GQA, MLA,DoubleSparse,如何做业务的隔离?
2.1 cache 池
内存池的定义,如果每次都是要使用内存的时候,才去申请,效率会很低。容易导致碎片化,带来管理的困难。
但我们可以提前申请一大块内存,需要的时候从这个内存池去拿就好了,SGLang 这里也一样。
2.2 二级 cache 池
kv cache 有这些自定义类型,还在不断增长,MHA,MLA,DoubleSpars cache,管理起来比较麻烦。
使用二级内存池,一级记录 high level 信息,跟具体业务隔离,二级各种类继承,根据需要来调用。
2.3 一级内存池:req_to_token_pool
跟踪每个请求使用的 token 位置,具体的 kv cache 在二级内存池。
key:第 i 个 req 对应 req_to_token 第 i 行,第 i 个 req 的第 j 个 token 对应 req_to_token 第 i 行第 j 列
value:从二级缓存池的 allocater 获取的 token_to_kv 的内存块 id(定位)
2.3.1 代码
初始化代码
sglang/srt/model_executor/model_runner.py
核心函数,init_memory_pool
初始化内存池
self.req_to_token_pool = ReqToTokenPool(
size=max_num_reqs + 1,
max_context_len=self.model_config.context_len + 4,
device=self.device,
use_records=False,
)
输入的最大的 req 请求量和 max content len
实现代码
sglang\srt\mem_cache\memory_pool.py
2.3.2 提供的功能
分配内存块
self.req_to_token = torch.zeros(
(size, max_context_len), dtype=torch.int32, device=device
)
最大的 req 量级,content len
内存块的 write,alloc,free,内存池的常见操作,这里不做赘述,大家自行看代码就好
2.3.3 实际操作代码
prepare_for_extend()函数非常合适来看一级内存是如何操作。
此函数为 extend 的执行做准备,分配一二级内存,更新一级内存池。 二级内存池,因为此时 kv cache 的值还没有计算好,只是分配,要在 forward 的过程中才会写入。
代码路径
python\sglang\srt\managers\schedule_batch.py
2.4 二级内存池:token_to_kv_pool
SGLang 两级内存池系统中的第二级,定义多种二级内存池类,MHA,MLA,DoubleSparse 等。
key:lay_id+token_cache_loc(不同层
value:真实的 kv cache 值
2.4.1 代码
因为有多个实现,类继承
初始化代码
sglang/srt/model_executor/model_runner.py
实现代码
sglang\srt\mem_cache\memory_pool.py
2.4.2 提供的功能
写入,分配,清理等核心功能。
值得注意的点,读内存,传入 layer id,直接返回某一层的内存。但写内存,传入 layer id loc,这个 loc 的 k v tensor。因为推理的时候,不会有需求说读取某个指定位置的 kv。
2.4.3 实际操作代码
二级内存池的 set(写入)只有各个 backend 的 forward 函数才会 set。保证了二级内存池 set 的规范化,极大降低了内存泄漏的风险。
通过类似多态来实现,这块的代码比较绕,从注册到使用。但整体写的挺干净,给 SGLang 点个赞。
一级内存就满世界 set 了,但本身小,内存泄漏问题不大。
具体 python\sglang\srt\layers\attention\下面的 backend 结尾的 py 文件都可以参考。
3 cache 复用
通过一二级 cache 池,解决了 cache 的管理问题。
但此时观察数据,会发现大部分请求都有大量重复的开头,如 sys prompt。那我们想办法复用这些 kv cache 就好了。
有两种复用方式,radix 和 chunk。
3.1 Radix cache
3.1.1 是什么?
Radix tree 可以理解为前缀树的高自定义版本
https://zhuanlan.zhihu.com/p/693556044
我们大致梳理下这个内存复用的需求:
计算好的 kv cache 存入 radix cache,新的请求匹配 radix cache,如果算过了,直接就拿过来用。
理想情况下,肯定是 radix cache 存的越多越好。但机器是有限的,所以需要对 kv cache 做清理。
但清理就会碰到这么一个问题,我们如何判断一块 radix cache 是可以清理的。跟 java 的引用计数是类似的,引用归零就可以做清理。
加一,减一的时机,当 req 执行完毕,就可以减一。req 还在 forward,就要加一,代表还在用,你不能清理。
以及是按 step 执行,req 有的执行完,有的才刚开始,会导致不同 req 的前缀会有重复执行,要及时做内存优化。
基于上面的需求,核心为下面三个函数。
- match_prefix:匹配命中了哪些 cache。
- cache_finished_req:req 执行完了,把引用-1。告诉大家,这个前缀我不用了。
- cache_unfinished_req:req 没执行完,这个前缀我要用,引用+1。同时自己的 kv cache 也更新到 radix tree(以及一些特殊逻辑的处理
3.1.2 代码
sglang/srt/mem_cache/radix_cache.py
3.1.3 cache_unfinished_req
把未完成的 req 的存入 radix cache。
引用计数加一,self.inc_lock_ref(new_last_node)
有一行代码非常有意思,cursor 也回答不好,我们当时讨论了半小时。
self.token_to_kv_pool.free(kv_indices[len(req.prefix_indices) : new_prefix_len])
req 没有执行完毕,为什么要 free?cache_finished_req 也有这行代码,为什么?
大家也可以自行思考为什么。(给一点提示,如果不加,会导致跟 page attention 类似的内存浪费问题。
3.2 Chunked cache
简单点,可以理解为把输入分成块来 cache
但核心函数跟 radix 一致,并且更简单,不展开。
3.3 token 粒度的内存池的好处
开放问题,对比 vllm 的 page attention,SGLang 做到了 token 粒度。
个人不靠谱看法,SGLang 这种做法更好一些。
page attention 只负责管理连续 token 的 kvcache 尽可能在内存连续。
radix attention 也做到了管理连续 token 的 kvcache 内存连续,但是在此基础上强化前缀缓存的复用。
4 推理过程中和 kv cache 的交互
我们已经设计好了二级内存池,如何复用也想好了,那么在 req 传入到结束的生命周期,也就是推理的过程中,是如何跟二级内存池和 radix(chunk)做交互的?
会从三部分来展开,整体框架,举个例子,细化版本
4.1 整体框架
三步走
- a)预处理
新的 req,查询是否命中,命中复用即可。没有命中,去二级缓存池申请槽位。
- b)模型执行
算好 batch,模型执行,算出 kv cache。主要跟二层 kv cache 池交互,把申请的槽位写入算好的 kv cache。
- c)后处理
如果 req 执行完毕,radix tree 引用减一,一二级内存池做对应清理。
4.2 举例子
这个图举了几个实际的 req 来讲清楚,整体流程是怎么走的。
感谢 zenan 和 zhongtao
4.3 细化版本
这个图要感谢 bruce,非常详细
END
作者:王焱
来源:GiantPandaLLM
推荐阅读
- YOLO LwF 破局持续目标检测 | 自蒸馏+重放记忆双引擎,单阶段检测器告别灾难性遗忘
- AI 能看懂细节了!IDEA 研究院多模态目标检测模型 DINO-XSeek,自然语言精准定位目标
- 轻量化+动态上采样,参数减38%、精度升4.1%,边缘设备实时部署
- 详解 vLLM 和 SGLang awq dequantize kernel 的魔法
欢迎大家点赞留言,更多 Arm 技术文章动态请关注极术社区嵌入式AI专栏欢迎添加极术小姐姐微信(id:aijishu20)加入技术交流群,请备注研究方向。