Skip to content

Commit c6bd5fe

Browse files
committed
fix moe gating & config parsing
1 parent 9b1a82b commit c6bd5fe

File tree

5 files changed

+40
-50
lines changed

5 files changed

+40
-50
lines changed

lmdeploy/turbomind/deploy/config.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,8 @@ class ModelConfig:
5858
expert_num: List[int] = ()
5959
expert_inter_size: int = 0
6060
experts_per_token: int = 0
61-
moe_shared_gate: int = False
62-
norm_topk_prob: int = False
61+
moe_shared_gate: bool = False
62+
norm_topk_prob: bool = False
6363
routed_scale: float = 1.0
6464
topk_group: int = 1
6565
topk_method: str = 'greedy'

lmdeploy/turbomind/deploy/source_model/qwen.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,6 @@ def model_info(self):
178178
info['experts_per_token'] = cfg['num_experts_per_tok']
179179
info['inter_size'] = cfg['shared_expert_intermediate_size']
180180
info['moe_shared_gate'] = True
181-
info['moe_norm_topk_prob'] = cfg['norm_topk_prob']
181+
info['norm_topk_prob'] = cfg['norm_topk_prob']
182182
info['attn_bias'] = 1
183183
return info

src/turbomind/kernels/gemm/moe_utils_v2.cu

+16-5
Original file line numberDiff line numberDiff line change
@@ -515,9 +515,11 @@ __global__ void MoeGateKernel_v8(float* scales, // [e,n]
515515

516516
PRAGMA_UNROLL
517517
for (int i = 0; i < max_tiles * max_expert_num; i += block_dim) {
518-
int e = (i + threadIdx.x) % max_expert_num;
519-
int t = (i + threadIdx.x) / max_expert_num;
520-
smem.shared_accum[t][e] = 0;
518+
int e = (i + threadIdx.x) % max_expert_num;
519+
int t = (i + threadIdx.x) / max_expert_num;
520+
if (t < max_tiles) {
521+
smem.shared_accum[t][e] = 0;
522+
}
521523
}
522524

523525
__syncthreads();
@@ -538,8 +540,6 @@ __global__ void MoeGateKernel_v8(float* scales, // [e,n]
538540
masks[expert_id * token_num_padded + ti2] = idx;
539541
scales[idx * token_num + ti2] = scale * routed_scale;
540542
atomicAdd(&smem.shared_accum[ti2 >> log_tile][expert_id], 1);
541-
542-
// printf("%d %d %f\n", idx, expert_id, scale);
543543
}
544544
}
545545

@@ -613,6 +613,17 @@ void invokeMoeGate_V2(int* f2n, // [e*n] -> n
613613

614614
if (experts <= 8) {
615615
if (experts_per_token <= 2) {
616+
// MoeGateKernel_V2<2, 128><<<cdiv(tokens, 128), 128, 0, st>>>(scales,
617+
// (int8_t*)masks,
618+
// accum,
619+
// logits,
620+
// log_tile,
621+
// tiles,
622+
// tokens,
623+
// tokens_padded,
624+
// experts);
625+
626+
// std::cout << tokens << " " << experts << " " << experts_per_token << " " << tokens_padded << "\n";
616627
invoke(_Int<8>, _Int<2>, _Int<8>, _Int<4>);
617628
}
618629
else {

src/turbomind/models/llama/llama_utils.cu

+19-40
Original file line numberDiff line numberDiff line change
@@ -1,49 +1,26 @@
11
// Copyright (c) OpenMMLab. All rights reserved.
22

3-
#include "src/turbomind/kernels/reduce_kernel_utils.cuh"
4-
#include "src/turbomind/models/llama/llama_utils.h"
5-
#include "src/turbomind/utils/cuda_utils.h"
63
#include <cmath>
74
#include <cstdio>
85
#include <cstdlib>
96
#include <cstring>
7+
#include <type_traits>
8+
#include <vector>
9+
1010
#include <cuda_fp16.h>
1111
#include <curand_kernel.h>
1212
#include <thrust/device_vector.h>
1313
#include <thrust/execution_policy.h>
1414
#include <thrust/host_vector.h>
15-
#include <vector>
15+
16+
#include "src/turbomind/models/llama/llama_utils.h"
17+
#include "src/turbomind/utils/cuda_utils.h"
1618

1719
namespace turbomind {
1820

1921
CmpMode compare_mode = kCmpRead;
2022
// CmpMode compare_mode = kCmpWrite;
2123

22-
template<typename T>
23-
struct abs_diff_t {
24-
using type = T;
25-
};
26-
27-
template<>
28-
struct abs_diff_t<half> {
29-
using type = float;
30-
};
31-
32-
template<>
33-
struct abs_diff_t<__nv_bfloat16> {
34-
using type = float;
35-
};
36-
37-
template<typename T>
38-
struct abs_diff: public thrust::unary_function<thrust::tuple<T, T>, typename abs_diff_t<T>::type> {
39-
__host__ __device__ float operator()(thrust::tuple<T, T> x) const
40-
{
41-
using R = typename abs_diff_t<T>::type;
42-
auto r = R(thrust::get<0>(x)) - R(thrust::get<1>(x));
43-
return r < R(0) ? -r : r;
44-
}
45-
};
46-
4724
template<typename T>
4825
void CheckNan(const T* ptr, size_t size, std::string key, cudaStream_t stream)
4926
{
@@ -64,10 +41,8 @@ void CheckNan(const T* ptr, size_t size, std::string key, cudaStream_t stream)
6441
template<typename T>
6542
void CmpRead(T* ptr, size_t size, std::string key, cudaStream_t stream)
6643
{
67-
// wait for b
68-
check_cuda_error(cudaStreamSynchronize(stream));
6944
// read a from file
70-
thrust::host_vector<T> h_a(size);
45+
std::vector<T> h_a(size);
7146
{
7247
const auto filename = "tmp/" + key + ".cmp";
7348
std::ifstream ifs(filename, std::ios::binary);
@@ -86,17 +61,21 @@ void CmpRead(T* ptr, size_t size, std::string key, cudaStream_t stream)
8661
}
8762
ifs.read((char*)h_a.data(), sizeof(T) * h_a.size());
8863
}
89-
// copy a to device
90-
thrust::device_vector<T> a = h_a;
91-
// create abs(a - b) iterator
92-
thrust::device_ptr<T> dev_ptr(ptr);
93-
auto zip_iter = thrust::make_zip_iterator(thrust::make_tuple(a.begin(), dev_ptr));
94-
auto transform_iter = thrust::make_transform_iterator(zip_iter, abs_diff<T>{});
95-
// sum(abs(a - b))
96-
auto asum = thrust::reduce(thrust::device, transform_iter, transform_iter + size);
64+
std::vector<T> h_b(size);
65+
check_cuda_error(cudaMemcpyAsync(h_b.data(), ptr, sizeof(T) * size, cudaMemcpyDefault, stream));
66+
check_cuda_error(cudaStreamSynchronize(stream));
67+
68+
using Tacc = std::conditional_t<std::is_integral_v<T>, int64_t, float>;
69+
70+
Tacc asum{};
71+
for (size_t i = 0; i < size; ++i) {
72+
asum += std::abs((Tacc)h_a[i] - (Tacc)h_b[i]);
73+
}
74+
9775
std::cerr << key << ": " << asum << " " << asum / size << "\n";
9876

9977
check_cuda_error(cudaMemcpyAsync(ptr, h_a.data(), sizeof(T) * h_a.size(), cudaMemcpyDefault, stream));
78+
check_cuda_error(cudaStreamSynchronize(stream));
10079
}
10180

10281
template<typename T>

src/turbomind/triton_backend/llama/LlamaTritonModel.cc

+2-2
Original file line numberDiff line numberDiff line change
@@ -315,8 +315,8 @@ LlamaTritonModel<T>::LlamaTritonModel(size_t tensor_para_size,
315315

316316
moe_param_.experts_per_token = model_reader["experts_per_token"].as<int>(0);
317317
moe_param_.inter_size = model_reader["expert_inter_size"].as<int>(0);
318-
moe_param_.shared_gate = model_reader["moe_shared_gate"].as<int>(0);
319-
moe_param_.norm_topk_prob = model_reader["norm_topk_prob"].as<bool>(false);
318+
moe_param_.shared_gate = model_reader["moe_shared_gate"].as<bool>();
319+
moe_param_.norm_topk_prob = model_reader["norm_topk_prob"].as<bool>();
320320
moe_param_.routed_scale = model_reader["routed_scale"].as<float>(1.f);
321321
moe_param_.topk_group = model_reader["topk_group"].as<int>(1);
322322
moe_param_.topk_method = model_reader["topk_method"].as<std::string>("greedy");

0 commit comments

Comments
 (0)