Skip to content
9 changes: 9 additions & 0 deletions open_instruct/grpo_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -625,6 +625,15 @@ def load(self, path: str, map_location=None):
np.random.seed(worker_seed)
random.seed(worker_seed)

torch.distributed.init_process_group(
backend="nccl",
init_method="env://",
world_size=self.world_size,
rank=self.rank,
timeout=timedelta(minutes=args.backend_timeout),
device_id=torch.device("cuda", self.local_rank),
)

deepspeed.init_distributed(timeout=timedelta(minutes=args.backend_timeout))

ds_config = get_train_ds_config(offload=False, adam_offload=False, stage=args.deepspeed_stage, bf16=True)
Expand Down