可学习AttentionPredictor:实现16倍KV缓存压缩与Attention加速

以大模型百万分之一参数量的预测小模型,在 LongBench 数据集上使用 4%的 KV Cache 达到 Full Cache 99%的模型准确率。

近日,中科大王杰教授团队(MIRA Lab)和华为诺亚方舟实验室(Huawei Noah's Ark Lab)联合提出了基于时序预测的可学习稀疏注意力计算方法 AttentionPredictor,使用仅大模型百万分之一大小的小模型辅助大模型推理,在十六倍的压缩比下取得近乎无损的推理准确率,推理速度提升 1.4 倍,这为长序列推理任务的高显存占用带来了新的解决方案。代码已开源!

image.png

论文:https://arxiv.org/abs/2502.04077
代码:https://github.com/MIRALab-US...

太长不看版

随着大语言模型(LLM)的广泛应用,长上下文生成场景下通过键值缓存(KV Cache)压缩实现高效推理成为研究热点。现有方法基于启发式注意力评分筛选关键 token,但忽略了注意力分数的时间模式特性,导致大模型性能显著下降。

本文提出 AttentionPredictor,一种可学习的注意力预测方法,通过轻量级卷积模型捕捉注意力时空特征,使用 LLM 百万分之一的参数量即可精准预测下一 token 的注意力分数。结合跨 token 的 KV 缓存预取框架,AttentionPredictor 在 16 倍压缩率下仍保持大模型性能近乎无损,推理速度提升 1.4 倍,显著超越现有最优方法。

Image

与近期大热的 DeepSeek 的 NSA、Kimi 的 MoBA、微软的 SeerAttention 比较,我们工作的相同点都是分块压缩和检索近似注意力分数,并且都使用可学习的小模型加速注意力计算。我们工作与近期工作的不同点有:

  • 我们分块并压缩的是计算出的注意力分数,而这些工作分块的是 Key。
  • 我们估算 Attention 方式是从历史注意力分数中预测下一步的注意力分数,而近期工作是使用压缩后的 Key 与 Query 相乘得到估算的注意力分数。
  • 我们的可学习部分是预测下一步 attention 的模型,而近期工作是学习一个模型来表征压缩后的 key。
  • 我们的方法是 post-training 的,不涉及 LLM 训练阶段,而近期工作都在训练阶段就加入了稀疏注意力。

引言

随着大语言模型(LLM)在长上下文推理任务中的广泛应用,键值缓存(KV Cache)的内存与计算开销成为制约其部署效率的核心瓶颈。传统方法通过启发式规则(如历史注意力累加、局部窗口筛选)或近似检索技术压缩 KV Cache,但这些方法普遍面临两大挑战:

其一,静态启发式评分难以捕捉注意力动态演化的时间模式,导致关键 token 识别不准确,模型性能显著下降;

其二,现有检索方法依赖当前步骤的 Query 信息,无法通过异步预取机制隐藏计算与传输延迟,限制了实际加速效果。

针对上述问题,本文提出 AttentionPredictor——首个学习 Attention 动态时空模式的 KV Cache 压缩框架。通过系统分析注意力分数的演化规律,我们发现注意力分布呈现强时间特征, 如重复访问(Re-access)、顺序访问(Sequential)、周期性(Seasonal)特征。

基于此,本文将注意力序列建模为时空信号,利用轻量级卷积网络预测下一 token 的注意力分数,突破传统静态评分的局限性。同时,设计跨 token 的 KV Cache 预取框架,通过异步加载与并行化调度,将评估与传输时间隐藏于模型推理过程中,显著提升解码效率。

与现有方法相比,本文的创新性主要体现在三个方面:

  1. 动态时间模式建模:首次将注意力分数建模为时空序列,通过卷积网络学习重复访问(Re-access)、顺序访问(Sequential)、周期性(Seasonal)等模式,实现高精度预测下一步注意力。
  2. 跨令牌预取框架:提出异步加载下一 token 关键缓存的机制,将 token 评估与传输时间隐藏于推理过程中,显著降低解码延迟。
  3. 高效压缩与校准:引入分块注意力压缩和分布误差校准技术,在减少计算量的同时保持预测准确性,16 倍压缩率下模型性能损失小于 1%。。

在长序列任务上,AttentionPredictor 在 16 倍压缩率下平均性能损失<1%。

在长 CoT 任务上,在 16k 上下文长度下,AttentionPredictor 仅损失 2.05%准确率,显著优于 Quest 的 16.91%下降。

Image

图 1. H2O、Quest 和 AttentionPredictor 使用历史注意力分数识别下一步的关键 token 的方法比较。我们基于学习的时空预测器可以捕捉动态注意力模式,并准确预测下一步的注意力得分。

1. 背景与问题介绍

当前高效 LLM 推理与 KV 缓存压缩方法主要分为四类:

  • 缓存驱逐方法:基于启发式规则筛选历史关键 token,如 StreamingLLM(保留初始与近期 token)、H2O(历史注意力累加)、SnapKV(窗口内注意力筛选)、MInference(垂直-斜线、block 等模式)。此类方法依赖静态评分,难以捕捉动态时间模式,导致长上下文场景下性能显著下降。
  • 缓存检索方法:通过近似 Query-Key 交互检索关键 token,如 Quest(分页键近似计算)、PQCache(键值量化)。这类方法的计算开销较大,且准确率随分页大学的增加而大幅下降(如 Quest 在 page size 从 16 增加到 64 时,精度下降 11%)。且这依赖当前步骤 Query,无法通过预取隐藏延迟。
  • 可学习的稀疏注意力:Kimi 的 MoBA 与 Quest 类似地计算分块 attention,再取 top-K 作为稀疏 mask。进一步地,MoBA 将稀疏 Attention 加入了模型训练,使模型在稀疏注意力上的性能得以提高。微软的 SeerAttention 将按块 pooling 后的 Keys 再经过一个可学习的 Linear 层,以对压缩后的 Keys 编码,使计算出的近似 Attention 接近原始分布。DeepSeek 的 NSA 使用可学习的块编码模型代替常见的 Pooling 来建模每个分块,并结合了多种缓存压缩策略。这些工作都需要在模型训练阶段就加入稀疏注意力的使用,以达到更好的模型效果。
  • 跨层预取方法:结合缓存检索与跨层预取(如 InfiniGen),但单层推理时间不足以覆盖长序列传输延迟和估算时延,扩展性受限。

现有方法均面临动态模式建模不足与计算-传输延迟耦合两大瓶颈,制约了高压缩率下的模型性能与推理速度。

2. 动机实验——注意力具有时序模式

为揭示注意力演化的内在规律,本文通过大量实验分析发现解码过程中注意力分布呈现三类可预测的模式(见图 2):

  • 重复访问(Re-access):特定 token 在多步骤中被反复关注(垂直带状分布);
  • 顺序访问(Sequential):注意力沿 token 序列逐步推移(对角线分布);
  • 周期性(Seasonal):关键 token 周期性出现(交替带状分布)。

我们发现 Query 具有很强的连续性,相邻解码步骤的查询向量余弦相似度高达 87%。推导表明注意力分数的差异主要由微小增量 Δq 主导,使得相邻步骤关键 token 高度重叠,支持跨 token 预取。详细推导见论文第 2 章节。

Image

图 2. 三种时序注意力模式的可视化。Re-access 显示对特定标记的重复关注。Sequential 注意力向下一个标记推移。Seasonal 显示出周期性的模式,如较为集中的高注意力分数和较为均匀分布的注意力交替出现。

3. 方法介绍

本文方法包含注意力预测小模型(AttentionPredictor)与跨令牌 KV 缓存预取框架两部分(如图 3),前者通过动态时空建模精准筛选关键 token,后者通过异步加载机制隐藏计算与传输延迟,共同实现高效长上下文推理。

Image

图 3. 我们提出的 KV Cache 压缩方法 AttentionPredictor 和跨 token 预取框架。(a)AttentionPredictor 将历史注意力分数建模为时空序列,并借助预训练模型预测下一步的注意力。为了提升效率,在每个 decoding 步骤中,历史注意力分数会以压缩形式进行更新。(b)跨 token 预取框架。在 LLM 推理过程中,异步评估关键 token,并为下一个 token 获取 KV,从而有效加速解码阶段。

3.1 AttentionPredictor

问题建模: 在 LLM 解码阶段,KV 缓存压缩的目标是选择预算 B 个的关键令牌位置 pi 最大化注意力恢复率:

Image

即保留的注意力分数占比。传统方法依赖静态启发式评分,而 AttentionPredictor 通过预测下一 token 的注意力分数动态筛选关键位置。

时空序列建模:本文将历史注意力分数Image建模为时空信号,利用轻量级卷积网络(2 层 2D 卷积+1 层 1D 卷积)捕捉多尺度时空特征。模型输入为分块压缩后的注意力Image (块大小为 b),输出为下一 token 的预测注意力Image,再通过 Top-K 筛选确定关键 token 位置。

模型训练:仅需 3%的注意力数据(如 LongBench 中每个任务选取 5 个样本)即可完成训练,并支持跨任务泛化(如从 LongBench 迁移至 GSM8K)。模型参数量极小,仅为 LLM 的百万分之一,对存储资源的占用可忽略。

误差抑制技术:本文使用了分块压缩来减少计算量,即对注意力矩阵执行 Max pooling,将计算量降低至 1/b,同时保留局部关键信息。本文还使用了误差校准技术,即每隔 M 步计算完整注意力分数,修正因稀疏计算累积的分布偏差,确保长期预测稳定性。

Image

算法 1. AttentionPredictor 识别关键 KV Cache 的算法流程。 更多细节可见原文第 4 章节。

跨 token 预取框架

在解码阶段,本文通过异步并行化机制隐藏关键 token 评估与传输延迟。具体来说,在 GPU 执行 LLM 推理时,利用 AttentionPredictor 预测下一步的关键 token 索引,并异步从 CPU 加载 p 对应的 KV 缓存至 GPU。通过掩盖数据传输和估算时延(如图 4),与现有跨层预取方法(如 InfiniGen)相比,本方法通过跨 token 粒度,在 32k 上下文下实现 1.4 倍解码加速。

Image

图 4. 我们提出的跨 token 预取的流程图。通过异步加载下一个 token 的关键 KV 缓存,我们的框架隐藏了 token 评估和传输延迟,从而加速了 LLM 推理的解码。

4. 实验介绍

我们实验的数据集包括长序列任务 LongBench 和数学推理任务 GSM8K,平均输入 token 数达 13K。实验包含 4 个部分:

  1. 评估 AttentionPrdictor 在不同缓存预算下的注意力重建率。
  2. 在两大类任务上评估使用我们方法后大模型的性能。
  3. 通过消融实验展示我们方法各部分的效用。
  4. 评估我们方法的推理效率。

我们在此文章中详细介绍实验 2,其余实验请参见原论文的第 5 章节。

LongBench 数据集

部分实验结果见表 1。实验结果显示,我们的方法在不同 budget 下模型精度损失均<1%。

Image

表 1. 我们提出的 AttentionPredictor 相较于其他方法,模型精度损失大幅降低。其中 H2O 方法使用 64 步注意力分数,与我们方法参数对齐,记为 H2O+。

数学推理数据集

部分实验结果见表 2。我们通过调整 few-shot 的个数模拟长输入 CoT 任务。实验结果表明,我们的方法显著优于现有方法,在 16k 长度下,AttentionPredictor 仅损失 2.05%准确率,显著优于 Quest 的 16.91%下降。

Image

表 2. 我们方法在长 CoT 任务上的模型精度表现优于其他方法。

相关文章

END

作者:USTC & Noah & TU
来源:NeuralTalk

推荐阅读

欢迎大家点赞留言,更多 Arm 技术文章动态请关注极术社区嵌入式AI专栏欢迎添加极术小姐姐微信(id:aijishu20)加入技术交流群,请备注研究方向。

推荐阅读
关注数
18879
内容数
1412
嵌入式端AI,包括AI算法在推理框架Tengine,MNN,NCNN,PaddlePaddle及相关芯片上的实现。欢迎加入微信交流群,微信号:aijishu20(备注:嵌入式)
目录
极术微信服务号
关注极术微信号
实时接收点赞提醒和评论通知
安谋科技学堂公众号
关注安谋科技学堂
实时获取安谋科技及 Arm 教学资源
安谋科技招聘公众号
关注安谋科技招聘
实时获取安谋科技中国职位信息