-
Notifications
You must be signed in to change notification settings - Fork 445
Now, LLMRayActor
returns logprobs, and we calculate some stats about them vs the trainer logprobs in grpo_fast.py
.
#1041
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of purely using the vllm logprobs which always seemed pretty unstable, can we using something like truncated importance sampling from https://fengyao.notion.site/off-policy-rl
if you want to do it token-wise here's a code snippet from the authors https://github.com/yaof20/verl/blob/1e413344a2f31aefdbd05457843274f84dff9f2d/verl/trainer/ppo/core_algos.py#L893
…aN source - Added truncated_importance_sampling_ratio_cap parameter (default 0.0) - Implemented importance sampling with comprehensive assertions - Added checks for INVALID_LOGPROB values and extreme logprob differences - Added NaN checks at each step of the calculation This will help identify exactly where NaNs are introduced when importance sampling is enabled. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]>
- Add torch.nan_to_num to replace NaN values with INVALID_LOGPROB after collation - Query tokens in packed sequences are set to NaN by pack_sequences in rl_utils2.py - Apply nan_to_num in both training loop (line 1031) and old_logprobs calculation (line 987) - Implement proper importance sampling with masking: * Only apply IS where both logprobs are valid (not INVALID_LOGPROB) * Use response mask to ensure only response tokens are affected * Initialize importance ratio to 1.0 (neutral) for invalid positions * Clamp logprob differences to prevent numerical overflow - Remove all debug assertions that were causing failures - Ensure importance sampling only affects valid response token positions This fixes the 'NaN in mb_old_logprobs before IS' assertion error. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]>
Run with truncated importance sampling: Wandb |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just do the check for not setting both args at the same time (as they conflict)
The other changes aren't necessary. Overall, I think our logic for logprobs is a bit convoluted but that's not the point of this PR so no need to fix.
…tch at source This adds logging to check if vLLM CompletionOutput has mismatched token_ids and logprobs lengths. According to vLLM source analysis, all generated tokens should have logprobs, so N-1 behavior would indicate a bug. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]>
Root cause: We were unconditionally appending EOS token when finish_reason='stop', but not appending corresponding logprob. This created len(response) = N+1 but len(logprobs) = N mismatch. Fix: - Only append EOS for truly empty responses (the actual edge case) - When we do append EOS, also append NaN to logprobs - Normal responses ending with </answer> no longer get EOS appended - vLLM returns N logprobs for N tokens, so no mismatch Also added assertions to verify masks match correctly. Updated diagnostic logging to treat length mismatches as errors rather than expected behavior. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]>
After discussing with Hamish, one of the main issues is that we had code which adds EOS when we use a stop string (for instance, in reasoning, we end when we generate I changed this PR to remove the EOS addition. Runs seem unaffected: https://wandb.ai/ai2-llm/open_instruct_internal?nw=ba89hjgyfp6 |
Added a
use_vllm_logprobs
flag which uses the vllm logprobs to train on instead of the local ones from the learner model. Also added the KL divergence from the generator to the trainer, and an option to use truncated importance sampling via thetruncated_importance_sampling_ratio_cap
flag.Runs with
--truncated_importance_sampling_ratio_cap 2.0
:Runs with
--use_vllm_logprobs
: