Skip to content

Conversation

@keshavvinayak01
Copy link
Contributor

@keshavvinayak01 keshavvinayak01 commented Nov 4, 2025

Description

  • Added support for PyTorch's flex_attention Higher-Order Operator in torch-mlir.
  • Implemented Torch_AtenFlexAttentionOp with 6 operands (query, key, value, scale, enable_gqa, return_lse) and 2 optional attributes (score_mod_fn, mask_mod_fn) for function references.
  • The FX importer (_import_hop_flex_attention) correctly extracts score/mask modification functions from get_attr nodes using module IDs, following the while_loop HOP pattern.
  • Includes TODO markers for kernel_options performance tuning parameters.
  • Imports flex_attention from PyTorch FX graphs into valid MLIR.

keshavvinayak01 and others added 17 commits October 22, 2025 09:41
Signed-off-by: Keshav Vinayak Jha <[email protected]>
Signed-off-by: Keshav Vinayak Jha <[email protected]>
Signed-off-by: Keshav Vinayak Jha <[email protected]>
Change 1: Converts builtin tensors → Torch tensors when entering the loop body
Change 2: Ensures Torch tensors → builtin tensors when yielding back to the loop condition
Without these fixes, the conversion would fail when while loops carry tensor values

Also modified basic_test.py FILECHECK statements.

Signed-off-by: Keshav Vinayak Jha <[email protected]>
Signed-off-by: Keshav Vinayak Jha <[email protected]>
Signed-off-by: Keshav Vinayak Jha <[email protected]>
Signed-off-by: Keshav Vinayak Jha <[email protected]>
Signed-off-by: Keshav Vinayak Jha <[email protected]>
Signed-off-by: Keshav Vinayak Jha <[email protected]>
1. Better documentation for AtenFlexAttentionOp
2. Function referece added as attributes to aten.flex_attention
3. Updates to _import_hop_flex_attention reflecting latest changes of module import.
4. Removed discardable attributes; scored_mod_fn and mask_mod_fn added as optionalAttr

Signed-off-by: Keshav Vinayak Jha <[email protected]>
Remove note about method usage for HOPs.
@keshavvinayak01 keshavvinayak01 changed the title Keshavvinayak01/torch aten flex attention [TORCH] Added flex_attention hop function Nov 4, 2025
Removed TODO note for grouped query attention support in the docstring and comments.
@keshavvinayak01 keshavvinayak01 force-pushed the keshavvinayak01/torch-aten-flex_attention branch from 095cb61 to 5e024f6 Compare November 6, 2025 09:36
@keshavvinayak01 keshavvinayak01 marked this pull request as ready for review November 6, 2025 09:37
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant