diff --git a/WORKSPACE b/WORKSPACE index 203f0580b..34d29fadd 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -60,8 +60,8 @@ load("@rules_python//python/pip_install:repositories.bzl", "pip_install_dependen pip_install_dependencies() -ENZYME_COMMIT = "f463384c7f3ae601db25acdb9213e03f7a4daaba" -ENZYME_SHA256 = "45aae6a73009a44d66b422794b394dcc99b45b7648922a2801cbda902ac9bbf7" +ENZYME_COMMIT = "2c753c97fcb41623e9aca972edfc08202b23e04f" +ENZYME_SHA256 = "0af3843503e25b973ae82dfa958a843e372708e24b34d195f38b28ced17fcb84" http_archive( name = "enzyme", diff --git a/src/enzyme_ad/jax/Implementations/MHLODerivatives.td b/src/enzyme_ad/jax/Implementations/MHLODerivatives.td index 1f1bdac32..ed8736a11 100644 --- a/src/enzyme_ad/jax/Implementations/MHLODerivatives.td +++ b/src/enzyme_ad/jax/Implementations/MHLODerivatives.td @@ -6,4 +6,6 @@ class HLOInst : Inst; class HLOReadOnlyIdentityOp ptrargs_ = [0]> : ReadOnlyIdentityOp<"mhlo", opName_, ptrargs_>; +class HLOControlFlowOp : ControlFlowOp<"mhlo", opName_, impl_>; + include "HLODerivatives.td" diff --git a/src/enzyme_ad/jax/Implementations/StableHLOAutoDiffOpInterfaceImpl.cpp b/src/enzyme_ad/jax/Implementations/StableHLOAutoDiffOpInterfaceImpl.cpp index 82f4ef954..815bac0be 100644 --- a/src/enzyme_ad/jax/Implementations/StableHLOAutoDiffOpInterfaceImpl.cpp +++ b/src/enzyme_ad/jax/Implementations/StableHLOAutoDiffOpInterfaceImpl.cpp @@ -27,9 +27,118 @@ using namespace mlir; using namespace mlir::enzyme; +using namespace mlir::stablehlo; namespace { #include "src/enzyme_ad/jax/Implementations/StableHLODerivatives.inc" + +// From +// https://github.com/openxla/stablehlo/blob/5d1a9c892500c2e9fecbfedfa66ffe84ff1caf7b/stablehlo/dialect/StablehloOps.cpp#L1498C1-L1532C1 +bool hasSameOperandAndResultTypes(Operation &op) { + Type expected; + if (op.getNumResults() != 0) + expected = op.getResult(0).getType(); + if (op.getNumOperands() != 0) + expected = op.getOperand(0).getType(); + if (!expected) + return false; + + auto typeMatch = [&](Type actual) { return actual == expected; }; + return llvm::all_of(op.getOperandTypes(), typeMatch) && + llvm::all_of(op.getResultTypes(), typeMatch); +} + +static bool isEligibleForCompactPrint(ReduceOp op) { + // Check E1. + auto &block = op.getBody().front(); + if (!hasSingleElement(block.without_terminator())) + return false; + + Operation &innerOp = *block.begin(); + + // Check E2. + if (innerOp.getDialect() != op->getDialect()) + return false; + + if (innerOp.getNumOperands() != 2 || + !innerOp.hasTrait() || + !hasSameOperandAndResultTypes(innerOp) || + !innerOp.hasTrait() || + !innerOp.hasTrait()) + return false; + + // Check E3. + if (op.getInputs().empty()) + return false; + + auto elemType = + op.getInputs()[0].getType().cast().getElementType(); + auto expectedInnerOpType = RankedTensorType::get(/*shape=*/{}, elemType); + if (innerOp.getOperands()[0].getType() != expectedInnerOpType) + return false; + + // Check E4. + if (!llvm::equal(block.getArguments(), innerOp.getOperands())) + return false; + + // Check E5. + auto retOp = dyn_cast(block.getTerminator()); + if (!retOp) + return false; + + return llvm::equal(innerOp.getResults(), retOp.getOperands()); +} + +template +class AutoDiffReduceFwd + : public AutoDiffOpInterface::ExternalModel, OpTy> { +public: + LogicalResult createForwardModeTangent(Operation *orig, OpBuilder &builder, + MGradientUtils *gutils) const { + auto red = cast(orig); + if (!isEligibleForCompactPrint(red)) + return failure(); + + Operation &innerOp = red.getBody().front().front(); + if (!isa(innerOp)) + return failure(); + + Operation *primal = gutils->getNewFromOriginal(orig); + + IRMapping map; + for (auto &operand : orig->getOpOperands()) { + if (!gutils->isConstantValue(operand.get())) { + map.map(operand.get(), gutils->invertPointerM(operand.get(), builder)); + continue; + } + if (auto iface = + dyn_cast(operand.get().getType())) { + if (!iface.requiresShadow()) { + // TODO only do if mutable + Type retTy = iface.getShadowType(); + auto toret = retTy.cast().createNullValue( + builder, operand.get().getLoc()); + map.map(operand.get(), toret); + continue; + } + } + orig->emitWarning() << "Unsupported constant arg to reduce forward " + "handler(opidx=" + << operand.getOperandNumber() + << ", op=" << operand.get() << ")\n"; + return failure(); + } + Operation *shadow = builder.clone(*orig, map); + + Value shadowRes = shadow->getResult(0); + + gutils->setDiffe(orig->getResult(0), shadowRes, builder); + gutils->eraseIfUnused(orig); + + return success(); + } +}; + } // namespace void mlir::enzyme::registerStableHLODialectAutoDiffInterface( @@ -37,5 +146,6 @@ void mlir::enzyme::registerStableHLODialectAutoDiffInterface( registry.addExtension( +[](MLIRContext *context, stablehlo::StablehloDialect *) { registerInterfaces(context); + ReduceOp::attachInterface>(*context); }); } diff --git a/src/enzyme_ad/jax/Implementations/StableHLODerivatives.td b/src/enzyme_ad/jax/Implementations/StableHLODerivatives.td index 377c94e1d..9eca4bfb9 100644 --- a/src/enzyme_ad/jax/Implementations/StableHLODerivatives.td +++ b/src/enzyme_ad/jax/Implementations/StableHLODerivatives.td @@ -6,4 +6,7 @@ class HLOInst : Inst; class HLOReadOnlyIdentityOp ptrargs_ = [0]> : ReadOnlyIdentityOp<"stablehlo", opName_, ptrargs_>; +class HLOControlFlowOp : ControlFlowOp<"stablehlo", opName_, impl_>; + + include "HLODerivatives.td" diff --git a/src/enzyme_ad/jax/primitives.py b/src/enzyme_ad/jax/primitives.py index 35faa3489..c7287960e 100644 --- a/src/enzyme_ad/jax/primitives.py +++ b/src/enzyme_ad/jax/primitives.py @@ -487,21 +487,29 @@ def _enzyme_primal_lowering( pass_pipeline = pipeline_options.pass_pipeline() if lang == LANG_MHLO: (in_tree, in_idx_map, mfunc) = source - in_idxs = sorted(set(v for _, v in in_idx_map.items())) - avals = [ctx.avals_in[i] for i in in_idxs] + + orig_shapes = [] + seen = {} + for i, shape in enumerate(in_shapes): + if in_idx_map[i] in seen: + continue + seen[in_idx_map[i]] = i + orig_shapes.append(shape) + print("orig_shapes", orig_shapes) + print("seen", seen) + avals = [ctx.avals_in[seen[i]] for i in seen] avals_in = jax.tree_util.tree_unflatten(in_tree, avals) + print("avals_in", avals_in) + print("in_idx_map", in_idx_map) + print("in_shapes", in_shapes) + + lowered_func = jax.jit(mfunc).lower(*avals_in) mhlo = lowered_func.compiler_ir(dialect="stablehlo") source = str(mhlo) kept = lowered_func.compile()._executable._kept_var_idx + print("kept", kept) in_args = tuple(arg for (i, arg) in enumerate(in_args) if in_idx_map[i] in kept) - orig_shapes = [] - seen = [] - for i, shape in enumerate(in_shapes): - if in_idx_map[i] in seen: - continue - seen.append(in_idx_map[i]) - orig_shapes.append(shape) if len(kept) != len(orig_shapes): post = ",".join(["enzyme_dup"] * len(kept)) prev = ",".join(["enzyme_dup"] * len(orig_shapes)) @@ -510,6 +518,7 @@ def _enzyme_primal_lowering( in_shapes = [ shape for (i, shape) in enumerate(in_shapes) if in_idx_map[i] in kept ] + print("post in_shapes", in_shapes) if pipeline_options.stablehlo_inject(): fn = enzyme_call.run_pass_pipeline(source, pass_pipeline) diff --git a/test/bench_vs_xla.py b/test/bench_vs_xla.py index 2ca50b50c..865edad5d 100644 --- a/test/bench_vs_xla.py +++ b/test/bench_vs_xla.py @@ -155,8 +155,11 @@ def harness(self, name, in_fn, ins, dins, douts): self.assertTrue((jnp.abs(primals - primals_p) < 1e-6).all()) - for t, t_p in zip(tangents, tangents_p): - self.assertTrue((jnp.abs(t - t_p) < 1e-6).all()) + if len(tangents.shape) == 0: + self.assertTrue((jnp.abs(tangents - tangents_p) < 1e-6).all()) + else: + for t, t_p in zip(tangents, tangents_p): + self.assertTrue((jnp.abs(t - t_p) < 1e-6).all()) print( name + " EnzymeMLIR(", diff --git a/test/llama.py b/test/llama.py index bd780c09b..a0ffb356b 100644 --- a/test/llama.py +++ b/test/llama.py @@ -6,6 +6,7 @@ import numpy as np import timeit +argv = ("-I/usr/include/c++/11", "-I/usr/include/x86_64-linux-gnu/c++/11") def rmsnorm(x, weight): ss = 1 / jnp.sqrt(x.dot(x) / x.shape[0] + 1e-5) @@ -29,6 +30,7 @@ def silu(x): # Token is token value asserts = True +pipeline = enzyme_jax.NewXLAPipeline(mlirad=True) def forward(x, config, weights, key_cache, value_cache): pos = key_cache.shape[1] @@ -59,6 +61,8 @@ def forward(x, config, weights, key_cache, value_cache): wo = weights["wo"] if asserts: + if wo.shape != (n_layers, dim, dim): + print(wo.shape, weights, (n_layers, dim, kv_dim, kv_mul, head_size, hidden_dim, n_kv_heads, vocab_size, n_heads, seq_len, n_layers)) assert wo.shape == (n_layers, dim, dim) rms_ffn_weight = weights["rms_ffn_weight"] if asserts: @@ -282,13 +286,9 @@ def sfn(x, weights, key_cache, value_cache): func = partial(forward, config) - @jax.jit - def jfunc(x, weights, key_cache, value_cache): - return func(x, weights, key_cache, value_cache) + jfunc = jax.jit(func) - @enzyme_jax.enzyme_jax_ir() - def efunc(x, weights, key_cache, value_cache): - return func(x, weights, key_cache, value_cache) + efunc = enzyme_jax.enzyme_jax_ir(argv=argv, pipeline_options=pipeline)(func) eres = efunc(x, weights, key_cache, value_cache) print("Enzyme primal", eres)