diff --git a/nncf/common/graph/transformations/commands.py b/nncf/common/graph/transformations/commands.py index c40183dd94e..f08ef831e98 100644 --- a/nncf/common/graph/transformations/commands.py +++ b/nncf/common/graph/transformations/commands.py @@ -140,6 +140,7 @@ def __init__(self, target_type: TargetType): :param target_type: Type of the target point. """ self._target_type = target_type + self.target_node_name: str = "" @property def type(self) -> TargetType: diff --git a/nncf/common/logging/track_progress.py b/nncf/common/logging/track_progress.py index b82a95f32bf..21f4873431e 100644 --- a/nncf/common/logging/track_progress.py +++ b/nncf/common/logging/track_progress.py @@ -146,6 +146,10 @@ def __iter__(self) -> Iterable[ProgressType]: self.sequence, total=self.total, description=self.description, update_period=self.update_period ) + def __next__(self): + with self.progress: + return next(self.iterator) + def __enter__(self): self.progress.start() self.task = self.progress.add_task(self.description, total=self.total) diff --git a/nncf/common/tensor_statistics/aggregator.py b/nncf/common/tensor_statistics/aggregator.py index 6ca5178bbf3..3e5a5b0fa82 100644 --- a/nncf/common/tensor_statistics/aggregator.py +++ b/nncf/common/tensor_statistics/aggregator.py @@ -10,13 +10,11 @@ # limitations under the License. from abc import ABC from abc import abstractmethod -from itertools import islice from typing import Any, Dict, TypeVar from nncf.common import factory from nncf.common.graph.graph import NNCFGraph from nncf.common.graph.transformations.layout import TransformationLayout -from nncf.common.logging.track_progress import track from nncf.common.tensor import NNCFTensor from nncf.common.tensor_statistics.statistic_point import StatisticPointsContainer from nncf.data.dataset import Dataset @@ -30,9 +28,9 @@ class StatisticsAggregator(ABC): Base class for statistics collection. """ - def __init__(self, dataset: Dataset): + def __init__(self, dataset: Dataset[int, int]): self.dataset = dataset - self.stat_subset_size = None + self.stat_subset_size = 0 self.statistic_points = StatisticPointsContainer() def collect_statistics(self, model: TModel, graph: NNCFGraph) -> None: @@ -50,21 +48,15 @@ def collect_statistics(self, model: TModel, graph: NNCFGraph) -> None: merged_statistics = self._get_merged_statistic_points(self.statistic_points, model, graph) transformation_layout = self._get_transformation_layout_extra_outputs(merged_statistics) - model_with_outputs = model_transformer.transform(transformation_layout) + model_with_outputs: TModel = model_transformer.transform(transformation_layout) engine = factory.EngineFactory.create(model_with_outputs) - dataset_length = self.dataset.get_length() - total = ( - min(dataset_length or self.stat_subset_size, self.stat_subset_size) - if self.stat_subset_size is not None - else None - ) empty_statistics = True - for input_data in track( - islice(self.dataset.get_inference_data(), self.stat_subset_size), - total=total, - description="Statistics collection", - ): + data_iterable = iter([self.dataset.get_inference_data()]) + for input_data in data_iterable: + outputs = engine.infer(input_data) + processed_outputs = self._process_outputs(outputs) + self._register_statistics(processed_outputs, merged_statistics) outputs = engine.infer(input_data) processed_outputs = self._process_outputs(outputs) self._register_statistics(processed_outputs, merged_statistics) @@ -87,7 +79,7 @@ def register_statistic_points(self, statistic_points: StatisticPointsContainer) for _, _statistic_points in self.statistic_points.items(): for _statistic_point in _statistic_points: - for _, tensor_collectors in _statistic_point.algorithm_to_tensor_collectors.items(): + for tensor_collectors in _statistic_point.algorithm_to_tensor_collectors.values(): for tensor_collector in tensor_collectors: if self.stat_subset_size is None: self.stat_subset_size = tensor_collector.num_samples diff --git a/nncf/common/tensor_statistics/collectors.py b/nncf/common/tensor_statistics/collectors.py index e1b01310216..9c58f8b3f6d 100644 --- a/nncf/common/tensor_statistics/collectors.py +++ b/nncf/common/tensor_statistics/collectors.py @@ -12,22 +12,27 @@ from abc import ABC from abc import abstractmethod from collections import deque -from typing import List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np +import torch from nncf.common.tensor import NNCFTensor from nncf.common.tensor import TensorElementsType from nncf.common.tensor import TensorType from nncf.common.tensor_statistics.reduction import get_per_channel_history -ReductionAxes = Tuple[int] +ReductionAxes = Tuple[int, ...] class TensorStatisticCollectorBase(ABC): """Collector estimate statistics at the quantization point based on the provided reduction shape.""" - def __init__(self, reduction_shape: Optional[ReductionAxes] = None, num_samples: Optional[int] = None): + def __init__( + self, + reduction_shape: Optional[ReductionAxes] = None, + num_samples: Optional[int] = None, + ): """ Initializes Tensor Statistic Collector @@ -40,10 +45,10 @@ def __init__(self, reduction_shape: Optional[ReductionAxes] = None, num_samples: self._num_samples = num_samples @property - def num_samples(self) -> int: + def num_samples(self) -> Optional[int]: return self._num_samples - def register_input(self, x: TensorType) -> TensorType: + def register_input(self, x: torch.Tensor) -> torch.Tensor: """Registers input tensor""" if not self._enabled: return x @@ -56,32 +61,32 @@ def register_input(self, x: TensorType) -> TensorType: return x @abstractmethod - def _register_input(self, x: TensorType): + def _register_input(self, x: TensorType) -> TensorType: pass - def get_statistics(self): + def get_statistics(self) -> None: """Returns collected statistics, if present.""" if self._collected_samples == 0: raise StatisticsNotCollectedError() return self._get_statistics() @abstractmethod - def _get_statistics(self): + def _get_statistics(self) -> None: pass - def enable(self): + def enable(self) -> None: self._enabled = True - def disable(self): + def disable(self) -> None: self._enabled = False - def reset(self): + def reset(self) -> None: """Resets all the statistics in the collector.""" self._collected_samples = 0 self._reset() @abstractmethod - def _reset(self): + def _reset(self) -> None: pass def collected_samples(self) -> int: @@ -100,12 +105,15 @@ class OfflineTensorStatisticCollector(TensorStatisticCollectorBase): """Collects statistics in offline regime by storing the data and aggregating it afterwards.""" def __init__( - self, reduction_shape: Optional[ReductionAxes] = None, num_samples: int = None, window_size: int = None + self, + reduction_shape: Optional[ReductionAxes] = None, + num_samples: int = None, + window_size: int = None, ): super().__init__(reduction_shape, num_samples) - self._samples = deque(maxlen=window_size) + self._samples: deque[torch.Tensor] = deque(maxlen=window_size) - def _reset(self): + def _reset(self) -> None: self._samples.clear() @@ -116,7 +124,11 @@ class NNCFCollectorTensorProcessor(ABC): @staticmethod @abstractmethod - def reduce_min(x: NNCFTensor, axis: Union[int, Tuple[int, ...], List[int]], keepdims: bool = False) -> NNCFTensor: + def reduce_min( + x: NNCFTensor, + axis: Union[int, Tuple[int, ...], List[int]], + keepdims: bool = False, + ) -> NNCFTensor: """ Computes minimum of elements across dimensions of NNCFTensor. @@ -129,7 +141,11 @@ def reduce_min(x: NNCFTensor, axis: Union[int, Tuple[int, ...], List[int]], keep @staticmethod @abstractmethod - def reduce_max(x: NNCFTensor, axis: Union[int, Tuple[int, ...], List[int]], keepdims: bool = False) -> NNCFTensor: + def reduce_max( + x: NNCFTensor, + axis: Union[int, Tuple[int, ...], List[int]], + keepdims: bool = False, + ) -> NNCFTensor: """ Computes maximum of elements across dimensions of NNCFTensor. @@ -174,7 +190,11 @@ def max(x1: NNCFTensor, x2: NNCFTensor) -> NNCFTensor: @staticmethod @abstractmethod - def mean(x: NNCFTensor, axis: Union[int, Tuple[int, ...], List[int]], keepdims=False) -> NNCFTensor: + def mean( + x: NNCFTensor, + axis: Union[int, Tuple[int, ...], List[int]], + keepdims: bool = False, + ) -> NNCFTensor: """ Computes the mean of elements across given dimensions of NNCFTensor. @@ -187,7 +207,11 @@ def mean(x: NNCFTensor, axis: Union[int, Tuple[int, ...], List[int]], keepdims=F @staticmethod @abstractmethod - def median(x: NNCFTensor, axis: Union[int, Tuple[int, ...], List[int]], keepdims=False) -> NNCFTensor: + def median( + x: NNCFTensor, + axis: Union[int, Tuple[int, ...], List[int]], + keepdims: bool = False, + ) -> NNCFTensor: """ Computes the median of elements across given dimensions of NNCFTensor. @@ -201,7 +225,11 @@ def median(x: NNCFTensor, axis: Union[int, Tuple[int, ...], List[int]], keepdims @classmethod @abstractmethod def masked_mean( - cls, x: NNCFTensor, axis: Union[int, Tuple[int, ...], List[int]], mask: NNCFTensor, keepdims=False + cls, + x: NNCFTensor, + axis: Union[int, Tuple[int, ...], List[int]], + mask: NNCFTensor, + keepdims: bool = False, ) -> NNCFTensor: """ Computes the masked mean of elements across given dimensions of NNCFTensor. @@ -218,7 +246,11 @@ def masked_mean( @classmethod @abstractmethod def masked_median( - cls, x: NNCFTensor, axis: Union[int, Tuple[int, ...], List[int]], mask: NNCFTensor, keepdims=False + cls, + x: NNCFTensor, + axis: Union[int, Tuple[int, ...], List[int]], + mask: NNCFTensor, + keepdims: bool = False, ) -> NNCFTensor: """ Computes the masked median of elements across given dimensions of NNCFTensor. @@ -266,7 +298,7 @@ def squeeze(x: NNCFTensor, dim: Optional[Union[int, Tuple[int, ...]]] = None) -> @staticmethod @abstractmethod - def sum(tensor: NNCFTensor) -> TensorElementsType: + def sum(tensor: TensorElementsType) -> TensorElementsType: """ Returns a sum of each elements in a given NNCFTensor. @@ -363,6 +395,7 @@ def cat(x: List[NNCFTensor], axis: int) -> NNCFTensor: """ @staticmethod + @abstractmethod def logical_or(input_: NNCFTensor, other: NNCFTensor) -> NNCFTensor: """ Computes the element-wise logical OR of the given input tensors. @@ -374,6 +407,7 @@ def logical_or(input_: NNCFTensor, other: NNCFTensor) -> NNCFTensor: """ @staticmethod + @abstractmethod def less(input_: NNCFTensor, other: NNCFTensor) -> NNCFTensor: """ Return the truth value of (x1 < x2) element-wise. @@ -406,20 +440,21 @@ def __init__(self, use_abs_max: bool, reduction_shape: ReductionAxes, num_sample super().__init__(reduction_shape, num_samples) self._use_abs_max = use_abs_max self._tensor_processor = self._get_processor() - self._min_values = None self._max_values = None @staticmethod @abstractmethod - def _get_processor(): + def _get_processor() -> None: pass - def _register_input_common(self, x: NNCFTensor): - min_reduced = self._tensor_processor.reduce_min(x, self._reduction_shape) - if self._use_abs_max: + def _register_input_common(self, x: NNCFTensor) -> None: + if self._tensor_processor is not None: + min_reduced: None = self._tensor_processor.reduce_min(x, self._reduction_shape) + if self._use_abs_max and self._tensor_processor is not None: x = self._tensor_processor.abs(x) - max_reduced = self._tensor_processor.reduce_max(x, self._reduction_shape) + if self._tensor_processor is not None: + max_reduced: None = self._tensor_processor.reduce_max(x, self._reduction_shape) if self._min_values is None: self._min_values = min_reduced @@ -431,7 +466,7 @@ def _register_input_common(self, x: NNCFTensor): else: self._max_values = self._tensor_processor.max(max_reduced, self._max_values) - def _reset(self): + def _reset(self) -> None: self._min_values = None self._max_values = None @@ -455,21 +490,23 @@ def __init__( self._use_abs_max = use_abs_max self._tensor_processor = self._get_processor() - self._all_min_values = deque(maxlen=window_size) - self._all_max_values = deque(maxlen=window_size) + self._all_min_values: deque[int] = deque(maxlen=window_size) + self._all_max_values: deque[int] = deque(maxlen=window_size) @staticmethod @abstractmethod - def _get_processor(): + def _get_processor() -> None: pass - def _register_input_common(self, x: NNCFTensor): - min_reduced = self._tensor_processor.reduce_min(x, self._reduction_shape) - if self._use_abs_max: + def _register_input_common(self, x: NNCFTensor) -> None: + if self._tensor_processor is not None: + min_reduced: int = self._tensor_processor.reduce_min(x, self._reduction_shape) + if self._use_abs_max and self._tensor_processor is not None: x = self._tensor_processor.abs(x) - max_reduced = self._tensor_processor.reduce_max(x, self._reduction_shape) + if self._tensor_processor is not None: + max_reduced: int = self._tensor_processor.reduce_max(x, self._reduction_shape) - if self._use_per_sample_stats: + if self._use_per_sample_stats and self._tensor_processor is not None: self._all_min_values.extend(self._tensor_processor.unstack(min_reduced)) self._all_max_values.extend(self._tensor_processor.unstack(max_reduced)) else: @@ -477,14 +514,14 @@ def _register_input_common(self, x: NNCFTensor): self._all_max_values.append(max_reduced) @abstractmethod - def _min_aggregate(self): + def _min_aggregate(self) -> None: pass @abstractmethod - def _max_aggregate(self): + def _max_aggregate(self) -> None: pass - def _reset(self): + def _reset(self) -> None: self._all_min_values.clear() self._all_max_values.clear() @@ -508,17 +545,21 @@ def __init__( self._use_means_of_mins = use_means_of_mins self._use_means_of_maxs = use_means_of_maxs - def _min_aggregate(self): - stacked_min = self._tensor_processor.stack(self._all_min_values) - if self._use_means_of_mins: + def _min_aggregate(self) -> None: + if self._tensor_processor is not None: + stacked_min = self._tensor_processor.stack(self._all_min_values) + if self._use_means_of_mins and self._tensor_processor is not None: return self._tensor_processor.mean(stacked_min, axis=0) - return self._tensor_processor.reduce_min(stacked_min, axis=0) + if self._tensor_processor is not None: + return self._tensor_processor.reduce_min(stacked_min, axis=0) - def _max_aggregate(self): - stacked_max = self._tensor_processor.stack(self._all_max_values) - if self._use_means_of_maxs: + def _max_aggregate(self) -> None: + if self._tensor_processor is not None: + stacked_max = self._tensor_processor.stack(self._all_max_values) + if self._use_means_of_maxs and self._tensor_processor is not None: return self._tensor_processor.mean(stacked_max, axis=0) - return self._tensor_processor.reduce_max(stacked_max, axis=0) + if self._tensor_processor is not None: + return self._tensor_processor.reduce_max(stacked_max, axis=0) class MeanMinMaxStatisticCollector(MinMaxOfflineStatisticCollectorBase): @@ -526,13 +567,17 @@ class MeanMinMaxStatisticCollector(MinMaxOfflineStatisticCollectorBase): Collector aggregates mean of minimum values and mean of maximum values. """ - def _min_aggregate(self): - stacked_min = self._tensor_processor.stack(self._all_min_values) - return self._tensor_processor.mean(stacked_min, axis=0) + def _min_aggregate(self) -> None: + if self._tensor_processor is not None: + stacked_min = self._tensor_processor.stack(self._all_min_values) + if self._tensor_processor is not None: + return self._tensor_processor.mean(stacked_min, axis=0) - def _max_aggregate(self): - stacked_max = self._tensor_processor.stack(self._all_max_values) - return self._tensor_processor.mean(stacked_max, axis=0) + def _max_aggregate(self) -> None: + if self._tensor_processor is not None: + stacked_max = self._tensor_processor.stack(self._all_max_values) + if self._tensor_processor is not None: + return self._tensor_processor.mean(stacked_max, axis=0) class MeanStatisticCollector(OfflineTensorStatisticCollector): @@ -540,7 +585,12 @@ class MeanStatisticCollector(OfflineTensorStatisticCollector): Collector that aggregates statistics as mean along a pre-assigned axis. """ - def __init__(self, channel_axis: int, num_samples: Optional[int] = None, window_size: Optional[int] = None) -> None: + def __init__( + self, + channel_axis: int, + num_samples: Optional[int] = None, + window_size: Optional[int] = None, + ) -> None: """ :param channel_axis: The main axis for the reduction while statistics collection. :param num_samples: Optional parameter for statistic collection that regulates @@ -550,30 +600,32 @@ def __init__(self, channel_axis: int, num_samples: Optional[int] = None, window_ super().__init__(num_samples=num_samples) self._channel_axis = channel_axis self._tensor_processor = self._get_processor() - self._all_values = deque(maxlen=window_size) - self._all_shapes = deque(maxlen=window_size) + self._all_values: deque[int] = deque(maxlen=window_size) + self._all_shapes: deque[list[int]] = deque(maxlen=window_size) @staticmethod @abstractmethod - def _get_processor(): + def _get_processor() -> None: pass - def _register_input_common(self, x: NNCFTensor): - if self._channel_axis == 0: + def _register_input_common(self, x: NNCFTensor) -> None: + if self._channel_axis == 0 and self._tensor_processor is not None: self._all_values.append(self._tensor_processor.batch_mean(x)) - else: + elif self._tensor_processor is not None: self._all_values.append(self._tensor_processor.mean_per_channel(x, self._channel_axis)) self._all_shapes.append(x.shape) - def _reset(self): + def _reset(self) -> None: self._all_values.clear() self._all_shapes.clear() - def _mean_aggregate(self): - all_values_stack = self._tensor_processor.stack(self._all_values) - return self._tensor_processor.mean(all_values_stack, 0) + def _mean_aggregate(self) -> None: + if self._tensor_processor is not None: + all_values_stack = self._tensor_processor.stack(self._all_values) + if self._tensor_processor is not None: + return self._tensor_processor.mean(all_values_stack, 0) - def _shape(self): + def _shape(self) -> list[int]: return self._all_shapes[0] @@ -589,17 +641,17 @@ def __init__(self, num_samples: Optional[int] = None) -> None: the number of samples that will be processed. """ super().__init__(num_samples=num_samples) - self._all_values = [] + self._all_values: List[torch.Tensor] = [] @staticmethod @abstractmethod - def _get_processor(): + def _get_processor() -> None: pass - def _register_input_common(self, x: NNCFTensor): - self._all_values.append(x.tensor) + def _register_input_common(self, x: NNCFTensor) -> None: + self._all_values.append(torch.Tensor(x.tensor)) - def _reset(self): + def _reset(self) -> None: self._all_values.clear() @@ -608,15 +660,19 @@ class MedianMADStatisticCollector(OfflineTensorStatisticCollector): Collector estimates median and median absolute deviation (MAD). """ - def _prepare_statistics(self): - per_channel_history = get_per_channel_history(self._samples, list(self._reduction_shape), discard_zeros=True) + def _prepare_statistics(self) -> np.ndarray[float, Any]: + per_channel_history = get_per_channel_history( + deque(np.array(tensor) for tensor in self._samples), + list(*self._reduction_shape), + discard_zeros=True, + ) per_channel_median = [np.median(channel_hist) for channel_hist in per_channel_history] per_channel_mad = [] for idx, median in enumerate(per_channel_median): per_channel_mad.append(np.median(abs(per_channel_history[idx] - median))) numpy_median = np.asarray(per_channel_median) numpy_mad = np.asarray(per_channel_mad) - return numpy_median, numpy_mad + return np.concatenate([numpy_median, numpy_mad], axis=1) class PercentileStatisticCollector(OfflineTensorStatisticCollector): @@ -634,8 +690,11 @@ def __init__( super().__init__(reduction_shape, num_samples, window_size) self._percentiles_to_collect = percentiles_to_collect - def _prepare_statistics(self): - per_channel_history = get_per_channel_history(self._samples, list(self._reduction_shape)) + def _prepare_statistics(self) -> dict[float, np.ndarray[float, Any]]: + per_channel_history = get_per_channel_history( + deque(np.array(tensor) for tensor in self._samples), + list(*self._reduction_shape), + ) percentile_vs_values_dict = {} for pc in self._percentiles_to_collect: per_channel_percentiles = [np.percentile(channel_hist, pc) for channel_hist in per_channel_history] @@ -657,10 +716,10 @@ def __init__( window_size: int = None, ): super().__init__(reduction_shape, num_samples, window_size) - self._all_pct_values = {} + self._all_pct_values: Dict[float, deque[torch.Tensor]] = {} for pc in percentiles_to_collect: self._all_pct_values[pc] = deque(maxlen=window_size) - def _reset(self): + def _reset(self) -> None: for _, val in self._all_pct_values.items(): val.clear() diff --git a/nncf/common/tensor_statistics/reduction.py b/nncf/common/tensor_statistics/reduction.py index a1cdf89e8b9..949123e67f8 100644 --- a/nncf/common/tensor_statistics/reduction.py +++ b/nncf/common/tensor_statistics/reduction.py @@ -10,9 +10,10 @@ # limitations under the License. from collections import deque -from typing import List, Tuple +from typing import Any, List, Tuple import numpy as np +from numpy import ndarray def get_channel_count_and_dim_idx(scale_shape: List[int]) -> Tuple[int, int]: @@ -25,7 +26,9 @@ def get_channel_count_and_dim_idx(scale_shape: List[int]) -> Tuple[int, int]: return channel_count, channel_dim_idx -def split_into_channels(input_: np.ndarray, scale_shape: List[int]) -> List[np.ndarray]: +def split_into_channels( + input_: np.ndarray[np.float64, np.dtype[Any]], scale_shape: list[int] +) -> List[np.ndarray[np.float64, np.dtype[Any]]]: channel_count, channel_dim_idx = get_channel_count_and_dim_idx(scale_shape) channel_first_tensor = np.moveaxis(input_, channel_dim_idx, 0) if channel_count == 1: @@ -37,9 +40,13 @@ def split_into_channels(input_: np.ndarray, scale_shape: List[int]) -> List[np.n return ret_list -def get_per_channel_history(raw_input_history: deque, scale_shape: List[int], discard_zeros=False) -> List: +def get_per_channel_history( + raw_input_history: deque[np.ndarray[np.float64, Any]], + scale_shape: List[int], + discard_zeros: bool = False, +) -> List[ndarray[np.float64, Any]]: channel_count, _ = get_channel_count_and_dim_idx(scale_shape) - per_channel_history = [None for i in range(channel_count)] + per_channel_history = [np.zeros((0,)) for _ in range(channel_count)] for _ in range(len(raw_input_history)): entry = raw_input_history.popleft() split = split_into_channels(entry, scale_shape) @@ -52,14 +59,16 @@ def get_per_channel_history(raw_input_history: deque, scale_shape: List[int], di flat_channel_split = flat_channel_split[flat_channel_split != 0] if per_channel_history[i] is None: - per_channel_history[i] = flat_channel_split + per_channel_history[i] = flat_channel_split.tolist() else: per_channel_history[i] = np.concatenate([per_channel_history[i], flat_channel_split]) raw_input_history.append(entry) return per_channel_history -def np_percentile_reduce_like(input_: np.array, ref_tensor_shape: Tuple[int], q: float) -> np.array: +def np_percentile_reduce_like( + input_: ndarray[np.float64, Any], ref_tensor_shape: Tuple[int], q: float +) -> ndarray[np.float64, Any]: numel = np.prod(ref_tensor_shape) if numel == 1: return np.array([np.percentile(input_, q)]) diff --git a/nncf/common/tensor_statistics/statistic_point.py b/nncf/common/tensor_statistics/statistic_point.py index 681eaaa429e..b7a91c4a5bf 100644 --- a/nncf/common/tensor_statistics/statistic_point.py +++ b/nncf/common/tensor_statistics/statistic_point.py @@ -10,10 +10,11 @@ # limitations under the License. from collections import UserDict -from typing import Callable, Generator, Optional, Tuple +from typing import Any, Callable, Generator, List, Optional, Tuple + +import torch from nncf.common.graph.transformations.commands import TargetPoint -from nncf.common.tensor import TensorType from nncf.common.tensor_statistics.collectors import TensorStatisticCollectorBase @@ -25,23 +26,29 @@ class StatisticPoint: algorithm implies on what algorithm nedeed this statistics. """ - def __init__(self, target_point: TargetPoint, tensor_collector: TensorStatisticCollectorBase, algorithm: str): + def __init__( + self, + target_point: TargetPoint, + tensor_collector: TensorStatisticCollectorBase, + algorithm: str, + ): self.target_point = target_point self.algorithm_to_tensor_collectors = {algorithm: [tensor_collector]} - def __eq__(self, other): - return ( + def __eq__(self, other: Any) -> bool: + result: bool = ( self.target_point == other.target_point and self.algorithm_to_tensor_collectors == other.self.algorithm_to_tensor_collectors ) + return result - def register_tensor(self, x: TensorType): + def register_tensor(self, x: torch.Tensor) -> None: for tensor_collectors in self.algorithm_to_tensor_collectors.values(): for tensor_collector in tensor_collectors: tensor_collector.register_input(x) -class StatisticPointsContainer(UserDict): +class StatisticPointsContainer(UserDict[str, List[StatisticPoint]]): """ Container with iteration interface for handling a composition of StatisticPoint. """ @@ -73,7 +80,7 @@ def add_statistic_point(self, statistic_point: StatisticPoint) -> None: def iter_through_statistic_points_in_target_node( self, target_node_name: str, filter_fn: Callable[[StatisticPoint], bool] - ) -> StatisticPoint: + ) -> Generator[StatisticPoint, None, None]: """ Returns iterable through all statistic points in node with target_node_name. @@ -98,14 +105,17 @@ def get_tensor_collectors( """ if filter_fn is None: - def default_filter_fn(stat_point: StatisticPoint): + def default_filter_fn(stat_point: StatisticPoint) -> bool: return True filter_fn = default_filter_fn for target_node_name in self.data: for statistic_point in self.iter_through_statistic_points_in_target_node(target_node_name, filter_fn): - for algorithm, tensor_collectors in statistic_point.algorithm_to_tensor_collectors.items(): + for ( + algorithm, + tensor_collectors, + ) in statistic_point.algorithm_to_tensor_collectors.items(): for tensor_collector in tensor_collectors: yield algorithm, statistic_point, tensor_collector diff --git a/nncf/common/tensor_statistics/statistics.py b/nncf/common/tensor_statistics/statistics.py index 17a5e26a3f0..41245ea8f71 100644 --- a/nncf/common/tensor_statistics/statistics.py +++ b/nncf/common/tensor_statistics/statistics.py @@ -12,7 +12,7 @@ from abc import ABC from abc import abstractmethod from collections import Counter -from typing import TypeVar +from typing import Dict, List, TypeVar TensorType = TypeVar("TensorType") @@ -24,11 +24,11 @@ class TensorStatistic(ABC): @staticmethod @abstractmethod - def tensor_eq(tensor1: TensorType, tensor2: TensorType, rtol=1e-6) -> bool: + def tensor_eq(tensor1: TensorType, tensor2: TensorType, rtol: float = 1e-6) -> bool: pass @abstractmethod - def __eq__(self, other): + def __eq__(self, other: object) -> bool: pass @@ -36,11 +36,13 @@ class MinMaxTensorStatistic(TensorStatistic): MIN_STAT = "min_values" MAX_STAT = "max_values" - def __init__(self, min_values, max_values): + def __init__(self, min_values: TensorType, max_values: TensorType): self.min_values = min_values self.max_values = max_values - def __eq__(self, other: "MinMaxTensorStatistic") -> bool: + def __eq__(self, other: object) -> bool: + if not isinstance(other, MinMaxTensorStatistic): + return NotImplemented return self.tensor_eq(self.min_values, other.min_values) and self.tensor_eq(self.max_values, other.max_values) @@ -52,7 +54,7 @@ class MeanTensorStatistic(TensorStatistic): Base class for the statistics that collects as mean per-axis """ - def __init__(self, mean_values, shape): + def __init__(self, mean_values: TensorType, shape: List[int]) -> None: """ :param mean_values: Collected mean per-axis values. :param shape: The shape of the collected statistics. @@ -60,7 +62,9 @@ def __init__(self, mean_values, shape): self.mean_values = mean_values self.shape = shape - def __eq__(self, other: "MeanTensorStatistic") -> bool: + def __eq__(self, other: object) -> bool: + if not isinstance(other, MeanTensorStatistic): + return NotImplemented return self.tensor_eq(self.mean_values, other.mean_values) and self.tensor_eq(self.shape, other.shape) @@ -68,11 +72,13 @@ class MedianMADTensorStatistic(TensorStatistic): MEDIAN_VALUES_STAT = "median_values" MAD_VALUES_STAT = "mad_values" - def __init__(self, median_values, mad_values): + def __init__(self, median_values: TensorType, mad_values: TensorType) -> None: self.median_values = median_values self.mad_values = mad_values - def __eq__(self, other: "MedianMADTensorStatistic") -> bool: + def __eq__(self, other: object) -> bool: + if not isinstance(other, MedianMADTensorStatistic): + return NotImplemented return self.tensor_eq(self.median_values, other.median_values) and self.tensor_eq( self.mad_values, other.mad_values ) @@ -81,14 +87,19 @@ def __eq__(self, other: "MedianMADTensorStatistic") -> bool: class PercentileTensorStatistic(TensorStatistic): PERCENTILE_VS_VALUE_DICT = "percentile_vs_values_dict" - def __init__(self, percentile_vs_values_dict): + def __init__(self, percentile_vs_values_dict: Dict[float, TensorType]) -> None: self.percentile_vs_values_dict = percentile_vs_values_dict - def __eq__(self, other: "PercentileTensorStatistic", rtol=1e-9) -> bool: + def __eq__(self, other: object, rtol: float = 1e-9) -> bool: + if not isinstance(other, PercentileTensorStatistic): + return NotImplemented if Counter(self.percentile_vs_values_dict.keys()) != Counter(other.percentile_vs_values_dict.keys()): return False for pct in self.percentile_vs_values_dict: - if not self.tensor_eq(self.percentile_vs_values_dict[pct], other.percentile_vs_values_dict[pct]): + if not self.tensor_eq( + self.percentile_vs_values_dict[pct], + other.percentile_vs_values_dict[pct], + ): return False return True @@ -100,11 +111,13 @@ class RawTensorStatistic(TensorStatistic): Base class for the raw statistics, without any aggregation. """ - def __init__(self, values): + def __init__(self, values: TensorType) -> None: """ :param values: Collected raw values. """ self.values = values - def __eq__(self, other: "RawTensorStatistic") -> bool: + def __eq__(self, other: object) -> bool: + if not isinstance(other, RawTensorStatistic): + return NotImplemented # Delegate to parent class return self.tensor_eq(self.values, other.values)