diff --git a/tools/onnx-optimize.py b/tools/onnx-optimize.py index cd263a8e3..ed7436be8 100644 --- a/tools/onnx-optimize.py +++ b/tools/onnx-optimize.py @@ -13,7 +13,7 @@ import logging import onnx -from onnx import helper +from onnx import helper, shape_inference from tf2onnx.graph import GraphUtil from tf2onnx import logging, optimizer, constants @@ -46,6 +46,12 @@ def load_graph(fname, target): return g, model_proto +def model_shape_inference(onnx_model_proto): + inferred_model = shape_inference.infer_shapes(onnx_model_proto) + onnx.checker.check_model(inferred_model) + return inferred_model + + def main(): args = get_args() @@ -64,10 +70,12 @@ def main(): model_proto = helper.make_model(onnx_graph, **kwargs) + model_proto_inferred = model_shape_inference(model_proto) + # write onnx graph if args.output: with open(args.output, "wb") as f: - f.write(model_proto.SerializeToString()) + f.write(model_proto_inferred.SerializeToString()) if __name__ == "__main__":