Skip to content

Commit

Permalink
[Bugfix] Avoid truncating the outputs based on string lengths (#201)
Browse files Browse the repository at this point in the history
* Fix context size

* - redundant condition

---------

Co-authored-by: Clémentine Fourrier <[email protected]>
  • Loading branch information
anton-l and clefourrier authored Jul 8, 2024
1 parent 843a0f8 commit 6064695
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 41 deletions.
37 changes: 17 additions & 20 deletions src/lighteval/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,27 +530,7 @@ def greedy_until(
returns_logits = batch[0].use_logits
num_samples = batch[0].num_samples

# The main question for this step is the following:
# Would we rather truncate the prompt to allow generation to go to max_new_tokens, at the risk
# of loosing some meaning, or have some generations that are exceedingly short?
# The choice we go for here is to avoid truncating the prompt if we can, since it
# should have been managed by the prompt creator/few shot manager if requested by the user.
context = [c.context for c in batch]
smallest_context = min(len(c) for c in context)
biggest_context = max(len(c) for c in context)
if smallest_context > self.max_length:
hlog_warn(
f"The smallest context of your batch ({smallest_context}) is bigger than the maximum context size allowed by the model ({self.max_length}) for a task in"
+ str({i.task_name for i in batch})
+ ". This is likely to lead to some errors." # noqa C401
)

if (
biggest_context > self.max_length
): # There will be truncation of at least one sample, maximum generation size will be one
max_new_tokens = 1
else: # We can't allow generation of more than max_length
max_new_tokens = min(self.max_length - biggest_context, max_new_tokens)

# See doc https://huggingface.co/docs/transformers/v4.38.2/en/pad_truncation#padding-and-truncation
# Will do left truncation and padding, as defined when creating the tokenizer
Expand All @@ -563,6 +543,23 @@ def greedy_until(
add_special_tokens=self.add_special_tokens,
).to(self.device)

# The main question for this step is the following:
# Would we rather truncate the prompt to allow generation to go to max_new_tokens, at the risk
# of losing some meaning, or have some generations that are exceedingly short?
# The choice we go for here is to avoid truncating the prompt if we can, since it
# should have been managed by the prompt creator/few shot manager if requested by the user.
context_size = tokenized["input_ids"].shape[1]
if context_size > self.max_length:
hlog_warn(
f"The context size of your batch ({context_size}) is bigger than the maximum context size allowed by the model ({self.max_length}) for a task in"
+ str({i.task_name for i in batch})
+ ". This is likely to lead to some errors." # noqa C401
)
# There will be truncation of at least one sample, maximum generation size will be one
max_new_tokens = 1
else: # We can't allow generation of more than max_length
max_new_tokens = min(self.max_length - context_size, max_new_tokens)

prepared_batch = Batch(
input_ids=tokenized["input_ids"],
input_lengths=[len(item == 1) for item in tokenized["attention_mask"]],
Expand Down
39 changes: 18 additions & 21 deletions src/lighteval/models/nanotron_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1207,27 +1207,7 @@ def greedy_until(
"Nonotron models does not allow sampling evaluations - this is likely to fail or provide problematic results"
)

# The main question for this step is the following:
# Would we rather truncate the prompt to allow generation to go to max_new_tokens, at the risk
# of loosing some meaning, or have some generations that are exceedingly short?
# The choice we go for here is to avoid truncating the prompt if we can, since it
# should have been managed by the prompt creator/few shot manager if requested by the user.
context = [c.context for c in batch] # or tokenized context?
smallest_context = min(len(c) for c in context)
biggest_context = max(len(c) for c in context)
if smallest_context > self.max_length:
hlog_warn(
f"The smallest context of your batch ({smallest_context}) is bigger than the maximum context size allowed by the model ({self.max_length}) for a task in"
+ str({i.task_name for i in batch})
+ ". This is likely to lead to some errors." # noqa C401
)

if (
biggest_context > self.max_length
): # There will be truncation of at least one sample, maximum generation size will be one
max_new_tokens = 1
else: # We can't allow generation of more than max_length
max_new_tokens = min(self.max_length - biggest_context, max_new_tokens)
context = [c.context for c in batch]

# See doc https://huggingface.co/docs/transformers/v4.38.2/en/pad_truncation#padding-and-truncation
# Will do left truncation and padding, as defined when creating the tokenizer
Expand All @@ -1240,6 +1220,23 @@ def greedy_until(
add_special_tokens=self.add_special_tokens,
).to(self.device)

# The main question for this step is the following:
# Would we rather truncate the prompt to allow generation to go to max_new_tokens, at the risk
# of losing some meaning, or have some generations that are exceedingly short?
# The choice we go for here is to avoid truncating the prompt if we can, since it
# should have been managed by the prompt creator/few shot manager if requested by the user.
context_size = tokenized["input_ids"].shape[1]
if context_size > self.max_length:
hlog_warn(
f"The context size of your batch ({context_size}) is bigger than the maximum context size allowed by the model ({self.max_length}) for a task in"
+ str({i.task_name for i in batch})
+ ". This is likely to lead to some errors." # noqa C401
)
# There will be truncation of at least one sample, maximum generation size will be one
max_new_tokens = 1
else: # We can't allow generation of more than max_length
max_new_tokens = min(self.max_length - context_size, max_new_tokens)

batch_model = Batch(
input_ids=tokenized["input_ids"],
input_lengths=[len(item == 1) for item in tokenized["attention_mask"]],
Expand Down

0 comments on commit 6064695

Please sign in to comment.