Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unexpectedly high memory consumption in jax.lax.scan #26639

Open
dbezgin opened this issue Feb 20, 2025 · 0 comments
Open

Unexpectedly high memory consumption in jax.lax.scan #26639

dbezgin opened this issue Feb 20, 2025 · 0 comments
Labels
bug Something isn't working

Comments

@dbezgin
Copy link

dbezgin commented Feb 20, 2025

Description

Hi everyone,

We have encountered unexpectedly high memory consumption in jax.lax.scan when using scan for an inner loop in combination with an unrolled Python for loop. In the minimal working example below, we have a function consisting of an outer loop and an inner loop. The outer loop is implemented using Python's for loop, while the inner loop is implemented using jax.lax.scan.

import jax
import jax.numpy as jnp

def inner_loop_v0(x, N):
    def body_fun(x1, _):
        x1 = 1 + jnp.cos(x1**2)
        return x1, x1
    _, x_out = jax.lax.scan(body_fun, x, None, N)
    return x_out[-1]

def inner_loop_v1(x, N):
    def body_fun(x1, x2):
        x1 = 1 + jnp.cos(x1**2)
        return x1, None
    x_out, _ = jax.lax.scan(body_fun, x, None, N)
    return x_out

def inner_loop_v2(x, N):
    x_list = []
    for _ in range(N):
        x = 1 + jnp.cos(x**2)
        x_list.append(x)
    return x_list[-1]

def outer_loop(x, N_outer, N_inner, mode):
    for _ in range(N_outer):
        if mode == 0:
            x = inner_loop_v0(x, N_inner)
        elif mode == 1:
            x = inner_loop_v1(x, N_inner)
        elif mode == 2:
            x = inner_loop_v2(x, N_inner)
        else:
            raise NotImplementedError
    return x

outer_loop = jax.jit(outer_loop, static_argnums=(1, 2, 3))

N_outer = 10
N_inner = 100
N_res = 200
A = jnp.arange(N_res**3, dtype=jnp.float32).reshape(N_res,N_res,N_res)

for mode in range(3):
    lowered = outer_loop.lower(A, N_outer, N_inner, mode)
    compiled = lowered.compile()
    memory_analysis = compiled.memory_analysis()
    print(f"GPU memory mode v{mode}: {(memory_analysis.output_size_in_bytes + memory_analysis.temp_size_in_bytes) / 1024**2:.2f} MB")

We have 3 different version for the inner loop:

  • Version v0 uses jax.lax.scan for the inner loop, where the return value is only the last entry of the stacked array.
  • Version v1 uses jax.lax.scan for the inner loop, where the return value is the carry value
  • Version v2 uses a Python for-loop for the inner loop

When we check memory consumption with N_outer = 5, N_inner = 100, and N_res = 200, we observe the following:

  • GPU memory mode v0: 15289.31 MB
  • GPU memory mode v1: 30.52 MB
  • GPU memory mode v2: 30.52 MB

Memory consumption of v0 is higher than for v1 and v2. This is expected, as v0 stacks the intermediate results in the scan.

However, when we increase the outer loop count, the memory consumption of version v0 is increasing linearly.
For N_outer = 10, N_inner = 100, and N_res = 200, we observe the following:

  • GPU memory mode v0: 30548.10 MB
  • GPU memory mode v1: 30.52 MB
  • GPU memory mode v2: 30.52 MB

And for N_outer = 15, N_inner = 100, and N_res = 200:

  • GPU memory mode v0: 45806.89 MB
  • GPU memory mode v1: 30.52 MB
  • GPU memory mode v1: 30.52 MB

Is this expected behavior? If so, could you elaborate why this is happening? We would expect the memory consumption for all three inner loops to remain constant when increasing the outer loop count.

Any feedback is much appreciated. Thank you!

System info (python version, jaxlib version, accelerator, etc.)

Python 3.11.4 (main, Jul  5 2023, 13:45:01) [GCC 11.2.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import jax; jax.print_environment_info()
jax:    0.5.0
jaxlib: 0.5.0
numpy:  1.26.4
python: 3.11.4 (main, Jul  5 2023, 13:45:01) [GCC 11.2.0]
device info: NVIDIA RTX A6000-1, 1 local devices"
process_count: 1
platform: uname_result(system='Linux', node='ursell', release='5.15.0-127-generic', version='#137-Ubuntu SMP Fri Nov 8 15:21:01 UTC 2024', machine='x86_64')
@dbezgin dbezgin added the bug Something isn't working label Feb 20, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

1 participant