From 4e3610e22553f4706db8a6b22733519e9704cd24 Mon Sep 17 00:00:00 2001 From: arda-argmax Date: Wed, 4 Dec 2024 17:11:48 -0800 Subject: [PATCH] fix model download and vad fail --- whisperkit/pipelines.py | 34 +++++++++++++++++++++++++++------- 1 file changed, 27 insertions(+), 7 deletions(-) diff --git a/whisperkit/pipelines.py b/whisperkit/pipelines.py index bff0020..bd06510 100644 --- a/whisperkit/pipelines.py +++ b/whisperkit/pipelines.py @@ -13,6 +13,7 @@ from typing import Optional from argmaxtools.utils import _maybe_git_clone, get_logger +from huggingface_hub import snapshot_download from whisperkit import _constants @@ -132,8 +133,29 @@ def clone_models(self): """ Download WhisperKit model files from Hugging Face Hub (only the files needed for `self.whisper_version`) """ - self.models_dir = os.path.join(self.repo_dir, "models") # dummy - self.results_dir = os.path.join(self.repo_dir, "results") + self.models_dir = os.path.join( + self.repo_dir, "Models", self.whisper_version.replace("/", "_")) + + os.makedirs(self.models_dir, exist_ok=True) + + snapshot_download( + repo_id=_constants.MODEL_REPO_ID, + allow_patterns=f"{self.whisper_version.replace('/', '_')}/*", + revision=self.model_commit_hash, + local_dir=os.path.dirname(self.models_dir), + local_dir_use_symlinks=True + ) + + if self.model_commit_hash is None: + self.model_commit_hash = subprocess.run( + f"git ls-remote git@hf.co:{_constants.MODEL_REPO_ID}", + shell=True, stdout=subprocess.PIPE + ).stdout.decode("utf-8").rsplit("\n")[0].rsplit("\t")[0] + logger.info( + "--model-commit-hash not specified, " + f"imputing with HEAD={self.model_commit_hash}") + + self.results_dir = os.path.join(self.models_dir, "results") os.makedirs(self.results_dir, exist_ok=True) def transcribe(self, audio_file_path: str, forced_language: Optional[str] = None) -> str: @@ -143,11 +165,10 @@ def transcribe(self, audio_file_path: str, forced_language: Optional[str] = None self.cli_path, "transcribe", "--audio-path", audio_file_path, - "--model-prefix", self.whisper_version.rsplit("/")[0], - "--model", self.whisper_version.rsplit("/")[1], + "--model-path", self.models_dir, "--text-decoder-compute-units", self._text_decoder_compute_units, "--audio-encoder-compute-units", self._audio_encoder_compute_units, - "--chunking-strategy", "vad", + # "--chunking-strategy", "vad", "--report-path", self.results_dir, "--report", "--word-timestamps" if self._word_timestamps else "", "" if forced_language is None else f"--use-prefill-prompt --language {forced_language}", @@ -182,8 +203,7 @@ def transcribe_folder(self, audio_folder_path: str, forced_language: Optional[st self.cli_path, "transcribe", "--audio-folder", audio_folder_path, - "--model-prefix", self.whisper_version.rsplit("/")[0], - "--model", self.whisper_version.rsplit("/")[1], + "--model-path", self.models_dir, "--text-decoder-compute-units", self._text_decoder_compute_units, "--audio-encoder-compute-units", self._audio_encoder_compute_units, "--report-path", self.results_dir, "--report",