From feffe4a1f3a0851a4896f9819b2bb2ed0810613f Mon Sep 17 00:00:00 2001 From: boocmp Date: Mon, 19 Aug 2024 16:01:10 +0700 Subject: [PATCH] whisper backend [transformers]. --- src/runners/audio_transcriber.py | 71 +++++++++++++++++++++++++++++++- 1 file changed, 70 insertions(+), 1 deletion(-) diff --git a/src/runners/audio_transcriber.py b/src/runners/audio_transcriber.py index ce75782..25166fb 100644 --- a/src/runners/audio_transcriber.py +++ b/src/runners/audio_transcriber.py @@ -167,7 +167,8 @@ def transcribe_audio(self, inputs: list[BatchInput]) -> list[str]: """ -from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor +""" +from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, WhisperProcessor, import torch from itertools import groupby @@ -230,3 +231,71 @@ def transcribe_audio(self, inputs: list[BatchInput]) -> list[str]: ) for text in segments ] +""" + +from transformers import WhisperProcessor, WhisperForConditionalGeneration +import torch + + +class BatchableAudioTranscriber(bentoml.Runnable): + SUPPORTED_RESOURCES = ("nvidia.com/gpu", "cpu") + SUPPORTS_CPU_MULTI_THREADING = True + + def __init__(self): + self.device = "cuda" if torch.cuda.is_available() else "cpu" + self.processor = WhisperProcessor.from_pretrained("openai/whisper-base.en") + self.model = WhisperForConditionalGeneration.from_pretrained( + "openai/whisper-base.en", attn_implementation="sdpa" + ).cuda() + + self.model.generation_config.cache_implementation = "static" + self.model.forward = torch.compile( + self.model.forward, mode="reduce-overhead", fullgraph=True + ) + + def transcribe(self, audios): + input_features = self.processor( + audios, return_tensors="pt", sampling_rate=16000, padding=True + ).input_features.cuda() + + for _ in range(2): + self.model.generate(input_features) + + predicted_ids = self.model.generate(input_features) + transcriptions = self.processor.batch_decode(predicted_ids, skip_special_tokens=True) + print(transcriptions) + + return transcriptions + + @bentoml.Runnable.method(batchable=True) + def transcribe_audio(self, inputs: list[BatchInput]) -> list[str]: + result = [] + + # merging audio + ts = datetime.now() + audio_batch = [] + for input in inputs: + wav = decode_audio(io.BytesIO(input.audio)) + chunks = get_speech_timestamps(wav) + if len(chunks) == 0: + audio_batch.append(np.zeros(16000, dtype=np.float32)) + else: + wav = collect_chunks(wav, chunks=chunks) + audio_batch.append(wav) + + merge_time = (datetime.now() - ts).total_seconds() + + ts = datetime.now() + segments = self.transcribe(audio_batch) + transcribe_time = (datetime.now() - ts).total_seconds() + + return [ + BatchOutput( + text=text, + batched_count=len(inputs), + merge_audio_time=merge_time, + transcribe_time=transcribe_time, + restore_time=0, + ) + for text in segments + ]