@@ -105,6 +105,10 @@ class GRPOLoss(LossModule):
105105 When set, tokens with 0.5 * (log(pi_theta/pi_ref))^2 > kl_mask_threshold are masked out from the loss.
106106 This stabilizes updates by skipping tokens that drifted too far from the reference distribution
107107 (see table and description; enables per-token trust region).
108+ aggregation (str, optional): loss aggregation strategy for the policy objective.
109+ - "token_mean": global masked token mean (weights long sequences more). Default.
110+ - "prompt_mean": per-sample masked mean over tokens, then mean across samples (equal sample weight).
111+ - "none": return per-token loss (mask applied, no aggregation). Useful for downstream custom reductions.
108112 entropy_bonus (bool, optional): if ``True``, an entropy bonus will be added to the
109113 loss to favour exploratory policies.
110114 samples_mc_entropy (int, optional): if the distribution retrieved from the policy
@@ -194,6 +198,7 @@ def __init__(
194198 * ,
195199 clip_epsilon : float | tuple [float , float ] = 0.2 ,
196200 kl_mask_threshold : float | None = None ,
201+ aggregation : str | None = "token_mean" ,
197202 entropy_bonus : bool = True ,
198203 samples_mc_entropy : int = 1 ,
199204 entropy_coeff : float = 0.01 ,
@@ -214,6 +219,7 @@ def __init__(
214219 self .entropy_coeff = entropy_coeff
215220 self .reduction = reduction if reduction is not None else "mean"
216221 self .kl_mask_threshold = kl_mask_threshold
222+ self .aggregation = aggregation or "token_mean"
217223
218224 # Determine device and register clip epsilon as buffer
219225 if device is None :
@@ -444,13 +450,13 @@ def forward(self, tensordict: TensorDictBase) -> LLMOutputType:
444450 td_out .set ("loss_entropy" , - self .entropy_coeff * entropy )
445451
446452 td_out .set ("ESS" , _reduce (ess / batch , self .reduction ))
447- td_out = td_out . named_apply (
448- lambda name , value : _reduce (
449- value , reduction = self . reduction , mask = mask
450- ). squeeze ( - 1 )
451- if name .startswith ("loss_" )
452- else value ,
453- )
453+ # Aggregate loss terms according to aggregation strategy
454+ for key in list ( td_out . keys ()):
455+ if isinstance ( key , tuple ) or not isinstance ( key , str ):
456+ continue
457+ if key .startswith ("loss_" ):
458+ val = td_out . get ( key )
459+ td_out . set ( key , self . _aggregate_loss_value ( val , mask ) )
454460 if self .kl_to_ref_coeff is not None and self .kl_to_ref_coeff > 0 :
455461 # FIXME: parameterize this
456462 loss_kl , kl_penalty = self ._kl_to_ref (
@@ -494,6 +500,34 @@ def _compute_policy_objective(
494500 gain = torch .stack ([gain1 , gain2 ], - 1 ).min (dim = - 1 ).values
495501 return - gain , clip_fraction
496502
503+ def _aggregate_loss_value (
504+ self , value : torch .Tensor , mask : torch .Tensor
505+ ) -> torch .Tensor :
506+ """Aggregate a per-token loss tensor using the configured strategy.
507+
508+ Supports:
509+ - token_mean: masked mean across all tokens (default)
510+ - prompt_mean: per-sample masked mean over tokens, then mean across batch
511+ - none: return per-token loss with masked-out tokens set to 0
512+
513+ The input `value` is expected to have shape [..., T, 1] where T is the token dimension,
514+ and `mask` has shape [..., T].
515+ """
516+ if self .aggregation == "none" or self .reduction == "none" :
517+ mask_exp = expand_as_right (mask , value )
518+ return torch .where (mask_exp , value , value .new_zeros (()).expand_as (value ))
519+
520+ if self .aggregation == "prompt_mean" :
521+ # Mean over valid tokens per sample, then mean across batch
522+ mask_exp = expand_as_right (mask , value ).to (value .dtype )
523+ token_sum = (value * mask_exp ).sum (dim = - 2 , keepdim = False )
524+ token_count = mask_exp .sum (dim = - 2 , keepdim = False ).clamp_min (1.0 )
525+ sample_mean = token_sum / token_count
526+ return sample_mean .mean (dim = 0 , keepdim = False )
527+
528+ # token_mean (global masked mean)
529+ return _reduce (value , reduction = "mean" , mask = mask ).squeeze (- 1 )
530+
497531 def _get_entropy (
498532 self , dist : d .Distribution , adv_shape : torch .Size
499533 ) -> torch .Tensor | TensorDict :
0 commit comments