From ba60f5251b8098d0da97f6ea99246a4b6802f0fb Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 13 May 2024 17:17:45 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- sheeprl/algos/dreamer_v3/agent.py | 4 +--- sheeprl/data/buffers.py | 20 ++++++++------------ sheeprl/utils/distribution.py | 2 ++ 3 files changed, 11 insertions(+), 15 deletions(-) diff --git a/sheeprl/algos/dreamer_v3/agent.py b/sheeprl/algos/dreamer_v3/agent.py index 02a40a8a..2a506420 100644 --- a/sheeprl/algos/dreamer_v3/agent.py +++ b/sheeprl/algos/dreamer_v3/agent.py @@ -83,9 +83,7 @@ def __init__( layer_args={"kernel_size": 4, "stride": 2, "padding": 1, "bias": layer_norm_cls == nn.Identity}, activation=activation, norm_layer=[layer_norm_cls] * stages, - norm_args=[ - {**layer_norm_kw, "normalized_shape": (2**i) * channels_multiplier} for i in range(stages) - ], + norm_args=[{**layer_norm_kw, "normalized_shape": (2**i) * channels_multiplier} for i in range(stages)], ), nn.Flatten(-3, -1), ) diff --git a/sheeprl/data/buffers.py b/sheeprl/data/buffers.py index 8c51c9b1..e51afadd 100644 --- a/sheeprl/data/buffers.py +++ b/sheeprl/data/buffers.py @@ -135,12 +135,10 @@ def to_tensor( return buf @typing.overload - def add(self, data: "ReplayBuffer", validate_args: bool = False) -> None: - ... + def add(self, data: "ReplayBuffer", validate_args: bool = False) -> None: ... @typing.overload - def add(self, data: Dict[str, np.ndarray], validate_args: bool = False) -> None: - ... + def add(self, data: Dict[str, np.ndarray], validate_args: bool = False) -> None: ... def add(self, data: "ReplayBuffer" | Dict[str, np.ndarray], validate_args: bool = False) -> None: """Add data to the replay buffer. If the replay buffer is full, then the oldest data is overwritten. @@ -617,12 +615,10 @@ def __len__(self) -> int: return self.buffer_size @typing.overload - def add(self, data: "ReplayBuffer", validate_args: bool = False) -> None: - ... + def add(self, data: "ReplayBuffer", validate_args: bool = False) -> None: ... @typing.overload - def add(self, data: Dict[str, np.ndarray], validate_args: bool = False) -> None: - ... + def add(self, data: Dict[str, np.ndarray], validate_args: bool = False) -> None: ... def add( self, @@ -860,8 +856,9 @@ def __len__(self) -> int: return self._cum_lengths[-1] if len(self._buf) > 0 else 0 @typing.overload - def add(self, data: "ReplayBuffer", env_idxes: Sequence[int] | None = None, validate_args: bool = False) -> None: - ... + def add( + self, data: "ReplayBuffer", env_idxes: Sequence[int] | None = None, validate_args: bool = False + ) -> None: ... @typing.overload def add( @@ -869,8 +866,7 @@ def add( data: Dict[str, np.ndarray], env_idxes: Sequence[int] | None = None, validate_args: bool = False, - ) -> None: - ... + ) -> None: ... def add( self, diff --git a/sheeprl/utils/distribution.py b/sheeprl/utils/distribution.py index 31765bb6..842a745d 100644 --- a/sheeprl/utils/distribution.py +++ b/sheeprl/utils/distribution.py @@ -307,6 +307,7 @@ class OneHotCategoricalValidateArgs(Distribution): probs (Tensor): event probabilities logits (Tensor): event log probabilities (unnormalized) """ + arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector} support = constraints.one_hot has_enumerate_support = True @@ -391,6 +392,7 @@ class OneHotCategoricalStraightThroughValidateArgs(OneHotCategoricalValidateArgs [1] Estimating or Propagating Gradients Through Stochastic Neurons for Conditional Computation (Bengio et al, 2013) """ + has_rsample = True def rsample(self, sample_shape=torch.Size()):