From d9790d99ca4c66914fb01f63e255b07fa0305b2b Mon Sep 17 00:00:00 2001 From: Sean Naren Date: Sat, 30 Jan 2021 12:52:41 +0000 Subject: [PATCH] Pytorch Lightning Integration (#569) * Added minimal code to integrate Pytorch Lightning into training.py * Added autocast support, removed intra epoch checkpointing for simplicity, integrated checkpoint support, fixed validation support * Fixed multi-gpu support * Fixed smoke test, pretrained tests will be broken till new model release Added trains viz logging Precision * Updated README, fixed server class, updated k8s config file, added fix for adam Trains support, removed autocast since this is handled via lightning * Swapped to using tqdm write for readability when checkpointing, added an4 config * Added base script for each dataset, updated default params * Swapped to using native CTC, updated common voice script, removed incorrect lightning version * Updated cv params and output manifest location, set default epochs to the epochs used for previous release * Disable trains logger for now, simplified checkpointing logic for new release * Added new metrics class, removed save_output/verbose for now, using new ModelCheckpoint class for model saving * multiprocess duration collection for speed, allow loading from file path, refactor path name and test * Swap to latest release candidate, fixed flag reference * Format smoke test, update path to best save k model * Update to latest RC * Removed trains logging, rely on PL tensorboard. swap to saving json object for manifest to modify root path * Ensure abs path for manifest root path * Use absolute paths for manifest * Update requirements, abstract all PL trainer arguments * Enable checkpoint callback * Enable checkpoint callback, add verbosity * Add sharded as a dependency for better memory use * Set num workers, add spec augment * Update deepspeech_pytorch/data/utils.py Co-authored-by: Anas Abou Allaban * Specify blank index explicitly * Add blank index to ctc loss * Fix CI * Fix Syntax Warning * Fix install requirements * Use torchaudio (#607) * Use torchaudio * Add torchaudio to reqs * Fixes for testing, update AN4 config, update dockerfile base image * Add noninteractive to remove stalling * revert * Update API Co-authored-by: Sean Narenthiran Co-authored-by: Anas Abou Allaban Co-authored-by: Anas Abou Allaban --- Dockerfile | 12 +- README.md | 90 ++--- configs/an4.yaml | 18 + configs/commonvoice.yaml | 19 + configs/librispeech.yaml | 19 + configs/tedlium.yaml | 19 + data/an4.py | 33 +- data/common_voice.py | 87 +++-- data/librispeech.py | 22 +- data/ted.py | 32 +- data/voxforge.py | 13 +- deepspeech_pytorch/checkpoint.py | 133 ++----- .../configs/inference_config.py | 4 +- deepspeech_pytorch/configs/train_config.py | 68 +--- deepspeech_pytorch/data/data_opts.py | 1 + deepspeech_pytorch/data/utils.py | 71 ++-- deepspeech_pytorch/decoder.py | 32 -- deepspeech_pytorch/enums.py | 14 +- deepspeech_pytorch/inference.py | 50 +-- deepspeech_pytorch/loader/data_loader.py | 76 ++-- deepspeech_pytorch/loader/data_module.py | 62 ++++ deepspeech_pytorch/loader/merge_manifests.py | 31 -- deepspeech_pytorch/logger.py | 72 ---- deepspeech_pytorch/model.py | 206 ++++++----- deepspeech_pytorch/state.py | 171 --------- deepspeech_pytorch/testing.py | 110 ++---- deepspeech_pytorch/training.py | 324 +++--------------- deepspeech_pytorch/utils.py | 10 +- deepspeech_pytorch/validation.py | 170 +++++++++ kubernetes/train.yaml | 6 +- requirements.txt | 9 +- search_lm_params.py | 7 +- server.py | 45 ++- tests/pretrained_smoke_test.py | 41 ++- tests/smoke_test.py | 196 +++++++---- train.py | 7 +- translations/README_JP.md | 2 +- 37 files changed, 1043 insertions(+), 1239 deletions(-) create mode 100644 configs/an4.yaml create mode 100644 configs/commonvoice.yaml create mode 100644 configs/librispeech.yaml create mode 100644 configs/tedlium.yaml create mode 100644 deepspeech_pytorch/loader/data_module.py delete mode 100644 deepspeech_pytorch/loader/merge_manifests.py delete mode 100644 deepspeech_pytorch/logger.py delete mode 100644 deepspeech_pytorch/state.py create mode 100644 deepspeech_pytorch/validation.py diff --git a/Dockerfile b/Dockerfile index 18dccac0..e4a5ed37 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,4 +1,4 @@ -FROM pytorch/pytorch:1.5.1-cuda10.1-cudnn7-devel +FROM pytorch/pytorch:1.7.0-cuda11.0-cudnn8-devel ENV LD_LIBRARY_PATH=/usr/local/lib:$LD_LIBRARY_PATH WORKDIR /workspace/ @@ -7,20 +7,10 @@ WORKDIR /workspace/ RUN apt-get update -y RUN apt-get install -y git curl ca-certificates bzip2 cmake tree htop bmon iotop sox libsox-dev libsox-fmt-all vim -# install warp-CTC -ENV CUDA_HOME=/usr/local/cuda -RUN git clone https://github.com/SeanNaren/warp-ctc.git -RUN cd warp-ctc; mkdir build; cd build; cmake ..; make -RUN cd warp-ctc; cd pytorch_binding; python setup.py install - # install ctcdecode RUN git clone --recursive https://github.com/parlance/ctcdecode.git RUN cd ctcdecode; pip install . -# install apex -RUN git clone --recursive https://github.com/NVIDIA/apex.git -RUN cd apex; pip install . - # install deepspeech.pytorch ADD . /workspace/deepspeech.pytorch RUN cd deepspeech.pytorch; pip install -r requirements.txt && pip install -e . diff --git a/README.md b/README.md index 9a4c92f8..3bb5b315 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,7 @@ # deepspeech.pytorch [![Build Status](https://travis-ci.org/SeanNaren/deepspeech.pytorch.svg?branch=master)](https://travis-ci.org/SeanNaren/deepspeech.pytorch) -Implementation of DeepSpeech2 for PyTorch. The repo supports training/testing and inference using the [DeepSpeech2](http://arxiv.org/pdf/1512.02595v1.pdf) model. Optionally a [kenlm](https://github.com/kpu/kenlm) language model can be used at inference time. +Implementation of DeepSpeech2 for PyTorch using [PyTorch Lightning](https://github.com/PyTorchLightning/pytorch-lightning). The repo supports training/testing and inference using the [DeepSpeech2](http://arxiv.org/pdf/1512.02595v1.pdf) model. Optionally a [kenlm](https://github.com/kpu/kenlm) language model can be used at inference time. ## Installation @@ -26,20 +26,6 @@ an Anaconda installation on Ubuntu, with PyTorch installed. Install [PyTorch](https://github.com/pytorch/pytorch#installation) if you haven't already. -Install this fork for Warp-CTC bindings: -``` -git clone https://github.com/SeanNaren/warp-ctc.git -cd warp-ctc; mkdir build; cd build; cmake ..; make -export CUDA_HOME="/usr/local/cuda" -cd ../pytorch_binding && python setup.py install -``` - -Install NVIDIA apex: -``` -git clone --recursive https://github.com/NVIDIA/apex.git -cd apex && pip install . -``` - If you want decoding to support beam search with an optional language model, install ctcdecode: ``` git clone --recursive https://github.com/parlance/ctcdecode.git @@ -93,7 +79,7 @@ Configuration is done via [Hydra](https://github.com/facebookresearch/hydra). Defaults can be seen in [config.py](deepspeech_pytorch/configs/train_config.py). Below is how you can override values set already: ``` -python train.py data.train_manifest=data/train_manifest.csv data.val_manifest=data/val_manifest.csv +python train.py data.train_path=data/train_manifest.csv data.val_path=data/val_manifest.csv ``` Use `python train.py --help` for all parameters and options. @@ -103,27 +89,15 @@ You can also specify a config file to keep parameters stored in a yaml file like Create folder `experiment/` and file `experiment/an4.yaml`: ```yaml data: - train_manifest: data/an4_train_manifest.csv - val_manifest: data/an4_val_manifest.csv + train_path: data/an4_train_manifest.csv + val_path: data/an4_val_manifest.csv ``` ``` python train.py +experiment=an4 ``` -There is also [Visdom](https://github.com/facebookresearch/visdom) support to visualize training. Once a server has been started, to use: - -``` -python train.py visualization.visdom=true -``` - -There is also Tensorboard support to visualize training. Follow the instructions to set up. To use: - -``` -python train.py visualization.tensorboard=true visualization.log_dir=log_dir/ # Make sure the Tensorboard instance is made pointing to this log directory -``` - -For both visualisation tools, you can add your own name to the run by changing the `--id` parameter when training. +To see options available, check [here](./deepspeech_pytorch/configs/train_config.py). ### Multi-GPU Training @@ -136,9 +110,10 @@ python -m torchelastic.distributed.launch \ --standalone \ --nnodes=1 \ --nproc_per_node=4 \ - train.py data.train_manifest=data/an4_train_manifest.csv \ - data.val_manifest=data/an4_val_manifest.csv apex.opt_level=O1 data.num_workers=8 \ - data.batch_size=8 training.epochs=70 checkpointing.checkpoint=true checkpointing.save_n_recent_models=3 + train.py data.train_path=data/an4_train_manifest.csv \ + data.val_path=data/an4_val_manifest.csv model.precision=half data.num_workers=8 \ + data.batch_size=8 trainer.max_epochs=70 checkpoint.checkpoint=true checkpointing.save_n_recent_models=3 \ + trainer.accelerator=ddp trainer.gpus=4 ``` You'll see the output for all the processes running on each individual GPU. @@ -169,14 +144,15 @@ python -m torchelastic.distributed.launch \ --rdzv_id=123 \ --rdzv_backend=etcd \ --rdzv_endpoint=$PUBLIC_HOST_NAME:4377 \ - train.py data.train_manifest=/share/data/an4_train_manifest.csv \ - data.val_manifest=/share/data/an4_val_manifest.csv apex.opt_level=O1 \ - data.num_workers=8 checkpointing.save_folder=/share/checkpoints/ \ - checkpointing.checkpoint=true checkpointing.load_auto_checkpoint=true checkpointing.save_n_recent_models=3 \ - data.batch_size=8 training.epochs=70 + train.py data.train_path=/share/data/an4_train_manifest.csv \ + data.val_path=/share/data/an4_val_manifest.csv model.precision=half \ + data.num_workers=8 checkpoint.save_folder=/share/checkpoints/ \ + checkpoint.checkpoint=true checkpoint.load_auto_checkpoint=true checkpointing.save_n_recent_models=3 \ + data.batch_size=8 trainer.max_epochs=70 \ + trainer.accelerator=ddp trainer.gpus=4 trainer.num_nodes=2 ``` -Using the `checkpointing.load_auto_checkpoint=true` flag and the `checkpointing.checkpoint_per_iteration` flag we can re-continue training from the latest saved checkpoint. +Using the `load_auto_checkpoint=true` flag we can re-continue training from the latest saved checkpoint. Currently it is expected that there is an NFS drive/shared mount across all nodes within the cluster to load the latest checkpoint from. @@ -184,13 +160,11 @@ Currently it is expected that there is an NFS drive/shared mount across all node If you are using NVIDIA volta cards or above to train your model, it's highly suggested to turn on mixed precision for speed/memory benefits. More information can be found [here](https://docs.nvidia.com/deeplearning/sdk/mixed-precision-training/index.html). -Different Optimization levels are available. More information on the Nvidia Apex API can be seen [here](https://nvidia.github.io/apex/amp.html#opt-levels). - ``` -python train.py data.train_manifest=data/train_manifest.csv data.val_manifest=data/val_manifest.csv apex.opt_level=O1 apex.loss_scale=1.0 +python train.py data.train_manifest=data/train_manifest.csv data.val_manifest=data/val_manifest.csv trainer.precision=16 ``` -Training a model in mixed-precision means you can use 32 bit float or half precision at runtime. Float32 is default, to use half precision (Which on V100s come with a speedup and better memory use) use the `--half` flag when testing or transcribing. +Training a model in mixed-precision means you can use 32 bit float or half precision at runtime. Float32 is default, to use half precision (Which on V100s come with a speedup and better memory use) use the `model.precision=half` flag when testing or transcribing. ### Swapping to ADAMW Optimizer @@ -230,29 +204,21 @@ Applies small changes to the tempo and gain when loading audio to increase robus ### Checkpoints -Training supports saving checkpoints of the model to continue training from should an error occur or early termination. To enable epoch -checkpoints use: - -``` -python train.py checkpoint=true -``` +Training supports saving checkpoints of the model to continue training from should an error occur or early termination. -To enable checkpoints every N batches through the epoch as well as epoch saving: +To enable epoch checkpoints use: ``` -python train.py checkpoint=true --checkpoint-per-batch N # N is the number of batches to wait till saving a checkpoint at this batch. +python train.py checkpoint=true ``` -Note for the batch checkpointing system to work, you cannot change the batch size when loading a checkpointed model from it's original training -run. - -To continue from a checkpointed model that has been saved: +To continue from a checkpoint model: ``` python train.py checkpointing.continue_from=models/deepspeech_checkpoint_epoch_N_iter_N.pth ``` -This continues from the same training state as well as recreates the visdom graph to continue from if enabled. +This continues from the same training state. If you would like to start from a previous checkpoint model but not continue training, add the `training.finetune=true` flag to restart training from the `checkpointing.continue_from` weights. @@ -275,7 +241,7 @@ To also note, there is no final softmax layer on the model as when trained, warp To evaluate a trained model on a test set (has to be in the same format as the training set): ``` -python test.py model.model_path=models/deepspeech.pth test_manifest=/path/to/test_manifest.csv +python test.py model.model_path=models/deepspeech.pth test_path=/path/to/test_manifest.csv ``` An example script to output a transcription has been provided: @@ -284,7 +250,7 @@ An example script to output a transcription has been provided: python transcribe.py model.model_path=models/deepspeech.pth audio_path=/path/to/audio.wav ``` -If you used mixed-precision or half precision when training the model, you can use the `--half` flag for a speed/memory benefit. +If you used mixed-precision or half precision when training the model, you can use the `model.precision=half` for a speed/memory benefit. ## Inference Server @@ -307,7 +273,7 @@ In addition download the latest pre-trained librispeech model from the releases First we need to generate the acoustic output to be used to evaluate the model on LibriSpeech val. ``` -python test.py data.test_manifest=data/librispeech_val_manifest.csv model.model_path=librispeech_pretrained_v2.pth save_output=librispeech_val_output.npy +python test.py data.test_path=data/librispeech_val_manifest.csv model.model_path=librispeech_pretrained_v2.pth save_output=librispeech_val_output.npy ``` We use a beam width of 128 which gives reasonable results. We suggest using a CPU intensive node to carry out the grid search. @@ -331,7 +297,7 @@ To build your own LM you need to use the KenLM repo found [here](https://github. ### Alternate Decoders By default, `test.py` and `transcribe.py` use a `GreedyDecoder` which picks the highest-likelihood output label at each timestep. Repeated and blank symbols are then filtered to give the final output. -A beam search decoder can optionally be used with the installation of the `ctcdecode` library as described in the Installation section. The `test` and `transcribe` scripts have a `decoder_type` argument. To use the beam decoder, add `lm.decoder_type=beam`. The beam decoder enables additional decoding parameters: +A beam search decoder can optionally be used with the installation of the `ctcdecode` library as described in the Installation section. The `test` and `transcribe` scripts have a `lm` config. To use the beam decoder, add `lm.decoder_type=beam`. The beam decoder enables additional decoding parameters: - **lm.beam_width** how many beams to consider at each timestep - **lm.lm_path** optional binary KenLM language model to use for decoding - **lm.alpha** weight for language model @@ -339,7 +305,7 @@ A beam search decoder can optionally be used with the installation of the `ctcde ### Time offsets -Use the `--offsets` flag to get positional information of each character in the transcription when using `transcribe.py` script. The offsets are based on the size +Use the `offsets=true` flag to get positional information of each character in the transcription when using `transcribe.py` script. The offsets are based on the size of the output tensor, which you need to convert into a format required. For example, based on default parameters you could multiply the offsets by a scalar (duration of file in seconds / size of output) to get the offsets in seconds. diff --git a/configs/an4.yaml b/configs/an4.yaml new file mode 100644 index 00000000..e413cc19 --- /dev/null +++ b/configs/an4.yaml @@ -0,0 +1,18 @@ +# @package _global_ +data: + train_path: data/an4_train_manifest.json + val_path: data/an4_val_manifest.json + batch_size: 8 + num_workers: 8 +trainer: + max_epochs: 70 + gpus: 1 + precision: 16 + gradient_clip_val: 400 # Norm cutoff to prevent explosion of gradients + accelerator: ddp + plugins: ddp_sharded + checkpoint_callback: True +checkpoint: + save_top_k: 1 + monitor: "wer" + verbose: True \ No newline at end of file diff --git a/configs/commonvoice.yaml b/configs/commonvoice.yaml new file mode 100644 index 00000000..18a031f5 --- /dev/null +++ b/configs/commonvoice.yaml @@ -0,0 +1,19 @@ +# @package _global_ +data: + train_path: data/commonvoice_train_manifest.json + val_path: data/commonvoice_dev_manifest.json + num_workers: 8 + augmentation: + spec_augment: True +trainer: + max_epochs: 16 + gpus: 1 + precision: 16 + gradient_clip_val: 400 # Norm cutoff to prevent explosion of gradients + accelerator: ddp + plugins: ddp_sharded + checkpoint_callback: True +checkpoint: + save_top_k: 1 + monitor: "wer" + verbose: True \ No newline at end of file diff --git a/configs/librispeech.yaml b/configs/librispeech.yaml new file mode 100644 index 00000000..8a4b5081 --- /dev/null +++ b/configs/librispeech.yaml @@ -0,0 +1,19 @@ +# @package _global_ +data: + train_path: data/libri_train_manifest.json + val_path: data/libri_val_manifest.json + num_workers: 8 + augmentation: + spec_augment: True +trainer: + max_epochs: 16 + gpus: 1 + precision: 16 + gradient_clip_val: 400 # Norm cutoff to prevent explosion of gradients + accelerator: ddp + plugins: ddp_sharded + checkpoint_callback: True +checkpoint: + save_top_k: 1 + monitor: "wer" + verbose: True \ No newline at end of file diff --git a/configs/tedlium.yaml b/configs/tedlium.yaml new file mode 100644 index 00000000..3b4857ce --- /dev/null +++ b/configs/tedlium.yaml @@ -0,0 +1,19 @@ +# @package _global_ +data: + train_path: data/ted_train_manifest.json + val_path: data/ted_val_manifest.json + num_workers: 8 + augmentation: + spec_augment: True +trainer: + max_epochs: 16 + gpus: 1 + precision: 16 + gradient_clip_val: 400 # Norm cutoff to prevent explosion of gradients + accelerator: ddp + plugins: ddp_sharded + checkpoint_callback: True +checkpoint: + save_top_k: 1 + monitor: "wer" + verbose: True \ No newline at end of file diff --git a/data/an4.py b/data/an4.py index 7adc8e99..f1ae9ecf 100644 --- a/data/an4.py +++ b/data/an4.py @@ -122,7 +122,8 @@ def download_an4(target_dir: str, min_duration: float, max_duration: float, val_fraction: float, - sample_rate: int): + sample_rate: int, + num_workers: int): root_path = 'an4/' raw_tar_path = 'an4_raw.bigendian.tar.gz' if not os.path.exists(raw_tar_path): @@ -145,18 +146,21 @@ def download_an4(target_dir: str, print('Creating manifests...') create_manifest(data_path=train_path, - output_name='an4_train_manifest.csv', + output_name='an4_train_manifest.json', manifest_path=manifest_dir, min_duration=min_duration, - max_duration=max_duration) + max_duration=max_duration, + num_workers=num_workers) create_manifest(data_path=val_path, - output_name='an4_val_manifest.csv', + output_name='an4_val_manifest.json', manifest_path=manifest_dir, min_duration=min_duration, - max_duration=max_duration) + max_duration=max_duration, + num_workers=num_workers) create_manifest(data_path=test_path, - output_name='an4_test_manifest.csv', - manifest_path=manifest_dir) + output_name='an4_test_manifest.json', + manifest_path=manifest_dir, + num_workers=num_workers) if __name__ == '__main__': @@ -166,9 +170,12 @@ def download_an4(target_dir: str, parser.add_argument('--val-fraction', default=0.1, type=float, help='Number of files in the training set to use as validation.') args = parser.parse_args() - download_an4(target_dir=args.target_dir, - manifest_dir=args.manifest_dir, - min_duration=args.min_duration, - max_duration=args.max_duration, - val_fraction=args.val_fraction, - sample_rate=args.sample_rate) + download_an4( + target_dir=args.target_dir, + manifest_dir=args.manifest_dir, + min_duration=args.min_duration, + max_duration=args.max_duration, + val_fraction=args.val_fraction, + sample_rate=args.sample_rate, + num_workers=args.num_workers + ) diff --git a/data/common_voice.py b/data/common_voice.py index 747d6087..75831324 100644 --- a/data/common_voice.py +++ b/data/common_voice.py @@ -1,10 +1,12 @@ -import os -import wget -import tarfile import argparse import csv +import os +import tarfile from multiprocessing.pool import ThreadPool -import subprocess + +from sox import Transformer +import tqdm +import wget from deepspeech_pytorch.data.data_opts import add_data_opts from deepspeech_pytorch.data.utils import create_manifest @@ -13,13 +15,15 @@ parser = add_data_opts(parser) parser.add_argument("--target-dir", default='CommonVoice_dataset/', type=str, help="Directory to store the dataset.") parser.add_argument("--tar-path", type=str, help="Path to the Common Voice *.tar file if downloaded (Optional).") -parser.add_argument('--files-to-process', default="cv-valid-dev.csv,cv-valid-test.csv,cv-valid-train.csv", +parser.add_argument('--files-to-process', nargs='+', default=['test.tsv', 'dev.tsv', 'train.tsv'], type=str, help='list of *.csv file names to process') args = parser.parse_args() -COMMON_VOICE_URL = "https://common-voice-data-download.s3.amazonaws.com/cv_corpus_v1.tar.gz" +VERSION = 'cv-corpus-5.1-2020-06-22' +COMMON_VOICE_URL = "https://voice-prod-bundler-ee1969a6ce8178826482b88e843c335139bd3fb4.s3.amazonaws.com/" \ + "{}/en.tar.gz".format(VERSION) -def convert_to_wav(csv_file, target_dir): +def convert_to_wav(csv_file, target_dir, num_workers): """ Read *.csv file description, convert mp3 to wav, process text. Save results to target_dir. @@ -31,7 +35,7 @@ def convert_to_wav(csv_file, target_dir): txt_dir = os.path.join(target_dir, 'txt/') os.makedirs(wav_dir, exist_ok=True) os.makedirs(txt_dir, exist_ok=True) - path_to_data = os.path.dirname(csv_file) + audio_clips_path = os.path.dirname(csv_file) + '/clips/' def process(x): file_path, text = x @@ -39,18 +43,23 @@ def process(x): text = text.strip().upper() with open(os.path.join(txt_dir, file_name + '.txt'), 'w') as f: f.write(text) - cmd = "sox {} -r {} -b 16 -c 1 {}".format( - os.path.join(path_to_data, file_path), - args.sample_rate, - os.path.join(wav_dir, file_name + '.wav')) - subprocess.call([cmd], shell=True) + audio_path = os.path.join(audio_clips_path, file_path) + output_wav_path = os.path.join(wav_dir, file_name + '.wav') + + tfm = Transformer() + tfm.rate(samplerate=args.sample_rate) + tfm.build( + input_filepath=audio_path, + output_filepath=output_wav_path + ) print('Converting mp3 to wav for {}.'.format(csv_file)) with open(csv_file) as csvfile: - reader = csv.DictReader(csvfile) - data = [(row['filename'], row['text']) for row in reader] - with ThreadPool(10) as pool: - pool.map(process, data) + reader = csv.DictReader(csvfile, delimiter='\t') + next(reader, None) # skip the headers + data = [(row['path'], row['sentence']) for row in reader] + with ThreadPool(num_workers) as pool: + list(tqdm.tqdm(pool.imap(process, data), total=len(data))) def main(): @@ -58,32 +67,40 @@ def main(): os.makedirs(target_dir, exist_ok=True) target_unpacked_dir = os.path.join(target_dir, "CV_unpacked") - os.makedirs(target_unpacked_dir, exist_ok=True) - if args.tar_path and os.path.exists(args.tar_path): - print('Find existing file {}'.format(args.tar_path)) - target_file = args.tar_path + if os.path.exists(target_unpacked_dir): + print('Find existing folder {}'.format(target_unpacked_dir)) else: - print("Could not find downloaded Common Voice archive, Downloading corpus...") + print("Could not find Common Voice, Downloading corpus...") + filename = wget.download(COMMON_VOICE_URL, target_dir) target_file = os.path.join(target_dir, os.path.basename(filename)) - print("Unpacking corpus to {} ...".format(target_unpacked_dir)) - tar = tarfile.open(target_file) - tar.extractall(target_unpacked_dir) - tar.close() + os.makedirs(target_unpacked_dir, exist_ok=True) + print("Unpacking corpus to {} ...".format(target_unpacked_dir)) + tar = tarfile.open(target_file) + tar.extractall(target_unpacked_dir) + tar.close() + + folder_path = os.path.join(target_unpacked_dir, VERSION + '/en/') # TODO expose the language flag - for csv_file in args.files_to_process.split(','): - convert_to_wav(os.path.join(target_unpacked_dir, 'cv_corpus_v1/', csv_file), - os.path.join(target_dir, os.path.splitext(csv_file)[0])) + for csv_file in args.files_to_process: + convert_to_wav( + csv_file=os.path.join(folder_path, csv_file), + target_dir=os.path.join(target_dir, os.path.splitext(csv_file)[0]), + num_workers=args.num_workers + ) print('Creating manifests...') - for csv_file in args.files_to_process.split(','): - create_manifest(data_path=os.path.join(target_dir, os.path.splitext(csv_file)[0]), - output_name=os.path.splitext(csv_file)[0] + '_manifest.csv', - manifest_path=args.manifest_dir, - min_duration=args.min_duration, - max_duration=args.max_duration) + for csv_file in args.files_to_process: + create_manifest( + data_path=os.path.join(target_dir, os.path.splitext(csv_file)[0]), + output_name='commonvoice_' + os.path.splitext(csv_file)[0] + '_manifest.json', + manifest_path=args.manifest_dir, + min_duration=args.min_duration, + max_duration=args.max_duration, + num_workers=args.num_workers + ) if __name__ == "__main__": diff --git a/data/librispeech.py b/data/librispeech.py index a292105b..407668a6 100644 --- a/data/librispeech.py +++ b/data/librispeech.py @@ -103,15 +103,21 @@ def main(): print("Finished {}".format(url)) shutil.rmtree(extracted_dir) if split_type == 'train': # Prune to min/max duration - create_manifest(data_path=split_dir, - output_name='libri_' + split_type + '_manifest.csv', - manifest_path=args.manifest_dir, - min_duration=args.min_duration, - max_duration=args.max_duration) + create_manifest( + data_path=split_dir, + output_name='libri_' + split_type + '_manifest.json', + manifest_path=args.manifest_dir, + min_duration=args.min_duration, + max_duration=args.max_duration, + num_workers=args.num_workers + ) else: - create_manifest(data_path=split_dir, - output_name='libri_' + split_type + '_manifest.csv', - manifest_path=args.manifest_dir) + create_manifest( + data_path=split_dir, + output_name='libri_' + split_type + '_manifest.json', + manifest_path=args.manifest_dir, + num_workers=args.num_workers + ) if __name__ == "__main__": diff --git a/data/ted.py b/data/ted.py index 4acb750f..5b84ce3c 100644 --- a/data/ted.py +++ b/data/ted.py @@ -119,17 +119,27 @@ def main(): prepare_dir(test_ted_dir) print('Creating manifests...') - create_manifest(data_path=train_ted_dir, - output_name='ted_train_manifest.csv', - manifest_path=args.manifest_dir, - min_duration=args.min_duration, - max_duration=args.max_duration) - create_manifest(data_path=val_ted_dir, - output_name='ted_val_manifest.csv', - manifest_path=args.manifest_dir) - create_manifest(data_path=test_ted_dir, - output_name='ted_test_manifest.csv', - manifest_path=args.manifest_dir) + create_manifest( + data_path=train_ted_dir, + output_name='ted_train_manifest.json', + manifest_path=args.manifest_dir, + min_duration=args.min_duration, + max_duration=args.max_duration, + num_workers=args.num_workers + ) + create_manifest( + data_path=val_ted_dir, + output_name='ted_val_manifest.json', + manifest_path=args.manifest_dir, + num_workers=args.num_workers + + ) + create_manifest( + data_path=test_ted_dir, + output_name='ted_test_manifest.json', + manifest_path=args.manifest_dir, + num_workers=args.num_workers + ) if __name__ == "__main__": diff --git a/data/voxforge.py b/data/voxforge.py index f0caf63d..5fbf2f20 100644 --- a/data/voxforge.py +++ b/data/voxforge.py @@ -95,8 +95,11 @@ def prepare_sample(recording_name, url, target_folder): for f in tqdm(all_files, total=len(all_files)): prepare_sample(f.replace(".tgz", ""), VOXFORGE_URL_16kHz + f, target_dir) print('Creating manifests...') - create_manifest(data_path=target_dir, - output_name='voxforge_train_manifest.csv', - manifest_path=args.manifest_dir, - min_duration=args.min_duration, - max_duration=args.max_duration) + create_manifest( + data_path=target_dir, + output_name='voxforge_train_manifest.json', + manifest_path=args.manifest_dir, + min_duration=args.min_duration, + max_duration=args.max_duration, + num_workers=args.num_workers + ) diff --git a/deepspeech_pytorch/checkpoint.py b/deepspeech_pytorch/checkpoint.py index 84a01a23..1d9f891a 100644 --- a/deepspeech_pytorch/checkpoint.py +++ b/deepspeech_pytorch/checkpoint.py @@ -1,86 +1,36 @@ import os -from abc import ABC -from pathlib import Path, PosixPath +from pathlib import Path import hydra -import torch -from deepspeech_pytorch.configs.train_config import GCSCheckpointConfig, CheckpointConfig, FileCheckpointConfig -from deepspeech_pytorch.state import TrainingState from google.cloud import storage +from hydra_configs.pytorch_lightning.callbacks import ModelCheckpointConf +from pytorch_lightning.callbacks import ModelCheckpoint +from tqdm import tqdm +from deepspeech_pytorch.configs.train_config import GCSCheckpointConfig -class CheckpointHandler(ABC): - def __init__(self, - cfg: CheckpointConfig, - save_location): - self.checkpoint_prefix = 'deepspeech_checkpoint_' # TODO do we want to expose this? - self.save_location = save_location - self.checkpoint_per_iteration = cfg.checkpoint_per_iteration - self.save_n_recent_models = cfg.save_n_recent_models +class CheckpointHandler(ModelCheckpoint): - if type(self.save_location) == PosixPath: - self.checkpoint_prefix_path = self.save_location / self.checkpoint_prefix - self.best_val_path = self.save_location / cfg.best_val_model_name - else: - self.checkpoint_prefix_path = self.save_location + self.checkpoint_prefix - self.best_val_path = self.save_location + cfg.best_val_model_name - - def save_model(self, - model_path: str, - state: TrainingState, - epoch: int, - i: int = None): - raise NotImplementedError + def __init__(self, cfg: ModelCheckpointConf): + super().__init__( + dirpath=cfg.dirpath, + filename=cfg.filename, + monitor=cfg.monitor, + verbose=cfg.verbose, + save_last=cfg.save_last, + save_top_k=cfg.save_top_k, + save_weights_only=cfg.save_weights_only, + mode=cfg.mode, + period=cfg.period, + prefix=cfg.prefix + ) def find_latest_checkpoint(self): raise NotImplementedError - def check_and_delete_oldest_checkpoint(self): - raise NotImplementedError - - def save_checkpoint_model(self, epoch, state, i=None): - if self.save_n_recent_models > 0: - self.check_and_delete_oldest_checkpoint() - model_path = self._create_checkpoint_path(epoch=epoch, - i=i) - self.save_model(model_path=model_path, - state=state, - epoch=epoch, - i=i) - - def save_iter_checkpoint_model(self, epoch, state, i): - if self.checkpoint_per_iteration > 0 and i > 0 and (i + 1) % self.checkpoint_per_iteration == 0: - self.save_checkpoint_model(epoch=epoch, - state=state, - i=i) - - def save_best_model(self, epoch, state): - self.save_model(model_path=self.best_val_path, - state=state, - epoch=epoch) - - def _create_checkpoint_path(self, epoch, i=None): - """ - Creates path to save checkpoint. - We automatically iterate the epoch and iteration for readibility. - :param epoch: The epoch (index starts at 0). - :param i: The iteration (index starts at 0). - :return: The path to save the model - """ - if i: - checkpoint_path = str(self.checkpoint_prefix_path) + 'epoch_%d_iter_%d.pth' % (epoch + 1, i + 1) - else: - checkpoint_path = str(self.checkpoint_prefix_path) + 'epoch_%d.pth' % (epoch + 1) - return checkpoint_path - class FileCheckpointHandler(CheckpointHandler): - def __init__(self, cfg: FileCheckpointConfig): - self.save_folder = Path(hydra.utils.to_absolute_path(cfg.save_folder)) - self.save_folder.mkdir(parents=True, exist_ok=True) # Ensure save folder exists - super().__init__(cfg=cfg, - save_location=self.save_folder) def find_latest_checkpoint(self): """ @@ -88,7 +38,7 @@ def find_latest_checkpoint(self): If there are no checkpoints, returns None. :return: The latest checkpoint path, or None if no checkpoints are found. """ - paths = list(self.save_folder.rglob(self.checkpoint_prefix + '*')) + paths = list(Path(self.dirpath).rglob(self.prefix + '*')) if paths: paths.sort(key=os.path.getctime) latest_checkpoint_path = paths[-1] @@ -96,28 +46,15 @@ def find_latest_checkpoint(self): else: return None - def check_and_delete_oldest_checkpoint(self): - paths = list(self.save_folder.rglob(self.checkpoint_prefix + '*')) - if paths and len(paths) >= self.save_n_recent_models: - paths.sort(key=os.path.getctime) - print("Deleting old checkpoint %s" % str(paths[0])) - os.remove(paths[0]) - - def save_model(self, model_path, state, epoch, i=None): - print("Saving model to %s" % model_path) - torch.save(obj=state.serialize_state(epoch=epoch, - iteration=i), - f=model_path) - class GCSCheckpointHandler(CheckpointHandler): def __init__(self, cfg: GCSCheckpointConfig): self.client = storage.Client() self.local_save_file = hydra.utils.to_absolute_path(cfg.local_save_file) self.gcs_bucket = cfg.gcs_bucket + self.gcs_save_folder = cfg.gcs_save_folder self.bucket = self.client.bucket(bucket_name=self.gcs_bucket) - super().__init__(cfg=cfg, - save_location=cfg.gcs_save_folder) + super().__init__(cfg=cfg) def find_latest_checkpoint(self): """ @@ -126,7 +63,7 @@ def find_latest_checkpoint(self): If there are no checkpoints, returns None. :return: The latest checkpoint path, or None if no checkpoints are found. """ - prefix = self.save_location + self.checkpoint_prefix + prefix = self.gcs_save_folder + self.prefix paths = list(self.client.list_blobs(self.gcs_bucket, prefix=prefix)) if paths: paths.sort(key=lambda x: x.time_created) @@ -136,20 +73,16 @@ def find_latest_checkpoint(self): else: return None - def check_and_delete_oldest_checkpoint(self): - prefix = self.save_location + self.checkpoint_prefix - paths = list(self.client.list_blobs(self.gcs_bucket, prefix=prefix)) - if paths and len(paths) >= self.save_n_recent_models: - paths.sort(key=lambda x: x.time_created) - print("Deleting old checkpoint %s" % paths[0].name) - paths[0].delete() - - def save_model(self, model_path, state, epoch, i=None): - print("Saving model to %s" % model_path) - torch.save(obj=state.serialize_state(epoch=epoch, - iteration=i), - f=self.local_save_file) - self._save_file_to_gcs(model_path) + def _save_model(self, filepath: str, trainer, pl_module): + + # in debugging, track when we save checkpoints + trainer.dev_debugger.track_checkpointing_history(filepath) + + # make paths + if trainer.is_global_zero: + tqdm.write("Saving model to %s" % filepath) + trainer.save_checkpoint(filepath) + self._save_file_to_gcs(filepath) def _save_file_to_gcs(self, model_path): blob = self.bucket.blob(model_path) diff --git a/deepspeech_pytorch/configs/inference_config.py b/deepspeech_pytorch/configs/inference_config.py index 7c8389f4..53464243 100644 --- a/deepspeech_pytorch/configs/inference_config.py +++ b/deepspeech_pytorch/configs/inference_config.py @@ -18,7 +18,7 @@ class LMConfig: @dataclass class ModelConfig: - use_half: bool = True # Use half precision. This is recommended when using mixed-precision at training time + precision: int = 32 # Set to 16 to use mixed-precision for inference cuda: bool = True model_path: str = '' @@ -37,7 +37,7 @@ class TranscribeConfig(InferenceConfig): @dataclass class EvalConfig(InferenceConfig): - test_manifest: str = '' # Path to validation manifest csv + test_path: str = '' # Path to validation manifest csv or folder verbose: bool = True # Print out decoded output and error of each sample save_output: str = '' # Saves output of model from test to this file_path batch_size: int = 20 # Batch size for testing diff --git a/deepspeech_pytorch/configs/train_config.py b/deepspeech_pytorch/configs/train_config.py index 4f831a30..68731331 100644 --- a/deepspeech_pytorch/configs/train_config.py +++ b/deepspeech_pytorch/configs/train_config.py @@ -1,25 +1,19 @@ from dataclasses import dataclass, field from typing import Any, List -from deepspeech_pytorch.enums import DistributedBackend, SpectrogramWindow, RNNType +from hydra_configs.pytorch_lightning.callbacks import ModelCheckpointConf +from hydra_configs.pytorch_lightning.trainer import TrainerConf from omegaconf import MISSING +from deepspeech_pytorch.enums import SpectrogramWindow, RNNType + defaults = [ - {"optim": "sgd"}, + {"optim": "adam"}, {"model": "bidirectional"}, - {"checkpointing": "file"} + {"checkpoint": "file"} ] -@dataclass -class TrainingConfig: - no_cuda: bool = False # Enable CPU only training - finetune: bool = False # Fine-tune the model from checkpoint "continue_from" - seed: int = 123456 # Seed for generators - dist_backend: DistributedBackend = DistributedBackend.nccl # If using distribution, the backend to be used - epochs: int = 70 # Number of Training Epochs - - @dataclass class SpectConfig: sample_rate: int = 16000 # The sample rate for the data/model features @@ -40,9 +34,9 @@ class AugmentationConfig: @dataclass class DataConfig: - train_manifest: str = 'data/train_manifest.csv' - val_manifest: str = 'data/val_manifest.csv' - batch_size: int = 20 # Batch size for training + train_path: str = 'data/train_manifest.csv' + val_path: str = 'data/val_manifest.csv' + batch_size: int = 64 # Batch size for training num_workers: int = 4 # Number of workers used in data-loading labels_path: str = 'labels.json' # Contains tokens for model output spect: SpectConfig = SpectConfig() @@ -63,10 +57,9 @@ class UniDirectionalConfig(BiDirectionalConfig): @dataclass class OptimConfig: - learning_rate: float = 3e-4 # Initial Learning Rate - learning_anneal: float = 1.1 # Annealing applied to learning rate after each epoch + learning_rate: float = 1.5e-4 # Initial Learning Rate + learning_anneal: float = 0.99 # Annealing applied to learning rate after each epoch weight_decay: float = 1e-5 # Initial Weight Decay - max_norm: float = 400 # Norm cutoff to prevent explosion of gradients @dataclass @@ -81,40 +74,15 @@ class AdamConfig(OptimConfig): @dataclass -class CheckpointConfig: - continue_from: str = '' # Continue training from checkpoint model - checkpoint: bool = True # Enables epoch checkpoint saving of model - checkpoint_per_iteration: int = 0 # Save checkpoint per N number of iterations. Default is disabled - save_n_recent_models: int = 10 # Max number of checkpoints to save, delete older checkpoints - best_val_model_name: str = 'deepspeech_final.pth' # Name to save best validated model within the save folder - load_auto_checkpoint: bool = False # Automatically load the latest checkpoint from save folder - - -@dataclass -class FileCheckpointConfig(CheckpointConfig): - save_folder: str = 'models/' # Location to save checkpoint models - - -@dataclass -class GCSCheckpointConfig(CheckpointConfig): +class GCSCheckpointConfig(ModelCheckpointConf): gcs_bucket: str = MISSING # Bucket to store model checkpoints e.g bucket-name gcs_save_folder: str = MISSING # Folder to store model checkpoints in bucket e.g models/ local_save_file: str = './local_checkpoint.pth' # Place to store temp file on disk @dataclass -class VisualizationConfig: - id: str = 'DeepSpeech training' # Name to use when visualizing/storing the run - visdom: bool = False # Turn on visdom graphing - tensorboard: bool = False # Turn on Tensorboard graphing - log_dir: str = 'visualize/deepspeech_final' # Location of Tensorboard log - log_params: bool = False # Log parameter values and gradients - - -@dataclass -class ApexConfig: - opt_level: str = 'O1' # Apex optimization level, check https://nvidia.github.io/apex/amp.html for more information - loss_scale: int = 1 # Loss scaling used by Apex. Default is 1 due to warp-ctc not supporting scaling of gradients +class DeepSpeechTrainerConf(TrainerConf): + callbacks: Any = MISSING @dataclass @@ -122,9 +90,9 @@ class DeepSpeechConfig: defaults: List[Any] = field(default_factory=lambda: defaults) optim: Any = MISSING model: Any = MISSING - checkpointing: Any = MISSING - training: TrainingConfig = TrainingConfig() + checkpoint: Any = MISSING + trainer: DeepSpeechTrainerConf = DeepSpeechTrainerConf() data: DataConfig = DataConfig() augmentation: AugmentationConfig = AugmentationConfig() - apex: ApexConfig = ApexConfig() - visualization: VisualizationConfig = VisualizationConfig() + seed: int = 123456 # Seed for generators + load_auto_checkpoint: bool = False # Automatically load the latest checkpoint from save folder diff --git a/deepspeech_pytorch/data/data_opts.py b/deepspeech_pytorch/data/data_opts.py index ef177641..bd3e34ff 100644 --- a/deepspeech_pytorch/data/data_opts.py +++ b/deepspeech_pytorch/data/data_opts.py @@ -6,5 +6,6 @@ def add_data_opts(parser): help='Prunes training samples shorter than the min duration (given in seconds, default 1)') data_opts.add_argument('--max-duration', default=15, type=int, help='Prunes training samples longer than the max duration (given in seconds, default 15)') + parser.add_argument('--num-workers', default=4, type=int, help='Number of workers for processing data.') parser.add_argument('--sample-rate', default=16000, type=int, help='Sample rate') return parser diff --git a/deepspeech_pytorch/data/utils.py b/deepspeech_pytorch/data/utils.py index b8ea71a1..0ab323d2 100644 --- a/deepspeech_pytorch/data/utils.py +++ b/deepspeech_pytorch/data/utils.py @@ -1,38 +1,65 @@ from __future__ import print_function -import fnmatch -import io +import json import os -import subprocess +from multiprocessing import Pool +from pathlib import Path +from typing import Optional +import sox from tqdm import tqdm -def create_manifest(data_path, output_name, manifest_path, min_duration=None, max_duration=None): - file_paths = [os.path.join(dirpath, f) - for dirpath, dirnames, files in os.walk(data_path) - for f in fnmatch.filter(files, '*.wav')] - file_paths = order_and_prune_files(file_paths, min_duration, max_duration) - os.makedirs(manifest_path, exist_ok=True) - with io.FileIO(os.path.join(manifest_path, output_name), "w") as file: - for wav_path in tqdm(file_paths, total=len(file_paths)): - transcript_path = wav_path.replace('/wav/', '/txt/').replace('.wav', '.txt') - sample = os.path.abspath(wav_path) + ',' + os.path.abspath(transcript_path) + '\n' - file.write(sample.encode('utf-8')) - print('\n') +def create_manifest( + data_path: str, + output_name: str, + manifest_path: str, + num_workers: int, + min_duration: Optional[float] = None, + max_duration: Optional[float] = None): + data_path = os.path.abspath(data_path) + file_paths = list(Path(data_path).rglob('*.wav')) + file_paths = order_and_prune_files( + file_paths=file_paths, + min_duration=min_duration, + max_duration=max_duration, + num_workers=num_workers + ) + output_path = Path(manifest_path) / output_name + output_path.parent.mkdir(exist_ok=True, parents=True) -def order_and_prune_files(file_paths, min_duration, max_duration): + manifest = { + 'root_path': data_path, + 'samples': [] + } + for wav_path in tqdm(file_paths, total=len(file_paths)): + wav_path = wav_path.relative_to(data_path) + transcript_path = wav_path.parent.with_name("txt") / wav_path.with_suffix(".txt").name + manifest['samples'].append({ + 'wav_path': wav_path.as_posix(), + 'transcript_path': transcript_path.as_posix() + }) + + output_path.write_text(json.dumps(manifest), encoding='utf8') + + +def _duration_file_path(path): + return path, sox.file_info.duration(path) + + +def order_and_prune_files( + file_paths, + min_duration, + max_duration, + num_workers): + print("Gathering durations...") + with Pool(processes=num_workers) as p: + duration_file_paths = list(tqdm(p.imap(_duration_file_path, file_paths), total=len(file_paths))) print("Sorting manifests...") - duration_file_paths = [(path, float(subprocess.check_output( - ['soxi -D \"%s\"' % path.strip()], shell=True))) for path in file_paths] if min_duration and max_duration: print("Pruning manifests between %d and %d seconds" % (min_duration, max_duration)) duration_file_paths = [(path, duration) for path, duration in duration_file_paths if min_duration <= duration <= max_duration] - def func(element): - return element[1] - - duration_file_paths.sort(key=func) return [x[0] for x in duration_file_paths] # Remove durations diff --git a/deepspeech_pytorch/decoder.py b/deepspeech_pytorch/decoder.py index 6d4a64e4..1d6a3dee 100644 --- a/deepspeech_pytorch/decoder.py +++ b/deepspeech_pytorch/decoder.py @@ -15,7 +15,6 @@ # ---------------------------------------------------------------------------- # Modified to support pytorch Tensors -import Levenshtein as Lev import torch from six.moves import xrange @@ -39,37 +38,6 @@ def __init__(self, labels, blank_index=0): space_index = labels.index(' ') self.space_index = space_index - def wer(self, s1, s2): - """ - Computes the Word Error Rate, defined as the edit distance between the - two provided sentences after tokenizing to words. - Arguments: - s1 (string): space-separated sentence - s2 (string): space-separated sentence - """ - - # build mapping of words to integers - b = set(s1.split() + s2.split()) - word2char = dict(zip(b, range(len(b)))) - - # map the words to a char array (Levenshtein packages only accepts - # strings) - w1 = [chr(word2char[w]) for w in s1.split()] - w2 = [chr(word2char[w]) for w in s2.split()] - - return Lev.distance(''.join(w1), ''.join(w2)) - - def cer(self, s1, s2): - """ - Computes the Character Error Rate, defined as the edit distance. - - Arguments: - s1 (string): space-separated sentence - s2 (string): space-separated sentence - """ - s1, s2, = s1.replace(' ', ''), s2.replace(' ', '') - return Lev.distance(s1, s2) - def decode(self, probs, sizes=None): """ Given a matrix of character probabilities, returns the decoder's diff --git a/deepspeech_pytorch/enums.py b/deepspeech_pytorch/enums.py index 328fb07f..023ed767 100644 --- a/deepspeech_pytorch/enums.py +++ b/deepspeech_pytorch/enums.py @@ -1,17 +1,13 @@ from enum import Enum +from torch import nn + class DecoderType(Enum): greedy: str = 'greedy' beam: str = 'beam' -class DistributedBackend(Enum): - gloo = 'gloo' - mpi = 'mpi' - nccl = 'nccl' - - class SpectrogramWindow(Enum): hamming = 'hamming' hann = 'hann' @@ -20,6 +16,6 @@ class SpectrogramWindow(Enum): class RNNType(Enum): - lstm = 'lstm' - rnn = 'rnn' - gru = 'gru' + lstm = nn.LSTM + rnn = nn.RNN + gru = nn.GRU diff --git a/deepspeech_pytorch/inference.py b/deepspeech_pytorch/inference.py index cdb5e677..f15e4dfd 100644 --- a/deepspeech_pytorch/inference.py +++ b/deepspeech_pytorch/inference.py @@ -1,7 +1,9 @@ import json from typing import List +import hydra import torch +from torch.cuda.amp import autocast from deepspeech_pytorch.configs.inference_config import TranscribeConfig from deepspeech_pytorch.decoder import Decoder @@ -42,25 +44,34 @@ def decode_results(decoded_output: List, def transcribe(cfg: TranscribeConfig): device = torch.device("cuda" if cfg.model.cuda else "cpu") - model = load_model(device=device, - model_path=cfg.model.model_path, - use_half=cfg.model.use_half) + model = load_model( + device=device, + model_path=cfg.model.model_path + ) - decoder = load_decoder(labels=model.labels, - cfg=cfg.lm) + decoder = load_decoder( + labels=model.labels, + cfg=cfg.lm + ) - spect_parser = SpectrogramParser(audio_conf=model.audio_conf, - normalize=True) + spect_parser = SpectrogramParser( + audio_conf=model.spect_cfg, + normalize=True + ) - decoded_output, decoded_offsets = run_transcribe(audio_path=cfg.audio_path, - spect_parser=spect_parser, - model=model, - decoder=decoder, - device=device, - use_half=cfg.model.use_half) - results = decode_results(decoded_output=decoded_output, - decoded_offsets=decoded_offsets, - cfg=cfg) + decoded_output, decoded_offsets = run_transcribe( + audio_path=hydra.utils.to_absolute_path(cfg.audio_path), + spect_parser=spect_parser, + model=model, + decoder=decoder, + device=device, + precision=cfg.model.precision + ) + results = decode_results( + decoded_output=decoded_output, + decoded_offsets=decoded_offsets, + cfg=cfg + ) print(json.dumps(results)) @@ -69,13 +80,12 @@ def run_transcribe(audio_path: str, model: DeepSpeech, decoder: Decoder, device: torch.device, - use_half: bool): + precision: int): spect = spect_parser.parse_audio(audio_path).contiguous() spect = spect.view(1, 1, spect.size(0), spect.size(1)) spect = spect.to(device) - if use_half: - spect = spect.half() input_sizes = torch.IntTensor([spect.size(3)]).int() - out, output_sizes = model(spect, input_sizes) + with autocast(enabled=precision == 16): + out, output_sizes = model(spect, input_sizes) decoded_output, decoded_offsets = decoder.decode(out, output_sizes) return decoded_output, decoded_offsets diff --git a/deepspeech_pytorch/loader/data_loader.py b/deepspeech_pytorch/loader/data_loader.py index 18b37c8e..71833463 100644 --- a/deepspeech_pytorch/loader/data_loader.py +++ b/deepspeech_pytorch/loader/data_loader.py @@ -1,29 +1,29 @@ +import json import math import os +from pathlib import Path from tempfile import NamedTemporaryFile import librosa import numpy as np -import soundfile as sf import sox import torch from torch.utils.data import Dataset, Sampler, DistributedSampler, DataLoader +import torchaudio from deepspeech_pytorch.configs.train_config import SpectConfig, AugmentationConfig from deepspeech_pytorch.loader.spec_augment import spec_augment +torchaudio.set_audio_backend("sox_io") + def load_audio(path): - sound, sample_rate = sf.read(path, dtype='int16') - # TODO this should be 32768.0 to get twos-complement range. - # TODO the difference is negligible but should be fixed for new models. - sound = sound.astype('float32') / 32767 # normalize audio - if len(sound.shape) > 1: - if sound.shape[1] == 1: - sound = sound.squeeze() - else: - sound = sound.mean(axis=1) # multiple channels, average - return sound + sound, sample_rate = torchaudio.load(path) + if sound.shape[0] == 1: + sound = sound.squeeze() + else: + sound = sound.mean(axis=0) # multiple channels, average + return sound.numpy() class AudioParser(object): @@ -138,30 +138,27 @@ def parse_transcript(self, transcript_path): class SpectrogramDataset(Dataset, SpectrogramParser): def __init__(self, audio_conf: SpectConfig, - manifest_filepath: str, + input_path: str, labels: list, normalize: bool = False, - augmentation_conf: AugmentationConfig = None): + aug_cfg: AugmentationConfig = None): """ Dataset that loads tensors via a csv containing file paths to audio files and transcripts separated by a comma. Each new line is a different sample. Example below: /path/to/audio.wav,/path/to/audio.txt ... - + You can also pass the directory of dataset. :param audio_conf: Config containing the sample rate, window and the window length/stride in seconds - :param manifest_filepath: Path to manifest csv as describe above + :param input_path: Path to input. :param labels: List containing all the possible characters to map to :param normalize: Apply standard mean and deviation normalization to audio tensor :param augmentation_conf(Optional): Config containing the augmentation parameters """ - with open(manifest_filepath) as f: - ids = f.readlines() - ids = [x.strip().split(',') for x in ids] - self.ids = ids - self.size = len(ids) + self.ids = self._parse_input(input_path) + self.size = len(self.ids) self.labels_map = dict([(labels[i], i) for i in range(len(labels))]) - super(SpectrogramDataset, self).__init__(audio_conf, normalize, augmentation_conf) + super(SpectrogramDataset, self).__init__(audio_conf, normalize, aug_cfg) def __getitem__(self, index): sample = self.ids[index] @@ -170,6 +167,22 @@ def __getitem__(self, index): transcript = self.parse_transcript(transcript_path) return spect, transcript + def _parse_input(self, input_path): + ids = [] + if os.path.isdir(input_path): + for wav_path in Path(input_path).rglob('*.wav'): + transcript_path = str(wav_path).replace('/wav/', '/txt/').replace('.wav', '.txt') + ids.append((wav_path, transcript_path)) + else: + # Assume it is a manifest file + with open(input_path) as f: + manifest = json.load(f) + for sample in manifest['samples']: + wav_path = os.path.join(manifest['root_path'], sample['wav_path']) + transcript_path = os.path.join(manifest['root_path'], sample['transcript_path']) + ids.append((wav_path, transcript_path)) + return ids + def parse_transcript(self, transcript_path): with open(transcript_path, 'r', encoding='utf8') as transcript_file: transcript = transcript_file.read().replace('\n', '') @@ -202,7 +215,7 @@ def func(p): input_percentages[x] = seq_length / float(max_seqlength) target_sizes[x] = len(target) targets.extend(target) - targets = torch.IntTensor(targets) + targets = torch.tensor(targets, dtype=torch.long) return inputs, targets, input_percentages, target_sizes @@ -222,11 +235,12 @@ class DSRandomSampler(Sampler): This is essential since we support saving/loading state during an epoch. """ - def __init__(self, dataset, batch_size=1, start_index=0): + def __init__(self, dataset, batch_size=1): super().__init__(data_source=dataset) self.dataset = dataset - self.start_index = start_index + self.start_index = 0 + self.epoch = 0 self.batch_size = batch_size ids = list(range(len(self.dataset))) self.bins = [ids[i:i + self.batch_size] for i in range(0, len(ids), self.batch_size)] @@ -251,9 +265,6 @@ def __len__(self): def set_epoch(self, epoch): self.epoch = epoch - def reset_training_step(self, training_step): - self.start_index = training_step - class DSElasticDistributedSampler(DistributedSampler): """ @@ -261,9 +272,9 @@ class DSElasticDistributedSampler(DistributedSampler): This is essential since we support saving/loading state during an epoch. """ - def __init__(self, dataset, num_replicas=None, rank=None, start_index=0, batch_size=1): + def __init__(self, dataset, num_replicas=None, rank=None, batch_size=1): super().__init__(dataset=dataset, num_replicas=num_replicas, rank=rank) - self.start_index = start_index + self.start_index = 0 self.batch_size = batch_size ids = list(range(len(dataset))) self.bins = [ids[i:i + self.batch_size] for i in range(0, len(ids), self.batch_size)] @@ -297,13 +308,6 @@ def __iter__(self): def __len__(self): return self.num_samples - def reset_training_step(self, training_step): - self.start_index = training_step - self.num_samples = int( - math.ceil(float(len(self.bins) - self.start_index) / self.num_replicas) - ) - self.total_size = self.num_samples * self.num_replicas - def audio_with_sox(path, sample_rate, start_time, end_time): """ diff --git a/deepspeech_pytorch/loader/data_module.py b/deepspeech_pytorch/loader/data_module.py new file mode 100644 index 00000000..2b747317 --- /dev/null +++ b/deepspeech_pytorch/loader/data_module.py @@ -0,0 +1,62 @@ +import pytorch_lightning as pl +from hydra.utils import to_absolute_path + +from deepspeech_pytorch.configs.train_config import DataConfig +from deepspeech_pytorch.loader.data_loader import SpectrogramDataset, DSRandomSampler, AudioDataLoader, \ + DSElasticDistributedSampler + + +class DeepSpeechDataModule(pl.LightningDataModule): + + def __init__(self, + labels: list, + data_cfg: DataConfig, + normalize: bool, + is_distributed: bool): + super().__init__() + self.train_path = to_absolute_path(data_cfg.train_path) + self.val_path = to_absolute_path(data_cfg.val_path) + self.labels = labels + self.data_cfg = data_cfg + self.spect_cfg = data_cfg.spect + self.aug_cfg = data_cfg.augmentation + self.normalize = normalize + self.is_distributed = is_distributed + + def train_dataloader(self): + train_dataset = self._create_dataset(self.train_path) + if self.is_distributed: + train_sampler = DSElasticDistributedSampler( + dataset=train_dataset, + batch_size=self.data_cfg.batch_size + ) + else: + train_sampler = DSRandomSampler( + dataset=train_dataset, + batch_size=self.data_cfg.batch_size + ) + train_loader = AudioDataLoader( + dataset=train_dataset, + num_workers=self.data_cfg.num_workers, + batch_sampler=train_sampler + ) + return train_loader + + def val_dataloader(self): + val_dataset = self._create_dataset(self.val_path) + val_loader = AudioDataLoader( + dataset=val_dataset, + num_workers=self.data_cfg.num_workers, + batch_size=self.data_cfg.batch_size + ) + return val_loader + + def _create_dataset(self, input_path): + dataset = SpectrogramDataset( + audio_conf=self.spect_cfg, + input_path=input_path, + labels=self.labels, + normalize=True, + aug_cfg=self.aug_cfg + ) + return dataset diff --git a/deepspeech_pytorch/loader/merge_manifests.py b/deepspeech_pytorch/loader/merge_manifests.py deleted file mode 100644 index a74ece2a..00000000 --- a/deepspeech_pytorch/loader/merge_manifests.py +++ /dev/null @@ -1,31 +0,0 @@ -from __future__ import print_function - -import argparse -import io -import os - -from tqdm import tqdm -from utils import order_and_prune_files - -parser = argparse.ArgumentParser(description='Merges all manifest CSV files in specified folder.') -parser.add_argument('--merge-dir', default='manifests/', help='Path to all manifest files you want to merge') -parser.add_argument('--min-duration', default=1, type=int, - help='Prunes any samples shorter than the min duration (given in seconds, default 1)') -parser.add_argument('--max-duration', default=15, type=int, - help='Prunes any samples longer than the max duration (given in seconds, default 15)') -parser.add_argument('--output-path', default='merged_manifest.csv', help='Output path to merged manifest') - -args = parser.parse_args() - -file_paths = [] -for file in os.listdir(args.merge_dir): - if file.endswith(".csv"): - with open(os.path.join(args.merge_dir, file), 'r') as fh: - file_paths += fh.readlines() -file_paths = [file_path.split(',')[0] for file_path in file_paths] -file_paths = order_and_prune_files(file_paths, args.min_duration, args.max_duration) -with io.FileIO(args.output_path, "w") as file: - for wav_path in tqdm(file_paths, total=len(file_paths)): - transcript_path = wav_path.replace('/wav/', '/txt/').replace('.wav', '.txt') - sample = os.path.abspath(wav_path) + ',' + os.path.abspath(transcript_path) + '\n' - file.write(sample.encode('utf-8')) diff --git a/deepspeech_pytorch/logger.py b/deepspeech_pytorch/logger.py deleted file mode 100644 index 3383456c..00000000 --- a/deepspeech_pytorch/logger.py +++ /dev/null @@ -1,72 +0,0 @@ -import os - -import torch - - -def to_np(x): - return x.cpu().numpy() - - -class VisdomLogger(object): - def __init__(self, id, num_epochs): - from visdom import Visdom - self.viz = Visdom() - self.opts = dict(title=id, ylabel='', xlabel='Epoch', legend=['Loss', 'WER', 'CER']) - self.viz_window = None - self.epochs = torch.arange(1, num_epochs + 1) - self.visdom_plotter = True - - def update(self, epoch, values): - x_axis = self.epochs[0:epoch + 1] - y_axis = torch.stack((values.loss_results[:epoch], - values.wer_results[:epoch], - values.cer_results[:epoch]), - dim=1) - self.viz_window = self.viz.line( - X=x_axis, - Y=y_axis, - opts=self.opts, - win=self.viz_window, - update='replace' if self.viz_window else None - ) - - def load_previous_values(self, start_epoch, results_state): - self.update(start_epoch - 1, results_state) # Add all values except the iteration we're starting from - - -class TensorBoardLogger(object): - def __init__(self, id, log_dir, log_params): - os.makedirs(log_dir, exist_ok=True) - from torch.utils.tensorboard import SummaryWriter - self.id = id - self.tensorboard_writer = SummaryWriter(log_dir) - self.log_params = log_params - - def update(self, epoch, results_state, parameters=None): - loss = results_state.loss_results[epoch] - wer = results_state.wer_results[epoch] - cer = results_state.cer_results[epoch] - values = { - 'Avg Train Loss': loss, - 'Avg WER': wer, - 'Avg CER': cer - } - self.tensorboard_writer.add_scalars(self.id, values, epoch + 1) - if self.log_params: - for tag, value in parameters(): - tag = tag.replace('.', '/') - self.tensorboard_writer.add_histogram(tag, to_np(value), epoch + 1) - self.tensorboard_writer.add_histogram(tag + '/grad', to_np(value.grad), epoch + 1) - - def load_previous_values(self, start_epoch, result_state): - loss_results = result_state.loss_results[:start_epoch] - wer_results = result_state.wer_results[:start_epoch] - cer_results = result_state.cer_results[:start_epoch] - - for i in range(start_epoch): - values = { - 'Avg Train Loss': loss_results[i], - 'Avg WER': wer_results[i], - 'Avg CER': cer_results[i] - } - self.tensorboard_writer.add_scalars(self.id, values, i + 1) diff --git a/deepspeech_pytorch/model.py b/deepspeech_pytorch/model.py index 595e1798..e695dde2 100644 --- a/deepspeech_pytorch/model.py +++ b/deepspeech_pytorch/model.py @@ -1,21 +1,18 @@ import math -from collections import OrderedDict +from typing import List, Union +import pytorch_lightning as pl import torch import torch.nn as nn import torch.nn.functional as F -# Due to backwards compatibility we need to keep the below structure for mapping RNN type from omegaconf import OmegaConf +from torch.cuda.amp import autocast +from torch.nn import CTCLoss -from deepspeech_pytorch.configs.train_config import SpectConfig -from deepspeech_pytorch.enums import SpectrogramWindow - -supported_rnns = { - 'lstm': nn.LSTM, - 'rnn': nn.RNN, - 'gru': nn.GRU -} -supported_rnns_inv = dict((v, k) for k, v in supported_rnns.items()) +from deepspeech_pytorch.configs.train_config import SpectConfig, BiDirectionalConfig, OptimConfig, AdamConfig, \ + SGDConfig, UniDirectionalConfig +from deepspeech_pytorch.decoder import GreedyDecoder +from deepspeech_pytorch.validation import CharErrorRate, WordErrorRate class SequenceWise(nn.Module): @@ -115,8 +112,15 @@ def __init__(self, n_features, context): self.context = context self.n_features = n_features self.pad = (0, self.context - 1) - self.conv = nn.Conv1d(self.n_features, self.n_features, kernel_size=self.context, stride=1, - groups=self.n_features, padding=0, bias=None) + self.conv = nn.Conv1d( + self.n_features, + self.n_features, + kernel_size=self.context, + stride=1, + groups=self.n_features, + padding=0, + bias=False + ) def forward(self, x): x = x.transpose(0, 1).transpose(1, 2) @@ -131,20 +135,23 @@ def __repr__(self): + ', context=' + str(self.context) + ')' -class DeepSpeech(nn.Module): - def __init__(self, rnn_type, labels, rnn_hidden_size, nb_layers, audio_conf, - bidirectional, context=20): - super(DeepSpeech, self).__init__() +class DeepSpeech(pl.LightningModule): + def __init__(self, + labels: List, + model_cfg: Union[UniDirectionalConfig, BiDirectionalConfig], + precision: int, + optim_cfg: Union[AdamConfig, SGDConfig], + spect_cfg: SpectConfig + ): + super().__init__() + self.save_hyperparameters() + self.model_cfg = model_cfg + self.precision = precision + self.optim_cfg = optim_cfg + self.spect_cfg = spect_cfg + self.bidirectional = True if OmegaConf.get_type(model_cfg) is BiDirectionalConfig else False - self.hidden_size = rnn_hidden_size - self.hidden_layers = nb_layers - self.rnn_type = rnn_type - self.audio_conf = audio_conf self.labels = labels - self.bidirectional = bidirectional - - sample_rate = self.audio_conf.sample_rate - window_size = self.audio_conf.window_size num_classes = len(self.labels) self.conv = MaskConv(nn.Sequential( @@ -156,34 +163,53 @@ def __init__(self, rnn_type, labels, rnn_hidden_size, nb_layers, audio_conf, nn.Hardtanh(0, 20, inplace=True) )) # Based on above convolutions and spectrogram size using conv formula (W - F + 2P)/ S+1 - rnn_input_size = int(math.floor((sample_rate * window_size) / 2) + 1) + rnn_input_size = int(math.floor((self.spect_cfg.sample_rate * self.spect_cfg.window_size) / 2) + 1) rnn_input_size = int(math.floor(rnn_input_size + 2 * 20 - 41) / 2 + 1) rnn_input_size = int(math.floor(rnn_input_size + 2 * 10 - 21) / 2 + 1) rnn_input_size *= 32 - rnns = [] - rnn = BatchRNN(input_size=rnn_input_size, hidden_size=rnn_hidden_size, rnn_type=rnn_type, - bidirectional=bidirectional, batch_norm=False) - rnns.append(('0', rnn)) - for x in range(nb_layers - 1): - rnn = BatchRNN(input_size=rnn_hidden_size, hidden_size=rnn_hidden_size, rnn_type=rnn_type, - bidirectional=bidirectional) - rnns.append(('%d' % (x + 1), rnn)) - self.rnns = nn.Sequential(OrderedDict(rnns)) + self.rnns = nn.Sequential( + BatchRNN( + input_size=rnn_input_size, + hidden_size=self.model_cfg.hidden_size, + rnn_type=self.model_cfg.rnn_type.value, + bidirectional=self.bidirectional, + batch_norm=False + ), + *( + BatchRNN( + input_size=self.model_cfg.hidden_size, + hidden_size=self.model_cfg.hidden_size, + rnn_type=self.model_cfg.rnn_type.value, + bidirectional=self.bidirectional + ) for x in range(self.model_cfg.hidden_layers - 1) + ) + ) + self.lookahead = nn.Sequential( # consider adding batch norm? - Lookahead(rnn_hidden_size, context=context), + Lookahead(self.model_cfg.hidden_size, context=self.model_cfg.lookahead_context), nn.Hardtanh(0, 20, inplace=True) - ) if not bidirectional else None + ) if not self.bidirectional else None fully_connected = nn.Sequential( - nn.BatchNorm1d(rnn_hidden_size), - nn.Linear(rnn_hidden_size, num_classes, bias=False) + nn.BatchNorm1d(self.model_cfg.hidden_size), + nn.Linear(self.model_cfg.hidden_size, num_classes, bias=False) ) self.fc = nn.Sequential( SequenceWise(fully_connected), ) self.inference_softmax = InferenceBatchSoftmax() + self.criterion = CTCLoss(blank=self.labels.index('_'), reduction='sum', zero_infinity=True) + self.evaluation_decoder = GreedyDecoder(self.labels) # Decoder used for validation + self.wer = WordErrorRate( + decoder=self.evaluation_decoder, + target_decoder=self.evaluation_decoder + ) + self.cer = CharErrorRate( + decoder=self.evaluation_decoder, + target_decoder=self.evaluation_decoder + ) def forward(self, x, lengths): lengths = lengths.cpu().int() @@ -206,6 +232,64 @@ def forward(self, x, lengths): x = self.inference_softmax(x) return x, output_lengths + def training_step(self, batch, batch_idx): + inputs, targets, input_percentages, target_sizes = batch + input_sizes = input_percentages.mul_(int(inputs.size(3))).int() + out, output_sizes = self(inputs, input_sizes) + out = out.transpose(0, 1) # TxNxH + out = out.log_softmax(-1) + + loss = self.criterion(out, targets, output_sizes, target_sizes) + return loss + + def validation_step(self, batch, batch_idx): + inputs, targets, input_percentages, target_sizes = batch + input_sizes = input_percentages.mul_(int(inputs.size(3))).int() + inputs = inputs.to(self.device) + with autocast(enabled=self.precision == 16): + out, output_sizes = self(inputs, input_sizes) + decoded_output, _ = self.evaluation_decoder.decode(out, output_sizes) + self.wer( + preds=out, + preds_sizes=output_sizes, + targets=targets, + target_sizes=target_sizes + ) + self.cer( + preds=out, + preds_sizes=output_sizes, + targets=targets, + target_sizes=target_sizes + ) + self.log('wer', self.wer.compute(), prog_bar=True, on_epoch=True) + self.log('cer', self.cer.compute(), prog_bar=True, on_epoch=True) + + def configure_optimizers(self): + if OmegaConf.get_type(self.optim_cfg) is SGDConfig: + optimizer = torch.optim.SGD( + params=self.parameters(), + lr=self.optim_cfg.learning_rate, + momentum=self.optim_cfg.momentum, + nesterov=True, + weight_decay=self.optim_cfg.weight_decay + ) + elif OmegaConf.get_type(self.optim_cfg) is AdamConfig: + optimizer = torch.optim.AdamW( + params=self.parameters(), + lr=self.optim_cfg.learning_rate, + betas=self.optim_cfg.betas, + eps=self.optim_cfg.eps, + weight_decay=self.optim_cfg.weight_decay + ) + else: + raise ValueError("Optimizer has not been specified correctly.") + + scheduler = torch.optim.lr_scheduler.ExponentialLR( + optimizer=optimizer, + gamma=self.optim_cfg.learning_anneal + ) + return [optimizer], [scheduler] + def get_seq_lens(self, input_length): """ Given a 1D Tensor or Variable containing integer sequence lengths, return a 1D tensor or variable @@ -218,47 +302,3 @@ def get_seq_lens(self, input_length): if type(m) == nn.modules.conv.Conv2d: seq_len = ((seq_len + 2 * m.padding[1] - m.dilation[1] * (m.kernel_size[1] - 1) - 1) // m.stride[1] + 1) return seq_len.int() - - @classmethod - def load_model(cls, path): - package = torch.load(path, map_location=lambda storage, loc: storage) - model = DeepSpeech.load_model_package(package) - return model - - @classmethod - def load_model_package(cls, package): - # TODO Added for backwards compatibility, should be remove for new release - if OmegaConf.get_type(package['audio_conf']) == dict: - audio_conf = package['audio_conf'] - package['audio_conf'] = SpectConfig(sample_rate=audio_conf['sample_rate'], - window_size=audio_conf['window_size'], - window=SpectrogramWindow(audio_conf['window'])) - model = cls(rnn_hidden_size=package['hidden_size'], - nb_layers=package['hidden_layers'], - labels=package['labels'], - audio_conf=package['audio_conf'], - rnn_type=supported_rnns[package['rnn_type']], - bidirectional=package.get('bidirectional', True)) - model.load_state_dict(package['state_dict']) - return model - - def serialize_state(self): - return { - 'hidden_size': self.hidden_size, - 'hidden_layers': self.hidden_layers, - 'rnn_type': supported_rnns_inv.get(self.rnn_type, self.rnn_type.__name__.lower()), - 'audio_conf': self.audio_conf, - 'labels': self.labels, - 'state_dict': self.state_dict(), - 'bidirectional': self.bidirectional, - } - - @staticmethod - def get_param_size(model): - params = 0 - for p in model.parameters(): - tmp = 1 - for x in p.size(): - tmp *= x - params += tmp - return params diff --git a/deepspeech_pytorch/state.py b/deepspeech_pytorch/state.py deleted file mode 100644 index eea87cce..00000000 --- a/deepspeech_pytorch/state.py +++ /dev/null @@ -1,171 +0,0 @@ -import torch - -from deepspeech_pytorch.model import DeepSpeech -from deepspeech_pytorch.utils import remove_parallel_wrapper - - -class ResultState: - def __init__(self, - loss_results, - wer_results, - cer_results): - self.loss_results = loss_results - self.wer_results = wer_results - self.cer_results = cer_results - - def add_results(self, - epoch, - loss_result, - wer_result, - cer_result): - self.loss_results[epoch] = loss_result - self.wer_results[epoch] = wer_result - self.cer_results[epoch] = cer_result - - def serialize_state(self): - return { - 'loss_results': self.loss_results, - 'wer_results': self.wer_results, - 'cer_results': self.cer_results - } - - -class TrainingState: - def __init__(self, - model, - result_state=None, - optim_state=None, - amp_state=None, - best_wer=None, - avg_loss=0, - epoch=0, - training_step=0): - """ - Wraps around training model and states for saving/loading convenience. - For backwards compatibility there are more states being saved than necessary. - """ - self.model = model - self.result_state = result_state - self.optim_state = optim_state - self.amp_state = amp_state - self.best_wer = best_wer - self.avg_loss = avg_loss - self.epoch = epoch - self.training_step = training_step - - def track_optim_state(self, optimizer): - self.optim_state = optimizer.state_dict() - - def track_amp_state(self, amp): - self.amp_state = amp.state_dict() - - def init_results_tracking(self, epochs): - self.result_state = ResultState(loss_results=torch.FloatTensor(epochs), - wer_results=torch.FloatTensor(epochs), - cer_results=torch.FloatTensor(epochs)) - - def add_results(self, - epoch, - loss_result, - wer_result, - cer_result): - self.result_state.add_results(epoch=epoch, - loss_result=loss_result, - wer_result=wer_result, - cer_result=cer_result) - - def init_finetune_states(self, epochs): - """ - Resets the training environment, but keeps model specific states in tact. - This is when fine-tuning a model on another dataset where training is to be reset but the model - weights are to be loaded - :param epochs: Number of epochs fine-tuning. - """ - self.init_results_tracking(epochs) - self._reset_amp_state() - self._reset_optim_state() - self._reset_epoch() - self.reset_training_step() - self._reset_best_wer() - - def serialize_state(self, epoch, iteration): - model = remove_parallel_wrapper(self.model) - model_dict = model.serialize_state() - training_dict = self._serialize_training_state(epoch=epoch, - iteration=iteration) - results_dict = self.result_state.serialize_state() - # Ensure flat structure for backwards compatibility - state_dict = {**model_dict, **training_dict, **results_dict} # Combine dicts - return state_dict - - def _serialize_training_state(self, epoch, iteration): - return { - 'optim_dict': self.optim_state, - 'amp': self.amp_state, - 'avg_loss': self.avg_loss, - 'best_wer': self.best_wer, - 'epoch': epoch + 1, # increment for readability - 'iteration': iteration, - } - - @classmethod - def load_state(cls, state_path): - print("Loading state from model %s" % state_path) - state = torch.load(state_path, map_location=lambda storage, loc: storage) - model = DeepSpeech.load_model_package(state) - optim_state = state['optim_dict'] - amp_state = state.get('amp') - if not amp_state: - print("WARNING: No state for Apex has been stored in the model.") - - epoch = int(state.get('epoch', 1)) - 1 # Index start at 0 for training - training_step = state.get('iteration', None) - if training_step is None: - epoch += 1 # We saved model after epoch finished, start at the next epoch. - training_step = 0 - else: - training_step += 1 - avg_loss = int(state.get('avg_loss', 0)) - loss_results = state['loss_results'] - cer_results = state['cer_results'] - wer_results = state['wer_results'] - best_wer = state.get('best_wer') - - result_state = ResultState(loss_results=loss_results, - cer_results=cer_results, - wer_results=wer_results) - return cls(optim_state=optim_state, - amp_state=amp_state, - model=model, - result_state=result_state, - best_wer=best_wer, - avg_loss=avg_loss, - epoch=epoch, - training_step=training_step) - - def set_epoch(self, epoch): - self.epoch = epoch - - def set_best_wer(self, wer): - self.best_wer = wer - - def set_training_step(self, training_step): - self.training_step = training_step - - def reset_training_step(self): - self.training_step = 0 - - def reset_avg_loss(self): - self.avg_loss = 0 - - def _reset_amp_state(self): - self.amp_state = None - - def _reset_optim_state(self): - self.optim_state = None - - def _reset_epoch(self): - self.epoch = 0 - - def _reset_best_wer(self): - self.best_wer = None diff --git a/deepspeech_pytorch/testing.py b/deepspeech_pytorch/testing.py index e18bc1c9..c464e735 100644 --- a/deepspeech_pytorch/testing.py +++ b/deepspeech_pytorch/testing.py @@ -1,94 +1,50 @@ import hydra import torch -from tqdm import tqdm from deepspeech_pytorch.configs.inference_config import EvalConfig from deepspeech_pytorch.decoder import GreedyDecoder from deepspeech_pytorch.loader.data_loader import SpectrogramDataset, AudioDataLoader from deepspeech_pytorch.utils import load_model, load_decoder +from deepspeech_pytorch.validation import run_evaluation @torch.no_grad() def evaluate(cfg: EvalConfig): device = torch.device("cuda" if cfg.model.cuda else "cpu") - model = load_model(device=device, - model_path=cfg.model.model_path, - use_half=cfg.model.use_half) - - decoder = load_decoder(labels=model.labels, - cfg=cfg.lm) - target_decoder = GreedyDecoder(model.labels, - blank_index=model.labels.index('_')) - test_dataset = SpectrogramDataset(audio_conf=model.audio_conf, - manifest_filepath=hydra.utils.to_absolute_path(cfg.test_manifest), - labels=model.labels, - normalize=True) - test_loader = AudioDataLoader(test_dataset, - batch_size=cfg.batch_size, - num_workers=cfg.num_workers) - wer, cer, output_data = run_evaluation(test_loader=test_loader, - device=device, - model=model, - decoder=decoder, - target_decoder=target_decoder, - save_output=cfg.save_output, - verbose=cfg.verbose, - use_half=cfg.model.use_half) + model = load_model( + device=device, + model_path=cfg.model.model_path + ) + + decoder = load_decoder( + labels=model.labels, + cfg=cfg.lm + ) + target_decoder = GreedyDecoder( + labels=model.labels, + blank_index=model.labels.index('_') + ) + test_dataset = SpectrogramDataset( + audio_conf=model.spect_cfg, + input_path=hydra.utils.to_absolute_path(cfg.test_path), + labels=model.labels, + normalize=True + ) + test_loader = AudioDataLoader( + test_dataset, + batch_size=cfg.batch_size, + num_workers=cfg.num_workers + ) + wer, cer = run_evaluation( + test_loader=test_loader, + device=device, + model=model, + decoder=decoder, + target_decoder=target_decoder, + precision=cfg.model.precision + ) print('Test Summary \t' 'Average WER {wer:.3f}\t' 'Average CER {cer:.3f}\t'.format(wer=wer, cer=cer)) - if cfg.save_output: - torch.save(output_data, hydra.utils.to_absolute_path(cfg.save_output)) - - -@torch.no_grad() -def run_evaluation(test_loader, - device, - model, - decoder, - target_decoder, - save_output=None, - verbose=False, - use_half=False): - model.eval() - total_cer, total_wer, num_tokens, num_chars = 0, 0, 0, 0 - output_data = [] - for i, (data) in tqdm(enumerate(test_loader), total=len(test_loader)): - inputs, targets, input_percentages, target_sizes = data - input_sizes = input_percentages.mul_(int(inputs.size(3))).int() - inputs = inputs.to(device) - if use_half: - inputs = inputs.half() - # unflatten targets - split_targets = [] - offset = 0 - for size in target_sizes: - split_targets.append(targets[offset:offset + size]) - offset += size - - out, output_sizes = model(inputs, input_sizes) - - decoded_output, _ = decoder.decode(out, output_sizes) - target_strings = target_decoder.convert_to_strings(split_targets) - - if save_output is not None: - # add output to data array, and continue - output_data.append((out.cpu(), output_sizes, target_strings)) - for x in range(len(target_strings)): - transcript, reference = decoded_output[x][0], target_strings[x][0] - wer_inst = decoder.wer(transcript, reference) - cer_inst = decoder.cer(transcript, reference) - total_wer += wer_inst - total_cer += cer_inst - num_tokens += len(reference.split()) - num_chars += len(reference.replace(' ', '')) - if verbose: - print("Ref:", reference.lower()) - print("Hyp:", transcript.lower()) - print("WER:", float(wer_inst) / len(reference.split()), - "CER:", float(cer_inst) / len(reference.replace(' ', '')), "\n") - wer = float(total_wer) / num_tokens - cer = float(total_cer) / num_chars - return wer * 100, cer * 100, output_data diff --git a/deepspeech_pytorch/training.py b/deepspeech_pytorch/training.py index 86e66bc6..7c018170 100644 --- a/deepspeech_pytorch/training.py +++ b/deepspeech_pytorch/training.py @@ -1,288 +1,54 @@ import json -import os -import random -import time -import numpy as np -import torch.distributed as dist -import torch.utils.data.distributed -from apex import amp +import hydra +from deepspeech_pytorch.checkpoint import CheckpointHandler, GCSCheckpointHandler +from deepspeech_pytorch.configs.train_config import DeepSpeechConfig, GCSCheckpointConfig +from deepspeech_pytorch.loader.data_module import DeepSpeechDataModule +from deepspeech_pytorch.model import DeepSpeech from hydra.utils import to_absolute_path from omegaconf import OmegaConf -from torch.nn.parallel import DistributedDataParallel -from warpctc_pytorch import CTCLoss +from pytorch_lightning import seed_everything -from deepspeech_pytorch.checkpoint import FileCheckpointHandler, GCSCheckpointHandler -from deepspeech_pytorch.configs.train_config import SGDConfig, AdamConfig, BiDirectionalConfig, UniDirectionalConfig, \ - FileCheckpointConfig, GCSCheckpointConfig -from deepspeech_pytorch.decoder import GreedyDecoder -from deepspeech_pytorch.loader.data_loader import SpectrogramDataset, DSRandomSampler, DSElasticDistributedSampler, \ - AudioDataLoader -from deepspeech_pytorch.logger import VisdomLogger, TensorBoardLogger -from deepspeech_pytorch.model import DeepSpeech, supported_rnns -from deepspeech_pytorch.state import TrainingState -from deepspeech_pytorch.testing import run_evaluation -from deepspeech_pytorch.utils import check_loss +def train(cfg: DeepSpeechConfig): + seed_everything(cfg.seed) -class AverageMeter(object): - """Computes and stores the average and current value""" + with open(to_absolute_path(cfg.data.labels_path)) as label_file: + labels = json.load(label_file) - def __init__(self): - self.reset() - - def reset(self): - self.val = 0 - self.avg = 0 - self.sum = 0 - self.count = 0 - - def update(self, val, n=1): - self.val = val - self.sum += val * n - self.count += n - self.avg = self.sum / self.count - - -def train(cfg): - # Set seeds for determinism - torch.manual_seed(cfg.training.seed) - torch.cuda.manual_seed_all(cfg.training.seed) - np.random.seed(cfg.training.seed) - random.seed(cfg.training.seed) - - main_proc = True - device = torch.device("cpu" if cfg.training.no_cuda else "cuda") - - is_distributed = os.environ.get("LOCAL_RANK") # If local rank exists, distributed env - - if is_distributed: - # when using NCCL, on failures, surviving nodes will deadlock on NCCL ops - # because NCCL uses a spin-lock on the device. Set this env var and - # to enable a watchdog thread that will destroy stale NCCL communicators - os.environ["NCCL_BLOCKING_WAIT"] = "1" - - device_id = int(os.environ["LOCAL_RANK"]) - torch.cuda.set_device(device_id) - print(f"Setting CUDA Device to {device_id}") - - dist.init_process_group(backend=cfg.training.dist_backend.value) - main_proc = device_id == 0 # Main process handles saving of models and reporting - - if OmegaConf.get_type(cfg.checkpointing) == FileCheckpointConfig: - checkpoint_handler = FileCheckpointHandler(cfg=cfg.checkpointing) - elif OmegaConf.get_type(cfg.checkpointing) == GCSCheckpointConfig: - checkpoint_handler = GCSCheckpointHandler(cfg=cfg.checkpointing) - else: - raise ValueError("Checkpoint Config has not been specified correctly.") - - if main_proc and cfg.visualization.visdom: - visdom_logger = VisdomLogger(id=cfg.visualization.id, - num_epochs=cfg.training.epochs) - if main_proc and cfg.visualization.tensorboard: - tensorboard_logger = TensorBoardLogger(id=cfg.visualization.id, - log_dir=to_absolute_path(cfg.visualization.log_dir), - log_params=cfg.visualization.log_params) - - if cfg.checkpointing.load_auto_checkpoint: - latest_checkpoint = checkpoint_handler.find_latest_checkpoint() - if latest_checkpoint: - cfg.checkpointing.continue_from = latest_checkpoint - - if cfg.checkpointing.continue_from: # Starting from previous model - state = TrainingState.load_state(state_path=to_absolute_path(cfg.checkpointing.continue_from)) - model = state.model - if cfg.training.finetune: - state.init_finetune_states(cfg.training.epochs) - - if main_proc and cfg.visualization.visdom: # Add previous scores to visdom graph - visdom_logger.load_previous_values(state.epoch, state.results) - if main_proc and cfg.visualization.tensorboard: # Previous scores to tensorboard logs - tensorboard_logger.load_previous_values(state.epoch, state.results) - else: - # Initialise new model training - with open(to_absolute_path(cfg.data.labels_path)) as label_file: - labels = json.load(label_file) - - if OmegaConf.get_type(cfg.model) is BiDirectionalConfig: - model = DeepSpeech(rnn_hidden_size=cfg.model.hidden_size, - nb_layers=cfg.model.hidden_layers, - labels=labels, - rnn_type=supported_rnns[cfg.model.rnn_type.value], - audio_conf=cfg.data.spect, - bidirectional=True) - elif OmegaConf.get_type(cfg.model) is UniDirectionalConfig: - model = DeepSpeech(rnn_hidden_size=cfg.model.hidden_size, - nb_layers=cfg.model.hidden_layers, - labels=labels, - rnn_type=supported_rnns[cfg.model.rnn_type.value], - audio_conf=cfg.data.spect, - bidirectional=False, - context=cfg.model.lookahead_context) + if cfg.trainer.checkpoint_callback: + if OmegaConf.get_type(cfg.checkpoint) is GCSCheckpointConfig: + checkpoint_callback = GCSCheckpointHandler( + cfg=cfg.checkpoint + ) + cfg.trainer.callbacks = [checkpoint_callback] else: - raise ValueError("Model Config has not been specified correctly.") - - state = TrainingState(model=model) - state.init_results_tracking(epochs=cfg.training.epochs) - - # Data setup - evaluation_decoder = GreedyDecoder(model.labels) # Decoder used for validation - train_dataset = SpectrogramDataset(audio_conf=model.audio_conf, - manifest_filepath=to_absolute_path(cfg.data.train_manifest), - labels=model.labels, - normalize=True, - augmentation_conf=cfg.data.augmentation) - test_dataset = SpectrogramDataset(audio_conf=model.audio_conf, - manifest_filepath=to_absolute_path(cfg.data.val_manifest), - labels=model.labels, - normalize=True) - if not is_distributed: - train_sampler = DSRandomSampler(dataset=train_dataset, - batch_size=cfg.data.batch_size, - start_index=state.training_step) - else: - train_sampler = DSElasticDistributedSampler(dataset=train_dataset, - batch_size=cfg.data.batch_size, - start_index=state.training_step) - train_loader = AudioDataLoader(dataset=train_dataset, - num_workers=cfg.data.num_workers, - batch_sampler=train_sampler) - test_loader = AudioDataLoader(dataset=test_dataset, - num_workers=cfg.data.num_workers, - batch_size=cfg.data.batch_size) - - model = model.to(device) - parameters = model.parameters() - if OmegaConf.get_type(cfg.optim) is SGDConfig: - optimizer = torch.optim.SGD(parameters, - lr=cfg.optim.learning_rate, - momentum=cfg.optim.momentum, - nesterov=True, - weight_decay=cfg.optim.weight_decay) - elif OmegaConf.get_type(cfg.optim) is AdamConfig: - optimizer = torch.optim.AdamW(parameters, - lr=cfg.optim.learning_rate, - betas=cfg.optim.betas, - eps=cfg.optim.eps, - weight_decay=cfg.optim.weight_decay) - else: - raise ValueError("Optimizer has not been specified correctly.") - - model, optimizer = amp.initialize(model, optimizer, - enabled=not cfg.training.no_cuda, - opt_level=cfg.apex.opt_level, - loss_scale=cfg.apex.loss_scale) - if state.optim_state is not None: - optimizer.load_state_dict(state.optim_state) - if state.amp_state is not None: - amp.load_state_dict(state.amp_state) - - # Track states for optimizer/amp - state.track_optim_state(optimizer) - if not cfg.training.no_cuda: - state.track_amp_state(amp) - - if is_distributed: - model = DistributedDataParallel(model, device_ids=[device_id]) - print(model) - print("Number of parameters: %d" % DeepSpeech.get_param_size(model)) - - criterion = CTCLoss() - batch_time = AverageMeter() - data_time = AverageMeter() - losses = AverageMeter() - - for epoch in range(state.epoch, cfg.training.epochs): - model.train() - end = time.time() - start_epoch_time = time.time() - state.set_epoch(epoch=epoch) - train_sampler.set_epoch(epoch=epoch) - train_sampler.reset_training_step(training_step=state.training_step) - for i, (data) in enumerate(train_loader, start=state.training_step): - state.set_training_step(training_step=i) - inputs, targets, input_percentages, target_sizes = data - input_sizes = input_percentages.mul_(int(inputs.size(3))).int() - # measure data loading time - data_time.update(time.time() - end) - inputs = inputs.to(device) - - out, output_sizes = model(inputs, input_sizes) - out = out.transpose(0, 1) # TxNxH - - float_out = out.float() # ensure float32 for loss - loss = criterion(float_out, targets, output_sizes, target_sizes).to(device) - loss = loss / inputs.size(0) # average the loss by minibatch - loss_value = loss.item() - - # Check to ensure valid loss was calculated - valid_loss, error = check_loss(loss, loss_value) - if valid_loss: - optimizer.zero_grad() - - # compute gradient - with amp.scale_loss(loss, optimizer) as scaled_loss: - scaled_loss.backward() - torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), cfg.optim.max_norm) - optimizer.step() - else: - print(error) - print('Skipping grad update') - loss_value = 0 - - state.avg_loss += loss_value - losses.update(loss_value, inputs.size(0)) - - # measure elapsed time - batch_time.update(time.time() - end) - end = time.time() - print('Epoch: [{0}][{1}/{2}]\t' - 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' - 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' - 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format( - (epoch + 1), (i + 1), len(train_loader), batch_time=batch_time, data_time=data_time, loss=losses)) - - if main_proc and cfg.checkpointing.checkpoint_per_iteration: - checkpoint_handler.save_iter_checkpoint_model(epoch=epoch, i=i, state=state) - del loss, out, float_out - - state.avg_loss /= len(train_dataset) - - epoch_time = time.time() - start_epoch_time - print('Training Summary Epoch: [{0}]\t' - 'Time taken (s): {epoch_time:.0f}\t' - 'Average Loss {loss:.3f}\t'.format(epoch + 1, epoch_time=epoch_time, loss=state.avg_loss)) - - with torch.no_grad(): - wer, cer, output_data = run_evaluation(test_loader=test_loader, - device=device, - model=model, - decoder=evaluation_decoder, - target_decoder=evaluation_decoder) - - state.add_results(epoch=epoch, - loss_result=state.avg_loss, - wer_result=wer, - cer_result=cer) - - print('Validation Summary Epoch: [{0}]\t' - 'Average WER {wer:.3f}\t' - 'Average CER {cer:.3f}\t'.format(epoch + 1, wer=wer, cer=cer)) - - if main_proc and cfg.visualization.visdom: - visdom_logger.update(epoch, state.result_state) - if main_proc and cfg.visualization.tensorboard: - tensorboard_logger.update(epoch, state.result_state, model.named_parameters()) - - if main_proc and cfg.checkpointing.checkpoint: # Save epoch checkpoint - checkpoint_handler.save_checkpoint_model(epoch=epoch, state=state) - # anneal lr - for g in optimizer.param_groups: - g['lr'] = g['lr'] / cfg.optim.learning_anneal - print('Learning rate annealed to: {lr:.6f}'.format(lr=g['lr'])) - - if main_proc and (state.best_wer is None or state.best_wer > wer): - checkpoint_handler.save_best_model(epoch=epoch, state=state) - state.set_best_wer(wer) - state.reset_avg_loss() - state.reset_training_step() # Reset training step for next epoch + checkpoint_callback = CheckpointHandler( + cfg=cfg.checkpoint + ) + if cfg.load_auto_checkpoint: + resume_from_checkpoint = checkpoint_callback.find_latest_checkpoint() + if resume_from_checkpoint: + cfg.trainer.resume_from_checkpoint = resume_from_checkpoint + + data_loader = DeepSpeechDataModule( + labels=labels, + data_cfg=cfg.data, + normalize=True, + is_distributed=cfg.trainer.gpus > 1 + ) + + model = DeepSpeech( + labels=labels, + model_cfg=cfg.model, + optim_cfg=cfg.optim, + precision=cfg.trainer.precision, + spect_cfg=cfg.data.spect + ) + + trainer = hydra.utils.instantiate( + config=cfg.trainer, + replace_sampler_ddp=False, + callbacks=[checkpoint_callback] if cfg.trainer.checkpoint_callback else None, + ) + trainer.fit(model, data_loader) diff --git a/deepspeech_pytorch/utils.py b/deepspeech_pytorch/utils.py index adbc9265..22916359 100644 --- a/deepspeech_pytorch/utils.py +++ b/deepspeech_pytorch/utils.py @@ -27,13 +27,10 @@ def check_loss(loss, loss_value): def load_model(device, - model_path, - use_half): - model = DeepSpeech.load_model(hydra.utils.to_absolute_path(model_path)) + model_path): + model = DeepSpeech.load_from_checkpoint(hydra.utils.to_absolute_path(model_path)) model.eval() model = model.to(device) - if use_half: - model = model.half() return model @@ -49,7 +46,8 @@ def load_decoder(labels, cfg: LMConfig): cutoff_top_n=cfg.cutoff_top_n, cutoff_prob=cfg.cutoff_prob, beam_width=cfg.beam_width, - num_processes=cfg.lm_workers) + num_processes=cfg.lm_workers, + blank_index=labels.index('_')) else: decoder = GreedyDecoder(labels=labels, blank_index=labels.index('_')) diff --git a/deepspeech_pytorch/validation.py b/deepspeech_pytorch/validation.py new file mode 100644 index 00000000..1c47922d --- /dev/null +++ b/deepspeech_pytorch/validation.py @@ -0,0 +1,170 @@ +from abc import ABC, abstractmethod + +import torch +from torch.cuda.amp import autocast +from tqdm import tqdm + +from deepspeech_pytorch.decoder import Decoder, GreedyDecoder + +from pytorch_lightning.metrics import Metric +import Levenshtein as Lev + + +class ErrorRate(Metric, ABC): + def __init__(self, + decoder: Decoder, + target_decoder: GreedyDecoder, + save_output: bool = False, + dist_sync_on_step: bool = False): + super().__init__(dist_sync_on_step=dist_sync_on_step) + self.decoder = decoder + self.target_decoder = target_decoder + self.save_output = save_output + + @abstractmethod + def calculate_metric(self, transcript, reference): + raise NotImplementedError + + def update(self, preds: torch.Tensor, + preds_sizes: torch.Tensor, + targets: torch.Tensor, + target_sizes: torch.Tensor): + # unflatten targets + split_targets = [] + offset = 0 + for size in target_sizes: + split_targets.append(targets[offset:offset + size]) + offset += size + decoded_output, _ = self.decoder.decode(preds, preds_sizes) + target_strings = self.target_decoder.convert_to_strings(split_targets) + for x in range(len(target_strings)): + transcript, reference = decoded_output[x][0], target_strings[x][0] + self.calculate_metric( + transcript=transcript, + reference=reference + ) + + +class CharErrorRate(ErrorRate): + def __init__(self, + decoder: Decoder, + target_decoder: GreedyDecoder, + save_output: bool = False, + dist_sync_on_step: bool = False): + super().__init__( + decoder=decoder, + target_decoder=target_decoder, + save_output=save_output, + dist_sync_on_step=dist_sync_on_step + ) + self.decoder = decoder + self.target_decoder = target_decoder + self.save_output = save_output + self.add_state("cer", default=torch.tensor(0), dist_reduce_fx="sum") + self.add_state("n_chars", default=torch.tensor(0), dist_reduce_fx="sum") + + def calculate_metric(self, transcript, reference): + cer_inst = self.cer_calc(transcript, reference) + self.cer += cer_inst + self.n_chars += len(reference.replace(' ', '')) + + def compute(self): + cer = float(self.cer) / self.n_chars + return cer.item() * 100 + + def cer_calc(self, s1, s2): + """ + Computes the Character Error Rate, defined as the edit distance. + + Arguments: + s1 (string): space-separated sentence + s2 (string): space-separated sentence + """ + s1, s2, = s1.replace(' ', ''), s2.replace(' ', '') + return Lev.distance(s1, s2) + + +class WordErrorRate(ErrorRate): + def __init__(self, + decoder: Decoder, + target_decoder: GreedyDecoder, + save_output: bool = False, + dist_sync_on_step: bool = False): + super().__init__( + decoder=decoder, + target_decoder=target_decoder, + save_output=save_output, + dist_sync_on_step=dist_sync_on_step + ) + self.decoder = decoder + self.target_decoder = target_decoder + self.save_output = save_output + self.add_state("wer", default=torch.tensor(0), dist_reduce_fx="sum") + self.add_state("n_tokens", default=torch.tensor(0), dist_reduce_fx="sum") + + def calculate_metric(self, transcript, reference): + wer_inst = self.wer_calc(transcript, reference) + self.wer += wer_inst + self.n_tokens += len(reference.split()) + + def compute(self): + wer = float(self.wer) / self.n_tokens + return wer.item() * 100 + + def wer_calc(self, s1, s2): + """ + Computes the Word Error Rate, defined as the edit distance between the + two provided sentences after tokenizing to words. + Arguments: + s1 (string): space-separated sentence + s2 (string): space-separated sentence + """ + + # build mapping of words to integers + b = set(s1.split() + s2.split()) + word2char = dict(zip(b, range(len(b)))) + + # map the words to a char array (Levenshtein packages only accepts + # strings) + w1 = [chr(word2char[w]) for w in s1.split()] + w2 = [chr(word2char[w]) for w in s2.split()] + + return Lev.distance(''.join(w1), ''.join(w2)) + + +@torch.no_grad() +def run_evaluation(test_loader, + model, + decoder: Decoder, + device: torch.device, + target_decoder: Decoder, + precision: int): + model.eval() + wer = WordErrorRate( + decoder=decoder, + target_decoder=target_decoder + ) + cer = CharErrorRate( + decoder=decoder, + target_decoder=target_decoder + ) + for i, (batch) in tqdm(enumerate(test_loader), total=len(test_loader)): + inputs, targets, input_percentages, target_sizes = batch + input_sizes = input_percentages.mul_(int(inputs.size(3))).int() + inputs = inputs.to(device) + with autocast(enabled=precision == 16): + out, output_sizes = model(inputs, input_sizes) + decoded_output, _ = decoder.decode(out, output_sizes) + wer.update( + preds=out, + preds_sizes=output_sizes, + targets=targets, + target_sizes=target_sizes + ) + cer.update( + preds=out, + preds_sizes=output_sizes, + targets=targets, + target_sizes=target_sizes + ) + return wer.compute(), cer.compute() diff --git a/kubernetes/train.yaml b/kubernetes/train.yaml index f0e60d18..7277e372 100644 --- a/kubernetes/train.yaml +++ b/kubernetes/train.yaml @@ -24,12 +24,14 @@ spec: args: - "--nproc_per_node=1" - "/workspace/deepspeech.pytorch/train.py" - - "data.train_manifest=/audio-data/an4_manifests/an4_train_manifest.csv" - - "data.val_manifest=/audio-data/an4_manifests/an4_val_manifest.csv" + - "data.train_path=/audio-data/an4_manifests/an4_train_manifest.csv" + - "data.val_path=/audio-data/an4_manifests/an4_val_manifest.csv" - "data.labels_path=/workspace/deepspeech.pytorch/labels.json" - "data.num_workers=8" - "training.epochs=70" - "data.batch_size=8" + - "training.multigpu=distributed" + - "model.precision=half" - "checkpointing=gcs" - "checkpointing.gcs_bucket=deepspeech-1234" # Swap this to point to the appropriate GCS bucket - "checkpointing.gcs_save_folder=models/" diff --git a/requirements.txt b/requirements.txt index e6b3fb22..9a4509e2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,21 +1,20 @@ scipy numpy -soundfile python-levenshtein torch +torchaudio torchelastic -visdom +pytorch-lightning>=1.1 wget librosa -numba==0.43.0 -llvmlite==0.32.1 tqdm matplotlib flask sox sklearn -soundfile pytest hydra-core google-cloud-storage jupyter +git+https://github.com/romesco/hydra-lightning/#subdirectory=hydra-configs-pytorch-lightning +fairscale \ No newline at end of file diff --git a/search_lm_params.py b/search_lm_params.py index 1354532a..5d16b521 100644 --- a/search_lm_params.py +++ b/search_lm_params.py @@ -33,9 +33,10 @@ print("error: LM must be provided for tuning") sys.exit(1) -model = load_model(model_path=args.model_path, - device='cpu', - use_half=False) +model = load_model( + model_path=args.model_path, + device='cpu' +) saved_output = torch.load(args.saved_output) diff --git a/server.py b/server.py index 2f425c59..d0e1a94c 100644 --- a/server.py +++ b/server.py @@ -37,12 +37,14 @@ def transcribe_file(): with NamedTemporaryFile(suffix=file_extension) as tmp_saved_audio_file: file.save(tmp_saved_audio_file.name) logging.info('Transcribing file...') - transcription, _ = run_transcribe(audio_path=tmp_saved_audio_file, - spect_parser=spect_parser, - model=model, - decoder=decoder, - device=device, - use_half=args.half) + transcription, _ = run_transcribe( + audio_path=tmp_saved_audio_file, + spect_parser=spect_parser, + model=model, + decoder=decoder, + device=device, + precision=config.model.precision + ) logging.info('File transcribed') res['status'] = "OK" res['transcription'] = transcription @@ -58,19 +60,32 @@ def main(cfg: ServerConfig): logging.info('Setting up server...') device = torch.device("cuda" if cfg.model.cuda else "cpu") - model = load_model(device=device, - model_path=cfg.model.model_path, - use_half=cfg.model.use_half) + model = load_model( + device=device, + model_path=cfg.model.model_path + ) - decoder = load_decoder(labels=model.labels, - cfg=cfg.lm) + decoder = load_decoder( + labels=model.labels, + cfg=cfg.lm + ) - spect_parser = SpectrogramParser(audio_conf=model.audio_conf, - normalize=True) + spect_parser = SpectrogramParser( + audio_conf=model.audio_conf, + normalize=True + ) - spect_parser = SpectrogramParser(model.audio_conf, normalize=True) + spect_parser = SpectrogramParser( + audio_conf=model.spect_cfg, + normalize=True + ) logging.info('Server initialised') - app.run(host=cfg.host, port=cfg.port, debug=True, use_reloader=False) + app.run( + host=cfg.host, + port=cfg.port, + debug=True, + use_reloader=False + ) if __name__ == "__main__": diff --git a/tests/pretrained_smoke_test.py b/tests/pretrained_smoke_test.py index 964d3594..fa495064 100644 --- a/tests/pretrained_smoke_test.py +++ b/tests/pretrained_smoke_test.py @@ -8,9 +8,9 @@ from tests.smoke_test import DatasetConfig, DeepSpeechSmokeTest pretrained_urls = [ - 'https://github.com/SeanNaren/deepspeech.pytorch/releases/latest/download/an4_pretrained_v2.pth', - 'https://github.com/SeanNaren/deepspeech.pytorch/releases/latest/download/librispeech_pretrained_v2.pth', - 'https://github.com/SeanNaren/deepspeech.pytorch/releases/latest/download/ted_pretrained_v2.pth' + 'https://github.com/SeanNaren/deepspeech.pytorch/releases/latest/download/an4_pretrained_v3.ckpt', + 'https://github.com/SeanNaren/deepspeech.pytorch/releases/latest/download/librispeech_pretrained_v3.ckpt', + 'https://github.com/SeanNaren/deepspeech.pytorch/releases/latest/download/ted_pretrained_v3.ckpt' ] lm_path = 'http://www.openslr.org/resources/11/3-gram.pruned.3e-7.arpa.gz' @@ -20,9 +20,14 @@ class PretrainedSmokeTest(DeepSpeechSmokeTest): def test_pretrained_eval_inference(self): # Disabled GPU due to using TravisCI - cuda, use_half = False, False - train_manifest, val_manifest, test_manifest = self.download_data(DatasetConfig(target_dir=self.target_dir, - manifest_dir=self.manifest_dir)) + cuda, precision = False, 32 + train_manifest, val_manifest, test_manifest = self.download_data( + DatasetConfig( + target_dir=self.target_dir, + manifest_dir=self.manifest_dir + ), + folders=False + ) wget.download(lm_path) for pretrained_url in pretrained_urls: print("Running Pre-trained Smoke test for: ", pretrained_url) @@ -44,16 +49,20 @@ def test_pretrained_eval_inference(self): ] for lm_config in lm_configs: - self.eval_model(model_path=pretrained_path, - test_manifest=test_manifest, - cuda=cuda, - use_half=use_half, - lm_config=lm_config) - self.inference(test_manifest=test_manifest, - model_path=pretrained_path, - cuda=cuda, - lm_config=lm_config, - use_half=use_half) + self.eval_model( + model_path=pretrained_path, + test_path=test_manifest, + cuda=cuda, + lm_config=lm_config, + precision=precision + ) + self.inference( + test_path=test_manifest, + model_path=pretrained_path, + cuda=cuda, + lm_config=lm_config, + precision=precision + ) if __name__ == '__main__': diff --git a/tests/smoke_test.py b/tests/smoke_test.py index be8032c8..6897fca4 100644 --- a/tests/smoke_test.py +++ b/tests/smoke_test.py @@ -1,18 +1,21 @@ +import json import os import shutil import tempfile import unittest from dataclasses import dataclass +from pathlib import Path from data.an4 import download_an4 from deepspeech_pytorch.configs.inference_config import EvalConfig, ModelConfig, TranscribeConfig, LMConfig from deepspeech_pytorch.configs.train_config import DeepSpeechConfig, AdamConfig, BiDirectionalConfig, \ - FileCheckpointConfig, \ - DataConfig, TrainingConfig + DataConfig, DeepSpeechTrainerConf from deepspeech_pytorch.enums import DecoderType from deepspeech_pytorch.inference import transcribe from deepspeech_pytorch.testing import evaluate from deepspeech_pytorch.training import train +from hydra_configs.pytorch_lightning.callbacks import ModelCheckpointConf +from hydra_configs.pytorch_lightning.trainer import TrainerConf @dataclass @@ -23,6 +26,7 @@ class DatasetConfig: max_duration: float = 15 val_fraction: float = 0.1 sample_rate: int = 16000 + num_workers: int = 4 class DeepSpeechSmokeTest(unittest.TestCase): @@ -37,25 +41,41 @@ def tearDown(self): shutil.rmtree(self.model_dir) def build_train_evaluate_model(self, + limit_train_batches: int, + limit_val_batches: int, epoch: int, batch_size: int, model_config: BiDirectionalConfig, - use_half: bool, - cuda: bool): - train_manifest, val_manifest, test_manifest = self.download_data(DatasetConfig(target_dir=self.target_dir, - manifest_dir=self.manifest_dir)) - - train_cfg = self.create_training_config(epoch=epoch, - batch_size=batch_size, - train_manifest=train_manifest, - val_manifest=val_manifest, - model_config=model_config, - cuda=cuda) + precision: int, + gpus: int, + folders: bool): + cuda = gpus > 0 + + train_path, val_path, test_path = self.download_data( + DatasetConfig( + target_dir=self.target_dir, + manifest_dir=self.manifest_dir + ), + folders=folders + ) + + train_cfg = self.create_training_config( + limit_train_batches=limit_train_batches, + limit_val_batches=limit_val_batches, + max_epochs=epoch, + batch_size=batch_size, + train_path=train_path, + val_path=val_path, + model_config=model_config, + precision=precision, + gpus=gpus + ) print("Running Training DeepSpeech Model Smoke Test") train(train_cfg) # Expected final model path after training - model_path = self.model_dir + '/deepspeech_final.pth' + print(os.listdir(self.model_dir)) + model_path = self.model_dir + '/last.ckpt' assert os.path.exists(model_path) lm_configs = [ @@ -68,91 +88,124 @@ def build_train_evaluate_model(self, for lm_config in lm_configs: self.eval_model( model_path=model_path, - test_manifest=test_manifest, + test_path=test_path, cuda=cuda, - use_half=use_half, + precision=precision, lm_config=lm_config ) - self.inference(test_manifest=test_manifest, + self.inference(test_path=test_path, model_path=model_path, cuda=cuda, - use_half=use_half, + precision=precision, lm_config=lm_config) def eval_model(self, model_path: str, - test_manifest: str, + test_path: str, cuda: bool, - use_half: bool, + precision: int, lm_config: LMConfig): # Due to using TravisCI with no GPU support we have to disable cuda eval_cfg = EvalConfig( model=ModelConfig( cuda=cuda, model_path=model_path, - use_half=use_half + precision=precision ), lm=lm_config, - test_manifest=test_manifest + test_path=test_path ) evaluate(eval_cfg) def inference(self, - test_manifest: str, + test_path: str, model_path: str, cuda: bool, - use_half: bool, + precision: int, lm_config: LMConfig): # Select one file from our test manifest to run inference - with open(test_manifest) as f: - file_path = next(f).strip().split(',')[0] + if os.path.isdir(test_path): + file_path = next(Path(test_path).rglob('*.wav')) + else: + with open(test_path) as f: + # select a file to use for inference test + manifest = json.load(f) + file_name = manifest['samples'][0]['wav_path'] + directory = manifest['root_path'] + file_path = os.path.join(directory, file_name) transcribe_cfg = TranscribeConfig( model=ModelConfig( cuda=cuda, model_path=model_path, - use_half=use_half + precision=precision ), lm=lm_config, audio_path=file_path ) transcribe(transcribe_cfg) - def download_data(self, cfg: DatasetConfig): - download_an4(target_dir=cfg.target_dir, - manifest_dir=cfg.manifest_dir, - min_duration=cfg.min_duration, - max_duration=cfg.max_duration, - val_fraction=cfg.val_fraction, - sample_rate=cfg.sample_rate) - # Expected manifests paths - train_manifest = os.path.join(self.manifest_dir, 'an4_train_manifest.csv') - val_manifest = os.path.join(self.manifest_dir, 'an4_val_manifest.csv') - test_manifest = os.path.join(self.manifest_dir, 'an4_test_manifest.csv') + def download_data(self, + cfg: DatasetConfig, + folders: bool): + download_an4( + target_dir=cfg.target_dir, + manifest_dir=cfg.manifest_dir, + min_duration=cfg.min_duration, + max_duration=cfg.max_duration, + val_fraction=cfg.val_fraction, + sample_rate=cfg.sample_rate, + num_workers=cfg.num_workers + ) + + # Expected output paths + if folders: + train_path = os.path.join(self.target_dir, 'train/') + val_path = os.path.join(self.target_dir, 'val/') + test_path = os.path.join(self.target_dir, 'test/') + else: + train_path = os.path.join(self.manifest_dir, 'an4_train_manifest.json') + val_path = os.path.join(self.manifest_dir, 'an4_val_manifest.json') + test_path = os.path.join(self.manifest_dir, 'an4_test_manifest.json') # Assert manifest paths exists - assert os.path.exists(train_manifest) - assert os.path.exists(val_manifest) - assert os.path.exists(test_manifest) - return train_manifest, val_manifest, test_manifest + assert os.path.exists(train_path) + assert os.path.exists(val_path) + assert os.path.exists(test_path) + return train_path, val_path, test_path def create_training_config(self, - epoch: int, + limit_train_batches: int, + limit_val_batches: int, + max_epochs: int, batch_size: int, - train_manifest: str, - val_manifest: str, + train_path: str, + val_path: str, model_config: BiDirectionalConfig, - cuda: bool): + precision: int, + gpus: int): return DeepSpeechConfig( - training=TrainingConfig(epochs=epoch, - no_cuda=not cuda), - data=DataConfig(train_manifest=train_manifest, - val_manifest=val_manifest, - batch_size=batch_size), + trainer=DeepSpeechTrainerConf( + max_epochs=max_epochs, + precision=precision, + gpus=gpus, + checkpoint_callback=True, + limit_train_batches=limit_train_batches, + limit_val_batches=limit_val_batches + ), + data=DataConfig( + train_path=train_path, + val_path=val_path, + batch_size=batch_size + ), optim=AdamConfig(), model=model_config, - checkpointing=FileCheckpointConfig(save_folder=self.model_dir) + checkpoint=ModelCheckpointConf( + dirpath=self.model_dir, + save_last=True, + verbose=True + ) ) @@ -160,13 +213,38 @@ class AN4SmokeTest(DeepSpeechSmokeTest): def test_train_eval_inference(self): # Hardcoded sizes to reduce memory/time, and disabled GPU due to using TravisCI - model_cfg = BiDirectionalConfig(hidden_size=10, - hidden_layers=1) - self.build_train_evaluate_model(epoch=1, - batch_size=10, - model_config=model_cfg, - cuda=False, - use_half=False) + model_cfg = BiDirectionalConfig( + hidden_size=10, + hidden_layers=1 + ) + self.build_train_evaluate_model( + limit_train_batches=1, + limit_val_batches=1, + epoch=1, + batch_size=10, + model_config=model_cfg, + precision=32, + gpus=0, + folders=False + ) + + def test_train_eval_inference_folder(self): + """Test train/eval/inference using folder directories rather than manifest files""" + # Hardcoded sizes to reduce memory/time, and disabled GPU due to using TravisCI + model_cfg = BiDirectionalConfig( + hidden_size=10, + hidden_layers=1 + ) + self.build_train_evaluate_model( + limit_train_batches=1, + limit_val_batches=1, + epoch=1, + batch_size=10, + model_config=model_cfg, + precision=32, + gpus=0, + folders=True + ) if __name__ == '__main__': diff --git a/train.py b/train.py index 8433147d..0cabd379 100644 --- a/train.py +++ b/train.py @@ -1,16 +1,17 @@ import hydra from hydra.core.config_store import ConfigStore +from hydra_configs.pytorch_lightning.callbacks import ModelCheckpointConf from deepspeech_pytorch.configs.train_config import DeepSpeechConfig, AdamConfig, SGDConfig, BiDirectionalConfig, \ - UniDirectionalConfig, GCSCheckpointConfig, FileCheckpointConfig + UniDirectionalConfig, GCSCheckpointConfig from deepspeech_pytorch.training import train cs = ConfigStore.instance() cs.store(name="config", node=DeepSpeechConfig) cs.store(group="optim", name="sgd", node=SGDConfig) cs.store(group="optim", name="adam", node=AdamConfig) -cs.store(group="checkpointing", name="file", node=FileCheckpointConfig) -cs.store(group="checkpointing", name="gcs", node=GCSCheckpointConfig) +cs.store(group="checkpoint", name="file", node=ModelCheckpointConf) +cs.store(group="checkpoint", name="gcs", node=GCSCheckpointConfig) cs.store(group="model", name="bidirectional", node=BiDirectionalConfig) cs.store(group="model", name="unidirectional", node=UniDirectionalConfig) diff --git a/translations/README_JP.md b/translations/README_JP.md index 0b895daf..3cd364b7 100644 --- a/translations/README_JP.md +++ b/translations/README_JP.md @@ -223,7 +223,7 @@ python model.py --model-path models/deepspeech.pth トレーニング済みのモデルをテスト用データセットで評価したい場合は以下のスクリプトを実行してください。もちろん、フォーマットはトレーニング用データセットと同様である必要があります。 ``` -python test.py --model-path models/deepspeech.pth --test-manifest /path/to/test_manifest.csv --cuda +python test.py --model-path models/deepspeech.pth --test-path /path/to/test_manifest.csv --cuda ``` 一つの音声の文字起こしをテストできるスクリプトも用意されています。