diff --git a/model_compression_toolkit/core/common/model_collector.py b/model_compression_toolkit/core/common/model_collector.py index b65bbbfda..e48b49f0d 100644 --- a/model_compression_toolkit/core/common/model_collector.py +++ b/model_compression_toolkit/core/common/model_collector.py @@ -158,7 +158,7 @@ def infer(self, inputs_list: List[np.ndarray]): for td, sc in zip(tensor_data, self.stats_containers_list): if isinstance(sc, (list, tuple)): if not isinstance(td, (list, tuple)): - Logger.critical('\'tensor_data\' must be a list or a tuple if \'stats_containers_list\' contains lists or tuples.') # pragma: no cover + Logger.critical(f"\'tensor_data\' is of type {type(td)} but must be of the same type as \'stats_containers_list\', which is of type {type(sc)}") # pragma: no cover if len(sc) != len(td): Logger.critical('\'tensor_data\' and \'stats_containers_list\' must have matching lengths') # pragma: no cover for tdi, sci in zip(td, sc): diff --git a/model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/scaled_dot_product_attention.py b/model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/scaled_dot_product_attention.py new file mode 100644 index 000000000..ed4b9ec5c --- /dev/null +++ b/model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/scaled_dot_product_attention.py @@ -0,0 +1,231 @@ +# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import torch.nn as nn +import torch +import math +from copy import copy +import numpy as np +from model_compression_toolkit.core.common.graph.functional_node import FunctionalNode +from model_compression_toolkit.core.common import BaseSubstitution +from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher +from model_compression_toolkit.core.common.graph.base_graph import Graph, BaseNode, OutTensor +from model_compression_toolkit.core.pytorch.constants import DIM +from model_compression_toolkit.core.pytorch.pytorch_device_config import get_working_device + + +class ScaledDotProductDecomposition(BaseSubstitution): + """ + Decompose torch.nn.scale_dot_product into its base operators: + Transpose (over k) + MatMul(over q and transposed k) + Mul (for scaling) + Add (for masking. optional operation, used in cases that attn_mask ig given) + Dropout + Softmax + Matmul. + """ + + def __init__(self): + """ + Matches scaled_dot_product_attention node. + """ + super().__init__(matcher_instance=NodeOperationMatcher(nn.functional.scaled_dot_product_attention)) + + def _get_input_by_name(self, attention_node: FunctionalNode, input_name: str, + input_index: int, default_value: any) -> any: + """ + Search for attention_node input value in op_call_kwargs (using input_name) and op_call_args (using input_index). + In case the input is not given, returns its default_value. + + """ + if input_name in attention_node.op_call_kwargs: + return attention_node.op_call_kwargs[input_name] + elif len(attention_node.op_call_args) > input_index: # input order: [attn_mask, dropout_p, is_causal] + return attention_node.op_call_args[input_index] + return default_value + + def _get_attention_input_nodes(self, graph: Graph, attention_node: FunctionalNode) -> dict: + q, k, v = 0, 1, 2 + prev_nodes = graph.get_prev_nodes(attention_node, sink_index_sorted=True) + q_node, k_node, v_node = prev_nodes[q], prev_nodes[k], prev_nodes[v] + return {"q": q_node, "k": k_node, "v": v_node} + + def _get_transpose_k_node(self, attention_node_name: str, key_node: BaseNode) -> BaseNode: + input_shape, output_shape = copy(key_node.output_shape[0]), copy(key_node.output_shape[0]) + output_shape[-2], output_shape[-1] = input_shape[-1], input_shape[-2] + transpose_node = FunctionalNode(name=f"{attention_node_name}_{key_node.name}_transpose", + framework_attr={}, + input_shape=input_shape, + output_shape=output_shape, + weights={}, + layer_class=torch.transpose, + op_call_args=[-1, -2], # axes to transpose + op_call_kwargs={}, + functional_op=torch.transpose) + return transpose_node + + def _get_scale_node(self, attention_node: FunctionalNode, q_node: BaseNode, matmul_node: BaseNode) -> FunctionalNode: + """ + :return: multiplication node that represents multiplication by the scale factor + """ + scale_name = f'{attention_node.name}_scale' + q_embd_axis = -1 + input_scale = self._get_input_by_name(attention_node, "scale", 3, None) + scale_factor = input_scale if input_scale else (1 / math.sqrt(q_node.output_shape[0][q_embd_axis])) + scale_node = FunctionalNode(name=scale_name, + framework_attr={}, + input_shape=(matmul_node.output_shape), + output_shape=matmul_node.output_shape, + weights={}, + layer_class=torch.mul, + op_call_args=[scale_factor], + op_call_kwargs={}, + functional_op=torch.mul) + return scale_node + + def _get_matmul_node(self, attention_node_name: str, q_node: BaseNode, transposed_k_node: BaseNode) -> BaseNode: + matmul1_output_shape = copy(q_node.output_shape[0]) + matmul1_output_shape[-2] = q_node.output_shape[0][-2] + matmul1_output_shape[-1] = transposed_k_node.output_shape[-1] + matmul_name = f'{attention_node_name}_matmul1' + return FunctionalNode(name=matmul_name, + framework_attr={}, + input_shape=(tuple(q_node.output_shape[0]), tuple(transposed_k_node.output_shape)), + output_shape=tuple(matmul1_output_shape), + weights={}, + layer_class=torch.matmul, + op_call_args=[], + op_call_kwargs={}, + functional_op=torch.matmul) + + def _get_mask_node(self, attention_node: FunctionalNode, scale_node: FunctionalNode) -> FunctionalNode: + """ + :return: Add operator node with the mask tensor as input. In case there is no mask tensor, returns None. + """ + attention_mask_tensor = self._get_attention_mask_tensor(attention_node) + if attention_mask_tensor is None: + return None + mask_node_name = f'{attention_node.name}_mask' + return FunctionalNode(name=mask_node_name, + framework_attr={}, + input_shape=(scale_node.output_shape), + output_shape=scale_node.output_shape, + weights={}, + layer_class=torch.add, + op_call_args=[], + op_call_kwargs={'other': attention_mask_tensor}, + functional_op=torch.add) + + def _get_softmax_node(self, attention_node_name: str, in_out_shape: tuple) -> BaseNode: + softmax_name = f'{attention_node_name}_softmax' + return BaseNode(name=softmax_name, + framework_attr={DIM: -1}, + input_shape=in_out_shape, + output_shape=in_out_shape, + weights={}, + layer_class=nn.Softmax) + + def _get_matmul2_node(self, attention_node_name: str, softmax_node: BaseNode, v_node: BaseNode) -> FunctionalNode: + matmul2_output_shape = list(copy(softmax_node.output_shape)) + matmul2_output_shape[-2] = softmax_node.output_shape[-2] + matmul2_output_shape[-1] = v_node.output_shape[0][-1] + matmul2_name = f'{attention_node_name}_matmul2' + return FunctionalNode(name=matmul2_name, + framework_attr={}, + input_shape=(tuple(softmax_node.output_shape), tuple(v_node.output_shape[0])), + output_shape=tuple(matmul2_output_shape), + weights={}, + layer_class=torch.matmul, + op_call_args=[], + op_call_kwargs={}, + functional_op=torch.matmul) + + def _get_attention_mask_tensor(self, attention_node: FunctionalNode) -> torch.Tensor: + """ + :return: mask tensor given as part of attention node input. + Since MCT doesn't support infinite values, we don't support is_causal (torch.nn.scale_dot_product_attention + argument) and boolean mask tensor, as they both require -inf values. + """ + device = get_working_device() + is_causal = self._get_input_by_name(attention_node, "is_causal", 2, False) + if is_causal: + raise NotImplementedError("scaled_dot_product_attention is_causal feature is not implemented.") + input_weights = list(attention_node.weights.values()) + attn_mask = input_weights[0] if len(input_weights) > 0 else None + if attn_mask is not None and (attn_mask.dtype == "bool"): + raise NotImplementedError( + "scaled_dot_product_attention attn_mask is of type boolean, which is not supported.") + if attn_mask is not None and (not np.isfinite(attn_mask).all()): + raise NotImplementedError( + "scaled_dot_product_attention attn_mask contains infinite value, which is not supported.") + return torch.from_numpy(attn_mask).to(device) if attn_mask is not None else None + + def _get_dropout_node(self, attention_node: FunctionalNode, in_out_shape: tuple) -> BaseNode: + dropout_p = attention_node.op_call_kwargs.get('dropout_p', 0) + dropout_name = f'{attention_node.name}_dropout' + return BaseNode(name=dropout_name, + framework_attr={"p": dropout_p}, + input_shape=in_out_shape, + output_shape=in_out_shape, + weights={}, + layer_class=nn.Dropout) + + def substitute(self, graph: Graph, attention_node: FunctionalNode) -> Graph: + """ + Removes a scaled_dot_product_attention node from the graph, and replaces it with a compatible graph that + consists of: + Transpose (over k) + MatMul(over q and transposed k) + Mul (for scaling) + Add (for masking. optional operation, used in cases that attn_mask ig given) + Dropout + Softmax + Matmul. + :param graph: A Graph to apply substitution on + :param attention_node: the node to replace + :return: A graph after the substitution + """ + print("In scale_dot_product_attention substitution@@@@@@@@") + input_nodes = self._get_attention_input_nodes(graph, attention_node) + q_node, k_node, v_node = input_nodes["q"], input_nodes["k"], input_nodes["v"] + transpose_k_node = self._get_transpose_k_node(attention_node.name, k_node) + matmul_node = self._get_matmul_node(attention_node.name, q_node, transpose_k_node) + scale_node = self._get_scale_node(attention_node, q_node, matmul_node) + mask_node = self._get_mask_node(attention_node, scale_node) + softmax_node = self._get_softmax_node(attention_node.name, matmul_node.output_shape) + dropout_node = self._get_dropout_node(attention_node, softmax_node.output_shape) + matmul2_node = self._get_matmul2_node(attention_node.name, softmax_node, v_node) + + graph.add_node_with_in_edges(transpose_k_node, [k_node]) + graph.add_node_with_in_edges(matmul_node, [q_node, transpose_k_node]) + graph.add_node_with_in_edges(scale_node, [matmul_node]) + if mask_node: + graph.add_node_with_in_edges(mask_node, [scale_node]) + graph.add_node_with_in_edges(softmax_node, [mask_node if mask_node else scale_node]) + graph.add_node_with_in_edges(dropout_node, [softmax_node]) + graph.add_node_with_in_edges(matmul2_node, [dropout_node if dropout_node else softmax_node, v_node]) + + graph_outputs = graph.get_outputs() + for i, g_out in enumerate(graph_outputs): + if g_out.node == attention_node: + graph_outputs[i] = OutTensor(node=matmul2_node, node_out_index=g_out.node_out_index) + + graph.reconnect_out_edges(current_node=attention_node, new_node=matmul2_node) + graph.remove_edge(q_node, attention_node) + graph.remove_edge(k_node, attention_node) + graph.remove_edge(v_node, attention_node) + graph.remove_node(attention_node, new_graph_outputs=graph_outputs) + return graph diff --git a/model_compression_toolkit/core/pytorch/pytorch_implementation.py b/model_compression_toolkit/core/pytorch/pytorch_implementation.py index dd7c715ec..04cd86b15 100644 --- a/model_compression_toolkit/core/pytorch/pytorch_implementation.py +++ b/model_compression_toolkit/core/pytorch/pytorch_implementation.py @@ -53,6 +53,8 @@ pytorch_linear_collapsing from model_compression_toolkit.core.pytorch.graph_substitutions.substitutions.multi_head_attention_decomposition \ import MultiHeadAttentionDecomposition +from model_compression_toolkit.core.pytorch.graph_substitutions.substitutions.scaled_dot_product_attention import \ + ScaledDotProductDecomposition from model_compression_toolkit.core.pytorch.graph_substitutions.substitutions.transform_function_call_method import \ TransformFunctionCallMethod from model_compression_toolkit.core.pytorch.graph_substitutions.substitutions.const_holder_conv import \ @@ -237,6 +239,7 @@ def get_substitutions_prepare_graph(self, fw_info: FrameworkInfo = None) -> List """ return [ReshapeWithStaticShapes(), MultiHeadAttentionDecomposition(), + ScaledDotProductDecomposition(), TransformFunctionCallMethod(), FunctionalConvSubstitution(fw_info), FunctionalBatchNorm(), diff --git a/tests/common_tests/base_test.py b/tests/common_tests/base_test.py index 5b08bc613..ade673b5a 100644 --- a/tests/common_tests/base_test.py +++ b/tests/common_tests/base_test.py @@ -12,13 +12,20 @@ def __init__(self, unit_test, val_batch_size=1, num_calibration_iter=1, num_of_inputs=1, - input_shape=(8, 8, 3)): + input_shape=(8, 8, 3), + use_is_close_validation=False + ): + """ + :param use_is_close_validation: Allow similar (instead of exact) outputs when comparing the original float + model output against the no_quantization model output. + """ self.unit_test = unit_test self.val_batch_size = val_batch_size self.num_calibration_iter = num_calibration_iter self.num_of_inputs = num_of_inputs self.input_shape = (val_batch_size,) + input_shape + self.use_is_close_validation = use_is_close_validation def generate_inputs(self): return [np.random.randn(*in_shape) for in_shape in self.get_input_shapes()] diff --git a/tests/pytorch_tests/model_tests/base_pytorch_test.py b/tests/pytorch_tests/model_tests/base_pytorch_test.py index 5eeb81b79..b7cc402f5 100644 --- a/tests/pytorch_tests/model_tests/base_pytorch_test.py +++ b/tests/pytorch_tests/model_tests/base_pytorch_test.py @@ -95,8 +95,10 @@ def compare(self, quantized_models, float_model, input_x=None, quantization_info # Check if we have a BatchNorm or MultiheadAttention layer in the model. # If so, the outputs will not be the same, since the sqrt function in the # Decomposition is not exactly like the sqrt in the C implementation of PyTorch. - if torch.nn.BatchNorm2d or torch.nn.MultiheadAttention in [type(module) for name, module in float_model.named_modules()]: - self.unit_test.assertTrue(np.all(np.isclose(torch_tensor_to_numpy(f), torch_tensor_to_numpy(q), + float_model_operators = [type(module) for name, module in float_model.named_modules()] + if (torch.nn.BatchNorm2d in float_model_operators or + torch.nn.MultiheadAttention in float_model_operators or self.use_is_close_validation): + self.unit_test.assertTrue(np.all(np.isclose(torch_tensor_to_numpy(f), torch_tensor_to_numpy(q), atol=self.float_reconstruction_error))) else: self.unit_test.assertTrue(torch_tensor_to_numpy(torch.sum(torch.abs(f - q))) == 0, diff --git a/tests/pytorch_tests/model_tests/feature_models/bn_function_test.py b/tests/pytorch_tests/model_tests/feature_models/bn_function_test.py index bf9073144..c5e2d34fb 100644 --- a/tests/pytorch_tests/model_tests/feature_models/bn_function_test.py +++ b/tests/pytorch_tests/model_tests/feature_models/bn_function_test.py @@ -44,6 +44,7 @@ class BNFNetTest(BasePytorchTest): def __init__(self, unit_test): super().__init__(unit_test) + self.use_is_close_validation = True # because the net contains BN layer def create_inputs_shape(self): return [[self.val_batch_size, 3, 32, 32], [self.val_batch_size, 3, 32, 32]] diff --git a/tests/pytorch_tests/model_tests/feature_models/layer_norm_net_test.py b/tests/pytorch_tests/model_tests/feature_models/layer_norm_net_test.py index ce7d0a9be..3056db173 100644 --- a/tests/pytorch_tests/model_tests/feature_models/layer_norm_net_test.py +++ b/tests/pytorch_tests/model_tests/feature_models/layer_norm_net_test.py @@ -59,6 +59,7 @@ def __init__(self, unit_test, has_weight=None, has_bias=None): super().__init__(unit_test) self.has_weight = has_weight self.has_bias = has_bias + self.use_is_close_validation = True def create_inputs_shape(self): return [[self.val_batch_size, 3, 32, 32]] diff --git a/tests/pytorch_tests/model_tests/feature_models/scaled_dot_product_attention_test.py b/tests/pytorch_tests/model_tests/feature_models/scaled_dot_product_attention_test.py new file mode 100644 index 000000000..d7f1eed4f --- /dev/null +++ b/tests/pytorch_tests/model_tests/feature_models/scaled_dot_product_attention_test.py @@ -0,0 +1,106 @@ +# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from tests.pytorch_tests.model_tests.base_pytorch_test import BasePytorchTest +from torch import nn +from packaging import version +import torch + +class ScaledDotProductAttentionNet(nn.Module): + def __init__(self, dropout_p=0.0, scale=None, attn_mask=None, is_causal=False): + super().__init__() + self.dropout_p = dropout_p + self.scale = scale + self.attn_mask = attn_mask + self.is_causal = is_causal + + def forward(self, q, k, v): + x = nn.functional.scaled_dot_product_attention(q, k, v, + attn_mask=self.attn_mask, + dropout_p=self.dropout_p, + is_causal=self.is_causal, + # scale=self.scale + ) + return x + + +class ScaledDotProductAttentionTest(BasePytorchTest): + """ + This test checks the scaled_dot_product_attention (SDPA) substitution using a single SDPA layer. + """ + + def __init__(self, unit_test, batch_size: int, q_and_k_embd_size: int, v_embd_size: int, source_seq_len: int, + target_seq_len: int, dropout_p: float = 0.0, scale: float = None, attn_mask: float = None, + is_causal: bool = False): + + super().__init__(unit_test) + self.batch_size = batch_size + self.q_and_k_embd_size = q_and_k_embd_size + self.v_embd_size = v_embd_size + self.source_seq_len = source_seq_len + self.target_seq_len = target_seq_len + self.use_is_close_validation = True # because SDPA contains sqrt operation which leads to sightly different output values compared to original torch model + self.dropout_p = dropout_p + self.scale = scale + self.attn_mask = attn_mask + self.is_causal = is_causal + + def create_feature_network(self, input_shape) -> nn.Module: + + if version.parse(torch.__version__) >= version.parse("2.1"): + return ScaledDotProductAttentionNet(dropout_p=self.dropout_p, + attn_mask=self.attn_mask, + is_causal=self.is_causal, + scale=self.scale) + + else: # older torch versions don't have scale argument + return ScaledDotProductAttentionNet(dropout_p=self.dropout_p, + attn_mask=self.attn_mask, + is_causal=self.is_causal) + + def create_inputs_shape(self) -> list: + q_shape = [self.batch_size, self.target_seq_len, self.q_and_k_embd_size] + k_shape = [self.batch_size, self.source_seq_len, self.q_and_k_embd_size] + v_shape = [self.batch_size, self.source_seq_len, self.v_embd_size] + return [q_shape, k_shape, v_shape] + + def _test_substitution_structure_output(self, post_substitution_nodes) -> None: + """ + :param post_substitution_nodes: The graph nodes after the SDPA substitution + raise Exception if case the post_substitution_nodes doesn't match the expected_nodes_counter + """ + expected_nodes_counter = { + 'DummyPlaceHolder': 3, + "transpose": 1, + "matmul": 2, + "mul": 1, # scale operator + "Softmax": 1, + "Dropout": 1, + "add": 0 if self.attn_mask is None else 1 # mask operator + } + + for node in post_substitution_nodes: + operator_name = node.layer_class.__name__ + if not (operator_name in expected_nodes_counter): + raise Exception(f"Post substitution graph contains unexpected node: {operator_name}") + expected_nodes_counter[operator_name] -= 1 + + counter_results = set(expected_nodes_counter.values()) + if not (len(counter_results) == 1 and 0 in counter_results): # validate that all values are zeros + raise Exception(f"Post substitution graph contains unexpected nodes: {[k for k, v in expected_nodes_counter.items() if v != 0]}") + + def compare(self, quantized_models, float_model, input_x=None, quantization_info=None) -> None: + super().compare(quantized_models, float_model, input_x, quantization_info) + post_substitution_nodes = quantized_models['no_quantization'].node_sort + self._test_substitution_structure_output(post_substitution_nodes) diff --git a/tests/pytorch_tests/model_tests/test_feature_models_runner.py b/tests/pytorch_tests/model_tests/test_feature_models_runner.py index 229992a7a..2210fbc21 100644 --- a/tests/pytorch_tests/model_tests/test_feature_models_runner.py +++ b/tests/pytorch_tests/model_tests/test_feature_models_runner.py @@ -70,6 +70,7 @@ ConstantConvReuseSubstitutionTest, ConstantConvTransposeSubstitutionTest from tests.pytorch_tests.model_tests.feature_models.multi_head_attention_test import MHALayerNetTest, \ MHALayerNetFeatureTest +from tests.pytorch_tests.model_tests.feature_models.scaled_dot_product_attention_test import ScaledDotProductAttentionTest from tests.pytorch_tests.model_tests.feature_models.scale_equalization_test import \ ScaleEqualizationWithZeroPadNetTest, ScaleEqualizationNetTest, \ ScaleEqualizationReluFuncNetTest, ScaleEqualizationReluFuncWithZeroPadNetTest, \ @@ -108,6 +109,7 @@ from tests.pytorch_tests.model_tests.feature_models.remove_identity_test import RemoveIdentityTest from tests.pytorch_tests.model_tests.feature_models.activation_16bit_test import Activation16BitTest, \ Activation16BitMixedPrecisionTest +from model_compression_toolkit.core.pytorch.pytorch_device_config import get_working_device class FeatureModelsTestRunner(unittest.TestCase): @@ -159,7 +161,7 @@ def test_add_net(self): """ AddNetTest(self).run_test() - def test_layer_norm_net(self): + def test_layer_norm_net(self): # yoyo """ These tests check the nn.functional.layer_norm operations. """ @@ -362,7 +364,7 @@ def test_scalar_tensor(self): """ ScalarTensorTest(self).run_test() - def test_layer_name(self): + def test_layer_name(self): # yoyo """ This test checks that we build a correct graph and correctly reconstruct the model given the fact that we reuse nodes and abuse the naming convention of fx (if we resuse @@ -590,6 +592,26 @@ def test_mha_layer_test(self): MHALayerNetFeatureTest(self, num_heads[0], q_seq_len[0], qdim[0] * num_heads[0], kv_seq_len[0], kdim[0], vdim[0], bias=True, add_bias_kv=True).run_test() + def test_scaled_dot_product_attention_layer(self): + """ + This test checks the ScaledDotProductDecomposition substitution feature. + """ + + batch_size = [3, 1, 5] + q_and_k_embd_size = [8, 9, 3] + v_embd_size = [19, 2, 6] + source_seq_len = [21, 4, 15] + target_seq_len = [13, 12, 9] + for i in range(len(batch_size)): + ScaledDotProductAttentionTest(self, batch_size[i], q_and_k_embd_size[i], v_embd_size[i], source_seq_len[i], + target_seq_len[i]).run_test(seed=3) + ScaledDotProductAttentionTest(self, batch_size[i], q_and_k_embd_size[i], v_embd_size[i], source_seq_len[i], + target_seq_len[i], dropout_p=0.0, scale=5).run_test(seed=3) + attn_mask = torch.ones(target_seq_len[i], source_seq_len[i]).to(get_working_device()) + ScaledDotProductAttentionTest(self, batch_size[i], q_and_k_embd_size[i], v_embd_size[i], source_seq_len[i], + target_seq_len[i], attn_mask=attn_mask).run_test(seed=3) + + def test_gptq(self): """ This test checks the GPTQ feature.