Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add e2e script for video translation #10

Merged
merged 8 commits into from
Dec 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions inference/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
80 changes: 80 additions & 0 deletions inference/configs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------

from dataclasses import dataclass, field
from typing import List, Optional, Tuple

from omegaconf import II, MISSING


@dataclass
class PreprocessingConfig:
# Bounding box expansion and threshold parameters
up_exp: float = 1.0
down_exp: float = 3.0
left_exp: float = 1.5
right_exp: float = 1.5
iou_threshold: float = 0.2
num_ratio_threshold: float = 0.5

# Dlib detection parameters
hog_detector: bool = True
detector_path: Optional[str] = "checkpoints/detector.dat"
detection_sampling_rate: int = 16
detection_downsample: bool = True

# Sliding window sampling parameters
num_frames: int = 128
feature_extraction_stride: int = 64
sampling_rate: int = 2
target_fps: int = 25

# Cropping and Normalization
target_size: int = 224
mean: Tuple[float, float, float] = (0.45, 0.45, 0.45)
std: Tuple[float, float, float] = (0.225, 0.225, 0.225)

debug: bool = False
verbose: bool = II("verbose")

def __post_init__(self):
if self.hog_detector is False:
assert (
self.detector_path is not None
), "detector_path must be provded if `hog_detector=False`"


@dataclass
class FeatureExtractionConfig:
pretrained_model_path: str = "checkpoints/feature_extractor.pth"
model_name: str = "hiera_base_128x224"
max_batch_size: int = 2
fp16: bool = True
verbose: bool = II("verbose")


@dataclass
class TranslationConfig:
pretrained_model_path: str = "checkpoints/translator.pth"
tokenizer_path: str = "google-t5/t5-base"
base_model_name: str = "google-t5/t5-base"
feature_dim: int = 768
num_translations: int = 5
do_sample: bool = False
num_beams: int = 5
temperature: float = 1.0
max_length: int = 128
verbose: bool = II("verbose")


@dataclass
class RunConfig:
video_path: str = MISSING
preprocessing: PreprocessingConfig = field(default_factory=lambda: PreprocessingConfig())
feature_extraction: FeatureExtractionConfig = field(default_factory=FeatureExtractionConfig())
translation: TranslationConfig = field(default_factory=TranslationConfig())
verbose: bool = False
82 changes: 82 additions & 0 deletions inference/feature_extraction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------

import time
from pathlib import Path
from typing import Any, Generator

import torch
import torch.nn as nn
from omegaconf import DictConfig

from ssvp_slt.modeling import sign_hiera
from ssvp_slt.modeling.sign_hiera import SignHiera
from ssvp_slt.util.misc import load_model


def shard_generator(data: Any, shard_size: int) -> Generator[Any, None, None]:
for i in range(0, len(data), shard_size):
yield data[i : i + shard_size]


class FeatureExtractor:
def __init__(self, config: DictConfig, device: torch.device):
self.config = config
self.device = device

self.model = self._load_model()

def _load_model(self) -> SignHiera:
"""
Loads a pretrained SignHiera model for feature extraction and moves it to specified device
"""

model = sign_hiera.__dict__[self.config.model_name](pretrained=False, strict=False)

print("Loading feature extractor")
load_model(model, Path(self.config.pretrained_model_path))

model.head = nn.Identity()
model.eval()
model.to(self.device)

return model

@torch.inference_mode()
def __call__(self, frames: torch.Tensor, padding: torch.Tensor) -> torch.Tensor:

t0 = time.time()

frames = frames.to(self.device)
padding = padding.to(self.device)

if len(frames) > self.config.max_batch_size:

shard_outputs = []

frame_shards = shard_generator(frames, self.config.max_batch_size)
padding_shards = shard_generator(padding, self.config.max_batch_size)

for frames_shard, padding_shard in zip(frame_shards, padding_shards):
with torch.cuda.amp.autocast(enabled=self.config.fp16):
shard_output = self.model.extract_features(frames_shard, padding=padding_shard)
if len(shard_output.shape) == 1:
shard_output = shard_output.unsqueeze(0)

shard_outputs.append(shard_output)

outputs = torch.concatenate(shard_outputs, dim=0)
else:
with torch.cuda.amp.autocast(enabled=self.config.fp16):
outputs = self.model.extract_features(frames, padding=padding)

t1 = time.time()

if self.config.verbose:
print(f"3. Feature extraction: {t1 - t0:.3f}s")

return outputs
76 changes: 76 additions & 0 deletions inference/run.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------

import os
import time
from pathlib import Path
from typing import Sequence
import torch
from omegaconf import DictConfig, OmegaConf
from configs import FeatureExtractionConfig, RunConfig, TranslationConfig
#TODO: add Preprocessor class
from feature_extraction import FeatureExtractor
from translation import Translator

def print_translations(keys: Sequence[str], translations: Sequence[str]) -> None:
assert len(keys) == len(translations)
print(f"\nTranslations:\n{'-'*50}")
for key, translation in zip(keys, translations):
print(f'{key}: "{translation}"')
print(f"{'-' * 50}\n")

def run_pipeline(config: DictConfig):
"""
Main function to run the sign language translation pipeline.
"""
os.chdir(os.getcwd()) # Ensure correct working directory

if config.verbose:
print(f"Config:\n{OmegaConf.to_yaml(config, resolve=True)}")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.float16 if config.feature_extraction.fp16 else torch.float32

t0 = time.time()

# Initialize components
feature_extractor = FeatureExtractor(config.feature_extraction, device=device)
translator_cls = Translator
translator = translator_cls(config.translation, device=device, dtype=dtype)

t1 = time.time()
if config.verbose:
print(f"1. Model loading: {t1 - t0:.3f}s")

#TODO: perform preprocessing
extracted_features = feature_extractor(**inputs)

translations = translator(extracted_features)["translations"]

keys = range(config.translation.num_translations)

# Output results
print_translations(keys, translations)


# Define the translation configuration
translation_config = RunConfig(
video_path="path/to/your/video.mp4",
verbose=True,
feature_extraction=FeatureExtractionConfig(
pretrained_model_path="path/to/your/model.pth",
),
translation=TranslationConfig(
base_model_name="google-t5/t5-base",
)
)

# Convert it to DictConfig
translation_dict_config = OmegaConf.structured(translation_config)

# Run pipeline with provided parameters
run_pipeline(translation_dict_config)
82 changes: 82 additions & 0 deletions inference/translation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------

import time
from typing import Dict, Tuple

import torch
from omegaconf import DictConfig
from transformers import AutoTokenizer, PreTrainedTokenizerFast

from ssvp_slt.modeling.sign_t5 import SignT5Config, SignT5ForConditionalGeneration

class Translator:
def __init__(
self,
config: DictConfig,
device: torch.device,
dtype: torch.dtype = torch.float32,
):
self.config = config
self.device = device
self.dtype = dtype

self.model, self.tokenizer = self._load_model_and_tokenizer()

def _load_model_and_tokenizer(
self,
) -> Tuple[SignT5ForConditionalGeneration, PreTrainedTokenizerFast]:
"""
Loads a pretrained SignT5 model for translation and moves it to specified device
"""

config = SignT5Config(
decoder_start_token_id=0,
output_past=True,
tie_word_embeddings=False,
feature_dim=self.config.feature_dim,
)
model = SignT5ForConditionalGeneration._from_config(config)

tokenizer = AutoTokenizer.from_pretrained(
self.config.tokenizer_path, use_fast=False, legacy=True
)

print("Loading translator")

model.eval()
model.to(self.device)
model.to(self.dtype)

return model, tokenizer

@torch.inference_mode()
def __call__(self, features: torch.Tensor) -> Dict[str, str]:

t0 = time.time()

features = features.to(self.device)

generated_tokens = self.model.generate(
inputs_embeds=features.unsqueeze(0),
num_return_sequences=self.config.num_translations,
max_length=self.config.max_length,
num_beams=self.config.num_beams,
do_sample=self.config.do_sample,
temperature=self.config.temperature,
)
generated_text = [
t.strip()
for t in self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
]

t1 = time.time()

if self.config.verbose:
print(f"4. Translation: {t1 - t0:.3f}s")

return {"translations": generated_text}