Skip to content

Commit

Permalink
Jean/efficient bootstrapping (#3)
Browse files Browse the repository at this point in the history
* trying

* not working because of weird class vs instance issue

* getting nowhere

* errors start to make more sense

* implicit self

* supposedly working if no reinstantiation

* weird bugs still...

* fixing inii of algo

* iterating

* making slow progress

* it runs !

* fixing criterrion

* fixing predict

* fixing metric

* more robust equality check but is it needed ?

* more strength to the popensity to see actual learning

* starting up tests

* trying to launch tests

* first CP of tests run

* tests don't pass hmm...

* trying to fix bootstrap

* bootstrapping is working

* trying to make linting better

* trying to make linting better

* trying to make linting better

* trying to make linting better

* trying to make linting better

* adding better docstring

* trying to make docs pass

* trying to make docs pass

* adding doc and fixing it

* trying to make sphinx pass

* trying to fix CI

* tryiing to make CI pass

* upping version

* trying to install the correct version

* fixing data error

* first test working

* fixing linting

* accomodating new update_from_checkpoint functon

* Update setup.py

* accomodating substrafl api chhange

* new API compliance
  • Loading branch information
jeandut authored Jan 23, 2024
1 parent e7be88b commit f94bc5d
Show file tree
Hide file tree
Showing 9 changed files with 966 additions and 39 deletions.
4 changes: 4 additions & 0 deletions docs/source/api/strategies.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,8 @@ fedeca.strategies

.. autoclass:: fedeca.strategies.WebDisco

.. automodule:: fedeca.strategies.bootstraper

.. automodule:: fedeca.strategies.webdisco_utils


1 change: 1 addition & 0 deletions fedeca/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
"""Top level package for :mod:`fedeca`."""
from .fedeca_core import FedECA
from .fedeca_core import LogisticRegressionTorch
from .competitors import PooledIPTW, MatchingAjudsted, NaiveComparison
24 changes: 13 additions & 11 deletions fedeca/algorithms/torch_dp_fed_avg_algo.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Differentially private algorithm to be used with FedAvg strategy."""
import logging
import random
from typing import Any, Optional

import numpy as np
Expand Down Expand Up @@ -371,28 +372,26 @@ def _get_state_to_save(self) -> dict:

return checkpoint

def _update_from_checkpoint(self, path) -> dict:
def _update_from_checkpoint(self, checkpoint: dict) -> None:
"""Set self attributes using saved values.
Parameters
----------
path : Path
Path towards the checkpoint to use.
checkpoint : dict
Checkpoint to load.
Returns
-------
dict
The emptied checkpoint.
"""
# One cannot simply call checkpoint = super()._update_from_checkpoint(path)
# One cannot simply call checkpoint = super()._update_from_checkpoint(chkpt)
# because we have to change the model class if it should be changed
# (and optimizer) aka if we find a specific key in the checkpoint
assert (
path.is_file()
), f'Cannot load the model - does not exist {list(path.parent.glob("*"))}'
checkpoint = torch.load(path, map_location=self._device)

# For some reason substrafl save and load client before calling train
if "privacy_accountant_state_dict" in checkpoint:

self.accountant = RDPAccountant()
self.accountant.load_state_dict(
checkpoint.pop("privacy_accountant_state_dict")
Expand Down Expand Up @@ -429,10 +428,13 @@ def _update_from_checkpoint(self, path) -> dict:

self._index_generator = checkpoint.pop("index_generator")

random.setstate(checkpoint.pop("random_rng_state"))
np.random.set_state(checkpoint.pop("numpy_rng_state"))

if self._device == torch.device("cpu"):
torch.set_rng_state(checkpoint.pop("rng_state").to(self._device))
torch.set_rng_state(checkpoint.pop("torch_rng_state").to(self._device))
else:
torch.cuda.set_rng_state(checkpoint.pop("rng_state").to("cpu"))
torch.cuda.set_rng_state(checkpoint.pop("torch_rng_state").to("cpu"))

attr_names = [
"dp_max_grad_norm",
Expand All @@ -447,4 +449,4 @@ def _update_from_checkpoint(self, path) -> dict:
for attr in attr_names:
setattr(self, attr, checkpoint.pop(attr))

return checkpoint
return
16 changes: 5 additions & 11 deletions fedeca/algorithms/torch_webdisco_algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import os
from copy import deepcopy
from math import sqrt
from pathlib import Path
from typing import Any, List, Optional

# hello
Expand Down Expand Up @@ -547,23 +546,18 @@ def _get_state_to_save(self) -> dict:
checkpoint.update({"global_moments": self.global_moments})
return checkpoint

def _update_from_checkpoint(self, path: Path) -> dict:
def _update_from_checkpoint(self, checkpoint: dict) -> None:
"""Load the local state from the checkpoint.
Parameters
----------
path : pathlib.Path
Path where the checkpoint is saved
Returns
-------
dict
Checkpoint
checkpoint : dict
The checkpoint to load.
"""
checkpoint = super()._update_from_checkpoint(path=path)
super()._update_from_checkpoint(checkpoint=checkpoint)
self.server_state = checkpoint.pop("server_state")
self.global_moments = checkpoint.pop("global_moments")
return checkpoint
return

def summary(self):
"""Summary of the class to be exposed in the experiment summary file.
Expand Down
1 change: 1 addition & 0 deletions fedeca/strategies/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
"""Init file for strategies."""
from .webdisco import WebDisco
from .bootstraper import make_bootstrap_metric_function, make_bootstrap_strategy
Loading

0 comments on commit f94bc5d

Please sign in to comment.