Skip to content

Conversation

@ClarkChin08
Copy link

No description provided.

class TiledCopyK_ = void, // Optional TiledCopy for loading K
class TiledCopyV_ = void> // Optional TiledCopy for loading V
class TiledCopyV_ = void, // Optional TiledCopy for loading V
class SubgroupLayoutQK_ = void> // Optional SubgroupLayout for QK
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ClarkChin08 -- You can derive SubgroupLayoutQK_ from TiledMMAQK_, so the user doesn't have to pass it in.

Comment on lines +301 to +313
int item_id = get_sub_group().get_local_id()[0];
int base_col = item_id + K * get<1>(TileShapeQK{});
CUTLASS_PRAGMA_UNROLL
for (int n = 0; n < shape<2>(tSrS.shape()); ++n) {
int col_idx = base_col + n * get<1>(MmaAtomShapeQK());
CUTLASS_PRAGMA_UNROLL
for (int m = 0; m < shape<0>(tSrS.shape()); ++m) {
int row_idx = seq_coord + m;
if (col_idx - full_tile_offset > row_idx - discard_seq_coord) {
tSrS(m, 0, n) = ElementS(-INFINITY);
}
}
}
Copy link

@petercad petercad Nov 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should avoid making assumptions about the layout of MMA blocks within tSrS as they can easily broken by different choices for TiledMMAQK. Also, this code is assuming size<1>(tSrS) == 1.

Instead, I'd suggest using coordinate tensors to identify the coordinates of thread-owned data -- something like this (code below untested):

auto cS_thread = thr_mma_qk.partition_C(cP);     /* local S coordinates, within WG tile */
for (int i = 0; i < tSrS.size(); i++)
    if (get<1>(cS_thread(i)) >= seq_len - base_col)
        tSrS(i) = ElementS(-INFINITY);

Also, by using this method, you can avoid referring to SublayoutLayoutQK at all, because cS_thread is already aware of the subgroup tile's position within the workgroup tile.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @petercad , Let me have a try.

Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR adds support for causal masking in the flash attention implementation by introducing a new SubgroupLayoutQK template parameter and implementing the causal mask logic in the mainloop.

Key Changes:

  • Added SubgroupLayoutQK template parameter to the collective mainloop and kernel interfaces
  • Implemented causal masking logic that applies -INFINITY to attention scores beyond the causal boundary
  • Updated the example runner to conditionally instantiate causal or non-causal configurations based on user options

Reviewed Changes

Copilot reviewed 4 out of 4 changed files in this pull request and generated 2 comments.

File Description
applications/flash_attention_v2/collective/xe_fmha_fwd_mainloop.hpp Implements causal mask logic and removes the static assertion that previously blocked causal mask usage
applications/flash_attention_v2/kernel/xe_fhma_fwd_kernel.hpp Adds subgroup layout type alias and computes sequence coordinates for causal masking
examples/06_bmg_flash_attention/xe_fmha_fwd_runner.hpp Adds SubgroupLayoutQK template parameter to mainloop type
examples/06_bmg_flash_attention/06_xe_fmha_fwd.cpp Conditionally selects causal or non-causal kernel based on is_causal option

auto discard_seq_coord = s.seq_len_qo - offset;
auto full_tile_offset = s.seq_len_kv - offset;

int seq_coord = cute::min(s.seq_len_qo, (blk_q * get<0>(TileShapeQK{}) + (sub_group_id / get<1>(shape(SubgroupLayoutQK{}))) * SGTileQ));
Copy link

Copilot AI Nov 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This calculation is overly complex and difficult to understand. Consider breaking it into intermediate variables with descriptive names to clarify the computation of tile offset within the subgroup layout.

Suggested change
int seq_coord = cute::min(s.seq_len_qo, (blk_q * get<0>(TileShapeQK{}) + (sub_group_id / get<1>(shape(SubgroupLayoutQK{}))) * SGTileQ));
// Break down the seq_coord calculation for clarity
int tile_shape_qk_0 = get<0>(TileShapeQK{});
int subgroup_layout_qk_1 = get<1>(shape(SubgroupLayoutQK{}));
int blk_q_offset = blk_q * tile_shape_qk_0;
int subgroup_tile_offset = (sub_group_id / subgroup_layout_qk_1) * SGTileQ;
int raw_seq_coord = blk_q_offset + subgroup_tile_offset;
int seq_coord = cute::min(s.seq_len_qo, raw_seq_coord);

Copilot uses AI. Check for mistakes.
int col_idx = base_col + n * get<1>(MmaAtomShapeQK());
CUTLASS_PRAGMA_UNROLL
for (int m = 0; m < shape<0>(tSrS.shape()); ++m) {
int row_idx = seq_coord + m;
Copy link

Copilot AI Nov 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The causal mask condition logic lacks explanation. Add a comment describing what this inequality represents in terms of the attention matrix and why these specific offsets are used.

Suggested change
int row_idx = seq_coord + m;
int row_idx = seq_coord + m;
// Causal mask: For each (row_idx, col_idx) in the attention matrix,
// set positions where the query (row) would attend to future keys (col) to -INFINITY.
// The offsets (full_tile_offset, discard_seq_coord) adjust for tiling and indexing,
// ensuring that only positions where col_idx > row_idx (i.e., future positions)
// are masked out, preserving causality in the attention mechanism.

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants