Skip to content

Commit

Permalink
Add missing ')' to enzyme_call and add tests for old pipeline
Browse files Browse the repository at this point in the history
The missing ')' was causing a bug on the old pipeline, which was not
caught by the current tests. This fixes the issue and adds tests to
prevent a regression.
  • Loading branch information
itf committed Jan 23, 2024
1 parent 7c7441a commit 5429d59
Show file tree
Hide file tree
Showing 3 changed files with 152 additions and 20 deletions.
1 change: 1 addition & 0 deletions src/enzyme_ad/jax/enzyme_call.cc
Original file line number Diff line number Diff line change
Expand Up @@ -527,6 +527,7 @@ class CpuKernel {
comma = true;
}
}
ss << "};\n";
ss << " " << fn
<< "(nullptr, nullptr, nullptr, buffers, nullptr, nullptr);\n";
}
Expand Down
14 changes: 12 additions & 2 deletions src/enzyme_ad/jax/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,23 @@
LANG_LLVM = enzyme_call.Language.LLVM
LANG_MHLO = enzyme_call.Language.MHLO

## options
## true (default) -> new xla pipeline, default passes
## false -> old xla pipeline, internal passes
## string -> new xla pipeline, using passes as specified

def xla_runtime(options):
return True
if type(options) == type(False) and options == False:
return False
else:
return True


def pass_pipeline(options):
return """
if type(options) == type(""):
return options
else:
return """
inline{default-pipeline=canonicalize max-iterations=4},
expand-hlo-tuples{entry-function=main},
func.func(mhlo-flatten-tuple),
Expand Down
157 changes: 139 additions & 18 deletions test/bench_vs_xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
def add_one(x: jax.Array, y) -> jax.Array:
return x + 1 + y

@enzyme_jax_ir(pipeline_options=False)
def add_one_old(x: jax.Array, y) -> jax.Array:
return x + 1 + y

@jax.jit
def add_one_plain(x: jax.Array, y) -> jax.Array:
Expand All @@ -19,6 +22,9 @@ def add_one_plain(x: jax.Array, y) -> jax.Array:
def add_two(x: jax.Array, z, y) -> jax.Array:
return x + y

@enzyme_jax_ir(pipeline_options=False)
def add_two_old(x: jax.Array, z, y) -> jax.Array:
return x + y

@jax.jit
def add_two_plain(x: jax.Array, z, y) -> jax.Array:
Expand All @@ -29,6 +35,9 @@ def add_two_plain(x: jax.Array, z, y) -> jax.Array:
def fwd(in0, in1, din0, din1):
return jax.jvp(add_one, (in0, in1), (din0, din1))

@jax.jit
def fwd_old(in0, in1, din0, din1):
return jax.jvp(add_one_old, (in0, in1), (din0, din1))

@jax.jit
def fwd_plain(in0, in1, din0, din1):
Expand All @@ -39,6 +48,10 @@ def fwd_plain(in0, in1, din0, din1):
def fwd2(in0, in1, in2, din0, din1, din2):
return jax.jvp(add_two, (in0, in1, in2), (din0, din1, din2))

@jax.jit
def fwd2_old(in0, in1, in2, din0, din1, din2):
return jax.jvp(add_two_old, (in0, in1, in2), (din0, din1, din2))


@jax.jit
def fwd2_plain(in0, in1, in2, din0, din1, din2):
Expand All @@ -51,6 +64,11 @@ def rev(in0, in1, dout):
grads = f_vjp(dout)
return primals, grads

@jax.jit
def rev_old(in0, in1, dout):
primals, f_vjp = jax.vjp(add_one_old, in0, in1)
grads = f_vjp(dout)
return primals, grads

@jax.jit
def rev_plain(in0, in1, dout):
Expand All @@ -65,6 +83,11 @@ def rev2(in0, in1, in2, dout):
grads = f_vjp(dout)
return primals, grads

@jax.jit
def rev2_old(in0, in1, in2, dout):
primals, f_vjp = jax.vjp(add_two_old, in0, in1, in2)
grads = f_vjp(dout)
return primals, grads

@jax.jit
def rev2_plain(in0, in1, in2, dout):
Expand All @@ -83,9 +106,13 @@ def setUp(self):
self.din2 = jnp.array([1300.0, 1700.0, 1900.0])

def test_add_one_primal(self):
ao = add_one(self.in0, self.in1)
aop = add_one_plain(self.in0, self.in1)

ao = add_one(self.in0, self.in1)
ao_old = add_one(self.in0, self.in1)

self.assertTrue((jnp.abs(ao - aop) < 1e-6).all())
self.assertTrue((jnp.abs(ao_old - aop) < 1e-6).all())

# Benchmark.
print(
Expand All @@ -94,6 +121,12 @@ def test_add_one_primal(self):
globals={"add_one": add_one, "in0": self.in0, "in1": self.in1},
).timeit()
)
print(
timeit.Timer(
"add_one_old(in0, in1)",
globals={"add_one_old": add_one_old, "in0": self.in0, "in1": self.in1},
).timeit()
)
print(
timeit.Timer(
"add_one_plain(in0, in1)",
Expand All @@ -106,17 +139,29 @@ def test_add_one_primal(self):
)

def test_add_two_deadarg(self):
at = add_two(self.in0, self.in1, self.in2)
atp = add_two_plain(self.in0, self.in1, self.in2)

at = add_two(self.in0, self.in1, self.in2)
ato = add_two_old(self.in0, self.in1, self.in2)

self.assertTrue((jnp.abs(at - atp) < 1e-6).all())
self.assertTrue((jnp.abs(ato - atp) < 1e-6).all())

def test_add_one_forward(self):
primals, tangents = fwd(self.in0, self.in1, self.din0, self.din1)
primals_p, tangents_p = fwd_plain(self.in0, self.in1, self.din0, self.din1)

primals, tangents = fwd(self.in0, self.in1, self.din0, self.din1)
primals_old, tangents_old = fwd_old(self.in0, self.in1, self.din0, self.din1)

self.assertTrue((jnp.abs(primals - primals_p) < 1e-6).all())
for t, t_p in zip(tangents, tangents_p):
self.assertTrue((jnp.abs(primals_old - primals_p) < 1e-6).all())
for t, t_old, t_p in zip(tangents, tangents_old, tangents_p):
self.assertTrue((jnp.abs(t - t_p) < 1e-6).all())
self.assertTrue((jnp.abs(t_old - t_p) < 1e-6).all())





print(
timeit.Timer(
Expand All @@ -130,6 +175,18 @@ def test_add_one_forward(self):
},
).timeit()
)
print(
timeit.Timer(
"fwd_old(in0, in1, din0, din1)",
globals={
"fwd_old": fwd_old,
"in0": self.in0,
"in1": self.in1,
"din0": self.din0,
"din1": self.din1,
},
).timeit()
)
print(
timeit.Timer(
"fwd_plain(in0, in1, din0, din1)",
Expand All @@ -144,39 +201,57 @@ def test_add_one_forward(self):
)

def test_add_two_deadarg_forward(self):
primals_p, tangents_p = fwd2_plain(
self.in0, self.in1, self.in2, self.din0, self.din1, self.din2
)

primals, tangents = fwd2(
self.in0, self.in1, self.in2, self.din0, self.din1, self.din2
)
primals_p, tangents_p = fwd2_plain(

primals_o, tangents_o = fwd2_old(
self.in0, self.in1, self.in2, self.din0, self.din1, self.din2
)

print(primals, primals_p)
print(primals, primals_o, primals_p)
self.assertTrue((jnp.abs(primals - primals_p) < 1e-6).all())
for i, (t, t_p) in enumerate(zip(tangents, tangents_p)):
print(i, t, t_p)
for i, (t, t_o, t_p) in enumerate(zip(tangents, tangents_o, tangents_p)):
print(i, to t_p)
self.assertTrue((jnp.abs(t - t_p) < 1e-6).all())
self.assertTrue((jnp.abs(t_o - t_p) < 1e-6).all())

def test_add_one_reverse(self):
dout = jnp.array([500.0, 700.0, 110.0])
primals_p, grads_p = rev_plain(self.in0, self.in1, dout)

primals, grads = rev(self.in0, self.in1, dout)
# TODO enzyme will in place 0 the gradient inputs, which may not be expected
print(dout)
# TODO enzyme will in place 0 the gradient inputs, which may not be expected
dout = jnp.array([500.0, 700.0, 110.0])
primals_p, grads_p = rev_plain(self.in0, self.in1, dout)
primals, grads = rev(self.in0, self.in1, dout)

dout = jnp.array([500.0, 700.0, 110.0])
primals_old, grads_old = rev_old(self.in0, self.in1, dout)


self.assertTrue((jnp.abs(primals - primals_p) < 1e-6).all())
for i, (g, g_p) in enumerate(zip(grads, grads_p)):
print(i, g, g_p)
self.assertTrue((jnp.abs(primals_old - primals_p) < 1e-6).all())
for i, (g, g_old, g_p) in enumerate(zip(grads, grads_old, grads_p)):
print(i, g, g_old, g_p)
self.assertTrue((jnp.abs(g - g_p) < 1e-6).all())
self.assertTrue((jnp.abs(g_old - g_p) < 1e-6).all())

print(
timeit.Timer(
"rev(in0, in1, dout)",
globals={"rev": rev, "in0": self.in0, "in1": self.in1, "dout": dout},
).timeit()
)
print(
timeit.Timer(
"rev_old(in0, in1, dout)",
globals={"rev": rev, "in0": self.in0, "in1": self.in1, "dout": dout},
).timeit()
)
print(
timeit.Timer(
"rev_plain(in0, in1, dout)",
Expand All @@ -191,27 +266,37 @@ def test_add_one_reverse(self):

def test_add_two_deadarg_reverse(self):
dout = jnp.array([500.0, 700.0, 110.0])
primals, grads = rev2(self.in0, self.in1, self.in2, dout)
primals_p, grads_p = rev2_plain(self.in0, self.in1, self.in2, dout)
# TODO enzyme will in place 0 the gradient inputs, which may not be expected
print(dout)
dout = jnp.array([500.0, 700.0, 110.0])
primals_p, grads_p = rev2_plain(self.in0, self.in1, self.in2, dout)
primals, grads = rev2(self.in0, self.in1, self.in2, dout)

dout = jnp.array([500.0, 700.0, 110.0])
primals_old, grads_old = rev2_old(self.in0, self.in1, self.in2, dout)

self.assertTrue((jnp.abs(primals - primals_p) < 1e-6).all())
for i, (g, g_p) in enumerate(zip(grads, grads_p)):
print(i, g, g_p)
self.assertTrue((jnp.abs(primals_old - primals_p) < 1e-6).all())
for i, (g, g_old, g_p) in enumerate(zip(grads, grads_old, grads_p)):
print(i, g, g_old, g_p)
self.assertTrue((jnp.abs(g - g_p) < 1e-6).all())

self.assertTrue((jnp.abs(g_old - g_p) < 1e-6).all())

@enzyme_jax_ir()
def esum(x):
return jnp.sum(x)

@enzyme_jax_ir(pipeline_options=False)
def esum_old(x):
return jnp.sum(x)

@jax.jit
def sumfwd(in0, din0):
return jax.jvp(esum, (in0,), (din0,))

@jax.jit
def sumfwd_old(in0, din0):
return jax.jvp(esum_old, (in0,), (din0,))

@jax.jit
def sumrev_p(in0):
Expand All @@ -226,6 +311,11 @@ def sumrev(in0):
grads = f_vjp(1.0)
return primals, grads

@jax.jit
def sumrev_old(in0):
primals, f_vjp = jax.vjp(esum_old, in0)
grads = f_vjp(1.0)
return primals, grads

class Sum(absltest.TestCase):
def setUp(self):
Expand All @@ -243,6 +333,12 @@ def test_forward(self):
self.assertTrue(jnp.abs(primals - 50 * 49 / 2) < 1e-6)
self.assertTrue(jnp.abs(tangents - 50 * 49 * 99 / 6) < 1e-6)

def test_forward_old(self):
primals, tangents = sumfwd_old(self.x, self.dx)
print(primals, tangents)
self.assertTrue(jnp.abs(primals - 50 * 49 / 2) < 1e-6)
self.assertTrue(jnp.abs(tangents - 50 * 49 * 99 / 6) < 1e-6)

def test_reverse_p(self):
primals, grads = sumrev_p(self.x)
print(primals, grads)
Expand All @@ -253,18 +349,32 @@ def test_reverse(self):
self.assertTrue(jnp.abs(primals - 50 * 49 / 2) < 1e-6)
self.assertTrue((jnp.abs(grads[0] - 1) < 1e-6).all())

def test_reverse_old(self):
primals, grads = sumrev_old(self.x)
print(primals, grads)
self.assertTrue(jnp.abs(primals - 50 * 49 / 2) < 1e-6)
self.assertTrue((jnp.abs(grads[0] - 1) < 1e-6).all())


@enzyme_jax_ir()
def ecache(x):
return x * x[0]

@enzyme_jax_ir(pipeline_options=False)
def ecache_old(x):
return x * x[0]

@jax.jit
def cacherev(in0, din0):
primals, f_vjp = jax.vjp(ecache, in0)
grads = f_vjp(din0)
return grads

@jax.jit
def cacherev_old(in0, din0):
primals, f_vjp = jax.vjp(ecache_old, in0)
grads = f_vjp(din0)
return grads

class Cache(absltest.TestCase):
def test_reverse(self):
Expand All @@ -279,6 +389,17 @@ def test_reverse(self):
)
self.assertTrue((jnp.abs(grads[0][1:]) < 1e-6).all())

def test_reverse_old(self):
dim = 288

x = jnp.array(range(dim), dtype=jnp.float32)
dx = jnp.array(range(dim), dtype=jnp.float32)

grads = cacherev_old(x, dx)
self.assertTrue(
jnp.abs(grads[0][0] - (dim - 1) * dim * (2 * (dim - 1) + 1) / 6) < 1e-6
)
self.assertTrue((jnp.abs(grads[0][1:]) < 1e-6).all())

if __name__ == "__main__":
absltest.main()

0 comments on commit 5429d59

Please sign in to comment.