From 4b63a8648c9968ce754bbdcb15f2067d1272ec4e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9mentine?= Date: Mon, 9 Dec 2024 19:31:30 +0100 Subject: [PATCH] added doc --- src/lighteval/models/model_input.py | 38 ++++++++++++++++++++++++++--- 1 file changed, 34 insertions(+), 4 deletions(-) diff --git a/src/lighteval/models/model_input.py b/src/lighteval/models/model_input.py index 569cd7dc..ae74aa55 100644 --- a/src/lighteval/models/model_input.py +++ b/src/lighteval/models/model_input.py @@ -44,7 +44,19 @@ class GenerationParameters: truncate_prompt: Optional[bool] = None # vllm, tgi @classmethod - def from_dict(cls, config_dict): + def from_dict(cls, config_dict: dict): + """Creates a GenerationParameters object from a config dictionary + + Args: + config_dict (dict): Config dictionary. Must obey the following shape: + {"generation_parameters": + { + "early_stopping": value, + ... + "truncate_prompt": value + } + } + """ if "generation_parameters" not in config_dict: return cls cls.early_stopping = config_dict["generation_parameters"].get("early_stopping", None) @@ -63,7 +75,13 @@ def from_dict(cls, config_dict): cls.truncate_prompt = config_dict["generation_parameters"].get("truncate_prompt", None) return cls - def to_vllm_openai_dict(self): + def to_vllm_openai_dict(self) -> dict: + """Selects relevant generation and sampling parameters for vllm and openai models. + Doc: https://docs.vllm.ai/en/v0.5.5/dev/sampling_params.html + + Returns: + dict: The parameters to create a vllm.SamplingParams or just provide OpenAI params as such in the model config. + """ # Task specific sampling params to set in model: n, best_of, use_beam_search # Generation specific params to set in model: logprobs, prompt_logprobs args = { @@ -84,7 +102,13 @@ def to_vllm_openai_dict(self): } return {k: v for k, v in args.items() if v is not None} - def to_transformers_dict(self): + def to_transformers_dict(self) -> dict: + """Selects relevant generation and sampling parameters for transformers models. + Doc: https://huggingface.co/docs/transformers/v4.46.3/en/main_classes/text_generation#transformers.GenerationConfig + + Returns: + dict: The parameters to create a transformers.GenerationConfig in the model config. + """ # Task specific sampling params to set in model: do_sample, num_return_sequences, num_beans args = { "max_new_tokens": self.max_new_tokens, @@ -104,7 +128,13 @@ def to_transformers_dict(self): # we still create the object as it uses validation steps return {k: v for k, v in args.items() if v is not None} - def to_tgi_inferenceendpoint_dict(self): + def to_tgi_inferenceendpoint_dict(self) -> dict: + """Selects relevant generation and sampling parameters for tgi or inference endpoints models. + Doc: https://huggingface.co/docs/huggingface_hub/v0.26.3/en/package_reference/inference_types#huggingface_hub.TextGenerationInputGenerateParameters + + Returns: + dict: The parameters to create a huggingface_hub.TextGenerationInputGenerateParameters in the model config. + """ # Task specific sampling params to set in model: best_of, do_sample args = { "decoder_input_details": True,