-
Notifications
You must be signed in to change notification settings - Fork 162
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
`exact()` Scorer which will normalize the text of the answer and target(s) and perform an exact matching comparison of the text. This scorer will return `CORRECT` when the answer is an exact match to one or more targets. `f1()` Scorer which computes the `F1` score for the answer (which balances recall precision by taking the harmonic mean between recall and precision).
- Loading branch information
1 parent
da4ea9b
commit e15dcbc
Showing
4 changed files
with
237 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,157 @@ | ||
import re | ||
import string | ||
from typing import List | ||
|
||
from inspect_ai.solver._task_state import TaskState | ||
|
||
from ._metric import CORRECT, INCORRECT, Score | ||
from ._metrics import mean, stderr | ||
from ._scorer import Scorer, scorer | ||
from ._target import Target | ||
|
||
|
||
@scorer(metrics=[mean(), stderr()]) | ||
def f1() -> Scorer: | ||
"""Scorer which produces an F1 score | ||
Computes the `F1` score for the answer (which balances recall precision by taking the harmonic mean between recall and precision). | ||
""" | ||
|
||
async def score(state: TaskState, target: Target) -> Score: | ||
# Get generated answer and extract relevant answer text | ||
answer = state.output.completion | ||
targets = target.target | ||
|
||
f1_score = max_f1_score(answer, targets) | ||
return Score(value=f1_score, answer=answer) | ||
|
||
return score | ||
|
||
|
||
@scorer(metrics=[mean(), stderr()]) | ||
def exact() -> Scorer: | ||
"""Scorer which produces an exact match score | ||
Normalizes the text of the answer and target(s) and performs an exact matching comparison of the text. This scorer will return `CORRECT` when the answer is an exact match to one or more targets. | ||
""" | ||
|
||
async def score(state: TaskState, target: Target) -> Score: | ||
# Get generated answer and extract relevant answer text | ||
answer = state.output.completion | ||
targets = target.target | ||
|
||
exact_score = max_exact_score(answer, targets) | ||
return Score(value=CORRECT if exact_score == 1.0 else INCORRECT, answer=answer) | ||
|
||
return score | ||
|
||
|
||
def max_f1_score(answer: str, targets: List[str]) -> float: | ||
# Find the maximum F1 score for this answer | ||
max_f1 = 0.0 | ||
for target in targets: | ||
if target[0].strip(): | ||
f1_score = compute_f1(answer, target) | ||
max_f1 = max(max_f1, f1_score) | ||
return max_f1 | ||
|
||
|
||
def max_exact_score(answer: str, targets: List[str]) -> float: | ||
# Find the maximum exact score for this answer | ||
max_exact = 0.0 | ||
for target in targets: | ||
if target[0].strip(): | ||
exact_score = compute_exact(answer, target) | ||
max_exact = max(max_exact, exact_score) | ||
return max_exact | ||
|
||
|
||
def compute_f1(answer: str, target: str) -> float: | ||
"""Takes a predicted answer and a gold answer (that are both either a string or a list of strings), and returns exact match and the SQuAD F1 metric for the prediction.""" | ||
answer_words = _to_words(answer) | ||
target_words = _to_words(target) | ||
|
||
return _f1(answer_words=answer_words, target_words=target_words) | ||
|
||
|
||
def compute_exact(answer: str, target: str) -> float: | ||
return answer == target | ||
|
||
|
||
def _to_words( | ||
answer: str, | ||
) -> set[str]: | ||
normalized = _normalize(answer) | ||
token_bag = set(normalized.split()) | ||
return token_bag | ||
|
||
|
||
def _f1(answer_words: set[str], target_words: set[str]) -> float: | ||
intersection = len(answer_words.intersection(target_words)) | ||
if not answer_words: | ||
precision = 1.0 | ||
else: | ||
precision = intersection / float(len(answer_words)) | ||
if not target_words: | ||
recall = 1.0 | ||
else: | ||
recall = intersection / float(len(target_words)) | ||
f1 = ( | ||
(2 * precision * recall) / (precision + recall) | ||
if not (precision == 0.0 and recall == 0.0) | ||
else 0.0 | ||
) | ||
return f1 | ||
|
||
|
||
def _is_number(text: str) -> bool: | ||
try: | ||
float(text) | ||
return True | ||
except ValueError: | ||
return False | ||
|
||
|
||
def _remove_articles(text: str) -> str: | ||
_ARTICLES = re.compile(r"\b(a|an|the)\b", re.UNICODE) | ||
return _ARTICLES.sub(" ", text) | ||
|
||
|
||
def _remove_punc(text: str) -> str: | ||
exclude = set(string.punctuation) | ||
is_number = _is_number(text) | ||
if not is_number: | ||
return "".join(ch for ch in text if ch not in exclude) | ||
else: | ||
return text | ||
|
||
|
||
def _normalize_whitespace(text: str) -> str: | ||
return " ".join(text.split()) | ||
|
||
|
||
def _normalize_number(text: str) -> str: | ||
is_number = _is_number(text) | ||
if is_number: | ||
return str(float(text)) | ||
else: | ||
return text | ||
|
||
|
||
def _tokenize(text: str) -> List[str]: | ||
return re.split(" |-", text) | ||
|
||
|
||
def _normalize(answer: str) -> str: | ||
"""Normalize text to remove extraneous characters and words.""" | ||
tokens = [] | ||
tokenized_answer = _tokenize(answer) | ||
for token in tokenized_answer: | ||
token = _remove_punc(token.lower()) | ||
token = _normalize_number(token) | ||
token = _remove_articles(token) | ||
token = _normalize_whitespace(token) | ||
tokens.append(token) | ||
tokens = [token for token in tokens if token.strip()] | ||
normalized = " ".join(tokens).strip() | ||
return normalized |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
import pytest | ||
from test_helpers.utils import simple_task_state | ||
|
||
from inspect_ai.scorer import Target | ||
from inspect_ai.scorer._classification import exact, f1 | ||
from inspect_ai.scorer._metric import CORRECT, INCORRECT | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_exact_match(): | ||
scorer = exact() | ||
state = simple_task_state(model_output="foo") | ||
result = await scorer(state, Target(["foo"])) | ||
|
||
assert result.text == CORRECT | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_exact_match_max(): | ||
scorer = exact() | ||
state = simple_task_state(model_output="foo") | ||
result = await scorer(state, Target(["foobar", "boofar", "foo"])) | ||
|
||
assert result.text == CORRECT | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_exact_nonmatch(): | ||
scorer = exact() | ||
state = simple_task_state(model_output="foo1") | ||
result = await scorer(state, Target(["foo"])) | ||
|
||
assert result.text == INCORRECT | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_f1_basic_match(): | ||
scorer = f1() | ||
state = simple_task_state(model_output="foo") | ||
result = await scorer(state, Target(["foo"])) | ||
|
||
assert result.text == "1.0" | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_f1_basic_nonmatch(): | ||
scorer = f1() | ||
state = simple_task_state(model_output="foo1") | ||
result = await scorer(state, Target(["foo"])) | ||
|
||
assert result.text == "0.0" | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_f1_good_match(): | ||
scorer = f1() | ||
state = simple_task_state(model_output="Paris") | ||
result = await scorer(state, Target(["Paris, Texas", "Paris"])) | ||
|
||
assert result.text == "1.0" | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_f1_partial_match(): | ||
scorer = f1() | ||
state = simple_task_state(model_output="Paris") | ||
result = await scorer(state, Target(["Paris, Texas"])) | ||
|
||
assert result.text == "0.6666666666666666" |