Skip to content

Commit

Permalink
🚀 [Update] Fix Trainer Logs & Optimized DataLoader params
Browse files Browse the repository at this point in the history
  • Loading branch information
LuluW8071 committed Aug 13, 2024
1 parent cef3fb9 commit 6cb9b5a
Show file tree
Hide file tree
Showing 2 changed files with 113 additions and 149 deletions.
240 changes: 97 additions & 143 deletions src/Automatic_Speech_Recognition/neuralnet/dataset.py
Original file line number Diff line number Diff line change
@@ -1,144 +1,98 @@
import pytorch_lightning as pl
import os
import argparse
import random
import json
import torchaudio
import torch
import torch.nn as nn
import torchaudio.transforms as transforms
import numpy as np

from torch.utils.data import DataLoader, Dataset
from utils import TextTransform


class LogMelSpec(nn.Module):
def __init__(self, sample_rate=16000, n_mels=128, hop_length=350, n_fft=1024):
super(LogMelSpec, self).__init__()
self.transform = transforms.MelSpectrogram(sample_rate=sample_rate, n_mels=n_mels,
hop_length=hop_length, n_fft = n_fft)

def forward(self, x):
x = self.transform(x) # mel spectrogram
x = np.log(x + 1e-14) # logarithmic, add small value to avoid inf
return x


# Custom Dataset Class
class CustomAudioDataset(Dataset):
def __init__(self, json_path, transform=None, log_ex=True, valid=False):
print(f'Loading json data from {json_path}')
with open(json_path, 'r') as f:
self.data = json.load(f)
# print(self.data)
self.text_process = TextTransform() # Initialize TextProcess for text processing
self.log_ex = log_ex

if valid:
self.audio_transforms = torch.nn.Sequential(
LogMelSpec()
)
else:
self.audio_transforms = torch.nn.Sequential(
LogMelSpec(),
transforms.FrequencyMasking(freq_mask_param=30),
transforms.TimeMasking(time_mask_param=70)
)


def __len__(self):
return len(self.data)

def __getitem__(self, idx):
item = self.data[idx]
file_path = item['key']

try:
waveform, _ = torchaudio.load(file_path) # Point to location of audio data
utterance = item['text'].lower() # Point to sentence of audio data
# print(waveform, sample_rate)
# print('Sentences:', utterance)
label = self.text_process.text_to_int(utterance)
spectrogram = self.audio_transforms(waveform) # (channel, feature, time)
spec_len = spectrogram.shape[-1] // 2
label_len = len(label)

# print(f'SpecShape: {spectrogram.shape} \t shape[-1]: {spectrogram.shape[-1]}')
# print(f'Speclen: {spec_len} \t Label_len: {label_len}')

if spec_len < label_len:
raise Exception('spectrogram len is bigger then label len')
if spectrogram.shape[0] > 1:
raise Exception('dual channel, skipping audio file %s' %file_path)
if spectrogram.shape[2] > 1650*3:
raise Exception('spectrogram to big. size %s' %spectrogram.shape[2])
if label_len == 0:
raise Exception('label len is zero... skipping %s' %file_path)

# print(f'{idx}. {utterance}')
return spectrogram, label, spec_len, label_len

except Exception as e:
if self.log_ex:
print(str(e), file_path)
return self.__getitem__(idx - 1 if idx != 0 else idx + 1)

def describe(self):
return self.data.describe()


# Lightning Data Module
class SpeechDataModule(pl.LightningDataModule):
def __init__(self, batch_size, train_json, test_json, num_workers):
super().__init__()
self.batch_size = batch_size
self.train_json = train_json
self.test_json = test_json
self.num_workers = num_workers

def setup(self, stage=None):
self.train_dataset = CustomAudioDataset(self.train_json,
valid=False)
self.test_dataset = CustomAudioDataset(self.test_json,
valid=True)

def data_processing(self, data):
spectrograms = []
labels = []
input_lengths = []
label_lengths = []
for (spectrogram, label, input_length, label_length) in data:
if spectrogram is None:
continue

spectrograms.append(spectrogram.squeeze(0).transpose(0, 1))
# print(len(spectrograms))
# print(f'Label Check: {label}')
labels.append(torch.Tensor(label))
input_lengths.append(spectrogram.shape[-1] // 2)
label_lengths.append(len(label))
# Print the shapes of spectrograms before padding
# for spec in spectrograms:
# print("Spec before padding:", spec.shape)

# NOTE: https://www.geeksforgeeks.org/how-do-you-handle-sequence-padding-and-packing-in-pytorch-for-rnns/
spectrograms = nn.utils.rnn.pad_sequence(spectrograms, batch_first=True).unsqueeze(1).transpose(2, 3)
# print('Padded Spectrograms: ', spectrograms.shape)
labels = nn.utils.rnn.pad_sequence(labels, batch_first=True)

return spectrograms, labels, input_lengths, label_lengths


def train_dataloader(self):
return DataLoader(self.train_dataset,
batch_size=self.batch_size,
shuffle=True,
collate_fn=lambda x: self.data_processing(x),
num_workers=self.num_workers,
pin_memory=True) # Optimizes data-transfer speed for CUDA

def val_dataloader(self):
return DataLoader(self.test_dataset,
batch_size=self.batch_size,
shuffle=False,
collate_fn=lambda x: self.data_processing(x),
num_workers=self.num_workers,
pin_memory=True)
import csv
import sox
from tqdm import tqdm
from multiprocessing.pool import ThreadPool
from pathlib import Path

# Function to clean text by removing specified characters
def clean_text(text):
characters_to_remove =':!,"‘’;—?'
translator = str.maketrans('', '', characters_to_remove)
return text.translate(translator)

def process_file(row, clips_directory, directory, output_format):
file_name = row['path'] # Original file location
clips_name = file_name.rpartition('.')[0] + '.' + output_format
text = clean_text(row['sentence']) # Clean the sentence
audio_path = os.path.join(directory, 'clips', file_name)
output_audio_path = os.path.join(clips_directory, clips_name)

# Convert MP3 to FLAC using Sox
tfm = sox.Transformer()
tfm.rate(samplerate=16000)
tfm.build(input_filepath=audio_path, output_filepath=output_audio_path)

return {'key': clips_directory + '/' + clips_name, 'text': text}

def main(args):
data = [] # Empty list to store clips and sentences
directory = args.file_path.rpartition('/')[0]
percent = args.percent

# Create a 'clips' directory inside defined save_json_path
clips_directory = os.path.abspath(os.path.join(args.save_json_path, 'clips'))

if not os.path.exists(clips_directory):
os.makedirs(clips_directory)

with open(args.file_path, encoding="utf-8") as f:
length = sum(1 for _ in f) - 1

with open(args.file_path, newline='', encoding="utf-8") as csv_file:
reader = csv.DictReader(csv_file, delimiter='\t')
data_to_process = [(row, clips_directory, directory, args.output_format) for row in reader]

if args.convert:
print(f"{length} files found. Converting MP3 to {args.output_format.upper()} using {args.num_workers} workers.")
with ThreadPool(args.num_workers) as pool:
data = list(tqdm(pool.imap(lambda x: process_file(*x), data_to_process), total=length))
else:
for row in data_to_process:
file_name = row[0]['path']
clips_name = file_name.rpartition('.')[0] + '.' + args.output_format
text = clean_text(row[0]['sentence'])
data.append({'key': clips_directory + '/' + clips_name, 'text': text})

# Splitting data into train and test set and saving into JSON file
random.shuffle(data)
print("Creating train and test JSON sets")

train_data = data[:int(length * (1 - percent / 100))]
test_data = data[int(length * (1 - percent / 100)):]

with open(os.path.join(args.save_json_path, 'train.json'), 'w', encoding='utf-8') as f:
json.dump(train_data, f, ensure_ascii=False, indent=4)

with open(os.path.join(args.save_json_path, 'test.json'), 'w', encoding='utf-8') as f:
json.dump(test_data, f, ensure_ascii=False, indent=4)

print("Done!")

if __name__ == "__main__":
parser = argparse.ArgumentParser(description=
"""
Utility script to convert CommonVoice MP3 to FLAC and
split train and test JSON files for training ASR model.
"""
)
parser.add_argument('--file_path', type=str, default=None, required=True,
help='path to one of the .tsv files found in cv-corpus')
parser.add_argument('--save_json_path', type=str, default=None, required=True,
help='path to the dir where the json files are supposed to be saved')
parser.add_argument('--percent', type=int, default=10, required=False,
help='percent of clips put into test.json instead of train.json')
parser.add_argument('--convert', default=True, action='store_true',
help='indicates that the script should convert mp3 to flac')
parser.add_argument('--not-convert', dest='convert', action='store_false',
help='indicates that the script should not convert mp3 to flac')
parser.add_argument('-w','--num_workers', type=int, default=2,
help='number of worker threads for processing')
parser.add_argument('--output_format', type=str, default='flac',
help='output audio format (flac or wav)')

args = parser.parse_args()
main(args)
22 changes: 16 additions & 6 deletions src/Automatic_Speech_Recognition/neuralnet/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch.nn.functional as F
import torch.optim as optim
from torch import nn
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor, EarlyStopping
from pytorch_lightning.loggers import CometLogger

# Load API
Expand All @@ -25,6 +25,7 @@ def __init__(self, model, args):
self.args = args

# Metrics
self.losses = []
self.loss_fn = nn.CTCLoss(blank=28, zero_infinity=True)

def forward(self, x, hidden):
Expand Down Expand Up @@ -58,8 +59,9 @@ def training_step(self, batch, batch_idx):

def validation_step(self, batch, batch_idx):
loss, y_pred, labels, label_lengths = self._common_step(batch, batch_idx)
val_cer, val_wer = [], []
self.losses.append(loss)

val_cer, val_wer = [], []
decoded_preds, decoded_targets = GreedyDecoder(y_pred.transpose(0, 1), labels, label_lengths)

# Log predictions
Expand All @@ -78,8 +80,15 @@ def validation_step(self, batch, batch_idx):
'val_cer': avg_cer,
'val_wer': avg_wer,
}, on_step=False, on_epoch=True, prog_bar=True, logger=True, batch_size=batch_idx)
self.log('val_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True, batch_size=batch_idx)
return loss
return {'val_loss': loss}


def on_validation_epoch_end(self):
avg_loss = torch.stack(self.losses).mean()

self.log('val_loss', avg_loss.item(), on_step=False, on_epoch=True, prog_bar=True, logger=True)
self.losses.clear() # Clear losses for next epochs


def predict_step(self, batch, batch_idx):
pass
Expand Down Expand Up @@ -122,7 +131,8 @@ def main(args):
precision=args.precision,
val_check_interval=args.steps,
gradient_clip_val=1.0,
callbacks=[EarlyStopping(monitor="val_loss"),
callbacks=[LearningRateMonitor(logging_interval='epoch'),
EarlyStopping(monitor="val_loss"),
checkpoint_callback
],
logger=comet_logger
Expand Down Expand Up @@ -157,4 +167,4 @@ def main(args):
parser.add_argument('--checkpoint_path', default=None, type=str, help='path to a checkpoint file to resume training')

args = parser.parse_args()
main(args)
main(args)

0 comments on commit 6cb9b5a

Please sign in to comment.