From 1457b51d3140dd2b23f6058595d0f8fad4390025 Mon Sep 17 00:00:00 2001 From: Schrodinger ZHU Yifan Date: Mon, 4 Nov 2024 20:29:04 -0500 Subject: [PATCH] fixes --- include/mlir-gccjit/IR/GCCJITAttrs.td | 7 +++--- src/Conversion/ConvertMemrefToGCCJIT.cpp | 16 ++++++------ src/Conversion/TypeConverter.cpp | 6 ++--- src/Translation/TranslateToGCCJIT.cpp | 31 ++++++++++++++++++------ test/lowering/gemm.mlir | 15 +++++++++++- test/syntax/record.mlir | 20 +++++++-------- 6 files changed, 62 insertions(+), 33 deletions(-) diff --git a/include/mlir-gccjit/IR/GCCJITAttrs.td b/include/mlir-gccjit/IR/GCCJITAttrs.td index 9b4c16e..edaeed6 100644 --- a/include/mlir-gccjit/IR/GCCJITAttrs.td +++ b/include/mlir-gccjit/IR/GCCJITAttrs.td @@ -180,7 +180,7 @@ def FieldAttr : GCCJIT_Attr<"Field", "field"> { let builders = [ AttrBuilder<(ins "mlir::StringAttr":$name, "mlir::Type":$type), [{ - return get($_ctxt, name, type, 0, std::nullopt); + return get($_ctxt, name, type, std::nullopt, std::nullopt); }]>, AttrBuilder<(ins "mlir::StringAttr":$name, "mlir::Type":$type, "unsigned":$bitWidth), [{ @@ -188,12 +188,13 @@ def FieldAttr : GCCJIT_Attr<"Field", "field"> { }]>, AttrBuilder<(ins "mlir::StringAttr":$name, "mlir::Type":$type, "mlir::gccjit::SourceLocAttr":$loc), [{ - return get($_ctxt, name, type, 0, loc); + return get($_ctxt, name, type, std::nullopt, loc); }]>, ]; + // attribute can eat up the `:` separator, so we need to move the name to the front let assemblyFormat = [{ - `<` $type $name (`:` $bitWidth^)? ($loc^)? `>` + `<` $name $type (`:` $bitWidth^)? ($loc^)? `>` }]; } diff --git a/src/Conversion/ConvertMemrefToGCCJIT.cpp b/src/Conversion/ConvertMemrefToGCCJIT.cpp index 7816318..389ea6d 100644 --- a/src/Conversion/ConvertMemrefToGCCJIT.cpp +++ b/src/Conversion/ConvertMemrefToGCCJIT.cpp @@ -12,10 +12,16 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include -#include +#include +#include #include #include +#include +#include +#include +#include +#include +#include #include "libgccjit.h" #include "mlir-gccjit/Conversion/Conversions.h" @@ -25,12 +31,6 @@ #include "mlir-gccjit/IR/GCCJITOpsEnums.h" #include "mlir-gccjit/IR/GCCJITTypes.h" #include "mlir-gccjit/Passes.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/BuiltinAttributes.h" -#include "mlir/IR/BuiltinTypeInterfaces.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/Location.h" -#include "mlir/IR/Value.h" using namespace mlir; using namespace mlir::gccjit; diff --git a/src/Conversion/TypeConverter.cpp b/src/Conversion/TypeConverter.cpp index 17a7cb0..3332ed8 100644 --- a/src/Conversion/TypeConverter.cpp +++ b/src/Conversion/TypeConverter.cpp @@ -177,7 +177,7 @@ GCCJITTypeConverter::getMemrefDescriptorType(mlir::MemRefType type) const { llvm::enumerate(ArrayRef{elementPtrType, elementPtrType, indexType, dimOrStrideType, dimOrStrideType})) { auto nameAttr = StringAttr::get(type.getContext(), names[idx]); - fields.push_back(FieldAttr::get(type.getContext(), nameAttr, field, 0)); + fields.push_back(FieldAttr::get(type.getContext(), nameAttr, field)); } auto fieldsAttr = ArrayAttr::get(type.getContext(), fields); return StructType::get(type.getContext(), nameAttr, fieldsAttr); @@ -199,7 +199,7 @@ gccjit::StructType GCCJITTypeConverter::getUnrankedMemrefDescriptorType( llvm::enumerate(ArrayRef{indexType, opaquePtrType})) { auto name = Twine("__field_").concat(Twine(idx)).str(); auto nameAttr = StringAttr::get(type.getContext(), name); - fields.push_back(FieldAttr::get(type.getContext(), nameAttr, field, 0)); + fields.push_back(FieldAttr::get(type.getContext(), nameAttr, field)); } auto fieldsAttr = ArrayAttr::get(type.getContext(), fields); return StructType::get(type.getContext(), nameAttr, fieldsAttr); @@ -220,7 +220,7 @@ Type GCCJITTypeConverter::convertAndPackTypesIfNonSingleton( for (auto [idx, type] : llvm::enumerate(types)) { auto name = Twine("__field_").concat(Twine(idx)).str(); auto nameAttr = StringAttr::get(func.getContext(), name); - fields.push_back(FieldAttr::get(type.getContext(), nameAttr, type, 0)); + fields.push_back(FieldAttr::get(type.getContext(), nameAttr, type)); } auto nameAttr = StringAttr::get(func.getContext(), name); auto fieldsAttr = ArrayAttr::get(func.getContext(), fields); diff --git a/src/Translation/TranslateToGCCJIT.cpp b/src/Translation/TranslateToGCCJIT.cpp index c19ced3..c71bede 100644 --- a/src/Translation/TranslateToGCCJIT.cpp +++ b/src/Translation/TranslateToGCCJIT.cpp @@ -107,8 +107,9 @@ class RegionVisitor { gcc_jit_rvalue *visitExprWithoutCache(AddrOp op); gcc_jit_rvalue *visitExprWithoutCache(FnAddrOp op); gcc_jit_lvalue *visitExprWithoutCache(GetGlobalOp op); - gcc_jit_rvalue *visitExprWithoutCache(ExprOp op); + Expr visitExprWithoutCache(ExprOp op); gcc_jit_lvalue *visitExprWithoutCache(DerefOp op); + Expr visitExprWithoutCache(AccessFieldOp op); /// The following operations are entrypoints for real codegen. void visitAssignOp(gcc_jit_block *blk, AssignOp op); @@ -515,18 +516,17 @@ Expr RegionVisitor::translateIntoContext() { Block &block = region.getBlocks().front(); auto terminator = cast(block.getTerminator()); auto value = terminator->getOperand(0); - auto rvalue = visitExpr(value, true); + auto expr = visitExpr(value, true); if (auto globalOp = dyn_cast(parent)) { auto symName = SymbolRefAttr::get(getMLIRContext(), globalOp.getSymName()); auto *lvalue = getTranslator().getGlobalLValue(symName); - gcc_jit_global_set_initializer_rvalue(lvalue, rvalue); + gcc_jit_global_set_initializer_rvalue(lvalue, expr); return {}; } - if (auto exprOp = dyn_cast(parent)) { - return rvalue; - } + if (auto exprOp = dyn_cast(parent)) + return expr; llvm_unreachable("unknown region parent"); } @@ -558,8 +558,8 @@ Expr RegionVisitor::visitExpr(Value value, bool toplevel) { .Case([&](GetGlobalOp op) { return visitExprWithoutCache(op); }) .Case([&](ExprOp op) { return visitExprWithoutCache(op); }) .Case([&](DerefOp op) { return visitExprWithoutCache(op); }) + .Case([&](AccessFieldOp op) { return visitExprWithoutCache(op); }) .Default([](Operation *op) -> Expr { - op->dump(); llvm::report_fatal_error("unknown expression type"); }); @@ -579,7 +579,22 @@ gcc_jit_lvalue *RegionVisitor::visitExprWithoutCache(DerefOp op) { return gcc_jit_context_new_array_access(getContext(), loc, ptr, offset); } -gcc_jit_rvalue *RegionVisitor::visitExprWithoutCache(ExprOp op) { +Expr RegionVisitor::visitExprWithoutCache(AccessFieldOp op) { + auto composite = visitExpr(op.getComposite()); + auto *loc = getTranslator().getLocation(op.getLoc()); + auto *compositeTy = getTranslator().convertType(op.getComposite().getType()); + auto index = op.getField().getZExtValue(); + // TODO: support union and query from cache instead + auto *structure = gcc_jit_type_is_struct(compositeTy); + if (!structure) + llvm_unreachable("expected struct type"); + auto *field = gcc_jit_struct_get_field(structure, index); + if (isa(op.getType())) + return gcc_jit_lvalue_access_field(composite, loc, field); + return gcc_jit_rvalue_access_field(composite, loc, field); +} + +Expr RegionVisitor::visitExprWithoutCache(ExprOp op) { RegionVisitor visitor(getTranslator(), op.getRegion(), this); return visitor.translateIntoContext(); } diff --git a/test/lowering/gemm.mlir b/test/lowering/gemm.mlir index 73aaf64..0146075 100644 --- a/test/lowering/gemm.mlir +++ b/test/lowering/gemm.mlir @@ -1,4 +1,12 @@ -// RUN: %gccjit-opt %s -lower-affine -convert-scf-to-cf -convert-arith-to-gccjit -convert-memref-to-gccjit -convert-func-to-gccjit -reconcile-unrealized-casts | %filecheck %s +// RUN: %gccjit-opt %s \ +// RUN: -lower-affine \ +// RUN: -convert-scf-to-cf \ +// RUN: -convert-arith-to-gccjit \ +// RUN: -convert-memref-to-gccjit \ +// RUN: -convert-func-to-gccjit \ +// RUN: -reconcile-unrealized-casts -mlir-print-debuginfo -o %t.mlir +// RUN: %filecheck --input-file=%t.mlir %s +// RUN: %gccjit-translate %t.mlir -mlir-to-gccjit-gimple | %filecheck %s --check-prefix=CHECK-GIMPLE module { // CHECK-NOT: func.func // CHECK-NOT: func.return @@ -15,11 +23,15 @@ module { %acc0 = arith.constant 0.0 : f32 %sum = affine.for %k = 0 to 100 iter_args(%acc = %acc0) -> f32 { // Load values from A and B + // CHECK-GIMPLE: %{{[0-9]+}} = %{{[0-9\.a-z]+}}[(%{{[0-9]+}} * (size_t)100 + %{{[0-9]+}})] %a_val = affine.load %A[%i, %k] : memref<100x100xf32> + // CHECK-GIMPLE: %{{[0-9]+}} = %{{[0-9\.a-z]+}}[(%{{[0-9]+}} * (size_t)100 + %{{[0-9]+}})] %b_val = affine.load %B[%k, %j] : memref<100x100xf32> // Multiply and accumulate + // CHECK-GIMPLE: %[[V:[0-9]+]] = %{{[0-9]+}} * %{{[0-9]+}} %prod = arith.mulf %a_val, %b_val : f32 + // CHECK-GIMPLE: %{{[0-9]+}} = %{{[0-9]+}} + %[[V]] %new_acc = arith.addf %acc, %prod : f32 // Yield the new accumulated value @@ -33,6 +45,7 @@ module { %final_val = arith.addf %c_scaled, %result : f32 // Store the final result back to matrix C + // CHECK-GIMPLE: %{{[0-9\.a-z]+}}[(%{{[0-9]+}} * (size_t)100 + %{{[0-9]+}})] = %{{[0-9]+}} affine.store %final_val, %C[%i, %j] : memref<100x100xf32> } } diff --git a/test/syntax/record.mlir b/test/syntax/record.mlir index f98f701..71f29a7 100644 --- a/test/syntax/record.mlir +++ b/test/syntax/record.mlir @@ -4,19 +4,19 @@ module @test { gccjit.func imported @gemm ( !gccjit.struct<"__memref_188510220862752" { - #gccjit.field> "base">, - #gccjit.field> "aligned">, - #gccjit.field "offset">, - #gccjit.field, 2> "sizes">, - #gccjit.field, 2> "strides"> + #gccjit.field<"base" !gccjit.ptr>>, + #gccjit.field<"aligned" !gccjit.ptr>>, + #gccjit.field<"offset" !gccjit.int : 32>, + #gccjit.field<"sizes" !gccjit.array, 2>>, + #gccjit.field<"strides" !gccjit.array, 2>> }> ) // CHECK: @gemm // CHECK-SAME: !gccjit.struct<"__memref_188510220862752" { - // CHECK-SAME: #gccjit.field> "base"> - // CHECK-SAME: #gccjit.field> "aligned"> - // CHECK-SAME: #gccjit.field "offset"> - // CHECK-SAME: #gccjit.field, 2> "sizes"> - // CHECK-SAME: #gccjit.field, 2> "strides"> + // CHECK-SAME: #gccjit.field<"base" !gccjit.ptr>> + // CHECK-SAME: #gccjit.field<"aligned" !gccjit.ptr>> + // CHECK-SAME: #gccjit.field<"offset" !gccjit.int : 32> + // CHECK-SAME: #gccjit.field<"sizes" !gccjit.array, 2>> + // CHECK-SAME: #gccjit.field<"strides" !gccjit.array, 2>> // CHECK-SAME: } }