Skip to content

Conversation

finbarrtimbers
Copy link
Collaborator

@finbarrtimbers finbarrtimbers commented Sep 29, 2025

Added a use_vllm_logprobs flag which uses the vllm logprobs to train on instead of the local ones from the learner model. Also added the KL divergence from the generator to the trainer, and an option to use truncated importance sampling via the truncated_importance_sampling_ratio_cap flag.

Runs with --truncated_importance_sampling_ratio_cap 2.0:

Runs with --use_vllm_logprobs:

@finbarrtimbers finbarrtimbers marked this pull request as ready for review September 29, 2025 16:54
Copy link
Contributor

@mnoukhov mnoukhov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of purely using the vllm logprobs which always seemed pretty unstable, can we using something like truncated importance sampling from https://fengyao.notion.site/off-policy-rl
image

if you want to do it token-wise here's a code snippet from the authors https://github.com/yaof20/verl/blob/1e413344a2f31aefdbd05457843274f84dff9f2d/verl/trainer/ppo/core_algos.py#L893

finbarrtimbers and others added 3 commits October 7, 2025 08:23
…aN source

- Added truncated_importance_sampling_ratio_cap parameter (default 0.0)
- Implemented importance sampling with comprehensive assertions
- Added checks for INVALID_LOGPROB values and extreme logprob differences
- Added NaN checks at each step of the calculation

This will help identify exactly where NaNs are introduced when importance sampling is enabled.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <[email protected]>
- Add torch.nan_to_num to replace NaN values with INVALID_LOGPROB after collation
- Query tokens in packed sequences are set to NaN by pack_sequences in rl_utils2.py
- Apply nan_to_num in both training loop (line 1031) and old_logprobs calculation (line 987)
- Implement proper importance sampling with masking:
  * Only apply IS where both logprobs are valid (not INVALID_LOGPROB)
  * Use response mask to ensure only response tokens are affected
  * Initialize importance ratio to 1.0 (neutral) for invalid positions
  * Clamp logprob differences to prevent numerical overflow
- Remove all debug assertions that were causing failures
- Ensure importance sampling only affects valid response token positions

This fixes the 'NaN in mb_old_logprobs before IS' assertion error.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <[email protected]>
@finbarrtimbers
Copy link
Collaborator Author

Run with truncated importance sampling: Wandb

Copy link
Contributor

@mnoukhov mnoukhov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just do the check for not setting both args at the same time (as they conflict)

The other changes aren't necessary. Overall, I think our logic for logprobs is a bit convoluted but that's not the point of this PR so no need to fix.

finbarrtimbers and others added 7 commits October 7, 2025 11:04
…tch at source

This adds logging to check if vLLM CompletionOutput has mismatched token_ids
and logprobs lengths. According to vLLM source analysis, all generated tokens
should have logprobs, so N-1 behavior would indicate a bug.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <[email protected]>
Root cause: We were unconditionally appending EOS token when finish_reason='stop',
but not appending corresponding logprob. This created len(response) = N+1 but
len(logprobs) = N mismatch.

Fix:
- Only append EOS for truly empty responses (the actual edge case)
- When we do append EOS, also append NaN to logprobs
- Normal responses ending with </answer> no longer get EOS appended
- vLLM returns N logprobs for N tokens, so no mismatch

Also added assertions to verify masks match correctly. Updated diagnostic logging
to treat length mismatches as errors rather than expected behavior.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <[email protected]>
@finbarrtimbers
Copy link
Collaborator Author

After discussing with Hamish, one of the main issues is that we had code which adds EOS when we use a stop string (for instance, in reasoning, we end when we generate </answer>, not <eos>). Because we're manually adding EOS instead of having vLLM generate it, we obviously don't have logprobs, which is why the masks are different.

I changed this PR to remove the EOS addition. Runs seem unaffected:

https://wandb.ai/ai2-llm/open_instruct_internal?nw=ba89hjgyfp6

@finbarrtimbers finbarrtimbers added this pull request to the merge queue Oct 8, 2025
Merged via the queue into main with commit bd26188 Oct 8, 2025
3 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants