From 9c5f3cc24ba79297bb5ecba7a1559c6074c7ee01 Mon Sep 17 00:00:00 2001 From: Hongyu Chiu <20734616+james77777778@users.noreply.github.com> Date: Mon, 26 Feb 2024 17:51:01 +0800 Subject: [PATCH 01/33] Add `quantize` to `Dense` Add `quantize` to `Layer` Add `quantizers` --- keras/layers/core/dense.py | 91 ++++++++++++++++++++- keras/layers/layer.py | 51 +++++++++++- keras/layers/preprocessing/index_lookup.py | 2 + keras/models/model.py | 31 ++++++++ keras/ops/operation.py | 15 +++- keras/quantizers/__init__.py | 51 ++++++++++++ keras/quantizers/quantizers.py | 93 ++++++++++++++++++++++ 7 files changed, 329 insertions(+), 5 deletions(-) create mode 100644 keras/quantizers/quantizers.py diff --git a/keras/layers/core/dense.py b/keras/layers/core/dense.py index 487700583b7..76733015b85 100644 --- a/keras/layers/core/dense.py +++ b/keras/layers/core/dense.py @@ -4,6 +4,7 @@ from keras import constraints from keras import initializers from keras import ops +from keras import quantizers from keras import regularizers from keras.api_export import keras_export from keras.layers.input_spec import InputSpec @@ -120,6 +121,10 @@ def build(self, input_shape): self.built = True if self.lora_rank: self.enable_lora(self.lora_rank) + if self.quantization_mode: + self.quantize( + self.quantization_mode, self.quantization_trainable, input_shape + ) @property def kernel(self): @@ -141,6 +146,26 @@ def call(self, inputs): x = self.activation(x) return x + def dynamic_int8_call(self, inputs): + inputs, inputs_scale = self.inputs_quantizer(inputs) + if self.kernel_quantizer is not None: + kernel, kernel_scale = self.kernel_quantizer(self._kernel) + else: + kernel = self.kernel + kernel_scale = self.kernel_scale + x = ops.matmul(inputs, kernel) + x = ops.cast(x, self.compute_dtype) + x = ops.divide(x, ops.multiply(inputs_scale, kernel_scale)) + + # We need to explicitly add lora_kernels instead of directly using + # self.kernel because we want `ops.matmul` to be computed in int8 format + if self.lora_enabled: + x = ops.add(x, ops.matmul(self.lora_kernel_a, self.lora_kernel_b)) + + if self.activation is not None: + x = self.activation(x) + return x + def compute_output_shape(self, input_shape): output_shape = list(input_shape) output_shape[-1] = self.units @@ -177,10 +202,74 @@ def enable_lora( initializer=initializers.get(b_initializer), regularizer=self.kernel_regularizer, ) - self.kernel.trainable = False + self._kernel.trainable = False self._tracker.lock() self.lora_enabled = True + def quantize(self, mode, trainable=False, input_shape=None): + self._check_quantize_args(mode, trainable) + if input_shape is None: + if self._build_shapes_dict is None: + raise ValueError( + "If no `input_shape` is provided, you must first build the" + "layer before applying the quantization." + ) + input_shape = list(tuple(self._build_shapes_dict.values())[0]) + + if mode == "dynamic_int8": + inputs_quantizer_axes = list(range(len(input_shape))) + if len(inputs_quantizer_axes) > 2: + inputs_quantizer_axes.pop(-2) + self.inputs_quantizer = quantizers.AbsMaxQuantizer( + axis=inputs_quantizer_axes + ) + if trainable is False: + # Merge lora-related parameters to make use of fully int8 kernel + self._merge_lora_into_kernel() + kernel_value, kernel_scale = quantizers.abs_max_quantize( + self._kernel, axis=0 + ) + + self._tracker.unlock() + self._untrack_variable(self._kernel) + self._kernel = self.add_weight( + name="kernel", + shape=self._kernel.shape, + initializer=initializers.Constant(kernel_value), + dtype="int8", + trainable=False, + ) + self.kernel_scale = self.add_weight( + name="kernel_scale", + shape=kernel_scale.shape, + initializer=initializers.Constant(kernel_scale), + dtype=self.compute_dtype, + trainable=False, + ) + self._tracker.lock() + self.kernel_quantizer = None + else: + self.kernel_scale = None + self.kernel_quantizer = quantizers.AbsMaxQuantizer(axis=0) + else: + NotImplementedError() + self._quantization_mode = mode + self._quantization_trainable = trainable + + def _merge_lora_into_kernel(self): + if not self.lora_enabled: + return + + # Merge lora-enabled kernel into kernel + self._kernel.assign(self.kernel) + + # Untrack lora parameters + self._tracker.unlock() + self.lora_kernel_a = self._untrack_variable(self.lora_kernel_a) + self.lora_kernel_b = self._untrack_variable(self.lora_kernel_b) + self._tracker.lock() + self.lora_enabled = False + def save_own_variables(self, store): if not self.lora_enabled: return super().save_own_variables(store) diff --git a/keras/layers/layer.py b/keras/layers/layer.py index a93e304a8ee..f86b927713b 100644 --- a/keras/layers/layer.py +++ b/keras/layers/layer.py @@ -261,6 +261,14 @@ def __init__( stacklevel=2, ) self._input_shape_arg = input_shape_arg + # Quantization parameters + self._quantization_mode = kwargs.pop("quantization_mode", None) + self._quantization_trainable = kwargs.pop( + "quantization_trainable", False + ) + self._check_quantize_args( + self.quantization_mode, self.quantization_trainable + ) if kwargs: raise ValueError( "Unrecognized keyword arguments " @@ -625,6 +633,14 @@ def non_trainable_weights(self): return self.weights return [v for v in self.weights if not v.trainable] + @property + def quantization_mode(self): + return self._quantization_mode + + @property + def quantization_trainable(self): + return self._quantization_trainable + @property def metrics_variables(self): """List of all metric variables.""" @@ -789,8 +805,12 @@ def maybe_convert(x): ) kwargs[expected_mask_arg_name] = mask + ############################## + # 7. Populate quantization argument(s) + kwargs["quantization_mode"] = self.quantization_mode + #################### - # 7. Call the layer. + # 8. Call the layer. try: with backend.name_scope(self.name, caller=self): current_scope = backend.get_autocast_scope() @@ -862,6 +882,9 @@ def call(self, *args, **kwargs): "method implemented." ) + def dynamic_int8_call(self, *args, **kwargs): + return self.call(*args, **kwargs) + @traceback_utils.filter_traceback def stateless_call( self, @@ -951,7 +974,10 @@ def stateless_call( with backend.StatelessScope( state_mapping=mapping, collect_losses=return_losses ) as scope: - outputs = self.call(*args, **kwargs) + if self.quantization_mode == "dynamic_int8": + outputs = self.dynamic_int8_call(*args, **kwargs) + else: + outputs = self.call(*args, **kwargs) if return_losses: losses = self.losses @@ -1094,6 +1120,25 @@ def _clear_losses(self): for layer in self._layers: layer._clear_losses() + def quantize(self, mode, trainable=False, input_shape=None): + self._check_quantize_args(mode, trainable) + warnings.warn( + "`quantize` is not implemented for class " + f"'{self.__class__.__name__}' so the quantization is skipped." + ) + + def _check_quantize_args(self, mode, trainable): + if mode not in (None, "dynamic_int8"): + raise ValueError( + "Currently, `quantize` must be one of " + f"(`None`, 'dynamic_int8'). Received: mode={mode}" + ) + if not isinstance(trainable, bool): + raise TypeError( + "`trainable` must be boolean. " + f"Received: trainable={trainable} of type '{type(trainable)}'" + ) + def save_own_variables(self, store): """Saves the state of the layer. @@ -1362,6 +1407,8 @@ def get_config(self): config = { "trainable": self.trainable, "dtype": self.dtype_policy.name, + "quantization_mode": self.quantization_mode, + "quantization_trainable": self.quantization_trainable, } return {**base_config, **config} diff --git a/keras/layers/preprocessing/index_lookup.py b/keras/layers/preprocessing/index_lookup.py index a99651f62ea..7484c29ead0 100644 --- a/keras/layers/preprocessing/index_lookup.py +++ b/keras/layers/preprocessing/index_lookup.py @@ -188,6 +188,8 @@ def __init__( ) kwargs.pop("trainable", None) kwargs.pop("dtype", None) + kwargs.pop("quantization_mode", None) + kwargs.pop("quantization_trainable", None) if kwargs: raise ValueError(f"Unrecognized keyword argument(s): {kwargs}") diff --git a/keras/models/model.py b/keras/models/model.py index 3132350d156..c94c2f050b9 100644 --- a/keras/models/model.py +++ b/keras/models/model.py @@ -361,6 +361,37 @@ def load_weights(self, filepath, skip_mismatch=False, **kwargs): self, filepath, skip_mismatch=skip_mismatch, **kwargs ) + def quantize(self, mode, trainable): + """Quantize the weights of the model. + + Note that the model must be built first before calling this method. + Quantization will be skipped if the layer doesn't implement `quantize` + function. + + Args: + mode: The mode of the quantization. The supported modes are + `"dynamic_int8"`. + trainable: Boolean, whether to enable training after the + quantization. This is useful for finetuning lora weights of the + model. + """ + if not self.built: + raise ValueError( + "The model must be built first before calling `quantize`." + ) + mode_changed = False + for layer in self.layers: + original_mode = layer.quantization_mode + layer.quantize(mode, trainable) + if layer.quantization_mode != original_mode: + mode_changed = True + # We need to set these functions to `None` to remake them for changed + # call function + if mode_changed: + self.train_function = None + self.test_function = None + self.predict_function = None + def build_from_config(self, config): if not config: return diff --git a/keras/ops/operation.py b/keras/ops/operation.py index 2932bc86254..f26f9e6a699 100644 --- a/keras/ops/operation.py +++ b/keras/ops/operation.py @@ -29,12 +29,16 @@ def __init__(self, name=None): @traceback_utils.filter_traceback def __call__(self, *args, **kwargs): + _quantization_mode = kwargs.pop("quantization_mode", None) if traceback_utils.is_traceback_filtering_enabled(): # Wrap self.call to provide helpful info in case of exception if any_symbolic_tensors(args, kwargs): call_fn = self.symbolic_call else: - call_fn = self.call + if _quantization_mode == "dynamic_int8": + call_fn = self.dynamic_int8_call + else: + call_fn = self.call call_fn = traceback_utils.inject_argument_info_in_traceback( call_fn, object_name=(f"{self.__class__.__name__}.call()"), @@ -44,7 +48,10 @@ def __call__(self, *args, **kwargs): # Plain flow. if any_symbolic_tensors(args, kwargs): return self.symbolic_call(*args, **kwargs) - return self.call(*args, **kwargs) + if _quantization_mode == "dynamic_int8": + return self.dynamic_int8_call(*args, **kwargs) + else: + return self.call(*args, **kwargs) def symbolic_call(self, *args, **kwargs): # Perform shape/dtype inference. @@ -63,6 +70,10 @@ def symbolic_call(self, *args, **kwargs): def call(self, *args, **kwargs): raise NotImplementedError + def dynamic_int8_call(self, *args, **kwargs): + # Note that `dynamic_int8_call` defaults to `call` if not implemented. + return self.call(*args, **kwargs) + def compute_output_spec(self, *args, **kwargs): try: return backend.compute_output_spec(self.call, *args, **kwargs) diff --git a/keras/quantizers/__init__.py b/keras/quantizers/__init__.py index e69de29bb2d..8da3a15cf3d 100644 --- a/keras/quantizers/__init__.py +++ b/keras/quantizers/__init__.py @@ -0,0 +1,51 @@ +import inspect + +from keras.api_export import keras_export +from keras.quantizers.quantizers import AbsMaxQuantizer +from keras.quantizers.quantizers import Quantizer +from keras.quantizers.quantizers import abs_max_quantize +from keras.saving import serialization_lib +from keras.utils.naming import to_snake_case + +ALL_OBJECTS = {Quantizer, AbsMaxQuantizer} +ALL_OBJECTS_DICT = {cls.__name__: cls for cls in ALL_OBJECTS} +ALL_OBJECTS_DICT.update( + {to_snake_case(cls.__name__): cls for cls in ALL_OBJECTS} +) + + +@keras_export("keras.quantizers.serialize") +def serialize(initializer): + return serialization_lib.serialize_keras_object(initializer) + + +@keras_export("keras.quantizers.deserialize") +def deserialize(config, custom_objects=None): + """Return a Keras quantizer object via its config.""" + return serialization_lib.deserialize_keras_object( + config, + module_objects=ALL_OBJECTS_DICT, + custom_objects=custom_objects, + ) + + +@keras_export("keras.quantizers.get") +def get(identifier): + """Retrieve a Keras quantizer object via an identifier.""" + if identifier is None: + return None + if isinstance(identifier, dict): + obj = deserialize(identifier) + elif isinstance(identifier, str): + obj = ALL_OBJECTS_DICT.get(identifier, None) + else: + obj = identifier + + if callable(obj): + if inspect.isclass(obj): + obj = obj() + return obj + else: + raise ValueError( + f"Could not interpret quantizer identifier: {identifier}" + ) diff --git a/keras/quantizers/quantizers.py b/keras/quantizers/quantizers.py new file mode 100644 index 00000000000..544bc6e143f --- /dev/null +++ b/keras/quantizers/quantizers.py @@ -0,0 +1,93 @@ +from keras import backend +from keras import ops +from keras.api_export import keras_export + + +@keras_export(["keras.quantizers.abs_max_quantize"]) +def abs_max_quantize( + inputs, + axis, + value_range=(-127, 127), + dtype="int8", + epsilon=backend.epsilon(), +): + scale = ops.divide( + value_range[1], + ops.max(ops.abs(inputs), axis=axis, keepdims=True) + epsilon, + ) + outputs = ops.multiply(inputs, scale) + outputs = ops.clip(ops.round(outputs), value_range[0], value_range[1]) + outputs = ops.cast(outputs, dtype) + return outputs, scale + + +@keras_export(["keras.Quantizer", "keras.quantizers.Quantizer"]) +class Quantizer: + def __init__(self, output_dtype="int8"): + self.output_dtype = output_dtype + + def __call__(self, x): + """Compute a quantized output from an input tensor.""" + return x + + @classmethod + def from_config(cls, config): + """Creates a quantizer from its config. + + This method is the reverse of `get_config`, + capable of instantiating the same quantizer from the config + dictionary. + + This method is used by Keras `model_to_estimator`, saving and + loading models to HDF5 formats, Keras model cloning, some visualization + utilities, and exporting models to and from JSON. + + Args: + config: A Python dictionary, typically the output of get_config. + + Returns: + A quantizer instance. + """ + return cls(**config) + + def get_config(self): + """Returns the config of the quantizer. + + An quantizer config is a Python dictionary (serializable) + containing all configuration parameters of the quantizer. + The same quantizer can be reinstantiated later + (without any saved state) from this configuration. + + This method is optional if you are just training and executing models, + exporting to and from SavedModels, or using weight checkpoints. + + This method is required for Keras `model_to_estimator`, saving and + loading models to HDF5 formats, Keras model cloning, some visualization + utilities, and exporting models to and from JSON. + + Returns: + Python dictionary. + """ + raise NotImplementedError(f"{self} does not implement get_config()") + + +class AbsMaxQuantizer(Quantizer): + def __init__( + self, + axis, + value_range=(-127, 127), + epsilon=backend.epsilon(), + output_dtype="int8", + ): + Quantizer.__init__(self, output_dtype=output_dtype) + if isinstance(axis, int): + axis = (axis,) + self.axis = tuple(axis) + self.value_range = value_range + self.epsilon = epsilon + + def __call__(self, x): + quantized_x, scale = abs_max_quantize( + x, self.axis, self.value_range, self.output_dtype, self.epsilon + ) + return quantized_x, scale From ef48715e157265c1650c0f911d56c3c6581b0e15 Mon Sep 17 00:00:00 2001 From: Hongyu Chiu <20734616+james77777778@users.noreply.github.com> Date: Mon, 26 Feb 2024 18:21:08 +0800 Subject: [PATCH 02/33] Fix `dynamic_int8_call` --- keras/layers/core/dense.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/keras/layers/core/dense.py b/keras/layers/core/dense.py index 76733015b85..79bf18a5b9c 100644 --- a/keras/layers/core/dense.py +++ b/keras/layers/core/dense.py @@ -141,12 +141,13 @@ def kernel(self): def call(self, inputs): x = ops.matmul(inputs, self.kernel) if self.bias is not None: - x = x + self.bias + x = ops.add(x, self.bias) if self.activation is not None: x = self.activation(x) return x def dynamic_int8_call(self, inputs): + ori_inputs = inputs inputs, inputs_scale = self.inputs_quantizer(inputs) if self.kernel_quantizer is not None: kernel, kernel_scale = self.kernel_quantizer(self._kernel) @@ -160,8 +161,12 @@ def dynamic_int8_call(self, inputs): # We need to explicitly add lora_kernels instead of directly using # self.kernel because we want `ops.matmul` to be computed in int8 format if self.lora_enabled: - x = ops.add(x, ops.matmul(self.lora_kernel_a, self.lora_kernel_b)) - + lora_x = ops.matmul( + ori_inputs, ops.matmul(self.lora_kernel_a, self.lora_kernel_b) + ) + x = ops.add(x, lora_x) + if self.bias is not None: + x = ops.add(x, self.bias) if self.activation is not None: x = self.activation(x) return x From 3456c66de3bf8e6b4e3fad9e3f7e5bc0a8a7ea17 Mon Sep 17 00:00:00 2001 From: Hongyu Chiu <20734616+james77777778@users.noreply.github.com> Date: Mon, 26 Feb 2024 22:26:48 +0800 Subject: [PATCH 03/33] Cleanup for unused `quantization_trainable` --- keras/layers/core/dense.py | 80 +++++++++------------- keras/layers/layer.py | 23 ++----- keras/layers/preprocessing/index_lookup.py | 1 - keras/models/model.py | 7 +- 4 files changed, 37 insertions(+), 74 deletions(-) diff --git a/keras/layers/core/dense.py b/keras/layers/core/dense.py index 79bf18a5b9c..304977a9e43 100644 --- a/keras/layers/core/dense.py +++ b/keras/layers/core/dense.py @@ -122,9 +122,7 @@ def build(self, input_shape): if self.lora_rank: self.enable_lora(self.lora_rank) if self.quantization_mode: - self.quantize( - self.quantization_mode, self.quantization_trainable, input_shape - ) + self.quantize(self.quantization_mode, input_shape) @property def kernel(self): @@ -147,24 +145,13 @@ def call(self, inputs): return x def dynamic_int8_call(self, inputs): - ori_inputs = inputs + if self.lora_enabled: + raise ValueError("`dynamic_int8_call` doesn't support lora weights") + inputs, inputs_scale = self.inputs_quantizer(inputs) - if self.kernel_quantizer is not None: - kernel, kernel_scale = self.kernel_quantizer(self._kernel) - else: - kernel = self.kernel - kernel_scale = self.kernel_scale - x = ops.matmul(inputs, kernel) + x = ops.matmul(inputs, self.kernel) x = ops.cast(x, self.compute_dtype) - x = ops.divide(x, ops.multiply(inputs_scale, kernel_scale)) - - # We need to explicitly add lora_kernels instead of directly using - # self.kernel because we want `ops.matmul` to be computed in int8 format - if self.lora_enabled: - lora_x = ops.matmul( - ori_inputs, ops.matmul(self.lora_kernel_a, self.lora_kernel_b) - ) - x = ops.add(x, lora_x) + x = ops.divide(x, ops.multiply(inputs_scale, self.kernel_scale)) if self.bias is not None: x = ops.add(x, self.bias) if self.activation is not None: @@ -211,8 +198,8 @@ def enable_lora( self._tracker.lock() self.lora_enabled = True - def quantize(self, mode, trainable=False, input_shape=None): - self._check_quantize_args(mode, trainable) + def quantize(self, mode, input_shape=None): + self._check_quantize_args(mode) if input_shape is None: if self._build_shapes_dict is None: raise ValueError( @@ -228,38 +215,33 @@ def quantize(self, mode, trainable=False, input_shape=None): self.inputs_quantizer = quantizers.AbsMaxQuantizer( axis=inputs_quantizer_axes ) - if trainable is False: - # Merge lora-related parameters to make use of fully int8 kernel - self._merge_lora_into_kernel() - kernel_value, kernel_scale = quantizers.abs_max_quantize( - self._kernel, axis=0 - ) + # Merge lora-related parameters to make use of fully int8 kernel + self._merge_lora_into_kernel() + kernel_value, kernel_scale = quantizers.abs_max_quantize( + self._kernel, axis=0 + ) - self._tracker.unlock() - self._untrack_variable(self._kernel) - self._kernel = self.add_weight( - name="kernel", - shape=self._kernel.shape, - initializer=initializers.Constant(kernel_value), - dtype="int8", - trainable=False, - ) - self.kernel_scale = self.add_weight( - name="kernel_scale", - shape=kernel_scale.shape, - initializer=initializers.Constant(kernel_scale), - dtype=self.compute_dtype, - trainable=False, - ) - self._tracker.lock() - self.kernel_quantizer = None - else: - self.kernel_scale = None - self.kernel_quantizer = quantizers.AbsMaxQuantizer(axis=0) + self._tracker.unlock() + self._untrack_variable(self._kernel) + self._kernel = self.add_weight( + name="kernel", + shape=self._kernel.shape, + initializer=initializers.Constant(kernel_value), + dtype="int8", + trainable=False, + ) + self.kernel_scale = self.add_weight( + name="kernel_scale", + shape=kernel_scale.shape, + initializer=initializers.Constant(kernel_scale), + dtype=self.compute_dtype, + trainable=False, + ) + self._tracker.lock() + self.kernel_quantizer = None else: NotImplementedError() self._quantization_mode = mode - self._quantization_trainable = trainable def _merge_lora_into_kernel(self): if not self.lora_enabled: diff --git a/keras/layers/layer.py b/keras/layers/layer.py index f86b927713b..2840951b08e 100644 --- a/keras/layers/layer.py +++ b/keras/layers/layer.py @@ -263,12 +263,7 @@ def __init__( self._input_shape_arg = input_shape_arg # Quantization parameters self._quantization_mode = kwargs.pop("quantization_mode", None) - self._quantization_trainable = kwargs.pop( - "quantization_trainable", False - ) - self._check_quantize_args( - self.quantization_mode, self.quantization_trainable - ) + self._check_quantize_args(self.quantization_mode) if kwargs: raise ValueError( "Unrecognized keyword arguments " @@ -637,10 +632,6 @@ def non_trainable_weights(self): def quantization_mode(self): return self._quantization_mode - @property - def quantization_trainable(self): - return self._quantization_trainable - @property def metrics_variables(self): """List of all metric variables.""" @@ -1120,24 +1111,19 @@ def _clear_losses(self): for layer in self._layers: layer._clear_losses() - def quantize(self, mode, trainable=False, input_shape=None): - self._check_quantize_args(mode, trainable) + def quantize(self, mode, input_shape=None): + self._check_quantize_args(mode) warnings.warn( "`quantize` is not implemented for class " f"'{self.__class__.__name__}' so the quantization is skipped." ) - def _check_quantize_args(self, mode, trainable): + def _check_quantize_args(self, mode): if mode not in (None, "dynamic_int8"): raise ValueError( "Currently, `quantize` must be one of " f"(`None`, 'dynamic_int8'). Received: mode={mode}" ) - if not isinstance(trainable, bool): - raise TypeError( - "`trainable` must be boolean. " - f"Received: trainable={trainable} of type '{type(trainable)}'" - ) def save_own_variables(self, store): """Saves the state of the layer. @@ -1408,7 +1394,6 @@ def get_config(self): "trainable": self.trainable, "dtype": self.dtype_policy.name, "quantization_mode": self.quantization_mode, - "quantization_trainable": self.quantization_trainable, } return {**base_config, **config} diff --git a/keras/layers/preprocessing/index_lookup.py b/keras/layers/preprocessing/index_lookup.py index 7484c29ead0..4f422d80a4e 100644 --- a/keras/layers/preprocessing/index_lookup.py +++ b/keras/layers/preprocessing/index_lookup.py @@ -189,7 +189,6 @@ def __init__( kwargs.pop("trainable", None) kwargs.pop("dtype", None) kwargs.pop("quantization_mode", None) - kwargs.pop("quantization_trainable", None) if kwargs: raise ValueError(f"Unrecognized keyword argument(s): {kwargs}") diff --git a/keras/models/model.py b/keras/models/model.py index c94c2f050b9..48dd91b48ba 100644 --- a/keras/models/model.py +++ b/keras/models/model.py @@ -361,7 +361,7 @@ def load_weights(self, filepath, skip_mismatch=False, **kwargs): self, filepath, skip_mismatch=skip_mismatch, **kwargs ) - def quantize(self, mode, trainable): + def quantize(self, mode): """Quantize the weights of the model. Note that the model must be built first before calling this method. @@ -371,9 +371,6 @@ def quantize(self, mode, trainable): Args: mode: The mode of the quantization. The supported modes are `"dynamic_int8"`. - trainable: Boolean, whether to enable training after the - quantization. This is useful for finetuning lora weights of the - model. """ if not self.built: raise ValueError( @@ -382,7 +379,7 @@ def quantize(self, mode, trainable): mode_changed = False for layer in self.layers: original_mode = layer.quantization_mode - layer.quantize(mode, trainable) + layer.quantize(mode) if layer.quantization_mode != original_mode: mode_changed = True # We need to set these functions to `None` to remake them for changed From 5d397692978fde7997c9b1b141f620b51c4a2cc1 Mon Sep 17 00:00:00 2001 From: Hongyu Chiu <20734616+james77777778@users.noreply.github.com> Date: Mon, 26 Feb 2024 22:27:23 +0800 Subject: [PATCH 04/33] Add demo script --- check_dynamic_int8.py | 129 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 129 insertions(+) create mode 100644 check_dynamic_int8.py diff --git a/check_dynamic_int8.py b/check_dynamic_int8.py new file mode 100644 index 00000000000..00acc0d3a9c --- /dev/null +++ b/check_dynamic_int8.py @@ -0,0 +1,129 @@ +import time + +import jax +import numpy as np +import tensorflow as tf + +import keras +from keras import backend +from keras import layers +from keras import models +from keras import ops +from keras import saving + +# Model / data parameters +num_classes = 10 +input_shape = (28, 28, 1) +epochs = 1 + +# Load the data and split it between train and test sets +(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data() +x_train = x_train.astype("float32") / 255 +x_test = x_test.astype("float32") / 255 +x_train = np.expand_dims(x_train, -1) +x_test = np.expand_dims(x_test, -1) +y_train = keras.utils.to_categorical(y_train, num_classes) +y_test = keras.utils.to_categorical(y_test, num_classes) + + +def build_model(num_layers=32, units=1024): + inputs = layers.Input([28, 28]) + x = layers.Flatten()(inputs) + for _ in range(num_layers): + x = layers.Dense(units)(x) + x = layers.BatchNormalization()(x) + x = layers.ReLU()(x) + outputs = layers.Dense(10, use_bias=True, activation="softmax")(x) + model = models.Model(inputs, outputs) + return model + + +def enable_lora(model): + for layer in model.layers: + if hasattr(layer, "enable_lora"): + layer.enable_lora(2) + + +def benchmark(model, batch_size=1024, input_shape=(28, 28), iterations=200): + def fn(x): + return model(x, training=False) + + if backend.backend() == "tensorflow": + jit_fn = tf.function(fn, jit_compile=True) + elif backend.backend() == "jax": + jit_fn = jax.jit(fn) + else: + jit_fn = fn + + # warmup + x = ops.ones([batch_size, *input_shape]) + for _ in range(10): + _ = ops.convert_to_numpy(jit_fn(x)) + + times = [] + for _ in range(iterations): + t0 = time.time() + _ = ops.convert_to_numpy(jit_fn(x)) + t1 = time.time() + times.append(t1 - t0) + avg_time = sum(times) / len(times) + return avg_time + + +ENABLE_LORA = False + +model = build_model(num_layers=32, units=1024) +model.compile( + loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"] +) + +"""Train float model""" +print("Start training:") +model.fit(x_train, y_train, batch_size=128, epochs=epochs, validation_split=0.1) +print("Performance of float32:") +score = model.evaluate(x_test, y_test, verbose=0) +print(f"Test accuracy: {score[1]:.5f}") +avg_time = benchmark(model) +print(f"Avg. time (batch_size=1024): {avg_time:.5f}s") +model.save("model_fp32.keras") + +if ENABLE_LORA: + """Enable lora""" + enable_lora(model) + + """Fine-tuning lora weights with trainable quantization""" + model.compile( + loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"] + ) + model.fit( + x_train, y_train, batch_size=128, epochs=epochs, validation_split=0.1 + ) + print("Performance of fine-tuning lora with trainable quantization:") + score = model.evaluate(x_test, y_test, verbose=0) + print(f"Test accuracy: {score[1]:.5f}") + avg_time = benchmark(model) + print(f"Avg. time (batch_size=1024): {avg_time:.5f}s") + + """Quantize to int8 weights""" + model.quantize(mode="dynamic_int8") + int8_model = model + int8_model.compile(loss="categorical_crossentropy", metrics=["accuracy"]) + score = int8_model.evaluate(x_test, y_test, verbose=0) + print(f"Test accuracy: {score[1]:.5f}") + avg_time = benchmark(int8_model) + print(f"Avg. time (batch_size=1024): {avg_time:.5f}s") +else: + """Quantization""" + model.quantize(mode="dynamic_int8") + int8_model = model + int8_model.compile(loss="categorical_crossentropy", metrics=["accuracy"]) + score = int8_model.evaluate(x_test, y_test, verbose=0) + print(f"Test accuracy: {score[1]:.5f}") + avg_time = benchmark(int8_model) + print(f"Avg. time (batch_size=1024): {avg_time:.5f}s") + +"""Saving & loading""" +int8_model.save("model_int8.keras") +reloaded_int8_model = saving.load_model("model_int8.keras") +reloaded_score = reloaded_int8_model.evaluate(x_test, y_test, verbose=0) +print(f"Reloaded int8 model test accuracy: {reloaded_score[1]:.5f}") From f549268029b30551bf33a7d043f655460a1da41d Mon Sep 17 00:00:00 2001 From: Hongyu Chiu <20734616+james77777778@users.noreply.github.com> Date: Mon, 26 Feb 2024 22:38:43 +0800 Subject: [PATCH 05/33] Update script --- check_dynamic_int8.py | 108 ++++++++++++++++++++++++------------------ 1 file changed, 61 insertions(+), 47 deletions(-) diff --git a/check_dynamic_int8.py b/check_dynamic_int8.py index 00acc0d3a9c..f07abf4ef10 100644 --- a/check_dynamic_int8.py +++ b/check_dynamic_int8.py @@ -70,60 +70,74 @@ def fn(x): return avg_time -ENABLE_LORA = False - -model = build_model(num_layers=32, units=1024) -model.compile( - loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"] -) - -"""Train float model""" -print("Start training:") -model.fit(x_train, y_train, batch_size=128, epochs=epochs, validation_split=0.1) -print("Performance of float32:") -score = model.evaluate(x_test, y_test, verbose=0) -print(f"Test accuracy: {score[1]:.5f}") -avg_time = benchmark(model) -print(f"Avg. time (batch_size=1024): {avg_time:.5f}s") -model.save("model_fp32.keras") - -if ENABLE_LORA: - """Enable lora""" - enable_lora(model) - - """Fine-tuning lora weights with trainable quantization""" +for enable_rola in (True, False): + model = build_model(num_layers=32, units=1024) model.compile( loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"] ) + + """Train float model""" + print("Start training float model:") model.fit( x_train, y_train, batch_size=128, epochs=epochs, validation_split=0.1 ) - print("Performance of fine-tuning lora with trainable quantization:") + print("Performance of float32:") score = model.evaluate(x_test, y_test, verbose=0) print(f"Test accuracy: {score[1]:.5f}") avg_time = benchmark(model) print(f"Avg. time (batch_size=1024): {avg_time:.5f}s") - - """Quantize to int8 weights""" - model.quantize(mode="dynamic_int8") - int8_model = model - int8_model.compile(loss="categorical_crossentropy", metrics=["accuracy"]) - score = int8_model.evaluate(x_test, y_test, verbose=0) - print(f"Test accuracy: {score[1]:.5f}") - avg_time = benchmark(int8_model) - print(f"Avg. time (batch_size=1024): {avg_time:.5f}s") -else: - """Quantization""" - model.quantize(mode="dynamic_int8") - int8_model = model - int8_model.compile(loss="categorical_crossentropy", metrics=["accuracy"]) - score = int8_model.evaluate(x_test, y_test, verbose=0) - print(f"Test accuracy: {score[1]:.5f}") - avg_time = benchmark(int8_model) - print(f"Avg. time (batch_size=1024): {avg_time:.5f}s") - -"""Saving & loading""" -int8_model.save("model_int8.keras") -reloaded_int8_model = saving.load_model("model_int8.keras") -reloaded_score = reloaded_int8_model.evaluate(x_test, y_test, verbose=0) -print(f"Reloaded int8 model test accuracy: {reloaded_score[1]:.5f}") + model.save("model_fp32.keras") + + if enable_rola: + """Enable lora""" + print("Enable lora weights") + enable_lora(model) + + """Fine-tuning lora weights""" + model.compile( + loss="categorical_crossentropy", + optimizer="adam", + metrics=["accuracy"], + ) + model.fit( + x_train, + y_train, + batch_size=128, + epochs=epochs, + validation_split=0.1, + ) + print("Performance of fine-tuned lora weights:") + score = model.evaluate(x_test, y_test, verbose=0) + print(f"Test accuracy: {score[1]:.5f}") + avg_time = benchmark(model) + print(f"Avg. time (batch_size=1024): {avg_time:.5f}s") + + """Quantize to int8 weights""" + model.quantize(mode="dynamic_int8") + int8_model = model + int8_model.compile( + loss="categorical_crossentropy", metrics=["accuracy"] + ) + print("Performance of quantized model:") + score = int8_model.evaluate(x_test, y_test, verbose=0) + print(f"Test accuracy: {score[1]:.5f}") + avg_time = benchmark(int8_model) + print(f"Avg. time (batch_size=1024): {avg_time:.5f}s") + else: + """Quantization""" + model.quantize(mode="dynamic_int8") + int8_model = model + int8_model.compile( + loss="categorical_crossentropy", metrics=["accuracy"] + ) + print("Performance of quantized model:") + score = int8_model.evaluate(x_test, y_test, verbose=0) + print(f"Test accuracy: {score[1]:.5f}") + avg_time = benchmark(int8_model) + print(f"Avg. time (batch_size=1024): {avg_time:.5f}s") + + """Saving & loading""" + int8_model.save("model_int8.keras") + reloaded_int8_model = saving.load_model("model_int8.keras") + reloaded_score = reloaded_int8_model.evaluate(x_test, y_test, verbose=0) + print(f"Reloaded int8 model test accuracy: {reloaded_score[1]:.5f}") From 047d05644145da510a2951de9e651cc5a07f9a31 Mon Sep 17 00:00:00 2001 From: Hongyu Chiu <20734616+james77777778@users.noreply.github.com> Date: Fri, 1 Mar 2024 17:28:45 +0800 Subject: [PATCH 06/33] Update dtype policy --- keras/layers/core/dense.py | 15 +++++++++------ keras/layers/layer.py | 29 +++++++++++++++-------------- keras/models/model.py | 6 +++--- keras/ops/operation.py | 14 +++++++------- keras/quantizers/quantizers.py | 2 +- 5 files changed, 35 insertions(+), 31 deletions(-) diff --git a/keras/layers/core/dense.py b/keras/layers/core/dense.py index 304977a9e43..72e9bb2ac54 100644 --- a/keras/layers/core/dense.py +++ b/keras/layers/core/dense.py @@ -2,6 +2,7 @@ from keras import activations from keras import constraints +from keras import dtype_policies from keras import initializers from keras import ops from keras import quantizers @@ -121,8 +122,8 @@ def build(self, input_shape): self.built = True if self.lora_rank: self.enable_lora(self.lora_rank) - if self.quantization_mode: - self.quantize(self.quantization_mode, input_shape) + if self.dtype_policy.quantization_mode: + self.quantize(self.dtype_policy.quantization_mode, input_shape) @property def kernel(self): @@ -144,9 +145,9 @@ def call(self, inputs): x = self.activation(x) return x - def dynamic_int8_call(self, inputs): + def int8_call(self, inputs): if self.lora_enabled: - raise ValueError("`dynamic_int8_call` doesn't support lora weights") + raise ValueError("`int8_call` doesn't support lora weights") inputs, inputs_scale = self.inputs_quantizer(inputs) x = ops.matmul(inputs, self.kernel) @@ -208,7 +209,7 @@ def quantize(self, mode, input_shape=None): ) input_shape = list(tuple(self._build_shapes_dict.values())[0]) - if mode == "dynamic_int8": + if mode == "quantized_int8": inputs_quantizer_axes = list(range(len(input_shape))) if len(inputs_quantizer_axes) > 2: inputs_quantizer_axes.pop(-2) @@ -241,7 +242,9 @@ def quantize(self, mode, input_shape=None): self.kernel_quantizer = None else: NotImplementedError() - self._quantization_mode = mode + + quantized_dtype = f"{mode}_from_{self.dtype_policy.name}" + self.dtype_policy = dtype_policies.get(quantized_dtype) def _merge_lora_into_kernel(self): if not self.lora_enabled: diff --git a/keras/layers/layer.py b/keras/layers/layer.py index 2840951b08e..7a88dde9170 100644 --- a/keras/layers/layer.py +++ b/keras/layers/layer.py @@ -261,9 +261,6 @@ def __init__( stacklevel=2, ) self._input_shape_arg = input_shape_arg - # Quantization parameters - self._quantization_mode = kwargs.pop("quantization_mode", None) - self._check_quantize_args(self.quantization_mode) if kwargs: raise ValueError( "Unrecognized keyword arguments " @@ -628,10 +625,6 @@ def non_trainable_weights(self): return self.weights return [v for v in self.weights if not v.trainable] - @property - def quantization_mode(self): - return self._quantization_mode - @property def metrics_variables(self): """List of all metric variables.""" @@ -798,7 +791,7 @@ def maybe_convert(x): ############################## # 7. Populate quantization argument(s) - kwargs["quantization_mode"] = self.quantization_mode + kwargs["is_quantized_int8"] = self.dtype_policy.is_quantized_int8 #################### # 8. Call the layer. @@ -873,7 +866,7 @@ def call(self, *args, **kwargs): "method implemented." ) - def dynamic_int8_call(self, *args, **kwargs): + def int8_call(self, *args, **kwargs): return self.call(*args, **kwargs) @traceback_utils.filter_traceback @@ -965,8 +958,8 @@ def stateless_call( with backend.StatelessScope( state_mapping=mapping, collect_losses=return_losses ) as scope: - if self.quantization_mode == "dynamic_int8": - outputs = self.dynamic_int8_call(*args, **kwargs) + if self.dtype_policy.is_quantized_int8: + outputs = self.int8_call(*args, **kwargs) else: outputs = self.call(*args, **kwargs) if return_losses: @@ -1119,10 +1112,19 @@ def quantize(self, mode, input_shape=None): ) def _check_quantize_args(self, mode): - if mode not in (None, "dynamic_int8"): + if mode not in (None, "quantized_int8"): raise ValueError( "Currently, `quantize` must be one of " - f"(`None`, 'dynamic_int8'). Received: mode={mode}" + f"(`None`, 'quantized_int8'). Received: mode={mode}" + ) + if ( + mode == "quantized_int8" + and self.dtype_policy.compute_dtype == "float16" + ): + raise ValueError( + f"mode='{mode}' doesn't work well with " + "compute_dtype='float16'. Consider loading model/layer with " + "other dtype policy before calling `quantize`." ) def save_own_variables(self, store): @@ -1393,7 +1395,6 @@ def get_config(self): config = { "trainable": self.trainable, "dtype": self.dtype_policy.name, - "quantization_mode": self.quantization_mode, } return {**base_config, **config} diff --git a/keras/models/model.py b/keras/models/model.py index 48dd91b48ba..921d5441bd6 100644 --- a/keras/models/model.py +++ b/keras/models/model.py @@ -370,7 +370,7 @@ def quantize(self, mode): Args: mode: The mode of the quantization. The supported modes are - `"dynamic_int8"`. + `"quantized_int8"`. """ if not self.built: raise ValueError( @@ -378,9 +378,9 @@ def quantize(self, mode): ) mode_changed = False for layer in self.layers: - original_mode = layer.quantization_mode + original_mode = layer.dtype_policy.quantization_mode layer.quantize(mode) - if layer.quantization_mode != original_mode: + if layer.dtype_policy.quantization_mode != original_mode: mode_changed = True # We need to set these functions to `None` to remake them for changed # call function diff --git a/keras/ops/operation.py b/keras/ops/operation.py index f26f9e6a699..e744a52e61f 100644 --- a/keras/ops/operation.py +++ b/keras/ops/operation.py @@ -29,14 +29,14 @@ def __init__(self, name=None): @traceback_utils.filter_traceback def __call__(self, *args, **kwargs): - _quantization_mode = kwargs.pop("quantization_mode", None) + is_quantized_int8 = kwargs.pop("is_quantized_int8", None) if traceback_utils.is_traceback_filtering_enabled(): # Wrap self.call to provide helpful info in case of exception if any_symbolic_tensors(args, kwargs): call_fn = self.symbolic_call else: - if _quantization_mode == "dynamic_int8": - call_fn = self.dynamic_int8_call + if is_quantized_int8: + call_fn = self.int8_call else: call_fn = self.call call_fn = traceback_utils.inject_argument_info_in_traceback( @@ -48,8 +48,8 @@ def __call__(self, *args, **kwargs): # Plain flow. if any_symbolic_tensors(args, kwargs): return self.symbolic_call(*args, **kwargs) - if _quantization_mode == "dynamic_int8": - return self.dynamic_int8_call(*args, **kwargs) + if is_quantized_int8: + return self.int8_call(*args, **kwargs) else: return self.call(*args, **kwargs) @@ -70,8 +70,8 @@ def symbolic_call(self, *args, **kwargs): def call(self, *args, **kwargs): raise NotImplementedError - def dynamic_int8_call(self, *args, **kwargs): - # Note that `dynamic_int8_call` defaults to `call` if not implemented. + def int8_call(self, *args, **kwargs): + # Note that `int8_call` defaults to `call` if not implemented. return self.call(*args, **kwargs) def compute_output_spec(self, *args, **kwargs): diff --git a/keras/quantizers/quantizers.py b/keras/quantizers/quantizers.py index 544bc6e143f..2051d61c2a2 100644 --- a/keras/quantizers/quantizers.py +++ b/keras/quantizers/quantizers.py @@ -13,7 +13,7 @@ def abs_max_quantize( ): scale = ops.divide( value_range[1], - ops.max(ops.abs(inputs), axis=axis, keepdims=True) + epsilon, + ops.add(ops.max(ops.abs(inputs), axis=axis, keepdims=True), epsilon), ) outputs = ops.multiply(inputs, scale) outputs = ops.clip(ops.round(outputs), value_range[0], value_range[1]) From a7209f421266193a02ee194a042ed627312133df Mon Sep 17 00:00:00 2001 From: Hongyu Chiu <20734616+james77777778@users.noreply.github.com> Date: Fri, 1 Mar 2024 17:29:14 +0800 Subject: [PATCH 07/33] Update dtype_policy --- keras/dtype_policies/dtype_policy.py | 34 ++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/keras/dtype_policies/dtype_policy.py b/keras/dtype_policies/dtype_policy.py index 5dad6d093f1..c77b7167ea5 100644 --- a/keras/dtype_policies/dtype_policy.py +++ b/keras/dtype_policies/dtype_policy.py @@ -62,6 +62,7 @@ def __init__(self, name): f"Received: name={name} (of type {type(name)})" ) self._name = name + self._quantization_mode = None self._compute_dtype, self._variable_dtype = self._parse_name(name) # TODO: check that the current hardware supports the provided # dtype policy and raise/warn otherwise. @@ -79,6 +80,19 @@ def _parse_name(self, name): return "float16", "float32" elif name == "mixed_bfloat16": return "bfloat16", "float32" + elif "quantized_int8" in name: + if "_from_" not in name: + raise ValueError( + f"Cannot convert '{name}' to a quantized DTypePolicy. " + "Valid policies are in the pattern of " + "'quantized_int8_from_(name)' such as " + "'quantized_int8_from_mixed_bfloat16'." + ) + # "quantized_int8_from_float32" indicates that the layer + # or model is quantized from float32 dtype policy. + ori_name = name.split("_from_")[-1] + self._quantization_mode = "quantized_int8" + return self._parse_name(ori_name) try: dtype = backend.standardize_dtype(name) return dtype, dtype @@ -127,6 +141,26 @@ def compute_dtype(self): """ return self._compute_dtype + @property + def quantization_mode(self): + """The quantization mode of this policy. + + Returns: + The quantization mode of this policy, as a string. `None` if no + quantization. + """ + return self._quantization_mode + + @property + def is_quantized_int8(self): + """Whether this policy is quantized to `'int8'`. + + Returns: + The boolean value indicating whether this policy is quantized to + `'int8'`. + """ + return "quantized_int8" in self._name + @property def name(self): """Returns the name of this policy.""" From ee5da63624c26a89fb5f59224e077514c0e08e07 Mon Sep 17 00:00:00 2001 From: Hongyu Chiu <20734616+james77777778@users.noreply.github.com> Date: Fri, 1 Mar 2024 17:29:29 +0800 Subject: [PATCH 08/33] Update demo script --- check_dynamic_int8.py | 41 ++++++++++++++++++++++++++++------------- 1 file changed, 28 insertions(+), 13 deletions(-) diff --git a/check_dynamic_int8.py b/check_dynamic_int8.py index f07abf4ef10..ed4a81ba61d 100644 --- a/check_dynamic_int8.py +++ b/check_dynamic_int8.py @@ -1,3 +1,4 @@ +import os import time import jax @@ -6,11 +7,17 @@ import keras from keras import backend +from keras import dtype_policies from keras import layers from keras import models from keras import ops from keras import saving +# Set dtype policy +dtype = "mixed_bfloat16" +dtype_policies.dtype_policy.set_dtype_policy(dtype) +print(f"Global dtype policy: {dtype_policies.dtype_policy.dtype_policy()}") + # Model / data parameters num_classes = 10 input_shape = (28, 28, 1) @@ -77,20 +84,20 @@ def fn(x): ) """Train float model""" - print("Start training float model:") + print("=====Start training float model=====") model.fit( x_train, y_train, batch_size=128, epochs=epochs, validation_split=0.1 ) - print("Performance of float32:") + print(f"Performance of {dtype}:") score = model.evaluate(x_test, y_test, verbose=0) - print(f"Test accuracy: {score[1]:.5f}") + print(f" Test accuracy: {score[1]:.5f}") avg_time = benchmark(model) - print(f"Avg. time (batch_size=1024): {avg_time:.5f}s") + print(f" Avg. inference time (batch_size=1024): {avg_time:.5f}s") model.save("model_fp32.keras") if enable_rola: """Enable lora""" - print("Enable lora weights") + print("=====Enable lora weights=====") enable_lora(model) """Fine-tuning lora weights""" @@ -108,36 +115,44 @@ def fn(x): ) print("Performance of fine-tuned lora weights:") score = model.evaluate(x_test, y_test, verbose=0) - print(f"Test accuracy: {score[1]:.5f}") + print(f" Test accuracy: {score[1]:.5f}") avg_time = benchmark(model) - print(f"Avg. time (batch_size=1024): {avg_time:.5f}s") + print(f" Avg. inference time (batch_size=1024): {avg_time:.5f}s") """Quantize to int8 weights""" - model.quantize(mode="dynamic_int8") + model.quantize(mode="quantized_int8") int8_model = model int8_model.compile( loss="categorical_crossentropy", metrics=["accuracy"] ) print("Performance of quantized model:") score = int8_model.evaluate(x_test, y_test, verbose=0) - print(f"Test accuracy: {score[1]:.5f}") + print(f" Test accuracy: {score[1]:.5f}") avg_time = benchmark(int8_model) - print(f"Avg. time (batch_size=1024): {avg_time:.5f}s") + print(f" Avg. inference time (batch_size=1024): {avg_time:.5f}s") else: + print("=====No lora weights=====") """Quantization""" - model.quantize(mode="dynamic_int8") + model.quantize(mode="quantized_int8") int8_model = model int8_model.compile( loss="categorical_crossentropy", metrics=["accuracy"] ) print("Performance of quantized model:") score = int8_model.evaluate(x_test, y_test, verbose=0) - print(f"Test accuracy: {score[1]:.5f}") + print(f" Test accuracy: {score[1]:.5f}") avg_time = benchmark(int8_model) - print(f"Avg. time (batch_size=1024): {avg_time:.5f}s") + print(f" Avg. inference time (batch_size=1024): {avg_time:.5f}s") """Saving & loading""" int8_model.save("model_int8.keras") reloaded_int8_model = saving.load_model("model_int8.keras") reloaded_score = reloaded_int8_model.evaluate(x_test, y_test, verbose=0) print(f"Reloaded int8 model test accuracy: {reloaded_score[1]:.5f}") + print("Size of saved model:") + print(f" fp32: {os.path.getsize('model_fp32.keras') >> 20}MB") + print(f" int8: {os.path.getsize('model_int8.keras') >> 20}MB") + +"""Cleanup""" +os.remove("model_fp32.keras") +os.remove("model_int8.keras") From 723e69a3181de0ce28c09e13bbcf8fd7ba965685 Mon Sep 17 00:00:00 2001 From: Hongyu Chiu <20734616+james77777778@users.noreply.github.com> Date: Sat, 2 Mar 2024 14:25:59 +0800 Subject: [PATCH 09/33] Update Dense --- keras/layers/core/dense.py | 15 +-------------- 1 file changed, 1 insertion(+), 14 deletions(-) diff --git a/keras/layers/core/dense.py b/keras/layers/core/dense.py index 72e9bb2ac54..551fc52e124 100644 --- a/keras/layers/core/dense.py +++ b/keras/layers/core/dense.py @@ -201,21 +201,8 @@ def enable_lora( def quantize(self, mode, input_shape=None): self._check_quantize_args(mode) - if input_shape is None: - if self._build_shapes_dict is None: - raise ValueError( - "If no `input_shape` is provided, you must first build the" - "layer before applying the quantization." - ) - input_shape = list(tuple(self._build_shapes_dict.values())[0]) - if mode == "quantized_int8": - inputs_quantizer_axes = list(range(len(input_shape))) - if len(inputs_quantizer_axes) > 2: - inputs_quantizer_axes.pop(-2) - self.inputs_quantizer = quantizers.AbsMaxQuantizer( - axis=inputs_quantizer_axes - ) + self.inputs_quantizer = quantizers.AbsMaxQuantizer(axis=-1) # Merge lora-related parameters to make use of fully int8 kernel self._merge_lora_into_kernel() kernel_value, kernel_scale = quantizers.abs_max_quantize( From b36560a0f570d4e8df11bebeeea69ac44356ae3c Mon Sep 17 00:00:00 2001 From: Hongyu Chiu <20734616+james77777778@users.noreply.github.com> Date: Sun, 3 Mar 2024 22:24:26 +0800 Subject: [PATCH 10/33] Remove unused `input_shape` --- keras/layers/core/dense.py | 4 ++-- keras/layers/layer.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/keras/layers/core/dense.py b/keras/layers/core/dense.py index 551fc52e124..ef2c45ab7b4 100644 --- a/keras/layers/core/dense.py +++ b/keras/layers/core/dense.py @@ -123,7 +123,7 @@ def build(self, input_shape): if self.lora_rank: self.enable_lora(self.lora_rank) if self.dtype_policy.quantization_mode: - self.quantize(self.dtype_policy.quantization_mode, input_shape) + self.quantize(self.dtype_policy.quantization_mode) @property def kernel(self): @@ -199,7 +199,7 @@ def enable_lora( self._tracker.lock() self.lora_enabled = True - def quantize(self, mode, input_shape=None): + def quantize(self, mode): self._check_quantize_args(mode) if mode == "quantized_int8": self.inputs_quantizer = quantizers.AbsMaxQuantizer(axis=-1) diff --git a/keras/layers/layer.py b/keras/layers/layer.py index 7a88dde9170..6d3656d1513 100644 --- a/keras/layers/layer.py +++ b/keras/layers/layer.py @@ -1104,7 +1104,7 @@ def _clear_losses(self): for layer in self._layers: layer._clear_losses() - def quantize(self, mode, input_shape=None): + def quantize(self, mode): self._check_quantize_args(mode) warnings.warn( "`quantize` is not implemented for class " From ad6dd6b64ad03e6b369cc68aad926130ad2edbcc Mon Sep 17 00:00:00 2001 From: Hongyu Chiu <20734616+james77777778@users.noreply.github.com> Date: Mon, 4 Mar 2024 22:10:12 +0800 Subject: [PATCH 11/33] Add `self.is_quantized_int8` to `Operation` and some minor updates --- keras/layers/core/dense.py | 9 +++++---- keras/layers/layer.py | 10 ++++------ keras/models/model.py | 6 +++--- keras/ops/operation.py | 15 ++++++++++++--- 4 files changed, 24 insertions(+), 16 deletions(-) diff --git a/keras/layers/core/dense.py b/keras/layers/core/dense.py index ef2c45ab7b4..91bc6dedaf0 100644 --- a/keras/layers/core/dense.py +++ b/keras/layers/core/dense.py @@ -148,9 +148,9 @@ def call(self, inputs): def int8_call(self, inputs): if self.lora_enabled: raise ValueError("`int8_call` doesn't support lora weights") - inputs, inputs_scale = self.inputs_quantizer(inputs) x = ops.matmul(inputs, self.kernel) + # De-scale outputs x = ops.cast(x, self.compute_dtype) x = ops.divide(x, ops.multiply(inputs_scale, self.kernel_scale)) if self.bias is not None: @@ -202,13 +202,14 @@ def enable_lora( def quantize(self, mode): self._check_quantize_args(mode) if mode == "quantized_int8": - self.inputs_quantizer = quantizers.AbsMaxQuantizer(axis=-1) # Merge lora-related parameters to make use of fully int8 kernel self._merge_lora_into_kernel() + # Configure `self.inputs_quantizer` + self.inputs_quantizer = quantizers.AbsMaxQuantizer(axis=-1) + # Quantize `self._kernel` to int8 and compute corresponding scale kernel_value, kernel_scale = quantizers.abs_max_quantize( self._kernel, axis=0 ) - self._tracker.unlock() self._untrack_variable(self._kernel) self._kernel = self.add_weight( @@ -226,12 +227,12 @@ def quantize(self, mode): trainable=False, ) self._tracker.lock() - self.kernel_quantizer = None else: NotImplementedError() quantized_dtype = f"{mode}_from_{self.dtype_policy.name}" self.dtype_policy = dtype_policies.get(quantized_dtype) + self.is_quantized_int8 = self.dtype_policy.is_quantized_int8 def _merge_lora_into_kernel(self): if not self.lora_enabled: diff --git a/keras/layers/layer.py b/keras/layers/layer.py index 6d3656d1513..ce7d99f7d4f 100644 --- a/keras/layers/layer.py +++ b/keras/layers/layer.py @@ -290,6 +290,8 @@ def __init__( self._convert_input_args = True # Whether to allow non-tensors as positional arguments in `call()`. self._allow_non_tensor_positional_args = False + # Whether to set `is_quantized_int8` + self.is_quantized_int8 = self.dtype_policy.is_quantized_int8 # Dict of shapes that were used to call `build()`. self._build_shapes_dict = None self._initializer_tracker() @@ -789,12 +791,8 @@ def maybe_convert(x): ) kwargs[expected_mask_arg_name] = mask - ############################## - # 7. Populate quantization argument(s) - kwargs["is_quantized_int8"] = self.dtype_policy.is_quantized_int8 - #################### - # 8. Call the layer. + # 7. Call the layer. try: with backend.name_scope(self.name, caller=self): current_scope = backend.get_autocast_scope() @@ -958,7 +956,7 @@ def stateless_call( with backend.StatelessScope( state_mapping=mapping, collect_losses=return_losses ) as scope: - if self.dtype_policy.is_quantized_int8: + if self.is_quantized_int8: outputs = self.int8_call(*args, **kwargs) else: outputs = self.call(*args, **kwargs) diff --git a/keras/models/model.py b/keras/models/model.py index 921d5441bd6..c460a40c768 100644 --- a/keras/models/model.py +++ b/keras/models/model.py @@ -365,8 +365,8 @@ def quantize(self, mode): """Quantize the weights of the model. Note that the model must be built first before calling this method. - Quantization will be skipped if the layer doesn't implement `quantize` - function. + `quantize` will recursively call `quantize(mode)` in all layers and + will be skipped if the layer doesn't implement the function. Args: mode: The mode of the quantization. The supported modes are @@ -377,7 +377,7 @@ def quantize(self, mode): "The model must be built first before calling `quantize`." ) mode_changed = False - for layer in self.layers: + for layer in self._flatten_layers(include_self=False, recursive=True): original_mode = layer.dtype_policy.quantization_mode layer.quantize(mode) if layer.dtype_policy.quantization_mode != original_mode: diff --git a/keras/ops/operation.py b/keras/ops/operation.py index e744a52e61f..c04fa135523 100644 --- a/keras/ops/operation.py +++ b/keras/ops/operation.py @@ -26,16 +26,16 @@ def __init__(self, name=None): self.name = name self._inbound_nodes = [] self._outbound_nodes = [] + self._is_quantized_int8 = False @traceback_utils.filter_traceback def __call__(self, *args, **kwargs): - is_quantized_int8 = kwargs.pop("is_quantized_int8", None) if traceback_utils.is_traceback_filtering_enabled(): # Wrap self.call to provide helpful info in case of exception if any_symbolic_tensors(args, kwargs): call_fn = self.symbolic_call else: - if is_quantized_int8: + if self.is_quantized_int8: call_fn = self.int8_call else: call_fn = self.call @@ -48,7 +48,7 @@ def __call__(self, *args, **kwargs): # Plain flow. if any_symbolic_tensors(args, kwargs): return self.symbolic_call(*args, **kwargs) - if is_quantized_int8: + if self.is_quantized_int8: return self.int8_call(*args, **kwargs) else: return self.call(*args, **kwargs) @@ -238,6 +238,15 @@ def output(self): """ return self._get_node_attribute_at_index(0, "output_tensors", "output") + @property + def is_quantized_int8(self): + """Whether the operation is quantized to int8.""" + return self._is_quantized_int8 + + @is_quantized_int8.setter + def is_quantized_int8(self, value): + self._is_quantized_int8 = value + def _get_node_attribute_at_index(self, node_index, attr, attr_name): """Private utility to retrieves an attribute (e.g. inputs) from a node. From 1ebf41e6474e8799a1cfc7b27fbdd4f0f7688faa Mon Sep 17 00:00:00 2001 From: Hongyu Chiu <20734616+james77777778@users.noreply.github.com> Date: Mon, 4 Mar 2024 22:10:25 +0800 Subject: [PATCH 12/33] Add `quantize` to `EinsumDense` --- keras/layers/core/einsum_dense.py | 239 +++++++++++++++++++++++++++++- 1 file changed, 237 insertions(+), 2 deletions(-) diff --git a/keras/layers/core/einsum_dense.py b/keras/layers/core/einsum_dense.py index a950ac4992e..17dfc1eda9d 100644 --- a/keras/layers/core/einsum_dense.py +++ b/keras/layers/core/einsum_dense.py @@ -1,13 +1,17 @@ import re +import string import numpy as np from keras import activations from keras import constraints +from keras import dtype_policies from keras import initializers from keras import ops +from keras import quantizers from keras import regularizers from keras.api_export import keras_export +from keras.layers.input_spec import InputSpec from keras.layers.layer import Layer @@ -170,9 +174,12 @@ def build(self, input_shape): ) else: self.bias = None - super().build(input_shape) + self.input_spec = InputSpec(shape=input_shape) + self.built = True if self.lora_rank: self.enable_lora(self.lora_rank) + if self.dtype_policy.quantization_mode: + self.quantize(self.dtype_policy.quantization_mode) @property def kernel(self): @@ -222,6 +229,30 @@ def call(self, inputs): x = self.activation(x) return x + def int8_call(self, inputs): + if self.lora_enabled: + raise ValueError("`int8_call` doesn't support lora weights") + inputs, inputs_scale = self.inputs_quantizer(inputs) + x = ops.einsum(self.equation, inputs, self.kernel) + # Deal with `inputs_scale` + inputs_scale = ops.transpose(inputs_scale, self._input_transpose_axes) + if self._input_expand_axes: + inputs_scale = ops.expand_dims( + inputs_scale, axis=self._input_expand_axes + ) + if self._input_squeeze_axes: + inputs_scale = ops.squeeze( + inputs_scale, axis=self._input_squeeze_axes + ) + # De-scale outputs + x = ops.cast(x, self.compute_dtype) + x = ops.divide(x, ops.multiply(inputs_scale, self.kernel_scale)) + if self.bias is not None: + x += self.bias + if self.activation is not None: + x = self.activation(x) + return x + def enable_lora( self, rank, a_initializer="he_uniform", b_initializer="zeros" ): @@ -253,10 +284,87 @@ def enable_lora( initializer=initializers.get(b_initializer), regularizer=self.kernel_regularizer, ) - self.kernel.trainable = False + self._kernel.trainable = False self._tracker.lock() self.lora_enabled = True + def quantize(self, mode): + self._check_quantize_args(mode) + if mode == "quantized_int8": + # Merge lora-related parameters to make use of fully int8 kernel + self._merge_lora_into_kernel() + + if self.input_spec is None: + raise ValueError( + f"Cannot quantize {self.name} that isn't yet built." + ) + ( + self._input_reduced_axes, + self._kernel_reduced_axes, + self._input_transpose_axes, + self._kernel_transpose_axes, + self._input_expand_axes, + self._kernel_expand_axes, + self._input_squeeze_axes, + self._kernel_squeeze_axes, + ) = _analyze_quantization_info(self.equation, self.input_spec.ndim) + # Configure `self.inputs_quantizer` + self.inputs_quantizer = quantizers.AbsMaxQuantizer( + axis=self._input_reduced_axes + ) + # Quantize `self._kernel` to int8 and compute corresponding scale + kernel_value, kernel_scale = quantizers.abs_max_quantize( + self._kernel, axis=self._kernel_reduced_axes + ) + kernel_scale = ops.transpose( + kernel_scale, self._kernel_transpose_axes + ) + if self._kernel_expand_axes: + kernel_scale = ops.expand_dims( + kernel_scale, axis=self._kernel_expand_axes + ) + if self._kernel_squeeze_axes: + kernel_scale = ops.squeeze( + kernel_scale, axis=self._kernel_squeeze_axes + ) + self._tracker.unlock() + self._untrack_variable(self._kernel) + self._kernel = self.add_weight( + name="kernel", + shape=self._kernel.shape, + initializer=initializers.Constant(kernel_value), + dtype="int8", + trainable=False, + ) + self.kernel_scale = self.add_weight( + name="kernel_scale", + shape=kernel_scale.shape, + initializer=initializers.Constant(kernel_scale), + dtype=self.compute_dtype, + trainable=False, + ) + self._tracker.lock() + else: + NotImplementedError() + + quantized_dtype = f"{mode}_from_{self.dtype_policy.name}" + self.dtype_policy = dtype_policies.get(quantized_dtype) + self.is_quantized_int8 = self.dtype_policy.is_quantized_int8 + + def _merge_lora_into_kernel(self): + if not self.lora_enabled: + return + + # Merge lora-enabled kernel into kernel + self._kernel.assign(self.kernel) + + # Untrack lora parameters + self._tracker.unlock() + self.lora_kernel_a = self._untrack_variable(self.lora_kernel_a) + self.lora_kernel_b = self._untrack_variable(self.lora_kernel_b) + self._tracker.lock() + self.lora_enabled = False + def save_own_variables(self, store): if not self.lora_enabled: return super().save_own_variables(store) @@ -423,3 +531,130 @@ def _analyze_split_string( bias_shape = None return weight_shape, bias_shape, output_shape + + +def _analyze_quantization_info(equation, input_shape): + + def get_specs(equation, input_shape): + possible_labels = string.ascii_letters + dot_replaced_string = re.sub(r"\.\.\.", "0", equation) + + # This is the case where no ellipses are present in the string. + split_string = re.match( + "([a-zA-Z]+),([a-zA-Z]+)->([a-zA-Z]+)", dot_replaced_string + ) + if split_string is not None: + input_spec = split_string.group(1) + weight_spec = split_string.group(2) + output_spec = split_string.group(3) + return input_spec, weight_spec, output_spec + + # This is the case where ellipses are present on the left. + split_string = re.match( + "0([a-zA-Z]+),([a-zA-Z]+)->0([a-zA-Z]+)", dot_replaced_string + ) + if split_string is not None: + input_spec = split_string.group(1) + weight_spec = split_string.group(2) + output_spec = split_string.group(3) + elided = len(input_shape) - len(input_spec) + possible_labels = sorted( + set(possible_labels) + - set(input_spec) + - set(weight_spec) + - set(output_spec) + ) + # Pad labels on the left to `input_spec` and `output_spec` + for i in range(elided): + input_spec = possible_labels[i] + input_spec + output_spec = possible_labels[i] + output_spec + return input_spec, weight_spec, output_spec + + # This is the case where ellipses are present on the right. + split_string = re.match( + "([a-zA-Z]{2,})0,([a-zA-Z]+)->([a-zA-Z]+)0", dot_replaced_string + ) + if split_string is not None: + input_spec = split_string.group(1) + weight_spec = split_string.group(2) + output_spec = split_string.group(3) + elided = len(input_shape) - len(input_spec) + possible_labels = sorted( + set(possible_labels) + - set(input_spec) + - set(weight_spec) + - set(output_spec) + ) + # Pad labels on the right to `input_spec` and `output_spec` + for i in range(elided): + input_spec = input_spec + possible_labels[i] + output_spec = output_spec + possible_labels[i] + return input_spec, weight_spec, output_spec + + raise ValueError( + f"Invalid einsum equation '{equation}'. Equations must be in the " + "form [X],[Y]->[Z], ...[X],[Y]->...[Z], or [X]...,[Y]->[Z]...." + ) + + input_spec, weight_spec, output_spec = get_specs(equation, input_shape) + + # Determine the axes that should be reduced by the quantizer + input_reduced_axes = [] + weight_reduced_axes = [] + for i, label in enumerate(input_spec): + index = output_spec.find(label) + if index == -1: + input_reduced_axes.append(i) + for i, label in enumerate(weight_spec): + index = output_spec.find(label) + if index == -1: + weight_reduced_axes.append(i) + + # Determine the axes of `ops.expand_dims` + input_expand_axes = [] + weight_expand_axes = [] + for i, label in enumerate(output_spec): + index_input = input_spec.find(label) + index_weight = weight_spec.find(label) + if index_input == -1: + input_expand_axes.append(i) + if index_weight == -1: + weight_expand_axes.append(i) + + # Determine the axes of `ops.transpose` + input_transpose_axes = [] + weight_transpose_axes = [] + for i, label in enumerate(output_spec): + index_input = input_spec.find(label) + index_weight = weight_spec.find(label) + if index_input != -1: + input_transpose_axes.append(index_input) + if index_weight != -1: + weight_transpose_axes.append(index_weight) + # Postprocess the information: + # 1. Add dummy axes (1) to transpose_axes + # 2. Add axis to squeeze_axes if 1. failed + input_squeeze_axes = [] + weight_squeeze_axes = [] + for ori_index in input_reduced_axes: + try: + index = input_expand_axes.pop(0) + except IndexError: + input_squeeze_axes.append(ori_index) + input_transpose_axes.insert(index, ori_index) + for ori_index in weight_reduced_axes: + try: + index = weight_expand_axes.pop(0) + except IndexError: + weight_squeeze_axes.append(ori_index) + weight_transpose_axes.insert(index, ori_index) + return ( + input_reduced_axes, + weight_reduced_axes, + input_transpose_axes, + weight_transpose_axes, + input_expand_axes, + weight_expand_axes, + input_squeeze_axes, + weight_squeeze_axes, + ) From 443a6b1533cc598c5dfd02be27c3a1e1d1769525 Mon Sep 17 00:00:00 2001 From: Hongyu Chiu <20734616+james77777778@users.noreply.github.com> Date: Tue, 5 Mar 2024 11:02:50 +0800 Subject: [PATCH 13/33] Add `ab,bc->ac` custom ops --- keras/backend/tensorflow/numpy.py | 3 + keras/ops/numpy_test.py | 117 +++++++++++++++++++----------- 2 files changed, 77 insertions(+), 43 deletions(-) diff --git a/keras/backend/tensorflow/numpy.py b/keras/backend/tensorflow/numpy.py index 1e181be04f6..1c652a5c17e 100644 --- a/keras/backend/tensorflow/numpy.py +++ b/keras/backend/tensorflow/numpy.py @@ -103,6 +103,7 @@ def is_valid_for_custom_ops(subscripts, *operands): # `None`. if subscripts in [ "a,b->ab", + "ab,bc->ac", "abc,cd->abd", "abcd,abed->abce", "abcd,adbe->acbe", @@ -158,6 +159,8 @@ def use_custom_ops(subscripts, *operands, output_type): x = tf.expand_dims(x, axis=-1) y = tf.expand_dims(y, axis=0) return tf.matmul(x, y, output_type=output_type) + elif subscripts == "ab,bc->ac": + return tf.matmul(x, y, output_type=output_type) elif subscripts == "abc,cd->abd": return tf.matmul(x, y, output_type=output_type) elif subscripts == "abc,cde->abde": diff --git a/keras/ops/numpy_test.py b/keras/ops/numpy_test.py index 72279d9419e..b1f8fb59d30 100644 --- a/keras/ops/numpy_test.py +++ b/keras/ops/numpy_test.py @@ -2161,7 +2161,11 @@ def test_einsum(self): ) self.assertAllClose(knp.Einsum(",ijk")(5, y), np.einsum(",ijk", 5, y)) - def test_einsum_with_custom_ops(self): + @pytest.mark.skipif( + backend.backend() != "tensorflow", + reason=f"{backend.backend()} doesn't implement custom ops for einsum.", + ) + def test_einsum_custom_ops_for_tensorflow(self): subscripts = "a,b->ab" x = np.arange(2).reshape([2]).astype("float32") y = np.arange(3).reshape([3]).astype("float32") @@ -2169,6 +2173,13 @@ def test_einsum_with_custom_ops(self): knp.einsum(subscripts, x, y), np.einsum(subscripts, x, y) ) + subscripts = "ab,bc->ac" + x = np.arange(6).reshape([2, 3]).astype("float32") + y = np.arange(12).reshape([3, 4]).astype("float32") + self.assertAllClose( + knp.einsum(subscripts, x, y), np.einsum(subscripts, x, y) + ) + subscripts = "abc,cd->abd" x = np.arange(24).reshape([2, 3, 4]).astype("float32") y = np.arange(20).reshape([4, 5]).astype("float32") @@ -5634,49 +5645,69 @@ def get_input_shapes(subscripts): expected_dtype, ) - # Test custom implementation of einsum for tensorflow - if backend.backend() == "tensorflow": - for subscripts in [ - "a,b->ab", - "abc,cd->abd", - "abc,cde->abde", - "abc,dce->abde", - "abcd,abed->abce", - "abcd,adbe->acbe", - "abcd,aecd->acbe", - "abcd,aecd->aceb", - "abcd,cde->abe", - "abcde,aebf->adbcf", - "abcde,afce->acdbf", - ]: - x1_shape, x2_shape = get_input_shapes(subscripts) - x1 = knp.ones(x1_shape, dtype=dtype1) - x2 = knp.ones(x2_shape, dtype=dtype2) - x1_jax = jnp.ones(x1_shape, dtype=dtype1) - x2_jax = jnp.ones(x2_shape, dtype=dtype2) - if dtype1 == "int8" and dtype2 == "int8": - preferred_element_type = "int32" - else: - preferred_element_type = None - expected_dtype = standardize_dtype( - jnp.einsum( - subscripts, - x1_jax, - x2_jax, - preferred_element_type=preferred_element_type, - ).dtype - ) + @parameterized.named_parameters( + named_product( + dtypes=list(itertools.combinations(ALL_DTYPES, 2)) + + [("int8", "int8")] + ) + ) + @pytest.mark.skipif( + backend.backend() != "tensorflow", + reason=f"{backend.backend()} doesn't implement custom ops for einsum.", + ) + def test_einsum_custom_ops_for_tensorflow(self, dtypes): + import jax.numpy as jnp - self.assertEqual( - standardize_dtype(knp.einsum(subscripts, x1, x2).dtype), - expected_dtype, - ) - self.assertEqual( - standardize_dtype( - knp.Einsum(subscripts).symbolic_call(x1, x2).dtype - ), - expected_dtype, - ) + def get_input_shapes(subscripts): + x1_labels = subscripts.split(",")[0] + x2_labels = subscripts.split("->")[0][len(x1_labels) + 1 :] + x1_shape = [1] * len(x1_labels) + x2_shape = [1] * len(x2_labels) + return x1_shape, x2_shape + + dtype1, dtype2 = dtypes + for subscripts in [ + "a,b->ab", + "ab,bc->ac", + "abc,cd->abd", + "abc,cde->abde", + "abc,dce->abde", + "abcd,abed->abce", + "abcd,adbe->acbe", + "abcd,aecd->acbe", + "abcd,aecd->aceb", + "abcd,cde->abe", + "abcde,aebf->adbcf", + "abcde,afce->acdbf", + ]: + x1_shape, x2_shape = get_input_shapes(subscripts) + x1 = knp.ones(x1_shape, dtype=dtype1) + x2 = knp.ones(x2_shape, dtype=dtype2) + x1_jax = jnp.ones(x1_shape, dtype=dtype1) + x2_jax = jnp.ones(x2_shape, dtype=dtype2) + if dtype1 == "int8" and dtype2 == "int8": + preferred_element_type = "int32" + else: + preferred_element_type = None + expected_dtype = standardize_dtype( + jnp.einsum( + subscripts, + x1_jax, + x2_jax, + preferred_element_type=preferred_element_type, + ).dtype + ) + + self.assertEqual( + standardize_dtype(knp.einsum(subscripts, x1, x2).dtype), + expected_dtype, + ) + self.assertEqual( + standardize_dtype( + knp.Einsum(subscripts).symbolic_call(x1, x2).dtype + ), + expected_dtype, + ) @parameterized.named_parameters(named_product(dtype=ALL_DTYPES)) def test_empty(self, dtype): From d774cd35683de44892b76be39bcb63e275778346 Mon Sep 17 00:00:00 2001 From: Hongyu Chiu <20734616+james77777778@users.noreply.github.com> Date: Tue, 5 Mar 2024 11:03:40 +0800 Subject: [PATCH 14/33] Update --- check_dynamic_int8.py | 158 ------------------------------ check_quantized_int8.py | 143 +++++++++++++++++++++++++++ keras/layers/core/dense.py | 16 +-- keras/layers/core/einsum_dense.py | 16 +-- keras/layers/layer.py | 2 +- 5 files changed, 160 insertions(+), 175 deletions(-) delete mode 100644 check_dynamic_int8.py create mode 100644 check_quantized_int8.py diff --git a/check_dynamic_int8.py b/check_dynamic_int8.py deleted file mode 100644 index ed4a81ba61d..00000000000 --- a/check_dynamic_int8.py +++ /dev/null @@ -1,158 +0,0 @@ -import os -import time - -import jax -import numpy as np -import tensorflow as tf - -import keras -from keras import backend -from keras import dtype_policies -from keras import layers -from keras import models -from keras import ops -from keras import saving - -# Set dtype policy -dtype = "mixed_bfloat16" -dtype_policies.dtype_policy.set_dtype_policy(dtype) -print(f"Global dtype policy: {dtype_policies.dtype_policy.dtype_policy()}") - -# Model / data parameters -num_classes = 10 -input_shape = (28, 28, 1) -epochs = 1 - -# Load the data and split it between train and test sets -(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data() -x_train = x_train.astype("float32") / 255 -x_test = x_test.astype("float32") / 255 -x_train = np.expand_dims(x_train, -1) -x_test = np.expand_dims(x_test, -1) -y_train = keras.utils.to_categorical(y_train, num_classes) -y_test = keras.utils.to_categorical(y_test, num_classes) - - -def build_model(num_layers=32, units=1024): - inputs = layers.Input([28, 28]) - x = layers.Flatten()(inputs) - for _ in range(num_layers): - x = layers.Dense(units)(x) - x = layers.BatchNormalization()(x) - x = layers.ReLU()(x) - outputs = layers.Dense(10, use_bias=True, activation="softmax")(x) - model = models.Model(inputs, outputs) - return model - - -def enable_lora(model): - for layer in model.layers: - if hasattr(layer, "enable_lora"): - layer.enable_lora(2) - - -def benchmark(model, batch_size=1024, input_shape=(28, 28), iterations=200): - def fn(x): - return model(x, training=False) - - if backend.backend() == "tensorflow": - jit_fn = tf.function(fn, jit_compile=True) - elif backend.backend() == "jax": - jit_fn = jax.jit(fn) - else: - jit_fn = fn - - # warmup - x = ops.ones([batch_size, *input_shape]) - for _ in range(10): - _ = ops.convert_to_numpy(jit_fn(x)) - - times = [] - for _ in range(iterations): - t0 = time.time() - _ = ops.convert_to_numpy(jit_fn(x)) - t1 = time.time() - times.append(t1 - t0) - avg_time = sum(times) / len(times) - return avg_time - - -for enable_rola in (True, False): - model = build_model(num_layers=32, units=1024) - model.compile( - loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"] - ) - - """Train float model""" - print("=====Start training float model=====") - model.fit( - x_train, y_train, batch_size=128, epochs=epochs, validation_split=0.1 - ) - print(f"Performance of {dtype}:") - score = model.evaluate(x_test, y_test, verbose=0) - print(f" Test accuracy: {score[1]:.5f}") - avg_time = benchmark(model) - print(f" Avg. inference time (batch_size=1024): {avg_time:.5f}s") - model.save("model_fp32.keras") - - if enable_rola: - """Enable lora""" - print("=====Enable lora weights=====") - enable_lora(model) - - """Fine-tuning lora weights""" - model.compile( - loss="categorical_crossentropy", - optimizer="adam", - metrics=["accuracy"], - ) - model.fit( - x_train, - y_train, - batch_size=128, - epochs=epochs, - validation_split=0.1, - ) - print("Performance of fine-tuned lora weights:") - score = model.evaluate(x_test, y_test, verbose=0) - print(f" Test accuracy: {score[1]:.5f}") - avg_time = benchmark(model) - print(f" Avg. inference time (batch_size=1024): {avg_time:.5f}s") - - """Quantize to int8 weights""" - model.quantize(mode="quantized_int8") - int8_model = model - int8_model.compile( - loss="categorical_crossentropy", metrics=["accuracy"] - ) - print("Performance of quantized model:") - score = int8_model.evaluate(x_test, y_test, verbose=0) - print(f" Test accuracy: {score[1]:.5f}") - avg_time = benchmark(int8_model) - print(f" Avg. inference time (batch_size=1024): {avg_time:.5f}s") - else: - print("=====No lora weights=====") - """Quantization""" - model.quantize(mode="quantized_int8") - int8_model = model - int8_model.compile( - loss="categorical_crossentropy", metrics=["accuracy"] - ) - print("Performance of quantized model:") - score = int8_model.evaluate(x_test, y_test, verbose=0) - print(f" Test accuracy: {score[1]:.5f}") - avg_time = benchmark(int8_model) - print(f" Avg. inference time (batch_size=1024): {avg_time:.5f}s") - - """Saving & loading""" - int8_model.save("model_int8.keras") - reloaded_int8_model = saving.load_model("model_int8.keras") - reloaded_score = reloaded_int8_model.evaluate(x_test, y_test, verbose=0) - print(f"Reloaded int8 model test accuracy: {reloaded_score[1]:.5f}") - print("Size of saved model:") - print(f" fp32: {os.path.getsize('model_fp32.keras') >> 20}MB") - print(f" int8: {os.path.getsize('model_int8.keras') >> 20}MB") - -"""Cleanup""" -os.remove("model_fp32.keras") -os.remove("model_int8.keras") diff --git a/check_quantized_int8.py b/check_quantized_int8.py new file mode 100644 index 00000000000..aef5dd274d8 --- /dev/null +++ b/check_quantized_int8.py @@ -0,0 +1,143 @@ +import os +import time + +import jax +import numpy as np +import tensorflow as tf + +import keras +from keras import backend +from keras import dtype_policies +from keras import layers +from keras import models +from keras import ops +from keras import saving + +# Set dtype policy +dtype = "mixed_bfloat16" # float32 +dtype_policies.dtype_policy.set_dtype_policy(dtype) +print(f"Global dtype policy: {dtype_policies.dtype_policy.dtype_policy()}") + +# Model / data parameters +use_einsum = True +num_classes = 10 +input_shape = (28, 28, 1) +epochs = 1 + +# Load the data and split it between train and test sets +(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data() +x_train = x_train.astype("float32") / 255 +x_test = x_test.astype("float32") / 255 +x_train = np.expand_dims(x_train, -1) +x_test = np.expand_dims(x_test, -1) +y_train = keras.utils.to_categorical(y_train, num_classes) +y_test = keras.utils.to_categorical(y_test, num_classes) + + +def build_model(num_layers=32, units=1024, use_einsum=False): + inputs = layers.Input([28, 28]) + x = layers.Flatten()(inputs) + for _ in range(num_layers): + if use_einsum: + x = layers.EinsumDense("ab,bc->ac", output_shape=[units])(x) + else: + x = layers.Dense(units)(x) + x = layers.BatchNormalization()(x) + x = layers.ReLU()(x) + outputs = layers.Dense(10, use_bias=True, activation="softmax")(x) + model = models.Model(inputs, outputs) + return model + + +def enable_lora(model): + for layer in model.layers: + if hasattr(layer, "enable_lora"): + layer.enable_lora(2) + + +def benchmark(model, batch_size=1024, input_shape=(28, 28), iterations=200): + def fn(x): + return model(x, training=False) + + if backend.backend() == "tensorflow": + jit_fn = tf.function(fn, jit_compile=True) + elif backend.backend() == "jax": + jit_fn = jax.jit(fn) + else: + jit_fn = fn + + # warmup + x = ops.ones([batch_size, *input_shape]) + for _ in range(10): + _ = ops.convert_to_numpy(jit_fn(x)) + + times = [] + for _ in range(iterations): + t0 = time.time() + _ = ops.convert_to_numpy(jit_fn(x)) + t1 = time.time() + times.append(t1 - t0) + avg_time = sum(times) / len(times) + return avg_time + + +model = build_model(num_layers=32, units=1024, use_einsum=use_einsum) +model.compile( + loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"] +) + +"""Train float model""" +print("=====Start training float model=====") +model.fit(x_train, y_train, batch_size=128, epochs=epochs, validation_split=0.1) +print(f"Performance of {dtype}:") +score = model.evaluate(x_test, y_test, verbose=0) +print(f" Test accuracy: {score[1]:.5f}") +avg_time = benchmark(model) +print(f" Avg. inference time (batch_size=1024): {avg_time:.5f}s") +model.save("model_fp32.keras") + +"""Enable lora""" +print("=====Enable lora weights=====") +enable_lora(model) + +"""Fine-tuning lora weights""" +model.compile( + loss="categorical_crossentropy", + optimizer="adam", + metrics=["accuracy"], +) +model.fit( + x_train, + y_train, + batch_size=128, + epochs=epochs, + validation_split=0.1, +) +print("Performance of fine-tuned lora weights:") +score = model.evaluate(x_test, y_test, verbose=0) +print(f" Test accuracy: {score[1]:.5f}") +avg_time = benchmark(model) +print(f" Avg. inference time (batch_size=1024): {avg_time:.5f}s") + +"""Quantize to int8 weights""" +model.quantize(mode="quantized_int8") +int8_model = model +int8_model.compile(loss="categorical_crossentropy", metrics=["accuracy"]) +print("Performance of quantized model:") +score = int8_model.evaluate(x_test, y_test, verbose=0) +print(f" Test accuracy: {score[1]:.5f}") +avg_time = benchmark(int8_model) +print(f" Avg. inference time (batch_size=1024): {avg_time:.5f}s") + +"""Saving & loading""" +int8_model.save("model_int8.keras") +reloaded_int8_model = saving.load_model("model_int8.keras") +reloaded_score = reloaded_int8_model.evaluate(x_test, y_test, verbose=0) +print(f"Reloaded int8 model test accuracy: {reloaded_score[1]:.5f}") +print("Size of saved model:") +print(f" fp32: {os.path.getsize('model_fp32.keras') >> 20}MB") +print(f" int8: {os.path.getsize('model_int8.keras') >> 20}MB") + +"""Cleanup""" +os.remove("model_fp32.keras") +os.remove("model_int8.keras") diff --git a/keras/layers/core/dense.py b/keras/layers/core/dense.py index 91bc6dedaf0..b39b09edb27 100644 --- a/keras/layers/core/dense.py +++ b/keras/layers/core/dense.py @@ -198,6 +198,7 @@ def enable_lora( self._kernel.trainable = False self._tracker.lock() self.lora_enabled = True + self.lora_rank = rank def quantize(self, mode): self._check_quantize_args(mode) @@ -234,19 +235,18 @@ def quantize(self, mode): self.dtype_policy = dtype_policies.get(quantized_dtype) self.is_quantized_int8 = self.dtype_policy.is_quantized_int8 - def _merge_lora_into_kernel(self): + def _merge_lora_into_kernel(self, untrack=False): if not self.lora_enabled: return - # Merge lora-enabled kernel into kernel self._kernel.assign(self.kernel) - - # Untrack lora parameters - self._tracker.unlock() - self.lora_kernel_a = self._untrack_variable(self.lora_kernel_a) - self.lora_kernel_b = self._untrack_variable(self.lora_kernel_b) - self._tracker.lock() self.lora_enabled = False + if untrack: + self._tracker.unlock() + self.lora_kernel_a = self._untrack_variable(self.lora_kernel_a) + self.lora_kernel_b = self._untrack_variable(self.lora_kernel_b) + self._tracker.lock() + self.lora_rank = None def save_own_variables(self, store): if not self.lora_enabled: diff --git a/keras/layers/core/einsum_dense.py b/keras/layers/core/einsum_dense.py index 17dfc1eda9d..49499e11bce 100644 --- a/keras/layers/core/einsum_dense.py +++ b/keras/layers/core/einsum_dense.py @@ -287,6 +287,7 @@ def enable_lora( self._kernel.trainable = False self._tracker.lock() self.lora_enabled = True + self.lora_rank = rank def quantize(self, mode): self._check_quantize_args(mode) @@ -351,19 +352,18 @@ def quantize(self, mode): self.dtype_policy = dtype_policies.get(quantized_dtype) self.is_quantized_int8 = self.dtype_policy.is_quantized_int8 - def _merge_lora_into_kernel(self): + def _merge_lora_into_kernel(self, untrack=False): if not self.lora_enabled: return - # Merge lora-enabled kernel into kernel self._kernel.assign(self.kernel) - - # Untrack lora parameters - self._tracker.unlock() - self.lora_kernel_a = self._untrack_variable(self.lora_kernel_a) - self.lora_kernel_b = self._untrack_variable(self.lora_kernel_b) - self._tracker.lock() self.lora_enabled = False + if untrack: + self._tracker.unlock() + self.lora_kernel_a = self._untrack_variable(self.lora_kernel_a) + self.lora_kernel_b = self._untrack_variable(self.lora_kernel_b) + self._tracker.lock() + self.lora_rank = None def save_own_variables(self, store): if not self.lora_enabled: diff --git a/keras/layers/layer.py b/keras/layers/layer.py index ce7d99f7d4f..c5be38c5122 100644 --- a/keras/layers/layer.py +++ b/keras/layers/layer.py @@ -290,7 +290,7 @@ def __init__( self._convert_input_args = True # Whether to allow non-tensors as positional arguments in `call()`. self._allow_non_tensor_positional_args = False - # Whether to set `is_quantized_int8` + # Propagate `self.dtype_policy.is_quantized_int8` self.is_quantized_int8 = self.dtype_policy.is_quantized_int8 # Dict of shapes that were used to call `build()`. self._build_shapes_dict = None From 5c27a44ee355f8286411f8669ec671ffdf435c6e Mon Sep 17 00:00:00 2001 From: Hongyu Chiu <20734616+james77777778@users.noreply.github.com> Date: Tue, 5 Mar 2024 11:46:12 +0800 Subject: [PATCH 15/33] Update RESULT.md --- RESULT.md | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) create mode 100644 RESULT.md diff --git a/RESULT.md b/RESULT.md new file mode 100644 index 00000000000..c8ed33e473a --- /dev/null +++ b/RESULT.md @@ -0,0 +1,24 @@ +# Result + +Configuration: + +- MNIST +- model + - `Dense` or `EinsumDense` + - `BatchNormalization` + - `ReLU` +- fine-tuning with `enable_lora(rank=2)` +- inference time: batch size=1024 + - float: `self.lora_enabled=True` + - int8: `self.lora_enabled=False` (merged) + +|backend|dtype_policy|layer|float acc.|int8 acc.|float inference time|int8 inference time|inference time ratio| +|-|-|-|-|-|-|-|-| +|tensorflow|float32|`Dense`|0.95990|0.96000|0.00395s|0.00198s|0.501| +|tensorflow|mixed_bfloat16|`Dense`|0.96110|0.96110|0.00265s|0.00200s|0.755| +|jax|float32|`Dense`|0.96130|0.96160|0.00304s|0.00132s|0.434| +|jax|mixed_bfloat16|`Dense`|0.95290|0.95300|0.00177s|0.00133s|0.751| +|tensorflow|float32|`EinsumDense`|0.95950|0.95920|0.00384s|0.00188s|0.490| +|tensorflow|mixed_bfloat16|`EinsumDense`|0.95980|0.95970|0.00258s|0.00200s|0.775| +|jax|float32|`EinsumDense`|0.96170|0.96160|0.00302s|0.00132s|0.437| +|jax|mixed_bfloat16|`EinsumDense`|0.95720|0.95680|0.00176s|0.00125s|0.710| From 0571370fd0a4a890603484646d3d9d3b09898cfc Mon Sep 17 00:00:00 2001 From: Hongyu Chiu <20734616+james77777778@users.noreply.github.com> Date: Tue, 5 Mar 2024 11:58:51 +0800 Subject: [PATCH 16/33] Update RESULT.md --- RESULT.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/RESULT.md b/RESULT.md index c8ed33e473a..adce5efc75b 100644 --- a/RESULT.md +++ b/RESULT.md @@ -16,9 +16,9 @@ Configuration: |-|-|-|-|-|-|-|-| |tensorflow|float32|`Dense`|0.95990|0.96000|0.00395s|0.00198s|0.501| |tensorflow|mixed_bfloat16|`Dense`|0.96110|0.96110|0.00265s|0.00200s|0.755| -|jax|float32|`Dense`|0.96130|0.96160|0.00304s|0.00132s|0.434| -|jax|mixed_bfloat16|`Dense`|0.95290|0.95300|0.00177s|0.00133s|0.751| |tensorflow|float32|`EinsumDense`|0.95950|0.95920|0.00384s|0.00188s|0.490| |tensorflow|mixed_bfloat16|`EinsumDense`|0.95980|0.95970|0.00258s|0.00200s|0.775| +|jax|float32|`Dense`|0.96130|0.96160|0.00304s|0.00132s|0.434| +|jax|mixed_bfloat16|`Dense`|0.95290|0.95300|0.00177s|0.00133s|0.751| |jax|float32|`EinsumDense`|0.96170|0.96160|0.00302s|0.00132s|0.437| |jax|mixed_bfloat16|`EinsumDense`|0.95720|0.95680|0.00176s|0.00125s|0.710| From 515e193fef3b3bb6a7df008c4b5699849f1ad3e3 Mon Sep 17 00:00:00 2001 From: Hongyu Chiu <20734616+james77777778@users.noreply.github.com> Date: Tue, 5 Mar 2024 13:12:34 +0800 Subject: [PATCH 17/33] Update `InputSpec` --- keras/layers/core/einsum_dense.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keras/layers/core/einsum_dense.py b/keras/layers/core/einsum_dense.py index 49499e11bce..1ed3eed404f 100644 --- a/keras/layers/core/einsum_dense.py +++ b/keras/layers/core/einsum_dense.py @@ -174,7 +174,7 @@ def build(self, input_shape): ) else: self.bias = None - self.input_spec = InputSpec(shape=input_shape) + self.input_spec = InputSpec(ndim=len(input_shape)) self.built = True if self.lora_rank: self.enable_lora(self.lora_rank) From 885174fedbe1240163ffbd214a793dbefab4b77e Mon Sep 17 00:00:00 2001 From: Hongyu Chiu <20734616+james77777778@users.noreply.github.com> Date: Wed, 6 Mar 2024 09:42:38 +0800 Subject: [PATCH 18/33] Raise NotImplementedError in `int8_call` and `quantize` --- keras/layers/layer.py | 19 +++++++++---------- keras/models/model.py | 9 +++++---- keras/ops/operation.py | 3 +-- 3 files changed, 15 insertions(+), 16 deletions(-) diff --git a/keras/layers/layer.py b/keras/layers/layer.py index c5be38c5122..417669f61c5 100644 --- a/keras/layers/layer.py +++ b/keras/layers/layer.py @@ -865,7 +865,10 @@ def call(self, *args, **kwargs): ) def int8_call(self, *args, **kwargs): - return self.call(*args, **kwargs) + raise NotImplementedError( + f"Layer {self.__class__.__name__} does not have a `int8_call()` " + "method implemented." + ) @traceback_utils.filter_traceback def stateless_call( @@ -1103,22 +1106,18 @@ def _clear_losses(self): layer._clear_losses() def quantize(self, mode): - self._check_quantize_args(mode) - warnings.warn( - "`quantize` is not implemented for class " - f"'{self.__class__.__name__}' so the quantization is skipped." + raise NotImplementedError( + f"Layer {self.__class__.__name__} does not have a `quantize()` " + "method implemented." ) - def _check_quantize_args(self, mode): + def _check_quantize_args(self, mode, compute_dtype): if mode not in (None, "quantized_int8"): raise ValueError( "Currently, `quantize` must be one of " f"(`None`, 'quantized_int8'). Received: mode={mode}" ) - if ( - mode == "quantized_int8" - and self.dtype_policy.compute_dtype == "float16" - ): + if mode == "quantized_int8" and compute_dtype == "float16": raise ValueError( f"mode='{mode}' doesn't work well with " "compute_dtype='float16'. Consider loading model/layer with " diff --git a/keras/models/model.py b/keras/models/model.py index c460a40c768..28434246ced 100644 --- a/keras/models/model.py +++ b/keras/models/model.py @@ -374,14 +374,15 @@ def quantize(self, mode): """ if not self.built: raise ValueError( - "The model must be built first before calling `quantize`." + "The model must be built first before calling `quantize()`." ) mode_changed = False for layer in self._flatten_layers(include_self=False, recursive=True): - original_mode = layer.dtype_policy.quantization_mode - layer.quantize(mode) - if layer.dtype_policy.quantization_mode != original_mode: + try: + layer.quantize(mode) mode_changed = True + except NotImplementedError as e: + warnings.warn(str(e)) # We need to set these functions to `None` to remake them for changed # call function if mode_changed: diff --git a/keras/ops/operation.py b/keras/ops/operation.py index c04fa135523..e0846f0517a 100644 --- a/keras/ops/operation.py +++ b/keras/ops/operation.py @@ -71,8 +71,7 @@ def call(self, *args, **kwargs): raise NotImplementedError def int8_call(self, *args, **kwargs): - # Note that `int8_call` defaults to `call` if not implemented. - return self.call(*args, **kwargs) + raise NotImplementedError def compute_output_spec(self, *args, **kwargs): try: From 32277fcdeb86abdcdffdf203610db6765c08ac8a Mon Sep 17 00:00:00 2001 From: Hongyu Chiu <20734616+james77777778@users.noreply.github.com> Date: Wed, 6 Mar 2024 09:43:14 +0800 Subject: [PATCH 19/33] Fix variable creation issue in `quantize` --- keras/layers/core/dense.py | 8 +++++--- keras/layers/core/einsum_dense.py | 8 +++++--- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/keras/layers/core/dense.py b/keras/layers/core/dense.py index b39b09edb27..7fc50fd5d2c 100644 --- a/keras/layers/core/dense.py +++ b/keras/layers/core/dense.py @@ -201,7 +201,7 @@ def enable_lora( self.lora_rank = rank def quantize(self, mode): - self._check_quantize_args(mode) + self._check_quantize_args(mode, self.compute_dtype) if mode == "quantized_int8": # Merge lora-related parameters to make use of fully int8 kernel self._merge_lora_into_kernel() @@ -216,17 +216,19 @@ def quantize(self, mode): self._kernel = self.add_weight( name="kernel", shape=self._kernel.shape, - initializer=initializers.Constant(kernel_value), + initializer="zeros", dtype="int8", trainable=False, ) + self._kernel.assign(kernel_value) self.kernel_scale = self.add_weight( name="kernel_scale", shape=kernel_scale.shape, - initializer=initializers.Constant(kernel_scale), + initializer="zeros", dtype=self.compute_dtype, trainable=False, ) + self.kernel_scale.assign(kernel_scale) self._tracker.lock() else: NotImplementedError() diff --git a/keras/layers/core/einsum_dense.py b/keras/layers/core/einsum_dense.py index 1ed3eed404f..fc9ea7b92d1 100644 --- a/keras/layers/core/einsum_dense.py +++ b/keras/layers/core/einsum_dense.py @@ -290,7 +290,7 @@ def enable_lora( self.lora_rank = rank def quantize(self, mode): - self._check_quantize_args(mode) + self._check_quantize_args(mode, self.compute_dtype) if mode == "quantized_int8": # Merge lora-related parameters to make use of fully int8 kernel self._merge_lora_into_kernel() @@ -333,17 +333,19 @@ def quantize(self, mode): self._kernel = self.add_weight( name="kernel", shape=self._kernel.shape, - initializer=initializers.Constant(kernel_value), + initializer="zeros", dtype="int8", trainable=False, ) + self._kernel.assign(kernel_value) self.kernel_scale = self.add_weight( name="kernel_scale", shape=kernel_scale.shape, - initializer=initializers.Constant(kernel_scale), + initializer="zeros", dtype=self.compute_dtype, trainable=False, ) + self.kernel_scale.assign(kernel_scale) self._tracker.lock() else: NotImplementedError() From 95cfe0b5f1d0b37afb4d9e44fb792347df63d02a Mon Sep 17 00:00:00 2001 From: Hongyu Chiu <20734616+james77777778@users.noreply.github.com> Date: Wed, 6 Mar 2024 17:23:36 +0800 Subject: [PATCH 20/33] Introduce `FloatDtypePolicy` and `QuantizedDtypePolicy` --- keras/dtype_policies/__init__.py | 10 +- keras/dtype_policies/dtype_policy.py | 134 ++++++++++++++++----------- keras/layers/core/dense.py | 15 +-- keras/layers/core/einsum_dense.py | 15 +-- keras/layers/layer.py | 17 ++-- keras/ops/operation.py | 26 ++---- 6 files changed, 126 insertions(+), 91 deletions(-) diff --git a/keras/dtype_policies/__init__.py b/keras/dtype_policies/__init__.py index 871491d5c7d..f885a4c0139 100644 --- a/keras/dtype_policies/__init__.py +++ b/keras/dtype_policies/__init__.py @@ -1,9 +1,12 @@ from keras import backend from keras.dtype_policies import dtype_policy -from keras.saving import serialization_lib +from keras.dtype_policies.dtype_policy import FloatDTypePolicy +from keras.dtype_policies.dtype_policy import QuantizedDTypePolicy def get(identifier): + from keras.saving import serialization_lib + if identifier is None: return dtype_policy.dtype_policy() if isinstance(identifier, dtype_policy.DTypePolicy): @@ -11,7 +14,10 @@ def get(identifier): if isinstance(identifier, dict): return serialization_lib.deserialize_keras_object(identifier) if isinstance(identifier, str): - return dtype_policy.DTypePolicy(identifier) + if "quantized" in identifier: + return dtype_policy.QuantizedDTypePolicy(identifier) + else: + return dtype_policy.FloatDTypePolicy(identifier) try: return dtype_policy.DTypePolicy(backend.standardize_dtype(identifier)) except: diff --git a/keras/dtype_policies/dtype_policy.py b/keras/dtype_policies/dtype_policy.py index c77b7167ea5..b89f4c8cc94 100644 --- a/keras/dtype_policies/dtype_policy.py +++ b/keras/dtype_policies/dtype_policy.py @@ -62,10 +62,8 @@ def __init__(self, name): f"Received: name={name} (of type {type(name)})" ) self._name = name - self._quantization_mode = None - self._compute_dtype, self._variable_dtype = self._parse_name(name) - # TODO: check that the current hardware supports the provided - # dtype policy and raise/warn otherwise. + self._compute_dtype = None + self._variable_dtype = None def _parse_name(self, name): """Parses a `DTypePolicy` name into a compute and variable dtype. @@ -76,32 +74,7 @@ def _parse_name(self, name): Returns: The `(compute_dtype, variable_dtype)` pair. """ - if name == "mixed_float16": - return "float16", "float32" - elif name == "mixed_bfloat16": - return "bfloat16", "float32" - elif "quantized_int8" in name: - if "_from_" not in name: - raise ValueError( - f"Cannot convert '{name}' to a quantized DTypePolicy. " - "Valid policies are in the pattern of " - "'quantized_int8_from_(name)' such as " - "'quantized_int8_from_mixed_bfloat16'." - ) - # "quantized_int8_from_float32" indicates that the layer - # or model is quantized from float32 dtype policy. - ori_name = name.split("_from_")[-1] - self._quantization_mode = "quantized_int8" - return self._parse_name(ori_name) - try: - dtype = backend.standardize_dtype(name) - return dtype, dtype - except ValueError: - raise ValueError( - f"Cannot convert '{name}' to a mixed precision DTypePolicy." - " Valid policies include 'mixed_float16', 'mixed_bfloat16', " - "and the name of any dtype such as 'float32'." - ) + raise NotImplementedError @property def variable_dtype(self): @@ -141,26 +114,6 @@ def compute_dtype(self): """ return self._compute_dtype - @property - def quantization_mode(self): - """The quantization mode of this policy. - - Returns: - The quantization mode of this policy, as a string. `None` if no - quantization. - """ - return self._quantization_mode - - @property - def is_quantized_int8(self): - """Whether this policy is quantized to `'int8'`. - - Returns: - The boolean value indicating whether this policy is quantized to - `'int8'`. - """ - return "quantized_int8" in self._name - @property def name(self): """Returns the name of this policy.""" @@ -199,6 +152,80 @@ def from_config(cls, config): return cls(**config) +@keras_export( + ["keras.FloatDTypePolicy", "keras.dtype_policies.FloatDTypePolicy"] +) +class FloatDTypePolicy(DTypePolicy): + def __init__(self, name): + super().__init__(name) + self._compute_dtype, self._variable_dtype = self._parse_name(name) + # TODO: check that the current hardware supports the provided + # dtype policy and raise/warn otherwise. + + def _parse_name(self, name): + if name == "mixed_float16": + return "float16", "float32" + elif name == "mixed_bfloat16": + return "bfloat16", "float32" + try: + dtype = backend.standardize_dtype(name) + return dtype, dtype + except ValueError: + raise ValueError( + f"Cannot convert '{name}' to a mixed precision DTypePolicy." + " Valid policies include 'mixed_float16', 'mixed_bfloat16', " + "and the name of any dtype such as 'float32'." + ) + + def __repr__(self): + return f'' + + +@keras_export( + ["keras.QuantizedDTypePolicy", "keras.dtype_policies.QuantizedDTypePolicy"] +) +class QuantizedDTypePolicy(DTypePolicy): + def __init__(self, name): + super().__init__(name) + self._compute_dtype, self._variable_dtype = self._parse_name(name) + + def _parse_name(self, name): + if "_from_" in name: + quantization_mode, from_name = name.split("_from_") + self._quantization_mode = quantization_mode + if from_name == "mixed_float16": + return "float16", "float32" + elif from_name == "mixed_bfloat16": + return "bfloat16", "float32" + try: + dtype = backend.standardize_dtype(name) + return dtype, dtype + except ValueError: + raise ValueError( + f"Cannot convert '{name}' to a mixed precision DTypePolicy." + " Valid policies include 'mixed_float16', 'mixed_bfloat16'," + " and the name of any dtype such as 'float32'." + ) + raise ValueError( + f"Cannot convert '{name}' to a quantized DTypePolicy. " + "Valid policies are in the pattern of 'quantized_int8_from_(name)' " + "such as 'quantized_int8_from_mixed_bfloat16'." + ) + + @property + def quantization_mode(self): + """The quantization mode of this policy. + + Returns: + The quantization mode of this policy, as a string. `None` if no + quantization. + """ + return self._quantization_mode + + def __repr__(self): + return f'' + + @keras_export( [ "keras.config.set_dtype_policy", @@ -215,7 +242,10 @@ def set_dtype_policy(policy): """ if not isinstance(policy, DTypePolicy): if isinstance(policy, str): - policy = DTypePolicy(policy) + if "quantized" in policy: + policy = QuantizedDTypePolicy(policy) + else: + policy = FloatDTypePolicy(policy) else: raise ValueError( "Invalid `policy` argument. " @@ -238,6 +268,6 @@ def dtype_policy(): """Returns the current default dtype policy object.""" policy = global_state.get_global_attribute("dtype_policy", None) if policy is None: - policy = DTypePolicy(backend.floatx()) + policy = FloatDTypePolicy(backend.floatx()) set_dtype_policy(policy) return policy diff --git a/keras/layers/core/dense.py b/keras/layers/core/dense.py index 7fc50fd5d2c..6c0563a876d 100644 --- a/keras/layers/core/dense.py +++ b/keras/layers/core/dense.py @@ -122,7 +122,7 @@ def build(self, input_shape): self.built = True if self.lora_rank: self.enable_lora(self.lora_rank) - if self.dtype_policy.quantization_mode: + if isinstance(self.dtype_policy, dtype_policies.QuantizedDTypePolicy): self.quantize(self.dtype_policy.quantization_mode) @property @@ -145,9 +145,9 @@ def call(self, inputs): x = self.activation(x) return x - def int8_call(self, inputs): + def quantized_call(self, inputs): if self.lora_enabled: - raise ValueError("`int8_call` doesn't support lora weights") + raise ValueError("`quantized_call` doesn't support lora weights") inputs, inputs_scale = self.inputs_quantizer(inputs) x = ops.matmul(inputs, self.kernel) # De-scale outputs @@ -233,9 +233,12 @@ def quantize(self, mode): else: NotImplementedError() - quantized_dtype = f"{mode}_from_{self.dtype_policy.name}" - self.dtype_policy = dtype_policies.get(quantized_dtype) - self.is_quantized_int8 = self.dtype_policy.is_quantized_int8 + # Set new dtype policy + if not isinstance( + self.dtype_policy, dtype_policies.QuantizedDTypePolicy + ): + quantized_dtype = f"{mode}_from_{self.dtype_policy.name}" + self.dtype_policy = dtype_policies.get(quantized_dtype) def _merge_lora_into_kernel(self, untrack=False): if not self.lora_enabled: diff --git a/keras/layers/core/einsum_dense.py b/keras/layers/core/einsum_dense.py index fc9ea7b92d1..b0f42e16fe5 100644 --- a/keras/layers/core/einsum_dense.py +++ b/keras/layers/core/einsum_dense.py @@ -178,7 +178,7 @@ def build(self, input_shape): self.built = True if self.lora_rank: self.enable_lora(self.lora_rank) - if self.dtype_policy.quantization_mode: + if isinstance(self.dtype_policy, dtype_policies.QuantizedDTypePolicy): self.quantize(self.dtype_policy.quantization_mode) @property @@ -229,9 +229,9 @@ def call(self, inputs): x = self.activation(x) return x - def int8_call(self, inputs): + def quantized_call(self, inputs): if self.lora_enabled: - raise ValueError("`int8_call` doesn't support lora weights") + raise ValueError("`quantized_call` doesn't support lora weights") inputs, inputs_scale = self.inputs_quantizer(inputs) x = ops.einsum(self.equation, inputs, self.kernel) # Deal with `inputs_scale` @@ -350,9 +350,12 @@ def quantize(self, mode): else: NotImplementedError() - quantized_dtype = f"{mode}_from_{self.dtype_policy.name}" - self.dtype_policy = dtype_policies.get(quantized_dtype) - self.is_quantized_int8 = self.dtype_policy.is_quantized_int8 + # Set new dtype policy + if not isinstance( + self.dtype_policy, dtype_policies.QuantizedDTypePolicy + ): + quantized_dtype = f"{mode}_from_{self.dtype_policy.name}" + self.dtype_policy = dtype_policies.get(quantized_dtype) def _merge_lora_into_kernel(self, untrack=False): if not self.lora_enabled: diff --git a/keras/layers/layer.py b/keras/layers/layer.py index 417669f61c5..a0a50431ee9 100644 --- a/keras/layers/layer.py +++ b/keras/layers/layer.py @@ -245,7 +245,7 @@ def __init__( ): BackendLayer.__init__(self) self._lock = False - Operation.__init__(self, name=name) + Operation.__init__(self, dtype=dtype, name=name) self.activity_regularizer = regularizers.get(activity_regularizer) input_dim_arg = kwargs.pop("input_dim", None) if input_dim_arg is not None: @@ -268,7 +268,6 @@ def __init__( ) self.built = False - self.dtype_policy = dtype_policies.get(dtype) self.autocast = autocast self._input_spec = None self._called = False @@ -290,8 +289,6 @@ def __init__( self._convert_input_args = True # Whether to allow non-tensors as positional arguments in `call()`. self._allow_non_tensor_positional_args = False - # Propagate `self.dtype_policy.is_quantized_int8` - self.is_quantized_int8 = self.dtype_policy.is_quantized_int8 # Dict of shapes that were used to call `build()`. self._build_shapes_dict = None self._initializer_tracker() @@ -864,10 +861,10 @@ def call(self, *args, **kwargs): "method implemented." ) - def int8_call(self, *args, **kwargs): + def quantized_call(self, *args, **kwargs): raise NotImplementedError( - f"Layer {self.__class__.__name__} does not have a `int8_call()` " - "method implemented." + f"Layer {self.__class__.__name__} does not have a " + "`quantized_call()` method implemented." ) @traceback_utils.filter_traceback @@ -959,8 +956,10 @@ def stateless_call( with backend.StatelessScope( state_mapping=mapping, collect_losses=return_losses ) as scope: - if self.is_quantized_int8: - outputs = self.int8_call(*args, **kwargs) + if isinstance( + self.dtype_policy, dtype_policies.QuantizedDTypePolicy + ): + outputs = self.quantized_call(*args, **kwargs) else: outputs = self.call(*args, **kwargs) if return_losses: diff --git a/keras/ops/operation.py b/keras/ops/operation.py index e0846f0517a..bb5f26f669a 100644 --- a/keras/ops/operation.py +++ b/keras/ops/operation.py @@ -4,6 +4,7 @@ import tree from keras import backend +from keras import dtype_policies from keras.api_export import keras_export from keras.backend.common.keras_tensor import any_symbolic_tensors from keras.ops.node import Node @@ -14,7 +15,7 @@ @keras_export("keras.Operation") class Operation: - def __init__(self, name=None): + def __init__(self, dtype=None, name=None): if name is None: name = auto_name(self.__class__.__name__) if not isinstance(name, str) or "/" in name: @@ -23,10 +24,10 @@ def __init__(self, name=None): "cannot contain character `/`. " f"Received: name={name} (of type {type(name)})" ) + self.dtype_policy = dtype_policies.get(dtype) self.name = name self._inbound_nodes = [] self._outbound_nodes = [] - self._is_quantized_int8 = False @traceback_utils.filter_traceback def __call__(self, *args, **kwargs): @@ -35,8 +36,10 @@ def __call__(self, *args, **kwargs): if any_symbolic_tensors(args, kwargs): call_fn = self.symbolic_call else: - if self.is_quantized_int8: - call_fn = self.int8_call + if isinstance( + self.dtype_policy, dtype_policies.QuantizedDTypePolicy + ): + call_fn = self.quantized_call else: call_fn = self.call call_fn = traceback_utils.inject_argument_info_in_traceback( @@ -48,8 +51,8 @@ def __call__(self, *args, **kwargs): # Plain flow. if any_symbolic_tensors(args, kwargs): return self.symbolic_call(*args, **kwargs) - if self.is_quantized_int8: - return self.int8_call(*args, **kwargs) + if isinstance(self.dtype_policy, dtype_policies.QuantizedDTypePolicy): + return self.quantized_call(*args, **kwargs) else: return self.call(*args, **kwargs) @@ -70,7 +73,7 @@ def symbolic_call(self, *args, **kwargs): def call(self, *args, **kwargs): raise NotImplementedError - def int8_call(self, *args, **kwargs): + def quantized_call(self, *args, **kwargs): raise NotImplementedError def compute_output_spec(self, *args, **kwargs): @@ -237,15 +240,6 @@ def output(self): """ return self._get_node_attribute_at_index(0, "output_tensors", "output") - @property - def is_quantized_int8(self): - """Whether the operation is quantized to int8.""" - return self._is_quantized_int8 - - @is_quantized_int8.setter - def is_quantized_int8(self, value): - self._is_quantized_int8 = value - def _get_node_attribute_at_index(self, node_index, attr, attr_name): """Private utility to retrieves an attribute (e.g. inputs) from a node. From 7370914df38fec32c48304c7834759cc76ad1ee5 Mon Sep 17 00:00:00 2001 From: Hongyu Chiu <20734616+james77777778@users.noreply.github.com> Date: Wed, 6 Mar 2024 17:44:45 +0800 Subject: [PATCH 21/33] Defaults to `backend.floatx()` to `self._compute_dtype` and `self._variable_dtype` --- keras/dtype_policies/dtype_policy.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/keras/dtype_policies/dtype_policy.py b/keras/dtype_policies/dtype_policy.py index b89f4c8cc94..7ee5a304db7 100644 --- a/keras/dtype_policies/dtype_policy.py +++ b/keras/dtype_policies/dtype_policy.py @@ -62,8 +62,8 @@ def __init__(self, name): f"Received: name={name} (of type {type(name)})" ) self._name = name - self._compute_dtype = None - self._variable_dtype = None + self._compute_dtype = backend.floatx() + self._variable_dtype = backend.floatx() def _parse_name(self, name): """Parses a `DTypePolicy` name into a compute and variable dtype. From 43f09f78015a3f9a275f1b6ac146a61612261b24 Mon Sep 17 00:00:00 2001 From: Hongyu Chiu <20734616+james77777778@users.noreply.github.com> Date: Wed, 6 Mar 2024 17:47:15 +0800 Subject: [PATCH 22/33] Remove unused code --- keras/layers/preprocessing/index_lookup.py | 1 - 1 file changed, 1 deletion(-) diff --git a/keras/layers/preprocessing/index_lookup.py b/keras/layers/preprocessing/index_lookup.py index 4f422d80a4e..a99651f62ea 100644 --- a/keras/layers/preprocessing/index_lookup.py +++ b/keras/layers/preprocessing/index_lookup.py @@ -188,7 +188,6 @@ def __init__( ) kwargs.pop("trainable", None) kwargs.pop("dtype", None) - kwargs.pop("quantization_mode", None) if kwargs: raise ValueError(f"Unrecognized keyword argument(s): {kwargs}") From 64d624164f4e8c2438e21cef93a55065d5f31191 Mon Sep 17 00:00:00 2001 From: Hongyu Chiu <20734616+james77777778@users.noreply.github.com> Date: Thu, 7 Mar 2024 11:20:23 +0800 Subject: [PATCH 23/33] Rename `mode=quantized_int8` to `mode=int8` and update dtype policy for backwards compatibility --- keras/dtype_policies/__init__.py | 10 +- keras/dtype_policies/dtype_policy.py | 76 ++++++----- keras/dtype_policies/dtype_policy_test.py | 154 ++++++++++++++++++++-- keras/layers/core/dense.py | 11 +- keras/layers/core/dense_test.py | 2 +- keras/layers/core/einsum_dense.py | 11 +- keras/layers/core/einsum_dense_test.py | 2 +- keras/layers/layer.py | 7 +- keras/models/model.py | 6 +- 9 files changed, 219 insertions(+), 60 deletions(-) diff --git a/keras/dtype_policies/__init__.py b/keras/dtype_policies/__init__.py index f885a4c0139..027fa1dd092 100644 --- a/keras/dtype_policies/__init__.py +++ b/keras/dtype_policies/__init__.py @@ -9,17 +9,17 @@ def get(identifier): if identifier is None: return dtype_policy.dtype_policy() - if isinstance(identifier, dtype_policy.DTypePolicy): + if isinstance(identifier, (FloatDTypePolicy, QuantizedDTypePolicy)): return identifier if isinstance(identifier, dict): return serialization_lib.deserialize_keras_object(identifier) if isinstance(identifier, str): - if "quantized" in identifier: - return dtype_policy.QuantizedDTypePolicy(identifier) + if "int8" in identifier: + return QuantizedDTypePolicy(identifier) else: - return dtype_policy.FloatDTypePolicy(identifier) + return FloatDTypePolicy(identifier) try: - return dtype_policy.DTypePolicy(backend.standardize_dtype(identifier)) + return FloatDTypePolicy(backend.standardize_dtype(identifier)) except: raise ValueError( "Cannot interpret `dtype` argument. Expected a string " diff --git a/keras/dtype_policies/dtype_policy.py b/keras/dtype_policies/dtype_policy.py index 7ee5a304db7..150a5f34a4f 100644 --- a/keras/dtype_policies/dtype_policy.py +++ b/keras/dtype_policies/dtype_policy.py @@ -1,3 +1,5 @@ +import warnings + from keras import backend from keras import ops from keras.api_export import keras_export @@ -55,12 +57,28 @@ class DTypePolicy: to explicitly construct a `DTypePolicy` object. """ - def __init__(self, name): + def __new__(cls, name): if not isinstance(name, str): raise TypeError( "'name' must be a string, such as 'mixed_float16'. " f"Received: name={name} (of type {type(name)})" ) + # For backwards compatibility + # TODO: We should consider deprecating this behavior + if cls is __class__: + warnings.warn( + "Consider using the subclass of DTypePolicy to initialize the " + "dtype policy such as FloatDTypePolicy and " + "QuantizedDTypePolicy." + ) + else: + return super().__new__(cls) + if "int8" in name: + return QuantizedDTypePolicy(name) + else: + return FloatDTypePolicy(name) + + def __init__(self, name): self._name = name self._compute_dtype = backend.floatx() self._variable_dtype = backend.floatx() @@ -119,9 +137,6 @@ def name(self): """Returns the name of this policy.""" return self._name - def __repr__(self): - return f'' - def convert_input(self, x, autocast, dtype): dtype = backend.standardize_dtype(dtype) if backend.is_tensor(x): @@ -172,9 +187,9 @@ def _parse_name(self, name): return dtype, dtype except ValueError: raise ValueError( - f"Cannot convert '{name}' to a mixed precision DTypePolicy." - " Valid policies include 'mixed_float16', 'mixed_bfloat16', " - "and the name of any dtype such as 'float32'." + f"Cannot convert '{name}' to a mixed precision " + "FloatDTypePolicy. Valid policies include 'mixed_float16', " + "'mixed_bfloat16', and the name of any dtype such as 'float32'." ) def __repr__(self): @@ -187,38 +202,39 @@ def __repr__(self): class QuantizedDTypePolicy(DTypePolicy): def __init__(self, name): super().__init__(name) - self._compute_dtype, self._variable_dtype = self._parse_name(name) + self._quantization_mode, self._compute_dtype, self._variable_dtype = ( + self._parse_name(name) + ) def _parse_name(self, name): - if "_from_" in name: - quantization_mode, from_name = name.split("_from_") - self._quantization_mode = quantization_mode - if from_name == "mixed_float16": - return "float16", "float32" - elif from_name == "mixed_bfloat16": - return "bfloat16", "float32" - try: - dtype = backend.standardize_dtype(name) - return dtype, dtype - except ValueError: - raise ValueError( - f"Cannot convert '{name}' to a mixed precision DTypePolicy." - " Valid policies include 'mixed_float16', 'mixed_bfloat16'," - " and the name of any dtype such as 'float32'." - ) - raise ValueError( - f"Cannot convert '{name}' to a quantized DTypePolicy. " - "Valid policies are in the pattern of 'quantized_int8_from_(name)' " - "such as 'quantized_int8_from_mixed_bfloat16'." + error_msg = ( + f"Cannot convert '{name}' to a QuantizedDTypePolicy. " + "Valid policies include " + "'int8_from_float32', 'int8_from_float16', 'int8_from_bfloat16', " + "'int8_from_mixed_float16', 'int8_from_mixed_bfloat16'." ) + split_name = name.split("_from_") + if len(split_name) != 2: + raise ValueError(error_msg) + mode, from_name = split_name + if mode not in ("int8",): + raise ValueError(error_msg) + if from_name == "mixed_float16": + return mode, "float16", "float32" + elif from_name == "mixed_bfloat16": + return mode, "bfloat16", "float32" + try: + dtype = backend.standardize_dtype(from_name) + return mode, dtype, dtype + except ValueError: + raise ValueError(error_msg) @property def quantization_mode(self): """The quantization mode of this policy. Returns: - The quantization mode of this policy, as a string. `None` if no - quantization. + The quantization mode of this policy, as a string. """ return self._quantization_mode diff --git a/keras/dtype_policies/dtype_policy_test.py b/keras/dtype_policies/dtype_policy_test.py index f543226fe8f..dfd09fcf02a 100644 --- a/keras/dtype_policies/dtype_policy_test.py +++ b/keras/dtype_policies/dtype_policy_test.py @@ -1,4 +1,6 @@ from keras.dtype_policies.dtype_policy import DTypePolicy +from keras.dtype_policies.dtype_policy import FloatDTypePolicy +from keras.dtype_policies.dtype_policy import QuantizedDTypePolicy from keras.dtype_policies.dtype_policy import dtype_policy from keras.dtype_policies.dtype_policy import set_dtype_policy from keras.testing import test_case @@ -48,7 +50,7 @@ def test_properties(self): def test_repr(self): """Test __repr__ method.""" policy = DTypePolicy("mixed_float16") - self.assertEqual(repr(policy), '') + self.assertEqual(repr(policy), '') def test_get_config_from_config(self): """Test get_config and from_config methods.""" @@ -60,6 +62,120 @@ def test_get_config_from_config(self): self.assertEqual(new_policy.name, "mixed_float16") +class FloatDTypePolicyTest(test_case.TestCase): + def test_initialization_valid_name(self): + """Test initialization with a valid name.""" + policy = FloatDTypePolicy("mixed_float16") + self.assertEqual(policy.compute_dtype, "float16") + self.assertEqual(policy.variable_dtype, "float32") + + def test_initialization_invalid_name(self): + """Test initialization with an invalid name.""" + with self.assertRaisesRegex(ValueError, "Cannot convert"): + FloatDTypePolicy("invalid_name") + + def test_initialization_non_string_name(self): + """Test initialization with a non-string name.""" + with self.assertRaisesRegex(TypeError, "'name' must be a string"): + FloatDTypePolicy(123) + + def test_properties_mixed_float16(self): + """Test properties for 'mixed_float16'.""" + policy = FloatDTypePolicy("mixed_float16") + self.assertEqual(policy.compute_dtype, "float16") + self.assertEqual(policy.variable_dtype, "float32") + + def test_properties_mixed_bfloat16(self): + """Test properties for 'mixed_bfloat16'.""" + policy = FloatDTypePolicy("mixed_bfloat16") + self.assertEqual(policy.compute_dtype, "bfloat16") + self.assertEqual(policy.variable_dtype, "float32") + + def test_initialization_with_invalid_name_behaviour(self): + """Test initialization behavior with an invalid name.""" + with self.assertRaisesRegex(ValueError, "Cannot convert"): + FloatDTypePolicy("invalid_name") + + def test_properties(self): + """Test variable_dtype, compute_dtype, and name properties.""" + policy = FloatDTypePolicy("mixed_float16") + self.assertEqual(policy.variable_dtype, "float32") + self.assertEqual(policy.compute_dtype, "float16") + self.assertEqual(policy.name, "mixed_float16") + + def test_repr(self): + """Test __repr__ method.""" + policy = FloatDTypePolicy("mixed_float16") + self.assertEqual(repr(policy), '') + + def test_get_config_from_config(self): + """Test get_config and from_config methods.""" + policy = FloatDTypePolicy("mixed_float16") + config = policy.get_config() + self.assertEqual(config, {"name": "mixed_float16"}) + + new_policy = FloatDTypePolicy.from_config(config) + self.assertEqual(new_policy.name, "mixed_float16") + + +class QuantizedDTypePolicyTest(test_case.TestCase): + def test_initialization_valid_name(self): + """Test initialization with a valid name.""" + policy = QuantizedDTypePolicy("int8_from_mixed_float16") + self.assertEqual(policy.compute_dtype, "float16") + self.assertEqual(policy.variable_dtype, "float32") + + def test_initialization_invalid_name(self): + """Test initialization with an invalid name.""" + with self.assertRaisesRegex(ValueError, "Cannot convert"): + QuantizedDTypePolicy("invalid_name") + + def test_initialization_non_string_name(self): + """Test initialization with a non-string name.""" + with self.assertRaisesRegex(TypeError, "'name' must be a string"): + QuantizedDTypePolicy(123) + + def test_properties_mixed_float16(self): + """Test properties for 'mixed_float16'.""" + policy = QuantizedDTypePolicy("int8_from_mixed_float16") + self.assertEqual(policy.compute_dtype, "float16") + self.assertEqual(policy.variable_dtype, "float32") + + def test_properties_mixed_bfloat16(self): + """Test properties for 'mixed_bfloat16'.""" + policy = QuantizedDTypePolicy("int8_from_mixed_bfloat16") + self.assertEqual(policy.compute_dtype, "bfloat16") + self.assertEqual(policy.variable_dtype, "float32") + + def test_initialization_with_invalid_name_behaviour(self): + """Test initialization behavior with an invalid name.""" + with self.assertRaisesRegex(ValueError, "Cannot convert"): + QuantizedDTypePolicy("invalid_name") + + def test_properties(self): + """Test variable_dtype, compute_dtype, and name properties.""" + policy = QuantizedDTypePolicy("int8_from_mixed_float16") + self.assertEqual(policy.variable_dtype, "float32") + self.assertEqual(policy.compute_dtype, "float16") + self.assertEqual(policy.name, "int8_from_mixed_float16") + + def test_repr(self): + """Test __repr__ method.""" + policy = QuantizedDTypePolicy("int8_from_mixed_float16") + self.assertEqual( + repr(policy), '' + ) + + def test_get_config_from_config(self): + """Test get_config and from_config methods.""" + policy = QuantizedDTypePolicy("int8_from_mixed_float16") + config = policy.get_config() + self.assertEqual(config, {"name": "int8_from_mixed_float16"}) + + new_policy = QuantizedDTypePolicy.from_config(config) + self.assertEqual(new_policy.name, "int8_from_mixed_float16") + + class DTypePolicyGlobalFunctionsTest(test_case.TestCase): def setUp(self): """Reset the global dtype policy before each test.""" @@ -72,8 +188,8 @@ def test_set_dtype_policy_valid_string(self): self.assertEqual(policy.name, "mixed_float16") def test_set_dtype_policy_valid_policy(self): - """Test set_dtype_policy with a valid DTypePolicy object.""" - policy_obj = DTypePolicy("mixed_float16") + """Test set_dtype_policy with a valid FloatDTypePolicy object.""" + policy_obj = FloatDTypePolicy("mixed_float16") set_dtype_policy(policy_obj) policy = dtype_policy() self.assertEqual(policy.name, "mixed_float16") @@ -89,26 +205,48 @@ def test_dtype_policy_default(self): self.assertEqual(policy.name, "float32") -class DTypePolicyEdgeCasesTest(test_case.TestCase): +class FloatDTypePolicyEdgeCasesTest(test_case.TestCase): + def test_empty_name(self): + """Test initialization with an empty name.""" + with self.assertRaisesRegex(ValueError, "Cannot convert"): + FloatDTypePolicy("") + + def test_special_character_name(self): + """Test initialization with special characters in the name.""" + with self.assertRaisesRegex(ValueError, "Cannot convert"): + FloatDTypePolicy("@mixed_float16!") + + def test_very_long_name(self): + """Test initialization with a very long name.""" + with self.assertRaisesRegex(ValueError, "Cannot convert"): + FloatDTypePolicy("mixed_float16" * 100) + + def test_almost_valid_name(self): + """Test initialization with a name close to a valid one.""" + with self.assertRaisesRegex(ValueError, "Cannot convert"): + FloatDTypePolicy("mixed_float15") + + +class QuantizedDTypePolicyEdgeCasesTest(test_case.TestCase): def test_empty_name(self): """Test initialization with an empty name.""" with self.assertRaisesRegex(ValueError, "Cannot convert"): - DTypePolicy("") + QuantizedDTypePolicy("") def test_special_character_name(self): """Test initialization with special characters in the name.""" with self.assertRaisesRegex(ValueError, "Cannot convert"): - DTypePolicy("@mixed_float16!") + QuantizedDTypePolicy("@int8_from_mixed_float16!") def test_very_long_name(self): """Test initialization with a very long name.""" with self.assertRaisesRegex(ValueError, "Cannot convert"): - DTypePolicy("mixed_float16" * 100) + QuantizedDTypePolicy("int8_from_mixed_float16" * 100) def test_almost_valid_name(self): """Test initialization with a name close to a valid one.""" with self.assertRaisesRegex(ValueError, "Cannot convert"): - DTypePolicy("mixed_float15") + QuantizedDTypePolicy("int7_from_mixed_float16") class DTypePolicyGlobalFunctionsEdgeCasesTest(test_case.TestCase): diff --git a/keras/layers/core/dense.py b/keras/layers/core/dense.py index 6c0563a876d..c23cd58354f 100644 --- a/keras/layers/core/dense.py +++ b/keras/layers/core/dense.py @@ -202,7 +202,7 @@ def enable_lora( def quantize(self, mode): self._check_quantize_args(mode, self.compute_dtype) - if mode == "quantized_int8": + if mode == "int8": # Merge lora-related parameters to make use of fully int8 kernel self._merge_lora_into_kernel() # Configure `self.inputs_quantizer` @@ -211,24 +211,25 @@ def quantize(self, mode): kernel_value, kernel_scale = quantizers.abs_max_quantize( self._kernel, axis=0 ) + kernel_scale = ops.cast(kernel_scale, self.compute_dtype) self._tracker.unlock() self._untrack_variable(self._kernel) self._kernel = self.add_weight( name="kernel", shape=self._kernel.shape, - initializer="zeros", + # Prevent adding a large constant to the computation graph + initializer=lambda shape, dtype: kernel_value, dtype="int8", trainable=False, ) - self._kernel.assign(kernel_value) self.kernel_scale = self.add_weight( name="kernel_scale", shape=kernel_scale.shape, - initializer="zeros", + # Prevent adding a large constant to the computation graph + initializer=lambda shape, dtype: kernel_scale, dtype=self.compute_dtype, trainable=False, ) - self.kernel_scale.assign(kernel_scale) self._tracker.lock() else: NotImplementedError() diff --git a/keras/layers/core/dense_test.py b/keras/layers/core/dense_test.py index 042db4bb0b5..7a94c78eb16 100644 --- a/keras/layers/core/dense_test.py +++ b/keras/layers/core/dense_test.py @@ -241,7 +241,7 @@ def test_enable_lora(self): model.save(temp_filepath) new_model = saving.load_model(temp_filepath) - self.assertFalse(new_model.layers[0].lora_enabled) + self.assertTrue(new_model.layers[0].lora_enabled) self.assertAllClose(model.predict(x), new_model.predict(x)) # Try saving and reloading the model's weights only diff --git a/keras/layers/core/einsum_dense.py b/keras/layers/core/einsum_dense.py index b0f42e16fe5..544861a8f0e 100644 --- a/keras/layers/core/einsum_dense.py +++ b/keras/layers/core/einsum_dense.py @@ -291,7 +291,7 @@ def enable_lora( def quantize(self, mode): self._check_quantize_args(mode, self.compute_dtype) - if mode == "quantized_int8": + if mode == "int8": # Merge lora-related parameters to make use of fully int8 kernel self._merge_lora_into_kernel() @@ -317,6 +317,7 @@ def quantize(self, mode): kernel_value, kernel_scale = quantizers.abs_max_quantize( self._kernel, axis=self._kernel_reduced_axes ) + kernel_scale = ops.cast(kernel_scale, self.compute_dtype) kernel_scale = ops.transpose( kernel_scale, self._kernel_transpose_axes ) @@ -333,19 +334,19 @@ def quantize(self, mode): self._kernel = self.add_weight( name="kernel", shape=self._kernel.shape, - initializer="zeros", + # Prevent adding a large constant to the computation graph + initializer=lambda shape, dtype: kernel_value, dtype="int8", trainable=False, ) - self._kernel.assign(kernel_value) self.kernel_scale = self.add_weight( name="kernel_scale", shape=kernel_scale.shape, - initializer="zeros", + # Prevent adding a large constant to the computation graph + initializer=lambda shape, dtype: kernel_scale, dtype=self.compute_dtype, trainable=False, ) - self.kernel_scale.assign(kernel_scale) self._tracker.lock() else: NotImplementedError() diff --git a/keras/layers/core/einsum_dense_test.py b/keras/layers/core/einsum_dense_test.py index b064f4fb02a..99f9c8aa54f 100644 --- a/keras/layers/core/einsum_dense_test.py +++ b/keras/layers/core/einsum_dense_test.py @@ -326,7 +326,7 @@ def test_enable_lora(self): model.save(temp_filepath) new_model = saving.load_model(temp_filepath) - self.assertFalse(new_model.layers[0].lora_enabled) + self.assertTrue(new_model.layers[0].lora_enabled) self.assertAllClose(model.predict(x), new_model.predict(x)) # Try saving and reloading the model's weights only diff --git a/keras/layers/layer.py b/keras/layers/layer.py index a0a50431ee9..349a7587178 100644 --- a/keras/layers/layer.py +++ b/keras/layers/layer.py @@ -1111,12 +1111,11 @@ def quantize(self, mode): ) def _check_quantize_args(self, mode, compute_dtype): - if mode not in (None, "quantized_int8"): + if mode not in ("int8",): raise ValueError( - "Currently, `quantize` must be one of " - f"(`None`, 'quantized_int8'). Received: mode={mode}" + f"`quantize` must be one of ('int8'). Received: mode={mode}" ) - if mode == "quantized_int8" and compute_dtype == "float16": + if mode == "int8" and compute_dtype == "float16": raise ValueError( f"mode='{mode}' doesn't work well with " "compute_dtype='float16'. Consider loading model/layer with " diff --git a/keras/models/model.py b/keras/models/model.py index 28434246ced..75703e60f57 100644 --- a/keras/models/model.py +++ b/keras/models/model.py @@ -370,12 +370,16 @@ def quantize(self, mode): Args: mode: The mode of the quantization. The supported modes are - `"quantized_int8"`. + ('int8'). """ if not self.built: raise ValueError( "The model must be built first before calling `quantize()`." ) + if mode not in ("int8",): + raise ValueError( + f"`quantize` must be one of ('int8'). Received: mode={mode}" + ) mode_changed = False for layer in self._flatten_layers(include_self=False, recursive=True): try: From 0ff27ed21f527fb1f4bd092a17bfe0bde24b13e5 Mon Sep 17 00:00:00 2001 From: Hongyu Chiu <20734616+james77777778@users.noreply.github.com> Date: Thu, 7 Mar 2024 11:20:48 +0800 Subject: [PATCH 24/33] Update demo script --- check_quantized_int8.py | 177 ++++++++++++++++++++++------------------ 1 file changed, 99 insertions(+), 78 deletions(-) diff --git a/check_quantized_int8.py b/check_quantized_int8.py index aef5dd274d8..2a22952f54a 100644 --- a/check_quantized_int8.py +++ b/check_quantized_int8.py @@ -1,3 +1,4 @@ +import argparse import os import time @@ -13,25 +14,16 @@ from keras import ops from keras import saving -# Set dtype policy -dtype = "mixed_bfloat16" # float32 -dtype_policies.dtype_policy.set_dtype_policy(dtype) -print(f"Global dtype policy: {dtype_policies.dtype_policy.dtype_policy()}") -# Model / data parameters -use_einsum = True -num_classes = 10 -input_shape = (28, 28, 1) -epochs = 1 - -# Load the data and split it between train and test sets -(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data() -x_train = x_train.astype("float32") / 255 -x_test = x_test.astype("float32") / 255 -x_train = np.expand_dims(x_train, -1) -x_test = np.expand_dims(x_test, -1) -y_train = keras.utils.to_categorical(y_train, num_classes) -y_test = keras.utils.to_categorical(y_test, num_classes) +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--dtype-policy", + default="float32", + choices=["float32", "bfloat16", "mixed_bfloat16"], + ) + parser.add_argument("--use-einsum", action="store_true") + return parser.parse_args() def build_model(num_layers=32, units=1024, use_einsum=False): @@ -81,63 +73,92 @@ def fn(x): return avg_time -model = build_model(num_layers=32, units=1024, use_einsum=use_einsum) -model.compile( - loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"] -) - -"""Train float model""" -print("=====Start training float model=====") -model.fit(x_train, y_train, batch_size=128, epochs=epochs, validation_split=0.1) -print(f"Performance of {dtype}:") -score = model.evaluate(x_test, y_test, verbose=0) -print(f" Test accuracy: {score[1]:.5f}") -avg_time = benchmark(model) -print(f" Avg. inference time (batch_size=1024): {avg_time:.5f}s") -model.save("model_fp32.keras") - -"""Enable lora""" -print("=====Enable lora weights=====") -enable_lora(model) - -"""Fine-tuning lora weights""" -model.compile( - loss="categorical_crossentropy", - optimizer="adam", - metrics=["accuracy"], -) -model.fit( - x_train, - y_train, - batch_size=128, - epochs=epochs, - validation_split=0.1, -) -print("Performance of fine-tuned lora weights:") -score = model.evaluate(x_test, y_test, verbose=0) -print(f" Test accuracy: {score[1]:.5f}") -avg_time = benchmark(model) -print(f" Avg. inference time (batch_size=1024): {avg_time:.5f}s") - -"""Quantize to int8 weights""" -model.quantize(mode="quantized_int8") -int8_model = model -int8_model.compile(loss="categorical_crossentropy", metrics=["accuracy"]) -print("Performance of quantized model:") -score = int8_model.evaluate(x_test, y_test, verbose=0) -print(f" Test accuracy: {score[1]:.5f}") -avg_time = benchmark(int8_model) -print(f" Avg. inference time (batch_size=1024): {avg_time:.5f}s") - -"""Saving & loading""" -int8_model.save("model_int8.keras") -reloaded_int8_model = saving.load_model("model_int8.keras") -reloaded_score = reloaded_int8_model.evaluate(x_test, y_test, verbose=0) -print(f"Reloaded int8 model test accuracy: {reloaded_score[1]:.5f}") -print("Size of saved model:") -print(f" fp32: {os.path.getsize('model_fp32.keras') >> 20}MB") -print(f" int8: {os.path.getsize('model_int8.keras') >> 20}MB") - -"""Cleanup""" -os.remove("model_fp32.keras") -os.remove("model_int8.keras") +def main(): + args = get_args() + + # Set dtype policy + dtype = args.dtype_policy + dtype_policies.dtype_policy.set_dtype_policy(dtype) + print(f"Global dtype policy: {dtype_policies.dtype_policy.dtype_policy()}") + + # Model / data parameters + use_einsum = args.use_einsum + num_classes = 10 + input_shape = (28, 28, 1) + epochs = 1 + + # Load the data and split it between train and test sets + (x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data() + x_train = x_train.astype("float32") / 255 + x_test = x_test.astype("float32") / 255 + x_train = np.expand_dims(x_train, -1) + x_test = np.expand_dims(x_test, -1) + y_train = keras.utils.to_categorical(y_train, num_classes) + y_test = keras.utils.to_categorical(y_test, num_classes) + + model = build_model(num_layers=32, units=1024, use_einsum=use_einsum) + model.compile( + loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"] + ) + + """Train float model""" + print("=====Start training float model=====") + model.fit( + x_train, y_train, batch_size=128, epochs=epochs, validation_split=0.1 + ) + print(f"Performance of {dtype}:") + score = model.evaluate(x_test, y_test, verbose=0) + print(f" Test accuracy: {score[1]:.5f}") + avg_time = benchmark(model, input_shape=input_shape) + print(f" Avg. inference time (batch_size=1024): {avg_time:.5f}s") + model.save("model_fp32.keras") + + """Enable lora""" + print("=====Enable lora weights=====") + enable_lora(model) + + """Fine-tuning lora weights""" + model.compile( + loss="categorical_crossentropy", + optimizer="adam", + metrics=["accuracy"], + ) + model.fit( + x_train, + y_train, + batch_size=128, + epochs=epochs, + validation_split=0.1, + ) + print("Performance of fine-tuned lora weights:") + score = model.evaluate(x_test, y_test, verbose=0) + print(f" Test accuracy: {score[1]:.5f}") + avg_time = benchmark(model, input_shape=input_shape) + print(f" Avg. inference time (batch_size=1024): {avg_time:.5f}s") + + """Quantize to int8 weights""" + model.quantize(mode="int8") + int8_model = model + int8_model.compile(loss="categorical_crossentropy", metrics=["accuracy"]) + print("Performance of quantized model:") + score = int8_model.evaluate(x_test, y_test, verbose=0) + print(f" Test accuracy: {score[1]:.5f}") + avg_time = benchmark(int8_model, input_shape=input_shape) + print(f" Avg. inference time (batch_size=1024): {avg_time:.5f}s") + + """Saving & loading""" + int8_model.save("model_int8.keras") + reloaded_int8_model = saving.load_model("model_int8.keras") + reloaded_score = reloaded_int8_model.evaluate(x_test, y_test, verbose=0) + print(f"Reloaded int8 model test accuracy: {reloaded_score[1]:.5f}") + print("Size of saved model:") + print(f" fp32: {os.path.getsize('model_fp32.keras') >> 20}MB") + print(f" int8: {os.path.getsize('model_int8.keras') >> 20}MB") + + """Cleanup""" + os.remove("model_fp32.keras") + os.remove("model_int8.keras") + + +if __name__ == "__main__": + main() From 499296ceb979da319030836eae9c8d5004c7d032 Mon Sep 17 00:00:00 2001 From: Hongyu Chiu <20734616+james77777778@users.noreply.github.com> Date: Thu, 7 Mar 2024 11:24:38 +0800 Subject: [PATCH 25/33] Delete demo script and result --- RESULT.md | 24 ------ check_quantized_int8.py | 164 ---------------------------------------- 2 files changed, 188 deletions(-) delete mode 100644 RESULT.md delete mode 100644 check_quantized_int8.py diff --git a/RESULT.md b/RESULT.md deleted file mode 100644 index adce5efc75b..00000000000 --- a/RESULT.md +++ /dev/null @@ -1,24 +0,0 @@ -# Result - -Configuration: - -- MNIST -- model - - `Dense` or `EinsumDense` - - `BatchNormalization` - - `ReLU` -- fine-tuning with `enable_lora(rank=2)` -- inference time: batch size=1024 - - float: `self.lora_enabled=True` - - int8: `self.lora_enabled=False` (merged) - -|backend|dtype_policy|layer|float acc.|int8 acc.|float inference time|int8 inference time|inference time ratio| -|-|-|-|-|-|-|-|-| -|tensorflow|float32|`Dense`|0.95990|0.96000|0.00395s|0.00198s|0.501| -|tensorflow|mixed_bfloat16|`Dense`|0.96110|0.96110|0.00265s|0.00200s|0.755| -|tensorflow|float32|`EinsumDense`|0.95950|0.95920|0.00384s|0.00188s|0.490| -|tensorflow|mixed_bfloat16|`EinsumDense`|0.95980|0.95970|0.00258s|0.00200s|0.775| -|jax|float32|`Dense`|0.96130|0.96160|0.00304s|0.00132s|0.434| -|jax|mixed_bfloat16|`Dense`|0.95290|0.95300|0.00177s|0.00133s|0.751| -|jax|float32|`EinsumDense`|0.96170|0.96160|0.00302s|0.00132s|0.437| -|jax|mixed_bfloat16|`EinsumDense`|0.95720|0.95680|0.00176s|0.00125s|0.710| diff --git a/check_quantized_int8.py b/check_quantized_int8.py deleted file mode 100644 index 2a22952f54a..00000000000 --- a/check_quantized_int8.py +++ /dev/null @@ -1,164 +0,0 @@ -import argparse -import os -import time - -import jax -import numpy as np -import tensorflow as tf - -import keras -from keras import backend -from keras import dtype_policies -from keras import layers -from keras import models -from keras import ops -from keras import saving - - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument( - "--dtype-policy", - default="float32", - choices=["float32", "bfloat16", "mixed_bfloat16"], - ) - parser.add_argument("--use-einsum", action="store_true") - return parser.parse_args() - - -def build_model(num_layers=32, units=1024, use_einsum=False): - inputs = layers.Input([28, 28]) - x = layers.Flatten()(inputs) - for _ in range(num_layers): - if use_einsum: - x = layers.EinsumDense("ab,bc->ac", output_shape=[units])(x) - else: - x = layers.Dense(units)(x) - x = layers.BatchNormalization()(x) - x = layers.ReLU()(x) - outputs = layers.Dense(10, use_bias=True, activation="softmax")(x) - model = models.Model(inputs, outputs) - return model - - -def enable_lora(model): - for layer in model.layers: - if hasattr(layer, "enable_lora"): - layer.enable_lora(2) - - -def benchmark(model, batch_size=1024, input_shape=(28, 28), iterations=200): - def fn(x): - return model(x, training=False) - - if backend.backend() == "tensorflow": - jit_fn = tf.function(fn, jit_compile=True) - elif backend.backend() == "jax": - jit_fn = jax.jit(fn) - else: - jit_fn = fn - - # warmup - x = ops.ones([batch_size, *input_shape]) - for _ in range(10): - _ = ops.convert_to_numpy(jit_fn(x)) - - times = [] - for _ in range(iterations): - t0 = time.time() - _ = ops.convert_to_numpy(jit_fn(x)) - t1 = time.time() - times.append(t1 - t0) - avg_time = sum(times) / len(times) - return avg_time - - -def main(): - args = get_args() - - # Set dtype policy - dtype = args.dtype_policy - dtype_policies.dtype_policy.set_dtype_policy(dtype) - print(f"Global dtype policy: {dtype_policies.dtype_policy.dtype_policy()}") - - # Model / data parameters - use_einsum = args.use_einsum - num_classes = 10 - input_shape = (28, 28, 1) - epochs = 1 - - # Load the data and split it between train and test sets - (x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data() - x_train = x_train.astype("float32") / 255 - x_test = x_test.astype("float32") / 255 - x_train = np.expand_dims(x_train, -1) - x_test = np.expand_dims(x_test, -1) - y_train = keras.utils.to_categorical(y_train, num_classes) - y_test = keras.utils.to_categorical(y_test, num_classes) - - model = build_model(num_layers=32, units=1024, use_einsum=use_einsum) - model.compile( - loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"] - ) - - """Train float model""" - print("=====Start training float model=====") - model.fit( - x_train, y_train, batch_size=128, epochs=epochs, validation_split=0.1 - ) - print(f"Performance of {dtype}:") - score = model.evaluate(x_test, y_test, verbose=0) - print(f" Test accuracy: {score[1]:.5f}") - avg_time = benchmark(model, input_shape=input_shape) - print(f" Avg. inference time (batch_size=1024): {avg_time:.5f}s") - model.save("model_fp32.keras") - - """Enable lora""" - print("=====Enable lora weights=====") - enable_lora(model) - - """Fine-tuning lora weights""" - model.compile( - loss="categorical_crossentropy", - optimizer="adam", - metrics=["accuracy"], - ) - model.fit( - x_train, - y_train, - batch_size=128, - epochs=epochs, - validation_split=0.1, - ) - print("Performance of fine-tuned lora weights:") - score = model.evaluate(x_test, y_test, verbose=0) - print(f" Test accuracy: {score[1]:.5f}") - avg_time = benchmark(model, input_shape=input_shape) - print(f" Avg. inference time (batch_size=1024): {avg_time:.5f}s") - - """Quantize to int8 weights""" - model.quantize(mode="int8") - int8_model = model - int8_model.compile(loss="categorical_crossentropy", metrics=["accuracy"]) - print("Performance of quantized model:") - score = int8_model.evaluate(x_test, y_test, verbose=0) - print(f" Test accuracy: {score[1]:.5f}") - avg_time = benchmark(int8_model, input_shape=input_shape) - print(f" Avg. inference time (batch_size=1024): {avg_time:.5f}s") - - """Saving & loading""" - int8_model.save("model_int8.keras") - reloaded_int8_model = saving.load_model("model_int8.keras") - reloaded_score = reloaded_int8_model.evaluate(x_test, y_test, verbose=0) - print(f"Reloaded int8 model test accuracy: {reloaded_score[1]:.5f}") - print("Size of saved model:") - print(f" fp32: {os.path.getsize('model_fp32.keras') >> 20}MB") - print(f" int8: {os.path.getsize('model_int8.keras') >> 20}MB") - - """Cleanup""" - os.remove("model_fp32.keras") - os.remove("model_int8.keras") - - -if __name__ == "__main__": - main() From f6493cb6a5da8b1bf29b63e8556a807b62a2ddbf Mon Sep 17 00:00:00 2001 From: Hongyu Chiu <20734616+james77777778@users.noreply.github.com> Date: Thu, 7 Mar 2024 12:29:40 +0800 Subject: [PATCH 26/33] Add `quantize` tests for Dense and EinsumDense --- keras/layers/core/dense.py | 5 ++ keras/layers/core/dense_test.py | 57 +++++++++++++++++++++ keras/layers/core/einsum_dense.py | 5 ++ keras/layers/core/einsum_dense_test.py | 71 ++++++++++++++++++++++++++ keras/layers/layer.py | 2 + 5 files changed, 140 insertions(+) diff --git a/keras/layers/core/dense.py b/keras/layers/core/dense.py index c23cd58354f..27b6b1c7f7b 100644 --- a/keras/layers/core/dense.py +++ b/keras/layers/core/dense.py @@ -1,6 +1,7 @@ import numpy as np from keras import activations +from keras import backend from keras import constraints from keras import dtype_policies from keras import initializers @@ -203,6 +204,8 @@ def enable_lora( def quantize(self, mode): self._check_quantize_args(mode, self.compute_dtype) if mode == "int8": + if backend.standardize_dtype(self._kernel.dtype) == "int8": + raise ValueError("`quantize` can only be done once per layer.") # Merge lora-related parameters to make use of fully int8 kernel self._merge_lora_into_kernel() # Configure `self.inputs_quantizer` @@ -230,6 +233,8 @@ def quantize(self, mode): dtype=self.compute_dtype, trainable=False, ) + if self.bias is not None: + self.bias.trainable = False self._tracker.lock() else: NotImplementedError() diff --git a/keras/layers/core/dense_test.py b/keras/layers/core/dense_test.py index 7a94c78eb16..208d0e47588 100644 --- a/keras/layers/core/dense_test.py +++ b/keras/layers/core/dense_test.py @@ -304,3 +304,60 @@ def test_enable_lora_when_already_enabled(self): layer.enable_lora(rank=2) with self.assertRaisesRegex(ValueError, "lora is already enabled"): layer.enable_lora(rank=2) + + def test_quantize_int8(self): + layer = layers.Dense(units=16) + layer.build((None, 8)) + layer.quantize("int8") + + # Try eager call + x = np.random.random((2, 8)) + _ = layer(x) + + # Try saving and reloading the model + model = models.Sequential([layer]) + temp_filepath = os.path.join( + self.get_temp_dir(), "quantized_model.keras" + ) + model.save(temp_filepath) + new_model = saving.load_model(temp_filepath) + self.assertAllClose(model.predict(x), new_model.predict(x)) + + # Try saving and reloading the model's weights only + temp_filepath = os.path.join( + self.get_temp_dir(), "quantized_model.weights.h5" + ) + model.save_weights(temp_filepath) + + @pytest.mark.requires_trainable_backend + def test_quantize_dtype_argument(self): + self.run_layer_test( + layers.Dense, + init_kwargs={ + "units": 5, + "dtype": "int8_from_mixed_bfloat16", + }, + input_shape=(2, 3, 4), + expected_output_shape=(2, 3, 5), + expected_num_trainable_weights=0, + expected_num_non_trainable_weights=3, + expected_num_seed_generators=0, + expected_num_losses=0, + supports_masking=True, + ) + + def test_quantize_on_unbuilt_layer(self): + layer = layers.Dense(units=2) + with self.assertRaisesRegex( + ValueError, "Cannot quantize on a layer that isn't yet built." + ): + layer.quantize("int8") + + def test_quantize_when_already_quantized(self): + layer = layers.Dense(units=2) + layer.build((None, 2)) + layer.quantize("int8") + with self.assertRaisesRegex( + ValueError, "`quantize` can only be done once per layer." + ): + layer.quantize("int8") diff --git a/keras/layers/core/einsum_dense.py b/keras/layers/core/einsum_dense.py index 544861a8f0e..e7609e75539 100644 --- a/keras/layers/core/einsum_dense.py +++ b/keras/layers/core/einsum_dense.py @@ -4,6 +4,7 @@ import numpy as np from keras import activations +from keras import backend from keras import constraints from keras import dtype_policies from keras import initializers @@ -292,6 +293,8 @@ def enable_lora( def quantize(self, mode): self._check_quantize_args(mode, self.compute_dtype) if mode == "int8": + if backend.standardize_dtype(self._kernel.dtype) == "int8": + raise ValueError("`quantize` can only be done once per layer.") # Merge lora-related parameters to make use of fully int8 kernel self._merge_lora_into_kernel() @@ -347,6 +350,8 @@ def quantize(self, mode): dtype=self.compute_dtype, trainable=False, ) + if self.bias is not None: + self.bias.trainable = False self._tracker.lock() else: NotImplementedError() diff --git a/keras/layers/core/einsum_dense_test.py b/keras/layers/core/einsum_dense_test.py index 99f9c8aa54f..c196af1fd48 100644 --- a/keras/layers/core/einsum_dense_test.py +++ b/keras/layers/core/einsum_dense_test.py @@ -372,3 +372,74 @@ def test_lora_rank_argument(self): expected_num_losses=0, supports_masking=False, ) + + def test_quantize_int8(self): + layer = layers.EinsumDense( + equation="ab,bcd->acd", + output_shape=(8, 32), + bias_axes="d", + ) + layer.build((None, 3)) + layer.quantize("int8") + + # Try eager call + x = np.random.random((2, 3)) + _ = layer(x) + + # Try saving and reloading the model + model = models.Sequential([layer]) + temp_filepath = os.path.join( + self.get_temp_dir(), "quantized_model.keras" + ) + model.save(temp_filepath) + new_model = saving.load_model(temp_filepath) + self.assertAllClose(model.predict(x), new_model.predict(x)) + + # Try saving and reloading the model's weights only + temp_filepath = os.path.join( + self.get_temp_dir(), "quantized_model.weights.h5" + ) + model.save_weights(temp_filepath) + + @pytest.mark.requires_trainable_backend + def test_quantize_dtype_argument(self): + self.run_layer_test( + layers.EinsumDense, + init_kwargs={ + "equation": "ab,bcd->acd", + "output_shape": (8, 32), + "bias_axes": "d", + "dtype": "int8_from_mixed_bfloat16", + }, + input_shape=(2, 3), + expected_output_shape=(2, 8, 32), + expected_num_trainable_weights=0, + expected_num_non_trainable_weights=3, + expected_num_seed_generators=0, + expected_num_losses=0, + supports_masking=False, + ) + + def test_quantize_on_unbuilt_layer(self): + layer = layers.EinsumDense( + equation="ab,bcd->acd", + output_shape=(8, 32), + bias_axes="d", + ) + with self.assertRaisesRegex( + ValueError, "Cannot quantize on a layer that isn't yet built." + ): + layer.quantize("int8") + + def test_quantize_when_already_quantized(self): + layer = layers.EinsumDense( + equation="ab,bcd->acd", + output_shape=(8, 32), + bias_axes="d", + ) + layer.build((None, 3)) + layer.quantize("int8") + with self.assertRaisesRegex( + ValueError, "`quantize` can only be done once per layer." + ): + layer.quantize("int8") diff --git a/keras/layers/layer.py b/keras/layers/layer.py index 349a7587178..9ec39cb6d05 100644 --- a/keras/layers/layer.py +++ b/keras/layers/layer.py @@ -1111,6 +1111,8 @@ def quantize(self, mode): ) def _check_quantize_args(self, mode, compute_dtype): + if not self.built: + raise ValueError("Cannot quantize on a layer that isn't yet built.") if mode not in ("int8",): raise ValueError( f"`quantize` must be one of ('int8'). Received: mode={mode}" From 8b717bea0a79d0b6b147a885b25011f6e51a2beb Mon Sep 17 00:00:00 2001 From: Hongyu Chiu <20734616+james77777778@users.noreply.github.com> Date: Thu, 7 Mar 2024 12:52:24 +0800 Subject: [PATCH 27/33] Add `quantize` tests for Model --- keras/models/model_test.py | 40 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/keras/models/model_test.py b/keras/models/model_test.py index 09f168b94b7..40ae4d2f8d7 100644 --- a/keras/models/model_test.py +++ b/keras/models/model_test.py @@ -561,3 +561,43 @@ def test_functional_list_outputs_invalid_nested_list_losses(self): "it should have as many entries as the model has outputs", ): model.fit(x, (y1, y2), batch_size=2, epochs=1, verbose=0) + + def test_quantize(self): + model = _get_model() + x1 = np.random.rand(2, 3) + x2 = np.random.rand(2, 3) + model.quantize("int8") + _ = model((x1, x2)) + + for layer in model._flatten_layers(): + if isinstance(layer, (layers.Dense, layers.EinsumDense)): + self.assertEqual(layer.dtype_policy.name, "int8_from_float32") + self.assertEqual(layer.dtype_policy.quantization_mode, "int8") + + def test_quantize_unbuilt(self): + class MyModel(Model): + def __init__(self): + super().__init__() + self.dense1 = layers.Dense(32, activation="relu") + self.dense2 = layers.Dense(5, activation="softmax") + self.dropout = layers.Dropout(0.5) + + def call(self, inputs, training=False): + x = self.dense1(inputs) + x = self.dropout(x, training=training) + return self.dense2(x) + + model = MyModel() + with self.assertRaisesRegex( + ValueError, "The model must be built first before calling" + ): + model.quantize("int8") + + x = np.random.rand(2, 3) + _ = model(x) + model.quantize("int8") + + def test_quantize_invalid_args(self): + model = _get_model() + with self.assertRaisesRegex(ValueError, "`quantize` must be one of"): + model.quantize("abc") From 792545af7f64067eabd5a7a2339845727d3bf8a4 Mon Sep 17 00:00:00 2001 From: Hongyu Chiu <20734616+james77777778@users.noreply.github.com> Date: Thu, 7 Mar 2024 13:20:31 +0800 Subject: [PATCH 28/33] Add tests for `keras.quantizers` --- keras/quantizers/quantizers.py | 9 +++++++++ keras/quantizers/quantizers_test.py | 24 ++++++++++++++++++++++++ 2 files changed, 33 insertions(+) create mode 100644 keras/quantizers/quantizers_test.py diff --git a/keras/quantizers/quantizers.py b/keras/quantizers/quantizers.py index 2051d61c2a2..b9fc111f886 100644 --- a/keras/quantizers/quantizers.py +++ b/keras/quantizers/quantizers.py @@ -71,6 +71,7 @@ def get_config(self): raise NotImplementedError(f"{self} does not implement get_config()") +@keras_export(["keras.AbsMaxQuantizer", "keras.quantizers.AbsMaxQuantizer"]) class AbsMaxQuantizer(Quantizer): def __init__( self, @@ -91,3 +92,11 @@ def __call__(self, x): x, self.axis, self.value_range, self.output_dtype, self.epsilon ) return quantized_x, scale + + def get_config(self): + return { + "axis": self.axis, + "value_range": self.value_range, + "epsilon": self.epsilon, + "output_dtype": self.output_dtype, + } diff --git a/keras/quantizers/quantizers_test.py b/keras/quantizers/quantizers_test.py new file mode 100644 index 00000000000..d646f612b9d --- /dev/null +++ b/keras/quantizers/quantizers_test.py @@ -0,0 +1,24 @@ +from keras import ops +from keras import quantizers +from keras import random +from keras import testing + + +class QuantizersTest(testing.TestCase): + def test_abs_max_quantizer(self): + values = random.uniform([3, 4, 5], minval=-1, maxval=1) + quantizer = quantizers.AbsMaxQuantizer(axis=-1) + + # Test quantize + quantized_values, scale = quantizer(values) + self.assertEqual(quantized_values.shape, [3, 4, 5]) + self.assertEqual(scale.shape, [3, 4, 1]) + self.assertLessEqual(ops.max(quantized_values), 127) + self.assertGreaterEqual(ops.min(quantized_values), -127) + + # Test dequantize + dequantized_values = ops.divide(quantized_values, scale) + self.assertAllClose(values, dequantized_values, atol=1) + + # Test serialization + self.run_class_serialization_test(quantizer) From 1d124bb2014c4576ba5025b360ff8a96d9aa63cf Mon Sep 17 00:00:00 2001 From: Hongyu Chiu <20734616+james77777778@users.noreply.github.com> Date: Thu, 7 Mar 2024 13:28:47 +0800 Subject: [PATCH 29/33] Update `quantizers` tests --- keras/quantizers/quantizers_test.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/keras/quantizers/quantizers_test.py b/keras/quantizers/quantizers_test.py index d646f612b9d..61ae58a8428 100644 --- a/keras/quantizers/quantizers_test.py +++ b/keras/quantizers/quantizers_test.py @@ -9,16 +9,19 @@ def test_abs_max_quantizer(self): values = random.uniform([3, 4, 5], minval=-1, maxval=1) quantizer = quantizers.AbsMaxQuantizer(axis=-1) - # Test quantize + # Test quantizing quantized_values, scale = quantizer(values) - self.assertEqual(quantized_values.shape, [3, 4, 5]) - self.assertEqual(scale.shape, [3, 4, 1]) + self.assertEqual(tuple(quantized_values.shape), (3, 4, 5)) + self.assertEqual(tuple(scale.shape), (3, 4, 1)) self.assertLessEqual(ops.max(quantized_values), 127) self.assertGreaterEqual(ops.min(quantized_values), -127) - # Test dequantize + # Test dequantizing dequantized_values = ops.divide(quantized_values, scale) - self.assertAllClose(values, dequantized_values, atol=1) + rmse = ops.sqrt( + ops.mean(ops.square(ops.subtract(values, dequantized_values))) + ) + self.assertLess(rmse, 1e-1) # loose assertion # Test serialization self.run_class_serialization_test(quantizer) From f7d7cd7497f9bc600b191069e841e19db8630a0c Mon Sep 17 00:00:00 2001 From: Hongyu Chiu <20734616+james77777778@users.noreply.github.com> Date: Thu, 7 Mar 2024 15:25:35 +0800 Subject: [PATCH 30/33] Improve test coverage --- keras/backend/torch/numpy.py | 6 +++++- keras/dtype_policies/dtype_policy.py | 2 +- keras/dtype_policies/dtype_policy_test.py | 13 +++++++++++++ keras/layers/core/dense_test.py | 8 ++++++++ keras/layers/core/einsum_dense_test.py | 12 ++++++++++++ keras/quantizers/__init__.py | 4 ++-- keras/quantizers/quantizers_test.py | 10 ++++++++++ 7 files changed, 51 insertions(+), 4 deletions(-) diff --git a/keras/backend/torch/numpy.py b/keras/backend/torch/numpy.py index 6b80a8378ed..60499e60dce 100644 --- a/keras/backend/torch/numpy.py +++ b/keras/backend/torch/numpy.py @@ -35,8 +35,12 @@ def einsum(subscripts, *operands, **kwargs): # the behavior of jax. dtypes_to_resolve = list(set(standardize_dtype(x.dtype) for x in operands)) if len(dtypes_to_resolve) == 1 and dtypes_to_resolve[0] == "int8": + compute_dtype = "int32" + if get_device() == "cuda": + # TODO: torch.einsum doesn't support int32 when using cuda + compute_dtype = config.floatx() # prevent overflow - operands = [cast(operand, "int32") for operand in operands] + operands = [cast(operand, compute_dtype) for operand in operands] return cast(torch.einsum(subscripts, *operands), "int32") return torch.einsum(subscripts, *operands) diff --git a/keras/dtype_policies/dtype_policy.py b/keras/dtype_policies/dtype_policy.py index 150a5f34a4f..ee8692f9192 100644 --- a/keras/dtype_policies/dtype_policy.py +++ b/keras/dtype_policies/dtype_policy.py @@ -258,7 +258,7 @@ def set_dtype_policy(policy): """ if not isinstance(policy, DTypePolicy): if isinstance(policy, str): - if "quantized" in policy: + if "int8" in policy: policy = QuantizedDTypePolicy(policy) else: policy = FloatDTypePolicy(policy) diff --git a/keras/dtype_policies/dtype_policy_test.py b/keras/dtype_policies/dtype_policy_test.py index dfd09fcf02a..8b3d6d43b6c 100644 --- a/keras/dtype_policies/dtype_policy_test.py +++ b/keras/dtype_policies/dtype_policy_test.py @@ -187,6 +187,12 @@ def test_set_dtype_policy_valid_string(self): policy = dtype_policy() self.assertEqual(policy.name, "mixed_float16") + def test_set_dtype_policy_valid_string_quantized(self): + """Test set_dtype_policy with a valid string.""" + set_dtype_policy("int8_from_mixed_float16") + policy = dtype_policy() + self.assertEqual(policy.name, "int8_from_mixed_float16") + def test_set_dtype_policy_valid_policy(self): """Test set_dtype_policy with a valid FloatDTypePolicy object.""" policy_obj = FloatDTypePolicy("mixed_float16") @@ -194,6 +200,13 @@ def test_set_dtype_policy_valid_policy(self): policy = dtype_policy() self.assertEqual(policy.name, "mixed_float16") + def test_set_dtype_policy_valid_policy_quantized(self): + """Test set_dtype_policy with a valid FloatDTypePolicy object.""" + policy_obj = QuantizedDTypePolicy("int8_from_mixed_float16") + set_dtype_policy(policy_obj) + policy = dtype_policy() + self.assertEqual(policy.name, "int8_from_mixed_float16") + def test_set_dtype_policy_invalid(self): """Test set_dtype_policy with an invalid input.""" with self.assertRaisesRegex(ValueError, "Invalid `policy` argument"): diff --git a/keras/layers/core/dense_test.py b/keras/layers/core/dense_test.py index 208d0e47588..fcfffed1653 100644 --- a/keras/layers/core/dense_test.py +++ b/keras/layers/core/dense_test.py @@ -329,6 +329,14 @@ def test_quantize_int8(self): ) model.save_weights(temp_filepath) + # Try lora + layer = layers.Dense(units=16) + layer.build((None, 8)) + layer.enable_lora(4) + layer.quantize("int8") + x = np.random.random((2, 8)) + _ = layer(x) + @pytest.mark.requires_trainable_backend def test_quantize_dtype_argument(self): self.run_layer_test( diff --git a/keras/layers/core/einsum_dense_test.py b/keras/layers/core/einsum_dense_test.py index c196af1fd48..e7126ed3f3c 100644 --- a/keras/layers/core/einsum_dense_test.py +++ b/keras/layers/core/einsum_dense_test.py @@ -401,6 +401,18 @@ def test_quantize_int8(self): ) model.save_weights(temp_filepath) + # Try lora + layer = layers.EinsumDense( + equation="ab,bcd->acd", + output_shape=(8, 32), + bias_axes="d", + ) + layer.build((None, 3)) + layer.enable_lora(2) + layer.quantize("int8") + x = np.random.random((2, 3)) + _ = layer(x) + @pytest.mark.requires_trainable_backend def test_quantize_dtype_argument(self): self.run_layer_test( diff --git a/keras/quantizers/__init__.py b/keras/quantizers/__init__.py index 8da3a15cf3d..6139476a1fc 100644 --- a/keras/quantizers/__init__.py +++ b/keras/quantizers/__init__.py @@ -30,7 +30,7 @@ def deserialize(config, custom_objects=None): @keras_export("keras.quantizers.get") -def get(identifier): +def get(identifier, **kwargs): """Retrieve a Keras quantizer object via an identifier.""" if identifier is None: return None @@ -43,7 +43,7 @@ def get(identifier): if callable(obj): if inspect.isclass(obj): - obj = obj() + obj = obj(kwargs) return obj else: raise ValueError( diff --git a/keras/quantizers/quantizers_test.py b/keras/quantizers/quantizers_test.py index 61ae58a8428..1483fb7cc9c 100644 --- a/keras/quantizers/quantizers_test.py +++ b/keras/quantizers/quantizers_test.py @@ -5,6 +5,16 @@ class QuantizersTest(testing.TestCase): + def test_get_method(self): + quantizer = quantizers.get("abs_max_quantizer", axis=-1) + self.assertTrue(quantizer, quantizers.AbsMaxQuantizer) + + quantizer = quantizers.get(None) + self.assertEqual(quantizer, None) + + with self.assertRaises(ValueError): + quantizers.get("typo") + def test_abs_max_quantizer(self): values = random.uniform([3, 4, 5], minval=-1, maxval=1) quantizer = quantizers.AbsMaxQuantizer(axis=-1) From eb76c902f5cc823d2a6f04518c8a1d1a63a2c00a Mon Sep 17 00:00:00 2001 From: Hongyu Chiu <20734616+james77777778@users.noreply.github.com> Date: Thu, 7 Mar 2024 20:47:56 +0800 Subject: [PATCH 31/33] Resolve comments --- keras/dtype_policies/dtype_policy.py | 18 ++++---------- keras/layers/core/dense.py | 5 +++- keras/layers/layer.py | 3 ++- keras/models/model.py | 7 +++--- keras/quantizers/quantizers.py | 36 ++++++++++++++-------------- 5 files changed, 33 insertions(+), 36 deletions(-) diff --git a/keras/dtype_policies/dtype_policy.py b/keras/dtype_policies/dtype_policy.py index ee8692f9192..26e063df677 100644 --- a/keras/dtype_policies/dtype_policy.py +++ b/keras/dtype_policies/dtype_policy.py @@ -1,5 +1,3 @@ -import warnings - from keras import backend from keras import ops from keras.api_export import keras_export @@ -66,17 +64,10 @@ def __new__(cls, name): # For backwards compatibility # TODO: We should consider deprecating this behavior if cls is __class__: - warnings.warn( - "Consider using the subclass of DTypePolicy to initialize the " - "dtype policy such as FloatDTypePolicy and " - "QuantizedDTypePolicy." - ) - else: - return super().__new__(cls) - if "int8" in name: - return QuantizedDTypePolicy(name) - else: + if "int8" in name: + return QuantizedDTypePolicy(name) return FloatDTypePolicy(name) + return super().__new__(cls) def __init__(self, name): self._name = name @@ -189,7 +180,8 @@ def _parse_name(self, name): raise ValueError( f"Cannot convert '{name}' to a mixed precision " "FloatDTypePolicy. Valid policies include 'mixed_float16', " - "'mixed_bfloat16', and the name of any dtype such as 'float32'." + "'mixed_bfloat16', and the name of any float dtype such as " + "'float32'." ) def __repr__(self): diff --git a/keras/layers/core/dense.py b/keras/layers/core/dense.py index 27b6b1c7f7b..a2210c83241 100644 --- a/keras/layers/core/dense.py +++ b/keras/layers/core/dense.py @@ -237,7 +237,10 @@ def quantize(self, mode): self.bias.trainable = False self._tracker.lock() else: - NotImplementedError() + NotImplementedError( + "Invalid quantization mode. Expected 'int8'. " + f"Received: mode={mode}" + ) # Set new dtype policy if not isinstance( diff --git a/keras/layers/layer.py b/keras/layers/layer.py index 9ec39cb6d05..b41ead720e1 100644 --- a/keras/layers/layer.py +++ b/keras/layers/layer.py @@ -1121,7 +1121,8 @@ def _check_quantize_args(self, mode, compute_dtype): raise ValueError( f"mode='{mode}' doesn't work well with " "compute_dtype='float16'. Consider loading model/layer with " - "other dtype policy before calling `quantize`." + "other dtype policy such as 'mixed_bfloat16' before calling " + "`quantize`." ) def save_own_variables(self, store): diff --git a/keras/models/model.py b/keras/models/model.py index 75703e60f57..3d51bf868a4 100644 --- a/keras/models/model.py +++ b/keras/models/model.py @@ -369,8 +369,8 @@ def quantize(self, mode): will be skipped if the layer doesn't implement the function. Args: - mode: The mode of the quantization. The supported modes are - ('int8'). + mode: The mode of the quantization. Only 'int8' is supported at this + time. """ if not self.built: raise ValueError( @@ -378,7 +378,8 @@ def quantize(self, mode): ) if mode not in ("int8",): raise ValueError( - f"`quantize` must be one of ('int8'). Received: mode={mode}" + "Invalid quantization mode. Expected 'int8'. " + f"Received: mode={mode}" ) mode_changed = False for layer in self._flatten_layers(include_self=False, recursive=True): diff --git a/keras/quantizers/quantizers.py b/keras/quantizers/quantizers.py index b9fc111f886..e21f10bcc38 100644 --- a/keras/quantizers/quantizers.py +++ b/keras/quantizers/quantizers.py @@ -3,24 +3,6 @@ from keras.api_export import keras_export -@keras_export(["keras.quantizers.abs_max_quantize"]) -def abs_max_quantize( - inputs, - axis, - value_range=(-127, 127), - dtype="int8", - epsilon=backend.epsilon(), -): - scale = ops.divide( - value_range[1], - ops.add(ops.max(ops.abs(inputs), axis=axis, keepdims=True), epsilon), - ) - outputs = ops.multiply(inputs, scale) - outputs = ops.clip(ops.round(outputs), value_range[0], value_range[1]) - outputs = ops.cast(outputs, dtype) - return outputs, scale - - @keras_export(["keras.Quantizer", "keras.quantizers.Quantizer"]) class Quantizer: def __init__(self, output_dtype="int8"): @@ -71,6 +53,24 @@ def get_config(self): raise NotImplementedError(f"{self} does not implement get_config()") +@keras_export(["keras.quantizers.abs_max_quantize"]) +def abs_max_quantize( + inputs, + axis, + value_range=(-127, 127), + dtype="int8", + epsilon=backend.epsilon(), +): + scale = ops.divide( + value_range[1], + ops.add(ops.max(ops.abs(inputs), axis=axis, keepdims=True), epsilon), + ) + outputs = ops.multiply(inputs, scale) + outputs = ops.clip(ops.round(outputs), value_range[0], value_range[1]) + outputs = ops.cast(outputs, dtype) + return outputs, scale + + @keras_export(["keras.AbsMaxQuantizer", "keras.quantizers.AbsMaxQuantizer"]) class AbsMaxQuantizer(Quantizer): def __init__( From 243445f10592af1daaf3f0d3546495eea7ebd13b Mon Sep 17 00:00:00 2001 From: Hongyu Chiu <20734616+james77777778@users.noreply.github.com> Date: Thu, 7 Mar 2024 20:58:38 +0800 Subject: [PATCH 32/33] Fix test --- keras/models/model_test.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/keras/models/model_test.py b/keras/models/model_test.py index 40ae4d2f8d7..f650bf79c96 100644 --- a/keras/models/model_test.py +++ b/keras/models/model_test.py @@ -599,5 +599,7 @@ def call(self, inputs, training=False): def test_quantize_invalid_args(self): model = _get_model() - with self.assertRaisesRegex(ValueError, "`quantize` must be one of"): + with self.assertRaisesRegex( + ValueError, "Invalid quantization mode. Expected 'int8'." + ): model.quantize("abc") From 50c6e74f6c675c5c44d79e511341b65385163d27 Mon Sep 17 00:00:00 2001 From: Hongyu Chiu <20734616+james77777778@users.noreply.github.com> Date: Fri, 8 Mar 2024 10:37:53 +0800 Subject: [PATCH 33/33] Improve `Model.quantize` to identify the leaves of the model --- keras/models/model.py | 14 ++++++++------ keras/models/model_test.py | 38 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 46 insertions(+), 6 deletions(-) diff --git a/keras/models/model.py b/keras/models/model.py index 3d51bf868a4..2adb88bd7c5 100644 --- a/keras/models/model.py +++ b/keras/models/model.py @@ -382,12 +382,14 @@ def quantize(self, mode): f"Received: mode={mode}" ) mode_changed = False - for layer in self._flatten_layers(include_self=False, recursive=True): - try: - layer.quantize(mode) - mode_changed = True - except NotImplementedError as e: - warnings.warn(str(e)) + for layer in self._flatten_layers(): + list_of_sublayers = list(layer._flatten_layers()) + if len(list_of_sublayers) == 1: # leaves of the model + try: + layer.quantize(mode) + mode_changed = True + except NotImplementedError as e: + warnings.warn(str(e)) # We need to set these functions to `None` to remake them for changed # call function if mode_changed: diff --git a/keras/models/model_test.py b/keras/models/model_test.py index f650bf79c96..263a200ba28 100644 --- a/keras/models/model_test.py +++ b/keras/models/model_test.py @@ -2,6 +2,7 @@ import pytest from absl.testing import parameterized +from keras import backend from keras import layers from keras import testing from keras.layers.core.input_layer import Input @@ -603,3 +604,40 @@ def test_quantize_invalid_args(self): ValueError, "Invalid quantization mode. Expected 'int8'." ): model.quantize("abc") + + def test_quantize_nested_model(self): + class NestedLayer(layers.Layer): + def __init__(self, units): + super().__init__() + self.dense = layers.Dense(units) + + def call(self, x): + x = self.dense(x) + return x + + class DoubleNestedLayer(layers.Layer): + def __init__(self, units): + super().__init__() + self.nested_dense1 = NestedLayer(units) + self.nested_dense2 = NestedLayer(units) + self.dense = layers.Dense(units) + + def call(self, x): + x = self.nested_dense1(x) + x = self.nested_dense2(x) + x = self.dense(x) + return x + + inputs = layers.Input([3]) + outputs = DoubleNestedLayer(8)(inputs) + model = Model(inputs, outputs) + model.quantize("int8") + + kernel_count = 0 + for weight in model.weights: + if weight.name == "kernel": + kernel_count += 1 + self.assertEqual( + backend.standardize_dtype(weight.dtype), "int8" + ) + self.assertEqual(kernel_count, 3)