diff --git a/.github/workflows/test_full.yml b/.github/workflows/test_full.yml index 07727c4a..e3265692 100644 --- a/.github/workflows/test_full.yml +++ b/.github/workflows/test_full.yml @@ -28,8 +28,12 @@ jobs: - name: Install dependencies run: | pip install -r prereq.txt - pip install --upgrade pip - pip install .[all] - - name: Test with pytest - run: pytest -vvvs --durations=50 + - name: Test Core + run: | + pip install .[testing] + pytest -vvvs --durations=50 + - name: Test GOGGLE + run: | + pip install .[testing,goggle] + pytest -vvvs -k goggle --durations=50 diff --git a/.github/workflows/test_pr.yml b/.github/workflows/test_pr.yml index 35a8b2d0..2c907433 100644 --- a/.github/workflows/test_pr.yml +++ b/.github/workflows/test_pr.yml @@ -1,4 +1,4 @@ -name: Tests Python +name: Tests Fast Python on: push: @@ -55,8 +55,12 @@ jobs: - name: Install dependencies run: | pip install -r prereq.txt - pip install --upgrade pip - pip install .[all] - - name: Test with pytest - run: pytest -vvvs -m "not slow" --durations=50 + - name: Test Core + run: | + pip install .[testing] + pytest -vvvsx -m "not slow" --durations=50 + - name: Test GOGGLE + run: | + pip install .[testing,goggle] + pytest -vvvsx -m "not slow" -k goggle diff --git a/prereq.txt b/prereq.txt index dbdb5914..0d7eb1f0 100644 --- a/prereq.txt +++ b/prereq.txt @@ -1,3 +1,3 @@ numpy -torch +torch<2.0 tsai diff --git a/setup.cfg b/setup.cfg index de0ac067..96a67f09 100644 --- a/setup.cfg +++ b/setup.cfg @@ -34,8 +34,8 @@ python_requires = >=3.7 install_requires = scikit-learn>=1.0 nflows>=0.14 - pandas>=1.3 - torch>=1.10.0 + pandas>=1.3,<2.0 + torch>=1.10,<2.0 numpy>=1.20 lifelines>=0.27 opacus>=1.3 diff --git a/src/synthcity/plugins/core/plugin.py b/src/synthcity/plugins/core/plugin.py index 33529764..d4049bed 100644 --- a/src/synthcity/plugins/core/plugin.py +++ b/src/synthcity/plugins/core/plugin.py @@ -548,13 +548,16 @@ def __init__(self, plugins: list, expected_type: Type, categories: list) -> None self._available_plugins = {} for plugin in plugins: stem = Path(plugin).stem.split("plugin_")[-1] + cls = self._load_single_plugin_impl(plugin) + if cls is None: + continue self._available_plugins[stem] = plugin self._expected_type = expected_type self._categories = categories @validate_arguments - def _load_single_plugin(self, plugin_name: str) -> None: - """Helper for loading a single plugin""" + def _load_single_plugin_impl(self, plugin_name: str) -> Optional[Type]: + """Helper for loading a single plugin implementation""" plugin = Path(plugin_name) name = plugin.stem ptype = plugin.parent.name @@ -579,6 +582,10 @@ def _load_single_plugin(self, plugin_name: str) -> None: spec.loader.exec_module(mod) cls = mod.plugin + if cls is None: + log.critical(f"module disabled: {plugin_name}") + return None + failed = False break except BaseException as e: @@ -587,10 +594,19 @@ def _load_single_plugin(self, plugin_name: str) -> None: if failed: log.critical(f"module {name} load failed") - return + return None + + return cls + + @validate_arguments + def _load_single_plugin(self, plugin_name: str) -> bool: + """Helper for loading a single plugin""" + cls = self._load_single_plugin_impl(plugin_name) + if cls is None: + return False - log.debug(f"Loaded plugin {cls.type()} - {cls.name()}") self.add(cls.name(), cls) + return True def list(self) -> List[str]: """Get all the available plugins.""" diff --git a/src/synthcity/plugins/generic/plugin_goggle.py b/src/synthcity/plugins/generic/plugin_goggle.py index e829cd3e..3e982f06 100644 --- a/src/synthcity/plugins/generic/plugin_goggle.py +++ b/src/synthcity/plugins/generic/plugin_goggle.py @@ -23,11 +23,18 @@ FloatDistribution, IntegerDistribution, ) -from synthcity.plugins.core.models.tabular_goggle import TabularGoggle from synthcity.plugins.core.plugin import Plugin from synthcity.plugins.core.schema import Schema from synthcity.utils.constants import DEVICE +try: + # synthcity absolute + from synthcity.plugins.core.models.tabular_goggle import TabularGoggle + + module_disabled = False +except ImportError: + module_disabled = True + class GOGGLEPlugin(Plugin): """ @@ -248,4 +255,7 @@ def _generate(self, count: int, syn_schema: Schema, **kwargs: Any) -> pd.DataFra return self._safe_generate(self.model.generate, count, syn_schema) -plugin = GOGGLEPlugin +if module_disabled: + plugin = None +else: + plugin = GOGGLEPlugin diff --git a/tests/plugins/generic/generic_helpers.py b/tests/plugins/generic/generic_helpers.py index 1b9453e0..9abd23e9 100644 --- a/tests/plugins/generic/generic_helpers.py +++ b/tests/plugins/generic/generic_helpers.py @@ -1,5 +1,5 @@ # stdlib -from typing import Dict, List, Type +from typing import Dict, List, Optional, Type # third party import pandas as pd @@ -10,15 +10,20 @@ from synthcity.utils.serialization import load, save -def generate_fixtures(name: str, plugin: Type, plugin_args: Dict = {}) -> List: +def generate_fixtures( + name: str, plugin: Optional[Type], plugin_args: Dict = {} +) -> List: + if plugin is None: + return [] + def from_api() -> Plugin: return Plugins().get(name, **plugin_args) def from_module() -> Plugin: - return plugin(**plugin_args) + return plugin(**plugin_args) # type: ignore def from_serde() -> Plugin: - buff = save(plugin(**plugin_args)) + buff = save(plugin(**plugin_args)) # type: ignore return load(buff) return [from_api(), from_module(), from_serde()] diff --git a/tests/plugins/generic/test_goggle.py b/tests/plugins/generic/test_goggle.py index 0fd89099..9b194ae0 100644 --- a/tests/plugins/generic/test_goggle.py +++ b/tests/plugins/generic/test_goggle.py @@ -1,5 +1,6 @@ # third party import numpy as np +import pkg_resources import pytest from generic_helpers import generate_fixtures from sklearn.datasets import load_diabetes, load_iris @@ -12,35 +13,60 @@ from synthcity.plugins.generic.plugin_goggle import plugin from synthcity.utils.serialization import load, save +is_missing_goggle_deps = plugin is None + plugin_name = "goggle" plugin_args = { "n_iter": 10, "device": "cpu", } +if not is_missing_goggle_deps: + goggle_dependencies = {"dgl", "torch-scatter", "torch-sparse", "torch-geometric"} + installed = {pkg.key for pkg in pkg_resources.working_set} + is_missing_goggle_deps = len(goggle_dependencies - installed) > 0 + -@pytest.mark.parametrize("test_plugin", generate_fixtures(plugin_name, plugin)) +@pytest.mark.skipif(is_missing_goggle_deps, reason="Goggle dependencies not installed") +@pytest.mark.parametrize( + "test_plugin", + generate_fixtures(plugin_name, plugin), +) def test_plugin_sanity(test_plugin: Plugin) -> None: assert test_plugin is not None -@pytest.mark.parametrize("test_plugin", generate_fixtures(plugin_name, plugin)) +@pytest.mark.skipif(is_missing_goggle_deps, reason="Goggle dependencies not installed") +@pytest.mark.parametrize( + "test_plugin", + generate_fixtures(plugin_name, plugin), +) def test_plugin_name(test_plugin: Plugin) -> None: assert test_plugin.name() == plugin_name -@pytest.mark.parametrize("test_plugin", generate_fixtures(plugin_name, plugin)) +@pytest.mark.skipif(is_missing_goggle_deps, reason="Goggle dependencies not installed") +@pytest.mark.parametrize( + "test_plugin", + generate_fixtures(plugin_name, plugin), +) def test_plugin_type(test_plugin: Plugin) -> None: assert test_plugin.type() == "generic" -@pytest.mark.parametrize("test_plugin", generate_fixtures(plugin_name, plugin)) +@pytest.mark.skipif(is_missing_goggle_deps, reason="Goggle dependencies not installed") +@pytest.mark.parametrize( + "test_plugin", + generate_fixtures(plugin_name, plugin), +) def test_plugin_hyperparams(test_plugin: Plugin) -> None: assert len(test_plugin.hyperparameter_space()) == 9 +@pytest.mark.skipif(is_missing_goggle_deps, reason="Goggle dependencies not installed") @pytest.mark.parametrize( - "test_plugin", generate_fixtures(plugin_name, plugin, plugin_args) + "test_plugin", + generate_fixtures(plugin_name, plugin, plugin_args), ) def test_plugin_fit(test_plugin: Plugin) -> None: Xraw, y = load_diabetes(return_X_y=True, as_frame=True) @@ -48,8 +74,10 @@ def test_plugin_fit(test_plugin: Plugin) -> None: test_plugin.fit(GenericDataLoader(Xraw)) +@pytest.mark.skipif(is_missing_goggle_deps, reason="Goggle dependencies not installed") @pytest.mark.parametrize( - "test_plugin", generate_fixtures(plugin_name, plugin, plugin_args) + "test_plugin", + generate_fixtures(plugin_name, plugin, plugin_args), ) @pytest.mark.parametrize("serialize", [True, False]) def test_plugin_generate(test_plugin: Plugin, serialize: bool) -> None: @@ -78,8 +106,10 @@ def test_plugin_generate(test_plugin: Plugin, serialize: bool) -> None: assert (X_gen1.numpy() != X_gen3.numpy()).any() +@pytest.mark.skipif(is_missing_goggle_deps, reason="Goggle dependencies not installed") @pytest.mark.parametrize( - "test_plugin", generate_fixtures(plugin_name, plugin, plugin_args) + "test_plugin", + generate_fixtures(plugin_name, plugin, plugin_args), ) def test_plugin_generate_constraints_goggle(test_plugin: Plugin) -> None: X, y = load_iris(as_frame=True, return_X_y=True) @@ -105,12 +135,15 @@ def test_plugin_generate_constraints_goggle(test_plugin: Plugin) -> None: assert list(X_gen.columns) == list(X.columns) +@pytest.mark.skipif(is_missing_goggle_deps, reason="Goggle dependencies not installed") def test_sample_hyperparams() -> None: + assert plugin is not None for i in range(100): args = plugin.sample_hyperparameters() assert plugin(**args) is not None +@pytest.mark.skipif(is_missing_goggle_deps, reason="Goggle dependencies not installed") @pytest.mark.slow @pytest.mark.parametrize( "compress_dataset,decoder_arch", @@ -130,6 +163,7 @@ def test_eval_performance_goggle(compress_dataset: bool, decoder_arch: str) -> N Xraw["target"] = y X = GenericDataLoader(Xraw) + assert plugin is not None for retry in range(2): test_plugin = plugin( n_iter=5000,