Skip to content

Commit

Permalink
[gccjit] lower assume alignment
Browse files Browse the repository at this point in the history
  • Loading branch information
SchrodingerZhu committed Nov 8, 2024
1 parent 9290b36 commit e4b066d
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 3 deletions.
90 changes: 87 additions & 3 deletions src/Conversion/ConvertMemrefToGCCJIT.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,17 @@
#include <mlir/Dialect/MemRef/IR/MemRef.h>
#include <mlir/IR/Builders.h>
#include <mlir/IR/BuiltinAttributes.h>
#include <mlir/IR/BuiltinOps.h>
#include <mlir/IR/BuiltinTypeInterfaces.h>
#include <mlir/IR/BuiltinTypes.h>
#include <mlir/IR/Dominance.h>
#include <mlir/IR/Location.h>
#include <mlir/IR/Operation.h>
#include <mlir/IR/PatternMatch.h>
#include <mlir/IR/Types.h>
#include <mlir/IR/Value.h>
#include <mlir/Support/LogicalResult.h>
#include <mlir/Transforms/DialectConversion.h>

#include "libgccjit.h"
#include "mlir-gccjit/Conversion/Conversions.h"
Expand Down Expand Up @@ -584,8 +589,88 @@ struct AllocOpLowering : public AllocationLowering<memref::AllocOp> {
using AllocationLowering<memref::AllocOp>::AllocationLowering;
};

void removeAssumeAlignmentOp(memref::AssumeAlignmentOp op,
GCCJITTypeConverter *typeConverter,
IRRewriter &rewriter,
MutableArrayRef<OpOperand *> replacement) {
rewriter.setInsertionPoint(op);
auto memRefType = cast<MemRefType>(op.getMemref().getType());
auto descriptorType = typeConverter->getMemrefDescriptorType(memRefType);
auto materializedMemref = typeConverter->materializeTargetConversion(
rewriter, op.getLoc(), descriptorType, op.getMemref());
auto ptrType = cast<FieldAttr>(descriptorType.getFields()[0]).getType();
auto offsetType = cast<FieldAttr>(descriptorType.getFields()[2]).getType();
auto arrayType = cast<FieldAttr>(descriptorType.getFields()[3]).getType();
auto exprBundle = rewriter.create<ExprOp>(op->getLoc(), descriptorType);
auto *block = rewriter.createBlock(&exprBundle.getBody());
auto voidPtrType = PointerType::get(rewriter.getContext(),
VoidType::get(rewriter.getContext()));
rewriter.setInsertionPointToStart(block);
Value allocPtr = rewriter.create<gccjit::AccessFieldOp>(
op.getLoc(), ptrType, materializedMemref, rewriter.getIndexAttr(0));
Value alignedPtr = rewriter.create<gccjit::AccessFieldOp>(
op.getLoc(), ptrType, materializedMemref, rewriter.getIndexAttr(1));
Value offset = rewriter.create<gccjit::AccessFieldOp>(
op.getLoc(), offsetType, materializedMemref, rewriter.getIndexAttr(2));
Value sizes = rewriter.create<gccjit::AccessFieldOp>(
op.getLoc(), arrayType, materializedMemref, rewriter.getIndexAttr(3));
Value strides = rewriter.create<gccjit::AccessFieldOp>(
op.getLoc(), arrayType, materializedMemref, rewriter.getIndexAttr(4));
alignedPtr =
rewriter.create<gccjit::BitCastOp>(op.getLoc(), voidPtrType, alignedPtr);
alignedPtr =
rewriter
.create<gccjit::CallOp>(
op.getLoc(), voidPtrType,
SymbolRefAttr::get(rewriter.getContext(),
"__builtin_assume_aligned"),
ValueRange{alignedPtr, offset},
/* tailcall */ nullptr, /* builtin */ rewriter.getUnitAttr())
.getResult();
alignedPtr =
rewriter.create<gccjit::BitCastOp>(op.getLoc(), ptrType, alignedPtr);
auto newMemRef = rewriter.create<gccjit::NewStructOp>(
op.getLoc(), descriptorType, ArrayRef<int32_t>{0, 1, 2, 3, 4},
ValueRange{allocPtr, alignedPtr, offset, sizes, strides});
rewriter.create<gccjit::ReturnOp>(op.getLoc(), newMemRef);
rewriter.setInsertionPoint(op);
auto srcValue = typeConverter->materializeSourceConversion(
rewriter, op.getLoc(), memRefType, exprBundle.getResult());
for (auto &use : replacement)
use->set(srcValue);
rewriter.eraseOp(op);
}

void removeAllAssumeAlignmentOps(ModuleOp moduleOp,
GCCJITTypeConverter *typeConverter,
DominanceInfo &domInfo,
llvm::SmallVectorImpl<Operation *> &ops) {
for (auto func : moduleOp.getOps<func::FuncOp>()) {
llvm::DenseMap<memref::AssumeAlignmentOp, llvm::SmallVector<OpOperand *>>
replacement;
func.walk([&](memref::AssumeAlignmentOp op) {
for (auto &use : op.getMemref().getUses()) {
auto *user = use.getOwner();
if (isa<memref::AssumeAlignmentOp>(user))
continue;
if (domInfo.properlyDominates(op, user))
replacement[op].push_back(&use);
}
});
IRRewriter rewriter(func.getContext());
rewriter.startOpModification(func);
func->walk([&](memref::AssumeAlignmentOp op) {
removeAssumeAlignmentOp(op, typeConverter, rewriter, replacement[op]);
});
rewriter.finalizeOpModification(func);
ops.push_back(func);
domInfo.invalidate(&func.getFunctionBody());
}
}

void ConvertMemrefToGCCJITPass::runOnOperation() {
auto moduleOp = getOperation();
auto &domInfo = getAnalysis<DominanceInfo>();
auto typeConverter = GCCJITTypeConverter();
auto materializeAsUnrealizedCast = [](OpBuilder &builder, Type resultType,
ValueRange inputs,
Expand All @@ -603,11 +688,10 @@ void ConvertMemrefToGCCJITPass::runOnOperation() {
AllocOpLowering, DeallocOpLowering>(typeConverter,
&getContext());
mlir::ConversionTarget target(getContext());
target.addLegalDialect<gccjit::GCCJITDialect>();
target.addLegalDialect<gccjit::GCCJITDialect, BuiltinDialect>();
target.addIllegalDialect<memref::MemRefDialect>();
llvm::SmallVector<Operation *> ops;
for (auto func : moduleOp.getOps<func::FuncOp>())
ops.push_back(func);
removeAllAssumeAlignmentOps(moduleOp, &typeConverter, domInfo, ops);
if (failed(applyPartialConversion(ops, target, std::move(patterns))))
signalPassFailure();
}
Expand Down
16 changes: 16 additions & 0 deletions test/lowering/assume_aligned.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
// RUN: %gccjit-opt %s -convert-memref-to-gccjit -convert-memref-to-gccjit -convert-func-to-gccjit -reconcile-unrealized-casts -o %t.mlir -mlir-print-debuginfo
// RUN: %gccjit-translate %t.mlir -o %t.gimple -mlir-to-gccjit-gimple
// RUN: %filecheck --input-file=%t.gimple %s --check-prefix=CHECK-GIMPLE
module @test attributes {
gccjit.opt_level = #gccjit.opt_level<O3>,
gccjit.debug_info = false
}
{
func.func @foo(%arg0: memref<100x100xf32>) -> memref<100x100xf32> {
// CHECK-GIMPLE: %0 = %arg0;
// CHECK-GIMPLE: %1 = (struct memref<100x100xf32>) {.base=%0.base, .aligned=bitcast(__builtin_assume_aligned ((bitcast(%0.aligned, void *)), %0.offset), float *), .offset=%0.offset, .sizes=%0.sizes, .strides=%0.strides};
// CHECK-GIMPLE: return %1;
memref.assume_alignment %arg0, 128 : memref<100x100xf32>
return %arg0 : memref<100x100xf32>
}
}

0 comments on commit e4b066d

Please sign in to comment.