Skip to content

Commit

Permalink
skip mnist download if not linux for speed
Browse files Browse the repository at this point in the history
  • Loading branch information
robsdavis committed Sep 3, 2024
1 parent 2e65e01 commit 263dc8a
Show file tree
Hide file tree
Showing 9 changed files with 41 additions and 44 deletions.
6 changes: 0 additions & 6 deletions .github/workflows/test_pr.yml
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,6 @@ jobs:
run: |
python -m pip install -U pip
pip install -r prereq.txt
- name: Pre-test setup
run: |
curl -o ./train-images-idx3-ubyte.gz https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
curl -o ./train-labels-idx1-ubyte.gz https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
curl -o ./t10k-images-idx3-ubyte.gz https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
curl -o ./t10k-labels-idx1-ubyte.gz https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
- name: Test Core
run: |
pip install .[testing]
Expand Down
7 changes: 3 additions & 4 deletions tests/metrics/test_detection.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# stdlib
import sys
from typing import Type

# third party
Expand Down Expand Up @@ -154,13 +155,11 @@ def test_detect_synth_timeseries(test_plugin: Plugin, evaluator_t: Type) -> None
assert evaluator.direction() == "minimize"


@pytest.mark.skipif(sys.platform == "linux", reason="Linux only for faster results")
@pytest.mark.slow_1
@pytest.mark.slow
def test_image_support_detection() -> None:
try:
dataset = datasets.MNIST(".", download=False)
except RuntimeError:
dataset = datasets.MNIST(".", download=True)
dataset = datasets.MNIST(".", download=True)

X1 = ImageDataLoader(dataset).sample(100)
X2 = ImageDataLoader(dataset).sample(100)
Expand Down
6 changes: 2 additions & 4 deletions tests/metrics/test_performance.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,13 +477,11 @@ def test_evaluate_performance_time_series_survival(
assert def_score == good_score["syn_id.c_index"] - good_score["syn_id.brier_score"]


@pytest.mark.skipif(sys.platform != "linux", reason="Linux only for faster results")
@pytest.mark.slow_1
@pytest.mark.slow
def test_image_support_perf() -> None:
try:
dataset = datasets.MNIST(".", download=False)
except RuntimeError:
dataset = datasets.MNIST(".", download=True)
dataset = datasets.MNIST(".", download=True)

X1 = ImageDataLoader(dataset).sample(100)
X2 = ImageDataLoader(dataset).sample(100)
Expand Down
7 changes: 3 additions & 4 deletions tests/metrics/test_privacy.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# stdlib
import sys
from typing import Type

# third party
Expand Down Expand Up @@ -80,11 +81,9 @@ def test_evaluator(evaluator_t: Type, test_plugin: Plugin) -> None:
assert isinstance(def_score, (float, int))


@pytest.mark.skipif(sys.platform != "linux", reason="Linux only for faster results")
def test_image_support() -> None:
try:
dataset = datasets.MNIST(".", download=False)
except RuntimeError:
dataset = datasets.MNIST(".", download=True)
dataset = datasets.MNIST(".", download=True)

X1 = ImageDataLoader(dataset).sample(100)
X2 = ImageDataLoader(dataset).sample(100)
Expand Down
7 changes: 3 additions & 4 deletions tests/metrics/test_sanity.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# stdlib
import sys
from typing import Callable, Tuple

# third party
Expand Down Expand Up @@ -194,11 +195,9 @@ def test_evaluate_distant_values(test_plugin: Plugin) -> None:
assert isinstance(def_score, float)


@pytest.mark.skipif(sys.platform != "linux", reason="Linux only for faster results")
def test_image_support() -> None:
try:
dataset = datasets.MNIST(".", download=False)
except RuntimeError:
dataset = datasets.MNIST(".", download=True)
dataset = datasets.MNIST(".", download=True)

X1 = ImageDataLoader(dataset).sample(100)
X2 = ImageDataLoader(dataset).sample(100)
Expand Down
7 changes: 3 additions & 4 deletions tests/metrics/test_statistical.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# stdlib
import sys
from typing import Any, Tuple, Type

# third party
Expand Down Expand Up @@ -283,11 +284,9 @@ def test_evaluate_survival_km_distance(test_plugin: Plugin) -> None:
assert SurvivalKMDistance.direction() == "minimize"


@pytest.mark.skipif(sys.platform != "linux", reason="Linux only for faster results")
def test_image_support() -> None:
try:
dataset = datasets.MNIST(".", download=False)
except RuntimeError:
dataset = datasets.MNIST(".", download=True)
dataset = datasets.MNIST(".", download=True)

X1 = ImageDataLoader(dataset).sample(100)
X2 = ImageDataLoader(dataset).sample(100)
Expand Down
13 changes: 5 additions & 8 deletions tests/plugins/core/test_dataloader.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# stdlib
import sys
from datetime import datetime
from typing import Any

Expand Down Expand Up @@ -635,13 +636,11 @@ def test_time_series_survival_pack_unpack_padding(as_numpy: bool) -> None:
assert len(unp_observation_times[idx]) == max_window_len


@pytest.mark.skipif(sys.platform != "linux", reason="Linux only for faster results")
@pytest.mark.parametrize("height", [55, 64])
@pytest.mark.parametrize("width", [32, 22])
def test_image_dataloader_sanity(height: int, width: int) -> None:
try:
dataset = datasets.MNIST(".", download=False)
except RuntimeError:
dataset = datasets.MNIST(".", download=True)
dataset = datasets.MNIST(".", download=True)

loader = ImageDataLoader(
data=dataset,
Expand Down Expand Up @@ -680,11 +679,9 @@ def test_image_dataloader_sanity(height: int, width: int) -> None:
assert loader.unpack().labels().shape == (len(loader),)


@pytest.mark.skipif(sys.platform != "linux", reason="Linux only for faster results")
def test_image_dataloader_create_from_info() -> None:
try:
dataset = datasets.MNIST(".", download=False)
except RuntimeError:
dataset = datasets.MNIST(".", download=True)
dataset = datasets.MNIST(".", download=True)

loader = ImageDataLoader(
data=dataset,
Expand Down
16 changes: 11 additions & 5 deletions tests/plugins/images/test_image_adsgan.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# stdlib
import sys

# third party
import numpy as np
import pytest
Expand All @@ -11,11 +14,6 @@

plugin_name = "image_adsgan"

try:
dataset = datasets.MNIST(".", download=False)
except RuntimeError:
dataset = datasets.MNIST(".", download=True)


@pytest.mark.parametrize("test_plugin", generate_fixtures(plugin_name, plugin))
def test_plugin_sanity(test_plugin: Plugin) -> None:
Expand All @@ -37,15 +35,19 @@ def test_plugin_hyperparams(test_plugin: Plugin) -> None:
assert len(test_plugin.hyperparameter_space()) == 6


@pytest.mark.skipif(sys.platform != "linux", reason="Linux only for faster results")
def test_plugin_fit() -> None:
dataset = datasets.MNIST(".", download=True)
test_plugin = plugin(n_iter=5)

X = ImageDataLoader(dataset).sample(100)

test_plugin.fit(X)


@pytest.mark.skipif(sys.platform != "linux", reason="Linux only for faster results")
def test_plugin_generate() -> None:
dataset = datasets.MNIST(".", download=True)
test_plugin = plugin(n_iter=10, n_units_latent=13)

X = ImageDataLoader(dataset).sample(100)
Expand All @@ -60,9 +62,11 @@ def test_plugin_generate() -> None:
assert len(X_gen) == 50


@pytest.mark.skipif(sys.platform != "linux", reason="Linux only for faster results")
@pytest.mark.slow_2
@pytest.mark.slow
def test_plugin_generate_with_conditional() -> None:
dataset = datasets.MNIST(".", download=True)
test_plugin = plugin(n_iter=10, n_units_latent=13)

X = ImageDataLoader(dataset).sample(100)
Expand All @@ -75,9 +79,11 @@ def test_plugin_generate_with_conditional() -> None:
assert len(X_gen) == 50


@pytest.mark.skipif(sys.platform != "linux", reason="Linux only for faster results")
@pytest.mark.slow_2
@pytest.mark.slow
def test_plugin_generate_with_stop_conditional() -> None:
dataset = datasets.MNIST(".", download=True)
test_plugin = plugin(n_iter=10, n_units_latent=13, n_iter_print=2)

X = ImageDataLoader(dataset).sample(100)
Expand Down
16 changes: 11 additions & 5 deletions tests/plugins/images/test_image_cgan.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# stdlib
import sys

# third party
import numpy as np
import pytest
Expand All @@ -11,11 +14,6 @@

plugin_name = "image_cgan"

try:
dataset = datasets.MNIST(".", download=False)
except RuntimeError:
dataset = datasets.MNIST(".", download=True)


@pytest.mark.parametrize("test_plugin", generate_fixtures(plugin_name, plugin))
def test_plugin_sanity(test_plugin: Plugin) -> None:
Expand All @@ -37,18 +35,22 @@ def test_plugin_hyperparams(test_plugin: Plugin) -> None:
assert len(test_plugin.hyperparameter_space()) == 6


@pytest.mark.skipif(sys.platform != "linux", reason="Linux only for faster results")
@pytest.mark.parametrize("height", [32, 64, 128])
@pytest.mark.slow_2
@pytest.mark.slow
def test_plugin_fit(height: int) -> None:
dataset = datasets.MNIST(".", download=True)
test_plugin = plugin(n_iter=5)

X = ImageDataLoader(dataset, height=height).sample(100)

test_plugin.fit(X)


@pytest.mark.skipif(sys.platform != "linux", reason="Linux only for faster results")
def test_plugin_generate() -> None:
dataset = datasets.MNIST(".", download=True)
test_plugin = plugin(n_iter=10, n_units_latent=13)

X = ImageDataLoader(dataset).sample(100)
Expand All @@ -63,7 +65,9 @@ def test_plugin_generate() -> None:
assert len(X_gen) == 50


@pytest.mark.skipif(sys.platform != "linux", reason="Linux only for faster results")
def test_plugin_generate_with_conditional() -> None:
dataset = datasets.MNIST(".", download=True)
test_plugin = plugin(n_iter=10, n_units_latent=13)

X = ImageDataLoader(dataset).sample(100)
Expand All @@ -76,9 +80,11 @@ def test_plugin_generate_with_conditional() -> None:
assert len(X_gen) == 50


@pytest.mark.skipif(sys.platform != "linux", reason="Linux only for faster results")
@pytest.mark.slow_2
@pytest.mark.slow
def test_plugin_generate_with_stop_conditional() -> None:
dataset = datasets.MNIST(".", download=True)
test_plugin = plugin(n_iter=10, n_units_latent=13, n_iter_print=2)

X = ImageDataLoader(dataset).sample(100)
Expand Down

0 comments on commit 263dc8a

Please sign in to comment.