Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[gccjit] lower memref #22

Merged
merged 4 commits into from
Nov 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions include/mlir-gccjit/IR/GCCJITAttrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -180,20 +180,21 @@ def FieldAttr : GCCJIT_Attr<"Field", "field"> {

let builders = [
AttrBuilder<(ins "mlir::StringAttr":$name, "mlir::Type":$type), [{
return get($_ctxt, name, type, 0, std::nullopt);
return get($_ctxt, name, type, std::nullopt, std::nullopt);
}]>,
AttrBuilder<(ins "mlir::StringAttr":$name, "mlir::Type":$type,
"unsigned":$bitWidth), [{
return get($_ctxt, name, type, bitWidth, std::nullopt);
}]>,
AttrBuilder<(ins "mlir::StringAttr":$name, "mlir::Type":$type,
"mlir::gccjit::SourceLocAttr":$loc), [{
return get($_ctxt, name, type, 0, loc);
return get($_ctxt, name, type, std::nullopt, loc);
}]>,
];

// attribute can eat up the `:` separator, so we need to move the name to the front
let assemblyFormat = [{
`<` $type $name (`:` $bitWidth^)? ($loc^)? `>`
`<` $name $type (`:` $bitWidth^)? ($loc^)? `>`
}];
}

Expand Down
2 changes: 1 addition & 1 deletion include/mlir-gccjit/IR/GCCJITOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -779,7 +779,7 @@ def AccessFieldOp : GCCJIT_Op<"access_field"> {
```
}];
let arguments = (ins AnyType:$composite, IndexAttr:$field);
let results = (outs GCCJIT_LValueType:$result);
let results = (outs AnyType:$result);
let assemblyFormat = [{
$composite `[` $field `]` `:` functional-type(operands, results) attr-dict
}];
Expand Down
199 changes: 198 additions & 1 deletion src/Conversion/ConvertMemrefToGCCJIT.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,24 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include <llvm/Support/Casting.h>
#include <llvm/Support/ErrorHandling.h>
#include <mlir/Dialect/Func/IR/FuncOps.h>
#include <mlir/Dialect/MemRef/IR/MemRef.h>
#include <mlir/IR/Builders.h>
#include <mlir/IR/BuiltinAttributes.h>
#include <mlir/IR/BuiltinTypeInterfaces.h>
#include <mlir/IR/BuiltinTypes.h>
#include <mlir/IR/Location.h>
#include <mlir/IR/Value.h>

#include "libgccjit.h"
#include "mlir-gccjit/Conversion/Conversions.h"
#include "mlir-gccjit/Conversion/TypeConverter.h"
#include "mlir-gccjit/IR/GCCJITAttrs.h"
#include "mlir-gccjit/IR/GCCJITOps.h"
#include "mlir-gccjit/IR/GCCJITOpsEnums.h"
#include "mlir-gccjit/IR/GCCJITTypes.h"
#include "mlir-gccjit/Passes.h"

using namespace mlir;
Expand All @@ -29,13 +42,197 @@ struct ConvertMemrefToGCCJITPass
void runOnOperation() override final;
};

template <typename T>
class GCCJITLoweringPattern : public mlir::OpConversionPattern<T> {
protected:
const GCCJITTypeConverter *getTypeConverter() const {
return static_cast<const GCCJITTypeConverter *>(this->typeConverter);
}

public:
using OpConversionPattern<T>::OpConversionPattern;
};

Value createIndexAttrConstant(OpBuilder &builder, Location loc, Type resultType,
int64_t value) {

auto indexTy = IntType::get(builder.getContext(), GCC_JIT_TYPE_SIZE_T);
auto intAttr = IntAttr::get(builder.getContext(), indexTy,
{64, static_cast<uint64_t>(value)});
return builder.create<gccjit::ConstantOp>(loc, resultType, intAttr);
}

Value getMemRefDescriptorOffset(OpBuilder &builder, Value descriptor,
Location loc) {
auto indexTy = IntType::get(builder.getContext(), GCC_JIT_TYPE_SIZE_T);
return builder.create<gccjit::AccessFieldOp>(loc, indexTy, descriptor,
builder.getIndexAttr(2));
}

Value getMemRefDiscriptorAlignedPtr(OpBuilder &builder, Value descriptor,
const GCCJITTypeConverter &converter,
Location loc, MemRefType type) {
auto elementType = converter.convertType(type.getElementType());
auto ptrTy = PointerType::get(builder.getContext(), elementType);
return builder.create<gccjit::AccessFieldOp>(loc, ptrTy, descriptor,
builder.getIndexAttr(1));
}

Value getMemRefDescriptorBufferPtr(OpBuilder &builder, Location loc,
Value descriptor,
const GCCJITTypeConverter &converter,
MemRefType type) {
auto [strides, offsetCst] = getStridesAndOffset(type);
auto alignedPtr =
getMemRefDiscriptorAlignedPtr(builder, descriptor, converter, loc, type);

// For zero offsets, we already have the base pointer.
if (offsetCst == 0)
return alignedPtr;

// Otherwise add the offset to the aligned base.
Type indexType = IntType::get(builder.getContext(), GCC_JIT_TYPE_SIZE_T);
Value offsetVal =
ShapedType::isDynamic(offsetCst)
? getMemRefDescriptorOffset(builder, descriptor, loc)
: createIndexAttrConstant(builder, loc, indexType, offsetCst);
Type elementType = converter.convertType(type.getElementType());
auto lvalueTy = LValueType::get(builder.getContext(), elementType);
auto lvalue =
builder.create<gccjit::DerefOp>(loc, lvalueTy, alignedPtr, offsetVal);
return builder.create<gccjit::AddrOp>(
loc, PointerType::get(builder.getContext(), elementType), lvalue);
}

Value getStridedElementLValue(Location loc, MemRefType type, Value descriptor,
ExprOp parent, ValueRange indices,
const GCCJITTypeConverter &typeConverter,
ConversionPatternRewriter &rewriter) {
Value materializedMemref = nullptr;
Value ptrToStrideField = nullptr;
auto [strides, offset] = getStridesAndOffset(type);
auto indexTy = IntType::get(rewriter.getContext(), GCC_JIT_TYPE_SIZE_T);
auto elementType = typeConverter.convertType(type.getElementType());
auto doMaterialization = [&]() {
if (materializedMemref)
return;
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(parent);
auto lvalueTy =
LValueType::get(rewriter.getContext(), descriptor.getType());
materializedMemref = rewriter.create<gccjit::LocalOp>(
loc, lvalueTy, nullptr, nullptr, nullptr);
rewriter.create<gccjit::AssignOp>(loc, descriptor, materializedMemref);
};
auto generateStride = [&](size_t i) -> Value {
doMaterialization();
if (!ptrToStrideField) {
auto descriptorTy = cast<StructType>(descriptor.getType());
auto fieldTy = cast<ArrayType>(
cast<FieldAttr>(descriptorTy.getRecordFields()[4]).getType());
auto fieldLValueTy = LValueType::get(rewriter.getContext(), fieldTy);
auto strideField = rewriter.create<gccjit::AccessFieldOp>(
loc, fieldLValueTy, materializedMemref, rewriter.getIndexAttr(4));
auto ptrToStrideArray = rewriter.create<gccjit::AddrOp>(
loc, PointerType::get(rewriter.getContext(), fieldTy), strideField);
ptrToStrideField = rewriter.create<gccjit::BitCastOp>(
loc, PointerType::get(rewriter.getContext(), indexTy),
ptrToStrideArray);
}
auto offset = rewriter.create<gccjit::AccessFieldOp>(
loc, indexTy, ptrToStrideField, rewriter.getIndexAttr(i));
auto strideLValue = rewriter.create<gccjit::DerefOp>(
loc, LValueType::get(rewriter.getContext(), indexTy), ptrToStrideField,
offset);
return rewriter.create<gccjit::AsRValueOp>(loc, indexTy, strideLValue);
};

Value base = getMemRefDescriptorBufferPtr(rewriter, loc, descriptor,
typeConverter, type);
Value index;
for (int i = 0, e = indices.size(); i < e; ++i) {
Value increment = indices[i];
if (strides[i] != 1) { // Skip if stride is 1.
Value stride =
ShapedType::isDynamic(strides[i])
? generateStride(i)
: createIndexAttrConstant(rewriter, loc, indexTy, strides[i]);
increment = rewriter.create<gccjit::BinaryOp>(loc, indexTy, BOp::Mult,
increment, stride);
}
index = index ? rewriter.create<gccjit::BinaryOp>(loc, indexTy, BOp::Plus,
index, increment)
: increment;
}

return rewriter.create<gccjit::DerefOp>(
loc, LValueType::get(rewriter.getContext(), elementType), base, index);
}

class LoadOpLowering : public GCCJITLoweringPattern<memref::LoadOp> {
public:
using GCCJITLoweringPattern::GCCJITLoweringPattern;
mlir::LogicalResult
matchAndRewrite(memref::LoadOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
auto type = op.getMemRefType();
auto retTy = typeConverter->convertType(op.getResult().getType());
auto exprBundle = rewriter.replaceOpWithNewOp<ExprOp>(op, retTy);
auto *block = rewriter.createBlock(&exprBundle.getBody());
rewriter.setInsertionPointToStart(block);
Value dataLValue = getStridedElementLValue(
op.getLoc(), type, adaptor.getMemref(), exprBundle,
adaptor.getIndices(), *getTypeConverter(), rewriter);
auto rvalue = rewriter.create<AsRValueOp>(op.getLoc(), retTy, dataLValue);
rewriter.create<ReturnOp>(op.getLoc(), rvalue);
return success();
}
};

class StoreOpLowering : public GCCJITLoweringPattern<memref::StoreOp> {
public:
using GCCJITLoweringPattern::GCCJITLoweringPattern;
mlir::LogicalResult
matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
auto type = op.getMemRefType();
auto elemTy = typeConverter->convertType(type.getElementType());
auto elemLValueTy = LValueType::get(rewriter.getContext(), elemTy);
auto expr = rewriter.create<ExprOp>(op->getLoc(), elemLValueTy, true);
auto *block = rewriter.createBlock(&expr.getBody());
{
rewriter.setInsertionPointToStart(block);
Value dataLValue = getStridedElementLValue(
op.getLoc(), type, adaptor.getMemref(), expr, adaptor.getIndices(),
*getTypeConverter(), rewriter);
rewriter.create<ReturnOp>(op.getLoc(), dataLValue);
}
rewriter.setInsertionPoint(op);
rewriter.replaceOpWithNewOp<AssignOp>(op, adaptor.getValue(), expr);
return success();
}
};

void ConvertMemrefToGCCJITPass::runOnOperation() {
auto moduleOp = getOperation();
auto typeConverter = GCCJITTypeConverter();
auto materializeAsUnrealizedCast = [](OpBuilder &builder, Type resultType,
ValueRange inputs,
Location loc) -> Value {
if (inputs.size() != 1)
return Value();

return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
.getResult(0);
};
typeConverter.addTargetMaterialization(materializeAsUnrealizedCast);
typeConverter.addSourceMaterialization(materializeAsUnrealizedCast);
mlir::RewritePatternSet patterns(&getContext());
patterns.insert<LoadOpLowering, StoreOpLowering>(typeConverter,
&getContext());
mlir::ConversionTarget target(getContext());
target.addLegalDialect<gccjit::GCCJITDialect>();
target.addIllegalDialect<mlir::memref::MemRefDialect>();
target.addIllegalDialect<memref::MemRefDialect>();
llvm::SmallVector<Operation *> ops;
for (auto func : moduleOp.getOps<func::FuncOp>())
ops.push_back(func);
Expand Down
6 changes: 3 additions & 3 deletions src/Conversion/TypeConverter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ GCCJITTypeConverter::getMemrefDescriptorType(mlir::MemRefType type) const {
llvm::enumerate(ArrayRef<Type>{elementPtrType, elementPtrType, indexType,
dimOrStrideType, dimOrStrideType})) {
auto nameAttr = StringAttr::get(type.getContext(), names[idx]);
fields.push_back(FieldAttr::get(type.getContext(), nameAttr, field, 0));
fields.push_back(FieldAttr::get(type.getContext(), nameAttr, field));
}
auto fieldsAttr = ArrayAttr::get(type.getContext(), fields);
return StructType::get(type.getContext(), nameAttr, fieldsAttr);
Expand All @@ -199,7 +199,7 @@ gccjit::StructType GCCJITTypeConverter::getUnrankedMemrefDescriptorType(
llvm::enumerate(ArrayRef<Type>{indexType, opaquePtrType})) {
auto name = Twine("__field_").concat(Twine(idx)).str();
auto nameAttr = StringAttr::get(type.getContext(), name);
fields.push_back(FieldAttr::get(type.getContext(), nameAttr, field, 0));
fields.push_back(FieldAttr::get(type.getContext(), nameAttr, field));
}
auto fieldsAttr = ArrayAttr::get(type.getContext(), fields);
return StructType::get(type.getContext(), nameAttr, fieldsAttr);
Expand All @@ -220,7 +220,7 @@ Type GCCJITTypeConverter::convertAndPackTypesIfNonSingleton(
for (auto [idx, type] : llvm::enumerate(types)) {
auto name = Twine("__field_").concat(Twine(idx)).str();
auto nameAttr = StringAttr::get(func.getContext(), name);
fields.push_back(FieldAttr::get(type.getContext(), nameAttr, type, 0));
fields.push_back(FieldAttr::get(type.getContext(), nameAttr, type));
}
auto nameAttr = StringAttr::get(func.getContext(), name);
auto fieldsAttr = ArrayAttr::get(func.getContext(), fields);
Expand Down
31 changes: 23 additions & 8 deletions src/Translation/TranslateToGCCJIT.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,9 @@ class RegionVisitor {
gcc_jit_rvalue *visitExprWithoutCache(AddrOp op);
gcc_jit_rvalue *visitExprWithoutCache(FnAddrOp op);
gcc_jit_lvalue *visitExprWithoutCache(GetGlobalOp op);
gcc_jit_rvalue *visitExprWithoutCache(ExprOp op);
Expr visitExprWithoutCache(ExprOp op);
gcc_jit_lvalue *visitExprWithoutCache(DerefOp op);
Expr visitExprWithoutCache(AccessFieldOp op);

/// The following operations are entrypoints for real codegen.
void visitAssignOp(gcc_jit_block *blk, AssignOp op);
Expand Down Expand Up @@ -515,18 +516,17 @@ Expr RegionVisitor::translateIntoContext() {
Block &block = region.getBlocks().front();
auto terminator = cast<gccjit::ReturnOp>(block.getTerminator());
auto value = terminator->getOperand(0);
auto rvalue = visitExpr(value, true);
auto expr = visitExpr(value, true);

if (auto globalOp = dyn_cast<gccjit::GlobalOp>(parent)) {
auto symName = SymbolRefAttr::get(getMLIRContext(), globalOp.getSymName());
auto *lvalue = getTranslator().getGlobalLValue(symName);
gcc_jit_global_set_initializer_rvalue(lvalue, rvalue);
gcc_jit_global_set_initializer_rvalue(lvalue, expr);
return {};
}

if (auto exprOp = dyn_cast<ExprOp>(parent)) {
return rvalue;
}
if (auto exprOp = dyn_cast<ExprOp>(parent))
return expr;

llvm_unreachable("unknown region parent");
}
Expand Down Expand Up @@ -558,8 +558,8 @@ Expr RegionVisitor::visitExpr(Value value, bool toplevel) {
.Case([&](GetGlobalOp op) { return visitExprWithoutCache(op); })
.Case([&](ExprOp op) { return visitExprWithoutCache(op); })
.Case([&](DerefOp op) { return visitExprWithoutCache(op); })
.Case([&](AccessFieldOp op) { return visitExprWithoutCache(op); })
.Default([](Operation *op) -> Expr {
op->dump();
llvm::report_fatal_error("unknown expression type");
});

Expand All @@ -579,7 +579,22 @@ gcc_jit_lvalue *RegionVisitor::visitExprWithoutCache(DerefOp op) {
return gcc_jit_context_new_array_access(getContext(), loc, ptr, offset);
}

gcc_jit_rvalue *RegionVisitor::visitExprWithoutCache(ExprOp op) {
Expr RegionVisitor::visitExprWithoutCache(AccessFieldOp op) {
auto composite = visitExpr(op.getComposite());
auto *loc = getTranslator().getLocation(op.getLoc());
auto *compositeTy = getTranslator().convertType(op.getComposite().getType());
auto index = op.getField().getZExtValue();
// TODO: support union and query from cache instead
auto *structure = gcc_jit_type_is_struct(compositeTy);
if (!structure)
llvm_unreachable("expected struct type");
auto *field = gcc_jit_struct_get_field(structure, index);
if (isa<LValueType>(op.getType()))
return gcc_jit_lvalue_access_field(composite, loc, field);
return gcc_jit_rvalue_access_field(composite, loc, field);
}

Expr RegionVisitor::visitExprWithoutCache(ExprOp op) {
RegionVisitor visitor(getTranslator(), op.getRegion(), this);
return visitor.translateIntoContext();
}
Expand Down
20 changes: 18 additions & 2 deletions test/lowering/gemm.mlir
Original file line number Diff line number Diff line change
@@ -1,5 +1,16 @@
// RUN: %gccjit-opt %s -lower-affine -convert-scf-to-cf -convert-arith-to-gccjit -convert-func-to-gccjit -reconcile-unrealized-casts | %filecheck %s
module {
// RUN: %gccjit-opt %s \
// RUN: -lower-affine \
// RUN: -convert-scf-to-cf \
// RUN: -convert-arith-to-gccjit \
// RUN: -convert-memref-to-gccjit \
// RUN: -convert-func-to-gccjit \
// RUN: -reconcile-unrealized-casts -mlir-print-debuginfo -o %t.mlir
// RUN: %filecheck --input-file=%t.mlir %s
// RUN: %gccjit-translate %t.mlir -mlir-to-gccjit-gimple | %filecheck %s --check-prefix=CHECK-GIMPLE
module @test attributes {
gccjit.opt_level = #gccjit.opt_level<O3>
}
{
// CHECK-NOT: func.func
// CHECK-NOT: func.return
// CHECK-NOT: cf.cond_br
Expand All @@ -15,11 +26,15 @@ module {
%acc0 = arith.constant 0.0 : f32
%sum = affine.for %k = 0 to 100 iter_args(%acc = %acc0) -> f32 {
// Load values from A and B
// CHECK-GIMPLE: %{{[0-9]+}} = %{{[0-9\.a-z]+}}[(%{{[0-9]+}} * (size_t)100 + %{{[0-9]+}})]
%a_val = affine.load %A[%i, %k] : memref<100x100xf32>
// CHECK-GIMPLE: %{{[0-9]+}} = %{{[0-9\.a-z]+}}[(%{{[0-9]+}} * (size_t)100 + %{{[0-9]+}})]
%b_val = affine.load %B[%k, %j] : memref<100x100xf32>

// Multiply and accumulate
// CHECK-GIMPLE: %[[V:[0-9]+]] = %{{[0-9]+}} * %{{[0-9]+}}
%prod = arith.mulf %a_val, %b_val : f32
// CHECK-GIMPLE: %{{[0-9]+}} = %{{[0-9]+}} + %[[V]]
%new_acc = arith.addf %acc, %prod : f32

// Yield the new accumulated value
Expand All @@ -33,6 +48,7 @@ module {
%final_val = arith.addf %c_scaled, %result : f32

// Store the final result back to matrix C
// CHECK-GIMPLE: %{{[0-9\.a-z]+}}[(%{{[0-9]+}} * (size_t)100 + %{{[0-9]+}})] = %{{[0-9]+}}
affine.store %final_val, %C[%i, %j] : memref<100x100xf32>
}
}
Expand Down
Loading