Skip to content

Commit

Permalink
edit test to download encoder
Browse files Browse the repository at this point in the history
Signed-off-by: Joozef315 <[email protected]>
  • Loading branch information
JooZef315 committed Nov 6, 2024
1 parent 5b55d99 commit c051573
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 2 deletions.
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -94,4 +94,7 @@ dmypy.json
# misc
*.mp4
sweep*/
core*
core*

features_outputs
*.pth
Empty file removed fairseq-sl/pip
Empty file.
4 changes: 3 additions & 1 deletion tests/feature_extraction_test.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import unittest
from translation.feature_extraction_module import FeatureExtractionConfig, FeatureExtractionModule, LauncherConfig
from unittest.mock import patch, MagicMock
from utils.download_model import get_model_path

class TestFeatureExtractionModule(unittest.TestCase):
def setUp(self):
# Mock the configuration for the FeatureExtractionModule
model_path = get_model_path()
self.config = FeatureExtractionConfig(
data_dir="MOCK_dataset",
pretrained_model_path="signhiera_mock.pth",
pretrained_model_path=model_path,
launcher=LauncherConfig(cluster="local")
)
self.module = FeatureExtractionModule(self.config)
Expand Down
16 changes: 16 additions & 0 deletions utils/download_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import os
import wget

def get_model_path():
model_path = 'signhiera_mock.pth'
url = 'https://dl.fbaipublicfiles.com/SONAR/asl/signhiera_mock.pth'

# Check if the model file exists
if os.path.exists(model_path):
print(f"Model already exists at: {model_path}")
else:
print("Model not found, downloading...")
filename = wget.download(url, model_path)
print(f"Downloaded model to: {filename}")

return model_path

0 comments on commit c051573

Please sign in to comment.