Skip to content

Commit

Permalink
Skipif goggle tests (#170)
Browse files Browse the repository at this point in the history
* Skip goggle tests if dependencies not installed and pin torch<2.0

* Pass pre-commit

* install goggle dependencies in workflows

* match pytorch version in prereq to version in setup.cfg

* more depends issues

* cleanup prereq

* reorganize workflows

* handle modules that cannot be loaded

* fix linting errors

* more linting

* rename PR tests

* cleanup workflows

* dependency mess

---------

Co-authored-by: Bogdan Cebere <[email protected]>
  • Loading branch information
robsdavis and bcebere authored Apr 4, 2023
1 parent e4a7a22 commit 8ae952a
Show file tree
Hide file tree
Showing 8 changed files with 102 additions and 29 deletions.
12 changes: 8 additions & 4 deletions .github/workflows/test_full.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,12 @@ jobs:
- name: Install dependencies
run: |
pip install -r prereq.txt
pip install --upgrade pip
pip install .[all]
- name: Test with pytest
run: pytest -vvvs --durations=50
- name: Test Core
run: |
pip install .[testing]
pytest -vvvs --durations=50
- name: Test GOGGLE
run: |
pip install .[testing,goggle]
pytest -vvvs -k goggle --durations=50
14 changes: 9 additions & 5 deletions .github/workflows/test_pr.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: Tests Python
name: Tests Fast Python

on:
push:
Expand Down Expand Up @@ -55,8 +55,12 @@ jobs:
- name: Install dependencies
run: |
pip install -r prereq.txt
pip install --upgrade pip
pip install .[all]
- name: Test with pytest
run: pytest -vvvs -m "not slow" --durations=50
- name: Test Core
run: |
pip install .[testing]
pytest -vvvsx -m "not slow" --durations=50
- name: Test GOGGLE
run: |
pip install .[testing,goggle]
pytest -vvvsx -m "not slow" -k goggle
2 changes: 1 addition & 1 deletion prereq.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
numpy
torch
torch<2.0
tsai
4 changes: 2 additions & 2 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ python_requires = >=3.7
install_requires =
scikit-learn>=1.0
nflows>=0.14
pandas>=1.3
torch>=1.10.0
pandas>=1.3,<2.0
torch>=1.10,<2.0
numpy>=1.20
lifelines>=0.27
opacus>=1.3
Expand Down
24 changes: 20 additions & 4 deletions src/synthcity/plugins/core/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,13 +548,16 @@ def __init__(self, plugins: list, expected_type: Type, categories: list) -> None
self._available_plugins = {}
for plugin in plugins:
stem = Path(plugin).stem.split("plugin_")[-1]
cls = self._load_single_plugin_impl(plugin)
if cls is None:
continue
self._available_plugins[stem] = plugin
self._expected_type = expected_type
self._categories = categories

@validate_arguments
def _load_single_plugin(self, plugin_name: str) -> None:
"""Helper for loading a single plugin"""
def _load_single_plugin_impl(self, plugin_name: str) -> Optional[Type]:
"""Helper for loading a single plugin implementation"""
plugin = Path(plugin_name)
name = plugin.stem
ptype = plugin.parent.name
Expand All @@ -579,6 +582,10 @@ def _load_single_plugin(self, plugin_name: str) -> None:

spec.loader.exec_module(mod)
cls = mod.plugin
if cls is None:
log.critical(f"module disabled: {plugin_name}")
return None

failed = False
break
except BaseException as e:
Expand All @@ -587,10 +594,19 @@ def _load_single_plugin(self, plugin_name: str) -> None:

if failed:
log.critical(f"module {name} load failed")
return
return None

return cls

@validate_arguments
def _load_single_plugin(self, plugin_name: str) -> bool:
"""Helper for loading a single plugin"""
cls = self._load_single_plugin_impl(plugin_name)
if cls is None:
return False

log.debug(f"Loaded plugin {cls.type()} - {cls.name()}")
self.add(cls.name(), cls)
return True

def list(self) -> List[str]:
"""Get all the available plugins."""
Expand Down
14 changes: 12 additions & 2 deletions src/synthcity/plugins/generic/plugin_goggle.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,18 @@
FloatDistribution,
IntegerDistribution,
)
from synthcity.plugins.core.models.tabular_goggle import TabularGoggle
from synthcity.plugins.core.plugin import Plugin
from synthcity.plugins.core.schema import Schema
from synthcity.utils.constants import DEVICE

try:
# synthcity absolute
from synthcity.plugins.core.models.tabular_goggle import TabularGoggle

module_disabled = False
except ImportError:
module_disabled = True


class GOGGLEPlugin(Plugin):
"""
Expand Down Expand Up @@ -248,4 +255,7 @@ def _generate(self, count: int, syn_schema: Schema, **kwargs: Any) -> pd.DataFra
return self._safe_generate(self.model.generate, count, syn_schema)


plugin = GOGGLEPlugin
if module_disabled:
plugin = None
else:
plugin = GOGGLEPlugin
13 changes: 9 additions & 4 deletions tests/plugins/generic/generic_helpers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# stdlib
from typing import Dict, List, Type
from typing import Dict, List, Optional, Type

# third party
import pandas as pd
Expand All @@ -10,15 +10,20 @@
from synthcity.utils.serialization import load, save


def generate_fixtures(name: str, plugin: Type, plugin_args: Dict = {}) -> List:
def generate_fixtures(
name: str, plugin: Optional[Type], plugin_args: Dict = {}
) -> List:
if plugin is None:
return []

def from_api() -> Plugin:
return Plugins().get(name, **plugin_args)

def from_module() -> Plugin:
return plugin(**plugin_args)
return plugin(**plugin_args) # type: ignore

def from_serde() -> Plugin:
buff = save(plugin(**plugin_args))
buff = save(plugin(**plugin_args)) # type: ignore
return load(buff)

return [from_api(), from_module(), from_serde()]
Expand Down
48 changes: 41 additions & 7 deletions tests/plugins/generic/test_goggle.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# third party
import numpy as np
import pkg_resources
import pytest
from generic_helpers import generate_fixtures
from sklearn.datasets import load_diabetes, load_iris
Expand All @@ -12,44 +13,71 @@
from synthcity.plugins.generic.plugin_goggle import plugin
from synthcity.utils.serialization import load, save

is_missing_goggle_deps = plugin is None

plugin_name = "goggle"
plugin_args = {
"n_iter": 10,
"device": "cpu",
}

if not is_missing_goggle_deps:
goggle_dependencies = {"dgl", "torch-scatter", "torch-sparse", "torch-geometric"}
installed = {pkg.key for pkg in pkg_resources.working_set}
is_missing_goggle_deps = len(goggle_dependencies - installed) > 0


@pytest.mark.parametrize("test_plugin", generate_fixtures(plugin_name, plugin))
@pytest.mark.skipif(is_missing_goggle_deps, reason="Goggle dependencies not installed")
@pytest.mark.parametrize(
"test_plugin",
generate_fixtures(plugin_name, plugin),
)
def test_plugin_sanity(test_plugin: Plugin) -> None:
assert test_plugin is not None


@pytest.mark.parametrize("test_plugin", generate_fixtures(plugin_name, plugin))
@pytest.mark.skipif(is_missing_goggle_deps, reason="Goggle dependencies not installed")
@pytest.mark.parametrize(
"test_plugin",
generate_fixtures(plugin_name, plugin),
)
def test_plugin_name(test_plugin: Plugin) -> None:
assert test_plugin.name() == plugin_name


@pytest.mark.parametrize("test_plugin", generate_fixtures(plugin_name, plugin))
@pytest.mark.skipif(is_missing_goggle_deps, reason="Goggle dependencies not installed")
@pytest.mark.parametrize(
"test_plugin",
generate_fixtures(plugin_name, plugin),
)
def test_plugin_type(test_plugin: Plugin) -> None:
assert test_plugin.type() == "generic"


@pytest.mark.parametrize("test_plugin", generate_fixtures(plugin_name, plugin))
@pytest.mark.skipif(is_missing_goggle_deps, reason="Goggle dependencies not installed")
@pytest.mark.parametrize(
"test_plugin",
generate_fixtures(plugin_name, plugin),
)
def test_plugin_hyperparams(test_plugin: Plugin) -> None:
assert len(test_plugin.hyperparameter_space()) == 9


@pytest.mark.skipif(is_missing_goggle_deps, reason="Goggle dependencies not installed")
@pytest.mark.parametrize(
"test_plugin", generate_fixtures(plugin_name, plugin, plugin_args)
"test_plugin",
generate_fixtures(plugin_name, plugin, plugin_args),
)
def test_plugin_fit(test_plugin: Plugin) -> None:
Xraw, y = load_diabetes(return_X_y=True, as_frame=True)
Xraw["target"] = y
test_plugin.fit(GenericDataLoader(Xraw))


@pytest.mark.skipif(is_missing_goggle_deps, reason="Goggle dependencies not installed")
@pytest.mark.parametrize(
"test_plugin", generate_fixtures(plugin_name, plugin, plugin_args)
"test_plugin",
generate_fixtures(plugin_name, plugin, plugin_args),
)
@pytest.mark.parametrize("serialize", [True, False])
def test_plugin_generate(test_plugin: Plugin, serialize: bool) -> None:
Expand Down Expand Up @@ -78,8 +106,10 @@ def test_plugin_generate(test_plugin: Plugin, serialize: bool) -> None:
assert (X_gen1.numpy() != X_gen3.numpy()).any()


@pytest.mark.skipif(is_missing_goggle_deps, reason="Goggle dependencies not installed")
@pytest.mark.parametrize(
"test_plugin", generate_fixtures(plugin_name, plugin, plugin_args)
"test_plugin",
generate_fixtures(plugin_name, plugin, plugin_args),
)
def test_plugin_generate_constraints_goggle(test_plugin: Plugin) -> None:
X, y = load_iris(as_frame=True, return_X_y=True)
Expand All @@ -105,12 +135,15 @@ def test_plugin_generate_constraints_goggle(test_plugin: Plugin) -> None:
assert list(X_gen.columns) == list(X.columns)


@pytest.mark.skipif(is_missing_goggle_deps, reason="Goggle dependencies not installed")
def test_sample_hyperparams() -> None:
assert plugin is not None
for i in range(100):
args = plugin.sample_hyperparameters()
assert plugin(**args) is not None


@pytest.mark.skipif(is_missing_goggle_deps, reason="Goggle dependencies not installed")
@pytest.mark.slow
@pytest.mark.parametrize(
"compress_dataset,decoder_arch",
Expand All @@ -130,6 +163,7 @@ def test_eval_performance_goggle(compress_dataset: bool, decoder_arch: str) -> N
Xraw["target"] = y
X = GenericDataLoader(Xraw)

assert plugin is not None
for retry in range(2):
test_plugin = plugin(
n_iter=5000,
Expand Down

0 comments on commit 8ae952a

Please sign in to comment.