Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Jun 21, 2024
1 parent dcf34e4 commit 4f0a35a
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 7 deletions.
10 changes: 5 additions & 5 deletions torchrl/modules/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def mode(self):
return self.base_dist.mean

@property
def deterministic(self):
def deterministic_sample(self):
return self.mean


Expand Down Expand Up @@ -294,7 +294,7 @@ def mode(self):
return torch.max(torch.stack([m, a], -1), dim=-1)[0]

@property
def deterministic(self):
def deterministic_sample(self):
return self.mean

def log_prob(self, value, **kwargs):
Expand Down Expand Up @@ -488,7 +488,7 @@ def mode(self):
return self.deterministic

@property
def deterministic(self):
def deterministic_sample(self):
m = self.root_dist.mean
for t in self.transforms:
m = t(m)
Expand Down Expand Up @@ -619,7 +619,7 @@ def mode(self) -> torch.Tensor:
return self.param

@property
def deterministic(self):
def deterministic_sample(self):
return self.mean

@property
Expand Down Expand Up @@ -748,7 +748,7 @@ def mode(self) -> torch.Tensor:
return mode

@property
def deterministic(self):
def deterministic_sample(self):
return self.mode

@property
Expand Down
2 changes: 1 addition & 1 deletion torchrl/modules/distributions/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def mode(self) -> torch.Tensor:
return (self.probs == self.probs.max(-1, True)[0]).to(torch.long)

@property
def deterministic(self):
def deterministic_sample(self):
return self.mode

@_one_hot_wrapper(D.Categorical)
Expand Down
2 changes: 1 addition & 1 deletion torchrl/modules/distributions/truncated_normal.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def mean(self):
return self._mean

@property
def deterministic(self):
def deterministic_sample(self):
return self.mean

@property
Expand Down

0 comments on commit 4f0a35a

Please sign in to comment.