Skip to content

Commit

Permalink
Merge branch 'vllm-project:main' into qwen-support-lora
Browse files Browse the repository at this point in the history
  • Loading branch information
jeejeelee authored Oct 28, 2024
2 parents 6462961 + 2adb440 commit 9b126aa
Show file tree
Hide file tree
Showing 164 changed files with 4,706 additions and 1,748 deletions.
52 changes: 52 additions & 0 deletions .github/workflows/stale.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
name: 'Close inactive issues and PRs'

on:
schedule:
# Daily at 1:30 AM UTC
- cron: '30 1 * * *'

jobs:
close-issues-and-pull-requests:
permissions:
issues: write
pull-requests: write
actions: write
runs-on: ubuntu-latest
steps:
- uses: actions/stale@28ca1036281a5e5922ead5184a1bbf96e5fc984e # v9.0.0
with:
# Increasing this value ensures that changes to this workflow
# propagate to all issues and PRs in days rather than months
operations-per-run: 1000

exempt-draft-pr: true
exempt-issue-labels: 'keep-open'
exempt-pr-labels: 'keep-open'

labels-to-add-when-unstale: 'unstale'
labels-to-remove-when-stale: 'unstale'

days-before-issue-stale: 90
days-before-issue-close: 30
stale-issue-label: 'stale'
stale-issue-message: >
This issue has been automatically marked as stale because it has not
had any activity within 90 days. It will be automatically closed if no
further activity occurs within 30 days. Leave a comment if
you feel this issue should remain open. Thank you!
close-issue-message: >
This issue has been automatically closed due to inactivity. Please
feel free to reopen if you feel it is still relevant. Thank you!
days-before-pr-stale: 90
days-before-pr-close: 30
stale-pr-label: 'stale'
stale-pr-message: >
This pull request has been automatically marked as stale because it
has not had any activity within 90 days. It will be automatically
closed if no further activity occurs within 30 days. Leave a comment
if you feel this pull request should remain open. Thank you!
close-pr-message: >
This pull request has been automatically closed due to inactivity.
Please feel free to reopen if you intend to continue working on it.
Thank you!
16 changes: 9 additions & 7 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx11
# requirements.txt files and should be kept consistent. The ROCm torch
# versions are derived from Dockerfile.rocm
#
set(TORCH_SUPPORTED_VERSION_CUDA "2.4.0")
set(TORCH_SUPPORTED_VERSION_CUDA "2.5.0")
set(TORCH_SUPPORTED_VERSION_ROCM "2.5.0")

#
Expand Down Expand Up @@ -169,12 +169,12 @@ endif()

#
# Use FetchContent for C++ dependencies that are compiled as part of vLLM's build process.
# Configure it to place files in vllm/.deps, in order to play nicely with sccache.
# setup.py will override FETCHCONTENT_BASE_DIR to play nicely with sccache.
# Each dependency that produces build artifacts should override its BINARY_DIR to avoid
# conflicts between build types. It should instead be set to ${CMAKE_BINARY_DIR}/<dependency>.
#
include(FetchContent)
get_filename_component(PROJECT_ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}" ABSOLUTE)
file(MAKE_DIRECTORY "${FETCHCONTENT_BASE_DIR}")
set(FETCHCONTENT_BASE_DIR "${PROJECT_ROOT_DIR}/.deps")
file(MAKE_DIRECTORY ${FETCHCONTENT_BASE_DIR}) # Ensure the directory exists
message(STATUS "FetchContent base directory: ${FETCHCONTENT_BASE_DIR}")

#
Expand All @@ -195,7 +195,6 @@ set(VLLM_EXT_SRC
"csrc/quantization/compressed_tensors/int8_quant_kernels.cu"
"csrc/quantization/fp8/common.cu"
"csrc/cuda_utils_kernels.cu"
"csrc/moe_align_block_size_kernels.cu"
"csrc/prepare_inputs/advance_step.cu"
"csrc/torch_bindings.cpp")

Expand Down Expand Up @@ -405,6 +404,7 @@ target_compile_definitions(_C PRIVATE CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1)

set(VLLM_MOE_EXT_SRC
"csrc/moe/torch_bindings.cpp"
"csrc/moe/moe_align_sum_kernels.cu"
"csrc/moe/topk_softmax_kernels.cu")

set_gencode_flags_for_srcs(
Expand Down Expand Up @@ -507,8 +507,10 @@ else()
FetchContent_Declare(
vllm-flash-attn
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
GIT_TAG 013f0c4fc47e6574060879d9734c1df8c5c273bd
GIT_TAG 5259c586c403a4e4d8bf69973c159b40cc346fb9
GIT_PROGRESS TRUE
# Don't share the vllm-flash-attn build between build types
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn
)
endif()

Expand Down
4 changes: 3 additions & 1 deletion benchmarks/benchmark_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,8 +272,10 @@ def main(args: argparse.Namespace):
raise ValueError(f"Unknown backend: {args.backend}")
total_num_tokens = sum(prompt_len + output_len
for _, prompt_len, output_len in requests)
total_output_tokens = sum(output_len for _, _, output_len in requests)
print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
f"{total_num_tokens / elapsed_time:.2f} tokens/s")
f"{total_num_tokens / elapsed_time:.2f} total tokens/s, "
f"{total_output_tokens / elapsed_time:.2f} output tokens/s")

# Output JSON results if specified
if args.output_json:
Expand Down
33 changes: 17 additions & 16 deletions benchmarks/kernels/benchmark_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,22 +88,23 @@ def prepare(i: int):
input_gating.copy_(gating_output[i])

def run():
fused_moe(
x,
w1,
w2,
input_gating,
topk,
renormalize=True,
inplace=True,
override_config=config,
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a16=use_int8_w8a16,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
)
from vllm.model_executor.layers.fused_moe import override_config
with override_config(config):
fused_moe(
x,
w1,
w2,
input_gating,
topk,
renormalize=True,
inplace=True,
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a16=use_int8_w8a16,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
)

# JIT compilation & warmup
run()
Expand Down
6 changes: 1 addition & 5 deletions cmake/utils.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -424,11 +424,7 @@ function (define_gpu_extension_target GPU_MOD_NAME)
# Don't use `TORCH_LIBRARIES` for CUDA since it pulls in a bunch of
# dependencies that are not necessary and may not be installed.
if (GPU_LANGUAGE STREQUAL "CUDA")
if ("${CUDA_CUDA_LIB}" STREQUAL "")
set(CUDA_CUDA_LIB "${CUDA_CUDA_LIBRARY}")
endif()
target_link_libraries(${GPU_MOD_NAME} PRIVATE ${CUDA_CUDA_LIB}
${CUDA_LIBRARIES})
target_link_libraries(${GPU_MOD_NAME} PRIVATE CUDA::cudart CUDA::cuda_driver)
else()
target_link_libraries(${GPU_MOD_NAME} PRIVATE ${TORCH_LIBRARIES})
endif()
Expand Down
42 changes: 42 additions & 0 deletions csrc/activation_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,48 @@ void gelu_tanh_and_mul(torch::Tensor& out, // [..., d]

namespace vllm {

template <typename T>
__device__ __forceinline__ T fatrelu_kernel(const T& x, const float threshold) {
const float f = (float)x;
return (T)(f > threshold ? f : 0.0f);
}

template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&, const float)>
__global__ void act_and_mul_kernel_with_param(
scalar_t* __restrict__ out, const scalar_t* __restrict__ input, const int d,
const float param) {
const int64_t token_idx = blockIdx.x;
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]);
const scalar_t y = VLLM_LDG(&input[token_idx * 2 * d + d + idx]);
out[token_idx * d + idx] = ACT_FN(x, param) * y;
}
}

} // namespace vllm

#define LAUNCH_ACTIVATION_GATE_KERNEL_WITH_PARAM(KERNEL, PARAM) \
int d = input.size(-1) / 2; \
int64_t num_tokens = input.numel() / input.size(-1); \
dim3 grid(num_tokens); \
dim3 block(std::min(d, 1024)); \
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
VLLM_DISPATCH_FLOATING_TYPES( \
input.scalar_type(), "act_and_mul_kernel_with_param", [&] { \
vllm::act_and_mul_kernel_with_param<scalar_t, KERNEL<scalar_t>> \
<<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(), \
input.data_ptr<scalar_t>(), d, \
PARAM); \
});

void fatrelu_and_mul(torch::Tensor& out, // [..., d],
torch::Tensor& input, // [..., 2 * d]
double threshold) {
LAUNCH_ACTIVATION_GATE_KERNEL_WITH_PARAM(vllm::fatrelu_kernel, threshold);
}
namespace vllm {

// Element-wise activation kernel template.
template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&)>
__global__ void activation_kernel(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>

#include <ATen/ATen.h>
#include <THC/THCAtomics.cuh>

#include "cuda_compat.h"
#include "dispatch_utils.h"
#include "../cuda_compat.h"
#include "../dispatch_utils.h"

#define CEILDIV(x, y) (((x) + (y) - 1) / (y))

namespace vllm {
namespace moe {

namespace {
__device__ __forceinline__ int32_t index(int32_t total_col, int32_t row,
Expand All @@ -32,10 +34,10 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids,
extern __shared__ int32_t shared_mem[];

int32_t* tokens_cnts =
shared_mem; // 2d tensor with shape (num_experts + 1, num_experts)
shared_mem; // 2d tensor with shape (blockDim.x + 1, num_experts)
int32_t* cumsum =
shared_mem + (num_experts + 1) *
num_experts; // 1d tensor with shape (num_experts + 1)
shared_mem +
(blockDim.x + 1) * num_experts; // 1d tensor with shape (num_experts + 1)

for (int i = 0; i < num_experts; ++i) {
tokens_cnts[index(num_experts, threadIdx.x + 1, i)] = 0;
Expand All @@ -53,10 +55,12 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids,
__syncthreads();

// For each expert we accumulate the token counts from the different threads.
tokens_cnts[index(num_experts, 0, threadIdx.x)] = 0;
for (int i = 1; i <= blockDim.x; ++i) {
tokens_cnts[index(num_experts, i, threadIdx.x)] +=
tokens_cnts[index(num_experts, i - 1, threadIdx.x)];
if (threadIdx.x < num_experts) {
tokens_cnts[index(num_experts, 0, threadIdx.x)] = 0;
for (int i = 1; i <= blockDim.x; ++i) {
tokens_cnts[index(num_experts, i, threadIdx.x)] +=
tokens_cnts[index(num_experts, i - 1, threadIdx.x)];
}
}

__syncthreads();
Expand All @@ -79,9 +83,11 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids,
* For each expert, each thread processes the tokens of the corresponding
* blocks and stores the corresponding expert_id for each block.
*/
for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1];
i += block_size) {
expert_ids[i / block_size] = threadIdx.x;
if (threadIdx.x < num_experts) {
for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1];
i += block_size) {
expert_ids[i / block_size] = threadIdx.x;
}
}

/**
Expand All @@ -106,6 +112,24 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids,
++tokens_cnts[index(num_experts, threadIdx.x, expert_id)];
}
}

template <typename scalar_t, int TOPK>
__global__ void moe_sum_kernel(
scalar_t* __restrict__ out, // [..., d]
const scalar_t* __restrict__ input, // [..., topk, d]
const int d) {
const int64_t token_idx = blockIdx.x;
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
scalar_t x = 0.0;
#pragma unroll
for (int k = 0; k < TOPK; ++k) {
x += VLLM_LDG(&input[token_idx * TOPK * d + k * d + idx]);
}
out[token_idx * d + idx] = x;
}
}

} // namespace moe
} // namespace vllm

void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
Expand All @@ -117,18 +141,62 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
// calc needed amount of shared mem for `tokens_cnts` and `cumsum`
// tensors
const int32_t num_thread = max((int32_t)num_experts, WARP_SIZE);
const int32_t shared_mem =
((num_experts + 1) * num_experts + (num_experts + 1)) *
((num_thread + 1) * num_experts + (num_experts + 1)) *
sizeof(int32_t);

// set dynamic shared mem
auto kernel = vllm::moe_align_block_size_kernel<scalar_t>;
auto kernel = vllm::moe::moe_align_block_size_kernel<scalar_t>;
AT_CUDA_CHECK(VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(
(void*)kernel, shared_mem));
kernel<<<1, num_experts, shared_mem, stream>>>(
kernel<<<1, num_thread, shared_mem, stream>>>(
topk_ids.data_ptr<scalar_t>(), sorted_token_ids.data_ptr<int32_t>(),
experts_ids.data_ptr<int32_t>(),
num_tokens_post_pad.data_ptr<int32_t>(), num_experts, block_size,
topk_ids.numel());
});
}

void moe_sum(torch::Tensor& input, // [num_tokens, topk, hidden_size]
torch::Tensor& output) // [num_tokens, hidden_size]
{
const int hidden_size = input.size(-1);
const int num_tokens = output.numel() / hidden_size;
const int topk = input.size(1);

dim3 grid(num_tokens);
dim3 block(std::min(hidden_size, 1024));
const at::cuda::OptionalCUDAGuard device_guard(device_of(output));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();

switch (topk) {
case 2:
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_sum_kernel", [&] {
vllm::moe::moe_sum_kernel<scalar_t, 2><<<grid, block, 0, stream>>>(
output.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(),
hidden_size);
});
break;

case 3:
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_sum_kernel", [&] {
vllm::moe::moe_sum_kernel<scalar_t, 3><<<grid, block, 0, stream>>>(
output.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(),
hidden_size);
});
break;

case 4:
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_sum_kernel", [&] {
vllm::moe::moe_sum_kernel<scalar_t, 4><<<grid, block, 0, stream>>>(
output.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(),
hidden_size);
});
break;

default:
at::sum_out(output, input, 1);
break;
}
}
7 changes: 7 additions & 0 deletions csrc/moe/moe_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,10 @@
void topk_softmax(torch::Tensor& topk_weights, torch::Tensor& topk_indices,
torch::Tensor& token_expert_indices,
torch::Tensor& gating_output);

void moe_sum(torch::Tensor& input, torch::Tensor& output);

void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
int64_t block_size, torch::Tensor sorted_token_ids,
torch::Tensor experts_ids,
torch::Tensor num_tokens_post_pad);
Loading

0 comments on commit 9b126aa

Please sign in to comment.