Skip to content

Commit

Permalink
[misc] fix: load and offload in compute log prob (#208)
Browse files Browse the repository at this point in the history
- As titled
- Relevant: #181
  • Loading branch information
PeterSH6 authored Feb 5, 2025
1 parent 89ba48e commit 6872dbe
Showing 1 changed file with 10 additions and 0 deletions.
10 changes: 10 additions & 0 deletions verl/workers/fsdp_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,10 @@ def generate_sequences(self, prompts: DataProto):
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def compute_log_prob(self, data: DataProto):
assert self._is_actor
if self._is_offload_param:
load_fsdp_param_and_grad(module=self.actor_module_fsdp,
device_id=torch.cuda.current_device(),
load_grad=self._is_offload_grad)
data = data.to('cuda')
# we should always recompute old_log_probs when it is HybridEngine
data.meta_info['micro_batch_size'] = self.config.rollout.log_prob_micro_batch_size_per_gpu
Expand All @@ -491,7 +495,13 @@ def compute_log_prob(self, data: DataProto):
if self.world_size > 1:
self.actor.actor_module._handle.reshard(True)

if self._is_offload_param:
# NOTE(sgm): the grad is already in CPU, only offload param here
offload_fsdp_param_and_grad(module=self.actor_module_fsdp, offload_grad=self._is_offload_grad)

# clear kv cache
torch.cuda.empty_cache()
log_gpu_memory_usage('After compute_log_prob', logger=logger)
return output

@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
Expand Down

0 comments on commit 6872dbe

Please sign in to comment.