diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py index 1372a3211..7a93e074b 100644 --- a/open_instruct/grpo_fast.py +++ b/open_instruct/grpo_fast.py @@ -625,6 +625,15 @@ def load(self, path: str, map_location=None): np.random.seed(worker_seed) random.seed(worker_seed) + torch.distributed.init_process_group( + backend="nccl", + init_method="env://", + world_size=self.world_size, + rank=self.rank, + timeout=timedelta(minutes=args.backend_timeout), + device_id=torch.device("cuda", self.local_rank), + ) + deepspeed.init_distributed(timeout=timedelta(minutes=args.backend_timeout)) ds_config = get_train_ds_config(offload=False, adam_offload=False, stage=args.deepspeed_stage, bf16=True)