Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Mar 8, 2024
1 parent 992034f commit f44ccd8
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 62 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -203,27 +203,36 @@ class AutoDiffBroadcastInDimRev
SmallVector<int64_t> bcastDims(op.getBroadcastDimensions().begin(),
op.getBroadcastDimensions().end());

SmallVector<int64_t> newDims;
SmallVector<int64_t> reduceShape;
SmallVector<int64_t> reducedDims;
SmallVector<int64_t> iterShape;
for (auto en : llvm::enumerate(outTy.getShape())) {
if (llvm::is_contained(bcastDims, en.index())) {
if (en.value() != 1) {
newDims.push_back(en.index());
ssize_t bcastIdx = -1;
for (auto en2 : llvm::enumerate(bcastDims)) {
if (en2.value() == en.index()) {
bcastIdx = en2.index();
break;
}
}
if (bcastIdx != -1) {
if (en.value() != inTy.getShape()[bcastIdx]) {
reducedDims.push_back(en.index());
assert(inTy.getShape()[bcastIdx] == 1);
} else {
iterShape.push_back(inTy.getShape()[bcastIdx]);
}
continue;
}
reduceShape.push_back(en.value());
newDims.push_back(en.index());
reducedDims.push_back(en.index());
}

auto reduceTy = RankedTensorType::get(reduceShape, inTy.getElementType());
auto reduceTy = RankedTensorType::get(iterShape, inTy.getElementType());

Value zero = gutils->getShadowType(reduceTy)
.cast<AutoDiffTypeInterface>()
.createNullValue(builder, op.getLoc());

auto red = builder.create<ReduceOp>(op.getLoc(), TypeRange(zero.getType()),
inDiffe, zero, newDims);
inDiffe, zero, reducedDims);
red.getBody().push_back(new Block());
Block &body = red.getBody().front();
OpBuilder bodyBuilder(orig->getContext());
Expand Down
104 changes: 53 additions & 51 deletions test/bench_vs_xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,13 +70,9 @@ def harness(self, name, in_fn, ins, dins, douts):

print(
name + " JaX Primal: ",
timeit.Timer(
primalstr,
globals={
"fn": rfn_jax,
}
| primalins,
).timeit(number)
timeit.Timer(primalstr, globals={"fn": rfn_jax,} | primalins,).timeit(
number
)
/ number,
)

Expand All @@ -97,13 +93,7 @@ def harness(self, name, in_fn, ins, dins, douts):
fwdins = primalins | {("din" + str(i)): dins[0] for i in range(len(dins))}
print(
name + " JaX Fwd: ",
timeit.Timer(
fwdstr,
globals={
"fwd": fwd_jax,
}
| fwdins,
).timeit(number)
timeit.Timer(fwdstr, globals={"fwd": fwd_jax,} | fwdins,).timeit(number)
/ number,
)

Expand All @@ -124,13 +114,7 @@ def harness(self, name, in_fn, ins, dins, douts):

print(
name + " JaX Rev: ",
timeit.Timer(
revstr,
globals={
"rev": rev_jax,
}
| revins,
).timeit(number)
timeit.Timer(revstr, globals={"rev": rev_jax,} | revins,).timeit(number)
/ number,
)

Expand All @@ -148,11 +132,7 @@ def harness(self, name, in_fn, ins, dins, douts):
name,
") Primal: ",
timeit.Timer(
primalstr,
globals={
"fn": rfn_enzyme,
}
| primalins,
primalstr, globals={"fn": rfn_enzyme,} | primalins,
).timeit(number)
/ number,
)
Expand All @@ -174,13 +154,9 @@ def harness(self, name, in_fn, ins, dins, douts):
name + " EnzymeMLIR(",
name,
") Fwd: ",
timeit.Timer(
fwdstr,
globals={
"fwd": fwd_enzyme,
}
| fwdins,
).timeit(number)
timeit.Timer(fwdstr, globals={"fwd": fwd_enzyme,} | fwdins,).timeit(
number
)
/ number,
)

Expand All @@ -198,13 +174,9 @@ def harness(self, name, in_fn, ins, dins, douts):
name + " EnzymeMLIR(",
name,
") Rev: ",
timeit.Timer(
revstr,
globals={
"rev": rev_enzyme,
}
| revins,
).timeit(number)
timeit.Timer(revstr, globals={"rev": rev_enzyme,} | revins,).timeit(
number
)
/ number,
)

Expand Down Expand Up @@ -291,7 +263,9 @@ class Slicing(EnzymeJaxTest):
def setUp(self):
dim = 3
self.ins = [jnp.array(range(dim), dtype=jnp.float32).reshape(1, dim, 1)]
self.dins = [jnp.array([i * i for i in range(dim)], dtype=jnp.float32).reshape(1, dim, 1)]
self.dins = [
jnp.array([i * i for i in range(dim)], dtype=jnp.float32).reshape(1, dim, 1)
]
self.douts = [jnp.array([i * i for i in range(dim)], dtype=jnp.float32)]

def nomlir(x):
Expand All @@ -311,16 +285,24 @@ def setUp(self):
dim = 12
self.ins = [jnp.array(range(dim), dtype=jnp.float32)]
self.dins = [jnp.array([i * i for i in range(dim)], dtype=jnp.float32)]
self.douts = [jnp.array([i * i for i in range(2*dim)], dtype=jnp.float32).reshape((2, dim))]
self.douts = [
jnp.array([i * i for i in range(2 * dim)], dtype=jnp.float32).reshape(
(2, dim)
)
]

def nomlir(x):
return [(name, a) for (name, a) in x if name != "NewXLAMLIR" and name != "NewXLA" and name != "OldXLA"]
return [
(name, a)
for (name, a) in x
if name != "NewXLAMLIR" and name != "NewXLA" and name != "OldXLA"
]

self.revfilter = nomlir

def f(x):
toconv2 = jnp.ones((dim, dim))
k = jnp.einsum('jk,k->j', toconv2, x)
k = jnp.einsum("jk,k->j", toconv2, x)
kcl = jnp.zeros((1, dim))
h = jnp.reshape(k, (1, dim))
kcl = jnp.append(kcl, h, axis=0)
Expand All @@ -329,15 +311,24 @@ def f(x):
self.fn = f
self.name = "activitymismatch"


class GenDot(EnzymeJaxTest):
def setUp(self):
dim = 12
self.ins = [jnp.array(range(dim), dtype=jnp.float32)]
self.dins = [jnp.array([i * i for i in range(dim)], dtype=jnp.float32)]
self.douts = [jnp.array([i * i for i in range(2*dim)], dtype=jnp.float32).reshape((2, dim))]
self.douts = [
jnp.array([i * i for i in range(2 * dim)], dtype=jnp.float32).reshape(
(2, dim)
)
]

def nomlir(x):
return [(name, a) for (name, a) in x if name != "NewXLAMLIR" and name != "NewXLA" and name != "OldXLA"]
return [
(name, a)
for (name, a) in x
if name != "NewXLAMLIR" and name != "NewXLA" and name != "OldXLA"
]

self.revfilter = nomlir

Expand All @@ -349,7 +340,7 @@ def f(x):
k = jnp.reshape(jnp.einsum("ijk,ik -> ij", toconv2, k_tmp), (dim,))

kcl = jnp.zeros((1, dim))

h = jnp.reshape(k, (1, dim))
kcl = jnp.append(kcl, h, axis=0)
return kcl
Expand All @@ -361,12 +352,22 @@ def f(x):
class Concat(EnzymeJaxTest):
def setUp(self):
dim = 12
self.ins = [jnp.array(range(dim), dtype=jnp.float32), 10*jnp.array(range(dim), dtype=jnp.float32)]
self.dins = [jnp.array([i * i for i in range(dim)], dtype=jnp.float32), jnp.array([i * i *i / 3. for i in range(dim)], dtype=jnp.float32)]
self.douts = [jnp.array([i * i for i in range(2*dim)], dtype=jnp.float32)]
self.ins = [
jnp.array(range(dim), dtype=jnp.float32),
10 * jnp.array(range(dim), dtype=jnp.float32),
]
self.dins = [
jnp.array([i * i for i in range(dim)], dtype=jnp.float32),
jnp.array([i * i * i / 3.0 for i in range(dim)], dtype=jnp.float32),
]
self.douts = [jnp.array([i * i for i in range(2 * dim)], dtype=jnp.float32)]

def nomlir(x):
return [(name, a) for (name, a) in x if name != "NewXLAMLIR" and name != "NewXLA" and name != "OldXLA"]
return [
(name, a)
for (name, a) in x
if name != "NewXLAMLIR" and name != "NewXLA" and name != "OldXLA"
]

self.revfilter = nomlir

Expand All @@ -376,5 +377,6 @@ def f(x, y):
self.fn = f
self.name = "Concat"


if __name__ == "__main__":
absltest.main()
10 changes: 8 additions & 2 deletions test/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,7 @@ def sfn(x, weights, key_cache, value_cache):
# mlir = jax.jit(partial(forward, config)).lower(1, weights, key_cache, value_cache).compiler_ir(dialect="mhlo")

if True:

@jax.jit
def jfwd(x, dx, weights, dweights, kc, dkc, vc, dvc):
return jax.jvp(jfunc, (x, weights, kc, vc), (x, weights, dkc, dvc))
Expand Down Expand Up @@ -411,8 +412,13 @@ def erev(x, weights, kc, vc, dx, dkc, dvc):
jres = jrev(x, weights, key_cache, value_cache, dx, dkc, dvc)
print("Jax rev", jres)

jrev2 = enzyme_jax.enzyme_jax_ir(argv=argv, pipeline_options=enzyme_jax.JaXPipeline("inline{default-pipeline=canonicalize max-iterations=4},"
+ "canonicalize,cse,enzyme-hlo-opt,cse"))(jrev)
jrev2 = enzyme_jax.enzyme_jax_ir(
argv=argv,
pipeline_options=enzyme_jax.JaXPipeline(
"inline{default-pipeline=canonicalize max-iterations=4},"
+ "canonicalize,cse,enzyme-hlo-opt,cse"
),
)(jrev)

jres2 = jrev2(x, weights, key_cache, value_cache, dx, dkc, dvc)
print("Jax2 rev", jres2)
Expand Down

0 comments on commit f44ccd8

Please sign in to comment.