From f4836cc2c87653d10e5f42bbedb86abad8245d45 Mon Sep 17 00:00:00 2001 From: siege Date: Wed, 13 Sep 2023 23:44:34 -0700 Subject: [PATCH] Inline the deprecated tf.layers import. In principle this still exists in Keras, but it's a pain to import and it seemed simpler to just inline the few functions. PiperOrigin-RevId: 565273232 --- .../python/layers/conv_variational.py | 127 ++++++++++++++++-- .../python/layers/conv_variational_test.py | 9 +- 2 files changed, 122 insertions(+), 14 deletions(-) diff --git a/tensorflow_probability/python/layers/conv_variational.py b/tensorflow_probability/python/layers/conv_variational.py index 88e9fc1c8e..a4af7a9be5 100644 --- a/tensorflow_probability/python/layers/conv_variational.py +++ b/tensorflow_probability/python/layers/conv_variational.py @@ -23,7 +23,6 @@ from tensorflow_probability.python.internal import docstring_util from tensorflow_probability.python.layers import util as tfp_layers_util from tensorflow_probability.python.util.seed_stream import SeedStream -from tensorflow.python.layers import utils as tf_layers_util # pylint: disable=g-direct-tensorflow-import from tensorflow.python.ops import nn_ops # pylint: disable=g-direct-tensorflow-import @@ -149,12 +148,12 @@ def __init__( **kwargs) self.rank = rank self.filters = filters - self.kernel_size = tf_layers_util.normalize_tuple( + self.kernel_size = normalize_tuple( kernel_size, rank, 'kernel_size') - self.strides = tf_layers_util.normalize_tuple(strides, rank, 'strides') - self.padding = tf_layers_util.normalize_padding(padding) - self.data_format = tf_layers_util.normalize_data_format(data_format) - self.dilation_rate = tf_layers_util.normalize_tuple( + self.strides = normalize_tuple(strides, rank, 'strides') + self.padding = normalize_padding(padding) + self.data_format = normalize_data_format(data_format) + self.dilation_rate = normalize_tuple( dilation_rate, rank, 'dilation_rate') self.activation = tf.keras.activations.get(activation) self.input_spec = tf.keras.layers.InputSpec(ndim=self.rank + 2) @@ -216,7 +215,7 @@ def build(self, input_shape): dilation_rate=self.dilation_rate, strides=self.strides, padding=self.padding.upper(), - data_format=tf_layers_util.convert_data_format( + data_format=convert_data_format( self.data_format, self.rank + 2)) self.built = True @@ -256,7 +255,7 @@ def compute_output_shape(self, input_shape): space = input_shape[1:-1] new_space = [] for i in range(len(space)): - new_dim = tf_layers_util.conv_output_length( + new_dim = conv_output_length( space[i], self.kernel_size[i], padding=self.padding, @@ -268,7 +267,7 @@ def compute_output_shape(self, input_shape): space = input_shape[2:] new_space = [] for i in range(len(space)): - new_dim = tf_layers_util.conv_output_length( + new_dim = conv_output_length( space[i], self.kernel_size[i], padding=self.padding, @@ -1581,3 +1580,113 @@ def __init__( Convolution1DFlipout = Conv1DFlipout Convolution2DFlipout = Conv2DFlipout Convolution3DFlipout = Conv3DFlipout + + +def convert_data_format(data_format, ndim): # pylint: disable=missing-function-docstring + if data_format == 'channels_last': + if ndim == 3: + return 'NWC' + elif ndim == 4: + return 'NHWC' + elif ndim == 5: + return 'NDHWC' + else: + raise ValueError(f'Input rank: {ndim} not supported. We only support ' + 'input rank 3, 4 or 5.') + elif data_format == 'channels_first': + if ndim == 3: + return 'NCW' + elif ndim == 4: + return 'NCHW' + elif ndim == 5: + return 'NCDHW' + else: + raise ValueError(f'Input rank: {ndim} not supported. We only support ' + 'input rank 3, 4 or 5.') + else: + raise ValueError(f'Invalid data_format: {data_format}. We only support ' + '"channels_first" or "channels_last"') + + +def normalize_tuple(value, n, name): + """Transforms a single integer or iterable of integers into an integer tuple. + + Args: + value: The value to validate and convert. Could an int, or any iterable + of ints. + n: The size of the tuple to be returned. + name: The name of the argument being validated, e.g. "strides" or + "kernel_size". This is only used to format error messages. + + Returns: + A tuple of n integers. + + Raises: + ValueError: If something else than an int/long or iterable thereof was + passed. + """ + if isinstance(value, int): + return (value,) * n + else: + try: + value_tuple = tuple(value) + except TypeError: + raise ValueError(f'Argument `{name}` must be a tuple of {str(n)} ' + f'integers. Received: {str(value)}') from None + if len(value_tuple) != n: + raise ValueError(f'Argument `{name}` must be a tuple of {str(n)} ' + f'integers. Received: {str(value)}') + for single_value in value_tuple: + try: + int(single_value) + except (ValueError, TypeError): + raise ValueError(f'Argument `{name}` must be a tuple of {str(n)} ' + f'integers. Received: {str(value)} including element ' + f'{str(single_value)} of type ' + f'{str(type(single_value))}') from None + return value_tuple + + +def normalize_data_format(value): + data_format = value.lower() + if data_format not in {'channels_first', 'channels_last'}: + raise ValueError('The `data_format` argument must be one of ' + '"channels_first", "channels_last". Received: ' + f'{str(value)}.') + return data_format + + +def normalize_padding(value): + padding = value.lower() + if padding not in {'valid', 'same'}: + raise ValueError('The `padding` argument must be one of "valid", "same". ' + f'Received: {str(padding)}.') + return padding + + +def conv_output_length(input_length, filter_size, padding, stride, dilation=1): + """Determines output length of a convolution given input length. + + Args: + input_length: integer. + filter_size: integer. + padding: one of "same", "valid", "full". + stride: integer. + dilation: dilation rate, integer. + + Returns: + The output length (integer). + """ + if input_length is None: + return None + assert padding in {'same', 'valid', 'full'} + dilated_filter_size = filter_size + (filter_size - 1) * (dilation - 1) + if padding == 'same': + output_length = input_length + elif padding == 'valid': + output_length = input_length - dilated_filter_size + 1 + elif padding == 'full': + output_length = input_length + dilated_filter_size - 1 + else: + raise ValueError(f'Invalid padding: {padding}') + return (output_length + stride - 1) // stride diff --git a/tensorflow_probability/python/layers/conv_variational_test.py b/tensorflow_probability/python/layers/conv_variational_test.py index d842f45c12..a942c808b7 100644 --- a/tensorflow_probability/python/layers/conv_variational_test.py +++ b/tensorflow_probability/python/layers/conv_variational_test.py @@ -30,7 +30,6 @@ from tensorflow_probability.python.layers import util from tensorflow_probability.python.random import random_ops from tensorflow_probability.python.util import seed_stream -from tensorflow.python.layers import utils as tf_layers_util from tensorflow.python.ops import nn_ops @@ -369,13 +368,13 @@ def _testConvReparameterization(self, layer_class): # pylint: disable=invalid-n tf.TensorShape(inputs.shape), filter_shape=tf.TensorShape(kernel_shape), padding='SAME', - data_format=tf_layers_util.convert_data_format( + data_format=conv_variational.convert_data_format( self.data_format, inputs.shape.rank)) expected_outputs = convolution_op(inputs, kernel_posterior.result_sample) expected_outputs = tf.nn.bias_add( expected_outputs, bias_posterior.result_sample, - data_format=tf_layers_util.convert_data_format(self.data_format, 4)) + data_format=conv_variational.convert_data_format(self.data_format, 4)) [ expected_outputs_, actual_outputs_, @@ -435,7 +434,7 @@ def _testConvFlipout(self, layer_class): # pylint: disable=invalid-name tf.TensorShape(inputs.shape), filter_shape=tf.TensorShape(kernel_shape), padding='SAME', - data_format=tf_layers_util.convert_data_format( + data_format=conv_variational.convert_data_format( self.data_format, inputs.shape.rank)) expected_kernel_posterior_affine = normal.Normal( @@ -483,7 +482,7 @@ def _testConvFlipout(self, layer_class): # pylint: disable=invalid-name expected_outputs = tf.nn.bias_add( expected_outputs, bias_posterior.result_sample, - data_format=tf_layers_util.convert_data_format(self.data_format, 4)) + data_format=conv_variational.convert_data_format(self.data_format, 4)) [ expected_outputs_, actual_outputs_,