Skip to content

Commit bd26188

Browse files
Now, LLMRayActor returns logprobs, and we calculate some stats about them vs the trainer logprobs in grpo_fast.py. (#1041)
* Now, llmrayactor returns logprobs * Updated code * CLeaned up PR. * Cleaned up PR. * Updated logprob code * Fixed code * now uses nan * Now, we filter nans * Cleaned up code. * Fixed tests * Updated code * Added vllm logprobs * Cleaned up code * Undo changes to script. * Fixed bug in logprobs * fixed failing tests * Added importance sampling ratio * Added back comment * Removed comment * Test config * Add truncated importance sampling with debug assertions to identify NaN source - Added truncated_importance_sampling_ratio_cap parameter (default 0.0) - Implemented importance sampling with comprehensive assertions - Added checks for INVALID_LOGPROB values and extreme logprob differences - Added NaN checks at each step of the calculation This will help identify exactly where NaNs are introduced when importance sampling is enabled. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]> * Fix NaN handling in truncated importance sampling - Add torch.nan_to_num to replace NaN values with INVALID_LOGPROB after collation - Query tokens in packed sequences are set to NaN by pack_sequences in rl_utils2.py - Apply nan_to_num in both training loop (line 1031) and old_logprobs calculation (line 987) - Implement proper importance sampling with masking: * Only apply IS where both logprobs are valid (not INVALID_LOGPROB) * Use response mask to ensure only response tokens are affected * Initialize importance ratio to 1.0 (neutral) for invalid positions * Clamp logprob differences to prevent numerical overflow - Remove all debug assertions that were causing failures - Ensure importance sampling only affects valid response token positions This fixes the 'NaN in mb_old_logprobs before IS' assertion error. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]> * added reverse kl * Simplified code * changes to debug mask * more logging * more logging * Addedcomment describing * Add diagnostic logging to vllm_utils3 to detect logprobs length mismatch at source This adds logging to check if vLLM CompletionOutput has mismatched token_ids and logprobs lengths. According to vLLM source analysis, all generated tokens should have logprobs, so N-1 behavior would indicate a bug. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]> * Fix vLLM logprobs N-1 mismatch by only appending EOS for empty responses Root cause: We were unconditionally appending EOS token when finish_reason='stop', but not appending corresponding logprob. This created len(response) = N+1 but len(logprobs) = N mismatch. Fix: - Only append EOS for truly empty responses (the actual edge case) - When we do append EOS, also append NaN to logprobs - Normal responses ending with </answer> no longer get EOS appended - vLLM returns N logprobs for N tokens, so no mismatch Also added assertions to verify masks match correctly. Updated diagnostic logging to treat length mismatches as errors rather than expected behavior. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]> * Address review comments. * Updated scripts * Updated scripts * Cleaned up PR. * Added assert --------- Co-authored-by: Claude <[email protected]>
1 parent 3a21f22 commit bd26188

File tree

6 files changed

+200
-28
lines changed

6 files changed

+200
-28
lines changed

open_instruct/grpo_fast.py

Lines changed: 129 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,8 @@ class Args:
248248
"""the lower clip range"""
249249
clip_higher: float = 0.2
250250
"""the higher clip range. Sometimes we want this to be higher, see DAPO (https://arxiv.org/abs/2503.14476)"""
251+
truncated_importance_sampling_ratio_cap: float = 0.0
252+
"""The maximum cap for truncated importance sampling ratio (0 means disabled)"""
251253
inflight_updates: bool = False
252254
"""Enable immediate stopping of request processing when should_stop is set, allowing for quick pausing and resumption"""
253255
kl_estimator: Literal["kl1", "kl2", "kl3", "kl4"] = "kl3"
@@ -275,6 +277,8 @@ class Args:
275277

276278
record_entropy: bool = False
277279
"""whether to record the entropy of the policy during training. Uses extra memory."""
280+
use_vllm_logprobs: bool = False
281+
"""whether to use vLLM's logprobs for training instead of calculating them via forward pass"""
278282

279283
# Reward
280284
# -- r1 style format reward
@@ -436,6 +440,11 @@ def __post_init__(self):
436440
logger.warning("When using the v0 version of vLLM, caching is broken and will never be invalidated.")
437441
if self.vllm_enable_prefix_caching:
438442
raise ValueError("Prefix caching is currently not supported for v0.")
443+
if self.use_vllm_logprobs and self.truncated_importance_sampling_ratio_cap > 0.0:
444+
raise ValueError(
445+
"Cannot use both `use_vllm_logprobs` and `truncated_importance_sampling_ratio_cap`. "
446+
"use_vllm_logprobs sets old_logprobs to vLLM logprobs, making importance sampling pointless."
447+
)
439448
assert self.num_samples_per_prompt_rollout > 0, "Number of samples per prompt must be greater than 0!"
440449
if self.num_samples_per_prompt_rollout == 1:
441450
logger.warning("num_samples_per_prompt_rollout is 1. This reduces GRPO to REINFORCE.")
@@ -893,6 +902,7 @@ def train(
893902
collated_position_ids,
894903
collated_advantages,
895904
collated_response_masks,
905+
collated_vllm_logprobs,
896906
pad_token_id: int,
897907
num_mini_batches: int,
898908
):
@@ -903,6 +913,7 @@ def train(
903913
to_device_inplace(collated_position_ids, self.device)
904914
to_device_inplace(collated_advantages, self.device)
905915
to_device_inplace(collated_response_masks, self.device)
916+
to_device_inplace(collated_vllm_logprobs, self.device)
906917
# accumulation steps should always be at least 1
907918
accumulation_steps = max(math.ceil(len(collated_query_responses) / num_mini_batches - 0.5), 1)
908919
leftover = len(collated_query_responses) % accumulation_steps
@@ -913,6 +924,7 @@ def train(
913924
collated_position_ids = collated_position_ids[0:-leftover]
914925
collated_advantages = collated_advantages[0:-leftover]
915926
collated_response_masks = collated_response_masks[0:-leftover]
927+
collated_vllm_logprobs = collated_vllm_logprobs[0:-leftover]
916928
logger.warning(f"{leftover} samples are dropped due to batch size {num_mini_batches}")
917929

918930
# recalculate the "real" number of mini-batches
@@ -958,21 +970,31 @@ def train(
958970
attention_mask = collated_attention_masks[i]
959971
position_id = collated_position_ids[i]
960972
response_mask = collated_response_masks[i]
961-
old_logprob, _ = self.forward(
962-
self.model,
963-
query_response,
964-
attention_mask,
965-
position_id,
966-
pad_token_id,
967-
args.temperature,
968-
return_entropy=False,
969-
)
973+
if not args.use_vllm_logprobs:
974+
local_old_logprob, _ = self.forward(
975+
self.model,
976+
query_response,
977+
attention_mask,
978+
position_id,
979+
pad_token_id,
980+
args.temperature,
981+
return_entropy=False,
982+
)
983+
vllm_old_logprob = collated_vllm_logprobs[i][:, 1:]
970984
if args.mask_tool_use and args.tool_use:
971985
response_mask = response_mask.bool() & tool_mask.bool()
972986
else:
973987
response_mask = response_mask.bool()
974-
old_logprob = torch.masked_fill(old_logprob, ~response_mask[:, 1:], INVALID_LOGPROB)
975-
old_logprobs[i] = old_logprob
988+
if not args.use_vllm_logprobs:
989+
local_old_logprob = torch.masked_fill(
990+
local_old_logprob, ~response_mask[:, 1:], INVALID_LOGPROB
991+
)
992+
vllm_old_logprob = torch.masked_fill(vllm_old_logprob, ~response_mask[:, 1:], INVALID_LOGPROB)
993+
vllm_old_logprob = torch.nan_to_num(vllm_old_logprob, nan=INVALID_LOGPROB)
994+
if args.use_vllm_logprobs:
995+
old_logprobs[i] = vllm_old_logprob
996+
else:
997+
old_logprobs[i] = local_old_logprob
976998
torch.cuda.empty_cache()
977999

9781000
local_step = 0
@@ -1001,7 +1023,7 @@ def train(
10011023
mb_response_masks_bool = mb_response_masks[:, 1:].bool() & mb_tool_mask[:, 1:].bool()
10021024
mb_attention_mask = collated_attention_masks[i]
10031025
mb_position_id = collated_position_ids[i]
1004-
mb_new_logprobs, mb_entropy = self.forward(
1026+
mb_local_logprobs, mb_entropy = self.forward(
10051027
self.model,
10061028
mb_query_responses,
10071029
mb_attention_mask,
@@ -1010,16 +1032,50 @@ def train(
10101032
args.temperature,
10111033
return_entropy=args.record_entropy,
10121034
)
1013-
mb_new_logprobs = torch.masked_fill(mb_new_logprobs, ~mb_response_masks_bool, INVALID_LOGPROB)
1035+
mb_local_logprobs = torch.masked_fill(mb_local_logprobs, ~mb_response_masks_bool, INVALID_LOGPROB)
1036+
mb_vllm_logprobs = collated_vllm_logprobs[i][:, 1:]
1037+
mb_vllm_logprobs = torch.masked_fill(mb_vllm_logprobs, ~mb_response_masks_bool, INVALID_LOGPROB)
1038+
# Replace any remaining NaN values (query tokens in packed sequences are set to NaN by pack_sequences in rl_utils2.py)
1039+
mb_vllm_logprobs = torch.nan_to_num(mb_vllm_logprobs, nan=INVALID_LOGPROB)
1040+
1041+
# Compare vLLM logprobs with local logprobs
1042+
with torch.no_grad():
1043+
valid_mask = mb_response_masks_bool & ~torch.isnan(mb_vllm_logprobs)
1044+
logprob_diff = (mb_local_logprobs - mb_vllm_logprobs).abs()
1045+
masked_diff = torch.masked_fill(logprob_diff, ~valid_mask, 0.0)
1046+
mean_diff = masked_diff.sum() / valid_mask.sum() if valid_mask.sum() > 0 else 0.0
1047+
max_diff = masked_diff.max()
1048+
std_diff = masked_diff[valid_mask].std() if valid_mask.sum() > 1 else 0.0
1049+
1050+
self.local_metrics.add("debug/vllm_vs_local_logprob_diff_mean", mean_diff.item())
1051+
self.local_metrics.add("debug/vllm_vs_local_logprob_diff_max", max_diff.item())
1052+
self.local_metrics.add("debug/vllm_vs_local_logprob_diff_std", std_diff.item())
1053+
1054+
reverse_kl = torch.exp(mb_vllm_logprobs) * (mb_vllm_logprobs - mb_local_logprobs)
1055+
masked_reverse_kl = torch.masked_fill(reverse_kl, ~valid_mask, 0.0)
1056+
mean_reverse_kl = masked_reverse_kl.sum() / valid_mask.sum() if valid_mask.sum() > 0 else 0.0
1057+
self.local_metrics.add("debug/vllm_local_reverse_kl", mean_reverse_kl.item())
1058+
1059+
mb_new_logprobs = mb_local_logprobs
10141060

10151061
# Cache the old logprobs
10161062
if num_mini_batches > 1:
10171063
mb_old_logprobs = old_logprobs[i]
10181064
else:
10191065
with torch.no_grad():
10201066
if epoch_idx == 0:
1021-
old_logprobs[i] = mb_new_logprobs
1022-
mb_old_logprobs = old_logprobs[i].detach()
1067+
if args.use_vllm_logprobs:
1068+
old_logprobs[i] = mb_vllm_logprobs
1069+
else:
1070+
old_logprobs[i] = mb_local_logprobs.detach()
1071+
mb_old_logprobs = old_logprobs[i]
1072+
1073+
old_logprobs_mask = mb_old_logprobs != INVALID_LOGPROB
1074+
assert torch.all(old_logprobs_mask == mb_response_masks_bool), (
1075+
f"Old logprobs mask should match response mask. "
1076+
f"old_mask sum={old_logprobs_mask.sum()}, "
1077+
f"response_mask sum={mb_response_masks_bool.sum()}"
1078+
)
10231079

10241080
# Calculate the policy's loss
10251081
logprobs_diff = mb_new_logprobs - mb_old_logprobs
@@ -1028,6 +1084,46 @@ def train(
10281084
pg_losses2 = -mb_advantages[:, 1:] * torch.clamp(
10291085
ratio, 1.0 - args.clip_lower, 1.0 + args.clip_higher
10301086
)
1087+
1088+
# Apply truncated importance sampling if enabled
1089+
if args.truncated_importance_sampling_ratio_cap > 0 and mb_vllm_logprobs is not None:
1090+
old_logprobs_mask = mb_old_logprobs != INVALID_LOGPROB
1091+
vllm_logprobs_mask = mb_vllm_logprobs != INVALID_LOGPROB
1092+
1093+
assert torch.all(old_logprobs_mask == mb_response_masks_bool), (
1094+
f"Old logprobs mask should match response mask. "
1095+
f"old_mask sum={old_logprobs_mask.sum()}, "
1096+
f"response_mask sum={mb_response_masks_bool.sum()}"
1097+
)
1098+
assert torch.all(vllm_logprobs_mask == mb_response_masks_bool), (
1099+
f"vLLM logprobs mask should match response mask. "
1100+
f"vllm_mask sum={vllm_logprobs_mask.sum()}, "
1101+
f"response_mask sum={mb_response_masks_bool.sum()}"
1102+
)
1103+
1104+
valid_mask = mb_response_masks_bool
1105+
1106+
# Initialize importance ratio to 1.0 (no effect) for all positions
1107+
tis_imp_ratio = torch.ones_like(mb_old_logprobs)
1108+
1109+
if valid_mask.any():
1110+
# Calculate logprob difference only for valid positions
1111+
logprob_diff_is = mb_old_logprobs - mb_vllm_logprobs
1112+
# Clamp to prevent numerical overflow in exp
1113+
logprob_diff_is = torch.where(
1114+
valid_mask, logprob_diff_is.clamp(-10.0, 10.0), torch.zeros_like(logprob_diff_is)
1115+
)
1116+
# Compute importance ratio only for valid positions
1117+
tis_imp_ratio = torch.where(valid_mask, torch.exp(logprob_diff_is), tis_imp_ratio)
1118+
# Apply cap
1119+
tis_imp_ratio = torch.clamp(
1120+
tis_imp_ratio, max=args.truncated_importance_sampling_ratio_cap
1121+
)
1122+
1123+
# Apply importance sampling to losses
1124+
pg_losses = pg_losses * tis_imp_ratio
1125+
pg_losses2 = pg_losses2 * tis_imp_ratio
1126+
10311127
pg_loss_max = torch.max(pg_losses, pg_losses2)
10321128

10331129
# Here we recalculate kl: we want the KL loss to backpropagate through the model
@@ -1510,6 +1606,7 @@ def accumulate_inference_batches(
15101606
combined_tool_outputs = []
15111607
combined_tool_runtimes = []
15121608
combined_tool_calleds = []
1609+
combined_logprobs = []
15131610

15141611
earliest_start_time = float("inf")
15151612
prompt_lengths = []
@@ -1530,6 +1627,8 @@ def accumulate_inference_batches(
15301627
combined_tool_runtimes.extend(result.request_info.tool_runtimes)
15311628
combined_tool_calleds.extend(result.request_info.tool_calleds)
15321629

1630+
combined_logprobs.extend(result.logprobs)
1631+
15331632
earliest_start_time = min(earliest_start_time, result.start_time)
15341633

15351634
prompt_lengths.append(len(all_queries[i]))
@@ -1570,6 +1669,7 @@ def accumulate_inference_batches(
15701669
request_info=combined_request_info,
15711670
dataset_index=None, # Not meaningful for combined result
15721671
token_statistics=accumulated_stats,
1672+
logprobs=combined_logprobs,
15731673
)
15741674

15751675
if actor_manager is not None:
@@ -1636,14 +1736,10 @@ def data_preparation_thread(
16361736
for i in range(len(result.request_info.tool_outputs))
16371737
]
16381738
for i in range(len(result.finish_reasons)):
1639-
# edge case: sometimes it outputs eos immediately, and we get an empty response
1640-
# in that case, we need to add the eos token to the response
1641-
# note that this also adds eos to the end of reponses that stopped for other reasons.
1642-
if result.finish_reasons[i] == "stop" and (
1643-
len(result.responses[i]) == 0 or result.responses[i][-1] != tokenizer.eos_token_id
1644-
):
1739+
if result.finish_reasons[i] == "stop" and len(result.responses[i]) == 0:
16451740
result.responses[i].append(tokenizer.eos_token_id)
1646-
result.masks[i].append(1) # never mask the eos token for
1741+
result.masks[i].append(1)
1742+
result.logprobs[i].append(float("nan"))
16471743
with Timer("🔥 [Data Preparation Thread] Decoding responses", noop=True):
16481744
decoded_responses = tokenizer.batch_decode(result.responses, skip_special_tokens=True)
16491745
decoded_queries = batch.raw_queries
@@ -1706,6 +1802,7 @@ def data_preparation_thread(
17061802
masks = [result.masks[i] for i in non_zero_gradient_index]
17071803
batch = batch[non_zero_gradient_index.tolist()]
17081804
finish_reasons = [result.finish_reasons[i] for i in non_zero_gradient_index]
1805+
vllm_logprobs = [result.logprobs[i] for i in non_zero_gradient_index]
17091806
if args.mask_truncated_completions:
17101807
stop_idxes = torch.tensor([i for i in range(len(finish_reasons)) if finish_reasons[i] == "stop"])
17111808
num_truncated = len(finish_reasons) - len(stop_idxes)
@@ -1720,6 +1817,7 @@ def data_preparation_thread(
17201817
masks = [masks[i] for i in stop_idxes]
17211818
batch = batch[stop_idxes.tolist()]
17221819
finish_reasons = [finish_reasons[i] for i in stop_idxes]
1820+
vllm_logprobs = [vllm_logprobs[i] for i in stop_idxes]
17231821

17241822
if args.fill_completions:
17251823
with Timer("⏱ [Data Preparation Thread] Refill completions"):
@@ -1763,6 +1861,7 @@ def data_preparation_thread(
17631861
)
17641862

17651863
finish_reasons += [finish_reasons[i] for i in sampled_indices]
1864+
vllm_logprobs += [vllm_logprobs[i] for i in sampled_indices]
17661865

17671866
logger.info(
17681867
f"📊 Duplicated {need_to_fill_prompt} prompts from {len(sampled_indices)} total responses"
@@ -1783,6 +1882,7 @@ def data_preparation_thread(
17831882
masks=masks,
17841883
pack_length=args.pack_length,
17851884
pad_token_id=tokenizer.pad_token_id,
1885+
vllm_logprobs=vllm_logprobs,
17861886
)
17871887
num_new_tokens = sum(len(seq) for seq in packed_sequences.query_responses)
17881888
# Vectorized advantage calculation: create a lookup array where each index corresponds to a response mask value
@@ -1832,6 +1932,7 @@ def data_preparation_thread(
18321932
per_device_packed_position_ids = packed_sequences.position_ids[B * i : B * (i + 1)]
18331933
per_device_packed_advantages = packed_sequences.advantages[B * i : B * (i + 1)]
18341934
per_device_packed_response_masks = packed_sequences.response_masks[B * i : B * (i + 1)]
1935+
per_device_packed_vllm_logprobs = packed_sequences.vllm_logprobs[B * i : B * (i + 1)]
18351936

18361937
# Shuffle the batch and collate the data
18371938
b_inds = np.random.permutation(len(per_device_packed_query_responses))
@@ -1841,6 +1942,7 @@ def data_preparation_thread(
18411942
collated_position_ids = []
18421943
collated_response_masks = []
18431944
collated_advantages = []
1945+
collated_vllm_logprobs = []
18441946
for j in range(0, len(per_device_packed_query_responses), args.per_device_train_batch_size):
18451947
micro_range = b_inds[j : j + args.per_device_train_batch_size]
18461948
collated_query_responses.append(
@@ -1863,6 +1965,9 @@ def data_preparation_thread(
18631965
collated_advantages.append(
18641966
collate_fn([per_device_packed_advantages[idx] for idx in micro_range], 0)
18651967
)
1968+
collated_vllm_logprobs.append(
1969+
collate_fn([per_device_packed_vllm_logprobs[idx] for idx in micro_range], 0)
1970+
)
18661971
collated_data.append(
18671972
{
18681973
"collated_query_responses": collated_query_responses,
@@ -1871,6 +1976,7 @@ def data_preparation_thread(
18711976
"collated_position_ids": collated_position_ids,
18721977
"collated_advantages": collated_advantages,
18731978
"collated_response_masks": collated_response_masks,
1979+
"collated_vllm_logprobs": collated_vllm_logprobs,
18741980
}
18751981
)
18761982

@@ -2175,6 +2281,7 @@ def create_generation_configs(args: Args):
21752281
n=args.num_samples_per_prompt_rollout,
21762282
stop=args.stop_strings,
21772283
seed=args.seed,
2284+
logprobs=1, # Enable logprobs to compare with local calculations
21782285
# IMPORTANT: Set output_kind to FINAL_ONLY to ensure vLLM V1 properly handles n>1
21792286
# With the default CUMULATIVE mode, vLLM V1 returns separate outputs for each
21802287
# completion, making it difficult to aggregate them correctly. FINAL_ONLY mode

open_instruct/queue_types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ class GenerationResult:
3636
training_step: Optional[int] = None
3737
token_statistics: Optional[TokenStatistics] = None
3838
start_time: Optional[float] = None
39+
logprobs: Optional[List[List[float]]] = None # logprobs for each token in each response
3940

4041

4142
@dataclass

0 commit comments

Comments
 (0)