Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Mar 8, 2024
1 parent 992034f commit 9ed373a
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -203,27 +203,36 @@ class AutoDiffBroadcastInDimRev
SmallVector<int64_t> bcastDims(op.getBroadcastDimensions().begin(),
op.getBroadcastDimensions().end());

SmallVector<int64_t> newDims;
SmallVector<int64_t> reduceShape;
SmallVector<int64_t> reducedDims;
SmallVector<int64_t> iterShape;
for (auto en : llvm::enumerate(outTy.getShape())) {
if (llvm::is_contained(bcastDims, en.index())) {
if (en.value() != 1) {
newDims.push_back(en.index());
ssize_t bcastIdx = -1;
for (auto en2 : llvm::enumerate(bcastDims)) {
if (en2.value() == en.index()) {
bcastIdx = en2.index();
break;
}
}
if (bcastIdx != -1) {
if (en.value() != inTy.getShape()[bcastIdx]) {
reducedDims.push_back(en.index());
assert(inTy.getShape()[bcastIdx] == 1);
} else {
iterShape.push_back(inTy.getShape()[bcastIdx]);
}
continue;
}
reduceShape.push_back(en.value());
newDims.push_back(en.index());
reducedDims.push_back(en.index());
}

auto reduceTy = RankedTensorType::get(reduceShape, inTy.getElementType());
auto reduceTy = RankedTensorType::get(iterShape, inTy.getElementType());

Value zero = gutils->getShadowType(reduceTy)
.cast<AutoDiffTypeInterface>()
.createNullValue(builder, op.getLoc());

auto red = builder.create<ReduceOp>(op.getLoc(), TypeRange(zero.getType()),
inDiffe, zero, newDims);
inDiffe, zero, reducedDims);
red.getBody().push_back(new Block());
Block &body = red.getBody().front();
OpBuilder bodyBuilder(orig->getContext());
Expand Down
52 changes: 41 additions & 11 deletions test/bench_vs_xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,9 @@ class Slicing(EnzymeJaxTest):
def setUp(self):
dim = 3
self.ins = [jnp.array(range(dim), dtype=jnp.float32).reshape(1, dim, 1)]
self.dins = [jnp.array([i * i for i in range(dim)], dtype=jnp.float32).reshape(1, dim, 1)]
self.dins = [
jnp.array([i * i for i in range(dim)], dtype=jnp.float32).reshape(1, dim, 1)
]
self.douts = [jnp.array([i * i for i in range(dim)], dtype=jnp.float32)]

def nomlir(x):
Expand All @@ -311,16 +313,24 @@ def setUp(self):
dim = 12
self.ins = [jnp.array(range(dim), dtype=jnp.float32)]
self.dins = [jnp.array([i * i for i in range(dim)], dtype=jnp.float32)]
self.douts = [jnp.array([i * i for i in range(2*dim)], dtype=jnp.float32).reshape((2, dim))]
self.douts = [
jnp.array([i * i for i in range(2 * dim)], dtype=jnp.float32).reshape(
(2, dim)
)
]

def nomlir(x):
return [(name, a) for (name, a) in x if name != "NewXLAMLIR" and name != "NewXLA" and name != "OldXLA"]
return [
(name, a)
for (name, a) in x
if name != "NewXLAMLIR" and name != "NewXLA" and name != "OldXLA"
]

self.revfilter = nomlir

def f(x):
toconv2 = jnp.ones((dim, dim))
k = jnp.einsum('jk,k->j', toconv2, x)
k = jnp.einsum("jk,k->j", toconv2, x)
kcl = jnp.zeros((1, dim))
h = jnp.reshape(k, (1, dim))
kcl = jnp.append(kcl, h, axis=0)
Expand All @@ -329,15 +339,24 @@ def f(x):
self.fn = f
self.name = "activitymismatch"


class GenDot(EnzymeJaxTest):
def setUp(self):
dim = 12
self.ins = [jnp.array(range(dim), dtype=jnp.float32)]
self.dins = [jnp.array([i * i for i in range(dim)], dtype=jnp.float32)]
self.douts = [jnp.array([i * i for i in range(2*dim)], dtype=jnp.float32).reshape((2, dim))]
self.douts = [
jnp.array([i * i for i in range(2 * dim)], dtype=jnp.float32).reshape(
(2, dim)
)
]

def nomlir(x):
return [(name, a) for (name, a) in x if name != "NewXLAMLIR" and name != "NewXLA" and name != "OldXLA"]
return [
(name, a)
for (name, a) in x
if name != "NewXLAMLIR" and name != "NewXLA" and name != "OldXLA"
]

self.revfilter = nomlir

Expand All @@ -349,7 +368,7 @@ def f(x):
k = jnp.reshape(jnp.einsum("ijk,ik -> ij", toconv2, k_tmp), (dim,))

kcl = jnp.zeros((1, dim))

h = jnp.reshape(k, (1, dim))
kcl = jnp.append(kcl, h, axis=0)
return kcl
Expand All @@ -361,12 +380,22 @@ def f(x):
class Concat(EnzymeJaxTest):
def setUp(self):
dim = 12
self.ins = [jnp.array(range(dim), dtype=jnp.float32), 10*jnp.array(range(dim), dtype=jnp.float32)]
self.dins = [jnp.array([i * i for i in range(dim)], dtype=jnp.float32), jnp.array([i * i *i / 3. for i in range(dim)], dtype=jnp.float32)]
self.douts = [jnp.array([i * i for i in range(2*dim)], dtype=jnp.float32)]
self.ins = [
jnp.array(range(dim), dtype=jnp.float32),
10 * jnp.array(range(dim), dtype=jnp.float32),
]
self.dins = [
jnp.array([i * i for i in range(dim)], dtype=jnp.float32),
jnp.array([i * i * i / 3.0 for i in range(dim)], dtype=jnp.float32),
]
self.douts = [jnp.array([i * i for i in range(2 * dim)], dtype=jnp.float32)]

def nomlir(x):
return [(name, a) for (name, a) in x if name != "NewXLAMLIR" and name != "NewXLA" and name != "OldXLA"]
return [
(name, a)
for (name, a) in x
if name != "NewXLAMLIR" and name != "NewXLA" and name != "OldXLA"
]

self.revfilter = nomlir

Expand All @@ -376,5 +405,6 @@ def f(x, y):
self.fn = f
self.name = "Concat"


if __name__ == "__main__":
absltest.main()
10 changes: 8 additions & 2 deletions test/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,7 @@ def sfn(x, weights, key_cache, value_cache):
# mlir = jax.jit(partial(forward, config)).lower(1, weights, key_cache, value_cache).compiler_ir(dialect="mhlo")

if True:

@jax.jit
def jfwd(x, dx, weights, dweights, kc, dkc, vc, dvc):
return jax.jvp(jfunc, (x, weights, kc, vc), (x, weights, dkc, dvc))
Expand Down Expand Up @@ -411,8 +412,13 @@ def erev(x, weights, kc, vc, dx, dkc, dvc):
jres = jrev(x, weights, key_cache, value_cache, dx, dkc, dvc)
print("Jax rev", jres)

jrev2 = enzyme_jax.enzyme_jax_ir(argv=argv, pipeline_options=enzyme_jax.JaXPipeline("inline{default-pipeline=canonicalize max-iterations=4},"
+ "canonicalize,cse,enzyme-hlo-opt,cse"))(jrev)
jrev2 = enzyme_jax.enzyme_jax_ir(
argv=argv,
pipeline_options=enzyme_jax.JaXPipeline(
"inline{default-pipeline=canonicalize max-iterations=4},"
+ "canonicalize,cse,enzyme-hlo-opt,cse"
),
)(jrev)

jres2 = jrev2(x, weights, key_cache, value_cache, dx, dkc, dvc)
print("Jax2 rev", jres2)
Expand Down

0 comments on commit 9ed373a

Please sign in to comment.