理解本文需要具备SME2构架预备知识,建议先阅读之前文章。
本文先介绍如何利用INT8整型类型SME2 outer product and accumulate指令实现的矩阵乘。
INT8类型的矩阵乘
INT8类型的SME2 Outer Product and Accumulate 指令执行的操作如下:
以SME2 Int8 Matrix Multiply为例,SME2 Streaming SVE向量长度为SVL,A矩阵和B矩阵的矩阵乘(得到C矩阵)可以表达为:
其中红色标识部分可以使用SME2的一条Int8 outer product and accumulate指令实现。每条Int8 SME2 outer product and accumulate可以看作执行一个(SVL_in_bytes/4) x 4的矩阵和另一个 4 x (SVL_in_bytes/4) 的矩阵乘,得到一个(SVL_in_bytes/4) x (SVL_in_bytes/4) 大小的int32矩阵。由此表达可以看出C矩阵可以由计算一个个C矩阵中SVL x SVL的子矩阵((SVL_in_bytes/4) x (SVL_in_bytes/4) 大小的int32矩阵 )完成。每个SVL x SVL子矩阵可以由K/4个SME2 outer product and accumulate操作来完成。
SME2 outer product and accumulate实现矩阵乘
为了演示,把SVL假设为128-bit(可以放16个Int8元素的向量)。
将A矩阵的
[ [a0_0, a0_1, a0_2, a0_3], [a1_0, a1_1, a1_2, a1_3], [a2_0, a2_1, a2_2, a2_3], [a3_0, a3_1, a3_2, a3_3] ]
取出到一个SME2 SVE向量寄存器,假设它为Z__a0。
将B矩阵的[ [b0_0, b1_0, b2_0, b3_0], [b0_1, b1_1, b2_1, b3_1] , [b0_2, b1_2, b2_2, b3_2] , [b0_3, b1_3, b2_3, b3_3] ]
,取出到一个SME2 SVE向量寄存器,假设它为Z__b0。
- 将Z__a0与Z__b0进行SME2 Int8 outer product and accumulate计算,将得到C矩阵中如下元素们的中间值:
图中C矩阵被更新的元素的颜色为A矩阵向量Z__a0(红色)和B矩阵向量Z__b0(蓝色)的混色,表示它们是这两个颜色的向量的SME2 outer product and accumulate运算更新的。注意得到的C矩阵的每个元素类型为Int32。
将A矩阵的
[ [a0_4, a0_5, a0_6, a0_7], [a1_4, a1_5, a1_6, a1_7], [a2_4, a2_5, a2_6, a2_7], [a3_4, a3_5, a3_6, a3_7] ]
取出到一个SME2 SVE向量寄存器,假设它为Z__a1。
将B矩阵[ [b4_0, b5_0, b6_0, b7_0], [b4_1, b5_1, b6_1, b7_1] , [b4_2, b5_2, b6_2, b7_2] , [b4_3, b5_3, b6_3, b7_3] ]
取出到一个SME2 SVE向量寄存器,假设它为Z__b1。
- 将Z__a1与Z__b1进行SME2 outer product and accumulate计算,将更新C矩阵中如下元素们的中间值:
- 重复以上步骤,在K维度上迭代:
完成这些迭代之后,可以得到这些C矩阵元素的最终结果:
- 再利用同样的方法计算C矩阵中其他子矩阵的结果。·
为了,
• 提高A矩阵和B矩阵这些向量的利用效率
• 提高ZA tile的利用效率:如之前FP32 SME2 outer product and accumulate类似,可以同时使用4个ZA tile:ZA0.S – ZA3.S
• 将多条 back to back(背靠背)的SME2 outer product and accumulate指令pipeline起来,可以提供CPU硬件执行这些操作的效率
实际上可以做一些优化调整:将4个SME2 outer product and accumulate操作pipe起来。
- (a)计算矩阵A红色Z_a0和矩阵B蓝色Z_b0, 结果存放在SME2 tile ZA0.S
(b)计算矩阵A红色Z_a0和矩阵B绿色Z_b5, 结果存放在SME2 tile ZA1.S
(c)计算矩阵A棕色Z_a5和矩阵B蓝色Z_b0, 结果存放在SME2 tile ZA2.S
(d)计算矩阵A棕色Z_a5和矩阵B绿色Z_b5, 结果存放在SME2 tile ZA4.S
- (a)计算矩阵A红色Z_a1和矩阵B蓝色Z_b1, 结果更新在SME2 tile ZA0.S
(b)计算矩阵A红色Z_a1和矩阵B绿色Z_b6, 结果更新在SME2 tile ZA1.S
(c)计算矩阵A棕色Z_a6和矩阵B蓝色Z_b1, 结果更新在SME2 tile ZA2.S
(d)计算矩阵A棕色Z_a6和矩阵B绿色Z_b6, 结果更新在SME2 tile ZA4.S
- 在K维度以上面的方式迭代,最终可以计算C矩阵上面那些元素的最终结果。
A矩阵和B矩阵的重排
由上面步骤看出,最好对A矩阵和B矩阵进行一些transpose重排,来提高矩阵的内存访问效率。矩阵需要重排为:
上章节的SME2 outer product and accumulate运算中使用经重排后的A矩阵和B矩阵,可以达到对A和B的向量访问在内存中地址连续,提供内存访问效率,减少cache miss。
对于A矩阵的重排,可以使用FP32 outer product and accumulate矩阵操作一样的对A矩阵重排的方法,只需把4个水平的Int8元素组合看作是一个32-bit的FP32元素,这样可以利用SME2 的tile来实现on-fly transpose的快速重排。
而对于B矩阵的重排,可以利用SME2的Multi-vector的ZIP指令来实现:
ZIP {Z4.B-Z7.B}, {Z0.B-Z3.B}
它可以建Z0.B-Z3.B中的元素
Transpose为
SME2 Int8 outer product and accumulate运算代码实现
SME2 Int8 outer product and accumulate运算代码实现可以参考:
"6:" // K loop
".inst 0xa08200c0 // smopa za0.s, p0/M, p0/M, z6.b, z2.b\n"
"subs x21, x21, #0x1\n"
".inst 0xa08300c1 // smopa za1.s, p0/M, p0/M, z6.b, z3.b\n"
".inst 0xa08201c2 // smopa za2.s, p0/M, p0/M, z14.b, z2.b\n"
".inst 0xa08301c3 // smopa za3.s, p0/M, p0/M, z14.b, z3.b\n"
".inst 0xa1400766 // ld1b { z6.b, z14.b }, pn9.b/Z, [x27]\n"
".inst 0xa0960340 // smopa za0.s, p0/M, p0/M, z26.b, z22.b\n"
".inst 0xa04006e2 // ld1b { z2.b-z3.b }, pn9.b/Z, [x23]\n"
".inst 0xa0970341 // smopa za1.s, p0/M, p0/M, z26.b, z23.b\n"
".inst 0xa0960362 // smopa za2.s, p0/M, p0/M, z27.b, z22.b\n"
".inst 0xa0970363 // smopa za3.s, p0/M, p0/M, z27.b, z23.b\n"
".inst 0xa041077a // ld1b { z26.b-z27.b }, pn9.b/Z, [x27, #0x2, MUL VL]\n"
".inst 0xa09500a0 // smopa za0.s, p0/M, p0/M, z5.b, z21.b\n"
".inst 0xa04106f6 // ld1b { z22.b-z23.b }, pn9.b/Z, [x23, #0x2, MUL VL]\n"
".inst 0xa09d00a1 // smopa za1.s, p0/M, p0/M, z5.b, z29.b\n"
".inst 0xa09501a2 // smopa za2.s, p0/M, p0/M, z13.b, z21.b\n"
".inst 0xa09d01a3 // smopa za3.s, p0/M, p0/M, z13.b, z29.b\n"
".inst 0xa1420765 // ld1b { z5.b, z13.b }, pn9.b/Z, [x27, #0x4, MUL VL]\n"
".inst 0xa14206f5 // ld1b { z21.b, z29.b }, pn9.b/Z, [x23, #0x4, MUL VL]\n"
".inst 0xa0910000 // smopa za0.s, p0/M, p0/M, z0.b, z17.b\n"
".inst 0xa0990001 // smopa za1.s, p0/M, p0/M, z0.b, z25.b\n"
".inst 0xa0910022 // smopa za2.s, p0/M, p0/M, z1.b, z17.b\n"
".inst 0xa0990023 // smopa za3.s, p0/M, p0/M, z1.b, z25.b\n"
".inst 0xa0430760 // ld1b { z0.b-z1.b }, pn9.b/Z, [x27, #0x6, MUL VL]\n"
"addvl x27, x27, #8\n"
".inst 0xa14306f1 // ld1b { z17.b, z25.b }, pn9.b/Z, [x23, #0x6, MUL VL]\n"
"addvl x23, x23, #8\n"
"bgt 6b\n"
https://github.com/ARM-softwa...
结语
其他数据类型的矩阵乘可以参照FP32, Int8的SME2 Int8 outer product and accumulate方式来实现。通过SME2 Int8 outer product and accumulate的快速outer product运算,SME2 on-fly transpose和SME2 multi-vector运算,可以很大程度上提高矩阵乘运算的速度和效率。
Arm Scalable Matrix Extension介绍
第二部分: Arm Scalable Matrix Extension (SME)指令
Arm构架如何让AI应用高效运行于CPU(1)
Arm构架如何让AI应用高效运行于CPU(2)
如何使用Arm向量指令加速矩阵乘(1)-SVE2 Dot Product
如何使用Arm向量指令加速矩阵乘(2)–SVE2 Matrix Multiply
如何使用Arm 向量指令加速矩阵乘 (3) – SME2 MOPA (1)