diff --git a/src/lighteval/models/base_model.py b/src/lighteval/models/base_model.py index df7b3e92c..7f17f24d9 100644 --- a/src/lighteval/models/base_model.py +++ b/src/lighteval/models/base_model.py @@ -530,27 +530,7 @@ def greedy_until( returns_logits = batch[0].use_logits num_samples = batch[0].num_samples - # The main question for this step is the following: - # Would we rather truncate the prompt to allow generation to go to max_new_tokens, at the risk - # of loosing some meaning, or have some generations that are exceedingly short? - # The choice we go for here is to avoid truncating the prompt if we can, since it - # should have been managed by the prompt creator/few shot manager if requested by the user. context = [c.context for c in batch] - smallest_context = min(len(c) for c in context) - biggest_context = max(len(c) for c in context) - if smallest_context > self.max_length: - hlog_warn( - f"The smallest context of your batch ({smallest_context}) is bigger than the maximum context size allowed by the model ({self.max_length}) for a task in" - + str({i.task_name for i in batch}) - + ". This is likely to lead to some errors." # noqa C401 - ) - - if ( - biggest_context > self.max_length - ): # There will be truncation of at least one sample, maximum generation size will be one - max_new_tokens = 1 - else: # We can't allow generation of more than max_length - max_new_tokens = min(self.max_length - biggest_context, max_new_tokens) # See doc https://huggingface.co/docs/transformers/v4.38.2/en/pad_truncation#padding-and-truncation # Will do left truncation and padding, as defined when creating the tokenizer @@ -563,6 +543,23 @@ def greedy_until( add_special_tokens=self.add_special_tokens, ).to(self.device) + # The main question for this step is the following: + # Would we rather truncate the prompt to allow generation to go to max_new_tokens, at the risk + # of losing some meaning, or have some generations that are exceedingly short? + # The choice we go for here is to avoid truncating the prompt if we can, since it + # should have been managed by the prompt creator/few shot manager if requested by the user. + context_size = tokenized["input_ids"].shape[1] + if context_size > self.max_length: + hlog_warn( + f"The context size of your batch ({context_size}) is bigger than the maximum context size allowed by the model ({self.max_length}) for a task in" + + str({i.task_name for i in batch}) + + ". This is likely to lead to some errors." # noqa C401 + ) + # There will be truncation of at least one sample, maximum generation size will be one + max_new_tokens = 1 + else: # We can't allow generation of more than max_length + max_new_tokens = min(self.max_length - context_size, max_new_tokens) + prepared_batch = Batch( input_ids=tokenized["input_ids"], input_lengths=[len(item == 1) for item in tokenized["attention_mask"]], diff --git a/src/lighteval/models/nanotron_model.py b/src/lighteval/models/nanotron_model.py index efe207091..b75bc2b27 100644 --- a/src/lighteval/models/nanotron_model.py +++ b/src/lighteval/models/nanotron_model.py @@ -1207,27 +1207,7 @@ def greedy_until( "Nonotron models does not allow sampling evaluations - this is likely to fail or provide problematic results" ) - # The main question for this step is the following: - # Would we rather truncate the prompt to allow generation to go to max_new_tokens, at the risk - # of loosing some meaning, or have some generations that are exceedingly short? - # The choice we go for here is to avoid truncating the prompt if we can, since it - # should have been managed by the prompt creator/few shot manager if requested by the user. - context = [c.context for c in batch] # or tokenized context? - smallest_context = min(len(c) for c in context) - biggest_context = max(len(c) for c in context) - if smallest_context > self.max_length: - hlog_warn( - f"The smallest context of your batch ({smallest_context}) is bigger than the maximum context size allowed by the model ({self.max_length}) for a task in" - + str({i.task_name for i in batch}) - + ". This is likely to lead to some errors." # noqa C401 - ) - - if ( - biggest_context > self.max_length - ): # There will be truncation of at least one sample, maximum generation size will be one - max_new_tokens = 1 - else: # We can't allow generation of more than max_length - max_new_tokens = min(self.max_length - biggest_context, max_new_tokens) + context = [c.context for c in batch] # See doc https://huggingface.co/docs/transformers/v4.38.2/en/pad_truncation#padding-and-truncation # Will do left truncation and padding, as defined when creating the tokenizer @@ -1240,6 +1220,23 @@ def greedy_until( add_special_tokens=self.add_special_tokens, ).to(self.device) + # The main question for this step is the following: + # Would we rather truncate the prompt to allow generation to go to max_new_tokens, at the risk + # of losing some meaning, or have some generations that are exceedingly short? + # The choice we go for here is to avoid truncating the prompt if we can, since it + # should have been managed by the prompt creator/few shot manager if requested by the user. + context_size = tokenized["input_ids"].shape[1] + if context_size > self.max_length: + hlog_warn( + f"The context size of your batch ({context_size}) is bigger than the maximum context size allowed by the model ({self.max_length}) for a task in" + + str({i.task_name for i in batch}) + + ". This is likely to lead to some errors." # noqa C401 + ) + # There will be truncation of at least one sample, maximum generation size will be one + max_new_tokens = 1 + else: # We can't allow generation of more than max_length + max_new_tokens = min(self.max_length - context_size, max_new_tokens) + batch_model = Batch( input_ids=tokenized["input_ids"], input_lengths=[len(item == 1) for item in tokenized["attention_mask"]],