-
Notifications
You must be signed in to change notification settings - Fork 11
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #54 from gomate-community/feature/add_rouge_metric
add rouge metric
- Loading branch information
Showing
4 changed files
with
176 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
from .base import Metric, MetricWithLLM, add_attribute | ||
from ._context_recall import ContextRecall | ||
from ._answer_rouge_correctness import AnswerRougeCorrectness | ||
from ._context_reject_rate import ContextRejectRate | ||
from ._answer_exact_match import AnswerEMCorrectness | ||
from ._answer_claim_recall import AnswerNLICorrectness | ||
from ._answer_claim_recall import AnswerNLICorrectness |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
# -*- coding: utf-8 -*- | ||
|
||
from typing import List, Any, Callable, Optional | ||
from rouge_score import rouge_scorer | ||
import datasets | ||
|
||
from datasets import Dataset | ||
from dataclasses import dataclass | ||
|
||
from rageval.metrics import Metric, add_attribute | ||
|
||
_DESCRIPTION = """Estimates ROUGE score by estimating answer and groundtruth answers. | ||
ROUGE is case insensitive, so the input text is converted to lower case before computing the score. This metrics is a wrapper around the https://github.com/google-research/google-research/blob/master/rouge/rouge_scorer.py | ||
""" | ||
|
||
_KWARGS_DESCRIPTION = """\ | ||
Args: | ||
name : str | ||
rouge_type : str, the rouge type to calculate. Defaults to 'rouge1', 'rouge2', 'rougeL', 'rougeLsum' | ||
"rouge1": unigram (1-gram) based scoring | ||
"rouge2": bigram (2-gram) based scoring | ||
"rougeL": Longest common subsequence based scoring. | ||
"rougeLSum": splits text using "\n". | ||
Optional Args: | ||
tokenizer : Callable, a tokenizer can be passed to the scorer, replacing the default tokenizer which tokenizes on whitespace, especially for non-latin languages. For example, the `jieba.cut` can be used for Chinese. | ||
Functions: | ||
_compute_one: compute the score by measure whether the args:`answer` contains short answer in list:`gt_answers`. | ||
Examples: | ||
>>> from datasets import Dataset | ||
>>> import rageval as rl | ||
>>> sample = { | ||
... "answers": [ | ||
... "Some nanomaterials may give rise to various kinds of lung damage." | ||
... ], | ||
... "gt_answers":[ | ||
... [ | ||
... "Nanoparticles can penetrate the body, affecting the lungs, brain, and other organs,\ | ||
... leading to possible respiratory, cardiovascular, and brain health problems.", | ||
... "Due to their small size, nanoparticles can infiltrate the body and impact vital organs,\ | ||
... posing risks to respiratory, heart, and neurological health." | ||
... ] | ||
... ] | ||
... } | ||
>>> dataset = Dataset.from_dict(sample) | ||
>>> metric = rl.metrics.AnswerRougeCorrectness('rougeL') | ||
>>> score, results = metric.compute(dataset, batch_size= 1) | ||
>>> assert 0 <= score <= 1 | ||
>>> type(results) | ||
<class 'datasets.arrow_dataset.Dataset'> | ||
""" | ||
|
||
_CITATION = """\ | ||
@article{lewis2020retrieval, | ||
title={Retrieval-augmented generation for knowledge-intensive nlp tasks}, | ||
author={Lewis, Patrick and Perez, Ethan and Piktus, Aleksandra and Petroni, Fabio and Karpukhin, Vladimir and Goyal, Naman and K{\"u}ttler, Heinrich and Lewis, Mike and Yih, Wen-tau and Rockt{\"a}schel, Tim and others}, | ||
journal={Advances in Neural Information Processing Systems}, | ||
volume={33}, | ||
pages={9459--9474}, | ||
year={2020} | ||
} | ||
""" | ||
|
||
|
||
@dataclass | ||
@add_attribute('mtype', 'AnswerCorrectness') | ||
@datasets.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) | ||
class AnswerRougeCorrectness(Metric): | ||
|
||
name = "answer_rouge_correctness" | ||
|
||
ALIAS = ['answer_rouge_correctness'] | ||
|
||
def __init__(self, rouge_type: str, tokenizer: Optional[Callable] = None): | ||
"""Explicitly initialize the AnswerRougeCorrectness to ensure all parent class initialized as well as initialize the rouge type and tokenizer.""" | ||
self._required_columns = ['answers', 'gt_answers'] | ||
self.rouge_type = rouge_type | ||
self.scorer = rouge_scorer.RougeScorer([rouge_type], use_stemmer=True, tokenizer=tokenizer) | ||
super().__init__() | ||
|
||
def __repr__(self) -> str: | ||
""":return: Formated string representation of the metric.""" | ||
return f"{self.ALIAS[0]}" | ||
|
||
def _info(self): | ||
return datasets.MetricInfo( | ||
description=_DESCRIPTION, | ||
inputs_description=_KWARGS_DESCRIPTION, | ||
citation=_CITATION, | ||
homepage="", | ||
features=datasets.Features( | ||
{ | ||
"answers": datasets.Value("string", id="sequence"), | ||
"contexts": datasets.Value("string", id="sequence"), | ||
} | ||
), | ||
codebase_urls=[], | ||
reference_urls=[] | ||
) | ||
|
||
def _validate_data(self, dataset: Dataset) -> bool: | ||
super()._validate_data(dataset) | ||
if not all(isinstance(answer, str) for answer in dataset["answers"]): | ||
raise ValueError("The type of answers should be a string.") | ||
if not all(isinstance(a, List) or not all(isinstance(item, str) for item in a) for a in dataset["gt_answers"]): | ||
raise ValueError("The type of gt_answers should be a list of strings.") | ||
|
||
def _compute_one(self, answer: str, gt_answers: List[str]) -> float: | ||
"""Evaluate the ROUGE between a single answer and groundtruth answers.""" | ||
score = self.scorer.score_multi(gt_answers, answer) | ||
return score[self.rouge_type].fmeasure | ||
|
||
def _compute_batch(self, dataset: Dataset) -> list: | ||
"""Evaluate the ROUGE of a batch of answers.""" | ||
results = [self._compute_one(answer, gt_answer) for answer, gt_answer in zip(dataset["answers"], dataset["gt_answers"])] | ||
return results |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,3 +16,4 @@ transformers == 4.37.2 | |
torch == 2.2.0 | ||
pandas == 2.0.0 | ||
nltk == 3.8.1 | ||
rouge_score == 0.1.2 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
import pytest | ||
from datasets import Dataset | ||
from typing import List | ||
from rageval.metrics import AnswerRougeCorrectness | ||
|
||
class CharTokenizer: | ||
"""Tokenize text into characters.""" | ||
def tokenize(self, text: str) -> List[str]: | ||
# Tokenize by characters to avoid a dependency on word segmentation methods. | ||
return [c for c in text] | ||
|
||
@pytest.fixture(scope='module') | ||
def sample(): | ||
test_case = { | ||
"answers": [ | ||
"###刚刚发声,A股这种情况十分罕见!大聪明逆市抄底330亿,一篇研报引爆全球,市场逻辑生变?", | ||
"The quick brown fox jumps over the lazy dog." | ||
], | ||
"gt_answers": [ | ||
[ | ||
"刚刚过去的这个月,美股总市值暴跌了将近6万亿美元(折合人民币超过40万亿),这背后的原因可能不仅仅是加息这么简单。最近瑞士信贷知名分析师Zoltan Polzsar撰写了一篇极其重要的文章,详细分析了现有世界秩序的崩坏本质以及美国和西方将要采取的应对策略。在该文中,Zoltan Polzsar直指美国通胀的本质和其长期性。同期,A股市场亦出现了大幅杀跌的情况。" | ||
], | ||
[ | ||
"The quick brown fox jumps over the lazy dog.", | ||
"The brown fox jumps over the lazy dog." | ||
] | ||
] | ||
} | ||
return test_case | ||
|
||
|
||
@pytest.fixture(scope='module') | ||
def testset(sample): | ||
ds = Dataset.from_dict(sample) | ||
return ds | ||
|
||
|
||
def test_case_on_answer_exact_match(testset): | ||
|
||
# Test with Chinese tokenizer | ||
chinese_tokenizer = CharTokenizer() | ||
metric = AnswerRougeCorrectness('rouge1', chinese_tokenizer) | ||
score, results = metric.compute(testset, 1) | ||
assert metric.mtype == 'AnswerCorrectness' | ||
assert 0 <= score <= 1 | ||
assert isinstance(results, Dataset) | ||
|
||
# Test with English tokenizer | ||
metric = AnswerRougeCorrectness('rouge1') | ||
score, results = metric.compute(testset, 1) | ||
assert metric.mtype == 'AnswerCorrectness' | ||
assert 0 <= score <= 1 | ||
assert isinstance(results, Dataset) |