-
Notifications
You must be signed in to change notification settings - Fork 13
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Co-authored-by: Jesper Dramsch <[email protected]> Co-authored-by: Simon <[email protected]> Co-authored-by: Simon Lang <[email protected]>
- Loading branch information
1 parent
3ecf92c
commit ca9b1ee
Showing
2 changed files
with
130 additions
and
0 deletions.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,130 @@ | ||
import logging | ||
import os | ||
|
||
import numpy as np | ||
import pytorch_lightning as pl | ||
import torch | ||
from lightning_fabric.utilities.optimizer import _optimizers_to_device | ||
from pytorch_lightning.overrides.distributed import _sync_module_states | ||
from pytorch_lightning.strategies.ddp import DDPStrategy | ||
from pytorch_lightning.trainer.states import TrainerFn | ||
|
||
from anemoi.training.utils.seeding import get_base_seed | ||
|
||
LOGGER = logging.getLogger(__name__) | ||
|
||
|
||
class DDPGroupStrategy(DDPStrategy): | ||
"""Distributed Data Parallel strategy with group communication.""" | ||
|
||
def __init__(self, num_gpus_per_model: int, **kwargs) -> None: | ||
super().__init__(**kwargs) | ||
self.model_comm_group_size = num_gpus_per_model | ||
|
||
def setup(self, trainer: pl.Trainer) -> None: | ||
assert self.accelerator is not None, "Accelerator is not initialized for distributed strategy" | ||
self.accelerator.setup(trainer) | ||
|
||
# determine the model groups that work together: | ||
|
||
assert self.world_size % self.model_comm_group_size == 0, ( | ||
f"Total number of GPUs ({self.world_size}) must be divisible by the number of GPUs " | ||
f"per model ({self.model_comm_group_size})." | ||
) | ||
|
||
model_comm_group_ranks = np.split( | ||
np.arange(self.world_size, dtype=int), int(self.world_size / self.model_comm_group_size) | ||
) | ||
model_comm_groups = [ | ||
torch.distributed.new_group(x) for x in model_comm_group_ranks | ||
] # every rank has to create all of these | ||
|
||
model_comm_group_id, model_comm_group_nr, model_comm_group_rank = self.get_my_model_comm_group( | ||
self.model_comm_group_size | ||
) | ||
model_comm_group = model_comm_groups[model_comm_group_id] | ||
self.model.set_model_comm_group(model_comm_group) | ||
LOGGER.debug( | ||
"Rank %d model_comm_group is %s, group number %d, with local group rank %d and comms_group_ranks %s", | ||
self.global_rank, | ||
str(model_comm_group_nr), | ||
model_comm_group_id, | ||
model_comm_group_rank, | ||
str(model_comm_group_ranks[model_comm_group_id]), | ||
) | ||
|
||
# register hooks for correct gradient reduction | ||
self.register_parameter_hooks() | ||
|
||
# move the model to the correct device | ||
self.model_to_device() | ||
|
||
# skip wrapping the model if we are not fitting as no gradients need to be exchanged | ||
trainer_fn = trainer.state.fn | ||
|
||
if trainer_fn == TrainerFn.FITTING and self._layer_sync: | ||
assert self.model is not None, "Model is not initialized for distributed strategy" | ||
self.model = self._layer_sync.apply(self.model) | ||
|
||
self.setup_precision_plugin() | ||
|
||
if trainer_fn == TrainerFn.FITTING: | ||
# do not wrap with DDP if not fitting as there's no gradients to reduce | ||
self.configure_ddp() | ||
|
||
# set up optimizers after the wrapped module has been moved to the device | ||
self.setup_optimizers(trainer) | ||
_optimizers_to_device(self.optimizers, self.root_device) | ||
|
||
import torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook as post_localSGD | ||
|
||
if isinstance(self._ddp_comm_state, post_localSGD.PostLocalSGDState): | ||
self._enable_model_averaging() | ||
else: | ||
# we need to manually synchronize the module's states since we aren't using the DDP wrapper | ||
assert self.model is not None, "Model is not initialized for distributed strategy" | ||
_sync_module_states(self.model) | ||
|
||
# seed ranks | ||
self.seed_rnd(model_comm_group_id) | ||
|
||
def get_my_model_comm_group(self, num_gpus_per_model): | ||
"""Determine tasks that work together and from a model group.""" | ||
model_comm_groups = np.arange(0, self.world_size, dtype=np.int32) | ||
model_comm_groups = np.split(model_comm_groups, self.world_size / num_gpus_per_model) | ||
|
||
model_comm_group_id = None | ||
for i, model_comm_group in enumerate(model_comm_groups): | ||
if self.global_rank in model_comm_group: | ||
model_comm_group_id = i | ||
model_comm_group_nr = model_comm_group | ||
model_comm_group_rank = np.ravel(np.asarray(model_comm_group == self.global_rank).nonzero())[0] | ||
return model_comm_group_id, model_comm_group_nr, model_comm_group_rank | ||
|
||
def seed_rnd(self, model_comm_group_id: int) -> None: | ||
"""Seed the random number generators for the rank.""" | ||
base_seed = get_base_seed() | ||
initial_seed = base_seed * (model_comm_group_id + 1) | ||
rnd_seed = pl.seed_everything(initial_seed) # note: workers are seeded independently in dataloader | ||
np_rng = np.random.default_rng(rnd_seed) | ||
sanity_rnd = (torch.rand(1), np_rng.random()) | ||
LOGGER.debug( | ||
"Strategy: Rank %d, model comm group id %d, base seed %d, seeded with %d, running with random seed: %d, sanity rnd: %s", | ||
int(os.environ.get("SLURM_PROCID", "0")), | ||
model_comm_group_id, | ||
base_seed, | ||
initial_seed, | ||
rnd_seed, | ||
sanity_rnd, | ||
) | ||
|
||
def register_parameter_hooks(self) -> None: | ||
"""Register parameter hooks for gradient reduction. | ||
Here, we rescale parameters that only see a subset of the input on each rank | ||
-> these are still divided by the total number of GPUs in DDP as if each rank would see a full set of inputs | ||
note: the trainable parameters are added before the split across GPUs and are therefore not rescaled. | ||
""" | ||
for name, param in self.model.named_parameters(): | ||
if param.requires_grad is True and "trainable" not in name: | ||
param.register_hook(lambda grad: grad * float(self.model_comm_group_size)) |