最近在整理先前实习做的一些工作,主要是对AI compiler做基于mlir的重构,以下是之前写的compiler frontend的一个比较基础的pass,针对自定义的IR Dialect做bufferization。
一、bufferization概念
Bufferization 是MLIR中一个重要的过程,它主要负责将具有tensor(张量)语义的操作转换为具有memref(内存引用)语义的操作。
-
Tensor在MLIR中代表抽象值类型的数据序列,它们并不直接对应于内存中的位置。
-
MemRef(Memory Reference)则代表对内存区域的具体引用,提供了更低级别的缓冲区访问能力。
-
Bufferization将tensor的语义转换为memref的语义,memref提供了更直接、更具体的内存访问方式,减少了编译器需要处理的抽象层次。
二、实现
以下是在XPU上自定义TIR的一个conv2d mlir的示意
pass的功能就是实现将func和op的tensor type转为memref type(TIR->MTIR),实现共包含两个pass,六个pattern!
module {
func.func @XPUFunc(%arg0: tensor<1x8x8x256xf32>) -> tensor<1x4x4x256xf32> attributes {input_names = ["data0"], input_num = 1 : i64, output_names = ["conv0_fix"]} {
%0 = "tir.const"() {value = dense_resource<__elided__> : tensor<256x2x2x256xi8>} : () -> tensor<256x2x2x256xi8>
%1 = "tir.const"() {value = dense_resource<__elided__> : tensor<256xi8>} : () -> tensor<256xi8>
%2 = "tir.float2fix"(%arg0) {bit_width = 8 : i32, fix_point = 0 : i32, if_signed = true, op_name = "data0_fix", round_mode = "XPU_ROUND"} : (tensor<1x8x8x256xf32>) -> tensor<1x8x8x256xi8>
%3 = "tir.conv2d-fix"(%2, %0, %1) {dilation = [1 : i32, 1 : i32], group = 1 : i32, hsigmoid_in = -128 : i32, kernel = [2 : i32, 2 : i32], nonlinear = "NONE", op_name = "conv0", pad = [0 : i32, 0 : i32, 0 : i32, 0 : i32], pad_mode = "FLOOR", shift_hsigmoid = -128 : i32, shift_hswish = -128 : i32, stride = [2 : i32, 2 : i32]} : (tensor<1x8x8x256xi8>, tensor<256x2x2x256xi8>, tensor<256xi8>) -> tensor<1x4x4x256xi8>
%4 = "tir.fix2float"(%3) {bit_width = 8 : i32, fix_point = 0 : i32, if_signed = true, op_name = "conv0_fix", round_mode = "XPU_ROUND"} : (tensor<1x4x4x256xi8>) -> tensor<1x4x4x256xf32>
return %4 : tensor<1x4x4x256xf32>
}
}
ODS自定义OP .td写法示例
include "Tir_op_base.td"
def Tir_ConstOp :
Tir_Op<"const", [ConstantLike, Pure, FirstAttrDerivedResultType]> {
let summary = "Represent a constant tensor with values";
let description = [{
The constant operator providing initialized values for tensors.
The initial values come either in `DenseElementsAttr` `value`, or from an
external binary file specified in `path`.
}];
let arguments = (ins
OptionalAttr:$value
);
let results = (outs Tir_Tensor:$output);
let hasFolder = 1;
}
...
...
2.1global_bufferize pass
实现分为两步pass,第一步为global_bufferize pass,即将func的argument和return的tensor type转为memref。代码和注释如下所示
/// @brief Early bufferization on global input/output and constants
class GlobalBufferize : public impl::GlobalBufferizeBase {
public:
void runOnOperation() override { // 重写基类的runOnOperation函数
auto *ctx = &getContext();
//获取上下文,FuncOp的成员函数,用于后续创建新的Op、添加转换规则
ConversionTarget target(*ctx);
//ConversionTarget 用于指定在转换过程中哪些Op是合法的,哪些是需要动态检查的。
target.addDynamicallyLegalOp<:constop>([](Operation *op) {
auto ttype = op->getResult(0).getType().cast();
return ttype.getRank() == 0;
}); //tir.ConstOp返回维度数(秩)是0的时候也就是标量,才合法转换 //不然就转为memex.const
target.addLegalOp<:constop>(); //静态合法,不需要转换
target.addLegalOp<:uploadop>();
target.addLegalOp<:downloadop>();
target.addDynamicallyLegalOp<:func::returnop>(
[](Operation *op) { return op->getNumOperands() == 0; });
//ReturnOp返回数为0时合法。
//因为后续用到了upload和download将func里面的argu2进行结果copy,所以不需要return结果了
mlir::func::FuncOp func = getOperation(); //获取funcOp
updateFuncOp(func); //更新Op的操作
RewritePatternSet convertPatterns(ctx); //存Pattern的集合
convertPatterns.insert(ctx);
//将ConstOp、ReturnOp的ConvertPattern加入set
(void)applyPartialConversion(func, target, std::move(convertPatterns));
//根据target中定义的规则进行convertpatternset中的转换
}
};
} //
//创建返回pass对象
std::unique_ptr<:pass> tir::createGlobalBufferizePass() {
return std::make_unique();
}
以上是globalbufferize pass的主要部分,在定义的target合法规则检查上应用了两个转换pattern和updateFuncOp。下面看updateFuncOp
static inline MemRefType tensorToMemRef(RankedTensorType type) {
return
MemRefType::get(type.getShape(), type.getElementType());
}
static void updateFuncOp(mlir::func::FuncOp func) {
mlir::OpBuilder builder(func.getBody());
//OpBuilder用于在Func Body内生成Op
auto funcType = func.getFunctionType();
//获取FuncOp的inputs、results类型信息
llvm::SmallVector4> argTypes; //存更新后的函数参数类型
for (auto type : llvm::enumerate(funcType.getInputs())) {
//遍历FuncOp的输入参数
auto tensorType = type.value().dyn_cast();
if (tensorType) {
auto argType = tensorToMemRef(tensorType); //将tensor转为memref
auto arg = func.getArgument(type.index());
arg.setType(argType);
//以上三步将funcOp inputs的对应type由Tensor type转为MemRef type
auto load = builder.create<:uploadop>(func.getLoc(), tensorType, arg);
//创建tir.upload op,将该Op的input和result(args)为tensor type
arg.replaceAllUsesExcept(load->getResult(0), load);
//loadOp input替换为memref,result还是tensor
argTypes.emplace_back(argType);
} else {
argTypes.emplace_back(type.value());
}
}
for (auto type : funcType.getResults()) {
auto tensorType = type.cast();
auto argType = tensorToMemRef(tensorType);
argTypes.emplace_back(argType);
func.front().addArguments(argType, builder.getUnknownLoc());
}
//将funcOp的type根据argTypes vector进行替换
func.setType(FunctionType::get(func.getContext(), argTypes, llvm::None));
}
总结:updateFuncOp 函数的作用是将输入参数和输出结果从 RankedTensorType 转换为 MemRefType,另外还创建了tir.uploadOp(memref->tensor)来获取对应input的memref类型输入转为tensor。
再来看两个convertpattern,对于ConstOpConvert,实现上是用自定义memtx.const(tensor->memtef)+tir.upload(memref->tensor)替换了原来的tir.const(tensor->tensor)
struct ConstOpConverter : public OpConversionPattern {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(ConstOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto tensorType = op.getOutput().getType().cast();
auto memRefType = tensorToMemRef(tensorType);
auto mconst =
rewriter.create<:constop>(op.getLoc(), memRefType, *op.getValue())
.getResult();
rewriter.replaceOpWithNewOp<:uploadop>(op, tensorType, mconst);
return success();
}
};
struct ReturnOpConverter : public OpConversionPattern<:func::returnop> {
using OpConversionPattern<:func::returnop>::OpConversionPattern;
LogicalResult
matchAndRewrite(mlir::func::ReturnOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto func = op->getParentOfType<:func::funcop>();
unsigned retArgIndex = func.getNumArguments() - op.getNumOperands();
for (auto opr : llvm::enumerate(adaptor.getOperands())) {
auto outputArg = func.getArgument(retArgIndex + opr.index());
rewriter.create<:downloadop>(op.getLoc(), opr.value(), outputArg);
}
rewriter.replaceOpWithNewOp<:func::returnop>(op);
return success();
}
};
对于ReturnOpConverter,用tir.download替换returnOp,将输出结果从tensor转为memref
global_bufferize pass后的结果如下,可以看到func的arg转为了memref,新增了tir.upload和download作为func arg输入memref->tensor的Op,memtx.const+tir.upload用于memref和tensor转换
module {
func.func @XPUFunc(%arg0: memref<1x8x8x256xf32>, %arg1: memref<1x4x4x256xf32>) attributes {input_names = ["data0"], input_num = 1 : i64, output_names = ["conv0_fix"]} {
%0 = "tir.upload"(%arg0) : (memref<1x8x8x256xf32>) -> tensor<1x8x8x256xf32>
%1 = "memtx.const"() {value = dense_resource<__elided__> : tensor<256x2x2x256xi8>} : () -> memref<256x2x2x256xi8>
%2 = "tir.upload"(%1) : (memref<256x2x2x256xi8>) -> tensor<256x2x2x256xi8>
%3 = "memtx.const"() {value = dense_resource<__elided__> : tensor<256xi8>} : () -> memref<256xi8>
%4 = "tir.upload"(%3) : (memref<256xi8>) -> tensor<256xi8>
%5 = "tir.float2fix"(%0) {bit_width = 8 : i32, fix_point = 0 : i32, if_signed = true, op_name = "data0_fix", round_mode = "XPU_ROUND"} : (tensor<1x8x8x256xf32>) -> tensor<1x8x8x256xi8>
%6 = "tir.conv2d-fix"(%5, %2, %4) {dilation = [1 : i32, 1 : i32], group = 1 : i32, hsigmoid_in = -128 : i32, kernel = [2 : i32, 2 : i32], nonlinear = "NONE", op_name = "conv0", pad = [0 : i32, 0 : i32, 0 : i32, 0 : i32], pad_mode = "FLOOR", shift_hsigmoid = -128 : i32, shift_hswish = -128 : i32, stride = [2 : i32, 2 : i32]} : (tensor<1x8x8x256xi8>, tensor<256x2x2x256xi8>, tensor<256xi8>) -> tensor<1x4x4x256xi8>
%7 = "tir.fix2float"(%6) {bit_width = 8 : i32, fix_point = 0 : i32, if_signed = true, op_name = "conv0_fix", round_mode = "XPU_ROUND"} : (tensor<1x4x4x256xi8>) -> tensor<1x4x4x256xf32>
"tir.download"(%7, %arg1) : (tensor<1x4x4x256xf32>, memref<1x4x4x256xf32>) -> ()
return
}
}
下面是新增的ODS自定义Op
include "tir_op_base.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
def Tir_MemRef : StridedMemRefOf<[Tir_ElementType]>;
def Tir_UpLoadOp : Tir_Op<"upload", [NoMemoryEffect]> {
let arguments = (ins Tir_MemRef:$mem);
let results = (outs Tir_Tensor:$output);
}
def Tir_DownLoadOp : Tir_Op<"download"> {
let arguments = (ins Tir_Tensor:$tensor, Tir_MemRef:$mem);
}
2.2tir2mtir_convert pass
直接上结果,我们的目的是将IR 做bufferization即不能出现出memref类型外的tensor类型,在前一个pass global_bufferize后,我们得到了IR所示的结果,在此基础上继续写第二个pass->tir2mtir_convert。
module {
func.func @XPUFunc(%arg0: memref<1x8x8x256xf32>, %arg1: memref<1x4x4x256xf32>) attributes {input_names = ["data0"], input_num = 1 : i64, output_names = ["conv0_fix"]} {
%alloc = memref.alloc() : memref<1x8x8x256xf32>
"memtx.copy"(%arg0, %alloc) : (memref<1x8x8x256xf32>, memref<1x8x8x256xf32>) -> ()
%0 = "memtx.const"() {value = dense_resource<__elided__> : tensor<256x2x2x256xi8>} : () -> memref<256x2x2x256xi8>
%alloc_0 = memref.alloc() : memref<256x2x2x256xi8>
"memtx.copy"(%0, %alloc_0) : (memref<256x2x2x256xi8>, memref<256x2x2x256xi8>) -> ()
%1 = "memtx.const"() {value = dense_resource<__elided__> : tensor<256xi8>} : () -> memref<256xi8>
%alloc_1 = memref.alloc() : memref<256xi8>
"memtx.copy"(%1, %alloc_1) : (memref<256xi8>, memref<256xi8>) -> ()
%alloc_2 = memref.alloc() : memref<1x8x8x256xi8>
"mtir.float2fix"(%alloc, %alloc_2) {bit_width = 8 : i32, fix_point = 0 : i32, if_signed = true, op_name = "data0_fix", round_mode = "XPU_ROUND"} : (memref<1x8x8x256xf32>, memref<1x8x8x256xi8>) -> ()
%alloc_3 = memref.alloc() : memref<1x4x4x256xi8>
"mtir.conv2d-fix"(%alloc_2, %alloc_0, %alloc_1, %alloc_3) {dilation = [1 : i32, 1 : i32], group = 1 : i32, hsigmoid_in = -128 : i32, kernel = [2 : i32, 2 : i32], nonlinear = "NONE", op_name = "conv0", pad = [0 : i32, 0 : i32, 0 : i32, 0 : i32], pad_mode = "FLOOR", shift_hsigmoid = -128 : i32, shift_hswish = -128 : i32, stride = [2 : i32, 2 : i32]} : (memref<1x8x8x256xi8>, memref<256x2x2x256xi8>, memref<256xi8>, memref<1x4x4x256xi8>) -> ()
%alloc_4 = memref.alloc() : memref<1x4x4x256xf32>
"mtir.fix2float"(%alloc_3, %alloc_4) {bit_width = 8 : i32, fix_point = 0 : i32, if_signed = true, op_name = "conv0_fix", round_mode = "XPU_ROUND"} : (memref<1x4x4x256xi8>, memref<1x4x4x256xf32>) -> ()
"memtx.copy"(%alloc_4, %arg1) : (memref<1x4x4x256xf32>, memref<1x4x4x256xf32>) -> ()
return
}
}
pass如下
struct ConvertTirToMTirPass
: public impl::ConvertTirToMTirBase {
void runOnOperation() override {
mlir::func::FuncOp f = getOperation();
auto &context = getContext();
ConversionTarget target(context);
mlir::bufferization::BufferizeTypeConverter typeConverter;
// 设置TirToMTir的legality 和 patterns
setupTirToMTirLegality(typeConverter, target);
RewritePatternSet patterns(&context);
populateTirToMTirPatterns(typeConverter, patterns);
// 使用在target上定义的合法性pattern做conversion转换
if (failed(applyFullConversion(f, target, std::move(patterns)))) {
signalPassFailure();
}
// 设置finalize的legality和patterns
RewritePatternSet finalizePatterns(&context);
ConversionTarget finalizeTarget(context);
finalizeTarget.markUnknownOpDynamicallyLegal(
[&](Operation *op) { return typeConverter.isLegal(op); });
populateEliminateBufferizeMaterializationsPatterns(typeConverter,
finalizePatterns);
// 使用在target上定义的合法性pattern做conversion转换
if (failed(applyFullConversion(f, finalizeTarget,
std::move(finalizePatterns)))) {
signalPassFailure();
}
}
};
} // end anonymous namespace
std::unique_ptr<:operationpass>>
mxir::createConvertTirToMTirPass() {
return std::make_unique();
}
下面来看具体的Legality和pattern
//添加和标记合法和非法的方言,在convert的时候应用
void xcompiler::mxir::setupTirToMTirLegality(
mlir::bufferization::BufferizeTypeConverter &typeConverter,
ConversionTarget &target) {
target.addLegalDialect<:memrefdialect>();
target.addLegalDialect<:mtirdialect>();
target.addLegalDialect<:memtxdialect>();
target.addLegalDialect();
target.addLegalOp<:func::returnop>();
target.addIllegalDialect<:tirdialect>();
//virtual buffer
mlir::bufferization::populateBufferizeMaterializationLegality(target);
}
void xcompiler::mxir::populateTirToMTirPatterns(
mlir::bufferization::BufferizeTypeConverter &typeConverter,
RewritePatternSet &patterns) {
auto *context = patterns.getContext();
typeConverter.addConversion(
[](RankedTensorType type) -> Optional { return llvm::None; });
//不支持tensorType
typeConverter.addArgumentMaterialization(
[](OpBuilder &builder, TensorType type, ValueRange inputs,
Location loc) -> Optional {
if (type.getRank() == 0) { //标量直接返回第一个输入
return inputs[0];
}
return llvm::None;
});
//主要应用了四个pattern
patterns.add TirOpConverter>(typeConverter, context);
}
四个pattern
//alloc op
//为给定的op创建一个内存分配操作memref::AllocOp
static memref::AllocOp createAllocForOp(Operation *op, MemRefType type,
OpBuilder &builder) {
auto alloc = builder.create<:allocop>(op->getLoc(), type);
if (auto attr = op->getAttrOfType("id")) {
auto