Skip to content

Commit

Permalink
v0.3.0 (#20)
Browse files Browse the repository at this point in the history
* Update link to experiments instructions (issue #16).

* Tolerate the case of ema_params or polyak_params being None.

* Update to v0.3.0

Main changes:
* Ability to use jax_privacy without jaxline.
* Implementing multi-label image classification to reproduce results on CheXpert / MIMIC-CXR.
* Typing improvements.

Co-authored-by: Borja Balle <[email protected]>
Co-authored-by: Sahra Ghalebikesabi <[email protected]>
Co-authored-by: Aneesh Pappu <[email protected]>
Co-authored-by: Robert Stanforth <[email protected]>

---------

Co-authored-by: Borja Balle <[email protected]>
Co-authored-by: Sahra Ghalebikesabi <[email protected]>
Co-authored-by: Aneesh Pappu <[email protected]>
Co-authored-by: Robert Stanforth <[email protected]>
  • Loading branch information
5 people authored Dec 20, 2023
1 parent 95c5b00 commit 8c284ce
Show file tree
Hide file tree
Showing 60 changed files with 5,144 additions and 2,102 deletions.
6 changes: 5 additions & 1 deletion CITATION.cff
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,17 @@ authors:
given-names: "Leonard"
- family-names: "De"
given-names: "Soham"
- family-names: "Ghalebikesabi"
given-names: "Sahra"
- family-names: "Hayes"
given-names: "Jamie"
- family-names: "Pappu"
given-names: "Aneesh"
- family-names: "Smith"
given-names: "Samuel L"
- family-names: "Stanforth"
given-names: "Robert"
title: "JAX-Privacy"
version: 0.1.0
date-released: 2022-04-28
url: "https://github.com/deepmind/jax_privacy"
url: "https://github.com/google-deepmind/jax_privacy"
21 changes: 13 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ codebase without modifying them.
The package can be installed by running the following command-line:

```
pip install git+https://github.com/deepmind/jax_privacy
pip install git+https://github.com/google-deepmind/jax_privacy
```

### Option 2: Local Installation (Allowing Edits) <a id="install-option2"></a>
Expand All @@ -34,7 +34,7 @@ our results.
* The first step is to clone the repository:

```
git clone https://github.com/deepmind/jax_privacy
git clone https://github.com/google-deepmind/jax_privacy
```

* Then the code can be installed so that local modifications to the code are
Expand All @@ -49,20 +49,25 @@ pip install -e .

### Unlocking High-Accuracy Differentially Private Image Classification through Scale

Instructions detailed in [experiments/image_classification](jax_privacy/experiments/image_classification).
* Instructions: [experiments/image_classification](jax_privacy/experiments/image_classification).
* arXiv link: https://arxiv.org/abs/2204.13650.
* Bibtex reference: [link](https://github.com/google-deepmind/jax_privacy/blob/main/bibtex/de2022unlocking.bib).

This work is available on arXiv at [this link](https://arxiv.org/abs/2204.13650).
If you use it, please cite the following [bibtex reference](https://github.com/deepmind/jax_privacy/blob/main/bibtex/de2022unlocking.bib).
### Unlocking Accuracy and Fairness in Differentially Private Image Classification

* Instructions: [experiments/image_classification](jax_privacy/experiments/image_classification).
* arXiv link: https://arxiv.org/abs/2308.10888.
* Bibtex reference: [link](https://github.com/google-deepmind/jax_privacy/blob/main/bibtex/berrada2023unlocking.bib).

## How to Cite This Repository <a id="citing"></a>
If you use code from this repository, please cite the following reference:

```
@software{jax-privacy2022github,
author = {Balle, Borja and Berrada, Leonard and De, Soham and Hayes, Jamie and Smith, Samuel L and Stanforth, Robert},
author = {Balle, Borja and Berrada, Leonard and De, Soham and Ghalebikesabi, Sahra and Hayes, Jamie and Pappu, Aneesh and Smith, Samuel L and Stanforth, Robert},
title = {{JAX}-{P}rivacy: Algorithms for Privacy-Preserving Machine Learning in JAX},
url = {http://github.com/deepmind/jax_privacy},
version = {0.2.0},
url = {http://github.com/google-deepmind/jax_privacy},
version = {0.3.0},
year = {2022},
}
```
Expand Down
6 changes: 6 additions & 0 deletions bibtex/berrada2023unlocking.bib
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
@article{berrada2023unlocking,
title={{Unlocking Accuracy and Fairness in Differentially Private Image Classification}},
author={Leonard Berrada and Soham De and Judy Hanwen Shen and Jamie Hayes and Robert Stanforth and David Stutz and Pushmeet Kohli and Samuel L. Smith and Borja Balle},
journal={arXiv preprint arXiv:2308.10888},
year={2023}
}
15 changes: 0 additions & 15 deletions jax_privacy/experiments/__init__.py

This file was deleted.

18 changes: 9 additions & 9 deletions jax_privacy/experiments/image_classification/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,18 @@
Reproducing experiments of the paper "Unlocking High-Accuracy Differentially Private Image Classification through Scale"

This work is available on arXiv at [this link](https://arxiv.org/abs/2204.13650).
If you use it, please cite the following [bibtex reference](https://github.com/deepmind/jax_privacy/blob/main/bibtex/de2022unlocking.bib).
If you use it, please cite the following [bibtex reference](https://github.com/google-deepmind/jax_privacy/blob/main/bibtex/de2022unlocking.bib).

The following instructions assume that our package has been installed through
[option 2](https://github.com/deepmind/jax_privacy#install-option2).
[option 2](https://github.com/google-deepmind/jax_privacy#install-option2).

## Intro


- An experiment can be run by executing from this directory:

```
python run_experiment.py --config=<relative/path/to/config.py> --jaxline_mode=train_eval_multithreaded
python run_experiment_loop.py --config=<relative/path/to/config.py>
```

where the config file contains all relevant hyper-parameters for the experiment.
Expand All @@ -27,8 +27,8 @@ where the config file contains all relevant hyper-parameters for the experiment.
- Model definition: `config.experiment_kwargs.config.model`
- Noise multiplier sigma: `config.experiment_kwargs.config.training.dp.noise_multiplier`
- Number of updates: `config.experiment_kwargs.config.num_updates`
- Privacy budget (delta): `config.experiment_kwargs.config.dp.target_delta`
- Privacy budget (epsilon): `config.experiment_kwargs.config.dp.stop_training_at_epsilon`
- Privacy budget (delta): `config.experiment_kwargs.config.dp.delta`
- Privacy budget (epsilon): `config.experiment_kwargs.config.dp.auto_tune_target_epsilon`

Note: we provide examples of configurations for various experiments. To
reproduce the results of our paper, please refer to the hyper-parameters listed
Expand All @@ -37,27 +37,27 @@ reproduce the results of our paper, please refer to the hyper-parameters listed
## Training from Scratch on CIFAR-10

```
python run_experiment.py --config=configs/cifar10_wrn_16_4_eps1.py --jaxline_mode=train_eval_multithreaded
python run_experiment_loop.py --config=configs/cifar10_wrn_16_4_eps1.py
```


## Training from Scratch on ImageNet

```
python run_experiment.py --config=configs/imagenet_nf_resnet_50_eps8.py --jaxline_mode=train_eval_multithreaded
python run_experiment_loop.py --config=configs/imagenet_nf_resnet_50_eps8.py
```

## Fine-tuning on CIFAR

```
python run_experiment.py --config=configs/cifar100_wrn_28_10_eps1_finetune.py --jaxline_mode=train_eval_multithreaded
python run_experiment_loop.py --config=configs/cifar100_wrn_28_10_eps1_finetune.py
```

See `jax_privacy/experiments/image_classification/config_base.py` for the available pre-trained models.

## Additional Details

- Training and evaluation accuracies throughout training will be printed to the console. At the moment JAXline does not have the capability to save model checkpoints to disk.
- Training and evaluation accuracies throughout training will be printed to the console.
- If you are observing Out of Memory errors with the default configs, consider reducing the value of `config.experiment_kwargs.config.training.batch_size.per_device_per_step` to ensure the number of examples processed each time step fits in memory. This might make training slower, but will not change the effective batch-size used for each model update. Note that `config.experiment_kwargs.config.training.batch_size.init_value` should be divisible by `config.experiment_kwargs.config.training.batch_size.per_device_per_step` times the number of accelerators in your machine.
- The number of updates given in the config is ignored if `stop_training_at_epsilon` is specified, in which case the training automatically stops when the total privacy budget has been spent.
- The `auto_tune` feature in the config can be used to, for example, calibrate the noise multiplier under a pre-specified privacy budget, number of iterations and batch-size.
Expand Down
75 changes: 22 additions & 53 deletions jax_privacy/experiments/image_classification/config_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,85 +15,51 @@

"""Base configuration."""

from collections.abc import Mapping
import dataclasses
import random
from typing import Any, Mapping

from jax_privacy.experiments import image_data as data
from jax_privacy.experiments.image_classification.models import base
from jax_privacy.src.training import auto_tune
from jax_privacy.src.training import averaging as averaging_py
from jax_privacy.src.training import experiment_config as experiment_config_py
from jax_privacy.src.training import optimizer_config
from jaxline import base_config as jaxline_base_config
import ml_collections


MODEL_CKPT = ml_collections.FrozenConfigDict({
'WRN_40_4_CIFAR100': 'WRN_40_4_CIFAR100.dill',
'WRN_40_4_IMAGENET32': 'WRN_40_4_IMAGENET32.dill',
'WRN_28_10_IMAGENET32': 'WRN_28_10_IMAGENET32.dill',
})


@dataclasses.dataclass(kw_only=True, slots=True)
class ModelRestoreConfig:
"""Configuration for restoring the model.
Attributes:
path: Path to the model to restore.
params_key: (dictionary) Key identifying the parameters in the checkpoint to
restore.
network_state_key: (dictionary) Key identifying the model state in the
checkpoint to restore.
layer_to_reset: Optional identifying name of the layer to reset when loading
the checkpoint (useful for resetting the classification layer to use a
different number of classes for example).
"""

path: str | None = None
params_key: str | None = None
network_state_key: str | None = None
layer_to_reset: str | None = None


@dataclasses.dataclass(kw_only=True, slots=True)
class ModelConfig:
"""Config for the model.
Attributes:
name: Identifying name of the model.
kwargs: Keyword arguments to construct the model.
restore: Configuration for restoring the model.
"""
name: str
kwargs: Mapping[str, Any] = dataclasses.field(default_factory=dict)
restore: ModelRestoreConfig = dataclasses.field(
default_factory=ModelRestoreConfig)


@dataclasses.dataclass(kw_only=True, slots=True)
class ExperimentConfig:
"""Configuration for the experiment.
Attributes:
num_updates: Number of updates for the experiment.
optimizer: Optimizer configuration.
model: Model configuration.
training: Training configuration.
label_smoothing: parameter within [0, 1] to smooth the labels. The default
value of 0 corresponds to no smoothing.
averaging: Averaging configuration.
evaluation: Evaluation configuration.
eval_disparity: Whether to compute disparity at evaluation time.
data_train: Training data configuration.
data_eval: Eval data configuration.
data_eval_additional: Configuration for an (optional) additional evaluation
dataset.
random_seed: Random seed (automatically changed from the default value).
"""

num_updates: int
optimizer: optimizer_config.OptimizerConfig
model: ModelConfig
model: base.ModelConfig
training: experiment_config_py.TrainingConfig
averaging: experiment_config_py.AveragingConfig
label_smoothing: float = 0.0
averaging: Mapping[str, averaging_py.AveragingConfig] = dataclasses.field(
default_factory=dict)
evaluation: experiment_config_py.EvaluationConfig
eval_disparity: bool = False
data_train: data.DataLoader
data_eval: data.DataLoader
data_eval_additional: data.DataLoader | None = None
random_seed: int = 0


Expand All @@ -114,9 +80,9 @@ def build_jaxline_config(

# Intervals can be measured in 'steps' or 'secs'.
config.interval_type = 'steps'
config.log_train_data_interval = 100
config.log_tensors_interval = 100
config.save_checkpoint_interval = 250
config.log_train_data_interval = 10
config.log_tensors_interval = 10
config.save_checkpoint_interval = 50
config.eval_specific_checkpoint_dir = ''

config.experiment_kwargs = ml_collections.ConfigDict()
Expand All @@ -130,7 +96,10 @@ def build_jaxline_config(
# noise injected in DP-SGD will be invalid otherwise.
assert config.random_mode_train == 'same_host_same_device'

if config.experiment_kwargs.config.training.dp.auto_tune:
config = auto_tune.dp_auto_tune_config(config)
if config.experiment_kwargs.config.training.dp.auto_tune_field:
config.experiment_kwargs.config.training = auto_tune.dp_auto_tune_config(
config.experiment_kwargs.config.training,
config.experiment_kwargs.config.data_train.config.num_samples,
)

return config
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

from jax_privacy.experiments import image_data
from jax_privacy.experiments.image_classification import config_base
from jax_privacy.experiments.image_classification.models import models
from jax_privacy.src.training import averaging
from jax_privacy.src.training import experiment_config
from jax_privacy.src.training import optimizer_config
import ml_collections
Expand All @@ -26,24 +28,21 @@ def get_config() -> ml_collections.ConfigDict:
"""Experiment config."""

config = config_base.ExperimentConfig(
num_updates=250,
optimizer=optimizer_config.sgd_config(
lr=optimizer_config.constant_lr_config(1.0),
),
model=config_base.ModelConfig(
name='wideresnet',
kwargs={
'depth': 28,
'width': 10,
},
restore=config_base.ModelRestoreConfig(
path=config_base.MODEL_CKPT.WRN_28_10_IMAGENET32,
params_key='params',
network_state_key='network_state',
layer_to_reset='wide_res_net/Softmax',
model=models.WithRestoreModelConfig(
path=models.Registry.WRN_28_10_IMAGENET32.path,
params_key='params',
network_state_key='network_state',
layer_to_ignore='wide_res_net/Softmax',
model=models.WideResNetConfig(
depth=28,
width=10,
),
),
training=experiment_config.TrainingConfig(
num_updates=250,
batch_size=experiment_config.BatchSizeTrainConfig(
total=16384,
per_device_per_step=16,
Expand All @@ -53,19 +52,18 @@ def get_config() -> ml_collections.ConfigDict:
dp=experiment_config.DPConfig(
delta=1e-5,
clipping_norm=1.0,
stop_training_at_epsilon=1.0,
auto_tune_target_epsilon=1.0,
rescale_to_unit_norm=True,
noise_multiplier=21.1,
auto_tune=None, # 'num_updates',
auto_tune_field=None, # 'num_updates',
),
logging=experiment_config.LoggingConfig(
grad_clipping=True,
grad_alignment=False,
snr_global=True, # signal-to-noise ratio across layers
snr_per_layer=False, # signal-to-noise ratio per layer
),
),
averaging=experiment_config.AveragingConfig(ema_coefficient=0.9,),
averaging={'ema': averaging.ExponentialMovingAveragingConfig(decay=0.9)},
data_train=image_data.Cifar100Loader(
config=image_data.Cifar100TrainValidConfig(
preprocess_name='standardise',
Expand Down
Loading

0 comments on commit 8c284ce

Please sign in to comment.