详解 vLLM 和 SGLang awq dequantize kernel 的魔法

0x0. 前言

本片文章解析一下 vLLM/SGLang 中 awq int4 的反量化 kernel,这个 kernel 触发条件为当输入 x 的 shape 的 tokens<256 时,这个时候会先把 int4 的 awq 权重使用awq_dequantize反量化回 float16,然后调用 PyTorch Matmul 执行 float16 的乘法,代码位置见: https://github.com/vllm-proje...

defapply(self,  
          layer: torch.nn.Module,  
          x: torch.Tensor,  
          bias: Optional[torch.Tensor] = None) -> torch.Tensor:  
    qweight = layer.qweight  
    scales = layer.scales  
    qzeros = layer.qzeros  
    pack_factor = self.quant_config.pack_factor  
    out_shape = (x.shape[:-1] + (qweight.shape[-1] * pack_factor, ))  
    reshaped_x = x.reshape(-1, x.shape[-1])

### num_tokens >= threshold

FP16_MATMUL_HEURISTIC_CONDITION = x.shape[:-1].numel() >= 256

if FP16*MATMUL_HEURISTIC_CONDITION:  
        out = ops.awq_dequantize(qweight, scales, qzeros, 0, 0, 0)  
        out = torch.matmul(reshaped_x, out)  
else:  
        out = ops.awq_gemm(reshaped_x, qweight, scales, qzeros,  
                           pack_factor)  
if bias isnotNone:  
        out.add*(bias)  
return out.reshape(out_shape)  

本文要解析的就是这里的 vllm ops.awq_dequantize这个 kernel,这个 kernel 的代码单独抽出来只有几十行代码,但是代码中涉及到的魔法和数学有点多,如果不了解这里的原理就会很痛苦,所以我这里来详细解析一下。vllm ops.awq_dequantize这个算子的原始来源是 FasterTransformer 仓库,然后 sglang 的 sgl-kernel 也有一份针对这个算子的干净实现,并通过调整线程块有更快的速度,我这里直接针对这份代码来解析,链接见:https://github.com/sgl-projec...

还需要说明一下,对于 AWQ/GPTQ 来说,权重的量化不是 PerChannel 的而是 GroupWise 的,也就是在 K 方向会有 GS 组 Scales 和 Zeros,例如假设 K/GS=128,那就是在 K 方向有 128 行的 Weight 共享一个 Scales 和 Zeros。因此,它和 PerChannel 的差异就是需要在反量化的时候乘以 Scales 并加上 Zeros。除此之外,AWQ 本身需要在 Activation 计算之前乘以它自己的 ActScale。在下面的 Kernel 中,针对的是 weight,K 方向就是行(row)方向。

0x1. 接口函数

// PyTorch 接口函数,用于 AWQ 权重反量化  
torch::Tensor awq_dequantize(torch::Tensor qweight, torch::Tensor scales, torch::Tensor qzeros){  
// 获取输入张量的维度信息  
int qweight_rows = qweight.size(0);  
int qweight_cols = qweight.size(1);  
int group_size = qweight_rows / scales.size(0); // 计算量化组大小

// 设置 CUDA 网格和块的维度  
int x_num_threads = 16;  
int y_num_threads = 16;  
int x_blocks = qweight_cols / x_num_threads;  
int y_blocks = qweight_rows / y_num_threads;

// 确保在正确的 CUDA 设备上执行  
const at::cuda::OptionalCUDAGuard device_guard(device_of(qweight));

// 创建输出张量,与 scales 具有相同的数据类型和设备  
auto output_tensor_options = torch::TensorOptions().dtype(scales.dtype()).device(scales.device());  
  at::Tensor output = torch::empty({qweight_rows, qweight_cols * 8}, output_tensor_options);

// 获取各个张量的数据指针  
auto _qweight = reinterpret_cast<int*>(qweight.data_ptr<int>());  
auto _scales = reinterpret_cast<half*>(scales.data_ptr<at::Half>());  
auto _zeros = reinterpret_cast<int*>(qzeros.data_ptr<int>());  
auto _output = reinterpret_cast<half*>(output.data_ptr<at::Half>());

// 配置 CUDA 核函数的执行参数  
dim3 num_blocks(x_blocks, y_blocks);  
dim3 threads_per_block(x_num_threads, y_num_threads);

// 获取当前 CUDA 流并启动核函数  
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();  
  dequantize_weights<<<num_blocks, threads_per_block, 0, stream>>>(  
      _qweight, _scales, _zeros, _output, group_size, qweight_cols);

// 返回反量化后的权重张量  
return output;  
}  

需要注意的点是,kernel 的输入是int4类型的,输出是float16类型的,然后输入的 shape 是[qweight_rows, qweight_cols],输出的 shape 是[qweight_rows, qweight_cols * 8]。由此,我们也可以看出输入数据的元素是一个 32 位整数  source,它包含了 8 个 4 位整数(每个 4 位可以表示 0-15 的值)。这 8 个 4 位整数被紧密地打包在一起,如下图所示:

[4bit][4bit][4bit][4bit][4bit][4bit][4bit][4bit]

接下来,在 kernel launch 配置方面,使用二维的线程网格和线程块,并且每个线程处理输入 Tensor 中的一个元素,非常直观:

int x_num_threads = 16;  
int y_num_threads = 16;  
int x_blocks = qweight_cols / x_num_threads;  
int y_blocks = qweight_rows / y_num_threads;  
dim3 num_blocks(x_blocks, y_blocks);  
dim3 threads_per_block(x_num_threads, y_num_threads);  

0x2. dequantize_weights kernel 流程

// 权重反量化的 CUDA kernel,最大线程数为 256  
**global** void **launch_bounds**(256) dequantize_weights(  
int* **restrict** qweight,    // 量化后的权重  
    half* **restrict** scales,    // 量化比例因子  
int* **restrict** qzeros,     // 量化零点  
    half* **restrict** output,    // 输出的反量化权重  
int group_size,               // 量化组大小  
int qweight_cols) {           // 量化权重的列数  
// 计算当前线程处理的列和行索引  
int col = blockIdx.x _ blockDim.x + threadIdx.x;  
int row = blockIdx.y _ blockDim.y + threadIdx.y;

// 获取当前处理位置的零点,并反量化为 fp16x2 格式  
  uint4 zeros = dequantize_s4_to_fp16x2(qzeros[col + (row / group_size) * qweight_cols]);  
// 加载对应的缩放因子  
  uint4 loaded_scale = _(uint4_)(scales + 8 * col + (row / group_size) * qweight_cols * 8);

// 将量化权重反量化为 fp16x2 格式  
  uint4 weight_fp16 = dequantize_s4_to_fp16x2(qweight[col + row * qweight_cols]);

// 对每个 fp16x2 元素执行(weight - zero) * scale 操作  
// 处理第一对 fp16 值  
asmvolatile("sub.f16x2 %0, %1, %2;\n" : "=r"(weight_fp16.x) : "r"(weight_fp16.x), "r"(zeros.x));  
asmvolatile("mul.rn.f16x2 %0, %1, %2;\n" : "=r"(weight_fp16.x) : "r"(weight_fp16.x), "r"(loaded_scale.x));  
// 处理第二对 fp16 值  
asmvolatile("sub.f16x2 %0, %1, %2;\n" : "=r"(weight_fp16.y) : "r"(weight_fp16.y), "r"(zeros.y));  
asmvolatile("mul.rn.f16x2 %0, %1, %2;\n" : "=r"(weight_fp16.y) : "r"(weight_fp16.y), "r"(loaded_scale.y));  
// 处理第三对 fp16 值  
asmvolatile("sub.f16x2 %0, %1, %2;\n" : "=r"(weight_fp16.z) : "r"(weight_fp16.z), "r"(zeros.z));  
asmvolatile("mul.rn.f16x2 %0, %1, %2;\n" : "=r"(weight_fp16.z) : "r"(weight_fp16.z), "r"(loaded_scale.z));  
// 处理第四对 fp16 值  
asmvolatile("sub.f16x2 %0, %1, %2;\n" : "=r"(weight_fp16.w) : "r"(weight_fp16.w), "r"(zeros.w));  
asmvolatile("mul.rn.f16x2 %0, %1, %2;\n" : "=r"(weight_fp16.w) : "r"(weight_fp16.w), "r"(loaded_scale.w));

// 计算输出指针位置并存储结果  
  half* output_ptr = output + 8 * col + 8 * row * qweight_cols;  
  _(uint4_)output_ptr = weight_fp16;  
}  

这里整体是非常好理解的,我们根据线程 id 定位到当前线程处理的列和行索引之后分别加载零点 zeros,缩放系数 loaded_scale 和权重 weight_fp16 并对 zeros/weight_fp16 应用dequantize_s4_to_fp16x2反量化 kernel 把当前行列所在的 int32 类型的值(8 个 int4)反量化为 8 个 half 类型的输出值,注意这里是用 4 个 half2 来存储的。然后使用(weight - zero) * scale操作来完成反量化的过程。

这里解析一个asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(weight_fp16.x) : "r"(weight_fp16.x), "r"(zeros.x));指令:

这行代码使用了 CUDA PTX,用于执行半精度浮点数(fp16)的减法操作。它的基本语法为:

asm [volatile] ("汇编指令" : 输出操作数 : 输入操作数 : 可能被修改的寄存器);  

下面是详细解析:

  • asm volatile
  • asm  关键字表示这是内联汇编代码
  • volatile  修饰符告诉编译器不要优化或重排这段汇编代码,确保它按照指定的顺序执行
  • sub.f16x2 %0, %1, %2;\n
  • 这是实际的 CUDA PTX 汇编指令
  • sub.f16x2  是 CUDA 的指令,表示对两个并排的 fp16 值(packed half2)执行减法操作
  • %0, %1, %2  是占位符,分别对应后面定义的输出和输入操作数
  • \n  是换行符,用于格式化汇编代码
  • : "=r"(weight_fp16.x) : "r"(weight_fp16.x), "r"(zeros.x));
  • 第一个冒号后的  "=r"(weight_fp16.x)  是输出操作数,=r 表示这是一个输出到通用寄存器的值
  • 第二个冒号后的  "r"(weight_fp16.x)  和  "r"(zeros.x))  是两个输入操作数,r 表示它们来自通用寄存器

通过这个指令就实现了反量化中的减零点的功能,kernel 中其它的 ptx 指令类推。

0x3. dequantize_s4_to_fp16x2 kernel(魔法发生的地方)

这段代码对应的原理在 nvidia 2023 年夏日专场其实简单讲了一下,我这里结合当时的 PPT 复述一下这里的原理,通过这个复述读者稍后就可以知道代码中的那一堆魔术和用于计算的 PTX 指令是做了什么了。注意下面引用的图来 BiliBili NVIDIA 英伟达频道 上传的《TensorRT-LLM 中的 Quantization GEMM(Ampere Mixed GEMM)的 CUTLASS 2.x 实现讲解》。

FasterTransformer 高效的 Int8/Int4 快速 Convert 为 FP16

image.png

这张 slides 展示了 FP16 的 IEEE 754 标准,一个 16bit 的数里面包含 1 个符号位,5 个基码位,10 个尾数。

Image

假设我们有一个 uint8 的数 143,如果我们把它放到实际的 FP16 的尾数位里面去,那么我们是否有办法通过合理的设置基码位把 143 表达出来呢?那我们按照已知的 FP16 的数值计算方法,拿基码位的二进制前面加上一个 1.x,然后去乘以 2 的(基码位的值-15)次方,我们已知 143 对应的实际上对应的是下面的值。假设我们想用这个 FP16 的值来表达 Int8,我们可以发现如果 x=25 的话,我们把上面的 FP16 的值减去 1024 就是下面的 143 了。因此,我们只需要把 int8 的值放到尾数位,然后把它的基码位设置成 25,然后再把 FP16 的数值结果减去 1024 就可以得到 UINT8 转换到 FP16 的值。

Image

总结一下就是直接把 UINT8 的数值放在 FP16 的尾数位,

Image

然后再把 FP16 的基码位设置成 25,这个 25 对应的十六进制表示就是 0x64,

Image

随后再把最终的这个值减去 FP16 形式的 1024,就完成了从 UINT8 到 FP16 的转换。

Image

如果是 Int8 的话,应该怎么做呢?可以注意到 UINT8 和 INT8 只是数值范围的区别,那么我们需要把 INT8 的数据加上 128,就能把它转换成 UINT8 的形式。这样转换出来的 FP16 的结果,只需要在减去 1024 的时候多减去 128,就恢复到了对应的原始 INT8 的数值。

Image

那么我们怎么实际的去用指令完成上面描述的这个操作呢?可以注意到有一种叫作 prmt 的 PTX 指令,这个指令做的事情就是从 2 个 32bit 的寄存器 A,B 中抽出 4 个 8bit 组成最终的 d。而这 4 个 8bit 怎么抽取,就是每个 8bit 对应到 c 寄存器里面的低 4bit,就是说 c 寄存器的低 4bit 每个 bit 都是一个索引,假设 A,B 两个 32 位寄存器里面存放的是上方左图这样的数据形式,即 ABCDEFGH。那么在 c 寄存器中,索引的 4 个数字分别是 1,3,5,7,那么最终这个 D 寄存器里面的 4 个 8bit 数据就是 GECA。通过这种指令就可以实现从 32bit 寄存器里面抽取对应想要的一个字节出来的效果。

Image

对应到 TRT-LLM 的转换代码就是这样的形式,我们可以注意到它用 permute 指令从输入的 UINT8 数据和 magic number 组成的这两个 32 位寄存器中去抽取 4 个 8bit,抽取的索引放在这个 mask_for_elt_01/23 中。这里的两个掩码值  mask_for_elt_01 = 0x5250  和  mask_for_elt_23 = 0x5351  是用于 CUDA 的 PRMT(Permute)指令的控制参数,它们决定了如何重排字节。

--------------------分割线---------------------

这里我感觉比较难理解,所以下面详细拆解一下:

PRMT 指令基础

首先,PRMT 指令的格式是:

prmt.b32 d, a, b, c;  

其中,d  是目标寄存器;a  和  b  是源寄存器;c  是控制码(即我们讨论的掩码)。然后 PRMT 指令将  a  和  b  的字节重新排列,根据控制码  c  中的每个字节决定输出的每个字节。

掩码的二进制表示

将掩码转换为二进制 (我用计算器算的):

Image

Image

掩码的工作原理

在 PRMT 指令中,控制码  c  的每个字节控制输出的一个字节。每个控制字节的格式为:

[7:6] 选择源(00=a的低字, 01=a的高字, 10=b的低字, 11=b的高字)  
[5:3] 保留或用于其他功能  
[2:0] 选择字节索引(0-3)  

mask_for_elt_01 (0x5250)  分析

拆分为 4 个字节:0x520x50

  • 第 1 个字节  0x52 = 0101 0010
  • 01: 选择 a 的高字(即源数据的高 16 位)
  • 010: 选择索引 2 的字节
  • 第 2 个字节 0x50 = 0101 0000
  • 01: 选择 a 的高字
  • 000: 选择索引 0 的字节 这个掩码用于提取源数据中的第 0 和第 2 个字节(即偶数位置的字节),并将它们放入结果的低 16 位。

mask_for_elt_23 (0x5351)  分析

拆分为 4 个字节:0x530x51

  • 第 1 个字节  0x53 = 0101 0011
  • 01: 选择 a 的高字
  • 011: 选择索引 3 的字节
  • 第 2 个字节  0x51 = 0101 0001
  • 01: 选择 a 的高字
  • 001: 选择索引 1 的字节 这个掩码用于提取源数据中的第 1 和第 3 个字节(即奇数位置的字节),并将它们放入结果的低 16 位。
对应到代码
asmvolatile("prmt.b32 %0,%1,%2,%3;\n" : "=r"(h[0]) : "r"(i8s), "r"(start_byte_for_fp16), "r"(mask_for_elt_01));  
asmvolatile("prmt.b32 %0,%1,%2,%3;\n" : "=r"(h[1]) : "r"(i8s), "r"(start_byte_for_fp16), "r"(mask_for_elt_23));  
  • 第一条指令使用mask_for_elt_01提取源数据i8s中的偶数位置字节(0 和 2),并与start_byte_for_fp16(0x64006400)结合
  • 第二条指令使用mask_for_elt_23提取源数据i8s中的奇数位置字节(1 和 3),并与start_byte_for_fp16结合
staticconstexpruint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480;  
asmvolatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(I8s_TO_F16s_MAGIC_NUM));  
asmvolatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[1]) : "r"(h[1]), "r"(I8s_TO_F16s_MAGIC_NUM));  

之后再像我们刚才描述的那样,在它的基础上减掉(1024+128)就得到了真实的这 4 个 INT8 对应的 FP16 的值。注意这里的 (1024+128)是 dtype=float16 下的 1152 对应的二进制。

----------------------------分割线-----------------------------

我们可能会注意到,这里为什么要分别抽取 01 和 23,而不是抽取 0123 呢?这主要是为了和之后的 INT4 的实现保持一致,在 INT4 的实现里不得不按照 02,13 的方式去抽取。

Image

前面介绍了 INT8 到 FP16 的转换,如果是 INT4 应该怎么转呢?permute 指令只能以 8Bit 为单位进行数据的操作,但是在 4Bit 的转换中,我们知道 4Bit 就是一个 8Bit 里面高 4Bit 存一个数据,低 4Bit 存另外一个数据。那么,我们就需要一种形式能把实际的 8Bit 里面的高低 4 个 Bit 给抽取出来。

Image

抽取出来之后我们应该怎么做呢?先看低 4 个 bit,假设我们以位运算的方式把 8Bit 中的低 4 个 Bit 给抽取出来放到一个 FP16 的尾数里面去,然后前面也在基码位上赋值和 Int8 相同的 25,也就是 16 进制的 64。我们再把这个得到的值减去(1024+8),就得到了最终这个低 4Bit 对应的 FP16 的值。

Image

那如果是高 4 个 Bit 应该怎么做呢?我们注意到低 4 个 Bit 是直接放到最低的 4 个 Bit 位,高 4 个 Bit 同样用位运算抽取出来之后这高 4 个 Bit 是存在于一个 Int8 的高 4Bit 里面,那放到尾数位的话那么它就需要去进行一个额外的除以 16 的操作,相当于右移了 4 位,最后就移到了黄色的位置。移动到这里之后,就可以进行和刚才一样的那些操作了,减去对应的值就得到了实际对应的 FP16 的值。这里减去的值是 1024/16=64,因为移位的原因还要减掉 8。

Image

注意到在提取 Int4 数据的时候是用这张 Slides 的形式去提取的,而刚好有一种叫 lop3 的 PTX 指令可以完成这件事情。lop3 这个 PTX 指令的大概描述就是他会在输入 a, b, c 三个寄存器作为输入,然后有一个 Lut 值,这个 Lut 值是怎么确定的呢?假设 a,b,c 分别对应了 0xF0,0xCC,0xAA,我们把这三个值进行我们想要的操作得到的值作为 Lut 值,把这个 Lut 值放进去之后指令就会自动对 a, b, c 进行相应的操作,把结果写到 d。所以,我们就可以利用这个指令把 Lut 值给它,它就可以帮我们高效完成 Int4 数据的提取了。最后,我们就把 Int4 转成 FP16 的过程转换成了一条 lop3 指令加上一条 fma(或者 sub)指令。

结合我们的 AWQ 的转换代码,LOP3 的应用是:

asmvolatile("lop3.b32 %0, %1, %2, %3, %4;\n"  
               : "=r"(h[0])  
               : "r"(i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));  

这里 LOP3 指令实现了类似  (i4s & BOTTOM_MASK) | I4s_TO_F16s_MAGIC_NUM  的操作,但只用一条指令就完成了,大大提高了效率。

Image

这张 Slides 展示了 Int4 到 FP16 的具体代码实现,我们注意到它提取的时候会用到 0x0f 或者 0xf0 来提取 Int4,这样的话假如我们有连续的 Int4 的话,那被提取出来的分别是第 0 个 Int4 和第 4 个 Int4 以及第 1 个 Int4 和第 5 个 Int4。所以它的奇偶被分别提取了出来。实际上我们是用 8 个连续的 Int4 来进行类型转换,因此它每次先把第 0 个 Int4 和第 4 个 Int4 提取出来,放到两个连续的 FP16 里面去,然后再去把第 1 和第 5 个 Int4 提取出来,放到两个连续的 FP16 里面去,以此类推。我们之前在做 Int8 的时候也分奇偶提取就和这里不得不做的这个数据提取动作保持一致。

Image

为了实际计算的时候去逆转这个元素排布的变化,我们需要在计算之前把 Layout 进行相应的调整。就是说以 Int4 位例的话就分别把它的奇偶位元素分别提取出来,这样在我们真正做计算把它从 INT4 转成 FP16 的时候,就会通过上一页 Slides 介绍的操作完成对这个 Layout 的逆运算,还原回了真实的连续排布的 layout。

这就是描述的最后一种快速的 Int4/Int8 转 FP16 的优化的 layout 变化。通过这种优化就把前面提到的一个 convert 指令转换成了一系列lop3或者prmt指令。虽然指令数没有变化,但是指令的 latency 会更低。

dequantize_s4_to_fp16x2 kernel 解析

实际上上面的原理解析的代码就是这个 dequantize_s4_to_fp16x2 kernel,根据上面的原理解析添加了几个注释,现在细节应该都比较清楚了。

**device** uint4 dequantize_s4_to_fp16x2(uint32_tconst& source){  
#if defined(**CUDA_ARCH**) && **CUDA_ARCH** >= 750  
  uint4 result;

uint32_t* h = reinterpret_cast<uint32_t*>(&result);  
uint32_tconst i4s = reinterpret_cast<uint32_tconst&>(source);

// First, we extract the i4s and construct an intermediate fp16 number.  
staticconstexpruint32_t immLut = (0xf0 & 0xcc) | 0xaa;  
staticconstexpruint32_t BOTTOM_MASK = 0x000f000f;  
staticconstexpruint32_t TOP_MASK = 0x00f000f0;  
staticconstexpruint32_t I4s_TO_F16s_MAGIC_NUM = 0x64006400;

// 注释说明了这种实现的优势:  
// 1. 整个序列只需要 1 条移位指令  
// 2. 利用寄存器打包格式和无符号整数表示  
// 3. 利用 sub 和 fma 指令具有相同的吞吐量来优化转换

// 将 i4s 右移 8 位,用于处理第 4-7 个元素  
// 提前发出以隐藏 RAW 依赖关系  
constuint32_t top_i4s = i4s >> 8;

// 提取并转换第 0 和第 1 个元素(低字节的低 4 位)  
// 使用 LOP3 指令实现(i4s & BOTTOM_MASK) | I4s_TO_F16s_MAGIC_NUM  
asmvolatile("lop3.b32 %0, %1, %2, %3, %4;\n"  
               : "=r"(h[0])  
               : "r"(i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));

// 提取并转换第 2 和第 3 个元素(低字节的高 4 位)  
// 使用 LOP3 指令实现(i4s & TOP_MASK) | I4s_TO_F16s_MAGIC_NUM  
asmvolatile("lop3.b32 %0, %1, %2, %3, %4;\n"  
               : "=r"(h[1])  
               : "r"(i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));

// 提取并转换第 4 和第 5 个元素(高字节的低 4 位)  
asmvolatile("lop3.b32 %0, %1, %2, %3, %4;\n"  
               : "=r"(h[2])  
               : "r"(top_i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));

// 提取并转换第 6 和第 7 个元素(高字节的高 4 位)  
asmvolatile("lop3.b32 %0, %1, %2, %3, %4;\n"  
               : "=r"(h[3])  
               : "r"(top_i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));

// 定义用于最终转换的魔数常量  
// 表示 fp16 格式的{1024, 1024}  
staticconstexpruint32_t FP16_TOP_MAGIC_NUM = 0x64006400;  
// 表示 fp16 格式的{1 / 16, 1 / 16},用于缩放高 4 位的值  
staticconstexpruint32_t ONE_SIXTEENTH = 0x2c002c00;  
// 表示 fp16 格式的{-64, -64},用于偏移校正  
staticconstexpruint32_t NEG_64 = 0xd400d400;

// 最终转换步骤:将中间 fp16 值转换为实际的 int4 值  
// 处理第 0 和第 1 个元素:直接减去 1024  
asmvolatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM));

// 处理第 2 和第 3 个元素:乘以 1/16 再减去 64  
// 相当于(h[1] * 1/16 - 64),因为高 4 位需要右移 4 位  
asmvolatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_64));

// 处理第 4 和第 5 个元素:直接减去 1024  
asmvolatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM));

// 处理第 6 和第 7 个元素:乘以 1/16 再减去 64  
asmvolatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_64));

return result;  // 返回包含 8 个 fp16 值的 uint4 结构  
#else  
  assert(false);  // 如果 CUDA 架构低于 7.5,则断言失败  
return {};  
#endif  
}  

0x4. 总结

本文详细解析了 vLLM/SGLang 中 AWQ int4 反量化 kernel 的实现原理和优化技巧。该 kernel 巧妙利用 IEEE 754 浮点数表示特性,通过 LOP3 和 PRMT 等 PTX 指令高效地将 int4 权重转换为 fp16 格式。通过直接操作尾数位和基码位,避免了传统转换方法中的多次移位和类型转换,实现了高性能的反量化操作。整个过程只需少量高效指令,充分利用了 CUDA 硬件特性,是一种精巧的底层优化技术。因为很底层,所以代码实现虽然简短但引入了大量的 Magic Number 和先验知识,我这里结合 nvidia 的一个 PPT 和自己的理解把它搞清楚了,希望可以帮助到有相同困惑的读者。

END

作者:BBuf
来源:GiantPandaLLM

推荐阅读

欢迎大家点赞留言,更多 Arm 技术文章动态请关注极术社区嵌入式AI专栏欢迎添加极术小姐姐微信(id:aijishu20)加入技术交流群,请备注研究方向。

推荐阅读
关注数
18921
内容数
1431
嵌入式端AI,包括AI算法在推理框架Tengine,MNN,NCNN,PaddlePaddle及相关芯片上的实现。欢迎加入微信交流群,微信号:aijishu20(备注:嵌入式)
目录
极术微信服务号
关注极术微信号
实时接收点赞提醒和评论通知
安谋科技学堂公众号
关注安谋科技学堂
实时获取安谋科技及 Arm 教学资源
安谋科技招聘公众号
关注安谋科技招聘
实时获取安谋科技中国职位信息