From bc0b8bc42d195b406dd7d3f2cab1ac86bf454339 Mon Sep 17 00:00:00 2001 From: Thomas Wolf Date: Thu, 8 Feb 2024 10:53:31 +0100 Subject: [PATCH] Last PR to make custom tasks work for everyone (#23) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Small one to be in the release. Can iterate later if needed --------- Co-authored-by: clementine@huggingface.co Co-authored-by: Clémentine Fourrier <22726840+clefourrier@users.noreply.github.com> Co-authored-by: Nathan Habib <30601243+NathanHB@users.noreply.github.com> --- run_evals_accelerate.py | 2 +- run_evals_nanotron.py | 2 +- src/lighteval/logging/info_loggers.py | 12 +- src/lighteval/main_accelerate.py | 2 +- src/lighteval/main_nanotron.py | 8 +- src/lighteval/models/base_model.py | 6 +- src/lighteval/tasks/lighteval_task.py | 97 ++++++++--- src/lighteval/tasks/registry.py | 88 ++++++---- .../custom_tasks/custom_evaluation_tasks.py | 72 ++++---- .../custom_tasks/custom_evaluation_utils.py | 159 ------------------ .../lighteval_config_override_template.yaml | 4 +- 11 files changed, 177 insertions(+), 275 deletions(-) delete mode 100644 tasks_examples/custom_tasks/custom_evaluation_utils.py diff --git a/run_evals_accelerate.py b/run_evals_accelerate.py index 7002c8747..0dfc658e2 100644 --- a/run_evals_accelerate.py +++ b/run_evals_accelerate.py @@ -71,7 +71,7 @@ def get_parser(): parser.add_argument("--override_batch_size", type=int, default=-1) parser.add_argument("--dataset_loading_processes", type=int, default=1) parser.add_argument( - "--custom_tasks_file", + "--custom_tasks", type=str, default=None, help="Path to a file with custom tasks (a TASK list of dict and potentially prompt formating functions)", diff --git a/run_evals_nanotron.py b/run_evals_nanotron.py index 9b98d0057..c82d27bd3 100644 --- a/run_evals_nanotron.py +++ b/run_evals_nanotron.py @@ -20,7 +20,7 @@ def get_parser(): parser.add_argument( "--cache-dir", type=str, - default="", + default=None, help="Cache directory", ) diff --git a/src/lighteval/logging/info_loggers.py b/src/lighteval/logging/info_loggers.py index 194e65f5f..1b5ecdbd6 100644 --- a/src/lighteval/logging/info_loggers.py +++ b/src/lighteval/logging/info_loggers.py @@ -69,8 +69,12 @@ class GeneralConfigLogger: def __init__(self) -> None: """Stores the current lighteval commit for reproducibility, and starts the evaluation timer.""" - repo = git.Repo(os.path.dirname(__file__).split("src")[0]) - self.lighteval_sha = repo.git.rev_parse("HEAD") + try: + repo = git.Repo(os.path.dirname(__file__).split("src")[0]) + except git.InvalidGitRepositoryError: + repo = None + + self.lighteval_sha = repo.git.rev_parse("HEAD") if repo is not None else "?" self.start_time = time.perf_counter() def log_args_info( @@ -543,5 +547,5 @@ def log(self, task_dict: dict[str, LightevalTask]) -> None: self.tasks_configs = {name: task.cfg for name, task in task_dict.items()} def log_num_docs(self, task_name: str, original_num_docs: int, effective_num_docs: int) -> None: - self.tasks_configs[task_name]["original_num_docs"] = original_num_docs - self.tasks_configs[task_name]["effective_num_docs"] = effective_num_docs + self.tasks_configs[task_name].original_num_docs = original_num_docs + self.tasks_configs[task_name].effective_num_docs = effective_num_docs diff --git a/src/lighteval/main_accelerate.py b/src/lighteval/main_accelerate.py index b048fbd98..091e179b8 100644 --- a/src/lighteval/main_accelerate.py +++ b/src/lighteval/main_accelerate.py @@ -60,7 +60,7 @@ def main(args): with accelerator.main_process_first() if accelerator is not None else nullcontext(): task_names_list, few_shots_dict = taskinfo_selector(args.tasks) task_dict = Registry(cache_dir=env_config.cache_dir).get_task_dict( - task_names_list, custom_tasks_file=args.custom_tasks_file + task_names_list, custom_tasks=args.custom_tasks ) # Loading all the dataset in a distributed manner LightevalTask.load_datasets(task_dict.values(), args.dataset_loading_processes) diff --git a/src/lighteval/main_nanotron.py b/src/lighteval/main_nanotron.py index 523e73ce3..b52be757b 100644 --- a/src/lighteval/main_nanotron.py +++ b/src/lighteval/main_nanotron.py @@ -38,7 +38,7 @@ def main( checkpoint_config_path: str, lighteval_config_path: Optional[str] = None, - cache_dir: str = None, + cache_dir: Optional[str] = None, config_cls: Type = Config, model_config_cls: Optional[Type] = None, model_cls: Optional[Type] = None, @@ -109,14 +109,14 @@ def main( with htrack_block("Tasks loading"): with local_ranks_zero_first(): tasks_selection = lighteval_config.tasks.tasks - if lighteval_config.tasks.custom_tasks_file: - _, tasks_groups_dict = get_custom_tasks(lighteval_config.tasks.custom_tasks_file) + if lighteval_config.tasks.custom_tasks: + _, tasks_groups_dict = get_custom_tasks(lighteval_config.tasks.custom_tasks) if tasks_groups_dict and lighteval_config.tasks.tasks in tasks_groups_dict: tasks_selection = tasks_groups_dict[lighteval_config.tasks.tasks] task_names_list, few_shots_dict = taskinfo_selector(tasks_selection) task_dict = Registry(cache_dir=cache_dir).get_task_dict( - task_names_list, custom_tasks_file=lighteval_config.tasks.custom_tasks_file + task_names_list, custom_tasks=lighteval_config.tasks.custom_tasks ) # Loading all the dataset in a distributed manner LightevalTask.load_datasets(task_dict.values(), lighteval_config.tasks.dataset_loading_processes) diff --git a/src/lighteval/models/base_model.py b/src/lighteval/models/base_model.py index 0f753a491..8ebe90de4 100644 --- a/src/lighteval/models/base_model.py +++ b/src/lighteval/models/base_model.py @@ -22,9 +22,7 @@ LoglikelihoodSingleTokenRequest, Request, ) -from lighteval.utils import ( - is_accelerate_available, -) +from lighteval.utils import as_list, is_accelerate_available from lighteval.utils_parallelism import find_executable_batch_size @@ -342,7 +340,7 @@ def greedy_until( list[GenerateReturn]: list of generated responses. """ for request in requests: - request.stop_sequence = request.stop_sequence + [self.tokenizer.eos_token] + request.stop_sequence = as_list(request.stop_sequence) + [self.tokenizer.eos_token] request.tokenized_context = self.tok_encode(request.context) dataset = GenerativeTaskDataset(requests=requests, dataset_splits=self.DATASET_SPLITS) diff --git a/src/lighteval/tasks/lighteval_task.py b/src/lighteval/tasks/lighteval_task.py index e16963a9c..7176f423a 100644 --- a/src/lighteval/tasks/lighteval_task.py +++ b/src/lighteval/tasks/lighteval_task.py @@ -1,8 +1,9 @@ import collections import random +from dataclasses import dataclass from multiprocessing import Pool from pathlib import Path -from typing import TYPE_CHECKING, List, Optional, Tuple +from typing import TYPE_CHECKING, List, Optional, Tuple, Union from datasets import load_dataset @@ -39,8 +40,62 @@ from lighteval.logging.evaluation_tracker import EvaluationTracker +@dataclass +class LightevalTaskConfig: + name: str + prompt_function: str + hf_repo: str + hf_subset: str + metric: Tuple[Union[str, Metrics]] + hf_avail_splits: Optional[Tuple[str]] = None + evaluation_splits: Optional[Tuple[str]] = None + few_shots_split: Optional[str] = None + few_shots_select: Optional[str] = None + generation_size: int = -1 + stop_sequence: Optional[Tuple[str]] = None + output_regex: Optional[str] = None + + frozen: bool = False + suite: Optional[Tuple[str]] = None # we use this to know if we should use a custom lighteval or bigcode task + + def as_dict(self): + return { + "name": self.name, + "prompt_function": self.prompt_function, + "hf_repo": self.hf_repo, + "hf_subset": self.hf_subset, + "metric": tuple(str(m) for m in self.metric), + "hf_avail_splits": self.hf_avail_splits, + "evaluation_splits": self.evaluation_splits, + "few_shots_split": self.few_shots_split, + "few_shots_select": self.few_shots_select, + "generation_size": self.generation_size, + "stop_sequence": self.stop_sequence, + "output_regex": self.output_regex, + "frozen": self.frozen, + "suite": self.suite, + } + + def __post_init__(self): + if self.suite is None: + self.suite = ["custom"] + if self.hf_avail_splits is None: + self.hf_avail_splits = ["train", "validation", "test"] + if self.evaluation_splits is None: + self.evaluation_splits = ["validation"] + if self.stop_sequence is None: + self.stop_sequence = ["\n"] + + # Convert list to tuple for hashing + self.metric = tuple(self.metric) + self.hf_avail_splits = tuple(self.hf_avail_splits) if self.hf_avail_splits is not None else None + self.evaluation_splits = tuple(self.evaluation_splits) if self.evaluation_splits is not None else None + self.suite = tuple(self.suite) if self.suite is not None else None + self.stop_sequence = tuple(self.stop_sequence) if self.stop_sequence is not None else None + + class LightevalTask: - def __init__(self, name: str, cfg: dict, cache_dir: Optional[str] = None, custom_tasks_module=None): + def __init__(self, name: str, cfg: LightevalTaskConfig, cache_dir: Optional[str] = None, custom_tasks_module=None): """ Initialize a LightEval task. @@ -60,8 +115,8 @@ def __init__(self, name: str, cfg: dict, cache_dir: Optional[str] = None, custom self._cfg = cfg # Dataset info - self.hf_repo = cfg["hf_repo"] - self.hf_subset = cfg["hf_subset"] + self.hf_repo = cfg.hf_repo + self.hf_subset = cfg.hf_subset self.dataset_path = self.hf_repo self.dataset_config_name = self.hf_subset self.dataset = None # Delayed download @@ -70,22 +125,22 @@ def __init__(self, name: str, cfg: dict, cache_dir: Optional[str] = None, custom self._docs = None # Managing splits and few shot - self.all_available_splits = as_list(cfg["hf_avail_splits"]) - if cfg.get("evaluation_splits", None) is None: + self.all_available_splits = as_list(cfg.hf_avail_splits) + if cfg.evaluation_splits is None: raise ValueError(f"The evaluation split for task {self.name} is None. Please select a valid split.") - self.evaluation_split = as_list(cfg["evaluation_splits"]) - if cfg.get("few_shots_split", None) is not None: - self.fewshot_split = as_list(cfg["few_shots_split"]) + self.evaluation_split = as_list(cfg.evaluation_splits) + if cfg.few_shots_split is not None: + self.fewshot_split = as_list(cfg.few_shots_split) else: self.fewshot_split = as_list(self.get_first_possible_fewshot_splits()) self.fewshot_sampler = FewShotSampler( - few_shots_select=cfg["few_shots_select"], few_shots_split=self.fewshot_split + few_shots_select=cfg.few_shots_select, few_shots_split=self.fewshot_split ) # Metrics - self.metrics = as_list(cfg["metric"]) - self.suite = as_list(cfg["suite"]) + self.metrics = as_list(cfg.metric) + self.suite = as_list(cfg.suite) ignored = [metric for metric in self.metrics if Metrics[metric].value.category == MetricCategory.IGNORED] if len(ignored) > 0: hlog_warn(f"[WARNING] Not implemented yet: ignoring the metric {' ,'.join(ignored)} for task {self.name}.") @@ -95,20 +150,20 @@ def __init__(self, name: str, cfg: dict, cache_dir: Optional[str] = None, custom # Data processing # to use once prompt formatting is managed as a module if custom_tasks_module is None: - self.formatter = getattr(tasks_prompt_formatting, cfg["prompt_function"]) - elif hasattr(custom_tasks_module, cfg["prompt_function"]): + self.formatter = getattr(tasks_prompt_formatting, cfg.prompt_function) + elif hasattr(custom_tasks_module, cfg.prompt_function): # If we have a prompt in both the custom_tasks_module and our tasks_prompt_formatting # We take the prompt from the custom_tasks_module - if hasattr(tasks_prompt_formatting, cfg["prompt_function"]): + if hasattr(tasks_prompt_formatting, cfg.prompt_function): hlog_warn( - f"Be careful you are using custom prompt function {cfg['prompt_function']} and not the default one." + f"Be careful you are using custom prompt function {cfg.prompt_function} and not the default one." ) - self.formatter = getattr(custom_tasks_module, cfg["prompt_function"]) + self.formatter = getattr(custom_tasks_module, cfg.prompt_function) else: - self.formatter = getattr(tasks_prompt_formatting, cfg["prompt_function"]) - self.generation_size = cfg["generation_size"] - self.stop_sequence = cfg["stop_sequence"] - self.output_regex = cfg["output_regex"] + self.formatter = getattr(tasks_prompt_formatting, cfg.prompt_function) + self.generation_size = cfg.generation_size + self.stop_sequence = cfg.stop_sequence + self.output_regex = cfg.output_regex # Save options self.save_queries: bool = False diff --git a/src/lighteval/tasks/registry.py b/src/lighteval/tasks/registry.py index f662bf5a8..1e7db339e 100644 --- a/src/lighteval/tasks/registry.py +++ b/src/lighteval/tasks/registry.py @@ -1,15 +1,16 @@ import collections import importlib import os +from pathlib import Path from pprint import pformat from types import ModuleType -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional, Tuple, Union from datasets import Dataset from datasets.load import dataset_module_factory from lighteval.logging.hierarchical_logger import hlog, hlog_warn -from lighteval.tasks.lighteval_task import LightevalTask +from lighteval.tasks.lighteval_task import LightevalTask, LightevalTaskConfig # original is the reimplementation of original evals @@ -57,44 +58,48 @@ def get_task_class( ValueError: If the task is not found in the task registry or custom task registry. """ if task_name in self.TASK_REGISTRY: + if custom_tasks_registry is not None and task_name in custom_tasks_registry: + hlog_warn( + f"One of the tasks you requested ({task_name}) exists both in the default and custom tasks. Selecting the default task." + ) return self.TASK_REGISTRY[task_name] - elif custom_tasks_registry is not None and task_name in custom_tasks_registry: + if custom_tasks_registry is not None and task_name in custom_tasks_registry: return custom_tasks_registry[task_name] - else: - hlog_warn(f"{task_name} not found in provided tasks") - hlog_warn(pformat(self.TASK_REGISTRY)) - raise ValueError( - f"Cannot find tasks {task_name} in task list or in custom task registry ({custom_tasks_registry})" - ) + hlog_warn(f"{task_name} not found in provided tasks") + hlog_warn(pformat(self.TASK_REGISTRY)) + raise ValueError( + f"Cannot find tasks {task_name} in task list or in custom task registry ({custom_tasks_registry})" + ) def get_task_dict( - self, task_name_list: List[str], custom_tasks_file: Optional[str] = None + self, task_name_list: List[str], custom_tasks: Optional[Union[str, ModuleType]] = None ) -> Dict[str, LightevalTask]: """ Get a dictionary of tasks based on the task name list. Args: task_name_list (List[str]): A list of task names. - custom_tasks_file (Optional[str]): Path to the custom tasks file. + custom_tasks (Optional[Union[str, ModuleType]]): Path to the custom tasks file or name of a module to import containing custom tasks or the module it-self Returns: Dict[str, LightevalTask]: A dictionary containing the tasks. Notes: - - If custom_tasks_file is provided, it will import the custom tasks module and create a custom tasks registry. + - If custom_tasks is provided, it will import the custom tasks module and create a custom tasks registry. - Each task in the task_name_list will be instantiated with the corresponding task class. """ - if custom_tasks_file is not None: - dataset_module = dataset_module_factory(str(custom_tasks_file)) - custom_tasks_module = importlib.import_module(dataset_module.module_path) + # Import custom tasks provided by the user + custom_tasks_registry = None + custom_tasks_module = None + if custom_tasks is not None: + custom_tasks_module = create_custom_tasks_module(custom_tasks=custom_tasks) + if custom_tasks_module is not None: custom_tasks_registry = create_config_tasks( meta_table=custom_tasks_module.TASKS_TABLE, cache_dir=self.cache_dir ) hlog(custom_tasks_registry) - else: - custom_tasks_module = None - custom_tasks_registry = None + # Select relevant tasks given the subset asked for by the user tasks_dict = {} for task_name in task_name_list: task_class = self.get_task_class(task_name, custom_tasks_registry=custom_tasks_registry) @@ -103,9 +108,32 @@ def get_task_dict( return tasks_dict -def get_custom_tasks(custom_tasks_file: str) -> Tuple[ModuleType, str]: - dataset_module = dataset_module_factory(str(custom_tasks_file)) - custom_tasks_module = importlib.import_module(dataset_module.module_path) +def create_custom_tasks_module(custom_tasks: Union[str, ModuleType]) -> ModuleType: + """Creates a custom task module to load tasks defined by the user in their own file. + + Args: + custom_tasks (Optional[Union[str, ModuleType]]): Path to the custom tasks file or name of a module to import containing custom tasks or the module it-self + + Returns: + ModuleType: The newly imported/created custom tasks modules + """ + if isinstance(custom_tasks, ModuleType): + return custom_tasks + if isinstance(custom_tasks, (str, Path)) and os.path.exists(custom_tasks): + dataset_module = dataset_module_factory(str(custom_tasks)) + return importlib.import_module(dataset_module.module_path) + if isinstance(custom_tasks, (str, Path)): + return importlib.import_module(custom_tasks) + raise ValueError(f"Cannot import custom tasks from {custom_tasks}") + + +def get_custom_tasks(custom_tasks: Union[str, ModuleType]) -> Tuple[ModuleType, str]: + """Get custom tasks from the given custom tasks file or module. + + Args: + custom_tasks (Optional[Union[str, ModuleType]]): Path to the custom tasks file or name of a module to import containing custom tasks or the module it-self + """ + custom_tasks_module = create_custom_tasks_module(custom_tasks=custom_tasks) tasks_string = "" if hasattr(custom_tasks_module, "TASKS_GROUPS"): tasks_string = custom_tasks_module.TASKS_GROUPS @@ -116,7 +144,7 @@ def taskinfo_selector( tasks: str, ) -> tuple[list[str], dict[str, list[tuple[int, bool]]]]: """ - Selects task information based on the given tasks and description dictionary path. + Converts a input string of tasks name to task information usable by lighteval. Args: tasks (str): A string containing a comma-separated list of tasks in the @@ -174,7 +202,7 @@ def create_config_tasks( Dict[str, LightevalTask]: A dictionary of task names mapped to their corresponding LightevalTask classes. """ - def create_task(name, cfg, cache_dir): + def create_task(name, cfg: LightevalTaskConfig, cache_dir: str): class LightevalTaskFromConfig(LightevalTask): def __init__(self, custom_tasks_module=None): super().__init__(name, cfg, cache_dir=cache_dir, custom_tasks_module=custom_tasks_module) @@ -194,18 +222,6 @@ def __init__(self, custom_tasks_module=None): continue for suite in line["suite"]: if suite in DEFAULT_SUITES: - tasks_with_config[f"{suite}|{line['name']}"] = line + tasks_with_config[f"{suite}|{line['name']}"] = LightevalTaskConfig(**line) return {task: create_task(task, cfg, cache_dir=cache_dir) for task, cfg in tasks_with_config.items()} - - -def task_to_suites(suites_selection: list = None): - task_to_suites = {} - meta_table = Dataset.from_json(TABLE_PATH) - for line in meta_table: - if suites_selection is None: - task_to_suites[line["name"]] = line["suite"] - else: - task_to_suites[line["name"]] = [suite for suite in line["suite"] if suite in suites_selection] - - return task_to_suites diff --git a/tasks_examples/custom_tasks/custom_evaluation_tasks.py b/tasks_examples/custom_tasks/custom_evaluation_tasks.py index b0dae200c..0ed928e59 100644 --- a/tasks_examples/custom_tasks/custom_evaluation_tasks.py +++ b/tasks_examples/custom_tasks/custom_evaluation_tasks.py @@ -6,44 +6,41 @@ """ import re from dataclasses import asdict -from typing import Dict, List +from typing import Dict, List, Tuple +from lighteval.metrics import Metrics +from lighteval.tasks.lighteval_task import LightevalTaskConfig from lighteval.tasks.requests import Doc +from lighteval.tasks.tasks_prompt_formatting import LETTER_INDICES -from .custom_evaluation_utils import * - -# fmt: off -LETTER_INDICES = ["A", "B", "C", "D", "E", "F", "G", "H", "I", "J", "K", "L", "M", "N", "O", "P", "Q", "R", "S", "T", "U", "V", "W", "X", "Y", "Z"] -# fmt: on - -_TASKS_STRINGS: List[Tuple[CustomEvaluationTask, str]] = [] -_TASKS: List[CustomEvaluationTask] = [] +_TASKS_STRINGS: List[Tuple[LightevalTaskConfig, str]] = [] +_TASKS: List[LightevalTaskConfig] = [] ## COMMON_SENSE_REASONING_TASKS ## COMMON_SENSE_REASONING_TASKS = [ - CustomEvaluationTask( + LightevalTaskConfig( name="hellaswag", prompt_function="hellaswag_prompt", hf_repo="hellaswag", hf_subset="default", metric=["loglikelihood_acc", "loglikelihood_acc_norm_nospace"], ), - CustomEvaluationTask( + LightevalTaskConfig( name="winogrande", prompt_function="winogrande", hf_repo="winogrande", hf_subset="winogrande_xl", metric=["loglikelihood_acc", "loglikelihood_acc_norm_nospace"], ), - CustomEvaluationTask( + LightevalTaskConfig( name="piqa", prompt_function="piqa_harness", hf_repo="piqa", hf_subset="plain_text", metric=["loglikelihood_acc", "loglikelihood_acc_norm_nospace"], ), - CustomEvaluationTask( + LightevalTaskConfig( name="siqa", prompt_function="siqa_prompt", hf_repo="lighteval/siqa", @@ -51,14 +48,14 @@ hf_avail_splits=["train", "validation"], metric=["loglikelihood_acc", "loglikelihood_acc_norm_nospace"], ), - CustomEvaluationTask( + LightevalTaskConfig( name="openbookqa", prompt_function="openbookqa", hf_repo="openbookqa", hf_subset="main", metric=["loglikelihood_acc", "loglikelihood_acc_norm_nospace"], ), - CustomEvaluationTask( + LightevalTaskConfig( name="arc:easy", prompt_function="arc", hf_repo="ai2_arc", @@ -67,7 +64,7 @@ generation_size=1, metric=["loglikelihood_acc", "loglikelihood_acc_norm_nospace"], ), - CustomEvaluationTask( + LightevalTaskConfig( name="arc:challenge", prompt_function="arc", hf_repo="ai2_arc", @@ -76,7 +73,7 @@ generation_size=1, metric=["loglikelihood_acc", "loglikelihood_acc_norm_nospace"], ), - CustomEvaluationTask( + LightevalTaskConfig( name="commonsense_qa", prompt_function="commonsense_qa_prompt", hf_repo="commonsense_qa", @@ -134,21 +131,21 @@ def preprocess(text): ## WORLD_KNOWLEDGE_TASKS ## WORLD_KNOWLEDGE_TASKS = [ - CustomEvaluationTask( + LightevalTaskConfig( name="trivia_qa", prompt_function="triviaqa", hf_repo="trivia_qa", hf_subset="rc.nocontext", - metric=[Metrics.quasi_exact_match2], + metric=[Metrics.quasi_exact_match], generation_size=20, stop_sequence=["\n", ".", ","], ), - CustomEvaluationTask( + LightevalTaskConfig( name="natural_questions", prompt_function="natural_questions_prompt", hf_repo="lighteval/natural_questions_clean", hf_subset="default", - metric=[Metrics.quasi_exact_match2], + metric=[Metrics.quasi_exact_match], generation_size=20, stop_sequence=["\n", ".", ","], ), @@ -173,19 +170,19 @@ def natural_questions_prompt(line, task_name: str = None): ## Reading comprehension ## READING_COMP_TASKS = [ - CustomEvaluationTask( + LightevalTaskConfig( name="super_glue:boolq", prompt_function="boolq_prompt", hf_repo="super_glue", hf_subset="boolq", metric=["target_perplexity"], ), - CustomEvaluationTask( + LightevalTaskConfig( name="quac", prompt_function="quac", hf_repo="lighteval/quac_helm", hf_subset="deault", - metric=[Metrics.quasi_exact_match2], + metric=[Metrics.quasi_exact_match], generation_size=20, stop_sequence=["\n", ".", ","], ), @@ -207,7 +204,7 @@ def boolq_prompt(line, task_name: str = None): ## MATH ## -class CustomMathEvaluationTask(CustomEvaluationTask): +class CustomMathEvaluationTask(LightevalTaskConfig): """Custom class for math tasks with all the defaults set""" def __init__( @@ -216,7 +213,7 @@ def __init__( prompt_function="math", hf_repo="lighteval/MATH", hf_subset=None, - metric=[Metrics.math_quasi_exact_match], + metric=[Metrics.quasi_exact_match_math], hf_avail_splits=None, evaluation_splits=["test"], few_shots_split=None, @@ -254,7 +251,7 @@ def __init__( CustomMathEvaluationTask(name="math:prealgebra", hf_subset="prealgebra"), CustomMathEvaluationTask(name="math:precalculus", hf_subset="precalculus"), ] -GSM8K = CustomEvaluationTask( +GSM8K = LightevalTaskConfig( name="gsm8k", prompt_function="gsm8k", hf_repo="gsm8k", @@ -275,7 +272,7 @@ def __init__( ## MMLU ## -class CustomMMLUEvaluationTask(CustomEvaluationTask): +class CustomMMLUEvaluationTask(LightevalTaskConfig): def __init__( self, name, @@ -418,7 +415,7 @@ def mmlu_prompt(line, task_name: str = None): ## BBH ## -class CustomBBHEvaluationTask(CustomEvaluationTask): +class CustomBBHEvaluationTask(LightevalTaskConfig): def __init__( self, name, @@ -509,7 +506,7 @@ def bbh_prompt(line, task_name: str = None): ## AGI eval ## -class CustomAGIEvalEvaluationTask(CustomEvaluationTask): +class CustomAGIEvalEvaluationTask(LightevalTaskConfig): def __init__( self, name, @@ -556,7 +553,7 @@ def __init__( name="agi_eval:math", hf_subset="math", prompt_function="agi_eval_math_prompt", - metric=[Metrics.exact_match, Metrics.quasi_exact_match2], + metric=[Metrics.exact_match, Metrics.quasi_exact_match], generation_size=40, ), CustomAGIEvalEvaluationTask(name="agi_eval:sat-en", hf_subset="sat-en"), @@ -620,7 +617,7 @@ def agi_eval_prompt_no_letters(line, task_name: str = None): ## HUMAN EVAL ## -# human_eval = CustomEvaluationTask( +# human_eval = LightevalTaskConfig( # name="human_eval", # prompt_function="human_eval", # hf_repo="lighteval/human_eval", @@ -628,23 +625,14 @@ def agi_eval_prompt_no_letters(line, task_name: str = None): # ), -def has_generative_metrics(task: CustomEvaluationTask) -> bool: - for metric in task.metric: - if metric in NEEDS_GENERATION_ONLY: - return True - return False - - EARLY_SIGNAL_TASKS = ",".join([t[1] for t in COMMON_SENSE_REASONING_STRING] + [t[1] for t in MMLU_STRING]) # Convert to dict for lighteval -TASKS_TABLE = [asdict(task) for task in _TASKS] +TASKS_TABLE = [task.as_dict() for task in _TASKS] # You can have a few pre-organised groups of tasks TASKS_GROUPS = { "all": ",".join(t[1] for t in _TASKS_STRINGS), "early-signal": EARLY_SIGNAL_TASKS, - "non-generatives": ",".join(t for k, t in _TASKS_STRINGS if not has_generative_metrics(k)), - "generatives": ",".join(t for k, t in _TASKS_STRINGS if has_generative_metrics(k)), } if __name__ == "__main__": diff --git a/tasks_examples/custom_tasks/custom_evaluation_utils.py b/tasks_examples/custom_tasks/custom_evaluation_utils.py deleted file mode 100644 index d3f005db1..000000000 --- a/tasks_examples/custom_tasks/custom_evaluation_utils.py +++ /dev/null @@ -1,159 +0,0 @@ -""" -Custom evaluation tasks for lighteval -""" -from dataclasses import dataclass -from enum import Enum, auto -from typing import Optional, Tuple, Union - - -class Metrics(Enum): - any_target_loglikelihood_acc = auto() - bert_score = auto() - bias = auto() - bits_per_byte = auto() - bleu = auto() - bleu_1 = auto() - bleu_4 = auto() - byte_perplexity = auto() - chrf = auto() - code_eval_APPS = auto() - code_eval_HE = auto() - copyright = auto() - disinformation = auto() - exact_match = auto() - exact_set_match = auto() - extractiveness = auto() - f1_from_bags = auto() - f1_quasi = auto() - f1_sequence = auto() - f1_set_match = auto() - faithfulness = auto() - iou_set_match = auto() - log_prob = auto() - loglikelihood_acc = auto() - loglikelihood_acc_norm = auto() - loglikelihood_acc_norm_nospace = auto() - loglikelihood_acc_norm_single_token = auto() - loglikelihood_acc_single_token = auto() - loglikelihood_f1 = auto() - loglikelihood_f1_single_token = auto() - math_quasi_exact_match = auto() - mc_taco = auto() - mcc = auto() - mcc_single_token = auto() - mrr = auto() - mrr_single_token = auto() - multi_fi_numeric = auto() - one_choice_loglikelihood_acc = auto() - perfect_exact_match = auto() - prediction_perplexity = auto() - prefix_exact_match = auto() - prefix_quasi_exact_match = auto() - quasi_exact_match = auto() - quasi_exact_match2 = auto() - ranking = auto() - recall_at_1_single_token = auto() - recall_at_2_single_token = auto() - recall_at_1 = auto() - recall_at_2 = auto() - rouge = auto() - rouge_1 = auto() - rouge_2 = auto() - rouge_l = auto() - target_perplexity = auto() - ter = auto() - toxicity = auto() - truthfulqa_mc_metrics = auto() - word_perplexity = auto() - - def __str__(self): - return self.name.replace("_at_", "@") - - -NEEDS_GENERATION_ONLY = [ - "perfect_exact_match", - "exact_match", - "quasi_exact_match", - "quasi_exact_match2", - "prefix_exact_match", - "prefix_quasi_exact_match", - "math_quasi_exact_match", - "iou_set_match", - "exact_set_match", - "f1_sequence", - "f1_quasi", - "f1_set_match", - "f1_from_bags", - "chrf", - "ter", - "rouge", - "rouge_1", - "rouge_2", - "rouge_l", - "faithfulness", - "extractiveness", - "bert_score", - "bleu", - "bleu_1", - "bleu_4", - "bias", - "toxicity", - "code_eval_HE", - "code_eval_APPS", - "copyright", -] - - -@dataclass(unsafe_hash=True) -class CustomEvaluationTask: - name: str - prompt_function: str - hf_repo: str - hf_subset: str - metric: Tuple[Union[str, Metrics]] - hf_avail_splits: Optional[Tuple[str]] = None - evaluation_splits: Optional[Tuple[str]] = None - few_shots_split: Optional[str] = None - few_shots_select: Optional[str] = None - generation_size: int = -1 - stop_sequence: Optional[Tuple[str]] = None - output_regex: Optional[str] = None - - frozen: bool = False - suite: Optional[Tuple[str]] = None # we use this to know if we should use a custom lighteval or bigcode task - - def __post_init__(self): - self.metric = [str(m) for m in self.metric] - if self.suite is None: - self.suite = ["custom"] - if self.hf_avail_splits is None: - self.hf_avail_splits = ["train", "validation", "test"] - if self.evaluation_splits is None: - self.evaluation_splits = ["validation"] - if self.stop_sequence is None: - self.stop_sequence = ["\n"] - - # Convert list to tuple for hashing - self.metric = tuple(self.metric) - self.hf_avail_splits = tuple(self.hf_avail_splits) if self.hf_avail_splits else None - self.evaluation_splits = tuple(self.evaluation_splits) if self.evaluation_splits else None - self.suite = tuple(self.suite) if self.suite else None - self.stop_sequence = tuple(self.stop_sequence) if self.stop_sequence else None - - -@dataclass(unsafe_hash=True) -class BigCodeEvaluationTask: - name: str - bigcode_task: str - bigcode_task_kwargs: Optional[dict] = None - n_samples: int = 1 - prefix: Optional[str] = None - - suite: Tuple[str] = None - - def __post_init__(self): - if self.suite is None: - self.suite = ("bigcode",) - - # Convert list to tuple for hashing - self.suite = tuple(self.suite) diff --git a/tasks_examples/custom_tasks/lighteval_config_override_template.yaml b/tasks_examples/custom_tasks/lighteval_config_override_template.yaml index 6544a88af..390a81ecb 100644 --- a/tasks_examples/custom_tasks/lighteval_config_override_template.yaml +++ b/tasks_examples/custom_tasks/lighteval_config_override_template.yaml @@ -20,9 +20,9 @@ lighteval: tp_linear_async_communication: false tp_mode: ALL_REDUCE tasks: - custom_tasks_file: /fsx/thomwolf/github/lighteval/tasks_examples/custom_tasks/custom_evaluation_tasks.py + custom_tasks: /fsx/thomwolf/github/lighteval/tasks_examples/custom_tasks/custom_evaluation_tasks.py dataset_loading_processes: 8 - max_samples: 1000 + max_samples: 10 multichoice_continuations_start_space: null no_multichoice_continuations_start_space: null num_fewshot_seeds: null