Skip to content

Commit

Permalink
use FastRoPE
Browse files Browse the repository at this point in the history
  • Loading branch information
irexyc committed Dec 4, 2024
1 parent d9d5a38 commit 832ade3
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 100 deletions.
2 changes: 1 addition & 1 deletion src/turbomind/kernels/attention/attention_universal.h
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ struct AttentionUniversal {
ApplyBias(vec_Q, vec_K, vec_V, params, head_idx, kv_head_idx, offset);

if (params.cos_sin) {
PrecomputeFastRoPE rope{};
FastRoPE rope{};
PRAGMA_UNROLL
for (int s = 0; s < ITER_S; ++s) {
PRAGMA_UNROLL
Expand Down
2 changes: 1 addition & 1 deletion src/turbomind/kernels/attention/kv_cache_utils_v2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ __global__ void __launch_bounds__(128) ProcessKV_v2(char** blocks,
}
}

PrecomputeFastRoPE rope;
FastRoPE rope{};
PRAGMA_UNROLL
for (int s = 0; s < ITER_S; ++s) {
PRAGMA_UNROLL
Expand Down
99 changes: 1 addition & 98 deletions src/turbomind/kernels/attention/rotary_embedding.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ __device__ void ApplyRotaryEmbedding(Array<T, 4>& x, float base, int dims, int t
}
}

struct PrecomputeFastRoPE {
struct FastRoPE {

template<typename T, int N>
__device__ void apply(Array<T, N>& x, Array<T, N>& cs)
Expand All @@ -82,103 +82,6 @@ struct PrecomputeFastRoPE {
}
};

template<class D, int N>
struct FastRoPE {

static_assert(N % 2 == 0);

Array<float, N / 2> inv_freq_;
bool is_valid_;
float attention_scaling_;

__device__ FastRoPE(int idx,
D dims,
float base,
float ti_scale,
float factor,
float llama3_inv_scaling_factor,
float llama3_alpha,
float llama3_beta,
float yarn_ramp_inv_factor_div_2,
float yarn_ramp_inv_factor_mul_min,
float yarn_inv_scaling_factor,
float attention_scaling,
std::integral_constant<int, N>)
{
is_valid_ = idx < dims;
attention_scaling_ = attention_scaling;
/// TODO: Take this away from device code
const float scale_factor = -log2f(base) / dims;
PRAGMA_UNROLL
for (int i = 0; i < N; i += 2) {
inv_freq_[i / 2] = ti_scale * exp2f((idx + i) * scale_factor);
}
// clang-format off
/* The [llama3 rope](https://github.com/huggingface/transformers/blob/5f4ee98a7ade33e1c54fdd6181d04ee7b426b392/src/transformers/modeling_rope_utils.py#L298)
* used by llama3.1 equals to the following equation, given the precommuted parameters as:
```C++
inv_scaling_factor = 1 / factor;
inv_diff_freq_factor = 1 / (high_freq_factor - low_freq_factor);
alpha = old_context_len / (2 * PI) * inv_diff_freq_factor;
beta = low_freq_factor * inv_diff_freq_factor
```
*/
// clang-format on
if (llama3_inv_scaling_factor) {
PRAGMA_UNROLL
for (int i = 0; i < N; i += 2) {
auto freq = inv_freq_[i / 2];
auto smooth = fmaxf(0.f, fminf(1.f, llama3_alpha * freq - llama3_beta));
inv_freq_[i / 2] = (1 - smooth) * freq * llama3_inv_scaling_factor + smooth * freq;
}
}
if (yarn_ramp_inv_factor_div_2) {
PRAGMA_UNROLL
for (int i = 0; i < N; i += 2) {
auto freq = inv_freq_[i / 2];
float alpha = (idx + i) * yarn_ramp_inv_factor_div_2 - yarn_ramp_inv_factor_mul_min;
alpha = fmaxf(0.f, fminf(1.f, alpha));
inv_freq_[i / 2] = freq - freq * alpha * yarn_inv_scaling_factor;
}
}
}

template<typename T>
__device__ void apply(Array<T, N>& x, float timestep)
{
#if 0
PRAGMA_UNROLL
for (int i = 0; i < N; i += 2) {
float c, s;
sincosf(timestep * inv_freq_[i / 2], &s, &c);
s *= attention_scaling_;
c *= attention_scaling_;
float tmp0 = c * (float)x[i] - s * (float)x[i + 1];
float tmp1 = c * (float)x[i + 1] + s * (float)x[i];
if (is_valid_) {
x[i] = (T)tmp0;
x[i + 1] = (T)tmp1;
}
}
#else
// Most models apply rotary embedding in half precision
PRAGMA_UNROLL
for (int i = 0; i < N; i += 2) {
float c, s;
sincosf(timestep * inv_freq_[i / 2], &s, &c);
s *= attention_scaling_;
c *= attention_scaling_;
T tmp0 = (T)c * x[i] - (T)s * x[i + 1];
T tmp1 = (T)c * x[i + 1] + (T)s * x[i];
if (is_valid_) {
x[i] = tmp0;
x[i + 1] = tmp1;
}
}
#endif
}
};

template<int N, int C = 8>
struct RoPE {
Array<float, N> inv_freqs_;
Expand Down
11 changes: 11 additions & 0 deletions src/turbomind/models/llama/rotary_emb.cu
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,17 @@ RotaryEmbeddingV2<T>::RotaryEmbeddingV2(const AttentionParam& param, cudaStream_
break;
}
case RopeType::kLlama3: {
// clang-format off
/* The [llama3 rope](https://github.com/huggingface/transformers/blob/5f4ee98a7ade33e1c54fdd6181d04ee7b426b392/src/transformers/modeling_rope_utils.py#L298)
* used by llama3.1 equals to the following equation, given the precommuted parameters as:
```C++
inv_scaling_factor = 1 / factor;
inv_diff_freq_factor = 1 / (high_freq_factor - low_freq_factor);
alpha = old_context_len / (2 * PI) * inv_diff_freq_factor;
beta = low_freq_factor * inv_diff_freq_factor
```
*/
// clang-format on
const double PI = 3.14159265358979323846;
float inv_diff_freq_factor = 1.0 / (param.rope.llama3.high_freq_factor - param.rope.llama3.low_freq_factor);
llama3_.inv_scaling_factor = 1.0 / param.rope.factor;
Expand Down

0 comments on commit 832ade3

Please sign in to comment.