Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Jun 5, 2024
1 parent 944c3ec commit 6287d43
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 4 deletions.
3 changes: 1 addition & 2 deletions test/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,13 +135,12 @@ def test_tanhnormal_mode(self):
mode = t.mode
assert mode.shape == loc.shape
empirical_mode, empirical_mode_lp = torch.zeros_like(loc), -float("inf")
for v in torch.range(-1, 1, step=0.01):
for v in torch.arange(-1, 1, step=0.01):
lp = t.log_prob(v.expand_as(t.loc))
empirical_mode = torch.where(lp > empirical_mode_lp, v, empirical_mode)
empirical_mode_lp = torch.where(
lp > empirical_mode_lp, lp, empirical_mode_lp
)
print(abs(empirical_mode - mode).max(), abs(empirical_mode - mode).median())
assert abs(empirical_mode - mode).max() < 0.1, abs(empirical_mode - mode).max()
assert mode.shape == loc.shape
assert (mode.std(0).max() < 0.1).all(), mode.std(0)
Expand Down
2 changes: 0 additions & 2 deletions torchrl/modules/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import functools
import warnings
from numbers import Number
from typing import Dict, Optional, Sequence, Tuple, Union
Expand All @@ -11,7 +10,6 @@
import torch
from torch import distributions as D, nn
from torch.distributions import constraints
from torchrl._utils import logger as torchrl_logger

from torchrl.modules.distributions.truncated_normal import (
TruncatedNormal as _TruncatedNormal,
Expand Down

0 comments on commit 6287d43

Please sign in to comment.