Skip to content

Commit

Permalink
removed preprocessing
Browse files Browse the repository at this point in the history
  • Loading branch information
emdeeweegio committed Mar 21, 2024
1 parent 950d7c3 commit b42a817
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 21 deletions.
27 changes: 7 additions & 20 deletions v6_logistic_regression_py/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,7 @@
aggregate,
coordinate_task,
export_model,
initialize_model,
trash_outcomes
initialize_model
)

MODEL_ATTRIBUTE_KEYS = ["coef_", "intercept_", "classes_"]
Expand All @@ -35,11 +34,7 @@ def master(
max_iter: int = 15,
delta: float = 0.01,
n_local_iterations: int = 1,
org_ids: List[int] = None,
trash_kwargs=dict(
survival_column="Survival.time",
event_column="deadstatus.event",
threshold=730)
org_ids: List[int] = None
) -> Dict[str, Any]:
"""
Orchestrates federated logistic regression training across nodes.
Expand Down Expand Up @@ -90,13 +85,13 @@ def master(
model_attrs = export_model(global_model, MODEL_ATTRIBUTE_KEYS)
input_ = {
'method': 'logistic_regression_partial',
'kwargs': {'model_attributes': model_attrs, 'predictors': predictors, 'outcome': outcome, 'trash_kwargs': trash_kwargs, "n_local_iterations": n_local_iterations}
'kwargs': {'model_attributes': model_attrs, 'predictors': predictors, 'outcome': outcome, "n_local_iterations": n_local_iterations}
}
partial_results = coordinate_task(client, input_, ids)

# Aggregating updates into the global model and assessing convergence.
global_model = aggregate(global_model, partial_results, MODEL_AGGREGATION_KEYS)
new_loss = compute_global_loss(client, global_model, predictors, outcome, ids, trash_kwargs)
new_loss = compute_global_loss(client, global_model, predictors, outcome, ids)

loss_diff = abs(loss - new_loss) if loss is not None else 2 * delta
loss = new_loss
Expand All @@ -112,14 +107,14 @@ def master(
'iteration': iteration
}

def compute_global_loss(client, model, predictors, outcome, ids, trash_kwargs):
def compute_global_loss(client, model, predictors, outcome, ids):
"""
Helper function to compute global loss, abstracting detailed logging.
"""
model_attributes = export_model(model, MODEL_ATTRIBUTE_KEYS)
input_ = {
'method': 'compute_loss_partial',
'kwargs': {'model_attributes': model_attributes, 'predictors': predictors, 'outcome': outcome, 'trash_kwargs': trash_kwargs}
'kwargs': {'model_attributes': model_attributes, 'predictors': predictors, 'outcome': outcome}
}
results = coordinate_task(client, input_, ids)
aggregated_sample_size = np.sum([res['size'] for res in results])
Expand All @@ -134,7 +129,6 @@ def logistic_regression_partial(
model_attributes: Dict[str, List[float]],
predictors: List[str],
outcome: str,
trash_kwargs: dict,
n_local_iterations: int = 1,
) -> Dict[str, any]:
"""
Expand All @@ -159,9 +153,6 @@ def logistic_regression_partial(
# Drop rows with NaNs
df = df.dropna(how='any')

# REMOVE
df = trash_outcomes(df, outcome, **trash_kwargs)

# Get features and outcomes
X = df[predictors].values
y = df[outcome].values
Expand Down Expand Up @@ -192,8 +183,7 @@ def compute_loss_partial(
df: pd.DataFrame,
model_attributes: Dict[str, list],
predictors: List[str],
outcome: str,
trash_kwargs: dict
outcome: str
) -> Dict[str, Any]:
"""
Computes logistic regression model loss on local dataset.
Expand All @@ -217,9 +207,6 @@ def compute_loss_partial(
# Drop rows with NaNs
df = df.dropna(how='any')

# REMOVE
df = trash_outcomes(df, outcome, **trash_kwargs)

# Get features and outcomes
X = df[predictors].values
y = df[outcome].values
Expand Down
2 changes: 1 addition & 1 deletion v6_logistic_regression_py/example.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
'predictors': ['t', 'n', 'm'],
'outcome': 'vital_status',
'classes': ['alive', 'dead'],
'max_iter': 10,
'max_iter': 100,
'delta': 0.0001
}
},
Expand Down

0 comments on commit b42a817

Please sign in to comment.