Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Model Sharding #70

Open
sadamov opened this issue Aug 19, 2024 · 1 comment
Open

Implement Model Sharding #70

sadamov opened this issue Aug 19, 2024 · 1 comment
Labels
enhancement New feature or request

Comments

@sadamov
Copy link
Collaborator

sadamov commented Aug 19, 2024

Description:

We should implement model sharding in neural-lam to allow for training with larger batch sizes without exhausting GPU vRAM. This feature will enable users to scale to larger models and improve training efficiency. Using high-res datasets with many input feature channels quickly exhausts even GPUs with 100GB vRAM. Currently, the batch-size must be reduced to 1-4 on many systems for such datasets.

Proposed implementation:

  1. Add sharding logic to model definition (see here for bipartite_subgraph & here)
  2. Provide configuration options for sharding strategy as train_model flag

Benefits:

  • Train larger models with limited GPU resources
  • Increase batch sizes for improved training + speed
  • Enable future compatibility with Anemoi-Models (exchange models between frameworks)
  • Allow training on high-res datasets with many features
  • Allow training with large boundary areas

Technical considerations:

  • PyTorch's DistributedDataParallel should play well with bipartite_graph
  • Ensure compatibility with existing model architectures (graph_lam and hi_lam)
  • Add documentation and examples for using sharded models

This feature will significantly enhance neural-lam's capabilities for large-scale atmospheric modeling.

@sadamov sadamov added the enhancement New feature or request label Aug 19, 2024
@joeloskarsson
Copy link
Collaborator

A few thoughts about sharding (and also about saving VRAM in general):

  • We would always want to go down to batch size 1 and rely on DDP first before doing model sharding. Sharding should only be necessary once training the model with a single sample does not fit on GPU.
  • I agree that if there is a way to utilize and connect this with Anemoi that would be the best approach.
    • What could be a challenge here is that I expect the cross-device communication overhead to be dependent on the exact graph structure. So things might get really complex if we want to allow this for arbitrary graphs. Not sure, but something to look into when diving into the technical details.
  • While this is one way save VRAM, I think something that should have even higher priority (and is much easier to implement) would be gradient checkpointing in-between roll-out steps. This has been used to train most large-scale global models, but we have yet to add it (as an option) in neural-lam.
    • While this will save VRAM when fine-tuning on rollout, it does increase the total amount of computations, so has its drawbacks as well.
    • This still won't save you if your model is so big that one single unroll step with batch size 1 does not fit on the GPU.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants