Arm KleidiAI是一个利用arm CPU向量扩展指令(包括NEON MLA, dot product, i8mm矩阵乘,SME2 outer product, SME2 multi-vector等)加速AI应用中的GEMM, GEMV,矩阵转置,量化运算的uKernel (micro-kernel)软件库。
KleidiAI 概况
KleidiAI是一个轻量级的运算加速库。它实现基本的运算加速,KleidiAI 不 提供以下功能:
- AI model解释,构建运算网络
- 内存管理,动态内存分配
- 运算网络的分拆合并
- 运算的调度,如调度不同的算子到不同的运算单元(如CPU,NPU,GPU等)
KleidiAI专注于基于arm CPU指令集的AI应用基本算子的优化,采用C语言编写这些加速函数,它不依赖于其他库和ML框架,使得KleidiAI可以被容易地集成到各种ML框架中,加速在Arm CPU上推理的性能。KleidiAI能够与PyTorch、TensorFlow、MediaPipe、Angel等主流AI框架无缝集成,从而加速Meta Llama 3、Phi-3、混元大模型等核心模型的性能,为生成式AI工作负载带来显著提升。
以PyTorch为例,Kleidi 在 Torch ATen 层中提供了一个新算子以加载模型。该层将模型权重以特定格式打包在内存中,使用 KleidiAI GEMM 内核提高性能。同样地,针对模型执行的优化使用了 ATen 层中的另一个算子。该算子对先前打包的模型权重进行 matmul 运算的量化。
KleidiAI 算子
KleidiAI库主要由矩阵乘(matmul)算子实现和其需要的量化(Quantization)和数据打包(packing)算子实现为主。Matmul算子的数据类型包括f32, f16, bf16, 8-bit量化Q8, 4-bit量化Q4。通过arm CPU向量指令的汇编指令来实现这些算子。例如,
量化和打包(Quantizing/Packing)支持per-block, per-channel量化为Q8, Q4类型。KleidiAI的优化做了量化后模型内存占用大小,模型精度和运算速度的平衡。Arm向量运算指令本身并不支持INT4数据(Q4)数据类型,为什么KleidiAI还提供可将weight量化为Q4类型的算子呢?这是基于模型内存占用大小考虑,减少weight对内存的占用。在计算的时候,可以将Q4转成INT8来做向量运算。
实际上很多模型本身就提供已经量化好的Q4格式。通常, 将weight量化为Q4在推理的时候,不会带来很大的模型输出误差。 但是,神经网络的activations (即输入)通常变化范围很大,而且不是在数据范围里均匀分布。将输入数据进行INT4的量化(仅有16个数据值来表示)可能会带来大的模式输出误差。 因此,KleidiAI可以策略性地将weight量化为Q4, 而将输入量化为Q8.
再来聊一聊packing(打包),为什么需要packing呢?
在采用不同Arm vector指令实现矩阵乘运算加速时,为了内存高效的内存访问(主要是减少cache miss,和多次重用已经load的数据),我们通常需要对输入矩阵进行重排。这点在我之前有关如果使用arm向量指令加速矩阵运算中有所描述。
例如,采用SMMLA向量指令加速的矩阵乘算子,它操作INT8数值。对weight INT4量化之后Packing的作用是将2个INT4数值排在1个 8-bit的内存空间。
KleidiAI matmul算子的命名规则
根据矩阵分块运算,一个matmul运算可以表示为:
// RHS(N) LOOP
for(n_idx = 0; n_idx < n; n_idx+=nr){
// LHS(M) LOOP
for(m_idx = 0; m_idx < m; m_idx+=mr){
// K LOOP
for(k_idx = 0; k_idx < k; k_idx += kr){
//Block Loop
计算mr x kr 和 kr x nr 大小子矩阵的矩阵乘
}
在k维度累加,得到结果矩阵中mr x nr大小的子矩阵的结果
}
KleidiAI matmul算子的名字规则为:
kai_<op>_<fused_ops>_<dst_info>_<lhs_info>_<rhs_info>_<mr x nr x kacc>_<technology>_<feature>_<instruction>
目前, KleidiAI支持的向量指令为NEON和SME2。
举一个例子:
kai_run_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_8x4x32_neon_i8mm
其中,
- f32表示输出矩阵是FP32格式的
- qsi8d32p4x8表示输入左矩阵(left hand source, LHS)的格式为Quantized symmetric Signed Integer 8-bit Per dimension quantization with block length multiple of 32。
- qsi4c32p4x8表示输入右矩阵的格式为Quantized symmetric Signed Integer 4-bit Per channel quantization with block length multiple of 32。
- 8x4x32 表示一次outer loop进行的运算得出结果矩阵中的一个8x4(8行4列)子矩阵的中间结果, 32表示一次Block Loop进行32次I8MM SMMLA向量运算次数。
利用KleidiAI进行矩阵乘运算
每个KleidiAI matmul算子都有对应的LHS和RHS格式要求,LHS和RHS的格式与使用NEON还是SME2加速有关,对于SME2还与硬件实现向量长度有关。除此之外,格式还与matmul算子block loop实现多少次向量运算相关。因此,利用KleidiAI matmul算子加速,一般需要先对LHS,RHS矩阵进行相应的量化和打包操作,除非原始的LHS和RHS矩阵格式已经满足对应matmul算子的格式要求。
使用KleidiAI进行matmul运算时,主要有3个步骤:
1. Quantizing & Packing LHS矩阵
2. Quantizing & Packing RHS矩阵
3. 进行matmul算子运算
因为一个网络的weight (RHS)矩阵是不变的,对于同一matmul算子,只需要进行一次Quantizing & Packing。
我们通过一个kai_run_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_8x4x32_neon_i8mm算子进行矩阵运算来演示这个过程。
这个算子通过NEON I8MM MMLA(矩阵乘累加指令),每次Block loop进行32次MMLA运算来实现matmul操作。
kai_run_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_8x4x32_neon_i8mm算子有相应的函数给出matmul运算过程使用的mr, nr, kr值(由matmul算子决定这些值):
- kai_get_mr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_8x4x32_neon_i8mm
- kai_get_nr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_8x4x32_neon_i8mm
- kai_get_kr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_8x4x32_neon_i8mm
通过mr, nr, kr值,可以通过以下函数进一步得到m_step, n_step 的信息
- kai_get_m_step_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_8x4x32_neon_i8mm
- kai_get_n_step_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_8x4x32_neon_i8mm
在进行输入矩阵的Quantizing & Packing时,需要利用得到的mr, nr, kr值。
假设LHS和RHS输入矩阵是原始的F32类型的M x K和K x N矩阵,利用kai_run_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_8x4x32_neon_i8mm算子进行matmul运算的过程为:
- 通过kai_run_lhs_quant_pack_qsi8d32p_f32函数将F32 LHS 进行Quantizing & Packing, 得到kai_run_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_8x4x32_neon_i8mm算子要求的qsi8d32p4x8 LHS格式。
使用bl(block length)为32, 这个算子给出的mr, kr, sr分别为4, 16, 2。kai_run_lhs_quant_pack_qsi8d32p_f32函数利用这个信息进行LHS量化和重排,它将LHS按32倍数大小的per-block量化为qsi8类型,量化采用F16 scale,并重排矩阵。这个过程图示如下:
得到包含F16 scale的qsi8量化并打包的LHS。
2. 将原始F32 RHS矩阵按32倍数大小的per-channel量化为qs4类型。这个过程的一个参考实现为quant_qs4c32_f32。可以图示为:
此过程得到包含量化scale的qs4类型的矩阵,它还需要进一步packing为kai_run_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_8x4x32_neon_i8mm算子要求的qsi4c32p4x8 RHS。这个过程可以利用kleidiAI提供的kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0函数来实现。此过程可以图示为:
这个过程的目前主要是重排矩阵,让后续的matmul运算的内存访问连续,减少cache miss。这个过程也有一些简单的qsu4转qsi4数据类型的操作。
3. 现在符合格式的LHS和RHS矩阵都有了,终于可以利用kai_run_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_8x4x32_neon_i8mm函数进行matmul运算了。这个过程的一个block loop运算可以图示为:
其对应代码为:
"1:" // Row loop
"mov x10, %x[rhs_packed]\n"
"mov x9, %x[n]\n"
"add x28, %x[dst], %x[dst_stride_row], LSL #3\n"
"2:" // Column loop
"mov x22, %x[lhs_packed]\n"
"movi v1.16b, #0x0\n"
"movi v22.16b, #0x0\n"
"mov x21, %x[num_blocks]\n"
"movi v14.16b, #0x0\n"
"movi v12.16b, #0x0\n"
"movi v15.16b, #0x0\n"
"movi v19.16b, #0x0\n"
"movi v3.16b, #0x0\n"
"movi v2.16b, #0x0\n"
"add x20, x22, x11\n"
"3:" // Block loop
"ldr d11, [x10, #0x0]\n"
"ldr d10, [x22, #0x0]\n"
"add x10, x10, #0x8\n"
"add x22, x22, #0x8\n"
"ldr q25, [x10, #0x0]\n"
"ldr q30, [x10, #0x10]\n"
"movi v6.4s, #0x0\n"
"movi v21.4s, #0x0\n"
"ldr d24, [x20, #0x0]\n"
"ldr q28, [x22, #0x0]\n"
"add x20, x20, #0x8\n"
"movi v9.4s, #0x0\n"
"ldr q4, [x22, #0x10]\n"
"ldr q23, [x20, #0x0]\n"
"movi v0.4s, #0x0\n"
"movi v31.4s, #0x0\n"
"ldr q17, [x20, #0x10]\n"
"ldr q18, [x10, #0x20]\n"
"shl v20.16b, v25.16b, #0x4\n"
"shl v29.16b, v30.16b, #0x4\n"
"ldr q16, [x10, #0x30]\n"
"ldr q26, [x22, #0x20]\n"
"movi v7.4s, #0x0\n"
"movi v27.4s, #0x0\n"
"ldr q8, [x22, #0x30]\n"
"ldr q5, [x20, #0x20]\n"
"and v25.16b, v25.16b, v13.16b\n"
"and v30.16b, v30.16b, v13.16b\n"
".inst 0x4e94a786 // smmla v6.4s, v28.16b, v20.16b\n"
".inst 0x4e9da795 // smmla v21.4s, v28.16b, v29.16b\n"
"ldr q28, [x20, #0x30]\n"
"fcvtl v11.4s, v11.4h\n"
".inst 0x4e94a489 // smmla v9.4s, v4.16b, v20.16b\n"
".inst 0x4e9da480 // smmla v0.4s, v4.16b, v29.16b\n"
"ldr q4, [x22, #0x40]\n"
"fcvtl v10.4s, v10.4h\n"
".inst 0x4e94a6ff // smmla v31.4s, v23.16b, v20.16b\n"
".inst 0x4e9da6e7 // smmla v7.4s, v23.16b, v29.16b\n"
"ldr q23, [x22, #0x50]\n"
"fcvtl v24.4s, v24.4h\n"
".inst 0x4e94a63b // smmla v27.4s, v17.16b, v20.16b\n"
"movi v20.4s, #0x0\n"
"subs x21, x21, #0x1\n"
"add x10, x10, #0x40\n"
".inst 0x4e9da634 // smmla v20.4s, v17.16b, v29.16b\n"
"ldr q17, [x20, #0x40]\n"
"shl v29.16b, v18.16b, #0x4\n"
"and v18.16b, v18.16b, v13.16b\n"
".inst 0x4e9da746 // smmla v6.4s, v26.16b, v29.16b\n"
".inst 0x4e9da509 // smmla v9.4s, v8.16b, v29.16b\n"
".inst 0x4e9da4bf // smmla v31.4s, v5.16b, v29.16b\n"
".inst 0x4e9da79b // smmla v27.4s, v28.16b, v29.16b\n"
"ldr q29, [x20, #0x50]\n"
".inst 0x4e99a486 // smmla v6.4s, v4.16b, v25.16b\n"
".inst 0x4e99a6e9 // smmla v9.4s, v23.16b, v25.16b\n"
".inst 0x4e99a63f // smmla v31.4s, v17.16b, v25.16b\n"
".inst 0x4e99a7bb // smmla v27.4s, v29.16b, v25.16b\n"
"shl v25.16b, v16.16b, #0x4\n"
"and v16.16b, v16.16b, v13.16b\n"
".inst 0x4e99a755 // smmla v21.4s, v26.16b, v25.16b\n"
"ldr q26, [x22, #0x60]\n"
".inst 0x4e99a500 // smmla v0.4s, v8.16b, v25.16b\n"
"ldr q8, [x22, #0x70]\n"
"add x22, x22, #0x80\n"
".inst 0x4e99a4a7 // smmla v7.4s, v5.16b, v25.16b\n"
"ldr q5, [x20, #0x60]\n"
".inst 0x4e99a794 // smmla v20.4s, v28.16b, v25.16b\n"
"ldr q25, [x20, #0x70]\n"
"fmul v28.4s, v11.4s, v10.s[0]\n"
"add x20, x20, #0x80\n"
".inst 0x4e92a746 // smmla v6.4s, v26.16b, v18.16b\n"
".inst 0x4e9ea495 // smmla v21.4s, v4.16b, v30.16b\n"
"fmul v4.4s, v11.4s, v10.s[1]\n"
".inst 0x4e9ea6e0 // smmla v0.4s, v23.16b, v30.16b\n"
".inst 0x4e92a509 // smmla v9.4s, v8.16b, v18.16b\n"
"fmul v23.4s, v11.4s, v10.s[2]\n"
".inst 0x4e9ea627 // smmla v7.4s, v17.16b, v30.16b\n"
".inst 0x4e92a4bf // smmla v31.4s, v5.16b, v18.16b\n"
"fmul v17.4s, v11.4s, v10.s[3]\n"
".inst 0x4e9ea7b4 // smmla v20.4s, v29.16b, v30.16b\n"
".inst 0x4e92a73b // smmla v27.4s, v25.16b, v18.16b\n"
"fmul v30.4s, v11.4s, v24.s[0]\n"
".inst 0x4e90a755 // smmla v21.4s, v26.16b, v16.16b\n"
"fmul v29.4s, v11.4s, v24.s[1]\n"
".inst 0x4e90a500 // smmla v0.4s, v8.16b, v16.16b\n"
"fmul v18.4s, v11.4s, v24.s[2]\n"
"fmul v10.4s, v11.4s, v24.s[3]\n"
".inst 0x4e90a4a7 // smmla v7.4s, v5.16b, v16.16b\n"
".inst 0x4e90a734 // smmla v20.4s, v25.16b, v16.16b\n"
"uzp1 v26.2d, v6.2d, v21.2d\n"
"uzp2 v6.2d, v6.2d, v21.2d\n"
"uzp1 v24.2d, v9.2d, v0.2d\n"
"uzp2 v16.2d, v9.2d, v0.2d\n"
"uzp1 v8.2d, v31.2d, v7.2d\n"
"uzp2 v11.2d, v31.2d, v7.2d\n"
"scvtf v26.4s, v26.4s, #0x4\n"
"uzp1 v31.2d, v27.2d, v20.2d\n"
"uzp2 v7.2d, v27.2d, v20.2d\n"
"scvtf v6.4s, v6.4s, #0x4\n"
"scvtf v24.4s, v24.4s, #0x4\n"
"scvtf v16.4s, v16.4s, #0x4\n"
"scvtf v8.4s, v8.4s, #0x4\n"
"fmla v1.4s, v26.4s, v28.4s\n"
"scvtf v11.4s, v11.4s, #0x4\n"
"scvtf v31.4s, v31.4s, #0x4\n"
"scvtf v7.4s, v7.4s, #0x4\n"
"fmla v22.4s, v6.4s, v4.4s\n"
"fmla v14.4s, v24.4s, v23.4s\n"
"fmla v12.4s, v16.4s, v17.4s\n"
"fmla v15.4s, v8.4s, v30.4s\n"
"fmla v19.4s, v11.4s, v29.4s\n"
"fmla v3.4s, v31.4s, v18.4s\n"
"fmla v2.4s, v7.4s, v10.4s\n"
"bgt 3b\n"
"ld1r { v17.4s }, [%x[clamp_vals]]\n"
"add x20, %x[clamp_vals], #0x4\n"
"cmp x9, #0x4\n"
"ld1r { v10.4s }, [x20]\n"
"fmax v1.4s, v1.4s, v17.4s\n"
"fmax v22.4s, v22.4s, v17.4s\n"
"fmax v14.4s, v14.4s, v17.4s\n"
"fmax v12.4s, v12.4s, v17.4s\n"
"fmax v15.4s, v15.4s, v17.4s\n"
"fmax v19.4s, v19.4s, v17.4s\n"
"fmax v3.4s, v3.4s, v17.4s\n"
"fmax v2.4s, v2.4s, v17.4s\n"
"fmin v1.4s, v1.4s, v10.4s\n"
"fmin v22.4s, v22.4s, v10.4s\n"
"fmin v14.4s, v14.4s, v10.4s\n"
"fmin v12.4s, v12.4s, v10.4s\n"
"fmin v15.4s, v15.4s, v10.4s\n"
"fmin v19.4s, v19.4s, v10.4s\n"
"fmin v3.4s, v3.4s, v10.4s\n"
"fmin v2.4s, v2.4s, v10.4s\n"
"blt 4f\n"
"mov x20, %x[dst]\n"
"str q1, [x20, #0x0]\n"
"add x20, x20, %x[dst_stride_row]\n"
"str q22, [x20, #0x0]\n"
"add x20, x20, %x[dst_stride_row]\n"
"str q14, [x20, #0x0]\n"
"add x20, x20, %x[dst_stride_row]\n"
"str q12, [x20, #0x0]\n"
"add x20, x20, %x[dst_stride_row]\n"
"str q15, [x20, #0x0]\n"
"add x20, x20, %x[dst_stride_row]\n"
"str q19, [x20, #0x0]\n"
"add x20, x20, %x[dst_stride_row]\n"
"str q3, [x20, #0x0]\n"
"add x20, x20, %x[dst_stride_row]\n"
"str q2, [x20, #0x0]\n"
"b 7f\n"
上图的过程可以算出一个8行4列子矩阵的部分中间结果,在K维度持续做这样的运算,并累加结果,可以得到这个8行4列子矩阵的最终结果。
kai_run_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_8x4x32_neon_i8mm算子通过类似的方式算出所有子矩阵的结果,从而得到全部的矩阵乘结果。
集成KleidiAI到llama.cpp
Llama.cpp默认使用ggml-cpu作为CPU推理的backend。集成KleidiAI到llama.cpp时,可以使用KleidiAI的算子替换ggml-cpu中的算子。其中一个实现就是:当算子的类型为GGML_OP_MUL_MAT的GEMM,并且数据类型为GGML_TYPE_Q4_0时,可以使用KleidiaAI kai_run_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_8x4x32_neon_i8mm算子。
举一个llama.cpp+KleidiAI跑Qwen1.5 0.5B chat模型的例子 (
qwen1_5-0_5b-chat-q4_0.gguf)。
因为模型中的weight已经时Q4模型,但是为了利用kai_run_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_8x4x32_neon_i8mm算子,还是需要在load model时,使用kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0算子对model矩阵(LHS)进行qsu4到qsi4类型转换,并为matmul算子打包(packing)。其函数调用的过程为:
在运行model时,需要对F32 LHS矩阵进行量化和打包, 然后使用kai_run_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_8x4x32_neon_i8mm算子进行矩阵乘运算。其函数调用的过程为:
结语
此文介绍了KleidiAI的工作方式,对于KleidiAI的用户,需要在了解各个Matmul算子的加速方式。在选择合适的算子之后,需要对LHS和RHS矩阵进行适当的量化和打包,才能使Matmul算子正确工作。
Arm KleidiAI轻量级AI运算加速库可以被容易地集成到主流AI框架,让运行在arm CPU上的广泛生态软件,高效地利用arm CPU构架向量指令带来的性能加速。