diff --git a/src/transformers/models/whisper/generation_whisper.py b/src/transformers/models/whisper/generation_whisper.py index 1790d8d3c765fa..1b4ecb831bf98f 100644 --- a/src/transformers/models/whisper/generation_whisper.py +++ b/src/transformers/models/whisper/generation_whisper.py @@ -1571,7 +1571,7 @@ def detect_language( ) with torch.no_grad(): - logits = self(**inputs, decoder_input_ids=decoder_input_ids).logits[:, -1] + logits = self(**inputs, decoder_input_ids=decoder_input_ids, use_cache=False).logits[:, -1] non_lang_mask = torch.ones_like(logits[0], dtype=torch.bool) non_lang_mask[list(generation_config.lang_to_id.values())] = False