From 6c4dc3ddd916673bf49d0d8f9484768537461011 Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Mon, 2 Oct 2023 18:34:37 +0200 Subject: [PATCH] add test --- optimum/onnxruntime/base.py | 17 ++++++++++++----- tests/exporters/exporters_utils.py | 1 + tests/onnxruntime/test_modeling.py | 15 ++++++++------- tests/onnxruntime/utils_onnxruntime_tests.py | 1 + 4 files changed, 22 insertions(+), 12 deletions(-) diff --git a/optimum/onnxruntime/base.py b/optimum/onnxruntime/base.py index 75cd968e4fa..419d3417c30 100644 --- a/optimum/onnxruntime/base.py +++ b/optimum/onnxruntime/base.py @@ -216,8 +216,12 @@ def prepare_inputs_for_merged( # Generate dummy past for the first forward if uses a merged decoder if self.parent_model.use_merged and past_key_values is None: batch_size = input_ids.shape[0] - num_attention_heads = self.normalized_config.num_attention_heads - embed_size_per_head = self.normalized_config.hidden_size // num_attention_heads + + if self.normalized_config.config.model_type != "mistral": + num_attention_heads = self.normalized_config.num_attention_heads + else: + num_attention_heads = self.normalized_config.num_key_value_heads + embed_size_per_head = self.normalized_config.hidden_size // self.normalized_config.num_attention_heads dtype = constructor.float16 if self.use_fp16 else constructor.float32 # TODO: find a way to better handle this controlflow, this is EXTREMELY ugly @@ -277,8 +281,11 @@ def compute_past_key_values_output_shapes( `Dict[str, List[int]]`: The dictionary mapping each past key value output name to its corresponding shape. """ batch_size = input_ids.size(0) - num_attention_heads = self.normalized_config.num_attention_heads - embed_size_per_head = self.normalized_config.hidden_size // num_attention_heads + if self.normalized_config.config.model_type != "mistral": + num_attention_heads = self.normalized_config.num_attention_heads + else: + num_attention_heads = self.normalized_config.num_key_value_heads + embed_size_per_head = self.normalized_config.hidden_size // self.normalized_config.num_attention_heads sequence_length = input_ids.size(1) if past_key_values is not None and use_cache_branch is not False: @@ -527,7 +534,7 @@ def compute_past_key_values_output_shapes( ) -> Dict[str, int]: batch_size = input_ids.size(0) num_attention_heads = self.normalized_config.num_attention_heads - embed_size_per_head = self.normalized_config.hidden_size // num_attention_heads + embed_size_per_head = self.normalized_config.hidden_size // self.normalized_config.num_attention_heads sequence_length = input_ids.size(1) encoder_sequence_length = encoder_hidden_states.size(1) diff --git a/tests/exporters/exporters_utils.py b/tests/exporters/exporters_utils.py index 34ecf444212..d409cf4b8ac 100644 --- a/tests/exporters/exporters_utils.py +++ b/tests/exporters/exporters_utils.py @@ -91,6 +91,7 @@ "m2m-100": "hf-internal-testing/tiny-random-m2m_100", "marian": "sshleifer/tiny-marian-en-de", # hf-internal-testing ones are broken "mbart": "hf-internal-testing/tiny-random-mbart", + "mistram": "hf-internal-testing/tiny-random-MistralModel", "mobilebert": "hf-internal-testing/tiny-random-MobileBertModel", "mobilenet-v2": "hf-internal-testing/tiny-random-MobileNetV2Model", "mobilenet-v1": "google/mobilenet_v1_0.75_192", diff --git a/tests/onnxruntime/test_modeling.py b/tests/onnxruntime/test_modeling.py index ed5b1b3158c..971cc9faccc 100644 --- a/tests/onnxruntime/test_modeling.py +++ b/tests/onnxruntime/test_modeling.py @@ -1955,6 +1955,7 @@ class ORTModelForCausalLMIntegrationTest(ORTModelTestMixin): "gpt_neox", "gptj", "llama", + "mistral", "mpt", ] @@ -2083,7 +2084,7 @@ def test_compare_to_transformers(self, test_name: str, model_arch: str, use_cach set_seed(SEED) transformers_model = AutoModelForCausalLM.from_pretrained(model_id) transformers_model = transformers_model.eval() - tokenizer = get_preprocessor(model_id) + tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False if model_arch == "mistral" else True) tokens = tokenizer( "This is a sample output", return_tensors="pt", @@ -2147,7 +2148,7 @@ def test_pipeline_ort_model(self, test_name: str, model_arch: str, use_cache: bo use_io_binding=use_io_binding, ) - tokenizer = get_preprocessor(model_id) + tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False if model_arch == "mistral" else True) pipe = pipeline("text-generation", model=onnx_model, tokenizer=tokenizer) text = "My Name is Philipp and i live" outputs = pipe(text) @@ -2178,7 +2179,7 @@ def test_pipeline_on_gpu(self, test_name: str, model_arch: str, use_cache: bool) model_id = MODEL_NAMES[model_arch] onnx_model = ORTModelForCausalLM.from_pretrained(self.onnx_model_dirs[test_name]) - tokenizer = get_preprocessor(model_id) + tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False if model_arch == "mistral" else True) pipe = pipeline("text-generation", model=onnx_model, tokenizer=tokenizer, device=0) text = "My Name is Philipp and i live" outputs = pipe(text) @@ -2209,7 +2210,7 @@ def test_pipeline_on_trt_execution_provider(self, test_name: str, model_arch: st use_cache=use_cache, ) - tokenizer = get_preprocessor(model_id) + tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False if model_arch == "mistral" else True) # build engine for a short sequence text = ["short"] @@ -2253,7 +2254,7 @@ def test_compare_with_and_without_past_key_values(self, model_arch): self._setup(model_args) model_id = MODEL_NAMES[model_arch] - tokenizer = get_preprocessor(model_id) + tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False if model_arch == "mistral" else True) text = "My Name is Philipp and i live" tokens = tokenizer(text, return_tensors="pt", return_token_type_ids=False if model_arch == "llama" else None) @@ -2304,7 +2305,7 @@ def test_compare_merged_and_not_merged_models_outputs(self, test_name: str, mode self._setup(model_args) model_id = MODEL_NAMES[model_arch] - tokenizer = get_preprocessor(model_id) + tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False if model_arch == "mistral" else True) text = "My Name is Philipp and i live" tokens = tokenizer(text, return_tensors="pt", return_token_type_ids=False if model_arch == "llama" else None) @@ -2349,7 +2350,7 @@ def test_compare_to_io_binding(self, test_name: str, model_arch: str, use_cache: self.onnx_model_dirs[test_name], use_cache=use_cache, use_io_binding=True ).to("cuda") - tokenizer = get_preprocessor(model_id) + tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False if model_arch == "mistral" else True) tokens = tokenizer(["This is a sample output"] * 2, return_tensors="pt").to("cuda") position_ids = None diff --git a/tests/onnxruntime/utils_onnxruntime_tests.py b/tests/onnxruntime/utils_onnxruntime_tests.py index 985d0340350..b529b2c9715 100644 --- a/tests/onnxruntime/utils_onnxruntime_tests.py +++ b/tests/onnxruntime/utils_onnxruntime_tests.py @@ -74,6 +74,7 @@ "m2m_100": "hf-internal-testing/tiny-random-m2m_100", "marian": "sshleifer/tiny-marian-en-de", # hf-internal-testing ones are broken "mbart": "hf-internal-testing/tiny-random-mbart", + "mistral": "hf-internal-testing/tiny-random-MistralForCausalLM", "mobilebert": "hf-internal-testing/tiny-random-MobileBertModel", "mobilenet_v1": "google/mobilenet_v1_0.75_192", "mobilenet_v2": "hf-internal-testing/tiny-random-MobileNetV2Model",