From d1d4dadedb27e5341b17bc0832686c023c1c3731 Mon Sep 17 00:00:00 2001 From: Rodrigo de Salvo Braz Date: Wed, 1 May 2024 16:52:08 -0700 Subject: [PATCH] Add observation space to single item recommendation system tutorial Summary: Add observation space to single item recommendation system tutorial Reviewed By: jb3618columbia Differential Revision: D56854463 fbshipit-source-id: 4eb24338f1b1cb881c20ceda6e39d6a944a74903 --- pearl/api/environment.py | 1 - test/unit/test_tutorials/test_rec_system.py | 7 +++++++ .../single_item_recommender_system.ipynb | 21 ++++++++++++------- 3 files changed, 21 insertions(+), 8 deletions(-) diff --git a/pearl/api/environment.py b/pearl/api/environment.py index 6c26d245..be70dae8 100644 --- a/pearl/api/environment.py +++ b/pearl/api/environment.py @@ -32,7 +32,6 @@ def action_space(self) -> ActionSpace: """Returns the action space of the environment.""" pass - # FIXME: add this and in implement in all concrete subclasses @property @abstractmethod def observation_space(self) -> Space: diff --git a/test/unit/test_tutorials/test_rec_system.py b/test/unit/test_tutorials/test_rec_system.py index 65b5ad84..89b0c698 100644 --- a/test/unit/test_tutorials/test_rec_system.py +++ b/test/unit/test_tutorials/test_rec_system.py @@ -19,6 +19,7 @@ from pearl.api.action_space import ActionSpace from pearl.api.environment import Environment from pearl.api.observation import Observation +from pearl.api.space import Space from pearl.history_summarization_modules.lstm_history_summarization_module import ( LSTMHistorySummarizationModule, ) @@ -40,6 +41,7 @@ ) from pearl.utils.functional_utils.experimentation.set_seed import set_seed from pearl.utils.functional_utils.train_and_eval.online_learning import online_learning +from pearl.utils.instantiations.spaces.box import BoxSpace from pearl.utils.instantiations.spaces.discrete_action import DiscreteActionSpace set_seed(0) @@ -147,9 +149,14 @@ def __init__( self.state: torch.Tensor = torch.zeros((self.history_length, 100)).to(device) self._action_space: DiscreteActionSpace = DiscreteActionSpace(self.actions[0]) + @property def action_space(self) -> ActionSpace: return DiscreteActionSpace(self.actions[0]) + @property + def observation_space(self) -> Space: + return BoxSpace(low=torch.zeros((1,)), high=torch.ones((1,))) + def reset(self, seed: Optional[int] = None) -> Tuple[Observation, ActionSpace]: self.state: torch.Tensor = torch.zeros((self.history_length, 100)) self.t = 0 diff --git a/tutorials/single_item_recommender_system_example/single_item_recommender_system.ipynb b/tutorials/single_item_recommender_system_example/single_item_recommender_system.ipynb index beae15a1..8a3d5f49 100644 --- a/tutorials/single_item_recommender_system_example/single_item_recommender_system.ipynb +++ b/tutorials/single_item_recommender_system_example/single_item_recommender_system.ipynb @@ -36,7 +36,7 @@ "metadata": { "id": "nFomZD4OjZLK" }, - "execution_count": 2, + "execution_count": null, "outputs": [] }, { @@ -66,7 +66,7 @@ "id": "5i2jE98RjhK1", "outputId": "ad80e72d-51cb-4594-a97e-2f4cf5466667" }, - "execution_count": 3, + "execution_count": null, "outputs": [ { "output_type": "stream", @@ -220,6 +220,7 @@ "from pearl.api.action_space import ActionSpace\n", "from pearl.api.environment import Environment\n", "from pearl.api.observation import Observation\n", + "from pearl.api.space import Space\n", "from pearl.history_summarization_modules.lstm_history_summarization_module import (\n", " LSTMHistorySummarizationModule,\n", ")\n", @@ -241,6 +242,7 @@ ")\n", "from pearl.utils.functional_utils.experimentation.set_seed import set_seed\n", "from pearl.utils.functional_utils.train_and_eval.online_learning import online_learning\n", + "from pearl.utils.instantiations.spaces.box import BoxSpace\n", "from pearl.utils.instantiations.spaces.discrete_action import DiscreteActionSpace\n", "import matplotlib.pyplot as plt\n", "\n", @@ -249,7 +251,7 @@ "metadata": { "id": "Lp6pRjTDjpDo" }, - "execution_count": 4, + "execution_count": null, "outputs": [] }, { @@ -345,9 +347,14 @@ " self.state: torch.Tensor = torch.zeros((self.history_length, 100)).to(device)\n", " self._action_space: DiscreteActionSpace = DiscreteActionSpace(self.actions[0])\n", "\n", + " @property\n", " def action_space(self) -> ActionSpace:\n", " return DiscreteActionSpace(self.actions[0])\n", "\n", + " @property\n", + " def observation_space(self) -> Space:\n", + " return BoxSpace(low=torch.zeros((1,)), high=torch.ones((1,)))\n", + "\n", " def reset(self, seed: Optional[int] = None) -> Tuple[Observation, ActionSpace]:\n", " self.state: torch.Tensor = torch.zeros((self.history_length, 100))\n", " self.t = 0\n", @@ -402,7 +409,7 @@ "id": "BYdPfGgZp8HN", "outputId": "6be7ab71-02d3-4d82-c66e-2ca60655efd1" }, - "execution_count": 5, + "execution_count": null, "outputs": [ { "output_type": "stream", @@ -469,7 +476,7 @@ "id": "GDnAlQQNqC7z", "outputId": "e55dc40f-f5ad-48ba-f5ac-e1efb3f63a1c" }, - "execution_count": 6, + "execution_count": null, "outputs": [ { "output_type": "stream", @@ -1549,7 +1556,7 @@ "id": "hewvpLU_qHhO", "outputId": "64d05cd0-c71b-4337-def6-441f9fbd59f5" }, - "execution_count": 6, + "execution_count": null, "outputs": [ { "output_type": "stream", @@ -3635,7 +3642,7 @@ "id": "xuuCmTfoqMg9", "outputId": "7fe1ee42-6697-4443-9f9a-953a1cbac5fa" }, - "execution_count": 7, + "execution_count": null, "outputs": [ { "output_type": "stream",