Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose samples via the CLI #228

Merged
merged 9 commits into from
Aug 1, 2024
50 changes: 39 additions & 11 deletions src/lighteval/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,42 +23,70 @@
# SOFTWARE.

import argparse
import os
from dataclasses import asdict
from pprint import pformat

from lighteval.parsers import parser_accelerate, parser_nanotron
from lighteval.tasks.registry import Registry
from lighteval.parsers import parser_accelerate, parser_nanotron, parser_utils_tasks
from lighteval.tasks.registry import Registry, taskinfo_selector


CACHE_DIR = os.getenv("HF_HOME")


def cli_evaluate():
parser = argparse.ArgumentParser(description="CLI tool for lighteval, a lightweight framework for LLM evaluation")
subparsers = parser.add_subparsers(help="help for subcommand", dest="subcommand")

# create the parser for the "accelerate" command
# Subparser for the "accelerate" command
parser_a = subparsers.add_parser("accelerate", help="use accelerate and transformers as backend for evaluation.")
parser_accelerate(parser_a)

# create the parser for the "nanotron" command
# Subparser for the "nanotron" command
parser_b = subparsers.add_parser("nanotron", help="use nanotron as backend for evaluation.")
parser_nanotron(parser_b)

parser.add_argument("--list-tasks", action="store_true", help="List available tasks")
# Subparser for task utils functions
parser_c = subparsers.add_parser("tasks", help="use nanotron as backend for evaluation.")
parser_utils_tasks(parser_c)

args = parser.parse_args()

if args.subcommand == "accelerate":
from lighteval.main_accelerate import main as main_accelerate

main_accelerate(args)
return

if args.subcommand == "nanotron":
elif args.subcommand == "nanotron":
from lighteval.main_nanotron import main as main_nanotron

main_nanotron(args.checkpoint_config_path, args.lighteval_override, args.cache_dir)
return

if args.list_tasks:
Registry(cache_dir="").print_all_tasks()
return
elif args.subcommand == "tasks":
if args.list:
Registry(cache_dir="").print_all_tasks()

if args.inspect:
print(f"Loading the tasks dataset to cache folder: {args.cache_dir}")
print(
"All examples will be displayed without few shot, as few shot sample construction requires loading a model and using its tokenizer."
)
# Loading task
task_names_list, _ = taskinfo_selector(args.inspect)
task_dict = Registry(cache_dir=args.cache_dir).get_task_dict(task_names_list)
for name, task in task_dict.items():
print("-" * 10, name, "-" * 10)
if args.show_config:
print("-" * 10, "CONFIG")
task.print_config()
for ix, sample in enumerate(task.eval_docs()[: int(args.num_samples)]):
if ix == 0:
print("-" * 10, "SAMPLES")
print(f"-- sample {ix} --")
print(pformat(asdict(sample), indent=1))

else:
print("You did not provide any argument. Exiting")


if __name__ == "__main__":
Expand Down
19 changes: 14 additions & 5 deletions src/lighteval/logging/hierarchical_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import sys
import time
from datetime import timedelta
from logging import Logger
from typing import Any, Callable

from lighteval.utils import is_accelerate_available, is_nanotron_available
Expand All @@ -37,8 +38,6 @@

logger = get_logger(__name__, log_level="INFO")
else:
from logging import Logger

logger = Logger(__name__, level="INFO")

from colorama import Fore, Style
Expand Down Expand Up @@ -76,6 +75,7 @@ def log(self, x: Any) -> None:


HIERARCHICAL_LOGGER = HierarchicalLogger()
BACKUP_LOGGER = Logger(__name__, level="INFO")


# Exposed public methods
Expand All @@ -84,23 +84,32 @@ def hlog(x: Any) -> None:

Logs a string version of x through the singleton [`HierarchicalLogger`].
"""
HIERARCHICAL_LOGGER.log(x)
try:
HIERARCHICAL_LOGGER.log(x)
except RuntimeError:
BACKUP_LOGGER.warning(x)


def hlog_warn(x: Any) -> None:
"""Warning logger.

Logs a string version of x, which will appear in a yellow color, through the singleton [`HierarchicalLogger`].
"""
HIERARCHICAL_LOGGER.log(Fore.YELLOW + str(x) + Style.RESET_ALL)
try:
HIERARCHICAL_LOGGER.log(Fore.YELLOW + str(x) + Style.RESET_ALL)
except RuntimeError:
BACKUP_LOGGER.warning(Fore.YELLOW + str(x) + Style.RESET_ALL)


def hlog_err(x: Any) -> None:
"""Error logger.

Logs a string version of x, which will appear in a red color, through the singleton [`HierarchicalLogger`].
"""
HIERARCHICAL_LOGGER.log(Fore.RED + str(x) + Style.RESET_ALL)
try:
HIERARCHICAL_LOGGER.log(Fore.RED + str(x) + Style.RESET_ALL)
except RuntimeError:
BACKUP_LOGGER.warning(Fore.RED + str(x) + Style.RESET_ALL)


class htrack_block:
Expand Down
1 change: 0 additions & 1 deletion src/lighteval/main_accelerate.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@
hlog_warn("Using either accelerate or text-generation to run this script is advised.")

TOKEN = os.getenv("HF_TOKEN")
CACHE_DIR = os.getenv("HF_HOME")

if is_accelerate_available():
from accelerate import Accelerator, InitProcessGroupKwargs
Expand Down
4 changes: 2 additions & 2 deletions src/lighteval/metrics/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from enum import Enum, auto


class MetricCategory(Enum):
class MetricCategory(str, Enum):
TARGET_PERPLEXITY = auto()
PERPLEXITY = auto()
GENERATIVE = auto()
Expand All @@ -37,7 +37,7 @@ class MetricCategory(Enum):
IGNORED = auto()


class MetricUseCase(Enum):
class MetricUseCase(str, Enum):
# General
ACCURACY = auto()
PERPLEXITY = auto()
Expand Down
29 changes: 24 additions & 5 deletions src/lighteval/parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,6 @@ def parser_accelerate(parser=None):
parser.add_argument(
"--public_run", default=False, action="store_true", help="Push results and details to a public repo"
)
parser.add_argument(
"--cache_dir", type=str, default=CACHE_DIR, help="Cache directory used to store datasets and models"
)
parser.add_argument(
"--results_org",
type=str,
Expand Down Expand Up @@ -99,6 +96,9 @@ def parser_accelerate(parser=None):
default=None,
help="Id of a task, e.g. 'original|mmlu:abstract_algebra|5' or path to a texte file with a list of tasks",
)
parser.add_argument(
"--cache_dir", type=str, default=CACHE_DIR, help="Cache directory used to store datasets and models"
)
parser.add_argument("--num_fewshot_seeds", type=int, default=1, help="Number of trials the few shots")
return parser

Expand All @@ -121,8 +121,27 @@ def parser_nanotron(parser=None):
help="Path to an optional YAML or python Lighteval config to override part of the checkpoint Lighteval config",
)
parser.add_argument(
"--cache-dir",
"--cache_dir", type=str, default=CACHE_DIR, help="Cache directory used to store datasets and models"
)


def parser_utils_tasks(parser=None):
if parser is None:
parser = argparse.ArgumentParser(
description="CLI tool for lighteval, a lightweight framework for LLM evaluation"
)

group = parser.add_mutually_exclusive_group(required=True)

group.add_argument("--list", action="store_true", help="List available tasks")
group.add_argument(
"--inspect",
type=str,
default=None,
help="Cache directory",
help="Id of tasks or path to a text file with a list of tasks (e.g. 'original|mmlu:abstract_algebra|5') for which you want to manually inspect samples.",
)
parser.add_argument("--num_samples", type=int, default=10, help="Number of samples to display")
parser.add_argument("--show_config", default=False, action="store_true", help="Will display the full task config")
parser.add_argument(
"--cache_dir", type=str, default=CACHE_DIR, help="Cache directory used to store datasets and models"
)
30 changes: 29 additions & 1 deletion src/lighteval/tasks/lighteval_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,16 @@
# SOFTWARE.

import collections
import inspect
import os
import random
from dataclasses import dataclass
from dataclasses import asdict, dataclass
from multiprocessing import Pool
from pathlib import Path
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union

from datasets import load_dataset
from pytablewriter import MarkdownTableWriter

from lighteval.few_shot_manager import FewShotSampler
from lighteval.logging.hierarchical_logger import hlog, hlog_warn
Expand Down Expand Up @@ -232,6 +234,32 @@ def __init__( # noqa: C901
def cfg(self):
return self._cfg

def print_config(self):
md_writer = MarkdownTableWriter()
md_writer.headers = ["Key", "Value"]

values = []

for k, v in asdict(self.cfg).items():
if k == "metric":
for ix, metrics in enumerate(v):
for metric_k, metric_v in metrics.items():
if inspect.ismethod(metric_v):
values.append([f"{k} {ix}: {metric_k}", metric_v.__qualname__])
else:
values.append([f"{k} {ix}: {metric_k}", repr(metric_v)])

else:
if isinstance(v, Callable):
values.append([k, v.__name__])
else:
values.append([k, repr(v)])
# print(k, ":", repr(v))

md_writer.value_matrix = values

print(md_writer.dumps())

def doc_to_text_without_instructions(self, doc: Doc) -> str:
"""
Returns the query of the document without the instructions. If the
Expand Down
Loading