From 77eee8cb025c0df900fcc0872e5204b9b96b55cf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9mentine=20Fourrier?= <22726840+clefourrier@users.noreply.github.com> Date: Wed, 7 Feb 2024 14:53:41 +0100 Subject: [PATCH] Adding the target perplexity fix back (#15) --------- Co-authored-by: Thomas Wolf --- src/lighteval/metrics/__init__.py | 17 +++++++++++++---- src/lighteval/metrics/metrics_sample.py | 9 +++++---- src/lighteval/metrics/sample_preparator.py | 6 +++--- tests/reference_scores/harness_metrics.json | 4 ++-- 4 files changed, 23 insertions(+), 13 deletions(-) diff --git a/src/lighteval/metrics/__init__.py b/src/lighteval/metrics/__init__.py index 3b17854e..3a0984bf 100644 --- a/src/lighteval/metrics/__init__.py +++ b/src/lighteval/metrics/__init__.py @@ -8,11 +8,18 @@ def apply_target_perplexity_metric(results: list[ModelReturn], formatted_doc: Doc, metrics: list[str]): outputs = {} - current_results = [results.pop(0) for _ in range(len(formatted_doc.get_golds()))] + reference_text = formatted_doc.get_golds()[0] + current_result = results.pop(0) + target_logprob = current_result.result[0] + target_acc = current_result.result[1] for metric in metrics: - if Metrics[metric].value.category == MetricCategory.PERPLEXITY: - outputs.update(Metrics[metric].value.compute(results=current_results)) + if Metrics[metric].value.category == MetricCategory.TARGET_PERPLEXITY: + outputs.update( + Metrics[metric].value.compute( + logprobs=target_logprob, target_acc=target_acc, reference_text=reference_text + ) + ) return results, outputs @@ -30,7 +37,9 @@ def apply_perplexity_metric(results: list[ModelReturn], formatted_doc: Doc, metr for metric in metrics: if Metrics[metric].value.category == MetricCategory.PERPLEXITY: - outputs.update(Metrics[metric].value.compute(results=current_result, reference_text=reference_text)) + outputs.update( + Metrics[metric].value.compute(logprobs=current_result.result, reference_text=reference_text) + ) return results, outputs diff --git a/src/lighteval/metrics/metrics_sample.py b/src/lighteval/metrics/metrics_sample.py index ec123741..e87e3bb5 100644 --- a/src/lighteval/metrics/metrics_sample.py +++ b/src/lighteval/metrics/metrics_sample.py @@ -1,6 +1,8 @@ """This module manages all the metrics occurring at the sample level. The results of said metrics are then aggregated using simple function (min, mean, max, ...) at the corpus level. Most metrics fall under this category. """ +from typing import Union + import nltk import numpy as np from nltk.metrics.distance import edit_distance @@ -275,17 +277,16 @@ def compute(self, choices_logprob: list[float], gold_ixs: list[float], formatted return 1.0 / (min(ranked_choices) + 1) -def acc_golds_likelihood(results: list[tuple[float, int]], **kwargs) -> int: +def acc_golds_likelihood(target_acc: Union[list[int], int], **kwargs) -> int: """Tests if at least one of predicted gold targets' log-likelihood is above 0.5. Args: - results (list[int]): List of tuples containing, for each gold, the predictions log-probabilities associated with whether they are above 0.5 aggregated. - formatted_doc (Doc): _description_ + target_acc (list[int]): List of scores indicating whether the predictions log-probabilities are above 0.5 aggregated. Returns: int: 1 if at least one of the possible golds had a log-likelihood above 0.5. """ - return max([int(acc_ppl) for _, acc_ppl in results]) + return max([int(acc_ppl) for acc_ppl in as_list(target_acc)]) class ROUGE: diff --git a/src/lighteval/metrics/sample_preparator.py b/src/lighteval/metrics/sample_preparator.py index 65902292..c28ed247 100644 --- a/src/lighteval/metrics/sample_preparator.py +++ b/src/lighteval/metrics/sample_preparator.py @@ -106,14 +106,14 @@ def count_units(self, text: str) -> int: if self.units_type == "bytes": return len(text.encode("utf-8")) - def prepare(self, results, reference_text, **kwargs): + def prepare(self, logprobs: list[float] | float, reference_text: str, **kwargs): """Prepares an individual perplexity example to the format expected by metrics computed at the corpus level (aggregated). Args: - results (list[float]): List of the logprobabilities computed for each item + logprobs (list[float]): List of the logprobabilities computed for each item of the sequence or single aggregated logprob over the sequence reference_text (str): Current reference text for which to compute the length in self.units_type Returns: PerplexityCorpusMetricInput: Stores the measured logprobs and associated text lengths, counted in the reference unit. """ - return PerplexityCorpusMetricInput(logprobs=results.result, weights=self.count_units(reference_text)) + return PerplexityCorpusMetricInput(logprobs=logprobs, weights=self.count_units(reference_text)) diff --git a/tests/reference_scores/harness_metrics.json b/tests/reference_scores/harness_metrics.json index a6c506f3..1c8c5b91 100644 --- a/tests/reference_scores/harness_metrics.json +++ b/tests/reference_scores/harness_metrics.json @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:a1965f0b9c66cfe1b1f3cc380a80949e32eab92ae8eac079c0339506ce827093 -size 48373142 +oid sha256:408956938a6b7a18b03658bb9772b471efcea4aa04afb0b35d76cecfca6a706e +size 48376580