From 4f0a35a2faa4ddd786c383250c42d623b4988edf Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 21 Jun 2024 08:08:24 +0100 Subject: [PATCH] amend --- torchrl/modules/distributions/continuous.py | 10 +++++----- torchrl/modules/distributions/discrete.py | 2 +- torchrl/modules/distributions/truncated_normal.py | 2 +- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/torchrl/modules/distributions/continuous.py b/torchrl/modules/distributions/continuous.py index 2ea9cc3ee29..75dab5264c1 100644 --- a/torchrl/modules/distributions/continuous.py +++ b/torchrl/modules/distributions/continuous.py @@ -91,7 +91,7 @@ def mode(self): return self.base_dist.mean @property - def deterministic(self): + def deterministic_sample(self): return self.mean @@ -294,7 +294,7 @@ def mode(self): return torch.max(torch.stack([m, a], -1), dim=-1)[0] @property - def deterministic(self): + def deterministic_sample(self): return self.mean def log_prob(self, value, **kwargs): @@ -488,7 +488,7 @@ def mode(self): return self.deterministic @property - def deterministic(self): + def deterministic_sample(self): m = self.root_dist.mean for t in self.transforms: m = t(m) @@ -619,7 +619,7 @@ def mode(self) -> torch.Tensor: return self.param @property - def deterministic(self): + def deterministic_sample(self): return self.mean @property @@ -748,7 +748,7 @@ def mode(self) -> torch.Tensor: return mode @property - def deterministic(self): + def deterministic_sample(self): return self.mode @property diff --git a/torchrl/modules/distributions/discrete.py b/torchrl/modules/distributions/discrete.py index bfed6d29bec..c48d8168887 100644 --- a/torchrl/modules/distributions/discrete.py +++ b/torchrl/modules/distributions/discrete.py @@ -107,7 +107,7 @@ def mode(self) -> torch.Tensor: return (self.probs == self.probs.max(-1, True)[0]).to(torch.long) @property - def deterministic(self): + def deterministic_sample(self): return self.mode @_one_hot_wrapper(D.Categorical) diff --git a/torchrl/modules/distributions/truncated_normal.py b/torchrl/modules/distributions/truncated_normal.py index cf8610ceb57..1350aeb2bc3 100644 --- a/torchrl/modules/distributions/truncated_normal.py +++ b/torchrl/modules/distributions/truncated_normal.py @@ -86,7 +86,7 @@ def mean(self): return self._mean @property - def deterministic(self): + def deterministic_sample(self): return self.mean @property