Skip to content

Commit

Permalink
use input dims
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasMoutawwakil committed Sep 2, 2024
1 parent 5366877 commit d957547
Showing 1 changed file with 11 additions and 5 deletions.
16 changes: 11 additions & 5 deletions optimum/onnxruntime/modeling_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ def prepare_past_key_values(
dtype = constructor.float16 if self.use_fp16 else constructor.float32

# TODO: find a way to better handle this controlflow, this is EXTREMELY UGLY.
if self.model_type == "bloom" and not check_if_transformers_greater("4.44"):
if self.__class__.__name__ == "ORTBloomForCausalLM":
shape_value = (batch_size * num_attention_heads, 0, embed_size_per_head)
shape_key = (batch_size * num_attention_heads, embed_size_per_head, 0)
key = constructor.zeros(shape_key, dtype=dtype)
Expand Down Expand Up @@ -533,9 +533,9 @@ def _from_pretrained(

# Since https://github.com/huggingface/optimum/pull/871/
# changed axis notation/naming during export, we need to update the dims
for dim in input_dims.keys():
if "past" in dim and input_dims[dim][2] == "past_sequence_length + sequence_length":
input_dims[dim][2] = "past_sequence_length"
for input_name in input_dims.keys():
if "past" in input_name and input_dims[input_name][2] == "past_sequence_length + sequence_length":
input_dims[input_name][2] = "past_sequence_length"
override_dims = True

if override_dims:
Expand All @@ -558,6 +558,12 @@ def _from_pretrained(
size_threshold=0,
)

# Since transformers 4.44, the bloom model has been updated to use the standard cache format
use_old_bloom_modeling = not check_if_transformers_greater("4.44")
for input_name in input_dims.keys():
if input_dims[input_name][0] == "batch_size x num_heads":
use_old_bloom_modeling = True

del onnx_model

model = ORTModel.load_model(
Expand All @@ -567,7 +573,7 @@ def _from_pretrained(
provider_options=provider_options,
)

if config.model_type == "bloom" and not check_if_transformers_greater("4.44"):
if config.model_type == "bloom" and use_old_bloom_modeling:
init_cls = ORTBloomForCausalLM
elif config.model_type == "falcon":
init_cls = ORTFalconForCausalLM
Expand Down

0 comments on commit d957547

Please sign in to comment.