Skip to content

Commit

Permalink
add e2e script & e2e demo
Browse files Browse the repository at this point in the history
  • Loading branch information
JooZef315 committed Nov 29, 2024
1 parent 27c0a24 commit 0e3f127
Show file tree
Hide file tree
Showing 20 changed files with 4,867 additions and 1 deletion.
6 changes: 5 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -94,4 +94,8 @@ dmypy.json
# misc
*.mp4
sweep*/
core*
core*

features_outputs
MOCK_dataset
*.pth
99 changes: 99 additions & 0 deletions examples/inference/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@


### Installation

1. Create and activate conda environment
```bash
conda create --name sign-language-inference python=3.10
conda activate sign-language-inference
```

2. Install dependencies
```bash
conda install -c -y conda-forge libsndfile==1.0.31 # fairseq2 dependency
pip install -r requirements.txt
```

3. Install dlib with CUDA support (alternatively, `pip install dlib`; comes without CUDA support)

Confirm whether your dlib installation supports CUDA with: `python -c "import dlib; print(dlib.DLIB_USE_CUDA)"`.

```bash
# This configuration worked on FAIR cluster; make sure to check build logs

module load cmake/3.15.3/gcc.7.3.0 gcc/12.2.0 cuda/12.1 cudnn/v8.8.1.3-cuda.12.0

git clone https://github.com/davisking/dlib.git
cd dlib
python setup.py install \
--set CUDA_TOOLKIT_ROOT_DIR=/public/apps/cuda/12.1 \
--set CMAKE_PREFIX_PATH=/public/apps/cudnn/v8.8.1.3-cuda.12.0 \
--set USE_AVX_INSTRUCTIONS=yes \
--set DLIB_USE_CUDA=yes

```

4. Install this repo (from repo root)
```bash
pip install -e .
```


### Model checkpoints

Copy model checkpoints to [checkpoints](checkpoints) folder.
```bash
cp -r /checkpoint/philliprust/slt_inference/checkpoints/* checkpoints/
```

### Running Inference

```bash
python run.py video_path=/path/to/video.mp4
```

Pass `verbose=True` to print config and time elapsed for various steps in the pipeline.

Pass `preprocessing.hog_detector=false` if running dlib CNN detector with CUDA support.

Pass `preprocessing.detection_downsample=false` if the video input resolution is already small, e.g. 224x224.

Here is a more advanced example for running a SONAR model:

```bash
python run.py \
video_path=/path/to/my/video.mp4 \
use_sonar=true \
preprocessing.detection_downsample=false \
feature_extraction.pretrained_model_path=/path/to/trained/signhiera.pth \
translation.base_model_name=google/t5-v1_1-large \
translation.pretrained_model_path=/path/to/trained/sonar/encoder/best_model.pth \
feature_extraction.fp16=true \
'translation.tgt_langs=[eng_Latn, fra_Latn, deu_Latn, zho_Hans]'
```


### Running the Gradio app

To run the Gradio app, you can use the exact same command as for `run.py` but leave out the `video_path`. You can then open the browser and upload a video. As soon as the video finishes uploading, it should start playing automatically which will trigger the inference pipeline. If the video doesn't start automatically, just press the play button manually.

```bash
python app.py \
use_sonar=true \
preprocessing.detection_downsample=false \
feature_extraction.pretrained_model_path=/path/to/trained/signhiera.pth \
translation.base_model_name=google/t5-v1_1-large \
translation.pretrained_model_path=/path/to/trained/sonar/encoder/best_model.pth \
feature_extraction.fp16=true \
'translation.tgt_langs=[eng_Latn, fra_Latn, deu_Latn, zho_Hans]'
```

The Gradio app needs to be launched from a GPU machine. If running on a devfair, you can tunnel to your local machine via `ssh -L 8000:localhost:7860 devfair`. After the devfair login, you can go to `localhost:8000` in your browser to use the app.

### Slicing a video before inferencing

Videos should ideally be a sentence long. To slice a video, e.g. from second 10 to 14, you can use ffmpeg:

```bash
ffmpeg -ss 10.00 -to 14.00 -i ./video.MOV -c:v libx264 -crf 20 video_slice.mp4 -loglevel info
```
198 changes: 198 additions & 0 deletions examples/inference/app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
import os
from dataclasses import dataclass, field
from pathlib import Path
from queue import Queue
from time import sleep
from typing import List, Optional, Sequence, Tuple

import gradio as gr
import hydra
import torch
from hydra.core.config_store import ConfigStore
from omegaconf import II, DictConfig, OmegaConf

from slt_inference.feature_extraction import FeatureExtractor
from slt_inference.preprocessing import Preprocessor
from slt_inference.translation import SonarTranslator, Translator
from slt_inference.util import FLORES200_ID2LANG, print_translations


@dataclass
class PreprocessingConfig:
# Bounding box expansion and threshold parameters
up_exp: float = 1.0
down_exp: float = 3.0
left_exp: float = 1.5
right_exp: float = 1.5
iou_threshold: float = 0.2
num_ratio_threshold: float = 0.5

# Dlib detection parameters
hog_detector: bool = True
detector_path: Optional[str] = "checkpoints/detector.dat"
detection_sampling_rate: int = 16
detection_downsample: bool = True

# Sliding window sampling parameters
num_frames: int = 128
feature_extraction_stride: int = 64
sampling_rate: int = 2
target_fps: int = 25

# Cropping and Normalization
target_size: int = 224
mean: Tuple[float, float, float] = (0.45, 0.45, 0.45)
std: Tuple[float, float, float] = (0.225, 0.225, 0.225)

debug: bool = False
verbose: bool = II("verbose")

def __post_init__(self):
if self.hog_detector is False:
assert (
self.detector_path is not None
), "detector_path must be provded if `hog_detector=False`"


@dataclass
class FeatureExtractionConfig:
pretrained_model_path: str = "checkpoints/feature_extractor.pth"
model_name: str = "hiera_base_128x224"
max_batch_size: int = 2
fp16: bool = True
verbose: bool = II("verbose")


@dataclass
class TranslationConfig:
pretrained_model_path: str = "checkpoints/translator.pth"
tokenizer_path: str = "checkpoints/tokenizer"
base_model_name: str = "google/t5-v1_1-large"
feature_dim: int = 768
decoder_path: str = "checkpoints/sonar_decoder.pt"
decoder_spm_path: str = "checkpoints/decoder_sentencepiece.model"

# Target languages for Sonar translator
tgt_langs: List[str] = field(default_factory=lambda: ["eng_Latn"])

# Generation parameters
num_translations: int = 5
do_sample: bool = False
num_beams: int = 5
temperature: float = 1.0
max_length: int = 128

verbose: bool = II("verbose")

def __post_init__(self):
for lang in self.tgt_langs:
if lang not in FLORES200_ID2LANG:
raise ValueError(f"{lang} is not a valid FLORES-200 language ID")


@dataclass
class RunConfig:
preprocessing: PreprocessingConfig = PreprocessingConfig()
feature_extraction: FeatureExtractionConfig = FeatureExtractionConfig()
translation: TranslationConfig = TranslationConfig()
use_sonar: bool = False
verbose: bool = False


cs = ConfigStore.instance()
cs.store(name="run_config", node=RunConfig)

video = None
video_released = False
translation_queue = Queue(maxsize=1)

css = """
.app {
max-width: 50% !important;
}
"""


@hydra.main(config_name="run_config")
def main(config: DictConfig):
os.chdir(hydra.utils.get_original_cwd())

print(f"Config:\n{OmegaConf.to_yaml(config, resolve=True)}")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.float16 if config.feature_extraction.fp16 else torch.float32

preprocessor = Preprocessor(config.preprocessing, device=device)
feature_extractor = FeatureExtractor(config.feature_extraction, device=device)

translator_cls = SonarTranslator if config.use_sonar else Translator
translator = translator_cls(config.translation, device=device, dtype=dtype)

def release():
"""Triggered when video ends to indicate translation can now be displayed"""
global video_released
video_released = True

def load_video(video_path: str):
"""Load and preprocess the uploaded video"""
global video
video = preprocessor(Path(video_path))

def inference_video():
"""Run inference on video and put translations in a queue"""
global translation_queue
extracted_features = feature_extractor(**video)

kwargs = {"tgt_langs": config.translation.tgt_langs} if config.use_sonar else {}
translations = translator(extracted_features, **kwargs)["translations"]
translation_queue.put(translations)

keys = (
[FLORES200_ID2LANG[lang] for lang in config.translation.tgt_langs]
if config.use_sonar
else range(config.translation.num_translations)
)
print_translations(keys, translations)

def show_translation():
"""Consume translation results from queue and display them"""

global video_released
# Wait until video has finished playing before showing translation
while not video_released:
sleep(0.05)

video_released = False
translation_result = translation_queue.get()

return (
[gr.Text(v) for v in translation_result]
if config.use_sonar
else gr.Text(translation_result[0])
)

with gr.Blocks(css=css) as demo:
with gr.Column(scale=1):
gr.Markdown("### ASL")
input_video = gr.Video(label="Input Video", height=360, autoplay=True)

output_texts = []
if config.use_sonar:
for lang in config.translation.tgt_langs:
gr.Markdown(f"### {FLORES200_ID2LANG[lang]}")
output_texts.append(gr.Text("", interactive=False, label="Translation"))
else:
gr.Markdown("### English")
output_texts.append(gr.Text("", interactive=False, label="Translation"))

input_video.upload(fn=load_video, inputs=input_video, outputs=None).success(
fn=inference_video, inputs=None, outputs=None
).success(fn=show_translation, inputs=None, outputs=output_texts)

input_video.end(fn=release, inputs=None, outputs=None)

demo.launch()


if __name__ == "__main__":
main()
45 changes: 45 additions & 0 deletions examples/inference/e2e.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import os
import time
from pathlib import Path
import torch
from omegaconf import DictConfig, OmegaConf
from src.slt_inference.preprocessing import Preprocessor
from src.slt_inference.feature_extraction import FeatureExtractor
from src.slt_inference.translation import SonarTranslator, Translator
from src.slt_inference.util import FLORES200_ID2LANG, print_translations

def e2e_pipeline(config: DictConfig):
"""
Main function to run the sign language translation pipeline.
"""
os.chdir(os.getcwd()) # Ensure correct working directory

if config.verbose:
print(f"Config:\n{OmegaConf.to_yaml(config, resolve=True)}")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.float16 if config.feature_extraction.fp16 else torch.float32

t0 = time.time()

# Initialize components
preprocessor = Preprocessor(config.preprocessing, device=device)
feature_extractor = FeatureExtractor(config.feature_extraction, device=device)
translator_cls = SonarTranslator if config.use_sonar else Translator
translator = translator_cls(config.translation, device=device, dtype=dtype)

t1 = time.time()
if config.verbose:
print(f"1. Model loading: {t1 - t0:.3f}s")

# Process input
inputs = preprocessor(Path(config.video_path))
extracted_features = feature_extractor(**inputs)

kwargs = {"tgt_langs": config.translation.tgt_langs} if config.use_sonar else {}
translations = translator(extracted_features, **kwargs)["translations"]

keys = [FLORES200_ID2LANG[lang] for lang in config.translation.tgt_langs] if config.use_sonar else range(config.translation.num_translations)

# Output results
print_translations(keys, translations)
22 changes: 22 additions & 0 deletions examples/inference/e2e_demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from omegaconf import OmegaConf
from e2e import e2e_pipeline
from src.slt_inference.configs import RunConfig, FeatureExtractionConfig, TranslationConfig

# Define the translation configuration
translation_config = RunConfig(
video_path="D:/Pro/MLH/ctf/video.mp4",
verbose=True,
feature_extraction=FeatureExtractionConfig(
pretrained_model_path="translation/signhiera_mock.pth",
),
translation=TranslationConfig(
base_model_name="google/t5-v1_1-large",
tgt_langs=["eng_Latn", "fra_Latn"]
)
)

# Convert it to DictConfig
translation_dict_config = OmegaConf.structured(translation_config)

# Run pipeline with provided parameters
e2e_pipeline(translation_dict_config)
Loading

0 comments on commit 0e3f127

Please sign in to comment.