Skip to content

Commit

Permalink
Saved GenerationParameter class in model config classes, then saved i…
Browse files Browse the repository at this point in the history
…n the models to use other attributes later
  • Loading branch information
clefourrier committed Dec 10, 2024
1 parent 4b63a86 commit 071d502
Show file tree
Hide file tree
Showing 8 changed files with 105 additions and 31 deletions.
2 changes: 1 addition & 1 deletion src/lighteval/main_accelerate.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def accelerate( # noqa C901
# We extract the model args
args_dict = {k.split("=")[0]: k.split("=")[1] for k in config["base_params"]["model_args"].split(",")}

args_dict["generation_config"] = GenerationParameters.from_dict(config).to_transformers_dict()
args_dict["generation_parameters"] = GenerationParameters.from_dict(config)

# We store the relevant other args
args_dict["base_model"] = config["merged_weights"]["base_model"]
Expand Down
42 changes: 29 additions & 13 deletions src/lighteval/main_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from typing import Optional

import typer
import yaml
from typer import Argument, Option
from typing_extensions import Annotated

Expand All @@ -42,10 +43,19 @@
@app.command(rich_help_panel="Evaluation Backends")
def openai(
# === general ===
model_name: Annotated[
str, Argument(help="The model name to evaluate (has to be available through the openai API.")
],
tasks: Annotated[str, Argument(help="Comma-separated list of tasks to evaluate on.")],
model_name: Annotated[
str,
Argument(
help="The model name to evaluate (has to be available through the openai API. Mutually exclusive with the config path"
),
] = None,
model_config_path: Annotated[
str,
Argument(
help="Path to model config yaml file. (examples/model_configs/endpoint_model.yaml). Mutually exclusive with the model name"
),
] = None,
# === Common parameters ===
system_prompt: Annotated[
Optional[str], Option(help="Use system prompt for evaluation.", rich_help_panel=HELP_PANNEL_NAME_4)
Expand Down Expand Up @@ -96,8 +106,12 @@ def openai(

# from lighteval.models.model_input import GenerationParameters
from lighteval.models.endpoints.openai_model import OpenAIModelConfig
from lighteval.models.model_input import GenerationParameters
from lighteval.pipeline import EnvConfig, ParallelismManager, Pipeline, PipelineParameters

if not (model_name is None ^ model_config_path is None):
raise typer.Abort("You must define either the model_name or the model_config_path, not both")

env_config = EnvConfig(token=TOKEN, cache_dir=cache_dir)
evaluation_tracker = EvaluationTracker(
output_dir=output_dir,
Expand All @@ -109,8 +123,14 @@ def openai(
)

parallelism_manager = ParallelismManager.OPENAI
# sampling_params = GenerationParameters.from_dict(config)
model_config = OpenAIModelConfig(model=model_name) # , sampling_params=sampling_params.to_vllm_openai_dict())

if model_name:
model_config = OpenAIModelConfig(model=model_name)
else:
with open(model_config_path, "r") as f:
config = yaml.safe_load(f)["model"]
generation_parameters = GenerationParameters.from_dict(config)
model_config = OpenAIModelConfig(model=config["model_name"], generation_parameters=generation_parameters)

pipeline_params = PipelineParameters(
launcher_type=parallelism_manager,
Expand Down Expand Up @@ -201,8 +221,6 @@ def inference_endpoint(
"""
Evaluate models using inference-endpoints as backend.
"""
import yaml

from lighteval.logging.evaluation_tracker import EvaluationTracker
from lighteval.models.endpoints.endpoint_model import (
InferenceEndpointModelConfig,
Expand Down Expand Up @@ -230,7 +248,7 @@ def inference_endpoint(
# Find a way to add this back
# if config["base_params"].get("endpoint_name", None):
# return InferenceModelConfig(model=config["base_params"]["endpoint_name"])
generation_config = GenerationParameters.from_dict(config)
generation_parameters = GenerationParameters.from_dict(config)
all_params = {
"model_name": config["base_params"].get("model_name", None),
"endpoint_name": config["base_params"].get("endpoint_name", None),
Expand All @@ -245,7 +263,7 @@ def inference_endpoint(
"namespace": config.get("instance", {}).get("namespace", None),
"image_url": config.get("instance", {}).get("image_url", None),
"env_vars": config.get("instance", {}).get("env_vars", None),
"generation_config": generation_config.to_tgi_inferenceendpoint_dict(),
"generation_parameters": generation_parameters,
}

model_config = InferenceEndpointModelConfig(
Expand Down Expand Up @@ -342,8 +360,6 @@ def tgi(
"""
Evaluate models using TGI as backend.
"""
import yaml

from lighteval.logging.evaluation_tracker import EvaluationTracker
from lighteval.models.endpoints.tgi_model import TGIModelConfig
from lighteval.models.model_input import GenerationParameters
Expand All @@ -364,13 +380,13 @@ def tgi(
with open(model_config_path, "r") as f:
config = yaml.safe_load(f)["model"]

generation_config = GenerationParameters.from_dict(config)
generation_parameters = GenerationParameters.from_dict(config)

model_config = TGIModelConfig(
inference_server_address=config["instance"]["inference_server_address"],
inference_server_auth=config["instance"]["inference_server_auth"],
model_id=config["instance"]["model_id"],
generation_config=generation_config.to_tgi_inferenceendpoint_dict(),
generation_parameters=generation_parameters,
)

pipeline_params = PipelineParameters(
Expand Down
31 changes: 27 additions & 4 deletions src/lighteval/main_vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import os
from typing import Optional

from typer import Argument, Option
from typer import Abort, Argument, Option
from typing_extensions import Annotated


Expand All @@ -37,8 +37,19 @@

def vllm(
# === general ===
model_args: Annotated[str, Argument(help="Model arguments in the form key1=value1,key2=value2,...")],
tasks: Annotated[str, Argument(help="Comma-separated list of tasks to evaluate on.")],
model_args: Annotated[
str,
Argument(
help="Model arguments in the form key1=value1,key2=value2,... Mutually exclusive with the config path"
),
] = None,
model_config_path: Annotated[
str,
Argument(
help="Path to model config yaml file. (examples/model_configs/vllm_model.yaml). Mutually exclusive with the model args"
),
] = None,
# === Common parameters ===
use_chat_template: Annotated[
bool, Option(help="Use chat template for evaluation.", rich_help_panel=HELP_PANNEL_NAME_4)
Expand Down Expand Up @@ -88,10 +99,16 @@ def vllm(
"""
Evaluate models using vllm as backend.
"""
import yaml

from lighteval.logging.evaluation_tracker import EvaluationTracker
from lighteval.models.model_input import GenerationParameters
from lighteval.models.vllm.vllm_model import VLLMModelConfig
from lighteval.pipeline import EnvConfig, ParallelismManager, Pipeline, PipelineParameters

if not (model_args is None ^ model_config_path is None):
raise Abort("You must define either the model_args or the model_config_path, not both")

TOKEN = os.getenv("HF_TOKEN")

env_config = EnvConfig(token=TOKEN, cache_dir=cache_dir)
Expand All @@ -118,8 +135,14 @@ def vllm(
system_prompt=system_prompt,
)

model_args_dict: dict = {k.split("=")[0]: k.split("=")[1] if "=" in k else True for k in model_args.split(",")}
model_config = VLLMModelConfig(**model_args_dict)
if model_args:
model_args_dict: dict = {k.split("=")[0]: k.split("=")[1] if "=" in k else True for k in model_args.split(",")}
model_config = VLLMModelConfig(**model_args_dict)
else:
with open(model_config_path, "r") as f:
config = yaml.safe_load(f)["model"]
generation_parameters = GenerationParameters.from_dict(config)
model_config = VLLMModelConfig(**model_args_dict, generation_parameters=generation_parameters)

pipeline = Pipeline(
tasks=tasks,
Expand Down
17 changes: 14 additions & 3 deletions src/lighteval/models/endpoints/endpoint_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@

from lighteval.data import GenerativeTaskDataset, LoglikelihoodDataset
from lighteval.models.abstract_model import LightevalModel, ModelInfo
from lighteval.models.model_input import GenerationParameters
from lighteval.models.model_output import GenerativeResponse, LoglikelihoodResponse, LoglikelihoodSingleTokenResponse
from lighteval.tasks.requests import (
GreedyUntilRequest,
Expand Down Expand Up @@ -79,7 +80,11 @@
class InferenceModelConfig:
model: str
add_special_tokens: bool = True
generation_config: dict = dict
generation_parameters: GenerationParameters = None

def __post_init__(self):
if not self.generation_parameters:
self.generation_parameters = GenerationParameters()


@dataclass
Expand All @@ -100,7 +105,7 @@ class InferenceEndpointModelConfig:
namespace: str = None # The namespace under which to launch the endopint. Defaults to the current user's namespace
image_url: str = None
env_vars: dict = None
generation_config: dict = dict
generation_parameters: GenerationParameters = None

def __post_init__(self):
# xor operator, one is None but not the other
Expand All @@ -112,6 +117,9 @@ def __post_init__(self):
if not (self.endpoint_name is None) ^ int(self.model_name is None):
raise ValueError("You need to set either endpoint_name or model_name (but not both).")

if not self.generation_parameters:
self.generation_parameters = GenerationParameters()

def get_dtype_args(self) -> Dict[str, str]:
if self.model_dtype is None:
return {}
Expand Down Expand Up @@ -284,7 +292,10 @@ def __init__( # noqa: C901
model_dtype=config.model_dtype or "default",
model_size=-1,
)
self.generation_config = TextGenerationInputGenerateParameters(**config.generation_config)
self.generation_parameters = config.generation_parameters
self.generation_config = TextGenerationInputGenerateParameters(
**self.generation_parameters.to_tgi_inferenceendpoint_dict()
)

@staticmethod
def get_larger_hardware_suggestion(cur_instance_type: str = None, cur_instance_size: str = None):
Expand Down
10 changes: 8 additions & 2 deletions src/lighteval/models/endpoints/openai_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from lighteval.data import GenerativeTaskDataset, LoglikelihoodDataset
from lighteval.models.abstract_model import LightevalModel
from lighteval.models.endpoints.endpoint_model import ModelInfo
from lighteval.models.model_input import GenerationParameters
from lighteval.models.model_output import (
GenerativeResponse,
LoglikelihoodResponse,
Expand Down Expand Up @@ -62,7 +63,11 @@
@dataclass
class OpenAIModelConfig:
model: str
sampling_params: dict = dict
generation_parameters: GenerationParameters = None

def __post_init__(self):
if not self.generation_parameters:
self.generation_parameters = GenerationParameters()


class OpenAIClient(LightevalModel):
Expand All @@ -71,7 +76,8 @@ class OpenAIClient(LightevalModel):
def __init__(self, config: OpenAIModelConfig, env_config) -> None:
api_key = os.environ["OPENAI_API_KEY"]
self.client = OpenAI(api_key=api_key)
self.sampling_params = config.sampling_params
self.generation_parameters = config.generation_parameters
self.sampling_params = self.generation_parameters.to_vllm_openai_dict()

self.model_info = ModelInfo(
model_name=config.model,
Expand Down
12 changes: 10 additions & 2 deletions src/lighteval/models/endpoints/tgi_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from transformers import AutoTokenizer

from lighteval.models.endpoints.endpoint_model import InferenceEndpointModel, ModelInfo
from lighteval.models.model_input import GenerationParameters
from lighteval.utils.imports import NO_TGI_ERROR_MSG, is_tgi_available


Expand All @@ -50,7 +51,11 @@ class TGIModelConfig:
inference_server_address: str
inference_server_auth: str
model_id: str
generation_config: dict = dict
generation_parameters: GenerationParameters = None

def __post_init__(self):
if not self.generation_parameters:
self.generation_parameters = GenerationParameters()


# inherit from InferenceEndpointModel instead of LightevalModel since they both use the same interface, and only overwrite
Expand All @@ -66,7 +71,10 @@ def __init__(self, config: TGIModelConfig) -> None:
)

self.client = AsyncClient(config.inference_server_address, headers=headers, timeout=240)
self.generation_config = TextGenerationInputGenerateParameters(**config.generation_config)
self.generation_parameters = config.generation_parameters
self.generation_config = TextGenerationInputGenerateParameters(
**self.generation_parameters.to_tgi_inferenceendpoint_dict()
)
self._max_gen_toks = 256
self.model_info = requests.get(f"{config.inference_server_address}/info", headers=headers).json()
if "model_id" not in self.model_info:
Expand Down
15 changes: 10 additions & 5 deletions src/lighteval/models/transformers/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@

from lighteval.data import GenerativeTaskDataset, LoglikelihoodDataset, LoglikelihoodSingleTokenDataset
from lighteval.models.abstract_model import LightevalModel, ModelInfo
from lighteval.models.model_input import GenerationParameters
from lighteval.models.model_output import (
Batch,
GenerativeMultiturnResponse,
Expand Down Expand Up @@ -153,7 +154,7 @@ class BaseModelConfig:
trust_remote_code: bool = False
use_chat_template: bool = False
compile: bool = False
generation_config: dict = dict
generation_parameters: GenerationParameters = None

def __post_init__(self):
# Making sure this parameter is a boolean
Expand All @@ -180,6 +181,9 @@ def __post_init__(self):
if not isinstance(self.device, str):
raise ValueError("Current device must be passed as string.")

if not self.generation_parameters:
self.generation_parameters = GenerationParameters()

def _init_configs(self, model_name: str, env_config: EnvConfig) -> PretrainedConfig:
revision = self.revision
if self.subfolder:
Expand Down Expand Up @@ -259,7 +263,8 @@ def __init__(
self.model_sha = config.get_model_sha()

self.precision = _get_dtype(config.dtype, config=self._config)
self.generation_config = config.generation_config
self.generation_parameters = config.generation_parameters
self.generation_config_dict = self.generation_parameters.to_transformers_dict()

if is_accelerate_available():
model_size, _ = calculate_maximum_sizes(self.model)
Expand Down Expand Up @@ -636,7 +641,7 @@ def greedy_until_multi_turn( # noqa: C901
]
)

generation_config = GenerationConfig.from_dict(self.generation_config)
generation_config = GenerationConfig.from_dict(self.generation_config_dict)
generation_config.update(
{
"max_new_tokens": max_generated_tokens,
Expand Down Expand Up @@ -679,7 +684,7 @@ def greedy_until_multi_turn( # noqa: C901
]
)

generation_config = GenerationConfig.from_dict(self.generation_config)
generation_config = GenerationConfig.from_dict(self.generation_config_dict)
generation_config.update(
{
"max_new_tokens": max_generated_tokens,
Expand Down Expand Up @@ -876,7 +881,7 @@ def _generate(
stopping_criteria = stop_sequences_criteria(self.tokenizer, stop_sequences=stop_tokens, batch=batch)
batch_size, _ = batch.input_ids.shape

generation_config = GenerationConfig.from_dict(self.generation_config)
generation_config = GenerationConfig.from_dict(self.generation_config_dict)
generation_config.update(
{
"max_new_tokens": max_new_tokens,
Expand Down
7 changes: 6 additions & 1 deletion src/lighteval/models/vllm/vllm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@

from lighteval.data import GenerativeTaskDataset, LoglikelihoodDataset
from lighteval.models.abstract_model import LightevalModel, ModelInfo
from lighteval.models.model_input import GenerationParameters
from lighteval.models.model_output import (
GenerativeResponse,
LoglikelihoodResponse,
Expand Down Expand Up @@ -85,11 +86,15 @@ class VLLMModelConfig:
True # whether to add a space at the start of each continuation in multichoice generation
)
pairwise_tokenization: bool = False # whether to tokenize the context and continuation separately or together.
sampling_params: dict = dict # sampling parameters to use for generation
generation_parameters: GenerationParameters = None # sampling parameters to use for generation

subfolder: Optional[str] = None
temperature: float = 0.6 # will be used for multi sampling tasks, for tasks requiring no sampling, this will be ignored and set to 0.

def __post_init__(self):
if not self.generation_parameters:
self.generation_parameters = GenerationParameters()


class VLLMModel(LightevalModel):
def __init__(
Expand Down

0 comments on commit 071d502

Please sign in to comment.