From 4e76be826d0466087dc4a4fde5a0cb0c7e91f3d5 Mon Sep 17 00:00:00 2001 From: itai-berman Date: Wed, 20 Nov 2024 16:11:06 +0200 Subject: [PATCH] add substitution for functional linear (#1266) add Pytorch substitution for functional linear and related tests --- .../substitutions/functional_linear.py | 83 +++++++++++++++++++ .../core/pytorch/pytorch_implementation.py | 3 + .../feature_models/linear_function_test.py | 51 ++++++++++++ .../model_tests/test_feature_models_runner.py | 7 ++ 4 files changed, 144 insertions(+) create mode 100644 model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/functional_linear.py create mode 100644 tests/pytorch_tests/model_tests/feature_models/linear_function_test.py diff --git a/model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/functional_linear.py b/model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/functional_linear.py new file mode 100644 index 000000000..41ee11e7d --- /dev/null +++ b/model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/functional_linear.py @@ -0,0 +1,83 @@ +# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from torch import nn +import torch.nn.functional as F + +from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher +from model_compression_toolkit.core.common import BaseNode, Graph, BaseSubstitution +from model_compression_toolkit.core.common.graph.functional_node import FunctionalNode +from model_compression_toolkit.core.pytorch.constants import * +from model_compression_toolkit.logger import Logger + + +class FunctionalLinear(BaseSubstitution): + """ + Replace functional linear with Linear. + """ + + def __init__(self): + """ + Matches: functional linear + """ + func_node = NodeOperationMatcher(F.linear) + super().__init__(matcher_instance=func_node) + + def substitute(self, + graph: Graph, + func_node: FunctionalNode) -> Graph: + """ + Substitute functional.linear and its inputs with Linear. + Args: + graph: Graph we apply the substitution on. + node: node that match the pattern in the substitution init. + + Returns: + Graph after applying the substitution. + """ + + # Create new node of layer Linear + if 1 not in func_node.weights: + Logger.critical(f'Weight input missing for node {func_node.name}.') # pragma: no cover + # Extract index of kernel and bias according to tensor_input_allocs if they were input as kwargs. If + # they were input as args, use their fixed positions. + weight_index = func_node.tensor_input_allocs.index(KERNEL) if KERNEL in func_node.tensor_input_allocs else 1 + bias_index = func_node.tensor_input_allocs.index(BIAS) if BIAS in func_node.tensor_input_allocs else 2 + if weight_index not in func_node.weights: + Logger.critical(f'Mismatch between tensor_input_allocs and weight index in node {func_node.name}.') # pragma: no cover + weight = func_node.weights[weight_index] + bias = func_node.weights.get(bias_index) + + framework_attr = { + IN_FEATURES: func_node.input_shape[0][-1], + OUT_FEATURES: func_node.output_shape[0][-1], + BIAS: bias is not None, + } + + weights = {KERNEL: weight} if bias is None else {KERNEL: weight, BIAS: bias} + + new_node = BaseNode( + name=func_node.name, + framework_attr=framework_attr, + input_shape=func_node.input_shape[0], + output_shape=func_node.output_shape, + weights=weights, + layer_class=nn.Linear, + has_activation=func_node.has_activation, + reuse=func_node.reuse, + reuse_group=func_node.reuse_group + ) + + graph.replace_node(func_node, new_node) + return graph diff --git a/model_compression_toolkit/core/pytorch/pytorch_implementation.py b/model_compression_toolkit/core/pytorch/pytorch_implementation.py index 5ec26a66d..779892082 100644 --- a/model_compression_toolkit/core/pytorch/pytorch_implementation.py +++ b/model_compression_toolkit/core/pytorch/pytorch_implementation.py @@ -50,6 +50,8 @@ FunctionalBatchNorm from model_compression_toolkit.core.pytorch.graph_substitutions.substitutions.functional_layer_norm import \ FunctionalLayerNorm +from model_compression_toolkit.core.pytorch.graph_substitutions.substitutions.functional_linear import \ + FunctionalLinear from model_compression_toolkit.core.pytorch.graph_substitutions.substitutions.linear_collapsing import \ pytorch_linear_collapsing from model_compression_toolkit.core.pytorch.graph_substitutions.substitutions.multi_head_attention_decomposition \ @@ -266,6 +268,7 @@ def get_substitutions_prepare_graph(self, fw_info: FrameworkInfo = None) -> List FunctionalConvSubstitution(fw_info), FunctionalBatchNorm(), FunctionalLayerNorm(), + FunctionalLinear(), RemoveIdentity()] def get_substitutions_pre_statistics_collection(self, diff --git a/tests/pytorch_tests/model_tests/feature_models/linear_function_test.py b/tests/pytorch_tests/model_tests/feature_models/linear_function_test.py new file mode 100644 index 000000000..30dae8e22 --- /dev/null +++ b/tests/pytorch_tests/model_tests/feature_models/linear_function_test.py @@ -0,0 +1,51 @@ +# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +import torch +import torch.nn.functional as F +from tests.pytorch_tests.model_tests.base_pytorch_test import BasePytorchTest +from model_compression_toolkit.core.pytorch.pytorch_device_config import get_working_device + +""" +This test checks the linear functional substitution function. +""" + + +class LinearFNet(torch.nn.Module): + def __init__(self): + super(LinearFNet, self).__init__() + self.fc1 = torch.nn.Linear(in_features=1000, out_features=100, bias=False) + self.fc2 = torch.nn.Linear(in_features=100, out_features=50, bias=True) + self.fc3 = torch.nn.Linear(in_features=50, out_features=10, bias=False) + + def forward(self, x): + x = F.linear(x, self.fc1.weight, self.fc1.bias) + x = F.linear(x, bias=self.fc2.bias, weight=self.fc2.weight) + y = F.linear(x, self.fc3.weight, bias=None) + return y + + +class LinearFNetTest(BasePytorchTest): + """ + This test check the linear functional substitution function. + """ + + def __init__(self, unit_test): + super().__init__(unit_test) + + def create_inputs_shape(self): + return [[self.val_batch_size, 1000]] + + def create_feature_network(self, input_shape): + return LinearFNet() diff --git a/tests/pytorch_tests/model_tests/test_feature_models_runner.py b/tests/pytorch_tests/model_tests/test_feature_models_runner.py index b5c02f350..45c6e8f51 100644 --- a/tests/pytorch_tests/model_tests/test_feature_models_runner.py +++ b/tests/pytorch_tests/model_tests/test_feature_models_runner.py @@ -55,6 +55,7 @@ from tests.pytorch_tests.model_tests.feature_models.layer_norm_net_test import LayerNormNetTest from tests.pytorch_tests.model_tests.feature_models.linear_collapsing_test import TwoConv2DCollapsingTest, \ ThreeConv2DCollapsingTest, FourConv2DCollapsingTest, SixConv2DCollapsingTest +from tests.pytorch_tests.model_tests.feature_models.linear_function_test import LinearFNetTest from tests.pytorch_tests.model_tests.feature_models.lut_quantizer_test import LUTWeightsQuantizerTest, \ LUTActivationQuantizerTest from tests.pytorch_tests.model_tests.feature_models.manual_bit_selection import ManualBitWidthByLayerTypeTest, \ @@ -239,6 +240,12 @@ def test_bn_function(self): """ BNFNetTest(self).run_test() + def test_linear_function(self): + """ + This test check the linear functional substitution function. + """ + LinearFNetTest(self).run_test() + def test_broken_net(self): """ This test checks that the "broken" node (node without output) is being