Skip to content

Commit

Permalink
Merge branch 'master' into mengni/calib_fix
Browse files Browse the repository at this point in the history
  • Loading branch information
mengniwang95 authored Mar 12, 2024
2 parents 6413a5c + d8e60b8 commit afbf280
Show file tree
Hide file tree
Showing 9 changed files with 195 additions and 53 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,10 @@ def export_onnx_model(args, model):
config=AutoConfig.from_pretrained(args.input_model))

if args.input_model == 'Intel/bart-large-mrpc':
import os
os.system('python -m transformers.onnx --model=Intel/bart-large-mrpc --feature=sequence-classification --export_with_transformers bart-large-mrpc/')
import shutil
from optimum.exporters.onnx import main_export

main_export(args.input_model, output="bart-large-mrpc", task="text-classification")
shutil.move("bart-large-mrpc/model.onnx", args.output_model)
else:
export_onnx_model(args, model)
export_onnx_model(args, model)
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,5 @@ sympy
onnxruntime-extensions; python_version < '3.11'
numpy==1.23.5
sentencepiece
protobuf<=3.20.3
protobuf<=3.20.3
optimum
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,10 @@ def export_onnx_model(args, model):
config=AutoConfig.from_pretrained(args.input_model))

if args.input_model == 'Intel/bart-large-mrpc':
import os
os.system('python -m transformers.onnx --model=Intel/bart-large-mrpc --feature=sequence-classification --export_with_transformers bart-large-mrpc/')
import shutil
from optimum.exporters.onnx import main_export

main_export(args.input_model, output="bart-large-mrpc", task="text-classification")
shutil.move("bart-large-mrpc/model.onnx", args.output_model)
else:
export_onnx_model(args, model)
export_onnx_model(args, model)
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@ onnxruntime-extensions; python_version < '3.11'
numpy==1.23.5
sentencepiece
protobuf<=3.20.3
optimum[exporters]
optimum
Original file line number Diff line number Diff line change
Expand Up @@ -687,7 +687,10 @@ def eval_func(model):
example_inputs = get_example_inputs(model, eval_dataloader)
model = ipex.optimize(model)
with torch.no_grad():
model = torch.jit.trace(model, example_kwarg_inputs=example_inputs, strict=False)
if isinstance(example_inputs, dict):
model = torch.jit.trace(model, example_kwarg_inputs=example_inputs, strict=False)
else:
model = torch.jit.trace(model, example_inputs, strict=False)
model = torch.jit.freeze(model)

if model_args.benchmark or model_args.accuracy_only:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -263,10 +263,14 @@ def do_transformation(self):

weight_node = self.graph_info[new_node.input[1]].node
bias_node = self.graph_info[new_node.input[2]].node
max_input_node = self.graph_info[last_node.input[-1]].node
min_input_node = self.graph_info[last_node.input[-2]].node

if max_input_node.op == "Enter": # pragma: no cover
max_input_node = None
min_input_node = None
if last_node.op.find("Requantize") != -1 or last_node.op.find("QuantizeV2") != -1:
max_input_node = self.graph_info[last_node.input[-1]].node
min_input_node = self.graph_info[last_node.input[-2]].node

if max_input_node and max_input_node.op == "Enter": # pragma: no cover
min_input_parent_name = Helper.node_name_from_input(min_input_node.input[0])
max_input_parent_name = Helper.node_name_from_input(max_input_node.input[0])
min_input_parent_node = self.graph_info[min_input_parent_name].node
Expand Down
171 changes: 131 additions & 40 deletions neural_compressor/model/tensorflow_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,45 @@ def load_saved_model(model, saved_model_tags, input_tensor_names, output_tensor_
return opt, input_tensor_names, output_tensor_names


def _get_graph_from_saved_model_v3(model, input_tensor_names, output_tensor_names):
"""The version 3 function that get graph from saved_model.
Args:
model (string or tf.keras.Model): model path or tf.keras.Model object.
input_tensor_names (list of string): input tensor names of the model.
output_tensor_names (list of string): output tensor names of the model.
Returns:
graph_def (tf.compat.v1.Session): tf.compat.v1.Session object.
inputs (list of string): validated input names.
outputs (list of string): validated output names.
"""
from neural_compressor.adaptor.tf_utils.util import parse_saved_model

if isinstance(model, tf.keras.Model):
tmp_dir = cfg.default_workspace + "/saved_model"
model.save(tmp_dir)
model = tmp_dir
graph_def, _, _, _, input_names, output_names = parse_saved_model(
model, True, input_tensor_names, output_tensor_names
)

return graph_def, input_names, output_names


def _get_graph_from_saved_model_v2(saved_model_dir, input_tensor_names, output_tensor_names):
"""The version 2 function that get graph from the original keras model.
Args:
saved_model_dir (string): model path of a temporary saved_model.
input_tensor_names (list of string): input tensor names of the model.
output_tensor_names (list of string): output tensor names of the model.
Returns:
graph_def (tf.compat.v1.Session): tf.compat.v1.Session object.
input_names (list of string): validated input names.
output_names (list of string): validated output names.
"""
from tensorflow.python.saved_model import signature_constants, tag_constants

saved_model_exported_names = [signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
Expand All @@ -319,7 +357,17 @@ def _get_graph_from_saved_model_v2(saved_model_dir, input_tensor_names, output_t
return load_saved_model(saved_model_dir, saved_model_tags, input_tensor_names, output_tensor_names)


def _get_graph_from_original_keras_v2(model, output_dir):
def _get_graph_from_original_keras_v2(model):
"""The version 2 function that get graph from the original keras model.
Args:
model (string or tf.keras.Model): model path or tf.keras.Model object.
Returns:
graph_def (tf.compat.v1.Session): tf.compat.v1.Session object.
input_names (list of string): validated input names.
output_names (list of string): validated output names.
"""
from tensorflow.lite.python.convert import OpsSet
from tensorflow.lite.python.util import (
get_grappler_config,
Expand Down Expand Up @@ -364,6 +412,17 @@ def _get_graph_from_original_keras_v2(model, output_dir):


def _check_keras_format(model, saved_model_dir):
"""Decide which method will be used to get graph from the saved_model .
Args:
model (string or tf.keras.Model): model path or tf.keras.Model object.
saved_model_dir (string): the path to save a temporary saved_model.
Returns:
graph_def (tf.compat.v1.Session): tf.compat.v1.Session object.
inputs (list of string): validated input names.
outputs (list of string): validated output names.
"""
from tensorflow.python import saved_model
from tensorflow.python.saved_model import save_options
from tensorflow.python.saved_model.load import load
Expand All @@ -384,6 +443,16 @@ def _check_keras_format(model, saved_model_dir):


def _get_graph_from_saved_model_v1(model):
"""The version 1 function that get graph from saved_model.
Args:
model (string or tf.keras.Model): model path or tf.keras.Model object.
Returns:
graph_def (tf.compat.v1.Session): tf.compat.v1.Session object.
inputs (list of string): validated input names.
outputs (list of string): validated output names.
"""
from tensorflow.lite.python.convert_saved_model import get_inputs_outputs, get_meta_graph_def, get_signature_def
from tensorflow.python.client import session
from tensorflow.python.framework import ops
Expand Down Expand Up @@ -424,6 +493,51 @@ def _get_graph_from_saved_model_v1(model):
return graph_def, inputs, outputs


def try_loading_keras(model, input_tensor_names, output_tensor_names):
"""Try different ways of loading keras models.
Args:
model (string or tf.keras.Model): model path or tf.keras.Model object.
input_tensor_names (list of string): input tensor names of the model.
output_tensor_names (list of string): output tensor names of the model.
Returns:
graph_def (tf.compat.v1.Session): tf.compat.v1.Session object.
input_names (list of string): validated input names.
output_names (list of string): validated output names.
"""
temp_dir = tempfile.mkdtemp()
if not isinstance(model, tf.keras.Model):
model = tf.keras.models.load_model(model)
keras_format = _check_keras_format(model, temp_dir)

if keras_format == "saved_model_v2":
try:
graph_def, input_names, output_names = _get_graph_from_saved_model_v2(
temp_dir, input_tensor_names, output_tensor_names
)
if "_FusedBatchNormEx" in [node.op for node in graph_def.node]:
keras_format = "trackable_object"
except:
keras_format = "trackable_object"

if keras_format == "trackable_object":
try:
graph_def, input_names, output_names = _get_graph_from_original_keras_v2(model)
except:
keras_format = "saved_model_v1"

if keras_format == "saved_model_v1": # pragma: no cover
try:
tf.keras.backend.set_learning_phase(0)
graph_def, input_names, output_names = _get_graph_from_saved_model_v1(model)
except:
raise ValueError("Not supported keras model type...")

shutil.rmtree(temp_dir, True)
return graph_def, input_names, output_names


def keras_session(model, input_tensor_names, output_tensor_names, **kwargs):
"""Build session with keras model.
Expand All @@ -434,49 +548,19 @@ def keras_session(model, input_tensor_names, output_tensor_names, **kwargs):
Returns:
sess (tf.compat.v1.Session): tf.compat.v1.Session object.
input_tensor_names (list of string): validated input_tensor_names.
output_tensor_names (list of string): validated output_tensor_names.
"""
temp_dir = tempfile.mkdtemp()
if tf.version.VERSION > "2.1.0":
if not isinstance(model, tf.keras.Model):
model = tf.keras.models.load_model(model)
keras_format = _check_keras_format(model, temp_dir)
if keras_format == "saved_model_v2":
try:
graph_def, input_names, output_names = _get_graph_from_saved_model_v2(
temp_dir, input_tensor_names, output_tensor_names
)
if "_FusedBatchNormEx" in [node.op for node in graph_def.node]:
keras_format = "trackable_object"
except:
keras_format = "trackable_object"
if keras_format == "trackable_object":
try:
graph_def, input_names, output_names = _get_graph_from_original_keras_v2(model, temp_dir)
except:
keras_format = "saved_model_v1"
if keras_format == "saved_model_v1": # pragma: no cover
try:
tf.keras.backend.set_learning_phase(0)
graph_def, input_names, output_names = _get_graph_from_saved_model_v1(model)
except:
keras_format = "saved_model_general"
if keras_format == "saved_model_general": # pargma: no cover
try:
from neural_compressor.adaptor.tf_utils.util import parse_saved_model

graph_def, _saved_model, _, _, input_names, output_names = parse_saved_model(
temp_dir, True, input_tensor_names, output_tensor_names
)
except:
raise ValueError("Not supported keras model type...")

try:
graph_def, input_names, output_names = _get_graph_from_saved_model_v3(
model, input_tensor_names, output_tensor_names
)
except:
graph_def, input_names, output_names = try_loading_keras(model, input_tensor_names, output_tensor_names)
# tensorflow 1.x use v1 convert method
else:
tf.keras.backend.set_learning_phase(0)
graph_def, input_names, output_names = _get_graph_from_saved_model_v1(model)
shutil.rmtree(temp_dir, True)

return graph_def_session(graph_def, input_names, output_names, **kwargs)


Expand Down Expand Up @@ -645,12 +729,19 @@ def saved_model_session(model, input_tensor_names, output_tensor_names, **kwargs
output_tensor_names (list of string): validated output_tensor_names.
"""
try:
graph_def, input_names, output_names = _get_graph_from_saved_model_v2(
graph_def, input_names, output_names = _get_graph_from_saved_model_v3(
model, input_tensor_names, output_tensor_names
)
except:
graph_def, input_names, output_names = _get_graph_from_saved_model_v1(model)
try:
graph_def, input_names, output_names = _get_graph_from_saved_model_v2(
model, input_tensor_names, output_tensor_names
)
except:
graph_def, input_names, output_names = _get_graph_from_saved_model_v1(model)

assert graph_def is not None, "Can not parse the saved model..."

return graph_def_session(graph_def, input_names, output_names, **kwargs)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,8 @@ def qdq_quantize(
if ipex_ver.release > Version("2.1.0").release:
update_sq_scale(ipex_config_path, smoothquant_scale_info)
model.load_qconf_summary(qconf_summary=ipex_config_path)
_ipex_post_quant_process(model, example_inputs, inplace=inplace)
model.save_qconf_summary(qconf_summary=ipex_config_path)
model = _ipex_post_quant_process(model, example_inputs, inplace=inplace)

with open(ipex_config_path, "r") as f:
model.tune_cfg = json.load(f)
Expand Down
36 changes: 36 additions & 0 deletions test/3x/torch/quantization/test_smooth_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,3 +82,39 @@ def run_fn(model):
output1 = fp32_model(example_inputs)
output2 = q_model(example_inputs)
assert torch.allclose(output1, output2, atol=2e-2), "Accuracy gap atol > 0.02 is unexpected. Please check."

@pytest.mark.skipif(not is_ipex_available(), reason="Requires IPEX")
def test_sq_ipex_save_load(self):
from intel_extension_for_pytorch.quantization import convert, prepare

example_inputs = torch.zeros([1, 3])
qconfig = ipex.quantization.get_smooth_quant_qconfig_mapping(alpha=0.5)
user_model = copy.deepcopy(model)
user_model = prepare(user_model.eval(), qconfig, example_inputs=example_inputs, inplace=True)

def run_fn(model):
model(example_inputs)

run_fn(user_model)
with torch.no_grad():
user_model = convert(user_model.eval(), inplace=True).eval()
user_model(example_inputs)
user_model = torch.jit.trace(user_model.eval(), example_inputs, strict=False)
user_model = torch.jit.freeze(user_model.eval())
user_model(example_inputs)
user_model(example_inputs)
ipex_out = user_model(example_inputs)

fp32_model = copy.deepcopy(model)
quant_config = get_default_sq_config()
q_model = quantize(fp32_model, quant_config=quant_config, run_fn=run_fn, example_inputs=example_inputs)
assert q_model is not None, "Quantization failed!"
inc_out = q_model(example_inputs)
q_model.save("saved")

# load
loaded_model = torch.jit.load("saved")
loaded_out = loaded_model(example_inputs)
assert torch.allclose(inc_out, ipex_out, atol=1e-05), "Unexpected result. Please double check."

assert torch.allclose(inc_out, loaded_out, atol=1e-05), "Unexpected result. Please double check."

0 comments on commit afbf280

Please sign in to comment.