-
Notifications
You must be signed in to change notification settings - Fork 239
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Fixed]: Issues resolved raised by mypy for issue Make #22313 #2438
base: develop
Are you sure you want to change the base?
Changes from all commits
206aafc
c8ef1ea
57935f7
d4a3bdc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,13 +10,11 @@ | |
# limitations under the License. | ||
from abc import ABC | ||
from abc import abstractmethod | ||
from itertools import islice | ||
from typing import Any, Dict, TypeVar | ||
|
||
from nncf.common import factory | ||
from nncf.common.graph.graph import NNCFGraph | ||
from nncf.common.graph.transformations.layout import TransformationLayout | ||
from nncf.common.logging.track_progress import track | ||
from nncf.common.tensor import NNCFTensor | ||
from nncf.common.tensor_statistics.statistic_point import StatisticPointsContainer | ||
from nncf.data.dataset import Dataset | ||
|
@@ -30,9 +28,9 @@ class StatisticsAggregator(ABC): | |
Base class for statistics collection. | ||
""" | ||
|
||
def __init__(self, dataset: Dataset): | ||
def __init__(self, dataset: Dataset[int, int]): | ||
self.dataset = dataset | ||
self.stat_subset_size = None | ||
self.stat_subset_size = 0 | ||
self.statistic_points = StatisticPointsContainer() | ||
|
||
def collect_statistics(self, model: TModel, graph: NNCFGraph) -> None: | ||
|
@@ -50,21 +48,15 @@ def collect_statistics(self, model: TModel, graph: NNCFGraph) -> None: | |
|
||
merged_statistics = self._get_merged_statistic_points(self.statistic_points, model, graph) | ||
transformation_layout = self._get_transformation_layout_extra_outputs(merged_statistics) | ||
model_with_outputs = model_transformer.transform(transformation_layout) | ||
model_with_outputs: TModel = model_transformer.transform(transformation_layout) | ||
engine = factory.EngineFactory.create(model_with_outputs) | ||
|
||
dataset_length = self.dataset.get_length() | ||
total = ( | ||
min(dataset_length or self.stat_subset_size, self.stat_subset_size) | ||
if self.stat_subset_size is not None | ||
else None | ||
) | ||
Comment on lines
-56
to
-61
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was logic before: if There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The logic to set In this part of the code, it iterates over the registered statistic points and their associated tensor collectors. For each tensor collector, it checks if There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Got you. But despite the fact that this should work in every case, I prefer explicit double check |
||
empty_statistics = True | ||
for input_data in track( | ||
islice(self.dataset.get_inference_data(), self.stat_subset_size), | ||
total=total, | ||
description="Statistics collection", | ||
): | ||
data_iterable = iter([self.dataset.get_inference_data()]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please keep |
||
for input_data in data_iterable: | ||
outputs = engine.infer(input_data) | ||
processed_outputs = self._process_outputs(outputs) | ||
self._register_statistics(processed_outputs, merged_statistics) | ||
Comment on lines
+56
to
+59
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Code duplicate |
||
outputs = engine.infer(input_data) | ||
processed_outputs = self._process_outputs(outputs) | ||
self._register_statistics(processed_outputs, merged_statistics) | ||
|
@@ -87,7 +79,7 @@ def register_statistic_points(self, statistic_points: StatisticPointsContainer) | |
|
||
for _, _statistic_points in self.statistic_points.items(): | ||
for _statistic_point in _statistic_points: | ||
for _, tensor_collectors in _statistic_point.algorithm_to_tensor_collectors.items(): | ||
for tensor_collectors in _statistic_point.algorithm_to_tensor_collectors.values(): | ||
for tensor_collector in tensor_collectors: | ||
Comment on lines
+82
to
83
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this a real issue? |
||
if self.stat_subset_size is None: | ||
self.stat_subset_size = tensor_collector.num_samples | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
target_node_name
parameter is valid for all inherited classes except TFTargetPoint. I see three possible solutions here:target_node_name
attribute as a required init attribute.target_node_name
attribute and use it in correspondent typehintslayer_name
andop_name
totarget_node_name
in the TFTargetPoints.I suggest to implement the third solution
CC: @alexsu52