diff --git a/pyproject.toml b/pyproject.toml index 0096350f6..56faf1b4b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,7 +50,7 @@ keywords = ["evaluation", "nlp", "llm"] dependencies = [ # Base dependencies "transformers>=4.36.0", - "huggingface_hub==0.19.4", + "huggingface_hub==0.20.3", "torch>=2.0", "GitPython==3.1.31", # for logging "datasets>=2.14.0", diff --git a/run_evals_accelerate.py b/run_evals_accelerate.py index 337a1c2f7..7002c8747 100644 --- a/run_evals_accelerate.py +++ b/run_evals_accelerate.py @@ -6,8 +6,10 @@ def get_parser(): parser = argparse.ArgumentParser() group = parser.add_mutually_exclusive_group(required=True) - weight_type_group = parser.add_mutually_exclusive_group() + task_type_group = parser.add_mutually_exclusive_group(required=True) + # Model type 1) Base model + weight_type_group = parser.add_mutually_exclusive_group() weight_type_group.add_argument( "--delta_weights", action="store_true", @@ -24,39 +26,50 @@ def get_parser(): "--base_model", type=str, default=None, help="name of the base model to be used for delta or adapter weights" ) - parser.add_argument("--model_args", required=True) - parser.add_argument("--output_dir", required=True) + task_type_group.add_argument("--model_args") parser.add_argument("--model_dtype", type=str, default=None) parser.add_argument( "--multichoice_continuations_start_space", action="store_true", - help="Whether to force multiple choice continuations starts with a space", + help="Whether to force multiple choice continuations to start with a space", ) parser.add_argument( "--no_multichoice_continuations_start_space", action="store_true", - help="Whether to force multiple choice continuations do not starts with a space", + help="Whether to force multiple choice continuations to not start with a space", ) + parser.add_argument("--use_chat_template", default=False, action="store_true") + # Model type 2) TGI + task_type_group.add_argument("--inference_server_address", type=str) + parser.add_argument("--inference_server_auth", type=str, default=None) + # Model type 3) Inference endpoints + task_type_group.add_argument("--endpoint_model_name", type=str) + parser.add_argument("--accelerator", type=str, default=None) + parser.add_argument("--vendor", type=str, default=None) + parser.add_argument("--region", type=str, default=None) + parser.add_argument("--instance_size", type=str, default=None) + parser.add_argument("--instance_type", type=str, default=None) + parser.add_argument("--reuse_existing", default=False, action="store_true") + # Debug + parser.add_argument("--max_samples", type=int, default=None) + parser.add_argument("--job_id", type=str, help="Optional Job ID for future reference", default="") + # Saving parser.add_argument("--push_results_to_hub", default=False, action="store_true") parser.add_argument("--save_details", action="store_true") parser.add_argument("--push_details_to_hub", default=False, action="store_true") parser.add_argument( "--public_run", default=False, action="store_true", help="Push results and details to a public repo" ) - parser.add_argument("--max_samples", type=int, default=None) - parser.add_argument("--override_batch_size", type=int, default=-1) - parser.add_argument("--dataset_loading_processes", type=int, default=1) - parser.add_argument("--inference_server_address", type=str, default=None) - parser.add_argument("--inference_server_auth", type=str, default=None) - parser.add_argument("--num_fewshot_seeds", type=int, default=1, help="Number of trials the few shots") parser.add_argument("--cache_dir", type=str, default=CACHE_DIR) parser.add_argument( "--results_org", type=str, help="Hub organisation where you want to store the results. Your current token must have write access to it", ) - parser.add_argument("--job_id", type=str, help="Optional Job ID for future reference", default="") - parser.add_argument("--use_chat_template", default=False, action="store_true") + # Common parameters + parser.add_argument("--output_dir", required=True) + parser.add_argument("--override_batch_size", type=int, default=-1) + parser.add_argument("--dataset_loading_processes", type=int, default=1) parser.add_argument( "--custom_tasks_file", type=str, @@ -69,6 +82,7 @@ def get_parser(): default=None, help="Id of a task, e.g. 'original|mmlu:abstract_algebra|5' or path to a texte file with a list of tasks", ) + parser.add_argument("--num_fewshot_seeds", type=int, default=1, help="Number of trials the few shots") return parser diff --git a/src/lighteval/data.py b/src/lighteval/data.py index 3152ba4d7..5af88ff35 100644 --- a/src/lighteval/data.py +++ b/src/lighteval/data.py @@ -5,7 +5,14 @@ from torch.utils.data.distributed import DistributedSampler, T_co from lighteval.logging.hierarchical_logger import hlog_warn -from lighteval.tasks.requests import Request +from lighteval.tasks.requests import ( + GreedyUntilRequest, + GreedyUntilWithLogitsRequest, + LoglikelihoodRequest, + LoglikelihoodRollingRequest, + LoglikelihoodSingleTokenRequest, + Request, +) class DynamicBatchDataset(Dataset): @@ -28,6 +35,9 @@ def __init__( requests (List): A list of requests. dataset_splits (int): The number of dataset splits. """ + # We make sure the requests contain the tokenized versions of their values + if any(r.tokenized_context is None for r in requests): + raise ValueError("You passed a request for which tokenization had not happened yet.") # sort the requests using the collate function and save the original order enumerated_requests = list(enumerate(requests)) @@ -124,12 +134,12 @@ def __len__(self) -> int: """ return self.split_end - self.split_start - def _sorting_criteria(self, x) -> int: + def _sorting_criteria(self, request) -> int: raise NotImplementedError() class LoglikelihoodDataset(DynamicBatchDataset): - def _sorting_criteria(self, x) -> int: + def _sorting_criteria(self, request: LoglikelihoodRequest | LoglikelihoodRollingRequest) -> int: """ Collates the input data for batching. @@ -149,13 +159,12 @@ def _sorting_criteria(self, x) -> int: Returns: tuple: A tuple containing the sorted input data. """ - - toks = x[1] + x[2] + toks = request.tokenized_context + request.tokenized_continuation return -len(toks) class LoglikelihoodSingleTokenDataset(DynamicBatchDataset): - def _sorting_criteria(self, x) -> int: + def _sorting_criteria(self, request: LoglikelihoodSingleTokenRequest) -> int: """ Collates the input data for batching. @@ -167,19 +176,14 @@ def _sorting_criteria(self, x) -> int: is useful to simplify the batching logic and more importantly to make automatic adaptive batches much much easier to implement - any OOMs will happen right away rather than near the end - - Args: - x (tuple): A tuple containing the input data. - - Returns: - tuple: A tuple containing the collated data. """ - toks = x[1] # We take only the prompt, no need for the continuation (since it's a list of single tokens) + # We take only the prompt, no need for the continuation (since it's a list of single tokens) + toks = request.tokenized_context return -len(toks) class GenerativeTaskDataset(DynamicBatchDataset): - def _sorting_criteria(self, x) -> int: + def _sorting_criteria(self, request: GreedyUntilRequest | GreedyUntilWithLogitsRequest) -> int: """ Collate function for generating batches. @@ -189,9 +193,8 @@ def _sorting_criteria(self, x) -> int: Returns: Any: The collated data. """ - toks = x[0] - meta_data = x[1] - _, gen_length = meta_data[0], meta_data[1] + toks = request.tokenized_context + gen_length = request.generation_size return -(len(toks) + gen_length) @@ -211,7 +214,7 @@ def __getitem__(self, index) -> Request: """ return index, self.sorted_data[index + self.split_start] - def _sorting_criteria(self, x) -> int: + def _sorting_criteria(self, request) -> int: """ Collate function for generating batches. @@ -221,9 +224,8 @@ def _sorting_criteria(self, x) -> int: Returns: Any: The collated data. """ - toks = x[0] - meta_data = x[1] - _, gen_length = meta_data[0], meta_data[1] + toks = request.tokenized_context + gen_length = request.generation_size return -(len(toks) + gen_length) diff --git a/src/lighteval/evaluator.py b/src/lighteval/evaluator.py index 6ca5ed59d..c547bc1ad 100644 --- a/src/lighteval/evaluator.py +++ b/src/lighteval/evaluator.py @@ -8,7 +8,7 @@ from lighteval.logging.evaluation_tracker import EvaluationTracker from lighteval.logging.hierarchical_logger import hlog from lighteval.models.base_model import BaseModel -from lighteval.models.inference_client import ModelClient +from lighteval.models.tgi_model import ModelClient from lighteval.tasks.lighteval_task import LightevalTask from lighteval.tasks.requests import Doc, Request, RequestType, TaskExampleId diff --git a/src/lighteval/main_accelerate.py b/src/lighteval/main_accelerate.py index 349c8fc27..b048fbd98 100644 --- a/src/lighteval/main_accelerate.py +++ b/src/lighteval/main_accelerate.py @@ -121,4 +121,6 @@ def main(args): print(make_results_table(final_dict)) + model.cleanup() + return final_dict diff --git a/src/lighteval/metrics/__init__.py b/src/lighteval/metrics/__init__.py index 3b17854e7..6dc58ff57 100644 --- a/src/lighteval/metrics/__init__.py +++ b/src/lighteval/metrics/__init__.py @@ -8,11 +8,18 @@ def apply_target_perplexity_metric(results: list[ModelReturn], formatted_doc: Doc, metrics: list[str]): outputs = {} - current_results = [results.pop(0) for _ in range(len(formatted_doc.get_golds()))] + reference_text = formatted_doc.get_golds()[0] + current_result = results.pop(0) + target_logprob = current_result.result[0] + target_acc = current_result.result[1] for metric in metrics: - if Metrics[metric].value.category == MetricCategory.PERPLEXITY: - outputs.update(Metrics[metric].value.compute(results=current_results)) + if Metrics[metric].value.category == MetricCategory.TARGET_PERPLEXITY: + outputs.update( + Metrics[metric].value.compute( + logprobs=target_logprob, target_acc=target_acc, reference_text=reference_text + ) + ) return results, outputs @@ -30,7 +37,9 @@ def apply_perplexity_metric(results: list[ModelReturn], formatted_doc: Doc, metr for metric in metrics: if Metrics[metric].value.category == MetricCategory.PERPLEXITY: - outputs.update(Metrics[metric].value.compute(results=current_result, reference_text=reference_text)) + outputs.update( + Metrics[metric].value.compute(logprobs=current_result.result, reference_text=reference_text) + ) return results, outputs @@ -85,7 +94,9 @@ def apply_multichoice_metric(results: list[ModelReturn], formatted_doc: Doc, met raise ValueError( "You can't use a multi choice metric with only one choice. Use `acc_golds_likelihood` instead." ) - choices_logprob = [results[i].result[0] for i in range(len(formatted_doc.choices))] + + # Todo: make better system with return_bool_score instead of taking first element + choices_logprob = [results[i].result[0] for i in range(len(formatted_doc.choices))] # sum( gold_ixs = as_list(formatted_doc.gold_index) for metric in metrics: diff --git a/src/lighteval/metrics/metrics_sample.py b/src/lighteval/metrics/metrics_sample.py index ec123741b..e87e3bb58 100644 --- a/src/lighteval/metrics/metrics_sample.py +++ b/src/lighteval/metrics/metrics_sample.py @@ -1,6 +1,8 @@ """This module manages all the metrics occurring at the sample level. The results of said metrics are then aggregated using simple function (min, mean, max, ...) at the corpus level. Most metrics fall under this category. """ +from typing import Union + import nltk import numpy as np from nltk.metrics.distance import edit_distance @@ -275,17 +277,16 @@ def compute(self, choices_logprob: list[float], gold_ixs: list[float], formatted return 1.0 / (min(ranked_choices) + 1) -def acc_golds_likelihood(results: list[tuple[float, int]], **kwargs) -> int: +def acc_golds_likelihood(target_acc: Union[list[int], int], **kwargs) -> int: """Tests if at least one of predicted gold targets' log-likelihood is above 0.5. Args: - results (list[int]): List of tuples containing, for each gold, the predictions log-probabilities associated with whether they are above 0.5 aggregated. - formatted_doc (Doc): _description_ + target_acc (list[int]): List of scores indicating whether the predictions log-probabilities are above 0.5 aggregated. Returns: int: 1 if at least one of the possible golds had a log-likelihood above 0.5. """ - return max([int(acc_ppl) for _, acc_ppl in results]) + return max([int(acc_ppl) for acc_ppl in as_list(target_acc)]) class ROUGE: diff --git a/src/lighteval/metrics/sample_preparator.py b/src/lighteval/metrics/sample_preparator.py index 659022920..c28ed2470 100644 --- a/src/lighteval/metrics/sample_preparator.py +++ b/src/lighteval/metrics/sample_preparator.py @@ -106,14 +106,14 @@ def count_units(self, text: str) -> int: if self.units_type == "bytes": return len(text.encode("utf-8")) - def prepare(self, results, reference_text, **kwargs): + def prepare(self, logprobs: list[float] | float, reference_text: str, **kwargs): """Prepares an individual perplexity example to the format expected by metrics computed at the corpus level (aggregated). Args: - results (list[float]): List of the logprobabilities computed for each item + logprobs (list[float]): List of the logprobabilities computed for each item of the sequence or single aggregated logprob over the sequence reference_text (str): Current reference text for which to compute the length in self.units_type Returns: PerplexityCorpusMetricInput: Stores the measured logprobs and associated text lengths, counted in the reference unit. """ - return PerplexityCorpusMetricInput(logprobs=results.result, weights=self.count_units(reference_text)) + return PerplexityCorpusMetricInput(logprobs=logprobs, weights=self.count_units(reference_text)) diff --git a/src/lighteval/models/abstract_model.py b/src/lighteval/models/abstract_model.py new file mode 100644 index 000000000..13c1b438a --- /dev/null +++ b/src/lighteval/models/abstract_model.py @@ -0,0 +1,155 @@ +from abc import ABC, abstractmethod +from typing import Optional, Union + +import torch +from transformers import BatchEncoding + +from lighteval.models.model_config import EnvConfig +from lighteval.models.model_output import GenerateReturn, LoglikelihoodReturn, LoglikelihoodSingleTokenReturn +from lighteval.tasks.requests import ( + GreedyUntilRequest, + GreedyUntilWithLogitsRequest, + LoglikelihoodRequest, + LoglikelihoodRollingRequest, + LoglikelihoodSingleTokenRequest, +) + + +TokenSequence = Union[list[int], torch.LongTensor, torch.Tensor, BatchEncoding] + + +class LightevalModel(ABC): + DATASET_SPLITS = 4 + + """Abstract model class defining the API that every model to plug into lighteval must follow.""" + + @abstractmethod + def __init__( + self, + config, + env_config: EnvConfig, + ): + return NotImplemented + + def cleanup(self): + """Clean up operations if needed, such as closing an endpoint.""" + return + + @property + @abstractmethod + def tokenizer(self): + raise NotImplementedError + + @property + @abstractmethod + def add_special_tokens(self): + raise NotImplementedError + + @property + @abstractmethod + def max_length(self) -> int: + """Return the maximum sequence length of the model.""" + raise NotImplementedError + + @property + def disable_tqdm(self) -> bool: + raise NotImplementedError + + def greedy_until_with_logits( + self, + requests: list[GreedyUntilWithLogitsRequest], + override_bs: Optional[int] = None, + ) -> list[GenerateReturn]: + """ + Generates sequences greedily until a stopping condition is met, + returning both the generated sequences and the logits. + + Args: + requests (list[tuple[str, dict]]): A list of input requests, + where each request is a tuple containing a prompt string and a dictionary of additional parameters. + disable_tqdm (bool, optional): Whether to disable the tqdm progress bar. Defaults to False. + override_bs (Optional[int], optional): Overrides the batch size for generation. Defaults to None. + + Returns: + list[GenerateReturn]: A list of GenerateReturn objects, + where each object contains the generated sequence and the corresponding logits. + """ + return self.greedy_until( + requests=requests, + override_bs=override_bs, + returns_logits=True, + ) + + @abstractmethod + def greedy_until( + self, + requests: list[GreedyUntilRequest], + returns_logits: bool = False, + override_bs: Optional[int] = None, + ) -> list[GenerateReturn]: + """ + Generates responses using a greedy decoding strategy until certain ending conditions are met. + + Args: + requests (list[Request]): list of requests containing the context and ending conditions. + returns_logits (bool, optional): Whether to return the logits of the generated responses. Defaults to False. + disable_tqdm (bool, optional): Whether to disable the progress bar. Defaults to False. + override_bs (int, optional): Override the batch size for generation. Defaults to None. + + Returns: + list[GenerateReturn]: list of generated responses. + """ + return NotImplemented + + @abstractmethod + def loglikelihood( + self, requests: list[LoglikelihoodRequest], override_bs: Optional[int] = None + ) -> list[LoglikelihoodReturn]: + """Tokenize the context and continuation and compute the log likelihood of those + tokenized sequences. + """ + return NotImplemented + + @abstractmethod + def loglikelihood_rolling( + self, requests: list[LoglikelihoodRollingRequest], override_bs=None + ) -> list[LoglikelihoodReturn]: + """This function is used to compute the log likelihood of the context for perplexity metrics.""" + return NotImplemented + + @abstractmethod + def loglikelihood_single_token( + self, requests: list[LoglikelihoodSingleTokenRequest], override_bs: Optional[int] = None + ) -> list[LoglikelihoodSingleTokenReturn]: + """Tokenize the context and continuation and compute the log likelihood of those + tokenized sequences. + """ + return NotImplemented + + # Tokenization utils + def tok_encode(self, str_to_encode: str | list[str], add_special_tokens: Optional[bool] = None) -> TokenSequence: + if add_special_tokens is None: + add_special_tokens = self.add_special_tokens + if isinstance(str_to_encode, str): + return self.tokenizer.encode(str_to_encode, add_special_tokens=add_special_tokens) + return self.tokenizer( + str_to_encode, + padding=True, + add_special_tokens=add_special_tokens, + return_tensors="pt", + ) + + def tok_encode_pair(self, context, continuation): + """Encodes a context, continuation pair by taking care of the spaces in between.""" + n_spaces = len(context) - len(context.rstrip()) + if n_spaces > 0: + continuation = context[-n_spaces:] + continuation + context = context[:-n_spaces] + whole_enc = self.tok_encode(context + continuation) + context_enc = self.tok_encode(context) + context_enc_len = len(context_enc) + continuation_enc = whole_enc[context_enc_len:] + return context_enc, continuation_enc + + def tok_decode(self, tokens: torch.LongTensor) -> list[str]: + return self.tokenizer.batch_decode(tokens, skip_special_tokens=True) diff --git a/src/lighteval/models/adapter_model.py b/src/lighteval/models/adapter_model.py index 3c3da120a..bc0af1f79 100644 --- a/src/lighteval/models/adapter_model.py +++ b/src/lighteval/models/adapter_model.py @@ -1,7 +1,7 @@ from contextlib import nullcontext import torch -from transformers import AutoModel, PreTrainedTokenizer +from transformers import AutoModelForCausalLM, PreTrainedTokenizer from lighteval.logging.hierarchical_logger import hlog from lighteval.models.base_model import BaseModel @@ -20,7 +20,7 @@ def _create_auto_tokenizer(self, config: AdapterModelConfig, env_config: EnvConf # (= the parent model, not the model of interest) return self._create_auto_tokenizer_with_name(config.base_model, config=config, env_config=env_config) - def _create_auto_model(self, config: AdapterModelConfig, env_config: EnvConfig) -> AutoModel: + def _create_auto_model(self, config: AdapterModelConfig, env_config: EnvConfig) -> AutoModelForCausalLM: """Returns a PeftModel from a base model and a version fined tuned using PEFT.""" torch_dtype = _get_dtype(config.dtype, self._config) config.model_parallel, max_memory, device_map = self.init_model_parallel(config.model_parallel) @@ -31,7 +31,7 @@ def _create_auto_model(self, config: AdapterModelConfig, env_config: EnvConfig) if self.accelerator.is_local_main_process if self.accelerator is not None else nullcontext(): hlog(f"Loading model from {adapter_weights} and applying adapter to {config.base_model}") - base = self.AUTO_MODEL_CLASS.from_pretrained( + base = AutoModelForCausalLM.from_pretrained( config.base_model, torch_dtype=torch.float16, low_cpu_mem_usage=True, token=env_config.token ) # Should pass revision @@ -43,7 +43,7 @@ def _create_auto_model(self, config: AdapterModelConfig, env_config: EnvConfig) hlog(f"Loading model from {merged_path}") - model = self.AUTO_MODEL_CLASS.from_pretrained( + model = AutoModelForCausalLM.from_pretrained( merged_path, max_memory=max_memory, device_map=device_map, diff --git a/src/lighteval/models/base_model.py b/src/lighteval/models/base_model.py index ebcb15fe8..0f753a491 100644 --- a/src/lighteval/models/base_model.py +++ b/src/lighteval/models/base_model.py @@ -6,10 +6,11 @@ import transformers from torch.utils.data import DataLoader from tqdm import tqdm -from transformers import AutoTokenizer, BatchEncoding +from transformers import AutoModelForCausalLM, AutoTokenizer from lighteval.data import GenerativeTaskDataset, LoglikelihoodDataset, LoglikelihoodSingleTokenDataset from lighteval.logging.hierarchical_logger import hlog, hlog_err, hlog_warn +from lighteval.models.abstract_model import LightevalModel from lighteval.models.model_config import BaseModelConfig, EnvConfig from lighteval.models.model_output import Batch, GenerateReturn, LoglikelihoodReturn, LoglikelihoodSingleTokenReturn from lighteval.models.utils import _get_dtype, _get_precision, _simplify_name @@ -19,6 +20,7 @@ LoglikelihoodRequest, LoglikelihoodRollingRequest, LoglikelihoodSingleTokenRequest, + Request, ) from lighteval.utils import ( is_accelerate_available, @@ -31,35 +33,23 @@ os.environ["TOKENIZERS_PARALLELISM"] = "false" -TokenSequence = Union[list[int], torch.LongTensor, torch.Tensor, BatchEncoding] - -DATASET_SPLITS = 4 STARTING_BATCH_SIZE = 512 -class BaseModel: - AUTO_CONFIG_CLASS: transformers.AutoConfig = transformers.AutoConfig - AUTO_TOKENIZER_CLASS: transformers.AutoTokenizer = transformers.AutoTokenizer - AUTO_MODEL_CLASS: transformers.AutoModel = transformers.AutoModelForCausalLM - - # Default max sequence length setting for when no `max_length` is provided - # or no max length config setting is found in the model or tokenizer. - _DEFAULT_MAX_LENGTH: int = 2048 - +class BaseModel(LightevalModel): def __init__( self, config: BaseModelConfig, env_config: EnvConfig, ): """Initializes a HuggingFace `AutoModel` and `AutoTokenizer` for evaluation.""" + self._config = config.init_configs(env_config) self.accelerator = config.accelerator self._batch_size = config.batch_size - self._max_gen_toks = config.max_gen_toks - self._max_length = config.max_length - self._config = config.init_configs(env_config) + self._max_length = self._init_max_length(config.max_length) - self._add_special_tokens = config.add_special_tokens - self.tokenizer = self._create_auto_tokenizer(config, env_config) + self._add_special_tokens = config.add_special_tokens if config.add_special_tokens is not None else False + self._tokenizer = self._create_auto_tokenizer(config, env_config) # If model_parallel is not set we compare the number of process with the number of GPUs self.model = self._create_auto_model(config, env_config) @@ -81,16 +71,17 @@ def __init__( self.precision = _get_precision(config, model_auto_config=self._config) - def _encode_pair(self, context, continuation): - n_spaces = len(context) - len(context.rstrip()) - if n_spaces > 0: - continuation = context[-n_spaces:] + continuation - context = context[:-n_spaces] - whole_enc = self.tok_encode(context + continuation) - context_enc = self.tok_encode(context) - context_enc_len = len(context_enc) - continuation_enc = whole_enc[context_enc_len:] - return context_enc, continuation_enc + @property + def tokenizer(self): + return self._tokenizer + + @property + def add_special_tokens(self): + return self._add_special_tokens + + @property + def max_length(self) -> int: + return self._max_length def init_model_parallel(self, model_parallel: bool = None) -> Tuple[bool, Optional[dict], Optional[str]]: """Compute all the parameters related to model_parallel""" @@ -155,7 +146,7 @@ def _create_auto_model(self, config: BaseModelConfig, env_config: EnvConfig) -> config.model_parallel, max_memory, device_map = self.init_model_parallel(config.model_parallel) torch_dtype = _get_dtype(config.dtype, self._config) - model = self.AUTO_MODEL_CLASS.from_pretrained( + model = AutoModelForCausalLM.from_pretrained( config.pretrained, revision=config.revision + (f"/{config.subfolder}" if config.subfolder is not None else ""), max_memory=max_memory, @@ -222,46 +213,7 @@ def _create_auto_tokenizer_with_name( return tokenizer - @property - def add_special_tokens(self) -> bool: - """ - Determines whether to include special tokens in encoded text. - TODO: Remove these conditionals once HuggingFace supports a way to - check whether or not an arbitrary model was trained with special tokens. - - Returns: - bool: True if special tokens should be included, False otherwise. - - Raises: - ValueError: If the `add_special_tokens` value cannot be determined from the model class. - """ - if self._add_special_tokens is not None: - return self._add_special_tokens - elif self.AUTO_MODEL_CLASS is transformers.AutoModelForCausalLM: - return False - elif self.AUTO_MODEL_CLASS is transformers.AutoModelForSeq2SeqLM: - return True - else: - raise ValueError( - "Could not determine `add_special_tokens` value from the model " - "class. Set to `True` or `False` depending on whether the model " - "was pre-trained with special tokens." - ) - - @property - def eot_token(self) -> str: - return self.tokenizer.eos_token - - @property - def eot_token_id(self) -> int: - return self.tokenizer.eos_token_id - - @property - def max_gen_toks(self) -> int: - return self._max_gen_toks - - @property - def max_length(self) -> int: + def _init_max_length(self, max_length) -> int: """Return the maximum sequence length of the model. NOTE: Different model configurations have different max sequence length attribute names. @@ -276,10 +228,10 @@ def max_length(self) -> int: based on the model's configuration or tokenizer's model_max_length attribute. Returns: - None + int: Max length to use depending on the available args and config """ - if self._max_length is not None: - return self._max_length + if max_length is not None: + return max_length # Try to get the sequence length from the model config. seqlen_config_attrs = ("n_positions", "max_position_embeddings", "n_ctx") @@ -289,7 +241,9 @@ def max_length(self) -> int: if hasattr(self.tokenizer, "model_max_length"): return self.tokenizer.model_max_length - return self._DEFAULT_MAX_LENGTH + # Default max sequence length setting for when no `max_length` is provided + # or no max length config setting is found in the model or tokenizer. + return 2048 @property def batch_size(self) -> int: @@ -301,22 +255,26 @@ def batch_size(self) -> int: def device(self) -> Union[int, str, torch.device]: return self._device - def tok_encode(self, string: str, add_special_tokens: Optional[bool] = None) -> TokenSequence: - # TODO: Merge `tok_encode_batch` here. - if add_special_tokens is None: - add_special_tokens = self.add_special_tokens - return self.tokenizer.encode(string, add_special_tokens=add_special_tokens) - - def tok_encode_batch(self, strings: list[str]) -> TokenSequence: - return self.tokenizer( - strings, - padding=True, - add_special_tokens=self.add_special_tokens, - return_tensors="pt", - ) + @property + def disable_tqdm(self) -> bool: + disable_tqdm = False + if self.accelerator: + disable_tqdm = bool(not self.accelerator.is_main_process) + return disable_tqdm - def tok_decode(self, tokens: torch.LongTensor) -> list[str]: - return self.tokenizer.batch_decode(tokens, skip_special_tokens=True) + def _check_continuations_start_space(self, continuation: str) -> str: + """Some models tokenizer want a space at the beginning and other not. We update this if needed here. + multichoice_continuations_start_space can be: + - True (add a space if these isn't one) + - False (remove a space if there is one) + - None (Don't touch - default) + Todo: find a way to add this back WITHOUT breaking compatibility with the harness + """ + if self.multichoice_continuations_start_space is True and continuation[0] != " ": + continuation = " " + continuation + if self.multichoice_continuations_start_space is False and continuation[0] == " ": + continuation = continuation.lstrip() + return continuation def _model_call(self, inputs: torch.Tensor) -> torch.Tensor: return self.model(inputs).logits @@ -343,9 +301,7 @@ def forward_batch(batch_size): def greedy_until_with_logits( self, requests: list[GreedyUntilWithLogitsRequest], - disable_tqdm: bool = False, override_bs: Optional[int] = None, - dataset_splits: int = 4, ) -> list[GenerateReturn]: """ Generates sequences greedily until a stopping condition is met, @@ -354,9 +310,7 @@ def greedy_until_with_logits( Args: requests (list[tuple[str, dict]]): A list of input requests, where each request is a tuple containing a prompt string and a dictionary of additional parameters. - disable_tqdm (bool, optional): Whether to disable the tqdm progress bar. Defaults to False. override_bs (Optional[int], optional): Overrides the batch size for generation. Defaults to None. - dataset_splits (int, optional): Number of splits to divide the dataset into for parallel generation. Defaults to 4. Returns: list[GenerateReturn]: A list of GenerateReturn objects, @@ -366,16 +320,14 @@ def greedy_until_with_logits( return self.greedy_until( requests, returns_logits=True, - disable_tqdm=disable_tqdm, + disable_tqdm=self.disable_tqdm, override_bs=override_bs, - dataset_splits=dataset_splits, ) def greedy_until( self, requests: list[GreedyUntilRequest], returns_logits: bool = False, - disable_tqdm: bool = False, override_bs: Optional[int] = None, ) -> list[GenerateReturn]: """ @@ -384,56 +336,58 @@ def greedy_until( Args: requests (list[Request]): list of requests containing the context and ending conditions. returns_logits (bool, optional): Whether to return the logits of the generated responses. Defaults to False. - disable_tqdm (bool, optional): Whether to disable the progress bar. Defaults to False. override_bs (int, optional): Override the batch size for generation. Defaults to None. - dataset_splits (int, optional): Number of splits to divide the dataset into. Defaults to 4. Returns: list[GenerateReturn]: list of generated responses. """ - inputs = [ - ( - request.context, - (request.stop_sequence + [self.eot_token], request.generation_size), - ) - for request in requests - ] - dataset = GenerativeTaskDataset(requests=inputs, dataset_splits=DATASET_SPLITS) + for request in requests: + request.stop_sequence = request.stop_sequence + [self.tokenizer.eos_token] + request.tokenized_context = self.tok_encode(request.context) + + dataset = GenerativeTaskDataset(requests=requests, dataset_splits=self.DATASET_SPLITS) starting_batch_size = STARTING_BATCH_SIZE results = [] for split_start, split_end in tqdm( - dataset.splits_start_end_iterator(), total=DATASET_SPLITS, desc="Splits", position=0, disable=disable_tqdm + dataset.splits_start_end_iterator(), + total=self.DATASET_SPLITS, + desc="Splits", + position=0, + disable=self.disable_tqdm, ): - # longest context in the current split - context, (_, max_gen) = dataset[0] - longest_context_continuation_size_in_split = len(context) + max_gen + # Longest context in the current split is the first item (since we sort reversed) + longest_context_continuation_size_in_split = len(dataset[0].tokenized_context) + dataset[0].generation_size max_context_continuation_size_allowed = min(longest_context_continuation_size_in_split, self.max_length) batch_size = self._get_batch_size( override_bs=override_bs, max_input_length=max_context_continuation_size_allowed, starting_batch_size=starting_batch_size, ) + # For next iteration, since the batch will be smaller, we'll test a bigger batch size starting_batch_size = batch_size * 2 dataloader = DataLoader(dataset, batch_size=batch_size, collate_fn=lambda batch: batch) if self.accelerator: dataloader = self.accelerator.prepare(dataloader) - for batch in tqdm(dataloader, desc="Greedy generation", position=1, leave=False, disable=disable_tqdm): + for batch in tqdm( + dataloader, desc="Greedy generation", position=1, leave=False, disable=self.disable_tqdm + ): # NOTE: we are assuming all items in a batch behave similarly (same # stop_tokens and max_tokens genrated) which is not necessarily # the case! Because of that we only use batch size of 1 - stop_tokens, max_generated_tokens = batch[0][1] - context = [c[0] for c in batch] - max_context_allowed_by_model = self.max_length - max_generated_tokens + stop_tokens = batch[0].stop_sequence + max_generated_tokens = batch[0].generation_size + context = [c.context for c in batch] + max_context_size_allowed = self.max_length - max_generated_tokens tokenized = self.tokenizer( context, padding=True, truncation=True, return_tensors="pt", - max_length=max_context_allowed_by_model, + max_length=max_context_size_allowed, add_special_tokens=self.add_special_tokens, ).to(self.device) @@ -445,65 +399,86 @@ def greedy_until( padded=[0] * len(tokenized["input_ids"]), ) - # responses, logits and input_ids have all been gathered accross GPUs already - # but we also grab the original length of these vectors, which have been padded - # while being gathered - the added info - responses, logits, input_ids, len_resps, len_logits, len_ids = self._model_generate( - input_ids=prepared_batch.input_ids, - attention_mask=prepared_batch.input_mask, + cur_reponses = self._generate( + batch=prepared_batch, max_tokens=max_generated_tokens, - stop=stop_tokens, + stop_tokens=stop_tokens, returns_logits=returns_logits, ) + results.extend(cur_reponses) - if returns_logits: - logits = logits.cpu().numpy() + return dataset.get_original_order(results) - batch_truncated = torch.tensor(prepared_batch.truncated, device=self.device) - if self.accelerator: - batch_truncated = self.accelerator.gather_for_metrics(batch_truncated) - batch_padded = torch.tensor(prepared_batch.padded, device=self.device) - if self.accelerator: - batch_padded = self.accelerator.gather_for_metrics(batch_padded) + def _generate( + self, + batch: Batch, + max_tokens: int, + stop_tokens: list[str], + returns_logits: Optional[bool] = False, + ) -> list[GenerateReturn]: + """Contains the actual logic of the generation. + First computes the stop sequences, then generates the predictions, then converts the outputs to GenerateReturn. + """ + stopping_criteria = stop_sequences_criteria(self.tokenizer, stop_sequences=stop_tokens, batch=batch) - for ix, (response, batched_input, trunc, padded) in enumerate( - zip(responses, input_ids, batch_truncated, batch_padded) - ): - # Ensure the generated responses do not contain the stop sequences. - response = response[: len_resps[ix]] - decoded_response = self.tok_decode([response])[0] - - for term in stop_tokens: - decoded_response = decoded_response.split(term)[0] - - cur_response = GenerateReturn( - result=decoded_response, - logits=logits[ix][: len_logits[ix]] if returns_logits else None, - generated_tokens=response, - input_tokens=batched_input[: len_ids[ix]], - truncated_tokens_count=trunc.cpu().item(), - padded_tokens_count=padded.cpu().item(), - ) - results.append(cur_response) + # Compute model generation + outputs = self.model.generate( + input_ids=batch.input_ids, + attention_mask=batch.input_mask, + max_new_tokens=max_tokens, + stopping_criteria=stopping_criteria, + do_sample=False, + pad_token_id=self.tokenizer.pad_token_id if self.tokenizer.pad_token_id else self.tokenizer.eos_token_id, + return_dict_in_generate=True, + output_scores=True, + ) + if returns_logits: + logits = self.model.compute_transition_scores(outputs.sequences, outputs.scores, normalize_logits=True) + generations = outputs.sequences[:, batch.input_ids.size(1) :] + generations, len_gens = self.pad_and_gather(generations) + batch.input_ids, len_ids = self.pad_and_gather(batch.input_ids) - return dataset.get_original_order(results) + logits, len_logits = None, None + if returns_logits: + logits, len_logits = self.pad_and_gather(logits) + logits = logits.cpu().numpy() - def _check_continuations_start_space(self, continuation: str) -> str: - """Some models tokenizer want a space at the beginning and other not. We update this if needed here. - multichoice_continuations_start_space can be: - - True (add a space if these isn't one) - - False (remove a space if there is one) - - None (Don't touch - default) - Todo: find a way to add this back WITHOUT breaking compatibility with the harness - """ - if self.multichoice_continuations_start_space is True and continuation[0] != " ": - continuation = " " + continuation - if self.multichoice_continuations_start_space is False and continuation[0] == " ": - continuation = continuation.lstrip() - return continuation + # We gather remaining info + batch.truncated = torch.tensor(batch.truncated, device=self.device) + if self.accelerator: + batch.truncated = self.accelerator.gather_for_metrics(batch.truncated) + batch.padded = torch.tensor(batch.padded, device=self.device) + if self.accelerator: + batch.padded = self.accelerator.gather_for_metrics(batch.padded) + + # We convert to GenerateReturn outputs + all_responses = [] + for ix, (generation, batched_input, trunc, padded) in enumerate( + zip(generations, batch.input_ids, batch.truncated, batch.padded) + ): + # Ensure the generated responses do not contain the stop sequences. + generation = generation[: len_gens[ix]] + decoded_generation = self.tok_decode([generation])[0] + + for term in stop_tokens: + decoded_generation = decoded_generation.split(term)[0] + + cur_response = GenerateReturn( + result=decoded_generation, + logits=logits[ix][: len_logits[ix]] if returns_logits else None, + generated_tokens=generation, + input_tokens=batched_input[: len_ids[ix]], + truncated_tokens_count=trunc.cpu().item(), + padded_tokens_count=padded.cpu().item(), + ) + all_responses.append(cur_response) + + return all_responses def loglikelihood( - self, requests: list[LoglikelihoodRequest], override_bs: Optional[int] = None + self, + requests: list[LoglikelihoodRequest], + override_bs: Optional[int] = None, ) -> list[LoglikelihoodReturn]: """Tokenize the context and continuation and compute the log likelihood of those tokenized sequences. @@ -514,111 +489,101 @@ def loglikelihood( Returns: list[Tuple[float, bool]]: _description_ """ - tokenized_reqs = [] - for request in requests: - context, continuation = request.context, request.choice - - if context == "": - context_enc, continuation_enc = [self.eot_token_id], self.tok_encode(continuation) + if request.context == "": + request.tokenized_context = [self.tokenizer.eos_token_id] + request.tokenized_continuation = self.tok_encode(request.choice) else: - # DO NOT CHANGE THE FOLLOWING LINE! - # It is mandatory for compatibility with the harness!!! - context_enc, continuation_enc = self._encode_pair(context, continuation) - - tokenized_reqs.append(((context, continuation), context_enc, continuation_enc)) + # The following line is mandatory for compatibility with the harness + request.tokenized_context, request.tokenized_continuation = self.tok_encode_pair( + request.context, request.choice + ) - return self._loglikelihood_tokens(tokenized_reqs, override_bs=override_bs, dataset_splits=DATASET_SPLITS) + return self._loglikelihood_tokens(requests, override_bs=override_bs) def loglikelihood_rolling( - self, requests: list[LoglikelihoodRollingRequest], override_bs=None + self, + requests: list[LoglikelihoodRollingRequest], + override_bs=None, ) -> list[LoglikelihoodReturn]: """This function is used to compute the log likelihood of the context for perplexity metrics.""" - tokenized_reqs = [] for request in requests: # tuple of one elem - context = request.context - fake_context_enc, context_enc = [self.eot_token_id], self.tok_encode(context) - - tokenized_reqs.append((("", context), fake_context_enc, context_enc)) + request.tokenized_context = [self.tokenizer.eos_token_id] # Fake context + request.tokenized_continuation = self.tok_encode(request.context) - disable_tqdm = False - if self.accelerator: - disable_tqdm = bool(not self.accelerator.is_main_process) results = self._loglikelihood_tokens( - tokenized_reqs, + requests, override_bs=override_bs, - disable_tqdm=disable_tqdm, return_bool_score=False, - dataset_splits=DATASET_SPLITS, ) return results def _loglikelihood_tokens( self, - requests, - disable_tqdm: bool = False, + requests: list[LoglikelihoodRequest], override_bs: int = -1, - dataset_splits: int = 4, return_bool_score: bool = True, ) -> list[LoglikelihoodReturn]: - dataset = LoglikelihoodDataset(requests=requests, dataset_splits=dataset_splits) + dataset = LoglikelihoodDataset(requests=requests, dataset_splits=self.DATASET_SPLITS) starting_batch_size = STARTING_BATCH_SIZE res = [] for split_start, split_end in tqdm(dataset.splits_start_end_iterator()): - _, context_enc, continuation_enc = dataset[0] - max_context = len((context_enc + continuation_enc)[-(self.max_length + 1) :][:-1]) + context_enc = dataset[0].tokenized_context + continuation_enc = dataset[0].tokenized_continuation + max_context_continuation_size_allowed = len( + (context_enc + continuation_enc)[-(self.max_length + 1) :][:-1] + ) batch_size = self._get_batch_size( - override_bs=override_bs, max_input_length=max_context, starting_batch_size=starting_batch_size + override_bs=override_bs, + max_input_length=max_context_continuation_size_allowed, + starting_batch_size=starting_batch_size, ) - dataloader = DataLoader(dataset, batch_size=batch_size, collate_fn=lambda batch: batch) starting_batch_size = batch_size * 2 + dataloader = DataLoader(dataset, batch_size=batch_size, collate_fn=lambda batch: batch) if self.accelerator: dataloader = self.accelerator.prepare(dataloader) - for batch in tqdm(dataloader, disable=disable_tqdm): - inputs = [ - context_enc + continuation_enc[:-1] for _, context_enc, continuation_enc in batch - ] # The last token doesn't need to be input in the model - prepared_batch = self.prepare_batch(inputs, padding_length=max_context, max_context=max_context) + for batch in tqdm(dataloader, disable=self.disable_tqdm): + prepared_batch = self.prepare_batch( + batch, + padding_length=max_context_continuation_size_allowed, + max_context=max_context_continuation_size_allowed, + ) - out = self._model_call(prepared_batch.input_ids) - multi_logits = F.log_softmax(out, dim=-1) # [batch, padding_length, vocab] + model_output = self._model_call(prepared_batch.input_ids) + logits = F.log_softmax(model_output, dim=-1) # [batch, padding_length, vocab] logits_sum = [] max_equals = [] batch_cont_tokens = [] - for (_, _, cont_toks), logits, inplen in zip(batch, multi_logits, prepared_batch.input_lengths): + for cur_request, cur_logits, inplen in zip(batch, logits, prepared_batch.input_lengths): + cont_toks = torch.tensor(cur_request.tokenized_continuation, dtype=torch.long, device=self.device) + contlen = cont_toks.shape[0] # We only look at the continuation tokens - contlen = len(cont_toks) if contlen > inplen: - # continuation is longer than the allowed context size, everything is a continuation - logits = logits.unsqueeze(0).to(self.device) # [1, seq, vocab] - cont_toks = ( - torch.tensor(cont_toks, dtype=torch.long, device=self.device)[:inplen] - .unsqueeze(0) - .to(self.device) - ) # [1, seq] + # Continuation is longer than the input size, we are in rolling mode (only continuation) + cur_logits = cur_logits.unsqueeze(0).to(self.device) # [1, seq, vocab] + cont_toks = cont_toks[:inplen].unsqueeze(0).to(self.device) # [1, seq] else: - logits = logits[inplen - contlen : inplen].unsqueeze(0).to(self.device) # [1, seq, vocab] - cont_toks = ( - torch.tensor(cont_toks, dtype=torch.long, device=self.device).unsqueeze(0).to(self.device) - ) # [1, seq] + cur_logits = ( + cur_logits[inplen - contlen : inplen].unsqueeze(0).to(self.device) + ) # [1, seq, voc] + cont_toks = cont_toks.unsqueeze(0).to(self.device) # [1, seq] # Check if per-token argmax is exactly equal to continuation - greedy_tokens = logits.argmax(dim=-1).to(self.device) + greedy_tokens = cur_logits.argmax(dim=-1).to(self.device) # Sometimes the continuation is longer than allowed by the model, we only look at the first tokens max_equal = (greedy_tokens == cont_toks).all().squeeze(0).to(self.device) # Obtain log-probs at the corresponding continuation token indices - # last_token_slice = logits[:, -1, :].squeeze(0).tolist() - logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(-1) # [1, seq] + cur_logits = torch.gather(cur_logits, 2, cont_toks.unsqueeze(-1)).squeeze(-1) # [1, seq] # Answer: (log prob, is-exact-match) - logits_sum.append(logits.sum()) + logits_sum.append(cur_logits.sum()) max_equals.append(max_equal) batch_cont_tokens.append(cont_toks) @@ -649,6 +614,7 @@ def _loglikelihood_tokens( zip(logits, batch_cont_tokens, max_equal, batched_inputs, batch_truncated, batch_padded) ): answer = LoglikelihoodReturn( + # todo: we might want to store the logits unsummed result=(float(logit.sum()), bool(maxe)) if return_bool_score else float(logit.sum()), input_tokens=batched_input[: len_inputs[ix]].cpu().tolist(), generated_tokens=cont_tokens[: len_tokens[ix]].cpu().tolist(), @@ -658,7 +624,7 @@ def _loglikelihood_tokens( res.append(answer) # Clean up GPUS - del out + del model_output del logits del batched_inputs del batch_truncated @@ -666,9 +632,18 @@ def _loglikelihood_tokens( return dataset.get_original_order(res) - def prepare_batch(self, batch: list[str], padding_length: int, max_context: Optional[int] = None): + def prepare_batch( + self, batch: list[Request], padding_length: int, max_context: Optional[int] = None, single_token: bool = False + ): """Tokenize a batch of inputs and return also the length, truncations and padding""" - inputs = [] + if single_token: + inputs = [request.tokenized_context for request in batch] + else: + inputs = [ + request.tokenized_context + request.tokenized_continuation[:-1] for request in batch + ] # The last token (an eos) doesn't need to be given to the model + + input_tokens = [] attention_masks = [] input_lengths = [] truncated = [] @@ -678,57 +653,33 @@ def prepare_batch(self, batch: list[str], padding_length: int, max_context: Opti hlog_warn("max_context is None, using max_length") max_context = self.max_length - # because vectorizing is annoying, we first convert each (context, continuation) pair to padded - # tensors, then we pack them together into a batch, call the model, and then pick it all apart - # again because vectorizing is annoying - # Each sample is concatenated and cut to lenght or padded to max_length - for tokens in batch: - truncated.append(max(len(tokens) - max_context, 0)) - - # how this all works: - # CTX CONT - # inp 0 1 2 3|4 5 6 7 8 9 <- last token is deleted by inp[:, :-1] - # gpt2 \ \ - # logits 1 2 3|4 5 6 7 8 9 <- the ctx half gets tossed out by the - # cont_toks 4 5 6 7 8 9 [:, -len(continuation_enc):, :self.vocab_size] slice - - # when too long to fit in context, truncate from the left - inp = torch.tensor( - (tokens)[-max_context:], # [:-1], - dtype=torch.long, - ).to(self.device) - - (inplen,) = inp.shape - - # since in _collate we make sure length is descending, the longest is always the first one. - padding_length = padding_length if padding_length is not None else inplen - - if padding_length - inplen < 0: - hlog_err(f"{tokens=}") - hlog_err(f"{max_context=}") - hlog_err(f"{padding_length=}") - hlog_err(f"{inp.shape=}") - hlog_err(f"Padding length {padding_length} is smaller than input length {inplen}") + for orig_tokens in inputs: + truncated.append(max(len(orig_tokens) - max_context, 0)) + + # Truncate from the left if needed to fit in the model's context + tokens = torch.tensor((orig_tokens)[-max_context:], dtype=torch.long).to(self.device) + sequence_len = tokens.shape[0] + + # We add padding, if needed + padding_length = padding_length if padding_length is not None else sequence_len + + if padding_length - sequence_len < 0: + hlog_err(f"Padding length {padding_length} is smaller than input length {sequence_len}") raise ValueError("Negative padding") - padded.append(padding_length - inplen) + padded.append(padding_length - sequence_len) + # Right padding - it likely would be better to do left padding + tokens = F.pad(tokens, (0, padding_length - sequence_len), value=self.tokenizer.pad_token_id) - # pad length from seq to padding_length - att = torch.tensor([1] * len(inp) + [0] * (padding_length - inplen), dtype=torch.long).to(self.device) - inp = torch.cat( - [ - inp, # [seq] - torch.zeros(padding_length - inplen, dtype=torch.long).to(inp.device), # [padding_length - seq] - ], - dim=0, - ) - attention_masks.append(att.unsqueeze(0)) + # We create the attention mask to ignore padding + mask = tokens == self.tokenizer.pad_token_id + attention_masks.append(mask) - inputs.append(inp.unsqueeze(0)) # [1, padding_length] - input_lengths.append(inplen) + input_tokens.append(tokens.unsqueeze(0)) # [1, padding_length] + input_lengths.append(sequence_len) - batched_inputs = torch.cat(inputs, dim=0) # [batch, padding_length] + batched_inputs = torch.cat(input_tokens, dim=0) # [batch, padding_length] attention_masks = torch.cat(attention_masks, dim=0) return Batch( @@ -767,20 +718,14 @@ def loglikelihood_single_token( Returns: list[Tuple[float, bool]]: _description_ """ - tokenized_reqs = [] - for request in requests: - context = request.context - continuations = request.choices - - if context == "": - # end of text as context - context_enc = [self.eot_token_id] + if request.context == "": + request.tokenized_context = [self.tokenizer.eos_token_id] else: - context_enc = self.tok_encode(context) + request.tokenized_context = self.tok_encode(request.context) # Some models tokenizer want a space at the beginning and other not - continuations = [self._check_continuations_start_space(c) for c in continuations] + continuations = [self._check_continuations_start_space(c) for c in request.choices] # We must not accidentally prepend a continuation with a start of sentence token. continuations_enc = [self.tok_encode(c, add_special_tokens=False) for c in continuations] @@ -789,23 +734,19 @@ def loglikelihood_single_token( f"Trying to do single token multiple choice but one choice has several tokens: {continuations_enc}. " "If the additional pre-token is a space, try to set --no_multichoice_continuations_start_space " ) + request.tokenized_continuation = continuations_enc - tokenized_reqs.append(((context, continuations), context_enc, continuations_enc)) - disable_tqdm = False - if self.accelerator: - disable_tqdm = bool(not self.accelerator.is_main_process) - - return self._loglikelihood_single_token(tokenized_reqs, override_bs=override_bs, disable_tqdm=disable_tqdm) + return self._loglikelihood_single_token(requests, override_bs=override_bs) def _loglikelihood_single_token( - self, requests, disable_tqdm: bool = False, override_bs: int = -1, dataset_splits: int = 4 + self, requests: list[LoglikelihoodSingleTokenRequest], override_bs: int = -1 ) -> list[LoglikelihoodSingleTokenReturn]: - dataset = LoglikelihoodSingleTokenDataset(requests=requests, dataset_splits=dataset_splits) + dataset = LoglikelihoodSingleTokenDataset(requests=requests, dataset_splits=self.DATASET_SPLITS) starting_batch_size = STARTING_BATCH_SIZE res = [] for split_start, split_end in tqdm(dataset.splits_start_end_iterator()): - _, context_enc, _ = dataset[0] + context_enc = dataset[0].tokenized_context max_context = len(context_enc[-self.max_length :]) batch_size = self._get_batch_size(override_bs=override_bs, max_input_length=max_context) starting_batch_size = batch_size * 2 @@ -814,24 +755,23 @@ def _loglikelihood_single_token( if self.accelerator is not None: dataloader = self.accelerator.prepare(dataloader) - for batch in tqdm(dataloader, disable=disable_tqdm, position=1): - inputs = [context_enc for _, context_enc, _ in batch] - - prepared_batch = self.prepare_batch(inputs, padding_length=max_context, max_context=max_context) + for batch in tqdm(dataloader, disable=self.disable_tqdm, position=1): + prepared_batch = self.prepare_batch( + batch, padding_length=max_context, max_context=max_context, single_token=True + ) out = self._model_call(prepared_batch.input_ids) # [batch, padding_length, vocab] - out = F.log_softmax(out, dim=-1) # we do a softmax over the options, no the vocab batch_probs = [] batch_cont_tokens = [] - for (_, _, cont_toks), logits, inplen in zip(batch, out, prepared_batch.input_lengths): + for cur_request, logits, inplen in zip(batch, out, prepared_batch.input_lengths): # Get the last token logits = logits[inplen - 1] # [vocab] - cont_toks = torch.tensor(cont_toks, dtype=torch.long, device=self.device).squeeze( - -1 - ) # [num_choices] + cont_toks = torch.tensor( + cur_request.tokenized_continuation, dtype=torch.long, device=self.device + ).squeeze(-1) # [num_choices] # Obtain log-probs at the corresponding continuation token indices # last_token_slice = logits[:, -1, :].squeeze(0).tolist() @@ -878,57 +818,6 @@ def _loglikelihood_single_token( return dataset.get_original_order(res) - def _model_generate( - self, - input_ids: torch.Tensor, - attention_mask: torch.Tensor, - max_tokens: int, - stop: Optional[list[str]] = None, - returns_logits: Optional[bool] = False, - ) -> TokenSequence: - # max_tokens is the maximum number of *new* tokens - # Ensure that the context does not encroach into the `space` for the generation.` - # input_ids = inputs["input_ids"][:, max_tokens - self.max_length :].to(self.device) - # if inputs["input_ids"].shape != input_ids.shape: - # hlog_warn( - # f"The input was truncated from {inputs['input_ids'].shape} to {input_ids.shape}, with a model maximum context length of {self.max_length} and max {max_tokens} to generate." - # ) - # attention_mask = inputs["attention_mask"][:, max_tokens - self.max_length :].to(self.device) - - stopping_criteria = stop_sequences_criteria( - self.tokenizer, - stop_sequences=stop, - initial_decoder_input_length=input_ids.shape[1], - batch_size=input_ids.shape[0], - ) - - # Depending on whether accelerate is used or not - outputs = self.model.generate( - input_ids=input_ids, - attention_mask=attention_mask, - max_new_tokens=max_tokens, - stopping_criteria=stopping_criteria, - do_sample=False, - pad_token_id=self.tokenizer.pad_token_id if self.tokenizer.pad_token_id else self.tokenizer.eos_token_id, - return_dict_in_generate=True, - output_scores=returns_logits, - ) - if returns_logits: - transition_scores = self.model.compute_transition_scores( - outputs.sequences, outputs.scores, normalize_logits=True - ) - generations = outputs.sequences[:, input_ids.size(1) :] - generations, len_gens = self.pad_and_gather(generations) - input_ids, len_ids = self.pad_and_gather(input_ids) - - if returns_logits: - # Used input_ids to get its max_length - transition_scores, len_logits = self.pad_and_gather(transition_scores) - else: - transition_scores, len_logits = None, None - - return generations, transition_scores, input_ids, len_gens, len_logits, len_ids - class MultiTokenEOSCriteria(transformers.StoppingCriteria): """Criteria to stop on the specified multi-token sequence.""" @@ -937,9 +826,11 @@ def __init__( self, sequence: str, tokenizer: transformers.PreTrainedTokenizer, - initial_decoder_input_length: int, - batch_size: int, + batch: Batch, ): + initial_decoder_input_length = batch.input_ids.shape[1] + batch_size = batch.input_ids.shape[0] + self.initial_decoder_input_length = initial_decoder_input_length self.done_tracker = [False] * batch_size self.sequence = sequence @@ -962,14 +853,10 @@ def __call__(self, input_ids, scores, **kwargs) -> bool: def stop_sequences_criteria( tokenizer: transformers.PreTrainedTokenizer, stop_sequences: list[str], - initial_decoder_input_length: int, - batch_size: int, + batch: Batch, ) -> transformers.StoppingCriteriaList: return transformers.StoppingCriteriaList( [ - *[ - MultiTokenEOSCriteria(sequence, tokenizer, initial_decoder_input_length, batch_size) - for sequence in stop_sequences - ], + *[MultiTokenEOSCriteria(sequence, tokenizer, batch) for sequence in stop_sequences], ] ) diff --git a/src/lighteval/models/delta_model.py b/src/lighteval/models/delta_model.py index 1233470b9..e0da2cec5 100644 --- a/src/lighteval/models/delta_model.py +++ b/src/lighteval/models/delta_model.py @@ -2,7 +2,7 @@ import torch from tqdm import tqdm -from transformers import AutoModel +from transformers import AutoModelForCausalLM from lighteval.logging.hierarchical_logger import hlog from lighteval.models.base_model import BaseModel @@ -15,7 +15,7 @@ def _create_auto_model( self, config: DeltaModelConfig, env_config: EnvConfig, - ) -> AutoModel: + ) -> AutoModelForCausalLM: """Returns a model created by adding the weights of a delta model to a base model.""" config.model_parallel, max_memory, device_map = self.init_model_parallel(config.model_parallel) torch_dtype = _get_dtype(config.dtype, self._config) @@ -26,10 +26,10 @@ def _create_auto_model( if self.accelerator.is_main_process if self.accelerator is not None else nullcontext(): hlog(f"Loading base and delta models from {config.base_model} and {delta_model}") - base = self.AUTO_MODEL_CLASS.from_pretrained( + base = AutoModelForCausalLM.from_pretrained( config.base_model, torch_dtype=torch.float16, low_cpu_mem_usage=True, token=env_config.token ) - delta = self.AUTO_MODEL_CLASS.from_pretrained( + delta = AutoModelForCausalLM.from_pretrained( delta_model, revision=config.revision + (f"/{config.subfolder}" if config.subfolder is not None else ""), torch_dtype=torch.float16, @@ -46,7 +46,7 @@ def _create_auto_model( hlog(f"Loading delta-applied model from {delta_model}-delta-applied") - model = self.AUTO_MODEL_CLASS.from_pretrained( + model = AutoModelForCausalLM.from_pretrained( merged_path, max_memory=max_memory, device_map=device_map, diff --git a/src/lighteval/models/endpoint_model.py b/src/lighteval/models/endpoint_model.py new file mode 100644 index 000000000..b5c033a4f --- /dev/null +++ b/src/lighteval/models/endpoint_model.py @@ -0,0 +1,349 @@ +import asyncio +from typing import Coroutine, List, Optional, Union + +from huggingface_hub import ( + AsyncInferenceClient, + InferenceClient, + InferenceEndpoint, + InferenceEndpointTimeoutError, + create_inference_endpoint, + get_inference_endpoint, +) +from huggingface_hub.inference._text_generation import TextGenerationResponse +from torch.utils.data import DataLoader +from tqdm import tqdm +from transformers import AutoTokenizer + +from lighteval.data import GenerativeTaskDataset, LoglikelihoodDataset +from lighteval.logging.hierarchical_logger import hlog, hlog_err, hlog_warn +from lighteval.models.abstract_model import LightevalModel +from lighteval.models.model_config import EnvConfig, InferenceEndpointModelConfig, InferenceModelConfig +from lighteval.models.model_output import GenerateReturn, LoglikelihoodReturn, LoglikelihoodSingleTokenReturn +from lighteval.tasks.requests import ( + GreedyUntilRequest, + GreedyUntilWithLogitsRequest, + LoglikelihoodRequest, + LoglikelihoodRollingRequest, + LoglikelihoodSingleTokenRequest, +) +from lighteval.utils import as_list + + +BATCH_SIZE = 50 + + +class InferenceEndpointModel(LightevalModel): + """InferenceEndpointModels can be used both with the free inference client, or with inference + endpoints, which will use text-generation-inference to deploy your model for the duration of the evaluation. + """ + + def __init__( + self, config: Union[InferenceEndpointModelConfig, InferenceModelConfig], env_config: EnvConfig + ) -> None: + if isinstance(config, InferenceEndpointModelConfig): + if config.should_reuse_existing: + self.endpoint = get_inference_endpoint(name=config.name, token=env_config.token) + else: + self.endpoint: InferenceEndpoint = create_inference_endpoint( + name=config.name, + repository=config.repository, + framework=config.framework, + task="text-generation", + accelerator=config.accelerator, + vendor=config.vendor, + region=config.region, + type=config.endpoint_type, + instance_size=config.instance_size, + instance_type=config.instance_type, + token=env_config.token, + custom_image={ + "health_route": "/health", + "env": { + # Documentaiton: https://huggingface.co/docs/text-generation-inference/en/basic_tutorials/launcher + "MAX_BATCH_PREFILL_TOKENS": "2048", + "MAX_INPUT_LENGTH": "2047", + "MAX_TOTAL_TOKENS": "2048", + "MODEL_ID": "/repository", + }, + "url": "ghcr.io/huggingface/text-generation-inference:1.1.0", + }, + ) + hlog("Deploying your endpoint. Please wait.") + try: + self.endpoint.wait(timeout=600) # Waits for the endpoint to be deployed + except InferenceEndpointTimeoutError as e: + hlog_err("Endpoint did not start within 10 minutes, there was a timeout.") + raise e + hlog("Endpoint successfully deployed!") + self.name = config.repository + self.revision = self.endpoint.revision + self.async_client: AsyncInferenceClient = self.endpoint.async_client + self.client: InferenceClient = self.endpoint.client + + else: # Free inference client + self.endpoint = None + self.name = config.model + self.revision = "default" + self.async_client = AsyncInferenceClient(model=config.model, token=env_config.token) + self.client = InferenceClient(model=config.model, token=env_config.token) + + self.use_async = False # for debug - async use is faster + + self._tokenizer = AutoTokenizer.from_pretrained(self.name) + + @property + def tokenizer(self): + return self._tokenizer + + def cleanup(self): + if self.endpoint is not None: + self.endpoint.delete() + hlog_warn( + "You deleted your endpoint after using it. You'll need to create it again if you need to reuse it." + ) + + def max_length(self): + if self._max_length is not None: + return self._max_length + + if hasattr(self.tokenizer, "model_max_length"): + self._max_length = self.tokenizer.model_max_length + else: + self._max_length = 2048 + return self._max_length + + def __async_process_request( + self, context: str, stop_tokens: list[str], max_tokens: int + ) -> Coroutine[None, list[TextGenerationResponse], str]: + # Todo: add an option to launch with conversational instead for chat prompts + # https://huggingface.co/docs/huggingface_hub/v0.20.3/en/package_reference/inference_client#huggingface_hub.AsyncInferenceClient.conversational + generated_text = self.async_client.text_generation( + prompt=context, + details=True, + decoder_input_details=True, + max_new_tokens=max_tokens, + stop_sequences=stop_tokens, + # truncate=, + ) + + return generated_text + + def __process_request(self, context: str, stop_tokens: list[str], max_tokens: int) -> TextGenerationResponse: + # Todo: add an option to launch with conversational instead for chat prompts + # https://huggingface.co/docs/huggingface_hub/v0.20.3/en/package_reference/inference_client#huggingface_hub.AsyncInferenceClient.conversational + generated_text = self.client.text_generation( + prompt=context, + details=True, + decoder_input_details=True, + max_new_tokens=max_tokens, + stop_sequences=stop_tokens, + # truncate=, + ) + + return generated_text + + async def __async_process_batch_generate( + self, + requests: list[GreedyUntilRequest | GreedyUntilWithLogitsRequest], + ) -> list[TextGenerationResponse]: + return await asyncio.gather( + *[ + self.__async_process_request( + context=request.context, + stop_tokens=as_list(request.stop_sequence), + max_tokens=request.generation_size, + ) + for request in requests + ] + ) + + def __process_batch_generate( + self, + requests: list[GreedyUntilRequest | GreedyUntilWithLogitsRequest], + ) -> list[TextGenerationResponse]: + return [ + self.__process_request( + context=request.context, + stop_tokens=as_list(request.stop_sequence), + max_tokens=request.generation_size, + ) + for request in requests + ] + + async def __async_process_batch_logprob( + self, requests: list[LoglikelihoodRequest], rolling: bool = False + ) -> list[TextGenerationResponse]: + return await asyncio.gather( + *[ + self.__async_process_request( + context=request.context if rolling else request.context + request.choice, + stop_tokens=[], + max_tokens=1, + ) + for request in requests + ] + ) + + def __process_batch_logprob( + self, requests: list[LoglikelihoodRequest], rolling: bool = False + ) -> list[TextGenerationResponse]: + return [ + self.__process_request( + context=request.context if rolling else request.context + request.choice, + stop_tokens=[], + max_tokens=1, + ) + for request in requests + ] + + def greedy_until_with_logits( + self, + requests: list[GreedyUntilWithLogitsRequest], + override_bs: Optional[int] = None, + ) -> list[GenerateReturn]: + """ + Generates sequences greedily until a stopping condition is met, + returning both the generated sequences and the logits. + + Args: + requests (list[tuple[str, dict]]): A list of input requests, + where each request is a tuple containing a prompt string and a dictionary of additional parameters. + override_bs (Optional[int], optional): Overrides the batch size for generation. Defaults to None. + + Returns: + list[GenerateReturn]: A list of GenerateReturn objects, + where each object contains the generated sequence and the corresponding logits. + """ + + return self.greedy_until( + requests, + returns_logits=True, + override_bs=override_bs, + ) + + def greedy_until( + self, + requests: List[GreedyUntilRequest], + returns_logits: bool = False, + override_bs: Optional[int] = None, + ) -> List[GenerateReturn]: + for request in requests: + request.stop_sequence = request.stop_sequence + [self.tokenizer.eos_token] + + dataset = GenerativeTaskDataset(requests=requests, dataset_splits=self.DATASET_SPLITS) + batch_size = override_bs if override_bs is not None else BATCH_SIZE + results: List[str] = [] + + for _, _ in tqdm( + dataset.splits_start_end_iterator(), + total=self.DATASET_SPLITS, + desc="Splits", + position=0, + disable=self.disable_tqdm, + ): + dataloader = DataLoader(dataset, batch_size=batch_size, collate_fn=lambda batch: batch) + + for batch in tqdm( + dataloader, desc="Greedy generation", position=1, leave=False, disable=self.disable_tqdm + ): + if self.use_async: + responses = asyncio.run(self.__async_process_batch_generate(batch, returns_logits)) + else: + responses = self.__process_batch_generate(batch, returns_logits) + for response in responses: + results.append( + GenerateReturn( + result=response.generated_text, + logits=[item.logprob for item in response.details.prefill] if returns_logits else None, + truncated_tokens_count=-1, + padded_tokens_count=-1, + ) + ) + + return results + + def loglikelihood( + self, requests: list[LoglikelihoodRequest], override_bs: Optional[int] = None + ) -> list[LoglikelihoodReturn]: + for request in requests: + request.tokenized_context = self.tok_encode(request.context) + request.tokenized_continuation = self.tok_encode(request.choice) + dataset = LoglikelihoodDataset(requests=requests, dataset_splits=self.DATASET_SPLITS) + batch_size = override_bs if override_bs is not None else BATCH_SIZE + results: List[str] = [] + + for _, _ in tqdm( + dataset.splits_start_end_iterator(), + total=self.DATASET_SPLITS, + desc="Splits", + position=0, + disable=self.disable_tqdm, + ): + dataloader = DataLoader(dataset, batch_size=batch_size, collate_fn=lambda batch: batch) + + for batch in tqdm(dataloader, desc="Loglikleihoods", position=1, leave=False, disable=self.disable_tqdm): + if self.use_async: + responses = asyncio.run(self.__async_process_batch_logprob(batch)) + else: + responses = self.__process_batch_logprob(batch) + for ix, response in enumerate(responses): + len_choice = len(batch[ix].tokenized_continuation) + results.append( + LoglikelihoodReturn( + result=[ + t.logprob for t in response.details.prefill[-len_choice:] if t.logprob is not None + ], + input_tokens=[t.id for t in response.details.prefill[:-len_choice]], + generated_tokens=[t.id for t in response.details.prefill[-len_choice:]], + truncated_tokens_count=-1, + padded_tokens_count=-1, + ) + ) + + return results + + def loglikelihood_rolling( + self, requests: list[LoglikelihoodRollingRequest], override_bs=None + ) -> list[LoglikelihoodReturn]: + """This function is used to compute the log likelihood of the context for perplexity metrics.""" + for request in requests: + request.tokenized_context = [self.tokenizer.eos_token_id] + request.tokenized_continuation = self.tok_encode(request.context) + + dataset = LoglikelihoodDataset(requests=requests, dataset_splits=self.DATASET_SPLITS) + batch_size = override_bs if override_bs is not None else BATCH_SIZE + results: List[str] = [] + + for _, _ in tqdm( + dataset.splits_start_end_iterator(), + total=self.DATASET_SPLITS, + desc="Splits", + position=0, + disable=self.disable_tqdm, + ): + dataloader = DataLoader(dataset, batch_size=batch_size, collate_fn=lambda batch: batch) + + for batch in tqdm(dataloader, desc="Loglikleihoods", position=1, leave=False, disable=self.disable_tqdm): + if self.use_async: + responses = asyncio.run(self.__async_process_batch_logprob(batch, rolling=True)) + else: + responses = self.__process_batch_logprob(batch, rolling=True) + for response in responses: + results.append( + LoglikelihoodReturn( + result=[t.logprob for t in response.details.tokens[:-1]], + input_tokens=[t.id for t in response.details.prefill], + generated_tokens=[t.id for t in response.details.tokens[:-1]], + truncated_tokens_count=-1, + padded_tokens_count=-1, + ) + ) + + return results + + def loglikelihood_single_token( + self, + requests: list[LoglikelihoodSingleTokenRequest], + override_bs: Optional[int] = None, + ) -> list[LoglikelihoodSingleTokenReturn]: + raise ValueError("Endpoint models can't use single token metrics. Change the metric to the standard version") diff --git a/src/lighteval/models/model_config.py b/src/lighteval/models/model_config.py index 44cc2edf9..f1bbef43f 100644 --- a/src/lighteval/models/model_config.py +++ b/src/lighteval/models/model_config.py @@ -196,6 +196,25 @@ class TGIModelConfig: inference_server_auth: str +@dataclass +class InferenceModelConfig: + model: str + + +@dataclass +class InferenceEndpointModelConfig: + name: str + repository: str + accelerator: str + vendor: str + region: str + instance_size: str + instance_type: str + framework: str = "pytorch" + endpoint_type: str = "protected" + should_reuse_existing: bool = False + + def create_model_config(args: Namespace, accelerator: Union["Accelerator", None]) -> BaseModelConfig: # noqa: C901 """ Create a model configuration based on the provided arguments. @@ -215,7 +234,34 @@ def create_model_config(args: Namespace, accelerator: Union["Accelerator", None] """ if args.inference_server_address is not None and args.model_args is not None: raise ValueError("You cannot both use an inference server and load a model from its checkpoint.") + if args.inference_server_address is not None and args.endpoint_model_name is not None: + raise ValueError("You cannot both use a local inference server and load a model from an inference endpoint.") + if args.endpoint_model_name is not None and args.model_args is not None: + raise ValueError("You cannot both load a model from its checkpoint and from an inference endpoint.") + + # TGI + if args.inference_server_address is not None: + return TGIModelConfig( + inference_server_address=args.inference_server_address, inference_server_auth=args.inference_server_auth + ) + + # Endpoint + if args.endpoint_model_name: + if args.reuse_existing or args.vendor is not None: + model = args.endpoint_model_name.split("/")[1].lower() + return InferenceEndpointModelConfig( + name=f"{model}-lighteval", + repository=args.endpoint_model_name, + accelerator=args.accelerator, + region=args.region, + vendor=args.vendor, + instance_size=args.instance_size, + instance_type=args.instance_type, + should_reuse_existing=args.reuse_existing, + ) + return InferenceModelConfig(model=args.endpoint_model_name) + # Base multichoice_continuations_start_space = args.multichoice_continuations_start_space if not multichoice_continuations_start_space and not args.no_multichoice_continuations_start_space: multichoice_continuations_start_space = None diff --git a/src/lighteval/models/model_loader.py b/src/lighteval/models/model_loader.py index 98fd200e4..5bb7d0911 100644 --- a/src/lighteval/models/model_loader.py +++ b/src/lighteval/models/model_loader.py @@ -5,14 +5,17 @@ from lighteval.models.adapter_model import AdapterModel from lighteval.models.base_model import BaseModel from lighteval.models.delta_model import DeltaModel -from lighteval.models.inference_client import ModelClient +from lighteval.models.endpoint_model import InferenceEndpointModel from lighteval.models.model_config import ( AdapterModelConfig, BaseModelConfig, DeltaModelConfig, EnvConfig, + InferenceEndpointModelConfig, + InferenceModelConfig, TGIModelConfig, ) +from lighteval.models.tgi_model import ModelClient from lighteval.utils import NO_TGI_ERROR_MSG, is_accelerate_available, is_tgi_available @@ -29,7 +32,8 @@ class ModelInfo: def load_model( # noqa: C901 - config: Union[BaseModelConfig, AdapterModelConfig, DeltaModelConfig, TGIModelConfig], env_config: EnvConfig + config: Union[BaseModelConfig, AdapterModelConfig, DeltaModelConfig, TGIModelConfig, InferenceEndpointModelConfig], + env_config: EnvConfig, ) -> Tuple[Union[BaseModel, AdapterModel, DeltaModel, ModelClient], ModelInfo]: """Will load either a model from an inference server or a model from a checkpoint. depending on the arguments passed to the program. @@ -50,6 +54,9 @@ def load_model( # noqa: C901 if isinstance(config, TGIModelConfig): return load_model_with_tgi(config) + if isinstance(config, InferenceEndpointModelConfig) or isinstance(config, InferenceModelConfig): + return load_model_with_inference_endpoints(config, env_config=env_config) + if isinstance(config, BaseModelConfig): return load_model_with_accelerate_or_default(config=config, env_config=env_config) @@ -73,6 +80,18 @@ def load_model_with_tgi(config: TGIModelConfig): return model, model_info +def load_model_with_inference_endpoints(config: InferenceEndpointModelConfig, env_config: EnvConfig): + hlog("Spin up model using inference endpoint.") + model = InferenceEndpointModel(config=config, env_config=env_config) + model_info = ModelInfo( + model_name=model.name, + model_sha=model.revision, + model_dtype="default", + model_size=-1, + ) + return model, model_info + + def load_model_with_accelerate_or_default( config: Union[AdapterModelConfig, BaseModelConfig, DeltaModelConfig], env_config: EnvConfig ): diff --git a/src/lighteval/models/nanotron_model.py b/src/lighteval/models/nanotron_model.py index e634244e3..1d69f10dd 100644 --- a/src/lighteval/models/nanotron_model.py +++ b/src/lighteval/models/nanotron_model.py @@ -236,16 +236,6 @@ def add_special_tokens(self) -> bool: return self._add_special_tokens else: return False - # elif self.AUTO_MODEL_CLASS is transformers.AutoModelForCausalLM: - # return False - # elif self.AUTO_MODEL_CLASS is transformers.AutoModelForSeq2SeqLM: - # return True - # else: - # raise ValueError( - # "Could not determine `add_special_tokens` value from the model " - # "class. Set to `True` or `False` depending on whether the model " - # "was pre-trained with special tokens." - # ) @property def eot_token(self) -> str: diff --git a/src/lighteval/models/inference_client.py b/src/lighteval/models/tgi_model.py similarity index 100% rename from src/lighteval/models/inference_client.py rename to src/lighteval/models/tgi_model.py diff --git a/src/lighteval/tasks/requests.py b/src/lighteval/tasks/requests.py index 2b31bd5ee..baf6e01af 100644 --- a/src/lighteval/tasks/requests.py +++ b/src/lighteval/tasks/requests.py @@ -46,6 +46,8 @@ class LoglikelihoodRequest(Request): choice: str request_type = RequestType.LOGLIKELIHOOD + tokenized_context: list[int] = None + tokenized_continuation: list[int] = None @dataclass @@ -61,6 +63,8 @@ class LoglikelihoodSingleTokenRequest(Request): choices: list[str] request_type = RequestType.LOGLIKELIHOOD_SINGLE_TOKEN + tokenized_context: list[int] = None + tokenized_continuation: list[int] = None @dataclass @@ -72,6 +76,8 @@ class LoglikelihoodRollingRequest(Request): """ request_type = RequestType.LOGLIKELIHOOD_ROLLING + tokenized_context: list[int] = None + tokenized_continuation: list[int] = None @dataclass @@ -88,6 +94,7 @@ class GreedyUntilRequest(Request): stop_sequence: str generation_size: int request_type = RequestType.GREEDY_UNTIL + tokenized_context: list[int] = None @dataclass @@ -105,6 +112,7 @@ class GreedyUntilWithLogitsRequest(Request): stop_sequence: str generation_size: int request_type = RequestType.GREEDY_UNTIL_WITH_LOGITS + tokenized_context: list[int] = None class TaskExampleId(NamedTuple): diff --git a/tasks_examples/open_llm_leaderboard_tasks.txt b/tasks_examples/open_llm_leaderboard_tasks.txt index 41c0ff35a..5736e9537 100644 --- a/tasks_examples/open_llm_leaderboard_tasks.txt +++ b/tasks_examples/open_llm_leaderboard_tasks.txt @@ -57,4 +57,4 @@ lighteval|mmlu:security_studies|5|0 lighteval|mmlu:sociology|5|0 lighteval|mmlu:us_foreign_policy|5|0 lighteval|mmlu:virology|5|0 -lighteval|mmlu:world_religions|5|0 \ No newline at end of file +lighteval|mmlu:world_religions|5|0 diff --git a/tests/reference_scores/harness_metrics.json b/tests/reference_scores/harness_metrics.json index a6c506f34..1c8c5b91d 100644 --- a/tests/reference_scores/harness_metrics.json +++ b/tests/reference_scores/harness_metrics.json @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:a1965f0b9c66cfe1b1f3cc380a80949e32eab92ae8eac079c0339506ce827093 -size 48373142 +oid sha256:408956938a6b7a18b03658bb9772b471efcea4aa04afb0b35d76cecfca6a706e +size 48376580 diff --git a/tests/test_unit_reorder.py b/tests/test_unit_reorder.py index 17af515a7..7ae24546c 100644 --- a/tests/test_unit_reorder.py +++ b/tests/test_unit_reorder.py @@ -1,29 +1,77 @@ +import pytest +from transformers import AutoTokenizer + from lighteval.data import GenerativeTaskDataset +from lighteval.tasks.requests import GreedyUntilRequest # test data that will need to be sorted by length of the string -data = [ - ("1 The quick brown fox jumps over the lazy dog", ([":", "stop"], 10)), - ("2 The quick brown fox jumps over the lazy dog njsa", ([":", "stop"], 10)), - ("Some text", ([":", "stop"], 10)), - ("some more text", ([":", "stop"], 10)), - ("not sure what to write here", ([":", "stop"], 10)), +TEST_DATA = [ + GreedyUntilRequest( + task_name="test", + example_index=0, + request_index=0, + context="1 The quick brown fox jumps over the lazy dog", + stop_sequence=[":", "stop"], + generation_size=10, + ), + GreedyUntilRequest( + task_name="test", + example_index=2, + request_index=0, + context="2 The quick brown fox jumps over the lazy dog njsa", + stop_sequence=[":", "stop"], + generation_size=10, + ), + GreedyUntilRequest( + task_name="test", + example_index=5, + request_index=0, + context="Some text", + stop_sequence=[":", "stop"], + generation_size=10, + ), + GreedyUntilRequest( + task_name="test", + example_index=21, + request_index=0, + context="some more text", + stop_sequence=[":", "stop"], + generation_size=10, + ), + GreedyUntilRequest( + task_name="test", + example_index=1, + request_index=0, + context="not sure what to write here", + stop_sequence=[":", "stop"], + generation_size=10, + ), ] DATASET_SPLITS = 1 class TestReorderGenerativeTaskDataset: + def test_dataset_needs_tokenization(self): + with pytest.raises(ValueError): + GenerativeTaskDataset(requests=TEST_DATA, dataset_splits=DATASET_SPLITS) + def test_reorder_dataset(self): + tokenizer = AutoTokenizer.from_pretrained("gpt2") + data = TEST_DATA.copy() + for request in data: + request.tokenized_context = tokenizer.encode(request.context) + dataset = GenerativeTaskDataset(requests=data, dataset_splits=DATASET_SPLITS) sorted_data = dataset.sorted_data original_data = dataset.get_original_order(sorted_data) for i in range(len(sorted_data) - 1): - assert len(sorted_data[i][0]) >= len( - sorted_data[i + 1][0] - ), f"dataset[{i}][0] = {sorted_data[i][0]} is shorter than dataset[{i+1}][0] = {sorted_data[i+1][0]}" + assert ( + len(sorted_data[i].context) >= len(sorted_data[i + 1].context) + ), f"dataset[{i}][0] = {sorted_data[i].context} is shorter than dataset[{i+1}][0] = {sorted_data[i+1].context}" assert len(sorted_data) == len( original_data