From cc97f16f01fbeec40cc032e828f01fe401ca0e29 Mon Sep 17 00:00:00 2001 From: "Guilherme Paulino-Passos @ DoC-cluster" Date: Mon, 9 Sep 2024 19:21:30 +0100 Subject: [PATCH 1/3] fix test_bertscore_sorting bug + validate idf arg --- src/torchmetrics/functional/text/bert.py | 2 ++ tests/unittests/text/test_bertscore.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/src/torchmetrics/functional/text/bert.py b/src/torchmetrics/functional/text/bert.py index cfdb8c743b4..18a1ae94ef6 100644 --- a/src/torchmetrics/functional/text/bert.py +++ b/src/torchmetrics/functional/text/bert.py @@ -354,6 +354,8 @@ def bert_score( preds = list(preds) if not isinstance(target, (str, list, dict)): # dict for BERTScore class compute call target = list(target) + if not isinstance(idf, bool): + raise ValueError(f"The value of idf must be a boolean. Value passed:{idf=}") if verbose and (not _TQDM_AVAILABLE): raise ModuleNotFoundError( diff --git a/tests/unittests/text/test_bertscore.py b/tests/unittests/text/test_bertscore.py index dfd6d60a0e5..230514b9091 100644 --- a/tests/unittests/text/test_bertscore.py +++ b/tests/unittests/text/test_bertscore.py @@ -175,7 +175,7 @@ def test_bertscore_differentiability( @skip_on_connection_issues() @pytest.mark.skipif(not _TRANSFORMERS_GREATER_EQUAL_4_4, reason="test requires transformers>4.4") @pytest.mark.parametrize( - "idf", + ["idf"], [(False,), (True,)], ) def test_bertscore_sorting(idf: bool): From 6f50ffb56b4665c94a07b178459e7c13094dfc7b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 13 Sep 2024 16:18:19 +0000 Subject: [PATCH 2/3] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/unittests/text/test_bertscore.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unittests/text/test_bertscore.py b/tests/unittests/text/test_bertscore.py index 230514b9091..dfd6d60a0e5 100644 --- a/tests/unittests/text/test_bertscore.py +++ b/tests/unittests/text/test_bertscore.py @@ -175,7 +175,7 @@ def test_bertscore_differentiability( @skip_on_connection_issues() @pytest.mark.skipif(not _TRANSFORMERS_GREATER_EQUAL_4_4, reason="test requires transformers>4.4") @pytest.mark.parametrize( - ["idf"], + "idf", [(False,), (True,)], ) def test_bertscore_sorting(idf: bool): From 51062d5961b51048a0c8211d9fd829d058674f85 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Detlefsen Date: Thu, 10 Oct 2024 10:50:03 +0200 Subject: [PATCH 3/3] update code --- src/torchmetrics/functional/text/bert.py | 7 +++++-- tests/unittests/text/test_bertscore.py | 5 +---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/torchmetrics/functional/text/bert.py b/src/torchmetrics/functional/text/bert.py index 18a1ae94ef6..c073ac6e74f 100644 --- a/src/torchmetrics/functional/text/bert.py +++ b/src/torchmetrics/functional/text/bert.py @@ -349,13 +349,16 @@ def bert_score( """ if len(preds) != len(target): - raise ValueError("Number of predicted and reference sententes must be the same!") + raise ValueError( + "Expected number of predicted and reference sententes to be the same, but got" + f"{len(preds)} and {len(target)}" + ) if not isinstance(preds, (str, list, dict)): # dict for BERTScore class compute call preds = list(preds) if not isinstance(target, (str, list, dict)): # dict for BERTScore class compute call target = list(target) if not isinstance(idf, bool): - raise ValueError(f"The value of idf must be a boolean. Value passed:{idf=}") + raise ValueError(f"Expected argument `idf` to be a boolean, but got {idf}.") if verbose and (not _TQDM_AVAILABLE): raise ModuleNotFoundError( diff --git a/tests/unittests/text/test_bertscore.py b/tests/unittests/text/test_bertscore.py index dfd6d60a0e5..99605b69cc2 100644 --- a/tests/unittests/text/test_bertscore.py +++ b/tests/unittests/text/test_bertscore.py @@ -174,10 +174,7 @@ def test_bertscore_differentiability( @skip_on_connection_issues() @pytest.mark.skipif(not _TRANSFORMERS_GREATER_EQUAL_4_4, reason="test requires transformers>4.4") -@pytest.mark.parametrize( - "idf", - [(False,), (True,)], -) +@pytest.mark.parametrize("idf", [True, False]) def test_bertscore_sorting(idf: bool): """Test that BERTScore is invariant to the order of the inputs.""" short = "Short text"