Skip to content

Commit

Permalink
Add BBH (#7)
Browse files Browse the repository at this point in the history
* manage logprobs for one token, when several tasks have different number of choices
* add bbh to test suite
  • Loading branch information
clefourrier authored Mar 8, 2024
1 parent bca2b1d commit e324a83
Show file tree
Hide file tree
Showing 8 changed files with 588 additions and 14 deletions.
2 changes: 0 additions & 2 deletions src/lighteval/main_accelerate.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@ def main(args):
model_config = create_model_config(args=args, accelerator=accelerator)

with htrack_block("Model loading"):
# We need to load the model in the main process first to avoid downloading the model multiple times
with accelerator.main_process_first() if accelerator is not None else nullcontext():
model, model_info = load_model(config=model_config, env_config=env_config)
evaluation_tracker.general_config_logger.log_model_info(model_info)
Expand All @@ -84,7 +83,6 @@ def main(args):
task_dict = Registry(cache_dir=env_config.cache_dir).get_task_dict(
task_names_list, custom_tasks=args.custom_tasks
)
# Loading all the dataset in a distributed manner
LightevalTask.load_datasets(task_dict.values(), args.dataset_loading_processes)

evaluation_tracker.task_config_logger.log(task_dict)
Expand Down
7 changes: 5 additions & 2 deletions src/lighteval/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import torch
import torch.nn.functional as F
import transformers
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer
Expand Down Expand Up @@ -834,9 +835,11 @@ def _loglikelihood_single_token(
# Sync all
# Need reshape before gather
batched_inputs, len_inputs = self.pad_and_gather(prepared_batch.input_ids)
batch_probs = torch.stack(batch_probs)
# We sometimes have different tasks with a different number of choices.
# Padding to -10000 makes sure that we won't reach index problems later as all log probs will be smaller than that
batch_probs = pad_sequence(batch_probs, batch_first=True, padding_value=-10000000)
batch_probs, len_probs = self.pad_and_gather(batch_probs)
batch_cont_tokens = torch.stack(batch_cont_tokens)
batch_cont_tokens = pad_sequence(batch_cont_tokens, batch_first=True, padding_value=-10000000)
batch_cont_tokens, len_cont = self.pad_and_gather(batch_cont_tokens)

# No reshape
Expand Down
25 changes: 18 additions & 7 deletions src/lighteval/tasks/lighteval_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from pathlib import Path
from typing import TYPE_CHECKING, List, Optional, Tuple, Union

from datasets import DownloadMode, load_dataset
from datasets import load_dataset

from lighteval.few_shot_manager import FewShotSampler
from lighteval.logging.hierarchical_logger import hlog, hlog_warn
Expand Down Expand Up @@ -108,6 +108,8 @@ class LightevalTaskConfig:

trust_dataset: bool = None

must_remove_duplicate_docs: bool = None

def as_dict(self):
return {
"name": self.name,
Expand Down Expand Up @@ -213,6 +215,9 @@ def __init__(self, name: str, cfg: LightevalTaskConfig, cache_dir: Optional[str]
self.generation_size = cfg.generation_size
self.stop_sequence = cfg.stop_sequence
self.output_regex = cfg.output_regex
self.must_remove_duplicate_docs = cfg.must_remove_duplicate_docs
if self.must_remove_duplicate_docs is None:
self.must_remove_duplicate_docs = False

# Save options
self.save_queries: bool = False
Expand Down Expand Up @@ -318,6 +323,14 @@ def _get_docs_from_split(self, splits: list[str], few_shots=False) -> list[Doc]:
docs.extend(as_list(cur_docs))
return docs

def remove_duplicate_docs(self, docs: list[Doc]) -> list[Doc]:
seen_examples, res = set(), []
for doc in docs:
if str(doc) not in seen_examples:
res.append(doc)
seen_examples.add(str(doc))
return res

def fewshot_docs(self) -> list[Doc]:
"""
Returns the few shot documents. If the few shot documents are not
Expand Down Expand Up @@ -346,6 +359,8 @@ def eval_docs(self) -> list[Doc]:
"""
if self._docs is None:
self._docs = self._get_docs_from_split(self.evaluation_split)
if self.must_remove_duplicate_docs:
self._docs = self.remove_duplicate_docs(self._docs)
return self._docs

def doc_to_target(self, formatted_doc: Doc, few_shot: bool = False) -> str:
Expand All @@ -360,12 +375,8 @@ def doc_to_target(self, formatted_doc: Doc, few_shot: bool = False) -> str:
Returns:
str: Target of the document, which is the correct answer for a document.
"""
if few_shot:
if formatted_doc.target_for_fewshot_sorting is not None:
return formatted_doc.target_for_fewshot_sorting

# likely we mostly need one example not all
return formatted_doc.get_golds()[0]
return as_list(formatted_doc.get_golds(few_shot=few_shot))[0]

# Requests
def get_request_type(self) -> list[RequestType]:
Expand Down Expand Up @@ -572,7 +583,7 @@ def download_dataset_worker(args):
name=dataset_config_name,
data_dir=None,
cache_dir=None,
download_mode=DownloadMode.FORCE_REDOWNLOAD, # None
download_mode=None,
trust_remote_code=trust_dataset,
)
return dataset
Expand Down
17 changes: 14 additions & 3 deletions src/lighteval/tasks/requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

from dataclasses import dataclass
import json
from dataclasses import asdict, dataclass
from enum import Enum, auto
from typing import NamedTuple, Optional, Union

Expand Down Expand Up @@ -175,12 +176,22 @@ class Doc:
num_asked_few_shots: int = -1
num_effective_few_shots: int = -1

def get_golds(self):
def get_golds(self, few_shot: bool = False):
"""Return gold targets extracted from the target dict"""
gold_indices = as_list(self.gold_index)
if few_shot and self.target_for_fewshot_sorting is not None:
choices = self.target_for_fewshot_sorting
if isinstance(choices, str): # correct choice is already selected
return choices
else:
choices = self.choices
golds = []
for gold_ix in gold_indices:
local_golds = as_list(self.choices[gold_ix])
local_golds = as_list(choices[gold_ix])
for local_gold in local_golds:
golds.append(local_gold)
return golds

def __repr__(self):
doc_dict = asdict(self)
return json.dumps(doc_dict)
235 changes: 235 additions & 0 deletions src/lighteval/tasks/tasks_prompt_formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,10 @@
import re
import string

import numpy as np
import pycountry

from lighteval.logging.hierarchical_logger import hlog_warn
from lighteval.tasks.requests import Doc
from lighteval.utils import as_list

Expand Down Expand Up @@ -137,6 +139,239 @@ def process_path(path: str) -> str:
return queries


def bbh_harness(line, task_name: str = None):
line = {k: v for k, v in line.items() if v not in [None, ""]}

task_prefix = line.get("task_prefix", "")
example_input_prefix = line.get("example_input_prefix", "\nQ: ")
query = f"{task_prefix}{example_input_prefix}{line['input']}"

rng = np.random.RandomState(seed=42)
choice_prefix = line.get("choice_prefix", "\n choice: ")
append_choices = bool(line.get("append_choices", True))
# default
correct_index = line["target_idx"]
choices = line["choices"]
if append_choices:
choices = list(rng.permutation(sorted(line["choices"])))
query = f"{query}{choice_prefix}{choice_prefix.join(choices)}"
gold_item = line["choices"][line["target_idx"]]
correct_index = choices.index(gold_item)

example_output_prefix = line.get("example_output_prefix", "\nA: ")
query = f"{query}{example_output_prefix}"

return Doc(
task_name=task_name,
query=query,
choices=choices,
gold_index=correct_index,
target_for_fewshot_sorting=choices,
instruction=line.get("task_prefix", None),
)


def bbh_lighteval(line, task_name: str = None):
line = {k: v for k, v in line.items() if v is not None}

query = line.get("task_prefix", "")
query += line.get("example_input_prefix", "\nQuestion: ")
query += line["input"]
query += line.get("choice_prefix", "\n Choices: ")
query += "".join([f"\n{key}. {choice}" for key, choice in zip(LETTER_INDICES, line["choices"])])
query += line.get("example_output_prefix", "\nAnswer: ")

return Doc(
task_name=task_name,
query=query,
choices=LETTER_INDICES[: len(line["choices"])],
gold_index=line["target_idx"],
target_for_fewshot_sorting=LETTER_INDICES[: len(line["choices"])],
instruction=line.get("task_prefix", None),
)


def bbh(line, instruction, choices, task_name: str = None):
return Doc(
task_name=task_name,
query=f"{instruction}Q: {line['input']}\nA:",
choices=choices,
gold_index=choices.index(line["target"]),
target_for_fewshot_sorting=[f" {c}" for c in choices],
instruction=instruction,
)


def bbh_boolean_expressions(line, task_name: str = None):
instruction = "Evaluate the result of a random Boolean expression.\n\n"
choices = ["False", "True"]
return bbh(line, instruction, choices, task_name)


def bbh_causal_judgment(line, task_name: str = None):
instruction = "Answer questions about causal attribution.\n\n"
choices = ["Yes", "No"]
return bbh(line, instruction, choices, task_name)


def bbh_date_understanding(line, task_name: str = None):
instruction = "Infer the date from context.\n\n"
choices = [f"({c})" for c in LETTER_INDICES[:6]]
return bbh(line, instruction, choices, task_name)


def bbh_disambiguation_qa(line, task_name: str = None):
instruction = "Clarify the meaning of sentences with ambiguous pronouns.\n\n"
choices = [f"({c})" for c in LETTER_INDICES[:3]]
return bbh(line, instruction, choices, task_name)


def bbh_dyck_languages(line, task_name: str = None): # Can only be done in generative
instruction = "Correctly close a Dyck-n word.\n\n"
choices = [line["target"]]
return bbh(line, instruction, choices, task_name)


def bbh_formal_fallacies(line, task_name: str = None):
instruction = "Distinguish deductively valid arguments from formal fallacies.\n\n"
choices = ["valid", "invalid"]
return bbh(line, instruction, choices, task_name)


def bbh_geometric_shapes(line, task_name: str = None):
instruction = "Name geometric shapes from their SVG paths.\n\n"
choices = [f"({c})" for c in LETTER_INDICES[:11]]
return bbh(line, instruction, choices, task_name)


def bbh_hyperbaton(line, task_name: str = None):
instruction = "Order adjectives correctly in English sentences.\n\n"
choices = [f"({c})" for c in LETTER_INDICES[:2]]
return bbh(line, instruction, choices, task_name)


def bbh_logical_deduction_five_objects(line, task_name: str = None):
instruction = "A logical deduction task which requires deducing the order of a sequence of objects.\n\n"
choices = [f"({c})" for c in LETTER_INDICES[:5]]
return bbh(line, instruction, choices, task_name)


def bbh_logical_deduction_seven_objects(line, task_name: str = None):
instruction = "A logical deduction task which requires deducing the order of a sequence of objects.\n\n"
choices = [f"({c})" for c in LETTER_INDICES[:7]]
return bbh(line, instruction, choices, task_name)


def bbh_logical_deduction_three_objects(line, task_name: str = None):
instruction = "A logical deduction task which requires deducing the order of a sequence of objects.\n\n"
choices = [f"({c})" for c in LETTER_INDICES[:3]]
return bbh(line, instruction, choices, task_name)


def bbh_movie_recommendation(line, task_name: str = None):
if line["target"] == "Monsters, Inc": # this line is not correctly formatted
hlog_warn("One sample removed from task bbh:movie_recommentation because its line is incorrectly formatted.")
return []
instruction = "Recommend movies similar to the given list of movies.\n\n"
choices = [f"({c})" for c in LETTER_INDICES[:6]]
return bbh(line, instruction, choices, task_name)


def bbh_multistep_arithmetic_two(line, task_name: str = None):
instruction = "Solve multi-step arithmetic problems.\n\n" # Can only be done in generative
choices = [line["target"]]
return bbh(line, instruction, choices, task_name)


def bbh_navigate(line, task_name: str = None):
instruction = (
"Given a series of navigation instructions, determine whether one would end up back at the starting point.\n\n"
)
choices = ["Yes", "No"]
return bbh(line, instruction, choices, task_name)


def bbh_object_counting(line, task_name: str = None):
instruction = "Questions that involve enumerating objects and asking the model to count them.\n\n"
choices = [str(i) for i in range(1, 19)]
return bbh(line, instruction, choices, task_name)


def bbh_penguins_in_a_table(line, task_name: str = None):
instruction = "Answer questions about a table of penguins and their attributes.\n\n"
choices = [f"({c})" for c in LETTER_INDICES[:5]]
return bbh(line, instruction, choices, task_name)


def bbh_reasoning_about_colored_objects(line, task_name: str = None):
instruction = "Answer extremely simple questions about the colors of objects on a surface.\n\n"
choices = [f"({c})" for c in LETTER_INDICES[:18]]
return bbh(line, instruction, choices, task_name)


def bbh_ruin_names(line, task_name: str = None):
if line["target"] in ["dearth, wind, & fire", "rita, sue and bob poo"]: # line not correctly formatted
hlog_warn("One sample removed from task bbh:ruin_names because its line is incorrectly formatted.")
return []
instruction = "Select the humorous edit that 'ruins' the input movie or musical artist name.\n\n"
choices = [f"({c})" for c in LETTER_INDICES[:6]]
return bbh(line, instruction, choices, task_name)


def bbh_salient_translation_error_detection(line, task_name: str = None):
instruction = "Detect the type of error in an English translation of a German source sentence.\n\n"
choices = [f"({c})" for c in LETTER_INDICES[:6]]
return bbh(line, instruction, choices, task_name)


def bbh_snarks(line, task_name: str = None):
instruction = 'Determine which of two sentences is sarcastic.\n\nAccording to Cambridge University Dictionary, sarcasm is "the use of remarks that clearly mean the opposite of what they say, made in order to hurt someone\'s feelings or to criticize something in a humorous way." Sarcastic sentences often contain satirical or ironic utterances, hyperboles, ambivalent or witty remarks.\n\n'
choices = [f"({c})" for c in LETTER_INDICES[:2]]
return bbh(line, instruction, choices, task_name)


def bbh_sports_understanding(line, task_name: str = None):
instruction = "Determine whether an artificially constructed sentence relating to sports is plausible or not.\n\n"
choices = ["yes", "no"]
return bbh(line, instruction, choices, task_name)


def bbh_temporal_sequences(line, task_name: str = None):
instruction = "Task description: Answer questions about which times certain events could have occurred.\n\n"
choices = [f"({c})" for c in LETTER_INDICES[:4]]
return bbh(line, instruction, choices, task_name)


def bbh_tracking_shuffled_objects_five_objects(line, task_name: str = None):
instruction = "A task requiring determining the final positions of a set of objects given their initial positions and a description of a sequence of swaps.\n\n"
choices = [f"({c})" for c in LETTER_INDICES[:5]]
return bbh(line, instruction, choices, task_name)


def bbh_tracking_shuffled_objects_seven_objects(line, task_name: str = None):
instruction = "A task requiring determining the final positions of a set of objects given their initial positions and a description of a sequence of swaps.\n\n"
choices = [f"({c})" for c in LETTER_INDICES[:7]]
return bbh(line, instruction, choices, task_name)


def bbh_tracking_shuffled_objects_three_objects(line, task_name: str = None):
instruction = "A task requiring determining the final positions of a set of objects given their initial positions and a description of a sequence of swaps.\n\n"
choices = [f"({c})" for c in LETTER_INDICES[:3]]
return bbh(line, instruction, choices, task_name)


def bbh_web_of_lies(line, task_name: str = None):
instruction = "Evaluate a random boolean function expressed as a word problem.\n\n"
choices = ["Yes", "No"]
return bbh(line, instruction, choices, task_name)


def bbh_word_sorting(line, task_name: str = None):
instruction = "Sort a list of words.\n\n" # Can only be done in generative
choices = [line["target"]]
return bbh(line, instruction, choices, task_name)


def bbq(line, task_name: str = None): # HELM
query = f"The following are multiple choice questions (with answers).\nPassage: {line['context']}\nQuestion: {line['question']}"
query += "".join([f"\n{key}. {choice}" for key, choice in zip(LETTER_INDICES, line["choices"])])
Expand Down
Loading

0 comments on commit e324a83

Please sign in to comment.