SME2在SME的基础上,通过加入multi-vector(多向量)支持更好地平衡之前的向量计算和矩阵乘计算加速,提高向量处理能力和矩阵乘运算逻辑的重用性。
SME2也加入了压缩机器学习的数据格式的支持
- 通过支持查找表指令,可以快速地将压缩的2b或4b的权重转换为INT8, BF16, FP16或FP32格式
- 支持1-bit二进制网络的外积计算: C = C + popcount32(~(a ^ bT ))
SME2 Multi-vector指令
SME2引入了支持multi-vector操作数和multi-vector predication的数据处理指令,multi-vector可以利用ZA storage作为操作对象, multi-vector指令可以重用ZA Array vector和outer product中的运算逻辑,例如:
SME2 Multi-vector 可以是:
由一组2个或4个SVE2 Z 寄存器组成的Z multi-vector
它是连续或是跨步(stride)编号的寄存器,例如{ Z0.S-Z1.S }, { Z0.H, Z4.H, Z8.H, Z12.H }
由一组2个或4个ZA tile slice 组成的ZA multi-slice
它是连续编号的水平(horizontal)或垂直(vertical)的ZA tile slice
- 两个ZA tile slice的例子: ZA0H.B[w12, 0 : 1] 是由ZA0H.B[w12+0] 和ZA0H.B[w12+1]这两个水平tile slice组成的ZA multi-slice。ZA0H.B[w12, 2 : 3] 是由ZA0H.B[w12+2] 和ZA0H.B[w12+3]这两个水平tile slice组成的ZA multi-slice。
- 四个ZA tile slice的例子: ZA0V.B[w12, 0 : 3]是由ZA0V.B[w12+0] ,ZA0V.B[w12+1], ZA0V.B[w12+2] ,ZA0V.B[w12+3]这四个垂直tile slice组成的ZA multi-slice
由一组2个,4个,8个或16个ZA array vector组成的ZA multi-vector,它可以是
- 2组 ZA single-vector(Two ZA single-vector groups), 4组ZA single-vector(Four ZA single-vector groups)
- 1组ZA double-vector(One ZA double-vector group), 2组ZA double-vector(Two ZA double-vector groups), 4组ZA double-vector (Four ZA double-vector groups)
- 1组ZA quad-vector(One ZA quad-vector group), 2组ZA quad-vector(Two ZA quad-vector groups), 4组ZA quad-vector(Four ZA quad-vector groups).
ZA multi-vector相对复杂一点,这里展开简绍一下。
ZA multi-vector
首先,什么是single-vector,double-vector,quad-vector?
• 一个single-vector是由一个ZA Array vector组成。
• 一个double-vector是由2个编号连续的ZA Array vector组成。
• 一个quad-vector是由4个编号连续的ZA Array vector组成。
可以由多个single-vector,double-vector,quad-vector组成groups。但是这些single-vector groups, double-vector groups,quad-vector groups在ZA array里却不是连续的,而是均匀分布在ZA array里。
Two single-vector groups 的表达方式为ZA.T[Wv, imm, VGx2],其中的T表示数据类型,可以是B/H/S/D, Wv是用来存ZA vector index的基W寄存器,imm是表示ZA vector index的offset,VGx2表示它是2 xxx-vector groups。
Four single-vector groups 的表达方式为ZA.T[Wv, imm, VGx4],其中的T表示数据类型,可以是B/H/S/D, Wv是用来存ZA vector index的基W寄存器,imm是表示ZA vector index的offset,VGx4表示它是4 xxx-vector groups.
对于Two single-vector groups,它将整个ZA Array分成两部分.
访问一个Two single-vector groups,ZA.T[Wv, imm, VGx2],其实是访问ZA [n]和ZA [n + SVL_B/2] 这两个ZA Array vector,其中SVL_B为SME2的SVL向量长度中的byte数(例如如果SVL为512-bit,那么SVL_B为512/8=64),n 为(Wv+imm) % (SVL_B/2),即n为(Wv+imm)与(SVL_B/2)的模数。例如,假设W8=1, imm=1,SVL_B=64,那么ZA.B[W8, #1,VGx2]访问ZA[2]和ZA[34].
对于Four single-vector groups,它将整个ZA Array分成四部分.
访问一个Four single-vector groups,ZA.T[Wv, imm, VGx4],其实是访问ZA [n],ZA[n + SVL_B/4],ZA[n + SVL_B/2]和ZA[n + SVL_B x 3/4]这四个ZA Array vector,其中n 为(Wv+imm) % (SVL_B/4),即n为(Wv+imm)与(SVL_B/4)的模数。例如,假设W8=1, imm=1,SVL_B=64,那么ZA.B[W8, #1,VGx4]访问ZA[2],ZA[18],ZA[34]和ZA[50].
Two double-vector groups 的表达方式为ZA.T[Wv, imml:immh, VGx2],其中的T表示数据类型,可以是B/H/S/D, Wv是用来存ZA vector index的基W寄存器,imml:immh是连续的两个数字(一偶一奇,第一个数字是2的整数倍,第二个数字是第一个数字+1,如0:1, 2:3)是表示ZA vector index的offset,VGx2表示它是4 xxx-vector groups。
Four double -vector groups 的表达方式为ZA.T[Wv, imml:immh, VGx4],其中的T表示数据类型,可以是B/H/S/D, Wv是用来存ZA vector index的基W寄存器,imml:immh是连续的两个数字(一偶一奇,第一个数字是2的整数倍,第二个数字是第一个数字+1,如0:1, 2:3)是表示ZA vector index的offset,VGx4表示它是4 xxx-vector groups.
对于Two double-vector groups,它将整个ZA Array分成两部分.
访问一个Two double-vector groups,ZA.T[Wv, imml:immh, VGx2],其实是访问ZA [n], ZA [n+1], ZA [n + SVL_B/2] , ZA [n + SVL_B/2+1]这四个ZA Array vector,其中n 为(Wv+imm) % (SVL_B/2),即n为(Wv+imm)与(SVL_B/2)的模数。例如,假设W8=1, imml:immh =2:3,SVL_B=64,那么ZA.B[W8, 2:3,VGx2]访问ZA[3],ZA[4], ZA[35] , ZA[36].
对于Four double-vector groups,它将整个ZA Array分成四部分:
访问一个Four double-vector groups,ZA.T[Wv, imml:immh, VGx4],其实是访问ZA[n], ZA[n+1], ZA[n + SVL_B/4], ZA [n + SVL_B/4 +1], ZA[n + SVL_B/2], ZA [n + SVL_B/2 +1], ZA[n + SVL_B x 3/4], ZA [n + SVL_B x 3/4 +1]这8个ZA Array vector,其中n 为(Wv+imm) % (SVL_B/4),即n为(Wv+imm)与(SVL_B/4)的模数。例如,假设W8=1, imm=2,SVL_B=64,那么ZA.B[W8, 2:3,VGx4]访问ZA[3],ZA[4],ZA[19],ZA[20],ZA[35],ZA[36],ZA[51],ZA[52].
对于One ZA quad-vector group,Two ZA quad-vector groups, Four ZA quad-vector groups 可以以此类推。
Multi-vector指令的例子
Multi-vector MOV指令可以移动2个或4个SVE2 Z寄存器到ZA multi-vector groups,或者移动ZA multi-vector groups到2个或4个SVE2 Z寄存器。例如,MOVA ZA.D[w8, #0, VGx2], { Z0.D-Z1.D }
这条指令将两个SVE2 Z寄存器移动到一个 Two Single-vector groups.
假设W8=0,那么它的操作为:
Multi-vector加,乘,乘累加,点积,乘累加并扩大结果类型大小(包括FMLA, FADD, FDOT, BFDOT, [SU]DOT, ADD这些指令)可以操作2个或4个SVE2 Z寄存器,将结果存入ZA multi-vector。例如,FMLAL ZA.S[W8, 0:1, VGx2], { Z0.H-Z1.H }, Z2.H
假设W8=0, SVL_B=64那么它的操作为:
它将Z0寄存器中奇数的元素与Z2寄存器中奇数的元素进行乘运算并扩到结果类型,将结果放在Z[0], 将Z0寄存器中偶数的元素与Z2寄存器中奇数的偶素进行乘运算并扩到结果类型,将结果放在Z[1],将Z1寄存器中奇数的元素与Z2寄存器中奇数的元素进行乘运算并扩到结果类型,将结果放在Z[32], 将Z1寄存器中偶数的元素与Z2寄存器中奇数的偶素进行乘运算并扩到结果类型,将结果放在Z[33].
SME2 Multi-vector predication
因为SME2 multi-vector指令支持多个vector,要在一条指令中像Non-streaming SVE mode里面对vector中的每个元素都单独predicate会有点困难,SME2 Multi-vector指令中引入了一种新的predicate方式:predicate-as-count。之前的方式称之为predicate-as-mask。 这种新的方式在SVE2 P寄存器中存放一个counter(计数器),这个counter计数连续有效(Active)或无效(Inactive)元素的数量(从第 0元素 开始)。
一个P寄存器可以用为predicate-as-mask,这时在指令中用Pm来表示(例如P0, P1. ..), 用作为predicate-as-count,这时在指令中用PNg来表示(例如PN0, PN1. ..)。实际上Pg即PNg (如P0即PN0)。
PNg的编码格式为:
只使用PN寄存器的16-bit。
它包含一个‘size’ field,它指示element size或是表示All-FALSE。 它可以占1到4 bit。
紧接着的是‘count‘ field, 如果‘size’不是0b0000, 那么‘count’ field表示从第0个元素开始的连续有效(Active)或无效(Inactive)元素的数量,根据element size和SVL长度,可以决定‘count‘ field的在PN寄存器中的有效bit数。按照最长的SVE长度2048-bit和element size为Byte,count最多只需要8-bit。
PN寄存器中的第15bit为‘invert’ bit。如果invert bit为0,表示‘count‘中计数的是第0个元素开始的连续有效(Active)的数量。如果invert bit为1,表示‘count‘中计数的是第0个元素开始的连续有效(Inactive)的数量。
All-True的编码可以由invert=1, size=0b0000 (All-FALSE)来表示,可以理解为将All-FALSE invert就变为All-True,即PNg=0x8000为All-True, PNg=0x0000为All-FALSE。
举两个例子:
SME2 Multi-vector指令扩展了和增加了一些指令,用于生成PNg寄存器,操作,使用PNg寄存器。
新的WHILE ,PTRUE指令可以生成PNg。PEXT 指令可以将“predicate-as-count” PNg转化为 “predicate-as-mask“ Pn寄存器。
例如,WHILELT <PNd>.<T>, <Xn>, <Xm>, <vl>
其中vl可以是VLx2 或是VLx4, 用于指示这条指令PNd中的count应该包含2个vector还是4个vector中连续有效(Active)或无效(Inactive)元素的数量。其中WHILELT中的LT表示little than (<), 这条指令比较Xn和Xm,结合vl, SVL长度和Elem_Size来生成PNd.
例如WHILELT PN0.T, X0, X1, VLx2,
其中PN0.T的T (Type)可以是B,H,S,D... 。下面以SVL_T来代表SVL/Elem_Size, 即一个vector长度(1xSVL)可以放T类型元素个数。
如果X1>= (X0+2xSVL_T) ,其中的2表示两个vector:
因为WHILELT中的LT(little than)条件, 那么生成的PN0为All-TRUE。
如果X0>= X1:
因为WHILELT中的LT(little than)条件, 那么生成的PN0为All-FALSE。
其他情况:
那么PN0.invert=0, PN0.count=(X1-X0), PN0.Elem_Size=Elem_Size。
Multi-vector LDR/STR 指令可以使用生成的PN寄存器进行predicate, 例如:LD1D { Z0.D– Z3.D }, PN8/Z, [X0]
假设PN8.Invert=0, Count=26, Elem_Size=Double, 那么这条指令从内存中读取26个double word到Z0.D-Z3.D。Z3.D中高double word对应的predication为False,不需要从内存读取而是被指令直接清零。
下面以实现以下函数为例:
void simple_add(long long *x, unsigned long n){
unsigned long i;
for(i=0;i<n;i++)
x[i]= x[i]+1;
}
通过一般SVE2实现的代码为:
void simple_add(long long *x, unsigned long n)
{
unsigned long i;
asm (" mov z1.d, #1 \n"
"whilelo p0.d, %[i], %[n] \n"
"1: \n"
"ld1d z0.d, p0/z, [%[x], %[i], lsl #3] \n"
"add z0.d, p0/m, z0.d, z1.d \n"
"st1d z0.d, p0, [%[x], %[i], lsl #3] \n"
"uqincd %[i] \n"
"whilelo p0.d, %[i], %[n] \n"
"b.any 1b "
: [i] "=&r" (i)
: "[i]" (0), [x] "r" (x), [n] "r" (n)
: "memory", "cc", "p0", "z0", "z1");
}
而通过SME2multi-vector的实现为:
void simple_add_multi_vec(long long *x, unsigned long n)
{
unsigned long i;
__asm__ ("smstart \n");
asm (" mov z1.d, #1 \n"
"whilelo pn8.d, %[i], %[n], vlx4 \n"
"1: \n"
"ld1d {z4.d-z7.d}, pn8/z, [%[x], %[i], lsl #3] \n"
"add {z4.d-z7.d}, {z4.d-z7.d}, z1.d \n"
"st1d {z4.d-z7.d}, pn8, [%[x], %[i], lsl #3] \n"
"inch %[i] \n"
"whilelo pn8.d, %[i], %[n], vlx4 \n"
"b.any 1b "
: [i] "=&r" (i)
: "[i]" (0), [x] "r" (x), [n] "r" (n)
: "memory", "cc", "p0", "z0", "z1","pn8");
__asm__ ("smstop \n");
}
SME2 LUTI2, LUTI4查表指令
SME2引入了LUTI2, LUTI4指令用于查表检索,实现通过查表快速完成2b, 4b类型的数据转换为8b, 16b, 32b类型。SME2新引入了一个固定长度的512-bit寄存器,ZT0, 用于存放查找表。
有两种操作类型:
- LUTI2:源寄存器中的索引为2b,对应到ZT0寄存器中的4个元素。元素始终位于32位的位置中,并且可以缩减为8位或16位。
- LUTI4:源寄存器中的索引为4位,对应到到ZT0寄存器中的16个元素。
LUTI2和LUTI4指令可以生成1个、2个或4个寄存器的结果。
不是源寄存器中所有的index都可以访问。假如SVL为512-bit, 一条产生4个包含32-bit元素的目的寄存器的LUTI2指令可以输出64个元素(每个寄存器16个元素)。因为index是2-bit, 所以只使用输入源寄存器的128-bit。这个指令有额外的index,用来指定使用输入寄存器的哪个段(例如是[127:0], [255:128]还是其他128-bit)。上面指令LUTI2 Zd.B, {ZT0}, Zn[part]
中的part即是用来指定哪个段的。
LUTIx指令可以支持单目的寄存器,两目的寄存器和四目的寄存器的形式。
下面给出几个例子。
LUTI2单寄存器
LUTI2 Zd.T, {ZT0}, Zn[part]
例如 LUTI2 Z1.B, ZT0, Z0[2] 假设 SVL=128-bit,其实现的操作为:
对于例子中SVL=128-bit和数据类型为B,要填满一个SVL长度的目的寄存器需要的检索数目为16 (128除以8),每条检索为2b,这使每个segment在Zn中的bit数为16x2=32b,指令中的'part'用来指示用Zn的哪个segment。在这个例子里面part为2。
另一个LUTI2单寄存器的例子:
LUTI2 Z1.H, ZT0, Z0[1] 假设 SVL=128-bit,其实现的操作为:
对于例子中SVL=128-bit和数据类型为H,要填满一个SVL长度的目的寄存器需要的检索数目为8 (128除以16),每条检索为2b,这使每个segment在Zn中的bit数为8x2=16b,指令中的'part'用来指示用Zn的哪个segment。在这个例子里面part为1。
LUTI4单寄存器
LUTI4 Zd.T, {ZT0}, Zn[part]
例如 LUTI4 Z1.B, ZT0, Z0[1] 假设 SVL=128-bit,其实现的操作为:
对于例子中SVL=128-bit和数据类型为B,要填满一个SVL长度的目的寄存器需要的检索数目为16 (128除以8),每条检索为4b,这使每个segment在Zn中的bit数为16x4=64b,指令中的'part'用来指示用Zn的哪个segment。在这个例子里面part为1。
另一个LUTI4单寄存器的例子:
LUTI4 Z1.H, ZT0, Z0[1] 假设 SVL=128-bit,其实现的操作为:
对于例子中SVL=128-bit和数据类型为H,要填满一个SVL长度的目的寄存器需要的检索数目为8 (128除以16),每条检索为4b,这使每个segment在Zn中的bit数为8x4=32b,指令中的'part'用来指示用Zn的哪个segment。在这个例子里面part为1。
LUTIx 多寄存器
LUTIx 多寄存器支持使用‘part‘和不使用'part'的形式,多目的寄存器可以是连续或者步长(stride)编号的寄存器。
以LUTI2四寄存器为例,
连续寄存器的不使用'part' LUTI2形式为:
LUTI2 { <Zd1>.B-<Zd4>.B }, ZT0, { <Zn1>-<Zn2> }, 以LUTI2 {Z0.B-Z3.B}, ZT0, {Z8-Z9} 为例,假设 SVL=128-bit,其实现的操作为:
连续寄存器的使用'part' LUTI4形式为:
LUTI4 { <Zd1>.<T>-<Zd4>.<T> }, ZT0, <Zn>[<part>]以LUTI4 {Z4.S-Z7.S}, ZT0, Z0[1] 为例,假设 SVL=128-bit,其实现的操作为:
Arm Scalable Matrix Extension介绍
第二部分: Arm Scalable Matrix Extension (SME)指令