@@ -248,6 +248,8 @@ class Args:
248248 """the lower clip range"""
249249 clip_higher : float = 0.2
250250 """the higher clip range. Sometimes we want this to be higher, see DAPO (https://arxiv.org/abs/2503.14476)"""
251+ truncated_importance_sampling_ratio_cap : float = 0.0
252+ """The maximum cap for truncated importance sampling ratio (0 means disabled)"""
251253 inflight_updates : bool = False
252254 """Enable immediate stopping of request processing when should_stop is set, allowing for quick pausing and resumption"""
253255 kl_estimator : Literal ["kl1" , "kl2" , "kl3" , "kl4" ] = "kl3"
@@ -275,6 +277,8 @@ class Args:
275277
276278 record_entropy : bool = False
277279 """whether to record the entropy of the policy during training. Uses extra memory."""
280+ use_vllm_logprobs : bool = False
281+ """whether to use vLLM's logprobs for training instead of calculating them via forward pass"""
278282
279283 # Reward
280284 # -- r1 style format reward
@@ -436,6 +440,11 @@ def __post_init__(self):
436440 logger .warning ("When using the v0 version of vLLM, caching is broken and will never be invalidated." )
437441 if self .vllm_enable_prefix_caching :
438442 raise ValueError ("Prefix caching is currently not supported for v0." )
443+ if self .use_vllm_logprobs and self .truncated_importance_sampling_ratio_cap > 0.0 :
444+ raise ValueError (
445+ "Cannot use both `use_vllm_logprobs` and `truncated_importance_sampling_ratio_cap`. "
446+ "use_vllm_logprobs sets old_logprobs to vLLM logprobs, making importance sampling pointless."
447+ )
439448 assert self .num_samples_per_prompt_rollout > 0 , "Number of samples per prompt must be greater than 0!"
440449 if self .num_samples_per_prompt_rollout == 1 :
441450 logger .warning ("num_samples_per_prompt_rollout is 1. This reduces GRPO to REINFORCE." )
@@ -893,6 +902,7 @@ def train(
893902 collated_position_ids ,
894903 collated_advantages ,
895904 collated_response_masks ,
905+ collated_vllm_logprobs ,
896906 pad_token_id : int ,
897907 num_mini_batches : int ,
898908 ):
@@ -903,6 +913,7 @@ def train(
903913 to_device_inplace (collated_position_ids , self .device )
904914 to_device_inplace (collated_advantages , self .device )
905915 to_device_inplace (collated_response_masks , self .device )
916+ to_device_inplace (collated_vllm_logprobs , self .device )
906917 # accumulation steps should always be at least 1
907918 accumulation_steps = max (math .ceil (len (collated_query_responses ) / num_mini_batches - 0.5 ), 1 )
908919 leftover = len (collated_query_responses ) % accumulation_steps
@@ -913,6 +924,7 @@ def train(
913924 collated_position_ids = collated_position_ids [0 :- leftover ]
914925 collated_advantages = collated_advantages [0 :- leftover ]
915926 collated_response_masks = collated_response_masks [0 :- leftover ]
927+ collated_vllm_logprobs = collated_vllm_logprobs [0 :- leftover ]
916928 logger .warning (f"{ leftover } samples are dropped due to batch size { num_mini_batches } " )
917929
918930 # recalculate the "real" number of mini-batches
@@ -958,21 +970,31 @@ def train(
958970 attention_mask = collated_attention_masks [i ]
959971 position_id = collated_position_ids [i ]
960972 response_mask = collated_response_masks [i ]
961- old_logprob , _ = self .forward (
962- self .model ,
963- query_response ,
964- attention_mask ,
965- position_id ,
966- pad_token_id ,
967- args .temperature ,
968- return_entropy = False ,
969- )
973+ if not args .use_vllm_logprobs :
974+ local_old_logprob , _ = self .forward (
975+ self .model ,
976+ query_response ,
977+ attention_mask ,
978+ position_id ,
979+ pad_token_id ,
980+ args .temperature ,
981+ return_entropy = False ,
982+ )
983+ vllm_old_logprob = collated_vllm_logprobs [i ][:, 1 :]
970984 if args .mask_tool_use and args .tool_use :
971985 response_mask = response_mask .bool () & tool_mask .bool ()
972986 else :
973987 response_mask = response_mask .bool ()
974- old_logprob = torch .masked_fill (old_logprob , ~ response_mask [:, 1 :], INVALID_LOGPROB )
975- old_logprobs [i ] = old_logprob
988+ if not args .use_vllm_logprobs :
989+ local_old_logprob = torch .masked_fill (
990+ local_old_logprob , ~ response_mask [:, 1 :], INVALID_LOGPROB
991+ )
992+ vllm_old_logprob = torch .masked_fill (vllm_old_logprob , ~ response_mask [:, 1 :], INVALID_LOGPROB )
993+ vllm_old_logprob = torch .nan_to_num (vllm_old_logprob , nan = INVALID_LOGPROB )
994+ if args .use_vllm_logprobs :
995+ old_logprobs [i ] = vllm_old_logprob
996+ else :
997+ old_logprobs [i ] = local_old_logprob
976998 torch .cuda .empty_cache ()
977999
9781000 local_step = 0
@@ -1001,7 +1023,7 @@ def train(
10011023 mb_response_masks_bool = mb_response_masks [:, 1 :].bool () & mb_tool_mask [:, 1 :].bool ()
10021024 mb_attention_mask = collated_attention_masks [i ]
10031025 mb_position_id = collated_position_ids [i ]
1004- mb_new_logprobs , mb_entropy = self .forward (
1026+ mb_local_logprobs , mb_entropy = self .forward (
10051027 self .model ,
10061028 mb_query_responses ,
10071029 mb_attention_mask ,
@@ -1010,16 +1032,50 @@ def train(
10101032 args .temperature ,
10111033 return_entropy = args .record_entropy ,
10121034 )
1013- mb_new_logprobs = torch .masked_fill (mb_new_logprobs , ~ mb_response_masks_bool , INVALID_LOGPROB )
1035+ mb_local_logprobs = torch .masked_fill (mb_local_logprobs , ~ mb_response_masks_bool , INVALID_LOGPROB )
1036+ mb_vllm_logprobs = collated_vllm_logprobs [i ][:, 1 :]
1037+ mb_vllm_logprobs = torch .masked_fill (mb_vllm_logprobs , ~ mb_response_masks_bool , INVALID_LOGPROB )
1038+ # Replace any remaining NaN values (query tokens in packed sequences are set to NaN by pack_sequences in rl_utils2.py)
1039+ mb_vllm_logprobs = torch .nan_to_num (mb_vllm_logprobs , nan = INVALID_LOGPROB )
1040+
1041+ # Compare vLLM logprobs with local logprobs
1042+ with torch .no_grad ():
1043+ valid_mask = mb_response_masks_bool & ~ torch .isnan (mb_vllm_logprobs )
1044+ logprob_diff = (mb_local_logprobs - mb_vllm_logprobs ).abs ()
1045+ masked_diff = torch .masked_fill (logprob_diff , ~ valid_mask , 0.0 )
1046+ mean_diff = masked_diff .sum () / valid_mask .sum () if valid_mask .sum () > 0 else 0.0
1047+ max_diff = masked_diff .max ()
1048+ std_diff = masked_diff [valid_mask ].std () if valid_mask .sum () > 1 else 0.0
1049+
1050+ self .local_metrics .add ("debug/vllm_vs_local_logprob_diff_mean" , mean_diff .item ())
1051+ self .local_metrics .add ("debug/vllm_vs_local_logprob_diff_max" , max_diff .item ())
1052+ self .local_metrics .add ("debug/vllm_vs_local_logprob_diff_std" , std_diff .item ())
1053+
1054+ reverse_kl = torch .exp (mb_vllm_logprobs ) * (mb_vllm_logprobs - mb_local_logprobs )
1055+ masked_reverse_kl = torch .masked_fill (reverse_kl , ~ valid_mask , 0.0 )
1056+ mean_reverse_kl = masked_reverse_kl .sum () / valid_mask .sum () if valid_mask .sum () > 0 else 0.0
1057+ self .local_metrics .add ("debug/vllm_local_reverse_kl" , mean_reverse_kl .item ())
1058+
1059+ mb_new_logprobs = mb_local_logprobs
10141060
10151061 # Cache the old logprobs
10161062 if num_mini_batches > 1 :
10171063 mb_old_logprobs = old_logprobs [i ]
10181064 else :
10191065 with torch .no_grad ():
10201066 if epoch_idx == 0 :
1021- old_logprobs [i ] = mb_new_logprobs
1022- mb_old_logprobs = old_logprobs [i ].detach ()
1067+ if args .use_vllm_logprobs :
1068+ old_logprobs [i ] = mb_vllm_logprobs
1069+ else :
1070+ old_logprobs [i ] = mb_local_logprobs .detach ()
1071+ mb_old_logprobs = old_logprobs [i ]
1072+
1073+ old_logprobs_mask = mb_old_logprobs != INVALID_LOGPROB
1074+ assert torch .all (old_logprobs_mask == mb_response_masks_bool ), (
1075+ f"Old logprobs mask should match response mask. "
1076+ f"old_mask sum={ old_logprobs_mask .sum ()} , "
1077+ f"response_mask sum={ mb_response_masks_bool .sum ()} "
1078+ )
10231079
10241080 # Calculate the policy's loss
10251081 logprobs_diff = mb_new_logprobs - mb_old_logprobs
@@ -1028,6 +1084,46 @@ def train(
10281084 pg_losses2 = - mb_advantages [:, 1 :] * torch .clamp (
10291085 ratio , 1.0 - args .clip_lower , 1.0 + args .clip_higher
10301086 )
1087+
1088+ # Apply truncated importance sampling if enabled
1089+ if args .truncated_importance_sampling_ratio_cap > 0 and mb_vllm_logprobs is not None :
1090+ old_logprobs_mask = mb_old_logprobs != INVALID_LOGPROB
1091+ vllm_logprobs_mask = mb_vllm_logprobs != INVALID_LOGPROB
1092+
1093+ assert torch .all (old_logprobs_mask == mb_response_masks_bool ), (
1094+ f"Old logprobs mask should match response mask. "
1095+ f"old_mask sum={ old_logprobs_mask .sum ()} , "
1096+ f"response_mask sum={ mb_response_masks_bool .sum ()} "
1097+ )
1098+ assert torch .all (vllm_logprobs_mask == mb_response_masks_bool ), (
1099+ f"vLLM logprobs mask should match response mask. "
1100+ f"vllm_mask sum={ vllm_logprobs_mask .sum ()} , "
1101+ f"response_mask sum={ mb_response_masks_bool .sum ()} "
1102+ )
1103+
1104+ valid_mask = mb_response_masks_bool
1105+
1106+ # Initialize importance ratio to 1.0 (no effect) for all positions
1107+ tis_imp_ratio = torch .ones_like (mb_old_logprobs )
1108+
1109+ if valid_mask .any ():
1110+ # Calculate logprob difference only for valid positions
1111+ logprob_diff_is = mb_old_logprobs - mb_vllm_logprobs
1112+ # Clamp to prevent numerical overflow in exp
1113+ logprob_diff_is = torch .where (
1114+ valid_mask , logprob_diff_is .clamp (- 10.0 , 10.0 ), torch .zeros_like (logprob_diff_is )
1115+ )
1116+ # Compute importance ratio only for valid positions
1117+ tis_imp_ratio = torch .where (valid_mask , torch .exp (logprob_diff_is ), tis_imp_ratio )
1118+ # Apply cap
1119+ tis_imp_ratio = torch .clamp (
1120+ tis_imp_ratio , max = args .truncated_importance_sampling_ratio_cap
1121+ )
1122+
1123+ # Apply importance sampling to losses
1124+ pg_losses = pg_losses * tis_imp_ratio
1125+ pg_losses2 = pg_losses2 * tis_imp_ratio
1126+
10311127 pg_loss_max = torch .max (pg_losses , pg_losses2 )
10321128
10331129 # Here we recalculate kl: we want the KL loss to backpropagate through the model
@@ -1510,6 +1606,7 @@ def accumulate_inference_batches(
15101606 combined_tool_outputs = []
15111607 combined_tool_runtimes = []
15121608 combined_tool_calleds = []
1609+ combined_logprobs = []
15131610
15141611 earliest_start_time = float ("inf" )
15151612 prompt_lengths = []
@@ -1530,6 +1627,8 @@ def accumulate_inference_batches(
15301627 combined_tool_runtimes .extend (result .request_info .tool_runtimes )
15311628 combined_tool_calleds .extend (result .request_info .tool_calleds )
15321629
1630+ combined_logprobs .extend (result .logprobs )
1631+
15331632 earliest_start_time = min (earliest_start_time , result .start_time )
15341633
15351634 prompt_lengths .append (len (all_queries [i ]))
@@ -1570,6 +1669,7 @@ def accumulate_inference_batches(
15701669 request_info = combined_request_info ,
15711670 dataset_index = None , # Not meaningful for combined result
15721671 token_statistics = accumulated_stats ,
1672+ logprobs = combined_logprobs ,
15731673 )
15741674
15751675 if actor_manager is not None :
@@ -1636,14 +1736,10 @@ def data_preparation_thread(
16361736 for i in range (len (result .request_info .tool_outputs ))
16371737 ]
16381738 for i in range (len (result .finish_reasons )):
1639- # edge case: sometimes it outputs eos immediately, and we get an empty response
1640- # in that case, we need to add the eos token to the response
1641- # note that this also adds eos to the end of reponses that stopped for other reasons.
1642- if result .finish_reasons [i ] == "stop" and (
1643- len (result .responses [i ]) == 0 or result .responses [i ][- 1 ] != tokenizer .eos_token_id
1644- ):
1739+ if result .finish_reasons [i ] == "stop" and len (result .responses [i ]) == 0 :
16451740 result .responses [i ].append (tokenizer .eos_token_id )
1646- result .masks [i ].append (1 ) # never mask the eos token for
1741+ result .masks [i ].append (1 )
1742+ result .logprobs [i ].append (float ("nan" ))
16471743 with Timer ("🔥 [Data Preparation Thread] Decoding responses" , noop = True ):
16481744 decoded_responses = tokenizer .batch_decode (result .responses , skip_special_tokens = True )
16491745 decoded_queries = batch .raw_queries
@@ -1706,6 +1802,7 @@ def data_preparation_thread(
17061802 masks = [result .masks [i ] for i in non_zero_gradient_index ]
17071803 batch = batch [non_zero_gradient_index .tolist ()]
17081804 finish_reasons = [result .finish_reasons [i ] for i in non_zero_gradient_index ]
1805+ vllm_logprobs = [result .logprobs [i ] for i in non_zero_gradient_index ]
17091806 if args .mask_truncated_completions :
17101807 stop_idxes = torch .tensor ([i for i in range (len (finish_reasons )) if finish_reasons [i ] == "stop" ])
17111808 num_truncated = len (finish_reasons ) - len (stop_idxes )
@@ -1720,6 +1817,7 @@ def data_preparation_thread(
17201817 masks = [masks [i ] for i in stop_idxes ]
17211818 batch = batch [stop_idxes .tolist ()]
17221819 finish_reasons = [finish_reasons [i ] for i in stop_idxes ]
1820+ vllm_logprobs = [vllm_logprobs [i ] for i in stop_idxes ]
17231821
17241822 if args .fill_completions :
17251823 with Timer ("⏱ [Data Preparation Thread] Refill completions" ):
@@ -1763,6 +1861,7 @@ def data_preparation_thread(
17631861 )
17641862
17651863 finish_reasons += [finish_reasons [i ] for i in sampled_indices ]
1864+ vllm_logprobs += [vllm_logprobs [i ] for i in sampled_indices ]
17661865
17671866 logger .info (
17681867 f"📊 Duplicated { need_to_fill_prompt } prompts from { len (sampled_indices )} total responses"
@@ -1783,6 +1882,7 @@ def data_preparation_thread(
17831882 masks = masks ,
17841883 pack_length = args .pack_length ,
17851884 pad_token_id = tokenizer .pad_token_id ,
1885+ vllm_logprobs = vllm_logprobs ,
17861886 )
17871887 num_new_tokens = sum (len (seq ) for seq in packed_sequences .query_responses )
17881888 # Vectorized advantage calculation: create a lookup array where each index corresponds to a response mask value
@@ -1832,6 +1932,7 @@ def data_preparation_thread(
18321932 per_device_packed_position_ids = packed_sequences .position_ids [B * i : B * (i + 1 )]
18331933 per_device_packed_advantages = packed_sequences .advantages [B * i : B * (i + 1 )]
18341934 per_device_packed_response_masks = packed_sequences .response_masks [B * i : B * (i + 1 )]
1935+ per_device_packed_vllm_logprobs = packed_sequences .vllm_logprobs [B * i : B * (i + 1 )]
18351936
18361937 # Shuffle the batch and collate the data
18371938 b_inds = np .random .permutation (len (per_device_packed_query_responses ))
@@ -1841,6 +1942,7 @@ def data_preparation_thread(
18411942 collated_position_ids = []
18421943 collated_response_masks = []
18431944 collated_advantages = []
1945+ collated_vllm_logprobs = []
18441946 for j in range (0 , len (per_device_packed_query_responses ), args .per_device_train_batch_size ):
18451947 micro_range = b_inds [j : j + args .per_device_train_batch_size ]
18461948 collated_query_responses .append (
@@ -1863,6 +1965,9 @@ def data_preparation_thread(
18631965 collated_advantages .append (
18641966 collate_fn ([per_device_packed_advantages [idx ] for idx in micro_range ], 0 )
18651967 )
1968+ collated_vllm_logprobs .append (
1969+ collate_fn ([per_device_packed_vllm_logprobs [idx ] for idx in micro_range ], 0 )
1970+ )
18661971 collated_data .append (
18671972 {
18681973 "collated_query_responses" : collated_query_responses ,
@@ -1871,6 +1976,7 @@ def data_preparation_thread(
18711976 "collated_position_ids" : collated_position_ids ,
18721977 "collated_advantages" : collated_advantages ,
18731978 "collated_response_masks" : collated_response_masks ,
1979+ "collated_vllm_logprobs" : collated_vllm_logprobs ,
18741980 }
18751981 )
18761982
@@ -2175,6 +2281,7 @@ def create_generation_configs(args: Args):
21752281 n = args .num_samples_per_prompt_rollout ,
21762282 stop = args .stop_strings ,
21772283 seed = args .seed ,
2284+ logprobs = 1 , # Enable logprobs to compare with local calculations
21782285 # IMPORTANT: Set output_kind to FINAL_ONLY to ensure vLLM V1 properly handles n>1
21792286 # With the default CUMULATIVE mode, vLLM V1 returns separate outputs for each
21802287 # completion, making it difficult to aggregate them correctly. FINAL_ONLY mode
0 commit comments