Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
SchrodingerZhu committed Nov 5, 2024
1 parent 6e8d642 commit 1457b51
Show file tree
Hide file tree
Showing 6 changed files with 62 additions and 33 deletions.
7 changes: 4 additions & 3 deletions include/mlir-gccjit/IR/GCCJITAttrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -180,20 +180,21 @@ def FieldAttr : GCCJIT_Attr<"Field", "field"> {

let builders = [
AttrBuilder<(ins "mlir::StringAttr":$name, "mlir::Type":$type), [{
return get($_ctxt, name, type, 0, std::nullopt);
return get($_ctxt, name, type, std::nullopt, std::nullopt);
}]>,
AttrBuilder<(ins "mlir::StringAttr":$name, "mlir::Type":$type,
"unsigned":$bitWidth), [{
return get($_ctxt, name, type, bitWidth, std::nullopt);
}]>,
AttrBuilder<(ins "mlir::StringAttr":$name, "mlir::Type":$type,
"mlir::gccjit::SourceLocAttr":$loc), [{
return get($_ctxt, name, type, 0, loc);
return get($_ctxt, name, type, std::nullopt, loc);
}]>,
];

// attribute can eat up the `:` separator, so we need to move the name to the front
let assemblyFormat = [{
`<` $type $name (`:` $bitWidth^)? ($loc^)? `>`
`<` $name $type (`:` $bitWidth^)? ($loc^)? `>`
}];
}

Expand Down
16 changes: 8 additions & 8 deletions src/Conversion/ConvertMemrefToGCCJIT.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,16 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include <llvm-20/llvm/Support/Casting.h>
#include <llvm-20/llvm/Support/ErrorHandling.h>
#include <llvm/Support/Casting.h>
#include <llvm/Support/ErrorHandling.h>
#include <mlir/Dialect/Func/IR/FuncOps.h>
#include <mlir/Dialect/MemRef/IR/MemRef.h>
#include <mlir/IR/Builders.h>
#include <mlir/IR/BuiltinAttributes.h>
#include <mlir/IR/BuiltinTypeInterfaces.h>
#include <mlir/IR/BuiltinTypes.h>
#include <mlir/IR/Location.h>
#include <mlir/IR/Value.h>

#include "libgccjit.h"
#include "mlir-gccjit/Conversion/Conversions.h"
Expand All @@ -25,12 +31,6 @@
#include "mlir-gccjit/IR/GCCJITOpsEnums.h"
#include "mlir-gccjit/IR/GCCJITTypes.h"
#include "mlir-gccjit/Passes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/Value.h"

using namespace mlir;
using namespace mlir::gccjit;
Expand Down
6 changes: 3 additions & 3 deletions src/Conversion/TypeConverter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ GCCJITTypeConverter::getMemrefDescriptorType(mlir::MemRefType type) const {
llvm::enumerate(ArrayRef<Type>{elementPtrType, elementPtrType, indexType,
dimOrStrideType, dimOrStrideType})) {
auto nameAttr = StringAttr::get(type.getContext(), names[idx]);
fields.push_back(FieldAttr::get(type.getContext(), nameAttr, field, 0));
fields.push_back(FieldAttr::get(type.getContext(), nameAttr, field));
}
auto fieldsAttr = ArrayAttr::get(type.getContext(), fields);
return StructType::get(type.getContext(), nameAttr, fieldsAttr);
Expand All @@ -199,7 +199,7 @@ gccjit::StructType GCCJITTypeConverter::getUnrankedMemrefDescriptorType(
llvm::enumerate(ArrayRef<Type>{indexType, opaquePtrType})) {
auto name = Twine("__field_").concat(Twine(idx)).str();
auto nameAttr = StringAttr::get(type.getContext(), name);
fields.push_back(FieldAttr::get(type.getContext(), nameAttr, field, 0));
fields.push_back(FieldAttr::get(type.getContext(), nameAttr, field));
}
auto fieldsAttr = ArrayAttr::get(type.getContext(), fields);
return StructType::get(type.getContext(), nameAttr, fieldsAttr);
Expand All @@ -220,7 +220,7 @@ Type GCCJITTypeConverter::convertAndPackTypesIfNonSingleton(
for (auto [idx, type] : llvm::enumerate(types)) {
auto name = Twine("__field_").concat(Twine(idx)).str();
auto nameAttr = StringAttr::get(func.getContext(), name);
fields.push_back(FieldAttr::get(type.getContext(), nameAttr, type, 0));
fields.push_back(FieldAttr::get(type.getContext(), nameAttr, type));
}
auto nameAttr = StringAttr::get(func.getContext(), name);
auto fieldsAttr = ArrayAttr::get(func.getContext(), fields);
Expand Down
31 changes: 23 additions & 8 deletions src/Translation/TranslateToGCCJIT.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,9 @@ class RegionVisitor {
gcc_jit_rvalue *visitExprWithoutCache(AddrOp op);
gcc_jit_rvalue *visitExprWithoutCache(FnAddrOp op);
gcc_jit_lvalue *visitExprWithoutCache(GetGlobalOp op);
gcc_jit_rvalue *visitExprWithoutCache(ExprOp op);
Expr visitExprWithoutCache(ExprOp op);
gcc_jit_lvalue *visitExprWithoutCache(DerefOp op);
Expr visitExprWithoutCache(AccessFieldOp op);

/// The following operations are entrypoints for real codegen.
void visitAssignOp(gcc_jit_block *blk, AssignOp op);
Expand Down Expand Up @@ -515,18 +516,17 @@ Expr RegionVisitor::translateIntoContext() {
Block &block = region.getBlocks().front();
auto terminator = cast<gccjit::ReturnOp>(block.getTerminator());
auto value = terminator->getOperand(0);
auto rvalue = visitExpr(value, true);
auto expr = visitExpr(value, true);

if (auto globalOp = dyn_cast<gccjit::GlobalOp>(parent)) {
auto symName = SymbolRefAttr::get(getMLIRContext(), globalOp.getSymName());
auto *lvalue = getTranslator().getGlobalLValue(symName);
gcc_jit_global_set_initializer_rvalue(lvalue, rvalue);
gcc_jit_global_set_initializer_rvalue(lvalue, expr);
return {};
}

if (auto exprOp = dyn_cast<ExprOp>(parent)) {
return rvalue;
}
if (auto exprOp = dyn_cast<ExprOp>(parent))
return expr;

llvm_unreachable("unknown region parent");
}
Expand Down Expand Up @@ -558,8 +558,8 @@ Expr RegionVisitor::visitExpr(Value value, bool toplevel) {
.Case([&](GetGlobalOp op) { return visitExprWithoutCache(op); })
.Case([&](ExprOp op) { return visitExprWithoutCache(op); })
.Case([&](DerefOp op) { return visitExprWithoutCache(op); })
.Case([&](AccessFieldOp op) { return visitExprWithoutCache(op); })
.Default([](Operation *op) -> Expr {
op->dump();
llvm::report_fatal_error("unknown expression type");
});

Expand All @@ -579,7 +579,22 @@ gcc_jit_lvalue *RegionVisitor::visitExprWithoutCache(DerefOp op) {
return gcc_jit_context_new_array_access(getContext(), loc, ptr, offset);
}

gcc_jit_rvalue *RegionVisitor::visitExprWithoutCache(ExprOp op) {
Expr RegionVisitor::visitExprWithoutCache(AccessFieldOp op) {
auto composite = visitExpr(op.getComposite());
auto *loc = getTranslator().getLocation(op.getLoc());
auto *compositeTy = getTranslator().convertType(op.getComposite().getType());
auto index = op.getField().getZExtValue();
// TODO: support union and query from cache instead
auto *structure = gcc_jit_type_is_struct(compositeTy);
if (!structure)
llvm_unreachable("expected struct type");
auto *field = gcc_jit_struct_get_field(structure, index);
if (isa<LValueType>(op.getType()))
return gcc_jit_lvalue_access_field(composite, loc, field);
return gcc_jit_rvalue_access_field(composite, loc, field);
}

Expr RegionVisitor::visitExprWithoutCache(ExprOp op) {
RegionVisitor visitor(getTranslator(), op.getRegion(), this);
return visitor.translateIntoContext();
}
Expand Down
15 changes: 14 additions & 1 deletion test/lowering/gemm.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,12 @@
// RUN: %gccjit-opt %s -lower-affine -convert-scf-to-cf -convert-arith-to-gccjit -convert-memref-to-gccjit -convert-func-to-gccjit -reconcile-unrealized-casts | %filecheck %s
// RUN: %gccjit-opt %s \
// RUN: -lower-affine \
// RUN: -convert-scf-to-cf \
// RUN: -convert-arith-to-gccjit \
// RUN: -convert-memref-to-gccjit \
// RUN: -convert-func-to-gccjit \
// RUN: -reconcile-unrealized-casts -mlir-print-debuginfo -o %t.mlir
// RUN: %filecheck --input-file=%t.mlir %s
// RUN: %gccjit-translate %t.mlir -mlir-to-gccjit-gimple | %filecheck %s --check-prefix=CHECK-GIMPLE
module {
// CHECK-NOT: func.func
// CHECK-NOT: func.return
Expand All @@ -15,11 +23,15 @@ module {
%acc0 = arith.constant 0.0 : f32
%sum = affine.for %k = 0 to 100 iter_args(%acc = %acc0) -> f32 {
// Load values from A and B
// CHECK-GIMPLE: %{{[0-9]+}} = %{{[0-9\.a-z]+}}[(%{{[0-9]+}} * (size_t)100 + %{{[0-9]+}})]
%a_val = affine.load %A[%i, %k] : memref<100x100xf32>
// CHECK-GIMPLE: %{{[0-9]+}} = %{{[0-9\.a-z]+}}[(%{{[0-9]+}} * (size_t)100 + %{{[0-9]+}})]
%b_val = affine.load %B[%k, %j] : memref<100x100xf32>

// Multiply and accumulate
// CHECK-GIMPLE: %[[V:[0-9]+]] = %{{[0-9]+}} * %{{[0-9]+}}
%prod = arith.mulf %a_val, %b_val : f32
// CHECK-GIMPLE: %{{[0-9]+}} = %{{[0-9]+}} + %[[V]]
%new_acc = arith.addf %acc, %prod : f32

// Yield the new accumulated value
Expand All @@ -33,6 +45,7 @@ module {
%final_val = arith.addf %c_scaled, %result : f32

// Store the final result back to matrix C
// CHECK-GIMPLE: %{{[0-9\.a-z]+}}[(%{{[0-9]+}} * (size_t)100 + %{{[0-9]+}})] = %{{[0-9]+}}
affine.store %final_val, %C[%i, %j] : memref<100x100xf32>
}
}
Expand Down
20 changes: 10 additions & 10 deletions test/syntax/record.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,19 @@
module @test {
gccjit.func imported @gemm (
!gccjit.struct<"__memref_188510220862752" {
#gccjit.field<!gccjit.ptr<!gccjit.fp<float>> "base">,
#gccjit.field<!gccjit.ptr<!gccjit.fp<float>> "aligned">,
#gccjit.field<!gccjit.int<size_t> "offset">,
#gccjit.field<!gccjit.array<!gccjit.int<size_t>, 2> "sizes">,
#gccjit.field<!gccjit.array<!gccjit.int<size_t>, 2> "strides">
#gccjit.field<"base" !gccjit.ptr<!gccjit.fp<float>>>,
#gccjit.field<"aligned" !gccjit.ptr<!gccjit.fp<float>>>,
#gccjit.field<"offset" !gccjit.int<size_t> : 32>,
#gccjit.field<"sizes" !gccjit.array<!gccjit.int<size_t>, 2>>,
#gccjit.field<"strides" !gccjit.array<!gccjit.int<size_t>, 2>>
}>
)
// CHECK: @gemm
// CHECK-SAME: !gccjit.struct<"__memref_188510220862752" {
// CHECK-SAME: #gccjit.field<!gccjit.ptr<!gccjit.fp<float>> "base">
// CHECK-SAME: #gccjit.field<!gccjit.ptr<!gccjit.fp<float>> "aligned">
// CHECK-SAME: #gccjit.field<!gccjit.int<size_t> "offset">
// CHECK-SAME: #gccjit.field<!gccjit.array<!gccjit.int<size_t>, 2> "sizes">
// CHECK-SAME: #gccjit.field<!gccjit.array<!gccjit.int<size_t>, 2> "strides">
// CHECK-SAME: #gccjit.field<"base" !gccjit.ptr<!gccjit.fp<float>>>
// CHECK-SAME: #gccjit.field<"aligned" !gccjit.ptr<!gccjit.fp<float>>>
// CHECK-SAME: #gccjit.field<"offset" !gccjit.int<size_t> : 32>
// CHECK-SAME: #gccjit.field<"sizes" !gccjit.array<!gccjit.int<size_t>, 2>>
// CHECK-SAME: #gccjit.field<"strides" !gccjit.array<!gccjit.int<size_t>, 2>>
// CHECK-SAME: }
}

0 comments on commit 1457b51

Please sign in to comment.