diff --git a/optimum/onnxruntime/quantization.py b/optimum/onnxruntime/quantization.py index 31140c5b747..12313008b2c 100644 --- a/optimum/onnxruntime/quantization.py +++ b/optimum/onnxruntime/quantization.py @@ -279,6 +279,10 @@ def compute_ranges(self) -> Dict[str, Tuple[float, float]]: ) LOGGER.info("Computing calibration ranges") + + if parse(ort_version) >= Version("1.16.0"): + return self._calibrator.compute_data() + return self._calibrator.compute_range() def quantize( @@ -351,8 +355,13 @@ def quantize( has_subgraphs = True break - if quantization_config.is_static and has_subgraphs: - raise NotImplementedError("Static quantization is currently not supported for models with" " subgraphs.") + if has_subgraphs: + if quantization_config.is_static: + raise NotImplementedError("Static quantization is currently not supported for models with subgraphs.") + if parse(ort_version) == Version("1.16.0"): + raise ValueError( + "ONNX Runtime version v1.16.0 is not compatible with quantization for models with subgraphs, please downgrade to 1.15.1 or upgrade to a higher version. Reference: https://github.com/microsoft/onnxruntime/pull/17651" + ) quantizer_factory = QDQQuantizer if use_qdq else ONNXQuantizer diff --git a/tests/cli/test_cli.py b/tests/cli/test_cli.py index 7e94d58db7b..147c0bd258f 100644 --- a/tests/cli/test_cli.py +++ b/tests/cli/test_cli.py @@ -21,6 +21,9 @@ import unittest from pathlib import Path +from onnxruntime import __version__ as ort_version +from packaging.version import Version, parse + import optimum.commands @@ -84,14 +87,22 @@ def test_quantize_commands(self): export_commands = [ f"optimum-cli export onnx --model hf-internal-testing/tiny-random-BertModel {tempdir}/encoder", f"optimum-cli export onnx --model hf-internal-testing/tiny-random-gpt2 {tempdir}/decoder", - f"optimum-cli export onnx --model hf-internal-testing/tiny-random-t5 {tempdir}/encoder-decoder", + # f"optimum-cli export onnx --model hf-internal-testing/tiny-random-t5 {tempdir}/encoder-decoder", ] quantize_commands = [ f"optimum-cli onnxruntime quantize --onnx_model {tempdir}/encoder --avx2 -o {tempdir}/quantized_encoder", f"optimum-cli onnxruntime quantize --onnx_model {tempdir}/decoder --avx2 -o {tempdir}/quantized_decoder", - f"optimum-cli onnxruntime quantize --onnx_model {tempdir}/encoder-decoder --avx2 -o {tempdir}/quantized_encoder_decoder", + # f"optimum-cli onnxruntime quantize --onnx_model {tempdir}/encoder-decoder --avx2 -o {tempdir}/quantized_encoder_decoder", ] + if parse(ort_version) != Version("1.16.0"): + export_commands.append( + f"optimum-cli export onnx --model hf-internal-testing/tiny-random-t5 {tempdir}/encoder-decoder" + ) + quantize_commands.append( + f"optimum-cli onnxruntime quantize --onnx_model {tempdir}/encoder-decoder --avx2 -o {tempdir}/quantized_encoder_decoder" + ) + for export, quantize in zip(export_commands, quantize_commands): subprocess.run(export, shell=True, check=True) subprocess.run(quantize, shell=True, check=True) diff --git a/tests/onnxruntime/test_quantization.py b/tests/onnxruntime/test_quantization.py index 111c7338808..aff1b51b534 100644 --- a/tests/onnxruntime/test_quantization.py +++ b/tests/onnxruntime/test_quantization.py @@ -19,7 +19,9 @@ from pathlib import Path from onnx import load as onnx_load +from onnxruntime import __version__ as ort_version from onnxruntime.quantization import QuantFormat, QuantizationMode, QuantType +from packaging.version import Version, parse from parameterized import parameterized from transformers import AutoTokenizer @@ -112,9 +114,9 @@ def test_dynamic_quantization(self, model_cls, model_name, expected_quantized_ma self.assertEqual(expected_quantized_matmuls, num_quantized_matmul) gc.collect() + @unittest.skipIf(parse(ort_version) == Version("1.16.0"), "not supported with this onnxruntime version") def test_dynamic_quantization_subgraphs(self): qconfig = AutoQuantizationConfig.avx512(is_static=False, per_channel=True) - # with tempfile.TemporaryDirectory() as tmp_dir: tmp_dir = tempfile.mkdtemp() output_dir = Path(tmp_dir) model = ORTModelForCausalLM.from_pretrained(