diff --git a/keras/backend/tensorflow/numpy.py b/keras/backend/tensorflow/numpy.py index 1e181be04f6..1c652a5c17e 100644 --- a/keras/backend/tensorflow/numpy.py +++ b/keras/backend/tensorflow/numpy.py @@ -103,6 +103,7 @@ def is_valid_for_custom_ops(subscripts, *operands): # `None`. if subscripts in [ "a,b->ab", + "ab,bc->ac", "abc,cd->abd", "abcd,abed->abce", "abcd,adbe->acbe", @@ -158,6 +159,8 @@ def use_custom_ops(subscripts, *operands, output_type): x = tf.expand_dims(x, axis=-1) y = tf.expand_dims(y, axis=0) return tf.matmul(x, y, output_type=output_type) + elif subscripts == "ab,bc->ac": + return tf.matmul(x, y, output_type=output_type) elif subscripts == "abc,cd->abd": return tf.matmul(x, y, output_type=output_type) elif subscripts == "abc,cde->abde": diff --git a/keras/backend/torch/numpy.py b/keras/backend/torch/numpy.py index 6b80a8378ed..60499e60dce 100644 --- a/keras/backend/torch/numpy.py +++ b/keras/backend/torch/numpy.py @@ -35,8 +35,12 @@ def einsum(subscripts, *operands, **kwargs): # the behavior of jax. dtypes_to_resolve = list(set(standardize_dtype(x.dtype) for x in operands)) if len(dtypes_to_resolve) == 1 and dtypes_to_resolve[0] == "int8": + compute_dtype = "int32" + if get_device() == "cuda": + # TODO: torch.einsum doesn't support int32 when using cuda + compute_dtype = config.floatx() # prevent overflow - operands = [cast(operand, "int32") for operand in operands] + operands = [cast(operand, compute_dtype) for operand in operands] return cast(torch.einsum(subscripts, *operands), "int32") return torch.einsum(subscripts, *operands) diff --git a/keras/dtype_policies/__init__.py b/keras/dtype_policies/__init__.py index 871491d5c7d..027fa1dd092 100644 --- a/keras/dtype_policies/__init__.py +++ b/keras/dtype_policies/__init__.py @@ -1,19 +1,25 @@ from keras import backend from keras.dtype_policies import dtype_policy -from keras.saving import serialization_lib +from keras.dtype_policies.dtype_policy import FloatDTypePolicy +from keras.dtype_policies.dtype_policy import QuantizedDTypePolicy def get(identifier): + from keras.saving import serialization_lib + if identifier is None: return dtype_policy.dtype_policy() - if isinstance(identifier, dtype_policy.DTypePolicy): + if isinstance(identifier, (FloatDTypePolicy, QuantizedDTypePolicy)): return identifier if isinstance(identifier, dict): return serialization_lib.deserialize_keras_object(identifier) if isinstance(identifier, str): - return dtype_policy.DTypePolicy(identifier) + if "int8" in identifier: + return QuantizedDTypePolicy(identifier) + else: + return FloatDTypePolicy(identifier) try: - return dtype_policy.DTypePolicy(backend.standardize_dtype(identifier)) + return FloatDTypePolicy(backend.standardize_dtype(identifier)) except: raise ValueError( "Cannot interpret `dtype` argument. Expected a string " diff --git a/keras/dtype_policies/dtype_policy.py b/keras/dtype_policies/dtype_policy.py index 5dad6d093f1..26e063df677 100644 --- a/keras/dtype_policies/dtype_policy.py +++ b/keras/dtype_policies/dtype_policy.py @@ -55,16 +55,24 @@ class DTypePolicy: to explicitly construct a `DTypePolicy` object. """ - def __init__(self, name): + def __new__(cls, name): if not isinstance(name, str): raise TypeError( "'name' must be a string, such as 'mixed_float16'. " f"Received: name={name} (of type {type(name)})" ) + # For backwards compatibility + # TODO: We should consider deprecating this behavior + if cls is __class__: + if "int8" in name: + return QuantizedDTypePolicy(name) + return FloatDTypePolicy(name) + return super().__new__(cls) + + def __init__(self, name): self._name = name - self._compute_dtype, self._variable_dtype = self._parse_name(name) - # TODO: check that the current hardware supports the provided - # dtype policy and raise/warn otherwise. + self._compute_dtype = backend.floatx() + self._variable_dtype = backend.floatx() def _parse_name(self, name): """Parses a `DTypePolicy` name into a compute and variable dtype. @@ -75,19 +83,7 @@ def _parse_name(self, name): Returns: The `(compute_dtype, variable_dtype)` pair. """ - if name == "mixed_float16": - return "float16", "float32" - elif name == "mixed_bfloat16": - return "bfloat16", "float32" - try: - dtype = backend.standardize_dtype(name) - return dtype, dtype - except ValueError: - raise ValueError( - f"Cannot convert '{name}' to a mixed precision DTypePolicy." - " Valid policies include 'mixed_float16', 'mixed_bfloat16', " - "and the name of any dtype such as 'float32'." - ) + raise NotImplementedError @property def variable_dtype(self): @@ -132,9 +128,6 @@ def name(self): """Returns the name of this policy.""" return self._name - def __repr__(self): - return f'' - def convert_input(self, x, autocast, dtype): dtype = backend.standardize_dtype(dtype) if backend.is_tensor(x): @@ -165,6 +158,82 @@ def from_config(cls, config): return cls(**config) +@keras_export( + ["keras.FloatDTypePolicy", "keras.dtype_policies.FloatDTypePolicy"] +) +class FloatDTypePolicy(DTypePolicy): + def __init__(self, name): + super().__init__(name) + self._compute_dtype, self._variable_dtype = self._parse_name(name) + # TODO: check that the current hardware supports the provided + # dtype policy and raise/warn otherwise. + + def _parse_name(self, name): + if name == "mixed_float16": + return "float16", "float32" + elif name == "mixed_bfloat16": + return "bfloat16", "float32" + try: + dtype = backend.standardize_dtype(name) + return dtype, dtype + except ValueError: + raise ValueError( + f"Cannot convert '{name}' to a mixed precision " + "FloatDTypePolicy. Valid policies include 'mixed_float16', " + "'mixed_bfloat16', and the name of any float dtype such as " + "'float32'." + ) + + def __repr__(self): + return f'' + + +@keras_export( + ["keras.QuantizedDTypePolicy", "keras.dtype_policies.QuantizedDTypePolicy"] +) +class QuantizedDTypePolicy(DTypePolicy): + def __init__(self, name): + super().__init__(name) + self._quantization_mode, self._compute_dtype, self._variable_dtype = ( + self._parse_name(name) + ) + + def _parse_name(self, name): + error_msg = ( + f"Cannot convert '{name}' to a QuantizedDTypePolicy. " + "Valid policies include " + "'int8_from_float32', 'int8_from_float16', 'int8_from_bfloat16', " + "'int8_from_mixed_float16', 'int8_from_mixed_bfloat16'." + ) + split_name = name.split("_from_") + if len(split_name) != 2: + raise ValueError(error_msg) + mode, from_name = split_name + if mode not in ("int8",): + raise ValueError(error_msg) + if from_name == "mixed_float16": + return mode, "float16", "float32" + elif from_name == "mixed_bfloat16": + return mode, "bfloat16", "float32" + try: + dtype = backend.standardize_dtype(from_name) + return mode, dtype, dtype + except ValueError: + raise ValueError(error_msg) + + @property + def quantization_mode(self): + """The quantization mode of this policy. + + Returns: + The quantization mode of this policy, as a string. + """ + return self._quantization_mode + + def __repr__(self): + return f'' + + @keras_export( [ "keras.config.set_dtype_policy", @@ -181,7 +250,10 @@ def set_dtype_policy(policy): """ if not isinstance(policy, DTypePolicy): if isinstance(policy, str): - policy = DTypePolicy(policy) + if "int8" in policy: + policy = QuantizedDTypePolicy(policy) + else: + policy = FloatDTypePolicy(policy) else: raise ValueError( "Invalid `policy` argument. " @@ -204,6 +276,6 @@ def dtype_policy(): """Returns the current default dtype policy object.""" policy = global_state.get_global_attribute("dtype_policy", None) if policy is None: - policy = DTypePolicy(backend.floatx()) + policy = FloatDTypePolicy(backend.floatx()) set_dtype_policy(policy) return policy diff --git a/keras/dtype_policies/dtype_policy_test.py b/keras/dtype_policies/dtype_policy_test.py index f543226fe8f..8b3d6d43b6c 100644 --- a/keras/dtype_policies/dtype_policy_test.py +++ b/keras/dtype_policies/dtype_policy_test.py @@ -1,4 +1,6 @@ from keras.dtype_policies.dtype_policy import DTypePolicy +from keras.dtype_policies.dtype_policy import FloatDTypePolicy +from keras.dtype_policies.dtype_policy import QuantizedDTypePolicy from keras.dtype_policies.dtype_policy import dtype_policy from keras.dtype_policies.dtype_policy import set_dtype_policy from keras.testing import test_case @@ -48,7 +50,7 @@ def test_properties(self): def test_repr(self): """Test __repr__ method.""" policy = DTypePolicy("mixed_float16") - self.assertEqual(repr(policy), '') + self.assertEqual(repr(policy), '') def test_get_config_from_config(self): """Test get_config and from_config methods.""" @@ -60,6 +62,120 @@ def test_get_config_from_config(self): self.assertEqual(new_policy.name, "mixed_float16") +class FloatDTypePolicyTest(test_case.TestCase): + def test_initialization_valid_name(self): + """Test initialization with a valid name.""" + policy = FloatDTypePolicy("mixed_float16") + self.assertEqual(policy.compute_dtype, "float16") + self.assertEqual(policy.variable_dtype, "float32") + + def test_initialization_invalid_name(self): + """Test initialization with an invalid name.""" + with self.assertRaisesRegex(ValueError, "Cannot convert"): + FloatDTypePolicy("invalid_name") + + def test_initialization_non_string_name(self): + """Test initialization with a non-string name.""" + with self.assertRaisesRegex(TypeError, "'name' must be a string"): + FloatDTypePolicy(123) + + def test_properties_mixed_float16(self): + """Test properties for 'mixed_float16'.""" + policy = FloatDTypePolicy("mixed_float16") + self.assertEqual(policy.compute_dtype, "float16") + self.assertEqual(policy.variable_dtype, "float32") + + def test_properties_mixed_bfloat16(self): + """Test properties for 'mixed_bfloat16'.""" + policy = FloatDTypePolicy("mixed_bfloat16") + self.assertEqual(policy.compute_dtype, "bfloat16") + self.assertEqual(policy.variable_dtype, "float32") + + def test_initialization_with_invalid_name_behaviour(self): + """Test initialization behavior with an invalid name.""" + with self.assertRaisesRegex(ValueError, "Cannot convert"): + FloatDTypePolicy("invalid_name") + + def test_properties(self): + """Test variable_dtype, compute_dtype, and name properties.""" + policy = FloatDTypePolicy("mixed_float16") + self.assertEqual(policy.variable_dtype, "float32") + self.assertEqual(policy.compute_dtype, "float16") + self.assertEqual(policy.name, "mixed_float16") + + def test_repr(self): + """Test __repr__ method.""" + policy = FloatDTypePolicy("mixed_float16") + self.assertEqual(repr(policy), '') + + def test_get_config_from_config(self): + """Test get_config and from_config methods.""" + policy = FloatDTypePolicy("mixed_float16") + config = policy.get_config() + self.assertEqual(config, {"name": "mixed_float16"}) + + new_policy = FloatDTypePolicy.from_config(config) + self.assertEqual(new_policy.name, "mixed_float16") + + +class QuantizedDTypePolicyTest(test_case.TestCase): + def test_initialization_valid_name(self): + """Test initialization with a valid name.""" + policy = QuantizedDTypePolicy("int8_from_mixed_float16") + self.assertEqual(policy.compute_dtype, "float16") + self.assertEqual(policy.variable_dtype, "float32") + + def test_initialization_invalid_name(self): + """Test initialization with an invalid name.""" + with self.assertRaisesRegex(ValueError, "Cannot convert"): + QuantizedDTypePolicy("invalid_name") + + def test_initialization_non_string_name(self): + """Test initialization with a non-string name.""" + with self.assertRaisesRegex(TypeError, "'name' must be a string"): + QuantizedDTypePolicy(123) + + def test_properties_mixed_float16(self): + """Test properties for 'mixed_float16'.""" + policy = QuantizedDTypePolicy("int8_from_mixed_float16") + self.assertEqual(policy.compute_dtype, "float16") + self.assertEqual(policy.variable_dtype, "float32") + + def test_properties_mixed_bfloat16(self): + """Test properties for 'mixed_bfloat16'.""" + policy = QuantizedDTypePolicy("int8_from_mixed_bfloat16") + self.assertEqual(policy.compute_dtype, "bfloat16") + self.assertEqual(policy.variable_dtype, "float32") + + def test_initialization_with_invalid_name_behaviour(self): + """Test initialization behavior with an invalid name.""" + with self.assertRaisesRegex(ValueError, "Cannot convert"): + QuantizedDTypePolicy("invalid_name") + + def test_properties(self): + """Test variable_dtype, compute_dtype, and name properties.""" + policy = QuantizedDTypePolicy("int8_from_mixed_float16") + self.assertEqual(policy.variable_dtype, "float32") + self.assertEqual(policy.compute_dtype, "float16") + self.assertEqual(policy.name, "int8_from_mixed_float16") + + def test_repr(self): + """Test __repr__ method.""" + policy = QuantizedDTypePolicy("int8_from_mixed_float16") + self.assertEqual( + repr(policy), '' + ) + + def test_get_config_from_config(self): + """Test get_config and from_config methods.""" + policy = QuantizedDTypePolicy("int8_from_mixed_float16") + config = policy.get_config() + self.assertEqual(config, {"name": "int8_from_mixed_float16"}) + + new_policy = QuantizedDTypePolicy.from_config(config) + self.assertEqual(new_policy.name, "int8_from_mixed_float16") + + class DTypePolicyGlobalFunctionsTest(test_case.TestCase): def setUp(self): """Reset the global dtype policy before each test.""" @@ -71,13 +187,26 @@ def test_set_dtype_policy_valid_string(self): policy = dtype_policy() self.assertEqual(policy.name, "mixed_float16") + def test_set_dtype_policy_valid_string_quantized(self): + """Test set_dtype_policy with a valid string.""" + set_dtype_policy("int8_from_mixed_float16") + policy = dtype_policy() + self.assertEqual(policy.name, "int8_from_mixed_float16") + def test_set_dtype_policy_valid_policy(self): - """Test set_dtype_policy with a valid DTypePolicy object.""" - policy_obj = DTypePolicy("mixed_float16") + """Test set_dtype_policy with a valid FloatDTypePolicy object.""" + policy_obj = FloatDTypePolicy("mixed_float16") set_dtype_policy(policy_obj) policy = dtype_policy() self.assertEqual(policy.name, "mixed_float16") + def test_set_dtype_policy_valid_policy_quantized(self): + """Test set_dtype_policy with a valid FloatDTypePolicy object.""" + policy_obj = QuantizedDTypePolicy("int8_from_mixed_float16") + set_dtype_policy(policy_obj) + policy = dtype_policy() + self.assertEqual(policy.name, "int8_from_mixed_float16") + def test_set_dtype_policy_invalid(self): """Test set_dtype_policy with an invalid input.""" with self.assertRaisesRegex(ValueError, "Invalid `policy` argument"): @@ -89,26 +218,48 @@ def test_dtype_policy_default(self): self.assertEqual(policy.name, "float32") -class DTypePolicyEdgeCasesTest(test_case.TestCase): +class FloatDTypePolicyEdgeCasesTest(test_case.TestCase): + def test_empty_name(self): + """Test initialization with an empty name.""" + with self.assertRaisesRegex(ValueError, "Cannot convert"): + FloatDTypePolicy("") + + def test_special_character_name(self): + """Test initialization with special characters in the name.""" + with self.assertRaisesRegex(ValueError, "Cannot convert"): + FloatDTypePolicy("@mixed_float16!") + + def test_very_long_name(self): + """Test initialization with a very long name.""" + with self.assertRaisesRegex(ValueError, "Cannot convert"): + FloatDTypePolicy("mixed_float16" * 100) + + def test_almost_valid_name(self): + """Test initialization with a name close to a valid one.""" + with self.assertRaisesRegex(ValueError, "Cannot convert"): + FloatDTypePolicy("mixed_float15") + + +class QuantizedDTypePolicyEdgeCasesTest(test_case.TestCase): def test_empty_name(self): """Test initialization with an empty name.""" with self.assertRaisesRegex(ValueError, "Cannot convert"): - DTypePolicy("") + QuantizedDTypePolicy("") def test_special_character_name(self): """Test initialization with special characters in the name.""" with self.assertRaisesRegex(ValueError, "Cannot convert"): - DTypePolicy("@mixed_float16!") + QuantizedDTypePolicy("@int8_from_mixed_float16!") def test_very_long_name(self): """Test initialization with a very long name.""" with self.assertRaisesRegex(ValueError, "Cannot convert"): - DTypePolicy("mixed_float16" * 100) + QuantizedDTypePolicy("int8_from_mixed_float16" * 100) def test_almost_valid_name(self): """Test initialization with a name close to a valid one.""" with self.assertRaisesRegex(ValueError, "Cannot convert"): - DTypePolicy("mixed_float15") + QuantizedDTypePolicy("int7_from_mixed_float16") class DTypePolicyGlobalFunctionsEdgeCasesTest(test_case.TestCase): diff --git a/keras/layers/core/dense.py b/keras/layers/core/dense.py index 487700583b7..a2210c83241 100644 --- a/keras/layers/core/dense.py +++ b/keras/layers/core/dense.py @@ -1,9 +1,12 @@ import numpy as np from keras import activations +from keras import backend from keras import constraints +from keras import dtype_policies from keras import initializers from keras import ops +from keras import quantizers from keras import regularizers from keras.api_export import keras_export from keras.layers.input_spec import InputSpec @@ -120,6 +123,8 @@ def build(self, input_shape): self.built = True if self.lora_rank: self.enable_lora(self.lora_rank) + if isinstance(self.dtype_policy, dtype_policies.QuantizedDTypePolicy): + self.quantize(self.dtype_policy.quantization_mode) @property def kernel(self): @@ -136,7 +141,21 @@ def kernel(self): def call(self, inputs): x = ops.matmul(inputs, self.kernel) if self.bias is not None: - x = x + self.bias + x = ops.add(x, self.bias) + if self.activation is not None: + x = self.activation(x) + return x + + def quantized_call(self, inputs): + if self.lora_enabled: + raise ValueError("`quantized_call` doesn't support lora weights") + inputs, inputs_scale = self.inputs_quantizer(inputs) + x = ops.matmul(inputs, self.kernel) + # De-scale outputs + x = ops.cast(x, self.compute_dtype) + x = ops.divide(x, ops.multiply(inputs_scale, self.kernel_scale)) + if self.bias is not None: + x = ops.add(x, self.bias) if self.activation is not None: x = self.activation(x) return x @@ -177,9 +196,71 @@ def enable_lora( initializer=initializers.get(b_initializer), regularizer=self.kernel_regularizer, ) - self.kernel.trainable = False + self._kernel.trainable = False self._tracker.lock() self.lora_enabled = True + self.lora_rank = rank + + def quantize(self, mode): + self._check_quantize_args(mode, self.compute_dtype) + if mode == "int8": + if backend.standardize_dtype(self._kernel.dtype) == "int8": + raise ValueError("`quantize` can only be done once per layer.") + # Merge lora-related parameters to make use of fully int8 kernel + self._merge_lora_into_kernel() + # Configure `self.inputs_quantizer` + self.inputs_quantizer = quantizers.AbsMaxQuantizer(axis=-1) + # Quantize `self._kernel` to int8 and compute corresponding scale + kernel_value, kernel_scale = quantizers.abs_max_quantize( + self._kernel, axis=0 + ) + kernel_scale = ops.cast(kernel_scale, self.compute_dtype) + self._tracker.unlock() + self._untrack_variable(self._kernel) + self._kernel = self.add_weight( + name="kernel", + shape=self._kernel.shape, + # Prevent adding a large constant to the computation graph + initializer=lambda shape, dtype: kernel_value, + dtype="int8", + trainable=False, + ) + self.kernel_scale = self.add_weight( + name="kernel_scale", + shape=kernel_scale.shape, + # Prevent adding a large constant to the computation graph + initializer=lambda shape, dtype: kernel_scale, + dtype=self.compute_dtype, + trainable=False, + ) + if self.bias is not None: + self.bias.trainable = False + self._tracker.lock() + else: + NotImplementedError( + "Invalid quantization mode. Expected 'int8'. " + f"Received: mode={mode}" + ) + + # Set new dtype policy + if not isinstance( + self.dtype_policy, dtype_policies.QuantizedDTypePolicy + ): + quantized_dtype = f"{mode}_from_{self.dtype_policy.name}" + self.dtype_policy = dtype_policies.get(quantized_dtype) + + def _merge_lora_into_kernel(self, untrack=False): + if not self.lora_enabled: + return + # Merge lora-enabled kernel into kernel + self._kernel.assign(self.kernel) + self.lora_enabled = False + if untrack: + self._tracker.unlock() + self.lora_kernel_a = self._untrack_variable(self.lora_kernel_a) + self.lora_kernel_b = self._untrack_variable(self.lora_kernel_b) + self._tracker.lock() + self.lora_rank = None def save_own_variables(self, store): if not self.lora_enabled: diff --git a/keras/layers/core/dense_test.py b/keras/layers/core/dense_test.py index 042db4bb0b5..fcfffed1653 100644 --- a/keras/layers/core/dense_test.py +++ b/keras/layers/core/dense_test.py @@ -241,7 +241,7 @@ def test_enable_lora(self): model.save(temp_filepath) new_model = saving.load_model(temp_filepath) - self.assertFalse(new_model.layers[0].lora_enabled) + self.assertTrue(new_model.layers[0].lora_enabled) self.assertAllClose(model.predict(x), new_model.predict(x)) # Try saving and reloading the model's weights only @@ -304,3 +304,68 @@ def test_enable_lora_when_already_enabled(self): layer.enable_lora(rank=2) with self.assertRaisesRegex(ValueError, "lora is already enabled"): layer.enable_lora(rank=2) + + def test_quantize_int8(self): + layer = layers.Dense(units=16) + layer.build((None, 8)) + layer.quantize("int8") + + # Try eager call + x = np.random.random((2, 8)) + _ = layer(x) + + # Try saving and reloading the model + model = models.Sequential([layer]) + temp_filepath = os.path.join( + self.get_temp_dir(), "quantized_model.keras" + ) + model.save(temp_filepath) + new_model = saving.load_model(temp_filepath) + self.assertAllClose(model.predict(x), new_model.predict(x)) + + # Try saving and reloading the model's weights only + temp_filepath = os.path.join( + self.get_temp_dir(), "quantized_model.weights.h5" + ) + model.save_weights(temp_filepath) + + # Try lora + layer = layers.Dense(units=16) + layer.build((None, 8)) + layer.enable_lora(4) + layer.quantize("int8") + x = np.random.random((2, 8)) + _ = layer(x) + + @pytest.mark.requires_trainable_backend + def test_quantize_dtype_argument(self): + self.run_layer_test( + layers.Dense, + init_kwargs={ + "units": 5, + "dtype": "int8_from_mixed_bfloat16", + }, + input_shape=(2, 3, 4), + expected_output_shape=(2, 3, 5), + expected_num_trainable_weights=0, + expected_num_non_trainable_weights=3, + expected_num_seed_generators=0, + expected_num_losses=0, + supports_masking=True, + ) + + def test_quantize_on_unbuilt_layer(self): + layer = layers.Dense(units=2) + with self.assertRaisesRegex( + ValueError, "Cannot quantize on a layer that isn't yet built." + ): + layer.quantize("int8") + + def test_quantize_when_already_quantized(self): + layer = layers.Dense(units=2) + layer.build((None, 2)) + layer.quantize("int8") + with self.assertRaisesRegex( + ValueError, "`quantize` can only be done once per layer." + ): + layer.quantize("int8") diff --git a/keras/layers/core/einsum_dense.py b/keras/layers/core/einsum_dense.py index a950ac4992e..e7609e75539 100644 --- a/keras/layers/core/einsum_dense.py +++ b/keras/layers/core/einsum_dense.py @@ -1,13 +1,18 @@ import re +import string import numpy as np from keras import activations +from keras import backend from keras import constraints +from keras import dtype_policies from keras import initializers from keras import ops +from keras import quantizers from keras import regularizers from keras.api_export import keras_export +from keras.layers.input_spec import InputSpec from keras.layers.layer import Layer @@ -170,9 +175,12 @@ def build(self, input_shape): ) else: self.bias = None - super().build(input_shape) + self.input_spec = InputSpec(ndim=len(input_shape)) + self.built = True if self.lora_rank: self.enable_lora(self.lora_rank) + if isinstance(self.dtype_policy, dtype_policies.QuantizedDTypePolicy): + self.quantize(self.dtype_policy.quantization_mode) @property def kernel(self): @@ -222,6 +230,30 @@ def call(self, inputs): x = self.activation(x) return x + def quantized_call(self, inputs): + if self.lora_enabled: + raise ValueError("`quantized_call` doesn't support lora weights") + inputs, inputs_scale = self.inputs_quantizer(inputs) + x = ops.einsum(self.equation, inputs, self.kernel) + # Deal with `inputs_scale` + inputs_scale = ops.transpose(inputs_scale, self._input_transpose_axes) + if self._input_expand_axes: + inputs_scale = ops.expand_dims( + inputs_scale, axis=self._input_expand_axes + ) + if self._input_squeeze_axes: + inputs_scale = ops.squeeze( + inputs_scale, axis=self._input_squeeze_axes + ) + # De-scale outputs + x = ops.cast(x, self.compute_dtype) + x = ops.divide(x, ops.multiply(inputs_scale, self.kernel_scale)) + if self.bias is not None: + x += self.bias + if self.activation is not None: + x = self.activation(x) + return x + def enable_lora( self, rank, a_initializer="he_uniform", b_initializer="zeros" ): @@ -253,9 +285,96 @@ def enable_lora( initializer=initializers.get(b_initializer), regularizer=self.kernel_regularizer, ) - self.kernel.trainable = False + self._kernel.trainable = False self._tracker.lock() self.lora_enabled = True + self.lora_rank = rank + + def quantize(self, mode): + self._check_quantize_args(mode, self.compute_dtype) + if mode == "int8": + if backend.standardize_dtype(self._kernel.dtype) == "int8": + raise ValueError("`quantize` can only be done once per layer.") + # Merge lora-related parameters to make use of fully int8 kernel + self._merge_lora_into_kernel() + + if self.input_spec is None: + raise ValueError( + f"Cannot quantize {self.name} that isn't yet built." + ) + ( + self._input_reduced_axes, + self._kernel_reduced_axes, + self._input_transpose_axes, + self._kernel_transpose_axes, + self._input_expand_axes, + self._kernel_expand_axes, + self._input_squeeze_axes, + self._kernel_squeeze_axes, + ) = _analyze_quantization_info(self.equation, self.input_spec.ndim) + # Configure `self.inputs_quantizer` + self.inputs_quantizer = quantizers.AbsMaxQuantizer( + axis=self._input_reduced_axes + ) + # Quantize `self._kernel` to int8 and compute corresponding scale + kernel_value, kernel_scale = quantizers.abs_max_quantize( + self._kernel, axis=self._kernel_reduced_axes + ) + kernel_scale = ops.cast(kernel_scale, self.compute_dtype) + kernel_scale = ops.transpose( + kernel_scale, self._kernel_transpose_axes + ) + if self._kernel_expand_axes: + kernel_scale = ops.expand_dims( + kernel_scale, axis=self._kernel_expand_axes + ) + if self._kernel_squeeze_axes: + kernel_scale = ops.squeeze( + kernel_scale, axis=self._kernel_squeeze_axes + ) + self._tracker.unlock() + self._untrack_variable(self._kernel) + self._kernel = self.add_weight( + name="kernel", + shape=self._kernel.shape, + # Prevent adding a large constant to the computation graph + initializer=lambda shape, dtype: kernel_value, + dtype="int8", + trainable=False, + ) + self.kernel_scale = self.add_weight( + name="kernel_scale", + shape=kernel_scale.shape, + # Prevent adding a large constant to the computation graph + initializer=lambda shape, dtype: kernel_scale, + dtype=self.compute_dtype, + trainable=False, + ) + if self.bias is not None: + self.bias.trainable = False + self._tracker.lock() + else: + NotImplementedError() + + # Set new dtype policy + if not isinstance( + self.dtype_policy, dtype_policies.QuantizedDTypePolicy + ): + quantized_dtype = f"{mode}_from_{self.dtype_policy.name}" + self.dtype_policy = dtype_policies.get(quantized_dtype) + + def _merge_lora_into_kernel(self, untrack=False): + if not self.lora_enabled: + return + # Merge lora-enabled kernel into kernel + self._kernel.assign(self.kernel) + self.lora_enabled = False + if untrack: + self._tracker.unlock() + self.lora_kernel_a = self._untrack_variable(self.lora_kernel_a) + self.lora_kernel_b = self._untrack_variable(self.lora_kernel_b) + self._tracker.lock() + self.lora_rank = None def save_own_variables(self, store): if not self.lora_enabled: @@ -423,3 +542,130 @@ def _analyze_split_string( bias_shape = None return weight_shape, bias_shape, output_shape + + +def _analyze_quantization_info(equation, input_shape): + + def get_specs(equation, input_shape): + possible_labels = string.ascii_letters + dot_replaced_string = re.sub(r"\.\.\.", "0", equation) + + # This is the case where no ellipses are present in the string. + split_string = re.match( + "([a-zA-Z]+),([a-zA-Z]+)->([a-zA-Z]+)", dot_replaced_string + ) + if split_string is not None: + input_spec = split_string.group(1) + weight_spec = split_string.group(2) + output_spec = split_string.group(3) + return input_spec, weight_spec, output_spec + + # This is the case where ellipses are present on the left. + split_string = re.match( + "0([a-zA-Z]+),([a-zA-Z]+)->0([a-zA-Z]+)", dot_replaced_string + ) + if split_string is not None: + input_spec = split_string.group(1) + weight_spec = split_string.group(2) + output_spec = split_string.group(3) + elided = len(input_shape) - len(input_spec) + possible_labels = sorted( + set(possible_labels) + - set(input_spec) + - set(weight_spec) + - set(output_spec) + ) + # Pad labels on the left to `input_spec` and `output_spec` + for i in range(elided): + input_spec = possible_labels[i] + input_spec + output_spec = possible_labels[i] + output_spec + return input_spec, weight_spec, output_spec + + # This is the case where ellipses are present on the right. + split_string = re.match( + "([a-zA-Z]{2,})0,([a-zA-Z]+)->([a-zA-Z]+)0", dot_replaced_string + ) + if split_string is not None: + input_spec = split_string.group(1) + weight_spec = split_string.group(2) + output_spec = split_string.group(3) + elided = len(input_shape) - len(input_spec) + possible_labels = sorted( + set(possible_labels) + - set(input_spec) + - set(weight_spec) + - set(output_spec) + ) + # Pad labels on the right to `input_spec` and `output_spec` + for i in range(elided): + input_spec = input_spec + possible_labels[i] + output_spec = output_spec + possible_labels[i] + return input_spec, weight_spec, output_spec + + raise ValueError( + f"Invalid einsum equation '{equation}'. Equations must be in the " + "form [X],[Y]->[Z], ...[X],[Y]->...[Z], or [X]...,[Y]->[Z]...." + ) + + input_spec, weight_spec, output_spec = get_specs(equation, input_shape) + + # Determine the axes that should be reduced by the quantizer + input_reduced_axes = [] + weight_reduced_axes = [] + for i, label in enumerate(input_spec): + index = output_spec.find(label) + if index == -1: + input_reduced_axes.append(i) + for i, label in enumerate(weight_spec): + index = output_spec.find(label) + if index == -1: + weight_reduced_axes.append(i) + + # Determine the axes of `ops.expand_dims` + input_expand_axes = [] + weight_expand_axes = [] + for i, label in enumerate(output_spec): + index_input = input_spec.find(label) + index_weight = weight_spec.find(label) + if index_input == -1: + input_expand_axes.append(i) + if index_weight == -1: + weight_expand_axes.append(i) + + # Determine the axes of `ops.transpose` + input_transpose_axes = [] + weight_transpose_axes = [] + for i, label in enumerate(output_spec): + index_input = input_spec.find(label) + index_weight = weight_spec.find(label) + if index_input != -1: + input_transpose_axes.append(index_input) + if index_weight != -1: + weight_transpose_axes.append(index_weight) + # Postprocess the information: + # 1. Add dummy axes (1) to transpose_axes + # 2. Add axis to squeeze_axes if 1. failed + input_squeeze_axes = [] + weight_squeeze_axes = [] + for ori_index in input_reduced_axes: + try: + index = input_expand_axes.pop(0) + except IndexError: + input_squeeze_axes.append(ori_index) + input_transpose_axes.insert(index, ori_index) + for ori_index in weight_reduced_axes: + try: + index = weight_expand_axes.pop(0) + except IndexError: + weight_squeeze_axes.append(ori_index) + weight_transpose_axes.insert(index, ori_index) + return ( + input_reduced_axes, + weight_reduced_axes, + input_transpose_axes, + weight_transpose_axes, + input_expand_axes, + weight_expand_axes, + input_squeeze_axes, + weight_squeeze_axes, + ) diff --git a/keras/layers/core/einsum_dense_test.py b/keras/layers/core/einsum_dense_test.py index b064f4fb02a..e7126ed3f3c 100644 --- a/keras/layers/core/einsum_dense_test.py +++ b/keras/layers/core/einsum_dense_test.py @@ -326,7 +326,7 @@ def test_enable_lora(self): model.save(temp_filepath) new_model = saving.load_model(temp_filepath) - self.assertFalse(new_model.layers[0].lora_enabled) + self.assertTrue(new_model.layers[0].lora_enabled) self.assertAllClose(model.predict(x), new_model.predict(x)) # Try saving and reloading the model's weights only @@ -372,3 +372,86 @@ def test_lora_rank_argument(self): expected_num_losses=0, supports_masking=False, ) + + def test_quantize_int8(self): + layer = layers.EinsumDense( + equation="ab,bcd->acd", + output_shape=(8, 32), + bias_axes="d", + ) + layer.build((None, 3)) + layer.quantize("int8") + + # Try eager call + x = np.random.random((2, 3)) + _ = layer(x) + + # Try saving and reloading the model + model = models.Sequential([layer]) + temp_filepath = os.path.join( + self.get_temp_dir(), "quantized_model.keras" + ) + model.save(temp_filepath) + new_model = saving.load_model(temp_filepath) + self.assertAllClose(model.predict(x), new_model.predict(x)) + + # Try saving and reloading the model's weights only + temp_filepath = os.path.join( + self.get_temp_dir(), "quantized_model.weights.h5" + ) + model.save_weights(temp_filepath) + + # Try lora + layer = layers.EinsumDense( + equation="ab,bcd->acd", + output_shape=(8, 32), + bias_axes="d", + ) + layer.build((None, 3)) + layer.enable_lora(2) + layer.quantize("int8") + x = np.random.random((2, 3)) + _ = layer(x) + + @pytest.mark.requires_trainable_backend + def test_quantize_dtype_argument(self): + self.run_layer_test( + layers.EinsumDense, + init_kwargs={ + "equation": "ab,bcd->acd", + "output_shape": (8, 32), + "bias_axes": "d", + "dtype": "int8_from_mixed_bfloat16", + }, + input_shape=(2, 3), + expected_output_shape=(2, 8, 32), + expected_num_trainable_weights=0, + expected_num_non_trainable_weights=3, + expected_num_seed_generators=0, + expected_num_losses=0, + supports_masking=False, + ) + + def test_quantize_on_unbuilt_layer(self): + layer = layers.EinsumDense( + equation="ab,bcd->acd", + output_shape=(8, 32), + bias_axes="d", + ) + with self.assertRaisesRegex( + ValueError, "Cannot quantize on a layer that isn't yet built." + ): + layer.quantize("int8") + + def test_quantize_when_already_quantized(self): + layer = layers.EinsumDense( + equation="ab,bcd->acd", + output_shape=(8, 32), + bias_axes="d", + ) + layer.build((None, 3)) + layer.quantize("int8") + with self.assertRaisesRegex( + ValueError, "`quantize` can only be done once per layer." + ): + layer.quantize("int8") diff --git a/keras/layers/layer.py b/keras/layers/layer.py index a93e304a8ee..b41ead720e1 100644 --- a/keras/layers/layer.py +++ b/keras/layers/layer.py @@ -245,7 +245,7 @@ def __init__( ): BackendLayer.__init__(self) self._lock = False - Operation.__init__(self, name=name) + Operation.__init__(self, dtype=dtype, name=name) self.activity_regularizer = regularizers.get(activity_regularizer) input_dim_arg = kwargs.pop("input_dim", None) if input_dim_arg is not None: @@ -268,7 +268,6 @@ def __init__( ) self.built = False - self.dtype_policy = dtype_policies.get(dtype) self.autocast = autocast self._input_spec = None self._called = False @@ -862,6 +861,12 @@ def call(self, *args, **kwargs): "method implemented." ) + def quantized_call(self, *args, **kwargs): + raise NotImplementedError( + f"Layer {self.__class__.__name__} does not have a " + "`quantized_call()` method implemented." + ) + @traceback_utils.filter_traceback def stateless_call( self, @@ -951,7 +956,12 @@ def stateless_call( with backend.StatelessScope( state_mapping=mapping, collect_losses=return_losses ) as scope: - outputs = self.call(*args, **kwargs) + if isinstance( + self.dtype_policy, dtype_policies.QuantizedDTypePolicy + ): + outputs = self.quantized_call(*args, **kwargs) + else: + outputs = self.call(*args, **kwargs) if return_losses: losses = self.losses @@ -1094,6 +1104,27 @@ def _clear_losses(self): for layer in self._layers: layer._clear_losses() + def quantize(self, mode): + raise NotImplementedError( + f"Layer {self.__class__.__name__} does not have a `quantize()` " + "method implemented." + ) + + def _check_quantize_args(self, mode, compute_dtype): + if not self.built: + raise ValueError("Cannot quantize on a layer that isn't yet built.") + if mode not in ("int8",): + raise ValueError( + f"`quantize` must be one of ('int8'). Received: mode={mode}" + ) + if mode == "int8" and compute_dtype == "float16": + raise ValueError( + f"mode='{mode}' doesn't work well with " + "compute_dtype='float16'. Consider loading model/layer with " + "other dtype policy such as 'mixed_bfloat16' before calling " + "`quantize`." + ) + def save_own_variables(self, store): """Saves the state of the layer. diff --git a/keras/models/model.py b/keras/models/model.py index 3132350d156..2adb88bd7c5 100644 --- a/keras/models/model.py +++ b/keras/models/model.py @@ -361,6 +361,42 @@ def load_weights(self, filepath, skip_mismatch=False, **kwargs): self, filepath, skip_mismatch=skip_mismatch, **kwargs ) + def quantize(self, mode): + """Quantize the weights of the model. + + Note that the model must be built first before calling this method. + `quantize` will recursively call `quantize(mode)` in all layers and + will be skipped if the layer doesn't implement the function. + + Args: + mode: The mode of the quantization. Only 'int8' is supported at this + time. + """ + if not self.built: + raise ValueError( + "The model must be built first before calling `quantize()`." + ) + if mode not in ("int8",): + raise ValueError( + "Invalid quantization mode. Expected 'int8'. " + f"Received: mode={mode}" + ) + mode_changed = False + for layer in self._flatten_layers(): + list_of_sublayers = list(layer._flatten_layers()) + if len(list_of_sublayers) == 1: # leaves of the model + try: + layer.quantize(mode) + mode_changed = True + except NotImplementedError as e: + warnings.warn(str(e)) + # We need to set these functions to `None` to remake them for changed + # call function + if mode_changed: + self.train_function = None + self.test_function = None + self.predict_function = None + def build_from_config(self, config): if not config: return diff --git a/keras/models/model_test.py b/keras/models/model_test.py index 09f168b94b7..263a200ba28 100644 --- a/keras/models/model_test.py +++ b/keras/models/model_test.py @@ -2,6 +2,7 @@ import pytest from absl.testing import parameterized +from keras import backend from keras import layers from keras import testing from keras.layers.core.input_layer import Input @@ -561,3 +562,82 @@ def test_functional_list_outputs_invalid_nested_list_losses(self): "it should have as many entries as the model has outputs", ): model.fit(x, (y1, y2), batch_size=2, epochs=1, verbose=0) + + def test_quantize(self): + model = _get_model() + x1 = np.random.rand(2, 3) + x2 = np.random.rand(2, 3) + model.quantize("int8") + _ = model((x1, x2)) + + for layer in model._flatten_layers(): + if isinstance(layer, (layers.Dense, layers.EinsumDense)): + self.assertEqual(layer.dtype_policy.name, "int8_from_float32") + self.assertEqual(layer.dtype_policy.quantization_mode, "int8") + + def test_quantize_unbuilt(self): + class MyModel(Model): + def __init__(self): + super().__init__() + self.dense1 = layers.Dense(32, activation="relu") + self.dense2 = layers.Dense(5, activation="softmax") + self.dropout = layers.Dropout(0.5) + + def call(self, inputs, training=False): + x = self.dense1(inputs) + x = self.dropout(x, training=training) + return self.dense2(x) + + model = MyModel() + with self.assertRaisesRegex( + ValueError, "The model must be built first before calling" + ): + model.quantize("int8") + + x = np.random.rand(2, 3) + _ = model(x) + model.quantize("int8") + + def test_quantize_invalid_args(self): + model = _get_model() + with self.assertRaisesRegex( + ValueError, "Invalid quantization mode. Expected 'int8'." + ): + model.quantize("abc") + + def test_quantize_nested_model(self): + class NestedLayer(layers.Layer): + def __init__(self, units): + super().__init__() + self.dense = layers.Dense(units) + + def call(self, x): + x = self.dense(x) + return x + + class DoubleNestedLayer(layers.Layer): + def __init__(self, units): + super().__init__() + self.nested_dense1 = NestedLayer(units) + self.nested_dense2 = NestedLayer(units) + self.dense = layers.Dense(units) + + def call(self, x): + x = self.nested_dense1(x) + x = self.nested_dense2(x) + x = self.dense(x) + return x + + inputs = layers.Input([3]) + outputs = DoubleNestedLayer(8)(inputs) + model = Model(inputs, outputs) + model.quantize("int8") + + kernel_count = 0 + for weight in model.weights: + if weight.name == "kernel": + kernel_count += 1 + self.assertEqual( + backend.standardize_dtype(weight.dtype), "int8" + ) + self.assertEqual(kernel_count, 3) diff --git a/keras/ops/numpy_test.py b/keras/ops/numpy_test.py index 72279d9419e..b1f8fb59d30 100644 --- a/keras/ops/numpy_test.py +++ b/keras/ops/numpy_test.py @@ -2161,7 +2161,11 @@ def test_einsum(self): ) self.assertAllClose(knp.Einsum(",ijk")(5, y), np.einsum(",ijk", 5, y)) - def test_einsum_with_custom_ops(self): + @pytest.mark.skipif( + backend.backend() != "tensorflow", + reason=f"{backend.backend()} doesn't implement custom ops for einsum.", + ) + def test_einsum_custom_ops_for_tensorflow(self): subscripts = "a,b->ab" x = np.arange(2).reshape([2]).astype("float32") y = np.arange(3).reshape([3]).astype("float32") @@ -2169,6 +2173,13 @@ def test_einsum_with_custom_ops(self): knp.einsum(subscripts, x, y), np.einsum(subscripts, x, y) ) + subscripts = "ab,bc->ac" + x = np.arange(6).reshape([2, 3]).astype("float32") + y = np.arange(12).reshape([3, 4]).astype("float32") + self.assertAllClose( + knp.einsum(subscripts, x, y), np.einsum(subscripts, x, y) + ) + subscripts = "abc,cd->abd" x = np.arange(24).reshape([2, 3, 4]).astype("float32") y = np.arange(20).reshape([4, 5]).astype("float32") @@ -5634,49 +5645,69 @@ def get_input_shapes(subscripts): expected_dtype, ) - # Test custom implementation of einsum for tensorflow - if backend.backend() == "tensorflow": - for subscripts in [ - "a,b->ab", - "abc,cd->abd", - "abc,cde->abde", - "abc,dce->abde", - "abcd,abed->abce", - "abcd,adbe->acbe", - "abcd,aecd->acbe", - "abcd,aecd->aceb", - "abcd,cde->abe", - "abcde,aebf->adbcf", - "abcde,afce->acdbf", - ]: - x1_shape, x2_shape = get_input_shapes(subscripts) - x1 = knp.ones(x1_shape, dtype=dtype1) - x2 = knp.ones(x2_shape, dtype=dtype2) - x1_jax = jnp.ones(x1_shape, dtype=dtype1) - x2_jax = jnp.ones(x2_shape, dtype=dtype2) - if dtype1 == "int8" and dtype2 == "int8": - preferred_element_type = "int32" - else: - preferred_element_type = None - expected_dtype = standardize_dtype( - jnp.einsum( - subscripts, - x1_jax, - x2_jax, - preferred_element_type=preferred_element_type, - ).dtype - ) + @parameterized.named_parameters( + named_product( + dtypes=list(itertools.combinations(ALL_DTYPES, 2)) + + [("int8", "int8")] + ) + ) + @pytest.mark.skipif( + backend.backend() != "tensorflow", + reason=f"{backend.backend()} doesn't implement custom ops for einsum.", + ) + def test_einsum_custom_ops_for_tensorflow(self, dtypes): + import jax.numpy as jnp - self.assertEqual( - standardize_dtype(knp.einsum(subscripts, x1, x2).dtype), - expected_dtype, - ) - self.assertEqual( - standardize_dtype( - knp.Einsum(subscripts).symbolic_call(x1, x2).dtype - ), - expected_dtype, - ) + def get_input_shapes(subscripts): + x1_labels = subscripts.split(",")[0] + x2_labels = subscripts.split("->")[0][len(x1_labels) + 1 :] + x1_shape = [1] * len(x1_labels) + x2_shape = [1] * len(x2_labels) + return x1_shape, x2_shape + + dtype1, dtype2 = dtypes + for subscripts in [ + "a,b->ab", + "ab,bc->ac", + "abc,cd->abd", + "abc,cde->abde", + "abc,dce->abde", + "abcd,abed->abce", + "abcd,adbe->acbe", + "abcd,aecd->acbe", + "abcd,aecd->aceb", + "abcd,cde->abe", + "abcde,aebf->adbcf", + "abcde,afce->acdbf", + ]: + x1_shape, x2_shape = get_input_shapes(subscripts) + x1 = knp.ones(x1_shape, dtype=dtype1) + x2 = knp.ones(x2_shape, dtype=dtype2) + x1_jax = jnp.ones(x1_shape, dtype=dtype1) + x2_jax = jnp.ones(x2_shape, dtype=dtype2) + if dtype1 == "int8" and dtype2 == "int8": + preferred_element_type = "int32" + else: + preferred_element_type = None + expected_dtype = standardize_dtype( + jnp.einsum( + subscripts, + x1_jax, + x2_jax, + preferred_element_type=preferred_element_type, + ).dtype + ) + + self.assertEqual( + standardize_dtype(knp.einsum(subscripts, x1, x2).dtype), + expected_dtype, + ) + self.assertEqual( + standardize_dtype( + knp.Einsum(subscripts).symbolic_call(x1, x2).dtype + ), + expected_dtype, + ) @parameterized.named_parameters(named_product(dtype=ALL_DTYPES)) def test_empty(self, dtype): diff --git a/keras/ops/operation.py b/keras/ops/operation.py index 2932bc86254..bb5f26f669a 100644 --- a/keras/ops/operation.py +++ b/keras/ops/operation.py @@ -4,6 +4,7 @@ import tree from keras import backend +from keras import dtype_policies from keras.api_export import keras_export from keras.backend.common.keras_tensor import any_symbolic_tensors from keras.ops.node import Node @@ -14,7 +15,7 @@ @keras_export("keras.Operation") class Operation: - def __init__(self, name=None): + def __init__(self, dtype=None, name=None): if name is None: name = auto_name(self.__class__.__name__) if not isinstance(name, str) or "/" in name: @@ -23,6 +24,7 @@ def __init__(self, name=None): "cannot contain character `/`. " f"Received: name={name} (of type {type(name)})" ) + self.dtype_policy = dtype_policies.get(dtype) self.name = name self._inbound_nodes = [] self._outbound_nodes = [] @@ -34,7 +36,12 @@ def __call__(self, *args, **kwargs): if any_symbolic_tensors(args, kwargs): call_fn = self.symbolic_call else: - call_fn = self.call + if isinstance( + self.dtype_policy, dtype_policies.QuantizedDTypePolicy + ): + call_fn = self.quantized_call + else: + call_fn = self.call call_fn = traceback_utils.inject_argument_info_in_traceback( call_fn, object_name=(f"{self.__class__.__name__}.call()"), @@ -44,7 +51,10 @@ def __call__(self, *args, **kwargs): # Plain flow. if any_symbolic_tensors(args, kwargs): return self.symbolic_call(*args, **kwargs) - return self.call(*args, **kwargs) + if isinstance(self.dtype_policy, dtype_policies.QuantizedDTypePolicy): + return self.quantized_call(*args, **kwargs) + else: + return self.call(*args, **kwargs) def symbolic_call(self, *args, **kwargs): # Perform shape/dtype inference. @@ -63,6 +73,9 @@ def symbolic_call(self, *args, **kwargs): def call(self, *args, **kwargs): raise NotImplementedError + def quantized_call(self, *args, **kwargs): + raise NotImplementedError + def compute_output_spec(self, *args, **kwargs): try: return backend.compute_output_spec(self.call, *args, **kwargs) diff --git a/keras/quantizers/__init__.py b/keras/quantizers/__init__.py index e69de29bb2d..6139476a1fc 100644 --- a/keras/quantizers/__init__.py +++ b/keras/quantizers/__init__.py @@ -0,0 +1,51 @@ +import inspect + +from keras.api_export import keras_export +from keras.quantizers.quantizers import AbsMaxQuantizer +from keras.quantizers.quantizers import Quantizer +from keras.quantizers.quantizers import abs_max_quantize +from keras.saving import serialization_lib +from keras.utils.naming import to_snake_case + +ALL_OBJECTS = {Quantizer, AbsMaxQuantizer} +ALL_OBJECTS_DICT = {cls.__name__: cls for cls in ALL_OBJECTS} +ALL_OBJECTS_DICT.update( + {to_snake_case(cls.__name__): cls for cls in ALL_OBJECTS} +) + + +@keras_export("keras.quantizers.serialize") +def serialize(initializer): + return serialization_lib.serialize_keras_object(initializer) + + +@keras_export("keras.quantizers.deserialize") +def deserialize(config, custom_objects=None): + """Return a Keras quantizer object via its config.""" + return serialization_lib.deserialize_keras_object( + config, + module_objects=ALL_OBJECTS_DICT, + custom_objects=custom_objects, + ) + + +@keras_export("keras.quantizers.get") +def get(identifier, **kwargs): + """Retrieve a Keras quantizer object via an identifier.""" + if identifier is None: + return None + if isinstance(identifier, dict): + obj = deserialize(identifier) + elif isinstance(identifier, str): + obj = ALL_OBJECTS_DICT.get(identifier, None) + else: + obj = identifier + + if callable(obj): + if inspect.isclass(obj): + obj = obj(kwargs) + return obj + else: + raise ValueError( + f"Could not interpret quantizer identifier: {identifier}" + ) diff --git a/keras/quantizers/quantizers.py b/keras/quantizers/quantizers.py new file mode 100644 index 00000000000..e21f10bcc38 --- /dev/null +++ b/keras/quantizers/quantizers.py @@ -0,0 +1,102 @@ +from keras import backend +from keras import ops +from keras.api_export import keras_export + + +@keras_export(["keras.Quantizer", "keras.quantizers.Quantizer"]) +class Quantizer: + def __init__(self, output_dtype="int8"): + self.output_dtype = output_dtype + + def __call__(self, x): + """Compute a quantized output from an input tensor.""" + return x + + @classmethod + def from_config(cls, config): + """Creates a quantizer from its config. + + This method is the reverse of `get_config`, + capable of instantiating the same quantizer from the config + dictionary. + + This method is used by Keras `model_to_estimator`, saving and + loading models to HDF5 formats, Keras model cloning, some visualization + utilities, and exporting models to and from JSON. + + Args: + config: A Python dictionary, typically the output of get_config. + + Returns: + A quantizer instance. + """ + return cls(**config) + + def get_config(self): + """Returns the config of the quantizer. + + An quantizer config is a Python dictionary (serializable) + containing all configuration parameters of the quantizer. + The same quantizer can be reinstantiated later + (without any saved state) from this configuration. + + This method is optional if you are just training and executing models, + exporting to and from SavedModels, or using weight checkpoints. + + This method is required for Keras `model_to_estimator`, saving and + loading models to HDF5 formats, Keras model cloning, some visualization + utilities, and exporting models to and from JSON. + + Returns: + Python dictionary. + """ + raise NotImplementedError(f"{self} does not implement get_config()") + + +@keras_export(["keras.quantizers.abs_max_quantize"]) +def abs_max_quantize( + inputs, + axis, + value_range=(-127, 127), + dtype="int8", + epsilon=backend.epsilon(), +): + scale = ops.divide( + value_range[1], + ops.add(ops.max(ops.abs(inputs), axis=axis, keepdims=True), epsilon), + ) + outputs = ops.multiply(inputs, scale) + outputs = ops.clip(ops.round(outputs), value_range[0], value_range[1]) + outputs = ops.cast(outputs, dtype) + return outputs, scale + + +@keras_export(["keras.AbsMaxQuantizer", "keras.quantizers.AbsMaxQuantizer"]) +class AbsMaxQuantizer(Quantizer): + def __init__( + self, + axis, + value_range=(-127, 127), + epsilon=backend.epsilon(), + output_dtype="int8", + ): + Quantizer.__init__(self, output_dtype=output_dtype) + if isinstance(axis, int): + axis = (axis,) + self.axis = tuple(axis) + self.value_range = value_range + self.epsilon = epsilon + + def __call__(self, x): + quantized_x, scale = abs_max_quantize( + x, self.axis, self.value_range, self.output_dtype, self.epsilon + ) + return quantized_x, scale + + def get_config(self): + return { + "axis": self.axis, + "value_range": self.value_range, + "epsilon": self.epsilon, + "output_dtype": self.output_dtype, + } diff --git a/keras/quantizers/quantizers_test.py b/keras/quantizers/quantizers_test.py new file mode 100644 index 00000000000..1483fb7cc9c --- /dev/null +++ b/keras/quantizers/quantizers_test.py @@ -0,0 +1,37 @@ +from keras import ops +from keras import quantizers +from keras import random +from keras import testing + + +class QuantizersTest(testing.TestCase): + def test_get_method(self): + quantizer = quantizers.get("abs_max_quantizer", axis=-1) + self.assertTrue(quantizer, quantizers.AbsMaxQuantizer) + + quantizer = quantizers.get(None) + self.assertEqual(quantizer, None) + + with self.assertRaises(ValueError): + quantizers.get("typo") + + def test_abs_max_quantizer(self): + values = random.uniform([3, 4, 5], minval=-1, maxval=1) + quantizer = quantizers.AbsMaxQuantizer(axis=-1) + + # Test quantizing + quantized_values, scale = quantizer(values) + self.assertEqual(tuple(quantized_values.shape), (3, 4, 5)) + self.assertEqual(tuple(scale.shape), (3, 4, 1)) + self.assertLessEqual(ops.max(quantized_values), 127) + self.assertGreaterEqual(ops.min(quantized_values), -127) + + # Test dequantizing + dequantized_values = ops.divide(quantized_values, scale) + rmse = ops.sqrt( + ops.mean(ops.square(ops.subtract(values, dequantized_values))) + ) + self.assertLess(rmse, 1e-1) # loose assertion + + # Test serialization + self.run_class_serialization_test(quantizer)