Skip to content

Commit

Permalink
Zipformer recipe for SPGISpeech (#1449)
Browse files Browse the repository at this point in the history
  • Loading branch information
marcoyang1998 authored Feb 22, 2024
1 parent 819bb45 commit 2483b8b
Show file tree
Hide file tree
Showing 16 changed files with 2,912 additions and 18 deletions.
97 changes: 81 additions & 16 deletions egs/spgispeech/ASR/RESULTS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,70 @@
## Results

### SPGISpeech BPE training results (Zipformer Transducer)

#### 2024-01-05

#### Zipformer encoder + embedding decoder

Transducer: Zipformer encoder + stateless decoder.

The WERs are:

| | dev | val | comment |
|---------------------------|------------|------------|------------------------------------------|
| greedy search | 2.08 | 2.14 | --epoch 30 --avg 10 |
| modified beam search | 2.05 | 2.09 | --epoch 30 --avg 10 --beam-size 4 |
| fast beam search | 2.07 | 2.17 | --epoch 30 --avg 10 --beam 20 --max-contexts 8 --max-states 64 |

**NOTE:** SPGISpeech transcripts can be prepared in `ortho` or `norm` ways, which refer to whether the
transcripts are orthographic or normalized. These WERs correspond to the normalized transcription
scenario.

The training command for reproducing is given below:

```
export CUDA_VISIBLE_DEVICES="0,1,2,3"
python zipformer/train.py \
--world-size 4 \
--num-epochs 30 \
--start-epoch 1 \
--use-fp16 1 \
--exp-dir zipformer/exp \
--num-workers 2 \
--max-duration 1000
```

The decoding command is:
```
# greedy search
python ./zipformer/decode.py \
--epoch $epoch \
--avg $avg \
--exp-dir ./zipformer/exp \
--max-duration 1000 \
--decoding-method greedy_search
# modified beam search
python ./zipformer/decode.py \
--epoch $epoch \
--avg $avg \
--exp-dir ./zipformer/exp \
--max-duration 1000 \
--decoding-method modified_beam_search
# fast beam search
python ./zipformer/decode.py \
--epoch $epoch \
--avg $avg \
--exp-dir ./zipformer/exp \
--max-duration 1000 \
--decoding-method fast_beam_search \
--beam 4 \
--max-contexts 4 \
--max-states 8
```

### SPGISpeech BPE training results (Pruned Transducer)

#### 2022-05-11
Expand Down Expand Up @@ -43,28 +108,28 @@ The decoding command is:
```
# greedy search
./pruned_transducer_stateless2/decode.py \
--iter 696000 --avg 10 \
--exp-dir ./pruned_transducer_stateless2/exp \
--max-duration 100 \
--decoding-method greedy_search
--iter 696000 --avg 10 \
--exp-dir ./pruned_transducer_stateless2/exp \
--max-duration 100 \
--decoding-method greedy_search
# modified beam search
./pruned_transducer_stateless2/decode.py \
--iter 696000 --avg 10 \
--exp-dir ./pruned_transducer_stateless2/exp \
--max-duration 100 \
--decoding-method modified_beam_search \
--beam-size 4
--iter 696000 --avg 10 \
--exp-dir ./pruned_transducer_stateless2/exp \
--max-duration 100 \
--decoding-method modified_beam_search \
--beam-size 4
# fast beam search
./pruned_transducer_stateless2/decode.py \
--iter 696000 --avg 10 \
--exp-dir ./pruned_transducer_stateless2/exp \
--max-duration 1500 \
--decoding-method fast_beam_search \
--beam 4 \
--max-contexts 4 \
--max-states 8
--iter 696000 --avg 10 \
--exp-dir ./pruned_transducer_stateless2/exp \
--max-duration 1500 \
--decoding-method fast_beam_search \
--beam 4 \
--max-contexts 4 \
--max-states 8
```

Pretrained model is available at <https://huggingface.co/desh2608/icefall-asr-spgispeech-pruned-transducer-stateless2>
Expand Down
23 changes: 21 additions & 2 deletions egs/spgispeech/ASR/pruned_transducer_stateless2/asr_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,20 @@ def add_arguments(cls, parser: argparse.ArgumentParser):
help="Determines the maximum duration of a concatenated cut "
"relative to the duration of the longest cut in a batch.",
)
group.add_argument(
"--drop-last",
type=str2bool,
default=False,
help="When enabled, the last batch will be dropped",
)
group.add_argument(
"--return-cuts",
type=str2bool,
default=True,
help="When enabled, each batch will have the "
"field: batch['supervisions']['cut'] with the cuts that "
"were used to construct it.",
)
group.add_argument(
"--gap",
type=float,
Expand Down Expand Up @@ -143,7 +157,7 @@ def add_arguments(cls, parser: argparse.ArgumentParser):
group.add_argument(
"--num-workers",
type=int,
default=8,
default=2,
help="The number of training dataloader workers that "
"collect the batches.",
)
Expand Down Expand Up @@ -176,7 +190,7 @@ def train_dataloaders(
The state dict for the training sampler.
"""
logging.info("About to get Musan cuts")
cuts_musan = load_manifest(self.args.manifest_dir / "cuts_musan.jsonl.gz")
cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")

transforms = []
if self.args.enable_musan:
Expand Down Expand Up @@ -223,11 +237,13 @@ def train_dataloaders(
cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
input_transforms=input_transforms,
return_cuts=self.args.return_cuts,
)
else:
train = K2SpeechRecognitionDataset(
cut_transforms=transforms,
input_transforms=input_transforms,
return_cuts=self.args.return_cuts,
)

logging.info("Using DynamicBucketingSampler.")
Expand Down Expand Up @@ -276,10 +292,12 @@ def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
validate = K2SpeechRecognitionDataset(
cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
return_cuts=self.args.return_cuts,
)
else:
validate = K2SpeechRecognitionDataset(
cut_transforms=transforms,
return_cuts=self.args.return_cuts,
)
valid_sampler = DynamicBucketingSampler(
cuts_valid,
Expand All @@ -303,6 +321,7 @@ def test_dataloaders(self, cuts: CutSet) -> DataLoader:
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
if self.args.on_the_fly_feats
else PrecomputedFeatures(),
return_cuts=self.args.return_cuts,
)
sampler = DynamicBucketingSampler(
cuts, max_duration=self.args.max_duration, shuffle=False
Expand Down
1 change: 1 addition & 0 deletions egs/spgispeech/ASR/zipformer/asr_datamodule.py
1 change: 1 addition & 0 deletions egs/spgispeech/ASR/zipformer/beam_search.py
Loading

0 comments on commit 2483b8b

Please sign in to comment.