Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extractive Match metric #495

Merged
merged 10 commits into from
Jan 15, 2025
Merged
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ tensorboardX = ["tensorboardX"]
vllm = ["vllm", "ray", "more_itertools"]
quality = ["ruff==v0.2.2","pre-commit"]
tests = ["pytest==7.4.0"]
dev = ["lighteval[accelerate,quality,tests,multilingual]"]
dev = ["lighteval[accelerate,quality,tests,multilingual,math]"]
docs = ["hf-doc-builder", "watchdog"]
extended_tasks = [
"langdetect", # ifeval
Expand All @@ -109,6 +109,7 @@ multilingual = [
"jieba", # for chinese tokenizer
"pyvi", # for vietnamese tokenizer
]
math = ["latex2sympy2_extended>=0.9.0"]

[project.urls]
Homepage = "https://github.com/huggingface/lighteval"
Expand Down
106 changes: 105 additions & 1 deletion src/lighteval/metrics/dynamic_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

from typing import Callable, Literal
import logging
from typing import Callable, Literal, Sequence

import numpy as np

Expand All @@ -37,8 +38,22 @@
LogProbTokenNorm,
get_multilingual_normalizer,
)
from lighteval.metrics.utils.extractive_match_utils import ( # noqa: F401
ExprExtractionConfig,
ExtractionTarget,
IndicesExtractionConfig,
LatexExtractionConfig,
extract_target_from_pred,
get_extraction_regexes,
)
from lighteval.metrics.utils.math_comparisson import compare_gold_target
from lighteval.metrics.utils.metric_utils import MetricCategory, MetricUseCase, SampleLevelMetric
from lighteval.tasks.requests import Doc
from lighteval.utils.language import Language
from lighteval.utils.timeout import timeout


logger = logging.getLogger(__name__)


def loglikelihood_acc_metric(normalization: LogProbNormalization | None = None) -> SampleLevelMetric:
Expand Down Expand Up @@ -168,3 +183,92 @@ def multilingual_quasi_exact_match_metric(
corpus_level_fn=np.mean,
higher_is_better=True,
)


def multilingual_extractive_match_metric(
language: Language,
hynky1999 marked this conversation as resolved.
Show resolved Hide resolved
gold_extraction_target: Sequence[ExtractionTarget] = (ExprExtractionConfig(),),
pred_extraction_target: Sequence[ExtractionTarget] = (ExprExtractionConfig(),),
aggregation_function: Callable[[list[float]], float] = max,
fallback_mode: Literal["no_fallback", "first_match"] = "first_match",
precision: int = 6,
) -> SampleLevelMetric:
"""Creates a language-aware extractive match metric that extracts answers from the model's output.

Known issues:
- If the task is to simplify an expression, the metric might overestimate the accuracy. This is because if the model doesn't output any anchor for the extraction (e.g final answer is..),
it's possible that the the extracted prediction will be the expression to simplify. Because we do simplifications ourselves, it can thus happen that sympy will correctly simplify the expression,
thus it will match gold, despite model not doing anything. PRs to fix this are welcome.

- There is currently no StringExtractionConfig, so if the gold is \boxed{\text{Friday}} and model outputs Friday it will not match, because nothing will be extracted.

Args:
language: Language
The language of the samples.
gold_extraction_target: Sequence[ExtractionTarget]
Extraction targets to use for gold answers. Defaults to extracting simple math expressions.
pred_extraction_target: Sequence[ExtractionTarget]
Extraction targets to use for predictions. Defaults to extracting simple math expressions.
aggregation_function: Callable[[list[float]], float]
Function to aggregate scores when multiple golds/predictions are present. Defaults to max.
fallback_mode: Literal["no_fallback", "first_match"]
How to perform extraction. Defaults to "first_match".
- "no_fallback": Only use first successfully parsed matches
- "first_match": Use the first successfully parsed match + first match irregardless the parsing success
precision: int
Number of decimal places to use when comparing numerical values. Defaults to 6.

Returns:
A sample level metric that extracts and compares mathematical expressions.

"""

@timeout(2)
def add_to_specifics_with_timeout(
formatted_doc: Doc, extracted_predictions: list[list[str]], extracted_golds: list[list[str]]
) -> None:
if formatted_doc.specific is None:
formatted_doc.specific = {}

formatted_doc.specific["extracted_predictions"] = [
str(pred) for preds in extracted_predictions for pred in preds
]
formatted_doc.specific["extracted_golds"] = [str(gold) for golds in extracted_golds for gold in golds]

def sample_level_fn(golds: list[str], predictions: list[str], formatted_doc: Doc) -> float:
gold_extraction_regexes = get_extraction_regexes(formatted_doc, gold_extraction_target, language)
pred_extraction_regexes = get_extraction_regexes(formatted_doc, pred_extraction_target, language)

extracted_predictions = [
extract_target_from_pred(pred, pred_extraction_regexes, fallback_mode) for pred in predictions
]
extracted_golds = [extract_target_from_pred(gold, gold_extraction_regexes, fallback_mode) for gold in golds]

# Assert on empty gold and warn on empty pred
if any(len(g) == 0 for g in extracted_golds):
raise ValueError(f"No gold targets found for at least one gold. Gold: {golds}, Pred: {predictions}")

if all(len(p) == 0 for p in extracted_predictions):
logger.warning(f"No predictions found for all predictions. Gold: {golds}, Pred: {predictions}")
hynky1999 marked this conversation as resolved.
Show resolved Hide resolved

# We have to use timeout because the sypmy to str conversion can be very slow
try:
add_to_specifics_with_timeout(formatted_doc, extracted_predictions, extracted_golds)
except: # noqa: E722
logger.warning("Timeout when adding extracted predictions and golds to specific")

return aggregation_function(
[
(1.0 if any(compare_gold_target(gold, pred, precision) for gold in extracted_golds) else 0.0)
for pred in extracted_predictions
]
)

return SampleLevelMetric(
metric_name="extractive_match",
sample_level_fn=sample_level_fn,
category=MetricCategory.GENERATIVE,
use_case=MetricUseCase.ACCURACY,
corpus_level_fn=np.mean,
higher_is_better=True,
)
Loading
Loading