Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
8f2145c
Now, llmrayactor returns logprobs
finbarrtimbers Sep 29, 2025
4632792
Updated code
finbarrtimbers Sep 29, 2025
fb6028d
CLeaned up PR.
finbarrtimbers Sep 29, 2025
c275203
Cleaned up PR.
finbarrtimbers Sep 29, 2025
966c78b
Updated logprob code
finbarrtimbers Sep 29, 2025
1ef90d8
Fixed code
finbarrtimbers Sep 29, 2025
090d104
now uses nan
finbarrtimbers Sep 29, 2025
5ca9d0b
Now, we filter nans
finbarrtimbers Sep 29, 2025
54728e9
Cleaned up code.
finbarrtimbers Sep 29, 2025
bdaf060
Fixed tests
finbarrtimbers Sep 29, 2025
2d4f540
Updated code
finbarrtimbers Sep 29, 2025
280b4f8
Added vllm logprobs
finbarrtimbers Sep 29, 2025
4f871d6
Cleaned up code
finbarrtimbers Sep 29, 2025
a6ea5da
Undo changes to script.
finbarrtimbers Sep 29, 2025
3d7c852
Fixed bug in logprobs
finbarrtimbers Sep 29, 2025
03c4207
fixed failing tests
finbarrtimbers Sep 29, 2025
c8c3afe
Merge branch 'main' into vllm-logprobs
finbarrtimbers Oct 7, 2025
f7c6bca
Added importance sampling ratio
finbarrtimbers Oct 7, 2025
82d9535
Added back comment
finbarrtimbers Oct 7, 2025
deedd09
Removed comment
finbarrtimbers Oct 7, 2025
b493408
Test config
finbarrtimbers Oct 7, 2025
dfd10b8
Add truncated importance sampling with debug assertions to identify N…
finbarrtimbers Oct 7, 2025
9236c9e
Fix NaN handling in truncated importance sampling
finbarrtimbers Oct 7, 2025
4d6fa40
added reverse kl
finbarrtimbers Oct 7, 2025
8eaee10
Merge branch 'main' into vllm-logprobs
finbarrtimbers Oct 7, 2025
4f07ed6
Simplified code
finbarrtimbers Oct 7, 2025
61f1441
changes to debug mask
finbarrtimbers Oct 7, 2025
ddf2425
more logging
finbarrtimbers Oct 7, 2025
8f4448a
more logging
finbarrtimbers Oct 7, 2025
2b7e531
Addedcomment describing
finbarrtimbers Oct 7, 2025
707cf10
Add diagnostic logging to vllm_utils3 to detect logprobs length misma…
finbarrtimbers Oct 7, 2025
fc036c4
Fix vLLM logprobs N-1 mismatch by only appending EOS for empty responses
finbarrtimbers Oct 7, 2025
88fcfcc
Address review comments.
finbarrtimbers Oct 8, 2025
4afb6dd
Updated scripts
finbarrtimbers Oct 8, 2025
dbd3cb2
Updated scripts
finbarrtimbers Oct 8, 2025
5dea688
Cleaned up PR.
finbarrtimbers Oct 8, 2025
ed7bb9c
Added assert
finbarrtimbers Oct 8, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
151 changes: 129 additions & 22 deletions open_instruct/grpo_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,8 @@ class Args:
"""the lower clip range"""
clip_higher: float = 0.2
"""the higher clip range. Sometimes we want this to be higher, see DAPO (https://arxiv.org/abs/2503.14476)"""
truncated_importance_sampling_ratio_cap: float = 0.0
"""The maximum cap for truncated importance sampling ratio (0 means disabled)"""
inflight_updates: bool = False
"""Enable immediate stopping of request processing when should_stop is set, allowing for quick pausing and resumption"""
kl_estimator: Literal["kl1", "kl2", "kl3", "kl4"] = "kl3"
Expand Down Expand Up @@ -275,6 +277,8 @@ class Args:

record_entropy: bool = False
"""whether to record the entropy of the policy during training. Uses extra memory."""
use_vllm_logprobs: bool = False
"""whether to use vLLM's logprobs for training instead of calculating them via forward pass"""

# Reward
# -- r1 style format reward
Expand Down Expand Up @@ -436,6 +440,11 @@ def __post_init__(self):
logger.warning("When using the v0 version of vLLM, caching is broken and will never be invalidated.")
if self.vllm_enable_prefix_caching:
raise ValueError("Prefix caching is currently not supported for v0.")
if self.use_vllm_logprobs and self.truncated_importance_sampling_ratio_cap > 0.0:
raise ValueError(
"Cannot use both `use_vllm_logprobs` and `truncated_importance_sampling_ratio_cap`. "
"use_vllm_logprobs sets old_logprobs to vLLM logprobs, making importance sampling pointless."
)
assert self.num_samples_per_prompt_rollout > 0, "Number of samples per prompt must be greater than 0!"
if self.num_samples_per_prompt_rollout == 1:
logger.warning("num_samples_per_prompt_rollout is 1. This reduces GRPO to REINFORCE.")
Expand Down Expand Up @@ -893,6 +902,7 @@ def train(
collated_position_ids,
collated_advantages,
collated_response_masks,
collated_vllm_logprobs,
pad_token_id: int,
num_mini_batches: int,
):
Expand All @@ -903,6 +913,7 @@ def train(
to_device_inplace(collated_position_ids, self.device)
to_device_inplace(collated_advantages, self.device)
to_device_inplace(collated_response_masks, self.device)
to_device_inplace(collated_vllm_logprobs, self.device)
# accumulation steps should always be at least 1
accumulation_steps = max(math.ceil(len(collated_query_responses) / num_mini_batches - 0.5), 1)
leftover = len(collated_query_responses) % accumulation_steps
Expand All @@ -913,6 +924,7 @@ def train(
collated_position_ids = collated_position_ids[0:-leftover]
collated_advantages = collated_advantages[0:-leftover]
collated_response_masks = collated_response_masks[0:-leftover]
collated_vllm_logprobs = collated_vllm_logprobs[0:-leftover]
logger.warning(f"{leftover} samples are dropped due to batch size {num_mini_batches}")

# recalculate the "real" number of mini-batches
Expand Down Expand Up @@ -958,21 +970,31 @@ def train(
attention_mask = collated_attention_masks[i]
position_id = collated_position_ids[i]
response_mask = collated_response_masks[i]
old_logprob, _ = self.forward(
self.model,
query_response,
attention_mask,
position_id,
pad_token_id,
args.temperature,
return_entropy=False,
)
if not args.use_vllm_logprobs:
local_old_logprob, _ = self.forward(
self.model,
query_response,
attention_mask,
position_id,
pad_token_id,
args.temperature,
return_entropy=False,
)
vllm_old_logprob = collated_vllm_logprobs[i][:, 1:]
if args.mask_tool_use and args.tool_use:
response_mask = response_mask.bool() & tool_mask.bool()
else:
response_mask = response_mask.bool()
old_logprob = torch.masked_fill(old_logprob, ~response_mask[:, 1:], INVALID_LOGPROB)
old_logprobs[i] = old_logprob
if not args.use_vllm_logprobs:
local_old_logprob = torch.masked_fill(
local_old_logprob, ~response_mask[:, 1:], INVALID_LOGPROB
)
vllm_old_logprob = torch.masked_fill(vllm_old_logprob, ~response_mask[:, 1:], INVALID_LOGPROB)
vllm_old_logprob = torch.nan_to_num(vllm_old_logprob, nan=INVALID_LOGPROB)
if args.use_vllm_logprobs:
old_logprobs[i] = vllm_old_logprob
else:
old_logprobs[i] = local_old_logprob
torch.cuda.empty_cache()

local_step = 0
Expand Down Expand Up @@ -1001,7 +1023,7 @@ def train(
mb_response_masks_bool = mb_response_masks[:, 1:].bool() & mb_tool_mask[:, 1:].bool()
mb_attention_mask = collated_attention_masks[i]
mb_position_id = collated_position_ids[i]
mb_new_logprobs, mb_entropy = self.forward(
mb_local_logprobs, mb_entropy = self.forward(
self.model,
mb_query_responses,
mb_attention_mask,
Expand All @@ -1010,16 +1032,50 @@ def train(
args.temperature,
return_entropy=args.record_entropy,
)
mb_new_logprobs = torch.masked_fill(mb_new_logprobs, ~mb_response_masks_bool, INVALID_LOGPROB)
mb_local_logprobs = torch.masked_fill(mb_local_logprobs, ~mb_response_masks_bool, INVALID_LOGPROB)
mb_vllm_logprobs = collated_vllm_logprobs[i][:, 1:]
mb_vllm_logprobs = torch.masked_fill(mb_vllm_logprobs, ~mb_response_masks_bool, INVALID_LOGPROB)
# Replace any remaining NaN values (query tokens in packed sequences are set to NaN by pack_sequences in rl_utils2.py)
mb_vllm_logprobs = torch.nan_to_num(mb_vllm_logprobs, nan=INVALID_LOGPROB)

# Compare vLLM logprobs with local logprobs
with torch.no_grad():
valid_mask = mb_response_masks_bool & ~torch.isnan(mb_vllm_logprobs)
logprob_diff = (mb_local_logprobs - mb_vllm_logprobs).abs()
masked_diff = torch.masked_fill(logprob_diff, ~valid_mask, 0.0)
mean_diff = masked_diff.sum() / valid_mask.sum() if valid_mask.sum() > 0 else 0.0
max_diff = masked_diff.max()
std_diff = masked_diff[valid_mask].std() if valid_mask.sum() > 1 else 0.0

self.local_metrics.add("debug/vllm_vs_local_logprob_diff_mean", mean_diff.item())
self.local_metrics.add("debug/vllm_vs_local_logprob_diff_max", max_diff.item())
self.local_metrics.add("debug/vllm_vs_local_logprob_diff_std", std_diff.item())

reverse_kl = torch.exp(mb_vllm_logprobs) * (mb_vllm_logprobs - mb_local_logprobs)
masked_reverse_kl = torch.masked_fill(reverse_kl, ~valid_mask, 0.0)
mean_reverse_kl = masked_reverse_kl.sum() / valid_mask.sum() if valid_mask.sum() > 0 else 0.0
self.local_metrics.add("debug/vllm_local_reverse_kl", mean_reverse_kl.item())

mb_new_logprobs = mb_local_logprobs

# Cache the old logprobs
if num_mini_batches > 1:
mb_old_logprobs = old_logprobs[i]
else:
with torch.no_grad():
if epoch_idx == 0:
old_logprobs[i] = mb_new_logprobs
mb_old_logprobs = old_logprobs[i].detach()
if args.use_vllm_logprobs:
old_logprobs[i] = mb_vllm_logprobs
else:
old_logprobs[i] = mb_local_logprobs.detach()
mb_old_logprobs = old_logprobs[i]

old_logprobs_mask = mb_old_logprobs != INVALID_LOGPROB
assert torch.all(old_logprobs_mask == mb_response_masks_bool), (
f"Old logprobs mask should match response mask. "
f"old_mask sum={old_logprobs_mask.sum()}, "
f"response_mask sum={mb_response_masks_bool.sum()}"
)

# Calculate the policy's loss
logprobs_diff = mb_new_logprobs - mb_old_logprobs
Expand All @@ -1028,6 +1084,46 @@ def train(
pg_losses2 = -mb_advantages[:, 1:] * torch.clamp(
ratio, 1.0 - args.clip_lower, 1.0 + args.clip_higher
)

# Apply truncated importance sampling if enabled
if args.truncated_importance_sampling_ratio_cap > 0 and mb_vllm_logprobs is not None:
old_logprobs_mask = mb_old_logprobs != INVALID_LOGPROB
vllm_logprobs_mask = mb_vllm_logprobs != INVALID_LOGPROB

assert torch.all(old_logprobs_mask == mb_response_masks_bool), (
f"Old logprobs mask should match response mask. "
f"old_mask sum={old_logprobs_mask.sum()}, "
f"response_mask sum={mb_response_masks_bool.sum()}"
)
assert torch.all(vllm_logprobs_mask == mb_response_masks_bool), (
f"vLLM logprobs mask should match response mask. "
f"vllm_mask sum={vllm_logprobs_mask.sum()}, "
f"response_mask sum={mb_response_masks_bool.sum()}"
)

valid_mask = mb_response_masks_bool

# Initialize importance ratio to 1.0 (no effect) for all positions
tis_imp_ratio = torch.ones_like(mb_old_logprobs)

if valid_mask.any():
# Calculate logprob difference only for valid positions
logprob_diff_is = mb_old_logprobs - mb_vllm_logprobs
# Clamp to prevent numerical overflow in exp
logprob_diff_is = torch.where(
valid_mask, logprob_diff_is.clamp(-10.0, 10.0), torch.zeros_like(logprob_diff_is)
)
# Compute importance ratio only for valid positions
tis_imp_ratio = torch.where(valid_mask, torch.exp(logprob_diff_is), tis_imp_ratio)
# Apply cap
tis_imp_ratio = torch.clamp(
tis_imp_ratio, max=args.truncated_importance_sampling_ratio_cap
)

# Apply importance sampling to losses
pg_losses = pg_losses * tis_imp_ratio
pg_losses2 = pg_losses2 * tis_imp_ratio

pg_loss_max = torch.max(pg_losses, pg_losses2)

# Here we recalculate kl: we want the KL loss to backpropagate through the model
Expand Down Expand Up @@ -1510,6 +1606,7 @@ def accumulate_inference_batches(
combined_tool_outputs = []
combined_tool_runtimes = []
combined_tool_calleds = []
combined_logprobs = []

earliest_start_time = float("inf")
prompt_lengths = []
Expand All @@ -1530,6 +1627,8 @@ def accumulate_inference_batches(
combined_tool_runtimes.extend(result.request_info.tool_runtimes)
combined_tool_calleds.extend(result.request_info.tool_calleds)

combined_logprobs.extend(result.logprobs)

earliest_start_time = min(earliest_start_time, result.start_time)

prompt_lengths.append(len(all_queries[i]))
Expand Down Expand Up @@ -1570,6 +1669,7 @@ def accumulate_inference_batches(
request_info=combined_request_info,
dataset_index=None, # Not meaningful for combined result
token_statistics=accumulated_stats,
logprobs=combined_logprobs,
)

if actor_manager is not None:
Expand Down Expand Up @@ -1636,14 +1736,10 @@ def data_preparation_thread(
for i in range(len(result.request_info.tool_outputs))
]
for i in range(len(result.finish_reasons)):
# edge case: sometimes it outputs eos immediately, and we get an empty response
# in that case, we need to add the eos token to the response
# note that this also adds eos to the end of reponses that stopped for other reasons.
if result.finish_reasons[i] == "stop" and (
len(result.responses[i]) == 0 or result.responses[i][-1] != tokenizer.eos_token_id
):
if result.finish_reasons[i] == "stop" and len(result.responses[i]) == 0:
result.responses[i].append(tokenizer.eos_token_id)
result.masks[i].append(1) # never mask the eos token for
result.masks[i].append(1)
result.logprobs[i].append(float("nan"))
with Timer("🔥 [Data Preparation Thread] Decoding responses", noop=True):
decoded_responses = tokenizer.batch_decode(result.responses, skip_special_tokens=True)
decoded_queries = batch.raw_queries
Expand Down Expand Up @@ -1706,6 +1802,7 @@ def data_preparation_thread(
masks = [result.masks[i] for i in non_zero_gradient_index]
batch = batch[non_zero_gradient_index.tolist()]
finish_reasons = [result.finish_reasons[i] for i in non_zero_gradient_index]
vllm_logprobs = [result.logprobs[i] for i in non_zero_gradient_index]
if args.mask_truncated_completions:
stop_idxes = torch.tensor([i for i in range(len(finish_reasons)) if finish_reasons[i] == "stop"])
num_truncated = len(finish_reasons) - len(stop_idxes)
Expand All @@ -1720,6 +1817,7 @@ def data_preparation_thread(
masks = [masks[i] for i in stop_idxes]
batch = batch[stop_idxes.tolist()]
finish_reasons = [finish_reasons[i] for i in stop_idxes]
vllm_logprobs = [vllm_logprobs[i] for i in stop_idxes]

if args.fill_completions:
with Timer("⏱ [Data Preparation Thread] Refill completions"):
Expand Down Expand Up @@ -1763,6 +1861,7 @@ def data_preparation_thread(
)

finish_reasons += [finish_reasons[i] for i in sampled_indices]
vllm_logprobs += [vllm_logprobs[i] for i in sampled_indices]

logger.info(
f"📊 Duplicated {need_to_fill_prompt} prompts from {len(sampled_indices)} total responses"
Expand All @@ -1783,6 +1882,7 @@ def data_preparation_thread(
masks=masks,
pack_length=args.pack_length,
pad_token_id=tokenizer.pad_token_id,
vllm_logprobs=vllm_logprobs,
)
num_new_tokens = sum(len(seq) for seq in packed_sequences.query_responses)
# Vectorized advantage calculation: create a lookup array where each index corresponds to a response mask value
Expand Down Expand Up @@ -1832,6 +1932,7 @@ def data_preparation_thread(
per_device_packed_position_ids = packed_sequences.position_ids[B * i : B * (i + 1)]
per_device_packed_advantages = packed_sequences.advantages[B * i : B * (i + 1)]
per_device_packed_response_masks = packed_sequences.response_masks[B * i : B * (i + 1)]
per_device_packed_vllm_logprobs = packed_sequences.vllm_logprobs[B * i : B * (i + 1)]

# Shuffle the batch and collate the data
b_inds = np.random.permutation(len(per_device_packed_query_responses))
Expand All @@ -1841,6 +1942,7 @@ def data_preparation_thread(
collated_position_ids = []
collated_response_masks = []
collated_advantages = []
collated_vllm_logprobs = []
for j in range(0, len(per_device_packed_query_responses), args.per_device_train_batch_size):
micro_range = b_inds[j : j + args.per_device_train_batch_size]
collated_query_responses.append(
Expand All @@ -1863,6 +1965,9 @@ def data_preparation_thread(
collated_advantages.append(
collate_fn([per_device_packed_advantages[idx] for idx in micro_range], 0)
)
collated_vllm_logprobs.append(
collate_fn([per_device_packed_vllm_logprobs[idx] for idx in micro_range], 0)
)
collated_data.append(
{
"collated_query_responses": collated_query_responses,
Expand All @@ -1871,6 +1976,7 @@ def data_preparation_thread(
"collated_position_ids": collated_position_ids,
"collated_advantages": collated_advantages,
"collated_response_masks": collated_response_masks,
"collated_vllm_logprobs": collated_vllm_logprobs,
}
)

Expand Down Expand Up @@ -2175,6 +2281,7 @@ def create_generation_configs(args: Args):
n=args.num_samples_per_prompt_rollout,
stop=args.stop_strings,
seed=args.seed,
logprobs=1, # Enable logprobs to compare with local calculations
# IMPORTANT: Set output_kind to FINAL_ONLY to ensure vLLM V1 properly handles n>1
# With the default CUMULATIVE mode, vLLM V1 returns separate outputs for each
# completion, making it difficult to aggregate them correctly. FINAL_ONLY mode
Expand Down
1 change: 1 addition & 0 deletions open_instruct/queue_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class GenerationResult:
training_step: Optional[int] = None
token_statistics: Optional[TokenStatistics] = None
start_time: Optional[float] = None
logprobs: Optional[List[List[float]]] = None # logprobs for each token in each response


@dataclass
Expand Down
Loading