Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve JAX benchmark #72

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
183 changes: 124 additions & 59 deletions benchmarks/jax/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
import jax
import sphericart.jax

jax.config.update("jax_enable_x64", True) # enable float64 for jax

docstring = """
Benchmarks for the jax implementation of `sphericart`.

Expand All @@ -25,15 +23,21 @@

def sphericart_benchmark(
l_max=10,
n_samples=10000,
n_samples=200,
n_tries=100,
normalized=False,
device="cpu",
dtype=np.float64,
dtype=jnp.float64,
compare=False,
verbose=False,
warmup=16,
):

if dtype == jnp.float64:
jax.config.update("jax_enable_x64", True) # enable float64 for jax
else:
jax.config.update("jax_enable_x64", False) # disenable float64 for jax

key = jax.random.PRNGKey(0)
xyz = jax.random.normal(key, (n_samples, 3), dtype=dtype)

Expand All @@ -47,43 +51,35 @@ def sphericart_benchmark(
)

time_noderi = np.zeros(n_tries + warmup)
time_fw = np.zeros(n_tries + warmup)
time_bw = np.zeros(n_tries + warmup)

for i in range(n_tries + warmup):
elapsed = -time.time()
sh_sphericart = sh_calculator(xyz, l_max, normalized)
elapsed += time.time()
time_noderi[i] = elapsed

mean_time = time_noderi[warmup:].mean() / n_samples
std_time = time_noderi[warmup:].std() / n_samples
print(
f" No derivatives: {mean_time * 1e9: 10.1f} ns/sample ± "
f" No derivatives: {mean_time * 1e9: 10.1f} ns/sample ± "
+ f"{std_time * 1e9: 10.1f} (std)"
)
if verbose:
print("Warm-up timings / sec.:\n", time_noderi[:warmup])
if verbose: print("Warm-up timings / sec.:\n", time_noderi[:warmup])

time_noderi[:] = 0.0

for i in range(n_tries + warmup):
elapsed = -time.time()
sh_sphericart_jit = sh_calculator_jit(xyz, l_max, normalized)
elapsed += time.time()
time_noderi[i] = elapsed

mean_time = time_noderi[warmup:].mean() / n_samples
std_time = time_noderi[warmup:].std() / n_samples
print(
f" No derivatives, jit: {mean_time * 1e9: 10.1f} ns/sample ± "
f" No derivatives (jit): {mean_time * 1e9: 10.1f} ns/sample ± "
+ f"{std_time * 1e9: 10.1f} (std)"
)
if verbose:
print("Warm-up timings / sec.:\n", time_noderi[:warmup])
if verbose: print("Warm-up timings / sec.:\n", time_noderi[:warmup])

def scalar_output(xyz, l_max, normalized):
return jax.numpy.sum(sphericart.jax.spherical_harmonics(xyz, l_max, normalized))
return jnp.sum(sphericart.jax.spherical_harmonics(xyz, l_max, normalized))

sh_grad = jax.jit(jax.grad(scalar_output), static_argnums=1)

Expand All @@ -97,14 +93,14 @@ def scalar_output(xyz, l_max, normalized):
mean_time = time_deri[warmup:].mean() / n_samples
std_time = time_deri[warmup:].std() / n_samples
print(
f" Gradient (scalar, jit): {mean_time * 1e9:10.1f} ns/sample ± "
f" grad (jit): {mean_time * 1e9:10.1f} ns/sample ± "
+ f"{std_time * 1e9:10.1f} (std)"
)
if verbose:
print("Warm-up timings / sec.:\n", time_deri[:warmup])

def single_scalar_output(x, l_max, normalized):
return jax.numpy.sum(sphericart.jax.spherical_harmonics(x, l_max, normalized))
return jnp.sum(sphericart.jax.spherical_harmonics(x, l_max, normalized))

# Compute the Hessian for a single (3,) input
single_hessian = jax.hessian(single_scalar_output)
Expand All @@ -124,7 +120,7 @@ def single_scalar_output(x, l_max, normalized):
mean_time = time_deri[warmup:].mean() / n_samples
std_time = time_deri[warmup:].std() / n_samples
print(
f" Hessian (scalar, jit): {mean_time * 1e9:10.1f} ns/sample ± "
f" hessian (jit): {mean_time * 1e9:10.1f} ns/sample ± "
+ f"{std_time * 1e9:10.1f} (std)"
)
if verbose:
Expand All @@ -134,7 +130,7 @@ def single_scalar_output(x, l_max, normalized):
# and take its jacobian with respect to the input Cartesian coordinates,
# both in forward mode and in reverse mode
def array_output(xyz, l_max, normalized):
return jax.numpy.sum(
return jnp.sum(
sphericart.jax.spherical_harmonics(xyz, l_max, normalized), axis=0
)

Expand All @@ -150,7 +146,7 @@ def array_output(xyz, l_max, normalized):
mean_time = time_deri[warmup:].mean() / n_samples
std_time = time_deri[warmup:].std() / n_samples
print(
f" jacfwd (jit): {mean_time * 1e9:10.1f} ns/sample ± "
f" jacfwd (jit): {mean_time * 1e9:10.1f} ns/sample ± "
+ f"{std_time * 1e9:10.1f} (std)"
)
if verbose:
Expand All @@ -168,53 +164,122 @@ def array_output(xyz, l_max, normalized):
mean_time = time_deri[warmup:].mean() / n_samples
std_time = time_deri[warmup:].std() / n_samples
print(
f" jacrev (jit): {mean_time * 1e9:10.1f} ns/sample ± "
f" jacrev (jit): {mean_time * 1e9:10.1f} ns/sample ± "
+ f"{std_time * 1e9:10.1f} (std)"
)
if verbose:
print("Warm-up timings / sec.:\n", time_deri[:warmup])

if compare and _HAS_E3NN_JAX:
xyz_tensor = xyz.copy()

# compare to e3nn-jax
irreps = e3nn_jax.Irreps([e3nn_jax.Irrep(l, 1) for l in range(l_max + 1)])

def loss_fn(xyz_tensor):
sh_e3nn = e3nn_jax.spherical_harmonics(
irreps, xyz_tensor, normalize=normalized, normalization="integral"
)
loss = jnp.sum(sh_e3nn.array)
def e3nn_sph(xyz):
return e3nn_jax.spherical_harmonics(
irreps, xyz, normalize=normalized, normalization="integral"
).array

def single_scalar_output(xyz):
sh_e3nn = e3nn_sph(xyz)
loss = jnp.sum(sh_e3nn)
return loss

loss_grad_fn = jax.grad(loss_fn)
def array_output(xyz):
return jnp.sum(
e3nn_sph(xyz), axis=0
)

for i in range(n_tries + warmup):
elapsed = -time.time()
_ = loss_fn(xyz_tensor)
elapsed += time.time()
time_fw[i] = elapsed
jit_e3nn_sph = jax.jit(e3nn_sph)
jit_grad = jax.jit(jax.grad(single_scalar_output))
jit_vmap_hessian = jax.jit(jax.vmap(jax.hessian(single_scalar_output), in_axes=(0,)))
jit_jacfwd = jax.jit(jax.jacfwd(array_output))
jit_jacrev = jax.jit(jax.jacrev(array_output))

elapsed = -time.time()
_ = loss_grad_fn(xyz_tensor)
elapsed += time.time()
time_bw[i] = elapsed
time_noderi = np.zeros(n_tries + warmup)
for i in range(n_tries + warmup):
elapsed = -time.time()
sh_sphericart = e3nn_sph(xyz)
elapsed += time.time()
time_noderi[i] = elapsed
mean_time = time_noderi[warmup:].mean() / n_samples
std_time = time_noderi[warmup:].std() / n_samples
print(
f" E3NN no derivatives: {mean_time * 1e9: 10.1f} ns/sample ± "
+ f"{std_time * 1e9: 10.1f} (std)"
)
if verbose: print("Warm-up timings / sec.:\n", time_noderi[:warmup])

mean_time = time_fw[warmup:].mean() / n_samples
std_time = time_fw[warmup:].std() / n_samples
print(
f" E3NN-JAX-FW: {mean_time*1e9: 10.1f} ns/sample ± "
+ f"{std_time*1e9: 10.1f} (std)"
)
if verbose:
print("Warm-up timings / sec.: \n", time_fw[:warmup])
mean_time = time_bw[warmup:].mean() / n_samples
std_time = time_bw[warmup:].std() / n_samples
print(
f" E3NN-JAX-BW: {mean_time*1e9: 10.1f} ns/sample ± "
+ f"{std_time*1e9: 10.1f} (std)"
)
if verbose:
print("Warm-up timings / sec.: \n", time_bw[:warmup])
time_noderi[:] = 0.0
for i in range(n_tries + warmup):
elapsed = -time.time()
_ = jit_e3nn_sph(xyz)
elapsed += time.time()
time_noderi[i] = elapsed
mean_time = time_noderi[warmup:].mean() / n_samples
std_time = time_noderi[warmup:].std() / n_samples
print(
f" E3NN no der (jit): {mean_time * 1e9: 10.1f} ns/sample ± "
+ f"{std_time * 1e9: 10.1f} (std)"
)
if verbose: print("Warm-up timings / sec.:\n", time_noderi[:warmup])

time_deri = np.zeros(n_tries + warmup)
for i in range(n_tries + warmup):
elapsed = -time.time()
_ = jit_grad(xyz)
elapsed += time.time()
time_deri[i] = elapsed
mean_time = time_deri[warmup:].mean() / n_samples
std_time = time_deri[warmup:].std() / n_samples
print(
f" E3NN grad (jit): {mean_time * 1e9:10.1f} ns/sample ± "
+ f"{std_time * 1e9:10.1f} (std)"
)
if verbose:
print("Warm-up timings / sec.:\n", time_deri[:warmup])

time_deri = np.zeros(n_tries + warmup)
for i in range(n_tries + warmup):
elapsed = -time.time()
_ = jit_vmap_hessian(xyz)
elapsed += time.time()
time_deri[i] = elapsed
mean_time = time_deri[warmup:].mean() / n_samples
std_time = time_deri[warmup:].std() / n_samples
print(
f" E3NN hessian (jit): {mean_time * 1e9:10.1f} ns/sample ± "
+ f"{std_time * 1e9:10.1f} (std)"
)
if verbose: print("Warm-up timings / sec.:\n", time_deri[:warmup])

time_deri = np.zeros(n_tries + warmup)
for i in range(n_tries + warmup):
elapsed = -time.time()
_ = jit_jacfwd(xyz)
elapsed += time.time()
time_deri[i] = elapsed
mean_time = time_deri[warmup:].mean() / n_samples
std_time = time_deri[warmup:].std() / n_samples
print(
f" E3NN jacfwd (jit): {mean_time * 1e9:10.1f} ns/sample ± "
+ f"{std_time * 1e9:10.1f} (std)"
)
if verbose: print("Warm-up timings / sec.:\n", time_deri[:warmup])

time_deri = np.zeros(n_tries + warmup)
for i in range(n_tries + warmup):
elapsed = -time.time()
_ = jit_jacrev(xyz)
elapsed += time.time()
time_deri[i] = elapsed
mean_time = time_deri[warmup:].mean() / n_samples
std_time = time_deri[warmup:].std() / n_samples
print(
f" E3NN jacrev (jit): {mean_time * 1e9:10.1f} ns/sample ± "
+ f"{std_time * 1e9:10.1f} (std)"
)
if verbose: print("Warm-up timings / sec.:\n", time_deri[:warmup])

print(
"******************************************************************************"
)
Expand All @@ -224,7 +289,7 @@ def loss_fn(xyz_tensor):
parser = argparse.ArgumentParser(description=docstring)

parser.add_argument("-l", type=int, default=10, help="maximum angular momentum")
parser.add_argument("-s", type=int, default=10000, help="number of samples")
parser.add_argument("-s", type=int, default=200, help="number of samples")
parser.add_argument("-t", type=int, default=100, help="number of runs/sample")
parser.add_argument(
"-cpu", type=int, default=1, help="print CPU results (0=False, 1=True)"
Expand Down Expand Up @@ -264,7 +329,7 @@ def loss_fn(xyz_tensor):
args.t,
args.normalized,
device="cpu",
dtype=np.float64,
dtype=jnp.float64,
compare=args.compare,
verbose=args.verbose,
warmup=args.warmup,
Expand All @@ -275,7 +340,7 @@ def loss_fn(xyz_tensor):
args.t,
args.normalized,
device="cpu",
dtype=np.float32,
dtype=jnp.float32,
compare=args.compare,
verbose=args.verbose,
warmup=args.warmup,
Expand Down