AI学习者 · 7月1日

窥探Triton的lower(三)

在上一章,我们完成了ttir->ttgir的过程分析,重点在于理解其中用到的数据结构和流程。有了上面的基础,我们理解接下来的内容会非常轻松。在这一阶段结束时我们的case还是包含arith::addi、tt.load、tt.store等节点,在这一阶段我们会看到它们的变化。所以让我们直接进入最后的make_llir阶段。

  • make_llir

根据注释,这一步其实又可以分为两小步,TritonGPU -> LLVM-IR (MLIR) 和 LLVM-IR (MLIR) -> LLVM-IR (LLVM)。这两步的区别在于,第一步是还是MLIR级别的,也就是在dialect空间的转换,转换的结果就是LLVMDialect,而第二步是将LLVMDialect转换为真正的LLVM IR。

    @staticmethod
    def make_llir(src, metadata, options, capability):
        # warp-specialization mutates num_warps
        num_warp_groups = src.get_int_attr("triton_gpu.num-warp-groups-per-cta")
        if num_warp_groups is not None:
            metadata["num_warps"] *= num_warp_groups
        mod = src
        # TritonGPU -> LLVM-IR (MLIR)
        pm = ir.pass_manager(mod.context)
        pm.enable_debug()
        nvidia.passes.ttgpuir.add_decompose_unsupported_conversions(pm) # Decompose conversions that are not supported by TritonGPU -> LLVM
        passes.convert.add_scf_to_cf(pm) # Convert SCF dialect to ControlFlow dialect
        passes.convert.add_index_to_llvmir(pm) # Lower the `index` dialect to the `llvm` dialect
        passes.ttgpuir.add_allocate_shared_memory(pm) # Add metadata for shared memory allocation
        nvidia.passes.ttgpuir.add_to_llvmir(pm, capability)
        nvidia.passes.ttnvgpuir.add_nvgpu_to_llvm(pm) # 用来处理NVGPUDialect的节点,大部分替换为内嵌汇编
        passes.convert.add_arith_to_llvmir(pm) # Convert Arith dialect to LLVM dialect
        passes.common.add_canonicalizer(pm) # converts operations into their canonical forms by folding constants, identity transformations etc.
        passes.common.add_cse(pm) # Eliminate common sub-expressions
        passes.common.add_symbol_dce(pm) # Eliminate dead symbols
        if os.environ.get("TRITON_DISABLE_LINE_INFO", "0") == "0":
            passes.llvmir.add_di_scope(pm) # Materialize LLVM line info
        pm.run(mod)
        # LLVM-IR (MLIR) -> LLVM-IR (LLVM) 
        llvm.init_targets()
        context = llvm.context()
        llvm_mod = llvm.to_module(mod, context) # 将LLVM dialect转换为LLVMIR
        nvidia.set_nvvm_reflect_ftz(llvm_mod) # enable fast math path in libdevice
        if options.extern_libs:
            for name, path in options.extern_libs:
                llvm.link_extern_lib(llvm_mod, path) # link libdevice,一些函数库会用到
        llvm.optimize_module(llvm_mod, llvm.OPTIMIZE_O3) # O3优化
        metadata["shared"] = src.get_int_attr("triton_gpu.shared")
        ret = str(llvm_mod)
        del llvm_mod
        del context
        return ret

由于第二步的转换比较固定,我们重点关注第一步,将各种dialect都转成LLVMDialect。其中主要关注add_to_llvmir这个pass,因为我们case中的arith.addi、tt.load和tt.store都会在这个pass中被重写。跳转到它的实现会发现是这样的(triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TritonGPUToLLVM.cpp)

void runOnOperation() override {
    MLIRContext *context = &getContext();
    ModuleOp mod = getOperation();
    mlir::LowerToLLVMOptions option(context);
    option.overrideIndexBitwidth(32);
    TritonGPUToLLVMTypeConverter typeConverter(context, option);
    TritonLLVMConversionTarget convTarget(*context);
    int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
    int numCTAs = triton::gpu::TritonGPUDialect::getNumCTAs(mod);
    int threadsPerWarp = triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod);

    // Allocate shared memory and set barrier
    ModuleAllocation allocation(mod);
    ModuleMembarAnalysis membarPass(&allocation);
    membarPass.run();

    // Lower functions
    {
      mlir::LowerToLLVMOptions option(context);
      TritonGPUToLLVMTypeConverter typeConverter(context, option);
      TritonLLVMFunctionConversionTarget funcTarget(*context);
      RewritePatternSet funcPatterns(context);
      funcPatterns.add<FuncOpConversion>(typeConverter, numWarps,
                                         patternBenefitDefault);
      mlir::cf::populateControlFlowToLLVMConversionPatterns(typeConverter,
                                                            funcPatterns);
      if (failed(
              applyPartialConversion(mod, funcTarget, std::move(funcPatterns))))
        return signalPassFailure();
    }

    // initSharedMemory is run before the conversion of call and ret ops,
    // because the call op has to know the shared memory base address of each
    // function
    initSharedMemory(typeConverter);
    ModuleAxisInfoAnalysis axisInfoAnalysis(mod);
    OpBuilder::InsertPoint indexInsertPoint;

    RewritePatternSet patterns(context);
    TargetInfo targetInfo(computeCapability);
    int benefit = patternBenefitPrioritizeOverLLVMConversions;
    ......
    populateLoadStoreOpToLLVMPatterns(typeConverter, patterns, axisInfoAnalysis,
                                      benefit);
    // 会调用下面两条
    // patterns.add<LoadOpConversion>(typeConverter, axisInfoAnalysis, benefit);
    // patterns.add<StoreOpConversion>(typeConverter, axisInfoAnalysis, benefit);
    mlir::arith::populateArithToLLVMConversionPatterns(typeConverter, patterns);
    // 会调用patterns.add<AddIOpLowering...>
    ......
    if (failed(applyPartialConversion(mod, convTarget, std::move(patterns))))
      return signalPassFailure();
    ......
  }

我们发现,这里的代码结构也是populate#Opname#Pattern,然后再执行applyPartialConversion,好像和上一章中add_convert_to_ttgpuir的转换过程差不多,细心的小伙伴可以观察到区别在于这里我们的target是TritonLLVMFunctionConversionTarget,而前面是TritonGPUConversionTarget,前者又增加了IndexDialect、LLVMDialect、NVMDialect等为合法,其他非法dialect在这一阶段会被lower。

(triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TritonGPUToLLVM.cpp)

class TritonLLVMConversionTarget : public ConversionTarget {
public:
  explicit TritonLLVMConversionTarget(MLIRContext &ctx)
      : ConversionTarget(ctx) {
    addLegalDialect<LLVM::LLVMDialect>();
    addLegalDialect<NVVM::NVVMDialect>();
    addLegalDialect<mlir::triton::nvgpu::NVGPUDialect>();
    addIllegalDialect<triton::TritonDialect>();
    addIllegalDialect<triton::gpu::TritonGPUDialect>();
    addIllegalDialect<triton::nvidia_gpu::TritonNvidiaGPUDialect>();
    addIllegalDialect<mlir::gpu::GPUDialect>();
    addLegalOp<mlir::UnrealizedConversionCastOp>();
  }
};

typeConverter也由TritonGPUTypeConverter变成了TritonGPUToLLVMTypeConverter,增加了更多类型的转换方式,比如nv新支持的fp8数据类型

(triton/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp)

TritonGPUToLLVMTypeConverter::TritonGPUToLLVMTypeConverter(
    MLIRContext *ctx, LowerToLLVMOptions &option,
    const DataLayoutAnalysis *analysis)
    : LLVMTypeConverter(ctx, option, analysis) {
  addConversion([&](triton::PointerType type) -> std::optional<Type> {
    return convertTritonPointerType(type);
  });
  addConversion([&](RankedTensorType type) -> std::optional<Type> {
    return convertTritonTensorType(type);
  });
  addConversion([&](MemDescType type) -> std::optional<Type> {
    return convertMemDescType(type);
  });
  addConversion([&](triton::gpu::AsyncTokenType type) -> std::optional<Type> {
    return convertAsyncToken(type);
  });
  // Internally store float8 as int8
  addConversion([&](mlir::Float8E4M3B11FNUZType type) -> std::optional<Type> {
    return IntegerType::get(type.getContext(), 8);
  });
  addConversion([&](mlir::Float8E4M3FNType type) -> std::optional<Type> {
    return IntegerType::get(type.getContext(), 8);
  });
  addConversion([&](mlir::Float8E4M3FNUZType type) -> std::optional<Type> {
    return IntegerType::get(type.getContext(), 8);
  });
  addConversion([&](mlir::Float8E5M2Type type) -> std::optional<Type> {
    return IntegerType::get(type.getContext(), 8);
  });
  // Internally store bfloat16 as int16
  addConversion([&](BFloat16Type type) -> std::optional<Type> {
    return IntegerType::get(type.getContext(), 16);
  });
}

此外,在上一章populate#Opname#Pattern时,我们的arith.addi、tt.load和tt.store都是采用的通用转换模式GenericOpPattern来转换的,而这里会有三种处理方式。对于load,我们看到它的RewritePattern是LoadOpConversion,直接看它的转换函数matchAndRewrite()

matchAndRewrite(triton::LoadOp op, OpAdaptor adaptor,
                  ConversionPatternRewriter &rewriter) const override {
    ......
    // Define the instruction opcode
    auto &ld = ptxBuilder.create<>("ld")
                   ->o("volatile", op.getIsVolatile())
                   .global()
                   .o("ca", op.getCache() == triton::CacheModifier::CA)
                   .o("cg", op.getCache() == triton::CacheModifier::CG)
                   .o("L1::evict_first",
                      op.getEvict() == triton::EvictionPolicy::EVICT_FIRST)
                   .o("L1::evict_last",
                      op.getEvict() == triton::EvictionPolicy::EVICT_LAST)
                   .o("L1::cache_hint", hasL2EvictPolicy)
                   .v(nWords)
                   .b(width);
    ......
  }

可以看到nv这里将tt.load简单粗暴地处理成了内嵌汇编(可能是为了方便cache控制,amd是处理成了LLVM::LoadOp)。此外,load和store的处理还包括很多阶段,比如向量化、计算线程访问mask等,还会建立triton::nvgpu::ClusterCTAIdOp来和索引绑定一起,实现SIMT编程(这里省去了很多代码,还没有看完,之后尽量补上)。store同样通过StoreOpConversion处理成了内嵌汇编的形式。

对于arith.addi,在populateArithToLLVMConversionPatterns的时候调用了mlir中的方法AddIOpLowering,它也是继承自RewritePattern用来改写addi op

using AddIOpLowering =
    VectorConvertToLLVMPattern<arith::AddIOp, LLVM::AddOp,
                               arith::AttrConvertOverflowToLLVM>;

在它的实现中,会直接调用

rewriter.replaceOp(op, newOp->getResult(0)), success())

将arith::AddIOp替换为LLVM::AddOp。至此,我们已经有了target、type converter和各个op的RewritePattern,接下来就是重复上一章的运行过程,先判断合法化,再去做转换,最终完成它们到LLVMDialect的转换。此时的IR长这个样(为了简洁,打印了优化后的ir)

#loc = loc("toy.py":28:0)
module attributes {"triton_gpu.compute-capability" = 80 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.shared = 0 : i32, "triton_gpu.threads-per-warp" = 32 : i32} {
  llvm.mlir.global external @global_smem() {addr_space = 3 : i32, alignment = 16 : i64} : !llvm.array<0 x i8> loc(#loc)
  llvm.func @addi_kernel_01(%arg0: !llvm.ptr<1> loc("toy.py":28:0), %arg1: !llvm.ptr<1> loc("toy.py":28:0)) attributes {noinline = false, nvvm.kernel = 1 : ui1, nvvm.maxntid = array<i32: 128>} {
    %0 = llvm.mlir.constant(0 : i32) : i32 loc(#loc1)
    %1 = llvm.mlir.constant(0 : index) : i32 loc(#loc1)
    %2 = llvm.mlir.constant(true) : i1 loc(#loc1)
    %3 = llvm.mlir.constant(1 : i32) : i32 loc(#loc1)
    %4 = llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "mov.u32 $0, 0x0;\0A\09@$2 ld.global.b32 { $0 }, [ $1 + 0 ];", "=r,l,b" %arg0, %2 : (!llvm.ptr<1>, i1) -> i32 loc(#loc2)
    %5 = llvm.bitcast %4 : i32 to vector<1xi32> loc(#loc2)
    %6 = llvm.extractelement %5[%1 : i32] : vector<1xi32> loc(#loc2)
    %7 = llvm.add %6, %3  : i32 loc(#loc3)
    %8 = nvvm.read.ptx.sreg.tid.x : i32 loc(#loc4)
    %9 = llvm.and %2, %2  : i1 loc(#loc4)
    %10 = llvm.icmp "eq" %8, %0 : i32 loc(#loc4)
    %11 = llvm.and %9, %10  : i1 loc(#loc4)
    %12 = llvm.mlir.undef : vector<1xi32> loc(#loc4)
    %13 = llvm.insertelement %7, %12[%0 : i32] : vector<1xi32> loc(#loc4)
    %14 = llvm.bitcast %13 : vector<1xi32> to i32 loc(#loc4)
    %15 = llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "@$2 st.global.b32 [ $1 + 0 ], { $0 };", "r,l,b" %14, %arg1, %11 : (i32, !llvm.ptr<1>, i1) -> !llvm.void loc(#loc4)
    llvm.return loc(#loc5)
  } loc(#loc6)
} loc(#loc)
#di_file = #llvm.di_file<"toy.py" in "">
#di_subroutine_type = #llvm.di_subroutine_type<callingConvention = DW_CC_normal>
#loc1 = loc(unknown)
#loc2 = loc("toy.py":38:16)
#loc3 = loc("toy.py":39:17)
#loc4 = loc("toy.py":40:25)
#loc5 = loc("toy.py":40:4)
#di_compile_unit = #llvm.di_compile_unit<id = distinct[0]<>, sourceLanguage = DW_LANG_C, file = #di_file, producer = "triton", isOptimized = true, emissionKind = LineTablesOnly>
#di_subprogram = #llvm.di_subprogram<id = distinct[0]<>, compileUnit = #di_compile_unit, scope = #di_file, name = "addi_kernel_01", linkageName = "addi_kernel_01", file = #di_file, line = 28, scopeLine = 28, subprogramFlags = "Definition|Optimized", type = #di_subroutine_type>
#loc6 = loc(fused<#di_subprogram>[#loc])

其中arith.addi变成了llvm.add,tt.load/store变成了内嵌汇编,以及用来计算索引地址的若干条指令。load指令由于nv使用的内嵌汇编看起来不是很好理解,这里的%4=...这一条我们还是解释一下。首它是两条汇编组成:mov.u32 $0, 0x0 和 @$2 ld.global.b32 { $0 }, [ $1 + 0 ],这里的$0/$1/$2分别是%4/%arg0/%2。第一条汇编的含义是给%4一个初始值0,第二条load指令是从地址%arg0读取32位的数据到4%,并且第二条指令会受$2("@"后面是谓词寄存器)控制,当$2为true执行,反之不执行(这也是为什么需要第一条mov,当load不执行的时候需要给%4一个默认值)。

除了llvm和nvvm,已经没有了其他dialect,这里的nvvm.read.ptx.sreg.tid.x(对应cuda的threadIdx.x)是在add_nvgpu_to_llvm中lower triton::nvgpu::ClusterCTAIdOp得到,它可以被LLVM中的nvptx后端当作intrinsic函数来处理。

最后的LLVM-IR (MLIR) -> LLVM-IR (LLVM) 过程是将LLVM dialect转换为真正的LLVMIR,实际上是利用llvm::IRBuilder依次对每个op、符号等进行翻译,并创建metadata。它的实现在mlir/lib/Target/LLVMIR/ModuleTranslation.cpp的translateModuleToLLVMIR中,这里不展开介绍了。

到此为止,我们得到最终的LLVM IR,在这之前还通过llvm.optimize_module(llvm_mod, llvm.OPTIMIZE_O3)应用了O3优化,删除了很多冗余指令。由此可以经过后端编译器翻译成硬件指令,愉快地运行了。

; ModuleID = 'LLVMDialectModule'
source_filename = "LLVMDialectModule"

define void @addi_kernel_01(ptr addrspace(1) %0, ptr addrspace(1) %1) local_unnamed_addr !dbg !7 {
  %3 = tail call i32 asm sideeffect "mov.u32 $0, 0x0;\0A\09@$2 ld.global.b32 { $0 }, [ $1 + 0 ];", "=r,l,b"(ptr addrspace(1) %0, i1 true) #1, !dbg !10
  %4 = add i32 %3, 1, !dbg !11
  %5 = tail call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !dbg !12
  %6 = icmp eq i32 %5, 0, !dbg !12
  tail call void asm sideeffect "@$2 st.global.b32 [ $1 + 0 ], { $0 };", "r,l,b"(i32 %4, ptr addrspace(1) %1, i1 %6) #1, !dbg !12
  ret void, !dbg !13
}

; Function Attrs: mustprogress nocallback nofree nosync nounwind speculatable willreturn memory(none)
declare noundef i32 @llvm.nvvm.read.ptx.sreg.tid.x() #0

attributes #0 = { mustprogress nocallback nofree nosync nounwind speculatable willreturn memory(none) }
attributes #1 = { nounwind }

!llvm.module.flags = !{!0, !1}
!llvm.dbg.cu = !{!2}
!nvvm.annotations = !{!4, !5}
!llvm.ident = !{!6}

!0 = !{i32 2, !"Debug Info Version", i32 3}
!1 = !{i32 4, !"nvvm-reflect-ftz", i32 1}
!2 = distinct !DICompileUnit(language: DW_LANG_C, file: !3, producer: "triton", isOptimized: true, runtimeVersion: 0, emissionKind: LineTablesOnly)
!3 = !DIFile(filename: "toy.py", directory: "")
!4 = !{ptr @addi_kernel_01, !"kernel", i32 1}
!5 = !{ptr @addi_kernel_01, !"maxntidx", i32 128}
!6 = !{!"clang version 3.8.0 (tags/RELEASE_380/final)"}
!7 = distinct !DISubprogram(name: "addi_kernel_01", linkageName: "addi_kernel_01", scope: !3, file: !3, line: 28, type: !8, scopeLine: 28, spFlags: DISPFlagDefinition | DISPFlagOptimized, unit: !2)
!8 = !DISubroutineType(cc: DW_CC_normal, types: !9)
!9 = !{}
!10 = !DILocation(line: 38, column: 16, scope: !7)
!11 = !DILocation(line: 39, column: 17, scope: !7)
!12 = !DILocation(line: 40, column: 25, scope: !7)
!13 = !DILocation(line: 40, column: 4, scope: !7)

到这里,一共三章的内容就结束了,我们还是回到最初的case,以加法为例,它从源代码中的一个加法逐步lower到了以下形式ast.BinOp(op.name="__add__")->arith::addi->llvm.add->add,其实类比来看,clang中对加法是直接经过了ast->ir的阶段,因此可以看出triton多了dialect级别的表示,带来的好处就是编程上更大的宽容度,灵活的dialect表示支持了更高抽象层次的源代码,使得我们在编程时可以忽略更多细节,比如kernel内部的threadIdx调用、同步函数的添加等等。

最后,在学习过程中自己对整个triton的认知还非常浅显,driver、优化等很多部分都没有涉及到,在表述上也会有不少不准确的地方,希望大家指正,感谢阅读。

The End

作者:液态黑洞
来源:GiantPandaCV

推荐阅读

欢迎大家点赞留言,更多Arm技术文章动态请关注极术社区嵌入式AI专栏欢迎添加极术小姐姐微信(id:aijishu20)加入技术交流群,请备注研究方向。

推荐阅读
关注数
18777
内容数
1331
嵌入式端AI,包括AI算法在推理框架Tengine,MNN,NCNN,PaddlePaddle及相关芯片上的实现。欢迎加入微信交流群,微信号:aijishu20(备注:嵌入式)
目录
极术微信服务号
关注极术微信号
实时接收点赞提醒和评论通知
安谋科技学堂公众号
关注安谋科技学堂
实时获取安谋科技及 Arm 教学资源
安谋科技招聘公众号
关注安谋科技招聘
实时获取安谋科技中国职位信息