Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
thomwolf committed Feb 7, 2024
1 parent 9e63d63 commit ddd7fac
Show file tree
Hide file tree
Showing 8 changed files with 66 additions and 28 deletions.
2 changes: 1 addition & 1 deletion run_evals_accelerate.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def get_parser():
parser.add_argument("--override_batch_size", type=int, default=-1)
parser.add_argument("--dataset_loading_processes", type=int, default=1)
parser.add_argument(
"--custom_tasks_file",
"--custom_tasks",
type=str,
default=None,
help="Path to a file with custom tasks (a TASK list of dict and potentially prompt formating functions)",
Expand Down
2 changes: 1 addition & 1 deletion run_evals_nanotron.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def get_parser():
parser.add_argument(
"--cache-dir",
type=str,
default="",
default=None,
help="Cache directory",
)

Expand Down
2 changes: 1 addition & 1 deletion src/lighteval/main_accelerate.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def main(args):
with accelerator.main_process_first() if accelerator is not None else nullcontext():
task_names_list, few_shots_dict = taskinfo_selector(args.tasks)
task_dict = Registry(cache_dir=env_config.cache_dir).get_task_dict(
task_names_list, custom_tasks_file=args.custom_tasks_file
task_names_list, custom_tasks=args.custom_tasks
)
# Loading all the dataset in a distributed manner
LightevalTask.load_datasets(task_dict.values(), args.dataset_loading_processes)
Expand Down
8 changes: 4 additions & 4 deletions src/lighteval/main_nanotron.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
def main(
checkpoint_config_path: str,
lighteval_config_path: Optional[str] = None,
cache_dir: str = None,
cache_dir: Optional[str] = None,
config_cls: Type = Config,
model_config_cls: Optional[Type] = None,
model_cls: Optional[Type] = None,
Expand Down Expand Up @@ -109,14 +109,14 @@ def main(
with htrack_block("Tasks loading"):
with local_ranks_zero_first():
tasks_selection = lighteval_config.tasks.tasks
if lighteval_config.tasks.custom_tasks_file:
_, tasks_groups_dict = get_custom_tasks(lighteval_config.tasks.custom_tasks_file)
if lighteval_config.tasks.custom_tasks:
_, tasks_groups_dict = get_custom_tasks(lighteval_config.tasks.custom_tasks)
if tasks_groups_dict and lighteval_config.tasks.tasks in tasks_groups_dict:
tasks_selection = tasks_groups_dict[lighteval_config.tasks.tasks]

task_names_list, few_shots_dict = taskinfo_selector(tasks_selection)
task_dict = Registry(cache_dir=cache_dir).get_task_dict(
task_names_list, custom_tasks_file=lighteval_config.tasks.custom_tasks_file
task_names_list, custom_tasks=lighteval_config.tasks.custom_tasks
)
# Loading all the dataset in a distributed manner
LightevalTask.load_datasets(task_dict.values(), lighteval_config.tasks.dataset_loading_processes)
Expand Down
18 changes: 18 additions & 0 deletions src/lighteval/tasks/lighteval_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,24 @@ class CustomEvaluationTask:
frozen: bool = False
suite: Optional[Tuple[str]] = None # we use this to know if we should use a custom lighteval or bigcode task

def as_dict(self):
return {
"name": self.name,
"prompt_function": self.prompt_function,
"hf_repo": self.hf_repo,
"hf_subset": self.hf_subset,
"metric": tuple(str(m) for m in self.metric),
"hf_avail_splits": self.hf_avail_splits,
"evaluation_splits": self.evaluation_splits,
"few_shots_split": self.few_shots_split,
"few_shots_select": self.few_shots_select,
"generation_size": self.generation_size,
"stop_sequence": self.stop_sequence,
"output_regex": self.output_regex,
"frozen": self.frozen,
"suite": self.suite,
}

def __post_init__(self):
if self.suite is None:
self.suite = ["custom"]
Expand Down
46 changes: 33 additions & 13 deletions src/lighteval/tasks/registry.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import collections
import importlib
import os
from pathlib import Path
from pprint import pformat
from types import ModuleType
from typing import Dict, List, Optional, Tuple
from typing import Dict, List, Optional, Tuple, Union

from datasets import Dataset
from datasets.load import dataset_module_factory
Expand Down Expand Up @@ -68,32 +69,38 @@ def get_task_class(
)

def get_task_dict(
self, task_name_list: List[str], custom_tasks_file: Optional[str] = None
self, task_name_list: List[str], custom_tasks: Optional[Union[str, ModuleType]] = None
) -> Dict[str, LightevalTask]:
"""
Get a dictionary of tasks based on the task name list.
Args:
task_name_list (List[str]): A list of task names.
custom_tasks_file (Optional[str]): Path to the custom tasks file.
custom_tasks (Optional[Union[str, ModuleType]]): Path to the custom tasks file or name of a module to import containing custom tasks or the module it-self
Returns:
Dict[str, LightevalTask]: A dictionary containing the tasks.
Notes:
- If custom_tasks_file is provided, it will import the custom tasks module and create a custom tasks registry.
- If custom_tasks is provided, it will import the custom tasks module and create a custom tasks registry.
- Each task in the task_name_list will be instantiated with the corresponding task class.
"""
if custom_tasks_file is not None:
dataset_module = dataset_module_factory(str(custom_tasks_file))
custom_tasks_module = importlib.import_module(dataset_module.module_path)
custom_tasks_registry = None
if custom_tasks is not None:
if isinstance(custom_tasks, ModuleType):
custom_tasks_module = custom_tasks
elif isinstance(custom_tasks, (str, Path)) and os.path.exists(custom_tasks):
dataset_module = dataset_module_factory(str(custom_tasks))
custom_tasks_module = importlib.import_module(dataset_module.module_path)
elif isinstance(custom_tasks, (str, Path)):
custom_tasks_module = importlib.import_module(custom_tasks)
else:
raise ValueError(f"Cannot import custom tasks from {custom_tasks}")
if custom_tasks_module is not None:
custom_tasks_registry = create_config_tasks(
meta_table=custom_tasks_module.TASKS_TABLE, cache_dir=self.cache_dir
)
hlog(custom_tasks_registry)
else:
custom_tasks_module = None
custom_tasks_registry = None

tasks_dict = {}
for task_name in task_name_list:
Expand All @@ -103,9 +110,22 @@ def get_task_dict(
return tasks_dict


def get_custom_tasks(custom_tasks_file: str) -> Tuple[ModuleType, str]:
dataset_module = dataset_module_factory(str(custom_tasks_file))
custom_tasks_module = importlib.import_module(dataset_module.module_path)
def get_custom_tasks(custom_tasks: Optional[Union[str, ModuleType]] = None) -> Tuple[ModuleType, str]:
"""Get custom tasks from the given custom tasks file or module.
Args:
custom_tasks (Optional[Union[str, ModuleType]]): Path to the custom tasks file or name of a module to import containing custom tasks or the module it-self
"""
if custom_tasks is not None:
if isinstance(custom_tasks, ModuleType):
custom_tasks_module = custom_tasks
elif isinstance(custom_tasks, (str, Path)) and os.path.exists(custom_tasks):
dataset_module = dataset_module_factory(str(custom_tasks))
custom_tasks_module = importlib.import_module(dataset_module.module_path)
elif isinstance(custom_tasks, (str, Path)):
custom_tasks_module = importlib.import_module(custom_tasks)
else:
raise ValueError(f"Cannot import custom tasks from {custom_tasks}")
tasks_string = ""
if hasattr(custom_tasks_module, "TASKS_GROUPS"):
tasks_string = custom_tasks_module.TASKS_GROUPS
Expand Down
12 changes: 6 additions & 6 deletions tasks_examples/custom_tasks/custom_evaluation_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def preprocess(text):
prompt_function="triviaqa",
hf_repo="trivia_qa",
hf_subset="rc.nocontext",
metric=[Metrics.quasi_exact_match2],
metric=[Metrics.quasi_exact_match],
generation_size=20,
stop_sequence=["\n", ".", ","],
),
Expand All @@ -145,7 +145,7 @@ def preprocess(text):
prompt_function="natural_questions_prompt",
hf_repo="lighteval/natural_questions_clean",
hf_subset="default",
metric=[Metrics.quasi_exact_match2],
metric=[Metrics.quasi_exact_match],
generation_size=20,
stop_sequence=["\n", ".", ","],
),
Expand Down Expand Up @@ -182,7 +182,7 @@ def natural_questions_prompt(line, task_name: str = None):
prompt_function="quac",
hf_repo="lighteval/quac_helm",
hf_subset="deault",
metric=[Metrics.quasi_exact_match2],
metric=[Metrics.quasi_exact_match],
generation_size=20,
stop_sequence=["\n", ".", ","],
),
Expand Down Expand Up @@ -213,7 +213,7 @@ def __init__(
prompt_function="math",
hf_repo="lighteval/MATH",
hf_subset=None,
metric=[Metrics.math_quasi_exact_match],
metric=[Metrics.quasi_exact_match_math],
hf_avail_splits=None,
evaluation_splits=["test"],
few_shots_split=None,
Expand Down Expand Up @@ -553,7 +553,7 @@ def __init__(
name="agi_eval:math",
hf_subset="math",
prompt_function="agi_eval_math_prompt",
metric=[Metrics.exact_match, Metrics.quasi_exact_match2],
metric=[Metrics.exact_match, Metrics.quasi_exact_match],
generation_size=40,
),
CustomAGIEvalEvaluationTask(name="agi_eval:sat-en", hf_subset="sat-en"),
Expand Down Expand Up @@ -628,7 +628,7 @@ def agi_eval_prompt_no_letters(line, task_name: str = None):
EARLY_SIGNAL_TASKS = ",".join([t[1] for t in COMMON_SENSE_REASONING_STRING] + [t[1] for t in MMLU_STRING])

# Convert to dict for lighteval
TASKS_TABLE = [asdict(task) for task in _TASKS]
TASKS_TABLE = [task.as_dict() for task in _TASKS]
# You can have a few pre-organised groups of tasks
TASKS_GROUPS = {
"all": ",".join(t[1] for t in _TASKS_STRINGS),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ lighteval:
tp_linear_async_communication: false
tp_mode: ALL_REDUCE
tasks:
custom_tasks_file: /fsx/thomwolf/github/lighteval/tasks_examples/custom_tasks/custom_evaluation_tasks.py
custom_tasks: /fsx/thomwolf/github/lighteval/tasks_examples/custom_tasks/custom_evaluation_tasks.py
dataset_loading_processes: 8
max_samples: 1000
max_samples: 10
multichoice_continuations_start_space: null
no_multichoice_continuations_start_space: null
num_fewshot_seeds: null
Expand Down

0 comments on commit ddd7fac

Please sign in to comment.