Skip to content

Commit

Permalink
Merge branch 'simplify_task_system' into hynek_function
Browse files Browse the repository at this point in the history
  • Loading branch information
clefourrier authored Jul 9, 2024
2 parents 9f518ad + da016b7 commit 8f98337
Show file tree
Hide file tree
Showing 9 changed files with 189 additions and 49 deletions.
11 changes: 11 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,17 @@ python run_evals_accelerate.py \
--output_dir "./evals"
```

### Using the dummy model
To debug or obtain random baseline scores for a given set of tasks, you can use the `dummy` model:
```shell
python run_evals_accelerate.py \
--model_args "dummy"\
--tasks <task parameters> \
--output_dir output_dir
```
This "model" randomly generates logprobs (for selection/accuracy tasks) and the string "random baseline" for generation tasks.
You can also select a specific seed for the random logprob values generated by the dummy model: `--model_args "dummy,seed=123"`.

## Deep thanks
`lighteval` was originally built on top of the great [Eleuther AI Harness](https://github.com/EleutherAI/lm-evaluation-harness) (we use the latter to power the [Open LLM Leaderboard](https://huggingface.co/spaces/HuggingFaceH4/open_llm_leaderboard)). We also took a lot of inspiration from the amazing [HELM](https://crfm.stanford.edu/helm/latest/), notably for metrics.

Expand Down
37 changes: 17 additions & 20 deletions src/lighteval/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,27 +530,7 @@ def greedy_until(
returns_logits = batch[0].use_logits
num_samples = batch[0].num_samples

# The main question for this step is the following:
# Would we rather truncate the prompt to allow generation to go to max_new_tokens, at the risk
# of loosing some meaning, or have some generations that are exceedingly short?
# The choice we go for here is to avoid truncating the prompt if we can, since it
# should have been managed by the prompt creator/few shot manager if requested by the user.
context = [c.context for c in batch]
smallest_context = min(len(c) for c in context)
biggest_context = max(len(c) for c in context)
if smallest_context > self.max_length:
hlog_warn(
f"The smallest context of your batch ({smallest_context}) is bigger than the maximum context size allowed by the model ({self.max_length}) for a task in"
+ str({i.task_name for i in batch})
+ ". This is likely to lead to some errors." # noqa C401
)

if (
biggest_context > self.max_length
): # There will be truncation of at least one sample, maximum generation size will be one
max_new_tokens = 1
else: # We can't allow generation of more than max_length
max_new_tokens = min(self.max_length - biggest_context, max_new_tokens)

# See doc https://huggingface.co/docs/transformers/v4.38.2/en/pad_truncation#padding-and-truncation
# Will do left truncation and padding, as defined when creating the tokenizer
Expand All @@ -563,6 +543,23 @@ def greedy_until(
add_special_tokens=self.add_special_tokens,
).to(self.device)

# The main question for this step is the following:
# Would we rather truncate the prompt to allow generation to go to max_new_tokens, at the risk
# of losing some meaning, or have some generations that are exceedingly short?
# The choice we go for here is to avoid truncating the prompt if we can, since it
# should have been managed by the prompt creator/few shot manager if requested by the user.
context_size = tokenized["input_ids"].shape[1]
if context_size > self.max_length:
hlog_warn(
f"The context size of your batch ({context_size}) is bigger than the maximum context size allowed by the model ({self.max_length}) for a task in"
+ str({i.task_name for i in batch})
+ ". This is likely to lead to some errors." # noqa C401
)
# There will be truncation of at least one sample, maximum generation size will be one
max_new_tokens = 1
else: # We can't allow generation of more than max_length
max_new_tokens = min(self.max_length - context_size, max_new_tokens)

prepared_batch = Batch(
input_ids=tokenized["input_ids"],
input_lengths=[len(item == 1) for item in tokenized["attention_mask"]],
Expand Down
89 changes: 89 additions & 0 deletions src/lighteval/models/dummy_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# MIT License
#
# Copyright (c) 2024 The HuggingFace Team
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

# inspired by https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/models/dummy.py

import random
from typing import Optional

from transformers import AutoTokenizer

from lighteval.models.abstract_model import LightevalModel
from lighteval.models.model_config import DummyModelConfig, EnvConfig
from lighteval.models.model_output import GenerateReturn, LoglikelihoodReturn, LoglikelihoodSingleTokenReturn
from lighteval.tasks.requests import (
GreedyUntilRequest,
LoglikelihoodRequest,
LoglikelihoodRollingRequest,
LoglikelihoodSingleTokenRequest,
)


class DummyModel(LightevalModel):
"""Dummy model to generate random baselines."""

def __init__(
self,
config: DummyModelConfig,
env_config: EnvConfig,
):
self.config = config
self.env_config = env_config
self._random = random.Random(self.config.seed)
self._tokenizer = None

@property
def tokenizer(self):
if not self._tokenizer:
self._tokenizer = AutoTokenizer.from_pretrained("gpt2")
return self._tokenizer

@property
def add_special_tokens(self):
return False

@property
def max_length(self) -> int:
return 2048

def greedy_until(
self, requests: list[GreedyUntilRequest], override_bs: Optional[int] = None
) -> list[GenerateReturn]:
return [GenerateReturn(result="random baseline") for _ in range(len(requests))]

def loglikelihood(
self, requests: list[LoglikelihoodRequest], override_bs: Optional[int] = None
) -> list[LoglikelihoodReturn]:
return [LoglikelihoodReturn((-self._random.random(), False)) for _ in requests]

def loglikelihood_rolling(
self, requests: list[LoglikelihoodRollingRequest], override_bs: Optional[int] = None
) -> list[LoglikelihoodReturn]:
return [LoglikelihoodReturn((-self._random.random(), False)) for _ in requests]

def loglikelihood_single_token(
self, requests: list[LoglikelihoodSingleTokenRequest], override_bs: Optional[int] = None
) -> list[LoglikelihoodSingleTokenReturn]:
return [
LoglikelihoodSingleTokenReturn(result=[-self._random.random() for _ in req.tokenized_continuation])
for req in requests
]
24 changes: 21 additions & 3 deletions src/lighteval/models/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,11 @@ class TGIModelConfig:
model_id: str


@dataclass
class DummyModelConfig:
seed: int = 42


@dataclass
class InferenceModelConfig:
model: str
Expand Down Expand Up @@ -253,7 +258,16 @@ def nullable_keys() -> list[str]:
return ["namespace", "env_vars", "image_url"]


def create_model_config(args: Namespace, accelerator: Union["Accelerator", None]) -> BaseModelConfig: # noqa: C901
def create_model_config( # noqa: C901
args: Namespace, accelerator: Union["Accelerator", None]
) -> Union[
BaseModelConfig,
AdapterModelConfig,
DeltaModelConfig,
TGIModelConfig,
InferenceEndpointModelConfig,
DummyModelConfig,
]:
"""
Create a model configuration based on the provided arguments.
Expand All @@ -262,7 +276,7 @@ def create_model_config(args: Namespace, accelerator: Union["Accelerator", None]
accelerator (Union[Accelerator, None]): accelerator to use for model training.
Returns:
BaseModelConfig: model configuration.
Union[BaseModelConfig, AdapterModelConfig, DeltaModelConfig, TGIModelConfig, InferenceEndpointModelConfig, DummyModelConfig]: model configuration.
Raises:
ValueError: If both an inference server address and model arguments are provided.
Expand All @@ -271,7 +285,11 @@ def create_model_config(args: Namespace, accelerator: Union["Accelerator", None]
ValueError: If a base model is specified when not using delta weights or adapter weights.
"""
if args.model_args:
args_dict = {k.split("=")[0]: k.split("=")[1] for k in args.model_args.split(",")}
args_dict = {k.split("=")[0]: k.split("=")[1] if "=" in k else True for k in args.model_args.split(",")}

if args_dict.pop("dummy", False):
return DummyModelConfig(**args_dict)

args_dict["accelerator"] = accelerator
args_dict["use_chat_template"] = args.use_chat_template

Expand Down
20 changes: 18 additions & 2 deletions src/lighteval/models/model_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,13 @@
from lighteval.models.adapter_model import AdapterModel
from lighteval.models.base_model import BaseModel
from lighteval.models.delta_model import DeltaModel
from lighteval.models.dummy_model import DummyModel
from lighteval.models.endpoint_model import InferenceEndpointModel
from lighteval.models.model_config import (
AdapterModelConfig,
BaseModelConfig,
DeltaModelConfig,
DummyModelConfig,
EnvConfig,
InferenceEndpointModelConfig,
InferenceModelConfig,
Expand All @@ -54,9 +56,16 @@ class ModelInfo:


def load_model( # noqa: C901
config: Union[BaseModelConfig, AdapterModelConfig, DeltaModelConfig, TGIModelConfig, InferenceEndpointModelConfig],
config: Union[
BaseModelConfig,
AdapterModelConfig,
DeltaModelConfig,
TGIModelConfig,
InferenceEndpointModelConfig,
DummyModelConfig,
],
env_config: EnvConfig,
) -> Tuple[Union[BaseModel, AdapterModel, DeltaModel, ModelClient], ModelInfo]:
) -> Tuple[Union[BaseModel, AdapterModel, DeltaModel, ModelClient, DummyModel], ModelInfo]:
"""Will load either a model from an inference server or a model from a checkpoint, depending
on the config type.
Expand All @@ -82,6 +91,9 @@ def load_model( # noqa: C901
if isinstance(config, BaseModelConfig):
return load_model_with_accelerate_or_default(config=config, env_config=env_config)

if isinstance(config, DummyModelConfig):
return load_dummy_model(config=config, env_config=env_config)


def load_model_with_tgi(config: TGIModelConfig):
if not is_tgi_available():
Expand Down Expand Up @@ -143,3 +155,7 @@ def load_model_with_accelerate_or_default(
hlog(f"Model info: {model_info}")

return model, model_info


def load_dummy_model(config: DummyModelConfig, env_config: EnvConfig):
return DummyModel(config=config, env_config=env_config), ModelInfo(model_name="dummy", model_sha=str(config.seed))
4 changes: 2 additions & 2 deletions src/lighteval/models/model_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ class ModelReturn:
result: Union[tuple, list, str]
input_tokens: list[int] = field(default_factory=list) # model inputs
generated_tokens: list[int] = field(default_factory=list) # model generations
truncated_tokens_count: Optional[int] = None # How many tokens truncated
padded_tokens_count: Optional[int] = None # How many tokens of padding
truncated_tokens_count: Optional[int] = 0 # How many tokens truncated
padded_tokens_count: Optional[int] = 0 # How many tokens of padding

def get_result_for_eval(self):
raise NotImplementedError()
Expand Down
39 changes: 18 additions & 21 deletions src/lighteval/models/nanotron_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1207,27 +1207,7 @@ def greedy_until(
"Nonotron models does not allow sampling evaluations - this is likely to fail or provide problematic results"
)

# The main question for this step is the following:
# Would we rather truncate the prompt to allow generation to go to max_new_tokens, at the risk
# of loosing some meaning, or have some generations that are exceedingly short?
# The choice we go for here is to avoid truncating the prompt if we can, since it
# should have been managed by the prompt creator/few shot manager if requested by the user.
context = [c.context for c in batch] # or tokenized context?
smallest_context = min(len(c) for c in context)
biggest_context = max(len(c) for c in context)
if smallest_context > self.max_length:
hlog_warn(
f"The smallest context of your batch ({smallest_context}) is bigger than the maximum context size allowed by the model ({self.max_length}) for a task in"
+ str({i.task_name for i in batch})
+ ". This is likely to lead to some errors." # noqa C401
)

if (
biggest_context > self.max_length
): # There will be truncation of at least one sample, maximum generation size will be one
max_new_tokens = 1
else: # We can't allow generation of more than max_length
max_new_tokens = min(self.max_length - biggest_context, max_new_tokens)
context = [c.context for c in batch]

# See doc https://huggingface.co/docs/transformers/v4.38.2/en/pad_truncation#padding-and-truncation
# Will do left truncation and padding, as defined when creating the tokenizer
Expand All @@ -1240,6 +1220,23 @@ def greedy_until(
add_special_tokens=self.add_special_tokens,
).to(self.device)

# The main question for this step is the following:
# Would we rather truncate the prompt to allow generation to go to max_new_tokens, at the risk
# of losing some meaning, or have some generations that are exceedingly short?
# The choice we go for here is to avoid truncating the prompt if we can, since it
# should have been managed by the prompt creator/few shot manager if requested by the user.
context_size = tokenized["input_ids"].shape[1]
if context_size > self.max_length:
hlog_warn(
f"The context size of your batch ({context_size}) is bigger than the maximum context size allowed by the model ({self.max_length}) for a task in"
+ str({i.task_name for i in batch})
+ ". This is likely to lead to some errors." # noqa C401
)
# There will be truncation of at least one sample, maximum generation size will be one
max_new_tokens = 1
else: # We can't allow generation of more than max_length
max_new_tokens = min(self.max_length - context_size, max_new_tokens)

batch_model = Batch(
input_ids=tokenized["input_ids"],
input_lengths=[len(item == 1) for item in tokenized["attention_mask"]],
Expand Down
12 changes: 12 additions & 0 deletions src/lighteval/tasks/lighteval_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,6 +498,18 @@ def construct_requests(
generation_size=self.generation_size,
)
]
if self.has_metric_category[MetricCategory.LLM_AS_JUDGE]:
requests[RequestType.GREEDY_UNTIL] += [
GreedyUntilRequest(
task_name=current_task_name,
example_index=document_id_seed,
request_index=0,
context=context,
stop_sequence=self.stop_sequence,
generation_size=self.generation_size,
num_samples=1,
)
]

return requests

Expand Down
2 changes: 1 addition & 1 deletion src/lighteval/tasks/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ def create_config_tasks(
Create configuration tasks based on the provided meta_table.
Args:
meta_table: meta_table containing task
meta_table: meta_table containing tasks
configurations. If not provided, it will be loaded from TABLE_PATH.
cache_dir: Directory to store cached data. If not
provided, the default cache directory will be used.
Expand Down

0 comments on commit 8f98337

Please sign in to comment.