From 7fcaab354853abf08d71f80340c29814dae17e91 Mon Sep 17 00:00:00 2001 From: shaltielshmid Date: Thu, 4 Jul 2024 11:23:10 +0300 Subject: [PATCH] Updated tgi_model and added parameters for endpoint_model (#208) * Added image url parameter * Fixed up tgi model config * Undid tgi available check * Adjust tgi parameter names, and checked for attr existence * Fixed task Id in argparse * Removed obfuscation from private functions, to allow inheritance to override * Updated tgi model to inherit from endpoint and just modify client calls * Added option to specify model id in config for tgi model * Added option to specify custom env vars * Updated env vras * Applied ruff format * Added docs + readme * Ruff format --- README.md | 22 +++- examples/model_configs/endpoint_model.yaml | 5 +- examples/model_configs/tgi_model.yaml | 1 + run_evals_accelerate.py | 5 +- src/lighteval/models/endpoint_model.py | 35 +++--- src/lighteval/models/model_config.py | 17 ++- src/lighteval/models/model_loader.py | 6 +- src/lighteval/models/tgi_model.py | 127 ++++++++------------- src/lighteval/utils.py | 2 +- 9 files changed, 109 insertions(+), 111 deletions(-) diff --git a/README.md b/README.md index 735a62da0..b90fc9764 100644 --- a/README.md +++ b/README.md @@ -139,7 +139,7 @@ accelerate launch --multi_gpu --num_processes= run_evals_accelerate.py --output_dir output_dir ``` -Examples of possible configuration files are provided in `examples/model_configs`. +You can find the template of the expected model configuration in [examples/model_configs/base_model.yaml_](./examples/model_configs/base_model.yaml). ### Evaluating a large model with pipeline parallelism @@ -182,6 +182,25 @@ python run_evals_accelerate.py \ --output_dir output_dir ``` +### Evaluate the model on a server/container. + +An alternative to launching the evaluation locally is to serve the model on a TGI-compatible server/container and then run the evaluation by sending requests to the server. The command is the same as before, except you specify a path to a yaml config file (detailed below): + +```shell +python run_evals_accelerate.py \ + --model_config_path="/path/to/config/file"\ + --tasks \ + --output_dir output_dir +``` + +There are two types of configuration files that can be provided for running on the server: + +1. [endpoint_model.yaml](./examples/model_configs/endpoint_model.yaml): This configuration allows you to launch the model using [HuggingFace's Inference Endpoints](https://huggingface.co/inference-endpoints/dedicated). You can specify in the configuration file all the relevant parameters, and then `lighteval` will automatically deploy the endpoint, run the evaluation, and finally delete the endpoint (unless you specify an endpoint that was already launched, in which case the endpoint won't be deleted afterwards). + +2. [tgi_model.yaml](./examples/model_configs/tgi_model.yaml): This configuration lets you specify the URL of a model running in a TGI container, such as one deployed on HuggingFace's serverless inference. + +Templates for these configurations can be found in [examples/model_configs](./examples/model_configs/). + ### Evaluate a model on extended, community, or custom tasks. Independently of the default tasks provided in `lighteval` that you will find in the `tasks_table.jsonl` file, you can use `lighteval` to evaluate models on tasks that require special processing (or have been added by the community). These tasks have their own evaluation suites and are defined as follows: @@ -190,7 +209,6 @@ Independently of the default tasks provided in `lighteval` that you will find in * `community`: tasks that have been added by the community. See the [`community_tasks`](./community_tasks) folder for examples. * `custom`: tasks that are defined locally and not present in the core library. Use this suite if you want to experiment with designing a special metric or task. - For example, to run an extended task like `ifeval`, you can run: ```shell python run_evals_accelerate.py \ diff --git a/examples/model_configs/endpoint_model.yaml b/examples/model_configs/endpoint_model.yaml index cc05dcf57..9e0db4374 100644 --- a/examples/model_configs/endpoint_model.yaml +++ b/examples/model_configs/endpoint_model.yaml @@ -5,7 +5,7 @@ model: model: "meta-llama/Llama-2-7b-hf" revision: "main" dtype: "float16" # can be any of "awq", "eetq", "gptq", "4bit' or "8bit" (will use bitsandbytes), "bfloat16" or "float16" - reuse_existing: false # if true, ignore all params in instance + reuse_existing: false # if true, ignore all params in instance, and don't delete the endpoint after evaluation instance: accelerator: "gpu" region: "eu-west-1" @@ -15,5 +15,8 @@ model: framework: "pytorch" endpoint_type: "protected" namespace: null # The namespace under which to launch the endopint. Defaults to the current user's namespace + image_url: null # Optionally specify the docker image to use when launching the endpoint model. E.g., launching models with later releases of the TGI container with support for newer models. + env_vars: + null # Optional environment variables to include when launching the endpoint. e.g., `MAX_INPUT_LENGTH: 2048` generation: add_special_tokens: true diff --git a/examples/model_configs/tgi_model.yaml b/examples/model_configs/tgi_model.yaml index 4cfb80860..5e45641f9 100644 --- a/examples/model_configs/tgi_model.yaml +++ b/examples/model_configs/tgi_model.yaml @@ -3,3 +3,4 @@ model: instance: inference_server_address: "" inference_server_auth: null + model_id: null # Optional, only required if the TGI container was launched with model_id pointing to a local directory \ No newline at end of file diff --git a/run_evals_accelerate.py b/run_evals_accelerate.py index a743cb496..23e46cb05 100644 --- a/run_evals_accelerate.py +++ b/run_evals_accelerate.py @@ -20,10 +20,11 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. -""" Example run command: +"""Example run command: accelerate config accelerate launch run_evals_accelerate.py --tasks="leaderboard|hellaswag|5|1" --output_dir "/scratch/evals" --model_args "pretrained=gpt2" """ + import argparse from lighteval.main_accelerate import CACHE_DIR, main @@ -70,7 +71,7 @@ def get_parser(): "--tasks", type=str, default=None, - help="Id of a task, e.g. 'original|mmlu:abstract_algebra|5' or path to a texte file with a list of tasks", + help="Id of a task, e.g. 'original|mmlu:abstract_algebra|5|0' or path to a texte file with a list of tasks", ) parser.add_argument("--num_fewshot_seeds", type=int, default=1, help="Number of trials the few shots") return parser diff --git a/src/lighteval/models/endpoint_model.py b/src/lighteval/models/endpoint_model.py index d79e0f912..b7e9af31b 100644 --- a/src/lighteval/models/endpoint_model.py +++ b/src/lighteval/models/endpoint_model.py @@ -92,8 +92,9 @@ def __init__( "MAX_TOTAL_TOKENS": "2048", "MODEL_ID": "/repository", **config.get_dtype_args(), + **config.get_custom_env_vars(), }, - "url": "ghcr.io/huggingface/text-generation-inference:1.1.0", + "url": (config.image_url or "ghcr.io/huggingface/text-generation-inference:1.1.0"), }, ) hlog("Deploying your endpoint. Please wait.") @@ -149,7 +150,7 @@ def max_length(self): self._max_length = 2048 return self._max_length - def __async_process_request( + def _async_process_request( self, context: str, stop_tokens: list[str], max_tokens: int ) -> Coroutine[None, list[TextGenerationOutput], str]: # Todo: add an option to launch with conversational instead for chat prompts @@ -165,7 +166,7 @@ def __async_process_request( return generated_text - def __process_request(self, context: str, stop_tokens: list[str], max_tokens: int) -> TextGenerationOutput: + def _process_request(self, context: str, stop_tokens: list[str], max_tokens: int) -> TextGenerationOutput: # Todo: add an option to launch with conversational instead for chat prompts # https://huggingface.co/docs/huggingface_hub/v0.20.3/en/package_reference/inference_client#huggingface_hub.AsyncInferenceClient.conversational generated_text = self.client.text_generation( @@ -179,13 +180,13 @@ def __process_request(self, context: str, stop_tokens: list[str], max_tokens: in return generated_text - async def __async_process_batch_generate( + async def _async_process_batch_generate( self, requests: list[GreedyUntilRequest], ) -> list[TextGenerationOutput]: return await asyncio.gather( *[ - self.__async_process_request( + self._async_process_request( context=request.context, stop_tokens=as_list(request.stop_sequence), max_tokens=request.generation_size, @@ -194,12 +195,12 @@ async def __async_process_batch_generate( ] ) - def __process_batch_generate( + def _process_batch_generate( self, requests: list[GreedyUntilRequest], ) -> list[TextGenerationOutput]: return [ - self.__process_request( + self._process_request( context=request.context, stop_tokens=as_list(request.stop_sequence), max_tokens=request.generation_size, @@ -207,12 +208,12 @@ def __process_batch_generate( for request in requests ] - async def __async_process_batch_logprob( + async def _async_process_batch_logprob( self, requests: list[LoglikelihoodRequest], rolling: bool = False ) -> list[TextGenerationOutput]: return await asyncio.gather( *[ - self.__async_process_request( + self._async_process_request( context=request.context if rolling else request.context + request.choice, stop_tokens=[], max_tokens=1, @@ -221,11 +222,11 @@ async def __async_process_batch_logprob( ] ) - def __process_batch_logprob( + def _process_batch_logprob( self, requests: list[LoglikelihoodRequest], rolling: bool = False ) -> list[TextGenerationOutput]: return [ - self.__process_request( + self._process_request( context=request.context if rolling else request.context + request.choice, stop_tokens=[], max_tokens=1, @@ -267,9 +268,9 @@ def greedy_until( ) if self.use_async: - responses = asyncio.run(self.__async_process_batch_generate(batch)) + responses = asyncio.run(self._async_process_batch_generate(batch)) else: - responses = self.__process_batch_generate(batch) + responses = self._process_batch_generate(batch) for response in responses: results.append( GenerateReturn( @@ -303,9 +304,9 @@ def loglikelihood( for batch in tqdm(dataloader, desc="Loglikelihoods", position=1, leave=False, disable=self.disable_tqdm): if self.use_async: - responses = asyncio.run(self.__async_process_batch_logprob(batch)) + responses = asyncio.run(self._async_process_batch_logprob(batch)) else: - responses = self.__process_batch_logprob(batch) + responses = self._process_batch_logprob(batch) for cur_request, response in zip(batch, responses): cont_toks = torch.tensor(cur_request.tokenized_continuation) len_choice = len(cont_toks) @@ -351,9 +352,9 @@ def loglikelihood_rolling( dataloader, desc="Loglikelihoods, rolling", position=1, leave=False, disable=self.disable_tqdm ): if self.use_async: - responses = asyncio.run(self.__async_process_batch_logprob(batch, rolling=True)) + responses = asyncio.run(self._async_process_batch_logprob(batch, rolling=True)) else: - responses = self.__process_batch_logprob(batch, rolling=True) + responses = self._process_batch_logprob(batch, rolling=True) for response in responses: logits = [t.logprob for t in response.details.tokens[:-1]] diff --git a/src/lighteval/models/model_config.py b/src/lighteval/models/model_config.py index 5cb7c89d6..f2736e1af 100644 --- a/src/lighteval/models/model_config.py +++ b/src/lighteval/models/model_config.py @@ -200,6 +200,7 @@ def init_configs(self, env_config: EnvConfig): class TGIModelConfig: inference_server_address: str inference_server_auth: str + model_id: str @dataclass @@ -224,6 +225,8 @@ class InferenceEndpointModelConfig: add_special_tokens: bool = True revision: str = "main" namespace: str = None # The namespace under which to launch the endopint. Defaults to the current user's namespace + image_url: str = None + env_vars: dict = None def get_dtype_args(self) -> Dict[str, str]: model_dtype = self.model_dtype.lower() @@ -237,6 +240,9 @@ def get_dtype_args(self) -> Dict[str, str]: return {"DTYPE": model_dtype} return {} + def get_custom_env_vars(self) -> Dict[str, str]: + return {k: str(v) for k, v in self.env_vars.items()} if self.env_vars else {} + @staticmethod def nullable_keys() -> list[str]: """ @@ -244,7 +250,7 @@ def nullable_keys() -> list[str]: keys be specified in the configuration in order to launch the endpoint. This function returns the list of keys that are not required and can remain None. """ - return ["namespace"] + return ["namespace", "env_vars", "image_url"] def create_model_config(args: Namespace, accelerator: Union["Accelerator", None]) -> BaseModelConfig: # noqa: C901 @@ -271,7 +277,7 @@ def create_model_config(args: Namespace, accelerator: Union["Accelerator", None] return BaseModelConfig(**args_dict) - if args.model_config: + if hasattr(args, "model_config") and args.model_config: config = args.model_config["model"] else: with open(args.model_config_path, "r") as f: @@ -279,8 +285,9 @@ def create_model_config(args: Namespace, accelerator: Union["Accelerator", None] if config["type"] == "tgi": return TGIModelConfig( - inference_server_address=args["instance"]["inference_server_address"], - inference_server_auth=args["instance"]["inference_server_auth"], + inference_server_address=config["instance"]["inference_server_address"], + inference_server_auth=config["instance"]["inference_server_auth"], + model_id=config["instance"]["model_id"], ) if config["type"] == "endpoint": @@ -303,6 +310,8 @@ def create_model_config(args: Namespace, accelerator: Union["Accelerator", None] instance_size=config["instance"]["instance_size"], instance_type=config["instance"]["instance_type"], namespace=config["instance"]["namespace"], + image_url=config["instance"].get("image_url", None), + env_vars=config["instance"].get("env_vars", None), ) return InferenceModelConfig(model=config["base_params"]["endpoint_name"]) diff --git a/src/lighteval/models/model_loader.py b/src/lighteval/models/model_loader.py index 3af5be263..dd55b4241 100644 --- a/src/lighteval/models/model_loader.py +++ b/src/lighteval/models/model_loader.py @@ -88,10 +88,12 @@ def load_model_with_tgi(config: TGIModelConfig): raise ImportError(NO_TGI_ERROR_MSG) hlog(f"Load model from inference server: {config.inference_server_address}") - model = ModelClient(address=config.inference_server_address, auth_token=config.inference_server_auth) + model = ModelClient( + address=config.inference_server_address, auth_token=config.inference_server_auth, model_id=config.model_id + ) model_name = str(model.model_info["model_id"]) model_sha = model.model_info["model_sha"] - model_precision = model.model_info["dtype"] + model_precision = model.model_info["model_dtype"] model_size = -1 model_info = ModelInfo( model_name=model_name, diff --git a/src/lighteval/models/tgi_model.py b/src/lighteval/models/tgi_model.py index 5d519667b..754152587 100644 --- a/src/lighteval/models/tgi_model.py +++ b/src/lighteval/models/tgi_model.py @@ -21,15 +21,14 @@ # SOFTWARE. import asyncio -import math -from typing import Coroutine, List, Tuple, Union +from typing import Coroutine -import numpy as np import requests -from tqdm import tqdm +from huggingface_hub import TextGenerationOutput from transformers import AutoTokenizer -from lighteval.utils import NO_TGI_ERROR_MSG, as_list, is_tgi_available +from lighteval.models.endpoint_model import InferenceEndpointModel +from lighteval.utils import NO_TGI_ERROR_MSG, is_tgi_available if is_tgi_available(): @@ -45,99 +44,63 @@ def divide_chunks(array, n): yield array[i : i + n] -class ModelClient: +# inherit from InferenceEndpointModel instead of LightevalModel since they both use the same interface, and only overwrite +# the client functions, since they use a different client. +class ModelClient(InferenceEndpointModel): _DEFAULT_MAX_LENGTH: int = 4096 - def __init__( - self, - address, - auth_token=None, - ) -> None: + def __init__(self, address, auth_token=None, model_id=None) -> None: if not is_tgi_available(): raise ImportError(NO_TGI_ERROR_MSG) - headers = {} if auth_token is None else {"Authorization": f"Basic {auth_token}"} + headers = {} if auth_token is None else {"Authorization": f"Bearer {auth_token}"} self.client = AsyncClient(address, headers=headers, timeout=240) self._max_gen_toks = 256 - self.model_info = requests.get(f"{address}/info").json() - self.tokenizer = AutoTokenizer.from_pretrained(self.model_info["model_id"]) - - def __process_request_generate(self, request: Tuple[str, Union[Tuple, List]]) -> Coroutine[None, List, str]: - context, stopping_arugments = request - - if isinstance(stopping_arugments, tuple): - stop_sequence_arg, max_gen_tokens_arg = stopping_arugments - stop_sequences = as_list(stop_sequence_arg) - # Todo @clefourrier add proper messaging explaining this - # we don't want people to be surprised because they set a max len in the model overwritten by the eval - max_tokens = max_gen_tokens_arg - else: - stop_sequences = as_list(stopping_arugments) - max_tokens = self._max_gen_toks - - if stop_sequences is None or stop_sequences == [None]: - stop_sequences = [] - + self.model_info = requests.get(f"{address}/info", headers=headers).json() + if "model_id" not in self.model_info: + raise ValueError("Error occured when fetching info: " + str(self.model_info)) + if model_id: + self.model_info["model_id"] = model_id + self._tokenizer = AutoTokenizer.from_pretrained(self.model_info["model_id"]) + self._add_special_tokens = True + self.use_async = True + + def _async_process_request( + self, context: str, stop_tokens: list[str], max_tokens: int + ) -> Coroutine[None, list[TextGenerationOutput], str]: + # Todo: add an option to launch with conversational instead for chat prompts generated_text = self.client.generate( - context, - max_new_tokens=max_tokens, + prompt=context, decoder_input_details=True, - stop_sequences=stop_sequences, - seed=42, - truncate=ModelClient._DEFAULT_MAX_LENGTH, + max_new_tokens=max_tokens, + stop_sequences=stop_tokens, ) return generated_text - async def __process_batch_generate(self, requests: List[Tuple[str, Union[Tuple, List]]]): - return await asyncio.gather(*[self.__process_request_generate(request) for request in requests]) - - def greedy_until(self, requests: List[Tuple[str, Union[Tuple, List]]], override_bs=None) -> List[str]: - generated_texts: List[str] = [] - - batch_size = override_bs if override_bs > 0 else BATCH_SIZE - - for batch in tqdm( - divide_chunks(requests, batch_size), total=math.ceil(len(requests) // batch_size), maxinterval=2 - ): - results = asyncio.run(self.__process_batch_generate(batch)) - generated_texts.extend([result.generated_text for result in results]) - - return generated_texts + def _process_request(self, *args, **kwargs) -> TextGenerationOutput: + return asyncio.run(self._async_process_request(*args, **kwargs)) - def __process_request_logprob(self, request: Tuple[str, str]) -> Coroutine[None, List, str]: - context, choice = request - out = self.client.generate(context + choice, max_new_tokens=1, decoder_input_details=True) - return out - - async def __process_batch_logprob(self, requests: List[Tuple[str, str]]): - return await asyncio.gather(*[self.__process_request_logprob(request) for request in requests]) - - def loglikelihood(self, requests: List[Tuple[str, str]], override_bs=None) -> List[Tuple[float, bool]]: - res: List[Tuple[float, bool]] = [] - - batch_size = override_bs if override_bs > 0 else BATCH_SIZE - - for batch in tqdm( - divide_chunks(requests, batch_size), total=math.ceil(len(requests) // batch_size), maxinterval=1 - ): - results = asyncio.run(self.__process_batch_logprob(batch)) - details = [result.details.prefill for result in results] - - for detail, (context, choice) in zip(details, batch): - tokenized_context = self.tokenizer.tokenize(context, add_special_tokens=True) - tokenized_input = self.tokenizer.tokenize(context + choice, add_special_tokens=True) + def set_cache_hook(self, cache_hook): + self.cache_hook = cache_hook - i = 0 - while i < len(tokenized_context) and tokenized_input[i] == tokenized_context[i]: - i += 1 + @property + def tokenizer(self): + return self._tokenizer - logprobs = [token.logprob for token in detail[i:]] + @property + def add_special_tokens(self): + return self._add_special_tokens - logit_sum: float = np.sum(logprobs) - res.append((logit_sum, False)) + @property + def max_length(self) -> int: + if hasattr(self.tokenizer, "model_max_length"): + return self.tokenizer.model_max_length + return ModelClient._DEFAULT_MAX_LENGTH - return res + @property + def disable_tqdm(self) -> bool: + False - def set_cache_hook(self, cache_hook): - self.cache_hook = cache_hook + def cleanup(self): + pass diff --git a/src/lighteval/utils.py b/src/lighteval/utils.py index d3c32e994..3380fc9a5 100644 --- a/src/lighteval/utils.py +++ b/src/lighteval/utils.py @@ -153,7 +153,7 @@ def is_accelerate_available() -> bool: def is_tgi_available() -> bool: - return importlib.util.find_spec("text-generation") is not None + return importlib.util.find_spec("text_generation") is not None NO_TGI_ERROR_MSG = "You are trying to start a text generation inference endpoint, but text-generation is not present in your local environement. Please install it using pip."