Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Full tests #294

Closed
wants to merge 22 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .github/workflows/test_all_tutorials.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ jobs:
- uses: actions/checkout@v3
with:
submodules: true
- uses: gautamkrishnar/keepalive-workflow@v1
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
Expand Down
8 changes: 4 additions & 4 deletions .github/workflows/test_full.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ jobs:
- uses: actions/checkout@v3
with:
submodules: true
- uses: gautamkrishnar/keepalive-workflow@v1
ref: main
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
Expand All @@ -30,18 +30,18 @@ jobs:
run: |
python -m pip install -U pip
pip install -r prereq.txt
- name: Limit OpenMP threads
run: |
echo "OMP_NUM_THREADS=2" >> $GITHUB_ENV
- name: Test Core - slow part one
timeout-minutes: 1000
run: |
pip install .[testing]
pytest -vvvs --durations=50 -m "slow_1"
- name: Test Core - slow part two
timeout-minutes: 1000
run: |
pip install .[testing]
pytest -vvvs --durations=50 -m "slow_2"
- name: Test Core - fast
timeout-minutes: 1000
run: |
pip install .[testing]
pytest -vvvs --durations=50 -m "not slow"
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/test_pr.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ name: Tests Fast Python
on:
push:
branches: [main, release]
pull_request:
types: [opened, synchronize, reopened]
# pull_request:
# types: [opened, synchronize, reopened]
workflow_dispatch:

jobs:
Expand Down
8 changes: 4 additions & 4 deletions .github/workflows/test_tutorials.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ name: PR Tutorials
on:
push:
branches: [main, release]
pull_request:
types: [opened, synchronize, reopened]
# pull_request:
# types: [opened, synchronize, reopened]
schedule:
- cron: "2 3 * * 4"
workflow_dispatch:
Expand All @@ -20,7 +20,7 @@ jobs:
- uses: actions/checkout@v3
with:
submodules: true
- uses: gautamkrishnar/keepalive-workflow@v1
ref: main
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
Expand All @@ -40,4 +40,4 @@ jobs:
python -m pip install ipykernel
python -m ipykernel install --user
- name: Run the tutorials
run: python tests/nb_eval.py --nb_dir tutorials/ --tutorial_tests minimal_tests
run: python tests/nb_eval.py --nb_dir tutorials/ --tutorial_tests minimal_tests --timeout 3600
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ install_requires =
tenacity
tqdm
loguru
pydantic<2.0
pydantic>=2.0
cloudpickle
scipy
xgboost<3.0.0
Expand Down
4 changes: 4 additions & 0 deletions src/synthcity/benchmark/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def evaluate(
strict_augmentation: bool = False,
ad_hoc_augment_vals: Optional[Dict] = None,
use_metric_cache: bool = True,
n_eval_folds: int = 5,
**generate_kwargs: Any,
) -> pd.DataFrame:
"""Benchmark the performance of several algorithms.
Expand Down Expand Up @@ -102,6 +103,8 @@ def evaluate(
A dictionary containing the number of each class to augment the real data with. This is only required if using the rule="ad-hoc" option. Defaults to None.
use_metric_cache: bool
If the current metric has been previously run and is cached, it will be reused for the experiments. Defaults to True.
n_eval_folds: int
the KFolds used by MetricEvaluators in the benchmarks. Defaults to 5.
plugin_kwargs:
Optional kwargs for each algorithm. Example {"adsgan": {"n_iter": 10}},
"""
Expand Down Expand Up @@ -295,6 +298,7 @@ def evaluate(
task_type=task_type,
workspace=workspace,
use_cache=use_metric_cache,
n_folds=n_eval_folds,
)

mean_score = evaluation["mean"].to_dict()
Expand Down
6 changes: 3 additions & 3 deletions src/synthcity/metrics/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ def f() -> None:
"epoch": epoch,
},
workspace / "DomiasMIA_bnaf_checkpoint.pt",
)
) # nosec B614

return f

Expand All @@ -348,7 +348,7 @@ def f() -> None:

log.info("Loading model..")
if (workspace / "checkpoint.pt").exists():
checkpoint = torch.load(workspace / "checkpoint.pt")
checkpoint = torch.load(workspace / "checkpoint.pt") # nosec B614
model.load_state_dict(checkpoint["model"])
optimizer.load_state_dict(checkpoint["optimizer"])

Expand Down Expand Up @@ -453,7 +453,7 @@ def train(
"epoch": epoch,
},
workspace / "checkpoint.pt",
)
) # nosec B614
log.debug(
f"""
###### Stop training after {epoch + 1} epochs!
Expand Down
18 changes: 13 additions & 5 deletions src/synthcity/metrics/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ def evaluate(
random_state: int = 0,
workspace: Path = Path("workspace"),
use_cache: bool = True,
n_folds: int = 5,
) -> pd.DataFrame:
"""Core evaluation logic for the metrics

Expand Down Expand Up @@ -202,12 +203,16 @@ def evaluate(

"""
We need to encode the categorical data in the real and synthetic data.
To ensure each category in the two datasets are mapped to the same one hot vector, we merge X_syn into X_gt for computing the encoder.
TODO: Check whether the optional datasets also need to be taking into account when getting the encoder.
To ensure each category in the two datasets are mapped to the same one hot vector, we merge all avalable datasets for computing the encoder.
"""
X_gt_df = X_gt.dataframe()
X_syn_df = X_syn.dataframe()
X_enc = create_from_info(pd.concat([X_gt_df, X_syn_df]), X_gt.info())
all_df = pd.concat([X_gt.dataframe(), X_syn.dataframe()])
if X_train:
all_df = pd.concat([all_df, X_train.dataframe()])
if X_ref_syn:
all_df = pd.concat([all_df, X_ref_syn.dataframe()])
if X_augmented:
all_df = pd.concat([all_df, X_augmented.dataframe()])
X_enc = create_from_info(all_df, X_gt.info())
_, encoders = X_enc.encode()

# now we encode the data
Expand Down Expand Up @@ -238,6 +243,7 @@ def evaluate(
random_state=random_state,
workspace=workspace,
use_cache=use_cache,
n_folds=n_folds,
),
X_gt,
X_augmented,
Expand All @@ -251,6 +257,7 @@ def evaluate(
random_state=random_state,
workspace=workspace,
use_cache=use_cache,
n_folds=n_folds,
),
X_gt,
X_syn,
Expand All @@ -267,6 +274,7 @@ def evaluate(
random_state=random_state,
workspace=workspace,
use_cache=use_cache,
n_folds=n_folds,
),
X_gt.sample(eval_cnt),
X_syn.sample(eval_cnt),
Expand Down
10 changes: 6 additions & 4 deletions src/synthcity/plugins/core/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@
# third party
import numpy as np
import pandas as pd
from pydantic import BaseModel, validate_arguments, validator
from pydantic import BaseModel, field_validator, validate_arguments

# synthcity absolute
import synthcity.logger as log

Rule = Tuple[str, str, Any] # Define a type alias for clarity


class Constraints(BaseModel):
"""
Expand Down Expand Up @@ -41,10 +43,10 @@ class Constraints(BaseModel):
and thresh is the threshold or data type.
"""

rules: list = []
rules: list[Rule] = []

@validator("rules")
def _validate_rules(cls: Any, rules: List, values: dict, **kwargs: Any) -> List:
@field_validator("rules", mode="before")
def _validate_rules(cls: Any, rules: List) -> List:
supported_ops: list = [
"<",
">=",
Expand Down
Loading
Loading