Skip to content

Commit

Permalink
more iterations locally
Browse files Browse the repository at this point in the history
  • Loading branch information
emdeeweegio committed Mar 15, 2024
1 parent 966040d commit 36fa9d1
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions v6_logistic_regression_py/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def master(
classes: List[str],
max_iter: int = 15,
delta: float = 0.01,
n_local_iterations: int = 1,
org_ids: List[int] = None,
trash_kwargs=dict(
survival_column="Survival.time",
Expand Down Expand Up @@ -89,7 +90,7 @@ def master(
model_attrs = export_model(global_model, MODEL_ATTRIBUTE_KEYS)
input_ = {
'method': 'logistic_regression_partial',
'kwargs': {'model_attributes': model_attrs, 'predictors': predictors, 'outcome': outcome, 'trash_kwargs': trash_kwargs}
'kwargs': {'model_attributes': model_attrs, 'predictors': predictors, 'outcome': outcome, 'trash_kwargs': trash_kwargs, "n_local_iterations": n_local_iterations}
}
partial_results = coordinate_task(client, input_, ids)

Expand Down Expand Up @@ -133,7 +134,8 @@ def logistic_regression_partial(
model_attributes: Dict[str, List[float]],
predictors: List[str],
outcome: str,
trash_kwargs: dict
trash_kwargs: dict,
n_local_iterations: int = 1,
) -> Dict[str, any]:
"""
Fits logistic regression model on local dataset.
Expand Down Expand Up @@ -166,7 +168,7 @@ def logistic_regression_partial(

# Create local LogisticRegression estimator object
model_kwargs = dict(
max_iter=1, # local epoch
max_iter=n_local_iterations, # local epoch
warm_start=True, # prevent refreshing weights when fitting
)
model = initialize_model(LogisticRegression, model_attributes=model_attributes, **model_kwargs)
Expand Down

0 comments on commit 36fa9d1

Please sign in to comment.