Skip to content

Commit

Permalink
Add proposed exact and f1 scorers.
Browse files Browse the repository at this point in the history
`exact()`
Scorer which will normalize the text of the answer and target(s) and perform an exact matching comparison of the text. This scorer will return `CORRECT` when the answer is an exact match to one or more targets.

`f1()`
Scorer which computes the `F1` score for the answer (which balances recall precision by taking the harmonic mean between recall and precision).
  • Loading branch information
dragonstyle committed Aug 31, 2024
1 parent da4ea9b commit e15dcbc
Show file tree
Hide file tree
Showing 4 changed files with 237 additions and 0 deletions.
8 changes: 8 additions & 0 deletions docs/scorers.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,14 @@ Inspect includes some simple text matching scorers as well as a couple of model

Scorer for model output that preceded answers with "ANSWER: ". Can extract letters, words, or the remainder of the line.

- `exact()`

Scorer which will normalize the text of the answer and target(s) and perform an exact matching comparison of the text. This scorer will return `CORRECT` when the answer is an exact match to one or more targets.

- `f1()`

Scorer which computes the `F1` score for the answer (which balances recall precision by taking the harmonic mean between recall and precision).

- `model_graded_qa()`

Have another model assess whether the model output is a correct answer based on the grading guidance contained in `target`. Has a built-in template that can be customised.
Expand Down
3 changes: 3 additions & 0 deletions src/inspect_ai/scorer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from ._answer import AnswerPattern, answer
from ._choice import choice
from ._classification import exact, f1
from ._match import includes, match
from ._metric import (
CORRECT,
Expand Down Expand Up @@ -41,6 +42,8 @@
"answer",
"choice",
"pattern",
"f1",
"exact",
"AnswerPattern",
"Scorer",
"Target",
Expand Down
157 changes: 157 additions & 0 deletions src/inspect_ai/scorer/_classification.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
import re
import string
from typing import List

from inspect_ai.solver._task_state import TaskState

from ._metric import CORRECT, INCORRECT, Score
from ._metrics import mean, stderr
from ._scorer import Scorer, scorer
from ._target import Target


@scorer(metrics=[mean(), stderr()])
def f1() -> Scorer:
"""Scorer which produces an F1 score
Computes the `F1` score for the answer (which balances recall precision by taking the harmonic mean between recall and precision).
"""

async def score(state: TaskState, target: Target) -> Score:
# Get generated answer and extract relevant answer text
answer = state.output.completion
targets = target.target

f1_score = max_f1_score(answer, targets)
return Score(value=f1_score, answer=answer)

return score


@scorer(metrics=[mean(), stderr()])
def exact() -> Scorer:
"""Scorer which produces an exact match score
Normalizes the text of the answer and target(s) and performs an exact matching comparison of the text. This scorer will return `CORRECT` when the answer is an exact match to one or more targets.
"""

async def score(state: TaskState, target: Target) -> Score:
# Get generated answer and extract relevant answer text
answer = state.output.completion
targets = target.target

exact_score = max_exact_score(answer, targets)
return Score(value=CORRECT if exact_score == 1.0 else INCORRECT, answer=answer)

return score


def max_f1_score(answer: str, targets: List[str]) -> float:
# Find the maximum F1 score for this answer
max_f1 = 0.0
for target in targets:
if target[0].strip():
f1_score = compute_f1(answer, target)
max_f1 = max(max_f1, f1_score)
return max_f1


def max_exact_score(answer: str, targets: List[str]) -> float:
# Find the maximum exact score for this answer
max_exact = 0.0
for target in targets:
if target[0].strip():
exact_score = compute_exact(answer, target)
max_exact = max(max_exact, exact_score)
return max_exact


def compute_f1(answer: str, target: str) -> float:
"""Takes a predicted answer and a gold answer (that are both either a string or a list of strings), and returns exact match and the SQuAD F1 metric for the prediction."""
answer_words = _to_words(answer)
target_words = _to_words(target)

return _f1(answer_words=answer_words, target_words=target_words)


def compute_exact(answer: str, target: str) -> float:
return answer == target


def _to_words(
answer: str,
) -> set[str]:
normalized = _normalize(answer)
token_bag = set(normalized.split())
return token_bag


def _f1(answer_words: set[str], target_words: set[str]) -> float:
intersection = len(answer_words.intersection(target_words))
if not answer_words:
precision = 1.0
else:
precision = intersection / float(len(answer_words))
if not target_words:
recall = 1.0
else:
recall = intersection / float(len(target_words))
f1 = (
(2 * precision * recall) / (precision + recall)
if not (precision == 0.0 and recall == 0.0)
else 0.0
)
return f1


def _is_number(text: str) -> bool:
try:
float(text)
return True
except ValueError:
return False


def _remove_articles(text: str) -> str:
_ARTICLES = re.compile(r"\b(a|an|the)\b", re.UNICODE)
return _ARTICLES.sub(" ", text)


def _remove_punc(text: str) -> str:
exclude = set(string.punctuation)
is_number = _is_number(text)
if not is_number:
return "".join(ch for ch in text if ch not in exclude)
else:
return text


def _normalize_whitespace(text: str) -> str:
return " ".join(text.split())


def _normalize_number(text: str) -> str:
is_number = _is_number(text)
if is_number:
return str(float(text))
else:
return text


def _tokenize(text: str) -> List[str]:
return re.split(" |-", text)


def _normalize(answer: str) -> str:
"""Normalize text to remove extraneous characters and words."""
tokens = []
tokenized_answer = _tokenize(answer)
for token in tokenized_answer:
token = _remove_punc(token.lower())
token = _normalize_number(token)
token = _remove_articles(token)
token = _normalize_whitespace(token)
tokens.append(token)
tokens = [token for token in tokens if token.strip()]
normalized = " ".join(tokens).strip()
return normalized
69 changes: 69 additions & 0 deletions tests/scorer/test_classification.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import pytest
from test_helpers.utils import simple_task_state

from inspect_ai.scorer import Target
from inspect_ai.scorer._classification import exact, f1
from inspect_ai.scorer._metric import CORRECT, INCORRECT


@pytest.mark.asyncio
async def test_exact_match():
scorer = exact()
state = simple_task_state(model_output="foo")
result = await scorer(state, Target(["foo"]))

assert result.text == CORRECT


@pytest.mark.asyncio
async def test_exact_match_max():
scorer = exact()
state = simple_task_state(model_output="foo")
result = await scorer(state, Target(["foobar", "boofar", "foo"]))

assert result.text == CORRECT


@pytest.mark.asyncio
async def test_exact_nonmatch():
scorer = exact()
state = simple_task_state(model_output="foo1")
result = await scorer(state, Target(["foo"]))

assert result.text == INCORRECT


@pytest.mark.asyncio
async def test_f1_basic_match():
scorer = f1()
state = simple_task_state(model_output="foo")
result = await scorer(state, Target(["foo"]))

assert result.text == "1.0"


@pytest.mark.asyncio
async def test_f1_basic_nonmatch():
scorer = f1()
state = simple_task_state(model_output="foo1")
result = await scorer(state, Target(["foo"]))

assert result.text == "0.0"


@pytest.mark.asyncio
async def test_f1_good_match():
scorer = f1()
state = simple_task_state(model_output="Paris")
result = await scorer(state, Target(["Paris, Texas", "Paris"]))

assert result.text == "1.0"


@pytest.mark.asyncio
async def test_f1_partial_match():
scorer = f1()
state = simple_task_state(model_output="Paris")
result = await scorer(state, Target(["Paris, Texas"]))

assert result.text == "0.6666666666666666"

0 comments on commit e15dcbc

Please sign in to comment.