From f349620f316e67287afe5ecdcfb2c25dc73c61ba Mon Sep 17 00:00:00 2001 From: frostedoyster Date: Sat, 21 Oct 2023 16:46:24 +0200 Subject: [PATCH 1/2] Formatting and jitting e3nn --- benchmarks/jax/benchmark.py | 181 ++++++++++++++++++++++++------------ 1 file changed, 123 insertions(+), 58 deletions(-) diff --git a/benchmarks/jax/benchmark.py b/benchmarks/jax/benchmark.py index b969b27d..8c9860aa 100644 --- a/benchmarks/jax/benchmark.py +++ b/benchmarks/jax/benchmark.py @@ -6,8 +6,6 @@ import jax import sphericart.jax -jax.config.update("jax_enable_x64", True) # enable float64 for jax - docstring = """ Benchmarks for the jax implementation of `sphericart`. @@ -25,15 +23,21 @@ def sphericart_benchmark( l_max=10, - n_samples=10000, + n_samples=200, n_tries=100, normalized=False, device="cpu", - dtype=np.float64, + dtype=jnp.float64, compare=False, verbose=False, warmup=16, ): + + if dtype == jnp.float64: + jax.config.update("jax_enable_x64", True) # enable float64 for jax + else: + jax.config.update("jax_enable_x64", False) # disenable float64 for jax + key = jax.random.PRNGKey(0) xyz = jax.random.normal(key, (n_samples, 3), dtype=dtype) @@ -47,43 +51,35 @@ def sphericart_benchmark( ) time_noderi = np.zeros(n_tries + warmup) - time_fw = np.zeros(n_tries + warmup) - time_bw = np.zeros(n_tries + warmup) - for i in range(n_tries + warmup): elapsed = -time.time() sh_sphericart = sh_calculator(xyz, l_max, normalized) elapsed += time.time() time_noderi[i] = elapsed - mean_time = time_noderi[warmup:].mean() / n_samples std_time = time_noderi[warmup:].std() / n_samples print( - f" No derivatives: {mean_time * 1e9: 10.1f} ns/sample ± " + f" No derivatives: {mean_time * 1e9: 10.1f} ns/sample ± " + f"{std_time * 1e9: 10.1f} (std)" ) - if verbose: - print("Warm-up timings / sec.:\n", time_noderi[:warmup]) + if verbose: print("Warm-up timings / sec.:\n", time_noderi[:warmup]) time_noderi[:] = 0.0 - for i in range(n_tries + warmup): elapsed = -time.time() sh_sphericart_jit = sh_calculator_jit(xyz, l_max, normalized) elapsed += time.time() time_noderi[i] = elapsed - mean_time = time_noderi[warmup:].mean() / n_samples std_time = time_noderi[warmup:].std() / n_samples print( f" No derivatives, jit: {mean_time * 1e9: 10.1f} ns/sample ± " + f"{std_time * 1e9: 10.1f} (std)" ) - if verbose: - print("Warm-up timings / sec.:\n", time_noderi[:warmup]) + if verbose: print("Warm-up timings / sec.:\n", time_noderi[:warmup]) def scalar_output(xyz, l_max, normalized): - return jax.numpy.sum(sphericart.jax.spherical_harmonics(xyz, l_max, normalized)) + return jnp.sum(sphericart.jax.spherical_harmonics(xyz, l_max, normalized)) sh_grad = jax.jit(jax.grad(scalar_output), static_argnums=1) @@ -97,14 +93,14 @@ def scalar_output(xyz, l_max, normalized): mean_time = time_deri[warmup:].mean() / n_samples std_time = time_deri[warmup:].std() / n_samples print( - f" Gradient (scalar, jit): {mean_time * 1e9:10.1f} ns/sample ± " + f" grad (jit): {mean_time * 1e9:10.1f} ns/sample ± " + f"{std_time * 1e9:10.1f} (std)" ) if verbose: print("Warm-up timings / sec.:\n", time_deri[:warmup]) def single_scalar_output(x, l_max, normalized): - return jax.numpy.sum(sphericart.jax.spherical_harmonics(x, l_max, normalized)) + return jnp.sum(sphericart.jax.spherical_harmonics(x, l_max, normalized)) # Compute the Hessian for a single (3,) input single_hessian = jax.hessian(single_scalar_output) @@ -124,7 +120,7 @@ def single_scalar_output(x, l_max, normalized): mean_time = time_deri[warmup:].mean() / n_samples std_time = time_deri[warmup:].std() / n_samples print( - f" Hessian (scalar, jit): {mean_time * 1e9:10.1f} ns/sample ± " + f" hessian (jit): {mean_time * 1e9:10.1f} ns/sample ± " + f"{std_time * 1e9:10.1f} (std)" ) if verbose: @@ -134,7 +130,7 @@ def single_scalar_output(x, l_max, normalized): # and take its jacobian with respect to the input Cartesian coordinates, # both in forward mode and in reverse mode def array_output(xyz, l_max, normalized): - return jax.numpy.sum( + return jnp.sum( sphericart.jax.spherical_harmonics(xyz, l_max, normalized), axis=0 ) @@ -150,7 +146,7 @@ def array_output(xyz, l_max, normalized): mean_time = time_deri[warmup:].mean() / n_samples std_time = time_deri[warmup:].std() / n_samples print( - f" jacfwd (jit): {mean_time * 1e9:10.1f} ns/sample ± " + f" jacfwd (jit): {mean_time * 1e9:10.1f} ns/sample ± " + f"{std_time * 1e9:10.1f} (std)" ) if verbose: @@ -168,53 +164,122 @@ def array_output(xyz, l_max, normalized): mean_time = time_deri[warmup:].mean() / n_samples std_time = time_deri[warmup:].std() / n_samples print( - f" jacrev (jit): {mean_time * 1e9:10.1f} ns/sample ± " + f" jacrev (jit): {mean_time * 1e9:10.1f} ns/sample ± " + f"{std_time * 1e9:10.1f} (std)" ) if verbose: print("Warm-up timings / sec.:\n", time_deri[:warmup]) if compare and _HAS_E3NN_JAX: - xyz_tensor = xyz.copy() - + # compare to e3nn-jax irreps = e3nn_jax.Irreps([e3nn_jax.Irrep(l, 1) for l in range(l_max + 1)]) - def loss_fn(xyz_tensor): - sh_e3nn = e3nn_jax.spherical_harmonics( - irreps, xyz_tensor, normalize=normalized, normalization="integral" - ) - loss = jnp.sum(sh_e3nn.array) + def e3nn_sph(xyz): + return e3nn_jax.spherical_harmonics( + irreps, xyz, normalize=normalized, normalization="integral" + ).array + + def single_scalar_output(xyz): + sh_e3nn = e3nn_sph(xyz) + loss = jnp.sum(sh_e3nn) return loss - loss_grad_fn = jax.grad(loss_fn) + def array_output(xyz): + return jnp.sum( + e3nn_sph(xyz), axis=0 + ) - for i in range(n_tries + warmup): - elapsed = -time.time() - _ = loss_fn(xyz_tensor) - elapsed += time.time() - time_fw[i] = elapsed + jit_e3nn_sph = jax.jit(e3nn_sph) + jit_grad = jax.jit(jax.grad(single_scalar_output)) + jit_vmap_hessian = jax.jit(jax.vmap(jax.hessian(single_scalar_output), in_axes=(0,))) + jit_jacfwd = jax.jit(jax.jacfwd(array_output)) + jit_jacrev = jax.jit(jax.jacrev(array_output)) - elapsed = -time.time() - _ = loss_grad_fn(xyz_tensor) - elapsed += time.time() - time_bw[i] = elapsed + time_noderi = np.zeros(n_tries + warmup) + for i in range(n_tries + warmup): + elapsed = -time.time() + sh_sphericart = e3nn_sph(xyz) + elapsed += time.time() + time_noderi[i] = elapsed + mean_time = time_noderi[warmup:].mean() / n_samples + std_time = time_noderi[warmup:].std() / n_samples + print( + f" E3NN no derivatives: {mean_time * 1e9: 10.1f} ns/sample ± " + + f"{std_time * 1e9: 10.1f} (std)" + ) + if verbose: print("Warm-up timings / sec.:\n", time_noderi[:warmup]) - mean_time = time_fw[warmup:].mean() / n_samples - std_time = time_fw[warmup:].std() / n_samples - print( - f" E3NN-JAX-FW: {mean_time*1e9: 10.1f} ns/sample ± " - + f"{std_time*1e9: 10.1f} (std)" - ) - if verbose: - print("Warm-up timings / sec.: \n", time_fw[:warmup]) - mean_time = time_bw[warmup:].mean() / n_samples - std_time = time_bw[warmup:].std() / n_samples - print( - f" E3NN-JAX-BW: {mean_time*1e9: 10.1f} ns/sample ± " - + f"{std_time*1e9: 10.1f} (std)" - ) - if verbose: - print("Warm-up timings / sec.: \n", time_bw[:warmup]) + time_noderi[:] = 0.0 + for i in range(n_tries + warmup): + elapsed = -time.time() + _ = jit_e3nn_sph(xyz) + elapsed += time.time() + time_noderi[i] = elapsed + mean_time = time_noderi[warmup:].mean() / n_samples + std_time = time_noderi[warmup:].std() / n_samples + print( + f" E3NN no der, jit: {mean_time * 1e9: 10.1f} ns/sample ± " + + f"{std_time * 1e9: 10.1f} (std)" + ) + if verbose: print("Warm-up timings / sec.:\n", time_noderi[:warmup]) + + time_deri = np.zeros(n_tries + warmup) + for i in range(n_tries + warmup): + elapsed = -time.time() + _ = jit_grad(xyz) + elapsed += time.time() + time_deri[i] = elapsed + mean_time = time_deri[warmup:].mean() / n_samples + std_time = time_deri[warmup:].std() / n_samples + print( + f" E3NN grad (jit): {mean_time * 1e9:10.1f} ns/sample ± " + + f"{std_time * 1e9:10.1f} (std)" + ) + if verbose: + print("Warm-up timings / sec.:\n", time_deri[:warmup]) + + time_deri = np.zeros(n_tries + warmup) + for i in range(n_tries + warmup): + elapsed = -time.time() + _ = jit_vmap_hessian(xyz) + elapsed += time.time() + time_deri[i] = elapsed + mean_time = time_deri[warmup:].mean() / n_samples + std_time = time_deri[warmup:].std() / n_samples + print( + f" E3NN hessian (jit): {mean_time * 1e9:10.1f} ns/sample ± " + + f"{std_time * 1e9:10.1f} (std)" + ) + if verbose: print("Warm-up timings / sec.:\n", time_deri[:warmup]) + + time_deri = np.zeros(n_tries + warmup) + for i in range(n_tries + warmup): + elapsed = -time.time() + _ = jit_jacfwd(xyz) + elapsed += time.time() + time_deri[i] = elapsed + mean_time = time_deri[warmup:].mean() / n_samples + std_time = time_deri[warmup:].std() / n_samples + print( + f" E3NN jacfwd (jit): {mean_time * 1e9:10.1f} ns/sample ± " + + f"{std_time * 1e9:10.1f} (std)" + ) + if verbose: print("Warm-up timings / sec.:\n", time_deri[:warmup]) + + time_deri = np.zeros(n_tries + warmup) + for i in range(n_tries + warmup): + elapsed = -time.time() + _ = jit_jacrev(xyz) + elapsed += time.time() + time_deri[i] = elapsed + mean_time = time_deri[warmup:].mean() / n_samples + std_time = time_deri[warmup:].std() / n_samples + print( + f" E3NN jacrev (jit): {mean_time * 1e9:10.1f} ns/sample ± " + + f"{std_time * 1e9:10.1f} (std)" + ) + if verbose: print("Warm-up timings / sec.:\n", time_deri[:warmup]) + print( "******************************************************************************" ) @@ -224,7 +289,7 @@ def loss_fn(xyz_tensor): parser = argparse.ArgumentParser(description=docstring) parser.add_argument("-l", type=int, default=10, help="maximum angular momentum") - parser.add_argument("-s", type=int, default=10000, help="number of samples") + parser.add_argument("-s", type=int, default=200, help="number of samples") parser.add_argument("-t", type=int, default=100, help="number of runs/sample") parser.add_argument( "-cpu", type=int, default=1, help="print CPU results (0=False, 1=True)" @@ -264,7 +329,7 @@ def loss_fn(xyz_tensor): args.t, args.normalized, device="cpu", - dtype=np.float64, + dtype=jnp.float64, compare=args.compare, verbose=args.verbose, warmup=args.warmup, @@ -275,7 +340,7 @@ def loss_fn(xyz_tensor): args.t, args.normalized, device="cpu", - dtype=np.float32, + dtype=jnp.float32, compare=args.compare, verbose=args.verbose, warmup=args.warmup, From 7b411d165e2278315cf5876f6596a77b71a80963 Mon Sep 17 00:00:00 2001 From: Michele Ceriotti Date: Sun, 22 Oct 2023 08:20:10 -0700 Subject: [PATCH 2/2] Just consistentcy in output --- benchmarks/jax/benchmark.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/benchmarks/jax/benchmark.py b/benchmarks/jax/benchmark.py index 8c9860aa..e347242f 100644 --- a/benchmarks/jax/benchmark.py +++ b/benchmarks/jax/benchmark.py @@ -73,7 +73,7 @@ def sphericart_benchmark( mean_time = time_noderi[warmup:].mean() / n_samples std_time = time_noderi[warmup:].std() / n_samples print( - f" No derivatives, jit: {mean_time * 1e9: 10.1f} ns/sample ± " + f" No derivatives (jit): {mean_time * 1e9: 10.1f} ns/sample ± " + f"{std_time * 1e9: 10.1f} (std)" ) if verbose: print("Warm-up timings / sec.:\n", time_noderi[:warmup]) @@ -218,7 +218,7 @@ def array_output(xyz): mean_time = time_noderi[warmup:].mean() / n_samples std_time = time_noderi[warmup:].std() / n_samples print( - f" E3NN no der, jit: {mean_time * 1e9: 10.1f} ns/sample ± " + f" E3NN no der (jit): {mean_time * 1e9: 10.1f} ns/sample ± " + f"{std_time * 1e9: 10.1f} (std)" ) if verbose: print("Warm-up timings / sec.:\n", time_noderi[:warmup])