From af330a62c9a9cf21489422c86039e6e904f3f950 Mon Sep 17 00:00:00 2001 From: Rodrigo de Salvo Braz Date: Thu, 22 Feb 2024 12:13:47 -0800 Subject: [PATCH] Move UCI data package away froms scripts Summary: This diff moves the file `utils/scripts/cb_benchmark/cb_download_benchmarks.py` to `utils`. This was needed because the `scripts` directory is not included in the `pearl` library per se, so it was also not included in the `pearl` Bento kernel, therefore not being available on the Bento execution of the Contextual Bandits tutorial notebook that depends on that kernel. Because the function to download UCI data can be thought of as a general utility and not part of a script, it's been moved to `util`. It has also been renamed to reflect the fact that it is not part of a benchmark script. Reviewed By: Yonathae Differential Revision: D54049997 fbshipit-source-id: 62403dfe156f246e246a29b774eb7de197d9dc8b --- pearl/utils/scripts/cb_benchmark/__init__.py | 3 +- .../cb_benchmark/cb_benchmark_config.py | 2 +- .../scripts/cb_benchmark/run_cb_benchmarks.py | 2 +- .../cb_download_benchmarks.py => uci_data.py} | 0 .../contextual_bandits_tutorial.ipynb | 45 ++++++++++++------- 5 files changed, 32 insertions(+), 20 deletions(-) rename pearl/utils/{scripts/cb_benchmark/cb_download_benchmarks.py => uci_data.py} (100%) diff --git a/pearl/utils/scripts/cb_benchmark/__init__.py b/pearl/utils/scripts/cb_benchmark/__init__.py index fd71d7ba..47867cd2 100644 --- a/pearl/utils/scripts/cb_benchmark/__init__.py +++ b/pearl/utils/scripts/cb_benchmark/__init__.py @@ -4,6 +4,8 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from pearl.utils.uci_data import download_uci_data + from .cb_benchmark_config import ( return_neural_fastcb_config, return_neural_lin_ts_config, @@ -11,7 +13,6 @@ return_neural_squarecb_config, return_offline_eval_config, ) -from .cb_download_benchmarks import download_uci_data from .run_cb_benchmarks import ( online_evaluation, diff --git a/pearl/utils/scripts/cb_benchmark/cb_benchmark_config.py b/pearl/utils/scripts/cb_benchmark/cb_benchmark_config.py index 9dcc36c2..4e975c46 100644 --- a/pearl/utils/scripts/cb_benchmark/cb_benchmark_config.py +++ b/pearl/utils/scripts/cb_benchmark/cb_benchmark_config.py @@ -20,7 +20,7 @@ FastCBExploration, SquareCBExploration, ) -from pearl.policy_learners.exploration_modules.contextual_bandits.thompson_sampling_exploration import ( +from pearl.policy_learners.exploration_modules.contextual_bandits.thompson_sampling_exploration import ( # noqa E501 ThompsonSamplingExplorationLinear, ) from pearl.policy_learners.exploration_modules.contextual_bandits.ucb_exploration import ( diff --git a/pearl/utils/scripts/cb_benchmark/run_cb_benchmarks.py b/pearl/utils/scripts/cb_benchmark/run_cb_benchmarks.py index 863fff1c..74f11a4a 100644 --- a/pearl/utils/scripts/cb_benchmark/run_cb_benchmarks.py +++ b/pearl/utils/scripts/cb_benchmark/run_cb_benchmarks.py @@ -44,7 +44,7 @@ satimage_uci_dict, yeast_uci_dict, ) -from pearl.utils.scripts.cb_benchmark.cb_download_benchmarks import download_uci_data +from pearl.utils.uci_data import download_uci_data def online_evaluation( diff --git a/pearl/utils/scripts/cb_benchmark/cb_download_benchmarks.py b/pearl/utils/uci_data.py similarity index 100% rename from pearl/utils/scripts/cb_benchmark/cb_download_benchmarks.py rename to pearl/utils/uci_data.py diff --git a/tutorials/contextual_bandits/contextual_bandits_tutorial.ipynb b/tutorials/contextual_bandits/contextual_bandits_tutorial.ipynb index 56e2c3c7..e2850b97 100644 --- a/tutorials/contextual_bandits/contextual_bandits_tutorial.ipynb +++ b/tutorials/contextual_bandits/contextual_bandits_tutorial.ipynb @@ -182,7 +182,7 @@ "from pearl.replay_buffers.sequential_decision_making.fifo_off_policy_replay_buffer import FIFOOffPolicyReplayBuffer\n", "from pearl.utils.functional_utils.train_and_eval.online_learning import online_learning\n", "from pearl.pearl_agent import PearlAgent\n", - "from pearl.utils.scripts.cb_benchmark.cb_download_benchmarks import download_uci_data\n", + "from pearl.utils.uci_data import download_uci_data\n", "from pearl.utils.instantiations.environments.contextual_bandit_uci_environment import (\n", " SLCBEnvironment,\n", ")\n", @@ -245,9 +245,6 @@ } ], "source": [ - "# load environment\n", - "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", - "\n", "# Download UCI dataset if doesn't exist\n", "uci_data_path = \"./utils/instantiations/environments/uci_datasets\"\n", "if not os.path.exists(uci_data_path):\n", @@ -1176,29 +1173,43 @@ "custom": { "cells": [], "metadata": { - "accelerator": "GPU", - "colab": { - "gpuType": "T4", - "include_colab_link": true, - "provenance": [] + "custom": { + "cells": [], + "metadata": { + "accelerator": "GPU", + "colab": { + "gpuType": "T4", + "include_colab_link": true, + "provenance": [] + }, + "fileHeader": "", + "fileUid": "4316417e-7688-45f2-a94f-24148bfc425e", + "isAdHoc": false, + "kernelspec": { + "display_name": "pearl (local)", + "language": "python", + "name": "pearl_local" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 2 }, "fileHeader": "", - "fileUid": "4316417e-7688-45f2-a94f-24148bfc425e", + "fileUid": "1158a851-91bb-437e-a391-aba92448f600", + "indentAmount": 2, "isAdHoc": false, - "kernelspec": { - "display_name": "pearl (local)", - "language": "python", - "name": "pearl_local" - }, "language_info": { - "name": "python" + "name": "plaintext" } }, "nbformat": 4, "nbformat_minor": 2 }, "fileHeader": "", - "fileUid": "1158a851-91bb-437e-a391-aba92448f600", + "fileUid": "06710d6d-2a6b-4a80-a1f7-31b8d3b7c146", "indentAmount": 2, "isAdHoc": false, "language_info": {