diff --git a/applications/flash_attention_v2/collective/xe_fmha_fwd_mainloop.hpp b/applications/flash_attention_v2/collective/xe_fmha_fwd_mainloop.hpp index b2c802da4b..833217488b 100644 --- a/applications/flash_attention_v2/collective/xe_fmha_fwd_mainloop.hpp +++ b/applications/flash_attention_v2/collective/xe_fmha_fwd_mainloop.hpp @@ -171,6 +171,7 @@ struct FMHAFwdMainloop, CausalMask_, QVCoord blk_qv, // WG tile indices: (Q,V) int blk_k0, // K block range: [K0,K1) int blk_k1, + int total_blk, // Total # of K blocks int thr_id) { // Work-item ID using namespace sycl::ext::oneapi::this_work_item; @@ -289,7 +290,7 @@ struct FMHAFwdMainloop, CausalMask_, prefetch(prefetch_v, pVgV(_,_,_,K)); /* k masking for remainder tiles */ - if (check_remainder_k && K == blk_k1 - 1) { + if (check_remainder_k && K == total_blk - 1) { FragSRow k_rem_mask; int k = get<0>(tKgK(0,0,0,K,0)) + get_sub_group().get_local_id()[0]; CUTLASS_PRAGMA_UNROLL diff --git a/applications/flash_attention_v2/kernel/xe_fhma_fwd_kernel.hpp b/applications/flash_attention_v2/kernel/xe_fhma_fwd_kernel.hpp index fced70ee84..d74ee2fa26 100644 --- a/applications/flash_attention_v2/kernel/xe_fhma_fwd_kernel.hpp +++ b/applications/flash_attention_v2/kernel/xe_fhma_fwd_kernel.hpp @@ -38,6 +38,7 @@ #include "flash_attention_v2/collective/xe_fmha_fwd_mainloop.hpp" #include "flash_attention_v2/collective/xe_fmha_fwd_epilogue.hpp" +#include "flash_attention_v2/kernel/xe_tile_scheduler.hpp" namespace cutlass::fmha::kernel { @@ -216,7 +217,7 @@ class XeFMHAFwdKernel { K(_,_,head,idx_b), V(_,_,head,idx_b), tArA, tA_max, tA_sum, - blk_qv, 0, k_blocks, + blk_qv, 0, k_blocks, k_blocks, thr_id); if constexpr (!is_empty_v && !is_empty_v) { @@ -232,4 +233,421 @@ class XeFMHAFwdKernel { } }; +template +class XeFMHAFwdDynamicSplitKernel { + +public: + // + // Type Aliases + // + using ProblemShape = ProblemShape_; + + // Mainloop derived types + using CollectiveMainloop = CollectiveMainloop_; + using MainloopArguments = typename CollectiveMainloop::Arguments; + using MainloopParams = typename CollectiveMainloop::Params; + + using TiledMMAQK = typename CollectiveMainloop::TiledMMAQK; + using TiledMMAPV = typename CollectiveMainloop::TiledMMAPV; + using TileShapeQK = typename CollectiveMainloop::TileShapeQK; + using TileShapePV = typename CollectiveMainloop::TileShapePV; + + using ElementQ = typename CollectiveMainloop::TensorQ::element_type; + using ElementK = typename CollectiveMainloop::TensorK::element_type; + using ElementV = typename CollectiveMainloop::TensorV::element_type; + + using StrideQ = decltype(stride(typename CollectiveMainloop::TensorQ{})); + using StrideK = decltype(stride(typename CollectiveMainloop::TensorK{})); + using StrideV = decltype(stride(typename CollectiveMainloop::TensorV{})); + + using SGPerWG = typename CollectiveMainloop::SGPerWG; + + using FragA = typename CollectiveMainloop::FragA; + using SingleFragA = typename CollectiveMainloop::SingleFragA; + using FragARow = typename CollectiveMainloop::FragARow; + // element dtype for MmaPV results + using ElementA = typename CollectiveMainloop::ElementA; + + // Tile scheduler derived types + static_assert(is_same_v); + using TileScheduler = TileScheduler_; + using TileSchedulerParams = typename TileScheduler::Params; + + // Epilogue derived types + using CollectiveEpilogue = CollectiveEpilogue_; + using EpilogueArguments = typename CollectiveEpilogue::Arguments; + using EpilogueParams = typename CollectiveEpilogue::Params; + + using TileShapeO = typename CollectiveEpilogue::TileShapeO; + using ElementO = typename CollectiveEpilogue::TensorO::element_type; + using StrideO = decltype(stride(typename CollectiveEpilogue::TensorO{})); + + // Kernel level shared memory storage + using MainloopSharedStorage = typename CollectiveMainloop::SharedStorage; + using EpilogueSharedStorage = typename CollectiveEpilogue::SharedStorage; + union SharedStorage { + MainloopSharedStorage mainloop; + EpilogueSharedStorage epilogue; + }; + + static constexpr int SharedStorageSize = is_empty_v ? size_t(0) + : sizeof(SharedStorage); + + // Important: make sure multiple of 16 element for each copy + // this is for storing partial results from different KV partitions + static constexpr int num_elem_per_thread = (size(FragA{}.shape()) + 2 * size(FragARow{}.shape()) + 15) / 16 * 16; + // FIXME: maybe exceed more than 4 paritions??? + static const int max_num_partitions = 8; + + // Device side arguments + struct KernelArguments { + ProblemShape shape; + const ElementQ *Q; + StrideQ dQ; + const ElementK *K; + StrideK dK; + const ElementV *V; + StrideV dV; + ElementO *O; + StrideO dO; + }; + using KernelParams = KernelArguments; + + struct Arguments { + KernelArguments kernel{}; + MainloopArguments mainloop{}; + EpilogueArguments epilogue{}; + KernelHardwareInfo hw_info{}; + }; + + // Kernel entry point API + struct Params { + KernelParams kernel; + MainloopParams mainloop; + EpilogueParams epilogue; + TileSchedulerParams scheduler; + // workspace for storing partial results of different KV partitions + ElementA *partial_results_ptr = nullptr; + // for atomic add + int32_t *atomic_reduce_cnt_ptr = nullptr; + }; + + // + // Methods + // + + static Params to_underlying_arguments(Arguments const &args, void *workspace) { + int num_batch_heads = args.kernel.shape.batch * args.kernel.shape.num_heads_q; + int32_t *atomic_reduce_cnt_ptr = reinterpret_cast(workspace); + ElementA *partial_results_ptr = reinterpret_cast(atomic_reduce_cnt_ptr + num_batch_heads); + return {args.kernel, + CollectiveMainloop::to_underlying_arguments(args.mainloop, workspace), + CollectiveEpilogue::to_underlying_arguments(args.epilogue, workspace), + TileScheduler::to_underlying_arguments(args.kernel.shape, args.hw_info, TileShapeO{}), + partial_results_ptr, atomic_reduce_cnt_ptr + }; + } + + static bool can_implement(Arguments const &args) { + // current kernel only support decode + if (args.kernel.shape.seq_len_qo > 1) { + return false; + } + // current kernel only support num batch heads less than total XeCore count + if (args.kernel.shape.batch * args.kernel.shape.num_heads_q > args.hw_info.sm_count) { + return false; + } + return CollectiveMainloop::can_implement(args.mainloop) + && CollectiveEpilogue::can_implement(args.epilogue); + } + + static int get_workspace_size(Arguments const &args) { + int ws_size = 0; + int num_batch_heads = args.kernel.shape.batch * args.kernel.shape.num_heads_q; + const int wg_size = SGPerWG::value * intel::sg_size; + + // partial attn outputs, exp sum and max logits + ws_size += (max_num_partitions * num_batch_heads) * wg_size * num_elem_per_thread * sizeof(ElementA); + // atomic counter + ws_size += num_batch_heads * sizeof(int32_t); + return ws_size; + } + + static cutlass::Status initialize_workspace(Arguments const &args, void *workspace = nullptr, + cudaStream_t stream = nullptr, CudaHostAdapter *cuda_adapter = nullptr) { + int num_batch_heads = args.kernel.shape.batch * args.kernel.shape.num_heads_q; + compat::fill(reinterpret_cast(workspace), (int32_t)0, num_batch_heads); + auto partial_ws_count = (get_workspace_size(args) - num_batch_heads * sizeof(int32_t)) / sizeof(ElementA); + auto* partial_results_ptr = reinterpret_cast(reinterpret_cast(workspace) + num_batch_heads); + compat::fill(partial_results_ptr, (ElementA)0, partial_ws_count); + return Status::kSuccess; + } + + static dim3 get_grid_shape(Params const ¶ms) { + return TileScheduler::template get_grid_shape(params.scheduler); + } + + static dim3 get_block_shape() { return dim3(SGPerWG::value * intel::sg_size, 1, 1); } + + CUTLASS_DEVICE + int get_partition_id(const int cur_wg_id, const int batch_head_id, const int num_blocks_per_wg, const int local_k_blocks) { + int partition_id = 0; + if (batch_head_id == 0) { + return cur_wg_id; + } + int start_wg_id = batch_head_id * local_k_blocks / num_blocks_per_wg; + partition_id = cur_wg_id - start_wg_id; + return partition_id; + } + + CUTLASS_DEVICE + int get_num_partitions(const int batch_head_id, const int num_blocks_per_wg, const int local_k_blocks) { + int num_partitions = 1; + int start_wg_id = batch_head_id * local_k_blocks / num_blocks_per_wg; + int end_wg_id = (batch_head_id + 1) * local_k_blocks / num_blocks_per_wg; + num_partitions = end_wg_id - start_wg_id + 1; + // end_wg_id is the starting wg id of next batch head id + if (((batch_head_id + 1) * local_k_blocks) % num_blocks_per_wg == 0) { + num_partitions -= 1; + } + return num_partitions; + } + + template + CUTLASS_DEVICE + void reduce_split2(const Params ¶ms, FragA &out1, FragARow& max_val1, FragARow& exp_sum_val1, FragA &out2, FragARow& max_val2, FragARow& exp_sum_val2) { + // global max value + FragARow max_prev1 = max_val1; + FragARow max_prev2 = max_val2; + + auto scale = params.mainloop.scale; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < max_val1.size(); i++) { + max_val1(i) = sycl::max(max_val1(i), max_val2(i)); + } + + FragARow rescale1, rescale2; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < max_val1.size(); i++) { + rescale1(i) = sycl::native::exp2(max_prev1(i) - max_val1(i)); + rescale2(i) = sycl::native::exp2(max_prev2(i) - max_val1(i)); + } + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < exp_sum_val1.size(); i++) { + exp_sum_val1(i) = exp_sum_val1(i) * rescale1(i) + exp_sum_val2(i) * rescale2(i); + } + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < out1.size(); i++) + out1(i) = out1(i) * broadcast<0>(rescale1, out1, i) + out2(i) * broadcast<0>(rescale2, out2, i); + } + + CUTLASS_DEVICE + void operator()(Params const ¶ms, char *smem_buf) + { + using namespace sycl::ext::oneapi::this_work_item; + + SharedStorage& shared_storage = *reinterpret_cast(smem_buf); + + auto &p = params.kernel; + ProblemShape const& s = p.shape; + int head_group_q = s.num_heads_q / s.num_heads_kv; + + int thr_id = int(ThreadIdxX()); + int wg_id = int(BlockIdxZ()); + + int sg_id = thr_id / intel::sg_size; + int tid_in_sg = thr_id % intel::sg_size; + int num_batch_heads = s.batch * s.num_heads_q; + + int local_k_blocks = cute::ceil_div(s.seq_len_kv, get<1>(TileShapeQK{})); + // total number of blocks need to be processed across all wgs + int total_k_blocks = local_k_blocks * num_batch_heads; + // to guarantee all wg process similar number of blocks of KV + int num_blocks_per_wg = cute::ceil_div(total_k_blocks, GridDimZ()); + + TileScheduler tile_scheduler{params.scheduler, get<1>(TileShapeQK{}), local_k_blocks, num_batch_heads}; + + CUTLASS_PRAGMA_NO_UNROLL + for (; tile_scheduler.is_valid(); ++tile_scheduler) { + // head_q, idx_b from tile scheduler will not be used + // auto [blk_q, blk_v, head_q_unused, idx_b_unused] = tile_scheduler.get_block_coord(); // (Q,V,h,b) + auto [blk_q, blk_v, start_batch_head_id] = tile_scheduler.get_block_coord(); // (Q,V, batch_head_idx) + auto blk_qv = make_coord(blk_q, blk_v); + + auto shape_Q = make_shape(s.seq_len_qo, s.head_size_qk, s.num_heads_q, s.batch); + auto shape_K = make_shape(s.seq_len_kv, s.head_size_qk, s.num_heads_kv, s.batch); + auto shape_V = make_shape(s.head_size_vo, s.seq_len_kv, s.num_heads_kv, s.batch); + auto shape_O = make_shape(s.seq_len_qo, s.head_size_vo, s.num_heads_kv, s.batch); + + auto dcQ = const_cast(p.Q); // de-const these for uniformity + auto dcK = const_cast(p.K); + auto dcV = const_cast(p.V); + + Tensor Q = make_tensor(make_gmem_ptr(dcQ), make_layout(shape_Q, p.dQ)); // (q,d,h,b) + Tensor K = make_tensor(make_gmem_ptr(dcK), make_layout(shape_K, p.dK)); // (k,d,h,b) + Tensor V = make_tensor(make_gmem_ptr(dcV), make_layout(shape_V, p.dV)); // (v,k,h,b) + Tensor O = make_tensor(make_gmem_ptr(p.O), make_layout(shape_O, p.dO)); // (q,v,h,b) + + // O accumulator types + FragA tArA; + FragARow tA_max, tA_sum; + + // compute num computed blocks for start batch head id + int num_computed_blocks = (start_batch_head_id == 0) ? (wg_id * num_blocks_per_wg) : (wg_id * num_blocks_per_wg - start_batch_head_id * local_k_blocks); + int start_blk, end_blk, head_q, idx_b, head_kv; + // leader wg is also responsible for reducing partial results, while other + // worker wg only to compute partial results + bool is_leader_wg = wg_id < num_batch_heads; + + if (thr_id == 0 && is_leader_wg) { + // reset atomic counter before computation + *(params.atomic_reduce_cnt_ptr + wg_id) = 0; + } + + // Main loop + CollectiveMainloop mainloop(params.mainloop, shared_storage.mainloop); + + // compute blocks budget remained for each wg + int block_budget_remained = num_blocks_per_wg; + int batch_head_id = start_batch_head_id; + bool is_update_batch_head_id = false; + while (block_budget_remained > 0) { + int num_new_blocks = local_k_blocks - num_computed_blocks; + if (num_new_blocks <= block_budget_remained) { + // finished current batch head id + start_blk = num_computed_blocks; + end_blk = start_blk + num_new_blocks; + + // update states + num_computed_blocks = 0; + block_budget_remained -= num_new_blocks; + is_update_batch_head_id = true; + } else { + // budget cannot afford finishing current batch head id + start_blk = num_computed_blocks; + end_blk = start_blk + block_budget_remained; + + block_budget_remained = 0; + is_update_batch_head_id = false; + } + + head_q = batch_head_id % s.num_heads_q; + idx_b = batch_head_id / s.num_heads_q; + head_kv = head_q / head_group_q; + // mainloop + mainloop(Q(_,_,head_q,idx_b), + K(_,_,head_kv,idx_b), + V(_,_,head_kv,idx_b), + tArA, tA_max, tA_sum, + blk_qv, start_blk, end_blk, local_k_blocks, + thr_id); + + // partition id of start batch head id in current wg + int partition_id = get_partition_id(wg_id, batch_head_id, num_blocks_per_wg, local_k_blocks); + + // store partial result: tArA, tA_max and tA_sum + int offset = batch_head_id * max_num_partitions * num_elem_per_thread * SGPerWG::value * intel::sg_size + + partition_id * num_elem_per_thread * SGPerWG::value * intel::sg_size + + sg_id * intel::sg_size * num_elem_per_thread + + tid_in_sg * num_elem_per_thread; + Tensor tPartial = make_tensor(params.partial_results_ptr + offset, make_shape(Int{})); + Tensor merged_res = make_tensor(Int{}); + + CUTLASS_PRAGMA_UNROLL + for(int i = 0; i < size(FragA{}.shape()); ++i) { + merged_res(i) = tArA(i); + } + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(FragARow{}.shape()); ++i) { + merged_res(2 * i + size(FragA{}.shape())) = tA_max(i); + merged_res(2 * i + 1 + size(FragA{}.shape())) = tA_sum(i); + } + copy(merged_res, tPartial); + + // after store, set atomic cnt + if (thr_id == 0) { + atomicAdd(params.atomic_reduce_cnt_ptr + batch_head_id, 1); + } + + // advance to next batch head id + if (is_update_batch_head_id) { + batch_head_id += 1; + if (batch_head_id >= num_batch_heads) { + break; + } + } + } + + if (is_leader_wg) { + int num_partitions = get_num_partitions(wg_id, num_blocks_per_wg, local_k_blocks); + + // check atomic to wait for partial results ready + while(atomicLoad(params.atomic_reduce_cnt_ptr + wg_id) != num_partitions) {} + + clear(tArA); + clear(tA_max); + clear(tA_sum); + + for (int i = 0; i < num_partitions; ++i) { + int offset = wg_id * max_num_partitions * SGPerWG::value * intel::sg_size * num_elem_per_thread + + i * SGPerWG::value * intel::sg_size * num_elem_per_thread + + sg_id * intel::sg_size * num_elem_per_thread + + tid_in_sg * num_elem_per_thread; + Tensor tPartial = make_tensor(params.partial_results_ptr + offset, make_shape(Int{})); + Tensor merged_res = make_tensor(Int{}); + copy(tPartial, merged_res); + + if (i == 0) { + CUTLASS_PRAGMA_UNROLL + for(int i = 0; i < size(FragA{}.shape()); ++i) { + tArA(i) = merged_res(i); + } + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(FragARow{}.shape()); ++i) { + tA_max(i) = merged_res(2 * i + size(FragA{}.shape())); + tA_sum(i) = merged_res(2 * i + 1 + size(FragA{}.shape())); + } + + continue; + } + + FragA tArA_2; + FragARow tA_max_2, tA_sum_2; + CUTLASS_PRAGMA_UNROLL + for(int i = 0; i < size(FragA{}.shape()); ++i) { + tArA_2(i) = merged_res(i); + } + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(FragARow{}.shape()); ++i) { + tA_max_2(i) = merged_res(2 * i + size(FragA{}.shape())); + tA_sum_2(i) = merged_res(2 * i + 1 + size(FragA{}.shape())); + } + + reduce_split2(params, tArA, tA_max, tA_sum, tArA_2, tA_max_2, tA_sum_2); + } + + // require group barrier if using SLM + if constexpr (!is_empty_v && !is_empty_v) { + sycl::group_barrier(get_work_group<3>()); + } + + head_q = wg_id % s.num_heads_q; + idx_b = wg_id / s.num_heads_q; + head_kv = head_q / head_group_q; + + // Epilogue + CollectiveEpilogue epilogue{params.epilogue, shared_storage.epilogue}; + epilogue(O(_,_,head_q,idx_b), + tArA, tA_max, tA_sum, + blk_qv, thr_id); + } + } + } +}; + } // namespace cutlass::fmha::kernel diff --git a/applications/flash_attention_v2/kernel/xe_tile_scheduler.hpp b/applications/flash_attention_v2/kernel/xe_tile_scheduler.hpp index a14d6db482..fc106cd34d 100644 --- a/applications/flash_attention_v2/kernel/xe_tile_scheduler.hpp +++ b/applications/flash_attention_v2/kernel/xe_tile_scheduler.hpp @@ -92,4 +92,73 @@ struct XeFHMAIndividualTileScheduler { } }; +struct XeFHMAIndividualPersistentTileScheduler { + + struct Params { + dim3 grid; + FastDivmod divmod_num_heads; + }; + + bool valid_ = true; + Params params; + int kv_tile_size_; + // num of kv blocks for each head + int local_num_kv_blocks_; + int num_batch_heads_; + + CUTLASS_DEVICE + XeFHMAIndividualPersistentTileScheduler(Params const& params, int kv_tile_size, + int local_num_kv_blocks, int num_batch_heads) + : params(params), kv_tile_size_(kv_tile_size), local_num_kv_blocks_(local_num_kv_blocks), num_batch_heads_(num_batch_heads) {} + + template + static Params to_underlying_arguments( + ProblemShape const& shape, KernelHardwareInfo hw_info, + TileShape const& tile_shape) + { + using namespace cute; + + dim3 grid(size(ceil_div(shape.head_size_vo, get<1>(tile_shape))), // V + size(ceil_div(shape.seq_len_qo, get<0>(tile_shape))), // Q + size(shape.batch * shape.num_heads_q)); // (h,b) -- split later + int num_heads = shape.num_heads_q; + grid.z = hw_info.sm_count; + + return Params{grid, {num_heads}}; + } + + template + static dim3 get_grid_shape(Params const& params) { + return params.grid; + } + + CUTLASS_DEVICE + bool is_valid() { + return valid_; + } + + CUTLASS_DEVICE + auto get_block_coord() { + using namespace cute; + int wg_id = BlockIdxZ(); + int head; + + // total number of blocks need to be processed across all wgs + int total_num_kv_blocks = local_num_kv_blocks_ * num_batch_heads_; + // guarantee all wg process similar number of blocks of KV (load balance) + int num_blocks_per_wg = cute::ceil_div(total_num_kv_blocks, GridDimZ()); + + // compute start batch head id for current wg + int start_batch_head_id = wg_id * num_blocks_per_wg / local_num_kv_blocks_; + + return make_coord(BlockIdxY(), BlockIdxX(), start_batch_head_id); + } + + CUTLASS_DEVICE + XeFHMAIndividualPersistentTileScheduler& operator++() { + valid_ = false; + return *this; + } +}; + } // namespace cutlass::fmha::kernel diff --git a/examples/06_bmg_flash_attention/06_xe_fmha_fwd.cpp b/examples/06_bmg_flash_attention/06_xe_fmha_fwd.cpp index d19d2bbd1a..2d9c8c35a4 100644 --- a/examples/06_bmg_flash_attention/06_xe_fmha_fwd.cpp +++ b/examples/06_bmg_flash_attention/06_xe_fmha_fwd.cpp @@ -107,36 +107,45 @@ int main(int argc, const char **argv) { #endif #elif defined(DECODE) + +#if PERSISTENT +#define NUM_SG _16 +#define KV_TILE_SIZE _256 +#else +#define NUM_SG _8 +#define KV_TILE_SIZE _512 +#endif + #if HEAD_DIM == 16 /* Tiny config for testing */ using ShapeQK = Shape<_1, _16, _16>; // (q,k,d) using ShapePV = Shape<_1, _16, _16>; // (q,v,k) using ShapeOut = Shape<_1, _16>; // (q,v) - using SubgroupLayoutQK = Layout>; + using SubgroupLayoutQK = Layout>; #elif HEAD_DIM == 64 using ShapeQK = Shape<_1, _512, _64>; using ShapePV = Shape<_1, _32, _512>; using ShapeOut = Shape<_1, _64>; - using SubgroupLayoutQK = Layout>; + using SubgroupLayoutQK = Layout>; #elif HEAD_DIM == 96 using ShapeQK = Shape<_1, _512, _64>; using ShapePV = Shape<_1, _32, _512>; using ShapeOut = Shape<_1, _96>; - using SubgroupLayoutQK = Layout>; + using SubgroupLayoutQK = Layout>; #elif HEAD_DIM == 128 - using ShapeQK = Shape<_1, _512, _64>; - using ShapePV = Shape<_1, _32, _512>; + using ShapeQK = Shape<_1, KV_TILE_SIZE, _64>; + using ShapePV = Shape<_1, _32, KV_TILE_SIZE>; using ShapeOut = Shape<_1, _128>; - using SubgroupLayoutQK = Layout>; + using SubgroupLayoutQK = Layout>; #elif HEAD_DIM == 192 using ShapeQK = Shape<_1, _512, _64>; using ShapePV = Shape<_1, _32, _512>; using ShapeOut = Shape<_1, _192>; - using SubgroupLayoutQK = Layout>; + using SubgroupLayoutQK = Layout>; #endif #else #error Either DECODE or PREFILL should be defined. @@ -148,5 +157,9 @@ int main(int argc, const char **argv) { constexpr int PipelineStages = 2; #endif - return FMHAConfig::run(options); +#if PERSISTENT + return FMHAConfig::run(options); +#else + return FMHAConfig::run(options); +#endif } diff --git a/examples/06_bmg_flash_attention/CMakeLists.txt b/examples/06_bmg_flash_attention/CMakeLists.txt index 5ccc5f30cd..17d144b327 100644 --- a/examples/06_bmg_flash_attention/CMakeLists.txt +++ b/examples/06_bmg_flash_attention/CMakeLists.txt @@ -44,6 +44,12 @@ foreach(HEAD_DIM 64 96 128 192) 06_xe_fmha_fwd.cpp ) + # specific test for persistent kernel + cutlass_example_add_executable( + 06_xe_fmha_fwd_decode_persistent_hdim${HEAD_DIM} + 06_xe_fmha_fwd.cpp + ) + cutlass_example_add_executable( 06_bmg_prefill_attention_hdim${HEAD_DIM} 06_bmg_prefill_attention.cpp @@ -84,4 +90,5 @@ foreach(HEAD_DIM 64 96 128 192) target_compile_definitions(06_bmg_decode_attention_fp8_hdim${HEAD_DIM} PRIVATE HEAD_DIM=${HEAD_DIM}) target_compile_definitions(06_xe_fmha_fwd_prefill_hdim${HEAD_DIM} PRIVATE HEAD_DIM=${HEAD_DIM} PREFILL SHOW_DIFF=1) target_compile_definitions(06_xe_fmha_fwd_decode_hdim${HEAD_DIM} PRIVATE HEAD_DIM=${HEAD_DIM} DECODE SHOW_DIFF=1) + target_compile_definitions(06_xe_fmha_fwd_decode_persistent_hdim${HEAD_DIM} PRIVATE HEAD_DIM=${HEAD_DIM} DECODE PERSISTENT SHOW_DIFF=1) endforeach() diff --git a/examples/06_bmg_flash_attention/xe_fmha_fwd_runner.hpp b/examples/06_bmg_flash_attention/xe_fmha_fwd_runner.hpp index 6ce6e9e95b..3140c637db 100644 --- a/examples/06_bmg_flash_attention/xe_fmha_fwd_runner.hpp +++ b/examples/06_bmg_flash_attention/xe_fmha_fwd_runner.hpp @@ -514,6 +514,7 @@ template default */ int PipelineStages, + bool persistent, typename ElementQ = bfloat16_t, typename ElementK = bfloat16_t, typename ElementV = bfloat16_t, @@ -546,6 +547,7 @@ struct FMHAConfig { // The KernelHardwareInfo struct holds the number of EUs on the GPU with a given device ID. This // information is used by the underlying kernel. cutlass::KernelHardwareInfo hw_info; + hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); using ProblemShapeType = cutlass::fmha::kernel::FMHAProblemShape; @@ -583,9 +585,12 @@ struct FMHAConfig { GmemTiledCopyO >; - using FMHAKernel = cutlass::fmha::kernel::XeFMHAFwdKernel< - ProblemShapeType, CollectiveMainloop, CollectiveEpilogue, Scheduler - >; + using FMHAKernel = conditional_t, + cutlass::fmha::kernel::XeFMHAFwdDynamicSplitKernel< + ProblemShapeType, CollectiveMainloop, CollectiveEpilogue, Scheduler>, + cutlass::fmha::kernel::XeFMHAFwdKernel< + ProblemShapeType, CollectiveMainloop, CollectiveEpilogue, Scheduler> + >; ExampleRunner runner; @@ -594,6 +599,7 @@ struct FMHAConfig { } static int run(const Options &options) { - return run(options); + return persistent ? run(options) : + run(options); } }; diff --git a/include/cutlass/gpu_generics.h b/include/cutlass/gpu_generics.h index adc0882e91..69b582e1ec 100644 --- a/include/cutlass/gpu_generics.h +++ b/include/cutlass/gpu_generics.h @@ -365,6 +365,15 @@ CUTLASS_DEVICE T atomicAdd(T *address, T val) { return static_cast(0); } +template +CUTLASS_DEVICE T atomicSub(T *address, T val) { +#if defined(__SYCL_DEVICE_ONLY__) + return compat::atomic_fetch_sub(address, val); +#endif + return static_cast(0); +} + + CUTLASS_DEVICE int atomicCAS(int *address, int compare, int val) { int result = 0; #if defined(__SYCL_DEVICE_ONLY__) @@ -373,6 +382,15 @@ CUTLASS_DEVICE int atomicCAS(int *address, int compare, int val) { return result; } +CUTLASS_DEVICE int atomicLoad(int *address) { + int result = 0; +#if defined(__SYCL_DEVICE_ONLY__) + auto atm = sycl::atomic_ref(address[0]); + result = atm.load(); +#endif + return result; +} + // Error using cudaError_t = unsigned int; constexpr cudaError_t cudaSuccess = 0;