From 9c588b76d7822c23c6890ef6b3c35a036c325d8f Mon Sep 17 00:00:00 2001 From: Elad Cohen <78862769+elad-c@users.noreply.github.com> Date: Tue, 5 Dec 2023 16:34:25 +0200 Subject: [PATCH] Add TF matmul with const to Dense layer substitution (#880) --- .../core/keras/constants.py | 2 + .../substitutions/matmul_substitution.py | 108 ++++++++++++++++++ .../core/keras/keras_implementation.py | 3 + .../feature_networks/bn_folding_test.py | 2 +- .../matmul_substitution_test.py | 78 +++++++++++++ .../test_features_runner.py | 10 +- 6 files changed, 199 insertions(+), 4 deletions(-) create mode 100644 model_compression_toolkit/core/keras/graph_substitutions/substitutions/matmul_substitution.py create mode 100644 tests/keras_tests/feature_networks_tests/feature_networks/matmul_substitution_test.py diff --git a/model_compression_toolkit/core/keras/constants.py b/model_compression_toolkit/core/keras/constants.py index 5b4c20de3..d3c67d0a2 100644 --- a/model_compression_toolkit/core/keras/constants.py +++ b/model_compression_toolkit/core/keras/constants.py @@ -53,6 +53,8 @@ TARGET_SHAPE = 'target_shape' TRANSPOSE_A = 'transpose_a' TRANSPOSE_B = 'transpose_b' +ADJOINT_A = 'adjoint_a' +ADJOINT_B = 'adjoint_b' DEPTH_MULTIPLIER = 'depth_multiplier' DEPTHWISE_INITIALIZER = 'depthwise_initializer' DEPTHWISE_REGULARIZER = 'depthwise_regularizer' diff --git a/model_compression_toolkit/core/keras/graph_substitutions/substitutions/matmul_substitution.py b/model_compression_toolkit/core/keras/graph_substitutions/substitutions/matmul_substitution.py new file mode 100644 index 000000000..a376f8752 --- /dev/null +++ b/model_compression_toolkit/core/keras/graph_substitutions/substitutions/matmul_substitution.py @@ -0,0 +1,108 @@ +# Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import numpy as np +import tensorflow as tf +from model_compression_toolkit.logger import Logger +from model_compression_toolkit.core import common +from model_compression_toolkit.core.common.graph.base_graph import Graph +from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher +from model_compression_toolkit.core.common.graph.base_node import BaseNode +from model_compression_toolkit.core.common.graph.functional_node import FunctionalNode +from model_compression_toolkit.core.keras.constants import TRANSPOSE_A, TRANSPOSE_B, \ + ADJOINT_A, ADJOINT_B, UNITS, USE_BIAS, KERNEL + + +class MatmulToDenseSubstitution(common.BaseSubstitution): + """ + Replace a linear layer that has an activation function, with two nodes: same linear layer without + an activation function, and a new activation layer to replace the function the linear node had. + """ + + def __init__(self): + """ + Matches: tf.linalg.matmul + """ + super().__init__(matcher_instance=NodeOperationMatcher(tf.linalg.matmul)) + + def substitute(self, + graph: Graph, + matmul_node: FunctionalNode) -> Graph: + """ + Replace tf.linalg.matmul with Tensor and const with Dense layer + + Args: + graph: Graph we apply the substitution on. + matmul_node: Node to replace. + + Returns: + Graph after applying the substitution. + """ + + if len(graph.get_prev_nodes(matmul_node)) > 1: + # matmul of 2 activation tensors -> can't replace with Dense layer + return graph + + if matmul_node.framework_attr.get(ADJOINT_A, False) or matmul_node.framework_attr.get(ADJOINT_B, False): + # MCT doesn't support complex tensors + return graph + + if matmul_node.framework_attr.get(TRANSPOSE_A, False): + # first input should be an activation tensor with batch axis, that shouldn't be transposed + return graph + + # read const from matmul inputs + if len(matmul_node.op_call_args) > 0: + w = matmul_node.op_call_args[0] + elif 'b' in matmul_node.op_call_kwargs: + w = matmul_node.op_call_kwargs['b'] + else: + Logger.error(f"Matmul substitution: can't locate weight for node {matmul_node.name}") # pragma: no cover + + # Convert weight const to numpy array + if isinstance(w, tf.Tensor): + w = w.numpy() + elif isinstance(w, list): + w = np.array(w) + elif not isinstance(w, np.ndarray): + Logger.error(f'Unable to convert constant to numpy array: {matmul_node.name}') # pragma: no cover + + if len(w.shape) != 2: + # weight tensor should be of shape (Cin, Cout) + return graph + + # transpose const if "transpose_b" flag is True + if matmul_node.op_call_kwargs.get(TRANSPOSE_B, False) or ( + len(matmul_node.op_call_args) >= 3 and matmul_node.op_call_args[2]): + w = w.transpose() + + dense_node = BaseNode(matmul_node.name, + {UNITS: w.shape[1], USE_BIAS: False}, + matmul_node.input_shape, matmul_node.output_shape, + {KERNEL: w}, tf.keras.layers.Dense, + reuse=matmul_node.reuse, reuse_group=matmul_node.reuse_group) + + graph.add_node(dense_node) + graph.reconnect_in_edges(current_node=matmul_node, + new_node=dense_node) + graph.reconnect_out_edges(current_node=matmul_node, + new_node=dense_node) + graph.replace_output_node(current_node=matmul_node, + new_node=dense_node) + graph.remove_node(matmul_node) + + return graph + + diff --git a/model_compression_toolkit/core/keras/keras_implementation.py b/model_compression_toolkit/core/keras/keras_implementation.py index dac938ee5..1f6af001f 100644 --- a/model_compression_toolkit/core/keras/keras_implementation.py +++ b/model_compression_toolkit/core/keras/keras_implementation.py @@ -68,6 +68,8 @@ from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO from model_compression_toolkit.core.keras.graph_substitutions.substitutions.activation_decomposition import \ ActivationDecomposition +from model_compression_toolkit.core.keras.graph_substitutions.substitutions.matmul_substitution import \ + MatmulToDenseSubstitution from model_compression_toolkit.core.keras.graph_substitutions.substitutions.softmax_shift import \ keras_softmax_shift from model_compression_toolkit.core.keras.graph_substitutions.substitutions.batchnorm_folding import \ @@ -260,6 +262,7 @@ def get_substitutions_prepare_graph(self, fw_info: FrameworkInfo = None) -> List """ return [SeparableConvDecomposition(), + MatmulToDenseSubstitution(), MultiHeadAttentionDecomposition(), ActivationDecomposition(), DwconvToConv()] diff --git a/tests/keras_tests/feature_networks_tests/feature_networks/bn_folding_test.py b/tests/keras_tests/feature_networks_tests/feature_networks/bn_folding_test.py index ff8d5a9e3..581765574 100644 --- a/tests/keras_tests/feature_networks_tests/feature_networks/bn_folding_test.py +++ b/tests/keras_tests/feature_networks_tests/feature_networks/bn_folding_test.py @@ -147,7 +147,7 @@ def create_networks(self): return tf.keras.models.Model(inputs=inputs, outputs=x) -class Conv2DBNConcatnFoldingTest(BaseBatchNormalizationFolding): +class Conv2DBNConcatFoldingTest(BaseBatchNormalizationFolding): def __init__(self, unit_test): super().__init__(unit_test, linear_layer=layers.Conv2D) diff --git a/tests/keras_tests/feature_networks_tests/feature_networks/matmul_substitution_test.py b/tests/keras_tests/feature_networks_tests/feature_networks/matmul_substitution_test.py new file mode 100644 index 000000000..57f98d513 --- /dev/null +++ b/tests/keras_tests/feature_networks_tests/feature_networks/matmul_substitution_test.py @@ -0,0 +1,78 @@ +# Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + + +from packaging import version +import tensorflow as tf +if version.parse(tf.__version__) >= version.parse("2.13"): + from keras.src.layers.core import TFOpLambda +else: + from keras.layers.core import TFOpLambda + +from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.latest import generate_keras_tpc +import model_compression_toolkit as mct + +from tests.common_tests.helpers.generate_test_tp_model import generate_test_tp_model +from tests.keras_tests.feature_networks_tests.base_keras_feature_test import BaseKerasFeatureNetworkTest +import numpy as np +from tests.common_tests.helpers.tensors_compare import cosine_similarity + + +class MatmulToDenseSubstitutionTest(BaseKerasFeatureNetworkTest): + def __init__(self, unit_test): + super().__init__(unit_test, input_shape=(8,)) + + def get_tpc(self): + tp = generate_test_tp_model({'weights_n_bits': 16, + 'activation_n_bits': 16, + 'enable_weights_quantization': False, + 'enable_activation_quantization': False}) + return generate_keras_tpc(name="no_quantization", tp_model=tp) + + def get_quantization_config(self): + return mct.core.QuantizationConfig(mct.core.QuantizationErrorMethod.NOCLIPPING, + mct.core.QuantizationErrorMethod.NOCLIPPING, + False, False, True) + + def create_networks(self): + inputs = tf.keras.layers.Input(shape=self.get_input_shapes()[0][1:]) + x = tf.matmul(inputs, b=tf.random.normal((8, 10))) + x = tf.keras.layers.ReLU()(x) + x = tf.matmul(x, np.random.normal(size=(10, 16))) + x = tf.keras.layers.ReLU()(x) + x = tf.matmul(x, np.random.normal(size=(16, 32)).tolist()) + x = tf.matmul(tf.reshape(x, (-1, 8, 4)), + tf.reshape(x, (-1, 4, 8))) + x = tf.keras.layers.ReLU()(tf.reshape(x, (-1, 64))) + x = tf.matmul(x, tf.random.normal((11, 64)), transpose_b=True) + x = tf.keras.layers.ReLU()(x) + x = tf.matmul(x, tf.random.normal((10, 11)), False, True) + x = tf.keras.layers.ReLU()(x) + return tf.keras.models.Model(inputs=inputs, outputs=x) + + def compare(self, quantized_model, float_model, input_x=None, quantization_info=None): + # check the output didn't change + y = float_model(input_x).numpy() + y_hat = quantized_model(input_x).numpy() + cs = cosine_similarity(y, y_hat) + self.unit_test.assertTrue(np.isclose(cs, 1), msg=f'fail cosine similarity check: {cs}') + + num_matmuls = 0 + for layer in quantized_model.layers: + if isinstance(layer, TFOpLambda) and layer.function is tf.matmul: + num_matmuls += 1 + + # check all "matmul"s were replaced except the one with 2 tensor inputs + self.unit_test.assertTrue(num_matmuls == 1, msg=f'Only one matmul should remain in the quantized model') diff --git a/tests/keras_tests/feature_networks_tests/test_features_runner.py b/tests/keras_tests/feature_networks_tests/test_features_runner.py index bbe8e5235..765541ac2 100644 --- a/tests/keras_tests/feature_networks_tests/test_features_runner.py +++ b/tests/keras_tests/feature_networks_tests/test_features_runner.py @@ -31,7 +31,7 @@ BiasCorrectionDepthwiseTest from tests.keras_tests.feature_networks_tests.feature_networks.bn_folding_test import Conv2DBNFoldingTest, \ DepthwiseConv2DBNFoldingTest, DepthwiseConv2DBNFoldingHighMultiplierTest, Conv2DTransposeBNFoldingTest, \ - Conv2DBNConcatnFoldingTest, SeparableConv2DBNFoldingTest, BNForwardFoldingTest + Conv2DBNConcatFoldingTest, SeparableConv2DBNFoldingTest, BNForwardFoldingTest from tests.keras_tests.feature_networks_tests.feature_networks.conv_bn_relu_residual_test import ConvBnReluResidualTest from tests.keras_tests.feature_networks_tests.feature_networks.decompose_separable_conv_test import \ DecomposeSeparableConvTest @@ -124,6 +124,7 @@ MixedPercisionSearchLastLayerDistanceTest, MixedPercisionSearchActivationKPINonConfNodesTest, \ MixedPercisionSearchTotalKPINonConfNodesTest, MixedPercisionSearchPartWeightsLayersTest, MixedPercisionCombinedNMSTest from tests.keras_tests.feature_networks_tests.feature_networks.old_api_test import OldApiTest +from tests.keras_tests.feature_networks_tests.feature_networks.matmul_substitution_test import MatmulToDenseSubstitutionTest from model_compression_toolkit.qat.common.qat_config import TrainingMethod layers = tf.keras.layers @@ -467,8 +468,11 @@ def test_activation_decomposition(self): def test_experimental_exporter(self): ExportableModelTest(self).run_test() - def test_conv2d_bn_concant(self): - Conv2DBNConcatnFoldingTest(self).run_test() + def test_matmul_dense_substitution(self): + MatmulToDenseSubstitutionTest(self).run_test() + + def test_conv2d_bn_concat(self): + Conv2DBNConcatFoldingTest(self).run_test() def test_activation_scaling_relu6(self): ReLUBoundToPOTNetTest(self).run_test()