Skip to content

Commit

Permalink
Update test_rec_system.py
Browse files Browse the repository at this point in the history
  • Loading branch information
dkorenkevych authored Apr 25, 2024
1 parent 2ebb572 commit 13483e1
Showing 1 changed file with 12 additions and 2 deletions.
14 changes: 12 additions & 2 deletions test/unit/test_tutorials/test_rec_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import random
import unittest
from typing import List, Optional, Tuple
import os

import numpy as np

Expand Down Expand Up @@ -189,16 +190,25 @@ def setUp(self) -> None:
def test_rec_system(self) -> None:
# load environment
model = SequenceClassificationModel(100).to(device)
print("current dir", os.getcwd())
if os.path.exists("pearl"):
# Meta internal tests
print("pearl directory exists)
model_dir = "pearl/tutorials/single_item_recommender_system_example/"
else:
# Github CI tests
print("Pearl directory exists)
model_dir = "Pearl/tutorials/single_item_recommender_system_example/"
model.load_state_dict(
# Note: in the tutorial the directory "pearl" must be replaced by "Pearl"
torch.load(
"pearl/tutorials/single_item_recommender_system_example/env_model_state_dict.pt",
os.path.join("model_dir", "env_model_state_dict.pt"),
weights_only=True,
)
)
# Note: in the tutorial the directory "pearl" must be replaced by "Pearl"
actions = torch.load(
"pearl/tutorials/single_item_recommender_system_example/news_embedding_small.pt",
os.path.join(model_dir, "news_embedding_small.pt"),
weights_only=True,
)
history_length = 8
Expand Down

0 comments on commit 13483e1

Please sign in to comment.