Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Aug 12, 2024
1 parent 9118012 commit f7bffe5
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 15 deletions.
4 changes: 1 addition & 3 deletions sheeprl/algos/dreamer_v3/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,7 @@ def __init__(
layer_args={"kernel_size": 4, "stride": 2, "padding": 1, "bias": layer_norm_cls == nn.Identity},
activation=activation,
norm_layer=[layer_norm_cls] * stages,
norm_args=[
{**layer_norm_kw, "normalized_shape": (2**i) * channels_multiplier} for i in range(stages)
],
norm_args=[{**layer_norm_kw, "normalized_shape": (2**i) * channels_multiplier} for i in range(stages)],
),
nn.Flatten(-3, -1),
)
Expand Down
20 changes: 8 additions & 12 deletions sheeprl/data/buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,12 +135,10 @@ def to_tensor(
return buf

@typing.overload
def add(self, data: "ReplayBuffer", validate_args: bool = False) -> None:
...
def add(self, data: "ReplayBuffer", validate_args: bool = False) -> None: ...

@typing.overload
def add(self, data: Dict[str, np.ndarray], validate_args: bool = False) -> None:
...
def add(self, data: Dict[str, np.ndarray], validate_args: bool = False) -> None: ...

def add(self, data: "ReplayBuffer" | Dict[str, np.ndarray], validate_args: bool = False) -> None:
"""Add data to the replay buffer. If the replay buffer is full, then the oldest data is overwritten.
Expand Down Expand Up @@ -617,12 +615,10 @@ def __len__(self) -> int:
return self.buffer_size

@typing.overload
def add(self, data: "ReplayBuffer", validate_args: bool = False) -> None:
...
def add(self, data: "ReplayBuffer", validate_args: bool = False) -> None: ...

@typing.overload
def add(self, data: Dict[str, np.ndarray], validate_args: bool = False) -> None:
...
def add(self, data: Dict[str, np.ndarray], validate_args: bool = False) -> None: ...

def add(
self,
Expand Down Expand Up @@ -860,17 +856,17 @@ def __len__(self) -> int:
return self._cum_lengths[-1] if len(self._buf) > 0 else 0

@typing.overload
def add(self, data: "ReplayBuffer", env_idxes: Sequence[int] | None = None, validate_args: bool = False) -> None:
...
def add(
self, data: "ReplayBuffer", env_idxes: Sequence[int] | None = None, validate_args: bool = False
) -> None: ...

@typing.overload
def add(
self,
data: Dict[str, np.ndarray],
env_idxes: Sequence[int] | None = None,
validate_args: bool = False,
) -> None:
...
) -> None: ...

def add(
self,
Expand Down

0 comments on commit f7bffe5

Please sign in to comment.