You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
We have encountered unexpectedly high memory consumption in jax.lax.scan when using scan for an inner loop in combination with an unrolled Python for loop. In the minimal working example below, we have a function consisting of an outer loop and an inner loop. The outer loop is implemented using Python's for loop, while the inner loop is implemented using jax.lax.scan.
Version v0 uses jax.lax.scan for the inner loop, where the return value is only the last entry of the stacked array.
Version v1 uses jax.lax.scan for the inner loop, where the return value is the carry value
Version v2 uses a Python for-loop for the inner loop
When we check memory consumption with N_outer = 5, N_inner = 100, and N_res = 200, we observe the following:
GPU memory mode v0: 15289.31 MB
GPU memory mode v1: 30.52 MB
GPU memory mode v2: 30.52 MB
Memory consumption of v0 is higher than for v1 and v2. This is expected, as v0 stacks the intermediate results in the scan.
However, when we increase the outer loop count, the memory consumption of version v0 is increasing linearly.
For N_outer = 10, N_inner = 100, and N_res = 200, we observe the following:
GPU memory mode v0: 30548.10 MB
GPU memory mode v1: 30.52 MB
GPU memory mode v2: 30.52 MB
And for N_outer = 15, N_inner = 100, and N_res = 200:
GPU memory mode v0: 45806.89 MB
GPU memory mode v1: 30.52 MB
GPU memory mode v1: 30.52 MB
Is this expected behavior? If so, could you elaborate why this is happening? We would expect the memory consumption for all three inner loops to remain constant when increasing the outer loop count.
Any feedback is much appreciated. Thank you!
System info (python version, jaxlib version, accelerator, etc.)
Python 3.11.4 (main, Jul 5 2023, 13:45:01) [GCC 11.2.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import jax; jax.print_environment_info()
jax: 0.5.0
jaxlib: 0.5.0
numpy: 1.26.4
python: 3.11.4 (main, Jul 5 2023, 13:45:01) [GCC 11.2.0]
device info: NVIDIA RTX A6000-1, 1 local devices"
process_count: 1
platform: uname_result(system='Linux', node='ursell', release='5.15.0-127-generic', version='#137-Ubuntu SMP Fri Nov 8 15:21:01 UTC 2024', machine='x86_64')
The text was updated successfully, but these errors were encountered:
Description
Hi everyone,
We have encountered unexpectedly high memory consumption in jax.lax.scan when using scan for an inner loop in combination with an unrolled Python for loop. In the minimal working example below, we have a function consisting of an outer loop and an inner loop. The outer loop is implemented using Python's for loop, while the inner loop is implemented using jax.lax.scan.
We have 3 different version for the inner loop:
When we check memory consumption with N_outer = 5, N_inner = 100, and N_res = 200, we observe the following:
Memory consumption of v0 is higher than for v1 and v2. This is expected, as v0 stacks the intermediate results in the scan.
However, when we increase the outer loop count, the memory consumption of version v0 is increasing linearly.
For N_outer = 10, N_inner = 100, and N_res = 200, we observe the following:
And for N_outer = 15, N_inner = 100, and N_res = 200:
Is this expected behavior? If so, could you elaborate why this is happening? We would expect the memory consumption for all three inner loops to remain constant when increasing the outer loop count.
Any feedback is much appreciated. Thank you!
System info (python version, jaxlib version, accelerator, etc.)
The text was updated successfully, but these errors were encountered: