diff --git a/bechdelai/audio/speaker_diarization.py b/bechdelai/audio/speaker_diarization.py new file mode 100644 index 0000000..66cc35d --- /dev/null +++ b/bechdelai/audio/speaker_diarization.py @@ -0,0 +1,58 @@ +from typing import Any, Dict, List, Tuple + +import torch +from speechbox import ASRDiarizationPipeline + + +## From https://huggingface.co/spaces/speechbox/whisper-speaker-diarization +## Uses the pre-trained checkpoint Whisper Tiny for the ASR transcriptions and pyannote.audio to label the speakers. +class SpeakerDiarization: + def __init__(self, model_name="openai/whisper-tiny"): + self.device = 0 if torch.cuda.is_available() else "cpu" + self.pipe = ASRDiarizationPipeline.from_pretrained( + asr_model=model_name, + device=self.device, + ) + + def transcribe(self, file_upload) -> List[Dict[str, Any]]: + """Transcribe audio file using speaker diarization + + Args: + file_upload (_type_): Input audio file + + Returns: + str: Transcription of audio file + """ + segments = self.pipe(file_upload) + return segments + + def tuple_to_string(self, start_end_tuple: Tuple[float, float], ndigits: int = 1) -> str: + """Turn a tuple of floats into a string + + Args: + start_end_tuple (Tuple[float, float]): Start and end times + ndigits (int, optional): Number of digits . Defaults to 1. + + Returns: + str: String representation of tuple + """ + return str((round(start_end_tuple[0], ndigits), round(start_end_tuple[1], ndigits))) + + + def format_as_transcription(self, raw_segments: List[Dict[str, Any]], with_timestamps: bool=False) -> str: + """Format raw speaker diarization output as a human readable transcription + + Args: + raw_segments (_type_): Raw speaker diarization output + with_timestamps (bool): Whether to include timestamps in the transcription + + Returns: + str: Transcription of audio file + """ + if with_timestamps: + return "\n\n".join([chunk["speaker"] + " " + self.tuple_to_string(chunk["timestamp"]) + chunk["text"] for chunk in raw_segments]) + else: + return "\n\n".join([chunk["speaker"] + chunk["text"] for chunk in raw_segments]) + + + \ No newline at end of file