爱笑的小姐姐 · 2021年11月09日

Tengine推理引擎 | 切图 源码解读

Tengine-lite 源码阅读已经出了三期的文章了,在对每个部分解读的过程中也对tengine的运行机制有了一些认识。本期 切图 的源码解读应该是Tengine的非常核心的机制之一。

代码比前几期都相对短了很多,但是我反反复复看了好多遍,才大致理解了split运行的时候在做什么,以及为什么要这么做。期间还和闲来大佬打了25分钟的语音电话,闲来大佬也讲了讲怎么去读源码,并不一定每一行都要读(其实我也是挑着重点看的)。

我在这传递一下闲来大佬的精神:大致就是关键位置打上断点,开debug模式过一遍每个流程,包括init、prerun、split、run等,然后再一个一个分解,(重点是不懂的随时@闲来大佬)

总的来说,切图就是看每一个子图节点的op支不支持当前的device,如果支持,就把子图放在device上,如果不支持就放在cpu上,然后把前后相连的且在同一device的子图合并。(貌似商业版的切图机制更复杂一些,支持混合精度,这样在推理的时候应该是更快了)

下面进入正题,不想看全部的可以直接跳到最后一个函数add_sub_graph_to_ir_graph,我主要是向闲来大佬请教的这个函数的问题。

// split.c
#define MODEL_COMPLEX_COUNT 3
​
// 判断tensor的数据类型是否为允许的精度数据类型
// return tensor.data_type in allowed_precision or tensor的量化参数数 > 0
int tensor_in_precision(const struct tensor* tensor, struct vector* allowed_precision)
{
    int count = get_vector_num(allowed_precision);
    for (int i = 0; i < count; i++)
    {
        const int* const precision = (const int* const)get_vector_data(allowed_precision, i);
        if (*precision == (int)tensor->data_type || tensor->quant_param_num > 0)
        {
            return 0;
        }
    }
​
    return -1;
}
​
​
int node_in_precision(const struct graph* ir_graph, uint16_t node_id, struct vector* allowed_precision)
{
    if (node_id > ir_graph->node_num)
    {
        return -1;
    }
​
    const ir_node_t* ir_node = ir_graph->node_list[node_id];
​
    for (int8_t i = 0; i < ir_node->output_num; i++)
    {
        // 获取node的当前output tensor
        uint16_t index = ir_node->output_tensors[i];
        const struct tensor* tensor = ir_graph->tensor_list[index];
​
        if (TENSOR_TYPE_VAR == tensor->tensor_type || TENSOR_TYPE_INPUT == tensor->tensor_type)
        {
            // 判断可变tensor或者输入tensor的数据类型合法性,【只要有一个合法就返回0??? 难道不应该每一个都判断吗】
            // 这边请教了闲来大佬,说是理论上不应该有node支持混合精度,这里只是个潜在的约束,所以不用太在意这个问题
            const int in_precision = tensor_in_precision(tensor, allowed_precision);
            if (0 == in_precision)
            {
                return 0;
            }
        }
        else
        {
            return 0;
        }
    }
​
    return -1;
}
​
// 判断指定node的op是否是在ops_list中
// return node.op.type in [op.type for op in ops_list]
int node_in_list(const struct graph* ir_graph, struct vector* ops_list, const uint16_t node_id)
{
    
    if (NULL == ir_graph || NULL == ops_list)
    {
        return -1;
    }
​
    const uint16_t node_op_type = ir_graph->node_list[node_id]->op.type;
​
    for (int i = 0; i < get_vector_num(ops_list); i++)
    {
        int* loop_op = (int*)get_vector_data(ops_list, i);
        if (node_op_type == *loop_op)
        {
            return 0;
        }
    }
​
    return -1;
}
​
// 获取graph中所有的是阻塞op的node或者不是允许的精度的node的id列表 (就是得到不支持的节点)
// return [i for i in range(len(node_num)) 
//          if (nodes[i].op.type in blocked_ops) or (nodes[i].datatype not in allowed_precision)]
struct vector* get_graph_blocked_nodes(const struct graph* ir_graph, struct vector* blocked_ops, struct vector* allowed_precision)
{
    struct vector* blocked_nodes_list = create_vector(sizeof(uint16_t), NULL);
​
    for (uint16_t i = 0; i < ir_graph->node_num; i++)
    {
        // 判断第i个node的op是否为阻塞op
        int is_blocked_op = node_in_list(ir_graph, blocked_ops, i);
        // 判断第i个node的精度是否为允许的精度
        int is_allowed_precision = node_in_precision(ir_graph, i, allowed_precision);
        
        if (0 == is_blocked_op || 0 != is_allowed_precision)
        {
            push_vector_data(blocked_nodes_list, &i);
            continue;
        }
    }
​
    return blocked_nodes_list;
}
​
// policy has some issue, must be fixed
void split_graph_node_to_sub_graph(struct graph* ir_graph, struct vector* allowed_ops, struct vector* blocked_ops, struct vector* allowed_precision)
{
    // 先得到所有不支持的node
    struct vector* blocked_nodes_list = get_graph_blocked_nodes(ir_graph, blocked_ops, allowed_precision);
    const int blocked_nodes_count = get_vector_num(blocked_nodes_list);
    
    // 如果有不支持的节点
    if (blocked_nodes_count != 0)
    {
        // from the last unsupported node to collecting all sub graphs
        // scan from back to front
        for (int i = blocked_nodes_count - 1; i >= 0; i--)
        {
            // start node id (the blocked one)
            // 获取当前的不支持node的id
            // 有点双指针的感觉,举个例子,blocked_nodes_list=[1,2,3], graph的节点个数是5,
            // 那i=0时,first是3,last是5,扫描的范围是3~5,
            // i=1时,first左移是2,此时3以后的索引是i=0时扫描过了的,可以直接忽略,因此last是3,扫描的范围是2~3
            uint16_t first_node_id = *((uint16_t*)get_vector_data(blocked_nodes_list, i));
            uint16_t last_node_id = ir_graph->node_num;
            
            if (i < blocked_nodes_count - 1)
            {
                last_node_id = *((uint16_t*)get_vector_data(blocked_nodes_list, i + 1));
            }
            
            int children_nodes_is_complicated = 0;
​ 
            // scan if these nodes is complicated to be solved
            // 把扫描的 数据类型有效的node的复杂度统计出来(被我强行命名为复杂度),加到children_nodes_is_complicated
            for (uint16_t j = first_node_id; j < last_node_id; j++)
            {
                if (0 == node_in_list(ir_graph, allowed_ops, j))
                {
                    const uint16_t node_op_type = ir_graph->node_list[j]->op.type;
​
                    if (OP_FC == node_op_type)
                    {
                        children_nodes_is_complicated += MODEL_COMPLEX_COUNT;
                    }
                    else
                    {
                        children_nodes_is_complicated++;
                    }
                }
            }
​
            // 判断 复杂度如果不是特别大,直接把node添加到subgraph中,subgraph初始为CPU
            if (children_nodes_is_complicated < MODEL_COMPLEX_COUNT) // directly add these nodes to sub graph list
            {
                struct subgraph* sub_graph = (struct subgraph*)sys_malloc(sizeof(struct subgraph));
                init_ir_subgraph((struct graph*)ir_graph, sub_graph, 0);
​
                // not including the last one
                sub_graph->node_num = last_node_id - first_node_id;
                sub_graph->node_list = (uint16_t*)sys_malloc(sizeof(uint16_t) * sub_graph->node_num);
​
                for (uint16_t j = 0; j < sub_graph->node_num; j++)
                {
                    sub_graph->node_list[j] = j + first_node_id;
                }
​
                sub_graph->device = find_default_device();
​
                push_vector_data(ir_graph->subgraph_list, &sub_graph);
            }
            else
            {
                // 复杂度较高的情况,使用指定好了的加速设备创建子图,子图的第一个节点需要使用cpu执行
                struct subgraph* sub_device_graph = (struct subgraph*)sys_malloc(sizeof(struct subgraph));
                init_ir_subgraph((struct graph*)ir_graph, sub_device_graph, 0);
​
                sub_device_graph->node_num = last_node_id - (first_node_id + 1);
                sub_device_graph->node_list = (uint16_t*)sys_malloc(sizeof(uint16_t) * sub_device_graph->node_num);
​
                for (uint16_t j = 0; j < sub_device_graph->node_num; j++)
                {
                    sub_device_graph->node_list[j] = j + first_node_id + 1;
                }
​
                struct device* nn_dev = ir_graph->attribute->context->device;
                sub_device_graph->device = nn_dev;
​
                push_vector_data(ir_graph->subgraph_list, &sub_device_graph);
​
                // ---------------
​
                // add cpu running nodes
                struct subgraph* sub_cpu_graph = (struct subgraph*)sys_malloc(sizeof(struct subgraph));
                init_ir_subgraph((struct graph*)ir_graph, sub_cpu_graph, 0);
​
                sub_cpu_graph->node_num = 1;
                sub_cpu_graph->node_list = (uint16_t*)sys_malloc(sizeof(uint16_t) * sub_cpu_graph->node_num);
                sub_cpu_graph->node_list[0] = first_node_id;
​
                sub_cpu_graph->device = find_default_device();
​
                push_vector_data(ir_graph->subgraph_list, &sub_cpu_graph);
            }
        }
    }
​
    // add main sub graph
    // 这部分是 添加一个起始的子图,有点像USB一分三,这部分就是那个一
    // 上一部分是 扫描所有的blocked的node,分配subgraph,这里我们举例blocked node=[3,4,5]
    // 也就是最后扫描的区间是3~4号的node,那0~2这几个node还没分配subgraph,而且0~2的node都不是blocked,
    // 就直接给0~2的node创建子图,设置对应的device
    struct subgraph* sub_graph = (struct subgraph*)sys_malloc(sizeof(struct subgraph));
    
    // subgraph->graph指向graph
    init_ir_subgraph((struct graph*)ir_graph, sub_graph, 0);
​
    uint16_t stop_node_id;
    if (blocked_nodes_count == 0)
    {
        stop_node_id = ir_graph->node_num;
    }
    else
    {
        stop_node_id = *((uint16_t*)get_vector_data((struct vector*)blocked_nodes_list, 0));
    }
​
    sub_graph->node_num = stop_node_id;
    sub_graph->node_list = (uint16_t*)sys_malloc(sizeof(uint16_t) * sub_graph->node_num);
​
    for (uint16_t i = 0; i < stop_node_id; i++)
    {
        sub_graph->node_list[i] = i;
    }
​
    // 设置起始subgraph的设备
    struct device* nn_dev = NULL;
    if (NULL != ir_graph->attribute->context->device)
    {
        nn_dev = ir_graph->attribute->context->device;
    }
    else
    {
        nn_dev = find_default_device();
    }
    sub_graph->device = nn_dev;
​
    push_vector_data(ir_graph->subgraph_list, &sub_graph);
​
    release_vector(blocked_nodes_list);
​
    // optimize the sub graphs
    while (1)
    {
        int same_sub_graph_found = 0;
        int sub_graphs_count = get_vector_num(ir_graph->subgraph_list);
        for (int i = 1; i < sub_graphs_count; i++)
        {
            // 获取最后的子图,对应着离输入最近的子图,因为subgraph_list在添加的时候是从最后往前添加的,
            // 因此node最小的是在subgraph_list的最后
            struct subgraph* last_sub_graph = *(struct subgraph**)get_vector_data(ir_graph->subgraph_list, (sub_graphs_count - 1) - (i - 1));
            // 这里写的是current,实际上就是last_sub_graph在subgraph list中的前一个subgraph,
            // 而且current的node id靠近graph的输出层
            struct subgraph* current_sub_graph = *(struct subgraph**)get_vector_data(ir_graph->subgraph_list, (sub_graphs_count - 1) - i);
           // 这里判断相邻的两个device是否一致,如果一致,那就直接把两个subgraph合并成一个,
            // 要注意顺序,last的节点是靠前的,current的节点是靠后的
            if (current_sub_graph->device == last_sub_graph->device)
            {
                uint16_t* node_list = (uint16_t*)sys_malloc(sizeof(uint16_t) * (last_sub_graph->node_num + current_sub_graph->node_num));
​
                for (int j = 0; j < last_sub_graph->node_num; j++)
                {
                    node_list[j] = last_sub_graph->node_list[j];
                }
​
                for (int j = 0; j < current_sub_graph->node_num; j++)
                {
                    node_list[j + last_sub_graph->node_num] = current_sub_graph->node_list[j];
                }
​
                last_sub_graph->node_num += current_sub_graph->node_num;
                sys_free(last_sub_graph->node_list);
                last_sub_graph->node_list = node_list;
​
                remove_vector_via_index(ir_graph->subgraph_list, (sub_graphs_count - 1) - i);
​
                same_sub_graph_found = 1;
                break;
            }
        }
​
        if (!same_sub_graph_found)
            break;
    }
}
​
// 为graph的每一个subgraph中的input tensor列表和output tensor list设置输入输出tensor id
void generate_sub_graph_io(struct graph* ir_graph)
{
    int sub_graph_count = get_vector_num(ir_graph->subgraph_list);
    for (int index = 0; index < sub_graph_count; index++)
    {
        // 遍历每一个subgraph
        struct subgraph* sub_graph = *(struct subgraph**)get_vector_data(ir_graph->subgraph_list, index);
​
        uint16_t random_input_id = 0;
        uint16_t random_output_id = 0;
​
        // 找到subgraph的第一个input tensor id,赋给random_input_id
        // 这边查找id的方法并不是严格的,所以是random,
        // 因为每个node有多个input tensor和output tensor,代码只找第一个tensor
        // random的意义相当于分界点的标志位,因为min<=random,而max>=random
        for (int i = 0; i < sub_graph->node_num; i++)
        {  
            uint16_t node_id = sub_graph->node_list[i];
            struct node* ir_node = ir_graph->node_list[node_id];
            if (ir_node->input_num > 0)
            {
                struct tensor* tensor = get_ir_graph_tensor(ir_graph, ir_node->input_tensors[0]);
​
                if (tensor->tensor_type == TENSOR_TYPE_INPUT || tensor->tensor_type == TENSOR_TYPE_VAR)
                {
                    random_input_id = tensor->index;
                    break;
                }
            }
        }
        // 找到subgraph的第一个output tensor id,赋给random_output_id
        for (int i = 0; i < sub_graph->node_num; i++)
        {
            uint16_t node_id = sub_graph->node_list[i];
            struct node* ir_node = ir_graph->node_list[node_id];
            if (ir_node->output_num > 0)
            {
                struct tensor* tensor = get_ir_graph_tensor(ir_graph, ir_node->output_tensors[0]);
                random_output_id = tensor->index;
                break;
            }
        }
​
        uint16_t min_input_tensor_id = random_input_id;
        uint16_t max_input_tensor_id = random_input_id;
        uint16_t min_output_tensor_id = random_output_id;
        uint16_t max_output_tensor_id = random_output_id;
​
        // 这里是严格查找最大最小
        for (int i = 0; i < sub_graph->node_num; i++)
        {
            struct node* ir_node = ir_graph->node_list[sub_graph->node_list[i]];
            // 这里 查找了所有node 的所有input tensor
            for (int k = 0; k < ir_node->input_num; k++)
            {
                struct tensor* tensor = get_ir_graph_tensor(ir_graph, ir_node->input_tensors[k]);
​
                if (tensor->tensor_type == TENSOR_TYPE_INPUT || tensor->tensor_type == TENSOR_TYPE_VAR)
                {
                    // 查找tensor index的最小值
                    if (tensor->index < min_input_tensor_id)
                        min_input_tensor_id = tensor->index;
                // 查找最大值
                    if (tensor->index > max_input_tensor_id)
                        max_input_tensor_id = tensor->index;
                }
            }
            // output部分和input基本一致
            for (int k = 0; k < ir_node->output_num; k++)
            {
                struct tensor* tensor = get_ir_graph_tensor(ir_graph, ir_node->output_tensors[k]);
​
                if (tensor->tensor_type != TENSOR_TYPE_INPUT)
                {
                    if (tensor->index < min_output_tensor_id)
                        min_output_tensor_id = tensor->index;
​
                    if (tensor->index > max_output_tensor_id)
                        max_output_tensor_id = tensor->index;
                }
                else
                {
                    if (tensor->index < min_input_tensor_id)
                        min_input_tensor_id = tensor->index;
​
                    if (tensor->index > max_input_tensor_id)
                        max_input_tensor_id = tensor->index;
                }
            }
        }
        
        // 为input和output的tensors开辟空间
        uint16_t* input_tensors = (uint16_t*)malloc(sizeof(uint16_t) * (max_input_tensor_id - min_input_tensor_id + 1));
        uint16_t* output_tensors = (uint16_t*)malloc(sizeof(uint16_t) * (max_output_tensor_id - min_output_tensor_id + 1));
​
        memset(input_tensors, 0, sizeof(uint16_t) * (max_input_tensor_id - min_input_tensor_id + 1));
        memset(output_tensors, 0, sizeof(uint16_t) * (max_output_tensor_id - min_output_tensor_id + 1));
​
        for (int j = 0; j < sub_graph->node_num; j++)
        {
            struct node* ir_node = ir_graph->node_list[sub_graph->node_list[j]];
​
            for (int k = 0; k < ir_node->input_num; k++)
            {
                struct tensor* tensor = get_ir_graph_tensor(ir_graph, ir_node->input_tensors[k]);
​
                if (tensor->tensor_type == TENSOR_TYPE_INPUT || tensor->tensor_type == TENSOR_TYPE_VAR)
                {
                    // 这里是一个计数器
                    input_tensors[tensor->index - min_input_tensor_id]++;
                }
            }
            // 同样output tensor也有类似的计数器
            for (int k = 0; k < ir_node->output_num; k++)
            {
                struct tensor* tensor = get_ir_graph_tensor(ir_graph, ir_node->output_tensors[k]);
​
                if (tensor->tensor_type != TENSOR_TYPE_INPUT)
                {
                    if (tensor->tensor_type != TENSOR_TYPE_CONST)
                    {
                        output_tensors[tensor->index - min_output_tensor_id]++;
                    }
                }
                else
                {
                    input_tensors[tensor->index - min_input_tensor_id]++;
                }
            }
        }
        
        // 搜索范围定义成input 和output的区间的交集
        // 比如input的id区间是[1, 4], output的id区间是[3, 5], 那搜索范围就是[3,4]
        // 这部分的作用就是 把input的id区间和output的id区间交集部分的计数器置零,
        // 为什么要置零呢?因为交集的意思是 当前node的tensor既是input又是output,即tensor是中间层的tensor,
        // 我们要把这部分的过滤掉,不纳入后续统计
        uint16_t search_start = min_input_tensor_id > min_output_tensor_id ? min_input_tensor_id : min_output_tensor_id;
        uint16_t search_end = max_input_tensor_id < max_output_tensor_id ? max_input_tensor_id : max_output_tensor_id;
​
        for (int i = 0; i < (search_end - search_start) + 1; i++)
        {
            // 当i=0时,input offset=2, output offset是0
            int input_offset = (search_start - min_input_tensor_id) + i;
            int output_offset = (search_start - min_output_tensor_id) + i;
            // 这两个flag指向的实际上都是id为3的tensor,也就是搜索范围的第0个tensor
            int input_flag = input_tensors[input_offset];
            int output_flag = output_tensors[output_offset];
            
            if (input_flag > 0 && output_flag > 0)
            {
                input_tensors[input_offset] = 0;
                output_tensors[output_offset] = 0;
            }
        }
        
        // 统计subgraph的input tensor数
        sub_graph->input_num = 0;
        for (int j = 0; j < max_input_tensor_id - min_input_tensor_id + 1; j++)
        {
            if (input_tensors[j] > 0)
            {
                sub_graph->input_num++;
            }
        }
        // 统计subgraph的output tensor数
        sub_graph->output_num = 0;
        for (int j = 0; j < max_output_tensor_id - min_output_tensor_id + 1; j++)
        {
            if (output_tensors[j] > 0)
            {
                sub_graph->output_num++;
            }
        }
        // 为subgraph的输入输出开辟空间
        sub_graph->input_tensor_list = (uint16_t*)sys_malloc(sizeof(uint16_t) * sub_graph->input_num);
        sub_graph->output_tensor_list = (uint16_t*)sys_malloc(sizeof(uint16_t) * sub_graph->output_num);
        
        // 为subgraph的input tensor列表赋值
        uint16_t input_tensor_count = 0;
        for (int j = 0; j < max_input_tensor_id - min_input_tensor_id + 1; j++)
        {
            if (input_tensors[j] > 0)
            {
                sub_graph->input_tensor_list[input_tensor_count] = min_input_tensor_id + j;
                input_tensor_count++;
            }
        }
        // 为subgraph的output tensor列表赋值
        uint16_t output_tensor_count = 0;
        for (int j = 0; j < max_output_tensor_id - min_output_tensor_id + 1; j++)
        {
            if (output_tensors[j] > 0)
            {
                sub_graph->output_tensor_list[output_tensor_count] = min_output_tensor_id + j;
                output_tensor_count++;
            }
        }
​
        sys_free(input_tensors);
        sys_free(output_tensors);
    }
}
​
void add_sub_graph_to_ir_graph(struct graph* ir_graph)
{
    const int sub_graphs_count = get_vector_num(ir_graph->subgraph_list);
​
    // subgraph列表倒序排列 因为之前切图的时候靠近输出的子图放在subgraph列表的前面,现在要重新排序
    for (int i = 0; i < sub_graphs_count / 2; i++)
    {
        struct subgraph* sub_graph_front = *(struct subgraph**)get_vector_data(ir_graph->subgraph_list, i);
        struct subgraph* sub_graph_back = *(struct subgraph**)get_vector_data(ir_graph->subgraph_list, (sub_graphs_count - 1) - i);
​
        struct subgraph* mid_temp = (struct subgraph*)sys_malloc(sizeof(struct subgraph));
​
        memcpy(mid_temp, sub_graph_back, sizeof(struct subgraph));
        memcpy(sub_graph_back, sub_graph_front, sizeof(struct subgraph));
        memcpy(sub_graph_front, mid_temp, sizeof(struct subgraph));
​
        sys_free(mid_temp);
    }
​
    // 重置subgraph的index,同时重置graph中的node指向的subgraph的index
    for (int i = 0; i < sub_graphs_count; i++)
    {
        struct subgraph* sub_graph = *(struct subgraph**)get_vector_data(ir_graph->subgraph_list, i);
        sub_graph->index = i;
​
        for (int j = 0; j < sub_graph->node_num; j++)
        {
            ir_graph->node_list[sub_graph->node_list[j]]->subgraph_idx = i;
        }
    }
​
    // find no-output input in current sub graph
    for (int i = 1; i < sub_graphs_count; i++)
    {
        struct subgraph* sub_graph = *(struct subgraph**)get_vector_data(ir_graph->subgraph_list, i);
        for (int j = 0; j < sub_graph->input_num; j++)
        {
            // 遍历subgraph的不是INPUT类型的input tensor
            struct tensor* ir_tensor = ir_graph->tensor_list[sub_graph->input_tensor_list[j]];
​
            if (ir_tensor->tensor_type != TENSOR_TYPE_INPUT)
            {
                // 获取到tensor的生产者node所对应的subgraph
                uint16_t node_id = ir_tensor->producer;
                uint8_t sub_graph_id = ir_graph->node_list[node_id]->subgraph_idx;
                struct subgraph* target_sub_graph = *(struct subgraph**)get_vector_data(ir_graph->subgraph_list, sub_graph_id);
​
                // 如果当前的tensor属于target subgraph的输出tensor,标志位置1,
                // 表示当前的tensor是可以由其生产者node所对应的target subgraph输出得到
                int tensor_mask_as_out_flag = 0;
                for (int k = 0; k < target_sub_graph->output_num; k++)
                {
                    if (target_sub_graph->output_tensor_list[k] == ir_tensor->index)
                    {
                        tensor_mask_as_out_flag = 1;
                        break;
                    }
                }
                // 如果当前input tensor不在target subgraph的output tensor列表中,
                // (个人认为 不太可能出现这种情况,因为这是有遗漏的情况)
                // 我错了我错了,和闲来大佬学到了,比如SSD网络,
                // 中间层有输出的情况,输出的tensor确实把输入tensor遗漏掉了           
                // 因此 要新开辟内存空间,把当前input tensor添加到target subgraph的输出tensor列表中
                if (!tensor_mask_as_out_flag)
                {
                    uint16_t* new_output_tensor_list = (uint16_t*)sys_malloc(sizeof(uint16_t) * (target_sub_graph->output_num + 1));
​
                    memcpy(new_output_tensor_list, target_sub_graph->output_tensor_list, sizeof(uint16_t) * target_sub_graph->output_num);
                    new_output_tensor_list[target_sub_graph->output_num] = ir_tensor->index;
​
                    sys_free(target_sub_graph->output_tensor_list);
                    target_sub_graph->output_tensor_list = new_output_tensor_list;
                    target_sub_graph->output_num += 1;
                }
            }
        }
    }
​
    // fill the wait count
    // 遍历每一个subgraph的每一个输入tensor,判断tensor的类型如果是VAR型,那么当前的subgraph的需要等待的input tensor数要加1
    for (int i = 0; i < sub_graphs_count; i++)
    {
        struct subgraph* sub_graph = *(struct subgraph**)get_vector_data(ir_graph->subgraph_list, i);
        sub_graph->input_wait_count = 0;
​
        for (int j = 0; j < sub_graph->input_num; j++)
        {
            struct tensor* tensor = ir_graph->tensor_list[sub_graph->input_tensor_list[j]];
​
            if (tensor->tensor_type == TENSOR_TYPE_VAR)
                sub_graph->input_wait_count++;
        }
    }
}
原文链接:https://zhuanlan.zhihu.com/p/403372541
作者:闪电侠的右手

推荐阅读

推荐阅读
关注数
3393
内容数
68
Tengine是一款轻量级模块化高性能的神经网络推理引擎 ;欢迎体验Tengine,[链接] 《Tengine开发者入门资料包》[链接]
目录
极术微信服务号
关注极术微信号
实时接收点赞提醒和评论通知
安谋科技学堂公众号
关注安谋科技学堂
实时获取安谋科技及 Arm 教学资源
安谋科技招聘公众号
关注安谋科技招聘
实时获取安谋科技中国职位信息