From 301fd10a0ca7ab314ab64ccf9501f5e101782030 Mon Sep 17 00:00:00 2001 From: BenAAndrew Date: Wed, 24 Nov 2021 20:19:06 +0000 Subject: [PATCH] Improve dataset validation and errors --- dataset/__init__.py | 19 +++++++++++++ dataset/clip_generator.py | 58 +++++++++++++++++++++++-------------- tests/test_dataset.py | 11 ++++++- tests/test_training.py | 60 +++++++++++++++++++++++++++++++++------ training/train.py | 6 +++- training/utils.py | 42 +++++++++++++++++++++++---- 6 files changed, 160 insertions(+), 36 deletions(-) diff --git a/dataset/__init__.py b/dataset/__init__.py index e69de29..8362fa2 100644 --- a/dataset/__init__.py +++ b/dataset/__init__.py @@ -0,0 +1,19 @@ +from string import punctuation, digits + +def get_invalid_characters(text, symbols): + """ + Returns all invalid characters in text + + Parameters + ---------- + text : str + String to check + symbols : list + List of symbols that are valid + + Returns + ------- + set + All invalid characters + """ + return set([c for c in text if c not in symbols and c not in punctuation and c not in digits]) diff --git a/dataset/clip_generator.py b/dataset/clip_generator.py index 24da3e8..c2d8c3d 100644 --- a/dataset/clip_generator.py +++ b/dataset/clip_generator.py @@ -9,12 +9,13 @@ from pydub import AudioSegment from datetime import datetime +from dataset import get_invalid_characters from dataset.utils import similarity import dataset.forced_alignment.align as align from dataset.forced_alignment.search import FuzzySearch from dataset.forced_alignment.audio import DEFAULT_RATE from dataset.audio_processing import change_sample_rate, cut_audio, add_silence -from training import PUNCTUATION +from training import DEFAULT_ALPHABET, PUNCTUATION MIN_LENGTH = 1.0 @@ -98,7 +99,7 @@ def _combine_clip(combined_clip, audio_path, output_path): def generate_clips_from_textfile( audio_path, - script_path, + text, transcription_model, output_path, logging=logging, @@ -113,8 +114,8 @@ def generate_clips_from_textfile( ---------- audio_path : str Path to audio file (must have been converted using convert_audio) - script_path : str - Path to text file + text : str + Source text transcription_model : TranscriptionModel Transcription model output_path : str @@ -133,12 +134,8 @@ def generate_clips_from_textfile( (list, list) List of clips and clip lengths in seconds """ - logging.info(f"Loading script from {script_path}...") - with open(script_path, "r", encoding=CHARACTER_ENCODING) as script_file: - clean_text = script_file.read().lower().strip().replace("\n", " ").replace(" ", " ") - logging.info("Searching text for matching fragments...") - search = FuzzySearch(clean_text) + search = FuzzySearch(text) logging.info("Changing sample rate...") converted_audio_path = change_sample_rate(audio_path, DEFAULT_RATE) @@ -169,7 +166,7 @@ def generate_clips_from_textfile( and "match-end" in fragment and fragment["match-end"] - fragment["match-start"] > 0 ): - fragment_matched = clean_text[fragment["match-start"] : fragment["match-end"]] + fragment_matched = text[fragment["match-start"] : fragment["match-end"]] if fragment_matched: fragment["text"] = fragment_matched clip_lengths.append(fragment["duration"]) @@ -183,7 +180,7 @@ def generate_clips_from_textfile( def generate_clips_from_subtitles( audio_path, - subtitle_path, + subs, transcription_model, output_path, logging=logging, @@ -198,8 +195,8 @@ def generate_clips_from_subtitles( ---------- audio_path : str Path to audio file (must have been converted using convert_audio) - subtitle_path : str - Path to subtitle file + subs : list + List of pysrt subtitle objects transcription_model : TranscriptionModel Transcription model output_path : str @@ -219,7 +216,6 @@ def generate_clips_from_subtitles( List of clips and clip lengths in seconds """ logging.info("Loading subtitles...") - subs = pysrt.open(subtitle_path) total = len(subs) logging.info(f"{total} subtitle lines detected...") @@ -272,6 +268,7 @@ def clip_generator( unlabelled_path, label_path, logging=logging, + symbols=DEFAULT_ALPHABET, min_length=MIN_LENGTH, max_length=MAX_LENGTH, silence_padding=0.1, @@ -300,6 +297,8 @@ def clip_generator( Path to save label file to logging : logging (optional) Logging object to write logs to + symbols : list (optional) + List of valid symbols min_length : float (optional) Minimum duration of a clip in seconds max_length : float (optional) @@ -324,16 +323,33 @@ def clip_generator( assert not os.path.isdir( output_path ), "Output directory already exists. Did you mean to use 'Extend existing dataset'?" - os.makedirs(output_path, exist_ok=False) - os.makedirs(unlabelled_path, exist_ok=False) assert os.path.isfile(audio_path), "Audio file not found" assert os.path.isfile(script_path), "Script file not found" assert audio_path.endswith(".wav"), "Must be a WAV file" - if script_path.endswith(".srt"): + os.makedirs(output_path, exist_ok=False) + os.makedirs(unlabelled_path, exist_ok=False) + + # Validate text + is_subtitle = script_path.endswith(".srt") + logging.info(f"Loading {script_path}...") + + if is_subtitle: + subs = pysrt.open(script_path) + text = ' '.join([sub.text for sub in subs]) + else: + with open(script_path, "r", encoding=CHARACTER_ENCODING) as script_file: + text = script_file.read() + + text = text.lower().strip().replace("\n", " ").replace(" ", " ") + invalid_chars = get_invalid_characters(text, symbols) + assert not invalid_chars, f"Invalid characters in text (missing from language): {','.join(invalid_chars)}" + + # Generate clips + if is_subtitle: clips, unlabelled_clips, clip_lengths = generate_clips_from_subtitles( audio_path, - script_path, + subs, transcription_model, output_path, logging, @@ -344,7 +360,7 @@ def clip_generator( else: clips, unlabelled_clips, clip_lengths = generate_clips_from_textfile( audio_path, - script_path, + text, transcription_model, output_path, logging, @@ -353,6 +369,8 @@ def clip_generator( min_confidence, ) + assert clips, "No audio clips could be generated" + if combine_clips: logging.info("Combining clips") clips, clip_lengths = clip_combiner(audio_path, output_path, clips, max_length) @@ -371,8 +389,6 @@ def clip_generator( else: os.remove(os.path.join(output_path, filename)) - assert clips, "No audio clips could be generated" - # Produce alignment file logging.info(f"Produced {len(clips)} final clips") with open(forced_alignment_path, "w", encoding=CHARACTER_ENCODING) as result_file: diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 75ee7f6..bfeccec 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -3,8 +3,10 @@ import json from pathlib import Path import json +import pysrt from tests.test_synthesis import MIN_SYNTHESIS_SCORE +from dataset import get_invalid_characters from dataset.analysis import get_total_audio_duration, get_clip_lengths, validate_dataset from dataset.clip_generator import generate_clips_from_subtitles, clip_combiner from dataset.create_dataset import create_dataset @@ -36,6 +38,12 @@ def transcribe(self, path): return TRANSCRIPTION[filename] +# Invalid characters +def test_get_invalid_characters(): + invalid_chars = get_invalid_characters("aà1!", ["a"]) + assert invalid_chars == set("à") + + # Dataset creation def test_create_dataset(): audio_path = os.path.join("test_samples", "audio.wav") @@ -97,10 +105,11 @@ def test_generate_clips_from_subtitles(): os.makedirs(dataset_directory) audio_path = os.path.join("test_samples", "audio.wav") subtitle_path = os.path.join("test_samples", "sub.srt") + subs = pysrt.open(subtitle_path) clips, unlabelled_clips, clip_lengths = generate_clips_from_subtitles( audio_path=audio_path, - subtitle_path=subtitle_path, + subs=subs, transcription_model=FakeSubtitleTranscriptionModel(), output_path=dataset_directory, ) diff --git a/tests/test_training.py b/tests/test_training.py index 50c163e..226b490 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -30,6 +30,8 @@ LEARNING_RATE_PER_64, BATCH_SIZE_PER_GB, BASE_SYMBOLS, + train_test_split, + validate_dataset, ) @@ -246,16 +248,58 @@ def test_load_metadata(): "2820_5100.wav": "enabled the commission to conclude", "5130_7560.wav": "that five shots may have been", } - train_size = 0.67 - train_files, test_files = load_metadata(metadata_path, train_size) - assert len(train_files) == 2 - for name, text in train_files: + filepaths_and_text = load_metadata(metadata_path) + assert len(filepaths_and_text) == 3 + for name, text in filepaths_and_text: assert data[name] == text - assert len(test_files) == 1 - name, text = test_files[0] - assert data[name] == text - assert name not in [i[0] for i in train_files] + +def test_train_test_split(): + filepaths_and_text = [ + ("0_2730.wav", "the examination and testimony of the experts"), + ("2820_5100.wav", "enabled the commission to conclude"), + ("5130_7560.wav", "that five shots may have been") + ] + train_files, test_files = train_test_split(filepaths_and_text, 0.67) + assert train_files == filepaths_and_text[:2] + assert test_files == filepaths_and_text[2:] + + +# Validate dataset +@mock.patch("os.listdir", return_value=["1.wav", "3.wav"]) +def test_validate_dataset_missing_files(listdir): + filepaths_and_text = [ + ("1.wav", "abc"), + ("2.wav", "abc"), + ("3.wav", "abc") + ] + symbols = ["a", "b", "c"] + + exception = "" + try: + validate_dataset(filepaths_and_text, "", symbols) + except AssertionError as e: + exception = str(e) + + assert exception == "Missing files: 2.wav" + +@mock.patch("os.listdir", return_value=["1.wav", "2.wav"]) +def test_validate_dataset_invalid_characters(listdir): + filepaths_and_text = [ + ("1.wav", "abc"), + ("2.wav", "def"), + ] + symbols = ["a", "b", "c"] + + exception = "" + try: + validate_dataset(filepaths_and_text, "", symbols) + except AssertionError as e: + exception = str(e) + + failed_characters = exception.split(":")[1] + for character in ["d","e","f"]: + assert character in failed_characters # Memory diff --git a/training/train.py b/training/train.py index 9bcbbca..0283cbd 100644 --- a/training/train.py +++ b/training/train.py @@ -28,6 +28,8 @@ load_symbols, check_early_stopping, calc_avgmax_attention, + train_test_split, + validate_dataset, ) from training.tacotron2_model import Tacotron2, TextMelCollate, Tacotron2Loss from training.tacotron2_model.utils import process_batch @@ -134,8 +136,10 @@ def train( # Load data logging.info("Loading data...") - train_files, test_files = load_metadata(metadata_path, train_size) + filepaths_and_text = load_metadata(metadata_path) symbols = load_symbols(alphabet_path) if alphabet_path else DEFAULT_ALPHABET + validate_dataset(filepaths_and_text, dataset_directory, symbols) + train_files, test_files = train_test_split(filepaths_and_text, train_size) trainset = VoiceDataset(train_files, dataset_directory, symbols) valset = VoiceDataset(test_files, dataset_directory, symbols) collate_fn = TextMelCollate() diff --git a/training/utils.py b/training/utils.py index 3aca6a8..f6d1807 100644 --- a/training/utils.py +++ b/training/utils.py @@ -4,6 +4,7 @@ from PIL import Image from dataset.clip_generator import CHARACTER_ENCODING +from dataset import get_invalid_characters from training import BASE_SYMBOLS from training.tacotron2_model.utils import get_mask_from_lengths @@ -73,7 +74,7 @@ def get_learning_rate(batch_size): ) -def load_metadata(metadata_path, train_size): +def load_metadata(metadata_path): """ Load metadata file and split entries into train and test. @@ -81,6 +82,41 @@ def load_metadata(metadata_path, train_size): ---------- metadata_path : str Path to metadata file + + Returns + ------- + list + List of samples + """ + with open(metadata_path, encoding=CHARACTER_ENCODING) as f: + filepaths_and_text = [line.strip().split("|") for line in f] + random.shuffle(filepaths_and_text) + return filepaths_and_text + + +def validate_dataset(filepaths_and_text, dataset_directory, symbols): + missing_files = set() + invalid_characters = set() + wavs = os.listdir(dataset_directory) + for filename, text in filepaths_and_text: + if filename not in wavs: + missing_files.add(filename) + invalid_characters_for_row = get_invalid_characters(text, symbols) + if invalid_characters_for_row: + invalid_characters.update(invalid_characters_for_row) + + assert not missing_files, f"Missing files: {(',').join(missing_files)}" + assert not invalid_characters, f"Invalid characters (for alphabet): {(',').join(invalid_characters)}" + + +def train_test_split(filepaths_and_text, train_size): + """ + Split dataset into train & test data + + Parameters + ---------- + filepaths_and_text : list + List of samples train_size : float Percentage of entries to use for training (rest used for testing) @@ -89,10 +125,6 @@ def load_metadata(metadata_path, train_size): (list, list) List of train and test samples """ - with open(metadata_path, encoding=CHARACTER_ENCODING) as f: - filepaths_and_text = [line.strip().split("|") for line in f] - - random.shuffle(filepaths_and_text) train_cutoff = int(len(filepaths_and_text) * train_size) train_files = filepaths_and_text[:train_cutoff] test_files = filepaths_and_text[train_cutoff:]