Skip to content

Commit

Permalink
fix lint
Browse files Browse the repository at this point in the history
  • Loading branch information
lzhangzz committed Feb 28, 2024
1 parent 1fde87a commit dcad091
Show file tree
Hide file tree
Showing 12 changed files with 65 additions and 53 deletions.
19 changes: 12 additions & 7 deletions src/turbomind/kernels/attention/decoding.cu
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,6 @@ void dispatchDecoding(const AttentionParams<T>& params)
FT_CHECK(0);
}


template<>
void dispatchDecoding(const AttentionParams<nv_bfloat16>& params)
{
Expand All @@ -113,27 +112,33 @@ void dispatchDecoding(const AttentionParams<nv_bfloat16>& params)
if (params.arch >= 80) {
if (0) {}
else if (query_group_sz % 2 == 0) {
return invokeDecoding<typename DecodingConfig<arch::Sm80, nv_bfloat16, int8_t, 2, kHeadDim>::Kernel>(params);
return invokeDecoding<typename DecodingConfig<arch::Sm80, nv_bfloat16, int8_t, 2, kHeadDim>::Kernel>(
params);
}
else {
return invokeDecoding<typename DecodingConfig<arch::Sm80, nv_bfloat16, int8_t, 1, kHeadDim>::Kernel>(params);
return invokeDecoding<typename DecodingConfig<arch::Sm80, nv_bfloat16, int8_t, 1, kHeadDim>::Kernel>(
params);
}
}
}
else {
if (params.arch >= 80) {
if (0) {}
else if (query_group_sz % 8 == 0) {
return invokeDecoding<typename DecodingConfig<arch::Sm80, nv_bfloat16, nv_bfloat16, 8, kHeadDim>::Kernel>(params);
return invokeDecoding<
typename DecodingConfig<arch::Sm80, nv_bfloat16, nv_bfloat16, 8, kHeadDim>::Kernel>(params);
}
else if (query_group_sz % 4 == 0) {
return invokeDecoding<typename DecodingConfig<arch::Sm80, nv_bfloat16, nv_bfloat16, 4, kHeadDim>::Kernel>(params);
return invokeDecoding<
typename DecodingConfig<arch::Sm80, nv_bfloat16, nv_bfloat16, 4, kHeadDim>::Kernel>(params);
}
else if (query_group_sz % 2 == 0) {
return invokeDecoding<typename DecodingConfig<arch::Sm80, nv_bfloat16, nv_bfloat16, 2, kHeadDim>::Kernel>(params);
return invokeDecoding<
typename DecodingConfig<arch::Sm80, nv_bfloat16, nv_bfloat16, 2, kHeadDim>::Kernel>(params);
}
else {
return invokeDecoding<typename DecodingConfig<arch::Sm80, nv_bfloat16, nv_bfloat16, 1, kHeadDim>::Kernel>(params);
return invokeDecoding<
typename DecodingConfig<arch::Sm80, nv_bfloat16, nv_bfloat16, 1, kHeadDim>::Kernel>(params);
}
}
}
Expand Down
18 changes: 9 additions & 9 deletions src/turbomind/kernels/attention/decoding_simt.h
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ struct Impl<Sm70_Simt, T_, Tkv_, CTA_H_, CTA_Q_, CTA_S_, WARP_H_, WARP_Q, WARP_S
for (int m = 0; m < K_M; ++m) { // Q
const int hi = m * OP_H;
const int ri = threadIdx.x;
((Func&&)func)(hi, 0, ri, frag_M[m][0], frag_L[m][0]);
((Func &&) func)(hi, 0, ri, frag_M[m][0], frag_L[m][0]);
}
}

Expand All @@ -226,7 +226,7 @@ struct Impl<Sm70_Simt, T_, Tkv_, CTA_H_, CTA_Q_, CTA_S_, WARP_H_, WARP_Q, WARP_S
const int hi = m * OP_H + warp_id_h * WARP_H;
const int si = n * OP_S + lane_id / 8 + warp_id_s * WARP_S;
const int ri = lane_id % 8;
((Func&&)func)(hi, /*qi*/ 0, si, ri, S[m][n][0]);
((Func &&) func)(hi, /*qi*/ 0, si, ri, S[m][n][0]);
}
}
}
Expand Down Expand Up @@ -272,7 +272,7 @@ struct Impl<Sm70_Simt, T_, Tkv_, CTA_H_, CTA_Q_, CTA_S_, WARP_H_, WARP_Q, WARP_S
smem_K.Load(frag_K[k + 1], k + 1, offset);
}
else {
((Preload&&)preload)();
((Preload &&) preload)();
}

PRAGMA_UNROLL
Expand All @@ -291,10 +291,10 @@ struct Impl<Sm70_Simt, T_, Tkv_, CTA_H_, CTA_Q_, CTA_S_, WARP_H_, WARP_Q, WARP_S
}
}
if (k < K_K - 1) {
((Prefetch&&)prefetch)(k);
((Prefetch &&) prefetch)(k);
}
if (k == K_K - 2) {
((Prefetch&&)prefetch)(K_K - 1);
((Prefetch &&) prefetch)(K_K - 1);
}
}

Expand Down Expand Up @@ -328,7 +328,7 @@ struct Impl<Sm70_Simt, T_, Tkv_, CTA_H_, CTA_Q_, CTA_S_, WARP_H_, WARP_Q, WARP_S
smem_V.Load(frag_V[k + 1], k + 1, offset);
}
else {
((Preload&&)preload)();
((Preload &&) preload)();
}

PRAGMA_UNROLL
Expand All @@ -347,10 +347,10 @@ struct Impl<Sm70_Simt, T_, Tkv_, CTA_H_, CTA_Q_, CTA_S_, WARP_H_, WARP_Q, WARP_S
}
}
if (k < V_K - 1) {
((Prefetch&&)prefetch)(k);
((Prefetch &&) prefetch)(k);
}
if (k == V_K - 2) {
((Prefetch&&)prefetch)(V_K - 1);
((Prefetch &&) prefetch)(V_K - 1);
}
}
}
Expand Down Expand Up @@ -552,7 +552,7 @@ struct Impl<Sm70_Simt, T_, Tkv_, CTA_H_, CTA_Q_, CTA_S_, WARP_H_, WARP_Q, WARP_S
// for (int i = 0; i < 8; ++i) {
// printf("O %4d %4d %f\n", hi + blockIdx.x * CTA_H, di + i, frag_O[m][n][i]);
// }
((Func&&)func)(hi, 0, di, frag_O[m][n]);
((Func &&) func)(hi, 0, di, frag_O[m][n]);
}
}
}
Expand Down
18 changes: 9 additions & 9 deletions src/turbomind/kernels/attention/decoding_sm80.h
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ struct Impl<Sm80_81616, T_, Tkv_, CTA_H_, CTA_Q_, CTA_S_, WARP_H_, WARP_Q, WARP_
for (int q = 0; q < 2; ++q) {
const int si = m * OP_M + lane_id / 4 * 1 + s * 8 + warp_id * WARP_S;
const int hi = n * OP_N + lane_id % 4 * 2 + q * 1;
((Func&&)func)(hi, /*qi*/ 0, si, /*ri*/ 0, S[m][n][s * 2 + q]);
((Func &&) func)(hi, /*qi*/ 0, si, /*ri*/ 0, S[m][n][s * 2 + q]);
}
}
}
Expand All @@ -219,7 +219,7 @@ struct Impl<Sm80_81616, T_, Tkv_, CTA_H_, CTA_Q_, CTA_S_, WARP_H_, WARP_Q, WARP_
for (int q = 0; q < 2; ++q) {
const int hi = lane_id % 4 * 2 + n * OP_N + q * 1;
const int ri = lane_id / 4 * 1;
((Func&&)func)(hi, /*qi*/ 0, ri, frag_M[n][q], frag_L[n][q]);
((Func &&) func)(hi, /*qi*/ 0, ri, frag_M[n][q], frag_L[n][q]);
}
}
}
Expand Down Expand Up @@ -257,7 +257,7 @@ struct Impl<Sm80_81616, T_, Tkv_, CTA_H_, CTA_Q_, CTA_S_, WARP_H_, WARP_Q, WARP_
smem_K.Load(frag_K[k + 1], k + 1, offset);
}
else {
((Preload&&)preload)();
((Preload &&) preload)();
}
PRAGMA_UNROLL
for (int m = 0; m < K_M; ++m) {
Expand All @@ -267,10 +267,10 @@ struct Impl<Sm80_81616, T_, Tkv_, CTA_H_, CTA_Q_, CTA_S_, WARP_H_, WARP_Q, WARP_
}
}
if (k < K_K - 1) {
((Prefetch&&)prefetch)(k);
((Prefetch &&) prefetch)(k);
}
if (k == K_K - 2) {
((Prefetch&&)prefetch)(K_K - 1);
((Prefetch &&) prefetch)(K_K - 1);
}
}
}
Expand All @@ -292,7 +292,7 @@ struct Impl<Sm80_81616, T_, Tkv_, CTA_H_, CTA_Q_, CTA_S_, WARP_H_, WARP_Q, WARP_
smem_V.Load(frag_V[m + 1], m + 1, offset);
}
else {
((Preload&&)preload)();
((Preload &&) preload)();
}
PRAGMA_UNROLL
for (int k = 0; k < V_K; ++k) {
Expand All @@ -302,10 +302,10 @@ struct Impl<Sm80_81616, T_, Tkv_, CTA_H_, CTA_Q_, CTA_S_, WARP_H_, WARP_Q, WARP_
}
}
if (m < V_M - 1) {
((Prefetch&&)prefetch)(m);
((Prefetch &&) prefetch)(m);
}
if (m == V_M - 2) {
((Prefetch&&)prefetch)(V_M - 1);
((Prefetch &&) prefetch)(V_M - 1);
}
}
}
Expand Down Expand Up @@ -575,7 +575,7 @@ struct Impl<Sm80_81616, T_, Tkv_, CTA_H_, CTA_Q_, CTA_S_, WARP_H_, WARP_Q, WARP_
const int hi = offset.y + s * Map::kDeltaS;
const int di = offset.x + c * Map::kDeltaC;
Load(tmp_O[s][c], &storage.O1[hi][di]);
((Func&&)func)(hi, 0, di, tmp_O[s][c]);
((Func &&) func)(hi, 0, di, tmp_O[s][c]);
}
}
}
Expand Down
18 changes: 12 additions & 6 deletions src/turbomind/kernels/attention/impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,20 @@ struct Arch {
}
};

struct Sm80_16816: Arch<80> {};
struct Sm80_16816: Arch<80> {
};

struct Sm80_81616: Arch<80> {};
struct Sm80_81616: Arch<80> {
};

struct Sm75_1688: Arch<75, 80> {};
struct Sm75_1688: Arch<75, 80> {
};

struct Sm70_884: Arch<70, 75> {};
struct Sm70_884: Arch<70, 75> {
};

struct Sm70_Simt: Arch<70> {};
struct Sm70_Simt: Arch<70> {
};

template<class Tag,
class T,
Expand All @@ -35,7 +40,8 @@ template<class Tag,
int WARP_S,
int HeadDim,
int Stages = 2>
struct Impl {};
struct Impl {
};

} // namespace attention

Expand Down
10 changes: 5 additions & 5 deletions src/turbomind/kernels/attention/impl_sm70.h
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ struct Impl<Sm70_884, T_, T_, CTA_H_, CTA_Q_, CTA_S_, WARP_H_, WARP_Q, WARP_S, H
using T = T_;
using Tkv = T_;

using Arch = Sm70_884;
using Arch = Sm70_884;

static constexpr int CTA_H = CTA_H_;
static constexpr int CTA_Q = CTA_Q_;
Expand Down Expand Up @@ -297,7 +297,7 @@ struct Impl<Sm70_884, T_, T_, CTA_H_, CTA_Q_, CTA_S_, WARP_H_, WARP_Q, WARP_S, H
for (int s0 = 0; s0 < 2; ++s0) {
const int qi = m * OP_M + (lane_id & 8) + (lane_id & 1) + lane_id / 16 * 4 + q * 2;
const int si = n * OP_N + (lane_id & 4) * 2 + (lane_id & 2) + s1 * 4 + s0;
((Func&&)func)(0, warp_id * WARP_Q + qi, si, /*ri*/ 0, S[m][n][s1 * 4 + q * 2 + s0]);
((Func &&) func)(0, warp_id * WARP_Q + qi, si, /*ri*/ 0, S[m][n][s1 * 4 + q * 2 + s0]);
}
}
}
Expand Down Expand Up @@ -347,7 +347,7 @@ struct Impl<Sm70_884, T_, T_, CTA_H_, CTA_Q_, CTA_S_, WARP_H_, WARP_Q, WARP_S, H
// }
}
else {
((Preload&&)preload)();
((Preload &&) preload)();
}
PRAGMA_UNROLL
for (int m = 0; m < K_M; ++m) {
Expand Down Expand Up @@ -387,7 +387,7 @@ struct Impl<Sm70_884, T_, T_, CTA_H_, CTA_Q_, CTA_S_, WARP_H_, WARP_Q, WARP_S, H
// }
}
else {
((Preload&&)preload)();
((Preload &&) preload)();
}
PRAGMA_UNROLL
for (int m = 0; m < V_M; ++m) {
Expand Down Expand Up @@ -544,7 +544,7 @@ struct Impl<Sm70_884, T_, T_, CTA_H_, CTA_Q_, CTA_S_, WARP_H_, WARP_Q, WARP_S, H
frag_O[m][n][d1 * 4 + q * 2 + d0] *= inv_L[m][q];
}
}
((Func&&)func)(0, qi, di, (Array<float, 2>&)frag_O[m][n][d1 * 4 + q * 2]);
((Func &&) func)(0, qi, di, (Array<float, 2>&)frag_O[m][n][d1 * 4 + q * 2]);
}
}
}
Expand Down
12 changes: 6 additions & 6 deletions src/turbomind/kernels/attention/impl_sm80.h
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ struct Impl<Sm80_16816, T_, T_, CTA_H_, CTA_Q_, CTA_S_, WARP_H, WARP_Q, WARP_S,
smem_K.Load(frag_K[k + 1], k + 1, offset);
}
else {
((Preload&&)preload)();
((Preload &&) preload)();
}
PRAGMA_UNROLL
for (int m = 0; m < K_M; ++m) {
Expand All @@ -235,10 +235,10 @@ struct Impl<Sm80_16816, T_, T_, CTA_H_, CTA_Q_, CTA_S_, WARP_H, WARP_Q, WARP_S,
}
}
if (k < K_K - 1) {
((Prefetch&&)prefetch)(k);
((Prefetch &&) prefetch)(k);
}
if (k == K_K - 2) {
((Prefetch&&)prefetch)(K_K - 1);
((Prefetch &&) prefetch)(K_K - 1);
}
}
}
Expand All @@ -260,7 +260,7 @@ struct Impl<Sm80_16816, T_, T_, CTA_H_, CTA_Q_, CTA_S_, WARP_H, WARP_Q, WARP_S,
smem_V.Load(frag_V[k + 1], k + 1, offset);
}
else {
((Preload&&)preload)();
((Preload &&) preload)();
}
PRAGMA_UNROLL
for (int m = 0; m < V_M; ++m) {
Expand All @@ -271,10 +271,10 @@ struct Impl<Sm80_16816, T_, T_, CTA_H_, CTA_Q_, CTA_S_, WARP_H, WARP_Q, WARP_S,
}
}
if (k < V_K - 1) {
((Prefetch&&)prefetch)(k);
((Prefetch &&) prefetch)(k);
}
if (k == V_K - 2) {
((Prefetch&&)prefetch)(V_K - 1);
((Prefetch &&) prefetch)(V_K - 1);
}
}
}
Expand Down
9 changes: 5 additions & 4 deletions src/turbomind/kernels/attention/mainloop_sm80.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
namespace turbomind::attention {

template<int Stages>
struct Sm80_CpAsync {};
struct Sm80_CpAsync {
};

template<int Stages, class Impl_>
struct Mainloop<Sm80_CpAsync<Stages>, Impl_> {
Expand Down Expand Up @@ -56,17 +57,17 @@ struct Mainloop<Sm80_CpAsync<Stages>, Impl_> {
template<class... Args>
__device__ void operator()(Args&&... args)
{
Run(Sm80_CpAsync<Stages>{}, ((Args&&)args)...);
Run(Sm80_CpAsync<Stages>{}, ((Args &&) args)...);
}

template<int Idx, class A, class B>
__device__ static decltype(auto) Select(A&& a, B&& b)
{
if constexpr (Idx) {
return (B&&)b;
return (B &&) b;
}
else {
return (A&&)a;
return (A &&) a;
}
}

Expand Down
3 changes: 2 additions & 1 deletion src/turbomind/kernels/attention/reference.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ void invokeApplyRotaryEmbedding(
template<class T>
class Reference {
public:
enum Type {
enum Type
{
kUNFUSED,
kFLASH_ATTENTION
};
Expand Down
2 changes: 1 addition & 1 deletion src/turbomind/kernels/attention/test_attention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ int test_attention()
constexpr size_t kSequenceLen = 0;
constexpr int kMaxSplitK = 1;

constexpr int kBlockSz = 128;
constexpr int kBlockSz = 128;

#endif

Expand Down
5 changes: 2 additions & 3 deletions src/turbomind/kernels/attention/test_utils.cu
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,10 @@ template void Compare(const half* src, const half* ref, size_t stride, int m, in
template void
Compare(const float* src, const float* ref, size_t stride, int m, int n, bool show, float rtol, float atol);
#if ENABLE_BF16
template void Compare(const nv_bfloat16* src, const nv_bfloat16* ref, size_t stride, int m, int n, bool show, float rtol, float atol);
template void
Compare(const nv_bfloat16* src, const nv_bfloat16* ref, size_t stride, int m, int n, bool show, float rtol, float atol);
#endif


void LoadBinary(const std::string& path, size_t size, void* dst)
{
std::ifstream ifs(path, std::ios::binary | std::ios::in);
Expand Down Expand Up @@ -181,7 +181,6 @@ template void RNG::GenerateNormal(float* out, size_t count, float scale, float s
template void RNG::GenerateNormal(nv_bfloat16* out, size_t count, float scale, float shift);
#endif


template<typename T>
struct SATypeConverter {
using Type = T;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
namespace turbomind {

template<typename T>
struct ToCutlassType_ {};
struct ToCutlassType_ {
};

template<>
struct ToCutlassType_<float> {
Expand Down
Loading

0 comments on commit dcad091

Please sign in to comment.