高效 CUDA 算子编写指南:结合 NCU 与 Cursor Claude-sonnet-3.5

我的课程笔记,欢迎关注: https://github.com/BBuf/how-to-optim-algorithm-in-cuda/tree/master/cuda-mode

0x0. 预览版

上周 MiniMax 开源了他们 4560 亿参数的 MoE 大模型,其中一个亮点是这个模型是一个 Lightning Attention 和 Softmax Attention 的混合架构,技术报告链接见:https://filecdn.minimax.chat/_Arxiv_MiniMax_01_Report.pdf 。关于这个模型更多的细节推荐感兴趣的朋友读 @sonta 的回答:https://www.zhihu.com/question/9630107500/answer/79882585725

提到 Linear Attention 我也不困了,去年就对 RWKV 架构产生过兴趣也做过开源贡献,同时也了解了 Linear Attention 架构的一些算法原理和做推理的优势,具体可以参考我之前的几篇 blog:

如果要在 SGLang 推理框架中去支持 MiniMax Text01 模型,首先就需要实现 https://huggingface.co/MiniMaxAI/MiniMax-Text-01/blob/main/modeling_minimax_text_01.py 中的 MiniMaxText01LightningAttention 模块,这个正是我所擅长的。所以几乎用了一个完整的周末在 SGLang 中建立了 MiniMaxText01LightningAttention 这个模块的 Prefill 和 Decode 过程的优化算子和 Benchmark,对于 Prefiil 来说我只建立了一个 Benchmark ,使用了 OpenNLPLab 提供的 lightning_attn2 的 Triton 算子 https://github.com/OpenNLPLab/lightning-attention/blob/main/lightning_attn/ops/triton/lightning_attn2.py 。这个 Triton 算子相比于 HuggingFace 的原始实现把 Prefill 端到端耗时提升了数倍,可以参考下面的截图:

image.png

而对于 Decode 阶段来说,这是一个典型的 Memory Bound 的算子,这个算子的 Python 代码单独抽出来非常简单。也是我这篇文章的起点,就是把这个算子的性能优化一下,提升带宽利用率和降低执行时间。然后我展示了一下如何正确的使用 Cursor 结合 NCU 来尝试做 CUDA 优化。

首先,这个算子的 PyTorch 代码可以写成下面这几行:

def lightning_attention_decode_naive(q, k, v, past_kv, slope):
    """Naive implementation of lightning attention decode"""
    original_dtype = q.dtype
    ratio = torch.exp(-slope)  # [h, 1, 1]

kv = past_kv
    b, h, n, d = q.shape

output = []
    for i in range(n):
        kv = ratio * kv.to(torch.float32) + torch.einsum(
            "... n d, ... n e -> ... d e",
            k[:, :, i : i + 1],
            v[:, :, i : i + 1],
        )
        qkv = torch.einsum(
            "... n e, ... e d -> ... n d",
            q[:, :, i : i + 1].to(torch.float32),
            kv.to(torch.float32),
        )
        output.append(qkv)
    output = torch.concat(output, dim=-2)

return output.to(original_dtype), kv

其中,输入 Tensor 的形状如下截图:

Image

其次,我这里的目标就是优化一下这个 Kernel,尽可能的提升带宽利用率并且降低 kernel 的耗时。总的来说,我在 Cursor 的协助下写了 2 个版本的 Triton Kernel,以及几个版本的 CUDA Kernel,最后无论是在 lightning_attention_decode 这个算子的 Micro Benchmark 还是端到端的 Lightning Attention 模块的耗时相比于原始的 PyTorch 实现都实现了加速,对于算子来说在 batch 较小时可达到 2 倍加速。

Image

Image

详细数据可以参考 https://github.com/sgl-project/sglang/pull/3030

最后,这个 kernel 还有非常大的可提升空间,不过这不是本文重点,本文的重点是我将在下一节演示一下我是如何使用 Cursor+NCU 来联合优化 CUDA Kernel 的,如果你想在 Cursor 中使用最先进的 Claude-3.5-sonnet-20241022 来直接给你写出性能不错的 CUDA kernel,根据我的使用记录来看是非常困难的。大模型既不会给你避免 Bank Conflict,也不会给你合并内存访问,并且大多数时候还会给你写出效率非常低的 Python 直译 cuda 代码。然而 Cursor 下的 Claude-3.5-sonnet-2024102 有多模态功能是可以看懂图片的,所以我们可以把 NCU 的一些关键 Profile 信息给他,手工强化学习,我稍后会演示如何利用 NCU 的结果让 Cursor 更聪明,从而写出我们想要的优化代码。

0x1. 实操版

0x1.1 Triton naive 版本

kernel 代码:https://github.com/sgl-project/sglang/pull/2920/files#diff-16ed66afc4b7f52545a3fffd55c9fd6daaf87189d9a0d252fccba42951c1cc40R14-R105

Image

首先是一个最 Naive 的版本,对于 q,k,v 的每个头使用一个 Block 来计算,也就是一共有个 Block,然后每个头的维度都从 92 padding 到 128 来满足 Triton kernel 的计算需求。

从上面的性能结果来看,和原始的 PyTorch 实现几乎没有区别。

0x1.2 Triton 优化版本

https://github.com/sgl-project/sglang/pull/2966

Image

把上面的 naive 版本的 Triton kernel 之前的手动 Padding 到 128 移除了,然后在 kernel 中使用 Mask 的方式来解决 dim 维度没有对齐到 2 的幂次的问题。从上面的结果可以看到,Lightning Attention 模块的端到端时间确实是下降了一些。

0x1.3 CUDA 版本

把上面那几行 Lighting Attention Decode Python 代码直接扔给 Cursor Sonnet 3.5 20241022 模型,然后它很快就产生了一份 cuda kernel。

#define THREADS_PER_BLOCK 128

template<typename T>
**global** void lightning_attention_decode_kernel(
    const T* **restrict** q,      // [b, h, 1, d]
    const T* **restrict** k,      // [b, h, 1, d]
    const T* **restrict** v,      // [b, h, 1, e]
    const float* **restrict** past_kv, // [b, h, d, e]
    const float* **restrict** slope,   // [h, 1, 1]
    T* **restrict** output,       // [b, h, 1, e]
    float* **restrict** new_kv,   // [b, h, d, e]
    const int batch_size,
    const int num_heads,
    const int dim,
    const int embed_dim) {


    const int32_t tid = threadIdx.x;
    const int32_t current_head = blockIdx.x;
    const int32_t b = current_head / num_heads;
    const int32_t h = current_head % num_heads;


    if (b >= batch_size) return;


    const int32_t qk_offset = b * num_heads * dim + h * dim;
    const int32_t v_offset = b * num_heads * embed_dim + h * embed_dim;
    const int32_t kv_offset = b * num_heads * dim * embed_dim + h * dim * embed_dim;


    // 1.  计算新的 kv: new_kv = ratio * past_kv + k * v^T
    const float ratio = expf(-1.0f * slope[h]);
    for (int d = tid; d < dim; d += THREADS_PER_BLOCK) {
        T k_value = k[qk_offset + d];
        for (int e = 0; e < embed_dim; e++) {
            const int32_t kv_index = kv_offset + d * embed_dim + e;
            new_kv[kv_index] = ratio * past_kv[kv_index] + k_value * v[v_offset + e];
        }
    }


    __syncthreads();  //  确保所有线程完成 new_kv 的计算


    // 2.  计算 qkv attention 输出: output = q * new_kv
    for (int e = tid; e < embed_dim; e += THREADS_PER_BLOCK) {
        float sum = 0.0f;

但是测试 Benchmark 之后可以发现这个版本的 kernel 性能相比于 Triton 算子的耗时会慢 5 倍左右。

想找出性能差异的原因,最靠谱的方法就是分析下 nuc 的结果,我写了下面的 profile 脚本:

import math
import torch
import triton
import triton.language as tl
from sgl_kernel import lightning_attention_decode

def next_power_of_2(n):
    return 2 ** (int(math.ceil(math.log(n, 2))))

@triton.jit
def _decode_kernel(
    Q,
    K,
    V,
    KV,
    Out,
    S,
    b: tl.constexpr,
    h: tl.constexpr,
    n: tl.constexpr,
    d: tl.constexpr,
    d_original: tl.constexpr,
    e: tl.constexpr,
    e_original: tl.constexpr,
):
    off_bh = tl.program_id(0)
    off_h = off_bh % h

qk_offset = off_bh * n * d
    v_offset = off_bh * n * e
    o_offset = off_bh * n * e
    kv_offset = off_bh * d * e

s = tl.load(S + off_h)
    ratio = tl.exp(-s)

d_idx = tl.arange(0, d)
    e_idx = tl.arange(0, e)

# Create masks for original dimensions
    d_mask = d_idx < d_original
    e_mask = e_idx < e_original

# Load with masking
    q = tl.load(Q + qk_offset + d_idx, mask=d_mask, other=0.0)
    k = tl.load(K + qk_offset + d_idx, mask=d_mask, other=0.0)
    v = tl.load(V + v_offset + e_idx, mask=e_mask, other=0.0)

# Load KV with 2D masking
    kv = tl.load(
        KV + kv_offset + d_idx[:, None] * e + e_idx[None, :],
        mask=(d_mask[:, None] & e_mask[None, :]),
        other=0.0,
    )

# Compute outer product using element-wise operations
    k_v_prod = k[:, None] * v[None, :]
    kv = ratio * kv + k_v_prod

# Store KV with 2D masking
    tl.store(
        KV + kv_offset + d_idx[:, None] * e + e_idx[None, :],
        kv.to(KV.dtype.element_ty),
        mask=(d_mask[:, None] & e_mask[None, :]),
    )

# Compute matrix-vector multiplication using element-wise operations and reduction
    o = tl.sum(q[:, None] * kv, axis=0)

# Store output with masking
    tl.store(Out + o_offset + e_idx, o.to(Out.dtype.element_ty), mask=e_mask)

def triton_lightning_attn_decode(q, k, v, kv, s):
    """Triton implementation of Lightning Attention decode operation"""
    b, h, n, d = q.shape
    e = v.shape[-1]
    assert n == 1, "Sequence length must be 1 in decode mode"

# Get padded dimensions (power of 2)
    d_padded = next_power_of_2(d)
    e_padded = next_power_of_2(e)

# Create output tensor (padded)
    o_padded = torch.empty(b, h, n, e_padded, dtype=v.dtype, device=v.device)

# Create padded tensors without actually padding the data
    q_padded = torch.empty(b, h, n, d_padded, dtype=q.dtype, device=q.device)
    k_padded = torch.empty(b, h, n, d_padded, dtype=k.dtype, device=k.device)
    v_padded = torch.empty(b, h, n, e_padded, dtype=v.dtype, device=v.device)
    kv_padded = torch.empty(
        b, h, d_padded, e_padded, dtype=torch.float32, device=kv.device
    )

# Copy data to padded tensors
    q_padded[..., :d] = q
    k_padded[..., :d] = k
    v_padded[..., :e] = v
    kv_padded[..., :d, :e] = kv

# Launch kernel
    grid = (b * h, 1)
    _decode_kernel[grid](
        q_padded,
        k_padded,
        v_padded,
        kv_padded,
        o_padded,
        s,
        b=b,
        h=h,
        n=n,
        d=d_padded,
        d_original=d,
        e=e_padded,
        e_original=e,
    )

# Get unpadded outputs
    o = o_padded[..., :e]
    kv_out = kv_padded[..., :d, :e]

return o, kv_out

dtype = torch.bfloat16
device = torch.device("cuda")
num_heads = 64
head_dim = 96
seq_len = 1
batch_size = 1

q = torch.randn(batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype)
k = torch.randn(batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype)
v = torch.randn(batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype)
past_kv = torch.randn(batch_size, num_heads, head_dim, head_dim, device=device)
slope = torch.randn(num_heads, 1, 1, device=device)

output_triton, new_kv_triton = triton_lightning_attn_decode(q.clone(), k.clone(), v.clone(), past_kv.clone(), slope.clone())

output_kernel = torch.empty_like(output_triton)
new_kv_kernel = torch.empty_like(new_kv_triton)
lightning_attention_decode(
    q.clone(), k.clone(), v.clone(), past_kv.clone(), slope.clone(),
    output_kernel, new_kv_kernel
)

print('end')

然后执行 /usr/local/NVIDIA-Nsight-Compute-2024.3/ncu --set full -o lightning_attention_decode_bs=1 python3 test_lighting_attention.py

得到 ncu 文件之后可以重点关注一下 Memory Wordload Analysis 这一列:

Triton 版本:

Image

CUDA 版本:

Image

有两个主要区别,首先 CUDA 版本没有使用 Shared Memory 加速读取和写入,第二个区别是 Triton 版本写回全局内存的数据量要小得多。

接下来可以让 Cursor 辅助我们写一个 Shared Memory 的版本,把 q,k,v,new_kv 的计算都放在 Shared Memory 里面。Cursor 确实可以写,但是如果我们不指出是否存在 Bank Conflict,它是不会管的。这就导致它实现的第一个 Shared Memory 版本的 kernel 执行时间比最开始的全局内存读写版本还要慢 4 倍,这里我就不贴代码了。接下来需要给 Cursor 手动解释一下它存在 Bank Conflict,主要是计算 new_kv_shared 的时候存在大量 Bank Conflict,我们要求他执行一个 padding 来避免 Bank Conflict,这样 Cursor 就可以写出看起来正常的代码了。

#define THREADS_PER_BLOCK 128

template<typename T>
**global** void lightning_attention_decode_kernel(
    const T* **restrict** q,      // [b, h, 1, d]
    const T* **restrict** k,      // [b, h, 1, d]
    const T* **restrict** v,      // [b, h, 1, e]
    const float* **restrict** past_kv, // [b, h, d, e]
    const float* **restrict** slope,   // [h, 1, 1]
    T* **restrict** output,       // [b, h, 1, e]
    float* **restrict** new_kv,   // [b, h, d, e]
    const int batch_size,
    const int num_heads,
    const int dim,
    const int embed_dim) {


    extern **shared** char smem[]; //  动态共享内存声明
    //  为所有数组在共享内存中分配空间
    T* q_shared = reinterpret_cast<T*>(smem);
    T* k_shared = reinterpret_cast<T*>(smem + dim * sizeof(T));
    T* v_shared = reinterpret_cast<T*>(smem + 2 * dim * sizeof(T));
    float* new_kv_shared = reinterpret_cast<float*>(smem + (2 * dim + embed_dim) * sizeof(T));
    T* output_shared = reinterpret_cast<T*>(smem + (2 * dim + embed_dim) * sizeof(T) + dim * (embed_dim + 1) * sizeof(float));


    const int32_t tid = threadIdx.x;
    const int32_t current_head = blockIdx.x;
    const int32_t b = current_head / num_heads;
    const int32_t h = current_head % num_heads;


    if (b >= batch_size) return;


    const int32_t qk_offset = b * num_heads * dim + h * dim;
    const int32_t v_offset = b * num_heads * embed_dim + h * embed_dim;
    const int32_t kv_offset = b * num_heads * dim * embed_dim + h * dim * embed_dim;

for (int d = tid; d < dim; d += blockDim.x) {
        q_shared[d] = q[qk_offset + d];
        k_shared[d] = k[qk_offset + d];
    }
    for (int e = tid; e < embed_dim; e += blockDim.x) {
        v_shared[e] = v[v_offset + e];
    }


    __syncthreads();


    const float ratio = expf(-1.0f * slope[h]);

for (int d = tid; d < dim; d += blockDim.x) {
        T k_val = k_shared[d];
        for (int e = 0; e < embed_dim; ++e) {
            int past_kv_idx = kv_offset + d * embed_dim + e;
            T v_val = v_shared[e];
            float new_val = ratio * past_kv[past_kv_idx] + k_val * v_val;
            int shared_idx = d * (embed_dim + 1) + e;
            new_kv_shared[shared_idx] = new_val;
        }
    }


    __syncthreads();


    for (int idx = tid; idx < dim * embed_dim; idx += blockDim.x) {
        int d = idx / embed_dim;
        int e = idx % embed_dim;
        int shared_idx = d * (embed_dim + 1) + e;
        int global_idx = kv_offset + idx;
        new_kv[global_idx] = new_kv_shared[shared_idx];
    }


    __syncthreads();


    for (int e = tid; e < embed_dim; e += blockDim.x) {
        float sum = 0.0f;
        for (int d = 0; d < dim; ++d) {
            int shared_idx = d * (embed_dim + 1) + e;
            sum += q_shared[d] * new_kv_shared[shared_idx];
        }
        output_shared[e] = static_cast<T>(sum);
    }


    __syncthreads();


    if (tid == 0) {
        for (int e = 0; e < embed_dim; ++e) {
            output[v_offset + e] = output_shared[e];
        }
    }
}

但是当我们测试 Benchmark 的时候发现这个版本的速度虽然在 bs<=4 的时候比 Triton 快不少,但是当继续增大 bs 的时候速度越来越慢,是 Triton 是 2-3 倍执行时间。

继续打开 NCU 的 Memory Wordload Analysis,我们发现这次它抛出了一个写 Global Memory 不连续导致性能降低的问题。

Image

把这个结果反馈给 Cursor,Cursor 现在可以知道主要问题是写 new_kv 的时候内部循环·for (int e = 0; e < embed_dim; ++e)·导致线程在访问全局内存时 stride 太大,然后内存没有合并访问,且每个线程需要写入多次全局内存,增加了内存事务数。这也是我们看到这个 kernel 写全局内存的时候比 Triton 多了几倍的原因。知道原因之后 Cursor 就可以改成正确的代码了。代码如下:

#define THREADS_PER_BLOCK 128

template<typename T>
**global** void lightning_attention_decode_kernel(
    const T* **restrict** q,      // [b, h, 1, d]
    const T* **restrict** k,      // [b, h, 1, d]
    const T* **restrict** v,      // [b, h, 1, e]
    const float* **restrict** past_kv, // [b, h, d, e]
    const float* **restrict** slope,   // [h, 1, 1]
    T* **restrict** output,       // [b, h, 1, e]
    float* **restrict** new_kv,   // [b, h, d, e]
    const int batch_size,
    const int num_heads,
    const int dim,
    const int embed_dim) {


    extern **shared** char smem[]; //  动态共享内存声明
    //  为所有数组在共享内存中分配空间
    T* q_shared = reinterpret_cast<T*>(smem);
    T* k_shared = reinterpret_cast<T*>(smem + dim * sizeof(T));
    T* v_shared = reinterpret_cast<T*>(smem + 2 * dim * sizeof(T));
    float* new_kv_shared = reinterpret_cast<float*>(smem + (2 * dim + embed_dim) * sizeof(T));
    T* output_shared = reinterpret_cast<T*>(smem + (2 * dim + embed_dim) * sizeof(T) + dim * (embed_dim + 1) * sizeof(float));


    const int32_t tid = threadIdx.x;
    const int32_t current_head = blockIdx.x;
    const int32_t b = current_head / num_heads;
    const int32_t h = current_head % num_heads;


    if (b >= batch_size) return;


    const int32_t qk_offset = b * num_heads * dim + h * dim;
    const int32_t v_offset = b * num_heads * embed_dim + h * embed_dim;
    const int32_t kv_offset = b * num_heads * dim * embed_dim + h * dim * embed_dim;

for (int d = tid; d < dim; d += blockDim.x) {
        q_shared[d] = q[qk_offset + d];
        k_shared[d] = k[qk_offset + d];
    }
    for (int e = tid; e < embed_dim; e += blockDim.x) {
        v_shared[e] = v[v_offset + e];
    }


    __syncthreads();


    const float ratio = expf(-1.0f * slope[h]);

for (int d = tid; d < dim; d += blockDim.x) {
        T k_val = k_shared[d];
        for (int e = 0; e < embed_dim; ++e) {
            int past_kv_idx = kv_offset + d * embed_dim + e;
            T v_val = v_shared[e];
            float new_val = ratio * past_kv[past_kv_idx] + k_val * v_val;
            int shared_idx = d * (embed_dim + 1) + e;
            new_kv_shared[shared_idx] = new_val;
        }
    }


    __syncthreads();


    for (int idx = tid; idx < dim * embed_dim; idx += blockDim.x) {
        int d = idx / embed_dim;
        int e = idx % embed_dim;
        int shared_idx = d * (embed_dim + 1) + e;
        int global_idx = kv_offset + idx;
        new_kv[global_idx] = new_kv_shared[shared_idx];
    }


    __syncthreads();


    for (int e = tid; e < embed_dim; e += blockDim.x) {
        float sum = 0.0f;
        for (int d = 0; d < dim; ++d) {
            int shared_idx = d * (embed_dim + 1) + e;
            sum += q_shared[d] * new_kv_shared[shared_idx];
        }
        output_shared[e] = static_cast<T>(sum);
    }


    __syncthreads();


    if (tid == 0) {
        for (int e = 0; e < embed_dim; ++e) {
            output[v_offset + e] = output_shared[e];
        }
    }
}

这里重构了 new_kv 的内存访问模式,让相邻线程访问连续的内存地址,达到内存合并访问的目的。

这个 kernel 还有很多优化空间,例如一个 Block 中实际上还有一个 warp 没有工作,因为一个 Block 是 128 个线程,但是 dim=96,所以可以优化成一个 warp 处理一行这种版本。此外,我们没有使用向量化读取进一步降低内存事务等等。

不过从我 kernel Micro Benchmark 以及 end2end 的 Lighting Attention 模块 Benchmark 结果来看,它已经超越了 Triton 的优化版本,在各个 Batch 下都取得了优势。

0x2. 总结

基于 MiniMax Lighting Attention Decode 算子演示了下 Cursor Claude-sonnet-3.5-20241022 这种最先进的大模型目前写 CUDA 底层优化的限制,以及我们如果要使用这种工具应该怎么人工给他一些反馈,让它可以真正的正确工作起来。不要轻易相信 AI 生成的任何代码,特别是涉及到优化的代码。

END

作者:BBuf
来源:GiantPandaCV

推荐阅读

欢迎大家点赞留言,更多 Arm 技术文章动态请关注极术社区嵌入式 AI 专栏欢迎添加极术小姐姐微信(id:aijishu20)加入技术交流群,请备注研究方向。

推荐阅读
关注数
18855
内容数
1399
嵌入式端AI,包括AI算法在推理框架Tengine,MNN,NCNN,PaddlePaddle及相关芯片上的实现。欢迎加入微信交流群,微信号:aijishu20(备注:嵌入式)
目录
极术微信服务号
关注极术微信号
实时接收点赞提醒和评论通知
安谋科技学堂公众号
关注安谋科技学堂
实时获取安谋科技及 Arm 教学资源
安谋科技招聘公众号
关注安谋科技招聘
实时获取安谋科技中国职位信息