Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Scan leads to much higher memory usage than an unrolled loop #26618

Open
marcelroed opened this issue Feb 19, 2025 · 0 comments
Open

Scan leads to much higher memory usage than an unrolled loop #26618

marcelroed opened this issue Feb 19, 2025 · 0 comments
Labels
bug Something isn't working

Comments

@marcelroed
Copy link

Description

Hi! I'm working on a model that needs to do a large amount of gradient updates within JIT, and we use a scan to implement these updates. However, it seems when the loop is rolled (unroll=False in the code below) it requires an additional amount of memory that is not needed with unroll=True). When the number of iterations goes from 1 to 2, the unrolled version increases the memory requirement by an expected incremental amount, and the requirement increments by the same amount when increasing from 2 to 3, 3 to 4, etc. With unroll=False, the memory requirement has a larger step change (3x expected) when increasing the number of iterations to 2, but then continues to increase by the expected amount for additional steps.

From what I understand, the body of the scan loop must be a fixed function, so the allocation required at all iterations will be the max of what will be required in each iteration.

Things I have considered:

  • There might be some behavior at the start or end of the scan loop that requires more memory. I've tried unrolling only the first and last iterations of the scan, but this only delays the jump in memory usage to whenever the rolled part goes from 1 to 2 iterations.
  • The compiler can't optimize the scan loop body with knowledge of what happens externally to it. If this is the case, is it possible to get around it?

Is this a known issue? If so, are there any good ways to get around it?

Below is a minimal reproduction:

import jax
import jax.numpy as jnp
import jax.random as jrandom

def minimal_example(seqlen: int, unroll: bool = False):
    key = jrandom.key(0)
    X = jnp.ones((seqlen, 1), dtype=jnp.float32)  # `seqlen` samples to scan over

    def init_model():
        return {'w1': jrandom.normal(key, shape=(1, 1024)), 'w2': jrandom.normal(key, shape=(1024, 1024 * 1024))}
    
    # Model is two linear layers (single layer is degenerate)
    def model_fn(model, x):
        return x @ model['w1'] @ model['w2']
    
    model = init_model()
    
    def compute_loss(model, X):
        def scan_update_rule(carry, x):
            model, loss = carry

            # Result from model
            z = model_fn(model, x)

            # Modify model based on this result (in reality this is a gradient update, but simplified here for minimal repro)
            zval = jnp.sum(z)
            new_model = jax.tree.map(lambda p: p * zval, model)

            new_loss = loss + zval  # Accumulate loss
            return (new_model, new_loss), None
        
        carry = (model, jnp.zeros(()))
        (_, final_loss), _ = jax.lax.scan(scan_update_rule, carry, X, unroll=unroll)
        return final_loss

    minimal_fun = jax.grad(compute_loss)
    minimal_fun = jax.jit(minimal_fun, donate_argnums=(0,))

    lowered = minimal_fun.lower(model, X)
    compiled = lowered.compile()
    memory_analysis = compiled.memory_analysis()

    print(f"\n==={seqlen=}, {unroll=}===")
    print(f"{memory_analysis}")
    print(f"Estimated memory cost: {(memory_analysis.output_size_in_bytes + memory_analysis.temp_size_in_bytes) / 1024 ** 3} GB")

minimal_example(seqlen=1)                # 4  GB -- Expected memory usage as long as there is only one loop
minimal_example(seqlen=2, unroll=True)   # 8  GB -- With unrolling the memory grows by 4GB, which is expected
minimal_example(seqlen=3, unroll=True)   # 12 GB -- Normal growth by 4 GB per iteration
minimal_example(seqlen=4, unroll=True)   # 16 GB -- Normal growth by 4 GB per iteration

minimal_example(seqlen=2, unroll=False)  # 16 GB -- Without unrolling, the memory usage grows by an additional 8 GB on top of the expected 4
minimal_example(seqlen=3, unroll=False)  # 20 GB -- Further growth is only 4 GB as expected
minimal_example(seqlen=4, unroll=False)  # 24 GB -- Normal growth by 4 GB per iteration

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.5.0
jaxlib: 0.5.0
numpy:  2.2.3
python: 3.11.11 (main, Jan 14 2025, 22:49:08) [Clang 19.1.6 ]
device info: NVIDIA A100-SXM4-80GB-1, 1 local devices"
process_count: 1
platform: uname_result(system='Linux', node='...', release='5.15.0-112-generic', version='#122-Ubuntu SMP Thu May 23 07:48:21 UTC 2024', machine='x86_64')


$ nvidia-smi
Wed Feb 19 11:55:03 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.14              Driver Version: 550.54.14      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA A100-SXM4-80GB          Off |   00000000:07:00.0 Off |                    0 |
| N/A   30C    P0             62W /  400W |       3MiB /  81920MiB |      0%      Default |
|                                         |                        |             Disabled |
@marcelroed marcelroed added the bug Something isn't working label Feb 19, 2025
@marcelroed marcelroed changed the title Scan leads to much higher memory usage than unroll Scan leads to much higher memory usage than an unrolled loop Feb 20, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

1 participant