Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Homogeneize generation params #428

Merged
merged 44 commits into from
Jan 2, 2025
Merged
Changes from 1 commit
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
1ff5a78
adding input generation config
clefourrier Dec 9, 2024
12c6a90
added tgi model
clefourrier Dec 9, 2024
c9657d2
grammar is task dependant, removed from the cofnig
clefourrier Dec 9, 2024
ac6565a
added openai config + moved everything to dict
clefourrier Dec 9, 2024
2628571
added generation configs to models
clefourrier Dec 9, 2024
c24bf9b
added generation configs to models
clefourrier Dec 9, 2024
0aa2e19
fix
clefourrier Dec 9, 2024
e3311bd
fix
clefourrier Dec 9, 2024
a3f535f
added doc
clefourrier Dec 9, 2024
286668f
Saved GenerationParameter class in model config classes, then saved i…
clefourrier Dec 10, 2024
0b2475a
changed model args
clefourrier Dec 10, 2024
521559f
test
clefourrier Dec 10, 2024
91363fe
Merge branch 'main' into clem_homogeneize_generation_params
clefourrier Dec 11, 2024
c088ab6
updated launchers
clefourrier Dec 11, 2024
e1bd34f
Merge branch 'main' into clem_homogeneize_generation_params
clefourrier Dec 12, 2024
3eb7d0f
rename base_model to transformers_model
clefourrier Dec 12, 2024
a585701
removed the use of a GenerationConfig object, as it's got lots of par…
clefourrier Dec 12, 2024
f9ab29b
revert
clefourrier Dec 12, 2024
4833929
fix docs
clefourrier Dec 12, 2024
30bed89
fix #16 by also allowing a generationconfig object to be passed progr…
clefourrier Dec 12, 2024
431b4f2
Merge branch 'main' into clem_homogeneize_generation_params
clefourrier Dec 12, 2024
fb4ecdc
Apply suggestions from code review
clefourrier Dec 16, 2024
be99c5e
Update src/lighteval/models/transformers/transformers_model.py
clefourrier Dec 16, 2024
e8b9057
Merge branch 'main' into clem_homogeneize_generation_params
clefourrier Dec 16, 2024
dece2f9
removed temperature from default vllm params as it should be passed v…
clefourrier Dec 17, 2024
8e3b7e2
Merge branch 'main' into clem_homogeneize_generation_params
clefourrier Dec 17, 2024
5c89fe2
Update src/lighteval/models/transformers/transformers_model.py
clefourrier Dec 18, 2024
6a18b81
logging fix
clefourrier Dec 18, 2024
83cbb10
Merge branch 'main' into clem_homogeneize_generation_params
clefourrier Dec 18, 2024
c6f42ca
Merge branch 'main' into clem_homogeneize_generation_params
clefourrier Dec 18, 2024
90593a9
added default gen params
clefourrier Dec 18, 2024
ff5026b
Apply suggestions from code review
clefourrier Dec 26, 2024
87d052c
rename file
clefourrier Dec 26, 2024
3f96b95
added from path to openai model
clefourrier Dec 26, 2024
843b572
style
clefourrier Dec 26, 2024
e233190
Update src/lighteval/models/transformers/transformers_model.py
clefourrier Dec 26, 2024
97db620
inferenceendpoint renamed to ie
clefourrier Dec 26, 2024
e2d512b
Merge branch 'main' into clem_homogeneize_generation_params
clefourrier Dec 26, 2024
e636f73
style 2
clefourrier Dec 26, 2024
ded4cf0
fix vllm
clefourrier Dec 26, 2024
fddfa6f
Merge branch 'main' into clem_homogeneize_generation_params
clefourrier Jan 2, 2025
c0566ee
Merge branch 'main' into clem_homogeneize_generation_params
NathanHB Jan 2, 2025
7a54afa
restore line
clefourrier Jan 2, 2025
4319230
Merge branch 'main' into clem_homogeneize_generation_params
clefourrier Jan 2, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
adding input generation config
clefourrier committed Dec 10, 2024

Verified

This commit was signed with the committer’s verified signature.
eliias Hannes Moser
commit 1ff5a78398db36549bb6e1e1ff27f6d301716998
112 changes: 112 additions & 0 deletions src/lighteval/models/model_input.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
# MIT License

# Copyright (c) 2024 The HuggingFace Team

# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:

# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

from dataclasses import dataclass
clefourrier marked this conversation as resolved.
Show resolved Hide resolved
from typing import Optional

from lighteval.utils.imports import NO_VLLM_ERROR_MSG, is_vllm_available


@dataclass
class GenerationParameters:
early_stopping: Optional[bool] = None # vllm, transformers
repetition_penalty: Optional[float] = None # vllm, transformers, tgi
frequency_penalty: Optional[float] = None # vllm, tgi
length_penalty: Optional[float] = None # vllm, transformers
presence_penalty: Optional[float] = None # vllm

max_new_tokens: Optional[int] = None # vllm, transformers, tgi
min_new_tokens: Optional[int] = None # vllm, transformers

seed: Optional[int] = None # vllm, tgi
stop_tokens: Optional[list[str]] = None # vllm, transformers, tgi
temperature: Optional[float] = None # vllm, transformers, tgi
top_k: Optional[int] = None # vllm, transformers, tgi
min_p: Optional[float] = None # vllm, transformers
top_p: Optional[int] = None # vllm, transformers, tgi
truncate_prompt: Optional[bool] = None # vllm, tgi

def to_vllm(self):
if not is_vllm_available():
raise ImportError(NO_VLLM_ERROR_MSG)
from vllm import SamplingParameters

# Task specific sampling params to set in model: n, best_of, use_beam_search
# Generation specific params to set in model: logprobs, prompt_logprobs
args = {
"presence_penalty": self.presence_penalty,
"frequency_penalty": self.frequency_penalty,
"repetition_penalty": self.repetition_penalty,
"temperature": self.temperature,
"top_p": self.top_p,
"top_k": self.top_k,
"min_p": self.min_p,
"seed": self.seed,
"length_penalty": self.length_penalty,
"early_stopping": self.early_stopping,
"stop": self.stop_tokens,
"max_tokens": self.max_new_tokens,
"min_tokens": self.min_new_tokens,
"truncate_prompt_tokens": self.truncate_prompt,
}
return SamplingParameters(**{k: v for k, v in args.items() if v is not None})

def to_transformers(self):
from transformers import GenerationConfig

# Task specific sampling params to set in model: do_sample, num_return_sequences, num_beans
args = {
"max_new_tokens": self.max_new_tokens,
"min_new_tokens": self.min_new_tokens,
"early_stopping": self.early_stopping,
"stop_strings": self.stop_tokens,
"temperature": self.temperature,
"top_k": self.top_k,
"top_p": self.top_p,
"min_p": self.min_p,
"repetition_penalty": self.repetition_penalty,
"length_penalty": self.length_penalty,
"output_scores": True,
"return_dict_in_generate": True,
}
# Even though we only use the dict representation of the GenerationConfig
# we still create the object as it uses validation steps
return GenerationConfig(**{k: v for k, v in args.items() if v is not None})

def to_tgi(self):
from huggingface_hub import TextGenerationInputGenerateParameters

# Task specific sampling params to set in model: best_of, do_sample
args = {
"decoder_input_details": True,
"details": True,
"frequency_penalty": self.frequency_penalty,
"max_new_tokens": self.max_new_tokens,
"repetition_penalty": self.repetition_penalty,
"seed": self.seed,
"stop": self.stop_tokens,
"temperature": self.temperature,
"top_k": self.top_k,
"top_p": self.top_p,
"truncate": self.truncate_prompt,
}
return TextGenerationInputGenerateParameters(**{k: v for k, v in args.items() if v is not None})
87 changes: 55 additions & 32 deletions src/lighteval/models/transformers/base_model.py
Original file line number Diff line number Diff line change
@@ -36,9 +36,11 @@
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
GenerationConfig,
GPTQConfig,
PretrainedConfig,
)
from transformers.generation.utils import GenerateOutput
from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES

from lighteval.data import GenerativeTaskDataset, LoglikelihoodDataset, LoglikelihoodSingleTokenDataset
@@ -151,6 +153,7 @@ class BaseModelConfig:
trust_remote_code: bool = False
use_chat_template: bool = False
compile: bool = False
generation_config: GenerationConfig = None

def __post_init__(self):
# Making sure this parameter is a boolean
@@ -256,6 +259,7 @@ def __init__(
self.model_sha = config.get_model_sha()

self.precision = _get_dtype(config.dtype, config=self._config)
self.generation_config = config.generation_config.to_dict()

if is_accelerate_available():
model_size, _ = calculate_maximum_sizes(self.model)
@@ -631,25 +635,29 @@ def greedy_until_multi_turn( # noqa: C901
],
]
)
model_outputs = self.model.generate(
**model_inputs,
max_new_tokens=max_generated_tokens,
stopping_criteria=stopping_criteria,
do_sample=False,
pad_token_id=self.tokenizer.pad_token_id
if self.tokenizer.pad_token_id
else self.tokenizer.eos_token_id,

generation_config = GenerationConfig.from_dict(self.generation_config or {})
generation_config.update(
{
"max_new_tokens": max_generated_tokens,
"pad_token_id": self.tokenizer.pad_token_id
if self.tokenizer.pad_token_id
else self.tokenizer.eos_token_id,
"eos_token_id": self.tokenizer.eos_token_id,
"do_sample": False,
}
)
model_outputs = model_outputs[0, model_inputs["input_ids"].size(1) :]

model_outputs: GenerateOutput = self.model.generate(
**model_inputs, stopping_criteria=stopping_criteria, generation_config=generation_config
)
model_outputs = model_outputs.sequences[0, model_inputs["input_ids"].size(1) :]
model_generations = [model_outputs]
decoded_generation = self.tokenizer.decode(model_outputs)
for term in stop_tokens:
decoded_generation = decoded_generation.split(term)[0]

input_tokens = [model_inputs["input_ids"]]

for i, multi_turn_context in enumerate(request.context[1:]):
multi_turn_context = multi_turn_context.format(model_response=decoded_generation)
multi_turn_context = multi_turn_context.format(model_response=model_generations[-1])

model_inputs = self.tokenizer(
multi_turn_context,
@@ -671,17 +679,25 @@ def greedy_until_multi_turn( # noqa: C901
]
)

model_outputs = self.model.generate(
generation_config = GenerationConfig.from_dict(self.generation_config or {})
generation_config.update(
{
"max_new_tokens": max_generated_tokens,
"pad_token_id": self.tokenizer.pad_token_id
if self.tokenizer.pad_token_id
else self.tokenizer.eos_token_id,
"eos_token_id": self.tokenizer.eos_token_id,
"do_sample": False,
}
)

model_outputs: GenerateOutput = self.model.generate(
input_ids=model_inputs["input_ids"],
attention_mask=model_inputs["attention_mask"],
max_new_tokens=max_generated_tokens,
stopping_criteria=stopping_criteria,
do_sample=False,
pad_token_id=self.tokenizer.pad_token_id
if self.tokenizer.pad_token_id
else self.tokenizer.eos_token_id,
generation_config=generation_config,
)
model_outputs = model_outputs[0, model_inputs["input_ids"].size(1) :]
model_outputs = model_outputs.sequences[0, model_inputs["input_ids"].size(1) :]
model_generations.append(model_outputs)
decoded_generation = self.tokenizer.decode(model_outputs, skip_special_tokens=True)
input_tokens.append(model_inputs["input_ids"])
@@ -708,7 +724,7 @@ def greedy_until_multi_turn( # noqa: C901
results.append(
GenerativeMultiturnResponse(
result=answers,
input_tokens=[],
input_tokens=input_tokens,
generated_tokens=[],
truncated_tokens_count=0,
padded_tokens_count=0,
@@ -860,29 +876,36 @@ def _generate(
stopping_criteria = stop_sequences_criteria(self.tokenizer, stop_sequences=stop_tokens, batch=batch)
batch_size, _ = batch.input_ids.shape

generation_config = GenerationConfig.from_dict(self.generation_config or {})
generation_config.update(
{
"max_new_tokens": max_new_tokens,
"pad_token_id": self.tokenizer.pad_token_id
if self.tokenizer.pad_token_id
else self.tokenizer.eos_token_id,
"eos_token_id": self.tokenizer.eos_token_id,
"do_sample": do_sample,
"num_return_sequences": num_samples,
"output_logits": returns_logits,
"renormalize_logits": True,
}
)

# Compute model generation
outputs = self.model.generate(
outputs: GenerateOutput = self.model.generate(
input_ids=batch.input_ids,
attention_mask=batch.input_mask,
max_new_tokens=max_new_tokens,
stopping_criteria=stopping_criteria,
pad_token_id=self.tokenizer.pad_token_id if self.tokenizer.pad_token_id else self.tokenizer.eos_token_id,
return_dict_in_generate=True,
output_scores=True,
eos_token_id=self.tokenizer.eos_token_id,
do_sample=do_sample,
num_return_sequences=num_samples,
generation_config=generation_config,
)
if returns_logits:
logits = self.model.compute_transition_scores(outputs.sequences, outputs.scores, normalize_logits=True)
generations = outputs.sequences[:, batch.input_ids.size(1) :]
generations = torch.reshape(generations, (batch_size, num_samples, -1))
generations, len_gens = self.pad_and_gather(generations, num_samples=num_samples)
batch.input_ids, len_ids = self.pad_and_gather(batch.input_ids)

logits, len_logits = None, None
if returns_logits:
logits, len_logits = self.pad_and_gather(logits)
logits, len_logits = self.pad_and_gather(outputs.logits)
logits = logits.cpu().numpy()

# We gather remaining info
21 changes: 13 additions & 8 deletions src/lighteval/models/vllm/vllm_model.py
Original file line number Diff line number Diff line change
@@ -85,6 +85,7 @@ class VLLMModelConfig:
True # whether to add a space at the start of each continuation in multichoice generation
)
pairwise_tokenization: bool = False # whether to tokenize the context and continuation separately or together.
sampling_params: SamplingParams = None # sampling parameters to use for generation

subfolder: Optional[str] = None
temperature: float = 0.6 # will be used for multi sampling tasks, for tasks requiring no sampling, this will be ignored and set to 0.
@@ -117,6 +118,7 @@ def __init__(
self.precision = _get_dtype(config.dtype, config=self._config)

self.model_info = ModelInfo(model_name=self.model_name, model_sha=self.model_sha)
self.sampling_params = config.sampling_params
self.pairwise_tokenization = config.pairwise_tokenization

@property
@@ -300,16 +302,19 @@ def _generate(
generate: bool = True,
) -> list[GenerativeResponse]:
"""Contains the actual logic of the generation."""
sampling_params = self.sampling_params or SamplingParams()
if generate:
sampling_params = SamplingParams(
temperature=float(self._config.temperature) if num_samples > 1 else 0.0,
n=num_samples,
max_tokens=max_new_tokens,
stop=stop_tokens,
logprobs=1 if returns_logits else 0,
)
sampling_params.temperature = float(self._config.temperature) if num_samples > 1 else 0.0
clefourrier marked this conversation as resolved.
Show resolved Hide resolved
sampling_params.n = num_samples
sampling_params.max_tokens = max_new_tokens
sampling_params.stop = stop_tokens
sampling_params.logprobs = 1 if returns_logits else 0

else:
sampling_params = SamplingParams(temperature=0, prompt_logprobs=1, max_tokens=1, detokenize=False)
sampling_params.temperature = 0
sampling_params.prompt_logprobs = 1
sampling_params.max_tokens = 1
sampling_params.detokenize = False

if self.data_parallel_size > 1:
# vLLM hangs if tensor_parallel > 1 and resources are set in ray.remote