Skip to content

Commit

Permalink
reuse evaluate_full & edit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
JooZef315 committed Nov 22, 2024
1 parent 18f3d56 commit fc3fafc
Show file tree
Hide file tree
Showing 6 changed files with 66 additions and 28 deletions.
1 change: 1 addition & 0 deletions dlib
Submodule dlib added at 392409
4 changes: 2 additions & 2 deletions src/ssvp_slt/data/sign_features_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,8 @@ def _construct_loader(self):
break
try:
feature_name, length, label = line.strip().split("\t")
length = int(length)
except Exception:
length = int(float(length))
except Exception:
invalid += 1
continue

Expand Down
17 changes: 10 additions & 7 deletions tests/translation_demo.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,34 @@
# # Copyright (c) Meta Platforms, Inc. and affiliates.
# # All rights reserved.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# # This source code is licensed under the license found in the
# # LICENSE file in the root directory of this source tree.
# # --------------------------------------------------------
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------

import sys
import os
from omegaconf import OmegaConf
os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'

sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

# Now, import the modules
from translation.run_translation_module import Config, run_translation, ModelConfig, DataConfig, CommonConfig


# Define the translation configuration
translation_config = Config(
common=CommonConfig(
eval=True,
load_model="translation/signhiera_mock.pth"
load_model="translation/signhiera_mock.pth",
num_workers=0
),
data=DataConfig(
val_data_dir="features_outputs/0"
),
model=ModelConfig(
name_or_path="google-t5/t5-base",
feature_dim=1024
feature_dim=768
)
)

Expand Down
20 changes: 13 additions & 7 deletions tests/translation_module_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,27 @@
# --------------------------------------------------------

import unittest
from unittest.mock import MagicMock, patch
from translation.run_translation_module import TranslationModule, Config
from unittest.mock import patch

from omegaconf import OmegaConf
from translation.run_translation_module import run_translation, Config
import numpy as np

class TestTranslationModule(unittest.TestCase):
def setUp(self):
# Basic setup for testing TranslationModule
self.config = Config()
self.config.model.name_or_path = "translation/signhiera_mock.pth"
self.translator = TranslationModule(self.config)
self.config.data.val_data_dir = "features_outputs/0"
self.config.common.num_workers = 0
self.config.model.name_or_path = "google-t5/t5-base"
# Convert it to DictConfig
translation_dict_config = OmegaConf.structured(self.config)
self.translator = run_translation(translation_dict_config)

@patch("run_translation_module.TranslationModule.run_translation")
def test_translation_with_mock_features(self, mock_run_translation):
# Mock feature array that simulates extracted features
mock_features = np.random.rand(10, 512) # 10 timesteps, 512-dim features
mock_features = np.random.rand(10, 768) # 10 timesteps, 768-dim features

# Mock translation return value
mock_run_translation.return_value = "This is a test translation."
Expand All @@ -35,12 +41,12 @@ def test_translation_with_mock_features(self, mock_run_translation):

def test_configuration_loading(self):
# Ensure the configuration fields are loaded as expected
self.assertEqual(self.config.model.name_or_path, "translation/signhiera_mock.pth")
self.assertEqual(self.config.model.name_or_path, "google-t5/t5-base")

@patch("translation_module.TranslationModule.run_translation")
def test_translation_output_type(self, mock_run_translation):
# Mock feature array for translation
mock_features = np.random.rand(10, 512)
mock_features = np.random.rand(10, 768)

# Mock output for translation to simulate text output
mock_run_translation.return_value = "Translation successful."
Expand Down
48 changes: 38 additions & 10 deletions translation/main_translation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,34 +7,62 @@

import datetime
import time
import json
import os

import evaluate as hf_evaluate
import ssvp_slt.util.misc as misc
import torch
from omegaconf import DictConfig, OmegaConf

from translation.engine_translation import evaluate, evaluate_full, train_one_epoch
from translation.utils_translation import (create_dataloader, create_model_and_tokenizer,
create_optimizer_and_loss_scaler)



def eval(cfg: DictConfig):
"""
Function to handle the evaluation of the model.
Function to handle the evaluation of the model on validation data only.
"""
device = torch.device(cfg.common.device)
model, tokenizer = create_model_and_tokenizer(cfg)

# Load model for finetuning or eval
# Load model for evaluation
if (misc.get_last_checkpoint(cfg) is None or cfg.common.eval) and cfg.common.load_model:
misc.load_model(model, cfg.common.load_model)

evaluate_full(cfg, model.to(device), tokenizer, device)
# Create validation data loader and evaluate the model
dataloader_val = create_dataloader("val", cfg, tokenizer)
val_stats, _, _ = evaluate(cfg, dataloader_val, model.to(device), tokenizer, device)
cfg.common.eval = True
cfg.common.dist_eval = False

# Optionally, print or log val_stats for evaluation feedback
print("Validation Stats:", val_stats)
dataloader_val = create_dataloader("val", cfg, tokenizer)
stats, predictions, references = evaluate(
cfg, dataloader_val, model.to(device), tokenizer, device
)
# Clear GPU for BLEURT computations if enabled
model.to("cpu")
del model
torch.cuda.empty_cache()

# Compute BLEURT if configured and add full BLEURT array to stats
if cfg.common.compute_bleurt:
bleurt_metric = hf_evaluate.load("bleurt", module_type="metric", config_name="BLEURT-20")
stats["bleurt"] = bleurt_metric.compute(
predictions=predictions, references=references
)["scores"]

# Save results
output_dir = cfg.common.output_dir
print(f"Validation results: {json.dumps(stats, ensure_ascii=False, indent=4)}")

with open(os.path.join(output_dir, "val_outputs.tsv"), "w", encoding="utf-8") as f:
f.write("Prediction\tReference\n")
for hyp, ref in zip(predictions, references):
f.write(f"{hyp}\t{ref}\n")

with open(os.path.join(output_dir, "val_results.json"), "w") as f:
json.dump(stats, f, ensure_ascii=False, indent=4)

print(f"Wrote validation outputs to {cfg.common.output_dir}")
print("Evaluation completed.")


def main(cfg: DictConfig):
Expand Down
4 changes: 2 additions & 2 deletions translation/run_translation_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from omegaconf import II, MISSING, DictConfig, OmegaConf
from ssvp_slt.util.misc import reformat_logger
# from main_translation import main as translate

from translation.main_translation import eval as translate

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -42,7 +42,7 @@ class CommonConfig:
@dataclass
class ModelConfig:
name_or_path: str = None
feature_dim: int = 512
feature_dim: int = 768
from_scratch: bool = False
dropout: float = 0.3
num_beams: int = 5
Expand Down

0 comments on commit fc3fafc

Please sign in to comment.