Skip to content

Commit

Permalink
Optimize onnx quantsim init data type inference (#3747)
Browse files Browse the repository at this point in the history
Signed-off-by: Michael Tuttle <[email protected]>
  • Loading branch information
quic-mtuttle authored Jan 27, 2025
1 parent 084abb8 commit 2ca89b2
Show file tree
Hide file tree
Showing 4 changed files with 140 additions and 9 deletions.
48 changes: 39 additions & 9 deletions TrainingExtensions/onnx/src/python/aimet_onnx/quantsim.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from pathlib import Path
import os
from typing import Dict, List, Union, Tuple, Optional
import itertools
import json
import warnings
import numpy as np
Expand All @@ -66,7 +67,8 @@
from aimet_onnx.meta.connectedgraph import ConnectedGraph
from aimet_onnx.qc_quantize_op import QcQuantizeOp, OpMode, TensorQuantizerParams, GroupedBlockQuantizeDequantize
from aimet_onnx.quantsim_config.quantsim_config import QuantSimConfigurator
from aimet_onnx.utils import make_dummy_input, add_hook_to_get_activation, remove_activation_hooks
from aimet_onnx.utils import make_dummy_input, save_model_with_external_weights, add_hook_to_get_activation, \
remove_activation_hooks

logger = AimetLogger.get_area_logger(AimetLogger.LogAreas.Quant)

Expand Down Expand Up @@ -270,7 +272,11 @@ def _get_activations_to_quantize(self, dummy_input: Dict[str, np.ndarray]):
:param dummy_input: Sample input to be run through the model
"""
self.fill_activation_dtypes(dummy_input)
try:
self.activation_dtypes = self._infer_activation_dtypes()
except onnx.shape_inference.InferenceError:
self.activation_dtypes = self._observe_activation_dtypes(dummy_input)

self.input_name_to_nodes = self.model.input_name_to_nodes()
self.output_name_to_node = self.model.output_name_to_node()

Expand Down Expand Up @@ -366,9 +372,32 @@ def _check_matmul_add_patten(self, node: onnx.NodeProto) -> bool:
return True
return False

def fill_activation_dtypes(self, dummy_input: Dict[str, np.ndarray]):
def _infer_activation_dtypes(self):
"""
Get the data type for each activation through shape inference
"""
if self.model.model.ByteSize() >= onnx.checker.MAXIMUM_PROTOBUF:
with tempfile.TemporaryDirectory(dir=self._path) as tempdir:
save_path = os.path.join(tempdir, "inferred_model.onnx")
save_model_with_external_weights(self.model.model, save_path, location=Path(save_path).name + ".data")
onnx.shape_inference.infer_shapes_path(save_path)
# Do not load the weights for the shape inference model, we only need to access the graph's `value_info`
inferred_model = onnx.load(save_path, load_external_data=False)
else:
inferred_model = onnx.shape_inference.infer_shapes(self.model.model)

activation_dtypes = {}
for val_info in itertools.chain(inferred_model.graph.value_info,
inferred_model.graph.input,
inferred_model.graph.output):
act_name = val_info.name
dtype = onnx.mapping.TENSOR_TYPE_TO_NP_TYPE[val_info.type.tensor_type.elem_type]
activation_dtypes[act_name] = dtype
return activation_dtypes

def _observe_activation_dtypes(self, dummy_input: Dict[str, np.ndarray]):
"""
Get the data type for each activation
Get the data type for each activation by returning all activations
:param dummy_input: Sample input to run through the model
"""
Expand All @@ -379,11 +408,14 @@ def fill_activation_dtypes(self, dummy_input: Dict[str, np.ndarray]):
sess = QuantizationSimModel.build_session(self.model.model, ['CPUExecutionProvider'],
user_onnx_libs=self._user_onnx_libs, path=self._path)
outputs = sess.run(None, dummy_input)

activation_dtypes = {}
for idx in range(len(self.model.graph().output)):
act_name = self.model.graph().output[idx].name
dtype = outputs[idx].dtype
self.activation_dtypes[act_name] = dtype
activation_dtypes[act_name] = dtype
remove_activation_hooks(self.model.model, hooks)
return activation_dtypes

def _add_quantization_nodes(self):
"""
Expand Down Expand Up @@ -524,8 +556,7 @@ def build_session(model: onnx.ModelProto, providers: List, user_onnx_libs: List[
output_path = os.path.join(path, 'model.onnx')
if save_as_external_data:
# Note: Saving as external data mutates the saved model, removing all initializer data
onnx.save_model(model, output_path, save_as_external_data=True, location=Path(output_path).name + ".data")
onnx.load_external_data_for_model(model, base_dir=path)
save_model_with_external_weights(model, output_path, location=Path(output_path).name + ".data")

path_or_bytes = output_path if save_as_external_data else model.SerializeToString()
session = InferenceSession(
Expand Down Expand Up @@ -778,8 +809,7 @@ def export(self, path: str, filename_prefix: str):
self.remove_quantization_nodes()
if self.model.model.ByteSize() >= onnx.checker.MAXIMUM_PROTOBUF:
# Note: Saving as external data mutates the saved model, removing all initializer data
self.model.save_model_to_file(os.path.join(path, filename_prefix) + '.onnx', use_external_data_format=True)
onnx.load_external_data_for_model(self.model.model, base_dir=path)
save_model_with_external_weights(self.model.model, os.path.join(path, filename_prefix) + '.onnx')
else:
self.model.save_model_to_file(os.path.join(path, filename_prefix) + '.onnx')

Expand Down
12 changes: 12 additions & 0 deletions TrainingExtensions/onnx/src/python/aimet_onnx/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,18 @@ def retrieve_constant_input(node: NodeProto, model: ModelProto, index: int
transposed = True
return weight, transposed

def save_model_with_external_weights(model: onnx.ModelProto, f: str, **kwargs):
"""
Saves an onnx model with external weights without mutating the original model
:param model: ONNX ModelProto object to save
:param f: filename to save the model to
:param kwargs: Additional keyword arguments to pass to :func:`onnx.save_model`
"""
onnx.save_model(model, f, save_as_external_data=True, **kwargs)
# Load back weights which are removed when saving as external data
onnx.load_external_data_for_model(model, os.path.dirname(f))


class CachedDataset:
"""
Expand Down
43 changes: 43 additions & 0 deletions TrainingExtensions/onnx/test/python/models/models_for_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
from torch.nn.modules.batchnorm import _BatchNorm
from onnx import helper, numpy_helper, OperatorSetIdProto, TensorProto, load_model
from onnxruntime.quantization.onnx_quantizer import ONNXModel
from onnxruntime_extensions import PyOp, onnx_op

from aimet_common import libquant_info
from .mobilenet import MockMobileNetV1, MockMobileNetV11
Expand Down Expand Up @@ -2709,3 +2710,45 @@ def squeezenet1_0(tmpdir):
input_names=["input"], output_names=["output"])
model = onnx.load(filepath)
return ONNXModel(model)

@onnx_op(op_type="CustomAdd",
inputs=[PyOp.dt_float, PyOp.dt_float],
outputs=[PyOp.dt_float])
def add_op(x, y):
return x + y

def custom_op_model():
model = helper.make_model(
graph=helper.make_graph(
name="CustomAddModel",
inputs=[helper.make_tensor_value_info('model_input', TensorProto.FLOAT, shape=[10, 10])],
outputs=[helper.make_tensor_value_info('model_output', TensorProto.FLOAT, shape=[10, 10])],
initializer=[],
nodes=[
helper.make_node(
"Relu",
inputs=["model_input"],
outputs=["y"],
),
helper.make_node(
"CustomAdd",
inputs=["model_input", "y"],
outputs=["z"],
domain="ai.onnx.contrib"
),
helper.make_node(
"CustomAdd",
inputs=["z", "y"],
outputs=["output"],
domain="ai.onnx.contrib"
),
helper.make_node(
"Exp",
inputs=["output"],
outputs=["model_output"]
)
],
),
opset_imports=[helper.make_operatorsetid('ai.onnx.contrib', 1)]
)
return model
46 changes: 46 additions & 0 deletions TrainingExtensions/onnx/test/python/test_quantsim.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@
import json
import os
import tempfile
import tracemalloc

import onnx.numpy_helper
import torch
import numpy as np
Expand Down Expand Up @@ -653,6 +655,44 @@ def test_multiple_output_quantsim(self):
path=tempdir)
sim.session.run(None, {'input': sample_input})

def test_quantsim_init_memory_usage(self):
"""
When: Instantiate a quantsim model with high activation memory usage
Then: Memory usage should not spike
"""
num_layers = 2 ** 9
activation_dim = 2 ** 13
batch_size = 2 ** 8
total_act_memory = num_layers * activation_dim * batch_size

# Create a model with very high total activation memory usage
layers = [
onnx.helper.make_node("Constant", inputs=[], outputs=["shape"], name="shape",
value=onnx.numpy_helper.from_array(np.array([batch_size, activation_dim], dtype=np.dtype("int64")))),
onnx.helper.make_node("Expand", inputs=["input", "shape"], outputs=["act0"], name="reshape"),
]
for idx in range(num_layers):
layers.append(
onnx.helper.make_node("Sigmoid", inputs=[f"act{idx}"], outputs=[f"act{idx + 1}"],
name=f"layer_{idx}")
)

input_tensor = onnx.helper.make_tensor_value_info("input", onnx.TensorProto.FLOAT, [1, 1])
output_tensor = onnx.helper.make_tensor_value_info(f"act{num_layers}", onnx.TensorProto.FLOAT,
[batch_size, activation_dim])
graph = onnx.helper.make_graph(layers, "graph", initializer=[], inputs=[input_tensor],
outputs=[output_tensor])
model = onnx.helper.make_model(graph)

with tempfile.TemporaryDirectory() as tempdir:
tracemalloc.start()
sim = QuantizationSimModel(model, path=tempdir)
current_mem, peak_mem = tracemalloc.get_traced_memory()
tracemalloc.stop()

assert peak_mem < current_mem + 0.25 * total_act_memory
assert peak_mem < current_mem * 5

@pytest.mark.skip(reason="test requires exact version of torch that the code has built against.")
def test_model_with_custom_ops(self):
custom_ops_path = os.path.dirname(libquant_info.__file__)
Expand Down Expand Up @@ -1690,3 +1730,9 @@ def test_identity_conv_perchannel(self):
config_file=get_path_for_per_channel_config())
assert sim.qc_quantize_op_dict["identity.input"].quant_info.usePerChannelMode
assert sim.qc_quantize_op_dict["identity.input"].quant_info.channelAxis == 0

def test_customop_model(self):
from onnxruntime_extensions import get_library_path
model = models_for_tests.custom_op_model()
sim = QuantizationSimModel(model, user_onnx_libs=[get_library_path()])
assert {"model_input", "output", "model_output", "y", "z"} == sim.qc_quantize_op_dict.keys()

0 comments on commit 2ca89b2

Please sign in to comment.