Skip to content

Commit

Permalink
Add method to broadcast and reduce distributed model parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
Toni-SM committed Jun 20, 2024
1 parent 5abed23 commit 4f8e3d8
Showing 1 changed file with 46 additions and 1 deletion.
47 changes: 46 additions & 1 deletion skrl/models/torch/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import numpy as np
import torch

from skrl import logger
from skrl import config, logger


class Model(torch.nn.Module):
Expand Down Expand Up @@ -743,3 +743,48 @@ def update_parameters(self, model: torch.nn.Module, polyak: float = 1) -> None:
for parameters, model_parameters in zip(self.parameters(), model.parameters()):
parameters.data.mul_(1 - polyak)
parameters.data.add_(polyak * model_parameters.data)

def broadcast_parameters(self, rank: int = 0):
"""Broadcast model parameters to the whole group (e.g.: across all nodes) in distributed runs
After calling this method, the distributed model will contain the broadcasted parameters from ``rank``
:param rank: Worker/process rank from which to broadcast model parameters (default: ``0``)
:type rank: int
Example::
# broadcast model parameter from worker/process with rank 1
>>> if config.torch.is_distributed:
... model.update_parameters(source_model, rank=1)
"""
object_list = [self.state_dict()]
torch.distributed.broadcast_object_list(object_list, rank)
self.load_state_dict(object_list[0])

def reduce_parameters(self):
"""Reduce model parameters across all workers/processes in the whole group (e.g.: across all nodes)
After calling this method, the distributed model parameters will be bitwise identical for all workers/processes
Example::
# reduce model parameter across all workers/processes
>>> if config.torch.is_distributed:
... model.reduce_parameters()
"""
# batch all_reduce ops: https://github.com/entity-neural-network/incubator/pull/220
gradients = []
for parameters in self.parameters():
if parameters.grad is not None:
gradients.append(parameters.grad.view(-1))
gradients = torch.cat(gradients)

torch.distributed.all_reduce(gradients, op=torch.distributed.ReduceOp.SUM)

offset = 0
for parameters in self.parameters():
if parameters.grad is not None:
parameters.grad.data.copy_(gradients[offset:offset + parameters.numel()] \
.view_as(parameters.grad.data) / config.torch.world_size)
offset += parameters.numel()

0 comments on commit 4f8e3d8

Please sign in to comment.