diff --git a/README.md b/README.md index 20aa9514..7968a3c9 100644 --- a/README.md +++ b/README.md @@ -1,10 +1,6 @@ # deepspeech.pytorch -* Add tests for dataloading -* Fix validation. Assume problems with SequenceWise module and using batch norm in 3d mode. -* Support LibriSpeech via multi-processed scripts - -Implementation of [Baidu Warp-CTC](https://github.com/baidu-research/warp-ctc) using pytorch. +Implementation of DeepSpeech2 using [Baidu Warp-CTC](https://github.com/baidu-research/warp-ctc). Creates a network based on the [DeepSpeech2](http://arxiv.org/pdf/1512.02595v1.pdf) architecture, trained with the CTC activation function. # Installation @@ -39,16 +35,29 @@ pip install -r requirements.txt Currently only supports an4. To download and setup the an4 dataset run below command in the root folder of the repo: ``` -python get_an4.py -python create_dataset_manifest.py --root_path dataset/ +cd data; python an4.py ``` This will generate csv manifests files used to load the data for training. +LibriSpeech formatting is in the works. + +### Custom Dataset + +To create a custom dataset you must create a CSV file containing the locations of the training data. This has to be in the format of: + +``` +/path/to/audio.wav,/path/to/text.txt +/path/to/audio2.wav,/path/to/text2.txt +... +``` + +The first path is to the audio file, and the second path is to a text file containing the transcript on one line. This can then be used as stated below. + ## Training ``` -python main.py --train_manifest train_manifest.csv --test_manifest test_manifest.csv +python train.py --train_manifest data/train_manifest.csv --val_manifest data/val_manifest.csv ``` diff --git a/data/__init__.py b/data/__init__.py index e69de29b..194b058e 100644 --- a/data/__init__.py +++ b/data/__init__.py @@ -0,0 +1 @@ +from . import data_loader diff --git a/data/get_an4.py b/data/an4.py similarity index 80% rename from data/get_an4.py rename to data/an4.py index db3a8121..3795d5c4 100644 --- a/data/get_an4.py +++ b/data/an4.py @@ -10,9 +10,10 @@ parser = argparse.ArgumentParser(description='Processes and downloads an4.') parser.add_argument('--an4_path', default='an4_dataset/', help='Path to save dataset') parser.add_argument('--sample_rate', default=16000, type=int, help='Sample rate') +args = parser.parse_args() -def format_data(data_tag, name, wav_folder): +def _format_data(root_path, data_tag, name, wav_folder): data_path = args.an4_path + data_tag + '/' + name + '/' new_transcript_path = data_path + '/txt/' new_wav_path = data_path + '/wav/' @@ -25,11 +26,11 @@ def format_data(data_tag, name, wav_folder): transcripts = root_path + 'etc/an4_%s.transcription' % data_tag train_path = wav_path + wav_folder - convert_audio_to_wav(train_path) - format_files(file_ids, new_transcript_path, new_wav_path, transcripts, wav_path) + _convert_audio_to_wav(train_path) + _format_files(file_ids, new_transcript_path, new_wav_path, transcripts, wav_path) -def convert_audio_to_wav(train_path): +def _convert_audio_to_wav(train_path): with os.popen('find %s -type f -name "*.raw"' % train_path) as pipe: for line in pipe: raw_path = line.strip() @@ -39,7 +40,7 @@ def convert_audio_to_wav(train_path): os.system(cmd) -def format_files(file_ids, new_transcript_path, new_wav_path, transcripts, wav_path): +def _format_files(file_ids, new_transcript_path, new_wav_path, transcripts, wav_path): with open(file_ids, 'r') as f: with open(transcripts, 'r') as t: paths = f.readlines() @@ -47,7 +48,7 @@ def format_files(file_ids, new_transcript_path, new_wav_path, transcripts, wav_p for x in range(len(paths)): path = wav_path + paths[x].strip() + '.wav' filename = path.split('/')[-1] - extracted_transcript = process_transcript(transcripts, x) + extracted_transcript = _process_transcript(transcripts, x) current_path = os.path.abspath(path) new_path = new_wav_path + filename text_path = new_transcript_path + filename.replace('.wav', '.txt') @@ -56,28 +57,26 @@ def format_files(file_ids, new_transcript_path, new_wav_path, transcripts, wav_p os.rename(current_path, new_path) -def process_transcript(transcripts, x): +def _process_transcript(transcripts, x): extracted_transcript = transcripts[x].split('(')[0].strip("").split('<')[0].strip().upper() return extracted_transcript def main(): - global args, root_path - args = parser.parse_args() root_path = 'an4/' name = 'an4' subprocess.call(['wget http://www.speech.cs.cmu.edu/databases/an4/an4_raw.bigendian.tar.gz'], shell=True) subprocess.call(['tar -xzvf an4_raw.bigendian.tar.gz'], stdout=open(os.devnull, 'wb'), shell=True) os.makedirs(args.an4_path) - format_data('train', name, 'an4_clstk') - format_data('test', name, 'an4test_clstk') + _format_data(root_path, 'train', name, 'an4_clstk') + _format_data(root_path, 'test', name, 'an4test_clstk') shutil.rmtree(root_path) os.remove('an4_raw.bigendian.tar.gz') train_path = args.an4_path + '/train/' test_path = args.an4_path + '/test/' print ('Creating manifests...') create_manifest(train_path, 'train') - create_manifest(test_path, 'test') + create_manifest(test_path, 'val') if __name__ == '__main__': diff --git a/data/data_loader.py b/data/data_loader.py index 4c1f0c03..dcea23e4 100644 --- a/data/data_loader.py +++ b/data/data_loader.py @@ -1,45 +1,51 @@ +import librosa +import numpy as np import scipy.signal import torch from torch.utils.data import DataLoader from torch.utils.data import Dataset -import librosa -import numpy as np +windows = {'hamming': scipy.signal.hamming, 'hann': scipy.signal.hann, 'blackman': scipy.signal.blackman, + 'bartlett': scipy.signal.bartlett} -class AudioDataset(Dataset): - def __init__(self, conf): - super(AudioDataset, self).__init__() - with open(conf['manifest_filename']) as f: - ids = f.readlines() - ids = [x.strip().split(',') for x in ids] - self.ids = ids - self.size = len(ids) - self.conf = conf - self.audio_conf = conf['audio'] - self.alphabet_map = dict([(conf['alphabet'][i], i) for i in range(len(conf['alphabet']))]) - self.normalize = conf.get('normalize', False) - - def __getitem__(self, index): - sample = self.ids[index] - audio_path, transcript_path = sample[0], sample[1] - spect = self.spectrogram(audio_path) - transcript = self.parse_transcript(transcript_path) - return spect, transcript +class AudioParser(object): def parse_transcript(self, transcript_path): - with open(transcript_path, 'r') as transcript_file: - transcript = transcript_file.read().replace('\n', '') - transcript = [self.alphabet_map[x] for x in list(transcript)] - return transcript + """ + :param transcript_path: Path where transcript is stored from the manifest file + :return: Transcript in training/testing format + """ + raise NotImplementedError + + def parse_audio(self, audio_path): + """ + :param audio_path: Path where audio is stored from the manifest file + :return: Audio in training/testing format + """ + raise NotImplementedError + - def spectrogram(self, audio_path): - y, _ = librosa.core.load(audio_path, sr=self.audio_conf['sample_rate']) - n_fft = int(self.audio_conf['sample_rate'] * self.audio_conf['window_size']) +class SpectrogramParser(AudioParser): + def __init__(self, audio_conf, normalize=False): + """ + Parses audio file into spectrogram with optional normalization + :param audio_conf: Dictionary containing the sample rate, window and the window length/stride in seconds + :param normalize(default False): Apply standard mean and deviation normalization to audio tensor + """ + super(SpectrogramParser, self).__init__() + self.window_stride = audio_conf['window_stride'] + self.window_size = audio_conf['window_size'] + self.sample_rate = audio_conf['sample_rate'] + self.window = windows.get(audio_conf['window'], windows['hamming']) + self.normalize = normalize + + def parse_audio(self, audio_path): + y, _ = librosa.core.load(audio_path, sr=self.sample_rate) + n_fft = int(self.sample_rate * self.window_size) win_length = n_fft - hop_length = int(self.audio_conf['sample_rate'] * self.audio_conf['window_stride']) - window = scipy.signal.hamming # TODO if statement to select window based on conf + hop_length = int(self.sample_rate * self.window_stride) # STFT - D = librosa.stft(y, n_fft=n_fft, hop_length=hop_length, win_length=win_length, window=window) + D = librosa.stft(y, n_fft=n_fft, hop_length=hop_length, win_length=win_length, window=self.window) spect, phase = librosa.magphase(D) # S = log(S+1) spect = np.log1p(spect) @@ -52,11 +58,50 @@ def spectrogram(self, audio_path): return spect + def parse_transcript(self, transcript_path): + raise NotImplementedError + + +class SpectrogramDataset(Dataset, SpectrogramParser): + def __init__(self, audio_conf, manifest_filepath, labels, normalize=False): + """ + Dataset that loads tensors via a csv containing file paths to audio files and transcripts separated by + a comma. Each new line is a different sample. Example below: + + /path/to/audio.wav,/path/to/audio.txt + ... + + :param audio_conf: Dictionary containing the sample rate, window and the window length/stride in seconds + :param manifest_filepath: Path to manifest csv as describe above + :param labels: String containing all the possible characters to map to + :param normalize: Apply standard mean and deviation normalization to audio tensor + """ + with open(manifest_filepath) as f: + ids = f.readlines() + ids = [x.strip().split(',') for x in ids] + self.ids = ids + self.size = len(ids) + self.labels_map = dict([(labels[i], i) for i in range(len(labels))]) + super(SpectrogramDataset, self).__init__(audio_conf, normalize) + + def __getitem__(self, index): + sample = self.ids[index] + audio_path, transcript_path = sample[0], sample[1] + spect = self.parse_audio(audio_path) + transcript = self.parse_transcript(transcript_path) + return spect, transcript + + def parse_transcript(self, transcript_path): + with open(transcript_path, 'r') as transcript_file: + transcript = transcript_file.read().replace('\n', '') + transcript = [self.labels_map[x] for x in list(transcript)] + return transcript + def __len__(self): return self.size -def collate_fn(batch): +def _collate_fn(batch): def func(p): return p[0].size(1) @@ -83,5 +128,8 @@ def func(p): class AudioDataLoader(DataLoader): def __init__(self, *args, **kwargs): + """ + Creates a data loader for AudioDatasets. + """ super(AudioDataLoader, self).__init__(*args, **kwargs) - self.collate_fn = collate_fn + self.collate_fn = _collate_fn diff --git a/data/utils.py b/data/utils.py index ffc756e2..8ea1bd0e 100644 --- a/data/utils.py +++ b/data/utils.py @@ -1,37 +1,51 @@ -import argparse +from __future__ import print_function + +import fnmatch import io import os import subprocess -parser = argparse.ArgumentParser(description='Creates training and testing manifests') -parser.add_argument('--root_path', default='an4_dataset', help='Path to the dataset') -""" -We need to add progress bars like will did in gen audio. Just copy that code. -We also need a call (basically the same in the dataloader, a find that gives us the total number of wav files) -""" +def _update_progress(progress): + print("\rProgress: [{0:50s}] {1:.1f}%".format('#' * int(progress * 50), + progress * 100), end="") def create_manifest(data_path, tag, ordered=True): manifest_path = '%s_manifest.csv' % tag file_paths = [] - with os.popen('find %s -type f -name "*.wav"' % data_path) as pipe: - for file_path in pipe: - file_paths.append(file_path.strip()) + wav_files = [os.path.join(dirpath, f) + for dirpath, dirnames, files in os.walk(data_path) + for f in fnmatch.filter(files, '*.wav')] + size = len(wav_files) + counter = 0 + for file_path in wav_files: + file_paths.append(file_path.strip()) + counter += 1 + _update_progress(counter / float(size)) + print('\n') if ordered: - print("Sorting files by length...") - - def func(element): - output = subprocess.check_output( - ['soxi -D %s' % element.strip()], - shell=True - ) - return float(output) - - file_paths.sort(key=func) + _order_files(file_paths) + counter = 0 with io.FileIO(manifest_path, "w") as file: for wav_path in file_paths: transcript_path = wav_path.replace('/wav/', '/txt/').replace('.wav', '.txt') sample = os.path.abspath(wav_path) + ',' + os.path.abspath(transcript_path) + '\n' file.write(sample) + counter += 1 + _update_progress(counter / float(size)) + print('\n') + + +def _order_files(file_paths): + print("Sorting files by length...") + + def func(element): + output = subprocess.check_output( + ['soxi -D %s' % element.strip()], + shell=True + ) + return float(output) + + file_paths.sort(key=func) diff --git a/decoder.py b/decoder.py index 078c9629..155a5e4b 100644 --- a/decoder.py +++ b/decoder.py @@ -25,15 +25,15 @@ class Decoder(object): helper functions. Subclasses should implement the decode() method. Arguments: - alphabet (string): mapping from integers to characters. + labels (string): mapping from integers to characters. blank_index (int, optional): index for the blank '_' character. Defaults to 0. space_index (int, optional): index for the space ' ' character. Defaults to 28. """ - def __init__(self, alphabet, blank_index=0, space_index=1): - # e.g. alphabet = "_'ABCDEFGHIJKLMNOPQRSTUVWXYZ#" - self.alphabet = alphabet - self.int_to_char = dict([(i, c) for (i, c) in enumerate(alphabet)]) + def __init__(self, labels, blank_index=0, space_index=1): + # e.g. labels = "_'ABCDEFGHIJKLMNOPQRSTUVWXYZ#" + self.labels = labels + self.int_to_char = dict([(i, c) for (i, c) in enumerate(labels)]) self.blank_index = blank_index self.space_index = space_index @@ -73,7 +73,7 @@ def process_string(self, remove_repetitions, sequence): # skip. if remove_repetitions and i != 0 and char == sequence[i - 1]: pass - elif char == self.alphabet[self.space_index]: + elif char == self.labels[self.space_index]: string += ' ' else: string = string + char diff --git a/labels.json b/labels.json new file mode 100644 index 00000000..396d6b68 --- /dev/null +++ b/labels.json @@ -0,0 +1,31 @@ +[ + "_", + "'", + "A", + "B", + "C", + "D", + "E", + "F", + "G", + "H", + "I", + "J", + "K", + "L", + "M", + "N", + "O", + "P", + "Q", + "R", + "S", + "T", + "U", + "V", + "W", + "X", + "Y", + "Z", + " " +] \ No newline at end of file diff --git a/model.py b/model.py index d0dd0673..53cbe604 100644 --- a/model.py +++ b/model.py @@ -3,20 +3,24 @@ import torch import torch.nn as nn from torch.autograd import Variable +import math class SequenceWise(nn.Module): - def __init__(self, module, batch_first=False): + def __init__(self, module): + """ + Collapses input of dim T*N*H to (T*N)*H, and applies to a module. + Allows handling of variable sequence lengths and minibatch sizes. + :param module: Module to apply input to. + """ super(SequenceWise, self).__init__() self.module = module - self.batch_first = batch_first def forward(self, x): t, n = x.size(0), x.size(1) x = x.view(t * n, -1) x = self.module(x) - x = x.view(n, t, -1) - x = x.transpose(0, 1).contiguous() + x = x.view(t, n, -1) return x def __repr__(self): @@ -33,7 +37,7 @@ def __init__(self, input_size, hidden_size, bidirectional=False, batch_norm=True self.hidden_size = hidden_size self.batch_norm_activate = batch_norm self.bidirectional = bidirectional - self.batch_norm = nn.BatchNorm1d(input_size) + self.batch_norm = SequenceWise(nn.BatchNorm1d(input_size)) self.rnn = nn.LSTM(input_size=input_size, hidden_size=hidden_size, bidirectional=bidirectional, bias=False) self.num_directions = 2 if bidirectional else 1 @@ -42,29 +46,29 @@ def forward(self, x): c0 = Variable(torch.zeros(self.num_directions, x.size(1), self.hidden_size).type_as(x.data)) h0 = Variable(torch.zeros(self.num_directions, x.size(1), self.hidden_size).type_as(x.data)) if self.batch_norm_activate: - t, n = x.size(0), x.size(1) - x = x.view(n, -1, t) x = self.batch_norm(x) - x = x.transpose(1, 2).transpose(0, 1) - x = x.contiguous() x, _ = self.rnn(x, (c0, h0)) if self.bidirectional: - x = x.view(x.size(0), x.size(1), 2, -1).sum(2).view(x.size(0), x.size(1), -1) # (TxNxD*2) -> (TxNxD) by sum + x = x.view(x.size(0), x.size(1), 2, -1).sum(2).view(x.size(0), x.size(1), -1) # (TxNxH*2) -> (TxNxH) by sum return x class DeepSpeech(nn.Module): def __init__(self, num_classes=29, rnn_hidden_size=400, nb_layers=4, bidirectional=True): super(DeepSpeech, self).__init__() - rnn_input_size = 32 * 41 # TODO this is only for 16khz, work this out for any window_size/stride/sample_rate self.conv = nn.Sequential( nn.Conv2d(1, 32, kernel_size=(41, 11), stride=(2, 2)), nn.BatchNorm2d(32), nn.Hardtanh(0, 20, inplace=True), - nn.Conv2d(32, 32, kernel_size=(21, 11), stride=(1, 2)), + nn.Conv2d(32, 32, kernel_size=(21, 11), stride=(2, 1)), nn.BatchNorm2d(32), nn.Hardtanh(0, 20, inplace=True) ) + # Based on above convolutions and spectrogram size using conv formula (W - F + 2P)/ S+1 + rnn_input_size = int(math.floor((16000 * 0.02) / 2) + 1) + rnn_input_size = int(math.floor(rnn_input_size - 41) / 2 + 1) + rnn_input_size = int(math.floor(rnn_input_size - 21) / 2 + 1) + rnn_input_size *= 32 rnns = [] rnn = BatchLSTM(input_size=rnn_input_size, hidden_size=rnn_hidden_size, bidirectional=bidirectional, batch_norm=False) @@ -79,7 +83,7 @@ def __init__(self, num_classes=29, rnn_hidden_size=400, nb_layers=4, bidirection nn.Linear(rnn_hidden_size, num_classes, bias=False) ) self.fc = nn.Sequential( - SequenceWise(fully_connected, batch_first=True), + SequenceWise(fully_connected), ) def forward(self, x): @@ -87,7 +91,7 @@ def forward(self, x): sizes = x.size() x = x.view(sizes[0], sizes[1] * sizes[2], sizes[3]) # Collapse feature dimension - x = x.transpose(1, 2).transpose(0, 1).contiguous() # seqLength x batch x features + x = x.transpose(1, 2).transpose(0, 1).contiguous() # TxNxH x = self.rnns(x) diff --git a/predict.py b/predict.py new file mode 100644 index 00000000..e0e1d3b8 --- /dev/null +++ b/predict.py @@ -0,0 +1,44 @@ +import argparse +import json + +import torch +from torch.autograd import Variable + +from data.data_loader import SpectrogramParser +from decoder import ArgMaxDecoder +from model import DeepSpeech + +parser = argparse.ArgumentParser(description='DeepSpeech prediction') +parser.add_argument('--sample_rate', default=16000, type=int, help='Sample rate') +parser.add_argument('--labels_path', default='labels.json', help='Contains all characters for prediction') +parser.add_argument('--model_path', default='models/deepspeech_final.pth.tar', + help='Path to model file created by training') +parser.add_argument('--audio_path', default='audio.wav', + help='Audio file to predict on') +parser.add_argument('--window_size', default=.02, type=float, help='Window size for spectrogram in seconds') +parser.add_argument('--window_stride', default=.01, type=float, help='Window stride for spectrogram in seconds') +parser.add_argument('--window', default='hamming', help='Window type for spectrogram generation') +parser.add_argument('--cuda', default=True, type=bool, help='Use cuda to train model') +args = parser.parse_args() + +if __name__ == '__main__': + package = torch.load(args.model_path) + model = DeepSpeech(rnn_hidden_size=package['hidden_size'], nb_layers=package['hidden_layers'], + num_classes=package['nout']) + if args.cuda: + model = torch.nn.DataParallel(model).cuda() + model.load_state_dict(package['state_dict']) + audio_conf = dict(sample_rate=args.sample_rate, + window_size=args.window_size, + window_stride=args.window_stride, + window=args.window) + with open(args.labels_path) as label_file: + labels = str(''.join(json.load(label_file))) + decoder = ArgMaxDecoder(labels) + parser = SpectrogramParser(audio_conf, normalize=True) + spect = parser.parse_audio(args.audio_path).contiguous() + spect = spect.view(1, 1, spect.size(0), spect.size(1)) + out = model(Variable(spect)) + out = out.transpose(0, 1) # TxNxH + decoded_output = decoder.decode(out.data) + print(decoded_output[0]) diff --git a/test/test.py b/test/test.py index c6ffcdbd..4fac6164 100644 --- a/test/test.py +++ b/test/test.py @@ -12,7 +12,7 @@ def test_decoder(self): [[[0, 0, 0, 1, 0, 0], [0, 0, 1, 0, 0, 0], [0, 0, 0, 0, 0, 1]], [[0, 0, 0, 0, 0, 1], [0, 0, 0, 0, 0, 1], [0, 0, 0, 0, 0, 1]]]) .transpose(0, 1)) # seqLength x batch x outputDim - decoder = ArgMaxDecoder(alphabet="_'ABCDEFGHIJKLMNOPQRSTUVWXYZ ") + decoder = ArgMaxDecoder(labels="_'ABCDEFGHIJKLMNOPQRSTUVWXYZ ") decoded = decoder.decode(input.data, None) expected_decoding = ['BAD', 'D'] self.assertItemsEqual(expected_decoding, decoded) diff --git a/main.py b/train.py similarity index 69% rename from main.py rename to train.py index 3734a561..d7c2ba50 100644 --- a/main.py +++ b/train.py @@ -1,26 +1,30 @@ import argparse +import errno +import json +import os import time import torch from torch.autograd import Variable from warpctc_pytorch import CTCLoss +from data.data_loader import AudioDataLoader, SpectrogramDataset from decoder import ArgMaxDecoder from model import DeepSpeech -from data.data_loader import AudioDataLoader, AudioDataset -parser = argparse.ArgumentParser(description='DeepSpeech pytorch params') +parser = argparse.ArgumentParser(description='DeepSpeech training') parser.add_argument('--train_manifest', metavar='DIR', help='path to train manifest csv', default='data/train_manifest.csv') -parser.add_argument('--test_manifest', metavar='DIR', - help='path to test manifest csv', default='data/test_manifest.csv') +parser.add_argument('--val_manifest', metavar='DIR', + help='path to validation manifest csv', default='data/val_manifest.csv') parser.add_argument('--sample_rate', default=16000, type=int, help='Sample rate') parser.add_argument('--batch_size', default=20, type=int, help='Batch size for training') parser.add_argument('--num_workers', default=4, type=int, help='Number of workers used in dataloading') -parser.add_argument('--frame_length', default=.02, type=float, help='Window size for spectrogram in seconds') -parser.add_argument('--frame_stride', default=.01, type=float, help='Window stride for spectrogram in seconds') +parser.add_argument('--labels_path', default='labels.json', help='Contains all characters for prediction') +parser.add_argument('--window_size', default=.02, type=float, help='Window size for spectrogram in seconds') +parser.add_argument('--window_stride', default=.01, type=float, help='Window stride for spectrogram in seconds') parser.add_argument('--window', default='hamming', help='Window type for spectrogram generation') -parser.add_argument('--hidden_size', default=512, type=int, help='Hidden size of RNNs') +parser.add_argument('--hidden_size', default=400, type=int, help='Hidden size of RNNs') parser.add_argument('--hidden_layers', default=4, type=int, help='Number of RNN layers') parser.add_argument('--epochs', default=70, type=int, help='Number of training epochs') parser.add_argument('--cuda', default=True, type=bool, help='Use cuda to train model') @@ -29,6 +33,10 @@ parser.add_argument('--max_norm', default=400, type=int, help='Norm cutoff to prevent explosion of gradients') parser.add_argument('--learning_anneal', default=1.1, type=float, help='Annealing applied to learning rate every epoch') parser.add_argument('--silent', default=True, type=bool, help='Turn off progress tracking per iteration') +parser.add_argument('--epoch_save', default=False, type=bool, help='Save model every epoch') +parser.add_argument('--save_folder', default='models/', help='Location to save epoch models') +parser.add_argument('--final_model_path', default='models/deepspeech_final.pth.tar', + help='Location to save final model') class AverageMeter(object): @@ -50,35 +58,48 @@ def update(self, val, n=1): self.avg = self.sum / self.count +def checkpoint(model, args, nout, epoch=None): + package = { + 'epoch': epoch if epoch else 'N/A', + 'hidden_size': args.hidden_size, + 'hidden_layers': args.hidden_layers, + 'nout': nout, + 'state_dict': model.state_dict(), + } + return package + + def main(): args = parser.parse_args() - + save_folder = args.save_folder + try: + os.makedirs(save_folder) + except OSError as e: + if e.errno == errno.EEXIST: + print('Directory already exists.') + else: + raise criterion = CTCLoss() - alphabet = "_'ABCDEFGHIJKLMNOPQRSTUVWXYZ " - - audio_config = dict(sample_rate=16000, - window_size=0.02, - window_stride=0.01, - window_type='hamming', - ) - - train_dataloader_config = dict(type="audio,transcription", - audio=audio_config, - manifest_filename=args.train_manifest, - alphabet=alphabet, - normalize=True) - test_dataloader_config = dict(type="audio,transcription", - audio=audio_config, - manifest_filename=args.test_manifest, - alphabet=alphabet, - normalize=True) - train_loader = AudioDataLoader(AudioDataset(train_dataloader_config), args.batch_size, + + with open(args.labels_path) as label_file: + labels = str(''.join(json.load(label_file))) + + audio_conf = dict(sample_rate=args.sample_rate, + window_size=args.window_size, + window_stride=args.window_stride, + window=args.window) + + train_dataset = SpectrogramDataset(audio_conf=audio_conf, manifest_filepath=args.train_manifest, labels=labels, + normalize=True) + test_dataset = SpectrogramDataset(audio_conf=audio_conf, manifest_filepath=args.val_manifest, labels=labels, + normalize=True) + train_loader = AudioDataLoader(train_dataset, batch_size=args.batch_size, num_workers=args.num_workers) - test_loader = AudioDataLoader(AudioDataset(test_dataloader_config), args.batch_size, + test_loader = AudioDataLoader(test_dataset, batch_size=args.batch_size, num_workers=args.num_workers) - model = DeepSpeech(rnn_hidden_size=args.hidden_size, nb_layers=args.hidden_layers, num_classes=len(alphabet)) - decoder = ArgMaxDecoder(alphabet=alphabet) + model = DeepSpeech(rnn_hidden_size=args.hidden_size, nb_layers=args.hidden_layers, num_classes=len(labels)) + decoder = ArgMaxDecoder(labels) if args.cuda: model = torch.nn.DataParallel(model).cuda() print(model) @@ -90,7 +111,7 @@ def main(): data_time = AverageMeter() losses = AverageMeter() - for epoch in range(args.epochs - 1): + for epoch in range(args.epochs): model.train() end = time.time() avg_loss = 0 @@ -106,7 +127,7 @@ def main(): inputs = inputs.cuda() out = model(inputs) - out = out.transpose(0, 1) # seqLength x batchSize x alphabet + out = out.transpose(0, 1) # TxNxH seq_length = out.size(0) sizes = Variable(input_percentages.mul_(int(seq_length)).int()) @@ -175,7 +196,7 @@ def main(): inputs = inputs.cuda() out = model(inputs) - out = out.transpose(0, 1) # seqLength x batchSize x alphabet + out = out.transpose(0, 1) # TxNxH seq_length = out.size(0) sizes = Variable(input_percentages.mul_(int(seq_length)).int()) @@ -191,13 +212,14 @@ def main(): wer = total_wer / len(test_loader.dataset) cer = total_cer / len(test_loader.dataset) - # We need to format the targets into actual sentences print('Validation Summary Epoch: [{0}]\t' 'Average WER {wer:.0f}\t' 'Average CER {cer:.0f}\t'.format( (epoch + 1), wer=wer * 100, cer=cer * 100)) - decoded_output = decoder.decode(out.data, sizes) - print (decoded_output) + if args.epoch_save: + file_path = '%s/deepspeech_%d.pth.tar' % (save_folder, epoch) + torch.save(checkpoint(model, args, len(labels), epoch), file_path) + torch.save(checkpoint(model, args, len(labels)), args.final_model_path) if __name__ == '__main__':