在第一章我们已经完成了源码到make_ir的大致过程的分析,通过处理AST得到了初始的ttir。在这一章,我们会继续往下走,完成最后一步compile_ir。我们编译的后端nvptx又将这一步分为了五个小阶段:make_ttir、make_ttgir、make_llir、make_ptx和make_cubin,其中后两个阶段借助llvm和nv的ptxas完成,因此我们主要关注前面三个阶段。这些阶段由多个pass组合而成(这里默认读者知晓了编译器中"pass"的含义和作用), 根据pass的来源,我们可以发现它们大致可以分为下面几种
- common,定义在mlir/include/mlir/Transforms/Passes.td
- ttir,定义在
triton/include/triton/Dialect/Triton/Transforms/Passes.td - ttgpuir,定义在
triton/include/triton/Dialect/TritonGPU/Transforms/Passes.td - ttnvgpuir,定义在
triton/include/triton/Dialect/TritonNvidiaGPU/Transforms/Passes.td
它们都通过继承mlir/Pass/PassBase.td来定义Pass。从上到下它们的语义越来越底层,和硬件相关的信息也越多。在这其中,我们会挑选一些重要的pass来分析。另外,这一篇文章代码展示的会比较多,一些不重要的部分已经删去了,只保留当前关注的重要部分。
- make_ttir
这个阶段的代码如下所示,涉及到的pass的作用在注释里给出(triton/python/triton/backends/nvidia/compiler.py)
@staticmethoddef make_ttir(mod, metadata, opt):
pm = ir.pass_manager(mod.context)
pm.enable_debug()
passes.common.add_inliner(pm) # try to inline function call
passes.ttir.add_rewrite_tensor_pointer(pm) # Rewrite load/stores with tensor pointers into legacy load/stores
passes.ttir.add_combine(pm) # combine ops
passes.common.add_canonicalizer(pm) # converts operations into their canonical forms by folding constants, identity transformations etc.
passes.ttir.add_reorder_broadcast(pm) # Moves broadcast and splat after elementwise operations
passes.common.add_cse(pm) # Eliminate common sub-expressions
passes.common.add_licm(pm) # Hoist loop invariant instructions outside of the loop
passes.common.add_symbol_dce(pm) # Eliminate dead symbols
pm.run(mod)
return mod
这里基本都是一些优化pass,例如add_inliner,在调用add_inliner pass时会通过http://pass.cc(http://pass.cc)跳转到Passes.td里定义的constructor ,从而进入pass的具体实现文件。
def Inliner : Pass<"inline"> {
let summary = "Inline function calls";
let constructor = "mlir::createInlinerPass()";
let options = [
......
];
}
其他的pass我们在注释里给出了其功能。这一阶段的pass对我们的case几乎没有影响,因为我们的例子实在太简单了,没有这些优化的空间,因此我们直接进入下一阶段。
- make_ttgir
这一阶段的内容较为丰富,我们先将所有pass的功能写在注释里
@staticmethoddef make_ttgir(mod, metadata, opt, capability):
cluster_info = nvidia.ClusterInfo()
if opt.cluster_dims is not None:
cluster_info.clusterDimX = opt.cluster_dims[0]
cluster_info.clusterDimY = opt.cluster_dims[1]
cluster_info.clusterDimZ = opt.cluster_dims[2]
# TTIR -> TTGIR
pm = ir.pass_manager(mod.context)
pm.enable_debug()
passes.ttir.add_convert_to_ttgpuir(pm, opt.num_warps, 32, opt.num_ctas, capability)
# optimize TTGIR
passes.ttgpuir.add_coalesce(pm) # coalesced some mem op for cache optimization
nvidia.passes.ttnvgpuir.add_plan_cta(pm, cluster_info) # 决定CTA(threadblock)分块
passes.ttgpuir.add_remove_layout_conversions(pm) # remove superfluous layout conversions
passes.ttgpuir.add_optimize_thread_locality(pm) # Reduce the cost of synchronization between threads in an SM
passes.ttgpuir.add_accelerate_matmul(pm, capability) # Optimize the input/output layout of `dot` instruction to make them compatible hardware accelerators(e.g., Nvidia tensor cores)
passes.ttgpuir.add_remove_layout_conversions(pm) # remove superfluous layout conversions
passes.ttgpuir.add_optimize_dot_operands(pm) # Re-arranged layouts of tensors used as matrix multiplication operands
passes.common.add_cse(pm) # Eliminate common sub-expressions
if capability // 10 >= 8:
passes.ttgpuir.add_pipeline(pm, opt.num_stages, opt.num_warps, opt.num_ctas, capability) # Applies software pipelining to loops in the module based on number of stages
if capability // 10 <= 8:
passes.ttgpuir.add_prefetch(pm) # Decompose `DotOp` instructions in loops into several finer-grained `DotOp`
passes.ttgpuir.add_optimize_dot_operands(pm) # Re-arranged layouts of tensors used as matrix multiplication operands
passes.ttgpuir.add_remove_layout_conversions(pm) # remove superfluous layout conversions
passes.ttgpuir.add_reduce_data_duplication(pm) # Reduce data duplication in register by decomposing
passes.ttgpuir.add_reorder_instructions(pm) # Reorder instructions
passes.common.add_cse(pm) # Eliminate common sub-expressions
passes.common.add_symbol_dce(pm) # Eliminate dead symbols
if capability // 10 >= 9:
nvidia.passes.ttnvgpuir.add_fence_insertion(pm) # Insert fences across generic and async proxy
passes.common.add_canonicalizer(pm) # converts operations into their canonical forms by folding constants, identity transformations etc.
pm.run(mod)
metadata["cluster_dims"] = (cluster_info.clusterDimX, cluster_info.clusterDimY, cluster_info.clusterDimZ)
return mod
其中重点关注一下TTIR -> TTGIR的转换过程(本质上是dialect中op的转换),也就是add_convert_to_ttgpuir pass,熟悉llvm的小伙伴就可以把它理解为DAG阶段的合法化操作。通过定义文件直接跳转到它的实现triton/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp
void runOnOperation() override {
MLIRContext *context = &getContext();
ModuleOp mod = getOperation();
// type converter TritonGPUTypeConverter typeConverter(context, numWarps, threadsPerWarp,numCTAs);
TritonGPUConversionTarget target(*context, typeConverter);
// rewrite patterns RewritePatternSet patterns(context); // 新建一个patternset,用来收集op的转换RewritePattern populateArithPatternsAndLegality(typeConverter, patterns, target); // 添加arith dialect op的RewritePattern到patternset // 会调用patterns.add<GenericOpPattern<arith::AddIOp>表明使用GenericOpPattern处理addi populateMathPatternsAndLegality(typeConverter, patterns, target); // 添加math dialect op的RewritePattern到patternset populateTritonPatterns(typeConverter, patterns, numCTAs); // 添加triton dialect op的RewritePattern到patternset // 会调用patterns.insert<GenericOpPattern<triton::LoadOp>, GenericOpPattern<triton::StoreOp> 表明使用GenericOpPattern处理ld st populateSCFPatterns(typeConverter, patterns); // 添加SCF dialect op的RewritePattern到patternset populateCFPatterns(typeConverter, patterns); // 添加CFP dialect op的RewritePattern到patternset ......
mod->setAttr( // 为module设置一些属性 ......
);
if (failed(applyPartialConversion(mod, target, std::move(patterns))))
return signalPassFailure();
}
这个转换过程也可以叫做rewrite过程,这里我们先注意一下用到的type converter和target和分别是TritonGPUTypeConverter和TritonGPUConversionTarget。这两个对象非常重要,type converter会指定某些数据类型的转换方式,target会指定哪些op是合法的(类似LLVM中的类型合法化和op合法化),后面也会介绍它们的代码。然后通过populate#Opname#Pattern函数将一些dialect的节点对应的转换pattern记录在patternset里,用来提供节点的转换方法,我们在上一章结束时的IR中的arith.addi、tt.load、tt.store都是用GenericOpPattern处理(这里只是声明转换pattern,还没有执行,后面会看到实现)。之后会调用mlir中的mlir/lib/Transforms/Utils/DialectConversion.cpp/applyPartialConversion函数去处理op的转换
LogicalResultmlir::applyPartialConversion(ArrayRef<Operation *> ops,
const ConversionTarget &target,
const FrozenRewritePatternSet &patterns,
DenseSet<Operation *> *unconvertedOps) {
OperationConverter opConverter(target, patterns, OpConversionMode::Partial,
unconvertedOps);
return opConverter.convertOperations(ops);}
转换函数convertOperations在同一文件中
LogicalResult OperationConverter::convertOperations(
ArrayRef<Operation *> ops,
function_ref<void(Diagnostic &)> notifyCallback) {
if (ops.empty())
return success();
const ConversionTarget &target = opLegalizer.getTarget();
// Compute the set of operations and blocks to convert. SmallVector<Operation *> toConvert;
for (auto *op : ops) {
op->walk<WalkOrder::PreOrder, ForwardDominanceIterator<>>(
[&](Operation *op) {
toConvert.push_back(op);
auto legalityInfo = target.isLegal(op); // 判断当前op是否合法 if (legalityInfo && legalityInfo->isRecursivelyLegal)
return WalkResult::skip();
return WalkResult::advance();
});
}
// Convert each operation and discard rewrites on failure. ConversionPatternRewriter rewriter(ops.front()->getContext());
ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
for (auto *op : toConvert)
if (failed(convert(rewriter, op))) // 转换op return rewriterImpl.discardRewrites(), failure();
......
return success();}
这里我们只需要关注两个重要的函数,target.isLegal(op) 和 convert(rewriter, op),即哪些op在ttgir level是合法的、对不合法的op怎么处理。
首先对于第一个问题,在创建TritonGPUConversionTarget(继承自ConversionTarget)时,会指定一些dialect或op的合法属性(triton/lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp)
TritonGPUConversionTarget::TritonGPUConversionTarget(
MLIRContext &context, TritonGPUTypeConverter &typeConverter)
: ConversionTarget(context) {
addLegalDialect<triton::gpu::TritonGPUDialect>(); // TritonGPUDialect的所有op都合法 addIllegalOp<scf::ExecuteRegionOp, scf::ParallelOp, scf::ReduceOp, // 这几个op不合法 scf::ReduceReturnOp>();
addDynamicallyLegalDialect<arith::ArithDialect, math::MathDialect,
triton::TritonDialect, cf::ControlFlowDialect,
scf::SCFDialect>([&](Operation *op) { // 这些dialect的op只有满足下面的条件才合法 bool hasLegalRegions = true;
for (auto ®ion : op->getRegions()) {
hasLegalRegions = hasLegalRegions && typeConverter.isLegal(®ion);
}
if (hasLegalRegions && typeConverter.isLegal(op)) { // 这里会调用type converter判断类型是否合法 return true;
}
return false;
});
......}
而上面用到的type converter是TritonGPUTypeConverter
TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
int numWarps, int threadsPerWarp,
int numCTAs)
: context(context), numWarps(numWarps), threadsPerWarp(threadsPerWarp),
numCTAs(numCTAs) {
addConversion([](Type type) { return type; }); // 返回当前类型表示该类型合法
// Add encoding for tensor addConversion([this](RankedTensorType tensorType) -> RankedTensorType { // 转换tensor类型 // types with encoding are already in the right format // TODO: check for layout encodings more specifically if (tensorType.getEncoding())
return tensorType;
ArrayRef<int64_t> shape = tensorType.getShape();
triton::gpu::BlockedEncodingAttr encoding =
getDefaultBlockedEncoding(this->context, shape, this->numWarps,
this->threadsPerWarp, this->numCTAs);
return RankedTensorType::get(shape, tensorType.getElementType(), encoding);
});
// Add encoding for tensor pointer addConversion([this](triton::PointerType ptrType) -> triton::PointerType { // 转换pointer类型 // Check whether tensor pointer `tt.ptr<tensor<>>` auto pointeeTensorType =
ptrType.getPointeeType().dyn_cast<RankedTensorType>();
if (pointeeTensorType == nullptr)
return ptrType;
// Add layout into the tensor auto convertedTensorType = convertType(pointeeTensorType);
return triton::PointerType::get(convertedTensorType,
ptrType.getAddressSpace());
});
// // Materializations // ......}
可以看到type converter声明了RankedTensorType和PointerType需要转换。因此在判断target.isLegal(op)时,会根据target和type converter的接口返回当前op是否合法。(type converter还有一个作用是materialization,用来描述了如何将一组值转换为所需类型的单个值)
在判断完合法性后,对于第二个函数convert(rewriter, op),会根据rewrite模式有不同的行为,我们这里的mode是"Partial",会忽略失败的转换,以使得ir中也存在不合法的操作
LogicalResult OperationConverter::convert(ConversionPatternRewriter &rewriter,
Operation *op) {
// Legalize the given operation. if (failed(opLegalizer.legalize(op, rewriter))) {
if (mode == OpConversionMode::Full)
return op->emitError()
<< "failed to legalize operation '" << op->getName() << "'";
if (mode == OpConversionMode::Partial) {
if (opLegalizer.isIllegal(op))
return op->emitError()
<< "failed to legalize operation '" << op->getName()
<< "' that was explicitly marked illegal";
if (trackedOps)
trackedOps->insert(op);
}
} else if (mode == OpConversionMode::Analysis) {
trackedOps->insert(op);
}
return success();}
可以看到上面的代码在转换时会调用opLegalizer.legalize()函数,我们继续进入legalize
LogicalResultOperationLegalizer::legalize(Operation *op,
ConversionPatternRewriter &rewriter) {
......
// If the operation isn't legal, try to fold it in-place. if (succeeded(legalizeWithFold(op, rewriter))) { // fold即将他的运算直接放在操作数中,例如常数 LLVM_DEBUG({
logSuccess(logger, "operation was folded");
logger.startLine() << logLineComment;
});
return success();
}
// Otherwise, we need to apply a legalization pattern to this operation. if (succeeded(legalizeWithPattern(op, rewriter))) { // 尝试用其他的pattern来代替当前op LLVM_DEBUG({
logSuccess(logger, "");
logger.startLine() << logLineComment;
});
return success();
}
......
return failure();}
这里的合法化有两种模式,折叠和替换。我们关注第二种模式,它最终会调用mlir/lib/Rewrite/PatternApplicator.cpp的applicator.matchAndRewrite(),这里的代码省去了大部分逻辑,只保留了基础的
LogicalResult PatternApplicator::matchAndRewrite(
const Pattern *bestPattern = nullptr;
// Find the next pattern with the highest benefit. ......
const auto *pattern = static_cast<const RewritePattern *>(bestPattern);
result = pattern->matchAndRewrite(op, rewriter);
......}
其中,RewritePattern对象包含了对某个节点的rewrite方式,也就是上面添加到patternset中的转换pattern,在mlir和triton中,各个dialect或target的目录下的文件可能会包含一些继承自RewritePattern的对象用来处理一些节点,并定义自己的rewrite方式。这里的逻辑是,根据当前op的name找到所有的候选pattern,然后会经过一个cost model计算得到当前pattern的benefits指标作为选择的依据,选出最收益最高的Pattern调用它的matchAndRewrite方法。我们以一个arith::CeilDivUIOp节点为例(这个节点不一定真正在这个阶段转换,只是举例用)
/// Expands CeilDivUIOp (n, m) into/// n == 0 ? 0 : ((n-1) / m) + 1/// OpRewritePattern 继承自 RewritePatternstruct CeilDivUIOpConverter : public OpRewritePattern<arith::CeilDivUIOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(arith::CeilDivUIOp op,
PatternRewriter &rewriter) const final {
Location loc = op.getLoc();
Value a = op.getLhs();
Value b = op.getRhs();
Value zero = createConst(loc, a.getType(), 0, rewriter);
Value compare =
rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, a, zero);
Value one = createConst(loc, a.getType(), 1, rewriter);
Value minusOne = rewriter.create<arith::SubIOp>(loc, a, one);
Value quotient = rewriter.create<arith::DivUIOp>(loc, minusOne, b);
Value plusOne = rewriter.create<arith::AddIOp>(loc, quotient, one);
rewriter.replaceOpWithNewOp<arith::SelectOp>(op, compare, zero, plusOne);
return success();
}};
这里将ceildiv操作转换成了通过selectOp实现的形式,它的matchAndRewrite函数就可能在合法化arith::CeilDivUIOp的时候调用。
对我们的case中的arith.addi、tt.load、tt.store节点,首先在合法化时,根据上面的代码他们所在的arith::ArithDialect和triton::TritonDialect都是属于DynamicallyLegal,只有类型合法时才合法;并且上面提到都是使用GenericOpPattern去rewrite,如下所示,它可以看作一个通用的OpRewritePattern,它也只会处理op的不合法的数据类型,对op本身没有处理,因此我们的ir在这一步这里不会变化。(target也会自己定义一些节点的RewritePattern,在下一章可以看到)
template <class Op> struct GenericOpPattern : public OpConversionPattern<Op> {
using OpConversionPattern<Op>::OpConversionPattern;
LogicalResult
matchAndRewrite(Op op, typename Op::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
SmallVector<Type> retTypes;
if (failed(this->getTypeConverter()->convertTypes(op->getResultTypes(),
retTypes)))
return failure();
rewriter.replaceOpWithNewOp<Op>(op, retTypes, adaptor.getOperands(),
op->getAttrs());
return success();
}
};
此外,我们不需要对目标生成直接合法的操作,mlir框架会自动构建一个转换图,将非法操作转换为一组合法操作。mlir的文档里举了个例子,假设定义了一个操作:“foo.add”,当提供以下模式时:[bar.add
-> baz.add
, baz.add
-> foo.add
],框架会自动检测到可以合法化为 bar.add
->“foo.add”,即使不存在直接转换。这意味着不必为“bar.add”->“foo.add”定义直接合法化模式。
目前,我们已经了解add_convert_to_ttgpuir的过程了,也就是合法化的过程。我们先总结一下当前阶段的流程:ttir->ttgir的过程就是对dialect op的转换,首先判断一个op是否合法(通过target和type converter提供的接口),如果合法则不用处理,非法的话对其进行转换(通过RewritePatternSet添加的RewritePattern中的接口)。
而add_convert_to_ttgpuir后面的优化pass我们先跳过了,它们的实现相对比较好理解。注意这一阶段开始,已经有一些硬件特定的pass了,例如nvidia的add_fence_insertion。这一阶段我们的case对应的IR没有太大变化。现在它长这样
#loc = loc("toy.py":28:0)
module attributes {"triton_gpu.compute-capability" = 80 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} {
tt.func public @addi_kernel_01(%arg0: !tt.ptr<i32, 1> loc("toy.py":28:0), %arg1: !tt.ptr<i32, 1> loc("toy.py":28:0)) attributes {noinline = false} {
%c1_i32 = arith.constant 1 : i32 loc(#loc1)
%0 = tt.load %arg0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : i32 loc(#loc2)
%1 = arith.addi %0, %c1_i32 : i32 loc(#loc3)
tt.store %arg1, %1 {cache = 1 : i32, evict = 1 : i32} : i32 loc(#loc4)
tt.return loc(#loc5)
} loc(#loc)
} loc(#loc)
#loc1 = loc(unknown)
#loc2 = loc("toy.py":39:16)
#loc3 = loc("toy.py":40:17)
#loc4 = loc("toy.py":41:25)
#loc5 = loc("toy.py":41:4)
可以看到相比之前只是多了一些硬件相关的module属性。
本来计划在这一篇也把make_llir的过程写完的,但是不知不觉篇幅已经很长了,那就放在下一篇吧。其实make_llir和make_ttgir的过程和用到的接口很类似,在后面的分析过程我们会看到很多熟悉的函数,因此理解起来也顺畅一点。在make_llir之后,我们的case得到的IR也会发生较大的变化。
The End
作者:液态黑洞
来源:GiantPandaCV
推荐阅读
- Llama也能做图像生成!港大字节推出开源自回归文生图模型,在线体验已开放
- 如何在 PyTorch 中 profile CUDA kernels
- 窥探Trition的lower(一)
- TensorRT-LLM部署调优-指北
欢迎大家点赞留言,更多Arm技术文章动态请关注极术社区嵌入式AI专栏欢迎添加极术小姐姐微信(id:aijishu20)加入技术交流群,请备注研究方向。