Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added MultiInputPolicy support to CrossQ #268

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions sb3_contrib/crossq/crossq.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
from torch.nn import functional as F

from sb3_contrib.crossq.policies import Actor, CrossQCritic, CrossQPolicy, MlpPolicy
from sb3_contrib.crossq.policies import Actor, CrossQCritic, CrossQPolicy, MlpPolicy, MultiInputPolicy

SelfCrossQ = TypeVar("SelfCrossQ", bound="CrossQ")

Expand Down Expand Up @@ -67,7 +67,8 @@ class CrossQ(OffPolicyAlgorithm):

policy_aliases: ClassVar[dict[str, type[BasePolicy]]] = {
"MlpPolicy": MlpPolicy,
# TODO: Implement CnnPolicy and MultiInputPolicy
"MultiInputPolicy": MultiInputPolicy,
# TODO: Implement CnnPolicy
}
policy: CrossQPolicy
actor: Actor
Expand Down Expand Up @@ -235,7 +236,14 @@ def train(self, gradient_steps: int, batch_size: int = 64) -> None:
#
# 2. From a computational perspective a single forward pass is simply more efficient than
# two sequential forward passes.
all_obs = th.cat([replay_data.observations, replay_data.next_observations], dim=0)

if isinstance(replay_data.observations, dict):
all_obs = {
key: th.cat([replay_data.observations[key], replay_data.next_observations[key]], dim=0)
for key in replay_data.observations.keys()
}
else:
all_obs = th.cat([replay_data.observations, replay_data.next_observations], dim=0)
all_actions = th.cat([replay_data.actions, next_actions], dim=0)
# Update critic BN stats
self.critic.set_bn_training_mode(True)
Expand Down Expand Up @@ -331,3 +339,4 @@ def _get_torch_save_params(self) -> tuple[list[str], list[str]]:
else:
saved_pytorch_variables = ["ent_coef_tensor"]
return state_dicts, saved_pytorch_variables

75 changes: 75 additions & 0 deletions sb3_contrib/crossq/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from stable_baselines3.common.torch_layers import (
BaseFeaturesExtractor,
FlattenExtractor,
CombinedExtractor,
create_mlp,
get_actor_critic_arch,
)
Expand Down Expand Up @@ -529,3 +530,77 @@ def set_training_mode(self, mode: bool) -> None:


MlpPolicy = CrossQPolicy

class MultiInputPolicy(CrossQPolicy):
"""
Policy class (with both actor and critic) for CrossQ.

:param observation_space: Observation space
:param action_space: Action space
:param lr_schedule: Learning rate schedule (could be constant)
:param net_arch: The specification of the policy and value networks.
:param activation_fn: Activation function
:param use_sde: Whether to use State Dependent Exploration or not
:param log_std_init: Initial value for the log standard deviation
:param use_expln: Use ``expln()`` function instead of ``exp()`` when using gSDE to ensure
a positive standard deviation (cf paper). It allows to keep variance
above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
:param clip_mean: Clip the mean output when using gSDE to avoid numerical instability.
:param features_extractor_class: Features extractor to use.
:param normalize_images: Whether to normalize images or not,
dividing by 255.0 (True by default)
:param optimizer_class: The optimizer to use,
``th.optim.Adam`` by default
:param optimizer_kwargs: Additional keyword arguments,
excluding the learning rate, to pass to the optimizer
:param n_quantiles: Number of quantiles for the critic.
:param n_critics: Number of critic networks to create.
:param share_features_extractor: Whether to share or not the features extractor
between the actor and the critic (this saves computation time)
"""

def __init__(
self,
observation_space: spaces.Space,
action_space: spaces.Box,
lr_schedule: Schedule,
net_arch: Optional[Union[list[int], dict[str, list[int]]]] = None,
activation_fn: type[nn.Module] = nn.ReLU,
batch_norm: bool = True,
batch_norm_momentum: float = 0.01, # Note: Jax implementation is 1 - momentum = 0.99
batch_norm_eps: float = 0.001,
renorm_warmup_steps: int = 100_000,
use_sde: bool = False,
log_std_init: float = -3,
use_expln: bool = False,
clip_mean: float = 2.0,
features_extractor_class: type[BaseFeaturesExtractor] = CombinedExtractor,
features_extractor_kwargs: Optional[dict[str, Any]] = None,
normalize_images: bool = True,
optimizer_class: type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[dict[str, Any]] = None,
n_critics: int = 2,
share_features_extractor: bool = False,
):
super().__init__(
observation_space,
action_space,
lr_schedule,
net_arch,
activation_fn,
batch_norm,
batch_norm_momentum,
batch_norm_eps,
renorm_warmup_steps,
use_sde,
log_std_init,
use_expln,
clip_mean,
features_extractor_class,
features_extractor_kwargs,
normalize_images,
optimizer_class,
optimizer_kwargs,
n_critics,
share_features_extractor,
)
10 changes: 5 additions & 5 deletions tests/test_dict_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.vec_env import DummyVecEnv, VecFrameStack, VecNormalize

from sb3_contrib import QRDQN, TQC, TRPO
from sb3_contrib import QRDQN, TQC, TRPO, CrossQ


class DummyDictEnv(gym.Env):
Expand Down Expand Up @@ -89,7 +89,7 @@ def test_env(use_discrete_actions, channel_last, nested_dict_obs, vec_only):
check_env(DummyDictEnv(use_discrete_actions, channel_last, nested_dict_obs, vec_only))


@pytest.mark.parametrize("model_class", [QRDQN, TQC, TRPO])
@pytest.mark.parametrize("model_class", [QRDQN, TQC, TRPO, CrossQ])
def test_consistency(model_class):
"""
Make sure that dict obs with vector only vs using flatten obs is equivalent.
Expand Down Expand Up @@ -134,7 +134,7 @@ def test_consistency(model_class):
assert np.allclose(action_1, action_2)


@pytest.mark.parametrize("model_class", [QRDQN, TQC, TRPO])
@pytest.mark.parametrize("model_class", [QRDQN, TQC, TRPO, CrossQ])
@pytest.mark.parametrize("channel_last", [False, True])
def test_dict_spaces(model_class, channel_last):
"""
Expand Down Expand Up @@ -179,7 +179,7 @@ def test_dict_spaces(model_class, channel_last):
evaluate_policy(model, env, n_eval_episodes=5, warn=False)


@pytest.mark.parametrize("model_class", [QRDQN, TQC, TRPO])
@pytest.mark.parametrize("model_class", [QRDQN, TQC, TRPO, CrossQ])
@pytest.mark.parametrize("channel_last", [False, True])
def test_dict_vec_framestack(model_class, channel_last):
"""
Expand Down Expand Up @@ -228,7 +228,7 @@ def test_dict_vec_framestack(model_class, channel_last):
evaluate_policy(model, env, n_eval_episodes=5, warn=False)


@pytest.mark.parametrize("model_class", [QRDQN, TQC, TRPO])
@pytest.mark.parametrize("model_class", [QRDQN, TQC, TRPO, CrossQ])
def test_vec_normalize(model_class):
"""
Additional tests to check observation space support
Expand Down