diff --git a/src/torchmetrics/functional/text/bert.py b/src/torchmetrics/functional/text/bert.py index cfdb8c743b4..c073ac6e74f 100644 --- a/src/torchmetrics/functional/text/bert.py +++ b/src/torchmetrics/functional/text/bert.py @@ -349,11 +349,16 @@ def bert_score( """ if len(preds) != len(target): - raise ValueError("Number of predicted and reference sententes must be the same!") + raise ValueError( + "Expected number of predicted and reference sententes to be the same, but got" + f"{len(preds)} and {len(target)}" + ) if not isinstance(preds, (str, list, dict)): # dict for BERTScore class compute call preds = list(preds) if not isinstance(target, (str, list, dict)): # dict for BERTScore class compute call target = list(target) + if not isinstance(idf, bool): + raise ValueError(f"Expected argument `idf` to be a boolean, but got {idf}.") if verbose and (not _TQDM_AVAILABLE): raise ModuleNotFoundError( diff --git a/tests/unittests/text/test_bertscore.py b/tests/unittests/text/test_bertscore.py index dfd6d60a0e5..99605b69cc2 100644 --- a/tests/unittests/text/test_bertscore.py +++ b/tests/unittests/text/test_bertscore.py @@ -174,10 +174,7 @@ def test_bertscore_differentiability( @skip_on_connection_issues() @pytest.mark.skipif(not _TRANSFORMERS_GREATER_EQUAL_4_4, reason="test requires transformers>4.4") -@pytest.mark.parametrize( - "idf", - [(False,), (True,)], -) +@pytest.mark.parametrize("idf", [True, False]) def test_bertscore_sorting(idf: bool): """Test that BERTScore is invariant to the order of the inputs.""" short = "Short text"