Skip to content

Commit

Permalink
[gccjit] add an alloca sum test
Browse files Browse the repository at this point in the history
  • Loading branch information
SchrodingerZhu committed Nov 7, 2024
1 parent 8fee246 commit 4c565fe
Show file tree
Hide file tree
Showing 6 changed files with 95 additions and 21 deletions.
4 changes: 2 additions & 2 deletions include/mlir-gccjit/Conversion/TypeConverter.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

#include "mlir-gccjit/IR/GCCJITAttrs.h"
#include "mlir-gccjit/IR/GCCJITTypes.h"
#include "mlir/IR/MLIRContext.h"

namespace mlir::gccjit {
class GCCJITTypeConverter : public TypeConverter {
Expand Down Expand Up @@ -59,8 +60,7 @@ class GCCJITTypeConverter : public TypeConverter {
gccjit::StructType
getUnrankedMemrefDescriptorType(mlir::UnrankedMemRefType type) const;

Type convertAndPackTypesIfNonSingleton(TypeRange types,
FunctionType name) const;
Type convertAndPackTypesIfNonSingleton(TypeRange types, MLIRContext *) const;
};
} // namespace mlir::gccjit
#endif // MLIR_GCCJIT_CONVERSION_TYPECONVERTER_H
2 changes: 1 addition & 1 deletion include/mlir-gccjit/IR/GCCJITOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ def SwitchOp : GCCJIT_Op<"switch", [Terminator, Pure, ParentOneOf<["FuncOp"]>]>
//===----------------------------------------------------------------------===//
// LocalOp
//===----------------------------------------------------------------------===//
def LocalOp : GCCJIT_Op<"local", [ParentOneOf<["FuncOp"]>]> {
def LocalOp : GCCJIT_Op<"local"> {
let summary = "Declare a local variable";
let description = [{
The "local_var" operation declares a local variable.
Expand Down
24 changes: 17 additions & 7 deletions src/Conversion/ConvertFuncToGCCJIT.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,17 @@ void ConvertFuncToGCCJITPass::runOnOperation() {
auto moduleOp = getOperation();
SymbolTable symbolTable(moduleOp);
auto typeConverter = GCCJITTypeConverter();
auto materializeAsUnrealizedCast = [](OpBuilder &builder, Type resultType,
ValueRange inputs,
Location loc) -> Value {
if (inputs.size() != 1)
return Value();

return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
.getResult(0);
};
typeConverter.addTargetMaterialization(materializeAsUnrealizedCast);
typeConverter.addSourceMaterialization(materializeAsUnrealizedCast);
mlir::RewritePatternSet patterns(&getContext());
populateFuncToGCCJITPatterns(&getContext(), typeConverter, patterns,
symbolTable);
Expand Down Expand Up @@ -265,8 +276,8 @@ NewStructOp packValues(mlir::Location loc, mlir::ValueRange values,
mlir::TypeRange types,
mlir::ConversionPatternRewriter &rewriter,
FunctionType func) {
auto packedType =
typeConverter.convertAndPackTypesIfNonSingleton(types, func);
auto packedType = typeConverter.convertAndPackTypesIfNonSingleton(
types, rewriter.getContext());
auto structType = cast<gccjit::StructType>(packedType);
auto indices =
llvm::to_vector(llvm::seq<int>(0, structType.getFields().size()));
Expand Down Expand Up @@ -318,11 +329,10 @@ class CallOpLowering : public GCCJITLoweringPattern<func::CallOp> {
matchAndRewrite(func::CallOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
auto callee = op.getCalleeAttr();
auto funcOp = dyn_cast<func::FuncOp>(symbolTable.lookup(callee.getValue()));
if (!funcOp)
return mlir::failure();
Type resultTy = getTypeConverter()->convertAndPackTypesIfNonSingleton(
op->getResultTypes(), funcOp.getFunctionType());
op->getResultTypes(), getContext());
if (isa<VoidType>(resultTy))
resultTy = {};
auto callOp = rewriter.create<gccjit::CallOp>(op.getLoc(), resultTy, callee,
adaptor.getOperands());
if (op->getNumResults() <= 1)
Expand All @@ -349,7 +359,7 @@ class CallIndirectOpLowering
matchAndRewrite(func::CallIndirectOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
auto resultTy = getTypeConverter()->convertAndPackTypesIfNonSingleton(
op->getResultTypes(), op.getCallee().getType());
op->getResultTypes(), getContext());
auto callOp = rewriter.create<gccjit::PtrCallOp>(
op.getLoc(), resultTy, adaptor.getCallee(), adaptor.getOperands());
if (op->getNumResults() <= 1)
Expand Down
21 changes: 10 additions & 11 deletions src/Conversion/TypeConverter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

#include "mlir-gccjit/IR/GCCJITAttrs.h"
#include "mlir-gccjit/IR/GCCJITTypes.h"
#include "mlir/IR/MLIRContext.h"
#include <mlir/IR/BuiltinTypes.h>

using namespace mlir;
Expand Down Expand Up @@ -145,7 +146,8 @@ GCCJITTypeConverter::convertFunctionType(mlir::FunctionType type,
argTypes.reserve(type.getNumInputs());
if (convertTypes(type.getInputs(), argTypes).failed())
return {};
auto resultType = convertAndPackTypesIfNonSingleton(type.getResults(), type);
auto resultType =
convertAndPackTypesIfNonSingleton(type.getResults(), type.getContext());
return FuncType::get(type.getContext(), argTypes, resultType, isVarArg);
}

Expand Down Expand Up @@ -206,25 +208,22 @@ gccjit::StructType GCCJITTypeConverter::getUnrankedMemrefDescriptorType(
}

Type GCCJITTypeConverter::convertAndPackTypesIfNonSingleton(
TypeRange types, FunctionType func) const {
TypeRange types, MLIRContext *ctx) const {
if (types.size() == 0)
return VoidType::get(func.getContext());
return VoidType::get(ctx);
if (types.size() == 1)
return convertType(types.front());

auto name =
Twine("__retpack_")
.concat(Twine(reinterpret_cast<uintptr_t>(func.getAsOpaquePointer())))
.str();
auto *name = "__return_pack";
SmallVector<Attribute> fields;
for (auto [idx, type] : llvm::enumerate(types)) {
auto name = Twine("__field_").concat(Twine(idx)).str();
auto nameAttr = StringAttr::get(func.getContext(), name);
auto nameAttr = StringAttr::get(ctx, name);
fields.push_back(FieldAttr::get(type.getContext(), nameAttr, type));
}
auto nameAttr = StringAttr::get(func.getContext(), name);
auto fieldsAttr = ArrayAttr::get(func.getContext(), fields);
return StructType::get(func.getContext(), nameAttr, fieldsAttr);
auto nameAttr = StringAttr::get(ctx, name);
auto fieldsAttr = ArrayAttr::get(ctx, fields);
return StructType::get(ctx, nameAttr, fieldsAttr);
}

bool GCCJITTypeConverter::isSigned(gccjit::IntType type) const {
Expand Down
13 changes: 13 additions & 0 deletions test/lowering/alloca_sum.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@

#include <stdint.h>
#include <stdio.h>

void print(int32_t x) {
printf("%d\n", x);
}

int32_t read() {
int32_t x;
scanf("%d", &x);
return x;
}
52 changes: 52 additions & 0 deletions test/lowering/alloca_sum.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
// RUN: %gccjit-opt %s \
// RUN: -lower-affine \
// RUN: -convert-scf-to-cf \
// RUN: -convert-arith-to-gccjit \
// RUN: -convert-memref-to-gccjit \
// RUN: -convert-func-to-gccjit \
// RUN: -reconcile-unrealized-casts -mlir-print-debuginfo -o %t.mlir

// RUN: %gccjit-translate %t.mlir -mlir-to-gccjit-gimple | %filecheck %s --check-prefix=CHECK-GIMPLE
// RUN: %gccjit-translate %t.mlir -mlir-to-gccjit-dylib -o %t.so
// RUN: cc -O3 %p/alloca_sum.c %t.so -Wl,-rpath,%T -o %t.exe
// RUN: seq 1 100 | %t.exe | %filecheck %s --check-prefix=CHECK-OUTPUT

// CHECK-OUTPUT: 5050
module attributes { gccjit.opt_level = #gccjit.opt_level<O3>, gccjit.debug_info = false } {
// Import C standard library functions for I/O
func.func private @read() -> i32
func.func private @print(%val: i32)

func.func @main() -> i32 {
// Allocate memory for the 100 integers array and initialize sum to 0
// CHECK-GIMPLE: %{{[0-9]+}} = bitcast(alloca (%{{[0-9]+}}), __uint32_t *);
%array = memref.alloca() : memref<100xi32>
%sum_init = arith.constant 0 : i32

// Loop to read 100 integers into the array
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c100 = arith.constant 100 : index

scf.for %i = %c0 to %c100 step %c1 {
// Call scanf to read an integer
// CHECK-GIMPLE: %{{[0-9]+}} = read ();
%read = func.call @read() : () -> i32
memref.store %read, %array[%i] : memref<100xi32>
}

// Loop to calculate the sum using iter_args
%final_sum = scf.for %i = %c0 to %c100 step %c1 iter_args(%acc = %sum_init) -> (i32) {
%elem = memref.load %array[%i] : memref<100xi32>
%new_sum = arith.addi %acc, %elem : i32
scf.yield %new_sum : i32
}

// Print the result using printf
// CHECK-GIMPLE: (void)print (%{{[0-9]+}});
func.call @print(%final_sum) : (i32) -> ()

%c0_1 = arith.constant 0 : i32
return %c0_1 : i32
}
}

0 comments on commit 4c565fe

Please sign in to comment.