Skip to content

Commit

Permalink
🚀 [Update] Dataset and Trainer Changes
Browse files Browse the repository at this point in the history
Dataset and Trainer script changes to handle librispeech datasets. Fixing Freeze Script and checkpoint loading
  • Loading branch information
LuluW8071 committed Sep 16, 2024
1 parent e0b23c9 commit 76df8ce
Show file tree
Hide file tree
Showing 3 changed files with 119 additions and 110 deletions.
Original file line number Diff line number Diff line change
@@ -1,46 +1,40 @@
import pytorch_lightning as pl
import json
import torchaudio
import torch
import torch.nn as nn
import torchaudio.transforms as transforms
import numpy as np

from torch.utils.data import DataLoader, Dataset
from utils import TextTransform # Comment this for engine inference

from utils import TextTransform # Comment this for ASR engine inference

class LogMelSpec(nn.Module):
def __init__(self, sample_rate=16000, n_mels=80, hop_length=160):
def __init__(self, sample_rate=16000,
hop_length=160, n_mels=80):
super(LogMelSpec, self).__init__()
self.transform = transforms.MelSpectrogram(sample_rate=sample_rate,
n_mels=n_mels,
hop_length=hop_length)
self.transform = transforms.MelSpectrogram(
sample_rate=sample_rate,
n_mels=n_mels,
hop_length=hop_length
)

def forward(self, x):
x = self.transform(x) # mel spectrogram
x = np.log(x + 1e-14) # logarithmic, add small value to avoid inf
x = self.transform(x)
x = torch.log(x + 1e-14) # logarithmic, add small value to avoid inf
return x

def get_featurizer(sample_rate=16000, n_mels=80, hop_length=160):
return LogMelSpec(sample_rate=sample_rate,
n_mels=n_mels,
hop_length=hop_length)

# Custom Dataset Class
def get_featurizer(sample_rate, n_feats=80, hop_length=160):
return LogMelSpec(sample_rate=sample_rate, n_mels=n_feats, hop_length=hop_length)


class CustomAudioDataset(Dataset):
def __init__(self, json_path, transform=None, log_ex=True, valid=False):
print(f'Loading json data from {json_path}')
with open(json_path, 'r') as f:
self.data = json.load(f)
# print(self.data)
self.text_process = TextTransform() # Initialize TextProcess for text processing
def __init__(self, dataset, transform=None, log_ex=True, valid=False):
self.dataset = dataset
self.text_process = TextTransform() # Initialize TextProcess for text processing
self.log_ex = log_ex

if valid:
self.audio_transforms = torch.nn.Sequential(
LogMelSpec()
)
self.audio_transforms = nn.Sequential(LogMelSpec())
else:
time_masks = [torchaudio.transforms.TimeMasking(time_mask_param=15, p=0.05) for _ in range(10)]
self.audio_transforms = nn.Sequential(
Expand All @@ -49,61 +43,61 @@ def __init__(self, json_path, transform=None, log_ex=True, valid=False):
*time_masks,
)


def __len__(self):
return len(self.data)
return len(self.dataset)

def __getitem__(self, idx):
item = self.data[idx]
file_path = item['key']

try:
waveform, _ = torchaudio.load(file_path) # Point to location of audio data
utterance = item['text'].lower() # Point to sentence of audio data
waveform, sample_rate, utterance, _, _, _ = self.dataset[idx]
utterance = utterance.lower()
label = self.text_process.text_to_int(utterance)
spectrogram = self.audio_transforms(waveform) # (channel, feature, time)
spec_len = spectrogram.shape[-1]

# Apply audio transformations
spectrogram = self.audio_transforms(waveform) # (channel, feature, time)

spec_len = spectrogram.shape[-1] // 2
label_len = len(label)

# print(f'SpecShape: {spectrogram.shape} \t shape[-1]: {spectrogram.shape[-1]}')
# print(f'Speclen: {spec_len} \t Label_len: {label_len}')

if spec_len < label_len:
raise Exception('spectrogram len is bigger then label len')
if spectrogram.shape[0] > 1:
raise Exception('dual channel, skipping audio file %s' %file_path)
if spectrogram.shape[2] > 1650*3:
raise Exception('spectrogram to big. size %s' %spectrogram.shape[2])
if label_len == 0:
raise Exception('label len is zero... skipping %s' %file_path)

# Check if spectrogram or label length is valid
if spec_len < label_len or spectrogram.shape[0] > 1 or label_len == 0:
raise ValueError('Invalid spectrogram or label length.')

return spectrogram, label, spec_len, label_len

except FileNotFoundError as fnf_error:
if self.log_ex:
pass

# Skip the file and move to the next available sample
return self.__getitem__(idx - 1 if idx != 0 else idx + 1)

except Exception as e:
# Handle any other exceptions and retry with neighboring samples
if self.log_ex:
print(str(e), file_path)
print(str(e))
return self.__getitem__(idx - 1 if idx != 0 else idx + 1)

def describe(self):
return self.data.describe()


# Lightning Data Module
class SpeechDataModule(pl.LightningDataModule):
def __init__(self, batch_size, train_json, test_json, num_workers):
def __init__(self, batch_size, train_url, test_url, num_workers):
super().__init__()
self.batch_size = batch_size
self.train_json = train_json
self.test_json = test_json
self.train_url = train_url
self.test_url = test_url
self.num_workers = num_workers
self.text_process = TextTransform() # Initialize TextProcess for text processing
self.text_process = TextTransform()

def setup(self, stage=None):
self.train_dataset = CustomAudioDataset(self.train_json,
valid=False)
self.test_dataset = CustomAudioDataset(self.test_json,
valid=True)

# Load multiple training and test URLs
train_dataset = [torchaudio.datasets.LIBRISPEECH("./data", url=url, download=True) for url in self.train_url]
test_dataset = [torchaudio.datasets.LIBRISPEECH("./data", url=url, download=True) for url in self.test_url]

# Concatenate multiple datasets into one
combined_train_dataset = torch.utils.data.ConcatDataset(train_dataset)
combined_test_dataset = torch.utils.data.ConcatDataset(test_dataset)

self.train_dataset = CustomAudioDataset(combined_train_dataset, valid=False)
self.test_dataset = CustomAudioDataset(combined_test_dataset, valid=True)

def data_processing(self, data):
spectrograms, labels, references, input_lengths, label_lengths = [], [], [], [], []
for (spectrogram, label, input_length, label_length) in data:
Expand All @@ -114,6 +108,7 @@ def data_processing(self, data):
input_lengths.append(((spectrogram.shape[-1] - 1) // 2 - 1) // 2)
label_lengths.append(label_length)
references.append(self.text_process.int_to_text(label)) # Convert label back to text

# Pad the spectrograms to have the same width (time dimension)
spectrograms = nn.utils.rnn.pad_sequence(spectrograms, batch_first=True)
labels = nn.utils.rnn.pad_sequence(labels, batch_first=True)
Expand All @@ -127,20 +122,19 @@ def data_processing(self, data):
mask[i, :, :l] = 0

return spectrograms, labels, input_lengths, label_lengths, references, mask.bool()



def train_dataloader(self):
return DataLoader(self.train_dataset,
batch_size=self.batch_size,
shuffle=True,
collate_fn=lambda x: self.data_processing(x),
num_workers=self.num_workers,
pin_memory=True) # Optimizes data-transfer speed
return DataLoader(self.train_dataset,
batch_size=self.batch_size,
shuffle=True,
collate_fn=self.data_processing,
num_workers=self.num_workers,
pin_memory=True)

def val_dataloader(self):
return DataLoader(self.test_dataset,
batch_size=self.batch_size,
return DataLoader(self.test_dataset,
batch_size=self.batch_size,
shuffle=False,
collate_fn=lambda x: self.data_processing(x),
num_workers=self.num_workers,
pin_memory=True)
collate_fn=self.data_processing,
num_workers=self.num_workers,
pin_memory=True)
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def forward(self, x):

def trace(model):
model.eval()
x = torch.rand(1, 128, 80) # (Batch_size, seq_length, input_feat)
x = torch.rand(1, 300, 80) # (Batch_size, seq_length, input_feat)
traced = torch.jit.trace(model, x)
return traced

Expand Down
Loading

0 comments on commit 76df8ce

Please sign in to comment.