diff --git a/src/Conversion/ConvertMemrefToGCCJIT.cpp b/src/Conversion/ConvertMemrefToGCCJIT.cpp index 4e9ebda..68a3113 100644 --- a/src/Conversion/ConvertMemrefToGCCJIT.cpp +++ b/src/Conversion/ConvertMemrefToGCCJIT.cpp @@ -12,9 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include #include #include +#include #include #include #include @@ -22,6 +22,7 @@ #include #include #include +#include #include #include "libgccjit.h" @@ -32,7 +33,6 @@ #include "mlir-gccjit/IR/GCCJITOpsEnums.h" #include "mlir-gccjit/IR/GCCJITTypes.h" #include "mlir-gccjit/Passes.h" -#include "mlir/IR/Types.h" using namespace mlir; using namespace mlir::gccjit; @@ -133,12 +133,13 @@ class AllocationLowering : public GCCJITLoweringPattern { virtual std::tuple allocateBuffer(ConversionPatternRewriter &rewriter, Location loc, Value size, - Operation *op) const = 0; + OpType op) const = 0; private: static constexpr uint64_t kMinAlignedAllocAlignment = 16UL; public: + using GCCJITLoweringPattern::GCCJITLoweringPattern; LogicalResult matchAndRewrite(OpType op, typename OpConversionPattern::OpAdaptor adaptor, @@ -191,33 +192,6 @@ class StoreOpLowering : public GCCJITLoweringPattern { } }; -void ConvertMemrefToGCCJITPass::runOnOperation() { - auto moduleOp = getOperation(); - auto typeConverter = GCCJITTypeConverter(); - auto materializeAsUnrealizedCast = [](OpBuilder &builder, Type resultType, - ValueRange inputs, - Location loc) -> Value { - if (inputs.size() != 1) - return Value(); - - return builder.create(loc, resultType, inputs) - .getResult(0); - }; - typeConverter.addTargetMaterialization(materializeAsUnrealizedCast); - typeConverter.addSourceMaterialization(materializeAsUnrealizedCast); - mlir::RewritePatternSet patterns(&getContext()); - patterns.insert(typeConverter, - &getContext()); - mlir::ConversionTarget target(getContext()); - target.addLegalDialect(); - target.addIllegalDialect(); - llvm::SmallVector ops; - for (auto func : moduleOp.getOps()) - ops.push_back(func); - if (failed(applyPartialConversion(ops, target, std::move(patterns)))) - signalPassFailure(); -} - template IntType GCCJITLoweringPattern::getIndexType() const { return IntType::get(this->getContext(), GCC_JIT_TYPE_SIZE_T); } @@ -472,6 +446,7 @@ Value AllocationLowering::allocateBufferAutoAlign( return rewriter.create(loc, elementPtrType, result); } +[[gnu::used]] bool isConvertibleAndHasIdentityMaps(MemRefType type, const GCCJITTypeConverter &typeConverter) { if (!typeConverter.convertType(type.getElementType())) @@ -485,7 +460,7 @@ void GCCJITLoweringPattern::getMemRefDescriptorSizes( ConversionPatternRewriter &rewriter, SmallVectorImpl &sizes, SmallVectorImpl &strides, Value &size, bool sizeInBytes) const { assert( - isConvertibleAndHasIdentityMaps(memRefType, this->getTypeConverter()) && + isConvertibleAndHasIdentityMaps(memRefType, *this->getTypeConverter()) && "layout maps must have been normalized away"); assert(count(memRefType.getShape(), ShapedType::kDynamic) == static_cast(dynamicSizes.size()) && @@ -528,6 +503,8 @@ void GCCJITLoweringPattern::getMemRefDescriptorSizes( Type elementType = this->getTypeConverter()->convertType(memRefType.getElementType()); size = rewriter.create(loc, indexType, elementType); + size = rewriter.create(loc, indexType, BOp::Mult, size, + runningStride); } else { size = runningStride; } @@ -565,18 +542,94 @@ LogicalResult AllocationLowering::matchAndRewrite( return rewriter.notifyMatchFailure(loc, "underlying buffer allocation failed"); + auto arrayTy = ArrayType::get(rewriter.getContext(), this->getIndexType(), + memRefType.getRank()); + auto sizeArr = rewriter.create(loc, arrayTy, sizes); + auto strideArr = rewriter.create(loc, arrayTy, strides); + auto zero = + this->createIndexAttrConstant(rewriter, loc, this->getIndexType(), 0); // Create the MemRef descriptor. auto memRefDescriptor = rewriter.create( loc, convertedType, ArrayRef{0, 1, 2, 3, 4}, - ValueRange{alignedPtr, allocatedPtr, size}); + ValueRange{alignedPtr, allocatedPtr, zero, sizeArr, strideArr}); // Return the final value of the descriptor. rewriter.create(loc, memRefDescriptor); } - // Return the final value of the descriptor. - rewriter.replaceOp(op, exprBundle); return success(); } + +struct AllocaOpLowering : public AllocationLowering { + using AllocationLowering::AllocationLowering; + std::tuple + allocateBuffer(ConversionPatternRewriter &rewriter, Location loc, Value size, + memref::AllocaOp op) const override final { + auto allocaOp = cast(op); + auto elementType = + typeConverter->convertType(allocaOp.getType().getElementType()); + + if (allocaOp.getType().getMemorySpace()) + return std::make_tuple(Value(), Value()); + + auto elementPtrType = PointerType::get(rewriter.getContext(), elementType); + + Value alloca; + + if (auto align = op.getAlignment()) { + auto alignment = + createIndexAttrConstant(rewriter, loc, getIndexType(), *align); + alloca = rewriter + .create(loc, getVoidPtrType(), + SymbolRefAttr::get(rewriter.getContext(), + "alloca_with_align"), + ValueRange{size, alignment}, + /* tailcall */ nullptr, + /* builtin */ rewriter.getUnitAttr()) + .getResult(); + } else { + alloca = rewriter + .create( + loc, getVoidPtrType(), + SymbolRefAttr::get(rewriter.getContext(), "alloca"), + ValueRange{size}, + /* tailcall */ nullptr, + /* builtin */ rewriter.getUnitAttr()) + .getResult(); + } + + alloca = rewriter.create(loc, elementPtrType, alloca); + + return std::make_tuple(alloca, alloca); + } +}; + +void ConvertMemrefToGCCJITPass::runOnOperation() { + auto moduleOp = getOperation(); + auto typeConverter = GCCJITTypeConverter(); + auto materializeAsUnrealizedCast = [](OpBuilder &builder, Type resultType, + ValueRange inputs, + Location loc) -> Value { + if (inputs.size() != 1) + return Value(); + + return builder.create(loc, resultType, inputs) + .getResult(0); + }; + typeConverter.addTargetMaterialization(materializeAsUnrealizedCast); + typeConverter.addSourceMaterialization(materializeAsUnrealizedCast); + mlir::RewritePatternSet patterns(&getContext()); + patterns.insert( + typeConverter, &getContext()); + mlir::ConversionTarget target(getContext()); + target.addLegalDialect(); + target.addIllegalDialect(); + llvm::SmallVector ops; + for (auto func : moduleOp.getOps()) + ops.push_back(func); + if (failed(applyPartialConversion(ops, target, std::move(patterns)))) + signalPassFailure(); +} + } // namespace std::unique_ptr mlir::gccjit::createConvertMemrefToGCCJITPass() { diff --git a/src/GCCJITOps.cpp b/src/GCCJITOps.cpp index ac6b53a..bd4db86 100644 --- a/src/GCCJITOps.cpp +++ b/src/GCCJITOps.cpp @@ -12,8 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "mlir-gccjit/IR/GCCJITOps.h" - #include #include #include @@ -46,6 +44,7 @@ #include #include "mlir-gccjit/IR/GCCJITDialect.h" +#include "mlir-gccjit/IR/GCCJITOps.h" #include "mlir-gccjit/IR/GCCJITOpsEnums.h" #include "mlir-gccjit/IR/GCCJITTypes.h" @@ -264,13 +263,41 @@ ParseResult parseArrayOrVectorElements( OpAsmParser &parser, Type expectedType, llvm::SmallVectorImpl &elementValues, llvm::SmallVectorImpl &elementTypes) { - llvm_unreachable("Not implemented"); + bool mayContinue = true; + auto parseOptionalValueTypePair = [&]() -> ParseResult { + OpAsmParser::UnresolvedOperand elementValue; + Type elementType; + if (!parser.parseOptionalOperand(elementValue).has_value()) { + mayContinue = false; + return success(); + } + if (parser.parseColonType(elementType)) + return failure(); + elementValues.push_back(elementValue); + elementTypes.push_back(elementType); + if (parser.parseOptionalComma().succeeded()) { + mayContinue = true; + return success(); + } + mayContinue = false; + return success(); + }; + while (mayContinue) + if (parseOptionalValueTypePair().failed()) + return failure(); + return success(); } void printArrayOrVectorElements(OpAsmPrinter &p, Operation *op, Type expectedType, OperandRange elementValues, ValueTypeRange elementTypes) { - llvm_unreachable("Not implemented"); + llvm::interleaveComma(llvm::zip(elementValues, elementTypes), p, + [&](auto pair) { + auto [value, type] = pair; + p.printOperand(value); + p << " : "; + p.printType(type); + }); } struct ParseNamedUnitAttr { diff --git a/test/lowering/alloca.mlir b/test/lowering/alloca.mlir new file mode 100644 index 0000000..af74edd --- /dev/null +++ b/test/lowering/alloca.mlir @@ -0,0 +1,54 @@ +// RUN: %gccjit-opt %s -o %t.mlir -convert-memref-to-gccjit +// RUN: %filecheck --input-file=%t.mlir %s +module @test +{ + + func.func @foo() { + // CHECK: %[[V0:[0-9]+]] = gccjit.expr { + // CHECK: %[[V1:[0-9]+]] = gccjit.const #gccjit.int<100> : !gccjit.int + // CHECK: %[[V2:[0-9]+]] = gccjit.const #gccjit.int<100> : !gccjit.int + // CHECK: %[[V3:[0-9]+]] = gccjit.const #gccjit.int<1> : !gccjit.int + // CHECK: %[[V4:[0-9]+]] = gccjit.const #gccjit.int<10000> : !gccjit.int + // CHECK: %[[V5:[0-9]+]] = gccjit.sizeof !gccjit.fp : + // CHECK: %[[V6:[0-9]+]] = gccjit.binary mult(%[[V5]] : !gccjit.int, %[[V4]] : !gccjit.int) : !gccjit.int + // CHECK: %[[V7:[0-9]+]] = gccjit.call builtin @alloca(%[[V6]]) : (!gccjit.int) -> !gccjit.ptr + // CHECK: %[[V8:[0-9]+]] = gccjit.bitcast %[[V7]] : !gccjit.ptr to !gccjit.ptr> + // CHECK: %[[V9:[0-9]+]] = gccjit.new_array , 2>[%[[V1]] : !gccjit.int, %[[V2]] : !gccjit.int] + // CHECK: %[[V10:[0-9]+]] = gccjit.new_array , 2>[%[[V2]] : !gccjit.int, %[[V3]] : !gccjit.int] + // CHECK: %[[V11:[0-9]+]] = gccjit.const #gccjit.int<0> : !gccjit.int + // CHECK: %[[V12:[0-9]+]] = gccjit.new_struct [0, 1, 2, 3, 4][%[[V8]], %[[V8]], %[[V11]], %[[V9]], %[[V10]]] : (!gccjit.ptr>, !gccjit.ptr>, !gccjit.int, !gccjit.array, 2>, !gccjit.array, 2>) -> !gccjit.struct<"memref<100x100xf32>" {#gccjit.field<"base" !gccjit.ptr>>, #gccjit.field<"aligned" !gccjit.ptr>>, #gccjit.field<"offset" !gccjit.int>, #gccjit.field<"sizes" !gccjit.array, 2>>, #gccjit.field<"strides" !gccjit.array, 2>>}> + // CHECK: gccjit.return %[[V12]] : !gccjit.struct<"memref<100x100xf32>" {#gccjit.field<"base" !gccjit.ptr>>, #gccjit.field<"aligned" !gccjit.ptr>>, #gccjit.field<"offset" !gccjit.int>, #gccjit.field<"sizes" !gccjit.array, 2>>, #gccjit.field<"strides" !gccjit.array, 2>>}> + // CHECK: } : !gccjit.struct<"memref<100x100xf32>" {#gccjit.field<"base" !gccjit.ptr>>, #gccjit.field<"aligned" !gccjit.ptr>>, #gccjit.field<"offset" !gccjit.int>, #gccjit.field<"sizes" !gccjit.array, 2>>, #gccjit.field<"strides" !gccjit.array, 2>>}> + %a = memref.alloca () : memref<100x100xf32> + return + } + + func.func @bar(%arg0 : index) { + // CHECK: %[[V0:[0-9]+]] = builtin.unrealized_conversion_cast %{{[0-9a-z]+}} : index to !gccjit.int + // CHECK: %[[V1:[0-9]+]] = gccjit.expr { + // CHECK: %[[V2:[0-9]+]] = gccjit.const #gccjit.int<133> : !gccjit.int + // CHECK: %[[V3:[0-9]+]] = gccjit.const #gccjit.int<723> : !gccjit.int + // CHECK: %[[V4:[0-9]+]] = gccjit.const #gccjit.int<1> : !gccjit.int + // CHECK: %[[V5:[0-9]+]] = gccjit.binary mult(%[[V0]] : !gccjit.int, %[[V3]] : !gccjit.int) : !gccjit.int + // CHECK: %[[V6:[0-9]+]] = gccjit.binary mult(%[[V5]] : !gccjit.int, %[[V2]] : !gccjit.int) : !gccjit.int + // CHECK: %[[V7:[0-9]+]] = gccjit.sizeof !gccjit.fp : + // CHECK: %[[V8:[0-9]+]] = gccjit.binary mult(%[[V7]] : !gccjit.int, %[[V6]] : !gccjit.int) : !gccjit.int + // CHECK: %[[V9:[0-9]+]] = gccjit.call builtin @alloca(%[[V8]]) : (!gccjit.int) -> !gccjit.ptr + // CHECK: %[[V10:[0-9]+]] = gccjit.bitcast %[[V9]] : !gccjit.ptr to !gccjit.ptr> + // CHECK: %[[V11:[0-9]+]] = gccjit.new_array , 3>[%[[V2]] : !gccjit.int, %[[V3]] : !gccjit.int, %[[V0]] : !gccjit.int] + // CHECK: %[[V12:[0-9]+]] = gccjit.new_array , 3>[%[[V5]] : !gccjit.int, %[[V0]] : !gccjit.int, %[[V4]] : !gccjit.int] + // CHECK: %[[V13:[0-9]+]] = gccjit.const #gccjit.int<0> : !gccjit.int + // CHECK: %[[V14:[0-9]+]] = gccjit.new_struct [0, 1, 2, 3, 4][%[[V10]], %[[V10]], %[[V13]], %[[V11]], %[[V12]]] : (!gccjit.ptr>, !gccjit.ptr>, !gccjit.int, !gccjit.array, 3>, !gccjit.array, 3>) -> !gccjit.struct<"memref<133x723x?xf32>" {#gccjit.field<"base" !gccjit.ptr>>, #gccjit.field<"aligned" !gccjit.ptr>>, #gccjit.field<"offset" !gccjit.int>, #gccjit.field<"sizes" !gccjit.array, 3>>, #gccjit.field<"strides" !gccjit.array, 3>>}> + // CHECK: gccjit.return %[[V14]] : !gccjit.struct<"memref<133x723x?xf32>" {#gccjit.field<"base" !gccjit.ptr>>, #gccjit.field<"aligned" !gccjit.ptr>>, #gccjit.field<"offset" !gccjit.int>, #gccjit.field<"sizes" !gccjit.array, 3>>, #gccjit.field<"strides" !gccjit.array, 3>>}> + // CHECK: } : !gccjit.struct<"memref<133x723x?xf32>" {#gccjit.field<"base" !gccjit.ptr>>, #gccjit.field<"aligned" !gccjit.ptr>>, #gccjit.field<"offset" !gccjit.int>, #gccjit.field<"sizes" !gccjit.array, 3>>, #gccjit.field<"strides" !gccjit.array, 3>>}> + %a = memref.alloca (%arg0) : memref<133x723x?xf32> + return + } + + func.func @baz(%arg0 : index) { + // CHECK: %[[V:[0-9]+]] = gccjit.const #gccjit.int<128> : !gccjit.int + // CHECK: gccjit.call builtin @alloca_with_align(%{{[0-9]+}}, %[[V]]) : (!gccjit.int, !gccjit.int) -> !gccjit.ptr + %a = memref.alloca (%arg0) {alignment = 128} : memref<133x723x?xf32> + return + } +}