Skip to content

Commit

Permalink
Adding the target perplexity fix back (#15)
Browse files Browse the repository at this point in the history
---------

Co-authored-by: Thomas Wolf <[email protected]>
  • Loading branch information
clefourrier and thomwolf authored Feb 7, 2024
1 parent 37db422 commit 77eee8c
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 13 deletions.
17 changes: 13 additions & 4 deletions src/lighteval/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,18 @@

def apply_target_perplexity_metric(results: list[ModelReturn], formatted_doc: Doc, metrics: list[str]):
outputs = {}
current_results = [results.pop(0) for _ in range(len(formatted_doc.get_golds()))]
reference_text = formatted_doc.get_golds()[0]
current_result = results.pop(0)
target_logprob = current_result.result[0]
target_acc = current_result.result[1]

for metric in metrics:
if Metrics[metric].value.category == MetricCategory.PERPLEXITY:
outputs.update(Metrics[metric].value.compute(results=current_results))
if Metrics[metric].value.category == MetricCategory.TARGET_PERPLEXITY:
outputs.update(
Metrics[metric].value.compute(
logprobs=target_logprob, target_acc=target_acc, reference_text=reference_text
)
)

return results, outputs

Expand All @@ -30,7 +37,9 @@ def apply_perplexity_metric(results: list[ModelReturn], formatted_doc: Doc, metr

for metric in metrics:
if Metrics[metric].value.category == MetricCategory.PERPLEXITY:
outputs.update(Metrics[metric].value.compute(results=current_result, reference_text=reference_text))
outputs.update(
Metrics[metric].value.compute(logprobs=current_result.result, reference_text=reference_text)
)

return results, outputs

Expand Down
9 changes: 5 additions & 4 deletions src/lighteval/metrics/metrics_sample.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"""This module manages all the metrics occurring at the sample level. The results of said metrics are then aggregated
using simple function (min, mean, max, ...) at the corpus level. Most metrics fall under this category.
"""
from typing import Union

import nltk
import numpy as np
from nltk.metrics.distance import edit_distance
Expand Down Expand Up @@ -275,17 +277,16 @@ def compute(self, choices_logprob: list[float], gold_ixs: list[float], formatted
return 1.0 / (min(ranked_choices) + 1)


def acc_golds_likelihood(results: list[tuple[float, int]], **kwargs) -> int:
def acc_golds_likelihood(target_acc: Union[list[int], int], **kwargs) -> int:
"""Tests if at least one of predicted gold targets' log-likelihood is above 0.5.
Args:
results (list[int]): List of tuples containing, for each gold, the predictions log-probabilities associated with whether they are above 0.5 aggregated.
formatted_doc (Doc): _description_
target_acc (list[int]): List of scores indicating whether the predictions log-probabilities are above 0.5 aggregated.
Returns:
int: 1 if at least one of the possible golds had a log-likelihood above 0.5.
"""
return max([int(acc_ppl) for _, acc_ppl in results])
return max([int(acc_ppl) for acc_ppl in as_list(target_acc)])


class ROUGE:
Expand Down
6 changes: 3 additions & 3 deletions src/lighteval/metrics/sample_preparator.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,14 +106,14 @@ def count_units(self, text: str) -> int:
if self.units_type == "bytes":
return len(text.encode("utf-8"))

def prepare(self, results, reference_text, **kwargs):
def prepare(self, logprobs: list[float] | float, reference_text: str, **kwargs):
"""Prepares an individual perplexity example to the format expected by metrics computed at the corpus level (aggregated).
Args:
results (list[float]): List of the logprobabilities computed for each item
logprobs (list[float]): List of the logprobabilities computed for each item of the sequence or single aggregated logprob over the sequence
reference_text (str): Current reference text for which to compute the length in self.units_type
Returns:
PerplexityCorpusMetricInput: Stores the measured logprobs and associated text lengths, counted in the reference unit.
"""
return PerplexityCorpusMetricInput(logprobs=results.result, weights=self.count_units(reference_text))
return PerplexityCorpusMetricInput(logprobs=logprobs, weights=self.count_units(reference_text))
4 changes: 2 additions & 2 deletions tests/reference_scores/harness_metrics.json
Git LFS file not shown

0 comments on commit 77eee8c

Please sign in to comment.