Skip to content

Commit

Permalink
Merge pull request #54 from gomate-community/feature/add_rouge_metric
Browse files Browse the repository at this point in the history
add rouge metric
  • Loading branch information
faneshion authored Mar 6, 2024
2 parents fbc5cc2 + f032a3b commit aa7b32d
Show file tree
Hide file tree
Showing 4 changed files with 176 additions and 1 deletion.
3 changes: 2 additions & 1 deletion rageval/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .base import Metric, MetricWithLLM, add_attribute
from ._context_recall import ContextRecall
from ._answer_rouge_correctness import AnswerRougeCorrectness
from ._context_reject_rate import ContextRejectRate
from ._answer_exact_match import AnswerEMCorrectness
from ._answer_claim_recall import AnswerNLICorrectness
from ._answer_claim_recall import AnswerNLICorrectness
120 changes: 120 additions & 0 deletions rageval/metrics/_answer_rouge_correctness.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
# -*- coding: utf-8 -*-

from typing import List, Any, Callable, Optional
from rouge_score import rouge_scorer
import datasets

from datasets import Dataset
from dataclasses import dataclass

from rageval.metrics import Metric, add_attribute

_DESCRIPTION = """Estimates ROUGE score by estimating answer and groundtruth answers.
ROUGE is case insensitive, so the input text is converted to lower case before computing the score. This metrics is a wrapper around the https://github.com/google-research/google-research/blob/master/rouge/rouge_scorer.py
"""

_KWARGS_DESCRIPTION = """\
Args:
name : str
rouge_type : str, the rouge type to calculate. Defaults to 'rouge1', 'rouge2', 'rougeL', 'rougeLsum'
"rouge1": unigram (1-gram) based scoring
"rouge2": bigram (2-gram) based scoring
"rougeL": Longest common subsequence based scoring.
"rougeLSum": splits text using "\n".
Optional Args:
tokenizer : Callable, a tokenizer can be passed to the scorer, replacing the default tokenizer which tokenizes on whitespace, especially for non-latin languages. For example, the `jieba.cut` can be used for Chinese.
Functions:
_compute_one: compute the score by measure whether the args:`answer` contains short answer in list:`gt_answers`.
Examples:
>>> from datasets import Dataset
>>> import rageval as rl
>>> sample = {
... "answers": [
... "Some nanomaterials may give rise to various kinds of lung damage."
... ],
... "gt_answers":[
... [
... "Nanoparticles can penetrate the body, affecting the lungs, brain, and other organs,\
... leading to possible respiratory, cardiovascular, and brain health problems.",
... "Due to their small size, nanoparticles can infiltrate the body and impact vital organs,\
... posing risks to respiratory, heart, and neurological health."
... ]
... ]
... }
>>> dataset = Dataset.from_dict(sample)
>>> metric = rl.metrics.AnswerRougeCorrectness('rougeL')
>>> score, results = metric.compute(dataset, batch_size= 1)
>>> assert 0 <= score <= 1
>>> type(results)
<class 'datasets.arrow_dataset.Dataset'>
"""

_CITATION = """\
@article{lewis2020retrieval,
title={Retrieval-augmented generation for knowledge-intensive nlp tasks},
author={Lewis, Patrick and Perez, Ethan and Piktus, Aleksandra and Petroni, Fabio and Karpukhin, Vladimir and Goyal, Naman and K{\"u}ttler, Heinrich and Lewis, Mike and Yih, Wen-tau and Rockt{\"a}schel, Tim and others},
journal={Advances in Neural Information Processing Systems},
volume={33},
pages={9459--9474},
year={2020}
}
"""


@dataclass
@add_attribute('mtype', 'AnswerCorrectness')
@datasets.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
class AnswerRougeCorrectness(Metric):

name = "answer_rouge_correctness"

ALIAS = ['answer_rouge_correctness']

def __init__(self, rouge_type: str, tokenizer: Optional[Callable] = None):
"""Explicitly initialize the AnswerRougeCorrectness to ensure all parent class initialized as well as initialize the rouge type and tokenizer."""
self._required_columns = ['answers', 'gt_answers']
self.rouge_type = rouge_type
self.scorer = rouge_scorer.RougeScorer([rouge_type], use_stemmer=True, tokenizer=tokenizer)
super().__init__()

def __repr__(self) -> str:
""":return: Formated string representation of the metric."""
return f"{self.ALIAS[0]}"

def _info(self):
return datasets.MetricInfo(
description=_DESCRIPTION,
inputs_description=_KWARGS_DESCRIPTION,
citation=_CITATION,
homepage="",
features=datasets.Features(
{
"answers": datasets.Value("string", id="sequence"),
"contexts": datasets.Value("string", id="sequence"),
}
),
codebase_urls=[],
reference_urls=[]
)

def _validate_data(self, dataset: Dataset) -> bool:
super()._validate_data(dataset)
if not all(isinstance(answer, str) for answer in dataset["answers"]):
raise ValueError("The type of answers should be a string.")
if not all(isinstance(a, List) or not all(isinstance(item, str) for item in a) for a in dataset["gt_answers"]):
raise ValueError("The type of gt_answers should be a list of strings.")

def _compute_one(self, answer: str, gt_answers: List[str]) -> float:
"""Evaluate the ROUGE between a single answer and groundtruth answers."""
score = self.scorer.score_multi(gt_answers, answer)
return score[self.rouge_type].fmeasure

def _compute_batch(self, dataset: Dataset) -> list:
"""Evaluate the ROUGE of a batch of answers."""
results = [self._compute_one(answer, gt_answer) for answer, gt_answer in zip(dataset["answers"], dataset["gt_answers"])]
return results
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,4 @@ transformers == 4.37.2
torch == 2.2.0
pandas == 2.0.0
nltk == 3.8.1
rouge_score == 0.1.2
53 changes: 53 additions & 0 deletions tests/units/test_answer_rouge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import pytest
from datasets import Dataset
from typing import List
from rageval.metrics import AnswerRougeCorrectness

class CharTokenizer:
"""Tokenize text into characters."""
def tokenize(self, text: str) -> List[str]:
# Tokenize by characters to avoid a dependency on word segmentation methods.
return [c for c in text]

@pytest.fixture(scope='module')
def sample():
test_case = {
"answers": [
"###刚刚发声,A股这种情况十分罕见!大聪明逆市抄底330亿,一篇研报引爆全球,市场逻辑生变?",
"The quick brown fox jumps over the lazy dog."
],
"gt_answers": [
[
"刚刚过去的这个月,美股总市值暴跌了将近6万亿美元(折合人民币超过40万亿),这背后的原因可能不仅仅是加息这么简单。最近瑞士信贷知名分析师Zoltan Polzsar撰写了一篇极其重要的文章,详细分析了现有世界秩序的崩坏本质以及美国和西方将要采取的应对策略。在该文中,Zoltan Polzsar直指美国通胀的本质和其长期性。同期,A股市场亦出现了大幅杀跌的情况。"
],
[
"The quick brown fox jumps over the lazy dog.",
"The brown fox jumps over the lazy dog."
]
]
}
return test_case


@pytest.fixture(scope='module')
def testset(sample):
ds = Dataset.from_dict(sample)
return ds


def test_case_on_answer_exact_match(testset):

# Test with Chinese tokenizer
chinese_tokenizer = CharTokenizer()
metric = AnswerRougeCorrectness('rouge1', chinese_tokenizer)
score, results = metric.compute(testset, 1)
assert metric.mtype == 'AnswerCorrectness'
assert 0 <= score <= 1
assert isinstance(results, Dataset)

# Test with English tokenizer
metric = AnswerRougeCorrectness('rouge1')
score, results = metric.compute(testset, 1)
assert metric.mtype == 'AnswerCorrectness'
assert 0 <= score <= 1
assert isinstance(results, Dataset)

0 comments on commit aa7b32d

Please sign in to comment.