From f047874a0eebe716831f683f63fafde56481878a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9mentine=20Fourrier?= <22726840+clefourrier@users.noreply.github.com> Date: Thu, 1 Aug 2024 09:38:07 +0200 Subject: [PATCH 1/2] Fix inference endpoint config (#244) --- src/lighteval/models/model_config.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/lighteval/models/model_config.py b/src/lighteval/models/model_config.py index 6f7af8af4..eb8320c03 100644 --- a/src/lighteval/models/model_config.py +++ b/src/lighteval/models/model_config.py @@ -325,10 +325,10 @@ def create_model_config( # noqa: C901 ) if config["type"] == "endpoint": - reuse_existing_endpoint = config["base_params"]["reuse_existing"] + reuse_existing_endpoint = config["base_params"].get("reuse_existing", None) complete_config_endpoint = all( val not in [None, ""] - for key, val in config["instance"].items() + for key, val in config.get("instance", {}).items() if key not in InferenceEndpointModelConfig.nullable_keys() ) if reuse_existing_endpoint or complete_config_endpoint: From cbae17dcbc9401a3015e0ca2d4f78716ee2ed69e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9mentine=20Fourrier?= <22726840+clefourrier@users.noreply.github.com> Date: Thu, 1 Aug 2024 10:43:45 +0200 Subject: [PATCH 2/2] Expose samples via the CLI (#228) * add examples of samples (does not include few shot), and add robustness to logger * added cache at all levels * added nice printing for the config and more options for task display * added pprint --------- Co-authored-by: Nathan Habib <30601243+NathanHB@users.noreply.github.com> --- src/lighteval/__main__.py | 50 +++++++++++++++----- src/lighteval/logging/hierarchical_logger.py | 19 ++++++-- src/lighteval/main_accelerate.py | 1 - src/lighteval/metrics/utils.py | 4 +- src/lighteval/parsers.py | 29 ++++++++++-- src/lighteval/tasks/lighteval_task.py | 30 +++++++++++- 6 files changed, 108 insertions(+), 25 deletions(-) diff --git a/src/lighteval/__main__.py b/src/lighteval/__main__.py index 9deb09251..8cf7c011f 100644 --- a/src/lighteval/__main__.py +++ b/src/lighteval/__main__.py @@ -23,24 +23,32 @@ # SOFTWARE. import argparse +import os +from dataclasses import asdict +from pprint import pformat -from lighteval.parsers import parser_accelerate, parser_nanotron -from lighteval.tasks.registry import Registry +from lighteval.parsers import parser_accelerate, parser_nanotron, parser_utils_tasks +from lighteval.tasks.registry import Registry, taskinfo_selector + + +CACHE_DIR = os.getenv("HF_HOME") def cli_evaluate(): parser = argparse.ArgumentParser(description="CLI tool for lighteval, a lightweight framework for LLM evaluation") subparsers = parser.add_subparsers(help="help for subcommand", dest="subcommand") - # create the parser for the "accelerate" command + # Subparser for the "accelerate" command parser_a = subparsers.add_parser("accelerate", help="use accelerate and transformers as backend for evaluation.") parser_accelerate(parser_a) - # create the parser for the "nanotron" command + # Subparser for the "nanotron" command parser_b = subparsers.add_parser("nanotron", help="use nanotron as backend for evaluation.") parser_nanotron(parser_b) - parser.add_argument("--list-tasks", action="store_true", help="List available tasks") + # Subparser for task utils functions + parser_c = subparsers.add_parser("tasks", help="use nanotron as backend for evaluation.") + parser_utils_tasks(parser_c) args = parser.parse_args() @@ -48,17 +56,37 @@ def cli_evaluate(): from lighteval.main_accelerate import main as main_accelerate main_accelerate(args) - return - if args.subcommand == "nanotron": + elif args.subcommand == "nanotron": from lighteval.main_nanotron import main as main_nanotron main_nanotron(args.checkpoint_config_path, args.lighteval_override, args.cache_dir) - return - if args.list_tasks: - Registry(cache_dir="").print_all_tasks() - return + elif args.subcommand == "tasks": + if args.list: + Registry(cache_dir="").print_all_tasks() + + if args.inspect: + print(f"Loading the tasks dataset to cache folder: {args.cache_dir}") + print( + "All examples will be displayed without few shot, as few shot sample construction requires loading a model and using its tokenizer." + ) + # Loading task + task_names_list, _ = taskinfo_selector(args.inspect) + task_dict = Registry(cache_dir=args.cache_dir).get_task_dict(task_names_list) + for name, task in task_dict.items(): + print("-" * 10, name, "-" * 10) + if args.show_config: + print("-" * 10, "CONFIG") + task.print_config() + for ix, sample in enumerate(task.eval_docs()[: int(args.num_samples)]): + if ix == 0: + print("-" * 10, "SAMPLES") + print(f"-- sample {ix} --") + print(pformat(asdict(sample), indent=1)) + + else: + print("You did not provide any argument. Exiting") if __name__ == "__main__": diff --git a/src/lighteval/logging/hierarchical_logger.py b/src/lighteval/logging/hierarchical_logger.py index 5efb44153..9cf718f18 100644 --- a/src/lighteval/logging/hierarchical_logger.py +++ b/src/lighteval/logging/hierarchical_logger.py @@ -23,6 +23,7 @@ import sys import time from datetime import timedelta +from logging import Logger from typing import Any, Callable from lighteval.utils import is_accelerate_available, is_nanotron_available @@ -37,8 +38,6 @@ logger = get_logger(__name__, log_level="INFO") else: - from logging import Logger - logger = Logger(__name__, level="INFO") from colorama import Fore, Style @@ -76,6 +75,7 @@ def log(self, x: Any) -> None: HIERARCHICAL_LOGGER = HierarchicalLogger() +BACKUP_LOGGER = Logger(__name__, level="INFO") # Exposed public methods @@ -84,7 +84,10 @@ def hlog(x: Any) -> None: Logs a string version of x through the singleton [`HierarchicalLogger`]. """ - HIERARCHICAL_LOGGER.log(x) + try: + HIERARCHICAL_LOGGER.log(x) + except RuntimeError: + BACKUP_LOGGER.warning(x) def hlog_warn(x: Any) -> None: @@ -92,7 +95,10 @@ def hlog_warn(x: Any) -> None: Logs a string version of x, which will appear in a yellow color, through the singleton [`HierarchicalLogger`]. """ - HIERARCHICAL_LOGGER.log(Fore.YELLOW + str(x) + Style.RESET_ALL) + try: + HIERARCHICAL_LOGGER.log(Fore.YELLOW + str(x) + Style.RESET_ALL) + except RuntimeError: + BACKUP_LOGGER.warning(Fore.YELLOW + str(x) + Style.RESET_ALL) def hlog_err(x: Any) -> None: @@ -100,7 +106,10 @@ def hlog_err(x: Any) -> None: Logs a string version of x, which will appear in a red color, through the singleton [`HierarchicalLogger`]. """ - HIERARCHICAL_LOGGER.log(Fore.RED + str(x) + Style.RESET_ALL) + try: + HIERARCHICAL_LOGGER.log(Fore.RED + str(x) + Style.RESET_ALL) + except RuntimeError: + BACKUP_LOGGER.warning(Fore.RED + str(x) + Style.RESET_ALL) class htrack_block: diff --git a/src/lighteval/main_accelerate.py b/src/lighteval/main_accelerate.py index 12122c527..904a68322 100644 --- a/src/lighteval/main_accelerate.py +++ b/src/lighteval/main_accelerate.py @@ -43,7 +43,6 @@ hlog_warn("Using either accelerate or text-generation to run this script is advised.") TOKEN = os.getenv("HF_TOKEN") -CACHE_DIR = os.getenv("HF_HOME") if is_accelerate_available(): from accelerate import Accelerator, InitProcessGroupKwargs diff --git a/src/lighteval/metrics/utils.py b/src/lighteval/metrics/utils.py index 52e8e0665..0310b5f4b 100644 --- a/src/lighteval/metrics/utils.py +++ b/src/lighteval/metrics/utils.py @@ -24,7 +24,7 @@ from enum import Enum, auto -class MetricCategory(Enum): +class MetricCategory(str, Enum): TARGET_PERPLEXITY = auto() PERPLEXITY = auto() GENERATIVE = auto() @@ -37,7 +37,7 @@ class MetricCategory(Enum): IGNORED = auto() -class MetricUseCase(Enum): +class MetricUseCase(str, Enum): # General ACCURACY = auto() PERPLEXITY = auto() diff --git a/src/lighteval/parsers.py b/src/lighteval/parsers.py index d05ba312f..499d945ec 100644 --- a/src/lighteval/parsers.py +++ b/src/lighteval/parsers.py @@ -66,9 +66,6 @@ def parser_accelerate(parser=None): parser.add_argument( "--public_run", default=False, action="store_true", help="Push results and details to a public repo" ) - parser.add_argument( - "--cache_dir", type=str, default=CACHE_DIR, help="Cache directory used to store datasets and models" - ) parser.add_argument( "--results_org", type=str, @@ -99,6 +96,9 @@ def parser_accelerate(parser=None): default=None, help="Id of a task, e.g. 'original|mmlu:abstract_algebra|5' or path to a texte file with a list of tasks", ) + parser.add_argument( + "--cache_dir", type=str, default=CACHE_DIR, help="Cache directory used to store datasets and models" + ) parser.add_argument("--num_fewshot_seeds", type=int, default=1, help="Number of trials the few shots") return parser @@ -121,8 +121,27 @@ def parser_nanotron(parser=None): help="Path to an optional YAML or python Lighteval config to override part of the checkpoint Lighteval config", ) parser.add_argument( - "--cache-dir", + "--cache_dir", type=str, default=CACHE_DIR, help="Cache directory used to store datasets and models" + ) + + +def parser_utils_tasks(parser=None): + if parser is None: + parser = argparse.ArgumentParser( + description="CLI tool for lighteval, a lightweight framework for LLM evaluation" + ) + + group = parser.add_mutually_exclusive_group(required=True) + + group.add_argument("--list", action="store_true", help="List available tasks") + group.add_argument( + "--inspect", type=str, default=None, - help="Cache directory", + help="Id of tasks or path to a text file with a list of tasks (e.g. 'original|mmlu:abstract_algebra|5') for which you want to manually inspect samples.", + ) + parser.add_argument("--num_samples", type=int, default=10, help="Number of samples to display") + parser.add_argument("--show_config", default=False, action="store_true", help="Will display the full task config") + parser.add_argument( + "--cache_dir", type=str, default=CACHE_DIR, help="Cache directory used to store datasets and models" ) diff --git a/src/lighteval/tasks/lighteval_task.py b/src/lighteval/tasks/lighteval_task.py index 07251d696..701357420 100644 --- a/src/lighteval/tasks/lighteval_task.py +++ b/src/lighteval/tasks/lighteval_task.py @@ -21,14 +21,16 @@ # SOFTWARE. import collections +import inspect import os import random -from dataclasses import dataclass +from dataclasses import asdict, dataclass from multiprocessing import Pool from pathlib import Path from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union from datasets import load_dataset +from pytablewriter import MarkdownTableWriter from lighteval.few_shot_manager import FewShotSampler from lighteval.logging.hierarchical_logger import hlog, hlog_warn @@ -232,6 +234,32 @@ def __init__( # noqa: C901 def cfg(self): return self._cfg + def print_config(self): + md_writer = MarkdownTableWriter() + md_writer.headers = ["Key", "Value"] + + values = [] + + for k, v in asdict(self.cfg).items(): + if k == "metric": + for ix, metrics in enumerate(v): + for metric_k, metric_v in metrics.items(): + if inspect.ismethod(metric_v): + values.append([f"{k} {ix}: {metric_k}", metric_v.__qualname__]) + else: + values.append([f"{k} {ix}: {metric_k}", repr(metric_v)]) + + else: + if isinstance(v, Callable): + values.append([k, v.__name__]) + else: + values.append([k, repr(v)]) + # print(k, ":", repr(v)) + + md_writer.value_matrix = values + + print(md_writer.dumps()) + def doc_to_text_without_instructions(self, doc: Doc) -> str: """ Returns the query of the document without the instructions. If the