Skip to content

Commit

Permalink
test passing
Browse files Browse the repository at this point in the history
  • Loading branch information
jeandut committed Aug 12, 2024
1 parent 85cdf39 commit 31fa318
Show file tree
Hide file tree
Showing 2 changed files with 154 additions and 0 deletions.
153 changes: 153 additions & 0 deletions fedeca/tests/strategies/test_smd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
"""Module testing substraFL moments strategy."""
import os
import subprocess
import unittest
from pathlib import Path

import git
import numpy as np
import pandas as pd
import torch
from substrafl.dependency import Dependency
from substrafl.experiment import execute_experiment
from substrafl.model_loading import download_aggregate_shared_state
from substrafl.nodes import AggregationNode
from torch import nn

import fedeca
from fedeca.fedeca_core import LogisticRegressionTorch
from fedeca.metrics.metrics import standardized_mean_diff
from fedeca.strategies.fed_smd import FedSMD
from fedeca.tests.common import TestTempDir
from fedeca.utils.data_utils import split_dataframe_across_clients
from fedeca.utils.survival_utils import CoxData


class TestSMD(TestTempDir):
"""Test substrafl computation of SMD.
Tests the FL computation of SMD is the same as in pandas-pooled version
"""

def setUp(self, backend_type="subprocess", ndim=10) -> None:
"""Set up the quantities needed for the tests."""
# Let's generate 1000 data samples with 10 covariates
data = CoxData(seed=42, n_samples=1000, ndim=ndim)
self.df = data.generate_dataframe()

# We remove the true propensity score
self.df = self.df.drop(columns=["propensity_scores"], axis=1)

self.clients, self.train_data_nodes, _, _, _ = split_dataframe_across_clients(
self.df,
n_clients=4,
split_method="split_control_over_centers",
split_method_kwargs={"treatment_info": "treatment"},
data_path=Path(self.test_dir) / "data",
backend_type=backend_type,
)
kwargs_agg_node = {"organization_id": self.train_data_nodes[0].organization_id}
self.aggregation_node = AggregationNode(**kwargs_agg_node)
# Packaging the right dependencies

fedeca_path = fedeca.__path__[0]
repo_folder = Path(
git.Repo(fedeca_path, search_parent_directories=True).working_dir
).resolve()
wheel_folder = repo_folder / "temp"
os.makedirs(wheel_folder, exist_ok=True)
for stale_wheel in wheel_folder.glob("fedeca*.whl"):
stale_wheel.unlink()
process = subprocess.Popen(
f"python -m build --wheel --outdir {wheel_folder} {repo_folder}",
shell=True,
stdout=subprocess.PIPE,
)
process.wait()
assert process.returncode == 0, "Failed to build the wheel"
self.wheel_path = next(wheel_folder.glob("fedeca*.whl"))
self.ds_client = self.clients[self.train_data_nodes[0].organization_id]
self.propensity_model = LogisticRegressionTorch(ndim=ndim)

self.propensity_model.fc1.weight.data = nn.parameter.Parameter(
torch.randn(
size=self.propensity_model.fc1.weight.data.shape, dtype=torch.float64
)
)
self.propensity_model.fc1.bias.data = nn.parameter.Parameter(
torch.randn(
size=self.propensity_model.fc1.bias.data.shape, dtype=torch.float64
)
)

def test_end_to_end(self):
"""Compare a FL and pooled computation of Moments.
The data are the tcga ones.
"""
# Get fl_results.
strategy = FedSMD(
treated_col="treatment",
duration_col="time",
event_col="event",
propensity_model=self.propensity_model,
client_identifier="center",
)

compute_plan = execute_experiment(
client=self.ds_client,
strategy=strategy,
train_data_nodes=self.train_data_nodes,
evaluation_strategy=None,
aggregation_node=self.aggregation_node,
num_rounds=1,
experiment_folder=str(Path(self.test_dir) / "experiment_summaries"),
dependencies=Dependency(
local_installable_dependencies=[Path(self.wheel_path)]
),
)

fl_results = download_aggregate_shared_state(
client=self.ds_client,
compute_plan_key=compute_plan.key,
round_idx=0,
)

assert not fl_results["weighted_smd"].equals(fl_results["unweighted_smd"])
X = self.df.drop(columns=["time", "event", "treatment"], axis=1)
covariates = X.columns
Xprop = torch.from_numpy(X.values).type(self.propensity_model.fc1.weight.dtype)
with torch.no_grad():
self.propensity_model.eval()
propensity_scores = self.propensity_model(Xprop)

propensity_scores = propensity_scores.detach().numpy().flatten()
weights = self.df["treatment"] * 1.0 / propensity_scores + (
1 - self.df["treatment"]
) * 1.0 / (1.0 - propensity_scores)
weights = weights.values

X_weighted = (Xprop * weights[:, np.newaxis]).numpy()
X_weighted_df = pd.DataFrame(X_weighted, columns=covariates)
X_df = pd.DataFrame(Xprop.numpy(), columns=covariates)

standardized_mean_diff_pooled_weighted = standardized_mean_diff(
X_weighted_df,
self.df["treatment"] == 1,
).div(100.0)
standardized_mean_diff_pooled_unweighted = standardized_mean_diff(
X_df,
self.df["treatment"] == 1,
).div(100.0)

# We check equality of FL computation and pooled results
pd.testing.assert_series_equal(
standardized_mean_diff_pooled_weighted,
fl_results["weighted_smd"],
rtol=1e-2,
)
pd.testing.assert_series_equal(
standardized_mean_diff_pooled_unweighted,
fl_results["unweighted_smd"],
rtol=1e-2,
)
1 change: 1 addition & 0 deletions fedeca/utils/survival_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1730,6 +1730,7 @@ def compute_X_y_and_propensity_weights_function(
np.in1d(np.unique(treated.astype("uint8"))[0], [0, 1])
), "The treated column should have all its values in set([0, 1])"
Xprop = torch.from_numpy(Xprop)

with torch.no_grad():
propensity_scores = propensity_model(Xprop)

Expand Down

0 comments on commit 31fa318

Please sign in to comment.