Skip to content

Commit

Permalink
continuing
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Feb 13, 2024
1 parent 21e5d84 commit 8782988
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 30 deletions.
39 changes: 25 additions & 14 deletions WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,18 @@ llvm_configure(name = "llvm-project", targets = LLVM_TARGETS)
XLA_COMMIT = "c5163ff997d8be8fd32136e25050fa32c67c989f"
XLA_SHA256 = ""

http_archive(
name = "xla",
sha256 = XLA_SHA256,
strip_prefix = "xla-" + XLA_COMMIT,
urls = ["https://github.com/wsmoses/xla/archive/{commit}.tar.gz".format(commit = XLA_COMMIT)],
patch_args = ["-p1"],
patches = ["//:patches/xla.patch"],
# http_archive(
# name = "xla",
# sha256 = XLA_SHA256,
# strip_prefix = "xla-" + XLA_COMMIT,
# urls = ["https://github.com/wsmoses/xla/archive/{commit}.tar.gz".format(commit = XLA_COMMIT)],
# patch_args = ["-p1"],
# patches = ["//:patches/xla.patch"],
# )

local_repository(
name = "xlae",
path = "./xla"
)

PYRULES_COMMIT = "fe33a4582c37499f3caeb49a07a78fc7948a8949"
Expand All @@ -60,16 +65,22 @@ load("@rules_python//python/pip_install:repositories.bzl", "pip_install_dependen

pip_install_dependencies()

ENZYME_COMMIT = "ed28cb68ccf47b5ff2594421ad62f878be562b03"
ENZYME_SHA256 = ""

http_archive(
# ENZYME_COMMIT = "ed28cb68ccf47b5ff2594421ad62f878be562b03"
# ENZYME_SHA256 = ""
#
# http_archive(
# name = "enzyme",
# sha256 = ENZYME_SHA256,
# strip_prefix = "Enzyme-" + ENZYME_COMMIT + "/enzyme",
# urls = ["https://github.com/EnzymeAD/Enzyme/archive/{commit}.tar.gz".format(commit = ENZYME_COMMIT)],
# )

local_repository(
name = "enzyme",
sha256 = ENZYME_SHA256,
strip_prefix = "Enzyme-" + ENZYME_COMMIT + "/enzyme",
urls = ["https://github.com/EnzymeAD/Enzyme/archive/{commit}.tar.gz".format(commit = ENZYME_COMMIT)],
path = "../Enzyme/enzyme",
)


JAX_COMMIT = "9a098e922aff62a3b49bd673b9518d97ee599248"
JAX_SHA256 = ""

Expand Down
Binary file added src/enzyme_ad/jax/.primitives.py.swp
Binary file not shown.
8 changes: 6 additions & 2 deletions src/enzyme_ad/jax/compile_with_xla.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,17 @@

#include "compile_with_xla.h"

#include "Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h"

// Compile an MHLO module given as a string to LLVM IR using XLA.
std::unique_ptr<xla::LocalExecutable>
compile_mhlo_to_llvm_with_xla(llvm::StringRef mhlo_text, std::string &output,
bool xla_runtime,
const std::string &pass_pipeline) {
// Parse MLIR.
mlir::MLIRContext context;
mlir::DialectRegistry registry;
mlir::enzyme::registerCoreDialectAutodiffInterfaces(registry);
mlir::MLIRContext context(registry);
context.loadDialect<mlir::arith::ArithDialect>();
context.loadDialect<mlir::func::FuncDialect>();
context.loadDialect<mlir::mhlo::MhloDialect>();
Expand Down Expand Up @@ -132,7 +136,7 @@ compile_mhlo_to_llvm_with_xla(llvm::StringRef mhlo_text, std::string &output,
xla_computation.proto(), std::move(module_config_or_error.value()),
local_client->mutable_backend(), executor.value(),
{build_options.device_allocator(), build_options.compile_thread_pool(),
build_options.layout_canonicalization_callback()},
build_options.layout_canonicalization_callback(), &registry},
build_options.run_backend_only());
if (!executable.ok()) {
throw pybind11::value_error(executable.status().ToString());
Expand Down
5 changes: 5 additions & 0 deletions src/enzyme_ad/jax/enzyme_call.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
#include "xla/service/cpu/cpu_executable.h"

#include "Enzyme/FunctionUtils.h"
#include "Enzyme/MLIR/Passes/Passes.h"

enum class ABI { Primal, Forward, Augmented, Reverse, Tape };

Expand Down Expand Up @@ -204,6 +205,7 @@ class CpuKernel {
nullptr);
}
}
llvm::errs() << "linkMod: " << *linkMod << "\n";
}
if (xla_runtime) {
ss << " extern \"C\" void " << fn << "(void* exec";
Expand Down Expand Up @@ -831,10 +833,12 @@ class CpuKernel {
#endif
}

llvm::errs() << " str: " << ss.str() << "\n";
auto mod = GetLLVMFromJob("/enzyme_call/source.cpp", ss.str(), /*cpp*/ true,
pyargv_strs, llvm_ctx.get(), std::move(linkMod));
if (!mod)
throw pybind11::value_error("failed to compile C++");
llvm::errs() << " postmod: " << *mod << "\n";
return std::make_tuple(std::move(mod), std::move(llvm_ctx), out_off,
tmpBuf);
}
Expand Down Expand Up @@ -1022,6 +1026,7 @@ PYBIND11_MODULE(enzyme_call, m) {
mlir::registerAsyncPasses();
mlir::arith::registerArithPasses();
mlir::memref::registerMemRefPasses();
mlir::registerenzymePasses();

pybind11::enum_<Language>(m, "Language")
.value("CPP", Language::CPP)
Expand Down
50 changes: 36 additions & 14 deletions src/enzyme_ad/jax/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def mlir_ad(self):

class OldXLAPipeline:
def xla_runtime(self):
raise False
return False

def pass_pipeline(self):
return ""
Expand Down Expand Up @@ -195,19 +195,20 @@ def __init__(self, passes=None, mlirad=False):
test-convergence=false
top-down=true},
cse"""
assert len(passes) != 0
self.passes = passes
self.mlirad = mlirad

def xla_runtime(self):
raise False
return True

def pass_pipeline(self):
return self.passes

def mlir_ad(self):
return self.mlirad

DefaultPipeline = NewXLAPipeline("", True)
DefaultPipeline = NewXLAPipeline(None, True)

def pass_pipeline(options):
if type(options) == type(""):
Expand Down Expand Up @@ -369,7 +370,7 @@ def _enzyme_aug_abstract_eval(
in_shapes = [absmaketup(a) for a in in_shapes]

if lang == LANG_MHLO:
(in_tree, func) = source
(in_tree, _, func) = source
avals_in = jax.tree_util.tree_unflatten(in_tree, args_flat)
lowered_func = jax.jit(func).lower(*avals_in)
mhlo = lowered_func.compiler_ir(dialect="stablehlo")
Expand Down Expand Up @@ -453,15 +454,31 @@ def _enzyme_primal_lowering(
in_args = (*args_flat,)

if lang == LANG_MHLO:
(in_tree, func) = source
avals_in = jax.tree_util.tree_unflatten(in_tree, ctx.avals_in)
(in_tree, in_idx_map, func) = source
in_idxs = sorted(set(v for _, v in in_idx_map.items()))
avals = [ctx.avals_in[i] for i in in_idxs]
avals_in = jax.tree_util.tree_unflatten(in_tree, avals)
lowered_func = jax.jit(func).lower(*avals_in)
mhlo = lowered_func.compiler_ir(dialect="stablehlo")
source = str(mhlo)
print(in_idx_map)
print("source", source)
kept = lowered_func.compile()._executable._kept_var_idx
in_args = tuple(arg for (i, arg) in enumerate(in_args) if i in kept)
in_shapes = [shape for (i, shape) in enumerate(in_shapes) if i in kept]

in_args = tuple(arg for (i, arg) in enumerate(in_args) if in_idx_map[i] in kept)
orig_shapes = []
seen = []
for (i, shape) in enumerate(in_shapes):
if in_idx_map[i] in seen:
continue
seen.append(in_idx_map[i])
orig_shapes.append(shape)
in_shapes = [shape for (i, shape) in enumerate(orig_shapes) if i in kept]
print("in args", in_args)
print("in shapes", in_shapes)

print(pipeline_options)
print(pipeline_options.xla_runtime())
print(pipeline_options.pass_pipeline())
argv = argv + ("-resource-dir", resource_dir()) + cflags()
identifier, tmpBuf = enzyme_call.create_enzyme_cpu_kernel(
source,
Expand All @@ -486,6 +503,7 @@ def _enzyme_primal_lowering(
custom_call = stablehlo.CustomCallOp(
out_types, mlir_args, call_target_name="jaxzyme.primal"
)
print(custom_call)

results = custom_call.results

Expand Down Expand Up @@ -516,7 +534,7 @@ def _enzyme_fwd_lowering(
in_args = (*args_flat,)

if lang == LANG_MHLO:
(in_tree, func) = source
(in_tree, _, func) = source
avals_in = jax.tree_util.tree_unflatten(in_tree, ctx.avals_in[::2])
lowered_func = jax.jit(func).lower(*avals_in)
mhlo = lowered_func.compiler_ir(dialect="stablehlo")
Expand Down Expand Up @@ -578,7 +596,7 @@ def _enzyme_aug_lowering(
in_args = (*args_flat,)

if lang == LANG_MHLO:
(in_tree, func) = source
(in_tree, _, func) = source
avals_in = jax.tree_util.tree_unflatten(in_tree, ctx.avals_in)
lowered_func = jax.jit(func).lower(*avals_in)
mhlo = lowered_func.compiler_ir(dialect="stablehlo")
Expand Down Expand Up @@ -644,7 +662,7 @@ def _enzyme_rev_lowering(

kept = None
if lang == LANG_MHLO:
(in_tree, func) = source
(in_tree, _, func) = source
avals_in = jax.tree_util.tree_unflatten(in_tree, ctx.avals_out)
lowered_func = jax.jit(func).lower(*avals_in)
mhlo = lowered_func.compiler_ir(dialect="stablehlo")
Expand Down Expand Up @@ -785,7 +803,10 @@ def make_zero(tan, prim):
for o in kwargs["out_shapes"]:
outshapes2.append(o)
outshapes2.append(o)
shadconv = ffi_call(*args, out_shapes=outshapes2, source=kwargs["source"], fn=kwargs["fn"], argv=kwargs["argv"], lang=kwargs["lang"], pipeline_options=pipeline_options)
(in_tree, in_idx_map, func) = kwargs["source"]
avals = {2*k:v for k, v in in_idx_map.items()} | {2*k+1:v for k, v in in_idx_map.items()}
source = (in_tree, avals, func)
shadconv = ffi_call(*args, out_shapes=outshapes2, source=source, fn=kwargs["fn"], argv=kwargs["argv"], lang=kwargs["lang"], pipeline_options=pipeline_options)
else:
shadconv = _enzyme_fwd_p.bind(
*args,
Expand Down Expand Up @@ -891,13 +912,14 @@ def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
def wrapped(*args: Any):
args_flat, in_tree = jax.tree_util.tree_flatten(args)
out_shape = jax.eval_shape(func, *args)
in_idxs = {i:i for i in range(len(args_flat))}
out_shape_flat, out_tree = jax.tree_util.tree_flatten(out_shape)
out_shape_flat = [
jax.core.ShapedArray(o.shape, o.dtype) for o in out_shape_flat
]
out_flat = ffi_call(
*args_flat,
source=(in_tree, func),
source=(in_tree, in_idxs, func),
fn="",
out_shapes=out_shape_flat,
argv=argv,
Expand Down

0 comments on commit 8782988

Please sign in to comment.