From 0e3f127630e0de3eac59c6d7f866fe7300c4a8bf Mon Sep 17 00:00:00 2001 From: Joozef315 Date: Fri, 29 Nov 2024 18:17:29 +0200 Subject: [PATCH] add e2e script & e2e demo --- .gitignore | 6 +- examples/inference/README.md | 99 + examples/inference/app.py | 198 ++ examples/inference/e2e.py | 45 + examples/inference/e2e_demo.py | 22 + examples/inference/run.py | 62 + examples/inference/setup.py | 16 + .../inference/src/slt_inference/__init__.py | 0 .../src/slt_inference/checkpoint_utils.py | 0 .../inference/src/slt_inference/configs.py | 85 + .../src/slt_inference/feature_extraction.py | 75 + .../src/slt_inference/modeling/__init__.py | 0 .../src/slt_inference/modeling/sign_hiera.py | 770 +++++++ .../modeling/sign_hiera_utils.py | 332 +++ .../src/slt_inference/modeling/sign_t5.py | 1830 +++++++++++++++++ .../slt_inference/modeling/sonar_decoder.py | 114 + .../modeling/sonar_t5_encoder.py | 96 + .../src/slt_inference/preprocessing.py | 709 +++++++ .../src/slt_inference/translation.py | 144 ++ examples/inference/src/slt_inference/util.py | 265 +++ 20 files changed, 4867 insertions(+), 1 deletion(-) create mode 100644 examples/inference/README.md create mode 100644 examples/inference/app.py create mode 100644 examples/inference/e2e.py create mode 100644 examples/inference/e2e_demo.py create mode 100644 examples/inference/run.py create mode 100644 examples/inference/setup.py create mode 100644 examples/inference/src/slt_inference/__init__.py create mode 100644 examples/inference/src/slt_inference/checkpoint_utils.py create mode 100644 examples/inference/src/slt_inference/configs.py create mode 100644 examples/inference/src/slt_inference/feature_extraction.py create mode 100644 examples/inference/src/slt_inference/modeling/__init__.py create mode 100644 examples/inference/src/slt_inference/modeling/sign_hiera.py create mode 100644 examples/inference/src/slt_inference/modeling/sign_hiera_utils.py create mode 100644 examples/inference/src/slt_inference/modeling/sign_t5.py create mode 100644 examples/inference/src/slt_inference/modeling/sonar_decoder.py create mode 100644 examples/inference/src/slt_inference/modeling/sonar_t5_encoder.py create mode 100644 examples/inference/src/slt_inference/preprocessing.py create mode 100644 examples/inference/src/slt_inference/translation.py create mode 100644 examples/inference/src/slt_inference/util.py diff --git a/.gitignore b/.gitignore index 93690b6..9b2cf64 100644 --- a/.gitignore +++ b/.gitignore @@ -94,4 +94,8 @@ dmypy.json # misc *.mp4 sweep*/ -core* \ No newline at end of file +core* + +features_outputs +MOCK_dataset +*.pth \ No newline at end of file diff --git a/examples/inference/README.md b/examples/inference/README.md new file mode 100644 index 0000000..fc7e133 --- /dev/null +++ b/examples/inference/README.md @@ -0,0 +1,99 @@ + + +### Installation + +1. Create and activate conda environment +```bash +conda create --name sign-language-inference python=3.10 +conda activate sign-language-inference +``` + +2. Install dependencies +```bash +conda install -c -y conda-forge libsndfile==1.0.31 # fairseq2 dependency +pip install -r requirements.txt +``` + +3. Install dlib with CUDA support (alternatively, `pip install dlib`; comes without CUDA support) + +Confirm whether your dlib installation supports CUDA with: `python -c "import dlib; print(dlib.DLIB_USE_CUDA)"`. + +```bash +# This configuration worked on FAIR cluster; make sure to check build logs + +module load cmake/3.15.3/gcc.7.3.0 gcc/12.2.0 cuda/12.1 cudnn/v8.8.1.3-cuda.12.0 + +git clone https://github.com/davisking/dlib.git +cd dlib +python setup.py install \ + --set CUDA_TOOLKIT_ROOT_DIR=/public/apps/cuda/12.1 \ + --set CMAKE_PREFIX_PATH=/public/apps/cudnn/v8.8.1.3-cuda.12.0 \ + --set USE_AVX_INSTRUCTIONS=yes \ + --set DLIB_USE_CUDA=yes + +``` + +4. Install this repo (from repo root) +```bash +pip install -e . +``` + + +### Model checkpoints + +Copy model checkpoints to [checkpoints](checkpoints) folder. +```bash +cp -r /checkpoint/philliprust/slt_inference/checkpoints/* checkpoints/ +``` + +### Running Inference + +```bash +python run.py video_path=/path/to/video.mp4 +``` + +Pass `verbose=True` to print config and time elapsed for various steps in the pipeline. + +Pass `preprocessing.hog_detector=false` if running dlib CNN detector with CUDA support. + +Pass `preprocessing.detection_downsample=false` if the video input resolution is already small, e.g. 224x224. + +Here is a more advanced example for running a SONAR model: + +```bash +python run.py \ + video_path=/path/to/my/video.mp4 \ + use_sonar=true \ + preprocessing.detection_downsample=false \ + feature_extraction.pretrained_model_path=/path/to/trained/signhiera.pth \ + translation.base_model_name=google/t5-v1_1-large \ + translation.pretrained_model_path=/path/to/trained/sonar/encoder/best_model.pth \ + feature_extraction.fp16=true \ + 'translation.tgt_langs=[eng_Latn, fra_Latn, deu_Latn, zho_Hans]' +``` + + +### Running the Gradio app + +To run the Gradio app, you can use the exact same command as for `run.py` but leave out the `video_path`. You can then open the browser and upload a video. As soon as the video finishes uploading, it should start playing automatically which will trigger the inference pipeline. If the video doesn't start automatically, just press the play button manually. + +```bash +python app.py \ + use_sonar=true \ + preprocessing.detection_downsample=false \ + feature_extraction.pretrained_model_path=/path/to/trained/signhiera.pth \ + translation.base_model_name=google/t5-v1_1-large \ + translation.pretrained_model_path=/path/to/trained/sonar/encoder/best_model.pth \ + feature_extraction.fp16=true \ + 'translation.tgt_langs=[eng_Latn, fra_Latn, deu_Latn, zho_Hans]' +``` + +The Gradio app needs to be launched from a GPU machine. If running on a devfair, you can tunnel to your local machine via `ssh -L 8000:localhost:7860 devfair`. After the devfair login, you can go to `localhost:8000` in your browser to use the app. + +### Slicing a video before inferencing + +Videos should ideally be a sentence long. To slice a video, e.g. from second 10 to 14, you can use ffmpeg: + +```bash +ffmpeg -ss 10.00 -to 14.00 -i ./video.MOV -c:v libx264 -crf 20 video_slice.mp4 -loglevel info +``` \ No newline at end of file diff --git a/examples/inference/app.py b/examples/inference/app.py new file mode 100644 index 0000000..f372f6d --- /dev/null +++ b/examples/inference/app.py @@ -0,0 +1,198 @@ +import os +from dataclasses import dataclass, field +from pathlib import Path +from queue import Queue +from time import sleep +from typing import List, Optional, Sequence, Tuple + +import gradio as gr +import hydra +import torch +from hydra.core.config_store import ConfigStore +from omegaconf import II, DictConfig, OmegaConf + +from slt_inference.feature_extraction import FeatureExtractor +from slt_inference.preprocessing import Preprocessor +from slt_inference.translation import SonarTranslator, Translator +from slt_inference.util import FLORES200_ID2LANG, print_translations + + +@dataclass +class PreprocessingConfig: + # Bounding box expansion and threshold parameters + up_exp: float = 1.0 + down_exp: float = 3.0 + left_exp: float = 1.5 + right_exp: float = 1.5 + iou_threshold: float = 0.2 + num_ratio_threshold: float = 0.5 + + # Dlib detection parameters + hog_detector: bool = True + detector_path: Optional[str] = "checkpoints/detector.dat" + detection_sampling_rate: int = 16 + detection_downsample: bool = True + + # Sliding window sampling parameters + num_frames: int = 128 + feature_extraction_stride: int = 64 + sampling_rate: int = 2 + target_fps: int = 25 + + # Cropping and Normalization + target_size: int = 224 + mean: Tuple[float, float, float] = (0.45, 0.45, 0.45) + std: Tuple[float, float, float] = (0.225, 0.225, 0.225) + + debug: bool = False + verbose: bool = II("verbose") + + def __post_init__(self): + if self.hog_detector is False: + assert ( + self.detector_path is not None + ), "detector_path must be provded if `hog_detector=False`" + + +@dataclass +class FeatureExtractionConfig: + pretrained_model_path: str = "checkpoints/feature_extractor.pth" + model_name: str = "hiera_base_128x224" + max_batch_size: int = 2 + fp16: bool = True + verbose: bool = II("verbose") + + +@dataclass +class TranslationConfig: + pretrained_model_path: str = "checkpoints/translator.pth" + tokenizer_path: str = "checkpoints/tokenizer" + base_model_name: str = "google/t5-v1_1-large" + feature_dim: int = 768 + decoder_path: str = "checkpoints/sonar_decoder.pt" + decoder_spm_path: str = "checkpoints/decoder_sentencepiece.model" + + # Target languages for Sonar translator + tgt_langs: List[str] = field(default_factory=lambda: ["eng_Latn"]) + + # Generation parameters + num_translations: int = 5 + do_sample: bool = False + num_beams: int = 5 + temperature: float = 1.0 + max_length: int = 128 + + verbose: bool = II("verbose") + + def __post_init__(self): + for lang in self.tgt_langs: + if lang not in FLORES200_ID2LANG: + raise ValueError(f"{lang} is not a valid FLORES-200 language ID") + + +@dataclass +class RunConfig: + preprocessing: PreprocessingConfig = PreprocessingConfig() + feature_extraction: FeatureExtractionConfig = FeatureExtractionConfig() + translation: TranslationConfig = TranslationConfig() + use_sonar: bool = False + verbose: bool = False + + +cs = ConfigStore.instance() +cs.store(name="run_config", node=RunConfig) + +video = None +video_released = False +translation_queue = Queue(maxsize=1) + +css = """ +.app { + max-width: 50% !important; +} +""" + + +@hydra.main(config_name="run_config") +def main(config: DictConfig): + os.chdir(hydra.utils.get_original_cwd()) + + print(f"Config:\n{OmegaConf.to_yaml(config, resolve=True)}") + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + dtype = torch.float16 if config.feature_extraction.fp16 else torch.float32 + + preprocessor = Preprocessor(config.preprocessing, device=device) + feature_extractor = FeatureExtractor(config.feature_extraction, device=device) + + translator_cls = SonarTranslator if config.use_sonar else Translator + translator = translator_cls(config.translation, device=device, dtype=dtype) + + def release(): + """Triggered when video ends to indicate translation can now be displayed""" + global video_released + video_released = True + + def load_video(video_path: str): + """Load and preprocess the uploaded video""" + global video + video = preprocessor(Path(video_path)) + + def inference_video(): + """Run inference on video and put translations in a queue""" + global translation_queue + extracted_features = feature_extractor(**video) + + kwargs = {"tgt_langs": config.translation.tgt_langs} if config.use_sonar else {} + translations = translator(extracted_features, **kwargs)["translations"] + translation_queue.put(translations) + + keys = ( + [FLORES200_ID2LANG[lang] for lang in config.translation.tgt_langs] + if config.use_sonar + else range(config.translation.num_translations) + ) + print_translations(keys, translations) + + def show_translation(): + """Consume translation results from queue and display them""" + + global video_released + # Wait until video has finished playing before showing translation + while not video_released: + sleep(0.05) + + video_released = False + translation_result = translation_queue.get() + + return ( + [gr.Text(v) for v in translation_result] + if config.use_sonar + else gr.Text(translation_result[0]) + ) + + with gr.Blocks(css=css) as demo: + with gr.Column(scale=1): + gr.Markdown("### ASL") + input_video = gr.Video(label="Input Video", height=360, autoplay=True) + + output_texts = [] + if config.use_sonar: + for lang in config.translation.tgt_langs: + gr.Markdown(f"### {FLORES200_ID2LANG[lang]}") + output_texts.append(gr.Text("", interactive=False, label="Translation")) + else: + gr.Markdown("### English") + output_texts.append(gr.Text("", interactive=False, label="Translation")) + + input_video.upload(fn=load_video, inputs=input_video, outputs=None).success( + fn=inference_video, inputs=None, outputs=None + ).success(fn=show_translation, inputs=None, outputs=output_texts) + + input_video.end(fn=release, inputs=None, outputs=None) + + demo.launch() + + +if __name__ == "__main__": + main() diff --git a/examples/inference/e2e.py b/examples/inference/e2e.py new file mode 100644 index 0000000..1169292 --- /dev/null +++ b/examples/inference/e2e.py @@ -0,0 +1,45 @@ +import os +import time +from pathlib import Path +import torch +from omegaconf import DictConfig, OmegaConf +from src.slt_inference.preprocessing import Preprocessor +from src.slt_inference.feature_extraction import FeatureExtractor +from src.slt_inference.translation import SonarTranslator, Translator +from src.slt_inference.util import FLORES200_ID2LANG, print_translations + +def e2e_pipeline(config: DictConfig): + """ + Main function to run the sign language translation pipeline. + """ + os.chdir(os.getcwd()) # Ensure correct working directory + + if config.verbose: + print(f"Config:\n{OmegaConf.to_yaml(config, resolve=True)}") + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + dtype = torch.float16 if config.feature_extraction.fp16 else torch.float32 + + t0 = time.time() + + # Initialize components + preprocessor = Preprocessor(config.preprocessing, device=device) + feature_extractor = FeatureExtractor(config.feature_extraction, device=device) + translator_cls = SonarTranslator if config.use_sonar else Translator + translator = translator_cls(config.translation, device=device, dtype=dtype) + + t1 = time.time() + if config.verbose: + print(f"1. Model loading: {t1 - t0:.3f}s") + + # Process input + inputs = preprocessor(Path(config.video_path)) + extracted_features = feature_extractor(**inputs) + + kwargs = {"tgt_langs": config.translation.tgt_langs} if config.use_sonar else {} + translations = translator(extracted_features, **kwargs)["translations"] + + keys = [FLORES200_ID2LANG[lang] for lang in config.translation.tgt_langs] if config.use_sonar else range(config.translation.num_translations) + + # Output results + print_translations(keys, translations) diff --git a/examples/inference/e2e_demo.py b/examples/inference/e2e_demo.py new file mode 100644 index 0000000..39cea05 --- /dev/null +++ b/examples/inference/e2e_demo.py @@ -0,0 +1,22 @@ +from omegaconf import OmegaConf +from e2e import e2e_pipeline +from src.slt_inference.configs import RunConfig, FeatureExtractionConfig, TranslationConfig + +# Define the translation configuration +translation_config = RunConfig( + video_path="D:/Pro/MLH/ctf/video.mp4", + verbose=True, + feature_extraction=FeatureExtractionConfig( + pretrained_model_path="translation/signhiera_mock.pth", + ), + translation=TranslationConfig( + base_model_name="google/t5-v1_1-large", + tgt_langs=["eng_Latn", "fra_Latn"] + ) +) + +# Convert it to DictConfig +translation_dict_config = OmegaConf.structured(translation_config) + +# Run pipeline with provided parameters +e2e_pipeline(translation_dict_config) diff --git a/examples/inference/run.py b/examples/inference/run.py new file mode 100644 index 0000000..1f5737f --- /dev/null +++ b/examples/inference/run.py @@ -0,0 +1,62 @@ +import os +import time +from pathlib import Path + +import hydra +import torch +from hydra.core.config_store import ConfigStore +from omegaconf import DictConfig, OmegaConf + +from src.slt_inference.feature_extraction import FeatureExtractor +from src.slt_inference.translation import SonarTranslator, Translator +from src.slt_inference.util import FLORES200_ID2LANG, print_translations +from src.slt_inference.configs import FeatureExtractor, RunConfig +from src.slt_inference.preprocessing import Preprocessor + +from omegaconf import DictConfig, OmegaConf + + +cs = ConfigStore.instance() +cs.store(name="run_config", node=RunConfig) + + +@hydra.main(config_name="run_config") +def main(config: DictConfig): + + os.chdir(hydra.utils.get_original_cwd()) + + if config.verbose: + print(f"Config:\n{OmegaConf.to_yaml(config, resolve=True)}") + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + dtype = torch.float16 if config.feature_extraction.fp16 else torch.float32 + + t0 = time.time() + + preprocessor = Preprocessor(config.preprocessing, device=device) + feature_extractor = FeatureExtractor(config.feature_extraction, device=device) + + translator_cls = SonarTranslator if config.use_sonar else Translator + translator = translator_cls(config.translation, device=device, dtype=dtype) + + t1 = time.time() + + if config.verbose: + print(f"1. Model loading: {t1 - t0:.3f}s") + + inputs = preprocessor(Path(config.video_path)) + extracted_features = feature_extractor(**inputs) + + kwargs = {"tgt_langs": config.translation.tgt_langs} if config.use_sonar else {} + translations = translator(extracted_features, **kwargs)["translations"] + + keys = ( + [FLORES200_ID2LANG[lang] for lang in config.translation.tgt_langs] + if config.use_sonar + else range(config.translation.num_translations) + ) + print_translations(keys, translations) + + +if __name__ == "__main__": + main() diff --git a/examples/inference/setup.py b/examples/inference/setup.py new file mode 100644 index 0000000..fb00690 --- /dev/null +++ b/examples/inference/setup.py @@ -0,0 +1,16 @@ +#!/usr/bin/env python + +from setuptools import find_packages, setup + +setup( + name="slt_inference", + version="0.0.1", + author="Phillip Rust", + author_email="philliprust@meta.com", + url="https://github.com/fairinternal/slt_inference", + description="Experimental code for sign language translation inference", + license="", + package_dir={"": "src"}, + packages=find_packages("src"), + zip_safe=True, +) diff --git a/examples/inference/src/slt_inference/__init__.py b/examples/inference/src/slt_inference/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/examples/inference/src/slt_inference/checkpoint_utils.py b/examples/inference/src/slt_inference/checkpoint_utils.py new file mode 100644 index 0000000..e69de29 diff --git a/examples/inference/src/slt_inference/configs.py b/examples/inference/src/slt_inference/configs.py new file mode 100644 index 0000000..c50be48 --- /dev/null +++ b/examples/inference/src/slt_inference/configs.py @@ -0,0 +1,85 @@ + +from dataclasses import dataclass, field +from typing import List, Optional, Tuple + +from omegaconf import II, MISSING + + +@dataclass +class PreprocessingConfig: + # Bounding box expansion and threshold parameters + up_exp: float = 1.0 + down_exp: float = 3.0 + left_exp: float = 1.5 + right_exp: float = 1.5 + iou_threshold: float = 0.2 + num_ratio_threshold: float = 0.5 + + # Dlib detection parameters + hog_detector: bool = True + detector_path: Optional[str] = "checkpoints/detector.dat" + detection_sampling_rate: int = 16 + detection_downsample: bool = True + + # Sliding window sampling parameters + num_frames: int = 128 + feature_extraction_stride: int = 64 + sampling_rate: int = 2 + target_fps: int = 25 + + # Cropping and Normalization + target_size: int = 224 + mean: Tuple[float, float, float] = (0.45, 0.45, 0.45) + std: Tuple[float, float, float] = (0.225, 0.225, 0.225) + + debug: bool = False + verbose: bool = II("verbose") + + def __post_init__(self): + if self.hog_detector is False: + assert ( + self.detector_path is not None + ), "detector_path must be provded if `hog_detector=False`" + + +@dataclass +class FeatureExtractionConfig: + pretrained_model_path: str = "checkpoints/feature_extractor.pth" + model_name: str = "hiera_base_128x224" + max_batch_size: int = 2 + fp16: bool = True + verbose: bool = II("verbose") + + +@dataclass +class TranslationConfig: + pretrained_model_path: str = "checkpoints/translator.pth" + tokenizer_path: str = "checkpoints/tokenizer" + base_model_name: str = "google/t5-v1_1-large" + feature_dim: int = 768 + decoder_path: str = "checkpoints/sonar_decoder.pt" + decoder_spm_path: str = "checkpoints/decoder_sentencepiece.model" + + # Target languages for Sonar translator + tgt_langs: List[str] = field(default_factory=lambda: ["eng_Latn"]) + + # Generation parameters + # Note: these are ignored when using SONAR + num_translations: int = 5 + do_sample: bool = False + num_beams: int = 5 + temperature: float = 1.0 + max_length: int = 128 + + verbose: bool = II("verbose") + + +@dataclass +class RunConfig: + video_path: str = MISSING + + preprocessing: PreprocessingConfig = field(default_factory=lambda: PreprocessingConfig()) + feature_extraction: FeatureExtractionConfig = field(default_factory=FeatureExtractionConfig()) + translation: TranslationConfig = field(default_factory=TranslationConfig()) + use_sonar: bool = False + verbose: bool = False diff --git a/examples/inference/src/slt_inference/feature_extraction.py b/examples/inference/src/slt_inference/feature_extraction.py new file mode 100644 index 0000000..58f369b --- /dev/null +++ b/examples/inference/src/slt_inference/feature_extraction.py @@ -0,0 +1,75 @@ +import time +from pathlib import Path +from typing import Any, Generator + +import torch +import torch.nn as nn +from omegaconf import DictConfig + +from .modeling import sign_hiera +from .modeling.sign_hiera import SignHiera +from .util import load_model + + +def shard_generator(data: Any, shard_size: int) -> Generator[Any, None, None]: + for i in range(0, len(data), shard_size): + yield data[i : i + shard_size] + + +class FeatureExtractor: + def __init__(self, config: DictConfig, device: torch.device): + self.config = config + self.device = device + + self.model = self._load_model() + + def _load_model(self) -> SignHiera: + """ + Loads a pretrained SignHiera model for feature extraction and moves it to specified device + """ + + model = sign_hiera.__dict__[self.config.model_name](pretrained=False, strict=False) + + print("Loading feature extractor") + load_model(model, Path(self.config.pretrained_model_path)) + + model.head = nn.Identity() + model.eval() + model.to(self.device) + + return model + + @torch.inference_mode() + def __call__(self, frames: torch.Tensor, padding: torch.Tensor) -> torch.Tensor: + + t0 = time.time() + + frames = frames.to(self.device) + padding = padding.to(self.device) + + if len(frames) > self.config.max_batch_size: + + shard_outputs = [] + + frame_shards = shard_generator(frames, self.config.max_batch_size) + padding_shards = shard_generator(padding, self.config.max_batch_size) + + for frames_shard, padding_shard in zip(frame_shards, padding_shards): + with torch.cuda.amp.autocast(enabled=self.config.fp16): + shard_output = self.model.extract_features(frames_shard, padding=padding_shard) + if len(shard_output.shape) == 1: + shard_output = shard_output.unsqueeze(0) + + shard_outputs.append(shard_output) + + outputs = torch.concatenate(shard_outputs, dim=0) + else: + with torch.cuda.amp.autocast(enabled=self.config.fp16): + outputs = self.model.extract_features(frames, padding=padding) + + t1 = time.time() + + if self.config.verbose: + print(f"3. Feature extraction: {t1 - t0:.3f}s") + + return outputs diff --git a/examples/inference/src/slt_inference/modeling/__init__.py b/examples/inference/src/slt_inference/modeling/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/examples/inference/src/slt_inference/modeling/sign_hiera.py b/examples/inference/src/slt_inference/modeling/sign_hiera.py new file mode 100644 index 0000000..46bcc0f --- /dev/null +++ b/examples/inference/src/slt_inference/modeling/sign_hiera.py @@ -0,0 +1,770 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# +# References: +# +# Hiera: https://github.com/facebookresearch/hiera/ +# slowfast: https://github.com/facebookresearch/SlowFast +# timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm +# -------------------------------------------------------- + +""" +SignHiera PyTorch model +Adapted from https://github.com/facebookresearch/hiera/blob/v0.1.2/hiera/hiera.py + +Main changes made: +- made clip size variable (to increase from 16 to 128) +- added temporal attention masking at mask unit (MU) level +- added feature extraction +- added loading from pretrained CLIP/FLIP model with Hiera-based vision tower +""" + +import collections.abc +import math +from functools import partial +from itertools import repeat +from typing import Callable, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .sign_hiera_utils import Reroll, Unroll, conv_nd, do_masked_conv, do_pool, pretrained_model + + +# From PyTorch internals +def _ntuple(n): + """Credit: https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/helpers.py""" + + def parse(x): + if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): + return tuple(x) + return tuple(repeat(x, n)) + + return parse + + +to_2tuple = _ntuple(2) + + +class Mlp(nn.Module): + """ + Credit: https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/mlp.py + + MLP as used in Vision Transformer, MLP-Mixer and related networks + """ + + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + norm_layer=None, + bias=True, + drop=0.0, + use_conv=False, + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + bias = to_2tuple(bias) + drop_probs = to_2tuple(drop) + linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear + + self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0]) + self.act = act_layer() + self.drop1 = nn.Dropout(drop_probs[0]) + self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity() + self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1]) + self.drop2 = nn.Dropout(drop_probs[1]) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop1(x) + x = self.norm(x) + x = self.fc2(x) + x = self.drop2(x) + return x + + +def drop_path(x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True): + """ + Credit: https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/drop.py + + Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + + This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, + the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for + changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use + 'survival rate' as the argument. + + """ + if drop_prob == 0.0 or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0 and scale_by_keep: + random_tensor.div_(keep_prob) + return x * random_tensor + + +class DropPath(nn.Module): + """ + Credit: https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/drop.py + + Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + """ + + def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True): + super().__init__() + self.drop_prob = drop_prob + self.scale_by_keep = scale_by_keep + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training, self.scale_by_keep) + + def extra_repr(self): + return f"drop_prob={round(self.drop_prob,3):0.3f}" + + +class MaskUnitAttention(nn.Module): + """ + Computes either Mask Unit or Global Attention. Also is able to perform q pooling. + + Note: this assumes the tokens have already been flattened and unrolled into mask units. + See `Unroll` for more details. + """ + + def __init__( + self, + dim: int, + dim_out: int, + heads: int, + q_stride: int = 1, + window_size: int = 0, + use_mask_unit_attn: bool = False, + ): + """ + Args: + - dim, dim_out: The input and output feature dimensions. + - heads: The number of attention heads. + - q_stride: If greater than 1, pool q with this stride. The stride should be flattened (e.g., 2x2 = 4). + - window_size: The current (flattened) size of a mask unit *after* pooling (if any). + - use_mask_unit_attn: Use Mask Unit or Global Attention. + """ + super().__init__() + + self.dim = dim + self.dim_out = dim_out + self.heads = heads + self.q_stride = q_stride + + self.head_dim = dim_out // heads + self.scale = (self.head_dim) ** -0.5 + + self.qkv = nn.Linear(dim, 3 * dim_out) + self.proj = nn.Linear(dim_out, dim_out) + + self.window_size = window_size + self.use_mask_unit_attn = use_mask_unit_attn + + def forward(self, x: torch.Tensor, attn_mask: torch.Tensor) -> torch.Tensor: + """Input should be of shape [batch, tokens, channels].""" + + B, N, _ = x.shape + num_windows = (N // (self.q_stride * self.window_size)) if self.use_mask_unit_attn else 1 + + qkv = ( + self.qkv(x) + .reshape(B, -1, num_windows, 3, self.heads, self.head_dim) + .permute(3, 0, 4, 2, 1, 5) + ) + q, k, v = qkv[0], qkv[1], qkv[2] + + # [B, N, 1] -> [B, W, N//W] + attn_mask = attn_mask.reshape(B, -1, num_windows).permute(0, 2, 1) + # [B, W, N//W] -> [B, H=1, W, L=1, N//W] + attn_mask = attn_mask[:, None, :, None, :] + attn_mask = torch.where(attn_mask == 0, -float("inf"), attn_mask) + + if self.q_stride > 1: + # Refer to Unroll to see how this performs a maxpool-Nd + q = ( + q.view(B, self.heads, num_windows, self.q_stride, -1, self.head_dim) + .max(dim=3) + .values + ) + + # Only use attention mask when doing global attention. + # This is because we do attention masking at the level of mask units, so when doing mask unit attention, + # the attention mask would either be true or false for the full window, which is not useful + if hasattr(F, "scaled_dot_product_attention"): + # Note: the original paper did *not* use SDPA, it's a free boost! + x = F.scaled_dot_product_attention( + q, k, v, attn_mask=attn_mask if not self.use_mask_unit_attn else None + ) + else: + attn = (q * self.scale) @ k.transpose(-1, -2) + if not self.use_mask_unit_attn: + attn += attn_mask + attn = attn.softmax(dim=-1) + x = attn @ v + + x = x.transpose(1, 3).reshape(B, -1, self.dim_out) + x = self.proj(x) + + return x + + +class SignHieraBlock(nn.Module): + def __init__( + self, + dim: int, + dim_out: int, + heads: int, + mlp_ratio: float = 4.0, + drop_path: float = 0.0, + norm_layer: nn.Module = nn.LayerNorm, + act_layer: nn.Module = nn.GELU, + q_stride: int = 1, + window_size: int = 0, + use_mask_unit_attn: bool = False, + ): + super().__init__() + + self.dim = dim + self.dim_out = dim_out + + self.norm1 = norm_layer(dim) + self.attn = MaskUnitAttention( + dim, dim_out, heads, q_stride, window_size, use_mask_unit_attn + ) + + self.norm2 = norm_layer(dim_out) + self.mlp = Mlp(dim_out, int(dim_out * mlp_ratio), act_layer=act_layer) + + self.drop_path = DropPath(drop_path) if drop_path > 0 else nn.Identity() + if dim != dim_out: + self.proj = nn.Linear(dim, dim_out) + + def forward(self, x: torch.Tensor, attn_mask: torch.Tensor) -> torch.Tensor: + # Attention + Q Pooling + x_norm = self.norm1(x) + if self.dim != self.dim_out: + x = do_pool(self.proj(x_norm), stride=self.attn.q_stride) + x = x + self.drop_path(self.attn(x_norm, attn_mask=attn_mask)) + + # MLP + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class Head(nn.Module): + def __init__( + self, + dim: int, + num_classes: int, + dropout_rate: float = 0.0, + act_func: Callable[[torch.Tensor], torch.Tensor] = lambda x: x.softmax(dim=-1), + ): + super().__init__() + self.dropout = nn.Dropout(dropout_rate) if dropout_rate > 0 else nn.Identity() + self.projection = nn.Linear(dim, num_classes) + # act_fun for eval and testing only + self.act_func = act_func + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.dropout(x) + x = self.projection(x) + if not self.training: + x = self.act_func(x) + return x + + +class PatchEmbed(nn.Module): + """Patch embed that supports any number of spatial dimensions (1d, 2d, 3d).""" + + def __init__( + self, + dim_in: int, + dim_out: int, + kernel: Tuple[int, ...], + stride: Tuple[int, ...], + padding: Tuple[int, ...], + ): + super().__init__() + + # Support any number of spatial dimensions + self.spatial_dims = len(kernel) + self.proj = conv_nd(self.spatial_dims)( + dim_in, + dim_out, + kernel_size=kernel, + stride=stride, + padding=padding, + ) + + def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor: + x = do_masked_conv(x, self.proj, mask) + x = x.reshape(x.shape[0], x.shape[1], -1).transpose(2, 1) + return x + + +class SignHiera(nn.Module): + def __init__( + self, + input_size: Tuple[int, ...] = (224, 224), + in_chans: int = 3, + embed_dim: int = 96, # initial embed dim + num_heads: int = 1, # initial number of heads + num_classes: int = 1000, + stages: Tuple[int, ...] = (2, 3, 16, 3), + q_pool: int = 3, # number of q_pool stages + q_stride: Tuple[int, ...] = (2, 2), + mask_unit_size: Tuple[int, ...] = (8, 8), # must divide q_stride ** (#stages-1) + # mask_unit_attn: which stages use mask unit attention? + mask_unit_attn: Tuple[bool, ...] = (True, True, False, False), + dim_mul: float = 2.0, + head_mul: float = 2.0, + patch_kernel: Tuple[int, ...] = (7, 7), + patch_stride: Tuple[int, ...] = (4, 4), + patch_padding: Tuple[int, ...] = (3, 3), + mlp_ratio: float = 4.0, + drop_path_rate: float = 0.0, + norm_layer: nn.Module = partial(nn.LayerNorm, eps=1e-6), + head_dropout: float = 0.0, + head_init_scale: float = 0.001, + sep_pos_embed: bool = False, + **kwargs, + ): + super().__init__() + + depth = sum(stages) + self.tokens_spatial_shape = [i // s for i, s in zip(input_size, patch_stride)] + num_tokens = math.prod(self.tokens_spatial_shape) + flat_mu_size = math.prod(mask_unit_size) + flat_q_stride = math.prod(q_stride) + + assert q_pool < len(stages) + self.q_pool, self.q_stride = q_pool, q_stride + self.mu_size, self.mask_unit_size = flat_mu_size, mask_unit_size + self.mask_spatial_shape = [ + i // s for i, s in zip(self.tokens_spatial_shape, self.mask_unit_size) + ] + self.stage_ends = [sum(stages[:i]) - 1 for i in range(1, len(stages) + 1)] + self.patch_stride = patch_stride + self.feature_dim = int(embed_dim * dim_mul * len(self.stage_ends)) + + self.patch_embed = PatchEmbed( + in_chans, embed_dim, patch_kernel, patch_stride, patch_padding + ) + + self.sep_pos_embed = sep_pos_embed + if sep_pos_embed: + self.pos_embed_spatial = nn.Parameter( + torch.zeros( + 1, + self.tokens_spatial_shape[1] * self.tokens_spatial_shape[2], + embed_dim, + ) + ) + self.pos_embed_temporal = nn.Parameter( + torch.zeros(1, self.tokens_spatial_shape[0], embed_dim) + ) + else: + self.pos_embed = nn.Parameter(torch.zeros(1, num_tokens, embed_dim)) + + # Setup roll and reroll modules + self.unroll = Unroll(input_size, patch_stride, [q_stride] * len(self.stage_ends[:-1])) + self.reroll = Reroll( + input_size, + patch_stride, + [q_stride] * len(self.stage_ends[:-1]), + self.stage_ends, + q_pool, + ) + # q_pool locations + q_pool_blocks = [x + 1 for x in self.stage_ends[:q_pool]] + self.q_pool_blocks = set(q_pool_blocks) + + # stochastic depth decay rule + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] + + # Transformer blocks + cur_stage = 0 + self.blocks = nn.ModuleList() + + for i in range(depth): + dim_out = embed_dim + # Mask unit or global attention. + # Lag by 1 block, so that global attention, + # applied post pooling on lower resolution + use_mask_unit_attn = mask_unit_attn[cur_stage] + + if i - 1 in self.stage_ends: + dim_out = int(embed_dim * dim_mul) + num_heads = int(num_heads * head_mul) + cur_stage += 1 + if i in q_pool_blocks: + flat_mu_size //= flat_q_stride + + block = SignHieraBlock( + dim=embed_dim, + dim_out=dim_out, + heads=num_heads, + mlp_ratio=mlp_ratio, + drop_path=dpr[i], + norm_layer=norm_layer, + q_stride=(flat_q_stride if i in q_pool_blocks else 1), + window_size=flat_mu_size, + use_mask_unit_attn=use_mask_unit_attn, + ) + + embed_dim = dim_out + self.blocks.append(block) + + self.norm = norm_layer(embed_dim) + self.head = Head(embed_dim, num_classes, dropout_rate=head_dropout) + + # Initialize everything + if sep_pos_embed: + nn.init.trunc_normal_(self.pos_embed_spatial, std=0.02) + nn.init.trunc_normal_(self.pos_embed_temporal, std=0.02) + else: + nn.init.trunc_normal_(self.pos_embed, std=0.02) + self.apply(partial(self._init_weights)) + self.head.projection.weight.data.mul_(head_init_scale) + self.head.projection.bias.data.mul_(head_init_scale) + + def _init_weights(self, m: nn.Module, init_bias: float = 0.02) -> None: + if isinstance(m, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d)): + nn.init.trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, init_bias) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, init_bias) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self) -> List[str]: + if self.sep_pos_embed: + return ["pos_embed_spatial", "pos_embed_temporal"] + else: + return ["pos_embed"] + + def get_random_mask( + self, + x: torch.Tensor, + mask_ratio: float, + attn_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Generates a random mask, mask_ratio fraction are dropped. + 1 is *keep*, 0 is *remove*. Useful for MAE, FLIP, etc. + """ + B = x.shape[0] + # Tokens selected for masking at mask unit level + num_windows = math.prod(self.mask_spatial_shape) # num_mask_units + len_keep = int(num_windows * (1 - mask_ratio)) + noise = torch.rand(B, num_windows, device=x.device) + + # Attention mask indicates tokens that are non-padding (1) or padding (0) + # Out of the non-padding tokens we take the one with the highest noise + # And bump up noise to 100 to guarantee that it gets masked + # We therefore ensure that at least one masked patch is non-padding + # This is necessary because we only compute loss on non-padding tokens, i.e. loss would otherwise be NaN + if attn_mask is not None: + noise_mask = torch.argmax(noise * attn_mask, dim=1) + noise[torch.arange(noise.size(0)), noise_mask] = 100.0 + # First (1x)x#MUyx#MUx tokens will not be padding, so by setting low value we guarantee that + # at least one non-padding token will be kept + noise[ + torch.arange(noise.size(0)), + torch.randint( + self.mask_spatial_shape[-2] * self.mask_spatial_shape[-1], + (noise.size(0),), + ), + ] = -100 + + # Sort noise for each sample + ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove + ids_restore = torch.argsort(ids_shuffle, dim=1) + + # Generate the binary mask: 1 is *keep*, 0 is *remove* + # Note this is opposite to original MAE + mask = torch.zeros([B, num_windows], device=x.device) + mask[:, :len_keep] = 1 + # Unshuffle to get the binary mask + mask = torch.gather(mask, dim=1, index=ids_restore) + + return mask.bool() + + def get_attention_mask(self, padding: torch.Tensor, device: torch.device) -> torch.Tensor: + """ + Creates a temporal attention mask based on the number of padding frames + """ + attn_mask = torch.ones( + (padding.shape[0], math.prod(self.mask_spatial_shape)), device=device + ) + + # #MUs that are padded + num_padding_mus = ( + (padding // (self.mask_unit_size[0] * self.patch_stride[0])) + * self.mask_spatial_shape[1] + * self.mask_spatial_shape[2] + ) + + for i in range(num_padding_mus.shape[0]): + if num_padding_mus[i] > 0: + attn_mask[i, -num_padding_mus[i] :] = 0 + + return attn_mask + + def get_pos_embed(self) -> torch.Tensor: + if self.sep_pos_embed: + return self.pos_embed_spatial.repeat( + 1, self.tokens_spatial_shape[0], 1 + ) + torch.repeat_interleave( + self.pos_embed_temporal, + self.tokens_spatial_shape[1] * self.tokens_spatial_shape[2], + dim=1, + ) + else: + return self.pos_embed + + def forward( + self, + x: torch.Tensor, + mask: torch.Tensor = None, + attn_mask: torch.Tensor = None, + return_intermediates: bool = False, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """ + mask should be a boolean tensor of shape [B, #MUt*#MUy*#MUx] where #MU are the number of mask units in that dim. + Note: 1 in mask is *keep*, 0 is *remove*; mask.sum(dim=-1) should be the same across the batch. + """ + + # Slowfast training passes in a list + if isinstance(x, list): + x = x[0] + + if attn_mask is None: + attn_mask = torch.ones( + (x.shape[0], math.prod(self.mask_spatial_shape)), device=x.device + ) + + intermediates = [] + + # Zero out both mask tokens and padding (attn_mask == 0) tokens in patch embedding conv + patch_embed_mask = torch.logical_and(mask, attn_mask) if mask is not None else attn_mask + x = self.patch_embed( + x, mask=patch_embed_mask.view(x.shape[0], 1, *self.mask_spatial_shape) + ) + + x = x + self.get_pos_embed() + x = self.unroll(x) + + # get spatial view of attention mask + attn_mask = attn_mask.view(attn_mask.shape[0], *self.mask_spatial_shape) + + # upsample by mask unit size, then flatten and unsqueeze channel dimension + for i, s in enumerate(self.mask_unit_size): + attn_mask = attn_mask.repeat_interleave(s, i) + attn_mask = attn_mask.view(attn_mask.shape[0], -1).unsqueeze(-1) + attn_mask = self.unroll(attn_mask) + + # Discard masked tokens + if mask is not None: + x = x[mask[..., None].tile(1, self.mu_size, x.shape[2])].view( + x.shape[0], -1, x.shape[-1] + ) + attn_mask = attn_mask[mask[..., None].tile(1, self.mu_size, attn_mask.shape[2])].view( + attn_mask.shape[0], -1, attn_mask.shape[-1] + ) + + for i, blk in enumerate(self.blocks): + x = blk(x, attn_mask=attn_mask) + + # Downsample attention mask + if i in self.q_pool_blocks: + attn_mask = ( + attn_mask.view(attn_mask.shape[0], math.prod(self.q_stride), -1, 1) + .max(1) + .values + ) + + # if return_intermediates and #i in self.stage_ends: + if i in self.stage_ends: + intermediates.append(self.reroll(x, i, mask=mask)) + + if mask is None: + x = x.mean(dim=1) + x = self.norm(x) + x = self.head(x) + + # x may not always be in spatial order here. + # e.g. if q_pool = 2, mask_unit_size = (8, 8), and + # q_stride = (2, 2), not all unrolls were consumed, + # intermediates[-1] is x in spatial order + if return_intermediates: + return x, intermediates + + return x + + def extract_features( + self, + x: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + padding: Optional[torch.Tensor] = None, + return_attn_mask: bool = False, + ) -> torch.Tensor: + """ + Extract features from a video tensor x + + """ + + attn_mask = ( + attention_mask + if attention_mask is not None + else ( + self.get_attention_mask(padding, device=x.device) if padding is not None else None + ) + ) + x = self.forward(x, attn_mask=attn_mask, return_intermediates=True) + + # Take last intermediate features and mean pool over spatial dimensions + x = x[1][-1].mean(dim=(2, 3)) + + # Remove padding + if padding is not None: + assert x.dim() == 3 + num_padding_units = padding // (self.mask_unit_size[0] * self.patch_stride[0]) + x = torch.cat( + [ + (features[: -num_padding_units[i]] if num_padding_units[i] > 0 else features) + for i, features in enumerate(x) + ], + dim=0, + ) + return x.view(-1, x.shape[-1]) + + def load_weights(self, checkpoint_path: str) -> None: + """ + Loads SignHiera weights from a pretrained model checkpoint + """ + + checkpoint = torch.load(checkpoint_path) + + if "model" in checkpoint.keys(): + checkpoint_model = checkpoint["model"] + else: + checkpoint_model = checkpoint["model_state"] + + _mismatch = False + new_checkpoint_model = {} + for k, v in checkpoint_model.items(): + if k in self.state_dict() and self.state_dict()[k].shape != v.shape: + print(f"Pruning {k} due to size mismatch") + _mismatch = True + else: + new_checkpoint_model[k] = v + + if _mismatch: + print( + "Warning: Not all parameters from the checkpoint state dict match the target shape. " + "Please check whether this is intended, e.g. when changing the clip size, or not." + ) + + # load pre-trained model + msg = self.load_state_dict(new_checkpoint_model, strict=False) + print(msg) + + @classmethod + def from_clip_model(cls, model_id: str, clip_model_path: str) -> nn.Module: + """ + Loads a SignHiera encoder from a pretrained CLIP model with SignHiera vision tower + """ + + import sys + + from ssvp_slt.modeling.clip import CLIP, CLIPTextCfg, CLIPVisionCfg + + checkpoint = torch.load(clip_model_path) + args = checkpoint["args"] + model_params = checkpoint["clip"] + + vision_cfg = CLIPVisionCfg( + model_id=args.model, + proj="mlp", + ) + + # FIXME: might cause errors if proj and pooler are not `mlp` and `mean_pooler` + text_cfg = CLIPTextCfg( + hf_model_name=args.text_model_name_or_path, + proj="mlp", + pooler_type="mean_pooler", + ) + clip = CLIP(embed_dim=768, vision_cfg=vision_cfg, text_cfg=text_cfg, output_dict=True) + + print(f"Loading CLIP weights from {clip_model_path}") + msg = clip.load_state_dict(model_params) + print(msg) + + print("Loading SignHiera weights from CLIP vision tower") + model = sys.modules[__name__].__dict__[model_id](**args.__dict__) + msg = model.load_state_dict(clip.visual.transformer.state_dict(), strict=False) + print(msg) + + return model + + +# Video models + + +@pretrained_model( + { + "mae_k400_ft_k400": "https://dl.fbaipublicfiles.com/hiera/hiera_base_16x224.pth", + }, + default="mae_k400_ft_k400", +) +def hiera_base_16x224(num_classes: int = 400, **kwdargs) -> SignHiera: + return SignHiera( + num_classes=num_classes, # K400 has 400 classes + input_size=(16, 224, 224), + q_stride=(1, 2, 2), + mask_unit_size=(1, 8, 8), + patch_kernel=(3, 7, 7), + patch_stride=(2, 4, 4), + patch_padding=(1, 3, 3), + sep_pos_embed=True, + **kwdargs, + ) + + +@pretrained_model( + { + "mae_k400_ft_k400": "https://dl.fbaipublicfiles.com/hiera/hiera_base_16x224.pth", + }, + default="mae_k400_ft_k400", +) +def hiera_base_128x224(num_classes: int = 400, **kwdargs) -> SignHiera: + return SignHiera( + num_classes=num_classes, # K400 has 400 classes + input_size=(128, 224, 224), + q_stride=(1, 2, 2), + mask_unit_size=(1, 8, 8), + patch_kernel=(3, 7, 7), + patch_stride=(2, 4, 4), + patch_padding=(1, 3, 3), + sep_pos_embed=True, + q_pool=3, + **kwdargs, + ) diff --git a/examples/inference/src/slt_inference/modeling/sign_hiera_utils.py b/examples/inference/src/slt_inference/modeling/sign_hiera_utils.py new file mode 100644 index 0000000..d3021db --- /dev/null +++ b/examples/inference/src/slt_inference/modeling/sign_hiera_utils.py @@ -0,0 +1,332 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# +# References: +# +# Hiera: https://github.com/facebookresearch/hiera/ +# slowfast: https://github.com/facebookresearch/SlowFast +# timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm +# -------------------------------------------------------- + +""" +Utils for SignHiera +Adapted from https://github.com/facebookresearch/hiera/blob/v0.1.2/hiera/hiera_utils.py +""" + +import math +from typing import Callable, Dict, List, Optional, Tuple, Type + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def pretrained_model(checkpoints: Dict[str, str], default: str = None) -> Callable: + """ + Loads a SignHiera model from a pretrained source (if pretrained=True). + Use "checkpoint" to specify the checkpoint. + """ + + def inner(model_func: Callable) -> Callable: + def model_def( + pretrained: bool = False, + checkpoint: str = default, + strict: bool = True, + **kwdargs, + ) -> nn.Module: + if pretrained: + if checkpoints is None: + raise RuntimeError( + "This model currently doesn't have pretrained weights available." + ) + elif checkpoint is None: + raise RuntimeError("No checkpoint specified.") + elif checkpoint not in checkpoints: + raise RuntimeError( + f"Invalid checkpoint specified ({checkpoint}). Options are: {list(checkpoints.keys())}." + ) + + state_dict = torch.hub.load_state_dict_from_url( + checkpoints[checkpoint], map_location="cpu" + ) + + if "head.projection.weight" in state_dict["model_state"]: + # Set the number of classes equal to the state_dict only if the user doesn't want to overwrite it + if "num_classes" not in kwdargs: + kwdargs["num_classes"] = state_dict["model_state"][ + "head.projection.weight" + ].shape[0] + # If the user specified a different number of classes, + # remove the projection weights or else we'll error out + elif ( + kwdargs["num_classes"] + != state_dict["model_state"]["head.projection.weight"].shape[0] + ): + del state_dict["model_state"]["head.projection.weight"] + del state_dict["model_state"]["head.projection.bias"] + + model = model_func(**kwdargs) + if pretrained: + # Disable being strict when trying to load a encoder-decoder model into an encoder-only model + if "decoder_pos_embed" in state_dict["model_state"] and not hasattr( + model, "decoder_pos_embed" + ): + strict = False + + new_state_dict = {} + for k, v in state_dict["model_state"].items(): + if ( + not strict + and k in model.state_dict() + and model.state_dict()[k].shape != v.shape + ): + print(f"Pruning {k} due to size mismatch") + else: + new_state_dict[k] = v + + missing_keys, unexpected_keys = model.load_state_dict( + new_state_dict, strict=strict + ) + print(f"Missing keys: {missing_keys}") + print(f"Unexpected keys: {unexpected_keys}") + + return model + + return model_def + + return inner + + +def conv_nd(n: int) -> Type[nn.Module]: + """ + Returns a conv with nd (e.g., Conv2d for n=2). Work up to n=3. + If you wanted a 4d Hiera, you could probably just implement this for n=4. (no promises) + """ + return [nn.Identity, nn.Conv1d, nn.Conv2d, nn.Conv3d][n] + + +def do_pool(x: torch.Tensor, stride: int) -> torch.Tensor: + # Refer to `Unroll` to see how this performs a maxpool-Nd + return x.view(x.shape[0], stride, -1, x.shape[-1]).max(dim=1).values + + +def get_resized_mask(target_size: torch.Size, mask: torch.Tensor) -> torch.Tensor: + # target_size: [(T), (H), W] + # (spatial) mask: [B, C, (t), (h), w] + if mask is None: + return mask + + assert len(mask.shape[2:]) == len(target_size) + if mask.shape[2:] != target_size: + return F.interpolate(mask.float(), size=target_size) + return mask + + +def do_masked_conv( + x: torch.Tensor, conv: nn.Module, mask: Optional[torch.Tensor] = None +) -> torch.Tensor: + """Zero-out the masked regions of the input before conv. + Prevents leakage of masked regions when using overlapping kernels. + """ + if conv is None: + return x + if mask is None: + return conv(x) + + mask = get_resized_mask(target_size=x.shape[2:], mask=mask) + return conv(x * mask.bool()) + + +def undo_windowing(x: torch.Tensor, shape: List[int], mu_shape: List[int]) -> torch.Tensor: + """ + Restore spatial organization by undoing windowed organization of mask units. + + Args: + x: organized by mask units windows, e.g. in 2d [B, #MUy*#MUx, MUy, MUx, C] + shape: current spatial shape, if it were not organized into mask unit + windows, e.g. in 2d [B, #MUy*MUy, #MUx*MUx, C]. + mu_shape: current mask unit shape, e.g. in 2d [MUy, MUx] + Returns: + x: e.g. in 2d, [B, #MUy*MUy, #MUx*MUx, C] + """ + D = len(shape) + B, C = x.shape[0], x.shape[-1] + # [B, #MUy*#MUx, MUy, MUx, C] -> [B, #MUy, #MUx, MUy, MUx, C] + num_MUs = [s // mu for s, mu in zip(shape, mu_shape)] + x = x.view(B, *num_MUs, *mu_shape, C) + + # [B, #MUy, #MUx, MUy, MUx, C] -> [B, #MUy*MUy, #MUx*MUx, C] + permute = ( + [0] + + sum( + [list(p) for p in zip(range(1, 1 + D), range(1 + D, 1 + 2 * D))], + [], + ) + + [len(x.shape) - 1] + ) + x = x.permute(permute).reshape(B, *shape, C) + + return x + + +class Unroll(nn.Module): + """ + Reorders the tokens such that patches are contiguous in memory. + E.g., given [B, (H, W), C] and stride of (Sy, Sx), this will re-order the tokens as + [B, (Sy, Sx, H // Sy, W // Sx), C] + + This allows operations like Max2d to be computed as x.view(B, Sx*Sy, -1, C).max(dim=1). + Not only is this faster, but it also makes it easy to support inputs of arbitrary + dimensions in addition to patch-wise sparsity. + + Performing this operation multiple times in sequence puts entire windows as contiguous + in memory. For instance, if you applied the stride (2, 2) 3 times, entire windows of + size 8x8 would be contiguous in memory, allowing operations like mask unit attention + computed easily and efficiently, while also allowing max to be applied sequentially. + + Note: This means that intermediate values of the model are not in HxW order, so they + need to be re-rolled if you want to use the intermediate values as a HxW feature map. + The last block of the network is fine though, since by then the strides are all consumed. + """ + + def __init__( + self, + input_size: Tuple[int, ...], + patch_stride: Tuple[int, ...], + unroll_schedule: List[Tuple[int, ...]], + ): + super().__init__() + self.size = [i // s for i, s in zip(input_size, patch_stride)] + self.schedule = unroll_schedule + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Input: Flattened patch embeddings [B, N, C] + Output: Patch embeddings [B, N, C] permuted such that [B, 4, N//4, C].max(1) etc. performs MaxPoolNd + """ + B, N, C = x.shape + + cur_size = self.size + + x = x.view(*([B] + cur_size + [C])) + + for strides in self.schedule: + # Move patches with the given strides to the batch dimension + + # Create a view of the tensor with the patch stride as separate dims + # For example in 2d: [B, H // Sy, Sy, W // Sx, Sx, C] + cur_size = [i // s for i, s in zip(cur_size, strides)] + new_shape = [B] + sum([[i, s] for i, s in zip(cur_size, strides)], []) + [C] + + x = x.view(new_shape) + + # Move the patch stride into the batch dimension + # For example in 2d: [B, Sy, Sx, H // Sy, W // Sx, C] + L = len(new_shape) + permute = [0] + list(range(2, L - 1, 2)) + list(range(1, L - 1, 2)) + [L - 1] + x = x.permute(permute) + + # Now finally flatten the relevant dims into the batch dimension + x = x.flatten(0, len(strides)) + B *= math.prod(strides) + + x = x.reshape(-1, math.prod(self.size), C) + + return x + + +class Reroll(nn.Module): + """ + Undos the "unroll" operation so that you can use intermediate features. + """ + + def __init__( + self, + input_size: Tuple[int, ...], + patch_stride: Tuple[int, ...], + unroll_schedule: List[Tuple[int, ...]], + stage_ends: List[int], + q_pool: int, + ): + super().__init__() + self.size = [i // s for i, s in zip(input_size, patch_stride)] + + # The first stage has to reverse everything + # The next stage has to reverse all but the first unroll, etc. + self.schedule = {} + size = self.size + for i in range(stage_ends[-1] + 1): + self.schedule[i] = unroll_schedule, size + # schedule unchanged if no pooling at a stage end + if i in stage_ends[:q_pool]: + if len(unroll_schedule) > 0: + size = [n // s for n, s in zip(size, unroll_schedule[0])] + unroll_schedule = unroll_schedule[1:] + + def forward(self, x: torch.Tensor, block_idx: int, mask: torch.Tensor = None) -> torch.Tensor: + """ + Roll the given tensor back up to spatial order assuming it's from the given block. + + If no mask is provided: + - Returns [B, H, W, C] for 2d, [B, T, H, W, C] for 3d, etc. + If a mask is provided: + - Returns [B, #MUs, MUy, MUx, C] for 2d, etc. + """ + schedule, size = self.schedule[block_idx] + B, N, C = x.shape + + D = len(size) + cur_mu_shape = [1] * D + + for strides in schedule: + # Extract the current patch from N + x = x.view(B, *strides, N // math.prod(strides), *cur_mu_shape, C) + + # Move that patch into the current MU + # Example in 2d: [B, Sy, Sx, N//(Sy*Sx), MUy, MUx, C] -> [B, N//(Sy*Sx), Sy, MUy, Sx, MUx, C] + L = len(x.shape) + permute = ( + [0, 1 + D] + + sum( + [list(p) for p in zip(range(1, 1 + D), range(1 + D + 1, L - 1))], + [], + ) + + [L - 1] + ) + x = x.permute(permute) + + # Reshape to [B, N//(Sy*Sx), *MU, C] + for i in range(D): + cur_mu_shape[i] *= strides[i] + x = x.reshape(B, -1, *cur_mu_shape, C) + N = x.shape[1] + + # Current shape (e.g., 2d: [B, #MUy*#MUx, MUy, MUx, C]) + x = x.view(B, N, *cur_mu_shape, C) + + # If masked, return [B, #MUs, MUy, MUx, C] + if mask is not None: + return x + + # If not masked, we can return [B, H, W, C] + x = undo_windowing(x, size, cur_mu_shape) + + return x + + +def unwrap_model(model: nn.Module) -> nn.Module: + """ + Recursively unwraps a model from potential containers (as used in distributed training). + + Args: + model (`torch.nn.Module`): The model to unwrap. + """ + # since there could be multiple levels of wrapping, unwrap recursively + if hasattr(model, "module"): + return unwrap_model(model.module) + else: + return model diff --git a/examples/inference/src/slt_inference/modeling/sign_t5.py b/examples/inference/src/slt_inference/modeling/sign_t5.py new file mode 100644 index 0000000..a36681c --- /dev/null +++ b/examples/inference/src/slt_inference/modeling/sign_t5.py @@ -0,0 +1,1830 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +""" +PyTorch SignT5 model. +Adapted from https://github.com/huggingface/transformers/blob/v4.32.0/src/transformers/models/t5/modeling_t5.py + +Main changes made: +- added `FeatureProjection` +- added label smoothing +- added `_reinit_weights` method +""" + +import copy +import math +import os +import warnings +from typing import Optional, Tuple, Union + +import torch +from torch import nn +from torch.nn import CrossEntropyLoss +from transformers.activations import ACT2FN +from transformers.modeling_outputs import (BaseModelOutput, + BaseModelOutputWithPastAndCrossAttentions, + Seq2SeqLMOutput, Seq2SeqModelOutput) +from transformers.modeling_utils import PretrainedConfig, PreTrainedModel +from transformers.pytorch_utils import (ALL_LAYERNORM_LAYERS, find_pruneable_heads_and_indices, + prune_linear_layer) +from transformers.utils import DUMMY_INPUTS, DUMMY_MASK, is_torch_fx_proxy, logging +from transformers.utils.model_parallel_utils import assert_device_map, get_device_map + +logger = logging.get_logger(__name__) + +#################################################### +# This dict contains ids and associated url +# for the pretrained weights provided with the models +#################################################### +T5_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "t5-small", + "t5-base", + "t5-large", + "t5-3b", + "t5-11b", + # See all T5 models at https://huggingface.co/models?filter=t5 +] + +T5_PRETRAINED_CONFIG_ARCHIVE_MAP = { + "t5-small": "https://huggingface.co/t5-small/resolve/main/config.json", + "t5-base": "https://huggingface.co/t5-base/resolve/main/config.json", + "t5-large": "https://huggingface.co/t5-large/resolve/main/config.json", + "t5-3b": "https://huggingface.co/t5-3b/resolve/main/config.json", + "t5-11b": "https://huggingface.co/t5-11b/resolve/main/config.json", +} + + +class SignT5Config(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`SignT5Model`]. It is used to + instantiate a T5 model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the T5 + [t5-small](https://huggingface.co/t5-small) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Arguments: + vocab_size (`int`, *optional*, defaults to 32128): + Vocabulary size of the T5 model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`T5Model`] or [`TFT5Model`]. + d_model (`int`, *optional*, defaults to 512): + Size of the encoder layers and the pooler layer. + d_kv (`int`, *optional*, defaults to 64): + Size of the key, query, value projections per attention head. The `inner_dim` of the projection layer will + be defined as `num_heads * d_kv`. + d_ff (`int`, *optional*, defaults to 2048): + Size of the intermediate feed forward layer in each `T5Block`. + num_layers (`int`, *optional*, defaults to 6): + Number of hidden layers in the Transformer encoder. + num_decoder_layers (`int`, *optional*): + Number of hidden layers in the Transformer decoder. Will use the same value as `num_layers` if not set. + num_heads (`int`, *optional*, defaults to 8): + Number of attention heads for each attention layer in the Transformer encoder. + relative_attention_num_buckets (`int`, *optional*, defaults to 32): + The number of buckets to use for each attention layer. + relative_attention_max_distance (`int`, *optional*, defaults to 128): + The maximum distance of the longer sequences for the bucket separation. + dropout_rate (`float`, *optional*, defaults to 0.1): + The ratio for all dropout layers. + classifier_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for classifier. + layer_norm_eps (`float`, *optional*, defaults to 1e-6): + The epsilon used by the layer normalization layers. + initializer_factor (`float`, *optional*, defaults to 1): + A factor for initializing all weight matrices (should be kept to 1, used internally for initialization + testing). + feed_forward_proj (`string`, *optional*, defaults to `"relu"`): + Type of feed forward layer to be used. Should be one of `"relu"` or `"gated-gelu"`. T5v1.1 uses the + `"gated-gelu"` feed forward projection. Original T5 uses `"relu"`. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). + """ + + model_type = "t5" + keys_to_ignore_at_inference = ["past_key_values"] + attribute_map = { + "hidden_size": "d_model", + "num_attention_heads": "num_heads", + "num_hidden_layers": "num_layers", + } + + def __init__( + self, + vocab_size=32128, + d_model=768, + d_kv=64, + d_ff=2048, + num_layers=12, + num_decoder_layers=12, + num_heads=12, + relative_attention_num_buckets=32, + relative_attention_max_distance=128, + dropout_rate=0.1, + layer_norm_epsilon=1e-6, + initializer_factor=1.0, + feed_forward_proj="gated-gelu", + is_encoder_decoder=True, + use_cache=True, + pad_token_id=0, + eos_token_id=1, + classifier_dropout=0.0, + feature_dim=768, + label_smoothing=0.0, + **kwargs, + ): + self.vocab_size = vocab_size + self.feature_dim = feature_dim + self.label_smoothing = label_smoothing + self.d_model = d_model + self.d_kv = d_kv + self.d_ff = d_ff + self.num_layers = num_layers + self.num_decoder_layers = ( + num_decoder_layers if num_decoder_layers is not None else self.num_layers + ) # default = symmetry + self.num_heads = num_heads + self.relative_attention_num_buckets = relative_attention_num_buckets + self.relative_attention_max_distance = relative_attention_max_distance + self.dropout_rate = dropout_rate + self.classifier_dropout = classifier_dropout + self.layer_norm_epsilon = layer_norm_epsilon + self.initializer_factor = initializer_factor + self.feed_forward_proj = feed_forward_proj + self.use_cache = use_cache + + act_info = self.feed_forward_proj.split("-") + self.dense_act_fn = act_info[-1] + self.is_gated_act = act_info[0] == "gated" + + if len(act_info) > 1 and act_info[0] != "gated" or len(act_info) > 2: + raise ValueError( + f"`feed_forward_proj`: {feed_forward_proj} is not a valid activation function of the dense layer." + "Please make sure `feed_forward_proj` is of the format `gated-{ACT_FN}` or `{ACT_FN}`, e.g. " + "'gated-gelu' or 'relu'" + ) + + # for backwards compatibility + if feed_forward_proj == "gated-gelu": + self.dense_act_fn = "gelu_new" + + super().__init__( + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + is_encoder_decoder=is_encoder_decoder, + **kwargs, + ) + + +#################################################### +# This is a conversion method from TF 1.0 to PyTorch +# More details: https://medium.com/huggingface/from-tensorflow-to-pytorch-265f40ef2a28 +#################################################### +def load_tf_weights_in_t5(model, config, tf_checkpoint_path): + """Load tf checkpoints in a pytorch model.""" + try: + import re + + import numpy as np + import tensorflow as tf + except ImportError: + logger.error( + "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " + "https://www.tensorflow.org/install/ for installation instructions." + ) + raise + tf_path = os.path.abspath(tf_checkpoint_path) + logger.info(f"Converting TensorFlow checkpoint from {tf_path}") + # Load weights from TF model + init_vars = tf.train.list_variables(tf_path) + names = [] + tf_weights = {} + for name, shape in init_vars: + logger.info(f"Loading TF weight {name} with shape {shape}") + array = tf.train.load_variable(tf_path, name) + names.append(name) + tf_weights[name] = array + + for txt_name in names: + name = txt_name.split("/") + # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v + # which are not required for using pretrained model + if any( + n + in [ + "adam_v", + "adam_m", + "AdamWeightDecayOptimizer", + "AdamWeightDecayOptimizer_1", + "global_step", + ] + for n in name + ): + logger.info(f"Skipping {'/'.join(name)}") + tf_weights.pop(txt_name, None) + continue + if "_slot_" in name[-1]: + logger.info(f"Skipping {'/'.join(name)}") + tf_weights.pop(txt_name, None) + continue + pointer = model + array = tf_weights[txt_name] + + for m_name in name: + if re.fullmatch(r"[A-Za-z]+_\d+", m_name): + scope_names = re.split(r"_(\d+)", m_name) + else: + scope_names = [m_name] + if scope_names[0] in ["kernel", "scale", "embedding"]: + pointer = getattr(pointer, "weight") + elif scope_names[0] == "self_attention": + pointer = getattr(pointer, "layer") + pointer = pointer[0] + elif scope_names[0] == "enc_dec_attention": + pointer = getattr(pointer, "layer") + pointer = pointer[1] + elif scope_names[0] == "dense_relu_dense": + pointer = getattr(pointer, "layer") + pointer = pointer[2] + elif scope_names[0] == "rms_norm": + if hasattr(pointer, "layer_norm"): + pointer = getattr(pointer, "layer_norm") + elif hasattr(pointer, "final_layer_norm"): + pointer = getattr(pointer, "final_layer_norm") + elif scope_names[0] == "scale": + pointer = getattr(pointer, "weight") + elif scope_names[0] == "output_bias" or scope_names[0] == "beta": + pointer = getattr(pointer, "bias") + elif scope_names[0] == "squad": + pointer = getattr(pointer, "classifier") + elif scope_names[0] == "decoder" and name[1] == "logits": + continue + elif scope_names[0] == "logits": + pointer = getattr(pointer, "lm_head") + elif scope_names[0] == "wi" and len(scope_names) > 1 and scope_names[1].isdigit(): + pointer = getattr(pointer, f"wi_{scope_names[1]}") + continue + else: + try: + pointer = getattr(pointer, scope_names[0]) + except AttributeError: + logger.info(f"Skipping {'/'.join(name)}") + continue + if len(scope_names) >= 2: + num = int(scope_names[1]) + pointer = pointer[num] + if scope_names[0] not in ["kernel", "scale", "embedding"]: + pointer = getattr(pointer, "weight") + if scope_names[0] != "embedding": + logger.info(f"Transposing numpy weight of shape {array.shape} for {name}") + array = np.transpose(array) + try: + if pointer.shape != array.shape: + raise ValueError( + f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched" + ) + except AssertionError as e: + e.args += (pointer.shape, array.shape) + raise + logger.info(f"Initialize PyTorch weight {name}") + pointer.data = torch.from_numpy(array.astype(np.float32)) + tf_weights.pop(txt_name, None) + + logger.info(f"Weights not copied to PyTorch model: {', '.join(tf_weights.keys())}.") + return model + + +class FeatureProjection(nn.Module): + def __init__(self, in_dim: int, out_dim: int): + super().__init__() + self.feature_proj = nn.Linear(in_dim, out_dim) + self.feature_proj_scale = nn.Parameter(torch.ones(out_dim)) + + def forward(self, features: torch.Tensor) -> torch.Tensor: + features = self.feature_proj_scale * self.feature_proj(features) + return features + + +class SignT5LayerNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + Construct a layernorm module in the T5 style. No bias and no subtraction of mean. + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + # T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean + # Square Layer Normalization https://arxiv.org/abs/1910.07467 thus varience is calculated + # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for + # half-precision inputs is done in fp32 + + variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + + # convert into half-precision if necessary + if self.weight.dtype in [torch.float16, torch.bfloat16]: + hidden_states = hidden_states.to(self.weight.dtype) + + return self.weight * hidden_states + + +try: + from apex.normalization import FusedRMSNorm + + SignT5LayerNorm = FusedRMSNorm # noqa + + logger.info("Discovered apex.normalization.FusedRMSNorm - will use it instead of T5LayerNorm") +except ImportError: + # using the normal T5LayerNorm + pass +except Exception: + logger.warning("discovered apex but it failed to load, falling back to T5LayerNorm") + pass + +ALL_LAYERNORM_LAYERS.append(SignT5LayerNorm) + + +class SignT5DenseActDense(nn.Module): + def __init__(self, config: SignT5Config): + super().__init__() + self.wi = nn.Linear(config.d_model, config.d_ff, bias=False) + self.wo = nn.Linear(config.d_ff, config.d_model, bias=False) + self.dropout = nn.Dropout(config.dropout_rate) + self.act = ACT2FN[config.dense_act_fn] + + def forward(self, hidden_states): + hidden_states = self.wi(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.dropout(hidden_states) + if ( + isinstance(self.wo.weight, torch.Tensor) + and hidden_states.dtype != self.wo.weight.dtype + and self.wo.weight.dtype != torch.int8 + ): + hidden_states = hidden_states.to(self.wo.weight.dtype) + hidden_states = self.wo(hidden_states) + return hidden_states + + +class SignT5DenseGatedActDense(nn.Module): + def __init__(self, config: SignT5Config): + super().__init__() + self.wi_0 = nn.Linear(config.d_model, config.d_ff, bias=False) + self.wi_1 = nn.Linear(config.d_model, config.d_ff, bias=False) + self.wo = nn.Linear(config.d_ff, config.d_model, bias=False) + self.dropout = nn.Dropout(config.dropout_rate) + self.act = ACT2FN[config.dense_act_fn] + + def forward(self, hidden_states): + hidden_gelu = self.act(self.wi_0(hidden_states)) + hidden_linear = self.wi_1(hidden_states) + hidden_states = hidden_gelu * hidden_linear + hidden_states = self.dropout(hidden_states) + + # To make 8bit quantization work for google/flan-t5-xxl, self.wo is kept in float32. + # See https://github.com/huggingface/transformers/issues/20287 + # we also make sure the weights are not in `int8` in case users will force `_keep_in_fp32_modules` to be `None`` + if ( + isinstance(self.wo.weight, torch.Tensor) + and hidden_states.dtype != self.wo.weight.dtype + and self.wo.weight.dtype != torch.int8 + ): + hidden_states = hidden_states.to(self.wo.weight.dtype) + + hidden_states = self.wo(hidden_states) + return hidden_states + + +class SignT5LayerFF(nn.Module): + def __init__(self, config: SignT5Config): + super().__init__() + if config.is_gated_act: + self.DenseReluDense = SignT5DenseGatedActDense(config) + else: + self.DenseReluDense = SignT5DenseActDense(config) + + self.layer_norm = SignT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.dropout = nn.Dropout(config.dropout_rate) + + def forward(self, hidden_states): + forwarded_states = self.layer_norm(hidden_states) + forwarded_states = self.DenseReluDense(forwarded_states) + hidden_states = hidden_states + self.dropout(forwarded_states) + return hidden_states + + +class SignT5Attention(nn.Module): + def __init__(self, config: SignT5Config, has_relative_attention_bias=False): + super().__init__() + self.is_decoder = config.is_decoder + self.has_relative_attention_bias = has_relative_attention_bias + self.relative_attention_num_buckets = config.relative_attention_num_buckets + self.relative_attention_max_distance = config.relative_attention_max_distance + self.d_model = config.d_model + self.key_value_proj_dim = config.d_kv + self.n_heads = config.num_heads + self.dropout = config.dropout_rate + self.inner_dim = self.n_heads * self.key_value_proj_dim + + # Mesh TensorFlow initialization to avoid scaling before softmax + self.q = nn.Linear(self.d_model, self.inner_dim, bias=False) + self.k = nn.Linear(self.d_model, self.inner_dim, bias=False) + self.v = nn.Linear(self.d_model, self.inner_dim, bias=False) + self.o = nn.Linear(self.inner_dim, self.d_model, bias=False) + + if self.has_relative_attention_bias: + self.relative_attention_bias = nn.Embedding( + self.relative_attention_num_buckets, self.n_heads + ) + self.pruned_heads = set() + self.gradient_checkpointing = False + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.n_heads, self.key_value_proj_dim, self.pruned_heads + ) + # Prune linear layers + self.q = prune_linear_layer(self.q, index) + self.k = prune_linear_layer(self.k, index) + self.v = prune_linear_layer(self.v, index) + self.o = prune_linear_layer(self.o, index, dim=1) + # Update hyper params + self.n_heads = self.n_heads - len(heads) + self.inner_dim = self.key_value_proj_dim * self.n_heads + self.pruned_heads = self.pruned_heads.union(heads) + + @staticmethod + def _relative_position_bucket( + relative_position, bidirectional=True, num_buckets=32, max_distance=128 + ): + """ + Adapted from Mesh Tensorflow: + https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593 + + Translate relative position to a bucket number for relative attention. The relative position is defined as + memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to + position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for + small absolute relative_position and larger buckets for larger absolute relative_positions. All relative + positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket. + This should allow for more graceful generalization to longer sequences than the model has been trained on + + Args: + relative_position: an int32 Tensor + bidirectional: a boolean - whether the attention is bidirectional + num_buckets: an integer + max_distance: an integer + + Returns: + a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets) + """ + relative_buckets = 0 + if bidirectional: + num_buckets //= 2 + relative_buckets += (relative_position > 0).to(torch.long) * num_buckets + relative_position = torch.abs(relative_position) + else: + relative_position = -torch.min(relative_position, torch.zeros_like(relative_position)) + # now relative_position is in the range [0, inf) + + # half of the buckets are for exact increments in positions + max_exact = num_buckets // 2 + is_small = relative_position < max_exact + + # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance + relative_position_if_large = max_exact + ( + torch.log(relative_position.float() / max_exact) + / math.log(max_distance / max_exact) + * (num_buckets - max_exact) + ).to(torch.long) + relative_position_if_large = torch.min( + relative_position_if_large, + torch.full_like(relative_position_if_large, num_buckets - 1), + ) + + relative_buckets += torch.where(is_small, relative_position, relative_position_if_large) + return relative_buckets + + def compute_bias(self, query_length, key_length, device=None): + """Compute binned relative position bias""" + if device is None: + device = self.relative_attention_bias.weight.device + context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] + memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :] + relative_position = memory_position - context_position # shape (query_length, key_length) + relative_position_bucket = self._relative_position_bucket( + relative_position, # shape (query_length, key_length) + bidirectional=(not self.is_decoder), + num_buckets=self.relative_attention_num_buckets, + max_distance=self.relative_attention_max_distance, + ) + values = self.relative_attention_bias( + relative_position_bucket + ) # shape (query_length, key_length, num_heads) + values = values.permute([2, 0, 1]).unsqueeze( + 0 + ) # shape (1, num_heads, query_length, key_length) + return values + + def forward( + self, + hidden_states, + mask=None, + key_value_states=None, + position_bias=None, + past_key_value=None, + layer_head_mask=None, + query_length=None, + use_cache=False, + output_attentions=False, + ): + """ + Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). + """ + # Input is (batch_size, seq_length, dim) + # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length) + # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head) + batch_size, seq_length = hidden_states.shape[:2] + + real_seq_length = seq_length + + if past_key_value is not None: + if len(past_key_value) != 2: + raise ValueError( + f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states" + ) + real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length + + key_length = real_seq_length if key_value_states is None else key_value_states.shape[1] + + def shape(states): + """projection""" + return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose( + 1, 2 + ) + + def unshape(states): + """reshape""" + return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim) + + def project(hidden_states, proj_layer, key_value_states, past_key_value): + """projects hidden states correctly to key/query states""" + if key_value_states is None: + # self-attn + # (batch_size, n_heads, seq_length, dim_per_head) + hidden_states = shape(proj_layer(hidden_states)) + elif past_key_value is None: + # cross-attn + # (batch_size, n_heads, seq_length, dim_per_head) + hidden_states = shape(proj_layer(key_value_states)) + + if past_key_value is not None: + if key_value_states is None: + # self-attn + # (batch_size, n_heads, key_length, dim_per_head) + hidden_states = torch.cat([past_key_value, hidden_states], dim=2) + elif past_key_value.shape[2] != key_value_states.shape[1]: + # checking that the `sequence_length` of the `past_key_value` is the same as + # the provided `key_value_states` to support prefix tuning + # cross-attn + # (batch_size, n_heads, seq_length, dim_per_head) + hidden_states = shape(proj_layer(key_value_states)) + else: + # cross-attn + hidden_states = past_key_value + return hidden_states + + # get query states + query_states = shape( + self.q(hidden_states) + ) # (batch_size, n_heads, seq_length, dim_per_head) + + # get key/value states + key_states = project( + hidden_states, + self.k, + key_value_states, + past_key_value[0] if past_key_value is not None else None, + ) + value_states = project( + hidden_states, + self.v, + key_value_states, + past_key_value[1] if past_key_value is not None else None, + ) + + # compute scores + scores = torch.matmul( + query_states, key_states.transpose(3, 2) + ) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 + + if position_bias is None: + if not self.has_relative_attention_bias: + position_bias = torch.zeros( + (1, self.n_heads, real_seq_length, key_length), + device=scores.device, + dtype=scores.dtype, + ) + if self.gradient_checkpointing and self.training: + position_bias.requires_grad = True + else: + position_bias = self.compute_bias( + real_seq_length, key_length, device=scores.device + ) + + # if key and values are already calculated + # we want only the last query position bias + if past_key_value is not None: + position_bias = position_bias[:, :, -hidden_states.size(1) :, :] + + if mask is not None: + position_bias = ( + position_bias + mask + ) # (batch_size, n_heads, seq_length, key_length) + + if self.pruned_heads: + mask = torch.ones(position_bias.shape[1]) + mask[list(self.pruned_heads)] = 0 + position_bias_masked = position_bias[:, mask.bool()] + else: + position_bias_masked = position_bias + + scores += position_bias_masked + attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as( + scores + ) # (batch_size, n_heads, seq_length, key_length) + attn_weights = nn.functional.dropout( + attn_weights, p=self.dropout, training=self.training + ) # (batch_size, n_heads, seq_length, key_length) + + # Mask heads if we want to + if layer_head_mask is not None: + attn_weights = attn_weights * layer_head_mask + + attn_output = unshape( + torch.matmul(attn_weights, value_states) + ) # (batch_size, seq_length, dim) + attn_output = self.o(attn_output) + + present_key_value_state = ( + (key_states, value_states) if (self.is_decoder and use_cache) else None + ) + outputs = (attn_output,) + (present_key_value_state,) + (position_bias,) + + if output_attentions: + outputs = outputs + (attn_weights,) + return outputs + + +class SignT5LayerSelfAttention(nn.Module): + def __init__(self, config, has_relative_attention_bias=False): + super().__init__() + self.SelfAttention = SignT5Attention( + config, has_relative_attention_bias=has_relative_attention_bias + ) + self.layer_norm = SignT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.dropout = nn.Dropout(config.dropout_rate) + + def forward( + self, + hidden_states, + attention_mask=None, + position_bias=None, + layer_head_mask=None, + past_key_value=None, + use_cache=False, + output_attentions=False, + ): + normed_hidden_states = self.layer_norm(hidden_states) + attention_output = self.SelfAttention( + normed_hidden_states, + mask=attention_mask, + position_bias=position_bias, + layer_head_mask=layer_head_mask, + past_key_value=past_key_value, + use_cache=use_cache, + output_attentions=output_attentions, + ) + hidden_states = hidden_states + self.dropout(attention_output[0]) + outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them + return outputs + + +class SignT5LayerCrossAttention(nn.Module): + def __init__(self, config): + super().__init__() + self.EncDecAttention = SignT5Attention(config, has_relative_attention_bias=False) + self.layer_norm = SignT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.dropout = nn.Dropout(config.dropout_rate) + + def forward( + self, + hidden_states, + key_value_states, + attention_mask=None, + position_bias=None, + layer_head_mask=None, + past_key_value=None, + use_cache=False, + query_length=None, + output_attentions=False, + ): + normed_hidden_states = self.layer_norm(hidden_states) + attention_output = self.EncDecAttention( + normed_hidden_states, + mask=attention_mask, + key_value_states=key_value_states, + position_bias=position_bias, + layer_head_mask=layer_head_mask, + past_key_value=past_key_value, + use_cache=use_cache, + query_length=query_length, + output_attentions=output_attentions, + ) + layer_output = hidden_states + self.dropout(attention_output[0]) + outputs = (layer_output,) + attention_output[1:] # add attentions if we output them + return outputs + + +class SignT5Block(nn.Module): + def __init__(self, config, has_relative_attention_bias=False): + super().__init__() + self.is_decoder = config.is_decoder + self.layer = nn.ModuleList() + self.layer.append( + SignT5LayerSelfAttention( + config, has_relative_attention_bias=has_relative_attention_bias + ) + ) + if self.is_decoder: + self.layer.append(SignT5LayerCrossAttention(config)) + + self.layer.append(SignT5LayerFF(config)) + + def forward( + self, + hidden_states, + attention_mask=None, + position_bias=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + encoder_decoder_position_bias=None, + layer_head_mask=None, + cross_attn_layer_head_mask=None, + past_key_value=None, + use_cache=False, + output_attentions=False, + return_dict=True, + ): + if past_key_value is not None: + if not self.is_decoder: + logger.warning( + "`past_key_values` is passed to the encoder. Please make sure this is intended." + ) + expected_num_past_key_values = 2 if encoder_hidden_states is None else 4 + + if len(past_key_value) != expected_num_past_key_values: + raise ValueError( + f"There should be {expected_num_past_key_values} past states. " + f"{'2 (past / key) for cross attention. ' if expected_num_past_key_values == 4 else ''}" + f"Got {len(past_key_value)} past key / value states" + ) + + self_attn_past_key_value = past_key_value[:2] + cross_attn_past_key_value = past_key_value[2:] + else: + self_attn_past_key_value, cross_attn_past_key_value = None, None + + self_attention_outputs = self.layer[0]( + hidden_states, + attention_mask=attention_mask, + position_bias=position_bias, + layer_head_mask=layer_head_mask, + past_key_value=self_attn_past_key_value, + use_cache=use_cache, + output_attentions=output_attentions, + ) + hidden_states, present_key_value_state = self_attention_outputs[:2] + attention_outputs = self_attention_outputs[ + 2: + ] # Keep self-attention outputs and relative position weights + + # clamp inf values to enable fp16 training + if hidden_states.dtype == torch.float16: + clamp_value = torch.where( + torch.isinf(hidden_states).any(), + torch.finfo(hidden_states.dtype).max - 1000, + torch.finfo(hidden_states.dtype).max, + ) + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + do_cross_attention = self.is_decoder and encoder_hidden_states is not None + if do_cross_attention: + # the actual query length is unknown for cross attention + # if using past key value states. Need to inject it here + if present_key_value_state is not None: + query_length = present_key_value_state[0].shape[2] + else: + query_length = None + + cross_attention_outputs = self.layer[1]( + hidden_states, + key_value_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + position_bias=encoder_decoder_position_bias, + layer_head_mask=cross_attn_layer_head_mask, + past_key_value=cross_attn_past_key_value, + query_length=query_length, + use_cache=use_cache, + output_attentions=output_attentions, + ) + hidden_states = cross_attention_outputs[0] + + # clamp inf values to enable fp16 training + if hidden_states.dtype == torch.float16: + clamp_value = torch.where( + torch.isinf(hidden_states).any(), + torch.finfo(hidden_states.dtype).max - 1000, + torch.finfo(hidden_states.dtype).max, + ) + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + # Combine self attn and cross attn key value states + if present_key_value_state is not None: + present_key_value_state = present_key_value_state + cross_attention_outputs[1] + + # Keep cross-attention outputs and relative position weights + attention_outputs = attention_outputs + cross_attention_outputs[2:] + + # Apply Feed Forward layer + hidden_states = self.layer[-1](hidden_states) + + # clamp inf values to enable fp16 training + if hidden_states.dtype == torch.float16: + clamp_value = torch.where( + torch.isinf(hidden_states).any(), + torch.finfo(hidden_states.dtype).max - 1000, + torch.finfo(hidden_states.dtype).max, + ) + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + outputs = (hidden_states,) + + if use_cache: + outputs = outputs + (present_key_value_state,) + attention_outputs + else: + outputs = outputs + attention_outputs + + return outputs # hidden-states, present_key_value_states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) + + +class SignT5ClassificationHead(nn.Module): + """Head for sentence-level classification tasks.""" + + def __init__(self, config: SignT5Config): + super().__init__() + self.dense = nn.Linear(config.d_model, config.d_model) + self.dropout = nn.Dropout(p=config.classifier_dropout) + self.out_proj = nn.Linear(config.d_model, config.num_labels) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dropout(hidden_states) + hidden_states = self.dense(hidden_states) + hidden_states = torch.tanh(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.out_proj(hidden_states) + return hidden_states + + +class SignT5PreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = SignT5Config + load_tf_weights = load_tf_weights_in_t5 + base_model_prefix = "transformer" + is_parallelizable = True + supports_gradient_checkpointing = True + _no_split_modules = ["SignT5Block"] + _keep_in_fp32_modules = ["wo"] + + @property + def dummy_inputs(self): + input_ids = torch.tensor(DUMMY_INPUTS) + input_mask = torch.tensor(DUMMY_MASK) + dummy_inputs = { + "decoder_input_ids": input_ids, + "input_ids": input_ids, + "decoder_attention_mask": input_mask, + } + return dummy_inputs + + def _init_weights(self, module): + """Initialize the weights""" + factor = self.config.initializer_factor # Used for testing weights initialization + if isinstance(module, SignT5LayerNorm): + module.weight.data.fill_(factor * 1.0) + elif isinstance( + module, + (SignT5Model, SignT5ForConditionalGeneration), + ): + # Mesh TensorFlow embeddings initialization + # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L1624 + module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0) + if hasattr(module, "lm_head") and not self.config.tie_word_embeddings: + module.lm_head.weight.data.normal_(mean=0.0, std=factor * 1.0) + if hasattr(module, "qa_outputs"): + module.qa_outputs.weight.data.normal_( + mean=0.0, std=factor * ((self.config.d_model) ** -0.5) + ) + module.qa_outputs.bias.data.zero_() + elif isinstance(module, SignT5ClassificationHead): + module.dense.weight.data.normal_( + mean=0.0, std=factor * ((self.config.d_model) ** -0.5) + ) + if hasattr(module.dense, "bias") and module.dense.bias is not None: + module.dense.bias.data.zero_() + module.out_proj.weight.data.normal_( + mean=0.0, std=factor * ((self.config.d_model) ** -0.5) + ) + if hasattr(module.out_proj, "bias") and module.out_proj.bias is not None: + module.out_proj.bias.data.zero_() + elif isinstance(module, SignT5DenseActDense): + # Mesh TensorFlow FF initialization + # See https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/transformer_layers.py#L56 + # and https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L89 + module.wi.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + if hasattr(module.wi, "bias") and module.wi.bias is not None: + module.wi.bias.data.zero_() + module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5)) + if hasattr(module.wo, "bias") and module.wo.bias is not None: + module.wo.bias.data.zero_() + elif isinstance(module, SignT5DenseGatedActDense): + module.wi_0.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + if hasattr(module.wi_0, "bias") and module.wi_0.bias is not None: + module.wi_0.bias.data.zero_() + module.wi_1.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + if hasattr(module.wi_1, "bias") and module.wi_1.bias is not None: + module.wi_1.bias.data.zero_() + module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5)) + if hasattr(module.wo, "bias") and module.wo.bias is not None: + module.wo.bias.data.zero_() + elif isinstance(module, SignT5Attention): + # Mesh TensorFlow attention initialization to avoid scaling before softmax + # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/attention.py#L136 + d_model = self.config.d_model + key_value_proj_dim = self.config.d_kv + n_heads = self.config.num_heads + module.q.weight.data.normal_( + mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5) + ) + module.k.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5)) + module.v.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5)) + module.o.weight.data.normal_( + mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5) + ) + if module.has_relative_attention_bias: + module.relative_attention_bias.weight.data.normal_( + mean=0.0, std=factor * ((d_model) ** -0.5) + ) + elif isinstance(module, FeatureProjection): + module.feature_proj.weight.data.normal_( + mean=0.0, std=factor * ((self.config.d_model) ** -0.5) + ) + if hasattr(module.feature_proj, "bias") and module.feature_proj.bias is not None: + module.feature_proj.bias.data.zero_() + module.feature_proj_scale.data.fill_(1.0) + + def _shift_right(self, input_ids): + decoder_start_token_id = self.config.decoder_start_token_id + pad_token_id = self.config.pad_token_id + + if decoder_start_token_id is None: + raise ValueError( + "self.model.config.decoder_start_token_id has to be defined. In T5 it is usually set to the pad_token_id. " + "See T5 docs for more information." + ) + + # shift inputs to the right + if is_torch_fx_proxy(input_ids): + # Item assignment is not supported natively for proxies. + shifted_input_ids = torch.full(input_ids.shape[:-1] + (1,), decoder_start_token_id) + shifted_input_ids = torch.cat([shifted_input_ids, input_ids[..., :-1]], dim=-1) + else: + shifted_input_ids = input_ids.new_zeros(input_ids.shape) + shifted_input_ids[..., 1:] = input_ids[..., :-1].clone() + shifted_input_ids[..., 0] = decoder_start_token_id + + if pad_token_id is None: + raise ValueError("self.model.config.pad_token_id has to be defined.") + # replace possible -100 values in labels by `pad_token_id` + shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) + + return shifted_input_ids + + +class SignT5Stack(SignT5PreTrainedModel): + def __init__(self, config, embed_tokens=None): + super().__init__(config) + + self.embed_tokens = embed_tokens + self.is_decoder = config.is_decoder + + if not config.is_decoder: + feature_dim = config.feature_dim if hasattr(config, "feature_dim") else 255 + self.feature_projection = FeatureProjection(feature_dim, config.d_model) + + self.block = nn.ModuleList( + [ + SignT5Block(config, has_relative_attention_bias=bool(i == 0)) + for i in range(config.num_layers) + ] + ) + self.final_layer_norm = SignT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.dropout = nn.Dropout(config.dropout_rate) + + # Initialize weights and apply final processing + self.post_init() + # Model parallel + self.model_parallel = False + self.device_map = None + self.gradient_checkpointing = False + + def parallelize(self, device_map=None): + warnings.warn( + "`T5Stack.parallelize` is deprecated and will be removed in v5 of Transformers, you should load your model" + " with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own" + " `device_map` but it needs to be a dictionary module_name to device, so for instance {'block.0': 0," + " 'block.1': 1, ...}", + FutureWarning, + ) + # Check validity of device_map + self.device_map = ( + get_device_map(len(self.block), range(torch.cuda.device_count())) + if device_map is None + else device_map + ) + assert_device_map(self.device_map, len(self.block)) + self.model_parallel = True + self.first_device = ( + "cpu" + if "cpu" in self.device_map.keys() + else "cuda:" + str(min(self.device_map.keys())) + ) + self.last_device = "cuda:" + str(max(self.device_map.keys())) + # Load onto devices + for k, v in self.device_map.items(): + for layer in v: + cuda_device = "cuda:" + str(k) + self.block[layer] = self.block[layer].to(cuda_device) + + # Set embed_tokens to first layer + self.embed_tokens = self.embed_tokens.to(self.first_device) + + if not self.config.is_decoder: + self.feature_projection = self.feature_projection.to(self.first_device) + + # Set final layer norm to last device + self.final_layer_norm = self.final_layer_norm.to(self.last_device) + + def deparallelize(self): + warnings.warn( + "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.", + FutureWarning, + ) + self.model_parallel = False + self.device_map = None + self.first_device = "cpu" + self.last_device = "cpu" + for i in range(len(self.block)): + self.block[i] = self.block[i].to("cpu") + self.embed_tokens = self.embed_tokens.to("cpu") + + if not self.config.is_decoder: + self.feature_projection = self.feature_projection.to("cpu") + + self.final_layer_norm = self.final_layer_norm.to("cpu") + torch.cuda.empty_cache() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, new_embeddings): + self.embed_tokens = new_embeddings + + def forward( + self, + input_ids=None, + attention_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + inputs_embeds=None, + head_mask=None, + cross_attn_head_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + # Model parallel + if self.model_parallel: + torch.cuda.set_device(self.first_device) + self.embed_tokens = self.embed_tokens.to(self.first_device) + + if not self.config.is_decoder: + self.feature_projection = self.feature_projection.to(self.first_device) + + use_cache = use_cache if use_cache is not None else self.config.use_cache + output_attentions = ( + output_attentions if output_attentions is not None else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + err_msg_prefix = "decoder_" if self.is_decoder else "" + raise ValueError( + f"You cannot specify both {err_msg_prefix}input_ids and {err_msg_prefix}inputs_embeds at the same time" + ) + elif input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + err_msg_prefix = "decoder_" if self.is_decoder else "" + raise ValueError( + f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds" + ) + + if inputs_embeds is None: + if self.embed_tokens is None: + raise ValueError("You have to initialize the model with valid token embeddings") + inputs_embeds = self.embed_tokens(input_ids) + + if not self.config.is_decoder: + inputs_embeds = self.feature_projection(inputs_embeds) + + batch_size, seq_length = input_shape + + # required mask seq length can be calculated via length of past + mask_seq_length = ( + past_key_values[0][0].shape[2] + seq_length + if past_key_values is not None + else seq_length + ) + + if use_cache is True: + if not self.is_decoder: + raise ValueError( + f"`use_cache` can only be set to `True` if {self} is used as a decoder" + ) + + # initialize past_key_values with `None` if past does not exist + if past_key_values is None: + past_key_values = [None] * len(self.block) + + if attention_mask is None: + attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.is_decoder and encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones( + encoder_hidden_shape, device=inputs_embeds.device, dtype=torch.long + ) + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_extended_attention_mask = None + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + # Prepare head mask if needed + head_mask = self.get_head_mask(head_mask, self.config.num_layers) + cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers) + present_key_value_states = () if use_cache else None + all_hidden_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + all_cross_attentions = () if (output_attentions and self.is_decoder) else None + position_bias = None + encoder_decoder_position_bias = None + + hidden_states = self.dropout(inputs_embeds) + + for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)): + layer_head_mask = head_mask[i] + cross_attn_layer_head_mask = cross_attn_head_mask[i] + # Model parallel + if self.model_parallel: + torch.cuda.set_device(hidden_states.device) + # Ensure that attention_mask is always on the same device as hidden_states + if attention_mask is not None: + attention_mask = attention_mask.to(hidden_states.device) + if position_bias is not None: + position_bias = position_bias.to(hidden_states.device) + if encoder_hidden_states is not None: + encoder_hidden_states = encoder_hidden_states.to(hidden_states.device) + if encoder_extended_attention_mask is not None: + encoder_extended_attention_mask = encoder_extended_attention_mask.to( + hidden_states.device + ) + if encoder_decoder_position_bias is not None: + encoder_decoder_position_bias = encoder_decoder_position_bias.to( + hidden_states.device + ) + if layer_head_mask is not None: + layer_head_mask = layer_head_mask.to(hidden_states.device) + if cross_attn_layer_head_mask is not None: + cross_attn_layer_head_mask = cross_attn_layer_head_mask.to( + hidden_states.device + ) + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + layer_module.forward, + hidden_states, + extended_attention_mask, + position_bias, + encoder_hidden_states, + encoder_extended_attention_mask, + encoder_decoder_position_bias, + layer_head_mask, + cross_attn_layer_head_mask, + None, # past_key_value is always None with gradient checkpointing + use_cache, + output_attentions, + ) + else: + layer_outputs = layer_module( + hidden_states, + attention_mask=extended_attention_mask, + position_bias=position_bias, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + encoder_decoder_position_bias=encoder_decoder_position_bias, + layer_head_mask=layer_head_mask, + cross_attn_layer_head_mask=cross_attn_layer_head_mask, + past_key_value=past_key_value, + use_cache=use_cache, + output_attentions=output_attentions, + ) + + # layer_outputs is a tuple with: + # hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) + if use_cache is False: + layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:] + + hidden_states, present_key_value_state = layer_outputs[:2] + + # We share the position biases between the layers - the first layer store them + # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights), + # (cross-attention position bias), (cross-attention weights) + position_bias = layer_outputs[2] + if self.is_decoder and encoder_hidden_states is not None: + encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3] + # append next layer key value states + if use_cache: + present_key_value_states = present_key_value_states + (present_key_value_state,) + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[3],) + if self.is_decoder: + all_cross_attentions = all_cross_attentions + (layer_outputs[5],) + + # Model Parallel: If it's the last layer for that device, put things on the next device + if self.model_parallel: + for k, v in self.device_map.items(): + if i == v[-1] and "cuda:" + str(k) != self.last_device: + hidden_states = hidden_states.to("cuda:" + str(k + 1)) + + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.dropout(hidden_states) + + # Add last layer + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + present_key_value_states, + all_hidden_states, + all_attentions, + all_cross_attentions, + ] + if v is not None + ) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=present_key_value_states, + hidden_states=all_hidden_states, + attentions=all_attentions, + cross_attentions=all_cross_attentions, + ) + + +class SignT5Model(SignT5PreTrainedModel): + _keys_to_ignore_on_load_unexpected = [ + "decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight", + ] + _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] + + def __init__(self, config: SignT5Config): + super().__init__(config) + self.shared = nn.Embedding(config.vocab_size, config.d_model) + + encoder_config = copy.deepcopy(config) + encoder_config.is_decoder = False + encoder_config.use_cache = False + encoder_config.is_encoder_decoder = False + self.encoder = SignT5Stack(encoder_config, self.shared) + + decoder_config = copy.deepcopy(config) + decoder_config.is_decoder = True + decoder_config.is_encoder_decoder = False + decoder_config.num_layers = config.num_decoder_layers + self.decoder = SignT5Stack(decoder_config, self.shared) + + # Initialize weights and apply final processing + self.post_init() + + # Model parallel + self.model_parallel = False + self.device_map = None + + def parallelize(self, device_map=None): + warnings.warn( + "`T5Model.parallelize` is deprecated and will be removed in v5 of Transformers, you should load your model" + " with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own" + " `device_map` but it needs to be a dictionary module_name to device, so for instance {'encoder.block.0':" + " 0, 'encoder.block.1': 1, ...}", + FutureWarning, + ) + self.device_map = ( + get_device_map(len(self.encoder.block), range(torch.cuda.device_count())) + if device_map is None + else device_map + ) + assert_device_map(self.device_map, len(self.encoder.block)) + self.encoder.parallelize(self.device_map) + self.decoder.parallelize(self.device_map) + self.model_parallel = True + + def deparallelize(self): + warnings.warn( + "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.", + FutureWarning, + ) + self.encoder.deparallelize() + self.decoder.deparallelize() + self.encoder = self.encoder.to("cpu") + self.decoder = self.decoder.to("cpu") + self.model_parallel = False + self.device_map = None + torch.cuda.empty_cache() + + def get_input_embeddings(self): + return self.shared + + def set_input_embeddings(self, new_embeddings): + self.shared = new_embeddings + self.encoder.set_input_embeddings(new_embeddings) + self.decoder.set_input_embeddings(new_embeddings) + + def _tie_weights(self): + if self.config.tie_word_embeddings: + self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared) + self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared) + + def get_encoder(self): + return self.encoder + + def get_decoder(self): + return self.decoder + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.BoolTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + decoder_head_mask: Optional[torch.FloatTensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.Tensor] = None, + decoder_inputs_embeds: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.FloatTensor], Seq2SeqModelOutput]: + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask + if head_mask is not None and decoder_head_mask is None: + if self.config.num_layers == self.config.num_decoder_layers: + decoder_head_mask = head_mask + + # Encode if needed (training, first prediction pass) + if encoder_outputs is None: + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + hidden_states = encoder_outputs[0] + + # Set device for model parallelism + if self.model_parallel: + torch.cuda.set_device(self.decoder.first_device) + hidden_states = hidden_states.to(self.decoder.first_device) + if decoder_input_ids is not None: + decoder_input_ids = decoder_input_ids.to(self.decoder.first_device) + if attention_mask is not None: + attention_mask = attention_mask.to(self.decoder.first_device) + if decoder_attention_mask is not None: + decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device) + + # Decode + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + inputs_embeds=decoder_inputs_embeds, + past_key_values=past_key_values, + encoder_hidden_states=hidden_states, + encoder_attention_mask=attention_mask, + head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + if not return_dict: + return decoder_outputs + encoder_outputs + + return Seq2SeqModelOutput( + last_hidden_state=decoder_outputs.last_hidden_state, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + +class SignT5ForConditionalGeneration(SignT5PreTrainedModel): + _keys_to_ignore_on_load_unexpected = [ + "decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight", + ] + _tied_weights_keys = [ + "encoder.embed_tokens.weight", + "decoder.embed_tokens.weight", + "lm_head.weight", + ] + + def __init__(self, config: SignT5Config): + super().__init__(config) + self.model_dim = config.d_model + + self.shared = nn.Embedding(config.vocab_size, config.d_model) + + encoder_config = copy.deepcopy(config) + encoder_config.is_decoder = False + encoder_config.use_cache = False + encoder_config.is_encoder_decoder = False + self.encoder = SignT5Stack(encoder_config, self.shared) + + decoder_config = copy.deepcopy(config) + decoder_config.is_decoder = True + decoder_config.is_encoder_decoder = False + decoder_config.num_layers = config.num_decoder_layers + self.decoder = SignT5Stack(decoder_config, self.shared) + + self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + # Model parallel + self.model_parallel = False + self.device_map = None + + def parallelize(self, device_map=None): + warnings.warn( + "`T5ForConditionalGeneration.parallelize` is deprecated and will be removed in v5 of Transformers, you" + " should load your model with `device_map='balanced'` in the call to `from_pretrained`. You can also" + " provide your own `device_map` but it needs to be a dictionary module_name to device, so for instance" + " {'encoder.block.0': 0, 'encoder.block.1': 1, ...}", + FutureWarning, + ) + self.device_map = ( + get_device_map(len(self.encoder.block), range(torch.cuda.device_count())) + if device_map is None + else device_map + ) + assert_device_map(self.device_map, len(self.encoder.block)) + self.encoder.parallelize(self.device_map) + self.decoder.parallelize(self.device_map) + self.lm_head = self.lm_head.to(self.decoder.first_device) + self.model_parallel = True + + def deparallelize(self): + warnings.warn( + "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.", + FutureWarning, + ) + self.encoder.deparallelize() + self.decoder.deparallelize() + self.encoder = self.encoder.to("cpu") + self.decoder = self.decoder.to("cpu") + self.lm_head = self.lm_head.to("cpu") + self.model_parallel = False + self.device_map = None + torch.cuda.empty_cache() + + def get_input_embeddings(self): + return self.shared + + def set_input_embeddings(self, new_embeddings): + self.shared = new_embeddings + self.encoder.set_input_embeddings(new_embeddings) + self.decoder.set_input_embeddings(new_embeddings) + + def _tie_weights(self): + if self.config.tie_word_embeddings: + self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared) + self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared) + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def get_output_embeddings(self): + return self.lm_head + + def get_encoder(self): + return self.encoder + + def get_decoder(self): + return self.decoder + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.BoolTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + decoder_head_mask: Optional[torch.FloatTensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[-100, 0, ..., + config.vocab_size - 1]`. All labels set to `-100` are ignored (masked), the loss is only computed for + labels in `[0, ..., config.vocab_size]` + """ + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask + if head_mask is not None and decoder_head_mask is None: + if self.config.num_layers == self.config.num_decoder_layers: + decoder_head_mask = head_mask + + # Encode if needed (training, first prediction pass) + if encoder_outputs is None: + # Convert encoder inputs in embeddings if needed + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + hidden_states = encoder_outputs[0] + + if self.model_parallel: + torch.cuda.set_device(self.decoder.first_device) + + if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None: + # get decoder inputs from shifting lm labels to the right + decoder_input_ids = self._shift_right(labels) + + # Set device for model parallelism + if self.model_parallel: + torch.cuda.set_device(self.decoder.first_device) + hidden_states = hidden_states.to(self.decoder.first_device) + if decoder_input_ids is not None: + decoder_input_ids = decoder_input_ids.to(self.decoder.first_device) + if attention_mask is not None: + attention_mask = attention_mask.to(self.decoder.first_device) + if decoder_attention_mask is not None: + decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device) + + # Decode + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + inputs_embeds=decoder_inputs_embeds, + past_key_values=past_key_values, + encoder_hidden_states=hidden_states, + encoder_attention_mask=attention_mask, + head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = decoder_outputs[0] + + # Set device for model parallelism + if self.model_parallel: + torch.cuda.set_device(self.encoder.first_device) + self.lm_head = self.lm_head.to(self.encoder.first_device) + sequence_output = sequence_output.to(self.lm_head.weight.device) + + if self.config.tie_word_embeddings: + # Rescale output before projecting on vocab + # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 + sequence_output = sequence_output * (self.model_dim**-0.5) + + lm_logits = self.lm_head(sequence_output) + + loss = None + if labels is not None: + label_smoothing = ( + self.config.label_smoothing if hasattr(self.config, "label_smoothing") else 0.1 + ) + loss_fct = CrossEntropyLoss( + ignore_index=self.config.pad_token_id, label_smoothing=label_smoothing + ) + # move labels to correct device to enable PP + labels = labels.to(lm_logits.device) + loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1)) + # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666 + + if not return_dict: + output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs + return ((loss,) + output) if loss is not None else output + + return Seq2SeqLMOutput( + loss=loss, + logits=lm_logits, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + head_mask=None, + decoder_head_mask=None, + decoder_attention_mask=None, + cross_attn_head_mask=None, + use_cache=None, + encoder_outputs=None, + **kwargs, + ): + # cut decoder_input_ids if past_key_values is used + if past_key_values is not None: + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] + + return { + "decoder_input_ids": input_ids, + "past_key_values": past_key_values, + "encoder_outputs": encoder_outputs, + "attention_mask": attention_mask, + "head_mask": head_mask, + "decoder_head_mask": decoder_head_mask, + "decoder_attention_mask": decoder_attention_mask, + "cross_attn_head_mask": cross_attn_head_mask, + "use_cache": use_cache, + } + + def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): + return self._shift_right(labels) + + def _reorder_cache(self, past_key_values, beam_idx): + # if decoder past is not included in output + # speedy decoding is disabled and no need to reorder + if past_key_values is None: + logger.warning( + "You might want to consider setting `use_cache=True` to speed up decoding" + ) + return past_key_values + + reordered_decoder_past = () + for layer_past_states in past_key_values: + # get the correct batch idx from layer past batch dim + # batch dim of `past` is at 2nd position + reordered_layer_past_states = () + for layer_past_state in layer_past_states: + # need to set correct `past` for each of the four key / value states + reordered_layer_past_states = reordered_layer_past_states + ( + layer_past_state.index_select(0, beam_idx.to(layer_past_state.device)), + ) + + if reordered_layer_past_states[0].shape != layer_past_states[0].shape: + raise ValueError( + f"reordered_layer_past_states[0] shape {reordered_layer_past_states[0].shape} and layer_past_states[0] shape {layer_past_states[0].shape} mismatched" + ) + if len(reordered_layer_past_states) != len(layer_past_states): + raise ValueError( + f"length of reordered_layer_past_states {len(reordered_layer_past_states)} and length of layer_past_states {len(layer_past_states)} mismatched" + ) + + reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,) + return reordered_decoder_past + + def _reinit_weights(self): + print("Reinitializing weights") + self.apply(self._init_weights) diff --git a/examples/inference/src/slt_inference/modeling/sonar_decoder.py b/examples/inference/src/slt_inference/modeling/sonar_decoder.py new file mode 100644 index 0000000..3a01f75 --- /dev/null +++ b/examples/inference/src/slt_inference/modeling/sonar_decoder.py @@ -0,0 +1,114 @@ +import os +from pathlib import Path +from typing import Union + +import numpy as np +import sentencepiece as spm +import torch +from fairseq.models import FairseqEncoder +from fairseq.models.fairseq_model import FairseqEncoderDecoderModel +from fairseq.models.transformer.transformer_decoder import TransformerDecoderBase +from fairseq.sequence_generator import SequenceGenerator +from torch import device + + +class DummyEncoder(FairseqEncoder): + def __init__(self): + super().__init__(None) + + def forward_torchscript(self, net_input): + return {"encoder_out": [net_input["source"].unsqueeze(0)], "encoder_padding_mask": [None]} + + def reorder_encoder_out(self, encoder_out_dict, new_order): + encoder_out_dict["encoder_out"][0] = encoder_out_dict["encoder_out"][0].index_select( + 1, new_order + ) + return encoder_out_dict + + +class TmodulesTextDecoder: + def __init__( + self, + path_decoder, + path_spm, + beam_size=5, + len_penalty=1.0, + temperature=1.0, + no_repeat_ngram_size=0, + max_len=100, + device=torch.device("cpu"), + ): + self.path_decoder = path_decoder + + checkpoint = torch.load(Path(path_decoder).resolve()) + self.decoder = TransformerDecoderBase( + checkpoint["cfg"], checkpoint["dictionary"], checkpoint["embed_tokens"] + ) + self.decoder.load_state_dict(checkpoint["state_dict"]) + self.decoder.eval().to(device) + self.spm_out = spm.SentencePieceProcessor(model_file=path_spm) + self.dict_out = checkpoint["dictionary"] + self.device = device + self.dummy_encoder = DummyEncoder() + embedding_model = FairseqEncoderDecoderModel(self.dummy_encoder, self.decoder) + self.generator = SequenceGenerator( + [embedding_model], + self.dict_out, + beam_size=beam_size, + len_penalty=len_penalty, + temperature=temperature, + no_repeat_ngram_size=no_repeat_ngram_size, + max_len=max_len, + ) + + def decode_file(self, file_path, lang="eng_Latn"): + nbex = int(os.path.getsize(file_path) / 1024 / 2) + embeddings = np.array(np.memmap(file_path, mode="r", dtype=np.float16, shape=(nbex, 1024))) + return self.decode(embeddings, lang=lang) + + def decode(self, embeddings, lang="eng_Latn", bz=100): + if isinstance(embeddings, (np.ndarray, np.generic)): + embeddings = torch.FloatTensor(embeddings) + + if len(embeddings.shape) == 1: + embeddings = embeddings.unsqueeze(0) + + batches = torch.split(embeddings, bz) + preds = [] + + for batch in batches: + sample = { + "net_input": { + "source": batch.to(self.device), + "padding_mask": torch.tensor([[False]] * batch.shape[0]).to(self.device), + } + } + prefix_tokens = ( + torch.LongTensor([[self.dict_out.index(f"__{lang}__")]]) + .expand(batch.shape[0], 1) + .to(self.device) + ) + preds = preds + self.generator.forward(sample, prefix_tokens=prefix_tokens) + + gens = [] + for i in range(len(preds)): + gens.append( + self.spm_out.decode_pieces( + self.dict_out.string(preds[i][0]["tokens"][1:]).split(" ") + ).replace(" ", "") + ) + return gens + + +def load_decoder( + model_path: str, spm_path: str, *, freeze: bool = True, device: Union[str, device] = "cpu" +): + model = TmodulesTextDecoder( + path_decoder=model_path, path_spm=spm_path, max_len=200, device=device + ) + + if freeze: + for p in model.decoder.parameters(): + p.requires_grad = False + + return model diff --git a/examples/inference/src/slt_inference/modeling/sonar_t5_encoder.py b/examples/inference/src/slt_inference/modeling/sonar_t5_encoder.py new file mode 100644 index 0000000..2b4fa55 --- /dev/null +++ b/examples/inference/src/slt_inference/modeling/sonar_t5_encoder.py @@ -0,0 +1,96 @@ +from dataclasses import dataclass +from typing import Optional + +import torch +from fairseq2.models.sequence import SequenceBatch +from fairseq2.nn.padding import PaddingMask +from fairseq2.typing import DataType, Device +from omegaconf import DictConfig +from torch import Tensor, nn + +from .sign_t5 import SignT5Config, SignT5Model + + +@dataclass +class SonarEncoderOutput: + """Dataclass for both speech and text SONAR encoder outputs""" + + encoded_seqs: Tensor + """Holds the output of the encoder + *Shape:* :math:`(N,S,M)`, where :math:`N` is the batch size, + :math:`S` is the sequence length, and :math:`M` is the + dimensionality of the model. + """ + + sentence_embeddings: Tensor + """ Pooled representation, derived from encoded_seqs by pooling in dim=1 + *Shape:* :math:`(N,M)`, where :math:`N` is the batch size, and :math:`M` is the + dimensionality of the model. + """ + + padding_mask: Optional[PaddingMask] + """Optional, the floating padding mask over sequences (-inf means masked element) + *Shape:* :math:`(N,S)`, where :math:`N` is the batch size, + :math:`S` is the sequence length. + """ + + +class SignT5SonarEncoder(SignT5Model): + def __init__(self, config: SignT5Config, output_dim: int = 1024): + super().__init__(config) + + self.projection_out = nn.Linear(config.d_model, output_dim, bias=False) + self.projection_out.weight.data.normal_(mean=0.0, std=1e-4) + + @property + def bos_idx(self): + return self.config.decoder_start_token_id + + def forward(self, batch: SequenceBatch) -> SonarEncoderOutput: + + seqs = batch.seqs + padding_mask = batch.padding_mask.materialize() if batch.padding_mask is not None else None + + encoded_seqs = self.encoder( + attention_mask=padding_mask, + inputs_embeds=seqs, + )[0] + + decoder_out = self.decoder( + input_ids=self._get_pooling_tokens(batch.batch_size, seqs.device), + encoder_hidden_states=encoded_seqs, + encoder_attention_mask=padding_mask, + )[0] + + sentence_embeddings = self.projection_out(decoder_out).squeeze(1) + + return SonarEncoderOutput( + encoded_seqs=encoded_seqs, + sentence_embeddings=sentence_embeddings, + padding_mask=batch.padding_mask, + ) + + def _get_pooling_tokens(self, batch_size: int, device: Device) -> Tensor: + return torch.tensor( + [self.bos_idx] * batch_size, dtype=torch.int64, device=device + ).unsqueeze(1) + + +def create_sonar_signt5_encoder_model( + config: DictConfig, *, device: Optional[Device] = None, dtype: Optional[DataType] = None +): + + config = SignT5Config.from_pretrained( + config.base_model_name, + decoder_start_token_id=0, + dropout_rate=0.0, + ) + model = SignT5SonarEncoder(config) + + if device is not None: + model.to(device) + + if dtype is not None: + model.to(dtype) + + return model diff --git a/examples/inference/src/slt_inference/preprocessing.py b/examples/inference/src/slt_inference/preprocessing.py new file mode 100644 index 0000000..9ede1da --- /dev/null +++ b/examples/inference/src/slt_inference/preprocessing.py @@ -0,0 +1,709 @@ +import math +import time +from functools import lru_cache +from pathlib import Path +from typing import Dict, List, NamedTuple, Optional, Sequence, Tuple, Union + +import cv2 +import dlib +import numpy as np +import torch +from omegaconf import DictConfig + + +class BboxTuple(NamedTuple): + """ + A type representing a bounding box. + """ + + x1: int + y1: int + x2: int + y2: int + + +Tube = List[Tuple[int, BboxTuple]] + + +def load_video(video_path: Path) -> Tuple[List[np.ndarray], float]: + """ + Load a video into a list of frames. + + Args: + video_path (Path): The path to the video file to load. + + Returns: + Tuple[List[np.ndarray], float]: A list of video frames and the framerate + """ + + cap = cv2.VideoCapture(str(video_path)) + fps = cap.get(cv2.CAP_PROP_FPS) + frames = [] + while True: + ret, frame = cap.read() + if not ret: + break + frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) + return frames, fps + + +def write_video(frames: np.ndarray, output_path: str, fps: float) -> None: + """ + Write a video frame array to disk + + Args: + frames (np.ndarray): Video frame array of shape `(T, H, W, C)` + output_path (str): Outputh path where video gets written + fps (float): Output framerate + + Returns: + None + """ + + frames = frames[..., ::-1] + fourcc = cv2.VideoWriter_fourcc(*"mp4v") + out = cv2.VideoWriter(output_path, fourcc, fps, (frames.shape[2], frames.shape[1])) + for i in range(frames.shape[0]): + out.write(frames[i]) + out.release() + + +def calculate_optical_flow(frames: Sequence[np.ndarray]) -> List[np.ndarray]: + """ + Calculate the optical flow between consecutive frames in a sequence of frames. + + Args: + frames (Sequence[np.ndarray]): Sequence of frames. + + Returns: + List[np.ndarray]: List of magnitudes representing the optical flow. + """ + previous_gray = None + optical_flow_magnitudes = [] + + for current_frame in frames: + scale_ratio = determine_scale_ratio(current_frame) + + current_gray = cv2.cvtColor(current_frame, cv2.COLOR_RGB2GRAY) + resized_dimensions = ( + int(current_frame.shape[1] * scale_ratio), + int(current_frame.shape[0] * scale_ratio), + ) + current_gray = cv2.resize(current_gray, resized_dimensions) + + if previous_gray is not None: + flow = cv2.calcOpticalFlowFarneback( + previous_gray, current_gray, None, 0.5, 3, 15, 3, 5, 1.2, 0 + ) + magnitude, _ = cv2.cartToPolar(flow[..., 0], flow[..., 1]) + normalized_magnitude = normalize_magnitude(magnitude) + resized_magnitude = cv2.resize( + normalized_magnitude, (current_frame.shape[1], current_frame.shape[0]) + ) + else: + resized_magnitude = np.zeros( + (current_frame.shape[0], current_frame.shape[1]), dtype=np.uint8 + ) + + previous_gray = current_gray + optical_flow_magnitudes.append(resized_magnitude) + + return optical_flow_magnitudes + + +def determine_scale_ratio(frame: np.ndarray) -> float: + """ + Determine the scale ratio based on the frame's width. + + Args: + frame (np.ndarray): The current frame. + + Returns: + float: The scale ratio. + """ + if frame.shape[1] > 640 and frame.shape[1] > 960: + return 1 / 4 + elif frame.shape[1] >= 960 and frame.shape[1] < 1280: + return 1 / 6 + elif frame.shape[1] >= 1280: + return 1 / 8 + else: + return 1 / 2 + + +def normalize_magnitude(magnitude: np.ndarray) -> np.ndarray: + """ + Normalize the magnitude of the optical flow. + + Args: + magnitude (np.ndarray): The magnitude of the optical flow. + + Returns: + np.ndarray: The normalized magnitude. + """ + return ( + 255.0 * (magnitude - magnitude.min()) / max(float(magnitude.max() - magnitude.min()), 1) + ).astype(np.uint8) + + +def find_target_bbox( + bbox_arr: Sequence[Sequence[BboxTuple]], + opts: Sequence[np.ndarray], + iou_thr: float = 0.5, + len_ratio_thr: float = 0.5, +) -> Tuple[Optional[BboxTuple], List[Tube],]: + """ + Function to find the target bounding box and tubes. + + Args: + bbox_arr (Sequence[Sequence[BboxTuple]]): Sequence of bounding boxes. + opts (Sequence[np.ndarray]): Sequence of optical flow arrays. + iou_thr (float, optional): Intersection over Union threshold. Defaults to 0.5. + len_ratio_thr (float, optional): Length ratio threshold. Defaults to 0.5. + + Returns: + Tuple[Optional[BboxTuple], List[Tube]: Target bounding box and tubes. + """ + tubes = [] + total_bboxes = sum(len(x) for x in bbox_arr) + + while total_bboxes > 0: + anchor = next((i, bbox_arr[i].pop()) for i, bboxes in enumerate(bbox_arr) if bboxes) + tube = [anchor] + for i, bboxes in enumerate(bbox_arr): + if anchor[0] == i or not bboxes: + continue + ious = np.array([get_iou(anchor[1], bbox) for bbox in bboxes]) + max_iou_index = ious.argmax() + if ious[max_iou_index] > iou_thr: + target_bbox = bboxes.pop(max_iou_index) + tube.append([i, target_bbox]) + tubes.append(tube) + total_bboxes -= len(tube) + + mean_vals_and_tubes = [ + (calculate_mean_val(tube, opts), tube) + for tube in tubes + if len(tube) / len(opts) > len_ratio_thr + ] + _, best_tube = max(mean_vals_and_tubes) if mean_vals_and_tubes else (-1, None) + + target_bbox = ( + tuple(np.array([bbox[1] for bbox in best_tube]).mean(axis=0)) if best_tube else None + ) + return target_bbox, tubes + + +@lru_cache(maxsize=None) +def get_iou(bbox_a: BboxTuple, bbox_b: BboxTuple) -> float: + """ + Calculate the Intersection over Union (IoU) between two bounding boxes. + The IoU is defined as the area of the overlap between the two bounding boxes divided by the area of their union. + + Args: + bbox_a (BboxTuple): The first bounding box, represented as a tuple of (x1, y1, x2, y2). + bbox_b (BboxTuple): The second bounding box, represented as a tuple of (x1, y1, x2, y2). + + Returns: + float: The IoU between the two bounding boxes. + """ + x_a = max(bbox_a[0], bbox_b[0]) + y_a = max(bbox_a[1], bbox_b[1]) + x_b = min(bbox_a[2], bbox_b[2]) + y_b = min(bbox_a[3], bbox_b[3]) + inter_area = max(0, x_b - x_a + 1) * max(0, y_b - y_a + 1) + bbox_a_area = (bbox_a[2] - bbox_a[0] + 1) * (bbox_a[3] - bbox_a[1] + 1) + bbox_b_area = (bbox_b[2] - bbox_b[0] + 1) * (bbox_b[3] - bbox_b[1] + 1) + + iou = inter_area / float(bbox_a_area + bbox_b_area - inter_area) + + return iou + + +def calculate_mean_val(tube: Tube, opts: List[np.ndarray]) -> float: + """ + Function to calculate the mean value of a tube. + + Args: + tube (Tube): Tube to calculate mean value for. + opts (List[np.ndarray]): List of options. + + Returns: + float: Mean value of the tube. + """ + mean_val = sum( + opts[frame_index][max(y0, 0) : y1, max(x0, 0) : x1].mean() + for frame_index, (x0, y0, x1, y1) in ((x[0], tuple(map(int, x[1]))) for x in tube) + ) + return mean_val / len(tube) + + +def calculate_bbox_from_tubes(tubes: Sequence[Tube]) -> Optional[BboxTuple]: + """ + Calculate bounding box from tubes. + + Args: + tubes (Sequence[Tube]): Sequence of tubes. A tube is a sequence of (frame index, bounding box) tuples. + + Returns: + Optional[BboxTuple]: Bounding box, or None if not found. + """ + total_sizes = [ + sum((bbox[3] - bbox[0]) * (bbox[2] - bbox[1]) for _, bbox in tube) for tube in tubes + ] + if max(total_sizes) > 0: + idx = np.array(total_sizes).argmax() + return tuple(np.array([x for _, x in tubes[idx]]).mean(axis=0)) + return None + + +def crop_resize(imgs: Sequence[np.ndarray], bbox: BboxTuple, target_size: int) -> np.ndarray: + """ + This function crops and resizes frames based on the provided bounding box and target size. + + Args: + imgs (Sequence[np.ndarray]): Sequence of frames to be processed. + bbox (BboxTuple): Bounding box coordinates (x0, y0, x1, y1). + target_size (int): The target size for the output images. + + Returns: + np.ndarray: Stacked array of processed frames. + """ + x0, y0, x1, y1 = bbox + + exp = abs((x1 - x0) - (y1 - y0)) / 2 + if x1 - x0 < y1 - y0: + x0, x1 = x0 - exp, x1 + exp + else: + y0, y1 = y0 - exp, y1 + exp + x0, x1, y0, y1 = map(int, (x0, x1, y0, y1)) + + # Calculate expansion values for each side + left_expand = max(-x0, 0) + up_expand = max(-y0, 0) + right_expand = max(x1 - imgs[0].shape[1] + 1, 0) + down_expand = max(y1 - imgs[0].shape[0] + 1, 0) + + # Pad, crop, and resize each frame + rois = np.stack( + [ + cv2.resize( + cv2.copyMakeBorder( + img, + up_expand, + down_expand, + left_expand, + right_expand, + cv2.BORDER_CONSTANT, + (0, 0, 0), + )[y0 + up_expand : y1 + up_expand, x0 + left_expand : x1 + left_expand], + (target_size, target_size), + ) + for img in imgs + ] + ) + + return rois + + +def temporal_sampling( + frames: torch.Tensor, + start_idx: int, + end_idx: int, + num_samples: int, + return_index: bool = False, +) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """ + Given the start and end frame index, sample num_samples frames between + the start and end with equal interval. + Args: + frames (tensor): a tensor of video frames, dimension is + `num video frames` x `channel` x `height` x `width`. + start_idx (int): the index of the start frame. + end_idx (int): the index of the end frame. + num_samples (int): number of frames to sample. + Returns: + frames (tersor): a tensor of temporal sampled video frames, dimension is + `num clip frames` x `channel` x `height` x `width`. + """ + index = torch.linspace(start_idx + 0.5, end_idx + 0.5, num_samples, device=frames.device) + index = torch.clamp(index, 0, frames.shape[0] - 1).long() + new_frames = torch.index_select(frames, 0, index) + + if return_index: + return new_frames, index + return new_frames + + +def tensor_normalize( + tensor: torch.Tensor, + mean: Union[torch.Tensor, Tuple[float, float, float]], + std: Union[torch.Tensor, Tuple[float, float, float]], +) -> torch.Tensor: + """ + Normalize a given tensor by subtracting the mean and dividing the std. + Args: + tensor (tensor): tensor to normalize. + mean (tensor or list): mean value to subtract. + std (tensor or list): std to divide. + """ + if tensor.dtype == torch.uint8: + tensor = tensor.float() + tensor = tensor / 255.0 + if isinstance(mean, tuple): + mean = torch.tensor(mean, device=tensor.device) + if isinstance(std, tuple): + std = torch.tensor(std, device=tensor.device) + tensor = tensor - mean + tensor = tensor / std + return tensor + + +def uniform_crop( + images: torch.Tensor, size: int, spatial_idx: int, scale_size: Optional[int] = None +) -> torch.Tensor: + """ + Perform uniform spatial sampling on the images. + Args: + images (tensor): images to perform uniform crop. The dimension is + `num frames` x `channel` x `height` x `width`. + size (int): size of height and weight to crop the images. + spatial_idx (int): 0, 1, or 2 for left, center, and right crop if width + is larger than height. Or 0, 1, or 2 for top, center, and bottom + crop if height is larger than width. + scale_size (int): optinal. If not None, resize the images to scale_size before + performing any crop. + Returns: + cropped (tensor): images with dimension of + `num frames` x `channel` x `size` x `size`. + """ + assert spatial_idx in [0, 1, 2] + ndim = len(images.shape) + if ndim == 3: + images = images.unsqueeze(0) + height = images.shape[2] + width = images.shape[3] + + if scale_size is not None: + if width <= height: + width, height = scale_size, int(height / width * scale_size) + else: + width, height = int(width / height * scale_size), scale_size + images = torch.nn.functional.interpolate( + images, + size=(height, width), + mode="bilinear", + align_corners=False, + ) + + y_offset = int(math.ceil((height - size) / 2)) + x_offset = int(math.ceil((width - size) / 2)) + + if height > width: + if spatial_idx == 0: + y_offset = 0 + elif spatial_idx == 2: + y_offset = height - size + else: + if spatial_idx == 0: + x_offset = 0 + elif spatial_idx == 2: + x_offset = width - size + cropped = images[:, :, y_offset : y_offset + size, x_offset : x_offset + size] + if ndim == 3: + cropped = cropped.squeeze(0) + return cropped + + +def get_num_padding_frames( + idx: torch.Tensor, + num_frames: int, + sampling_rate: int, + fps: float, + target_fps: float, +) -> int: + """ + Get the number of padding frames based on the provided parameters + + Args: + idx (torch.Tensor): A tensor containing indices. + num_frames (int): The total number of frames. + sampling_rate (int): The rate at which frames are sampled. + fps (float): The original frames per second. + target_fps (float): The target frames per second. + + Returns: + int: The number of padding frames. + """ + + num_unique = len(torch.unique(idx)) + + # Frames duplicated via interpolation should not count as padding + if target_fps > (fps * sampling_rate): + num_non_padding = math.floor(num_unique * target_fps / (fps * sampling_rate)) + else: + num_non_padding = num_unique + return num_frames - num_non_padding + + +class Preprocessor: + def __init__(self, config: DictConfig, device: torch.device): + self.config = config + self.device = device + + if config.hog_detector: + self.detector = dlib.get_frontal_face_detector() + else: + self.detector = dlib.cnn_face_detection_model_v1(config.detector_path) + + def detect_frame(self, frame: np.ndarray) -> List[BboxTuple]: + """ + Detect faces in a frame using either a HOG detector or a CNN-based detector. + + Args: + frame (np.ndarray): The input frame to be processed. + + Returns: + List[BboxTuple]: A list of bounding box tuples, each tuple containing the + coordinates (left, top, right, bottom) of a detected object. + """ + if self.config.detection_downsample: + scale_ratio = determine_scale_ratio(frame) + frame = cv2.resize(frame, (0, 0), fx=scale_ratio, fy=scale_ratio) + else: + scale_ratio = 1 + + if self.config.hog_detector: + return [ + ( + rect.left() // scale_ratio, + rect.top() // scale_ratio, + rect.right() // scale_ratio, + rect.bottom() // scale_ratio, + ) + for rect in self.detector(cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY), 1) + ] + else: + return [ + ( + d.rect.left() // scale_ratio, + d.rect.top() // scale_ratio, + d.rect.right() // scale_ratio, + d.rect.bottom() // scale_ratio, + ) + for d in self.detector(frame, 1) + ] + + def detect_faces(self, frames: Sequence[np.ndarray]) -> Tuple[List[List[BboxTuple]], int]: + """ + Detect faces in a sequence of images. + This function applies a face detector to each image in the input sequence. The detected faces are returned as + bounding boxes. The function also returns the maximum number of faces detected in a single image. + + Args: + frames (Sequence[np.ndarray]): A sequence of frames in which to detect faces. + + Returns: + Tuple[List[List[BboxTuple]], int]: A tuple where the first element is a list of lists of bounding boxes + for the detected faces, and the second element is the maximum number of faces detected in a single image. + """ + bboxes = [self.detect_frame(frame) for frame in frames] + max_num_faces = max([len(x) for x in bboxes]) + + return bboxes, max_num_faces + + def _expand_bboxes(self, bboxes: Sequence[Sequence[BboxTuple]]) -> None: + """ + Expand bounding boxes based on the expansion configuration. + This function iterates over each bounding box in the provided sequence and expands it according to the + expansion configuration defined in `self.config`. The expansion is performed in all four directions: left, + up, right, and down. + + Args: + bboxes (Sequence[Sequence[BboxTuple]]): A sequence of sequences of bounding boxes. Each inner sequence + represents a set of bounding boxes for a particular frame. + + Returns: + None + """ + for i in range(len(bboxes)): + for j in range(len(bboxes[i])): + x0, y0, x1, y1 = bboxes[i][j] + w, h = x1 - x0 + 1, y1 - y0 + 1 + x0, y0, x1, y1 = ( + x0 - w * self.config.left_exp, + y0 - h * self.config.up_exp, + x1 + w * self.config.right_exp, + y1 + h * self.config.down_exp, + ) + bboxes[i][j] = (x0, y0, x1, y1) + + def _bbox_from_bboxes_mean(self, bboxes: Sequence[Sequence[BboxTuple]]) -> BboxTuple: + """ + Calculate the mean bounding box from a sequence of bounding boxes. + This function takes a sequence of bounding boxes, filters out those that do not have exactly one element, + and then calculates the mean bounding box from the remaining ones. + + Args: + bboxes (Sequence[Sequence[BboxTuple]]): A sequence of sequences of bounding boxes. Each inner sequence + represents a set of bounding boxes for a particular frame. + + Returns: + BboxTuple: The mean bounding box calculated from the input bounding boxes. + """ + return tuple(np.array([x for x in bboxes if len(x) == 1]).mean(axis=0)[0]) + + def _try_bbox_from_optical_flow( + self, + frames: Sequence[np.ndarray], + bboxes: Sequence[Sequence[BboxTuple]], + ) -> Optional[BboxTuple]: + """ + Try to find a bounding box from optical flow. + This function calculates the optical flow from the given frames, then tries to find a target bounding box + based on the calculated optical flow and the given bounding boxes. If no bounding box is found, it tries to + calculate a bounding box from the tubes. + + Args: + frames (Sequence[np.ndarray]): A sequence of frames from a video. + bboxes (Sequence[Sequence[BboxTuple]]): A sequence of bounding boxes for each frame. + + Returns: + Optional[BboxTuple]: The found bounding box, or None if no bounding box could be found. + """ + + opts = calculate_optical_flow(frames) + + opts = opts[:: self.config.detection_sampling_rate] + + bbox, tubes = find_target_bbox( + bboxes, opts, self.config.iou_threshold, self.config.num_ratio_threshold + ) + + if bbox is None and tubes: + bbox = calculate_bbox_from_tubes(tubes) + + return bbox + + def _bbox_from_center_crop(self, frame: np.ndarray) -> BboxTuple: + """ + Obtain a bounding box by center cropping a video frame. + This function calculates the center of the frame and then creates a bounding box around the center. + The size of the bounding box is half of the smaller dimension of the frame. + + Args: + frame (np.ndarray): Frame array + + Returns: + BboxTuple: Bounding box for center crop. + """ + W, H = frame.shape[1], frame.shape[0] + cx, cy, size = W // 2, H // 2, min(W, H) // 2 + return (cx - size // 2, cy - size // 2, cx + size // 2, cy + size // 2) + + def _sample_frames_for_feature_extraction( + self, frames: torch.Tensor, fps: float + ) -> Dict[str, Union[int, torch.Tensor]]: + """ + Samples clips with a fixed number of frames in a sliding window for the full video. + The clips are then stacked in the batch dimension. If the video is shorter than `num_frames`, + it will be padded. The method returns the number of padding frames. + This method does not support repeat augmentation. + """ + + frames_list = [] + num_padding_frames_list = [] + + stride = self.config.feature_extraction_stride + sampling_rate = self.config.sampling_rate + num_frames = self.config.num_frames + + clip_sz = sampling_rate * num_frames / self.config.target_fps * fps + + for i in range(0, frames.shape[0], stride * sampling_rate): + start_idx, end_idx = i, i + clip_sz - 1 + new_frames, idx = temporal_sampling( + frames, start_idx, end_idx, num_frames, return_index=True + ) + + num_padding_frames_list.append( + get_num_padding_frames(idx, num_frames, sampling_rate, fps, self.config.target_fps) + ) + + frames_list.append(new_frames) + + if end_idx >= frames.shape[0]: + break + + new_frames = torch.stack(frames_list, dim=0) + + new_frames = tensor_normalize(new_frames, tuple(self.config.mean), tuple(self.config.std)) + frames = new_frames.permute(0, 4, 1, 2, 3) # b t h w c -> b c t h w + + num_padding_frames = torch.tensor( + num_padding_frames_list, dtype=torch.long, device=frames.device + ) + + return {"frames": frames, "padding": num_padding_frames} + + def __call__(self, video_path: Path) -> Dict[str, torch.Tensor]: + """ + Process a video and return regions of interest (ROIs). + This function loads a video, detects faces in the video frames, expands the bounding boxes of the detected faces, + and then tries to find a bounding box either by calculating the mean of the bounding boxes or by using optical flow. + If no bounding box is found, it creates one by center cropping the first frame. Finally, it crops and resizes the + frames according to the found bounding box and returns the resulting regions of interest. + + Args: + video_path (Path): The path to the video file to process. + + Returns: + Dict[str, torch.Tensor]: A dictionary containing a frames tensor and a padding tensor + """ + + assert video_path.is_file(), f"{video_path} does not exist." + + t0 = time.time() + + frames, fps = load_video(video_path) + + t1 = time.time() + + bboxes, max_num_faces = self.detect_faces(frames[:: self.config.detection_sampling_rate]) + + t2 = time.time() + + self._expand_bboxes(bboxes) + + if max_num_faces == 1: + bbox = self._bbox_from_bboxes_mean(bboxes) + else: + bbox = self._try_bbox_from_optical_flow(frames, bboxes) + + if bbox is None: + bbox = self._bbox_from_center_crop(frames[0]) + + t3 = time.time() + + rois = crop_resize(frames, bbox, self.config.target_size) + + t4 = time.time() + + if self.config.debug: + write_video(rois, f"{video_path.stem}_roi_crop.mp4", fps) + + out = self._sample_frames_for_feature_extraction( + torch.from_numpy(rois).to(self.device), fps=fps + ) + + t5 = time.time() + + if self.config.verbose: + print(f"2. Preprocessing: {t5 - t0:.3f}s") + print(f" - Loading: {t1 - t0:.3f}s") + print(f" - Face detection: {t2 - t1:.3f}s") + print(f" - Finding bboxes: {t3 - t2:.3f}s") + print(f" - Crop and resize: {t4 - t3:.3f}s") + print(f" - Sampling: {t5 - t4:.3f} s") + + return out diff --git a/examples/inference/src/slt_inference/translation.py b/examples/inference/src/slt_inference/translation.py new file mode 100644 index 0000000..81a0e97 --- /dev/null +++ b/examples/inference/src/slt_inference/translation.py @@ -0,0 +1,144 @@ +import time +from pathlib import Path +from typing import Dict, Sequence, Tuple + +import torch +from fairseq2.models.sequence import SequenceBatch +from omegaconf import DictConfig +from torch import nn +from transformers import AutoTokenizer, PreTrainedTokenizerFast + +from .modeling.sign_t5 import SignT5Config, SignT5ForConditionalGeneration +from .modeling.sonar_decoder import load_decoder +from .modeling.sonar_t5_encoder import create_sonar_signt5_encoder_model +from .util import load_model + + +class Translator: + def __init__( + self, + config: DictConfig, + device: torch.device, + dtype: torch.dtype = torch.float32, + ): + self.config = config + self.device = device + self.dtype = dtype + + self.model, self.tokenizer = self._load_model_and_tokenizer() + + def _load_model_and_tokenizer( + self, + ) -> Tuple[SignT5ForConditionalGeneration, PreTrainedTokenizerFast]: + """ + Loads a pretrained SignT5 model for translation and moves it to specified device + """ + + config = SignT5Config( + decoder_start_token_id=0, + output_past=True, + tie_word_embeddings=False, + feature_dim=self.config.feature_dim, + ) + model = SignT5ForConditionalGeneration._from_config(config) + + tokenizer = AutoTokenizer.from_pretrained( + self.config.tokenizer_path, use_fast=False, legacy=True + ) + + print("Loading translator") + load_model(model, Path(self.config.pretrained_model_path)) + + model.eval() + model.to(self.device) + model.to(self.dtype) + + return model, tokenizer + + @torch.inference_mode() + def __call__(self, features: torch.Tensor) -> Dict[str, str]: + + t0 = time.time() + + features = features.to(self.device) + + generated_tokens = self.model.generate( + inputs_embeds=features.unsqueeze(0), + num_return_sequences=self.config.num_translations, + max_length=self.config.max_length, + num_beams=self.config.num_beams, + do_sample=self.config.do_sample, + temperature=self.config.temperature, + ) + generated_text = [ + t.strip() + for t in self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True) + ] + + t1 = time.time() + + if self.config.verbose: + print(f"4. Translation: {t1 - t0:.3f}s") + + return {"translations": generated_text} + + +class SonarTranslator: + def __init__( + self, + config: DictConfig, + device: torch.device, + dtype: torch.dtype = torch.float32, + ): + self.config = config + self.device = device + self.dtype = dtype + + self.encoder, self.decoder = self._load_encoder_and_decoder() + + def _load_encoder_and_decoder( + self, + ) -> Tuple[nn.Module, nn.Module]: + + encoder = create_sonar_signt5_encoder_model( + self.config, device=self.device, dtype=self.dtype + ) + decoder = load_decoder( + self.config.decoder_path, self.config.decoder_spm_path, device=self.device + ) + decoder.decoder.to(self.dtype) + + print("Loading translator") + load_model(encoder, Path(self.config.pretrained_model_path), model_key="student") + + encoder.eval() + decoder.decoder.eval() + + return encoder, decoder + + @torch.inference_mode() + def __call__( + self, + features: torch.Tensor, + tgt_langs: Sequence[str] = ["eng_Latn"], + ) -> Dict[str, str]: + + t0 = time.time() + + features = features.to(self.device) + + sentence_embedding = self.encoder( + SequenceBatch(seqs=features.unsqueeze(0), padding_mask=None) + ).sentence_embeddings + + translations = [] + for tgt_lang in tgt_langs: + generated_text = self.decoder.decode(sentence_embedding, tgt_lang)[0] + translations.append(generated_text) + + t1 = time.time() + + if self.config.verbose: + print(f"4. Translation: {t1 - t0:.3f}s") + + return {"translations": translations} diff --git a/examples/inference/src/slt_inference/util.py b/examples/inference/src/slt_inference/util.py new file mode 100644 index 0000000..8805c7b --- /dev/null +++ b/examples/inference/src/slt_inference/util.py @@ -0,0 +1,265 @@ +from pathlib import Path +from typing import Sequence + +import torch + +SIGNHIERA_EXPECTED_MISSING = { + "norm.weight", + "norm.bias", + "head.projection.weight", + "head.projection.bias", +} +SIGNHIERA_EXPECTED_EXTRA = ("decoder", "multi_scale_fusion", "encoder_norm", "mask_token") + + +def load_model(model: torch.nn.Module, checkpoint_path: Path, model_key: str = "model") -> None: + """ + Loads only the model from a saved checkpoint + Model loading is not strict and parameters are evicted from the loaded state_dict if their + shape does not match the one in passed `model`. + """ + + with open(checkpoint_path, "rb") as f: + checkpoint = torch.load(f, map_location="cpu") + + # Try all of these if necessary + for candidate_key in [model_key, "model", "model_state"]: + if candidate_key in checkpoint.keys(): + checkpoint_model = checkpoint[candidate_key] + break + + new_checkpoint_model = {} + for k, v in checkpoint_model.items(): + if "feature_proj" in k and "feature_projection" not in k: + k = k.replace("feature_proj", "feature_projection.feature_proj") + + if k in model.state_dict() and model.state_dict()[k].shape != v.shape: + print(f"Pruning {k} due to size mismatch") + else: + new_checkpoint_model[k] = v + + missing_keys, unexpected_keys = model.load_state_dict(new_checkpoint_model, strict=False) + # Filter out keys expected to be missing or extra when loading SignHiera + missing_keys = [k for k in missing_keys if k not in SIGNHIERA_EXPECTED_MISSING] + unexpected_keys = [k for k in unexpected_keys if not k.startswith(SIGNHIERA_EXPECTED_EXTRA)] + if not (missing_keys or unexpected_keys): + print("All keys matched successfully\n") + else: + print(f"{missing_keys = }, {unexpected_keys = }\n") + + +def print_translations(keys: Sequence[str], translations: Sequence[str]) -> None: + assert len(keys) == len(translations) + print(f"\nTranslations:\n{'-'*50}") + for key, translation in zip(keys, translations): + print(f'{key}: "{translation}"') + print(f"{'-' * 50}\n") + + +FLORES200_LANG2ID = { + "Acehnese (Arabic script)": "ace_Arab", + "Acehnese (Latin script)": "ace_Latn", + "Mesopotamian Arabic": "acm_Arab", + "Ta’izzi-Adeni Arabic": "acq_Arab", + "Tunisian Arabic": "aeb_Arab", + "Afrikaans": "afr_Latn", + "South Levantine Arabic": "ajp_Arab", + "Akan": "aka_Latn", + "Amharic": "amh_Ethi", + "North Levantine Arabic": "apc_Arab", + "Modern Standard Arabic": "arb_Arab", + "Modern Standard Arabic (Romanized)": "arb_Latn", + "Najdi Arabic": "ars_Arab", + "Moroccan Arabic": "ary_Arab", + "Egyptian Arabic": "arz_Arab", + "Assamese": "asm_Beng", + "Asturian": "ast_Latn", + "Awadhi": "awa_Deva", + "Central Aymara": "ayr_Latn", + "South Azerbaijani": "azb_Arab", + "North Azerbaijani": "azj_Latn", + "Bashkir": "bak_Cyrl", + "Bambara": "bam_Latn", + "Balinese": "ban_Latn", + "Belarusian": "bel_Cyrl", + "Bemba": "bem_Latn", + "Bengali": "ben_Beng", + "Bhojpuri": "bho_Deva", + "Banjar (Arabic script)": "bjn_Arab", + "Banjar (Latin script)": "bjn_Latn", + "Standard Tibetan": "bod_Tibt", + "Bosnian": "bos_Latn", + "Buginese": "bug_Latn", + "Bulgarian": "bul_Cyrl", + "Catalan": "cat_Latn", + "Cebuano": "ceb_Latn", + "Czech": "ces_Latn", + "Chokwe": "cjk_Latn", + "Central Kurdish": "ckb_Arab", + "Crimean Tatar": "crh_Latn", + "Welsh": "cym_Latn", + "Danish": "dan_Latn", + "German": "deu_Latn", + "Southwestern Dinka": "dik_Latn", + "Dyula": "dyu_Latn", + "Dzongkha": "dzo_Tibt", + "Greek": "ell_Grek", + "English": "eng_Latn", + "Esperanto": "epo_Latn", + "Estonian": "est_Latn", + "Basque": "eus_Latn", + "Ewe": "ewe_Latn", + "Faroese": "fao_Latn", + "Fijian": "fij_Latn", + "Finnish": "fin_Latn", + "Fon": "fon_Latn", + "French": "fra_Latn", + "Friulian": "fur_Latn", + "Nigerian Fulfulde": "fuv_Latn", + "Scottish Gaelic": "gla_Latn", + "Irish": "gle_Latn", + "Galician": "glg_Latn", + "Guarani": "grn_Latn", + "Gujarati": "guj_Gujr", + "Haitian Creole": "hat_Latn", + "Hausa": "hau_Latn", + "Hebrew": "heb_Hebr", + "Hindi": "hin_Deva", + "Chhattisgarhi": "hne_Deva", + "Croatian": "hrv_Latn", + "Hungarian": "hun_Latn", + "Armenian": "hye_Armn", + "Igbo": "ibo_Latn", + "Ilocano": "ilo_Latn", + "Indonesian": "ind_Latn", + "Icelandic": "isl_Latn", + "Italian": "ita_Latn", + "Javanese": "jav_Latn", + "Japanese": "jpn_Jpan", + "Kabyle": "kab_Latn", + "Jingpho": "kac_Latn", + "Kamba": "kam_Latn", + "Kannada": "kan_Knda", + "Kashmiri (Arabic script)": "kas_Arab", + "Kashmiri (Devanagari script)": "kas_Deva", + "Georgian": "kat_Geor", + "Central Kanuri (Arabic script)": "knc_Arab", + "Central Kanuri (Latin script)": "knc_Latn", + "Kazakh": "kaz_Cyrl", + "Kabiyè": "kbp_Latn", + "Kabuverdianu": "kea_Latn", + "Khmer": "khm_Khmr", + "Kikuyu": "kik_Latn", + "Kinyarwanda": "kin_Latn", + "Kyrgyz": "kir_Cyrl", + "Kimbundu": "kmb_Latn", + "Northern Kurdish": "kmr_Latn", + "Kikongo": "kon_Latn", + "Korean": "kor_Hang", + "Lao": "lao_Laoo", + "Ligurian": "lij_Latn", + "Limburgish": "lim_Latn", + "Lingala": "lin_Latn", + "Lithuanian": "lit_Latn", + "Lombard": "lmo_Latn", + "Latgalian": "ltg_Latn", + "Luxembourgish": "ltz_Latn", + "Luba-Kasai": "lua_Latn", + "Ganda": "lug_Latn", + "Luo": "luo_Latn", + "Mizo": "lus_Latn", + "Standard Latvian": "lvs_Latn", + "Magahi": "mag_Deva", + "Maithili": "mai_Deva", + "Malayalam": "mal_Mlym", + "Marathi": "mar_Deva", + "Minangkabau (Arabic script)": "min_Arab", + "Minangkabau (Latin script)": "min_Latn", + "Macedonian": "mkd_Cyrl", + "Plateau Malagasy": "plt_Latn", + "Maltese": "mlt_Latn", + "Meitei (Bengali script)": "mni_Beng", + "Halh Mongolian": "khk_Cyrl", + "Mossi": "mos_Latn", + "Maori": "mri_Latn", + "Burmese": "mya_Mymr", + "Dutch": "nld_Latn", + "Norwegian Nynorsk": "nno_Latn", + "Norwegian Bokmål": "nob_Latn", + "Nepali": "npi_Deva", + "Northern Sotho": "nso_Latn", + "Nuer": "nus_Latn", + "Nyanja": "nya_Latn", + "Occitan": "oci_Latn", + "West Central Oromo": "gaz_Latn", + "Odia": "ory_Orya", + "Pangasinan": "pag_Latn", + "Eastern Panjabi": "pan_Guru", + "Papiamento": "pap_Latn", + "Western Persian": "pes_Arab", + "Polish": "pol_Latn", + "Portuguese": "por_Latn", + "Dari": "prs_Arab", + "Southern Pashto": "pbt_Arab", + "Ayacucho Quechua": "quy_Latn", + "Romanian": "ron_Latn", + "Rundi": "run_Latn", + "Russian": "rus_Cyrl", + "Sango": "sag_Latn", + "Sanskrit": "san_Deva", + "Santali": "sat_Olck", + "Sicilian": "scn_Latn", + "Shan": "shn_Mymr", + "Sinhala": "sin_Sinh", + "Slovak": "slk_Latn", + "Slovenian": "slv_Latn", + "Samoan": "smo_Latn", + "Shona": "sna_Latn", + "Sindhi": "snd_Arab", + "Somali": "som_Latn", + "Southern Sotho": "sot_Latn", + "Spanish": "spa_Latn", + "Tosk Albanian": "als_Latn", + "Sardinian": "srd_Latn", + "Serbian": "srp_Cyrl", + "Swati": "ssw_Latn", + "Sundanese": "sun_Latn", + "Swedish": "swe_Latn", + "Swahili": "swh_Latn", + "Silesian": "szl_Latn", + "Tamil": "tam_Taml", + "Tatar": "tat_Cyrl", + "Telugu": "tel_Telu", + "Tajik": "tgk_Cyrl", + "Tagalog": "tgl_Latn", + "Thai": "tha_Thai", + "Tigrinya": "tir_Ethi", + "Tamasheq (Latin script)": "taq_Latn", + "Tamasheq (Tifinagh script)": "taq_Tfng", + "Tok Pisin": "tpi_Latn", + "Tswana": "tsn_Latn", + "Tsonga": "tso_Latn", + "Turkmen": "tuk_Latn", + "Tumbuka": "tum_Latn", + "Turkish": "tur_Latn", + "Twi": "twi_Latn", + "Central Atlas Tamazight": "tzm_Tfng", + "Uyghur": "uig_Arab", + "Ukrainian": "ukr_Cyrl", + "Umbundu": "umb_Latn", + "Urdu": "urd_Arab", + "Northern Uzbek": "uzn_Latn", + "Venetian": "vec_Latn", + "Vietnamese": "vie_Latn", + "Waray": "war_Latn", + "Wolof": "wol_Latn", + "Xhosa": "xho_Latn", + "Eastern Yiddish": "ydd_Hebr", + "Yoruba": "yor_Latn", + "Yue Chinese": "yue_Hant", + "Chinese (Simplified)": "zho_Hans", + "Chinese (Traditional)": "zho_Hant", + "Standard Malay": "zsm_Latn", + "Zulu": "zul_Latn", +} +FLORES200_ID2LANG = {v: k for k, v in FLORES200_LANG2ID.items()}