-
Notifications
You must be signed in to change notification settings - Fork 314
[BugFix] Refactor attention kernel to handle OOB positions by filling with -inf instead of clearing accumulators.
#1222
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
WalkthroughThe PR replaces blanket clears of accumulator buffers with conditional per-element initialization across multiple flash-attention examples, setting out-of-bounds entries to -inf when the k-index exceeds sequence length, otherwise 0; it also adds comments in some backward kernels, a small CLI default tweak, and updates test imports to *_bshd variants. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
actor Host
participant Kernel
Note over Host,Kernel: Forward kernel (non-causal) — per-k block initialization
Host->>Kernel: launch forward kernel (loop over k)
loop for each k
Kernel->>Kernel: parallel init acc_s[i,j]
alt (k*block_N + j) >= seq_len (OOB)
Kernel-->>Kernel: acc_s[i,j] = -inf
else (in-bounds)
Kernel-->>Kernel: acc_s[i,j] = 0
end
Kernel->>Kernel: GEMM accumulate using acc_s
end
Kernel-->>Host: return results
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20–30 minutes
Possibly related PRs
Suggested reviewers
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
📜 Recent review detailsConfiguration used: CodeRabbit UI Review profile: CHILL Plan: Pro 📒 Files selected for processing (13)
🚧 Files skipped from review as they are similar to previous changes (13)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
…nf` instead of clearing accumulators.
…orward examples for better clarity and consistency.
Summary by CodeRabbit
Bug Fixes
Chores