You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi, I'm training transformer model with Hybrid Sharded Data Parallelism. This setup is similar to FSDP/ZeRO-3 where params all-gather-ed for each layer's forward/backward pass and dropped afterwards. Although, instead of sharding both model params and optimizer state over all GPUs in the cluster, I shard model params only over subset of devices (usually within single node for the fast all-gathers over NVLink) and shard optimizer state over all gpus (similar to FSDP/ZeRO-1/2/3).
Basically, I have mesh (param_groups, model) and for each param tensor P of shape (X, Y) I shard param tensor with partition spec (model, None) and corresponding to this param P optimizer state P_o of the same shape (X, Y) with partition spec (model, param_groups).
When mesh (param_groups, model) size is:
(1, N_GPUs) - this is basically FSDP/ZeRO-3.
(N, N_GPUs/ N), N > 1 - HSDP.
I'm also have a gradient accumulation implemented where we split input batch into chunks, calculate forward/backward pass independently and then sum their gradients.
When using gradient accumulation with the factor of N (batch is splitted into N chucks and processes independently) and sequence lengths of S, peak memory usage must be equal setup with gradient accumulation with the factor of 2 * N and 2 * SEQ_LEN. This is because resulting input tensor is of shape [B / 2, 2 * S] has the same numel as tensor [B, S].
And this is completely true for the FSDP setup with mesh size (1, N_GPUs) for any gradient accumulation factor I've tested, peak memory usages are identical but when I'm trying to use HSDP, something weird happens.
When I'm using gradient accumulation factor of N > 1, peak memory usage is totally expected BUT as soon as I set it to 1, peak memory usage greatly increases.
Here, I have a toy model with the mesh (2, 4), total batch size of 64 and 3 setups:
gradient accumulation factor = 1, seq_len = 512
gradient accumulation factor = 2, seq_len = 1024
gradient accumulation factor = 4, seq_len = 2048
Second and third setup consumes practically identical amount of memory (~50 GB on each GPU), while first sone consumes way more - 61GB.
Grep the number of all-gathers in both text files. Your one with while loop has 271 all-gathers while the one without has 175 all-gathers. Most likely your all-gathers are being CSEed away in the accum=1 or GSPMD partitioner is partitioning differently if while loop is present.
# adapt the prediction function to gather weights just before their use,
# and to re-gather them on the backward pass (rather than saving them)
@partial(jax.remat, policy=lambda op, *_, **__: str(op) != 'all_gather')
Hi, I'm training transformer model with Hybrid Sharded Data Parallelism. This setup is similar to FSDP/ZeRO-3 where params all-gather-ed for each layer's forward/backward pass and dropped afterwards. Although, instead of sharding both model params and optimizer state over all GPUs in the cluster, I shard model params only over subset of devices (usually within single node for the fast all-gathers over NVLink) and shard optimizer state over all gpus (similar to FSDP/ZeRO-1/2/3).
Basically, I have mesh (param_groups, model) and for each param tensor P of shape (X, Y) I shard param tensor with partition spec (model, None) and corresponding to this param P optimizer state P_o of the same shape (X, Y) with partition spec (model, param_groups).
When mesh (param_groups, model) size is:
I'm also have a gradient accumulation implemented where we split input batch into chunks, calculate forward/backward pass independently and then sum their gradients.
When using gradient accumulation with the factor of N (batch is splitted into N chucks and processes independently) and sequence lengths of S, peak memory usage must be equal setup with gradient accumulation with the factor of 2 * N and 2 * SEQ_LEN. This is because resulting input tensor is of shape [B / 2, 2 * S] has the same numel as tensor [B, S].
And this is completely true for the FSDP setup with mesh size (1, N_GPUs) for any gradient accumulation factor I've tested, peak memory usages are identical but when I'm trying to use HSDP, something weird happens.
When I'm using gradient accumulation factor of N > 1, peak memory usage is totally expected BUT as soon as I set it to 1, peak memory usage greatly increases.
Here, I have a toy model with the mesh (2, 4), total batch size of 64 and 3 setups:
Second and third setup consumes practically identical amount of memory (~50 GB on each GPU), while first sone consumes way more - 61GB.
Here's HLOs of the first and second setups:
compiled_train_fn_grad_accum=2.txt
compiled_train_fn_grad_accum=1.txt
JAX issue - jax-ml/jax#24208
The text was updated successfully, but these errors were encountered: