forked from sony/model_optimization
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add TF matmul with const to Dense layer substitution (sony#880)
- Loading branch information
Showing
6 changed files
with
199 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
108 changes: 108 additions & 0 deletions
108
...l_compression_toolkit/core/keras/graph_substitutions/substitutions/matmul_substitution.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
# Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
|
||
import numpy as np | ||
import tensorflow as tf | ||
from model_compression_toolkit.logger import Logger | ||
from model_compression_toolkit.core import common | ||
from model_compression_toolkit.core.common.graph.base_graph import Graph | ||
from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher | ||
from model_compression_toolkit.core.common.graph.base_node import BaseNode | ||
from model_compression_toolkit.core.common.graph.functional_node import FunctionalNode | ||
from model_compression_toolkit.core.keras.constants import TRANSPOSE_A, TRANSPOSE_B, \ | ||
ADJOINT_A, ADJOINT_B, UNITS, USE_BIAS, KERNEL | ||
|
||
|
||
class MatmulToDenseSubstitution(common.BaseSubstitution): | ||
""" | ||
Replace a linear layer that has an activation function, with two nodes: same linear layer without | ||
an activation function, and a new activation layer to replace the function the linear node had. | ||
""" | ||
|
||
def __init__(self): | ||
""" | ||
Matches: tf.linalg.matmul | ||
""" | ||
super().__init__(matcher_instance=NodeOperationMatcher(tf.linalg.matmul)) | ||
|
||
def substitute(self, | ||
graph: Graph, | ||
matmul_node: FunctionalNode) -> Graph: | ||
""" | ||
Replace tf.linalg.matmul with Tensor and const with Dense layer | ||
Args: | ||
graph: Graph we apply the substitution on. | ||
matmul_node: Node to replace. | ||
Returns: | ||
Graph after applying the substitution. | ||
""" | ||
|
||
if len(graph.get_prev_nodes(matmul_node)) > 1: | ||
# matmul of 2 activation tensors -> can't replace with Dense layer | ||
return graph | ||
|
||
if matmul_node.framework_attr.get(ADJOINT_A, False) or matmul_node.framework_attr.get(ADJOINT_B, False): | ||
# MCT doesn't support complex tensors | ||
return graph | ||
|
||
if matmul_node.framework_attr.get(TRANSPOSE_A, False): | ||
# first input should be an activation tensor with batch axis, that shouldn't be transposed | ||
return graph | ||
|
||
# read const from matmul inputs | ||
if len(matmul_node.op_call_args) > 0: | ||
w = matmul_node.op_call_args[0] | ||
elif 'b' in matmul_node.op_call_kwargs: | ||
w = matmul_node.op_call_kwargs['b'] | ||
else: | ||
Logger.error(f"Matmul substitution: can't locate weight for node {matmul_node.name}") # pragma: no cover | ||
|
||
# Convert weight const to numpy array | ||
if isinstance(w, tf.Tensor): | ||
w = w.numpy() | ||
elif isinstance(w, list): | ||
w = np.array(w) | ||
elif not isinstance(w, np.ndarray): | ||
Logger.error(f'Unable to convert constant to numpy array: {matmul_node.name}') # pragma: no cover | ||
|
||
if len(w.shape) != 2: | ||
# weight tensor should be of shape (Cin, Cout) | ||
return graph | ||
|
||
# transpose const if "transpose_b" flag is True | ||
if matmul_node.op_call_kwargs.get(TRANSPOSE_B, False) or ( | ||
len(matmul_node.op_call_args) >= 3 and matmul_node.op_call_args[2]): | ||
w = w.transpose() | ||
|
||
dense_node = BaseNode(matmul_node.name, | ||
{UNITS: w.shape[1], USE_BIAS: False}, | ||
matmul_node.input_shape, matmul_node.output_shape, | ||
{KERNEL: w}, tf.keras.layers.Dense, | ||
reuse=matmul_node.reuse, reuse_group=matmul_node.reuse_group) | ||
|
||
graph.add_node(dense_node) | ||
graph.reconnect_in_edges(current_node=matmul_node, | ||
new_node=dense_node) | ||
graph.reconnect_out_edges(current_node=matmul_node, | ||
new_node=dense_node) | ||
graph.replace_output_node(current_node=matmul_node, | ||
new_node=dense_node) | ||
graph.remove_node(matmul_node) | ||
|
||
return graph | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
78 changes: 78 additions & 0 deletions
78
tests/keras_tests/feature_networks_tests/feature_networks/matmul_substitution_test.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
# Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
|
||
|
||
from packaging import version | ||
import tensorflow as tf | ||
if version.parse(tf.__version__) >= version.parse("2.13"): | ||
from keras.src.layers.core import TFOpLambda | ||
else: | ||
from keras.layers.core import TFOpLambda | ||
|
||
from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.latest import generate_keras_tpc | ||
import model_compression_toolkit as mct | ||
|
||
from tests.common_tests.helpers.generate_test_tp_model import generate_test_tp_model | ||
from tests.keras_tests.feature_networks_tests.base_keras_feature_test import BaseKerasFeatureNetworkTest | ||
import numpy as np | ||
from tests.common_tests.helpers.tensors_compare import cosine_similarity | ||
|
||
|
||
class MatmulToDenseSubstitutionTest(BaseKerasFeatureNetworkTest): | ||
def __init__(self, unit_test): | ||
super().__init__(unit_test, input_shape=(8,)) | ||
|
||
def get_tpc(self): | ||
tp = generate_test_tp_model({'weights_n_bits': 16, | ||
'activation_n_bits': 16, | ||
'enable_weights_quantization': False, | ||
'enable_activation_quantization': False}) | ||
return generate_keras_tpc(name="no_quantization", tp_model=tp) | ||
|
||
def get_quantization_config(self): | ||
return mct.core.QuantizationConfig(mct.core.QuantizationErrorMethod.NOCLIPPING, | ||
mct.core.QuantizationErrorMethod.NOCLIPPING, | ||
False, False, True) | ||
|
||
def create_networks(self): | ||
inputs = tf.keras.layers.Input(shape=self.get_input_shapes()[0][1:]) | ||
x = tf.matmul(inputs, b=tf.random.normal((8, 10))) | ||
x = tf.keras.layers.ReLU()(x) | ||
x = tf.matmul(x, np.random.normal(size=(10, 16))) | ||
x = tf.keras.layers.ReLU()(x) | ||
x = tf.matmul(x, np.random.normal(size=(16, 32)).tolist()) | ||
x = tf.matmul(tf.reshape(x, (-1, 8, 4)), | ||
tf.reshape(x, (-1, 4, 8))) | ||
x = tf.keras.layers.ReLU()(tf.reshape(x, (-1, 64))) | ||
x = tf.matmul(x, tf.random.normal((11, 64)), transpose_b=True) | ||
x = tf.keras.layers.ReLU()(x) | ||
x = tf.matmul(x, tf.random.normal((10, 11)), False, True) | ||
x = tf.keras.layers.ReLU()(x) | ||
return tf.keras.models.Model(inputs=inputs, outputs=x) | ||
|
||
def compare(self, quantized_model, float_model, input_x=None, quantization_info=None): | ||
# check the output didn't change | ||
y = float_model(input_x).numpy() | ||
y_hat = quantized_model(input_x).numpy() | ||
cs = cosine_similarity(y, y_hat) | ||
self.unit_test.assertTrue(np.isclose(cs, 1), msg=f'fail cosine similarity check: {cs}') | ||
|
||
num_matmuls = 0 | ||
for layer in quantized_model.layers: | ||
if isinstance(layer, TFOpLambda) and layer.function is tf.matmul: | ||
num_matmuls += 1 | ||
|
||
# check all "matmul"s were replaced except the one with 2 tensor inputs | ||
self.unit_test.assertTrue(num_matmuls == 1, msg=f'Only one matmul should remain in the quantized model') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters