From 24008aa1ed67c4f75c90107b4937178a1452519d Mon Sep 17 00:00:00 2001 From: Max Bain Date: Sun, 7 May 2023 15:32:58 +0100 Subject: [PATCH] fix long segments, break into sentences using nltk, improve align logic, improve diarize (sentence-based) --- requirements.txt | 3 +- whisperx/alignment.py | 553 +++++++++++++++-------------------------- whisperx/asr.py | 155 +----------- whisperx/diarize.py | 82 +++--- whisperx/transcribe.py | 17 +- whisperx/utils.py | 69 +++-- 6 files changed, 287 insertions(+), 592 deletions(-) diff --git a/requirements.txt b/requirements.txt index f4f9c217..ec90a07f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,4 +4,5 @@ faster-whisper transformers ffmpeg-python==0.2.0 pandas -setuptools==65.6.3 \ No newline at end of file +setuptools==65.6.3 +nltk \ No newline at end of file diff --git a/whisperx/alignment.py b/whisperx/alignment.py index e63e6e5c..2812c10b 100644 --- a/whisperx/alignment.py +++ b/whisperx/alignment.py @@ -13,6 +13,7 @@ from .audio import SAMPLE_RATE, load_audio from .utils import interpolate_nans +import nltk LANGUAGES_WITHOUT_SPACES = ["ja", "zh"] @@ -84,386 +85,226 @@ def align( align_model_metadata: dict, audio: Union[str, np.ndarray, torch.Tensor], device: str, - extend_duration: float = 0.0, - start_from_previous: bool = True, interpolate_method: str = "nearest", + return_char_alignments: bool = False, ): """ - Force align phoneme recognition predictions to known transcription - - Parameters - ---------- - transcript: Iterator[dict] - The Whisper model instance - - model: torch.nn.Module - Alignment model (wav2vec2) - - audio: Union[str, np.ndarray, torch.Tensor] - The path to the audio file to open, or the audio waveform - - device: str - cuda device - - diarization: pd.DataFrame {'start': List[float], 'end': List[float], 'speaker': List[float]} - diarization segments with speaker labels. - - extend_duration: float - Amount to pad input segments by. If not using vad--filter then recommended to use 2 seconds - - If the gzip compression ratio is above this value, treat as failed - - interpolate_method: str ["nearest", "linear", "ignore"] - Method to assign timestamps to non-aligned words. Words are not able to be aligned when none of the characters occur in the align model dictionary. - "nearest" copies timestamp of nearest word within the segment. "linear" is linear interpolation. "drop" removes that word from output. - - Returns - ------- - A dictionary containing the resulting text ("text") and segment-level details ("segments"), and - the spoken language ("language"), which is detected when `decode_options["language"]` is None. + Align phoneme recognition predictions to known transcription. """ + if not torch.is_tensor(audio): if isinstance(audio, str): audio = load_audio(audio) audio = torch.from_numpy(audio) if len(audio.shape) == 1: audio = audio.unsqueeze(0) - + MAX_DURATION = audio.shape[1] / SAMPLE_RATE model_dictionary = align_model_metadata["dictionary"] model_lang = align_model_metadata["language"] model_type = align_model_metadata["type"] - aligned_segments = [] - - prev_t2 = 0 - - char_segments_arr = { - "segment-idx": [], - "subsegment-idx": [], - "word-idx": [], - "char": [], - "start": [], - "end": [], - "score": [], - } - + # 1. Preprocess to keep only characters in dictionary for sdx, segment in enumerate(transcript): - while True: - segment_align_success = False - - # strip spaces at beginning / end, but keep track of the amount. - num_leading = len(segment["text"]) - len(segment["text"].lstrip()) - num_trailing = len(segment["text"]) - len(segment["text"].rstrip()) - transcription = segment["text"] - - # TODO: convert number tokenizer / symbols to phonetic words for alignment. - # e.g. "$300" -> "three hundred dollars" - # currently "$300" is ignored since no characters present in the phonetic dictionary + # strip spaces at beginning / end, but keep track of the amount. + num_leading = len(segment["text"]) - len(segment["text"].lstrip()) + num_trailing = len(segment["text"]) - len(segment["text"].rstrip()) + text = segment["text"] + + # split into words + if model_lang not in LANGUAGES_WITHOUT_SPACES: + per_word = text.split(" ") + else: + per_word = text - # split into words + clean_char, clean_cdx = [], [] + for cdx, char in enumerate(text): + char_ = char.lower() + # wav2vec2 models use "|" character to represent spaces if model_lang not in LANGUAGES_WITHOUT_SPACES: - per_word = transcription.split(" ") - else: - per_word = transcription - - # first check that characters in transcription can be aligned (they are contained in align model"s dictionary) - clean_char, clean_cdx = [], [] - for cdx, char in enumerate(transcription): - char_ = char.lower() - # wav2vec2 models use "|" character to represent spaces - if model_lang not in LANGUAGES_WITHOUT_SPACES: - char_ = char_.replace(" ", "|") - - # ignore whitespace at beginning and end of transcript - if cdx < num_leading: - pass - elif cdx > len(transcription) - num_trailing - 1: - pass - elif char_ in model_dictionary.keys(): - clean_char.append(char_) - clean_cdx.append(cdx) - - clean_wdx = [] - for wdx, wrd in enumerate(per_word): - if any([c in model_dictionary.keys() for c in wrd]): - clean_wdx.append(wdx) - - # if no characters are in the dictionary, then we skip this segment... - if len(clean_char) == 0: - print(f'Failed to align segment ("{segment["text"]}"): no characters in this segment found in model dictionary, resorting to original...') - break - - transcription_cleaned = "".join(clean_char) - tokens = [model_dictionary[c] for c in transcription_cleaned] - - # we only pad if not using VAD filtering - if "seg_text" not in segment: - # pad according original timestamps - t1 = max(segment["start"] - extend_duration, 0) - t2 = min(segment["end"] + extend_duration, MAX_DURATION) - - # use prev_t2 as current t1 if it"s later - if start_from_previous and t1 < prev_t2: - t1 = prev_t2 - - # check if timestamp range is still valid - if t1 >= MAX_DURATION: - print("Failed to align segment: original start time longer than audio duration, skipping...") - break - if t2 - t1 < 0.02: - print("Failed to align segment: duration smaller than 0.02s time precision") - break - - f1 = int(t1 * SAMPLE_RATE) - f2 = int(t2 * SAMPLE_RATE) - - waveform_segment = audio[:, f1:f2] - - with torch.inference_mode(): - if model_type == "torchaudio": - emissions, _ = model(waveform_segment.to(device)) - elif model_type == "huggingface": - emissions = model(waveform_segment.to(device)).logits - else: - raise NotImplementedError(f"Align model of type {model_type} not supported.") - emissions = torch.log_softmax(emissions, dim=-1) - - emission = emissions[0].cpu().detach() + char_ = char_.replace(" ", "|") + + # ignore whitespace at beginning and end of transcript + if cdx < num_leading: + pass + elif cdx > len(text) - num_trailing - 1: + pass + elif char_ in model_dictionary.keys(): + clean_char.append(char_) + clean_cdx.append(cdx) + + clean_wdx = [] + for wdx, wrd in enumerate(per_word): + if any([c in model_dictionary.keys() for c in wrd]): + clean_wdx.append(wdx) + + sentence_spans = list(nltk.tokenize.punkt.PunktSentenceTokenizer().span_tokenize(text)) + + segment["clean_char"] = clean_char + segment["clean_cdx"] = clean_cdx + segment["clean_wdx"] = clean_wdx + segment["sentence_spans"] = sentence_spans + + aligned_segments = [] - blank_id = 0 - for char, code in model_dictionary.items(): - if char == '[pad]' or char == '': - blank_id = code + # 2. Get prediction matrix from alignment model & align + for sdx, segment in enumerate(transcript): + t1 = segment["start"] + t2 = segment["end"] + text = segment["text"] + + aligned_seg = { + "start": t1, + "end": t2, + "text": text, + "words": [], + } + + if return_char_alignments: + aligned_seg["chars"] = [] + + # check we can align + if len(segment["clean_char"]) == 0: + print(f'Failed to align segment ("{segment["text"]}"): no characters in this segment found in model dictionary, resorting to original...') + aligned_segments.append(aligned_seg) + continue + + if t1 >= MAX_DURATION or t2 - t1 < 0.02: + print("Failed to align segment: original start time longer than audio duration, skipping...") + aligned_segments.append(aligned_seg) + continue + + text_clean = "".join(segment["clean_char"]) + tokens = [model_dictionary[c] for c in text_clean] + + f1 = int(t1 * SAMPLE_RATE) + f2 = int(t2 * SAMPLE_RATE) + + # TODO: Probably can get some speedup gain with batched inference here + waveform_segment = audio[:, f1:f2] + + with torch.inference_mode(): + if model_type == "torchaudio": + emissions, _ = model(waveform_segment.to(device)) + elif model_type == "huggingface": + emissions = model(waveform_segment.to(device)).logits + else: + raise NotImplementedError(f"Align model of type {model_type} not supported.") + emissions = torch.log_softmax(emissions, dim=-1) + + emission = emissions[0].cpu().detach() + + blank_id = 0 + for char, code in model_dictionary.items(): + if char == '[pad]' or char == '': + blank_id = code + + trellis = get_trellis(emission, tokens, blank_id) + path = backtrack(trellis, emission, tokens, blank_id) + + if path is None: + print(f'Failed to align segment ("{segment["text"]}"): backtrack failed, resorting to original...') + aligned_segments.append(aligned_seg) + continue + + char_segments = merge_repeats(path, text_clean) + + duration = t2 -t1 + ratio = duration * waveform_segment.size(0) / (trellis.size(0) - 1) + + # assign timestamps to aligned characters + char_segments_arr = [] + word_idx = 0 + for cdx, char in enumerate(text): + start, end, score = None, None, None + if cdx in segment["clean_cdx"]: + char_seg = char_segments[segment["clean_cdx"].index(cdx)] + start = round(char_seg.start * ratio + t1, 3) + end = round(char_seg.end * ratio + t1, 3) + score = round(char_seg.score, 3) + + char_segments_arr.append( + { + "char": char, + "start": start, + "end": end, + "score": score, + "word-idx": word_idx, + } + ) - trellis = get_trellis(emission, tokens, blank_id) - path = backtrack(trellis, emission, tokens, blank_id) - if path is None: - print(f'Failed to align segment ("{segment["text"]}"): backtrack failed, resorting to original...') - break - char_segments = merge_repeats(path, transcription_cleaned) - # word_segments = merge_words(char_segments) + # increment word_idx, nltk word tokenization would probably be more robust here, but us space for now... + if model_lang in LANGUAGES_WITHOUT_SPACES: + word_idx += 1 + elif cdx == len(text) - 1 or text[cdx+1] == " ": + word_idx += 1 - - # sub-segments - if "seg-text" not in segment: - segment["seg-text"] = [transcription] - - seg_lens = [0] + [len(x) for x in segment["seg-text"]] - seg_lens_cumsum = list(np.cumsum(seg_lens)) - sub_seg_idx = 0 - - wdx = 0 - duration = t2 - t1 - ratio = duration * waveform_segment.size(0) / (trellis.size(0) - 1) - for cdx, char in enumerate(transcription + " "): - is_last = False - if cdx == len(transcription): - break - elif cdx+1 == len(transcription): - is_last = True - - - start, end, score = None, None, None - if cdx in clean_cdx: - char_seg = char_segments[clean_cdx.index(cdx)] - start = round(char_seg.start * ratio + t1, 3) - end = round(char_seg.end * ratio + t1, 3) - score = char_seg.score - - char_segments_arr["char"].append(char) - char_segments_arr["start"].append(start) - char_segments_arr["end"].append(end) - char_segments_arr["score"].append(score) - char_segments_arr["word-idx"].append(wdx) - char_segments_arr["segment-idx"].append(sdx) - char_segments_arr["subsegment-idx"].append(sub_seg_idx) - - # word-level info - if model_lang in LANGUAGES_WITHOUT_SPACES: - # character == word - wdx += 1 - elif is_last or transcription[cdx+1] == " " or cdx == seg_lens_cumsum[sub_seg_idx+1] - 1: - wdx += 1 - - if is_last or cdx == seg_lens_cumsum[sub_seg_idx+1] - 1: - wdx = 0 - sub_seg_idx += 1 - - prev_t2 = segment["end"] - - segment_align_success = True - # end while True loop - break - - # reset prev_t2 due to drifting issues - if not segment_align_success: - prev_t2 = 0 + char_segments_arr = pd.DataFrame(char_segments_arr) + + aligned_subsegments = [] + # assign sentence_idx to each character index + char_segments_arr["sentence-idx"] = None + for sdx, (sstart, send) in enumerate(segment["sentence_spans"]): + curr_chars = char_segments_arr.loc[(char_segments_arr.index >= sstart) & (char_segments_arr.index <= send)] + char_segments_arr.loc[(char_segments_arr.index >= sstart) & (char_segments_arr.index <= send), "sentence-idx"] = sdx - char_segments_arr = pd.DataFrame(char_segments_arr) - not_space = char_segments_arr["char"] != " " - - per_seg_grp = char_segments_arr.groupby(["segment-idx", "subsegment-idx"], as_index = False) - char_segments_arr = per_seg_grp.apply(lambda x: x.reset_index(drop = True)).reset_index() - per_word_grp = char_segments_arr[not_space].groupby(["segment-idx", "subsegment-idx", "word-idx"]) - per_subseg_grp = char_segments_arr[not_space].groupby(["segment-idx", "subsegment-idx"]) - per_seg_grp = char_segments_arr[not_space].groupby(["segment-idx"]) - char_segments_arr["local-char-idx"] = char_segments_arr.groupby(["segment-idx", "subsegment-idx"]).cumcount() - per_word_grp = char_segments_arr[not_space].groupby(["segment-idx", "subsegment-idx", "word-idx"]) # regroup - - word_segments_arr = {} - - # start of word is first char with a timestamp - word_segments_arr["start"] = per_word_grp["start"].min().values - # end of word is last char with a timestamp - word_segments_arr["end"] = per_word_grp["end"].max().values - # score of word is mean (excluding nan) - word_segments_arr["score"] = per_word_grp["score"].mean().values - - word_segments_arr["segment-text-start"] = per_word_grp["local-char-idx"].min().astype(int).values - word_segments_arr["segment-text-end"] = per_word_grp["local-char-idx"].max().astype(int).values+1 - word_segments_arr = pd.DataFrame(word_segments_arr) - - word_segments_arr[["segment-idx", "subsegment-idx", "word-idx"]] = per_word_grp["local-char-idx"].min().reset_index()[["segment-idx", "subsegment-idx", "word-idx"]].astype(int) - segments_arr = {} - segments_arr["start"] = per_subseg_grp["start"].min().reset_index()["start"] - segments_arr["end"] = per_subseg_grp["end"].max().reset_index()["end"] - segments_arr = pd.DataFrame(segments_arr) - segments_arr[["segment-idx", "subsegment-idx-start"]] = per_subseg_grp["start"].min().reset_index()[["segment-idx", "subsegment-idx"]] - segments_arr["subsegment-idx-end"] = segments_arr["subsegment-idx-start"] + 1 - - # interpolate missing words / sub-segments - if interpolate_method != "ignore": - wrd_subseg_grp = word_segments_arr.groupby(["segment-idx", "subsegment-idx"], group_keys=False) - wrd_seg_grp = word_segments_arr.groupby(["segment-idx"], group_keys=False) - # we still know which word timestamps are interpolated because their score == nan - word_segments_arr["start"] = wrd_subseg_grp['start'].apply(lambda group: interpolate_nans(group, method=interpolate_method)) - word_segments_arr["end"] = wrd_subseg_grp['end'].apply(lambda group: interpolate_nans(group, method=interpolate_method)) - - word_segments_arr["start"] = wrd_seg_grp['start'].apply(lambda group: interpolate_nans(group, method=interpolate_method)) - word_segments_arr["end"] = wrd_seg_grp['end'].apply(lambda group: interpolate_nans(group, method=interpolate_method)) - - sub_seg_grp = segments_arr.groupby(["segment-idx"], group_keys=False) - segments_arr['start'] = sub_seg_grp['start'].apply(lambda group: interpolate_nans(group, method=interpolate_method)) - segments_arr['end'] = sub_seg_grp['end'].apply(lambda group: interpolate_nans(group, method=interpolate_method)) - - # merge words & subsegments which are missing times - word_grp = word_segments_arr.groupby(["segment-idx", "subsegment-idx", "end"]) - - word_segments_arr["segment-text-start"] = word_grp["segment-text-start"].transform(min) - word_segments_arr["segment-text-end"] = word_grp["segment-text-end"].transform(max) - word_segments_arr.drop_duplicates(subset=["segment-idx", "subsegment-idx", "end"], inplace=True) - - seg_grp_dup = segments_arr.groupby(["segment-idx", "start", "end"]) - segments_arr["subsegment-idx-start"] = seg_grp_dup["subsegment-idx-start"].transform(min) - segments_arr["subsegment-idx-end"] = seg_grp_dup["subsegment-idx-end"].transform(max) - segments_arr.drop_duplicates(subset=["segment-idx", "subsegment-idx-start", "subsegment-idx-end"], inplace=True) - else: - word_segments_arr.dropna(inplace=True) - segments_arr.dropna(inplace=True) - - # if some segments still have missing timestamps (usually because all numerals / symbols), then use original timestamps... - segments_arr['start'].fillna(pd.Series([x['start'] for x in transcript]), inplace=True) - segments_arr['end'].fillna(pd.Series([x['end'] for x in transcript]), inplace=True) - segments_arr['subsegment-idx-start'].fillna(0, inplace=True) - segments_arr['subsegment-idx-end'].fillna(1, inplace=True) - - - aligned_segments = [] - aligned_segments_word = [] - - word_segments_arr.set_index(["segment-idx", "subsegment-idx"], inplace=True) - char_segments_arr.set_index(["segment-idx", "subsegment-idx", "word-idx"], inplace=True) - - for sdx, srow in segments_arr.iterrows(): - - seg_idx = int(srow["segment-idx"]) - sub_start = int(srow["subsegment-idx-start"]) - sub_end = int(srow["subsegment-idx-end"]) - - seg = transcript[seg_idx] - text = "".join(seg["seg-text"][sub_start:sub_end]) - - wseg = word_segments_arr.loc[seg_idx].loc[sub_start:sub_end-1] - wseg["start"].fillna(srow["start"], inplace=True) - wseg["end"].fillna(srow["end"], inplace=True) - wseg["segment-text-start"].fillna(0, inplace=True) - wseg["segment-text-end"].fillna(len(text)-1, inplace=True) - - cseg = char_segments_arr.loc[seg_idx].loc[sub_start:sub_end-1] - # fixes bug for single segment in transcript - cseg['segment-text-start'] = cseg['level_1'] if 'level_1' in cseg else 0 - cseg['segment-text-end'] = cseg['level_1'] + 1 if 'level_1' in cseg else 1 - if 'level_1' in cseg: del cseg['level_1'] - if 'level_0' in cseg: del cseg['level_0'] - cseg.reset_index(inplace=True) - - def get_raw_text(word_row): - return seg["seg-text"][word_row.name][int(word_row["segment-text-start"]):int(word_row["segment-text-end"])+1] - - word_list = [] - wdx = 0 - curr_text = get_raw_text(wseg.iloc[wdx]) - if not curr_text.startswith(" "): - curr_text = " " + curr_text + sentence_text = text[sstart:send] + sentence_start = curr_chars["start"].min() + sentence_end = curr_chars["end"].max() + sentence_words = [] + + for word_idx in curr_chars["word-idx"].unique(): + word_chars = curr_chars.loc[curr_chars["word-idx"] == word_idx] + word_text = "".join(word_chars["char"].tolist()).strip() + if len(word_text) == 0: + continue + word_start = word_chars["start"].min() + word_end = word_chars["end"].max() + word_score = round(word_chars["score"].mean(), 3) + + # -1 indicates unalignable + word_segment = {"word": word_text} + + if not np.isnan(word_start): + word_segment["start"] = word_start + if not np.isnan(word_end): + word_segment["end"] = word_end + if not np.isnan(word_score): + word_segment["score"] = word_score + + sentence_words.append(word_segment) - if len(wseg) > 1: - for _, wrow in wseg.iloc[1:].iterrows(): - if wrow['start'] != wseg.iloc[wdx]['start']: - word_start = wseg.iloc[wdx]['start'] - word_end = wseg.iloc[wdx]['end'] - - aligned_segments_word.append( - { - "text": curr_text.strip(), - "start": word_start, - "end": word_end - } - ) - - word_list.append( - { - "word": curr_text.rstrip(), - "start": word_start, - "end": word_end, - } - ) - - curr_text = " " - curr_text += get_raw_text(wrow) + " " - wdx += 1 - - aligned_segments_word.append( - { - "text": curr_text.strip(), - "start": wseg.iloc[wdx]["start"], - "end": wseg.iloc[wdx]["end"] - } - ) - - word_list.append( - { - "word": curr_text.rstrip(), - "start": wseg.iloc[wdx]['start'], - "end": wseg.iloc[wdx]['end'], - } - ) - - aligned_segments.append( - { - "start": srow["start"], - "end": srow["end"], - "text": text, - "words": word_list, - "word-segments": wseg, - "char-segments": cseg - } - ) - - - return {"segments": aligned_segments, "word_segments": aligned_segments_word} - + aligned_subsegments.append({ + "text": sentence_text, + "start": sentence_start, + "end": sentence_end, + "words": sentence_words, + }) + + if return_char_alignments: + curr_chars = curr_chars[["char", "start", "end", "score"]] + curr_chars.fillna(-1, inplace=True) + curr_chars = curr_chars.to_dict("records") + curr_chars = [{key: val for key, val in char.items() if val != -1} for char in curr_chars] + + aligned_subsegments = pd.DataFrame(aligned_subsegments) + aligned_subsegments["start"] = interpolate_nans(aligned_subsegments["start"], method=interpolate_method) + aligned_subsegments["end"] = interpolate_nans(aligned_subsegments["end"], method=interpolate_method) + # concatenate sentences with same timestamps + agg_dict = {"text": " ".join, "words": "sum"} + if return_char_alignments: + agg_dict["chars"] = "sum" + aligned_subsegments= aligned_subsegments.groupby(["start", "end"], as_index=False).agg(agg_dict) + aligned_subsegments = aligned_subsegments.to_dict('records') + aligned_segments += aligned_subsegments + + # create word_segments list + word_segments = [] + for segment in aligned_segments: + word_segments += segment["words"] + + return {"segments": aligned_segments, "word_segments": word_segments} """ source: https://pytorch.org/tutorials/intermediate/forced_alignment_with_torchaudio_tutorial.html diff --git a/whisperx/asr.py b/whisperx/asr.py index ba6220bd..f2c54f6c 100644 --- a/whisperx/asr.py +++ b/whisperx/asr.py @@ -78,7 +78,7 @@ def load_model(whisper_arch, device, compute_type="float16", asr_options=None, l class WhisperModel(faster_whisper.WhisperModel): ''' FasterWhisperModel provides batched inference for faster-whisper. - Currently only works in non-timestamp mode. + Currently only works in non-timestamp mode and fixed prompt for all samples in batch. ''' def generate_segment_batched(self, features: np.ndarray, tokenizer: faster_whisper.tokenizer.Tokenizer, options: faster_whisper.transcribe.TranscriptionOptions, encoder_output = None): @@ -140,6 +140,13 @@ def encode(self, features: np.ndarray) -> ctranslate2.StorageView: return self.model.encode(features, to_cpu=to_cpu) class FasterWhisperPipeline(Pipeline): + """ + Huggingface Pipeline wrapper for FasterWhisperModel. + """ + # TODO: + # - add support for timestamp mode + # - add support for custom inference kwargs + def __init__( self, model, @@ -261,149 +268,3 @@ def detect_language(self, audio: np.ndarray): language = language_token[2:-2] print(f"Detected language: {language} ({language_probability:.2f}) in first 30s of audio...") return language - -if __name__ == "__main__": - main_type = "simple" - import time - - import jiwer - from tqdm import tqdm - from whisper.normalizers import EnglishTextNormalizer - - from benchmark.tedlium import parse_tedlium_annos - - if main_type == "complex": - from faster_whisper.tokenizer import Tokenizer - from faster_whisper.transcribe import TranscriptionOptions - from faster_whisper.vad import (SpeechTimestampsMap, - get_speech_timestamps) - - from whisperx.vad import load_vad_model, merge_chunks - - from .audio import SAMPLE_RATE, load_audio, log_mel_spectrogram - faster_t_options = TranscriptionOptions( - beam_size=5, - best_of=5, - patience=1, - length_penalty=1, - temperatures=[0.0, 0.2, 0.4, 0.6, 0.8, 1.0], - compression_ratio_threshold=2.4, - log_prob_threshold=-1.0, - no_speech_threshold=0.6, - condition_on_previous_text=False, - initial_prompt=None, - prefix=None, - suppress_blank=True, - suppress_tokens=[-1], - without_timestamps=True, - max_initial_timestamp=0.0, - word_timestamps=False, - prepend_punctuations="\"'“¿([{-", - append_punctuations="\"'.。,,!!??::”)]}、" - ) - whisper_arch = "large-v2" - device = "cuda" - batch_size = 16 - model = WhisperModel(whisper_arch, device="cuda", compute_type="float16",) - tokenizer = Tokenizer(model.hf_tokenizer, model.model.is_multilingual, task="transcribe", language="en") - model = FasterWhisperPipeline(model, tokenizer, faster_t_options, device=-1) - fn = "DanielKahneman_2010.wav" - wav_dir = f"/tmp/test/wav/" - vad_model = load_vad_model("cuda", 0.6, 0.3) - audio = load_audio(os.path.join(wav_dir, fn)) - vad_segments = vad_model({"waveform": torch.from_numpy(audio).unsqueeze(0), "sample_rate": SAMPLE_RATE}) - vad_segments = merge_chunks(vad_segments, 30) - - def data(audio, segments): - for seg in segments: - f1 = int(seg['start'] * SAMPLE_RATE) - f2 = int(seg['end'] * SAMPLE_RATE) - # print(f2-f1) - yield {'inputs': audio[f1:f2]} - vad_method="pyannote" - - wav_dir = f"/tmp/test/wav/" - wer_li = [] - time_li = [] - for fn in os.listdir(wav_dir): - if fn == "RobertGupta_2010U.wav": - continue - base_fn = fn.split('.')[0] - audio_fp = os.path.join(wav_dir, fn) - - audio = load_audio(audio_fp) - t1 = time.time() - if vad_method == "pyannote": - vad_segments = vad_model({"waveform": torch.from_numpy(audio).unsqueeze(0), "sample_rate": SAMPLE_RATE}) - vad_segments = merge_chunks(vad_segments, 30) - elif vad_method == "silero": - vad_segments = get_speech_timestamps(audio, threshold=0.5, max_speech_duration_s=30) - vad_segments = [{"start": x["start"] / SAMPLE_RATE, "end": x["end"] / SAMPLE_RATE} for x in vad_segments] - new_segs = [] - curr_start = vad_segments[0]['start'] - curr_end = vad_segments[0]['end'] - for seg in vad_segments[1:]: - if seg['end'] - curr_start > 30: - new_segs.append({"start": curr_start, "end": curr_end}) - curr_start = seg['start'] - curr_end = seg['end'] - else: - curr_end = seg['end'] - new_segs.append({"start": curr_start, "end": curr_end}) - vad_segments = new_segs - text = [] - # for idx, out in tqdm(enumerate(model(data(audio_fp, vad_segments), batch_size=batch_size)), total=len(vad_segments)): - for idx, out in enumerate(model(data(audio, vad_segments), batch_size=batch_size)): - text.append(out['text']) - t2 = time.time() - if batch_size == 1: - text = [x[0] for x in text] - text = " ".join(text) - - normalizer = EnglishTextNormalizer() - text = normalizer(text) - gt_corpus = normalizer(parse_tedlium_annos(base_fn, "/tmp/test/")) - - wer_result = jiwer.wer(gt_corpus, text) - print("WER: %.2f \t time: %.2f \t [%s]" % (wer_result * 100, t2-t1, fn)) - - wer_li.append(wer_result) - time_li.append(t2-t1) - print("# Avg Mean...") - print("WER: %.2f" % (sum(wer_li) * 100/len(wer_li))) - print("Time: %.2f" % (sum(time_li)/len(time_li))) - elif main_type == "simple": - model = load_model( - "large-v2", - device="cuda", - language="en", - ) - - wav_dir = f"/tmp/test/wav/" - wer_li = [] - time_li = [] - for fn in os.listdir(wav_dir): - if fn == "RobertGupta_2010U.wav": - continue - # fn = "DanielKahneman_2010.wav" - base_fn = fn.split('.')[0] - audio_fp = os.path.join(wav_dir, fn) - - audio = load_audio(audio_fp) - t1 = time.time() - out = model.transcribe(audio_fp, batch_size=8)["segments"] - t2 = time.time() - - text = " ".join([x['text'] for x in out]) - normalizer = EnglishTextNormalizer() - text = normalizer(text) - gt_corpus = normalizer(parse_tedlium_annos(base_fn, "/tmp/test/")) - - wer_result = jiwer.wer(gt_corpus, text) - print("WER: %.2f \t time: %.2f \t [%s]" % (wer_result * 100, t2-t1, fn)) - - wer_li.append(wer_result) - time_li.append(t2-t1) - print("# Avg Mean...") - print("WER: %.2f" % (sum(wer_li) * 100/len(wer_li))) - print("Time: %.2f" % (sum(time_li)/len(time_li))) diff --git a/whisperx/diarize.py b/whisperx/diarize.py index 93ff41de..320d2a48 100644 --- a/whisperx/diarize.py +++ b/whisperx/diarize.py @@ -11,7 +11,6 @@ def __init__( use_auth_token=None, device: Optional[Union[str, torch.device]] = "cpu", ): - self.model = Pipeline.from_pretrained(model_name, use_auth_token=use_auth_token) if isinstance(device, str): device = torch.device(device) self.model = Pipeline.from_pretrained(model_name, use_auth_token=use_auth_token).to(device) @@ -21,59 +20,44 @@ def __call__(self, audio, min_speakers=None, max_speakers=None): diarize_df = pd.DataFrame(segments.itertracks(yield_label=True)) diarize_df['start'] = diarize_df[0].apply(lambda x: x.start) diarize_df['end'] = diarize_df[0].apply(lambda x: x.end) + diarize_df.rename(columns={2: "speaker"}, inplace=True) return diarize_df -def assign_word_speakers(diarize_df, result_segments, fill_nearest=False): - for seg in result_segments: - wdf = seg['word-segments'] - if len(wdf['start'].dropna()) == 0: - wdf['start'] = seg['start'] - wdf['end'] = seg['end'] - speakers = [] - for wdx, wrow in wdf.iterrows(): - if not np.isnan(wrow['start']): - diarize_df['intersection'] = np.minimum(diarize_df['end'], wrow['end']) - np.maximum(diarize_df['start'], wrow['start']) - diarize_df['union'] = np.maximum(diarize_df['end'], wrow['end']) - np.minimum(diarize_df['start'], wrow['start']) - # remove no hit - if not fill_nearest: - dia_tmp = diarize_df[diarize_df['intersection'] > 0] - else: - dia_tmp = diarize_df - if len(dia_tmp) == 0: - speaker = None - else: - speaker = dia_tmp.sort_values("intersection", ascending=False).iloc[0][2] - else: - speaker = None - speakers.append(speaker) - seg['word-segments']['speaker'] = speakers - speaker_count = pd.Series(speakers).value_counts() - if len(speaker_count) == 0: - seg["speaker"]= "UNKNOWN" +def assign_word_speakers(diarize_df, transcript_result, fill_nearest=False): + transcript_segments = transcript_result["segments"] + for seg in transcript_segments: + # assign speaker to segment (if any) + diarize_df['intersection'] = np.minimum(diarize_df['end'], seg['end']) - np.maximum(diarize_df['start'], seg['start']) + diarize_df['union'] = np.maximum(diarize_df['end'], seg['end']) - np.minimum(diarize_df['start'], seg['start']) + # remove no hit, otherwise we look for closest (even negative intersection...) + if not fill_nearest: + dia_tmp = diarize_df[diarize_df['intersection'] > 0] else: - seg["speaker"] = speaker_count.index[0] + dia_tmp = diarize_df + if len(dia_tmp) > 0: + # sum over speakers + speaker = dia_tmp.groupby("speaker")["intersection"].sum().sort_values(ascending=False).index[0] + seg["speaker"] = speaker + + # assign speaker to words + if 'words' in seg: + for word in seg['words']: + if 'start' in word: + diarize_df['intersection'] = np.minimum(diarize_df['end'], word['end']) - np.maximum(diarize_df['start'], word['start']) + diarize_df['union'] = np.maximum(diarize_df['end'], word['end']) - np.minimum(diarize_df['start'], word['start']) + # remove no hit + if not fill_nearest: + dia_tmp = diarize_df[diarize_df['intersection'] > 0] + else: + dia_tmp = diarize_df + if len(dia_tmp) > 0: + # sum over speakers + speaker = dia_tmp.groupby("speaker")["intersection"].sum().sort_values(ascending=False).index[0] + word["speaker"] = speaker + + return transcript_result - # create word level segments for .srt - word_seg = [] - for seg in result_segments: - wseg = pd.DataFrame(seg["word-segments"]) - for wdx, wrow in wseg.iterrows(): - if wrow["start"] is not None: - speaker = wrow['speaker'] - if speaker is None or speaker == np.nan: - speaker = "UNKNOWN" - word_seg.append( - { - "start": wrow["start"], - "end": wrow["end"], - "text": f"[{speaker}]: " + seg["text"][int(wrow["segment-text-start"]):int(wrow["segment-text-end"])] - } - ) - - # TODO: create segments but split words on new speaker - - return result_segments, word_seg class Segment: def __init__(self, start, end, speaker=None): diff --git a/whisperx/transcribe.py b/whisperx/transcribe.py index f3f63fe5..b89a545e 100644 --- a/whisperx/transcribe.py +++ b/whisperx/transcribe.py @@ -64,14 +64,11 @@ def cli(): parser.add_argument("--max_line_width", type=optional_int, default=None, help="(not possible with --no_align) the maximum number of characters in a line before breaking the line") parser.add_argument("--max_line_count", type=optional_int, default=None, help="(requires --no_align) the maximum number of lines in a segment") parser.add_argument("--highlight_words", type=str2bool, default=False, help="(requires --word_timestamps True) underline each word as it is spoken in srt and vtt") + parser.add_argument("--segment_resolution", type=str, default="sentence", choices=["sentence", "chunk"], help="(not possible with --no_align) the maximum number of characters in a line before breaking the line") - # parser.add_argument("--word_timestamps", type=str2bool, default=False, help="(experimental) extract word-level timestamps and refine the results based on them") - # parser.add_argument("--prepend_punctuations", type=str, default="\"\'“¿([{-", help="if word_timestamps is True, merge these punctuation symbols with the next word") - # parser.add_argument("--append_punctuations", type=str, default="\"\'.。,,!!??::”)]}、", help="if word_timestamps is True, merge these punctuation symbols with the previous word") parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS") parser.add_argument("--hf_token", type=str, default=None, help="Hugging Face Access Token to access PyAnnote gated models") - # parser.add_argument("--model_flush", action="store_true", help="Flush memory from each model after use, reduces GPU requirement but slower processing >1 audio file.") # fmt: on args = parser.parse_args().__dict__ @@ -97,7 +94,6 @@ def cli(): min_speakers: int = args.pop("min_speakers") max_speakers: int = args.pop("max_speakers") - # TODO: check model loading works. if model_name.endswith(".en") and args["language"] not in {"en", "English"}: if args["language"] is not None: @@ -176,6 +172,7 @@ def cli(): align_model, align_metadata = load_align_model(result["language"], device) print(">>Performing alignment...") result = align(result["segments"], align_model, align_metadata, input_audio, device, interpolate_method=interpolate_method) + results.append((result, audio_path)) # Unload align model @@ -193,18 +190,10 @@ def cli(): diarize_model = DiarizationPipeline(use_auth_token=hf_token, device=device) for result, input_audio_path in tmp_results: diarize_segments = diarize_model(input_audio_path, min_speakers=min_speakers, max_speakers=max_speakers) - results_segments, word_segments = assign_word_speakers(diarize_segments, result["segments"]) - result = {"segments": results_segments, "word_segments": word_segments} + result = assign_word_speakers(diarize_segments, result) results.append((result, input_audio_path)) - # >> Write for result, audio_path in results: - # Remove pandas dataframes from result so that - # we can serialize the result with json - for seg in result["segments"]: - seg.pop("word-segments", None) - seg.pop("char-segments", None) - writer(result, audio_path, writer_args) if __name__ == "__main__": diff --git a/whisperx/utils.py b/whisperx/utils.py index 3401a848..d042bb70 100644 --- a/whisperx/utils.py +++ b/whisperx/utils.py @@ -231,11 +231,16 @@ def iterate_subtitles(): line_count = 1 # the next subtitle to yield (a list of word timings with whitespace) subtitle: list[dict] = [] - last = result["segments"][0]["words"][0]["start"] + times = [] + last = result["segments"][0]["start"] for segment in result["segments"]: for i, original_timing in enumerate(segment["words"]): timing = original_timing.copy() - long_pause = not preserve_segments and timing["start"] - last > 3.0 + long_pause = not preserve_segments + if "start" in timing: + long_pause = long_pause and timing["start"] - last > 3.0 + else: + long_pause = False has_room = line_len + len(timing["word"]) <= max_line_width seg_break = i == 0 and len(subtitle) > 0 and preserve_segments if line_len > 0 and has_room and not long_pause and not seg_break: @@ -251,8 +256,9 @@ def iterate_subtitles(): or seg_break ): # subtitle break - yield subtitle + yield subtitle, times subtitle = [] + times = [] line_count = 1 elif line_len > 0: # line break @@ -260,40 +266,53 @@ def iterate_subtitles(): timing["word"] = "\n" + timing["word"] line_len = len(timing["word"].strip()) subtitle.append(timing) - last = timing["start"] + times.append((segment["start"], segment["end"], segment.get("speaker"))) + if "start" in timing: + last = timing["start"] if len(subtitle) > 0: - yield subtitle + yield subtitle, times if "words" in result["segments"][0]: - for subtitle in iterate_subtitles(): - subtitle_start = self.format_timestamp(subtitle[0]["start"]) - subtitle_end = self.format_timestamp(subtitle[-1]["end"]) - subtitle_text = "".join([word["word"] for word in subtitle]) - if highlight_words: + for subtitle, _ in iterate_subtitles(): + sstart, ssend, speaker = _[0] + subtitle_start = self.format_timestamp(sstart) + subtitle_end = self.format_timestamp(ssend) + subtitle_text = " ".join([word["word"] for word in subtitle]) + has_timing = any(["start" in word for word in subtitle]) + + # add [$SPEAKER_ID]: to each subtitle if speaker is available + prefix = "" + if speaker is not None: + prefix = f"[{speaker}]: " + + if highlight_words and has_timing: last = subtitle_start all_words = [timing["word"] for timing in subtitle] for i, this_word in enumerate(subtitle): - start = self.format_timestamp(this_word["start"]) - end = self.format_timestamp(this_word["end"]) - if last != start: - yield last, start, subtitle_text - - yield start, end, "".join( - [ - re.sub(r"^(\s*)(.*)$", r"\1\2", word) - if j == i - else word - for j, word in enumerate(all_words) - ] - ) - last = end + if "start" in this_word: + start = self.format_timestamp(this_word["start"]) + end = self.format_timestamp(this_word["end"]) + if last != start: + yield last, start, subtitle_text + + yield start, end, prefix + " ".join( + [ + re.sub(r"^(\s*)(.*)$", r"\1\2", word) + if j == i + else word + for j, word in enumerate(all_words) + ] + ) + last = end else: - yield subtitle_start, subtitle_end, subtitle_text + yield subtitle_start, subtitle_end, prefix + subtitle_text else: for segment in result["segments"]: segment_start = self.format_timestamp(segment["start"]) segment_end = self.format_timestamp(segment["end"]) segment_text = segment["text"].strip().replace("-->", "->") + if "speaker" in segment: + segment_text = f"[{segment['speaker']}]: {segment_text}" yield segment_start, segment_end, segment_text def format_timestamp(self, seconds: float):