You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
We should implement model sharding in neural-lam to allow for training with larger batch sizes without exhausting GPU vRAM. This feature will enable users to scale to larger models and improve training efficiency. Using high-res datasets with many input feature channels quickly exhausts even GPUs with 100GB vRAM. Currently, the batch-size must be reduced to 1-4 on many systems for such datasets.
A few thoughts about sharding (and also about saving VRAM in general):
We would always want to go down to batch size 1 and rely on DDP first before doing model sharding. Sharding should only be necessary once training the model with a single sample does not fit on GPU.
I agree that if there is a way to utilize and connect this with Anemoi that would be the best approach.
What could be a challenge here is that I expect the cross-device communication overhead to be dependent on the exact graph structure. So things might get really complex if we want to allow this for arbitrary graphs. Not sure, but something to look into when diving into the technical details.
While this is one way save VRAM, I think something that should have even higher priority (and is much easier to implement) would be gradient checkpointing in-between roll-out steps. This has been used to train most large-scale global models, but we have yet to add it (as an option) in neural-lam.
While this will save VRAM when fine-tuning on rollout, it does increase the total amount of computations, so has its drawbacks as well.
This still won't save you if your model is so big that one single unroll step with batch size 1 does not fit on the GPU.
Description:
We should implement model sharding in neural-lam to allow for training with larger batch sizes without exhausting GPU vRAM. This feature will enable users to scale to larger models and improve training efficiency. Using high-res datasets with many input feature channels quickly exhausts even GPUs with 100GB vRAM. Currently, the batch-size must be reduced to 1-4 on many systems for such datasets.
Proposed implementation:
train_model
flagBenefits:
Technical considerations:
DistributedDataParallel
should play well withbipartite_graph
This feature will significantly enhance neural-lam's capabilities for large-scale atmospheric modeling.
The text was updated successfully, but these errors were encountered: