-
Notifications
You must be signed in to change notification settings - Fork 10
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Running into OOM on 256 GB RAM #5
Comments
It's quite possible that 10,000 rows is simply too large for the forward pass. One idea -- though I've never tried it -- could be to split the rows into smaller chunks and create a bootstrapped estimate of the graph by running several forward passes. However, it seems that your error occurs here, after the forward pass is already done, can you confirm this? |
Hello, Thank you for the great work and repo! I am also experiencing a similar memory issue. Just loading the model consumes about 37GB of GPU vRAM, and inferring from a large data (n > 10,000; d = 18) results in the OOM problem. I am not too familiar with JAX. Is the problem I am facing expected? If so, is it possible to use CPUs for AVICI inference? I wrote my code below for your reference. Again, I very much appreciate this great repo.
|
Hi, thanks for the feedback! g_prob = model(x=x, interv=interv, devices="cpu") to run the forward pass on CPU. Note that the runtime will be significantly slower on CPU and grows at least quadratically with
Depending on your JAX settings and GPU size, this probably occurs because JAX preallocates by default 75% of the available GPU memory when first used, see here. Since the model is fairly small in terms of parameters, it cannot take up 37 GB in memory only after loading. If I set Let me know if this helps and whether you have any other issues. |
@syleeheal @catchmoosa I pushed an experimental feature to the Bear in mind that this idea was not validated properly, but when testing it in smaller examples, this performs better than staying at the memory limit, i.e., using 1 chunk. In the setting of @syleeheal with d=18 and n > 10000, we can for example test this via: import avici
from avici import simulate_data
from avici.metrics import shd, classification_metrics, threshold_metrics
if __name__ == "__main__":
model = avici.load_pretrained(download="scm-v0")
d = 18
chunk_size = 2000
for n in [200, 1000, 2000, 4000, 8000, 16000, 32000]:
g, x, interv = simulate_data(d=d, n=n, n_interv=0, domain="rff-gauss", seed=0)
g_prob = model(x=x, interv=interv, devices="cpu", experimental_chunk_size=chunk_size)
shdist = shd(g, (g_prob > 0.5).astype(int))
f1 = classification_metrics(g, (g_prob > 0.5).astype(int))['f1']
auprc = threshold_metrics(g, g_prob)['auprc']
print(f"{n:8d} shd: {shdist:.1f} f1: {f1:.3f} auprc: {auprc:.3f}") This prints:
|
Traceback (most recent call last): File "/teamspace/studios/this_studio/mcgill_fiam/0X-Causal_discovery/discovery.py", line 30, in <module> g_prob = model(x=x) File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/avici/pretrain.py", line 109, in __call__ out = onp.array(out) File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/jax/_src/array.py", line 429, in __array__ return np.asarray(self._value, dtype=dtype, **kwds) File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/jax/_src/profiler.py", line 333, in wrapper return func(*args, **kwargs) File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/jax/_src/array.py", line 628, in _value self._npy_value = self._single_device_array_to_np_array() jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: Buffer Definition Event: Error preparing computation: %sOut of memory allocating 332034480032 bytes.
This is on 10,000 rows with 51 variables. Can you help me with this issue?
The text was updated successfully, but these errors were encountered: