diff --git a/src/enzyme_ad/jax/BUILD b/src/enzyme_ad/jax/BUILD index 4379f2e5..6f32d78e 100644 --- a/src/enzyme_ad/jax/BUILD +++ b/src/enzyme_ad/jax/BUILD @@ -112,6 +112,7 @@ pybind_extension( "@llvm-project//llvm:OrcJIT", "@llvm-project//llvm:OrcTargetProcess", "@llvm-project//llvm:Support", + "@llvm-project//mlir:AllPassesAndDialects", ":clang_compile", ":compile_with_xla", "@com_google_absl//absl/status:statusor", diff --git a/src/enzyme_ad/jax/compile_with_xla.cc b/src/enzyme_ad/jax/compile_with_xla.cc index eef2c8ba..224ef531 100644 --- a/src/enzyme_ad/jax/compile_with_xla.cc +++ b/src/enzyme_ad/jax/compile_with_xla.cc @@ -33,7 +33,8 @@ // Compile an MHLO module given as a string to LLVM IR using XLA. std::unique_ptr compile_mhlo_to_llvm_with_xla(llvm::StringRef mhlo_text, std::string &output, - bool xla_runtime) { + bool xla_runtime, + const std::string &pass_pipeline) { // Parse MLIR. mlir::MLIRContext context; context.loadDialect(); @@ -103,6 +104,10 @@ compile_mhlo_to_llvm_with_xla(llvm::StringRef mhlo_text, std::string &output, build_options.mutable_debug_options()->set_xla_cpu_use_xla_runtime( xla_runtime); + build_options.mutable_debug_options() + ->mutable_xla_backend_extra_options() + ->emplace("xla_cpu_experimental_override_pipeline", pass_pipeline); + if (build_options.device_ordinal() == -1) { build_options.set_device_ordinal(local_client->default_device_ordinal()); } diff --git a/src/enzyme_ad/jax/compile_with_xla.h b/src/enzyme_ad/jax/compile_with_xla.h index 106ed918..a2e8f3ac 100644 --- a/src/enzyme_ad/jax/compile_with_xla.h +++ b/src/enzyme_ad/jax/compile_with_xla.h @@ -5,4 +5,5 @@ // Compile an MHLO module given as a string to LLVM IR using XLA. std::unique_ptr compile_mhlo_to_llvm_with_xla(llvm::StringRef mhlo_text, std::string &output, - bool xla_runtime); + bool xla_runtime, + const std::string &pass_pipeline); diff --git a/src/enzyme_ad/jax/enzyme_call.cc b/src/enzyme_ad/jax/enzyme_call.cc index 07a09dac..46ac1686 100644 --- a/src/enzyme_ad/jax/enzyme_call.cc +++ b/src/enzyme_ad/jax/enzyme_call.cc @@ -38,6 +38,19 @@ #include "llvm/Support/TargetSelect.h" #include "llvm/Support/raw_ostream.h" +#include "mlir/InitAllPasses.h" +#include "xla/mlir/backends/cpu/transforms/passes.h" +#include "xla/mlir/math/transforms/passes.h" +#include "xla/mlir/memref/transforms/passes.h" +#include "xla/mlir/runtime/transforms/passes.h" +#include "xla/mlir_hlo/deallocation/transforms/passes.h" +#include "xla/mlir_hlo/lhlo/IR/lhlo_ops.h" +#include "xla/mlir_hlo/lhlo/transforms/passes.h" +#include "xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops.h" +#include "xla/mlir_hlo/mhlo/transforms/passes.h" + +#include "xla/mlir_hlo/transforms/passes.h" + #include "compile_with_xla.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_instructions.h" @@ -81,7 +94,8 @@ class CpuKernel { llvm::ArrayRef out_names, llvm::ArrayRef> in_shapes, llvm::ArrayRef in_names, PyObject *pyargv, - ABI mode, Language lang, bool xla_runtime) { + ABI mode, Language lang, bool xla_runtime, + const std::string &pass_pipeline) { auto llvm_ctx = std::make_unique(); std::string input; @@ -102,8 +116,8 @@ class CpuKernel { break; case Language::MHLO: { - local_executable = - compile_mhlo_to_llvm_with_xla(source, stringbuf, xla_runtime); + local_executable = compile_mhlo_to_llvm_with_xla( + source, stringbuf, xla_runtime, pass_pipeline); auto *cpu_executable = static_cast( local_executable->executable()); auto &assignment = cpu_executable->buffer_assignment(); @@ -830,11 +844,12 @@ class CpuKernel { llvm::ArrayRef out_names, llvm::ArrayRef> in_shapes, llvm::ArrayRef in_names, PyObject *pyargv, - Language lang, bool xla_runtime) { + Language lang, bool xla_runtime, + const std::string &pass_pipeline) { auto mode = ABI::Tape; auto [mod, llvm_ctx, num_out, tmpBuf] = createLLVMMod(fn, source, out_shapes, out_names, in_shapes, in_names, - pyargv, mode, lang, xla_runtime); + pyargv, mode, lang, xla_runtime, pass_pipeline); auto lfn = mod->getFunction("entry"); auto RI = llvm::cast(lfn->getEntryBlock().getTerminator()); @@ -846,12 +861,12 @@ class CpuKernel { } static size_t tempSize(llvm::StringRef source, Language lang, - bool xla_runtime) { + bool xla_runtime, const std::string &pass_pipeline) { switch (lang) { case Language::MHLO: { std::string llvm_ir; - auto local_executable = - compile_mhlo_to_llvm_with_xla(source, llvm_ir, xla_runtime); + auto local_executable = compile_mhlo_to_llvm_with_xla( + source, llvm_ir, xla_runtime, pass_pipeline); auto *cpu_executable = static_cast( local_executable->executable()); auto &assignment = cpu_executable->buffer_assignment(); @@ -868,13 +883,13 @@ class CpuKernel { llvm::ArrayRef out_names, llvm::ArrayRef> in_shapes, llvm::ArrayRef in_names, PyObject *pyargv, ABI mode, - Language lang, bool xla_runtime) { + Language lang, bool xla_runtime, const std::string &pass_pipeline) { llvm::sys::SmartScopedWriter lock(kernel_mutex); size_t identifier = last_identifier++; auto [mod, llvm_ctx, num_out, tmpBuf] = createLLVMMod(fn, source, out_shapes, out_names, in_shapes, in_names, - pyargv, mode, lang, xla_runtime); + pyargv, mode, lang, xla_runtime, pass_pipeline); if (!JIT) { DL = std::make_unique(mod.get()); @@ -986,6 +1001,27 @@ PYBIND11_MODULE(enzyme_call, m) { llvm::InitializeAllAsmParsers(); EnzymeAlwaysInlineDiff.setValue(true); + mlir::registerAllPasses(); + + mlir::mhlo::registerAllMhloPasses(); + xla::cpu::registerCpuTransformsPasses(); + mlir::hlo::registerLMHLOTransformsPasses(); + xla::runtime::registerRuntimeTransformsPasses(); + xla::registerMathTransformsPasses(); + xla::registerMemrefTransformsPasses(); + + mlir::registerShapePasses(); + mlir::registerConvertShapeToStandardPass(); + mlir::registerConvertShapeConstraintsPass(); + mlir::memref::registerResolveShapedTypeResultDims(); + mlir::registerLinalgPasses(); + mlir::registerReconcileUnrealizedCastsPass(); + mlir::registerConversionPasses(); + mlir::bufferization::registerBufferizationPasses(); + mlir::registerAsyncPasses(); + mlir::arith::registerArithPasses(); + mlir::memref::registerMemRefPasses(); + pybind11::enum_(m, "Language") .value("CPP", Language::CPP) .value("LLVM", Language::LLVM) @@ -1002,8 +1038,8 @@ PYBIND11_MODULE(enzyme_call, m) { [](const std::string &source, const std::string &fn, const pybind11::list &py_out_shapes, const pybind11::list &py_in_shapes, pybind11::object pyargv, - ABI mode, Language lang, - bool xla_runtime) -> std::tuple { + ABI mode, Language lang, bool xla_runtime, + const std::string &pass_pipeline) -> std::tuple { llvm::SmallVector> out_shapes; out_shapes.reserve(pybind11::len(py_out_shapes)); llvm::SmallVector> in_shapes; @@ -1039,20 +1075,22 @@ PYBIND11_MODULE(enzyme_call, m) { } return CpuKernel::create(fn, source, out_shapes, out_types, in_shapes, in_types, pyargv.ptr(), mode, (Language)lang, - xla_runtime); + xla_runtime, pass_pipeline); }); - m.def( - "tmp_size", - [](const std::string &source, Language lang, bool xla_runtime) -> size_t { - return CpuKernel::tempSize(source, (Language)lang, xla_runtime); - }); + m.def("tmp_size", + [](const std::string &source, Language lang, bool xla_runtime, + const std::string &pass_pipeline) -> size_t { + return CpuKernel::tempSize(source, (Language)lang, xla_runtime, + pass_pipeline); + }); m.def("tape_and_tmp_size", [](const std::string &source, const std::string &fn, const pybind11::list &py_out_shapes, const pybind11::list &py_in_shapes, pybind11::object pyargv, - Language lang, bool xla_runtime) -> std::pair { + Language lang, bool xla_runtime, + const std::string &pass_pipeline) -> std::pair { llvm::SmallVector> out_shapes; out_shapes.reserve(pybind11::len(py_out_shapes)); llvm::SmallVector> in_shapes; @@ -1086,9 +1124,9 @@ PYBIND11_MODULE(enzyme_call, m) { target.push_back(nested_element.cast()); } } - return CpuKernel::tapeAndTempSize(fn, source, out_shapes, out_types, - in_shapes, in_types, pyargv.ptr(), - (Language)lang, xla_runtime); + return CpuKernel::tapeAndTempSize( + fn, source, out_shapes, out_types, in_shapes, in_types, + pyargv.ptr(), (Language)lang, xla_runtime, pass_pipeline); }); m.def("get_cpu_callback", []() { @@ -1097,9 +1135,11 @@ PYBIND11_MODULE(enzyme_call, m) { }); m.def("compile_mhlo_to_llvm_with_xla", - [](const std::string &mhlo_text, bool xla_runtime) { + [](const std::string &mhlo_text, bool xla_runtime, + const std::string &pass_pipeline) { std::string llvm_ir; - compile_mhlo_to_llvm_with_xla(mhlo_text, llvm_ir, xla_runtime); + compile_mhlo_to_llvm_with_xla(mhlo_text, llvm_ir, xla_runtime, + pass_pipeline); return llvm_ir; }); } diff --git a/src/enzyme_ad/jax/primitives.py b/src/enzyme_ad/jax/primitives.py index 416a1be8..8c54a76a 100644 --- a/src/enzyme_ad/jax/primitives.py +++ b/src/enzyme_ad/jax/primitives.py @@ -22,7 +22,157 @@ LANG_LLVM = enzyme_call.Language.LLVM LANG_MHLO = enzyme_call.Language.MHLO -xla_runtime = True + +def xla_runtime(options): + return True + + +def pass_pipeline(options): + return """ + inline{default-pipeline=canonicalize max-iterations=4}, + expand-hlo-tuples{entry-function=main}, + func.func(mhlo-flatten-tuple), + xla-legalize-abi, + func.func(mhlo-test-lower-general-dot), + func.func(mhlo-broadcast-propagation), + cse, + canonicalize{ + max-iterations=10 + max-num-rewrites=-1 + region-simplify=true + test-convergence=false + top-down=true}, + func.func(xla-sparse-custom-call-to-pack), + func.func(legalize-sparse-ops{legalize-to-custom-calls=false}), + func.func(chlo-legalize-to-hlo{ + expand-compositions=true legalize-broadcasts=true}), + func.func(mhlo-sparse-rewriting), + func.func(mhlo-legalize-control-flow), + func.func(mhlo-legalize-dot-general-to-dot), + hlo-legalize-to-arithmetic, + func.func(xla-legalize-library-ops), + func.func(mhlo-expand-ops-simplifier), + func.func(hlo-canonicalize-scatter), + func.func(hlo-canonicalize-dot), + func.func(group-reduction-dimensions{prefer-columns-reductions=true}), + func.func(hlo-legalize-to-linalg{enable-primitive-ops=false}), + func.func(lower-index-cast), + convert-to-signless, + func.func(shape-simplification), + func.func(shape-to-shape-lowering), + convert-shape-to-std, + func.func(convert-shape-constraints), + cse, + resolve-shaped-type-result-dims, + canonicalize{ + max-iterations=10 + max-num-rewrites=-1 + region-simplify=true + test-convergence=false + top-down=true}, + func.func(linalg-fuse-elementwise-ops), + reconcile-unrealized-casts, + convert-tensor-to-linalg, + func.func(detensorize-scf-ops), + func.func(linalg-detensorize{aggressive-mode=true}), + eliminate-empty-tensors, + func.func(empty-tensor-to-alloc-tensor), + canonicalize{ + max-iterations=10 + max-num-rewrites=-1 + region-simplify=true + test-convergence=false + top-down=true}, + func.func(linalg-generalize-named-ops), + eliminate-empty-tensors, + sparsification-and-bufferization, + sparse-storage-specifier-to-llvm, + func.func(canonicalize{ + max-iterations=10 + max-num-rewrites=-1 + region-simplify=true + test-convergence=false + top-down=true}), + func.func(finalizing-bufferize), + func.func(xla-rewrite-realloc-to-alloc), + func.func(vectorize-copy), + func.func(naive-copy-removal), + func.func(convert-linalg-to-loops), + cse, + canonicalize{ + max-iterations=10 + max-num-rewrites=-1 + region-simplify=true + test-convergence=false + top-down=true}, + buffer-results-to-out-params, + func.func(promote-buffers-to-stack{ + max-alloc-size-in-bytes=1024 + max-rank-of-allocated-memref=1}), + func.func(buffer-deallocation), + convert-bufferization-to-memref, + func.func(xla-remove-copies-to-out-params), + cse, + canonicalize{ + max-iterations=10 + max-num-rewrites=-1 + region-simplify=true + test-convergence=false + top-down=true}, + func.func(convert-complex-to-standard), + cse, + canonicalize{ + max-iterations=10 + max-num-rewrites=-1 + region-simplify=true + test-convergence=false + top-down=true}, + func.func(convert-vector-to-scf{ + full-unroll=false + lower-tensors=false + target-rank=1}), + func.func(xla-legalize-i1-vector-transfers), + func.func(xla-convert-memref-element-cast-to-llvm), + async-func-to-async-runtime, + xla-rt-export-functions, + xla-cpu-to-cpu-runtime, + xla-rt-convert-custom-calls, + xla-rt-convert-asserts, + inline{default-pipeline=canonicalize max-iterations=4}, + canonicalize{ + max-iterations=10 + max-num-rewrites=-1 + region-simplify=true + test-convergence=false + top-down=true}, + cse, + func.func(xla-math-approximation{oplist=all}), + func.func(convert-linalg-to-parallel-loops), + canonicalize{ + max-iterations=10 + max-num-rewrites=-1 + region-simplify=true + test-convergence=false + top-down=true}, + async-to-async-runtime, + xla-rt-move-allocas-to-entry-block, + async-runtime-policy-based-ref-counting, + func.func(arith-expand{include-bf16=false}), + func.func(memref-expand), + func.func(expand-strided-metadata), + lower-affine, + func.func(xla-memref-aligned-allocations{alignment=0}), + xla-rt-to-llvm, + convert-async-to-llvm, + generic-host-to-llvm{enable-avx2=false}, + reconcile-unrealized-casts, + canonicalize{ + max-iterations=10 + max-num-rewrites=-1 + region-simplify=true + test-convergence=false + top-down=true}, + cse""" def resource_dir(): @@ -69,6 +219,7 @@ def _enzyme_primal_impl( argv: Sequence[str], out_shapes: Sequence[jax.core.ShapedArray], lang: enzyme_call.Language, + pipeline_options ) -> Sequence[jax.Array]: del args_flat, source, out_shapes raise RuntimeError("must be JIT'ed") @@ -81,6 +232,7 @@ def _enzyme_fwd_impl( argv: Sequence[str], out_shapes: Sequence[jax.core.ShapedArray], lang: enzyme_call.Language, + pipeline_options ) -> Sequence[jax.Array]: del args_flat, source, out_shapes raise RuntimeError("must be JIT'ed") @@ -93,6 +245,7 @@ def _enzyme_aug_impl( argv: Sequence[str], out_shapes: Sequence[jax.core.ShapedArray], lang: enzyme_call.Language, + pipeline_options ) -> Sequence[jax.Array]: del args_flat, source, out_shapes raise RuntimeError("must be JIT'ed") @@ -105,6 +258,7 @@ def _enzyme_shadow_aug_impl( argv: Sequence[str], out_shapes: Sequence[jax.core.ShapedArray], lang: enzyme_call.Language, + pipeline_options ) -> Sequence[jax.Array]: del args_flat, source, out_shapes raise RuntimeError("must be JIT'ed") @@ -117,6 +271,7 @@ def _enzyme_rev_impl( argv: Sequence[str], in_shapes, lang: enzyme_call.Language, + pipeline_options ) -> Sequence[jax.Array]: del args_flat, source, out_shapes raise RuntimeError("must be JIT'ed") @@ -129,6 +284,7 @@ def _enzyme_primal_abstract_eval( argv: Sequence[str], out_shapes: Sequence[jax.core.ShapedArray], lang: enzyme_call.Language, + pipeline_options ) -> Sequence[jax.core.ShapedArray]: # TODO: we may attempt some lightweight parsing of source to extract the # result types instead. @@ -142,6 +298,7 @@ def _enzyme_fwd_abstract_eval( argv: Sequence[str], out_shapes: Sequence[jax.core.ShapedArray], lang: enzyme_call.Language, + pipeline_options ) -> Sequence[jax.core.ShapedArray]: del source, fn, args_flat return tuple(o for o in out_shapes for _ in range(2)) @@ -160,6 +317,7 @@ def _enzyme_aug_abstract_eval( argv: Sequence[str], out_shapes: Sequence[jax.core.ShapedArray], lang: enzyme_call.Language, + pipeline_options ) -> Sequence[jax.core.ShapedArray]: in_shapes = args_flat @@ -181,7 +339,14 @@ def _enzyme_aug_abstract_eval( argv = argv + ("-resource-dir", resource_dir()) + cflags() tapeSize, tmpSize = enzyme_call.tape_and_tmp_size( - source, fn, out_shapes, in_shapes, argv, lang, xla_runtime + source, + fn, + out_shapes, + in_shapes, + argv, + lang, + xla_runtime(pipeline_options), + pass_pipeline(pipeline_options), ) res = tuple(prev_out_shapes) + ( jax.core.ShapedArray((tapeSize,), (jax.numpy.int8)), @@ -196,6 +361,7 @@ def _enzyme_shadow_aug_abstract_eval( argv: Sequence[str], out_shapes: Sequence[jax.core.ShapedArray], lang: enzyme_call.Language, + pipeline_options ) -> Sequence[jax.core.ShapedArray]: return out_shapes @@ -207,6 +373,7 @@ def _enzyme_rev_abstract_eval( argv: Sequence[str], in_shapes, lang: enzyme_call.Language, + pipeline_options ) -> Sequence[jax.core.ShapedArray]: return tuple( jax.core.ShapedArray(shape, dejaxify(tyid)) for (shape, tyid) in in_shapes @@ -233,6 +400,7 @@ def _enzyme_primal_lowering( argv: Sequence[str], out_shapes: Sequence[jax.core.ShapedArray], lang: enzyme_call.Language, + pipeline_options ) -> Sequence[ir.Value]: del out_shapes @@ -262,7 +430,8 @@ def _enzyme_primal_lowering( argv, enzyme_call.ABI.Primal, lang, - xla_runtime, + xla_runtime(pipeline_options), + pass_pipeline(pipeline_options), ) identifier_attr = jax_mlir.dense_int_elements([identifier]) identifier_op = stablehlo.ConstantOp(identifier_attr) @@ -293,6 +462,7 @@ def _enzyme_fwd_lowering( argv: Sequence[str], out_shapes: Sequence[jax.core.ShapedArray], lang: enzyme_call.Language, + pipeline_options ) -> Sequence[ir.Value]: del out_shapes @@ -323,7 +493,8 @@ def _enzyme_fwd_lowering( argv, enzyme_call.ABI.Forward, lang, - xla_runtime, + xla_runtime(pipeline_options), + pass_pipeline(pipeline_options), ) identifier_attr = jax_mlir.dense_int_elements([identifier]) identifier_op = stablehlo.ConstantOp(identifier_attr) @@ -353,6 +524,7 @@ def _enzyme_aug_lowering( argv: Sequence[str], out_shapes: Sequence[jax.core.ShapedArray], lang: enzyme_call.Language, + pipeline_options ) -> Sequence[ir.Value]: del out_shapes @@ -383,7 +555,8 @@ def _enzyme_aug_lowering( argv, enzyme_call.ABI.Augmented, lang, - xla_runtime, + xla_runtime(pipeline_options), + pass_pipeline(pipeline_options), ) identifier_attr = jax_mlir.dense_int_elements([identifier]) identifier_op = stablehlo.ConstantOp(identifier_attr) @@ -411,6 +584,7 @@ def _enzyme_rev_lowering( argv: Sequence[str], in_shapes: Sequence[jax.core.ShapedArray], lang: enzyme_call.Language, + pipeline_options ) -> Sequence[ir.Value]: del in_shapes @@ -450,7 +624,8 @@ def _enzyme_rev_lowering( argv, enzyme_call.ABI.Reverse, lang, - xla_runtime, + xla_runtime(pipeline_options), + pass_pipeline(pipeline_options), ) identifier_attr = jax_mlir.dense_int_elements([identifier]) identifier_op = stablehlo.ConstantOp(identifier_attr) @@ -497,9 +672,16 @@ def ffi_call( fn: str = "f", argv: tuple[str] = (), lang: int = LANG_CPP, + pipeline_options=None ): return _enzyme_primal_p.bind( - *args, source=source, fn=fn, argv=argv, out_shapes=out_shapes, lang=lang + *args, + source=source, + fn=fn, + argv=argv, + out_shapes=out_shapes, + lang=lang, + pipeline_options=pipeline_options ) @@ -509,9 +691,16 @@ def cpp_call( source: str, fn: str = "f", argv: tuple[str] = (), + pipeline_options=None ): return ffi_call( - *args, source=source, fn=fn, argv=argv, out_shapes=out_shapes, lang=LANG_CPP + *args, + source=source, + fn=fn, + argv=argv, + out_shapes=out_shapes, + lang=LANG_CPP, + pipeline_options=pipeline_options ) @@ -550,6 +739,7 @@ def make_zero(tan, prim): argv=kwargs["argv"], out_shapes=kwargs["out_shapes"], lang=kwargs["lang"], + pipeline_options=kwargs["pipeline_options"] ) res = (shadconv[0::2], shadconv[1::2]) return res @@ -640,7 +830,7 @@ def enzyme_vjp(shadow_rets, *prim_args, **kwargs): ad.primitive_transposes[_enzyme_shadow_aug_p] = enzyme_vjp -def enzyme_jax_ir(argv=()): +def enzyme_jax_ir(argv=(), pipeline_options=None): def decorator(func: Callable[..., Any]) -> Callable[..., Any]: @jax.jit def wrapped(*args: Any): @@ -657,6 +847,7 @@ def wrapped(*args: Any): out_shapes=out_shape_flat, argv=argv, lang=LANG_MHLO, + pipeline_options=pipeline_options ) return jax.tree_util.tree_unflatten(out_tree, out_flat)