Skip to content

Commit

Permalink
add ntk scaling and logn scaling
Browse files Browse the repository at this point in the history
  • Loading branch information
lzhangzz committed Oct 30, 2023
1 parent 64de1cd commit 699b0bf
Show file tree
Hide file tree
Showing 22 changed files with 287 additions and 166 deletions.
11 changes: 8 additions & 3 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -198,10 +198,13 @@ set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --std=c++${CXX_STD} -DCUDA_PTX_FP8_F2F

set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -O3")
# set(CMAKE_CUDA_FLAGS_RELEASE "${CMAKE_CUDA_FLAGS_RELEASE} -Xcompiler -O3 --ptxas-options=--verbose")
set(CMAKE_CUDA_FLAGS_RELEASE "${CMAKE_CUDA_FLAGS_RELEASE} -Xcompiler -O3 -DCUDA_PTX_FP8_F2FP_ENABLED")
set(CMAKE_CUDA_FLAGS_RELEASE "${CMAKE_CUDA_FLAGS_RELEASE} -Xcompiler -O3 -DCUDA_PTX_FP8_F2FP_ENABLED")
set(CMAKE_CUDA_FLAGS_RELWITHDEBINFO "${CMAKE_CUDA_FLAGS_RELWITHDEBINFO} -Xcompiler -O3 -DCUDA_PTX_FP8_F2FP_ENABLED")

if(BUILD_FAST_MATH)
set(CMAKE_CUDA_FLAGS_RELEASE "${CMAKE_CUDA_FLAGS_RELEASE} --use_fast_math")
message("CMAKE_CUDA_FLAGS_RELEASE: ${CMAKE_CUDA_FLAGS_RELEASE}")
set(CMAKE_CUDA_FLAGS_RELEASE "${CMAKE_CUDA_FLAGS_RELEASE} --use_fast_math")
set(CMAKE_CUDA_FLAGS_RELWITHDEBINFO "${CMAKE_CUDA_FLAGS_RELWITHDEBINFO} --use_fast_math")
message("Release build CUDA flags: ${CMAKE_CUDA_FLAGS_RELEASE}")
endif()

set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib)
Expand Down Expand Up @@ -268,11 +271,13 @@ print(torch._C._GLIBCXX_USE_CXX11_ABI,end='');"
OUTPUT_VARIABLE USE_CXX11_ABI)
message("-- USE_CXX11_ABI=${USE_CXX11_ABI}")
if (USE_CXX11_ABI)
set(CMAKE_CUDA_FLAGS_RELWITHDEBINFO "${CMAKE_CUDA_FLAGS_RELWITHDEBINFO} -D_GLIBCXX_USE_CXX11_ABI=1")
set(CMAKE_CUDA_FLAGS_RELEASE "${CMAKE_CUDA_FLAGS_RELEASE} -D_GLIBCXX_USE_CXX11_ABI=1")
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -D_GLIBCXX_USE_CXX11_ABI=1")
set(CMAKE_CUDA_FLAGS_DEBUG "${CMAKE_CUDA_FLAGS_DEBUG} -D_GLIBCXX_USE_CXX11_ABI=1")
set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -D_GLIBCXX_USE_CXX11_ABI=1")
else()
set(CMAKE_CUDA_FLAGS_RELWITHDEBINFO "${CMAKE_CUDA_FLAGS_RELWITHDEBINFO} -D_GLIBCXX_USE_CXX11_ABI=0")
set(CMAKE_CUDA_FLAGS_RELEASE "${CMAKE_CUDA_FLAGS_RELEASE} -D_GLIBCXX_USE_CXX11_ABI=0")
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -D_GLIBCXX_USE_CXX11_ABI=0")
set(CMAKE_CUDA_FLAGS_DEBUG "${CMAKE_CUDA_FLAGS_DEBUG} -D_GLIBCXX_USE_CXX11_ABI=0")
Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
# Copyright (c) OpenMMLab. All rights reserved.

add_library(decoder_multihead_attention STATIC decoder_multihead_attention.cu kv_cache.cu)
target_compile_options(decoder_multihead_attention PRIVATE
--generate-line-info -O3 -use_fast_math -Xptxas=-v --expt-relaxed-constexpr --keep)
# target_compile_options(decoder_multihead_attention PRIVATE
# --generate-line-info -O3 -use_fast_math -Xptxas=-v --expt-relaxed-constexpr --keep)
set_property(TARGET decoder_multihead_attention PROPERTY POSITION_INDEPENDENT_CODE ON)
set_property(TARGET decoder_multihead_attention PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
target_link_libraries(decoder_multihead_attention PRIVATE nvidia::cutlass::cutlass)

add_executable(test_decoder_multihead_attention test_utils.cu test_decoder_multihead_attention.cu)
target_compile_options(test_decoder_multihead_attention PRIVATE
--generate-line-info -O3 -use_fast_math -Xptxas=-v --expt-relaxed-constexpr)
# target_compile_options(test_decoder_multihead_attention PRIVATE
# --generate-line-info -O3 -use_fast_math -Xptxas=-v --expt-relaxed-constexpr)
target_link_libraries(test_decoder_multihead_attention PRIVATE
decoder_multihead_attention
decoder_masked_multihead_attention
Expand Down
66 changes: 43 additions & 23 deletions src/turbomind/kernels/decoder_multihead_attention/array_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,63 +87,83 @@ inline __device__ Array<T, N> operator*(const Array<T, N>& a, const T& b)

} // namespace ops

template<typename To, typename From, int N>
inline __device__ Array<To, N> cast(const Array<From, N>& src)
{
Array<To, N> dst;
PRAGMA_UNROLL
for (int i = 0; i < N; ++i) {
dst[i] = (To)src[i];
}
return dst;
}

template<int N>
struct RotaryEmbedding {

static_assert(N % 2 == 0);

Array<float, N> inv_freqs_;
Array<float, N> cs_;

__device__ RotaryEmbedding(float base, int dims, int timestep, int2 offset)
{
PRAGMA_UNROLL
for (int i = 0; i < N; i += 2) {
const float2 tmp = rotary_embedding_coefficient(offset.x + i, dims, base, timestep);
inv_freqs_[i] = tmp.x;
inv_freqs_[i + 1] = tmp.y;
const float2 tmp = get_coefficient(offset.x + i, dims, base, timestep);
cs_[i] = tmp.x;
cs_[i + 1] = tmp.y;
}
}

inline __device__ float2 rotary_embedding_coefficient(int idx, int dims, float base, int timestep)
static __device__ inline float2 get_coefficient(int idx, int dims, float base, int timestep)
{
const float inv_freq = timestep / powf(base, idx / (float)dims);
return {cos(inv_freq), sin(inv_freq)};
float2 cs;
sincosf(inv_freq, &cs.y, &cs.x);
return cs;
}

template<typename T>
__device__ void apply(Array<T, N>& x)
{
PRAGMA_UNROLL
for (int i = 0; i < N; i += 2) {
float tmp0 = inv_freqs_[i] * (float)x[i] - inv_freqs_[i + 1] * (float)x[i + 1];
float tmp1 = inv_freqs_[i] * (float)x[i + 1] + inv_freqs_[i + 1] * (float)x[i];
float tmp0 = cs_[i] * (float)x[i] - cs_[i + 1] * (float)x[i + 1];
float tmp1 = cs_[i] * (float)x[i + 1] + cs_[i + 1] * (float)x[i];
x[i] = (T)tmp0;
x[i + 1] = (T)tmp1;
}
}
};

template<typename VecQk, typename ThreadMap>
struct LogNScaling {
__device__ void apply(VecQk& x)

float scale_;

__device__ static float get_scale(int seq_len, int max_position_embeddings)
{
PRAGMA_UNROLL
for (int i = 0; i < VecQk::kSize; ++i) {
// TODO:
if (seq_len <= max_position_embeddings) {
return 1.f;
}
else {
return log2(seq_len) / log2(max_position_embeddings);
}
}
};

template<typename To, typename From, int N>
inline __device__ Array<To, N> cast(const Array<From, N>& src)
{
Array<To, N> dst;
PRAGMA_UNROLL
for (int i = 0; i < N; ++i) {
dst[i] = (To)src[i];
__device__ LogNScaling(int seq_len, int max_position_embeddings)
{
scale_ = get_scale(seq_len, max_position_embeddings);
}
return dst;
}

template<typename T, int N>
__device__ void apply(Array<T, N>& x) const
{
PRAGMA_UNROLL
for (int i = 0; i < N; ++i) {
x[i] = (T)((float)x[i] * scale_);
}
}
};

template<typename T, int N>
inline __device__ void Store(T* dst, const Array<T, N>& src)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ void invokeDecoderMultiheadAttention(const DecoderMultiHeadAttentionParams<T>& p

static const size_t kDynSmemSize = Attn::GetDynamicSmemSize();

[[maybe_unused]] static const bool _ = Print<Attn>(kDynSmemSize);
// [[maybe_unused]] static const bool _ = Print<Attn>(kDynSmemSize);

const int slice_count = (params.max_seq_len + Attn::kSliceLen - 1) / Attn::kSliceLen;
const int max_split_k = std::min(params.max_split_k, std::max(1, slice_count));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ struct DecoderMultiHeadAttentionParams {
// sequence-level buffers
const int* __restrict__ per_sample_length;
const bool* __restrict__ finished;
const float* __restrict__ rope_theta;

// kv cache
void** __restrict__ per_sample_k_cache; // [H, S, D]
Expand Down Expand Up @@ -50,7 +51,7 @@ struct DecoderMultiHeadAttentionParams {
int rotary_embedding_dim;
float rotary_embedding_base;
int max_position_embeddings;
bool use_dynamic_ntk;
// bool use_dynamic_ntk;

// log(n) attention
bool use_logn_attn;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -233,16 +233,30 @@ struct DecoderMultiHeadAttentionKernel {
frag_V = frag_V + bias_V;
}

// for (int i = 0; i < kVecQSize; ++i) {
// printf("q[%2d][%3d] = %f\n", (int)head_idx_, (int)(offset.x + i), (float)frag_Q[0][i]);
// }

float rotary_embedding_base =
params_.rope_theta ? params_.rope_theta[batch_idx_] : params_.rotary_embedding_base;

// Apply rotary embedding
RotaryEmbedding<kVecQSize> rotary_emb(
params_.rotary_embedding_base, params_.rotary_embedding_dim, timestep_, offset);
RotaryEmbedding<kVecQSize> rotary_emb(rotary_embedding_base, params_.rotary_embedding_dim, timestep_, offset);

PRAGMA_UNROLL
for (int s = 0; s < kQHeadPerThread; ++s) {
rotary_emb.apply(frag_Q[s]);
}
rotary_emb.apply(frag_K);

if (params_.use_logn_attn) {
LogNScaling logn_scaling(timestep_ + 1, params_.max_position_embeddings);
PRAGMA_UNROLL
for (int s = 0; s < kQHeadPerThread; ++s) {
logn_scaling.apply(frag_Q[s]);
}
}

if (kSplitK && step_begin_) { // Split idx > 0
PRAGMA_UNROLL
for (int s = 0; s < kQHeadPerThread; ++s) {
Expand All @@ -268,6 +282,7 @@ struct DecoderMultiHeadAttentionKernel {
qk *= params_.inv_sqrt_dh;
smem_M_[qi] = qk;
smem_L_[qi] = 1.f;
// printf("qk[%2d] = %f\n", head_idx_, qk);
}
// write Q and O
Store(&smem_Q_[qi * kMaxHeadDim + offset.x], frag_Q[s]);
Expand Down Expand Up @@ -467,10 +482,6 @@ struct DecoderMultiHeadAttentionKernel {
/// block synchronization
frag_M = qk_max<MapKv>(frag_M, smem_red_max_, warp_id_, lane_id_);

if (threadIdx.x == 0 && step == timestep_ - kSliceLen) {
// printf("frag_M[%d] = %f\n", head_idx_, (float)frag_M[0]);
}

// wait while smem_red_ is being used.
// __syncthreads();

Expand All @@ -488,6 +499,10 @@ struct DecoderMultiHeadAttentionKernel {
}
}

// if (threadIdx.x == 0 && step + iter_length == timestep_) {
// printf("frag_M[%2d] = %f\n", head_idx_, (float)frag_M[0]);
// }

// __syncthreads(); // DEBUG

/////////////////////////////////////////////////////////////////////////////////////////
Expand All @@ -506,17 +521,9 @@ struct DecoderMultiHeadAttentionKernel {
}
}

// if (thread0()) {
// printf("frag_L0 = %f\n", (float)frag_L[0]);
// }

/// block synchronization
frag_L = blockSum<kWarpCount>(frag_L, smem_red_sum_, warp_id_, lane_id_);

if (thread0()) {
// printf("frag_L = %f\n", (float)frag_L[0]);
}

for (int qi = 0; qi < kHeadPerCta; ++qi) {
// exp(m1 - m2) * l1
frag_L[qi] += exp_M_diff[qi] * smem_L_[qi];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,13 +109,13 @@ int main(int argc, char* argv[])
constexpr int kHeadNum = 32;
constexpr int kHeadDim = 128;
constexpr int KvHeadNum = 32;
constexpr int kBatchSize = 1;
constexpr int kContextLen = 1024;
constexpr int kBatchSize = 32;
constexpr int kContextLen = 7306;
// constexpr int kContextLen = 1024;
constexpr int kSequenceLen = kContextLen + 1;
constexpr int kBlockSz = 128;
constexpr int kTestIter = 1;
constexpr int kMaxSplitK = 4;
constexpr int kMaxSplitK = 1;

RNG rng{};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ void mmha_ft_reference(const DecoderMultiHeadAttentionParams<T>& p, cudaStream_t
params.hidden_size_per_head = p.size_per_head;
params.rotary_embedding_dim = p.rotary_embedding_dim;
params.max_position_embeddings = p.max_position_embeddings;
params.use_dynamic_ntk = p.use_dynamic_ntk;
params.use_dynamic_ntk = false;
params.use_logn_attn = p.use_logn_attn;

// Note: keep norm factor (sqrt(K_dim)) when adopting megatron T5 structure (may adjust)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
#pragma once

#include "../gemm_s_f16/common.h"
#include "src/turbomind/kernels/custom_ar_kernels.h"

namespace turbomind {

Expand Down
Loading

0 comments on commit 699b0bf

Please sign in to comment.