From 4f8e3d8d16366e2cfcf39cea451849d3d9d943b2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Antonio=20Serrano=20Mu=C3=B1oz?= Date: Thu, 20 Jun 2024 17:12:30 -0400 Subject: [PATCH] Add method to broadcast and reduce distributed model parameters --- skrl/models/torch/base.py | 47 ++++++++++++++++++++++++++++++++++++++- 1 file changed, 46 insertions(+), 1 deletion(-) diff --git a/skrl/models/torch/base.py b/skrl/models/torch/base.py index 757a8ba2..bf90ba8a 100644 --- a/skrl/models/torch/base.py +++ b/skrl/models/torch/base.py @@ -7,7 +7,7 @@ import numpy as np import torch -from skrl import logger +from skrl import config, logger class Model(torch.nn.Module): @@ -743,3 +743,48 @@ def update_parameters(self, model: torch.nn.Module, polyak: float = 1) -> None: for parameters, model_parameters in zip(self.parameters(), model.parameters()): parameters.data.mul_(1 - polyak) parameters.data.add_(polyak * model_parameters.data) + + def broadcast_parameters(self, rank: int = 0): + """Broadcast model parameters to the whole group (e.g.: across all nodes) in distributed runs + + After calling this method, the distributed model will contain the broadcasted parameters from ``rank`` + + :param rank: Worker/process rank from which to broadcast model parameters (default: ``0``) + :type rank: int + + Example:: + + # broadcast model parameter from worker/process with rank 1 + >>> if config.torch.is_distributed: + ... model.update_parameters(source_model, rank=1) + """ + object_list = [self.state_dict()] + torch.distributed.broadcast_object_list(object_list, rank) + self.load_state_dict(object_list[0]) + + def reduce_parameters(self): + """Reduce model parameters across all workers/processes in the whole group (e.g.: across all nodes) + + After calling this method, the distributed model parameters will be bitwise identical for all workers/processes + + Example:: + + # reduce model parameter across all workers/processes + >>> if config.torch.is_distributed: + ... model.reduce_parameters() + """ + # batch all_reduce ops: https://github.com/entity-neural-network/incubator/pull/220 + gradients = [] + for parameters in self.parameters(): + if parameters.grad is not None: + gradients.append(parameters.grad.view(-1)) + gradients = torch.cat(gradients) + + torch.distributed.all_reduce(gradients, op=torch.distributed.ReduceOp.SUM) + + offset = 0 + for parameters in self.parameters(): + if parameters.grad is not None: + parameters.grad.data.copy_(gradients[offset:offset + parameters.numel()] \ + .view_as(parameters.grad.data) / config.torch.world_size) + offset += parameters.numel()