Skip to content

Commit

Permalink
Fix drectories structure discrepancy in internal tests and github CI
Browse files Browse the repository at this point in the history
Summary: There is a discrepancy in directories structure between our internal testing environment and github CI (i.e. pearl/ vs Pear/ directories). This results in some tests failing. This diff fixes the discrepancy.

Reviewed By: rodrigodesalvobraz

Differential Revision:
D56585773

Privacy Context Container: L1202097

fbshipit-source-id: 52ec06179298e09f362dedc31a9843663ec0a301
  • Loading branch information
Dmytro Korenkevych authored and facebook-github-bot committed Apr 26, 2024
1 parent e38f01b commit 9a8ea26
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 6 deletions.
5 changes: 4 additions & 1 deletion pearl/utils/scripts/cb_benchmark/cb_benchmark_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,14 @@
from pearl.policy_learners.exploration_modules.contextual_bandits.ucb_exploration import (
UCBExploration,
)
from pearl.test.utils import prefix_dir
from pearl.utils.instantiations.environments.contextual_bandit_uci_environment import (
SLCBEnvironment,
)

DATA_PATH: str = "./utils/instantiations/environments/uci_datasets"

DATA_PATH: str = f"{prefix_dir()}utils/instantiations/environments/uci_datasets"


"""
Experiment config
Expand Down
9 changes: 6 additions & 3 deletions pearl/utils/scripts/cb_benchmark/run_cb_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,10 @@
from pearl.replay_buffers.contextual_bandits.discrete_contextual_bandit_replay_buffer import (
DiscreteContextualBanditReplayBuffer,
)
from pearl.test.utils import prefix_dir
from pearl.utils.instantiations.environments.contextual_bandit_uci_environment import (
SLCBEnvironment,
)

from pearl.utils.instantiations.spaces.discrete_action import DiscreteActionSpace

from pearl.utils.scripts.cb_benchmark.cb_benchmark_config import (
Expand Down Expand Up @@ -274,15 +274,18 @@ def run_cb_benchmarks(
"""

# Create UCI data directory if it does not already exist
uci_data_path = "./utils/instantiations/environments/uci_datasets"
uci_data_path: str = f"{prefix_dir()}utils/instantiations/environments/uci_datasets"
save_results_path: str = (
f"{prefix_dir()}utils/scripts/cb_benchmark/experiments_results"
)

if not os.path.exists(uci_data_path):
os.makedirs(uci_data_path)

# Download UCI data
download_uci_data(data_path=uci_data_path)

# Create folder for result if it does not already exist
save_results_path: str = "./utils/scripts/cb_benchmark/experiments_results"
if not os.path.exists(save_results_path):
os.makedirs(save_results_path)

Expand Down
12 changes: 10 additions & 2 deletions test/unit/test_tutorials/test_rec_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# pyre-strict


import os
import random
import unittest
from typing import List, Optional, Tuple
Expand Down Expand Up @@ -189,16 +190,23 @@ def setUp(self) -> None:
def test_rec_system(self) -> None:
# load environment
model = SequenceClassificationModel(100).to(device)
if os.path.exists("../Pearl"):
# Github CI
model_dir = "tutorials/single_item_recommender_system_example/"
else:
# internal Meta tests
model_dir = "pearl/tutorials/single_item_recommender_system_example/"

model.load_state_dict(
# Note: in the tutorial the directory "pearl" must be replaced by "Pearl"
torch.load(
"pearl/tutorials/single_item_recommender_system_example/env_model_state_dict.pt",
os.path.join(model_dir, "env_model_state_dict.pt"),
weights_only=True,
)
)
# Note: in the tutorial the directory "pearl" must be replaced by "Pearl"
actions = torch.load(
"pearl/tutorials/single_item_recommender_system_example/news_embedding_small.pt",
os.path.join(model_dir, "news_embedding_small.pt"),
weights_only=True,
)
history_length = 8
Expand Down
15 changes: 15 additions & 0 deletions test/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
This file contains helpers for unittest creation
"""

import os
from typing import Tuple

import torch
Expand Down Expand Up @@ -40,3 +41,17 @@ def create_normal_pdf_training_data(
) # corresponding pdf of mvn
y_corrupted = y + 0.01 * torch.randn(num_data_points) # noise corrupted targets
return x, y_corrupted


def prefix_dir() -> str:
"""
Returns the path needed to go from the current working directory while running
tests to the second-level Pearl packages, depending on the platform being run.
On the GitHub setup, this is "pearl/". In the internal Meta setup, this is "".
"""
if os.path.exists("../Pearl"):
# github CI
return "pearl/"
else:
# internal Meta tests
return ""

0 comments on commit 9a8ea26

Please sign in to comment.