Skip to content

Commit

Permalink
Various bugfixes (#233)
Browse files Browse the repository at this point in the history
* debugging

* debugging

* bugfixes

* skip great tests on python<3.9

* sorting system

* reload plugin context on pluginloader init

* debugging

* tidy up

* fixing plugin registry and goggle dependancy check

* Great test on cpu only

* Automated commit by Keepalive Workflow to keep the repository active

* remove keepalive

* fix permissions

* fix release script permissions

* uncomment registry test

* clean up test_plugin_add

* revert to concat in ts unpack

* reload before test_plugin_add

* Skip great fit for github actions

---------

Co-authored-by: gkr-bot <[email protected]>
  • Loading branch information
robsdavis and gkr-bot authored Jan 5, 2024
1 parent ab47bcf commit 73cfd8c
Show file tree
Hide file tree
Showing 21 changed files with 83 additions and 44 deletions.
1 change: 0 additions & 1 deletion .github/workflows/test_pr.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ jobs:
- uses: actions/checkout@v2
with:
submodules: true
- uses: gautamkrishnar/keepalive-workflow@v1
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
Expand Down
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ package_dir =
python_requires = >=3.8

install_requires =
importlib-metadata
pandas>=1.4,<2
torch>=1.10.0,<2.0
scikit-learn>=1.0
Expand Down
2 changes: 0 additions & 2 deletions src/synthcity/metrics/eval_performance.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,12 +425,10 @@ def ts_eval_cbk(
temporal_train_data = id_temporal_gt[train_idx]
observation_times_train_data = id_observation_times_gt[train_idx]
outcome_train_data = id_outcome_gt[train_idx]

static_test_data = id_static_gt[test_idx]
temporal_test_data = id_temporal_gt[test_idx]
observation_times_test_data = id_observation_times_gt[test_idx]
outcome_test_data = id_outcome_gt[test_idx]

real_score = ts_eval_cbk(
static_train_data,
temporal_train_data,
Expand Down
1 change: 1 addition & 0 deletions src/synthcity/plugins/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"time_series",
"domain_adaptation",
"images",
"debug",
]
plugins = {}

Expand Down
4 changes: 3 additions & 1 deletion src/synthcity/plugins/core/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -932,7 +932,9 @@ def unpack(self, as_numpy: bool = False, pad: bool = False) -> Any:
longest_observation_seq = max([len(seq) for seq in temporal_data])
return (
np.asarray(static_data),
np.asarray(temporal_data),
np.asarray(
pd.concat(temporal_data)
), # TODO: check this works with time series benchmarks
# masked array to handle variable length sequences
ma.vstack(
[
Expand Down
8 changes: 3 additions & 5 deletions src/synthcity/plugins/core/models/aim.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,12 +126,10 @@ def __init__(self, epsilon: float, delta: float):
Base class for a mechanism.
:param epsilon: privacy parameter
:param delta: privacy parameter
:param prng: pseudo random number generator
"""
self.epsilon = epsilon
self.delta = delta
self.rho = 0 if delta == 0 else cdp_rho(epsilon, delta)
self.prng = np.random

def run(self, dataset: Dataset, workload: List[Tuple]) -> Any:
pass
Expand Down Expand Up @@ -204,7 +202,7 @@ def exponential_mechanism(
else:
p = softmax(0.5 * epsilon / sensitivity * q + base_measure)

return keys[self.prng.choice(p.size, p=p)]
return keys[np.random.choice(p.size, p=p)]

# def gaussian_noise_scale(self, l2_sensitivity, epsilon, delta):
# """Return the Gaussian noise necessary to attain (epsilon, delta)-DP"""
Expand All @@ -223,11 +221,11 @@ def exponential_mechanism(

def gaussian_noise(self, sigma: float, size: Union[int, Tuple]) -> np.ndarray:
"""Generate iid Gaussian noise of a given scale and size"""
return self.prng.normal(0, sigma, size)
return np.random.normal(0, sigma, size)

# def laplace_noise(self, b, size):
# """Generate iid Laplace noise of a given scale and size"""
# return self.prng.laplace(0, b, size)
# return np.random.laplace(0, b, size)

# def best_noise_distribution(self, l1_sensitivity, l2_sensitivity, epsilon, delta):
# """Adaptively determine if Laplace or Gaussian noise will be better, and
Expand Down
8 changes: 4 additions & 4 deletions src/synthcity/plugins/core/models/mbi/clique_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,21 +36,21 @@ def uniform(domain, cliques):
return CliqueVector({cl: Factor.uniform(domain.project(cl)) for cl in cliques})

@staticmethod
def random(domain, cliques, prng=np.random):
def random(domain, cliques):
# synthcity relative
from .factor import Factor

return CliqueVector(
{cl: Factor.random(domain.project(cl), prng) for cl in cliques}
{cl: Factor.random(domain.project(cl), np.random) for cl in cliques}
)

@staticmethod
def normal(domain, cliques, prng=np.random):
def normal(domain, cliques):
# synthcity relative
from .factor import Factor

return CliqueVector(
{cl: Factor.normal(domain.project(cl), prng) for cl in cliques}
{cl: Factor.normal(domain.project(cl), np.random) for cl in cliques}
)

@staticmethod
Expand Down
6 changes: 3 additions & 3 deletions src/synthcity/plugins/core/models/tabular_aim.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# stdlib
import itertools
from abc import ABCMeta
from typing import Any, Optional, Union

# third party
Expand All @@ -17,7 +18,7 @@
from .mbi.domain import Domain


class TabularAIM:
class TabularAIM(metaclass=ABCMeta):
"""
.. inheritance-diagram:: synthcity.plugins.core.models.tabular_aim.TabularAIM
:parts: 1
Expand Down Expand Up @@ -68,7 +69,6 @@ def __init__(
self.degree = degree
self.num_marginals = num_marginals
self.max_cells = max_cells
self.prng = np.random

@validate_arguments(config=dict(arbitrary_types_allowed=True))
def fit(
Expand Down Expand Up @@ -101,7 +101,7 @@ def fit(
if self.num_marginals is not None:
workload = [
workload[i]
for i in self.prng.choice(
for i in np.random.choice(
len(workload), self.num_marginals, replace=False
)
]
Expand Down
3 changes: 2 additions & 1 deletion src/synthcity/plugins/core/models/tabular_arf.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# stdlib
from abc import ABCMeta
from typing import Any, Union

# third party
Expand All @@ -21,7 +22,7 @@
from synthcity.utils.constants import DEVICE


class TabularARF:
class TabularARF(metaclass=ABCMeta):
def __init__(
self,
# ARF parameters
Expand Down
3 changes: 2 additions & 1 deletion src/synthcity/plugins/core/models/tabular_goggle.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# stdlib
from abc import ABCMeta
from typing import Any, Optional, Union

# third party
Expand All @@ -15,7 +16,7 @@
from .tabular_encoder import TabularEncoder


class TabularGoggle:
class TabularGoggle(metaclass=ABCMeta):
def __init__(
self,
X: pd.DataFrame,
Expand Down
3 changes: 2 additions & 1 deletion src/synthcity/plugins/core/models/tabular_great.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# stdlib
from abc import ABCMeta
from typing import Any, Dict, Optional, Union

# third party
Expand All @@ -20,7 +21,7 @@
from synthcity.utils.constants import DEVICE


class TabularGReaT:
class TabularGReaT(metaclass=ABCMeta):
"""
.. inheritance-diagram:: synthcity.plugins.core.models.tabular_great.TabularGReaT
:parts: 1
Expand Down
1 change: 1 addition & 0 deletions src/synthcity/plugins/core/models/ts_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,7 @@ def forward(
raise ValueError("NaNs detected in the temporal horizons")

if self.use_horizon_condition:
# TODO: ADD error handling for len(temporal_data.shape) != 3 or len(observation_times.shape) != 2
temporal_data_merged = torch.cat(
[temporal_data, observation_times.unsqueeze(2)], dim=2
)
Expand Down
10 changes: 9 additions & 1 deletion src/synthcity/plugins/core/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,6 +560,9 @@ class PluginLoader:

@validate_arguments
def __init__(self, plugins: list, expected_type: Type, categories: list) -> None:
# self.reload()
global PLUGIN_CATEGORY_REGISTRY
PLUGIN_CATEGORY_REGISTRY = {cat: [] for cat in categories}
self._refresh()
self._available_plugins = {}
for plugin in plugins:
Expand Down Expand Up @@ -662,6 +665,8 @@ def _add_category(self, category: str, name: str) -> "PluginLoader":

def add(self, name: str, cls: Type) -> "PluginLoader":
"""Add a new plugin"""
global PLUGIN_REGISTRY
global PLUGIN_CATEGORY_REGISTRY
self._refresh()
if name in self._plugins:
log.info(f"Plugin {name} already exists. Overwriting")
Expand Down Expand Up @@ -742,5 +747,8 @@ def __getitem__(self, key: str) -> Any:
return self.get(key)

def reload(self) -> "PluginLoader":
self._plugins = {}
global PLUGIN_CATEGORY_REGISTRY
global PLUGIN_REGISTRY
PLUGIN_CATEGORY_REGISTRY = dict()
PLUGIN_REGISTRY = dict()
return self
2 changes: 1 addition & 1 deletion src/synthcity/plugins/privacy/plugin_aim.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def name() -> str:

@staticmethod
def type() -> str:
return "generic"
return "privacy"

@staticmethod
def hyperparameter_space(**kwargs: Any) -> List[Distribution]:
Expand Down
1 change: 0 additions & 1 deletion src/synthcity/utils/datasets/time_series/google_stocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,6 @@ def load(
np.asarray(observation_times),
np.asarray(outcome, dtype=np.float32),
)

return (
pd.DataFrame(np.zeros((len(temporal_data), 0))),
temporal_data,
Expand Down
2 changes: 1 addition & 1 deletion tests/metrics/test_performance.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,7 @@ def test_evaluate_performance_custom_labels(


@pytest.mark.slow
@pytest.mark.parametrize("test_plugin", [Plugins().get("marginal_distributions")])
@pytest.mark.parametrize("test_plugin", [Plugins().get("timegan")])
@pytest.mark.parametrize(
"evaluator_t",
[
Expand Down
11 changes: 8 additions & 3 deletions tests/plugins/generic/test_goggle.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# third party
import numpy as np
import pkg_resources
import pytest
from generic_helpers import generate_fixtures
from importlib_metadata import PackageNotFoundError, distribution
from sklearn.datasets import load_diabetes, load_iris

# synthcity absolute
Expand All @@ -24,8 +24,13 @@

if not is_missing_goggle_deps:
goggle_dependencies = {"dgl", "torch-scatter", "torch-sparse", "torch-geometric"}
installed = {pkg.key for pkg in pkg_resources.working_set}
is_missing_goggle_deps = len(goggle_dependencies - installed) > 0
missing_deps = []
for dep in goggle_dependencies:
try:
distribution(dep)
except PackageNotFoundError:
missing_deps.append(dep)
is_missing_goggle_deps = len(missing_deps) > 0


@pytest.mark.skipif(is_missing_goggle_deps, reason="Goggle dependencies not installed")
Expand Down
23 changes: 21 additions & 2 deletions tests/plugins/generic/test_great.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# stdlib
import os
import random
import sys
from datetime import datetime, timedelta

# third party
Expand All @@ -15,9 +16,14 @@
from synthcity.plugins import Plugin
from synthcity.plugins.core.constraints import Constraints
from synthcity.plugins.core.dataloader import GenericDataLoader
from synthcity.plugins.generic.plugin_great import plugin
from synthcity.utils.serialization import load, save

if sys.version_info >= (3, 9):
# synthcity absolute
from synthcity.plugins.generic.plugin_great import plugin
else:
plugin = None

IN_GITHUB_ACTIONS: bool = os.getenv("GITHUB_ACTIONS") == "true"

plugin_name = "great"
Expand All @@ -28,34 +34,44 @@
}


@pytest.mark.skipif(sys.version_info < (3, 9), reason="GReaT requires Python 3.9+")
@pytest.mark.parametrize("test_plugin", generate_fixtures(plugin_name, plugin))
def test_plugin_sanity(test_plugin: Plugin) -> None:
assert test_plugin is not None


@pytest.mark.skipif(sys.version_info < (3, 9), reason="GReaT requires Python 3.9+")
@pytest.mark.parametrize("test_plugin", generate_fixtures(plugin_name, plugin))
def test_plugin_name(test_plugin: Plugin) -> None:
assert test_plugin.name() == plugin_name


@pytest.mark.skipif(sys.version_info < (3, 9), reason="GReaT requires Python 3.9+")
@pytest.mark.parametrize("test_plugin", generate_fixtures(plugin_name, plugin))
def test_plugin_type(test_plugin: Plugin) -> None:
assert test_plugin.type() == "generic"


@pytest.mark.skipif(sys.version_info < (3, 9), reason="GReaT requires Python 3.9+")
@pytest.mark.parametrize("test_plugin", generate_fixtures(plugin_name, plugin))
def test_plugin_hyperparams(test_plugin: Plugin) -> None:
assert len(test_plugin.hyperparameter_space()) == 1


@pytest.mark.skipif(sys.version_info < (3, 9), reason="GReaT requires Python 3.9+")
@pytest.mark.parametrize(
"test_plugin", generate_fixtures(plugin_name, plugin, plugin_args)
)
@pytest.mark.skipif(
IN_GITHUB_ACTIONS,
reason="GReaT generate required too much memory to reliably run in GitHub Actions",
)
def test_plugin_fit(test_plugin: Plugin) -> None:
X, _ = load_iris(as_frame=True, return_X_y=True)
test_plugin.fit(GenericDataLoader(X))


@pytest.mark.skipif(sys.version_info < (3, 9), reason="GReaT requires Python 3.9+")
@pytest.mark.skipif(
IN_GITHUB_ACTIONS,
reason="GReaT generate required too much memory to reliably run in GitHub Actions",
Expand Down Expand Up @@ -92,6 +108,7 @@ def test_plugin_generate(test_plugin: Plugin, serialize: bool) -> None:


@pytest.mark.slow
@pytest.mark.skipif(sys.version_info < (3, 9), reason="GReaT requires Python 3.9+")
@pytest.mark.skipif(
IN_GITHUB_ACTIONS,
reason="GReaT generate required too much memory to reliably run in GitHub Actions",
Expand All @@ -102,7 +119,7 @@ def test_plugin_generate(test_plugin: Plugin, serialize: bool) -> None:
def test_plugin_generate_constraints_great(test_plugin: Plugin) -> None:
X, y = load_iris(as_frame=True, return_X_y=True)
X["target"] = y
test_plugin.fit(GenericDataLoader(X))
test_plugin.fit(GenericDataLoader(X), device="cpu")

constraints = Constraints(
rules=[
Expand Down Expand Up @@ -134,6 +151,7 @@ def test_sample_hyperparams() -> None:
assert plugin(**args) is not None


@pytest.mark.skipif(sys.version_info < (3, 9), reason="GReaT requires Python 3.9+")
@pytest.mark.skipif(
IN_GITHUB_ACTIONS,
reason="GReaT generate required too much memory to reliably run in GitHub Actions",
Expand Down Expand Up @@ -168,6 +186,7 @@ def gen_datetime(min_year: int = 2000, max_year: int = datetime.now().year) -> d


@pytest.mark.slow
@pytest.mark.skipif(sys.version_info < (3, 9), reason="GReaT requires Python 3.9+")
@pytest.mark.skipif(
IN_GITHUB_ACTIONS,
reason="GReaT generate required too much memory to reliably run in GitHub Actions",
Expand Down
2 changes: 1 addition & 1 deletion tests/plugins/privacy/test_aim.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def test_plugin_name(test_plugin: Plugin) -> None:

@pytest.mark.parametrize("test_plugin", generate_fixtures(plugin_name, plugin))
def test_plugin_type(test_plugin: Plugin) -> None:
assert test_plugin.type() == "generic"
assert test_plugin.type() == "privacy"


@pytest.mark.parametrize("test_plugin", generate_fixtures(plugin_name, plugin))
Expand Down
Loading

0 comments on commit 73cfd8c

Please sign in to comment.