diff --git a/test/llama.py b/test/llama.py index 53994d19..bd780c09 100644 --- a/test/llama.py +++ b/test/llama.py @@ -4,6 +4,7 @@ import jax.lax import enzyme_ad.jax as enzyme_jax import numpy as np +import timeit def rmsnorm(x, weight): @@ -289,13 +290,40 @@ def jfunc(x, weights, key_cache, value_cache): def efunc(x, weights, key_cache, value_cache): return func(x, weights, key_cache, value_cache) - # eres = efunc(x, weights, key_cache, value_cache) - # print("Enzyme primal", eres) - # res = func(x, weights, key_cache, value_cache) - # print("Jax primal", res) - # print (" max error", jnp.max(jnp.abs(eres-res))) - # assert (jnp.abs(eres - res) < 1e-3).all() - + eres = efunc(x, weights, key_cache, value_cache) + print("Enzyme primal", eres) + res = jfunc(x, weights, key_cache, value_cache) + print("Jax primal", res) + print(" max error", jnp.max(jnp.abs(eres - res))) + assert (jnp.abs(eres - res) < 1e-3).all() + + number = 1000 + print( + "Enzyme primal", + timeit.Timer( + "efunc(x, weights, key_cache, value_cache)", + globals={ + "efunc": efunc, + "x": x, + "weights": weights, + "key_cache": key_cache, + "value_cache": value_cache, + }, + ).timeit(number), + ) + print( + "JaX primal", + timeit.Timer( + "jfunc(x, weights, key_cache, value_cache)", + globals={ + "jfunc": jfunc, + "x": x, + "weights": weights, + "key_cache": key_cache, + "value_cache": value_cache, + }, + ).timeit(number), + ) # jfunc = jax.jit(partial(forward, config)) # mlir = jax.jit(partial(forward, config)).lower(1, weights, key_cache, value_cache).compiler_ir(dialect="mhlo") @@ -307,11 +335,44 @@ def jfwd(x, dx, weights, dweights, kc, dkc, vc, dvc): def efwd(x, dx, weights, dweights, kc, dkc, vc, dvc): return jax.jvp(efunc, (x, weights, kc, vc), (x, weights, dkc, dvc)) - # print("pre fwd diff") - # eres = efwd(x, dx, weights, dweights, key_cache, key_cache, value_cache, value_cache) - # print("Enzyme fwd", eres) - # jres = jfwd(x, dx, weights, dweights, key_cache, key_cache, value_cache, value_cache) - # print("Jax fwd", jres) + eres = efwd( + x, dx, weights, dweights, key_cache, key_cache, value_cache, value_cache + ) + print("Enzyme fwd", eres) + jres = jfwd( + x, dx, weights, dweights, key_cache, key_cache, value_cache, value_cache + ) + print("Jax fwd", jres) + print( + "Enzyme fwd", + timeit.Timer( + "efwd(x, dx, weights, dweights, key_cache, key_cache, value_cache, value_cache)", + globals={ + "efwd": efwd, + "x": x, + "dx": dx, + "weights": weights, + "dweights": dweights, + "key_cache": key_cache, + "value_cache": value_cache, + }, + ).timeit(number), + ) + print( + "JaX fwd", + timeit.Timer( + "jfwd(x, dx, weights, dweights, key_cache, key_cache, value_cache, value_cache)", + globals={ + "jfwd": jfwd, + "x": x, + "dx": dx, + "weights": weights, + "dweights": dweights, + "key_cache": key_cache, + "value_cache": value_cache, + }, + ).timeit(number), + ) @jax.jit def jrev(x, weights, kc, vc, dx, dkc, dvc): @@ -328,6 +389,39 @@ def erev(x, weights, kc, vc, dx, dkc, dvc): jres = jrev(x, weights, key_cache, value_cache, dx, dkc, dvc) print("Jax rev", jres) + print( + "Enzyme rev", + timeit.Timer( + "erev(x, weights, key_cache, value_cache, dx, dkc, dvc)", + globals={ + "erev": erev, + "x": x, + "weights": weights, + "key_cache": key_cache, + "value_cache": value_cache, + "dx": dx, + "dkc": dkc, + "dvc": dvc, + }, + ).timeit(number), + ) + print( + "JaX rev", + timeit.Timer( + "jrev(x, weights, key_cache, value_cache, dx, dkc, dvc)", + globals={ + "jrev": jrev, + "x": x, + "weights": weights, + "key_cache": key_cache, + "value_cache": value_cache, + "dx": dx, + "dkc": dkc, + "dvc": dvc, + }, + ).timeit(number), + ) + if __name__ == "__main__": absltest.main()