diff --git a/WORKSPACE b/WORKSPACE index 99086d31..08c33578 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -33,13 +33,18 @@ llvm_configure(name = "llvm-project", targets = LLVM_TARGETS) XLA_COMMIT = "c5163ff997d8be8fd32136e25050fa32c67c989f" XLA_SHA256 = "" -http_archive( - name = "xla", - sha256 = XLA_SHA256, - strip_prefix = "xla-" + XLA_COMMIT, - urls = ["https://github.com/wsmoses/xla/archive/{commit}.tar.gz".format(commit = XLA_COMMIT)], - patch_args = ["-p1"], - patches = ["//:patches/xla.patch"], +# http_archive( +# name = "xla", +# sha256 = XLA_SHA256, +# strip_prefix = "xla-" + XLA_COMMIT, +# urls = ["https://github.com/wsmoses/xla/archive/{commit}.tar.gz".format(commit = XLA_COMMIT)], +# patch_args = ["-p1"], +# patches = ["//:patches/xla.patch"], +# ) + +local_repository( + name = "xlae", + path = "./xla" ) PYRULES_COMMIT = "fe33a4582c37499f3caeb49a07a78fc7948a8949" @@ -60,16 +65,22 @@ load("@rules_python//python/pip_install:repositories.bzl", "pip_install_dependen pip_install_dependencies() -ENZYME_COMMIT = "ed28cb68ccf47b5ff2594421ad62f878be562b03" -ENZYME_SHA256 = "" - -http_archive( +# ENZYME_COMMIT = "ed28cb68ccf47b5ff2594421ad62f878be562b03" +# ENZYME_SHA256 = "" +# +# http_archive( +# name = "enzyme", +# sha256 = ENZYME_SHA256, +# strip_prefix = "Enzyme-" + ENZYME_COMMIT + "/enzyme", +# urls = ["https://github.com/EnzymeAD/Enzyme/archive/{commit}.tar.gz".format(commit = ENZYME_COMMIT)], +# ) + +local_repository( name = "enzyme", - sha256 = ENZYME_SHA256, - strip_prefix = "Enzyme-" + ENZYME_COMMIT + "/enzyme", - urls = ["https://github.com/EnzymeAD/Enzyme/archive/{commit}.tar.gz".format(commit = ENZYME_COMMIT)], + path = "../Enzyme/enzyme", ) + JAX_COMMIT = "9a098e922aff62a3b49bd673b9518d97ee599248" JAX_SHA256 = "" diff --git a/src/enzyme_ad/jax/.primitives.py.swp b/src/enzyme_ad/jax/.primitives.py.swp new file mode 100644 index 00000000..59871542 Binary files /dev/null and b/src/enzyme_ad/jax/.primitives.py.swp differ diff --git a/src/enzyme_ad/jax/compile_with_xla.cc b/src/enzyme_ad/jax/compile_with_xla.cc index 224ef531..0052863d 100644 --- a/src/enzyme_ad/jax/compile_with_xla.cc +++ b/src/enzyme_ad/jax/compile_with_xla.cc @@ -30,13 +30,17 @@ #include "compile_with_xla.h" +#include "Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h" + // Compile an MHLO module given as a string to LLVM IR using XLA. std::unique_ptr compile_mhlo_to_llvm_with_xla(llvm::StringRef mhlo_text, std::string &output, bool xla_runtime, const std::string &pass_pipeline) { // Parse MLIR. - mlir::MLIRContext context; + mlir::DialectRegistry registry; + mlir::enzyme::registerCoreDialectAutodiffInterfaces(registry); + mlir::MLIRContext context(registry); context.loadDialect(); context.loadDialect(); context.loadDialect(); @@ -132,7 +136,7 @@ compile_mhlo_to_llvm_with_xla(llvm::StringRef mhlo_text, std::string &output, xla_computation.proto(), std::move(module_config_or_error.value()), local_client->mutable_backend(), executor.value(), {build_options.device_allocator(), build_options.compile_thread_pool(), - build_options.layout_canonicalization_callback()}, + build_options.layout_canonicalization_callback(), ®istry}, build_options.run_backend_only()); if (!executable.ok()) { throw pybind11::value_error(executable.status().ToString()); diff --git a/src/enzyme_ad/jax/enzyme_call.cc b/src/enzyme_ad/jax/enzyme_call.cc index 13eb423c..64a0cdfe 100644 --- a/src/enzyme_ad/jax/enzyme_call.cc +++ b/src/enzyme_ad/jax/enzyme_call.cc @@ -57,6 +57,7 @@ #include "xla/service/cpu/cpu_executable.h" #include "Enzyme/FunctionUtils.h" +#include "Enzyme/MLIR/Passes/Passes.h" enum class ABI { Primal, Forward, Augmented, Reverse, Tape }; @@ -204,6 +205,7 @@ class CpuKernel { nullptr); } } + llvm::errs() << "linkMod: " << *linkMod << "\n"; } if (xla_runtime) { ss << " extern \"C\" void " << fn << "(void* exec"; @@ -831,10 +833,12 @@ class CpuKernel { #endif } + llvm::errs() << " str: " << ss.str() << "\n"; auto mod = GetLLVMFromJob("/enzyme_call/source.cpp", ss.str(), /*cpp*/ true, pyargv_strs, llvm_ctx.get(), std::move(linkMod)); if (!mod) throw pybind11::value_error("failed to compile C++"); + llvm::errs() << " postmod: " << *mod << "\n"; return std::make_tuple(std::move(mod), std::move(llvm_ctx), out_off, tmpBuf); } @@ -1022,6 +1026,7 @@ PYBIND11_MODULE(enzyme_call, m) { mlir::registerAsyncPasses(); mlir::arith::registerArithPasses(); mlir::memref::registerMemRefPasses(); + mlir::registerenzymePasses(); pybind11::enum_(m, "Language") .value("CPP", Language::CPP) diff --git a/src/enzyme_ad/jax/primitives.py b/src/enzyme_ad/jax/primitives.py index 1a335522..66f1ab75 100644 --- a/src/enzyme_ad/jax/primitives.py +++ b/src/enzyme_ad/jax/primitives.py @@ -39,7 +39,7 @@ def mlir_ad(self): class OldXLAPipeline: def xla_runtime(self): - raise False + return False def pass_pipeline(self): return "" @@ -195,11 +195,12 @@ def __init__(self, passes=None, mlirad=False): test-convergence=false top-down=true}, cse""" + assert len(passes) != 0 self.passes = passes self.mlirad = mlirad def xla_runtime(self): - raise False + return True def pass_pipeline(self): return self.passes @@ -207,7 +208,7 @@ def pass_pipeline(self): def mlir_ad(self): return self.mlirad -DefaultPipeline = NewXLAPipeline("", True) +DefaultPipeline = NewXLAPipeline(None, True) def pass_pipeline(options): if type(options) == type(""): @@ -369,7 +370,7 @@ def _enzyme_aug_abstract_eval( in_shapes = [absmaketup(a) for a in in_shapes] if lang == LANG_MHLO: - (in_tree, func) = source + (in_tree, _, func) = source avals_in = jax.tree_util.tree_unflatten(in_tree, args_flat) lowered_func = jax.jit(func).lower(*avals_in) mhlo = lowered_func.compiler_ir(dialect="stablehlo") @@ -453,15 +454,31 @@ def _enzyme_primal_lowering( in_args = (*args_flat,) if lang == LANG_MHLO: - (in_tree, func) = source - avals_in = jax.tree_util.tree_unflatten(in_tree, ctx.avals_in) + (in_tree, in_idx_map, func) = source + in_idxs = sorted(set(v for _, v in in_idx_map.items())) + avals = [ctx.avals_in[i] for i in in_idxs] + avals_in = jax.tree_util.tree_unflatten(in_tree, avals) lowered_func = jax.jit(func).lower(*avals_in) mhlo = lowered_func.compiler_ir(dialect="stablehlo") source = str(mhlo) + print(in_idx_map) + print("source", source) kept = lowered_func.compile()._executable._kept_var_idx - in_args = tuple(arg for (i, arg) in enumerate(in_args) if i in kept) - in_shapes = [shape for (i, shape) in enumerate(in_shapes) if i in kept] - + in_args = tuple(arg for (i, arg) in enumerate(in_args) if in_idx_map[i] in kept) + orig_shapes = [] + seen = [] + for (i, shape) in enumerate(in_shapes): + if in_idx_map[i] in seen: + continue + seen.append(in_idx_map[i]) + orig_shapes.append(shape) + in_shapes = [shape for (i, shape) in enumerate(orig_shapes) if i in kept] + print("in args", in_args) + print("in shapes", in_shapes) + + print(pipeline_options) + print(pipeline_options.xla_runtime()) + print(pipeline_options.pass_pipeline()) argv = argv + ("-resource-dir", resource_dir()) + cflags() identifier, tmpBuf = enzyme_call.create_enzyme_cpu_kernel( source, @@ -486,6 +503,7 @@ def _enzyme_primal_lowering( custom_call = stablehlo.CustomCallOp( out_types, mlir_args, call_target_name="jaxzyme.primal" ) + print(custom_call) results = custom_call.results @@ -516,7 +534,7 @@ def _enzyme_fwd_lowering( in_args = (*args_flat,) if lang == LANG_MHLO: - (in_tree, func) = source + (in_tree, _, func) = source avals_in = jax.tree_util.tree_unflatten(in_tree, ctx.avals_in[::2]) lowered_func = jax.jit(func).lower(*avals_in) mhlo = lowered_func.compiler_ir(dialect="stablehlo") @@ -578,7 +596,7 @@ def _enzyme_aug_lowering( in_args = (*args_flat,) if lang == LANG_MHLO: - (in_tree, func) = source + (in_tree, _, func) = source avals_in = jax.tree_util.tree_unflatten(in_tree, ctx.avals_in) lowered_func = jax.jit(func).lower(*avals_in) mhlo = lowered_func.compiler_ir(dialect="stablehlo") @@ -644,7 +662,7 @@ def _enzyme_rev_lowering( kept = None if lang == LANG_MHLO: - (in_tree, func) = source + (in_tree, _, func) = source avals_in = jax.tree_util.tree_unflatten(in_tree, ctx.avals_out) lowered_func = jax.jit(func).lower(*avals_in) mhlo = lowered_func.compiler_ir(dialect="stablehlo") @@ -785,7 +803,10 @@ def make_zero(tan, prim): for o in kwargs["out_shapes"]: outshapes2.append(o) outshapes2.append(o) - shadconv = ffi_call(*args, out_shapes=outshapes2, source=kwargs["source"], fn=kwargs["fn"], argv=kwargs["argv"], lang=kwargs["lang"], pipeline_options=pipeline_options) + (in_tree, in_idx_map, func) = kwargs["source"] + avals = {2*k:v for k, v in in_idx_map.items()} | {2*k+1:v for k, v in in_idx_map.items()} + source = (in_tree, avals, func) + shadconv = ffi_call(*args, out_shapes=outshapes2, source=source, fn=kwargs["fn"], argv=kwargs["argv"], lang=kwargs["lang"], pipeline_options=pipeline_options) else: shadconv = _enzyme_fwd_p.bind( *args, @@ -891,13 +912,14 @@ def decorator(func: Callable[..., Any]) -> Callable[..., Any]: def wrapped(*args: Any): args_flat, in_tree = jax.tree_util.tree_flatten(args) out_shape = jax.eval_shape(func, *args) + in_idxs = {i:i for i in range(len(args_flat))} out_shape_flat, out_tree = jax.tree_util.tree_flatten(out_shape) out_shape_flat = [ jax.core.ShapedArray(o.shape, o.dtype) for o in out_shape_flat ] out_flat = ffi_call( *args_flat, - source=(in_tree, func), + source=(in_tree, in_idxs, func), fn="", out_shapes=out_shape_flat, argv=argv,