Skip to content

Commit

Permalink
cleaner runtime patch
Browse files Browse the repository at this point in the history
  • Loading branch information
gysit committed Jan 19, 2021
1 parent 4b7bf75 commit 0160e99
Showing 1 changed file with 243 additions and 88 deletions.
331 changes: 243 additions & 88 deletions patches/runtime.patch
Original file line number Diff line number Diff line change
@@ -1,45 +1,25 @@
diff --git a/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp b/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp
index 6859834de67f..3f34ab7de38a 100644
index cee1d7ba20e3..a148c38ac872 100644
--- a/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp
+++ b/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp
@@ -102,12 +102,12 @@ protected:
llvmVoidType,
{
llvmPointerType, /* void* f */
- llvmIntPtrType, /* intptr_t gridXDim */
- llvmIntPtrType, /* intptr_t gridyDim */
- llvmIntPtrType, /* intptr_t gridZDim */
- llvmIntPtrType, /* intptr_t blockXDim */
- llvmIntPtrType, /* intptr_t blockYDim */
- llvmIntPtrType, /* intptr_t blockZDim */
+ llvmInt32Type, /* int32_t gridXDim */
+ llvmInt32Type, /* int32_t gridyDim */
+ llvmInt32Type, /* int32_t gridZDim */
+ llvmInt32Type, /* int32_t blockXDim */
+ llvmInt32Type, /* int32_t blockYDim */
+ llvmInt32Type, /* int32_t blockZDim */
llvmInt32Type, /* unsigned int sharedMemBytes */
llvmPointerType, /* void *hstream */
llvmPointerPointerType, /* void **kernelParams */
@@ -146,7 +146,7 @@ protected:
FunctionCallBuilder allocCallBuilder = {
"mgpuMemAlloc",
llvmPointerType /* void * */,
- {llvmIntPtrType /* intptr_t sizeBytes */,
+ {llvmInt32Type /* int32_t sizeBytes */,
llvmPointerType /* void *stream */}};
FunctionCallBuilder deallocCallBuilder = {
"mgpuMemFree",
@@ -156,7 +156,7 @@ protected:
"mgpuMemcpy",
llvmVoidType,
{llvmPointerType /* void *dst */, llvmPointerType /* void *src */,
- llvmIntPtrType /* intptr_t sizeBytes */,
+ llvmInt32Type /* int32_t sizeBytes */,
llvmPointerType /* void *stream */}};
};

@@ -292,7 +292,11 @@ private:
@@ -17,6 +17,7 @@

#include "../PassDetail.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
+#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
#include "mlir/Dialect/GPU/GPUDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/Attributes.h"
@@ -56,7 +57,7 @@ public:
ArrayRef<Type> argumentTypes)
: functionName(functionName),
functionType(LLVM::LLVMFunctionType::get(returnType, argumentTypes)) {}
- LLVM::CallOp create(Location loc, OpBuilder &builder,
+ LLVM::CallOp create(Location loc, OpBuilder &builder, unsigned indexBitwidth,
ArrayRef<Value> arguments) const;

private:
@@ -291,7 +292,11 @@ private:
} // namespace

void GpuToLLVMConversionPass::runOnOperation() {
Expand All @@ -52,36 +32,230 @@ index 6859834de67f..3f34ab7de38a 100644
OwningRewritePatternList patterns;
populateStdToLLVMConversionPatterns(converter, patterns);
populateGpuToLLVMConversionPatterns(converter, patterns, gpuBinaryAnnotation);
diff --git a/mlir/tools/mlir-cuda-runner/cuda-runtime-wrappers.cpp b/mlir/tools/mlir-cuda-runner/cuda-runtime-wrappers.cpp
index 72d172889d30..9e0aff602113 100644
--- a/mlir/tools/mlir-cuda-runner/cuda-runtime-wrappers.cpp
+++ b/mlir/tools/mlir-cuda-runner/cuda-runtime-wrappers.cpp
@@ -61,10 +61,10 @@ extern "C" CUfunction mgpuModuleGetFunction(CUmodule module, const char *name) {
// The wrapper uses intptr_t instead of CUDA's unsigned int to match
// the type of MLIR's index type. This avoids the need for casts in the
// generated MLIR code.
-extern "C" void mgpuLaunchKernel(CUfunction function, intptr_t gridX,
- intptr_t gridY, intptr_t gridZ,
- intptr_t blockX, intptr_t blockY,
- intptr_t blockZ, int32_t smem, CUstream stream,
+extern "C" void mgpuLaunchKernel(CUfunction function, int32_t gridX,
+ int32_t gridY, int32_t gridZ,
+ int32_t blockX, int32_t blockY,
+ int32_t blockZ, int32_t smem, CUstream stream,
void **params, void **extra) {
CUDA_REPORT_IF_ERROR(cuLaunchKernel(function, gridX, gridY, gridZ, blockX,
blockY, blockZ, smem, stream, params,
@@ -107,7 +107,7 @@ extern "C" void mgpuEventRecord(CUevent event, CUstream stream) {
CUDA_REPORT_IF_ERROR(cuEventRecord(event, stream));
@@ -303,6 +308,7 @@ void GpuToLLVMConversionPass::runOnOperation() {
}

LLVM::CallOp FunctionCallBuilder::create(Location loc, OpBuilder &builder,
+ unsigned indexBitwidth,
ArrayRef<Value> arguments) const {
auto module = builder.getBlock()->getParent()->getParentOfType<ModuleOp>();
auto function = [&] {
@@ -311,9 +317,34 @@ LLVM::CallOp FunctionCallBuilder::create(Location loc, OpBuilder &builder,
return OpBuilder(module.getBody()->getTerminator())
.create<LLVM::LLVMFuncOp>(loc, functionName, functionType);
}();
+ // Cast index arguments to the bitwidth of the runtime wrapper functions.
+ SmallVector<Value, 4> castedArguments;
+ castedArguments.reserve(arguments.size());
+ llvm::transform(
+ llvm::zip(arguments,
+ const_cast<LLVM::LLVMFunctionType &>(functionType).getParams()),
+ std::back_inserter(castedArguments),
+ [&](std::tuple<Value, Type> x) -> Value {
+ auto value = std::get<0>(x);
+ auto paramType = std::get<1>(x);
+ if (value.getType().isIndex() &&
+ paramType.getIntOrFloatBitWidth() > indexBitwidth) {
+ return builder.create<LLVM::SExtOp>(loc, paramType, value);
+ }
+ if (value.getType().isInteger(indexBitwidth) &&
+ paramType.getIntOrFloatBitWidth() >
+ value.getType().getIntOrFloatBitWidth()) {
+ return value.getType().isUnsignedInteger()
+ ? builder.create<LLVM::ZExtOp>(loc, paramType, value)
+ .getResult()
+ : builder.create<LLVM::SExtOp>(loc, paramType, value)
+ .getResult();
+ }
+ return value;
+ });
return builder.create<LLVM::CallOp>(
loc, const_cast<LLVM::LLVMFunctionType &>(functionType).getReturnType(),
- builder.getSymbolRefAttr(function), arguments);
+ builder.getSymbolRefAttr(function), castedArguments);
}

// Returns whether all operands are of LLVM type.
@@ -348,6 +379,7 @@ LogicalResult ConvertHostRegisterOpToGpuRuntimeCallPattern::matchAndRewrite(
return failure();

Location loc = op->getLoc();
+ unsigned indexBitwidth = getTypeConverter()->getIndexTypeBitwidth();

auto memRefType = hostRegisterOp.value().getType();
auto elementType = memRefType.cast<UnrankedMemRefType>().getElementType();
@@ -356,7 +388,7 @@ LogicalResult ConvertHostRegisterOpToGpuRuntimeCallPattern::matchAndRewrite(
auto arguments = getTypeConverter()->promoteOperands(loc, op->getOperands(),
operands, rewriter);
arguments.push_back(elementSize);
- hostRegisterCallBuilder.create(loc, rewriter, arguments);
+ hostRegisterCallBuilder.create(loc, rewriter, indexBitwidth, arguments);

rewriter.eraseOp(op);
return success();
@@ -373,6 +405,7 @@ LogicalResult ConvertAllocOpToGpuRuntimeCallPattern::matchAndRewrite(
return failure();

auto loc = allocOp.getLoc();
+ unsigned indexBitwidth = getTypeConverter()->getIndexTypeBitwidth();
auto adaptor = gpu::AllocOpAdaptor(operands, allocOp->getAttrDictionary());

// Get shape of the memref as values: static sizes are constant
@@ -388,7 +421,8 @@ LogicalResult ConvertAllocOpToGpuRuntimeCallPattern::matchAndRewrite(
Type elementPtrType = this->getElementPtrType(memRefType);
auto stream = adaptor.asyncDependencies().front();
Value allocatedPtr =
- allocCallBuilder.create(loc, rewriter, {sizeBytes, stream}).getResult(0);
+ allocCallBuilder.create(loc, rewriter, indexBitwidth, {sizeBytes, stream})
+ .getResult(0);
allocatedPtr =
rewriter.create<LLVM::BitcastOp>(loc, elementPtrType, allocatedPtr);

@@ -412,6 +446,7 @@ LogicalResult ConvertDeallocOpToGpuRuntimeCallPattern::matchAndRewrite(
return failure();

Location loc = deallocOp.getLoc();
+ unsigned indexBitwidth = getTypeConverter()->getIndexTypeBitwidth();

auto adaptor =
gpu::DeallocOpAdaptor(operands, deallocOp->getAttrDictionary());
@@ -419,7 +454,7 @@ LogicalResult ConvertDeallocOpToGpuRuntimeCallPattern::matchAndRewrite(
MemRefDescriptor(adaptor.memref()).allocatedPtr(rewriter, loc);
auto casted = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, pointer);
Value stream = adaptor.asyncDependencies().front();
- deallocCallBuilder.create(loc, rewriter, {casted, stream});
+ deallocCallBuilder.create(loc, rewriter, indexBitwidth, {casted, stream});

rewriter.replaceOp(deallocOp, {stream});
return success();
@@ -437,11 +472,14 @@ LogicalResult ConvertWaitOpToGpuRuntimeCallPattern::matchAndRewrite(
return rewriter.notifyMatchFailure(waitOp, "Cannot convert async op.");

Location loc = waitOp.getLoc();
+ unsigned indexBitwidth = getTypeConverter()->getIndexTypeBitwidth();

for (auto asyncDependency : operands)
- streamSynchronizeCallBuilder.create(loc, rewriter, {asyncDependency});
+ streamSynchronizeCallBuilder.create(loc, rewriter, indexBitwidth,
+ {asyncDependency});
for (auto asyncDependency : operands)
- streamDestroyCallBuilder.create(loc, rewriter, {asyncDependency});
+ streamDestroyCallBuilder.create(loc, rewriter, indexBitwidth,
+ {asyncDependency});

rewriter.eraseOp(waitOp);
return success();
@@ -460,6 +498,7 @@ LogicalResult ConvertWaitAsyncOpToGpuRuntimeCallPattern::matchAndRewrite(
return rewriter.notifyMatchFailure(waitOp, "Can only convert async op.");

Location loc = waitOp.getLoc();
+ unsigned indexBitwidth = getTypeConverter()->getIndexTypeBitwidth();

auto insertionPoint = rewriter.saveInsertionPoint();
SmallVector<Value, 1> events;
@@ -472,17 +511,21 @@ LogicalResult ConvertWaitAsyncOpToGpuRuntimeCallPattern::matchAndRewrite(
// which is late and therefore misses parallelism, but still valid.
rewriter.setInsertionPointToStart(waitOp->getBlock());
}
- auto event = eventCreateCallBuilder.create(loc, rewriter, {}).getResult(0);
+ auto event = eventCreateCallBuilder.create(loc, rewriter, indexBitwidth, {})
+ .getResult(0);
auto stream = std::get<1>(pair);
- eventRecordCallBuilder.create(loc, rewriter, {event, stream});
+ eventRecordCallBuilder.create(loc, rewriter, indexBitwidth,
+ {event, stream});
events.push_back(event);
}
rewriter.restoreInsertionPoint(insertionPoint);
- auto stream = streamCreateCallBuilder.create(loc, rewriter, {}).getResult(0);
+ auto stream = streamCreateCallBuilder.create(loc, rewriter, indexBitwidth, {})
+ .getResult(0);
for (auto event : events)
- streamWaitEventCallBuilder.create(loc, rewriter, {stream, event});
+ streamWaitEventCallBuilder.create(loc, rewriter, indexBitwidth,
+ {stream, event});
for (auto event : events)
- eventDestroyCallBuilder.create(loc, rewriter, {event});
+ eventDestroyCallBuilder.create(loc, rewriter, indexBitwidth, {event});
rewriter.replaceOp(waitOp, {stream});

return success();
@@ -601,6 +644,7 @@ LogicalResult ConvertLaunchFuncOpToGpuRuntimeCallPattern::matchAndRewrite(
launchOp, "Cannot convert non-async op with async dependencies.");

Location loc = launchOp.getLoc();
+ unsigned indexBitwidth = getTypeConverter()->getIndexTypeBitwidth();

// Create an LLVM global with CUBIN extracted from the kernel annotation and
// obtain a pointer to the first byte in it.
@@ -622,25 +666,27 @@ LogicalResult ConvertLaunchFuncOpToGpuRuntimeCallPattern::matchAndRewrite(
LLVM::createGlobalString(loc, rewriter, nameBuffer.str(),
binaryAttr.getValue(), LLVM::Linkage::Internal);

- auto module = moduleLoadCallBuilder.create(loc, rewriter, data);
+ auto module =
+ moduleLoadCallBuilder.create(loc, rewriter, indexBitwidth, data);
// Get the function from the module. The name corresponds to the name of
// the kernel function.
auto kernelName = generateKernelNameConstant(
launchOp.getKernelModuleName(), launchOp.getKernelName(), loc, rewriter);
auto function = moduleGetFunctionCallBuilder.create(
- loc, rewriter, {module.getResult(0), kernelName});
+ loc, rewriter, indexBitwidth, {module.getResult(0), kernelName});
auto zero = rewriter.create<LLVM::ConstantOp>(loc, llvmInt32Type,
rewriter.getI32IntegerAttr(0));
auto adaptor =
gpu::LaunchFuncOpAdaptor(operands, launchOp->getAttrDictionary());
Value stream =
adaptor.asyncDependencies().empty()
- ? streamCreateCallBuilder.create(loc, rewriter, {}).getResult(0)
+ ? streamCreateCallBuilder.create(loc, rewriter, indexBitwidth, {})
+ .getResult(0)
: adaptor.asyncDependencies().front();
// Create array of pointers to kernel arguments.
auto kernelParams = generateParamsArray(launchOp, operands, rewriter);
auto nullpointer = rewriter.create<LLVM::NullOp>(loc, llvmPointerPointerType);
- launchKernelCallBuilder.create(loc, rewriter,
+ launchKernelCallBuilder.create(loc, rewriter, indexBitwidth,
{function.getResult(0), launchOp.gridSizeX(),
launchOp.gridSizeY(), launchOp.gridSizeZ(),
launchOp.blockSizeX(), launchOp.blockSizeY(),
@@ -655,11 +701,12 @@ LogicalResult ConvertLaunchFuncOpToGpuRuntimeCallPattern::matchAndRewrite(
// Synchronize with host and destroy stream. This must be the stream created
// above (with no other uses) because we check that the synchronous version
// does not have any async dependencies.
- streamSynchronizeCallBuilder.create(loc, rewriter, stream);
- streamDestroyCallBuilder.create(loc, rewriter, stream);
+ streamSynchronizeCallBuilder.create(loc, rewriter, indexBitwidth, stream);
+ streamDestroyCallBuilder.create(loc, rewriter, indexBitwidth, stream);
rewriter.eraseOp(launchOp);
}
- moduleUnloadCallBuilder.create(loc, rewriter, module.getResult(0));
+ moduleUnloadCallBuilder.create(loc, rewriter, indexBitwidth,
+ module.getResult(0));

return success();
}
@@ -675,6 +722,7 @@ LogicalResult ConvertMemcpyOpToGpuRuntimeCallPattern::matchAndRewrite(
return failure();

auto loc = memcpyOp.getLoc();
+ unsigned indexBitwidth = getTypeConverter()->getIndexTypeBitwidth();
auto adaptor = gpu::MemcpyOpAdaptor(operands, memcpyOp->getAttrDictionary());

MemRefDescriptor srcDesc(adaptor.src());
@@ -701,7 +749,8 @@ LogicalResult ConvertMemcpyOpToGpuRuntimeCallPattern::matchAndRewrite(
MemRefDescriptor(adaptor.dst()).alignedPtr(rewriter, loc));

auto stream = adaptor.asyncDependencies().front();
- memcpyCallBuilder.create(loc, rewriter, {dst, src, sizeBytes, stream});
+ memcpyCallBuilder.create(loc, rewriter, indexBitwidth,
+ {dst, src, sizeBytes, stream});

rewriter.replaceOp(memcpyOp, {stream});

-extern "C" void *mgpuMemAlloc(uint64_t sizeBytes, CUstream /*stream*/) {
+extern "C" void *mgpuMemAlloc(uint32_t sizeBytes, CUstream /*stream*/) {
CUdeviceptr ptr;
CUDA_REPORT_IF_ERROR(cuMemAlloc(&ptr, sizeBytes));
return reinterpret_cast<void *>(ptr);
diff --git a/mlir/tools/mlir-rocm-runner/rocm-runtime-wrappers.cpp b/mlir/tools/mlir-rocm-runner/rocm-runtime-wrappers.cpp
index 4f62f204f4a8..da8cbd24628e 100644
index 4f62f204f4a8..b8f55ddd8e83 100644
--- a/mlir/tools/mlir-rocm-runner/rocm-runtime-wrappers.cpp
+++ b/mlir/tools/mlir-rocm-runner/rocm-runtime-wrappers.cpp
@@ -36,7 +36,7 @@ static auto InitializeCtx = [] {
Expand All @@ -93,27 +267,9 @@ index 4f62f204f4a8..da8cbd24628e 100644
HIP_REPORT_IF_ERROR(hipCtxCreate(&context, /*flags=*/0, device));
return 0;
}();
@@ -61,10 +61,10 @@ extern "C" hipFunction_t mgpuModuleGetFunction(hipModule_t module,
// The wrapper uses intptr_t instead of ROCM's unsigned int to match
// the type of MLIR's index type. This avoids the need for casts in the
// generated MLIR code.
-extern "C" void mgpuLaunchKernel(hipFunction_t function, intptr_t gridX,
- intptr_t gridY, intptr_t gridZ,
- intptr_t blockX, intptr_t blockY,
- intptr_t blockZ, int32_t smem,
+extern "C" void mgpuLaunchKernel(hipFunction_t function, int32_t gridX,
+ int32_t gridY, int32_t gridZ,
+ int32_t blockX, int32_t blockY,
+ int32_t blockZ, int32_t smem,
hipStream_t stream, void **params,
void **extra) {
HIP_REPORT_IF_ERROR(hipModuleLaunchKernel(function, gridX, gridY, gridZ,
@@ -108,19 +108,19 @@ extern "C" void mgpuEventRecord(hipEvent_t event, hipStream_t stream) {
HIP_REPORT_IF_ERROR(hipEventRecord(event, stream));
}
@@ -110,17 +110,17 @@ extern "C" void mgpuEventRecord(hipEvent_t event, hipStream_t stream) {

-extern "C" void *mgpuMemAlloc(uint64_t sizeBytes, hipStream_t /*stream*/) {
+extern "C" void *mgpuMemAlloc(uint32_t sizeBytes, hipStream_t /*stream*/) {
extern "C" void *mgpuMemAlloc(uint64_t sizeBytes, hipStream_t /*stream*/) {
void *ptr;
- HIP_REPORT_IF_ERROR(hipMemAlloc(&ptr, sizeBytes));
+ HIP_REPORT_IF_ERROR(hipMalloc(&ptr, sizeBytes));
Expand All @@ -125,8 +281,7 @@ index 4f62f204f4a8..da8cbd24628e 100644
+ HIP_REPORT_IF_ERROR(hipFree(ptr));
}

-extern "C" void mgpuMemcpy(void *dst, void *src, uint64_t sizeBytes,
+extern "C" void mgpuMemcpy(void *dst, void *src, uint32_t sizeBytes,
extern "C" void mgpuMemcpy(void *dst, void *src, uint64_t sizeBytes,
hipStream_t stream) {
- HIP_REPORT_IF_ERROR(hipMemcpyAsync(dst, src, sizeBytes, stream));
+ HIP_REPORT_IF_ERROR(hipMemcpyAsync(dst, src, sizeBytes, hipMemcpyDefault, stream));
Expand Down

0 comments on commit 0160e99

Please sign in to comment.