Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

A small improvement in metrics_sample.py::ROUGE #217

Merged
merged 9 commits into from
Aug 14, 2024
27 changes: 20 additions & 7 deletions src/lighteval/metrics/metrics_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,7 @@ def __init__(
normalize_gold: callable = None,
normalize_pred: callable = None,
aggregation_function: callable = None,
tokenizer: object = None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could use the tokenizer object here transformers.PreTrainedTokenizer

):
"""A ROUGE wrapper method. Relies on `rouge_scorer`.

Expand All @@ -338,6 +339,8 @@ def __init__(
Defaults to None if no normalization is applied.
normalize_pred (callable, optional): Function to use to normalize the predicted strings.
Defaults to None if no normalization is applied.
tokenizer (object, optional): An object with `tokenize` method to be used by rouge scorer. If None, rouge-scorer's
default tokenizer will be used.
"""
if aggregation_function and bootstrap:
hlog_warn("Can't use both bootstrapping and an aggregation function in Rouge. Keeping bootstrap.")
Expand All @@ -350,7 +353,7 @@ def __init__(
raise ValueError(
f"Rouge was initialised with method {methods}, which is not in {','.join(self.ALLOWED_ROUGE_METHODS)}"
)
self.scorer = rouge_scorer.RougeScorer([methods])
self.scorer = rouge_scorer.RougeScorer([methods], tokenizer=tokenizer)
self.multiple_golds = multiple_golds
self.bootstrap = bootstrap
self.normalize_gold = normalize_gold
Expand Down Expand Up @@ -416,8 +419,18 @@ def __init__(
normalize_gold: callable = None,
normalize_pred: callable = None,
):
"""A BERT scorer class. Relies on some called extracted from `bert-score`. By default, will use the
`microsoft/deberta-large-mnli` as scorer
r"""A BERT scorer class. Relies on some called extracted from `bert-score`. By default, will use the
`microsoft/deberta-large-mnli` as scorer. For each tokenized (pred, target) pair, it computes Precision,
Recall and F1 as following:

Precision = \sum_{t=1}^{len(pred)} \div{max(Cos.Sim.(pred_t, target))}{IDF(pred_t)}

Recall = \sum_{t=1}^{len(target)} \div{max(Cos.Sim.(target_t, pred))}{IDF(target_t)}

F1 = \div{Precision * Recall}{Precision + Recall}

in which `Cos.Sim.` is the Cosine Similarity metric and `IDF(.)` represents the Inverse Document
Frequency of its input token. It defaults to 1 for all tokens and 0 for EOS and SEP tokens.

Args:
normalize_gold (callable, optional): Function to use to normalize the reference strings.
Expand Down Expand Up @@ -563,19 +576,19 @@ def __init__(
self.strip_prediction = strip_prediction
self.sample_aggregations = {"longest_common_prefix_length": max, "edit_distance": min, "edit_similarity": max}

def compute(self, gold: list[str], predictions: list[str], **kwargs) -> dict:
def compute(self, golds: list[str], predictions: list[str], **kwargs) -> dict:
"""Computes all the requested metrics on the golds and prediction.

Args:
gold (list[str]): A list of possible golds. If it contains more than one item, only the first one is kept.
golds (list[str]): A list of possible golds. If it contains more than one item, only the first one is kept.
predictions (list[str]): Predicted strings.

Returns:
dict: The different scores computed
"""
if len(gold) > 0:
if len(golds) > 1:
hlog_warn("Provided more than one gold to compute a string distance metric. Just using the first one.")
reference = gold[0]
reference = golds[0]

result = {m: [] for m in self.metric_types}
for sequence in predictions:
Expand Down
Loading