Skip to content

Commit 624b675

Browse files
committed
[Feature] kl_mask_threshold
ghstack-source-id: eb37730 Pull-Request: #3208
1 parent 9754b25 commit 624b675

File tree

1 file changed

+32
-0
lines changed

1 file changed

+32
-0
lines changed

torchrl/objectives/llm/grpo.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,10 @@ class GRPOLoss(LossModule):
101101
- float x: symmetric clipping [1 - x, 1 + x] (default: 0.2)
102102
- tuple (eps_low, eps_high): asymmetric clipping [1 - eps_low, 1 + eps_high] as in DAPO Clip-Higher
103103
recommended defaults from DAPO: (0.20, 0.28); see Eq. (10) in the paper.
104+
kl_mask_threshold (float | None, optional): enable token-wise trust-region filtering (KL-Mask).
105+
When set, tokens with 0.5 * (log(pi_theta/pi_ref))^2 > kl_mask_threshold are masked out from the loss.
106+
This stabilizes updates by skipping tokens that drifted too far from the reference distribution
107+
(see table and description; enables per-token trust region).
104108
entropy_bonus (bool, optional): if ``True``, an entropy bonus will be added to the
105109
loss to favour exploratory policies.
106110
samples_mc_entropy (int, optional): if the distribution retrieved from the policy
@@ -189,6 +193,7 @@ def __init__(
189193
actor_network: LLMWrapperBase | None = None,
190194
*,
191195
clip_epsilon: float | tuple[float, float] = 0.2,
196+
kl_mask_threshold: float | None = None,
192197
entropy_bonus: bool = True,
193198
samples_mc_entropy: int = 1,
194199
entropy_coeff: float = 0.01,
@@ -208,6 +213,7 @@ def __init__(
208213
self.samples_mc_entropy = samples_mc_entropy
209214
self.entropy_coeff = entropy_coeff
210215
self.reduction = reduction if reduction is not None else "mean"
216+
self.kl_mask_threshold = kl_mask_threshold
211217

212218
# Determine device and register clip epsilon as buffer
213219
if device is None:
@@ -382,6 +388,32 @@ def forward(self, tensordict: TensorDictBase) -> LLMOutputType:
382388
tensordict, adv_shape=advantage.shape[:-1]
383389
)
384390
mask = dist.mask
391+
392+
# Optional per-token trust-region filtering (KL-Mask) vs reference policy
393+
if self.kl_mask_threshold is not None and self.kl_mask_threshold > 0:
394+
try:
395+
ref_log_prob = tensordict.get(
396+
self.tensor_keys.ref_log_probs,
397+
as_padded_tensor=True,
398+
padding_side="left",
399+
padding_value=0.0,
400+
)
401+
except KeyError:
402+
ref_log_prob = None
403+
cur_log_prob = tensordict.get("_cur_log_prob", None)
404+
if (ref_log_prob is not None) and (cur_log_prob is not None):
405+
# Align to valid tokens only (safety)
406+
cur_log_prob_masked = torch.where(
407+
expand_as_right(mask, cur_log_prob), cur_log_prob, 0.0
408+
)
409+
ref_log_prob_masked = torch.where(
410+
expand_as_right(mask, ref_log_prob), ref_log_prob, 0.0
411+
)
412+
log_is_ref = cur_log_prob_masked - ref_log_prob_masked
413+
kl_token = 0.5 * (log_is_ref**2)
414+
tr_mask = kl_token <= self.kl_mask_threshold
415+
# Combine with attention mask
416+
mask = mask & tr_mask
385417
# ESS for logging
386418
with torch.no_grad():
387419
# In theory, ESS should be computed on particles sampled from the same source. Here we sample according

0 commit comments

Comments
 (0)