Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce dynamic int8 quantization API #19263

Merged
merged 36 commits into from
Mar 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
9c5f3cc
Add `quantize` to `Dense`
james77777778 Feb 26, 2024
ef48715
Fix `dynamic_int8_call`
james77777778 Feb 26, 2024
3456c66
Cleanup for unused `quantization_trainable`
james77777778 Feb 26, 2024
5d39769
Add demo script
james77777778 Feb 26, 2024
f549268
Update script
james77777778 Feb 26, 2024
26c4585
Merge branch 'keras-team:master' into add-quantize-v2
james77777778 Mar 1, 2024
047d056
Update dtype policy
james77777778 Mar 1, 2024
a7209f4
Update dtype_policy
james77777778 Mar 1, 2024
ee5da63
Update demo script
james77777778 Mar 1, 2024
723e69a
Update Dense
james77777778 Mar 2, 2024
18afecd
Merge branch 'keras-team:master' into add-quantize-v2
james77777778 Mar 2, 2024
b36560a
Remove unused `input_shape`
james77777778 Mar 3, 2024
e4a8996
Merge branch 'keras-team:master' into add-quantize-v2
james77777778 Mar 4, 2024
ad6dd6b
Add `self.is_quantized_int8` to `Operation` and some minor updates
james77777778 Mar 4, 2024
1ebf41e
Add `quantize` to `EinsumDense`
james77777778 Mar 4, 2024
443a6b1
Add `ab,bc->ac` custom ops
james77777778 Mar 5, 2024
d774cd3
Update
james77777778 Mar 5, 2024
5c27a44
Update RESULT.md
james77777778 Mar 5, 2024
0571370
Update RESULT.md
james77777778 Mar 5, 2024
515e193
Update `InputSpec`
james77777778 Mar 5, 2024
885174f
Raise NotImplementedError in `int8_call` and `quantize`
james77777778 Mar 6, 2024
32277fc
Fix variable creation issue in `quantize`
james77777778 Mar 6, 2024
95cfe0b
Introduce `FloatDtypePolicy` and `QuantizedDtypePolicy`
james77777778 Mar 6, 2024
7370914
Defaults to `backend.floatx()` to `self._compute_dtype` and `self._va…
james77777778 Mar 6, 2024
43f09f7
Remove unused code
james77777778 Mar 6, 2024
64d6241
Rename `mode=quantized_int8` to `mode=int8` and update dtype policy f…
james77777778 Mar 7, 2024
0ff27ed
Update demo script
james77777778 Mar 7, 2024
499296c
Delete demo script and result
james77777778 Mar 7, 2024
f6493cb
Add `quantize` tests for Dense and EinsumDense
james77777778 Mar 7, 2024
8b717be
Add `quantize` tests for Model
james77777778 Mar 7, 2024
792545a
Add tests for `keras.quantizers`
james77777778 Mar 7, 2024
1d124bb
Update `quantizers` tests
james77777778 Mar 7, 2024
f7d7cd7
Improve test coverage
james77777778 Mar 7, 2024
eb76c90
Resolve comments
james77777778 Mar 7, 2024
243445f
Fix test
james77777778 Mar 7, 2024
50c6e74
Improve `Model.quantize` to identify the leaves of the model
james77777778 Mar 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions keras/backend/tensorflow/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ def is_valid_for_custom_ops(subscripts, *operands):
# `None`.
if subscripts in [
"a,b->ab",
"ab,bc->ac",
"abc,cd->abd",
"abcd,abed->abce",
"abcd,adbe->acbe",
Expand Down Expand Up @@ -158,6 +159,8 @@ def use_custom_ops(subscripts, *operands, output_type):
x = tf.expand_dims(x, axis=-1)
y = tf.expand_dims(y, axis=0)
return tf.matmul(x, y, output_type=output_type)
elif subscripts == "ab,bc->ac":
return tf.matmul(x, y, output_type=output_type)
elif subscripts == "abc,cd->abd":
return tf.matmul(x, y, output_type=output_type)
elif subscripts == "abc,cde->abde":
Expand Down
6 changes: 5 additions & 1 deletion keras/backend/torch/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,12 @@ def einsum(subscripts, *operands, **kwargs):
# the behavior of jax.
dtypes_to_resolve = list(set(standardize_dtype(x.dtype) for x in operands))
if len(dtypes_to_resolve) == 1 and dtypes_to_resolve[0] == "int8":
compute_dtype = "int32"
if get_device() == "cuda":
# TODO: torch.einsum doesn't support int32 when using cuda
compute_dtype = config.floatx()
# prevent overflow
operands = [cast(operand, "int32") for operand in operands]
operands = [cast(operand, compute_dtype) for operand in operands]
return cast(torch.einsum(subscripts, *operands), "int32")
return torch.einsum(subscripts, *operands)

Expand Down
14 changes: 10 additions & 4 deletions keras/dtype_policies/__init__.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,25 @@
from keras import backend
from keras.dtype_policies import dtype_policy
from keras.saving import serialization_lib
from keras.dtype_policies.dtype_policy import FloatDTypePolicy
from keras.dtype_policies.dtype_policy import QuantizedDTypePolicy


def get(identifier):
from keras.saving import serialization_lib

if identifier is None:
return dtype_policy.dtype_policy()
if isinstance(identifier, dtype_policy.DTypePolicy):
if isinstance(identifier, (FloatDTypePolicy, QuantizedDTypePolicy)):
return identifier
if isinstance(identifier, dict):
return serialization_lib.deserialize_keras_object(identifier)
if isinstance(identifier, str):
return dtype_policy.DTypePolicy(identifier)
if "int8" in identifier:
return QuantizedDTypePolicy(identifier)
else:
return FloatDTypePolicy(identifier)
try:
return dtype_policy.DTypePolicy(backend.standardize_dtype(identifier))
return FloatDTypePolicy(backend.standardize_dtype(identifier))
except:
raise ValueError(
"Cannot interpret `dtype` argument. Expected a string "
Expand Down
116 changes: 94 additions & 22 deletions keras/dtype_policies/dtype_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,16 +55,24 @@ class DTypePolicy:
to explicitly construct a `DTypePolicy` object.
"""

def __init__(self, name):
def __new__(cls, name):
if not isinstance(name, str):
raise TypeError(
"'name' must be a string, such as 'mixed_float16'. "
f"Received: name={name} (of type {type(name)})"
)
# For backwards compatibility
# TODO: We should consider deprecating this behavior
if cls is __class__:
james77777778 marked this conversation as resolved.
Show resolved Hide resolved
if "int8" in name:
return QuantizedDTypePolicy(name)
return FloatDTypePolicy(name)
return super().__new__(cls)

def __init__(self, name):
self._name = name
self._compute_dtype, self._variable_dtype = self._parse_name(name)
# TODO: check that the current hardware supports the provided
# dtype policy and raise/warn otherwise.
self._compute_dtype = backend.floatx()
self._variable_dtype = backend.floatx()

def _parse_name(self, name):
"""Parses a `DTypePolicy` name into a compute and variable dtype.
Expand All @@ -75,19 +83,7 @@ def _parse_name(self, name):
Returns:
The `(compute_dtype, variable_dtype)` pair.
"""
if name == "mixed_float16":
return "float16", "float32"
elif name == "mixed_bfloat16":
return "bfloat16", "float32"
try:
dtype = backend.standardize_dtype(name)
return dtype, dtype
except ValueError:
raise ValueError(
f"Cannot convert '{name}' to a mixed precision DTypePolicy."
" Valid policies include 'mixed_float16', 'mixed_bfloat16', "
"and the name of any dtype such as 'float32'."
)
raise NotImplementedError

@property
def variable_dtype(self):
Expand Down Expand Up @@ -132,9 +128,6 @@ def name(self):
"""Returns the name of this policy."""
return self._name

def __repr__(self):
return f'<DTypePolicy "{self._name}">'

def convert_input(self, x, autocast, dtype):
dtype = backend.standardize_dtype(dtype)
if backend.is_tensor(x):
Expand Down Expand Up @@ -165,6 +158,82 @@ def from_config(cls, config):
return cls(**config)


@keras_export(
["keras.FloatDTypePolicy", "keras.dtype_policies.FloatDTypePolicy"]
)
class FloatDTypePolicy(DTypePolicy):
def __init__(self, name):
super().__init__(name)
self._compute_dtype, self._variable_dtype = self._parse_name(name)
# TODO: check that the current hardware supports the provided
# dtype policy and raise/warn otherwise.

def _parse_name(self, name):
if name == "mixed_float16":
return "float16", "float32"
elif name == "mixed_bfloat16":
return "bfloat16", "float32"
try:
dtype = backend.standardize_dtype(name)
return dtype, dtype
except ValueError:
raise ValueError(
f"Cannot convert '{name}' to a mixed precision "
"FloatDTypePolicy. Valid policies include 'mixed_float16', "
"'mixed_bfloat16', and the name of any float dtype such as "
"'float32'."
)

def __repr__(self):
return f'<FloatDTypePolicy "{self._name}">'


@keras_export(
["keras.QuantizedDTypePolicy", "keras.dtype_policies.QuantizedDTypePolicy"]
)
class QuantizedDTypePolicy(DTypePolicy):
def __init__(self, name):
super().__init__(name)
self._quantization_mode, self._compute_dtype, self._variable_dtype = (
self._parse_name(name)
)

def _parse_name(self, name):
error_msg = (
f"Cannot convert '{name}' to a QuantizedDTypePolicy. "
"Valid policies include "
"'int8_from_float32', 'int8_from_float16', 'int8_from_bfloat16', "
"'int8_from_mixed_float16', 'int8_from_mixed_bfloat16'."
)
split_name = name.split("_from_")
if len(split_name) != 2:
raise ValueError(error_msg)
mode, from_name = split_name
if mode not in ("int8",):
raise ValueError(error_msg)
if from_name == "mixed_float16":
return mode, "float16", "float32"
elif from_name == "mixed_bfloat16":
return mode, "bfloat16", "float32"
try:
dtype = backend.standardize_dtype(from_name)
return mode, dtype, dtype
except ValueError:
raise ValueError(error_msg)

@property
def quantization_mode(self):
"""The quantization mode of this policy.

Returns:
The quantization mode of this policy, as a string.
"""
return self._quantization_mode

def __repr__(self):
return f'<QuantizedDTypePolicy "{self._name}">'


@keras_export(
[
"keras.config.set_dtype_policy",
Expand All @@ -181,7 +250,10 @@ def set_dtype_policy(policy):
"""
if not isinstance(policy, DTypePolicy):
if isinstance(policy, str):
policy = DTypePolicy(policy)
if "int8" in policy:
policy = QuantizedDTypePolicy(policy)
else:
policy = FloatDTypePolicy(policy)
else:
raise ValueError(
"Invalid `policy` argument. "
Expand All @@ -204,6 +276,6 @@ def dtype_policy():
"""Returns the current default dtype policy object."""
policy = global_state.get_global_attribute("dtype_policy", None)
if policy is None:
policy = DTypePolicy(backend.floatx())
policy = FloatDTypePolicy(backend.floatx())
set_dtype_policy(policy)
return policy
Loading