Skip to content

Commit

Permalink
Benchmark llama (#23)
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses authored Dec 21, 2023
1 parent 5a15457 commit 24eedb4
Showing 1 changed file with 106 additions and 12 deletions.
118 changes: 106 additions & 12 deletions test/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import jax.lax
import enzyme_ad.jax as enzyme_jax
import numpy as np
import timeit


def rmsnorm(x, weight):
Expand Down Expand Up @@ -289,13 +290,40 @@ def jfunc(x, weights, key_cache, value_cache):
def efunc(x, weights, key_cache, value_cache):
return func(x, weights, key_cache, value_cache)

# eres = efunc(x, weights, key_cache, value_cache)
# print("Enzyme primal", eres)
# res = func(x, weights, key_cache, value_cache)
# print("Jax primal", res)
# print (" max error", jnp.max(jnp.abs(eres-res)))
# assert (jnp.abs(eres - res) < 1e-3).all()

eres = efunc(x, weights, key_cache, value_cache)
print("Enzyme primal", eres)
res = jfunc(x, weights, key_cache, value_cache)
print("Jax primal", res)
print(" max error", jnp.max(jnp.abs(eres - res)))
assert (jnp.abs(eres - res) < 1e-3).all()

number = 1000
print(
"Enzyme primal",
timeit.Timer(
"efunc(x, weights, key_cache, value_cache)",
globals={
"efunc": efunc,
"x": x,
"weights": weights,
"key_cache": key_cache,
"value_cache": value_cache,
},
).timeit(number),
)
print(
"JaX primal",
timeit.Timer(
"jfunc(x, weights, key_cache, value_cache)",
globals={
"jfunc": jfunc,
"x": x,
"weights": weights,
"key_cache": key_cache,
"value_cache": value_cache,
},
).timeit(number),
)
# jfunc = jax.jit(partial(forward, config))
# mlir = jax.jit(partial(forward, config)).lower(1, weights, key_cache, value_cache).compiler_ir(dialect="mhlo")

Expand All @@ -307,11 +335,44 @@ def jfwd(x, dx, weights, dweights, kc, dkc, vc, dvc):
def efwd(x, dx, weights, dweights, kc, dkc, vc, dvc):
return jax.jvp(efunc, (x, weights, kc, vc), (x, weights, dkc, dvc))

# print("pre fwd diff")
# eres = efwd(x, dx, weights, dweights, key_cache, key_cache, value_cache, value_cache)
# print("Enzyme fwd", eres)
# jres = jfwd(x, dx, weights, dweights, key_cache, key_cache, value_cache, value_cache)
# print("Jax fwd", jres)
eres = efwd(
x, dx, weights, dweights, key_cache, key_cache, value_cache, value_cache
)
print("Enzyme fwd", eres)
jres = jfwd(
x, dx, weights, dweights, key_cache, key_cache, value_cache, value_cache
)
print("Jax fwd", jres)
print(
"Enzyme fwd",
timeit.Timer(
"efwd(x, dx, weights, dweights, key_cache, key_cache, value_cache, value_cache)",
globals={
"efwd": efwd,
"x": x,
"dx": dx,
"weights": weights,
"dweights": dweights,
"key_cache": key_cache,
"value_cache": value_cache,
},
).timeit(number),
)
print(
"JaX fwd",
timeit.Timer(
"jfwd(x, dx, weights, dweights, key_cache, key_cache, value_cache, value_cache)",
globals={
"jfwd": jfwd,
"x": x,
"dx": dx,
"weights": weights,
"dweights": dweights,
"key_cache": key_cache,
"value_cache": value_cache,
},
).timeit(number),
)

@jax.jit
def jrev(x, weights, kc, vc, dx, dkc, dvc):
Expand All @@ -328,6 +389,39 @@ def erev(x, weights, kc, vc, dx, dkc, dvc):
jres = jrev(x, weights, key_cache, value_cache, dx, dkc, dvc)
print("Jax rev", jres)

print(
"Enzyme rev",
timeit.Timer(
"erev(x, weights, key_cache, value_cache, dx, dkc, dvc)",
globals={
"erev": erev,
"x": x,
"weights": weights,
"key_cache": key_cache,
"value_cache": value_cache,
"dx": dx,
"dkc": dkc,
"dvc": dvc,
},
).timeit(number),
)
print(
"JaX rev",
timeit.Timer(
"jrev(x, weights, key_cache, value_cache, dx, dkc, dvc)",
globals={
"jrev": jrev,
"x": x,
"weights": weights,
"key_cache": key_cache,
"value_cache": value_cache,
"dx": dx,
"dkc": dkc,
"dvc": dvc,
},
).timeit(number),
)


if __name__ == "__main__":
absltest.main()

0 comments on commit 24eedb4

Please sign in to comment.