10

修志龙_ZenonXiu · 2024年10月13日 · 上海

如何使用Arm 向量指令加速矩阵乘 (3) – SME2 MOPA (1)

(国庆之后的仅有的一天周末天气不佳,大部分时间用来码文,尽快完成这些arm CPU对AI/ML序列文章... 写文章太累了)

理解本文需要具备SME2构架预备知识,建议先阅读之前文章。
Arm Scalable Matrix Extension介绍
第二部分: Arm Scalable Matrix Extension (SME)指令

本文先介绍如何利用FP32类型SME2 outer product and accumulate指令实现的矩阵乘。

FP32类型的矩阵乘

FP32类型的SME2 Outer Product and Accumulate 指令执行的操作如下:
Picture2.jpg

以SME2 FP32 Matrix Multiply为例,SME2 Streaming SVE向量长度为SVL,A矩阵和B矩阵的矩阵乘(得到C矩阵)可以表达为:
image.png
其中红色标识部分可以使用SME2的一条outer product and accumulate指令实现。由此表达可以看出C矩阵可以由计算一个个C矩阵中SVL x SVL的子矩阵完成。每个SVL x SVL子矩阵可以由K个SME2 outer product and accumulate操作来完成。

SME2 outer product and accumulate实现矩阵乘

为了演示,把SVL假设为128-bit(可以放4个FP32元素的向量)。

  1. 将A矩阵的[a0_0, a1_0, a2_0, a3_0]取出到一个SME2 SVE向量寄存器,假设它为Z__a0。
    将B矩阵的[b0_0, b0_1, b0_2, b0_3]取出到一个SME2 SVE向量寄存器,假设它为Z__b0。
  2. 将Z__a0与Z__b0进行SME2 outer product and accumulate计算,将得到C矩阵中如下元素们的中间值:
    gemm-fp32 mpoa.jpg

图中C矩阵被更新的元素的颜色为A矩阵向量Z__a0(红色)和B矩阵向量Z__b0(蓝色)的混色,表示它们是这两个颜色的向量的SME2 outer product and accumulate运算更新的。

  1. 将A矩阵的[a0_1, a1_1, a2_1, a3_1]取出到一个SME2 SVE向量寄存器,假设它为Z__a1。
    将B矩阵的[b1_0, b1_1, b1_2, b1_3]取出到一个SME2 SVE向量寄存器,假设它为Z__b1。
  2. 将Z__a1与Z__b1进行SME2 outer product and accumulate计算,将更新C矩阵中如下元素们的中间值:
    gemm-fp32 mpoa_2.jpg
  3. 重复以上步骤,在K维度上迭代。完成这些迭代之后,可以得到这些C矩阵元素的最终结果:
    gemm-fp32 mpoa_3.jpg
  4. 再利用同样的方法计算C矩阵中其他子矩阵的结果, 例如:
    gemm-fp32 mpoa_4.jpg

为了,

  • 提高A矩阵和B矩阵这些向量的利用效率
  • 提高ZA tile的利用效率:对于FP32 数据大小(4 bytes),可用的ZA tile是4个:ZA0.S – ZA3.S
  • 将多条 back to back(背靠背)的SME2 outer product and accumulate指令pipeline起来,可以提供CPU硬件执行这些操作的效率

实际上可以做一些优化调整:将4个SME2 outer product and accumulate操作pipe起来。
gemm-fp32 mpoa_5.jpg

  1. (a)计算矩阵A红色Z_a0和矩阵B蓝色Z_b0, 结果存放在SME2 tile ZA0.S
    (b)计算矩阵A红色Z_a0和矩阵B绿色Z_b0, 结果存放在SME2 tile ZA1.S
    (c)计算矩阵A棕色Z_a0和矩阵B蓝色Z_b0, 结果存放在SME2 tile ZA2.S
    (d)计算矩阵A棕色Z_a0和矩阵B绿色Z_b0, 结果存放在SME2 tile ZA4.S
  2. (a)计算矩阵A红色Z_a1和矩阵B蓝色Z_b1, 结果更新在SME2 tile ZA0.S
    (b)计算矩阵A红色Z_a1和矩阵B绿色Z_b1, 结果更新在SME2 tile ZA1.S
    (c)计算矩阵A棕色Z_a1和矩阵B蓝色Z_b1, 结果更新在SME2 tile ZA2.S
    (d)计算矩阵A棕色Z_a1和矩阵B绿色Z_b1, 结果更新在SME2 tile ZA4.S

    1. 在K维度以上面的方式迭代,最终可以计算C矩阵上面那些元素的最终结果。

    A矩阵的重排

    由上面步骤看出,最好对A矩阵进行一些transpose重排,来提供A矩阵的内存访问效率。A矩阵需要重排为:
    gemm-fp32 mpoa_7.jpg

上章节的SME2 outer product and accumulate运算中使用经重排后的A矩阵,可以达到对A的向量访问在内存中地址连续,提供内存访问效率,减少cache miss。

可以利用SME2的On-the-fly matrix transposition功能,它借助ZA tile slice可以水平(行)和垂直(列)访问的特点:
先将原始A矩阵的几行读取到ZA tile的水平slice,然后将这个ZA tile以垂直slice访问,将垂直slice存到重排的A矩阵内存。

  1. (a). 先从原始A矩阵读取SVL/4 (这个SVL=16 byte, 128-bit的例子里,为16/4=4)行到ZA tile (ZAn.S)的水平slice中。
    (b). 然后将这个ZA tile以垂直slice的方式访问,将这些ZA tile slice存到重排的A矩阵内存。
    gemm-fp32 mpoa_8.jpg
  2. (a). 继续读取原始A矩阵这些行的元素到ZA tile (ZAn.S)的水平slice中。
    (b). 然后将这个ZA tile以垂直slice的方式访问,将这些ZA tile slice存到重排的A矩阵内存。
    gemm-fp32 mpoa_9.jpg
  3. 继续这些行的其他元素K维度处理。如果最后剩下一些不足够SVL向量长度(leftover),则可以借助SME2 内存访问和ZA tile slice访问的的predicate功能:
    gemm-fp32 mpoa_10.jpg
  4. 按照1,2,3的方式继续处理下面4行的transpose。
  5. 如果最后M维度的剩下行不足SVL/4 (这个例子为4)行,那么需要在重排的A矩阵中填0处理。可以这样来实现:
    gemm-fp32 mpoa_11.jpg

    上图中最后剩下2行,不够SVL/4 即 4行,那么只需读取剩下的两行到ZA tile的两个水平slice,然后将剩下的水平slice置0,再将这个ZA tile 的垂直slice存到重排的A矩阵内存。

代码实现

SME2 outer product and accumulate运算代码实现可以参考:


  "6:"  // K loop
      ".inst 0x808702c0  // fmopa za0.s, p0/M, p0/M, z22.s, z7.s\n"
      "subs x21, x21, #0x1\n"
      ".inst 0x808f02c1  // fmopa za1.s, p0/M, p0/M, z22.s, z15.s\n"
      ".inst 0x808702e2  // fmopa za2.s, p0/M, p0/M, z23.s, z7.s\n"
      ".inst 0x808f02e3  // fmopa za3.s, p0/M, p0/M, z23.s, z15.s\n"
      ".inst 0xa0404776  // ld1w { z22.s-z23.s }, pn9.b/Z, [x27]\n"
      ".inst 0x809400c0  // fmopa za0.s, p0/M, p0/M, z6.s, z20.s\n"
      ".inst 0xa14046e7  // ld1w { z7.s, z15.s }, pn9.b/Z, [x23]\n"
      ".inst 0x809500c1  // fmopa za1.s, p0/M, p0/M, z6.s, z21.s\n"
      ".inst 0x809401c2  // fmopa za2.s, p0/M, p0/M, z14.s, z20.s\n"
      ".inst 0x809501c3  // fmopa za3.s, p0/M, p0/M, z14.s, z21.s\n"
      ".inst 0xa1414766  // ld1w { z6.s, z14.s }, pn9.b/Z, [x27, #0x2, MUL VL]\n"
      ".inst 0x80830040  // fmopa za0.s, p0/M, p0/M, z2.s, z3.s\n"
      ".inst 0xa04146f4  // ld1w { z20.s-z21.s }, pn9.b/Z, [x23, #0x2, MUL VL]\n"
      ".inst 0x808b0041  // fmopa za1.s, p0/M, p0/M, z2.s, z11.s\n"
      ".inst 0x80830142  // fmopa za2.s, p0/M, p0/M, z10.s, z3.s\n"
      ".inst 0x808b0143  // fmopa za3.s, p0/M, p0/M, z10.s, z11.s\n"
      ".inst 0xa1424762  // ld1w { z2.s, z10.s }, pn9.b/Z, [x27, #0x4, MUL VL]\n"
      ".inst 0xa14246e3  // ld1w { z3.s, z11.s }, pn9.b/Z, [x23, #0x4, MUL VL]\n"
      ".inst 0x80840020  // fmopa za0.s, p0/M, p0/M, z1.s, z4.s\n"
      ".inst 0x80850021  // fmopa za1.s, p0/M, p0/M, z1.s, z5.s\n"
      ".inst 0x80840122  // fmopa za2.s, p0/M, p0/M, z9.s, z4.s\n"
      ".inst 0x80850123  // fmopa za3.s, p0/M, p0/M, z9.s, z5.s\n"
      ".inst 0xa1434761  // ld1w { z1.s, z9.s }, pn9.b/Z, [x27, #0x6, MUL VL]\n"
      "addvl x27, x27, #8\n"
      ".inst 0xa04346e4  // ld1w { z4.s-z5.s }, pn9.b/Z, [x23, #0x6, MUL VL]\n"
      "addvl x23, x23, #8\n"
      "bgt 6b\n"

https://github.com/ARM-softwa...

A矩阵重排代码可以参考:

.Load_loop:
  psel pn10, pn8, p0.s[w12, 0]
  psel pn11, pn8, p0.s[w12, 1]
  psel pn12, pn8, p0.s[w12, 2]
  psel pn13, pn8, p0.s[w12, 3]
  ld1w {z20.s, z28.s}, pn10/z, [x6]         // matLeft, 将原始矩阵的行load到SME2 SVE寄存器
  ld1w {z21.s, z29.s}, pn11/z, [x6, x1, lsl #2] // matLeft + K
  ld1w {z22.s, z30.s}, pn12/z, [x6, x14, lsl #2] // matLeft + K*2
  ld1w {z23.s, z31.s}, pn13/z, [x6, x15, lsl #2] // matLeft + K*3
  mova za0h.s[w12, 0:3], {z20.s-z23.s}     //将上面SVE寄存器的值mov到ZA tile的水平slice
  mova za1h.s[w12, 0:3], {z28.s-z31.s}
  
  add x6, x6, x1, lsl #4 // matLeft+=4*K FP32 elements (bytes)
  add w12, w12, #4 // Increment counter
  cmp w12, w4
  b.mi .Load_loop
  
  mov w12, #0 // Store_loop counter
  
.Store_loop:
  whilelt pn10.b, x9, x10, vlx4
  whilelt pn11.b, x9, x13, vlx4
  mova {z0.s-z3.s}, za0v.s[w12, 0:3]      //将ZA tile的垂直slice移到SVE寄存器
  mova {z4.s-z7.s}, za1v.s[w12, 0:3]
  st1w {z0.s-z3.s}, pn10, [x9] // Store 4 col vectors to matLeft_mod,将上面SVE寄存器值存入重排的矩阵内存
  st1w {z4.s-z7.s}, pn11, [x9, x16, lsl #2] // matLeft_mod+SVLs*SVLs
  addvl x9, x9, #4 // matLeft_mod += 4*SVLb (bytes)
  add w12, w12, #4 // Increment counter
  cmp w12, w4
  b.mi .Store_loop

下文将介绍Int8类型SME2 outer product and accumulate的矩阵乘实现。

Arm Scalable Matrix Extension介绍
第二部分: Arm Scalable Matrix Extension (SME)指令
Arm构架如何让AI应用高效运行于CPU(1)
Arm构架如何让AI应用高效运行于CPU(2)
如何使用Arm向量指令加速矩阵乘(1)-SVE2 Dot Product
如何使用Arm向量指令加速矩阵乘(2)–SVE2 Matrix Multiply

推荐阅读
关注数
8650
内容数
61
目录
极术微信服务号
关注极术微信号
实时接收点赞提醒和评论通知
安谋科技学堂公众号
关注安谋科技学堂
实时获取安谋科技及 Arm 教学资源
安谋科技招聘公众号
关注安谋科技招聘
实时获取安谋科技中国职位信息