Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Homogeneize generation params #428

Open
wants to merge 31 commits into
base: main
Choose a base branch
from

Conversation

clefourrier
Copy link
Member

@clefourrier clefourrier commented Dec 9, 2024

This PR does 3 things:

  • Provide an homogeneized API for people to use to provide model generation parameters in model configs. Those parameters are notably provided to all models which can take them (vllm, open ai, tgi, transformers, ...)
  • Renames BaseModel to TransformersModel
  • Also allows TransformersModels to use a transformers.GenerationConfig object directly, when created programmatically

I would put system_prompt, fewshot_seeds, and use_chat_template in the GenerationParameters too since they are generation parameters logically, but it can be another PR

Closes #16

@clefourrier clefourrier changed the base branch from main to refacto_model December 9, 2024 14:31
@HuggingFaceDocBuilderDev
Copy link
Collaborator

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@clefourrier clefourrier force-pushed the clem_homogeneize_generation_params branch from 071d502 to 286668f Compare December 10, 2024 13:55
@clefourrier
Copy link
Member Author

Interestingly, using an explicit generation config seems to mess up this quite a lot

@clefourrier clefourrier requested a review from NathanHB December 12, 2024 13:14
Copy link
Member

@albertvillanova albertvillanova left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great! I only reviewed it partially. Some comments below.

src/lighteval/models/transformers/transformers_model.py Outdated Show resolved Hide resolved
src/lighteval/models/transformers/transformers_model.py Outdated Show resolved Hide resolved
Co-authored-by: Nathan Habib <[email protected]>
Co-authored-by: Albert Villanova del Moral <[email protected]>
@clefourrier clefourrier requested a review from NathanHB December 18, 2024 14:34
Copy link
Member

@NathanHB NathanHB left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice to have a coherent config, few nits and good to go

@@ -117,6 +122,7 @@ def __init__(
self.precision = _get_dtype(config.dtype, config=self._config)

self.model_info = ModelInfo(model_name=self.model_name, model_sha=self.model_sha)
self.sampling_params = SamplingParams(**config.sampling_params)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where does the config.sampling_params come from ? I do not see them in VLLMModelConfig

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

after testing it indeed does not run

def __post_init__(self):
super()

logger.warning(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
logger.warning(
warnings.warn(

super()

logger.warning(
"Careful, the BaseModel name is deprecated and will be removed, you should use TransformersModel instead!"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"Careful, the BaseModel name is deprecated and will be removed, you should use TransformersModel instead!"
"Careful, the BaseModel name is deprecated and will be removed, you should use TransformersModel instead!",
FutureWarning,

@@ -124,6 +127,8 @@ class BaseModelConfig:
model at a quantized precision. Needed for 4-bit and 8-bit precision.
trust_remote_code (bool): Whether to trust remote code during model
loading.
generation_parameters (GenerationParameters): Range of parameters which will affect the generation.
generation_config (GenerationConfig): GenerationConfig object (only passed during manual creation)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand why we have both generation paramaters and config ?

@@ -1297,6 +1348,15 @@ def _loglikelihood_single_token(
return dataset.get_original_order(res)


class BaseModel(TransformersModel):
def __post_init__(self):
super()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
super()
super().__post_init__()

@@ -44,7 +44,7 @@ def accelerate( # noqa C901
model_args: Annotated[
str,
Argument(
help="Model arguments in the form key1=value1,key2=value2,... or path to yaml config file (see examples/model_configs/base_model.yaml)"
help="Model arguments in the form key1=value1,key2=value2,... or path to yaml config file (see examples/model_configs/transformers_model.yaml)"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to rename the file

  • examples/model_configs/base_model.yaml
    to
  • examples/model_configs/transformers_model.yaml

Comment on lines +60 to +77
if "generation" not in config_dict:
return GenerationParameters()
return GenerationParameters(
early_stopping=config_dict["generation"].get("early_stopping", None),
repetition_penalty=config_dict["generation"].get("repetition_penalty", None),
frequency_penalty=config_dict["generation"].get("frequency_penalty", None),
length_penalty=config_dict["generation"].get("length_penalty", None),
presence_penalty=config_dict["generation"].get("presence_penalty", None),
max_new_tokens=config_dict["generation"].get("max_new_tokens", None),
min_new_tokens=config_dict["generation"].get("min_new_tokens", None),
seed=config_dict["generation"].get("seed", None),
stop_tokens=config_dict["generation"].get("stop_tokens", None),
temperature=config_dict["generation"].get("temperature", None),
top_k=config_dict["generation"].get("top_k", None),
min_p=config_dict["generation"].get("min_p", None),
top_p=config_dict["generation"].get("top_p", None),
truncate_prompt=config_dict["generation"].get("truncate_prompt", None),
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if "generation" not in config_dict:
return GenerationParameters()
return GenerationParameters(
early_stopping=config_dict["generation"].get("early_stopping", None),
repetition_penalty=config_dict["generation"].get("repetition_penalty", None),
frequency_penalty=config_dict["generation"].get("frequency_penalty", None),
length_penalty=config_dict["generation"].get("length_penalty", None),
presence_penalty=config_dict["generation"].get("presence_penalty", None),
max_new_tokens=config_dict["generation"].get("max_new_tokens", None),
min_new_tokens=config_dict["generation"].get("min_new_tokens", None),
seed=config_dict["generation"].get("seed", None),
stop_tokens=config_dict["generation"].get("stop_tokens", None),
temperature=config_dict["generation"].get("temperature", None),
top_k=config_dict["generation"].get("top_k", None),
min_p=config_dict["generation"].get("min_p", None),
top_p=config_dict["generation"].get("top_p", None),
truncate_prompt=config_dict["generation"].get("truncate_prompt", None),
)
return GenerationParameters(**config_dict.get("generation", {}))

Comment on lines +88 to +104
args = {
"presence_penalty": self.presence_penalty,
"frequency_penalty": self.frequency_penalty,
"repetition_penalty": self.repetition_penalty,
"temperature": self.temperature,
"top_p": self.top_p,
"top_k": self.top_k,
"min_p": self.min_p,
"seed": self.seed,
"length_penalty": self.length_penalty,
"early_stopping": self.early_stopping,
"stop": self.stop_tokens,
"max_tokens": self.max_new_tokens,
"min_tokens": self.min_new_tokens,
"truncate_prompt_tokens": self.truncate_prompt,
}
return {k: v for k, v in args.items() if v is not None}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
args = {
"presence_penalty": self.presence_penalty,
"frequency_penalty": self.frequency_penalty,
"repetition_penalty": self.repetition_penalty,
"temperature": self.temperature,
"top_p": self.top_p,
"top_k": self.top_k,
"min_p": self.min_p,
"seed": self.seed,
"length_penalty": self.length_penalty,
"early_stopping": self.early_stopping,
"stop": self.stop_tokens,
"max_tokens": self.max_new_tokens,
"min_tokens": self.min_new_tokens,
"truncate_prompt_tokens": self.truncate_prompt,
}
return {k: v for k, v in args.items() if v is not None}
return {k: v for k, v in asdict(self).items() if v is not None}

# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

from dataclasses import dataclass
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
from dataclasses import dataclass
from dataclasses import dataclass, asdict

Comment on lines +105 to +107
with open(model_args, "r") as f:
config = yaml.safe_load(f)["model"]
generation_parameters = GenerationParameters.from_dict(config)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe worth adding a from_path method?

}
return {k: v for k, v in args.items() if v is not None}

def to_tgi_inferenceendpoint_dict(self) -> dict:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe renaming to to_tgi_endpoint_dict?

  • TGI has the word inference included

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's because it's either tgi or inference endpoints. Maybe tgi_ie?

args = {
"max_new_tokens": self.max_new_tokens,
"min_new_tokens": self.min_new_tokens,
"early_stopping": self.early_stopping or False,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"early_stopping": self.early_stopping or False,
"early_stopping": self.early_stopping,

Why or False? If None, it will not be passed to transformers GenerationConfig and it defaults to False:

https://github.com/huggingface/transformers/blob/052e652d6d53c2b26ffde87e039b723949a53493/src/transformers/generation/configuration_utils.py#L368

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Allow passing a GenerationConfig for generative evals
4 participants