Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[REVIEW]: Qualitatively Different Preformance-Related Plots #74

Closed
ClaudMor opened this issue Oct 26, 2024 · 5 comments
Closed

[REVIEW]: Qualitatively Different Preformance-Related Plots #74

ClaudMor opened this issue Oct 26, 2024 · 5 comments
Assignees
Labels
bug Something isn't working

Comments

@ClaudMor
Copy link

Description of the problem

Dear Authors,

When running the code listed in the Jax: linear algebra backend tutorial on google coolab as prescribed (using T4 GPU), I consistently get qualitatively different plots for the "Computing entropy on large multi-dimensional arrays" case.

Steps to reproduce

Open a new google colab notebook with T4 GPU, and run:


!pip install hoi

import numpy as np
import jax
import jax.numpy as jnp
from time import time

from hoi.metrics import Oinfo
from hoi.core import get_entropy

import matplotlib.pyplot as plt

plt.style.use("ggplot")

def compute_timings(n=15):
    n_samples = np.linspace(10, 10e2, n).astype(int)
    n_features = np.linspace(1, 10, n).astype(int)
    n_variables = np.linspace(1, 10e2, n).astype(int)

    entropy = jax.vmap(get_entropy(method="gc"), in_axes=(0,))

    # dry run
    entropy(np.random.rand(2, 2, 10))

    timings_cpu = []
    data_size = []
    for n_s, n_f, n_v in zip(n_samples, n_features, n_variables):
        # generate random data
        x = np.random.rand(n_v, n_f, n_s)
        x = jnp.asarray(x)

        # compute entropy
        start = time()
        entropy(x)
        timings_cpu.append(time() - start)
        data_size.append(n_s * n_f * n_v)

    return data_size, timings_cpu

with jax.default_device(jax.devices("gpu")[0]):
    data_size, timings_gpu = compute_timings()

with jax.default_device(jax.devices("cpu")[0]):
    data_size, timings_cpu = compute_timings()

plt.plot(data_size, timings_cpu, label="CPU")
plt.plot(data_size, timings_gpu, label="GPU")
plt.xlabel("Data size")
plt.ylabel("Time (s)")
plt.title("CPU vs. GPU for computing entropy", fontweight="bold")
plt.legend()

Expected results

Plots qualitatively consistent with that in the tutorial.

Actual results

You may find a few plots I get below (resulting from different runs). GPU is consistently doing worse than CPU.

image

image

image

Additional information

I understand that it may depend on Jax and/or the actual GPU used. Besides, I get correct plots in the next section. Therefore, the resolution of this issue is not mandatory for the review, but please feel free to provide a reason for it should you have one.

REVIEW ISSUE: openjournals/joss-reviews#7360

@EtienneCmb
Copy link
Collaborator

Hi @ClaudMor,

I'm able to reproduce the issue. Regarding the performance claims about higher-order, it still works, but for a reason, it doesn't work anymore with the entropy. I'll let you know once I figure it out ! Thanks !

@EtienneCmb EtienneCmb self-assigned this Oct 29, 2024
@EtienneCmb EtienneCmb added the bug Something isn't working label Oct 29, 2024
@EtienneCmb
Copy link
Collaborator

Hi @ClaudMor,

I fixed the issue. I believe it was coming from a bad way of timing computations. Here's the revised code :

import numpy as np
import jax
import jax.numpy as jnp
import timeit

from hoi.core import get_entropy

import matplotlib.pyplot as plt

plt.style.use("ggplot")

n = 5
n_repeat= 5


entropy = jax.jit(jax.vmap(get_entropy(method="gc"), in_axes=(0,)))
entropy(np.random.rand(2, 2, 10))  # dry run

n_samples = np.linspace(10, 10e2, n).astype(int)
n_features = np.linspace(1, 10, n).astype(int)
n_variables = np.linspace(1, 10e2, n).astype(int)

data_size, timings_gpu, timings_cpu = [], [], []
for n_s, n_f, n_v in zip(n_samples, n_features, n_variables):
    x = np.random.rand(n_v, n_f, n_s)
    x = jnp.asarray(x)
    
    with jax.default_device(jax.devices("cpu")[0]):
        result_cpu = timeit.timeit('entropy(x).block_until_ready()', number=n_repeat, globals=globals())
        timings_cpu.append(result_cpu / n_repeat)

    with jax.default_device(jax.devices("gpu")[0]):
        result_gpu = timeit.timeit('entropy(x).block_until_ready()', number=n_repeat, globals=globals())
        timings_gpu.append(result_gpu / n_repeat)
    
    data_size.append(n_s * n_f * n_v)

plt.plot(data_size, timings_cpu, label="CPU")
plt.plot(data_size, timings_gpu, label="GPU")
plt.xlabel("Data size")
plt.ylabel("Time (s)")
plt.title("CPU vs. GPU for computing entropy", fontweight="bold")
plt.legend()

Very similar results on Google colab and Kaggle :
image

I let you try this code snippet and if you've similar results, I'll update the online documentation accordingly

@ClaudMor
Copy link
Author

Hi @EtienneCmb ,

I confirm I now get similar results.

@EtienneCmb
Copy link
Collaborator

Great ! I'll update the doc soon

@EtienneCmb
Copy link
Collaborator

EtienneCmb commented Oct 31, 2024

Hi @ClaudMor, I fixed the doc as we discussed in commit fdb1301. The new version of the doc in online.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants