diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index cd3a5e80d7bd0..445d74d6d9bbe 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -72,7 +72,7 @@ steps: commands: - # the following commands are for the first node, with ip 192.168.10.10 (ray environment already set up) - VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py - - TP_SIZE=2 PP_SIZE=2 EAGER_MODE=1 CHUNKED_PREFILL=0 pytest -v -s distributed/test_pipeline_parallel.py + - pytest -v -s distributed/test_pipeline_parallel.py - # the following commands are for the second node, with ip 192.168.10.11 (ray environment already set up) - VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py @@ -115,12 +115,7 @@ steps: working_dir: "/vllm-workspace/tests" num_gpus: 4 commands: - - TP_SIZE=2 PP_SIZE=2 EAGER_MODE=1 CHUNKED_PREFILL=1 pytest -v -s distributed/test_pipeline_parallel.py - - TP_SIZE=2 PP_SIZE=2 EAGER_MODE=1 CHUNKED_PREFILL=0 pytest -v -s distributed/test_pipeline_parallel.py - - TP_SIZE=1 PP_SIZE=3 EAGER_MODE=1 CHUNKED_PREFILL=0 pytest -v -s distributed/test_pipeline_parallel.py - - PP_SIZE=4 EAGER_MODE=1 CHUNKED_PREFILL=1 pytest -v -s distributed/test_pipeline_parallel.py - - PP_SIZE=4 EAGER_MODE=1 CHUNKED_PREFILL=0 pytest -v -s distributed/test_pipeline_parallel.py - + - pytest -v -s distributed/test_pipeline_parallel.py - label: Engine Test mirror_hardwares: [amd] diff --git a/CMakeLists.txt b/CMakeLists.txt index ced73ca03bfbc..335623bd2677d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -151,6 +151,7 @@ set(VLLM_EXT_SRC "csrc/quantization/fp8/common.cu" "csrc/cuda_utils_kernels.cu" "csrc/moe_align_block_size_kernels.cu" + "csrc/prepare_inputs/advance_step.cu" "csrc/torch_bindings.cpp") if(VLLM_GPU_LANG STREQUAL "CUDA") diff --git a/Dockerfile.rocm b/Dockerfile.rocm index befb0499f2e68..85dfda8dbb532 100644 --- a/Dockerfile.rocm +++ b/Dockerfile.rocm @@ -1,11 +1,6 @@ # Default ROCm 6.1 base image ARG BASE_IMAGE="rocm/pytorch:rocm6.1.2_ubuntu20.04_py3.9_pytorch_staging" -# Tested and supported base rocm/pytorch images -ARG ROCm_5_7_BASE="rocm/pytorch:rocm5.7_ubuntu20.04_py3.9_pytorch_2.0.1" \ - ROCm_6_0_BASE="rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1" \ - ROCM_6_1_BASE="rocm/pytorch:rocm6.1.2_ubuntu20.04_py3.9_pytorch_staging" - # Default ROCm ARCHes to build vLLM for. ARG PYTORCH_ROCM_ARCH="gfx908;gfx90a;gfx942;gfx1100" @@ -54,18 +49,6 @@ RUN pip install --upgrade pip RUN apt-get purge -y sccache; pip uninstall -y sccache; rm -f "$(which sccache)" # Install torch == 2.5.0 on ROCm RUN case "$(ls /opt | grep -Po 'rocm-[0-9]\.[0-9]')" in \ - *"rocm-5.7"*) \ - pip uninstall -y torch torchaudio torchvision \ - && pip install --no-cache-dir --pre \ - torch==2.5.0.dev20240710 torchaudio==2.4.0.dev20240710 \ - torchvision==0.20.0.dev20240710 \ - --index-url https://download.pytorch.org/whl/nightly/rocm5.7;; \ - *"rocm-6.0"*) \ - pip uninstall -y torch torchaudio torchvision \ - && pip install --no-cache-dir --pre \ - torch==2.5.0.dev20240710 torchaudio==2.4.0.dev20240710 \ - torchvision==0.20.0.dev20240710 \ - --index-url https://download.pytorch.org/whl/nightly/rocm6.0;; \ *"rocm-6.1"*) \ pip uninstall -y torch torchaudio torchvision \ && pip install --no-cache-dir --pre \ @@ -104,11 +87,6 @@ RUN --mount=type=cache,target=${CCACHE_DIR} \ && cd flash-attention \ && git checkout "${FA_BRANCH}" \ && git submodule update --init \ - && case "$(ls /opt | grep -Po 'rocm-[0-9]\.[0-9]')" in \ - *"rocm-5.7"*) \ - export VLLM_TORCH_PATH="$(python3 -c 'import torch; print(torch.__path__[0])')" \ - && patch "${VLLM_TORCH_PATH}"/utils/hipify/hipify_python.py hipify_patch.patch;; \ - *) ;; esac \ && GPU_ARCHS="${FA_GFX_ARCHS}" python3 setup.py bdist_wheel --dist-dir=/install; \ # Create an empty directory otherwise as later build stages expect one else mkdir -p /install; \ @@ -161,12 +139,9 @@ RUN --mount=type=cache,target=${CCACHE_DIR} \ --mount=type=cache,target=/root/.cache/pip \ pip install -U -r requirements-rocm.txt \ && case "$(ls /opt | grep -Po 'rocm-[0-9]\.[0-9]')" in \ - *"rocm-6.0"*) \ - patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h rocm_patch/rocm_bf16.patch;; \ *"rocm-6.1"*) \ # Bring in upgrades to HIP graph earlier than ROCm 6.2 for vLLM - wget -N https://github.com/ROCm/vllm/raw/fa78403/rocm_patch/libamdhip64.so.6 -P rocm_patch \ - && cp rocm_patch/libamdhip64.so.6 /opt/rocm/lib/libamdhip64.so.6 \ + wget -N https://github.com/ROCm/vllm/raw/fa78403/rocm_patch/libamdhip64.so.6 -P /opt/rocm/lib \ # Prevent interference if torch bundles its own HIP runtime && rm -f "$(python3 -c 'import torch; print(torch.__path__[0])')"/lib/libamdhip64.so* || true;; \ *) ;; esac \ diff --git a/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py b/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py index 377f8683c021f..234c2c8a1074c 100644 --- a/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py +++ b/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py @@ -20,18 +20,18 @@ # helpers -def to_fp8(tensor: torch.tensor) -> torch.tensor: +def to_fp8(tensor: torch.Tensor) -> torch.Tensor: finfo = torch.finfo(torch.float8_e4m3fn) return torch.round(tensor.clamp( min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn) -def to_int8(tensor: torch.tensor) -> torch.tensor: +def to_int8(tensor: torch.Tensor) -> torch.Tensor: return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8) def make_rand_tensors(dtype: torch.dtype, m: int, n: int, - k: int) -> Tuple[torch.tensor, torch.tensor]: + k: int) -> Tuple[torch.Tensor, torch.Tensor]: a = torch.randn((m, k), device='cuda') * 5 b = torch.randn((n, k), device='cuda').t() * 5 @@ -47,15 +47,15 @@ def make_rand_tensors(dtype: torch.dtype, m: int, n: int, # impl -def pytorch_mm_impl(a: torch.tensor, b: torch.tensor, scale_a: torch.tensor, - scale_b: torch.tensor, - out_dtype: torch.dtype) -> torch.tensor: +def pytorch_mm_impl(a: torch.Tensor, b: torch.Tensor, scale_a: torch.Tensor, + scale_b: torch.Tensor, + out_dtype: torch.dtype) -> torch.Tensor: return torch.mm(a, b) -def pytorch_fp8_impl(a: torch.tensor, b: torch.tensor, scale_a: torch.tensor, - scale_b: torch.tensor, - out_dtype: torch.dtype) -> torch.tensor: +def pytorch_fp8_impl(a: torch.Tensor, b: torch.Tensor, scale_a: torch.Tensor, + scale_b: torch.Tensor, + out_dtype: torch.dtype) -> torch.Tensor: return torch._scaled_mm(a, b, scale_a=scale_a, @@ -63,9 +63,9 @@ def pytorch_fp8_impl(a: torch.tensor, b: torch.tensor, scale_a: torch.tensor, out_dtype=out_dtype) -def pytorch_fp8_impl_fast_accum(a: torch.tensor, b: torch.tensor, - scale_a: torch.tensor, scale_b: torch.tensor, - out_dtype: torch.dtype) -> torch.tensor: +def pytorch_fp8_impl_fast_accum(a: torch.Tensor, b: torch.Tensor, + scale_a: torch.Tensor, scale_b: torch.Tensor, + out_dtype: torch.dtype) -> torch.Tensor: return torch._scaled_mm(a, b, scale_a=scale_a, @@ -74,15 +74,15 @@ def pytorch_fp8_impl_fast_accum(a: torch.tensor, b: torch.tensor, use_fast_accum=True) -def cutlass_impl(a: torch.tensor, b: torch.tensor, scale_a: torch.tensor, - scale_b: torch.tensor, - out_dtype: torch.dtype) -> torch.tensor: +def cutlass_impl(a: torch.Tensor, b: torch.Tensor, scale_a: torch.Tensor, + scale_b: torch.Tensor, + out_dtype: torch.dtype) -> torch.Tensor: return ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype=out_dtype) # bench -def bench_fn(a: torch.tensor, b: torch.tensor, scale_a: torch.tensor, - scale_b: torch.tensor, out_dtype: torch.dtype, label: str, +def bench_fn(a: torch.Tensor, b: torch.Tensor, scale_a: torch.Tensor, + scale_b: torch.Tensor, out_dtype: torch.dtype, label: str, sub_label: str, fn: Callable, description: str) -> TMeasurement: min_run_time = 1 diff --git a/benchmarks/kernels/benchmark_paged_attention.py b/benchmarks/kernels/benchmark_paged_attention.py index 16de60477c305..78cac8a555d1b 100644 --- a/benchmarks/kernels/benchmark_paged_attention.py +++ b/benchmarks/kernels/benchmark_paged_attention.py @@ -100,7 +100,7 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: start_time = time.perf_counter() # Using default kv_scale - kv_scale = 1.0 + k_scale = v_scale = 1.0 for _ in range(num_iters): if version == "v1": @@ -117,7 +117,8 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: max_seq_len, alibi_slopes, kv_cache_dtype, - kv_scale, + k_scale, + v_scale, ) elif version == "v2": ops.paged_attention_v2( @@ -136,7 +137,8 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: max_seq_len, alibi_slopes, kv_cache_dtype, - kv_scale, + k_scale, + v_scale, ) else: raise ValueError(f"Invalid version: {version}") diff --git a/csrc/attention/attention_kernels.cu b/csrc/attention/attention_kernels.cu index 91083481705cb..350dbce1d7ba9 100644 --- a/csrc/attention/attention_kernels.cu +++ b/csrc/attention/attention_kernels.cu @@ -105,9 +105,9 @@ __device__ void paged_attention_kernel( const int max_num_blocks_per_seq, const float* __restrict__ alibi_slopes, // [num_heads] const int q_stride, const int kv_block_stride, const int kv_head_stride, - const float kv_scale, const int tp_rank, const int blocksparse_local_blocks, - const int blocksparse_vert_stride, const int blocksparse_block_size, - const int blocksparse_head_sliding_step) { + const float k_scale, const float v_scale, const int tp_rank, + const int blocksparse_local_blocks, const int blocksparse_vert_stride, + const int blocksparse_block_size, const int blocksparse_head_sliding_step) { const int seq_idx = blockIdx.y; const int partition_idx = blockIdx.z; const int max_num_partitions = gridDim.z; @@ -285,7 +285,7 @@ __device__ void paged_attention_kernel( Quant_vec k_vec_quant = *reinterpret_cast( k_ptr + offset1 * BLOCK_SIZE * x + offset2); k_vecs[j] = fp8::scaled_convert( - k_vec_quant, kv_scale); + k_vec_quant, k_scale); } } @@ -415,7 +415,7 @@ __device__ void paged_attention_kernel( *reinterpret_cast(v_ptr + offset); // Vector conversion from V_quant_vec to V_vec. v_vec = fp8::scaled_convert(v_quant_vec, - kv_scale); + v_scale); } if (block_idx == num_seq_blocks - 1) { // NOTE(woosuk): When v_vec contains the tokens that are out of the @@ -513,15 +513,15 @@ __global__ void paged_attention_v1_kernel( const int max_num_blocks_per_seq, const float* __restrict__ alibi_slopes, // [num_heads] const int q_stride, const int kv_block_stride, const int kv_head_stride, - const float kv_scale, const int tp_rank, const int blocksparse_local_blocks, - const int blocksparse_vert_stride, const int blocksparse_block_size, - const int blocksparse_head_sliding_step) { + const float k_scale, const float v_scale, const int tp_rank, + const int blocksparse_local_blocks, const int blocksparse_vert_stride, + const int blocksparse_block_size, const int blocksparse_head_sliding_step) { paged_attention_kernel( /* exp_sums */ nullptr, /* max_logits */ nullptr, out, q, k_cache, v_cache, num_kv_heads, scale, block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, - kv_head_stride, kv_scale, tp_rank, blocksparse_local_blocks, + kv_head_stride, k_scale, v_scale, tp_rank, blocksparse_local_blocks, blocksparse_vert_stride, blocksparse_block_size, blocksparse_head_sliding_step); } @@ -549,14 +549,14 @@ __global__ void paged_attention_v2_kernel( const int max_num_blocks_per_seq, const float* __restrict__ alibi_slopes, // [num_heads] const int q_stride, const int kv_block_stride, const int kv_head_stride, - const float kv_scale, const int tp_rank, const int blocksparse_local_blocks, - const int blocksparse_vert_stride, const int blocksparse_block_size, - const int blocksparse_head_sliding_step) { + const float k_scale, const float v_scale, const int tp_rank, + const int blocksparse_local_blocks, const int blocksparse_vert_stride, + const int blocksparse_block_size, const int blocksparse_head_sliding_step) { paged_attention_kernel( exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale, block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes, q_stride, - kv_block_stride, kv_head_stride, kv_scale, tp_rank, + kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank, blocksparse_local_blocks, blocksparse_vert_stride, blocksparse_block_size, blocksparse_head_sliding_step); } @@ -682,7 +682,7 @@ __global__ void paged_attention_v2_reduce_kernel( out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, \ scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \ alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \ - kv_scale, tp_rank, blocksparse_local_blocks, \ + k_scale, v_scale, tp_rank, blocksparse_local_blocks, \ blocksparse_vert_stride, blocksparse_block_size, \ blocksparse_head_sliding_step); @@ -694,8 +694,8 @@ void paged_attention_v1_launcher( torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache, torch::Tensor& value_cache, int num_kv_heads, float scale, torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len, - const c10::optional& alibi_slopes, float kv_scale, - const int tp_rank, const int blocksparse_local_blocks, + const c10::optional& alibi_slopes, float k_scale, + float v_scale, const int tp_rank, const int blocksparse_local_blocks, const int blocksparse_vert_stride, const int blocksparse_block_size, const int blocksparse_head_sliding_step) { int num_seqs = query.size(0); @@ -770,7 +770,7 @@ void paged_attention_v1_launcher( paged_attention_v1_launcher( \ out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \ - seq_lens, max_seq_len, alibi_slopes, kv_scale, tp_rank, \ + seq_lens, max_seq_len, alibi_slopes, k_scale, v_scale, tp_rank, \ blocksparse_local_blocks, blocksparse_vert_stride, \ blocksparse_block_size, blocksparse_head_sliding_step); @@ -815,8 +815,8 @@ void paged_attention_v1( torch::Tensor& seq_lens, // [num_seqs] int64_t block_size, int64_t max_seq_len, const c10::optional& alibi_slopes, - const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank, - const int64_t blocksparse_local_blocks, + const std::string& kv_cache_dtype, double k_scale, double v_scale, + const int64_t tp_rank, const int64_t blocksparse_local_blocks, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_head_sliding_step) { const bool is_block_sparse = (blocksparse_vert_stride > 1); @@ -833,7 +833,7 @@ void paged_attention_v1( exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, \ value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \ seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \ - kv_block_stride, kv_head_stride, kv_scale, tp_rank, \ + kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank, \ blocksparse_local_blocks, blocksparse_vert_stride, \ blocksparse_block_size, blocksparse_head_sliding_step); \ vllm::paged_attention_v2_reduce_kernel& alibi_slopes, float kv_scale, - const int tp_rank, const int blocksparse_local_blocks, + const c10::optional& alibi_slopes, float k_scale, + float v_scale, const int tp_rank, const int blocksparse_local_blocks, const int blocksparse_vert_stride, const int blocksparse_block_size, const int blocksparse_head_sliding_step) { int num_seqs = query.size(0); @@ -932,8 +932,9 @@ void paged_attention_v2_launcher( IS_BLOCK_SPARSE>( \ out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \ num_kv_heads, scale, block_tables, seq_lens, max_seq_len, alibi_slopes, \ - kv_scale, tp_rank, blocksparse_local_blocks, blocksparse_vert_stride, \ - blocksparse_block_size, blocksparse_head_sliding_step); + k_scale, v_scale, tp_rank, blocksparse_local_blocks, \ + blocksparse_vert_stride, blocksparse_block_size, \ + blocksparse_head_sliding_step); #define CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \ switch (is_block_sparse) { \ @@ -980,8 +981,8 @@ void paged_attention_v2( torch::Tensor& seq_lens, // [num_seqs] int64_t block_size, int64_t max_seq_len, const c10::optional& alibi_slopes, - const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank, - const int64_t blocksparse_local_blocks, + const std::string& kv_cache_dtype, double k_scale, double v_scale, + const int64_t tp_rank, const int64_t blocksparse_local_blocks, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_head_sliding_step) { const bool is_block_sparse = (blocksparse_vert_stride > 1); diff --git a/csrc/cache.h b/csrc/cache.h index 86caa9345361d..52177e8901a89 100644 --- a/csrc/cache.h +++ b/csrc/cache.h @@ -18,8 +18,8 @@ void copy_blocks(std::vector const& key_caches, void reshape_and_cache(torch::Tensor& key, torch::Tensor& value, torch::Tensor& key_cache, torch::Tensor& value_cache, torch::Tensor& slot_mapping, - const std::string& kv_cache_dtype, - const double kv_scale); + const std::string& kv_cache_dtype, const double k_scale, + const double v_scale); void reshape_and_cache_flash(torch::Tensor& key, torch::Tensor& value, torch::Tensor& key_cache, diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index 72041076ae009..caef7f5e18630 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -159,8 +159,8 @@ __global__ void reshape_and_cache_kernel( // block_size] const int64_t* __restrict__ slot_mapping, // [num_tokens] const int key_stride, const int value_stride, const int num_heads, - const int head_size, const int block_size, const int x, - const float kv_scale) { + const int head_size, const int block_size, const int x, const float k_scale, + const float v_scale) { const int64_t token_idx = blockIdx.x; const int64_t slot_idx = slot_mapping[token_idx]; if (slot_idx < 0) { @@ -196,9 +196,9 @@ __global__ void reshape_and_cache_kernel( value_cache[tgt_value_idx] = tgt_value; } else { key_cache[tgt_key_idx] = - fp8::scaled_convert(tgt_key, kv_scale); + fp8::scaled_convert(tgt_key, k_scale); value_cache[tgt_value_idx] = - fp8::scaled_convert(tgt_value, kv_scale); + fp8::scaled_convert(tgt_value, v_scale); } } } @@ -248,7 +248,7 @@ __global__ void reshape_and_cache_flash_kernel( reinterpret_cast(key_cache.data_ptr()), \ reinterpret_cast(value_cache.data_ptr()), \ slot_mapping.data_ptr(), key_stride, value_stride, \ - num_heads, head_size, block_size, x, kv_scale); + num_heads, head_size, block_size, x, k_scale, v_scale); void reshape_and_cache( torch::Tensor& key, // [num_tokens, num_heads, head_size] @@ -258,7 +258,8 @@ void reshape_and_cache( torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size] torch::Tensor& slot_mapping, // [num_tokens] - const std::string& kv_cache_dtype, const double kv_scale) { + const std::string& kv_cache_dtype, const double k_scale, + const double v_scale) { int num_tokens = key.size(0); int num_heads = key.size(1); int head_size = key.size(2); @@ -318,13 +319,13 @@ namespace vllm { template __global__ void convert_fp8_kernel(const Tin* __restrict__ src_cache, Tout* __restrict__ dst_cache, - const float kv_scale, + const float scale, const int64_t block_stride) { const int64_t block_idx = blockIdx.x; for (int i = threadIdx.x; i < block_stride; i += blockDim.x) { int64_t idx = block_idx * block_stride + i; dst_cache[idx] = - fp8::scaled_convert(src_cache[idx], kv_scale); + fp8::scaled_convert(src_cache[idx], scale); } } @@ -333,11 +334,11 @@ __global__ void convert_fp8_kernel(const Tin* __restrict__ src_cache, #define CALL_CONVERT_FP8(Tout, Tin, KV_DTYPE) \ vllm::convert_fp8_kernel<<>>( \ reinterpret_cast(src_cache.data_ptr()), \ - reinterpret_cast(dst_cache.data_ptr()), kv_scale, block_stride); + reinterpret_cast(dst_cache.data_ptr()), scale, block_stride); // Only for testing. void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache, - const double kv_scale, const std::string& kv_cache_dtype) { + const double scale, const std::string& kv_cache_dtype) { torch::Device src_device = src_cache.device(); torch::Device dst_device = dst_cache.device(); TORCH_CHECK(src_device.is_cuda(), "src must be on a GPU") diff --git a/csrc/cpu/attention.cpp b/csrc/cpu/attention.cpp index 8367093325314..abb4e3bea14bb 100644 --- a/csrc/cpu/attention.cpp +++ b/csrc/cpu/attention.cpp @@ -423,11 +423,11 @@ void paged_attention_v1( torch::Tensor& value_cache, int64_t num_kv_heads, double scale, torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, int64_t max_seq_len, const c10::optional& alibi_slopes, - const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank, - const int64_t blocksparse_local_blocks, + const std::string& kv_cache_dtype, double k_scale, double v_scale, + const int64_t tp_rank, const int64_t blocksparse_local_blocks, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_head_sliding_step) { - TORCH_CHECK(kv_scale == 1.0f); + TORCH_CHECK(k_scale == 1.0f && v_scale == 1.0f); TORCH_CHECK(blocksparse_vert_stride <= 1, "CPU backend does not support blocksparse attention yet."); VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v1_impl", @@ -742,11 +742,11 @@ void paged_attention_v2( torch::Tensor& value_cache, int64_t num_kv_heads, double scale, torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, int64_t max_seq_len, const c10::optional& alibi_slopes, - const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank, - const int64_t blocksparse_local_blocks, + const std::string& kv_cache_dtype, double k_scale, double v_scale, + const int64_t tp_rank, const int64_t blocksparse_local_blocks, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_head_sliding_step) { - TORCH_CHECK(kv_scale == 1.0f); + TORCH_CHECK(k_scale == 1.0f && v_scale == 1.0f); TORCH_CHECK(blocksparse_vert_stride <= 1, "CPU backend does not support blocksparse attention yet."); VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v2_impl", diff --git a/csrc/cpu/cache.cpp b/csrc/cpu/cache.cpp index 2b5c3bd6ee70b..31d454328b2c1 100644 --- a/csrc/cpu/cache.cpp +++ b/csrc/cpu/cache.cpp @@ -107,8 +107,9 @@ void copy_blocks(std::vector const& key_caches, void reshape_and_cache(torch::Tensor& key, torch::Tensor& value, torch::Tensor& key_cache, torch::Tensor& value_cache, torch::Tensor& slot_mapping, - const std::string& kv_cache_dtype, double kv_scale) { - TORCH_CHECK(kv_scale == 1.0f); + const std::string& kv_cache_dtype, double k_scale, + double v_scale) { + TORCH_CHECK(k_scale == 1.0f && v_scale == 1.0f); int num_tokens = key.size(0); int num_heads = key.size(1); diff --git a/csrc/cpu/torch_bindings.cpp b/csrc/cpu/torch_bindings.cpp index 39e8cf3ed3c10..5be0e9810b5b9 100644 --- a/csrc/cpu/torch_bindings.cpp +++ b/csrc/cpu/torch_bindings.cpp @@ -16,8 +16,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " Tensor value_cache, int num_kv_heads, float scale," " Tensor block_tables, Tensor seq_lens, int block_size," " int max_seq_len, Tensor? alibi_slopes," - " str kv_cache_dtype, float kv_scale, int tp_rank," - " int blocksparse_local_blocks," + " str kv_cache_dtype, float k_scale, float v_scale," + " int tp_rank, int blocksparse_local_blocks," " int blocksparse_vert_stride, int blocksparse_block_size," " int blocksparse_head_sliding_step) -> ()"); ops.impl("paged_attention_v1", torch::kCPU, &paged_attention_v1); @@ -30,8 +30,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " Tensor value_cache, int num_kv_heads, float scale," " Tensor block_tables, Tensor seq_lens, int block_size," " int max_seq_len, Tensor? alibi_slopes," - " str kv_cache_dtype, float kv_scale, int tp_rank," - " int blocksparse_local_blocks," + " str kv_cache_dtype, float k_scale, float v_scale," + " int tp_rank, int blocksparse_local_blocks," " int blocksparse_vert_stride, int blocksparse_block_size," " int blocksparse_head_sliding_step) -> ()"); ops.impl("paged_attention_v2", torch::kCPU, &paged_attention_v2); @@ -103,7 +103,7 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { " Tensor! key_cache, Tensor! value_cache," " Tensor slot_mapping," " str kv_cache_dtype," - " float kv_scale) -> ()"); + " float k_scale, float v_scale) -> ()"); cache_ops.impl("reshape_and_cache", torch::kCPU, &reshape_and_cache); } diff --git a/csrc/ops.h b/csrc/ops.h index fb1099e4fe0c2..1e94a9f45ef08 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -8,8 +8,8 @@ void paged_attention_v1( torch::Tensor& value_cache, int64_t num_kv_heads, double scale, torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, int64_t max_seq_len, const c10::optional& alibi_slopes, - const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank, - const int64_t blocksparse_local_blocks, + const std::string& kv_cache_dtype, double k_scale, double v_scale, + const int64_t tp_rank, const int64_t blocksparse_local_blocks, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_head_sliding_step); @@ -19,8 +19,8 @@ void paged_attention_v2( torch::Tensor& value_cache, int64_t num_kv_heads, double scale, torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, int64_t max_seq_len, const c10::optional& alibi_slopes, - const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank, - const int64_t blocksparse_local_blocks, + const std::string& kv_cache_dtype, double k_scale, double v_scale, + const int64_t tp_rank, const int64_t blocksparse_local_blocks, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_head_sliding_step); @@ -52,6 +52,11 @@ void gelu_fast(torch::Tensor& out, torch::Tensor& input); void gelu_quick(torch::Tensor& out, torch::Tensor& input); +void advance_step(int64_t num_seqs, int64_t num_queries, int64_t block_size, + torch::Tensor& input_tokens, torch::Tensor& sampled_token_ids, + torch::Tensor& input_positions, torch::Tensor& seq_lens, + torch::Tensor& slot_mapping, torch::Tensor& block_tables); + #ifndef USE_ROCM torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes, const torch::Tensor& codebooks, diff --git a/csrc/prepare_inputs/advance_step.cu b/csrc/prepare_inputs/advance_step.cu new file mode 100644 index 0000000000000..0e537ddd6c4cd --- /dev/null +++ b/csrc/prepare_inputs/advance_step.cu @@ -0,0 +1,131 @@ +/* + * The goal of this GPU kernel is to advance input tensors on the GPU directly + * PR: https://github.com/vllm-project/vllm/pull/6338 + * Current restrictions: + * 1. Specialized for DraftModelRunner + * 2. Supports flash_attn only + */ + +#include "advance_step.cuh" + +namespace prepare_inputs { + +// +template +__global__ void advance_step_kernel(int num_seqs, int num_queries, + int block_size, long* input_tokens_ptr, + long const* sampled_token_ids_ptr, + long* input_positions_ptr, + int* seq_lens_ptr, long* slot_mapping_ptr, + int const* block_tables_ptr, + int64_t const block_tables_stride) { + int num_query_blocks = div_ceil(num_queries, num_threads); + + if (blockIdx.x >= num_query_blocks) { + return; + } + + int cur_query_id = blockIdx.x * num_threads + threadIdx.x; + + if (cur_query_id >= num_queries) { + return; + } + + // Update input_tokens + input_tokens_ptr[cur_query_id] = sampled_token_ids_ptr[cur_query_id]; + + int seq_len = seq_lens_ptr[cur_query_id]; + int next_seq_len = seq_len + 1; + int next_input_pos = next_seq_len - 1; + + // Update seq_lens + seq_lens_ptr[cur_query_id] = next_seq_len; + // Update input_positions + input_positions_ptr[cur_query_id] = next_input_pos; + + int const* seq_block_tables_ptr = + block_tables_ptr + block_tables_stride * cur_query_id; + + int block_index = next_input_pos / block_size; + int block_offset = next_input_pos % block_size; + + int slot_num = seq_block_tables_ptr[block_index] * block_size + block_offset; + // Update slot_mapping + slot_mapping_ptr[cur_query_id] = slot_num; +} + +inline void verify_tensor(std::string const& name, torch::Tensor& t, + int64_t const size_0, int64_t const size_1, + c10::ScalarType const type) { + bool size_0_cond = true; + if (size_0 != -1) { + size_0_cond = t.size(0) == size_0; + } + + bool size_1_cond = true; + if (size_1 != -1) { + size_1_cond = t.size(1) == size_1; + } + + bool is_contiguous = t.is_contiguous(); + bool same_type = t.dtype() == type; + + bool pass = size_0_cond && size_1_cond && is_contiguous && same_type; + if (!pass) { + TORCH_CHECK(false, "tensor: name = ", name, ", shape = ", t.sizes(), + " is_cont = ", t.is_contiguous(), ", type = ", t.dtype(), + " is not as expected: shape = [", size_0, ", ", size_1, + "], type = ", type); + } +} + +void advance_step(int num_seqs, int num_queries, int block_size, + torch::Tensor& input_tokens, // type: long + torch::Tensor& sampled_token_ids, // type: long + torch::Tensor& input_positions, // type: long + torch::Tensor& seq_lens, // type: int + torch::Tensor& slot_mapping, // type: long + torch::Tensor& block_tables) { // type: int + + if (logging) { + printf("advance_step:\n"); + printf(" num_seqs = %d\n", num_seqs); + printf(" num_queries = %d\n", num_queries); + printf(" block_size = %d\n", block_size); + } + // Verify all tensors + verify_tensor("input_tokens", input_tokens, num_seqs, -1, at::kLong); + verify_tensor("sampled_token_ids", sampled_token_ids, num_queries, 1, + at::kLong); + verify_tensor("input_positions", input_positions, num_seqs, -1, at::kLong); + verify_tensor("seq_lens", seq_lens, num_seqs, -1, at::kInt); + verify_tensor("slot_mapping", slot_mapping, num_seqs, -1, at::kLong); + verify_tensor("block_tables", block_tables, num_seqs, -1, at::kInt); + + int dev = sampled_token_ids.get_device(); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(dev); + + int blocks; + cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev); + + advance_step_kernel<<>>( + num_seqs, num_queries, block_size, + reinterpret_cast(input_tokens.data_ptr()), + reinterpret_cast(sampled_token_ids.data_ptr()), + reinterpret_cast(input_positions.data_ptr()), + reinterpret_cast(seq_lens.data_ptr()), + reinterpret_cast(slot_mapping.data_ptr()), + reinterpret_cast(block_tables.data_ptr()), + block_tables.stride(0)); +} + +} // namespace prepare_inputs + +void advance_step(int64_t num_seqs, int64_t num_queries, int64_t block_size, + torch::Tensor& input_tokens, torch::Tensor& sampled_token_ids, + torch::Tensor& input_positions, torch::Tensor& seq_lens, + torch::Tensor& slot_mapping, torch::Tensor& block_tables) { + prepare_inputs::advance_step(num_seqs, num_queries, block_size, input_tokens, + sampled_token_ids, input_positions, seq_lens, + slot_mapping, block_tables); +} \ No newline at end of file diff --git a/csrc/prepare_inputs/advance_step.cuh b/csrc/prepare_inputs/advance_step.cuh new file mode 100644 index 0000000000000..f21574681b1ab --- /dev/null +++ b/csrc/prepare_inputs/advance_step.cuh @@ -0,0 +1,19 @@ +#pragma once + +#include + +#include +#include +#include +#include +#include +#include + +namespace prepare_inputs { + +static constexpr int max_threads = 256; +static constexpr bool logging = false; + +constexpr int div_ceil(int a, int b) { return (a + b - 1) / b; } + +} // namespace prepare_inputs diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 18331a674eeba..ff9875e0e17a3 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -27,8 +27,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " Tensor value_cache, int num_kv_heads, float scale," " Tensor block_tables, Tensor seq_lens, int block_size," " int max_seq_len, Tensor? alibi_slopes," - " str kv_cache_dtype, float kv_scale, int tp_rank," - " int blocksparse_local_blocks," + " str kv_cache_dtype, float k_scale, float v_scale," + " int tp_rank, int blocksparse_local_blocks," " int blocksparse_vert_stride, int blocksparse_block_size," " int blocksparse_head_sliding_step) -> ()"); ops.impl("paged_attention_v1", torch::kCUDA, &paged_attention_v1); @@ -41,8 +41,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " Tensor value_cache, int num_kv_heads, float scale," " Tensor block_tables, Tensor seq_lens, int block_size," " int max_seq_len, Tensor? alibi_slopes," - " str kv_cache_dtype, float kv_scale, int tp_rank," - " int blocksparse_local_blocks," + " str kv_cache_dtype, float k_scale, float v_scale," + " int tp_rank, int blocksparse_local_blocks," " int blocksparse_vert_stride, int blocksparse_block_size," " int blocksparse_head_sliding_step) -> ()"); ops.impl("paged_attention_v2", torch::kCUDA, &paged_attention_v2); @@ -72,6 +72,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.def("gelu_quick(Tensor! out, Tensor input) -> ()"); ops.impl("gelu_quick", torch::kCUDA, &gelu_quick); + // prepare_inputs advance_step + ops.def("advance_step", &advance_step); + ops.impl("advance_step", torch::kCUDA, &advance_step); + // Layernorm // Apply Root Mean Square (RMS) Normalization to the input tensor. ops.def( @@ -223,7 +227,7 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { " Tensor! key_cache, Tensor! value_cache," " Tensor slot_mapping," " str kv_cache_dtype," - " float kv_scale) -> ()"); + " float k_scale, float v_scale) -> ()"); cache_ops.impl("reshape_and_cache", torch::kCUDA, &reshape_and_cache); // Reshape the key and value tensors and cache them. diff --git a/docs/source/getting_started/amd-installation.rst b/docs/source/getting_started/amd-installation.rst index cc41d47296f8d..1f9e4fabc4fc9 100644 --- a/docs/source/getting_started/amd-installation.rst +++ b/docs/source/getting_started/amd-installation.rst @@ -3,7 +3,7 @@ Installation with ROCm ====================== -vLLM supports AMD GPUs with ROCm 5.7 and 6.0. +vLLM supports AMD GPUs with ROCm 6.1. Requirements ------------ @@ -11,7 +11,7 @@ Requirements * OS: Linux * Python: 3.8 -- 3.11 * GPU: MI200s (gfx90a), MI300 (gfx942), Radeon RX 7900 series (gfx1100) -* ROCm 6.0 and ROCm 5.7 +* ROCm 6.1 Installation options: @@ -27,10 +27,10 @@ You can build and install vLLM from source. First, build a docker image from `Dockerfile.rocm `_ and launch a docker container from the image. -`Dockerfile.rocm `_ uses ROCm 6.0 by default, but also supports ROCm 5.7. +`Dockerfile.rocm `_ uses ROCm 6.1 by default, but also supports ROCm 5.7 and 6.0 in older vLLM branches. It provides flexibility to customize the build of docker image using the following arguments: -* `BASE_IMAGE`: specifies the base image used when running ``docker build``, specifically the PyTorch on ROCm base image. We have tested ROCm 5.7 and ROCm 6.0. The default is `rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1` +* `BASE_IMAGE`: specifies the base image used when running ``docker build``, specifically the PyTorch on ROCm base image. * `BUILD_FA`: specifies whether to build CK flash-attention. The default is 1. For `Radeon RX 7900 series (gfx1100) `_, this should be set to 0 before flash-attention supports this target. * `FX_GFX_ARCHS`: specifies the GFX architecture that is used to build CK flash-attention, for example, `gfx90a;gfx942` for MI200 and MI300. The default is `gfx90a;gfx942` * `FA_BRANCH`: specifies the branch used to build the CK flash-attention in `ROCm's flash-attention repo `_. The default is `ae7928c` @@ -39,24 +39,17 @@ It provides flexibility to customize the build of docker image using the followi Their values can be passed in when running ``docker build`` with ``--build-arg`` options. -To build vllm on ROCm 6.0 for MI200 and MI300 series, you can use the default: +To build vllm on ROCm 6.1 for MI200 and MI300 series, you can use the default: .. code-block:: console - $ docker build -f Dockerfile.rocm -t vllm-rocm . + $ DOCKER_BUILDKIT=1 docker build -f Dockerfile.rocm -t vllm-rocm . -To build vllm on ROCm 6.0 for Radeon RX7900 series (gfx1100), you should specify ``BUILD_FA`` as below: +To build vllm on ROCm 6.1 for Radeon RX7900 series (gfx1100), you should specify ``BUILD_FA`` as below: .. code-block:: console - $ docker build --build-arg BUILD_FA="0" -f Dockerfile.rocm -t vllm-rocm . - -To build docker image for vllm on ROCm 5.7, you can specify ``BASE_IMAGE`` as below: - -.. code-block:: console - - $ docker build --build-arg BASE_IMAGE="rocm/pytorch:rocm5.7_ubuntu22.04_py3.10_pytorch_2.0.1" \ - -f Dockerfile.rocm -t vllm-rocm . + $ DOCKER_BUILDKIT=1 docker build --build-arg BUILD_FA="0" -f Dockerfile.rocm -t vllm-rocm . To run the above docker image ``vllm-rocm``, use the below command: @@ -85,25 +78,12 @@ Option 2: Build from source 0. Install prerequisites (skip if you are already in an environment/docker with the following installed): - `ROCm `_ -- `Pytorch `_ +- `PyTorch `_ - `hipBLAS `_ -For installing PyTorch, you can start from a fresh docker image, e.g, `rocm/pytorch:rocm6.1.2_ubuntu20.04_py3.9_pytorch_staging`, `rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1`, `rocm/pytorch-nightly`. - -Alternatively, you can install pytorch using pytorch wheels. You can check Pytorch installation guild in Pytorch `Getting Started `_ - -For rocm6.0: - -.. code-block:: console - - $ pip3 install torch --index-url https://download.pytorch.org/whl/rocm6.0 - - -For rocm5.7: - -.. code-block:: console +For installing PyTorch, you can start from a fresh docker image, e.g, `rocm/pytorch:rocm6.1.2_ubuntu20.04_py3.9_pytorch_staging`, `rocm/pytorch-nightly`. - $ pip install torch --index-url https://download.pytorch.org/whl/rocm5.7 +Alternatively, you can install PyTorch using PyTorch wheels. You can check PyTorch installation guild in PyTorch `Getting Started `_ 1. Install `Triton flash attention for ROCm `_ @@ -115,8 +95,6 @@ Install ROCm's Triton flash attention (the default triton-mlir branch) following Install ROCm's flash attention (v2.0.4) following the instructions from `ROCm/flash-attention `_ .. note:: - - If you are using rocm5.7 with pytorch 2.1.0 onwards, you don't need to apply the `hipify_python.patch`. You can build the ROCm flash attention directly. - - If you fail to install `ROCm/flash-attention`, try cloning from the commit `6fd2f8e572805681cd67ef8596c7e2ce521ed3c6`. - ROCm's Flash-attention-2 (v2.0.4) does not support sliding windows attention. - You might need to downgrade the "ninja" version to 1.10 it is not used when compiling flash-attention-2 (e.g. `pip install ninja==1.10.2.4`) @@ -131,7 +109,6 @@ Install ROCm's flash attention (v2.0.4) following the instructions from `ROCm/fl .. tip:: - - You may need to turn on the ``--enforce-eager`` flag if you experience process hang when running the `benchmark_thoughput.py` script to test your installation. - Triton flash attention is used by default. For benchmarking purposes, it is recommended to run a warm up step before collecting perf numbers. - - To use CK flash-attention, please use this flag ``export VLLM_USE_TRITON_FLASH_ATTN=0`` to turn off triton flash attention. - - The ROCm version of pytorch, ideally, should match the ROCm driver version. + - To use CK flash-attention or PyTorch naive attention, please use this flag ``export VLLM_USE_TRITON_FLASH_ATTN=0`` to turn off triton flash attention. + - The ROCm version of PyTorch, ideally, should match the ROCm driver version. diff --git a/docs/source/getting_started/quickstart.rst b/docs/source/getting_started/quickstart.rst index 7c44a96865a50..89bdc247c5e8e 100644 --- a/docs/source/getting_started/quickstart.rst +++ b/docs/source/getting_started/quickstart.rst @@ -73,16 +73,13 @@ Start the server: .. code-block:: console - $ python -m vllm.entrypoints.openai.api_server \ - $ --model facebook/opt-125m + $ vllm serve facebook/opt-125m By default, the server uses a predefined chat template stored in the tokenizer. You can override this template by using the ``--chat-template`` argument: .. code-block:: console - $ python -m vllm.entrypoints.openai.api_server \ - $ --model facebook/opt-125m \ - $ --chat-template ./examples/template_chatml.jinja + $ vllm serve facebook/opt-125m --chat-template ./examples/template_chatml.jinja This server can be queried in the same format as OpenAI API. For example, list the models: diff --git a/docs/source/models/adding_model.rst b/docs/source/models/adding_model.rst index 53c19e5829218..5cffb58cafd96 100644 --- a/docs/source/models/adding_model.rst +++ b/docs/source/models/adding_model.rst @@ -114,7 +114,7 @@ Just add the following lines in your code: from your_code import YourModelForCausalLM ModelRegistry.register_model("YourModelForCausalLM", YourModelForCausalLM) -If you are running api server with `python -m vllm.entrypoints.openai.api_server args`, you can wrap the entrypoint with the following code: +If you are running api server with :code:`vllm serve `, you can wrap the entrypoint with the following code: .. code-block:: python @@ -124,4 +124,4 @@ If you are running api server with `python -m vllm.entrypoints.openai.api_server import runpy runpy.run_module('vllm.entrypoints.openai.api_server', run_name='__main__') -Save the above code in a file and run it with `python your_file.py args`. +Save the above code in a file and run it with :code:`python your_file.py `. diff --git a/docs/source/models/engine_args.rst b/docs/source/models/engine_args.rst index bdf566d3ebbd1..e7ce8cdcabe88 100644 --- a/docs/source/models/engine_args.rst +++ b/docs/source/models/engine_args.rst @@ -8,7 +8,7 @@ Below, you can find an explanation of every engine argument for vLLM: .. argparse:: :module: vllm.engine.arg_utils :func: _engine_args_parser - :prog: -m vllm.entrypoints.openai.api_server + :prog: vllm serve :nodefaultconst: Async Engine Arguments @@ -19,5 +19,5 @@ Below are the additional arguments related to the asynchronous engine: .. argparse:: :module: vllm.engine.arg_utils :func: _async_engine_args_parser - :prog: -m vllm.entrypoints.openai.api_server + :prog: vllm serve :nodefaultconst: \ No newline at end of file diff --git a/docs/source/models/lora.rst b/docs/source/models/lora.rst index 5cc3076073fbd..f08773fe59d92 100644 --- a/docs/source/models/lora.rst +++ b/docs/source/models/lora.rst @@ -61,8 +61,7 @@ LoRA adapted models can also be served with the Open-AI compatible vLLM server. .. code-block:: bash - python -m vllm.entrypoints.openai.api_server \ - --model meta-llama/Llama-2-7b-hf \ + vllm serve meta-llama/Llama-2-7b-hf \ --enable-lora \ --lora-modules sql-lora=$HOME/.cache/huggingface/hub/models--yard1--llama-2-7b-sql-lora-test/snapshots/0dfa347e8877a4d4ed19ee56c140fa518470028c/ diff --git a/docs/source/models/vlm.rst b/docs/source/models/vlm.rst index d488b0fefdf06..92aca168dadf2 100644 --- a/docs/source/models/vlm.rst +++ b/docs/source/models/vlm.rst @@ -94,9 +94,7 @@ Below is an example on how to launch the same ``llava-hf/llava-1.5-7b-hf`` with .. code-block:: bash - python -m vllm.entrypoints.openai.api_server \ - --model llava-hf/llava-1.5-7b-hf \ - --chat-template template_llava.jinja + vllm serve llava-hf/llava-1.5-7b-hf --chat-template template_llava.jinja .. important:: We have removed all vision language related CLI args in the ``0.5.1`` release. **This is a breaking change**, so please update your code to follow diff --git a/docs/source/serving/deploying_with_cerebrium.rst b/docs/source/serving/deploying_with_cerebrium.rst index ff0ac911108c4..9585b6ef5cb38 100644 --- a/docs/source/serving/deploying_with_cerebrium.rst +++ b/docs/source/serving/deploying_with_cerebrium.rst @@ -28,6 +28,9 @@ Next, to install the required packages, add the following to your cerebrium.toml .. code-block:: toml + [cerebrium.deployment] + docker_base_image_url = "nvidia/cuda:12.1.1-runtime-ubuntu22.04" + [cerebrium.dependencies.pip] vllm = "latest" diff --git a/docs/source/serving/deploying_with_dstack.rst b/docs/source/serving/deploying_with_dstack.rst index baf87314ca8e4..e1eb45b225d9c 100644 --- a/docs/source/serving/deploying_with_dstack.rst +++ b/docs/source/serving/deploying_with_dstack.rst @@ -40,7 +40,7 @@ Next, to provision a VM instance with LLM of your choice(`NousResearch/Llama-2-7 gpu: 24GB commands: - pip install vllm - - python -m vllm.entrypoints.openai.api_server --model $MODEL --port 8000 + - vllm serve $MODEL --port 8000 model: format: openai type: chat diff --git a/docs/source/serving/distributed_serving.rst b/docs/source/serving/distributed_serving.rst index 2dfb83f168b5d..fa1b04dc3dce5 100644 --- a/docs/source/serving/distributed_serving.rst +++ b/docs/source/serving/distributed_serving.rst @@ -35,16 +35,14 @@ To run multi-GPU serving, pass in the :code:`--tensor-parallel-size` argument wh .. code-block:: console - $ python -m vllm.entrypoints.openai.api_server \ - $ --model facebook/opt-13b \ + $ vllm serve facebook/opt-13b \ $ --tensor-parallel-size 4 You can also additionally specify :code:`--pipeline-parallel-size` to enable pipeline parallelism. For example, to run API server on 8 GPUs with pipeline parallelism and tensor parallelism: .. code-block:: console - $ python -m vllm.entrypoints.openai.api_server \ - $ --model gpt2 \ + $ vllm serve gpt2 \ $ --tensor-parallel-size 4 \ $ --pipeline-parallel-size 2 \ $ --distributed-executor-backend ray diff --git a/docs/source/serving/openai_compatible_server.md b/docs/source/serving/openai_compatible_server.md index 092c3c6cb9a3d..a06c30d9c48c6 100644 --- a/docs/source/serving/openai_compatible_server.md +++ b/docs/source/serving/openai_compatible_server.md @@ -4,7 +4,7 @@ vLLM provides an HTTP server that implements OpenAI's [Completions](https://plat You can start the server using Python, or using [Docker](deploying_with_docker.rst): ```bash -python -m vllm.entrypoints.openai.api_server --model NousResearch/Meta-Llama-3-8B-Instruct --dtype auto --api-key token-abc123 +vllm serve NousResearch/Meta-Llama-3-8B-Instruct --dtype auto --api-key token-abc123 ``` To call the server, you can use the official OpenAI Python client library, or any other HTTP client. @@ -97,9 +97,7 @@ template, or the template in string form. Without a chat template, the server wi and all chat requests will error. ```bash -python -m vllm.entrypoints.openai.api_server \ - --model ... \ - --chat-template ./path-to-chat-template.jinja +vllm serve --chat-template ./path-to-chat-template.jinja ``` vLLM community provides a set of chat templates for popular models. You can find them in the examples @@ -110,7 +108,7 @@ directory [here](https://github.com/vllm-project/vllm/tree/main/examples/) ```{argparse} :module: vllm.entrypoints.openai.cli_args :func: create_parser_for_docs -:prog: -m vllm.entrypoints.openai.api_server +:prog: vllm serve ``` ## Tool calling in the chat completion API diff --git a/examples/api_client.py b/examples/api_client.py index 5f7daa14d5044..27a2a08b7b0c3 100644 --- a/examples/api_client.py +++ b/examples/api_client.py @@ -1,8 +1,7 @@ -"""Example Python client for vllm.entrypoints.api_server +"""Example Python client for `vllm.entrypoints.api_server` NOTE: The API server is used only for demonstration and simple performance benchmarks. It is not intended for production use. -For production use, we recommend vllm.entrypoints.openai.api_server -and the OpenAI client API +For production use, we recommend `vllm serve` and the OpenAI client API. """ import argparse diff --git a/examples/logging_configuration.md b/examples/logging_configuration.md index 75b4b31a80462..0d278b0392403 100644 --- a/examples/logging_configuration.md +++ b/examples/logging_configuration.md @@ -95,9 +95,7 @@ to the path of the custom logging configuration JSON file: ```bash VLLM_LOGGING_CONFIG_PATH=/path/to/logging_config.json \ - python3 -m vllm.entrypoints.openai.api_server \ - --max-model-len 2048 \ - --model mistralai/Mistral-7B-v0.1 + vllm serve mistralai/Mistral-7B-v0.1 --max-model-len 2048 ``` @@ -152,9 +150,7 @@ to the path of the custom logging configuration JSON file: ```bash VLLM_LOGGING_CONFIG_PATH=/path/to/logging_config.json \ - python3 -m vllm.entrypoints.openai.api_server \ - --max-model-len 2048 \ - --model mistralai/Mistral-7B-v0.1 + vllm serve mistralai/Mistral-7B-v0.1 --max-model-len 2048 ``` @@ -167,9 +163,7 @@ loggers. ```bash VLLM_CONFIGURE_LOGGING=0 \ - python3 -m vllm.entrypoints.openai.api_server \ - --max-model-len 2048 \ - --model mistralai/Mistral-7B-v0.1 + vllm serve mistralai/Mistral-7B-v0.1 --max-model-len 2048 ``` diff --git a/examples/openai_vision_api_client.py b/examples/openai_vision_api_client.py index d4d9738a1f7bc..2082c378e267c 100644 --- a/examples/openai_vision_api_client.py +++ b/examples/openai_vision_api_client.py @@ -1,9 +1,7 @@ """An example showing how to use vLLM to serve VLMs. Launch the vLLM server with the following command: -python -m vllm.entrypoints.openai.api_server \ - --model llava-hf/llava-1.5-7b-hf \ - --chat-template template_llava.jinja +vllm serve llava-hf/llava-1.5-7b-hf --chat-template template_llava.jinja """ import base64 diff --git a/examples/production_monitoring/Otel.md b/examples/production_monitoring/Otel.md index 1449442273c7a..2c7a7caa1bd7c 100644 --- a/examples/production_monitoring/Otel.md +++ b/examples/production_monitoring/Otel.md @@ -36,7 +36,7 @@ ``` export OTEL_SERVICE_NAME="vllm-server" export OTEL_EXPORTER_OTLP_TRACES_INSECURE=true - python -m vllm.entrypoints.openai.api_server --model="facebook/opt-125m" --otlp-traces-endpoint="$OTEL_EXPORTER_OTLP_TRACES_ENDPOINT" + vllm serve facebook/opt-125m --otlp-traces-endpoint="$OTEL_EXPORTER_OTLP_TRACES_ENDPOINT" ``` 1. In a new shell, send requests with trace context from a dummy client @@ -62,7 +62,7 @@ By default, `grpc` is used. To set `http/protobuf` as the protocol, configure th ``` export OTEL_EXPORTER_OTLP_TRACES_PROTOCOL=http/protobuf export OTEL_EXPORTER_OTLP_TRACES_ENDPOINT=http://$JAEGER_IP:4318/v1/traces -python -m vllm.entrypoints.openai.api_server --model="facebook/opt-125m" --otlp-traces-endpoint="$OTEL_EXPORTER_OTLP_TRACES_ENDPOINT" +vllm serve facebook/opt-125m --otlp-traces-endpoint="$OTEL_EXPORTER_OTLP_TRACES_ENDPOINT" ``` ## Instrumentation of FastAPI @@ -74,7 +74,7 @@ OpenTelemetry allows automatic instrumentation of FastAPI. 1. Run vLLM with `opentelemetry-instrument` ``` - opentelemetry-instrument python -m vllm.entrypoints.openai.api_server --model="facebook/opt-125m" + opentelemetry-instrument vllm serve facebook/opt-125m ``` 1. Send a request to vLLM and find its trace in Jaeger. It should contain spans from FastAPI. diff --git a/examples/production_monitoring/README.md b/examples/production_monitoring/README.md index 268f2e771018f..807c0470e7b30 100644 --- a/examples/production_monitoring/README.md +++ b/examples/production_monitoring/README.md @@ -10,8 +10,7 @@ Install: Prometheus metric logging is enabled by default in the OpenAI-compatible server. Launch via the entrypoint: ```bash -python3 -m vllm.entrypoints.openai.api_server \ - --model mistralai/Mistral-7B-v0.1 \ +vllm serve mistralai/Mistral-7B-v0.1 \ --max-model-len 2048 \ --disable-log-requests ``` diff --git a/rocm_patch/rocm_bf16.patch b/rocm_patch/rocm_bf16.patch deleted file mode 100644 index a0f07da2a3e2b..0000000000000 --- a/rocm_patch/rocm_bf16.patch +++ /dev/null @@ -1,15 +0,0 @@ ---- amd_hip_bf16.h 2024-02-06 18:28:58.268699142 +0000 -+++ amd_hip_bf16.h.new 2024-02-06 18:28:31.988647133 +0000 -@@ -90,10 +90,10 @@ - #include "math_fwd.h" // ocml device functions - - #if defined(__HIPCC_RTC__) --#define __HOST_DEVICE__ __device__ -+#define __HOST_DEVICE__ __device__ static - #else - #include --#define __HOST_DEVICE__ __host__ __device__ -+#define __HOST_DEVICE__ __host__ __device__ static inline - #endif - - // Since we are using unsigned short to represent data in bfloat16, it can be of different sizes on diff --git a/tests/async_engine/test_openapi_server_ray.py b/tests/async_engine/test_openapi_server_ray.py index 575f8f19b8ebe..5ecd770ede836 100644 --- a/tests/async_engine/test_openapi_server_ray.py +++ b/tests/async_engine/test_openapi_server_ray.py @@ -9,17 +9,17 @@ @pytest.fixture(scope="module") def server(): - with RemoteOpenAIServer([ - "--model", - MODEL_NAME, - # use half precision for speed and memory savings in CI environment - "--dtype", - "float16", - "--max-model-len", - "2048", - "--enforce-eager", - "--engine-use-ray" - ]) as remote_server: + args = [ + # use half precision for speed and memory savings in CI environment + "--dtype", + "float16", + "--max-model-len", + "2048", + "--enforce-eager", + "--engine-use-ray" + ] + + with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: yield remote_server diff --git a/tests/distributed/test_pipeline_parallel.py b/tests/distributed/test_pipeline_parallel.py index 2d9f63795189d..123a77e14ad74 100644 --- a/tests/distributed/test_pipeline_parallel.py +++ b/tests/distributed/test_pipeline_parallel.py @@ -1,27 +1,18 @@ -import os - -import openai # use the official client for correctness check import pytest from ..utils import RemoteOpenAIServer -# downloading lora to test lora requests - -# any model with a chat template should work here -MODEL_NAME = "meta-llama/Meta-Llama-3-8B" -EAGER_MODE = bool(int(os.getenv("EAGER_MODE", 0))) -CHUNKED_PREFILL = bool(int(os.getenv("CHUNKED_PREFILL", 0))) -TP_SIZE = int(os.getenv("TP_SIZE", 1)) -PP_SIZE = int(os.getenv("PP_SIZE", 1)) - -pytestmark = pytest.mark.asyncio - -@pytest.fixture(scope="module") -def server(): - args = [ - "--model", - MODEL_NAME, +@pytest.mark.parametrize( + "TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME", [ + (2, 2, 0, 1, "meta-llama/Meta-Llama-3-8B"), + (2, 2, 1, 0, "meta-llama/Meta-Llama-3-8B"), + (1, 3, 0, 0, "meta-llama/Meta-Llama-3-8B"), + (1, 4, 0, 1, "meta-llama/Meta-Llama-3-8B"), + (1, 4, 1, 0, "meta-llama/Meta-Llama-3-8B"), + ]) +def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME): + pp_args = [ # use half precision for speed and memory savings in CI environment "--dtype", "bfloat16", @@ -32,109 +23,105 @@ def server(): "--distributed-executor-backend", "ray", ] + + # compare without pipeline parallelism + # NOTE: use mp backend for TP + # PP tests might involve multiple nodes, and ray might + # schedule all workers in a node other than the head node, + # which can cause the test to fail. + tp_args = [ + # use half precision for speed and memory savings in CI environment + "--dtype", + "bfloat16", + "--tensor-parallel-size", + str(max(TP_SIZE, 2)), # use at least TP_SIZE=2 to hold the model + "--distributed-executor-backend", + "mp", + ] if CHUNKED_PREFILL: - args += [ - "--enable-chunked-prefill", - ] + pp_args.append("--enable-chunked-prefill") + tp_args.append("--enable-chunked-prefill") if EAGER_MODE: - args += [ - "--enforce-eager", - ] - with RemoteOpenAIServer(args) as remote_server: - yield remote_server - - -@pytest.fixture(scope="module") -def client(server): - return server.get_async_client() - - -async def test_check_models(server, client: openai.AsyncOpenAI): - models = await client.models.list() - models = models.data - served_model = models[0] - assert served_model.id == MODEL_NAME - assert all(model.root == MODEL_NAME for model in models) - - -@pytest.mark.parametrize( - "model_name", - [MODEL_NAME], -) -async def test_single_completion(server, client: openai.AsyncOpenAI, - model_name: str): - completion = await client.completions.create(model=model_name, - prompt="Hello, my name is", - max_tokens=5, - temperature=0.0) - - assert completion.id is not None - assert completion.choices is not None and len(completion.choices) == 1 - assert completion.choices[0].text is not None and len( - completion.choices[0].text) >= 5 - assert completion.choices[0].finish_reason == "length" - assert completion.usage == openai.types.CompletionUsage( - completion_tokens=5, prompt_tokens=6, total_tokens=11) - - # test using token IDs - completion = await client.completions.create( - model=MODEL_NAME, - prompt=[0, 0, 0, 0, 0], - max_tokens=5, - temperature=0.0, - ) - assert completion.choices[0].text is not None and len( - completion.choices[0].text) >= 5 - - -@pytest.mark.parametrize( - # just test 1 lora hereafter - "model_name", - [MODEL_NAME], -) -async def test_batch_completions(server, client: openai.AsyncOpenAI, - model_name: str): - # test simple list - batch = await client.completions.create( - model=model_name, - prompt=["Hello, my name is", "Hello, my name is"], - max_tokens=5, - temperature=0.0, - ) - assert len(batch.choices) == 2 - assert batch.choices[0].text == batch.choices[1].text - - # test n = 2 - batch = await client.completions.create( - model=model_name, - prompt=["Hello, my name is", "Hello, my name is"], - n=2, - max_tokens=5, - temperature=0.0, - extra_body=dict( - # NOTE: this has to be true for n > 1 in vLLM, but not necessary - # for official client. - use_beam_search=True), - ) - assert len(batch.choices) == 4 - assert batch.choices[0].text != batch.choices[ - 1].text, "beam search should be different" - assert batch.choices[0].text == batch.choices[ - 2].text, "two copies of the same prompt should be the same" - assert batch.choices[1].text == batch.choices[ - 3].text, "two copies of the same prompt should be the same" - - # test streaming - batch = await client.completions.create( - model=model_name, - prompt=["Hello, my name is", "Hello, my name is"], - max_tokens=5, - temperature=0.0, - stream=True, - ) - texts = [""] * 2 - async for chunk in batch: - assert len(chunk.choices) == 1 - choice = chunk.choices[0] - texts[choice.index] += choice.text - assert texts[0] == texts[1] + pp_args.append("--enforce-eager") + tp_args.append("--enforce-eager") + + results = [] + for args in [pp_args, tp_args]: + with RemoteOpenAIServer(MODEL_NAME, args) as server: + client = server.get_client() + + # test models list + models = client.models.list() + models = models.data + served_model = models[0] + results.append({ + "test": "models_list", + "id": served_model.id, + "root": served_model.root, + }) + + # test with text prompt + completion = client.completions.create(model=MODEL_NAME, + prompt="Hello, my name is", + max_tokens=5, + temperature=0.0) + + results.append({ + "test": "single_completion", + "text": completion.choices[0].text, + "finish_reason": completion.choices[0].finish_reason, + "usage": completion.usage, + }) + + # test using token IDs + completion = client.completions.create( + model=MODEL_NAME, + prompt=[0, 0, 0, 0, 0], + max_tokens=5, + temperature=0.0, + ) + + results.append({ + "test": "token_ids", + "text": completion.choices[0].text, + "finish_reason": completion.choices[0].finish_reason, + "usage": completion.usage, + }) + + # test simple list + batch = client.completions.create( + model=MODEL_NAME, + prompt=["Hello, my name is", "Hello, my name is"], + max_tokens=5, + temperature=0.0, + ) + + results.append({ + "test": "simple_list", + "text0": batch.choices[0].text, + "text1": batch.choices[1].text, + }) + + # test streaming + batch = client.completions.create( + model=MODEL_NAME, + prompt=["Hello, my name is", "Hello, my name is"], + max_tokens=5, + temperature=0.0, + stream=True, + ) + texts = [""] * 2 + for chunk in batch: + assert len(chunk.choices) == 1 + choice = chunk.choices[0] + texts[choice.index] += choice.text + results.append({ + "test": "streaming", + "texts": texts, + }) + + n = len(results) // 2 + pp_results = results[:n] + tp_results = results[n:] + for pp, tp in zip(pp_results, tp_results): + assert pp == tp diff --git a/tests/entrypoints/openai/test_chat.py b/tests/entrypoints/openai/test_chat.py index 32e2d29f2aec5..1abaa01ae192a 100644 --- a/tests/entrypoints/openai/test_chat.py +++ b/tests/entrypoints/openai/test_chat.py @@ -22,27 +22,27 @@ @pytest.fixture(scope="module") def server(zephyr_lora_files, zephyr_lora_added_tokens_files): # noqa: F811 - with RemoteOpenAIServer([ - "--model", - MODEL_NAME, - # use half precision for speed and memory savings in CI environment - "--dtype", - "bfloat16", - "--max-model-len", - "8192", - "--enforce-eager", - # lora config below - "--enable-lora", - "--lora-modules", - f"zephyr-lora={zephyr_lora_files}", - f"zephyr-lora2={zephyr_lora_added_tokens_files}", - "--max-lora-rank", - "64", - "--max-cpu-loras", - "2", - "--max-num-seqs", - "128", - ]) as remote_server: + args = [ + # use half precision for speed and memory savings in CI environment + "--dtype", + "bfloat16", + "--max-model-len", + "8192", + "--enforce-eager", + # lora config below + "--enable-lora", + "--lora-modules", + f"zephyr-lora={zephyr_lora_files}", + f"zephyr-lora2={zephyr_lora_added_tokens_files}", + "--max-lora-rank", + "64", + "--max-cpu-loras", + "2", + "--max-num-seqs", + "128", + ] + + with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: yield remote_server diff --git a/tests/entrypoints/openai/test_completion.py b/tests/entrypoints/openai/test_completion.py index fc5c301f5d536..0896e337b5d24 100644 --- a/tests/entrypoints/openai/test_completion.py +++ b/tests/entrypoints/openai/test_completion.py @@ -56,36 +56,36 @@ def zephyr_pa_files(): @pytest.fixture(scope="module") def server(zephyr_lora_files, zephyr_lora_added_tokens_files, zephyr_pa_files): - with RemoteOpenAIServer([ - "--model", - MODEL_NAME, - # use half precision for speed and memory savings in CI environment - "--dtype", - "bfloat16", - "--max-model-len", - "8192", - "--max-num-seqs", - "128", - "--enforce-eager", - # lora config - "--enable-lora", - "--lora-modules", - f"zephyr-lora={zephyr_lora_files}", - f"zephyr-lora2={zephyr_lora_added_tokens_files}", - "--max-lora-rank", - "64", - "--max-cpu-loras", - "2", - # pa config - "--enable-prompt-adapter", - "--prompt-adapters", - f"zephyr-pa={zephyr_pa_files}", - f"zephyr-pa2={zephyr_pa_files}", - "--max-prompt-adapters", - "2", - "--max-prompt-adapter-token", - "128", - ]) as remote_server: + args = [ + # use half precision for speed and memory savings in CI environment + "--dtype", + "bfloat16", + "--max-model-len", + "8192", + "--max-num-seqs", + "128", + "--enforce-eager", + # lora config + "--enable-lora", + "--lora-modules", + f"zephyr-lora={zephyr_lora_files}", + f"zephyr-lora2={zephyr_lora_added_tokens_files}", + "--max-lora-rank", + "64", + "--max-cpu-loras", + "2", + # pa config + "--enable-prompt-adapter", + "--prompt-adapters", + f"zephyr-pa={zephyr_pa_files}", + f"zephyr-pa2={zephyr_pa_files}", + "--max-prompt-adapters", + "2", + "--max-prompt-adapter-token", + "128", + ] + + with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: yield remote_server diff --git a/tests/entrypoints/openai/test_embedding.py b/tests/entrypoints/openai/test_embedding.py index 4a32aadc8c3ae..2ca0c0d63c25c 100644 --- a/tests/entrypoints/openai/test_embedding.py +++ b/tests/entrypoints/openai/test_embedding.py @@ -11,17 +11,17 @@ @pytest.fixture(scope="module") def embedding_server(): - with RemoteOpenAIServer([ - "--model", - EMBEDDING_MODEL_NAME, - # use half precision for speed and memory savings in CI environment - "--dtype", - "bfloat16", - "--enforce-eager", - "--max-model-len", - "8192", - "--enforce-eager", - ]) as remote_server: + args = [ + # use half precision for speed and memory savings in CI environment + "--dtype", + "bfloat16", + "--enforce-eager", + "--max-model-len", + "8192", + "--enforce-eager", + ] + + with RemoteOpenAIServer(EMBEDDING_MODEL_NAME, args) as remote_server: yield remote_server diff --git a/tests/entrypoints/openai/test_models.py b/tests/entrypoints/openai/test_models.py index bf63f9a813f2c..c2cfff228c546 100644 --- a/tests/entrypoints/openai/test_models.py +++ b/tests/entrypoints/openai/test_models.py @@ -19,27 +19,27 @@ def zephyr_lora_files(): @pytest.fixture(scope="module") def server(zephyr_lora_files): - with RemoteOpenAIServer([ - "--model", - MODEL_NAME, - # use half precision for speed and memory savings in CI environment - "--dtype", - "bfloat16", - "--max-model-len", - "8192", - "--enforce-eager", - # lora config below - "--enable-lora", - "--lora-modules", - f"zephyr-lora={zephyr_lora_files}", - f"zephyr-lora2={zephyr_lora_files}", - "--max-lora-rank", - "64", - "--max-cpu-loras", - "2", - "--max-num-seqs", - "128", - ]) as remote_server: + args = [ + # use half precision for speed and memory savings in CI environment + "--dtype", + "bfloat16", + "--max-model-len", + "8192", + "--enforce-eager", + # lora config below + "--enable-lora", + "--lora-modules", + f"zephyr-lora={zephyr_lora_files}", + f"zephyr-lora2={zephyr_lora_files}", + "--max-lora-rank", + "64", + "--max-cpu-loras", + "2", + "--max-num-seqs", + "128", + ] + + with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: yield remote_server diff --git a/tests/entrypoints/openai/test_tokenization.py b/tests/entrypoints/openai/test_tokenization.py index 64f5df50a0eaf..ebf2dbfbb2b4b 100644 --- a/tests/entrypoints/openai/test_tokenization.py +++ b/tests/entrypoints/openai/test_tokenization.py @@ -14,27 +14,26 @@ @pytest.fixture(scope="module") def server(zephyr_lora_added_tokens_files: str): # noqa: F811 - with RemoteOpenAIServer([ - "--model", - MODEL_NAME, - # use half precision for speed and memory savings in CI environment - "--dtype", - "bfloat16", - "--max-model-len", - "8192", - "--enforce-eager", - "--max-num-seqs", - "128", - # lora config - "--enable-lora", - "--lora-modules", - f"zephyr-lora2={zephyr_lora_added_tokens_files}", - "--max-lora-rank", - "64", - ]) as remote_server: + args = [ + # use half precision for speed and memory savings in CI environment + "--dtype", + "bfloat16", + "--max-model-len", + "8192", + "--enforce-eager", + "--max-num-seqs", + "128", + # lora config + "--enable-lora", + "--lora-modules", + f"zephyr-lora2={zephyr_lora_added_tokens_files}", + "--max-lora-rank", + "64", + ] + + with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: yield remote_server - @pytest.fixture(scope="module") def tokenizer_name(model_name: str, zephyr_lora_added_tokens_files: str): # noqa: F811 diff --git a/tests/entrypoints/openai/test_vision.py b/tests/entrypoints/openai/test_vision.py index 563b68566bd2c..cc5c8d619183f 100644 --- a/tests/entrypoints/openai/test_vision.py +++ b/tests/entrypoints/openai/test_vision.py @@ -23,17 +23,17 @@ @pytest.fixture(scope="module") def server(): - with RemoteOpenAIServer([ - "--model", - MODEL_NAME, - "--dtype", - "bfloat16", - "--max-model-len", - "4096", - "--enforce-eager", - "--chat-template", - str(LLAVA_CHAT_TEMPLATE), - ]) as remote_server: + args = [ + "--dtype", + "bfloat16", + "--max-model-len", + "4096", + "--enforce-eager", + "--chat-template", + str(LLAVA_CHAT_TEMPLATE), + ] + + with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: yield remote_server diff --git a/tests/kernels/test_attention.py b/tests/kernels/test_attention.py index f848ad51c7014..2e6412c28958e 100644 --- a/tests/kernels/test_attention.py +++ b/tests/kernels/test_attention.py @@ -175,7 +175,7 @@ def test_paged_attention( key_cache, value_cache = key_caches[0], value_caches[0] # Using default kv_scale - kv_scale = 1.0 + k_scale = v_scale = 1.0 # Call the paged attention kernel. output = torch.empty_like(query) @@ -193,7 +193,8 @@ def test_paged_attention( max_seq_len, alibi_slopes, kv_cache_dtype, - kv_scale, + k_scale, + v_scale, ) elif version == "v2": num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE) @@ -224,7 +225,8 @@ def test_paged_attention( max_seq_len, alibi_slopes, kv_cache_dtype, - kv_scale, + k_scale, + v_scale, ) else: raise AssertionError(f"Unknown version: {version}") diff --git a/tests/kernels/test_blocksparse_attention.py b/tests/kernels/test_blocksparse_attention.py index 402545d1980d6..b3adb152949a2 100644 --- a/tests/kernels/test_blocksparse_attention.py +++ b/tests/kernels/test_blocksparse_attention.py @@ -212,7 +212,7 @@ def test_paged_attention( key_cache, value_cache = key_caches[0], value_caches[0] # Using default kv_scale - kv_scale = 1.0 + k_scale = v_scale = 1.0 tp_rank = 0 # Call the paged attention kernel. @@ -231,7 +231,8 @@ def test_paged_attention( max_seq_len, alibi_slopes, kv_cache_dtype, - kv_scale, + k_scale, + v_scale, tp_rank=tp_rank, blocksparse_local_blocks=blocksparse_local_blocks, blocksparse_vert_stride=blocksparse_vert_stride, @@ -267,7 +268,8 @@ def test_paged_attention( max_seq_len, alibi_slopes, kv_cache_dtype, - kv_scale, + k_scale, + v_scale, tp_rank=tp_rank, blocksparse_local_blocks=blocksparse_local_blocks, blocksparse_vert_stride=blocksparse_vert_stride, diff --git a/tests/kernels/test_cache.py b/tests/kernels/test_cache.py index 23b6baa60c05b..70ae3d0c6e0c3 100644 --- a/tests/kernels/test_cache.py +++ b/tests/kernels/test_cache.py @@ -155,11 +155,11 @@ def test_reshape_and_cache( cloned_value_cache = value_cache.clone() # Using default kv_scale - kv_scale = 1.0 + k_scale = v_scale = 1.0 # Call the reshape_and_cache kernel. ops.reshape_and_cache(key, value, key_cache, value_cache, slot_mapping, - kv_cache_dtype, kv_scale) + kv_cache_dtype, k_scale, v_scale) if kv_cache_dtype == "fp8": result_key_cache = torch.empty_like(key_cache, dtype=torch.float16) diff --git a/tests/metrics/test_metrics.py b/tests/metrics/test_metrics.py index 0191d85194e33..42b15cd6c458e 100644 --- a/tests/metrics/test_metrics.py +++ b/tests/metrics/test_metrics.py @@ -1,11 +1,13 @@ from typing import List import pytest +import ray from prometheus_client import REGISTRY from vllm import EngineArgs, LLMEngine from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm.engine.metrics import RayPrometheusStatLogger from vllm.sampling_params import SamplingParams MODELS = [ @@ -168,6 +170,55 @@ def test_engine_log_metrics_regression( assert_metrics(engine, disable_log_stats, len(example_prompts)) +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("max_tokens", [10]) +def test_metric_spec_decode( + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_tokens: int, +) -> None: + k = 5 + + with vllm_runner(model, + dtype=dtype, + disable_log_stats=False, + gpu_memory_utilization=0.4, + speculative_model=model, + num_speculative_tokens=k, + use_v2_block_manager=True) as vllm_model: + + # Force log interval to be 0 to catch all metrics. + stat_logger = vllm_model.model.llm_engine.stat_loggers['prometheus'] + stat_logger.local_interval = 0 + + # Note that the purpose of this test is to verify spec decode + # metrics instead of functional correctness, so the expected values + # are intended to be loose. + metric_name_to_expected_fn = { + "gauge_spec_decode_draft_acceptance_rate": lambda v: 0 <= v <= 1, + "gauge_spec_decode_efficiency": lambda v: 0 <= v <= 1, + "counter_spec_decode_num_accepted_tokens": lambda v: 0 <= v <= k, + "counter_spec_decode_num_draft_tokens": lambda v: v == k, + "counter_spec_decode_num_emitted_tokens": + lambda v: 0 <= v <= k + 1, + } + + # Use one request to better inspect the metrics. + prompts = example_prompts[:1] + + _ = vllm_model.generate_greedy(prompts, max_tokens) + for metric_name, is_expected in metric_name_to_expected_fn.items(): + metric_val = getattr( + stat_logger.metrics, + metric_name).labels(**stat_logger.labels)._value.get() + assert is_expected(metric_val), ( + f"the value of metric {metric_name} ({metric_val}) " + "does not meet expectation") + + def assert_metrics(engine: LLMEngine, disable_log_stats: bool, num_requests: int) -> None: if disable_log_stats: @@ -192,3 +243,55 @@ def assert_metrics(engine: LLMEngine, disable_log_stats: bool, labels) assert ( metric_value == num_requests), "Metrics should be collected" + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("max_tokens", [16]) +def test_engine_log_metrics_ray( + example_prompts, + model: str, + dtype: str, + max_tokens: int, +) -> None: + # This test is quite weak - it only checks that we can use + # RayPrometheusStatLogger without exceptions. + # Checking whether the metrics are actually emitted is unfortunately + # non-trivial. + + # We have to run in a Ray task for Ray metrics to be emitted correctly + @ray.remote(num_gpus=1) + def _inner(): + + class _RayPrometheusStatLogger(RayPrometheusStatLogger): + + def __init__(self, *args, **kwargs): + self._i = 0 + super().__init__(*args, **kwargs) + + def log(self, *args, **kwargs): + self._i += 1 + return super().log(*args, **kwargs) + + engine_args = EngineArgs( + model=model, + dtype=dtype, + disable_log_stats=False, + ) + engine = LLMEngine.from_engine_args(engine_args) + logger = _RayPrometheusStatLogger( + local_interval=0.5, + labels=dict(model_name=engine.model_config.served_model_name), + max_model_len=engine.model_config.max_model_len) + engine.add_logger("ray", logger) + for i, prompt in enumerate(example_prompts): + engine.add_request( + f"request-id-{i}", + prompt, + SamplingParams(max_tokens=max_tokens), + ) + while engine.has_unfinished_requests(): + engine.step() + assert logger._i > 0, ".log must be called at least once" + + ray.get(_inner.remote()) diff --git a/tests/quantization/test_fp8.py b/tests/quantization/test_fp8.py index 0ed91cbb447fd..82dc775f8d812 100644 --- a/tests/quantization/test_fp8.py +++ b/tests/quantization/test_fp8.py @@ -7,19 +7,49 @@ from tests.quantization.utils import is_quant_method_supported from vllm import _custom_ops as ops -from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod +from vllm.model_executor.layers.quantization.fp8 import (Fp8KVCacheMethod, + Fp8LinearMethod) MODELS = [ - "neuralmagic/Meta-Llama-3-8B-Instruct-FP8", + "neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV", "nm-testing/Phi-3-mini-128k-instruct-FP8", ] @pytest.mark.skipif(not is_quant_method_supported("fp8"), reason="FP8 is not supported on this GPU type.") -@pytest.mark.parametrize("model", MODELS) -def test_model_load_and_run(vllm_runner, model: str): - with vllm_runner(model) as llm: +@pytest.mark.parametrize("model_id", MODELS) +def test_model_load_and_run(vllm_runner, model_id: str): + with vllm_runner(model_id) as llm: + # note: this does not test accuracy, just that we can run through + # see lm-eval tests for accuracy + outputs = llm.generate_greedy(prompts=["Hello my name is"], + max_tokens=10) + print(outputs[0][1]) + + +KV_CACHE_MODELS = [ + # Deprecated AutoFP8 format using .kv_scale + "neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV", + # AutoFP8 format using separate .k_scale and .v_scale + "nm-testing/Qwen2-1.5B-Instruct-FP8-K-V", +] + + +@pytest.mark.skipif(not is_quant_method_supported("fp8"), + reason="FP8 is not supported on this GPU type.") +@pytest.mark.parametrize("model_id", KV_CACHE_MODELS) +def test_kv_cache_model_load_and_run(vllm_runner, model_id: str): + with vllm_runner(model_id, kv_cache_dtype="fp8") as llm: + + model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501 + attn = model.model.layers[0].self_attn.attn + assert isinstance(attn.quant_method, Fp8KVCacheMethod) + # NOTE: it is valid for scales to be 1.0 (default value), but we know + # these checkpoints have scales < 1.0 + assert 0.0 < attn._k_scale < 1.0 + assert 0.0 < attn._v_scale < 1.0 + # note: this does not test accuracy, just that we can run through # see lm-eval tests for accuracy outputs = llm.generate_greedy(prompts=["Hello my name is"], diff --git a/tests/spec_decode/e2e/conftest.py b/tests/spec_decode/e2e/conftest.py index fb3415b5db153..da72f6d503c11 100644 --- a/tests/spec_decode/e2e/conftest.py +++ b/tests/spec_decode/e2e/conftest.py @@ -162,6 +162,11 @@ def create_llm_generator(baseline_or_test, request, common_llm_kwargs, } test_name = request.node.name + model = kwargs["model"] + draft_model = kwargs.get("speculative_model", None) + same_draft_target_model = (draft_model is not None + and draft_model == model) + def generator_inner(): wait_for_gpu_memory_to_clear( @@ -177,6 +182,13 @@ def generator_inner(): print(f'Creating {baseline_or_test=} LLM for {test_name=}. {kwargs=}') llm = AsyncLLM(**kwargs) if use_async else LLM(**kwargs) + + # Override logging interval to 0 for spec decode test run to + # log all metrics in time. + if (baseline_or_test == "test" and not use_async + and llm.llm_engine.log_stats): + for sate_logger in llm.llm_engine.stat_loggers.values(): + sate_logger.local_interval = 0 set_random_seed(seed) yield llm @@ -188,6 +200,9 @@ def generator_outer(): yield llm del llm + # Set an attribute to the generator_outer function to allow us to + # determine whether to further check the acceptance rate in tests. + generator_outer.same_draft_target_model = same_draft_target_model # type: ignore return generator_outer @@ -204,18 +219,27 @@ def maybe_assert_ngram_worker(llm): def get_output_from_llm_generator( llm_generator, prompts, - sampling_params) -> Tuple[List[str], List[List[int]]]: + sampling_params) -> Tuple[List[str], List[List[int]], float]: tokens: List[str] = [] token_ids: List[List[int]] = [] + acceptance_rate: float = -1.0 for llm in llm_generator(): maybe_assert_ngram_worker(llm) outputs = llm.generate(prompts, sampling_params, use_tqdm=True) + token_ids = [output.outputs[0].token_ids for output in outputs] tokens = [output.outputs[0].text for output in outputs] + + # Fetch acceptance rate if logging is enabled. + if stat_loggers := getattr(llm.llm_engine, "stat_loggers", None): + stat_logger = stat_loggers["prometheus"] + acceptance_rate = (stat_logger.metrics. + gauge_spec_decode_draft_acceptance_rate.labels( + **stat_logger.labels)._value.get()) del llm - return tokens, token_ids + return tokens, token_ids, acceptance_rate def get_logprobs_from_llm_generator( @@ -237,7 +261,8 @@ def run_greedy_equality_correctness_test(baseline_llm_generator, batch_size, max_output_len, force_output_len: bool, - print_tokens: bool = False): + print_tokens: bool = False, + ensure_all_accepted: bool = False): """Helper method that compares the outputs of both the baseline LLM and the test LLM. It asserts greedy equality, e.g. that the outputs are exactly the same when temperature is zero. @@ -267,12 +292,13 @@ def run_greedy_equality_correctness_test(baseline_llm_generator, temperature=temperature, ) - spec_batch_tokens, spec_batch_token_ids = get_output_from_llm_generator( - test_llm_generator, prompts, sampling_params) + (spec_batch_tokens, spec_batch_token_ids, + acceptance_rate) = get_output_from_llm_generator(test_llm_generator, + prompts, sampling_params) - (baseline_batch_tokens, - baseline_batch_token_ids) = get_output_from_llm_generator( - baseline_llm_generator, prompts, sampling_params) + (baseline_batch_tokens, baseline_batch_token_ids, + _) = get_output_from_llm_generator(baseline_llm_generator, prompts, + sampling_params) assert len(baseline_batch_token_ids) == len(prompts) assert len(spec_batch_token_ids) == len(prompts) @@ -287,3 +313,6 @@ def run_greedy_equality_correctness_test(baseline_llm_generator, print(f'{i=} {baseline_token_ids=}') print(f'{i=} {spec_token_ids=}') assert baseline_token_ids == spec_token_ids + + if ensure_all_accepted: + assert acceptance_rate == 1.0 diff --git a/tests/spec_decode/e2e/test_multistep_correctness.py b/tests/spec_decode/e2e/test_multistep_correctness.py index 94cc36f22875a..86cab7aba2380 100644 --- a/tests/spec_decode/e2e/test_multistep_correctness.py +++ b/tests/spec_decode/e2e/test_multistep_correctness.py @@ -97,7 +97,7 @@ def test_spec_decode_e2e_with_detokenization(test_llm_generator, temperature=temperature, ) - batch_tokens, batch_token_ids = get_output_from_llm_generator( + batch_tokens, batch_token_ids, _ = get_output_from_llm_generator( test_llm_generator, prompts, sampling_params) # Expect a generation for each prompt in the batch. @@ -200,12 +200,18 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_bs1( Since this test is cheaper than other e2e correctness tests, we generate with a higher output_len. + + When the draft model is the same as the target model, we further check + whether all speculative tokens are accepted. """ - run_greedy_equality_correctness_test(baseline_llm_generator, - test_llm_generator, - batch_size, - max_output_len=output_len, - force_output_len=True) + ensure_all_accepted = test_llm_generator.same_draft_target_model + run_greedy_equality_correctness_test( + baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len=output_len, + force_output_len=True, + ensure_all_accepted=ensure_all_accepted) @pytest.mark.parametrize( diff --git a/tests/spec_decode/test_multi_step_worker.py b/tests/spec_decode/test_multi_step_worker.py index 9832d4f267e8a..442e40f07f0bb 100644 --- a/tests/spec_decode/test_multi_step_worker.py +++ b/tests/spec_decode/test_multi_step_worker.py @@ -642,3 +642,51 @@ def test_draft_proposals_mixed_k(): assert proposals.proposal_lens.tolist() == [ k for _ in range(expected_num_proposal_seqs - 1) ] + [0 for _ in range(expected_num_no_proposal_seqs)] + [k] + + +@torch.inference_mode() +def test_use_draft_model_runner_advance_step(): + """Verify that draft model runner triggers advance step + when applicable. + """ + seed = 100 + model_name = 'JackFram/llama-68m' + + k = 5 + batch_size = 32 + block_size = 32 + num_gpu_blocks = 2048 // block_size + worker = create_worker( + MultiStepWorker, + model_name, + block_size, + num_gpu_blocks, + seed, + model_runner_cls=TP1DraftModelRunner, + ) + + # Mock "_gpu_advance_step" to raise an exception when called. + exception_secret = "artificial stop" + worker.model_runner._gpu_advance_step = MagicMock() + worker.model_runner._gpu_advance_step.side_effect = ValueError( + exception_secret) + + seq_group_metadata_list, _, _ = create_batch(batch_size, k) + + # Fallback (should not call) when num_steps=1. + execute_model_req = ExecuteModelRequest( + seq_group_metadata_list=seq_group_metadata_list, + num_lookahead_slots=k, + num_steps=1) + worker.execute_model(execute_model_req=execute_model_req) + + # Expect exception if _gpu_advance_step is called. + execute_model_req = ExecuteModelRequest( + seq_group_metadata_list=seq_group_metadata_list, + num_lookahead_slots=k, + num_steps=k) + + with pytest.raises(ValueError, match=exception_secret): + worker.execute_model(execute_model_req=execute_model_req) + call_args_list = worker.model_runner._gpu_advance_step.call_args_list + assert len(call_args_list) == 1 diff --git a/tests/tensorizer_loader/test_tensorizer.py b/tests/tensorizer_loader/test_tensorizer.py index a43f9132585b5..b7030e3cd6d42 100644 --- a/tests/tensorizer_loader/test_tensorizer.py +++ b/tests/tensorizer_loader/test_tensorizer.py @@ -214,12 +214,12 @@ def test_openai_apiserver_with_tensorizer(vllm_runner, tmp_path): ## Start OpenAI API server openai_args = [ - "--model", model_ref, "--dtype", "float16", "--load-format", + "--dtype", "float16", "--load-format", "tensorizer", "--model-loader-extra-config", json.dumps(model_loader_extra_config), ] - with RemoteOpenAIServer(openai_args) as server: + with RemoteOpenAIServer(model_ref, openai_args) as server: print("Server ready.") client = server.get_client() diff --git a/tests/utils.py b/tests/utils.py index 8780d45a31b29..80e0895c551b2 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -49,7 +49,13 @@ class RemoteOpenAIServer: DUMMY_API_KEY = "token-abc123" # vLLM's OpenAI server does not need API key MAX_SERVER_START_WAIT_S = 600 # wait for server to start for 60 seconds - def __init__(self, cli_args: List[str], *, auto_port: bool = True) -> None: + def __init__( + self, + model: str, + cli_args: List[str], + *, + auto_port: bool = True, + ) -> None: if auto_port: if "-p" in cli_args or "--port" in cli_args: raise ValueError("You have manually specified the port" @@ -68,12 +74,10 @@ def __init__(self, cli_args: List[str], *, auto_port: bool = True) -> None: # the current process might initialize cuda, # to be safe, we should use spawn method env['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn' - self.proc = subprocess.Popen( - [sys.executable, "-m", "vllm.entrypoints.openai.api_server"] + - cli_args, - env=env, - stdout=sys.stdout, - stderr=sys.stderr) + self.proc = subprocess.Popen(["vllm", "serve"] + [model] + cli_args, + env=env, + stdout=sys.stdout, + stderr=sys.stderr) self._wait_for_server(url=self.url_for("health"), timeout=self.MAX_SERVER_START_WAIT_S) diff --git a/tests/worker/test_model_input.py b/tests/worker/test_model_input.py index ae818ee360f19..2126fafb2323b 100644 --- a/tests/worker/test_model_input.py +++ b/tests/worker/test_model_input.py @@ -3,7 +3,7 @@ import torch -from vllm.attention import AttentionMetadata +from vllm.attention import AttentionMetadata, AttentionMetadataBuilder from vllm.attention.backends.abstract import AttentionBackend from vllm.model_executor import SamplingMetadata from vllm.model_executor.pooling_metadata import PoolingMetadata @@ -26,6 +26,10 @@ def get_impl_cls(): def get_metadata_cls() -> Type["AttentionMetadata"]: return AttentionMetadata + @staticmethod + def get_builder_cls() -> Type["AttentionMetadataBuilder"]: + raise AttentionMetadataBuilder + @staticmethod def get_kv_cache_shape( num_blocks: int, diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 03308d04012aa..143957f7b65f0 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -84,7 +84,8 @@ def paged_attention_v1( max_seq_len: int, alibi_slopes: Optional[torch.Tensor], kv_cache_dtype: str, - kv_scale: float, + k_scale: float, + v_scale: float, tp_rank: int = 0, blocksparse_local_blocks: int = 0, blocksparse_vert_stride: int = 0, @@ -94,8 +95,9 @@ def paged_attention_v1( torch.ops._C.paged_attention_v1( out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len, alibi_slopes, kv_cache_dtype, - kv_scale, tp_rank, blocksparse_local_blocks, blocksparse_vert_stride, - blocksparse_block_size, blocksparse_head_sliding_step) + k_scale, v_scale, tp_rank, blocksparse_local_blocks, + blocksparse_vert_stride, blocksparse_block_size, + blocksparse_head_sliding_step) def paged_attention_v2( @@ -114,7 +116,8 @@ def paged_attention_v2( max_seq_len: int, alibi_slopes: Optional[torch.Tensor], kv_cache_dtype: str, - kv_scale: float, + k_scale: float, + v_scale: float, tp_rank: int = 0, blocksparse_local_blocks: int = 0, blocksparse_vert_stride: int = 0, @@ -124,7 +127,7 @@ def paged_attention_v2( torch.ops._C.paged_attention_v2( out, exp_sum, max_logits, tmp_out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len, - alibi_slopes, kv_cache_dtype, kv_scale, tp_rank, + alibi_slopes, kv_cache_dtype, k_scale, v_scale, tp_rank, blocksparse_local_blocks, blocksparse_vert_stride, blocksparse_block_size, blocksparse_head_sliding_step) @@ -163,6 +166,18 @@ def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor, torch.ops._C.fused_add_rms_norm(input, residual, weight, epsilon) +def advance_step(num_seqs: int, num_queries: int, block_size: int, + input_tokens: torch.Tensor, sampled_token_ids: torch.Tensor, + input_positions: torch.Tensor, seq_lens: torch.Tensor, + slot_mapping: torch.Tensor, + block_tables: torch.Tensor) -> None: + """Advance a step on GPU for existing inputs for a multi-step runner""" + return torch.ops._C.advance_step(num_seqs, num_queries, block_size, + input_tokens, sampled_token_ids, + input_positions, seq_lens, slot_mapping, + block_tables) + + # quantization ops # awq def awq_dequantize(qweight: torch.Tensor, scales: torch.Tensor, @@ -374,11 +389,12 @@ def reshape_and_cache( value_cache: torch.Tensor, slot_mapping: torch.Tensor, kv_cache_dtype: str, - kv_scale: float, + k_scale: float, + v_scale: float, ) -> None: torch.ops._C_cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slot_mapping, - kv_cache_dtype, kv_scale) + kv_cache_dtype, k_scale, v_scale) def reshape_and_cache_flash( diff --git a/vllm/_ipex_ops.py b/vllm/_ipex_ops.py index 99a875c9b3fb7..b4721b4e1aedd 100644 --- a/vllm/_ipex_ops.py +++ b/vllm/_ipex_ops.py @@ -59,7 +59,8 @@ def paged_attention_v1( max_context_len: int, alibi_slopes: Optional[torch.Tensor], kv_cache_dtype: str, - kv_scale: float, + k_scale: float, + v_scale: float, tp_rank: int = 0, blocksparse_local_blocks: int = 0, blocksparse_vert_stride: int = 0, @@ -99,7 +100,8 @@ def paged_attention_v2( max_context_len: int, alibi_slopes: Optional[torch.Tensor], kv_cache_dtype: str, - kv_scale: float, + k_scale: float, + v_scale: float, tp_rank: int = 0, blocksparse_local_blocks: int = 0, blocksparse_vert_stride: int = 0, @@ -227,7 +229,8 @@ def reshape_and_cache( value_cache: torch.Tensor, slot_mapping: torch.Tensor, kv_cache_dtype: str, - kv_scale: float, + k_scale: float, + v_scale: float, ) -> None: assert kv_cache_dtype == "auto" ipex.llm.modules.PagedAttention.reshape_and_cache( diff --git a/vllm/attention/__init__.py b/vllm/attention/__init__.py index f6bce9a187c64..44bfae44cfddd 100644 --- a/vllm/attention/__init__.py +++ b/vllm/attention/__init__.py @@ -1,5 +1,6 @@ from vllm.attention.backends.abstract import (AttentionBackend, - AttentionMetadata) + AttentionMetadata, + AttentionMetadataBuilder) from vllm.attention.layer import Attention from vllm.attention.selector import get_attn_backend @@ -7,6 +8,7 @@ "Attention", "AttentionBackend", "AttentionMetadata", + "AttentionMetadataBuilder", "Attention", "get_attn_backend", ] diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index adb8325168cdf..191c6ff000c85 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -1,11 +1,15 @@ from abc import ABC, abstractmethod from dataclasses import dataclass, fields from enum import Enum, auto -from typing import (Any, Dict, Generic, List, Optional, Set, Tuple, Type, - TypeVar) +from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Set, + Tuple, Type, TypeVar) import torch +if TYPE_CHECKING: + from vllm.sequence import SequenceGroupMetadata + from vllm.worker.model_runner_base import ModelRunnerInputBuilderBase + class AttentionType(Enum): DECODER = auto() # Decoder attention between previous layer Q/K/V @@ -35,6 +39,16 @@ def get_metadata_cls() -> Type["AttentionMetadata"]: def make_metadata(cls, *args, **kwargs) -> "AttentionMetadata": return cls.get_metadata_cls()(*args, **kwargs) + @staticmethod + @abstractmethod + def get_builder_cls() -> Type["AttentionMetadataBuilder"]: + raise NotImplementedError + + @classmethod + def make_metadata_builder(cls, *args, + **kwargs) -> "AttentionMetadataBuilder": + return cls.get_builder_cls()(*args, **kwargs) + @staticmethod @abstractmethod def get_kv_cache_shape( @@ -110,6 +124,33 @@ def asdict_zerocopy(self, T = TypeVar("T", bound=AttentionMetadata) +class AttentionMetadataBuilder(ABC, Generic[T]): + """Abstract class for attention metadata builders.""" + + @abstractmethod + def __init__(self, input_builder) -> None: + raise NotImplementedError + + @abstractmethod + def add_seq_group(self, seq_group_metadata: "SequenceGroupMetadata", + token_lens: List[int], seq_lens: List[int], + curr_seq_lens: List[int], query_lens: List[int], + context_lens: List[int], + curr_sliding_window_blocks: List[int], + prefix_cache_hit: bool, chunked_prefill_enabled: bool): + """Add a sequence group to the metadata and update + corresponding fields (in Python objects). + """ + raise NotImplementedError + + @abstractmethod + def build(self, runner: "ModelRunnerInputBuilderBase", seq_lens: List[int], + query_lens: List[int], cuda_graph_pad_size: int, + batch_size: int) -> T: + """Build attention metadata with on-device tensors.""" + raise NotImplementedError + + class AttentionImpl(ABC, Generic[T]): @abstractmethod @@ -134,7 +175,8 @@ def forward( value: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: T, - kv_scale: float = 1.0, + k_scale: float = 1.0, + v_scale: float = 1.0, attn_type: AttentionType = AttentionType.DECODER, ) -> torch.Tensor: raise NotImplementedError diff --git a/vllm/attention/backends/blocksparse_attn.py b/vllm/attention/backends/blocksparse_attn.py index fe4c4a45dca0d..71954f864a9b4 100644 --- a/vllm/attention/backends/blocksparse_attn.py +++ b/vllm/attention/backends/blocksparse_attn.py @@ -5,6 +5,7 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, AttentionType) +from vllm.attention.backends.utils import CommonMetadataBuilder from vllm.attention.ops.blocksparse_attention.interface import ( LocalStridedBlockSparseAttn, get_head_sliding_step) from vllm.attention.ops.paged_attn import PagedAttention @@ -93,6 +94,10 @@ def get_impl_cls() -> Type["BlocksparseFlashAttentionImpl"]: def get_metadata_cls() -> Type["AttentionMetadata"]: return BlocksparseFlashAttentionMetadata + @staticmethod + def get_builder_cls() -> Type["BlocksparseFlashAttentionMetadataBuilder"]: + return BlocksparseFlashAttentionMetadataBuilder + @staticmethod def get_kv_cache_shape( num_blocks: int, @@ -244,6 +249,12 @@ def decode_metadata(self) -> Optional["BlocksparseFlashAttentionMetadata"]: return self._cached_decode_metadata +class BlocksparseFlashAttentionMetadataBuilder( + CommonMetadataBuilder[BlocksparseFlashAttentionMetadata]): + + _metadata_cls = BlocksparseFlashAttentionMetadata + + class BlocksparseFlashAttentionImpl(AttentionImpl): """ If the input tensors contain prompt tokens, the layout is as follows: @@ -327,7 +338,8 @@ def forward( value: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: BlocksparseFlashAttentionMetadata, - kv_scale: float = 1.0, + k_scale: float = 1.0, + v_scale: float = 1.0, attn_type: AttentionType = AttentionType.DECODER, ) -> torch.Tensor: """Forward pass with FlashAttention and PagedAttention. @@ -368,7 +380,8 @@ def forward( value_cache, attn_metadata.slot_mapping, self.kv_cache_dtype, - kv_scale, + k_scale, + v_scale, ) if prefill_meta := attn_metadata.prefill_metadata: @@ -405,7 +418,8 @@ def forward( self.num_kv_heads, self.scale, self.alibi_slopes, - kv_scale, + k_scale, + v_scale, tp_rank=self.tp_rank, blocksparse_local_blocks=self.local_blocks, blocksparse_vert_stride=self.vert_stride, diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 048abed48d2e9..b8a64205b362b 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -1,13 +1,24 @@ """Attention layer with FlashAttention.""" from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Tuple, Type +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type import torch from vllm_flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache from vllm import _custom_ops as ops from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionMetadata, AttentionType) + AttentionMetadata, + AttentionMetadataBuilder, + AttentionType) +from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping, + compute_slot_mapping_start_idx, + is_block_tables_empty) +from vllm.sequence import SequenceGroupMetadata +from vllm.utils import make_tensor_with_pad + +if TYPE_CHECKING: + from vllm.worker.model_runner import (GPUModelRunnerBase, + ModelInputForGPUBuilder) class FlashAttentionBackend(AttentionBackend): @@ -28,6 +39,10 @@ def get_impl_cls() -> Type["FlashAttentionImpl"]: def get_metadata_cls() -> Type["AttentionMetadata"]: return FlashAttentionMetadata + @staticmethod + def get_builder_cls() -> Type["FlashAttentionMetadataBuilder"]: + return FlashAttentionMetadataBuilder + @staticmethod def get_kv_cache_shape( num_blocks: int, @@ -184,6 +199,170 @@ def decode_metadata(self) -> Optional["FlashAttentionMetadata"]: return self._cached_decode_metadata +class FlashAttentionMetadataBuilder( + AttentionMetadataBuilder[FlashAttentionMetadata]): + + def __init__(self, input_builder: "ModelInputForGPUBuilder"): + self.slot_mapping: List[int] = [] + self.prefill_seq_lens: List[int] = [] + self.context_lens: List[int] = [] + self.block_tables: List[List[int]] = [] + self.curr_seq_lens: List[int] = [] + self.num_prefills = 0 + self.num_prefill_tokens = 0 + self.num_decode_tokens = 0 + + self.sliding_window = input_builder.sliding_window + self.block_size = input_builder.block_size + self.use_v2_block_manager = ( + input_builder.scheduler_config.use_v2_block_manager) + + def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata, + token_lens: List[int], seq_lens: List[int], + curr_seq_lens: List[int], query_lens: List[int], + context_lens: List[int], + curr_sliding_window_blocks: List[int], + prefix_cache_hit: bool, chunked_prefill_enabled: bool): + """Add a sequence group to the metadata. Specifically update/append + 1. context length. + 2. block table. + 3. slot mapping. + """ + is_prompt = seq_group_metadata.is_prompt + block_tables = seq_group_metadata.block_tables + + for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len, + curr_sliding_window_block) in zip( + seq_group_metadata.seq_data.keys(), token_lens, seq_lens, + curr_seq_lens, query_lens, context_lens, + curr_sliding_window_blocks): + self.context_lens.append(context_len) + + if is_prompt: + self.num_prefills += 1 + self.num_prefill_tokens += token_len + self.prefill_seq_lens.append(seq_len) + else: + assert query_len == 1, ( + "seq_len: {}, context_len: {}, query_len: {}".format( + seq_len, context_len, query_len)) + self.num_decode_tokens += query_len + self.curr_seq_lens.append(curr_seq_len) + + # Compute block table. + # TODO(sang): Combine chunked prefill and prefix caching by + # only allowing multiple of block_size chunk size. + # NOTE: This only works for oooooooxxx style attention. + block_table = [] + if prefix_cache_hit: + # NOTE(woosuk): For flash-attn, the block table should + # include the entries for the incoming prefill tokens. + block_table = block_tables[seq_id] + elif ((chunked_prefill_enabled or not is_prompt) + and block_tables is not None): + block_table = block_tables[seq_id][-curr_sliding_window_block:] + self.block_tables.append(block_table) + + # Compute slot mapping. + is_profile_run = is_block_tables_empty(block_tables) + start_idx = compute_slot_mapping_start_idx( + is_prompt, query_len, context_len, self.sliding_window, + self.use_v2_block_manager) + compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id, + seq_len, context_len, start_idx, + self.block_size, + seq_group_metadata.block_tables) + + def build(self, runner: "GPUModelRunnerBase", seq_lens, query_lens, + cuda_graph_pad_size: int, batch_size: int): + """Build attention metadata with on-device tensors.""" + device = runner.device + use_captured_graph = cuda_graph_pad_size != -1 + + logits_soft_cap = getattr(runner.model_config.hf_config, + "attn_logit_softcapping", None) + if logits_soft_cap is not None: + raise ValueError( + "Please use Flashinfer backend for models with logits_soft_cap" + " (i.e., Gemma-2). Otherwise, the output might be wrong." + " Set Flashinfer backend by " + "export VLLM_ATTENTION_BACKEND=FLASHINFER.") + + max_query_len = max(query_lens) + max_prefill_seq_len = max(self.prefill_seq_lens, default=0) + max_decode_seq_len = max(self.curr_seq_lens, default=0) + num_decode_tokens = self.num_decode_tokens + + if use_captured_graph: + self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size) + self.block_tables.extend([] * cuda_graph_pad_size) + num_decode_tokens = batch_size + cuda_graph_pad_size + + # The shape of graph_block_tables is + # [max batch size, max context len // block size]. + input_block_tables = runner.graph_block_tables[:batch_size] + for i, block_table in enumerate(self.block_tables): + if block_table: + input_block_tables[i, :len(block_table)] = block_table + block_tables = torch.tensor(input_block_tables, device=device) + else: + max_block_table_len = max( + len(block_table) for block_table in self.block_tables) + block_tables = make_tensor_with_pad( + self.block_tables, + max_len=max_block_table_len, + pad=0, + dtype=torch.int, + device=device, + ) + assert max_query_len > 0, ("query_lens: {}".format(query_lens)) + + context_lens_tensor = torch.tensor(self.context_lens, + dtype=torch.int, + device=device) + seq_lens_tensor = torch.tensor(seq_lens, + dtype=torch.int, + device=device) + query_lens_tensor = torch.tensor(query_lens, + dtype=torch.long, + device=device) + query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1, + dtype=torch.int32, + device=device) + seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1, + dtype=torch.int32, + device=device) + torch.cumsum(seq_lens_tensor, + dim=0, + dtype=seq_start_loc.dtype, + out=seq_start_loc[1:]) + torch.cumsum(query_lens_tensor, + dim=0, + dtype=query_start_loc.dtype, + out=query_start_loc[1:]) + + slot_mapping_tensor = torch.tensor(self.slot_mapping, + dtype=torch.long, + device=device) + + return FlashAttentionMetadata( + num_prefills=self.num_prefills, + slot_mapping=slot_mapping_tensor, + num_prefill_tokens=self.num_prefill_tokens, + num_decode_tokens=num_decode_tokens, + seq_lens=seq_lens, + seq_lens_tensor=seq_lens_tensor, + max_query_len=max_query_len, + max_prefill_seq_len=max_prefill_seq_len, + max_decode_seq_len=max_decode_seq_len, + query_start_loc=query_start_loc, + seq_start_loc=seq_start_loc, + context_lens_tensor=context_lens_tensor, + block_tables=block_tables, + use_cuda_graph=use_captured_graph, + ) + + class FlashAttentionImpl(AttentionImpl): """ If the input tensors contain prompt tokens, the layout is as follows: @@ -256,7 +435,8 @@ def forward( value: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: FlashAttentionMetadata, - kv_scale: float = 1.0, + k_scale: float = 1.0, + v_scale: float = 1.0, attn_type: AttentionType = AttentionType.DECODER, ) -> torch.Tensor: """Forward pass with FlashAttention. @@ -277,7 +457,8 @@ def forward( "FlashAttentionImpl") # NOTE(woosuk): FlashAttention does not support FP8 KV cache. - assert kv_scale == 1.0, "kv_scale is not supported in FlashAttention." + assert k_scale == 1.0 and v_scale == 1.0, ( + "key/v_scale is not supported in FlashAttention.") num_tokens, hidden_size = query.shape # Reshape the query, key, and value tensors. diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index b27e3e40f566d..daff76051a956 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Set, Tuple, Type +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Type try: from flashinfer import BatchDecodeWithPagedKVCacheWrapper @@ -14,7 +14,18 @@ from vllm import _custom_ops as ops from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionMetadata, AttentionType) + AttentionMetadata, + AttentionMetadataBuilder, + AttentionType) +from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping, + compute_slot_mapping_start_idx, + is_block_tables_empty) +from vllm.sequence import SequenceGroupMetadata +from vllm.utils import get_kv_cache_torch_dtype, make_tensor_with_pad + +if TYPE_CHECKING: + from vllm.worker.model_runner import (GPUModelRunnerBase, + ModelInputForGPUBuilder) class FlashInferBackend(AttentionBackend): @@ -31,6 +42,10 @@ def get_impl_cls() -> Type["FlashInferImpl"]: def get_metadata_cls() -> Type["AttentionMetadata"]: return FlashInferMetadata + @staticmethod + def get_builder_cls() -> Type["FlashInferMetadataBuilder"]: + return FlashInferMetadataBuilder + @staticmethod def get_kv_cache_shape( num_blocks: int, @@ -188,6 +203,225 @@ def decode_metadata(self) -> Optional["FlashInferMetadata"]: return self +class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): + + def __init__(self, input_builder: "ModelInputForGPUBuilder"): + self.slot_mapping: List[int] = [] + self.prefill_seq_lens: List[int] = [] + self.context_lens: List[int] = [] + self.block_tables: List[List[int]] = [] + self.curr_seq_lens: List[int] = [] + self.num_prefills = 0 + self.num_prefill_tokens = 0 + self.num_decode_tokens = 0 + + self.sliding_window = input_builder.sliding_window + self.block_size = input_builder.block_size + self.use_v2_block_manager = ( + input_builder.scheduler_config.use_v2_block_manager) + + # Please follow https://docs.flashinfer.ai/tutorials/kv_layout.html#page-layout + # for the precise definition of the following fields. + # An example: + # request 1, page indices [0, 5, 8] + # request 2, page indices [1, 6, 7] + # request 3, page indices [3, 4] + # paged_kv_indices is a concatenation of page indices of all requests: + # [0, 5, 8, 1, 6, 7, 3, 4] + # paged_kv_indptr is used to index into paged_kv_indices: + # [0, 3, 6, 8] + self.paged_kv_indices: List[int] = [] + # 0 at the beginning of paged_kv_indptr indicates the start of the + # first request’s page indices in the paged_kv_indices list. + self.paged_kv_indptr: List[int] = [0] + # paged_kv_last_page_len is the length of the last page of each request + self.paged_kv_last_page_len: List[int] = [] + + def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata, + token_lens: List[int], seq_lens: List[int], + curr_seq_lens: List[int], query_lens: List[int], + context_lens: List[int], + curr_sliding_window_blocks: List[int], + prefix_cache_hit: bool, chunked_prefill_enabled: bool): + """Add a sequence group to the metadata. Specifically update/append + 1. context length. + 2. block table. + 3. slot mapping. + """ + is_prompt = seq_group_metadata.is_prompt + block_tables = seq_group_metadata.block_tables + computed_block_nums = seq_group_metadata.computed_block_nums + + for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len, + curr_sliding_window_block) in zip( + seq_group_metadata.seq_data.keys(), token_lens, seq_lens, + curr_seq_lens, query_lens, context_lens, + curr_sliding_window_blocks): + self.context_lens.append(context_len) + if is_prompt: + self.num_prefills += 1 + self.num_prefill_tokens += token_len + self.prefill_seq_lens.append(seq_len) + else: + assert query_len == 1, ( + "seq_len: {}, context_len: {}, query_len: {}".format( + seq_len, context_len, query_len)) + self.num_decode_tokens += query_len + self.curr_seq_lens.append(curr_seq_len) + + # Compute block table. + # TODO(sang): Combine chunked prefill and prefix caching by + # only allowing multiple of block_size chunk size. + # NOTE: This only works for oooooooxxx style attention. + block_table = [] + if prefix_cache_hit: + block_table = computed_block_nums + elif ((chunked_prefill_enabled or not is_prompt) + and block_tables is not None): + block_table = block_tables[seq_id][-curr_sliding_window_block:] + self.block_tables.append(block_table) + + is_profile_run = is_block_tables_empty(block_tables) + + # Compute slot mapping. + start_idx = compute_slot_mapping_start_idx( + is_prompt, query_len, context_len, self.sliding_window, + self.use_v2_block_manager) + compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id, + seq_len, context_len, start_idx, + self.block_size, + seq_group_metadata.block_tables) + + # It is not necessary to add paged_kv_indices, paged_kv_indptr, + # and paged_kv_last_page_len for profile run because we will + # create dummy inputs. + if is_profile_run: + return + + # Get the number of valid blocks based on sequence length. + # If seq_len = 16, block_size = 16, + # block_table_bound is 1 with 1 valid block. + # If seq_len = 15, block_size = 16, + # block_table_bound is 0 + 1 with 1 valid block. + block_table_bound = seq_len // self.block_size + 1 \ + if seq_len % self.block_size != 0 \ + else seq_len // self.block_size + block_table = block_tables[seq_id] + self.paged_kv_indices.extend(block_table[:block_table_bound]) + self.paged_kv_indptr.append(self.paged_kv_indptr[-1] + + block_table_bound) + + last_page_len = seq_len % self.block_size + if last_page_len == 0: + last_page_len = self.block_size + self.paged_kv_last_page_len.append(last_page_len) + + def build(self, runner: "GPUModelRunnerBase", seq_lens, query_lens, + cuda_graph_pad_size: int, batch_size: int): + device = runner.device + use_captured_graph = cuda_graph_pad_size != -1 + + max_query_len = max(query_lens) + max_prefill_seq_len = max(self.prefill_seq_lens, default=0) + num_decode_tokens = self.num_decode_tokens + + if use_captured_graph: + self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size) + self.block_tables.extend([] * cuda_graph_pad_size) + num_decode_tokens = batch_size + cuda_graph_pad_size + + # The shape of graph_block_tables is + # [max batch size, max context len // block size]. + input_block_tables = runner.graph_block_tables[:batch_size] + for i, block_table in enumerate(self.block_tables): + if block_table: + input_block_tables[i, :len(block_table)] = block_table + block_tables = torch.tensor(input_block_tables, device=device) + + last_paged_kv_indptr = self.paged_kv_indptr[-1] + self.paged_kv_indptr.extend([last_paged_kv_indptr] * + cuda_graph_pad_size) + self.paged_kv_last_page_len.extend([0] * cuda_graph_pad_size) + else: + max_block_table_len = max( + len(block_table) for block_table in self.block_tables) + block_tables = make_tensor_with_pad( + self.block_tables, + max_len=max_block_table_len, + pad=0, + dtype=torch.int, + device=device, + ) + assert max_query_len > 0, ("query_lens: {}".format(query_lens)) + + seq_lens_tensor = torch.tensor(seq_lens, + dtype=torch.int, + device=device) + query_lens_tensor = torch.tensor(query_lens, + dtype=torch.long, + device=device) + query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1, + dtype=torch.int32, + device=device) + seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1, + dtype=torch.int32, + device=device) + torch.cumsum(seq_lens_tensor, + dim=0, + dtype=seq_start_loc.dtype, + out=seq_start_loc[1:]) + torch.cumsum(query_lens_tensor, + dim=0, + dtype=query_start_loc.dtype, + out=query_start_loc[1:]) + + slot_mapping_tensor = torch.tensor(self.slot_mapping, + dtype=torch.long, + device=device) + + logits_soft_cap = getattr(runner.model_config.hf_config, + "attn_logit_softcapping", None) + + if len(self.paged_kv_indptr) > 0: + paged_kv_indices_tensor = torch.tensor(self.paged_kv_indices, + device="cpu", + dtype=torch.int) + paged_kv_indptr_tensor = torch.tensor(self.paged_kv_indptr, + device="cpu", + dtype=torch.int) + paged_kv_last_page_len_tensor = torch.tensor( + self.paged_kv_last_page_len, device="cpu", dtype=torch.int) + else: + paged_kv_indices_tensor = None + paged_kv_indptr_tensor = None + paged_kv_last_page_len_tensor = None + + kv_cache_dtype = get_kv_cache_torch_dtype(runner.kv_cache_dtype, + runner.model_config.dtype) + return FlashInferMetadata( + num_prefills=self.num_prefills, + slot_mapping=slot_mapping_tensor, + num_prefill_tokens=self.num_prefill_tokens, + num_decode_tokens=num_decode_tokens, + max_prefill_seq_len=max_prefill_seq_len, + block_tables=block_tables, + paged_kv_indptr=paged_kv_indptr_tensor, + paged_kv_indices=paged_kv_indices_tensor, + paged_kv_last_page_len=paged_kv_last_page_len_tensor, + num_qo_heads=runner.model_config.get_num_attention_heads( + runner.parallel_config), + num_kv_heads=runner.model_config.get_num_kv_heads( + runner.parallel_config), + head_dim=runner.model_config.get_head_size(), + page_size=self.block_size, + seq_start_loc=seq_start_loc, + query_start_loc=query_start_loc, + device=device, + data_type=kv_cache_dtype, + use_cuda_graph=use_captured_graph, + logits_soft_cap=logits_soft_cap) + + class FlashInferImpl(AttentionImpl): def __init__( @@ -223,10 +457,12 @@ def forward( value: torch.Tensor, kv_cache: Optional[torch.Tensor], attn_metadata: FlashInferMetadata, - kv_scale: float = 1.0, + k_scale: float = 1.0, + v_scale: float = 1.0, attn_type: AttentionType = AttentionType.DECODER, ) -> torch.Tensor: - assert kv_scale == 1.0 + assert k_scale == 1.0 and v_scale == 1.0, ( + "key/v_scale is not supported in FlashInfer.") if attn_type != AttentionType.DECODER: raise NotImplementedError("Encoder self-attention and " "encoder/decoder cross-attention " diff --git a/vllm/attention/backends/ipex_attn.py b/vllm/attention/backends/ipex_attn.py index 6a1295b1000bc..4559dd15f600c 100644 --- a/vllm/attention/backends/ipex_attn.py +++ b/vllm/attention/backends/ipex_attn.py @@ -156,7 +156,8 @@ def forward( value: torch.Tensor, kv_cache: Optional[torch.Tensor], attn_metadata: IpexAttnMetadata, # type: ignore - kv_scale: float = 1.0, + k_scale: float = 1.0, + v_scale: float = 1.0, attn_type: AttentionType = AttentionType.DECODER, ) -> torch.Tensor: """Forward pass with IPEX varlen_attention and PagedAttention. @@ -170,7 +171,7 @@ def forward( Returns: shape = [num_tokens, num_heads * head_size] """ - assert kv_scale == 1.0 + assert k_scale == 1.0 and v_scale == 1.0 if attn_type != AttentionType.DECODER: raise NotImplementedError("Encoder self-attention and " "encoder/decoder cross-attention " @@ -192,7 +193,8 @@ def forward( value_cache, attn_metadata.slot_mapping.flatten(), self.kv_cache_dtype, - kv_scale, + k_scale, + v_scale, ) if attn_metadata.is_prompt: @@ -273,7 +275,8 @@ def forward( max_seq_len, self.alibi_slopes, self.kv_cache_dtype, - kv_scale, + k_scale, + v_scale, ) else: # Run PagedAttention V2. @@ -305,7 +308,8 @@ def forward( max_seq_len, self.alibi_slopes, self.kv_cache_dtype, - kv_scale, + k_scale, + v_scale, ) # Reshape the output tensor. diff --git a/vllm/attention/backends/pallas.py b/vllm/attention/backends/pallas.py index c45f7b28b2afb..b83a83bb177d4 100644 --- a/vllm/attention/backends/pallas.py +++ b/vllm/attention/backends/pallas.py @@ -131,7 +131,8 @@ def forward( value: torch.Tensor, kv_cache: Tuple[Optional[torch.Tensor], Optional[torch.Tensor]], attn_metadata: PallasMetadata, - kv_scale: float = 1.0, + k_scale: float = 1.0, + v_scale: float = 1.0, attn_type: AttentionType = AttentionType.DECODER, ) -> torch.Tensor: """Forward pass with Pallas attention. @@ -146,7 +147,7 @@ def forward( Returns: shape = [batch_size, seq_len, num_heads * head_size] """ - assert kv_scale == 1.0 + assert k_scale == 1.0 and v_scale == 1.0 if attn_type != AttentionType.DECODER: raise NotImplementedError("Encoder self-attention and " "encoder/decoder cross-attention " diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 81b546c65c819..17c3b25034bf3 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -7,6 +7,7 @@ import vllm.envs as envs from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, AttentionType) +from vllm.attention.backends.utils import CommonMetadataBuilder from vllm.attention.ops.paged_attn import (PagedAttention, PagedAttentionMetadata) from vllm.logger import init_logger @@ -28,6 +29,10 @@ def get_impl_cls() -> Type["ROCmFlashAttentionImpl"]: def get_metadata_cls() -> Type["AttentionMetadata"]: return ROCmFlashAttentionMetadata + @staticmethod + def get_builder_cls() -> Type["ROCmFlashAttentionMetadataBuilder"]: + return ROCmFlashAttentionMetadataBuilder + @staticmethod def get_kv_cache_shape( num_blocks: int, @@ -166,6 +171,12 @@ def decode_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]: return self._cached_decode_metadata +class ROCmFlashAttentionMetadataBuilder( + CommonMetadataBuilder[ROCmFlashAttentionMetadata]): + + _metadata_cls = ROCmFlashAttentionMetadata + + def _make_alibi_bias(alibi_slopes: torch.Tensor, dtype: torch.dtype, seq_lens: Optional[List[int]], @@ -296,7 +307,8 @@ def forward( value: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: ROCmFlashAttentionMetadata, - kv_scale: float = 1.0, + k_scale: float = 1.0, + v_scale: float = 1.0, attn_type: AttentionType = AttentionType.DECODER, ) -> torch.Tensor: """Forward pass with FlashAttention and PagedAttention. @@ -336,7 +348,8 @@ def forward( value_cache, attn_metadata.slot_mapping, self.kv_cache_dtype, - kv_scale, + k_scale, + v_scale, ) num_prefill_tokens = attn_metadata.num_prefill_tokens @@ -456,7 +469,8 @@ def forward( self.num_kv_heads, self.scale, self.alibi_slopes, - kv_scale, + k_scale, + v_scale, ) # Reshape the output tensor. diff --git a/vllm/attention/backends/torch_sdpa.py b/vllm/attention/backends/torch_sdpa.py index 48418f24870f9..fe6a56123ce72 100644 --- a/vllm/attention/backends/torch_sdpa.py +++ b/vllm/attention/backends/torch_sdpa.py @@ -144,7 +144,8 @@ def forward( value: torch.Tensor, kv_cache: Optional[torch.Tensor], attn_metadata: TorchSDPAMetadata, # type: ignore - kv_scale: float = 1.0, + k_scale: float = 1.0, + v_scale: float = 1.0, attn_type: AttentionType = AttentionType.DECODER, ) -> torch.Tensor: """Forward pass with torch SDPA and PagedAttention. @@ -158,7 +159,7 @@ def forward( Returns: shape = [num_tokens, num_heads * head_size] """ - assert kv_scale == 1.0 + assert k_scale == 1.0 and v_scale == 1.0 if attn_type != AttentionType.DECODER: raise NotImplementedError("Encoder self-attention and " "encoder/decoder cross-attention " @@ -176,7 +177,8 @@ def forward( PagedAttention.write_to_paged_cache(key, value, key_cache, value_cache, attn_metadata.slot_mapping, - self.kv_cache_dtype, kv_scale) + self.kv_cache_dtype, k_scale, + v_scale) if attn_metadata.is_prompt: assert attn_metadata.seq_lens is not None @@ -239,7 +241,8 @@ def forward( self.num_kv_heads, self.scale, self.alibi_slopes, - kv_scale, + k_scale, + v_scale, ) # Reshape the output tensor. diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py index a3cfc6e20748b..62d0eeb249bd4 100644 --- a/vllm/attention/backends/utils.py +++ b/vllm/attention/backends/utils.py @@ -1,7 +1,239 @@ """Attention backend utils""" +from typing import TYPE_CHECKING, Dict, List, Type, TypeVar, Union + +import torch + +from vllm.attention import AttentionMetadata, AttentionMetadataBuilder +from vllm.sequence import SequenceGroupMetadata +from vllm.utils import make_tensor_with_pad # Error string(s) for encoder/decoder # unsupported attention scenarios - STR_NOT_IMPL_ENC_DEC_ROCM_HIP = ("ROCm/HIP is not currently supported " "with encoder/decoder models.") + +PAD_SLOT_ID = -1 + +if TYPE_CHECKING: + from vllm.worker.model_runner import (GPUModelRunnerBase, + ModelInputForGPUBuilder) + + +def is_block_tables_empty(block_tables: Union[None, Dict]): + """ + Check if block_tables is None or a dictionary with all None values. + """ + if block_tables is None: + return True + if isinstance(block_tables, dict) and all( + value is None for value in block_tables.values()): + return True + return False + + +def compute_slot_mapping_start_idx(is_prompt: bool, query_len: int, + context_len: int, sliding_window: int, + use_v2_block_manager: bool): + """ + Compute the start index of slot mapping. + """ + start_idx = 0 + if is_prompt and sliding_window is not None: + assert use_v2_block_manager or context_len == 0, ( + "Prefix caching is currently not supported with " + "sliding window attention in V1 block manager") + # When prefill, we use it to not write slots to kv cache + # to save memory. + start_idx = max(0, query_len - sliding_window) + return start_idx + + +def compute_slot_mapping(is_profile_run: bool, slot_mapping: List[int], + seq_id: int, seq_len: int, context_len: int, + start_idx: int, block_size: int, + block_tables: Dict[int, List[int]]): + """ + Compute slot mapping. + """ + if is_profile_run: + # During memory profiling, the block tables are not + # initialized yet. In this case, we just use a dummy + # slot mapping. + # In embeddings, the block tables are {seq_id: None}. + slot_mapping.extend([PAD_SLOT_ID] * seq_len) + return + + # Mask the [0, start_idx) tokens of the prompt with + # PAD_SLOT_ID, where start_idx is max(0, seq_len - + # sliding_window). For example, if the prompt len is 10, + # sliding window is 8, and block size is 4, the first two + # tokens are masked and the slot mapping will be + # [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1]. + block_table = block_tables[seq_id] + slot_mapping.extend([PAD_SLOT_ID] * max(0, start_idx - context_len)) + for i in range(max(start_idx, context_len), seq_len): + block_number = block_table[i // block_size] + block_offset = i % block_size + slot = block_number * block_size + block_offset + slot_mapping.append(slot) + + +TAttentionMetadata = TypeVar("TAttentionMetadata", bound='AttentionMetadata') + + +class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]): + + _metadata_cls: Type[TAttentionMetadata] + + def __init__(self, input_builder: "ModelInputForGPUBuilder"): + self.slot_mapping: List[int] = [] + self.prefill_seq_lens: List[int] = [] + self.context_lens: List[int] = [] + self.block_tables: List[List[int]] = [] + self.curr_seq_lens: List[int] = [] + self.num_prefills = 0 + self.num_prefill_tokens = 0 + self.num_decode_tokens = 0 + + self.sliding_window = input_builder.sliding_window + self.block_size = input_builder.block_size + self.use_v2_block_manager = ( + input_builder.scheduler_config.use_v2_block_manager) + + def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata, + token_lens: List[int], seq_lens: List[int], + curr_seq_lens: List[int], query_lens: List[int], + context_lens: List[int], + curr_sliding_window_blocks: List[int], prefix_cache_hit, + chunked_prefill_enabled): + is_prompt = seq_group_metadata.is_prompt + block_tables = seq_group_metadata.block_tables + computed_block_nums = seq_group_metadata.computed_block_nums + + for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len, + curr_sliding_window_block) in zip( + seq_group_metadata.seq_data.keys(), token_lens, seq_lens, + curr_seq_lens, query_lens, context_lens, + curr_sliding_window_blocks): + self.context_lens.append(context_len) + if is_prompt: + self.num_prefills += 1 + self.num_prefill_tokens += token_len + self.prefill_seq_lens.append(seq_len) + else: + assert query_len == 1, ( + "seq_len: {}, context_len: {}, query_len: {}".format( + seq_len, context_len, query_len)) + self.num_decode_tokens += query_len + self.curr_seq_lens.append(curr_seq_len) + + # Compute block table. + # TODO(sang): Combine chunked prefill and prefix caching by + # only allowing multiple of block_size chunk size. + # NOTE: This only works for oooooooxxx style attention. + block_table = [] + if prefix_cache_hit: + block_table = computed_block_nums + elif ((chunked_prefill_enabled or not is_prompt) + and block_tables is not None): + block_table = block_tables[seq_id][-curr_sliding_window_block:] + self.block_tables.append(block_table) + + # Compute slot mapping. + is_profile_run = is_block_tables_empty(block_tables) + start_idx = compute_slot_mapping_start_idx( + is_prompt, query_len, context_len, self.sliding_window, + self.use_v2_block_manager) + compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id, + seq_len, context_len, start_idx, + self.block_size, + seq_group_metadata.block_tables) + + def build(self, runner: "GPUModelRunnerBase", seq_lens: List[int], + query_lens: List[int], cuda_graph_pad_size: int, + batch_size: int): + device = runner.device + use_captured_graph = cuda_graph_pad_size != -1 + + logits_soft_cap = getattr(runner.model_config.hf_config, + "attn_logit_softcapping", None) + if logits_soft_cap is not None: + raise ValueError( + "Please use Flashinfer backend for models with logits_soft_cap " + "(i.e., Gemma-2). Otherwise, the output might be wrong. " + "Set Flashinfer backend by " + "export VLLM_ATTENTION_BACKEND=FLASHINFER.") + + max_query_len = max(query_lens) + max_prefill_seq_len = max(self.prefill_seq_lens, default=0) + max_decode_seq_len = max(self.curr_seq_lens, default=0) + num_decode_tokens = self.num_decode_tokens + + if use_captured_graph: + self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size) + self.block_tables.extend([] * cuda_graph_pad_size) + num_decode_tokens = batch_size + cuda_graph_pad_size + + # The shape of graph_block_tables is + # [max batch size, max context len // block size]. + input_block_tables = runner.graph_block_tables[:batch_size] + for i, block_table in enumerate(self.block_tables): + if block_table: + input_block_tables[i, :len(block_table)] = block_table + block_tables = torch.tensor(input_block_tables, device=device) + else: + max_block_table_len = max( + len(block_table) for block_table in self.block_tables) + block_tables = make_tensor_with_pad( + self.block_tables, + max_len=max_block_table_len, + pad=0, + dtype=torch.int, + device=device, + ) + assert max_query_len > 0, "query_lens: {}".format(query_lens) + + context_lens_tensor = torch.tensor(self.context_lens, + dtype=torch.int, + device=device) + seq_lens_tensor = torch.tensor(seq_lens, + dtype=torch.int, + device=device) + query_lens_tensor = torch.tensor(query_lens, + dtype=torch.long, + device=device) + query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1, + dtype=torch.int32, + device=device) + seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1, + dtype=torch.int32, + device=device) + torch.cumsum(seq_lens_tensor, + dim=0, + dtype=seq_start_loc.dtype, + out=seq_start_loc[1:]) + torch.cumsum(query_lens_tensor, + dim=0, + dtype=query_start_loc.dtype, + out=query_start_loc[1:]) + + slot_mapping_tensor = torch.tensor(self.slot_mapping, + dtype=torch.long, + device=device) + + return self._metadata_cls( # type: ignore + num_prefills=self.num_prefills, + slot_mapping=slot_mapping_tensor, + num_prefill_tokens=self.num_prefill_tokens, + num_decode_tokens=num_decode_tokens, + seq_lens=seq_lens, + seq_lens_tensor=seq_lens_tensor, + max_query_len=max_query_len, + max_prefill_seq_len=max_prefill_seq_len, + max_decode_seq_len=max_decode_seq_len, + query_start_loc=query_start_loc, + seq_start_loc=seq_start_loc, + context_lens_tensor=context_lens_tensor, + block_tables=block_tables, + use_cuda_graph=use_captured_graph, + ) diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 6cc5f1d1477ae..1573cd7da94cd 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -11,6 +11,7 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, AttentionType) +from vllm.attention.backends.utils import CommonMetadataBuilder from vllm.attention.ops.paged_attn import (PagedAttention, PagedAttentionMetadata) from vllm.logger import init_logger @@ -32,6 +33,10 @@ def get_impl_cls() -> Type["XFormersImpl"]: def get_metadata_cls() -> Type["AttentionMetadata"]: return XFormersMetadata + @staticmethod + def get_builder_cls() -> Type["XFormersMetadataBuilder"]: + return XFormersMetadataBuilder + @staticmethod def get_kv_cache_shape( num_blocks: int, @@ -362,6 +367,11 @@ def _get_seq_len_block_table_args( raise AttributeError(f"Invalid attention type {str(attn_type)}") +class XFormersMetadataBuilder(CommonMetadataBuilder[XFormersMetadata]): + + _metadata_cls = XFormersMetadata + + class XFormersImpl(AttentionImpl[XFormersMetadata]): """ If the input tensors contain prompt tokens, the layout is as follows: @@ -427,7 +437,8 @@ def forward( value: Optional[torch.Tensor], kv_cache: Optional[torch.Tensor], attn_metadata: "XFormersMetadata", - kv_scale: float = 1.0, + k_scale: float = 1.0, + v_scale: float = 1.0, attn_type: AttentionType = AttentionType.DECODER, ) -> torch.Tensor: """Forward pass with xFormers and PagedAttention. @@ -531,7 +542,7 @@ def forward( value_cache, updated_slot_mapping, self.kv_cache_dtype, - kv_scale) + k_scale, v_scale) if attn_type != AttentionType.ENCODER: # Decoder self-attention supports chunked prefill. @@ -620,7 +631,8 @@ def forward( self.num_kv_heads, self.scale, self.alibi_slopes, - kv_scale, + k_scale, + v_scale, ) # Reshape the output tensor. diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index b8cc87be8c748..0619bda90a2a7 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -47,13 +47,14 @@ def __init__( if num_kv_heads is None: num_kv_heads = num_heads - # The default kv_scale is set to 1.0. This is ignored + # The default k/v_scale is set to 1.0. This is ignored # when kv-cache is not fp8, and should be used with # kv-cache in fp8_e5m2. For kv-cache in fp8_e4m3, we - # expect the pre-quantized kv_scale to be loaded along + # expect the pre-quantized k/v_scale to be loaded along # with the model weights. self.kv_cache_dtype = kv_cache_dtype - self._kv_scale = 1.0 + self._k_scale = 1.0 + self._v_scale = 1.0 quant_method = quant_config.get_quant_method( self) if quant_config else None if quant_method is not None: @@ -66,8 +67,8 @@ def __init__( "fp8 checkpoints.") # When FP8 quantization is enabled, we make a parameter # "kv_scale" so that it can be loaded from FP8 checkpoint. - # The kv_scale will then be converted back to self._kv_scale - # in a native float32 value after weight loading. + # The k/v_scale will then be converted back to + # self._kv_scale in a native float32 value after weight loading self.quant_method = quant_method self.quant_method.create_weights(self) @@ -98,7 +99,8 @@ def forward( value, kv_cache, attn_metadata, - self._kv_scale, + self._k_scale, + self._v_scale, attn_type=attn_type) def extra_repr(self) -> str: diff --git a/vllm/attention/ops/ipex_attn.py b/vllm/attention/ops/ipex_attn.py index 5a5317b65004e..81d308c4d4e22 100644 --- a/vllm/attention/ops/ipex_attn.py +++ b/vllm/attention/ops/ipex_attn.py @@ -45,7 +45,8 @@ def write_to_paged_cache( value_cache: torch.Tensor, slot_mapping: torch.Tensor, kv_cache_dtype: str, - kv_scale: float, + k_scale: float, + v_scale: float, *args, ) -> None: ipex_modules.PagedAttention.reshape_and_cache( @@ -64,7 +65,8 @@ def forward_decode( num_kv_heads: int, scale: float, alibi_slopes: Optional[torch.Tensor], - kv_scale: float, + k_scale: float, + v_scale: float, *args, ) -> torch.Tensor: output = torch.empty_like(query) diff --git a/vllm/attention/ops/paged_attn.py b/vllm/attention/ops/paged_attn.py index a214f40d16514..ce7b4d129779c 100644 --- a/vllm/attention/ops/paged_attn.py +++ b/vllm/attention/ops/paged_attn.py @@ -66,7 +66,8 @@ def write_to_paged_cache( value_cache: torch.Tensor, slot_mapping: torch.Tensor, kv_cache_dtype: str, - kv_scale: float, + k_scale: float, + v_scale: float, ) -> None: ops.reshape_and_cache( key, @@ -75,7 +76,8 @@ def write_to_paged_cache( value_cache, slot_mapping.flatten(), kv_cache_dtype, - kv_scale, + k_scale, + v_scale, ) @staticmethod @@ -90,7 +92,8 @@ def forward_decode( num_kv_heads: int, scale: float, alibi_slopes: Optional[torch.Tensor], - kv_scale: float, + k_scale: float, + v_scale: float, tp_rank: int = 0, blocksparse_local_blocks: int = 0, blocksparse_vert_stride: int = 0, @@ -135,7 +138,8 @@ def forward_decode( max_seq_len, alibi_slopes, kv_cache_dtype, - kv_scale, + k_scale, + v_scale, tp_rank, blocksparse_local_blocks, blocksparse_vert_stride, @@ -172,7 +176,8 @@ def forward_decode( max_seq_len, alibi_slopes, kv_cache_dtype, - kv_scale, + k_scale, + v_scale, tp_rank, blocksparse_local_blocks, blocksparse_vert_stride, diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 084100f6c1135..8fcd85585a18f 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -7,6 +7,7 @@ import vllm.envs as envs from vllm.attention.backends.abstract import AttentionBackend from vllm.logger import init_logger +from vllm.platforms import current_platform from vllm.utils import is_cpu, is_hip, is_openvino, is_tpu, is_xpu logger = init_logger(__name__) @@ -136,7 +137,7 @@ def which_attn_to_use( selected_backend = (_Backend.ROCM_FLASH if selected_backend == _Backend.FLASH_ATTN else selected_backend) if selected_backend == _Backend.ROCM_FLASH: - if torch.cuda.get_device_capability()[0] != 9: + if current_platform.get_device_capability()[0] != 9: # not Instinct series GPUs. logger.info("flash_attn is not supported on NAVI GPUs.") else: @@ -145,7 +146,7 @@ def which_attn_to_use( # FlashAttn in NVIDIA GPUs. if selected_backend == _Backend.FLASH_ATTN: - if torch.cuda.get_device_capability()[0] < 8: + if current_platform.get_device_capability()[0] < 8: # Volta and Turing NVIDIA GPUs. logger.info( "Cannot use FlashAttention-2 backend for Volta and Turing " diff --git a/vllm/distributed/device_communicators/shm_broadcast.py b/vllm/distributed/device_communicators/shm_broadcast.py index db0064951cd1b..151b08c1b996c 100644 --- a/vllm/distributed/device_communicators/shm_broadcast.py +++ b/vllm/distributed/device_communicators/shm_broadcast.py @@ -170,7 +170,7 @@ def __init__( self.n_remote_reader = n_remote_reader if connect_ip is None: - connect_ip = get_ip() + connect_ip = get_ip() if n_remote_reader > 0 else "127.0.0.1" context = Context() @@ -230,6 +230,8 @@ def __init__( remote_sync_port=remote_sync_port, ) + logger.info("vLLM message queue communication handle: %s", self.handle) + def export_handle(self) -> Handle: return self.handle diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 32514078fd68e..8bced12a14347 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -12,6 +12,7 @@ from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_timeout import asyncio_timeout from vllm.engine.llm_engine import LLMEngine +from vllm.engine.metrics import StatLoggerBase from vllm.executor.ray_utils import initialize_ray_cluster, ray from vllm.inputs import LLMInputs, PromptInputs from vllm.logger import init_logger @@ -389,6 +390,7 @@ def from_engine_args( engine_args: AsyncEngineArgs, start_engine_loop: bool = True, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, + stat_loggers: Optional[Dict[str, StatLoggerBase]] = None, ) -> "AsyncLLMEngine": """Creates an async LLM engine from the engine arguments.""" # Create the engine configs. @@ -451,6 +453,7 @@ def from_engine_args( max_log_len=engine_args.max_log_len, start_engine_loop=start_engine_loop, usage_context=usage_context, + stat_loggers=stat_loggers, ) return engine @@ -962,3 +965,19 @@ async def is_tracing_enabled(self) -> bool: ) else: return self.engine.is_tracing_enabled() + + def add_logger(self, logger_name: str, logger: StatLoggerBase) -> None: + if self.engine_use_ray: + ray.get( + self.engine.add_logger.remote( # type: ignore + logger_name=logger_name, logger=logger)) + else: + self.engine.add_logger(logger_name=logger_name, logger=logger) + + def remove_logger(self, logger_name: str) -> None: + if self.engine_use_ray: + ray.get( + self.engine.remove_logger.remote( # type: ignore + logger_name=logger_name)) + else: + self.engine.remove_logger(logger_name=logger_name) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index fbe8e2ecf38a0..dc0a49c620d2e 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -379,6 +379,7 @@ def from_engine_args( cls, engine_args: EngineArgs, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, + stat_loggers: Optional[Dict[str, StatLoggerBase]] = None, ) -> "LLMEngine": """Creates an LLM engine from the engine arguments.""" # Create the engine configs. @@ -423,6 +424,7 @@ def from_engine_args( executor_class=executor_class, log_stats=not engine_args.disable_log_stats, usage_context=usage_context, + stat_loggers=stat_loggers, ) return engine diff --git a/vllm/engine/metrics.py b/vllm/engine/metrics.py index 48aec84298d86..4ed7da2377111 100644 --- a/vllm/engine/metrics.py +++ b/vllm/engine/metrics.py @@ -30,55 +30,55 @@ # begin-metrics-definitions class Metrics: labelname_finish_reason = "finished_reason" - _base_library = prometheus_client + _gauge_cls = prometheus_client.Gauge + _counter_cls = prometheus_client.Counter + _histogram_cls = prometheus_client.Histogram def __init__(self, labelnames: List[str], max_model_len: int): # Unregister any existing vLLM collectors self._unregister_vllm_metrics() # Config Information - self.info_cache_config = prometheus_client.Info( - name='vllm:cache_config', - documentation='information of cache_config') + self._create_info_cache_config() # System stats # Scheduler State - self.gauge_scheduler_running = self._base_library.Gauge( + self.gauge_scheduler_running = self._gauge_cls( name="vllm:num_requests_running", documentation="Number of requests currently running on GPU.", labelnames=labelnames) - self.gauge_scheduler_waiting = self._base_library.Gauge( + self.gauge_scheduler_waiting = self._gauge_cls( name="vllm:num_requests_waiting", documentation="Number of requests waiting to be processed.", labelnames=labelnames) - self.gauge_scheduler_swapped = self._base_library.Gauge( + self.gauge_scheduler_swapped = self._gauge_cls( name="vllm:num_requests_swapped", documentation="Number of requests swapped to CPU.", labelnames=labelnames) # KV Cache Usage in % - self.gauge_gpu_cache_usage = self._base_library.Gauge( + self.gauge_gpu_cache_usage = self._gauge_cls( name="vllm:gpu_cache_usage_perc", documentation="GPU KV-cache usage. 1 means 100 percent usage.", labelnames=labelnames) - self.gauge_cpu_cache_usage = self._base_library.Gauge( + self.gauge_cpu_cache_usage = self._gauge_cls( name="vllm:cpu_cache_usage_perc", documentation="CPU KV-cache usage. 1 means 100 percent usage.", labelnames=labelnames) # Iteration stats - self.counter_num_preemption = self._base_library.Counter( + self.counter_num_preemption = self._counter_cls( name="vllm:num_preemptions_total", documentation="Cumulative number of preemption from the engine.", labelnames=labelnames) - self.counter_prompt_tokens = self._base_library.Counter( + self.counter_prompt_tokens = self._counter_cls( name="vllm:prompt_tokens_total", documentation="Number of prefill tokens processed.", labelnames=labelnames) - self.counter_generation_tokens = self._base_library.Counter( + self.counter_generation_tokens = self._counter_cls( name="vllm:generation_tokens_total", documentation="Number of generation tokens processed.", labelnames=labelnames) - self.histogram_time_to_first_token = self._base_library.Histogram( + self.histogram_time_to_first_token = self._histogram_cls( name="vllm:time_to_first_token_seconds", documentation="Histogram of time to first token in seconds.", labelnames=labelnames, @@ -86,7 +86,7 @@ def __init__(self, labelnames: List[str], max_model_len: int): 0.001, 0.005, 0.01, 0.02, 0.04, 0.06, 0.08, 0.1, 0.25, 0.5, 0.75, 1.0, 2.5, 5.0, 7.5, 10.0 ]) - self.histogram_time_per_output_token = self._base_library.Histogram( + self.histogram_time_per_output_token = self._histogram_cls( name="vllm:time_per_output_token_seconds", documentation="Histogram of time per output token in seconds.", labelnames=labelnames, @@ -97,59 +97,157 @@ def __init__(self, labelnames: List[str], max_model_len: int): # Request stats # Latency - self.histogram_e2e_time_request = self._base_library.Histogram( + self.histogram_e2e_time_request = self._histogram_cls( name="vllm:e2e_request_latency_seconds", documentation="Histogram of end to end request latency in seconds.", labelnames=labelnames, buckets=[1.0, 2.5, 5.0, 10.0, 15.0, 20.0, 30.0, 40.0, 50.0, 60.0]) # Metadata - self.histogram_num_prompt_tokens_request = self._base_library.Histogram( + self.histogram_num_prompt_tokens_request = self._histogram_cls( name="vllm:request_prompt_tokens", documentation="Number of prefill tokens processed.", labelnames=labelnames, buckets=build_1_2_5_buckets(max_model_len), ) self.histogram_num_generation_tokens_request = \ - self._base_library.Histogram( + self._histogram_cls( name="vllm:request_generation_tokens", documentation="Number of generation tokens processed.", labelnames=labelnames, buckets=build_1_2_5_buckets(max_model_len), ) - self.histogram_best_of_request = self._base_library.Histogram( + self.histogram_best_of_request = self._histogram_cls( name="vllm:request_params_best_of", documentation="Histogram of the best_of request parameter.", labelnames=labelnames, buckets=[1, 2, 5, 10, 20], ) - self.histogram_n_request = self._base_library.Histogram( + self.histogram_n_request = self._histogram_cls( name="vllm:request_params_n", documentation="Histogram of the n request parameter.", labelnames=labelnames, buckets=[1, 2, 5, 10, 20], ) - self.counter_request_success = self._base_library.Counter( + self.counter_request_success = self._counter_cls( name="vllm:request_success_total", documentation="Count of successfully processed requests.", labelnames=labelnames + [Metrics.labelname_finish_reason]) + # Speculatie decoding stats + self.gauge_spec_decode_draft_acceptance_rate = self._gauge_cls( + name="vllm:spec_decode_draft_acceptance_rate", + documentation="Speulative token acceptance rate.", + labelnames=labelnames) + self.gauge_spec_decode_efficiency = self._gauge_cls( + name="vllm:spec_decode_efficiency", + documentation="Speculative decoding system efficiency.", + labelnames=labelnames) + self.counter_spec_decode_num_accepted_tokens = (self._counter_cls( + name="vllm:spec_decode_num_accepted_tokens_total", + documentation="Number of accepted tokens.", + labelnames=labelnames)) + self.counter_spec_decode_num_draft_tokens = self._counter_cls( + name="vllm:spec_decode_num_draft_tokens_total", + documentation="Number of draft tokens.", + labelnames=labelnames) + self.counter_spec_decode_num_emitted_tokens = (self._counter_cls( + name="vllm:spec_decode_num_emitted_tokens_total", + documentation="Number of emitted tokens.", + labelnames=labelnames)) + # Deprecated in favor of vllm:prompt_tokens_total - self.gauge_avg_prompt_throughput = self._base_library.Gauge( + self.gauge_avg_prompt_throughput = self._gauge_cls( name="vllm:avg_prompt_throughput_toks_per_s", documentation="Average prefill throughput in tokens/s.", labelnames=labelnames, ) # Deprecated in favor of vllm:generation_tokens_total - self.gauge_avg_generation_throughput = self._base_library.Gauge( + self.gauge_avg_generation_throughput = self._gauge_cls( name="vllm:avg_generation_throughput_toks_per_s", documentation="Average generation throughput in tokens/s.", labelnames=labelnames, ) + def _create_info_cache_config(self) -> None: + # Config Information + self.info_cache_config = prometheus_client.Info( + name='vllm:cache_config', + documentation='information of cache_config') + def _unregister_vllm_metrics(self) -> None: - for collector in list(self._base_library.REGISTRY._collector_to_names): + for collector in list(prometheus_client.REGISTRY._collector_to_names): if hasattr(collector, "_name") and "vllm" in collector._name: - self._base_library.REGISTRY.unregister(collector) + prometheus_client.REGISTRY.unregister(collector) + + +# end-metrics-definitions + + +class _RayGaugeWrapper: + """Wraps around ray.util.metrics.Gauge to provide same API as + prometheus_client.Gauge""" + + def __init__(self, + name: str, + documentation: str = "", + labelnames: Optional[List[str]] = None): + labelnames_tuple = tuple(labelnames) if labelnames else None + self._gauge = ray_metrics.Gauge(name=name, + description=documentation, + tag_keys=labelnames_tuple) + + def labels(self, **labels): + self._gauge.set_default_tags(labels) + return self + + def set(self, value: Union[int, float]): + return self._gauge.set(value) + + +class _RayCounterWrapper: + """Wraps around ray.util.metrics.Counter to provide same API as + prometheus_client.Counter""" + + def __init__(self, + name: str, + documentation: str = "", + labelnames: Optional[List[str]] = None): + labelnames_tuple = tuple(labelnames) if labelnames else None + self._counter = ray_metrics.Counter(name=name, + description=documentation, + tag_keys=labelnames_tuple) + + def labels(self, **labels): + self._counter.set_default_tags(labels) + return self + + def inc(self, value: Union[int, float] = 1.0): + if value == 0: + return + return self._counter.inc(value) + + +class _RayHistogramWrapper: + """Wraps around ray.util.metrics.Histogram to provide same API as + prometheus_client.Histogram""" + + def __init__(self, + name: str, + documentation: str = "", + labelnames: Optional[List[str]] = None, + buckets: Optional[List[float]] = None): + labelnames_tuple = tuple(labelnames) if labelnames else None + self._histogram = ray_metrics.Histogram(name=name, + description=documentation, + tag_keys=labelnames_tuple, + boundaries=buckets) + + def labels(self, **labels): + self._histogram.set_default_tags(labels) + return self + + def observe(self, value: Union[int, float]): + return self._histogram.observe(value) class RayMetrics(Metrics): @@ -157,7 +255,9 @@ class RayMetrics(Metrics): RayMetrics is used by RayPrometheusStatLogger to log to Ray metrics. Provides the same metrics as Metrics but uses Ray's util.metrics library. """ - _base_library = ray_metrics + _gauge_cls = _RayGaugeWrapper + _counter_cls = _RayCounterWrapper + _histogram_cls = _RayHistogramWrapper def __init__(self, labelnames: List[str], max_model_len: int): if ray_metrics is None: @@ -168,8 +268,9 @@ def _unregister_vllm_metrics(self) -> None: # No-op on purpose pass - -# end-metrics-definitions + def _create_info_cache_config(self) -> None: + # No-op on purpose + pass def build_1_2_5_buckets(max_value: int) -> List[int]: @@ -454,7 +555,26 @@ def log(self, stats: Stats): self.num_generation_tokens = [] self.last_local_log = stats.now + if stats.spec_decode_metrics is not None: + self._log_gauge( + self.metrics.gauge_spec_decode_draft_acceptance_rate, + stats.spec_decode_metrics.draft_acceptance_rate) + self._log_gauge(self.metrics.gauge_spec_decode_efficiency, + stats.spec_decode_metrics.system_efficiency) + self._log_counter( + self.metrics.counter_spec_decode_num_accepted_tokens, + stats.spec_decode_metrics.accepted_tokens) + self._log_counter( + self.metrics.counter_spec_decode_num_draft_tokens, + stats.spec_decode_metrics.draft_tokens) + self._log_counter( + self.metrics.counter_spec_decode_num_emitted_tokens, + stats.spec_decode_metrics.emitted_tokens) + class RayPrometheusStatLogger(PrometheusStatLogger): """RayPrometheusStatLogger uses Ray metrics instead.""" _metrics_cls = RayMetrics + + def info(self, type: str, obj: SupportsMetricsInfo) -> None: + return None diff --git a/vllm/executor/multiproc_gpu_executor.py b/vllm/executor/multiproc_gpu_executor.py index a0e248b2e1992..01ed9d1219e7f 100644 --- a/vllm/executor/multiproc_gpu_executor.py +++ b/vllm/executor/multiproc_gpu_executor.py @@ -1,5 +1,7 @@ import asyncio import os +import signal +import weakref from functools import partial from typing import Any, List, Optional @@ -78,6 +80,19 @@ def _init_executor(self) -> None: result_handler.start() self.worker_monitor.start() + # Set up signal handlers to shutdown the executor cleanly + # sometimes gc does not work well + + # Use weakref to avoid holding a reference to self + ref = weakref.ref(self) + + def shutdown(signum, frame): + if executor := ref(): + executor.shutdown() + + signal.signal(signal.SIGINT, shutdown) + signal.signal(signal.SIGTERM, shutdown) + self.driver_worker = self._create_worker( distributed_init_method=distributed_init_method) self._run_workers("init_device") diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py index 388f934ef75a6..edff9b6c93e09 100644 --- a/vllm/executor/ray_gpu_executor.py +++ b/vllm/executor/ray_gpu_executor.py @@ -224,13 +224,27 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", # broadcasted to. self.non_driver_workers: List[RayWorkerWrapper] = [] + tp_driver_worker_ranks = [] + non_driver_worker_ranks = [] for idx, rank in enumerate(worker_ranks[1:]): # We need to skip the driver worker, which we # do by skipping worker_ranks[0] which is always 0. if rank % self.parallel_config.tensor_parallel_size == 0: self.tp_driver_workers.append(self.workers[idx]) + tp_driver_worker_ranks.append(rank) else: self.non_driver_workers.append(self.workers[idx]) + non_driver_worker_ranks.append(rank) + + # Enforce rank order for correct rank to return final output. + self.tp_driver_workers = [ + worker for _, worker in sorted( + zip(tp_driver_worker_ranks, self.tp_driver_workers)) + ] + self.non_driver_workers = [ + worker for _, worker in sorted( + zip(non_driver_worker_ranks, self.non_driver_workers)) + ] def _driver_execute_model( self, execute_model_req: Optional[ExecuteModelRequest] diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index bc07d2b831862..684e1abf7bcf7 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -196,6 +196,15 @@ def __init__(self, else: self.register_parameter("bias", None) + def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): + # If the weight on disk does not have a shape, give it one + # (such scales for AutoFp8). + if len(loaded_weight.shape) == 0: + loaded_weight = loaded_weight.reshape(1) + + assert param.size() == loaded_weight.size() + param.data.copy_(loaded_weight) + def forward(self, x: torch.Tensor) -> torch.Tensor: bias = self.bias if not self.skip_bias_add else None assert self.quant_method is not None diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 5c916c9b4d7e4..cfef914ed6cf7 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -407,31 +407,56 @@ def __init__(self, quant_config: Fp8Config): self.quant_config = quant_config def create_weights(self, layer: torch.nn.Module): - """Create "weight" (aka kv_scale) for an attention layer. + """Create "weight" (aka k_scale and v_scale) for an attention layer. Args: layer: The layer that is using the QuantizeMethodBase factory. """ - # Initialize the KV cache scale to 1.0 as the default value. - # If the kv_scale appears in the checkpoint, it will be + # Initialize the KV cache scales to -1.0, which is an invalid value. + # If the k/v_scale appears in the checkpoint, it will be # overwritten when loading weights. - layer.kv_scale = Parameter(torch.tensor(1.0), requires_grad=False) + layer.k_scale = Parameter(torch.tensor(-1.0), requires_grad=False) + layer.v_scale = Parameter(torch.tensor(-1.0), requires_grad=False) def apply(self, layer: torch.nn.Module) -> torch.Tensor: raise RuntimeError("Fp8KVCacheMethod.apply should not be called.") def process_weights_after_loading(self, layer: Module) -> None: - # If the kv-cache dtype is auto, we enforce the kv-scale to be 1.0 + # If the kv-cache dtype is auto, we enforce the k/v_scale to be 1.0 # regardless whether the kv-scale is available in the checkpoint. if layer.kv_cache_dtype != "auto": - kv_scale = layer.kv_scale.to("cpu").tolist() - if not isinstance(kv_scale, float): + if layer.k_scale > 0.0 and layer.v_scale > 0.0: + # We prefer to use separate k_scale and v_scale if present + k_scale = layer.k_scale.to("cpu").tolist() + v_scale = layer.v_scale.to("cpu").tolist() + elif layer.k_scale < 0.0 and layer.v_scale < 0.0: + # If no scales were loaded (both scales are invalid negative + # values), use the default value of 1.0 + k_scale = Parameter(torch.tensor(1.0), requires_grad=False) + v_scale = Parameter(torch.tensor(1.0), requires_grad=False) + else: + # If we find a single kv_scale in the checkpoint, we remap + # kv_scale to k_scale during weight loading, and duplicate + # k_scale to v_scale here + assert layer.k_scale > 0.0 + scale_to_duplicate = max(layer.k_scale, layer.v_scale) + k_scale = scale_to_duplicate.to("cpu").tolist() + v_scale = scale_to_duplicate.to("cpu").tolist() + + if not isinstance(k_scale, float) or not isinstance( + v_scale, float): raise ValueError("Only support per-tensor scaling factor " "for fp8 KV cache") - layer._kv_scale = kv_scale - if layer._kv_scale == 1.0 and "e5m2" not in layer.kv_cache_dtype: + + # These are used in the final Attention.forward() + layer._k_scale = k_scale + layer._v_scale = v_scale + if (layer._k_scale == 1.0 and layer._v_scale == 1.0 + and "e5m2" not in layer.kv_cache_dtype): print_warning_once( - "Using KV cache scaling factor 1.0 for fp8_e4m3. This may " - "cause accuracy issues. Please make sure kv-cache scaling " - "factor is available in the fp8 checkpoint.") - del layer.kv_scale + "Using KV cache scaling factor 1.0 for fp8_e4m3. This " + "may cause accuracy issues. Please make sure k/v_scale " + "scaling factors are available in the fp8 checkpoint.") + + del layer.k_scale + del layer.v_scale diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index 6d00ea64f7cb8..5c376797a054f 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -47,6 +47,32 @@ def __init__(self): # speculative decoding. self.include_gpu_probs_tensor = False + def _init_sampling_tensors( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ): + """The goal here is to reuse sampling tensors between similar decode + runs. This is possible because sampling logic does not change between + decodes of the same sequences. + """ + _, vocab_size = logits.shape + + # First free any existing stored sampling tensors. + # This is necessary because some sampling tensors may + # have pinned memory. + self._sampling_tensors = None + + # Initialize new sampling tensors + (sampling_tensors, do_penalties, do_top_p_top_k, + do_min_p) = SamplingTensors.from_sampling_metadata( + sampling_metadata, vocab_size, logits.device, logits.dtype) + + self._sampling_tensors = sampling_tensors + self._do_penalties = do_penalties + self._do_top_p_top_k = do_top_p_top_k + self._do_min_p = do_min_p + def forward( self, logits: torch.Tensor, @@ -60,12 +86,23 @@ def forward( assert logits is not None _, vocab_size = logits.shape - logits = _apply_min_tokens_penalty(logits, sampling_metadata) - # Prepare sampling tensors with pinned memory to avoid blocking. - (sampling_tensors, do_penalties, do_top_p_top_k, - do_min_p) = SamplingTensors.from_sampling_metadata( - sampling_metadata, vocab_size, logits.device, logits.dtype) + if not sampling_metadata.reuse_sampling_tensors: + self._init_sampling_tensors(logits, sampling_metadata) + elif self._do_penalties: + # In this case, the sampling tensors logic depends on + # "output_tokens" of a sequence. As a result, we cannot + # reuse sampling tensors, since "output_tokens" changes + # between decode runs. + self._init_sampling_tensors(logits, sampling_metadata) + + assert self._sampling_tensors is not None + sampling_tensors = self._sampling_tensors + do_penalties = self._do_penalties + do_top_p_top_k = self._do_top_p_top_k + do_min_p = self._do_min_p + + logits = _apply_min_tokens_penalty(logits, sampling_metadata) # Apply presence and frequency penalties. if do_penalties: @@ -77,7 +114,7 @@ def forward( # Apply temperature scaling. # Use in-place division to avoid creating a new tensor. - logits.div_(sampling_tensors.temperatures.unsqueeze_(dim=1)) + logits.div_(sampling_tensors.temperatures.unsqueeze(dim=1)) if do_top_p_top_k: logits = _apply_top_k_top_p(logits, sampling_tensors.top_ps, @@ -109,13 +146,19 @@ def forward( on_device_tensors = None # Get the logprobs query results. - prompt_logprobs, sample_logprobs = _get_logprobs( - logprobs, sampling_metadata, sample_results) - return _build_sampler_output(sample_results, - sampling_metadata, - prompt_logprobs, - sample_logprobs, - on_device_tensors=on_device_tensors) + prompt_logprobs = None + sample_logprobs = None + if not sampling_metadata.skip_sampler_cpu_output: + prompt_logprobs, sample_logprobs = _get_logprobs( + logprobs, sampling_metadata, sample_results) + + return _build_sampler_output( + sample_results, + sampling_metadata, + prompt_logprobs, + sample_logprobs, + on_device_tensors=on_device_tensors, + skip_sampler_cpu_output=sampling_metadata.skip_sampler_cpu_output) @property def _should_modify_greedy_probs_inplace(self) -> bool: @@ -535,24 +578,29 @@ def _sample_with_torch( # GPU<->CPU sync happens in the loop below. # This also converts the sample output to Python objects. - for sampling_type in SamplingType: - if sampling_type not in sample_metadata: - continue - (seq_group_id, seq_groups) = sample_metadata[sampling_type] - if sampling_type == SamplingType.GREEDY: - sample_results = _greedy_sample(seq_groups, greedy_samples) - elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED): - sample_results = _random_sample(seq_groups, - multinomial_samples[sampling_type]) - elif sampling_type == SamplingType.BEAM: - sample_results = _beam_search_sample(seq_groups, - beam_search_logprobs) - sample_results_dict.update(zip(seq_group_id, sample_results)) + if not sampling_metadata.skip_sampler_cpu_output: + for sampling_type in SamplingType: + if sampling_type not in sample_metadata: + continue + (seq_group_id, seq_groups) = sample_metadata[sampling_type] + if sampling_type == SamplingType.GREEDY: + sample_results = _greedy_sample(seq_groups, greedy_samples) + elif sampling_type in (SamplingType.RANDOM, + SamplingType.RANDOM_SEED): + sample_results = _random_sample( + seq_groups, multinomial_samples[sampling_type]) + elif sampling_type == SamplingType.BEAM: + sample_results = _beam_search_sample(seq_groups, + beam_search_logprobs) + sample_results_dict.update(zip(seq_group_id, sample_results)) + + sample_results = [ + sample_results_dict.get(i, ([], [])) + for i in range(len(sampling_metadata.seq_groups)) + ] + else: + sample_results = [] - sample_results = [ - sample_results_dict.get(i, ([], [])) - for i in range(len(sampling_metadata.seq_groups)) - ] return sample_results, sampled_token_ids_tensor @@ -997,10 +1045,11 @@ def _modify_greedy_probs_inplace(logprobs: torch.Tensor, probs: torch.Tensor, def _build_sampler_output( sample_results: SampleResultType, sampling_metadata: SamplingMetadata, - prompt_logprobs: List[Optional[PromptLogprobs]], - sample_logprobs: List[SampleLogprobs], + prompt_logprobs: Optional[List[Optional[PromptLogprobs]]], + sample_logprobs: Optional[List[SampleLogprobs]], on_device_tensors: Optional[Tuple[torch.Tensor, torch.Tensor, torch.Tensor]], + skip_sampler_cpu_output: bool = False, ) -> SamplerOutput: """Construct Python objects with the output of sampling. @@ -1010,22 +1059,26 @@ def _build_sampler_output( allows post-processing without copies to CPU/serialization, e.g. in speculative decoding rejection sampling. """ - sampler_output: List[CompletionSequenceGroupOutput] = [] - for (seq_group, sample_result, group_prompt_logprobs, - group_sample_logprobs) in zip(sampling_metadata.seq_groups, - sample_results, prompt_logprobs, - sample_logprobs): - seq_ids = seq_group.seq_ids - next_token_ids, parent_ids = sample_result - seq_outputs: List[SequenceOutput] = [] - for parent_id, next_token_id, logprobs in zip(parent_ids, - next_token_ids, - group_sample_logprobs): - seq_outputs.append( - SequenceOutput(seq_ids[parent_id], next_token_id, logprobs)) - sampler_output.append( - CompletionSequenceGroupOutput(seq_outputs, group_prompt_logprobs)) + if not skip_sampler_cpu_output: + assert prompt_logprobs is not None + assert sample_logprobs is not None + + for (seq_group, sample_result, group_prompt_logprobs, + group_sample_logprobs) in zip(sampling_metadata.seq_groups, + sample_results, prompt_logprobs, + sample_logprobs): + seq_ids = seq_group.seq_ids + next_token_ids, parent_ids = sample_result + seq_outputs: List[SequenceOutput] = [] + for parent_id, next_token_id, logprobs in zip( + parent_ids, next_token_ids, group_sample_logprobs): + seq_outputs.append( + SequenceOutput(seq_ids[parent_id], next_token_id, + logprobs)) + sampler_output.append( + CompletionSequenceGroupOutput(seq_outputs, + group_prompt_logprobs)) # If not specified, store None values in SamplerOutput. if on_device_tensors is not None: diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py index c8568b3dc6690..698c59d49fe06 100644 --- a/vllm/model_executor/model_loader/weight_utils.py +++ b/vllm/model_executor/model_loader/weight_utils.py @@ -22,6 +22,7 @@ from vllm.model_executor.layers.quantization import (QuantizationConfig, get_quantization_config) from vllm.model_executor.layers.quantization.schema import QuantParamSchema +from vllm.utils import print_warning_once logger = init_logger(__name__) @@ -431,11 +432,6 @@ def convert_pyslice_to_tensor(x: Any) -> torch.Tensor: def default_weight_loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None: """Default weight loader.""" - # If the weight on disk does not have a shape, give it one - # (such scales for AutoFp8). - if len(loaded_weight.shape) == 0: - loaded_weight = loaded_weight.reshape(1) - assert param.size() == loaded_weight.size() param.data.copy_(loaded_weight) @@ -444,6 +440,7 @@ def initialize_dummy_weights( model: torch.nn.Module, low: float = -1e-3, high: float = 1e-3, + seed: int = 1234, ) -> None: """Initialize model weights with random values. @@ -451,14 +448,74 @@ def initialize_dummy_weights( measurements. Additionally, the model weights should not cause NaNs in the forward pass. We empirically found that initializing the weights with values between -1e-3 and 1e-3 works well for most models. + + We use per-parameter random seed, so that dummy weights are consistent, + even if the model is partitioned across multiple devices. When the seed + is fixed, the random values generated by this function only depends on + the parameter's number of elements and its data type. """ for param in model.state_dict().values(): if torch.is_floating_point(param): + generator = torch.Generator(device=param.data.device) + generator.manual_seed(seed) if torch.finfo(param.data.dtype).bits < 16: # uniform_ doesn't support < 16-bit datatypes (FP8) dtype = param.data.dtype tmp_param = param.data.to(torch.float16) - tmp_param = tmp_param.uniform_(low, high).to(dtype) + tmp_param = tmp_param.uniform_(low, high, + generator=generator).to(dtype) param.data.copy_(tmp_param) else: - param.uniform_(low, high) + param.uniform_(low, high, generator=generator) + + +def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> Optional[str]: + """Remap the name of FP8 k/v_scale parameters. + + This function handles the remapping of FP8 k/v_scale parameter names. + It detects if the given name ends with a suffix and attempts to remap + it to the expected name format in the model. If the remapped name is not + found in the params_dict, a warning is printed and None is returned. + + Args: + name (str): The original loaded checkpoint parameter name. + params_dict (dict): Dictionary containing the model's named parameters. + + Returns: + str: The remapped parameter name if successful, or the original name + if no remapping is needed. + None: If the remapped name is not found in params_dict. + """ + if name.endswith(".kv_scale"): + print_warning_once( + "DEPRECATED. Found kv_scale in the checkpoint. " + "This format is deprecated in favor of separate k_scale and " + "v_scale tensors and will be removed in a future release. " + "Functionally, we will remap kv_scale to k_scale and duplicate " + "k_scale to v_scale") + # NOTE: we remap the deprecated kv_scale to k_scale + remapped_name = name.replace(".kv_scale", ".attn.k_scale") + if remapped_name not in params_dict: + print_warning_once( + f"Found kv_scale in the checkpoint (e.g. {name}), " + "but not found the expected name in the model " + f"(e.g. {remapped_name}). kv_scale is " + "not loaded.") + return None + return remapped_name + + possible_scale_names = [".k_scale", ".v_scale"] + for scale_name in possible_scale_names: + if name.endswith(scale_name): + remapped_name = name.replace(scale_name, f".attn{scale_name}") + if remapped_name not in params_dict: + print_warning_once( + f"Found {scale_name} in the checkpoint (e.g. {name}), " + "but not found the expected name in the model " + f"(e.g. {remapped_name}). {scale_name} is " + "not loaded.") + return None + return remapped_name + + # If there were no matches, return the untouched param name + return name diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index a777d1fbfa802..4c434e54cf743 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -44,13 +44,13 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, kv_cache_scales_loader) + default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors, SamplerOutput -from vllm.utils import is_hip, print_warning_once +from vllm.utils import is_hip from .interfaces import SupportsLoRA -from .utils import is_pp_missing_parameter, make_layers +from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers class LlamaMLP(nn.Module): @@ -257,17 +257,24 @@ def __init__( (lora_config.max_loras or 1)) if lora_config else 0 self.vocab_size = config.vocab_size + lora_vocab self.org_vocab_size = config.vocab_size - self.embed_tokens = VocabParallelEmbedding( - self.vocab_size, - config.hidden_size, - org_num_embeddings=config.vocab_size, - ) + if get_pp_group().is_first_rank or (config.tie_word_embeddings + and get_pp_group().is_last_rank): + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + ) + else: + self.embed_tokens = PPMissingLayer() self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, lambda: LlamaDecoderLayer(config=config, cache_config=cache_config, quant_config=quant_config)) - self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + if get_pp_group().is_last_rank: + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + else: + self.norm = PPMissingLayer() def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) @@ -360,26 +367,30 @@ def __init__( cache_config, quant_config, lora_config=lora_config) - self.unpadded_vocab_size = config.vocab_size - if lora_config: - self.unpadded_vocab_size += lora_config.lora_extra_vocab_size - self.lm_head = ParallelLMHead( - self.unpadded_vocab_size, - config.hidden_size, - org_num_embeddings=config.vocab_size, - padding_size=DEFAULT_VOCAB_PADDING_SIZE - # We need bigger padding if using lora for kernel - # compatibility - if not lora_config else lora_config.lora_vocab_padding_size, - quant_config=quant_config, - ) - if config.tie_word_embeddings: - self.lm_head.weight = self.model.embed_tokens.weight - - logit_scale = getattr(config, "logit_scale", 1.0) - self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, - config.vocab_size, logit_scale) - self.sampler = Sampler() + if get_pp_group().is_last_rank: + self.unpadded_vocab_size = config.vocab_size + if lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE + # We need bigger padding if using lora for kernel + # compatibility + if not lora_config else lora_config.lora_vocab_padding_size, + quant_config=quant_config, + ) + if config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight + + logit_scale = getattr(config, "logit_scale", 1.0) + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size, + logit_scale) + self.sampler = Sampler() + else: + self.lm_head = PPMissingLayer() def forward( self, @@ -460,18 +471,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): if name.endswith(".bias") and name not in params_dict: continue # Remapping the name of FP8 kv-scale. - if name.endswith("kv_scale"): - remapped_kv_scale_name = name.replace( - ".kv_scale", ".attn.kv_scale") - if remapped_kv_scale_name not in params_dict: - print_warning_once( - f"Found kv scale in the checkpoint (e.g. {name}), " - "but not found the expected name in the model " - f"(e.g. {remapped_kv_scale_name}). kv-scale is " - "not loaded.") - continue - else: - name = remapped_kv_scale_name + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue if is_pp_missing_parameter(name, self): continue diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 0c456ada61230..e739df87cf96a 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -42,10 +42,10 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, maybe_remap_kv_scale_name) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors, SamplerOutput -from vllm.utils import print_warning_once from .interfaces import SupportsLoRA @@ -415,19 +415,10 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): if name.endswith(".bias") and name not in params_dict: continue # Remapping the name of FP8 kv-scale. - if name.endswith("kv_scale"): - remapped_kv_scale_name = name.replace( - ".kv_scale", ".attn.kv_scale") - if remapped_kv_scale_name not in params_dict: - print_warning_once( - "Found kv scale in the checkpoint " - f"(e.g. {name}), but not found the expected " - f"name in the model " - f"(e.g. {remapped_kv_scale_name}). " - "kv-scale is not loaded.") - continue - else: - name = remapped_kv_scale_name + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index e9ae2192f280d..e9aa4416eded4 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -43,10 +43,10 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, maybe_remap_kv_scale_name) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors, SamplerOutput -from vllm.utils import print_warning_once from .interfaces import SupportsLoRA @@ -382,18 +382,10 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): if name.endswith(".bias") and name not in params_dict: continue # Remapping the name of FP8 kv-scale. - if name.endswith("kv_scale"): - remapped_kv_scale_name = name.replace( - ".kv_scale", ".attn.kv_scale") - if remapped_kv_scale_name not in params_dict: - print_warning_once( - f"Found kv scale in the checkpoint (e.g. {name}), " - "but not found the expected name in the model " - f"(e.g. {remapped_kv_scale_name}). kv-scale is " - "not loaded.") - continue - else: - name = remapped_kv_scale_name + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) diff --git a/vllm/model_executor/sampling_metadata.py b/vllm/model_executor/sampling_metadata.py index c346cd0562867..29b077cf6d912 100644 --- a/vllm/model_executor/sampling_metadata.py +++ b/vllm/model_executor/sampling_metadata.py @@ -87,6 +87,12 @@ def sample(logits): The first tuple is [1, 2] (sampled index within original logit), and the second tuple is [0, 1] (sampled index within pruned logit). num_prompts: Number of prompt sequence groups in seq_groups. + skip_sampler_cpu_output: Indicates if we want to skip the GPU=>CPU + serialization of token outputs. + reuse_sampling_tensors: Indicates if we want to reuse sampling + tensors that are part of the sampler forward pass. Currently, + it is mainly used for multi-step decode. + """ def __init__( @@ -95,11 +101,15 @@ def __init__( selected_token_indices: torch.Tensor, categorized_sample_indices: Dict[SamplingType, torch.Tensor], num_prompts: int, + skip_sampler_cpu_output: bool = False, + reuse_sampling_tensors: bool = False, ) -> None: self.seq_groups = seq_groups self.selected_token_indices = selected_token_indices self.categorized_sample_indices = categorized_sample_indices self.num_prompts = num_prompts + self.skip_sampler_cpu_output = skip_sampler_cpu_output + self.reuse_sampling_tensors = reuse_sampling_tensors @staticmethod def prepare( diff --git a/vllm/spec_decode/draft_model_runner.py b/vllm/spec_decode/draft_model_runner.py index 90bba96ee8acb..3cb7ec58da4c1 100644 --- a/vllm/spec_decode/draft_model_runner.py +++ b/vllm/spec_decode/draft_model_runner.py @@ -2,17 +2,22 @@ import torch +from vllm import _custom_ops as ops +from vllm.attention.backends.flash_attn import FlashAttentionMetadata from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, MultiModalConfig, ParallelConfig, PromptAdapterConfig, SchedulerConfig) from vllm.logger import init_logger -from vllm.sequence import (IntermediateTensors, SamplerOutput, - SequenceGroupMetadata) +from vllm.sequence import (ExecuteModelRequest, IntermediateTensors, + SamplerOutput) from vllm.worker.model_runner import (ModelInputForGPUWithSamplingMetadata, ModelRunner) logger = init_logger(__name__) +debug_advance_input = False +enable_gpu_advance_step = True + class TP1DraftModelRunner(ModelRunner): """Specialized model runner for speculative decoding draft model. @@ -21,18 +26,9 @@ class TP1DraftModelRunner(ModelRunner): we could get rid of most CPU-GPU synchronization and data transfer overheads by keeping model input and output tensors on GPU all the time. - This runner is still under development so there's no performance gain - at this moment. Currently we adopt a temporary solution that caches the - seq_group_metadata_list for multi-step execution, so that we can - leverage existing prepare_model_input to be compatible with the current - execution flow, but we plan to remove this cache and avoid calling - prepare_model_input in execute_model at all. - - The detail development plan includes: - 1. Use "update_model_input" to update existing model_input without - creating a new one. - 2. Improve the performance of "update_model_input" with a GPU kernel. - 3. Support TP > 1 (this requires some designs because we do not expect + TODOs: + 1. Currently supports only flash-attn, add support for other attn_backends. + 2. Support TP > 1 (this requires some designs because we do not expect any broadcasting inside execute_model). """ @@ -71,51 +67,156 @@ def __init__( return_hidden_states=return_hidden_states, ) - # TODO: Remove this cache when we are able to update model_input - # directly in advance_step. - self.cached_seq_group_metadata_list: Optional[ - List[SequenceGroupMetadata]] = None + def _update_flash_attn_metadata(self, attn_metadata, num_seqs, + num_queries): + assert isinstance(attn_metadata, FlashAttentionMetadata) - def prepare_model_input( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - virtual_engine: int = 0, - finished_requests_ids: Optional[List[str]] = None - ) -> ModelInputForGPUWithSamplingMetadata: - """A temporary solution that caches the seq_group_metadata_list - for multi-step execution. - TODO: In-place update model_input and remove this function. - """ - self.cached_seq_group_metadata_list = seq_group_metadata_list - return super().prepare_model_input( - seq_group_metadata_list, - finished_requests_ids=finished_requests_ids) + if num_seqs != num_queries: + assert num_seqs > num_queries + assert attn_metadata.use_cuda_graph + + assert attn_metadata.num_prefills == 0 + assert attn_metadata.num_prefill_tokens == 0 + assert attn_metadata.num_decode_tokens == num_seqs + assert attn_metadata.slot_mapping.shape == (num_seqs, ) + + assert len(attn_metadata.seq_lens) == num_seqs + assert attn_metadata.seq_lens_tensor.shape == (num_seqs, ) + assert attn_metadata.max_query_len == 1 + assert attn_metadata.max_prefill_seq_len == 0 + assert attn_metadata.max_decode_seq_len == max(attn_metadata.seq_lens) + + assert attn_metadata.query_start_loc.shape == (num_queries + 1, ) + assert attn_metadata.seq_start_loc.shape == (num_seqs + 1, ) + + assert attn_metadata.context_lens_tensor.shape == (num_queries, ) + + assert attn_metadata.block_tables.shape[0] == num_seqs + + # Update query lengths. Note that we update only queries and not seqs, + # since tensors may be padded due to captured cuda graph batch size + for i in range(num_queries): + attn_metadata.seq_lens[i] += 1 + attn_metadata.max_decode_seq_len = max(attn_metadata.seq_lens) - def update_model_input( + def _update_sampling_metadata(self, sampling_metadata, num_seqs, + num_queries): + + assert sampling_metadata.num_prompts == 0 + assert len(sampling_metadata.seq_groups) == num_queries + assert sampling_metadata.selected_token_indices.shape == ( + num_queries, ) + # assert sampling_metadata.categorized_sample_indices == TODO: Add if needed # noqa: E501 + + # Verify that all sequences are decodes + for i in range(num_queries): + seq_group = sampling_metadata.seq_groups[i] + + assert seq_group.is_prompt is False # No prompt + assert seq_group.prompt_logprob_indices == [] # No prompt + assert seq_group.sample_indices == [i] # Simple + assert seq_group.seq_len is None # Decode + assert seq_group.query_len is None # Decode + + def _gpu_advance_step( self, model_input: ModelInputForGPUWithSamplingMetadata, last_output: SamplerOutput ) -> ModelInputForGPUWithSamplingMetadata: - """Prepare the model inputs for the next step. - TODO: In-place update model_input instead of calling - prepare_model_input. + # Currently, we expect "decode mode" only + assert not model_input.is_prompt + + # Get num_seqs + num_seqs = len(model_input.seq_lens) + num_queries = len(model_input.query_lens) + + # Get output tokens GPU tensor + sampled_token_ids = last_output.sampled_token_ids + assert sampled_token_ids is not None + + # Update attn_metadata + attn_metadata = model_input.attn_metadata + assert isinstance(attn_metadata, FlashAttentionMetadata) + self._update_flash_attn_metadata(attn_metadata, num_seqs, num_queries) + + # Update GPU tensors + ops.advance_step(num_seqs=num_seqs, + num_queries=num_queries, + block_size=self.block_size, + input_tokens=model_input.input_tokens, + sampled_token_ids=sampled_token_ids, + input_positions=model_input.input_positions, + seq_lens=attn_metadata.seq_lens_tensor, + slot_mapping=attn_metadata.slot_mapping, + block_tables=attn_metadata.block_tables) + + # Update sampling_metadata + sampling_metadata = model_input.sampling_metadata + self._update_sampling_metadata(sampling_metadata, num_seqs, + num_queries) + + # Create new input + new_model_input = self._model_input_cls( + input_tokens=model_input.input_tokens, + input_positions=model_input.input_positions, + attn_metadata=attn_metadata, + seq_lens=attn_metadata.seq_lens, + query_lens=model_input.query_lens, + lora_mapping=model_input.lora_mapping, + lora_requests=model_input.lora_requests, + multi_modal_kwargs=model_input.multi_modal_kwargs, + sampling_metadata=model_input.sampling_metadata, + is_prompt=False, + ) + + # Ensure we skip CPU samples + assert new_model_input.sampling_metadata.skip_sampler_cpu_output is True + # We can reuse sampling tensors since every decode iteration is the same + new_model_input.sampling_metadata.reuse_sampling_tensors = True + + if debug_advance_input: + logger.debug("NEW INPUT: ") + logger.debug(" input_tokens = %s", new_model_input.input_tokens) + logger.debug(" input_positions = %s", + new_model_input.input_positions) + logger.debug(" seq_lens = %d", new_model_input.seq_lens) + logger.debug(" query_lens = %d", new_model_input.query_lens) + logger.debug(" attn_metadata:") + logger.debug(" seq_lens_tensor: %s", + attn_metadata.seq_lens_tensor) + logger.debug(" slot_mapping: %s", attn_metadata.slot_mapping) + logger.debug(" block_tables: %s", attn_metadata.block_tables) + + return new_model_input + + def supports_gpu_multi_step(self, execute_model_req: ExecuteModelRequest): + """Determines if draft_model_runner GPU multi-step can be used. + Currently required conditions are: + 1. Only decodes + 2. Only flash-attn + 3. No LORA + 4. No prompt_adapter_config """ + if not enable_gpu_advance_step: + return False - # Append the output token to the sequence data. - assert self.cached_seq_group_metadata_list is not None - for seq_group_metadata, sequence_group_outputs in zip( - self.cached_seq_group_metadata_list, last_output.outputs): - seq_group_metadata.is_prompt = False + # We allow multi-step GPU only in decode mode + for seq_group in execute_model_req.seq_group_metadata_list: + if seq_group.is_prompt: + return False - for seq_output in sequence_group_outputs.samples: - seq = seq_group_metadata.seq_data[seq_output.parent_seq_id] + # TODO: Add support for other attn backends + if self.attn_backend.get_name() != "flash-attn": + return False - token_id = seq_output.output_token - token_logprob = seq_output.logprobs[token_id] + # TODO: Add support for LORA + if self.lora_config: + return False - seq.append_token_id(token_id, token_logprob.logprob) - seq.update_num_computed_tokens(1) + # TODO: Add soft-tuning prompt adapter support + if self.prompt_adapter_config: + return False - return self.prepare_model_input(self.cached_seq_group_metadata_list) + return True @torch.inference_mode() def execute_model( @@ -125,42 +226,86 @@ def execute_model( intermediate_tensors: Optional[IntermediateTensors] = None, num_steps: int = 1, ) -> Optional[List[SamplerOutput]]: - # Since we do not broadcast data inside execute_model anymore, - # we need to figure out the best way to support TP > 1 in this - # case, because we will at least need to broadcast the sampled - # tokens to all workers. - if not self.is_driver_worker: - raise ValueError("TP1DraftModelRunner only supports TP=1.") + """Executes num_steps forward passes with advacement of input tensors + on the GPU. Look at supports_gpu_multi_step(..) for pre-conditions. - if self.lora_config: - assert model_input.lora_requests is not None - assert model_input.lora_mapping is not None - self.set_active_loras(model_input.lora_requests, - model_input.lora_mapping) + Optimizations used: + 1. Input tensors are updated on the GPU directly + 2. Skips GPU=>CPU serialization of sampler outputs (we don't need + them since we do batch expansion later that uses GPU outputs) + 3. Reuses sampling tensors (since we run only decodes and they have + a repeating sampling logic) + """ - if self.prompt_adapter_config: - assert model_input.prompt_adapter_requests is not None - assert model_input.prompt_adapter_mapping is not None - self.set_active_prompt_adapters( - model_input.prompt_adapter_requests, - model_input.prompt_adapter_mapping) + # When num_steps == 1, we execute the fallback here for the GPU + # advance_step, which runs prepare_inputs on CPU and for each spec + # iteration invokes this function only once + # (Look at multi-step-worker code) + is_fallback = num_steps == 1 + if not is_fallback: + # Since we do not broadcast data inside execute_model anymore, + # we need to figure out the best way to support TP > 1 in this + # case, because we will at least need to broadcast the sampled + # tokens to all workers. + if not self.is_driver_worker: + raise ValueError("TP1DraftModelRunner only supports TP=1.") + + # Sanity + if self.lora_config is not None: + raise ValueError("TP1DraftModelRunner has no support for LORA") + if self.prompt_adapter_config is not None: + raise ValueError("TP1DraftModelRunner has no support for " + "prompt_adapter_config") + if model_input.multi_modal_kwargs: + raise ValueError( + "TP1DraftModelRunner has no support for multi_modal_kwargs" + ) + else: + if self.lora_config: + assert model_input.lora_requests is not None + assert model_input.lora_mapping is not None + self.set_active_loras(model_input.lora_requests, + model_input.lora_mapping) + + if self.prompt_adapter_config: + assert model_input.prompt_adapter_requests is not None + assert model_input.prompt_adapter_mapping is not None + self.set_active_prompt_adapters( + model_input.prompt_adapter_requests, + model_input.prompt_adapter_mapping) + + # Detect exec mode + assert model_input.attn_metadata is not None + use_cuda_graph = False + if model_input.attn_metadata.num_prefills > 0: + # In this case, execute_model(..) was called directly + if num_steps > 1: + raise ValueError( + "execute_model(..) of draft_model_runner can be called " + "directly only with a single-step prefill") + else: + # We can skip CPU samples for spec token generation. + # (We do allow CPU samples for num_steps == 1 to support the + # fallback case, where supports_gpu_multi_step(..) does not pass) + model_input.sampling_metadata.skip_sampler_cpu_output = ( + not is_fallback) + + # Attn attr defines if we use cuda graphs + use_cuda_graph = model_input.attn_metadata.use_cuda_graph + + # Get model + if use_cuda_graph: + graph_batch_size = model_input.input_tokens.shape[0] + model_executable = (self.graph_runners[model_input.virtual_engine] + [graph_batch_size]) + else: + model_executable = self.model - virtual_engine = model_input.virtual_engine outputs: List[SamplerOutput] = [] for step in range(num_steps): - # Currently cuda graph is only supported by the decode phase. - assert model_input.attn_metadata is not None - prefill_meta = model_input.attn_metadata.prefill_metadata - decode_meta = model_input.attn_metadata.decode_metadata - if prefill_meta is None and decode_meta.use_cuda_graph: - assert model_input.input_tokens is not None - graph_batch_size = model_input.input_tokens.shape[0] - model_executable = ( - self.graph_runners[virtual_engine][graph_batch_size]) - else: - model_executable = self.model - multi_modal_kwargs = model_input.multi_modal_kwargs or {} + + # Run model hidden_states = model_executable( input_ids=model_input.input_tokens, positions=model_input.input_positions, @@ -181,8 +326,8 @@ def execute_model( sampling_metadata=model_input.sampling_metadata, )) - # Prepare the inputs for the next step. + # Prepare inputs for the next step if step != num_steps - 1: - model_input = self.update_model_input(model_input, outputs[-1]) + model_input = self._gpu_advance_step(model_input, outputs[-1]) return outputs diff --git a/vllm/spec_decode/multi_step_worker.py b/vllm/spec_decode/multi_step_worker.py index 09a77f9e870fb..91689324557b5 100644 --- a/vllm/spec_decode/multi_step_worker.py +++ b/vllm/spec_decode/multi_step_worker.py @@ -43,7 +43,7 @@ def init_device(self) -> None: ) def set_include_gpu_probs_tensor(self) -> None: - # Need include_gpu_probs_tensor for multi_step_worker + # Need include_gpu_probs_tensor for MultiStepWorker self.model_runner.model.sampler.include_gpu_probs_tensor = True @torch.inference_mode() @@ -67,14 +67,23 @@ def sampler_output( expanded_request, indices_of_seq_with_bonus_tokens =\ self._expand_execute_model_request( execute_model_req, seq_ids_with_bonus_token_in_last_step) + # Run model sample_len times. model_outputs: List[SamplerOutput] = [] - if isinstance(self.model_runner, TP1DraftModelRunner): + if isinstance( + self.model_runner, TP1DraftModelRunner + ) and self.model_runner.supports_gpu_multi_step(expanded_request): + # Here we run the draft_model_runner with multi-step prepare + # on the GPU directly expanded_request.num_steps = sample_len model_outputs = self.execute_model( execute_model_req=expanded_request) else: - # TODO: Remove this branch once DraftModelRunner supports TP>1. + # Here we run multi-step directly, with every step prepared + # on the CPU. + # TODO: Remove this branch once DraftModelRunner supports TP>1 + # and other restrictions that are part of DraftModelRunner's + # supports_gpu_multi_step(..) for _ in range(sample_len): model_output: List[SamplerOutput] = super().execute_model( execute_model_req=expanded_request) @@ -171,7 +180,7 @@ def _filter_model_output( outputs=[ expanded_batch_output.outputs[i] for i in output_indices_to_retain - ], + ] if len(expanded_batch_output.outputs) > 0 else [], sampled_token_probs=( expanded_batch_output. sampled_token_probs[output_indices_to_retain] diff --git a/vllm/spec_decode/ngram_worker.py b/vllm/spec_decode/ngram_worker.py index 07991df52e655..a21222fec269b 100644 --- a/vllm/spec_decode/ngram_worker.py +++ b/vllm/spec_decode/ngram_worker.py @@ -13,7 +13,7 @@ class NGramWorker(NonLLMProposerWorkerBase, LoraNotSupportedWorkerBase): """NGramWorker provides a light drafter without need for model. - Current NGramWorker only implement prompt lookup decoding, + Current NGramWorker only implements prompt lookup decoding, and in future we may also do RAG type drafter and other scenarios which don't rely on LLM model to give proposals. """ @@ -37,7 +37,7 @@ def init_device(self): self.device = torch.device(f"cuda:{self.local_rank}") self.load_model = lambda *args, **kwargs: None - # Current only support Top1Proposer + # Current NGramWorker only supports Top1Proposer self._proposer = Top1Proposer( weakref.proxy(self), # type: ignore[arg-type] device=self.device, diff --git a/vllm/spec_decode/proposer_worker_base.py b/vllm/spec_decode/proposer_worker_base.py index fffa557121e17..51cefc0cbca8b 100644 --- a/vllm/spec_decode/proposer_worker_base.py +++ b/vllm/spec_decode/proposer_worker_base.py @@ -24,7 +24,7 @@ def sampler_output( ) -> Tuple[Optional[List[SamplerOutput]], bool]: raise NotImplementedError - def set_include_gpu_probs_tensor(self): + def set_include_gpu_probs_tensor(self) -> None: """Implementation optional""" pass diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index 3c8e3dee46831..903264aad7a15 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -206,7 +206,7 @@ def __init__( self.probs_dtype = self.spec_decode_sampler.probs_dtype self.token_id_dtype = self.spec_decode_sampler.token_id_dtype - # Lazy initiazliation. + # Lazy initialization. self.scorer: SpeculativeScorer # Hidden states from target model to pass to proposer diff --git a/vllm/spec_decode/top1_proposer.py b/vllm/spec_decode/top1_proposer.py index 7b34b5d34208b..ade293c2c0757 100644 --- a/vllm/spec_decode/top1_proposer.py +++ b/vllm/spec_decode/top1_proposer.py @@ -138,7 +138,7 @@ def _split_by_proposal_len( # Currently only proposal lens of 0 or the global batch proposal len # are supported. - # If max_proposal_len is defined, then we shall no exccess this + # If max_proposal_len is defined, then we shall no exceed this # quota for nonzero_proposal new_k = 0 if (self.max_proposal_len is None @@ -219,7 +219,7 @@ def _merge_outputs( proposal_lens: List[int], nonzero_proposal_len_indices: List[int], sampler_transposed: bool, - ) -> Tuple[torch.Tensor, torch.tensor, torch.Tensor]: + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """After speculations are produced, merge the speculation results with the skipped sequences. """ diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 205b4f58f7a83..75a2607d0d9c4 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -2,6 +2,7 @@ import gc import time import warnings +import weakref from collections import defaultdict from typing import (TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Set, Tuple, Type, TypeVar, Union) @@ -48,9 +49,9 @@ from vllm.sequence import (IntermediateTensors, SamplerOutput, SequenceGroupMetadata) from vllm.utils import (CudaMemoryProfiler, get_kv_cache_torch_dtype, is_hip, - is_pin_memory_available, make_tensor_with_pad) + is_pin_memory_available) from vllm.worker.model_runner_base import ( - ModelRunnerBase, ModelRunnerInputBase, + ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase, _add_attn_metadata_broadcastable_dict, _add_sampling_metadata_broadcastable_dict, _init_attn_metadata_from_tensor_dict, @@ -165,6 +166,298 @@ def from_broadcasted_tensor_dict( return cls(**tensor_dict) +class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): + """TBA""" + + def __init__(self, + runner: "GPUModelRunnerBase", + finished_requests_ids: Optional[List[str]] = None): + super().__init__() + self.runner = runner + self.model_input_cls = self.runner._model_input_cls + self.attn_backend = self.runner.attn_backend + self.scheduler_config = self.runner.scheduler_config + self.sliding_window = self.runner.sliding_window + self.block_size = self.runner.block_size + self.enable_lora = self.runner.lora_config is not None + self.enable_prompt_adapter = (self.runner.prompt_adapter_config + is not None) + self.multi_modal_input_mapper = self.runner.multi_modal_input_mapper + self.finished_requests_ids = finished_requests_ids + self.decode_only = True + + # Common inputs. + self.input_tokens: List[int] = [] + self.input_positions: List[int] = [] + self.seq_lens: List[int] = [] + self.query_lens: List[int] = [] + self.max_decode_seq_len: int = 0 + self.request_ids_to_seq_ids: Dict[str, List[int]] = defaultdict(list) + + # LoRA inputs. + self.lora_index_mapping: List[int] = [] + self.lora_prompt_mapping: List[int] = [] + self.lora_requests: Set[LoRARequest] = set() + + # Prompt adapter inputs. + self.prompt_adapter_index_mapping: List[int] = [] + self.prompt_adapter_prompt_mapping: List[int] = [] + self.prompt_adapter_requests: Set[PromptAdapterRequest] = set() + + # Multi-modal inputs. + self.multi_modal_inputs_list: List[MultiModalInputs] = [] + + # Attention metadata inputs. + self.attn_metadata_builder = self.attn_backend.make_metadata_builder( + self) + + # Engine/Model configurations. + self.chunked_prefill_enabled = ( + self.scheduler_config is not None + and self.scheduler_config.chunked_prefill_enabled) + if self.sliding_window is not None: + self.sliding_window_blocks = ( + self.sliding_window + self.block_size - 1) // self.block_size + self.block_aligned_sliding_window = \ + self.sliding_window_blocks * self.block_size + + def _compute_len_for_sliding_window(self, seq_len: int): + curr_sliding_window_blocks = 0 + sliding_seq_len = seq_len + + # TODO(sang): This is a hack to make sliding window work with + # paged attn. We can remove it if we make paged attn kernel + # to properly handle slinding window attn. + if self.sliding_window is not None: + curr_sliding_window_blocks = self.sliding_window_blocks + if self.scheduler_config.use_v2_block_manager: + # number of elements in last block + suff_len = seq_len % self.block_size + sliding_seq_len = min( + seq_len, self.block_aligned_sliding_window + suff_len) + if suff_len > 0: + curr_sliding_window_blocks += 1 + else: + sliding_seq_len = min(seq_len, self.sliding_window) + return curr_sliding_window_blocks, sliding_seq_len + + def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata): + seq_ids = list(seq_group_metadata.seq_data.keys()) + n_seqs = len(seq_ids) + is_prompt = seq_group_metadata.is_prompt + token_chunk_size = seq_group_metadata.token_chunk_size + + if is_prompt: + assert n_seqs == 1 + self.decode_only = False + + # Mapping from request IDs to sequence IDs. Used for Jamba models + # that manages the cache by itself. + self.request_ids_to_seq_ids[seq_group_metadata.request_id] = [] + # The number of input tokens in each sequence. + token_lens: List[int] = [] + # The number of tokens that are already computed. + context_lens: List[int] = [] + # The current sliding window block for each sequence. + curr_sliding_window_blocks: List[int] = [] + # The original sequence length (before applying sliding window) + # for each sequence. + orig_seq_lens: List[int] = [] + # The sequence length (may be capped to the sliding window). + curr_seq_lens: List[int] = [] + for seq_id in seq_ids: + seq_data = seq_group_metadata.seq_data[seq_id] + self.request_ids_to_seq_ids[seq_group_metadata.request_id].append( + seq_id) + computed_block_nums = seq_group_metadata.computed_block_nums + + # Check if hit prefix cache (i.e., some blocks are already computed) + # Note that prefix caching does not support sliding window. + prefix_cache_hit = (computed_block_nums is not None + and len(computed_block_nums) > 0 + and self.sliding_window is None and is_prompt) + if self.chunked_prefill_enabled and prefix_cache_hit: + raise RuntimeError( + "chunked prefill cannot be used with prefix caching now.") + + # Compute context length (the number of tokens that are + # already computed) and sequence length (total number of tokens). + seq_len = seq_data.get_len() + if is_prompt: + context_len = seq_data.get_num_computed_tokens() + else: + # get_num_computed_tokens is incorrect for spec decoding. + # So, we should have a special logic here. + # TODO(sang): Fix it. + context_len = seq_len - 1 + seq_len = min(seq_len, context_len + token_chunk_size) + + # Compute tokens. + if is_prompt: + tokens = seq_data.get_token_ids()[context_len:seq_len] + else: + # Optimization. get_token_ids requires the entire copy of + # tokens. + tokens = [seq_data.get_last_token_id()] + if prefix_cache_hit: + assert computed_block_nums is not None + context_len = len(computed_block_nums) * self.block_size + tokens = tokens[context_len:] + + # These are seq_len/context_len capped to the sliding window. + # They are passed to decode kernel. + # We still need original seq_len/context_len to compute slot + # mapping (and input position) below. + if is_prompt: + curr_sliding_window_block = 0 + sliding_seq_len = seq_len + query_len = seq_len - context_len + else: + curr_sliding_window_block, sliding_seq_len = ( + self._compute_len_for_sliding_window(seq_len)) + query_len = 1 + + self.seq_lens.append(sliding_seq_len) + if not is_prompt: + self.max_decode_seq_len = max(self.max_decode_seq_len, + sliding_seq_len) + self.query_lens.append(query_len) + self.input_tokens.extend(tokens) + self.input_positions.extend(list(range(context_len, seq_len))) + + # Intermediate data of the current sequence group for + # the attention metadata. + token_lens.append(len(tokens)) + context_lens.append(context_len) + curr_seq_lens.append(sliding_seq_len) + curr_sliding_window_blocks.append(curr_sliding_window_block) + orig_seq_lens.append(seq_len) + + # Update attention metadata. Note that input builder attributes + # (self.xxx) include all added sequences, so we need to slice + # the last n_seqs sequences. + self.attn_metadata_builder.add_seq_group( + seq_group_metadata, token_lens, orig_seq_lens, curr_seq_lens, + self.query_lens[-n_seqs:], context_lens, + curr_sliding_window_blocks, prefix_cache_hit, + self.chunked_prefill_enabled) + + # LoRA data. + if self.enable_lora: + lora_id = seq_group_metadata.lora_int_id + for query_len in self.query_lens[-n_seqs:]: + if lora_id > 0: + self.lora_requests.add(seq_group_metadata.lora_request) + self.lora_index_mapping += [lora_id] * query_len + self.lora_prompt_mapping.extend( + [lora_id] * + (query_len if seq_group_metadata.sampling_params + and seq_group_metadata.sampling_params.prompt_logprobs + is not None else 1)) + + # Prompt adapter data. Note that when is_prompt=True, + # we expect only one sequence in the group. + if self.enable_prompt_adapter: + prompt_adapter_id = seq_group_metadata.prompt_adapter_id + if prompt_adapter_id > 0 and is_prompt: + query_len = self.query_lens[-1] + self.prompt_adapter_requests.add( + seq_group_metadata.prompt_adapter_request) + + num_tokens = seq_group_metadata.\ + prompt_adapter_num_virtual_tokens + pm = [prompt_adapter_id + ] * num_tokens + [0] * (query_len - num_tokens) + self.prompt_adapter_index_mapping += pm + self.prompt_adapter_prompt_mapping.extend( + [prompt_adapter_id] * + (query_len if seq_group_metadata.sampling_params + and seq_group_metadata.sampling_params.prompt_logprobs + else 1)) + + # Multi-modal data. + mm_data = seq_group_metadata.multi_modal_data + if mm_data: + mm_kwargs = self.multi_modal_input_mapper(mm_data) + self.multi_modal_inputs_list.append(mm_kwargs) + + def build(self) -> ModelInputForGPU: + if not self.input_tokens: + return self.model_input_cls() + + batch_size = len(self.input_tokens) + use_captured_graph = ( + self.decode_only and not self.runner.model_config.enforce_eager + and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1] + and self.max_decode_seq_len <= self.runner.max_seq_len_to_capture) + + # If cuda graph can be used, pad tensors accordingly. + # See `capture_model` API for more details. + # vLLM uses cuda graph only for decoding requests. + cuda_graph_pad_size = -1 + if use_captured_graph: + graph_batch_size = _get_graph_batch_size(batch_size) + assert graph_batch_size >= batch_size + cuda_graph_pad_size = graph_batch_size - batch_size + batch_size = graph_batch_size + + # Tokens and positions. + self.input_tokens.extend([0] * cuda_graph_pad_size) + self.input_positions.extend([0] * cuda_graph_pad_size) + input_tokens_tensor = torch.tensor(self.input_tokens, + dtype=torch.long, + device=self.runner.device) + input_positions_tensor = torch.tensor(self.input_positions, + dtype=torch.long, + device=self.runner.device) + + # Sequence and query lengths. + self.seq_lens.extend([1] * cuda_graph_pad_size) + + # Attention metadata. + attn_metadata = self.attn_metadata_builder.build( + self.runner, self.seq_lens, self.query_lens, cuda_graph_pad_size, + batch_size) + + # LoRA data. + if self.enable_lora: + self.lora_index_mapping.extend([0] * cuda_graph_pad_size) + lora_mapping = LoRAMapping( + self.lora_index_mapping, + self.lora_prompt_mapping, + ) + else: + lora_mapping = None + + # Prompt adapter data. + if self.enable_prompt_adapter: + self.prompt_adapter_index_mapping.extend([0] * cuda_graph_pad_size) + prompt_adapter_mapping = PromptAdapterMapping( + self.prompt_adapter_index_mapping, + self.prompt_adapter_prompt_mapping, + ) + else: + prompt_adapter_mapping = None + + # Multi-modal data. + multi_modal_kwargs = MultiModalInputs.batch( + self.multi_modal_inputs_list, device=self.runner.device) + + return self.model_input_cls( + input_tokens=input_tokens_tensor, + input_positions=input_positions_tensor, + attn_metadata=attn_metadata, + seq_lens=self.seq_lens, + query_lens=self.query_lens, + lora_mapping=lora_mapping, + lora_requests=self.lora_requests, + multi_modal_kwargs=multi_modal_kwargs, + request_ids_to_seq_ids=self.request_ids_to_seq_ids, + finished_requests_ids=self.finished_requests_ids, + prompt_adapter_mapping=prompt_adapter_mapping, + prompt_adapter_requests=self.prompt_adapter_requests) + + class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): """ Helper class for shared methods between GPU model runners. @@ -368,464 +661,11 @@ def _prepare_model_input_tensors( If cuda graph is required, this API automatically pads inputs. """ - input_tokens: List[int] = [] - input_positions: List[int] = [] - slot_mapping: List[int] = [] - lora_index_mapping: List[int] = [] - lora_prompt_mapping: List[int] = [] - lora_requests: Set[LoRARequest] = set() - prompt_adapter_index_mapping: List[int] = [] - prompt_adapter_prompt_mapping: List[int] = [] - prompt_adapter_requests: Set[PromptAdapterRequest] = set() - - seq_lens: List[int] = [] - prefill_seq_lens: List[int] = [] - decode_seq_lens: List[int] = [] - context_lens: List[int] = [] - query_lens: List[int] = [] - block_tables: List[List[int]] = [] - multi_modal_inputs_list: List[MultiModalInputs] = [] - request_ids_to_seq_ids: Dict[str, List[int]] = defaultdict(list) - decode_only = True - num_prefills = 0 - num_prefill_tokens = 0 - num_decode_tokens = 0 - - # The following fields are only for flashinfer - # Please follow https://docs.flashinfer.ai/tutorials/kv_layout.html#page-layout - # for the precise definition of the following fields. - # An example: - # request 1, page indices [0, 5, 8] - # request 2, page indices [1, 6, 7] - # request 3, page indices [3, 4] - # paged_kv_indices is a concatenation of page indices of all requests: - # [0, 5, 8, 1, 6, 7, 3, 4] - # paged_kv_indptr is used to index into paged_kv_indices: - # [0, 3, 6, 8] - paged_kv_indices: List[int] = [] - # 0 at the beginning of paged_kv_indptr indicates the start of the - # first request’s page indices in the paged_kv_indices list. - paged_kv_indptr: List[int] = [0] - # paged_kv_last_page_len is the length of the last page of each request - paged_kv_last_page_len: List[int] = [] - - if len(seq_group_metadata_list) == 0: - return self._model_input_cls() - - if self.sliding_window is not None: - sliding_window_blocks = (self.sliding_window + self.block_size - - 1) // self.block_size - block_aligned_sliding_window = \ - sliding_window_blocks * self.block_size - + builder = ModelInputForGPUBuilder(weakref.proxy(self), + finished_requests_ids) for seq_group_metadata in seq_group_metadata_list: - seq_ids = list(seq_group_metadata.seq_data.keys()) - is_prompt = seq_group_metadata.is_prompt - - for seq_id in seq_ids: - computed_block_nums = seq_group_metadata.computed_block_nums - if (self.scheduler_config is not None - and self.scheduler_config.chunked_prefill_enabled - and not (computed_block_nums is None - or computed_block_nums == [])): - raise RuntimeError( - "chunked prefill cannot be used with prefix caching " - "now.") - - seq_data = seq_group_metadata.seq_data[seq_id] - if is_prompt: - context_len = seq_data.get_num_computed_tokens() - else: - # get_num_computed_tokens is incorrect for spec decoding. - # So, we should have a special logic here. - # TODO(sang): Fix it. - context_len = seq_data.get_len() - 1 - - seq_len = min( - seq_data.get_len(), - context_len + seq_group_metadata.token_chunk_size) - if is_prompt: - tokens = seq_data.get_token_ids()[context_len:seq_len] - else: - # Optimization. get_token_ids requires the entire copy of - # tokens. - tokens = [seq_data.get_last_token_id()] - - # Prefix cache was hit. - # Prefix is not supported with sliding_window - prefix_cache_hit = (computed_block_nums is not None - and len(computed_block_nums) > 0 - and self.sliding_window is None - and is_prompt) - - # These are seq_len/context_len capped to the sliding window. - # They are passed to decode kernel. - # We still need original seq_len/context_len to compute slot - # mapping (and input position) below. - curr_sliding_window_blocks = None - sliding_seq_len = seq_len - sliding_context_len = context_len - - # TODO(sang): This is a hack to make sliding window work with - # paged attn. We can remove it if we make paged attn kernel - # to properly handle slinding window attn. - if (self.sliding_window is not None and not is_prompt): - curr_sliding_window_blocks = sliding_window_blocks - if self.scheduler_config.use_v2_block_manager: - # number of elements in last block - suff_len = seq_len % self.block_size - sliding_seq_len = min( - seq_len, block_aligned_sliding_window + suff_len) - if suff_len > 0: - curr_sliding_window_blocks += 1 - else: - sliding_seq_len = min(seq_len, self.sliding_window) - sliding_context_len = sliding_seq_len - 1 - - # TODO(sang): Combine chunked prefill and prefix caching by - # only allowing multiple of block_size chunk size. - # NOTE: This only works for oooooooxxx style attention. - if prefix_cache_hit: - assert computed_block_nums is not None - context_len = len(computed_block_nums) * self.block_size - tokens = tokens[context_len:] - - # need to think what to set it to when we have both sliding - # window and prefix caching... - assert self.sliding_window is None, \ - "Prefix caching is not supported with sliding window" - sliding_context_len = context_len - - if self.attn_backend.get_name() == "flash-attn": - # NOTE(woosuk): For flash-attn, the block table should - # include the entries for the incoming prefill tokens. - # TODO(woosuk): This is a temporary fix. We should - # provide a unified interface for different backends. - block_table = seq_group_metadata.block_tables[seq_id] - else: - block_table = computed_block_nums - elif (self.scheduler_config.chunked_prefill_enabled - or not is_prompt): - if seq_group_metadata.block_tables is not None: - # chunked prefill or decode - block_table = seq_group_metadata.block_tables[seq_id] - if curr_sliding_window_blocks is not None: - block_table = block_table[ - -curr_sliding_window_blocks:] - else: - # Only happens when memory profiling runs. - block_table = [] - else: - # Prefill without chunked prefill or memory profiling. - block_table = [] - block_tables.append(block_table) - - seq_lens.append(sliding_seq_len) - context_lens.append(sliding_context_len) - query_len = sliding_seq_len - sliding_context_len - query_lens.append(query_len) - input_tokens.extend(tokens) - input_positions.extend(list(range(context_len, seq_len))) - lora_id = seq_group_metadata.lora_int_id - prompt_adapter_id = seq_group_metadata.prompt_adapter_id - - if is_prompt: - assert len(seq_ids) == 1 - num_prefills += 1 - num_prefill_tokens += len(tokens) - decode_only = False - prefill_seq_lens.append(seq_len) - else: - assert query_len == 1, ( - "seq_len: {}, context_len: {}, query_len: {}".format( - seq_len, context_len, query_len)) - num_decode_tokens += query_len - decode_seq_lens.append(sliding_seq_len) - - if lora_id > 0: - lora_requests.add(seq_group_metadata.lora_request) - - lora_index_mapping += [lora_id] * query_len - lora_prompt_mapping.extend( - [lora_id] * - (query_len if seq_group_metadata.sampling_params - and seq_group_metadata.sampling_params.prompt_logprobs - is not None else 1)) - - mm_data = seq_group_metadata.multi_modal_data - if mm_data: - # Process multi-modal data - mm_kwargs = self.multi_modal_input_mapper(mm_data) - multi_modal_inputs_list.append(mm_kwargs) - - if prompt_adapter_id > 0 and is_prompt: - prompt_adapter_requests.add( - seq_group_metadata.prompt_adapter_request) - - num_tokens = seq_group_metadata.\ - prompt_adapter_num_virtual_tokens - pm = [prompt_adapter_id - ] * num_tokens + [0] * (query_len - num_tokens) - prompt_adapter_index_mapping += pm - prompt_adapter_prompt_mapping.extend( - [prompt_adapter_id] * - (query_len if seq_group_metadata.sampling_params - and seq_group_metadata.sampling_params.prompt_logprobs - else 1)) - - is_profile_run = _is_block_tables_empty( - seq_group_metadata.block_tables) - if is_profile_run: - # During memory profiling, the block tables are not - # initialized yet. In this case, we just use a dummy - # slot mapping. - # In embeddings, the block tables are {seq_id: None}. - slot_mapping.extend([_PAD_SLOT_ID] * seq_len) - continue - - # Compute the slot mapping. - block_table = seq_group_metadata.block_tables[seq_id] - - # Mask the [0, start_idx) tokens of the prompt with - # _PAD_SLOT_ID, where start_idx is max(0, seq_len - - # sliding_window). For example, if the prompt len is 10, - # sliding window is 8, and block size is 4, the first two - # tokens are masked and the slot mapping will be - # [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1]. - start_idx = 0 - if self.sliding_window is not None: - if is_prompt: - assert self.scheduler_config.use_v2_block_manager \ - or context_len == 0, ( - "Prefix caching is currently not supported with " - "sliding window attention in V1 block manager") - # It is an optimization. When it is decoding, it is always - # 0. When prefill, we use it to not write slots to kv cache - # to save memory. - start_idx = max(0, query_len - self.sliding_window) - - for i in range(context_len, seq_len): - if i < start_idx: - slot_mapping.append(_PAD_SLOT_ID) - continue - - block_number = block_table[i // self.block_size] - block_offset = i % self.block_size - slot = block_number * self.block_size + block_offset - slot_mapping.append(slot) - - # Prepare input tensors for flashinfer - if self.attn_backend.get_name() == "flashinfer": - seq_len = seq_data.get_len() - # Get the number of valid blocks based on sequence length. - # If seq_len = 16, block_size = 16, - # block_table_bound is 1 with 1 valid block. - # If seq_len = 15, block_size = 16, - # block_table_bound is 0 + 1 with 1 valid block. - block_table_bound = seq_len // self.block_size + 1 \ - if seq_len % self.block_size != 0 \ - else seq_len // self.block_size - - paged_kv_indices.extend(block_table[:block_table_bound]) - paged_kv_indptr.append(paged_kv_indptr[-1] + - block_table_bound) - - last_page_len = seq_len % self.block_size - if last_page_len == 0: - last_page_len = self.block_size - paged_kv_last_page_len.append(last_page_len) - - batch_size = len(input_tokens) - max_query_len = max(query_lens) - max_prefill_seq_len = max(prefill_seq_lens, default=0) - max_decode_seq_len = max(decode_seq_lens, default=0) - - # If cuda graph can be used, pad tensors accordingly. - # See `capture_model` API for more details. - # vLLM uses cuda graph only for decoding requests. - use_captured_graph = ( - decode_only and not self.model_config.enforce_eager - and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1] - and max_decode_seq_len <= self.max_seq_len_to_capture) - if use_captured_graph: - graph_batch_size = _get_graph_batch_size(batch_size) - assert graph_batch_size >= batch_size - for _ in range(graph_batch_size - batch_size): - input_tokens.append(0) - input_positions.append(0) - slot_mapping.append(_PAD_SLOT_ID) - seq_lens.append(1) - block_tables.append([]) - lora_index_mapping.append(0) - prompt_adapter_index_mapping.append(0) - if self.attn_backend.get_name() == "flashinfer": - last_paged_kv_indptr = paged_kv_indptr[-1] - paged_kv_indptr.append(last_paged_kv_indptr) - paged_kv_last_page_len.append(0) - batch_size = graph_batch_size - num_decode_tokens = batch_size - - if use_captured_graph: - # The shape of graph_block_tables is - # [max batch size, max context len // block size]. - input_block_tables = self.graph_block_tables[:batch_size] - for i, block_table in enumerate(block_tables): - if block_table: - input_block_tables[i, :len(block_table)] = block_table - block_tables = torch.tensor(input_block_tables, device=self.device) - else: - max_block_table_len = max( - len(block_table) for block_table in block_tables) - block_tables = make_tensor_with_pad( - block_tables, - max_len=max_block_table_len, - pad=0, - dtype=torch.int, - device=self.device, - ) - assert max_query_len > 0, ("query_lens: {}".format(query_lens)) - - context_lens_tensor = torch.tensor(context_lens, - dtype=torch.int, - device=self.device) - - seq_lens_tensor = torch.tensor(seq_lens, - dtype=torch.int, - device=self.device) - query_lens_tensor = torch.tensor(query_lens, - dtype=torch.long, - device=self.device) - query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1, - dtype=torch.int32, - device=self.device) - seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1, - dtype=torch.int32, - device=self.device) - - torch.cumsum(seq_lens_tensor, - dim=0, - dtype=seq_start_loc.dtype, - out=seq_start_loc[1:]) - torch.cumsum(query_lens_tensor, - dim=0, - dtype=query_start_loc.dtype, - out=query_start_loc[1:]) - - input_tokens_tensor = torch.tensor(input_tokens, - dtype=torch.long, - device=self.device) - input_positions_tensor = torch.tensor(input_positions, - dtype=torch.long, - device=self.device) - slot_mapping_tensor = torch.tensor(slot_mapping, - dtype=torch.long, - device=self.device) - - logits_soft_cap = getattr(self.model_config.hf_config, - 'attn_logit_softcapping', None) - if logits_soft_cap is not None and self.attn_backend.get_name( - ) != "flashinfer": - raise ValueError("Please use Flashinfer backend for models with" - "logits_soft_cap (i.e., Gemma-2)." - " Otherwise, the output might be wrong." - " Set Flashinfer backend by " - "export VLLM_ATTENTION_BACKEND=FLASHINFER.") - - if self.attn_backend.get_name() == "flashinfer": - if len(paged_kv_indptr) > 0: - paged_kv_indices_tensor = torch.tensor(paged_kv_indices, - device='cpu', - dtype=torch.int) - paged_kv_indptr_tensor = torch.tensor(paged_kv_indptr, - device='cpu', - dtype=torch.int) - paged_kv_last_page_len_tensor = torch.tensor( - paged_kv_last_page_len, device='cpu', dtype=torch.int) - else: - paged_kv_indices_tensor = None - paged_kv_indptr_tensor = None - paged_kv_last_page_len_tensor = None - - kv_cache_dtype = get_kv_cache_torch_dtype(self.kv_cache_dtype, - self.model_config.dtype) - attn_metadata = self.attn_backend.make_metadata( - num_prefills=num_prefills, - slot_mapping=slot_mapping_tensor, - num_prefill_tokens=num_prefill_tokens, - num_decode_tokens=num_decode_tokens, - max_prefill_seq_len=max_prefill_seq_len, - block_tables=block_tables, - paged_kv_indptr=paged_kv_indptr_tensor, - paged_kv_indices=paged_kv_indices_tensor, - paged_kv_last_page_len=paged_kv_last_page_len_tensor, - num_qo_heads=self.model_config.get_num_attention_heads( - self.parallel_config), - num_kv_heads=self.model_config.get_num_kv_heads( - self.parallel_config), - head_dim=self.model_config.get_head_size(), - page_size=self.block_size, - seq_start_loc=seq_start_loc, - query_start_loc=query_start_loc, - device=self.device, - data_type=kv_cache_dtype, - use_cuda_graph=use_captured_graph, - logits_soft_cap=logits_soft_cap) - - else: - attn_metadata = self.attn_backend.make_metadata( - num_prefills=num_prefills, - slot_mapping=slot_mapping_tensor, - num_prefill_tokens=num_prefill_tokens, - num_decode_tokens=num_decode_tokens, - seq_lens=seq_lens, - seq_lens_tensor=seq_lens_tensor, - max_query_len=max_query_len, - max_prefill_seq_len=max_prefill_seq_len, - max_decode_seq_len=max_decode_seq_len, - query_start_loc=query_start_loc, - seq_start_loc=seq_start_loc, - context_lens_tensor=context_lens_tensor, - block_tables=block_tables, - use_cuda_graph=use_captured_graph, - ) - - if self.lora_config: - lora_mapping = LoRAMapping( - lora_index_mapping, - lora_prompt_mapping, - ) - else: - lora_mapping = None - - if self.prompt_adapter_config: - prompt_adapter_mapping = PromptAdapterMapping( - prompt_adapter_index_mapping, - prompt_adapter_prompt_mapping, - ) - else: - prompt_adapter_mapping = None - - multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list, - device=self.device) - request_ids_to_seq_ids = { - seq_group_metadata.request_id: - list(seq_group_metadata.seq_data.keys()) - for seq_group_metadata in seq_group_metadata_list - } - return self._model_input_cls( - input_tokens=input_tokens_tensor, - input_positions=input_positions_tensor, - attn_metadata=attn_metadata, - seq_lens=seq_lens, - query_lens=query_lens, - lora_mapping=lora_mapping, - lora_requests=lora_requests, - multi_modal_kwargs=multi_modal_kwargs, - request_ids_to_seq_ids=request_ids_to_seq_ids, - finished_requests_ids=finished_requests_ids, - prompt_adapter_mapping=prompt_adapter_mapping, - prompt_adapter_requests=prompt_adapter_requests, - ) + builder.add_seq_group(seq_group_metadata) + return builder.build() # type: ignore @torch.inference_mode() def profile_run(self) -> None: diff --git a/vllm/worker/model_runner_base.py b/vllm/worker/model_runner_base.py index bc0960fa16221..bc7a6a73b17c4 100644 --- a/vllm/worker/model_runner_base.py +++ b/vllm/worker/model_runner_base.py @@ -113,6 +113,21 @@ def from_broadcasted_tensor_dict( raise NotImplementedError +class ModelRunnerInputBuilderBase(ABC, Generic[T]): + """A builder to create ModelRunnerInputBase objects. + """ + + @abstractmethod + def add_seq_group(self, seq_group_metadata): + """TBA""" + raise NotImplementedError + + @abstractmethod + def build(self, *args, **kwargs) -> T: + """Build metadata with on-device tensors.""" + raise NotImplementedError + + class ModelRunnerBase(ABC, Generic[T]): """ Model runner interface that abstracts a particular hardware and/or type of diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index 6c1149ee9dfca..bbf0db31ee383 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -1,5 +1,5 @@ import time -from typing import List, Mapping, Optional, Tuple +from typing import List, Optional, Tuple import numpy as np import torch @@ -12,8 +12,6 @@ from vllm.logger import init_logger from vllm.model_executor.model_loader import get_model from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensors, - MultiModalInputs) from vllm.sequence import (CompletionSequenceGroupOutput, Logprob, SamplerOutput, SequenceGroupMetadata, SequenceOutput) @@ -68,10 +66,6 @@ def __init__( False, ) - # Multi-modal data support - self.multi_modal_input_mapper = MULTIMODAL_REGISTRY \ - .create_input_mapper(self.model_config) - def load_model(self) -> None: self.device = self.device_config.device @@ -154,7 +148,7 @@ def _dummy_run( # Dummy run. num_samples = _MAX_NUM_SAMPLES if is_prompt else 1 self.model(token_ids, position_ids, kv_caches, attn_metadata, - input_lens, None, t, p, num_samples) + input_lens, t, p, num_samples) def warmup_model( self, @@ -199,14 +193,12 @@ def warmup_model( def _prepare_prompt( self, seq_group_metadata_list: List[SequenceGroupMetadata], - ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, torch.Tensor, - Mapping[str, BatchedTensors]]: + ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, torch.Tensor]: assert len(seq_group_metadata_list) > 0 input_tokens: List[List[int]] = [] input_positions: List[List[int]] = [] prompt_lens: List[int] = [] slot_mapping: List[List[int]] = [] - multi_modal_inputs_list: List[MultiModalInputs] = [] for seq_group_metadata in seq_group_metadata_list: assert seq_group_metadata.is_prompt @@ -232,11 +224,6 @@ def _prepare_prompt( slot = block_number * self.block_size + block_offset slot_mapping[-1].append(slot) - mm_data = seq_group_metadata.multi_modal_data - if mm_data: - mm_kwargs = self.multi_modal_input_mapper(mm_data) - multi_modal_inputs_list.append(mm_kwargs) - assert len(prompt_lens) > 0 num_prefills = len(prompt_lens) num_prefill_tokens = sum(prompt_lens) @@ -274,24 +261,17 @@ def _prepare_prompt( block_tables=None, context_lens=None, ) - - multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list, - device=self.device) - - return (input_tokens, input_positions, attn_metadata, prompt_lens, - multi_modal_kwargs) + return input_tokens, input_positions, attn_metadata, prompt_lens def _prepare_decode( self, seq_group_metadata_list: List[SequenceGroupMetadata], - ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, torch.Tensor, - Mapping[str, BatchedTensors]]: + ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, torch.Tensor]: assert len(seq_group_metadata_list) > 0 input_tokens: List[List[int]] = [] input_positions: List[List[int]] = [] slot_mapping: List[List[int]] = [] context_lens: List[int] = [] - multi_modal_inputs_list: List[MultiModalInputs] = [] batch_idx = 0 for seq_group_metadata in seq_group_metadata_list: @@ -317,11 +297,6 @@ def _prepare_decode( slot = block_number * self.block_size + block_offset slot_mapping.append([slot]) - mm_data = seq_group_metadata.multi_modal_data - if mm_data: - mm_kwargs = self.multi_modal_input_mapper(mm_data) - multi_modal_inputs_list.append(mm_kwargs) - batch_size = _get_padded_batch_size(batch_idx) num_paddings = batch_size - batch_idx input_tokens = input_tokens + [[0]] * num_paddings @@ -355,12 +330,7 @@ def _prepare_decode( block_tables=block_tables, context_lens=context_lens, ) - - multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list, - device=self.device) - - return (input_tokens, input_positions, attn_metadata, input_lens, - multi_modal_kwargs) + return input_tokens, input_positions, attn_metadata, input_lens def _prepare_sample( self, @@ -513,7 +483,6 @@ def forward( kv_caches: List[Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]], attn_metadata: AttentionMetadata, input_lens: torch.Tensor, - multi_modal_kwargs: Optional[Mapping[str, BatchedTensors]], t: torch.Tensor, p: torch.Tensor, num_samples: int, @@ -527,8 +496,6 @@ def forward( memory profiling at initialization. attn_metadata: The Pallas attention metadata. input_lens: The actual input lengths of shape [batch_size]. - multi_modal_kwargs: Keyword arguments from multi-modal data to - pass to the model. t: The sampling temperature of shape [batch_size]. p: The top-p probability of shape [batch_size]. """ @@ -573,7 +540,6 @@ def forward( position_ids, kv_caches, attn_metadata, - **(multi_modal_kwargs or {}), ) hidden_states = hidden_states.flatten(0, 1) logits = self.model.compute_logits(hidden_states, sampling_metadata) diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 56d8587f8f010..f3c379d1aa34d 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -105,7 +105,7 @@ def __init__( # initialize_cache. self.cache_engine: List[CacheEngine] # Initialize gpu_cache as embedding models don't initialize kv_caches - self.gpu_cache: Optional[List[List[torch.tensor]]] = None + self.gpu_cache: Optional[List[List[torch.Tensor]]] = None def init_device(self) -> None: if self.device_config.device.type == "cuda":