diff --git a/torchrl/modules/distributions/discrete.py b/torchrl/modules/distributions/discrete.py index bfb7bc48f3c..d2d225ed3db 100644 --- a/torchrl/modules/distributions/discrete.py +++ b/torchrl/modules/distributions/discrete.py @@ -319,6 +319,10 @@ def _mask_logits( logits.masked_fill_(padding_mask, neg_inf) return logits + @property + def deterministic_sample(self): + return self.mode + class MaskedOneHotCategorical(MaskedCategorical): """MaskedCategorical distribution.