16

AI学习者 · 7月8日

窥探Triton的lower(二)

在第一章我们已经完成了源码到make_ir的大致过程的分析,通过处理AST得到了初始的ttir。在这一章,我们会继续往下走,完成最后一步compile_ir。我们编译的后端nvptx又将这一步分为了五个小阶段:make_ttir、make_ttgir、make_llir、make_ptx和make_cubin,其中后两个阶段借助llvm和nv的ptxas完成,因此我们主要关注前面三个阶段。这些阶段由多个pass组合而成(这里默认读者知晓了编译器中"pass"的含义和作用), 根据pass的来源,我们可以发现它们大致可以分为下面几种

  1. common,定义在mlir/include/mlir/Transforms/Passes.td
  2. ttir,定义在
    triton/include/triton/Dialect/Triton/Transforms/Passes.td
  3. ttgpuir,定义在
    triton/include/triton/Dialect/TritonGPU/Transforms/Passes.td
  4. ttnvgpuir,定义在
    triton/include/triton/Dialect/TritonNvidiaGPU/Transforms/Passes.td

它们都通过继承mlir/Pass/PassBase.td来定义Pass。从上到下它们的语义越来越底层,和硬件相关的信息也越多。在这其中,我们会挑选一些重要的pass来分析。另外,这一篇文章代码展示的会比较多,一些不重要的部分已经删去了,只保留当前关注的重要部分。

  • make_ttir

这个阶段的代码如下所示,涉及到的pass的作用在注释里给出(triton/python/triton/backends/nvidia/compiler.py)

@staticmethoddef make_ttir(mod, metadata, opt):
    pm = ir.pass_manager(mod.context)
    pm.enable_debug()             
    passes.common.add_inliner(pm) # try to inline function call
    passes.ttir.add_rewrite_tensor_pointer(pm) # Rewrite load/stores with tensor pointers into legacy load/stores
    passes.ttir.add_combine(pm) # combine ops
    passes.common.add_canonicalizer(pm) # converts operations into their canonical forms by folding constants, identity transformations etc.
    passes.ttir.add_reorder_broadcast(pm) # Moves broadcast and splat after elementwise operations
    passes.common.add_cse(pm) # Eliminate common sub-expressions
    passes.common.add_licm(pm) # Hoist loop invariant instructions outside of the loop
    passes.common.add_symbol_dce(pm) # Eliminate dead symbols
    pm.run(mod)
    return mod

这里基本都是一些优化pass,例如add_inliner,在调用add_inliner pass时会通过http://pass.cchttp://pass.cc)跳转到Passes.td里定义的constructor ,从而进入pass的具体实现文件。

def Inliner : Pass<"inline"> {
  let summary = "Inline function calls";
  let constructor = "mlir::createInlinerPass()";
  let options = [
   ......
  ];
}

其他的pass我们在注释里给出了其功能。这一阶段的pass对我们的case几乎没有影响,因为我们的例子实在太简单了,没有这些优化的空间,因此我们直接进入下一阶段。

  • make_ttgir

这一阶段的内容较为丰富,我们先将所有pass的功能写在注释里

@staticmethoddef make_ttgir(mod, metadata, opt, capability):
    cluster_info = nvidia.ClusterInfo()
    if opt.cluster_dims is not None:
        cluster_info.clusterDimX = opt.cluster_dims[0]
        cluster_info.clusterDimY = opt.cluster_dims[1]
        cluster_info.clusterDimZ = opt.cluster_dims[2]
    # TTIR -> TTGIR
    pm = ir.pass_manager(mod.context)
    pm.enable_debug()
    passes.ttir.add_convert_to_ttgpuir(pm, opt.num_warps, 32, opt.num_ctas, capability) 
    # optimize TTGIR
    passes.ttgpuir.add_coalesce(pm) # coalesced some mem op for cache optimization
    nvidia.passes.ttnvgpuir.add_plan_cta(pm, cluster_info) # 决定CTA(threadblock)分块
    passes.ttgpuir.add_remove_layout_conversions(pm) # remove superfluous layout conversions
    passes.ttgpuir.add_optimize_thread_locality(pm) # Reduce the cost of synchronization between threads in an SM
    passes.ttgpuir.add_accelerate_matmul(pm, capability) # Optimize the input/output layout of `dot` instruction to make them compatible hardware accelerators(e.g., Nvidia tensor cores)
    passes.ttgpuir.add_remove_layout_conversions(pm) # remove superfluous layout conversions
    passes.ttgpuir.add_optimize_dot_operands(pm) # Re-arranged layouts of tensors used as matrix multiplication operands
    passes.common.add_cse(pm)  # Eliminate common sub-expressions
    if capability // 10 >= 8:
        passes.ttgpuir.add_pipeline(pm, opt.num_stages, opt.num_warps, opt.num_ctas, capability) # Applies software pipelining to loops in the module based on number of stages
    if capability // 10 <= 8:
        passes.ttgpuir.add_prefetch(pm) # Decompose `DotOp` instructions in loops into several finer-grained `DotOp`
    passes.ttgpuir.add_optimize_dot_operands(pm)  # Re-arranged layouts of tensors used as matrix multiplication operands
    passes.ttgpuir.add_remove_layout_conversions(pm) # remove superfluous layout conversions
    passes.ttgpuir.add_reduce_data_duplication(pm) # Reduce data duplication in register by decomposing
    passes.ttgpuir.add_reorder_instructions(pm) # Reorder instructions
    passes.common.add_cse(pm) # Eliminate common sub-expressions
    passes.common.add_symbol_dce(pm) # Eliminate dead symbols
    if capability // 10 >= 9:
        nvidia.passes.ttnvgpuir.add_fence_insertion(pm) # Insert fences across generic and async proxy
    passes.common.add_canonicalizer(pm) # converts operations into their canonical forms by folding constants, identity transformations etc.
    pm.run(mod)
    metadata["cluster_dims"] = (cluster_info.clusterDimX, cluster_info.clusterDimY, cluster_info.clusterDimZ)
    return mod

其中重点关注一下TTIR -> TTGIR的转换过程(本质上是dialect中op的转换),也就是add_convert_to_ttgpuir pass,熟悉llvm的小伙伴就可以把它理解为DAG阶段的合法化操作。通过定义文件直接跳转到它的实现triton/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp

void runOnOperation() override {
    MLIRContext *context = &getContext();
    ModuleOp mod = getOperation();
    // type converter    TritonGPUTypeConverter typeConverter(context, numWarps, threadsPerWarp,numCTAs);
    TritonGPUConversionTarget target(*context, typeConverter);
    // rewrite patterns    RewritePatternSet patterns(context); // 新建一个patternset,用来收集op的转换RewritePattern    populateArithPatternsAndLegality(typeConverter, patterns, target); // 添加arith dialect op的RewritePattern到patternset    // 会调用patterns.add<GenericOpPattern<arith::AddIOp>表明使用GenericOpPattern处理addi    populateMathPatternsAndLegality(typeConverter, patterns, target); // 添加math dialect op的RewritePattern到patternset    populateTritonPatterns(typeConverter, patterns, numCTAs); // 添加triton dialect op的RewritePattern到patternset    // 会调用patterns.insert<GenericOpPattern<triton::LoadOp>, GenericOpPattern<triton::StoreOp> 表明使用GenericOpPattern处理ld st    populateSCFPatterns(typeConverter, patterns); // 添加SCF dialect op的RewritePattern到patternset    populateCFPatterns(typeConverter, patterns); // 添加CFP dialect op的RewritePattern到patternset    ......
    mod->setAttr(  // 为module设置一些属性       ......   
    );

    if (failed(applyPartialConversion(mod, target, std::move(patterns))))
      return signalPassFailure();
  }

这个转换过程也可以叫做rewrite过程,这里我们先注意一下用到的type converter和target和分别是TritonGPUTypeConverterTritonGPUConversionTarget。这两个对象非常重要,type converter会指定某些数据类型的转换方式,target会指定哪些op是合法的(类似LLVM中的类型合法化和op合法化),后面也会介绍它们的代码。然后通过populate#Opname#Pattern函数将一些dialect的节点对应的转换pattern记录在patternset里,用来提供节点的转换方法,我们在上一章结束时的IR中的arith.addi、tt.load、tt.store都是用GenericOpPattern处理(这里只是声明转换pattern,还没有执行,后面会看到实现)。之后会调用mlir中的mlir/lib/Transforms/Utils/DialectConversion.cpp/applyPartialConversion函数去处理op的转换

LogicalResultmlir::applyPartialConversion(ArrayRef<Operation *> ops,
                             const ConversionTarget &target,
                             const FrozenRewritePatternSet &patterns,
                             DenseSet<Operation *> *unconvertedOps) {
  OperationConverter opConverter(target, patterns, OpConversionMode::Partial,
                                 unconvertedOps);
  return opConverter.convertOperations(ops);}

转换函数convertOperations在同一文件中

LogicalResult OperationConverter::convertOperations(
    ArrayRef<Operation *> ops,
    function_ref<void(Diagnostic &)> notifyCallback) {
  if (ops.empty())
    return success();
  const ConversionTarget &target = opLegalizer.getTarget();

  // Compute the set of operations and blocks to convert.  SmallVector<Operation *> toConvert;
  for (auto *op : ops) {
    op->walk<WalkOrder::PreOrder, ForwardDominanceIterator<>>(
        [&](Operation *op) {
          toConvert.push_back(op);
          auto legalityInfo = target.isLegal(op); // 判断当前op是否合法          if (legalityInfo && legalityInfo->isRecursivelyLegal)
            return WalkResult::skip();
          return WalkResult::advance();
        });
  }

  // Convert each operation and discard rewrites on failure.  ConversionPatternRewriter rewriter(ops.front()->getContext());
  ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
  for (auto *op : toConvert)
    if (failed(convert(rewriter, op))) // 转换op      return rewriterImpl.discardRewrites(), failure();
  ......
  return success();}

这里我们只需要关注两个重要的函数,target.isLegal(op) 和 convert(rewriter, op),即哪些op在ttgir level是合法的、对不合法的op怎么处理。

首先对于第一个问题,在创建TritonGPUConversionTarget(继承自ConversionTarget)时,会指定一些dialect或op的合法属性(triton/lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp)

TritonGPUConversionTarget::TritonGPUConversionTarget(
    MLIRContext &context, TritonGPUTypeConverter &typeConverter)
    : ConversionTarget(context) {

  addLegalDialect<triton::gpu::TritonGPUDialect>(); // TritonGPUDialect的所有op都合法  addIllegalOp<scf::ExecuteRegionOp, scf::ParallelOp, scf::ReduceOp, // 这几个op不合法               scf::ReduceReturnOp>();
  addDynamicallyLegalDialect<arith::ArithDialect, math::MathDialect,
                             triton::TritonDialect, cf::ControlFlowDialect,
                             scf::SCFDialect>([&](Operation *op) {  // 这些dialect的op只有满足下面的条件才合法    bool hasLegalRegions = true;
    for (auto &region : op->getRegions()) {
      hasLegalRegions = hasLegalRegions && typeConverter.isLegal(&region);
    }
    if (hasLegalRegions && typeConverter.isLegal(op)) { // 这里会调用type converter判断类型是否合法      return true;
    }
    return false;
  });
  ......}

而上面用到的type converter是TritonGPUTypeConverter

TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
                                               int numWarps, int threadsPerWarp,
                                               int numCTAs)
    : context(context), numWarps(numWarps), threadsPerWarp(threadsPerWarp),
      numCTAs(numCTAs) {
  addConversion([](Type type) { return type; }); // 返回当前类型表示该类型合法
  // Add encoding for tensor  addConversion([this](RankedTensorType tensorType) -> RankedTensorType { // 转换tensor类型    // types with encoding are already in the right format    // TODO: check for layout encodings more specifically    if (tensorType.getEncoding())
      return tensorType;
    ArrayRef<int64_t> shape = tensorType.getShape();
    triton::gpu::BlockedEncodingAttr encoding =
        getDefaultBlockedEncoding(this->context, shape, this->numWarps,
                                  this->threadsPerWarp, this->numCTAs);
    return RankedTensorType::get(shape, tensorType.getElementType(), encoding);
  });

  // Add encoding for tensor pointer  addConversion([this](triton::PointerType ptrType) -> triton::PointerType { // 转换pointer类型    // Check whether tensor pointer `tt.ptr<tensor<>>`    auto pointeeTensorType =
        ptrType.getPointeeType().dyn_cast<RankedTensorType>();
    if (pointeeTensorType == nullptr)
      return ptrType;

    // Add layout into the tensor    auto convertedTensorType = convertType(pointeeTensorType);
    return triton::PointerType::get(convertedTensorType,
                                    ptrType.getAddressSpace());
  });

  //  // Materializations  //  ......}

可以看到type converter声明了RankedTensorType和PointerType需要转换。因此在判断target.isLegal(op)时,会根据target和type converter的接口返回当前op是否合法。(type converter还有一个作用是materialization,用来描述了如何将一组值转换为所需类型的单个值)

在判断完合法性后,对于第二个函数convert(rewriter, op),会根据rewrite模式有不同的行为,我们这里的mode是"Partial",会忽略失败的转换,以使得ir中也存在不合法的操作

LogicalResult OperationConverter::convert(ConversionPatternRewriter &rewriter,
                                          Operation *op) {
  // Legalize the given operation.  if (failed(opLegalizer.legalize(op, rewriter))) {
    if (mode == OpConversionMode::Full)
      return op->emitError()
             << "failed to legalize operation '" << op->getName() << "'";
    if (mode == OpConversionMode::Partial) {
      if (opLegalizer.isIllegal(op))
        return op->emitError()
               << "failed to legalize operation '" << op->getName()
               << "' that was explicitly marked illegal";
      if (trackedOps)
        trackedOps->insert(op);
    }
  } else if (mode == OpConversionMode::Analysis) {
    trackedOps->insert(op);
  }
  return success();}

可以看到上面的代码在转换时会调用opLegalizer.legalize()函数,我们继续进入legalize

LogicalResultOperationLegalizer::legalize(Operation *op,
                             ConversionPatternRewriter &rewriter) {
  ......
  // If the operation isn't legal, try to fold it in-place.  if (succeeded(legalizeWithFold(op, rewriter))) { // fold即将他的运算直接放在操作数中,例如常数    LLVM_DEBUG({
      logSuccess(logger, "operation was folded");
      logger.startLine() << logLineComment;
    });
    return success();
  }

  // Otherwise, we need to apply a legalization pattern to this operation.  if (succeeded(legalizeWithPattern(op, rewriter))) { // 尝试用其他的pattern来代替当前op    LLVM_DEBUG({
      logSuccess(logger, "");
      logger.startLine() << logLineComment;
    });
    return success();
  }
  ......
  return failure();}

这里的合法化有两种模式,折叠和替换。我们关注第二种模式,它最终会调用mlir/lib/Rewrite/PatternApplicator.cpp的applicator.matchAndRewrite(),这里的代码省去了大部分逻辑,只保留了基础的

LogicalResult PatternApplicator::matchAndRewrite(
  const Pattern *bestPattern = nullptr;
  // Find the next pattern with the highest benefit.  ......
  const auto *pattern = static_cast<const RewritePattern *>(bestPattern);
  result = pattern->matchAndRewrite(op, rewriter);
  ......}

其中,RewritePattern对象包含了对某个节点的rewrite方式,也就是上面添加到patternset中的转换pattern,在mlir和triton中,各个dialect或target的目录下的文件可能会包含一些继承自RewritePattern的对象用来处理一些节点,并定义自己的rewrite方式。这里的逻辑是,根据当前op的name找到所有的候选pattern,然后会经过一个cost model计算得到当前pattern的benefits指标作为选择的依据,选出最收益最高的Pattern调用它的matchAndRewrite方法。我们以一个arith::CeilDivUIOp节点为例(这个节点不一定真正在这个阶段转换,只是举例用)

/// Expands CeilDivUIOp (n, m) into///  n == 0 ? 0 : ((n-1) / m) + 1/// OpRewritePattern 继承自 RewritePatternstruct CeilDivUIOpConverter : public OpRewritePattern<arith::CeilDivUIOp> {
  using OpRewritePattern::OpRewritePattern;
  LogicalResult matchAndRewrite(arith::CeilDivUIOp op,
                                PatternRewriter &rewriter) const final {
    Location loc = op.getLoc();
    Value a = op.getLhs();
    Value b = op.getRhs();
    Value zero = createConst(loc, a.getType(), 0, rewriter);
    Value compare =
        rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, a, zero);
    Value one = createConst(loc, a.getType(), 1, rewriter);
    Value minusOne = rewriter.create<arith::SubIOp>(loc, a, one);
    Value quotient = rewriter.create<arith::DivUIOp>(loc, minusOne, b);
    Value plusOne = rewriter.create<arith::AddIOp>(loc, quotient, one);
    rewriter.replaceOpWithNewOp<arith::SelectOp>(op, compare, zero, plusOne);
    return success();
  }};

这里将ceildiv操作转换成了通过selectOp实现的形式,它的matchAndRewrite函数就可能在合法化arith::CeilDivUIOp的时候调用。

对我们的case中的arith.addi、tt.load、tt.store节点,首先在合法化时,根据上面的代码他们所在的arith::ArithDialect和triton::TritonDialect都是属于DynamicallyLegal,只有类型合法时才合法;并且上面提到都是使用GenericOpPattern去rewrite,如下所示,它可以看作一个通用的OpRewritePattern,它也只会处理op的不合法的数据类型,对op本身没有处理,因此我们的ir在这一步这里不会变化。(target也会自己定义一些节点的RewritePattern,在下一章可以看到)

template <class Op> struct GenericOpPattern : public OpConversionPattern<Op> {
  using OpConversionPattern<Op>::OpConversionPattern;

  LogicalResult
  matchAndRewrite(Op op, typename Op::Adaptor adaptor,
                  ConversionPatternRewriter &rewriter) const override {
    SmallVector<Type> retTypes;
    if (failed(this->getTypeConverter()->convertTypes(op->getResultTypes(),
                                                      retTypes)))
      return failure();
    rewriter.replaceOpWithNewOp<Op>(op, retTypes, adaptor.getOperands(),
                                    op->getAttrs());

    return success();
  }
};

此外,我们不需要对目标生成直接合法的操作,mlir框架会自动构建一个转换图,将非法操作转换为一组合法操作。mlir的文档里举了个例子,假设定义了一个操作:“foo.add”,当提供以下模式时:[bar.add -> baz.add, baz.add -> foo.add],框架会自动检测到可以合法化为 bar.add ->“foo.add”,即使不存在直接转换。这意味着不必为“bar.add”->“foo.add”定义直接合法化模式。

目前,我们已经了解add_convert_to_ttgpuir的过程了,也就是合法化的过程。我们先总结一下当前阶段的流程:ttir->ttgir的过程就是对dialect op的转换,首先判断一个op是否合法(通过target和type converter提供的接口),如果合法则不用处理,非法的话对其进行转换(通过RewritePatternSet添加的RewritePattern中的接口)。

而add_convert_to_ttgpuir后面的优化pass我们先跳过了,它们的实现相对比较好理解。注意这一阶段开始,已经有一些硬件特定的pass了,例如nvidia的add_fence_insertion。这一阶段我们的case对应的IR没有太大变化。现在它长这样

#loc = loc("toy.py":28:0)
module attributes {"triton_gpu.compute-capability" = 80 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} {
  tt.func public @addi_kernel_01(%arg0: !tt.ptr<i32, 1> loc("toy.py":28:0), %arg1: !tt.ptr<i32, 1> loc("toy.py":28:0)) attributes {noinline = false} {
    %c1_i32 = arith.constant 1 : i32 loc(#loc1)
    %0 = tt.load %arg0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : i32 loc(#loc2)
    %1 = arith.addi %0, %c1_i32 : i32 loc(#loc3)
    tt.store %arg1, %1 {cache = 1 : i32, evict = 1 : i32} : i32 loc(#loc4)
    tt.return loc(#loc5)
  } loc(#loc)
} loc(#loc)
#loc1 = loc(unknown)
#loc2 = loc("toy.py":39:16)
#loc3 = loc("toy.py":40:17)
#loc4 = loc("toy.py":41:25)
#loc5 = loc("toy.py":41:4)

可以看到相比之前只是多了一些硬件相关的module属性。

本来计划在这一篇也把make_llir的过程写完的,但是不知不觉篇幅已经很长了,那就放在下一篇吧。其实make_llir和make_ttgir的过程和用到的接口很类似,在后面的分析过程我们会看到很多熟悉的函数,因此理解起来也顺畅一点。在make_llir之后,我们的case得到的IR也会发生较大的变化。

The End

作者:液态黑洞
来源:GiantPandaCV

推荐阅读

欢迎大家点赞留言,更多Arm技术文章动态请关注极术社区嵌入式AI专栏欢迎添加极术小姐姐微信(id:aijishu20)加入技术交流群,请备注研究方向。

推荐阅读
关注数
18751
内容数
1308
嵌入式端AI,包括AI算法在推理框架Tengine,MNN,NCNN,PaddlePaddle及相关芯片上的实现。欢迎加入微信交流群,微信号:aijishu20(备注:嵌入式)
目录
极术微信服务号
关注极术微信号
实时接收点赞提醒和评论通知
安谋科技学堂公众号
关注安谋科技学堂
实时获取安谋科技及 Arm 教学资源
安谋科技招聘公众号
关注安谋科技招聘
实时获取安谋科技中国职位信息