修志龙_ZenonXiu · 2024年10月07日 · 上海

如何使用Arm 向量指令加速矩阵乘 (2) – SVE2 Matrix Multiply

本文以SVE2 Matrix Multiply指令实现Int8和FP32矩阵乘加速。

Int8的Matrix Multiply矩阵乘

Int8的Matrix Multiply(矩阵乘)指令的操作是:
Picture1.png

这条指令将第一个SVE2源向量中每128-bit看作2x8有符号8位整数矩阵, 第二个SVE2源向量中每128-bit的8x2有符号8位整数矩阵,然后将第一个SVE2源向量中的2x8矩阵与第二个SVE2源向量中的对应的8x2矩阵进行矩阵乘,生成的2x2 32位整数矩阵乘积累加到目标向量中的32位整数矩阵累加器中。
使用SVE2 Int8 Matrix Multiply进行矩阵乘基本和SVE2 Int8 Dot product类似。
以SVE2 Int8 Matrix Multiply为例,假设矩阵的行数和列数可以与向量长度VL匹配,A矩阵和B矩阵的矩阵乘(得到C矩阵)可以表达为:

image.png

其中标识出来的部分的功能可以使用SVE2 Matrix Multiply来实现,这条指令的第一个向量中的其他128-bit部分都重复其第一个128-bit部分。为了优化内存访问和数据重用,我们进行back to back的Matrix Multiply, 如果CPU的硬件pipeline实现支持同时多条Matrix Multiply指令的执行,这样可以更好的利用这些硬件资源。
我们以进行3个vector back to back的SVE2 Matrix Multiply,VL=256-bit 为例,其过程为:

A矩阵和B矩阵重排

首先,需要对A矩阵和B矩阵进行分割处理(为了内存访问更友好,需要对A,B矩阵在内存中进行重排):
A矩阵可以处理为:
gemm-sve_interleaved_u8u32_mmla_3vlx8_1.jpg
A矩阵块0由
[ a0_0, a0_1 , a0_2, a0_3, a0_4, a0_5 , a0_6, a0_7,
a1_0, a1_1 , a1_2, a1_3, a1_4, a1_5 , a1_6, a1_7,
…….
a7_0, a7_1 , a7_2, a7_3, a7_4, a7_5 , a7_6, a7_7
]。块1,2,3, 4以此类推。A矩阵以这样的方式在K维度继续分块,为了简略,图中没有画出所有的块。

B矩阵可以处理为:
gemm-sve_interleaved_u8u32_mmla_3vlx8_2.jpg

B矩阵块0由
[
b0_0, b1_0, b2_0, b3_0, b4_0, b5_0, b6_0, b7_0,
b0_1, b1_1, b2_1, b3_1, b4_1, b5_1, b6_1, b7_1,

...

b0_15, b1_15, b2_15, b3_15, b4_15, b5_15, b6_15,b7_15
]
这些元素(总共为3 x VL byte)组成。B矩阵以这样的方式在K维度继续分块,为了简略,图中没有画出所有的块。

Int 8 Matrix Multiply运算

接下来进行运算:

  1. 将经重排过的A矩阵的块0(0),0(1) load到一个SVE2向量寄存器的每一128-bit,(经重排之后的A矩阵的块在内存中会地址连续存储),可以使用LD1RQB指令
    LD1RQB { <Zt>.B }, <Pg>/Z, [<Xn|SP>, <Xm>]
    这条指令从内存中load 16 byte(128-bit)到一个SVE2向量寄存器的每一个128-bit段(重复这16 byte),如下图中vector中红色部分,0(0)表示A 矩阵块0的第0个8 byte(即 [a0_0, a0_1 , a0_2, a0_3, a0_4, a0_5 , a0_6, a0_7]), 0(1)表示A 矩阵块0的第1个8 byte(即 [a1_0, a1_1 , a1_2, a1_3, a1_4, a1_5 , a1_6, a1_7]),以此类推。

    gemm-sve_interleaved_u8u32_mmla_3vlx8_6.jpg

  2. 将经重排过的B矩阵的块0(0),0(1),0(2),0(3)到一个SVE2向量寄存器,(经重排之后的A矩阵的块在内存中会地址连续存储),可以使用LD1B指令
    LD1B { <Zt>.B }, <Pg>/Z, [<Xn|SP>{, #<imm>, MUL VL}]
    这条指令从内存中load VL byte到一个SVE2向量寄存器,如下图中vector中蓝色部分,0(0)表示B 矩阵块0的第0个8 byte(即 [b0_0, b1_0, b2_0, b3_0, b4_0, b5_0, b6_0, b7_0]), 0(1)表示B 矩阵块0的第1个8 byte(即 [b0_1, b1_1, b2_1, b3_1, b4_1, b5_1, b6_1, b7_1]),以此类推。
  3. 然后使用SVE2 Matrix Multiply指令,
    UMMLA <Zda>.S, <Zn>.B, <Zm>.B
    来执行如下图所示的操作:
    gemm-sve_interleaved_u8u32_mmla_3vlx8_3.jpg
    这会计算出8个32-bit结果(VL=256-bit),这些结果是如下图所示的C矩阵的这些元素的部分结果(中间结果,不是最终结果):
    gemm-sve_interleaved_u8u32_mmla_3vlx8_4.jpg
  4. 再利用1,2,3类似的方式计算:
    gemm-sve_interleaved_u8u32_mmla_3vlx8_4_1.jpg

得到如下图所示的C矩阵的这些元素的部分结果(中间结果,不是最终结果)。
gemm-sve_interleaved_u8u32_mmla_3vlx8_5.jpg

重复这样的过程,计算:
• A块的0(4), 0(5)与B块0(0),0(1),0(2),0(3)做UMMLA运算
gemm-sve_interleaved_u8u32_mmla_3vlx8_7.jpg
• A块的0(6),0(7) 与B块0(0),0(1),0(2),0(3)做UMMLA运算
gemm-sve_interleaved_u8u32_mmla_3vlx8_8.jpg

经过以上步骤,会计算出C矩阵以下元素的中间结果:
gemm-sve_interleaved_u8u32_mmla_3vlx8_9.jpg

  1. 计算:
    • A块的0(0),0(1)与B块0(4),0(5),0(6),0(7)做UMMLA运算
    • A块的0(2),0(3)与B块0(4),0(5),0(6),0(7)做UMMLA运算
    • A块的0(4),0(5)与B块0(4),0(5),0(6),0(7)做UMMLA运算
    • A块的0(6),0(7)与B块0(4),0(5),0(6),0(7)做UMMLA运算
    • A块的0(0),0(1)与B块0(8),0(9),0(10),0(11)做UMMLA运算
    • A块的0(2),0(3)与B块0(8),0(9),0(10),0(11)做UMMLA运算
    • ………
    • A块的0(6),0(7)与B块0(12),0(13),0(14),0(15)做UMMLA运算

经过以上步骤,会计算出C矩阵以下元素的中间结果:
gemm-sve_interleaved_u8u32_mmla_3vlx8_10.jpg

  1. 计算:
    • A块的1(0),1(1)与B块1(0),1(1),1(2),1(3)UMMLA运算(此过程会累加上面运算的C中间结果)
    gemm-sve_interleaved_u8u32_mmla_3vlx8_11.jpg

这会将C矩阵中的以下元素更新为:
gemm-sve_interleaved_u8u32_mmla_3vlx8_12.jpg

• A块的1(2),1(3)与B块1(0),1(1),1(2),1(3)UMMLA运算(此过程会累加上面运算的C中间结果)
• A块的1(4),1(5)与B块1(0),1(1),1(2),1(3)UMMLA运算(此过程会累加上面运算的C中间结果)
• A块的1(6),1(7)与B块1(0),1(1),1(2),1(3)UMMLA运算(此过程会累加上面运算的C中间结果)

• A块的1(0),1(1)与B块1(4),1(5),1(6),1(7)UMMLA运算(此过程会累加上面运算的C中间结果)
• A块的1(2),1(3)与B块1(4),1(5),1(6),1(7)UMMLA运算(此过程会累加上面运算的C中间结果)
• …………
• A块的1(6),1(7)与B块1(12),1(13),1(14),1(15)UMMLA运算(此过程会累加上面运算的C中间结果)

经上面步骤,C矩阵的元素更新为:
gemm-sve_interleaved_u8u32_mmla_3vlx8_13.jpg

  1. 继续在K维度对A矩阵的块与B矩阵的块进行类似的迭代运算,完成K维度迭代之后,可以得到C矩阵的以下元素的最终结果:
    gemm-sve_interleaved_u8u32_mmla_3vlx8_14.jpg
  2. 以同样的方式计算C矩阵的其他元素。

代码实现如下:

      "ld1rqb { z6.b }, p0/Z, [%x[Apanel]]\n"
      ".inst 0x45c49808  // ummla z8.s, z0.b, z4.b\n"
      ".inst 0x45c5980b  // ummla z11.s, z0.b, z5.b\n"
      ".inst 0x45c4982e  // ummla z14.s, z1.b, z4.b\n"
      ".inst 0x45c59831  // ummla z17.s, z1.b, z5.b\n"
      "ld1b { z3.b }, p0/Z, [x22]\n"
      ".inst 0x45c49854  // ummla z20.s, z2.b, z4.b\n"
      ".inst 0x45c59857  // ummla z23.s, z2.b, z5.b\n"
      "ld1b { z7.b }, p0/Z, [x22, #1, MUL VL]\n"
      ".inst 0x45c498da  // ummla z26.s, z6.b, z4.b\n"
      ".inst 0x45c598dd  // ummla z29.s, z6.b, z5.b\n"
      "ld1b { z4.b }, p0/Z, [x22, #2, MUL VL]\n"
      "ld1b { z5.b }, p0/Z, [x22, #3, MUL VL]\n"
      ".inst 0x45c39809  // ummla z9.s, z0.b, z3.b\n"
      "sub x20, x20, #0x2\n"
      ".inst 0x45c7980c  // ummla z12.s, z0.b, z7.b\n"
      ".inst 0x45c3982f  // ummla z15.s, z1.b, z3.b\n"
      "cmp x20, #0x2\n"
      ".inst 0x45c79832  // ummla z18.s, z1.b, z7.b\n"
      ".inst 0x45c39855  // ummla z21.s, z2.b, z3.b\n"
      ".inst 0x45c79858  // ummla z24.s, z2.b, z7.b\n"
      ".inst 0x45c398db  // ummla z27.s, z6.b, z3.b\n"
      "ld1b { z3.b }, p0/Z, [x22, #4, MUL VL]\n"
      ".inst 0x45c798de  // ummla z30.s, z6.b, z7.b\n"
      ".inst 0x45c4980a  // ummla z10.s, z0.b, z4.b\n"
      "ld1b { z7.b }, p0/Z, [x22, #5, MUL VL]\n"
      ".inst 0x45c5980d  // ummla z13.s, z0.b, z5.b\n"
      ".inst 0x45c49830  // ummla z16.s, z1.b, z4.b\n"
      "ld1rqb { z0.b }, p0/Z, [%x[Apanel], #16]\n"
      ".inst 0x45c59833  // ummla z19.s, z1.b, z5.b\n"
      ".inst 0x45c49856  // ummla z22.s, z2.b, z4.b\n"
      "ld1rqb { z1.b }, p0/Z, [%x[Apanel], #32]\n"
      ".inst 0x45c59859  // ummla z25.s, z2.b, z5.b\n"
      ".inst 0x45c498dc  // ummla z28.s, z6.b, z4.b\n"
      "ld1rqb { z2.b }, p0/Z, [%x[Apanel], #48]\n"
      ".inst 0x45c598df  // ummla z31.s, z6.b, z5.b\n"
      "ld1rqb { z6.b }, p0/Z, [%x[Apanel], #64]\n"
      "ld1b { z4.b }, p0/Z, [x22, #6, MUL VL]\n"
      "ld1b { z5.b }, p0/Z, [x22, #7, MUL VL]\n"
      "addvl x22, x22, #16\n"
      ".inst 0x45c39808  // ummla z8.s, z0.b, z3.b\n"
      ".inst 0x45c7980b  // ummla z11.s, z0.b, z7.b\n"
      ".inst 0x45c3982e  // ummla z14.s, z1.b, z3.b\n"
      ".inst 0x45c79831  // ummla z17.s, z1.b, z7.b\n"
      ".inst 0x45c39854  // ummla z20.s, z2.b, z3.b\n"
      ".inst 0x45c79857  // ummla z23.s, z2.b, z7.b\n"
      ".inst 0x45c398da  // ummla z26.s, z6.b, z3.b\n"
      "ld1b { z3.b }, p0/Z, [x22, #-8, MUL VL]\n"
      ".inst 0x45c798dd  // ummla z29.s, z6.b, z7.b\n"
      "ld1b { z7.b }, p0/Z, [x22, #-7, MUL VL]\n"
      ".inst 0x45c49809  // ummla z9.s, z0.b, z4.b\n"
      ".inst 0x45c5980c  // ummla z12.s, z0.b, z5.b\n"
      ".inst 0x45c4982f  // ummla z15.s, z1.b, z4.b\n"
      ".inst 0x45c59832  // ummla z18.s, z1.b, z5.b\n"
      ".inst 0x45c49855  // ummla z21.s, z2.b, z4.b\n"
      ".inst 0x45c59858  // ummla z24.s, z2.b, z5.b\n"
      ".inst 0x45c498db  // ummla z27.s, z6.b, z4.b\n"
      "ld1b { z4.b }, p0/Z, [x22, #-6, MUL VL]\n"
      ".inst 0x45c598de  // ummla z30.s, z6.b, z5.b\n"
      ".inst 0x45c3980a  // ummla z10.s, z0.b, z3.b\n"
      "ld1b { z5.b }, p0/Z, [x22, #-5, MUL VL]\n"
      ".inst 0x45c7980d  // ummla z13.s, z0.b, z7.b\n"
      ".inst 0x45c39830  // ummla z16.s, z1.b, z3.b\n"
      "ld1rqb { z0.b }, p0/Z, [%x[Apanel], #80]\n"
      ".inst 0x45c79833  // ummla z19.s, z1.b, z7.b\n"
      ".inst 0x45c39856  // ummla z22.s, z2.b, z3.b\n"
      "ld1rqb { z1.b }, p0/Z, [%x[Apanel], #96]\n"
      ".inst 0x45c79859  // ummla z25.s, z2.b, z7.b\n"
      ".inst 0x45c398dc  // ummla z28.s, z6.b, z3.b\n"
      "ld1rqb { z2.b }, p0/Z, [%x[Apanel], #112]\n"
      ".inst 0x45c798df  // ummla z31.s, z6.b, z7.b\n"
      "add %x[Apanel], %x[Apanel], #0x80\n"
      "addvl x22, x22, #-4\n"
      "bge 3b\n"

可以参见Arm compute library和KleidiAI中的代码:

https://github.com/ARM-softwa...

https://gitlab.arm.com/kleidi...

KleidiAI中的NEON实现来处理了量化的一些操作,会和介绍的实现稍有差别。
Arm Compute Library还包含了比较丰富的量化,较小的K值处理, NEON,等不同的kernel实现。

Int8的Matrix Multiply矩阵乘

FP32的Matrix Multiply(矩阵乘)指令的操作是:
Picture1.png

这条指令将第一个SVE2源向量中每128-bit看作2x2 FP32矩阵, 第二个SVE2源向量中每128-bit的2x2 FP32矩阵,然后将第一个SVE2源向量中的2x2矩阵与第二个SVE2源向量中的对应的2x2矩阵进行矩阵乘,生成的2x2 FP32矩阵乘积累加到目标向量中的FP32矩阵累加器中。

A矩阵和B矩阵重排

FP32 A矩阵可以重排为:
gemm-gemm_fp32_mmla_1.jpg
FP32 B矩阵可以重排为:
gemm-gemm_fp32_mmla_2.jpg

FP32 Matrix Multiply运算

FP32的A矩阵和B矩阵的矩阵乘运算可以采用Int8矩阵乘类似的方式:

  1. 重排后的矩阵A块的0(0),0(1)与B块0(0),0(1),0(2),0(3)FMMLA运算。其中A块0(0)为[a0_0, a0_1], A块0(1)为[a1_0, a1_1], B块0(0)为[b0_0, b1_0], B块0(1)为[b0_1, b1_1], B块0(2)为[b0_2, b1_2], B块0(3)为[b0_3, b1_3].
    gemm-gemm_fp32_mmla_3.jpg
  2. 经上面计算的到C矩阵的下面这些运算的中间值:
    gemm-gemm_fp32_mmla_4.jpg

后面的迭代计算和Int8类似,不再赘述。

代码如下:

                "2:\n"
                ".inst 0x64a4e408 // fmmla z8.s, z0.s, z4.s\n"
                "ld1w z7.s, p0/z, [%[b_ptr], #-1, MUL VL]\n"
                ".inst 0x64a4e42e // fmmla z14.s, z1.s, z4.s\n"
                "ld1rqw z3.s, p0/z, [%[a_ptr], #-0x10]\n"
                ".inst 0x64a4e454 // fmmla z20.s, z2.s, z4.s\n"
                "subs %[loops], %[loops], #0x1\n"
                ".inst 0x64a5e409 // fmmla z9.s, z0.s, z5.s\n"
                ".inst 0x64a4e47a // fmmla z26.s, z3.s, z4.s\n"
                "ld1w z4.s, p0/z, [%[b_ptr]]\n"
                ".inst 0x64a5e42f // fmmla z15.s, z1.s, z5.s\n"
                ".inst 0x64a5e455 // fmmla z21.s, z2.s, z5.s\n"
                ".inst 0x64a5e47b // fmmla z27.s, z3.s, z5.s\n"
                "ld1w z5.s, p0/z, [%[b_ptr], #1, MUL VL]\n"
                ".inst 0x64a6e40a // fmmla z10.s, z0.s, z6.s\n"
                ".inst 0x64a6e430 // fmmla z16.s, z1.s, z6.s\n"
                ".inst 0x64a6e456 // fmmla z22.s, z2.s, z6.s\n"
                ".inst 0x64a6e47c // fmmla z28.s, z3.s, z6.s\n"
                "ld1w z6.s, p0/z, [%[b_ptr], #2, MUL VL]\n"
                ".inst 0x64a7e40b // fmmla z11.s, z0.s, z7.s\n"
                ".inst 0x64a7e431 // fmmla z17.s, z1.s, z7.s\n"
                ".inst 0x64a7e457 // fmmla z23.s, z2.s, z7.s\n"
                ".inst 0x64a7e47d // fmmla z29.s, z3.s, z7.s\n"
                "ld1w z7.s, p0/z, [%[b_ptr], #3, MUL VL]\n"
                ".inst 0x64a4e40c // fmmla z12.s, z0.s, z4.s\n"
                ".inst 0x64a4e432 // fmmla z18.s, z1.s, z4.s\n"
                ".inst 0x64a4e458 // fmmla z24.s, z2.s, z4.s\n"
                ".inst 0x64a4e47e // fmmla z30.s, z3.s, z4.s\n"
                "ld1w z4.s, p0/z, [%[b_ptr], #4, MUL VL]\n"
                ".inst 0x64a5e40d // fmmla z13.s, z0.s, z5.s\n"
                "ld1rqw z0.s, p0/z, [%[a_ptr]]\n"
                ".inst 0x64a5e433 // fmmla z19.s, z1.s, z5.s\n"
                "ld1rqw z1.s, p0/z, [%[a_ptr], #0x10]\n"
                ".inst 0x64a5e459 // fmmla z25.s, z2.s, z5.s\n"
                "ld1rqw z2.s, p0/z, [%[a_ptr], #0x20]\n"
                ".inst 0x64a5e47f // fmmla z31.s, z3.s, z5.s\n"
                "ld1w z5.s, p0/z, [%[b_ptr], #5, MUL VL]\n"
                ".inst 0x64a6e408 // fmmla z8.s, z0.s, z6.s\n"
                "ld1rqw z3.s, p0/z, [%[a_ptr], #0x30]\n"
                ".inst 0x64a6e42e // fmmla z14.s, z1.s, z6.s\n"
                "add %[a_ptr], %[a_ptr], #0x80\n"
                ".inst 0x64a6e454 // fmmla z20.s, z2.s, z6.s\n"
                "addvl %[b_ptr], %[b_ptr], #12\n"
                ".inst 0x64a6e47a // fmmla z26.s, z3.s, z6.s\n"
                ".inst 0x64a7e409 // fmmla z9.s, z0.s, z7.s\n"
                ".inst 0x64a7e42f // fmmla z15.s, z1.s, z7.s\n"
                "ld1w z6.s, p0/z, [%[b_ptr], #-6, MUL VL]\n"
                ".inst 0x64a7e455 // fmmla z21.s, z2.s, z7.s\n"
                ".inst 0x64a7e47b // fmmla z27.s, z3.s, z7.s\n"
                "ld1w z7.s, p0/z, [%[b_ptr], #-5, MUL VL]\n"
                ".inst 0x64a4e40a // fmmla z10.s, z0.s, z4.s\n"
                ".inst 0x64a4e430 // fmmla z16.s, z1.s, z4.s\n"
                ".inst 0x64a4e456 // fmmla z22.s, z2.s, z4.s\n"
                ".inst 0x64a4e47c // fmmla z28.s, z3.s, z4.s\n"
                "ld1w z4.s, p0/z, [%[b_ptr], #-4, MUL VL]\n"
                ".inst 0x64a5e40b // fmmla z11.s, z0.s, z5.s\n"
                ".inst 0x64a5e431 // fmmla z17.s, z1.s, z5.s\n"
                ".inst 0x64a5e457 // fmmla z23.s, z2.s, z5.s\n"
                ".inst 0x64a5e47d // fmmla z29.s, z3.s, z5.s\n"
                "ld1w z5.s, p0/z, [%[b_ptr], #-3, MUL VL]\n"
                ".inst 0x64a6e40c // fmmla z12.s, z0.s, z6.s\n"
                ".inst 0x64a6e432 // fmmla z18.s, z1.s, z6.s\n"
                ".inst 0x64a6e458 // fmmla z24.s, z2.s, z6.s\n"
                ".inst 0x64a6e47e // fmmla z30.s, z3.s, z6.s\n"
                "ld1w z6.s, p0/z, [%[b_ptr], #-2, MUL VL]\n"
                ".inst 0x64a7e40d // fmmla z13.s, z0.s, z7.s\n"
                "ld1rqw z0.s, p0/z, [%[a_ptr], #-0x40]\n"
                ".inst 0x64a7e433 // fmmla z19.s, z1.s, z7.s\n"
                "ld1rqw z1.s, p0/z, [%[a_ptr], #-0x30]\n"
                ".inst 0x64a7e459 // fmmla z25.s, z2.s, z7.s\n"
                "ld1rqw z2.s, p0/z, [%[a_ptr], #-0x20]\n"
                ".inst 0x64a7e47f // fmmla z31.s, z3.s, z7.s\n"
                "b.ne 2b\n

可以参见Arm compute library中的代码:

https://github.com/ARM-softwa...

如何快速重排矩阵?

与前文类似,可以借助arm向量指令的ZIP1和ZIP2交织指令来实现快速矩阵重排。存回最终结果时,需要对放在向量寄存器中的各个2x2结果矩阵进行解交织,以便利用ST1W指令进行连续地址存储,这可以借助UZP1, UZP2解交织指令来完成。

结论

利用SVE2的ZIP交织指令可以快速重排矩阵,Matrix Multiply指令可以快速迭代计算重排之后的矩阵乘结果,相对Dot product,单条Matrix Multiply的乘加的数量是其2倍,也更大程度上利用了加载的数据。

推荐阅读
关注数
8650
内容数
61
目录
极术微信服务号
关注极术微信号
实时接收点赞提醒和评论通知
安谋科技学堂公众号
关注安谋科技学堂
实时获取安谋科技及 Arm 教学资源
安谋科技招聘公众号
关注安谋科技招聘
实时获取安谋科技中国职位信息