From a92fdd9ed26ca7c7a9dd7bd4598d181bca1d5655 Mon Sep 17 00:00:00 2001 From: Yi Wan Date: Fri, 13 Dec 2024 14:43:25 -0800 Subject: [PATCH] Add CNNActorNetwork Summary: To apply PPO to Atari games, we add CNNActorNetwork. Reviewed By: rodrigodesalvobraz Differential Revision: D66281497 fbshipit-source-id: 26899d83f0fb711b96bd4c3ea0c6bde93295a2ef --- .../actor_networks.py | 134 +++++++++++++++++- pearl/utils/scripts/benchmark.py | 32 +++++ pearl/utils/scripts/benchmark_config.py | 40 +++++- 3 files changed, 203 insertions(+), 3 deletions(-) diff --git a/pearl/neural_networks/sequential_decision_making/actor_networks.py b/pearl/neural_networks/sequential_decision_making/actor_networks.py index 072b124..6800d07 100644 --- a/pearl/neural_networks/sequential_decision_making/actor_networks.py +++ b/pearl/neural_networks/sequential_decision_making/actor_networks.py @@ -17,7 +17,13 @@ import torch.nn as nn from pearl.api.action_space import ActionSpace -from pearl.neural_networks.common.utils import mlp_block +from pearl.neural_networks.common.utils import ( + compute_output_dim_model_cnn, + conv_block, + mlp_block, +) +from pearl.utils.functional_utils.learning.is_one_hot_tensor import is_one_hot_tensor + from pearl.utils.instantiations.spaces.box_action import BoxActionSpace from torch import Tensor @@ -174,6 +180,132 @@ def get_action_prob( return action_probs.view(-1) +class CNNActorNetwork(ActorNetwork): + def __init__( + self, + input_width: int, + input_height: int, + input_channels_count: int, + kernel_sizes: List[int], + output_channels_list: List[int], + strides: List[int], + paddings: List[int], + hidden_dims_fully_connected: Optional[List[int]] = None, + output_dim: int = 1, + use_batch_norm_conv: bool = False, + use_batch_norm_fully_connected: bool = False, + action_space: Optional[ActionSpace] = None, + ) -> None: + """A CNN Actor Network is meant to be used with CNN to deal with images. + For an input state (batch of states), it outputs a probability distribution over + all the actions. + """ + super(CNNActorNetwork, self).__init__( + input_dim=input_width * input_height * input_channels_count, + hidden_dims=hidden_dims_fully_connected, + output_dim=output_dim, + action_space=action_space, + ) + self._input_channels = input_channels_count + self._input_height = input_height + self._input_width = input_width + self._output_channels = output_channels_list + self._kernel_sizes = kernel_sizes + self._strides = strides + self._paddings = paddings + if hidden_dims_fully_connected is None: + self._hidden_dims_fully_connected: List[int] = [] + else: + self._hidden_dims_fully_connected: List[int] = hidden_dims_fully_connected + + self._use_batch_norm_conv = use_batch_norm_conv + self._use_batch_norm_fully_connected = use_batch_norm_fully_connected + self._output_dim = output_dim + + self._model_cnn: nn.Module = conv_block( + input_channels_count=self._input_channels, + output_channels_list=self._output_channels, + kernel_sizes=self._kernel_sizes, + strides=self._strides, + paddings=self._paddings, + use_batch_norm=self._use_batch_norm_conv, + ) + # we concatenate actions to state representations in the mlp block of the Q-value network + self._mlp_input_dims: int = compute_output_dim_model_cnn( + input_channels=input_channels_count, + input_width=input_width, + input_height=input_height, + model_cnn=self._model_cnn, + ) + self._model_fc: nn.Module = mlp_block( + input_dim=self._mlp_input_dims, + hidden_dims=self._hidden_dims_fully_connected, + output_dim=self._output_dim, + use_batch_norm=self._use_batch_norm_fully_connected, + last_activation="softmax", + ) + self._state_dim: int = input_channels_count * input_height * input_width + + def forward( + self, + state_batch: torch.Tensor, # shape: (batch_size, input_channels, input_height, input_width) + ) -> torch.Tensor: + batch_size = state_batch.shape[0] + state_representation_batch = self._model_cnn( + state_batch / 255.0 + ) # (batch_size x output_channels[-1] x output_height x output_width) + state_representation_batch = state_representation_batch.view( + batch_size, -1 + ) # (batch_size x state dim) + policy = self._model_fc( + state_representation_batch + ) # (batch_size x num actions) + return policy + + def get_policy_distribution( + self, + state_batch: torch.Tensor, # shape: (batch_size, input_channels, input_height, input_width) + available_actions: Optional[torch.Tensor] = None, + unavailable_actions_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Gets a policy distribution from a discrete actor network. + available_actions and unavailable_actions_mask are not used. + """ + if len(state_batch.shape) == 3: + state_batch = state_batch.unsqueeze(0) + reshape_state_batch = True + else: + reshape_state_batch = False + policy_distribution = self.forward( + state_batch + ) # shape (batch_size, available_actions) + if reshape_state_batch: + policy_distribution = policy_distribution.squeeze(0) + return policy_distribution + + def get_action_prob( + self, + state_batch: torch.Tensor, # shape: (batch_size, input_channels, input_height, input_width) + action_batch: torch.Tensor, # shape: (batch_size, action_dim) + available_actions: Optional[torch.Tensor] = None, + unavailable_actions_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Gets probabilities of different actions from a discrete actor network. + Assumes that the input batch of actions is one-hot encoded + (generalize it later). + """ + assert is_one_hot_tensor(action_batch) + assert self._output_dim == action_batch.shape[-1] # action_dim = num actions + assert len(state_batch.shape) == 4 + assert len(action_batch.shape) == 2 + all_action_probs = self.forward(state_batch) # shape: (batch_size, output_dim) + action_probs = torch.sum(all_action_probs * action_batch, dim=1, keepdim=True) + + return action_probs.view(-1) + + class DynamicActionActorNetwork(VanillaActorNetwork): def __init__( self, diff --git a/pearl/utils/scripts/benchmark.py b/pearl/utils/scripts/benchmark.py index 047c821..6d54f41 100644 --- a/pearl/utils/scripts/benchmark.py +++ b/pearl/utils/scripts/benchmark.py @@ -26,6 +26,10 @@ from pearl.action_representation_modules.identity_action_representation_module import ( IdentityActionRepresentationModule, ) +from pearl.neural_networks.common.value_networks import CNNValueNetwork +from pearl.neural_networks.sequential_decision_making.actor_networks import ( + CNNActorNetwork, +) from pearl.neural_networks.sequential_decision_making.q_value_networks import ( CNNQValueMultiHeadNetwork, CNNQValueNetwork, @@ -241,6 +245,34 @@ def evaluate_single( **method["network_args"], ) + if ( + "critic_network_module" in method + and method["critic_network_module"] is CNNValueNetwork + ): + policy_learner_args["critic_network_instance"] = method[ + "critic_network_module" + ]( + input_width=env.observation_space.shape[2], + input_height=env.observation_space.shape[1], + input_channels_count=env.observation_space.shape[0], + output_dim=1, + **method["critic_network_args"], + ) + + if ( + "actor_network_module" in method + and method["actor_network_module"] is CNNActorNetwork + ): + policy_learner_args["actor_network_instance"] = method["actor_network_module"]( + input_width=env.observation_space.shape[2], + input_height=env.observation_space.shape[1], + input_channels_count=env.observation_space.shape[0], + output_dim=policy_learner_args[ + "action_representation_module" + ].representation_dim, + **method["actor_network_args"], + ) + if method["name"] == "DuelingDQN": # only for Dueling DQN assert "network_module" in method and "network_args" in method policy_learner_args["network_instance"] = method["network_module"]( diff --git a/pearl/utils/scripts/benchmark_config.py b/pearl/utils/scripts/benchmark_config.py index 134f2f6..75e2d44 100644 --- a/pearl/utils/scripts/benchmark_config.py +++ b/pearl/utils/scripts/benchmark_config.py @@ -15,8 +15,12 @@ from pearl.history_summarization_modules.lstm_history_summarization_module import ( LSTMHistorySummarizationModule, ) -from pearl.neural_networks.common.value_networks import VanillaValueNetwork +from pearl.neural_networks.common.value_networks import ( + CNNValueNetwork, + VanillaValueNetwork, +) from pearl.neural_networks.sequential_decision_making.actor_networks import ( + CNNActorNetwork, DynamicActionActorNetwork, GaussianActorNetwork, VanillaActorNetwork, @@ -392,7 +396,38 @@ "training_rounds": 50, "batch_size": 32, "epsilon": 0.1, - "use_critic": True, + }, + "replay_buffer": PPOReplayBuffer, + "replay_buffer_args": {"capacity": 50000}, + "action_representation_module": OneHotActionTensorRepresentationModule, + "action_representation_module_args": {}, + "learn_every_k_steps": 200, +} +PPO_Atari_method = { + "name": "PPO", + "policy_learner": ProximalPolicyOptimization, + "policy_learner_args": { + "actor_hidden_dims": [64, 64], + "critic_hidden_dims": [64, 64], + "training_rounds": 50, + "batch_size": 32, + "epsilon": 0.1, + }, + "actor_network_module": CNNActorNetwork, + "actor_network_args": { + "hidden_dims_fully_connected": [512], + "kernel_sizes": [8, 4, 3], + "output_channels_list": [32, 64, 64], + "strides": [4, 2, 1], + "paddings": [0, 0, 0], + }, + "critic_network_module": CNNValueNetwork, + "critic_network_args": { + "hidden_dims_fully_connected": [512], + "kernel_sizes": [8, 4, 3], + "output_channels_list": [32, 64, 64], + "strides": [4, 2, 1], + "paddings": [0, 0, 0], }, "replay_buffer": PPOReplayBuffer, "replay_buffer_args": {"capacity": 50000}, @@ -1372,6 +1407,7 @@ "methods": [ DQN_Atari_method, DQN_multi_head_Atari_method, + PPO_Atari_method, ], "device_id": 0, }