Skip to content

Commit 75b3d42

Browse files
authored
Fix GQA fusion to produce present key/value (#2634)
Output present key value from the Attention op because past key value is provided. Previously the Attention op created would consume past key/value but not produce present key/value, which is not correct for ORT. <img width="1377" height="1225" alt="image" src="https://github.com/user-attachments/assets/118958b4-bc27-4912-b70b-000549887c0f" /> Replaces #2632 Signed-off-by: Justin Chu <[email protected]>
1 parent 811937c commit 75b3d42

File tree

1 file changed

+2
-1
lines changed
  • onnxscript/rewriter/rules/fusion

1 file changed

+2
-1
lines changed

onnxscript/rewriter/rules/fusion/_gqa.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def pattern(
5252
_outputs=["attention_BHSDh"],
5353
)
5454

55-
return attention_BHSDh
55+
return attention_BHSDh, present_key_BHkvStD, present_value_BHkvStD
5656

5757
def check(
5858
self,
@@ -103,6 +103,7 @@ def rewrite(
103103
past_key_BHkvSpD,
104104
past_value_BHkvSpD,
105105
**original_attrs,
106+
_outputs=3,
106107
)
107108

108109

0 commit comments

Comments
 (0)