Skip to content

[MatmulLoopPipeline]: Prefetch 2D loads #4051

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 30, 2025
Merged

Conversation

etiotto
Copy link
Contributor

@etiotto etiotto commented Apr 29, 2025

Add a check to prefetch only 2D tensor loads loads. This avoid potential generation of invalid prefetch operations which would cause assertions in subsequent passes or lead to incorrect code generation.

Signed-off-by: Tiotto, Ettore <[email protected]>
@etiotto etiotto self-assigned this Apr 29, 2025
Copy link

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR adds a safeguard to ensure that only 2D tensor load operations are prefetched in the MatmulLoopPipeline, preventing generation of invalid prefetch operations that could lead to assertions or incorrect code generation.

  • Updated comments to clarify caching behavior.
  • Added a check to skip prefetching for non-2D tensor loads.

Copy link
Contributor

@alexbaden alexbaden left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we are changing this behavior, we should have a test - is this in response to a particular bug or just something we're worried could occur?

Signed-off-by: Tiotto, Ettore <[email protected]>
@etiotto etiotto requested a review from alexbaden April 29, 2025 21:21
Copy link
Contributor

@alexbaden alexbaden left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Looks good.

@etiotto
Copy link
Contributor Author

etiotto commented Apr 30, 2025

If we are changing this behavior, we should have a test - is this in response to a particular bug or just something we're worried could occur?

I have added a new lit test. I discovered the problem while working on another PR (#3634) which was failing one of the CI tests..

@etiotto etiotto merged commit ba04707 into main Apr 30, 2025
9 checks passed
@etiotto etiotto deleted the etiotto.matmul_pipeline branch April 30, 2025 14:25
david-hls pushed a commit to david-hls/intel-xpu-backend-for-triton that referenced this pull request Jun 18, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Performance] Enhance the software loop pipelining for tt.load with tensor of pointer
4 participants