From d43c9a324c81fbe1d85811ded08e56f37abac2d8 Mon Sep 17 00:00:00 2001 From: Atsuki Yamaguchi <30075338+gucci-j@users.noreply.github.com> Date: Wed, 17 Jul 2024 13:34:40 +0100 Subject: [PATCH 1/4] Fix _init_max_length in base_model.py (#185) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Update base_model.py * Update base_model.py * Removed try-except in base_model.py * Update src/lighteval/models/base_model.py * Revert "Update base_model.py" This reverts commit 003d3896a85ac4d34e8b48a86cbd50ccb9a394c0. --------- Co-authored-by: Clémentine Fourrier <22726840+clefourrier@users.noreply.github.com> Co-authored-by: Nathan Habib <30601243+NathanHB@users.noreply.github.com> --- src/lighteval/models/base_model.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/lighteval/models/base_model.py b/src/lighteval/models/base_model.py index 3e483d44..e5e63db9 100644 --- a/src/lighteval/models/base_model.py +++ b/src/lighteval/models/base_model.py @@ -267,8 +267,6 @@ def _init_max_length(self, max_length) -> int: if hasattr(self._config, attr): return getattr(self._config, attr) - if hasattr(self.tokenizer, "model_max_length"): - return self.tokenizer.model_max_length # Default max sequence length setting for when no `max_length` is provided # or no max length config setting is found in the model or tokenizer. return 2048 From 951cd5b586b4afdb84c0dc43bc057f91450dcae0 Mon Sep 17 00:00:00 2001 From: Sadra Barikbin Date: Wed, 17 Jul 2024 17:18:49 +0330 Subject: [PATCH 2/4] Make evaluator invariant of input request type order (#215) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Clémentine Fourrier <22726840+clefourrier@users.noreply.github.com> Co-authored-by: Nathan Habib <30601243+NathanHB@users.noreply.github.com> --- src/lighteval/evaluator.py | 11 +++++------ src/lighteval/metrics/__init__.py | 8 +++----- src/lighteval/tasks/lighteval_task.py | 22 +++++++++++----------- 3 files changed, 19 insertions(+), 22 deletions(-) diff --git a/src/lighteval/evaluator.py b/src/lighteval/evaluator.py index 883e5ef7..331b5f38 100644 --- a/src/lighteval/evaluator.py +++ b/src/lighteval/evaluator.py @@ -64,7 +64,7 @@ def evaluate( # noqa: C901 :return Dictionary of results """ - # A request output tupe is a Tuple where the first element is the index of + # A request output tuple is a Tuple where the first element is the index of # the request for one document of one task i.e. # task: "arc_easy", doc: "0"# request: "0" -> request_index = 0, # We can have multiple requests per doc for multi choice tasks for example. @@ -75,8 +75,11 @@ def evaluate( # noqa: C901 ) example_id_response_dict: dict[TaskExampleId, list[RequestIndexModelResponseTuple]] = collections.defaultdict(list) - for request_type, requests in requests_dict.items(): + for request_type in RequestType: + if request_type not in requests_dict: + continue hlog(f"Running {request_type} requests") + requests = requests_dict[request_type] # These are all the request type from the request factory at the moment if request_type == RequestType.LOGLIKELIHOOD: full_resps = lm.loglikelihood(requests, override_bs=override_bs) @@ -99,10 +102,6 @@ def evaluate( # noqa: C901 # ===== unpack results and sort back in order and return control to Task ===== for task_example_id, prediction_list in example_id_response_dict.items(): - # ===== Unpack the request ===== - prediction_list.sort( - key=lambda x: x.request_index - ) # When we use Loglikelihood for several tokens we have all the options here model_responses = [x.model_response for x in prediction_list] cur_task_name = task_example_id.task_name.rsplit("|", 1)[0] diff --git a/src/lighteval/metrics/__init__.py b/src/lighteval/metrics/__init__.py index 5b1e8b7c..1b105d74 100644 --- a/src/lighteval/metrics/__init__.py +++ b/src/lighteval/metrics/__init__.py @@ -116,15 +116,14 @@ def apply_generative_metric( def apply_multichoice_metric(results: list[ModelReturn], formatted_doc: Doc, metrics: list[Metric]): outputs = {} - if len(formatted_doc.choices) != len(results): - raise ValueError("Length of results is not equal to the length of the choices") + mc_results = results[: len(formatted_doc.choices)] if len(formatted_doc.choices) <= 1: raise ValueError( "You can't use a multi choice metric with only one choice. Use `acc_golds_likelihood` instead." ) # Todo: make better system with return_bool_score instead of taking first element - choices_logprob = [results[i].result[0] for i in range(len(formatted_doc.choices))] # sum( + choices_logprob = [mc_results[i].result[0] for i in range(len(formatted_doc.choices))] # sum( gold_ixs = as_list(formatted_doc.gold_index) for metric in metrics: @@ -132,8 +131,7 @@ def apply_multichoice_metric(results: list[ModelReturn], formatted_doc: Doc, met outputs.update( metric.compute(choices_logprob=choices_logprob, gold_ixs=gold_ixs, formatted_doc=formatted_doc) ) - - return results, outputs + return results[len(formatted_doc.choices) :], outputs def apply_multichoice_metric_one_token(results: list[ModelReturn], formatted_doc: Doc, metrics: list[Metric]): diff --git a/src/lighteval/tasks/lighteval_task.py b/src/lighteval/tasks/lighteval_task.py index 6595571f..07251d69 100644 --- a/src/lighteval/tasks/lighteval_task.py +++ b/src/lighteval/tasks/lighteval_task.py @@ -539,6 +539,16 @@ def process_results(self, formatted_doc: Doc, results: list[ModelReturn]) -> dic results=results, formatted_doc=formatted_doc, metrics=self.metrics ) outputs.update(cur_outputs) + if self.has_metric_category[MetricCategory.MULTICHOICE]: + results, cur_outputs = apply_multichoice_metric( + results=results, formatted_doc=formatted_doc, metrics=self.metrics + ) + outputs.update(cur_outputs) + if self.has_metric_category[MetricCategory.MULTICHOICE_ONE_TOKEN]: + results, cur_outputs = apply_multichoice_metric_one_token( + results=results, formatted_doc=formatted_doc, metrics=self.metrics + ) + outputs.update(cur_outputs) if self.has_metric_category[MetricCategory.PERPLEXITY]: results, cur_outputs = apply_perplexity_metric( results=results, formatted_doc=formatted_doc, metrics=self.metrics @@ -557,16 +567,6 @@ def process_results(self, formatted_doc: Doc, results: list[ModelReturn]) -> dic max_num_samples=max(self.num_samples), ) outputs.update(cur_outputs) - if self.has_metric_category[MetricCategory.MULTICHOICE]: - results, cur_outputs = apply_multichoice_metric( - results=results, formatted_doc=formatted_doc, metrics=self.metrics - ) - outputs.update(cur_outputs) - if self.has_metric_category[MetricCategory.MULTICHOICE_ONE_TOKEN]: - results, cur_outputs = apply_multichoice_metric_one_token( - results=results, formatted_doc=formatted_doc, metrics=self.metrics - ) - outputs.update(cur_outputs) if ( self.has_metric_category[MetricCategory.LLM_AS_JUDGE_MULTI_TURN] or self.has_metric_category[MetricCategory.LLM_AS_JUDGE] @@ -643,7 +643,7 @@ def create_requests_from_tasks( # noqa: C901 ) -> Tuple[dict[RequestType, list[Request]], dict[TaskExampleId, Doc]]: """ Takes a task dict and a fewshot dict and returns a dict of requests, a dict - of docs, and a dict of requests origins. The construction of prompts and + of docs, and a dict of requests origins. The construction of prompts and thus the managing of few shots is done here. Args: From 44f9a461bd366267cffce1317ad106477083c9dc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9mentine=20Fourrier?= <22726840+clefourrier@users.noreply.github.com> Date: Wed, 17 Jul 2024 16:36:28 +0200 Subject: [PATCH 3/4] Quantization related issues (#224) Fixes #200 and #176 --- src/lighteval/models/base_model.py | 6 ++--- src/lighteval/models/model_config.py | 33 +++++++++++++++++++++------- 2 files changed, 27 insertions(+), 12 deletions(-) diff --git a/src/lighteval/models/base_model.py b/src/lighteval/models/base_model.py index e5e63db9..f1ba6151 100644 --- a/src/lighteval/models/base_model.py +++ b/src/lighteval/models/base_model.py @@ -29,7 +29,7 @@ from torch.nn.utils.rnn import pad_sequence from torch.utils.data import DataLoader from tqdm import tqdm -from transformers import AutoModelForCausalLM, AutoTokenizer +from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig from lighteval.data import GenerativeTaskDataset, LoglikelihoodDataset, LoglikelihoodSingleTokenDataset from lighteval.logging.hierarchical_logger import hlog, hlog_err, hlog_warn @@ -88,9 +88,7 @@ def __init__( self.multichoice_continuations_start_space = config.multichoice_continuations_start_space # We are in DP (and launch the script with `accelerate launch`) - if not config.model_parallel and config.quantization_config is None: - # might need to use accelerate instead - # self.model = config.accelerator.prepare(self.model) + if not config.model_parallel and not isinstance(config.quantization_config, BitsAndBytesConfig): hlog(f"Using Data Parallelism, putting model on device {self._device}") self.model = self.model.to(self._device) diff --git a/src/lighteval/models/model_config.py b/src/lighteval/models/model_config.py index b6f4bb5d..75a29d02 100644 --- a/src/lighteval/models/model_config.py +++ b/src/lighteval/models/model_config.py @@ -95,7 +95,8 @@ class BaseModelConfig: Use `dtype="auto"` to derive the type from the model's weights. device (Union[int, str]): device to use for model training. quantization_config (Optional[BitsAndBytesConfig]): quantization - configuration for the model. Needed for 4-bit and 8-bit precision. + configuration for the model, manually provided to load a normally floating point + model at a quantized precision. Needed for 4-bit and 8-bit precision. trust_remote_code (bool): Whether to trust remote code during model loading. @@ -144,13 +145,29 @@ def _init_configs(self, model_name: str, env_config: EnvConfig) -> PretrainedCon cache_dir=env_config.cache_dir, token=env_config.token, ) - if getattr(auto_config, "quantization_config", False) and self.quantization_config is None: - if not is_autogptq_available(): - raise ImportError(NO_AUTOGPTQ_ERROR_MSG) - hlog( - "`quantization_config` is None but was found in the model's config, using the one found in config.json" - ) - self.quantization_config = GPTQConfig(**auto_config.quantization_config, disable_exllama=True) + + # Gathering the model's automatic quantization config, if available + try: + model_auto_quantization_config = auto_config.quantization_config + hlog("An automatic quantization config was found in the model's config. Using it to load the model") + except (AttributeError, KeyError): + model_auto_quantization_config = None + + if model_auto_quantization_config is not None: + if self.quantization_config is not None: + # We don't load models quantized by default with a different user provided conf + raise ValueError("You manually requested quantization on a model already quantized!") + + # We add the quantization to the model params we store + if model_auto_quantization_config["quant_method"] == "gptq": + if not is_autogptq_available(): + raise ImportError(NO_AUTOGPTQ_ERROR_MSG) + auto_config.quantization_config["use_exllama"] = None + self.quantization_config = GPTQConfig(**auto_config.quantization_config, disable_exllama=True) + elif model_auto_quantization_config["quant_method"] == "bitsandbytes": + if not is_bnb_available(): + raise ImportError(NO_BNB_ERROR_MSG) + self.quantization_config = BitsAndBytesConfig(**auto_config.quantization_config) return auto_config From 66ed7a28aa67e6c4a4ab1d93510d45be84802534 Mon Sep 17 00:00:00 2001 From: Sadra Barikbin Date: Thu, 18 Jul 2024 14:06:14 +0330 Subject: [PATCH 4/4] Fix a tiny bug in DROP metric (#229) * Fix the metric * Apply comment and Ruff --- .../metrics/harness_compatibility/drop.py | 37 +++++++++++++------ 1 file changed, 26 insertions(+), 11 deletions(-) diff --git a/src/lighteval/metrics/harness_compatibility/drop.py b/src/lighteval/metrics/harness_compatibility/drop.py index 57b9f18e..d6c8ac30 100644 --- a/src/lighteval/metrics/harness_compatibility/drop.py +++ b/src/lighteval/metrics/harness_compatibility/drop.py @@ -22,17 +22,30 @@ import re import string +from typing import List, Set, Tuple import numpy as np from scipy.optimize import linear_sum_assignment def drop_metrics(predictions: list[str], formatted_doc, **kwargs): # noqa: C901 - """F1 score from bag of words: comes from Harness Drop + """F1 score from bag of words: comes from Harness Drop. DROP offers two metrics, + a quasi exact match and a numeracy-focused F1 score. Quasi in the sense that it + does some normalizations before matching and numeracy-focused in the sense that + if there's number mismatch between the target and prediction F1 score is set to 0. + F1 score is computed using the intersection of target and prediction's BoW + representations with the additional spice that if the answer and/or prediction is + comprised of multiple spans, a greedy matching is done between the two sets of spans + (based on the very BoW overlap) and the average over F1 of pairs is returned. + DROP also accepts multiple answers in which case, the maximum of F1/ Exact Match + between prediction and the different answers is taken. + + For more information, please refer to the section 5 of the DROP paper (https://aclanthology.org/N19-1246/). + Todo: this code is really hard to follow, simplify when possible """ - def _answer_to_bags(answer): + def _answer_to_bags(answer: List[str]) -> Tuple[List[str], List[Set[str]]]: if isinstance(answer, (list, tuple)): raw_spans = answer else: @@ -45,7 +58,7 @@ def _answer_to_bags(answer): token_bags.append(set(normalized_span.split())) return normalized_spans, token_bags - def _get_metrics(predicted, gold): + def _get_metrics(predicted: List[str], gold: List[str]): """ Takes a predicted answer and a gold answer (that are both either a string or a list of strings), and returns exact match and the DROP F1 metric for the prediction. If you are @@ -53,15 +66,17 @@ def _get_metrics(predicted, gold): validation, or while training), this is the function you want to call, after using :func:`answer_json_to_strings` when reading the gold answer from the released data file. """ - predicted_bags = _answer_to_bags(predicted) - gold_bags = _answer_to_bags(gold) + pred_normalized_spans, pred_bags = _answer_to_bags(predicted) + gold_normalized_spans, gold_bags = _answer_to_bags(gold) - if set(predicted_bags[0]) == set(gold_bags[0]) and len(predicted_bags[0]) == len(gold_bags[0]): + if set(pred_normalized_spans) == set(gold_normalized_spans) and len(gold_normalized_spans) == len( + gold_normalized_spans + ): exact_match = 1.0 else: exact_match = 0.0 - f1_per_bag = _align_bags(predicted_bags[1], gold_bags[1]) + f1_per_bag = _align_bags(pred_bags, gold_bags) f1 = np.mean(f1_per_bag) f1 = round(f1, 2) return exact_match, f1 @@ -73,7 +88,7 @@ def _is_number(text): except ValueError: return False - def _match_numbers_if_present(gold_bag, predicted_bag): + def _match_numbers_if_present(gold_bag: Set[str], predicted_bag: Set[str]): gold_numbers = set() predicted_numbers = set() for word in gold_bag: @@ -86,7 +101,7 @@ def _match_numbers_if_present(gold_bag, predicted_bag): return True return False - def _align_bags(predicted, gold): + def _align_bags(predicted: List[Set[str]], gold: List[Set[str]]) -> np.array: """ Takes gold and predicted answer sets and first finds the optimal 1-1 alignment between them and gets maximum metric values over all the answers. @@ -136,7 +151,7 @@ def _fix_number(text): def _tokenize(text): return re.split(" |-", text) - def _normalize(answer): + def _normalize(answer: str): tokens = [ _white_space_fix(_remove_articles(_fix_number(_remove_punc(token.lower())))) for token in _tokenize(answer) ] @@ -147,9 +162,9 @@ def _normalize(answer): max_em = 0 max_f1 = 0 for gold_answer in formatted_doc.specific["golds_no_preprocessing"]: + exact_match, f1_score = _get_metrics(predictions, gold_answer) if isinstance(gold_answer, list): gold_answer = gold_answer[0] - exact_match, f1_score = _get_metrics(predictions, gold_answer) if gold_answer.strip(): max_em = max(max_em, exact_match) max_f1 = max(max_f1, f1_score)