Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update differential_privacy module to convert DP events to and from NamedTuples. #463

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 6 additions & 7 deletions tensorflow_privacy/privacy/dp_query/tree_aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@
as helper functions for `tree_aggregation_query`. This module and helper
functions are publicly accessible.
"""

import abc
import collections
from typing import Any, Callable, Collection, Optional, Tuple, Union
from typing import Any, Callable, Collection, NamedTuple, Optional, Tuple, Union

import attr
import tensorflow as tf

# TODO(b/192464750): find a proper place for the helper functions, privatize
Expand Down Expand Up @@ -170,8 +170,7 @@ def next(self, state):
return self.value_fn(), state


@attr.s(eq=False, frozen=True, slots=True)
class TreeState(object):
class TreeState(NamedTuple):
"""Class defining state of the tree.

Attributes:
Expand All @@ -183,9 +182,9 @@ class TreeState(object):
for the most recent leaf node.
value_generator_state: State of a stateful `ValueGenerator` for tree node.
"""
level_buffer = attr.ib(type=tf.Tensor)
level_buffer_idx = attr.ib(type=tf.Tensor)
value_generator_state = attr.ib(type=Any)
level_buffer: tf.Tensor
level_buffer_idx: tf.Tensor
value_generator_state: Any


# TODO(b/192464750): move `get_step_idx` to be a property of `TreeState`.
Expand Down
64 changes: 38 additions & 26 deletions tensorflow_privacy/privacy/dp_query/tree_aggregation_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,11 @@
eps = rdp_accountant.get_privacy_spent(orders, rdp, target_delta)[0]
"""

import attr
from typing import Any, NamedTuple

import dp_accounting
import tensorflow as tf

from tensorflow_privacy.privacy.dp_query import dp_query
from tensorflow_privacy.privacy.dp_query import tree_aggregation

Expand Down Expand Up @@ -84,8 +86,7 @@ class TreeCumulativeSumQuery(dp_query.SumAggregationDPQuery):
O(clip_norm*log(T)/eps) to guarantee eps-DP.
"""

@attr.s(frozen=True)
class GlobalState(object):
class GlobalState(NamedTuple):
"""Class defining global state for Tree sum queries.

Attributes:
Expand All @@ -94,9 +95,9 @@ class GlobalState(object):
clip_value: The clipping value to be passed to clip_fn.
samples_cumulative_sum: Noiseless cumulative sum of samples over time.
"""
tree_state = attr.ib()
clip_value = attr.ib()
samples_cumulative_sum = attr.ib()
tree_state: Any
clip_value: Any
samples_cumulative_sum: Any

def __init__(self,
record_specs,
Expand Down Expand Up @@ -182,10 +183,11 @@ def get_noised_result(self, sample_state, global_state):
global_state.tree_state)
noised_cumulative_sum = tf.nest.map_structure(tf.add, new_cumulative_sum,
cumulative_sum_noise)
new_global_state = attr.evolve(
global_state,
new_global_state = TreeCumulativeSumQuery.GlobalState(
tree_state=new_tree_state,
clip_value=global_state.clip_value,
samples_cumulative_sum=new_cumulative_sum,
tree_state=new_tree_state)
)
event = dp_accounting.UnsupportedDpEvent()
return noised_cumulative_sum, new_global_state, event

Expand All @@ -206,10 +208,11 @@ def reset_state(self, noised_results, global_state):
state for the next cumulative sum.
"""
new_tree_state = self._tree_aggregator.reset_state(global_state.tree_state)
return attr.evolve(
global_state,
return TreeCumulativeSumQuery.GlobalState(
tree_state=new_tree_state,
clip_value=global_state.clip_value,
samples_cumulative_sum=noised_results,
tree_state=new_tree_state)
)

@classmethod
def build_l2_gaussian_query(cls,
Expand Down Expand Up @@ -312,8 +315,7 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery):
O(clip_norm*log(T)/eps) to guarantee eps-DP.
"""

@attr.s(frozen=True)
class GlobalState(object):
class GlobalState(NamedTuple):
"""Class defining global state for Tree sum queries.

Attributes:
Expand All @@ -323,9 +325,9 @@ class GlobalState(object):
previous_tree_noise: Cumulative noise by tree aggregation from the
previous time the query is called on a sample.
"""
tree_state = attr.ib()
clip_value = attr.ib()
previous_tree_noise = attr.ib()
tree_state: Any
clip_value: Any
previous_tree_noise: Any

def __init__(self,
record_specs,
Expand Down Expand Up @@ -426,8 +428,11 @@ def get_noised_result(self, sample_state, global_state):
noised_sample = tf.nest.map_structure(lambda a, b, c: a + b - c,
sample_state, tree_noise,
global_state.previous_tree_noise)
new_global_state = attr.evolve(
global_state, previous_tree_noise=tree_noise, tree_state=new_tree_state)
new_global_state = TreeResidualSumQuery.GlobalState(
tree_state=new_tree_state,
clip_value=global_state.clip_value,
previous_tree_noise=tree_noise,
)
event = dp_accounting.UnsupportedDpEvent()
return noised_sample, new_global_state, event

Expand All @@ -448,21 +453,28 @@ def reset_state(self, noised_results, global_state):
"""
del noised_results
new_tree_state = self._tree_aggregator.reset_state(global_state.tree_state)
return attr.evolve(
global_state,
return TreeResidualSumQuery.GlobalState(
tree_state=new_tree_state,
clip_value=global_state.clip_value,
previous_tree_noise=self._zero_initial_noise(),
tree_state=new_tree_state)
)

def reset_l2_clip_gaussian_noise(self, global_state, clip_norm, stddev):
noise_generator_state = global_state.tree_state.value_generator_state
assert isinstance(self._tree_aggregator.value_generator,
tree_aggregation.GaussianNoiseGenerator)
noise_generator_state = self._tree_aggregator.value_generator.make_state(
noise_generator_state.seeds, stddev)
new_tree_state = attr.evolve(
global_state.tree_state, value_generator_state=noise_generator_state)
return attr.evolve(
global_state, clip_value=clip_norm, tree_state=new_tree_state)
new_tree_state = tree_aggregation.TreeState(
level_buffer=global_state.tree_state.level_buffer,
level_buffer_idx=global_state.tree_state.level_buffer_idx,
value_generator_state=noise_generator_state,
)
return TreeResidualSumQuery.GlobalState(
tree_state=new_tree_state,
clip_value=clip_norm,
previous_tree_noise=global_state.previous_tree_noise,
)

@classmethod
def build_l2_gaussian_query(cls,
Expand Down
10 changes: 4 additions & 6 deletions tensorflow_privacy/privacy/dp_query/tree_range_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,8 @@

import distutils
import math
from typing import Optional
from typing import Any, NamedTuple, Optional

import attr
import dp_accounting
import tensorflow as tf
from tensorflow_privacy.privacy.dp_query import distributed_discrete_gaussian_query
Expand Down Expand Up @@ -102,17 +101,16 @@ class TreeRangeSumQuery(dp_query.SumAggregationDPQuery):
Improves efficiency and reduces noise scale.
"""

@attr.s(frozen=True)
class GlobalState(object):
class GlobalState(NamedTuple):
"""Class defining global state for TreeRangeSumQuery.

Attributes:
arity: The branching factor of the tree (i.e. the number of children each
internal node has).
inner_query_state: The global state of the inner query.
"""
arity = attr.ib()
inner_query_state = attr.ib()
arity: Any
inner_query_state: Any

def __init__(self,
inner_query: dp_query.SumAggregationDPQuery,
Expand Down