Skip to content

Commit

Permalink
Synthetic experiments: power-tie, power-dp (#5)
Browse files Browse the repository at this point in the history
* add experiment config for power-tie analysis

* add experiment config for power-dp analysis

* replace print with `logging` messages
  • Loading branch information
honghaoli42 authored Jan 23, 2024
1 parent f94bc5d commit d0efbad
Show file tree
Hide file tree
Showing 4 changed files with 164 additions and 30 deletions.
66 changes: 66 additions & 0 deletions experiments/config/experiment/power_and_type_one_error_dp.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# @package _global_
name: "Power and type one error analyses"

# initial_seed is used to generate seed for each run
initial_seed: 123

data:
ndim: 10
scale_t: 10.0
shape_t: 3.0
propensity: "linear"
standardize_features: False

defaults:
- /[email protected]_robust: pooled_iptw
- /[email protected]_naive: pooled_iptw
- /[email protected]_robust: fl_iptw
- /[email protected]_naive: fl_iptw
- _self_

models:
IPTW_robust:
variance_method: "robust"
IPTW_naive:
variance_method: "naive"
FedECA_robust:
ndim: ${data.ndim}
num_rounds_list: [10, 10]
fedeca_path: "/home/owkin/fedeca/"
robust: True
FedECA_naive:
ndim: ${data.ndim}
num_rounds_list: [10, 10]
fedeca_path: "/home/owkin/fedeca/"
robust: False

# config fit FedECA
fit_fedeca:
n_clients: 3
split_method: "split_control_over_centers"
split_method_kwargs: {"treatment_info": "treatment_allocation"}
dp_max_grad_norm: 1.
dp_target_delta: 0.001
dp_propensity_model_training_params: {"batch_size": 100, "num_updates": 100}
dp_propensity_model_optimizer_kwargs: {"lr": 1e-2}
backend_type: "simu"

models_common:
treated_col: "treatment_allocation"
event_col: "event"
duration_col: "time"

parameters:
n_samples: 10000
n_reps: 1000
return_propensities: False
return_weights: False

hydra:
sweep:
dir: "/home/owkin/project/results_experiments/power_and_type_one_error_dp"
sweeper:
params:
data.cate: 1.0,0.4
data.overlap: -1,3
+fit_fedeca.dp_target_epsilon: 0.1, 10, 20, 30, 40, 50
62 changes: 62 additions & 0 deletions experiments/config/experiment/power_and_type_one_error_n_ties.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# @package _global_
name: "Power and type one error analyses"

# initial_seed is used to generate seed for each run
initial_seed: 42

data:
ndim: 10
scale_t: 10.0
shape_t: 3.0
propensity: "linear"
standardize_features: False

defaults:
- /[email protected]_robust: pooled_iptw
- /[email protected]_naive: pooled_iptw
- /[email protected]_robust: fl_iptw
- /[email protected]_naive: fl_iptw
- _self_

models:
IPTW_robust:
variance_method: "robust"
IPTW_naive:
variance_method: "naive"
FedECA_robust:
ndim: ${data.ndim}
num_rounds_list: [10, 10]
fedeca_path: "/home/owkin/fedeca/"
robust: True
FedECA_naive:
ndim: ${data.ndim}
num_rounds_list: [10, 10]
fedeca_path: "/home/owkin/fedeca/"
robust: False

# config fit FedECA
fit_fedeca:
n_clients: 3
split_method: "split_control_over_centers"
split_method_kwargs: {"treatment_info": "treatment_allocation"}
backend_type: "simu"

models_common:
treated_col: "treatment_allocation"
event_col: "event"
duration_col: "time"

parameters:
n_samples: 700
n_reps: 1000
return_propensities: False
return_weights: False

hydra:
sweep:
dir: "/home/owkin/project/results_experiments/power_and_type_one_error_n_ties"
sweeper:
params:
data.cate: 1.0,0.4
data.overlap: -1,3
++data.percent_ties: null, 0.05, 0.1, 0.25, 0.5, 0.8
47 changes: 24 additions & 23 deletions fedeca/fedeca_core.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Federate causal inference on distributed data."""
import logging
import sys
import time
from collections.abc import Callable
Expand Down Expand Up @@ -34,6 +35,8 @@
from fedeca.utils.substrafl_utils import get_outmodel_function
from fedeca.utils.survival_utils import BaseSurvivalEstimator, CoxPHModelTorch

logger = logging.getLogger(__name__)


class FedECA(Experiment, BaseSurvivalEstimator):
"""FedECA class tthat performs Federated IPTW."""
Expand Down Expand Up @@ -308,15 +311,15 @@ def check_cp_status(self, idx=0):
model_name = "Robust Variance"
training_type = "estimation"

print(f"Waiting on {model_name} {training_type} to finish...")
logger.info(f"Waiting on {model_name} {training_type} to finish...")
t1 = time.time()
t2 = t1
while (t2 - t1) < self.timeout:
status = self.ds_client.get_compute_plan(
self.compute_plan_keys[idx].key
).status
if status == ComputePlanStatus.done:
print(
logger.info(
f"""Compute plan {self.compute_plan_keys[0].key} of {model_name} has
finished !"""
)
Expand All @@ -336,7 +339,7 @@ def check_cp_status(self, idx=0):
):
pass
else:
print(
logger.warning(
f"""Compute plan status is {status}, this shouldn't happen, sleeping
{self.time_sleep} and retrying until timeout {self.timeout}"""
)
Expand Down Expand Up @@ -518,7 +521,7 @@ def fit(
if backend_type != "remote" and (
urls is not None or server_org_id is not None or tokens is not None
):
print(
logger.warning(
"urls, server_org_id and tokens are ignored if backend_type is "
"not remote; Make sure that you launched the fit with the right"
" combination of parameters."
Expand Down Expand Up @@ -598,9 +601,7 @@ def __init__(self):
)
# We put WebDisco in "robust" mode in the sense that we ask it
# to store all needed quantities for robust variance estimation
self.strategies[
1
].algo._robust = True # not sufficient for serialization
self.strategies[1].algo._robust = True # not sufficient for serialization
# possible only because we added robust as a kwargs
self.strategies[1].algo.kwargs.update({"robust": True})
# We need those two lines for the zip to consider all 3
Expand All @@ -616,9 +617,9 @@ def __init__(self):
def run(self, targets: Union[pd.DataFrame, None] = None):
"""Run the federated iptw algorithms."""
del targets
print("Careful for now the argument target is ignored completely")
logger.info("Careful for now the argument target is ignored completely")
# We first run the propensity model
print("Fitting the propensity model...")
logger.info("Fitting the propensity model...")
t1 = time.time()
super().run(1)

Expand All @@ -629,11 +630,11 @@ def run(self, targets: Union[pd.DataFrame, None] = None):
)
else:
self.performances_propensity_model = self.performances_strategies[0]
print(self.performances_propensity_model)
logger.info(self.performances_propensity_model)
t2 = time.time()
self.propensity_model_fit_time = t2 - t1
print(f"Time to fit Propensity model {self.propensity_model_fit_time}s")
print("Finished, recovering the final propensity model from substra")
logger.info(f"Time to fit Propensity model {self.propensity_model_fit_time}s")
logger.info("Finished, recovering the final propensity model from substra")
# TODO to add the opportunity to use the targets you have to either:
# give the full targets to every client as a kwargs of their Algo
# so effectively one would need to reinstantiate algos objects or to
Expand Down Expand Up @@ -665,16 +666,16 @@ def run(self, targets: Union[pd.DataFrame, None] = None):
for t in self.train_data_nodes:
t.keep_intermediate_states = True

print("Fitting propensity weighted Cox model...")
logger.info("Fitting propensity weighted Cox model...")
t1 = time.time()
super().run(1)

if not self.simu_mode:
self.check_cp_status(idx=1)
t2 = time.time()
self.webdisco_fit_time = t2 - t1
print(f"Time to fit WebDisco {self.webdisco_fit_time}s")
print("Finished fitting weighted Cox model.")
logger.info(f"Time to fit WebDisco {self.webdisco_fit_time}s")
logger.info("Finished fitting weighted Cox model.")
self.total_fit_time = self.propensity_model_fit_time + self.webdisco_fit_time
self.print_summary()

Expand All @@ -683,19 +684,19 @@ def print_summary(self):
assert (
len(self.compute_plan_keys) == 2
), "You need to run the run method before getting the summary"
print("Evolution of performance of propensity model:")
print(self.performances_propensity_model)
print("Checking if the Cox model has converged:")
logger.info("Evolution of performance of propensity model:")
logger.info(self.performances_propensity_model)
logger.info("Checking if the Cox model has converged:")
self.get_final_cox_model()
print("Computing summary...")
logger.info("Computing summary...")
self.compute_summary()
print("Final partial log-likelihood:")
print(self.ll)
print(self.results_)
logger.info("Final partial log-likelihood:")
logger.info(self.ll)
logger.info(self.results_)

def get_final_cox_model(self):
"""Retrieve final cox model."""
print("Retrieving final hessian and log-likelihood")
logger.info("Retrieving final hessian and log-likelihood")
if not self.simu_mode:
cp = self.compute_plan_keys[1].key
else:
Expand Down
19 changes: 12 additions & 7 deletions fedeca/utils/substrafl_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Utils functions for Substra."""
import logging
import os
import pickle
import tempfile
Expand Down Expand Up @@ -35,6 +36,8 @@
_check_environment_compatibility,
)

logger = logging.getLogger(__name__)


class Experiment:
"""Experiment class."""
Expand Down Expand Up @@ -108,11 +111,11 @@ def __init__(
if metrics_dicts_list and not all(
[len(t.metric_functions) == 0 for t in self.test_data_nodes]
):
print(
logger.warning(
"""WARNING: you are passing metrics to test data nodes with existing
metric_functions this will overwrite them"""
)
print(
logger.warning(
[
(f"Client {i}", t.metric_functions)
for i, t in enumerate(self.test_data_nodes)
Expand Down Expand Up @@ -253,7 +256,7 @@ def run(self, num_strategies_to_run=None):

# If no AggregationNode is given we take the first one
if self.aggregation_node is None:
print("Using the first client as a server.")
logger.info("Using the first client as a server.")
kwargs_agg_node = {
"organization_id": self.train_data_nodes[0].organization_id
}
Expand Down Expand Up @@ -315,12 +318,12 @@ def run(self, num_strategies_to_run=None):
scores = [t.scores for t in self.test_data_nodes]
robust_cox_variance = False
for idx, s in enumerate(scores):
print(f"====Client {idx}====")
logger.info(f"====Client {idx}====")
try:
print(s[-1])
logger.info(s[-1])
except IndexError:
robust_cox_variance = True
print("No metric")
logger.info("No metric")
# TODO Check that it is well formatted it's probably not
self.performances_strategies.append(pd.DataFrame(xp_output))
# Hacky hacky hack
Expand Down Expand Up @@ -515,7 +518,9 @@ def make_substrafl_torch_dataset_class(
[t in [event_col, duration_col] for t in target_cols]
)
if len(target_cols) == 1:
print(f"Making a dataset class to fit a model to predict {target_cols[0]}")
logger.info(
f"Making a dataset class to fit a model to predict {target_cols[0]}"
)
columns_to_drop = [event_col, duration_col]
elif len(target_cols) == 2:
assert set(target_cols) == set(
Expand Down

0 comments on commit d0efbad

Please sign in to comment.