Skip to content

Commit

Permalink
add test
Browse files Browse the repository at this point in the history
  • Loading branch information
echarlaix committed Oct 2, 2023
1 parent b5fbac1 commit 6c4dc3d
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 12 deletions.
17 changes: 12 additions & 5 deletions optimum/onnxruntime/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,8 +216,12 @@ def prepare_inputs_for_merged(
# Generate dummy past for the first forward if uses a merged decoder
if self.parent_model.use_merged and past_key_values is None:
batch_size = input_ids.shape[0]
num_attention_heads = self.normalized_config.num_attention_heads
embed_size_per_head = self.normalized_config.hidden_size // num_attention_heads

if self.normalized_config.config.model_type != "mistral":
num_attention_heads = self.normalized_config.num_attention_heads
else:
num_attention_heads = self.normalized_config.num_key_value_heads
embed_size_per_head = self.normalized_config.hidden_size // self.normalized_config.num_attention_heads

dtype = constructor.float16 if self.use_fp16 else constructor.float32
# TODO: find a way to better handle this controlflow, this is EXTREMELY ugly
Expand Down Expand Up @@ -277,8 +281,11 @@ def compute_past_key_values_output_shapes(
`Dict[str, List[int]]`: The dictionary mapping each past key value output name to its corresponding shape.
"""
batch_size = input_ids.size(0)
num_attention_heads = self.normalized_config.num_attention_heads
embed_size_per_head = self.normalized_config.hidden_size // num_attention_heads
if self.normalized_config.config.model_type != "mistral":
num_attention_heads = self.normalized_config.num_attention_heads
else:
num_attention_heads = self.normalized_config.num_key_value_heads
embed_size_per_head = self.normalized_config.hidden_size // self.normalized_config.num_attention_heads

sequence_length = input_ids.size(1)
if past_key_values is not None and use_cache_branch is not False:
Expand Down Expand Up @@ -527,7 +534,7 @@ def compute_past_key_values_output_shapes(
) -> Dict[str, int]:
batch_size = input_ids.size(0)
num_attention_heads = self.normalized_config.num_attention_heads
embed_size_per_head = self.normalized_config.hidden_size // num_attention_heads
embed_size_per_head = self.normalized_config.hidden_size // self.normalized_config.num_attention_heads

sequence_length = input_ids.size(1)
encoder_sequence_length = encoder_hidden_states.size(1)
Expand Down
1 change: 1 addition & 0 deletions tests/exporters/exporters_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@
"m2m-100": "hf-internal-testing/tiny-random-m2m_100",
"marian": "sshleifer/tiny-marian-en-de", # hf-internal-testing ones are broken
"mbart": "hf-internal-testing/tiny-random-mbart",
"mistram": "hf-internal-testing/tiny-random-MistralModel",
"mobilebert": "hf-internal-testing/tiny-random-MobileBertModel",
"mobilenet-v2": "hf-internal-testing/tiny-random-MobileNetV2Model",
"mobilenet-v1": "google/mobilenet_v1_0.75_192",
Expand Down
15 changes: 8 additions & 7 deletions tests/onnxruntime/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -1955,6 +1955,7 @@ class ORTModelForCausalLMIntegrationTest(ORTModelTestMixin):
"gpt_neox",
"gptj",
"llama",
"mistral",
"mpt",
]

Expand Down Expand Up @@ -2083,7 +2084,7 @@ def test_compare_to_transformers(self, test_name: str, model_arch: str, use_cach
set_seed(SEED)
transformers_model = AutoModelForCausalLM.from_pretrained(model_id)
transformers_model = transformers_model.eval()
tokenizer = get_preprocessor(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False if model_arch == "mistral" else True)
tokens = tokenizer(
"This is a sample output",
return_tensors="pt",
Expand Down Expand Up @@ -2147,7 +2148,7 @@ def test_pipeline_ort_model(self, test_name: str, model_arch: str, use_cache: bo
use_io_binding=use_io_binding,
)

tokenizer = get_preprocessor(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False if model_arch == "mistral" else True)
pipe = pipeline("text-generation", model=onnx_model, tokenizer=tokenizer)
text = "My Name is Philipp and i live"
outputs = pipe(text)
Expand Down Expand Up @@ -2178,7 +2179,7 @@ def test_pipeline_on_gpu(self, test_name: str, model_arch: str, use_cache: bool)
model_id = MODEL_NAMES[model_arch]
onnx_model = ORTModelForCausalLM.from_pretrained(self.onnx_model_dirs[test_name])

tokenizer = get_preprocessor(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False if model_arch == "mistral" else True)
pipe = pipeline("text-generation", model=onnx_model, tokenizer=tokenizer, device=0)
text = "My Name is Philipp and i live"
outputs = pipe(text)
Expand Down Expand Up @@ -2209,7 +2210,7 @@ def test_pipeline_on_trt_execution_provider(self, test_name: str, model_arch: st
use_cache=use_cache,
)

tokenizer = get_preprocessor(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False if model_arch == "mistral" else True)

# build engine for a short sequence
text = ["short"]
Expand Down Expand Up @@ -2253,7 +2254,7 @@ def test_compare_with_and_without_past_key_values(self, model_arch):
self._setup(model_args)

model_id = MODEL_NAMES[model_arch]
tokenizer = get_preprocessor(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False if model_arch == "mistral" else True)
text = "My Name is Philipp and i live"
tokens = tokenizer(text, return_tensors="pt", return_token_type_ids=False if model_arch == "llama" else None)

Expand Down Expand Up @@ -2304,7 +2305,7 @@ def test_compare_merged_and_not_merged_models_outputs(self, test_name: str, mode
self._setup(model_args)

model_id = MODEL_NAMES[model_arch]
tokenizer = get_preprocessor(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False if model_arch == "mistral" else True)
text = "My Name is Philipp and i live"
tokens = tokenizer(text, return_tensors="pt", return_token_type_ids=False if model_arch == "llama" else None)

Expand Down Expand Up @@ -2349,7 +2350,7 @@ def test_compare_to_io_binding(self, test_name: str, model_arch: str, use_cache:
self.onnx_model_dirs[test_name], use_cache=use_cache, use_io_binding=True
).to("cuda")

tokenizer = get_preprocessor(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False if model_arch == "mistral" else True)
tokens = tokenizer(["This is a sample output"] * 2, return_tensors="pt").to("cuda")

position_ids = None
Expand Down
1 change: 1 addition & 0 deletions tests/onnxruntime/utils_onnxruntime_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@
"m2m_100": "hf-internal-testing/tiny-random-m2m_100",
"marian": "sshleifer/tiny-marian-en-de", # hf-internal-testing ones are broken
"mbart": "hf-internal-testing/tiny-random-mbart",
"mistral": "hf-internal-testing/tiny-random-MistralForCausalLM",
"mobilebert": "hf-internal-testing/tiny-random-MobileBertModel",
"mobilenet_v1": "google/mobilenet_v1_0.75_192",
"mobilenet_v2": "hf-internal-testing/tiny-random-MobileNetV2Model",
Expand Down

0 comments on commit 6c4dc3d

Please sign in to comment.