Skip to content

Commit

Permalink
Augmentation benchmark (#150)
Browse files Browse the repository at this point in the history
* Add new benchmark code

* Merge main into branch

* Augmentation benchmark added

* Clean up

* Clean up

* Remove unnecessary tutorial file

* Clean up

* clean up

* Debug test and clean up

* Added new tests for augmentation benchmark

* Added new metric api tests for augmentation

* clean up

* clean up

* version bumped and clean up

* clean up docstrings
  • Loading branch information
robsdavis authored Mar 15, 2023
1 parent b82baca commit cf6ea56
Show file tree
Hide file tree
Showing 13 changed files with 610 additions and 45 deletions.
104 changes: 98 additions & 6 deletions src/synthcity/benchmark/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import json
import platform
import random
from copy import copy
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple

Expand All @@ -14,6 +15,7 @@

# synthcity absolute
import synthcity.logger as log
from synthcity.benchmark.utils import augment_data
from synthcity.metrics import Metrics
from synthcity.metrics.scores import ScoreEvaluator
from synthcity.plugins import Plugins
Expand Down Expand Up @@ -48,10 +50,16 @@ def evaluate(
synthetic_constraints: Optional[Constraints] = None,
synthetic_cache: bool = True,
synthetic_reuse_if_exists: bool = True,
augmented_reuse_if_exists: bool = True,
task_type: str = "classification", # classification, regression, survival_analysis, time_series
workspace: Path = Path("workspace"),
augmentation_rule: str = "equal",
strict_augmentation: bool = False,
ad_hoc_augment_vals: Optional[Dict] = None,
use_metric_cache: bool = True,
**generate_kwargs: Any,
) -> pd.DataFrame:

"""Benchmark the performance of several algorithms.
Args:
Expand Down Expand Up @@ -80,11 +88,21 @@ def evaluate(
synthetic_cache: bool
Enable experiment caching
synthetic_reuse_if_exists: bool
If the current synthetic dataset is cached, it will be reused for the experiments.
If the current synthetic dataset is cached, it will be reused for the experiments. Defaults to True.
augmented_reuse_if_exists: bool
If the current augmented dataset is cached, it will be reused for the experiments. Defaults to True.
task_type: str
The type of problem. Relevant for evaluating the downstream models with the correct metrics. Valid tasks are: "classification", "regression", "survival_analysis", "time_series", "time_series_survival".
workspace: Path
Path for caching experiments. Default: "workspace".
augmentation_rule: str
The rule used to achieve the desired proportion records with each value in the fairness column. Possible values are: 'equal', 'log', and 'ad-hoc'. Defaults to "equal".
strict_augmentation: bool
Flag to ensure that the condition for generating synthetic data is strictly met. Defaults to False.
ad_hoc_augment_vals: Dict
A dictionary containing the number of each class to augment the real data with. This is only required if using the rule="ad-hoc" option. Defaults to None.
use_metric_cache: bool
If the current metric has been previously run and is cached, it will be reused for the experiments. Defaults to True.
plugin_kwargs:
Optional kwargs for each algorithm. Example {"adsgan": {"n_iter": 10}},
"""
Expand Down Expand Up @@ -115,6 +133,17 @@ def evaluate(
hash_object = hashlib.sha256(kwargs_hash_raw)
kwargs_hash = hash_object.hexdigest()

augmentation_arguments = {
"augmentation_rule": augmentation_rule,
"strict_augmentation": strict_augmentation,
"ad_hoc_augment_vals": ad_hoc_augment_vals,
}
augmentation_arguments_hash_raw = json.dumps(
copy(augmentation_arguments), sort_keys=True
).encode()
augmentation_hash_object = hashlib.sha256(augmentation_arguments_hash_raw)
augmentation_hash = augmentation_hash_object.hexdigest()

repeats_list = list(range(repeats))
random.shuffle(repeats_list)

Expand All @@ -126,14 +155,22 @@ def evaluate(

clear_cache()

cache_file = (
X_syn_cache_file = (
workspace
/ f"{experiment_name}_{testcase}_{plugin}_{kwargs_hash}_{platform.python_version()}_{repeat}.bkp"
)
generator_file = (
workspace
/ f"{experiment_name}_{testcase}_{plugin}_{kwargs_hash}_{platform.python_version()}_generator_{repeat}.bkp"
)
X_augment_cache_file = (
workspace
/ f"{experiment_name}_{testcase}_{plugin}_augmentation_{augmentation_hash}_{kwargs_hash}_{platform.python_version()}_{repeat}.bkp"
)
augment_generator_file = (
workspace
/ f"{experiment_name}_{testcase}_{plugin}_augmentation_{augmentation_hash}_{kwargs_hash}_{platform.python_version()}_generator_{repeat}.bkp"
)

log.info(
f"[testcase] Experiment repeat: {repeat} task type: {task_type} Train df hash = {experiment_name}"
Expand All @@ -152,8 +189,8 @@ def evaluate(
if synthetic_cache:
save_to_file(generator_file, generator)

if cache_file.exists() and synthetic_reuse_if_exists:
X_syn = load_from_file(cache_file)
if X_syn_cache_file.exists() and synthetic_reuse_if_exists:
X_syn = load_from_file(X_syn_cache_file)
else:
try:
X_syn = generator.generate(
Expand All @@ -168,13 +205,68 @@ def evaluate(
continue

if synthetic_cache:
save_to_file(cache_file, X_syn)
save_to_file(X_syn_cache_file, X_syn)

# Augmentation
if metrics and any(
"augmentation" in metric
for metric in [x for v in metrics.values() for x in v]
):
if augment_generator_file.exists() and augmented_reuse_if_exists:
augment_generator = load_from_file(augment_generator_file)
else:
augment_generator = Plugins(categories=plugin_cats).get(
plugin,
**kwargs,
)
try:
if not X.get_fairness_column():
raise ValueError(
"To use the augmentation metrics, `fairness_column` must be set to a string representing the name of a column in the DataLoader."
)
augment_generator.fit(
X.train(),
cond=X.train()[X.get_fairness_column()],
)
except BaseException as e:
log.critical(
f"[{plugin}][take {repeat}] failed to fit augmentation generator: {e}"
)
continue
if synthetic_cache:
save_to_file(augment_generator_file, augment_generator)

if X_augment_cache_file.exists() and augmented_reuse_if_exists:
X_augmented = load_from_file(X_augment_cache_file)
else:
try:
X_augmented = augment_data(
X.train(),
augment_generator,
rule=augmentation_rule,
strict=strict_augmentation,
ad_hoc_augment_vals=ad_hoc_augment_vals,
**generate_kwargs,
)
if len(X_augmented) == 0:
raise RuntimeError("Plugin failed to generate data")
except BaseException as e:
log.critical(
f"[{plugin}][take {repeat}] failed to generate augmentation data: {e}"
)
continue
if synthetic_cache:
save_to_file(X_augment_cache_file, X_augmented)
else:
X_augmented = None
evaluation = Metrics.evaluate(
X_test if X_test is not None else X,
X_test if X_test is not None else X.test(),
X_syn,
X_augmented,
metrics=metrics,
task_type=task_type,
workspace=workspace,
use_cache=use_metric_cache,
)

mean_score = evaluation["mean"].to_dict()
Expand Down
193 changes: 193 additions & 0 deletions src/synthcity/benchmark/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
# stdlib
import math
from copy import copy
from typing import Any, Dict, Optional

# third party
import numpy as np
import pandas as pd
from pydantic import validate_arguments
from typing_extensions import Literal

# synthcity absolute
from synthcity.plugins.core.constraints import Constraints
from synthcity.plugins.core.dataloader import DataLoader


def calculate_fair_aug_sample_size(
X_train: pd.DataFrame,
fairness_column: Optional[str], # a categorical column of K levels
rule: Literal[
"equal", "log", "ad-hoc"
], # TODO: Confirm are there any more methods to include
ad_hoc_augment_vals: Optional[
Dict[Any, int]
] = None, # Only required for rule == "ad-hoc"
) -> Dict:
"""Calculate how many samples to augment.
Args:
X_train (pd.DataFrame): The real dataset to be augmented.
fairness_column (str): The column name of the column to test the fairness of a downstream model with respect to.
rule (Literal["equal", "log", "ad-hoc"]): The rule used to achieve the desired proportion records with each value in the fairness column. Defaults to "equal".
ad_hoc_augment_vals (Dict[ Union[int, str], int ], optional): A dictionary containing the number of each class to augment the real data with. If using rule="ad-hoc" this function returns ad_hoc_augment_vals, otherwise this parameter is ignored. Defaults to {}.
Returns:
Dict: A dictionary containing the number of each class to augment the real data with.
"""

# the majority class is unchanged
if rule == "equal":
# number of sample will be the same for each value in the fairness column after augmentation
# N_aug(i) = N_ang(j) for all i and j in value in the fairness column
fairness_col_counts = X_train[fairness_column].value_counts()
majority_size = fairness_col_counts.max()
augmentation_counts = {
fair_col_val: (majority_size - fairness_col_counts.loc[fair_col_val])
for fair_col_val in fairness_col_counts.index
}
elif rule == "log":
# number of samples in aug data will be proportional to the log frequency in the real data.
# Note: taking the log makes the distribution more even.
# N_aug(i) is proportional to log(N_real(i))
fairness_col_counts = X_train[fairness_column].value_counts()
majority_size = fairness_col_counts.max()
log_coefficient = majority_size / math.log(majority_size)

augmentation_counts = {
fair_col_val: (
majority_size - round(math.log(fair_col_count) * log_coefficient)
)
for fair_col_val, fair_col_count in fairness_col_counts.items()
}
elif rule == "ad-hoc":
# use user-specified values to augment
if not ad_hoc_augment_vals:
raise ValueError(
"When augmenting with an `ad-hoc` method, ad_hoc_augment_vals must be a dictionary, where the dictionary keys are the values of the fairness_column and the dictionary values are the number of records to augment."
)
else:
if not set(ad_hoc_augment_vals.keys()).issubset(
set(X_train[fairness_column].values)
):
raise ValueError(
"ad_hoc_augment_vals must be a dictionary, where the dictionary keys are the values of the fairness_column and the dictionary values are the number of records to augment."
)
elif set(X_train[fairness_column].values) != set(
ad_hoc_augment_vals.keys()
):
ad_hoc_augment_vals = {
k: v
for k, v in ad_hoc_augment_vals.items()
if k in set(X_train[fairness_column].values)
}

augmentation_counts = ad_hoc_augment_vals

return augmentation_counts


@validate_arguments(config=dict(arbitrary_types_allowed=True))
def _generate_synthetic_data(
X_train: DataLoader,
augment_generator: Any,
strict: bool = True,
rule: Literal["equal", "log", "ad-hoc"] = "equal",
ad_hoc_augment_vals: Optional[
Dict[Any, int]
] = None, # Only required for rule == "ad-hoc"
synthetic_constraints: Optional[Constraints] = None,
**generate_kwargs: Any,
) -> pd.DataFrame:
"""Generates synthetic data
Args:
X_train (DataLoader): The dataset used to train the downstream model.
augment_generator (Any): The synthetic model to be used to generate the synthetic portion of the augmented dataset.
strict (bool, optional): Flag to ensure that the condition for generating synthetic data is strictly met. Defaults to False.
rule (Literal["equal", "log", "ad-hoc"): The rule used to achieve the desired proportion records with each value in the fairness column. Defaults to "equal".
ad_hoc_augment_vals (Dict[ Union[int, str], int ], optional): A dictionary containing the number of each class to augment the real data with. This is only required if using the rule="ad-hoc" option. Defaults to {}.
Returns:
pd.DataFrame: The generated synthetic data.
"""
augmentation_counts = calculate_fair_aug_sample_size(
X_train.dataframe(),
X_train.get_fairness_column(),
rule,
ad_hoc_augment_vals=ad_hoc_augment_vals,
)
if not strict:
# set count equal to the total number of records required according to calculate_fair_aug_sample_size
count = sum(augmentation_counts.values())
cond = pd.Series(
np.repeat(
list(augmentation_counts.keys()), list(augmentation_counts.values())
)
)
syn_data = augment_generator.generate(
count=count,
cond=cond,
constraints=synthetic_constraints,
**generate_kwargs,
).dataframe()
else:
syn_data_list = []
for fairness_value, count in augmentation_counts.items():
if count > 0:
constraints = Constraints(
rules=[(X_train.get_fairness_column(), "==", fairness_value)]
)
syn_data_list.append(
augment_generator.generate(
count=count, constraints=constraints
).dataframe()
)
syn_data = pd.concat(syn_data_list)
return syn_data


@validate_arguments(config=dict(arbitrary_types_allowed=True))
def augment_data(
X_train: DataLoader,
augment_generator: Any,
strict: bool = False,
rule: Literal["equal", "log", "ad-hoc"] = "equal",
ad_hoc_augment_vals: Optional[
Dict[Any, int]
] = None, # Only required for rule == "ad-hoc"
synthetic_constraints: Optional[Constraints] = None,
**generate_kwargs: Any,
) -> DataLoader:
"""Augment the real data with generated synthetic data
Args:
X_train (DataLoader): The ground truth DataLoader to augment with synthetic data.
augment_generator (Any): The synthetic model to be used to generate the synthetic portion of the augmented dataset.
strict (bool, optional): Flag to ensure that the condition for generating synthetic data is strictly met. Defaults to False.
rule (Literal["equal", "log", "ad-hoc"): The rule used to achieve the desired proportion records with each value in the fairness column. Defaults to "equal".
ad_hoc_augment_vals (Dict[Union[int, str], int], optional): A dictionary containing the number of each class to augment the real data with. This is only required if using the rule="ad-hoc" option. Defaults to None.
synthetic_constraints (Optional[Constraints]): Constraints placed on the generation of the synthetic data. Defaults to None.
Returns:
DataLoader: The augmented dataset and labels.
"""
syn_data = _generate_synthetic_data(
X_train,
augment_generator,
strict=strict,
rule=rule,
ad_hoc_augment_vals=ad_hoc_augment_vals,
synthetic_constraints=synthetic_constraints,
**generate_kwargs,
)

augmented_data_loader = copy(X_train)
augmented_data_loader.data = pd.concat(
[
X_train.data,
syn_data,
]
)

return augmented_data_loader
Loading

0 comments on commit cf6ea56

Please sign in to comment.