本文以SVE2 Matrix Multiply指令实现Int8和FP32矩阵乘加速。
Int8的Matrix Multiply矩阵乘
Int8的Matrix Multiply(矩阵乘)指令的操作是:
这条指令将第一个SVE2源向量中每128-bit看作2x8有符号8位整数矩阵, 第二个SVE2源向量中每128-bit的8x2有符号8位整数矩阵,然后将第一个SVE2源向量中的2x8矩阵与第二个SVE2源向量中的对应的8x2矩阵进行矩阵乘,生成的2x2 32位整数矩阵乘积累加到目标向量中的32位整数矩阵累加器中。
使用SVE2 Int8 Matrix Multiply进行矩阵乘基本和SVE2 Int8 Dot product类似。
以SVE2 Int8 Matrix Multiply为例,假设矩阵的行数和列数可以与向量长度VL匹配,A矩阵和B矩阵的矩阵乘(得到C矩阵)可以表达为:
其中标识出来的部分的功能可以使用SVE2 Matrix Multiply来实现,这条指令的第一个向量中的其他128-bit部分都重复其第一个128-bit部分。为了优化内存访问和数据重用,我们进行back to back的Matrix Multiply, 如果CPU的硬件pipeline实现支持同时多条Matrix Multiply指令的执行,这样可以更好的利用这些硬件资源。
我们以进行3个vector back to back的SVE2 Matrix Multiply,VL=256-bit 为例,其过程为:
A矩阵和B矩阵重排
首先,需要对A矩阵和B矩阵进行分割处理(为了内存访问更友好,需要对A,B矩阵在内存中进行重排):
A矩阵可以处理为:
A矩阵块0由
[ a0_0, a0_1 , a0_2, a0_3, a0_4, a0_5 , a0_6, a0_7,
a1_0, a1_1 , a1_2, a1_3, a1_4, a1_5 , a1_6, a1_7,
…….
a7_0, a7_1 , a7_2, a7_3, a7_4, a7_5 , a7_6, a7_7
]。块1,2,3, 4以此类推。A矩阵以这样的方式在K维度继续分块,为了简略,图中没有画出所有的块。
B矩阵可以处理为:
B矩阵块0由
[
b0_0, b1_0, b2_0, b3_0, b4_0, b5_0, b6_0, b7_0,
b0_1, b1_1, b2_1, b3_1, b4_1, b5_1, b6_1, b7_1,
...
b0_15, b1_15, b2_15, b3_15, b4_15, b5_15, b6_15,b7_15
]
这些元素(总共为3 x VL byte)组成。B矩阵以这样的方式在K维度继续分块,为了简略,图中没有画出所有的块。
Int 8 Matrix Multiply运算
接下来进行运算:
将经重排过的A矩阵的块0(0),0(1) load到一个SVE2向量寄存器的每一128-bit,(经重排之后的A矩阵的块在内存中会地址连续存储),可以使用LD1RQB指令
LD1RQB { <Zt>.B }, <Pg>/Z, [<Xn|SP>, <Xm>]
这条指令从内存中load 16 byte(128-bit)到一个SVE2向量寄存器的每一个128-bit段(重复这16 byte),如下图中vector中红色部分,0(0)表示A 矩阵块0的第0个8 byte(即 [a0_0, a0_1 , a0_2, a0_3, a0_4, a0_5 , a0_6, a0_7]), 0(1)表示A 矩阵块0的第1个8 byte(即 [a1_0, a1_1 , a1_2, a1_3, a1_4, a1_5 , a1_6, a1_7]),以此类推。- 将经重排过的B矩阵的块0(0),0(1),0(2),0(3)到一个SVE2向量寄存器,(经重排之后的A矩阵的块在内存中会地址连续存储),可以使用LD1B指令
LD1B { <Zt>.B }, <Pg>/Z, [<Xn|SP>{, #<imm>, MUL VL}]
这条指令从内存中load VL byte到一个SVE2向量寄存器,如下图中vector中蓝色部分,0(0)表示B 矩阵块0的第0个8 byte(即 [b0_0, b1_0, b2_0, b3_0, b4_0, b5_0, b6_0, b7_0]), 0(1)表示B 矩阵块0的第1个8 byte(即 [b0_1, b1_1, b2_1, b3_1, b4_1, b5_1, b6_1, b7_1]),以此类推。 - 然后使用SVE2 Matrix Multiply指令,
UMMLA <Zda>.S, <Zn>.B, <Zm>.B
来执行如下图所示的操作:
这会计算出8个32-bit结果(VL=256-bit),这些结果是如下图所示的C矩阵的这些元素的部分结果(中间结果,不是最终结果): - 再利用1,2,3类似的方式计算:
得到如下图所示的C矩阵的这些元素的部分结果(中间结果,不是最终结果)。
重复这样的过程,计算:
• A块的0(4), 0(5)与B块0(0),0(1),0(2),0(3)做UMMLA运算
• A块的0(6),0(7) 与B块0(0),0(1),0(2),0(3)做UMMLA运算
经过以上步骤,会计算出C矩阵以下元素的中间结果:
- 计算:
• A块的0(0),0(1)与B块0(4),0(5),0(6),0(7)做UMMLA运算
• A块的0(2),0(3)与B块0(4),0(5),0(6),0(7)做UMMLA运算
• A块的0(4),0(5)与B块0(4),0(5),0(6),0(7)做UMMLA运算
• A块的0(6),0(7)与B块0(4),0(5),0(6),0(7)做UMMLA运算
• A块的0(0),0(1)与B块0(8),0(9),0(10),0(11)做UMMLA运算
• A块的0(2),0(3)与B块0(8),0(9),0(10),0(11)做UMMLA运算
• ………
• A块的0(6),0(7)与B块0(12),0(13),0(14),0(15)做UMMLA运算
经过以上步骤,会计算出C矩阵以下元素的中间结果:
- 计算:
• A块的1(0),1(1)与B块1(0),1(1),1(2),1(3)UMMLA运算(此过程会累加上面运算的C中间结果)
这会将C矩阵中的以下元素更新为:
• A块的1(2),1(3)与B块1(0),1(1),1(2),1(3)UMMLA运算(此过程会累加上面运算的C中间结果)
• A块的1(4),1(5)与B块1(0),1(1),1(2),1(3)UMMLA运算(此过程会累加上面运算的C中间结果)
• A块的1(6),1(7)与B块1(0),1(1),1(2),1(3)UMMLA运算(此过程会累加上面运算的C中间结果)
• A块的1(0),1(1)与B块1(4),1(5),1(6),1(7)UMMLA运算(此过程会累加上面运算的C中间结果)
• A块的1(2),1(3)与B块1(4),1(5),1(6),1(7)UMMLA运算(此过程会累加上面运算的C中间结果)
• …………
• A块的1(6),1(7)与B块1(12),1(13),1(14),1(15)UMMLA运算(此过程会累加上面运算的C中间结果)
经上面步骤,C矩阵的元素更新为:
- 继续在K维度对A矩阵的块与B矩阵的块进行类似的迭代运算,完成K维度迭代之后,可以得到C矩阵的以下元素的最终结果:
- 以同样的方式计算C矩阵的其他元素。
代码实现如下:
"ld1rqb { z6.b }, p0/Z, [%x[Apanel]]\n"
".inst 0x45c49808 // ummla z8.s, z0.b, z4.b\n"
".inst 0x45c5980b // ummla z11.s, z0.b, z5.b\n"
".inst 0x45c4982e // ummla z14.s, z1.b, z4.b\n"
".inst 0x45c59831 // ummla z17.s, z1.b, z5.b\n"
"ld1b { z3.b }, p0/Z, [x22]\n"
".inst 0x45c49854 // ummla z20.s, z2.b, z4.b\n"
".inst 0x45c59857 // ummla z23.s, z2.b, z5.b\n"
"ld1b { z7.b }, p0/Z, [x22, #1, MUL VL]\n"
".inst 0x45c498da // ummla z26.s, z6.b, z4.b\n"
".inst 0x45c598dd // ummla z29.s, z6.b, z5.b\n"
"ld1b { z4.b }, p0/Z, [x22, #2, MUL VL]\n"
"ld1b { z5.b }, p0/Z, [x22, #3, MUL VL]\n"
".inst 0x45c39809 // ummla z9.s, z0.b, z3.b\n"
"sub x20, x20, #0x2\n"
".inst 0x45c7980c // ummla z12.s, z0.b, z7.b\n"
".inst 0x45c3982f // ummla z15.s, z1.b, z3.b\n"
"cmp x20, #0x2\n"
".inst 0x45c79832 // ummla z18.s, z1.b, z7.b\n"
".inst 0x45c39855 // ummla z21.s, z2.b, z3.b\n"
".inst 0x45c79858 // ummla z24.s, z2.b, z7.b\n"
".inst 0x45c398db // ummla z27.s, z6.b, z3.b\n"
"ld1b { z3.b }, p0/Z, [x22, #4, MUL VL]\n"
".inst 0x45c798de // ummla z30.s, z6.b, z7.b\n"
".inst 0x45c4980a // ummla z10.s, z0.b, z4.b\n"
"ld1b { z7.b }, p0/Z, [x22, #5, MUL VL]\n"
".inst 0x45c5980d // ummla z13.s, z0.b, z5.b\n"
".inst 0x45c49830 // ummla z16.s, z1.b, z4.b\n"
"ld1rqb { z0.b }, p0/Z, [%x[Apanel], #16]\n"
".inst 0x45c59833 // ummla z19.s, z1.b, z5.b\n"
".inst 0x45c49856 // ummla z22.s, z2.b, z4.b\n"
"ld1rqb { z1.b }, p0/Z, [%x[Apanel], #32]\n"
".inst 0x45c59859 // ummla z25.s, z2.b, z5.b\n"
".inst 0x45c498dc // ummla z28.s, z6.b, z4.b\n"
"ld1rqb { z2.b }, p0/Z, [%x[Apanel], #48]\n"
".inst 0x45c598df // ummla z31.s, z6.b, z5.b\n"
"ld1rqb { z6.b }, p0/Z, [%x[Apanel], #64]\n"
"ld1b { z4.b }, p0/Z, [x22, #6, MUL VL]\n"
"ld1b { z5.b }, p0/Z, [x22, #7, MUL VL]\n"
"addvl x22, x22, #16\n"
".inst 0x45c39808 // ummla z8.s, z0.b, z3.b\n"
".inst 0x45c7980b // ummla z11.s, z0.b, z7.b\n"
".inst 0x45c3982e // ummla z14.s, z1.b, z3.b\n"
".inst 0x45c79831 // ummla z17.s, z1.b, z7.b\n"
".inst 0x45c39854 // ummla z20.s, z2.b, z3.b\n"
".inst 0x45c79857 // ummla z23.s, z2.b, z7.b\n"
".inst 0x45c398da // ummla z26.s, z6.b, z3.b\n"
"ld1b { z3.b }, p0/Z, [x22, #-8, MUL VL]\n"
".inst 0x45c798dd // ummla z29.s, z6.b, z7.b\n"
"ld1b { z7.b }, p0/Z, [x22, #-7, MUL VL]\n"
".inst 0x45c49809 // ummla z9.s, z0.b, z4.b\n"
".inst 0x45c5980c // ummla z12.s, z0.b, z5.b\n"
".inst 0x45c4982f // ummla z15.s, z1.b, z4.b\n"
".inst 0x45c59832 // ummla z18.s, z1.b, z5.b\n"
".inst 0x45c49855 // ummla z21.s, z2.b, z4.b\n"
".inst 0x45c59858 // ummla z24.s, z2.b, z5.b\n"
".inst 0x45c498db // ummla z27.s, z6.b, z4.b\n"
"ld1b { z4.b }, p0/Z, [x22, #-6, MUL VL]\n"
".inst 0x45c598de // ummla z30.s, z6.b, z5.b\n"
".inst 0x45c3980a // ummla z10.s, z0.b, z3.b\n"
"ld1b { z5.b }, p0/Z, [x22, #-5, MUL VL]\n"
".inst 0x45c7980d // ummla z13.s, z0.b, z7.b\n"
".inst 0x45c39830 // ummla z16.s, z1.b, z3.b\n"
"ld1rqb { z0.b }, p0/Z, [%x[Apanel], #80]\n"
".inst 0x45c79833 // ummla z19.s, z1.b, z7.b\n"
".inst 0x45c39856 // ummla z22.s, z2.b, z3.b\n"
"ld1rqb { z1.b }, p0/Z, [%x[Apanel], #96]\n"
".inst 0x45c79859 // ummla z25.s, z2.b, z7.b\n"
".inst 0x45c398dc // ummla z28.s, z6.b, z3.b\n"
"ld1rqb { z2.b }, p0/Z, [%x[Apanel], #112]\n"
".inst 0x45c798df // ummla z31.s, z6.b, z7.b\n"
"add %x[Apanel], %x[Apanel], #0x80\n"
"addvl x22, x22, #-4\n"
"bge 3b\n"
可以参见Arm compute library和KleidiAI中的代码:
https://github.com/ARM-softwa...
https://gitlab.arm.com/kleidi...
KleidiAI中的NEON实现来处理了量化的一些操作,会和介绍的实现稍有差别。
Arm Compute Library还包含了比较丰富的量化,较小的K值处理, NEON,等不同的kernel实现。
Int8的Matrix Multiply矩阵乘
FP32的Matrix Multiply(矩阵乘)指令的操作是:
这条指令将第一个SVE2源向量中每128-bit看作2x2 FP32矩阵, 第二个SVE2源向量中每128-bit的2x2 FP32矩阵,然后将第一个SVE2源向量中的2x2矩阵与第二个SVE2源向量中的对应的2x2矩阵进行矩阵乘,生成的2x2 FP32矩阵乘积累加到目标向量中的FP32矩阵累加器中。
A矩阵和B矩阵重排
FP32 A矩阵可以重排为:
FP32 B矩阵可以重排为:
FP32 Matrix Multiply运算
FP32的A矩阵和B矩阵的矩阵乘运算可以采用Int8矩阵乘类似的方式:
- 重排后的矩阵A块的0(0),0(1)与B块0(0),0(1),0(2),0(3)FMMLA运算。其中A块0(0)为[a0_0, a0_1], A块0(1)为[a1_0, a1_1], B块0(0)为[b0_0, b1_0], B块0(1)为[b0_1, b1_1], B块0(2)为[b0_2, b1_2], B块0(3)为[b0_3, b1_3].
- 经上面计算的到C矩阵的下面这些运算的中间值:
后面的迭代计算和Int8类似,不再赘述。
代码如下:
"2:\n"
".inst 0x64a4e408 // fmmla z8.s, z0.s, z4.s\n"
"ld1w z7.s, p0/z, [%[b_ptr], #-1, MUL VL]\n"
".inst 0x64a4e42e // fmmla z14.s, z1.s, z4.s\n"
"ld1rqw z3.s, p0/z, [%[a_ptr], #-0x10]\n"
".inst 0x64a4e454 // fmmla z20.s, z2.s, z4.s\n"
"subs %[loops], %[loops], #0x1\n"
".inst 0x64a5e409 // fmmla z9.s, z0.s, z5.s\n"
".inst 0x64a4e47a // fmmla z26.s, z3.s, z4.s\n"
"ld1w z4.s, p0/z, [%[b_ptr]]\n"
".inst 0x64a5e42f // fmmla z15.s, z1.s, z5.s\n"
".inst 0x64a5e455 // fmmla z21.s, z2.s, z5.s\n"
".inst 0x64a5e47b // fmmla z27.s, z3.s, z5.s\n"
"ld1w z5.s, p0/z, [%[b_ptr], #1, MUL VL]\n"
".inst 0x64a6e40a // fmmla z10.s, z0.s, z6.s\n"
".inst 0x64a6e430 // fmmla z16.s, z1.s, z6.s\n"
".inst 0x64a6e456 // fmmla z22.s, z2.s, z6.s\n"
".inst 0x64a6e47c // fmmla z28.s, z3.s, z6.s\n"
"ld1w z6.s, p0/z, [%[b_ptr], #2, MUL VL]\n"
".inst 0x64a7e40b // fmmla z11.s, z0.s, z7.s\n"
".inst 0x64a7e431 // fmmla z17.s, z1.s, z7.s\n"
".inst 0x64a7e457 // fmmla z23.s, z2.s, z7.s\n"
".inst 0x64a7e47d // fmmla z29.s, z3.s, z7.s\n"
"ld1w z7.s, p0/z, [%[b_ptr], #3, MUL VL]\n"
".inst 0x64a4e40c // fmmla z12.s, z0.s, z4.s\n"
".inst 0x64a4e432 // fmmla z18.s, z1.s, z4.s\n"
".inst 0x64a4e458 // fmmla z24.s, z2.s, z4.s\n"
".inst 0x64a4e47e // fmmla z30.s, z3.s, z4.s\n"
"ld1w z4.s, p0/z, [%[b_ptr], #4, MUL VL]\n"
".inst 0x64a5e40d // fmmla z13.s, z0.s, z5.s\n"
"ld1rqw z0.s, p0/z, [%[a_ptr]]\n"
".inst 0x64a5e433 // fmmla z19.s, z1.s, z5.s\n"
"ld1rqw z1.s, p0/z, [%[a_ptr], #0x10]\n"
".inst 0x64a5e459 // fmmla z25.s, z2.s, z5.s\n"
"ld1rqw z2.s, p0/z, [%[a_ptr], #0x20]\n"
".inst 0x64a5e47f // fmmla z31.s, z3.s, z5.s\n"
"ld1w z5.s, p0/z, [%[b_ptr], #5, MUL VL]\n"
".inst 0x64a6e408 // fmmla z8.s, z0.s, z6.s\n"
"ld1rqw z3.s, p0/z, [%[a_ptr], #0x30]\n"
".inst 0x64a6e42e // fmmla z14.s, z1.s, z6.s\n"
"add %[a_ptr], %[a_ptr], #0x80\n"
".inst 0x64a6e454 // fmmla z20.s, z2.s, z6.s\n"
"addvl %[b_ptr], %[b_ptr], #12\n"
".inst 0x64a6e47a // fmmla z26.s, z3.s, z6.s\n"
".inst 0x64a7e409 // fmmla z9.s, z0.s, z7.s\n"
".inst 0x64a7e42f // fmmla z15.s, z1.s, z7.s\n"
"ld1w z6.s, p0/z, [%[b_ptr], #-6, MUL VL]\n"
".inst 0x64a7e455 // fmmla z21.s, z2.s, z7.s\n"
".inst 0x64a7e47b // fmmla z27.s, z3.s, z7.s\n"
"ld1w z7.s, p0/z, [%[b_ptr], #-5, MUL VL]\n"
".inst 0x64a4e40a // fmmla z10.s, z0.s, z4.s\n"
".inst 0x64a4e430 // fmmla z16.s, z1.s, z4.s\n"
".inst 0x64a4e456 // fmmla z22.s, z2.s, z4.s\n"
".inst 0x64a4e47c // fmmla z28.s, z3.s, z4.s\n"
"ld1w z4.s, p0/z, [%[b_ptr], #-4, MUL VL]\n"
".inst 0x64a5e40b // fmmla z11.s, z0.s, z5.s\n"
".inst 0x64a5e431 // fmmla z17.s, z1.s, z5.s\n"
".inst 0x64a5e457 // fmmla z23.s, z2.s, z5.s\n"
".inst 0x64a5e47d // fmmla z29.s, z3.s, z5.s\n"
"ld1w z5.s, p0/z, [%[b_ptr], #-3, MUL VL]\n"
".inst 0x64a6e40c // fmmla z12.s, z0.s, z6.s\n"
".inst 0x64a6e432 // fmmla z18.s, z1.s, z6.s\n"
".inst 0x64a6e458 // fmmla z24.s, z2.s, z6.s\n"
".inst 0x64a6e47e // fmmla z30.s, z3.s, z6.s\n"
"ld1w z6.s, p0/z, [%[b_ptr], #-2, MUL VL]\n"
".inst 0x64a7e40d // fmmla z13.s, z0.s, z7.s\n"
"ld1rqw z0.s, p0/z, [%[a_ptr], #-0x40]\n"
".inst 0x64a7e433 // fmmla z19.s, z1.s, z7.s\n"
"ld1rqw z1.s, p0/z, [%[a_ptr], #-0x30]\n"
".inst 0x64a7e459 // fmmla z25.s, z2.s, z7.s\n"
"ld1rqw z2.s, p0/z, [%[a_ptr], #-0x20]\n"
".inst 0x64a7e47f // fmmla z31.s, z3.s, z7.s\n"
"b.ne 2b\n
可以参见Arm compute library中的代码:
https://github.com/ARM-softwa...
如何快速重排矩阵?
与前文类似,可以借助arm向量指令的ZIP1和ZIP2交织指令来实现快速矩阵重排。存回最终结果时,需要对放在向量寄存器中的各个2x2结果矩阵进行解交织,以便利用ST1W指令进行连续地址存储,这可以借助UZP1, UZP2解交织指令来完成。
结论
利用SVE2的ZIP交织指令可以快速重排矩阵,Matrix Multiply指令可以快速迭代计算重排之后的矩阵乘结果,相对Dot product,单条Matrix Multiply的乘加的数量是其2倍,也更大程度上利用了加载的数据。