Skip to content

Commit

Permalink
[gccjit] implement alloca operation
Browse files Browse the repository at this point in the history
  • Loading branch information
SchrodingerZhu committed Nov 7, 2024
1 parent 96125ce commit c0392c3
Show file tree
Hide file tree
Showing 3 changed files with 172 additions and 38 deletions.
121 changes: 87 additions & 34 deletions src/Conversion/ConvertMemrefToGCCJIT.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,17 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include <llvm-20/llvm/Support/LogicalResult.h>
#include <llvm/Support/Casting.h>
#include <llvm/Support/ErrorHandling.h>
#include <llvm/Support/LogicalResult.h>
#include <mlir/Dialect/Func/IR/FuncOps.h>
#include <mlir/Dialect/MemRef/IR/MemRef.h>
#include <mlir/IR/Builders.h>
#include <mlir/IR/BuiltinAttributes.h>
#include <mlir/IR/BuiltinTypeInterfaces.h>
#include <mlir/IR/BuiltinTypes.h>
#include <mlir/IR/Location.h>
#include <mlir/IR/Types.h>
#include <mlir/IR/Value.h>

#include "libgccjit.h"
Expand All @@ -32,7 +33,6 @@
#include "mlir-gccjit/IR/GCCJITOpsEnums.h"
#include "mlir-gccjit/IR/GCCJITTypes.h"
#include "mlir-gccjit/Passes.h"
#include "mlir/IR/Types.h"

using namespace mlir;
using namespace mlir::gccjit;
Expand Down Expand Up @@ -133,12 +133,13 @@ class AllocationLowering : public GCCJITLoweringPattern<OpType> {

virtual std::tuple<Value, Value>
allocateBuffer(ConversionPatternRewriter &rewriter, Location loc, Value size,
Operation *op) const = 0;
OpType op) const = 0;

private:
static constexpr uint64_t kMinAlignedAllocAlignment = 16UL;

public:
using GCCJITLoweringPattern<OpType>::GCCJITLoweringPattern;
LogicalResult
matchAndRewrite(OpType op,
typename OpConversionPattern<OpType>::OpAdaptor adaptor,
Expand Down Expand Up @@ -191,33 +192,6 @@ class StoreOpLowering : public GCCJITLoweringPattern<memref::StoreOp> {
}
};

void ConvertMemrefToGCCJITPass::runOnOperation() {
auto moduleOp = getOperation();
auto typeConverter = GCCJITTypeConverter();
auto materializeAsUnrealizedCast = [](OpBuilder &builder, Type resultType,
ValueRange inputs,
Location loc) -> Value {
if (inputs.size() != 1)
return Value();

return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
.getResult(0);
};
typeConverter.addTargetMaterialization(materializeAsUnrealizedCast);
typeConverter.addSourceMaterialization(materializeAsUnrealizedCast);
mlir::RewritePatternSet patterns(&getContext());
patterns.insert<LoadOpLowering, StoreOpLowering>(typeConverter,
&getContext());
mlir::ConversionTarget target(getContext());
target.addLegalDialect<gccjit::GCCJITDialect>();
target.addIllegalDialect<memref::MemRefDialect>();
llvm::SmallVector<Operation *> ops;
for (auto func : moduleOp.getOps<func::FuncOp>())
ops.push_back(func);
if (failed(applyPartialConversion(ops, target, std::move(patterns))))
signalPassFailure();
}

template <typename T> IntType GCCJITLoweringPattern<T>::getIndexType() const {
return IntType::get(this->getContext(), GCC_JIT_TYPE_SIZE_T);
}
Expand Down Expand Up @@ -472,6 +446,7 @@ Value AllocationLowering<OpType>::allocateBufferAutoAlign(
return rewriter.create<gccjit::BitCastOp>(loc, elementPtrType, result);
}

[[gnu::used]]
bool isConvertibleAndHasIdentityMaps(MemRefType type,
const GCCJITTypeConverter &typeConverter) {
if (!typeConverter.convertType(type.getElementType()))
Expand All @@ -485,7 +460,7 @@ void GCCJITLoweringPattern<OpType>::getMemRefDescriptorSizes(
ConversionPatternRewriter &rewriter, SmallVectorImpl<Value> &sizes,
SmallVectorImpl<Value> &strides, Value &size, bool sizeInBytes) const {
assert(
isConvertibleAndHasIdentityMaps(memRefType, this->getTypeConverter()) &&
isConvertibleAndHasIdentityMaps(memRefType, *this->getTypeConverter()) &&
"layout maps must have been normalized away");
assert(count(memRefType.getShape(), ShapedType::kDynamic) ==
static_cast<ssize_t>(dynamicSizes.size()) &&
Expand Down Expand Up @@ -528,6 +503,8 @@ void GCCJITLoweringPattern<OpType>::getMemRefDescriptorSizes(
Type elementType =
this->getTypeConverter()->convertType(memRefType.getElementType());
size = rewriter.create<gccjit::SizeOfOp>(loc, indexType, elementType);
size = rewriter.create<gccjit::BinaryOp>(loc, indexType, BOp::Mult, size,
runningStride);
} else {
size = runningStride;
}
Expand Down Expand Up @@ -565,18 +542,94 @@ LogicalResult AllocationLowering<OpType>::matchAndRewrite(
return rewriter.notifyMatchFailure(loc,
"underlying buffer allocation failed");

auto arrayTy = ArrayType::get(rewriter.getContext(), this->getIndexType(),
memRefType.getRank());
auto sizeArr = rewriter.create<gccjit::NewArrayOp>(loc, arrayTy, sizes);
auto strideArr = rewriter.create<gccjit::NewArrayOp>(loc, arrayTy, strides);
auto zero =
this->createIndexAttrConstant(rewriter, loc, this->getIndexType(), 0);
// Create the MemRef descriptor.
auto memRefDescriptor = rewriter.create<gccjit::NewStructOp>(
loc, convertedType, ArrayRef<int32_t>{0, 1, 2, 3, 4},
ValueRange{alignedPtr, allocatedPtr, size});
ValueRange{alignedPtr, allocatedPtr, zero, sizeArr, strideArr});

// Return the final value of the descriptor.
rewriter.create<ReturnOp>(loc, memRefDescriptor);
}
// Return the final value of the descriptor.
rewriter.replaceOp(op, exprBundle);
return success();
}

struct AllocaOpLowering : public AllocationLowering<memref::AllocaOp> {
using AllocationLowering<memref::AllocaOp>::AllocationLowering;
std::tuple<Value, Value>
allocateBuffer(ConversionPatternRewriter &rewriter, Location loc, Value size,
memref::AllocaOp op) const override final {
auto allocaOp = cast<memref::AllocaOp>(op);
auto elementType =
typeConverter->convertType(allocaOp.getType().getElementType());

if (allocaOp.getType().getMemorySpace())
return std::make_tuple(Value(), Value());

auto elementPtrType = PointerType::get(rewriter.getContext(), elementType);

Value alloca;

if (auto align = op.getAlignment()) {
auto alignment =
createIndexAttrConstant(rewriter, loc, getIndexType(), *align);
alloca = rewriter
.create<CallOp>(loc, getVoidPtrType(),
SymbolRefAttr::get(rewriter.getContext(),
"alloca_with_align"),
ValueRange{size, alignment},
/* tailcall */ nullptr,
/* builtin */ rewriter.getUnitAttr())
.getResult();
} else {
alloca = rewriter
.create<CallOp>(
loc, getVoidPtrType(),
SymbolRefAttr::get(rewriter.getContext(), "alloca"),
ValueRange{size},
/* tailcall */ nullptr,
/* builtin */ rewriter.getUnitAttr())
.getResult();
}

alloca = rewriter.create<BitCastOp>(loc, elementPtrType, alloca);

return std::make_tuple(alloca, alloca);
}
};

void ConvertMemrefToGCCJITPass::runOnOperation() {
auto moduleOp = getOperation();
auto typeConverter = GCCJITTypeConverter();
auto materializeAsUnrealizedCast = [](OpBuilder &builder, Type resultType,
ValueRange inputs,
Location loc) -> Value {
if (inputs.size() != 1)
return Value();

return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
.getResult(0);
};
typeConverter.addTargetMaterialization(materializeAsUnrealizedCast);
typeConverter.addSourceMaterialization(materializeAsUnrealizedCast);
mlir::RewritePatternSet patterns(&getContext());
patterns.insert<LoadOpLowering, StoreOpLowering, AllocaOpLowering>(
typeConverter, &getContext());
mlir::ConversionTarget target(getContext());
target.addLegalDialect<gccjit::GCCJITDialect>();
target.addIllegalDialect<memref::MemRefDialect>();
llvm::SmallVector<Operation *> ops;
for (auto func : moduleOp.getOps<func::FuncOp>())
ops.push_back(func);
if (failed(applyPartialConversion(ops, target, std::move(patterns))))
signalPassFailure();
}

} // namespace

std::unique_ptr<Pass> mlir::gccjit::createConvertMemrefToGCCJITPass() {
Expand Down
35 changes: 31 additions & 4 deletions src/GCCJITOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include "mlir-gccjit/IR/GCCJITOps.h"

#include <llvm/ADT/STLExtras.h>
#include <llvm/ADT/SmallVector.h>
#include <llvm/ADT/StringRef.h>
Expand Down Expand Up @@ -46,6 +44,7 @@
#include <mlir/Support/LogicalResult.h>

#include "mlir-gccjit/IR/GCCJITDialect.h"
#include "mlir-gccjit/IR/GCCJITOps.h"
#include "mlir-gccjit/IR/GCCJITOpsEnums.h"
#include "mlir-gccjit/IR/GCCJITTypes.h"

Expand Down Expand Up @@ -264,13 +263,41 @@ ParseResult parseArrayOrVectorElements(
OpAsmParser &parser, Type expectedType,
llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &elementValues,
llvm::SmallVectorImpl<Type> &elementTypes) {
llvm_unreachable("Not implemented");
bool mayContinue = true;
auto parseOptionalValueTypePair = [&]() -> ParseResult {
OpAsmParser::UnresolvedOperand elementValue;
Type elementType;
if (!parser.parseOptionalOperand(elementValue).has_value()) {
mayContinue = false;
return success();
}
if (parser.parseColonType(elementType))
return failure();
elementValues.push_back(elementValue);
elementTypes.push_back(elementType);
if (parser.parseOptionalComma().succeeded()) {
mayContinue = true;
return success();
}
mayContinue = false;
return success();
};
while (mayContinue)
if (parseOptionalValueTypePair().failed())
return failure();
return success();
}

void printArrayOrVectorElements(OpAsmPrinter &p, Operation *op,
Type expectedType, OperandRange elementValues,
ValueTypeRange<OperandRange> elementTypes) {
llvm_unreachable("Not implemented");
llvm::interleaveComma(llvm::zip(elementValues, elementTypes), p,
[&](auto pair) {
auto [value, type] = pair;
p.printOperand(value);
p << " : ";
p.printType(type);
});
}

struct ParseNamedUnitAttr {
Expand Down
54 changes: 54 additions & 0 deletions test/lowering/alloca.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
// RUN: %gccjit-opt %s -o %t.mlir -convert-memref-to-gccjit
// RUN: %filecheck --input-file=%t.mlir %s
module @test
{

func.func @foo() {
// CHECK: %[[V0:[0-9]+]] = gccjit.expr {
// CHECK: %[[V1:[0-9]+]] = gccjit.const #gccjit.int<100> : !gccjit.int<size_t>
// CHECK: %[[V2:[0-9]+]] = gccjit.const #gccjit.int<100> : !gccjit.int<size_t>
// CHECK: %[[V3:[0-9]+]] = gccjit.const #gccjit.int<1> : !gccjit.int<size_t>
// CHECK: %[[V4:[0-9]+]] = gccjit.const #gccjit.int<10000> : !gccjit.int<size_t>
// CHECK: %[[V5:[0-9]+]] = gccjit.sizeof !gccjit.fp<float> : <size_t>
// CHECK: %[[V6:[0-9]+]] = gccjit.binary mult(%[[V5]] : !gccjit.int<size_t>, %[[V4]] : !gccjit.int<size_t>) : !gccjit.int<size_t>
// CHECK: %[[V7:[0-9]+]] = gccjit.call builtin @alloca(%[[V6]]) : (!gccjit.int<size_t>) -> !gccjit.ptr<!gccjit.void>
// CHECK: %[[V8:[0-9]+]] = gccjit.bitcast %[[V7]] : !gccjit.ptr<!gccjit.void> to !gccjit.ptr<!gccjit.fp<float>>
// CHECK: %[[V9:[0-9]+]] = gccjit.new_array <!gccjit.int<size_t>, 2>[%[[V1]] : !gccjit.int<size_t>, %[[V2]] : !gccjit.int<size_t>]
// CHECK: %[[V10:[0-9]+]] = gccjit.new_array <!gccjit.int<size_t>, 2>[%[[V2]] : !gccjit.int<size_t>, %[[V3]] : !gccjit.int<size_t>]
// CHECK: %[[V11:[0-9]+]] = gccjit.const #gccjit.int<0> : !gccjit.int<size_t>
// CHECK: %[[V12:[0-9]+]] = gccjit.new_struct [0, 1, 2, 3, 4][%[[V8]], %[[V8]], %[[V11]], %[[V9]], %[[V10]]] : (!gccjit.ptr<!gccjit.fp<float>>, !gccjit.ptr<!gccjit.fp<float>>, !gccjit.int<size_t>, !gccjit.array<!gccjit.int<size_t>, 2>, !gccjit.array<!gccjit.int<size_t>, 2>) -> !gccjit.struct<"memref<100x100xf32>" {#gccjit.field<"base" !gccjit.ptr<!gccjit.fp<float>>>, #gccjit.field<"aligned" !gccjit.ptr<!gccjit.fp<float>>>, #gccjit.field<"offset" !gccjit.int<size_t>>, #gccjit.field<"sizes" !gccjit.array<!gccjit.int<size_t>, 2>>, #gccjit.field<"strides" !gccjit.array<!gccjit.int<size_t>, 2>>}>
// CHECK: gccjit.return %[[V12]] : !gccjit.struct<"memref<100x100xf32>" {#gccjit.field<"base" !gccjit.ptr<!gccjit.fp<float>>>, #gccjit.field<"aligned" !gccjit.ptr<!gccjit.fp<float>>>, #gccjit.field<"offset" !gccjit.int<size_t>>, #gccjit.field<"sizes" !gccjit.array<!gccjit.int<size_t>, 2>>, #gccjit.field<"strides" !gccjit.array<!gccjit.int<size_t>, 2>>}>
// CHECK: } : !gccjit.struct<"memref<100x100xf32>" {#gccjit.field<"base" !gccjit.ptr<!gccjit.fp<float>>>, #gccjit.field<"aligned" !gccjit.ptr<!gccjit.fp<float>>>, #gccjit.field<"offset" !gccjit.int<size_t>>, #gccjit.field<"sizes" !gccjit.array<!gccjit.int<size_t>, 2>>, #gccjit.field<"strides" !gccjit.array<!gccjit.int<size_t>, 2>>}>
%a = memref.alloca () : memref<100x100xf32>
return
}

func.func @bar(%arg0 : index) {
// CHECK: %[[V0:[0-9]+]] = builtin.unrealized_conversion_cast %{{[0-9a-z]+}} : index to !gccjit.int<size_t>
// CHECK: %[[V1:[0-9]+]] = gccjit.expr {
// CHECK: %[[V2:[0-9]+]] = gccjit.const #gccjit.int<133> : !gccjit.int<size_t>
// CHECK: %[[V3:[0-9]+]] = gccjit.const #gccjit.int<723> : !gccjit.int<size_t>
// CHECK: %[[V4:[0-9]+]] = gccjit.const #gccjit.int<1> : !gccjit.int<size_t>
// CHECK: %[[V5:[0-9]+]] = gccjit.binary mult(%[[V0]] : !gccjit.int<size_t>, %[[V3]] : !gccjit.int<size_t>) : !gccjit.int<size_t>
// CHECK: %[[V6:[0-9]+]] = gccjit.binary mult(%[[V5]] : !gccjit.int<size_t>, %[[V2]] : !gccjit.int<size_t>) : !gccjit.int<size_t>
// CHECK: %[[V7:[0-9]+]] = gccjit.sizeof !gccjit.fp<float> : <size_t>
// CHECK: %[[V8:[0-9]+]] = gccjit.binary mult(%[[V7]] : !gccjit.int<size_t>, %[[V6]] : !gccjit.int<size_t>) : !gccjit.int<size_t>
// CHECK: %[[V9:[0-9]+]] = gccjit.call builtin @alloca(%[[V8]]) : (!gccjit.int<size_t>) -> !gccjit.ptr<!gccjit.void>
// CHECK: %[[V10:[0-9]+]] = gccjit.bitcast %[[V9]] : !gccjit.ptr<!gccjit.void> to !gccjit.ptr<!gccjit.fp<float>>
// CHECK: %[[V11:[0-9]+]] = gccjit.new_array <!gccjit.int<size_t>, 3>[%[[V2]] : !gccjit.int<size_t>, %[[V3]] : !gccjit.int<size_t>, %[[V0]] : !gccjit.int<size_t>]
// CHECK: %[[V12:[0-9]+]] = gccjit.new_array <!gccjit.int<size_t>, 3>[%[[V5]] : !gccjit.int<size_t>, %[[V0]] : !gccjit.int<size_t>, %[[V4]] : !gccjit.int<size_t>]
// CHECK: %[[V13:[0-9]+]] = gccjit.const #gccjit.int<0> : !gccjit.int<size_t>
// CHECK: %[[V14:[0-9]+]] = gccjit.new_struct [0, 1, 2, 3, 4][%[[V10]], %[[V10]], %[[V13]], %[[V11]], %[[V12]]] : (!gccjit.ptr<!gccjit.fp<float>>, !gccjit.ptr<!gccjit.fp<float>>, !gccjit.int<size_t>, !gccjit.array<!gccjit.int<size_t>, 3>, !gccjit.array<!gccjit.int<size_t>, 3>) -> !gccjit.struct<"memref<133x723x?xf32>" {#gccjit.field<"base" !gccjit.ptr<!gccjit.fp<float>>>, #gccjit.field<"aligned" !gccjit.ptr<!gccjit.fp<float>>>, #gccjit.field<"offset" !gccjit.int<size_t>>, #gccjit.field<"sizes" !gccjit.array<!gccjit.int<size_t>, 3>>, #gccjit.field<"strides" !gccjit.array<!gccjit.int<size_t>, 3>>}>
// CHECK: gccjit.return %[[V14]] : !gccjit.struct<"memref<133x723x?xf32>" {#gccjit.field<"base" !gccjit.ptr<!gccjit.fp<float>>>, #gccjit.field<"aligned" !gccjit.ptr<!gccjit.fp<float>>>, #gccjit.field<"offset" !gccjit.int<size_t>>, #gccjit.field<"sizes" !gccjit.array<!gccjit.int<size_t>, 3>>, #gccjit.field<"strides" !gccjit.array<!gccjit.int<size_t>, 3>>}>
// CHECK: } : !gccjit.struct<"memref<133x723x?xf32>" {#gccjit.field<"base" !gccjit.ptr<!gccjit.fp<float>>>, #gccjit.field<"aligned" !gccjit.ptr<!gccjit.fp<float>>>, #gccjit.field<"offset" !gccjit.int<size_t>>, #gccjit.field<"sizes" !gccjit.array<!gccjit.int<size_t>, 3>>, #gccjit.field<"strides" !gccjit.array<!gccjit.int<size_t>, 3>>}>
%a = memref.alloca (%arg0) : memref<133x723x?xf32>
return
}

func.func @baz(%arg0 : index) {
// CHECK: %[[V:[0-9]+]] = gccjit.const #gccjit.int<128> : !gccjit.int<size_t>
// CHECK: gccjit.call builtin @alloca_with_align(%{{[0-9]+}}, %[[V]]) : (!gccjit.int<size_t>, !gccjit.int<size_t>) -> !gccjit.ptr<!gccjit.void>
%a = memref.alloca (%arg0) {alignment = 128} : memref<133x723x?xf32>
return
}
}

0 comments on commit c0392c3

Please sign in to comment.