From 05a799d0163a5ce83d2ac7a3a4d7b2e7c5380516 Mon Sep 17 00:00:00 2001 From: Schrodinger ZHU Yifan Date: Thu, 7 Nov 2024 00:13:25 -0500 Subject: [PATCH] [gccjit] fix missed translation of return type --- src/Conversion/TypeConverter.cpp | 2 +- test/lowering/alloc.mlir | 23 ++++++++++++++++------- 2 files changed, 17 insertions(+), 8 deletions(-) diff --git a/src/Conversion/TypeConverter.cpp b/src/Conversion/TypeConverter.cpp index 3332ed8..fe8baf8 100644 --- a/src/Conversion/TypeConverter.cpp +++ b/src/Conversion/TypeConverter.cpp @@ -210,7 +210,7 @@ Type GCCJITTypeConverter::convertAndPackTypesIfNonSingleton( if (types.size() == 0) return VoidType::get(func.getContext()); if (types.size() == 1) - return types.front(); + return convertType(types.front()); auto name = Twine("__retpack_") diff --git a/test/lowering/alloc.mlir b/test/lowering/alloc.mlir index edf1e2d..696ac63 100644 --- a/test/lowering/alloc.mlir +++ b/test/lowering/alloc.mlir @@ -1,23 +1,32 @@ // RUN: %gccjit-opt %s -o %t.mlir -convert-memref-to-gccjit // RUN: %filecheck --input-file=%t.mlir %s -module @test +module @test attributes { + gccjit.opt_level = #gccjit.opt_level, + gccjit.debug_info = false +} { - func.func @foo() { + func.func @foo() -> memref<100x100xf32> { // CHECK: gccjit.call builtin @aligned_alloc(%{{[0-9]+}}, %{{[0-9]+}}) : (!gccjit.int, !gccjit.int) -> !gccjit.ptr %a = memref.alloc () : memref<100x100xf32> - return + return %a : memref<100x100xf32> } - func.func @bar(%arg0 : index, %arg1: index) { + func.func @bar(%arg0 : index, %arg1: index) -> memref { // CHECK: gccjit.call builtin @aligned_alloc(%{{[0-9]+}}, %{{[0-9]+}}) : (!gccjit.int, !gccjit.int) -> !gccjit.ptr %a = memref.alloc (%arg0, %arg1) : memref - return + return %a : memref } - func.func @baz() { + func.func @baz() -> memref<133x723x1xi128> { // CHECK: gccjit.call builtin @aligned_alloc(%{{[0-9]+}}, %{{[0-9]+}}) : (!gccjit.int, !gccjit.int) -> !gccjit.ptr %a = memref.alloc () {alignment = 128} : memref<133x723x1xi128> - return + return %a : memref<133x723x1xi128> + } + + gccjit.func exported @qux() -> !gccjit.int { + // CHECK: gccjit.call builtin @aligned_alloc(%{{[0-9]+}}, %{{[0-9]+}}) : (!gccjit.int, !gccjit.int) -> !gccjit.ptr + %a = gccjit.alignof !gccjit.int : !gccjit.int + gccjit.return %a : !gccjit.int } }