Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inadequate memory consumption when using HSDP without gradient accumulation #18090

Open
qGentry opened this issue Oct 9, 2024 · 2 comments
Open

Comments

@qGentry
Copy link

qGentry commented Oct 9, 2024

Hi, I'm training transformer model with Hybrid Sharded Data Parallelism. This setup is similar to FSDP/ZeRO-3 where params all-gather-ed for each layer's forward/backward pass and dropped afterwards. Although, instead of sharding both model params and optimizer state over all GPUs in the cluster, I shard model params only over subset of devices (usually within single node for the fast all-gathers over NVLink) and shard optimizer state over all gpus (similar to FSDP/ZeRO-1/2/3).

Basically, I have mesh (param_groups, model) and for each param tensor P of shape (X, Y) I shard param tensor with partition spec (model, None) and corresponding to this param P optimizer state P_o of the same shape (X, Y) with partition spec (model, param_groups).

When mesh (param_groups, model) size is:

  1. (1, N_GPUs) - this is basically FSDP/ZeRO-3.
  2. (N, N_GPUs/ N), N > 1 - HSDP.

I'm also have a gradient accumulation implemented where we split input batch into chunks, calculate forward/backward pass independently and then sum their gradients.

When using gradient accumulation with the factor of N (batch is splitted into N chucks and processes independently) and sequence lengths of S, peak memory usage must be equal setup with gradient accumulation with the factor of 2 * N and 2 * SEQ_LEN. This is because resulting input tensor is of shape [B / 2, 2 * S] has the same numel as tensor [B, S].

And this is completely true for the FSDP setup with mesh size (1, N_GPUs) for any gradient accumulation factor I've tested, peak memory usages are identical but when I'm trying to use HSDP, something weird happens.

When I'm using gradient accumulation factor of N > 1, peak memory usage is totally expected BUT as soon as I set it to 1, peak memory usage greatly increases.

Here, I have a toy model with the mesh (2, 4), total batch size of 64 and 3 setups:

  1. gradient accumulation factor = 1, seq_len = 512
  2. gradient accumulation factor = 2, seq_len = 1024
  3. gradient accumulation factor = 4, seq_len = 2048

Second and third setup consumes practically identical amount of memory (~50 GB on each GPU), while first sone consumes way more - 61GB.

Here's HLOs of the first and second setups:
compiled_train_fn_grad_accum=2.txt
compiled_train_fn_grad_accum=1.txt

JAX issue - jax-ml/jax#24208

@ptoulme-aws
Copy link
Contributor

Grep the number of all-gathers in both text files. Your one with while loop has 271 all-gathers while the one without has 175 all-gathers. Most likely your all-gathers are being CSEed away in the accum=1 or GSPMD partitioner is partitioning differently if while loop is present.

@ptoulme-aws
Copy link
Contributor

Try adding this in Jax

# adapt the prediction function to gather weights just before their use,
# and to re-gather them on the backward pass (rather than saving them)
@partial(jax.remat, policy=lambda op, *_, **__: str(op) != 'all_gather')

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants