Skip to content

Commit

Permalink
Last PR to make custom tasks work for everyone (#23)
Browse files Browse the repository at this point in the history
Small one to be in the release. Can iterate later if needed

---------

Co-authored-by: [email protected] <[email protected]>
Co-authored-by: Clémentine Fourrier <[email protected]>
Co-authored-by: Nathan Habib <[email protected]>
  • Loading branch information
4 people authored Feb 8, 2024
1 parent f8bd2ab commit bc0b8bc
Show file tree
Hide file tree
Showing 11 changed files with 177 additions and 275 deletions.
2 changes: 1 addition & 1 deletion run_evals_accelerate.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def get_parser():
parser.add_argument("--override_batch_size", type=int, default=-1)
parser.add_argument("--dataset_loading_processes", type=int, default=1)
parser.add_argument(
"--custom_tasks_file",
"--custom_tasks",
type=str,
default=None,
help="Path to a file with custom tasks (a TASK list of dict and potentially prompt formating functions)",
Expand Down
2 changes: 1 addition & 1 deletion run_evals_nanotron.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def get_parser():
parser.add_argument(
"--cache-dir",
type=str,
default="",
default=None,
help="Cache directory",
)

Expand Down
12 changes: 8 additions & 4 deletions src/lighteval/logging/info_loggers.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,12 @@ class GeneralConfigLogger:

def __init__(self) -> None:
"""Stores the current lighteval commit for reproducibility, and starts the evaluation timer."""
repo = git.Repo(os.path.dirname(__file__).split("src")[0])
self.lighteval_sha = repo.git.rev_parse("HEAD")
try:
repo = git.Repo(os.path.dirname(__file__).split("src")[0])
except git.InvalidGitRepositoryError:
repo = None

self.lighteval_sha = repo.git.rev_parse("HEAD") if repo is not None else "?"
self.start_time = time.perf_counter()

def log_args_info(
Expand Down Expand Up @@ -543,5 +547,5 @@ def log(self, task_dict: dict[str, LightevalTask]) -> None:
self.tasks_configs = {name: task.cfg for name, task in task_dict.items()}

def log_num_docs(self, task_name: str, original_num_docs: int, effective_num_docs: int) -> None:
self.tasks_configs[task_name]["original_num_docs"] = original_num_docs
self.tasks_configs[task_name]["effective_num_docs"] = effective_num_docs
self.tasks_configs[task_name].original_num_docs = original_num_docs
self.tasks_configs[task_name].effective_num_docs = effective_num_docs
2 changes: 1 addition & 1 deletion src/lighteval/main_accelerate.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def main(args):
with accelerator.main_process_first() if accelerator is not None else nullcontext():
task_names_list, few_shots_dict = taskinfo_selector(args.tasks)
task_dict = Registry(cache_dir=env_config.cache_dir).get_task_dict(
task_names_list, custom_tasks_file=args.custom_tasks_file
task_names_list, custom_tasks=args.custom_tasks
)
# Loading all the dataset in a distributed manner
LightevalTask.load_datasets(task_dict.values(), args.dataset_loading_processes)
Expand Down
8 changes: 4 additions & 4 deletions src/lighteval/main_nanotron.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
def main(
checkpoint_config_path: str,
lighteval_config_path: Optional[str] = None,
cache_dir: str = None,
cache_dir: Optional[str] = None,
config_cls: Type = Config,
model_config_cls: Optional[Type] = None,
model_cls: Optional[Type] = None,
Expand Down Expand Up @@ -109,14 +109,14 @@ def main(
with htrack_block("Tasks loading"):
with local_ranks_zero_first():
tasks_selection = lighteval_config.tasks.tasks
if lighteval_config.tasks.custom_tasks_file:
_, tasks_groups_dict = get_custom_tasks(lighteval_config.tasks.custom_tasks_file)
if lighteval_config.tasks.custom_tasks:
_, tasks_groups_dict = get_custom_tasks(lighteval_config.tasks.custom_tasks)
if tasks_groups_dict and lighteval_config.tasks.tasks in tasks_groups_dict:
tasks_selection = tasks_groups_dict[lighteval_config.tasks.tasks]

task_names_list, few_shots_dict = taskinfo_selector(tasks_selection)
task_dict = Registry(cache_dir=cache_dir).get_task_dict(
task_names_list, custom_tasks_file=lighteval_config.tasks.custom_tasks_file
task_names_list, custom_tasks=lighteval_config.tasks.custom_tasks
)
# Loading all the dataset in a distributed manner
LightevalTask.load_datasets(task_dict.values(), lighteval_config.tasks.dataset_loading_processes)
Expand Down
6 changes: 2 additions & 4 deletions src/lighteval/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,7 @@
LoglikelihoodSingleTokenRequest,
Request,
)
from lighteval.utils import (
is_accelerate_available,
)
from lighteval.utils import as_list, is_accelerate_available
from lighteval.utils_parallelism import find_executable_batch_size


Expand Down Expand Up @@ -342,7 +340,7 @@ def greedy_until(
list[GenerateReturn]: list of generated responses.
"""
for request in requests:
request.stop_sequence = request.stop_sequence + [self.tokenizer.eos_token]
request.stop_sequence = as_list(request.stop_sequence) + [self.tokenizer.eos_token]
request.tokenized_context = self.tok_encode(request.context)

dataset = GenerativeTaskDataset(requests=requests, dataset_splits=self.DATASET_SPLITS)
Expand Down
97 changes: 76 additions & 21 deletions src/lighteval/tasks/lighteval_task.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import collections
import random
from dataclasses import dataclass
from multiprocessing import Pool
from pathlib import Path
from typing import TYPE_CHECKING, List, Optional, Tuple
from typing import TYPE_CHECKING, List, Optional, Tuple, Union

from datasets import load_dataset

Expand Down Expand Up @@ -39,8 +40,62 @@
from lighteval.logging.evaluation_tracker import EvaluationTracker


@dataclass
class LightevalTaskConfig:
name: str
prompt_function: str
hf_repo: str
hf_subset: str
metric: Tuple[Union[str, Metrics]]
hf_avail_splits: Optional[Tuple[str]] = None
evaluation_splits: Optional[Tuple[str]] = None
few_shots_split: Optional[str] = None
few_shots_select: Optional[str] = None
generation_size: int = -1
stop_sequence: Optional[Tuple[str]] = None
output_regex: Optional[str] = None

frozen: bool = False
suite: Optional[Tuple[str]] = None # we use this to know if we should use a custom lighteval or bigcode task

def as_dict(self):
return {
"name": self.name,
"prompt_function": self.prompt_function,
"hf_repo": self.hf_repo,
"hf_subset": self.hf_subset,
"metric": tuple(str(m) for m in self.metric),
"hf_avail_splits": self.hf_avail_splits,
"evaluation_splits": self.evaluation_splits,
"few_shots_split": self.few_shots_split,
"few_shots_select": self.few_shots_select,
"generation_size": self.generation_size,
"stop_sequence": self.stop_sequence,
"output_regex": self.output_regex,
"frozen": self.frozen,
"suite": self.suite,
}

def __post_init__(self):
if self.suite is None:
self.suite = ["custom"]
if self.hf_avail_splits is None:
self.hf_avail_splits = ["train", "validation", "test"]
if self.evaluation_splits is None:
self.evaluation_splits = ["validation"]
if self.stop_sequence is None:
self.stop_sequence = ["\n"]

# Convert list to tuple for hashing
self.metric = tuple(self.metric)
self.hf_avail_splits = tuple(self.hf_avail_splits) if self.hf_avail_splits is not None else None
self.evaluation_splits = tuple(self.evaluation_splits) if self.evaluation_splits is not None else None
self.suite = tuple(self.suite) if self.suite is not None else None
self.stop_sequence = tuple(self.stop_sequence) if self.stop_sequence is not None else None


class LightevalTask:
def __init__(self, name: str, cfg: dict, cache_dir: Optional[str] = None, custom_tasks_module=None):
def __init__(self, name: str, cfg: LightevalTaskConfig, cache_dir: Optional[str] = None, custom_tasks_module=None):
"""
Initialize a LightEval task.
Expand All @@ -60,8 +115,8 @@ def __init__(self, name: str, cfg: dict, cache_dir: Optional[str] = None, custom
self._cfg = cfg

# Dataset info
self.hf_repo = cfg["hf_repo"]
self.hf_subset = cfg["hf_subset"]
self.hf_repo = cfg.hf_repo
self.hf_subset = cfg.hf_subset
self.dataset_path = self.hf_repo
self.dataset_config_name = self.hf_subset
self.dataset = None # Delayed download
Expand All @@ -70,22 +125,22 @@ def __init__(self, name: str, cfg: dict, cache_dir: Optional[str] = None, custom
self._docs = None

# Managing splits and few shot
self.all_available_splits = as_list(cfg["hf_avail_splits"])
if cfg.get("evaluation_splits", None) is None:
self.all_available_splits = as_list(cfg.hf_avail_splits)
if cfg.evaluation_splits is None:
raise ValueError(f"The evaluation split for task {self.name} is None. Please select a valid split.")

self.evaluation_split = as_list(cfg["evaluation_splits"])
if cfg.get("few_shots_split", None) is not None:
self.fewshot_split = as_list(cfg["few_shots_split"])
self.evaluation_split = as_list(cfg.evaluation_splits)
if cfg.few_shots_split is not None:
self.fewshot_split = as_list(cfg.few_shots_split)
else:
self.fewshot_split = as_list(self.get_first_possible_fewshot_splits())
self.fewshot_sampler = FewShotSampler(
few_shots_select=cfg["few_shots_select"], few_shots_split=self.fewshot_split
few_shots_select=cfg.few_shots_select, few_shots_split=self.fewshot_split
)

# Metrics
self.metrics = as_list(cfg["metric"])
self.suite = as_list(cfg["suite"])
self.metrics = as_list(cfg.metric)
self.suite = as_list(cfg.suite)
ignored = [metric for metric in self.metrics if Metrics[metric].value.category == MetricCategory.IGNORED]
if len(ignored) > 0:
hlog_warn(f"[WARNING] Not implemented yet: ignoring the metric {' ,'.join(ignored)} for task {self.name}.")
Expand All @@ -95,20 +150,20 @@ def __init__(self, name: str, cfg: dict, cache_dir: Optional[str] = None, custom
# Data processing
# to use once prompt formatting is managed as a module
if custom_tasks_module is None:
self.formatter = getattr(tasks_prompt_formatting, cfg["prompt_function"])
elif hasattr(custom_tasks_module, cfg["prompt_function"]):
self.formatter = getattr(tasks_prompt_formatting, cfg.prompt_function)
elif hasattr(custom_tasks_module, cfg.prompt_function):
# If we have a prompt in both the custom_tasks_module and our tasks_prompt_formatting
# We take the prompt from the custom_tasks_module
if hasattr(tasks_prompt_formatting, cfg["prompt_function"]):
if hasattr(tasks_prompt_formatting, cfg.prompt_function):
hlog_warn(
f"Be careful you are using custom prompt function {cfg['prompt_function']} and not the default one."
f"Be careful you are using custom prompt function {cfg.prompt_function} and not the default one."
)
self.formatter = getattr(custom_tasks_module, cfg["prompt_function"])
self.formatter = getattr(custom_tasks_module, cfg.prompt_function)
else:
self.formatter = getattr(tasks_prompt_formatting, cfg["prompt_function"])
self.generation_size = cfg["generation_size"]
self.stop_sequence = cfg["stop_sequence"]
self.output_regex = cfg["output_regex"]
self.formatter = getattr(tasks_prompt_formatting, cfg.prompt_function)
self.generation_size = cfg.generation_size
self.stop_sequence = cfg.stop_sequence
self.output_regex = cfg.output_regex

# Save options
self.save_queries: bool = False
Expand Down
Loading

0 comments on commit bc0b8bc

Please sign in to comment.