Skip to content

Commit

Permalink
Rename VanillaCNN to CNNValueNetwork and divide state values by 255.
Browse files Browse the repository at this point in the history
Summary:
This diff makes two changes:

1) Following our naming convention for q value networks, rename VanillaCNN to CNNValueNetwork.

2) For atari games, raw images pixels values (0-255) are stored in the replay buffer (instead of values normalized to be within 0-1) to save memory. We need to do normalization in our CNN networks.

Reviewed By: rodrigodesalvobraz

Differential Revision: D66280552

fbshipit-source-id: 1346b6eb18cae8a831e7f071467487239723115a
  • Loading branch information
yiwan-rl authored and facebook-github-bot committed Dec 13, 2024
1 parent 5d9316b commit 1b0702e
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 9 deletions.
4 changes: 2 additions & 2 deletions pearl/neural_networks/common/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,15 @@

from .epistemic_neural_networks import Ensemble, EpistemicNeuralNetwork, MLPWithPrior
from .residual_wrapper import ResidualWrapper
from .value_networks import ValueNetwork, VanillaCNN, VanillaValueNetwork
from .value_networks import CNNValueNetwork, ValueNetwork, VanillaValueNetwork

__all__ = [
"Ensemble",
"EpistemicNeuralNetwork",
"MLPWithPrior",
"ResidualWrapper",
"ValueNetwork",
"VanillaCNN",
"CNNValueNetwork",
"VanillaValueNetwork",
"Epinet",
]
6 changes: 3 additions & 3 deletions pearl/neural_networks/common/value_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def xavier_init(self) -> None:
nn.init.xavier_normal_(layer.weight)


class VanillaCNN(ValueNetwork):
class CNNValueNetwork(ValueNetwork):
"""
Vanilla CNN with a convolutional block followed by an mlp block.
Args:
Expand Down Expand Up @@ -101,7 +101,7 @@ def __init__(
== len(strides)
== len(paddings)
)
super().__init__()
super(CNNValueNetwork, self).__init__()

self._input_channels = input_channels_count
self._input_height = input_height
Expand Down Expand Up @@ -142,7 +142,7 @@ def __init__(
)

def forward(self, x: Tensor) -> Tensor:
out_cnn = self._model_cnn(x)
out_cnn = self._model_cnn(x / 255.0)
out_flattened = torch.flatten(out_cnn, start_dim=1, end_dim=-1)
out_fc = self._model_fc(out_flattened)
return out_fc
8 changes: 4 additions & 4 deletions test/unit/with_pytorch/test_vanilla_cnns.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@

import torch
import torchvision
from pearl.neural_networks.common.value_networks import VanillaCNN
from pearl.neural_networks.common.value_networks import CNNValueNetwork

from torch import optim
from torch.utils.data import DataLoader, Subset
from torchvision import transforms


class TestVanillaCNNs(unittest.TestCase):
class TestCNNValueNetworks(unittest.TestCase):
def setUp(self) -> None:
transform = transforms.Compose([transforms.ToTensor()])
mnist_dataset = torchvision.datasets.MNIST(
Expand All @@ -42,7 +42,7 @@ def setUp(self) -> None:
self.mnist_train_dataset, self.batch_size, shuffle=True
)

def test_vanilla_cnns(self) -> None:
def test_cnns(self) -> None:
"""
a simple cnn should be able to fit the mnist digit dataset and the training
accuracy should be close to 90%
Expand All @@ -58,7 +58,7 @@ def test_vanilla_cnns(self) -> None:
paddings = [2]
hidden_dims_fully_connected = [64]
output_dim = 10
network = VanillaCNN(
network = CNNValueNetwork(
input_width=input_width,
input_height=input_height,
input_channels_count=input_channels,
Expand Down

0 comments on commit 1b0702e

Please sign in to comment.