Skip to content

Commit

Permalink
🚀 [Update] DataLoader Changes, Distributed gpu support & Full decode …
Browse files Browse the repository at this point in the history
…logs
  • Loading branch information
LuluW8071 committed Aug 14, 2024
1 parent 6cb9b5a commit d4b699c
Show file tree
Hide file tree
Showing 4 changed files with 197 additions and 151 deletions.
23 changes: 0 additions & 23 deletions .github/workflows/security_check.yml

This file was deleted.

2 changes: 1 addition & 1 deletion src/Automatic_Speech_Recognition/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def __call__(self, x):
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="demoing the speech recognition engine in terminal.")
parser.add_argument('--model_file', type=str, default=None, required=True,
help='optimized file to load. use freeze_model.py')
help='optimized file to load. use freeze.py')
parser.add_argument('--ken_lm_file', type=str, default=None, required=False,
help='If you have an ngram lm use to decode')

Expand Down
249 changes: 152 additions & 97 deletions src/Automatic_Speech_Recognition/neuralnet/dataset.py
Original file line number Diff line number Diff line change
@@ -1,98 +1,153 @@
import os
import argparse
import random
import pytorch_lightning as pl
import json
import csv
import sox
from tqdm import tqdm
from multiprocessing.pool import ThreadPool
from pathlib import Path

# Function to clean text by removing specified characters
def clean_text(text):
characters_to_remove =':!,"‘’;—?'
translator = str.maketrans('', '', characters_to_remove)
return text.translate(translator)

def process_file(row, clips_directory, directory, output_format):
file_name = row['path'] # Original file location
clips_name = file_name.rpartition('.')[0] + '.' + output_format
text = clean_text(row['sentence']) # Clean the sentence
audio_path = os.path.join(directory, 'clips', file_name)
output_audio_path = os.path.join(clips_directory, clips_name)

# Convert MP3 to FLAC using Sox
tfm = sox.Transformer()
tfm.rate(samplerate=16000)
tfm.build(input_filepath=audio_path, output_filepath=output_audio_path)

return {'key': clips_directory + '/' + clips_name, 'text': text}

def main(args):
data = [] # Empty list to store clips and sentences
directory = args.file_path.rpartition('/')[0]
percent = args.percent

# Create a 'clips' directory inside defined save_json_path
clips_directory = os.path.abspath(os.path.join(args.save_json_path, 'clips'))

if not os.path.exists(clips_directory):
os.makedirs(clips_directory)

with open(args.file_path, encoding="utf-8") as f:
length = sum(1 for _ in f) - 1

with open(args.file_path, newline='', encoding="utf-8") as csv_file:
reader = csv.DictReader(csv_file, delimiter='\t')
data_to_process = [(row, clips_directory, directory, args.output_format) for row in reader]

if args.convert:
print(f"{length} files found. Converting MP3 to {args.output_format.upper()} using {args.num_workers} workers.")
with ThreadPool(args.num_workers) as pool:
data = list(tqdm(pool.imap(lambda x: process_file(*x), data_to_process), total=length))
else:
for row in data_to_process:
file_name = row[0]['path']
clips_name = file_name.rpartition('.')[0] + '.' + args.output_format
text = clean_text(row[0]['sentence'])
data.append({'key': clips_directory + '/' + clips_name, 'text': text})

# Splitting data into train and test set and saving into JSON file
random.shuffle(data)
print("Creating train and test JSON sets")

train_data = data[:int(length * (1 - percent / 100))]
test_data = data[int(length * (1 - percent / 100)):]

with open(os.path.join(args.save_json_path, 'train.json'), 'w', encoding='utf-8') as f:
json.dump(train_data, f, ensure_ascii=False, indent=4)

with open(os.path.join(args.save_json_path, 'test.json'), 'w', encoding='utf-8') as f:
json.dump(test_data, f, ensure_ascii=False, indent=4)

print("Done!")

if __name__ == "__main__":
parser = argparse.ArgumentParser(description=
"""
Utility script to convert CommonVoice MP3 to FLAC and
split train and test JSON files for training ASR model.
"""
)
parser.add_argument('--file_path', type=str, default=None, required=True,
help='path to one of the .tsv files found in cv-corpus')
parser.add_argument('--save_json_path', type=str, default=None, required=True,
help='path to the dir where the json files are supposed to be saved')
parser.add_argument('--percent', type=int, default=10, required=False,
help='percent of clips put into test.json instead of train.json')
parser.add_argument('--convert', default=True, action='store_true',
help='indicates that the script should convert mp3 to flac')
parser.add_argument('--not-convert', dest='convert', action='store_false',
help='indicates that the script should not convert mp3 to flac')
parser.add_argument('-w','--num_workers', type=int, default=2,
help='number of worker threads for processing')
parser.add_argument('--output_format', type=str, default='flac',
help='output audio format (flac or wav)')

args = parser.parse_args()
main(args)
import torchaudio
import torch
import torch.nn as nn
import torchaudio.transforms as transforms
import numpy as np

from torch.utils.data import DataLoader, Dataset
from utils import TextTransform # Comment this for engine inference


class LogMelSpec(nn.Module):
def __init__(self, sample_rate=16000, n_mels=128, win_length=400, hop_length=160, n_fft=1024):
super(LogMelSpec, self).__init__()
self.transform = transforms.MelSpectrogram(sample_rate=sample_rate,
n_mels=n_mels,
win_length=win_length,
hop_length=hop_length,
n_fft = n_fft)

def forward(self, x):
x = self.transform(x) # mel spectrogram
x = np.log(x + 1e-14) # logarithmic, add small value to avoid inf
return x

def get_featurizer(sample_rate=16000, n_mels=128, win_length=400, hop_length=160, n_fft=1024):
return LogMelSpec(sample_rate=sample_rate,
n_mels=n_mels,
win_length=win_length,
hop_length=hop_length,
n_fft = n_fft)

# Custom Dataset Class
class CustomAudioDataset(Dataset):
def __init__(self, json_path, transform=None, log_ex=True, valid=False):
print(f'Loading json data from {json_path}')
with open(json_path, 'r') as f:
self.data = json.load(f)
# print(self.data)
self.text_process = TextTransform() # Initialize TextProcess for text processing
self.log_ex = log_ex

if valid:
self.audio_transforms = torch.nn.Sequential(
LogMelSpec()
)
else:
self.audio_transforms = torch.nn.Sequential(
LogMelSpec(),
transforms.FrequencyMasking(freq_mask_param=30),
transforms.TimeMasking(time_mask_param=70)
)


def __len__(self):
return len(self.data)

def __getitem__(self, idx):
item = self.data[idx]
file_path = item['key']

try:
waveform, _ = torchaudio.load(file_path) # Point to location of audio data
utterance = item['text'].lower() # Point to sentence of audio data
# print(waveform, sample_rate)
# print('Sentences:', utterance)
label = self.text_process.text_to_int(utterance)
spectrogram = self.audio_transforms(waveform) # (channel, feature, time)
spec_len = spectrogram.shape[-1] // 2
label_len = len(label)

# print(f'SpecShape: {spectrogram.shape} \t shape[-1]: {spectrogram.shape[-1]}')
# print(f'Speclen: {spec_len} \t Label_len: {label_len}')

if spec_len < label_len:
raise Exception('spectrogram len is bigger then label len')
if spectrogram.shape[0] > 1:
raise Exception('dual channel, skipping audio file %s' %file_path)
if spectrogram.shape[2] > 1650*3:
raise Exception('spectrogram to big. size %s' %spectrogram.shape[2])
if label_len == 0:
raise Exception('label len is zero... skipping %s' %file_path)

# print(f'{idx}. {utterance}')
return spectrogram, label, spec_len, label_len

except Exception as e:
if self.log_ex:
print(str(e), file_path)
return self.__getitem__(idx - 1 if idx != 0 else idx + 1)

def describe(self):
return self.data.describe()


# Lightning Data Module
class SpeechDataModule(pl.LightningDataModule):
def __init__(self, batch_size, train_json, test_json, num_workers):
super().__init__()
self.batch_size = batch_size
self.train_json = train_json
self.test_json = test_json
self.num_workers = num_workers

def setup(self, stage=None):
self.train_dataset = CustomAudioDataset(self.train_json,
valid=False)
self.test_dataset = CustomAudioDataset(self.test_json,
valid=True)

def data_processing(self, data):
spectrograms = []
labels = []
input_lengths = []
label_lengths = []
for (spectrogram, label, input_length, label_length) in data:
if spectrogram is None:
continue

spectrograms.append(spectrogram.squeeze(0).transpose(0, 1))
# print(len(spectrograms))
# print(f'Label Check: {label}')
labels.append(torch.Tensor(label))
input_lengths.append(spectrogram.shape[-1] // 2)
label_lengths.append(len(label))
# Print the shapes of spectrograms before padding
# for spec in spectrograms:
# print("Spec before padding:", spec.shape)

# NOTE: https://www.geeksforgeeks.org/how-do-you-handle-sequence-padding-and-packing-in-pytorch-for-rnns/
spectrograms = nn.utils.rnn.pad_sequence(spectrograms, batch_first=True).unsqueeze(1).transpose(2, 3)
# print('Padded Spectrograms: ', spectrograms.shape)
labels = nn.utils.rnn.pad_sequence(labels, batch_first=True)

return spectrograms, labels, input_lengths, label_lengths


def train_dataloader(self):
return DataLoader(self.train_dataset,
batch_size=self.batch_size,
shuffle=True,
collate_fn=lambda x: self.data_processing(x),
num_workers=self.num_workers,
pin_memory=True) # Optimizes data-transfer speed for CUDA

def val_dataloader(self):
return DataLoader(self.test_dataset,
batch_size=self.batch_size,
shuffle=False,
collate_fn=lambda x: self.data_processing(x),
num_workers=self.num_workers,
pin_memory=True)
Loading

0 comments on commit d4b699c

Please sign in to comment.