From 3200afabcf90d6f375de482524df56538a3816ab Mon Sep 17 00:00:00 2001 From: Yonathan Efroni Date: Fri, 26 Apr 2024 10:10:40 -0700 Subject: [PATCH] Fix CB errors Summary: Fix CB errors and add unitest. -) add cb unitests that replicates the cb tutorial. removed previous broken uci test. -) fixed a bug in joint CB. Reviewed By: rodrigodesalvobraz Differential Revision: D56336066 fbshipit-source-id: 7d2cd9c9d8201aaefc701cb5e3b6facbbb31a48b --- .../binary_action_representation_module.py | 2 + .../one_hot_action_representation_module.py | 2 + .../cb_benchmark/cb_benchmark_config.py | 2 + test/unit/test_tutorials/test_cb_tutorial.py | 166 ++++++++++++++++++ test/unit/with_pytorch/test_agent.py | 38 ---- 5 files changed, 172 insertions(+), 38 deletions(-) create mode 100644 test/unit/test_tutorials/test_cb_tutorial.py diff --git a/pearl/action_representation_modules/binary_action_representation_module.py b/pearl/action_representation_modules/binary_action_representation_module.py index 87082cd1..9bb0f1a7 100644 --- a/pearl/action_representation_modules/binary_action_representation_module.py +++ b/pearl/action_representation_modules/binary_action_representation_module.py @@ -25,6 +25,8 @@ def __init__(self, bits_num: int) -> None: self._max_number_actions: int = 2**bits_num def forward(self, x: torch.Tensor) -> torch.Tensor: + if len(x.shape) == 1: + x = x.unsqueeze(-1) return self.binary(x) # (batch_size x action_dim) diff --git a/pearl/action_representation_modules/one_hot_action_representation_module.py b/pearl/action_representation_modules/one_hot_action_representation_module.py index dbf92d74..c9a1be76 100644 --- a/pearl/action_representation_modules/one_hot_action_representation_module.py +++ b/pearl/action_representation_modules/one_hot_action_representation_module.py @@ -25,6 +25,8 @@ def __init__(self, max_number_actions: int) -> None: self._max_number_actions = max_number_actions def forward(self, x: torch.Tensor) -> torch.Tensor: + if len(x.shape) == 1: + x = x.unsqueeze(-1) return F.one_hot(x.long(), num_classes=self._max_number_actions).squeeze(dim=-2) # (batch_size x action_dim) diff --git a/pearl/utils/scripts/cb_benchmark/cb_benchmark_config.py b/pearl/utils/scripts/cb_benchmark/cb_benchmark_config.py index db394466..4fe59ef0 100644 --- a/pearl/utils/scripts/cb_benchmark/cb_benchmark_config.py +++ b/pearl/utils/scripts/cb_benchmark/cb_benchmark_config.py @@ -143,6 +143,7 @@ def return_neural_lin_ucb_config( "hidden_dims": [64, 16], "learning_rate": 0.01, "batch_size": 128, + "state_features_only": False, "training_rounds": run_config["training_rounds"], "action_representation_module": BinaryActionTensorRepresentationModule( bits_num=dim_actions @@ -176,6 +177,7 @@ def return_neural_lin_ts_config( "hidden_dims": [64, 16], "learning_rate": 0.01, "batch_size": 128, + "state_features_only": False, "training_rounds": run_config["training_rounds"], "action_representation_module": BinaryActionTensorRepresentationModule( bits_num=dim_actions diff --git a/test/unit/test_tutorials/test_cb_tutorial.py b/test/unit/test_tutorials/test_cb_tutorial.py new file mode 100644 index 00000000..2b155077 --- /dev/null +++ b/test/unit/test_tutorials/test_cb_tutorial.py @@ -0,0 +1,166 @@ +# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + + +import os +import unittest + +import torch +from pearl.action_representation_modules.one_hot_action_representation_module import ( + OneHotActionTensorRepresentationModule, +) +from pearl.pearl_agent import PearlAgent +from pearl.policy_learners.contextual_bandits.neural_bandit import NeuralBandit +from pearl.policy_learners.contextual_bandits.neural_linear_bandit import ( + NeuralLinearBandit, +) +from pearl.policy_learners.exploration_modules.contextual_bandits.squarecb_exploration import ( + SquareCBExploration, +) +from pearl.policy_learners.exploration_modules.contextual_bandits.thompson_sampling_exploration import ( + ThompsonSamplingExplorationLinear, +) +from pearl.policy_learners.exploration_modules.contextual_bandits.ucb_exploration import ( + UCBExploration, +) +from pearl.replay_buffers.sequential_decision_making.fifo_off_policy_replay_buffer import ( + FIFOOffPolicyReplayBuffer, +) +from pearl.utils.functional_utils.experimentation.set_seed import set_seed +from pearl.utils.functional_utils.train_and_eval.online_learning import online_learning +from pearl.utils.instantiations.environments.contextual_bandit_uci_environment import ( + SLCBEnvironment, +) +from pearl.utils.uci_data import download_uci_data + +set_seed(0) + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +device_id = 0 if torch.cuda.is_available() else -1 + +""" +This is a unit test version of the CB tutorial. +It is meant to check whether code changes break the tutorial. +It is therefore important that the tutorial and the code here are kept in sync. +As part of that synchronization, the markdown cells in the tutorial are +kept here as multi-line strings. + +For it to run quickly, the number of steps used for training is reduced. +""" + + +class TestCBTutorials(unittest.TestCase): + def setUp(self) -> None: + super().setUp() + + def test_cb_tutorials(self) -> None: + # load environment + device = -1 + + # Download UCI dataset if doesn't exist + uci_data_path = "./utils/instantiations/environments/uci_datasets" + if not os.path.exists(uci_data_path): + os.makedirs(uci_data_path) + download_uci_data(data_path=uci_data_path) + + # Built CB environment using the pendigits UCI dataset + pendigits_uci_dict = { + "path_filename": os.path.join(uci_data_path, "pendigits/pendigits.tra"), + "action_embeddings": "discrete", + "delim_whitespace": False, + "ind_to_drop": [], + "target_column": 16, + } + env = SLCBEnvironment(**pendigits_uci_dict) # pyre-ignore + + # experiment code + number_of_steps = 200 + record_period = 400 + + """ + SquareCB + """ + # Create a Neural SquareCB pearl agent with 1-hot action representation + action_representation_module = OneHotActionTensorRepresentationModule( + max_number_actions=env.unique_labels_num, + ) + + agent = PearlAgent( + policy_learner=NeuralBandit( + feature_dim=env.observation_dim + env.unique_labels_num, + hidden_dims=[64, 16], + training_rounds=10, + learning_rate=0.01, + action_representation_module=action_representation_module, + exploration_module=SquareCBExploration( + gamma=env.observation_dim * env.unique_labels_num * number_of_steps + ), + ), + replay_buffer=FIFOOffPolicyReplayBuffer(100_000), + device_id=device, + ) + + _ = online_learning( + agent=agent, + env=env, + number_of_steps=number_of_steps, + print_every_x_steps=100, + record_period=record_period, + learn_after_episode=True, + ) + + # Neural LinUCB + action_representation_module = OneHotActionTensorRepresentationModule( + max_number_actions=env.unique_labels_num, + ) + + agent = PearlAgent( + policy_learner=NeuralLinearBandit( + feature_dim=env.observation_dim + env.unique_labels_num, + hidden_dims=[64, 16], + state_features_only=False, + training_rounds=10, + learning_rate=0.01, + action_representation_module=action_representation_module, + exploration_module=UCBExploration(alpha=1.0), + ), + replay_buffer=FIFOOffPolicyReplayBuffer(100_000), + device_id=device, + ) + + _ = online_learning( + agent=agent, + env=env, + number_of_steps=number_of_steps, + print_every_x_steps=100, + record_period=record_period, + learn_after_episode=True, + ) + + # Neural LinTS + + action_representation_module = OneHotActionTensorRepresentationModule( + max_number_actions=env.unique_labels_num, + ) + + agent = PearlAgent( + policy_learner=NeuralLinearBandit( + feature_dim=env.observation_dim + env.unique_labels_num, + hidden_dims=[64, 16], + state_features_only=False, + training_rounds=10, + learning_rate=0.01, + action_representation_module=action_representation_module, + exploration_module=ThompsonSamplingExplorationLinear(), + ), + replay_buffer=FIFOOffPolicyReplayBuffer(100_000), + device_id=-1, + ) + + _ = online_learning( + agent=agent, + env=env, + number_of_steps=number_of_steps, + print_every_x_steps=100, + record_period=record_period, + learn_after_episode=True, + ) diff --git a/test/unit/with_pytorch/test_agent.py b/test/unit/with_pytorch/test_agent.py index 07eee824..9b4805ac 100644 --- a/test/unit/with_pytorch/test_agent.py +++ b/test/unit/with_pytorch/test_agent.py @@ -8,7 +8,6 @@ # pyre-strict import unittest -from typing import Any, Dict import torch from pearl.action_representation_modules.one_hot_action_representation_module import ( @@ -58,14 +57,6 @@ RewardIsEqualToTenTimesActionContextualBanditEnvironment, ) from pearl.utils.instantiations.spaces.discrete_action import DiscreteActionSpace -from pearl.utils.scripts.cb_benchmark.cb_benchmark_config import ( - pendigits_uci_dict, - return_neural_lin_ts_config, - return_neural_lin_ucb_config, - return_neural_squarecb_config, -) - -from pearl.utils.scripts.cb_benchmark.run_cb_benchmarks import run_cb_benchmarks class TestAgentWithPyTorch(unittest.TestCase): @@ -272,32 +263,3 @@ def test_contextual_bandit_with_tabular_q_learning_online_rl(self) -> None: agent, env, learn=False, exploit=True ) assert episode_info["return"] == max_action * 10 - - def test_contextual_bandit_on_uci_datasets(self) -> None: - # Tests that neural versions of CB algorithms train on a UCI dataset - # CB Algorithms are the neural versions of LinUCB, LinTS, and SquareCB with shared models. - - # set number of time steps to be small, just for unit testing purposes - run_config_test: Dict[str, Any] = { - "T": 300, - "training_rounds": 1, - "num_of_experiments": 1, - } - - # load configs of neural versions of SquareCB, LinUCB, and LinTS - cb_algorithms_config: Dict[str, Any] = { - "NeuralSquareCB": return_neural_squarecb_config, - "NeuralLinUCB": return_neural_lin_ucb_config, - "NeuralLinTS": return_neural_lin_ts_config, - } - - # load only pendigits UCI dataset - test_environments_config: Dict[str, Any] = { - "pendigits": pendigits_uci_dict, - } - - run_cb_benchmarks( - cb_algorithms_config=cb_algorithms_config, - test_environments_config=test_environments_config, - run_config=run_config_test, - )