From fc3fafc92dc3a06b47ba941bee49c7400200a2c9 Mon Sep 17 00:00:00 2001 From: Joozef315 Date: Fri, 22 Nov 2024 20:18:59 +0200 Subject: [PATCH] reuse evaluate_full & edit tests --- dlib | 1 + src/ssvp_slt/data/sign_features_dataset.py | 4 +- tests/translation_demo.py | 17 ++++---- tests/translation_module_test.py | 20 +++++---- translation/main_translation.py | 48 +++++++++++++++++----- translation/run_translation_module.py | 4 +- 6 files changed, 66 insertions(+), 28 deletions(-) create mode 160000 dlib diff --git a/dlib b/dlib new file mode 160000 index 0000000..3924095 --- /dev/null +++ b/dlib @@ -0,0 +1 @@ +Subproject commit 39240959fadbe7a7d1f6f132e35a425f6359c4c4 diff --git a/src/ssvp_slt/data/sign_features_dataset.py b/src/ssvp_slt/data/sign_features_dataset.py index e16d2f5..dbefedd 100644 --- a/src/ssvp_slt/data/sign_features_dataset.py +++ b/src/ssvp_slt/data/sign_features_dataset.py @@ -74,8 +74,8 @@ def _construct_loader(self): break try: feature_name, length, label = line.strip().split("\t") - length = int(length) - except Exception: + length = int(float(length)) + except Exception: invalid += 1 continue diff --git a/tests/translation_demo.py b/tests/translation_demo.py index 31fcbe8..77d5299 100644 --- a/tests/translation_demo.py +++ b/tests/translation_demo.py @@ -1,31 +1,34 @@ -# # Copyright (c) Meta Platforms, Inc. and affiliates. -# # All rights reserved. +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. -# # This source code is licensed under the license found in the -# # LICENSE file in the root directory of this source tree. -# # -------------------------------------------------------- +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- import sys import os from omegaconf import OmegaConf +os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0' sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) # Now, import the modules from translation.run_translation_module import Config, run_translation, ModelConfig, DataConfig, CommonConfig + # Define the translation configuration translation_config = Config( common=CommonConfig( eval=True, - load_model="translation/signhiera_mock.pth" + load_model="translation/signhiera_mock.pth", + num_workers=0 ), data=DataConfig( val_data_dir="features_outputs/0" ), model=ModelConfig( name_or_path="google-t5/t5-base", - feature_dim=1024 + feature_dim=768 ) ) diff --git a/tests/translation_module_test.py b/tests/translation_module_test.py index 7fffe5f..c87eccd 100644 --- a/tests/translation_module_test.py +++ b/tests/translation_module_test.py @@ -6,21 +6,27 @@ # -------------------------------------------------------- import unittest -from unittest.mock import MagicMock, patch -from translation.run_translation_module import TranslationModule, Config +from unittest.mock import patch + +from omegaconf import OmegaConf +from translation.run_translation_module import run_translation, Config import numpy as np class TestTranslationModule(unittest.TestCase): def setUp(self): # Basic setup for testing TranslationModule self.config = Config() - self.config.model.name_or_path = "translation/signhiera_mock.pth" - self.translator = TranslationModule(self.config) + self.config.data.val_data_dir = "features_outputs/0" + self.config.common.num_workers = 0 + self.config.model.name_or_path = "google-t5/t5-base" + # Convert it to DictConfig + translation_dict_config = OmegaConf.structured(self.config) + self.translator = run_translation(translation_dict_config) @patch("run_translation_module.TranslationModule.run_translation") def test_translation_with_mock_features(self, mock_run_translation): # Mock feature array that simulates extracted features - mock_features = np.random.rand(10, 512) # 10 timesteps, 512-dim features + mock_features = np.random.rand(10, 768) # 10 timesteps, 768-dim features # Mock translation return value mock_run_translation.return_value = "This is a test translation." @@ -35,12 +41,12 @@ def test_translation_with_mock_features(self, mock_run_translation): def test_configuration_loading(self): # Ensure the configuration fields are loaded as expected - self.assertEqual(self.config.model.name_or_path, "translation/signhiera_mock.pth") + self.assertEqual(self.config.model.name_or_path, "google-t5/t5-base") @patch("translation_module.TranslationModule.run_translation") def test_translation_output_type(self, mock_run_translation): # Mock feature array for translation - mock_features = np.random.rand(10, 512) + mock_features = np.random.rand(10, 768) # Mock output for translation to simulate text output mock_run_translation.return_value = "Translation successful." diff --git a/translation/main_translation.py b/translation/main_translation.py index 203aace..8eb6f08 100644 --- a/translation/main_translation.py +++ b/translation/main_translation.py @@ -7,7 +7,10 @@ import datetime import time +import json +import os +import evaluate as hf_evaluate import ssvp_slt.util.misc as misc import torch from omegaconf import DictConfig, OmegaConf @@ -15,26 +18,51 @@ from translation.engine_translation import evaluate, evaluate_full, train_one_epoch from translation.utils_translation import (create_dataloader, create_model_and_tokenizer, create_optimizer_and_loss_scaler) - - + def eval(cfg: DictConfig): """ - Function to handle the evaluation of the model. + Function to handle the evaluation of the model on validation data only. """ device = torch.device(cfg.common.device) model, tokenizer = create_model_and_tokenizer(cfg) - # Load model for finetuning or eval + # Load model for evaluation if (misc.get_last_checkpoint(cfg) is None or cfg.common.eval) and cfg.common.load_model: misc.load_model(model, cfg.common.load_model) - evaluate_full(cfg, model.to(device), tokenizer, device) - # Create validation data loader and evaluate the model - dataloader_val = create_dataloader("val", cfg, tokenizer) - val_stats, _, _ = evaluate(cfg, dataloader_val, model.to(device), tokenizer, device) + cfg.common.eval = True + cfg.common.dist_eval = False - # Optionally, print or log val_stats for evaluation feedback - print("Validation Stats:", val_stats) + dataloader_val = create_dataloader("val", cfg, tokenizer) + stats, predictions, references = evaluate( + cfg, dataloader_val, model.to(device), tokenizer, device + ) + # Clear GPU for BLEURT computations if enabled + model.to("cpu") + del model + torch.cuda.empty_cache() + + # Compute BLEURT if configured and add full BLEURT array to stats + if cfg.common.compute_bleurt: + bleurt_metric = hf_evaluate.load("bleurt", module_type="metric", config_name="BLEURT-20") + stats["bleurt"] = bleurt_metric.compute( + predictions=predictions, references=references + )["scores"] + + # Save results + output_dir = cfg.common.output_dir + print(f"Validation results: {json.dumps(stats, ensure_ascii=False, indent=4)}") + + with open(os.path.join(output_dir, "val_outputs.tsv"), "w", encoding="utf-8") as f: + f.write("Prediction\tReference\n") + for hyp, ref in zip(predictions, references): + f.write(f"{hyp}\t{ref}\n") + + with open(os.path.join(output_dir, "val_results.json"), "w") as f: + json.dump(stats, f, ensure_ascii=False, indent=4) + + print(f"Wrote validation outputs to {cfg.common.output_dir}") + print("Evaluation completed.") def main(cfg: DictConfig): diff --git a/translation/run_translation_module.py b/translation/run_translation_module.py index 4dc8ad8..57e0b9a 100644 --- a/translation/run_translation_module.py +++ b/translation/run_translation_module.py @@ -13,7 +13,7 @@ from omegaconf import II, MISSING, DictConfig, OmegaConf from ssvp_slt.util.misc import reformat_logger -# from main_translation import main as translate + from translation.main_translation import eval as translate logger = logging.getLogger(__name__) @@ -42,7 +42,7 @@ class CommonConfig: @dataclass class ModelConfig: name_or_path: str = None - feature_dim: int = 512 + feature_dim: int = 768 from_scratch: bool = False dropout: float = 0.3 num_beams: int = 5