Skip to content

Commit

Permalink
Expose samples via the CLI (#228)
Browse files Browse the repository at this point in the history
* add examples of samples (does not include few shot), and add robustness to logger

* added cache at all levels

* added nice printing for the config and more options for task display

* added pprint

---------

Co-authored-by: Nathan Habib <[email protected]>
  • Loading branch information
clefourrier and NathanHB authored Aug 1, 2024
1 parent f047874 commit cbae17d
Show file tree
Hide file tree
Showing 6 changed files with 108 additions and 25 deletions.
50 changes: 39 additions & 11 deletions src/lighteval/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,42 +23,70 @@
# SOFTWARE.

import argparse
import os
from dataclasses import asdict
from pprint import pformat

from lighteval.parsers import parser_accelerate, parser_nanotron
from lighteval.tasks.registry import Registry
from lighteval.parsers import parser_accelerate, parser_nanotron, parser_utils_tasks
from lighteval.tasks.registry import Registry, taskinfo_selector


CACHE_DIR = os.getenv("HF_HOME")


def cli_evaluate():
parser = argparse.ArgumentParser(description="CLI tool for lighteval, a lightweight framework for LLM evaluation")
subparsers = parser.add_subparsers(help="help for subcommand", dest="subcommand")

# create the parser for the "accelerate" command
# Subparser for the "accelerate" command
parser_a = subparsers.add_parser("accelerate", help="use accelerate and transformers as backend for evaluation.")
parser_accelerate(parser_a)

# create the parser for the "nanotron" command
# Subparser for the "nanotron" command
parser_b = subparsers.add_parser("nanotron", help="use nanotron as backend for evaluation.")
parser_nanotron(parser_b)

parser.add_argument("--list-tasks", action="store_true", help="List available tasks")
# Subparser for task utils functions
parser_c = subparsers.add_parser("tasks", help="use nanotron as backend for evaluation.")
parser_utils_tasks(parser_c)

args = parser.parse_args()

if args.subcommand == "accelerate":
from lighteval.main_accelerate import main as main_accelerate

main_accelerate(args)
return

if args.subcommand == "nanotron":
elif args.subcommand == "nanotron":
from lighteval.main_nanotron import main as main_nanotron

main_nanotron(args.checkpoint_config_path, args.lighteval_override, args.cache_dir)
return

if args.list_tasks:
Registry(cache_dir="").print_all_tasks()
return
elif args.subcommand == "tasks":
if args.list:
Registry(cache_dir="").print_all_tasks()

if args.inspect:
print(f"Loading the tasks dataset to cache folder: {args.cache_dir}")
print(
"All examples will be displayed without few shot, as few shot sample construction requires loading a model and using its tokenizer."
)
# Loading task
task_names_list, _ = taskinfo_selector(args.inspect)
task_dict = Registry(cache_dir=args.cache_dir).get_task_dict(task_names_list)
for name, task in task_dict.items():
print("-" * 10, name, "-" * 10)
if args.show_config:
print("-" * 10, "CONFIG")
task.print_config()
for ix, sample in enumerate(task.eval_docs()[: int(args.num_samples)]):
if ix == 0:
print("-" * 10, "SAMPLES")
print(f"-- sample {ix} --")
print(pformat(asdict(sample), indent=1))

else:
print("You did not provide any argument. Exiting")


if __name__ == "__main__":
Expand Down
19 changes: 14 additions & 5 deletions src/lighteval/logging/hierarchical_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import sys
import time
from datetime import timedelta
from logging import Logger
from typing import Any, Callable

from lighteval.utils import is_accelerate_available, is_nanotron_available
Expand All @@ -37,8 +38,6 @@

logger = get_logger(__name__, log_level="INFO")
else:
from logging import Logger

logger = Logger(__name__, level="INFO")

from colorama import Fore, Style
Expand Down Expand Up @@ -76,6 +75,7 @@ def log(self, x: Any) -> None:


HIERARCHICAL_LOGGER = HierarchicalLogger()
BACKUP_LOGGER = Logger(__name__, level="INFO")


# Exposed public methods
Expand All @@ -84,23 +84,32 @@ def hlog(x: Any) -> None:
Logs a string version of x through the singleton [`HierarchicalLogger`].
"""
HIERARCHICAL_LOGGER.log(x)
try:
HIERARCHICAL_LOGGER.log(x)
except RuntimeError:
BACKUP_LOGGER.warning(x)


def hlog_warn(x: Any) -> None:
"""Warning logger.
Logs a string version of x, which will appear in a yellow color, through the singleton [`HierarchicalLogger`].
"""
HIERARCHICAL_LOGGER.log(Fore.YELLOW + str(x) + Style.RESET_ALL)
try:
HIERARCHICAL_LOGGER.log(Fore.YELLOW + str(x) + Style.RESET_ALL)
except RuntimeError:
BACKUP_LOGGER.warning(Fore.YELLOW + str(x) + Style.RESET_ALL)


def hlog_err(x: Any) -> None:
"""Error logger.
Logs a string version of x, which will appear in a red color, through the singleton [`HierarchicalLogger`].
"""
HIERARCHICAL_LOGGER.log(Fore.RED + str(x) + Style.RESET_ALL)
try:
HIERARCHICAL_LOGGER.log(Fore.RED + str(x) + Style.RESET_ALL)
except RuntimeError:
BACKUP_LOGGER.warning(Fore.RED + str(x) + Style.RESET_ALL)


class htrack_block:
Expand Down
1 change: 0 additions & 1 deletion src/lighteval/main_accelerate.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@
hlog_warn("Using either accelerate or text-generation to run this script is advised.")

TOKEN = os.getenv("HF_TOKEN")
CACHE_DIR = os.getenv("HF_HOME")

if is_accelerate_available():
from accelerate import Accelerator, InitProcessGroupKwargs
Expand Down
4 changes: 2 additions & 2 deletions src/lighteval/metrics/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from enum import Enum, auto


class MetricCategory(Enum):
class MetricCategory(str, Enum):
TARGET_PERPLEXITY = auto()
PERPLEXITY = auto()
GENERATIVE = auto()
Expand All @@ -37,7 +37,7 @@ class MetricCategory(Enum):
IGNORED = auto()


class MetricUseCase(Enum):
class MetricUseCase(str, Enum):
# General
ACCURACY = auto()
PERPLEXITY = auto()
Expand Down
29 changes: 24 additions & 5 deletions src/lighteval/parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,6 @@ def parser_accelerate(parser=None):
parser.add_argument(
"--public_run", default=False, action="store_true", help="Push results and details to a public repo"
)
parser.add_argument(
"--cache_dir", type=str, default=CACHE_DIR, help="Cache directory used to store datasets and models"
)
parser.add_argument(
"--results_org",
type=str,
Expand Down Expand Up @@ -99,6 +96,9 @@ def parser_accelerate(parser=None):
default=None,
help="Id of a task, e.g. 'original|mmlu:abstract_algebra|5' or path to a texte file with a list of tasks",
)
parser.add_argument(
"--cache_dir", type=str, default=CACHE_DIR, help="Cache directory used to store datasets and models"
)
parser.add_argument("--num_fewshot_seeds", type=int, default=1, help="Number of trials the few shots")
return parser

Expand All @@ -121,8 +121,27 @@ def parser_nanotron(parser=None):
help="Path to an optional YAML or python Lighteval config to override part of the checkpoint Lighteval config",
)
parser.add_argument(
"--cache-dir",
"--cache_dir", type=str, default=CACHE_DIR, help="Cache directory used to store datasets and models"
)


def parser_utils_tasks(parser=None):
if parser is None:
parser = argparse.ArgumentParser(
description="CLI tool for lighteval, a lightweight framework for LLM evaluation"
)

group = parser.add_mutually_exclusive_group(required=True)

group.add_argument("--list", action="store_true", help="List available tasks")
group.add_argument(
"--inspect",
type=str,
default=None,
help="Cache directory",
help="Id of tasks or path to a text file with a list of tasks (e.g. 'original|mmlu:abstract_algebra|5') for which you want to manually inspect samples.",
)
parser.add_argument("--num_samples", type=int, default=10, help="Number of samples to display")
parser.add_argument("--show_config", default=False, action="store_true", help="Will display the full task config")
parser.add_argument(
"--cache_dir", type=str, default=CACHE_DIR, help="Cache directory used to store datasets and models"
)
30 changes: 29 additions & 1 deletion src/lighteval/tasks/lighteval_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,16 @@
# SOFTWARE.

import collections
import inspect
import os
import random
from dataclasses import dataclass
from dataclasses import asdict, dataclass
from multiprocessing import Pool
from pathlib import Path
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union

from datasets import load_dataset
from pytablewriter import MarkdownTableWriter

from lighteval.few_shot_manager import FewShotSampler
from lighteval.logging.hierarchical_logger import hlog, hlog_warn
Expand Down Expand Up @@ -232,6 +234,32 @@ def __init__( # noqa: C901
def cfg(self):
return self._cfg

def print_config(self):
md_writer = MarkdownTableWriter()
md_writer.headers = ["Key", "Value"]

values = []

for k, v in asdict(self.cfg).items():
if k == "metric":
for ix, metrics in enumerate(v):
for metric_k, metric_v in metrics.items():
if inspect.ismethod(metric_v):
values.append([f"{k} {ix}: {metric_k}", metric_v.__qualname__])
else:
values.append([f"{k} {ix}: {metric_k}", repr(metric_v)])

else:
if isinstance(v, Callable):
values.append([k, v.__name__])
else:
values.append([k, repr(v)])
# print(k, ":", repr(v))

md_writer.value_matrix = values

print(md_writer.dumps())

def doc_to_text_without_instructions(self, doc: Doc) -> str:
"""
Returns the query of the document without the instructions. If the
Expand Down

0 comments on commit cbae17d

Please sign in to comment.