Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature Request] Shape inference for GroupQueryAttention Op #23189

Open
peishenyan opened this issue Dec 24, 2024 · 0 comments
Open

[Feature Request] Shape inference for GroupQueryAttention Op #23189

peishenyan opened this issue Dec 24, 2024 · 0 comments
Labels
ep:WebNN WebNN execution provider feature request request for unsupported feature or enhancement

Comments

@peishenyan
Copy link
Contributor

Describe the feature request

For WebNN EP, the graph builder does not accept input and output with dynamic shape. So after FreeDimensionOverride it is expected that all shape / dims are static.
There was already a shape inference function for GroupQueryAttention Op in BaseGroupQueryAttentionTypeAndShapeInference() of onnxruntime/core/graph/contrib_ops/bert_defs.cc. However, the use_max_past_present_buffer parameter is set to -1 for each case, as in the following code:

void GroupQueryAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext& ctx, int past_key_index) {
// TODO(aciddelgado): propagate output shapes depending if kv-share buffer is on or not
constexpr int use_max_past_present_buffer = -1;
BaseGroupQueryAttentionTypeAndShapeInference(ctx, past_key_index, use_max_past_present_buffer);
}

So I was wondering if it is possible to pass an argument/flag to give it a chance to perform shape inference, at least when the shared buffer is used by some EPs.

Describe scenario use case

When some EPs use shared buffer for key / value cache, they pass the flag/argument to set the use_max_past_present_buffer to 1, which will enable the shape inference for GroupQueryAttention Ops.

@peishenyan peishenyan added the feature request request for unsupported feature or enhancement label Dec 24, 2024
@github-actions github-actions bot added the ep:WebNN WebNN execution provider label Dec 24, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ep:WebNN WebNN execution provider feature request request for unsupported feature or enhancement
Projects
None yet
Development

No branches or pull requests

1 participant