Skip to content

Commit

Permalink
Move UCI data package away froms scripts
Browse files Browse the repository at this point in the history
Summary:
This diff moves the file `utils/scripts/cb_benchmark/cb_download_benchmarks.py` to `utils`.

This was needed because the `scripts` directory is not included in the `pearl` library per se, so it was also not included in the `pearl` Bento kernel, therefore not being available on the Bento execution of the Contextual Bandits tutorial notebook that depends on that kernel.

Because the function to download UCI data can be thought of as a general utility and not part of a script, it's been moved to `util`. It has also been renamed to reflect the fact that it is not part of a benchmark script.

Reviewed By: Yonathae

Differential Revision: D54049997

fbshipit-source-id: 62403dfe156f246e246a29b774eb7de197d9dc8b
rodrigodesalvobraz authored and facebook-github-bot committed Feb 22, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
1 parent df96edc commit af330a6
Showing 5 changed files with 32 additions and 20 deletions.
3 changes: 2 additions & 1 deletion pearl/utils/scripts/cb_benchmark/__init__.py
Original file line number Diff line number Diff line change
@@ -4,14 +4,15 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from pearl.utils.uci_data import download_uci_data

from .cb_benchmark_config import (
return_neural_fastcb_config,
return_neural_lin_ts_config,
return_neural_lin_ucb_config,
return_neural_squarecb_config,
return_offline_eval_config,
)
from .cb_download_benchmarks import download_uci_data

from .run_cb_benchmarks import (
online_evaluation,
2 changes: 1 addition & 1 deletion pearl/utils/scripts/cb_benchmark/cb_benchmark_config.py
Original file line number Diff line number Diff line change
@@ -20,7 +20,7 @@
FastCBExploration,
SquareCBExploration,
)
from pearl.policy_learners.exploration_modules.contextual_bandits.thompson_sampling_exploration import (
from pearl.policy_learners.exploration_modules.contextual_bandits.thompson_sampling_exploration import ( # noqa E501
ThompsonSamplingExplorationLinear,
)
from pearl.policy_learners.exploration_modules.contextual_bandits.ucb_exploration import (
2 changes: 1 addition & 1 deletion pearl/utils/scripts/cb_benchmark/run_cb_benchmarks.py
Original file line number Diff line number Diff line change
@@ -44,7 +44,7 @@
satimage_uci_dict,
yeast_uci_dict,
)
from pearl.utils.scripts.cb_benchmark.cb_download_benchmarks import download_uci_data
from pearl.utils.uci_data import download_uci_data


def online_evaluation(
File renamed without changes.
45 changes: 28 additions & 17 deletions tutorials/contextual_bandits/contextual_bandits_tutorial.ipynb
Original file line number Diff line number Diff line change
@@ -182,7 +182,7 @@
"from pearl.replay_buffers.sequential_decision_making.fifo_off_policy_replay_buffer import FIFOOffPolicyReplayBuffer\n",
"from pearl.utils.functional_utils.train_and_eval.online_learning import online_learning\n",
"from pearl.pearl_agent import PearlAgent\n",
"from pearl.utils.scripts.cb_benchmark.cb_download_benchmarks import download_uci_data\n",
"from pearl.utils.uci_data import download_uci_data\n",
"from pearl.utils.instantiations.environments.contextual_bandit_uci_environment import (\n",
" SLCBEnvironment,\n",
")\n",
@@ -245,9 +245,6 @@
}
],
"source": [
"# load environment\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"\n",
"# Download UCI dataset if doesn't exist\n",
"uci_data_path = \"./utils/instantiations/environments/uci_datasets\"\n",
"if not os.path.exists(uci_data_path):\n",
@@ -1176,29 +1173,43 @@
"custom": {
"cells": [],
"metadata": {
"accelerator": "GPU",
"colab": {
"gpuType": "T4",
"include_colab_link": true,
"provenance": []
"custom": {
"cells": [],
"metadata": {
"accelerator": "GPU",
"colab": {
"gpuType": "T4",
"include_colab_link": true,
"provenance": []
},
"fileHeader": "",
"fileUid": "4316417e-7688-45f2-a94f-24148bfc425e",
"isAdHoc": false,
"kernelspec": {
"display_name": "pearl (local)",
"language": "python",
"name": "pearl_local"
},
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 2
},
"fileHeader": "",
"fileUid": "4316417e-7688-45f2-a94f-24148bfc425e",
"fileUid": "1158a851-91bb-437e-a391-aba92448f600",
"indentAmount": 2,
"isAdHoc": false,
"kernelspec": {
"display_name": "pearl (local)",
"language": "python",
"name": "pearl_local"
},
"language_info": {
"name": "python"
"name": "plaintext"
}
},
"nbformat": 4,
"nbformat_minor": 2
},
"fileHeader": "",
"fileUid": "1158a851-91bb-437e-a391-aba92448f600",
"fileUid": "06710d6d-2a6b-4a80-a1f7-31b8d3b7c146",
"indentAmount": 2,
"isAdHoc": false,
"language_info": {

0 comments on commit af330a6

Please sign in to comment.