From 0b6f5df5b3df09aca96db3976feb23b34b0aaac1 Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Wed, 15 Jan 2025 09:28:02 -0800 Subject: [PATCH] Add a note on deepspeed's gradient accumulation --- open_instruct/ppo_vllm_thread_ray.py | 1 + open_instruct/ppo_vllm_thread_ray_gtrl.py | 1 + 2 files changed, 2 insertions(+) diff --git a/open_instruct/ppo_vllm_thread_ray.py b/open_instruct/ppo_vllm_thread_ray.py index 985cc829e..6105014a7 100644 --- a/open_instruct/ppo_vllm_thread_ray.py +++ b/open_instruct/ppo_vllm_thread_ray.py @@ -1152,6 +1152,7 @@ def vllm_generate( mini_batch_end = mini_batch_start + args.local_mini_batch_size mini_batch_inds = b_inds[mini_batch_start:mini_batch_end] gradient_accumulation_idx = 0 + # NOTE: deepspeed handles gradient accumulation automatically; see https://github.com/microsoft/DeepSpeed/issues/758#issuecomment-801580724 for micro_batch_start in range(0, args.local_mini_batch_size, args.per_device_train_batch_size): # print("micro batch start", micro_batch_start, self.rank) micro_batch_end = micro_batch_start + args.per_device_train_batch_size diff --git a/open_instruct/ppo_vllm_thread_ray_gtrl.py b/open_instruct/ppo_vllm_thread_ray_gtrl.py index 989b676cc..78979ce99 100644 --- a/open_instruct/ppo_vllm_thread_ray_gtrl.py +++ b/open_instruct/ppo_vllm_thread_ray_gtrl.py @@ -1225,6 +1225,7 @@ def vllm_generate( mini_batch_end = mini_batch_start + args.local_mini_batch_size mini_batch_inds = b_inds[mini_batch_start:mini_batch_end] gradient_accumulation_idx = 0 + # NOTE: deepspeed handles gradient accumulation automatically; see https://github.com/microsoft/DeepSpeed/issues/758#issuecomment-801580724 for micro_batch_start in range(0, args.local_mini_batch_size, args.per_device_train_batch_size): # print("micro batch start", micro_batch_start, self.rank) micro_batch_end = micro_batch_start + args.per_device_train_batch_size