diff --git a/tests/ignite/contrib/metrics/test_average_precision.py b/tests/ignite/contrib/metrics/test_average_precision.py index 3d06998d16e..7a943ae855e 100644 --- a/tests/ignite/contrib/metrics/test_average_precision.py +++ b/tests/ignite/contrib/metrics/test_average_precision.py @@ -63,91 +63,82 @@ def test_check_shape(): ap._check_shape((torch.rand(4, 3), torch.rand(4, 3, 1))) -def test_binary_and_multilabel_inputs(): +@pytest.fixture(params=[item for item in range(8)]) +def test_data_binary_and_multilabel(request): + return [ + # Binary input data of shape (N,) or (N, 1) + (torch.randint(0, 2, size=(50,)).long(), torch.randint(0, 2, size=(50,)).long(), 1), + (torch.randint(0, 2, size=(50, 1)).long(), torch.randint(0, 2, size=(50, 1)).long(), 1), + # updated batches + (torch.randint(0, 2, size=(50,)).long(), torch.randint(0, 2, size=(50,)).long(), 16), + (torch.randint(0, 2, size=(50, 1)).long(), torch.randint(0, 2, size=(50, 1)).long(), 16), + # Binary input data of shape (N, L) + (torch.randint(0, 2, size=(50, 4)).long(), torch.randint(0, 2, size=(50, 4)).long(), 1), + (torch.randint(0, 2, size=(50, 7)).long(), torch.randint(0, 2, size=(50, 7)).long(), 1), + # updated batches + (torch.randint(0, 2, size=(50, 4)).long(), torch.randint(0, 2, size=(50, 4)).long(), 16), + (torch.randint(0, 2, size=(50, 7)).long(), torch.randint(0, 2, size=(50, 7)).long(), 16), + ][request.param] + + +@pytest.mark.parametrize("n_times", range(5)) +def test_binary_and_multilabel_inputs(n_times, test_data_binary_and_multilabel): + y_pred, y, batch_size = test_data_binary_and_multilabel ap = AveragePrecision() + ap.reset() + if batch_size > 1: + n_iters = y.shape[0] // batch_size + 1 + for i in range(n_iters): + idx = i * batch_size + ap.update((y_pred[idx : idx + batch_size], y[idx : idx + batch_size])) + else: + ap.update((y_pred, y)) - def _test(y_pred, y, batch_size): - ap.reset() - if batch_size > 1: - n_iters = y.shape[0] // batch_size + 1 - for i in range(n_iters): - idx = i * batch_size - ap.update((y_pred[idx : idx + batch_size], y[idx : idx + batch_size])) - else: - ap.update((y_pred, y)) - - np_y = y.numpy() - np_y_pred = y_pred.numpy() - - res = ap.compute() - assert isinstance(res, float) - assert average_precision_score(np_y, np_y_pred) == pytest.approx(res) - - def get_test_cases(): - test_cases = [ - # Binary input data of shape (N,) or (N, 1) - (torch.randint(0, 2, size=(50,)).long(), torch.randint(0, 2, size=(50,)).long(), 1), - (torch.randint(0, 2, size=(50, 1)).long(), torch.randint(0, 2, size=(50, 1)).long(), 1), - # updated batches - (torch.randint(0, 2, size=(50,)).long(), torch.randint(0, 2, size=(50,)).long(), 16), - (torch.randint(0, 2, size=(50, 1)).long(), torch.randint(0, 2, size=(50, 1)).long(), 16), - # Binary input data of shape (N, L) - (torch.randint(0, 2, size=(50, 4)).long(), torch.randint(0, 2, size=(50, 4)).long(), 1), - (torch.randint(0, 2, size=(50, 7)).long(), torch.randint(0, 2, size=(50, 7)).long(), 1), - # updated batches - (torch.randint(0, 2, size=(50, 4)).long(), torch.randint(0, 2, size=(50, 4)).long(), 16), - (torch.randint(0, 2, size=(50, 7)).long(), torch.randint(0, 2, size=(50, 7)).long(), 16), - ] + np_y = y.numpy() + np_y_pred = y_pred.numpy() - return test_cases + res = ap.compute() + assert isinstance(res, float) + assert average_precision_score(np_y, np_y_pred) == pytest.approx(res) - for _ in range(5): - # check multiple random inputs as random exact occurencies are rare - test_cases = get_test_cases() - for y_pred, y, batch_size in test_cases: - _test(y_pred, y, batch_size) +@pytest.fixture(params=[item for item in range(4)]) +def test_data_integration_binary_and_multilabel(request): + return [ + # Binary input data of shape (N,) or (N, 1) + (torch.randint(0, 2, size=(100,)).long(), torch.randint(0, 2, size=(100,)).long(), 10), + (torch.randint(0, 2, size=(100, 1)).long(), torch.randint(0, 2, size=(100, 1)).long(), 10), + # Binary input data of shape (N, L) + (torch.randint(0, 2, size=(100, 3)).long(), torch.randint(0, 2, size=(100, 3)).long(), 10), + (torch.randint(0, 2, size=(100, 4)).long(), torch.randint(0, 2, size=(100, 4)).long(), 10), + ][request.param] -def test_integration_binary_and_mulitlabel_inputs(): - def _test(y_pred, y, batch_size): - def update_fn(engine, batch): - idx = (engine.state.iteration - 1) * batch_size - y_true_batch = np_y[idx : idx + batch_size] - y_pred_batch = np_y_pred[idx : idx + batch_size] - return torch.from_numpy(y_pred_batch), torch.from_numpy(y_true_batch) - engine = Engine(update_fn) +@pytest.mark.parametrize("n_times", range(5)) +def test_integration_binary_and_mulitlabel_inputs(n_times, test_data_integration_binary_and_multilabel): + y_pred, y, batch_size = test_data_integration_binary_and_multilabel - ap_metric = AveragePrecision() - ap_metric.attach(engine, "ap") + def update_fn(engine, batch): + idx = (engine.state.iteration - 1) * batch_size + y_true_batch = np_y[idx : idx + batch_size] + y_pred_batch = np_y_pred[idx : idx + batch_size] + return torch.from_numpy(y_pred_batch), torch.from_numpy(y_true_batch) - np_y = y.numpy() - np_y_pred = y_pred.numpy() + engine = Engine(update_fn) - np_ap = average_precision_score(np_y, np_y_pred) + ap_metric = AveragePrecision() + ap_metric.attach(engine, "ap") - data = list(range(y_pred.shape[0] // batch_size)) - ap = engine.run(data, max_epochs=1).metrics["ap"] + np_y = y.numpy() + np_y_pred = y_pred.numpy() - assert isinstance(ap, float) - assert np_ap == pytest.approx(ap) + np_ap = average_precision_score(np_y, np_y_pred) - def get_test_cases(): - test_cases = [ - # Binary input data of shape (N,) or (N, 1) - (torch.randint(0, 2, size=(100,)).long(), torch.randint(0, 2, size=(100,)).long(), 10), - (torch.randint(0, 2, size=(100, 1)).long(), torch.randint(0, 2, size=(100, 1)).long(), 10), - # Binary input data of shape (N, L) - (torch.randint(0, 2, size=(100, 3)).long(), torch.randint(0, 2, size=(100, 3)).long(), 10), - (torch.randint(0, 2, size=(100, 4)).long(), torch.randint(0, 2, size=(100, 4)).long(), 10), - ] - return test_cases + data = list(range(y_pred.shape[0] // batch_size)) + ap = engine.run(data, max_epochs=1).metrics["ap"] - for _ in range(5): - # check multiple random inputs as random exact occurencies are rare - test_cases = get_test_cases() - for y_pred, y, batch_size in test_cases: - _test(y_pred, y, batch_size) + assert isinstance(ap, float) + assert np_ap == pytest.approx(ap) def _test_distrib_binary_and_multilabel_inputs(device): diff --git a/tests/ignite/contrib/metrics/test_cohen_kappa.py b/tests/ignite/contrib/metrics/test_cohen_kappa.py index 4d5c89406cb..fa73a84cdfa 100644 --- a/tests/ignite/contrib/metrics/test_cohen_kappa.py +++ b/tests/ignite/contrib/metrics/test_cohen_kappa.py @@ -71,43 +71,38 @@ def test_cohen_kappa_wrong_weights_type(): ck = CohenKappa(weights="dd") -@pytest.mark.parametrize("weights", [None, "linear", "quadratic"]) -def test_binary_input(weights): - ck = CohenKappa(weights) - - def _test(y_pred, y, batch_size): - ck.reset() - if batch_size > 1: - n_iters = y.shape[0] // batch_size + 1 - for i in range(n_iters): - idx = i * batch_size - ck.update((y_pred[idx : idx + batch_size], y[idx : idx + batch_size])) - else: - ck.update((y_pred, y)) +@pytest.fixture(params=range(4)) +def test_data_binary(request): + return [ + # Binary input data of shape (N,) or (N, 1) + (torch.randint(0, 2, size=(10,)).long(), torch.randint(0, 2, size=(10,)).long(), 1), + (torch.randint(0, 2, size=(10, 1)).long(), torch.randint(0, 2, size=(10, 1)).long(), 1), + # updated batches + (torch.randint(0, 2, size=(50,)).long(), torch.randint(0, 2, size=(50,)).long(), 16), + (torch.randint(0, 2, size=(50, 1)).long(), torch.randint(0, 2, size=(50, 1)).long(), 16), + ][request.param] - np_y = y.numpy() - np_y_pred = y_pred.numpy() - res = ck.compute() - assert isinstance(res, float) - assert cohen_kappa_score(np_y, np_y_pred, weights=weights) == pytest.approx(res) +@pytest.mark.parametrize("n_times", range(5)) +@pytest.mark.parametrize("weights", [None, "linear", "quadratic"]) +def test_binary_input(n_times, weights, test_data_binary): + y_pred, y, batch_size = test_data_binary + ck = CohenKappa(weights) + ck.reset() + if batch_size > 1: + n_iters = y.shape[0] // batch_size + 1 + for i in range(n_iters): + idx = i * batch_size + ck.update((y_pred[idx : idx + batch_size], y[idx : idx + batch_size])) + else: + ck.update((y_pred, y)) - def get_test_cases(): - test_cases = [ - # Binary input data of shape (N,) or (N, 1) - (torch.randint(0, 2, size=(10,)).long(), torch.randint(0, 2, size=(10,)).long(), 1), - (torch.randint(0, 2, size=(10, 1)).long(), torch.randint(0, 2, size=(10, 1)).long(), 1), - # updated batches - (torch.randint(0, 2, size=(50,)).long(), torch.randint(0, 2, size=(50,)).long(), 16), - (torch.randint(0, 2, size=(50, 1)).long(), torch.randint(0, 2, size=(50, 1)).long(), 16), - ] - return test_cases + np_y = y.numpy() + np_y_pred = y_pred.numpy() - for _ in range(5): - # check multiple random inputs as random exact occurencies are rare - test_cases = get_test_cases() - for y_pred, y, batch_size in test_cases: - _test(y_pred, y, batch_size) + res = ck.compute() + assert isinstance(res, float) + assert cohen_kappa_score(np_y, np_y_pred, weights=weights) == pytest.approx(res) def test_multilabel_inputs(): @@ -129,44 +124,41 @@ def test_multilabel_inputs(): ck.compute() -@pytest.mark.parametrize("weights", [None, "linear", "quadratic"]) -def test_integration_binary_input(weights): - def _test(y_pred, y, batch_size): - def update_fn(engine, batch): - idx = (engine.state.iteration - 1) * batch_size - y_true_batch = np_y[idx : idx + batch_size] - y_pred_batch = np_y_pred[idx : idx + batch_size] - return torch.from_numpy(y_pred_batch), torch.from_numpy(y_true_batch) +@pytest.fixture(params=range(2)) +def test_data_integration_binary(request): + return [ + # Binary input data of shape (N,) or (N, 1) + (torch.randint(0, 2, size=(50,)).long(), torch.randint(0, 2, size=(50,)).long(), 10), + (torch.randint(0, 2, size=(50, 1)).long(), torch.randint(0, 2, size=(50, 1)).long(), 10), + ][request.param] - engine = Engine(update_fn) - ck_metric = CohenKappa(weights=weights) - ck_metric.attach(engine, "ck") +@pytest.mark.parametrize("n_times", range(5)) +@pytest.mark.parametrize("weights", [None, "linear", "quadratic"]) +def test_integration_binary_input(n_times, weights, test_data_integration_binary): + y_pred, y, batch_size = test_data_integration_binary - np_y = y.numpy() - np_y_pred = y_pred.numpy() + def update_fn(engine, batch): + idx = (engine.state.iteration - 1) * batch_size + y_true_batch = np_y[idx : idx + batch_size] + y_pred_batch = np_y_pred[idx : idx + batch_size] + return torch.from_numpy(y_pred_batch), torch.from_numpy(y_true_batch) - np_ck = cohen_kappa_score(np_y, np_y_pred, weights=weights) + engine = Engine(update_fn) - data = list(range(y_pred.shape[0] // batch_size)) - ck = engine.run(data, max_epochs=1).metrics["ck"] + ck_metric = CohenKappa(weights=weights) + ck_metric.attach(engine, "ck") - assert isinstance(ck, float) - assert np_ck == pytest.approx(ck) + np_y = y.numpy() + np_y_pred = y_pred.numpy() - def get_test_cases(): - test_cases = [ - # Binary input data of shape (N,) or (N, 1) - (torch.randint(0, 2, size=(50,)).long(), torch.randint(0, 2, size=(50,)).long(), 10), - (torch.randint(0, 2, size=(50, 1)).long(), torch.randint(0, 2, size=(50, 1)).long(), 10), - ] - return test_cases + np_ck = cohen_kappa_score(np_y, np_y_pred, weights=weights) - for _ in range(5): - # check multiple random inputs as random exact occurencies are rare - test_cases = get_test_cases() - for y_pred, y, batch_size in test_cases: - _test(y_pred, y, batch_size) + data = list(range(y_pred.shape[0] // batch_size)) + ck = engine.run(data, max_epochs=1).metrics["ck"] + + assert isinstance(ck, float) + assert np_ck == pytest.approx(ck) def _test_distrib_binary_input(device): diff --git a/tests/ignite/contrib/metrics/test_roc_auc.py b/tests/ignite/contrib/metrics/test_roc_auc.py index c5357072090..dcc14aaba30 100644 --- a/tests/ignite/contrib/metrics/test_roc_auc.py +++ b/tests/ignite/contrib/metrics/test_roc_auc.py @@ -64,48 +64,43 @@ def test_check_shape(): roc_auc._check_shape((torch.rand(4, 3), torch.rand(4, 3, 1))) -def test_binary_and_multilabel_inputs(): +@pytest.fixture(params=range(8)) +def test_data_binary_and_multilabel(request): + return [ + # Binary input data of shape (N,) or (N, 1) + (torch.randint(0, 2, size=(50,)).long(), torch.randint(0, 2, size=(50,)).long(), 1), + (torch.randint(0, 2, size=(50, 1)).long(), torch.randint(0, 2, size=(50, 1)).long(), 1), + # updated batches + (torch.randint(0, 2, size=(50,)).long(), torch.randint(0, 2, size=(50,)).long(), 16), + (torch.randint(0, 2, size=(50, 1)).long(), torch.randint(0, 2, size=(50, 1)).long(), 16), + # Binary input data of shape (N, L) + (torch.randint(0, 2, size=(50, 4)).long(), torch.randint(0, 2, size=(50, 4)).long(), 1), + (torch.randint(0, 2, size=(50, 7)).long(), torch.randint(0, 2, size=(50, 7)).long(), 1), + # updated batches + (torch.randint(0, 2, size=(50, 4)).long(), torch.randint(0, 2, size=(50, 4)).long(), 16), + (torch.randint(0, 2, size=(50, 7)).long(), torch.randint(0, 2, size=(50, 7)).long(), 16), + ][request.param] + + +@pytest.mark.parametrize("n_times", range(5)) +def test_binary_and_multilabel_inputs(n_times, test_data_binary_and_multilabel): + y_pred, y, batch_size = test_data_binary_and_multilabel roc_auc = ROC_AUC() + roc_auc.reset() + if batch_size > 1: + n_iters = y.shape[0] // batch_size + 1 + for i in range(n_iters): + idx = i * batch_size + roc_auc.update((y_pred[idx : idx + batch_size], y[idx : idx + batch_size])) + else: + roc_auc.update((y_pred, y)) - def _test(y_pred, y, batch_size): - roc_auc.reset() - if batch_size > 1: - n_iters = y.shape[0] // batch_size + 1 - for i in range(n_iters): - idx = i * batch_size - roc_auc.update((y_pred[idx : idx + batch_size], y[idx : idx + batch_size])) - else: - roc_auc.update((y_pred, y)) + np_y = y.numpy() + np_y_pred = y_pred.numpy() - np_y = y.numpy() - np_y_pred = y_pred.numpy() - - res = roc_auc.compute() - assert isinstance(res, float) - assert roc_auc_score(np_y, np_y_pred) == pytest.approx(res) - - def get_test_cases(): - test_cases = [ - # Binary input data of shape (N,) or (N, 1) - (torch.randint(0, 2, size=(50,)).long(), torch.randint(0, 2, size=(50,)).long(), 1), - (torch.randint(0, 2, size=(50, 1)).long(), torch.randint(0, 2, size=(50, 1)).long(), 1), - # updated batches - (torch.randint(0, 2, size=(50,)).long(), torch.randint(0, 2, size=(50,)).long(), 16), - (torch.randint(0, 2, size=(50, 1)).long(), torch.randint(0, 2, size=(50, 1)).long(), 16), - # Binary input data of shape (N, L) - (torch.randint(0, 2, size=(50, 4)).long(), torch.randint(0, 2, size=(50, 4)).long(), 1), - (torch.randint(0, 2, size=(50, 7)).long(), torch.randint(0, 2, size=(50, 7)).long(), 1), - # updated batches - (torch.randint(0, 2, size=(50, 4)).long(), torch.randint(0, 2, size=(50, 4)).long(), 16), - (torch.randint(0, 2, size=(50, 7)).long(), torch.randint(0, 2, size=(50, 7)).long(), 16), - ] - return test_cases - - for _ in range(5): - test_cases = get_test_cases() - # check multiple random inputs as random exact occurencies are rare - for y_pred, y, batch_size in test_cases: - _test(y_pred, y, batch_size) + res = roc_auc.compute() + assert isinstance(res, float) + assert roc_auc_score(np_y, np_y_pred) == pytest.approx(res) def test_check_compute_fn(): @@ -124,46 +119,43 @@ def test_check_compute_fn(): em.update(output) -def test_integration_binary_and_multilabel_inputs(): - def _test(y_pred, y, batch_size): - def update_fn(engine, batch): - idx = (engine.state.iteration - 1) * batch_size - y_true_batch = np_y[idx : idx + batch_size] - y_pred_batch = np_y_pred[idx : idx + batch_size] - return torch.from_numpy(y_pred_batch), torch.from_numpy(y_true_batch) +@pytest.fixture(params=range(4)) +def test_data_integration_binary_and_multilabel(request): + return [ + # Binary input data of shape (N,) or (N, 1) + (torch.randint(0, 2, size=(100,)).long(), torch.randint(0, 2, size=(100,)).long(), 10), + (torch.randint(0, 2, size=(100, 1)).long(), torch.randint(0, 2, size=(100, 1)).long(), 10), + # Binary input data of shape (N, L) + (torch.randint(0, 2, size=(100, 3)).long(), torch.randint(0, 2, size=(100, 3)).long(), 10), + (torch.randint(0, 2, size=(100, 4)).long(), torch.randint(0, 2, size=(100, 4)).long(), 10), + ][request.param] - engine = Engine(update_fn) - roc_auc_metric = ROC_AUC() - roc_auc_metric.attach(engine, "roc_auc") +@pytest.mark.parametrize("n_times", range(5)) +def test_integration_binary_and_multilabel_inputs(n_times, test_data_integration_binary_and_multilabel): + y_pred, y, batch_size = test_data_integration_binary_and_multilabel - np_y = y.numpy() - np_y_pred = y_pred.numpy() + def update_fn(engine, batch): + idx = (engine.state.iteration - 1) * batch_size + y_true_batch = np_y[idx : idx + batch_size] + y_pred_batch = np_y_pred[idx : idx + batch_size] + return torch.from_numpy(y_pred_batch), torch.from_numpy(y_true_batch) - np_roc_auc = roc_auc_score(np_y, np_y_pred) + engine = Engine(update_fn) - data = list(range(y_pred.shape[0] // batch_size)) - roc_auc = engine.run(data, max_epochs=1).metrics["roc_auc"] + roc_auc_metric = ROC_AUC() + roc_auc_metric.attach(engine, "roc_auc") - assert isinstance(roc_auc, float) - assert np_roc_auc == pytest.approx(roc_auc) + np_y = y.numpy() + np_y_pred = y_pred.numpy() - def get_test_cases(): - test_cases = [ - # Binary input data of shape (N,) or (N, 1) - (torch.randint(0, 2, size=(100,)).long(), torch.randint(0, 2, size=(100,)).long(), 10), - (torch.randint(0, 2, size=(100, 1)).long(), torch.randint(0, 2, size=(100, 1)).long(), 10), - # Binary input data of shape (N, L) - (torch.randint(0, 2, size=(100, 3)).long(), torch.randint(0, 2, size=(100, 3)).long(), 10), - (torch.randint(0, 2, size=(100, 4)).long(), torch.randint(0, 2, size=(100, 4)).long(), 10), - ] - return test_cases + np_roc_auc = roc_auc_score(np_y, np_y_pred) - for _ in range(5): - # check multiple random inputs as random exact occurencies are rare - test_cases = get_test_cases() - for y_pred, y, batch_size in test_cases: - _test(y_pred, y, batch_size) + data = list(range(y_pred.shape[0] // batch_size)) + roc_auc = engine.run(data, max_epochs=1).metrics["roc_auc"] + + assert isinstance(roc_auc, float) + assert np_roc_auc == pytest.approx(roc_auc) def _test_distrib_binary_and_multilabel_inputs(device):