爱笑的小姐姐 · 2021年01月21日

Pytorch转ONNX-实战篇2(实战踩坑总结)

文章转载于:GiantPandaCV
作者:立交桥跳水冠军
编辑:GiantPandaCV

作者丨立交桥跳水冠军

来源丨https://zhuanlan.zhihu.com/p/...

编辑丨GiantPandaCV

前两篇文章分别从理论和ONNX的核心机制描述了Pytorch转ONNX需要注意的事情。接下来这篇文章没有什么核心主旨,只是纯粹记录我当时做项目的时候踩的坑以及应对方案

(1)Pytorch2ONNX不支持对slice对象赋值

下面这段代码是不被Pytorch原生的onnx转换接口支持的,即不能对slice对象赋值

preds[:, :, y1:y2, x1:x2] += crop_seg_logit  

仔细想想其实也比较合理,因为上面的操作也很难在DAG上被表示,因为并不仅仅是把preds中的那个区域取出来弄个新的变量,然后在上面+1,而是直接把preds的一部分改掉了。当时我负责MMSeg的slide inference转换的时候遇到了这个问题,解决方案如下:

preds += F.pad(crop_seg_logit,  
               (int(x1), int(preds.shape[3] - x2), int(y1),  
                int(preds.shape[2] - y2)))  

即我对crop\_seg\_logit做了一个padding,把它变成了和preds一样的大小,这样我就直接变成了矩阵相加,没必要变成slice的操作了

这个方法自然很丑,而且会引出一个新的问题,那就是Pytorch生成的onnx padding的格式,onnx runtime接收的格式以及TensorRT需要的格式都不一样。这个就是之后的问题了(超纲了,不讲了)

这里具体的例子我懒得查了,以二维矩阵的填充为例。只记得一个转出来的是(begin0, begin1, end0, end1),另一个是(begin0, end0, begin1, end1)

这里面begin0代表第0维左边的填充数量,end0代表右边的填充数量

(2)resize

当时做segmentation模型的时候,最重要的就是resize操作。ONNX里面的resize要求output shape必须为常量(即tuple of int),因此不可以用tensor.Size作为输入,因为人家并不是tuple of int

if isinstance(size, torch.Size):  
    size = tuple(int(x) for x in size)  

所以我们必须手动粗暴的把torch.Size变成tuple of int

当时有reviewer吐槽我这个方法丑,要我改成tuple(size),说Pytorch重载了tuple,直接可以把torch.Size变成tuple of int。但是很诡异的是在正常情况下的确可以,但如果一旦进入了ONNX tracining模式,这个方法就失效了。我简单看了看,推测是因为对tuple的重载是在C++层面做的,而ONNX tracing也会涉及到一些C++层面的事情,也就是说ONNX tracing会重载一些C++的部分,可能正好就把tuple给抹掉了

(3) 应对kwargs的约束

pytorch自带的onnx转换api: torch.onnx.export,只支持args参数。一般来说调用这个api只需要提供model(喜闻乐见的nn.Module),调用model的参数args(也就是调用model.forwrd()的参数)以及导出的文件名f。然后这个函数就会内部执行一遍: model(*args),执行的时候做tracining

image.png

但是我们知道一般来说除了args,还需要kwargs,比如model(input, get_loss=False),其中input就是args,False就是kwargs。OpenMMLab里面几乎所有的model都需要kwargs_

为了绕开这个约束,我们需要利用python的partial函数,将model做个封装:

model.forward = partial(model.forward, return_loss=False)  

这样我们可以给model提供需要的kwargs,同时又可以原封不动的调用torch.onnx.export

注意,kwargs不能包括网络的输入,比如如果你想把input image放进args,那么得到的onnx就会是一个没有输入的图(它会把kwargs里面的input image当成一个常量)

(4)Pytorch和ONNX Runtime结果对齐

OpenMMLab系列提供了一个很有用的功能,就是自动比对Pytorch和ONNXRuntime的精度。这个功能可以帮助用户确定转出来的ONNX有没有问题。

然而之前也提到过,ONNXRuntime和Pytorch需要的ONNX格式不一样,而且有些计算也不一样,因此就算结果对不上,也不能代表什么

在某些操作上,ONNXRuntime和Pytorch的行为不一致。比如对一个一维tensor:[0,0,0]调用argmax,那么ONNXRuntime返回的是0,而Pytorch是1(举个例子,具体的差异我记不清了)

当时我在做Detection模型的自动比对的时候就遇到了问题,在经历了nms操作之后,bbox会根据score的大小做排序,但score相同的情况下,ONNXRuntime和Pytorch的结果就会有差异。因此我们最后只选择比对score,而不管bbox的dx,dy这些信息了

- The End -

推荐阅读

更多嵌入式AI技术相关内容请关注嵌入式AI专栏。
推荐阅读
关注数
18839
内容数
1376
嵌入式端AI,包括AI算法在推理框架Tengine,MNN,NCNN,PaddlePaddle及相关芯片上的实现。欢迎加入微信交流群,微信号:aijishu20(备注:嵌入式)
目录
极术微信服务号
关注极术微信号
实时接收点赞提醒和评论通知
安谋科技学堂公众号
关注安谋科技学堂
实时获取安谋科技及 Arm 教学资源
安谋科技招聘公众号
关注安谋科技招聘
实时获取安谋科技中国职位信息