diff --git a/src/torchmetrics/functional/text/_deprecated.py b/src/torchmetrics/functional/text/_deprecated.py index 62ea5645048..fabfca2c0eb 100644 --- a/src/torchmetrics/functional/text/_deprecated.py +++ b/src/torchmetrics/functional/text/_deprecated.py @@ -134,7 +134,14 @@ def _chrf_score( whitespace: bool = False, return_sentence_level_score: bool = False, ) -> Union[Tensor, Tuple[Tensor, Tensor]]: - """Wrapper for deprecated import.""" + """Wrapper for deprecated import. + + >>> preds = ['the cat is on the mat'] + >>> target = [['there is a cat on the mat', 'a cat is on the mat']] + >>> _chrf_score(preds, target) + tensor(0.8640) + + """ _deprecated_root_import_func("chrf_score", "text") return chrf_score( preds=preds, diff --git a/src/torchmetrics/functional/text/chrf.py b/src/torchmetrics/functional/text/chrf.py index 397777f7310..375355b85cb 100644 --- a/src/torchmetrics/functional/text/chrf.py +++ b/src/torchmetrics/functional/text/chrf.py @@ -11,9 +11,525 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Sequence, Tuple, Union +# referenced from +# Library Name: torchtext +# Authors: torchtext authors +# Date: 2021-11-25 +# Link: -from torch import Tensor +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +# Copyright 2017 Maja Popovic + +# The program is distributed under the terms +# of the GNU General Public Licence (GPL) + +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. + +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +from collections import defaultdict +from itertools import chain +from typing import Dict, List, Optional, Sequence, Tuple, Union + +import torch +from torch import Tensor, tensor + +from torchmetrics.functional.text.helper import _validate_inputs + +_EPS_SMOOTHING = tensor(1e-16) +# Taken from https://github.com/mjpost/sacrebleu/blob/master/sacrebleu/metrics/chrf.py +_PUNCTUATIONS = set("!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~") + + +def _prepare_n_grams_dicts( + n_char_order: int, n_word_order: int +) -> Tuple[ + Dict[int, Tensor], Dict[int, Tensor], Dict[int, Tensor], Dict[int, Tensor], Dict[int, Tensor], Dict[int, Tensor] +]: + """Prepare dictionaries with default zero values for total ref, hypothesis and matching character and word n-grams. + + Args: + n_char_order: A character n-gram order. + n_word_order: A word n-gram order. + + Return: + Dictionaries with default zero values for total reference, hypothesis and matching character and word + n-grams. + + """ + total_preds_char_n_grams: Dict[int, Tensor] = {n + 1: tensor(0.0) for n in range(n_char_order)} + total_preds_word_n_grams: Dict[int, Tensor] = {n + 1: tensor(0.0) for n in range(n_word_order)} + total_target_char_n_grams: Dict[int, Tensor] = {n + 1: tensor(0.0) for n in range(n_char_order)} + total_target_word_n_grams: Dict[int, Tensor] = {n + 1: tensor(0.0) for n in range(n_word_order)} + total_matching_char_n_grams: Dict[int, Tensor] = {n + 1: tensor(0.0) for n in range(n_char_order)} + total_matching_word_n_grams: Dict[int, Tensor] = {n + 1: tensor(0.0) for n in range(n_word_order)} + + return ( + total_preds_char_n_grams, + total_preds_word_n_grams, + total_target_char_n_grams, + total_target_word_n_grams, + total_matching_char_n_grams, + total_matching_word_n_grams, + ) + + +def _get_characters(sentence: str, whitespace: bool) -> List[str]: + """Split sentence into individual characters. + + Args: + sentence: An input sentence to split. + whitespace: An indication whether to keep whitespaces during character n-gram extraction. + + Return: + A list of separated characters. + + """ + if whitespace: + return list(sentence) + return list(sentence.strip().replace(" ", "")) + + +def _separate_word_and_punctuation(word: str) -> List[str]: + """Separates out punctuations from beginning and end of words for chrF. + + Adapted from https://github.com/m-popovic/chrF and + https://github.com/mjpost/sacrebleu/blob/master/sacrebleu/metrics/chrf.py. + + Args: + word: An input word to be separated from a punctuation if present. + + Return: + A list of a single word or a separated word and punctuation. + + """ + if len(word) == 1: + return [word] + + if word[-1] in _PUNCTUATIONS: + return [word[:-1], word[-1]] + if word[0] in _PUNCTUATIONS: + return [word[0], word[1:]] + return [word] + + +def _get_words_and_punctuation(sentence: str) -> List[str]: + """Separates out punctuations from beginning and end of words for chrF for all words in the sentence. + + Args: + sentence: An input sentence to split + + Return: + An aggregated list of separated words and punctuations. + + """ + return list(chain.from_iterable(_separate_word_and_punctuation(word) for word in sentence.strip().split())) + + +def _ngram_counts(char_or_word_list: List[str], n_gram_order: int) -> Dict[int, Dict[Tuple[str, ...], Tensor]]: + """Calculate n-gram counts. + + Args: + char_or_word_list: A list of characters of words + n_gram_order: The largest number of n-gram. + + Return: + A dictionary of dictionaries with a counts of given n-grams. + + """ + ngrams: Dict[int, Dict[Tuple[str, ...], Tensor]] = defaultdict(lambda: defaultdict(lambda: tensor(0.0))) + for n in range(1, n_gram_order + 1): + for ngram in (tuple(char_or_word_list[i : i + n]) for i in range(len(char_or_word_list) - n + 1)): + ngrams[n][ngram] += tensor(1) + return ngrams + + +def _get_n_grams_counts_and_total_ngrams( + sentence: str, n_char_order: int, n_word_order: int, lowercase: bool, whitespace: bool +) -> Tuple[ + Dict[int, Dict[Tuple[str, ...], Tensor]], + Dict[int, Dict[Tuple[str, ...], Tensor]], + Dict[int, Tensor], + Dict[int, Tensor], +]: + """Get n-grams and total n-grams. + + Args: + sentence: An input sentence + n_char_order: A character n-gram order. + n_word_order: A word n-gram order. + lowercase: An indication whether to enable case-insensitivity. + whitespace: An indication whether to keep whitespaces during character n-gram extraction. + + Return: + char_n_grams_counts: A dictionary of dictionaries with sentence character n-grams. + word_n_grams_counts: A dictionary of dictionaries with sentence word n-grams. + total_char_n_grams: A dictionary containing a total number of sentence character n-grams. + total_word_n_grams: A dictionary containing a total number of sentence word n-grams. + + """ + + def _char_and_word_ngrams_counts( + sentence: str, n_char_order: int, n_word_order: int, lowercase: bool + ) -> Tuple[Dict[int, Dict[Tuple[str, ...], Tensor]], Dict[int, Dict[Tuple[str, ...], Tensor]]]: + """Get a dictionary of dictionaries with a counts of given n-grams.""" + if lowercase: + sentence = sentence.lower() + char_n_grams_counts = _ngram_counts(_get_characters(sentence, whitespace), n_char_order) + word_n_grams_counts = _ngram_counts(_get_words_and_punctuation(sentence), n_word_order) + return char_n_grams_counts, word_n_grams_counts + + def _get_total_ngrams(n_grams_counts: Dict[int, Dict[Tuple[str, ...], Tensor]]) -> Dict[int, Tensor]: + """Get total sum of n-grams over n-grams w.r.t n.""" + total_n_grams: Dict[int, Tensor] = defaultdict(lambda: tensor(0.0)) + for n in n_grams_counts: + total_n_grams[n] = sum(n_grams_counts[n].values()).detach().clone() # type: ignore + return total_n_grams + + char_n_grams_counts, word_n_grams_counts = _char_and_word_ngrams_counts( + sentence, n_char_order, n_word_order, lowercase + ) + total_char_n_grams = _get_total_ngrams(char_n_grams_counts) + total_word_n_grams = _get_total_ngrams(word_n_grams_counts) + + return char_n_grams_counts, word_n_grams_counts, total_char_n_grams, total_word_n_grams + + +def _get_ngram_matches( + hyp_n_grams_counts: Dict[int, Dict[Tuple[str, ...], Tensor]], + ref_n_grams_counts: Dict[int, Dict[Tuple[str, ...], Tensor]], +) -> Dict[int, Tensor]: + """Get a number of n-gram matches between reference and hypothesis n-grams. + + Args: + hyp_n_grams_counts: n-grams counts for hypothesis + ref_n_grams_counts: n-grams counts for reference + + Return: + matching_n_grams + + """ + matching_n_grams: Dict[int, Tensor] = defaultdict(lambda: tensor(0.0)) + for n in hyp_n_grams_counts: + min_n_grams = [ + torch.min(ref_n_grams_counts[n][n_gram], hyp_n_grams_counts[n][n_gram]) for n_gram in hyp_n_grams_counts[n] + ] + matching_n_grams[n] = sum(min_n_grams).detach().clone() # type: ignore + return matching_n_grams + + +def _sum_over_dicts(total_n_grams: Dict[int, Tensor], n_grams: Dict[int, Tensor]) -> Dict[int, Tensor]: + """Aggregate total n-grams to keep corpus-level statistics. + + Args: + total_n_grams: A dictionary containing a total corpus-level number of n-grams. + n_grams: A dictionary containing a sentence-level number of n-grams. + + Return: + A dictionary containing a total corpus-level number of n-grams. + + """ + for n in n_grams: + total_n_grams[n] += n_grams[n] + return total_n_grams + + +def _calculate_fscore( + matching_char_n_grams: Dict[int, Tensor], + matching_word_n_grams: Dict[int, Tensor], + hyp_char_n_grams: Dict[int, Tensor], + hyp_word_n_grams: Dict[int, Tensor], + ref_char_n_grams: Dict[int, Tensor], + ref_word_n_grams: Dict[int, Tensor], + n_order: float, + beta: float, +) -> Tensor: + """Calculate sentence-level chrF/chrF++ score. + + For given hypothesis and reference statistics (either sentence-level or corpus-level) + the chrF/chrF++ score is returned. + + Args: + matching_char_n_grams: + A total number of matching character n-grams between the best matching reference and hypothesis. + matching_word_n_grams: + A total number of matching word n-grams between the best matching reference and hypothesis. + hyp_char_n_grams: A total number of hypothesis character n-grams. + hyp_word_n_grams: A total number of hypothesis word n-grams. + ref_char_n_grams: A total number of reference character n-grams. + ref_word_n_grams: A total number of reference word n-grams. + n_order: A sum of character and word n-gram order. + beta: A parameter determining an importance of recall w.r.t. precision. If `beta=1`, their importance is equal. + + Return: + A chrF/chrF++ score. This function is universal both for sentence-level and corpus-level calculation. + + """ + + def _get_n_gram_fscore( + matching_n_grams: Dict[int, Tensor], ref_n_grams: Dict[int, Tensor], hyp_n_grams: Dict[int, Tensor], beta: float + ) -> Dict[int, Tensor]: + """Get n-gram level f-score.""" + precision: Dict[int, Tensor] = { + n: matching_n_grams[n] / hyp_n_grams[n] if hyp_n_grams[n] > 0 else tensor(0.0) for n in matching_n_grams + } + recall: Dict[int, Tensor] = { + n: matching_n_grams[n] / ref_n_grams[n] if ref_n_grams[n] > 0 else tensor(0.0) for n in matching_n_grams + } + denominator: Dict[int, Tensor] = { + n: torch.max(beta**2 * precision[n] + recall[n], _EPS_SMOOTHING) for n in matching_n_grams + } + f_score: Dict[int, Tensor] = { + n: (1 + beta**2) * precision[n] * recall[n] / denominator[n] for n in matching_n_grams + } + + return f_score + + char_n_gram_f_score = _get_n_gram_fscore(matching_char_n_grams, ref_char_n_grams, hyp_char_n_grams, beta) + word_n_gram_f_score = _get_n_gram_fscore(matching_word_n_grams, ref_word_n_grams, hyp_word_n_grams, beta) + + return (sum(char_n_gram_f_score.values()) + sum(word_n_gram_f_score.values())) / tensor(n_order) + + +def _calculate_sentence_level_chrf_score( + targets: List[str], + pred_char_n_grams_counts: Dict[int, Dict[Tuple[str, ...], Tensor]], + pred_word_n_grams_counts: Dict[int, Dict[Tuple[str, ...], Tensor]], + pred_char_n_grams: Dict[int, Tensor], + pred_word_n_grams: Dict[int, Tensor], + n_char_order: int, + n_word_order: int, + n_order: float, + beta: float, + lowercase: bool, + whitespace: bool, +) -> Tuple[Tensor, Dict[int, Tensor], Dict[int, Tensor], Dict[int, Tensor], Dict[int, Tensor]]: + """Calculate the best sentence-level chrF/chrF++ score. + + For a given pre-processed hypothesis, all references are evaluated and score and statistics + for the best matching reference is returned. + + Args: + targets: An iterable of references. + pred_char_n_grams_counts: A dictionary of dictionaries with hypothesis character n-grams. + pred_word_n_grams_counts: A dictionary of dictionaries with hypothesis word n-grams. + pred_char_n_grams: A total number of hypothesis character n-grams. + pred_word_n_grams: A total number of hypothesis word n-grams. + n_char_order: A character n-gram order. + n_word_order: A word n-gram order. + n_order: A sum of character and word n-gram order. + beta: A parameter determining an importance of recall w.r.t. precision. If `beta=1`, their importance is equal. + lowercase: An indication whether to enable case-insensitivity. + whitespace: An indication whether to keep whitespaces during character n-gram extraction. + + Return: + Return chrF/chrF++ score and statistics for the best matching hypothesis and reference. + + f_score: A sentence-level chrF/chrF++ score. + matching_char_n_grams: + A total number of matching character n-grams between the best matching reference and hypothesis. + matching_word_n_grams: + A total number of matching word n-grams between the best matching reference and hypothesis. + target_char_n_grams: A total number of reference character n-grams. + target_word_n_grams: A total number of reference word n-grams. + + """ + best_f_score = tensor(0.0) + best_matching_char_n_grams: Dict[int, Tensor] = defaultdict(lambda: tensor(0.0)) + best_matching_word_n_grams: Dict[int, Tensor] = defaultdict(lambda: tensor(0.0)) + best_target_char_n_grams: Dict[int, Tensor] = defaultdict(lambda: tensor(0.0)) + best_target_word_n_grams: Dict[int, Tensor] = defaultdict(lambda: tensor(0.0)) + + for target in targets: + ( + target_char_n_grams_counts, + target_word_n_grams_counts, + target_char_n_grams, + target_word_n_grams, + ) = _get_n_grams_counts_and_total_ngrams(target, n_char_order, n_word_order, lowercase, whitespace) + matching_char_n_grams = _get_ngram_matches(target_char_n_grams_counts, pred_char_n_grams_counts) + matching_word_n_grams = _get_ngram_matches(target_word_n_grams_counts, pred_word_n_grams_counts) + + f_score = _calculate_fscore( + matching_char_n_grams, + matching_word_n_grams, + pred_char_n_grams, + pred_word_n_grams, + target_char_n_grams, + target_word_n_grams, + n_order, + beta, + ) + + if f_score > best_f_score: + best_f_score = f_score + best_matching_char_n_grams = matching_char_n_grams + best_matching_word_n_grams = matching_word_n_grams + best_target_char_n_grams = target_char_n_grams + best_target_word_n_grams = target_word_n_grams + + return ( + best_f_score, + best_matching_char_n_grams, + best_matching_word_n_grams, + best_target_char_n_grams, + best_target_word_n_grams, + ) + + +def _chrf_score_update( + preds: Union[str, Sequence[str]], + target: Union[Sequence[str], Sequence[Sequence[str]]], + total_preds_char_n_grams: Dict[int, Tensor], + total_preds_word_n_grams: Dict[int, Tensor], + total_target_char_n_grams: Dict[int, Tensor], + total_target_word_n_grams: Dict[int, Tensor], + total_matching_char_n_grams: Dict[int, Tensor], + total_matching_word_n_grams: Dict[int, Tensor], + n_char_order: int, + n_word_order: int, + n_order: float, + beta: float, + lowercase: bool, + whitespace: bool, + sentence_chrf_score: Optional[List[Tensor]] = None, +) -> Tuple[ + Dict[int, Tensor], + Dict[int, Tensor], + Dict[int, Tensor], + Dict[int, Tensor], + Dict[int, Tensor], + Dict[int, Tensor], + Optional[List[Tensor]], +]: + """Update function for chrf score. + + Args: + preds: An iterable of hypothesis corpus. + target: An iterable of iterables of reference corpus. + total_preds_char_n_grams: A dictionary containing a total number of hypothesis character n-grams. + total_preds_word_n_grams: A dictionary containing a total number of hypothesis word n-grams. + total_target_char_n_grams: A dictionary containing a total number of reference character n-grams. + total_target_word_n_grams: A dictionary containing a total number of reference word n-grams. + total_matching_char_n_grams: + A dictionary containing a total number of matching character n-grams between references and hypotheses. + total_matching_word_n_grams: + A dictionary containing a total number of total matching word n-grams between references and hypotheses. + n_char_order: A character n-gram order. + n_word_order: A word n-gram order. + n_order: Sum of character and word n-gram order. + beta: A parameter determining an importance of recall w.r.t. precision. If `beta=1`, their importance is equal. + lowercase: An indication whether to enable case-insensitivity. + whitespace: An indication whether to keep whitespaces during character n-gram extraction. + sentence_chrf_score: A list of sentence-level chrF/chrF++ scores. + + Return: + total_target_char_n_grams: number of reference character n-grams. + total_target_word_n_grams: number of reference word n-grams. + total_preds_char_n_grams: number of hypothesis character n-grams. + total_preds_word_n_grams: number of hypothesis word n-grams. + total_matching_char_n_grams: number of matching character n-grams between references and hypotheses. + total_matching_word_n_grams: number of total matching word n-grams between references and hypotheses. + sentence_chrf_score: A list of sentence-level chrF/chrF++ scores. + + Raises: + ValueError: + If length of ``preds`` and ``target`` differs. + + """ + target_corpus, preds = _validate_inputs(target, preds) + + for pred, targets in zip(preds, target_corpus): + ( + pred_char_n_grams_counts, + pred_word_n_grams_counts, + pred_char_n_grams, + pred_word_n_grams, + ) = _get_n_grams_counts_and_total_ngrams(pred, n_char_order, n_word_order, lowercase, whitespace) + total_preds_char_n_grams = _sum_over_dicts(total_preds_char_n_grams, pred_char_n_grams) + total_preds_word_n_grams = _sum_over_dicts(total_preds_word_n_grams, pred_word_n_grams) + + ( + sentence_level_f_score, + matching_char_n_grams, + matching_word_n_grams, + target_char_n_grams, + target_word_n_grams, + ) = _calculate_sentence_level_chrf_score( + targets, # type: ignore + pred_char_n_grams_counts, + pred_word_n_grams_counts, + pred_char_n_grams, + pred_word_n_grams, + n_char_order, + n_word_order, + n_order, + beta, + lowercase, + whitespace, + ) + + if sentence_chrf_score is not None: + sentence_chrf_score.append(sentence_level_f_score.unsqueeze(0)) + + total_target_char_n_grams = _sum_over_dicts(total_target_char_n_grams, target_char_n_grams) + total_target_word_n_grams = _sum_over_dicts(total_target_word_n_grams, target_word_n_grams) + total_matching_char_n_grams = _sum_over_dicts(total_matching_char_n_grams, matching_char_n_grams) + total_matching_word_n_grams = _sum_over_dicts(total_matching_word_n_grams, matching_word_n_grams) + + return ( + total_preds_char_n_grams, + total_preds_word_n_grams, + total_target_char_n_grams, + total_target_word_n_grams, + total_matching_char_n_grams, + total_matching_word_n_grams, + sentence_chrf_score, + ) + + +def _chrf_score_compute( + total_preds_char_n_grams: Dict[int, Tensor], + total_preds_word_n_grams: Dict[int, Tensor], + total_target_char_n_grams: Dict[int, Tensor], + total_target_word_n_grams: Dict[int, Tensor], + total_matching_char_n_grams: Dict[int, Tensor], + total_matching_word_n_grams: Dict[int, Tensor], + n_order: float, + beta: float, +) -> Tensor: + """Compute chrF/chrF++ score based on pre-computed target, prediction and matching character and word n-grams. + + Args: + total_preds_char_n_grams: number of hypothesis character n-grams. + total_preds_word_n_grams: number of hypothesis word n-grams. + total_target_char_n_grams: number of reference character n-grams. + total_target_word_n_grams: number of reference word n-grams. + total_matching_char_n_grams: number of matching character n-grams between references and hypotheses. + total_matching_word_n_grams: number of total matching word n-grams between references and hypotheses. + n_order: A sum of character and word n-gram order. + beta: + A parameter determining an importance of recall w.r.t. precision. If `beta=1`, their importance is equal. + + Return: + A corpus-level chrF/chrF++ score. + + """ + return _calculate_fscore( + total_matching_char_n_grams, + total_matching_word_n_grams, + total_preds_char_n_grams, + total_preds_word_n_grams, + total_target_char_n_grams, + total_target_word_n_grams, + n_order, + beta, + ) def chrf_score( @@ -32,10 +548,6 @@ def chrf_score( `chrF++ score`_. This implementation follows the implementations from https://github.com/m-popovic/chrF and https://github.com/mjpost/sacrebleu/blob/master/sacrebleu/metrics/chrf.py. - .. attention:: - ChrF has been temporarily removed from the TorchMetrics package - due to licensing issues with the upstream package. - Args: preds: An iterable of hypothesis corpus. target: An iterable of iterables of reference corpus. @@ -50,13 +562,88 @@ def chrf_score( whitespace: An indication whether to keep whitespaces during character n-gram extraction. return_sentence_level_score: An indication whether a sentence-level chrF/chrF++ score to be returned. + Return: + A corpus-level chrF/chrF++ score. + (Optionally) A list of sentence-level chrF/chrF++ scores if `return_sentence_level_score=True`. + + Raises: + ValueError: + If ``n_char_order`` is not an integer greater than or equal to 1. + ValueError: + If ``n_word_order`` is not an integer greater than or equal to 0. + ValueError: + If ``beta`` is smaller than 0. + + Example: + >>> from torchmetrics.functional.text import chrf_score + >>> preds = ['the cat is on the mat'] + >>> target = [['there is a cat on the mat', 'a cat is on the mat']] + >>> chrf_score(preds, target) + tensor(0.8640) + References: [1] chrF: character n-gram F-score for automatic MT evaluation by Maja Popović `chrF score`_ [2] chrF++: words helping character n-grams by Maja Popović `chrF++ score`_ """ - raise NotImplementedError( - "ChrF has been temporarily removed from the TorchMetrics package" - " due to licensing issues with the upstream package." + if not isinstance(n_char_order, int) or n_char_order < 1: + raise ValueError("Expected argument `n_char_order` to be an integer greater than or equal to 1.") + if not isinstance(n_word_order, int) or n_word_order < 0: + raise ValueError("Expected argument `n_word_order` to be an integer greater than or equal to 0.") + if beta < 0: + raise ValueError("Expected argument `beta` to be greater than 0.") + + n_order = float(n_char_order + n_word_order) + + ( + total_preds_char_n_grams, + total_preds_word_n_grams, + total_target_char_n_grams, + total_target_word_n_grams, + total_matching_char_n_grams, + total_matching_word_n_grams, + ) = _prepare_n_grams_dicts(n_char_order, n_word_order) + + sentence_chrf_score: Optional[List[Tensor]] = [] if return_sentence_level_score else None + + ( + total_preds_char_n_grams, + total_preds_word_n_grams, + total_target_char_n_grams, + total_target_word_n_grams, + total_matching_char_n_grams, + total_matching_word_n_grams, + sentence_chrf_score, + ) = _chrf_score_update( + preds, + target, + total_preds_char_n_grams, + total_preds_word_n_grams, + total_target_char_n_grams, + total_target_word_n_grams, + total_matching_char_n_grams, + total_matching_word_n_grams, + n_char_order, + n_word_order, + n_order, + beta, + lowercase, + whitespace, + sentence_chrf_score, + ) + + chrf_f_score = _chrf_score_compute( + total_preds_char_n_grams, + total_preds_word_n_grams, + total_target_char_n_grams, + total_target_word_n_grams, + total_matching_char_n_grams, + total_matching_word_n_grams, + n_order, + beta, ) + + if sentence_chrf_score: + return chrf_f_score, torch.cat(sentence_chrf_score) + return chrf_f_score diff --git a/src/torchmetrics/text/_deprecated.py b/src/torchmetrics/text/_deprecated.py index 77e32730711..d3ba1c4010e 100644 --- a/src/torchmetrics/text/_deprecated.py +++ b/src/torchmetrics/text/_deprecated.py @@ -57,7 +57,15 @@ def __init__( class _CHRFScore(CHRFScore): - """Wrapper for deprecated import.""" + """Wrapper for deprecated import. + + >>> preds = ['the cat is on the mat'] + >>> target = [['there is a cat on the mat', 'a cat is on the mat']] + >>> chrf = _CHRFScore() + >>> chrf(preds, target) + tensor(0.8640) + + """ def __init__( self, diff --git a/src/torchmetrics/text/chrf.py b/src/torchmetrics/text/chrf.py index 742388d06ff..1ff412ab1a4 100644 --- a/src/torchmetrics/text/chrf.py +++ b/src/torchmetrics/text/chrf.py @@ -11,13 +11,20 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +# referenced from +# Library Name: torchtext +# Authors: torchtext authors and @sluks +# Date: 2021-11-25 +# Link: import itertools from typing import Any, Dict, Iterator, List, Optional, Sequence, Tuple, Union -from torch import Tensor +import torch +from torch import Tensor, tensor from torchmetrics import Metric +from torchmetrics.functional.text.chrf import _chrf_score_compute, _chrf_score_update, _prepare_n_grams_dicts from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE @@ -49,10 +56,6 @@ class CHRFScore(Metric): in `chrF++ score`_. This implementation follows the implementations from https://github.com/m-popovic/chrF and https://github.com/mjpost/sacrebleu/blob/master/sacrebleu/metrics/chrf.py. - .. attention:: - ChrF has been temporarily removed from the TorchMetrics package - due to licensing issues with the upstream package. - As input to ``forward`` and ``update`` the metric accepts the following input: - ``preds`` (:class:`~Sequence`): An iterable of hypothesis corpus @@ -73,6 +76,22 @@ class CHRFScore(Metric): return_sentence_level_score: An indication whether a sentence-level chrF/chrF++ score to be returned. kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + Raises: + ValueError: + If ``n_char_order`` is not an integer greater than or equal to 1. + ValueError: + If ``n_word_order`` is not an integer greater than or equal to 0. + ValueError: + If ``beta`` is smaller than 0. + + Example: + >>> from torchmetrics.text import CHRFScore + >>> preds = ['the cat is on the mat'] + >>> target = [['there is a cat on the mat', 'a cat is on the mat']] + >>> chrf = CHRFScore() + >>> chrf(preds, target) + tensor(0.8640) + """ is_differentiable: bool = False @@ -93,77 +112,73 @@ def __init__( return_sentence_level_score: bool = False, **kwargs: Any, ) -> None: - # super().__init__(**kwargs) - # - # if not isinstance(n_char_order, int) or n_char_order < 1: - # raise ValueError("Expected argument `n_char_order` to be an integer greater than or equal to 1.") - # self.n_char_order = n_char_order - # if not isinstance(n_word_order, int) or n_word_order < 0: - # raise ValueError("Expected argument `n_word_order` to be an integer greater than or equal to 0.") - # self.n_word_order = n_word_order - # if beta < 0: - # raise ValueError("Expected argument `beta` to be greater than 0.") - # self.beta = beta - # self.lowercase = lowercase - # self.whitespace = whitespace - # self.return_sentence_level_score = return_sentence_level_score - # - # self.n_order = float(n_char_order + n_word_order) - # - # # Adding state dynamically - # for (n_gram_level, n_gram_order), text in self._get_text_n_gram_iterator(): - # for n in range(1, n_gram_order + 1): - # state_name = self._get_state_name(text, n_gram_level, n) - # self.add_state(state_name, tensor(0.0), dist_reduce_fx="sum") - # - # if self.return_sentence_level_score: - # self.add_state("sentence_chrf_score", [], dist_reduce_fx="cat") - raise NotImplementedError( - "ChrF has been temporarily removed from the TorchMetrics package" - " due to licensing issues with the upstream package." - ) + super().__init__(**kwargs) + + if not isinstance(n_char_order, int) or n_char_order < 1: + raise ValueError("Expected argument `n_char_order` to be an integer greater than or equal to 1.") + self.n_char_order = n_char_order + if not isinstance(n_word_order, int) or n_word_order < 0: + raise ValueError("Expected argument `n_word_order` to be an integer greater than or equal to 0.") + self.n_word_order = n_word_order + if beta < 0: + raise ValueError("Expected argument `beta` to be greater than 0.") + self.beta = beta + self.lowercase = lowercase + self.whitespace = whitespace + self.return_sentence_level_score = return_sentence_level_score + + self.n_order = float(n_char_order + n_word_order) + + # Adding state dynamically + for (n_gram_level, n_gram_order), text in self._get_text_n_gram_iterator(): + for n in range(1, n_gram_order + 1): + state_name = self._get_state_name(text, n_gram_level, n) + self.add_state(state_name, tensor(0.0), dist_reduce_fx="sum") + + if self.return_sentence_level_score: + self.add_state("sentence_chrf_score", [], dist_reduce_fx="cat") def update(self, preds: Sequence[str], target: Sequence[Sequence[str]]) -> None: """Update state with predictions and targets.""" - # n_grams_dicts_tuple = _chrf_score_update( - # preds, - # target, - # *self._convert_states_to_dicts(), - # self.n_char_order, - # self.n_word_order, - # self.n_order, - # self.beta, - # self.lowercase, - # self.whitespace, - # self.sentence_chrf_score if self.return_sentence_level_score else None, - # ) - # self._update_states_from_dicts(n_grams_dicts_tuple[:-1]) - # if self.sentence_chrf_score is not None: - # self.sentence_chrf_score = n_grams_dicts_tuple[-1] - - def compute(self) -> Union[Tensor, Tuple[Tensor, Tensor]]: # type: ignore[empty-body] + n_grams_dicts_tuple = _chrf_score_update( + preds, + target, + *self._convert_states_to_dicts(), + self.n_char_order, + self.n_word_order, + self.n_order, + self.beta, + self.lowercase, + self.whitespace, + self.sentence_chrf_score if self.return_sentence_level_score else None, + ) + self._update_states_from_dicts(n_grams_dicts_tuple[:-1]) + if self.sentence_chrf_score is not None: + self.sentence_chrf_score = n_grams_dicts_tuple[-1] + + def compute(self) -> Union[Tensor, Tuple[Tensor, Tensor]]: """Calculate chrF/chrF++ score.""" - # if self.sentence_chrf_score is not None: - # return ( - # _chrf_score_compute(*self._convert_states_to_dicts(), self.n_order, self.beta), - # torch.cat(self.sentence_chrf_score), - # ) - # return _chrf_score_compute(*self._convert_states_to_dicts(), self.n_order, self.beta) - - def _convert_states_to_dicts(self) -> _DICT_STATES_TYPES: # type: ignore[empty-body] + if self.sentence_chrf_score is not None: + return ( + _chrf_score_compute(*self._convert_states_to_dicts(), self.n_order, self.beta), + torch.cat(self.sentence_chrf_score), + ) + return _chrf_score_compute(*self._convert_states_to_dicts(), self.n_order, self.beta) + + def _convert_states_to_dicts(self) -> _DICT_STATES_TYPES: """Convert global metric states to the n-gram dictionaries to be passed in ``_chrf_score_update``.""" - # n_grams_dicts: Dict[str, Dict[int, Tensor]] = dict( - # zip(_DICT_STATES_NAMES, _prepare_n_grams_dicts(self.n_char_order, self.n_word_order)) - # ) - # - # for (n_gram_level, n_gram_order), text in self._get_text_n_gram_iterator(): - # for n in range(1, n_gram_order + 1): - # dict_name = self._get_dict_name(text, n_gram_level) - # state_name = self._get_state_name(text, n_gram_level, n) - # - # n_grams_dicts[dict_name][n] = getattr(self, state_name) - # - # return tuple(n_grams_dicts.values()) # type: ignore + n_grams_dicts: Dict[str, Dict[int, Tensor]] = dict( + zip(_DICT_STATES_NAMES, _prepare_n_grams_dicts(self.n_char_order, self.n_word_order)) + ) + + for (n_gram_level, n_gram_order), text in self._get_text_n_gram_iterator(): + for n in range(1, n_gram_order + 1): + dict_name = self._get_dict_name(text, n_gram_level) + state_name = self._get_state_name(text, n_gram_level, n) + + n_grams_dicts[dict_name][n] = getattr(self, state_name) + + return tuple(n_grams_dicts.values()) # type: ignore def _update_states_from_dicts(self, n_grams_dicts_tuple: _DICT_STATES_TYPES) -> None: """Update global metric states based on the n-gram dictionaries calculated on the current batch.""" @@ -206,5 +221,29 @@ def plot( ModuleNotFoundError: If `matplotlib` is not installed + .. plot:: + :scale: 75 + + >>> # Example plotting a single value + >>> from torchmetrics.text import CHRFScore + >>> metric = CHRFScore() + >>> preds = ['the cat is on the mat'] + >>> target = [['there is a cat on the mat', 'a cat is on the mat']] + >>> metric.update(preds, target) + >>> fig_, ax_ = metric.plot() + + .. plot:: + :scale: 75 + + >>> # Example plotting multiple values + >>> from torchmetrics.text import CHRFScore + >>> metric = CHRFScore() + >>> preds = ['the cat is on the mat'] + >>> target = [['there is a cat on the mat', 'a cat is on the mat']] + >>> values = [ ] + >>> for _ in range(10): + ... values.append(metric(preds, target)) + >>> fig_, ax_ = metric.plot(values) + """ return self._plot(val, ax) diff --git a/tests/unittests/text/test_chrf.py b/tests/unittests/text/test_chrf.py new file mode 100644 index 00000000000..233c9451381 --- /dev/null +++ b/tests/unittests/text/test_chrf.py @@ -0,0 +1,163 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from functools import partial +from typing import Sequence + +import pytest +from torch import Tensor, tensor +from torchmetrics.functional.text.chrf import chrf_score +from torchmetrics.text.chrf import CHRFScore + +from unittests.text._helpers import TextTester +from unittests.text._inputs import _inputs_multiple_references, _inputs_single_sentence_multiple_references + + +def _reference_sacrebleu_chrf( + preds: Sequence[str], + targets: Sequence[Sequence[str]], + char_order: int, + word_order: int, + lowercase: bool, + whitespace: bool, +) -> Tensor: + try: + from sacrebleu import CHRF + except ImportError: + pytest.skip("test requires sacrebleu package to be installed") + + sacrebleu_chrf = CHRF( + char_order=char_order, word_order=word_order, lowercase=lowercase, whitespace=whitespace, eps_smoothing=True + ) + # Sacrebleu CHRF expects different format of input + targets = [[target[i] for target in targets] for i in range(len(targets[0]))] + sacrebleu_chrf = sacrebleu_chrf.corpus_score(preds, targets).score / 100 + return tensor(sacrebleu_chrf) + + +@pytest.mark.parametrize( + ["char_order", "word_order", "lowercase", "whitespace"], + [ + (6, 2, False, False), + (6, 2, False, True), + (4, 2, True, False), + (6, 0, True, False), + (6, 0, True, True), + (4, 0, False, True), + ], +) +@pytest.mark.parametrize( + ["preds", "targets"], + [(_inputs_multiple_references.preds, _inputs_multiple_references.target)], +) +class TestCHRFScore(TextTester): + """Test class for `CHRFScore` metric.""" + + @pytest.mark.parametrize("ddp", [pytest.param(True, marks=pytest.mark.DDP), False]) + def test_chrf_score_class(self, ddp, preds, targets, char_order, word_order, lowercase, whitespace): + """Test class implementation of metric.""" + metric_args = { + "n_char_order": char_order, + "n_word_order": word_order, + "lowercase": lowercase, + "whitespace": whitespace, + } + nltk_metric = partial( + _reference_sacrebleu_chrf, + char_order=char_order, + word_order=word_order, + lowercase=lowercase, + whitespace=whitespace, + ) + + self.run_class_metric_test( + ddp=ddp, + preds=preds, + targets=targets, + metric_class=CHRFScore, + reference_metric=nltk_metric, + metric_args=metric_args, + ) + + def test_chrf_score_functional(self, preds, targets, char_order, word_order, lowercase, whitespace): + """Test functional implementation of metric.""" + metric_args = { + "n_char_order": char_order, + "n_word_order": word_order, + "lowercase": lowercase, + "whitespace": whitespace, + } + nltk_metric = partial( + _reference_sacrebleu_chrf, + char_order=char_order, + word_order=word_order, + lowercase=lowercase, + whitespace=whitespace, + ) + + self.run_functional_metric_test( + preds, + targets, + metric_functional=chrf_score, + reference_metric=nltk_metric, + metric_args=metric_args, + ) + + def test_chrf_score_differentiability(self, preds, targets, char_order, word_order, lowercase, whitespace): + """Test the differentiability of the metric, according to its `is_differentiable` attribute.""" + metric_args = { + "n_char_order": char_order, + "n_word_order": word_order, + "lowercase": lowercase, + "whitespace": whitespace, + } + + self.run_differentiability_test( + preds=preds, + targets=targets, + metric_module=CHRFScore, + metric_functional=chrf_score, + metric_args=metric_args, + ) + + +def test_chrf_empty_functional(): + """Test that eed returns 0 when no input is provided.""" + preds = [] + targets = [[]] + assert chrf_score(preds, targets) == tensor(0.0) + + +def test_chrf_empty_class(): + """Test that eed returns 0 when no input is provided.""" + chrf = CHRFScore() + preds = [] + targets = [[]] + assert chrf(preds, targets) == tensor(0.0) + + +def test_chrf_return_sentence_level_score_functional(): + """Test that chrf can return sentence level scores.""" + preds = _inputs_single_sentence_multiple_references.preds + targets = _inputs_single_sentence_multiple_references.target + _, chrf_sentence_score = chrf_score(preds, targets, return_sentence_level_score=True) + isinstance(chrf_sentence_score, Tensor) + + +def test_chrf_return_sentence_level_class(): + """Test that chrf can return sentence level scores.""" + chrf = CHRFScore(return_sentence_level_score=True) + preds = _inputs_single_sentence_multiple_references.preds + targets = _inputs_single_sentence_multiple_references.target + _, chrf_sentence_score = chrf(preds, targets) + isinstance(chrf_sentence_score, Tensor)