Skip to content

Conversation

@Rachmanino
Copy link
Collaborator

@Rachmanino Rachmanino commented Nov 11, 2025

Summary by CodeRabbit

  • Bug Fixes

    • Improved out-of-bounds handling in multiple flash-attention examples: accumulator initialization now masks invalid sequence positions to avoid spurious contributions.
  • Chores

    • Updated example CLI defaults (including context size and causal flag) for clarity.
    • Updated test wiring to reference the revised example variants and added clarifying comments about non-causal out-of-bounds behavior.

@github-actions
Copy link

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Nov 11, 2025

Walkthrough

The PR replaces blanket clears of accumulator buffers with conditional per-element initialization across multiple flash-attention examples, setting out-of-bounds entries to -inf when the k-index exceeds sequence length, otherwise 0; it also adds comments in some backward kernels, a small CLI default tweak, and updates test imports to *_bshd variants.

Changes

Cohort / File(s) Summary
MHA forward examples
examples/flash_attention/example_mha_fwd_bhsd.py, examples/flash_attention/example_mha_fwd_bshd.py, examples/flash_attention/example_mha_fwd_bhsd_wgmma_pipelined.py, examples/flash_attention/example_mha_fwd_bshd_wgmma_pipelined.py
Replaced T.clear(acc_s) with an explicit per-element parallel init over (block_M, block_N): set acc_s[i,j] = -inf when k * block_N + j >= seq_kv (or seq_len), otherwise 0. Added explicit --is_causal default in one file.
GQA forward examples
examples/flash_attention/example_gqa_fwd_bshd.py, examples/flash_attention/example_gqa_fwd_bshd_wgmma_pipelined.py
Replaced blanket clear with conditional per-element initialization guarding out-of-bounds indices with -inf and zeros otherwise.
MHA backward examples
examples/flash_attention/example_mha_bwd_bhsd.py, examples/flash_attention/example_mha_bwd_bshd.py, examples/flash_attention/example_mha_bwd_bshd_wgmma_pipelined.py
In forward kernels: replaced T.clear(acc_s) with conditional per-element initialization. In backward kernels: added explanatory comments about OOB handling for non-causal cases. One script changes --n_ctx default from 1024 to 1048.
GQA backward examples
examples/flash_attention/example_gqa_bwd.py, examples/flash_attention/example_gqa_bwd_tma_reduce.py, examples/flash_attention/example_gqa_bwd_wgmma_pipelined.py
Replaced T.clear(acc_s) with per-element parallel initialization that masks out-of-bounds k indices with -inf, else 0.
Tests
examples/flash_attention/test_example_flash_attention.py
Updated imports and test calls to use *_bshd variants (example_mha_bwd_bshd, example_mha_bwd_bshd_wgmma_pipelined) and adjusted main() references accordingly.
Manifests / metadata
requirements.txt, pyproject.toml
Present in diff manifest; no functional changes reported.

Sequence Diagram(s)

sequenceDiagram
    autonumber
    actor Host
    participant Kernel
    Note over Host,Kernel: Forward kernel (non-causal) — per-k block initialization
    Host->>Kernel: launch forward kernel (loop over k)
    loop for each k
        Kernel->>Kernel: parallel init acc_s[i,j]
        alt (k*block_N + j) >= seq_len (OOB)
            Kernel-->>Kernel: acc_s[i,j] = -inf
        else (in-bounds)
            Kernel-->>Kernel: acc_s[i,j] = 0
        end
        Kernel->>Kernel: GEMM accumulate using acc_s
    end
    Kernel-->>Host: return results
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20–30 minutes

  • Check consistent use of seq_len vs seq_kv across files.
  • Verify off-by-one semantics in the boundary expression k * block_N + j >= seq_len.
  • Review the CLI default (--n_ctx) change and updated test imports for correctness.

Possibly related PRs

Suggested reviewers

  • LeiWang1999

Poem

🐇
I hopped through code at break of day,
Switched blanket wipes for careful array,
Out-of-bounds wear twilight's shawl of -inf,
In-bounds keep zero, tidy and swift,
A tiny hop, a safer drift ✨

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title accurately describes the main change: replacing blanket accumulator clearing with conditional initialization using negative infinity for out-of-bounds positions across all attention kernel examples.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

📜 Recent review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 063af94 and 70deb79.

📒 Files selected for processing (13)
  • examples/flash_attention/example_gqa_bwd.py (1 hunks)
  • examples/flash_attention/example_gqa_bwd_tma_reduce.py (1 hunks)
  • examples/flash_attention/example_gqa_bwd_wgmma_pipelined.py (1 hunks)
  • examples/flash_attention/example_gqa_fwd_bshd.py (1 hunks)
  • examples/flash_attention/example_gqa_fwd_bshd_wgmma_pipelined.py (1 hunks)
  • examples/flash_attention/example_mha_bwd_bhsd.py (2 hunks)
  • examples/flash_attention/example_mha_bwd_bshd.py (3 hunks)
  • examples/flash_attention/example_mha_bwd_bshd_wgmma_pipelined.py (2 hunks)
  • examples/flash_attention/example_mha_fwd_bhsd.py (2 hunks)
  • examples/flash_attention/example_mha_fwd_bhsd_wgmma_pipelined.py (1 hunks)
  • examples/flash_attention/example_mha_fwd_bshd.py (1 hunks)
  • examples/flash_attention/example_mha_fwd_bshd_wgmma_pipelined.py (1 hunks)
  • examples/flash_attention/test_example_flash_attention.py (3 hunks)
🚧 Files skipped from review as they are similar to previous changes (13)
  • examples/flash_attention/example_mha_fwd_bshd_wgmma_pipelined.py
  • examples/flash_attention/example_mha_fwd_bhsd_wgmma_pipelined.py
  • examples/flash_attention/example_gqa_fwd_bshd_wgmma_pipelined.py
  • examples/flash_attention/example_gqa_bwd_wgmma_pipelined.py
  • examples/flash_attention/example_mha_bwd_bshd_wgmma_pipelined.py
  • examples/flash_attention/example_mha_bwd_bhsd.py
  • examples/flash_attention/test_example_flash_attention.py
  • examples/flash_attention/example_mha_fwd_bhsd.py
  • examples/flash_attention/example_mha_fwd_bshd.py
  • examples/flash_attention/example_gqa_bwd.py
  • examples/flash_attention/example_mha_bwd_bshd.py
  • examples/flash_attention/example_gqa_bwd_tma_reduce.py
  • examples/flash_attention/example_gqa_fwd_bshd.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Quick Lint

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@Rachmanino Rachmanino reopened this Nov 11, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants