From 0bf3570bd6b376c936ea9f04fc15f129e738b168 Mon Sep 17 00:00:00 2001 From: BenAAndrew Date: Sun, 28 Nov 2021 19:54:18 +0000 Subject: [PATCH] Fix tests --- tests/test_dataset.py | 15 +++++++-------- training/clean_text.py | 9 ++++++--- training/utils.py | 2 +- 3 files changed, 14 insertions(+), 12 deletions(-) diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 6eb4412..bbd142c 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -225,31 +225,26 @@ def test_clip_combiner(): def test_extend_existing_dataset(): dataset_directory = "test-extend-dataset" audio_folder = os.path.join(dataset_directory, "wavs") - unlabelled_path = os.path.join(dataset_directory, "unlabelled") + unlabelled_folder = os.path.join(dataset_directory, "unlabelled") metadata_file = os.path.join(dataset_directory, "metadata.csv") os.makedirs(dataset_directory) os.makedirs(audio_folder) + os.makedirs(unlabelled_folder) with open(metadata_file, "w") as f: pass audio_path = os.path.join("test_samples", "audio.wav") converted_audio_path = os.path.join("test_samples", "audio-converted.wav") text_path = os.path.join("test_samples", "text.txt") - forced_alignment_path = os.path.join(dataset_directory, "align.json") label_path = os.path.join(dataset_directory, "metadata.csv") - info_path = os.path.join(dataset_directory, "info.json") suffix = "extend" min_confidence = 1.0 extend_existing_dataset( text_path=text_path, audio_path=audio_path, transcription_model=FakeTranscriptionModel(), - forced_alignment_path=forced_alignment_path, - output_path=audio_folder, - unlabelled_path=unlabelled_path, - label_path=label_path, + output_folder=dataset_directory, suffix=suffix, - info_path=info_path, min_confidence=min_confidence, combine_clips=False, ) @@ -258,6 +253,10 @@ def test_extend_existing_dataset(): name.split(".")[0] + "-" + suffix + ".wav" for name in EXPECTED_CLIPS ], "Unexpected audio clips" + assert os.listdir(unlabelled_folder) == [ + name.split(".")[0] + "-" + suffix + ".wav" for name in UNMATCHED_CLIPS + ], "Unexpected unlabelled audio clips" + with open(label_path) as f: lines = f.readlines() expected_text = [ diff --git a/training/clean_text.py b/training/clean_text.py index 0f8816e..e6c3aab 100644 --- a/training/clean_text.py +++ b/training/clean_text.py @@ -36,7 +36,7 @@ } -def clean_text(text, symbols=DEFAULT_ALPHABET): +def clean_text(text, symbols=DEFAULT_ALPHABET, remove_invalid_characters=True): """ Cleans text. This includes: - Replacing monetary terms (i.e. $ -> dollars) @@ -49,8 +49,10 @@ def clean_text(text, symbols=DEFAULT_ALPHABET): ---------- text : str Text to clean - symbols : list + symbols : list (optional) List of valid symbols in text (default is English alphabet & punctuation) + remove_invalid_characters : bool (optional) + Whether to remove characters not in symbols list (default is True) Returns ------- @@ -83,7 +85,8 @@ def clean_text(text, symbols=DEFAULT_ALPHABET): # Collapse whitespace text = re.sub(WHITESPACE_RE, " ", text) # Remove banned characters - text = "".join([c for c in text if c in symbols]) + if remove_invalid_characters: + text = "".join([c for c in text if c in symbols]) return text diff --git a/training/utils.py b/training/utils.py index 28c3cb6..cba939d 100644 --- a/training/utils.py +++ b/training/utils.py @@ -100,7 +100,7 @@ def validate_dataset(filepaths_and_text, dataset_directory, symbols): invalid_characters = set() wavs = os.listdir(dataset_directory) for filename, text in filepaths_and_text: - text = clean_text(text, symbols) + text = clean_text(text, remove_invalid_characters=False) if filename not in wavs: missing_files.add(filename) invalid_characters_for_row = get_invalid_characters(text, symbols)