-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
🚀 [Update] DataLoader Changes, Distributed gpu support & Full decode …
…logs
- Loading branch information
Showing
4 changed files
with
197 additions
and
151 deletions.
There are no files selected for viewing
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,98 +1,153 @@ | ||
import os | ||
import argparse | ||
import random | ||
import pytorch_lightning as pl | ||
import json | ||
import csv | ||
import sox | ||
from tqdm import tqdm | ||
from multiprocessing.pool import ThreadPool | ||
from pathlib import Path | ||
|
||
# Function to clean text by removing specified characters | ||
def clean_text(text): | ||
characters_to_remove =':!,"‘’;—?' | ||
translator = str.maketrans('', '', characters_to_remove) | ||
return text.translate(translator) | ||
|
||
def process_file(row, clips_directory, directory, output_format): | ||
file_name = row['path'] # Original file location | ||
clips_name = file_name.rpartition('.')[0] + '.' + output_format | ||
text = clean_text(row['sentence']) # Clean the sentence | ||
audio_path = os.path.join(directory, 'clips', file_name) | ||
output_audio_path = os.path.join(clips_directory, clips_name) | ||
|
||
# Convert MP3 to FLAC using Sox | ||
tfm = sox.Transformer() | ||
tfm.rate(samplerate=16000) | ||
tfm.build(input_filepath=audio_path, output_filepath=output_audio_path) | ||
|
||
return {'key': clips_directory + '/' + clips_name, 'text': text} | ||
|
||
def main(args): | ||
data = [] # Empty list to store clips and sentences | ||
directory = args.file_path.rpartition('/')[0] | ||
percent = args.percent | ||
|
||
# Create a 'clips' directory inside defined save_json_path | ||
clips_directory = os.path.abspath(os.path.join(args.save_json_path, 'clips')) | ||
|
||
if not os.path.exists(clips_directory): | ||
os.makedirs(clips_directory) | ||
|
||
with open(args.file_path, encoding="utf-8") as f: | ||
length = sum(1 for _ in f) - 1 | ||
|
||
with open(args.file_path, newline='', encoding="utf-8") as csv_file: | ||
reader = csv.DictReader(csv_file, delimiter='\t') | ||
data_to_process = [(row, clips_directory, directory, args.output_format) for row in reader] | ||
|
||
if args.convert: | ||
print(f"{length} files found. Converting MP3 to {args.output_format.upper()} using {args.num_workers} workers.") | ||
with ThreadPool(args.num_workers) as pool: | ||
data = list(tqdm(pool.imap(lambda x: process_file(*x), data_to_process), total=length)) | ||
else: | ||
for row in data_to_process: | ||
file_name = row[0]['path'] | ||
clips_name = file_name.rpartition('.')[0] + '.' + args.output_format | ||
text = clean_text(row[0]['sentence']) | ||
data.append({'key': clips_directory + '/' + clips_name, 'text': text}) | ||
|
||
# Splitting data into train and test set and saving into JSON file | ||
random.shuffle(data) | ||
print("Creating train and test JSON sets") | ||
|
||
train_data = data[:int(length * (1 - percent / 100))] | ||
test_data = data[int(length * (1 - percent / 100)):] | ||
|
||
with open(os.path.join(args.save_json_path, 'train.json'), 'w', encoding='utf-8') as f: | ||
json.dump(train_data, f, ensure_ascii=False, indent=4) | ||
|
||
with open(os.path.join(args.save_json_path, 'test.json'), 'w', encoding='utf-8') as f: | ||
json.dump(test_data, f, ensure_ascii=False, indent=4) | ||
|
||
print("Done!") | ||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser(description= | ||
""" | ||
Utility script to convert CommonVoice MP3 to FLAC and | ||
split train and test JSON files for training ASR model. | ||
""" | ||
) | ||
parser.add_argument('--file_path', type=str, default=None, required=True, | ||
help='path to one of the .tsv files found in cv-corpus') | ||
parser.add_argument('--save_json_path', type=str, default=None, required=True, | ||
help='path to the dir where the json files are supposed to be saved') | ||
parser.add_argument('--percent', type=int, default=10, required=False, | ||
help='percent of clips put into test.json instead of train.json') | ||
parser.add_argument('--convert', default=True, action='store_true', | ||
help='indicates that the script should convert mp3 to flac') | ||
parser.add_argument('--not-convert', dest='convert', action='store_false', | ||
help='indicates that the script should not convert mp3 to flac') | ||
parser.add_argument('-w','--num_workers', type=int, default=2, | ||
help='number of worker threads for processing') | ||
parser.add_argument('--output_format', type=str, default='flac', | ||
help='output audio format (flac or wav)') | ||
|
||
args = parser.parse_args() | ||
main(args) | ||
import torchaudio | ||
import torch | ||
import torch.nn as nn | ||
import torchaudio.transforms as transforms | ||
import numpy as np | ||
|
||
from torch.utils.data import DataLoader, Dataset | ||
from utils import TextTransform # Comment this for engine inference | ||
|
||
|
||
class LogMelSpec(nn.Module): | ||
def __init__(self, sample_rate=16000, n_mels=128, win_length=400, hop_length=160, n_fft=1024): | ||
super(LogMelSpec, self).__init__() | ||
self.transform = transforms.MelSpectrogram(sample_rate=sample_rate, | ||
n_mels=n_mels, | ||
win_length=win_length, | ||
hop_length=hop_length, | ||
n_fft = n_fft) | ||
|
||
def forward(self, x): | ||
x = self.transform(x) # mel spectrogram | ||
x = np.log(x + 1e-14) # logarithmic, add small value to avoid inf | ||
return x | ||
|
||
def get_featurizer(sample_rate=16000, n_mels=128, win_length=400, hop_length=160, n_fft=1024): | ||
return LogMelSpec(sample_rate=sample_rate, | ||
n_mels=n_mels, | ||
win_length=win_length, | ||
hop_length=hop_length, | ||
n_fft = n_fft) | ||
|
||
# Custom Dataset Class | ||
class CustomAudioDataset(Dataset): | ||
def __init__(self, json_path, transform=None, log_ex=True, valid=False): | ||
print(f'Loading json data from {json_path}') | ||
with open(json_path, 'r') as f: | ||
self.data = json.load(f) | ||
# print(self.data) | ||
self.text_process = TextTransform() # Initialize TextProcess for text processing | ||
self.log_ex = log_ex | ||
|
||
if valid: | ||
self.audio_transforms = torch.nn.Sequential( | ||
LogMelSpec() | ||
) | ||
else: | ||
self.audio_transforms = torch.nn.Sequential( | ||
LogMelSpec(), | ||
transforms.FrequencyMasking(freq_mask_param=30), | ||
transforms.TimeMasking(time_mask_param=70) | ||
) | ||
|
||
|
||
def __len__(self): | ||
return len(self.data) | ||
|
||
def __getitem__(self, idx): | ||
item = self.data[idx] | ||
file_path = item['key'] | ||
|
||
try: | ||
waveform, _ = torchaudio.load(file_path) # Point to location of audio data | ||
utterance = item['text'].lower() # Point to sentence of audio data | ||
# print(waveform, sample_rate) | ||
# print('Sentences:', utterance) | ||
label = self.text_process.text_to_int(utterance) | ||
spectrogram = self.audio_transforms(waveform) # (channel, feature, time) | ||
spec_len = spectrogram.shape[-1] // 2 | ||
label_len = len(label) | ||
|
||
# print(f'SpecShape: {spectrogram.shape} \t shape[-1]: {spectrogram.shape[-1]}') | ||
# print(f'Speclen: {spec_len} \t Label_len: {label_len}') | ||
|
||
if spec_len < label_len: | ||
raise Exception('spectrogram len is bigger then label len') | ||
if spectrogram.shape[0] > 1: | ||
raise Exception('dual channel, skipping audio file %s' %file_path) | ||
if spectrogram.shape[2] > 1650*3: | ||
raise Exception('spectrogram to big. size %s' %spectrogram.shape[2]) | ||
if label_len == 0: | ||
raise Exception('label len is zero... skipping %s' %file_path) | ||
|
||
# print(f'{idx}. {utterance}') | ||
return spectrogram, label, spec_len, label_len | ||
|
||
except Exception as e: | ||
if self.log_ex: | ||
print(str(e), file_path) | ||
return self.__getitem__(idx - 1 if idx != 0 else idx + 1) | ||
|
||
def describe(self): | ||
return self.data.describe() | ||
|
||
|
||
# Lightning Data Module | ||
class SpeechDataModule(pl.LightningDataModule): | ||
def __init__(self, batch_size, train_json, test_json, num_workers): | ||
super().__init__() | ||
self.batch_size = batch_size | ||
self.train_json = train_json | ||
self.test_json = test_json | ||
self.num_workers = num_workers | ||
|
||
def setup(self, stage=None): | ||
self.train_dataset = CustomAudioDataset(self.train_json, | ||
valid=False) | ||
self.test_dataset = CustomAudioDataset(self.test_json, | ||
valid=True) | ||
|
||
def data_processing(self, data): | ||
spectrograms = [] | ||
labels = [] | ||
input_lengths = [] | ||
label_lengths = [] | ||
for (spectrogram, label, input_length, label_length) in data: | ||
if spectrogram is None: | ||
continue | ||
|
||
spectrograms.append(spectrogram.squeeze(0).transpose(0, 1)) | ||
# print(len(spectrograms)) | ||
# print(f'Label Check: {label}') | ||
labels.append(torch.Tensor(label)) | ||
input_lengths.append(spectrogram.shape[-1] // 2) | ||
label_lengths.append(len(label)) | ||
# Print the shapes of spectrograms before padding | ||
# for spec in spectrograms: | ||
# print("Spec before padding:", spec.shape) | ||
|
||
# NOTE: https://www.geeksforgeeks.org/how-do-you-handle-sequence-padding-and-packing-in-pytorch-for-rnns/ | ||
spectrograms = nn.utils.rnn.pad_sequence(spectrograms, batch_first=True).unsqueeze(1).transpose(2, 3) | ||
# print('Padded Spectrograms: ', spectrograms.shape) | ||
labels = nn.utils.rnn.pad_sequence(labels, batch_first=True) | ||
|
||
return spectrograms, labels, input_lengths, label_lengths | ||
|
||
|
||
def train_dataloader(self): | ||
return DataLoader(self.train_dataset, | ||
batch_size=self.batch_size, | ||
shuffle=True, | ||
collate_fn=lambda x: self.data_processing(x), | ||
num_workers=self.num_workers, | ||
pin_memory=True) # Optimizes data-transfer speed for CUDA | ||
|
||
def val_dataloader(self): | ||
return DataLoader(self.test_dataset, | ||
batch_size=self.batch_size, | ||
shuffle=False, | ||
collate_fn=lambda x: self.data_processing(x), | ||
num_workers=self.num_workers, | ||
pin_memory=True) |
Oops, something went wrong.