diff --git a/worker/tests/data/test_strict_sentence_paragraphs-protected_paragraphs.json b/worker/tests/data/test_strict_sentence_paragraphs-protected_paragraphs.json new file mode 100644 index 00000000..067834ca --- /dev/null +++ b/worker/tests/data/test_strict_sentence_paragraphs-protected_paragraphs.json @@ -0,0 +1,146 @@ +{ + "input": [ + { + "type": "paragraph", + "speaker": null, + "lang": "de", + "children": [ + { + "text": "Willkommen ", + "start": 0.0, + "end": 0.82, + "conf": 0.4493102431297302, + "conf_ts": 0.0 + }, + { + "text": "zum ", + "start": 0.82, + "end": 1.07, + "conf": 0.9744400978088379, + "conf_ts": 0.005903157405555248 + } + ] + }, + { + "type": "paragraph", + "speaker": null, + "lang": "de", + "children": [ + { + "text": "[", + "start": 0.0, + "end": 0.82, + "conf": 0.4493102431297302, + "conf_ts": 0.0 + }, + { + "text": "Music", + "start": 0.82, + "end": 1, + "conf": 0.9744400978088379, + "conf_ts": 0.005903157405555248 + }, + { + "text": "]", + "start": 1, + "end": 1.07, + "conf": 0.4493102431297302, + "conf_ts": 0.0 + } + ] + }, + { + "type": "paragraph", + "speaker": null, + "lang": "de", + "children": [ + { + "text": "letzten ", + "start": 1.07, + "end": 1.65, + "conf": 0.9838394522666931, + "conf_ts": 0.01149927917867899 + }, + { + "text": "Token.", + "start": 1.65, + "end": 2.06, + "conf": 0.9566531777381897, + "conf_ts": 0.0096774036064744 + } + ] + } + ], + "expected": [ + { + "type": "paragraph", + "speaker": null, + "lang": "de", + "children": [ + { + "text": "Willkommen ", + "start": 0.0, + "end": 0.82, + "conf": 0.4493102431297302, + "conf_ts": 0.0 + }, + { + "text": "zum ", + "start": 0.82, + "end": 1.07, + "conf": 0.9744400978088379, + "conf_ts": 0.005903157405555248 + } + ] + }, + { + "type": "paragraph", + "speaker": null, + "lang": "de", + "children": [ + { + "text": "[", + "start": 0.0, + "end": 0.82, + "conf": 0.4493102431297302, + "conf_ts": 0.0 + }, + { + "text": "Music", + "start": 0.82, + "end": 1, + "conf": 0.9744400978088379, + "conf_ts": 0.005903157405555248 + }, + { + "text": "]", + "start": 1, + "end": 1.07, + "conf": 0.4493102431297302, + "conf_ts": 0.0 + } + ] + }, + { + "type": "paragraph", + "speaker": null, + "lang": "de", + "children": [ + { + "text": "letzten ", + "start": 1.07, + "end": 1.65, + "conf": 0.9838394522666931, + "conf_ts": 0.01149927917867899 + }, + { + "text": "Token.", + "start": 1.65, + "end": 2.06, + "conf": 0.9566531777381897, + "conf_ts": 0.0096774036064744 + } + ] + } + ] +} diff --git a/worker/transcribee_worker/whisper_transcribe.py b/worker/transcribee_worker/whisper_transcribe.py index ae8ca224..71cbcd8b 100644 --- a/worker/transcribee_worker/whisper_transcribe.py +++ b/worker/transcribee_worker/whisper_transcribe.py @@ -22,6 +22,11 @@ r".*\d\.\s?$" ), # Don't split on numerals followed by a dot, e.g. "during the 20. century" ] +# Regexes that protect a paragraph from being recombined +DONT_COMBINE_RES = [ + re.compile(r"^\[[^\s]*\]$"), # [MUSIC] + re.compile(r"^\*[^\s]*\*$"), # *Applause* +] def get_model_file(model_name: str): @@ -262,6 +267,13 @@ async def strict_sentence_paragraphs( lang=paragraph.lang, speaker=paragraph.speaker, children=[] ) + elif any(regex.search(paragraph.text()) for regex in DONT_COMBINE_RES): + if acc_paragraph.children: + yield acc_paragraph + acc_paragraph = None + yield paragraph + continue + locale = Locale(paragraph.lang) sentence_iter = BreakIterator.createSentenceInstance(locale) sentence_iter.setText(acc_paragraph.text() + paragraph.text())