@@ -101,6 +101,10 @@ class GRPOLoss(LossModule):
101101 - float x: symmetric clipping [1 - x, 1 + x] (default: 0.2)
102102 - tuple (eps_low, eps_high): asymmetric clipping [1 - eps_low, 1 + eps_high] as in DAPO Clip-Higher
103103 recommended defaults from DAPO: (0.20, 0.28); see Eq. (10) in the paper.
104+ kl_mask_threshold (float | None, optional): enable token-wise trust-region filtering (KL-Mask).
105+ When set, tokens with 0.5 * (log(pi_theta/pi_ref))^2 > kl_mask_threshold are masked out from the loss.
106+ This stabilizes updates by skipping tokens that drifted too far from the reference distribution
107+ (see table and description; enables per-token trust region).
104108 entropy_bonus (bool, optional): if ``True``, an entropy bonus will be added to the
105109 loss to favour exploratory policies.
106110 samples_mc_entropy (int, optional): if the distribution retrieved from the policy
@@ -189,6 +193,7 @@ def __init__(
189193 actor_network : LLMWrapperBase | None = None ,
190194 * ,
191195 clip_epsilon : float | tuple [float , float ] = 0.2 ,
196+ kl_mask_threshold : float | None = None ,
192197 entropy_bonus : bool = True ,
193198 samples_mc_entropy : int = 1 ,
194199 entropy_coeff : float = 0.01 ,
@@ -208,6 +213,7 @@ def __init__(
208213 self .samples_mc_entropy = samples_mc_entropy
209214 self .entropy_coeff = entropy_coeff
210215 self .reduction = reduction if reduction is not None else "mean"
216+ self .kl_mask_threshold = kl_mask_threshold
211217
212218 # Determine device and register clip epsilon as buffer
213219 if device is None :
@@ -382,6 +388,32 @@ def forward(self, tensordict: TensorDictBase) -> LLMOutputType:
382388 tensordict , adv_shape = advantage .shape [:- 1 ]
383389 )
384390 mask = dist .mask
391+
392+ # Optional per-token trust-region filtering (KL-Mask) vs reference policy
393+ if self .kl_mask_threshold is not None and self .kl_mask_threshold > 0 :
394+ try :
395+ ref_log_prob = tensordict .get (
396+ self .tensor_keys .ref_log_probs ,
397+ as_padded_tensor = True ,
398+ padding_side = "left" ,
399+ padding_value = 0.0 ,
400+ )
401+ except KeyError :
402+ ref_log_prob = None
403+ cur_log_prob = tensordict .get ("_cur_log_prob" , None )
404+ if (ref_log_prob is not None ) and (cur_log_prob is not None ):
405+ # Align to valid tokens only (safety)
406+ cur_log_prob_masked = torch .where (
407+ expand_as_right (mask , cur_log_prob ), cur_log_prob , 0.0
408+ )
409+ ref_log_prob_masked = torch .where (
410+ expand_as_right (mask , ref_log_prob ), ref_log_prob , 0.0
411+ )
412+ log_is_ref = cur_log_prob_masked - ref_log_prob_masked
413+ kl_token = 0.5 * (log_is_ref ** 2 )
414+ tr_mask = kl_token <= self .kl_mask_threshold
415+ # Combine with attention mask
416+ mask = mask & tr_mask
385417 # ESS for logging
386418 with torch .no_grad ():
387419 # In theory, ESS should be computed on particles sampled from the same source. Here we sample according
0 commit comments