Skip to content

Commit

Permalink
[perf] fix: set use_reentrant=False when enable gradient checkpointing (
Browse files Browse the repository at this point in the history
#114)

- Set use_reentrant=False to avoid duplicate allgather in backward when
gradient checkpointing is enabled.
- Optimize temperature computation by using inplace op
- Fix testing logics
  • Loading branch information
vermouth1992 authored Jan 18, 2025
1 parent e8eb9e4 commit 5a94e14
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 8 deletions.
3 changes: 2 additions & 1 deletion verl/trainer/ppo/ray_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -611,7 +611,8 @@ def fit(self):
metrics.update(actor_output_metrics)

# validate
if self.val_reward_fn is not None and self.global_steps % self.config.trainer.test_freq == 0:
if self.val_reward_fn is not None and self.config.trainer.test_freq > 0 and \
self.global_steps % self.config.trainer.test_freq == 0:
with _timer('testing', timing_raw):
val_metrics: dict = self._validate()
metrics.update(val_metrics)
Expand Down
5 changes: 3 additions & 2 deletions verl/workers/actor/dp_actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def _forward_micro_batch(self, micro_batch, temperature) -> Tuple[torch.Tensor,
use_cache=False) # prevent model thinks we are generating
logits_rmpad = output.logits.squeeze(0) # (total_nnz, vocab_size)

logits_rmpad /= temperature
logits_rmpad.div_(temperature)

# compute entropy
entropy_rmpad = verl_F.entropy_from_logits(logits_rmpad) # ((total_nnz / sp) + pad)
Expand Down Expand Up @@ -127,7 +127,8 @@ def _forward_micro_batch(self, micro_batch, temperature) -> Tuple[torch.Tensor,
attention_mask=attention_mask,
position_ids=position_ids,
use_cache=False) # prevent model thinks we are generating
logits = output.logits / temperature
logits = output.logits
logits.div_(temperature)
logits = logits[:, -response_length - 1:-1] # (bsz, response_length)
log_probs = logprobs_from_logits(logits, micro_batch['responses'])
entropy = verl_F.entropy_from_logits(logits) # (bsz, response_length)
Expand Down
13 changes: 8 additions & 5 deletions verl/workers/fsdp_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def _build_model_optimizer(self,
actor_module.to(torch_dtype)

if enable_gradient_checkpointing:
actor_module.gradient_checkpointing_enable()
actor_module.gradient_checkpointing_enable(gradient_checkpointing_kwargs={'use_reentrant': False})
torch.distributed.barrier()

if self.rank == 0:
Expand Down Expand Up @@ -212,7 +212,8 @@ def _build_model_optimizer(self,
sharding_strategy=sharding_strategy, # zero3
mixed_precision=mixed_precision,
sync_module_states=True,
device_mesh=self.device_mesh)
device_mesh=self.device_mesh,
forward_prefetch=False)

log_gpu_memory_usage('After Actor FSDP init', logger=logger)

Expand Down Expand Up @@ -575,7 +576,7 @@ def _build_critic_model_optimizer(self, config):
critic_module.to(torch_dtype)

if config.model.get('enable_gradient_checkpointing', False):
critic_module.gradient_checkpointing_enable()
critic_module.gradient_checkpointing_enable(gradient_checkpointing_kwargs={'use_reentrant': False})
if self.rank == 0:
print_model_size(critic_module)

Expand Down Expand Up @@ -603,7 +604,8 @@ def _build_critic_model_optimizer(self, config):
device_id=torch.cuda.current_device(),
sharding_strategy=ShardingStrategy.FULL_SHARD,
mixed_precision=mixed_precision,
sync_module_states=True)
sync_module_states=True,
forward_prefetch=False)

log_gpu_memory_usage('After critic FSDP', logger=None)

Expand Down Expand Up @@ -806,7 +808,8 @@ def _build_model(self, config):
device_id=torch.cuda.current_device(),
sharding_strategy=ShardingStrategy.FULL_SHARD, # zero3
sync_module_states=True,
cpu_offload=CPUOffload(offload_params=self.config.model.fsdp_config.param_offload))
cpu_offload=CPUOffload(offload_params=self.config.model.fsdp_config.param_offload),
forward_prefetch=False)

return reward_module

Expand Down

0 comments on commit 5a94e14

Please sign in to comment.