Skip to content

Commit

Permalink
[gccjit] fix faulty allocation code
Browse files Browse the repository at this point in the history
  • Loading branch information
SchrodingerZhu committed Nov 7, 2024
1 parent ca63a97 commit 1314f5b
Show file tree
Hide file tree
Showing 4 changed files with 116 additions and 155 deletions.
151 changes: 46 additions & 105 deletions src/Conversion/ConvertMemrefToGCCJIT.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,19 +121,13 @@ class AllocationLowering : public GCCJITLoweringPattern<OpType> {
int64_t alignedAllocationGetAlignment(ConversionPatternRewriter &rewriter,
Location loc, OpType op) const;

std::tuple<Value, Value>
allocateBufferManuallyAlign(ConversionPatternRewriter &rewriter, Location loc,
Value sizeBytes, OpType op,
Value alignment) const;

/// Allocates a memory buffer using an aligned allocation method.
Value allocateBufferAutoAlign(ConversionPatternRewriter &rewriter,
Location loc, Value sizeBytes, OpType op,
int64_t alignment) const;
Value alignment) const;

virtual std::tuple<Value, Value>
allocateBuffer(ConversionPatternRewriter &rewriter, Location loc, Value size,
OpType op) const = 0;
virtual void allocateBuffer(ConversionPatternRewriter &rewriter, Location loc,
Value size, OpType op) const = 0;

private:
static constexpr uint64_t kMinAlignedAllocAlignment = 16UL;
Expand Down Expand Up @@ -366,7 +360,7 @@ Value AllocationLowering<OpType>::getAlignment(
Type indexType = this->getIndexType();
alignment =
this->createIndexAttrConstant(rewriter, loc, indexType, *alignmentAttr);
} else if (!memRefType.getElementType().isSignlessIntOrIndexOrFloat()) {
} else {
alignment =
this->getAlignInBytes(loc, memRefType.getElementType(), rewriter);
}
Expand All @@ -390,63 +384,21 @@ Value AllocationLowering<OpType>::createAligned(
}

template <typename OpType>
std::tuple<Value, Value>
AllocationLowering<OpType>::allocateBufferManuallyAlign(
Value AllocationLowering<OpType>::allocateBufferAutoAlign(
ConversionPatternRewriter &rewriter, Location loc, Value sizeBytes,
OpType op, Value alignment) const {
if (alignment) {
// Adjust the allocation size to consider alignment.
sizeBytes = rewriter.create<gccjit::BinaryOp>(
loc, sizeBytes.getType(), BOp::Plus, sizeBytes, alignment);
}

OpType op, Value allocAlignment) const {
MemRefType memRefType = getMemRefResultType(op);
// Allocate the underlying buffer.
sizeBytes = createAligned(rewriter, loc, sizeBytes, allocAlignment);
Type elementPtrType = this->getElementPtrType(memRefType);
Value allocatedPtr =
auto result =
rewriter
.create<gccjit::CallOp>(
loc, this->getVoidPtrType(),
SymbolRefAttr::get(this->getContext(), "malloc"),
ValueRange{sizeBytes},
SymbolRefAttr::get(this->getContext(), "aligned_alloc"),
ValueRange{allocAlignment, sizeBytes},
/* tailcall */ nullptr, /* builtin */ rewriter.getUnitAttr())
.getResult();

if (!allocatedPtr)
return std::make_tuple(Value(), Value());
Value alignedPtr = allocatedPtr;
if (alignment) {
// Compute the aligned pointer.
Value allocatedInt = rewriter.create<gccjit::BitCastOp>(
loc, this->getIndexType(), allocatedPtr);
Value alignmentInt = createAligned(rewriter, loc, allocatedInt, alignment);
alignedPtr =
rewriter.create<gccjit::BitCastOp>(loc, elementPtrType, alignmentInt);
} else {
alignedPtr =
rewriter.create<gccjit::BitCastOp>(loc, elementPtrType, allocatedPtr);
}

return std::make_tuple(allocatedPtr, alignedPtr);
}

template <typename OpType>
Value AllocationLowering<OpType>::allocateBufferAutoAlign(
ConversionPatternRewriter &rewriter, Location loc, Value sizeBytes,
OpType op, int64_t alignment) const {
Value allocAlignment =
createIndexAttrConstant(rewriter, loc, this->getIndexType(), alignment);

MemRefType memRefType = getMemRefResultType(op);
sizeBytes = createAligned(rewriter, loc, sizeBytes, allocAlignment);

Type elementPtrType = this->getElementPtrType(memRefType);
auto result = rewriter.create<gccjit::CallOp>(
loc, this->getVoidPtrType(),
SymbolRefAttr::get(this->getContext(), "aligned_alloc"),
ValueRange{allocAlignment, sizeBytes},
/* tailcall */ nullptr, /* builtin */ rewriter.getUnitAttr());

return rewriter.create<gccjit::BitCastOp>(loc, elementPtrType, result);
}

Expand Down Expand Up @@ -523,58 +475,48 @@ LogicalResult AllocationLowering<OpType>::matchAndRewrite(
return rewriter.notifyMatchFailure(op, "incompatible memref type");
auto loc = op->getLoc();
auto convertedType = this->getTypeConverter()->convertType(memRefType);
auto exprBundle = rewriter.replaceOpWithNewOp<ExprOp>(op, convertedType);
auto *block = rewriter.createBlock(&exprBundle.getBody());

// Get actual sizes of the memref as values: static sizes are constant
// values and dynamic sizes are passed to 'alloc' as operands. In case of
// zero-dimensional memref, assume a scalar (size 1).
SmallVector<Value, 4> sizes;
SmallVector<Value, 4> strides;
Value size;

this->getMemRefDescriptorSizes(loc, memRefType, adaptor.getOperands(),
rewriter, sizes, strides, size, true);
auto elementPtrType = this->getElementPtrType(memRefType);
auto exprBundle = rewriter.create<ExprOp>(op.getLoc(), elementPtrType);
{
OpBuilder::InsertionGuard guard(rewriter);
auto *block = rewriter.createBlock(&exprBundle.getBody());
rewriter.setInsertionPointToStart(block);
// Get actual sizes of the memref as values: static sizes are constant
// values and dynamic sizes are passed to 'alloc' as operands. In case of
// zero-dimensional memref, assume a scalar (size 1).
SmallVector<Value, 4> sizes;
SmallVector<Value, 4> strides;
Value size;

this->getMemRefDescriptorSizes(loc, memRefType, adaptor.getOperands(),
rewriter, sizes, strides, size, true);

// Allocate the underlying buffer.
auto [allocatedPtr, alignedPtr] =
this->allocateBuffer(rewriter, loc, size, op);

if (!allocatedPtr || !alignedPtr)
return rewriter.notifyMatchFailure(loc,
"underlying buffer allocation failed");

auto arrayTy = ArrayType::get(rewriter.getContext(), this->getIndexType(),
memRefType.getRank());
auto sizeArr = rewriter.create<gccjit::NewArrayOp>(loc, arrayTy, sizes);
auto strideArr = rewriter.create<gccjit::NewArrayOp>(loc, arrayTy, strides);
auto zero =
this->createIndexAttrConstant(rewriter, loc, this->getIndexType(), 0);
// Create the MemRef descriptor.
auto memRefDescriptor = rewriter.create<gccjit::NewStructOp>(
loc, convertedType, ArrayRef<int32_t>{0, 1, 2, 3, 4},
ValueRange{alignedPtr, allocatedPtr, zero, sizeArr, strideArr});

// Return the final value of the descriptor.
rewriter.create<ReturnOp>(loc, memRefDescriptor);
this->allocateBuffer(rewriter, loc, size, op);
}
rewriter.setInsertionPoint(op);

auto arrayTy = ArrayType::get(rewriter.getContext(), this->getIndexType(),
memRefType.getRank());
auto sizeArr = rewriter.create<gccjit::NewArrayOp>(loc, arrayTy, sizes);
auto strideArr = rewriter.create<gccjit::NewArrayOp>(loc, arrayTy, strides);
auto zero =
this->createIndexAttrConstant(rewriter, loc, this->getIndexType(), 0);
// Create the MemRef descriptor.
rewriter.replaceOpWithNewOp<gccjit::NewStructOp>(
op, convertedType, ArrayRef<int32_t>{0, 1, 2, 3, 4},
ValueRange{exprBundle, exprBundle, zero, sizeArr, strideArr});

return success();
}

struct AllocaOpLowering : public AllocationLowering<memref::AllocaOp> {
using AllocationLowering<memref::AllocaOp>::AllocationLowering;
std::tuple<Value, Value>
allocateBuffer(ConversionPatternRewriter &rewriter, Location loc, Value size,
memref::AllocaOp op) const override final {
void allocateBuffer(ConversionPatternRewriter &rewriter, Location loc,
Value size, memref::AllocaOp op) const override final {
auto allocaOp = cast<memref::AllocaOp>(op);
auto elementType =
typeConverter->convertType(allocaOp.getType().getElementType());

if (allocaOp.getType().getMemorySpace())
return std::make_tuple(Value(), Value());

auto elementPtrType = PointerType::get(rewriter.getContext(), elementType);

Value alloca;
Expand All @@ -600,19 +542,18 @@ struct AllocaOpLowering : public AllocationLowering<memref::AllocaOp> {
/* builtin */ rewriter.getUnitAttr())
.getResult();
}

alloca = rewriter.create<BitCastOp>(loc, elementPtrType, alloca);

return std::make_tuple(alloca, alloca);
rewriter.create<ReturnOp>(loc, alloca);
}
};

struct AllocOpLowering : public AllocationLowering<memref::AllocOp> {
std::tuple<Value, Value>
allocateBuffer(ConversionPatternRewriter &rewriter, Location loc,
Value sizeBytes, memref::AllocOp op) const override final {
return allocateBufferManuallyAlign(rewriter, loc, sizeBytes, op,
getAlignment(rewriter, loc, op));
void allocateBuffer(ConversionPatternRewriter &rewriter, Location loc,
Value sizeBytes,
memref::AllocOp op) const override final {
auto result = allocateBufferAutoAlign(rewriter, loc, sizeBytes, op,
getAlignment(rewriter, loc, op));
rewriter.create<ReturnOp>(loc, result);
}
using AllocationLowering<memref::AllocOp>::AllocationLowering;
};
Expand Down
66 changes: 64 additions & 2 deletions src/Translation/TranslateToGCCJIT.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "mlir-gccjit/Translation/TranslateToGCCJIT.h"

#include <algorithm>
#include <cstddef>
#include <utility>

#include <llvm/ADT/SmallVector.h>
Expand Down Expand Up @@ -106,6 +107,8 @@ class RegionVisitor {
gcc_jit_rvalue *visitExprWithoutCache(PtrCallOp op);
gcc_jit_rvalue *visitExprWithoutCache(AddrOp op);
gcc_jit_rvalue *visitExprWithoutCache(FnAddrOp op);
gcc_jit_rvalue *visitExprWithoutCache(NewStructOp op);
gcc_jit_rvalue *visitExprWithoutCache(NewArrayOp op);
gcc_jit_lvalue *visitExprWithoutCache(GetGlobalOp op);
Expr visitExprWithoutCache(ExprOp op);
gcc_jit_lvalue *visitExprWithoutCache(DerefOp op);
Expand Down Expand Up @@ -571,6 +574,8 @@ Expr RegionVisitor::visitExpr(Value value, bool toplevel) {
.Case([&](ExprOp op) { return visitExprWithoutCache(op); })
.Case([&](DerefOp op) { return visitExprWithoutCache(op); })
.Case([&](AccessFieldOp op) { return visitExprWithoutCache(op); })
.Case([&](NewStructOp op) { return visitExprWithoutCache(op); })
.Case([&](NewArrayOp op) { return visitExprWithoutCache(op); })
.Default([](Operation *op) -> Expr {
llvm::report_fatal_error("unknown expression type");
});
Expand Down Expand Up @@ -606,6 +611,32 @@ Expr RegionVisitor::visitExprWithoutCache(AccessFieldOp op) {
return gcc_jit_rvalue_access_field(composite, loc, field);
}

gcc_jit_rvalue *RegionVisitor::visitExprWithoutCache(NewStructOp op) {
auto *rawStructTy = getTranslator().convertType(op.getType());
auto *structTy = gcc_jit_type_is_struct(rawStructTy);
if (!structTy)
llvm_unreachable("expected struct type");
llvm::SmallVector<gcc_jit_field *> fields;
llvm::SmallVector<gcc_jit_rvalue *> values;
for (auto field : op.getIndices())
fields.push_back(
gcc_jit_struct_get_field(structTy, static_cast<size_t>(field)));
visitExprAsRValue(op.getElements(), values);
auto *loc = getTranslator().getLocation(op.getLoc());
return gcc_jit_context_new_struct_constructor(getContext(), loc, rawStructTy,
values.size(), fields.data(),
values.data());
}

gcc_jit_rvalue *RegionVisitor::visitExprWithoutCache(NewArrayOp op) {
auto *arrayTy = getTranslator().convertType(op.getType());
auto *loc = getTranslator().getLocation(op.getLoc());
llvm::SmallVector<gcc_jit_rvalue *> values;
visitExprAsRValue(op.getElements(), values);
return gcc_jit_context_new_array_constructor(getContext(), loc, arrayTy,
values.size(), values.data());
}

Expr RegionVisitor::visitExprWithoutCache(ExprOp op) {
RegionVisitor visitor(getTranslator(), op.getRegion(), this);
return visitor.translateIntoContext();
Expand Down Expand Up @@ -663,11 +694,42 @@ gcc_jit_rvalue *RegionVisitor::visitExprWithoutCache(LiteralOp op) {
gcc_jit_rvalue *RegionVisitor::visitExprWithoutCache(SizeOfOp op) {
auto type = op.getType();
auto *typeHandle = getTranslator().convertType(type);
return gcc_jit_context_new_sizeof(getContext(), typeHandle);
auto *size = gcc_jit_context_new_sizeof(getContext(), typeHandle);
auto *loc = getTranslator().getLocation(op.getLoc());
auto resTy = op.getResult().getType();
auto *resTyHandle = getTranslator().convertType(resTy);
if (resTy.getKind() != GCC_JIT_TYPE_INT)
size = gcc_jit_context_new_cast(getContext(), loc, size, resTyHandle);
return size;
}

gcc_jit_rvalue *RegionVisitor::visitExprWithoutCache(AlignOfOp op) {
llvm_unreachable("GCCJIT does not support alignof yet");
#ifdef LIBGCCJIT_ABI_28
auto type = op.getType();
auto *typeHandle = getTranslator().convertType(type);
auto *align = gcc_jit_context_new_alignof(getContext(), typeHandle);
auto *loc = getTranslator().getLocation(op.getLoc());
auto resTy = op.getResult().getType();
auto *resTyHandle = getTranslator().convertType(resTy);
if (resTy.getKind() != GCC_JIT_TYPE_INT)
align = gcc_jit_context_new_cast(getContext(), loc, align, resTyHandle);
return align;
#endif
auto type = op.getType();
auto *typeHandle = getTranslator().convertType(type);
auto *resTyHandle = getTranslator().convertType(op.getResult().getType());
auto *typePtrHandle = gcc_jit_type_get_pointer(typeHandle);
auto *nullPtr = gcc_jit_context_null(getContext(), typePtrHandle);
auto *indexTy = gcc_jit_context_get_type(getContext(), GCC_JIT_TYPE_SIZE_T);
auto *one = gcc_jit_context_one(getContext(), indexTy);
auto *loc = getTranslator().getLocation(op.getLoc());
auto *access =
gcc_jit_context_new_array_access(getContext(), loc, nullPtr, one);
auto *addr = gcc_jit_lvalue_get_address(access, loc);
auto *addrInt = gcc_jit_context_new_bitcast(getContext(), loc, addr, indexTy);
auto *align =
gcc_jit_context_new_cast(getContext(), loc, addrInt, resTyHandle);
return align;
}

gcc_jit_rvalue *RegionVisitor::visitExprWithoutCache(AsRValueOp op) {
Expand Down
17 changes: 3 additions & 14 deletions test/lowering/alloc.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -4,30 +4,19 @@ module @test
{

func.func @foo() {
// CHECK: gccjit.call builtin @malloc(%{{[0-9]+}}) : (!gccjit.int<size_t>) -> !gccjit.ptr<!gccjit.void>
// CHECK: gccjit.call builtin @aligned_alloc(%{{[0-9]+}}, %{{[0-9]+}}) : (!gccjit.int<size_t>, !gccjit.int<size_t>) -> !gccjit.ptr<!gccjit.void>
%a = memref.alloc () : memref<100x100xf32>
return
}

func.func @bar(%arg0 : index, %arg1: index) {
// CHECK: gccjit.call builtin @malloc(%{{[0-9]+}}) : (!gccjit.int<size_t>) -> !gccjit.ptr<!gccjit.void>
// CHECK: gccjit.call builtin @aligned_alloc(%{{[0-9]+}}, %{{[0-9]+}}) : (!gccjit.int<size_t>, !gccjit.int<size_t>) -> !gccjit.ptr<!gccjit.void>
%a = memref.alloc (%arg0, %arg1) : memref<?x133x723x?xf32>
return
}

func.func @baz() {
// CHECK: %[[V6:[0-9]+]] = gccjit.sizeof !gccjit.int<uint128_t> : <size_t>
// CHECK: %[[V7:[0-9]+]] = gccjit.binary mult(%[[V6]] : !gccjit.int<size_t>, %{{[0-9]+}} : !gccjit.int<size_t>) : !gccjit.int<size_t>
// CHECK: %[[V8:[0-9]+]] = gccjit.const #gccjit.int<128> : !gccjit.int<size_t>
// CHECK: %[[V9:[0-9]+]] = gccjit.binary plus(%[[V7]] : !gccjit.int<size_t>, %[[V8]] : !gccjit.int<size_t>) : !gccjit.int<size_t>
// CHECK: %[[V10:[0-9]+]] = gccjit.call builtin @malloc(%[[V9]]) : (!gccjit.int<size_t>) -> !gccjit.ptr<!gccjit.void>
// CHECK: %[[V11:[0-9]+]] = gccjit.bitcast %[[V10]] : !gccjit.ptr<!gccjit.void> to !gccjit.int<size_t>
// CHECK: %[[V12:[0-9]+]] = gccjit.const #gccjit.int<1> : !gccjit.int<size_t>
// CHECK: %[[V13:[0-9]+]] = gccjit.binary minus(%[[V8]] : !gccjit.int<size_t>, %[[V12]] : !gccjit.int<size_t>) : !gccjit.int<size_t>
// CHECK: %[[V14:[0-9]+]] = gccjit.binary plus(%[[V11]] : !gccjit.int<size_t>, %[[V13]] : !gccjit.int<size_t>) : !gccjit.int<size_t>
// CHECK: %[[V15:[0-9]+]] = gccjit.binary modulo(%[[V14]] : !gccjit.int<size_t>, %[[V8]] : !gccjit.int<size_t>) : !gccjit.int<size_t>
// CHECK: %[[V16:[0-9]+]] = gccjit.binary minus(%[[V14]] : !gccjit.int<size_t>, %[[V15]] : !gccjit.int<size_t>) : !gccjit.int<size_t>
// CHECK: %[[V17:[0-9]+]] = gccjit.bitcast %[[V16]] : !gccjit.int<size_t> to !gccjit.ptr<!gccjit.int<uint128_t>>
// CHECK: gccjit.call builtin @aligned_alloc(%{{[0-9]+}}, %{{[0-9]+}}) : (!gccjit.int<size_t>, !gccjit.int<size_t>) -> !gccjit.ptr<!gccjit.void>
%a = memref.alloc () {alignment = 128} : memref<133x723x1xi128>
return
}
Expand Down
Loading

0 comments on commit 1314f5b

Please sign in to comment.