11

爱笑的小姐姐 · 2022年09月02日

TVM 学习指南(个人版)上

0x0. 前言

最近粗略的看完了天奇大佬的MLC课程(顺便修了一些语法和拼写错误,也算是做了微弱的贡献hh),对TVM的近期发展有了一些新的认识。之前天奇大佬在《新一代深度学习编译技术变革和展望》一文中(链接:https://zhuanlan.zhihu.com/p/446935289)讲解了TVM Unify也即统一多层抽象的概念。这里的统一多层抽象具体包括AutoTensorization用来解决硬件指令声明和张量程序对接,TVM FFI(PackedFunc)机制使得我们可以灵活地引入任意的算子库和运行库函数并且在各个编译模块和自定义模块里面相互调用。TensorIR负责张量级别程序和硬件张量指令的整合。Relax (Relax Next) 引入relay的进一步迭代,直接引入first class symbolic shape的支持 (摘抄自《新一代深度学习编译技术变革和展望》一文)。然后这些抽象可以相互交互和联合优化来构造深度学习模型对应的最终部署形式。我个人感觉TVM Unify类似于MLIR的Dialect,但是这几个抽象的直接交互能力相比于MLIR的逐级lower我感觉是更直观方便的,毕竟是Python First(这个只是我最近看MLC课程的一个感觉)。对这部分内容感兴趣的读者请查看天奇大佬的TVM Unify介绍原文以及MLC课程。

这篇文章我将结合TVM Unify相关的抽象以及之前的一些积累重新梳理一下TVM的整体流程。我会从前端,中端(图优化Pass机制),代码生成(Schedule),Runtime,开发工具几个角度来介绍一遍。我对TVM的代码并没有做到精细的阅读,所以本文将尽量避免涉及到底层C++代码的细枝末节,而是从较为宏观的视角来讲清楚目前TVM的架构。本篇文章的所有参考资料以及idea主要来自我维护的这个仓库(https://github.com/BBuf/tvm_mlir_learn)里面搜集的TVM的相关资料,TVM官方doc以及源码,MLC课程。上面这个仓库基本收集了TVM中文社区里面的大部分高质量博客或者专题,对TVM感兴趣的小伙伴可以自行下载或者收藏,更欢迎点个star。

写作不易,这篇文章对你有用的话也请点个赞👍。文章有错误也请指出,我动态修改。之后的计划应该会学习TVM如何和硬件的指令对接。

0x1. 前端

TVM为了向上兼容所有的机器学习框架如PyTorch,TensorFlow,ONNX等引入了Relay IR,机器学习模型在进入TVM之后首先会被转换为Relay IR。同时TVM为了向下兼容所有的硬件,引入了Tensor IR简称TIR,模型在被编译为指定硬件的源代码之前都会被Lower为TIR。另外,TVM社区正在开发新一代中间表示Relax(也被称为下一代Relay,目前还没有upstream主分支:https://github.com/tlc-pack/relax/tree/relax/python/tvm/relax),Relax是实现前言里面提到的TVM Unify关键的一环。TVM前端的架构可以粗略的表示为:

image.png

TVM前端架构图

接下来我们分别介绍一下 Relay,TIR,Relax这几种不同的前端表示。

0x1.1 Tensor IR(TIR)

由于无论是Relay还是新一代的Relax中间表示,它们最后都会被Lower到TIR(离硬件最近的IR),所以我们这里先介绍一下TIR。TIR的代码被封装在tvm.tir中,一个TIR可以被编译成目标硬件的源代码或者中间表示例如C++源码,CUDA源码,LLVM IR等等。那么TIR是如何被编译为目标硬件的代码呢?这是因为TIR的数据结构其实是一个AST(抽象语法树),然后这个语法树可以表示变量的声明,初始化,变量的计算,函数调用以及控制流(如if-else条件判断,循环等等)等等。所以只要我们遍历一下TIR对应的AST就可以实现一对一的将其翻译到目标硬件了。可以借助这个图来理解:

image.png

原图来自:https://zhuanlan.zhihu.com/p/533161438,侵删

在上图中有几个细节需要解释。首先是IRModule,IRModule 是在机器学习编译中保存元张量函数(也即PrimFunc)集合的容器对象,它是TVM进行编译的最小完整单元。TVM不同的前端表示最终都会被封装到IRModule中进行编译,在Linux下IRModule就是一个.so动态链接库。然后PrimFunc叫作元张量函数,它内部封装了一个完整的TIR AST。当IRModule被编译之后,每个PrimFunc都对应了这个动态库的一个函数入口,因此一个IRModule可以有很多个PrimFunc。然后上面的Codegen实际上就是对TIR AST进行中序遍历然后一对一的将AST Node翻译为相应的TIR Node对应的数据结构并发送给回调函数VisitExpr\_ 和 VisitStmt。VisitExpr\_ 用于处理 Expression Node,而 VisitStmt 用于处理 Statement Node。后续在介绍Codegen的时候我们再仔细探索一下这个转换流程。

这里还需要说明的一点是,在0.8之前的TVM要声明一个TIR AST依赖于对Tensor Expression的编译。现在TVM基于Python AST实现了一种新的特定领域的方言让我们可以直接使用Python来编写TIR AST。我们这里举一个例子:

`@tvm.script.ir_module  
class MyModule:  
    @T.prim_func  
    def mm_relu(A: T.Buffer[(128, 128), "float32"],  
                B: T.Buffer[(128, 128), "float32"],  
                C: T.Buffer[(128, 128), "float32"]):  
        T.func_attr({"global_symbol": "mm_relu", "tir.noalias": True})  
        Y = T.alloc_buffer((128, 128), dtype="float32")  
        for i, j, k in T.grid(128, 128, 128):  
            with T.block("Y"):  
                vi = T.axis.spatial(128, i)  
                vj = T.axis.spatial(128, j)  
                vk = T.axis.reduce(128, k)  
                with T.init():  
                    Y[vi, vj] = T.float32(0)  
                Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj]  
        for i, j in T.grid(128, 128):  
            with T.block("C"):  
                vi = T.axis.spatial(128, i)  
                vj = T.axis.spatial(128, j)  
                C[vi, vj] = T.max(Y[vi, vj], T.float32(0))  
`

它实现的功能对应的numpy代码为:

`def lnumpy_mm_relu(A: np.ndarray, B: np.ndarray, C: np.ndarray):  
    Y = np.empty((128, 128), dtype="float32")  
    for i in range(128):  
        for j in range(128):  
            for k in range(128):  
                if k == 0:  
                    Y[i, j] = 0  
                Y[i, j] = Y[i, j] + A[i, k] * B[k, j]  
    for i in range(128):  
        for j in range(128):  
            C[i, j] = max(Y[i, j], 0)  
`

其中,@tvm.script.ir_module表示被修饰的MyModule是一个待编译的IRModule,而@T.prim_func表示被修饰的main函数是元张量函数(PrimFunc),这个函数内部定义的就是TIR AST。

0x1.2 了解tvm.ir基础设施

继续讲Relay IR以及Relax之前我们先了解一下tvm.ir这个抽象,无论是TIR还是Relay/Relax IR它们都对应了IRModule这个统一的最小编译单元,同时它们也对应的有一套共用的IR基础设置,具体实现在[https://github.com/apache/tvm/tree/main/include/tvm/ir](https://github.com/apache/tvm/tree/main/include/tvm/ir)[https://github.com/apache/tvm/tree/main/src/ir](https://github.com/apache/tvm/tree/main/src/ir)目录下。

image.png

tvm.ir基础设施文件结构

对于IR来说,Type和Expr是尤为关键的两个概念。Type包含基础的数据类型如Int,Float,Double等等,也包含一些自定义的复杂类型比如函数类型,Tensor类型等。而对于Expr来说,既包含可以直接映射到Low-level IR的PrimExpr,又包含RelayExpr。

我们可以在[https://github.com/apache/tvm/blob/main/include/tvm/ir/type.h](https://github.com/apache/tvm/blob/main/include/tvm/ir/type.h)中看到对PrimTypeNode的定义:

/*!  
 * \brief Primitive data types used in the low-level IR.  
 *  
 * PrimType represents POD-values and handles that are  
 * not automatically managed by the runtime.  
 *  
 * \sa PrimType  
 */  
class PrimTypeNode : public TypeNode {  
 public:  
  /*!  
   * \brief The corresponding dtype field.  
   */  
  runtime::DataType dtype;  
 ...  
};  
  

可以看到PrimType可以直接对应到Low-level IR的基础数据类型。我们还可以找到FuncTypeNode的定义:

`/*!  
 * \brief Function type.  
 *  
 * We support polymorphic function type.  
 * This can be roughly viewed as template function in C++.  
 *  
 * \sa FuncType, TypeVar, TypeConstraint  
 */  
class FuncTypeNode : public TypeNode {  
 public:  
  /*! \brief type type of arguments */  
  Array<Type> arg_types;  
  /*! \brief The type of return value. */  
  Type ret_type;  
  // The following fields are used in polymorphic(template) functions  
  // For normal functions, the following two fields will be empty.  
  /*! \brief The type parameters of the function */  
  Array<TypeVar> type_params;  
  /*!  
   * \brief potential constraint the type need to obey  
   * \note this field is reserved for futher purposes.  
   */  
  Array<TypeConstraint> type_constraints;  
  ...  
};  
`

从注释可以看到FuncType类似C++的模板函数,记录了函数的参数类型和返回值类型以及模板参数,约束等信息。然后我们还可以关注一下和深度学习模型结合得很紧密的TensorTypeNode类型。

`/*!  
 * \brief This is the most commonly used type in relay.  
 *  TensorType have a fixed dimension, data type.  
 *  
 *  The elements of shape can be either IntImm(constant integer),  
 *  or any symbolic integer expression.  
 *  The symbolic integer allows generic shape inference in certain cases.  
 * \sa TensorType  
 */  
class TensorTypeNode : public BaseTensorTypeNode {  
 public:  
  /*!  
   * \brief The shape of the tensor,  
   *  represented by PrimExpr(tvm::Expr).  
   */  
  Array<PrimExpr> shape;  
  /*! \brief The content data type */  
  DataType dtype;  
 ...  
}  
`

我们从TensorTypeNode的定义可以看到shape也是TensorType的一部分,所以TVM在做类型推断的时候也包含了Shape的推断。也正是因为在IR中Shape是Type的一部分(比如Tensor[(m, n)]Tensor[(m, 4)]是不同的Type)导致TVM对动态Shape的支持非常困难,因为Expr的类型推断是不支持动态Shape的。这里需要提一下,Relax通过引入一个新的Type叫作DynTensor较好的解决了动态Shape的表示问题,DynTensor包含的信息是Dtype和Shape的纬度,但Shape本身的表达式是独立存储的。也就是Tensor[(m, n)]Tensor[(_, _)]都是同一个Type, 但是Tensor[(_, _)]Tensor[(_, _, _)]是不同的Type,这样就从原生上支持了动态Shape。我们从https://github.com/tlc-pack/relax/blob/95035621177fa0be4adfb55c766f030563e515a5/include/tvm/relax/type.h#L78这里可以看到DynTensor的定义:

`class DynTensorTypeNode : public BaseTensorTypeNode {  
 public:  
  /*!  
   * \brief The number of dimensions of the tensor, use -1 to denote tensor with unknwon number of  
   * dimensions.  
   */  
  int ndim; //现在直接定义ndim而不是shape  
  /*! \brief The content data type, use void to denote the dtype is unknown. */  
  DataType dtype;  
  ...  
};  
`

我们紧接着看一下Expr的定义([https://github.com/apache/tvm/blob/main/include/tvm/ir/expr.h](https://github.com/apache/tvm/blob/main/include/tvm/ir/expr.h)),Expr分成PrimExpr以及RelayExpr。其中PrimExpr保存了一个runtime时候的Dtype,然后

`/*!  
 * \brief Base node of all primitive expressions.  
 *  
 *  A primitive expression deals with low-level  
 *  POD data types and handles without  
 *  doing life-cycle management for objects.  
 *  
 *  PrimExpr is used in the low-level code  
 *  optimizations and integer analysis.  
 *  
 * \sa PrimExpr  
 */  
class PrimExprNode : public BaseExprNode {  
 public:  
  // runtime::DataType(dtype) 在编译时和运行时提供粗粒度类型信息。   
  // 它动态地内置在 PrimExpr 表达式构造中,可用于快速类型检查。  
  // 当 PrimExpr 对应于 i32 等 POD 值类型时,dtype 足以决定 PrimExpr 的 Type。  
  //  当 dtype 为 DataType::Handle() 时,表达式可以对应更细粒度的 Type,我们可以通过lazy类型推断得到类型。  
  DataType dtype;  
  }  
`

例如表示一个整数的Expr就可以通过继承PrimExprNode来实现,IntImm表示的是整数字面值表达式,所以它记录了一个int类型的value成员。

`// PrimExprs that are useful as runtime containers.  
//  
/*!  
 * \brief Constant integer literals in the program.  
 * \sa IntImm  
 */  
class IntImmNode : public PrimExprNode {  
 public:  
  /*! \brief the Internal value. */  
  int64_t value;  
 ...  
};  
`

RelayExpr的定义如下:

`/*!  
 * \brief 所有非Prim Expr的基础节点  
 *  
 * RelayExpr 支持张量类型、函数和 ADT 作为  
 * 一等公民。 对象对应的生命周期  
 * 由语言隐式管理。  
 *  
 * \sa RelayExpr  
 */  
class RelayExprNode : public BaseExprNode {  
 public:  
  /*!  
   * \brief 存储类型推断(类型检查)的结果。  
   *  
   * \note 这可以在类型推断之前未定义。 该值在序列化期间被丢弃。  
   */  
  mutable Type checked_type_ = Type(nullptr);  
  /*!  
   * \return The checked_type  
   */  
  inline const Type& checked_type() const;  
  /*!  
   * \brief 检查 Expr 的推断(检查)类型是否由 TTypeNode 支持并返回。  
   *  
   * \note 如果这个 Expr 的节点类型不是 TTypeNode,这个函数会抛出一个错误。  
   *  
   * \return 对应的 TTypeNode 指针。  
   * \tparam 我们寻找的特定 TypeNode。  
   */  
  template <typename TTypeNode>  
  inline const TTypeNode* type_as() const;  
  
  ...  
};  
`

总的来说,无论是高级别的Relay,Relax还是低级别的TIR,它们最终都是由这里的Expr和Type为基础来表达的。因为对于Relay和TIR来讲,它们的op定义都是继承自RelayExprNode:[https://github.com/apache/tvm/blob/main/include/tvm/ir/op.h#L58](https://github.com/apache/tvm/blob/main/include/tvm/ir/op.h#L58)。除了对Op名字,类型以及参数,属性等定义外还有一个特殊的参数support_level,从注释上看应该是用来解释当前Op的等级,值越小表示这种Op类型等级越高(暂不清楚具体的作用)。

`// TODO(tvm-team): migrate low-level intrinsics to use Op  
/*!  
 * \brief Primitive Op(builtin intrinsics)  
 *  
 * This data structure stores the meta-data  
 * about primitive operators that can be invoked via Call.  
 *  
 * Low-level IR intrinsics(such as libc.expf) are also  
 * implemented via Op.  
 *  
 * \sa Op  
 */  
class OpNode : public RelayExprNode {  
 public:  
  /*! \brief name of the operator */  
  String name;  
  /*! \brief the type of the operator */  
  mutable FuncType op_type;  
  /*!  
   * \brief detailed description of the operator  
   *  This can be used to generate docstring automatically for the operator.  
   */  
  String description;  
  /* \brief Information of input arguments to the operator */  
  Array<AttrFieldInfo> arguments;  
  /*!  
   * \brief The type key of the attribute field  
   *  This can be empty, in which case it defaults to anything.  
   */  
  String attrs_type_key;  
  /*!  
   * \brief attribute type index,  
   * this field varies in each run and is not exposed to frontend.  
   */  
  uint32_t attrs_type_index{0};  
  /*!  
   * \brief number of input arguments to the operator,  
   * -1 means it is variable length  
   */  
  int32_t num_inputs = -1;  
  /*!  
   * \brief support level of the operator,  
   *  The lower the more priority it contains.  
   *  This is in analogies to BLAS levels.  
   */  
  int32_t support_level = 10;  
 ...  
};  
`

最后我们看一下IRModule的定义,[https://github.com/apache/tvm/blob/main/include/tvm/ir/module.h#L56](https://github.com/apache/tvm/blob/main/include/tvm/ir/module.h#L56)。我们说过IRModule是TVM编译的最小单元,我们可以从它的定义中发现它就是一系列BaseFunc(在下一节Relay的介绍中我们会讲到它的实现)的映射。

`/*!  
 * \brief IRModule that holds functions and type definitions.  
 *  
 *  IRModule is the basic unit for all IR transformations across the stack.  
 *  
 *  Many operations require access to the global IRModule.  
 *  We pass the IRModule by value in a functional style as an explicit argument,  
 *  but we mutate the Module while optimizing programs.  
 * \sa IRModule  
 */  
class IRModuleNode : public Object {  
 public:  
  /*! \brief A map from ids to all global functions. */  
  Map<GlobalVar, BaseFunc> functions;  
  /*! \brief A map from global type vars to ADT type data. */  
  Map<GlobalTypeVar, TypeData> type_definitions;  
  /*! \brief The source map for the module. */  
  parser::SourceMap source_map;  
  /* \brief Additional attributes storing meta-data about the module. */  
  DictAttrs attrs;  
  ...  
  }  
`

其中type_definitions是对ADT的定义,本文不关注Relay中函数式编程的概念,所以不展开ADT以及Let Binding部分的概念和源码,感兴趣的朋友可以参考张伟大佬的这篇文章或者官方文档对Relay的介绍学习一下:https://zhuanlan.zhihu.com/p/446976730 。后面在介绍Relax IR的时候我们会看到,实际上Relax相比于Relay就类似于TensorFlow的静态图到PyTorch动态图的过度,更加强调数据流图的概念而非函数式编程的概念,我个人感觉也是为了易用性考虑吧。

0x1.3 Relay IR

接下来我们简单介绍一下Relay IR。首先Relay IR目前仍然是TVM和其它深度学习框架对接的主要方式,我之前在《【从零开始学TVM】三,基于ONNX模型结构了解TVM的前端》文章中以ONNX为例介绍了模型是如何转换为Relay IR的,然后这个Relay IR会被进一步封装为IRModule给TVM编译。

从源码角度来看,Relay的基类Expr就是tvm.ir基础设施中定义的RelayIR([https://github.com/apache/tvm/blob/main/include/tvm/relay/expr.h#L54](https://github.com/apache/tvm/blob/main/include/tvm/relay/expr.h#L54))。

`namespace relay {  
  
using Expr = tvm::RelayExpr;  
using ExprNode = tvm::RelayExprNode;  
using BaseFunc = tvm::BaseFunc;  
using BaseFuncNode = tvm::BaseFuncNode;  
using GlobalVar = tvm::GlobalVar;  
using GlobalVarNode = tvm::GlobalVarNode;  
using tvm::PrettyPrint;  
`

然后Relay还定义了ConstantExpr,TupleExpr,VarExpr,CallNodeExpr,LetNodeExpr,IfNodeExpr等多种Expr。我们可以看一下ConstantExprNode的定义,类定义中声明了数据data并定义了tensor_type方法返回data的类型,然后is_scalar函数用来判断这个常量是否为标量。

`*!  
 * \brief Constant tensor type.  
 */  
class ConstantNode : public ExprNode {  
 public:  
  /*! \brief The data of the tensor */  
  runtime::NDArray data;  
  
  /*! \return The corresponding tensor type of the data */  
  TensorType tensor_type() const;  
  
  /*! \return Whether it is scalar(rank-0 tensor) */  
  bool is_scalar() const { return data->ndim == 0; }  
  
 ...  
};  
`

然后我们再看一下VarNode的定义,Var就是Relay里面的变量,它的定义如下:

`/*! \brief Container for Var */  
class VarNode : public ExprNode {  
 public:  
  /*!  
   * \brief The unique identifier of the Var.  
   *  
   * vid will be preserved for the same Var during type inference  
   * and other rewritings, while the VarNode might be recreated  
   * to attach additional information.  
   * This property can be used to keep track of parameter Var  
   * information across passes.  
   */  
  Id vid;  
  /*!  
   * \brief type annotaion of the variable.  
   * This field records user provided type annotation of the Var.  
   * This field is optional and can be None.  
   */  
  Type type_annotation;  
  
  /*! \return The name hint of the variable */  
  const String& name_hint() const { return vid->name_hint; }  
};  
`

首先Id vid表示的就是变量的名称,可以理解为一个字符串,比如我们在可视化Relay IR时看到的以@开头的全局变量以及%开头的局部变量。这里的type_annotation表示变量的类型注释,这个字段是可选的。接下来我们再看一个FunctionNode的定义,FunctionNode就是IRModule中的BaseFunc在Relay里面的具体实现了:

`/*!  
 * \brief Relay Function container  
 * \sa Function  
 */  
class FunctionNode : public BaseFuncNode {  
 public:  
  /*! \brief Function parameters */  
  tvm::Array<Var> params;  
  /*!  
   * \brief  
   * The expression which represents the computation of the function,  
   * the expression may reference the parameters, and the type of it  
   * or sub-expressions may reference the type variables.  
   */  
  Expr body;  
  /*! \brief User annotated return type of the function. */  
  Type ret_type;  
  /*!  
   * \brief Type parameters of the function.  
   *  Enables the function to vary its type based on these.  
   *  This corresponds to template paramaters in c++'s terminology.  
   *  
   * \note This can be usually empty for non-polymorphic functions.  
   */  
  tvm::Array<TypeVar> type_params;  
}  
`

FunctionNode的定义中有函数参数,函数体以及返回值类型和参数类型。其它类型的Relay表达式定义我们就不看了,感兴趣的读者可以直接在[https://github.com/apache/tvm/tree/main/include/tvm/relay](https://github.com/apache/tvm/tree/main/include/tvm/relay)阅读。

接下来我们解析一下Relay中的Op定义,上一节tvm.ir基础设施中我们已经提到无论是Relay还是TIR的Op都定义为一种RelayExpr,也就是OpNode的定义。我们这里看一个Relay定义的bias\_add Op的例子来加深理解。

首先,我们为BiasAdd Op定一个属性类型记录它所有的属性,[https://github.com/apache/tvm/blob/main/include/tvm/relay/attrs/nn.h#L35-L48](https://github.com/apache/tvm/blob/main/include/tvm/relay/attrs/nn.h#L35-L48),属性定义时我们还可以给属性设置描述和默认值:

`/*!  
 * \brief Add a 1D Tensor to an axis of a data.  
 *  
 * \note bias_add is a special add operator that is in nn  
 *   and enables automatic derivation of bias's shape.  
 *   You can directly use add for more generalized case.  
 */  
struct BiasAddAttrs : public tvm::AttrsNode<BiasAddAttrs> {  
  int axis;  
  
  TVM_DECLARE_ATTRS(BiasAddAttrs, "relay.attrs.BiasAddAttrs") {  
    TVM_ATTR_FIELD(axis).describe("The axis to add the bias").set_default(1);  
  }  
};  
`

第二步,我们给Biass Add Op定义类型推断函数([https://github.com/apache/tvm/blob/main/src/relay/op/nn/nn.cc#L52](https://github.com/apache/tvm/blob/main/src/relay/op/nn/nn.cc#L52)):

`bool BiasAddRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,  
                const TypeReporter& reporter) {  
  ICHECK_EQ(types.size(), 3);  
  const auto* data = types[0].as<TensorTypeNode>();  
  if (data == nullptr) return false;  
  
  const BiasAddAttrs* param = attrs.as<BiasAddAttrs>();  
  ICHECK(param != nullptr);  
  int axis = param->axis;  
  if (axis < 0) {  
    axis = data->shape.size() + axis;  
  }  
  if (axis >= static_cast<int>(data->shape.size()) || axis < 0) {  
    reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan())  
                                     << "The axis in bias_add must be in range for the shape; "  
                                     << "attempted to access index " << param->axis << " of "  
                                     << PrettyPrint(data->shape));  
    return false;  
  }  
  
  // assign output type  
  reporter->Assign(types[1], TensorType({data->shape[axis]}, data->dtype));  
  reporter->Assign(types[2], types[0]);  
  return true;  
}  
`

假设这里指定的操作是 c = nn.bias\_add(a , b),这里的逻辑就是根据输入a的类型推断b和c的类型并重写(Assign)。

第三步,我们把nn.BiasAdd Op注册到全局表中([https://github.com/apache/tvm/blob/main/src/relay/op/nn/nn.cc#L88-L103](https://github.com/apache/tvm/blob/main/src/relay/op/nn/nn.cc#L88-L103)):

`RELAY_REGISTER_OP("nn.bias_add")  
    .describe(R"code(Add bias to an axis of the input.  
)code" TVM_ADD_FILELINE)  
    .set_attrs_type<BiasAddAttrs>()  
    .set_num_inputs(2)  
    .add_argument("data", "nD Tensor", "Input data.")  
    .add_argument("bias", "1D Tensor", "Bias.")  
    .set_support_level(1)  
    .add_type_rel("BiasAdd", BiasAddRel)  
    .set_attr<TOpPattern>("TOpPattern", kBroadcast)  
    .set_attr<FTVMCompute>("FTVMCompute", [](const Attrs& attrs, const Array<te::Tensor>& inputs,  
                                             const Type& out_type) {  
      const auto* param = attrs.as<BiasAddAttrs>();  
      return tvm::Array<tvm::te::Tensor>{topi::nn::bias_add(inputs[0], inputs[1], param->axis)};  
    });  
`

注意到这里的op name/describe/num_inputs/arguments/support_level是对应了OpNode类的成员,然后OpNode还有一个attrs_type_key和attrs_type_index成员对应的就是BiasAddAttrs了。然后我们再看一下这个FTVMCompute这个用来描述Op计算逻辑的额外属性,它使用Op的输入,属性参数以及输出类型来确定这个Op的计算逻辑。

到这里可能你还有一个疑问,我们知道TVM的核心是计算和调度分离,Relay Op的调度逻辑是怎么注册的呢

TVM没有为每个Relay OP注册compute和schedule,而是为其注册fcompute和fschedule,然后根据输入和属性参数,输出类型等生成对应的compute和schedul,这种compute和schedule的组合对应了OpImplementation([https://github.com/apache/tvm/blob/main/include/tvm/relay/op_strategy.h#L39](https://github.com/apache/tvm/blob/main/include/tvm/relay/op_strategy.h#L39))。

`/*!  
 * \brief Operator implementation that includes compute and schedule function.  
 */  
class OpImplementationNode : public Object {  
 public:  
  /*! \brief Compute function */  
  FTVMCompute fcompute;  
  /*! \brief Schedule function */  
  FTVMSchedule fschedule;  
  /*! \brief Name of the implementation */  
  String name;  
  /*! \brief Priority level */  
  int plevel;  
  
  void VisitAttrs(tvm::AttrVisitor* v) {  
    v->Visit("name", &name);  
    v->Visit("plevel", &plevel);  
  }  

 

 static constexpr const char* _type_key = "relay.OpImplementation";  
  TVM_DECLARE_FINAL_OBJECT_INFO(OpImplementationNode, Object);  
};  
  
/*!  
 * \brief Operator implementation class.  
 */  
class OpImplementation : public ObjectRef {  
 public:  
  /*!  
   * \brief Invoke the operator compute function.  
   * \param attrs The attribute of the primitive  
   * \param inputs The input tensors.  
   * \param out_type The output type information.  
   * \return The output compute description of the operator.  
   */  
  TVM_DLL Array<te::Tensor> Compute(const Attrs& attrs, const Array<te::Tensor>& inputs,  
                                    const Type& out_type);  
  /*!  
   * \brief Build the computation schedule.  
   * \param attrs The attribute of the node.  
   * \param outs The output tensors.  
   * \param target The build target.  
   * \return The computation schedule.  
   */  
  TVM_DLL te::Schedule Schedule(const Attrs& attrs, const Array<te::Tensor>& outs,  
                                const Target& target);  
  
  TVM_DEFINE_OBJECT_REF_METHODS(OpImplementation, ObjectRef, OpImplementationNode);  
};  
`

从OpImplementation类的实现我们看出,它的Compute和Schedule就是根据fcompute和fschedule来生成的。

`Array<te::Tensor> OpImplementation::Compute(const Attrs& attrs, const Array<te::Tensor>& inputs,  
                                            const Type& out_type) {  
  return (*this)->fcompute(attrs, inputs, out_type);  
}  
  
te::Schedule OpImplementation::Schedule(const Attrs& attrs, const Array<te::Tensor>& outs,  
                                        const Target& target) {  
  return (*this)->fschedule(attrs, outs, target);  
}  
`

然后由于特定的OpImplementation需要特定的条件,所以又按照这个条件(condition)进行分组,每一组被叫作OpSpecialization

`/*!  
 * \brief Specialized implementations for operators under certain conditions.  
 */  
class OpSpecializationNode : public Object {  
 public:  
  /*! \brief List of implementations. */  
  Array<OpImplementation> implementations;  
  /*! \brief Condition to enable the specialization.  
   *    Could be undefined to represent generic case. */  
  te::SpecializedCondition condition;  
  
  void VisitAttrs(tvm::AttrVisitor* v) {  
    v->Visit("condition", &condition);  
    v->Visit("implementations", &implementations);  
  }  
  
  static constexpr const char* _type_key = "relay.OpSpecialization";  
  TVM_DECLARE_FINAL_OBJECT_INFO(OpSpecializationNode, ExprNode);  
};  
  
`

最后使用一个OpStrategy类来记录这个Relay Op的所有OpImplementation。([https://github.com/apache/tvm/blob/main/include/tvm/relay/op_strategy.h#L130](https://github.com/apache/tvm/blob/main/include/tvm/relay/op_strategy.h#L130)

`/*!  
 * \brief Operator strategy to choose implementation.  
 */  
class OpStrategyNode : public Object {  
 public:  
  /*! \brief List of operator specializations. */  
  Array<OpSpecialization> specializations;  
  
  void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("specializations", &specializations); }  
  
  static constexpr const char* _type_key = "relay.OpStrategy";  
  TVM_DECLARE_FINAL_OBJECT_INFO(OpStrategyNode, ExprNode);  
};  
  
/*!  
 * \brief Operator strategy class.  
 */  
class OpStrategy : public ObjectRef {  
 public:  
  /*!  
   * \brief Add an implementation.  
   * \param fcompute Compute function  
   * \param fschedule Schedule function  
   * \param name Name of the implementation  
   * \param plevel Priority level of the implementation  
   */  
  TVM_DLL void AddImplementation(FTVMCompute fcompute, FTVMSchedule fschedule, String name,  
                                 int plevel);  
  
  TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(OpStrategy, ObjectRef, OpStrategyNode);  
};  
`

其中,AddImplementation函数通过FFI机制在Python层也可以调用,大多数的Relay Op都是在Python端注册它的Strategy。我们以Relay的nn.Softmax Op为例看一下,它的Strategy(包含fcompute+fschedule)注册在[https://github.com/apache/tvm/blob/main/python/tvm/relay/op/strategy/generic.py#L152](https://github.com/apache/tvm/blob/main/python/tvm/relay/op/strategy/generic.py#L152) 和 [https://github.com/apache/tvm/blob/main/python/tvm/relay/op/strategy/cuda.py#L78-L94](https://github.com/apache/tvm/blob/main/python/tvm/relay/op/strategy/cuda.py#L78-L94)

`@override_native_generic_func("softmax_strategy")  
def softmax_strategy(attrs, inputs, out_type, target):  
    """softmax generic strategy"""  
    strategy = _op.OpStrategy()  
    strategy.add_implementation(  
        wrap_compute_softmax(topi.nn.softmax),  
        wrap_topi_schedule(topi.generic.schedule_softmax),  
        name="softmax.generic",  
    )  
    return strategy  
  
@softmax_strategy.register(["cuda", "gpu"])  
def softmax_strategy_cuda(attrs, inputs, out_type, target):  
    """softmax cuda strategy"""  
    strategy = _op.OpStrategy()  
    strategy.add_implementation(  
        wrap_compute_softmax(topi.nn.softmax),  
        wrap_topi_schedule(topi.cuda.schedule_softmax),  
        name="softmax.cuda",  
    )  
    if target.kind.name == "cuda" and "cudnn" in target.libs:  
        strategy.add_implementation(  
            wrap_compute_softmax(topi.cuda.softmax_cudnn),  
            wrap_topi_schedule(topi.cuda.schedule_softmax_cudnn),  
            name="softmax.cudnn",  
            plevel=15,  
        )  
    return strategy  
  
`

然后在[https://github.com/apache/tvm/blob/main/python/tvm/relay/op/nn/_nn.py#L40](https://github.com/apache/tvm/blob/main/python/tvm/relay/op/nn/_nn.py#L40)将实现的Strategy注册到nn.softmax op。

`# softmax  
reg.register_strategy("nn.softmax", strategy.softmax_strategy)  
`

其实Relay Op除了Strategy属性之外,还又一些其它的属性,比如我们在[https://github.com/apache/tvm/blob/main/src/relay/op/nn/convolution.cc#L176](https://github.com/apache/tvm/blob/main/src/relay/op/nn/convolution.cc#L176) 这里可以看到Op还可以有FInferCorrectLayout和TOpPattern属性用于后续优化(比如算符融合Pass就依赖了TOpPattern属性,Ansor的data layerout transform依赖FInferCorrectLayout属性)。

``RELAY_REGISTER_OP("nn.conv1d")  
    .describe(R"code(1D convolution layer (e.g. spatial convolution over sequences).  
This layer creates a convolution kernel that is convolved  
with the layer input to produce a tensor of outputs.  
- **data**: This depends on the `layout` parameter. Input is 3D array of shape  
            (batch_size, in_channels, width) if `layout` is `NCW`.  
- **weight**: (channels, in_channels, kernel_size)  
- **out**:  This depends on the `layout` parameter. Output is 3D array of shape  
            (batch_size, channels, out_width) if `layout` is `NCW`.  
)code" TVM_ADD_FILELINE)  
    .set_attrs_type<Conv1DAttrs>()  
    .set_num_inputs(2)  
    .add_argument("data", "Tensor", "The input tensor.")  
    .add_argument("weight", "Tensor", "The weight tensor.")  
    .set_support_level(2)  
    .add_type_rel("Conv1D", Conv1DRel)  
    .set_attr<FInferCorrectLayout>("FInferCorrectLayout", ConvInferCorrectLayout<Conv1DAttrs>)  
    .set_attr<TOpPattern>("TOpPattern", kOutEWiseFusable);  
  
``

Relay就暂时讲到这里,Relay IR做为函数式风格的IR目前是TVM和其它深度学习框架交互的桥梁并且也经历了多年的维护完备性是比较好的(支持TensorFlow,PyTorch,Paddle,OneFlow各种主流深度学习框架)。但Relay的缺点在于由于共用了TVM的 tvm.ir 基础设施没办法支持Dynamic Shape导致Relay IR也无法支持Dynamic Shape,并且Relay IR这种函数式编程的风格相比于数据流图形式的计算图来说不是太直观。

0x1.4 Relax

由于Relax这个前端还没有正式upstream到apache tvm主分支,所以我这里就不从源码的角度来看。我们可以从Relax的wiki发现它不仅原生的支持动态Shape(通过提供DynTensor的抽象并将Shape从Tensor的type中分离出来实现的)还做了一个TVM Unify抽象,也就是天奇在《新一代深度学习编译技术变革和展望》一文中提到的,这个特点可以让不同的抽象之间相互交互和联合优化。这里提到的抽象包含AutoTensorization用来解决硬件指令声明和张量程序对接,TVM FFI(PackedFunc)机制使得我们可以灵活地引入任意的算子库和运行库函数并且在各个编译模块和自定义模块里面相互调用。TensorIR负责张量级别程序和硬件张量指令的整合。还有这里的Relax (Relax Next)。我们可以从下面的例子体会:

`import tvm.script  
from tvm.script import tir as T, relax as R  
  
@tvm.script.ir_module  
class MyIRModule:  
    @T.prim_func  
    def tir_exp_func(x: T.handle, y: T.handle): ## <= D2  
        X = T.match_buffer(x, (n,), "float32")  
        Y = T.match_buffer(y, (n,), "float32")  
        with T.grid(n) as i:  
            Y[i] = T.exp(X[i])   
  
    @R.function  
    def relax_func(x: R.Tensor[(n, k), "f32"], w: R.Tensor[_, "f32"]):  
        # n, k above are implicitly defined by the signature  
        # so we will be able to refer to n, k in the later part of the program  
        with R.dataflow(): ### <= D0  
            lv0 = R.match_shape(w, (k, m)) ## <= D1  
            lv1: R.Tensor[(n, m), "f32"] = R.dot(x, lv0)  
            lv2: R.Tensor[(n * m,), "f32"] = R.flatten(lv1) ## <= D1  
            lv3: R.Shape = (n * m,)  ## <= D1   
            gv0: R.Tensor[lv2, "f32"] = R.call_tir(lv2, tir_exp_func, [lv3])   ## <= D2  
            R.outputs(gv0)  
  
        R.call_packed("custom_inplace_update", gv0)  ## <= D0, D2  
        return gv0   
`

注意这里展示的代码片段是Relax wiki提供的,由于没有upstream主分支,它的用法也许会有微小变化。我们从这个代码中可以看到,Relax把Relax Function和TIR Function放到了同一个IRModule(最小的编译单元)也就是说在任意时刻我们都可以同时拿到这两个不同层次的IR进行修改(或者说联合优化)这就摆脱了编译器范式里因为Lower导致丢失高层语义信息无法联合优化的问题。知乎上思远指出了一个很经典的例子,我这里附上他回答链接([https://www.zhihu.com/question/522101384/answer/2391922144](https://www.zhihu.com/question/522101384/answer/2391922144))并截图来说明一下:

image.png
接下来我们翻译一下Relax的设计关键点来进一步体会Relax相比于Relay的变化(中间插了一些个人理解)。

D0:数据流块作为第一优先级的构造

大部分的relax_func都封装在with R.dataflow()构造里面。数据流块下的所有操作都是没有副作用的,并且不包含高级的控制流(比如if-then-else)或者嵌套区域。

一个数据流块可以有效地视为嵌入在程序里面的计算图。请注意,数据流块里面的大多数绑定变量(上面Relax脚本中的lv0, lv1, lv2, lv3)是local的,这意味着它们仅是块内可见的。这些变量可以被视为计算图的“内部节点”。我们可以将变量标记为输出(gv0),在这种情况下,该变量将在程序的后面部分可见。这些输出变量可以被视为计算图中的输出节点。

请注意, R.call_packed("custom_inplace_update", gv0) 在数据流块之外。数据流块之外的所有内容都可能产生副作用。因此,除非我们进行更仔细的分析,否则我们无法执行优化,例如根据拓扑顺序重新排序这些绑定。我们预计大多数优化将发生在数据流块级别。这些优化可以由熟悉计算图概念的 ML 工程师完成。隔离和表示有效组件的能力还为需要它们的地方提供了更高级别的优化机会。

D1:形状推导作为第一优先级的计算

形状推导对于动态模型工作负载至关重要。在动态形状设置下,我们通常需要在运行计算之前计算中间张量的形状。此外,我们还需要处理形状本身依赖于数据(例如unique op)的情况。最后,大多数动态形状工作负载仍然包含大量(部分)静态形状,理想情况下,我们希望利用这些静态形状信息进行优化。

`from tvm.script import relax as R  
  
@R.function  
def shape_example(x: R.Tensor[(n, 2, 2), "f32"]):  
    with R.dataflow():  
        # symbolic and static shape deduction  
        lv0: R.Tensor[(n, 4), "f32"] = R.reshape(x, (n, 4))   
        lv1: R.Tensor[(n * 4,), "f32"] = R.flatten(lv0)  
        lv2: R.Shape = (n * 4,)  
        # external opaque shape function  
        lv3: R.Shape = R.call_packed("myshape_func", lv2)  
        lv4: R.Tensor[lv3, "f32"] = R.call_tir(lv3, "custom_func", [lv1])   
        # data dependent case  
        lv5: R.Tensor[_, "f32"] = R.unique(lv4)  
        # re-match shape  
        lv6: R.Tensor[(m,), "f32"] = R.match_shape(lv5, (m,))  
        gv0: R.Tensor[(m,), "f32"] = R.exp(lv6)  
        R.outputs(gv0)  
    return gv0  
`

上述程序涵盖了形状推断的典型场景(在注释中标记)。重要的是,形状现在与张量值一起成为计算的一部分。这反映了形状的计算可以在运行时发生的事实。

而文本格式类型注释 lv0: R.Tensor[(n, 4), "f32"] 显示了每个Shape的值。这只是一个语法糖,从 IR 的角度来看,Shape字段 (n, 4) 不是 lv0.checked_type 的一部分。lv0 的类型是 DynTensor(rank=2, dtype="f32"),Shape是附加到每个 Expr 的特殊值字段。我们做出这个显式的选择是为了简化类型推断,这样我们就不需要进入完全依赖类型的领域。

有两个与符号Shape计算相关的关键结构:

D1a: match\_shape

value = match_shape(lhs, pattern)

形状匹配构造接受一个 lhs 值和pattern(整型符号表达式)。它有两个重载语义:

  • 当 lhs 为 Tensor 时,将 lhs.shape 匹配到 pattern 中,如果第一次出现在 pattern 中,则填充对应的整型符号变量,然后返回一个与 lhs 相同但 shape 字段更新为 pattern 的 Tensor。
  • lhs 也可以是直接匹配 pattern 的 Shape。当我们想要分离出不对应于任何张量值的 Shape 函数时,这很有用。

比如:

`from tvm.script import relax as R  
  
@R.function  
def shape_example(x: R.Tensor[_, "f32"], y: R.Tensor[_, "f32"]):  
    with R.dataflow():  
        # the match shape defines n, m because it appears for the first time  
        lv0: R.Tensor[(n, m)] = R.match_shape(x, (n, m))  
        # the second occurance of n, m will translate into an assertion   
        # that y's shape equals (n, m)  
        lv1: R.Tensor[(n, m)] = R.match_shape(y, (n, m))   
        # we can also call match_shape on shape expressions  
        lv2: Shape = R.match_shape(R.shape_of(y), (n, m))   
`

特别注意这里lv2的Shape就被设置为(n, m),并且match_shape的lhs是一个Shape表达式,而不是Tensor。

D1b. 从符号整数元组构造Shape

在我们得到 n 和 m 等符号化整数之后。我们可以将它们重新组合在一起以形成一个 Expr。任何符号整数表达式的元组都可以在 Relax 中被识别为Shape 值。 比如 (n, m) 就是一个表示 Shape 的值。

Shape传播的方法

重要的是,现在Shape是计算过程中值的一部分。编译时Shape推断可以被看作是对发生在Shape上的操作的常量折叠,程序有几种Shape计算的方法:

  • 方法1: 符号化的形状传播。可以将Shape分解为符号整数比如上个脚本中的n和m,然后我们可以使用符号整数的表达式来代表Shape的计算比如(n*4)。值得注意的是,静态形状是符号整数的一种特殊情况,然后我们可以重新组合符号整数来构造一个新的Shape如(n*4)
  • 方法2: 不透明的Shape函数调用。我们还可以实现不透明的Shape函数比如myshape_func(看上上个Relax脚本),这些不透明的Shape函数是快速破解运行时Shape函数的有用fallback(这里应该是说加上手工干预的形状推导?)。
  • 方法3:对于数据相关的Shape(如Unique),我们将简单地推迟到一个运行时的调用 f(inputs)->outpus 它接收一个输入张量,分配并返回输出张量。然后我们可以通过match\_shape构造从Tensor值中获得lv5的形状。(看上上个Relax脚本)
Implications for pass writing

很多优化Pass都需要知道Shape信息。既然很多Shape可以是符号化的比如 (n, 4),那么理想的优化Pass将需要更泛化一点以利用符号信息。比如在上述脚本中,我们知道所有的n都对应同一个值。这种约束很有用。因为符号化的整数(我们之前讲过对应 tir.PrimExpr )动态的执行常量折叠,当输入是静态Shape时计算的结果也应该动态的折叠为整形常数,保留我们执行静态Shape优化依赖的属性。因为我们现在可以在元组(n, 4)表示混合的静态符号Shape,所以我们可以尝试利用静态信息进行额外的优化。

D2:与 TensorIR 和 PackedFunc 直接交互

我们做出的最后一个关键设计决策是允许高层 IR 能够直接交互并调用低层 TensorIR 和 PackedFunc。TensorIR 函数和许多外部库采用目标传递约定(我们需要显式分配输出并作为参数传入函数)。我们使用 dps(destination passing) 来表示这个约定。dps 在低级 ML 优化中非常重要,因为它允许我们在可能的情况下一次性全局分配中间存储,并在没有主动内存分配的情况下执行计算。

调用 dps 函数意味着在调用之后,结果通过函数参数(例如,下面示例中的结果)而不是函数的返回值传回。

`// not destination passing  
int func(int x) {  
  return 1;  
}  
// destination passing  
void func(int x, int *result) {    
  *result = 1;  
}  
`

dps 风格在本质上意味着突变(输出)。我们需要一种将调用桥接到Relax Dataflow的方法(可以观察一下Relax这一节开头那部分的脚本),以便我们可以对一系列 tir 调用执行计算图样式的重写。

D2a. call\_tir

call_tir 是将调用桥接到Relax Dataflow的内嵌函数。它的命名含义是:“调用一个tir转换”

`def call_tir(output_shape: Shape, lowlevel_func: Expr, inputs: Tuple[Expr]) -> Expr:  
    """Example code to demonstrate the semantics of call tir"""  
    out_tensor = alloc_tensor(output_shape, current_expr.dtype)  
    lowlevel_func(*inputs, out_tensor)  
    return out_tensor  
`

call_tir 接受输出形状,lowlevel_func(can be packed func, tir PrimFunc) 和一个输入元组。call_tir 的语义可以通过上面的代码来演示。值得注意的是,当我们lower call_tir 时,我们不需要选择单独的分配输出张量。编译器可以选择创建中间张量的内存计划,并将它们联系在一起以实现有效重用。

值得注意的是,call_tir 内嵌函数的 output_shape 参数可以是不透明的形状值、符号整数元组或常量形状(支持动态Shape)。

lowlevel_func 可以是任何带有签名的函数:fn(input0, input1,... out0, out1...)

最常见的两种情况包括:(1) TIR 函数 (2) 不透明的packed func

实现笔记

call\_tir 可以实现为特殊的内嵌函数 (Op),以最大限度地减少对 IR 更改的影响(而不是独立的 IR 节点)。从 AST 的角度来看,这变为:

`Call(op=Op::Get("relax.call_tir"), shape, lowlevel_func, inputs)
`

这也将允许 call\_tir 的未来迭代而不改变 IR 本身,这可能在特定时间点需要:

  • 在同一个数组上启用多个突变序列(在 concat 相关操作的情况下)
  • 启用将符号化的Shape提示传递给融合操作。
对整合的影响

D2 使我们能够将较低级别的抽象直接嵌入到高级抽象(R.function)中。这释放了很多机会,包括但不限于:

  • 使用不同的策略逐步lower程序的不同部分。
  • 我们可以将call_tir节点作为AST的一部分进行优化,然后将一些关键信息比如data layerout信息带回到high level的IR获得更好的优化结果。
  • 将 BYOC 流作为转换的自然部分(通过将图的一部分转换为不透明打包函数的调用)。

这里的第二点实际上对应了Ansor引入的weight layout rewrite,即在算子auto-tuning之后,我们去分析最高效的weight layout,并且在编译时改写,来提高运行时的效率。那么没有Relax之前是怎么完成这个工作的呢?一个op 更适合的weight layout是要在tuning之后才能够知道的,而这个时候图IR已经被lower,不能修改了。所以Ansor用了一个非常tricky的方法,先lower一遍把tuning做好,再带着这些信息重新lower一遍。所以Relax通过消除lower的边界隔阂可以较好的解决这一问题。

D2b. Packed function calls

我们使用 R.call_packed 来指示对Packed Func的调用。从 AST 的角度来看,我们不需要引入额外的调用节点,而是可以引入一个 ExternFunc 构造,它表示我们可以调用的打包函数。

`Call(op=ExternFunc("my_packed_func"), *args)
`

R.call_packed 仅用作表示上述 AST 节点的语法糖。这使我们能够统一所有调用。值得注意的是,它还允许我们在必要时混合打包函数和 call\_tir。

`lv4: R.Tensor[lv3, "f32"] = R.call_tir(lv3, "custom_func", [lv1]) 
`

对应于下面的 AST。

`Call(op=Op::Get("relax.call_tir"), shape, ExternFunc("my_packed_func"), [lv1])
`

当我们想要将低级库(例如 cudnn)直接集成到高级而不调用内存分配时,外部打包函数上的 CallTIR 会很有用。

关于这一点在MLC课程中也有演示,通过dlpack调用PyTorch的Op来做优化,感兴趣的读者可以看一下,链接:https://mlc.ai/zh/chapter_end_to_end/index.html

这里简单做一个总结,Relax作为下一代Relay不仅原生支持动态Shape且使用体验更加靠近PyTorch这种数据流图的编程方式。尤其重要的是Relax在为TVM Unify而服务,通过和TensorIR抽象,TVMFFI(Packed Func)的交互(通过MLC教程可以知道,也可以和Auto Schedule交互)使得TVM Unify的目标得到实现。

当然我也要说一下我目前看到的Relax的不完善的地方,那就是Relax目前和其它深度学习框架对接还不够完善,如果能实现Relay到Relax的自动转换那将是一个振奋人心的消息,可以最小化我们的迁移成本。

0x3. Tensor Expression(TE)

让我们回到开头的这个图:

image.png
TVM前端架构图

我们可以发现Relay要到TIR有2条路径,第一条就是直接到TIR比如PrimExpr派生的节点比如一个IntImmNode可以直接映射到TIR,另外一条就是Relay里面类似Conv的Op的计算逻辑是用TOPI来表达的,TOPI是TVM自己的一个算子库,这些算子可以通过TE来进行表达。

除此之外,我们在前端介绍Relax的时候已经可以看到要直接编写TIR AST,一种方法是使用TVMScript来表示抽象的计算逻辑,另外一种方法就是要通过TE,TE的代码无法被直接编译成目标硬件的代码,而是需要先Lower为TIR的元张量函数才可以进行编译。其实我之前写过一些Schedule相关的文章比如《【TVM 三代优化巡礼】在X86上将普通的矩阵乘法算子提速90倍》,也都是基于TE的。由此可见,TE不仅提供了另外一种编写TIR AST的方法,还提供了一系列变换TIR AST的Schedule。在0x5节我们会提一下Schedule。

我们先看一下给予TVM Script写的这个向量加法的例子:

`@tvm.script.ir_module  
class MyModule:  
    @T.prim_func  
    def main(a: T.handle, b: T.handle):  
        # We exchange data between function by handles, which are similar to pointer.  
        T.func_attr({"global_symbol": "main", "tir.noalias": True})  
        # Create buffer from handles.  
        A = T.match_buffer(a, (8,), dtype="float32")  
        B = T.match_buffer(b, (8,), dtype="float32")  
        for i in range(8):  
            # A block is an abstraction for computation.  
            with T.block("B"):  
                # Define a spatial block iterator and bind it to value i.  
                vi = T.axis.spatial(8, i)  
                B[vi] = A[vi] + 1.0  
  
  
ir_module = MyModule  
print(type(ir_module))  
print(ir_module.script())  
`

输出:

`<class 'tvm.ir.module.IRModule'>  
# from tvm.script import tir as T  
@tvm.script.ir_module  
class Module:  
    @T.prim_func  
    def main(A: T.Buffer[8, "float32"], B: T.Buffer[8, "float32"]) -> None:  
        # function attr dict  
        T.func_attr({"global_symbol": "main", "tir.noalias": True})  
        # body  
        # with T.block("root")  
        for i in T.serial(8):  
            with T.block("B"):  
                vi = T.axis.spatial(8, i)  
                T.reads(A[vi])  
                T.writes(B[vi])  
                B[vi] = A[vi] + T.float32(1)  
`

然后我们使用TE DSL来表达这个向量加法:

`from tvm import te  
  
A = te.placeholder((8,), dtype="float32", name="A")  
B = te.compute((8,), lambda *i: A(*i) + 1.0, name="B")  
func = te.create_prim_func([A, B])  
ir_module_from_te = IRModule({"main": func})  
print(ir_module_from_te.script())  
`

输出:

`# from tvm.script import tir as T  
@tvm.script.ir_module  
class Module:  
    @T.prim_func  
    def main(A: T.Buffer[8, "float32"], B: T.Buffer[8, "float32"]) -> None:  
        # function attr dict  
        T.func_attr({"global_symbol": "main", "tir.noalias": True})  
        # body  
        # with T.block("root")  
        for i0 in T.serial(8):  
            with T.block("B"):  
                i0_1 = T.axis.spatial(8, i0)  
                T.reads(A[i0_1])  
                T.writes(B[i0_1])  
                B[i0_1] = A[i0_1] + T.float32(1)  
`

从两个输出中我们可以看到,最后创建的IRModule其实是完全一样的。然后这个IRModule可以被编译为目标硬件上可以执行的代码。如果你想更加深入的了解TE是如何被编译成TIR的,可以看一下 《TVM 自底向上(三):TE 的概念和编译原理》 这篇文章,我这里借一下作者文章中的核心图简要说明一下:

image.png

来自 :https://zhuanlan.zhihu.com/p/534313816 作者:Kord 侵删

我们从上往下看,这里的List[PrimExpr]就是这个lambda表达式中的PrimExpr集合,第一个PrimExpr是A(*i),第二个PrimExpr是1.0,然后+对应了TIR中的ExprOp([https://github.com/apache/tvm/blob/main/python/tvm/tir/expr.py#L66](https://github.com/apache/tvm/blob/main/python/tvm/tir/expr.py#L66)),Expr作用在1个或者多个PrimExpr上得到的结果仍然是PrimExpr。实际上,这里的List[PrimExpr]就对应了这个lambda表达式的AST表示。接下来我们看一下te.compute的代码([https://github.com/apache/tvm/blob/main/python/tvm/tir/expr.py#L66](https://github.com/apache/tvm/blob/main/python/tvm/tir/expr.py#L66)):

`def compute(shape, fcompute, name="compute", tag="", attrs=None, varargs_names=None):  
    """Construct a new tensor by computing over the shape domain.  
    The compute rule is result[axis] = fcompute(axis)  
    Parameters  
    ----------  
    shape: Tuple of Expr  
        The shape of the tensor  
    fcompute: lambda function of indices-> value  
        Specifies the input source expression  
    name: str, optional  
        The name hint of the tensor  
    tag: str, optional  
        Additional tag information about the compute.  
    attrs: dict, optional  
        The additional auxiliary attributes about the compute.  
    varargs_names: list, optional  
        The names to use for each of the varargs. If not supplied, the varargs  
        will be called i1, i2, ...  
    Returns  
    -------  
    tensor: Tensor  
        The created tensor  
    """  
    if _tag.TagScope.get_current() is not None:  
        if tag != "":  
            raise ValueError("nested tag is not allowed for now")  
        tag = _tag.TagScope.get_current().tag  
    shape = (shape,) if isinstance(shape, tvm.tir.PrimExpr) else shape  
    # for python3  
    shape = tuple([int(s) if isinstance(s, float) else s for s in shape])  
    out_ndim = len(shape)  
   # 获取输入给lambda表达式的参数列表   
    argspec = inspect.getfullargspec(fcompute)  
    if len(argspec.args) == 0 and argspec.varargs is None:  
        arg_names = ["i%d" % i for i in range(out_ndim)]  
    elif argspec.varargs is not None:  
        # if there is a varargs, it takes the remaining dimensions of out_ndim  
        num_remaining_args = out_ndim - len(argspec.args)  
        if varargs_names is not None:  
            if len(varargs_names) != num_remaining_args:  
                raise RuntimeError(  
                    f"Number of varargs ({num_remaining_args}) does not match number"  
                    f"of varargs_names ({len(varargs_names)})"  
                )  
            arg_names = argspec.args + varargs_names  
        else:  
            arg_names = argspec.args + [f"i{i}" for i in range(out_ndim - len(argspec.args))]  
    else:  
        arg_names = argspec.args  
        # if there are fewer args than out dimensions, the remaining dimensions  
        # are implicitly broadcast  
        out_ndim = len(arg_names)  
    assert argspec.varkw is None, "Variable keyword arguments not supported in fcompute"  
    assert argspec.defaults is None, "Default arguments not supported in fcompute"  
    assert len(argspec.kwonlyargs) == 0, "Keyword arguments are not supported in fcompute"  
  
    if out_ndim != len(arg_names):  
        raise ValueError(  
            "Number of args to fcompute does not match dimension, "  
            "args=%d, dimension=%d" % (len(arg_names), out_ndim)  
        )  
    
    dim_var = [tvm.tir.IterVar((0, s), x, 0) for x, s in zip(arg_names, shape[:out_ndim])]  
    # 基于lambda表达式创建List[PrimExpr]  
    body = fcompute(*[v.var for v in dim_var])  
     
   # 将List[PrimExpr]传给TensorComputeOp进行计算并返回一个tvm.te.Tensor  
    if isinstance(body, _tensor.TensorIntrinCall):  
        for i, s in enumerate(shape[out_ndim:]):  
            var_name = "ax" + str(i)  
            dim_var.append(tvm.tir.IterVar((0, s), var_name, 4))  
        op_node = _ffi_api.TensorComputeOp(  
            name,  
            tag,  
            dim_var,  
            body.reduce_axis,  
            out_ndim,  
            body.intrin,  
            body.tensors,  
            body.regions,  
            body.scalar_inputs,  
        )  
    else:  
        if not isinstance(body, (list, tuple)):  
            body = [body]  
        body = convert(body)  
        op_node = _ffi_api.ComputeOp(name, tag, attrs, dim_var, body)  
  
    num = op_node.num_outputs  
    outputs = tuple(op_node.output(i) for i in range(num))  
    return outputs[0] if num == 1 else outputs  
`

在compute的实现中最后返回的是TensorComputeOp对象的output()成员(也是一个tvm.te.Tensor), 同时这个tvm.te.Tensor包含这个TensorComputeOp对象(通过.op来访问,在[https://github.com/apache/tvm/blob/main/python/tvm/te/tensor.py#L108](https://github.com/apache/tvm/blob/main/python/tvm/te/tensor.py#L108)可以看到)。

最后func = te.create_prim_func([A, B])这行代码完成了TE到TIR的转换。这个api对应的c++实现在[https://github.com/apache/tvm/blob/v0.8.0/src/te/operation/create_primfunc.cc#L238](https://github.com/apache/tvm/blob/v0.8.0/src/te/operation/create_primfunc.cc#L238)这个文件,感兴趣的读者可以自行查看。基本流程就是将所有Operation对应的PrimExpr AST连在一起构成一个AST Graph,然后使用Post-DFS算法遍历这个AST Graph分别处理每一个Operation创建对应的TIR节点,最后构造一个完整的TIR PrimFunc。

TE除了可以构造TIR之外,另外一个重要的点就是它支持Schedule(tvm.te.Schedule),我在【TVM 三代优化巡礼】在X86上将普通的矩阵乘法算子提速90倍 文章中对GEMM优化的介绍就是基于TE Schedule来做变换进行优化计算的。

作者:BBuf
文章来源: GiantPandaCV量子位

推荐阅读

更多嵌入式AI干货请关注 嵌入式AI 专栏。欢迎添加极术小姐姐微信(id:aijishu20)加入技术交流群,请备注研究方向。
推荐阅读
关注数
18790
内容数
1342
嵌入式端AI,包括AI算法在推理框架Tengine,MNN,NCNN,PaddlePaddle及相关芯片上的实现。欢迎加入微信交流群,微信号:aijishu20(备注:嵌入式)
目录
极术微信服务号
关注极术微信号
实时接收点赞提醒和评论通知
安谋科技学堂公众号
关注安谋科技学堂
实时获取安谋科技及 Arm 教学资源
安谋科技招聘公众号
关注安谋科技招聘
实时获取安谋科技中国职位信息