You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi! I'm working on a model that needs to do a large amount of gradient updates within JIT, and we use a scan to implement these updates. However, it seems when the loop is rolled (unroll=False in the code below) it requires an additional amount of memory that is not needed with unroll=True). When the number of iterations goes from 1 to 2, the unrolled version increases the memory requirement by an expected incremental amount, and the requirement increments by the same amount when increasing from 2 to 3, 3 to 4, etc. With unroll=False, the memory requirement has a larger step change (3x expected) when increasing the number of iterations to 2, but then continues to increase by the expected amount for additional steps.
From what I understand, the body of the scan loop must be a fixed function, so the allocation required at all iterations will be the max of what will be required in each iteration.
Things I have considered:
There might be some behavior at the start or end of the scan loop that requires more memory. I've tried unrolling only the first and last iterations of the scan, but this only delays the jump in memory usage to whenever the rolled part goes from 1 to 2 iterations.
The compiler can't optimize the scan loop body with knowledge of what happens externally to it. If this is the case, is it possible to get around it?
Is this a known issue? If so, are there any good ways to get around it?
Below is a minimal reproduction:
importjaximportjax.numpyasjnpimportjax.randomasjrandomdefminimal_example(seqlen: int, unroll: bool=False):
key=jrandom.key(0)
X=jnp.ones((seqlen, 1), dtype=jnp.float32) # `seqlen` samples to scan overdefinit_model():
return {'w1': jrandom.normal(key, shape=(1, 1024)), 'w2': jrandom.normal(key, shape=(1024, 1024*1024))}
# Model is two linear layers (single layer is degenerate)defmodel_fn(model, x):
returnx @ model['w1'] @ model['w2']
model=init_model()
defcompute_loss(model, X):
defscan_update_rule(carry, x):
model, loss=carry# Result from modelz=model_fn(model, x)
# Modify model based on this result (in reality this is a gradient update, but simplified here for minimal repro)zval=jnp.sum(z)
new_model=jax.tree.map(lambdap: p*zval, model)
new_loss=loss+zval# Accumulate lossreturn (new_model, new_loss), Nonecarry= (model, jnp.zeros(()))
(_, final_loss), _=jax.lax.scan(scan_update_rule, carry, X, unroll=unroll)
returnfinal_lossminimal_fun=jax.grad(compute_loss)
minimal_fun=jax.jit(minimal_fun, donate_argnums=(0,))
lowered=minimal_fun.lower(model, X)
compiled=lowered.compile()
memory_analysis=compiled.memory_analysis()
print(f"\n==={seqlen=}, {unroll=}===")
print(f"{memory_analysis}")
print(f"Estimated memory cost: {(memory_analysis.output_size_in_bytes+memory_analysis.temp_size_in_bytes) /1024**3} GB")
minimal_example(seqlen=1) # 4 GB -- Expected memory usage as long as there is only one loopminimal_example(seqlen=2, unroll=True) # 8 GB -- With unrolling the memory grows by 4GB, which is expectedminimal_example(seqlen=3, unroll=True) # 12 GB -- Normal growth by 4 GB per iterationminimal_example(seqlen=4, unroll=True) # 16 GB -- Normal growth by 4 GB per iterationminimal_example(seqlen=2, unroll=False) # 16 GB -- Without unrolling, the memory usage grows by an additional 8 GB on top of the expected 4minimal_example(seqlen=3, unroll=False) # 20 GB -- Further growth is only 4 GB as expectedminimal_example(seqlen=4, unroll=False) # 24 GB -- Normal growth by 4 GB per iteration
System info (python version, jaxlib version, accelerator, etc.)
jax: 0.5.0
jaxlib: 0.5.0
numpy: 2.2.3
python: 3.11.11 (main, Jan 14 2025, 22:49:08) [Clang 19.1.6 ]
device info: NVIDIA A100-SXM4-80GB-1, 1 local devices"
process_count: 1
platform: uname_result(system='Linux', node='...', release='5.15.0-112-generic', version='#122-Ubuntu SMP Thu May 23 07:48:21 UTC 2024', machine='x86_64')
$ nvidia-smi
Wed Feb 19 11:55:03 2025
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.14 Driver Version: 550.54.14 CUDA Version: 12.4 |
|-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+========================+======================|
| 0 NVIDIA A100-SXM4-80GB Off | 00000000:07:00.0 Off | 0 |
| N/A 30C P0 62W / 400W | 3MiB / 81920MiB | 0% Default |
| | | Disabled |
The text was updated successfully, but these errors were encountered:
marcelroed
changed the title
Scan leads to much higher memory usage than unroll
Scan leads to much higher memory usage than an unrolled loop
Feb 20, 2025
Description
Hi! I'm working on a model that needs to do a large amount of gradient updates within JIT, and we use a scan to implement these updates. However, it seems when the loop is rolled (
unroll=False
in the code below) it requires an additional amount of memory that is not needed withunroll=True
). When the number of iterations goes from 1 to 2, the unrolled version increases the memory requirement by an expected incremental amount, and the requirement increments by the same amount when increasing from 2 to 3, 3 to 4, etc. Withunroll=False
, the memory requirement has a larger step change (3x expected) when increasing the number of iterations to 2, but then continues to increase by the expected amount for additional steps.From what I understand, the body of the scan loop must be a fixed function, so the allocation required at all iterations will be the max of what will be required in each iteration.
Things I have considered:
Is this a known issue? If so, are there any good ways to get around it?
Below is a minimal reproduction:
System info (python version, jaxlib version, accelerator, etc.)
The text was updated successfully, but these errors were encountered: