Skip to content

Commit

Permalink
TST check for nan and inf + single sample for metrics cl… (scikit-lea…
Browse files Browse the repository at this point in the history
  • Loading branch information
glemaitre authored and rth committed Oct 4, 2019
1 parent ac72a48 commit a47e914
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 18 deletions.
24 changes: 13 additions & 11 deletions sklearn/metrics/cluster/supervised.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from scipy import sparse as sp

from .expected_mutual_info_fast import expected_mutual_information
from ...utils.validation import check_array
from ...utils.validation import check_array, check_consistent_length
from ...utils.fixes import comb, _astype_copy_false


Expand All @@ -36,14 +36,18 @@ def check_clusterings(labels_true, labels_pred):
Parameters
----------
labels_true : int array, shape = [n_samples]
The true labels
labels_true : array-like of shape (n_samples,)
The true labels.
labels_pred : int array, shape = [n_samples]
The predicted labels
labels_pred : array-like of shape (n_samples,)
The predicted labels.
"""
labels_true = np.asarray(labels_true)
labels_pred = np.asarray(labels_pred)
labels_true = check_array(
labels_true, ensure_2d=False, ensure_min_samples=0
)
labels_pred = check_array(
labels_pred, ensure_2d=False, ensure_min_samples=0
)

# input checks
if labels_true.ndim != 1:
Expand All @@ -52,10 +56,8 @@ def check_clusterings(labels_true, labels_pred):
if labels_pred.ndim != 1:
raise ValueError(
"labels_pred must be 1D: shape is %r" % (labels_pred.shape,))
if labels_true.shape != labels_pred.shape:
raise ValueError(
"labels_true and labels_pred must have same size, got %d and %d"
% (labels_true.shape[0], labels_pred.shape[0]))
check_consistent_length(labels_true, labels_pred)

return labels_true, labels_pred


Expand Down
34 changes: 29 additions & 5 deletions sklearn/metrics/cluster/tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,7 @@ def test_normalized_output(metric_name):
# 0.22 AMI and NMI changes
@pytest.mark.filterwarnings('ignore::FutureWarning')
@pytest.mark.parametrize(
"metric_name",
dict(SUPERVISED_METRICS, **UNSUPERVISED_METRICS)
"metric_name", dict(SUPERVISED_METRICS, **UNSUPERVISED_METRICS)
)
def test_permute_labels(metric_name):
# All clustering metrics do not change score due to permutations of labels
Expand All @@ -150,11 +149,10 @@ def test_permute_labels(metric_name):
# 0.22 AMI and NMI changes
@pytest.mark.filterwarnings('ignore::FutureWarning')
@pytest.mark.parametrize(
"metric_name",
dict(SUPERVISED_METRICS, **UNSUPERVISED_METRICS)
"metric_name", dict(SUPERVISED_METRICS, **UNSUPERVISED_METRICS)
)
# For all clustering metrics Input parameters can be both
# in the form of arrays lists, positive, negetive or string
# in the form of arrays lists, positive, negative or string
def test_format_invariance(metric_name):
y_true = [0, 0, 0, 0, 1, 1, 1, 1]
y_pred = [0, 1, 2, 3, 4, 5, 6, 7]
Expand Down Expand Up @@ -183,3 +181,29 @@ def generate_formats(y):
y_true_gen = generate_formats(y_true)
for (y_true_fmt, fmt_name) in y_true_gen:
assert score_1 == metric(X, y_true_fmt)


@pytest.mark.parametrize("metric", SUPERVISED_METRICS.values())
def test_single_sample(metric):
# only the supervised metrics support single sample
for i, j in [(0, 0), (0, 1), (1, 0), (1, 1)]:
metric([i], [j])


@pytest.mark.parametrize(
"metric_name, metric_func",
dict(SUPERVISED_METRICS, **UNSUPERVISED_METRICS).items()
)
def test_inf_nan_input(metric_name, metric_func):
if metric_name in SUPERVISED_METRICS:
invalids = [([0, 1], [np.inf, np.inf]),
([0, 1], [np.nan, np.nan]),
([0, 1], [np.nan, np.inf])]
else:
X = np.random.randint(10, size=(2, 10))
invalids = [(X, [np.inf, np.inf]),
(X, [np.nan, np.nan]),
(X, [np.nan, np.inf])]
with pytest.raises(ValueError, match='contains NaN, infinity'):
for args in invalids:
metric_func(*args)
4 changes: 2 additions & 2 deletions sklearn/metrics/cluster/tests/test_supervised.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@
@ignore_warnings(category=FutureWarning)
def test_error_messages_on_wrong_input():
for score_func in score_funcs:
expected = ('labels_true and labels_pred must have same size,'
' got 2 and 3')
expected = (r'Found input variables with inconsistent numbers '
r'of samples: \[2, 3\]')
with pytest.raises(ValueError, match=expected):
score_func([0, 1], [1, 1, 1])

Expand Down

0 comments on commit a47e914

Please sign in to comment.