diff --git a/src/turbomind/kernels/attention/attention_universal.h b/src/turbomind/kernels/attention/attention_universal.h index c724c1e4de..82556c44fb 100644 --- a/src/turbomind/kernels/attention/attention_universal.h +++ b/src/turbomind/kernels/attention/attention_universal.h @@ -436,6 +436,12 @@ struct AttentionUniversal { } } + __device__ bool check_h(int hi) + { + /// FIXME: for tensor core decoding `CTA_H == 1` fails (currently CTA_H > 2 are used for TC) + return CTA_H == 1 || hi < CTA_H; + } + __device__ void StoreO(FragO& frag_O, FragL& frag_L, int qi_begin, @@ -445,7 +451,7 @@ struct AttentionUniversal { SharedStorage& storage) { Impl::StoreO(frag_O, frag_L, storage, [&](int hi, int qi, int di, const auto& vec) { - if (qi_begin + qi < qi_end) { + if (qi_begin + qi < qi_end && check_h(hi)) { const int offset = (qi_begin + qi) * params.num_heads * kHeadDim + (head_idx + hi) * kHeadDim + di; Store(¶ms.out[offset], cast(vec)); } @@ -462,7 +468,7 @@ struct AttentionUniversal { Impl::ForeachS(frag_S, [&](int hi, int qi, int si, int ri, float score) { qi += query_idx; si += offset_K; - if (qi < params.max_q_len && si < max_context_len) { + if (qi < params.max_q_len && si < max_context_len && check_h(hi)) { params.qk[batch_idx * params.num_heads * params.max_q_len * max_context_len + (head_idx + hi) * params.max_q_len * max_context_len + qi * max_context_len + si] = score; @@ -488,14 +494,14 @@ struct AttentionUniversal { }; Impl::StoreO(frag_O, frag_L, storage, [&](int hi, int qi, int di, const auto& vec) { - if (qi_begin + qi < qi_end) { + if (qi_begin + qi < qi_end && check_h(hi)) { Store(¶ms.partial_O[get_index(hi, qi) * kHeadDim + di], vec); } }); Impl::ForeachML(frag_M, frag_L, [&](int hi, int qi, int ri, float M, float L) { const int index = get_index(hi, qi); - if (qi_begin + qi < qi_end && ri == 0) { + if (qi_begin + qi < qi_end && ri == 0 && check_h(hi)) { // printf("ML %2d %2d %f %f\n", split_idx, head_idx + hi, M, L); params.partial_M[index] = M; params.partial_L[index] = L; diff --git a/src/turbomind/kernels/attention/decoding.cu b/src/turbomind/kernels/attention/decoding.cu index 7d6100864d..6ad50c5712 100644 --- a/src/turbomind/kernels/attention/decoding.cu +++ b/src/turbomind/kernels/attention/decoding.cu @@ -68,12 +68,15 @@ void dispatchDecoding(const AttentionParams& params) else if (query_group_sz % 8 == 0) { return invokeDecoding::Kernel>(params); } + else if (query_group_sz % 6 == 0) { + return invokeDecoding::Kernel>(params); + } + else if (query_group_sz % 5 == 0) { + return invokeDecoding::Kernel>(params); + } else if (query_group_sz % 4 == 0) { return invokeDecoding::Kernel>(params); } - else if (query_group_sz % 2 == 0) { - return invokeDecoding::Kernel>(params); - } else { return invokeDecoding::Kernel>(params); } @@ -128,13 +131,17 @@ void dispatchDecoding(const AttentionParams& params) return invokeDecoding< typename DecodingConfig::Kernel>(params); } - else if (query_group_sz % 4 == 0) { + else if (query_group_sz % 6 == 0) { return invokeDecoding< - typename DecodingConfig::Kernel>(params); + typename DecodingConfig::Kernel>(params); } - else if (query_group_sz % 2 == 0) { + else if (query_group_sz % 5 == 0) { return invokeDecoding< - typename DecodingConfig::Kernel>(params); + typename DecodingConfig::Kernel>(params); + } + else if (query_group_sz % 4 == 0) { + return invokeDecoding< + typename DecodingConfig::Kernel>(params); } else { return invokeDecoding< diff --git a/src/turbomind/kernels/attention/decoding_128_bf16_sm80.cu b/src/turbomind/kernels/attention/decoding_128_bf16_sm80.cu index 70ceb21ba3..f657116b76 100644 --- a/src/turbomind/kernels/attention/decoding_128_bf16_sm80.cu +++ b/src/turbomind/kernels/attention/decoding_128_bf16_sm80.cu @@ -10,12 +10,15 @@ using namespace attention; using sm80_bf16_bf16_g1_d128 = Decoding; template void invokeDecoding(const typename sm80_bf16_bf16_g1_d128::ParamType& params); -using sm80_bf16_f16_g2_d128 = Decoding; -template void invokeDecoding(const typename sm80_bf16_f16_g2_d128::ParamType& params); - using sm80_bf16_bf16_g4_d128 = Decoding; template void invokeDecoding(const typename sm80_bf16_bf16_g4_d128::ParamType& params); +using sm80_bf16_bf16_g5_d128 = Decoding; +template void invokeDecoding(const typename sm80_bf16_bf16_g5_d128::ParamType& params); + +using sm80_bf16_bf16_g6_d128 = Decoding; +template void invokeDecoding(const typename sm80_bf16_bf16_g6_d128::ParamType& params); + using sm80_bf16_bf16_g8_d128 = Decoding; template void invokeDecoding(const typename sm80_bf16_bf16_g8_d128::ParamType& params); diff --git a/src/turbomind/kernels/attention/decoding_128_f16_sm80.cu b/src/turbomind/kernels/attention/decoding_128_f16_sm80.cu index aca02ab302..e304151779 100644 --- a/src/turbomind/kernels/attention/decoding_128_f16_sm80.cu +++ b/src/turbomind/kernels/attention/decoding_128_f16_sm80.cu @@ -10,12 +10,15 @@ using namespace attention; using sm80_f16_f16_g1_d128 = Decoding; template void invokeDecoding(const typename sm80_f16_f16_g1_d128::ParamType& params); -using sm80_f16_f16_g2_d128 = Decoding; -template void invokeDecoding(const typename sm80_f16_f16_g2_d128::ParamType& params); - using sm80_f16_f16_g4_d128 = Decoding; template void invokeDecoding(const typename sm80_f16_f16_g4_d128::ParamType& params); +using sm80_f16_f16_g5_d128 = Decoding; +template void invokeDecoding(const typename sm80_f16_f16_g4_d128::ParamType& params); + +using sm80_f16_f16_g6_d128 = Decoding; +template void invokeDecoding(const typename sm80_f16_f16_g4_d128::ParamType& params); + using sm80_f16_f16_g8_d128 = Decoding; template void invokeDecoding(const typename sm80_f16_f16_g8_d128::ParamType& params); diff --git a/src/turbomind/kernels/attention/decoding_config.h b/src/turbomind/kernels/attention/decoding_config.h index 3197a71825..1a408d7c94 100644 --- a/src/turbomind/kernels/attention/decoding_config.h +++ b/src/turbomind/kernels/attention/decoding_config.h @@ -14,7 +14,7 @@ namespace turbomind::attention { -template +template struct DecodingConfig { static_assert(sizeof(T) == 0, "config not found"); }; @@ -22,16 +22,16 @@ struct DecodingConfig { template using Decoding = typename DecodingConfig::Kernel; -template -struct DecodingConfig { - using Attention = Impl; +template +struct DecodingConfig> { + using Attention = Impl; using Mainloop = Mainloop, Attention>; using Kernel = AttentionUniversal; }; -template -struct DecodingConfig { - using Attention = Impl; +template +struct DecodingConfig 2)>> { + using Attention = Impl; using Mainloop = Mainloop, Attention>; using Kernel = AttentionUniversal; }; @@ -57,32 +57,4 @@ struct DecodingConfig { using Kernel = AttentionUniversal; }; -// template -// struct DecodingConfig { -// using Attention = Impl; -// using Mainloop = Mainloop, Attention>; -// using Kernel = AttentionUniversal; -// }; - -// template -// struct DecodingConfig { -// using Attention = Impl; -// using Mainloop = Mainloop, Attention>; -// using Kernel = AttentionUniversal; -// }; - -// template -// struct DecodingConfig { -// using Attention = Impl; -// using Mainloop = Mainloop; -// using Kernel = AttentionUniversal; -// }; - -// template -// struct DecodingConfig { -// using Attention = Impl; -// using Mainloop = Mainloop; -// using Kernel = AttentionUniversal; -// }; - } // namespace turbomind::attention diff --git a/src/turbomind/kernels/attention/decoding_sm80.h b/src/turbomind/kernels/attention/decoding_sm80.h index 39b7d69d2e..844407b4bb 100644 --- a/src/turbomind/kernels/attention/decoding_sm80.h +++ b/src/turbomind/kernels/attention/decoding_sm80.h @@ -98,7 +98,6 @@ struct Impl>; + static constexpr int CTA_H1 = (CTA_H + OP_N - 1) / OP_N * OP_N; + + using SmemLayoutQ = SmemLayoutV2>; using SmemLayoutK = SmemLayoutV2>; using SmemLayoutV = SmemLayoutV2>; @@ -169,12 +170,12 @@ struct Impl; + using ThreadMapQ = RakedThreadMap; using ThreadMapKV = RakedThreadMap; static constexpr int kBatchK = ThreadMapKV::kIterS; @@ -565,7 +566,7 @@ struct Impl; + using Map = RakedThreadMap; Array tmp_O[Map::kIterS][Map::kIterC]; const int2 offset = Map::get_offset(warp_id, lane_id); PRAGMA_UNROLL