diff --git a/src/enzyme_ad/jax/Implementations/StableHLOAutoDiffOpInterfaceImpl.cpp b/src/enzyme_ad/jax/Implementations/StableHLOAutoDiffOpInterfaceImpl.cpp index 0ad4f3751..79ea13790 100644 --- a/src/enzyme_ad/jax/Implementations/StableHLOAutoDiffOpInterfaceImpl.cpp +++ b/src/enzyme_ad/jax/Implementations/StableHLOAutoDiffOpInterfaceImpl.cpp @@ -203,27 +203,36 @@ class AutoDiffBroadcastInDimRev SmallVector bcastDims(op.getBroadcastDimensions().begin(), op.getBroadcastDimensions().end()); - SmallVector newDims; - SmallVector reduceShape; + SmallVector reducedDims; + SmallVector iterShape; for (auto en : llvm::enumerate(outTy.getShape())) { - if (llvm::is_contained(bcastDims, en.index())) { - if (en.value() != 1) { - newDims.push_back(en.index()); + ssize_t bcastIdx = -1; + for (auto en2 : llvm::enumerate(bcastDims)) { + if (en2.value() == en.index()) { + bcastIdx = en2.index(); + break; + } + } + if (bcastIdx != -1) { + if (en.value() != inTy.getShape()[bcastIdx]) { + reducedDims.push_back(en.index()); + assert(inTy.getShape()[bcastIdx] == 1); + } else { + iterShape.push_back(inTy.getShape()[bcastIdx]); } continue; } - reduceShape.push_back(en.value()); - newDims.push_back(en.index()); + reducedDims.push_back(en.index()); } - auto reduceTy = RankedTensorType::get(reduceShape, inTy.getElementType()); + auto reduceTy = RankedTensorType::get(iterShape, inTy.getElementType()); Value zero = gutils->getShadowType(reduceTy) .cast() .createNullValue(builder, op.getLoc()); auto red = builder.create(op.getLoc(), TypeRange(zero.getType()), - inDiffe, zero, newDims); + inDiffe, zero, reducedDims); red.getBody().push_back(new Block()); Block &body = red.getBody().front(); OpBuilder bodyBuilder(orig->getContext()); diff --git a/test/bench_vs_xla.py b/test/bench_vs_xla.py index 3c241bb3b..b422069f8 100644 --- a/test/bench_vs_xla.py +++ b/test/bench_vs_xla.py @@ -70,13 +70,9 @@ def harness(self, name, in_fn, ins, dins, douts): print( name + " JaX Primal: ", - timeit.Timer( - primalstr, - globals={ - "fn": rfn_jax, - } - | primalins, - ).timeit(number) + timeit.Timer(primalstr, globals={"fn": rfn_jax,} | primalins,).timeit( + number + ) / number, ) @@ -97,13 +93,7 @@ def harness(self, name, in_fn, ins, dins, douts): fwdins = primalins | {("din" + str(i)): dins[0] for i in range(len(dins))} print( name + " JaX Fwd: ", - timeit.Timer( - fwdstr, - globals={ - "fwd": fwd_jax, - } - | fwdins, - ).timeit(number) + timeit.Timer(fwdstr, globals={"fwd": fwd_jax,} | fwdins,).timeit(number) / number, ) @@ -124,13 +114,7 @@ def harness(self, name, in_fn, ins, dins, douts): print( name + " JaX Rev: ", - timeit.Timer( - revstr, - globals={ - "rev": rev_jax, - } - | revins, - ).timeit(number) + timeit.Timer(revstr, globals={"rev": rev_jax,} | revins,).timeit(number) / number, ) @@ -148,11 +132,7 @@ def harness(self, name, in_fn, ins, dins, douts): name, ") Primal: ", timeit.Timer( - primalstr, - globals={ - "fn": rfn_enzyme, - } - | primalins, + primalstr, globals={"fn": rfn_enzyme,} | primalins, ).timeit(number) / number, ) @@ -174,13 +154,9 @@ def harness(self, name, in_fn, ins, dins, douts): name + " EnzymeMLIR(", name, ") Fwd: ", - timeit.Timer( - fwdstr, - globals={ - "fwd": fwd_enzyme, - } - | fwdins, - ).timeit(number) + timeit.Timer(fwdstr, globals={"fwd": fwd_enzyme,} | fwdins,).timeit( + number + ) / number, ) @@ -198,13 +174,9 @@ def harness(self, name, in_fn, ins, dins, douts): name + " EnzymeMLIR(", name, ") Rev: ", - timeit.Timer( - revstr, - globals={ - "rev": rev_enzyme, - } - | revins, - ).timeit(number) + timeit.Timer(revstr, globals={"rev": rev_enzyme,} | revins,).timeit( + number + ) / number, ) @@ -291,7 +263,9 @@ class Slicing(EnzymeJaxTest): def setUp(self): dim = 3 self.ins = [jnp.array(range(dim), dtype=jnp.float32).reshape(1, dim, 1)] - self.dins = [jnp.array([i * i for i in range(dim)], dtype=jnp.float32).reshape(1, dim, 1)] + self.dins = [ + jnp.array([i * i for i in range(dim)], dtype=jnp.float32).reshape(1, dim, 1) + ] self.douts = [jnp.array([i * i for i in range(dim)], dtype=jnp.float32)] def nomlir(x): @@ -311,16 +285,24 @@ def setUp(self): dim = 12 self.ins = [jnp.array(range(dim), dtype=jnp.float32)] self.dins = [jnp.array([i * i for i in range(dim)], dtype=jnp.float32)] - self.douts = [jnp.array([i * i for i in range(2*dim)], dtype=jnp.float32).reshape((2, dim))] + self.douts = [ + jnp.array([i * i for i in range(2 * dim)], dtype=jnp.float32).reshape( + (2, dim) + ) + ] def nomlir(x): - return [(name, a) for (name, a) in x if name != "NewXLAMLIR" and name != "NewXLA" and name != "OldXLA"] + return [ + (name, a) + for (name, a) in x + if name != "NewXLAMLIR" and name != "NewXLA" and name != "OldXLA" + ] self.revfilter = nomlir def f(x): toconv2 = jnp.ones((dim, dim)) - k = jnp.einsum('jk,k->j', toconv2, x) + k = jnp.einsum("jk,k->j", toconv2, x) kcl = jnp.zeros((1, dim)) h = jnp.reshape(k, (1, dim)) kcl = jnp.append(kcl, h, axis=0) @@ -329,15 +311,24 @@ def f(x): self.fn = f self.name = "activitymismatch" + class GenDot(EnzymeJaxTest): def setUp(self): dim = 12 self.ins = [jnp.array(range(dim), dtype=jnp.float32)] self.dins = [jnp.array([i * i for i in range(dim)], dtype=jnp.float32)] - self.douts = [jnp.array([i * i for i in range(2*dim)], dtype=jnp.float32).reshape((2, dim))] + self.douts = [ + jnp.array([i * i for i in range(2 * dim)], dtype=jnp.float32).reshape( + (2, dim) + ) + ] def nomlir(x): - return [(name, a) for (name, a) in x if name != "NewXLAMLIR" and name != "NewXLA" and name != "OldXLA"] + return [ + (name, a) + for (name, a) in x + if name != "NewXLAMLIR" and name != "NewXLA" and name != "OldXLA" + ] self.revfilter = nomlir @@ -349,7 +340,7 @@ def f(x): k = jnp.reshape(jnp.einsum("ijk,ik -> ij", toconv2, k_tmp), (dim,)) kcl = jnp.zeros((1, dim)) - + h = jnp.reshape(k, (1, dim)) kcl = jnp.append(kcl, h, axis=0) return kcl @@ -361,12 +352,22 @@ def f(x): class Concat(EnzymeJaxTest): def setUp(self): dim = 12 - self.ins = [jnp.array(range(dim), dtype=jnp.float32), 10*jnp.array(range(dim), dtype=jnp.float32)] - self.dins = [jnp.array([i * i for i in range(dim)], dtype=jnp.float32), jnp.array([i * i *i / 3. for i in range(dim)], dtype=jnp.float32)] - self.douts = [jnp.array([i * i for i in range(2*dim)], dtype=jnp.float32)] + self.ins = [ + jnp.array(range(dim), dtype=jnp.float32), + 10 * jnp.array(range(dim), dtype=jnp.float32), + ] + self.dins = [ + jnp.array([i * i for i in range(dim)], dtype=jnp.float32), + jnp.array([i * i * i / 3.0 for i in range(dim)], dtype=jnp.float32), + ] + self.douts = [jnp.array([i * i for i in range(2 * dim)], dtype=jnp.float32)] def nomlir(x): - return [(name, a) for (name, a) in x if name != "NewXLAMLIR" and name != "NewXLA" and name != "OldXLA"] + return [ + (name, a) + for (name, a) in x + if name != "NewXLAMLIR" and name != "NewXLA" and name != "OldXLA" + ] self.revfilter = nomlir @@ -376,5 +377,6 @@ def f(x, y): self.fn = f self.name = "Concat" + if __name__ == "__main__": absltest.main() diff --git a/test/llama.py b/test/llama.py index cc588b0a2..f7b8dd028 100644 --- a/test/llama.py +++ b/test/llama.py @@ -349,6 +349,7 @@ def sfn(x, weights, key_cache, value_cache): # mlir = jax.jit(partial(forward, config)).lower(1, weights, key_cache, value_cache).compiler_ir(dialect="mhlo") if True: + @jax.jit def jfwd(x, dx, weights, dweights, kc, dkc, vc, dvc): return jax.jvp(jfunc, (x, weights, kc, vc), (x, weights, dkc, dvc)) @@ -411,8 +412,13 @@ def erev(x, weights, kc, vc, dx, dkc, dvc): jres = jrev(x, weights, key_cache, value_cache, dx, dkc, dvc) print("Jax rev", jres) - jrev2 = enzyme_jax.enzyme_jax_ir(argv=argv, pipeline_options=enzyme_jax.JaXPipeline("inline{default-pipeline=canonicalize max-iterations=4}," - + "canonicalize,cse,enzyme-hlo-opt,cse"))(jrev) + jrev2 = enzyme_jax.enzyme_jax_ir( + argv=argv, + pipeline_options=enzyme_jax.JaXPipeline( + "inline{default-pipeline=canonicalize max-iterations=4}," + + "canonicalize,cse,enzyme-hlo-opt,cse" + ), + )(jrev) jres2 = jrev2(x, weights, key_cache, value_cache, dx, dkc, dvc) print("Jax2 rev", jres2)