在上一章,我们完成了ttir->ttgir的过程分析,重点在于理解其中用到的数据结构和流程。有了上面的基础,我们理解接下来的内容会非常轻松。在这一阶段结束时我们的case还是包含arith::addi、tt.load、tt.store等节点,在这一阶段我们会看到它们的变化。所以让我们直接进入最后的make_llir阶段。
- make_llir
根据注释,这一步其实又可以分为两小步,TritonGPU -> LLVM-IR (MLIR) 和 LLVM-IR (MLIR) -> LLVM-IR (LLVM)。这两步的区别在于,第一步是还是MLIR级别的,也就是在dialect空间的转换,转换的结果就是LLVMDialect,而第二步是将LLVMDialect转换为真正的LLVM IR。
@staticmethod
def make_llir(src, metadata, options, capability):
# warp-specialization mutates num_warps
num_warp_groups = src.get_int_attr("triton_gpu.num-warp-groups-per-cta")
if num_warp_groups is not None:
metadata["num_warps"] *= num_warp_groups
mod = src
# TritonGPU -> LLVM-IR (MLIR)
pm = ir.pass_manager(mod.context)
pm.enable_debug()
nvidia.passes.ttgpuir.add_decompose_unsupported_conversions(pm) # Decompose conversions that are not supported by TritonGPU -> LLVM
passes.convert.add_scf_to_cf(pm) # Convert SCF dialect to ControlFlow dialect
passes.convert.add_index_to_llvmir(pm) # Lower the `index` dialect to the `llvm` dialect
passes.ttgpuir.add_allocate_shared_memory(pm) # Add metadata for shared memory allocation
nvidia.passes.ttgpuir.add_to_llvmir(pm, capability)
nvidia.passes.ttnvgpuir.add_nvgpu_to_llvm(pm) # 用来处理NVGPUDialect的节点,大部分替换为内嵌汇编
passes.convert.add_arith_to_llvmir(pm) # Convert Arith dialect to LLVM dialect
passes.common.add_canonicalizer(pm) # converts operations into their canonical forms by folding constants, identity transformations etc.
passes.common.add_cse(pm) # Eliminate common sub-expressions
passes.common.add_symbol_dce(pm) # Eliminate dead symbols
if os.environ.get("TRITON_DISABLE_LINE_INFO", "0") == "0":
passes.llvmir.add_di_scope(pm) # Materialize LLVM line info
pm.run(mod)
# LLVM-IR (MLIR) -> LLVM-IR (LLVM)
llvm.init_targets()
context = llvm.context()
llvm_mod = llvm.to_module(mod, context) # 将LLVM dialect转换为LLVMIR
nvidia.set_nvvm_reflect_ftz(llvm_mod) # enable fast math path in libdevice
if options.extern_libs:
for name, path in options.extern_libs:
llvm.link_extern_lib(llvm_mod, path) # link libdevice,一些函数库会用到
llvm.optimize_module(llvm_mod, llvm.OPTIMIZE_O3) # O3优化
metadata["shared"] = src.get_int_attr("triton_gpu.shared")
ret = str(llvm_mod)
del llvm_mod
del context
return ret
由于第二步的转换比较固定,我们重点关注第一步,将各种dialect都转成LLVMDialect。其中主要关注add_to_llvmir这个pass,因为我们case中的arith.addi、tt.load和tt.store都会在这个pass中被重写。跳转到它的实现会发现是这样的(triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TritonGPUToLLVM.cpp)
void runOnOperation() override {
MLIRContext *context = &getContext();
ModuleOp mod = getOperation();
mlir::LowerToLLVMOptions option(context);
option.overrideIndexBitwidth(32);
TritonGPUToLLVMTypeConverter typeConverter(context, option);
TritonLLVMConversionTarget convTarget(*context);
int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
int numCTAs = triton::gpu::TritonGPUDialect::getNumCTAs(mod);
int threadsPerWarp = triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod);
// Allocate shared memory and set barrier
ModuleAllocation allocation(mod);
ModuleMembarAnalysis membarPass(&allocation);
membarPass.run();
// Lower functions
{
mlir::LowerToLLVMOptions option(context);
TritonGPUToLLVMTypeConverter typeConverter(context, option);
TritonLLVMFunctionConversionTarget funcTarget(*context);
RewritePatternSet funcPatterns(context);
funcPatterns.add<FuncOpConversion>(typeConverter, numWarps,
patternBenefitDefault);
mlir::cf::populateControlFlowToLLVMConversionPatterns(typeConverter,
funcPatterns);
if (failed(
applyPartialConversion(mod, funcTarget, std::move(funcPatterns))))
return signalPassFailure();
}
// initSharedMemory is run before the conversion of call and ret ops,
// because the call op has to know the shared memory base address of each
// function
initSharedMemory(typeConverter);
ModuleAxisInfoAnalysis axisInfoAnalysis(mod);
OpBuilder::InsertPoint indexInsertPoint;
RewritePatternSet patterns(context);
TargetInfo targetInfo(computeCapability);
int benefit = patternBenefitPrioritizeOverLLVMConversions;
......
populateLoadStoreOpToLLVMPatterns(typeConverter, patterns, axisInfoAnalysis,
benefit);
// 会调用下面两条
// patterns.add<LoadOpConversion>(typeConverter, axisInfoAnalysis, benefit);
// patterns.add<StoreOpConversion>(typeConverter, axisInfoAnalysis, benefit);
mlir::arith::populateArithToLLVMConversionPatterns(typeConverter, patterns);
// 会调用patterns.add<AddIOpLowering...>
......
if (failed(applyPartialConversion(mod, convTarget, std::move(patterns))))
return signalPassFailure();
......
}
我们发现,这里的代码结构也是populate#Opname#Pattern,然后再执行applyPartialConversion,好像和上一章中add_convert_to_ttgpuir的转换过程差不多,细心的小伙伴可以观察到区别在于这里我们的target是TritonLLVMFunctionConversionTarget,而前面是TritonGPUConversionTarget,前者又增加了IndexDialect、LLVMDialect、NVMDialect等为合法,其他非法dialect在这一阶段会被lower。
(triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TritonGPUToLLVM.cpp)
class TritonLLVMConversionTarget : public ConversionTarget {
public:
explicit TritonLLVMConversionTarget(MLIRContext &ctx)
: ConversionTarget(ctx) {
addLegalDialect<LLVM::LLVMDialect>();
addLegalDialect<NVVM::NVVMDialect>();
addLegalDialect<mlir::triton::nvgpu::NVGPUDialect>();
addIllegalDialect<triton::TritonDialect>();
addIllegalDialect<triton::gpu::TritonGPUDialect>();
addIllegalDialect<triton::nvidia_gpu::TritonNvidiaGPUDialect>();
addIllegalDialect<mlir::gpu::GPUDialect>();
addLegalOp<mlir::UnrealizedConversionCastOp>();
}
};
typeConverter也由TritonGPUTypeConverter变成了TritonGPUToLLVMTypeConverter,增加了更多类型的转换方式,比如nv新支持的fp8数据类型
(triton/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp)
TritonGPUToLLVMTypeConverter::TritonGPUToLLVMTypeConverter(
MLIRContext *ctx, LowerToLLVMOptions &option,
const DataLayoutAnalysis *analysis)
: LLVMTypeConverter(ctx, option, analysis) {
addConversion([&](triton::PointerType type) -> std::optional<Type> {
return convertTritonPointerType(type);
});
addConversion([&](RankedTensorType type) -> std::optional<Type> {
return convertTritonTensorType(type);
});
addConversion([&](MemDescType type) -> std::optional<Type> {
return convertMemDescType(type);
});
addConversion([&](triton::gpu::AsyncTokenType type) -> std::optional<Type> {
return convertAsyncToken(type);
});
// Internally store float8 as int8
addConversion([&](mlir::Float8E4M3B11FNUZType type) -> std::optional<Type> {
return IntegerType::get(type.getContext(), 8);
});
addConversion([&](mlir::Float8E4M3FNType type) -> std::optional<Type> {
return IntegerType::get(type.getContext(), 8);
});
addConversion([&](mlir::Float8E4M3FNUZType type) -> std::optional<Type> {
return IntegerType::get(type.getContext(), 8);
});
addConversion([&](mlir::Float8E5M2Type type) -> std::optional<Type> {
return IntegerType::get(type.getContext(), 8);
});
// Internally store bfloat16 as int16
addConversion([&](BFloat16Type type) -> std::optional<Type> {
return IntegerType::get(type.getContext(), 16);
});
}
此外,在上一章populate#Opname#Pattern时,我们的arith.addi、tt.load和tt.store都是采用的通用转换模式GenericOpPattern来转换的,而这里会有三种处理方式。对于load,我们看到它的RewritePattern是LoadOpConversion,直接看它的转换函数matchAndRewrite()
matchAndRewrite(triton::LoadOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
......
// Define the instruction opcode
auto &ld = ptxBuilder.create<>("ld")
->o("volatile", op.getIsVolatile())
.global()
.o("ca", op.getCache() == triton::CacheModifier::CA)
.o("cg", op.getCache() == triton::CacheModifier::CG)
.o("L1::evict_first",
op.getEvict() == triton::EvictionPolicy::EVICT_FIRST)
.o("L1::evict_last",
op.getEvict() == triton::EvictionPolicy::EVICT_LAST)
.o("L1::cache_hint", hasL2EvictPolicy)
.v(nWords)
.b(width);
......
}
可以看到nv这里将tt.load简单粗暴地处理成了内嵌汇编(可能是为了方便cache控制,amd是处理成了LLVM::LoadOp)。此外,load和store的处理还包括很多阶段,比如向量化、计算线程访问mask等,还会建立triton::nvgpu::ClusterCTAIdOp来和索引绑定一起,实现SIMT编程(这里省去了很多代码,还没有看完,之后尽量补上)。store同样通过StoreOpConversion处理成了内嵌汇编的形式。
对于arith.addi,在populateArithToLLVMConversionPatterns的时候调用了mlir中的方法AddIOpLowering,它也是继承自RewritePattern用来改写addi op
using AddIOpLowering =
VectorConvertToLLVMPattern<arith::AddIOp, LLVM::AddOp,
arith::AttrConvertOverflowToLLVM>;
在它的实现中,会直接调用
rewriter.replaceOp(op, newOp->getResult(0)), success())
将arith::AddIOp替换为LLVM::AddOp。至此,我们已经有了target、type converter和各个op的RewritePattern,接下来就是重复上一章的运行过程,先判断合法化,再去做转换,最终完成它们到LLVMDialect的转换。此时的IR长这个样(为了简洁,打印了优化后的ir)
#loc = loc("toy.py":28:0)
module attributes {"triton_gpu.compute-capability" = 80 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.shared = 0 : i32, "triton_gpu.threads-per-warp" = 32 : i32} {
llvm.mlir.global external @global_smem() {addr_space = 3 : i32, alignment = 16 : i64} : !llvm.array<0 x i8> loc(#loc)
llvm.func @addi_kernel_01(%arg0: !llvm.ptr<1> loc("toy.py":28:0), %arg1: !llvm.ptr<1> loc("toy.py":28:0)) attributes {noinline = false, nvvm.kernel = 1 : ui1, nvvm.maxntid = array<i32: 128>} {
%0 = llvm.mlir.constant(0 : i32) : i32 loc(#loc1)
%1 = llvm.mlir.constant(0 : index) : i32 loc(#loc1)
%2 = llvm.mlir.constant(true) : i1 loc(#loc1)
%3 = llvm.mlir.constant(1 : i32) : i32 loc(#loc1)
%4 = llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "mov.u32 $0, 0x0;\0A\09@$2 ld.global.b32 { $0 }, [ $1 + 0 ];", "=r,l,b" %arg0, %2 : (!llvm.ptr<1>, i1) -> i32 loc(#loc2)
%5 = llvm.bitcast %4 : i32 to vector<1xi32> loc(#loc2)
%6 = llvm.extractelement %5[%1 : i32] : vector<1xi32> loc(#loc2)
%7 = llvm.add %6, %3 : i32 loc(#loc3)
%8 = nvvm.read.ptx.sreg.tid.x : i32 loc(#loc4)
%9 = llvm.and %2, %2 : i1 loc(#loc4)
%10 = llvm.icmp "eq" %8, %0 : i32 loc(#loc4)
%11 = llvm.and %9, %10 : i1 loc(#loc4)
%12 = llvm.mlir.undef : vector<1xi32> loc(#loc4)
%13 = llvm.insertelement %7, %12[%0 : i32] : vector<1xi32> loc(#loc4)
%14 = llvm.bitcast %13 : vector<1xi32> to i32 loc(#loc4)
%15 = llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "@$2 st.global.b32 [ $1 + 0 ], { $0 };", "r,l,b" %14, %arg1, %11 : (i32, !llvm.ptr<1>, i1) -> !llvm.void loc(#loc4)
llvm.return loc(#loc5)
} loc(#loc6)
} loc(#loc)
#di_file = #llvm.di_file<"toy.py" in "">
#di_subroutine_type = #llvm.di_subroutine_type<callingConvention = DW_CC_normal>
#loc1 = loc(unknown)
#loc2 = loc("toy.py":38:16)
#loc3 = loc("toy.py":39:17)
#loc4 = loc("toy.py":40:25)
#loc5 = loc("toy.py":40:4)
#di_compile_unit = #llvm.di_compile_unit<id = distinct[0]<>, sourceLanguage = DW_LANG_C, file = #di_file, producer = "triton", isOptimized = true, emissionKind = LineTablesOnly>
#di_subprogram = #llvm.di_subprogram<id = distinct[0]<>, compileUnit = #di_compile_unit, scope = #di_file, name = "addi_kernel_01", linkageName = "addi_kernel_01", file = #di_file, line = 28, scopeLine = 28, subprogramFlags = "Definition|Optimized", type = #di_subroutine_type>
#loc6 = loc(fused<#di_subprogram>[#loc])
其中arith.addi变成了llvm.add,tt.load/store变成了内嵌汇编,以及用来计算索引地址的若干条指令。load指令由于nv使用的内嵌汇编看起来不是很好理解,这里的%4=...这一条我们还是解释一下。首它是两条汇编组成:mov.u32 $0, 0x0 和 @$2 ld.global.b32 { $0 }, [ $1 + 0 ],这里的$0/$1/$2分别是%4/%arg0/%2。第一条汇编的含义是给%4一个初始值0,第二条load指令是从地址%arg0读取32位的数据到4%,并且第二条指令会受$2("@"后面是谓词寄存器)控制,当$2为true执行,反之不执行(这也是为什么需要第一条mov,当load不执行的时候需要给%4一个默认值)。
除了llvm和nvvm,已经没有了其他dialect,这里的nvvm.read.ptx.sreg.tid.x(对应cuda的threadIdx.x)是在add_nvgpu_to_llvm中lower triton::nvgpu::ClusterCTAIdOp得到,它可以被LLVM中的nvptx后端当作intrinsic函数来处理。
最后的LLVM-IR (MLIR) -> LLVM-IR (LLVM) 过程是将LLVM dialect转换为真正的LLVMIR,实际上是利用llvm::IRBuilder依次对每个op、符号等进行翻译,并创建metadata。它的实现在mlir/lib/Target/LLVMIR/ModuleTranslation.cpp的translateModuleToLLVMIR中,这里不展开介绍了。
到此为止,我们得到最终的LLVM IR,在这之前还通过llvm.optimize_module(llvm_mod, llvm.OPTIMIZE_O3)应用了O3优化,删除了很多冗余指令。由此可以经过后端编译器翻译成硬件指令,愉快地运行了。
; ModuleID = 'LLVMDialectModule'
source_filename = "LLVMDialectModule"
define void @addi_kernel_01(ptr addrspace(1) %0, ptr addrspace(1) %1) local_unnamed_addr !dbg !7 {
%3 = tail call i32 asm sideeffect "mov.u32 $0, 0x0;\0A\09@$2 ld.global.b32 { $0 }, [ $1 + 0 ];", "=r,l,b"(ptr addrspace(1) %0, i1 true) #1, !dbg !10
%4 = add i32 %3, 1, !dbg !11
%5 = tail call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !dbg !12
%6 = icmp eq i32 %5, 0, !dbg !12
tail call void asm sideeffect "@$2 st.global.b32 [ $1 + 0 ], { $0 };", "r,l,b"(i32 %4, ptr addrspace(1) %1, i1 %6) #1, !dbg !12
ret void, !dbg !13
}
; Function Attrs: mustprogress nocallback nofree nosync nounwind speculatable willreturn memory(none)
declare noundef i32 @llvm.nvvm.read.ptx.sreg.tid.x() #0
attributes #0 = { mustprogress nocallback nofree nosync nounwind speculatable willreturn memory(none) }
attributes #1 = { nounwind }
!llvm.module.flags = !{!0, !1}
!llvm.dbg.cu = !{!2}
!nvvm.annotations = !{!4, !5}
!llvm.ident = !{!6}
!0 = !{i32 2, !"Debug Info Version", i32 3}
!1 = !{i32 4, !"nvvm-reflect-ftz", i32 1}
!2 = distinct !DICompileUnit(language: DW_LANG_C, file: !3, producer: "triton", isOptimized: true, runtimeVersion: 0, emissionKind: LineTablesOnly)
!3 = !DIFile(filename: "toy.py", directory: "")
!4 = !{ptr @addi_kernel_01, !"kernel", i32 1}
!5 = !{ptr @addi_kernel_01, !"maxntidx", i32 128}
!6 = !{!"clang version 3.8.0 (tags/RELEASE_380/final)"}
!7 = distinct !DISubprogram(name: "addi_kernel_01", linkageName: "addi_kernel_01", scope: !3, file: !3, line: 28, type: !8, scopeLine: 28, spFlags: DISPFlagDefinition | DISPFlagOptimized, unit: !2)
!8 = !DISubroutineType(cc: DW_CC_normal, types: !9)
!9 = !{}
!10 = !DILocation(line: 38, column: 16, scope: !7)
!11 = !DILocation(line: 39, column: 17, scope: !7)
!12 = !DILocation(line: 40, column: 25, scope: !7)
!13 = !DILocation(line: 40, column: 4, scope: !7)
到这里,一共三章的内容就结束了,我们还是回到最初的case,以加法为例,它从源代码中的一个加法逐步lower到了以下形式ast.BinOp(op.name="__add__")->arith::addi->llvm.add->add,其实类比来看,clang中对加法是直接经过了ast->ir的阶段,因此可以看出triton多了dialect级别的表示,带来的好处就是编程上更大的宽容度,灵活的dialect表示支持了更高抽象层次的源代码,使得我们在编程时可以忽略更多细节,比如kernel内部的threadIdx调用、同步函数的添加等等。
最后,在学习过程中自己对整个triton的认知还非常浅显,driver、优化等很多部分都没有涉及到,在表述上也会有不少不准确的地方,希望大家指正,感谢阅读。
The End
作者:液态黑洞
来源:GiantPandaCV
推荐阅读
- 窥探Trition的lower(一)
- MLIR_对自定义IR Dialect编写bufferization pass
- How to Do the Paper/Talk Reviews
- SIMD 指令集与数据并行程序
欢迎大家点赞留言,更多Arm技术文章动态请关注极术社区嵌入式AI专栏欢迎添加极术小姐姐微信(id:aijishu20)加入技术交流群,请备注研究方向。