diff --git a/nn_meter/builder/backend_meta/fusion_rule_tester/generate_testcase.py b/nn_meter/builder/backend_meta/fusion_rule_tester/generate_testcase.py index 9a9633aa..e0c48860 100644 --- a/nn_meter/builder/backend_meta/fusion_rule_tester/generate_testcase.py +++ b/nn_meter/builder/backend_meta/fusion_rule_tester/generate_testcase.py @@ -79,7 +79,10 @@ def generate_testcases(): if op1 in d1_required_layers or op2 in d1_required_layers: input_shape = [config['SHAPE_1D']] else: - input_shape = [config['HW'], config['HW'], config['CIN']] + if implement == "tensorflow": + input_shape = [config['HW'], config['HW'], config['CIN']] + else: + input_shape = [config['CIN'], config['HW'], config['HW']] bf_cls = type(class_name, (BasicFusion,), { 'name': name, 'cases': cases, diff --git a/nn_meter/builder/backend_meta/fusion_rule_tester/interface.py b/nn_meter/builder/backend_meta/fusion_rule_tester/interface.py index b63dbc00..855ae1f3 100644 --- a/nn_meter/builder/backend_meta/fusion_rule_tester/interface.py +++ b/nn_meter/builder/backend_meta/fusion_rule_tester/interface.py @@ -6,6 +6,7 @@ from ..utils import read_profiled_results from nn_meter.builder.utils import merge_info from nn_meter.builder.backend_meta.utils import Latency +logging = logging.getLogger("nn-Meter") class BaseTestCase: diff --git a/nn_meter/builder/backend_meta/fusion_rule_tester/test_fusion_rule.py b/nn_meter/builder/backend_meta/fusion_rule_tester/test_fusion_rule.py index b4fb323f..5b2367de 100644 --- a/nn_meter/builder/backend_meta/fusion_rule_tester/test_fusion_rule.py +++ b/nn_meter/builder/backend_meta/fusion_rule_tester/test_fusion_rule.py @@ -53,6 +53,6 @@ def analyze(self, profile_results): latency = {key: str(value) for key, value in rule.latency.items()} result[name]['latency'] = latency - result[name]['obey'] = obey + result[name]['obey'] = bool(obey) return result diff --git a/nn_meter/builder/config_manager.py b/nn_meter/builder/config_manager.py index 192c027d..915bf719 100644 --- a/nn_meter/builder/config_manager.py +++ b/nn_meter/builder/config_manager.py @@ -17,8 +17,9 @@ def copy_to_workspace(backend_type, workspace_path, backendConfigFile = None): os.makedirs(os.path.join(workspace_path, 'configs'), exist_ok=True) # backend config - if backend_type == 'customized' and backendConfigFile: - copyfile(backendConfigFile, os.path.join(workspace_path, 'configs', 'backend_config.yaml')) + if backend_type == 'customized': + if backendConfigFile: + copyfile(backendConfigFile, os.path.join(workspace_path, 'configs', 'backend_config.yaml')) else: if backend_type == 'tflite': config_name = __backend_tflite_cfg_filename__ diff --git a/nn_meter/builder/nn_modules/torch_networks/operators.py b/nn_meter/builder/nn_modules/torch_networks/operators.py index a40e46b8..053603e4 100644 --- a/nn_meter/builder/nn_modules/torch_networks/operators.py +++ b/nn_meter/builder/nn_modules/torch_networks/operators.py @@ -128,13 +128,13 @@ def forward(self, x): class FC(BaseOperator): def get_model(self): - cin = self.input_shape[0] - cout = self.input_shape[0] if "COUT" not in self.config else self.config["COUT"] + cin = self.input_shape[-1] + cout = self.input_shape[-1] if "COUT" not in self.config else self.config["COUT"] return nn.Linear(cin, cout) def get_output_shape(self): - cout = self.input_shape[0] if "COUT" not in self.config else self.config["COUT"] - return [cout] + self.input_shape[1:] + cout = self.input_shape[-1] if "COUT" not in self.config else self.config["COUT"] + return self.input_shape[:-1] + [cout] #-------------------- activation function --------------------#