Skip to content

Commit

Permalink
quick fix of grad accu and option for chunk prefill
Browse files Browse the repository at this point in the history
  • Loading branch information
PeterSH6 committed Jan 27, 2025
1 parent b7b6427 commit c03f0b0
Show file tree
Hide file tree
Showing 6 changed files with 17 additions and 2 deletions.
2 changes: 2 additions & 0 deletions tests/e2e/arithmetic_sequence/rl/config/ray_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ actor_rollout_ref:
log_prob_micro_batch_size_per_gpu: null
log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz}
log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}
disable_log_stats: True
enable_chunked_prefill: False # could get higher throughput
# for hf rollout
do_sample: True
# number of responses (i.e. num sample times)
Expand Down
2 changes: 2 additions & 0 deletions verl/trainer/config/ppo_megatron_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ actor_rollout_ref:
max_num_seqs: 1024
log_prob_micro_batch_size: null # will be deprecated, use log_prob_micro_batch_size_per_gpu
log_prob_micro_batch_size_per_gpu: null
disable_log_stats: True
enable_chunked_prefill: False # could get higher throughput
# for hf rollout
do_sample: True
layer_name_map:
Expand Down
1 change: 1 addition & 0 deletions verl/trainer/config/ppo_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ actor_rollout_ref:
log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz}
log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}
disable_log_stats: True
enable_chunked_prefill: False # could get higher throughput
# for hf rollout
do_sample: True
# number of responses (i.e. num sample times)
Expand Down
6 changes: 5 additions & 1 deletion verl/workers/actor/dp_actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,11 @@ def update_policy(self, data: DataProto):
metrics['actor/kl_loss'] = kl_loss.detach().item()
metrics['actor/kl_coef'] = self.config.kl_loss_coef

loss = policy_loss / self.gradient_accumulation
if self.config.use_dynamic_bsz:
# relative to the dynamic bsz
loss = policy_loss * (len(data) / self.config.ppo_mini_batch_size)
else:
loss = policy_loss / self.gradient_accumulation
loss.backward()

data = {
Expand Down
7 changes: 6 additions & 1 deletion verl/workers/critic/dp_critic.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,12 @@ def update_critic(self, data: DataProto):
returns=returns,
eos_mask=eos_mask,
cliprange_value=self.config.cliprange_value)
loss = vf_loss / self.gradient_accumulation
if self.config.use_dynamic_bsz:
# relative to the dynamic bsz
loss = vf_loss * (len(data) / self.config.ppo_mini_batch_size)
else:
loss = vf_loss / self.gradient_accumulation

loss.backward()

data = {
Expand Down
1 change: 1 addition & 0 deletions verl/workers/rollout/vllm_rollout/vllm_rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ def __init__(self, actor_module: nn.Module, config: DictConfig, tokenizer, model
load_format=config.load_format,
disable_log_stats=config.disable_log_stats,
max_num_batched_tokens=max_num_batched_tokens,
enable_chunked_prefill=config.enable_chunked_prefill,
)

# Offload vllm model to reduce peak memory usage
Expand Down

0 comments on commit c03f0b0

Please sign in to comment.