From 263dc8ab5754e64b8d45769a4745f11b0358085a Mon Sep 17 00:00:00 2001 From: Rob Davis Date: Tue, 3 Sep 2024 17:20:16 +0100 Subject: [PATCH] skip mnist download if not linux for speed --- .github/workflows/test_pr.yml | 6 ------ tests/metrics/test_detection.py | 7 +++---- tests/metrics/test_performance.py | 6 ++---- tests/metrics/test_privacy.py | 7 +++---- tests/metrics/test_sanity.py | 7 +++---- tests/metrics/test_statistical.py | 7 +++---- tests/plugins/core/test_dataloader.py | 13 +++++-------- tests/plugins/images/test_image_adsgan.py | 16 +++++++++++----- tests/plugins/images/test_image_cgan.py | 16 +++++++++++----- 9 files changed, 41 insertions(+), 44 deletions(-) diff --git a/.github/workflows/test_pr.yml b/.github/workflows/test_pr.yml index 16dc9651..6bbcbfec 100644 --- a/.github/workflows/test_pr.yml +++ b/.github/workflows/test_pr.yml @@ -55,12 +55,6 @@ jobs: run: | python -m pip install -U pip pip install -r prereq.txt - - name: Pre-test setup - run: | - curl -o ./train-images-idx3-ubyte.gz https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz - curl -o ./train-labels-idx1-ubyte.gz https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz - curl -o ./t10k-images-idx3-ubyte.gz https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz - curl -o ./t10k-labels-idx1-ubyte.gz https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz - name: Test Core run: | pip install .[testing] diff --git a/tests/metrics/test_detection.py b/tests/metrics/test_detection.py index 39919b11..bfa04629 100644 --- a/tests/metrics/test_detection.py +++ b/tests/metrics/test_detection.py @@ -1,4 +1,5 @@ # stdlib +import sys from typing import Type # third party @@ -154,13 +155,11 @@ def test_detect_synth_timeseries(test_plugin: Plugin, evaluator_t: Type) -> None assert evaluator.direction() == "minimize" +@pytest.mark.skipif(sys.platform == "linux", reason="Linux only for faster results") @pytest.mark.slow_1 @pytest.mark.slow def test_image_support_detection() -> None: - try: - dataset = datasets.MNIST(".", download=False) - except RuntimeError: - dataset = datasets.MNIST(".", download=True) + dataset = datasets.MNIST(".", download=True) X1 = ImageDataLoader(dataset).sample(100) X2 = ImageDataLoader(dataset).sample(100) diff --git a/tests/metrics/test_performance.py b/tests/metrics/test_performance.py index 1f0e8ab6..c8adf9a7 100644 --- a/tests/metrics/test_performance.py +++ b/tests/metrics/test_performance.py @@ -477,13 +477,11 @@ def test_evaluate_performance_time_series_survival( assert def_score == good_score["syn_id.c_index"] - good_score["syn_id.brier_score"] +@pytest.mark.skipif(sys.platform != "linux", reason="Linux only for faster results") @pytest.mark.slow_1 @pytest.mark.slow def test_image_support_perf() -> None: - try: - dataset = datasets.MNIST(".", download=False) - except RuntimeError: - dataset = datasets.MNIST(".", download=True) + dataset = datasets.MNIST(".", download=True) X1 = ImageDataLoader(dataset).sample(100) X2 = ImageDataLoader(dataset).sample(100) diff --git a/tests/metrics/test_privacy.py b/tests/metrics/test_privacy.py index 3a15679e..356ae819 100644 --- a/tests/metrics/test_privacy.py +++ b/tests/metrics/test_privacy.py @@ -1,4 +1,5 @@ # stdlib +import sys from typing import Type # third party @@ -80,11 +81,9 @@ def test_evaluator(evaluator_t: Type, test_plugin: Plugin) -> None: assert isinstance(def_score, (float, int)) +@pytest.mark.skipif(sys.platform != "linux", reason="Linux only for faster results") def test_image_support() -> None: - try: - dataset = datasets.MNIST(".", download=False) - except RuntimeError: - dataset = datasets.MNIST(".", download=True) + dataset = datasets.MNIST(".", download=True) X1 = ImageDataLoader(dataset).sample(100) X2 = ImageDataLoader(dataset).sample(100) diff --git a/tests/metrics/test_sanity.py b/tests/metrics/test_sanity.py index 27c8703f..553892e9 100644 --- a/tests/metrics/test_sanity.py +++ b/tests/metrics/test_sanity.py @@ -1,4 +1,5 @@ # stdlib +import sys from typing import Callable, Tuple # third party @@ -194,11 +195,9 @@ def test_evaluate_distant_values(test_plugin: Plugin) -> None: assert isinstance(def_score, float) +@pytest.mark.skipif(sys.platform != "linux", reason="Linux only for faster results") def test_image_support() -> None: - try: - dataset = datasets.MNIST(".", download=False) - except RuntimeError: - dataset = datasets.MNIST(".", download=True) + dataset = datasets.MNIST(".", download=True) X1 = ImageDataLoader(dataset).sample(100) X2 = ImageDataLoader(dataset).sample(100) diff --git a/tests/metrics/test_statistical.py b/tests/metrics/test_statistical.py index 4f138872..11b31ce4 100644 --- a/tests/metrics/test_statistical.py +++ b/tests/metrics/test_statistical.py @@ -1,4 +1,5 @@ # stdlib +import sys from typing import Any, Tuple, Type # third party @@ -283,11 +284,9 @@ def test_evaluate_survival_km_distance(test_plugin: Plugin) -> None: assert SurvivalKMDistance.direction() == "minimize" +@pytest.mark.skipif(sys.platform != "linux", reason="Linux only for faster results") def test_image_support() -> None: - try: - dataset = datasets.MNIST(".", download=False) - except RuntimeError: - dataset = datasets.MNIST(".", download=True) + dataset = datasets.MNIST(".", download=True) X1 = ImageDataLoader(dataset).sample(100) X2 = ImageDataLoader(dataset).sample(100) diff --git a/tests/plugins/core/test_dataloader.py b/tests/plugins/core/test_dataloader.py index d8db3a6d..c01481c0 100644 --- a/tests/plugins/core/test_dataloader.py +++ b/tests/plugins/core/test_dataloader.py @@ -1,4 +1,5 @@ # stdlib +import sys from datetime import datetime from typing import Any @@ -635,13 +636,11 @@ def test_time_series_survival_pack_unpack_padding(as_numpy: bool) -> None: assert len(unp_observation_times[idx]) == max_window_len +@pytest.mark.skipif(sys.platform != "linux", reason="Linux only for faster results") @pytest.mark.parametrize("height", [55, 64]) @pytest.mark.parametrize("width", [32, 22]) def test_image_dataloader_sanity(height: int, width: int) -> None: - try: - dataset = datasets.MNIST(".", download=False) - except RuntimeError: - dataset = datasets.MNIST(".", download=True) + dataset = datasets.MNIST(".", download=True) loader = ImageDataLoader( data=dataset, @@ -680,11 +679,9 @@ def test_image_dataloader_sanity(height: int, width: int) -> None: assert loader.unpack().labels().shape == (len(loader),) +@pytest.mark.skipif(sys.platform != "linux", reason="Linux only for faster results") def test_image_dataloader_create_from_info() -> None: - try: - dataset = datasets.MNIST(".", download=False) - except RuntimeError: - dataset = datasets.MNIST(".", download=True) + dataset = datasets.MNIST(".", download=True) loader = ImageDataLoader( data=dataset, diff --git a/tests/plugins/images/test_image_adsgan.py b/tests/plugins/images/test_image_adsgan.py index fead8127..a1b6414f 100644 --- a/tests/plugins/images/test_image_adsgan.py +++ b/tests/plugins/images/test_image_adsgan.py @@ -1,3 +1,6 @@ +# stdlib +import sys + # third party import numpy as np import pytest @@ -11,11 +14,6 @@ plugin_name = "image_adsgan" -try: - dataset = datasets.MNIST(".", download=False) -except RuntimeError: - dataset = datasets.MNIST(".", download=True) - @pytest.mark.parametrize("test_plugin", generate_fixtures(plugin_name, plugin)) def test_plugin_sanity(test_plugin: Plugin) -> None: @@ -37,7 +35,9 @@ def test_plugin_hyperparams(test_plugin: Plugin) -> None: assert len(test_plugin.hyperparameter_space()) == 6 +@pytest.mark.skipif(sys.platform != "linux", reason="Linux only for faster results") def test_plugin_fit() -> None: + dataset = datasets.MNIST(".", download=True) test_plugin = plugin(n_iter=5) X = ImageDataLoader(dataset).sample(100) @@ -45,7 +45,9 @@ def test_plugin_fit() -> None: test_plugin.fit(X) +@pytest.mark.skipif(sys.platform != "linux", reason="Linux only for faster results") def test_plugin_generate() -> None: + dataset = datasets.MNIST(".", download=True) test_plugin = plugin(n_iter=10, n_units_latent=13) X = ImageDataLoader(dataset).sample(100) @@ -60,9 +62,11 @@ def test_plugin_generate() -> None: assert len(X_gen) == 50 +@pytest.mark.skipif(sys.platform != "linux", reason="Linux only for faster results") @pytest.mark.slow_2 @pytest.mark.slow def test_plugin_generate_with_conditional() -> None: + dataset = datasets.MNIST(".", download=True) test_plugin = plugin(n_iter=10, n_units_latent=13) X = ImageDataLoader(dataset).sample(100) @@ -75,9 +79,11 @@ def test_plugin_generate_with_conditional() -> None: assert len(X_gen) == 50 +@pytest.mark.skipif(sys.platform != "linux", reason="Linux only for faster results") @pytest.mark.slow_2 @pytest.mark.slow def test_plugin_generate_with_stop_conditional() -> None: + dataset = datasets.MNIST(".", download=True) test_plugin = plugin(n_iter=10, n_units_latent=13, n_iter_print=2) X = ImageDataLoader(dataset).sample(100) diff --git a/tests/plugins/images/test_image_cgan.py b/tests/plugins/images/test_image_cgan.py index 787d99dd..fc30f84f 100644 --- a/tests/plugins/images/test_image_cgan.py +++ b/tests/plugins/images/test_image_cgan.py @@ -1,3 +1,6 @@ +# stdlib +import sys + # third party import numpy as np import pytest @@ -11,11 +14,6 @@ plugin_name = "image_cgan" -try: - dataset = datasets.MNIST(".", download=False) -except RuntimeError: - dataset = datasets.MNIST(".", download=True) - @pytest.mark.parametrize("test_plugin", generate_fixtures(plugin_name, plugin)) def test_plugin_sanity(test_plugin: Plugin) -> None: @@ -37,10 +35,12 @@ def test_plugin_hyperparams(test_plugin: Plugin) -> None: assert len(test_plugin.hyperparameter_space()) == 6 +@pytest.mark.skipif(sys.platform != "linux", reason="Linux only for faster results") @pytest.mark.parametrize("height", [32, 64, 128]) @pytest.mark.slow_2 @pytest.mark.slow def test_plugin_fit(height: int) -> None: + dataset = datasets.MNIST(".", download=True) test_plugin = plugin(n_iter=5) X = ImageDataLoader(dataset, height=height).sample(100) @@ -48,7 +48,9 @@ def test_plugin_fit(height: int) -> None: test_plugin.fit(X) +@pytest.mark.skipif(sys.platform != "linux", reason="Linux only for faster results") def test_plugin_generate() -> None: + dataset = datasets.MNIST(".", download=True) test_plugin = plugin(n_iter=10, n_units_latent=13) X = ImageDataLoader(dataset).sample(100) @@ -63,7 +65,9 @@ def test_plugin_generate() -> None: assert len(X_gen) == 50 +@pytest.mark.skipif(sys.platform != "linux", reason="Linux only for faster results") def test_plugin_generate_with_conditional() -> None: + dataset = datasets.MNIST(".", download=True) test_plugin = plugin(n_iter=10, n_units_latent=13) X = ImageDataLoader(dataset).sample(100) @@ -76,9 +80,11 @@ def test_plugin_generate_with_conditional() -> None: assert len(X_gen) == 50 +@pytest.mark.skipif(sys.platform != "linux", reason="Linux only for faster results") @pytest.mark.slow_2 @pytest.mark.slow def test_plugin_generate_with_stop_conditional() -> None: + dataset = datasets.MNIST(".", download=True) test_plugin = plugin(n_iter=10, n_units_latent=13, n_iter_print=2) X = ImageDataLoader(dataset).sample(100)