爱笑的小姐姐 · 2021年05月25日

深度学习-性能优化3:TF-TRT的简单使用

本文转自:知乎
作者:djh

深度学习-性能优化3:TF-TRT的简单使用

1、在训练的时候保存为pb模型

import tensorflow as tf

sess = tf.Session()
matrix_1 = tf.constant([3., 3.], name='input')
add = tf.add(matrix_1, matrix_1, name='output')
sess.run(add)

output_graph_def = tf.graph_util.convert_variables_to_constants(sess, sess.graph_def, output_node_names=['output'])
# 保存模型到目录下的model文件夹中
with tf.gfile.FastGFile('./model/tensorflow_matrix_graph.pb',mode='wb') as f:
    f.write(output_graph_def.SerializeToString())

sess.close()

2、在训练的时候保存为pb模型

saver = tf.train.Saver()  
    last_ckpt = saver.save(sess, 'model/model.ckpt')

3.ckpt转为pb模型---冻结

output_node_names = "InceptionV3/Logits/SpatialSqueeze"
    saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True)
    graph = tf.get_default_graph() # 获得默认的图
    input_graph_def = graph.as_graph_def()  # 返回一个序列化的图代表当前的图

    with tf.Session() as sess:
        saver.restore(sess, input_checkpoint) #恢复图并得到数据
        output_graph_def = graph_util.convert_variables_to_constants(  # 模型持久化,将变量值固定
            sess=sess,
            input_graph_def=input_graph_def,# 等于:sess.graph_def
            output_node_names=output_node_names.split(","))# 如果有多个输出节点,以逗号隔开

        with tf.gfile.GFile(output_graph, "wb") as f: #保存模型
            f.write(output_graph_def.SerializeToString()) #序列化输出

刚开始的时候老是出现

assert d in name_to_node_map, “%s is not in graph” % d
AssertionError: A is not in graph

使用下列for 循环打印出所有节点

saver = tf.train.Saver()
saver.restore(self.sess, tf.train.latest_checkpoint(saved_model_path))
for op in tf.get_default_graph().get_operations():  
    print(op.name)

同理接下来也会使用,如下代码,获取tensor的节点名字

for op in graph.get_operations():  
        print(op.name,op.values())

4.ckpt不转换直接冻结图,直接使用

import os
try:
    os.chdir(os.path.join(os.getcwd(), 'examples/classification'))
    print(os.getcwd())
except:
    pass

#%%
from PIL import Image
import sys
import os
import urllib
import tensorflow as tf
import tensorflow.contrib.tensorrt as trt
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import numpy as np
MODEL = 'inception_v1'
CHECKPOINT_PATH = 'inception_v1.ckpt'
NUM_CLASSES = 1001
LABELS_PATH = './data/imagenet_labels_%d.txt' % NUM_CLASSES
IMAGE_PATH = './data/dog-yawning.jpg'

frozen_graph, input_names, output_names = build_classification_graph(
    model=MODEL,
    checkpoint=checkpoint_path,
    num_classes=NUM_CLASSES
)

trt_graph = trt.create_inference_graph(
    input_graph_def=frozen_graph,
    outputs=output_names,
    max_batch_size=1,
    max_workspace_size_bytes=1 << 25,
    precision_mode='FP16',
    minimum_segment_size=50
)

tf_config = tf.ConfigProto()
tf_config.gpu_options.allow_growth = True

tf_sess = tf.Session(config=tf_config)

tf.import_graph_def(trt_graph, name='')

tf_input = tf_sess.graph.get_tensor_by_name(input_names[0] + ':0')
tf_output = tf_sess.graph.get_tensor_by_name(output_names[0] + ':0')

image = Image.open(IMAGE_PATH)

plt.imshow(image)

width = int(tf_input.shape.as_list()[1])
height = int(tf_input.shape.as_list()[2])
image = np.array(image.resize((width, height)))

output = tf_sess.run(tf_output, feed_dict={
    tf_input: image[None, ...]
})

scores = output[0]

with open(LABELS_PATH, 'r') as f:
    labels = f.readlines()

top5_idx = scores.argsort()[::-1][0:5]

for i in top5_idx:
    print('(%3f) %s' % (scores[i], labels[i]))

tf_sess.close()



def build_classification_graph(model, checkpoint, num_classes):
    global NETS, input_name, output_name

    net = NETS[model]
    tf_config = tf.ConfigProto()
    tf_config.gpu_options.allow_growth = True

    with tf.Graph().as_default() as tf_graph:
        with tf.Session(config=tf_config) as tf_sess:

            tf_input = tf.placeholder(tf.float32, [None, net.input_height, net.input_width, 3],
                    name=input_name)
            tf_preprocessed = net.preprocess(tf_input)

            with slim.arg_scope(net.arg_scope()):
                tf_net, tf_end_points = net.model(tf_preprocessed, is_training=False,
                    num_classes=num_classes)

            tf_output = net.postprocess(tf_net, name=output_name)

            # load checkpoint
            tf_saver = tf.train.Saver()
            tf_saver.restore(save_path=checkpoint, sess=tf_sess)

            # freeze graph
            frozen_graph = tf.graph_util.convert_variables_to_constants(
                tf_sess,
                tf_sess.graph_def,
                output_node_names=[output_name]
            )

            # remove relu 6
            frozen_graph = convert_relu6(frozen_graph)

    return frozen_graph, [input_name], [output_name]

其他参考文章

https://github.com/yazone/ai\_learning\_path​github.comhttps://github.com/yazone/ai\_learning\_path/tree/master/04.AI%E8%90%BD%E5%9C%B0%E5%AE%9E%E8%B7%B5/04.%E6%80%A7%E8%83%BD%E4%BC%98%E5%8C%96/nvidia%E6%80%A7%E8%83%BD%E4%BC%98%E5%8C%96​github.com

更多嵌入式AI技术相关内容请关注嵌入式AI专栏。
推荐阅读
关注数
16536
内容数
1230
嵌入式端AI,包括AI算法在推理框架Tengine,MNN,NCNN,PaddlePaddle及相关芯片上的实现。欢迎加入微信交流群,微信号:aijishu20(备注:嵌入式)
目录
极术微信服务号
关注极术微信号
实时接收点赞提醒和评论通知
安谋科技学堂公众号
关注安谋科技学堂
实时获取安谋科技及 Arm 教学资源
安谋科技招聘公众号
关注安谋科技招聘
实时获取安谋科技中国职位信息