Skip to content

Commit 038924d

Browse files
committed
fix index
1 parent 5d64dfb commit 038924d

File tree

2 files changed

+5
-5
lines changed

2 files changed

+5
-5
lines changed

applications/flash_attention_v2/kernel/xe_fhma_fwd_kernel.hpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -562,8 +562,8 @@ class XeFMHAFwdDynamicSplitKernel {
562562
}
563563
CUTLASS_PRAGMA_UNROLL
564564
for (int i = 0; i < size(FragARow{}.shape()); ++i) {
565-
merged_res(i + size(FragA{}.shape())) = tA_max(i);
566-
merged_res(i + 1 + size(FragA{}.shape())) = tA_sum(i);
565+
merged_res(2 * i + size(FragA{}.shape())) = tA_max(i);
566+
merged_res(2 * i + 1 + size(FragA{}.shape())) = tA_sum(i);
567567
}
568568
copy(merged_res, tPartial);
569569

@@ -608,8 +608,8 @@ class XeFMHAFwdDynamicSplitKernel {
608608

609609
CUTLASS_PRAGMA_UNROLL
610610
for (int i = 0; i < size(FragARow{}.shape()); ++i) {
611-
tA_max(i) = merged_res(i + size(FragA{}.shape()));
612-
tA_sum(i) = merged_res(i + 1 + size(FragA{}.shape()));
611+
tA_max(i) = merged_res(2 * i + size(FragA{}.shape()));
612+
tA_sum(i) = merged_res(2 * i + 1 + size(FragA{}.shape()));
613613
}
614614

615615
continue;

examples/06_bmg_flash_attention/06_xe_fmha_fwd.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ int main(int argc, const char **argv) {
112112
#define NUM_SG _16
113113
#define KV_TILE_SIZE _256
114114
#else
115-
#define NUM_SG _16
115+
#define NUM_SG _8
116116
#define KV_TILE_SIZE _512
117117
#endif
118118

0 commit comments

Comments
 (0)