From 6ef147e99b13a6e5e071b1309f9063f28445c771 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 21 Jan 2025 09:47:12 +0000 Subject: [PATCH] Update [ghstack-poisoned] --- torchrl/modules/distributions/discrete.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/torchrl/modules/distributions/discrete.py b/torchrl/modules/distributions/discrete.py index bfb7bc48f3c..d2d225ed3db 100644 --- a/torchrl/modules/distributions/discrete.py +++ b/torchrl/modules/distributions/discrete.py @@ -319,6 +319,10 @@ def _mask_logits( logits.masked_fill_(padding_mask, neg_inf) return logits + @property + def deterministic_sample(self): + return self.mode + class MaskedOneHotCategorical(MaskedCategorical): """MaskedCategorical distribution.