diff --git a/src/lighteval/metrics/__init__.py b/src/lighteval/metrics/__init__.py index 3b17854e7..3a0984bfc 100644 --- a/src/lighteval/metrics/__init__.py +++ b/src/lighteval/metrics/__init__.py @@ -8,11 +8,18 @@ def apply_target_perplexity_metric(results: list[ModelReturn], formatted_doc: Doc, metrics: list[str]): outputs = {} - current_results = [results.pop(0) for _ in range(len(formatted_doc.get_golds()))] + reference_text = formatted_doc.get_golds()[0] + current_result = results.pop(0) + target_logprob = current_result.result[0] + target_acc = current_result.result[1] for metric in metrics: - if Metrics[metric].value.category == MetricCategory.PERPLEXITY: - outputs.update(Metrics[metric].value.compute(results=current_results)) + if Metrics[metric].value.category == MetricCategory.TARGET_PERPLEXITY: + outputs.update( + Metrics[metric].value.compute( + logprobs=target_logprob, target_acc=target_acc, reference_text=reference_text + ) + ) return results, outputs @@ -30,7 +37,9 @@ def apply_perplexity_metric(results: list[ModelReturn], formatted_doc: Doc, metr for metric in metrics: if Metrics[metric].value.category == MetricCategory.PERPLEXITY: - outputs.update(Metrics[metric].value.compute(results=current_result, reference_text=reference_text)) + outputs.update( + Metrics[metric].value.compute(logprobs=current_result.result, reference_text=reference_text) + ) return results, outputs diff --git a/src/lighteval/metrics/metrics_sample.py b/src/lighteval/metrics/metrics_sample.py index ec123741b..e87e3bb58 100644 --- a/src/lighteval/metrics/metrics_sample.py +++ b/src/lighteval/metrics/metrics_sample.py @@ -1,6 +1,8 @@ """This module manages all the metrics occurring at the sample level. The results of said metrics are then aggregated using simple function (min, mean, max, ...) at the corpus level. Most metrics fall under this category. """ +from typing import Union + import nltk import numpy as np from nltk.metrics.distance import edit_distance @@ -275,17 +277,16 @@ def compute(self, choices_logprob: list[float], gold_ixs: list[float], formatted return 1.0 / (min(ranked_choices) + 1) -def acc_golds_likelihood(results: list[tuple[float, int]], **kwargs) -> int: +def acc_golds_likelihood(target_acc: Union[list[int], int], **kwargs) -> int: """Tests if at least one of predicted gold targets' log-likelihood is above 0.5. Args: - results (list[int]): List of tuples containing, for each gold, the predictions log-probabilities associated with whether they are above 0.5 aggregated. - formatted_doc (Doc): _description_ + target_acc (list[int]): List of scores indicating whether the predictions log-probabilities are above 0.5 aggregated. Returns: int: 1 if at least one of the possible golds had a log-likelihood above 0.5. """ - return max([int(acc_ppl) for _, acc_ppl in results]) + return max([int(acc_ppl) for acc_ppl in as_list(target_acc)]) class ROUGE: diff --git a/src/lighteval/metrics/sample_preparator.py b/src/lighteval/metrics/sample_preparator.py index 659022920..c28ed2470 100644 --- a/src/lighteval/metrics/sample_preparator.py +++ b/src/lighteval/metrics/sample_preparator.py @@ -106,14 +106,14 @@ def count_units(self, text: str) -> int: if self.units_type == "bytes": return len(text.encode("utf-8")) - def prepare(self, results, reference_text, **kwargs): + def prepare(self, logprobs: list[float] | float, reference_text: str, **kwargs): """Prepares an individual perplexity example to the format expected by metrics computed at the corpus level (aggregated). Args: - results (list[float]): List of the logprobabilities computed for each item + logprobs (list[float]): List of the logprobabilities computed for each item of the sequence or single aggregated logprob over the sequence reference_text (str): Current reference text for which to compute the length in self.units_type Returns: PerplexityCorpusMetricInput: Stores the measured logprobs and associated text lengths, counted in the reference unit. """ - return PerplexityCorpusMetricInput(logprobs=results.result, weights=self.count_units(reference_text)) + return PerplexityCorpusMetricInput(logprobs=logprobs, weights=self.count_units(reference_text)) diff --git a/tests/reference_scores/harness_metrics.json b/tests/reference_scores/harness_metrics.json index a6c506f34..1c8c5b91d 100644 --- a/tests/reference_scores/harness_metrics.json +++ b/tests/reference_scores/harness_metrics.json @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:a1965f0b9c66cfe1b1f3cc380a80949e32eab92ae8eac079c0339506ce827093 -size 48373142 +oid sha256:408956938a6b7a18b03658bb9772b471efcea4aa04afb0b35d76cecfca6a706e +size 48376580