From 52bb1c598ffd5ca69d964c8199074386e6d870eb Mon Sep 17 00:00:00 2001 From: Elad Cohen <78862769+elad-c@users.noreply.github.com> Date: Sun, 1 Dec 2024 12:01:06 +0200 Subject: [PATCH] Fix duplicate QCOs error: raise error only if duplicate QCOs are not identical. (#1282) --- model_compression_toolkit/core/common/graph/base_node.py | 7 ++++--- .../core/common/graph/functional_node.py | 2 +- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/model_compression_toolkit/core/common/graph/base_node.py b/model_compression_toolkit/core/common/graph/base_node.py index b90bc6a87..77e87c6ea 100644 --- a/model_compression_toolkit/core/common/graph/base_node.py +++ b/model_compression_toolkit/core/common/graph/base_node.py @@ -556,9 +556,10 @@ def get_qco(self, tpc: TargetPlatformCapabilities) -> QuantizationConfigOptions: # Extract qco with is_match_type to overcome mismatch of function types in TF 2.15 matching_qcos = [_qco for _type, _qco in tpc.layer2qco.items() if self.is_match_type(_type)] if matching_qcos: - if len(matching_qcos) > 1: - Logger.error('Found duplicate qco types!') - return matching_qcos[0] + if all([_qco == matching_qcos[0] for _qco in matching_qcos]): + return matching_qcos[0] + else: + Logger.critical(f"Found duplicate qco types for node '{self.name}' of type '{self.type}'!") # pragma: no cover return tpc.tp_model.default_qco def filter_node_qco_by_graph(self, tpc: TargetPlatformCapabilities, diff --git a/model_compression_toolkit/core/common/graph/functional_node.py b/model_compression_toolkit/core/common/graph/functional_node.py index bcf2cc15e..db7b7b6a5 100644 --- a/model_compression_toolkit/core/common/graph/functional_node.py +++ b/model_compression_toolkit/core/common/graph/functional_node.py @@ -85,5 +85,5 @@ def is_match_type(self, _type: Type) -> bool: Whether _type matches the self node type """ - names_match = _type.__name__ == self.type.__name__ if FOUND_TF else False + names_match = _type.__name__ == self.type.__name__ return super().is_match_type(_type) or names_match