Skip to content
Permalink

Comparing changes

This is a direct comparison between two commits made in this repository or its related repositories. View the default comparison for this range or learn more about diff comparisons.

Open a pull request

Create a new pull request by comparing changes across two branches. If you need to, you can also . Learn more about diff comparisons here.
base repository: Eclectic-Sheep/sheeprl
Failed to load repositories. Confirm that selected base ref is valid, then try again.
Loading
base: 24b6f41ef1676cf550ef47bb3e5ef4fefd2af939
Choose a base ref
..
head repository: Eclectic-Sheep/sheeprl
Failed to load repositories. Confirm that selected head ref is valid, then try again.
Loading
compare: b68e11451d1a45a3c5544a22d32731cff19b251a
Choose a head ref
Showing with 17 additions and 10 deletions.
  1. +1 −1 .pre-commit-config.yaml
  2. +1 −0 notebooks/dreamer_v3_imagination.ipynb
  3. +3 −1 sheeprl/algos/dreamer_v3/agent.py
  4. +12 −8 sheeprl/data/buffers.py
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -43,7 +43,7 @@ repos:
types: [jupyter]

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: "v0.8.3"
rev: "v0.8.4"
hooks:
- id: ruff
args: ["--config", "pyproject.toml", "--fix", "./sheeprl"]
1 change: 1 addition & 0 deletions notebooks/dreamer_v3_imagination.ipynb
Original file line number Diff line number Diff line change
@@ -60,6 +60,7 @@
"import torchvision\n",
"from lightning.fabric import Fabric\n",
"from omegaconf import OmegaConf\n",
"from PIL import Image\n",
"\n",
"from sheeprl.algos.dreamer_v3.agent import build_agent\n",
"from sheeprl.data.buffers import SequentialReplayBuffer\n",
4 changes: 3 additions & 1 deletion sheeprl/algos/dreamer_v3/agent.py
Original file line number Diff line number Diff line change
@@ -83,7 +83,9 @@ def __init__(
layer_args={"kernel_size": 4, "stride": 2, "padding": 1, "bias": layer_norm_cls == nn.Identity},
activation=activation,
norm_layer=[layer_norm_cls] * stages,
norm_args=[{**layer_norm_kw, "normalized_shape": (2**i) * channels_multiplier} for i in range(stages)],
norm_args=[
{**layer_norm_kw, "normalized_shape": (2**i) * channels_multiplier} for i in range(stages)
],
),
nn.Flatten(-3, -1),
)
20 changes: 12 additions & 8 deletions sheeprl/data/buffers.py
Original file line number Diff line number Diff line change
@@ -135,10 +135,12 @@ def to_tensor(
return buf

@typing.overload
def add(self, data: "ReplayBuffer", validate_args: bool = False) -> None: ...
def add(self, data: "ReplayBuffer", validate_args: bool = False) -> None:
...

@typing.overload
def add(self, data: Dict[str, np.ndarray], validate_args: bool = False) -> None: ...
def add(self, data: Dict[str, np.ndarray], validate_args: bool = False) -> None:
...

def add(self, data: "ReplayBuffer" | Dict[str, np.ndarray], validate_args: bool = False) -> None:
"""Add data to the replay buffer. If the replay buffer is full, then the oldest data is overwritten.
@@ -615,10 +617,12 @@ def __len__(self) -> int:
return self.buffer_size

@typing.overload
def add(self, data: "ReplayBuffer", validate_args: bool = False) -> None: ...
def add(self, data: "ReplayBuffer", validate_args: bool = False) -> None:
...

@typing.overload
def add(self, data: Dict[str, np.ndarray], validate_args: bool = False) -> None: ...
def add(self, data: Dict[str, np.ndarray], validate_args: bool = False) -> None:
...

def add(
self,
@@ -856,17 +860,17 @@ def __len__(self) -> int:
return self._cum_lengths[-1] if len(self._buf) > 0 else 0

@typing.overload
def add(
self, data: "ReplayBuffer", env_idxes: Sequence[int] | None = None, validate_args: bool = False
) -> None: ...
def add(self, data: "ReplayBuffer", env_idxes: Sequence[int] | None = None, validate_args: bool = False) -> None:
...

@typing.overload
def add(
self,
data: Dict[str, np.ndarray],
env_idxes: Sequence[int] | None = None,
validate_args: bool = False,
) -> None: ...
) -> None:
...

def add(
self,