Skip to content

Commit

Permalink
Benchmark llama
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Dec 18, 2023
1 parent 5a15457 commit 46cd495
Showing 1 changed file with 96 additions and 14 deletions.
110 changes: 96 additions & 14 deletions test/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import jax.lax
import enzyme_ad.jax as enzyme_jax
import numpy as np

import timeit

def rmsnorm(x, weight):
ss = 1 / jnp.sqrt(x.dot(x) / x.shape[0] + 1e-5)
Expand Down Expand Up @@ -289,13 +289,38 @@ def jfunc(x, weights, key_cache, value_cache):
def efunc(x, weights, key_cache, value_cache):
return func(x, weights, key_cache, value_cache)

# eres = efunc(x, weights, key_cache, value_cache)
# print("Enzyme primal", eres)
# res = func(x, weights, key_cache, value_cache)
# print("Jax primal", res)
# print (" max error", jnp.max(jnp.abs(eres-res)))
# assert (jnp.abs(eres - res) < 1e-3).all()

eres = efunc(x, weights, key_cache, value_cache)
print("Enzyme primal", eres)
res = jfunc(x, weights, key_cache, value_cache)
print("Jax primal", res)
print (" max error", jnp.max(jnp.abs(eres-res)))
assert (jnp.abs(eres - res) < 1e-3).all()

number = 1000
print("Enzyme primal",
timeit.Timer(
"efunc(x, weights, key_cache, value_cache)",
globals={
'efunc':efunc,
"x":x,
"weights":weights,
"key_cache":key_cache,
"value_cache":value_cache,
},
).timeit(number)
)
print("JaX primal",
timeit.Timer(
"jfunc(x, weights, key_cache, value_cache)",
globals={
'jfunc':jfunc,
"x":x,
"weights":weights,
"key_cache":key_cache,
"value_cache":value_cache,
},
).timeit(number)
)
# jfunc = jax.jit(partial(forward, config))
# mlir = jax.jit(partial(forward, config)).lower(1, weights, key_cache, value_cache).compiler_ir(dialect="mhlo")

Expand All @@ -307,11 +332,38 @@ def jfwd(x, dx, weights, dweights, kc, dkc, vc, dvc):
def efwd(x, dx, weights, dweights, kc, dkc, vc, dvc):
return jax.jvp(efunc, (x, weights, kc, vc), (x, weights, dkc, dvc))

# print("pre fwd diff")
# eres = efwd(x, dx, weights, dweights, key_cache, key_cache, value_cache, value_cache)
# print("Enzyme fwd", eres)
# jres = jfwd(x, dx, weights, dweights, key_cache, key_cache, value_cache, value_cache)
# print("Jax fwd", jres)
eres = efwd(x, dx, weights, dweights, key_cache, key_cache, value_cache, value_cache)
print("Enzyme fwd", eres)
jres = jfwd(x, dx, weights, dweights, key_cache, key_cache, value_cache, value_cache)
print("Jax fwd", jres)
print("Enzyme fwd",
timeit.Timer(
"efwd(x, dx, weights, dweights, key_cache, key_cache, value_cache, value_cache)",
globals={
'efwd':efwd,
"x":x,
"dx":dx,
"weights":weights,
"dweights":dweights,
"key_cache":key_cache,
"value_cache":value_cache,
},
).timeit(number)
)
print("JaX fwd",
timeit.Timer(
"jfwd(x, dx, weights, dweights, key_cache, key_cache, value_cache, value_cache)",
globals={
'jfwd':jfwd,
"x":x,
"dx":dx,
"weights":weights,
"dweights":dweights,
"key_cache":key_cache,
"value_cache":value_cache,
},
).timeit(number)
)

@jax.jit
def jrev(x, weights, kc, vc, dx, dkc, dvc):
Expand All @@ -327,7 +379,37 @@ def erev(x, weights, kc, vc, dx, dkc, dvc):
print("Enzyme rev", eres)
jres = jrev(x, weights, key_cache, value_cache, dx, dkc, dvc)
print("Jax rev", jres)


print("Enzyme rev",
timeit.Timer(
"erev(x, weights, key_cache, value_cache, dx, dkc, dvc)",
globals={
'erev':erev,
"x":x,
"weights":weights,
"key_cache":key_cache,
"value_cache":value_cache,
"dx":dx,
"dkc":dkc,
"dvc":dvc,
},
).timeit(number)
)
print("JaX rev",
timeit.Timer(
"jrev(x, weights, key_cache, value_cache, dx, dkc, dvc)",
globals={
'jrev':jrev,
"x":x,
"weights":weights,
"key_cache":key_cache,
"value_cache":value_cache,
"dx":dx,
"dkc":dkc,
"dvc":dvc,
},
).timeit(number)
)

if __name__ == "__main__":
absltest.main()

0 comments on commit 46cd495

Please sign in to comment.