Skip to content

Commit

Permalink
Deepseek v3 (#11502)
Browse files Browse the repository at this point in the history
Signed-off-by: mgoin <[email protected]>
Co-authored-by: mgoin <[email protected]>
Co-authored-by: robertgshaw2-neuralmagic <[email protected]>
  • Loading branch information
3 people authored Dec 27, 2024
1 parent 55fb97f commit f49777b
Show file tree
Hide file tree
Showing 7 changed files with 887 additions and 61 deletions.
160 changes: 141 additions & 19 deletions csrc/moe/moe_align_sum_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,92 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids,
}
}

// TODO(simon): this is temporarily adapted from
// https://github.com/sgl-project/sglang/commit/31548116a8dc8c6df7e146e0587335a59fc5b9d7
// we did this to unblock Deepseek V3 but there should be a better
// implementation to manage shared memory.
template <typename scalar_t>
__global__ void moe_align_block_size_global_mem_kernel(
scalar_t* __restrict__ topk_ids, int32_t* sorted_token_ids,
int32_t* expert_ids, int32_t* total_tokens_post_pad, int32_t num_experts,
int32_t block_size, size_t numel, int32_t* tokens_cnts, int32_t* cumsum) {
const size_t tokens_per_thread = CEILDIV(numel, blockDim.x);
const size_t start_idx = threadIdx.x * tokens_per_thread;

for (int i = 0; i < num_experts; ++i) {
tokens_cnts[index(num_experts, threadIdx.x + 1, i)] = 0;
}

/**
* In the first step we compute token_cnts[thread_index + 1][expert_index],
* which counts how many tokens in the token shard of thread_index are
* assigned to expert expert_index.
*/
for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
++tokens_cnts[index(num_experts, threadIdx.x + 1, topk_ids[i])];
}

__syncthreads();

// For each expert we accumulate the token counts from the different threads.
if (threadIdx.x < num_experts) {
tokens_cnts[index(num_experts, 0, threadIdx.x)] = 0;
for (int i = 1; i <= blockDim.x; ++i) {
tokens_cnts[index(num_experts, i, threadIdx.x)] +=
tokens_cnts[index(num_experts, i - 1, threadIdx.x)];
}
}

__syncthreads();

// We accumulate the token counts of all experts in thread 0.
if (threadIdx.x == 0) {
cumsum[0] = 0;
for (int i = 1; i <= num_experts; ++i) {
cumsum[i] = cumsum[i - 1] +
CEILDIV(tokens_cnts[index(num_experts, blockDim.x, i - 1)],
block_size) *
block_size;
}
*total_tokens_post_pad = cumsum[num_experts];
}

__syncthreads();

/**
* For each expert, each thread processes the tokens of the corresponding
* blocks and stores the corresponding expert_id for each block.
*/
if (threadIdx.x < num_experts) {
for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1];
i += block_size) {
expert_ids[i / block_size] = threadIdx.x;
}
}

/**
* Each thread processes a token shard, calculating the index of each token
* after sorting by expert number. Given the example topk_ids =
* [0,1,2,1,2,3,0,3,4] and block_size = 4, then the output would be [0, 6, *,
* *, 1, 3, *, *, 2, 4, *, *, 5, 7, *, *, 8, *, *, *], where * represents a
* padding value(preset in python).
*/
for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
int32_t expert_id = topk_ids[i];
/** The cumsum[expert_id] stores the starting index of the tokens that the
* expert with expert_id needs to process, and
* tokens_cnts[threadIdx.x][expert_id] stores the indices of the tokens
* processed by the expert with expert_id within the current thread's token
* shard.
*/
int32_t rank_post_pad =
tokens_cnts[index(num_experts, threadIdx.x, expert_id)] +
cumsum[expert_id];
sorted_token_ids[rank_post_pad] = i;
++tokens_cnts[index(num_experts, threadIdx.x, expert_id)];
}
}

template <typename scalar_t, int TOPK>
__global__ void moe_sum_kernel(
scalar_t* __restrict__ out, // [..., d]
Expand All @@ -137,25 +223,61 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
torch::Tensor experts_ids,
torch::Tensor num_tokens_post_pad) {
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_INTEGRAL_TYPES(
topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
// calc needed amount of shared mem for `tokens_cnts` and `cumsum`
// tensors
const int32_t num_thread = max((int32_t)num_experts, WARP_SIZE);
const int32_t shared_mem =
((num_thread + 1) * num_experts + (num_experts + 1)) *
sizeof(int32_t);

// set dynamic shared mem
auto kernel = vllm::moe::moe_align_block_size_kernel<scalar_t>;
AT_CUDA_CHECK(VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(
(void*)kernel, shared_mem));
kernel<<<1, num_thread, shared_mem, stream>>>(
topk_ids.data_ptr<scalar_t>(), sorted_token_ids.data_ptr<int32_t>(),
experts_ids.data_ptr<int32_t>(),
num_tokens_post_pad.data_ptr<int32_t>(), num_experts, block_size,
topk_ids.numel());
});

// If we have very large number of experts, we can no longer use shared
// memory.
// TODO(simon): the right solution should be calculating the exact right
// amount of shared memory and use that. The num_experts >= 256 is just a
// temporary solution to unblock Deepseek V3.
if (num_experts >= 256) {
VLLM_DISPATCH_INTEGRAL_TYPES(
topk_ids.scalar_type(), "moe_align_block_size_global_mem_kernel", [&] {
// calc needed amount of shared mem for `tokens_cnts` and `cumsum`
// tensors
const int32_t num_thread = max((int32_t)num_experts, WARP_SIZE);

const int32_t mem_tokens_cnts =
((num_experts + 1) * num_experts) * sizeof(int32_t);
const int32_t mem_cumsum = (num_experts + 1) * sizeof(int32_t);
// allocate global memory
int32_t* tokens_cnts;
int32_t* cumsum;
cudaMalloc(&tokens_cnts, mem_tokens_cnts);
cudaMalloc(&cumsum, mem_cumsum);

auto kernel =
vllm::moe::moe_align_block_size_global_mem_kernel<scalar_t>;
kernel<<<1, num_thread, 0, stream>>>(
topk_ids.data_ptr<scalar_t>(),
sorted_token_ids.data_ptr<int32_t>(),
experts_ids.data_ptr<int32_t>(),
num_tokens_post_pad.data_ptr<int32_t>(), num_experts, block_size,
topk_ids.numel(), tokens_cnts, cumsum);
cudaFree(tokens_cnts);
cudaFree(cumsum);
});
} else {
VLLM_DISPATCH_INTEGRAL_TYPES(
topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
// calc needed amount of shared mem for `tokens_cnts` and `cumsum`
// tensors
const int32_t num_thread = max((int32_t)num_experts, WARP_SIZE);
const int32_t shared_mem =
((num_thread + 1) * num_experts + (num_experts + 1)) *
sizeof(int32_t);

// set dynamic shared mem
auto kernel = vllm::moe::moe_align_block_size_kernel<scalar_t>;
AT_CUDA_CHECK(VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(
(void*)kernel, shared_mem));
kernel<<<1, num_thread, shared_mem, stream>>>(
topk_ids.data_ptr<scalar_t>(),
sorted_token_ids.data_ptr<int32_t>(),
experts_ids.data_ptr<int32_t>(),
num_tokens_post_pad.data_ptr<int32_t>(), num_experts, block_size,
topk_ids.numel());
});
}
}

void moe_sum(torch::Tensor& input, // [num_tokens, topk, hidden_size]
Expand Down
11 changes: 9 additions & 2 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,6 +596,12 @@ def _verify_cuda_graph(self) -> None:
self.max_seq_len_to_capture = min(self.max_seq_len_to_capture,
self.max_model_len)

if (self.hf_config.model_type == 'deepseek_v3'
and not self.enforce_eager):
logger.warning("CUDA graph is not supported for Deepseek V3 yet, "
"fallback to the eager mode.")
self.enforce_eager = True

def _verify_bnb_config(self) -> None:
"""
The current version of bitsandbytes (0.44.0) with 8-bit models does not
Expand Down Expand Up @@ -712,8 +718,9 @@ def get_hidden_size(self) -> int:

def get_head_size(self) -> int:
# TODO remove hard code
if hasattr(self.hf_text_config, "model_type"
) and self.hf_text_config.model_type == 'deepseek_v2':
if hasattr(self.hf_text_config,
"model_type") and (self.hf_text_config.model_type
in ('deepseek_v2', 'deepseek_v3')):
# FlashAttention supports only head_size 32, 64, 128, 256,
# we need to pad head_size 192 to 256
return 256
Expand Down
17 changes: 14 additions & 3 deletions vllm/model_executor/layers/fused_moe/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,18 +476,29 @@ def fused_topk(
return topk_weights, topk_ids


# This is used by the Deepseek-V2 model
# This is used by the Deepseek-V2 and Deepseek-V3 model
def grouped_topk(hidden_states: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
num_expert_group: int = 0,
topk_group: int = 0):
topk_group: int = 0,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None):

assert hidden_states.shape[0] == gating_output.shape[0], (
"Number of tokens mismatch")

scores = torch.softmax(gating_output, dim=-1)
if scoring_func == "softmax":
scores = torch.softmax(gating_output, dim=-1)
elif scoring_func == "sigmoid":
scores = gating_output.sigmoid()
else:
raise ValueError(f"Unsupported scoring function: {scoring_func}")

if e_score_correction_bias is not None:
scores.add_(e_score_correction_bias.unsqueeze(0))

num_token = scores.shape[0]
group_scores = scores.view(num_token, num_expert_group,
-1).max(dim=-1).values # [n, n_group]
Expand Down
Loading

0 comments on commit f49777b

Please sign in to comment.