Skip to content

Commit

Permalink
tensor core GQA dispatch for [4,5,6,8] (#1258)
Browse files Browse the repository at this point in the history
  • Loading branch information
lzhangzz authored Mar 11, 2024
1 parent 60d9bfd commit 331858b
Show file tree
Hide file tree
Showing 6 changed files with 49 additions and 57 deletions.
14 changes: 10 additions & 4 deletions src/turbomind/kernels/attention/attention_universal.h
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,12 @@ struct AttentionUniversal {
}
}

__device__ bool check_h(int hi)
{
/// FIXME: for tensor core decoding `CTA_H == 1` fails (currently CTA_H > 2 are used for TC)
return CTA_H == 1 || hi < CTA_H;
}

__device__ void StoreO(FragO& frag_O,
FragL& frag_L,
int qi_begin,
Expand All @@ -445,7 +451,7 @@ struct AttentionUniversal {
SharedStorage& storage)
{
Impl::StoreO<true>(frag_O, frag_L, storage, [&](int hi, int qi, int di, const auto& vec) {
if (qi_begin + qi < qi_end) {
if (qi_begin + qi < qi_end && check_h(hi)) {
const int offset = (qi_begin + qi) * params.num_heads * kHeadDim + (head_idx + hi) * kHeadDim + di;
Store(&params.out[offset], cast<T>(vec));
}
Expand All @@ -462,7 +468,7 @@ struct AttentionUniversal {
Impl::ForeachS(frag_S, [&](int hi, int qi, int si, int ri, float score) {
qi += query_idx;
si += offset_K;
if (qi < params.max_q_len && si < max_context_len) {
if (qi < params.max_q_len && si < max_context_len && check_h(hi)) {
params.qk[batch_idx * params.num_heads * params.max_q_len * max_context_len
+ (head_idx + hi) * params.max_q_len * max_context_len + qi * max_context_len + si] =
score;
Expand All @@ -488,14 +494,14 @@ struct AttentionUniversal {
};

Impl::StoreO<false>(frag_O, frag_L, storage, [&](int hi, int qi, int di, const auto& vec) {
if (qi_begin + qi < qi_end) {
if (qi_begin + qi < qi_end && check_h(hi)) {
Store(&params.partial_O[get_index(hi, qi) * kHeadDim + di], vec);
}
});

Impl::ForeachML(frag_M, frag_L, [&](int hi, int qi, int ri, float M, float L) {
const int index = get_index(hi, qi);
if (qi_begin + qi < qi_end && ri == 0) {
if (qi_begin + qi < qi_end && ri == 0 && check_h(hi)) {
// printf("ML %2d %2d %f %f\n", split_idx, head_idx + hi, M, L);
params.partial_M[index] = M;
params.partial_L[index] = L;
Expand Down
21 changes: 14 additions & 7 deletions src/turbomind/kernels/attention/decoding.cu
Original file line number Diff line number Diff line change
Expand Up @@ -68,12 +68,15 @@ void dispatchDecoding(const AttentionParams<T>& params)
else if (query_group_sz % 8 == 0) {
return invokeDecoding<typename DecodingConfig<arch::Sm80, T, T, 8, kHeadDim>::Kernel>(params);
}
else if (query_group_sz % 6 == 0) {
return invokeDecoding<typename DecodingConfig<arch::Sm80, T, T, 6, kHeadDim>::Kernel>(params);
}
else if (query_group_sz % 5 == 0) {
return invokeDecoding<typename DecodingConfig<arch::Sm80, T, T, 5, kHeadDim>::Kernel>(params);
}
else if (query_group_sz % 4 == 0) {
return invokeDecoding<typename DecodingConfig<arch::Sm80, T, T, 4, kHeadDim>::Kernel>(params);
}
else if (query_group_sz % 2 == 0) {
return invokeDecoding<typename DecodingConfig<arch::Sm80, T, T, 2, kHeadDim>::Kernel>(params);
}
else {
return invokeDecoding<typename DecodingConfig<arch::Sm80, T, T, 1, kHeadDim>::Kernel>(params);
}
Expand Down Expand Up @@ -128,13 +131,17 @@ void dispatchDecoding(const AttentionParams<nv_bfloat16>& params)
return invokeDecoding<
typename DecodingConfig<arch::Sm80, nv_bfloat16, nv_bfloat16, 8, kHeadDim>::Kernel>(params);
}
else if (query_group_sz % 4 == 0) {
else if (query_group_sz % 6 == 0) {
return invokeDecoding<
typename DecodingConfig<arch::Sm80, nv_bfloat16, nv_bfloat16, 4, kHeadDim>::Kernel>(params);
typename DecodingConfig<arch::Sm80, nv_bfloat16, nv_bfloat16, 6, kHeadDim>::Kernel>(params);
}
else if (query_group_sz % 2 == 0) {
else if (query_group_sz % 5 == 0) {
return invokeDecoding<
typename DecodingConfig<arch::Sm80, nv_bfloat16, nv_bfloat16, 2, kHeadDim>::Kernel>(params);
typename DecodingConfig<arch::Sm80, nv_bfloat16, nv_bfloat16, 5, kHeadDim>::Kernel>(params);
}
else if (query_group_sz % 4 == 0) {
return invokeDecoding<
typename DecodingConfig<arch::Sm80, nv_bfloat16, nv_bfloat16, 4, kHeadDim>::Kernel>(params);
}
else {
return invokeDecoding<
Expand Down
9 changes: 6 additions & 3 deletions src/turbomind/kernels/attention/decoding_128_bf16_sm80.cu
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,15 @@ using namespace attention;
using sm80_bf16_bf16_g1_d128 = Decoding<arch::Sm80, nv_bfloat16, nv_bfloat16, 1, 128>;
template void invokeDecoding<sm80_bf16_bf16_g1_d128>(const typename sm80_bf16_bf16_g1_d128::ParamType& params);

using sm80_bf16_f16_g2_d128 = Decoding<arch::Sm80, nv_bfloat16, nv_bfloat16, 2, 128>;
template void invokeDecoding<sm80_bf16_f16_g2_d128>(const typename sm80_bf16_f16_g2_d128::ParamType& params);

using sm80_bf16_bf16_g4_d128 = Decoding<arch::Sm80, nv_bfloat16, nv_bfloat16, 4, 128>;
template void invokeDecoding<sm80_bf16_bf16_g4_d128>(const typename sm80_bf16_bf16_g4_d128::ParamType& params);

using sm80_bf16_bf16_g5_d128 = Decoding<arch::Sm80, nv_bfloat16, nv_bfloat16, 5, 128>;
template void invokeDecoding<sm80_bf16_bf16_g5_d128>(const typename sm80_bf16_bf16_g5_d128::ParamType& params);

using sm80_bf16_bf16_g6_d128 = Decoding<arch::Sm80, nv_bfloat16, nv_bfloat16, 6, 128>;
template void invokeDecoding<sm80_bf16_bf16_g6_d128>(const typename sm80_bf16_bf16_g6_d128::ParamType& params);

using sm80_bf16_bf16_g8_d128 = Decoding<arch::Sm80, nv_bfloat16, nv_bfloat16, 8, 128>;
template void invokeDecoding<sm80_bf16_bf16_g8_d128>(const typename sm80_bf16_bf16_g8_d128::ParamType& params);

Expand Down
9 changes: 6 additions & 3 deletions src/turbomind/kernels/attention/decoding_128_f16_sm80.cu
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,15 @@ using namespace attention;
using sm80_f16_f16_g1_d128 = Decoding<arch::Sm80, half, half, 1, 128>;
template void invokeDecoding<sm80_f16_f16_g1_d128>(const typename sm80_f16_f16_g1_d128::ParamType& params);

using sm80_f16_f16_g2_d128 = Decoding<arch::Sm80, half, half, 2, 128>;
template void invokeDecoding<sm80_f16_f16_g2_d128>(const typename sm80_f16_f16_g2_d128::ParamType& params);

using sm80_f16_f16_g4_d128 = Decoding<arch::Sm80, half, half, 4, 128>;
template void invokeDecoding<sm80_f16_f16_g4_d128>(const typename sm80_f16_f16_g4_d128::ParamType& params);

using sm80_f16_f16_g5_d128 = Decoding<arch::Sm80, half, half, 5, 128>;
template void invokeDecoding<sm80_f16_f16_g5_d128>(const typename sm80_f16_f16_g4_d128::ParamType& params);

using sm80_f16_f16_g6_d128 = Decoding<arch::Sm80, half, half, 6, 128>;
template void invokeDecoding<sm80_f16_f16_g6_d128>(const typename sm80_f16_f16_g4_d128::ParamType& params);

using sm80_f16_f16_g8_d128 = Decoding<arch::Sm80, half, half, 8, 128>;
template void invokeDecoding<sm80_f16_f16_g8_d128>(const typename sm80_f16_f16_g8_d128::ParamType& params);

Expand Down
42 changes: 7 additions & 35 deletions src/turbomind/kernels/attention/decoding_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,24 +14,24 @@

namespace turbomind::attention {

template<class Arch, class T, class Tkv, int Qh, int HeadDim>
template<class Arch, class T, class Tkv, int Qh, int HeadDim, class SFINAE = void>
struct DecodingConfig {
static_assert(sizeof(T) == 0, "config not found");
};

template<class Arch, class T, class Tkv, int Qh, int HeadDim>
using Decoding = typename DecodingConfig<Arch, T, Tkv, Qh, HeadDim>::Kernel;

template<class T, class Tkv, int Qh, int HeadDim>
struct DecodingConfig<arch::Sm80, T, Tkv, Qh, HeadDim> {
using Attention = Impl<Sm70_Simt, T, Tkv, Qh, 1, 64, Qh, 1, 16, HeadDim, 3>;
template<class T, int Qh, int HeadDim>
struct DecodingConfig<arch::Sm80, T, T, Qh, HeadDim, std::enable_if_t<(Qh <= 2)>> {
using Attention = Impl<Sm70_Simt, T, T, Qh, 1, 64, Qh, 1, 16, HeadDim, 3>;
using Mainloop = Mainloop<Sm80_CpAsync<3>, Attention>;
using Kernel = AttentionUniversal<Mainloop, int, DecodingCtaMap>;
};

template<class T, class Tkv, int HeadDim>
struct DecodingConfig<arch::Sm80, T, Tkv, 8, HeadDim> {
using Attention = Impl<Sm80_81616, T, Tkv, 8, 1, 64, 8, 1, 16, HeadDim, 3>;
template<class T, int Qh, int HeadDim>
struct DecodingConfig<arch::Sm80, T, T, Qh, HeadDim, std::enable_if_t<(Qh > 2)>> {
using Attention = Impl<Sm80_81616, T, T, Qh, 1, 64, Qh, 1, 16, HeadDim, 3>;
using Mainloop = Mainloop<Sm80_CpAsync<3>, Attention>;
using Kernel = AttentionUniversal<Mainloop, int, DecodingCtaMap>;
};
Expand All @@ -57,32 +57,4 @@ struct DecodingConfig<arch::Sm70, T, int8_t, Qh, HeadDim> {
using Kernel = AttentionUniversal<Mainloop, int, DecodingCtaMap>;
};

// template<class T, class Tkv, class BlockSeqLen, int HeadDim>
// struct DecodingConfig<arch::Sm80, T, Tkv, BlockSeqLen, HeadDim> {
// using Attention = Impl<Sm70_Simt, T, Tkv, 1, 64, 1, 8, HeadDim>;
// using Mainloop = Mainloop<Sm80_CpAsync<7>, Attention>;
// using Kernel = AttentionUniversal<Mainloop, BlockSeqLen, DecodingCtaMap>;
// };

// template<class T, class Tkv, class BlockSeqLen, int HeadDim>
// struct DecodingConfig<arch::Sm80, T, Tkv, BlockSeqLen, HeadDim> {
// using Attention = Impl<Sm70_Simt, T, Tkv, 1, 128, 1, 16, HeadDim>;
// using Mainloop = Mainloop<Sm80_CpAsync<5>, Attention>;
// using Kernel = AttentionUniversal<Mainloop, BlockSeqLen, DecodingCtaMap>;
// };

// template<class T, class Tkv, class BlockSeqLen, int HeadDim>
// struct DecodingConfig<arch::Sm70, T, Tkv, BlockSeqLen, HeadDim> {
// using Attention = Impl<Sm70_Simt, T, Tkv, 1, 1, 64, 1, 1, 16, HeadDim, 1>;
// using Mainloop = Mainloop<arch::Sm70, Attention>;
// using Kernel = AttentionUniversal<Mainloop, BlockSeqLen, DecodingCtaMap>;
// };

// template<class T, class Tkv, class BlockSeqLen, int HeadDim>
// struct DecodingConfig<arch::Sm70, T, Tkv, BlockSeqLen, HeadDim> {
// using Attention = Impl<Sm70_Simt, T, Tkv, 1, 128, 1, 8, HeadDim, 1>;
// using Mainloop = Mainloop<arch::Sm70, Attention>;
// using Kernel = AttentionUniversal<Mainloop, BlockSeqLen, DecodingCtaMap>;
// };

} // namespace turbomind::attention
11 changes: 6 additions & 5 deletions src/turbomind/kernels/attention/decoding_sm80.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,6 @@ struct Impl<Sm80_81616, T_, Tkv_, CTA_H_, CTA_Q_, CTA_S_, WARP_H_, WARP_Q, WARP_
static constexpr int CTA_Q = CTA_Q_;
static constexpr int CTA_S = CTA_S_;

static_assert(CTA_H == 8);
static_assert(CTA_Q == 1);

static constexpr int WARP_H = WARP_H_;
Expand Down Expand Up @@ -145,7 +144,9 @@ struct Impl<Sm80_81616, T_, Tkv_, CTA_H_, CTA_Q_, CTA_S_, WARP_H_, WARP_Q, WARP_
static constexpr bool kUseSmemQ = false;
static constexpr bool kUseSmemP = false;

using SmemLayoutQ = SmemLayoutV2<CTA_H, HeadDim, CTA_H, HeadDim, Swizzle<3, 3, 4>>;
static constexpr int CTA_H1 = (CTA_H + OP_N - 1) / OP_N * OP_N;

using SmemLayoutQ = SmemLayoutV2<CTA_H1, HeadDim, CTA_H1, HeadDim, Swizzle<3, 3, 4>>;
using SmemLayoutK = SmemLayoutV2<CTA_S, HeadDim, 16, 64, Swizzle<3, 3, 3>>;
using SmemLayoutV = SmemLayoutV2<CTA_S, HeadDim, 16, 64, Swizzle<3, 3, 3>>;

Expand All @@ -169,12 +170,12 @@ struct Impl<Sm80_81616, T_, Tkv_, CTA_H_, CTA_Q_, CTA_S_, WARP_H_, WARP_Q, WARP_
__align__(16) SmemO O;
};

__align__(16) float O1[CTA_H][kHeadDim];
__align__(16) float O1[CTA_H1][kHeadDim];

T P[1];
};

using ThreadMapQ = RakedThreadMap<HeadDim, CTA_H, 8, kWarpCount>;
using ThreadMapQ = RakedThreadMap<HeadDim, CTA_H1, 8, kWarpCount>;
using ThreadMapKV = RakedThreadMap<HeadDim, CTA_S, 8, kWarpCount>;

static constexpr int kBatchK = ThreadMapKV::kIterS;
Expand Down Expand Up @@ -565,7 +566,7 @@ struct Impl<Sm80_81616, T_, Tkv_, CTA_H_, CTA_Q_, CTA_S_, WARP_H_, WARP_Q, WARP_

__syncthreads();

using Map = RakedThreadMap<kHeadDim, CTA_H, 4, kWarpCount>;
using Map = RakedThreadMap<kHeadDim, CTA_H1, 4, kWarpCount>;
Array<float, 4> tmp_O[Map::kIterS][Map::kIterC];
const int2 offset = Map::get_offset(warp_id, lane_id);
PRAGMA_UNROLL
Expand Down

0 comments on commit 331858b

Please sign in to comment.