Skip to content

Commit

Permalink
add distill whisper results (k2-fsa#1648)
Browse files Browse the repository at this point in the history
  • Loading branch information
yuekaizhang authored and Your Name committed Aug 9, 2024
1 parent 2a56d44 commit 4446a04
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 9 deletions.
9 changes: 5 additions & 4 deletions egs/multi_zh-hans/ASR/RESULTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@

Character Error Rates (CERs) listed below are produced by the checkpoint of the second epoch using greedy search.

| Datasets | alimeeting | alimeeting | aishell-1 | aishell-1 | aishell-2 | aishell-2 | aishell-4 | magicdata | magicdata | kespeech-asr | kespeech-asr | kespeech-asr | WenetSpeech |
|--------------------------------|-------------------|--------------|----------------|-------------|------------------|-------------|------------------|------------------|-------------|-----------------------|-----------------------|-------------|-------------------|
| Split | eval| test | dev | test | dev| test | test | dev| test | dev phase1 | dev phase2 | test | test meeting |
| Greedy Search | 23.22 | 28.24 | 0.61 | 0.66 | 2.67 | 2.80 | 16.61 | 2.56 | 2.21 | 4.73 | 1.90 | 5.98 | 8.13 |
|Model| Datasets | alimeeting | alimeeting | aishell-1 | aishell-1 | aishell-2 | aishell-2 | aishell-4 | magicdata | magicdata | kespeech-asr | kespeech-asr | kespeech-asr | WenetSpeech |
|-|--------------------------------|-------------------|--------------|----------------|-------------|------------------|-------------|------------------|------------------|-------------|-----------------------|-----------------------|-------------|-------------------|
| | Split | eval| test | dev | test | dev| test | test | dev| test | dev phase1 | dev phase2 | test | test meeting |
|whisper-large-v2-ft |Greedy Search | 23.22 | 28.24 | 0.61 | 0.66 | 2.67 | 2.80 | 16.61 | 2.56 | 2.21 | 4.73 | 1.90 | 5.98 | 8.13 |
|whisper-large-v2-ft-distill |Greedy Search | 24.91 | 26.73 | 0.91 | 0.94 | 2.71 | 2.98 | 17.65 | 2.81 | 2.47 | 5.16 | 2.10 | 6.27 | 8.34 |

Command for training is:
```bash
Expand Down
11 changes: 7 additions & 4 deletions egs/speechio/ASR/RESULTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,15 @@
| 7 | aispeech_api_zh | 3.62% | 2023.12 |
| 8 | **whisper-large-ft-v1** | **4.32%** | 2024.04 |
| 9 | **whisper-large-ft-v0.5** | **4.60%** | 2024.04 |
| 10 | **zipformer (70Mb)** | **6.17%** | 2023.10 |
| 11 | **whisper-large-ft-v0** | **6.34%** | 2023.03 |
| 12 | baidu_pro_api_zh | 7.29% | 2023.12 |
| 10 | **whisper-large-ft-v1-distill** | **4.71%** | 2024.04 |
| 11 | **zipformer (70Mb)** | **6.17%** | 2023.10 |
| 12 | **whisper-large-ft-v0** | **6.34%** | 2023.03 |
| 13 | baidu_pro_api_zh | 7.29% | 2023.12 |

Note: Above API results are from [SPEECHIO](https://github.com/SpeechColab/Leaderboard). All results used the default [normalize method.](https://github.com/SpeechColab/Leaderboard/blob/master/utils/benchmark.sh#L67)

For **whisper-large-ft-v1-distill**, instead of actually using distillation loss for training, the model structure and parameter initialization method from the [distill-whisper](https://arxiv.org/abs/2311.00430) paper were adopted: only the first and last layers of the decoder were retained.

<details><summary> Detail all models </summary><p>

| Model | Training Set | Note |
Expand All @@ -31,7 +34,7 @@ Note: Above API results are from [SPEECHIO](https://github.com/SpeechColab/Leade
|[whisper-large-ft-v0](https://huggingface.co/yuekai/icefall_asr_wenetspeech_whisper/tree/main/exp_large_v2)| wenetspeech | greedy_search, 3 epochs|
|[whisper-large-ft-v0.5](https://huggingface.co/yuekai/icefall_asr_wenetspeech_whisper/blob/main/epoch-2-avg-5.pt)| wenetspeech(updated) | [wenetspeech update method](https://github.com/k2-fsa/icefall/blob/master/egs/wenetspeech/ASR/local/fix_manifest.py), greedy_search, 2 epochs |
|[whisper-large-ft-v1](https://huggingface.co/yuekai/icefall_asr_multi-hans-zh_whisper/tree/main/v1.1)|wenetspeech(updated), other multi-hans-zh exclude datatang 200h|[wenetspeech update method](https://github.com/k2-fsa/icefall/blob/master/egs/wenetspeech/ASR/local/fix_manifest.py), greedy search, 3 epochs|

|[whisper-large-ft-v1-distill](https://huggingface.co/yuekai/icefall_asr_multi-hans-zh_whisper/tree/main/v1-distill)|wenetspeech(updated), other multi-hans-zh exclude datatang 200h|[wenetspeech update method](https://github.com/k2-fsa/icefall/blob/master/egs/wenetspeech/ASR/local/fix_manifest.py), greedy search, 6 epochs|
</details>


Expand Down
12 changes: 11 additions & 1 deletion egs/speechio/ASR/whisper/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
from multi_dataset import MultiDataset
from tn.chinese.normalizer import Normalizer
from whisper.normalizers import BasicTextNormalizer
from whisper_decoder_forward_monkey_patch import replace_whisper_decoder_forward
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
from zhconv import convert

Expand Down Expand Up @@ -215,7 +216,7 @@ def get_parser():
"--model-name",
type=str,
default="large-v2",
choices=["large-v2", "large-v3", "medium", "small", "base", "tiny"],
choices=["large-v2", "large-v3", "medium", "base", "small", "tiny"],
help="""The model name to use.
""",
)
Expand All @@ -227,6 +228,13 @@ def get_parser():
help="replace whisper encoder forward method to remove input length restriction",
)

parser.add_argument(
"--use-distill-whisper",
type=str2bool,
default=False,
help="Whether to use architecture of distill whisper.",
)

return parser


Expand Down Expand Up @@ -431,6 +439,8 @@ def main():

if params.remove_whisper_encoder_input_length_restriction:
replace_whisper_encoder_forward()
if params.use_distill_whisper:
replace_whisper_decoder_forward()
model = whisper.load_model(params.model_name, "cpu")
if params.epoch > 0:
if params.avg > 1:
Expand Down

0 comments on commit 4446a04

Please sign in to comment.