Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Running into OOM on 256 GB RAM #5

Closed
catchmoosa opened this issue Oct 23, 2024 · 4 comments
Closed

Running into OOM on 256 GB RAM #5

catchmoosa opened this issue Oct 23, 2024 · 4 comments

Comments

@catchmoosa
Copy link

catchmoosa commented Oct 23, 2024

Traceback (most recent call last): File "/teamspace/studios/this_studio/mcgill_fiam/0X-Causal_discovery/discovery.py", line 30, in <module> g_prob = model(x=x) File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/avici/pretrain.py", line 109, in __call__ out = onp.array(out) File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/jax/_src/array.py", line 429, in __array__ return np.asarray(self._value, dtype=dtype, **kwds) File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/jax/_src/profiler.py", line 333, in wrapper return func(*args, **kwargs) File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/jax/_src/array.py", line 628, in _value self._npy_value = self._single_device_array_to_np_array() jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: Buffer Definition Event: Error preparing computation: %sOut of memory allocating 332034480032 bytes.

This is on 10,000 rows with 51 variables. Can you help me with this issue?

@larslorch
Copy link
Owner

It's quite possible that 10,000 rows is simply too large for the forward pass. One idea -- though I've never tried it -- could be to split the rows into smaller chunks and create a bootstrapped estimate of the graph by running several forward passes.

However, it seems that your error occurs here, after the forward pass is already done, can you confirm this?
Maybe call jax.block_until_ready before this line to confirm, see here. In that case I don't currently know what could be the issue and would have to investigate. It would be great if you could provide a minimal example that reproduces this with random synthetic data

@syleeheal
Copy link

syleeheal commented Dec 2, 2024

Hello,

Thank you for the great work and repo!

I am also experiencing a similar memory issue. Just loading the model consumes about 37GB of GPU vRAM, and inferring from a large data (n > 10,000; d = 18) results in the OOM problem. I am not too familiar with JAX. Is the problem I am facing expected? If so, is it possible to use CPUs for AVICI inference? I wrote my code below for your reference.

Again, I very much appreciate this great repo.

import avici
import numpy as np
import pandas as pd

x = pd.read_csv(f'./data/obs.csv')
itv = pd.read_csv(f'./data/itv.csv')

model = avici.load_pretrained(download="scm-v0")
g_prob = model(x=np.array(x), interv=np.array(itv))

@larslorch
Copy link
Owner

Hi,

thanks for the feedback!
I pushed new features to the main branch that allow running the inference forward pass on CPU and also allow automatically sharding computation on multiple devices.
This should help with large datasets. Check out the updated docstring of the AVICIModel.__call__ function. You can now simply call


g_prob = model(x=x, interv=interv, devices="cpu")

to run the forward pass on CPU. Note that the runtime will be significantly slower on CPU and grows at least quadratically with n, but CPU memory can of course be cheaper/larger than GPU`.

Just loading the model consumes about 37GB of GPU vRAM

Depending on your JAX settings and GPU size, this probably occurs because JAX preallocates by default 75% of the available GPU memory when first used, see here. Since the model is fairly small in terms of parameters, it cannot take up 37 GB in memory only after loading.

If I set os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false" in the Colab notebook linked from the README, which avoids the automatic GPU allocation by JAX, then loading the model with avici.load_pretrained(download="scm-v0") allocates only 111MiB on GPU. We can check this by running !nvidia-smi in a new notebook cell.

Let me know if this helps and whether you have any other issues.

@larslorch
Copy link
Owner

larslorch commented Jan 3, 2025

@syleeheal @catchmoosa I pushed an experimental feature to the main branch and the latest PyPI release that helps with memory constraints.
The experimental_chunk_size flag in the forward pass of AVICIModel and BaseModel allows applying the transformer blocks up to the max-pooling operation in chunks of the input dataset (chunked along the observations axis, i.e., data points), which effectively processes chunks of datapoints at a time and then combines them with a final max-pooling operation before assembling the graph prediction.

Bear in mind that this idea was not validated properly, but when testing it in smaller examples, this performs better than staying at the memory limit, i.e., using 1 chunk. experimental_chunk_size should probably be set as high as memory allows, and it could make sense to consider smart ways of arranging the datapoints into chunks (e.g. distributing interventional data equally).

In the setting of @syleeheal with d=18 and n > 10000, we can for example test this via:

import avici
from avici import simulate_data
from avici.metrics import shd, classification_metrics, threshold_metrics

if __name__ == "__main__":

    model = avici.load_pretrained(download="scm-v0")

    d = 18
    chunk_size = 2000

    for n in [200, 1000, 2000, 4000, 8000, 16000, 32000]:
        g, x, interv = simulate_data(d=d, n=n, n_interv=0, domain="rff-gauss", seed=0)
        g_prob = model(x=x, interv=interv, devices="cpu", experimental_chunk_size=chunk_size)

        shdist = shd(g, (g_prob > 0.5).astype(int))
        f1 = classification_metrics(g, (g_prob > 0.5).astype(int))['f1']
        auprc = threshold_metrics(g, g_prob)['auprc']

        print(f"{n:8d}    shd: {shdist:.1f}    f1: {f1:.3f}    auprc: {auprc:.3f}")

This prints:

     200    shd: 17.0    f1: 0.640    auprc: 0.822
    1000    shd: 10.0    f1: 0.808    auprc: 0.947
    2000    shd: 13.0    f1: 0.731    auprc: 0.884
    4000    shd: 8.0    f1: 0.852    auprc: 0.941
    8000    shd: 8.0    f1: 0.852    auprc: 0.948
   16000    shd: 7.0    f1: 0.873    auprc: 0.939
   32000    shd: 9.0    f1: 0.836    auprc: 0.935

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants