From 832ade389fa5eae69acdc1849048251bc0cbdb42 Mon Sep 17 00:00:00 2001 From: irexyc Date: Wed, 4 Dec 2024 08:36:55 +0000 Subject: [PATCH] use FastRoPE --- .../kernels/attention/attention_universal.h | 2 +- .../kernels/attention/kv_cache_utils_v2.cu | 2 +- .../kernels/attention/rotary_embedding.h | 99 +------------------ src/turbomind/models/llama/rotary_emb.cu | 11 +++ 4 files changed, 14 insertions(+), 100 deletions(-) diff --git a/src/turbomind/kernels/attention/attention_universal.h b/src/turbomind/kernels/attention/attention_universal.h index 96e8c9317d..a4a209c4ab 100644 --- a/src/turbomind/kernels/attention/attention_universal.h +++ b/src/turbomind/kernels/attention/attention_universal.h @@ -231,7 +231,7 @@ struct AttentionUniversal { ApplyBias(vec_Q, vec_K, vec_V, params, head_idx, kv_head_idx, offset); if (params.cos_sin) { - PrecomputeFastRoPE rope{}; + FastRoPE rope{}; PRAGMA_UNROLL for (int s = 0; s < ITER_S; ++s) { PRAGMA_UNROLL diff --git a/src/turbomind/kernels/attention/kv_cache_utils_v2.cu b/src/turbomind/kernels/attention/kv_cache_utils_v2.cu index 3af6e7cd0b..e8c7d6bbfa 100644 --- a/src/turbomind/kernels/attention/kv_cache_utils_v2.cu +++ b/src/turbomind/kernels/attention/kv_cache_utils_v2.cu @@ -130,7 +130,7 @@ __global__ void __launch_bounds__(128) ProcessKV_v2(char** blocks, } } - PrecomputeFastRoPE rope; + FastRoPE rope{}; PRAGMA_UNROLL for (int s = 0; s < ITER_S; ++s) { PRAGMA_UNROLL diff --git a/src/turbomind/kernels/attention/rotary_embedding.h b/src/turbomind/kernels/attention/rotary_embedding.h index c6ef49a6a2..59e4592731 100644 --- a/src/turbomind/kernels/attention/rotary_embedding.h +++ b/src/turbomind/kernels/attention/rotary_embedding.h @@ -67,7 +67,7 @@ __device__ void ApplyRotaryEmbedding(Array& x, float base, int dims, int t } } -struct PrecomputeFastRoPE { +struct FastRoPE { template __device__ void apply(Array& x, Array& cs) @@ -82,103 +82,6 @@ struct PrecomputeFastRoPE { } }; -template -struct FastRoPE { - - static_assert(N % 2 == 0); - - Array inv_freq_; - bool is_valid_; - float attention_scaling_; - - __device__ FastRoPE(int idx, - D dims, - float base, - float ti_scale, - float factor, - float llama3_inv_scaling_factor, - float llama3_alpha, - float llama3_beta, - float yarn_ramp_inv_factor_div_2, - float yarn_ramp_inv_factor_mul_min, - float yarn_inv_scaling_factor, - float attention_scaling, - std::integral_constant) - { - is_valid_ = idx < dims; - attention_scaling_ = attention_scaling; - /// TODO: Take this away from device code - const float scale_factor = -log2f(base) / dims; - PRAGMA_UNROLL - for (int i = 0; i < N; i += 2) { - inv_freq_[i / 2] = ti_scale * exp2f((idx + i) * scale_factor); - } - // clang-format off - /* The [llama3 rope](https://github.com/huggingface/transformers/blob/5f4ee98a7ade33e1c54fdd6181d04ee7b426b392/src/transformers/modeling_rope_utils.py#L298) - * used by llama3.1 equals to the following equation, given the precommuted parameters as: - ```C++ - inv_scaling_factor = 1 / factor; - inv_diff_freq_factor = 1 / (high_freq_factor - low_freq_factor); - alpha = old_context_len / (2 * PI) * inv_diff_freq_factor; - beta = low_freq_factor * inv_diff_freq_factor - ``` - */ - // clang-format on - if (llama3_inv_scaling_factor) { - PRAGMA_UNROLL - for (int i = 0; i < N; i += 2) { - auto freq = inv_freq_[i / 2]; - auto smooth = fmaxf(0.f, fminf(1.f, llama3_alpha * freq - llama3_beta)); - inv_freq_[i / 2] = (1 - smooth) * freq * llama3_inv_scaling_factor + smooth * freq; - } - } - if (yarn_ramp_inv_factor_div_2) { - PRAGMA_UNROLL - for (int i = 0; i < N; i += 2) { - auto freq = inv_freq_[i / 2]; - float alpha = (idx + i) * yarn_ramp_inv_factor_div_2 - yarn_ramp_inv_factor_mul_min; - alpha = fmaxf(0.f, fminf(1.f, alpha)); - inv_freq_[i / 2] = freq - freq * alpha * yarn_inv_scaling_factor; - } - } - } - - template - __device__ void apply(Array& x, float timestep) - { -#if 0 - PRAGMA_UNROLL - for (int i = 0; i < N; i += 2) { - float c, s; - sincosf(timestep * inv_freq_[i / 2], &s, &c); - s *= attention_scaling_; - c *= attention_scaling_; - float tmp0 = c * (float)x[i] - s * (float)x[i + 1]; - float tmp1 = c * (float)x[i + 1] + s * (float)x[i]; - if (is_valid_) { - x[i] = (T)tmp0; - x[i + 1] = (T)tmp1; - } - } -#else - // Most models apply rotary embedding in half precision - PRAGMA_UNROLL - for (int i = 0; i < N; i += 2) { - float c, s; - sincosf(timestep * inv_freq_[i / 2], &s, &c); - s *= attention_scaling_; - c *= attention_scaling_; - T tmp0 = (T)c * x[i] - (T)s * x[i + 1]; - T tmp1 = (T)c * x[i + 1] + (T)s * x[i]; - if (is_valid_) { - x[i] = tmp0; - x[i + 1] = tmp1; - } - } -#endif - } -}; - template struct RoPE { Array inv_freqs_; diff --git a/src/turbomind/models/llama/rotary_emb.cu b/src/turbomind/models/llama/rotary_emb.cu index ef11977fc8..b25bc21d9f 100644 --- a/src/turbomind/models/llama/rotary_emb.cu +++ b/src/turbomind/models/llama/rotary_emb.cu @@ -168,6 +168,17 @@ RotaryEmbeddingV2::RotaryEmbeddingV2(const AttentionParam& param, cudaStream_ break; } case RopeType::kLlama3: { + // clang-format off + /* The [llama3 rope](https://github.com/huggingface/transformers/blob/5f4ee98a7ade33e1c54fdd6181d04ee7b426b392/src/transformers/modeling_rope_utils.py#L298) + * used by llama3.1 equals to the following equation, given the precommuted parameters as: + ```C++ + inv_scaling_factor = 1 / factor; + inv_diff_freq_factor = 1 / (high_freq_factor - low_freq_factor); + alpha = old_context_len / (2 * PI) * inv_diff_freq_factor; + beta = low_freq_factor * inv_diff_freq_factor + ``` + */ + // clang-format on const double PI = 3.14159265358979323846; float inv_diff_freq_factor = 1.0 / (param.rope.llama3.high_freq_factor - param.rope.llama3.low_freq_factor); llama3_.inv_scaling_factor = 1.0 / param.rope.factor;