Skip to content

Commit

Permalink
Apply comment and Ruff
Browse files Browse the repository at this point in the history
  • Loading branch information
sadra-barikbin committed Jul 17, 2024
1 parent 82c25fe commit a3dd652
Showing 1 changed file with 8 additions and 4 deletions.
12 changes: 8 additions & 4 deletions src/lighteval/metrics/harness_compatibility/drop.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@

import re
import string
from typing import List, Set, Tuple

import numpy as np
from scipy.optimize import linear_sum_assignment
from typing import List, Set, Tuple


def drop_metrics(predictions: list[str], formatted_doc, **kwargs): # noqa: C901
Expand All @@ -41,7 +41,7 @@ def drop_metrics(predictions: list[str], formatted_doc, **kwargs): # noqa: C901
between prediction and the different answers is taken.
For more information, please refer to the section 5 of the DROP paper (https://aclanthology.org/N19-1246/).
Todo: this code is really hard to follow, simplify when possible
"""

Expand Down Expand Up @@ -69,7 +69,9 @@ def _get_metrics(predicted: List[str], gold: List[str]):
pred_normalized_spans, pred_bags = _answer_to_bags(predicted)
gold_normalized_spans, gold_bags = _answer_to_bags(gold)

if set(pred_normalized_spans) == set(gold_normalized_spans) and len(gold_normalized_spans) == len(gold_normalized_spans):
if set(pred_normalized_spans) == set(gold_normalized_spans) and len(gold_normalized_spans) == len(
gold_normalized_spans
):
exact_match = 1.0
else:
exact_match = 0.0
Expand Down Expand Up @@ -161,7 +163,9 @@ def _normalize(answer: str):
max_f1 = 0
for gold_answer in formatted_doc.specific["golds_no_preprocessing"]:
exact_match, f1_score = _get_metrics(predictions, gold_answer)
if gold_answer[0].strip():
if isinstance(gold_answer, list):
gold_answer = gold_answer[0]
if gold_answer.strip():
max_em = max(max_em, exact_match)
max_f1 = max(max_f1, f1_score)
return {"qem": max_em, "f1": max_f1}

0 comments on commit a3dd652

Please sign in to comment.