Skip to content

Commit

Permalink
Make e3nn patch faster using torch.index_select (#28)
Browse files Browse the repository at this point in the history
  • Loading branch information
frostedoyster authored Apr 26, 2023
1 parent 6af5e9a commit d79910f
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 8 deletions.
20 changes: 13 additions & 7 deletions benchmarks/pytorch/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def sphericart_benchmark(
sh_calculator = sphericart.torch.SphericalHarmonics(l_max, normalized=normalized)
omp_threads = sh_calculator.omp_num_threads()
print(
f"**** Timings for l_max={l_max}, n_samples={n_samples}, n_tries={n_tries},"
f"**** Timings for l_max={l_max}, n_samples={n_samples}, n_tries={n_tries}, "
+ f"dtype={dtype}, device={device}, omp_num_threads={omp_threads} ****"
)

Expand Down Expand Up @@ -144,15 +144,17 @@ def sphericart_benchmark(
f" E3NN-FW: {mean_time * 1e9: 10.1f} ns/sample ± "
+ f"{std_time * 1e9: 10.1f} (std)"
)
print("Warm-up timings / sec.: \n", time_fw[:warmup])
if verbose:
print("Warm-up timings / sec.: \n", time_fw[:warmup])

mean_time = time_bw[warmup:].mean() / n_samples
std_time = time_bw[warmup:].std() / n_samples
print(
f" E3NN-BW: {mean_time * 1e9:10.1f} ns/sample ± "
+ f"{std_time * 1e9:10.1f} (std)"
)
print("Warm-up timings / sec.: \n", time_bw[:warmup])
if verbose:
print("Warm-up timings / sec.: \n", time_bw[:warmup])

# check the timing with the patch
sphericart.torch.patch_e3nn(e3nn)
Expand Down Expand Up @@ -182,14 +184,16 @@ def sphericart_benchmark(
f" PATCH-FW: {mean_time*1e9: 10.1f} ns/sample ± "
+ f"{std_time*1e9: 10.1f} (std)"
)
print("Warm-up timings / sec.: \n", time_fw[:warmup])
if verbose:
print("Warm-up timings / sec.: \n", time_fw[:warmup])
mean_time = time_bw[warmup:].mean() / n_samples
std_time = time_bw[warmup:].std() / n_samples
print(
f" PATCH-BW: {mean_time * 1e9:10.1f} ns/sample ± "
+ f"{std_time * 1e9:10.1f} (std)"
)
print("Warm-up timings / sec.: \n", time_bw[:warmup])
if verbose:
print("Warm-up timings / sec.: \n", time_bw[:warmup])
sphericart.torch.unpatch_e3nn(e3nn)

if compare and _HAS_E3NN_JAX:
Expand Down Expand Up @@ -230,14 +234,16 @@ def loss_fn(xyz_tensor):
f" E3NN-JAX-FW: {mean_time*1e9: 10.1f} ns/sample ± "
+ f"{std_time*1e9: 10.1f} (std)"
)
print("Warm-up timings / sec.: \n", time_fw[:warmup])
if verbose:
print("Warm-up timings / sec.: \n", time_fw[:warmup])
mean_time = time_bw[warmup:].mean() / n_samples
std_time = time_bw[warmup:].std() / n_samples
print(
f" E3NN-JAX-BW: {mean_time*1e9: 10.1f} ns/sample ± "
+ f"{std_time*1e9: 10.1f} (std)"
)
print("Warm-up timings / sec.: \n", time_bw[:warmup])
if verbose:
print("Warm-up timings / sec.: \n", time_bw[:warmup])
print(
"******************************************************************************"
)
Expand Down
6 changes: 5 additions & 1 deletion sphericart-torch/python/sphericart/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,11 @@ def e3nn_spherical_harmonics(
l_max = max(l_list)
is_range_lmax = list(l_list) == list(range(l_max + 1))

sh = SphericalHarmonics(l_max, normalized=normalize).compute(x[:, [2, 0, 1]])
sh = SphericalHarmonics(l_max, normalized=normalize).compute(
torch.index_select(
x, 1, torch.tensor([2, 0, 1], dtype=torch.long, device=x.device)
)
)
assert normalization in ["integral", "norm", "component"]
if normalization != "integral":
sh *= math.sqrt(4 * math.pi)
Expand Down

0 comments on commit d79910f

Please sign in to comment.