Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
Sean Naren authored Feb 16, 2017
1 parent ce07d8b commit 906771b
Show file tree
Hide file tree
Showing 11 changed files with 302 additions and 130 deletions.
25 changes: 17 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
# deepspeech.pytorch

* Add tests for dataloading
* Fix validation. Assume problems with SequenceWise module and using batch norm in 3d mode.
* Support LibriSpeech via multi-processed scripts

Implementation of [Baidu Warp-CTC](https://github.com/baidu-research/warp-ctc) using pytorch.
Implementation of DeepSpeech2 using [Baidu Warp-CTC](https://github.com/baidu-research/warp-ctc).
Creates a network based on the [DeepSpeech2](http://arxiv.org/pdf/1512.02595v1.pdf) architecture, trained with the CTC activation function.

# Installation
Expand Down Expand Up @@ -39,16 +35,29 @@ pip install -r requirements.txt
Currently only supports an4. To download and setup the an4 dataset run below command in the root folder of the repo:

```
python get_an4.py
python create_dataset_manifest.py --root_path dataset/
cd data; python an4.py
```

This will generate csv manifests files used to load the data for training.

LibriSpeech formatting is in the works.

### Custom Dataset

To create a custom dataset you must create a CSV file containing the locations of the training data. This has to be in the format of:

```
/path/to/audio.wav,/path/to/text.txt
/path/to/audio2.wav,/path/to/text2.txt
...
```

The first path is to the audio file, and the second path is to a text file containing the transcript on one line. This can then be used as stated below.

## Training


```
python main.py --train_manifest train_manifest.csv --test_manifest test_manifest.csv
python train.py --train_manifest data/train_manifest.csv --val_manifest data/val_manifest.csv
```

1 change: 1 addition & 0 deletions data/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from . import data_loader
23 changes: 11 additions & 12 deletions data/get_an4.py → data/an4.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@
parser = argparse.ArgumentParser(description='Processes and downloads an4.')
parser.add_argument('--an4_path', default='an4_dataset/', help='Path to save dataset')
parser.add_argument('--sample_rate', default=16000, type=int, help='Sample rate')
args = parser.parse_args()


def format_data(data_tag, name, wav_folder):
def _format_data(root_path, data_tag, name, wav_folder):
data_path = args.an4_path + data_tag + '/' + name + '/'
new_transcript_path = data_path + '/txt/'
new_wav_path = data_path + '/wav/'
Expand All @@ -25,11 +26,11 @@ def format_data(data_tag, name, wav_folder):
transcripts = root_path + 'etc/an4_%s.transcription' % data_tag
train_path = wav_path + wav_folder

convert_audio_to_wav(train_path)
format_files(file_ids, new_transcript_path, new_wav_path, transcripts, wav_path)
_convert_audio_to_wav(train_path)
_format_files(file_ids, new_transcript_path, new_wav_path, transcripts, wav_path)


def convert_audio_to_wav(train_path):
def _convert_audio_to_wav(train_path):
with os.popen('find %s -type f -name "*.raw"' % train_path) as pipe:
for line in pipe:
raw_path = line.strip()
Expand All @@ -39,15 +40,15 @@ def convert_audio_to_wav(train_path):
os.system(cmd)


def format_files(file_ids, new_transcript_path, new_wav_path, transcripts, wav_path):
def _format_files(file_ids, new_transcript_path, new_wav_path, transcripts, wav_path):
with open(file_ids, 'r') as f:
with open(transcripts, 'r') as t:
paths = f.readlines()
transcripts = t.readlines()
for x in range(len(paths)):
path = wav_path + paths[x].strip() + '.wav'
filename = path.split('/')[-1]
extracted_transcript = process_transcript(transcripts, x)
extracted_transcript = _process_transcript(transcripts, x)
current_path = os.path.abspath(path)
new_path = new_wav_path + filename
text_path = new_transcript_path + filename.replace('.wav', '.txt')
Expand All @@ -56,28 +57,26 @@ def format_files(file_ids, new_transcript_path, new_wav_path, transcripts, wav_p
os.rename(current_path, new_path)


def process_transcript(transcripts, x):
def _process_transcript(transcripts, x):
extracted_transcript = transcripts[x].split('(')[0].strip("<s>").split('<')[0].strip().upper()
return extracted_transcript


def main():
global args, root_path
args = parser.parse_args()
root_path = 'an4/'
name = 'an4'
subprocess.call(['wget http://www.speech.cs.cmu.edu/databases/an4/an4_raw.bigendian.tar.gz'], shell=True)
subprocess.call(['tar -xzvf an4_raw.bigendian.tar.gz'], stdout=open(os.devnull, 'wb'), shell=True)
os.makedirs(args.an4_path)
format_data('train', name, 'an4_clstk')
format_data('test', name, 'an4test_clstk')
_format_data(root_path, 'train', name, 'an4_clstk')
_format_data(root_path, 'test', name, 'an4test_clstk')
shutil.rmtree(root_path)
os.remove('an4_raw.bigendian.tar.gz')
train_path = args.an4_path + '/train/'
test_path = args.an4_path + '/test/'
print ('Creating manifests...')
create_manifest(train_path, 'train')
create_manifest(test_path, 'test')
create_manifest(test_path, 'val')


if __name__ == '__main__':
Expand Down
114 changes: 81 additions & 33 deletions data/data_loader.py
Original file line number Diff line number Diff line change
@@ -1,45 +1,51 @@
import librosa
import numpy as np
import scipy.signal
import torch
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import librosa
import numpy as np

windows = {'hamming': scipy.signal.hamming, 'hann': scipy.signal.hann, 'blackman': scipy.signal.blackman,
'bartlett': scipy.signal.bartlett}

class AudioDataset(Dataset):
def __init__(self, conf):
super(AudioDataset, self).__init__()
with open(conf['manifest_filename']) as f:
ids = f.readlines()
ids = [x.strip().split(',') for x in ids]
self.ids = ids
self.size = len(ids)
self.conf = conf
self.audio_conf = conf['audio']
self.alphabet_map = dict([(conf['alphabet'][i], i) for i in range(len(conf['alphabet']))])
self.normalize = conf.get('normalize', False)

def __getitem__(self, index):
sample = self.ids[index]
audio_path, transcript_path = sample[0], sample[1]
spect = self.spectrogram(audio_path)
transcript = self.parse_transcript(transcript_path)
return spect, transcript

class AudioParser(object):
def parse_transcript(self, transcript_path):
with open(transcript_path, 'r') as transcript_file:
transcript = transcript_file.read().replace('\n', '')
transcript = [self.alphabet_map[x] for x in list(transcript)]
return transcript
"""
:param transcript_path: Path where transcript is stored from the manifest file
:return: Transcript in training/testing format
"""
raise NotImplementedError

def parse_audio(self, audio_path):
"""
:param audio_path: Path where audio is stored from the manifest file
:return: Audio in training/testing format
"""
raise NotImplementedError


def spectrogram(self, audio_path):
y, _ = librosa.core.load(audio_path, sr=self.audio_conf['sample_rate'])
n_fft = int(self.audio_conf['sample_rate'] * self.audio_conf['window_size'])
class SpectrogramParser(AudioParser):
def __init__(self, audio_conf, normalize=False):
"""
Parses audio file into spectrogram with optional normalization
:param audio_conf: Dictionary containing the sample rate, window and the window length/stride in seconds
:param normalize(default False): Apply standard mean and deviation normalization to audio tensor
"""
super(SpectrogramParser, self).__init__()
self.window_stride = audio_conf['window_stride']
self.window_size = audio_conf['window_size']
self.sample_rate = audio_conf['sample_rate']
self.window = windows.get(audio_conf['window'], windows['hamming'])
self.normalize = normalize

def parse_audio(self, audio_path):
y, _ = librosa.core.load(audio_path, sr=self.sample_rate)
n_fft = int(self.sample_rate * self.window_size)
win_length = n_fft
hop_length = int(self.audio_conf['sample_rate'] * self.audio_conf['window_stride'])
window = scipy.signal.hamming # TODO if statement to select window based on conf
hop_length = int(self.sample_rate * self.window_stride)
# STFT
D = librosa.stft(y, n_fft=n_fft, hop_length=hop_length, win_length=win_length, window=window)
D = librosa.stft(y, n_fft=n_fft, hop_length=hop_length, win_length=win_length, window=self.window)
spect, phase = librosa.magphase(D)
# S = log(S+1)
spect = np.log1p(spect)
Expand All @@ -52,11 +58,50 @@ def spectrogram(self, audio_path):

return spect

def parse_transcript(self, transcript_path):
raise NotImplementedError


class SpectrogramDataset(Dataset, SpectrogramParser):
def __init__(self, audio_conf, manifest_filepath, labels, normalize=False):
"""
Dataset that loads tensors via a csv containing file paths to audio files and transcripts separated by
a comma. Each new line is a different sample. Example below:
/path/to/audio.wav,/path/to/audio.txt
...
:param audio_conf: Dictionary containing the sample rate, window and the window length/stride in seconds
:param manifest_filepath: Path to manifest csv as describe above
:param labels: String containing all the possible characters to map to
:param normalize: Apply standard mean and deviation normalization to audio tensor
"""
with open(manifest_filepath) as f:
ids = f.readlines()
ids = [x.strip().split(',') for x in ids]
self.ids = ids
self.size = len(ids)
self.labels_map = dict([(labels[i], i) for i in range(len(labels))])
super(SpectrogramDataset, self).__init__(audio_conf, normalize)

def __getitem__(self, index):
sample = self.ids[index]
audio_path, transcript_path = sample[0], sample[1]
spect = self.parse_audio(audio_path)
transcript = self.parse_transcript(transcript_path)
return spect, transcript

def parse_transcript(self, transcript_path):
with open(transcript_path, 'r') as transcript_file:
transcript = transcript_file.read().replace('\n', '')
transcript = [self.labels_map[x] for x in list(transcript)]
return transcript

def __len__(self):
return self.size


def collate_fn(batch):
def _collate_fn(batch):
def func(p):
return p[0].size(1)

Expand All @@ -83,5 +128,8 @@ def func(p):

class AudioDataLoader(DataLoader):
def __init__(self, *args, **kwargs):
"""
Creates a data loader for AudioDatasets.
"""
super(AudioDataLoader, self).__init__(*args, **kwargs)
self.collate_fn = collate_fn
self.collate_fn = _collate_fn
54 changes: 34 additions & 20 deletions data/utils.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,51 @@
import argparse
from __future__ import print_function

import fnmatch
import io
import os

import subprocess

parser = argparse.ArgumentParser(description='Creates training and testing manifests')
parser.add_argument('--root_path', default='an4_dataset', help='Path to the dataset')

"""
We need to add progress bars like will did in gen audio. Just copy that code.
We also need a call (basically the same in the dataloader, a find that gives us the total number of wav files)
"""
def _update_progress(progress):
print("\rProgress: [{0:50s}] {1:.1f}%".format('#' * int(progress * 50),
progress * 100), end="")


def create_manifest(data_path, tag, ordered=True):
manifest_path = '%s_manifest.csv' % tag
file_paths = []
with os.popen('find %s -type f -name "*.wav"' % data_path) as pipe:
for file_path in pipe:
file_paths.append(file_path.strip())
wav_files = [os.path.join(dirpath, f)
for dirpath, dirnames, files in os.walk(data_path)
for f in fnmatch.filter(files, '*.wav')]
size = len(wav_files)
counter = 0
for file_path in wav_files:
file_paths.append(file_path.strip())
counter += 1
_update_progress(counter / float(size))
print('\n')
if ordered:
print("Sorting files by length...")

def func(element):
output = subprocess.check_output(
['soxi -D %s' % element.strip()],
shell=True
)
return float(output)

file_paths.sort(key=func)
_order_files(file_paths)
counter = 0
with io.FileIO(manifest_path, "w") as file:
for wav_path in file_paths:
transcript_path = wav_path.replace('/wav/', '/txt/').replace('.wav', '.txt')
sample = os.path.abspath(wav_path) + ',' + os.path.abspath(transcript_path) + '\n'
file.write(sample)
counter += 1
_update_progress(counter / float(size))
print('\n')


def _order_files(file_paths):
print("Sorting files by length...")

def func(element):
output = subprocess.check_output(
['soxi -D %s' % element.strip()],
shell=True
)
return float(output)

file_paths.sort(key=func)
12 changes: 6 additions & 6 deletions decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,15 @@ class Decoder(object):
helper functions. Subclasses should implement the decode() method.
Arguments:
alphabet (string): mapping from integers to characters.
labels (string): mapping from integers to characters.
blank_index (int, optional): index for the blank '_' character. Defaults to 0.
space_index (int, optional): index for the space ' ' character. Defaults to 28.
"""

def __init__(self, alphabet, blank_index=0, space_index=1):
# e.g. alphabet = "_'ABCDEFGHIJKLMNOPQRSTUVWXYZ#"
self.alphabet = alphabet
self.int_to_char = dict([(i, c) for (i, c) in enumerate(alphabet)])
def __init__(self, labels, blank_index=0, space_index=1):
# e.g. labels = "_'ABCDEFGHIJKLMNOPQRSTUVWXYZ#"
self.labels = labels
self.int_to_char = dict([(i, c) for (i, c) in enumerate(labels)])
self.blank_index = blank_index
self.space_index = space_index

Expand Down Expand Up @@ -73,7 +73,7 @@ def process_string(self, remove_repetitions, sequence):
# skip.
if remove_repetitions and i != 0 and char == sequence[i - 1]:
pass
elif char == self.alphabet[self.space_index]:
elif char == self.labels[self.space_index]:
string += ' '
else:
string = string + char
Expand Down
Loading

0 comments on commit 906771b

Please sign in to comment.