Skip to content

Commit

Permalink
edit config & remove sonar from demo
Browse files Browse the repository at this point in the history
  • Loading branch information
JooZef315 committed Dec 2, 2024
1 parent 0e3f127 commit 9ac0f96
Show file tree
Hide file tree
Showing 6 changed files with 28 additions and 157 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -98,4 +98,5 @@ core*

features_outputs
MOCK_dataset
*.pth
*.pth
*.mp4
45 changes: 0 additions & 45 deletions examples/inference/e2e.py

This file was deleted.

9 changes: 5 additions & 4 deletions examples/inference/e2e_demo.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
from omegaconf import OmegaConf
from e2e import e2e_pipeline
from run import e2e_pipeline
from src.slt_inference.configs import RunConfig, FeatureExtractionConfig, TranslationConfig

# Define the translation configuration
translation_config = RunConfig(
video_path="D:/Pro/MLH/ctf/video.mp4",
video_path="./video.mp4",
verbose=True,
use_sonar=False,
feature_extraction=FeatureExtractionConfig(
pretrained_model_path="translation/signhiera_mock.pth",
pretrained_model_path="./signhiera_mock.pth",
),
translation=TranslationConfig(
base_model_name="google/t5-v1_1-large",
base_model_name="google-t5/t5-base",
tgt_langs=["eng_Latn", "fra_Latn"]
)
)
Expand Down
51 changes: 17 additions & 34 deletions examples/inference/run.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,19 @@
import os
import time
from pathlib import Path

import hydra
import torch
from hydra.core.config_store import ConfigStore
from omegaconf import DictConfig, OmegaConf

from src.slt_inference.preprocessing import Preprocessor
from src.slt_inference.feature_extraction import FeatureExtractor
from src.slt_inference.translation import SonarTranslator, Translator
from src.slt_inference.translation import Translator
from src.slt_inference.util import FLORES200_ID2LANG, print_translations
from src.slt_inference.configs import FeatureExtractor, RunConfig
from src.slt_inference.preprocessing import Preprocessor

from omegaconf import DictConfig, OmegaConf


cs = ConfigStore.instance()
cs.store(name="run_config", node=RunConfig)


@hydra.main(config_name="run_config")
def main(config: DictConfig):

os.chdir(hydra.utils.get_original_cwd())

def e2e_pipeline(config: DictConfig):
"""
Main function to run the sign language translation pipeline.
"""
os.chdir(os.getcwd()) # Ensure correct working directory

if config.verbose:
print(f"Config:\n{OmegaConf.to_yaml(config, resolve=True)}")

Expand All @@ -33,30 +22,24 @@ def main(config: DictConfig):

t0 = time.time()

# Initialize components
preprocessor = Preprocessor(config.preprocessing, device=device)
feature_extractor = FeatureExtractor(config.feature_extraction, device=device)

translator_cls = SonarTranslator if config.use_sonar else Translator
translator_cls = Translator
translator = translator_cls(config.translation, device=device, dtype=dtype)

t1 = time.time()

if config.verbose:
print(f"1. Model loading: {t1 - t0:.3f}s")


# Process input
inputs = preprocessor(Path(config.video_path))
extracted_features = feature_extractor(**inputs)

kwargs = {"tgt_langs": config.translation.tgt_langs} if config.use_sonar else {}
translations = translator(extracted_features, **kwargs)["translations"]

keys = (
[FLORES200_ID2LANG[lang] for lang in config.translation.tgt_langs]
if config.use_sonar
else range(config.translation.num_translations)
)
print_translations(keys, translations)

keys = [FLORES200_ID2LANG[lang] for lang in config.translation.tgt_langs] if config.use_sonar else range(config.translation.num_translations)

if __name__ == "__main__":
main()
# Output results
print_translations(keys, translations)
4 changes: 2 additions & 2 deletions examples/inference/src/slt_inference/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@ class FeatureExtractionConfig:
@dataclass
class TranslationConfig:
pretrained_model_path: str = "checkpoints/translator.pth"
tokenizer_path: str = "checkpoints/tokenizer"
base_model_name: str = "google/t5-v1_1-large"
tokenizer_path: str = "google-t5/t5-base"
base_model_name: str = "google-t5/t5-base"
feature_dim: int = 768
decoder_path: str = "checkpoints/sonar_decoder.pt"
decoder_spm_path: str = "checkpoints/decoder_sentencepiece.model"
Expand Down
73 changes: 2 additions & 71 deletions examples/inference/src/slt_inference/translation.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,11 @@
import time
from pathlib import Path
from typing import Dict, Sequence, Tuple
from typing import Dict, Tuple

import torch
from fairseq2.models.sequence import SequenceBatch
from omegaconf import DictConfig
from torch import nn
from transformers import AutoTokenizer, PreTrainedTokenizerFast

from .modeling.sign_t5 import SignT5Config, SignT5ForConditionalGeneration
from .modeling.sonar_decoder import load_decoder
from .modeling.sonar_t5_encoder import create_sonar_signt5_encoder_model
from .util import load_model


class Translator:
def __init__(
Expand Down Expand Up @@ -47,7 +40,6 @@ def _load_model_and_tokenizer(
)

print("Loading translator")
load_model(model, Path(self.config.pretrained_model_path))

model.eval()
model.to(self.device)
Expand Down Expand Up @@ -80,65 +72,4 @@ def __call__(self, features: torch.Tensor) -> Dict[str, str]:
if self.config.verbose:
print(f"4. Translation: {t1 - t0:.3f}s")

return {"translations": generated_text}


class SonarTranslator:
def __init__(
self,
config: DictConfig,
device: torch.device,
dtype: torch.dtype = torch.float32,
):
self.config = config
self.device = device
self.dtype = dtype

self.encoder, self.decoder = self._load_encoder_and_decoder()

def _load_encoder_and_decoder(
self,
) -> Tuple[nn.Module, nn.Module]:

encoder = create_sonar_signt5_encoder_model(
self.config, device=self.device, dtype=self.dtype
)
decoder = load_decoder(
self.config.decoder_path, self.config.decoder_spm_path, device=self.device
)
decoder.decoder.to(self.dtype)

print("Loading translator")
load_model(encoder, Path(self.config.pretrained_model_path), model_key="student")

encoder.eval()
decoder.decoder.eval()

return encoder, decoder

@torch.inference_mode()
def __call__(
self,
features: torch.Tensor,
tgt_langs: Sequence[str] = ["eng_Latn"],
) -> Dict[str, str]:

t0 = time.time()

features = features.to(self.device)

sentence_embedding = self.encoder(
SequenceBatch(seqs=features.unsqueeze(0), padding_mask=None)
).sentence_embeddings

translations = []
for tgt_lang in tgt_langs:
generated_text = self.decoder.decode(sentence_embedding, tgt_lang)[0]
translations.append(generated_text)

t1 = time.time()

if self.config.verbose:
print(f"4. Translation: {t1 - t0:.3f}s")

return {"translations": translations}
return {"translations": generated_text}

0 comments on commit 9ac0f96

Please sign in to comment.