Skip to content

Commit

Permalink
Fix single item recommendation system tutorial
Browse files Browse the repository at this point in the history
Summary:
This diff improves the single item recommendation system tutorial in several ways:
* corrects it to property work with agents returning actions rather than action indices. The current tutorial is broken and this fixes it.
* makes it reflect the corresponding recently introduced unit test. It is important that they are as similar as possible so that code changes breaking the tutorial are caught by the unit test.
* adds typing hints (because the unit test has them).
* includes auxiliary classes in the notebook itself, so users can more readily examine them.

Reviewed By: Yonathae

Differential Revision: D56238638

fbshipit-source-id: b4ac39b8314379a2ef4bcaec0011ebc70fc5a941
  • Loading branch information
rodrigodesalvobraz authored and facebook-github-bot committed Apr 22, 2024
1 parent ddd6ac4 commit 4175432
Show file tree
Hide file tree
Showing 4 changed files with 10,604 additions and 10,553 deletions.
29 changes: 18 additions & 11 deletions test/unit/test_tutorials/test_rec_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,11 @@
from pearl.replay_buffers.sequential_decision_making.fifo_off_policy_replay_buffer import (
FIFOOffPolicyReplayBuffer,
)
from pearl.utils.functional_utils.experimentation.set_seed import set_seed
from pearl.utils.functional_utils.train_and_eval.online_learning import online_learning
from pearl.utils.instantiations.spaces.discrete_action import DiscreteActionSpace

set_seed(0)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device_id = 0 if torch.cuda.is_available() else -1
Expand Down Expand Up @@ -128,12 +130,6 @@ def forward(self, x: torch.Tensor, action: torch.Tensor) -> torch.Tensor:
)
out, (_, _) = self.lstm(x, (h0, c0))
mlp_input = out[:, -1, :].view((batch_size, -1))
print("Action shape: ", action.shape)
print("mlp_input.shape: ", mlp_input.shape)
print(
"torch.cat([mlp_input, action], dim=-1).shape:, ",
torch.cat([mlp_input, action], dim=-1).shape,
)
return torch.sigmoid(self.mlp(torch.cat([mlp_input, action], dim=-1)))


Expand Down Expand Up @@ -182,6 +178,9 @@ def step(self, action: Action) -> ActionResult:
available_action_space=self._action_space,
)

def __str__(self) -> str:
return self.__class__.__name__


class TestTutorials(unittest.TestCase):
def setUp(self) -> None:
Expand All @@ -193,12 +192,14 @@ def test_rec_system(self) -> None:
model.load_state_dict(
# Note: in the tutorial the directory "pearl" must be replaced by "Pearl"
torch.load(
"pearl/tutorials/single_item_recommender_system_example/env_model_state_dict.pt"
"pearl/tutorials/single_item_recommender_system_example/env_model_state_dict.pt",
weights_only=True,
)
)
# Note: in the tutorial the directory "pearl" must be replaced by "Pearl"
actions = torch.load(
"pearl/tutorials/single_item_recommender_system_example/news_embedding_small.pt"
"pearl/tutorials/single_item_recommender_system_example/news_embedding_small.pt",
weights_only=True,
)
history_length = 8
env = RecEnv(list(actions.values())[:100], model, history_length)
Expand All @@ -210,8 +211,8 @@ def test_rec_system(self) -> None:

"""
## Vanilla DQN Agent
Able to handle dynamic action space but not able to handle partial observabilit
and sparse reward.
Able to handle dynamic action space but not able to handle partial observability
and sparse reward.
"""

# create a pearl agent
Expand Down Expand Up @@ -243,6 +244,8 @@ def test_rec_system(self) -> None:
record_period=min(record_period, number_of_steps),
learn_after_episode=True,
)

# Keep the commented out code below as a reference for the notebook
# torch.save(info["return"], "DQN-return.pt")
# plt.plot(
# record_period * np.arange(len(info["return"])), info["return"], label="DQN"
Expand All @@ -253,7 +256,7 @@ def test_rec_system(self) -> None:
"""
## DQN Agent with LSTM history summarization module
Now the DQN agent can handle partially observable environments with history summarization
Now the DQN agent can handle partially observable environments with history summarization.
"""

# Add a LSTM history summarization module
Expand Down Expand Up @@ -285,6 +288,8 @@ def test_rec_system(self) -> None:
record_period=min(record_period, number_of_steps),
learn_after_episode=True,
)

# Keep the commented out code below as a reference for the notebook
# torch.save(info["return"], "DQN-LSTM-return.pt")
# plt.plot(
# record_period * np.arange(len(info["return"])),
Expand Down Expand Up @@ -339,6 +344,8 @@ def test_rec_system(self) -> None:
record_period=min(record_period, number_of_steps),
learn_after_episode=True,
)

# Keep the commented out code below as a reference for the notebook
# torch.save(info["return"], "BootstrappedDQN-LSTM-return.pt")
# plt.plot(
# record_period * np.arange(len(info["return"])),
Expand Down
60 changes: 0 additions & 60 deletions tutorials/single_item_recommender_system_example/env.py

This file was deleted.

56 changes: 0 additions & 56 deletions tutorials/single_item_recommender_system_example/env_model.py

This file was deleted.

Loading

0 comments on commit 4175432

Please sign in to comment.