Skip to content

Commit

Permalink
Update drop benchmark to simplify data reading
Browse files Browse the repository at this point in the history
Tested against main baseline using 100 samples

### Gpt-4o
| Branch | Accuracy | StdErr |
| main | 0.856 | 0.0257 |
| f1 | 0.879 | 0.0229 |

### Gpt-4o-mini
| Branch | Accuracy | StdErr |
| main | 0.868 | 0.031 |
| f1 | 0.877 | 0.0285 |
  • Loading branch information
dragonstyle committed Sep 4, 2024
1 parent 09c4351 commit b98c011
Showing 1 changed file with 19 additions and 49 deletions.
68 changes: 19 additions & 49 deletions benchmarks/drop/drop.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,12 @@
"""

import re
from typing import Dict, List, Tuple
from typing import Dict, List

from inspect_ai import Task, task
from inspect_ai.dataset import Sample, hf_dataset
from inspect_ai.scorer import Score, Target, bootstrap_std, mean, scorer
from inspect_ai.scorer._classification import max_f1_score
from inspect_ai.scorer import f1
from inspect_ai.solver import (
TaskState,
generate,
prompt_template,
system_message,
Expand Down Expand Up @@ -71,34 +69,14 @@ def drop(
sample_fields=record_to_sample,
),
plan=build_plan(fewshot=fewshot, fewshot_seed=fewshot_seed),
scorer=drop_f1_scorer(),
scorer=f1(extract_answer),
)


@scorer(metrics=[mean(), bootstrap_std()])
def drop_f1_scorer():
async def score(state: TaskState, target: Target) -> Score:
# Get generated answer and extract relevant answer text
answer = state.output.completion
match = re.search(ANSWER_PATTERN, answer)
answer = match.group(1) if match else answer

# Get target answers (convert str elm to tuple by splitting on "|")
ref_answers = [el for elm in target.target for el in elm.split("|")]

# Compute exact match (EM) and F1 score
f1_score = max_f1_score(answer, ref_answers)

# F1 score reported as main aggregated metric, EM score added to metadata
return Score(
value=f1_score,
answer=answer,
metadata={
"gold_answers": ref_answers,
},
)

return score
def extract_answer(answer: str) -> str:
"""Function to extract the answer from the completion"""
match = re.search(ANSWER_PATTERN, answer)
return match.group(1) if match else answer


def build_plan(
Expand Down Expand Up @@ -144,7 +122,7 @@ def build_plan(
def record_to_sample(record: Dict) -> Sample:
return Sample(
input=format_input(record),
target=format_target(record),
target=get_answers(record),
id=record["query_id"],
)

Expand All @@ -157,21 +135,14 @@ def format_input(doc: Dict) -> str:
return input_str


def format_target(doc: Dict) -> List[str]:
target = get_answers(doc)
# Convert each tuple to str, since 'target' only accepts 'str' or 'List[str]'.
target = ["|".join(elm) for elm in target]
return target


def sample_to_fewshot(sample: Sample) -> str:
target = sample.target[0].split("|")[0]
return f"""{sample.input} {target}"""


# Copied from
# https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/tasks/drop/utils.py#L23C1-L49C19
def get_answers(doc: Dict) -> List:
def get_answers(doc: Dict) -> list[str]:
def _flatten_validated_answers(validated_answers: Dict) -> List[Dict]:
"""Flattens a dict of lists of validated answers.
Expand All @@ -189,8 +160,8 @@ def _flatten_validated_answers(validated_answers: Dict) -> List[Dict]:
)
return valid_answers

answers = []
answers_set = set()
answers: list[str] = []
answers_set: set[str] = set()
candidates = [doc["answer"]] + _flatten_validated_answers(doc["validated_answers"])
for candidate in candidates:
answer = parse_answer(candidate)
Expand All @@ -201,14 +172,13 @@ def _flatten_validated_answers(validated_answers: Dict) -> List[Dict]:
return answers


def parse_answer(answer: Dict) -> Tuple:
# NOTE: Everything is returned as a tuple for uniformity and hashability.
def parse_answer(answer: Dict) -> str:
# NOTE: Everything is converted to a string here, since this will ultimately
# be treated as a bag of words
if answer["number"] != "":
return (str(answer["number"]),)
return str(answer["number"])
if answer["spans"] != []:
return tuple(answer["spans"])
return (
" ".join(
[answer["date"]["day"], answer["date"]["month"], answer["date"]["year"]]
).strip(),
)
return ",".join(answer["spans"])
return " ".join(
[answer["date"]["day"], answer["date"]["month"], answer["date"]["year"]]
).strip()

0 comments on commit b98c011

Please sign in to comment.