Skip to content

Commit

Permalink
Generate: check that attention_mask is 2D (huggingface#33575)
Browse files Browse the repository at this point in the history
check attention mask in generate
  • Loading branch information
gante authored Sep 19, 2024
1 parent 413008c commit d9d59e7
Showing 1 changed file with 4 additions and 0 deletions.
4 changes: 4 additions & 0 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1864,6 +1864,10 @@ def generate(
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
inputs_tensor, generation_config._pad_token_tensor, generation_config._eos_token_tensor
)
elif kwargs_has_attention_mask:
# TODO (joao): generalize this check with other types of inputs
if model_input_name == "input_ids" and len(model_kwargs["attention_mask"].shape) > 2:
raise ValueError("`attention_mask` passed to `generate` must be 2D.")

if self.config.is_encoder_decoder and "encoder_outputs" not in model_kwargs:
# if model is encoder decoder encoder_outputs are created and added to `model_kwargs`
Expand Down

0 comments on commit d9d59e7

Please sign in to comment.