-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
20 changed files
with
4,867 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -94,4 +94,8 @@ dmypy.json | |
# misc | ||
*.mp4 | ||
sweep*/ | ||
core* | ||
core* | ||
|
||
features_outputs | ||
MOCK_dataset | ||
*.pth |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
|
||
|
||
### Installation | ||
|
||
1. Create and activate conda environment | ||
```bash | ||
conda create --name sign-language-inference python=3.10 | ||
conda activate sign-language-inference | ||
``` | ||
|
||
2. Install dependencies | ||
```bash | ||
conda install -c -y conda-forge libsndfile==1.0.31 # fairseq2 dependency | ||
pip install -r requirements.txt | ||
``` | ||
|
||
3. Install dlib with CUDA support (alternatively, `pip install dlib`; comes without CUDA support) | ||
|
||
Confirm whether your dlib installation supports CUDA with: `python -c "import dlib; print(dlib.DLIB_USE_CUDA)"`. | ||
|
||
```bash | ||
# This configuration worked on FAIR cluster; make sure to check build logs | ||
|
||
module load cmake/3.15.3/gcc.7.3.0 gcc/12.2.0 cuda/12.1 cudnn/v8.8.1.3-cuda.12.0 | ||
|
||
git clone https://github.com/davisking/dlib.git | ||
cd dlib | ||
python setup.py install \ | ||
--set CUDA_TOOLKIT_ROOT_DIR=/public/apps/cuda/12.1 \ | ||
--set CMAKE_PREFIX_PATH=/public/apps/cudnn/v8.8.1.3-cuda.12.0 \ | ||
--set USE_AVX_INSTRUCTIONS=yes \ | ||
--set DLIB_USE_CUDA=yes | ||
|
||
``` | ||
|
||
4. Install this repo (from repo root) | ||
```bash | ||
pip install -e . | ||
``` | ||
|
||
|
||
### Model checkpoints | ||
|
||
Copy model checkpoints to [checkpoints](checkpoints) folder. | ||
```bash | ||
cp -r /checkpoint/philliprust/slt_inference/checkpoints/* checkpoints/ | ||
``` | ||
|
||
### Running Inference | ||
|
||
```bash | ||
python run.py video_path=/path/to/video.mp4 | ||
``` | ||
|
||
Pass `verbose=True` to print config and time elapsed for various steps in the pipeline. | ||
|
||
Pass `preprocessing.hog_detector=false` if running dlib CNN detector with CUDA support. | ||
|
||
Pass `preprocessing.detection_downsample=false` if the video input resolution is already small, e.g. 224x224. | ||
|
||
Here is a more advanced example for running a SONAR model: | ||
|
||
```bash | ||
python run.py \ | ||
video_path=/path/to/my/video.mp4 \ | ||
use_sonar=true \ | ||
preprocessing.detection_downsample=false \ | ||
feature_extraction.pretrained_model_path=/path/to/trained/signhiera.pth \ | ||
translation.base_model_name=google/t5-v1_1-large \ | ||
translation.pretrained_model_path=/path/to/trained/sonar/encoder/best_model.pth \ | ||
feature_extraction.fp16=true \ | ||
'translation.tgt_langs=[eng_Latn, fra_Latn, deu_Latn, zho_Hans]' | ||
``` | ||
|
||
|
||
### Running the Gradio app | ||
|
||
To run the Gradio app, you can use the exact same command as for `run.py` but leave out the `video_path`. You can then open the browser and upload a video. As soon as the video finishes uploading, it should start playing automatically which will trigger the inference pipeline. If the video doesn't start automatically, just press the play button manually. | ||
|
||
```bash | ||
python app.py \ | ||
use_sonar=true \ | ||
preprocessing.detection_downsample=false \ | ||
feature_extraction.pretrained_model_path=/path/to/trained/signhiera.pth \ | ||
translation.base_model_name=google/t5-v1_1-large \ | ||
translation.pretrained_model_path=/path/to/trained/sonar/encoder/best_model.pth \ | ||
feature_extraction.fp16=true \ | ||
'translation.tgt_langs=[eng_Latn, fra_Latn, deu_Latn, zho_Hans]' | ||
``` | ||
|
||
The Gradio app needs to be launched from a GPU machine. If running on a devfair, you can tunnel to your local machine via `ssh -L 8000:localhost:7860 devfair`. After the devfair login, you can go to `localhost:8000` in your browser to use the app. | ||
|
||
### Slicing a video before inferencing | ||
|
||
Videos should ideally be a sentence long. To slice a video, e.g. from second 10 to 14, you can use ffmpeg: | ||
|
||
```bash | ||
ffmpeg -ss 10.00 -to 14.00 -i ./video.MOV -c:v libx264 -crf 20 video_slice.mp4 -loglevel info | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,198 @@ | ||
import os | ||
from dataclasses import dataclass, field | ||
from pathlib import Path | ||
from queue import Queue | ||
from time import sleep | ||
from typing import List, Optional, Sequence, Tuple | ||
|
||
import gradio as gr | ||
import hydra | ||
import torch | ||
from hydra.core.config_store import ConfigStore | ||
from omegaconf import II, DictConfig, OmegaConf | ||
|
||
from slt_inference.feature_extraction import FeatureExtractor | ||
from slt_inference.preprocessing import Preprocessor | ||
from slt_inference.translation import SonarTranslator, Translator | ||
from slt_inference.util import FLORES200_ID2LANG, print_translations | ||
|
||
|
||
@dataclass | ||
class PreprocessingConfig: | ||
# Bounding box expansion and threshold parameters | ||
up_exp: float = 1.0 | ||
down_exp: float = 3.0 | ||
left_exp: float = 1.5 | ||
right_exp: float = 1.5 | ||
iou_threshold: float = 0.2 | ||
num_ratio_threshold: float = 0.5 | ||
|
||
# Dlib detection parameters | ||
hog_detector: bool = True | ||
detector_path: Optional[str] = "checkpoints/detector.dat" | ||
detection_sampling_rate: int = 16 | ||
detection_downsample: bool = True | ||
|
||
# Sliding window sampling parameters | ||
num_frames: int = 128 | ||
feature_extraction_stride: int = 64 | ||
sampling_rate: int = 2 | ||
target_fps: int = 25 | ||
|
||
# Cropping and Normalization | ||
target_size: int = 224 | ||
mean: Tuple[float, float, float] = (0.45, 0.45, 0.45) | ||
std: Tuple[float, float, float] = (0.225, 0.225, 0.225) | ||
|
||
debug: bool = False | ||
verbose: bool = II("verbose") | ||
|
||
def __post_init__(self): | ||
if self.hog_detector is False: | ||
assert ( | ||
self.detector_path is not None | ||
), "detector_path must be provded if `hog_detector=False`" | ||
|
||
|
||
@dataclass | ||
class FeatureExtractionConfig: | ||
pretrained_model_path: str = "checkpoints/feature_extractor.pth" | ||
model_name: str = "hiera_base_128x224" | ||
max_batch_size: int = 2 | ||
fp16: bool = True | ||
verbose: bool = II("verbose") | ||
|
||
|
||
@dataclass | ||
class TranslationConfig: | ||
pretrained_model_path: str = "checkpoints/translator.pth" | ||
tokenizer_path: str = "checkpoints/tokenizer" | ||
base_model_name: str = "google/t5-v1_1-large" | ||
feature_dim: int = 768 | ||
decoder_path: str = "checkpoints/sonar_decoder.pt" | ||
decoder_spm_path: str = "checkpoints/decoder_sentencepiece.model" | ||
|
||
# Target languages for Sonar translator | ||
tgt_langs: List[str] = field(default_factory=lambda: ["eng_Latn"]) | ||
|
||
# Generation parameters | ||
num_translations: int = 5 | ||
do_sample: bool = False | ||
num_beams: int = 5 | ||
temperature: float = 1.0 | ||
max_length: int = 128 | ||
|
||
verbose: bool = II("verbose") | ||
|
||
def __post_init__(self): | ||
for lang in self.tgt_langs: | ||
if lang not in FLORES200_ID2LANG: | ||
raise ValueError(f"{lang} is not a valid FLORES-200 language ID") | ||
|
||
|
||
@dataclass | ||
class RunConfig: | ||
preprocessing: PreprocessingConfig = PreprocessingConfig() | ||
feature_extraction: FeatureExtractionConfig = FeatureExtractionConfig() | ||
translation: TranslationConfig = TranslationConfig() | ||
use_sonar: bool = False | ||
verbose: bool = False | ||
|
||
|
||
cs = ConfigStore.instance() | ||
cs.store(name="run_config", node=RunConfig) | ||
|
||
video = None | ||
video_released = False | ||
translation_queue = Queue(maxsize=1) | ||
|
||
css = """ | ||
.app { | ||
max-width: 50% !important; | ||
} | ||
""" | ||
|
||
|
||
@hydra.main(config_name="run_config") | ||
def main(config: DictConfig): | ||
os.chdir(hydra.utils.get_original_cwd()) | ||
|
||
print(f"Config:\n{OmegaConf.to_yaml(config, resolve=True)}") | ||
|
||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
dtype = torch.float16 if config.feature_extraction.fp16 else torch.float32 | ||
|
||
preprocessor = Preprocessor(config.preprocessing, device=device) | ||
feature_extractor = FeatureExtractor(config.feature_extraction, device=device) | ||
|
||
translator_cls = SonarTranslator if config.use_sonar else Translator | ||
translator = translator_cls(config.translation, device=device, dtype=dtype) | ||
|
||
def release(): | ||
"""Triggered when video ends to indicate translation can now be displayed""" | ||
global video_released | ||
video_released = True | ||
|
||
def load_video(video_path: str): | ||
"""Load and preprocess the uploaded video""" | ||
global video | ||
video = preprocessor(Path(video_path)) | ||
|
||
def inference_video(): | ||
"""Run inference on video and put translations in a queue""" | ||
global translation_queue | ||
extracted_features = feature_extractor(**video) | ||
|
||
kwargs = {"tgt_langs": config.translation.tgt_langs} if config.use_sonar else {} | ||
translations = translator(extracted_features, **kwargs)["translations"] | ||
translation_queue.put(translations) | ||
|
||
keys = ( | ||
[FLORES200_ID2LANG[lang] for lang in config.translation.tgt_langs] | ||
if config.use_sonar | ||
else range(config.translation.num_translations) | ||
) | ||
print_translations(keys, translations) | ||
|
||
def show_translation(): | ||
"""Consume translation results from queue and display them""" | ||
|
||
global video_released | ||
# Wait until video has finished playing before showing translation | ||
while not video_released: | ||
sleep(0.05) | ||
|
||
video_released = False | ||
translation_result = translation_queue.get() | ||
|
||
return ( | ||
[gr.Text(v) for v in translation_result] | ||
if config.use_sonar | ||
else gr.Text(translation_result[0]) | ||
) | ||
|
||
with gr.Blocks(css=css) as demo: | ||
with gr.Column(scale=1): | ||
gr.Markdown("### ASL") | ||
input_video = gr.Video(label="Input Video", height=360, autoplay=True) | ||
|
||
output_texts = [] | ||
if config.use_sonar: | ||
for lang in config.translation.tgt_langs: | ||
gr.Markdown(f"### {FLORES200_ID2LANG[lang]}") | ||
output_texts.append(gr.Text("", interactive=False, label="Translation")) | ||
else: | ||
gr.Markdown("### English") | ||
output_texts.append(gr.Text("", interactive=False, label="Translation")) | ||
|
||
input_video.upload(fn=load_video, inputs=input_video, outputs=None).success( | ||
fn=inference_video, inputs=None, outputs=None | ||
).success(fn=show_translation, inputs=None, outputs=output_texts) | ||
|
||
input_video.end(fn=release, inputs=None, outputs=None) | ||
|
||
demo.launch() | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
import os | ||
import time | ||
from pathlib import Path | ||
import torch | ||
from omegaconf import DictConfig, OmegaConf | ||
from src.slt_inference.preprocessing import Preprocessor | ||
from src.slt_inference.feature_extraction import FeatureExtractor | ||
from src.slt_inference.translation import SonarTranslator, Translator | ||
from src.slt_inference.util import FLORES200_ID2LANG, print_translations | ||
|
||
def e2e_pipeline(config: DictConfig): | ||
""" | ||
Main function to run the sign language translation pipeline. | ||
""" | ||
os.chdir(os.getcwd()) # Ensure correct working directory | ||
|
||
if config.verbose: | ||
print(f"Config:\n{OmegaConf.to_yaml(config, resolve=True)}") | ||
|
||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
dtype = torch.float16 if config.feature_extraction.fp16 else torch.float32 | ||
|
||
t0 = time.time() | ||
|
||
# Initialize components | ||
preprocessor = Preprocessor(config.preprocessing, device=device) | ||
feature_extractor = FeatureExtractor(config.feature_extraction, device=device) | ||
translator_cls = SonarTranslator if config.use_sonar else Translator | ||
translator = translator_cls(config.translation, device=device, dtype=dtype) | ||
|
||
t1 = time.time() | ||
if config.verbose: | ||
print(f"1. Model loading: {t1 - t0:.3f}s") | ||
|
||
# Process input | ||
inputs = preprocessor(Path(config.video_path)) | ||
extracted_features = feature_extractor(**inputs) | ||
|
||
kwargs = {"tgt_langs": config.translation.tgt_langs} if config.use_sonar else {} | ||
translations = translator(extracted_features, **kwargs)["translations"] | ||
|
||
keys = [FLORES200_ID2LANG[lang] for lang in config.translation.tgt_langs] if config.use_sonar else range(config.translation.num_translations) | ||
|
||
# Output results | ||
print_translations(keys, translations) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
from omegaconf import OmegaConf | ||
from e2e import e2e_pipeline | ||
from src.slt_inference.configs import RunConfig, FeatureExtractionConfig, TranslationConfig | ||
|
||
# Define the translation configuration | ||
translation_config = RunConfig( | ||
video_path="D:/Pro/MLH/ctf/video.mp4", | ||
verbose=True, | ||
feature_extraction=FeatureExtractionConfig( | ||
pretrained_model_path="translation/signhiera_mock.pth", | ||
), | ||
translation=TranslationConfig( | ||
base_model_name="google/t5-v1_1-large", | ||
tgt_langs=["eng_Latn", "fra_Latn"] | ||
) | ||
) | ||
|
||
# Convert it to DictConfig | ||
translation_dict_config = OmegaConf.structured(translation_config) | ||
|
||
# Run pipeline with provided parameters | ||
e2e_pipeline(translation_dict_config) |
Oops, something went wrong.