张新栋 · 3月26日

整合Pytorch和MNN的嵌入式部署流程

首发:https://zhuanlan.zhihu.com/p/76605363
作者:张新栋

工程的完整链接可以参考Github链接

Pytorch以其动态图的调用方式,深得许多科研人员的喜爱,是许多人进行科研研究、算法预研的不二之选。本文我们跟大家讨论一下,如何使用Pytorch来进行嵌入式的算法部署。这里我们采用的离线训练框架为Pytorch,嵌入式端的推理框架为阿里巴巴近期开源的高性能推理框架MNN。下面我们将结合MNIST这个简单的分类任务来跟大家一步一步的完成嵌入式端的部署。

Pytorch的模型不能直接被MNN进行解析,所以我们这里需要选定一个媒介。参考之前专栏的一篇文章《整合mxnet和MNN的嵌入式部署流程》,这里也采用ONNX进行pytorch和MNN之间的桥梁。

  1. 模型的设计

模型的设计与《整合mxnet和MNN的嵌入式部署流程》文中的模型设计基本一样,大家可以看下面代码:

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
import torchvision
import torch.optim as optim

class MNIST(nn.Module):

    def __init__(self):
        super(MNIST, self).__init__()
        self.conv0  = nn.Conv2d(in_channels=1, out_channels=20, kernel_size=5, bias=False)
        self.bn0    = nn.BatchNorm2d(num_features=20)
        self.relu0  = nn.ReLU()
        self.maxp0  = nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2))

        self.conv1  = nn.Conv2d(in_channels=20, out_channels=50, kernel_size=5, bias=False)
        self.bn1    = nn.BatchNorm2d(num_features=50)
        self.relu1  = nn.ReLU()
        self.maxp1  = nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2))

        self.conv2  = nn.Conv2d(in_channels=50, out_channels=500, kernel_size=4, stride=1, bias=False)
        self.bn2    = nn.BatchNorm2d(num_features=500)
        self.relu2  = nn.ReLU()
        self.conv3  = nn.Conv2d(in_channels=500, out_channels=10, kernel_size=1, stride=1, bias=False)
        # self.dense2 = nn.Linear(in_features=400, out_features=120, bias=False)
        # self.dp2    = nn.Dropout(p=0.5)
        # self.relu2  = nn.ReLU()
        # self.dense3 = nn.Linear(in_features=120, out_features=10, bias=False)

    def forward(self, x):
        x = self.conv0(x)
        x = self.bn0(x)
        x = self.relu0(x)
        x = self.maxp0(x)

        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu1(x)
        x = self.maxp1(x)

        # x = x.view(-1, self.num_flat_features(x))
        # x = self.dense2(x)
        # x = self.dp2(x)
        # x = self.relu2(x)
        # x = self.dense3(x)

        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu2(x)
        x = self.conv3(x)
        x = torch.squeeze(x)
        return x

大家可以注意一下上述代码里的注释部分,这里我们进行一下介绍。由于Pytorch在实现矩阵乘法的时候,需要使用view进行数据的拉平,然后再进行matmul的操作,这几个操作MNN并没有进行支持(我第一次的实现是使用注释部分的代码,然后MNNConvert的时候报的错误,不支持一些op)。所以我们这里采用了一个4*4和1*1的con2d来代替全连接层。

2. 导出ONNX模型

我们需要导出另外一种可以被MNN解析的模型格式,这里我们选择的是ONNX。如下为导出ONNX的脚本文件:

import torch
import torch.nn as nn
import torch.onnx
from train_mnist import MNIST

# A model class instance (class not shown)
model = MNIST()

# Load the weights from a file (.pth usually)
weights_path = './mnist.pth'
state_dict = torch.load(weights_path)

# Load the weights now into a model net architecture defined by our class
model.load_state_dict(state_dict)

# Create the right input shape (e.g. for an image)
input = torch.randn(1, 1, 28, 28)

torch.onnx.export(model, input, "mnist.onnx", verbose=True)


import onnx

# Load the ONNX model
model = onnx.load("mnist.onnx")

# Check that the IR is well formed
onnx.checker.check_model(model)

# Print a human readable representation of the graph
onnx.helper.printable_graph(model.graph)

3. 导出MNN模型

MNN提供了转换ONNX到MNN模型的工具,执行如下脚本即可,关于MNN转换工具编译可以参考Model Conversion。下面是转换脚本:

./MNNConvert -f ONNX --modelFile mnist.onnx --MNNModel mnist.mnn --bizCode MNN

输出的结果如下:

Start to Convert Other Model Format To MNN Model...
[16:09:54] /Users/xindongzhang/MNN/tools/converter/source/onnx/onnxConverter.cpp:29: ONNX Model ir version: 3
Start to Optimize the MNN Net...
[16:09:54] /Users/xindongzhang/MNN/tools/converter/source/optimizer/optimizer.cpp:44: Inputs: 0
[16:09:54] /Users/xindongzhang/MNN/tools/converter/source/optimizer/optimizer.cpp:54: Outputs: 32, Type = Squeeze
Converted Done!

可以看出采用的ONNX IR版本为3,输入的节点名字为0,输出节点名字为32.

4. 在线部署

在线部署流程在这里,为使用MNN加载解析好的mnn模型参数进行inference等一系列业务操作。关于如何在android上面使用mnn进行部署,本专栏已经有好几篇介绍的文章,这里就不进行赘述了。完整的JNI业务代码可以参考如下链接JNI 业务代码

  • 最后

选取的样例为简单的mnist,虽然全连接层的实现在转换过程中有一些小问题,但是我们修改了网络结构,采用一个4x4和一个1x1的conv2d来进行替代,解决了模型转换的问题。另外欢迎大家留言讨论、关注本专栏及公众号,谢谢大家!

  • 参考
  1. https://github.com/alibaba/MNN
  2. PyTorch
  3. https://github.com/xindongzhang/MNN-APPLICATIONS/tree/master/applications/mnist
  4. https://zhuanlan.zhihu.com/p/75742333


推荐阅读

专注嵌入式端的AI算法实现,欢迎关注作者微信公众号和知乎嵌入式AI算法实现专栏

WX20200305-192544.png

更多嵌入式AI相关的技术文章请关注极术嵌入式AI专栏

5 阅读 161
推荐阅读
0 条评论
关注数
174
文章数
41
嵌入式 AI,让AI无处不在。欢迎加入微信交流群,微信号:gg15319381845(备注:嵌入式)
目录
qrcode
关注微信服务号
实时接收回答提醒和评论通知