From 7e8d857d1ed6be32046324bf8f424690f116b4e9 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Thu, 31 Oct 2024 14:54:05 -0600 Subject: [PATCH] Add ONNX export support for granite models (#2043) * feat(exporters/onnx): Add GraniteOnnxConfig and task support list Branch: OnnxGranite Signed-off-by: Gabe Goodhart * feat: Add granite's normalized config for inference Branch: OnnxGranite Signed-off-by: Gabe Goodhart * feat(onnx opt): Add onnx optimization support for granite Branch: OnnxGranite Signed-off-by: Gabe Goodhart * fix(onnx/granite): Use LlamaOnnxConfig as the base for GraniteOnnxConfig Branch: OnnxGranite Signed-off-by: Gabe Goodhart * fix(onnxruntime): Add "granite" to list of model types with grouped attention Branch: OnnxGranite Signed-off-by: Gabe Goodhart * fix: Add granite to the list of models that require position_ids Branch: OnnxGranite Signed-off-by: Gabe Goodhart * fix(granite): Add MIN_TORCH_VERSION for recently fixed torch bug https://github.com/huggingface/optimum/pull/2043#issuecomment-2427975461 Branch: OnnxGranite Signed-off-by: Gabe Goodhart * test(granite): Add tiny random granite test for onnx exporter Branch: OnnxGranite Signed-off-by: Gabe Goodhart * tests(onnxruntime): Add granite to onnxruntime tests Branch: OnnxGranite Signed-off-by: Gabe Goodhart --------- Signed-off-by: Gabe Goodhart --- optimum/exporters/onnx/model_configs.py | 5 +++++ optimum/exporters/onnx/utils.py | 1 + optimum/exporters/tasks.py | 7 +++++++ optimum/onnxruntime/modeling_decoder.py | 2 +- optimum/onnxruntime/utils.py | 1 + optimum/utils/normalized_config.py | 1 + tests/exporters/exporters_utils.py | 1 + tests/onnxruntime/test_modeling.py | 1 + tests/onnxruntime/utils_onnxruntime_tests.py | 1 + 9 files changed, 19 insertions(+), 1 deletion(-) diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index 9e57128c272..cc752779d30 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -298,6 +298,11 @@ class GemmaOnnxConfig(LlamaOnnxConfig): pass +class GraniteOnnxConfig(LlamaOnnxConfig): + MIN_TRANSFORMERS_VERSION = version.parse("4.45.0") + MIN_TORCH_VERSION = version.parse("2.5.0") + + class PhiOnnxConfig(TextDecoderWithPositionIdsOnnxConfig): DEFAULT_ONNX_OPSET = 14 # Phi now uses F.scaled_dot_product_attention by default for torch>=2.1.1. NORMALIZED_CONFIG_CLASS = NormalizedTextConfig diff --git a/optimum/exporters/onnx/utils.py b/optimum/exporters/onnx/utils.py index 56249bbf5c3..19e24f88743 100644 --- a/optimum/exporters/onnx/utils.py +++ b/optimum/exporters/onnx/utils.py @@ -86,6 +86,7 @@ "phi", "phi3", "qwen2", + "granite", } diff --git a/optimum/exporters/tasks.py b/optimum/exporters/tasks.py index a489f34fb06..fdc8bfcb539 100644 --- a/optimum/exporters/tasks.py +++ b/optimum/exporters/tasks.py @@ -915,6 +915,13 @@ class TasksManager: "text-classification", onnx="LlamaOnnxConfig", ), + "granite": supported_tasks_mapping( + "feature-extraction", + "feature-extraction-with-past", + "text-generation", + "text-generation-with-past", + onnx="GraniteOnnxConfig", + ), "pegasus": supported_tasks_mapping( "feature-extraction", "feature-extraction-with-past", diff --git a/optimum/onnxruntime/modeling_decoder.py b/optimum/onnxruntime/modeling_decoder.py index 984d7f22ebf..8f1d062221a 100644 --- a/optimum/onnxruntime/modeling_decoder.py +++ b/optimum/onnxruntime/modeling_decoder.py @@ -340,7 +340,7 @@ def prepare_past_key_values( if self.model_type == "gemma": num_attention_heads = self.normalized_config.num_key_value_heads embed_size_per_head = self.normalized_config.head_dim - elif self.model_type in {"mistral", "llama", "qwen2"}: + elif self.model_type in {"mistral", "llama", "qwen2", "granite"}: num_attention_heads = self.normalized_config.num_key_value_heads else: num_attention_heads = self.normalized_config.num_attention_heads diff --git a/optimum/onnxruntime/utils.py b/optimum/onnxruntime/utils.py index 128e2406f11..9e92e0bd325 100644 --- a/optimum/onnxruntime/utils.py +++ b/optimum/onnxruntime/utils.py @@ -128,6 +128,7 @@ class ORTConfigManager: "gpt-neo": "gpt2", "gpt-neox": "gpt2", "gptj": "gpt2", + "granite": "gpt2", # longt5 with O4 results in segmentation fault "longt5": "bert", "llama": "gpt2", diff --git a/optimum/utils/normalized_config.py b/optimum/utils/normalized_config.py index 81207b76496..9ceed24c2dd 100644 --- a/optimum/utils/normalized_config.py +++ b/optimum/utils/normalized_config.py @@ -281,6 +281,7 @@ class NormalizedConfigManager: "xlm-roberta": NormalizedTextConfig, "yolos": NormalizedVisionConfig, "qwen2": NormalizedTextConfig, + "granite": NormalizedTextConfigWithGQA, } @classmethod diff --git a/tests/exporters/exporters_utils.py b/tests/exporters/exporters_utils.py index c8a33b0be35..ccccb5510bf 100644 --- a/tests/exporters/exporters_utils.py +++ b/tests/exporters/exporters_utils.py @@ -100,6 +100,7 @@ "gpt-neo": "hf-internal-testing/tiny-random-GPTNeoModel", "gpt-neox": "hf-internal-testing/tiny-random-GPTNeoXForCausalLM", "gptj": "hf-internal-testing/tiny-random-GPTJModel", + "granite": "hf-internal-testing/tiny-random-GraniteForCausalLM", "groupvit": "hf-internal-testing/tiny-random-groupvit", "ibert": "hf-internal-testing/tiny-random-IBertModel", "imagegpt": "hf-internal-testing/tiny-random-ImageGPTModel", diff --git a/tests/onnxruntime/test_modeling.py b/tests/onnxruntime/test_modeling.py index 597eb581e2a..a335e014478 100644 --- a/tests/onnxruntime/test_modeling.py +++ b/tests/onnxruntime/test_modeling.py @@ -2324,6 +2324,7 @@ class ORTModelForCausalLMIntegrationTest(ORTModelTestMixin): "gpt_neo", "gpt_neox", "gptj", + "granite", "llama", "mistral", "mpt", diff --git a/tests/onnxruntime/utils_onnxruntime_tests.py b/tests/onnxruntime/utils_onnxruntime_tests.py index e3d54237857..9f200e69b3d 100644 --- a/tests/onnxruntime/utils_onnxruntime_tests.py +++ b/tests/onnxruntime/utils_onnxruntime_tests.py @@ -104,6 +104,7 @@ "gpt_neo": "hf-internal-testing/tiny-random-GPTNeoModel", "gpt_neox": "hf-internal-testing/tiny-random-GPTNeoXForCausalLM", "gptj": "hf-internal-testing/tiny-random-GPTJForCausalLM", + "granite": "hf-internal-testing/tiny-random-GraniteForCausalLM", "groupvit": "hf-internal-testing/tiny-random-groupvit", "hubert": "hf-internal-testing/tiny-random-HubertModel", "ibert": "hf-internal-testing/tiny-random-IBertModel",