Skip to content

Commit

Permalink
cleanup api type converstion
Browse files Browse the repository at this point in the history
  • Loading branch information
gysit committed Jan 20, 2021
1 parent a4deddc commit 83cb8a7
Showing 1 changed file with 31 additions and 48 deletions.
79 changes: 31 additions & 48 deletions patches/runtime.patch
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ index aa228784e48a..61e4ee8b5498 100644
}

diff --git a/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp b/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp
index cee1d7ba20e3..7fd9f2cd95d3 100644
index cee1d7ba20e3..29ac805e2a21 100644
--- a/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp
+++ b/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp
@@ -41,9 +41,17 @@ namespace {
Expand Down Expand Up @@ -100,7 +100,7 @@ index cee1d7ba20e3..7fd9f2cd95d3 100644
ArrayRef<Value> arguments) const {
auto module = builder.getBlock()->getParent()->getParentOfType<ModuleOp>();
auto function = [&] {
@@ -311,9 +323,33 @@ LLVM::CallOp FunctionCallBuilder::create(Location loc, OpBuilder &builder,
@@ -311,9 +323,25 @@ LLVM::CallOp FunctionCallBuilder::create(Location loc, OpBuilder &builder,
return OpBuilder(module.getBody()->getTerminator())
.create<LLVM::LLVMFuncOp>(loc, functionName, functionType);
}();
Expand All @@ -113,18 +113,10 @@ index cee1d7ba20e3..7fd9f2cd95d3 100644
+ std::back_inserter(castedArguments), [&](const auto &pair) -> Value {
+ auto value = std::get<0>(pair);
+ auto paramType = std::get<1>(pair);
+ if (value.getType().isIndex() &&
+ if ((value.getType().isIndex() ||
+ value.getType().isSignlessInteger(indexBitwidth)) &&
+ paramType.getIntOrFloatBitWidth() > indexBitwidth) {
+ return builder.create<LLVM::SExtOp>(loc, paramType, value);
+ }
+ if (value.getType().isInteger(indexBitwidth) &&
+ paramType.getIntOrFloatBitWidth() >
+ value.getType().getIntOrFloatBitWidth()) {
+ return value.getType().isUnsignedInteger()
+ ? builder.create<LLVM::ZExtOp>(loc, paramType, value)
+ .getResult()
+ : builder.create<LLVM::SExtOp>(loc, paramType, value)
+ .getResult();
+ return builder.create<LLVM::ZExtOp>(loc, paramType, value);
+ }
+ return value;
+ });
Expand All @@ -135,15 +127,15 @@ index cee1d7ba20e3..7fd9f2cd95d3 100644
}

// Returns whether all operands are of LLVM type.
@@ -348,6 +384,7 @@ LogicalResult ConvertHostRegisterOpToGpuRuntimeCallPattern::matchAndRewrite(
@@ -348,6 +376,7 @@ LogicalResult ConvertHostRegisterOpToGpuRuntimeCallPattern::matchAndRewrite(
return failure();

Location loc = op->getLoc();
+ unsigned indexBitwidth = getTypeConverter()->getIndexTypeBitwidth();

auto memRefType = hostRegisterOp.value().getType();
auto elementType = memRefType.cast<UnrankedMemRefType>().getElementType();
@@ -356,7 +393,7 @@ LogicalResult ConvertHostRegisterOpToGpuRuntimeCallPattern::matchAndRewrite(
@@ -356,7 +385,7 @@ LogicalResult ConvertHostRegisterOpToGpuRuntimeCallPattern::matchAndRewrite(
auto arguments = getTypeConverter()->promoteOperands(loc, op->getOperands(),
operands, rewriter);
arguments.push_back(elementSize);
Expand All @@ -152,15 +144,15 @@ index cee1d7ba20e3..7fd9f2cd95d3 100644

rewriter.eraseOp(op);
return success();
@@ -373,6 +410,7 @@ LogicalResult ConvertAllocOpToGpuRuntimeCallPattern::matchAndRewrite(
@@ -373,6 +402,7 @@ LogicalResult ConvertAllocOpToGpuRuntimeCallPattern::matchAndRewrite(
return failure();

auto loc = allocOp.getLoc();
+ unsigned indexBitwidth = getTypeConverter()->getIndexTypeBitwidth();
auto adaptor = gpu::AllocOpAdaptor(operands, allocOp->getAttrDictionary());

// Get shape of the memref as values: static sizes are constant
@@ -388,7 +426,8 @@ LogicalResult ConvertAllocOpToGpuRuntimeCallPattern::matchAndRewrite(
@@ -388,7 +418,8 @@ LogicalResult ConvertAllocOpToGpuRuntimeCallPattern::matchAndRewrite(
Type elementPtrType = this->getElementPtrType(memRefType);
auto stream = adaptor.asyncDependencies().front();
Value allocatedPtr =
Expand All @@ -170,15 +162,15 @@ index cee1d7ba20e3..7fd9f2cd95d3 100644
allocatedPtr =
rewriter.create<LLVM::BitcastOp>(loc, elementPtrType, allocatedPtr);

@@ -412,6 +451,7 @@ LogicalResult ConvertDeallocOpToGpuRuntimeCallPattern::matchAndRewrite(
@@ -412,6 +443,7 @@ LogicalResult ConvertDeallocOpToGpuRuntimeCallPattern::matchAndRewrite(
return failure();

Location loc = deallocOp.getLoc();
+ unsigned indexBitwidth = getTypeConverter()->getIndexTypeBitwidth();

auto adaptor =
gpu::DeallocOpAdaptor(operands, deallocOp->getAttrDictionary());
@@ -419,7 +459,7 @@ LogicalResult ConvertDeallocOpToGpuRuntimeCallPattern::matchAndRewrite(
@@ -419,7 +451,7 @@ LogicalResult ConvertDeallocOpToGpuRuntimeCallPattern::matchAndRewrite(
MemRefDescriptor(adaptor.memref()).allocatedPtr(rewriter, loc);
auto casted = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, pointer);
Value stream = adaptor.asyncDependencies().front();
Expand All @@ -187,7 +179,7 @@ index cee1d7ba20e3..7fd9f2cd95d3 100644

rewriter.replaceOp(deallocOp, {stream});
return success();
@@ -437,11 +477,14 @@ LogicalResult ConvertWaitOpToGpuRuntimeCallPattern::matchAndRewrite(
@@ -437,11 +469,14 @@ LogicalResult ConvertWaitOpToGpuRuntimeCallPattern::matchAndRewrite(
return rewriter.notifyMatchFailure(waitOp, "Cannot convert async op.");

Location loc = waitOp.getLoc();
Expand All @@ -204,15 +196,15 @@ index cee1d7ba20e3..7fd9f2cd95d3 100644

rewriter.eraseOp(waitOp);
return success();
@@ -460,6 +503,7 @@ LogicalResult ConvertWaitAsyncOpToGpuRuntimeCallPattern::matchAndRewrite(
@@ -460,6 +495,7 @@ LogicalResult ConvertWaitAsyncOpToGpuRuntimeCallPattern::matchAndRewrite(
return rewriter.notifyMatchFailure(waitOp, "Can only convert async op.");

Location loc = waitOp.getLoc();
+ unsigned indexBitwidth = getTypeConverter()->getIndexTypeBitwidth();

auto insertionPoint = rewriter.saveInsertionPoint();
SmallVector<Value, 1> events;
@@ -472,17 +516,21 @@ LogicalResult ConvertWaitAsyncOpToGpuRuntimeCallPattern::matchAndRewrite(
@@ -472,17 +508,21 @@ LogicalResult ConvertWaitAsyncOpToGpuRuntimeCallPattern::matchAndRewrite(
// which is late and therefore misses parallelism, but still valid.
rewriter.setInsertionPointToStart(waitOp->getBlock());
}
Expand All @@ -239,15 +231,15 @@ index cee1d7ba20e3..7fd9f2cd95d3 100644
rewriter.replaceOp(waitOp, {stream});

return success();
@@ -601,6 +649,7 @@ LogicalResult ConvertLaunchFuncOpToGpuRuntimeCallPattern::matchAndRewrite(
@@ -601,6 +641,7 @@ LogicalResult ConvertLaunchFuncOpToGpuRuntimeCallPattern::matchAndRewrite(
launchOp, "Cannot convert non-async op with async dependencies.");

Location loc = launchOp.getLoc();
+ unsigned indexBitwidth = getTypeConverter()->getIndexTypeBitwidth();

// Create an LLVM global with CUBIN extracted from the kernel annotation and
// obtain a pointer to the first byte in it.
@@ -622,25 +671,27 @@ LogicalResult ConvertLaunchFuncOpToGpuRuntimeCallPattern::matchAndRewrite(
@@ -622,25 +663,27 @@ LogicalResult ConvertLaunchFuncOpToGpuRuntimeCallPattern::matchAndRewrite(
LLVM::createGlobalString(loc, rewriter, nameBuffer.str(),
binaryAttr.getValue(), LLVM::Linkage::Internal);

Expand Down Expand Up @@ -279,7 +271,7 @@ index cee1d7ba20e3..7fd9f2cd95d3 100644
{function.getResult(0), launchOp.gridSizeX(),
launchOp.gridSizeY(), launchOp.gridSizeZ(),
launchOp.blockSizeX(), launchOp.blockSizeY(),
@@ -655,11 +706,12 @@ LogicalResult ConvertLaunchFuncOpToGpuRuntimeCallPattern::matchAndRewrite(
@@ -655,11 +698,12 @@ LogicalResult ConvertLaunchFuncOpToGpuRuntimeCallPattern::matchAndRewrite(
// Synchronize with host and destroy stream. This must be the stream created
// above (with no other uses) because we check that the synchronous version
// does not have any async dependencies.
Expand All @@ -295,15 +287,15 @@ index cee1d7ba20e3..7fd9f2cd95d3 100644

return success();
}
@@ -675,6 +727,7 @@ LogicalResult ConvertMemcpyOpToGpuRuntimeCallPattern::matchAndRewrite(
@@ -675,6 +719,7 @@ LogicalResult ConvertMemcpyOpToGpuRuntimeCallPattern::matchAndRewrite(
return failure();

auto loc = memcpyOp.getLoc();
+ unsigned indexBitwidth = getTypeConverter()->getIndexTypeBitwidth();
auto adaptor = gpu::MemcpyOpAdaptor(operands, memcpyOp->getAttrDictionary());

MemRefDescriptor srcDesc(adaptor.src());
@@ -701,7 +754,8 @@ LogicalResult ConvertMemcpyOpToGpuRuntimeCallPattern::matchAndRewrite(
@@ -701,7 +746,8 @@ LogicalResult ConvertMemcpyOpToGpuRuntimeCallPattern::matchAndRewrite(
MemRefDescriptor(adaptor.dst()).alignedPtr(rewriter, loc));

auto stream = adaptor.asyncDependencies().front();
Expand All @@ -313,7 +305,7 @@ index cee1d7ba20e3..7fd9f2cd95d3 100644

rewriter.replaceOp(memcpyOp, {stream});

@@ -709,8 +763,11 @@ LogicalResult ConvertMemcpyOpToGpuRuntimeCallPattern::matchAndRewrite(
@@ -709,8 +755,11 @@ LogicalResult ConvertMemcpyOpToGpuRuntimeCallPattern::matchAndRewrite(
}

std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
Expand All @@ -328,7 +320,7 @@ index cee1d7ba20e3..7fd9f2cd95d3 100644

void mlir::populateGpuToLLVMConversionPatterns(
diff --git a/mlir/test/Conversion/GPUCommon/lower-alloc-to-gpu-runtime-calls.mlir b/mlir/test/Conversion/GPUCommon/lower-alloc-to-gpu-runtime-calls.mlir
index 4169f0e8191d..b22e947d975b 100644
index 4169f0e8191d..b210b9fc521c 100644
--- a/mlir/test/Conversion/GPUCommon/lower-alloc-to-gpu-runtime-calls.mlir
+++ b/mlir/test/Conversion/GPUCommon/lower-alloc-to-gpu-runtime-calls.mlir
@@ -1,4 +1,5 @@
Expand All @@ -337,18 +329,17 @@ index 4169f0e8191d..b22e947d975b 100644

module attributes {gpu.container_module} {
// CHECK-LABEL: llvm.func @main
@@ -8,6 +9,9 @@ module attributes {gpu.container_module} {
@@ -8,6 +9,8 @@ module attributes {gpu.container_module} {
%0 = gpu.wait async
// CHECK: %[[gep:.*]] = llvm.getelementptr {{.*}}[%[[size]]]
// CHECK: %[[size_bytes:.*]] = llvm.ptrtoint %[[gep]]
+ // CHECK32: %[[size_bytes:.*]] = llvm.ptrtoint
+ // CHECK32: {{%.*}} = llvm.sext %[[size_bytes:.*]] : i32 to i64
+
+ // CHECK32: {{%.*}} = llvm.zext %[[size_bytes:.*]] : i32 to i64
// CHECK: llvm.call @mgpuMemAlloc(%[[size_bytes]], %[[stream]])
%1, %2 = gpu.alloc async [%0] (%size) : memref<?xf32>
// CHECK: %[[float_ptr:.*]] = llvm.extractvalue {{.*}}[0]
diff --git a/mlir/test/Conversion/GPUCommon/lower-launch-func-to-gpu-runtime-calls.mlir b/mlir/test/Conversion/GPUCommon/lower-launch-func-to-gpu-runtime-calls.mlir
index 634385cf1a64..f3596e74660a 100644
index 634385cf1a64..dc32dec2e093 100644
--- a/mlir/test/Conversion/GPUCommon/lower-launch-func-to-gpu-runtime-calls.mlir
+++ b/mlir/test/Conversion/GPUCommon/lower-launch-func-to-gpu-runtime-calls.mlir
@@ -1,4 +1,5 @@
Expand All @@ -357,24 +348,17 @@ index 634385cf1a64..f3596e74660a 100644
// RUN: mlir-opt %s --gpu-to-llvm="gpu-binary-annotation=rocdl.hsaco" | FileCheck %s --check-prefix=ROCDL

module attributes {gpu.container_module} {
@@ -26,12 +27,14 @@ module attributes {gpu.container_module} {
args(%c32 : i32, %buffer : memref<?xf32>)
return
@@ -28,6 +29,8 @@ module attributes {gpu.container_module} {
}
-
+

// CHECK: [[C8:%.*]] = llvm.mlir.constant(8 : index) : i64
+ // CHECK32: [[C8:%.*]] = llvm.mlir.constant(8 : index) : i32
+ // CHECK32: {{%.*}} = llvm.zext [[C8]] : i32 to i64
// CHECK: [[ADDRESSOF:%.*]] = llvm.mlir.addressof @[[GLOBAL]]
// CHECK: [[C0:%.*]] = llvm.mlir.constant(0 : index)
// CHECK: [[BINARY:%.*]] = llvm.getelementptr [[ADDRESSOF]]{{\[}}[[C0]], [[C0]]]
// CHECK-SAME: -> !llvm.ptr<i8>
+ // CHECK32: [[C8:%.*]] = llvm.mlir.constant(8 : index) : i32
+ // CHECK32: {{%.*}} = llvm.sext [[C8]] : i32 to i64

// CHECK: [[MODULE:%.*]] = llvm.call @mgpuModuleLoad([[BINARY]])
// CHECK: [[FUNC:%.*]] = llvm.call @mgpuModuleGetFunction([[MODULE]], {{.*}})
diff --git a/mlir/test/Conversion/GPUCommon/lower-memcpy-to-gpu-runtime-calls.mlir b/mlir/test/Conversion/GPUCommon/lower-memcpy-to-gpu-runtime-calls.mlir
index 162c2f4e838a..8311c21c3ab9 100644
index 162c2f4e838a..e665c825f037 100644
--- a/mlir/test/Conversion/GPUCommon/lower-memcpy-to-gpu-runtime-calls.mlir
+++ b/mlir/test/Conversion/GPUCommon/lower-memcpy-to-gpu-runtime-calls.mlir
@@ -1,4 +1,5 @@
Expand All @@ -383,13 +367,12 @@ index 162c2f4e838a..8311c21c3ab9 100644

module attributes {gpu.container_module} {

@@ -7,6 +8,9 @@ module attributes {gpu.container_module} {
@@ -7,6 +8,8 @@ module attributes {gpu.container_module} {
// CHECK: %[[t0:.*]] = llvm.call @mgpuStreamCreate
%t0 = gpu.wait async
// CHECK: %[[size_bytes:.*]] = llvm.ptrtoint
+ // CHECK32: %[[size_bytes:.*]] = llvm.ptrtoint
+ // CHECK32: {{%.*}} = llvm.sext %[[size_bytes:.*]] : i32 to i64
+
+ // CHECK32: {{%.*}} = llvm.zext %[[size_bytes:.*]] : i32 to i64
// CHECK: %[[src:.*]] = llvm.bitcast
// CHECK: %[[dst:.*]] = llvm.bitcast
// CHECK: llvm.call @mgpuMemcpy(%[[dst]], %[[src]], %[[size_bytes]], %[[t0]])
Expand Down

0 comments on commit 83cb8a7

Please sign in to comment.