This sub-folder contains all the training and inference scripts to reproduce the Distil-Whisper project. Distil-Whisper is written in JAX to leverage the fast training and inference speed offered by TPU v4 hardware. However, it also works efficiently on GPU hardware without any additional code changes.
Reproducing the Distil-Whisper project requires four stages to be completed in successive order:
This README is partitioned according to the four stages. Each section provides a minimal example for running the scripts used in the project. The final scripts used to train the model are referenced in-line.
It is worth noting that the experiments performed in JAX/Flax have been on English ASR only. For multilingual training code, the PyTorch Training Code can easily be used, facilitating anyone to run Whisper distillation on a language of their choice.
Distil-Whisper is written in Python, JAX and Flax, and heavily leverages the Flax Whisper implementation in 🤗 Transformers. The instructions for installing the package are as follows:
- Install JAX from the official instructions, ensuring you install the correct version for your hardware (GPU or TPU).
- Install the
distil_whisper
package by cloning the repository and performing an editable installation:
git clone https://github.com/huggingface/distil-whisper.git
cd distil-whisper/training/flax
pip install -e .
Pseudo-labelling is the process of generating target text predictions for the input audio data using the teacher model. The generated text labels then replace the ground truth text labels when performing distillation. The rationale for using pseudo-labels instead of ground truth labels is to circumvent the issue of inconsistent transcription formatting across datasets.
The python script run_pseudo_labelling.py
is a flexible inference script that can be used
to generate pseudo-labels under a range of settings, including using both greedy and beam-search. It is also compatible
with 🤗 Datasets streaming mode, allowing users to load massive audio
datasets with no disk space requirements. For more information on streaming mode, the reader is referred to the
blog post: A Complete Guide to Audio Datasets.
The following script demonstrates how to pseudo-label the LibriSpeech 960h dataset with greedy sampling and streaming mode:
#!/usr/bin/env bash
python run_pseudo_labelling.py \
--model_name_or_path "openai/whisper-large-v2" \
--dataset_name "librispeech_asr" \
--dataset_config_name "all" \
--data_split_name "train.clean.100+train.clean.360+train.other.500" \
--text_column_name "text" \
--output_dir "./transcriptions" \
--per_device_eval_batch_size 16 \
--max_label_length 256 \
--dtype "bfloat16" \
--report_to "wandb" \
--dataloader_num_workers 16 \
--streaming \
--push_to_hub \
--generation_num_beams 1 # for greedy, set >1 for beam
The script will save the generated pseudo-labels alongside the file ids to the output directory output_dir
. Adding the
--push_to_hub
argument uploads the generated pseudo-labels to the Hugging Face Hub on save.
The directory pseudo_labelling_scripts
contains a collection of bash scripts for
pseudo-labelling all 10 audio datasets used in the project. The datasets with the Whisper generated transcriptions
can be found on the Hugging Face Hub under the Distil Whisper organisation.
They can be re-used should you wish to bypass the data labelling stage of the reproduction.
The script create_student_model.py
can be used to initialise a small student model
from a large teacher model. When initialising a student model with fewer layers than the teacher model, the student is
initialised by copying maximally spaced layers from the teacher, as per the DistilBart
recommendations.
The following command demonstrates how to initialise a student model from the large-v2 checkpoint, with all 32 encoder layer and 2 decoder layers. The 2 student decoder layers are copied from teacher layers 1 and 32 respectively, as the maximally spaced layers.
#!/usr/bin/env bash
python create_student_model.py \
--teacher_checkpoint "openai/whisper-large-v2" \
--encoder_layers 32 \
--decoder_layers 2 \
--save_dir "./large-32-2" \
--push_to_hub
The script run_distillation.py
is an end-to-end script for loading multiple
datasets, a student model, a teacher model, and performing teacher-student distillation. It uses the loss formulation
from DistilBart, which is a combination of a cross-entropy, KL-divergence and
mean-square error (MSE) loss:
The weight assigned to the MSE loss is configurable. The others are fixed to the values from the DistilBART paper.
The following command takes the LibriSpeech 960h dataset that was pseudo-labelled in the first stage and trains the
2-layer decoder model intialised in the previous step. Note that multiple training datasets and splits can be loaded
by separating the dataset arguments by +
symbols. Thus, the script generalises to any number of training datasets.
#!/usr/bin/env bash
python3 run_distillation.py \
--model_name_or_path "./large-32-2" \
--teacher_model_name_or_path "openai/whisper-large-v2" \
--train_dataset_name "librispeech_asr+librispeech_asr+librispeech_asr" \
--train_dataset_config_name "all+all+all" \
--train_split_name "train.clean.100+train.clean.360+train.other.500" \
--train_dataset_samples "100+360+500" \
--eval_dataset_name "librispeech_asr" \
--eval_dataset_config_name "all" \
--eval_split_name "validation.clean" \
--eval_steps 5000 \
--save_steps 5000 \
--warmup_steps 500 \
--learning_rate 0.0001 \
--lr_scheduler_type "constant_with_warmup" \
--logging_steps 25 \
--save_total_limit 1 \
--max_steps 20000 \
--wer_threshold 10 \
--per_device_train_batch_size 64 \
--per_device_eval_batch_size 64 \
--dataloader_num_workers 16 \
--dtype "bfloat16" \
--output_dir "./" \
--do_train \
--do_eval \
--use_scan \
--gradient_checkpointing \
--overwrite_output_dir \
--predict_with_generate \
--freeze_encoder \
--streaming \
--use_auth_token \
--push_to_hub
The above training script will take approximately 20 hours to complete on a TPU v4-8 and yield a final WER of 2.3%.
Training logs will be reported to TensorBoard and WandB, provided the relevant packages are available. An example of a saved checkpoint pushed to the Hugging Face Hub can be found here: large-32-2.
There are a few noteworthy arguments that can be configured to give optimal training performance:
train_dataset_samples
: defines the number of training samples in each dataset. Used to calculate the sampling probabilities in the dataloader. A good starting point is setting the samples to the number of hours of audio data in each split. A more refined strategy is setting it to the number of training samples in each split, however this might require downloading the dataset offline to compute these statistics.wer_threshold
: sets the WER threshold between the normalised pseudo-labels and normalised ground truth labels. Any samples with WER >wer_threshold
are discarded from the training data. This is beneficial to avoid training the student model on pseudo-labels where Whisper hallucinated or got the predictions grossly wrong.freeze_encoder
: whether to freeze the entire encoder of the student model during training. Beneficial when the student encoder is copied exactly from the teacher encoder. In this case, the encoder hidden-states from the teacher model are re-used for the student model. Stopping the gradient computation through the encoder and sharing the encoder hidden-states provides a significant memory saving, and can enable up to 2x batch sizes.dtype
: data type (dtype) in which the model computation should be performed. Note that this only controls the dtype of the computations (forward and backward pass), and not the dtype of the parameters or optimiser states.
The Distil Whisper project extends the above script to train on a combined dataset formed from 12 open-source ASR datasets,
totalling 22k hours and over 50k speakers. Template scripts to run training on this composite dataset can be found
in the directory distillation_scripts
.
There are two types of evaluation performed in Distil-Whisper:
- Short form: evaluation on audio samples less than 30s in duration. Examples include typical ASR test sets, such as the LibriSpeech validation set.
- Long form: evaluation on audio samples longer than 30s in duration. Examples include entire TED talks or earnings calls.
Both forms of evaluation are performed using the word-error rate (WER) metric.
The script run_eval.py
can be used to evaluate a trained student model over multiple validation sets.
The following example demonstrates how to evaluate the student model trained in the previous step on the LibriSpeech
validation.clean
and validation.other
dev sets. Again, it leverages streaming mode to bypass the need to download
the data offline:
#!/usr/bin/env bash
python run_eval.py \
--model_name_or_path "./large-32-2" \
--dataset_name "librispeech_asr+librispeech_asr" \
--dataset_config_name "all+all" \
--dataset_split_name "validation.clean+validation.other" \
--output_dir "./large-32-2" \
--per_device_eval_batch_size 64 \
--dtype "bfloat16" \
--dataloader_num_workers 16 \
--report_to "wandb" \
--streaming \
--predict_with_generate
Long form evaluation runs on the premise that a single long audio file can be chunked into smaller segments and inferred in parallel. The resulting transcriptions are then joined at the boundaries to give the final text prediction. A small overlap (or stride) is used between adjacent segments to ensure a continuous transcription across chunks.
This style of chunked inference is performed using the FlaxWhisperPipeline
class, which is heavily inspired from Whisper JAX.
The script run_long_form_transcription.py
can be used to evaluate the trained
student model on an arbitrary number of long-form evaluation sets. The following script demonstrates how to evaluate
the example student model on two such test sets, Earnings 21
and Earnings 22:
#!/usr/bin/env bash
python run_long_form_transcription.py \
--model_name_or_path "./large-32-2" \
--dataset_name "distil-whisper/earnings21+distil-whisper/earnings22" \
--dataset_config_name "default+default" \
--dataset_split_name "test+test+test+test" \
--text_column_name "transcription+transcription" \
--output_dir "./large-32-2" \
--per_device_eval_batch_size 64 \
--chunk_length_s 15 \
--dtype "bfloat16" \
--report_to "wandb" \
--streaming
The argument chunk_length_s
controls the length of the chunked audio samples. It should be set to match the typical
length of audio the student model was trained on. If unsure about what value of chunk_length_s
is optimal for your case,
it is recommended to run a sweep over all possible values. A template script for running a WandB sweep
can be found under run_chunk_length_s_sweep.yaml
.
We found there to be little-to-no difference in the downstream performance of the distilled model after pseudo labelling using either greedy or beam-search. We attribute this to the minimal difference in performance of the pre-trained Whisper model under greedy and beam-search decoding, giving pseudo-labelled transcriptions of similar quality. We encourage users to generate pseudo-labels using greedy decoding given it runs significantly faster. Beam search is only advised if the pre-trained model is hallucinating significantly on the audio inputs, in which case it helps reduce the frequency and severity of hallucinations. If using beam search, the number of beams can be kept low: even 2 beams helps reduce the amount of hallucinations significantly.
Whisper is trained on a timestamp prediction task as part of the pre-training set-up. Here, a fixed proportion of the pre-training data includes sequence-level timestamps as part of the transcription labels:
<|0.00|> Hey, this is a test transcription. <|3.42|>
Timestamp prediction is useful for enriching the transcriptions with timing information for downstream tasks, such as aligning the Whisper transcription with the output of a speaker diarization system, and also reduces the frequency of hallucinations.
The pseudo-labelling scrip run_pseudo_labelling.py
can be extended to predict timestamp
information in the audio data by appending the --return_timestamps
flag to the launch command. The timestamped labelled
data can be passed to the training script in exactly the same way as the non-timestamped version, and the pre-processing
function will take care of encoding the timestamps and appending the required task tokens.
Whisper is also pre-trained on a prompting task, where the transcription for the preceding utterance is fed as context to the current one:
<|startofprev|> This is the previous context from the preceding utterance.<|startoftranscript|> And this is the current utterance.<|endoftranscript|>
Annotating the transcriptions with previous context labels is only possible for datasets where we have consecutive files
and unique speaker ids, since we need to ensure segment i
directly follows on from segment i-1
if we use it as the
prompt.
As per the Whisper paper, we mask out the loss over the previous context tokens. At inference time, we can replace the previous context with a “prompt” to encourage the model to generate text in the style of the prompt (i.e. for specific named entities, or styles of transcription)
- 🤗 Hugging Face Transformers for the base Whisper implementation
- Google's TPU Research Cloud (TRC) programme for their generous provision of Cloud TPUs