Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix forward pass for merged model #1462

Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion optimum/onnxruntime/modeling_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def forward(
loss = None
if self.use_cache:
if past_key_values is not None:
input_ids = input_ids[:, -1:]
input_ids = input_ids[:, -1:] if past_key_values[0][0].shape[2] != 0 else input_ids
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kunal-vaishnavi Thank you for the PR. Isn't past_key_values expected to be None in the first forward pass? And prepared later at use_cache_branch, past_key_values, known_output_shapes = self.prepare_past_key_values(input_ids, past_key_values, use_torch)?

cc @echarlaix

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or did you mean decoder_model_merged.onnx?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kunal-vaishnavi Thank you for the PR. Isn't past_key_values expected to be None in the first forward pass? And prepared later at use_cache_branch, past_key_values, known_output_shapes = self.prepare_past_key_values(input_ids, past_key_values, use_torch)?

cc @echarlaix

Yes

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure when this could happen, could you share an example so that we can reproduce @kunal-vaishnavi ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is with decoder_model_merged.onnx. When running the forward pass, each past_key and past_value in past_key_values has shape (batch_size, num_heads, past_seq_len, head_size). For the prompt step, past_seq_len = 0 so each past_key and past_value input shape is (batch_size, num_heads, 0, head_size).

Since self.use_cache is true for ORTModelForCausalLM models and the merged model always has past_key_values, both if statements are true. Then the input_ids currently get modified for both prompt and token generation.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kunal-vaishnavi I think it is not true. While decoder_model_merged.onnx, this is only true at the session.run call, not at the forward call, that is typically called from transformers GenerationMixin.generate. Do you have a different use case?

@echarlaix loading decoder_model_merged.onnx following #1257 prints

The ONNX file decoder_model_merged.onnx is not a regular name used in optimum.onnxruntime that are ['model.onnx', 'model_quantized.onnx', 'model_optimized.onnx', 'decoder_with_past_model.onnx', 'decoder_with_past_model_quantized.onnx', 'decoder_with_past_model_optimized.onnx'], the ORTModelForCausalLM might not behave as expected.

is this expected?

@kunal-vaishnavi Apart from backward compatibility reasons, do you have limiting factors on your side that prevent to simply use the new version of the exported model (with an empty KV cache at the first iteration)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the PR linked below, we want to benchmark exported models using Optimum. To have an equal comparison, we currently evaluate with model(...) and not model.generate(...). We pass both input_ids and past_key_values to the forward call in order for the merged ONNX model to receive both inputs. When calling the forward pass, the input_ids and past_key_values are passed with the right shapes. However, the input_ids shape is changed in Optimum before they are passed with the wrong shape to the ONNX model.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you, makes sense indeed.

# Flatten the past_key_values (no need to flatten for models using multi-query attn)
if self.config.model_type not in MULTI_QUERY_ATTN_MODELS:
past_key_values = tuple(
Expand Down