From ff5026b10b1723652d912271b4ce714816198d77 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9mentine=20Fourrier?= <22726840+clefourrier@users.noreply.github.com> Date: Thu, 26 Dec 2024 11:16:40 +0100 Subject: [PATCH] Apply suggestions from code review Co-authored-by: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> Co-authored-by: Nathan Habib <30601243+NathanHB@users.noreply.github.com> --- src/lighteval/models/model_input.py | 41 ++----------------- .../models/transformers/transformers_model.py | 4 +- 2 files changed, 6 insertions(+), 39 deletions(-) diff --git a/src/lighteval/models/model_input.py b/src/lighteval/models/model_input.py index d2e5f435..6481124b 100644 --- a/src/lighteval/models/model_input.py +++ b/src/lighteval/models/model_input.py @@ -20,7 +20,7 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. -from dataclasses import dataclass +from dataclasses import dataclass, asdict from typing import Optional @@ -57,24 +57,7 @@ def from_dict(cls, config_dict: dict): } } """ - if "generation" not in config_dict: - return GenerationParameters() - return GenerationParameters( - early_stopping=config_dict["generation"].get("early_stopping", None), - repetition_penalty=config_dict["generation"].get("repetition_penalty", None), - frequency_penalty=config_dict["generation"].get("frequency_penalty", None), - length_penalty=config_dict["generation"].get("length_penalty", None), - presence_penalty=config_dict["generation"].get("presence_penalty", None), - max_new_tokens=config_dict["generation"].get("max_new_tokens", None), - min_new_tokens=config_dict["generation"].get("min_new_tokens", None), - seed=config_dict["generation"].get("seed", None), - stop_tokens=config_dict["generation"].get("stop_tokens", None), - temperature=config_dict["generation"].get("temperature", None), - top_k=config_dict["generation"].get("top_k", None), - min_p=config_dict["generation"].get("min_p", None), - top_p=config_dict["generation"].get("top_p", None), - truncate_prompt=config_dict["generation"].get("truncate_prompt", None), - ) + return GenerationParameters(**config_dict.get("generation", {})) def to_vllm_openai_dict(self) -> dict: """Selects relevant generation and sampling parameters for vllm and openai models. @@ -85,23 +68,7 @@ def to_vllm_openai_dict(self) -> dict: """ # Task specific sampling params to set in model: n, best_of, use_beam_search # Generation specific params to set in model: logprobs, prompt_logprobs - args = { - "presence_penalty": self.presence_penalty, - "frequency_penalty": self.frequency_penalty, - "repetition_penalty": self.repetition_penalty, - "temperature": self.temperature, - "top_p": self.top_p, - "top_k": self.top_k, - "min_p": self.min_p, - "seed": self.seed, - "length_penalty": self.length_penalty, - "early_stopping": self.early_stopping, - "stop": self.stop_tokens, - "max_tokens": self.max_new_tokens, - "min_tokens": self.min_new_tokens, - "truncate_prompt_tokens": self.truncate_prompt, - } - return {k: v for k, v in args.items() if v is not None} + return {k: v for k, v in asdict(self).items() if v is not None} def to_transformers_dict(self) -> dict: """Selects relevant generation and sampling parameters for transformers models. @@ -117,7 +84,7 @@ def to_transformers_dict(self) -> dict: args = { "max_new_tokens": self.max_new_tokens, "min_new_tokens": self.min_new_tokens, - "early_stopping": self.early_stopping or False, + "early_stopping": self.early_stopping, "stop_strings": self.stop_tokens, "temperature": self.temperature, "top_k": self.top_k, diff --git a/src/lighteval/models/transformers/transformers_model.py b/src/lighteval/models/transformers/transformers_model.py index 295ed4b1..34db20a6 100644 --- a/src/lighteval/models/transformers/transformers_model.py +++ b/src/lighteval/models/transformers/transformers_model.py @@ -1350,9 +1350,9 @@ def _loglikelihood_single_token( class BaseModel(TransformersModel): def __post_init__(self): - super() + super().__post_init__() - logger.warning( + warnings.warn( "Careful, the BaseModel name is deprecated and will be removed, you should use TransformersModel instead!" )