Skip to content

Commit

Permalink
Pytorch Lightning Integration (#569)
Browse files Browse the repository at this point in the history
* Added minimal code to integrate Pytorch Lightning into training.py

* Added autocast support, removed intra epoch checkpointing for simplicity, integrated checkpoint support, fixed validation support

* Fixed multi-gpu support

* Fixed smoke test, pretrained tests will be broken till new model release

Added trains viz logging

Precision

* Updated README, fixed server class, updated k8s config file, added fix for adam Trains support, removed autocast since this is handled via lightning

* Swapped to using tqdm write for readability when checkpointing, added an4 config

* Added base script for each dataset, updated default params

* Swapped to using native CTC, updated common voice script, removed incorrect lightning version

* Updated cv params and output manifest location, set default epochs to the epochs used for previous release

* Disable trains logger for now, simplified checkpointing logic for new release

* Added new metrics class, removed save_output/verbose for now, using new ModelCheckpoint class for model saving

* multiprocess duration collection for speed, allow loading from file path, refactor path name and test

* Swap to latest release candidate, fixed flag reference

* Format smoke test, update path to best save k model

* Update to latest RC

* Removed trains logging, rely on PL tensorboard. swap to saving json object for manifest to modify root path

* Ensure abs path for manifest root path

* Use absolute paths for manifest

* Update requirements, abstract all PL trainer arguments

* Enable checkpoint callback

* Enable checkpoint callback, add verbosity

* Add sharded as a dependency for better memory use

* Set num workers, add spec augment

* Update deepspeech_pytorch/data/utils.py

Co-authored-by: Anas Abou Allaban <[email protected]>

* Specify blank index explicitly

* Add blank index to ctc loss

* Fix CI

* Fix Syntax Warning

* Fix install requirements

* Use torchaudio (#607)

* Use torchaudio

* Add torchaudio to reqs

* Fixes for testing, update AN4 config, update dockerfile base image

* Add noninteractive to remove stalling

* revert

* Update API

Co-authored-by: Sean Narenthiran <[email protected]>
Co-authored-by: Anas Abou Allaban <[email protected]>
Co-authored-by: Anas Abou Allaban <[email protected]>
  • Loading branch information
4 people authored Jan 30, 2021
1 parent 4cb209a commit d9790d9
Show file tree
Hide file tree
Showing 37 changed files with 1,043 additions and 1,239 deletions.
12 changes: 1 addition & 11 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
FROM pytorch/pytorch:1.5.1-cuda10.1-cudnn7-devel
FROM pytorch/pytorch:1.7.0-cuda11.0-cudnn8-devel
ENV LD_LIBRARY_PATH=/usr/local/lib:$LD_LIBRARY_PATH

WORKDIR /workspace/
Expand All @@ -7,20 +7,10 @@ WORKDIR /workspace/
RUN apt-get update -y
RUN apt-get install -y git curl ca-certificates bzip2 cmake tree htop bmon iotop sox libsox-dev libsox-fmt-all vim

# install warp-CTC
ENV CUDA_HOME=/usr/local/cuda
RUN git clone https://github.com/SeanNaren/warp-ctc.git
RUN cd warp-ctc; mkdir build; cd build; cmake ..; make
RUN cd warp-ctc; cd pytorch_binding; python setup.py install

# install ctcdecode
RUN git clone --recursive https://github.com/parlance/ctcdecode.git
RUN cd ctcdecode; pip install .

# install apex
RUN git clone --recursive https://github.com/NVIDIA/apex.git
RUN cd apex; pip install .

# install deepspeech.pytorch
ADD . /workspace/deepspeech.pytorch
RUN cd deepspeech.pytorch; pip install -r requirements.txt && pip install -e .
Expand Down
90 changes: 28 additions & 62 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# deepspeech.pytorch
[![Build Status](https://travis-ci.org/SeanNaren/deepspeech.pytorch.svg?branch=master)](https://travis-ci.org/SeanNaren/deepspeech.pytorch)

Implementation of DeepSpeech2 for PyTorch. The repo supports training/testing and inference using the [DeepSpeech2](http://arxiv.org/pdf/1512.02595v1.pdf) model. Optionally a [kenlm](https://github.com/kpu/kenlm) language model can be used at inference time.
Implementation of DeepSpeech2 for PyTorch using [PyTorch Lightning](https://github.com/PyTorchLightning/pytorch-lightning). The repo supports training/testing and inference using the [DeepSpeech2](http://arxiv.org/pdf/1512.02595v1.pdf) model. Optionally a [kenlm](https://github.com/kpu/kenlm) language model can be used at inference time.

## Installation

Expand All @@ -26,20 +26,6 @@ an Anaconda installation on Ubuntu, with PyTorch installed.

Install [PyTorch](https://github.com/pytorch/pytorch#installation) if you haven't already.

Install this fork for Warp-CTC bindings:
```
git clone https://github.com/SeanNaren/warp-ctc.git
cd warp-ctc; mkdir build; cd build; cmake ..; make
export CUDA_HOME="/usr/local/cuda"
cd ../pytorch_binding && python setup.py install
```

Install NVIDIA apex:
```
git clone --recursive https://github.com/NVIDIA/apex.git
cd apex && pip install .
```

If you want decoding to support beam search with an optional language model, install ctcdecode:
```
git clone --recursive https://github.com/parlance/ctcdecode.git
Expand Down Expand Up @@ -93,7 +79,7 @@ Configuration is done via [Hydra](https://github.com/facebookresearch/hydra).
Defaults can be seen in [config.py](deepspeech_pytorch/configs/train_config.py). Below is how you can override values set already:

```
python train.py data.train_manifest=data/train_manifest.csv data.val_manifest=data/val_manifest.csv
python train.py data.train_path=data/train_manifest.csv data.val_path=data/val_manifest.csv
```

Use `python train.py --help` for all parameters and options.
Expand All @@ -103,27 +89,15 @@ You can also specify a config file to keep parameters stored in a yaml file like
Create folder `experiment/` and file `experiment/an4.yaml`:
```yaml
data:
train_manifest: data/an4_train_manifest.csv
val_manifest: data/an4_val_manifest.csv
train_path: data/an4_train_manifest.csv
val_path: data/an4_val_manifest.csv
```
```
python train.py +experiment=an4
```

There is also [Visdom](https://github.com/facebookresearch/visdom) support to visualize training. Once a server has been started, to use:

```
python train.py visualization.visdom=true
```

There is also Tensorboard support to visualize training. Follow the instructions to set up. To use:

```
python train.py visualization.tensorboard=true visualization.log_dir=log_dir/ # Make sure the Tensorboard instance is made pointing to this log directory
```

For both visualisation tools, you can add your own name to the run by changing the `--id` parameter when training.
To see options available, check [here](./deepspeech_pytorch/configs/train_config.py).

### Multi-GPU Training

Expand All @@ -136,9 +110,10 @@ python -m torchelastic.distributed.launch \
--standalone \
--nnodes=1 \
--nproc_per_node=4 \
train.py data.train_manifest=data/an4_train_manifest.csv \
data.val_manifest=data/an4_val_manifest.csv apex.opt_level=O1 data.num_workers=8 \
data.batch_size=8 training.epochs=70 checkpointing.checkpoint=true checkpointing.save_n_recent_models=3
train.py data.train_path=data/an4_train_manifest.csv \
data.val_path=data/an4_val_manifest.csv model.precision=half data.num_workers=8 \
data.batch_size=8 trainer.max_epochs=70 checkpoint.checkpoint=true checkpointing.save_n_recent_models=3 \
trainer.accelerator=ddp trainer.gpus=4
```

You'll see the output for all the processes running on each individual GPU.
Expand Down Expand Up @@ -169,28 +144,27 @@ python -m torchelastic.distributed.launch \
--rdzv_id=123 \
--rdzv_backend=etcd \
--rdzv_endpoint=$PUBLIC_HOST_NAME:4377 \
train.py data.train_manifest=/share/data/an4_train_manifest.csv \
data.val_manifest=/share/data/an4_val_manifest.csv apex.opt_level=O1 \
data.num_workers=8 checkpointing.save_folder=/share/checkpoints/ \
checkpointing.checkpoint=true checkpointing.load_auto_checkpoint=true checkpointing.save_n_recent_models=3 \
data.batch_size=8 training.epochs=70
train.py data.train_path=/share/data/an4_train_manifest.csv \
data.val_path=/share/data/an4_val_manifest.csv model.precision=half \
data.num_workers=8 checkpoint.save_folder=/share/checkpoints/ \
checkpoint.checkpoint=true checkpoint.load_auto_checkpoint=true checkpointing.save_n_recent_models=3 \
data.batch_size=8 trainer.max_epochs=70 \
trainer.accelerator=ddp trainer.gpus=4 trainer.num_nodes=2
```

Using the `checkpointing.load_auto_checkpoint=true` flag and the `checkpointing.checkpoint_per_iteration` flag we can re-continue training from the latest saved checkpoint.
Using the `load_auto_checkpoint=true` flag we can re-continue training from the latest saved checkpoint.

Currently it is expected that there is an NFS drive/shared mount across all nodes within the cluster to load the latest checkpoint from.

### Mixed Precision

If you are using NVIDIA volta cards or above to train your model, it's highly suggested to turn on mixed precision for speed/memory benefits. More information can be found [here](https://docs.nvidia.com/deeplearning/sdk/mixed-precision-training/index.html).

Different Optimization levels are available. More information on the Nvidia Apex API can be seen [here](https://nvidia.github.io/apex/amp.html#opt-levels).

```
python train.py data.train_manifest=data/train_manifest.csv data.val_manifest=data/val_manifest.csv apex.opt_level=O1 apex.loss_scale=1.0
python train.py data.train_manifest=data/train_manifest.csv data.val_manifest=data/val_manifest.csv trainer.precision=16
```

Training a model in mixed-precision means you can use 32 bit float or half precision at runtime. Float32 is default, to use half precision (Which on V100s come with a speedup and better memory use) use the `--half` flag when testing or transcribing.
Training a model in mixed-precision means you can use 32 bit float or half precision at runtime. Float32 is default, to use half precision (Which on V100s come with a speedup and better memory use) use the `model.precision=half` flag when testing or transcribing.

### Swapping to ADAMW Optimizer

Expand Down Expand Up @@ -230,29 +204,21 @@ Applies small changes to the tempo and gain when loading audio to increase robus

### Checkpoints

Training supports saving checkpoints of the model to continue training from should an error occur or early termination. To enable epoch
checkpoints use:

```
python train.py checkpoint=true
```
Training supports saving checkpoints of the model to continue training from should an error occur or early termination.

To enable checkpoints every N batches through the epoch as well as epoch saving:
To enable epoch checkpoints use:

```
python train.py checkpoint=true --checkpoint-per-batch N # N is the number of batches to wait till saving a checkpoint at this batch.
python train.py checkpoint=true
```

Note for the batch checkpointing system to work, you cannot change the batch size when loading a checkpointed model from it's original training
run.

To continue from a checkpointed model that has been saved:
To continue from a checkpoint model:

```
python train.py checkpointing.continue_from=models/deepspeech_checkpoint_epoch_N_iter_N.pth
```

This continues from the same training state as well as recreates the visdom graph to continue from if enabled.
This continues from the same training state.

If you would like to start from a previous checkpoint model but not continue training, add the `training.finetune=true` flag to restart training
from the `checkpointing.continue_from` weights.
Expand All @@ -275,7 +241,7 @@ To also note, there is no final softmax layer on the model as when trained, warp
To evaluate a trained model on a test set (has to be in the same format as the training set):

```
python test.py model.model_path=models/deepspeech.pth test_manifest=/path/to/test_manifest.csv
python test.py model.model_path=models/deepspeech.pth test_path=/path/to/test_manifest.csv
```

An example script to output a transcription has been provided:
Expand All @@ -284,7 +250,7 @@ An example script to output a transcription has been provided:
python transcribe.py model.model_path=models/deepspeech.pth audio_path=/path/to/audio.wav
```

If you used mixed-precision or half precision when training the model, you can use the `--half` flag for a speed/memory benefit.
If you used mixed-precision or half precision when training the model, you can use the `model.precision=half` for a speed/memory benefit.

## Inference Server

Expand All @@ -307,7 +273,7 @@ In addition download the latest pre-trained librispeech model from the releases

First we need to generate the acoustic output to be used to evaluate the model on LibriSpeech val.
```
python test.py data.test_manifest=data/librispeech_val_manifest.csv model.model_path=librispeech_pretrained_v2.pth save_output=librispeech_val_output.npy
python test.py data.test_path=data/librispeech_val_manifest.csv model.model_path=librispeech_pretrained_v2.pth save_output=librispeech_val_output.npy
```

We use a beam width of 128 which gives reasonable results. We suggest using a CPU intensive node to carry out the grid search.
Expand All @@ -331,15 +297,15 @@ To build your own LM you need to use the KenLM repo found [here](https://github.
### Alternate Decoders
By default, `test.py` and `transcribe.py` use a `GreedyDecoder` which picks the highest-likelihood output label at each timestep. Repeated and blank symbols are then filtered to give the final output.

A beam search decoder can optionally be used with the installation of the `ctcdecode` library as described in the Installation section. The `test` and `transcribe` scripts have a `decoder_type` argument. To use the beam decoder, add `lm.decoder_type=beam`. The beam decoder enables additional decoding parameters:
A beam search decoder can optionally be used with the installation of the `ctcdecode` library as described in the Installation section. The `test` and `transcribe` scripts have a `lm` config. To use the beam decoder, add `lm.decoder_type=beam`. The beam decoder enables additional decoding parameters:
- **lm.beam_width** how many beams to consider at each timestep
- **lm.lm_path** optional binary KenLM language model to use for decoding
- **lm.alpha** weight for language model
- **lm.beta** bonus weight for words

### Time offsets

Use the `--offsets` flag to get positional information of each character in the transcription when using `transcribe.py` script. The offsets are based on the size
Use the `offsets=true` flag to get positional information of each character in the transcription when using `transcribe.py` script. The offsets are based on the size
of the output tensor, which you need to convert into a format required.
For example, based on default parameters you could multiply the offsets by a scalar (duration of file in seconds / size of output) to get the offsets in seconds.

Expand Down
18 changes: 18 additions & 0 deletions configs/an4.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# @package _global_
data:
train_path: data/an4_train_manifest.json
val_path: data/an4_val_manifest.json
batch_size: 8
num_workers: 8
trainer:
max_epochs: 70
gpus: 1
precision: 16
gradient_clip_val: 400 # Norm cutoff to prevent explosion of gradients
accelerator: ddp
plugins: ddp_sharded
checkpoint_callback: True
checkpoint:
save_top_k: 1
monitor: "wer"
verbose: True
19 changes: 19 additions & 0 deletions configs/commonvoice.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# @package _global_
data:
train_path: data/commonvoice_train_manifest.json
val_path: data/commonvoice_dev_manifest.json
num_workers: 8
augmentation:
spec_augment: True
trainer:
max_epochs: 16
gpus: 1
precision: 16
gradient_clip_val: 400 # Norm cutoff to prevent explosion of gradients
accelerator: ddp
plugins: ddp_sharded
checkpoint_callback: True
checkpoint:
save_top_k: 1
monitor: "wer"
verbose: True
19 changes: 19 additions & 0 deletions configs/librispeech.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# @package _global_
data:
train_path: data/libri_train_manifest.json
val_path: data/libri_val_manifest.json
num_workers: 8
augmentation:
spec_augment: True
trainer:
max_epochs: 16
gpus: 1
precision: 16
gradient_clip_val: 400 # Norm cutoff to prevent explosion of gradients
accelerator: ddp
plugins: ddp_sharded
checkpoint_callback: True
checkpoint:
save_top_k: 1
monitor: "wer"
verbose: True
19 changes: 19 additions & 0 deletions configs/tedlium.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# @package _global_
data:
train_path: data/ted_train_manifest.json
val_path: data/ted_val_manifest.json
num_workers: 8
augmentation:
spec_augment: True
trainer:
max_epochs: 16
gpus: 1
precision: 16
gradient_clip_val: 400 # Norm cutoff to prevent explosion of gradients
accelerator: ddp
plugins: ddp_sharded
checkpoint_callback: True
checkpoint:
save_top_k: 1
monitor: "wer"
verbose: True
33 changes: 20 additions & 13 deletions data/an4.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,8 @@ def download_an4(target_dir: str,
min_duration: float,
max_duration: float,
val_fraction: float,
sample_rate: int):
sample_rate: int,
num_workers: int):
root_path = 'an4/'
raw_tar_path = 'an4_raw.bigendian.tar.gz'
if not os.path.exists(raw_tar_path):
Expand All @@ -145,18 +146,21 @@ def download_an4(target_dir: str,

print('Creating manifests...')
create_manifest(data_path=train_path,
output_name='an4_train_manifest.csv',
output_name='an4_train_manifest.json',
manifest_path=manifest_dir,
min_duration=min_duration,
max_duration=max_duration)
max_duration=max_duration,
num_workers=num_workers)
create_manifest(data_path=val_path,
output_name='an4_val_manifest.csv',
output_name='an4_val_manifest.json',
manifest_path=manifest_dir,
min_duration=min_duration,
max_duration=max_duration)
max_duration=max_duration,
num_workers=num_workers)
create_manifest(data_path=test_path,
output_name='an4_test_manifest.csv',
manifest_path=manifest_dir)
output_name='an4_test_manifest.json',
manifest_path=manifest_dir,
num_workers=num_workers)


if __name__ == '__main__':
Expand All @@ -166,9 +170,12 @@ def download_an4(target_dir: str,
parser.add_argument('--val-fraction', default=0.1, type=float,
help='Number of files in the training set to use as validation.')
args = parser.parse_args()
download_an4(target_dir=args.target_dir,
manifest_dir=args.manifest_dir,
min_duration=args.min_duration,
max_duration=args.max_duration,
val_fraction=args.val_fraction,
sample_rate=args.sample_rate)
download_an4(
target_dir=args.target_dir,
manifest_dir=args.manifest_dir,
min_duration=args.min_duration,
max_duration=args.max_duration,
val_fraction=args.val_fraction,
sample_rate=args.sample_rate,
num_workers=args.num_workers
)
Loading

0 comments on commit d9790d9

Please sign in to comment.