diff --git a/.buildkite/run-gh200-test.sh b/.buildkite/run-gh200-test.sh new file mode 100644 index 0000000000000..d06604f96f2b8 --- /dev/null +++ b/.buildkite/run-gh200-test.sh @@ -0,0 +1,25 @@ +#!/bin/bash + +# This script build the GH200 docker image and run the offline inference inside the container. +# It serves a sanity check for compilation and basic model usage. +set -ex + +# Try building the docker image +DOCKER_BUILDKIT=1 docker build . \ + --target vllm-openai \ + --platform "linux/arm64" \ + -t gh200-test \ + --build-arg max_jobs=66 \ + --build-arg nvcc_threads=2 \ + --build-arg torch_cuda_arch_list="9.0+PTX" \ + --build-arg vllm_fa_cmake_gpu_arches="90-real" + +# Setup cleanup +remove_docker_container() { docker rm -f gh200-test || true; } +trap remove_docker_container EXIT +remove_docker_container + +# Run the image and test offline inference +docker run --name gh200-test --gpus=all --entrypoint="" gh200-test bash -c ' + python3 examples/offline_inference.py +' diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 8f57006214c88..44f47fac1c1b3 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -181,14 +181,14 @@ steps: commands: - VLLM_USE_V1=1 pytest -v -s v1 -- label: Examples Test # 15min +- label: Examples Test # 25min working_dir: "/vllm-workspace/examples" #mirror_hardwares: [amd] source_file_dependencies: - vllm/entrypoints - examples/ commands: - - pip install awscli tensorizer # for llava example and tensorizer test + - pip install tensorizer # for tensorizer test - python3 offline_inference.py - python3 cpu_offload.py - python3 offline_inference_chat.py @@ -198,10 +198,13 @@ steps: - python3 offline_inference_vision_language_multi_image.py - python3 tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors - python3 offline_inference_encoder_decoder.py - - python3 offline_profile.py --model facebook/opt-125m + - python3 offline_inference_classification.py + - python3 offline_inference_embedding.py + - python3 offline_inference_scoring.py + - python3 offline_profile.py --model facebook/opt-125m run_num_steps --num-steps 2 - label: Prefix Caching Test # 9min - #mirror_hardwares: [amd] + mirror_hardwares: [amd] source_file_dependencies: - vllm/ - tests/prefix_caching @@ -321,7 +324,7 @@ steps: ##### models test ##### -- label: Basic Models Test # 30min +- label: Basic Models Test # 24min source_file_dependencies: - vllm/ - tests/models @@ -331,7 +334,7 @@ steps: - pytest -v -s models/test_registry.py - pytest -v -s models/test_initialization.py -- label: Language Models Test (Standard) # 42min +- label: Language Models Test (Standard) # 32min #mirror_hardwares: [amd] source_file_dependencies: - vllm/ @@ -342,7 +345,7 @@ steps: - pytest -v -s models/decoder_only/language -m 'core_model or quant_model' - pytest -v -s models/embedding/language -m core_model -- label: Language Models Test (Extended) # 50min +- label: Language Models Test (Extended) # 1h10min optional: true source_file_dependencies: - vllm/ @@ -353,7 +356,7 @@ steps: - pytest -v -s models/decoder_only/language -m 'not core_model and not quant_model' - pytest -v -s models/embedding/language -m 'not core_model' -- label: Multi-Modal Models Test (Standard) # 26min +- label: Multi-Modal Models Test (Standard) # 28min #mirror_hardwares: [amd] source_file_dependencies: - vllm/ @@ -369,7 +372,7 @@ steps: - pytest -v -s models/encoder_decoder/language -m core_model - pytest -v -s models/encoder_decoder/vision_language -m core_model -- label: Multi-Modal Models Test (Extended) # 1h15m +- label: Multi-Modal Models Test (Extended) 1 # 1h16m optional: true source_file_dependencies: - vllm/ @@ -380,14 +383,24 @@ steps: commands: - pip install git+https://github.com/TIGER-AI-Lab/Mantis.git - pytest -v -s models/decoder_only/audio_language -m 'not core_model and not quant_model' + - pytest -v -s models/decoder_only/vision_language/test_models.py -m 'split(group=0) and not core_model and not quant_model' # HACK - run phi3v tests separately to sidestep this transformers bug # https://github.com/huggingface/transformers/issues/34307 - pytest -v -s models/decoder_only/vision_language/test_phi3v.py - - pytest -v -s --ignore models/decoder_only/vision_language/test_phi3v.py models/decoder_only/vision_language -m 'not core_model and not quant_model' + - pytest -v -s --ignore models/decoder_only/vision_language/test_models.py --ignore models/decoder_only/vision_language/test_phi3v.py models/decoder_only/vision_language -m 'not core_model and not quant_model' - pytest -v -s models/embedding/vision_language -m 'not core_model' - pytest -v -s models/encoder_decoder/language -m 'not core_model' - pytest -v -s models/encoder_decoder/vision_language -m 'not core_model' +- label: Multi-Modal Models Test (Extended) 2 # 38m + optional: true + source_file_dependencies: + - vllm/ + - tests/models/decoder_only/vision_language + commands: + - pip install git+https://github.com/TIGER-AI-Lab/Mantis.git + - pytest -v -s models/decoder_only/vision_language/test_models.py -m 'split(group=1) and not core_model and not quant_model' + # This test is used only in PR development phase to test individual models and should never run on main - label: Custom Models Test optional: true @@ -422,11 +435,11 @@ steps: - tests/distributed/ commands: - # the following commands are for the first node, with ip 192.168.10.10 (ray environment already set up) - - VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py | grep -q 'Same node test passed' + - VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py | grep 'Same node test passed' - VLLM_MULTI_NODE=1 pytest -v -s distributed/test_multi_node_assignment.py - VLLM_MULTI_NODE=1 pytest -v -s distributed/test_pipeline_parallel.py - # the following commands are for the second node, with ip 192.168.10.11 (ray environment already set up) - - VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py | grep -q 'Same node test passed' + - VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py | grep 'Same node test passed' - label: Distributed Tests (2 GPUs) # 40min #mirror_hardwares: [amd] @@ -445,12 +458,12 @@ steps: commands: - pytest -v -s ./compile/test_basic_correctness.py - pytest -v -s ./compile/test_wrapper.py - - VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep -q 'Same node test passed' - - TARGET_TEST_SUITE=L4 pytest basic_correctness/ -v -s -m distributed_2_gpus + - VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep 'Same node test passed' + - TARGET_TEST_SUITE=L4 pytest basic_correctness/ -v -s -m 'distributed(num_gpus=2)' # Avoid importing model tests that cause CUDA reinitialization error - - pytest models/encoder_decoder/language/test_bart.py -v -s -m distributed_2_gpus - - pytest models/encoder_decoder/vision_language/test_broadcast.py -v -s -m distributed_2_gpus - - pytest models/decoder_only/vision_language/test_models.py -v -s -m distributed_2_gpus + - pytest models/encoder_decoder/language/test_bart.py -v -s -m 'distributed(num_gpus=2)' + - pytest models/encoder_decoder/vision_language/test_broadcast.py -v -s -m 'distributed(num_gpus=2)' + - pytest models/decoder_only/vision_language/test_models.py -v -s -m 'distributed(num_gpus=2)' - pytest -v -s spec_decode/e2e/test_integration_dist_tp2.py - pip install -e ./plugins/vllm_add_dummy_model - pytest -v -s distributed/test_distributed_oot.py @@ -540,7 +553,7 @@ steps: # see https://github.com/vllm-project/vllm/pull/5689 for details - pytest -v -s distributed/test_custom_all_reduce.py - torchrun --nproc_per_node=2 distributed/test_ca_buffer_sharing.py - - TARGET_TEST_SUITE=A100 pytest basic_correctness/ -v -s -m distributed_2_gpus + - TARGET_TEST_SUITE=A100 pytest basic_correctness/ -v -s -m 'distributed(num_gpus=2)' - pytest -v -s -x lora/test_mixtral.py - label: LM Eval Large Models # optional diff --git a/CMakeLists.txt b/CMakeLists.txt index c78cdc77a7e42..bf19b3d227171 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -196,6 +196,7 @@ set(VLLM_EXT_SRC "csrc/quantization/gptq/q_gemm.cu" "csrc/quantization/compressed_tensors/int8_quant_kernels.cu" "csrc/quantization/fp8/common.cu" + "csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu" "csrc/quantization/gguf/gguf_kernel.cu" "csrc/cuda_utils_kernels.cu" "csrc/prepare_inputs/advance_step.cu" @@ -300,7 +301,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") # # For the cutlass_scaled_mm kernels we want to build the c2x (CUTLASS 2.x) # kernels for the remaining archs that are not already built for 3x. - cuda_archs_loose_intersection(SCALED_MM_2X_ARCHS + cuda_archs_loose_intersection(SCALED_MM_2X_ARCHS "7.5;8.0;8.6;8.7;8.9;9.0" "${CUDA_ARCHS}") # subtract out the archs that are already built for 3x list(REMOVE_ITEM SCALED_MM_2X_ARCHS ${SCALED_MM_3X_ARCHS}) diff --git a/Dockerfile b/Dockerfile index 1aa863fd6d742..285ebac685b44 100644 --- a/Dockerfile +++ b/Dockerfile @@ -11,6 +11,7 @@ ARG CUDA_VERSION=12.4.1 FROM nvidia/cuda:${CUDA_VERSION}-devel-ubuntu20.04 AS base ARG CUDA_VERSION=12.4.1 ARG PYTHON_VERSION=3.12 +ARG TARGETPLATFORM ENV DEBIAN_FRONTEND=noninteractive # Install Python and other dependencies @@ -46,9 +47,14 @@ WORKDIR /workspace # install build and runtime dependencies COPY requirements-common.txt requirements-common.txt COPY requirements-cuda.txt requirements-cuda.txt +COPY requirements-cuda-arm64.txt requirements-cuda-arm64.txt RUN --mount=type=cache,target=/root/.cache/pip \ python3 -m pip install -r requirements-cuda.txt +RUN --mount=type=cache,target=/root/.cache/pip \ + if [ "$TARGETPLATFORM" = "linux/arm64" ]; then \ + python3 -m pip install -r requirements-cuda-arm64.txt; \ + fi # cuda arch list used by torch # can be useful for both `dev` and `test` @@ -63,6 +69,7 @@ ENV VLLM_FA_CMAKE_GPU_ARCHES=${vllm_fa_cmake_gpu_arches} #################### WHEEL BUILD IMAGE #################### FROM base AS build +ARG TARGETPLATFORM # install build dependencies COPY requirements-build.txt requirements-build.txt @@ -70,6 +77,11 @@ COPY requirements-build.txt requirements-build.txt RUN --mount=type=cache,target=/root/.cache/pip \ python3 -m pip install -r requirements-build.txt +RUN --mount=type=cache,target=/root/.cache/pip \ + if [ "$TARGETPLATFORM" = "linux/arm64" ]; then \ + python3 -m pip install -r requirements-cuda-arm64.txt; \ + fi + COPY . . ARG GIT_REPO_CHECK=0 RUN --mount=type=bind,source=.git,target=.git \ @@ -134,8 +146,8 @@ COPY requirements-test.txt requirements-test.txt COPY requirements-dev.txt requirements-dev.txt RUN --mount=type=cache,target=/root/.cache/pip \ python3 -m pip install -r requirements-dev.txt - #################### DEV IMAGE #################### + #################### vLLM installation IMAGE #################### # image with vLLM installed FROM nvidia/cuda:${CUDA_VERSION}-base-ubuntu22.04 AS vllm-base @@ -143,6 +155,9 @@ ARG CUDA_VERSION=12.4.1 ARG PYTHON_VERSION=3.12 WORKDIR /vllm-workspace ENV DEBIAN_FRONTEND=noninteractive +ARG TARGETPLATFORM + +COPY requirements-cuda-arm64.txt requirements-cuda-arm64.txt RUN PYTHON_VERSION_STR=$(echo ${PYTHON_VERSION} | sed 's/\.//g') && \ echo "export PYTHON_VERSION_STR=${PYTHON_VERSION_STR}" >> /etc/environment @@ -168,18 +183,25 @@ RUN echo 'tzdata tzdata/Areas select America' | debconf-set-selections \ # or future versions of triton. RUN ldconfig /usr/local/cuda-$(echo $CUDA_VERSION | cut -d. -f1,2)/compat/ -# install vllm wheel first, so that torch etc will be installed +# Install vllm wheel first, so that torch etc will be installed. RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist \ --mount=type=cache,target=/root/.cache/pip \ python3 -m pip install dist/*.whl --verbose RUN --mount=type=cache,target=/root/.cache/pip \ - . /etc/environment && \ - python3 -m pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.1.6/flashinfer-0.1.6+cu121torch2.4-cp${PYTHON_VERSION_STR}-cp${PYTHON_VERSION_STR}-linux_x86_64.whl + if [ "$TARGETPLATFORM" = "linux/arm64" ]; then \ + pip uninstall -y torch && \ + python3 -m pip install -r requirements-cuda-arm64.txt; \ + fi + +RUN --mount=type=cache,target=/root/.cache/pip \ +. /etc/environment && \ +if [ "$TARGETPLATFORM" != "linux/arm64" ]; then \ + python3 -m pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.1.6/flashinfer-0.1.6+cu121torch2.4-cp${PYTHON_VERSION_STR}-cp${PYTHON_VERSION_STR}-linux_x86_64.whl; \ +fi COPY examples examples #################### vLLM installation IMAGE #################### - #################### TEST IMAGE #################### # image to run unit testing suite # note that this uses vllm installed by `pip` @@ -209,7 +231,6 @@ COPY vllm/v1 /usr/local/lib/python3.12/dist-packages/vllm/v1 RUN mkdir test_docs RUN mv docs test_docs/ RUN mv vllm test_docs/ - #################### TEST IMAGE #################### #################### OPENAI API SERVER #################### @@ -218,8 +239,11 @@ FROM vllm-base AS vllm-openai # install additional dependencies for openai api server RUN --mount=type=cache,target=/root/.cache/pip \ - pip install accelerate hf_transfer 'modelscope!=1.15.0' 'bitsandbytes>=0.44.0' timm==0.9.10 boto3 runai-model-streamer runai-model-streamer[s3] - + if [ "$TARGETPLATFORM" = "linux/arm64" ]; then \ + pip install accelerate hf_transfer 'modelscope!=1.15.0' 'bitsandbytes>=0.42.0' 'timm==0.9.10' boto3 runai-model-streamer runai-model-streamer[s3]; \ + else \ + pip install accelerate hf_transfer 'modelscope!=1.15.0' 'bitsandbytes>=0.45.0' 'timm==0.9.10' boto3 runai-model-streamer runai-model-streamer[s3]; \ + fi ENV VLLM_USAGE_SOURCE production-docker-image ENTRYPOINT ["python3", "-m", "vllm.entrypoints.openai.api_server"] diff --git a/README.md b/README.md index ed5161ccffb45..93b71ddaccc61 100644 --- a/README.md +++ b/README.md @@ -134,3 +134,7 @@ If you use vLLM for your research, please cite our [paper](https://arxiv.org/abs * For coordinating contributions and development, please use Slack. * For security disclosures, please use Github's security advisory feature. * For collaborations and partnerships, please contact us at vllm-questions AT lists.berkeley.edu. + +## Media Kit + +* If you wish to use vLLM's logo, please refer to [our media kit repo](https://github.com/vllm-project/media-kit). diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py index 3256692142c5e..4eb0e1f8ac903 100644 --- a/benchmarks/benchmark_serving.py +++ b/benchmarks/benchmark_serving.py @@ -781,6 +781,7 @@ def main(args: argparse.Namespace): backend = args.backend model_id = args.model tokenizer_id = args.tokenizer if args.tokenizer is not None else args.model + tokenizer_mode = args.tokenizer_mode if args.base_url is not None: api_url = f"{args.base_url}{args.endpoint}" @@ -790,6 +791,7 @@ def main(args: argparse.Namespace): base_url = f"http://{args.host}:{args.port}" tokenizer = get_tokenizer(tokenizer_id, + tokenizer_mode=tokenizer_mode, trust_remote_code=args.trust_remote_code) if args.dataset is not None: @@ -1210,5 +1212,15 @@ def main(args: argparse.Namespace): "from the sampled HF dataset.", ) + parser.add_argument( + '--tokenizer-mode', + type=str, + default="auto", + choices=['auto', 'slow', 'mistral'], + help='The tokenizer mode.\n\n* "auto" will use the ' + 'fast tokenizer if available.\n* "slow" will ' + 'always use the slow tokenizer. \n* ' + '"mistral" will always use the `mistral_common` tokenizer.') + args = parser.parse_args() main(args) diff --git a/benchmarks/fused_kernels/layernorm_rms_benchmarks.py b/benchmarks/fused_kernels/layernorm_rms_benchmarks.py new file mode 100644 index 0000000000000..ef91f9f8eb529 --- /dev/null +++ b/benchmarks/fused_kernels/layernorm_rms_benchmarks.py @@ -0,0 +1,173 @@ +import pickle as pkl +import time +from dataclasses import dataclass +from itertools import product +from typing import Callable, Iterable, List, Optional + +import torch +import torch.utils.benchmark as TBenchmark +from torch.utils.benchmark import Measurement as TMeasurement +from tqdm import tqdm + +import vllm._custom_ops as ops +from vllm.model_executor.layers.layernorm import RMSNorm + + +@dataclass +class bench_params_t: + num_tokens: int + hidden_size: int + add_residual: bool + dtype: torch.dtype + + def description(self): + return (f'N {self.num_tokens} ' + f'x D {self.hidden_size} ' + f'x R {self.add_residual} ' + f'x DT {self.dtype}') + + +def get_bench_params() -> List[bench_params_t]: + ## Test Fixtures + NUM_TOKENS = [2**x for x in range(11)] + HIDDEN_SIZES = list(range(1024, 8129, 1024)) + ADD_RESIDUAL = [True, False] + DTYPES = [torch.bfloat16, torch.float] + + combinations = product(NUM_TOKENS, HIDDEN_SIZES, ADD_RESIDUAL, DTYPES) + bench_params = list(map(lambda x: \ + bench_params_t(x[0], x[1], x[2], x[3]), combinations)) + return bench_params + + +# Reference impls +def unfused_int8_impl(rms_norm_layer: RMSNorm, x: torch.Tensor, + residual: Optional[torch.Tensor], + quant_dtype: torch.dtype): + # Norm + torch_out = None + if residual is None: + torch_out = rms_norm_layer.forward_cuda(x, residual) + else: + torch_out, _ = rms_norm_layer.forward_cuda(x, residual) + + # Quant + torch_out, _, _ = ops.scaled_int8_quant(torch_out) + + +def unfused_fp8_impl(rms_norm_layer: RMSNorm, x: torch.Tensor, + residual: Optional[torch.Tensor], + quant_dtype: torch.dtype): + # Norm + torch_out = None + if residual is None: + torch_out = rms_norm_layer.forward_cuda(x, residual) + else: + torch_out, _ = rms_norm_layer.forward_cuda(x, residual) + + # Quant + torch_out, _ = ops.scaled_fp8_quant(torch_out) + + +def fused_impl( + rms_norm_layer: RMSNorm, # this stores the weights + x: torch.Tensor, + residual: Optional[torch.Tensor], + quant_dtype: torch.dtype): + out, _ = ops.rms_norm_dynamic_per_token_quant(x, + rms_norm_layer.weight, + 1e-6, + quant_dtype, + residual=residual) + + +# Bench functions +def bench_fn(rms_norm_layer: RMSNorm, x: torch.Tensor, residual: torch.Tensor, + quant_dtype: torch.dtype, label: str, sub_label: str, + fn: Callable, description: str) -> TMeasurement: + + min_run_time = 1 + + globals = { + "rms_norm_layer": rms_norm_layer, + "x": x, + "residual": residual, + "quant_dtype": quant_dtype, + "fn": fn, + } + return TBenchmark.Timer( + stmt="fn(rms_norm_layer, x, residual, quant_dtype)", + globals=globals, + label=label, + sub_label=sub_label, + description=description, + ).blocked_autorange(min_run_time=min_run_time) + +def bench(params: bench_params_t, label: str, sub_label: str) \ + -> Iterable[TMeasurement]: + + # Make inputs + layer = RMSNorm(params.hidden_size, 1e-6).to(dtype=params.dtype) + # Make weights + layer.weight.data.normal_(mean=1.0, std=0.1) + # Make inputs + scale = 1 / params.hidden_size + x = torch.randn(params.num_tokens, + params.hidden_size, + dtype=params.dtype, + device='cuda') * scale + residual = (torch.randn_like(x) * scale).to(device='cuda') \ + if params.add_residual else None + + timers = [] + + # unfused int8 impl. + timers.append( + bench_fn(layer, x, residual, torch.int8, label, sub_label, + unfused_int8_impl, "unfused_int8_impl")) + + # unfused fp8 impl. + timers.append( + bench_fn(layer, x, residual, torch.float8_e4m3fn, label, sub_label, + unfused_fp8_impl, "unfused_fp8_impl")) + + # fused int8 impl. + timers.append( + bench_fn(layer, x, residual, torch.int8, label, sub_label, fused_impl, + "fused_int8_impl")) + + # fused fp8 impl. + timers.append( + bench_fn(layer, x, residual, torch.float8_e4m3fn, label, sub_label, + fused_impl, "fused_fp8_impl")) + + print_timers(timers) + + return timers + + +# launch bench +# runner +def print_timers(timers: Iterable[TMeasurement]): + compare = TBenchmark.Compare(timers) + compare.print() + + +def main(): + torch.set_default_device('cuda') + bench_params = get_bench_params() + + timers = [] + for bp in tqdm(bench_params): + timers.extend( + bench(bp, "rms-norm-dynamic-per-token-quant", bp.description())) + print_timers(timers) + + # pickle all the results + timestamp = int(time.time()) + with open(f"rms_norm_dpt_quant-{timestamp}.pkl", "wb") as f: + pkl.dump(timers, f) + + +if __name__ == '__main__': + main() diff --git a/benchmarks/kernels/benchmark_rmsnorm.py b/benchmarks/kernels/benchmark_rmsnorm.py new file mode 100644 index 0000000000000..baa5de0fff1bd --- /dev/null +++ b/benchmarks/kernels/benchmark_rmsnorm.py @@ -0,0 +1,262 @@ +import itertools +from typing import Optional, Tuple, Union + +import torch +import triton +from flashinfer.norm import fused_add_rmsnorm, rmsnorm +from torch import nn + +from vllm import _custom_ops as vllm_ops + + +class HuggingFaceRMSNorm(nn.Module): + + def __init__(self, hidden_size: int, eps: float = 1e-6) -> None: + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward( + self, + x: torch.Tensor, + residual: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + orig_dtype = x.dtype + x = x.to(torch.float32) + if residual is not None: + x = x + residual.to(torch.float32) + residual = x.to(orig_dtype) + + variance = x.pow(2).mean(dim=-1, keepdim=True) + x = x * torch.rsqrt(variance + self.variance_epsilon) + x = x.to(orig_dtype) * self.weight + if residual is None: + return x + else: + return x, residual + + +def rmsnorm_naive( + x: torch.Tensor, + weight: torch.Tensor, + residual: Optional[torch.Tensor] = None, + eps: float = 1e-6, +): + naive_norm = HuggingFaceRMSNorm(x.shape[-1], eps=eps) + naive_norm.weight = nn.Parameter(weight) + naive_norm = naive_norm.to(x.device) + + orig_shape = x.shape + x = x.view(-1, x.shape[-1]) + if residual is not None: + residual = residual.view(-1, residual.shape[-1]) + + output = naive_norm(x, residual) + + if isinstance(output, tuple): + output = (output[0].view(orig_shape), output[1].view(orig_shape)) + else: + output = output.view(orig_shape) + return output + + +def rmsnorm_flashinfer( + x: torch.Tensor, + weight: torch.Tensor, + residual: Optional[torch.Tensor] = None, + eps: float = 1e-6, +): + orig_shape = x.shape + x = x.view(-1, x.shape[-1]) + if residual is not None: + residual = residual.view(-1, residual.shape[-1]) + + if residual is not None: + fused_add_rmsnorm(x, residual, weight, eps) + output = (x, residual) + else: + output = rmsnorm(x, weight, eps) + + if isinstance(output, tuple): + output = (output[0].view(orig_shape), output[1].view(orig_shape)) + else: + output = output.view(orig_shape) + return output + + +def rmsnorm_vllm( + x: torch.Tensor, + weight: torch.Tensor, + residual: Optional[torch.Tensor] = None, + eps: float = 1e-6, +): + orig_shape = x.shape + x = x.view(-1, x.shape[-1]) + if residual is not None: + residual = residual.view(-1, residual.shape[-1]) + + if residual is not None: + vllm_ops.fused_add_rms_norm(x, residual, weight, eps) + output = (x, residual) + else: + out = torch.empty_like(x) + vllm_ops.rms_norm(out, x, weight, eps) + output = out + + if isinstance(output, tuple): + output = (output[0].view(orig_shape), output[1].view(orig_shape)) + else: + output = output.view(orig_shape) + return output + + +def calculate_diff(batch_size, seq_len, hidden_size, use_residual=True): + dtype = torch.bfloat16 + x = torch.randn(batch_size, + seq_len, + hidden_size, + dtype=dtype, + device="cuda") + weight = torch.ones(hidden_size, dtype=dtype, device="cuda") + residual = torch.randn_like(x) if use_residual else None + + output_naive = rmsnorm_naive( + x.clone(), weight, + residual.clone() if residual is not None else None) + output_flashinfer = rmsnorm_flashinfer( + x.clone(), weight, + residual.clone() if residual is not None else None) + output_vllm = rmsnorm_vllm( + x.clone(), weight, + residual.clone() if residual is not None else None) + + if use_residual: + output_naive = output_naive[0] + output_flashinfer = output_flashinfer[0] + output_vllm = output_vllm[0] + + print(f"Naive output={output_naive}") + print(f"FlashInfer output={output_flashinfer}") + print(f"VLLM output={output_vllm}") + + if torch.allclose(output_naive, output_flashinfer, atol=1e-2, + rtol=1e-2) and torch.allclose( + output_naive, output_vllm, atol=1e-2, rtol=1e-2): + print("✅ All implementations match") + else: + print("❌ Implementations differ") + + +batch_size_range = [2**i for i in range(0, 7, 2)] +seq_length_range = [2**i for i in range(6, 11, 1)] +head_num_range = [32, 48] +configs = list( + itertools.product(head_num_range, batch_size_range, seq_length_range)) + + +def get_benchmark(use_residual): + + @triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["head_num", "batch_size", "seq_len"], + x_vals=[list(_) for _ in configs], + line_arg="provider", + line_vals=["huggingface", "flashinfer", "vllm"], + line_names=["HuggingFace", "FlashInfer", "vLLM"], + styles=[("blue", "-"), ("green", "-"), ("red", "-")], + ylabel="us", + plot_name= + f"rmsnorm-perf-{'with' if use_residual else 'without'}-residual", + args={}, + )) + def benchmark(head_num, batch_size, seq_len, provider): + dtype = torch.bfloat16 + hidden_size = head_num * 128 # assuming head_dim = 128 + + x = torch.randn(batch_size, + seq_len, + hidden_size, + dtype=dtype, + device="cuda") + weight = torch.ones(hidden_size, dtype=dtype, device="cuda") + residual = torch.randn_like(x) if use_residual else None + + quantiles = [0.5, 0.2, 0.8] + + if provider == "huggingface": + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: rmsnorm_naive( + x.clone(), + weight, + residual.clone() if residual is not None else None, + ), + quantiles=quantiles, + ) + elif provider == "flashinfer": + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: rmsnorm_flashinfer( + x.clone(), + weight, + residual.clone() if residual is not None else None, + ), + quantiles=quantiles, + ) + else: + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: rmsnorm_vllm( + x.clone(), + weight, + residual.clone() if residual is not None else None, + ), + quantiles=quantiles, + ) + + return 1000 * ms, 1000 * max_ms, 1000 * min_ms + + return benchmark + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument( + "--batch-size", + type=int, + default=4, + help="Batch size", + ) + parser.add_argument( + "--seq-len", + type=int, + default=128, + help="Sequence length", + ) + parser.add_argument( + "--hidden-size", + type=int, + default=4096, + help="Hidden size (2nd dimension) of the sequence", + ) + parser.add_argument("--use-residual", + action="store_true", + help="Whether to use residual connection") + parser.add_argument( + "--save-path", + type=str, + default="./configs/rmsnorm/", + help="Path to save rmsnorm benchmark results", + ) + + args = parser.parse_args() + + # Run correctness test + calculate_diff(batch_size=args.batch_size, + seq_len=args.seq_len, + hidden_size=args.hidden_size, + use_residual=args.use_residual) + + # Get the benchmark function with proper use_residual setting + benchmark = get_benchmark(args.use_residual) + # Run performance benchmark + benchmark.run(print_data=True, save_path=args.save_path) diff --git a/csrc/dispatch_utils.h b/csrc/dispatch_utils.h index a634e1c3d4886..03414b7e1ae93 100644 --- a/csrc/dispatch_utils.h +++ b/csrc/dispatch_utils.h @@ -14,6 +14,20 @@ #define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)) +// TODO(luka/varun): use FP8_TYPE macro after refactoring +#ifndef USE_ROCM + #define VLLM_DISPATCH_CASE_QUANT_TYPES(...) \ + AT_DISPATCH_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) +#else + #define VLLM_DISPATCH_CASE_QUANT_TYPES(...) \ + AT_DISPATCH_CASE(at::ScalarType::Float8_e4m3fnuz, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) +#endif + +#define VLLM_DISPATCH_QUANT_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_QUANT_TYPES(__VA_ARGS__)) + #define VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(...) \ AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ diff --git a/csrc/ops.h b/csrc/ops.h index ea001190bc202..816b471d062d2 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -66,6 +66,14 @@ void fused_add_rms_norm_static_fp8_quant(torch::Tensor& out, torch::Tensor& weight, torch::Tensor& scale, double epsilon); +void rms_norm_dynamic_per_token_quant(torch::Tensor& out, + torch::Tensor const& input, + torch::Tensor const& weight, + torch::Tensor& scales, + double const epsilon, + std::optional scale_ub, + std::optional residual); + void rotary_embedding(torch::Tensor& positions, torch::Tensor& query, torch::Tensor& key, int64_t head_size, torch::Tensor& cos_sin_cache, bool is_neox); diff --git a/csrc/quantization/fp8/common.cuh b/csrc/quantization/fp8/common.cuh index d7c0297d5333f..15bd5b6ed1564 100644 --- a/csrc/quantization/fp8/common.cuh +++ b/csrc/quantization/fp8/common.cuh @@ -1,6 +1,9 @@ #pragma once +#include "quantization/vectorization.cuh" + #include +#include #ifndef USE_ROCM #include @@ -15,6 +18,7 @@ using FP8_TYPE = c10::Float8_e4m3fnuz; // issue when running dynamic quantization. Here use 224.0f for rocm. constexpr auto FP8_E4M3_MAX = 224.0f; #endif +constexpr static auto kFp8Type = c10::CppTypeToScalarType::value; namespace vllm { @@ -89,22 +93,6 @@ __global__ void segmented_max_reduction(float* __restrict__ scale, } } -template -struct __align__(8) vec4_t { - scalar_t x; - scalar_t y; - scalar_t z; - scalar_t w; -}; - -typedef struct __align__(4) { - FP8_TYPE x; - FP8_TYPE y; - FP8_TYPE z; - FP8_TYPE w; -} -float8x4_t; - template __device__ float thread_max_vec(scalar_t const* __restrict__ input, int64_t const num_elems, int const tid, @@ -139,10 +127,10 @@ __device__ void scaled_fp8_conversion_vec(FP8_TYPE* __restrict__ out, float const scale, int64_t const num_elems, int const tid, int const step) { + using float8x4_t = q8x4_t; // Vectorized input/output to better utilize memory bandwidth. - vec4_t const* vectorized_in = - reinterpret_cast const*>(input); - float8x4_t* vectorized_out = reinterpret_cast(out); + auto const* vectorized_in = reinterpret_cast const*>(input); + auto* vectorized_out = reinterpret_cast(out); int64_t const num_vec_elems = num_elems >> 2; diff --git a/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu b/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu new file mode 100644 index 0000000000000..3c4f183bf4b59 --- /dev/null +++ b/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu @@ -0,0 +1,160 @@ + +#include +#include + +#include "../../dispatch_utils.h" +#include "layernorm_utils.cuh" +#include "quant_conversions.cuh" + +namespace vllm { + +template +__device__ void rms_norm_dynamic_per_token_quant_vec( + scalar_out_t* __restrict__ out, // [..., hidden_size] + float* __restrict__ scales, // [num_tokens] + scalar_t const* __restrict__ input, // [..., hidden_size] + scalar_t const* __restrict__ weight, // [hidden_size] + float const* scale_ub, float const var_epsilon, + float const min_scaling_factor, int32_t const hidden_size, + scalar_t* __restrict__ residual = nullptr) { + float rms = 0.0f; + float token_scale = 0.0f; + + // Compute rms + vllm::vectorized::compute_rms( + &rms, input, hidden_size, var_epsilon, residual); + + // Compute scale + vllm::vectorized::compute_dynamic_per_token_scales( + &token_scale, scales, input, weight, rms, scale_ub, min_scaling_factor, + hidden_size, residual); + + // RMS Norm + Quant + if constexpr (std::is_same_v) { + vllm::vectorized::norm_and_quant( + out, input, weight, rms, 1.0f / token_scale, hidden_size, residual); + } else { + // FP8 - Do not invert token_scale for exact match with FBGemm + vllm::vectorized::norm_and_quant( + out, input, weight, rms, token_scale, hidden_size, residual); + } +} + +// RMS norm + quant kernel +template +__global__ void rms_norm_dynamic_per_token_quant_kernel( + scalar_out_t* __restrict__ out, // [..., hidden_size] + float* __restrict__ scales, // [num_tokens] + scalar_t const* __restrict__ input, // [..., hidden_size] + scalar_t const* __restrict__ weight, // [hidden_size] + float const* scale_ub, float const var_epsilon, + float const min_scaling_factor, int32_t const hidden_size, + scalar_t* __restrict__ residual = nullptr) { + // For vectorization, token_input and token_output pointers need to be + // aligned at 8-byte and 4-byte addresses respectively. + bool const can_vectorize = hidden_size % 4 == 0; + + if (can_vectorize) { + return rms_norm_dynamic_per_token_quant_vec( + out, scales, input, weight, scale_ub, var_epsilon, min_scaling_factor, + hidden_size, residual); + } + + float rms = 0.0f; + float token_scale = 0.0f; + + // Compute RMS + vllm::compute_rms(&rms, input, hidden_size, + var_epsilon, residual); + // Compute Scale + vllm::compute_dynamic_per_token_scales( + &token_scale, scales, input, weight, rms, scale_ub, min_scaling_factor, + hidden_size, residual); + + // RMS Norm + Quant + if constexpr (std::is_same_v) { + vllm::norm_and_quant( + out, input, weight, rms, 1.0f / token_scale, hidden_size, residual); + } else { + // FP8 - Do not invert s_token_scale for exact match with FBGemm + vllm::norm_and_quant( + out, input, weight, rms, token_scale, hidden_size, residual); + } +} +} // namespace vllm + +// Residual add + RMS norm + dynamic per token +template +void rms_norm_dynamic_per_token_quant_dispatch( + torch::Tensor& out, // [..., hidden_size] + torch::Tensor const& input, // [..., hidden_size] + torch::Tensor const& weight, // [hidden_size] + torch::Tensor& scales, // [num_tokens] + double const var_epsilon, // Variance epsilon used in norm calculation + std::optional const& scale_ub, + std::optional& residual) { + int32_t hidden_size = input.size(-1); + int32_t num_tokens = input.numel() / hidden_size; + + dim3 grid(num_tokens); + dim3 block(std::min(hidden_size, 1024)); + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + const float min_scaling_factor = + out.dtype() == torch::kInt8 + ? std::numeric_limits::epsilon() + : 1.0f / (std::numeric_limits::max() * 512.f); + + if (residual.has_value()) { + VLLM_DISPATCH_QUANT_TYPES( + out.scalar_type(), "rms_norm_dynamic_per_token_quant_kernel", [&] { + vllm::rms_norm_dynamic_per_token_quant_kernel + <<>>( + out.data_ptr(), scales.data_ptr(), + input.data_ptr(), weight.data_ptr(), + scale_ub.has_value() ? scale_ub->data_ptr() : nullptr, + var_epsilon, min_scaling_factor, hidden_size, + residual->data_ptr()); + }); + + } else { + VLLM_DISPATCH_QUANT_TYPES( + out.scalar_type(), "rms_norm_dynamic_per_token_quant_kernel", [&] { + vllm::rms_norm_dynamic_per_token_quant_kernel + <<>>( + out.data_ptr(), scales.data_ptr(), + input.data_ptr(), weight.data_ptr(), + scale_ub.has_value() ? scale_ub->data_ptr() : nullptr, + var_epsilon, min_scaling_factor, hidden_size, nullptr); + }); + } +} + +void rms_norm_dynamic_per_token_quant( + torch::Tensor& out, // [..., hidden_size] + torch::Tensor const& input, // [..., hidden_size] + torch::Tensor const& weight, // [hidden_size] + torch::Tensor& scales, // [num_tokens] + double const var_epsilon, // Variance epsilon used in norm calculation + std::optional scale_ub, std::optional residual) { + TORCH_CHECK(out.dtype() == kFp8Type || out.dtype() == torch::kInt8); + TORCH_CHECK(out.is_contiguous() && input.is_contiguous()); + + if (scale_ub.has_value()) { + TORCH_CHECK(out.dtype() == kFp8Type); + } + TORCH_CHECK(scales.dtype() == torch::kFloat32); + + VLLM_DISPATCH_FLOATING_TYPES( + input.scalar_type(), "rms_norm_dynamic_per_token_quant_dispatch", [&] { + rms_norm_dynamic_per_token_quant_dispatch( + out, input, weight, scales, var_epsilon, scale_ub, residual); + }); +} diff --git a/csrc/quantization/fused_kernels/layernorm_utils.cuh b/csrc/quantization/fused_kernels/layernorm_utils.cuh new file mode 100644 index 0000000000000..cec6b54edb569 --- /dev/null +++ b/csrc/quantization/fused_kernels/layernorm_utils.cuh @@ -0,0 +1,327 @@ +#pragma once + +/** + * __device__ layernorm utilities. + */ + +#include "quantization/vectorization.cuh" +#include "quant_conversions.cuh" + +#ifndef USE_ROCM + #include +#else + #include +#endif + +namespace vllm { + +// has_residual must be true, if residual is not a nullptr +template +__device__ void compute_rms(float* rms, scalar_t const* __restrict__ input, + int32_t const hidden_size, float const epsilon, + scalar_t const* __restrict__ residual = nullptr) { + int64_t const token_offset = blockIdx.x * static_cast(hidden_size); + // sum of squares + float ss = 0.0f; + + for (int32_t i = threadIdx.x; i < hidden_size; i += blockDim.x) { + float x = static_cast(input[token_offset + i]); + if constexpr (has_residual) { + x += static_cast(residual[token_offset + i]); + } + + ss += x * x; + } + + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage reduceStore; + ss = BlockReduce(reduceStore).Reduce(ss, cub::Sum{}, blockDim.x); + + __shared__ float s_rms; + if (threadIdx.x == 0) { + s_rms = rsqrtf(ss / hidden_size + epsilon); + } + __syncthreads(); + + *rms = s_rms; +} + +template +__device__ void compute_dynamic_per_token_scales( + float* __restrict__ token_scale, float* __restrict__ all_token_scales, + scalar_t const* __restrict__ input, scalar_t const* __restrict__ weight, + float const rms, float const* __restrict__ scale_ub, + float const min_scaling_factor, int32_t const hidden_size, + scalar_t const* __restrict__ residual = nullptr) { + int64_t const token_offset = blockIdx.x * static_cast(hidden_size); + ; + constexpr scalar_out_t qmax{std::numeric_limits::max()}; + + float block_absmax_val_maybe = 0.0f; + for (int32_t i = threadIdx.x; i < hidden_size; i += blockDim.x) { + float x = static_cast(input[token_offset + i]); + if constexpr (has_residual) { + x += static_cast(residual[token_offset + i]); + } + + x = static_cast(static_cast(x * rms) * weight[i]); + block_absmax_val_maybe = fmaxf(block_absmax_val_maybe, fabsf(x)); + } + + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage reduceStore; + block_absmax_val_maybe = + BlockReduce(reduceStore) + .Reduce(block_absmax_val_maybe, cub::Max{}, blockDim.x); + + __shared__ float s_token_scale; + if (threadIdx.x == 0) { + float scale = 0.0f; + if (scale_ub) { + scale = min(block_absmax_val_maybe, *scale_ub); + } else { + scale = block_absmax_val_maybe; + } + // token scale computation + scale = max(scale / qmax, min_scaling_factor); + s_token_scale = scale; // Shared memory store + all_token_scales[blockIdx.x] = scale; // Global output store + } + __syncthreads(); + + *token_scale = s_token_scale; +} + +template +__device__ void norm_and_quant(scalar_out_t* __restrict__ output, + scalar_t const* __restrict__ input, + scalar_t const* __restrict__ weight, + float const rms, float const scale, + int32_t const hidden_size, + scalar_t* __restrict__ residual = nullptr) { + int64_t const token_offset = blockIdx.x * static_cast(hidden_size); + ; + + for (int32_t i = threadIdx.x; i < hidden_size; i += blockDim.x) { + float x = static_cast(input[token_offset + i]); + if constexpr (has_residual) { + x += static_cast(residual[token_offset + i]); + residual[token_offset + i] = static_cast(x); + } + // Norm + x = static_cast(static_cast(x * rms) * weight[i]); + // Quant + output[token_offset + i] = + ScaledQuant::quant_fn(x, scale); + } +} + +namespace vectorized { + +// Compute 1.0/rms(input) +// hidden_size must be a multiple of 4 +template +__device__ void compute_rms(float* rms, scalar_t const* __restrict__ input, + int32_t const hidden_size, float const epsilon, + scalar_t const* __restrict__ residual = nullptr) { + int64_t const token_offset = blockIdx.x * static_cast(hidden_size); + + // Vectorized input/output to better utilize memory bandwidth. + vec4_t const* vec_input = + reinterpret_cast const*>(&input[token_offset]); + vec4_t const* vec_residual = nullptr; + if constexpr (has_residual) { + vec_residual = + reinterpret_cast const*>(&residual[token_offset]); + } + + // sum of squares + float ss = 0.0f; + + int32_t const num_vec_elems = hidden_size >> 2; + +#pragma unroll 4 + for (int32_t i = threadIdx.x; i < num_vec_elems; i += blockDim.x) { + vec4_t in = vec_input[i]; + + vec4_t x; + x.x = static_cast(in.x); + x.y = static_cast(in.y); + x.z = static_cast(in.z); + x.w = static_cast(in.w); + if constexpr (has_residual) { + vec4_t r = vec_residual[i]; + x.x += static_cast(r.x); + x.y += static_cast(r.y); + x.z += static_cast(r.z); + x.w += static_cast(r.w); + } + + ss += x.x * x.x; + ss += x.y * x.y; + ss += x.z * x.z; + ss += x.w * x.w; + } + + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage reduceStore; + ss = BlockReduce(reduceStore).Reduce(ss, cub::Sum{}, blockDim.x); + + __shared__ float s_rms; + if (threadIdx.x == 0) { + s_rms = rsqrtf(ss / hidden_size + epsilon); + } + __syncthreads(); + + *rms = s_rms; +} + +// Vectorized version of vllm::compute_dynamic_per_token_scales +// hidden_size must be a multiple of 4 +template +__device__ void compute_dynamic_per_token_scales( + float* __restrict__ token_scale, float* __restrict__ all_token_scales, + scalar_t const* __restrict__ input, scalar_t const* __restrict__ weight, + float const rms, float const* __restrict__ scale_ub, + float const min_scaling_factor, int32_t const hidden_size, + scalar_t const* __restrict__ residual = nullptr) { + int64_t const token_offset = blockIdx.x * static_cast(hidden_size); + ; + + // Vectorized input/weight/residual to better utilize memory bandwidth. + vec4_t const* vec_input = + reinterpret_cast const*>(&input[token_offset]); + vec4_t const* vec_weight = + reinterpret_cast const*>(weight); + vec4_t const* vec_residual = nullptr; + if constexpr (has_residual) { + vec_residual = + reinterpret_cast const*>(&residual[token_offset]); + } + + constexpr scalar_out_t qmax{std::numeric_limits::max()}; + + int32_t const num_vec_elems = hidden_size >> 2; + float block_absmax_val_maybe = 0.0f; + +#pragma unroll 4 + for (int32_t i = threadIdx.x; i < num_vec_elems; i += blockDim.x) { + vec4_t in = vec_input[i]; + vec4_t const w = vec_weight[i]; + + vec4_t x; + x.x = static_cast(in.x); + x.y = static_cast(in.y); + x.z = static_cast(in.z); + x.w = static_cast(in.w); + if constexpr (has_residual) { + vec4_t r = vec_residual[i]; + x.x += static_cast(r.x); + x.y += static_cast(r.y); + x.z += static_cast(r.z); + x.w += static_cast(r.w); + } + + block_absmax_val_maybe = fmaxf( + block_absmax_val_maybe, fabs(static_cast(x.x * rms) * w.x)); + block_absmax_val_maybe = fmaxf( + block_absmax_val_maybe, fabs(static_cast(x.y * rms) * w.y)); + block_absmax_val_maybe = fmaxf( + block_absmax_val_maybe, fabs(static_cast(x.z * rms) * w.z)); + block_absmax_val_maybe = fmaxf( + block_absmax_val_maybe, fabs(static_cast(x.w * rms) * w.w)); + } + + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage reduceStore; + block_absmax_val_maybe = + BlockReduce(reduceStore) + .Reduce(block_absmax_val_maybe, cub::Max{}, blockDim.x); + + __shared__ float s_token_scale; + if (threadIdx.x == 0) { + float scale = 0.0f; + if (scale_ub) { + scale = min(block_absmax_val_maybe, *scale_ub); + } else { + scale = block_absmax_val_maybe; + } + // token scale computation + scale = max(scale / qmax, min_scaling_factor); + s_token_scale = scale; // shared memory store + all_token_scales[blockIdx.x] = scale; // global output store + } + __syncthreads(); + + *token_scale = s_token_scale; +} + +// hidden_size must be a multiple of 4 +template +__device__ void norm_and_quant(scalar_out_t* __restrict__ output, + scalar_t const* __restrict__ input, + scalar_t const* __restrict__ weight, + float const rms, float const scale, + int32_t const hidden_size, + scalar_t* __restrict__ residual = nullptr) { + int64_t const token_offset = blockIdx.x * static_cast(hidden_size); + ; + + // Vectorized input/output/weight/residual to better utilize memory bandwidth. + vec4_t const* vec_input = + reinterpret_cast const*>(&input[token_offset]); + vec4_t const* vec_weight = + reinterpret_cast const*>(weight); + q8x4_t* vec_output = + reinterpret_cast*>(&output[token_offset]); + vec4_t* vec_residual = nullptr; + if constexpr (has_residual) { + vec_residual = reinterpret_cast*>(&residual[token_offset]); + } + + int32_t const num_vec_elems = hidden_size >> 2; + +// TODO(luka/varun) extract into type-agnostic vectorized quant function to +// replace scaled_fp8_conversion_vec +#pragma unroll 4 + for (int32_t i = threadIdx.x; i < num_vec_elems; i += blockDim.x) { + vec4_t const in = vec_input[i]; + vec4_t const w = vec_weight[i]; + + vec4_t x; + x.x = static_cast(in.x); + x.y = static_cast(in.y); + x.z = static_cast(in.z); + x.w = static_cast(in.w); + if constexpr (has_residual) { + vec4_t r = vec_residual[i]; + x.x += static_cast(r.x); + x.y += static_cast(r.y); + x.z += static_cast(r.z); + x.w += static_cast(r.w); + // Update residual + r.x = static_cast(x.x); + r.y = static_cast(x.y); + r.z = static_cast(x.z); + r.w = static_cast(x.w); + vec_residual[i] = r; + } + + q8x4_t out; + out.x = ScaledQuant::quant_fn( + static_cast(x.x * rms) * w.x, scale); + out.y = ScaledQuant::quant_fn( + static_cast(x.y * rms) * w.y, scale); + out.z = ScaledQuant::quant_fn( + static_cast(x.z * rms) * w.z, scale); + out.w = ScaledQuant::quant_fn( + static_cast(x.w * rms) * w.w, scale); + vec_output[i] = out; + } +} + +} // namespace vectorized + +} // namespace vllm diff --git a/csrc/quantization/fused_kernels/quant_conversions.cuh b/csrc/quantization/fused_kernels/quant_conversions.cuh new file mode 100644 index 0000000000000..f8a9872226a3a --- /dev/null +++ b/csrc/quantization/fused_kernels/quant_conversions.cuh @@ -0,0 +1,81 @@ +#pragma once + +/** + * __device__ helper functions to deal with float -> quant datatype conversion + */ + +#include "quantization/vectorization.cuh" +// TODO(luka/varun):refactor common.cuh to use this file instead +#include "quantization/fp8/common.cuh" + +namespace vllm { + +// TODO(luka/varun): combine into common utilities for int8 +// (with int8_quant_kernels.cu) +static __device__ __forceinline__ int8_t float_to_int8_rn(float const x) { +#ifdef USE_ROCM + static const float i8_min = + static_cast(std::numeric_limits::min()); + static const float i8_max = + static_cast(std::numeric_limits::max()); + // round + float dst = std::nearbyint(x); + // saturate + dst = std::clamp(dst, i8_min, i8_max); + return static_cast(dst); +#else + // CUDA path + uint32_t dst; + asm volatile("cvt.rni.sat.s8.f32 %0, %1;" : "=r"(dst) : "f"(x)); + return reinterpret_cast(dst); +#endif +} + +static __device__ __forceinline__ FP8_TYPE float_to_fp8(float const x) { + float const r = fmax(-FP8_E4M3_MAX, fmin(x, FP8_E4M3_MAX)); + return static_cast(r); +} + +template +struct ScaledQuant; + +template +struct ScaledQuant< + quant_type_t, is_scale_inverted, + typename std::enable_if_t>> { + static __device__ __forceinline__ quant_type_t quant_fn(float const x, + float const scale) { + if constexpr (is_scale_inverted) { + return float_to_int8_rn(x * scale); + } else { + return float_to_int8_rn(x / scale); + } + } +}; + +template +struct ScaledQuant< + quant_type_t, is_scale_inverted, + typename std::enable_if_t>> { + static __device__ __forceinline__ quant_type_t quant_fn(float const x, + float const scale) { + if constexpr (is_scale_inverted) { + return float_to_fp8(x * scale); + } else { + return float_to_fp8(x / scale); + } + } +}; + +template +__device__ void scaled_quant_conversion(quant_type_t* __restrict__ output, + scalar_t const* __restrict__ input, + float const scale, int const tid, + int const num_elements, + int const step) { + for (int i = tid; i < num_elements; i += step) { + output[i] = ScaledQuant(input[i], scale); + } +} + +} // namespace vllm diff --git a/csrc/quantization/vectorization.cuh b/csrc/quantization/vectorization.cuh new file mode 100644 index 0000000000000..44c999130f756 --- /dev/null +++ b/csrc/quantization/vectorization.cuh @@ -0,0 +1,33 @@ +#pragma once +/** + * __device__ datatypes vectorized by 4 + */ + +// Include both AMD and NVIDIA fp8 types to avoid circular import +// TODO(luka/varun) use FP8_TYPE instead after refactoring +#include +#include + +namespace vllm { + +// Vectorization containers +template +struct __align__(8) vec4_t { + scalar_t x; + scalar_t y; + scalar_t z; + scalar_t w; +}; + +template +struct __align__(4) q8x4_t { + static_assert(std::is_same_v || + std::is_same_v || + std::is_same_v); + quant_type_t x; + quant_type_t y; + quant_type_t z; + quant_type_t w; +}; + +} // namespace vllm diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 4e64b9c92773a..1ffab14862fed 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -128,6 +128,14 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.impl("fused_add_rms_norm_static_fp8_quant", torch::kCUDA, &fused_add_rms_norm_static_fp8_quant); + // Fused Layernorm + Quant kernels + ops.def( + "rms_norm_dynamic_per_token_quant(Tensor! result, Tensor input, " + "Tensor weight, Tensor! scale, float epsilon, " + "Tensor? scale_ub, Tensor!? residual) -> ()"); + ops.impl("rms_norm_dynamic_per_token_quant", torch::kCUDA, + &rms_norm_dynamic_per_token_quant); + // Rotary embedding // Apply GPT-NeoX or GPT-J style rotary embedding to query and key. ops.def( diff --git a/docs/source/assets/usage/disagg_prefill/abstraction.jpg b/docs/source/assets/usage/disagg_prefill/abstraction.jpg new file mode 100644 index 0000000000000..1a99e3ed8cf5f Binary files /dev/null and b/docs/source/assets/usage/disagg_prefill/abstraction.jpg differ diff --git a/docs/source/assets/usage/disagg_prefill/overview.jpg b/docs/source/assets/usage/disagg_prefill/overview.jpg new file mode 100644 index 0000000000000..f029b4c05c808 Binary files /dev/null and b/docs/source/assets/usage/disagg_prefill/overview.jpg differ diff --git a/docs/source/design/multiprocessing.md b/docs/source/design/multiprocessing.md new file mode 100644 index 0000000000000..b58456ecc6da8 --- /dev/null +++ b/docs/source/design/multiprocessing.md @@ -0,0 +1,195 @@ +# Python Multiprocessing + +## Debugging + +Please see the [Debugging +Tips](https://docs.vllm.ai/en/latest/getting_started/debugging.html#python-multiprocessing) +page for information on known issues and how to solve them. + +## Introduction + +*Note that source code references are to the state of the code at the time of writing in December, 2024.* + +The use of Python multiprocessing in vLLM is complicated by: + +- The use of vLLM as a library and the inability to control the code using vLLM +- Varying levels of incompatibilities between multiprocessing methods and vLLM + dependencies + +This document describes how vLLM deals with these challenges. + +## Multiprocessing Methods + +[Python multiprocessing methods](https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods) include: + +- `spawn` - spawn a new Python process. This will be the default as of Python + 3.14. + +- `fork` - Use `os.fork()` to fork the Python interpreter. This is the default + in Python versions prior to 3.14. + +- `forkserver` - Spawn a server process that will fork a new process on request. + +### Tradeoffs + +`fork` is the fastest method, but is incompatible with dependencies that use +threads. + +`spawn` is more compatible with dependencies, but can be problematic when vLLM +is used as a library. If the consuming code does not use a `__main__` guard (`if +__name__ == "__main__":`), the code will be inadvertently re-executed when vLLM +spawns a new process. This can lead to infinite recursion, among other problems. + +`forkserver` will spawn a new server process that will fork new processes on +demand. This unfortunately has the same problem as `spawn` when vLLM is used as +a library. The server process is created as a spawned new process, which will +re-execute code not protected by a `__main__` guard. + +For both `spawn` and `forkserver`, the process must not depend on inheriting any +global state as would be the case with `fork`. + +## Compatibility with Dependencies + +Multiple vLLM dependencies indicate either a preference or requirement for using +`spawn`: + +- +- +- + +It is perhaps more accurate to say that there are known problems with using +`fork` after initializing these dependencies. + +## Current State (v0) + +The environment variable `VLLM_WORKER_MULTIPROC_METHOD` can be used to control which method is used by vLLM. The current default is `fork`. + +- + +When we know we own the process because the `vllm` command was used, we use +`spawn` because it's the most widely compatible. + +- + +The `multiproc_xpu_executor` forces the use of `spawn`. + +- + +There are other miscellaneous places hard-coding the use of `spawn`: + +- +- + +Related PRs: + +- + +## Prior State in v1 + +There was an environment variable to control whether multiprocessing is used in +the v1 engine core, `VLLM_ENABLE_V1_MULTIPROCESSING`. This defaulted to off. + +- + +When it was enabled, the v1 `LLMEngine` would create a new process to run the +engine core. + +- +- +- https://github.com/vllm-project/vllm/blob/d05f88679bedd73939251a17c3d785a354b2946c/vllm/v1/engine/core_client.py#L44-L45 + +It was off by default for all the reasons mentioned above - compatibility with +dependencies and code using vLLM as a library. + +### Changes Made in v1 + +There is not an easy solution with Python's `multiprocessing` that will work +everywhere. As a first step, we can get v1 into a state where it does "best +effort" choice of multiprocessing method to maximize compatibility. + +- Default to `fork`. +- Use `spawn` when we know we control the main process (`vllm` was executed). +- If we detect `cuda` was previously initialized, force `spawn` and emit a + warning. We know `fork` will break, so this is the best we can do. + +The case that is known to still break in this scenario is code using vLLM as a +library that initializes `cuda` before calling vLLM. The warning we emit should +instruct users to either add a `__main__` guard or to disable multiprocessing. + +If that known-failure case occurs, the user will see two messages that explain +what is happening. First, a log message from vLLM: + +``` + WARNING 12-11 14:50:37 multiproc_worker_utils.py:281] CUDA was previously + initialized. We must use the `spawn` multiprocessing start method. Setting + VLLM_WORKER_MULTIPROC_METHOD to 'spawn'. See + https://docs.vllm.ai/en/latest/getting_started/debugging.html#python-multiprocessing + for more information. +``` + +Second, Python itself will raise an exception with a nice explanation: + +``` +RuntimeError: + An attempt has been made to start a new process before the + current process has finished its bootstrapping phase. + + This probably means that you are not using fork to start your + child processes and you have forgotten to use the proper idiom + in the main module: + + if __name__ == '__main__': + freeze_support() + ... + + The "freeze_support()" line can be omitted if the program + is not going to be frozen to produce an executable. + + To fix this issue, refer to the "Safe importing of main module" + section in https://docs.python.org/3/library/multiprocessing.html +``` + +## Alternatives Considered + +### Detect if a `__main__` guard is present + +It has been suggested that we could behave better if we could detect whether +code using vLLM as a library has a `__main__` guard in place. This [post on +stackoverflow](https://stackoverflow.com/questions/77220442/multiprocessing-pool-in-a-python-class-without-name-main-guard) +was from a library author facing the same question. + +It is possible to detect whether we are in the original, `__main__` process, or +a subsequent spawned process. However, it does not appear to be straight forward +to detect whether a `__main__` guard is present in the code. + +This option has been discarded as impractical. + +### Use `forkserver` + +At first it appears that `forkserver` is a nice solution to the problem. +However, the way it works presents the same challenges that `spawn` does when +vLLM is used as a library. + +### Force `spawn` all the time + +One way to clean this up is to just force the use of `spawn` all the time and +document that the use of a `__main__` guard is required when using vLLM as a +library. This would unfortunately break existing code and make vLLM harder to +use, violating the desire to make the `LLM` class as easy as possible to use. + +Instead of pushing this on our users, we will retain the complexity to do our +best to make things work. + +## Future Work + +We may want to consider a different worker management approach in the future +that works around these challenges. + +1. We could implement something `forkserver`-like, but have the process manager + be something we initially launch by running our own subprocess and a custom + entrypoint for worker management (launch a `vllm-manager` process). + +2. We can explore other libraries that may better suit our needs. Examples to + consider: + +- diff --git a/docs/source/getting_started/debugging.rst b/docs/source/getting_started/debugging.rst index 0c1afcbd7c0b9..d6c83014dc69f 100644 --- a/docs/source/getting_started/debugging.rst +++ b/docs/source/getting_started/debugging.rst @@ -136,6 +136,62 @@ If the test script hangs or crashes, usually it means the hardware/drivers are b Adjust ``--nproc-per-node``, ``--nnodes``, and ``--node-rank`` according to your setup, being sure to execute different commands (with different ``--node-rank``) on different nodes. +Python multiprocessing +---------------------- + +`RuntimeError` Exception +^^^^^^^^^^^^^^^^^^^^^^^^ + +If you have seen a warning in your logs like this: + +.. code-block:: console + + WARNING 12-11 14:50:37 multiproc_worker_utils.py:281] CUDA was previously + initialized. We must use the `spawn` multiprocessing start method. Setting + VLLM_WORKER_MULTIPROC_METHOD to 'spawn'. See + https://docs.vllm.ai/en/latest/getting_started/debugging.html#python-multiprocessing + for more information. + +or an error from Python that looks like this: + +.. code-block:: console + + RuntimeError: + An attempt has been made to start a new process before the + current process has finished its bootstrapping phase. + + This probably means that you are not using fork to start your + child processes and you have forgotten to use the proper idiom + in the main module: + + if __name__ == '__main__': + freeze_support() + ... + + The "freeze_support()" line can be omitted if the program + is not going to be frozen to produce an executable. + + To fix this issue, refer to the "Safe importing of main module" + section in https://docs.python.org/3/library/multiprocessing.html + +then you must update your Python code to guard usage of ``vllm`` behind a ``if +__name__ == '__main__':`` block. For example, instead of this: + +.. code-block:: python + + import vllm + + llm = vllm.LLM(...) + +try this instead: + +.. code-block:: python + + if __name__ == '__main__': + import vllm + + llm = vllm.LLM(...) + Known Issues ---------------------------------------- - In ``v0.5.2``, ``v0.5.3``, and ``v0.5.3.post1``, there is a bug caused by `zmq `_ , which can occasionally cause vLLM to hang depending on the machine configuration. The solution is to upgrade to the latest version of ``vllm`` to include the `fix `_. diff --git a/docs/source/index.rst b/docs/source/index.rst index db5591ea6312a..d812885aafea9 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -95,6 +95,8 @@ Documentation :caption: Models models/supported_models + models/generative_models + models/pooling_models models/adding_model models/enabling_multimodal_inputs @@ -113,6 +115,7 @@ Documentation usage/engine_args usage/env_vars usage/usage_stats + usage/disagg_prefill .. toctree:: :maxdepth: 1 @@ -172,6 +175,7 @@ Documentation design/input_processing/model_inputs_index design/kernel/paged_attention design/multimodal/multimodal_index + design/multiprocessing .. For Developers: contributing to the vLLM project diff --git a/docs/source/models/generative_models.rst b/docs/source/models/generative_models.rst new file mode 100644 index 0000000000000..fb71185600863 --- /dev/null +++ b/docs/source/models/generative_models.rst @@ -0,0 +1,146 @@ +.. _generative_models: + +Generative Models +================= + +vLLM provides first-class support for generative models, which covers most of LLMs. + +In vLLM, generative models implement the :class:`~vllm.model_executor.models.VllmModelForTextGeneration` interface. +Based on the final hidden states of the input, these models output log probabilities of the tokens to generate, +which are then passed through :class:`~vllm.model_executor.layers.Sampler` to obtain the final text. + +Offline Inference +----------------- + +The :class:`~vllm.LLM` class provides various methods for offline inference. +See :ref:`Engine Arguments ` for a list of options when initializing the model. + +For generative models, the only supported :code:`task` option is :code:`"generate"`. +Usually, this is automatically inferred so you don't have to specify it. + +``LLM.generate`` +^^^^^^^^^^^^^^^^ + +The :class:`~vllm.LLM.generate` method is available to all generative models in vLLM. +It is similar to `its counterpart in HF Transformers `__, +except that tokenization and detokenization are also performed automatically. + +.. code-block:: python + + llm = LLM(model="facebook/opt-125m") + outputs = llm.generate("Hello, my name is") + + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + +You can optionally control the language generation by passing :class:`~vllm.SamplingParams`. +For example, you can use greedy sampling by setting :code:`temperature=0`: + +.. code-block:: python + + llm = LLM(model="facebook/opt-125m") + params = SamplingParams(temperature=0) + outputs = llm.generate("Hello, my name is", params) + + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + +A code example can be found in `examples/offline_inference.py `_. + +``LLM.beam_search`` +^^^^^^^^^^^^^^^^^^^ + +The :class:`~vllm.LLM.beam_search` method implements `beam search `__ on top of :class:`~vllm.LLM.generate`. +For example, to search using 5 beams and output at most 50 tokens: + +.. code-block:: python + + llm = LLM(model="facebook/opt-125m") + params = BeamSearchParams(beam_width=5, max_tokens=50) + outputs = llm.generate("Hello, my name is", params) + + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + +``LLM.chat`` +^^^^^^^^^^^^ + +The :class:`~vllm.LLM.chat` method implements chat functionality on top of :class:`~vllm.LLM.generate`. +In particular, it accepts input similar to `OpenAI Chat Completions API `__ +and automatically applies the model's `chat template `__ to format the prompt. + +.. important:: + + In general, only instruction-tuned models have a chat template. + Base models may perform poorly as they are not trained to respond to the chat conversation. + +.. code-block:: python + + llm = LLM(model="meta-llama/Meta-Llama-3-8B-Instruct") + conversation = [ + { + "role": "system", + "content": "You are a helpful assistant" + }, + { + "role": "user", + "content": "Hello" + }, + { + "role": "assistant", + "content": "Hello! How can I assist you today?" + }, + { + "role": "user", + "content": "Write an essay about the importance of higher education.", + }, + ] + outputs = llm.chat(conversation) + + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + +A code example can be found in `examples/offline_inference_chat.py `_. + +If the model doesn't have a chat template or you want to specify another one, +you can explicitly pass a chat template: + +.. code-block:: python + + from vllm.entrypoints.chat_utils import load_chat_template + + # You can find a list of existing chat templates under `examples/` + custom_template = load_chat_template(chat_template="") + print("Loaded chat template:", custom_template) + + outputs = llm.chat(conversation, chat_template=custom_template) + +Online Inference +---------------- + +Our `OpenAI Compatible Server <../serving/openai_compatible_server>`__ can be used for online inference. +Please click on the above link for more details on how to launch the server. + +Completions API +^^^^^^^^^^^^^^^ + +Our Completions API is similar to ``LLM.generate`` but only accepts text. +It is compatible with `OpenAI Completions API `__ +so that you can use OpenAI client to interact with it. +A code example can be found in `examples/openai_completion_client.py `_. + +Chat API +^^^^^^^^ + +Our Chat API is similar to ``LLM.chat``, accepting both text and :ref:`multi-modal inputs `. +It is compatible with `OpenAI Chat Completions API `__ +so that you can use OpenAI client to interact with it. +A code example can be found in `examples/openai_chat_completion_client.py `_. diff --git a/docs/source/models/pooling_models.rst b/docs/source/models/pooling_models.rst new file mode 100644 index 0000000000000..4e67677a2767a --- /dev/null +++ b/docs/source/models/pooling_models.rst @@ -0,0 +1,136 @@ +.. _pooling_models: + +Pooling Models +============== + +vLLM also supports pooling models, including embedding, reranking and reward models. + +In vLLM, pooling models implement the :class:`~vllm.model_executor.models.VllmModelForPooling` interface. +These models use a :class:`~vllm.model_executor.layers.Pooler` to extract the final hidden states of the input +before returning them. + +.. note:: + + We currently support pooling models primarily as a matter of convenience. + As shown in the :ref:`Compatibility Matrix `, most vLLM features are not applicable to + pooling models as they only work on the generation or decode stage, so performance may not improve as much. + +Offline Inference +----------------- + +The :class:`~vllm.LLM` class provides various methods for offline inference. +See :ref:`Engine Arguments ` for a list of options when initializing the model. + +For pooling models, we support the following :code:`task` options: + +- Embedding (:code:`"embed"` / :code:`"embedding"`) +- Classification (:code:`"classify"`) +- Sentence Pair Scoring (:code:`"score"`) +- Reward Modeling (:code:`"reward"`) + +The selected task determines the default :class:`~vllm.model_executor.layers.Pooler` that is used: + +- Embedding: Extract only the hidden states corresponding to the last token, and apply normalization. +- Classification: Extract only the hidden states corresponding to the last token, and apply softmax. +- Sentence Pair Scoring: Extract only the hidden states corresponding to the last token, and apply softmax. +- Reward Modeling: Extract all of the hidden states and return them directly. + +When loading `Sentence Transformers `__ models, +we attempt to override the default pooler based on its Sentence Transformers configuration file (:code:`modules.json`). + +You can customize the model's pooling method via the :code:`override_pooler_config` option, +which takes priority over both the model's and Sentence Transformers's defaults. + +``LLM.encode`` +^^^^^^^^^^^^^^ + +The :class:`~vllm.LLM.encode` method is available to all pooling models in vLLM. +It returns the extracted hidden states directly, which is useful for reward models. + +.. code-block:: python + + llm = LLM(model="Qwen/Qwen2.5-Math-RM-72B", task="reward") + (output,) = llm.encode("Hello, my name is") + + data = output.outputs.data + print(f"Data: {data!r}") + +``LLM.embed`` +^^^^^^^^^^^^^ + +The :class:`~vllm.LLM.embed` method outputs an embedding vector for each prompt. +It is primarily designed for embedding models. + +.. code-block:: python + + llm = LLM(model="intfloat/e5-mistral-7b-instruct", task="embed") + (output,) = llm.embed("Hello, my name is") + + embeds = output.outputs.embedding + print(f"Embeddings: {embeds!r} (size={len(embeds)})") + +A code example can be found in `examples/offline_inference_embedding.py `_. + +``LLM.classify`` +^^^^^^^^^^^^^^^^ + +The :class:`~vllm.LLM.classify` method outputs a probability vector for each prompt. +It is primarily designed for classification models. + +.. code-block:: python + + llm = LLM(model="jason9693/Qwen2.5-1.5B-apeach", task="classify") + (output,) = llm.classify("Hello, my name is") + + probs = output.outputs.probs + print(f"Class Probabilities: {probs!r} (size={len(probs)})") + +A code example can be found in `examples/offline_inference_classification.py `_. + +``LLM.score`` +^^^^^^^^^^^^^ + +The :class:`~vllm.LLM.score` method outputs similarity scores between sentence pairs. +It is primarily designed for `cross-encoder models `__. +These types of models serve as rerankers between candidate query-document pairs in RAG systems. + +.. note:: + + vLLM can only perform the model inference component (e.g. embedding, reranking) of RAG. + To handle RAG at a higher level, you should use integration frameworks such as `LangChain `_. + +.. code-block:: python + + llm = LLM(model="BAAI/bge-reranker-v2-m3", task="score") + (output,) = llm.score("What is the capital of France?", + "The capital of Brazil is Brasilia.") + + score = output.outputs.score + print(f"Score: {score}") + +A code example can be found in `examples/offline_inference_scoring.py `_. + +Online Inference +---------------- + +Our `OpenAI Compatible Server <../serving/openai_compatible_server>`__ can be used for online inference. +Please click on the above link for more details on how to launch the server. + +Embeddings API +^^^^^^^^^^^^^^ + +Our Embeddings API is similar to ``LLM.embed``, accepting both text and :ref:`multi-modal inputs `. + +The text-only API is compatible with `OpenAI Embeddings API `__ +so that you can use OpenAI client to interact with it. +A code example can be found in `examples/openai_embedding_client.py `_. + +The multi-modal API is an extension of the `OpenAI Embeddings API `__ +that incorporates `OpenAI Chat Completions API `__, +so it is not part of the OpenAI standard. Please see :ref:`this page ` for more details on how to use it. + +Score API +^^^^^^^^^ + +Our Score API is similar to ``LLM.score``. +Please see `this page <../serving/openai_compatible_server.html#score-api-for-cross-encoder-models>`__ for more details on how to use it. diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index 6540e023c1ab0..3bef3f3226062 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -3,11 +3,21 @@ Supported Models ================ -vLLM supports a variety of generative and embedding models from `HuggingFace (HF) Transformers `_. -This page lists the model architectures that are currently supported by vLLM. +vLLM supports generative and pooling models across various tasks. +If a model supports more than one task, you can set the task via the :code:`--task` argument. + +For each task, we list the model architectures that have been implemented in vLLM. Alongside each architecture, we include some popular models that use it. -For other models, you can check the :code:`config.json` file inside the model repository. +Loading a Model +^^^^^^^^^^^^^^^ + +HuggingFace Hub ++++++++++++++++ + +By default, vLLM loads models from `HuggingFace (HF) Hub `_. + +To determine whether a given model is supported, you can check the :code:`config.json` file inside the HF repository. If the :code:`"architectures"` field contains a model architecture listed below, then it should be supported in theory. .. tip:: @@ -17,38 +27,57 @@ If the :code:`"architectures"` field contains a model architecture listed below, from vllm import LLM - llm = LLM(model=...) # Name or path of your model + # For generative models (task=generate) only + llm = LLM(model=..., task="generate") # Name or path of your model output = llm.generate("Hello, my name is") print(output) - If vLLM successfully generates text, it indicates that your model is supported. + # For pooling models (task={embed,classify,reward}) only + llm = LLM(model=..., task="embed") # Name or path of your model + output = llm.encode("Hello, my name is") + print(output) + + If vLLM successfully returns text (for generative models) or hidden states (for pooling models), it indicates that your model is supported. Otherwise, please refer to :ref:`Adding a New Model ` and :ref:`Enabling Multimodal Inputs ` for instructions on how to implement your model in vLLM. Alternatively, you can `open an issue on GitHub `_ to request vLLM support. -.. note:: - To use models from `ModelScope `_ instead of HuggingFace Hub, set an environment variable: +ModelScope +++++++++++ - .. code-block:: shell +To use models from `ModelScope `_ instead of HuggingFace Hub, set an environment variable: - $ export VLLM_USE_MODELSCOPE=True +.. code-block:: shell - And use with :code:`trust_remote_code=True`. + $ export VLLM_USE_MODELSCOPE=True - .. code-block:: python +And use with :code:`trust_remote_code=True`. - from vllm import LLM +.. code-block:: python - llm = LLM(model=..., revision=..., trust_remote_code=True) # Name or path of your model - output = llm.generate("Hello, my name is") - print(output) + from vllm import LLM + + llm = LLM(model=..., revision=..., task=..., trust_remote_code=True) + + # For generative models (task=generate) only + output = llm.generate("Hello, my name is") + print(output) -Text-only Language Models -^^^^^^^^^^^^^^^^^^^^^^^^^ + # For pooling models (task={embed,classify,reward}) only + output = llm.encode("Hello, my name is") + print(output) -Text Generation ---------------- +List of Text-only Language Models +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Generative Models ++++++++++++++++++ + +See :ref:`this page ` for more information on how to use generative models. + +Text Generation (``--task generate``) +------------------------------------- .. list-table:: :widths: 25 25 50 5 5 @@ -89,9 +118,9 @@ Text Generation - :code:`THUDM/chatglm2-6b`, :code:`THUDM/chatglm3-6b`, etc. - ✅︎ - ✅︎ - * - :code:`CohereForCausalLM` + * - :code:`CohereForCausalLM`,:code:`Cohere2ForCausalLM` - Command-R - - :code:`CohereForAI/c4ai-command-r-v01`, etc. + - :code:`CohereForAI/c4ai-command-r-v01`, :code:`CohereForAI/c4ai-command-r7b-12-2024`, etc. - ✅︎ - ✅︎ * - :code:`DbrxForCausalLM` @@ -174,6 +203,11 @@ Text Generation - :code:`ibm-granite/granite-3.0-1b-a400m-base`, :code:`ibm-granite/granite-3.0-3b-a800m-instruct`, :code:`ibm/PowerMoE-3b`, etc. - ✅︎ - ✅︎ + * - :code:`GritLM` + - GritLM + - :code:`parasail-ai/GritLM-7B-vllm`. + - ✅︎ + - ✅︎ * - :code:`InternLMForCausalLM` - InternLM - :code:`internlm/internlm-7b`, :code:`internlm/internlm-chat-7b`, etc. @@ -328,8 +362,24 @@ Text Generation .. note:: Currently, the ROCm version of vLLM supports Mistral and Mixtral only for context lengths up to 4096. -Text Embedding --------------- +Pooling Models +++++++++++++++ + +See :ref:`this page ` for more information on how to use pooling models. + +.. important:: + Since some model architectures support both generative and pooling tasks, + you should explicitly specify the task type to ensure that the model is used in pooling mode instead of generative mode. + +Text Embedding (``--task embed``) +--------------------------------- + +Any text generation model can be converted into an embedding model by passing :code:`--task embed`. + +.. note:: + To get the best results, you should use pooling models that are specifically trained as such. + +The following table lists those that are tested in vLLM. .. list-table:: :widths: 25 25 50 5 5 @@ -350,6 +400,11 @@ Text Embedding - :code:`BAAI/bge-multilingual-gemma2`, etc. - - ✅︎ + * - :code:`GritLM` + - GritLM + - :code:`parasail-ai/GritLM-7B-vllm`. + - ✅︎ + - ✅︎ * - :code:`LlamaModel`, :code:`LlamaForCausalLM`, :code:`MistralModel`, etc. - Llama-based - :code:`intfloat/e5-mistral-7b-instruct`, etc. @@ -371,13 +426,6 @@ Text Embedding - - -.. important:: - Some model architectures support both generation and embedding tasks. - In this case, you have to pass :code:`--task embedding` to run the model in embedding mode. - -.. tip:: - You can override the model's pooling method by passing :code:`--override-pooler-config`. - .. note:: :code:`ssmits/Qwen2-7B-Instruct-embed-base` has an improperly defined Sentence Transformers config. You should manually set mean pooling by passing :code:`--override-pooler-config '{"pooling_type": "MEAN"}'`. @@ -389,8 +437,8 @@ Text Embedding On the other hand, its 1.5B variant (:code:`Alibaba-NLP/gte-Qwen2-1.5B-instruct`) uses causal attention despite being described otherwise on its model card. -Reward Modeling ---------------- +Reward Modeling (``--task reward``) +----------------------------------- .. list-table:: :widths: 25 25 50 5 5 @@ -416,11 +464,8 @@ Reward Modeling For process-supervised reward models such as :code:`peiyi9979/math-shepherd-mistral-7b-prm`, the pooling config should be set explicitly, e.g.: :code:`--override-pooler-config '{"pooling_type": "STEP", "step_tag_id": 123, "returned_token_ids": [456, 789]}'`. -.. note:: - As an interim measure, these models are supported in both offline and online inference via Embeddings API. - -Classification ---------------- +Classification (``--task classify``) +------------------------------------ .. list-table:: :widths: 25 25 50 5 5 @@ -437,11 +482,8 @@ Classification - ✅︎ - ✅︎ -.. note:: - As an interim measure, these models are supported in both offline and online inference via Embeddings API. - -Sentence Pair Scoring ---------------------- +Sentence Pair Scoring (``--task score``) +---------------------------------------- .. list-table:: :widths: 25 25 50 5 5 @@ -468,13 +510,10 @@ Sentence Pair Scoring - - -.. note:: - These models are supported in both offline and online inference via Score API. - .. _supported_mm_models: -Multimodal Language Models -^^^^^^^^^^^^^^^^^^^^^^^^^^ +List of Multimodal Language Models +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ The following modalities are supported depending on the model: @@ -491,8 +530,15 @@ On the other hand, modalities separated by :code:`/` are mutually exclusive. - e.g.: :code:`T / I` means that the model supports text-only and image-only inputs, but not text-with-image inputs. -Text Generation ---------------- +See :ref:`this page ` on how to pass multi-modal inputs to the model. + +Generative Models ++++++++++++++++++ + +See :ref:`this page ` for more information on how to use generative models. + +Text Generation (``--task generate``) +------------------------------------- .. list-table:: :widths: 25 25 15 20 5 5 5 @@ -618,9 +664,9 @@ Text Generation - ✅︎ - ✅︎ * - :code:`PaliGemmaForConditionalGeneration` - - PaliGemma + - PaliGemma, PaliGemma 2 - T + I\ :sup:`E` - - :code:`google/paligemma-3b-pt-224`, :code:`google/paligemma-3b-mix-224`, etc. + - :code:`google/paligemma-3b-pt-224`, :code:`google/paligemma-3b-mix-224`, :code:`google/paligemma2-3b-ft-docci-448`, etc. - - ✅︎ - @@ -696,8 +742,24 @@ Text Generation The official :code:`openbmb/MiniCPM-V-2` doesn't work yet, so we need to use a fork (:code:`HwwwH/MiniCPM-V-2`) for now. For more details, please see: https://github.com/vllm-project/vllm/pull/4087#issuecomment-2250397630 -Multimodal Embedding --------------------- +Pooling Models +++++++++++++++ + +See :ref:`this page ` for more information on how to use pooling models. + +.. important:: + Since some model architectures support both generative and pooling tasks, + you should explicitly specify the task type to ensure that the model is used in pooling mode instead of generative mode. + +Text Embedding (``--task embed``) +--------------------------------- + +Any text generation model can be converted into an embedding model by passing :code:`--task embed`. + +.. note:: + To get the best results, you should use pooling models that are specifically trained as such. + +The following table lists those that are tested in vLLM. .. list-table:: :widths: 25 25 15 25 5 5 @@ -728,12 +790,7 @@ Multimodal Embedding - - ✅︎ -.. important:: - Some model architectures support both generation and embedding tasks. - In this case, you have to pass :code:`--task embedding` to run the model in embedding mode. - -.. tip:: - You can override the model's pooling method by passing :code:`--override-pooler-config`. +---- Model Support Policy ===================== diff --git a/docs/source/quantization/bnb.rst b/docs/source/quantization/bnb.rst index 682938cc63d48..84f805bb60c2a 100644 --- a/docs/source/quantization/bnb.rst +++ b/docs/source/quantization/bnb.rst @@ -11,7 +11,7 @@ Below are the steps to utilize BitsAndBytes with vLLM. .. code-block:: console - $ pip install bitsandbytes>=0.44.0 + $ pip install bitsandbytes>=0.45.0 vLLM reads the model's config file and supports both in-flight quantization and pre-quantized checkpoint. diff --git a/docs/source/quantization/fp8.rst b/docs/source/quantization/fp8.rst index aacd07a34ad46..4dbf8e9d346e1 100644 --- a/docs/source/quantization/fp8.rst +++ b/docs/source/quantization/fp8.rst @@ -45,7 +45,7 @@ To produce performant FP8 quantized models with vLLM, you'll need to install the .. code-block:: console - $ pip install llmcompressor==0.1.0 + $ pip install llmcompressor Quantization Process -------------------- diff --git a/docs/source/quantization/int8.rst b/docs/source/quantization/int8.rst index 04fa308449507..aa5b251becb1c 100644 --- a/docs/source/quantization/int8.rst +++ b/docs/source/quantization/int8.rst @@ -19,7 +19,7 @@ To use INT8 quantization with vLLM, you'll need to install the `llm-compressor < .. code-block:: console - $ pip install llmcompressor==0.1.0 + $ pip install llmcompressor Quantization Process -------------------- @@ -142,4 +142,4 @@ Best Practices Troubleshooting and Support --------------------------- -If you encounter any issues or have feature requests, please open an issue on the ``vllm-project/llm-compressor`` GitHub repository. \ No newline at end of file +If you encounter any issues or have feature requests, please open an issue on the ``vllm-project/llm-compressor`` GitHub repository. diff --git a/docs/source/serving/deploying_with_docker.rst b/docs/source/serving/deploying_with_docker.rst index 14d94b09e9b9c..56f0020a1011a 100644 --- a/docs/source/serving/deploying_with_docker.rst +++ b/docs/source/serving/deploying_with_docker.rst @@ -37,6 +37,29 @@ You can build and run vLLM from source via the provided `Dockerfile `_ to start the cluster. +The first step, is to start containers and organize them into a cluster. We have provided a helper `script `_ to start the cluster. Please note, this script launches docker without administrative privileges that would be required to access GPU performance counters when running profiling and tracing tools. For that purpose, the script can have ``CAP_SYS_ADMIN`` to the docker container by using the ``--cap-add`` option in the docker run command. Pick a node as the head node, and run the following command: diff --git a/docs/source/serving/openai_compatible_server.md b/docs/source/serving/openai_compatible_server.md index f75653106cf66..1bc8d32d2d161 100644 --- a/docs/source/serving/openai_compatible_server.md +++ b/docs/source/serving/openai_compatible_server.md @@ -1,13 +1,13 @@ # OpenAI Compatible Server -vLLM provides an HTTP server that implements OpenAI's [Completions](https://platform.openai.com/docs/api-reference/completions) and [Chat](https://platform.openai.com/docs/api-reference/chat) API. +vLLM provides an HTTP server that implements OpenAI's [Completions](https://platform.openai.com/docs/api-reference/completions) and [Chat](https://platform.openai.com/docs/api-reference/chat) API, and more! -You can start the server using Python, or using [Docker](deploying_with_docker.rst): +You can start the server via the [`vllm serve`](#vllm-serve) command, or through [Docker](deploying_with_docker.rst): ```bash vllm serve NousResearch/Meta-Llama-3-8B-Instruct --dtype auto --api-key token-abc123 ``` -To call the server, you can use the official OpenAI Python client library, or any other HTTP client. +To call the server, you can use the [official OpenAI Python client](https://github.com/openai/openai-python), or any other HTTP client. ```python from openai import OpenAI client = OpenAI( @@ -25,166 +25,71 @@ completion = client.chat.completions.create( print(completion.choices[0].message) ``` -## API Reference +## Supported APIs We currently support the following OpenAI APIs: -- [Completions API](https://platform.openai.com/docs/api-reference/completions) +- [Completions API](#completions-api) (`/v1/completions`) + - Only applicable to [text generation models](../models/generative_models.rst) (`--task generate`). - *Note: `suffix` parameter is not supported.* -- [Chat Completions API](https://platform.openai.com/docs/api-reference/chat) - - [Vision](https://platform.openai.com/docs/guides/vision)-related parameters are supported; see [Multimodal Inputs](../usage/multimodal_inputs.rst). - - *Note: `image_url.detail` parameter is not supported.* - - We also support `audio_url` content type for audio files. - - Refer to [vllm.entrypoints.chat_utils](https://github.com/vllm-project/vllm/tree/main/vllm/entrypoints/chat_utils.py) for the exact schema. - - *TODO: Support `input_audio` content type as defined [here](https://github.com/openai/openai-python/blob/v1.52.2/src/openai/types/chat/chat_completion_content_part_input_audio_param.py).* +- [Chat Completions API](#chat-api) (`/v1/chat/completions`) + - Only applicable to [text generation models](../models/generative_models.rst) (`--task generate`) with a [chat template](#chat-template). - *Note: `parallel_tool_calls` and `user` parameters are ignored.* -- [Embeddings API](https://platform.openai.com/docs/api-reference/embeddings) - - Instead of `inputs`, you can pass in a list of `messages` (same schema as Chat Completions API), - which will be treated as a single prompt to the model according to its chat template. - - This enables multi-modal inputs to be passed to embedding models, see [this page](../usage/multimodal_inputs.rst) for details. - - *Note: You should run `vllm serve` with `--task embedding` to ensure that the model is being run in embedding mode.* - -## Score API for Cross Encoder Models +- [Embeddings API](#embeddings-api) (`/v1/embeddings`) + - Only applicable to [embedding models](../models/pooling_models.rst) (`--task embed`). -vLLM supports *cross encoders models* at the **/v1/score** endpoint, which is not an OpenAI API standard endpoint. You can find the documentation for these kind of models at [sbert.net](https://www.sbert.net/docs/package_reference/cross_encoder/cross_encoder.html). +In addition, we have the following custom APIs: -A ***Cross Encoder*** takes exactly two sentences / texts as input and either predicts a score or label for this sentence pair. It can for example predict the similarity of the sentence pair on a scale of 0 … 1. +- [Tokenizer API](#tokenizer-api) (`/tokenize`, `/detokenize`) + - Applicable to any model with a tokenizer. +- [Score API](#score-api) (`/score`) + - Only applicable to [cross-encoder models](../models/pooling_models.rst) (`--task score`). -### Example of usage for a pair of a string and a list of texts +(chat-template)= +## Chat Template -In this case, the model will compare the first given text to each of the texts containing the list. +In order for the language model to support chat protocol, vLLM requires the model to include +a chat template in its tokenizer configuration. The chat template is a Jinja2 template that +specifies how are roles, messages, and other chat-specific tokens are encoded in the input. -```bash -curl -X 'POST' \ - 'http://127.0.0.1:8000/v1/score' \ - -H 'accept: application/json' \ - -H 'Content-Type: application/json' \ - -d '{ - "model": "BAAI/bge-reranker-v2-m3", - "text_1": "What is the capital of France?", - "text_2": [ - "The capital of Brazil is Brasilia.", - "The capital of France is Paris." - ] -}' -``` +An example chat template for `NousResearch/Meta-Llama-3-8B-Instruct` can be found [here](https://github.com/meta-llama/llama3?tab=readme-ov-file#instruction-tuned-models) -Response: +Some models do not provide a chat template even though they are instruction/chat fine-tuned. For those model, +you can manually specify their chat template in the `--chat-template` parameter with the file path to the chat +template, or the template in string form. Without a chat template, the server will not be able to process chat +and all chat requests will error. ```bash -{ - "id": "score-request-id", - "object": "list", - "created": 693570, - "model": "BAAI/bge-reranker-v2-m3", - "data": [ - { - "index": 0, - "object": "score", - "score": [ - 0.001094818115234375 - ] - }, - { - "index": 1, - "object": "score", - "score": [ - 1 - ] - } - ], - "usage": {} -} +vllm serve --chat-template ./path-to-chat-template.jinja ``` -### Example of usage for a pair of two lists of texts - -In this case, the model will compare the one by one, making pairs by same index correspondent in each list. +vLLM community provides a set of chat templates for popular models. You can find them in the examples +directory [here](https://github.com/vllm-project/vllm/tree/main/examples/) -```bash -curl -X 'POST' \ - 'http://127.0.0.1:8000/v1/score' \ - -H 'accept: application/json' \ - -H 'Content-Type: application/json' \ - -d '{ - "model": "BAAI/bge-reranker-v2-m3", - "encoding_format": "float", - "text_1": [ - "What is the capital of Brazil?", - "What is the capital of France?" - ], - "text_2": [ - "The capital of Brazil is Brasilia.", - "The capital of France is Paris." +With the inclusion of multi-modal chat APIs, the OpenAI spec now accepts chat messages in a new format which specifies +both a `type` and a `text` field. An example is provided below: +```python +completion = client.chat.completions.create( + model="NousResearch/Meta-Llama-3-8B-Instruct", + messages=[ + {"role": "user", "content": [{"type": "text", "text": "Classify this sentiment: vLLM is wonderful!"}]} ] -}' -``` - -Response: - -```bash -{ - "id": "score-request-id", - "object": "list", - "created": 693447, - "model": "BAAI/bge-reranker-v2-m3", - "data": [ - { - "index": 0, - "object": "score", - "score": [ - 1 - ] - }, - { - "index": 1, - "object": "score", - "score": [ - 1 - ] - } - ], - "usage": {} -} +) ``` -### Example of usage for a pair of two strings - -In this case, the model will compare the strings of texts. - -```bash -curl -X 'POST' \ - 'http://127.0.0.1:8000/v1/score' \ - -H 'accept: application/json' \ - -H 'Content-Type: application/json' \ - -d '{ - "model": "BAAI/bge-reranker-v2-m3", - "encoding_format": "float", - "text_1": "What is the capital of France?", - "text_2": "The capital of France is Paris." -}' -``` +Most chat templates for LLMs expect the `content` field to be a string, but there are some newer models like +`meta-llama/Llama-Guard-3-1B` that expect the content to be formatted according to the OpenAI schema in the +request. vLLM provides best-effort support to detect this automatically, which is logged as a string like +*"Detected the chat template content format to be..."*, and internally converts incoming requests to match +the detected format, which can be one of: -Response: +- `"string"`: A string. + - Example: `"Hello world"` +- `"openai"`: A list of dictionaries, similar to OpenAI schema. + - Example: `[{"type": "text", "text": "Hello world!"}]` -```bash -{ - "id": "score-request-id", - "object": "list", - "created": 693447, - "model": "BAAI/bge-reranker-v2-m3", - "data": [ - { - "index": 0, - "object": "score", - "score": [ - 1 - ] - } - ], - "usage": {} -} -``` +If the result is not what you expect, you can set the `--chat-template-content-format` CLI argument +to override which format to use. ## Extra Parameters @@ -204,7 +109,7 @@ completion = client.chat.completions.create( ) ``` -### Extra HTTP Headers +## Extra HTTP Headers Only `X-Request-Id` HTTP request header is supported for now. @@ -230,7 +135,53 @@ completion = client.completions.create( print(completion._request_id) ``` -### Extra Parameters for Completions API +## CLI Reference + +(vllm-serve)= +### `vllm serve` + +The `vllm serve` command is used to launch the OpenAI-compatible server. + +```{argparse} +:module: vllm.entrypoints.openai.cli_args +:func: create_parser_for_docs +:prog: vllm serve +``` + +#### Configuration file + +You can load CLI arguments via a [YAML](https://yaml.org/) config file. +The argument names must be the long form of those outlined [above](#vllm-serve). + +For example: + +```yaml +# config.yaml + +host: "127.0.0.1" +port: 6379 +uvicorn-log-level: "info" +``` + +To use the above config file: + +```bash +$ vllm serve SOME_MODEL --config config.yaml +``` + +```{note} +In case an argument is supplied simultaneously using command line and the config file, the value from the command line will take precedence. +The order of priorities is `command line > config file values > defaults`. +``` + +## API Reference + +(completions-api)= +### Completions API + +Refer to [OpenAI's API reference](https://platform.openai.com/docs/api-reference/completions) for more details. + +#### Extra parameters The following [sampling parameters (click through to see documentation)](../dev/sampling_params.rst) are supported. @@ -248,7 +199,17 @@ The following extra parameters are supported: :end-before: end-completion-extra-params ``` -### Extra Parameters for Chat Completions API +(chat-api)= +### Chat Completions API + +Refer to [OpenAI's API reference](https://platform.openai.com/docs/api-reference/chat) for more details. + +We support both [Vision](https://platform.openai.com/docs/guides/vision)- and +[Audio](https://platform.openai.com/docs/guides/audio?audio-generation-quickstart-example=audio-in)-related parameters; +see our [Multimodal Inputs](../usage/multimodal_inputs.rst) guide for more information. +- *Note: `image_url.detail` parameter is not supported.* + +#### Extra parameters The following [sampling parameters (click through to see documentation)](../dev/sampling_params.rst) are supported. @@ -266,7 +227,19 @@ The following extra parameters are supported: :end-before: end-chat-completion-extra-params ``` -### Extra Parameters for Embeddings API +(embeddings-api)= +### Embeddings API + +Refer to [OpenAI's API reference](https://platform.openai.com/docs/api-reference/embeddings) for more details. + +If the model has a [chat template](#chat-template), you can replace `inputs` with a list of `messages` (same schema as [Chat Completions API](#chat-api)) +which will be treated as a single prompt to the model. + +```{tip} +This enables multi-modal inputs to be passed to embedding models, see [this page](../usage/multimodal_inputs.rst) for details. +``` + +#### Extra parameters The following [pooling parameters (click through to see documentation)](../dev/pooling_params.rst) are supported. @@ -276,7 +249,7 @@ The following [pooling parameters (click through to see documentation)](../dev/p :end-before: end-embedding-pooling-params ``` -The following extra parameters are supported: +The following extra parameters are supported by default: ```{literalinclude} ../../../vllm/entrypoints/openai/protocol.py :language: python @@ -284,80 +257,179 @@ The following extra parameters are supported: :end-before: end-embedding-extra-params ``` -## Chat Template +For chat-like input (i.e. if `messages` is passed), these extra parameters are supported instead: -In order for the language model to support chat protocol, vLLM requires the model to include -a chat template in its tokenizer configuration. The chat template is a Jinja2 template that -specifies how are roles, messages, and other chat-specific tokens are encoded in the input. +```{literalinclude} ../../../vllm/entrypoints/openai/protocol.py +:language: python +:start-after: begin-chat-embedding-extra-params +:end-before: end-chat-embedding-extra-params +``` -An example chat template for `NousResearch/Meta-Llama-3-8B-Instruct` can be found [here](https://github.com/meta-llama/llama3?tab=readme-ov-file#instruction-tuned-models) +(tokenizer-api)= +### Tokenizer API -Some models do not provide a chat template even though they are instruction/chat fine-tuned. For those model, -you can manually specify their chat template in the `--chat-template` parameter with the file path to the chat -template, or the template in string form. Without a chat template, the server will not be able to process chat -and all chat requests will error. +The Tokenizer API is a simple wrapper over [HuggingFace-style tokenizers](https://huggingface.co/docs/transformers/en/main_classes/tokenizer). +It consists of two endpoints: + +- `/tokenize` corresponds to calling `tokenizer.encode()`. +- `/detokenize` corresponds to calling `tokenizer.decode()`. + +(score-api)= +### Score API + +The Score API applies a cross-encoder model to predict scores for sentence pairs. +Usually, the score for a sentence pair refers to the similarity between two sentences, on a scale of 0 to 1. + +You can find the documentation for these kind of models at [sbert.net](https://www.sbert.net/docs/package_reference/cross_encoder/cross_encoder.html). + +#### Single inference + +You can pass a string to both `text_1` and `text_2`, forming a single sentence pair. + +Request: ```bash -vllm serve --chat-template ./path-to-chat-template.jinja +curl -X 'POST' \ + 'http://127.0.0.1:8000/score' \ + -H 'accept: application/json' \ + -H 'Content-Type: application/json' \ + -d '{ + "model": "BAAI/bge-reranker-v2-m3", + "encoding_format": "float", + "text_1": "What is the capital of France?", + "text_2": "The capital of France is Paris." +}' ``` -vLLM community provides a set of chat templates for popular models. You can find them in the examples -directory [here](https://github.com/vllm-project/vllm/tree/main/examples/) +Response: -With the inclusion of multi-modal chat APIs, the OpenAI spec now accepts chat messages in a new format which specifies -both a `type` and a `text` field. An example is provided below: -```python -completion = client.chat.completions.create( - model="NousResearch/Meta-Llama-3-8B-Instruct", - messages=[ - {"role": "user", "content": [{"type": "text", "text": "Classify this sentiment: vLLM is wonderful!"}]} - ] -) +```bash +{ + "id": "score-request-id", + "object": "list", + "created": 693447, + "model": "BAAI/bge-reranker-v2-m3", + "data": [ + { + "index": 0, + "object": "score", + "score": 1 + } + ], + "usage": {} +} ``` -Most chat templates for LLMs expect the `content` field to be a string, but there are some newer models like -`meta-llama/Llama-Guard-3-1B` that expect the content to be formatted according to the OpenAI schema in the -request. vLLM provides best-effort support to detect this automatically, which is logged as a string like -*"Detected the chat template content format to be..."*, and internally converts incoming requests to match -the detected format, which can be one of: +#### Batch inference -- `"string"`: A string. - - Example: `"Hello world"` -- `"openai"`: A list of dictionaries, similar to OpenAI schema. - - Example: `[{"type": "text", "text": "Hello world!"}]` +You can pass a string to `text_1` and a list to `text_2`, forming multiple sentence pairs +where each pair is built from `text_1` and a string in `text_2`. +The total number of pairs is `len(text_2)`. -If the result is not what you expect, you can set the `--chat-template-content-format` CLI argument -to override which format to use. +Request: -## Command line arguments for the server +```bash +curl -X 'POST' \ + 'http://127.0.0.1:8000/score' \ + -H 'accept: application/json' \ + -H 'Content-Type: application/json' \ + -d '{ + "model": "BAAI/bge-reranker-v2-m3", + "text_1": "What is the capital of France?", + "text_2": [ + "The capital of Brazil is Brasilia.", + "The capital of France is Paris." + ] +}' +``` -```{argparse} -:module: vllm.entrypoints.openai.cli_args -:func: create_parser_for_docs -:prog: vllm serve +Response: + +```bash +{ + "id": "score-request-id", + "object": "list", + "created": 693570, + "model": "BAAI/bge-reranker-v2-m3", + "data": [ + { + "index": 0, + "object": "score", + "score": 0.001094818115234375 + }, + { + "index": 1, + "object": "score", + "score": 1 + } + ], + "usage": {} +} ``` +You can pass a list to both `text_1` and `text_2`, forming multiple sentence pairs +where each pair is built from a string in `text_1` and the corresponding string in `text_2` (similar to `zip()`). +The total number of pairs is `len(text_2)`. + +Request: + +```bash +curl -X 'POST' \ + 'http://127.0.0.1:8000/score' \ + -H 'accept: application/json' \ + -H 'Content-Type: application/json' \ + -d '{ + "model": "BAAI/bge-reranker-v2-m3", + "encoding_format": "float", + "text_1": [ + "What is the capital of Brazil?", + "What is the capital of France?" + ], + "text_2": [ + "The capital of Brazil is Brasilia.", + "The capital of France is Paris." + ] +}' +``` -### Config file +Response: -The `serve` module can also accept arguments from a config file in -`yaml` format. The arguments in the yaml must be specified using the -long form of the argument outlined [here](https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html#command-line-arguments-for-the-server): +```bash +{ + "id": "score-request-id", + "object": "list", + "created": 693447, + "model": "BAAI/bge-reranker-v2-m3", + "data": [ + { + "index": 0, + "object": "score", + "score": 1 + }, + { + "index": 1, + "object": "score", + "score": 1 + } + ], + "usage": {} +} +``` -For example: +#### Extra parameters -```yaml -# config.yaml +The following [pooling parameters (click through to see documentation)](../dev/pooling_params.rst) are supported. -host: "127.0.0.1" -port: 6379 -uvicorn-log-level: "info" +```{literalinclude} ../../../vllm/entrypoints/openai/protocol.py +:language: python +:start-after: begin-score-pooling-params +:end-before: end-score-pooling-params ``` -```bash -$ vllm serve SOME_MODEL --config config.yaml +The following extra parameters are supported: + +```{literalinclude} ../../../vllm/entrypoints/openai/protocol.py +:language: python +:start-after: begin-score-extra-params +:end-before: end-score-extra-params ``` ---- -**NOTE** -In case an argument is supplied simultaneously using command line and the config file, the value from the commandline will take precedence. -The order of priorities is `command line > config file values > defaults`. diff --git a/docs/source/serving/serving_with_llamastack.rst b/docs/source/serving/serving_with_llamastack.rst index 8ef96c4e54369..a2acd7b39f887 100644 --- a/docs/source/serving/serving_with_llamastack.rst +++ b/docs/source/serving/serving_with_llamastack.rst @@ -24,7 +24,7 @@ Then start Llama Stack server pointing to your vLLM server with the following co config: url: http://127.0.0.1:8000 -Please refer to `this guide `_ for more details on this remote vLLM provider. +Please refer to `this guide `_ for more details on this remote vLLM provider. Inference via Embedded vLLM --------------------------- diff --git a/docs/source/usage/compatibility_matrix.rst b/docs/source/usage/compatibility_matrix.rst index a93632ff36fb8..04dd72b1e3527 100644 --- a/docs/source/usage/compatibility_matrix.rst +++ b/docs/source/usage/compatibility_matrix.rst @@ -39,13 +39,13 @@ Feature x Feature - :abbr:`prmpt adptr (Prompt Adapter)` - :ref:`SD ` - CUDA graph - - :abbr:`emd (Embedding Models)` + - :abbr:`pooling (Pooling Models)` - :abbr:`enc-dec (Encoder-Decoder Models)` - :abbr:`logP (Logprobs)` - :abbr:`prmpt logP (Prompt Logprobs)` - :abbr:`async output (Async Output Processing)` - multi-step - - :abbr:`mm (Multimodal)` + - :abbr:`mm (Multimodal Inputs)` - best-of - beam-search - :abbr:`guided dec (Guided Decoding)` @@ -151,7 +151,7 @@ Feature x Feature - - - - * - :abbr:`emd (Embedding Models)` + * - :abbr:`pooling (Pooling Models)` - ✗ - ✗ - ✗ @@ -253,7 +253,7 @@ Feature x Feature - - - - * - :abbr:`mm (Multimodal)` + * - :abbr:`mm (Multimodal Inputs)` - ✅ - `✗ `__ - `✗ `__ @@ -386,7 +386,7 @@ Feature x Hardware - ✅ - ✗ - ✅ - * - :abbr:`emd (Embedding Models)` + * - :abbr:`pooling (Pooling Models)` - ✅ - ✅ - ✅ @@ -402,7 +402,7 @@ Feature x Hardware - ✅ - ✅ - ✗ - * - :abbr:`mm (Multimodal)` + * - :abbr:`mm (Multimodal Inputs)` - ✅ - ✅ - ✅ diff --git a/docs/source/usage/disagg_prefill.rst b/docs/source/usage/disagg_prefill.rst new file mode 100644 index 0000000000000..9fe714b4fd856 --- /dev/null +++ b/docs/source/usage/disagg_prefill.rst @@ -0,0 +1,69 @@ +.. _disagg_prefill: + +Disaggregated prefilling (experimental) +======================================= + +This page introduces you the disaggregated prefilling feature in vLLM. This feature is experimental and subject to change. + +Why disaggregated prefilling? +----------------------------- + +Two main reasons: + +* **Tuning time-to-first-token (TTFT) and inter-token-latency (ITL) separately**. Disaggregated prefilling put prefill and decode phase of LLM inference inside different vLLM instances. This gives you the flexibility to assign different parallel strategies (e.g. ``tp`` and ``pp``) to tune TTFT without affecting ITL, or to tune ITL without affecting TTFT. +* **Controlling tail ITL**. Without disaggregated prefilling, vLLM may insert some prefill jobs during the decoding of one request. This results in higher tail latency. Disaggregated prefilling helps you solve this issue and control tail ITL. Chunked prefill with a proper chunk size also can achieve the same goal, but in practice it's hard to figure out the correct chunk size value. So disaggregated prefilling is a much more reliable way to control tail ITL. + +.. note:: + Disaggregated prefill DOES NOT improve throughput. + +Usage example +------------- + +Please refer to ``examples/disaggregated_prefill.sh`` for the example usage of disaggregated prefilling. + + +Benchmarks +---------- + +Please refer to ``benchmarks/disagg_benchmarks/`` for disaggregated prefilling benchmarks. + + +Development +----------- + +We implement disaggregated prefilling by running 2 vLLM instances. One for prefill (we call it prefill instance) and one for decode (we call it decode instance), and then use a connector to transfer the prefill KV caches and results from prefill instance to decode instance. + +All disaggregated prefilling implementation is under ``vllm/distributed/kv_transfer``. + +Key abstractions for disaggregated prefilling: + +* **Connector**: Connector allows **kv consumer** to retrieve the KV caches of a batch of request from **kv producer**. +* **LookupBuffer**: LookupBuffer provides two API: ``insert`` KV cache and ``drop_select`` KV cache. The semantics of ``insert`` and ``drop_select`` are similar to SQL, where ``insert`` inserts a KV cache into the buffer, and ``drop_select`` returns the KV cache that matches the given condition and drop it from the buffer. +* **Pipe**: A single-direction FIFO pipe for tensor transmission. It supports ``send_tensor`` and ``recv_tensor``. + +.. note:: + ``insert`` is non-blocking operation but ``drop_select`` is blocking operation. + +Here is a figure illustrating how the above 3 abstractions are organized: + +.. image:: /assets/usage/disagg_prefill/abstraction.jpg + :alt: Disaggregated prefilling abstractions + +The workflow of disaggregated prefilling is as follows: + +.. image:: /assets/usage/disagg_prefill/overview.jpg + :alt: Disaggregated prefilling workflow + +The ``buffer`` corresponds to ``insert`` API in LookupBuffer, and the ``drop_select`` corresponds to ``drop_select`` API in LookupBuffer. + + +Third-party contributions +------------------------- + +Disaggregated prefilling is highly related to infrastructure, so vLLM relies on third-party connectors for production-level disaggregated prefilling (and vLLM team will actively review and merge new PRs for third-party connectors). + +We recommend three ways of implementations: + +* **Fully-customized connector**: Implement your own ``Connector``, and call third-party libraries to send and receive KV caches, and many many more (like editing vLLM's model input to perform customized prefilling, etc). This approach gives you the most control, but at the risk of being incompatible with future vLLM versions. +* **Database-like connector**: Implement your own ``LookupBuffer`` and support the ``insert`` and ``drop_select`` APIs just like SQL. +* **Distributed P2P connector**: Implement your own ``Pipe`` and support the ``send_tensor`` and ``recv_tensor`` APIs, just like `torch.distributed`. diff --git a/docs/source/usage/faq.rst b/docs/source/usage/faq.rst index ce327abd5fa20..d88da32092924 100644 --- a/docs/source/usage/faq.rst +++ b/docs/source/usage/faq.rst @@ -11,7 +11,12 @@ A: Assuming that you're referring to using OpenAI compatible server to serve mul Q: Which model to use for offline inference embedding? -A: If you want to use an embedding model, try: https://huggingface.co/intfloat/e5-mistral-7b-instruct. Instead models, such as Llama-3-8b, Mistral-7B-Instruct-v0.3, are generation models rather than an embedding model +A: You can try `e5-mistral-7b-instruct `__ and `BAAI/bge-base-en-v1.5 `__; +more are listed :ref:`here `. + +By extracting hidden states, vLLM can automatically convert text generation models like `Llama-3-8B `__, +`Mistral-7B-Instruct-v0.3 `__ into embedding models, +but they are expected be inferior to models that are specifically trained on embedding tasks. ---------------------------------------- diff --git a/docs/source/usage/multimodal_inputs.rst b/docs/source/usage/multimodal_inputs.rst index c93f65327e31b..680382e457cc5 100644 --- a/docs/source/usage/multimodal_inputs.rst +++ b/docs/source/usage/multimodal_inputs.rst @@ -315,7 +315,95 @@ You can use `these tests `_. +Here is a simple example using Ultravox-v0.3. + +First, launch the OpenAI-compatible server: + +.. code-block:: bash + + vllm serve fixie-ai/ultravox-v0_3 + +Then, you can use the OpenAI client as follows: + +.. code-block:: python + + import base64 + import requests + from openai import OpenAI + from vllm.assets.audio import AudioAsset + + def encode_base64_content_from_url(content_url: str) -> str: + """Encode a content retrieved from a remote url to base64 format.""" + + with requests.get(content_url) as response: + response.raise_for_status() + result = base64.b64encode(response.content).decode('utf-8') + + return result + + openai_api_key = "EMPTY" + openai_api_base = "http://localhost:8000/v1" + + client = OpenAI( + api_key=openai_api_key, + base_url=openai_api_base, + ) + + # Any format supported by librosa is supported + audio_url = AudioAsset("winning_call").url + audio_base64 = encode_base64_content_from_url(audio_url) + + chat_completion_from_base64 = client.chat.completions.create( + messages=[{ + "role": "user", + "content": [ + { + "type": "text", + "text": "What's in this audio?" + }, + { + "type": "input_audio", + "input_audio": { + "data": audio_base64, + "format": "wav" + }, + }, + ], + }], + model=model, + max_completion_tokens=64, + ) + + result = chat_completion_from_base64.choices[0].message.content + print("Chat completion output from input audio:", result) + +Alternatively, you can pass :code:`audio_url`, which is the audio counterpart of :code:`image_url` for image input: + +.. code-block:: python + + chat_completion_from_url = client.chat.completions.create( + messages=[{ + "role": "user", + "content": [ + { + "type": "text", + "text": "What's in this audio?" + }, + { + "type": "audio_url", + "audio_url": { + "url": audio_url + }, + }, + ], + }], + model=model, + max_completion_tokens=64, + ) + + result = chat_completion_from_url.choices[0].message.content + print("Chat completion output from audio url:", result) A full code example can be found in `examples/openai_chat_completion_client_for_multimodal.py `_. @@ -345,12 +433,12 @@ Here is an end-to-end example using VLM2Vec. To serve the model: .. code-block:: bash - vllm serve TIGER-Lab/VLM2Vec-Full --task embedding \ + vllm serve TIGER-Lab/VLM2Vec-Full --task embed \ --trust-remote-code --max-model-len 4096 --chat-template examples/template_vlm2vec.jinja .. important:: - Since VLM2Vec has the same model architecture as Phi-3.5-Vision, we have to explicitly pass ``--task embedding`` + Since VLM2Vec has the same model architecture as Phi-3.5-Vision, we have to explicitly pass ``--task embed`` to run this model in embedding mode instead of text generation mode. The custom chat template is completely different from the original one for this model, @@ -386,12 +474,12 @@ Below is another example, this time using the ``MrLight/dse-qwen2-2b-mrl-v1`` mo .. code-block:: bash - vllm serve MrLight/dse-qwen2-2b-mrl-v1 --task embedding \ + vllm serve MrLight/dse-qwen2-2b-mrl-v1 --task embed \ --trust-remote-code --max-model-len 8192 --chat-template examples/template_dse_qwen2_vl.jinja .. important:: - Like with VLM2Vec, we have to explicitly pass ``--task embedding``. + Like with VLM2Vec, we have to explicitly pass ``--task embed``. Additionally, ``MrLight/dse-qwen2-2b-mrl-v1`` requires an EOS token for embeddings, which is handled by `this custom chat template `__. diff --git a/examples/offline_inference_audio_language.py b/examples/offline_inference_audio_language.py index 050b791b62adb..68b786961b14a 100644 --- a/examples/offline_inference_audio_language.py +++ b/examples/offline_inference_audio_language.py @@ -25,16 +25,16 @@ def run_ultravox(question: str, audio_count: int): tokenizer = AutoTokenizer.from_pretrained(model_name) messages = [{ - 'role': - 'user', - 'content': - "<|reserved_special_token_0|>\n" * audio_count + question + 'role': 'user', + 'content': "<|audio|>\n" * audio_count + question }] prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) - llm = LLM(model=model_name, limit_mm_per_prompt={"audio": audio_count}) + llm = LLM(model=model_name, + trust_remote_code=True, + limit_mm_per_prompt={"audio": audio_count}) stop_token_ids = None return llm, prompt, stop_token_ids diff --git a/examples/offline_inference_classification.py b/examples/offline_inference_classification.py new file mode 100644 index 0000000000000..de539b639a196 --- /dev/null +++ b/examples/offline_inference_classification.py @@ -0,0 +1,28 @@ +from vllm import LLM + +# Sample prompts. +prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", +] + +# Create an LLM. +# You should pass task="classify" for classification models +model = LLM( + model="jason9693/Qwen2.5-1.5B-apeach", + task="classify", + enforce_eager=True, +) + +# Generate logits. The output is a list of ClassificationRequestOutputs. +outputs = model.classify(prompts) + +# Print the outputs. +for prompt, output in zip(prompts, outputs): + probs = output.outputs.probs + probs_trimmed = ((str(probs[:16])[:-1] + + ", ...]") if len(probs) > 16 else probs) + print(f"Prompt: {prompt!r} | " + f"Class Probabilities: {probs_trimmed} (size={len(probs)})") diff --git a/examples/offline_inference_embedding.py b/examples/offline_inference_embedding.py index ae158eef2ca4c..58d004313ad51 100644 --- a/examples/offline_inference_embedding.py +++ b/examples/offline_inference_embedding.py @@ -9,9 +9,20 @@ ] # Create an LLM. -model = LLM(model="intfloat/e5-mistral-7b-instruct", enforce_eager=True) -# Generate embedding. The output is a list of PoolingRequestOutputs. -outputs = model.encode(prompts) +# You should pass task="embed" for embedding models +model = LLM( + model="intfloat/e5-mistral-7b-instruct", + task="embed", + enforce_eager=True, +) + +# Generate embedding. The output is a list of EmbeddingRequestOutputs. +outputs = model.embed(prompts) + # Print the outputs. -for output in outputs: - print(output.outputs.embedding) # list of 4096 floats +for prompt, output in zip(prompts, outputs): + embeds = output.outputs.embedding + embeds_trimmed = ((str(embeds[:16])[:-1] + + ", ...]") if len(embeds) > 16 else embeds) + print(f"Prompt: {prompt!r} | " + f"Embeddings: {embeds_trimmed} (size={len(embeds)})") diff --git a/examples/offline_inference_openai.md b/examples/offline_inference_openai.md index 4c64197975534..2436417cb543a 100644 --- a/examples/offline_inference_openai.md +++ b/examples/offline_inference_openai.md @@ -1,45 +1,48 @@ # Offline Inference with the OpenAI Batch file format - **NOTE:** This is a guide to performing batch inference using the OpenAI batch file format, **NOT** the complete Batch (REST) API. - - ## File Format - - The OpenAI batch file format consists of a series of json objects on new lines. +```{important} +This is a guide to performing batch inference using the OpenAI batch file format, **not** the complete Batch (REST) API. +``` + +## File Format - [See here for an example file.](https://github.com/vllm-project/vllm/blob/main/examples/openai_example_batch.jsonl) +The OpenAI batch file format consists of a series of json objects on new lines. - Each line represents a separate request. See the [OpenAI package reference](https://platform.openai.com/docs/api-reference/batch/requestInput) for more details. +[See here for an example file.](https://github.com/vllm-project/vllm/blob/main/examples/openai_example_batch.jsonl) - **NOTE:** We currently only support `/v1/chat/completions` and `/v1/embeddings` endpoints (completions coming soon). +Each line represents a separate request. See the [OpenAI package reference](https://platform.openai.com/docs/api-reference/batch/requestInput) for more details. - ## Pre-requisites +```{note} +We currently only support `/v1/chat/completions` and `/v1/embeddings` endpoints (completions coming soon). +``` -* Ensure you are using `vllm >= 0.4.3`. You can check by running `python -c "import vllm; print(vllm.__version__)"`. +## Pre-requisites + * The examples in this document use `meta-llama/Meta-Llama-3-8B-Instruct`. - Create a [user access token](https://huggingface.co/docs/hub/en/security-tokens) - Install the token on your machine (Run `huggingface-cli login`). - Get access to the gated model by [visiting the model card](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct) and agreeing to the terms and conditions. - ## Example 1: Running with a local file - - ### Step 1: Create your batch file - - To follow along with this example, you can download the example batch, or create your own batch file in your working directory. - - ``` - wget https://raw.githubusercontent.com/vllm-project/vllm/main/examples/openai_example_batch.jsonl - ``` - - Once you've created your batch file it should look like this - - ``` - $ cat openai_example_batch.jsonl +## Example 1: Running with a local file + +### Step 1: Create your batch file + +To follow along with this example, you can download the example batch, or create your own batch file in your working directory. + +``` +wget https://raw.githubusercontent.com/vllm-project/vllm/main/examples/openai_example_batch.jsonl +``` + +Once you've created your batch file it should look like this + +``` +$ cat openai_example_batch.jsonl {"custom_id": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "meta-llama/Meta-Llama-3-8B-Instruct", "messages": [{"role": "system", "content": "You are a helpful assistant."},{"role": "user", "content": "Hello world!"}],"max_completion_tokens": 1000}} {"custom_id": "request-2", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "meta-llama/Meta-Llama-3-8B-Instruct", "messages": [{"role": "system", "content": "You are an unhelpful assistant."},{"role": "user", "content": "Hello world!"}],"max_completion_tokens": 1000}} - ``` - - ### Step 2: Run the batch +``` + +### Step 2: Run the batch The batch running tool is designed to be used from the command line. @@ -85,18 +88,18 @@ To integrate with cloud blob storage, we recommend using presigned urls. ### Step 1: Upload your input script To follow along with this example, you can download the example batch, or create your own batch file in your working directory. - - ``` - wget https://raw.githubusercontent.com/vllm-project/vllm/main/examples/openai_example_batch.jsonl - ``` - - Once you've created your batch file it should look like this - - ``` - $ cat openai_example_batch.jsonl + +``` +wget https://raw.githubusercontent.com/vllm-project/vllm/main/examples/openai_example_batch.jsonl +``` + +Once you've created your batch file it should look like this + +``` +$ cat openai_example_batch.jsonl {"custom_id": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "meta-llama/Meta-Llama-3-8B-Instruct", "messages": [{"role": "system", "content": "You are a helpful assistant."},{"role": "user", "content": "Hello world!"}],"max_completion_tokens": 1000}} {"custom_id": "request-2", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "meta-llama/Meta-Llama-3-8B-Instruct", "messages": [{"role": "system", "content": "You are an unhelpful assistant."},{"role": "user", "content": "Hello world!"}],"max_completion_tokens": 1000}} - ``` +``` Now upload your batch file to your S3 bucket. @@ -104,7 +107,6 @@ Now upload your batch file to your S3 bucket. aws s3 cp openai_example_batch.jsonl s3://MY_BUCKET/MY_INPUT_FILE.jsonl ``` - ### Step 2: Generate your presigned urls Presigned urls can only be generated via the SDK. You can run the following python script to generate your presigned urls. Be sure to replace the `MY_BUCKET`, `MY_INPUT_FILE.jsonl`, and `MY_OUTPUT_FILE.jsonl` placeholders with your bucket and file names. @@ -179,21 +181,19 @@ aws s3 cp s3://MY_BUCKET/MY_OUTPUT_FILE.jsonl - ### Step 1: Create your batch file - Add embedding requests to your batch file. The following is an example: +Add embedding requests to your batch file. The following is an example: - ``` - {"custom_id": "request-1", "method": "POST", "url": "/v1/embeddings", "body": {"model": "intfloat/e5-mistral-7b-instruct", "input": "You are a helpful assistant."}} +``` +{"custom_id": "request-1", "method": "POST", "url": "/v1/embeddings", "body": {"model": "intfloat/e5-mistral-7b-instruct", "input": "You are a helpful assistant."}} {"custom_id": "request-2", "method": "POST", "url": "/v1/embeddings", "body": {"model": "intfloat/e5-mistral-7b-instruct", "input": "You are an unhelpful assistant."}} ``` - - You can even mix chat completion and embedding requests in the batch file, as long as the model you are using supports both chat completion and embeddings (note that all requests must use the same model). +You can even mix chat completion and embedding requests in the batch file, as long as the model you are using supports both chat completion and embeddings (note that all requests must use the same model). - ### Step 2: Run the batch +### Step 2: Run the batch You can run the batch using the same command as in earlier examples. - ### Step 3: Check your results You can check your results by running `cat results.jsonl` @@ -201,5 +201,5 @@ You can check your results by running `cat results.jsonl` ``` $ cat results.jsonl {"id":"vllm-db0f71f7dec244e6bce530e0b4ef908b","custom_id":"request-1","response":{"status_code":200,"request_id":"vllm-batch-3580bf4d4ae54d52b67eee266a6eab20","body":{"id":"embd-33ac2efa7996430184461f2e38529746","object":"list","created":444647,"model":"intfloat/e5-mistral-7b-instruct","data":[{"index":0,"object":"embedding","embedding":[0.016204833984375,0.0092010498046875,0.0018358230590820312,-0.0028228759765625,0.001422882080078125,-0.0031147003173828125,...]}],"usage":{"prompt_tokens":8,"total_tokens":8,"completion_tokens":0}}},"error":null} -...``` +... ``` diff --git a/examples/offline_inference_scoring.py b/examples/offline_inference_scoring.py new file mode 100644 index 0000000000000..5da9e710959b5 --- /dev/null +++ b/examples/offline_inference_scoring.py @@ -0,0 +1,23 @@ +from vllm import LLM + +# Sample prompts. +text_1 = "What is the capital of France?" +texts_2 = [ + "The capital of Brazil is Brasilia.", "The capital of France is Paris." +] + +# Create an LLM. +# You should pass task="score" for cross-encoder models +model = LLM( + model="BAAI/bge-reranker-v2-m3", + task="score", + enforce_eager=True, +) + +# Generate scores. The output is a list of ScoringRequestOutputs. +outputs = model.score(text_1, texts_2) + +# Print the outputs. +for text_2, output in zip(texts_2, outputs): + score = output.outputs.score + print(f"Pair: {[text_1, text_2]!r} | Score: {score}") diff --git a/examples/offline_inference_vision_language.py b/examples/offline_inference_vision_language.py index c6a274ee5894b..6d0495fdd4054 100644 --- a/examples/offline_inference_vision_language.py +++ b/examples/offline_inference_vision_language.py @@ -5,6 +5,8 @@ For most models, the prompt format should follow corresponding examples on HuggingFace model repository. """ +import random + from transformers import AutoTokenizer from vllm import LLM, SamplingParams @@ -17,13 +19,168 @@ # Unless specified, these settings have been tested to work on a single L4. +# Aria +def run_aria(question: str, modality: str): + assert modality == "image" + model_name = "rhymes-ai/Aria" + + llm = LLM(model=model_name, + tokenizer_mode="slow", + trust_remote_code=True, + dtype="bfloat16", + mm_cache_preprocessor=args.mm_cache_preprocessor) + + prompt = (f"<|im_start|>user\n<|img|>\n{question}" + "<|im_end|>\n<|im_start|>assistant\n") + + stop_token_ids = [93532, 93653, 944, 93421, 1019, 93653, 93519] + return llm, prompt, stop_token_ids + + +# BLIP-2 +def run_blip2(question: str, modality: str): + assert modality == "image" + + # BLIP-2 prompt format is inaccurate on HuggingFace model repository. + # See https://huggingface.co/Salesforce/blip2-opt-2.7b/discussions/15#64ff02f3f8cf9e4f5b038262 #noqa + prompt = f"Question: {question} Answer:" + llm = LLM(model="Salesforce/blip2-opt-2.7b", + mm_cache_preprocessor=args.mm_cache_preprocessor) + stop_token_ids = None + return llm, prompt, stop_token_ids + + +# Chameleon +def run_chameleon(question: str, modality: str): + assert modality == "image" + + prompt = f"{question}" + llm = LLM(model="facebook/chameleon-7b", + max_model_len=4096, + mm_cache_preprocessor=args.mm_cache_preprocessor) + stop_token_ids = None + return llm, prompt, stop_token_ids + + +# Fuyu +def run_fuyu(question: str, modality: str): + assert modality == "image" + + prompt = f"{question}\n" + llm = LLM(model="adept/fuyu-8b", + max_model_len=2048, + max_num_seqs=2, + mm_cache_preprocessor=args.mm_cache_preprocessor) + stop_token_ids = None + return llm, prompt, stop_token_ids + + +# GLM-4v +def run_glm4v(question: str, modality: str): + assert modality == "image" + model_name = "THUDM/glm-4v-9b" + + llm = LLM(model=model_name, + max_model_len=2048, + max_num_seqs=2, + trust_remote_code=True, + enforce_eager=True, + mm_cache_preprocessor=args.mm_cache_preprocessor) + prompt = question + stop_token_ids = [151329, 151336, 151338] + return llm, prompt, stop_token_ids + + +# H2OVL-Mississippi +def run_h2ovl(question: str, modality: str): + assert modality == "image" + + model_name = "h2oai/h2ovl-mississippi-2b" + + llm = LLM( + model=model_name, + trust_remote_code=True, + max_model_len=8192, + mm_cache_preprocessor=args.mm_cache_preprocessor, + ) + + tokenizer = AutoTokenizer.from_pretrained(model_name, + trust_remote_code=True) + messages = [{'role': 'user', 'content': f"\n{question}"}] + prompt = tokenizer.apply_chat_template(messages, + tokenize=False, + add_generation_prompt=True) + + # Stop tokens for H2OVL-Mississippi + # https://huggingface.co/h2oai/h2ovl-mississippi-2b + stop_token_ids = [tokenizer.eos_token_id] + return llm, prompt, stop_token_ids + + +# Idefics3-8B-Llama3 +def run_idefics3(question: str, modality: str): + assert modality == "image" + model_name = "HuggingFaceM4/Idefics3-8B-Llama3" + + llm = LLM( + model=model_name, + max_model_len=8192, + max_num_seqs=2, + enforce_eager=True, + # if you are running out of memory, you can reduce the "longest_edge". + # see: https://huggingface.co/HuggingFaceM4/Idefics3-8B-Llama3#model-optimizations + mm_processor_kwargs={ + "size": { + "longest_edge": 3 * 364 + }, + }, + mm_cache_preprocessor=args.mm_cache_preprocessor, + ) + prompt = ( + f"<|begin_of_text|>User:{question}\nAssistant:" + ) + stop_token_ids = None + return llm, prompt, stop_token_ids + + +# InternVL +def run_internvl(question: str, modality: str): + assert modality == "image" + + model_name = "OpenGVLab/InternVL2-2B" + + llm = LLM( + model=model_name, + trust_remote_code=True, + max_model_len=4096, + mm_cache_preprocessor=args.mm_cache_preprocessor, + ) + + tokenizer = AutoTokenizer.from_pretrained(model_name, + trust_remote_code=True) + messages = [{'role': 'user', 'content': f"\n{question}"}] + prompt = tokenizer.apply_chat_template(messages, + tokenize=False, + add_generation_prompt=True) + + # Stop tokens for InternVL + # models variants may have different stop tokens + # please refer to the model card for the correct "stop words": + # https://huggingface.co/OpenGVLab/InternVL2-2B/blob/main/conversation.py + stop_tokens = ["<|endoftext|>", "<|im_start|>", "<|im_end|>", "<|end|>"] + stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens] + return llm, prompt, stop_token_ids + + # LLaVA-1.5 def run_llava(question: str, modality: str): assert modality == "image" prompt = f"USER: \n{question}\nASSISTANT:" - llm = LLM(model="llava-hf/llava-1.5-7b-hf", max_model_len=4096) + llm = LLM(model="llava-hf/llava-1.5-7b-hf", + max_model_len=4096, + mm_cache_preprocessor=args.mm_cache_preprocessor) stop_token_ids = None return llm, prompt, stop_token_ids @@ -33,7 +190,9 @@ def run_llava_next(question: str, modality: str): assert modality == "image" prompt = f"[INST] \n{question} [/INST]" - llm = LLM(model="llava-hf/llava-v1.6-mistral-7b-hf", max_model_len=8192) + llm = LLM(model="llava-hf/llava-v1.6-mistral-7b-hf", + max_model_len=8192, + mm_cache_preprocessor=args.mm_cache_preprocessor) stop_token_ids = None return llm, prompt, stop_token_ids @@ -44,7 +203,9 @@ def run_llava_next_video(question: str, modality: str): assert modality == "video" prompt = f"USER: