diff --git a/nn_meter/ir_converter/onnx_converter/converter.py b/nn_meter/ir_converter/onnx_converter/converter.py index 17808e3e..b7002943 100644 --- a/nn_meter/ir_converter/onnx_converter/converter.py +++ b/nn_meter/ir_converter/onnx_converter/converter.py @@ -16,7 +16,7 @@ def __init__(self, model): self.graph = inferred_model.graph self.tensors = {} - for tensor in chain(self.graph.input, self.graph.value_info, self.graph.output): + for tensor in chain(self.graph.input, self.graph.value_info, self.graph.initializer, self.graph.output): self.tensors[tensor.name] = { "shape": get_tensor_shape(tensor), "inputs": [], diff --git a/nn_meter/ir_converter/onnx_converter/utils.py b/nn_meter/ir_converter/onnx_converter/utils.py index 172cc82b..e8b55733 100644 --- a/nn_meter/ir_converter/onnx_converter/utils.py +++ b/nn_meter/ir_converter/onnx_converter/utils.py @@ -2,8 +2,12 @@ # Licensed under the MIT license. def get_tensor_shape(tensor): shape = [] - for dim in tensor.type.tensor_type.shape.dim: - shape.append(dim.dim_value) + try: + for dim in tensor.type.tensor_type.shape.dim: + shape.append(dim.dim_value) + except AttributeError: + # initializer + shape += tensor.dims if len(shape) == 4: shape = [shape[0], shape[2], shape[3], shape[1]] return shape diff --git a/nn_meter/utils/graph_tool.py b/nn_meter/utils/graph_tool.py index 35947c3c..28b2324b 100644 --- a/nn_meter/utils/graph_tool.py +++ b/nn_meter/utils/graph_tool.py @@ -28,6 +28,9 @@ def node(self, name, inbound_nodes=None): self.graph[node]["outbounds"].append(name) def refresh(self): + if len(self.graph) <= 1: + return + last_remove_nodes_cnt = -1 while True: for name in self.graph.keys():