Skip to content

Commit

Permalink
fix rms norm, rotary embedding & deepseek v2 attention
Browse files Browse the repository at this point in the history
  • Loading branch information
lzhangzz committed Nov 28, 2024
1 parent 329e441 commit 90d2529
Show file tree
Hide file tree
Showing 13 changed files with 320 additions and 73 deletions.
1 change: 1 addition & 0 deletions lmdeploy/turbomind/deploy/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def verify(self):
class AttentionConfig:
rotary_embedding: int = 128
rope_theta: float = 10000.0
softmax_scale: float = 0
attention_factor: float = None
max_position_embeddings: int = 0
original_max_position_embeddings: int = 0
Expand Down
8 changes: 0 additions & 8 deletions lmdeploy/turbomind/deploy/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,14 +245,6 @@ def _export(self, idx: int, xs, kind: str, pack_fn, **kwargs):
q_b = q

cfg = self.model.model_config
qk_nope_dim = cfg.size_per_head - cfg.qk_rope_dim

q_b = q_b.reshape(-1, cfg.size_per_head)

# [nope_dim | rope_dim] -> [rope_dim | nope_dim]
q_nope, q_pe = torch.split(q_b, (qk_nope_dim, cfg.qk_rope_dim), dim=-1)
q_b = torch.cat((q_pe, q_nope),
dim=-1).view(-1, cfg.head_num * cfg.size_per_head)

o = o.reshape(cfg.head_num, cfg.v_head_dim, -1)
o = torch.nn.functional.pad(
Expand Down
23 changes: 16 additions & 7 deletions lmdeploy/turbomind/deploy/source_model/deepseek2.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def mla_norm(self, i: int):
return (*result, )


def get_yarn_attention_factor(rope_scaling: dict):
def get_yarn_params(rope_scaling: dict):

scaling_factor = float(rope_scaling['factor'])
mscale = rope_scaling['mscale']
Expand All @@ -71,7 +71,13 @@ def yarn_get_mscale(scale=1, mscale=1):
_mscale = float(
yarn_get_mscale(scaling_factor, mscale) /
yarn_get_mscale(scaling_factor, mscale_all_dim))
return _mscale

softmax_scale = 0
if mscale_all_dim:
scale = yarn_get_mscale(scaling_factor, mscale_all_dim)
softmax_scale = scale * scale

return _mscale, softmax_scale


@INPUT_MODELS.register_module(name='deepseek2')
Expand Down Expand Up @@ -100,11 +106,12 @@ def model_info(self):
inter_size = [n_shared_experts * expert_inter_size] * num_layer
inter_size[0] = cfg['intermediate_size']
norm_topk_prob = cfg['norm_topk_prob']
size_per_head = qk_rope_dim + qk_nope_dim
info.update(kv_lora_rank=cfg['kv_lora_rank'],
q_lora_rank=cfg['q_lora_rank'] or 0,
qk_rope_dim=qk_rope_dim,
v_head_dim=cfg['v_head_dim'],
size_per_head=qk_rope_dim + qk_nope_dim,
size_per_head=size_per_head,
rotary_embedding=qk_rope_dim,
expert_num=expert_num,
expert_inter_size=expert_inter_size,
Expand All @@ -118,8 +125,10 @@ def model_info(self):
tune_layer_num=2)
rope_scaling = cfg.get('rope_scaling')
if rope_scaling and rope_scaling['type'] == 'yarn':
info.update(
max_position_embeddings=rope_scaling[
'original_max_position_embeddings'],
attention_factor=get_yarn_attention_factor(rope_scaling))
attention_factor, softmax_scale = get_yarn_params(rope_scaling)
softmax_scale *= size_per_head**(-0.5)
info.update(max_position_embeddings=rope_scaling[
'original_max_position_embeddings'],
attention_factor=attention_factor,
softmax_scale=softmax_scale)
return info
17 changes: 17 additions & 0 deletions src/turbomind/kernels/attention/rotary_embedding.h
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ struct FastRoPE {
template<typename T>
__device__ void apply(Array<T, N>& x, float timestep)
{
#if 0
PRAGMA_UNROLL
for (int i = 0; i < N; i += 2) {
float c, s;
Expand All @@ -144,6 +145,22 @@ struct FastRoPE {
x[i + 1] = (T)tmp1;
}
}
#else
// Most models apply rotary embedding in half precision
PRAGMA_UNROLL
for (int i = 0; i < N; i += 2) {
float c, s;
sincosf(timestep * inv_freq_[i / 2], &s, &c);
s *= attention_scaling_;
c *= attention_scaling_;
T tmp0 = (T)c * x[i] - (T)s * x[i + 1];
T tmp1 = (T)c * x[i + 1] + (T)s * x[i];
if (is_valid_) {
x[i] = tmp0;
x[i + 1] = tmp1;
}
}
#endif
}
};

Expand Down
19 changes: 19 additions & 0 deletions src/turbomind/kernels/gemm/moe_utils_v2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -733,6 +733,7 @@ __global__ void MoeReduceKernel(T* dst, // [ n, d]
}

for (int i = threadIdx.x; i < dims; i += block_dim) {
#if 1
Array<float, vec_size> accum{};
if (dst_scale) {
Vec v;
Expand All @@ -749,6 +750,24 @@ __global__ void MoeReduceKernel(T* dst, // [ n, d]
accum = accum + x;
}
Store(dst_ptr[i].data(), cast<T>(accum));
#else
Array<T, vec_size> accum{};
if (dst_scale) {
Vec v;
Ldg(v, dst_ptr[i].data());
using namespace ops;
accum = v * (T)dst_scale;
}
PRAGMA_UNROLL
for (int e = 0; e < exp_k; ++e) {
Vec v;
Ldg(v, src_ptr[e][i].data());
using namespace ops;
const auto x = v * (T)scale[e];
accum = accum + x;
}
Store(dst_ptr[i].data(), accum);
#endif
}
}

Expand Down
159 changes: 143 additions & 16 deletions src/turbomind/kernels/norm/rms_norm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,15 @@
namespace turbomind {

template<class T, class Accum, int block_dim, int vec_size>
__global__ void RMSNormKernel(
T* dst, int dst_ld, const T* src, int src_ld, const T* weights, int dims, int num, float eps, float inv_dims)
__global__ void RMSNormKernel(T* dst,
int dst_ld,
const T* src,
int src_ld,
const T* __restrict__ weights,
int dims,
int num,
float eps,
float inv_dims)
{
const int ti = blockIdx.x;
const int di = threadIdx.x * vec_size;
Expand Down Expand Up @@ -56,32 +63,34 @@ __global__ void RMSNormKernel(
Array<T, vec_size> sv;
for (int i = di; i < dims; i += block_dim * vec_size) {
Load(vec, &src[i]);
Array<Accum, vec_size> tmp = cast<Accum>(vec);
Load(sv, &weights[i]);
Ldg(sv, &weights[i]);
PRAGMA_UNROLL
for (int c = 0; c < vec_size; ++c) {
tmp[c] *= (float)sv[c] * sum;
vec[c] = (T)((float)vec[c] * sum) * sv[c];
// vec[c] = (T)((float)vec[c] * sum * (float)sv[c]);
}
Store(&dst[i], cast<T>(tmp));
Store(&dst[i], vec);
}
}

template<class T>
void invokeRMSNorm(
T* dst, int dst_ld, const T* src, int src_ld, const T* weights, int dims, int num, float eps, cudaStream_t st)
{
constexpr int threads = 256;
constexpr int vec_size = 16 / sizeof(T);

constexpr int threads = 512;
const int blocks = num;

RMSNormKernel<T, float, threads, 8><<<blocks, threads, 0, st>>>(dst, //
dst_ld,
src,
src_ld,
weights,
dims,
num,
eps,
1.f / dims);
RMSNormKernel<T, float, threads, vec_size><<<blocks, threads, 0, st>>>(dst, //
dst_ld,
src,
src_ld,
weights,
dims,
num,
eps,
1.f / dims);
}

template void invokeRMSNorm(half* dst,
Expand All @@ -105,4 +114,122 @@ template void invokeRMSNorm(nv_bfloat16* dst,
cudaStream_t st);
#endif

// r' <- r + (h + b)
// h' <- norm(r') * w
template<class T, class Tacc, int block_dim, int vec_size>
__global__ void BiasResidualRMSNormKernel(T* __restrict__ residual,
T* __restrict__ hidden_states,
const T* __restrict__ weights,
const T* __restrict__ bias,
int dims,
int num,
float eps,
float inv_dims)
{
const int ti = blockIdx.x;
const int di = threadIdx.x * vec_size;

if (ti >= num) {
return;
}

residual += dims * ti;
hidden_states += dims * ti;

Array<Tacc, vec_size> accum{};

Array<T, vec_size> r_vec;
Array<T, vec_size> h_vec;
Array<T, vec_size> b_vec;

for (int i = di; i < dims; i += block_dim * vec_size) {
Load(r_vec, &residual[i]);
Load(h_vec, &hidden_states[i]);

using namespace ops;
r_vec = r_vec + h_vec;

if (bias) {
Ldg(b_vec, &bias[i]);
r_vec = r_vec + b_vec;
}

Store(&residual[i], r_vec);

Array<Tacc, vec_size> tmp = cast<Tacc>(r_vec);

accum = accum + tmp * tmp;
}

float sum{};
PRAGMA_UNROLL
for (int i = 0; i < vec_size; ++i) {
sum += accum[i];
}

using BlockReduce = cub::BlockReduce<Tacc, block_dim>;
__shared__ typename BlockReduce::TempStorage temp_storage;

sum = BlockReduce{temp_storage}.Sum(sum);

__shared__ float shared_sum;

if (threadIdx.x == 0) {
shared_sum = rsqrtf(sum * inv_dims + eps);
}

__syncthreads();

sum = shared_sum;

Array<T, vec_size> w_vec;
for (int i = di; i < dims; i += block_dim * vec_size) {
Load(r_vec, &residual[i]);
Ldg(w_vec, &weights[i]);
PRAGMA_UNROLL
for (int c = 0; c < vec_size; ++c) {
r_vec[c] = (T)((float)r_vec[c] * sum) * w_vec[c];
}
Store(&hidden_states[i], r_vec);
}
}

template<class T>
void invokeBiasResidualRMSNorm(
T* residual, T* hidden_states, const T* weights, const T* bias, int dims, int num, float eps, cudaStream_t st)
{
constexpr int vec_size = 16 / sizeof(T);
constexpr int threads = 512;
const int blocks = num;

BiasResidualRMSNormKernel<T, float, threads, vec_size><<<blocks, threads, 0, st>>>(residual, //
hidden_states,
weights,
bias,
dims,
num,
eps,
1.f / dims);
}

template void invokeBiasResidualRMSNorm(half* residual,
half* hidden_states,
const half* weights,
const half* bias,
int dims,
int num,
float eps,
cudaStream_t st);

#if ENABLE_BF16
template void invokeBiasResidualRMSNorm(nv_bfloat16* residual,
nv_bfloat16* hidden_states,
const nv_bfloat16* weights,
const nv_bfloat16* bias,
int dims,
int num,
float eps,
cudaStream_t st);
#endif

} // namespace turbomind
10 changes: 10 additions & 0 deletions src/turbomind/kernels/norm/rms_norm.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,14 @@ template<class T>
void invokeRMSNorm(
T* dst, int dst_ld, const T* src, int src_ld, const T* weights, int dims, int num, float eps, cudaStream_t st);

template<class T>
void invokeRMSNorm(T* dst, const T* src, const T* weights, int dims, int num, float eps, cudaStream_t st)
{
invokeRMSNorm(dst, dims, src, dims, weights, dims, num, eps, st);
}

template<class T>
void invokeBiasResidualRMSNorm(
T* residual, T* hidden_states, const T* weights, const T* bias, int dims, int num, float eps, cudaStream_t st);

} // namespace turbomind
1 change: 1 addition & 0 deletions src/turbomind/models/llama/llama_params.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ struct AttentionParam {
int rotary_embedding_dim;
float rotary_embedding_base;
int max_position_embeddings;
float softmax_scale;
std::string rope_scaling_type;
int original_max_position_embeddings;
float rope_scaling_factor;
Expand Down
15 changes: 12 additions & 3 deletions src/turbomind/models/llama/llama_utils.cu
Original file line number Diff line number Diff line change
Expand Up @@ -65,14 +65,23 @@ void CmpRead(T* ptr, size_t size, std::string key, cudaStream_t stream)
check_cuda_error(cudaMemcpyAsync(h_b.data(), ptr, sizeof(T) * size, cudaMemcpyDefault, stream));
check_cuda_error(cudaStreamSynchronize(stream));

using Tacc = std::conditional_t<std::is_integral_v<T>, int64_t, float>;
using Tacc = std::conditional_t<std::is_integral_v<T>, int64_t, float>;
constexpr Tacc eps = std::is_integral_v<T> ? 1 : 1e-8f;

Tacc asum{};
Tacc rsum{};
Tacc amean{};
for (size_t i = 0; i < size; ++i) {
asum += std::abs((Tacc)h_a[i] - (Tacc)h_b[i]);
Tacc x = (Tacc)h_b[i];
Tacc r = (Tacc)h_a[i];
Tacc abs_diff = std::abs(x - r);
Tacc rel_diff = abs_diff / std::max(std::max(std::abs(r), std::abs(x)), eps);
asum += abs_diff;
rsum += rel_diff;
amean += std::abs(r);
}

std::cerr << key << ": " << asum << " " << asum / size << "\n";
std::cerr << key << ": " << amean / size << " " << asum << " " << asum / size << " " << rsum / size << "\n";

check_cuda_error(cudaMemcpyAsync(ptr, h_a.data(), sizeof(T) * h_a.size(), cudaMemcpyDefault, stream));
check_cuda_error(cudaStreamSynchronize(stream));
Expand Down
Loading

0 comments on commit 90d2529

Please sign in to comment.