Skip to content

Commit

Permalink
Improve dataset validation and errors
Browse files Browse the repository at this point in the history
  • Loading branch information
BenAAndrew committed Nov 24, 2021
1 parent 2e21b39 commit 301fd10
Show file tree
Hide file tree
Showing 6 changed files with 160 additions and 36 deletions.
19 changes: 19 additions & 0 deletions dataset/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from string import punctuation, digits

def get_invalid_characters(text, symbols):
"""
Returns all invalid characters in text
Parameters
----------
text : str
String to check
symbols : list
List of symbols that are valid
Returns
-------
set
All invalid characters
"""
return set([c for c in text if c not in symbols and c not in punctuation and c not in digits])
58 changes: 37 additions & 21 deletions dataset/clip_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,13 @@
from pydub import AudioSegment
from datetime import datetime

from dataset import get_invalid_characters
from dataset.utils import similarity
import dataset.forced_alignment.align as align
from dataset.forced_alignment.search import FuzzySearch
from dataset.forced_alignment.audio import DEFAULT_RATE
from dataset.audio_processing import change_sample_rate, cut_audio, add_silence
from training import PUNCTUATION
from training import DEFAULT_ALPHABET, PUNCTUATION


MIN_LENGTH = 1.0
Expand Down Expand Up @@ -98,7 +99,7 @@ def _combine_clip(combined_clip, audio_path, output_path):

def generate_clips_from_textfile(
audio_path,
script_path,
text,
transcription_model,
output_path,
logging=logging,
Expand All @@ -113,8 +114,8 @@ def generate_clips_from_textfile(
----------
audio_path : str
Path to audio file (must have been converted using convert_audio)
script_path : str
Path to text file
text : str
Source text
transcription_model : TranscriptionModel
Transcription model
output_path : str
Expand All @@ -133,12 +134,8 @@ def generate_clips_from_textfile(
(list, list)
List of clips and clip lengths in seconds
"""
logging.info(f"Loading script from {script_path}...")
with open(script_path, "r", encoding=CHARACTER_ENCODING) as script_file:
clean_text = script_file.read().lower().strip().replace("\n", " ").replace(" ", " ")

logging.info("Searching text for matching fragments...")
search = FuzzySearch(clean_text)
search = FuzzySearch(text)

logging.info("Changing sample rate...")
converted_audio_path = change_sample_rate(audio_path, DEFAULT_RATE)
Expand Down Expand Up @@ -169,7 +166,7 @@ def generate_clips_from_textfile(
and "match-end" in fragment
and fragment["match-end"] - fragment["match-start"] > 0
):
fragment_matched = clean_text[fragment["match-start"] : fragment["match-end"]]
fragment_matched = text[fragment["match-start"] : fragment["match-end"]]
if fragment_matched:
fragment["text"] = fragment_matched
clip_lengths.append(fragment["duration"])
Expand All @@ -183,7 +180,7 @@ def generate_clips_from_textfile(

def generate_clips_from_subtitles(
audio_path,
subtitle_path,
subs,
transcription_model,
output_path,
logging=logging,
Expand All @@ -198,8 +195,8 @@ def generate_clips_from_subtitles(
----------
audio_path : str
Path to audio file (must have been converted using convert_audio)
subtitle_path : str
Path to subtitle file
subs : list
List of pysrt subtitle objects
transcription_model : TranscriptionModel
Transcription model
output_path : str
Expand All @@ -219,7 +216,6 @@ def generate_clips_from_subtitles(
List of clips and clip lengths in seconds
"""
logging.info("Loading subtitles...")
subs = pysrt.open(subtitle_path)
total = len(subs)
logging.info(f"{total} subtitle lines detected...")

Expand Down Expand Up @@ -272,6 +268,7 @@ def clip_generator(
unlabelled_path,
label_path,
logging=logging,
symbols=DEFAULT_ALPHABET,
min_length=MIN_LENGTH,
max_length=MAX_LENGTH,
silence_padding=0.1,
Expand Down Expand Up @@ -300,6 +297,8 @@ def clip_generator(
Path to save label file to
logging : logging (optional)
Logging object to write logs to
symbols : list (optional)
List of valid symbols
min_length : float (optional)
Minimum duration of a clip in seconds
max_length : float (optional)
Expand All @@ -324,16 +323,33 @@ def clip_generator(
assert not os.path.isdir(
output_path
), "Output directory already exists. Did you mean to use 'Extend existing dataset'?"
os.makedirs(output_path, exist_ok=False)
os.makedirs(unlabelled_path, exist_ok=False)
assert os.path.isfile(audio_path), "Audio file not found"
assert os.path.isfile(script_path), "Script file not found"
assert audio_path.endswith(".wav"), "Must be a WAV file"

if script_path.endswith(".srt"):
os.makedirs(output_path, exist_ok=False)
os.makedirs(unlabelled_path, exist_ok=False)

# Validate text
is_subtitle = script_path.endswith(".srt")
logging.info(f"Loading {script_path}...")

if is_subtitle:
subs = pysrt.open(script_path)
text = ' '.join([sub.text for sub in subs])
else:
with open(script_path, "r", encoding=CHARACTER_ENCODING) as script_file:
text = script_file.read()

text = text.lower().strip().replace("\n", " ").replace(" ", " ")
invalid_chars = get_invalid_characters(text, symbols)
assert not invalid_chars, f"Invalid characters in text (missing from language): {','.join(invalid_chars)}"

# Generate clips
if is_subtitle:
clips, unlabelled_clips, clip_lengths = generate_clips_from_subtitles(
audio_path,
script_path,
subs,
transcription_model,
output_path,
logging,
Expand All @@ -344,7 +360,7 @@ def clip_generator(
else:
clips, unlabelled_clips, clip_lengths = generate_clips_from_textfile(
audio_path,
script_path,
text,
transcription_model,
output_path,
logging,
Expand All @@ -353,6 +369,8 @@ def clip_generator(
min_confidence,
)

assert clips, "No audio clips could be generated"

if combine_clips:
logging.info("Combining clips")
clips, clip_lengths = clip_combiner(audio_path, output_path, clips, max_length)
Expand All @@ -371,8 +389,6 @@ def clip_generator(
else:
os.remove(os.path.join(output_path, filename))

assert clips, "No audio clips could be generated"

# Produce alignment file
logging.info(f"Produced {len(clips)} final clips")
with open(forced_alignment_path, "w", encoding=CHARACTER_ENCODING) as result_file:
Expand Down
11 changes: 10 additions & 1 deletion tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
import json
from pathlib import Path
import json
import pysrt

from tests.test_synthesis import MIN_SYNTHESIS_SCORE
from dataset import get_invalid_characters
from dataset.analysis import get_total_audio_duration, get_clip_lengths, validate_dataset
from dataset.clip_generator import generate_clips_from_subtitles, clip_combiner
from dataset.create_dataset import create_dataset
Expand Down Expand Up @@ -36,6 +38,12 @@ def transcribe(self, path):
return TRANSCRIPTION[filename]


# Invalid characters
def test_get_invalid_characters():
invalid_chars = get_invalid_characters("aà1!", ["a"])
assert invalid_chars == set("à")


# Dataset creation
def test_create_dataset():
audio_path = os.path.join("test_samples", "audio.wav")
Expand Down Expand Up @@ -97,10 +105,11 @@ def test_generate_clips_from_subtitles():
os.makedirs(dataset_directory)
audio_path = os.path.join("test_samples", "audio.wav")
subtitle_path = os.path.join("test_samples", "sub.srt")
subs = pysrt.open(subtitle_path)

clips, unlabelled_clips, clip_lengths = generate_clips_from_subtitles(
audio_path=audio_path,
subtitle_path=subtitle_path,
subs=subs,
transcription_model=FakeSubtitleTranscriptionModel(),
output_path=dataset_directory,
)
Expand Down
60 changes: 52 additions & 8 deletions tests/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
LEARNING_RATE_PER_64,
BATCH_SIZE_PER_GB,
BASE_SYMBOLS,
train_test_split,
validate_dataset,
)


Expand Down Expand Up @@ -246,16 +248,58 @@ def test_load_metadata():
"2820_5100.wav": "enabled the commission to conclude",
"5130_7560.wav": "that five shots may have been",
}
train_size = 0.67
train_files, test_files = load_metadata(metadata_path, train_size)
assert len(train_files) == 2
for name, text in train_files:
filepaths_and_text = load_metadata(metadata_path)
assert len(filepaths_and_text) == 3
for name, text in filepaths_and_text:
assert data[name] == text

assert len(test_files) == 1
name, text = test_files[0]
assert data[name] == text
assert name not in [i[0] for i in train_files]

def test_train_test_split():
filepaths_and_text = [
("0_2730.wav", "the examination and testimony of the experts"),
("2820_5100.wav", "enabled the commission to conclude"),
("5130_7560.wav", "that five shots may have been")
]
train_files, test_files = train_test_split(filepaths_and_text, 0.67)
assert train_files == filepaths_and_text[:2]
assert test_files == filepaths_and_text[2:]


# Validate dataset
@mock.patch("os.listdir", return_value=["1.wav", "3.wav"])
def test_validate_dataset_missing_files(listdir):
filepaths_and_text = [
("1.wav", "abc"),
("2.wav", "abc"),
("3.wav", "abc")
]
symbols = ["a", "b", "c"]

exception = ""
try:
validate_dataset(filepaths_and_text, "", symbols)
except AssertionError as e:
exception = str(e)

assert exception == "Missing files: 2.wav"

@mock.patch("os.listdir", return_value=["1.wav", "2.wav"])
def test_validate_dataset_invalid_characters(listdir):
filepaths_and_text = [
("1.wav", "abc"),
("2.wav", "def"),
]
symbols = ["a", "b", "c"]

exception = ""
try:
validate_dataset(filepaths_and_text, "", symbols)
except AssertionError as e:
exception = str(e)

failed_characters = exception.split(":")[1]
for character in ["d","e","f"]:
assert character in failed_characters


# Memory
Expand Down
6 changes: 5 additions & 1 deletion training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
load_symbols,
check_early_stopping,
calc_avgmax_attention,
train_test_split,
validate_dataset,
)
from training.tacotron2_model import Tacotron2, TextMelCollate, Tacotron2Loss
from training.tacotron2_model.utils import process_batch
Expand Down Expand Up @@ -134,8 +136,10 @@ def train(

# Load data
logging.info("Loading data...")
train_files, test_files = load_metadata(metadata_path, train_size)
filepaths_and_text = load_metadata(metadata_path)
symbols = load_symbols(alphabet_path) if alphabet_path else DEFAULT_ALPHABET
validate_dataset(filepaths_and_text, dataset_directory, symbols)
train_files, test_files = train_test_split(filepaths_and_text, train_size)
trainset = VoiceDataset(train_files, dataset_directory, symbols)
valset = VoiceDataset(test_files, dataset_directory, symbols)
collate_fn = TextMelCollate()
Expand Down
42 changes: 37 additions & 5 deletions training/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from PIL import Image

from dataset.clip_generator import CHARACTER_ENCODING
from dataset import get_invalid_characters
from training import BASE_SYMBOLS
from training.tacotron2_model.utils import get_mask_from_lengths

Expand Down Expand Up @@ -73,14 +74,49 @@ def get_learning_rate(batch_size):
)


def load_metadata(metadata_path, train_size):
def load_metadata(metadata_path):
"""
Load metadata file and split entries into train and test.
Parameters
----------
metadata_path : str
Path to metadata file
Returns
-------
list
List of samples
"""
with open(metadata_path, encoding=CHARACTER_ENCODING) as f:
filepaths_and_text = [line.strip().split("|") for line in f]
random.shuffle(filepaths_and_text)
return filepaths_and_text


def validate_dataset(filepaths_and_text, dataset_directory, symbols):
missing_files = set()
invalid_characters = set()
wavs = os.listdir(dataset_directory)
for filename, text in filepaths_and_text:
if filename not in wavs:
missing_files.add(filename)
invalid_characters_for_row = get_invalid_characters(text, symbols)
if invalid_characters_for_row:
invalid_characters.update(invalid_characters_for_row)

assert not missing_files, f"Missing files: {(',').join(missing_files)}"
assert not invalid_characters, f"Invalid characters (for alphabet): {(',').join(invalid_characters)}"


def train_test_split(filepaths_and_text, train_size):
"""
Split dataset into train & test data
Parameters
----------
filepaths_and_text : list
List of samples
train_size : float
Percentage of entries to use for training (rest used for testing)
Expand All @@ -89,10 +125,6 @@ def load_metadata(metadata_path, train_size):
(list, list)
List of train and test samples
"""
with open(metadata_path, encoding=CHARACTER_ENCODING) as f:
filepaths_and_text = [line.strip().split("|") for line in f]

random.shuffle(filepaths_and_text)
train_cutoff = int(len(filepaths_and_text) * train_size)
train_files = filepaths_and_text[:train_cutoff]
test_files = filepaths_and_text[train_cutoff:]
Expand Down

0 comments on commit 301fd10

Please sign in to comment.