-
Notifications
You must be signed in to change notification settings - Fork 621
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Pytorch Lightning Integration (#569)
* Added minimal code to integrate Pytorch Lightning into training.py * Added autocast support, removed intra epoch checkpointing for simplicity, integrated checkpoint support, fixed validation support * Fixed multi-gpu support * Fixed smoke test, pretrained tests will be broken till new model release Added trains viz logging Precision * Updated README, fixed server class, updated k8s config file, added fix for adam Trains support, removed autocast since this is handled via lightning * Swapped to using tqdm write for readability when checkpointing, added an4 config * Added base script for each dataset, updated default params * Swapped to using native CTC, updated common voice script, removed incorrect lightning version * Updated cv params and output manifest location, set default epochs to the epochs used for previous release * Disable trains logger for now, simplified checkpointing logic for new release * Added new metrics class, removed save_output/verbose for now, using new ModelCheckpoint class for model saving * multiprocess duration collection for speed, allow loading from file path, refactor path name and test * Swap to latest release candidate, fixed flag reference * Format smoke test, update path to best save k model * Update to latest RC * Removed trains logging, rely on PL tensorboard. swap to saving json object for manifest to modify root path * Ensure abs path for manifest root path * Use absolute paths for manifest * Update requirements, abstract all PL trainer arguments * Enable checkpoint callback * Enable checkpoint callback, add verbosity * Add sharded as a dependency for better memory use * Set num workers, add spec augment * Update deepspeech_pytorch/data/utils.py Co-authored-by: Anas Abou Allaban <[email protected]> * Specify blank index explicitly * Add blank index to ctc loss * Fix CI * Fix Syntax Warning * Fix install requirements * Use torchaudio (#607) * Use torchaudio * Add torchaudio to reqs * Fixes for testing, update AN4 config, update dockerfile base image * Add noninteractive to remove stalling * revert * Update API Co-authored-by: Sean Narenthiran <[email protected]> Co-authored-by: Anas Abou Allaban <[email protected]> Co-authored-by: Anas Abou Allaban <[email protected]>
- Loading branch information
1 parent
4cb209a
commit d9790d9
Showing
37 changed files
with
1,043 additions
and
1,239 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
# @package _global_ | ||
data: | ||
train_path: data/an4_train_manifest.json | ||
val_path: data/an4_val_manifest.json | ||
batch_size: 8 | ||
num_workers: 8 | ||
trainer: | ||
max_epochs: 70 | ||
gpus: 1 | ||
precision: 16 | ||
gradient_clip_val: 400 # Norm cutoff to prevent explosion of gradients | ||
accelerator: ddp | ||
plugins: ddp_sharded | ||
checkpoint_callback: True | ||
checkpoint: | ||
save_top_k: 1 | ||
monitor: "wer" | ||
verbose: True |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
# @package _global_ | ||
data: | ||
train_path: data/commonvoice_train_manifest.json | ||
val_path: data/commonvoice_dev_manifest.json | ||
num_workers: 8 | ||
augmentation: | ||
spec_augment: True | ||
trainer: | ||
max_epochs: 16 | ||
gpus: 1 | ||
precision: 16 | ||
gradient_clip_val: 400 # Norm cutoff to prevent explosion of gradients | ||
accelerator: ddp | ||
plugins: ddp_sharded | ||
checkpoint_callback: True | ||
checkpoint: | ||
save_top_k: 1 | ||
monitor: "wer" | ||
verbose: True |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
# @package _global_ | ||
data: | ||
train_path: data/libri_train_manifest.json | ||
val_path: data/libri_val_manifest.json | ||
num_workers: 8 | ||
augmentation: | ||
spec_augment: True | ||
trainer: | ||
max_epochs: 16 | ||
gpus: 1 | ||
precision: 16 | ||
gradient_clip_val: 400 # Norm cutoff to prevent explosion of gradients | ||
accelerator: ddp | ||
plugins: ddp_sharded | ||
checkpoint_callback: True | ||
checkpoint: | ||
save_top_k: 1 | ||
monitor: "wer" | ||
verbose: True |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
# @package _global_ | ||
data: | ||
train_path: data/ted_train_manifest.json | ||
val_path: data/ted_val_manifest.json | ||
num_workers: 8 | ||
augmentation: | ||
spec_augment: True | ||
trainer: | ||
max_epochs: 16 | ||
gpus: 1 | ||
precision: 16 | ||
gradient_clip_val: 400 # Norm cutoff to prevent explosion of gradients | ||
accelerator: ddp | ||
plugins: ddp_sharded | ||
checkpoint_callback: True | ||
checkpoint: | ||
save_top_k: 1 | ||
monitor: "wer" | ||
verbose: True |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.