Skip to content

Commit

Permalink
Add Whisper speaker diarization
Browse files Browse the repository at this point in the history
  • Loading branch information
DnzzL committed Feb 2, 2023
1 parent faf5729 commit dd79922
Showing 1 changed file with 58 additions and 0 deletions.
58 changes: 58 additions & 0 deletions bechdelai/audio/speaker_diarization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from typing import Any, Dict, List, Tuple

import torch
from speechbox import ASRDiarizationPipeline


## From https://huggingface.co/spaces/speechbox/whisper-speaker-diarization
## Uses the pre-trained checkpoint Whisper Tiny for the ASR transcriptions and pyannote.audio to label the speakers.
class SpeakerDiarization:
def __init__(self, model_name="openai/whisper-tiny"):
self.device = 0 if torch.cuda.is_available() else "cpu"
self.pipe = ASRDiarizationPipeline.from_pretrained(
asr_model=model_name,
device=self.device,
)

def transcribe(self, file_upload) -> List[Dict[str, Any]]:
"""Transcribe audio file using speaker diarization
Args:
file_upload (_type_): Input audio file
Returns:
str: Transcription of audio file
"""
segments = self.pipe(file_upload)
return segments

def tuple_to_string(self, start_end_tuple: Tuple[float, float], ndigits: int = 1) -> str:
"""Turn a tuple of floats into a string
Args:
start_end_tuple (Tuple[float, float]): Start and end times
ndigits (int, optional): Number of digits . Defaults to 1.
Returns:
str: String representation of tuple
"""
return str((round(start_end_tuple[0], ndigits), round(start_end_tuple[1], ndigits)))


def format_as_transcription(self, raw_segments: List[Dict[str, Any]], with_timestamps: bool=False) -> str:
"""Format raw speaker diarization output as a human readable transcription
Args:
raw_segments (_type_): Raw speaker diarization output
with_timestamps (bool): Whether to include timestamps in the transcription
Returns:
str: Transcription of audio file
"""
if with_timestamps:
return "\n\n".join([chunk["speaker"] + " " + self.tuple_to_string(chunk["timestamp"]) + chunk["text"] for chunk in raw_segments])
else:
return "\n\n".join([chunk["speaker"] + chunk["text"] for chunk in raw_segments])



0 comments on commit dd79922

Please sign in to comment.