diff --git a/pearl/pearl_agent.py b/pearl/pearl_agent.py index 1a62aeee..736d9683 100644 --- a/pearl/pearl_agent.py +++ b/pearl/pearl_agent.py @@ -158,10 +158,7 @@ def act(self, exploit: bool = False) -> Action: subjective_state_to_be_used, safe_action_space, exploit=exploit # pyre-fixme[6] ) - if isinstance(safe_action_space, DiscreteActionSpace): - self._latest_action = safe_action_space.actions_batch[int(action.item())] - else: - self._latest_action = action + self._latest_action = action return action diff --git a/pearl/policy_learners/exploration_modules/common/propensity_exploration.py b/pearl/policy_learners/exploration_modules/common/propensity_exploration.py index 6386bc0a..0a5d529f 100644 --- a/pearl/policy_learners/exploration_modules/common/propensity_exploration.py +++ b/pearl/policy_learners/exploration_modules/common/propensity_exploration.py @@ -27,7 +27,6 @@ class PropensityExploration(ExplorationModule): def __init__(self) -> None: super(PropensityExploration, self).__init__() - # TODO: We should make discrete action space itself iterable def act( self, subjective_state: SubjectiveState, diff --git a/pearl/policy_learners/exploration_modules/sequential_decision_making/deep_exploration.py b/pearl/policy_learners/exploration_modules/sequential_decision_making/deep_exploration.py index 20f0bb27..fd2d01cb 100644 --- a/pearl/policy_learners/exploration_modules/sequential_decision_making/deep_exploration.py +++ b/pearl/policy_learners/exploration_modules/sequential_decision_making/deep_exploration.py @@ -77,7 +77,9 @@ def act( # this does a forward pass since all available # actions are already stacked together - return torch.argmax(q_values).view((-1)) + action_index = torch.argmax(q_values) + action = action_space.actions[action_index] + return action def reset(self) -> None: # noqa: B027 # sample a new epistemic index (i.e., a Q-network) at the beginning of a diff --git a/pearl/policy_learners/sequential_decision_making/actor_critic_base.py b/pearl/policy_learners/sequential_decision_making/actor_critic_base.py index 3ad1742a..ad0c8e2b 100644 --- a/pearl/policy_learners/sequential_decision_making/actor_critic_base.py +++ b/pearl/policy_learners/sequential_decision_making/actor_critic_base.py @@ -244,7 +244,8 @@ def act( available_actions=actions, ) # (action_space_size) - exploit_action = torch.argmax(action_probabilities) + exploit_action_index = torch.argmax(action_probabilities) + exploit_action = available_action_space.actions[exploit_action_index] # Step 2: return exploit action if no exploration, # else pass through the exploration module diff --git a/pearl/policy_learners/sequential_decision_making/deep_td_learning.py b/pearl/policy_learners/sequential_decision_making/deep_td_learning.py index 58fa8b78..5903d37a 100644 --- a/pearl/policy_learners/sequential_decision_making/deep_td_learning.py +++ b/pearl/policy_learners/sequential_decision_making/deep_td_learning.py @@ -221,7 +221,8 @@ def act( # this does a forward pass since all avaialble # actions are already stacked together - exploit_action = torch.argmax(q_values).view((-1)) + exploit_action_index = torch.argmax(q_values) + exploit_action = available_action_space.actions[exploit_action_index] if exploit: return exploit_action diff --git a/pearl/policy_learners/sequential_decision_making/quantile_regression_deep_td_learning.py b/pearl/policy_learners/sequential_decision_making/quantile_regression_deep_td_learning.py index 944fedaa..96ccb263 100644 --- a/pearl/policy_learners/sequential_decision_making/quantile_regression_deep_td_learning.py +++ b/pearl/policy_learners/sequential_decision_making/quantile_regression_deep_td_learning.py @@ -149,7 +149,8 @@ def act( q_values = self.safety_module.get_q_values_under_risk_metric( states_repeated, actions, self._Q ) - exploit_action = torch.argmax(q_values).view((-1)) + exploit_action_index = torch.argmax(q_values) + exploit_action = available_action_space.actions[exploit_action_index] if exploit: return exploit_action diff --git a/pearl/policy_learners/sequential_decision_making/tabular_q_learning.py b/pearl/policy_learners/sequential_decision_making/tabular_q_learning.py index 38eec481..07d78ca6 100644 --- a/pearl/policy_learners/sequential_decision_making/tabular_q_learning.py +++ b/pearl/policy_learners/sequential_decision_making/tabular_q_learning.py @@ -7,9 +7,7 @@ # pyre-strict -from typing import Any, Dict, Iterable, List, Tuple - -import torch +from typing import Any, Dict, Iterable, Tuple from pearl.api.action import Action from pearl.api.action_space import ActionSpace @@ -45,6 +43,9 @@ def __init__( ) -> None: """ Initializes the tabular Q-learning policy learner. + Currently, tabular Q-learning assumes + a discrete action space, and assumes that for each action + int(action.item()) == action's index. Args: learning_rate (float, optional): the learning rate. Defaults to 0.01. @@ -66,6 +67,13 @@ def __init__( def reset(self, action_space: ActionSpace) -> None: self._action_space = action_space + for i, action in enumerate(self._action_space): + if int(action.item()) != i: + raise ValueError( + f"{self.__class__.__name__} only supports " + f"action spaces that are a DiscreteSpace where for each action " + f"action.item() == action's index. " + ) def act( self, @@ -74,21 +82,22 @@ def act( exploit: bool = False, ) -> Action: assert isinstance(available_action_space, DiscreteSpace) - # FIXME: this conversion should be eliminated once Action - # is no longer constrained to be a Tensor. - actions_as_ints: List[int] = [int(a.item()) for a in available_action_space] + # TODO: if we substitute DiscreteActionSpace for DiscreteSpace + # we get Pyre errors. It would be nice to fix this. + # Choose the action with the highest Q-value for the current state. - q_values_for_state = { - action: self.q_values.get((subjective_state, action), 0) - for action in actions_as_ints + action_q_values_for_state = { + action_index: self.q_values.get((subjective_state, action_index), 0) + for action_index in range(available_action_space.n) } - max_q_value = max(q_values_for_state.values()) - exploit_action = first_item( - action - for action, q_value in q_values_for_state.items() - if q_value == max_q_value + max_q_value_for_state = max(action_q_values_for_state.values()) + exploit_action_index = first_item( + action_index + for action_index, q_value in action_q_values_for_state.items() + if q_value == max_q_value_for_state ) - exploit_action = torch.tensor([exploit_action]) + exploit_action = available_action_space.actions[exploit_action_index] + if exploit: return exploit_action @@ -102,6 +111,7 @@ def learn( self, replay_buffer: ReplayBuffer, ) -> Dict[str, Any]: + # We know the sampling result from SingleTransitionReplayBuffer # is a list with a single tuple. transitions = replay_buffer.sample(1) diff --git a/pearl/safety_modules/identity_safety_module.py b/pearl/safety_modules/identity_safety_module.py index ff3f84d8..31aaa588 100644 --- a/pearl/safety_modules/identity_safety_module.py +++ b/pearl/safety_modules/identity_safety_module.py @@ -7,8 +7,6 @@ # pyre-strict -from typing import Optional - from pearl.api.action_space import ActionSpace from pearl.history_summarization_modules.history_summarization_module import ( SubjectiveState, diff --git a/pearl/utils/instantiations/environments/environments.py b/pearl/utils/instantiations/environments/environments.py index 9f85b268..da75b3c1 100644 --- a/pearl/utils/instantiations/environments/environments.py +++ b/pearl/utils/instantiations/environments/environments.py @@ -40,7 +40,7 @@ def __init__(self, number_of_steps: int = 100) -> None: self.number_of_steps_so_far = 0 self.number_of_steps: int = number_of_steps self._action_space = DiscreteActionSpace( - [torch.tensor(True), torch.tensor(False)] + [torch.tensor(False), torch.tensor(True)] ) def step(self, action: Action) -> ActionResult: diff --git a/test/unit/test_tutorials/test_rec_system.py b/test/unit/test_tutorials/test_rec_system.py new file mode 100644 index 00000000..f8d7b247 --- /dev/null +++ b/test/unit/test_tutorials/test_rec_system.py @@ -0,0 +1,354 @@ +# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + + +import random +import unittest +from typing import List, Optional, Tuple + +import numpy as np + +import torch +import torch.nn as nn +from pearl.action_representation_modules.identity_action_representation_module import ( + IdentityActionRepresentationModule, +) +from pearl.api.action import Action +from pearl.api.action_result import ActionResult +from pearl.api.action_space import ActionSpace +from pearl.api.environment import Environment +from pearl.api.observation import Observation +from pearl.history_summarization_modules.lstm_history_summarization_module import ( + LSTMHistorySummarizationModule, +) +from pearl.neural_networks.sequential_decision_making.q_value_networks import ( + EnsembleQValueNetwork, +) +from pearl.pearl_agent import PearlAgent +from pearl.policy_learners.sequential_decision_making.bootstrapped_dqn import ( + BootstrappedDQN, +) +from pearl.policy_learners.sequential_decision_making.deep_q_learning import ( + DeepQLearning, +) +from pearl.replay_buffers.sequential_decision_making.bootstrap_replay_buffer import ( + BootstrapReplayBuffer, +) +from pearl.replay_buffers.sequential_decision_making.fifo_off_policy_replay_buffer import ( + FIFOOffPolicyReplayBuffer, +) +from pearl.utils.functional_utils.train_and_eval.online_learning import online_learning +from pearl.utils.instantiations.spaces.discrete_action import DiscreteActionSpace + + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +device_id = 0 if torch.cuda.is_available() else -1 + +""" +This is a unit test version of the Recommender System tutorial. +It is meant to check whether code changes break the tutorial. +It is therefore important that the tutorial and the code here are kept in sync. +As part of that synchronization, the markdown cells in the tutorial are +kept here as multi-line strings. + +For it to run quickly, the number of steps used for training is reduced. +""" + + +number_of_steps = 30 # 100000 + + +""" +## Load Environment +This environment's underlying model was pre-trained using the MIND dataset (Wu et al. 2020). +The model is defined by class `SequenceClassificationModel` below. +The model's state dict is saved in +tutorials/single_item_recommender_system_example/env_model_state_dict.pt + +Each data point is: +- A history of impressions clicked by a user +- Each impression is represented by an 100-dim vector +- A list of impressions and whether or not they are clicked + +The environment is constructed with the setup below. Note that this is a contrived example +to illustrate Pearl's usage, agent modularity and a subset of features, +not to represent a real-world environment or problem. +- State: a history of impressions by a user (note that we used the history of impressions + instead of clicked impressions to speed up learning in this example). + Interested Pearl users can change it to history of clicked impressions with much longer + episode length and samples to run the following experiments.) +- Dynamic action space: two randomly picked impressions +- Action: one of the two impressions +- Reward: click +- Reset every 20 steps. +""" + + +class SequenceClassificationModel(nn.Module): + def __init__( + self, + observation_dim: int, + hidden_dim: int = 128, + state_dim: int = 128, + num_layers: int = 2, + ) -> None: + super(SequenceClassificationModel, self).__init__() + self.lstm = nn.LSTM( + num_layers=num_layers, + input_size=observation_dim, + hidden_size=hidden_dim, + batch_first=True, + ) + self.mlp = nn.Sequential( + nn.Linear(hidden_dim + observation_dim, hidden_dim), + nn.ReLU(), + nn.Linear(hidden_dim, hidden_dim // 2), + nn.ReLU(), + nn.Linear(hidden_dim // 2, 1), + ) + self.register_buffer( + "default_cell_representation", torch.zeros((num_layers, hidden_dim)) + ) + self.register_buffer( + "default_hidden_representation", torch.zeros((num_layers, hidden_dim)) + ) + + def forward(self, x: torch.Tensor, action: torch.Tensor) -> torch.Tensor: + batch_size = x.shape[0] + h0 = ( + self.default_hidden_representation.unsqueeze(1) + .repeat(1, batch_size, 1) + .detach() + ) + c0 = ( + self.default_cell_representation.unsqueeze(1) + .repeat(1, batch_size, 1) + .detach() + ) + out, (_, _) = self.lstm(x, (h0, c0)) + mlp_input = out[:, -1, :].view((batch_size, -1)) + print("Action shape: ", action.shape) + print("mlp_input.shape: ", mlp_input.shape) + print( + "torch.cat([mlp_input, action], dim=-1).shape:, ", + torch.cat([mlp_input, action], dim=-1).shape, + ) + return torch.sigmoid(self.mlp(torch.cat([mlp_input, action], dim=-1))) + + +class RecEnv(Environment): + def __init__( + self, actions: List[torch.Tensor], model: nn.Module, history_length: int + ) -> None: + self.model: nn.Module = model.to(device) + self.history_length: int = history_length + self.t = 0 + self.T = 20 + self.actions: List[List[torch.Tensor]] = [ + [torch.tensor(k) for k in random.sample(actions, 2)] for _ in range(self.T) + ] + self.state: torch.Tensor = torch.zeros((self.history_length, 100)).to(device) + self._action_space: DiscreteActionSpace = DiscreteActionSpace(self.actions[0]) + + def action_space(self) -> ActionSpace: + return DiscreteActionSpace(self.actions[0]) + + def reset(self, seed: Optional[int] = None) -> Tuple[Observation, ActionSpace]: + self.state: torch.Tensor = torch.zeros((self.history_length, 100)) + self.t = 0 + self._action_space: DiscreteActionSpace = DiscreteActionSpace( + self.actions[self.t] + ) + return [0.0], self._action_space + + def step(self, action: Action) -> ActionResult: + action = action.to(device) + action_batch = action.unsqueeze(0) + state_batch = self.state.unsqueeze(0).to(device) + reward = self.model(state_batch, action_batch) * 3 # To speed up learning + true_reward = np.random.binomial(1, reward.item()) + self.state = torch.cat([self.state[1:, :].to(device), action_batch], dim=0) + + self.t += 1 + if self.t < self.T: + self._action_space = DiscreteActionSpace(self.actions[self.t]) + return ActionResult( + observation=[float(true_reward)], + reward=float(true_reward), + terminated=self.t >= self.T, + truncated=False, + info={}, + available_action_space=self._action_space, + ) + + +class TestTutorials(unittest.TestCase): + def setUp(self) -> None: + super().setUp() + + def test_rec_system(self) -> None: + # load environment + model = SequenceClassificationModel(100).to(device) + model.load_state_dict( + # Note: in the tutorial the directory "pearl" must be replaced by "Pearl" + torch.load( + "pearl/tutorials/single_item_recommender_system_example/env_model_state_dict.pt" + ) + ) + # Note: in the tutorial the directory "pearl" must be replaced by "Pearl" + actions = torch.load( + "pearl/tutorials/single_item_recommender_system_example/news_embedding_small.pt" + ) + history_length = 8 + env = RecEnv(list(actions.values())[:100], model, history_length) + observation, action_space = env.reset() + assert isinstance(action_space, DiscreteActionSpace) + + # experiment code + record_period = 400 + + """ + ## Vanilla DQN Agent + Able to handle dynamic action space but not able to handle partial observabilit + and sparse reward. + """ + + # create a pearl agent + + action_representation_module = IdentityActionRepresentationModule( + max_number_actions=action_space.n, + representation_dim=action_space.action_dim, + ) + + # DQN-vanilla + agent = PearlAgent( + policy_learner=DeepQLearning( + state_dim=1, + action_space=action_space, + hidden_dims=[64, 64], + training_rounds=50, + action_representation_module=action_representation_module, + ), + replay_buffer=FIFOOffPolicyReplayBuffer(100_000), + device_id=device_id, + ) + + # info = \ + online_learning( + agent=agent, + env=env, + number_of_steps=number_of_steps, + print_every_x_steps=100, + record_period=min(record_period, number_of_steps), + learn_after_episode=True, + ) + # torch.save(info["return"], "DQN-return.pt") + # plt.plot( + # record_period * np.arange(len(info["return"])), info["return"], label="DQN" + # ) + # plt.legend() + # plt.show() + + """ + ## DQN Agent with LSTM history summarization module + + Now the DQN agent can handle partially observable environments with history summarization + """ + + # Add a LSTM history summarization module + + agent = PearlAgent( + policy_learner=DeepQLearning( + state_dim=128, + action_space=action_space, + hidden_dims=[64, 64], + training_rounds=50, + action_representation_module=action_representation_module, + ), + history_summarization_module=LSTMHistorySummarizationModule( + observation_dim=1, + action_dim=100, + hidden_dim=128, + history_length=history_length, + ), + replay_buffer=FIFOOffPolicyReplayBuffer(100_000), + device_id=device_id, + ) + + # info = \ + online_learning( + agent=agent, + env=env, + number_of_steps=number_of_steps, + print_every_x_steps=100, + record_period=min(record_period, number_of_steps), + learn_after_episode=True, + ) + # torch.save(info["return"], "DQN-LSTM-return.pt") + # plt.plot( + # record_period * np.arange(len(info["return"])), + # info["return"], + # label="DQN-LSTM", + # ) + # plt.legend() + # plt.show() + + """ + ## Bootstrapped DQN Agent with LSTM History Summarization + + Leveraging the deep exploration value-based algorithm, now the agent can achieve a + better performance in a much faster way while being able to still leverage + history summarization capability. + Note how top average performance takes around 20,000 steps in the graph above, + but only about 5,000 steps in the graph below. + """ + + # Better exploration with BootstrappedDQN-LSTM + + agent = PearlAgent( + policy_learner=BootstrappedDQN( + q_ensemble_network=EnsembleQValueNetwork( + state_dim=128, + action_dim=100, + ensemble_size=10, + output_dim=1, + hidden_dims=[64, 64], + prior_scale=0.3, + ), + action_space=action_space, + training_rounds=50, + action_representation_module=action_representation_module, + ), + history_summarization_module=LSTMHistorySummarizationModule( + observation_dim=1, + action_dim=100, + hidden_dim=128, + history_length=history_length, + ), + replay_buffer=BootstrapReplayBuffer(100_000, 1.0, 10), + device_id=device_id, + ) + + # info = \ + online_learning( + agent=agent, + env=env, + number_of_steps=number_of_steps, + print_every_x_steps=100, + record_period=min(record_period, number_of_steps), + learn_after_episode=True, + ) + # torch.save(info["return"], "BootstrappedDQN-LSTM-return.pt") + # plt.plot( + # record_period * np.arange(len(info["return"])), + # info["return"], + # label="BootstrappedDQN-LSTM", + # ) + # plt.legend() + # plt.show() + + """ + ## Summary + In this example, we illustrated Pearl's capability of dealing with dynamic action space, + standard policy learning, history summarization and intelligent exploration, + all in a single agent. + """ diff --git a/tutorials/README_meta_only.md b/tutorials/README_meta_only.md new file mode 100644 index 00000000..998932ca --- /dev/null +++ b/tutorials/README_meta_only.md @@ -0,0 +1,11 @@ +Make sure to reflect any changes on tutorials to the correposponding +unit tests in test/unit/test_tutorials, +so that we are alerted to code changes breaking the tutorial. +Likewise, make sure to reflect any changes in the tests to the tutorial. + +Note that this synchronization is not a copy-and-paste affair; +one must be careful to replace things as needed: +- the notebook uses `plt.show()` to show the plot, the unit test +does not compute the graphs (the code is kept commented out for reference) +- the paths to the files (such as .pt files) must use "pearl" in fbcode and "Pearl" in + the open-source version.