爱笑的小姐姐 · 2021年01月20日

Pytorch转ONNX-实战篇1(tracing机制)

文章转载于:GiantPandaCV
作者:立交桥跳水冠军
编辑:GiantPandaCV

昨天的文章简单描述了在Pytorch转ONNX中面临的问题和需要注意的事情,今天的文章会重点结合OpenMMlab系列中用到的Pytorch转ONNX的小技巧来介绍实战部分。

(1)tracing的机制

上文提到过,Pytorch转ONNX的方式是基于tracing(追踪),通俗来说,就是ONNX的相关代码在一旁看着Pytorch跑一遍,运行了什么内容就把什么记录下来。但是在这里并不是所有Python的运行内容都会被记录。举个例子,下面的代码中,

c = torch.matmul(a, b)  
print("Blabla")  
e = torch.matmul(c, d)  

其中只有第1,3行相关的内容会被记录,因为只有他们是和Pytorch相关的,而第二行只是普通的python语句。

具体来说,只有ATen操作会被记录下来。ATen可以被理解为一个Pytorch的基本操作库,一切的Pytorch函数都是基于这些零部件构造出来的(比如ATen就是加减乘除,所有Pytorch的其他操作,比如平方,算sigmoid,都可以根据加减乘除构造出来)

*之前说的ONNX无法记录if语句的问题也是因为if并不是Aten中的操作

虽然ONNX可以记录所有Pytorch的执行(即记录所有ATen操作),但是在输出的时候会做一个剪枝,把没用的操作剪掉

举个例子,下面的程序,显而易见第一句话是没有用的。

t1 = torch.matmul(a, b)  
t2 = torch.matmul(c, d)  
return t2  

ONNX会在得到全部的操作以及他们之间的输入输出关系后(以DAG作为表示),根据DAG的输出往前推,做遍历,所有可以被遍历到的节点被保留,其他节点直接扔掉。

在MMDetection(https://github.com/open-mmlab...,在NMS(non-Maximumnon maximum suppression)中有如下代码:

if bboxes.numel() == 0:  
    bboxes = multibboxes.newzeros((0, 5))  
    labels = multibboxes.newzeros((0, ), dtype=torch.long)  
  
    if torch.onnx.isinonnxexport():  
        raise RuntimeError('[ONNX Error] Can not record NMS '  
                           'as it has not been executed this time')  
    return bboxes, labels  
  
dets, keep = batchednms(bboxes, scores, labels, nmscfg)  

代码逻辑很简单,如果之前的网络根本没有输出任何合法的bbox(第一行的分支判断),那么显然nms的结果就是一堆0,所以没必要运行nms直接返回0就可以。

如果我们想将这段代码转换到ONNX,之前我们提到过ONNX不能处理分支逻辑,因此只能选择一条路去走,记录那条路转换得到的模型。很显然,正常情况下我们自然期待会有较多的bbox,并且将这些bbox作为参数调用nms。

所以如果我们发现模型执行的路径触发了if分支,我们必须要进行一个判断,看看是不是在转ONNX,如果是的话我们就需要直接报错,因为显然转出来的ONNX不是我们想要的。

假设什么都不做,在这种情况下我们转出来的模型是什么样呢?思考一下不难发现,假设函数的返回值就是网络的最终输出,那么我们只会得到一个2个节点的DAG,即第2,3行的两个操作。之前说过ONNX拿到所有的DAG之后会做剪枝,在这里ONNX拿到返回值(bboxes, labels)做回溯,发现最头上就是第2,3行的两个操作,就直接停掉了。所有其他的操作,比如backbone,rpn,fpn,都会被扔掉。

因此,在进行MMDet模型的转换的时候,必须用真实的数据和训练好的参数来做转换,否则基本不会得到有效的bbox,于是就会触发第6行的error

(2)利用tracing机制做优化

在MMSeg中有一个很巧妙的利用tracing机制做优化的例子。

在slide inference时,我们需要计算一个count mat矩阵,这个矩阵在h, w以及对应的stride都固定的情况下会是一个常量。

不过在训练时,往往这些都是我们要调的参数,所有MMSeg没有选择把这些常数保存下来,而是每次都算一遍

countmat = img.newzeros((batchsize, 1, himg, wimg))  
        for hidx in range(hgrids):  
            for widx in range(wgrids):  
                y1 = hidx * hstride  
                x1 = widx * wstride  
                y2 = min(y1 + hcrop, himg)  
                x2 = min(x1 + wcrop, wimg)  
                y1 = max(y2 - hcrop, 0)  
                x1 = max(x2 - wcrop, 0)  
                cropimg = img[:, :, y1:y2, x1:x2]  
                cropseglogit = self.encodedecode(cropimg, imgmeta)  
                preds += F.pad(cropseglogit,  
                               (int(x1), int(preds.shape[3] - x2), int(y1),  
                                int(preds.shape[2] - y2)))  
  
                countmat[:, :, y1:y2, x1:x2] += 1  
        assert (countmat == 0).sum() == 0  
        if torch.onnx.isinonnxexport():  
            # cast countmat to constant while exporting to ONNX  
            countmat = torch.fromnumpy(  
                countmat.cpu().detach().numpy()).to(device=img.device)  

不过在部署时,这些参数往往是固定的,因此我们没必要把它算一遍。因此在倒数第4行的if分支里,我们做了一件看似很没用的事

countmat = torch.fromnumpy(countmat.cpu().detach().numpy()).to(device=img.device)  

即我们把算出来的countmat从tensor转换成numpy,再转回tensor。

其实我们的目的是切断tracing。

之前提到过,ONNX只能记录ATen相关的操作,但是很显然,tensor和numpy的互转肯定不是ATen操作。因此在回溯的时候,当访问到count mat,ONNX并不能发现它是被谁运算出来的,所以countmat就会被看作一个常数被保存下来,之前计算countmat的部分都会被扔掉

- The End -

推荐阅读

更多嵌入式AI技术相关内容请关注嵌入式AI专栏。
推荐阅读
关注数
18839
内容数
1376
嵌入式端AI,包括AI算法在推理框架Tengine,MNN,NCNN,PaddlePaddle及相关芯片上的实现。欢迎加入微信交流群,微信号:aijishu20(备注:嵌入式)
目录
极术微信服务号
关注极术微信号
实时接收点赞提醒和评论通知
安谋科技学堂公众号
关注安谋科技学堂
实时获取安谋科技及 Arm 教学资源
安谋科技招聘公众号
关注安谋科技招聘
实时获取安谋科技中国职位信息