diff --git a/test/unit/test_tutorials/test_rec_system.py b/test/unit/test_tutorials/test_rec_system.py index 65b5ad84..a813a78f 100644 --- a/test/unit/test_tutorials/test_rec_system.py +++ b/test/unit/test_tutorials/test_rec_system.py @@ -6,6 +6,7 @@ import random import unittest from typing import List, Optional, Tuple +import os import numpy as np @@ -189,16 +190,25 @@ def setUp(self) -> None: def test_rec_system(self) -> None: # load environment model = SequenceClassificationModel(100).to(device) + print("current dir", os.getcwd()) + if os.path.exists("pearl"): + # Meta internal tests + print("pearl directory exists) + model_dir = "pearl/tutorials/single_item_recommender_system_example/" + else: + # Github CI tests + print("Pearl directory exists) + model_dir = "Pearl/tutorials/single_item_recommender_system_example/" model.load_state_dict( # Note: in the tutorial the directory "pearl" must be replaced by "Pearl" torch.load( - "pearl/tutorials/single_item_recommender_system_example/env_model_state_dict.pt", + os.path.join("model_dir", "env_model_state_dict.pt"), weights_only=True, ) ) # Note: in the tutorial the directory "pearl" must be replaced by "Pearl" actions = torch.load( - "pearl/tutorials/single_item_recommender_system_example/news_embedding_small.pt", + os.path.join(model_dir, "news_embedding_small.pt"), weights_only=True, ) history_length = 8