Skip to content

Commit

Permalink
Add CNNActorNetwork
Browse files Browse the repository at this point in the history
Summary: To apply PPO to Atari games, we add CNNActorNetwork.

Reviewed By: rodrigodesalvobraz

Differential Revision: D66281497

fbshipit-source-id: 26899d83f0fb711b96bd4c3ea0c6bde93295a2ef
  • Loading branch information
yiwan-rl authored and facebook-github-bot committed Dec 13, 2024
1 parent 1b0702e commit a92fdd9
Show file tree
Hide file tree
Showing 3 changed files with 203 additions and 3 deletions.
134 changes: 133 additions & 1 deletion pearl/neural_networks/sequential_decision_making/actor_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,13 @@
import torch.nn as nn

from pearl.api.action_space import ActionSpace
from pearl.neural_networks.common.utils import mlp_block
from pearl.neural_networks.common.utils import (
compute_output_dim_model_cnn,
conv_block,
mlp_block,
)
from pearl.utils.functional_utils.learning.is_one_hot_tensor import is_one_hot_tensor

from pearl.utils.instantiations.spaces.box_action import BoxActionSpace

from torch import Tensor
Expand Down Expand Up @@ -174,6 +180,132 @@ def get_action_prob(
return action_probs.view(-1)


class CNNActorNetwork(ActorNetwork):
def __init__(
self,
input_width: int,
input_height: int,
input_channels_count: int,
kernel_sizes: List[int],
output_channels_list: List[int],
strides: List[int],
paddings: List[int],
hidden_dims_fully_connected: Optional[List[int]] = None,
output_dim: int = 1,
use_batch_norm_conv: bool = False,
use_batch_norm_fully_connected: bool = False,
action_space: Optional[ActionSpace] = None,
) -> None:
"""A CNN Actor Network is meant to be used with CNN to deal with images.
For an input state (batch of states), it outputs a probability distribution over
all the actions.
"""
super(CNNActorNetwork, self).__init__(
input_dim=input_width * input_height * input_channels_count,
hidden_dims=hidden_dims_fully_connected,
output_dim=output_dim,
action_space=action_space,
)
self._input_channels = input_channels_count
self._input_height = input_height
self._input_width = input_width
self._output_channels = output_channels_list
self._kernel_sizes = kernel_sizes
self._strides = strides
self._paddings = paddings
if hidden_dims_fully_connected is None:
self._hidden_dims_fully_connected: List[int] = []
else:
self._hidden_dims_fully_connected: List[int] = hidden_dims_fully_connected

self._use_batch_norm_conv = use_batch_norm_conv
self._use_batch_norm_fully_connected = use_batch_norm_fully_connected
self._output_dim = output_dim

self._model_cnn: nn.Module = conv_block(
input_channels_count=self._input_channels,
output_channels_list=self._output_channels,
kernel_sizes=self._kernel_sizes,
strides=self._strides,
paddings=self._paddings,
use_batch_norm=self._use_batch_norm_conv,
)
# we concatenate actions to state representations in the mlp block of the Q-value network
self._mlp_input_dims: int = compute_output_dim_model_cnn(
input_channels=input_channels_count,
input_width=input_width,
input_height=input_height,
model_cnn=self._model_cnn,
)
self._model_fc: nn.Module = mlp_block(
input_dim=self._mlp_input_dims,
hidden_dims=self._hidden_dims_fully_connected,
output_dim=self._output_dim,
use_batch_norm=self._use_batch_norm_fully_connected,
last_activation="softmax",
)
self._state_dim: int = input_channels_count * input_height * input_width

def forward(
self,
state_batch: torch.Tensor, # shape: (batch_size, input_channels, input_height, input_width)
) -> torch.Tensor:
batch_size = state_batch.shape[0]
state_representation_batch = self._model_cnn(
state_batch / 255.0
) # (batch_size x output_channels[-1] x output_height x output_width)
state_representation_batch = state_representation_batch.view(
batch_size, -1
) # (batch_size x state dim)
policy = self._model_fc(
state_representation_batch
) # (batch_size x num actions)
return policy

def get_policy_distribution(
self,
state_batch: torch.Tensor, # shape: (batch_size, input_channels, input_height, input_width)
available_actions: Optional[torch.Tensor] = None,
unavailable_actions_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Gets a policy distribution from a discrete actor network.
available_actions and unavailable_actions_mask are not used.
"""
if len(state_batch.shape) == 3:
state_batch = state_batch.unsqueeze(0)
reshape_state_batch = True
else:
reshape_state_batch = False
policy_distribution = self.forward(
state_batch
) # shape (batch_size, available_actions)
if reshape_state_batch:
policy_distribution = policy_distribution.squeeze(0)
return policy_distribution

def get_action_prob(
self,
state_batch: torch.Tensor, # shape: (batch_size, input_channels, input_height, input_width)
action_batch: torch.Tensor, # shape: (batch_size, action_dim)
available_actions: Optional[torch.Tensor] = None,
unavailable_actions_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Gets probabilities of different actions from a discrete actor network.
Assumes that the input batch of actions is one-hot encoded
(generalize it later).
"""
assert is_one_hot_tensor(action_batch)
assert self._output_dim == action_batch.shape[-1] # action_dim = num actions
assert len(state_batch.shape) == 4
assert len(action_batch.shape) == 2
all_action_probs = self.forward(state_batch) # shape: (batch_size, output_dim)
action_probs = torch.sum(all_action_probs * action_batch, dim=1, keepdim=True)

return action_probs.view(-1)


class DynamicActionActorNetwork(VanillaActorNetwork):
def __init__(
self,
Expand Down
32 changes: 32 additions & 0 deletions pearl/utils/scripts/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@
from pearl.action_representation_modules.identity_action_representation_module import (
IdentityActionRepresentationModule,
)
from pearl.neural_networks.common.value_networks import CNNValueNetwork
from pearl.neural_networks.sequential_decision_making.actor_networks import (
CNNActorNetwork,
)
from pearl.neural_networks.sequential_decision_making.q_value_networks import (
CNNQValueMultiHeadNetwork,
CNNQValueNetwork,
Expand Down Expand Up @@ -241,6 +245,34 @@ def evaluate_single(
**method["network_args"],
)

if (
"critic_network_module" in method
and method["critic_network_module"] is CNNValueNetwork
):
policy_learner_args["critic_network_instance"] = method[
"critic_network_module"
](
input_width=env.observation_space.shape[2],
input_height=env.observation_space.shape[1],
input_channels_count=env.observation_space.shape[0],
output_dim=1,
**method["critic_network_args"],
)

if (
"actor_network_module" in method
and method["actor_network_module"] is CNNActorNetwork
):
policy_learner_args["actor_network_instance"] = method["actor_network_module"](
input_width=env.observation_space.shape[2],
input_height=env.observation_space.shape[1],
input_channels_count=env.observation_space.shape[0],
output_dim=policy_learner_args[
"action_representation_module"
].representation_dim,
**method["actor_network_args"],
)

if method["name"] == "DuelingDQN": # only for Dueling DQN
assert "network_module" in method and "network_args" in method
policy_learner_args["network_instance"] = method["network_module"](
Expand Down
40 changes: 38 additions & 2 deletions pearl/utils/scripts/benchmark_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,12 @@
from pearl.history_summarization_modules.lstm_history_summarization_module import (
LSTMHistorySummarizationModule,
)
from pearl.neural_networks.common.value_networks import VanillaValueNetwork
from pearl.neural_networks.common.value_networks import (
CNNValueNetwork,
VanillaValueNetwork,
)
from pearl.neural_networks.sequential_decision_making.actor_networks import (
CNNActorNetwork,
DynamicActionActorNetwork,
GaussianActorNetwork,
VanillaActorNetwork,
Expand Down Expand Up @@ -392,7 +396,38 @@
"training_rounds": 50,
"batch_size": 32,
"epsilon": 0.1,
"use_critic": True,
},
"replay_buffer": PPOReplayBuffer,
"replay_buffer_args": {"capacity": 50000},
"action_representation_module": OneHotActionTensorRepresentationModule,
"action_representation_module_args": {},
"learn_every_k_steps": 200,
}
PPO_Atari_method = {
"name": "PPO",
"policy_learner": ProximalPolicyOptimization,
"policy_learner_args": {
"actor_hidden_dims": [64, 64],
"critic_hidden_dims": [64, 64],
"training_rounds": 50,
"batch_size": 32,
"epsilon": 0.1,
},
"actor_network_module": CNNActorNetwork,
"actor_network_args": {
"hidden_dims_fully_connected": [512],
"kernel_sizes": [8, 4, 3],
"output_channels_list": [32, 64, 64],
"strides": [4, 2, 1],
"paddings": [0, 0, 0],
},
"critic_network_module": CNNValueNetwork,
"critic_network_args": {
"hidden_dims_fully_connected": [512],
"kernel_sizes": [8, 4, 3],
"output_channels_list": [32, 64, 64],
"strides": [4, 2, 1],
"paddings": [0, 0, 0],
},
"replay_buffer": PPOReplayBuffer,
"replay_buffer_args": {"capacity": 50000},
Expand Down Expand Up @@ -1372,6 +1407,7 @@
"methods": [
DQN_Atari_method,
DQN_multi_head_Atari_method,
PPO_Atari_method,
],
"device_id": 0,
}
Expand Down

0 comments on commit a92fdd9

Please sign in to comment.