Skip to content

Commit

Permalink
Add options to kernel lowering (#196)
Browse files Browse the repository at this point in the history
* Add options to kernel lowering

* fmt
  • Loading branch information
wsmoses authored Dec 17, 2024
1 parent 3ce6c51 commit dea6396
Show file tree
Hide file tree
Showing 2 changed files with 173 additions and 12 deletions.
148 changes: 137 additions & 11 deletions src/enzyme_ad/jax/Passes/LowerKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,18 +55,124 @@
#include "llvm/ExecutionEngine/Orc/ThreadSafeModule.h"
#include "llvm/ExecutionEngine/SectionMemoryManager.h"

#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h"
#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h"
#include "mlir/Conversion/IndexToLLVM/IndexToLLVM.h"
#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
#include "mlir/Conversion/NVGPUToNVVM/NVGPUToNVVM.h"
#include "mlir/Conversion/NVVMToLLVM/NVVMToLLVM.h"
#include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h"
#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.h"
#include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/GPU/Pipelines/Passes.h"
#include "mlir/Dialect/GPU/Transforms/Passes.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Pass/PassOptions.h"
#include "mlir/Transforms/Passes.h"

#include "mlir/Target/LLVMIR/Export.h"

#define DEBUG_TYPE "lower-kernel"

using namespace mlir;
using namespace mlir::enzyme;
using namespace mlir::gpu;
using namespace enzyme;
using namespace mlir::enzymexla;
using namespace enzymexla;

using namespace stablehlo;

namespace {

void buildCommonPassPipeline(
OpPassManager &pm, const mlir::gpu::GPUToNVVMPipelineOptions &options) {
pm.addPass(createConvertNVGPUToNVVMPass());
pm.addPass(createGpuKernelOutliningPass());
pm.addPass(createConvertVectorToSCFPass());
pm.addPass(createConvertSCFToCFPass());
pm.addPass(createConvertNVVMToLLVMPass());
pm.addPass(createConvertFuncToLLVMPass());
pm.addPass(memref::createExpandStridedMetadataPass());

GpuNVVMAttachTargetOptions nvvmTargetOptions;
nvvmTargetOptions.triple = options.cubinTriple;
nvvmTargetOptions.chip = options.cubinChip;
nvvmTargetOptions.features = options.cubinFeatures;
nvvmTargetOptions.optLevel = options.optLevel;
pm.addPass(createGpuNVVMAttachTarget(nvvmTargetOptions));
pm.addPass(createLowerAffinePass());
pm.addPass(createArithToLLVMConversionPass());
ConvertIndexToLLVMPassOptions convertIndexToLLVMPassOpt;
convertIndexToLLVMPassOpt.indexBitwidth = options.indexBitWidth;
pm.addPass(createConvertIndexToLLVMPass(convertIndexToLLVMPassOpt));
pm.addPass(createCanonicalizerPass());
pm.addPass(createCSEPass());
}

//===----------------------------------------------------------------------===//
// GPUModule-specific stuff.
//===----------------------------------------------------------------------===//
void buildGpuPassPipeline(OpPassManager &pm,
const mlir::gpu::GPUToNVVMPipelineOptions &options) {
pm.addNestedPass<gpu::GPUModuleOp>(createStripDebugInfoPass());
ConvertGpuOpsToNVVMOpsOptions opt;
opt.useBarePtrCallConv = options.kernelUseBarePtrCallConv;
opt.indexBitwidth = options.indexBitWidth;
pm.addNestedPass<gpu::GPUModuleOp>(createConvertGpuOpsToNVVMOps(opt));
pm.addNestedPass<gpu::GPUModuleOp>(createCanonicalizerPass());
pm.addNestedPass<gpu::GPUModuleOp>(createCSEPass());
pm.addNestedPass<gpu::GPUModuleOp>(createReconcileUnrealizedCastsPass());
}

//===----------------------------------------------------------------------===//
// Host Post-GPU pipeline
//===----------------------------------------------------------------------===//
void buildHostPostPipeline(OpPassManager &pm,
const mlir::gpu::GPUToNVVMPipelineOptions &options,
std::string toolkitPath,
llvm::SmallVectorImpl<std::string> &linkFiles) {
GpuToLLVMConversionPassOptions opt;
opt.hostBarePtrCallConv = options.hostUseBarePtrCallConv;
opt.kernelBarePtrCallConv = options.kernelUseBarePtrCallConv;
pm.addPass(createGpuToLLVMConversionPass(opt));

GpuModuleToBinaryPassOptions gpuModuleToBinaryPassOptions;
gpuModuleToBinaryPassOptions.compilationTarget = options.cubinFormat;
gpuModuleToBinaryPassOptions.toolkitPath = toolkitPath;
gpuModuleToBinaryPassOptions.linkFiles.append(linkFiles);
pm.addPass(createGpuModuleToBinaryPass(gpuModuleToBinaryPassOptions));
pm.addPass(createConvertMathToLLVMPass());
pm.addPass(createCanonicalizerPass());
pm.addPass(createCSEPass());
pm.addPass(createReconcileUnrealizedCastsPass());
}

void buildLowerToNVVMPassPipeline(
OpPassManager &pm, const GPUToNVVMPipelineOptions &options,
std::string toolkitPath, llvm::SmallVectorImpl<std::string> &linkFiles) {
// Common pipelines
buildCommonPassPipeline(pm, options);

// GPUModule-specific stuff
buildGpuPassPipeline(pm, options);

// Host post-GPUModule-specific stuff
buildHostPostPipeline(pm, options, toolkitPath, linkFiles);
}

} // namespace

typedef void XlaCustomCallStatus;

llvm::StringMap<void *> kernels;
Expand Down Expand Up @@ -104,8 +210,7 @@ void *CompileHostModule(std::string &key, mlir::ModuleOp modOp) {
std::unique_ptr<llvm::LLVMContext> ctx(new llvm::LLVMContext);
auto llvmModule = translateModuleToLLVMIR(modOp, *ctx);
if (!llvmModule) {
llvm::errs() << "could not convert to LLVM IR"
<< "\n";
llvm::errs() << "could not convert to LLVM IR\n";
return nullptr;
}
llvmModule->setDataLayout(JIT->getDataLayout());
Expand Down Expand Up @@ -191,7 +296,10 @@ gpu::ObjectAttr getSelectedObject(gpu::BinaryOp op) {
void *CompileKernel(SymbolTableCollection &symbolTable, mlir::Location loc,
FunctionOpInterface op, bool jit, size_t gridx,
size_t gridy, size_t gridz, size_t blockx, size_t blocky,
size_t blockz, size_t shmem) {
size_t blockz, size_t shmem, std::string toolkitPath,
llvm::SmallVectorImpl<std::string> &linkFiles,
int indexBitWidth, std::string cubinChip,
std::string cubinFeatures) {

OpBuilder builder(op);

Expand Down Expand Up @@ -332,15 +440,15 @@ void *CompileKernel(SymbolTableCollection &symbolTable, mlir::Location loc,

PassManager pm(submod.getContext());
mlir::gpu::GPUToNVVMPipelineOptions options;
options.indexBitWidth = 64;
options.indexBitWidth = indexBitWidth;
options.cubinTriple = "nvptx64-nvidia-cuda";
options.cubinChip = "sm_50";
options.cubinFeatures = "+ptx60";
options.cubinChip = cubinChip;
options.cubinFeatures = cubinFeatures;
options.cubinFormat = "fatbin";
options.optLevel = 2;
options.kernelUseBarePtrCallConv = false;
options.hostUseBarePtrCallConv = false;
mlir::gpu::buildLowerToNVVMPassPipeline(pm, options);
buildLowerToNVVMPassPipeline(pm, options, toolkitPath, linkFiles);

pm.run(submod);

Expand Down Expand Up @@ -491,7 +599,9 @@ struct LowerKernelPass : public LowerKernelPassBase<LowerKernelPass> {
options.optLevel = 2;
options.kernelUseBarePtrCallConv = false;
options.hostUseBarePtrCallConv = false;
mlir::gpu::buildLowerToNVVMPassPipeline(pm, options);
std::string toolkitPath = "";
SmallVector<std::string> linkFiles;
buildLowerToNVVMPassPipeline(pm, options, toolkitPath, linkFiles);
pm.getDependentDialects(registry);

registry.insert<mlir::arith::ArithDialect, mlir::func::FuncDialect,
Expand All @@ -501,12 +611,26 @@ struct LowerKernelPass : public LowerKernelPassBase<LowerKernelPass> {
mlir::NVVM::NVVMDialect, mlir::LLVM::LLVMDialect>();
}

SmallVector<std::string> parseLinkFilesString(StringRef inp) {
if (inp.size() == 0)
return {};
SmallVector<StringRef, 1> split;
SmallVector<std::string> out;
StringRef(inp.data(), inp.size()).split(split, ';');
for (auto &str : split) {
out.push_back(str.str());
}
return out;
}

void runOnOperation() override {
auto context = getOperation()->getContext();

SymbolTableCollection symbolTable;
symbolTable.getSymbolTable(getOperation());

llvm::SmallVector<std::string> linkFilesArray =
parseLinkFilesString(linkFiles.getValue());
getOperation()->walk([&](KernelCallOp op) {
mlir::ArrayAttr operand_layouts =
op.getOperandLayouts()
Expand Down Expand Up @@ -542,9 +666,11 @@ struct LowerKernelPass : public LowerKernelPassBase<LowerKernelPass> {
}

// Compiled kernel goes here once ready
data[0] = (size_t)CompileKernel(symbolTable, op.getLoc(), fn, jit,
data[1], data[2], data[3], data[4],
data[5], data[6], data[7]);
data[0] = (size_t)CompileKernel(
symbolTable, op.getLoc(), fn, jit, data[1], data[2], data[3], data[4],
data[5], data[6], data[7], toolkitPath.getValue(), linkFilesArray,
indexBitWidth.getValue(), cubinChip.getValue(),
cubinFeatures.getValue());

std::string backendinfo((char *)&data, sizeof(void *));

Expand Down
37 changes: 36 additions & 1 deletion src/enzyme_ad/jax/Passes/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,42 @@ def LowerKernelPass : Pass<"lower-kernel"> {
/*type=*/"bool",
/*default=*/"true",
/*description=*/"Whether to jit the kernel"
>
>,
Option<
/*C++ variable name=*/"toolkitPath",
/*CLI argument=*/"toolkitPath",
/*type=*/"std::string",
/*default=*/"",
/*description=*/"The location of the cuda toolkit"
>,
Option<
/*C++ variable name=*/"linkFiles",
/*CLI argument=*/"linkFiles",
/*type=*/"std::string",
/*default=*/"",
/*description=*/"Semicolon separated list of files to link"
>,
Option<
/*C++ variable name=*/"cubinChip",
/*CLI argument=*/"cubinChip",
/*type=*/"std::string",
/*default=*/"\"sm_50\"",
/*description=*/"cubinChip"
>,
Option<
/*C++ variable name=*/"cubinFeatures",
/*CLI argument=*/"cubinFeatures",
/*type=*/"std::string",
/*default=*/"\"+ptx60\"",
/*description=*/"cubinChip"
>,
Option<
/*C++ variable name=*/"indexBitWidth",
/*CLI argument=*/"indexBitWidth",
/*type=*/"int",
/*default=*/"64",
/*description=*/"indexBitWidth"
>,
];
}

Expand Down

0 comments on commit dea6396

Please sign in to comment.