diff --git a/src/lighteval/main_accelerate.py b/src/lighteval/main_accelerate.py index f106485f..3718b7ae 100644 --- a/src/lighteval/main_accelerate.py +++ b/src/lighteval/main_accelerate.py @@ -73,7 +73,6 @@ def main(args): model_config = create_model_config(args=args, accelerator=accelerator) with htrack_block("Model loading"): - # We need to load the model in the main process first to avoid downloading the model multiple times with accelerator.main_process_first() if accelerator is not None else nullcontext(): model, model_info = load_model(config=model_config, env_config=env_config) evaluation_tracker.general_config_logger.log_model_info(model_info) @@ -84,7 +83,6 @@ def main(args): task_dict = Registry(cache_dir=env_config.cache_dir).get_task_dict( task_names_list, custom_tasks=args.custom_tasks ) - # Loading all the dataset in a distributed manner LightevalTask.load_datasets(task_dict.values(), args.dataset_loading_processes) evaluation_tracker.task_config_logger.log(task_dict) diff --git a/src/lighteval/models/base_model.py b/src/lighteval/models/base_model.py index 556d9f51..db913c04 100644 --- a/src/lighteval/models/base_model.py +++ b/src/lighteval/models/base_model.py @@ -26,6 +26,7 @@ import torch import torch.nn.functional as F import transformers +from torch.nn.utils.rnn import pad_sequence from torch.utils.data import DataLoader from tqdm import tqdm from transformers import AutoModelForCausalLM, AutoTokenizer @@ -834,9 +835,11 @@ def _loglikelihood_single_token( # Sync all # Need reshape before gather batched_inputs, len_inputs = self.pad_and_gather(prepared_batch.input_ids) - batch_probs = torch.stack(batch_probs) + # We sometimes have different tasks with a different number of choices. + # Padding to -10000 makes sure that we won't reach index problems later as all log probs will be smaller than that + batch_probs = pad_sequence(batch_probs, batch_first=True, padding_value=-10000000) batch_probs, len_probs = self.pad_and_gather(batch_probs) - batch_cont_tokens = torch.stack(batch_cont_tokens) + batch_cont_tokens = pad_sequence(batch_cont_tokens, batch_first=True, padding_value=-10000000) batch_cont_tokens, len_cont = self.pad_and_gather(batch_cont_tokens) # No reshape diff --git a/src/lighteval/tasks/lighteval_task.py b/src/lighteval/tasks/lighteval_task.py index a94b5995..3f310270 100644 --- a/src/lighteval/tasks/lighteval_task.py +++ b/src/lighteval/tasks/lighteval_task.py @@ -27,7 +27,7 @@ from pathlib import Path from typing import TYPE_CHECKING, List, Optional, Tuple, Union -from datasets import DownloadMode, load_dataset +from datasets import load_dataset from lighteval.few_shot_manager import FewShotSampler from lighteval.logging.hierarchical_logger import hlog, hlog_warn @@ -108,6 +108,8 @@ class LightevalTaskConfig: trust_dataset: bool = None + must_remove_duplicate_docs: bool = None + def as_dict(self): return { "name": self.name, @@ -213,6 +215,9 @@ def __init__(self, name: str, cfg: LightevalTaskConfig, cache_dir: Optional[str] self.generation_size = cfg.generation_size self.stop_sequence = cfg.stop_sequence self.output_regex = cfg.output_regex + self.must_remove_duplicate_docs = cfg.must_remove_duplicate_docs + if self.must_remove_duplicate_docs is None: + self.must_remove_duplicate_docs = False # Save options self.save_queries: bool = False @@ -318,6 +323,14 @@ def _get_docs_from_split(self, splits: list[str], few_shots=False) -> list[Doc]: docs.extend(as_list(cur_docs)) return docs + def remove_duplicate_docs(self, docs: list[Doc]) -> list[Doc]: + seen_examples, res = set(), [] + for doc in docs: + if str(doc) not in seen_examples: + res.append(doc) + seen_examples.add(str(doc)) + return res + def fewshot_docs(self) -> list[Doc]: """ Returns the few shot documents. If the few shot documents are not @@ -346,6 +359,8 @@ def eval_docs(self) -> list[Doc]: """ if self._docs is None: self._docs = self._get_docs_from_split(self.evaluation_split) + if self.must_remove_duplicate_docs: + self._docs = self.remove_duplicate_docs(self._docs) return self._docs def doc_to_target(self, formatted_doc: Doc, few_shot: bool = False) -> str: @@ -360,12 +375,8 @@ def doc_to_target(self, formatted_doc: Doc, few_shot: bool = False) -> str: Returns: str: Target of the document, which is the correct answer for a document. """ - if few_shot: - if formatted_doc.target_for_fewshot_sorting is not None: - return formatted_doc.target_for_fewshot_sorting - # likely we mostly need one example not all - return formatted_doc.get_golds()[0] + return as_list(formatted_doc.get_golds(few_shot=few_shot))[0] # Requests def get_request_type(self) -> list[RequestType]: @@ -572,7 +583,7 @@ def download_dataset_worker(args): name=dataset_config_name, data_dir=None, cache_dir=None, - download_mode=DownloadMode.FORCE_REDOWNLOAD, # None + download_mode=None, trust_remote_code=trust_dataset, ) return dataset diff --git a/src/lighteval/tasks/requests.py b/src/lighteval/tasks/requests.py index 9f97680c..b1fd6062 100644 --- a/src/lighteval/tasks/requests.py +++ b/src/lighteval/tasks/requests.py @@ -20,7 +20,8 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. -from dataclasses import dataclass +import json +from dataclasses import asdict, dataclass from enum import Enum, auto from typing import NamedTuple, Optional, Union @@ -175,12 +176,22 @@ class Doc: num_asked_few_shots: int = -1 num_effective_few_shots: int = -1 - def get_golds(self): + def get_golds(self, few_shot: bool = False): """Return gold targets extracted from the target dict""" gold_indices = as_list(self.gold_index) + if few_shot and self.target_for_fewshot_sorting is not None: + choices = self.target_for_fewshot_sorting + if isinstance(choices, str): # correct choice is already selected + return choices + else: + choices = self.choices golds = [] for gold_ix in gold_indices: - local_golds = as_list(self.choices[gold_ix]) + local_golds = as_list(choices[gold_ix]) for local_gold in local_golds: golds.append(local_gold) return golds + + def __repr__(self): + doc_dict = asdict(self) + return json.dumps(doc_dict) diff --git a/src/lighteval/tasks/tasks_prompt_formatting.py b/src/lighteval/tasks/tasks_prompt_formatting.py index a29bc65e..08f088b2 100644 --- a/src/lighteval/tasks/tasks_prompt_formatting.py +++ b/src/lighteval/tasks/tasks_prompt_formatting.py @@ -26,8 +26,10 @@ import re import string +import numpy as np import pycountry +from lighteval.logging.hierarchical_logger import hlog_warn from lighteval.tasks.requests import Doc from lighteval.utils import as_list @@ -137,6 +139,239 @@ def process_path(path: str) -> str: return queries +def bbh_harness(line, task_name: str = None): + line = {k: v for k, v in line.items() if v not in [None, ""]} + + task_prefix = line.get("task_prefix", "") + example_input_prefix = line.get("example_input_prefix", "\nQ: ") + query = f"{task_prefix}{example_input_prefix}{line['input']}" + + rng = np.random.RandomState(seed=42) + choice_prefix = line.get("choice_prefix", "\n choice: ") + append_choices = bool(line.get("append_choices", True)) + # default + correct_index = line["target_idx"] + choices = line["choices"] + if append_choices: + choices = list(rng.permutation(sorted(line["choices"]))) + query = f"{query}{choice_prefix}{choice_prefix.join(choices)}" + gold_item = line["choices"][line["target_idx"]] + correct_index = choices.index(gold_item) + + example_output_prefix = line.get("example_output_prefix", "\nA: ") + query = f"{query}{example_output_prefix}" + + return Doc( + task_name=task_name, + query=query, + choices=choices, + gold_index=correct_index, + target_for_fewshot_sorting=choices, + instruction=line.get("task_prefix", None), + ) + + +def bbh_lighteval(line, task_name: str = None): + line = {k: v for k, v in line.items() if v is not None} + + query = line.get("task_prefix", "") + query += line.get("example_input_prefix", "\nQuestion: ") + query += line["input"] + query += line.get("choice_prefix", "\n Choices: ") + query += "".join([f"\n{key}. {choice}" for key, choice in zip(LETTER_INDICES, line["choices"])]) + query += line.get("example_output_prefix", "\nAnswer: ") + + return Doc( + task_name=task_name, + query=query, + choices=LETTER_INDICES[: len(line["choices"])], + gold_index=line["target_idx"], + target_for_fewshot_sorting=LETTER_INDICES[: len(line["choices"])], + instruction=line.get("task_prefix", None), + ) + + +def bbh(line, instruction, choices, task_name: str = None): + return Doc( + task_name=task_name, + query=f"{instruction}Q: {line['input']}\nA:", + choices=choices, + gold_index=choices.index(line["target"]), + target_for_fewshot_sorting=[f" {c}" for c in choices], + instruction=instruction, + ) + + +def bbh_boolean_expressions(line, task_name: str = None): + instruction = "Evaluate the result of a random Boolean expression.\n\n" + choices = ["False", "True"] + return bbh(line, instruction, choices, task_name) + + +def bbh_causal_judgment(line, task_name: str = None): + instruction = "Answer questions about causal attribution.\n\n" + choices = ["Yes", "No"] + return bbh(line, instruction, choices, task_name) + + +def bbh_date_understanding(line, task_name: str = None): + instruction = "Infer the date from context.\n\n" + choices = [f"({c})" for c in LETTER_INDICES[:6]] + return bbh(line, instruction, choices, task_name) + + +def bbh_disambiguation_qa(line, task_name: str = None): + instruction = "Clarify the meaning of sentences with ambiguous pronouns.\n\n" + choices = [f"({c})" for c in LETTER_INDICES[:3]] + return bbh(line, instruction, choices, task_name) + + +def bbh_dyck_languages(line, task_name: str = None): # Can only be done in generative + instruction = "Correctly close a Dyck-n word.\n\n" + choices = [line["target"]] + return bbh(line, instruction, choices, task_name) + + +def bbh_formal_fallacies(line, task_name: str = None): + instruction = "Distinguish deductively valid arguments from formal fallacies.\n\n" + choices = ["valid", "invalid"] + return bbh(line, instruction, choices, task_name) + + +def bbh_geometric_shapes(line, task_name: str = None): + instruction = "Name geometric shapes from their SVG paths.\n\n" + choices = [f"({c})" for c in LETTER_INDICES[:11]] + return bbh(line, instruction, choices, task_name) + + +def bbh_hyperbaton(line, task_name: str = None): + instruction = "Order adjectives correctly in English sentences.\n\n" + choices = [f"({c})" for c in LETTER_INDICES[:2]] + return bbh(line, instruction, choices, task_name) + + +def bbh_logical_deduction_five_objects(line, task_name: str = None): + instruction = "A logical deduction task which requires deducing the order of a sequence of objects.\n\n" + choices = [f"({c})" for c in LETTER_INDICES[:5]] + return bbh(line, instruction, choices, task_name) + + +def bbh_logical_deduction_seven_objects(line, task_name: str = None): + instruction = "A logical deduction task which requires deducing the order of a sequence of objects.\n\n" + choices = [f"({c})" for c in LETTER_INDICES[:7]] + return bbh(line, instruction, choices, task_name) + + +def bbh_logical_deduction_three_objects(line, task_name: str = None): + instruction = "A logical deduction task which requires deducing the order of a sequence of objects.\n\n" + choices = [f"({c})" for c in LETTER_INDICES[:3]] + return bbh(line, instruction, choices, task_name) + + +def bbh_movie_recommendation(line, task_name: str = None): + if line["target"] == "Monsters, Inc": # this line is not correctly formatted + hlog_warn("One sample removed from task bbh:movie_recommentation because its line is incorrectly formatted.") + return [] + instruction = "Recommend movies similar to the given list of movies.\n\n" + choices = [f"({c})" for c in LETTER_INDICES[:6]] + return bbh(line, instruction, choices, task_name) + + +def bbh_multistep_arithmetic_two(line, task_name: str = None): + instruction = "Solve multi-step arithmetic problems.\n\n" # Can only be done in generative + choices = [line["target"]] + return bbh(line, instruction, choices, task_name) + + +def bbh_navigate(line, task_name: str = None): + instruction = ( + "Given a series of navigation instructions, determine whether one would end up back at the starting point.\n\n" + ) + choices = ["Yes", "No"] + return bbh(line, instruction, choices, task_name) + + +def bbh_object_counting(line, task_name: str = None): + instruction = "Questions that involve enumerating objects and asking the model to count them.\n\n" + choices = [str(i) for i in range(1, 19)] + return bbh(line, instruction, choices, task_name) + + +def bbh_penguins_in_a_table(line, task_name: str = None): + instruction = "Answer questions about a table of penguins and their attributes.\n\n" + choices = [f"({c})" for c in LETTER_INDICES[:5]] + return bbh(line, instruction, choices, task_name) + + +def bbh_reasoning_about_colored_objects(line, task_name: str = None): + instruction = "Answer extremely simple questions about the colors of objects on a surface.\n\n" + choices = [f"({c})" for c in LETTER_INDICES[:18]] + return bbh(line, instruction, choices, task_name) + + +def bbh_ruin_names(line, task_name: str = None): + if line["target"] in ["dearth, wind, & fire", "rita, sue and bob poo"]: # line not correctly formatted + hlog_warn("One sample removed from task bbh:ruin_names because its line is incorrectly formatted.") + return [] + instruction = "Select the humorous edit that 'ruins' the input movie or musical artist name.\n\n" + choices = [f"({c})" for c in LETTER_INDICES[:6]] + return bbh(line, instruction, choices, task_name) + + +def bbh_salient_translation_error_detection(line, task_name: str = None): + instruction = "Detect the type of error in an English translation of a German source sentence.\n\n" + choices = [f"({c})" for c in LETTER_INDICES[:6]] + return bbh(line, instruction, choices, task_name) + + +def bbh_snarks(line, task_name: str = None): + instruction = 'Determine which of two sentences is sarcastic.\n\nAccording to Cambridge University Dictionary, sarcasm is "the use of remarks that clearly mean the opposite of what they say, made in order to hurt someone\'s feelings or to criticize something in a humorous way." Sarcastic sentences often contain satirical or ironic utterances, hyperboles, ambivalent or witty remarks.\n\n' + choices = [f"({c})" for c in LETTER_INDICES[:2]] + return bbh(line, instruction, choices, task_name) + + +def bbh_sports_understanding(line, task_name: str = None): + instruction = "Determine whether an artificially constructed sentence relating to sports is plausible or not.\n\n" + choices = ["yes", "no"] + return bbh(line, instruction, choices, task_name) + + +def bbh_temporal_sequences(line, task_name: str = None): + instruction = "Task description: Answer questions about which times certain events could have occurred.\n\n" + choices = [f"({c})" for c in LETTER_INDICES[:4]] + return bbh(line, instruction, choices, task_name) + + +def bbh_tracking_shuffled_objects_five_objects(line, task_name: str = None): + instruction = "A task requiring determining the final positions of a set of objects given their initial positions and a description of a sequence of swaps.\n\n" + choices = [f"({c})" for c in LETTER_INDICES[:5]] + return bbh(line, instruction, choices, task_name) + + +def bbh_tracking_shuffled_objects_seven_objects(line, task_name: str = None): + instruction = "A task requiring determining the final positions of a set of objects given their initial positions and a description of a sequence of swaps.\n\n" + choices = [f"({c})" for c in LETTER_INDICES[:7]] + return bbh(line, instruction, choices, task_name) + + +def bbh_tracking_shuffled_objects_three_objects(line, task_name: str = None): + instruction = "A task requiring determining the final positions of a set of objects given their initial positions and a description of a sequence of swaps.\n\n" + choices = [f"({c})" for c in LETTER_INDICES[:3]] + return bbh(line, instruction, choices, task_name) + + +def bbh_web_of_lies(line, task_name: str = None): + instruction = "Evaluate a random boolean function expressed as a word problem.\n\n" + choices = ["Yes", "No"] + return bbh(line, instruction, choices, task_name) + + +def bbh_word_sorting(line, task_name: str = None): + instruction = "Sort a list of words.\n\n" # Can only be done in generative + choices = [line["target"]] + return bbh(line, instruction, choices, task_name) + + def bbq(line, task_name: str = None): # HELM query = f"The following are multiple choice questions (with answers).\nPassage: {line['context']}\nQuestion: {line['question']}" query += "".join([f"\n{key}. {choice}" for key, choice in zip(LETTER_INDICES, line["choices"])]) diff --git a/src/lighteval/tasks/tasks_table.jsonl b/src/lighteval/tasks/tasks_table.jsonl index 98434c57..f7268e12 100644 --- a/src/lighteval/tasks/tasks_table.jsonl +++ b/src/lighteval/tasks/tasks_table.jsonl @@ -28,6 +28,33 @@ {"name":"auto_categorization","suite":["bigbench","bigbench_json"],"prompt_function":"bigbench","hf_repo":"bigbench","hf_subset":"auto_categorization","hf_avail_splits":["default","train","validation"],"evaluation_splits":["default"],"few_shots_split":null,"few_shots_select":null,"generation_size":1,"metric":["bleu"],"stop_sequence":["\n"],"output_regex":null,"frozen":false, "trust_dataset": true} {"name":"auto_debugging","suite":["bigbench_lite","bigbench","bigbench_json"],"prompt_function":"bigbench_linefeed_before_and_after_query","hf_repo":"bigbench","hf_subset":"auto_debugging","hf_avail_splits":["default","train","validation"],"evaluation_splits":["default"],"few_shots_split":null,"few_shots_select":null,"generation_size":100,"metric":["perfect_exact_match"],"stop_sequence":null,"output_regex":"[^\\.\\?\\!\\;\\n]+", "trust_dataset": true} {"name":"babi_qa","suite":["helm"],"prompt_function":"babi_qa","hf_repo":"facebook\/babi_qa","hf_subset":"en-valid-qa1","hf_avail_splits":["train","test","validation"],"evaluation_splits":["validation","test"],"few_shots_split":null,"few_shots_select":null,"generation_size":-1,"metric":["exact_match","quasi_exact_match","prefix_exact_match","prefix_quasi_exact_match"],"stop_sequence":["\n"],"output_regex":null,"frozen":false, "trust_dataset": true} +{"name":"bbh:boolean_expressions","suite":["harness"],"prompt_function":"bbh_boolean_expressions","hf_repo":"lukaemon/bbh","hf_subset":"boolean_expressions","hf_avail_splits":["test"],"evaluation_splits":["test"],"few_shots_split":null,"few_shots_select":null,"generation_size":20,"metric":["exact_match","quasi_exact_match","prefix_exact_match","prefix_quasi_exact_match","perfect_exact_match"],"stop_sequence":["", "Q:", "\n\n"],"output_regex":null,"frozen":false, "trust_dataset":true} +{"name":"bbh:causal_judgment","suite":["harness"],"prompt_function":"bbh_causal_judgment","hf_repo":"lukaemon/bbh","hf_subset":"causal_judgement","hf_avail_splits":["test"],"evaluation_splits":["test"],"few_shots_split":null,"few_shots_select":null,"generation_size":20,"metric":["exact_match","quasi_exact_match","prefix_exact_match","prefix_quasi_exact_match","perfect_exact_match"],"stop_sequence":["", "Q:", "\n\n"],"output_regex":null,"frozen":false, "trust_dataset":true} +{"name":"bbh:date_understanding","suite":["harness"],"prompt_function":"bbh_date_understanding","hf_repo":"lukaemon/bbh","hf_subset":"date_understanding","hf_avail_splits":["test"],"evaluation_splits":["test"],"few_shots_split":null,"few_shots_select":null,"generation_size":20,"metric":["exact_match","quasi_exact_match","prefix_exact_match","prefix_quasi_exact_match","perfect_exact_match"],"stop_sequence":["", "Q:", "\n\n"],"output_regex":null,"frozen":false, "trust_dataset":true} +{"name":"bbh:disambiguation_qa","suite":["harness"],"prompt_function":"bbh_disambiguation_qa","hf_repo":"lukaemon/bbh","hf_subset":"disambiguation_qa","hf_avail_splits":["test"],"evaluation_splits":["test"],"few_shots_split":null,"few_shots_select":null,"generation_size":20,"metric":["exact_match","quasi_exact_match","prefix_exact_match","prefix_quasi_exact_match","perfect_exact_match"],"stop_sequence":["", "Q:", "\n\n"],"output_regex":null,"frozen":false, "trust_dataset":true} +{"name":"bbh:dyck_languages","suite":["harness"],"prompt_function":"bbh_dyck_languages","hf_repo":"lukaemon/bbh","hf_subset":"dyck_languages","hf_avail_splits":["test"],"evaluation_splits":["test"],"few_shots_split":null,"few_shots_select":null,"generation_size":20,"metric":["exact_match","quasi_exact_match","prefix_exact_match","prefix_quasi_exact_match","perfect_exact_match"],"stop_sequence":["", "Q:", "\n\n"],"output_regex":null,"frozen":false, "trust_dataset":true} +{"name":"bbh:formal_fallacies","suite":["harness"],"prompt_function":"bbh_formal_fallacies","hf_repo":"lukaemon/bbh","hf_subset":"formal_fallacies","hf_avail_splits":["test"],"evaluation_splits":["test"],"few_shots_split":null,"few_shots_select":null,"generation_size":20,"metric":["exact_match","quasi_exact_match","prefix_exact_match","prefix_quasi_exact_match","perfect_exact_match"],"stop_sequence":["", "Q:", "\n\n"],"output_regex":null,"frozen":false, "trust_dataset":true} +{"name":"bbh:geometric_shapes","suite":["harness"],"prompt_function":"bbh_geometric_shapes","hf_repo":"lukaemon/bbh","hf_subset":"geometric_shapes","hf_avail_splits":["test"],"evaluation_splits":["test"],"few_shots_split":null,"few_shots_select":null,"generation_size":20,"metric":["exact_match","quasi_exact_match","prefix_exact_match","prefix_quasi_exact_match","perfect_exact_match"],"stop_sequence":["", "Q:", "\n\n"],"output_regex":null,"frozen":false, "trust_dataset":true} +{"name":"bbh:hyperbaton","suite":["harness"],"prompt_function":"bbh_hyperbaton","hf_repo":"lukaemon/bbh","hf_subset":"hyperbaton","hf_avail_splits":["test"],"evaluation_splits":["test"],"few_shots_split":null,"few_shots_select":null,"generation_size":20,"metric":["exact_match","quasi_exact_match","prefix_exact_match","prefix_quasi_exact_match","perfect_exact_match"],"stop_sequence":["", "Q:", "\n\n"],"output_regex":null,"frozen":false, "trust_dataset":true} +{"name":"bbh:logical_deduction_five_objects","suite":["harness"],"prompt_function":"bbh_logical_deduction_five_objects","hf_repo":"lukaemon/bbh","hf_subset":"logical_deduction_five_objects","hf_avail_splits":["test"],"evaluation_splits":["test"],"few_shots_split":null,"few_shots_select":null,"generation_size":20,"metric":["exact_match","quasi_exact_match","prefix_exact_match","prefix_quasi_exact_match","perfect_exact_match"],"stop_sequence":["", "Q:", "\n\n"],"output_regex":null,"frozen":false, "trust_dataset":true} +{"name":"bbh:logical_deduction_seven_objects","suite":["harness"],"prompt_function":"bbh_logical_deduction_seven_objects","hf_repo":"lukaemon/bbh","hf_subset":"logical_deduction_seven_objects","hf_avail_splits":["test"],"evaluation_splits":["test"],"few_shots_split":null,"few_shots_select":null,"generation_size":20,"metric":["exact_match","quasi_exact_match","prefix_exact_match","prefix_quasi_exact_match","perfect_exact_match"],"stop_sequence":["", "Q:", "\n\n"],"output_regex":null,"frozen":false, "trust_dataset":true} +{"name":"bbh:logical_deduction_three_objects","suite":["harness"],"prompt_function":"bbh_logical_deduction_three_objects","hf_repo":"lukaemon/bbh","hf_subset":"logical_deduction_three_objects","hf_avail_splits":["test"],"evaluation_splits":["test"],"few_shots_split":null,"few_shots_select":null,"generation_size":20,"metric":["exact_match","quasi_exact_match","prefix_exact_match","prefix_quasi_exact_match","perfect_exact_match"],"stop_sequence":["", "Q:", "\n\n"],"output_regex":null,"frozen":false, "trust_dataset":true} +{"name":"bbh:movie_recommendation","suite":["harness"],"prompt_function":"bbh_movie_recommendation","hf_repo":"lukaemon/bbh","hf_subset":"movie_recommendation","hf_avail_splits":["test"],"evaluation_splits":["test"],"few_shots_split":null,"few_shots_select":null,"generation_size":20,"metric":["exact_match","quasi_exact_match","prefix_exact_match","prefix_quasi_exact_match","perfect_exact_match"],"stop_sequence":["", "Q:", "\n\n"],"output_regex":null,"frozen":false, "trust_dataset":true} +{"name":"bbh:multistep_arithmetic_two","suite":["harness"],"prompt_function":"bbh_multistep_arithmetic_two","hf_repo":"lukaemon/bbh","hf_subset":"multistep_arithmetic_two","hf_avail_splits":["test"],"evaluation_splits":["test"],"few_shots_split":null,"few_shots_select":null,"generation_size":20,"metric":["exact_match","quasi_exact_match","prefix_exact_match","prefix_quasi_exact_match","perfect_exact_match"],"stop_sequence":["", "Q:", "\n\n"],"output_regex":null,"frozen":false, "trust_dataset":true} +{"name":"bbh:navigate","suite":["harness"],"prompt_function":"bbh_navigate","hf_repo":"lukaemon/bbh","hf_subset":"navigate","hf_avail_splits":["test"],"evaluation_splits":["test"],"few_shots_split":null,"few_shots_select":null,"generation_size":20,"metric":["exact_match","quasi_exact_match","prefix_exact_match","prefix_quasi_exact_match","perfect_exact_match"],"stop_sequence":["", "Q:", "\n\n"],"output_regex":null,"frozen":false, "trust_dataset":true} +{"name":"bbh:object_counting","suite":["harness"],"prompt_function":"bbh_object_counting","hf_repo":"lukaemon/bbh","hf_subset":"object_counting","hf_avail_splits":["test"],"evaluation_splits":["test"],"few_shots_split":null,"few_shots_select":null,"generation_size":20,"metric":["exact_match","quasi_exact_match","prefix_exact_match","prefix_quasi_exact_match","perfect_exact_match"],"stop_sequence":["", "Q:", "\n\n"],"output_regex":null,"frozen":false, "trust_dataset":true} +{"name":"bbh:penguins_in_a_table","suite":["harness"],"prompt_function":"bbh_penguins_in_a_table","hf_repo":"lukaemon/bbh","hf_subset":"penguins_in_a_table","hf_avail_splits":["test"],"evaluation_splits":["test"],"few_shots_split":null,"few_shots_select":null,"generation_size":20,"metric":["exact_match","quasi_exact_match","prefix_exact_match","prefix_quasi_exact_match","perfect_exact_match"],"stop_sequence":["", "Q:", "\n\n"],"output_regex":null,"frozen":false, "trust_dataset":true} +{"name":"bbh:reasoning_about_colored_objects","suite":["harness"],"prompt_function":"bbh_reasoning_about_colored_objects","hf_repo":"lukaemon/bbh","hf_subset":"reasoning_about_colored_objects","hf_avail_splits":["test"],"evaluation_splits":["test"],"few_shots_split":null,"few_shots_select":null,"generation_size":20,"metric":["exact_match","quasi_exact_match","prefix_exact_match","prefix_quasi_exact_match","perfect_exact_match"],"stop_sequence":["", "Q:", "\n\n"],"output_regex":null,"frozen":false, "trust_dataset":true} +{"name":"bbh:ruin_names","suite":["harness"],"prompt_function":"bbh_ruin_names","hf_repo":"lukaemon/bbh","hf_subset":"ruin_names","hf_avail_splits":["test"],"evaluation_splits":["test"],"few_shots_split":null,"few_shots_select":null,"generation_size":20,"metric":["exact_match","quasi_exact_match","prefix_exact_match","prefix_quasi_exact_match","perfect_exact_match"],"stop_sequence":["", "Q:", "\n\n"],"output_regex":null,"frozen":false, "trust_dataset":true} +{"name":"bbh:salient_translation_error_detection","suite":["harness"],"prompt_function":"bbh_salient_translation_error_detection","hf_repo":"lukaemon/bbh","hf_subset":"salient_translation_error_detection","hf_avail_splits":["test"],"evaluation_splits":["test"],"few_shots_split":null,"few_shots_select":null,"generation_size":20,"metric":["exact_match","quasi_exact_match","prefix_exact_match","prefix_quasi_exact_match","perfect_exact_match"],"stop_sequence":["", "Q:", "\n\n"],"output_regex":null,"frozen":false, "trust_dataset":true} +{"name":"bbh:snarks","suite":["harness"],"prompt_function":"bbh_snarks","hf_repo":"lukaemon/bbh","hf_subset":"snarks","hf_avail_splits":["test"],"evaluation_splits":["test"],"few_shots_split":null,"few_shots_select":null,"generation_size":20,"metric":["exact_match","quasi_exact_match","prefix_exact_match","prefix_quasi_exact_match","perfect_exact_match"],"stop_sequence":["", "Q:", "\n\n"],"output_regex":null,"frozen":false, "trust_dataset":true} +{"name":"bbh:sports_understanding","suite":["harness"],"prompt_function":"bbh_sports_understanding","hf_repo":"lukaemon/bbh","hf_subset":"sports_understanding","hf_avail_splits":["test"],"evaluation_splits":["test"],"few_shots_split":null,"few_shots_select":null,"generation_size":20,"metric":["exact_match","quasi_exact_match","prefix_exact_match","prefix_quasi_exact_match","perfect_exact_match"],"stop_sequence":["", "Q:", "\n\n"],"output_regex":null,"frozen":false, "trust_dataset":true} +{"name":"bbh:temporal_sequences","suite":["harness"],"prompt_function":"bbh_temporal_sequences","hf_repo":"lukaemon/bbh","hf_subset":"temporal_sequences","hf_avail_splits":["test"],"evaluation_splits":["test"],"few_shots_split":null,"few_shots_select":null,"generation_size":20,"metric":["exact_match","quasi_exact_match","prefix_exact_match","prefix_quasi_exact_match","perfect_exact_match"],"stop_sequence":["", "Q:", "\n\n"],"output_regex":null,"frozen":false, "trust_dataset":true} +{"name":"bbh:tracking_shuffled_objects_five_objects","suite":["harness"],"prompt_function":"bbh_tracking_shuffled_objects_five_objects","hf_repo":"lukaemon/bbh","hf_subset":"tracking_shuffled_objects_five_objects","hf_avail_splits":["test"],"evaluation_splits":["test"],"few_shots_split":null,"few_shots_select":null,"generation_size":20,"metric":["exact_match","quasi_exact_match","prefix_exact_match","prefix_quasi_exact_match","perfect_exact_match"],"stop_sequence":["", "Q:", "\n\n"],"output_regex":null,"frozen":false, "trust_dataset":true} +{"name":"bbh:tracking_shuffled_objects_seven_objects","suite":["harness"],"prompt_function":"bbh_tracking_shuffled_objects_seven_objects","hf_repo":"lukaemon/bbh","hf_subset":"tracking_shuffled_objects_seven_objects","hf_avail_splits":["test"],"evaluation_splits":["test"],"few_shots_split":null,"few_shots_select":null,"generation_size":20,"metric":["exact_match","quasi_exact_match","prefix_exact_match","prefix_quasi_exact_match","perfect_exact_match"],"stop_sequence":["", "Q:", "\n\n"],"output_regex":null,"frozen":false, "trust_dataset":true} +{"name":"bbh:tracking_shuffled_objects_three_objects","suite":["harness"],"prompt_function":"bbh_tracking_shuffled_objects_three_objects","hf_repo":"lukaemon/bbh","hf_subset":"tracking_shuffled_objects_three_objects","hf_avail_splits":["test"],"evaluation_splits":["test"],"few_shots_split":null,"few_shots_select":null,"generation_size":20,"metric":["exact_match","quasi_exact_match","prefix_exact_match","prefix_quasi_exact_match","perfect_exact_match"],"stop_sequence":["", "Q:", "\n\n"],"output_regex":null,"frozen":false, "trust_dataset":true} +{"name":"bbh:web_of_lies","suite":["harness"],"prompt_function":"bbh_web_of_lies","hf_repo":"lukaemon/bbh","hf_subset":"web_of_lies","hf_avail_splits":["test"],"evaluation_splits":["test"],"few_shots_split":null,"few_shots_select":null,"generation_size":20,"metric":["exact_match","quasi_exact_match","prefix_exact_match","prefix_quasi_exact_match","perfect_exact_match"],"stop_sequence":["", "Q:", "\n\n"],"output_regex":null,"frozen":false, "trust_dataset":true} +{"name":"bbh:word_sorting","suite":["harness"],"prompt_function":"bbh_word_sorting","hf_repo":"lukaemon/bbh","hf_subset":"word_sorting","hf_avail_splits":["test"],"evaluation_splits":["test"],"few_shots_split":null,"few_shots_select":null,"generation_size":20,"metric":["exact_match","quasi_exact_match","prefix_exact_match","prefix_quasi_exact_match","perfect_exact_match"],"stop_sequence":["", "Q:", "\n\n"],"output_regex":null,"frozen":false, "trust_dataset":true} {"name":"bbq","suite":["helm"],"prompt_function":"bbq","hf_repo":"lighteval\/bbq_helm","hf_subset":"all","hf_avail_splits":["train","test"],"evaluation_splits":["test"],"few_shots_split":null,"few_shots_select":null,"generation_size":-1,"metric":["exact_match","quasi_exact_match","prefix_exact_match","prefix_quasi_exact_match","perfect_exact_match"],"stop_sequence":["\n"],"output_regex":null,"frozen":false, "trust_dataset": true} {"name":"bbq:Age","suite":["helm"],"prompt_function":"bbq","hf_repo":"lighteval\/bbq_helm","hf_subset":"Age","hf_avail_splits":["train","test"],"evaluation_splits":["test"],"few_shots_split":null,"few_shots_select":null,"generation_size":-1,"metric":["exact_match","quasi_exact_match","prefix_exact_match","prefix_quasi_exact_match","perfect_exact_match"],"stop_sequence":["\n"],"output_regex":null,"frozen":false, "trust_dataset": true} {"name":"bbq:Disability_status","suite":["helm"],"prompt_function":"bbq","hf_repo":"lighteval\/bbq_helm","hf_subset":"Disability_status","hf_avail_splits":["train","test"],"evaluation_splits":["test"],"few_shots_split":null,"few_shots_select":null,"generation_size":-1,"metric":["exact_match","quasi_exact_match","prefix_exact_match","prefix_quasi_exact_match","perfect_exact_match"],"stop_sequence":["\n"],"output_regex":null,"frozen":false, "trust_dataset": true} diff --git a/tasks_examples/bbh.txt b/tasks_examples/bbh.txt new file mode 100644 index 00000000..6b90fa3a --- /dev/null +++ b/tasks_examples/bbh.txt @@ -0,0 +1,36 @@ +lighteval|bigbench:causal_judgment|3|0 +lighteval|bigbench:date_understanding|3|0 +lighteval|bigbench:disambiguation_qa|3|0 +lighteval|bigbench:geometric_shapes|3|0 +lighteval|bigbench:logical_deduction_five_objects|3|0 +lighteval|bigbench:logical_deduction_seven_objects|3|0 +lighteval|bigbench:logical_deduction_three_objects|3|0 +lighteval|bigbench:movie_recommendation|3|0 +lighteval|bigbench:navigate|3|0 +lighteval|bigbench:reasoning_about_colored_objects|3|0 +lighteval|bigbench:ruin_names|3|0 +lighteval|bigbench:salient_translation_error_detection|3|0 +lighteval|bigbench:snarks|3|0 +lighteval|bigbench:sports_understanding|3|0 +lighteval|bigbench:temporal_sequences|3|0 +lighteval|bigbench:tracking_shuffled_objects_five_objects|3|0 +lighteval|bigbench:tracking_shuffled_objects_seven_objects|3|0 +lighteval|bigbench:tracking_shuffled_objects_three_objects|3|0 +harness|bigbench:causal_judgment|3|0 +harness|bigbench:date_understanding|3|0 +harness|bigbench:disambiguation_qa|3|0 +harness|bigbench:geometric_shapes|3|0 +harness|bigbench:logical_deduction_five_objects|3|0 +harness|bigbench:logical_deduction_seven_objects|3|0 +harness|bigbench:logical_deduction_three_objects|3|0 +harness|bigbench:movie_recommendation|3|0 +harness|bigbench:navigate|3|0 +harness|bigbench:reasoning_about_colored_objects|3|0 +harness|bigbench:ruin_names|3|0 +harness|bigbench:salient_translation_error_detection|3|0 +harness|bigbench:snarks|3|0 +harness|bigbench:sports_understanding|3|0 +harness|bigbench:temporal_sequences|3|0 +harness|bigbench:tracking_shuffled_objects_five_objects|3|0 +harness|bigbench:tracking_shuffled_objects_seven_objects|3|0 +harness|bigbench:tracking_shuffled_objects_three_objects|3|0 diff --git a/tests/reference_scores/reference_task_scores.py b/tests/reference_scores/reference_task_scores.py index f901a476..92695470 100644 --- a/tests/reference_scores/reference_task_scores.py +++ b/tests/reference_scores/reference_task_scores.py @@ -190,6 +190,133 @@ "pqem_stderr": 0.00428777857558616, }, "leaderboard|gsm8k|5|0": {"qem": 0.006065200909780136, "qem_stderr": 0.0021386703014604626}, + # "harness|gsm8k|5|0": {"acc": 0.004548900682335102, "acc_stderr": 0.0018535550440036204}, Actual harness results + "harness|bigbench:causal_judgment|3|0": { + "acc": 0.4842, + "acc_stderr": 0.0364, + "acc_norm": 0.4947, + "acc_norm_stderr": 0.0364, + }, + "harness|bigbench:date_understanding|3|0": { + "acc": 0.2764, + "acc_stderr": 0.0233, + "acc_norm": 0.2764, + "acc_norm_stderr": 0.0233, + }, + "harness|bigbench:disambiguation_qa|3|0": { + "acc": 0.3372, + "acc_stderr": 0.0295, + "acc_norm": 0.3450, + "acc_norm_stderr": 0.0297, + }, + "harness|bigbench:geometric_shapes|3|0": { + "acc": 0.1058, + "acc_stderr": 0.0163, + "acc_norm": 0.1476, + "acc_norm_stderr": 0.0187, + }, + "harness|bigbench:logical_deduction_five_objects|3|0": { + "acc": 0.2080, + "acc_stderr": 0.0182, + "acc_norm": 0.2120, + "acc_norm_stderr": 0.0183, + }, + "harness|bigbench:logical_deduction_seven_objects|3|0": { + "acc": 0.1743, + "acc_stderr": 0.0143, + "acc_norm": 0.1743, + "acc_norm_stderr": 0.0143, + }, + "harness|bigbench:logical_deduction_three_objects|3|0": { + "acc": 0.3033, + "acc_stderr": 0.0266, + "acc_norm": 0.3167, + "acc_norm_stderr": 0.0269, + }, + "harness|bigbench:movie_recommendation|3|0": { + "acc": 0.3900, + "acc_stderr": 0.0218, + "acc_norm": 0.3460, + "acc_norm_stderr": 0.0213, + }, + "harness|bigbench:navigate|3|0": { + "acc": 0.4990, + "acc_stderr": 0.0158, + "acc_norm": 0.5000, + "acc_norm_stderr": 0.0158, + }, + "harness|bigbench:reasoning_about_colored_objects|3|0": { + "acc": 0.1665, + "acc_stderr": 0.0083, + "acc_norm": 0.1535, + "acc_norm_stderr": 0.0081, + }, + "harness|bigbench:ruin_names|3|0": { + "acc": 0.3393, + "acc_stderr": 0.0224, + "acc_norm": 0.3237, + "acc_norm_stderr": 0.0221, + }, + "harness|bigbench:salient_translation_error_detection|3|0": { + "acc": 0.1834, + "acc_stderr": 0.0123, + "acc_norm": 0.1834, + "acc_norm_stderr": 0.0123, + }, + "harness|bigbench:snarks|3|0": { + "acc": 0.5359, + "acc_stderr": 0.0372, + "acc_norm": 0.5359, + "acc_norm_stderr": 0.0372, + }, + "harness|bigbench:sports_understanding|3|0": { + "acc": 0.5010, + "acc_stderr": 0.0159, + "acc_norm": 0.5020, + "acc_norm_stderr": 0.0159, + }, + "harness|bigbench:temporal_sequences|3|0": { + "acc": 0.2700, + "acc_stderr": 0.0140, + "acc_norm": 0.2710, + "acc_norm_stderr": 0.0141, + }, + "harness|bigbench:tracking_shuffled_objects_five_objects|3|0": { + "acc": 0.1928, + "acc_stderr": 0.0112, + "acc_norm": 0.1976, + "acc_norm_stderr": 0.0113, + }, + "harness|bigbench:tracking_shuffled_objects_seven_objects|3|0": { + "acc": 0.1463, + "acc_stderr": 0.0085, + "acc_norm": 0.1383, + "acc_norm_stderr": 0.0083, + }, + "harness|bigbench:tracking_shuffled_objects_three_objects|3|0": { + "acc": 0.3033, + "acc_stderr": 0.0266, + "acc_norm": 0.3167, + "acc_norm_stderr": 0.0269, + }, + "lighteval|bigbench:causal_judgment|3|0": {"acc": 0.5158, "acc_stderr": 0.0364}, + "lighteval|bigbench:date_understanding|3|0": {"acc": 0.0000, "acc_stderr": 0.0000}, + "lighteval|bigbench:disambiguation_qa|3|0": {"acc": 0.2984, "acc_stderr": 0.0285}, + "lighteval|bigbench:geometric_shapes|3|0": {"acc": 0.0972, "acc_stderr": 0.0156}, + "lighteval|bigbench:logical_deduction_five_objects|3|0": {"acc": 0.2000, "acc_stderr": 0.0179}, + "lighteval|bigbench:logical_deduction_seven_objects|3|0": {"acc": 0.1429, "acc_stderr": 0.0132}, + "lighteval|bigbench:logical_deduction_three_objects|3|0": {"acc": 0.3333, "acc_stderr": 0.0273}, + "lighteval|bigbench:movie_recommendation|3|0": {"acc": 0.2540, "acc_stderr": 0.0195}, + "lighteval|bigbench:navigate|3|0": {"acc": 0.4990, "acc_stderr": 0.0158}, + "lighteval|bigbench:reasoning_about_colored_objects|3|0": {"acc": 0.1560, "acc_stderr": 0.0081}, + "lighteval|bigbench:ruin_names|3|0": {"acc": 0.2411, "acc_stderr": 0.0202}, + "lighteval|bigbench:salient_translation_error_detection|3|0": {"acc": 0.1673, "acc_stderr": 0.0118}, + "lighteval|bigbench:snarks|3|0": {"acc": 0.4696, "acc_stderr": 0.0372}, + "lighteval|bigbench:sports_understanding|3|0": {"acc": 0.4990, "acc_stderr": 0.0158}, + "lighteval|bigbench:temporal_sequences|3|0": {"acc": 1.0000, "acc_stderr": 0.0000}, + "lighteval|bigbench:tracking_shuffled_objects_five_objects|3|0": {"acc": 0.1976, "acc_stderr": 0.0113}, + "lighteval|bigbench:tracking_shuffled_objects_seven_objects|3|0": {"acc": 0.1406, "acc_stderr": 0.0083}, + "lighteval|bigbench:tracking_shuffled_objects_three_objects|3|0": {"acc": 0.3333, "acc_stderr": 0.0273}, }, } @@ -362,5 +489,131 @@ "pqem_stderr": 0.09999999999999999, }, "leaderboard|gsm8k|5|0": {"qem": 0.0, "qem_stderr": 0.0}, + "harness|bigbench:causal_judgment|3|0": { + "acc": 0.6000, + "acc_stderr": 0.1633, + "acc_norm": 0.5000, + "acc_norm_stderr": 0.1667, + }, + "harness|bigbench:date_understanding|3|0": { + "acc": 0.2000, + "acc_stderr": 0.1333, + "acc_norm": 0.2000, + "acc_norm_stderr": 0.1333, + }, + "harness|bigbench:disambiguation_qa|3|0": { + "acc": 0.7000, + "acc_stderr": 0.1528, + "acc_norm": 0.3000, + "acc_norm_stderr": 0.1528, + }, + "harness|bigbench:geometric_shapes|3|0": { + "acc": 0.0000, + "acc_stderr": 0.0000, + "acc_norm": 0.2000, + "acc_norm_stderr": 0.1333, + }, + "harness|bigbench:logical_deduction_five_objects|3|0": { + "acc": 0.3000, + "acc_stderr": 0.1528, + "acc_norm": 0.3000, + "acc_norm_stderr": 0.1528, + }, + "harness|bigbench:logical_deduction_seven_objects|3|0": { + "acc": 0.2000, + "acc_stderr": 0.1333, + "acc_norm": 0.2000, + "acc_norm_stderr": 0.1333, + }, + "harness|bigbench:logical_deduction_three_objects|3|0": { + "acc": 0.2000, + "acc_stderr": 0.1333, + "acc_norm": 0.3000, + "acc_norm_stderr": 0.1528, + }, + "harness|bigbench:movie_recommendation|3|0": { + "acc": 0.5000, + "acc_stderr": 0.1667, + "acc_norm": 0.3000, + "acc_norm_stderr": 0.1528, + }, + "harness|bigbench:navigate|3|0": { + "acc": 0.6000, + "acc_stderr": 0.1633, + "acc_norm": 0.6000, + "acc_norm_stderr": 0.1633, + }, + "harness|bigbench:reasoning_about_colored_objects|3|0": { + "acc": 0.2000, + "acc_stderr": 0.1333, + "acc_norm": 0.1000, + "acc_norm_stderr": 0.1000, + }, + "harness|bigbench:ruin_names|3|0": { + "acc": 0.2000, + "acc_stderr": 0.1333, + "acc_norm": 0.2000, + "acc_norm_stderr": 0.1333, + }, + "harness|bigbench:salient_translation_error_detection|3|0": { + "acc": 0.1000, + "acc_stderr": 0.1000, + "acc_norm": 0.1000, + "acc_norm_stderr": 0.1000, + }, + "harness|bigbench:snarks|3|0": { + "acc": 0.4000, + "acc_stderr": 0.1633, + "acc_norm": 0.4000, + "acc_norm_stderr": 0.1633, + }, + "harness|bigbench:sports_understanding|3|0": { + "acc": 0.6000, + "acc_stderr": 0.1633, + "acc_norm": 0.6000, + "acc_norm_stderr": 0.1633, + }, + "harness|bigbench:temporal_sequences|3|0": { + "acc": 0.1000, + "acc_stderr": 0.1000, + "acc_norm": 0.1000, + "acc_norm_stderr": 0.1000, + }, + "harness|bigbench:tracking_shuffled_objects_five_objects|3|0": { + "acc": 0.2000, + "acc_stderr": 0.1333, + "acc_norm": 0.1000, + "acc_norm_stderr": 0.1000, + }, + "harness|bigbench:tracking_shuffled_objects_seven_objects|3|0": { + "acc": 0.1000, + "acc_stderr": 0.1000, + "acc_norm": 0.0000, + "acc_norm_stderr": 0.0000, + }, + "harness|bigbench:tracking_shuffled_objects_three_objects|3|0": { + "acc": 0.2000, + "acc_stderr": 0.1333, + "acc_norm": 0.3000, + "acc_norm_stderr": 0.1528, + }, + "lighteval|bigbench:causal_judgment|3|0": {"acc": 0.5000, "acc_stderr": 0.1667}, + "lighteval|bigbench:date_understanding|3|0": {"acc": 0.0000, "acc_stderr": 0.0000}, + "lighteval|bigbench:disambiguation_qa|3|0": {"acc": 0.7000, "acc_stderr": 0.1528}, + "lighteval|bigbench:geometric_shapes|3|0": {"acc": 0.2000, "acc_stderr": 0.1333}, + "lighteval|bigbench:logical_deduction_five_objects|3|0": {"acc": 0.1000, "acc_stderr": 0.1000}, + "lighteval|bigbench:logical_deduction_seven_objects|3|0": {"acc": 0.2000, "acc_stderr": 0.1333}, + "lighteval|bigbench:logical_deduction_three_objects|3|0": {"acc": 0.4000, "acc_stderr": 0.1633}, + "lighteval|bigbench:movie_recommendation|3|0": {"acc": 0.3000, "acc_stderr": 0.1528}, + "lighteval|bigbench:navigate|3|0": {"acc": 0.4000, "acc_stderr": 0.1633}, + "lighteval|bigbench:reasoning_about_colored_objects|3|0": {"acc": 0.2000, "acc_stderr": 0.1333}, + "lighteval|bigbench:ruin_names|3|0": {"acc": 0.3000, "acc_stderr": 0.1528}, + "lighteval|bigbench:salient_translation_error_detection|3|0": {"acc": 0.4000, "acc_stderr": 0.1633}, + "lighteval|bigbench:snarks|3|0": {"acc": 0.6000, "acc_stderr": 0.1633}, + "lighteval|bigbench:sports_understanding|3|0": {"acc": 0.6000, "acc_stderr": 0.1633}, + "lighteval|bigbench:temporal_sequences|3|0": {"acc": 1.0000, "acc_stderr": 0.0000}, + "lighteval|bigbench:tracking_shuffled_objects_five_objects|3|0": {"acc": 0.2000, "acc_stderr": 0.1333}, + "lighteval|bigbench:tracking_shuffled_objects_seven_objects|3|0": {"acc": 0.3000, "acc_stderr": 0.1528}, + "lighteval|bigbench:tracking_shuffled_objects_three_objects|3|0": {"acc": 0.4000, "acc_stderr": 0.1633}, }, }