Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds ability to use functions for prompt definition #207

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ repos:

- repo: https://github.com/charliermarsh/ruff-pre-commit
# Ruff version.
rev: 'v0.1.6'
rev: 'v0.2.2'
hooks:
- id: ruff
args: ['--fix']
Expand Down
2 changes: 2 additions & 0 deletions src/lighteval/logging/evaluation_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ class EnhancedJSONEncoder(json.JSONEncoder):
def default(self, o):
if is_dataclass(o):
return asdict(o)
if callable(o):
return o.__name__
return super().default(o)


Expand Down
67 changes: 41 additions & 26 deletions src/lighteval/tasks/lighteval_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from dataclasses import dataclass
from multiprocessing import Pool
from pathlib import Path
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Callable, List, Optional, Tuple, Union

from datasets import load_dataset

Expand Down Expand Up @@ -61,6 +61,8 @@
if TYPE_CHECKING:
from lighteval.logging.evaluation_tracker import EvaluationTracker

FormatterType = Callable[[dict, str], Doc]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you include a string?
I don't see which prompt formater takes a string as input



@dataclass
class LightevalTaskConfig:
Expand Down Expand Up @@ -89,7 +91,7 @@ class LightevalTaskConfig:
"""

name: str
prompt_function: str
prompt_function: FormatterType | str
hf_repo: str
hf_subset: str
metric: Tuple[Union[str, Metrics]]
Expand Down Expand Up @@ -149,6 +151,38 @@ def __post_init__(self):
self.stop_sequence = tuple(self.stop_sequence) if self.stop_sequence is not None else None


def load_prompt_function(prompt_function: str, custom_tasks_module: list | None) -> FormatterType:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I personally find it less readable with a custom type, we'll see if we keep it through time

"""
Tries to load the prompt function defined as string.
Arguments:
prompt_function (str): Name of the prompt function to load.
custom_tasks_module (list): List of custom modules to search for the prompt function.
Returns:
FormatterType: The prompt function.
"""

if custom_tasks_module is None:
return getattr(tasks_prompt_formatting, prompt_function)

formatter = []
for module in custom_tasks_module:
if hasattr(module, prompt_function):
formatter.append(getattr(module, prompt_function))

if len(formatter) == 0: # Default version
return getattr(tasks_prompt_formatting, prompt_function)
elif len(formatter) == 1:
# If we have a prompt in both the module and our tasks_prompt_formatting
# We take the prompt from the module
if hasattr(tasks_prompt_formatting, prompt_function):
hlog_warn(f"Be careful you are using custom prompt function {prompt_function} and not the default one.")
return formatter[0]
else:
raise Exception(
f"You defined the prompt function {prompt_function} several times in the different custom modules you are loading."
)


class LightevalTask:
def __init__( # noqa: C901
self, name: str, cfg: LightevalTaskConfig, cache_dir: Optional[str] = None, custom_tasks_module: list = None
Expand Down Expand Up @@ -209,31 +243,12 @@ def __init__( # noqa: C901
self.num_samples = [1] + [
int(metric.replace("maj_at_", "").split("_")[0]) for metric in self.metrics if "maj_at_" in metric
]

# Data processing
# to use once prompt formatting is managed as a module
if custom_tasks_module is None:
self.formatter = getattr(tasks_prompt_formatting, cfg.prompt_function)
self.formatter: FormatterType
if isinstance(cfg.prompt_function, str):
self.formatter = load_prompt_function(cfg.prompt_function, custom_tasks_module)
else:
formatter = []
for module in custom_tasks_module:
if hasattr(module, cfg.prompt_function):
formatter.append(getattr(module, cfg.prompt_function))

if len(formatter) == 0: # Default version
self.formatter = getattr(tasks_prompt_formatting, cfg.prompt_function)
elif len(formatter) == 1:
# If we have a prompt in both the module and our tasks_prompt_formatting
# We take the prompt from the module
if hasattr(tasks_prompt_formatting, cfg.prompt_function):
hlog_warn(
f"Be careful you are using custom prompt function {cfg.prompt_function} and not the default one."
)
self.formatter = formatter[0]
else:
raise Exception(
f"You defined the prompt function {cfg.prompt_function} several times in the different custom modules you are loading."
)
self.formatter = cfg.prompt_function

self.generation_size = cfg.generation_size
self.stop_sequence = cfg.stop_sequence
self.output_regex = cfg.output_regex
Expand Down
Loading