diff --git a/pearl/utils/scripts/cb_benchmark/cb_benchmark_config.py b/pearl/utils/scripts/cb_benchmark/cb_benchmark_config.py index f822005a..cb8d10ed 100644 --- a/pearl/utils/scripts/cb_benchmark/cb_benchmark_config.py +++ b/pearl/utils/scripts/cb_benchmark/cb_benchmark_config.py @@ -31,8 +31,10 @@ from pearl.utils.instantiations.environments.contextual_bandit_uci_environment import ( SLCBEnvironment, ) - -DATA_PATH: str = "./utils/instantiations/environments/uci_datasets" +if os.path.exists("../Pearl"): + DATA_PATH: str = "pearl/utils/instantiations/environments/uci_datasets" +else: + DATA_PATH: str = "./utils/instantiations/environments/uci_datasets" """ Experiment config diff --git a/pearl/utils/scripts/cb_benchmark/run_cb_benchmarks.py b/pearl/utils/scripts/cb_benchmark/run_cb_benchmarks.py index 5b326423..0049e427 100644 --- a/pearl/utils/scripts/cb_benchmark/run_cb_benchmarks.py +++ b/pearl/utils/scripts/cb_benchmark/run_cb_benchmarks.py @@ -275,6 +275,16 @@ def run_cb_benchmarks( # Create UCI data directory if it does not already exist uci_data_path = "./utils/instantiations/environments/uci_datasets" + print("current_dir", os.getcwd()) + print("dirs list", os.listdir()) + + if os.path.exists("../Pearl"): + uci_data_path = "pearl/utils/instantiations/environments/uci_datasets" + save_results_path: str = "pearl/utils/scripts/cb_benchmark/experiments_results" + print("dir exists", os.path.exists("pearl/utils/instantiations/environments/")) + else: + uci_data_path = "utils/instantiations/environments/uci_datasets" + save_results_path: str = "utils/scripts/cb_benchmark/experiments_results" if not os.path.exists(uci_data_path): os.makedirs(uci_data_path) @@ -282,7 +292,7 @@ def run_cb_benchmarks( download_uci_data(data_path=uci_data_path) # Create folder for result if it does not already exist - save_results_path: str = "./utils/scripts/cb_benchmark/experiments_results" + if not os.path.exists(save_results_path): os.makedirs(save_results_path) diff --git a/test/unit/test_tutorials/test_rec_system.py b/test/unit/test_tutorials/test_rec_system.py index 65b5ad84..4f66efdf 100644 --- a/test/unit/test_tutorials/test_rec_system.py +++ b/test/unit/test_tutorials/test_rec_system.py @@ -6,6 +6,7 @@ import random import unittest from typing import List, Optional, Tuple +import os import numpy as np @@ -66,7 +67,7 @@ This environment's underlying model was pre-trained using the MIND dataset (Wu et al. 2020). The model is defined by class `SequenceClassificationModel` below. The model's state dict is saved in -tutorials/single_item_recommender_system_example/env_model_state_dict.pt +tutorials/single_item_recommender_system_example/env_model_state_dict Each data point is: - A history of impressions clicked by a user @@ -189,16 +190,26 @@ def setUp(self) -> None: def test_rec_system(self) -> None: # load environment model = SequenceClassificationModel(100).to(device) + print("current dir", os.getcwd()) + print("dirs list", os.listdir()) + if os.path.exists("../Pearl"): + # Github CI tests + print("Pearl directory exists") + model_dir = "tutorials/single_item_recommender_system_example/" + else: + # Meta internal tests + print("pearl directory exists") + model_dir = "pearl/tutorials/single_item_recommender_system_example/" model.load_state_dict( # Note: in the tutorial the directory "pearl" must be replaced by "Pearl" torch.load( - "pearl/tutorials/single_item_recommender_system_example/env_model_state_dict.pt", + os.path.join(model_dir, "env_model_state_dict.pt"), weights_only=True, ) ) # Note: in the tutorial the directory "pearl" must be replaced by "Pearl" actions = torch.load( - "pearl/tutorials/single_item_recommender_system_example/news_embedding_small.pt", + os.path.join(model_dir, "news_embedding_small.pt"), weights_only=True, ) history_length = 8