使用 Triton 加速 2D 动态块量化 Float8 GEMM 简介

博客来源:https://pytorch.org/blog/accelerating-gemms-triton/ 这里做了翻译。这篇博客主要讲了如何用 Triton 来优化 Float8 格式的矩阵乘法(GEMM)运算。文章提出了一个叫 GridQuant 的方法,通过把大矩阵分成 256x256 的小块,然后再把每个小块分成更小的 32x32 的格子来处理数据。这种方法比之前的方案快了将近两倍。另外,文章还介绍了三个新技术:Warp 专门化、TMA(张量内存加速器)和持久化 kernel,这些技术让不同的计算任务可以更好地并行执行,充分利用 GPU 的硬件特性。通过这些优化,在某些特定场景下比之前最好的方案又快了约 1.2 倍,特别适合用在大语言模型的推理阶段。不过这里的 Triton 代码还没有开源。

使用 Triton 加速 2D 动态块量化 Float8 GEMM

Float8 (FP8)的 2D 块量化有望提高 Float8 量化的精度,同时加速推理和训练中的 GEMM 运算。在这篇博客中,我们展示了使用 Triton 进行块量化 Float8 GEMM 的两个主要阶段的进展。

对于从高精度(BFloat16)到 Float8 的 A 和 B 张量的输入量化,我们展示了 GridQuant,它利用 mini-grid stride loop 风格的处理方式,相比当前的 2D 块量化 kernel 实现了近 2 倍的加速(99.31%)。

对于 Float8 GEMM,我们展示了 Triton 的 3 个新发展 - Warp Specialization、TMA 和 persistent kernel,有效地创建了一个协作式 kernel(作为 Ping-Pong 调度的替代方案 PyTorch 博客 CUTLASS Ping-Pong GEMM Kernel 简介)。因此,我们比去年最好的 SplitK kernel 实现了约 1.2 倍的加速。

image.png

图 1:在不同大小下,2D 量化相对于当前基准的加速比较。(越低越好)

为什么选择 FP8 的 2D 块量化?

一般来说,当我们从张量级缩放,到行级缩放,再到 2D 块级缩放,最后到列级缩放时,fp8 量化的精度会逐步提高。这是因为给定 token 的特征存储在每一列中,因此该张量中的每一列都有更相似的缩放。

为了最小化给定数值集合中的异常值数量,我们希望找到共性,使得数字以相似的方式进行缩放。对于 transformer 来说,这意味着基于列的量化可能是最优的...然而,由于数据在内存中是按行连续布局的,列式内存访问效率极低。因此,列式加载需要在内存中进行大跨度访问来提取孤立的值,这与高效内存访问的核心原则相违背。

然而,2D 是次优选择,因为它包含了一些列式的特点,同时由于我们可以使用 2D 向量化这些加载,所以内存效率更高。因此,我们希望找到提高 2D 块量化速度的方法,这就是我们开发 GridQuant kernel 的原因。

对于量化过程,我们需要对高精度 BF16 输入张量(A = 输入激活,B = 权重)进行 2D 块量化,然后使用量化张量及其 2D 块缩放值进行 Float8 矩阵乘法,并返回 BF16 格式的输出 C 张量。

GridQuant 如何提高 2D 块量化效率?

GridQuant kernel 相比最初基于标准 tile 的基准量化实现有几项改进。GridQuant kernel 对整个输入张量进行两次完整的遍历,工作方式如下:

阶段 1 - 确定来自高精度张量的每个 256x256 子块的最大绝对值。

1 - 我们将 BF16 张量分成 256 x 256 的子块。这个量化大小是可配置的,但 256x256 是默认值,因为它在量化精度和处理效率之间提供了良好的平衡。

2 - 每个 256x256 子块被细分为 8x8 模式排列的 64 个子块,每个子块处理 32x32 元素块。一个 warp(32 个线程)处理其分配的 32x32 块内的所有元素计算。

3 - 我们在共享内存中声明一个 32x32 的 max_vals 数组。这将存储 2d 向量块在整个 256x256 子块中移动时每个位置 i,j 的当前最大值。

这是一个重要的改进,因为这意味着我们可以对 max vals 评分系统进行向量化更新,而不是标量更新,从而实现更高效的更新。

Image

图 2:输入张量的分块布局 - 在张量上创建 256x256 的网格,在每个 256x256 块内,进一步细分为 32x32 子块。为每个 256x256 块创建 32x32 max_vals。

4 - 每个 warp 处理一个 32x32 块,因为我们使用 4 个 warp,我们确保 Triton 编译器可以将下一个 32x32 块的内存加载与当前块的 absmax 计算流水线化。这确保了 warp 调度器能够在加载数据的 warp 和处理数据的 warp 之间切换,使 SM 持续忙碌。

5 - 32x32 2D 向量块处理以网格步进循环的方式在整个 256x256 子块中移动,每个 warp 根据其当前 32x32 子块更新共享内存 32x32 max_vals。因此 max_vals[i,j]在处理每个子块时保持最新的最大值。

完成 256x256 块网格步进循环后,maxvals 矩阵然后自身被归约以找到整个 256 块的绝对单一最大值。

这给出了这个 2D 256 x 256 块的最终缩放因子值。

阶段 2 - 使用阶段 1 中找到的单一最大值缩放因子,将 256x256 块值量化为 Float8。

接下来,我们对整个 256x256 块进行第二次遍历,使用阶段 1 中找到的最大值来重新缩放所有数字,将它们转换为 float 8 格式。

因为我们知道需要进行 2 次完整的遍历,所以在阶段 1 部分的加载期间,我们指示 triton 编译器以更高优先级将这些值保持在缓存中(evict policy = last)。

这意味着在第二次遍历期间,我们可以从 L2 缓存获得高命中率,这比直接访问 HBM 提供更快的内存访问。

当所有 256 x 256 块处理完成后,2D 块量化处理完成,我们可以返回新的 Float8 量化张量及其缩放因子矩阵,这将在 GEMM 处理的下一阶段使用。这个输入量化对第二个输入张量也重复进行,这意味着我们最终得到 A_Float 8、A_scaling_matrix 和 B_Float8 以及 B_scaling matrix。

GridQuant - GEMM Kernel

GridQuant-GEMM kernel 接收上述量化的四个输出进行处理。我们的高性能 GEMM kernel 具有几个新的 Triton 开发特性,以在 LLM 推理解码阶段相关的矩阵形状配置中实现 SOTA 性能。

这些新特性常见于使用 CUTLASS 3.x 构建的 Hopper 优化 kernel,如 FlashAttention-3(https://arxiv.org/abs/2407.08608)和Machete(https://neuralmagic.com/blog/introducing-machete-a-mixed-input-gemm-kernel-optimized-for-nvidia-hopper-gpus/)。在这里,我们讨论这些方法并展示使用Triton实现它们可以获得的性能优势。

张量内存加速器(TMA)

NVIDIA Hopper GPU 上的 TMA 单元是一个专用的硬件单元,用于处理 AI 工作负载中常见的多维张量的加载/存储操作。这有几个重要的好处。

从全局内存和共享内存传输数据可以在不涉及 GPU SM 上其他资源的情况下进行,释放寄存器和 CUDA 核心。此外,当在 warp 专用 kernel 中使用时,轻量级 TMA 操作可以分配给生产者 warp,允许内存传输和计算高度重叠。

关于 TMA 在 Triton 中的使用详情,请参见我们的前一篇博客

Warp 专用化(协作式 Persistent Kernel 设计)

Warp 专用化是一种利用 GPU 流水线并行性的技术。这个实验性特性通过tl.async_task API(https://github.com/facebookexperimental/triton/tree/ws)实现了专用线程的表达,允许用户指定Triton程序中的操作应该如何在warp之间"分割"。协作式Triton kernel 执行不同类型的计算和加载,每种操作都在其专用硬件上进行。为每个专用任务提供专用硬件使得对于没有数据依赖的操作能够高效地实现并行性。

Image

图 3. NVIDIA H100 SM 中专用硬件单元的逻辑视图

我们的 kernel 中创建流水线的操作是:

A - 从 GMEM 加载每块缩放到 SMEM (cp.async 引擎)

B - 从 GMEM 加载激活(A)和权重(B)tile 到 SMEM (TMA)

C - A tile 和 B tile 的矩阵乘法 = C tile (Tensor Core)

D - 用 A 的每块缩放和 B 的每块缩放来缩放 C tile (CUDA core)

这些步骤可以分配给 threadblock 中专用 warp 组执行的"任务"。协作策略有三个 warp 组。一个负责给计算单元提供数据的生产者 warp 组和 2 个执行计算的消费者 warp 组。两个消费者 warp 组各自处理同一输出 tile 的一半。

Image

图 4. Warp 专用化 Persistent 协作式 kernel (来源:NVIDIA(https://drive.google.com/file/d/18sthk6IUOKbdtFphpm_jZNXoJenbWR8m/view))

这与我们在之前博客中讨论的 ping-pong 调度不同,在 ping-pong 调度中,每个消费者 warp 组处理不同的输出 tile。我们注意到 Tensor Core 操作与 epilogue 计算不重叠。在计算的 epilogue 阶段减少 Tensor Core 流水线的利用率将减少消费者 warp 组的寄存器压力,相比 ping-pong 总是保持 Tensor Core 忙碌,这允许更大的 tile 大小。

最后,当网格大小超过 H100 GPU 上可用计算单元数量(132)时,我们的 kernel 被设计为 persistent。Persistent kernel 在 GPU 上保持活跃较长时间,在其生命周期内计算多个输出 tile。我们的 kernel 利用 TMA 异步共享到全局内存存储,同时继续处理下一个输出 tile,而不是承担调度多个 threadblock 的成本。

微基准测试

Image

图 5:在小批量范围和 Llama3 8192 N,K 大小下,Gridquant-GEMM 与我们最佳性能 SplitK kernel 的延迟比较(微秒)。(越低越好)

Warp 专用化 Triton kernel 在上述小 M 和方阵形状下实现了 SOTA 性能,相比 SplitK Triton kernel(这是 Triton GEMM 在这个低算术强度范围内之前最佳性能的策略)实现了近 1.2 倍的加速。对于未来的工作,我们计划调优我们的 kernel 在中到大 M 范围和非方阵上的性能。

结论和未来工作

未来工作包括对端到端工作流进行 gridquant 基准测试。此外,我们计划对非方阵(矩形)矩阵以及中到大 M 大小进行更广泛的基准测试。最后,我们计划探索 Triton 中的 ping-pong 风格 warp 专用化与当前协作式实现的对比。

END

作者:GiantPandaCV
来源:GiantPandaCV

推荐阅读

欢迎大家点赞留言,更多 Arm 技术文章动态请关注极术社区嵌入式 AI 专栏欢迎添加极术小姐姐微信(id:aijishu20)加入技术交流群,请备注研究方向。

推荐阅读
关注数
18853
内容数
1391
嵌入式端AI,包括AI算法在推理框架Tengine,MNN,NCNN,PaddlePaddle及相关芯片上的实现。欢迎加入微信交流群,微信号:aijishu20(备注:嵌入式)
目录
极术微信服务号
关注极术微信号
实时接收点赞提醒和评论通知
安谋科技学堂公众号
关注安谋科技学堂
实时获取安谋科技及 Arm 教学资源
安谋科技招聘公众号
关注安谋科技招聘
实时获取安谋科技中国职位信息