Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Feb 12, 2024
1 parent feae0f2 commit 66e63f7
Show file tree
Hide file tree
Showing 11 changed files with 148 additions and 106 deletions.
86 changes: 50 additions & 36 deletions sheeprl/algos/dreamer_v2/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,9 @@ def __init__(
layer_args={"kernel_size": 4, "stride": 2},
activation=activation,
norm_layer=[LayerNormChannelLast for _ in range(4)] if layer_norm else None,
norm_args=[{"normalized_shape": (2**i) * channels_multiplier} for i in range(4)]
if layer_norm
else None,
norm_args=(
[{"normalized_shape": (2**i) * channels_multiplier} for i in range(4)] if layer_norm else None
),
),
nn.Flatten(-3, -1),
)
Expand Down Expand Up @@ -172,12 +172,12 @@ def __init__(
],
activation=[activation, activation, activation, None],
norm_layer=[LayerNormChannelLast for _ in range(3)] + [None] if layer_norm else None,
norm_args=[
{"normalized_shape": (2 ** (4 - i - 2)) * channels_multiplier} for i in range(self.output_dim[0])
]
+ [None]
if layer_norm
else None,
norm_args=(
[{"normalized_shape": (2 ** (4 - i - 2)) * channels_multiplier} for i in range(self.output_dim[0])]
+ [None]
if layer_norm
else None
),
),
)

Expand Down Expand Up @@ -943,9 +943,11 @@ def build_agent(
activation=eval(world_model_cfg.representation_model.dense_act),
flatten_dim=None,
norm_layer=[nn.LayerNorm] if world_model_cfg.representation_model.layer_norm else None,
norm_args=[{"normalized_shape": world_model_cfg.representation_model.hidden_size}]
if world_model_cfg.representation_model.layer_norm
else None,
norm_args=(
[{"normalized_shape": world_model_cfg.representation_model.hidden_size}]
if world_model_cfg.representation_model.layer_norm
else None
),
)
transition_model = MLP(
input_dims=world_model_cfg.recurrent_model.recurrent_state_size,
Expand All @@ -954,9 +956,11 @@ def build_agent(
activation=eval(world_model_cfg.transition_model.dense_act),
flatten_dim=None,
norm_layer=[nn.LayerNorm] if world_model_cfg.transition_model.layer_norm else None,
norm_args=[{"normalized_shape": world_model_cfg.transition_model.hidden_size}]
if world_model_cfg.transition_model.layer_norm
else None,
norm_args=(
[{"normalized_shape": world_model_cfg.transition_model.hidden_size}]
if world_model_cfg.transition_model.layer_norm
else None
),
)
rssm = RSSM(
recurrent_model=recurrent_model.apply(init_weights),
Expand Down Expand Up @@ -999,15 +1003,19 @@ def build_agent(
hidden_sizes=[world_model_cfg.reward_model.dense_units] * world_model_cfg.reward_model.mlp_layers,
activation=eval(world_model_cfg.reward_model.dense_act),
flatten_dim=None,
norm_layer=[nn.LayerNorm for _ in range(world_model_cfg.reward_model.mlp_layers)]
if world_model_cfg.reward_model.layer_norm
else None,
norm_args=[
{"normalized_shape": world_model_cfg.reward_model.dense_units}
for _ in range(world_model_cfg.reward_model.mlp_layers)
]
if world_model_cfg.reward_model.layer_norm
else None,
norm_layer=(
[nn.LayerNorm for _ in range(world_model_cfg.reward_model.mlp_layers)]
if world_model_cfg.reward_model.layer_norm
else None
),
norm_args=(
[
{"normalized_shape": world_model_cfg.reward_model.dense_units}
for _ in range(world_model_cfg.reward_model.mlp_layers)
]
if world_model_cfg.reward_model.layer_norm
else None
),
)
if world_model_cfg.use_continues:
continue_model = MLP(
Expand All @@ -1016,15 +1024,19 @@ def build_agent(
hidden_sizes=[world_model_cfg.discount_model.dense_units] * world_model_cfg.discount_model.mlp_layers,
activation=eval(world_model_cfg.discount_model.dense_act),
flatten_dim=None,
norm_layer=[nn.LayerNorm for _ in range(world_model_cfg.discount_model.mlp_layers)]
if world_model_cfg.discount_model.layer_norm
else None,
norm_args=[
{"normalized_shape": world_model_cfg.discount_model.dense_units}
for _ in range(world_model_cfg.discount_model.mlp_layers)
]
if world_model_cfg.discount_model.layer_norm
else None,
norm_layer=(
[nn.LayerNorm for _ in range(world_model_cfg.discount_model.mlp_layers)]
if world_model_cfg.discount_model.layer_norm
else None
),
norm_args=(
[
{"normalized_shape": world_model_cfg.discount_model.dense_units}
for _ in range(world_model_cfg.discount_model.mlp_layers)
]
if world_model_cfg.discount_model.layer_norm
else None
),
)
world_model = WorldModel(
encoder.apply(init_weights),
Expand Down Expand Up @@ -1053,9 +1065,11 @@ def build_agent(
activation=eval(critic_cfg.dense_act),
flatten_dim=None,
norm_layer=[nn.LayerNorm for _ in range(critic_cfg.mlp_layers)] if critic_cfg.layer_norm else None,
norm_args=[{"normalized_shape": critic_cfg.dense_units} for _ in range(critic_cfg.mlp_layers)]
if critic_cfg.layer_norm
else None,
norm_args=(
[{"normalized_shape": critic_cfg.dense_units} for _ in range(critic_cfg.mlp_layers)]
if critic_cfg.layer_norm
else None
),
)
actor.apply(init_weights)
critic.apply(init_weights)
Expand Down
1 change: 1 addition & 0 deletions sheeprl/algos/dreamer_v2/dreamer_v2.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Dreamer-V2 implementation from [https://arxiv.org/abs/2010.02193](https://arxiv.org/abs/2010.02193).
Adapted from the original implementation from https://github.com/danijar/dreamerv2
"""

from __future__ import annotations

import copy
Expand Down
110 changes: 64 additions & 46 deletions sheeprl/algos/dreamer_v3/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,11 @@ def __init__(
layer_args={"kernel_size": 4, "stride": 2, "padding": 1, "bias": not layer_norm},
activation=activation,
norm_layer=[LayerNormChannelLast for _ in range(stages)] if layer_norm else None,
norm_args=[{"normalized_shape": (2**i) * channels_multiplier, "eps": 1e-3} for i in range(stages)]
if layer_norm
else None,
norm_args=(
[{"normalized_shape": (2**i) * channels_multiplier, "eps": 1e-3} for i in range(stages)]
if layer_norm
else None
),
),
nn.Flatten(-3, -1),
)
Expand Down Expand Up @@ -123,9 +125,9 @@ def __init__(
activation=activation,
layer_args={"bias": not layer_norm},
norm_layer=[nn.LayerNorm for _ in range(mlp_layers)] if layer_norm else None,
norm_args=[{"normalized_shape": dense_units, "eps": 1e-3} for _ in range(mlp_layers)]
if layer_norm
else None,
norm_args=(
[{"normalized_shape": dense_units, "eps": 1e-3} for _ in range(mlp_layers)] if layer_norm else None
),
)
self.output_dim = dense_units
self.symlog_inputs = symlog_inputs
Expand Down Expand Up @@ -193,13 +195,15 @@ def __init__(
+ [{"kernel_size": 4, "stride": 2, "padding": 1}],
activation=[activation for _ in range(stages - 1)] + [None],
norm_layer=[LayerNormChannelLast for _ in range(stages - 1)] + [None] if layer_norm else None,
norm_args=[
{"normalized_shape": (2 ** (stages - i - 2)) * channels_multiplier, "eps": 1e-3}
for i in range(stages - 1)
]
+ [None]
if layer_norm
else None,
norm_args=(
[
{"normalized_shape": (2 ** (stages - i - 2)) * channels_multiplier, "eps": 1e-3}
for i in range(stages - 1)
]
+ [None]
if layer_norm
else None
),
),
)

Expand Down Expand Up @@ -248,9 +252,9 @@ def __init__(
activation=activation,
layer_args={"bias": not layer_norm},
norm_layer=[nn.LayerNorm for _ in range(mlp_layers)] if layer_norm else None,
norm_args=[{"normalized_shape": dense_units, "eps": 1e-3} for _ in range(mlp_layers)]
if layer_norm
else None,
norm_args=(
[{"normalized_shape": dense_units, "eps": 1e-3} for _ in range(mlp_layers)] if layer_norm else None
),
)
self.heads = nn.ModuleList([nn.Linear(dense_units, mlp_dim) for mlp_dim in self.output_dims])

Expand Down Expand Up @@ -654,9 +658,9 @@ def __init__(
flatten_dim=None,
layer_args={"bias": not layer_norm},
norm_layer=[nn.LayerNorm for _ in range(mlp_layers)] if layer_norm else None,
norm_args=[{"normalized_shape": dense_units, "eps": 1e-3} for _ in range(mlp_layers)]
if layer_norm
else None,
norm_args=(
[{"normalized_shape": dense_units, "eps": 1e-3} for _ in range(mlp_layers)] if layer_norm else None
),
)
if is_continuous:
self.mlp_heads = nn.ModuleList([nn.Linear(dense_units, np.sum(actions_dim) * 2)])
Expand Down Expand Up @@ -980,9 +984,11 @@ def build_agent(
layer_args={"bias": not world_model_cfg.representation_model.layer_norm},
flatten_dim=None,
norm_layer=[nn.LayerNorm] if world_model_cfg.representation_model.layer_norm else None,
norm_args=[{"normalized_shape": world_model_cfg.representation_model.hidden_size}]
if world_model_cfg.representation_model.layer_norm
else None,
norm_args=(
[{"normalized_shape": world_model_cfg.representation_model.hidden_size}]
if world_model_cfg.representation_model.layer_norm
else None
),
)
transition_model = MLP(
input_dims=recurrent_state_size,
Expand All @@ -992,9 +998,11 @@ def build_agent(
layer_args={"bias": not world_model_cfg.transition_model.layer_norm},
flatten_dim=None,
norm_layer=[nn.LayerNorm] if world_model_cfg.transition_model.layer_norm else None,
norm_args=[{"normalized_shape": world_model_cfg.transition_model.hidden_size}]
if world_model_cfg.transition_model.layer_norm
else None,
norm_args=(
[{"normalized_shape": world_model_cfg.transition_model.hidden_size}]
if world_model_cfg.transition_model.layer_norm
else None
),
)
rssm = RSSM(
recurrent_model=recurrent_model.apply(init_weights),
Expand Down Expand Up @@ -1040,15 +1048,19 @@ def build_agent(
activation=eval(world_model_cfg.reward_model.dense_act),
layer_args={"bias": not world_model_cfg.reward_model.layer_norm},
flatten_dim=None,
norm_layer=[nn.LayerNorm for _ in range(world_model_cfg.reward_model.mlp_layers)]
if world_model_cfg.reward_model.layer_norm
else None,
norm_args=[
{"normalized_shape": world_model_cfg.reward_model.dense_units}
for _ in range(world_model_cfg.reward_model.mlp_layers)
]
if world_model_cfg.reward_model.layer_norm
else None,
norm_layer=(
[nn.LayerNorm for _ in range(world_model_cfg.reward_model.mlp_layers)]
if world_model_cfg.reward_model.layer_norm
else None
),
norm_args=(
[
{"normalized_shape": world_model_cfg.reward_model.dense_units}
for _ in range(world_model_cfg.reward_model.mlp_layers)
]
if world_model_cfg.reward_model.layer_norm
else None
),
)
continue_model = MLP(
input_dims=latent_state_size,
Expand All @@ -1057,15 +1069,19 @@ def build_agent(
activation=eval(world_model_cfg.discount_model.dense_act),
layer_args={"bias": not world_model_cfg.discount_model.layer_norm},
flatten_dim=None,
norm_layer=[nn.LayerNorm for _ in range(world_model_cfg.discount_model.mlp_layers)]
if world_model_cfg.discount_model.layer_norm
else None,
norm_args=[
{"normalized_shape": world_model_cfg.discount_model.dense_units}
for _ in range(world_model_cfg.discount_model.mlp_layers)
]
if world_model_cfg.discount_model.layer_norm
else None,
norm_layer=(
[nn.LayerNorm for _ in range(world_model_cfg.discount_model.mlp_layers)]
if world_model_cfg.discount_model.layer_norm
else None
),
norm_args=(
[
{"normalized_shape": world_model_cfg.discount_model.dense_units}
for _ in range(world_model_cfg.discount_model.mlp_layers)
]
if world_model_cfg.discount_model.layer_norm
else None
),
)
world_model = WorldModel(
encoder.apply(init_weights),
Expand Down Expand Up @@ -1096,9 +1112,11 @@ def build_agent(
layer_args={"bias": not critic_cfg.layer_norm},
flatten_dim=None,
norm_layer=[nn.LayerNorm for _ in range(critic_cfg.mlp_layers)] if critic_cfg.layer_norm else None,
norm_args=[{"normalized_shape": critic_cfg.dense_units} for _ in range(critic_cfg.mlp_layers)]
if critic_cfg.layer_norm
else None,
norm_args=(
[{"normalized_shape": critic_cfg.dense_units} for _ in range(critic_cfg.mlp_layers)]
if critic_cfg.layer_norm
else None
),
)
actor.apply(init_weights)
critic.apply(init_weights)
Expand Down
1 change: 1 addition & 0 deletions sheeprl/algos/dreamer_v3/dreamer_v3.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Dreamer-V3 implementation from [https://arxiv.org/abs/2301.04104](https://arxiv.org/abs/2301.04104)
Adapted from the original implementation from https://github.com/danijar/dreamerv3
"""

from __future__ import annotations

import copy
Expand Down
8 changes: 5 additions & 3 deletions sheeprl/algos/p2e_dv2/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,11 @@ def build_agent(
activation=eval(critic_cfg.dense_act),
flatten_dim=None,
norm_layer=[nn.LayerNorm for _ in range(critic_cfg.mlp_layers)] if critic_cfg.layer_norm else None,
norm_args=[{"normalized_shape": critic_cfg.dense_units} for _ in range(critic_cfg.mlp_layers)]
if critic_cfg.layer_norm
else None,
norm_args=(
[{"normalized_shape": critic_cfg.dense_units} for _ in range(critic_cfg.mlp_layers)]
if critic_cfg.layer_norm
else None
),
)
actor_task.apply(init_weights)
critic_task.apply(init_weights)
Expand Down
8 changes: 5 additions & 3 deletions sheeprl/algos/p2e_dv3/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,11 @@ def build_agent(
flatten_dim=None,
layer_args={"bias": not critic_cfg.layer_norm},
norm_layer=[nn.LayerNorm for _ in range(critic_cfg.mlp_layers)] if critic_cfg.layer_norm else None,
norm_args=[{"normalized_shape": critic_cfg.dense_units} for _ in range(critic_cfg.mlp_layers)]
if critic_cfg.layer_norm
else None,
norm_args=(
[{"normalized_shape": critic_cfg.dense_units} for _ in range(critic_cfg.mlp_layers)]
if critic_cfg.layer_norm
else None
),
),
}
critics_exploration[k]["module"].apply(init_weights)
Expand Down
16 changes: 10 additions & 6 deletions sheeprl/algos/ppo_recurrent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,11 @@ def __init__(
activation=eval(pre_rnn_mlp_cfg.activation),
layer_args={"bias": pre_rnn_mlp_cfg.bias},
norm_layer=[nn.LayerNorm] if pre_rnn_mlp_cfg.layer_norm else None,
norm_args=[{"normalized_shape": pre_rnn_mlp_cfg.dense_units, "eps": 1e-3}]
if pre_rnn_mlp_cfg.layer_norm
else None,
norm_args=(
[{"normalized_shape": pre_rnn_mlp_cfg.dense_units, "eps": 1e-3}]
if pre_rnn_mlp_cfg.layer_norm
else None
),
)
else:
self._pre_mlp = nn.Identity()
Expand All @@ -45,9 +47,11 @@ def __init__(
activation=eval(post_rnn_mlp_cfg.activation),
layer_args={"bias": post_rnn_mlp_cfg.bias},
norm_layer=[nn.LayerNorm] if post_rnn_mlp_cfg.layer_norm else None,
norm_args=[{"normalized_shape": post_rnn_mlp_cfg.dense_units, "eps": 1e-3}]
if post_rnn_mlp_cfg.layer_norm
else None,
norm_args=(
[{"normalized_shape": post_rnn_mlp_cfg.dense_units, "eps": 1e-3}]
if post_rnn_mlp_cfg.layer_norm
else None
),
)
self._output_dim = post_rnn_mlp_cfg.dense_units
else:
Expand Down
Loading

0 comments on commit 66e63f7

Please sign in to comment.