-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
1 changed file
with
58 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
from typing import Any, Dict, List, Tuple | ||
|
||
import torch | ||
from speechbox import ASRDiarizationPipeline | ||
|
||
|
||
## From https://huggingface.co/spaces/speechbox/whisper-speaker-diarization | ||
## Uses the pre-trained checkpoint Whisper Tiny for the ASR transcriptions and pyannote.audio to label the speakers. | ||
class SpeakerDiarization: | ||
def __init__(self, model_name="openai/whisper-tiny"): | ||
self.device = 0 if torch.cuda.is_available() else "cpu" | ||
self.pipe = ASRDiarizationPipeline.from_pretrained( | ||
asr_model=model_name, | ||
device=self.device, | ||
) | ||
|
||
def transcribe(self, file_upload) -> List[Dict[str, Any]]: | ||
"""Transcribe audio file using speaker diarization | ||
Args: | ||
file_upload (_type_): Input audio file | ||
Returns: | ||
str: Transcription of audio file | ||
""" | ||
segments = self.pipe(file_upload) | ||
return segments | ||
|
||
def tuple_to_string(self, start_end_tuple: Tuple[float, float], ndigits: int = 1) -> str: | ||
"""Turn a tuple of floats into a string | ||
Args: | ||
start_end_tuple (Tuple[float, float]): Start and end times | ||
ndigits (int, optional): Number of digits . Defaults to 1. | ||
Returns: | ||
str: String representation of tuple | ||
""" | ||
return str((round(start_end_tuple[0], ndigits), round(start_end_tuple[1], ndigits))) | ||
|
||
|
||
def format_as_transcription(self, raw_segments: List[Dict[str, Any]], with_timestamps: bool=False) -> str: | ||
"""Format raw speaker diarization output as a human readable transcription | ||
Args: | ||
raw_segments (_type_): Raw speaker diarization output | ||
with_timestamps (bool): Whether to include timestamps in the transcription | ||
Returns: | ||
str: Transcription of audio file | ||
""" | ||
if with_timestamps: | ||
return "\n\n".join([chunk["speaker"] + " " + self.tuple_to_string(chunk["timestamp"]) + chunk["text"] for chunk in raw_segments]) | ||
else: | ||
return "\n\n".join([chunk["speaker"] + chunk["text"] for chunk in raw_segments]) | ||
|
||
|
||
|