Skip to content

Commit

Permalink
Merge branch 'master' into newmetric/nrmse
Browse files Browse the repository at this point in the history
  • Loading branch information
SkafteNicki authored Oct 12, 2024
2 parents b7a116d + 0990ecf commit a1d44ac
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 5 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added new metric `ProcrustesDistance` to new domain Shape ([#2723](https://github.com/Lightning-AI/torchmetrics/pull/2723)


- Added `truncation` argument to `BERTScore` ([#2776](https://github.com/Lightning-AI/torchmetrics/pull/2776))


- Added `NormalizedRootMeanSquaredError` metric to regression subpackage ([#2442](https://github.com/Lightning-AI/torchmetrics/pull/2442))


Expand Down
5 changes: 4 additions & 1 deletion src/torchmetrics/functional/text/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,7 @@ def bert_score(
rescale_with_baseline: bool = False,
baseline_path: Optional[str] = None,
baseline_url: Optional[str] = None,
truncation: bool = False,
) -> Dict[str, Union[Tensor, List[float], str]]:
"""`Bert_score Evaluating Text Generation`_ for text similirity matching.
Expand Down Expand Up @@ -323,6 +324,7 @@ def bert_score(
of the files from `BERT_score`_
baseline_path: A path to the user's own local csv/tsv file with the baseline scale.
baseline_url: A url path to the user's own csv/tsv file with the baseline scale.
truncation: An indication of whether the input sequences should be truncated to the maximum length.
Returns:
Python dictionary containing the keys ``precision``, ``recall`` and ``f1`` with corresponding values.
Expand Down Expand Up @@ -417,13 +419,14 @@ def bert_score(

# We ignore mypy typing below as the proper typing is ensured by conditions above, only mypy cannot infer that.
if _are_valid_lists:
target_dataset = TextDataset(target, tokenizer, max_length, idf=idf) # type: ignore
target_dataset = TextDataset(target, tokenizer, max_length, idf=idf, truncation=truncation) # type: ignore
preds_dataset = TextDataset(
preds, # type: ignore
tokenizer,
max_length,
idf=idf,
tokens_idf=target_dataset.tokens_idf,
truncation=truncation,
)
elif _are_valid_tensors:
target_dataset = TokenizedDataset(**target, idf=idf) # type: ignore
Expand Down
6 changes: 4 additions & 2 deletions src/torchmetrics/functional/text/helper_embedding_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,10 +195,11 @@ def __init__(
tokenizer: Any,
max_length: int = 512,
preprocess_text_fn: Callable[
[List[str], Any, int], Union[Dict[str, Tensor], Tuple[Dict[str, Tensor], Optional[Tensor]]]
[List[str], Any, int, bool], Union[Dict[str, Tensor], Tuple[Dict[str, Tensor], Optional[Tensor]]]
] = _preprocess_text,
idf: bool = False,
tokens_idf: Optional[Dict[int, float]] = None,
truncation: bool = False,
) -> None:
"""Initialize text dataset class.
Expand All @@ -209,9 +210,10 @@ def __init__(
preprocess_text_fn: A function used for processing the input sentences.
idf: An indication of whether calculate token inverse document frequencies to weight the model embeddings.
tokens_idf: Inverse document frequencies (these should be calculated on reference sentences).
truncation: An indication of whether tokenized sequences should be padded only to the length of the longest
"""
_text = preprocess_text_fn(text, tokenizer, max_length)
_text = preprocess_text_fn(text, tokenizer, max_length, truncation)
if isinstance(_text, tuple):
self.text, self.sorting_indices = _text
else:
Expand Down
7 changes: 5 additions & 2 deletions src/torchmetrics/text/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ class BERTScore(Metric):
of the files from `BERT_score`_.
baseline_path: A path to the user's own local csv/tsv file with the baseline scale.
baseline_url: A url path to the user's own csv/tsv file with the baseline scale.
truncation: An indication of whether the input sequences should be truncated to the ``max_length``.
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
Example:
Expand Down Expand Up @@ -150,6 +151,7 @@ def __init__(
rescale_with_baseline: bool = False,
baseline_path: Optional[str] = None,
baseline_url: Optional[str] = None,
truncation: bool = False,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
Expand All @@ -169,6 +171,7 @@ def __init__(
self.rescale_with_baseline = rescale_with_baseline
self.baseline_path = baseline_path
self.baseline_url = baseline_url
self.truncation = truncation

if user_tokenizer:
self.tokenizer = user_tokenizer
Expand Down Expand Up @@ -210,15 +213,15 @@ def update(self, preds: Union[str, Sequence[str]], target: Union[str, Sequence[s
preds,
self.tokenizer,
self.max_length,
truncation=False,
truncation=self.truncation,
sort_according_length=False,
own_tokenizer=self.user_tokenizer,
)
target_dict, _ = _preprocess_text(
target,
self.tokenizer,
self.max_length,
truncation=False,
truncation=self.truncation,
sort_according_length=False,
own_tokenizer=self.user_tokenizer,
)
Expand Down
17 changes: 17 additions & 0 deletions tests/unittests/text/test_bertscore.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,3 +188,20 @@ def test_bertscore_sorting(idf: bool):

# First index should be the self-comparison - sorting by length should not shuffle this
assert score["f1"][0] > score["f1"][1]


@skip_on_connection_issues()
@pytest.mark.skipif(not _TRANSFORMERS_GREATER_EQUAL_4_4, reason="test requires transformers>4.4")
@pytest.mark.parametrize("truncation", [True, False])
def test_bertscore_truncation(truncation: bool):
"""Test that BERTScore truncation works as expected."""
pred = ["abc " * 2000]
gt = ["def " * 2000]
bert_score = BERTScore(truncation=truncation)

if truncation:
res = bert_score(pred, gt)
assert res["f1"] > 0.0
else:
with pytest.raises(RuntimeError, match="The expanded size of the tensor.*must match.*"):
bert_score(pred, gt)

0 comments on commit a1d44ac

Please sign in to comment.