From 2ca89b2f5ff0f24df492377ce9fbfac59b4db446 Mon Sep 17 00:00:00 2001 From: Michael Tuttle Date: Mon, 27 Jan 2025 09:15:18 -0800 Subject: [PATCH] Optimize onnx quantsim init data type inference (#3747) Signed-off-by: Michael Tuttle --- .../onnx/src/python/aimet_onnx/quantsim.py | 48 +++++++++++++++---- .../onnx/src/python/aimet_onnx/utils.py | 12 +++++ .../test/python/models/models_for_tests.py | 43 +++++++++++++++++ .../onnx/test/python/test_quantsim.py | 46 ++++++++++++++++++ 4 files changed, 140 insertions(+), 9 deletions(-) diff --git a/TrainingExtensions/onnx/src/python/aimet_onnx/quantsim.py b/TrainingExtensions/onnx/src/python/aimet_onnx/quantsim.py index e895c05d9c9..bae4a51756f 100644 --- a/TrainingExtensions/onnx/src/python/aimet_onnx/quantsim.py +++ b/TrainingExtensions/onnx/src/python/aimet_onnx/quantsim.py @@ -42,6 +42,7 @@ from pathlib import Path import os from typing import Dict, List, Union, Tuple, Optional +import itertools import json import warnings import numpy as np @@ -66,7 +67,8 @@ from aimet_onnx.meta.connectedgraph import ConnectedGraph from aimet_onnx.qc_quantize_op import QcQuantizeOp, OpMode, TensorQuantizerParams, GroupedBlockQuantizeDequantize from aimet_onnx.quantsim_config.quantsim_config import QuantSimConfigurator -from aimet_onnx.utils import make_dummy_input, add_hook_to_get_activation, remove_activation_hooks +from aimet_onnx.utils import make_dummy_input, save_model_with_external_weights, add_hook_to_get_activation, \ + remove_activation_hooks logger = AimetLogger.get_area_logger(AimetLogger.LogAreas.Quant) @@ -270,7 +272,11 @@ def _get_activations_to_quantize(self, dummy_input: Dict[str, np.ndarray]): :param dummy_input: Sample input to be run through the model """ - self.fill_activation_dtypes(dummy_input) + try: + self.activation_dtypes = self._infer_activation_dtypes() + except onnx.shape_inference.InferenceError: + self.activation_dtypes = self._observe_activation_dtypes(dummy_input) + self.input_name_to_nodes = self.model.input_name_to_nodes() self.output_name_to_node = self.model.output_name_to_node() @@ -366,9 +372,32 @@ def _check_matmul_add_patten(self, node: onnx.NodeProto) -> bool: return True return False - def fill_activation_dtypes(self, dummy_input: Dict[str, np.ndarray]): + def _infer_activation_dtypes(self): + """ + Get the data type for each activation through shape inference + """ + if self.model.model.ByteSize() >= onnx.checker.MAXIMUM_PROTOBUF: + with tempfile.TemporaryDirectory(dir=self._path) as tempdir: + save_path = os.path.join(tempdir, "inferred_model.onnx") + save_model_with_external_weights(self.model.model, save_path, location=Path(save_path).name + ".data") + onnx.shape_inference.infer_shapes_path(save_path) + # Do not load the weights for the shape inference model, we only need to access the graph's `value_info` + inferred_model = onnx.load(save_path, load_external_data=False) + else: + inferred_model = onnx.shape_inference.infer_shapes(self.model.model) + + activation_dtypes = {} + for val_info in itertools.chain(inferred_model.graph.value_info, + inferred_model.graph.input, + inferred_model.graph.output): + act_name = val_info.name + dtype = onnx.mapping.TENSOR_TYPE_TO_NP_TYPE[val_info.type.tensor_type.elem_type] + activation_dtypes[act_name] = dtype + return activation_dtypes + + def _observe_activation_dtypes(self, dummy_input: Dict[str, np.ndarray]): """ - Get the data type for each activation + Get the data type for each activation by returning all activations :param dummy_input: Sample input to run through the model """ @@ -379,11 +408,14 @@ def fill_activation_dtypes(self, dummy_input: Dict[str, np.ndarray]): sess = QuantizationSimModel.build_session(self.model.model, ['CPUExecutionProvider'], user_onnx_libs=self._user_onnx_libs, path=self._path) outputs = sess.run(None, dummy_input) + + activation_dtypes = {} for idx in range(len(self.model.graph().output)): act_name = self.model.graph().output[idx].name dtype = outputs[idx].dtype - self.activation_dtypes[act_name] = dtype + activation_dtypes[act_name] = dtype remove_activation_hooks(self.model.model, hooks) + return activation_dtypes def _add_quantization_nodes(self): """ @@ -524,8 +556,7 @@ def build_session(model: onnx.ModelProto, providers: List, user_onnx_libs: List[ output_path = os.path.join(path, 'model.onnx') if save_as_external_data: # Note: Saving as external data mutates the saved model, removing all initializer data - onnx.save_model(model, output_path, save_as_external_data=True, location=Path(output_path).name + ".data") - onnx.load_external_data_for_model(model, base_dir=path) + save_model_with_external_weights(model, output_path, location=Path(output_path).name + ".data") path_or_bytes = output_path if save_as_external_data else model.SerializeToString() session = InferenceSession( @@ -778,8 +809,7 @@ def export(self, path: str, filename_prefix: str): self.remove_quantization_nodes() if self.model.model.ByteSize() >= onnx.checker.MAXIMUM_PROTOBUF: # Note: Saving as external data mutates the saved model, removing all initializer data - self.model.save_model_to_file(os.path.join(path, filename_prefix) + '.onnx', use_external_data_format=True) - onnx.load_external_data_for_model(self.model.model, base_dir=path) + save_model_with_external_weights(self.model.model, os.path.join(path, filename_prefix) + '.onnx') else: self.model.save_model_to_file(os.path.join(path, filename_prefix) + '.onnx') diff --git a/TrainingExtensions/onnx/src/python/aimet_onnx/utils.py b/TrainingExtensions/onnx/src/python/aimet_onnx/utils.py index 48cd183d594..33136c443bd 100644 --- a/TrainingExtensions/onnx/src/python/aimet_onnx/utils.py +++ b/TrainingExtensions/onnx/src/python/aimet_onnx/utils.py @@ -403,6 +403,18 @@ def retrieve_constant_input(node: NodeProto, model: ModelProto, index: int transposed = True return weight, transposed +def save_model_with_external_weights(model: onnx.ModelProto, f: str, **kwargs): + """ + Saves an onnx model with external weights without mutating the original model + + :param model: ONNX ModelProto object to save + :param f: filename to save the model to + :param kwargs: Additional keyword arguments to pass to :func:`onnx.save_model` + """ + onnx.save_model(model, f, save_as_external_data=True, **kwargs) + # Load back weights which are removed when saving as external data + onnx.load_external_data_for_model(model, os.path.dirname(f)) + class CachedDataset: """ diff --git a/TrainingExtensions/onnx/test/python/models/models_for_tests.py b/TrainingExtensions/onnx/test/python/models/models_for_tests.py index 1579410c5bc..9af3bcf7648 100644 --- a/TrainingExtensions/onnx/test/python/models/models_for_tests.py +++ b/TrainingExtensions/onnx/test/python/models/models_for_tests.py @@ -52,6 +52,7 @@ from torch.nn.modules.batchnorm import _BatchNorm from onnx import helper, numpy_helper, OperatorSetIdProto, TensorProto, load_model from onnxruntime.quantization.onnx_quantizer import ONNXModel +from onnxruntime_extensions import PyOp, onnx_op from aimet_common import libquant_info from .mobilenet import MockMobileNetV1, MockMobileNetV11 @@ -2709,3 +2710,45 @@ def squeezenet1_0(tmpdir): input_names=["input"], output_names=["output"]) model = onnx.load(filepath) return ONNXModel(model) + +@onnx_op(op_type="CustomAdd", + inputs=[PyOp.dt_float, PyOp.dt_float], + outputs=[PyOp.dt_float]) +def add_op(x, y): + return x + y + +def custom_op_model(): + model = helper.make_model( + graph=helper.make_graph( + name="CustomAddModel", + inputs=[helper.make_tensor_value_info('model_input', TensorProto.FLOAT, shape=[10, 10])], + outputs=[helper.make_tensor_value_info('model_output', TensorProto.FLOAT, shape=[10, 10])], + initializer=[], + nodes=[ + helper.make_node( + "Relu", + inputs=["model_input"], + outputs=["y"], + ), + helper.make_node( + "CustomAdd", + inputs=["model_input", "y"], + outputs=["z"], + domain="ai.onnx.contrib" + ), + helper.make_node( + "CustomAdd", + inputs=["z", "y"], + outputs=["output"], + domain="ai.onnx.contrib" + ), + helper.make_node( + "Exp", + inputs=["output"], + outputs=["model_output"] + ) + ], + ), + opset_imports=[helper.make_operatorsetid('ai.onnx.contrib', 1)] + ) + return model diff --git a/TrainingExtensions/onnx/test/python/test_quantsim.py b/TrainingExtensions/onnx/test/python/test_quantsim.py index 2627de66e17..ae1b902e803 100644 --- a/TrainingExtensions/onnx/test/python/test_quantsim.py +++ b/TrainingExtensions/onnx/test/python/test_quantsim.py @@ -40,6 +40,8 @@ import json import os import tempfile +import tracemalloc + import onnx.numpy_helper import torch import numpy as np @@ -653,6 +655,44 @@ def test_multiple_output_quantsim(self): path=tempdir) sim.session.run(None, {'input': sample_input}) + def test_quantsim_init_memory_usage(self): + """ + When: Instantiate a quantsim model with high activation memory usage + Then: Memory usage should not spike + """ + num_layers = 2 ** 9 + activation_dim = 2 ** 13 + batch_size = 2 ** 8 + total_act_memory = num_layers * activation_dim * batch_size + + # Create a model with very high total activation memory usage + layers = [ + onnx.helper.make_node("Constant", inputs=[], outputs=["shape"], name="shape", + value=onnx.numpy_helper.from_array(np.array([batch_size, activation_dim], dtype=np.dtype("int64")))), + onnx.helper.make_node("Expand", inputs=["input", "shape"], outputs=["act0"], name="reshape"), + ] + for idx in range(num_layers): + layers.append( + onnx.helper.make_node("Sigmoid", inputs=[f"act{idx}"], outputs=[f"act{idx + 1}"], + name=f"layer_{idx}") + ) + + input_tensor = onnx.helper.make_tensor_value_info("input", onnx.TensorProto.FLOAT, [1, 1]) + output_tensor = onnx.helper.make_tensor_value_info(f"act{num_layers}", onnx.TensorProto.FLOAT, + [batch_size, activation_dim]) + graph = onnx.helper.make_graph(layers, "graph", initializer=[], inputs=[input_tensor], + outputs=[output_tensor]) + model = onnx.helper.make_model(graph) + + with tempfile.TemporaryDirectory() as tempdir: + tracemalloc.start() + sim = QuantizationSimModel(model, path=tempdir) + current_mem, peak_mem = tracemalloc.get_traced_memory() + tracemalloc.stop() + + assert peak_mem < current_mem + 0.25 * total_act_memory + assert peak_mem < current_mem * 5 + @pytest.mark.skip(reason="test requires exact version of torch that the code has built against.") def test_model_with_custom_ops(self): custom_ops_path = os.path.dirname(libquant_info.__file__) @@ -1690,3 +1730,9 @@ def test_identity_conv_perchannel(self): config_file=get_path_for_per_channel_config()) assert sim.qc_quantize_op_dict["identity.input"].quant_info.usePerChannelMode assert sim.qc_quantize_op_dict["identity.input"].quant_info.channelAxis == 0 + + def test_customop_model(self): + from onnxruntime_extensions import get_library_path + model = models_for_tests.custom_op_model() + sim = QuantizationSimModel(model, user_onnx_libs=[get_library_path()]) + assert {"model_input", "output", "model_output", "y", "z"} == sim.qc_quantize_op_dict.keys()