在开始之前,给大家出几个“高频面试题”,看看你能答上来吗?
- 举例说明 KV Cache 的计算过程
- 为什么要用 KV Cache?它能解决什么问题,代价又是什么?
- vLLM 里 KV Cache 形影不离的搭档是谁?
还记得之前那篇大语言模型推理,用动画一看就懂!的文章吗?是的!我们再次用动画来演示大语言模型的推理过程!几乎所有的大语言模型(LLM)都基于 Transformer 架构,它依赖于之前生成的 token 来预测下一个字符。而自注意力机制(self-attention)则是模型推理的核心:它不仅需要当前 token,还要每次“回顾”之前的所有 token。
动画演示 KV Cache
为了更加形象理解上面提到的自注意力机制的“回顾机制”,下面我画了一张图。它是大语言模型推理,用动画一看就懂!中那个文本生成步骤的第四步,其中计算 self-attention 时所需的 Key 和 Value 的示意图。
注意:Prompt 是 "The future of AI is" 有五个 token,第一步推理时模型输入的是整个 prompt,会计算出每个 prompt token 对应的 key 值和 value 值,为了清晰起见图里仅用 K1 和 V1 来代表它们。
接下来的动画演示了每一步计算自注意力的过程,清晰起见去掉了其他算子。
从图里看到每一步计算时,当前的 Qi 都需要和之前的 Kj 进行矩阵乘法计算,然后再和之前的 Vj 进行矩阵乘法。那么为了节省算力,我们可以把之前的 Kj、Vj 的结果“缓存”起来,这样每次只需要做增量计算。这个缓存机制就是 KV Cache ,简单却非常有效!来看看加上 KV Cache,推理过程变得多轻松吧!
从上面的动画可以看到除了第一步,其他步骤都可以通过缓存复用之前步骤产生的 Ki 和 Vi。这些步骤在计算 self attention 时只有一个 query,因此叫做 single query attention。
KV Cache 有多大?
一条文本所需的 KV Cache 计算公式如下:
KV Cache Bytes = 2 * 2 * Sequence Length *
举个实际的例子,Qwen2 7B 这个国产大模型,在 4 K 序列长度下,KV Cache 大小是 1.6 GB!这是什么概念呢?要知道很多人的显卡也就 8GB 或者 16GB。
KV Cache 的代价
KV Cache 虽然能节省计算,但是显存开销也很显著,随着模型变大(Hidden Size 和 Layer Num 会增大)、序列长度变长,占用的显存迅速膨胀。
假设你实现 KV Cache 时,每次都是预留了一个超大的仓库来存放它,但每次只用了一小部分,这会导致资源浪费,你服务不了太多用户,而且容易出现“撑爆”显存(OOM)的现象。因为你的用户,他们每次推理时的文本长度是变化的!
那么如何解决这些问题呢?vLLM 提出的 PagedAttention 就是聪明地按需分配空间,像是“分隔储物柜”,需要多少就分配多少,避免浪费。
在下一篇文章中,我将继续用动画的方式,深入拆解 KV Cache 的好基友 PagedAttention 的工作原理,带你从源码层面剖析 vLLM 如何用这一技术解决显存瓶颈。敬请期待!
参考资料:
Transformers KV Caching Explained
游凯超(vLLM 核心开发者)知乎上的《一文读懂 KV Cache》
EFFICIENTLY SCALING TRANSFORMER INFERENCE
END
来源:GiantPandaCV
推荐阅读
- Flex Attention API 应用 Notebook 代码速览
- 【翻译】教程:在PyTorch中为CUDA库绑定Python接口
- OpenVINO C++ 部署 YOLO11 对象检测
- 【翻译】FlexAttetion 基于Triton打造灵活度拉满的Attention
欢迎大家点赞留言,更多Arm技术文章动态请关注极术社区嵌入式AI专栏欢迎添加极术小姐姐微信(id:aijishu20)加入技术交流群,请备注研究方向。