Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update dependencies #288

Merged
merged 12 commits into from
Sep 3, 2024
2 changes: 1 addition & 1 deletion .github/workflows/release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ jobs:
runs-on: ${{ matrix.os }}
strategy:
matrix:
python-version: ["3.8", "3.9", "3.10"]
python-version: ["3.9", "3.10", "3.11"]
os: [macos-latest]

steps:
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/test_all_tutorials.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ jobs:
runs-on: ${{ matrix.os }}
strategy:
matrix:
python-version: ["3.8", "3.9", "3.10"]
python-version: ["3.9", "3.10", "3.11"]
os: [ubuntu-latest]
steps:
- uses: actions/checkout@v2
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/test_full.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ jobs:
runs-on: ${{ matrix.os }}
strategy:
matrix:
python-version: ["3.8", "3.9", "3.10"]
python-version: ["3.9", "3.10", "3.11"]
os: [macos-latest, ubuntu-latest, windows-latest]
steps:
- uses: actions/checkout@v2
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/test_tutorials.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ jobs:
runs-on: ${{ matrix.os }}
strategy:
matrix:
python-version: ["3.8", "3.9", "3.10"]
python-version: ["3.9", "3.10", "3.11"]
os: [ubuntu-latest]
steps:
- uses: actions/checkout@v2
Expand Down
4 changes: 2 additions & 2 deletions prereq.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
numpy>=1.20, <1.24
torch>=1.10.0,<2.0
numpy>=1.20
torch>=2.1, <2.3 # Max due to tsai
tsai
wheel>=0.40
18 changes: 9 additions & 9 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -29,32 +29,33 @@ include_package_data = True
package_dir =
=src

python_requires = >=3.8
python_requires = >=3.9

install_requires =
importlib-metadata
pandas>=1.4,<2
torch>=1.10.0,<2.0
pandas>=2.1 # min due to lifelines
torch>=2.1, <2.3 # Max due to tsai
scikit-learn>=1.2
nflows>=0.14
numpy>=1.20, <1.24
lifelines>=0.27,!= 0.27.5, <0.27.8
numpy>=1.20, <2.0
lifelines>=0.29.0, <0.30.0 # max due to xgbse
opacus>=1.3
networkx>2.0,<3.0
decaf-synthetic-data>=0.1.6
optuna>=3.1
shap
tenacity
tqdm
loguru
pydantic<2.0
cloudpickle
scipy
xgboost<2.0.0
xgboost<3.0.0
geomloss
pgmpy
redis
pycox
xgbse
xgbse>=0.3.1
pykeops
fflows
monai
Expand Down Expand Up @@ -96,13 +97,12 @@ testing =
click

goggle =
dgl<2.0
dgl
torch_geometric
torch_sparse
torch_scatter

all =
importlib-metadata;python_version<"3.8"
%(testing)s
%(goggle)s

Expand Down
15 changes: 12 additions & 3 deletions src/synthcity/plugins/core/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -931,11 +931,20 @@ def unpack(self, as_numpy: bool = False, pad: bool = False) -> Any:

if as_numpy:
longest_observation_seq = max([len(seq) for seq in temporal_data])
padded_temporal_data = np.zeros(
(len(temporal_data), longest_observation_seq, 5)
)
mask = np.ones((len(temporal_data), longest_observation_seq, 5), dtype=bool)
for i, arr in enumerate(temporal_data):
padded_temporal_data[i, : arr.shape[0], :] = arr # Copy the actual data
mask[
i, : arr.shape[0], :
] = False # Set mask to False where actual data is present

masked_temporal_data = ma.masked_array(padded_temporal_data, mask)
return (
np.asarray(static_data),
np.asarray(
temporal_data
), # TODO: check this works with time series benchmarks
masked_temporal_data, # TODO: check this works with time series benchmarks
# masked array to handle variable length sequences
ma.vstack(
[
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,16 @@
import numpy as np
import pandas as pd
from lifelines import KaplanMeierFitter
from scipy.integrate import trapz
from xgbse.non_parametric import _get_conditional_probs_from_survival

try:
# third party
from scipy.integrate import trapz
except ImportError:
from numpy import (
trapz,
) # As a fallback for older versions if scipy's import path changes

# synthcity absolute
from synthcity.plugins.core.models.survival_analysis.third_party.metrics import (
brier_score,
Expand Down
9 changes: 8 additions & 1 deletion src/synthcity/plugins/core/models/time_to_event/tte_aft.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,14 @@
import pandas as pd
from lifelines import WeibullAFTFitter
from pydantic import validate_arguments
from scipy.integrate import trapz

try:
# third party
from scipy.integrate import trapz
except ImportError:
from numpy import (
trapz,
) # As a fallback for older versions if scipy's import path changes

# synthcity absolute
from synthcity.plugins.core.distribution import Distribution, FloatDistribution
Expand Down
9 changes: 8 additions & 1 deletion src/synthcity/plugins/core/models/time_to_event/tte_coxph.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,14 @@
import pandas as pd
from lifelines import CoxPHFitter
from pydantic import validate_arguments
from scipy.integrate import trapz

try:
# third party
from scipy.integrate import trapz
except ImportError:
from numpy import (
trapz,
) # As a fallback for older versions if scipy's import path changes

# synthcity absolute
from synthcity.plugins.core.distribution import Distribution, FloatDistribution
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,16 @@
import torchtuples as tt
from pycox.models import DeepHitSingle
from pydantic import validate_arguments
from scipy.integrate import trapz
from sklearn.model_selection import train_test_split

try:
# third party
from scipy.integrate import trapz
except ImportError:
from numpy import (
trapz,
) # As a fallback for older versions if scipy's import path changes

# synthcity absolute
from synthcity.plugins.core.distribution import (
CategoricalDistribution,
Expand Down
9 changes: 8 additions & 1 deletion src/synthcity/plugins/core/models/time_to_event/tte_xgb.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,17 @@
import numpy as np
import pandas as pd
from pydantic import validate_arguments
from scipy.integrate import trapz
from xgbse import XGBSEDebiasedBCE, XGBSEKaplanNeighbors, XGBSEStackedWeibull
from xgbse.converters import convert_to_structured

try:
# third party
from scipy.integrate import trapz
except ImportError:
from numpy import (
trapz,
) # As a fallback for older versions if scipy's import path changes

# synthcity absolute
from synthcity.plugins.core.distribution import (
CategoricalDistribution,
Expand Down
9 changes: 9 additions & 0 deletions src/synthcity/plugins/core/serializable.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,11 @@ class Serializable:
"""Utility class for model persistence."""

def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
derived_module_path: Optional[Path] = None
self.fitted = (
False # make sure all serializable objects are not fitted by default
)

search_module = self.__class__.__module__
if not search_module.endswith(".py"):
Expand Down Expand Up @@ -58,9 +62,14 @@ def save_dict(self) -> dict:
data = self.__dict__[key]
if isinstance(data, Serializable):
members[key] = self.__dict__[key].save_dict()
elif key == "model":
members[key] = serialize(self.__dict__[key])
else:
members[key] = copy.deepcopy(self.__dict__[key])

if "fitted" not in members:
members["fitted"] = self.fitted # Ensure 'fitted' is always serialized

return {
"source": "synthcity",
"data": members,
Expand Down
145 changes: 136 additions & 9 deletions src/synthcity/utils/serialization.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,146 @@
# stdlib
import hashlib
from pathlib import Path
from typing import Any, Union
from typing import Any, List, Union

# third party
import cloudpickle
import pandas as pd


def save(model: Any) -> bytes:
return cloudpickle.dumps(model)


def load(buff: bytes) -> Any:
return cloudpickle.loads(buff)
from opacus import PrivacyEngine

# The list of plugins that are not simply loadable with cloudpickle
unloadable_plugins: List[str] = [
"dpgan", # DP-GAN plugin id not loadable with cloudpickle due to the DPOptimizer
]


# TODO: simplify this function back to just cloudpickle.dumps(model), if possible (i.e. if the DPOptimizer is not needed or becomes loadable with cloudpickle)
def save(custom_model: Any) -> bytes:
"""
Serialize a custom model object that may or may not contain a PyTorch model with a privacy engine.

Args:
custom_model: The custom model object to serialize, potentially containing a PyTorch model with a privacy engine.

Returns:
bytes: Serialized model state as bytes.
"""
# Checks is custom model is not a plugin without circular import
if not hasattr(custom_model, "name"):
return cloudpickle.dumps(custom_model)

if custom_model.name() not in unloadable_plugins:
return cloudpickle.dumps(custom_model)

# Initialize the checkpoint dictionary
checkpoint = {
"custom_model_state": None,
"pytorch_model_state": None,
"privacy_engine_state": None,
"optimizer_state": None,
"optimizer_class": None,
"optimizer_defaults": None,
}

# Save the state of the custom model object (excluding the PyTorch model and optimizer)
custom_model_state = {
key: value for key, value in custom_model.__dict__.items() if key != "model"
}
checkpoint["custom_model_state"] = cloudpickle.dumps(custom_model_state)

# Check if the custom model contains a PyTorch model
pytorch_model = None
if hasattr(custom_model, "model"):
pytorch_model = getattr(custom_model, "model")

# If a PyTorch model is found, check if it's using Opacus for DP
if pytorch_model:
checkpoint["pytorch_model_state"] = pytorch_model.state_dict()
if hasattr(pytorch_model, "privacy_engine") and isinstance(
pytorch_model.privacy_engine, PrivacyEngine
):
# Handle DP Optimizer
optimizer = pytorch_model.privacy_engine.optimizer

checkpoint.update(
{
"optimizer_state": optimizer.state_dict(),
"privacy_engine_state": pytorch_model.privacy_engine.state_dict(),
"optimizer_class": optimizer.__class__,
"optimizer_defaults": optimizer.defaults,
}
)

# Serialize the entire state with cloudpickle
return cloudpickle.dumps(checkpoint)


# TODO: simplify this function back to just cloudpickle.loads(model), if possible (i.e. if the DPOptimizer is not needed or becomes loadable with cloudpickle)
def load(buff: bytes, custom_model: Any = None) -> Any:
"""
Deserialize a custom model object that may or may not contain a PyTorch model with a privacy engine.

Args:
buff (bytes): Serialized model state as bytes.
custom_model: The custom model instance to load the state into.

Returns:
custom_model: The deserialized custom model with its original state.
"""
# Load the checkpoint
if custom_model is None or custom_model.name() not in unloadable_plugins:
return cloudpickle.loads(buff)

if custom_model is None:
raise ValueError(
f"custom_model must be provided when loading one of the following plugins: {unloadable_plugins}"
)

checkpoint = cloudpickle.loads(buff)
# Restore the custom model's own state (excluding the PyTorch model)
custom_model_state = cloudpickle.loads(checkpoint["custom_model_state"])
for key, value in custom_model_state.items():
setattr(custom_model, key, value)

# Find the PyTorch model inside the custom model if it exists
pytorch_model = None
if hasattr(custom_model, "model"):
pytorch_model = getattr(custom_model, "model")

# Load the states into the PyTorch model if it exists
if pytorch_model and checkpoint["pytorch_model_state"] is not None:
pytorch_model.load_state_dict(checkpoint["pytorch_model_state"])

# Check if the serialized model had a privacy engine
if checkpoint["privacy_engine_state"] is not None:
# If there was a privacy engine, recreate and reattach it
optimizer_class = checkpoint["optimizer_class"]
optimizer_defaults = checkpoint["optimizer_defaults"]

# Ensure the optimizer is correctly created with model's parameters
optimizer = optimizer_class(
pytorch_model.parameters(), **optimizer_defaults
)

# Recreate the privacy engine
privacy_engine = PrivacyEngine(
pytorch_model,
sample_rate=optimizer.defaults.get(
"sample_rate", 0.01
), # Use saved or default values
noise_multiplier=optimizer.defaults.get("noise_multiplier", 1.0),
max_grad_norm=optimizer.defaults.get("max_grad_norm", 1.0),
)
privacy_engine.attach(optimizer)

# Load the saved states
optimizer.load_state_dict(checkpoint["optimizer_state"])
privacy_engine.load_state_dict(checkpoint["privacy_engine_state"])

# Assign back to the PyTorch model (or the appropriate container)
pytorch_model.privacy_engine = privacy_engine

return custom_model


def save_to_file(path: Union[str, Path], model: Any) -> Any:
Expand Down
2 changes: 1 addition & 1 deletion src/synthcity/version.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "0.2.10"
__version__ = "0.2.11"

MAJOR_VERSION = ".".join(__version__.split(".")[:-1])
PATCH_VERSION = __version__.split(".")[-1]
Loading
Loading