Skip to content

Commit

Permalink
whisper backend [transformers].
Browse files Browse the repository at this point in the history
  • Loading branch information
boocmp committed Aug 19, 2024
1 parent 88c27d7 commit feffe4a
Showing 1 changed file with 70 additions and 1 deletion.
71 changes: 70 additions & 1 deletion src/runners/audio_transcriber.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,8 @@ def transcribe_audio(self, inputs: list[BatchInput]) -> list[str]:
"""

from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
"""
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, WhisperProcessor,
import torch
from itertools import groupby
Expand Down Expand Up @@ -230,3 +231,71 @@ def transcribe_audio(self, inputs: list[BatchInput]) -> list[str]:
)
for text in segments
]
"""

from transformers import WhisperProcessor, WhisperForConditionalGeneration
import torch


class BatchableAudioTranscriber(bentoml.Runnable):
SUPPORTED_RESOURCES = ("nvidia.com/gpu", "cpu")
SUPPORTS_CPU_MULTI_THREADING = True

def __init__(self):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.processor = WhisperProcessor.from_pretrained("openai/whisper-base.en")
self.model = WhisperForConditionalGeneration.from_pretrained(
"openai/whisper-base.en", attn_implementation="sdpa"
).cuda()

self.model.generation_config.cache_implementation = "static"
self.model.forward = torch.compile(
self.model.forward, mode="reduce-overhead", fullgraph=True
)

def transcribe(self, audios):
input_features = self.processor(
audios, return_tensors="pt", sampling_rate=16000, padding=True
).input_features.cuda()

for _ in range(2):
self.model.generate(input_features)

predicted_ids = self.model.generate(input_features)
transcriptions = self.processor.batch_decode(predicted_ids, skip_special_tokens=True)
print(transcriptions)

return transcriptions

@bentoml.Runnable.method(batchable=True)
def transcribe_audio(self, inputs: list[BatchInput]) -> list[str]:
result = []

# merging audio
ts = datetime.now()
audio_batch = []
for input in inputs:
wav = decode_audio(io.BytesIO(input.audio))
chunks = get_speech_timestamps(wav)
if len(chunks) == 0:
audio_batch.append(np.zeros(16000, dtype=np.float32))
else:
wav = collect_chunks(wav, chunks=chunks)
audio_batch.append(wav)

merge_time = (datetime.now() - ts).total_seconds()

ts = datetime.now()
segments = self.transcribe(audio_batch)
transcribe_time = (datetime.now() - ts).total_seconds()

return [
BatchOutput(
text=text,
batched_count=len(inputs),
merge_audio_time=merge_time,
transcribe_time=transcribe_time,
restore_time=0,
)
for text in segments
]

0 comments on commit feffe4a

Please sign in to comment.