-
Notifications
You must be signed in to change notification settings - Fork 10
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[REVIEW]: Qualitatively Different Preformance-Related Plots #74
Comments
Hi @ClaudMor, I'm able to reproduce the issue. Regarding the performance claims about higher-order, it still works, but for a reason, it doesn't work anymore with the entropy. I'll let you know once I figure it out ! Thanks ! |
Hi @ClaudMor, I fixed the issue. I believe it was coming from a bad way of timing computations. Here's the revised code : import numpy as np
import jax
import jax.numpy as jnp
import timeit
from hoi.core import get_entropy
import matplotlib.pyplot as plt
plt.style.use("ggplot")
n = 5
n_repeat= 5
entropy = jax.jit(jax.vmap(get_entropy(method="gc"), in_axes=(0,)))
entropy(np.random.rand(2, 2, 10)) # dry run
n_samples = np.linspace(10, 10e2, n).astype(int)
n_features = np.linspace(1, 10, n).astype(int)
n_variables = np.linspace(1, 10e2, n).astype(int)
data_size, timings_gpu, timings_cpu = [], [], []
for n_s, n_f, n_v in zip(n_samples, n_features, n_variables):
x = np.random.rand(n_v, n_f, n_s)
x = jnp.asarray(x)
with jax.default_device(jax.devices("cpu")[0]):
result_cpu = timeit.timeit('entropy(x).block_until_ready()', number=n_repeat, globals=globals())
timings_cpu.append(result_cpu / n_repeat)
with jax.default_device(jax.devices("gpu")[0]):
result_gpu = timeit.timeit('entropy(x).block_until_ready()', number=n_repeat, globals=globals())
timings_gpu.append(result_gpu / n_repeat)
data_size.append(n_s * n_f * n_v)
plt.plot(data_size, timings_cpu, label="CPU")
plt.plot(data_size, timings_gpu, label="GPU")
plt.xlabel("Data size")
plt.ylabel("Time (s)")
plt.title("CPU vs. GPU for computing entropy", fontweight="bold")
plt.legend() Very similar results on Google colab and Kaggle : I let you try this code snippet and if you've similar results, I'll update the online documentation accordingly |
Hi @EtienneCmb , I confirm I now get similar results. |
Great ! I'll update the doc soon |
Description of the problem
Dear Authors,
When running the code listed in the Jax: linear algebra backend tutorial on google coolab as prescribed (using T4 GPU), I consistently get qualitatively different plots for the "Computing entropy on large multi-dimensional arrays" case.
Steps to reproduce
Expected results
Plots qualitatively consistent with that in the tutorial.
Actual results
You may find a few plots I get below (resulting from different runs). GPU is consistently doing worse than CPU.
Additional information
I understand that it may depend on Jax and/or the actual GPU used. Besides, I get correct plots in the next section. Therefore, the resolution of this issue is not mandatory for the review, but please feel free to provide a reason for it should you have one.
REVIEW ISSUE: openjournals/joss-reviews#7360
The text was updated successfully, but these errors were encountered: