Skip to content

Commit 0b7bddd

Browse files
committed
[Feature] Aggregation strategies
ghstack-source-id: 6508aa3 Pull-Request: #3209
1 parent 624b675 commit 0b7bddd

File tree

1 file changed

+41
-7
lines changed

1 file changed

+41
-7
lines changed

torchrl/objectives/llm/grpo.py

Lines changed: 41 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,10 @@ class GRPOLoss(LossModule):
105105
When set, tokens with 0.5 * (log(pi_theta/pi_ref))^2 > kl_mask_threshold are masked out from the loss.
106106
This stabilizes updates by skipping tokens that drifted too far from the reference distribution
107107
(see table and description; enables per-token trust region).
108+
aggregation (str, optional): loss aggregation strategy for the policy objective.
109+
- "token_mean": global masked token mean (weights long sequences more). Default.
110+
- "prompt_mean": per-sample masked mean over tokens, then mean across samples (equal sample weight).
111+
- "none": return per-token loss (mask applied, no aggregation). Useful for downstream custom reductions.
108112
entropy_bonus (bool, optional): if ``True``, an entropy bonus will be added to the
109113
loss to favour exploratory policies.
110114
samples_mc_entropy (int, optional): if the distribution retrieved from the policy
@@ -194,6 +198,7 @@ def __init__(
194198
*,
195199
clip_epsilon: float | tuple[float, float] = 0.2,
196200
kl_mask_threshold: float | None = None,
201+
aggregation: str | None = "token_mean",
197202
entropy_bonus: bool = True,
198203
samples_mc_entropy: int = 1,
199204
entropy_coeff: float = 0.01,
@@ -214,6 +219,7 @@ def __init__(
214219
self.entropy_coeff = entropy_coeff
215220
self.reduction = reduction if reduction is not None else "mean"
216221
self.kl_mask_threshold = kl_mask_threshold
222+
self.aggregation = aggregation or "token_mean"
217223

218224
# Determine device and register clip epsilon as buffer
219225
if device is None:
@@ -444,13 +450,13 @@ def forward(self, tensordict: TensorDictBase) -> LLMOutputType:
444450
td_out.set("loss_entropy", -self.entropy_coeff * entropy)
445451

446452
td_out.set("ESS", _reduce(ess / batch, self.reduction))
447-
td_out = td_out.named_apply(
448-
lambda name, value: _reduce(
449-
value, reduction=self.reduction, mask=mask
450-
).squeeze(-1)
451-
if name.startswith("loss_")
452-
else value,
453-
)
453+
# Aggregate loss terms according to aggregation strategy
454+
for key in list(td_out.keys()):
455+
if isinstance(key, tuple) or not isinstance(key, str):
456+
continue
457+
if key.startswith("loss_"):
458+
val = td_out.get(key)
459+
td_out.set(key, self._aggregate_loss_value(val, mask))
454460
if self.kl_to_ref_coeff is not None and self.kl_to_ref_coeff > 0:
455461
# FIXME: parameterize this
456462
loss_kl, kl_penalty = self._kl_to_ref(
@@ -494,6 +500,34 @@ def _compute_policy_objective(
494500
gain = torch.stack([gain1, gain2], -1).min(dim=-1).values
495501
return -gain, clip_fraction
496502

503+
def _aggregate_loss_value(
504+
self, value: torch.Tensor, mask: torch.Tensor
505+
) -> torch.Tensor:
506+
"""Aggregate a per-token loss tensor using the configured strategy.
507+
508+
Supports:
509+
- token_mean: masked mean across all tokens (default)
510+
- prompt_mean: per-sample masked mean over tokens, then mean across batch
511+
- none: return per-token loss with masked-out tokens set to 0
512+
513+
The input `value` is expected to have shape [..., T, 1] where T is the token dimension,
514+
and `mask` has shape [..., T].
515+
"""
516+
if self.aggregation == "none" or self.reduction == "none":
517+
mask_exp = expand_as_right(mask, value)
518+
return torch.where(mask_exp, value, value.new_zeros(()).expand_as(value))
519+
520+
if self.aggregation == "prompt_mean":
521+
# Mean over valid tokens per sample, then mean across batch
522+
mask_exp = expand_as_right(mask, value).to(value.dtype)
523+
token_sum = (value * mask_exp).sum(dim=-2, keepdim=False)
524+
token_count = mask_exp.sum(dim=-2, keepdim=False).clamp_min(1.0)
525+
sample_mean = token_sum / token_count
526+
return sample_mean.mean(dim=0, keepdim=False)
527+
528+
# token_mean (global masked mean)
529+
return _reduce(value, reduction="mean", mask=mask).squeeze(-1)
530+
497531
def _get_entropy(
498532
self, dist: d.Distribution, adv_shape: torch.Size
499533
) -> torch.Tensor | TensorDict:

0 commit comments

Comments
 (0)