From 5fabd1e83bfa6f0df144b3aee2987ccb70aec973 Mon Sep 17 00:00:00 2001 From: Raushan Turganbay Date: Thu, 6 Jun 2024 15:21:32 +0500 Subject: [PATCH] Generation: fix handling of special tokens (#31254) * fix special tokens in generatioon * fix test * add warning * fix the check * warn once * fix --- src/transformers/generation/utils.py | 55 ++++++++++----------- tests/generation/test_framework_agnostic.py | 4 +- 2 files changed, 29 insertions(+), 30 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 535e82b8e02828..c6819090892594 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1436,23 +1436,6 @@ def _get_cache(self, cache_implementation: str, max_batch_size: int, max_cache_l self._cache.reset() return self._cache - def _get_decoder_start_token_id( - self, decoder_start_token_id: Union[int, List[int]] = None, bos_token_id: int = None - ) -> int: - decoder_start_token_id = ( - decoder_start_token_id - if decoder_start_token_id is not None - else self.generation_config.decoder_start_token_id - ) - bos_token_id = bos_token_id if bos_token_id is not None else self.generation_config.bos_token_id - - if decoder_start_token_id is not None: - return decoder_start_token_id - elif bos_token_id is not None: - return bos_token_id - else: - return - def _supports_default_dynamic_cache(self) -> bool: """ Return `True` if current model can use a `DynamicCache` instance when initializing the `past_key_values`. @@ -1478,25 +1461,32 @@ def _prepare_special_tokens( function). However, if called outside `generate`, consider creating a copy of `generation_config` first. """ - # Convert special tokens to tensors (if they exist) - def _tensor_or_none(token, device=None): + # Convert special tokens to tensors (if they exist either in kwargs or in self.config) + def _tensor_or_none(token_kwargs, token_self, device=None): if device is None: device = self.device + token = token_kwargs if token_kwargs is not None else token_self if token is None or isinstance(token, torch.Tensor): return token return torch.tensor(token, device=device, dtype=torch.long) - # for BC we also try to get `decoder_start_token_id` from model's generation config (#30892) - if self.config.is_encoder_decoder: - generation_config.decoder_start_token_id = self._get_decoder_start_token_id( - generation_config.decoder_start_token_id, generation_config.bos_token_id - ) + bos_token_id = _tensor_or_none( + generation_config.bos_token_id, self.generation_config.bos_token_id, device=device + ) + eos_token_id = _tensor_or_none( + generation_config.eos_token_id, self.generation_config.eos_token_id, device=device + ) + pad_token_id = _tensor_or_none( + generation_config.pad_token_id, self.generation_config.pad_token_id, device=device + ) + decoder_start_token_id = _tensor_or_none( + generation_config.decoder_start_token_id, self.generation_config.decoder_start_token_id, device=device + ) - bos_token_id = _tensor_or_none(generation_config.bos_token_id, device=device) - eos_token_id = _tensor_or_none(generation_config.eos_token_id, device=device) - pad_token_id = _tensor_or_none(generation_config.pad_token_id, device=device) - decoder_start_token_id = _tensor_or_none(generation_config.decoder_start_token_id, device=device) + # for BC we also try to get `decoder_start_token_id` or `bos_token_id` (#30892) + if self.config.is_encoder_decoder: + decoder_start_token_id = decoder_start_token_id if decoder_start_token_id is not None else bos_token_id # We can have more than one eos token. Always treat it as a 1D tensor (when it exists). if eos_token_id is not None and eos_token_id.ndim == 0: @@ -1512,6 +1502,15 @@ def _tensor_or_none(token, device=None): pad_token_id = eos_token_id[0] logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{pad_token_id} for open-end generation.") + # we can't infer attn mask if pad token is set to be eos token in model's generation config + if eos_token_id is not None and torch.isin(elements=eos_token_id, test_elements=pad_token_id).any(): + if kwargs_has_attention_mask is not None and not kwargs_has_attention_mask: + logger.warning_once( + "The attention mask is not set and cannot be inferred from input because pad token is same as eos token." + "As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` " + "to obtain reliable results." + ) + # Sanity checks/warnings if self.config.is_encoder_decoder and decoder_start_token_id is None: raise ValueError( diff --git a/tests/generation/test_framework_agnostic.py b/tests/generation/test_framework_agnostic.py index f4f13dd8d555ea..634824c2b38ea0 100644 --- a/tests/generation/test_framework_agnostic.py +++ b/tests/generation/test_framework_agnostic.py @@ -161,6 +161,7 @@ def test_transition_scores_greedy_search(self): tokenizer.pad_token = tokenizer.eos_token model = model_cls.from_pretrained("distilbert/distilgpt2") + model.generation_config.eos_token_id = None input_ids = tokenizer(articles, return_tensors=return_tensors, padding=True).input_ids if is_pt: model = model.to(torch_device) @@ -170,7 +171,6 @@ def test_transition_scores_greedy_search(self): input_ids=input_ids, max_new_tokens=5, pad_token_id=tokenizer.eos_token_id, - eos_token_id=None, return_dict_in_generate=True, output_scores=True, ) @@ -197,6 +197,7 @@ def test_transition_scores_greedy_search_normalized(self): tokenizer.pad_token = tokenizer.eos_token model = model_cls.from_pretrained("distilbert/distilgpt2") + model.generation_config.eos_token_id = None input_ids = tokenizer(articles, return_tensors=return_tensors, padding=True).input_ids if is_pt: model = model.to(torch_device) @@ -206,7 +207,6 @@ def test_transition_scores_greedy_search_normalized(self): input_ids=input_ids, max_new_tokens=5, pad_token_id=tokenizer.eos_token_id, - eos_token_id=None, return_dict_in_generate=True, output_scores=True, )