Skip to content

Commit

Permalink
Fix test_bertscore_sorting bug + validate idf arg (#2727)
Browse files Browse the repository at this point in the history
* fix test_bertscore_sorting bug + validate idf arg

* update code

---------

Co-authored-by: Guilherme Paulino-Passos @ DoC-cluster <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Jirka Borovec <[email protected]>
Co-authored-by: Nicki Skafte Detlefsen <[email protected]>
  • Loading branch information
5 people authored Oct 10, 2024
1 parent 6bfb775 commit 0fe772d
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 5 deletions.
7 changes: 6 additions & 1 deletion src/torchmetrics/functional/text/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,11 +349,16 @@ def bert_score(
"""
if len(preds) != len(target):
raise ValueError("Number of predicted and reference sententes must be the same!")
raise ValueError(
"Expected number of predicted and reference sententes to be the same, but got"
f"{len(preds)} and {len(target)}"
)
if not isinstance(preds, (str, list, dict)): # dict for BERTScore class compute call
preds = list(preds)
if not isinstance(target, (str, list, dict)): # dict for BERTScore class compute call
target = list(target)
if not isinstance(idf, bool):
raise ValueError(f"Expected argument `idf` to be a boolean, but got {idf}.")

if verbose and (not _TQDM_AVAILABLE):
raise ModuleNotFoundError(
Expand Down
5 changes: 1 addition & 4 deletions tests/unittests/text/test_bertscore.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,10 +174,7 @@ def test_bertscore_differentiability(

@skip_on_connection_issues()
@pytest.mark.skipif(not _TRANSFORMERS_GREATER_EQUAL_4_4, reason="test requires transformers>4.4")
@pytest.mark.parametrize(
"idf",
[(False,), (True,)],
)
@pytest.mark.parametrize("idf", [True, False])
def test_bertscore_sorting(idf: bool):
"""Test that BERTScore is invariant to the order of the inputs."""
short = "Short text"
Expand Down

0 comments on commit 0fe772d

Please sign in to comment.