Skip to content

Commit

Permalink
Add a note on deepspeed's gradient accumulation (#520)
Browse files Browse the repository at this point in the history
  • Loading branch information
vwxyzjn authored Jan 15, 2025
1 parent 4365dea commit 20f5bb7
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 0 deletions.
1 change: 1 addition & 0 deletions open_instruct/ppo_vllm_thread_ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1152,6 +1152,7 @@ def vllm_generate(
mini_batch_end = mini_batch_start + args.local_mini_batch_size
mini_batch_inds = b_inds[mini_batch_start:mini_batch_end]
gradient_accumulation_idx = 0
# NOTE: deepspeed handles gradient accumulation automatically; see https://github.com/microsoft/DeepSpeed/issues/758#issuecomment-801580724
for micro_batch_start in range(0, args.local_mini_batch_size, args.per_device_train_batch_size):
# print("micro batch start", micro_batch_start, self.rank)
micro_batch_end = micro_batch_start + args.per_device_train_batch_size
Expand Down
1 change: 1 addition & 0 deletions open_instruct/ppo_vllm_thread_ray_gtrl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1225,6 +1225,7 @@ def vllm_generate(
mini_batch_end = mini_batch_start + args.local_mini_batch_size
mini_batch_inds = b_inds[mini_batch_start:mini_batch_end]
gradient_accumulation_idx = 0
# NOTE: deepspeed handles gradient accumulation automatically; see https://github.com/microsoft/DeepSpeed/issues/758#issuecomment-801580724
for micro_batch_start in range(0, args.local_mini_batch_size, args.per_device_train_batch_size):
# print("micro batch start", micro_batch_start, self.rank)
micro_batch_end = micro_batch_start + args.per_device_train_batch_size
Expand Down

0 comments on commit 20f5bb7

Please sign in to comment.