diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index c3ec86c4..27de0eb0 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,6 @@ repos: - repo: https://github.com/PyCQA/autoflake - rev: v2.2.1 + rev: v2.3.1 hooks: - id: autoflake name: Remove unused variables and imports @@ -28,7 +28,7 @@ repos: files: \.py$ - repo: https://github.com/psf/black - rev: 23.12.1 + rev: 24.10.0 hooks: - id: black name: (black) Format Python code @@ -43,7 +43,7 @@ repos: types: [jupyter] - repo: https://github.com/astral-sh/ruff-pre-commit - rev: "v0.1.11" + rev: "v0.8.2" hooks: - id: ruff args: ["--config", "pyproject.toml", "--fix", "./sheeprl"] diff --git a/notebooks/dreamer_v3_imagination.ipynb b/notebooks/dreamer_v3_imagination.ipynb index 5545c47a..40ebbd74 100644 --- a/notebooks/dreamer_v3_imagination.ipynb +++ b/notebooks/dreamer_v3_imagination.ipynb @@ -60,7 +60,6 @@ "import torchvision\n", "from lightning.fabric import Fabric\n", "from omegaconf import OmegaConf\n", - "from PIL import Image\n", "\n", "from sheeprl.algos.dreamer_v3.agent import build_agent\n", "from sheeprl.data.buffers import SequentialReplayBuffer\n", diff --git a/sheeprl/algos/dreamer_v3/agent.py b/sheeprl/algos/dreamer_v3/agent.py index 02a40a8a..2a506420 100644 --- a/sheeprl/algos/dreamer_v3/agent.py +++ b/sheeprl/algos/dreamer_v3/agent.py @@ -83,9 +83,7 @@ def __init__( layer_args={"kernel_size": 4, "stride": 2, "padding": 1, "bias": layer_norm_cls == nn.Identity}, activation=activation, norm_layer=[layer_norm_cls] * stages, - norm_args=[ - {**layer_norm_kw, "normalized_shape": (2**i) * channels_multiplier} for i in range(stages) - ], + norm_args=[{**layer_norm_kw, "normalized_shape": (2**i) * channels_multiplier} for i in range(stages)], ), nn.Flatten(-3, -1), ) diff --git a/sheeprl/data/buffers.py b/sheeprl/data/buffers.py index 8c51c9b1..e51afadd 100644 --- a/sheeprl/data/buffers.py +++ b/sheeprl/data/buffers.py @@ -135,12 +135,10 @@ def to_tensor( return buf @typing.overload - def add(self, data: "ReplayBuffer", validate_args: bool = False) -> None: - ... + def add(self, data: "ReplayBuffer", validate_args: bool = False) -> None: ... @typing.overload - def add(self, data: Dict[str, np.ndarray], validate_args: bool = False) -> None: - ... + def add(self, data: Dict[str, np.ndarray], validate_args: bool = False) -> None: ... def add(self, data: "ReplayBuffer" | Dict[str, np.ndarray], validate_args: bool = False) -> None: """Add data to the replay buffer. If the replay buffer is full, then the oldest data is overwritten. @@ -617,12 +615,10 @@ def __len__(self) -> int: return self.buffer_size @typing.overload - def add(self, data: "ReplayBuffer", validate_args: bool = False) -> None: - ... + def add(self, data: "ReplayBuffer", validate_args: bool = False) -> None: ... @typing.overload - def add(self, data: Dict[str, np.ndarray], validate_args: bool = False) -> None: - ... + def add(self, data: Dict[str, np.ndarray], validate_args: bool = False) -> None: ... def add( self, @@ -860,8 +856,9 @@ def __len__(self) -> int: return self._cum_lengths[-1] if len(self._buf) > 0 else 0 @typing.overload - def add(self, data: "ReplayBuffer", env_idxes: Sequence[int] | None = None, validate_args: bool = False) -> None: - ... + def add( + self, data: "ReplayBuffer", env_idxes: Sequence[int] | None = None, validate_args: bool = False + ) -> None: ... @typing.overload def add( @@ -869,8 +866,7 @@ def add( data: Dict[str, np.ndarray], env_idxes: Sequence[int] | None = None, validate_args: bool = False, - ) -> None: - ... + ) -> None: ... def add( self,