梁德澎 · 2月26日

详解Winograd变换矩阵生成原理

本文想把有关Winograd这个算法背后所涉及到的数学知识用比较通俗的方式给读者描述一遍,并且在这的过程中也会添加一些我个人的理解。
作者:梁德澎
首发知乎:https://zhuanlan.zhihu.com/p/102351953

0、前言

其实网上已经有不少从数学原理的角度去解说Winograd[1,2,3,4,5,6,10]这个算法的文章了,为什么我还要写这篇文章。

主要是在看完许多相关的文章之后,对于Winograd这个算法背后的数学原理我还是没法完全理解,尤其是Winograd的变换矩阵究竟是如何生成的。而在查阅到的资料里面,在描述到一些相关数学定理的时候,许多细节部分都没有很详细的说明,只能通过额外去查找资料和手推公式来理解。

这也是促成我写这篇文章的主要原因,想把有关Winograd这个算法背后所涉及到的数学知识用比较通俗的方式给读者描述一遍,并且在这的过程中也会添加一些我个人的理解,当然我的理解也不一定正确,如果有误也请读者指出。

总的来说感觉Winograd这个算法真的很巧妙,要理解这个算法,需要懂得前置数学知识挺多的,如果其中一个地方没弄懂,都会对理解这个算法的数学原理造成困难。而且即使已经看懂了整体部分,但是很多细节部分如果仔细去想就会觉得自己还没有完全弄懂。

这里我把收集到的所有相关资料链接都统一放到文末参考资料里面,也方便读者去查阅。

1、卷积与多项式乘法

1.1、Convolution和Correlation的区别

首先卷积其实有两个含义[8,9]:

第一个是指一般数学意义上的的两个离散序列的卷积(Convolution);

第二个是深度学习中所用到的卷积(操作上更像Correlation而不是Convolution);

通俗来说两个离散序列在做Convolution操作的时候,首先需要将其中一个序列做镜像翻转,然后两个序列相向移动,从开始第一个元素重合到最后一个元素重合为止,相向移动步长为1,每次把重合的部分做点乘累加的到新的元素,最后生成新的序列。

而两个离散序列在做Correlation操作的时候,除了不需要翻转序列,操作上和Convolution一致。而深度学习中的Convolution其实和Correlation很像,但是又不完全一样。

不一样的地方在于,计算第一个元素的时候是直接把较短的序列和较长的序列从左首元素开始对齐,然后较短的序列按步长向右移动(假设步长也是1),一直到当前较短序列最右边的元素和较长序列最右边的元素对齐,要求短序列的每个元素都必须要和长序列的元素有重合,然后每次把重合的部分点乘累加得到新的元素。

从这里开始,下文提到的Correlation操作都是指深度学习中的卷积操作。

所以给定同样长度的两个序列,分别做Convolution和Correlation操作得到的结果序列长度是不一样的。假设两个长度分别是 n 和 k 的序列(n>=k),分别做Convolution和Correlation操作(步长都为1)得到的结果序列长度的计算公式分别为:

l_{Convolution}=n+k-1
l_{Correlation}=n-k+1

下面简单画个示意图解释的两者的区别:
1.jpg

左边的Convolution操作还有最后两步没有画出来,不过这已经能足够解释两者的区别了。

然后来看下两者的输出元素个数计算公式之间的联系,比如给定两个序列长度分别为 n 和 k ,Convolution操作的到的序列长度为 n+k-1 ,然后和 n 或者 k 长度的序列做 Correlation 可以得到长度为 k 或 n 的序列。

还有一点要提下,就是Winograd这个算法发明出来其实是用来加速Convolution操作的,所以计算变换矩阵也是从Convolution角度去计算,而计算出来的变换矩阵在做一点小变动之后,也可以直接应用在深度学习的Correlation操作中,这个在下文会讲到。

为什么提这个是因为,之前我在理解Winograd这个算法的是陷入了一个误区,一直是从深度学习卷积(Correlation)应用的这个角度去理解这个算法,然后一直想不明白,后来换成是从Convolution角度去理解很多地方就豁然开朗了。

1.2、卷积与多项式乘法的关系

Convolution操作其实直观上等价于多项式乘法操作[7]。

还是用上面举的例子来说明,假设两个离散序列 [3,1,5,4] 和 [7,9] ,我们可以把这两个序列看成是两个多项式 3+1x+5x²+4x³ 和 7+9x 。这两个多项式相乘的结果是:
21+34x+44x^2+73x^3+36x^4
把其系数从低次冥到高次冥排列,刚好就等价于这两个序列做 Convolution 的结果 [21,34,44,73,36] 。

从多项式相乘结果的最高次冥也可以看出两者的联系,假设两个离散序列长度分别为n 和 k,则对应的多项式的最高次冥分别为 x^{n-1}和 x^{k-1} ,则这两个多项式相乘结果的最高次冥为  x^{n+k-2}  ,再加上一个最低次冥 x^{0} 总共就是 n+k-1个元素,算上系数0的次冥。

下面再举一个例子,看下多项式乘法和Convolution操作的关系:
2.jpg

2、理解Winograd算法需要的数学理论知识

2.1、欧几里得算法

我们先来看下可以求解两个整数最大公约数的欧几里得算法[11,13],可能换成“辗转相除法”这个名字读者会更加熟悉。

我们知道两个整数 a 和 b 的公约数是既能整除 a 又能整除 b 的整数,而两者的最大公约数 g 就是这些数里面最大的数,通常用数学公式表示为 g=GCD(a,b),GCD是Greatest common divisor 的缩写。

简单复习下整除的定义: m(m>0) 整除 n 可以记作 m|n ,即存在一个整数 k 使得km=n 。

然后如果GCD(a,b)=1 ,则称 a 和 b互素,表示它们的最大公约数为1。而两个整数是否互素和它们本身是否是素数无关。

简单复习下素数[12]的定义:素数(Prime number)又称质数,指在大于1的自然数中,除了1和该数自身外,无法被其他自然数整除的数。

引用[11]的一个例子来解释最大公因数的概念,假设现在 a>b ,那么设一个长方形的高是 a ,宽是 b ,因为 a 和 b 的任何公约数 c 都可以整除a 和 b,所以长方形的高和宽都可以等分为长度是 c 的线段。通俗的说法是也就是长方形的内部可以刚好被边长是 c 的正方形填满。而最大公约数 g 是其中最大的一个正方形的边长,下面画个简单的示意图来说明:
1.jpg

接着我们看下如何用欧几里得算法求解最大公约数,先给出欧几里得算法的定义[11,13]。

GCD(a,b)=a, b=0
GCD(a,b)=GCD(b,a % b ), b>0

其中 a%b 也可写作 a mod b ,也就是求 a 除以 b 的余数,上面公式还有个条件就是 a>b ,第一个公式很好理解,因为任何整数都能整除0。

这里首先引入同余式的概念:若正整数 a 和 b 分别对 p 取模的余数相同,则可以记作 a≡b(mod p) ,也就是 a 和 b 模 p 同余。

再继续证明欧几里得算法第二个公式之前,先来看一下求模运算的一些运算规则[15]:

(a + b) % p = (a % p + b % p) % p
(a - b) % p = (a % p - b % p) % p
(a * b) % p = (a % p * b % p) % p

然后对于第二条公式的证明引用自[13],更多细节可以参考资料[11,13,14]:

a除以b 的商是p 余数是 q ,则可以表示为:
 a=b * p + q

考虑到 b 和 q 的最大公约数 GCD(b,q) 可知:

b % GCD(b,q)=0
q % GCD(b,q)=0



然后根据运算规则和上面的两个公式可得:

a%GCD(b,q)
=(b* p+q)% GCD(b,q)
=(b* p%GCD(b,q)+q%GCD(b,q)%GCD(b,q)
=((b% GCD(b,q)* p% GCD(b,q))% GCD(b,q)+q%GCD(b,q))%GCD(b,q)
=((0* p% GCD(b,q))% GCD(b,q)+0)%GCD(b,q)
=(0%GCD(b,q))%GCD(b,q)
=0

所以 b 和 q 的最大公约数可以整除 a ,也即GCD(b,q)|a ,所以 GCD(b,q) 能同时整除 b 、q 和 a ,所以b 和 q 的最大公约数也是 a和b的公约数,也即GCD(b,q) 也可以整除 GCD(a,b) ,也就是GCD(b,q)|GCD(a,b)。

然后对等式做下变换:

q=a-b* p

然后考虑 a 和 b 的最大公约数 GCD(a,b), 可知:

a%GCD(a,b)=0
b%GCD(a,b)=0

接着同样根据取模运算规则和上面的两个公式可得:

q%GCD(a,b)
=(a-b * p)% GCD(a,b)
=(a% GCD(a,b)-b*p% GCD(a,b))% GCD(a,b)
=(a% GCD(a,b)-(b% GCD(a,b)*p% GCD(a,b))% GCD(a,b)) % GCD(a,b)
=(0-(0 * p % GCD(a,b)) % GCD(a,b)) % GCD(a,b)
=0

所以同理可得GCD(a,b) 也可以整除GCD(b,q) ,也就是
GCD(a,b)|GCD(b,q)

所以通过上面的推导可得:

GCD(a,b)=GCD(b,q)= GCD(b, a% b), b>0

然后就可以根据这个求解最大公约数的递归式来写实现代码了:

#include <iostream>

int GCD(int m, int n) {
    int t = 1;
    while(t != 0) {
        t=m%n;
        m=n;
        n=t;
    }
    return m;
}

int GCDRecursive(int m, int n) {
    if (n == 0) return m;
    return GCDRecursive(n, m % n);
}

int main(int argc, char *argv[]) {
    if (argc != 3) {
        return 0;
    }

    int a = atoi(argv[1]);
    int b = atoi(argv[2]);

    if (a < b) {
        int t = a;
        a = b;
        b = t;
    }

    int gcd = GCD(a, b);

    printf("GCD(%d, %d)=%d\n", a, b, gcd);

    gcd = GCDRecursive(a, b);

    printf("GCD(%d, %d)=%d\n", a, b, gcd);

    return 0;
}

运行结果:
1.jpg

2.2、多项式的欧几里得算法

欧几里得算法也可以推广到多项式上,和整数最大公因数类似的多项式上也有最大公因式的概念,一样也有整除和求余的概念。所以求解两个多项式的最大公因式一样也可以应用欧几里得算法[18]。

首先下面通过3个例子来说明多项式除法[16,17]是如何u操作的,这里引用资料[17]对多项式除法规则的定义:

  • 首先把被除式的第一项除以除式的第一项;
  • 然后把除式乘以上面除法得到的结果,然后写在被除式下面;
  • 两者相减得到新的多项式;
  • 然后重复上面3个步骤,重复的时候用得到的新多项式作为被除式,一直相减到多项式的最高次小于除式的最高次或者得到0就停止;

例子一、整除的情况,x²-3x-10 除以 x+2 :
1.jpg

WX20200226-155332.png
1.jpg
WX20200226-155332.png
1.jpg
WX20200226-155332.png

WX20200226-155332.png

但是因式分解看起来就很难用代码实现,而欧几里得算法用代码来实现就容易多了。

2.3、扩展欧几里得算法

在介绍扩展欧几里得算法之前先来看下“裴蜀等式”,下面引用wikipedia上的解释[19]:

在数论中,裴蜀等式或裴蜀定理是一个关于最大公约数(或最大公因式)的定理。说明了对任何整数 a 、b 和 m ,关于未知数 x 和 y 的方程:

ax+by=m

有整数解时当且仅当 m是 a 和 b的最大公约数GCD(a,b)的倍数,也就是要求 GCD(a,b)|m 。裴蜀等式有解时必然有无穷多个整数解,每组解 x、y 都称为裴蜀数,可用扩展欧几里得算法求解。

比如,12和42的最大公约数是6,则方程 12x+46y=6 。事实上有:

$$ (-3) * 12+1 * 42=6 \\\\ 4 * 12+(-1) * 42=6 $$

特别来说,ax+by=1 有整数解当且仅当a 和 b 互素,即 GCD(a,b)=1 。

证明过程有兴趣的读者可以参考[19]。

接着来看下如何用扩展欧几里得[13,20]算法求解裴蜀等式,简单来说扩展欧几里德算法是对欧几里德算法的扩展,它可以用来求解形如 ax+by=c(a,b,c∈Z)的方程的一组整数解。我们可以从欧几里德算法的等式来实现扩展欧几里得算法:

我们先来看下方程 ax+by=GCD(a,b)的边界情况,当b=0的时候,方程可化为 ax=GCD(a,0) ,然后根据最大公约数的性质可知GCD(a,0)=a ,所以可以解得 x=1,y=0 。

然后对于一般情况,也是应用最大公因数的性质GCD(a,b)=GCD(b,a mod b) ,首先设  a'=b,b'=a mod b ,然后同样有方程

a'x'+b'y'=GCD(a',b')=GCD(b,a mod b)

联合需要求解的方程ax+by=GCD(a,b)可得

ax+by=GCD(a,b)=GCD(b,a mod b)=GCD(a',b')=a'x'+b'y'
1.png

下面看下实现代码:

#include <iostream>

int exgcd(int a, int b, int &x, int &y) {
    if (b == 0) {
        x = 1;
        y = 0;
        return a;
    }
       int gcd = exgcd(b, a % b, x, y);
       int t = x;
       x = y;
       y = t - a / b * x;
    return gcd;
}

int main(int argc, char *argv[]) {
    if (argc != 3) {
        return 0;
    }

    int a = atoi(argv[1]);
    int b = atoi(argv[2]);

    if (a < b) {
        int t = a;
        a = b;
        b = t;
    }
    int x, y;
    int gcd = exgcd(a, b, x, y);

    printf("%d * %d + %d * %d = %d\n", a, x, b, y, gcd);

    return 0;
}

运行结果:
1.jpg

所以扩展欧几里得算法可以同时求出ax+by=GCD(a,b) 方程的解和最大公因数GCD(a,b) 。

2.4、多项式的扩展欧几里得算法

1.png

1.png

2.5、乘法模逆元

WX20200226-155332.png

代码:

#include <iostream>

int exgcd(int a, int b, int &x, int &y) {
    if (b == 0) {
        x = 1;
        y = 0;
        return a;
    }
       int gcd = exgcd(b, a % b, x, y);
       int t = x;
       x = y;
       y = t - a / b * x;
    return gcd;
}

int reverse_unit(int a, int b) {
    int x, y;
    int gcd = exgcd(a, b, x, y);
    if (gcd != 1) {
         printf("reverse unit does not exist.\n");
        return -1;
    }
    return (x % b + b) % b;
}

int main(int argc, char *argv[]) {
    if (argc != 3) {
        return 0;
    }

    int a = atoi(argv[1]);
    int b = atoi(argv[2]);

    int reverse= reverse_unit(a, b);

    if (reverse != -1)
        printf("%d * %d = 1 (mod %d) \n", a, reverse, b);

    return 0;
}

运行结果:
1.jpg

2.6、多项式乘法模逆元

同理也可以应用扩展欧几里得算法求解多项式模的逆元,下面直接举例进行说明。

WX20200226-161623.png

下面给出扩展欧几里得算法每一步的计算过程:

WX20200226-155332.png
验证下
WX20200226-161623.png

2.7、中国剩余定理

有了前面知识点的铺垫,理解中国剩余定理[24,25,26]就容易多了。文章[24]对中国剩余定理的解释非常透彻,下面对中国剩余定理的解释大部分是参考这篇文章。推荐对数学感兴趣读者可以关注该专栏,都是和数学相关的内容。

这里先来看下“孙子算经”[27]里面的第二十六题,原文如下:

今有物,不知其數。三三數之,賸二;五五數之,賸三;七七數之,賸二。
問:物幾何?
答曰:二十三。

術曰:

三三數之,賸二,置一百四十;
五五數之,賸三,置六十三;
七七數之,賸二 ,置三十。
并之,得二百三十三,以二百一十減之,即得。

凡三三數之,賸一,則置七十;
五五數之,賸一,則置二十一;
七七數之,賸一,則置十五。
一百六以上,以一百五減之,即得。

用通俗的语言描述第二十六题就是:

现在有一个整数,该整数除以3余2、除以5余3、除以7余2,求该整数是多少?
答案是:23

解法:

除以3余2,加140;
除以5余3,加63;
除以7余2,加30;
求和140+63+30=233,再减去210,就得到23。

只要是除以3余1,就加70;
只要是除以5余1,就加21;
只要是除以7余1,就加15;
然后累加,如果超过了106就减去105就得到结果了。

首先把这个问题转化为一个求解同余方程组的问题,然后对这个问题的解法就称为中国剩余定理:
WX20200226-161941.png
WX20200226-155332.png
1.png
WX20200226-162209.png

刚好对应了原文“三三數之,賸二,置一百四十;”这一句。

WX20200226-162257.png

刚好对应了原文“五五數之,賸三,置六十三;”这一句。

WX20200226-162328.png

刚好对应原文“七七數之,賸二 ,置三十。”这一句。

然后求得:

WX20200226-155332.png

WX20200226-164431.png

刚好对应了原文“并之,得二百三十三,以二百一十減之,即得。”

WX20200226-155332.png

WX20200226-155332.png

这就是中国剩余定理,如果弄懂了上面孙子算经的题目,应该就很容易理解这个求解公式了。

1.png

代码:

#include <iostream>

int exgcd(int a, int b, int &x, int &y) {
    if (b == 0) {
        x = 1;
        y = 0;
        return a;
    }
       int gcd = exgcd(b, a % b, x, y);
       int t = x;
       x = y;
       y = t - a / b * x;
    return gcd;
}

int get_crt(int *a, int *m, int len) {
    int r, y;

    int N = 1;
    for (int i = 0; i < len; ++i) {
        N *= m[i];
    }
    int X = 0;

    for(int i=0; i<len; ++i) {
        int Mi = N / m[i];
        int gcd = exgcd(Mi, m[i], r, y);
        X += a[i] * Mi * r;
    }
    
    return X % N;
}

int main(int argc, char *argv[]) {
    int m[3] = {3, 5, 7};
    int a[3] = {2, 3, 2};

    int X = get_crt(a, m, 3);

    printf("crt = %d \n", X);

    return 0;
}

运行结果:
1.png

2.8、多项式的中国剩余定理

类似的中国剩余定理同样可以应用到多项式上,下面参考[28]给出多项式版本的中国剩余定理的定义:

WX20200226-155332.png

WX20200226-155332.png
WX20200226-155332.png
WX20200226-170246.png

3、多项式的中国剩余定理的应用

3.1、卷积操作与中国剩余定理的联系

终于到了本文最重点的部分了,在开始看本节之前确保已经理解了前面提到的数学知识。通过前面的介绍我们已经知道了卷积操作等价于多项式乘法,下面简要描述下卷积是怎么和中国剩余定理的产生联系的,这也是我理解的Winograd这个算法的核心。

需要注意的是下面的一些结论是我根据实际例子比如F(2,3)和F(4,3)推导得到的结论不一定正确。

WX20200226-155332.png

我们先有个概念就是Winograd是一个构造式的算法,是人为去构造一个计算 s(x) 的等价变换,下面介绍如何构造。

WX20200226-171722.png
WX20200226-171812.png
接着根据取模运算法则有

WX20200226-171918.png

WX20200226-171928.png

WX20200226-172122.png

再套用到具体情况比如2x3, 4x3卷积的时候,如果变换之后等式右边的所需的乘法次数小于h(x)p(x)的乘法次数就能达到加速的目的。

3.2、Winograd F(2,3)变换矩阵推导

现在来看下具体到F(2,3)的变换矩阵是如何得到的。

WX20200226-172336.png

WX20200226-172422.png
WX20200226-172508.png
WX20200226-172538.png
WX20200226-172605.png
WX20200226-172635.png
WX20200226-172726.png
WX20200226-172758.png

所以就是4次乘法和9次加法,除2的操作的开销可以在实际应用的时候把除2操作放到权值变换那里,就可以把运行时的开销去掉了。可以看到比原来的6次乘法和2次加法,少两次乘法,但是加法次数变多了。

WX20200226-172850.png
WX20200226-172918.png
WX20200226-173008.png

然后来看下这个变换是如何应用到深度学习中的卷积(Correlation)里面的,对于F(2,3)的应用,是用在1x3或者3x1卷积里面,长度是3的卷积核连续卷积两次得到两个输出,输入序列长度是4,刚好是把Winograd的变换矩阵反着来用的,为了和上面的公式对应,这里用s,p,h 分别表示,输入,权值和输出:
WX20200226-173205.png

然后验证下
WX20200226-173242.png

结果与直接做Correlation一致。

其实这里有一点没想明白的地方是,卷积操作中的Winograd变换公式是如何变成用在Correlation中变换公式的,直接推导的话推不出来,感觉中间还缺了一环,但是确实结论是正确的,实际推导结果也正确。

终于写完了,真的是第一次写那么长的博客,而且公式也比较多,如果有哪里写的不对或者公式错误的地方,请读者见谅。

4、参考资料

[1] 油管--The Winograd Transformation
[2] 深度加速(一)——概述, Winograd(1)
[3] 深度加速(二)——Winograd(2)
[4] Arxiv--Fast Algorithms for Convolutional Neural Networks
[5] https://github.com/andravin/wincnn/blob/master/2464-supp.pdf
[6] 知乎--源于《孙子算经》的Cudnn
[7] 向量卷积与多项式乘法
[8] Convolution Vs Correlation
[9]The difference between convolution and cross-correlation from a signal-analysis point of view
[10] 卷积神经网络中的Winograd快速卷积算法
[11] 輾轉相除法
[12] 维基百科--质数
[13] 欧几里德算法与扩展欧几里德算法
[14] 我終於頓悟輾轉相除法求最大公約數的原理了
[15] 取模运算涉及的算法
[16] 多项式除法竖式应当如何理解?
[17] 多项式长除法
[18] 用辗转相除法求多项式的最大公因式
[19] 维基百科--裴蜀定理
[20] 数论小结
[21] 扩展欧几里得算法 有限域上多项式求逆
[22] 维基百科--模逆元
[23] 「扩展欧几里得算法」与「模逆元」详解
[24] 知乎--中国剩余定理(CRT )
[25] 扩展欧几里得算法与中国剩余定理
[26] 中国剩余定理算法详解(余数互质和不互质)
[27] 百科故事--《孙子算经》卷下
[28] 知乎--多项式也有CRT么?
[29] Winograd数学原理【卡住了>_<】



推荐文章

更多AI移动端优化的请关注专栏嵌入式AI以及知乎(@梁德澎)。
8 阅读 588
推荐阅读
0 条评论
关注数
12218
内容数
190
嵌入式端AI,包括AI算法在推理框架Tengine,MNN,NCNN,PaddlePaddle及相关芯片上的实现。欢迎加入微信交流群,微信号:gg15319381845(备注:嵌入式)
目录
极术微信服务号
关注极术微信号
实时接收点赞提醒和评论通知
Arm中国学堂公众号
关注Arm中国学堂
实时获取免费 Arm 教学资源信息
Arm中国招聘公众号
关注Arm中国招聘
实时获取 Arm 中国职位信息