Skip to content

Commit

Permalink
fixup
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Feb 16, 2024
1 parent 7b7aed6 commit 5a246c4
Show file tree
Hide file tree
Showing 7 changed files with 146 additions and 19 deletions.
4 changes: 2 additions & 2 deletions WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ load("@rules_python//python/pip_install:repositories.bzl", "pip_install_dependen

pip_install_dependencies()

ENZYME_COMMIT = "f463384c7f3ae601db25acdb9213e03f7a4daaba"
ENZYME_SHA256 = "45aae6a73009a44d66b422794b394dcc99b45b7648922a2801cbda902ac9bbf7"
ENZYME_COMMIT = "2c753c97fcb41623e9aca972edfc08202b23e04f"
ENZYME_SHA256 = "0af3843503e25b973ae82dfa958a843e372708e24b34d195f38b28ced17fcb84"

http_archive(
name = "enzyme",
Expand Down
2 changes: 2 additions & 0 deletions src/enzyme_ad/jax/Implementations/MHLODerivatives.td
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,6 @@ class HLOInst<string m> : Inst<m, "mhlo">;

class HLOReadOnlyIdentityOp<string opName_, list<int> ptrargs_ = [0]> : ReadOnlyIdentityOp<"mhlo", opName_, ptrargs_>;

class HLOControlFlowOp<string opName_, string impl_> : ControlFlowOp<"mhlo", opName_, impl_>;

include "HLODerivatives.td"
110 changes: 110 additions & 0 deletions src/enzyme_ad/jax/Implementations/StableHLOAutoDiffOpInterfaceImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,125 @@

using namespace mlir;
using namespace mlir::enzyme;
using namespace mlir::stablehlo;

namespace {
#include "src/enzyme_ad/jax/Implementations/StableHLODerivatives.inc"

// From
// https://github.com/openxla/stablehlo/blob/5d1a9c892500c2e9fecbfedfa66ffe84ff1caf7b/stablehlo/dialect/StablehloOps.cpp#L1498C1-L1532C1
bool hasSameOperandAndResultTypes(Operation &op) {
Type expected;
if (op.getNumResults() != 0)
expected = op.getResult(0).getType();
if (op.getNumOperands() != 0)
expected = op.getOperand(0).getType();
if (!expected)
return false;

auto typeMatch = [&](Type actual) { return actual == expected; };
return llvm::all_of(op.getOperandTypes(), typeMatch) &&
llvm::all_of(op.getResultTypes(), typeMatch);
}

static bool isEligibleForCompactPrint(ReduceOp op) {
// Check E1.
auto &block = op.getBody().front();
if (!hasSingleElement(block.without_terminator()))
return false;

Operation &innerOp = *block.begin();

// Check E2.
if (innerOp.getDialect() != op->getDialect())
return false;

if (innerOp.getNumOperands() != 2 ||
!innerOp.hasTrait<mlir::OpTrait::OneResult>() ||
!hasSameOperandAndResultTypes(innerOp) ||
!innerOp.hasTrait<mlir::hlo::OpTrait::IsCommutative>() ||
!innerOp.hasTrait<mlir::OpTrait::ZeroRegions>())
return false;

// Check E3.
if (op.getInputs().empty())
return false;

auto elemType =
op.getInputs()[0].getType().cast<ShapedType>().getElementType();
auto expectedInnerOpType = RankedTensorType::get(/*shape=*/{}, elemType);
if (innerOp.getOperands()[0].getType() != expectedInnerOpType)
return false;

// Check E4.
if (!llvm::equal(block.getArguments(), innerOp.getOperands()))
return false;

// Check E5.
auto retOp = dyn_cast<ReturnOp>(block.getTerminator());
if (!retOp)
return false;

return llvm::equal(innerOp.getResults(), retOp.getOperands());
}

template <typename OpTy>
class AutoDiffReduceFwd
: public AutoDiffOpInterface::ExternalModel<AutoDiffReduceFwd<OpTy>, OpTy> {
public:
LogicalResult createForwardModeTangent(Operation *orig, OpBuilder &builder,
MGradientUtils *gutils) const {
auto red = cast<OpTy>(orig);
if (!isEligibleForCompactPrint(red))
return failure();

Operation &innerOp = red.getBody().front().front();
if (!isa<AddOp>(innerOp))
return failure();

Operation *primal = gutils->getNewFromOriginal(orig);

IRMapping map;
for (auto &operand : orig->getOpOperands()) {
if (!gutils->isConstantValue(operand.get())) {
map.map(operand.get(), gutils->invertPointerM(operand.get(), builder));
continue;
}
if (auto iface =
dyn_cast<AutoDiffTypeInterface>(operand.get().getType())) {
if (!iface.requiresShadow()) {
// TODO only do if mutable
Type retTy = iface.getShadowType();
auto toret = retTy.cast<AutoDiffTypeInterface>().createNullValue(
builder, operand.get().getLoc());
map.map(operand.get(), toret);
continue;
}
}
orig->emitWarning() << "Unsupported constant arg to reduce forward "
"handler(opidx="
<< operand.getOperandNumber()
<< ", op=" << operand.get() << ")\n";
return failure();
}
Operation *shadow = builder.clone(*orig, map);

Value shadowRes = shadow->getResult(0);

gutils->setDiffe(orig->getResult(0), shadowRes, builder);
gutils->eraseIfUnused(orig);

return success();
}
};

} // namespace

void mlir::enzyme::registerStableHLODialectAutoDiffInterface(
DialectRegistry &registry) {
registry.addExtension(
+[](MLIRContext *context, stablehlo::StablehloDialect *) {
registerInterfaces(context);
ReduceOp::attachInterface<AutoDiffReduceFwd<ReduceOp>>(*context);
});
}
3 changes: 3 additions & 0 deletions src/enzyme_ad/jax/Implementations/StableHLODerivatives.td
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,7 @@ class HLOInst<string m> : Inst<m, "stablehlo">;

class HLOReadOnlyIdentityOp<string opName_, list<int> ptrargs_ = [0]> : ReadOnlyIdentityOp<"stablehlo", opName_, ptrargs_>;

class HLOControlFlowOp<string opName_, string impl_> : ControlFlowOp<"stablehlo", opName_, impl_>;


include "HLODerivatives.td"
27 changes: 18 additions & 9 deletions src/enzyme_ad/jax/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,21 +487,29 @@ def _enzyme_primal_lowering(
pass_pipeline = pipeline_options.pass_pipeline()
if lang == LANG_MHLO:
(in_tree, in_idx_map, mfunc) = source
in_idxs = sorted(set(v for _, v in in_idx_map.items()))
avals = [ctx.avals_in[i] for i in in_idxs]

orig_shapes = []
seen = {}
for i, shape in enumerate(in_shapes):
if in_idx_map[i] in seen:
continue
seen[in_idx_map[i]] = i
orig_shapes.append(shape)
print("orig_shapes", orig_shapes)
print("seen", seen)
avals = [ctx.avals_in[seen[i]] for i in seen]
avals_in = jax.tree_util.tree_unflatten(in_tree, avals)
print("avals_in", avals_in)
print("in_idx_map", in_idx_map)
print("in_shapes", in_shapes)


lowered_func = jax.jit(mfunc).lower(*avals_in)
mhlo = lowered_func.compiler_ir(dialect="stablehlo")
source = str(mhlo)
kept = lowered_func.compile()._executable._kept_var_idx
print("kept", kept)
in_args = tuple(arg for (i, arg) in enumerate(in_args) if in_idx_map[i] in kept)
orig_shapes = []
seen = []
for i, shape in enumerate(in_shapes):
if in_idx_map[i] in seen:
continue
seen.append(in_idx_map[i])
orig_shapes.append(shape)
if len(kept) != len(orig_shapes):
post = ",".join(["enzyme_dup"] * len(kept))
prev = ",".join(["enzyme_dup"] * len(orig_shapes))
Expand All @@ -510,6 +518,7 @@ def _enzyme_primal_lowering(
in_shapes = [
shape for (i, shape) in enumerate(in_shapes) if in_idx_map[i] in kept
]
print("post in_shapes", in_shapes)

if pipeline_options.stablehlo_inject():
fn = enzyme_call.run_pass_pipeline(source, pass_pipeline)
Expand Down
7 changes: 5 additions & 2 deletions test/bench_vs_xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,8 +155,11 @@ def harness(self, name, in_fn, ins, dins, douts):

self.assertTrue((jnp.abs(primals - primals_p) < 1e-6).all())

for t, t_p in zip(tangents, tangents_p):
self.assertTrue((jnp.abs(t - t_p) < 1e-6).all())
if len(tangents.shape) == 0:
self.assertTrue((jnp.abs(tangents - tangents_p) < 1e-6).all())
else:
for t, t_p in zip(tangents, tangents_p):
self.assertTrue((jnp.abs(t - t_p) < 1e-6).all())

print(
name + " EnzymeMLIR(",
Expand Down
12 changes: 6 additions & 6 deletions test/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import numpy as np
import timeit

argv = ("-I/usr/include/c++/11", "-I/usr/include/x86_64-linux-gnu/c++/11")

def rmsnorm(x, weight):
ss = 1 / jnp.sqrt(x.dot(x) / x.shape[0] + 1e-5)
Expand All @@ -29,6 +30,7 @@ def silu(x):
# Token is token value
asserts = True

pipeline = enzyme_jax.NewXLAPipeline(mlirad=True)

def forward(x, config, weights, key_cache, value_cache):
pos = key_cache.shape[1]
Expand Down Expand Up @@ -59,6 +61,8 @@ def forward(x, config, weights, key_cache, value_cache):

wo = weights["wo"]
if asserts:
if wo.shape != (n_layers, dim, dim):
print(wo.shape, weights, (n_layers, dim, kv_dim, kv_mul, head_size, hidden_dim, n_kv_heads, vocab_size, n_heads, seq_len, n_layers))
assert wo.shape == (n_layers, dim, dim)
rms_ffn_weight = weights["rms_ffn_weight"]
if asserts:
Expand Down Expand Up @@ -282,13 +286,9 @@ def sfn(x, weights, key_cache, value_cache):

func = partial(forward, config)

@jax.jit
def jfunc(x, weights, key_cache, value_cache):
return func(x, weights, key_cache, value_cache)
jfunc = jax.jit(func)

@enzyme_jax.enzyme_jax_ir()
def efunc(x, weights, key_cache, value_cache):
return func(x, weights, key_cache, value_cache)
efunc = enzyme_jax.enzyme_jax_ir(argv=argv, pipeline_options=pipeline)(func)

eres = efunc(x, weights, key_cache, value_cache)
print("Enzyme primal", eres)
Expand Down

0 comments on commit 5a246c4

Please sign in to comment.