diff --git a/bmtrain/synchronize.py b/bmtrain/synchronize.py index 2587e0f..8dd5fa6 100644 --- a/bmtrain/synchronize.py +++ b/bmtrain/synchronize.py @@ -25,7 +25,7 @@ def wait_loader(): config['calc_stream'].record_event(config['load_event']) -def sum_loss(loss : torch.Tensor, comm: Optional[nccl.Communicator] = None): +def sum_loss(loss : torch.Tensor, comm: Optional[nccl.NCCLCommunicator] = None): """ Sum the loss across all workers.