From 699b0bfee07ec2c32889f304f9f734bbf322f705 Mon Sep 17 00:00:00 2001 From: Li Zhang Date: Mon, 30 Oct 2023 13:27:54 +0000 Subject: [PATCH] add ntk scaling and logn scaling --- CMakeLists.txt | 11 ++- .../CMakeLists.txt | 8 +- .../decoder_multihead_attention/array_ops.h | 66 +++++++++++------ .../decoder_multihead_attention.cu | 2 +- .../decoder_multihead_attention_params.h | 3 +- .../decoder_multihead_attention_template.h | 35 +++++---- .../test_decoder_multihead_attention.cu | 6 +- .../decoder_multihead_attention/test_utils.cu | 2 +- .../decoder_multihead_attention/thread_map.h | 1 - .../kernels/unfused_attention_kernels.cu | 73 ++++++++++--------- .../kernels/unfused_attention_kernels.h | 1 + src/turbomind/models/llama/LlamaBatch.cc | 53 ++++++++++++-- src/turbomind/models/llama/LlamaBatch.h | 9 ++- .../llama/LlamaContextAttentionLayer.cc | 7 +- .../models/llama/LlamaContextDecoder.cc | 1 + .../llama/LlamaDecoderSelfAttentionLayer.cc | 30 +++++--- src/turbomind/models/llama/LlamaV2.cc | 61 +++++++++------- src/turbomind/models/llama/LlamaV2.h | 63 ++++++++-------- src/turbomind/models/llama/SequenceManager.cc | 2 +- src/turbomind/models/llama/SequenceManager.h | 2 + src/turbomind/models/llama/llama_params.h | 5 +- .../triton_backend/llama/LlamaTritonModel.cc | 12 ++- 22 files changed, 287 insertions(+), 166 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index f3f1c7b171..a004d76af5 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -198,10 +198,13 @@ set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --std=c++${CXX_STD} -DCUDA_PTX_FP8_F2F set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -O3") # set(CMAKE_CUDA_FLAGS_RELEASE "${CMAKE_CUDA_FLAGS_RELEASE} -Xcompiler -O3 --ptxas-options=--verbose") -set(CMAKE_CUDA_FLAGS_RELEASE "${CMAKE_CUDA_FLAGS_RELEASE} -Xcompiler -O3 -DCUDA_PTX_FP8_F2FP_ENABLED") +set(CMAKE_CUDA_FLAGS_RELEASE "${CMAKE_CUDA_FLAGS_RELEASE} -Xcompiler -O3 -DCUDA_PTX_FP8_F2FP_ENABLED") +set(CMAKE_CUDA_FLAGS_RELWITHDEBINFO "${CMAKE_CUDA_FLAGS_RELWITHDEBINFO} -Xcompiler -O3 -DCUDA_PTX_FP8_F2FP_ENABLED") + if(BUILD_FAST_MATH) -set(CMAKE_CUDA_FLAGS_RELEASE "${CMAKE_CUDA_FLAGS_RELEASE} --use_fast_math") -message("CMAKE_CUDA_FLAGS_RELEASE: ${CMAKE_CUDA_FLAGS_RELEASE}") + set(CMAKE_CUDA_FLAGS_RELEASE "${CMAKE_CUDA_FLAGS_RELEASE} --use_fast_math") + set(CMAKE_CUDA_FLAGS_RELWITHDEBINFO "${CMAKE_CUDA_FLAGS_RELWITHDEBINFO} --use_fast_math") + message("Release build CUDA flags: ${CMAKE_CUDA_FLAGS_RELEASE}") endif() set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib) @@ -268,11 +271,13 @@ print(torch._C._GLIBCXX_USE_CXX11_ABI,end='');" OUTPUT_VARIABLE USE_CXX11_ABI) message("-- USE_CXX11_ABI=${USE_CXX11_ABI}") if (USE_CXX11_ABI) + set(CMAKE_CUDA_FLAGS_RELWITHDEBINFO "${CMAKE_CUDA_FLAGS_RELWITHDEBINFO} -D_GLIBCXX_USE_CXX11_ABI=1") set(CMAKE_CUDA_FLAGS_RELEASE "${CMAKE_CUDA_FLAGS_RELEASE} -D_GLIBCXX_USE_CXX11_ABI=1") set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -D_GLIBCXX_USE_CXX11_ABI=1") set(CMAKE_CUDA_FLAGS_DEBUG "${CMAKE_CUDA_FLAGS_DEBUG} -D_GLIBCXX_USE_CXX11_ABI=1") set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -D_GLIBCXX_USE_CXX11_ABI=1") else() + set(CMAKE_CUDA_FLAGS_RELWITHDEBINFO "${CMAKE_CUDA_FLAGS_RELWITHDEBINFO} -D_GLIBCXX_USE_CXX11_ABI=0") set(CMAKE_CUDA_FLAGS_RELEASE "${CMAKE_CUDA_FLAGS_RELEASE} -D_GLIBCXX_USE_CXX11_ABI=0") set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -D_GLIBCXX_USE_CXX11_ABI=0") set(CMAKE_CUDA_FLAGS_DEBUG "${CMAKE_CUDA_FLAGS_DEBUG} -D_GLIBCXX_USE_CXX11_ABI=0") diff --git a/src/turbomind/kernels/decoder_multihead_attention/CMakeLists.txt b/src/turbomind/kernels/decoder_multihead_attention/CMakeLists.txt index 7176017671..fe67d11f0a 100644 --- a/src/turbomind/kernels/decoder_multihead_attention/CMakeLists.txt +++ b/src/turbomind/kernels/decoder_multihead_attention/CMakeLists.txt @@ -1,15 +1,15 @@ # Copyright (c) OpenMMLab. All rights reserved. add_library(decoder_multihead_attention STATIC decoder_multihead_attention.cu kv_cache.cu) -target_compile_options(decoder_multihead_attention PRIVATE - --generate-line-info -O3 -use_fast_math -Xptxas=-v --expt-relaxed-constexpr --keep) +# target_compile_options(decoder_multihead_attention PRIVATE +# --generate-line-info -O3 -use_fast_math -Xptxas=-v --expt-relaxed-constexpr --keep) set_property(TARGET decoder_multihead_attention PROPERTY POSITION_INDEPENDENT_CODE ON) set_property(TARGET decoder_multihead_attention PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) target_link_libraries(decoder_multihead_attention PRIVATE nvidia::cutlass::cutlass) add_executable(test_decoder_multihead_attention test_utils.cu test_decoder_multihead_attention.cu) -target_compile_options(test_decoder_multihead_attention PRIVATE - --generate-line-info -O3 -use_fast_math -Xptxas=-v --expt-relaxed-constexpr) +# target_compile_options(test_decoder_multihead_attention PRIVATE +# --generate-line-info -O3 -use_fast_math -Xptxas=-v --expt-relaxed-constexpr) target_link_libraries(test_decoder_multihead_attention PRIVATE decoder_multihead_attention decoder_masked_multihead_attention diff --git a/src/turbomind/kernels/decoder_multihead_attention/array_ops.h b/src/turbomind/kernels/decoder_multihead_attention/array_ops.h index a847ada855..209da7e71d 100644 --- a/src/turbomind/kernels/decoder_multihead_attention/array_ops.h +++ b/src/turbomind/kernels/decoder_multihead_attention/array_ops.h @@ -87,27 +87,40 @@ inline __device__ Array operator*(const Array& a, const T& b) } // namespace ops +template +inline __device__ Array cast(const Array& src) +{ + Array dst; + PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + dst[i] = (To)src[i]; + } + return dst; +} + template struct RotaryEmbedding { static_assert(N % 2 == 0); - Array inv_freqs_; + Array cs_; __device__ RotaryEmbedding(float base, int dims, int timestep, int2 offset) { PRAGMA_UNROLL for (int i = 0; i < N; i += 2) { - const float2 tmp = rotary_embedding_coefficient(offset.x + i, dims, base, timestep); - inv_freqs_[i] = tmp.x; - inv_freqs_[i + 1] = tmp.y; + const float2 tmp = get_coefficient(offset.x + i, dims, base, timestep); + cs_[i] = tmp.x; + cs_[i + 1] = tmp.y; } } - inline __device__ float2 rotary_embedding_coefficient(int idx, int dims, float base, int timestep) + static __device__ inline float2 get_coefficient(int idx, int dims, float base, int timestep) { const float inv_freq = timestep / powf(base, idx / (float)dims); - return {cos(inv_freq), sin(inv_freq)}; + float2 cs; + sincosf(inv_freq, &cs.y, &cs.x); + return cs; } template @@ -115,35 +128,42 @@ struct RotaryEmbedding { { PRAGMA_UNROLL for (int i = 0; i < N; i += 2) { - float tmp0 = inv_freqs_[i] * (float)x[i] - inv_freqs_[i + 1] * (float)x[i + 1]; - float tmp1 = inv_freqs_[i] * (float)x[i + 1] + inv_freqs_[i + 1] * (float)x[i]; + float tmp0 = cs_[i] * (float)x[i] - cs_[i + 1] * (float)x[i + 1]; + float tmp1 = cs_[i] * (float)x[i + 1] + cs_[i + 1] * (float)x[i]; x[i] = (T)tmp0; x[i + 1] = (T)tmp1; } } }; -template struct LogNScaling { - __device__ void apply(VecQk& x) + + float scale_; + + __device__ static float get_scale(int seq_len, int max_position_embeddings) { - PRAGMA_UNROLL - for (int i = 0; i < VecQk::kSize; ++i) { - // TODO: + if (seq_len <= max_position_embeddings) { + return 1.f; + } + else { + return log2(seq_len) / log2(max_position_embeddings); } } -}; -template -inline __device__ Array cast(const Array& src) -{ - Array dst; - PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - dst[i] = (To)src[i]; + __device__ LogNScaling(int seq_len, int max_position_embeddings) + { + scale_ = get_scale(seq_len, max_position_embeddings); } - return dst; -} + + template + __device__ void apply(Array& x) const + { + PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + x[i] = (T)((float)x[i] * scale_); + } + } +}; template inline __device__ void Store(T* dst, const Array& src) diff --git a/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention.cu b/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention.cu index 709db6ebc0..02cc827694 100644 --- a/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention.cu +++ b/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention.cu @@ -40,7 +40,7 @@ void invokeDecoderMultiheadAttention(const DecoderMultiHeadAttentionParams& p static const size_t kDynSmemSize = Attn::GetDynamicSmemSize(); - [[maybe_unused]] static const bool _ = Print(kDynSmemSize); + // [[maybe_unused]] static const bool _ = Print(kDynSmemSize); const int slice_count = (params.max_seq_len + Attn::kSliceLen - 1) / Attn::kSliceLen; const int max_split_k = std::min(params.max_split_k, std::max(1, slice_count)); diff --git a/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention_params.h b/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention_params.h index 5f18b45216..add5a7161c 100644 --- a/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention_params.h +++ b/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention_params.h @@ -22,6 +22,7 @@ struct DecoderMultiHeadAttentionParams { // sequence-level buffers const int* __restrict__ per_sample_length; const bool* __restrict__ finished; + const float* __restrict__ rope_theta; // kv cache void** __restrict__ per_sample_k_cache; // [H, S, D] @@ -50,7 +51,7 @@ struct DecoderMultiHeadAttentionParams { int rotary_embedding_dim; float rotary_embedding_base; int max_position_embeddings; - bool use_dynamic_ntk; + // bool use_dynamic_ntk; // log(n) attention bool use_logn_attn; diff --git a/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention_template.h b/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention_template.h index dfeb86e568..ae82a8b786 100644 --- a/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention_template.h +++ b/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention_template.h @@ -233,9 +233,15 @@ struct DecoderMultiHeadAttentionKernel { frag_V = frag_V + bias_V; } + // for (int i = 0; i < kVecQSize; ++i) { + // printf("q[%2d][%3d] = %f\n", (int)head_idx_, (int)(offset.x + i), (float)frag_Q[0][i]); + // } + + float rotary_embedding_base = + params_.rope_theta ? params_.rope_theta[batch_idx_] : params_.rotary_embedding_base; + // Apply rotary embedding - RotaryEmbedding rotary_emb( - params_.rotary_embedding_base, params_.rotary_embedding_dim, timestep_, offset); + RotaryEmbedding rotary_emb(rotary_embedding_base, params_.rotary_embedding_dim, timestep_, offset); PRAGMA_UNROLL for (int s = 0; s < kQHeadPerThread; ++s) { @@ -243,6 +249,14 @@ struct DecoderMultiHeadAttentionKernel { } rotary_emb.apply(frag_K); + if (params_.use_logn_attn) { + LogNScaling logn_scaling(timestep_ + 1, params_.max_position_embeddings); + PRAGMA_UNROLL + for (int s = 0; s < kQHeadPerThread; ++s) { + logn_scaling.apply(frag_Q[s]); + } + } + if (kSplitK && step_begin_) { // Split idx > 0 PRAGMA_UNROLL for (int s = 0; s < kQHeadPerThread; ++s) { @@ -268,6 +282,7 @@ struct DecoderMultiHeadAttentionKernel { qk *= params_.inv_sqrt_dh; smem_M_[qi] = qk; smem_L_[qi] = 1.f; + // printf("qk[%2d] = %f\n", head_idx_, qk); } // write Q and O Store(&smem_Q_[qi * kMaxHeadDim + offset.x], frag_Q[s]); @@ -467,10 +482,6 @@ struct DecoderMultiHeadAttentionKernel { /// block synchronization frag_M = qk_max(frag_M, smem_red_max_, warp_id_, lane_id_); - if (threadIdx.x == 0 && step == timestep_ - kSliceLen) { - // printf("frag_M[%d] = %f\n", head_idx_, (float)frag_M[0]); - } - // wait while smem_red_ is being used. // __syncthreads(); @@ -488,6 +499,10 @@ struct DecoderMultiHeadAttentionKernel { } } + // if (threadIdx.x == 0 && step + iter_length == timestep_) { + // printf("frag_M[%2d] = %f\n", head_idx_, (float)frag_M[0]); + // } + // __syncthreads(); // DEBUG ///////////////////////////////////////////////////////////////////////////////////////// @@ -506,17 +521,9 @@ struct DecoderMultiHeadAttentionKernel { } } - // if (thread0()) { - // printf("frag_L0 = %f\n", (float)frag_L[0]); - // } - /// block synchronization frag_L = blockSum(frag_L, smem_red_sum_, warp_id_, lane_id_); - if (thread0()) { - // printf("frag_L = %f\n", (float)frag_L[0]); - } - for (int qi = 0; qi < kHeadPerCta; ++qi) { // exp(m1 - m2) * l1 frag_L[qi] += exp_M_diff[qi] * smem_L_[qi]; diff --git a/src/turbomind/kernels/decoder_multihead_attention/test_decoder_multihead_attention.cu b/src/turbomind/kernels/decoder_multihead_attention/test_decoder_multihead_attention.cu index b5249f31c2..e4636bcea0 100644 --- a/src/turbomind/kernels/decoder_multihead_attention/test_decoder_multihead_attention.cu +++ b/src/turbomind/kernels/decoder_multihead_attention/test_decoder_multihead_attention.cu @@ -109,13 +109,13 @@ int main(int argc, char* argv[]) constexpr int kHeadNum = 32; constexpr int kHeadDim = 128; constexpr int KvHeadNum = 32; - constexpr int kBatchSize = 1; - constexpr int kContextLen = 1024; + constexpr int kBatchSize = 32; + constexpr int kContextLen = 7306; // constexpr int kContextLen = 1024; constexpr int kSequenceLen = kContextLen + 1; constexpr int kBlockSz = 128; constexpr int kTestIter = 1; - constexpr int kMaxSplitK = 4; + constexpr int kMaxSplitK = 1; RNG rng{}; diff --git a/src/turbomind/kernels/decoder_multihead_attention/test_utils.cu b/src/turbomind/kernels/decoder_multihead_attention/test_utils.cu index 883f0fc3d0..c3fb0d77bc 100644 --- a/src/turbomind/kernels/decoder_multihead_attention/test_utils.cu +++ b/src/turbomind/kernels/decoder_multihead_attention/test_utils.cu @@ -226,7 +226,7 @@ void mmha_ft_reference(const DecoderMultiHeadAttentionParams& p, cudaStream_t params.hidden_size_per_head = p.size_per_head; params.rotary_embedding_dim = p.rotary_embedding_dim; params.max_position_embeddings = p.max_position_embeddings; - params.use_dynamic_ntk = p.use_dynamic_ntk; + params.use_dynamic_ntk = false; params.use_logn_attn = p.use_logn_attn; // Note: keep norm factor (sqrt(K_dim)) when adopting megatron T5 structure (may adjust) diff --git a/src/turbomind/kernels/decoder_multihead_attention/thread_map.h b/src/turbomind/kernels/decoder_multihead_attention/thread_map.h index 47b2636f6d..f4c2be1da2 100644 --- a/src/turbomind/kernels/decoder_multihead_attention/thread_map.h +++ b/src/turbomind/kernels/decoder_multihead_attention/thread_map.h @@ -3,7 +3,6 @@ #pragma once #include "../gemm_s_f16/common.h" -#include "src/turbomind/kernels/custom_ar_kernels.h" namespace turbomind { diff --git a/src/turbomind/kernels/unfused_attention_kernels.cu b/src/turbomind/kernels/unfused_attention_kernels.cu index abbbfd5562..040f7204bf 100644 --- a/src/turbomind/kernels/unfused_attention_kernels.cu +++ b/src/turbomind/kernels/unfused_attention_kernels.cu @@ -15,7 +15,7 @@ * limitations under the License. */ -#include "src/turbomind/kernels/decoder_masked_multihead_attention_utils.h" +#include "src/turbomind/kernels/decoder_multihead_attention/array_ops.h" #include "src/turbomind/kernels/reduce_kernel_utils.cuh" #include "src/turbomind/kernels/unfused_attention_kernels.h" #include "src/turbomind/utils/cuda_type_utils.cuh" @@ -854,19 +854,20 @@ __global__ void add_fusedQKV_bias_transpose_kernel(T* q_buf, T* v_buf, T* QKV, const T* __restrict qkv_bias, - const int* padding_offset, - const int* context_length, - const int* input_length, - int batch_size, - int seq_len, - int head_num, - int kv_head_num, - int size_per_head, - int rotary_embedding_dim, - float rotary_embedding_base, - int max_position_embeddings, - bool use_dynamic_ntk, - bool use_logn_attn) + const int* padding_offset, + const int* context_length, + const int* input_length, + const float* rope_theta, + int batch_size, + int seq_len, + int head_num, + int kv_head_num, + int size_per_head, + int rotary_embedding_dim, + float rotary_embedding_base, + int max_position_embeddings, + bool use_dynamic_ntk, + bool use_logn_attn) { // This kernel add bias to QKV, which has shape [batch_size, seq_len, 3, head_num, size_per_head], and // QKV split to 3 split buffer q, k, v and transpose them to [batch_size, head_num, seq_len, size_per_head]. @@ -907,12 +908,18 @@ __global__ void add_fusedQKV_bias_transpose_kernel(T* q_buf, Vec_t q, k, v; Vec_t q_bias, k_bias, v_bias; + using Vec = Array; + + static_assert(sizeof(Vec_t) == sizeof(Vec)); + + using namespace ops; + // load Q and apply bias if (!is_masked) { q = *reinterpret_cast(&QKV[src_q_idx]); if (qkv_bias) { - q_bias = *reinterpret_cast(&qkv_bias[hidden_idx]); - q = mmha::add(q, q_bias); + q_bias = *reinterpret_cast(&qkv_bias[hidden_idx]); + (Vec&)q = (Vec&)q + (Vec&)q_bias; } } @@ -921,10 +928,10 @@ __global__ void add_fusedQKV_bias_transpose_kernel(T* q_buf, k = *reinterpret_cast(&QKV[src_k_idx]); v = *reinterpret_cast(&QKV[src_v_idx]); if (qkv_bias) { - k_bias = *reinterpret_cast(&qkv_bias[hidden_idx + k_offset]); - v_bias = *reinterpret_cast(&qkv_bias[hidden_idx + v_offset]); - k = mmha::add(k, k_bias); - v = mmha::add(v, v_bias); + k_bias = *reinterpret_cast(&qkv_bias[hidden_idx + k_offset]); + v_bias = *reinterpret_cast(&qkv_bias[hidden_idx + v_offset]); + (Vec&)k = (Vec&)k + (Vec&)k_bias; + (Vec&)v = (Vec&)v + (Vec&)v_bias; } } @@ -932,24 +939,21 @@ __global__ void add_fusedQKV_bias_transpose_kernel(T* q_buf, const int history_len = context_len - input_length[batch_idx]; const int timestep = history_len + seq_idx; - if (use_dynamic_ntk) { - rotary_embedding_base = mmha::rotary_embedding_get_base( - context_len, max_position_embeddings, rotary_embedding_dim, rotary_embedding_base); + if (rope_theta) { + rotary_embedding_base = rope_theta[batch_idx]; } - // TODO: unused computation on k if GQA is used - mmha::apply_rotary_embedding(q, k, tidx, rotary_embedding_dim, rotary_embedding_base, timestep); + RotaryEmbedding rotary_emb(rotary_embedding_base, rotary_embedding_dim, timestep, {tidx * vec_size, 0}); + rotary_emb.apply((Array&)q); + + if (head_idx < kv_head_num) { + rotary_emb.apply((Array&)k); + } if (use_logn_attn) { // +1 to convert to context length at the timestep - float logn_scaling = mmha::logn_attn_get_scaling(timestep + 1, max_position_embeddings); - if constexpr (std::is_same_v) { - q = mmha::mul(logn_scaling, q); - } - else if constexpr (std::is_same_v) { - half tmp = __float2half(logn_scaling); - q = mmha::mul((uint16_t&)tmp, q); - } + LogNScaling logn_scaling(timestep + 1, max_position_embeddings); + logn_scaling.apply((Array&)q); } if (!is_masked && !q_buf) { // also skip modifying QKV if q/k/v_buf are present @@ -984,6 +988,7 @@ __global__ void add_fusedQKV_bias_transpose_kernel(T* q_buf, padding_offset, \ context_length, \ input_length, \ + rope_theta, \ batch_size, \ seq_len, \ head_num, \ @@ -1004,6 +1009,7 @@ void invokeAddFusedQKVBiasTranspose(T* q_buf, const int* padding_offset, const int* context_length, const int* input_length, + const float* rope_theta, const int batch_size, const int seq_len, const int token_num, @@ -1034,6 +1040,7 @@ void invokeAddFusedQKVBiasTranspose(T* q_buf, const int* padding_offset, \ const int* history_length, \ const int* input_length, \ + const float* rope_theta, \ const int batch_size, \ const int seq_len, \ const int token_num, \ diff --git a/src/turbomind/kernels/unfused_attention_kernels.h b/src/turbomind/kernels/unfused_attention_kernels.h index 846a1b7371..758fe7fba0 100644 --- a/src/turbomind/kernels/unfused_attention_kernels.h +++ b/src/turbomind/kernels/unfused_attention_kernels.h @@ -72,6 +72,7 @@ void invokeAddFusedQKVBiasTranspose(T* q_buf, const int* padding_offset, const int* context_length, const int* input_length, + const float* rope_theta, const int batch_size, const int seq_len, const int token_num, diff --git a/src/turbomind/models/llama/LlamaBatch.cc b/src/turbomind/models/llama/LlamaBatch.cc index d6d94be43d..f46d7ebe35 100644 --- a/src/turbomind/models/llama/LlamaBatch.cc +++ b/src/turbomind/models/llama/LlamaBatch.cc @@ -12,11 +12,12 @@ #include "src/turbomind/utils/Tensor.h" #include "src/turbomind/utils/cuda_utils.h" #include "src/turbomind/utils/debug_utils.h" +#include "src/turbomind/utils/gemm_test/gemm_func.h" #include "src/turbomind/utils/logger.h" #include +#include #include #include -#include #include #include #include @@ -59,11 +60,11 @@ void LlamaBatch::RejectInvalidRequests(Requests& stop_reqs, Requests& infer_r ec = Request::kInvalid; } else if (input_length > session_len_) { - ec = Request::kInvalid; + ec = Request::kTooLong; } else if (!r->start_flag) { if (auto seq = sequence_manager_->Get(r->id); seq == nullptr) { - ec = Request::kTooLong; + ec = Request::kInvalid; } else if (get_offset(seq->tokens.size()) + input_length > session_len_) { ec = Request::kTooLong; @@ -230,7 +231,7 @@ void LlamaBatch::ProcessInferRequests(const Requests& requests) if (rank_ == 0) { const int trunc_output_len = state.seq_len_limit[i] - state.h_context_length[i]; TM_LOG_WARNING( - "[initialize] [%ld] total sequence length (%d + %d) exceeds `session_len` (%d), `request_output_len` is truncated to %d", + "[ProcessInferRequests] [%ld] total sequence length (%d + %d) exceeds `session_len` (%d), `request_output_len` is truncated to %d", (long)seq.id, state.h_context_length[i], request_output_len, @@ -239,7 +240,35 @@ void LlamaBatch::ProcessInferRequests(const Requests& requests) } } - // recover random state HtoD if not a new sequence + // compute rope scaling factor + if (r->start_flag) { + seq.rope_theta = model_->attn_params_.rotary_embedding_base; + auto scaling_factor = 1.f; + if (r->inputs[rank_].isExist("rope_scaling_factor")) { // runtime scaling factor + scaling_factor = r->inputs[rank_].getVal("rope_scaling_factor"); + } + else if (model_->attn_params_.rope_scaling_factor >= 1.f) { // infer by `seq_len_limit` + scaling_factor = model_->attn_params_.rope_scaling_factor; + auto max_seq_len = state.seq_len_limit[i]; + auto max_pos_emb = model_->attn_params_.max_position_embeddings; + if (max_seq_len > max_pos_emb) { + scaling_factor = scaling_factor * max_seq_len / max_pos_emb - (scaling_factor - 1); + // scaling_factor = std::max(exp2f(ceilf(log2f((float)max_seq_len / max_pos_emb) + 1.f)) + // - 1.f, 1.f); + } + } + if (scaling_factor != 1.f) { + float rope_dim = model_->attn_params_.rotary_embedding_dim; + seq.rope_theta *= powf(scaling_factor, rope_dim / (rope_dim - 2.f)); + TM_LOG_INFO("[ProcessInferRequests] %ld rope_scaling_factor: %f, rope_theta = %f", + (long)seq.id, + scaling_factor, + seq.rope_theta); + } + } + state.h_rope_theta[i] = seq.rope_theta; + + // recover device states if not a new sequence if (!r->start_flag) { Copy((curandState_t*)seq.random_state.data() + 0, 1, (curandState_t*)state.top_k_curand_state); Copy((curandState_t*)seq.random_state.data() + 1, 1, (curandState_t*)state.top_p_curand_state); @@ -415,6 +444,7 @@ void LlamaBatch::CopyState(const std::pair _src, const std: dst->h_context_length[j] = src->h_context_length[i]; dst->h_finished[j] = src->h_finished[i]; + dst->h_rope_theta[j] = src->h_rope_theta[i]; dst->seq_len_limit[j] = src->seq_len_limit[i]; dst->sequences[j] = src->sequences[i]; dst->is_swap_in[j] = src->is_swap_in[i]; @@ -495,6 +525,8 @@ void LlamaBatch::AllocateBuffer(size_t batch_size, size_t session_len) request_output_ids_lens_ = (int*)allocator_->reMalloc(request_output_ids_lens_, sizeof(int) * batch_size, true); request_seqlen_ptrs_ = (int**)allocator_->reMalloc(request_seqlen_ptrs_, sizeof(int*) * batch_size, true); + rope_theta_ = (float*)allocator_->reMalloc(rope_theta_, sizeof(float) * batch_size, false); + is_allocate_buffer_ = true; } @@ -549,7 +581,8 @@ void LlamaBatch::AllocatePersistantBuffer(size_t max_batch_size) for (auto& s : states_) { s.h_context_length = (int*)allocator_->reMalloc(s.h_context_length, sizeof(int) * max_batch_size, false, true); - s.h_finished = (bool*)allocator_->reMalloc(s.h_finished, sizeof(bool) * max_batch_size * 2, false, true); + s.h_finished = (bool*)allocator_->reMalloc(s.h_finished, sizeof(bool) * max_batch_size * 2, false, true); + s.h_rope_theta = (float*)allocator_->reMalloc(s.h_rope_theta, sizeof(float) * max_batch_size, false, true); } h_seq_limit_len_ = @@ -613,6 +646,8 @@ void LlamaBatch::FreeBuffer() allocator_->free((void**)&request_output_ids_lens_); allocator_->free((void**)&request_seqlen_ptrs_); + allocator_->free((void**)&rope_theta_); + is_allocate_buffer_ = false; } @@ -620,6 +655,7 @@ void LlamaBatch::FreeBuffer() for (auto& s : states_) { allocator_->free((void**)&s.h_context_length, true); allocator_->free((void**)&s.h_finished, true); + allocator_->free((void**)&s.h_rope_theta, true); allocator_->free((void**)&s.output_ids); } allocator_->free((void**)&h_tmp_k_ptrs_, true); @@ -792,6 +828,8 @@ auto LlamaBatch::InitializeGeneration() -> GenerationState Copy(h_request_output_ids_lens_, batch_size, request_output_ids_lens_); Copy(h_request_seqlen_ptrs_, batch_size, request_seqlen_ptrs_); + Copy(state_->h_rope_theta, batch_size, rope_theta_); + // ! range of step_ [1, 2 * session_len] // consider a sequence with context_len == session_len and another sequence with context_len == 1 and // request_output_len == session_len - 1 => step_ will loop in [session_len, 2 * session_len) @@ -851,6 +889,7 @@ bool LlamaBatch::Generate(GenerationState& g) sequence_lengths_, finished_buf_, cu_block_counts_, + rope_theta_, g.step, 0, g.sum_seq_len, @@ -938,6 +977,7 @@ void LlamaBatch::ContextDecode() const int context_decode_count = batch_size - base; Copy(state_->h_context_length, batch_size, context_length_buf_); + Copy(state_->h_rope_theta, batch_size, rope_theta_); Copy(h_input_length_buf_, batch_size, input_length_buf_); check_cuda_error(cudaStreamSynchronize(stream_)); @@ -1042,6 +1082,7 @@ void LlamaBatch::ContextDecode() input_length_buf_ + first, context_length_buf_ + first, cu_block_counts_ + first, + rope_theta_ + first, token_count, max_input_len, max_context_cnts[k], diff --git a/src/turbomind/models/llama/LlamaBatch.h b/src/turbomind/models/llama/LlamaBatch.h index 4c8f8154be..4e7c2e7b11 100644 --- a/src/turbomind/models/llama/LlamaBatch.h +++ b/src/turbomind/models/llama/LlamaBatch.h @@ -22,6 +22,8 @@ struct BatchState { void* top_p_curand_state; int* output_ids; // output ids in [B, S] + float* h_rope_theta; + std::vector seq_len_limit; std::vector is_swap_in; @@ -180,6 +182,8 @@ class LlamaBatch { float* context_logits_buf_{}; float* local_context_logits_buf_{}; + float* rope_theta_{}; + // used by dynamic decoder int* token_ids_buf_{}; // all token IDs in [S, B], indexed using `step` int* end_ids_buf_{}; @@ -194,9 +198,8 @@ class LlamaBatch { int** h_request_seqlen_ptrs_{}; // pinned buffers - int* h_input_ids_buf_{}; - int* h_input_length_buf_{}; - // int* h_sequence_lengths_{}; + int* h_input_ids_buf_{}; + int* h_input_length_buf_{}; uint32_t* h_seq_limit_len_{}; int* h_cu_block_counts_{}; uintptr_t* h_k_block_ptrs_{}; diff --git a/src/turbomind/models/llama/LlamaContextAttentionLayer.cc b/src/turbomind/models/llama/LlamaContextAttentionLayer.cc index 1a62e2fb77..92fe00dc56 100644 --- a/src/turbomind/models/llama/LlamaContextAttentionLayer.cc +++ b/src/turbomind/models/llama/LlamaContextAttentionLayer.cc @@ -149,6 +149,8 @@ inline void LlamaContextAttentionLayer::forward(TensorMap* int* cu_seqlens = input_tensors->at("cu_seqlens").getPtr(); int* cu_block_counts = input_tensors->at("cu_block_counts").getPtr(); + const float* rope_theta = input_tensors->getPtr("rope_theta", nullptr); + const auto padding_offset = input_tensors->at("padding_offset").getPtr(); auto Show = [&](const T* x, size_t n) { @@ -179,16 +181,17 @@ inline void LlamaContextAttentionLayer::forward(TensorMap* padding_offset, // padding_offset, context_length, // used for applying rotary embedding input_length, + rope_theta, batch_size, max_q_len, // seq_len num_token, // batch_size * seq_len local_head_num_, local_kv_head_num_, size_per_head_, - params_.rotray_embedding_dim, + params_.rotary_embedding_dim, params_.rotary_embedding_base, params_.max_position_embeddings, - params_.use_dynamic_ntk, + false, // params_.use_dynamic_ntk, params_.use_logn_attn, stream_); sync_check_cuda_error(); diff --git a/src/turbomind/models/llama/LlamaContextDecoder.cc b/src/turbomind/models/llama/LlamaContextDecoder.cc index 2047ffa050..268ff7ab58 100644 --- a/src/turbomind/models/llama/LlamaContextDecoder.cc +++ b/src/turbomind/models/llama/LlamaContextDecoder.cc @@ -114,6 +114,7 @@ void LlamaContextDecoder::forwardSelfAttn(const Session& {"input_lengths", {MEMORY_GPU, TYPE_INT32, {sess.batch_size}, sess.input_length}}, {"context_lengths", {MEMORY_GPU, TYPE_INT32, {sess.batch_size}, sess.context_length}}, {"cu_block_counts", input_tensors->at("cu_block_counts")}, + {"rope_theta", input_tensors->at("rope_theta")}, {"max_seq_len", input_tensors->at("max_seq_len")}}; TensorMap self_attention_output_tensors{ diff --git a/src/turbomind/models/llama/LlamaDecoderSelfAttentionLayer.cc b/src/turbomind/models/llama/LlamaDecoderSelfAttentionLayer.cc index d411f3f412..ecce30072c 100644 --- a/src/turbomind/models/llama/LlamaDecoderSelfAttentionLayer.cc +++ b/src/turbomind/models/llama/LlamaDecoderSelfAttentionLayer.cc @@ -98,18 +98,14 @@ void LlamaDecoderSelfAttentionLayer::forward(TensorMap* o const int layer_id = input_tensors->getVal("layer_id"); - // const int step = input_tensors->getVal("step"); + const int step = input_tensors->getVal("step"); // const int step_1 = step - 1; const int batch_size = input_tensors->at("input_query").shape[0]; - allocateBuffer(batch_size); + const float* rope_theta = input_tensors->getPtr("rope_theta", nullptr); - // std::vector seqlens(batch_size); - // check_cuda_error( - // cudaMemcpyAsync(seqlens.data(), sequence_lengths_data, sizeof(int) * batch_size, cudaMemcpyDefault, - // stream_)); - // check_cuda_error(cudaStreamSynchronize(stream_)); + allocateBuffer(batch_size); // for (int i = 0; i < batch_size; ++i) { // if (gSequenceIds(i) == 1) { @@ -126,6 +122,10 @@ void LlamaDecoderSelfAttentionLayer::forward(TensorMap* o linear_.forward(qkv_buf_, input_query_data, batch_size, weights->qkv); } + // if (layer_id == 0) { + // Compare(qkv_buf_, batch_size * 3 * hidden_units_, Concat("qkv_buf", step, layer_id), kCmpRead, stream_); + // } + const auto layer_offset = layer_id * local_kv_head_num_ * kv_cache_block_len_ * size_per_head_; // const int memory_len = max_seq_len; @@ -137,6 +137,10 @@ void LlamaDecoderSelfAttentionLayer::forward(TensorMap* o params.v = params.k + local_kv_head_num_ * size_per_head_; params.stride = (local_head_num_ + 2 * local_kv_head_num_) * size_per_head_; + params.q_bias = weights->qkv.bias; + params.k_bias = params.q_bias + local_head_num_ * size_per_head_; + params.v_bias = params.k_bias + local_kv_head_num_ * size_per_head_; + params.batch_size = batch_size; params.cu_block_cnts = cu_block_counts; @@ -146,6 +150,7 @@ void LlamaDecoderSelfAttentionLayer::forward(TensorMap* o params.finished = finished_data; params.per_sample_length = sequence_lengths_data; + params.rope_theta = rope_theta; params.layer_offset = layer_offset; @@ -154,8 +159,11 @@ void LlamaDecoderSelfAttentionLayer::forward(TensorMap* o params.size_per_head = size_per_head_; params.inv_sqrt_dh = 1.f / std::sqrt((float)params.size_per_head); - params.rotary_embedding_dim = size_per_head_; - params.rotary_embedding_base = 10000.f; + params.rotary_embedding_dim = size_per_head_; + params.rotary_embedding_base = params_.rotary_embedding_base; + params.max_position_embeddings = params_.max_position_embeddings; + // params.use_dynamic_ntk = params_.use_dynamic_ntk; + params.use_logn_attn = params_.use_logn_attn; params.partial_O = workspace_; params.partial_M = params.partial_O + batch_size * local_head_num_ * kMaxSplitK * size_per_head_; @@ -198,6 +206,10 @@ void LlamaDecoderSelfAttentionLayer::forward(TensorMap* o // } // } + // if (layer_id == 0) { + // Compare(context_buf_, batch_size * hidden_units_, Concat("context_buf", step, layer_id), kCmpRead, stream_); + // } + { NvtxScope scope("o_gemm"); linear_.forward(hidden_features_data, context_buf_, batch_size, weights->output); diff --git a/src/turbomind/models/llama/LlamaV2.cc b/src/turbomind/models/llama/LlamaV2.cc index 53160c8ede..fca323f0ad 100644 --- a/src/turbomind/models/llama/LlamaV2.cc +++ b/src/turbomind/models/llama/LlamaV2.cc @@ -74,6 +74,7 @@ LlamaV2::LlamaV2(size_t head_num, inter_size_(inter_size), num_layer_(num_layer), vocab_size_(vocab_size), + attn_params_(attn_params), vocab_size_padded_(vocab_size), rmsnorm_eps_(norm_eps), start_id_(start_id), @@ -222,22 +223,23 @@ void LlamaV2::embeddingLookup(T* embeddings, const int* token_ids_buf, int ba } template -void LlamaV2::contextDecode(T* deocder_output, - uintptr_t* k_cache_ptr, - uintptr_t* v_cache_ptr, - void** tmp_k_ptrs, - void** tmp_v_ptrs, - T* context_decoder_input_buf, - T* context_decoder_output_buf, - const int* input_ids, - const int* input_length, - const int* context_length, - const int* cu_block_counts, - size_t token_num, - size_t max_input_len, - size_t max_context_len, - size_t session_len, - size_t batch_size) +void LlamaV2::contextDecode(T* deocder_output, + uintptr_t* k_cache_ptr, + uintptr_t* v_cache_ptr, + void** tmp_k_ptrs, + void** tmp_v_ptrs, + T* context_decoder_input_buf, + T* context_decoder_output_buf, + const int* input_ids, + const int* input_length, + const int* context_length, + const int* cu_block_counts, + const float* rope_theta, + size_t token_num, + size_t max_input_len, + size_t max_context_len, + size_t session_len, + size_t batch_size) { TM_LOG_DEBUG(__PRETTY_FUNCTION__); @@ -274,6 +276,7 @@ void LlamaV2::contextDecode(T* deocder_output, {"max_q_len", {MEMORY_CPU, TYPE_INT32, {1}, &max_q_len}}, {"max_kv_len", {MEMORY_CPU, TYPE_INT32, {1}, &max_kv_len}}, {"max_seq_len", {MEMORY_CPU, TYPE_INT32, {1}, &max_seq_len}}, + {"rope_theta", {MEMORY_GPU, TYPE_FP32, {hidden_units_}, rope_theta}}, {"cu_block_counts", {MEMORY_GPU, TYPE_INT32, {batch_size}, cu_block_counts}}}; std::unordered_map decoder_output_tensors{ @@ -292,18 +295,19 @@ void LlamaV2::contextDecode(T* deocder_output, } template -void LlamaV2::decoderForward(T* decoder_output, - uintptr_t* k_cache_ptr, - uintptr_t* v_cache_ptr, - T* decoder_input, - const int* sequence_length, - const bool* finished, - const int* cu_block_counts, - int step, - int ite, - int sum_seq_len, - int max_seq_len, - size_t batch_size) +void LlamaV2::decoderForward(T* decoder_output, + uintptr_t* k_cache_ptr, + uintptr_t* v_cache_ptr, + T* decoder_input, + const int* sequence_length, + const bool* finished, + const int* cu_block_counts, + const float* rope_theta, + int step, + int ite, + int sum_seq_len, + int max_seq_len, + size_t batch_size) { TM_LOG_DEBUG(__PRETTY_FUNCTION__); @@ -319,6 +323,7 @@ void LlamaV2::decoderForward(T* decoder_output, {"max_seq_len", {MEMORY_CPU, TYPE_INT32, {1}, &max_seq_len}}, {"finished", {MEMORY_GPU, TYPE_BOOL, {batch_size}, finished}}, {"output_norm_weight", {MEMORY_GPU, dtype, {hidden_units_}, weights_->output_norm_weight}}, + {"rope_theta", {MEMORY_GPU, TYPE_FP32, {batch_size}, rope_theta}}, {"step", {MEMORY_CPU, TYPE_INT32, {1}, &step}}, {"ite", {MEMORY_CPU, TYPE_INT32, {1}, &ite}}, }; diff --git a/src/turbomind/models/llama/LlamaV2.h b/src/turbomind/models/llama/LlamaV2.h index 99d5352746..f26900eaa0 100644 --- a/src/turbomind/models/llama/LlamaV2.h +++ b/src/turbomind/models/llama/LlamaV2.h @@ -29,6 +29,7 @@ #include "src/turbomind/models/llama/LlamaWeight.h" #include "src/turbomind/models/llama/Request.h" #include "src/turbomind/models/llama/SequenceManager.h" +#include "src/turbomind/models/llama/llama_params.h" #include "src/turbomind/utils/allocator.h" #include "src/turbomind/utils/cublasMMWrapper.h" #include "src/turbomind/utils/instance_comm.h" @@ -112,35 +113,37 @@ class LlamaV2 { void embeddingLookup(T* embeddings, const int* token_ids_buf, int batch_size, int step); - void contextDecode(T* deocder_output, - uintptr_t* k_block_ptrs, - uintptr_t* v_block_ptrs, - void** k_tmp_ptrs, - void** v_tmp_ptrs, - T* context_decoder_input_buf, - T* context_decoder_output_buf, - const int* input_ids, - const int* input_length, - const int* context_length, - const int* cu_block_counts, - size_t token_num, - size_t max_input_len, - size_t max_context_len, - size_t session_len, - size_t batch_size); - - void decoderForward(T* decoder_output, - uintptr_t* k_cache_ptr, - uintptr_t* v_cache_ptr, - T* decoder_input, - const int* sequence_length, - const bool* finished, - const int* cu_block_counts, - int step, - int ite, - int sum_seq_len, - int max_seq_len, - size_t batch_size); + void contextDecode(T* deocder_output, + uintptr_t* k_block_ptrs, + uintptr_t* v_block_ptrs, + void** k_tmp_ptrs, + void** v_tmp_ptrs, + T* context_decoder_input_buf, + T* context_decoder_output_buf, + const int* input_ids, + const int* input_length, + const int* context_length, + const int* cu_block_counts, + const float* rope_theta, + size_t token_num, + size_t max_input_len, + size_t max_context_len, + size_t session_len, + size_t batch_size); + + void decoderForward(T* decoder_output, + uintptr_t* k_cache_ptr, + uintptr_t* v_cache_ptr, + T* decoder_input, + const int* sequence_length, + const bool* finished, + const int* cu_block_counts, + const float* rope_theta, + int step, + int ite, + int sum_seq_len, + int max_seq_len, + size_t batch_size); void postDecodeEmbedding(float* logits, float* local_logits, const T* decoder_output, int batch_size); @@ -181,6 +184,8 @@ class LlamaV2 { size_t vocab_size_padded_; float rmsnorm_eps_ = 1e-6f; + const LlamaAttentionParams attn_params_; + static constexpr bool neox_rotary_style_ = false; const int start_id_; diff --git a/src/turbomind/models/llama/SequenceManager.cc b/src/turbomind/models/llama/SequenceManager.cc index 12f982be26..6c2778daa1 100644 --- a/src/turbomind/models/llama/SequenceManager.cc +++ b/src/turbomind/models/llama/SequenceManager.cc @@ -36,7 +36,7 @@ SequenceManager::SequenceManager(size_t layer_num, const Sequence* SequenceManager::Create(uint64_t id) { - Sequence sequence{id, {}, {}, {}, {}, {}}; + Sequence sequence{id, {}, {}, {}, {}, {}, {}, 0.f}; auto it = sequences_.find(id); if (it != sequences_.end()) { diff --git a/src/turbomind/models/llama/SequenceManager.h b/src/turbomind/models/llama/SequenceManager.h index 8800149bf1..be99e120e3 100644 --- a/src/turbomind/models/llama/SequenceManager.h +++ b/src/turbomind/models/llama/SequenceManager.h @@ -27,6 +27,8 @@ struct Sequence { // additional data kept round-to-round mutable std::vector random_state; // update by user + mutable float rope_theta; + friend std::ostream& operator<<(std::ostream& os, const Sequence& seq); }; diff --git a/src/turbomind/models/llama/llama_params.h b/src/turbomind/models/llama/llama_params.h index 8f8c96837b..78b1570f02 100644 --- a/src/turbomind/models/llama/llama_params.h +++ b/src/turbomind/models/llama/llama_params.h @@ -5,10 +5,11 @@ namespace turbomind { struct LlamaAttentionParams { - int rotray_embedding_dim; + int rotary_embedding_dim; float rotary_embedding_base; int max_position_embeddings; - bool use_dynamic_ntk; + float rope_scaling_factor; + // bool use_dynamic_ntk; bool use_logn_attn; }; diff --git a/src/turbomind/triton_backend/llama/LlamaTritonModel.cc b/src/turbomind/triton_backend/llama/LlamaTritonModel.cc index beab5d7d94..3a60896a59 100644 --- a/src/turbomind/triton_backend/llama/LlamaTritonModel.cc +++ b/src/turbomind/triton_backend/llama/LlamaTritonModel.cc @@ -74,6 +74,12 @@ void LlamaTritonModel::handleMissingParams() TM_LOG_WARNING("[LlamaTritonModel] `session_len` is not set, default to %d.", (int)session_len_); } + if (!attn_params_.max_position_embeddings) { + attn_params_.max_position_embeddings = session_len_; + TM_LOG_WARNING("[LlamaTritonModel] `max_position_embeddings` is not set, default to `session_len` (%d).", + (int)attn_params_.max_position_embeddings); + } + if (!max_context_token_num_) { max_context_token_num_ = (int)std::sqrt(max_batch_size_); TM_LOG_WARNING("[LlamaTritonModel] `max_context_token_num` is not set, default to %d.", @@ -142,10 +148,12 @@ LlamaTritonModel::LlamaTritonModel(size_t tensor_para_size, quant_policy_ = reader.GetInteger("llama", "quant_policy", 0); group_size_ = reader.GetInteger("llama", "group_size", 0); - attn_params_.rotray_embedding_dim = reader.GetInteger("llama", "rotary_embedding"); + // rotary embedding parameters + attn_params_.rotary_embedding_dim = reader.GetInteger("llama", "rotary_embedding"); attn_params_.rotary_embedding_base = reader.GetFloat("llama", "rope_theta", 10000.0f); + attn_params_.rope_scaling_factor = reader.GetFloat("llama", "rope_scaling_factor", 0.f); attn_params_.max_position_embeddings = reader.GetInteger("llama", "max_position_embeddings", 0); - attn_params_.use_dynamic_ntk = reader.GetInteger("llama", "use_dynamic_ntk", 0); + // attn_params_.use_dynamic_ntk = reader.GetInteger("llama", "use_dynamic_ntk", 0); attn_params_.use_logn_attn = reader.GetInteger("llama", "use_logn_attn", 0); handleMissingParams();