Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
BenAAndrew committed Nov 28, 2021
1 parent 9fcd00e commit 0bf3570
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 12 deletions.
15 changes: 7 additions & 8 deletions tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,31 +225,26 @@ def test_clip_combiner():
def test_extend_existing_dataset():
dataset_directory = "test-extend-dataset"
audio_folder = os.path.join(dataset_directory, "wavs")
unlabelled_path = os.path.join(dataset_directory, "unlabelled")
unlabelled_folder = os.path.join(dataset_directory, "unlabelled")
metadata_file = os.path.join(dataset_directory, "metadata.csv")
os.makedirs(dataset_directory)
os.makedirs(audio_folder)
os.makedirs(unlabelled_folder)
with open(metadata_file, "w") as f:
pass

audio_path = os.path.join("test_samples", "audio.wav")
converted_audio_path = os.path.join("test_samples", "audio-converted.wav")
text_path = os.path.join("test_samples", "text.txt")
forced_alignment_path = os.path.join(dataset_directory, "align.json")
label_path = os.path.join(dataset_directory, "metadata.csv")
info_path = os.path.join(dataset_directory, "info.json")
suffix = "extend"
min_confidence = 1.0
extend_existing_dataset(
text_path=text_path,
audio_path=audio_path,
transcription_model=FakeTranscriptionModel(),
forced_alignment_path=forced_alignment_path,
output_path=audio_folder,
unlabelled_path=unlabelled_path,
label_path=label_path,
output_folder=dataset_directory,
suffix=suffix,
info_path=info_path,
min_confidence=min_confidence,
combine_clips=False,
)
Expand All @@ -258,6 +253,10 @@ def test_extend_existing_dataset():
name.split(".")[0] + "-" + suffix + ".wav" for name in EXPECTED_CLIPS
], "Unexpected audio clips"

assert os.listdir(unlabelled_folder) == [
name.split(".")[0] + "-" + suffix + ".wav" for name in UNMATCHED_CLIPS
], "Unexpected unlabelled audio clips"

with open(label_path) as f:
lines = f.readlines()
expected_text = [
Expand Down
9 changes: 6 additions & 3 deletions training/clean_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
}


def clean_text(text, symbols=DEFAULT_ALPHABET):
def clean_text(text, symbols=DEFAULT_ALPHABET, remove_invalid_characters=True):
"""
Cleans text. This includes:
- Replacing monetary terms (i.e. $ -> dollars)
Expand All @@ -49,8 +49,10 @@ def clean_text(text, symbols=DEFAULT_ALPHABET):
----------
text : str
Text to clean
symbols : list
symbols : list (optional)
List of valid symbols in text (default is English alphabet & punctuation)
remove_invalid_characters : bool (optional)
Whether to remove characters not in symbols list (default is True)
Returns
-------
Expand Down Expand Up @@ -83,7 +85,8 @@ def clean_text(text, symbols=DEFAULT_ALPHABET):
# Collapse whitespace
text = re.sub(WHITESPACE_RE, " ", text)
# Remove banned characters
text = "".join([c for c in text if c in symbols])
if remove_invalid_characters:
text = "".join([c for c in text if c in symbols])
return text


Expand Down
2 changes: 1 addition & 1 deletion training/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def validate_dataset(filepaths_and_text, dataset_directory, symbols):
invalid_characters = set()
wavs = os.listdir(dataset_directory)
for filename, text in filepaths_and_text:
text = clean_text(text, symbols)
text = clean_text(text, remove_invalid_characters=False)
if filename not in wavs:
missing_files.add(filename)
invalid_characters_for_row = get_invalid_characters(text, symbols)
Expand Down

0 comments on commit 0bf3570

Please sign in to comment.