Skip to content

Commit

Permalink
fix: robust attribute overwriting logic (#4)
Browse files Browse the repository at this point in the history
* fix: `robust` attribute overwriting logic

Only overwrite the `robust` attribute if the argument `robust` is not None

* fix: proper `robust` check
  • Loading branch information
honghaoli42 authored Jan 19, 2024
1 parent 804e3e2 commit e7be88b
Showing 1 changed file with 24 additions and 23 deletions.
47 changes: 24 additions & 23 deletions fedeca/fedeca_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,33 +581,34 @@ def fit(
self.strategies[0] = self.propensity_model_strategy
self.strategies[1] = self.webdisco_strategy

if robust != self.robust:
if robust is not None and robust != self.robust:
self.robust = robust
if robust:

class MockAlgo:
def __init__(self):
self.strategies = ["Robust Cox Variance"]
if self.robust:

mock_algo = MockAlgo()
self.strategies.append(
RobustCoxVariance(
algo=mock_algo,
)
class MockAlgo:
def __init__(self):
self.strategies = ["Robust Cox Variance"]

mock_algo = MockAlgo()
self.strategies.append(
RobustCoxVariance(
algo=mock_algo,
)
# We put WebDisco in "robust" mode in the sense that we ask it
# to store all needed quantities for robust variance estimation
self.strategies[
1
].algo._robust = True # not sufficient for serialization
# possible only because we added robust as a kwargs
self.strategies[1].algo.kwargs.update({"robust": True})
# We need those two lines for the zip to consider all 3
# strategies
self.metrics_dicts_list.append({})
self.num_rounds_list.append(sys.maxsize)
else:
self.strategies = self.strategies[:2]
)
# We put WebDisco in "robust" mode in the sense that we ask it
# to store all needed quantities for robust variance estimation
self.strategies[
1
].algo._robust = True # not sufficient for serialization
# possible only because we added robust as a kwargs
self.strategies[1].algo.kwargs.update({"robust": True})
# We need those two lines for the zip to consider all 3
# strategies
self.metrics_dicts_list.append({})
self.num_rounds_list.append(sys.maxsize)
else:
self.strategies = self.strategies[:2]

self.run(targets=targets)
self.propensity_scores_, self.weights_ = self.compute_propensity_scores(data)
Expand Down

0 comments on commit e7be88b

Please sign in to comment.