Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ci failing tests fix #83

Closed
wants to merge 12 commits into from
6 changes: 4 additions & 2 deletions pearl/utils/scripts/cb_benchmark/cb_benchmark_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,10 @@
from pearl.utils.instantiations.environments.contextual_bandit_uci_environment import (
SLCBEnvironment,
)

DATA_PATH: str = "./utils/instantiations/environments/uci_datasets"
if os.path.exists("../Pearl"):
DATA_PATH: str = "pearl/utils/instantiations/environments/uci_datasets"
else:
DATA_PATH: str = "./utils/instantiations/environments/uci_datasets"

"""
Experiment config
Expand Down
12 changes: 11 additions & 1 deletion pearl/utils/scripts/cb_benchmark/run_cb_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,14 +275,24 @@ def run_cb_benchmarks(

# Create UCI data directory if it does not already exist
uci_data_path = "./utils/instantiations/environments/uci_datasets"
print("current_dir", os.getcwd())
print("dirs list", os.listdir())

if os.path.exists("../Pearl"):
uci_data_path = "pearl/utils/instantiations/environments/uci_datasets"
save_results_path: str = "pearl/utils/scripts/cb_benchmark/experiments_results"
print("dir exists", os.path.exists("pearl/utils/instantiations/environments/"))
else:
uci_data_path = "utils/instantiations/environments/uci_datasets"
save_results_path: str = "utils/scripts/cb_benchmark/experiments_results"
if not os.path.exists(uci_data_path):
os.makedirs(uci_data_path)

# Download UCI data
download_uci_data(data_path=uci_data_path)

# Create folder for result if it does not already exist
save_results_path: str = "./utils/scripts/cb_benchmark/experiments_results"

if not os.path.exists(save_results_path):
os.makedirs(save_results_path)

Expand Down
17 changes: 14 additions & 3 deletions test/unit/test_tutorials/test_rec_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import random
import unittest
from typing import List, Optional, Tuple
import os

import numpy as np

Expand Down Expand Up @@ -66,7 +67,7 @@
This environment's underlying model was pre-trained using the MIND dataset (Wu et al. 2020).
The model is defined by class `SequenceClassificationModel` below.
The model's state dict is saved in
tutorials/single_item_recommender_system_example/env_model_state_dict.pt
tutorials/single_item_recommender_system_example/env_model_state_dict

Each data point is:
- A history of impressions clicked by a user
Expand Down Expand Up @@ -189,16 +190,26 @@ def setUp(self) -> None:
def test_rec_system(self) -> None:
# load environment
model = SequenceClassificationModel(100).to(device)
print("current dir", os.getcwd())
print("dirs list", os.listdir())
if os.path.exists("../Pearl"):
# Github CI tests
print("Pearl directory exists")
model_dir = "tutorials/single_item_recommender_system_example/"
else:
# Meta internal tests
print("pearl directory exists")
model_dir = "pearl/tutorials/single_item_recommender_system_example/"
model.load_state_dict(
# Note: in the tutorial the directory "pearl" must be replaced by "Pearl"
torch.load(
"pearl/tutorials/single_item_recommender_system_example/env_model_state_dict.pt",
os.path.join(model_dir, "env_model_state_dict.pt"),
weights_only=True,
)
)
# Note: in the tutorial the directory "pearl" must be replaced by "Pearl"
actions = torch.load(
"pearl/tutorials/single_item_recommender_system_example/news_embedding_small.pt",
os.path.join(model_dir, "news_embedding_small.pt"),
weights_only=True,
)
history_length = 8
Expand Down
Loading