From 39460e1c46d6b1775a8b76cf5e33069420f4f82e Mon Sep 17 00:00:00 2001 From: jiahangxu Date: Tue, 13 Sep 2022 03:21:49 +0000 Subject: [PATCH 1/3] fix bug in torch fusion rules --- .../backend_meta/fusion_rule_tester/generate_testcase.py | 5 ++++- nn_meter/builder/config_manager.py | 5 +++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/nn_meter/builder/backend_meta/fusion_rule_tester/generate_testcase.py b/nn_meter/builder/backend_meta/fusion_rule_tester/generate_testcase.py index 9a9633aa..e0c48860 100644 --- a/nn_meter/builder/backend_meta/fusion_rule_tester/generate_testcase.py +++ b/nn_meter/builder/backend_meta/fusion_rule_tester/generate_testcase.py @@ -79,7 +79,10 @@ def generate_testcases(): if op1 in d1_required_layers or op2 in d1_required_layers: input_shape = [config['SHAPE_1D']] else: - input_shape = [config['HW'], config['HW'], config['CIN']] + if implement == "tensorflow": + input_shape = [config['HW'], config['HW'], config['CIN']] + else: + input_shape = [config['CIN'], config['HW'], config['HW']] bf_cls = type(class_name, (BasicFusion,), { 'name': name, 'cases': cases, diff --git a/nn_meter/builder/config_manager.py b/nn_meter/builder/config_manager.py index 192c027d..915bf719 100644 --- a/nn_meter/builder/config_manager.py +++ b/nn_meter/builder/config_manager.py @@ -17,8 +17,9 @@ def copy_to_workspace(backend_type, workspace_path, backendConfigFile = None): os.makedirs(os.path.join(workspace_path, 'configs'), exist_ok=True) # backend config - if backend_type == 'customized' and backendConfigFile: - copyfile(backendConfigFile, os.path.join(workspace_path, 'configs', 'backend_config.yaml')) + if backend_type == 'customized': + if backendConfigFile: + copyfile(backendConfigFile, os.path.join(workspace_path, 'configs', 'backend_config.yaml')) else: if backend_type == 'tflite': config_name = __backend_tflite_cfg_filename__ From 3bf5a68b78dde370c7bbb4dfc421ca0bb164d3cb Mon Sep 17 00:00:00 2001 From: jiahangxu Date: Tue, 13 Sep 2022 08:50:11 +0000 Subject: [PATCH 2/3] fix bugs of torch fc op --- .../backend_meta/fusion_rule_tester/test_fusion_rule.py | 2 +- nn_meter/builder/nn_modules/torch_networks/operators.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/nn_meter/builder/backend_meta/fusion_rule_tester/test_fusion_rule.py b/nn_meter/builder/backend_meta/fusion_rule_tester/test_fusion_rule.py index b4fb323f..5b2367de 100644 --- a/nn_meter/builder/backend_meta/fusion_rule_tester/test_fusion_rule.py +++ b/nn_meter/builder/backend_meta/fusion_rule_tester/test_fusion_rule.py @@ -53,6 +53,6 @@ def analyze(self, profile_results): latency = {key: str(value) for key, value in rule.latency.items()} result[name]['latency'] = latency - result[name]['obey'] = obey + result[name]['obey'] = bool(obey) return result diff --git a/nn_meter/builder/nn_modules/torch_networks/operators.py b/nn_meter/builder/nn_modules/torch_networks/operators.py index a40e46b8..053603e4 100644 --- a/nn_meter/builder/nn_modules/torch_networks/operators.py +++ b/nn_meter/builder/nn_modules/torch_networks/operators.py @@ -128,13 +128,13 @@ def forward(self, x): class FC(BaseOperator): def get_model(self): - cin = self.input_shape[0] - cout = self.input_shape[0] if "COUT" not in self.config else self.config["COUT"] + cin = self.input_shape[-1] + cout = self.input_shape[-1] if "COUT" not in self.config else self.config["COUT"] return nn.Linear(cin, cout) def get_output_shape(self): - cout = self.input_shape[0] if "COUT" not in self.config else self.config["COUT"] - return [cout] + self.input_shape[1:] + cout = self.input_shape[-1] if "COUT" not in self.config else self.config["COUT"] + return self.input_shape[:-1] + [cout] #-------------------- activation function --------------------# From b1fff2387619a9a5ca5c935158945dde75fc3917 Mon Sep 17 00:00:00 2001 From: jiahangxu Date: Tue, 13 Sep 2022 08:53:52 +0000 Subject: [PATCH 3/3] add logger in fusion rule test --- nn_meter/builder/backend_meta/fusion_rule_tester/interface.py | 1 + 1 file changed, 1 insertion(+) diff --git a/nn_meter/builder/backend_meta/fusion_rule_tester/interface.py b/nn_meter/builder/backend_meta/fusion_rule_tester/interface.py index b63dbc00..855ae1f3 100644 --- a/nn_meter/builder/backend_meta/fusion_rule_tester/interface.py +++ b/nn_meter/builder/backend_meta/fusion_rule_tester/interface.py @@ -6,6 +6,7 @@ from ..utils import read_profiled_results from nn_meter.builder.utils import merge_info from nn_meter.builder.backend_meta.utils import Latency +logging = logging.getLogger("nn-Meter") class BaseTestCase: