Skip to content

Commit

Permalink
multi-head CNN
Browse files Browse the repository at this point in the history
Summary: Add multi-head CNN.

Reviewed By: rodrigodesalvobraz

Differential Revision: D66207172

fbshipit-source-id: 9ae5e23050a6c715f61f9a20be3b7fc68464c401
  • Loading branch information
yiwan-rl authored and facebook-github-bot committed Dec 13, 2024
1 parent 1b86382 commit cafb382
Show file tree
Hide file tree
Showing 4 changed files with 182 additions and 2 deletions.
103 changes: 103 additions & 0 deletions pearl/neural_networks/sequential_decision_making/q_value_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -815,3 +815,106 @@ def state_dim(self) -> int:
@property
def action_dim(self) -> int:
return self._action_dim


class CNNQValueMultiHeadNetwork(QValueNetwork):
"""
A CNN version of state-action value (Q-value) network.
"""

def __init__(
self,
input_width: int,
input_height: int,
input_channels_count: int,
kernel_sizes: List[int],
output_channels_list: List[int],
strides: List[int],
paddings: List[int],
action_dim: int,
hidden_dims_fully_connected: Optional[List[int]] = None,
output_dim: int = 1,
use_batch_norm_conv: bool = False,
use_batch_norm_fully_connected: bool = False,
) -> None:
super(CNNQValueMultiHeadNetwork, self).__init__()

self._input_channels = input_channels_count
self._input_height = input_height
self._input_width = input_width
self._output_channels = output_channels_list
self._kernel_sizes = kernel_sizes
self._strides = strides
self._paddings = paddings
if hidden_dims_fully_connected is None:
self._hidden_dims_fully_connected: List[int] = []
else:
self._hidden_dims_fully_connected: List[int] = hidden_dims_fully_connected

self._use_batch_norm_conv = use_batch_norm_conv
self._use_batch_norm_fully_connected = use_batch_norm_fully_connected
self._output_dim = output_dim

self._model_cnn: nn.Module = conv_block(
input_channels_count=self._input_channels,
output_channels_list=self._output_channels,
kernel_sizes=self._kernel_sizes,
strides=self._strides,
paddings=self._paddings,
use_batch_norm=self._use_batch_norm_conv,
)
# we concatenate actions to state representations in the mlp block of the Q-value network
self._mlp_input_dims: int = compute_output_dim_model_cnn(
input_channels=input_channels_count,
input_width=input_width,
input_height=input_height,
model_cnn=self._model_cnn,
)
self._model_fc: nn.Module = mlp_block(
input_dim=self._mlp_input_dims,
hidden_dims=self._hidden_dims_fully_connected,
output_dim=self._output_dim,
use_batch_norm=self._use_batch_norm_fully_connected,
)
self._state_dim: int = input_channels_count * input_height * input_width
self._action_dim = action_dim

def get_q_values(
self,
state_batch: Tensor, # shape: (batch_size, input_channels, input_height, input_width)
action_batch: Tensor, # shape: (batch_size, number_of_actions_to_query, action_dim) or (batch_size, action_dim)
curr_available_actions_batch: Optional[Tensor] = None,
) -> Tensor:
# action representation is assumed to be one-hot
assert is_one_hot_tensor(action_batch)
assert self._output_dim == action_batch.shape[-1] # action_dim = num actions
assert len(state_batch.shape) == 4
assert len(action_batch.shape) == 3 or len(action_batch.shape) == 2
if len(action_batch.shape) == 2:
action_batch = action_batch.unsqueeze(1)

batch_size = state_batch.shape[0]
state_representation_batch = self._model_cnn(
state_batch / 255.0
) # (batch_size x output_channels[-1] x output_height x output_width)
state_representation_batch = state_representation_batch.view(
batch_size, -1
) # (batch_size x state dim)
q_values = self._model_fc(state_representation_batch).unsqueeze(
-1
) # (batch_size x num actions x 1)

q_values = torch.bmm(
action_batch, # shape: (batch_size, number_of_actions_to_query, action_dim)
q_values, # (batch_size x num actions x 1)
) # (batch_size x number_of_actions_to_query x 1)
q_values = q_values.squeeze(-1) # (batch_size x number_of_actions_to_query)
return q_values if len(action_batch) == 3 else q_values.squeeze(-1)

@property
def state_dim(self) -> int:
return self._state_dim

@property
def action_dim(self) -> int:
return self._action_dim
14 changes: 12 additions & 2 deletions pearl/utils/scripts/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
IdentityActionRepresentationModule,
)
from pearl.neural_networks.sequential_decision_making.q_value_networks import (
CNNQValueMultiHeadNetwork,
CNNQValueNetwork,
)
from pearl.pearl_agent import PearlAgent
Expand Down Expand Up @@ -219,15 +220,24 @@ def evaluate_single(
"history_summarization_module"
](**method["history_summarization_module_args"])

if "network_module" in method and method["network_module"] is CNNQValueNetwork:
if "network_module" in method and method["network_module"] in [
CNNQValueNetwork,
CNNQValueMultiHeadNetwork,
]:
policy_learner_args["network_instance"] = method["network_module"](
input_width=env.observation_space.shape[2],
input_height=env.observation_space.shape[1],
input_channels_count=env.observation_space.shape[0],
action_dim=policy_learner_args[
"action_representation_module"
].representation_dim,
output_dim=1,
output_dim=(
1
if method["network_module"] is CNNQValueNetwork
else policy_learner_args[
"action_representation_module"
].representation_dim
),
**method["network_args"],
)

Expand Down
25 changes: 25 additions & 0 deletions pearl/utils/scripts/benchmark_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
VanillaContinuousActorNetwork,
)
from pearl.neural_networks.sequential_decision_making.q_value_networks import (
CNNQValueMultiHeadNetwork,
CNNQValueNetwork,
DuelingQValueNetwork,
EnsembleQValueNetwork,
Expand Down Expand Up @@ -168,6 +169,29 @@
"action_representation_module": OneHotActionTensorRepresentationModule,
"action_representation_module_args": {},
}
DQN_multi_head_Atari_method = {
"name": "DQN",
"policy_learner": DeepQLearning,
"policy_learner_args": {
"training_rounds": 1,
"target_update_freq": 250,
"batch_size": 32,
},
"network_module": CNNQValueMultiHeadNetwork,
"network_args": {
"hidden_dims_fully_connected": [512],
"kernel_sizes": [8, 4, 3],
"output_channels_list": [32, 64, 64],
"strides": [4, 2, 1],
"paddings": [0, 0, 0],
},
"exploration_module": EGreedyExploration,
"exploration_module_args": {"epsilon": 0.1},
"replay_buffer": BasicReplayBuffer,
"replay_buffer_args": {"capacity": 50000},
"action_representation_module": OneHotActionTensorRepresentationModule,
"action_representation_module_args": {},
}
CDQN_method = {
"name": "Conservative DQN",
"policy_learner": DeepQLearning,
Expand Down Expand Up @@ -1347,6 +1371,7 @@
"record_period": 10000,
"methods": [
DQN_Atari_method,
DQN_multi_head_Atari_method,
],
"device_id": 0,
}
Expand Down
42 changes: 42 additions & 0 deletions test/unit/with_pytorch/test_cnn_based_q_value_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import torch
import torchvision
from pearl.neural_networks.sequential_decision_making.q_value_networks import (
CNNQValueMultiHeadNetwork,
CNNQValueNetwork,
)

Expand Down Expand Up @@ -78,3 +79,44 @@ def test_forward_pass(self) -> None:
) # test get_q_values method

self.assertEqual(q_values.shape[0], x_batch.shape[0])

def test_multi_head_networks_forward_pass(self) -> None:
"""
test to check if the get_q_values method returns a scalar for a batch of
images (observations) and a random batch of actions
"""
input_width = 28 # specifications for mnist dataset
input_height = 28 # specifications for mnist dataset
input_channels = 1 # specifications for mnist dataset

# build a cnn based q value network
kernel_sizes = [5]
output_channels = [16]
strides = [1]
paddings = [2]
hidden_dims_fully_connected = [64]
action_dim = 4 # random integer to generate dummy batch of actions
network = CNNQValueMultiHeadNetwork(
input_width=input_width,
input_height=input_height,
input_channels_count=input_channels,
kernel_sizes=kernel_sizes,
output_channels_list=output_channels,
strides=strides,
paddings=paddings,
output_dim=action_dim,
action_dim=action_dim,
hidden_dims_fully_connected=hidden_dims_fully_connected,
)

for x_batch, _ in self.train_dl:
indices = torch.randint(0, action_dim, (x_batch.shape[0],))
# Create the one-hot matrix
action_batch = torch.zeros((x_batch.shape[0], action_dim))
action_batch[torch.arange(x_batch.shape[0]), indices] = 1
print(action_batch.shape)
q_values = network.get_q_values(
x_batch, action_batch
) # test get_q_values method

self.assertEqual(q_values.shape[0], x_batch.shape[0])

0 comments on commit cafb382

Please sign in to comment.