专栏名称: GiantPandaCV
专注于机器学习、深度学习、计算机视觉、图像处理等多个方向技术分享。团队由一群热爱技术且热衷于分享的小伙伴组成。我们坚持原创,每天一到两篇原创技术分享。希望在传播知识、分享知识的同时能够启发你,大家一起共同进步(・ω<)☆
目录
相关文章推荐
药明康德  ·  35年来首个!吸入式过敏疗法再获FDA批准 ·  3 天前  
ZEALER  ·  雷军分享小米 SU7 Ultra ... ·  2 天前  
51好读  ›  专栏  ›  GiantPandaCV

窥探Triton的lower(三)

GiantPandaCV  · 公众号  ·  · 2024-06-29 23:45

正文



作者丨液态黑洞
来源丨https://zhuanlan.zhihu.com/p/696133729
编辑丨GiantPandaCV


在上一章,我们完成了ttir->ttgir的过程分析,重点在于理解其中用到的数据结构和流程。有了上面的基础,我们理解接下来的内容会非常轻松。在这一阶段结束时我们的case还是包含arith::addi、tt.load、tt.store等节点,在这一阶段我们会看到它们的变化。所以让我们直接进入最后的make_llir阶段。

  • make_llir

根据注释,这一步其实又可以分为两小步,TritonGPU -> LLVM-IR (MLIR) 和 LLVM-IR (MLIR) -> LLVM-IR (LLVM)。这两步的区别在于,第一步是还是MLIR级别的,也就是在dialect空间的转换,转换的结果就是LLVMDialect,而第二步是将LLVMDialect转换为真正的LLVM IR。

    @staticmethod
def make_llir(src, metadata, options, capability):
# warp-specialization mutates num_warps
num_warp_groups = src.get_int_attr("triton_gpu.num-warp-groups-per-cta")
if num_warp_groups is not None:
metadata["num_warps"] *= num_warp_groups
mod = src
# TritonGPU -> LLVM-IR (MLIR)
pm = ir.pass_manager(mod.context)
pm.enable_debug()
nvidia.passes.ttgpuir.add_decompose_unsupported_conversions(pm) # Decompose conversions that are not supported by TritonGPU -> LLVM
passes.convert.add_scf_to_cf(pm) # Convert SCF dialect to ControlFlow dialect
passes.convert.add_index_to_llvmir(pm) # Lower the `index` dialect to the `llvm` dialect
passes.ttgpuir.add_allocate_shared_memory(pm) # Add metadata for shared memory allocation
nvidia.passes.ttgpuir.add_to_llvmir(pm, capability)
nvidia.passes.ttnvgpuir.add_nvgpu_to_llvm(pm) # 用来处理NVGPUDialect的节点,大部分替换为内嵌汇编
passes.convert.add_arith_to_llvmir(pm) # Convert Arith dialect to LLVM dialect
passes.common.add_canonicalizer(pm) # converts operations into their canonical forms by folding constants, identity transformations etc.
passes.common.add_cse(pm) # Eliminate common sub-expressions
passes.common.add_symbol_dce(pm) # Eliminate dead symbols
if os.environ.get("TRITON_DISABLE_LINE_INFO", "0") == "0":
passes.llvmir.add_di_scope(pm) # Materialize LLVM line info
pm.run(mod)
# LLVM-IR (MLIR) -> LLVM-IR (LLVM)
llvm.init_targets()
context = llvm.context()
llvm_mod = llvm.to_module(mod, context) # 将LLVM dialect转换为LLVMIR
nvidia.set_nvvm_reflect_ftz(llvm_mod) # enable fast math path in libdevice
if options.extern_libs:
for name, path in options.extern_libs:
llvm.link_extern_lib(llvm_mod, path) # link libdevice,一些函数库会用到
llvm.optimize_module(llvm_mod, llvm.OPTIMIZE_O3) # O3优化
metadata["shared"] = src.get_int_attr("triton_gpu.shared")
ret = str(llvm_mod)
del llvm_mod
del context
return ret

由于第二步的转换比较固定,我们重点关注第一步,将各种dialect都转成LLVMDialect。其中主要关注add_to_llvmir这个pass,因为我们case中的arith.addi、tt.load和tt.store都会在这个pass中被重写。跳转到它的实现会发现是这样的(triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TritonGPUToLLVM.cpp)

void runOnOperation() override {
MLIRContext *context = &getContext();
ModuleOp mod = getOperation();
mlir::LowerToLLVMOptions option(context);
option.overrideIndexBitwidth(32);
TritonGPUToLLVMTypeConverter typeConverter(context, option);
TritonLLVMConversionTarget convTarget(*context);
int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
int numCTAs = triton::gpu::TritonGPUDialect::getNumCTAs(mod);
int threadsPerWarp = triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod);

// Allocate shared memory and set barrier
ModuleAllocation allocation(mod);
ModuleMembarAnalysis membarPass(&allocation);
membarPass.run();

// Lower functions
{
mlir::LowerToLLVMOptions option(context);
TritonGPUToLLVMTypeConverter typeConverter(context, option);
TritonLLVMFunctionConversionTarget funcTarget(*context);
RewritePatternSet funcPatterns(context);
funcPatterns.add<FuncOpConversion>(typeConverter, numWarps,
patternBenefitDefault);
mlir::cf::populateControlFlowToLLVMConversionPatterns(typeConverter,
funcPatterns);
if (failed(
applyPartialConversion(mod, funcTarget, std::move(funcPatterns))))
return signalPassFailure();
}

// initSharedMemory is run before the conversion of call and ret ops,
// because the call op has to know the shared memory base address of each
// function
initSharedMemory(typeConverter);
ModuleAxisInfoAnalysis axisInfoAnalysis(mod);
OpBuilder::InsertPoint indexInsertPoint;

RewritePatternSet patterns(context);
TargetInfo targetInfo(computeCapability);
int benefit = patternBenefitPrioritizeOverLLVMConversions;
......
populateLoadStoreOpToLLVMPatterns(typeConverter, patterns, axisInfoAnalysis,
benefit);
// 会调用下面两条
// patterns.add(typeConverter, axisInfoAnalysis, benefit);
// patterns.add(typeConverter, axisInfoAnalysis, benefit);
mlir::arith::populateArithToLLVMConversionPatterns(typeConverter, patterns);
// 会调用patterns.add
......
if (failed(applyPartialConversion(mod, convTarget, std::move(patterns))))
return signalPassFailure();
......
}

我们发现,这里的代码结构也是populate#Opname#Pattern,然后再执行applyPartialConversion,好像和上一章中add_convert_to_ttgpuir的转换过程差不多,细心的小伙伴可以观察到区别在于这里我们的target是TritonLLVMFunctionConversionTarget,而前面是TritonGPUConversionTarget,前者又增加了IndexDialect、LLVMDialect、NVMDialect等为合法,其他非法dialect在这一阶段会被lower。

(triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TritonGPUToLLVM.cpp)

class TritonLLVMConversionTarget : public ConversionTarget {
public:
explicit TritonLLVMConversionTarget(MLIRContext &ctx)
: ConversionTarget(ctx) {
addLegalDialect<:llvmdialect>();
addLegalDialect<:nvvmdialect>();
addLegalDialect<:triton::nvgpu::nvgpudialect>();
addIllegalDialect<:tritondialect>();
addIllegalDialect<:gpu::tritongpudialect>();
addIllegalDialect<:nvidia_gpu::tritonnvidiagpudialect>();
addIllegalDialect<:gpu::gpudialect>();
addLegalOp<:unrealizedconversioncastop>();
}
};

typeConverter也由TritonGPUTypeConverter变成了TritonGPUToLLVMTypeConverter,增加了更多类型的转换方式,比如nv新支持的fp8数据类型

(triton/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp)

TritonGPUToLLVMTypeConverter::TritonGPUToLLVMTypeConverter(
MLIRContext *ctx, LowerToLLVMOptions &option,
const DataLayoutAnalysis *analysis)
: LLVMTypeConverter(ctx, option, analysis) {
addConversion([&](triton::PointerType type) -> std::optional {
return convertTritonPointerType(type);
});
addConversion([&](RankedTensorType type) -> std::optional {
return convertTritonTensorType(type);
});
addConversion([&](MemDescType type) -> std::optional {
return convertMemDescType(type);
});
addConversion([&](triton::gpu::AsyncTokenType type) -> std::optional {
return convertAsyncToken(type);
});
// Internally store float8 as int8
addConversion([&](mlir::Float8E4M3B11FNUZType type) -> std::optional {
return IntegerType::get(type.getContext(), 8);
});
addConversion([&](mlir::Float8E4M3FNType type) -> std::optional {
return IntegerType::get(type.getContext(), 8);
});
addConversion([&](mlir::Float8E4M3FNUZType type) -> std::optional {
return IntegerType::get(type.getContext(), 8);
});
addConversion([&](mlir::Float8E5M2Type type) -> std::optional {
return IntegerType::get(type.getContext(), 8);
});
// Internally store bfloat16 as int16
addConversion([&](BFloat16Type type) -> std::optional {
return IntegerType::get(type.getContext(), 16);
});
}

此外,在上一章populate#Opname#Pattern时,我们的arith.addi、tt.load和tt.store都是采用的通用转换模式GenericOpPattern来转换的,而这里会有三种处理方式。对于load,我们看到它的RewritePattern是LoadOpConversion,直接看它的转换函数matchAndRewrite()

matchAndRewrite(triton::LoadOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
......
// Define the instruction opcode
auto &ld = ptxBuilder.create<>("ld")
->o("volatile", op.getIsVolatile())
.global()
.o("ca", op.getCache() == triton::CacheModifier::CA)
.o("cg", op.getCache() == triton::CacheModifier::CG)
.o("L1::evict_first",
op.getEvict() == triton::EvictionPolicy::EVICT_FIRST)
.o("L1::evict_last",
op.getEvict() == triton::EvictionPolicy::EVICT_LAST)
.o("L1::cache_hint", hasL2EvictPolicy)
.v(nWords)
.b(width);
......
}

可以看到nv这里将tt.load简单粗暴地处理成了内嵌汇编(可能是为了方便cache控制,amd是处理成了LLVM::LoadOp)。此外,load和store的处理还包括很多阶段,比如向量化、计算线程访问mask等,还会建立triton::nvgpu::ClusterCTAIdOp来和索引绑定一起,实现SIMT编程(这里省去了很多代码,还没有看完,之后尽量补上)。store同样通过StoreOpConversion处理成了内嵌汇编的形式。

对于arith.addi,在populateArithToLLVMConversionPatterns的时候调用了mlir中的方法AddIOpLowering,它也是继承自RewritePattern用来改写addi op

using AddIOpLowering =
VectorConvertToLLVMPattern<:addiop llvm::addop> arith::AttrConvertOverflowToLLVM>;

在它的实现中,会直接调用

rewriter.replaceOp(op, newOp->getResult(0)), success())

将arith::AddIOp替换为LLVM::AddOp。至此,我们已经有了target、type converter和各个op的RewritePattern,接下来就是重复上一章的运行过程,先判断合法化,再去做转换,最终完成它们到LLVMDialect的转换。此时的IR长这个样(为了简洁,打印了优化后的ir)

#loc = loc("toy.py":28:0)
module attributes {"triton_gpu.compute-capability" = 80 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.shared = 0 : i32, "triton_gpu.threads-per-warp" = 32 : i32} {
llvm.mlir.global external @global_smem() {addr_space = 3 : i32, alignment = 16 : i64} : !llvm.array<0 x i8> loc(#loc)
llvm.func @addi_kernel_01(%arg0: !llvm.ptr<1> loc("toy.py":28:0), %arg1: !llvm.ptr<1> loc("toy.py":28:0)) attributes {noinline = false, nvvm.kernel = 1 : ui1, nvvm.maxntid = array} {
%0 = llvm.mlir.constant(0 : i32) : i32 loc(#loc1)
%1 = llvm.mlir.constant(0 : index) : i32 loc(#loc1)
%2 = llvm.mlir.constant(true) : i1 loc(#loc1)
%3 = llvm.mlir.constant(1 : i32) : i32 loc(#loc1)
%4 = llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "mov.u32 $0, 0x0;\0A\09@$2 ld.global.b32 { $0 }, [ $1 + 0 ];", "=r,l,b" %arg0, %2 : (!llvm.ptr<1>, i1) -> i32 loc(#loc2)
%5 = llvm.bitcast %4 : i32 to vector<1xi32> loc(#loc2)
%6 = llvm.extractelement %5[%1 : i32] : vector<1xi32> loc(#loc2)
%7 = llvm.add %6, %3 : i32 loc(#loc3)
%8 = nvvm.read.ptx.sreg.tid.x : i32 loc(#loc4)
%9 = llvm.and %2, %2 : i1 loc(#loc4)
%10 = llvm.icmp "eq" %8, %0 : i32 loc(#loc4)
%11 = llvm.and %9, %10 : i1 loc(#loc4)
%12 = llvm.mlir.undef : vector<1xi32> loc(#loc4)
%13 = llvm.insertelement %7, %12[%0 : i32] : vector<1xi32> loc(#loc4)
%14 = llvm.bitcast %13 : vector<1xi32> to i32 loc(#loc4)
%15 = llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "@$2 st.global.b32 [ $1 + 0 ], { $0 };", "r,l,b" %14, %arg1, %11 : (i32, !llvm.ptr<1>, i1) -> !llvm.void loc(#loc4)
llvm.return loc(#loc5)
} loc(#loc6)
} loc(#loc)
#di_file = #llvm.di_file
#di_subroutine_type = #llvm.di_subroutine_type
#loc1 = loc(unknown)
#loc2 = loc("toy.py":38:16)
#loc3 = loc("toy.py":39:17)
#loc4 = loc("toy.py":40:25)
#loc5 = loc("toy.py":40:4)
#di_compile_unit = #llvm.di_compile_unit, sourceLanguage = DW_LANG_C, file = #di_file, producer = "triton", isOptimized = true, emissionKind = LineTablesOnly>
#di_subprogram = #llvm.di_subprogram, compileUnit = #di_compile_unit, scope = #di_file, name = "addi_kernel_01", linkageName = "addi_kernel_01", file = #di_file, line = 28, scopeLine = 28, subprogramFlags = "Definition|Optimized", type = #di_subroutine_type>
#loc6 = loc(fused[#loc])

其中arith.addi变成了llvm.add,tt.load/store变成了内嵌汇编,以及用来计算索引地址的若干条指令。load指令由于nv使用的内嵌汇编看起来不是很好理解,这里的%4=...这一条我们还是解释一下。首它是两条汇编组成:mov.u32 $0, 0x0 和 @$2 ld.global.b32 { $0 }, [ $1 + 0 ],这里的$0/$1/$2分别是%4/%arg0/%2。第一条汇编的含义是给%4一个初始值0,第二条load指令是从地址%arg0读取32位的数据到4%,并且第二条指令会受$2("@"后面是谓词寄存器)控制,当$2为true执行,反之不执行(这也是为什么需要第一条mov,当load不执行的时候需要给%4一个默认值)。







请到「今天看啥」查看全文