diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/BUILD b/tensorflow_privacy/privacy/fast_gradient_clipping/BUILD index acabca76..088294eb 100644 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/BUILD +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/BUILD @@ -46,6 +46,8 @@ py_library( ":common_manip_utils", ":layer_registry", ":type_aliases", + "//tensorflow_privacy/privacy/sparsity_preserving_noise:layer_registry", + "//tensorflow_privacy/privacy/sparsity_preserving_noise:type_aliases", ], ) @@ -55,7 +57,11 @@ py_test( python_version = "PY3", shard_count = 8, srcs_version = "PY3", - deps = [":gradient_clipping_utils"], + deps = [ + ":gradient_clipping_utils", + ":layer_registry", + "//tensorflow_privacy/privacy/sparsity_preserving_noise:layer_registry", + ], ) py_library( diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads.py b/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads.py index e31f1781..be7d92f9 100644 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads.py +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads.py @@ -164,6 +164,7 @@ def compute_gradient_norms( registry_generator_fn = gradient_clipping_utils.get_registry_generator_fn( tape=tape, layer_registry=layer_registry, + sparse_noise_layer_registry=None, num_microbatches=num_microbatches, ) layer_grad_vars, generator_outputs_list = ( diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads_test.py b/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads_test.py index 7b91461c..6ec47941 100644 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads_test.py +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads_test.py @@ -132,6 +132,7 @@ def _run_model_forward_backward_pass( registry_generator_fn = gradient_clipping_utils.get_registry_generator_fn( tape=tape, layer_registry=layer_registry.make_default_layer_registry(), + sparse_noise_layer_registry=None, num_microbatches=None, ) layer_grad_vars, registry_fn_outputs_list = ( diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/gradient_clipping_utils.py b/tensorflow_privacy/privacy/fast_gradient_clipping/gradient_clipping_utils.py index bac323ce..989a69c2 100644 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/gradient_clipping_utils.py +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/gradient_clipping_utils.py @@ -22,6 +22,8 @@ from tensorflow_privacy.privacy.fast_gradient_clipping import common_manip_utils from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry as lr from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases +from tensorflow_privacy.privacy.sparsity_preserving_noise import layer_registry as snlr +from tensorflow_privacy.privacy.sparsity_preserving_noise import type_aliases as sn_type_aliases @dataclasses.dataclass(frozen=True) @@ -29,6 +31,9 @@ class RegistryGeneratorFunctionOutput: layer_id: str layer_vars: Optional[Sequence[tf.Variable]] layer_sqr_norm_fn: Optional[type_aliases.SquareNormFunction] + varname_to_count_contribution_fn: Optional[ + dict[str, sn_type_aliases.ContributionCountHistogramFn] + ] layer_trainable_weights: Optional[Sequence[tf.Variable]] @@ -46,6 +51,7 @@ def has_internal_compute_graph(input_object: Any): def get_registry_generator_fn( tape: tf.GradientTape, layer_registry: lr.LayerRegistry, + sparse_noise_layer_registry: snlr.LayerRegistry, num_microbatches: Optional[type_aliases.BatchSize] = None, ) -> Optional[Callable[..., Tuple[tf.Tensor, RegistryGeneratorFunctionOutput]]]: """Creates the generator function for `model_forward_backward_pass()`. @@ -58,6 +64,10 @@ def get_registry_generator_fn( `output` is the pre-activator tensor, `sqr_grad_norms` is related to the squared norms of a layer's pre-activation tensor, and `vars` are relevant trainable + sparse_noise_layer_registry: A `LayerRegistry` instance containing functions + that help compute contribution counts for sparse noise. See + `tensorflow_privacy.privacy.sparsity_preserving_noise.layer_registry` for + more details. num_microbatches: An optional number or scalar `tf.Tensor` for the number of microbatches. If not None, indicates that the loss is grouped into num_microbatches (in this case, the batch dimension needs to be a multiple @@ -83,6 +93,16 @@ def registry_generator_fn(layer_instance, args, kwargs): 'be used for efficient gradient clipping.' % layer_instance.__class__.__name__ ) + varname_to_count_contribution_fn = None + if sparse_noise_layer_registry and sparse_noise_layer_registry.is_elem( + layer_instance + ): + count_contribution_registry_fn = sparse_noise_layer_registry.lookup( + layer_instance + ) + varname_to_count_contribution_fn = count_contribution_registry_fn( + layer_instance, args, kwargs, num_microbatches + ) registry_fn = layer_registry.lookup(layer_instance) (layer_vars, layer_outputs, layer_sqr_norm_fn) = registry_fn( layer_instance, args, kwargs, tape, num_microbatches @@ -91,6 +111,7 @@ def registry_generator_fn(layer_instance, args, kwargs): layer_id=str(id(layer_instance)), layer_vars=layer_vars, layer_sqr_norm_fn=layer_sqr_norm_fn, + varname_to_count_contribution_fn=varname_to_count_contribution_fn, layer_trainable_weights=layer_instance.trainable_weights, ) else: diff --git a/tensorflow_privacy/privacy/fast_gradient_clipping/gradient_clipping_utils_test.py b/tensorflow_privacy/privacy/fast_gradient_clipping/gradient_clipping_utils_test.py index 7069273d..0535e011 100644 --- a/tensorflow_privacy/privacy/fast_gradient_clipping/gradient_clipping_utils_test.py +++ b/tensorflow_privacy/privacy/fast_gradient_clipping/gradient_clipping_utils_test.py @@ -17,6 +17,8 @@ from absl.testing import parameterized import tensorflow as tf from tensorflow_privacy.privacy.fast_gradient_clipping import gradient_clipping_utils +from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry as lr +from tensorflow_privacy.privacy.sparsity_preserving_noise import layer_registry as snlr # ============================================================================== @@ -175,5 +177,92 @@ def test_new_custom_layer_spec(self): ) +class RegistryGeneratorFnTest(tf.test.TestCase, parameterized.TestCase): + + def _get_sparse_layer_registry(self): + def count_contribution_fn(_): + return None + + def registry_fn(*_): + return {'var': count_contribution_fn} + + registry = snlr.LayerRegistry() + registry.insert(tf.keras.layers.Embedding, registry_fn) + return registry, count_contribution_fn + + def _get_layer_registry(self): + var = tf.Variable(1.0) + output = tf.ones((1, 1)) + + def sqr_norm_fn(_): + return None + + def registry_fn(*_): + return [var], output, sqr_norm_fn + + registry = lr.LayerRegistry() + registry.insert(tf.keras.layers.Embedding, registry_fn) + registry.insert(tf.keras.layers.Dense, registry_fn) + return registry, var, output, sqr_norm_fn + + def test_registry_generator_fn(self): + inputs = tf.constant([[0, 1]]) + model = tf.keras.Sequential([ + tf.keras.layers.Embedding(10, 1), + tf.keras.layers.Dense(1), + ]) + + sparse_layer_registry, count_contribution_fn = ( + self._get_sparse_layer_registry() + ) + layer_registry, var, output, sqr_norm_fn = self._get_layer_registry() + registry_generator_fn = gradient_clipping_utils.get_registry_generator_fn( + tape=tf.GradientTape(), + layer_registry=layer_registry, + sparse_noise_layer_registry=sparse_layer_registry, + num_microbatches=None, + ) + embedding_layer = model.layers[0] + out, embedding_registry_generator_fn_output = registry_generator_fn( + embedding_layer, + [inputs], + {}, + ) + expected_embedding_registry_generator_fn_output = ( + gradient_clipping_utils.RegistryGeneratorFunctionOutput( + layer_id=str(id(embedding_layer)), + layer_vars=[var], + layer_sqr_norm_fn=sqr_norm_fn, + varname_to_count_contribution_fn={'var': count_contribution_fn}, + layer_trainable_weights=embedding_layer.trainable_weights, + ) + ) + self.assertEqual( + embedding_registry_generator_fn_output, + expected_embedding_registry_generator_fn_output, + ) + self.assertEqual(out, output) + dense_layer = model.layers[1] + out, dense_registry_generator_fn_output = registry_generator_fn( + dense_layer, + [inputs], + {}, + ) + expected_dense_registry_generator_fn_output = ( + gradient_clipping_utils.RegistryGeneratorFunctionOutput( + layer_id=str(id(dense_layer)), + layer_vars=[var], + layer_sqr_norm_fn=sqr_norm_fn, + varname_to_count_contribution_fn=None, + layer_trainable_weights=dense_layer.trainable_weights, + ) + ) + self.assertEqual( + dense_registry_generator_fn_output, + expected_dense_registry_generator_fn_output, + ) + self.assertEqual(out, output) + + if __name__ == '__main__': tf.test.main() diff --git a/tensorflow_privacy/privacy/keras_models/dp_keras_model.py b/tensorflow_privacy/privacy/keras_models/dp_keras_model.py index b7104f4d..472f175f 100644 --- a/tensorflow_privacy/privacy/keras_models/dp_keras_model.py +++ b/tensorflow_privacy/privacy/keras_models/dp_keras_model.py @@ -280,6 +280,7 @@ def train_step(self, data): gradient_clipping_utils.get_registry_generator_fn( tape=tape, layer_registry=self._layer_registry, + sparse_noise_layer_registry=None, num_microbatches=num_microbatches, ) )