Skip to content

Conversation

@LuFinch
Copy link
Contributor

@LuFinch LuFinch commented Nov 12, 2025

This PR moves the sycltla kernels in pytorch/pytorch#167056 into torch-xpu-ops.

This PR is based on #2030. When the build PR merge, I will rebase this PR.

Copy link
Contributor

@EikanWang EikanWang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TBH, I cannot quite understand the detailed implementation. I need to take more time to understand the logic.


file(GLOB xpu_cpp "xpu/*.cpp")
file(GLOB xpu_native_cpp "native/xpu/*.cpp" "native/sparse/*.cpp" "native/sparse/xpu/*.cpp" "native/nested/*.cpp" "native/nested/xpu/*.cpp" "native/transformers/*.cpp" "native/quantized/*.cpp")
file(GLOB xpu_native_cpp "native/xpu/*.cpp" "native/sparse/*.cpp" "native/sparse/xpu/*.cpp" "native/nested/*.cpp" "native/nested/xpu/*.cpp" "native/transformers/*.cpp" "native/quantized/*.cpp" "native/transformers/xpu/flash_attn/*.cpp")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: I think we should install the header file under flash_attn into PyTorch such as line 42

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May I know what is the purpose of installing header file?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Give a chance to use them in cpp extension.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@guangyey , I think PyTorch does not expose flash_attn because it is the underlying logic of sdpa, which is exposed as a backend. Meanwhile, I don't believe users invoke the flash_atten of PyTorch because dao/flash_atten is a better choice.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Meanwhile, the namespace of these functions is sycltla. It is weird to let users invoke sycl-tla-specific functions.

Copilot AI review requested due to automatic review settings November 13, 2025 05:52

This comment was marked as outdated.

@LuFinch LuFinch force-pushed the lfq/flash_attention branch from 770035a to 442c445 Compare November 13, 2025 05:55
out = at::empty({batch_size, numhead_qo, seqlen_qo, headsize_vo}, opts);
} else if (layout == ATTN_TENSOR_LAYOUT::BSHD) {
out = at::empty({batch_size, seqlen_qo, numhead_qo, headsize_vo}, opts)
.permute({0, 2, 1, 3});

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why need to permute here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

output is inited as BSHD contiguous but the shape should be BHSD in for SDPA. Hence it needs to permute the seqlen and numhead dimension.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants