Skip to content

Commit

Permalink
Custom MLIR lowering pipeline (#30)
Browse files Browse the repository at this point in the history
* Custom MLIR lowering pipeline

* pipeline fix
  • Loading branch information
wsmoses authored Jan 21, 2024
1 parent f95ca8c commit 7c7441a
Show file tree
Hide file tree
Showing 5 changed files with 273 additions and 35 deletions.
1 change: 1 addition & 0 deletions src/enzyme_ad/jax/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ pybind_extension(
"@llvm-project//llvm:OrcJIT",
"@llvm-project//llvm:OrcTargetProcess",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:AllPassesAndDialects",
":clang_compile",
":compile_with_xla",
"@com_google_absl//absl/status:statusor",
Expand Down
7 changes: 6 additions & 1 deletion src/enzyme_ad/jax/compile_with_xla.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@
// Compile an MHLO module given as a string to LLVM IR using XLA.
std::unique_ptr<xla::LocalExecutable>
compile_mhlo_to_llvm_with_xla(llvm::StringRef mhlo_text, std::string &output,
bool xla_runtime) {
bool xla_runtime,
const std::string &pass_pipeline) {
// Parse MLIR.
mlir::MLIRContext context;
context.loadDialect<mlir::arith::ArithDialect>();
Expand Down Expand Up @@ -103,6 +104,10 @@ compile_mhlo_to_llvm_with_xla(llvm::StringRef mhlo_text, std::string &output,
build_options.mutable_debug_options()->set_xla_cpu_use_xla_runtime(
xla_runtime);

build_options.mutable_debug_options()
->mutable_xla_backend_extra_options()
->emplace("xla_cpu_experimental_override_pipeline", pass_pipeline);

if (build_options.device_ordinal() == -1) {
build_options.set_device_ordinal(local_client->default_device_ordinal());
}
Expand Down
3 changes: 2 additions & 1 deletion src/enzyme_ad/jax/compile_with_xla.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@
// Compile an MHLO module given as a string to LLVM IR using XLA.
std::unique_ptr<xla::LocalExecutable>
compile_mhlo_to_llvm_with_xla(llvm::StringRef mhlo_text, std::string &output,
bool xla_runtime);
bool xla_runtime,
const std::string &pass_pipeline);
88 changes: 64 additions & 24 deletions src/enzyme_ad/jax/enzyme_call.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,19 @@
#include "llvm/Support/TargetSelect.h"
#include "llvm/Support/raw_ostream.h"

#include "mlir/InitAllPasses.h"
#include "xla/mlir/backends/cpu/transforms/passes.h"
#include "xla/mlir/math/transforms/passes.h"
#include "xla/mlir/memref/transforms/passes.h"
#include "xla/mlir/runtime/transforms/passes.h"
#include "xla/mlir_hlo/deallocation/transforms/passes.h"
#include "xla/mlir_hlo/lhlo/IR/lhlo_ops.h"
#include "xla/mlir_hlo/lhlo/transforms/passes.h"
#include "xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops.h"
#include "xla/mlir_hlo/mhlo/transforms/passes.h"

#include "xla/mlir_hlo/transforms/passes.h"

#include "compile_with_xla.h"
#include "xla/hlo/ir/hlo_casting_utils.h"
#include "xla/hlo/ir/hlo_instructions.h"
Expand Down Expand Up @@ -81,7 +94,8 @@ class CpuKernel {
llvm::ArrayRef<std::string> out_names,
llvm::ArrayRef<llvm::SmallVector<int64_t>> in_shapes,
llvm::ArrayRef<std::string> in_names, PyObject *pyargv,
ABI mode, Language lang, bool xla_runtime) {
ABI mode, Language lang, bool xla_runtime,
const std::string &pass_pipeline) {
auto llvm_ctx = std::make_unique<llvm::LLVMContext>();

std::string input;
Expand All @@ -102,8 +116,8 @@ class CpuKernel {
break;

case Language::MHLO: {
local_executable =
compile_mhlo_to_llvm_with_xla(source, stringbuf, xla_runtime);
local_executable = compile_mhlo_to_llvm_with_xla(
source, stringbuf, xla_runtime, pass_pipeline);
auto *cpu_executable = static_cast<xla::cpu::CpuExecutable *>(
local_executable->executable());
auto &assignment = cpu_executable->buffer_assignment();
Expand Down Expand Up @@ -830,11 +844,12 @@ class CpuKernel {
llvm::ArrayRef<std::string> out_names,
llvm::ArrayRef<llvm::SmallVector<int64_t>> in_shapes,
llvm::ArrayRef<std::string> in_names, PyObject *pyargv,
Language lang, bool xla_runtime) {
Language lang, bool xla_runtime,
const std::string &pass_pipeline) {
auto mode = ABI::Tape;
auto [mod, llvm_ctx, num_out, tmpBuf] =
createLLVMMod(fn, source, out_shapes, out_names, in_shapes, in_names,
pyargv, mode, lang, xla_runtime);
pyargv, mode, lang, xla_runtime, pass_pipeline);
auto lfn = mod->getFunction("entry");
auto RI =
llvm::cast<llvm::ReturnInst>(lfn->getEntryBlock().getTerminator());
Expand All @@ -846,12 +861,12 @@ class CpuKernel {
}

static size_t tempSize(llvm::StringRef source, Language lang,
bool xla_runtime) {
bool xla_runtime, const std::string &pass_pipeline) {
switch (lang) {
case Language::MHLO: {
std::string llvm_ir;
auto local_executable =
compile_mhlo_to_llvm_with_xla(source, llvm_ir, xla_runtime);
auto local_executable = compile_mhlo_to_llvm_with_xla(
source, llvm_ir, xla_runtime, pass_pipeline);
auto *cpu_executable = static_cast<xla::cpu::CpuExecutable *>(
local_executable->executable());
auto &assignment = cpu_executable->buffer_assignment();
Expand All @@ -868,13 +883,13 @@ class CpuKernel {
llvm::ArrayRef<std::string> out_names,
llvm::ArrayRef<llvm::SmallVector<int64_t>> in_shapes,
llvm::ArrayRef<std::string> in_names, PyObject *pyargv, ABI mode,
Language lang, bool xla_runtime) {
Language lang, bool xla_runtime, const std::string &pass_pipeline) {
llvm::sys::SmartScopedWriter<true> lock(kernel_mutex);
size_t identifier = last_identifier++;

auto [mod, llvm_ctx, num_out, tmpBuf] =
createLLVMMod(fn, source, out_shapes, out_names, in_shapes, in_names,
pyargv, mode, lang, xla_runtime);
pyargv, mode, lang, xla_runtime, pass_pipeline);

if (!JIT) {
DL = std::make_unique<llvm::DataLayout>(mod.get());
Expand Down Expand Up @@ -986,6 +1001,27 @@ PYBIND11_MODULE(enzyme_call, m) {
llvm::InitializeAllAsmParsers();
EnzymeAlwaysInlineDiff.setValue(true);

mlir::registerAllPasses();

mlir::mhlo::registerAllMhloPasses();
xla::cpu::registerCpuTransformsPasses();
mlir::hlo::registerLMHLOTransformsPasses();
xla::runtime::registerRuntimeTransformsPasses();
xla::registerMathTransformsPasses();
xla::registerMemrefTransformsPasses();

mlir::registerShapePasses();
mlir::registerConvertShapeToStandardPass();
mlir::registerConvertShapeConstraintsPass();
mlir::memref::registerResolveShapedTypeResultDims();
mlir::registerLinalgPasses();
mlir::registerReconcileUnrealizedCastsPass();
mlir::registerConversionPasses();
mlir::bufferization::registerBufferizationPasses();
mlir::registerAsyncPasses();
mlir::arith::registerArithPasses();
mlir::memref::registerMemRefPasses();

pybind11::enum_<Language>(m, "Language")
.value("CPP", Language::CPP)
.value("LLVM", Language::LLVM)
Expand All @@ -1002,8 +1038,8 @@ PYBIND11_MODULE(enzyme_call, m) {
[](const std::string &source, const std::string &fn,
const pybind11::list &py_out_shapes,
const pybind11::list &py_in_shapes, pybind11::object pyargv,
ABI mode, Language lang,
bool xla_runtime) -> std::tuple<size_t, size_t> {
ABI mode, Language lang, bool xla_runtime,
const std::string &pass_pipeline) -> std::tuple<size_t, size_t> {
llvm::SmallVector<llvm::SmallVector<int64_t>> out_shapes;
out_shapes.reserve(pybind11::len(py_out_shapes));
llvm::SmallVector<llvm::SmallVector<int64_t>> in_shapes;
Expand Down Expand Up @@ -1039,20 +1075,22 @@ PYBIND11_MODULE(enzyme_call, m) {
}
return CpuKernel::create(fn, source, out_shapes, out_types, in_shapes,
in_types, pyargv.ptr(), mode, (Language)lang,
xla_runtime);
xla_runtime, pass_pipeline);
});

m.def(
"tmp_size",
[](const std::string &source, Language lang, bool xla_runtime) -> size_t {
return CpuKernel::tempSize(source, (Language)lang, xla_runtime);
});
m.def("tmp_size",
[](const std::string &source, Language lang, bool xla_runtime,
const std::string &pass_pipeline) -> size_t {
return CpuKernel::tempSize(source, (Language)lang, xla_runtime,
pass_pipeline);
});

m.def("tape_and_tmp_size",
[](const std::string &source, const std::string &fn,
const pybind11::list &py_out_shapes,
const pybind11::list &py_in_shapes, pybind11::object pyargv,
Language lang, bool xla_runtime) -> std::pair<size_t, size_t> {
Language lang, bool xla_runtime,
const std::string &pass_pipeline) -> std::pair<size_t, size_t> {
llvm::SmallVector<llvm::SmallVector<int64_t>> out_shapes;
out_shapes.reserve(pybind11::len(py_out_shapes));
llvm::SmallVector<llvm::SmallVector<int64_t>> in_shapes;
Expand Down Expand Up @@ -1086,9 +1124,9 @@ PYBIND11_MODULE(enzyme_call, m) {
target.push_back(nested_element.cast<int64_t>());
}
}
return CpuKernel::tapeAndTempSize(fn, source, out_shapes, out_types,
in_shapes, in_types, pyargv.ptr(),
(Language)lang, xla_runtime);
return CpuKernel::tapeAndTempSize(
fn, source, out_shapes, out_types, in_shapes, in_types,
pyargv.ptr(), (Language)lang, xla_runtime, pass_pipeline);
});

m.def("get_cpu_callback", []() {
Expand All @@ -1097,9 +1135,11 @@ PYBIND11_MODULE(enzyme_call, m) {
});

m.def("compile_mhlo_to_llvm_with_xla",
[](const std::string &mhlo_text, bool xla_runtime) {
[](const std::string &mhlo_text, bool xla_runtime,
const std::string &pass_pipeline) {
std::string llvm_ir;
compile_mhlo_to_llvm_with_xla(mhlo_text, llvm_ir, xla_runtime);
compile_mhlo_to_llvm_with_xla(mhlo_text, llvm_ir, xla_runtime,
pass_pipeline);
return llvm_ir;
});
}
Loading

0 comments on commit 7c7441a

Please sign in to comment.