Skip to content

Commit

Permalink
Add TF matmul with const to Dense layer substitution (sony#880)
Browse files Browse the repository at this point in the history
  • Loading branch information
elad-c authored Dec 5, 2023
1 parent 12b92f0 commit 9c588b7
Show file tree
Hide file tree
Showing 6 changed files with 199 additions and 4 deletions.
2 changes: 2 additions & 0 deletions model_compression_toolkit/core/keras/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@
TARGET_SHAPE = 'target_shape'
TRANSPOSE_A = 'transpose_a'
TRANSPOSE_B = 'transpose_b'
ADJOINT_A = 'adjoint_a'
ADJOINT_B = 'adjoint_b'
DEPTH_MULTIPLIER = 'depth_multiplier'
DEPTHWISE_INITIALIZER = 'depthwise_initializer'
DEPTHWISE_REGULARIZER = 'depthwise_regularizer'
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
# Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

import numpy as np
import tensorflow as tf
from model_compression_toolkit.logger import Logger
from model_compression_toolkit.core import common
from model_compression_toolkit.core.common.graph.base_graph import Graph
from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher
from model_compression_toolkit.core.common.graph.base_node import BaseNode
from model_compression_toolkit.core.common.graph.functional_node import FunctionalNode
from model_compression_toolkit.core.keras.constants import TRANSPOSE_A, TRANSPOSE_B, \
ADJOINT_A, ADJOINT_B, UNITS, USE_BIAS, KERNEL


class MatmulToDenseSubstitution(common.BaseSubstitution):
"""
Replace a linear layer that has an activation function, with two nodes: same linear layer without
an activation function, and a new activation layer to replace the function the linear node had.
"""

def __init__(self):
"""
Matches: tf.linalg.matmul
"""
super().__init__(matcher_instance=NodeOperationMatcher(tf.linalg.matmul))

def substitute(self,
graph: Graph,
matmul_node: FunctionalNode) -> Graph:
"""
Replace tf.linalg.matmul with Tensor and const with Dense layer
Args:
graph: Graph we apply the substitution on.
matmul_node: Node to replace.
Returns:
Graph after applying the substitution.
"""

if len(graph.get_prev_nodes(matmul_node)) > 1:
# matmul of 2 activation tensors -> can't replace with Dense layer
return graph

if matmul_node.framework_attr.get(ADJOINT_A, False) or matmul_node.framework_attr.get(ADJOINT_B, False):
# MCT doesn't support complex tensors
return graph

if matmul_node.framework_attr.get(TRANSPOSE_A, False):
# first input should be an activation tensor with batch axis, that shouldn't be transposed
return graph

# read const from matmul inputs
if len(matmul_node.op_call_args) > 0:
w = matmul_node.op_call_args[0]
elif 'b' in matmul_node.op_call_kwargs:
w = matmul_node.op_call_kwargs['b']
else:
Logger.error(f"Matmul substitution: can't locate weight for node {matmul_node.name}") # pragma: no cover

# Convert weight const to numpy array
if isinstance(w, tf.Tensor):
w = w.numpy()
elif isinstance(w, list):
w = np.array(w)
elif not isinstance(w, np.ndarray):
Logger.error(f'Unable to convert constant to numpy array: {matmul_node.name}') # pragma: no cover

if len(w.shape) != 2:
# weight tensor should be of shape (Cin, Cout)
return graph

# transpose const if "transpose_b" flag is True
if matmul_node.op_call_kwargs.get(TRANSPOSE_B, False) or (
len(matmul_node.op_call_args) >= 3 and matmul_node.op_call_args[2]):
w = w.transpose()

dense_node = BaseNode(matmul_node.name,
{UNITS: w.shape[1], USE_BIAS: False},
matmul_node.input_shape, matmul_node.output_shape,
{KERNEL: w}, tf.keras.layers.Dense,
reuse=matmul_node.reuse, reuse_group=matmul_node.reuse_group)

graph.add_node(dense_node)
graph.reconnect_in_edges(current_node=matmul_node,
new_node=dense_node)
graph.reconnect_out_edges(current_node=matmul_node,
new_node=dense_node)
graph.replace_output_node(current_node=matmul_node,
new_node=dense_node)
graph.remove_node(matmul_node)

return graph


3 changes: 3 additions & 0 deletions model_compression_toolkit/core/keras/keras_implementation.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@
from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
from model_compression_toolkit.core.keras.graph_substitutions.substitutions.activation_decomposition import \
ActivationDecomposition
from model_compression_toolkit.core.keras.graph_substitutions.substitutions.matmul_substitution import \
MatmulToDenseSubstitution
from model_compression_toolkit.core.keras.graph_substitutions.substitutions.softmax_shift import \
keras_softmax_shift
from model_compression_toolkit.core.keras.graph_substitutions.substitutions.batchnorm_folding import \
Expand Down Expand Up @@ -260,6 +262,7 @@ def get_substitutions_prepare_graph(self, fw_info: FrameworkInfo = None) -> List
"""
return [SeparableConvDecomposition(),
MatmulToDenseSubstitution(),
MultiHeadAttentionDecomposition(),
ActivationDecomposition(),
DwconvToConv()]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def create_networks(self):
return tf.keras.models.Model(inputs=inputs, outputs=x)


class Conv2DBNConcatnFoldingTest(BaseBatchNormalizationFolding):
class Conv2DBNConcatFoldingTest(BaseBatchNormalizationFolding):
def __init__(self, unit_test):
super().__init__(unit_test,
linear_layer=layers.Conv2D)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================


from packaging import version
import tensorflow as tf
if version.parse(tf.__version__) >= version.parse("2.13"):
from keras.src.layers.core import TFOpLambda
else:
from keras.layers.core import TFOpLambda

from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.latest import generate_keras_tpc
import model_compression_toolkit as mct

from tests.common_tests.helpers.generate_test_tp_model import generate_test_tp_model
from tests.keras_tests.feature_networks_tests.base_keras_feature_test import BaseKerasFeatureNetworkTest
import numpy as np
from tests.common_tests.helpers.tensors_compare import cosine_similarity


class MatmulToDenseSubstitutionTest(BaseKerasFeatureNetworkTest):
def __init__(self, unit_test):
super().__init__(unit_test, input_shape=(8,))

def get_tpc(self):
tp = generate_test_tp_model({'weights_n_bits': 16,
'activation_n_bits': 16,
'enable_weights_quantization': False,
'enable_activation_quantization': False})
return generate_keras_tpc(name="no_quantization", tp_model=tp)

def get_quantization_config(self):
return mct.core.QuantizationConfig(mct.core.QuantizationErrorMethod.NOCLIPPING,
mct.core.QuantizationErrorMethod.NOCLIPPING,
False, False, True)

def create_networks(self):
inputs = tf.keras.layers.Input(shape=self.get_input_shapes()[0][1:])
x = tf.matmul(inputs, b=tf.random.normal((8, 10)))
x = tf.keras.layers.ReLU()(x)
x = tf.matmul(x, np.random.normal(size=(10, 16)))
x = tf.keras.layers.ReLU()(x)
x = tf.matmul(x, np.random.normal(size=(16, 32)).tolist())
x = tf.matmul(tf.reshape(x, (-1, 8, 4)),
tf.reshape(x, (-1, 4, 8)))
x = tf.keras.layers.ReLU()(tf.reshape(x, (-1, 64)))
x = tf.matmul(x, tf.random.normal((11, 64)), transpose_b=True)
x = tf.keras.layers.ReLU()(x)
x = tf.matmul(x, tf.random.normal((10, 11)), False, True)
x = tf.keras.layers.ReLU()(x)
return tf.keras.models.Model(inputs=inputs, outputs=x)

def compare(self, quantized_model, float_model, input_x=None, quantization_info=None):
# check the output didn't change
y = float_model(input_x).numpy()
y_hat = quantized_model(input_x).numpy()
cs = cosine_similarity(y, y_hat)
self.unit_test.assertTrue(np.isclose(cs, 1), msg=f'fail cosine similarity check: {cs}')

num_matmuls = 0
for layer in quantized_model.layers:
if isinstance(layer, TFOpLambda) and layer.function is tf.matmul:
num_matmuls += 1

# check all "matmul"s were replaced except the one with 2 tensor inputs
self.unit_test.assertTrue(num_matmuls == 1, msg=f'Only one matmul should remain in the quantized model')
10 changes: 7 additions & 3 deletions tests/keras_tests/feature_networks_tests/test_features_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
BiasCorrectionDepthwiseTest
from tests.keras_tests.feature_networks_tests.feature_networks.bn_folding_test import Conv2DBNFoldingTest, \
DepthwiseConv2DBNFoldingTest, DepthwiseConv2DBNFoldingHighMultiplierTest, Conv2DTransposeBNFoldingTest, \
Conv2DBNConcatnFoldingTest, SeparableConv2DBNFoldingTest, BNForwardFoldingTest
Conv2DBNConcatFoldingTest, SeparableConv2DBNFoldingTest, BNForwardFoldingTest
from tests.keras_tests.feature_networks_tests.feature_networks.conv_bn_relu_residual_test import ConvBnReluResidualTest
from tests.keras_tests.feature_networks_tests.feature_networks.decompose_separable_conv_test import \
DecomposeSeparableConvTest
Expand Down Expand Up @@ -124,6 +124,7 @@
MixedPercisionSearchLastLayerDistanceTest, MixedPercisionSearchActivationKPINonConfNodesTest, \
MixedPercisionSearchTotalKPINonConfNodesTest, MixedPercisionSearchPartWeightsLayersTest, MixedPercisionCombinedNMSTest
from tests.keras_tests.feature_networks_tests.feature_networks.old_api_test import OldApiTest
from tests.keras_tests.feature_networks_tests.feature_networks.matmul_substitution_test import MatmulToDenseSubstitutionTest
from model_compression_toolkit.qat.common.qat_config import TrainingMethod

layers = tf.keras.layers
Expand Down Expand Up @@ -467,8 +468,11 @@ def test_activation_decomposition(self):
def test_experimental_exporter(self):
ExportableModelTest(self).run_test()

def test_conv2d_bn_concant(self):
Conv2DBNConcatnFoldingTest(self).run_test()
def test_matmul_dense_substitution(self):
MatmulToDenseSubstitutionTest(self).run_test()

def test_conv2d_bn_concat(self):
Conv2DBNConcatFoldingTest(self).run_test()

def test_activation_scaling_relu6(self):
ReLUBoundToPOTNetTest(self).run_test()
Expand Down

0 comments on commit 9c588b7

Please sign in to comment.