首发:https://zhuanlan.zhihu.com/p/76605363
作者:张新栋
工程的完整链接可以参考Github链接。
Pytorch以其动态图的调用方式,深得许多科研人员的喜爱,是许多人进行科研研究、算法预研的不二之选。本文我们跟大家讨论一下,如何使用Pytorch来进行嵌入式的算法部署。这里我们采用的离线训练框架为Pytorch,嵌入式端的推理框架为阿里巴巴近期开源的高性能推理框架MNN。下面我们将结合MNIST这个简单的分类任务来跟大家一步一步的完成嵌入式端的部署。
Pytorch的模型不能直接被MNN进行解析,所以我们这里需要选定一个媒介。参考之前专栏的一篇文章《整合mxnet和MNN的嵌入式部署流程》,这里也采用ONNX进行pytorch和MNN之间的桥梁。
- 模型的设计
模型的设计与《整合mxnet和MNN的嵌入式部署流程》文中的模型设计基本一样,大家可以看下面代码:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
import torchvision
import torch.optim as optim
class MNIST(nn.Module):
def __init__(self):
super(MNIST, self).__init__()
self.conv0 = nn.Conv2d(in_channels=1, out_channels=20, kernel_size=5, bias=False)
self.bn0 = nn.BatchNorm2d(num_features=20)
self.relu0 = nn.ReLU()
self.maxp0 = nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2))
self.conv1 = nn.Conv2d(in_channels=20, out_channels=50, kernel_size=5, bias=False)
self.bn1 = nn.BatchNorm2d(num_features=50)
self.relu1 = nn.ReLU()
self.maxp1 = nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2))
self.conv2 = nn.Conv2d(in_channels=50, out_channels=500, kernel_size=4, stride=1, bias=False)
self.bn2 = nn.BatchNorm2d(num_features=500)
self.relu2 = nn.ReLU()
self.conv3 = nn.Conv2d(in_channels=500, out_channels=10, kernel_size=1, stride=1, bias=False)
# self.dense2 = nn.Linear(in_features=400, out_features=120, bias=False)
# self.dp2 = nn.Dropout(p=0.5)
# self.relu2 = nn.ReLU()
# self.dense3 = nn.Linear(in_features=120, out_features=10, bias=False)
def forward(self, x):
x = self.conv0(x)
x = self.bn0(x)
x = self.relu0(x)
x = self.maxp0(x)
x = self.conv1(x)
x = self.bn1(x)
x = self.relu1(x)
x = self.maxp1(x)
# x = x.view(-1, self.num_flat_features(x))
# x = self.dense2(x)
# x = self.dp2(x)
# x = self.relu2(x)
# x = self.dense3(x)
x = self.conv2(x)
x = self.bn2(x)
x = self.relu2(x)
x = self.conv3(x)
x = torch.squeeze(x)
return x
大家可以注意一下上述代码里的注释部分,这里我们进行一下介绍。由于Pytorch在实现矩阵乘法的时候,需要使用view进行数据的拉平,然后再进行matmul的操作,这几个操作MNN并没有进行支持(我第一次的实现是使用注释部分的代码,然后MNNConvert的时候报的错误,不支持一些op)。所以我们这里采用了一个4*4和1*1的con2d来代替全连接层。
2. 导出ONNX模型
我们需要导出另外一种可以被MNN解析的模型格式,这里我们选择的是ONNX。如下为导出ONNX的脚本文件:
import torch
import torch.nn as nn
import torch.onnx
from train_mnist import MNIST
# A model class instance (class not shown)
model = MNIST()
# Load the weights from a file (.pth usually)
weights_path = './mnist.pth'
state_dict = torch.load(weights_path)
# Load the weights now into a model net architecture defined by our class
model.load_state_dict(state_dict)
# Create the right input shape (e.g. for an image)
input = torch.randn(1, 1, 28, 28)
torch.onnx.export(model, input, "mnist.onnx", verbose=True)
import onnx
# Load the ONNX model
model = onnx.load("mnist.onnx")
# Check that the IR is well formed
onnx.checker.check_model(model)
# Print a human readable representation of the graph
onnx.helper.printable_graph(model.graph)
3. 导出MNN模型
MNN提供了转换ONNX到MNN模型的工具,执行如下脚本即可,关于MNN转换工具编译可以参考Model Conversion。下面是转换脚本:
./MNNConvert -f ONNX --modelFile mnist.onnx --MNNModel mnist.mnn --bizCode MNN
输出的结果如下:
Start to Convert Other Model Format To MNN Model...
[16:09:54] /Users/xindongzhang/MNN/tools/converter/source/onnx/onnxConverter.cpp:29: ONNX Model ir version: 3
Start to Optimize the MNN Net...
[16:09:54] /Users/xindongzhang/MNN/tools/converter/source/optimizer/optimizer.cpp:44: Inputs: 0
[16:09:54] /Users/xindongzhang/MNN/tools/converter/source/optimizer/optimizer.cpp:54: Outputs: 32, Type = Squeeze
Converted Done!
可以看出采用的ONNX IR版本为3,输入的节点名字为0,输出节点名字为32.
4. 在线部署
在线部署流程在这里,为使用MNN加载解析好的mnn模型参数进行inference等一系列业务操作。关于如何在android上面使用mnn进行部署,本专栏已经有好几篇介绍的文章,这里就不进行赘述了。完整的JNI业务代码可以参考如下链接JNI 业务代码。
- 最后
选取的样例为简单的mnist,虽然全连接层的实现在转换过程中有一些小问题,但是我们修改了网络结构,采用一个4x4和一个1x1的conv2d来进行替代,解决了模型转换的问题。另外欢迎大家留言讨论、关注本专栏及公众号,谢谢大家!
- 参考
- https://github.com/alibaba/MNN
- PyTorch
- https://github.com/xindongzhang/MNN-APPLICATIONS/tree/master/applications/mnist
- https://zhuanlan.zhihu.com/p/75742333
推荐阅读
- 使用AutoTVM基于Android-Arm-CPU调优CNN网络
- BlazeFace: 亚毫秒级的人脸检测器(含代码)
- 谈谈MNN的模型量化(一)数学模型
- 实战MNN之Mobilenet SSD部署(含源码)
专注嵌入式端的AI算法实现,欢迎关注作者微信公众号和知乎嵌入式AI算法实现专栏。
更多嵌入式AI相关的技术文章请关注极术嵌入式AI专栏。