From 5e3167e0f3600c18ba800c6000f5f1c1820d7205 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 9 Sep 2024 17:25:36 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- notebooks/dreamer_v3_imagination.ipynb | 1 - sheeprl/algos/dreamer_v3/agent.py | 4 +--- sheeprl/data/buffers.py | 20 ++++++++------------ 3 files changed, 9 insertions(+), 16 deletions(-) diff --git a/notebooks/dreamer_v3_imagination.ipynb b/notebooks/dreamer_v3_imagination.ipynb index 5545c47a..40ebbd74 100644 --- a/notebooks/dreamer_v3_imagination.ipynb +++ b/notebooks/dreamer_v3_imagination.ipynb @@ -60,7 +60,6 @@ "import torchvision\n", "from lightning.fabric import Fabric\n", "from omegaconf import OmegaConf\n", - "from PIL import Image\n", "\n", "from sheeprl.algos.dreamer_v3.agent import build_agent\n", "from sheeprl.data.buffers import SequentialReplayBuffer\n", diff --git a/sheeprl/algos/dreamer_v3/agent.py b/sheeprl/algos/dreamer_v3/agent.py index 02a40a8a..2a506420 100644 --- a/sheeprl/algos/dreamer_v3/agent.py +++ b/sheeprl/algos/dreamer_v3/agent.py @@ -83,9 +83,7 @@ def __init__( layer_args={"kernel_size": 4, "stride": 2, "padding": 1, "bias": layer_norm_cls == nn.Identity}, activation=activation, norm_layer=[layer_norm_cls] * stages, - norm_args=[ - {**layer_norm_kw, "normalized_shape": (2**i) * channels_multiplier} for i in range(stages) - ], + norm_args=[{**layer_norm_kw, "normalized_shape": (2**i) * channels_multiplier} for i in range(stages)], ), nn.Flatten(-3, -1), ) diff --git a/sheeprl/data/buffers.py b/sheeprl/data/buffers.py index 8c51c9b1..e51afadd 100644 --- a/sheeprl/data/buffers.py +++ b/sheeprl/data/buffers.py @@ -135,12 +135,10 @@ def to_tensor( return buf @typing.overload - def add(self, data: "ReplayBuffer", validate_args: bool = False) -> None: - ... + def add(self, data: "ReplayBuffer", validate_args: bool = False) -> None: ... @typing.overload - def add(self, data: Dict[str, np.ndarray], validate_args: bool = False) -> None: - ... + def add(self, data: Dict[str, np.ndarray], validate_args: bool = False) -> None: ... def add(self, data: "ReplayBuffer" | Dict[str, np.ndarray], validate_args: bool = False) -> None: """Add data to the replay buffer. If the replay buffer is full, then the oldest data is overwritten. @@ -617,12 +615,10 @@ def __len__(self) -> int: return self.buffer_size @typing.overload - def add(self, data: "ReplayBuffer", validate_args: bool = False) -> None: - ... + def add(self, data: "ReplayBuffer", validate_args: bool = False) -> None: ... @typing.overload - def add(self, data: Dict[str, np.ndarray], validate_args: bool = False) -> None: - ... + def add(self, data: Dict[str, np.ndarray], validate_args: bool = False) -> None: ... def add( self, @@ -860,8 +856,9 @@ def __len__(self) -> int: return self._cum_lengths[-1] if len(self._buf) > 0 else 0 @typing.overload - def add(self, data: "ReplayBuffer", env_idxes: Sequence[int] | None = None, validate_args: bool = False) -> None: - ... + def add( + self, data: "ReplayBuffer", env_idxes: Sequence[int] | None = None, validate_args: bool = False + ) -> None: ... @typing.overload def add( @@ -869,8 +866,7 @@ def add( data: Dict[str, np.ndarray], env_idxes: Sequence[int] | None = None, validate_args: bool = False, - ) -> None: - ... + ) -> None: ... def add( self,