From 4007769e203e775b5f69fc7d53e001700452562b Mon Sep 17 00:00:00 2001 From: Carl Date: Mon, 6 Jan 2025 15:17:31 +0000 Subject: [PATCH] fix unbound variable error when not using torch or tf backend (#2) Refactor to use backend specific gradient functions in tests and merges logic into single function --- .../_tf_keras/keras/quantizers/__init__.py | 7 +- keras/api/quantizers/__init__.py | 7 +- keras/src/quantizers/__init__.py | 7 +- keras/src/quantizers/quantizers.py | 310 ++----- keras/src/quantizers/quantizers_test.py | 803 ++++-------------- 5 files changed, 257 insertions(+), 877 deletions(-) diff --git a/keras/api/_tf_keras/keras/quantizers/__init__.py b/keras/api/_tf_keras/keras/quantizers/__init__.py index d6330c22335..86ca1a24039 100644 --- a/keras/api/_tf_keras/keras/quantizers/__init__.py +++ b/keras/api/_tf_keras/keras/quantizers/__init__.py @@ -13,13 +13,8 @@ from keras.src.quantizers.quantizers import compute_float8_amax_history from keras.src.quantizers.quantizers import compute_float8_scale from keras.src.quantizers.quantizers import fake_quant_with_min_max_args -from keras.src.quantizers.quantizers import ( - fake_quant_with_min_max_args_gradient, -) +from keras.src.quantizers.quantizers import fake_quant_with_min_max_vars from keras.src.quantizers.quantizers import ( fake_quant_with_min_max_vars_per_channel, ) -from keras.src.quantizers.quantizers import ( - fake_quant_with_min_max_vars_per_channel_gradient, -) from keras.src.quantizers.quantizers import quantize_and_dequantize diff --git a/keras/api/quantizers/__init__.py b/keras/api/quantizers/__init__.py index d6330c22335..86ca1a24039 100644 --- a/keras/api/quantizers/__init__.py +++ b/keras/api/quantizers/__init__.py @@ -13,13 +13,8 @@ from keras.src.quantizers.quantizers import compute_float8_amax_history from keras.src.quantizers.quantizers import compute_float8_scale from keras.src.quantizers.quantizers import fake_quant_with_min_max_args -from keras.src.quantizers.quantizers import ( - fake_quant_with_min_max_args_gradient, -) +from keras.src.quantizers.quantizers import fake_quant_with_min_max_vars from keras.src.quantizers.quantizers import ( fake_quant_with_min_max_vars_per_channel, ) -from keras.src.quantizers.quantizers import ( - fake_quant_with_min_max_vars_per_channel_gradient, -) from keras.src.quantizers.quantizers import quantize_and_dequantize diff --git a/keras/src/quantizers/__init__.py b/keras/src/quantizers/__init__.py index 461b6ef4596..22e997cfd06 100644 --- a/keras/src/quantizers/__init__.py +++ b/keras/src/quantizers/__init__.py @@ -7,15 +7,10 @@ from keras.src.quantizers.quantizers import compute_float8_amax_history from keras.src.quantizers.quantizers import compute_float8_scale from keras.src.quantizers.quantizers import fake_quant_with_min_max_args -from keras.src.quantizers.quantizers import ( - fake_quant_with_min_max_args_gradient, -) +from keras.src.quantizers.quantizers import fake_quant_with_min_max_vars from keras.src.quantizers.quantizers import ( fake_quant_with_min_max_vars_per_channel, ) -from keras.src.quantizers.quantizers import ( - fake_quant_with_min_max_vars_per_channel_gradient, -) from keras.src.quantizers.quantizers import quantize_and_dequantize from keras.src.saving import serialization_lib from keras.src.utils.naming import to_snake_case diff --git a/keras/src/quantizers/quantizers.py b/keras/src/quantizers/quantizers.py index 565ff76147a..fdf73dc3edd 100644 --- a/keras/src/quantizers/quantizers.py +++ b/keras/src/quantizers/quantizers.py @@ -133,12 +133,11 @@ def adjust_and_nudge(min_range, max_range, num_bits, narrow_range): raise ValueError("num_bits must be >= 2") n_steps = ops.cast(2**num_bits - 1, "float32") - if narrow_range: - n_steps -= 1.0 + n_steps = n_steps if not narrow_range else n_steps - 1.0 # Handle the case where min and max are too close - if abs(max_range - min_range) < 1e-10: - return min_range, max_range, 1.0 + # if abs(max_range - min_range) < 1e-10: + # return min_range, max_range, 1.0 # Calculate the step size step_size = (max_range - min_range) / n_steps @@ -192,108 +191,6 @@ def adjust_and_nudge(min_range, max_range, num_bits, narrow_range): ) # Returning nudged values and scale -@keras_export("keras.quantizers.fake_quant_with_min_max_args") -def fake_quant_with_min_max_args( - inputs, - min_range=-6.0, - max_range=6.0, - num_bits=8, - narrow_range=False, -): - """Fake quantization operation matching TensorFlow's implementation.""" - - inputs = ops.convert_to_tensor(inputs) - - @ops.custom_gradient - def _fake_quant_with_min_max_args(x): - quant_min, quant_max, step_size = adjust_and_nudge( - min_range, max_range, num_bits, narrow_range - ) - - n_steps = 2**num_bits - 1 - if narrow_range: - n_steps -= 1 - - # Clip and nudge input to the range - x_clipped = ops.clip(x, quant_min, quant_max) - x_norm = (x_clipped - quant_min) / step_size - x_quantized = ops.round(x_norm) - x_quantized = ops.clip(x_quantized, 0.0, n_steps) - result = x_quantized * step_size + quant_min - - def grad(*args, upstream=None): - if upstream is None: - (upstream,) = args - # Gradient mask: valid within the range - mask = ops.cast( - (x >= quant_min) & (x <= quant_max), dtype=upstream.dtype - ) - return ops.multiply(upstream, mask) - - return result, grad - - return _fake_quant_with_min_max_args(inputs) - - -@keras_export("keras.quantizers.fake_quant_with_min_max_vars") -def fake_quant_with_min_max_vars( - inputs, - min_range=-6.0, - max_range=6.0, - num_bits=8, - narrow_range=False, -): - """Fake quantization operation matching TensorFlow's implementation.""" - return fake_quant_with_min_max_args( - inputs, min_range, max_range, num_bits, narrow_range - ) - - -@keras_export("keras.quantizers.fake_quant_with_min_max_args_gradient") -def fake_quant_with_min_max_args_gradient( - gradients, - inputs, - min_range=-6.0, - max_range=6.0, - num_bits=8, - narrow_range=False, -): - """Fake quantization operation with gradient, - matching TensorFlow's implementation.""" - - inputs = ops.convert_to_tensor(inputs) - - def _fake_quant_with_min_max_args_gradient(x): - quant_min, quant_max, step_size = adjust_and_nudge( - min_range, max_range, num_bits, narrow_range - ) - - n_steps = 2**num_bits - 1 - if narrow_range: - n_steps -= 1 - - # Clip and nudge input to the range - x_clipped = ops.clip(x, quant_min, quant_max) - x_norm = (x_clipped - quant_min) / step_size - x_quantized = ops.round(x_norm) - x_quantized = ops.clip(x_quantized, 0.0, n_steps) - result = x_quantized * step_size + quant_min - - def grad(*args, upstream=None): - if upstream is None: - (upstream,) = args - # Gradient mask: valid within the range - mask = ops.cast( - (x >= quant_min) & (x <= quant_max), dtype=upstream.dtype - ) - return ops.multiply(upstream, mask) - - return result, grad - - output, grad = _fake_quant_with_min_max_args_gradient(inputs) - return output, grad(gradients) - - @keras_export("keras.quantizers.fake_quant_with_min_max_vars_per_channel") def fake_quant_with_min_max_vars_per_channel( inputs, @@ -303,7 +200,8 @@ def fake_quant_with_min_max_vars_per_channel( narrow_range, ): """ - Perform per-channel fake quantization with custom gradient. + Perform per-channel fake quantization with custom gradient using vectorized + operations. Args: inputs: Input tensor of float type @@ -315,165 +213,93 @@ def fake_quant_with_min_max_vars_per_channel( Returns: Fake-quantized tensor """ - inputs = ops.convert_to_tensor(inputs) min_vals = ops.convert_to_tensor(min_vals) max_vals = ops.convert_to_tensor(max_vals) @ops.custom_gradient def _fake_quant_with_min_max_vars_per_channel(x, min_val, max_val): - # Determine the number of channels - num_channels = min_val.shape[-1] - - # Initialize an empty list to store quantized values for each channel - quantized_channels = [] - masks = [] - - # Iterate over each channel - for i in range(num_channels): - # Extract min/max values for current channel - current_min = min_val[..., i] - current_max = max_val[..., i] - - # Calculate step size and quantized min/max using _adjust_range - qnt_min, qnt_max, step_size = adjust_and_nudge( - current_min, current_max, num_bits, narrow_range - ) - # Calculate the number of steps - n_steps = 2**num_bits - 1 - if narrow_range: - n_steps -= 1 - - # Clip and nudge input to the range for the current channel - x_clipped = ops.clip(x[..., i], qnt_min, qnt_max) - x_norm = (x_clipped - qnt_min) / step_size - x_quantized = ops.round(x_norm) - x_quantized = ops.clip(x_quantized, 0.0, n_steps) - result_channel = x_quantized * step_size + qnt_min - - quantized_channels.append(result_channel) - mask = ops.cast( - (x[..., i] >= qnt_min) & (x[..., i] <= qnt_max), - dtype=np.float32, - ) - masks.append(mask) - - # Concatenate quantized channels - result = ops.stack(quantized_channels, axis=-1) + # Calculate quantization parameters for all channels at once + qnt_min, qnt_max, step_size = adjust_and_nudge( + min_val, max_val, num_bits, narrow_range + ) + + # Calculate number of steps + n_steps = 2**num_bits - 1 + if narrow_range: + n_steps -= 1 + + # Expand dimensions to allow broadcasting + qnt_min = ops.expand_dims(qnt_min, axis=list(range(len(x.shape) - 1))) + qnt_max = ops.expand_dims(qnt_max, axis=list(range(len(x.shape) - 1))) + step_size = ops.expand_dims( + step_size, axis=list(range(len(x.shape) - 1)) + ) + + # Clip and quantize all channels simultaneously + x_clipped = ops.clip(x, qnt_min, qnt_max) + x_norm = (x_clipped - qnt_min) / step_size + x_quantized = ops.round(x_norm) + x_quantized = ops.clip(x_quantized, 0.0, n_steps) + result = x_quantized * step_size + qnt_min + + # Create gradient mask for all channels + masks = ops.cast( + (x >= qnt_min) & (x <= qnt_max), + dtype=np.float32, + ) def grad(*args, upstream=None): if upstream is None: (upstream,) = args - # Gradient mask: valid within the range - return ops.multiply(upstream, mask) + # Gradient for x + dx = ops.multiply(upstream, masks) + + # Gradient for min_val + # When x is clipped to min, the gradient flows to min_val + min_mask = ops.cast(x <= qnt_min, dtype=np.float32) + dims_to_reduce = list(range(len(x.shape) - 1)) + grad_min = ops.sum(upstream * min_mask, axis=dims_to_reduce) + + # Gradient for max_val + # When x is clipped to max, the gradient flows to max_val + max_mask = ops.cast(x >= qnt_max, dtype=np.float32) + grad_max = ops.sum(upstream * max_mask, axis=dims_to_reduce) + + return dx, grad_min, grad_max return result, grad return _fake_quant_with_min_max_vars_per_channel(inputs, min_vals, max_vals) -@keras_export( - "keras.quantizers.fake_quant_with_min_max_vars_per_channel_gradient" -) -def fake_quant_with_min_max_vars_per_channel_gradient( - gradients, +@keras_export("keras.quantizers.fake_quant_with_min_max_args") +def fake_quant_with_min_max_args( inputs, min_vals, max_vals, - num_bits, - narrow_range, + num_bits=8, + narrow_range=False, ): - """ - Perform per-channel fake quantization with custom gradient. - - Args: - inputs: Input tensor of float type - min_vals: Per-channel minimum values - max_vals: Per-channel maximum values - num_bits: Quantization bit width (2-16) - narrow_range: Whether to use narrow quantization range - - Returns: - Fake-quantized tensor - """ - - if isinstance(inputs, np.ndarray): - inputs = ops.convert_to_tensor(inputs) - min_vals = ops.convert_to_tensor(min_vals) - max_vals = ops.convert_to_tensor(max_vals) - - # @ops.custom_gradient - def _fake_quant_with_min_max_vars_per_channel_gradient(x, min_val, max_val): - # Determine the number of channels - num_channels = min_val.shape[-1] - - # Initialize an empty list to store quantized values for each channel - quantized_channels = [] - between_min_max_masks = [] - below_min_masks = [] - above_max_masks = [] - - # Iterate over each channel - for i in range(num_channels): - # Extract min/max values for current channel - current_min = min_val[..., i] - current_max = max_val[..., i] - - # Calculate step size and quantized min/max using _adjust_range - qnt_min, qnt_max, step_size = adjust_and_nudge( - current_min, current_max, num_bits, narrow_range - ) - - # Calculate the number of steps - n_steps = 2**num_bits - 1 - if narrow_range: - n_steps -= 1 - - # Clip and nudge input to the range for the current channel - x_clipped = ops.clip(x[..., i], qnt_min, qnt_max) - x_norm = (x_clipped - qnt_min) / step_size - x_quantized = ops.round(x_norm) - x_quantized = ops.clip(x_quantized, 0.0, n_steps) - result_channel = x_quantized * step_size + qnt_min - between_min_max_mask = ops.cast( - (x[..., i] >= qnt_min) & (x[..., i] <= qnt_max), - dtype=np.float32, - ) - below_min_mask = ops.cast((x[..., i] < qnt_min), dtype=np.float32) - above_max_mask = ops.cast((x[..., i] > qnt_max), dtype=np.float32) - between_min_max_masks.append(between_min_max_mask) - below_min_masks.append(below_min_mask) - above_max_masks.append(above_max_mask) - quantized_channels.append(result_channel) - - # Concatenate quantized channels - result = ops.stack(quantized_channels, axis=-1) - between_min_max_masks = ops.stack(between_min_max_masks, axis=-1) - below_min_masks = ops.stack(below_min_masks, axis=-1) - above_max_masks = ops.stack(above_max_masks, axis=-1) - - def grad(*args, upstream=None): - if upstream is None: - (upstream,) = args - backprops_wrt_input = ops.multiply(upstream, between_min_max_masks) - backprops_wrt_min = ops.sum( - ops.multiply(upstream, below_min_masks), axis=0 - ) - backprops_wrt_max = ops.sum( - ops.multiply(upstream, above_max_masks), axis=0 - ) - - return backprops_wrt_input, backprops_wrt_min, backprops_wrt_max + """Fake quantization operation matching TensorFlow's implementation.""" + return fake_quant_with_min_max_vars_per_channel( + inputs, min_vals, max_vals, num_bits, narrow_range + ) - return result, grad - output, grad = _fake_quant_with_min_max_vars_per_channel_gradient( - inputs, min_vals, max_vals +@keras_export("keras.quantizers.fake_quant_with_min_max_vars") +def fake_quant_with_min_max_vars( + inputs, + min_vals, + max_vals, + num_bits=8, + narrow_range=False, +): + """Fake quantization operation matching TensorFlow's implementation.""" + return fake_quant_with_min_max_vars_per_channel( + inputs, min_vals, max_vals, num_bits, narrow_range ) - backprops_wrt_input, backprops_wrt_min, backprops_wrt_max = grad(gradients) - - return output, backprops_wrt_input, backprops_wrt_min, backprops_wrt_max """Float8-related methods""" diff --git a/keras/src/quantizers/quantizers_test.py b/keras/src/quantizers/quantizers_test.py index 81543552627..095e3f6265c 100644 --- a/keras/src/quantizers/quantizers_test.py +++ b/keras/src/quantizers/quantizers_test.py @@ -104,151 +104,6 @@ def test_quantize_and_dequantize(self): def _TestOp( self, - op, - input_min, - input_max, - num_bits, - narrow_range, - expected_nudged_input_min, - expected_nudged_input_max, - expected_step, - ): - inputs = ops.array( - [ - expected_nudged_input_min - expected_step, - expected_nudged_input_min - 0.01, - expected_nudged_input_min, - expected_nudged_input_min + 0.01, - expected_nudged_input_min + expected_step - 0.01, - expected_nudged_input_min + expected_step, - expected_nudged_input_min + expected_step + 0.01, - expected_nudged_input_max - 0.01, - expected_nudged_input_max, - expected_nudged_input_max + 0.01, - expected_nudged_input_max + expected_step, - ], - dtype="float32", - ) - expected = ops.array( - [ - expected_nudged_input_min, - expected_nudged_input_min, - expected_nudged_input_min, - expected_nudged_input_min, - expected_nudged_input_min + expected_step, - expected_nudged_input_min + expected_step, - expected_nudged_input_min + expected_step, - expected_nudged_input_max, - expected_nudged_input_max, - expected_nudged_input_max, - expected_nudged_input_max, - ], - dtype="float32", - ) - initial_gradients = ops.arange(1, len(inputs) + 1, dtype="float32") - expected_backprops = ops.array( - [0.0, 0.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 0.0, 0.0], - dtype=float, - ) - if backend.backend() == "tensorflow": - import tensorflow as tf - - @tf.function(jit_compile=True) - def test_op(inputs, input_min, input_max, num_bits, narrow_range): - with tf.GradientTape() as tape: - tape.watch(inputs) - result = op( - inputs, input_min, input_max, num_bits, narrow_range - ) - return initial_gradients * tape.gradient(result, inputs) - - gradients = test_op( - inputs, input_min, input_max, num_bits, narrow_range - ) - - if backend.backend() == "torch": - import torch - - def test_op(inputs, input_min, input_max, num_bits, narrow_range): - # Create tensor and enable gradient tracking - inputs = torch.tensor( - inputs, dtype=torch.float32, requires_grad=True - ) - - # Apply the quantization operation - result = op( - inputs, input_min, input_max, num_bits, narrow_range - ) - - # Compute gradients - result.backward(torch.ones_like(result)) # Compute gradient - - # Multiply gradients by initial_gradients - gradients = initial_gradients * inputs.grad - - return gradients - - gradients = test_op( - inputs, input_min, input_max, num_bits, narrow_range - ) - - # test gradients - self.assertAllClose(gradients, expected_backprops) - - outputs = op( - inputs, - input_min, - input_max, - num_bits=num_bits, - narrow_range=narrow_range, - ) - self.assertAllClose(outputs, expected) - - def _TestGradOp( - self, - grad_op, - input_min, - input_max, - num_bits, - narrow_range, - expected_nudged_input_min, - expected_nudged_input_max, - expected_step, - ): - inputs = ops.array( - [ - expected_nudged_input_min - expected_step, - expected_nudged_input_min - 0.01, - expected_nudged_input_min, - expected_nudged_input_min + 0.01, - expected_nudged_input_min + expected_step - 0.01, - expected_nudged_input_min + expected_step, - expected_nudged_input_min + expected_step + 0.01, - expected_nudged_input_max - 0.01, - expected_nudged_input_max, - expected_nudged_input_max + 0.01, - expected_nudged_input_max + expected_step, - ], - dtype="float32", - ) - initial_gradients = ops.arange(1, len(inputs) + 1, dtype="float32") - expected_backprops = ops.array( - [0.0, 0.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 0.0, 0.0], - dtype=float, - ) - _, gradients = grad_op( - initial_gradients, - inputs, - input_min, - input_max, - num_bits=num_bits, - narrow_range=narrow_range, - ) - self.assertAllClose(gradients, expected_backprops) - - def _TestChannelsOp( - self, - op, input_mins, input_maxs, num_bits, @@ -260,6 +115,8 @@ def _TestChannelsOp( num_channels = len(input_mins) inputs_list = [] expected_list = [] + initial_gradients_list = [] + expected_backprops_wrt_input_list = [] for i in range(num_channels): expected_nudged_input_min = expected_nudged_input_mins[i] expected_nudged_input_max = expected_nudged_input_maxs[i] @@ -295,95 +152,95 @@ def _TestChannelsOp( expected_nudged_input_max, ] ) - inputs = ops.transpose(ops.array(inputs_list, dtype="float32")) - expected = ops.transpose(ops.array(expected_list, dtype="float32")) - input_min = ops.array(input_mins, dtype="float32") - input_max = ops.array(input_maxs, dtype="float32") - outputs = op( - inputs, - input_min, - input_max, - num_bits=num_bits, - narrow_range=narrow_range, - ) - self.assertAllClose(outputs, expected) - - def _TestChannelsGradOp( - self, - op, - input_mins, - input_maxs, - num_bits, - narrow_range, - expected_nudged_input_mins, - expected_nudged_input_maxs, - expected_steps, - ): - num_channels = len(input_mins) - inputs_list = [] - gradients_list = [] - expected_list = [] - expected_backprops_wrt_input_list = [] - expected_backprops_wrt_min_list = [] - expected_backprops_wrt_max_list = [] - for i in range(num_channels): - expected_nudged_input_min = expected_nudged_input_mins[i] - expected_nudged_input_max = expected_nudged_input_maxs[i] - expected_step = expected_steps[i] - inputs = [ - expected_nudged_input_min - expected_step, - expected_nudged_input_min - 0.01, - expected_nudged_input_min, - expected_nudged_input_min + 0.01, - expected_nudged_input_min + expected_step - 0.01, - expected_nudged_input_min + expected_step, - expected_nudged_input_min + expected_step + 0.01, - expected_nudged_input_max - 0.01, - expected_nudged_input_max, - expected_nudged_input_max + 0.01, - expected_nudged_input_max + expected_step, - ] - inputs_list.append(inputs) - gradients_list.append(list(range(1, len(inputs) + 1))) + initial_gradients_list.append( + list(range(1, len(inputs_list[-1]) + 1)) + ) expected_backprops_wrt_input_list.append( [0.0, 0.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 0.0, 0.0] ) - expected_backprops_wrt_min_list.append(1.0 + 2.0) - expected_backprops_wrt_max_list.append(10.0 + 11.0) - expected_list.append( - [ - expected_nudged_input_min, - expected_nudged_input_min, - expected_nudged_input_min, - expected_nudged_input_min, - expected_nudged_input_min + expected_step, - expected_nudged_input_min + expected_step, - expected_nudged_input_min + expected_step, - expected_nudged_input_max, - expected_nudged_input_max, - expected_nudged_input_max, - expected_nudged_input_max, - ] - ) - expected = ops.transpose(ops.array(expected_list, dtype="float32")) - inputs = ops.transpose(ops.array(inputs_list, dtype="float32")) - input_gradients = ops.transpose( - ops.array(gradients_list, dtype="float32") - ) + expected = ops.transpose(ops.array(expected_list, dtype="float32")) expected_backprops_wrt_input = ops.transpose( ops.array(expected_backprops_wrt_input_list, dtype="float32") ) - expected_backprops_wrt_min = ops.array( - expected_backprops_wrt_min_list, dtype="float32" - ) - expected_backprops_wrt_max = ops.array( - expected_backprops_wrt_max_list, dtype="float32" - ) input_min = ops.array(input_mins, dtype="float32") input_max = ops.array(input_maxs, dtype="float32") - outputs, backprops_wrt_input, backprops_wrt_min, backprops_wrt_max = op( - input_gradients, + initial_gradients = ops.transpose( + ops.array(initial_gradients_list, dtype="float32") + ) + if backend.backend() == "tensorflow": + import tensorflow as tf + + @tf.function(jit_compile=True) + def test_op(inputs, input_mins, input_maxs, num_bits, narrow_range): + with tf.GradientTape() as tape: + tape.watch(inputs) + result = ( + quantizers.fake_quant_with_min_max_vars_per_channel( + inputs, + input_mins, + input_maxs, + num_bits, + narrow_range, + ) + ) + return initial_gradients * tape.gradient(result, inputs) + + gradients = test_op( + inputs, input_mins, input_maxs, num_bits, narrow_range + ) + # test gradients + self.assertAllClose(gradients, expected_backprops_wrt_input) + + if backend.backend() == "torch": + import torch + + def test_op(inputs, input_min, input_max, num_bits, narrow_range): + # Create tensor and enable gradient tracking + inputs = torch.tensor( + inputs, dtype=torch.float32, requires_grad=True + ) + + # Apply the quantization operation + result = quantizers.fake_quant_with_min_max_vars_per_channel( + inputs, input_mins, input_maxs, num_bits, narrow_range + ) + + # Compute gradients + result.backward(torch.ones_like(result)) + + return initial_gradients * inputs.grad + + gradients = test_op( + inputs, input_min, input_max, num_bits, narrow_range + ) + # test gradients + self.assertAllClose(gradients, expected_backprops_wrt_input) + + if backend.backend() == "jax": + import jax + + def test_op(inputs, input_mins, input_maxs, num_bits, narrow_range): + # Define the function to compute gradients for + def quantize_fn(x): + return quantizers.fake_quant_with_min_max_vars_per_channel( + x, input_mins, input_maxs, num_bits, narrow_range + ) + + # Get the gradient function + grad_fn = jax.jit(jax.grad(lambda x: ops.sum(quantize_fn(x)))) + + # Compute gradients + input_gradients = grad_fn(inputs) + + return initial_gradients * input_gradients + + gradients = test_op( + inputs, input_min, input_max, num_bits, narrow_range + ) + # test gradients + self.assertAllClose(gradients, expected_backprops_wrt_input) + outputs = quantizers.fake_quant_with_min_max_vars_per_channel( inputs, input_min, input_max, @@ -392,473 +249,188 @@ def _TestChannelsGradOp( ) self.assertAllClose(outputs, expected) - self.assertAllClose(backprops_wrt_input, expected_backprops_wrt_input) - self.assertAllClose(expected_backprops_wrt_min, backprops_wrt_min) - self.assertAllClose(expected_backprops_wrt_max, backprops_wrt_max) - - def test_fakeQuantWithMinMaxArgs_with8BitsNoSclngNoNdgng(self): + def test_fakeQuantWithMinMax_8BitsNoSclngNoNdgng(self): self._TestOp( - quantizers.fake_quant_with_min_max_args, - 0.0, - 255.0, + [0.0], + [255.0], 8, False, - 0.0, - 255.0, - 1.0, + [0.0], + [255.0], + [1.0], ) - def test_fakeQuantWithMinMaxArgs_with8BitsSclngAndNdgngDown(self): + def test_fakeQuantWithMinMax_8BitsSclngAndNdgngDown(self): self._TestOp( - quantizers.fake_quant_with_min_max_args, - 0.5, - 128.0, + [0.5], + [128.0], 8, False, - 0.0, - 127.5, - 0.5, + [0.0], + [127.5], + [0.5], ) - def test_fakeQuantWithMinMaxArgs_with8BitsSclngAndNdgngUp(self): + def test_fakeQuantWithMinMax_8BitsSclngAndNdgngUp(self): self._TestOp( - quantizers.fake_quant_with_min_max_args, - -128.0, - -0.5, + [-128.0], + [-0.5], 8, False, - -127.5, - 0.0, - 0.5, + [-127.5], + [0.0], + [0.5], ) - def test_fakeQuantWithMinMaxArgs_with8BitsSclngAndNdgngBtwn(self): + def test_fakeQuantWithMinMax_8BitsSclngAndNdgngBtwn(self): self._TestOp( - quantizers.fake_quant_with_min_max_args, - -0.1, - 127.4, + [-0.1], + [127.4], 8, False, - 0.0, - 127.5, - 0.5, + [0.0], + [127.5], + [0.5], ) # 8 bits, narrow range. - def test_fakeQuantWithMinMaxArgs_with8BitsNrrwRangeNoSclngNoNdgng(self): + def test_fakeQuantWithMinMax_8BitsNrrwRangeNoSclngNoNdgng(self): self._TestOp( - quantizers.fake_quant_with_min_max_args, - 0.0, - 254.0, + [0.0], + [254.0], 8, True, - 0.0, - 254.0, - 1.0, + [0.0], + [254.0], + [1.0], ) - def test_fakeQuantWithMinMaxArgs_with8BitsNrrwRangeSclngAndNdgngDown(self): + def test_fakeQuantWithMinMax_8BitsNrrwRangeSclngAndNdgngDown(self): self._TestOp( - quantizers.fake_quant_with_min_max_args, - 0.1, - 127.1, + [0.1], + [127.1], 8, True, - 0.0, - 127.0, - 0.5, + [0.0], + [127.0], + [0.5], ) - def test_fakeQuantWithMinMaxArgs_with8BitsNrrwRangeSclngAndNdgngUp(self): + def test_fakeQuantWithMinMax_8BitsNrrwRangeSclngAndNdgngUp(self): self._TestOp( - quantizers.fake_quant_with_min_max_args, - -127.1, - -0.1, + [-127.1], + [-0.1], 8, True, - -127.0, - 0.0, - 0.5, + [-127.0], + [0.0], + [0.5], ) - def test_fakeQuantWithMinMaxArgs_with8BitsNrrwRangeSclngAndNdgngBtwn(self): + def test_fakeQuantWithMinMax_8BitsNrrwRangeSclngAndNdgngBtwn(self): self._TestOp( - quantizers.fake_quant_with_min_max_args, - -0.1, - 126.9, + [-0.1], + [126.9], 8, True, - 0.0, - 127.0, - 0.5, + [0.0], + [127.0], + [0.5], ) # 7 bits, wide range. - def test_fakeQuantWithMinMaxArgs_with7BitsNoSclngNoNdgng(self): + def test_fakeQuantWithMinMax_7BitsNoSclngNoNdgng(self): self._TestOp( - quantizers.fake_quant_with_min_max_args, - 0.0, - 127.0, + [0.0], + [127.0], 7, False, - 0.0, - 127.0, - 1.0, + [0.0], + [127.0], + [1.0], ) - def test_fakeQuantWithMinMaxArgs_with7BitsSclngAndNdgngDown(self): + def test_fakeQuantWithMinMax_7BitsSclngAndNdgngDown(self): self._TestOp( - quantizers.fake_quant_with_min_max_args, - 0.5, - 64.0, + [0.5], + [64.0], 7, False, - 0.0, - 63.5, - 0.5, + [0.0], + [63.5], + [0.5], ) - def test_fakeQuantWithMinMaxArgs_with7BitsSclngAndNdgngUp(self): + def test_fakeQuantWithMinMax_7BitsSclngAndNdgngUp(self): self._TestOp( - quantizers.fake_quant_with_min_max_args, - -64.0, - -0.5, + [-64.0], + [-0.5], 7, False, - -63.5, - 0.0, - 0.5, + [-63.5], + [0.0], + [0.5], ) - def test_fakeQuantWithMinMaxArgs_with7BitsSclngAndNdgngBtwn(self): + def test_fakeQuantWithMinMax_7BitsSclngAndNdgngBtwn(self): self._TestOp( - quantizers.fake_quant_with_min_max_args, - -0.1, - 63.4, + [-0.1], + [63.4], 7, False, - 0.0, - 63.5, - 0.5, + [0.0], + [63.5], + [0.5], ) # 7 bits, narrow range. - def test_fakeQuantWithMinMaxArgs_with7BitsNrrwRangeNoSclngNoNdgng(self): + def test_fakeQuantWithMinMax_7BitsNrrwRangeNoSclngNoNdgng(self): self._TestOp( - quantizers.fake_quant_with_min_max_args, - 0.0, - 126.0, + [0.0], + [126.0], 7, True, - 0.0, - 126.0, - 1.0, + [0.0], + [126.0], + [1.0], ) - def test_fakeQuantWithMinMaxArgs_with7BitsNrrwRangeSclngAndNdgngDown(self): + def test_fakeQuantWithMinMax_7BitsNrrwRangeSclngAndNdgngDown(self): self._TestOp( - quantizers.fake_quant_with_min_max_args, - 0.1, - 63.1, + [0.1], + [63.1], 7, True, - 0.0, - 63.0, - 0.5, + [0.0], + [63.0], + [0.5], ) - def test_fakeQuantWithMinMaxArgs_with7BitsNrrwRangeSclngAndNdgngUp(self): + def test_fakeQuantWithMinMax_7BitsNrrwRangeSclngAndNdgngUp(self): self._TestOp( - quantizers.fake_quant_with_min_max_args, - -63.1, - -0.1, + [-63.1], + [-0.1], 7, True, - -63.0, - 0.0, - 0.5, + [-63.0], + [0.0], + [0.5], ) - def test_fakeQuantWithMinMaxArgs_with7BitsNrrwRangeSclngAndNdgngBtwn(self): + def test_fakeQuantWithMinMax_7BitsNrrwRangeSclngAndNdgngBtwn(self): self._TestOp( - quantizers.fake_quant_with_min_max_args, - -0.1, - 62.9, + [-0.1], + [62.9], 7, True, - 0.0, - 63.0, - 0.5, + [0.0], + [63.0], + [0.5], ) # 8 bits, wide range. - def test_fakeQuantWithMinMaxArgsGrad_with8BitsNoSclngNoNdgng(self): - self._TestGradOp( - quantizers.fake_quant_with_min_max_args_gradient, - 0.0, - 255.0, - 8, - False, - 0.0, - 255.0, - 1.0, - ) - - def test_fakeQuantWithMinMaxArgsGrad_with8BitsSclngAndNdgngDown(self): - self._TestGradOp( - quantizers.fake_quant_with_min_max_args_gradient, - 0.5, - 128.0, - 8, - False, - 0.0, - 127.5, - 0.5, - ) - - def test_fakeQuantWithMinMaxArgsGrad_with8BitsSclngAndNdgngUp(self): - self._TestGradOp( - quantizers.fake_quant_with_min_max_args_gradient, - -128.0, - -0.5, - 8, - False, - -127.5, - 0.0, - 0.5, - ) - - def test_fakeQuantWithMinMaxArgsGrad_with8BitsSclngAndNdgngBtwn(self): - self._TestGradOp( - quantizers.fake_quant_with_min_max_args_gradient, - -0.1, - 127.4, - 8, - False, - 0.0, - 127.5, - 0.5, - ) - - # 8 bits, narrow range. - def test_fakeQuantWithMinMaxArgsGrad_with8BitsNrrwRangeNoSclngNoNdgng( - self, - ): - self._TestGradOp( - quantizers.fake_quant_with_min_max_args_gradient, - 0.0, - 254.0, - 8, - True, - 0.0, - 254.0, - 1.0, - ) - - def test_fakeQuantWithMinMaxArgsGrad_with8BitsNrrwRangeSclngAndNdgngDown( - self, - ): - self._TestGradOp( - quantizers.fake_quant_with_min_max_args_gradient, - 0.1, - 127.1, - 8, - True, - 0.0, - 127.0, - 0.5, - ) - - def test_fakeQuantWithMinMaxArgsGrad_with8BitsNrrwRangeSclngAndNdgngUp( - self, - ): - self._TestGradOp( - quantizers.fake_quant_with_min_max_args_gradient, - -127.1, - -0.1, - 8, - True, - -127.0, - 0.0, - 0.5, - ) - - def test_fakeQuantWithMinMaxArgsGrad_with8BitsNrrwRangeSclngAndNdgngBtwn( - self, - ): - self._TestGradOp( - quantizers.fake_quant_with_min_max_args_gradient, - -0.1, - 126.9, - 8, - True, - 0.0, - 127.0, - 0.5, - ) - - # 7 bits, wide range. - def test_fakeQuantWithMinMaxArgsGrad_with7BitsNoSclngNoNdgng(self): - self._TestGradOp( - quantizers.fake_quant_with_min_max_args_gradient, - 0.0, - 127.0, - 7, - False, - 0.0, - 127.0, - 1.0, - ) - - def test_fakeQuantWithMinMaxArgsGrad_with7BitsSclngAndNdgngDown(self): - self._TestGradOp( - quantizers.fake_quant_with_min_max_args_gradient, - 0.5, - 64.0, - 7, - False, - 0.0, - 63.5, - 0.5, - ) - - def test_fakeQuantWithMinMaxArgsGrad_with7BitsSclngAndNdgngUp(self): - self._TestGradOp( - quantizers.fake_quant_with_min_max_args_gradient, - -64.0, - -0.5, - 7, - False, - -63.5, - 0.0, - 0.5, - ) - - def test_fakeQuantWithMinMaxArgsGrad_with7BitsSclngAndNdgngBtwn(self): - self._TestGradOp( - quantizers.fake_quant_with_min_max_args_gradient, - -0.1, - 63.4, - 7, - False, - 0.0, - 63.5, - 0.5, - ) - - # 7 bits, narrow range. - def test_fakeQuantWithMinMaxArgsGrad_with7BitsNrrwRangeNoSclngNoNdgng( - self, - ): - self._TestGradOp( - quantizers.fake_quant_with_min_max_args_gradient, - 0.0, - 126.0, - 7, - True, - 0.0, - 126.0, - 1.0, - ) - - def test_fakeQuantWithMinMaxArgsGrad_with7BitsNrrwRangeSclngAndNdgngDown( - self, - ): - self._TestGradOp( - quantizers.fake_quant_with_min_max_args_gradient, - 0.1, - 63.1, - 7, - True, - 0.0, - 63.0, - 0.5, - ) - - def test_fakeQuantWithMinMaxArgsGrad_with7BitsNrrwRangeSclngAndNdgngUp( - self, - ): - self._TestGradOp( - quantizers.fake_quant_with_min_max_args_gradient, - -63.1, - -0.1, - 7, - True, - -63.0, - 0.0, - 0.5, - ) - - def test_fakeQuantWithMinMaxArgsGrad_with7BitsNrrwRangeSclngAndNdgngBtwn( - self, - ): - self._TestGradOp( - quantizers.fake_quant_with_min_max_args_gradient, - -0.1, - 62.9, - 7, - True, - 0.0, - 63.0, - 0.5, - ) - - # 8 bits, wide range. - def test_fakeQuantWithMinMaxVarsPerChannel_with8Bits(self): - self._TestChannelsOp( - quantizers.fake_quant_with_min_max_vars_per_channel, - [0.0, 0.5, -128.0, -0.1], - [255.0, 128.0, -0.5, 127.4], - 8, - False, - [0.0, 0.0, -127.5, 0.0], - [255.0, 127.5, 0.0, 127.5], - [1.0, 0.5, 0.5, 0.5], - ) - - # 8 bits, narrow range. - def test_fakeQuantWithMinMaxVarsPerChannel_with8BitsNarrowRange(self): - self._TestChannelsOp( - quantizers.fake_quant_with_min_max_vars_per_channel, - [0.0, 0.1, -127.1, -0.1], - [254.0, 127.1, -0.1, 126.9], - 8, - True, - [0.0, 0.0, -127.0, 0.0], - [254.0, 127.0, 0.0, 127.0], - [1.0, 0.5, 0.5, 0.5], - ) - - # 7 bits, wide range. - def test_fakeQuantWithMinMaxVarsPerChannel_with7Bits(self): - self._TestChannelsOp( - quantizers.fake_quant_with_min_max_vars_per_channel, - [0.0, 0.5, -64.0, -0.1], - [127.0, 64.0, -0.5, 63.4], - 7, - False, - [0.0, 0.0, -63.5, 0.0], - [127.0, 63.5, 0.0, 63.5], - [1.0, 0.5, 0.5, 0.5], - ) - - # 7 bits, narrow range. - def test_fakeQuantWithMinMaxVarsPerChannel_with7BitsNarrowRange(self): - self._TestChannelsOp( - quantizers.fake_quant_with_min_max_vars_per_channel, - [0.0, 0.1, -63.1, -0.1], - [126.0, 63.1, -0.1, 62.9], - 7, - True, - [0.0, 0.0, -63.0, 0.0], - [126.0, 63.0, 0.0, 63.0], - [1.0, 0.5, 0.5, 0.5], - ) - - # 8 bits, wide range. - def test_fakeQuantWithMinMaxVarsPerChannelGradient_with8Bits(self): - self._TestChannelsGradOp( - quantizers.fake_quant_with_min_max_vars_per_channel_gradient, + def test_fakeQuantWithMinMax_8Bits(self): + self._TestOp( [0.0, 0.5, -128.0, -0.1], [255.0, 128.0, -0.5, 127.4], 8, @@ -869,9 +441,8 @@ def test_fakeQuantWithMinMaxVarsPerChannelGradient_with8Bits(self): ) # 8 bits, narrow range. - def test_fakeQuantWithMinMaxVarsPerChannelGrad_with8BitsNrrwRange(self): - self._TestChannelsGradOp( - quantizers.fake_quant_with_min_max_vars_per_channel_gradient, + def test_fakeQuantWithMinMax_8BitsNarrowRange(self): + self._TestOp( [0.0, 0.1, -127.1, -0.1], [254.0, 127.1, -0.1, 126.9], 8, @@ -882,9 +453,8 @@ def test_fakeQuantWithMinMaxVarsPerChannelGrad_with8BitsNrrwRange(self): ) # 7 bits, wide range. - def test_fakeQuantWithMinMaxVarsPerChannelGradient_with7Bits(self): - self._TestChannelsGradOp( - quantizers.fake_quant_with_min_max_vars_per_channel_gradient, + def test_fakeQuantWithMinMax_7Bits(self): + self._TestOp( [0.0, 0.5, -64.0, -0.1], [127.0, 64.0, -0.5, 63.4], 7, @@ -895,9 +465,8 @@ def test_fakeQuantWithMinMaxVarsPerChannelGradient_with7Bits(self): ) # 7 bits, narrow range. - def test_fakeQuantWithMinMaxVarsPerChannelGradient_with7BitsNrrwRange(self): - self._TestChannelsGradOp( - quantizers.fake_quant_with_min_max_vars_per_channel_gradient, + def test_fakeQuantWithMinMax_7BitsNarrowRange(self): + self._TestOp( [0.0, 0.1, -63.1, -0.1], [126.0, 63.1, -0.1, 62.9], 7,