Skip to content

Commit

Permalink
now use the config class for everyone
Browse files Browse the repository at this point in the history
  • Loading branch information
clefourrier committed Feb 8, 2024
1 parent a39d6a3 commit be9b745
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 43 deletions.
38 changes: 19 additions & 19 deletions src/lighteval/tasks/lighteval_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@


@dataclass
class CustomEvaluationTask:
class LightevalTaskConfig:
name: str
prompt_function: str
hf_repo: str
Expand Down Expand Up @@ -95,7 +95,7 @@ def __post_init__(self):


class LightevalTask:
def __init__(self, name: str, cfg: dict, cache_dir: Optional[str] = None, custom_tasks_module=None):
def __init__(self, name: str, cfg: LightevalTaskConfig, cache_dir: Optional[str] = None, custom_tasks_module=None):
"""
Initialize a LightEval task.
Expand All @@ -115,8 +115,8 @@ def __init__(self, name: str, cfg: dict, cache_dir: Optional[str] = None, custom
self._cfg = cfg

# Dataset info
self.hf_repo = cfg["hf_repo"]
self.hf_subset = cfg["hf_subset"]
self.hf_repo = cfg.hf_repo
self.hf_subset = cfg.hf_subset
self.dataset_path = self.hf_repo
self.dataset_config_name = self.hf_subset
self.dataset = None # Delayed download
Expand All @@ -125,22 +125,22 @@ def __init__(self, name: str, cfg: dict, cache_dir: Optional[str] = None, custom
self._docs = None

# Managing splits and few shot
self.all_available_splits = as_list(cfg["hf_avail_splits"])
self.all_available_splits = as_list(cfg.hf_avail_splits)
if cfg.get("evaluation_splits", None) is None:
raise ValueError(f"The evaluation split for task {self.name} is None. Please select a valid split.")

self.evaluation_split = as_list(cfg["evaluation_splits"])
self.evaluation_split = as_list(cfg.evaluation_splits)
if cfg.get("few_shots_split", None) is not None:
self.fewshot_split = as_list(cfg["few_shots_split"])
self.fewshot_split = as_list(cfg.few_shots_split)
else:
self.fewshot_split = as_list(self.get_first_possible_fewshot_splits())
self.fewshot_sampler = FewShotSampler(
few_shots_select=cfg["few_shots_select"], few_shots_split=self.fewshot_split
few_shots_select=cfg.few_shots_select, few_shots_split=self.fewshot_split
)

# Metrics
self.metrics = as_list(cfg["metric"])
self.suite = as_list(cfg["suite"])
self.metrics = as_list(cfg.metric)
self.suite = as_list(cfg.suite)
ignored = [metric for metric in self.metrics if Metrics[metric].value.category == MetricCategory.IGNORED]
if len(ignored) > 0:
hlog_warn(f"[WARNING] Not implemented yet: ignoring the metric {' ,'.join(ignored)} for task {self.name}.")
Expand All @@ -150,20 +150,20 @@ def __init__(self, name: str, cfg: dict, cache_dir: Optional[str] = None, custom
# Data processing
# to use once prompt formatting is managed as a module
if custom_tasks_module is None:
self.formatter = getattr(tasks_prompt_formatting, cfg["prompt_function"])
elif hasattr(custom_tasks_module, cfg["prompt_function"]):
self.formatter = getattr(tasks_prompt_formatting, cfg.prompt_function)
elif hasattr(custom_tasks_module, cfg.prompt_function):
# If we have a prompt in both the custom_tasks_module and our tasks_prompt_formatting
# We take the prompt from the custom_tasks_module
if hasattr(tasks_prompt_formatting, cfg["prompt_function"]):
if hasattr(tasks_prompt_formatting, cfg.prompt_function):
hlog_warn(
f"Be careful you are using custom prompt function {cfg['prompt_function']} and not the default one."
f"Be careful you are using custom prompt function {cfg.prompt_function} and not the default one."
)
self.formatter = getattr(custom_tasks_module, cfg["prompt_function"])
self.formatter = getattr(custom_tasks_module, cfg.prompt_function)
else:
self.formatter = getattr(tasks_prompt_formatting, cfg["prompt_function"])
self.generation_size = cfg["generation_size"]
self.stop_sequence = cfg["stop_sequence"]
self.output_regex = cfg["output_regex"]
self.formatter = getattr(tasks_prompt_formatting, cfg.prompt_function)
self.generation_size = cfg.generation_size
self.stop_sequence = cfg.stop_sequence
self.output_regex = cfg.output_regex

# Save options
self.save_queries: bool = False
Expand Down
6 changes: 3 additions & 3 deletions src/lighteval/tasks/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from datasets.load import dataset_module_factory

from lighteval.logging.hierarchical_logger import hlog, hlog_warn
from lighteval.tasks.lighteval_task import LightevalTask
from lighteval.tasks.lighteval_task import LightevalTask, LightevalTaskConfig


# original is the reimplementation of original evals
Expand Down Expand Up @@ -202,7 +202,7 @@ def create_config_tasks(
Dict[str, LightevalTask]: A dictionary of task names mapped to their corresponding LightevalTask classes.
"""

def create_task(name, cfg, cache_dir):
def create_task(name, cfg: LightevalTaskConfig, cache_dir: str):
class LightevalTaskFromConfig(LightevalTask):
def __init__(self, custom_tasks_module=None):
super().__init__(name, cfg, cache_dir=cache_dir, custom_tasks_module=custom_tasks_module)
Expand All @@ -222,6 +222,6 @@ def __init__(self, custom_tasks_module=None):
continue
for suite in line["suite"]:
if suite in DEFAULT_SUITES:
tasks_with_config[f"{suite}|{line['name']}"] = line
tasks_with_config[f"{suite}|{line['name']}"] = LightevalTaskConfig(**line)

return {task: create_task(task, cfg, cache_dir=cache_dir) for task, cfg in tasks_with_config.items()}
42 changes: 21 additions & 21 deletions tasks_examples/custom_tasks/custom_evaluation_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,53 +9,53 @@
from typing import Dict, List, Tuple

from lighteval.metrics import Metrics
from lighteval.tasks.lighteval_task import CustomEvaluationTask
from lighteval.tasks.lighteval_task import LightevalTaskConfig
from lighteval.tasks.requests import Doc
from lighteval.tasks.tasks_prompt_formatting import LETTER_INDICES


_TASKS_STRINGS: List[Tuple[CustomEvaluationTask, str]] = []
_TASKS: List[CustomEvaluationTask] = []
_TASKS_STRINGS: List[Tuple[LightevalTaskConfig, str]] = []
_TASKS: List[LightevalTaskConfig] = []

## COMMON_SENSE_REASONING_TASKS ##
COMMON_SENSE_REASONING_TASKS = [
CustomEvaluationTask(
LightevalTaskConfig(
name="hellaswag",
prompt_function="hellaswag_prompt",
hf_repo="hellaswag",
hf_subset="default",
metric=["loglikelihood_acc", "loglikelihood_acc_norm_nospace"],
),
CustomEvaluationTask(
LightevalTaskConfig(
name="winogrande",
prompt_function="winogrande",
hf_repo="winogrande",
hf_subset="winogrande_xl",
metric=["loglikelihood_acc", "loglikelihood_acc_norm_nospace"],
),
CustomEvaluationTask(
LightevalTaskConfig(
name="piqa",
prompt_function="piqa_harness",
hf_repo="piqa",
hf_subset="plain_text",
metric=["loglikelihood_acc", "loglikelihood_acc_norm_nospace"],
),
CustomEvaluationTask(
LightevalTaskConfig(
name="siqa",
prompt_function="siqa_prompt",
hf_repo="lighteval/siqa",
hf_subset="default",
hf_avail_splits=["train", "validation"],
metric=["loglikelihood_acc", "loglikelihood_acc_norm_nospace"],
),
CustomEvaluationTask(
LightevalTaskConfig(
name="openbookqa",
prompt_function="openbookqa",
hf_repo="openbookqa",
hf_subset="main",
metric=["loglikelihood_acc", "loglikelihood_acc_norm_nospace"],
),
CustomEvaluationTask(
LightevalTaskConfig(
name="arc:easy",
prompt_function="arc",
hf_repo="ai2_arc",
Expand All @@ -64,7 +64,7 @@
generation_size=1,
metric=["loglikelihood_acc", "loglikelihood_acc_norm_nospace"],
),
CustomEvaluationTask(
LightevalTaskConfig(
name="arc:challenge",
prompt_function="arc",
hf_repo="ai2_arc",
Expand All @@ -73,7 +73,7 @@
generation_size=1,
metric=["loglikelihood_acc", "loglikelihood_acc_norm_nospace"],
),
CustomEvaluationTask(
LightevalTaskConfig(
name="commonsense_qa",
prompt_function="commonsense_qa_prompt",
hf_repo="commonsense_qa",
Expand Down Expand Up @@ -131,7 +131,7 @@ def preprocess(text):
## WORLD_KNOWLEDGE_TASKS ##

WORLD_KNOWLEDGE_TASKS = [
CustomEvaluationTask(
LightevalTaskConfig(
name="trivia_qa",
prompt_function="triviaqa",
hf_repo="trivia_qa",
Expand All @@ -140,7 +140,7 @@ def preprocess(text):
generation_size=20,
stop_sequence=["\n", ".", ","],
),
CustomEvaluationTask(
LightevalTaskConfig(
name="natural_questions",
prompt_function="natural_questions_prompt",
hf_repo="lighteval/natural_questions_clean",
Expand Down Expand Up @@ -170,14 +170,14 @@ def natural_questions_prompt(line, task_name: str = None):
## Reading comprehension ##

READING_COMP_TASKS = [
CustomEvaluationTask(
LightevalTaskConfig(
name="super_glue:boolq",
prompt_function="boolq_prompt",
hf_repo="super_glue",
hf_subset="boolq",
metric=["target_perplexity"],
),
CustomEvaluationTask(
LightevalTaskConfig(
name="quac",
prompt_function="quac",
hf_repo="lighteval/quac_helm",
Expand All @@ -204,7 +204,7 @@ def boolq_prompt(line, task_name: str = None):


## MATH ##
class CustomMathEvaluationTask(CustomEvaluationTask):
class CustomMathEvaluationTask(LightevalTaskConfig):
"""Custom class for math tasks with all the defaults set"""

def __init__(
Expand Down Expand Up @@ -251,7 +251,7 @@ def __init__(
CustomMathEvaluationTask(name="math:prealgebra", hf_subset="prealgebra"),
CustomMathEvaluationTask(name="math:precalculus", hf_subset="precalculus"),
]
GSM8K = CustomEvaluationTask(
GSM8K = LightevalTaskConfig(
name="gsm8k",
prompt_function="gsm8k",
hf_repo="gsm8k",
Expand All @@ -272,7 +272,7 @@ def __init__(


## MMLU ##
class CustomMMLUEvaluationTask(CustomEvaluationTask):
class CustomMMLUEvaluationTask(LightevalTaskConfig):
def __init__(
self,
name,
Expand Down Expand Up @@ -415,7 +415,7 @@ def mmlu_prompt(line, task_name: str = None):
## BBH ##


class CustomBBHEvaluationTask(CustomEvaluationTask):
class CustomBBHEvaluationTask(LightevalTaskConfig):
def __init__(
self,
name,
Expand Down Expand Up @@ -506,7 +506,7 @@ def bbh_prompt(line, task_name: str = None):


## AGI eval ##
class CustomAGIEvalEvaluationTask(CustomEvaluationTask):
class CustomAGIEvalEvaluationTask(LightevalTaskConfig):
def __init__(
self,
name,
Expand Down Expand Up @@ -617,7 +617,7 @@ def agi_eval_prompt_no_letters(line, task_name: str = None):


## HUMAN EVAL ##
# human_eval = CustomEvaluationTask(
# human_eval = LightevalTaskConfig(
# name="human_eval",
# prompt_function="human_eval",
# hf_repo="lighteval/human_eval",
Expand Down

0 comments on commit be9b745

Please sign in to comment.