diff --git a/.buildkite/nightly-benchmarks/benchmark-pipeline.yaml b/.buildkite/nightly-benchmarks/benchmark-pipeline.yaml index 2b70e2da5d87c..eec2a51e2f8fd 100644 --- a/.buildkite/nightly-benchmarks/benchmark-pipeline.yaml +++ b/.buildkite/nightly-benchmarks/benchmark-pipeline.yaml @@ -8,8 +8,7 @@ steps: containers: - image: badouralix/curl-jq command: - - sh - - .buildkite/nightly-benchmarks/scripts/wait-for-image.sh + - sh .buildkite/nightly-benchmarks/scripts/wait-for-image.sh - wait - label: "A100" agents: diff --git a/.buildkite/nightly-benchmarks/scripts/wait-for-image.sh b/.buildkite/nightly-benchmarks/scripts/wait-for-image.sh index c785e6a0da628..f16862907def1 100644 --- a/.buildkite/nightly-benchmarks/scripts/wait-for-image.sh +++ b/.buildkite/nightly-benchmarks/scripts/wait-for-image.sh @@ -2,9 +2,11 @@ TOKEN=$(curl -s -L "https://public.ecr.aws/token?service=public.ecr.aws&scope=repository:q9t5s3a7/vllm-ci-test-repo:pull" | jq -r .token) URL="https://public.ecr.aws/v2/q9t5s3a7/vllm-ci-test-repo/manifests/$BUILDKITE_COMMIT" +TIMEOUT_SECONDS=10 + retries=0 while [ $retries -lt 1000 ]; do - if [ $(curl -s -L -H "Authorization: Bearer $TOKEN" -o /dev/null -w "%{http_code}" $URL) -eq 200 ]; then + if [ $(curl -s --max-time $TIMEOUT_SECONDS -L -H "Authorization: Bearer $TOKEN" -o /dev/null -w "%{http_code}" $URL) -eq 200 ]; then exit 0 fi diff --git a/.buildkite/run-amd-test.sh b/.buildkite/run-amd-test.sh index 6659440135ff4..9274a30e04325 100755 --- a/.buildkite/run-amd-test.sh +++ b/.buildkite/run-amd-test.sh @@ -83,6 +83,7 @@ if [[ $commands == *" kernels "* ]]; then --ignore=kernels/test_encoder_decoder_attn.py \ --ignore=kernels/test_flash_attn.py \ --ignore=kernels/test_flashinfer.py \ + --ignore=kernels/test_gguf.py \ --ignore=kernels/test_int8_quant.py \ --ignore=kernels/test_machete_gemm.py \ --ignore=kernels/test_mamba_ssm.py \ diff --git a/.buildkite/run-cpu-test.sh b/.buildkite/run-cpu-test.sh index d2ae926daa7c0..73ce82c5857ab 100644 --- a/.buildkite/run-cpu-test.sh +++ b/.buildkite/run-cpu-test.sh @@ -22,13 +22,11 @@ docker exec cpu-test-avx2 bash -c "python3 examples/offline_inference.py" # Run basic model test docker exec cpu-test bash -c " - pip install pytest matplotlib einops transformers_stream_generator - pytest -v -s tests/models -m \"not vlm\" --ignore=tests/models/test_embedding.py \ - --ignore=tests/models/test_oot_registration.py \ - --ignore=tests/models/test_registry.py \ - --ignore=tests/models/test_fp8.py \ - --ignore=tests/models/test_jamba.py \ - --ignore=tests/models/test_danube3_4b.py" # Mamba and Danube3-4B on CPU is not supported + pip install pytest matplotlib einops transformers_stream_generator datamodel_code_generator + pytest -v -s tests/models/decoder_only/language \ + --ignore=tests/models/test_fp8.py \ + --ignore=tests/models/decoder_only/language/test_jamba.py \ + --ignore=tests/models/decoder_only/language/test_danube3_4b.py" # Mamba and Danube3-4B on CPU is not supported # Run compressed-tensor test docker exec cpu-test bash -c " diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 25f18cc57793e..37207b677a1ee 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -43,13 +43,16 @@ steps: fast_check: true source_file_dependencies: - vllm/ + - tests/mq_llm_engine - tests/async_engine - tests/test_inputs - tests/multimodal - tests/test_utils - tests/worker commands: - - pytest -v -s async_engine # Async Engine + - pytest -v -s mq_llm_engine # MQLLMEngine + - pytest -v -s async_engine # AsyncLLMEngine + - NUM_SCHEDULER_STEPS=4 pytest -v -s async_engine/test_async_llm_engine.py - pytest -v -s test_inputs.py - pytest -v -s multimodal - pytest -v -s test_utils.py # Utils @@ -93,7 +96,6 @@ steps: - pytest -v -s entrypoints/test_chat_utils.py - pytest -v -s entrypoints/offline_mode # Needs to avoid interference with other tests - - label: Distributed Tests (4 GPUs) # 10min working_dir: "/vllm-workspace/tests" num_gpus: 4 @@ -163,30 +165,6 @@ steps: - python3 tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors - python3 offline_inference_encoder_decoder.py -- label: Models Test # 1hr10min - source_file_dependencies: - - vllm/ - - tests/models - commands: - - pip install -e ./plugins/vllm_add_dummy_model - - pytest -v -s models/test_oot_registration.py # it needs a clean process - - pytest -v -s models -m \"not vlm\" --ignore=models/test_oot_registration.py - -- label: torch compile integration test - source_file_dependencies: - - vllm/ - commands: - - pytest -v -s ./compile/test_full_graph.py - - pytest -v -s ./compile/test_wrapper.py - - -- label: Vision Language Models Test # 42min - #mirror_hardwares: [amd] - source_file_dependencies: - - vllm/ - commands: - - pytest -v -s models -m vlm - - label: Prefix Caching Test # 7min #mirror_hardwares: [amd] source_file_dependencies: @@ -276,6 +254,13 @@ steps: - export VLLM_WORKER_MULTIPROC_METHOD=spawn - bash ./run-tests.sh -c configs/models-small.txt -t 1 +- label: Encoder Decoder tests # 5min + source_file_dependencies: + - vllm/ + - tests/encoder_decoder + commands: + - pytest -v -s encoder_decoder + - label: OpenAI-Compatible Tool Use # 20 min fast_check: false mirror_hardwares: [ amd ] @@ -285,6 +270,45 @@ steps: commands: - pytest -v -s tool_use +##### models test ##### + +- label: Basic Models Test # 3min + source_file_dependencies: + - vllm/ + - tests/models + commands: + - pip install -e ./plugins/vllm_add_dummy_model + - pytest -v -s models/test_oot_registration.py # it needs a clean process + - pytest -v -s models/*.py --ignore=models/test_oot_registration.py + +- label: Decoder-only Language Models Test # 1h3min + #mirror_hardwares: [amd] + source_file_dependencies: + - vllm/ + - tests/models/decoder_only/language + commands: + - pytest -v -s models/decoder_only/language + +- label: Decoder-only Multi-Modal Models Test # 56min + #mirror_hardwares: [amd] + source_file_dependencies: + - vllm/ + - tests/models/decoder_only/audio_language + - tests/models/decoder_only/vision_language + commands: + - pytest -v -s models/decoder_only/audio_language + - pytest -v -s models/decoder_only/vision_language + +- label: Other Models Test # 5min + #mirror_hardwares: [amd] + source_file_dependencies: + - vllm/ + - tests/models/embedding/language + - tests/models/encoder_decoder/language + commands: + - pytest -v -s models/embedding/language + - pytest -v -s models/encoder_decoder/language + ##### 1 GPU test ##### ##### multi gpus test ##### @@ -310,11 +334,11 @@ steps: - tests/distributed/ commands: - # the following commands are for the first node, with ip 192.168.10.10 (ray environment already set up) - - VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py + - VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py | grep -q 'Same node test passed' - VLLM_MULTI_NODE=1 pytest -v -s distributed/test_multi_node_assignment.py - VLLM_MULTI_NODE=1 pytest -v -s distributed/test_pipeline_parallel.py - # the following commands are for the second node, with ip 192.168.10.11 (ray environment already set up) - - VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py + - VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py | grep -q 'Same node test passed' - label: Distributed Tests (2 GPUs) # 28min #mirror_hardwares: [amd] @@ -326,12 +350,14 @@ steps: - vllm/executor/ - vllm/model_executor/models/ - tests/distributed/ + - vllm/compilation commands: - - VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py - - TARGET_TEST_SUITE=L4 pytest -v -s distributed/test_basic_distributed_correctness.py - - pytest -v -s distributed/test_basic_distributed_correctness_enc_dec.py - - pytest -v -s distributed/test_chunked_prefill_distributed.py - - pytest -v -s distributed/test_multimodal_broadcast.py + - pytest -v -s ./compile/test_full_graph.py + - pytest -v -s ./compile/test_wrapper.py + - VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep -q 'Same node test passed' + - TARGET_TEST_SUITE=L4 pytest basic_correctness/ -v -s -m distributed_2_gpus + # Avoid importing model tests that cause CUDA reinitialization error + - pytest models/encoder_decoder/language/test_bart.py models/decoder_only/vision_language/test_broadcast.py -v -s -m distributed_2_gpus - pytest -v -s spec_decode/e2e/test_integration_dist_tp2.py - pip install -e ./plugins/vllm_add_dummy_model - pytest -v -s distributed/test_distributed_oot.py diff --git a/.github/workflows/ruff.yml b/.github/workflows/ruff.yml index 1a794af572fef..90735d6e2bbf9 100644 --- a/.github/workflows/ruff.yml +++ b/.github/workflows/ruff.yml @@ -25,10 +25,10 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - pip install ruff==0.1.5 codespell==2.3.0 tomli==2.0.1 isort==5.13.2 + pip install -r requirements-lint.txt - name: Analysing the code with ruff run: | - ruff . + ruff check . - name: Spelling check with codespell run: | codespell --toml pyproject.toml diff --git a/CMakeLists.txt b/CMakeLists.txt index f8d6a2be9feae..c8f19de94e59b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -324,6 +324,25 @@ define_gpu_extension_target( WITH_SOABI) +if(VLLM_GPU_LANG STREQUAL "HIP") + # + # _rocm_C extension + # + set(VLLM_ROCM_EXT_SRC + "csrc/rocm/torch_bindings.cpp" + "csrc/rocm/attention.cu") + + define_gpu_extension_target( + _rocm_C + DESTINATION vllm + LANGUAGE ${VLLM_GPU_LANG} + SOURCES ${VLLM_ROCM_EXT_SRC} + COMPILE_FLAGS ${VLLM_GPU_FLAGS} + ARCHITECTURES ${VLLM_GPU_ARCHES} + USE_SABI 3 + WITH_SOABI) +endif() + if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP") message(STATUS "Enabling C extension.") @@ -331,5 +350,9 @@ if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP") message(STATUS "Enabling moe extension.") add_dependencies(default _moe_C) +endif() +if(VLLM_GPU_LANG STREQUAL "HIP") + message(STATUS "Enabling rocm extension.") + add_dependencies(default _rocm_C) endif() diff --git a/Dockerfile b/Dockerfile index 5484be5bc5785..001068b4b36ca 100644 --- a/Dockerfile +++ b/Dockerfile @@ -82,6 +82,7 @@ ENV BUILDKITE_COMMIT=${buildkite_commit} ARG USE_SCCACHE ARG SCCACHE_BUCKET_NAME=vllm-build-sccache ARG SCCACHE_REGION_NAME=us-west-2 +ARG SCCACHE_S3_NO_CREDENTIALS=0 # if USE_SCCACHE is set, use sccache to speed up compilation RUN --mount=type=cache,target=/root/.cache/pip \ if [ "$USE_SCCACHE" = "1" ]; then \ @@ -92,6 +93,7 @@ RUN --mount=type=cache,target=/root/.cache/pip \ && rm -rf sccache.tar.gz sccache-v0.8.1-x86_64-unknown-linux-musl \ && export SCCACHE_BUCKET=${SCCACHE_BUCKET_NAME} \ && export SCCACHE_REGION=${SCCACHE_REGION_NAME} \ + && export SCCACHE_S3_NO_CREDENTIALS=${SCCACHE_S3_NO_CREDENTIALS} \ && export SCCACHE_IDLE_TIMEOUT=0 \ && export CMAKE_BUILD_TYPE=Release \ && sccache --show-stats \ @@ -180,10 +182,6 @@ FROM vllm-base AS test ADD . /vllm-workspace/ # install development dependencies (for testing) -# A newer setuptools is required for installing some test dependencies from source that do not publish python 3.12 wheels -# This installation must complete before the test dependencies are collected and installed. -RUN --mount=type=cache,target=/root/.cache/pip \ - python3 -m pip install "setuptools>=74.1.1" RUN --mount=type=cache,target=/root/.cache/pip \ python3 -m pip install -r requirements-dev.txt diff --git a/Dockerfile.cpu b/Dockerfile.cpu index 34b4c95e34ffc..4d7289366296b 100644 --- a/Dockerfile.cpu +++ b/Dockerfile.cpu @@ -24,6 +24,8 @@ RUN echo 'ulimit -c 0' >> ~/.bashrc RUN pip install https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_dev/cpu/intel_extension_for_pytorch-2.4.0%2Bgitfbaa4bc-cp310-cp310-linux_x86_64.whl +WORKDIR /workspace + ENV PIP_EXTRA_INDEX_URL=https://download.pytorch.org/whl/cpu RUN --mount=type=cache,target=/root/.cache/pip \ --mount=type=bind,src=requirements-build.txt,target=requirements-build.txt \ diff --git a/Dockerfile.xpu b/Dockerfile.xpu index 321da98cf6c89..50bbd8f7dad87 100644 --- a/Dockerfile.xpu +++ b/Dockerfile.xpu @@ -1,15 +1,23 @@ -FROM intel/oneapi-basekit:2024.1.0-devel-ubuntu20.04 +FROM intel/oneapi-basekit:2024.2.1-0-devel-ubuntu22.04 RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB | gpg --dearmor | tee /usr/share/keyrings/intel-oneapi-archive-keyring.gpg > /dev/null && \ echo "deb [signed-by=/usr/share/keyrings/intel-oneapi-archive-keyring.gpg] https://apt.repos.intel.com/oneapi all main " | tee /etc/apt/sources.list.d/oneAPI.list && \ chmod 644 /usr/share/keyrings/intel-oneapi-archive-keyring.gpg && \ - rm /etc/apt/sources.list.d/intel-graphics.list && \ wget -O- https://repositories.intel.com/graphics/intel-graphics.key | gpg --dearmor | tee /usr/share/keyrings/intel-graphics.gpg > /dev/null && \ echo "deb [arch=amd64,i386 signed-by=/usr/share/keyrings/intel-graphics.gpg] https://repositories.intel.com/graphics/ubuntu jammy arc" | tee /etc/apt/sources.list.d/intel.gpu.jammy.list && \ chmod 644 /usr/share/keyrings/intel-graphics.gpg RUN apt-get update -y \ && apt-get install -y curl libicu70 lsb-release git wget vim numactl python3 python3-pip ffmpeg libsm6 libxext6 libgl1 + +RUN git clone https://github.com/intel/pti-gpu && \ + cd pti-gpu/sdk && \ + mkdir build && \ + cd build && \ + cmake -DCMAKE_BUILD_TYPE=Release -DCMAKE_TOOLCHAIN_FILE=../cmake/toolchains/icpx_toolchain.cmake -DBUILD_TESTING=OFF .. && \ + make -j && \ + cmake --install . --config Release --prefix "/usr/local" + COPY ./ /workspace/vllm WORKDIR /workspace/vllm diff --git a/benchmarks/backend_request_func.py b/benchmarks/backend_request_func.py index 3243bb94f787c..3def4a6d67acf 100644 --- a/benchmarks/backend_request_func.py +++ b/benchmarks/backend_request_func.py @@ -25,6 +25,7 @@ class RequestFuncInput: best_of: int = 1 use_beam_search: bool = False logprobs: Optional[int] = None + multi_modal_content: Optional[dict] = None @dataclass @@ -312,12 +313,15 @@ async def async_request_openai_chat_completions( async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: assert not request_func_input.use_beam_search + content = [{"type": "text", "text": request_func_input.prompt}] + if request_func_input.multi_modal_content: + content.append(request_func_input.multi_modal_content) payload = { "model": request_func_input.model, "messages": [ { "role": "user", - "content": request_func_input.prompt, + "content": content }, ], "temperature": 0.0, diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py index 9ba3f649810b7..3ace910a6cac6 100644 --- a/benchmarks/benchmark_serving.py +++ b/benchmarks/benchmark_serving.py @@ -24,6 +24,8 @@ """ import argparse import asyncio +import base64 +import io import json import os import random @@ -31,11 +33,13 @@ import warnings from dataclasses import dataclass from datetime import datetime -from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple +from typing import Any, AsyncGenerator, Collection, Dict, List, Optional, Tuple import numpy as np from backend_request_func import (ASYNC_REQUEST_FUNCS, RequestFuncInput, RequestFuncOutput) +from datasets import load_dataset +from PIL.Image import Image from tqdm.asyncio import tqdm from transformers import PreTrainedTokenizerBase @@ -84,7 +88,7 @@ def sample_sharegpt_requests( num_requests: int, tokenizer: PreTrainedTokenizerBase, fixed_output_len: Optional[int] = None, -) -> List[Tuple[str, int, int]]: +) -> List[Tuple[str, int, int, None]]: if fixed_output_len is not None and fixed_output_len < 4: raise ValueError("output_len too small") # Load the dataset. @@ -119,7 +123,7 @@ def sample_sharegpt_requests( if prompt_len > 1024 or prompt_len + output_len > 2048: # Prune too long sequences. continue - filtered_dataset.append((prompt, prompt_len, output_len)) + filtered_dataset.append((prompt, prompt_len, output_len, None)) return filtered_dataset @@ -131,7 +135,7 @@ def sample_sonnet_requests( output_len: int, prefix_len: int, tokenizer: PreTrainedTokenizerBase, -) -> List[Tuple[str, str, int, int]]: +) -> List[Tuple[str, str, int, int, None]]: assert ( input_len > prefix_len ), "'args.sonnet-input-len' must be greater than 'args.prefix-input-len'." @@ -189,7 +193,65 @@ def sample_sonnet_requests( message, add_generation_prompt=True, tokenize=False) prompt_len = len(tokenizer(prompt_formatted).input_ids) sampled_requests.append( - (prompt, prompt_formatted, prompt_len, output_len)) + (prompt, prompt_formatted, prompt_len, output_len, None)) + + return sampled_requests + + +def sample_hf_requests( + dataset_path: str, + dataset_subset: str, + dataset_split: str, + num_requests: int, + tokenizer: PreTrainedTokenizerBase, + fixed_output_len: Optional[int] = None, +) -> List[Tuple[str, str, int, Optional[Dict[str, Collection[str]]]]]: + dataset = load_dataset(dataset_path, + name=dataset_subset, + split=dataset_split, + streaming=True) + assert "conversations" in dataset.features, ( + "HF Dataset must have 'conversations' column.") + filtered_dataset = dataset.shuffle().filter( + lambda x: len(x["conversations"]) >= 2) + sampled_requests: List[Tuple[str, int, int, Dict[str, + Collection[str]]]] = [] + for data in filtered_dataset: + if len(sampled_requests) == num_requests: + break + + # Tokenize the prompts and completions. + prompt = data["conversations"][0]["value"] + prompt_token_ids = tokenizer(prompt).input_ids + completion = data["conversations"][1]["value"] + completion_token_ids = tokenizer(completion).input_ids + prompt_len = len(prompt_token_ids) + output_len = len(completion_token_ids + ) if fixed_output_len is None else fixed_output_len + if prompt_len < 4 or output_len < 4: + # Prune too short sequences. + continue + if prompt_len > 1024 or prompt_len + output_len > 2048: + # Prune too long sequences. + continue + + if "image" in data and isinstance(data["image"], Image): + image: Image = data["image"] + image = image.convert("RGB") + image_data = io.BytesIO() + image.save(image_data, format='JPEG') + image_base64 = base64.b64encode( + image_data.getvalue()).decode("utf-8") + mm_content = { + "type": "image_url", + "image_url": { + "url": f"data:image/jpeg;base64,{image_base64}" + }, + } + else: + mm_content = None + + sampled_requests.append((prompt, prompt_len, output_len, mm_content)) return sampled_requests @@ -223,8 +285,8 @@ def sample_random_requests( [(offsets[i] + i + j) % tokenizer.vocab_size for j in range(input_lens[i])]) - input_requests.append( - (prompt, int(prefix_len + input_lens[i]), int(output_lens[i]))) + input_requests.append((prompt, int(prefix_len + input_lens[i]), + int(output_lens[i]), None)) return input_requests @@ -343,7 +405,12 @@ async def benchmark( raise ValueError(f"Unknown backend: {backend}") print("Starting initial single prompt test run...") - test_prompt, test_prompt_len, test_output_len = input_requests[0] + test_prompt, test_prompt_len, test_output_len, test_mm_content = ( + input_requests[0]) + if backend != "openai-chat" and test_mm_content is not None: + # multi-modal benchmark is only available on OpenAI Chat backend. + raise ValueError( + "Multi-modal content is only supported on 'openai-chat' backend.") test_input = RequestFuncInput( model=model_id, prompt=test_prompt, @@ -353,6 +420,7 @@ async def benchmark( logprobs=logprobs, best_of=best_of, use_beam_search=use_beam_search, + multi_modal_content=test_mm_content, ) test_output = await request_func(request_func_input=test_input) if not test_output.success: @@ -373,6 +441,7 @@ async def benchmark( logprobs=logprobs, best_of=best_of, use_beam_search=use_beam_search, + multi_modal_content=test_mm_content, ) profile_output = await request_func(request_func_input=profile_input) if profile_output.success: @@ -385,7 +454,7 @@ async def benchmark( benchmark_start_time = time.perf_counter() tasks: List[asyncio.Task] = [] async for request in get_request(input_requests, request_rate): - prompt, prompt_len, output_len = request + prompt, prompt_len, output_len, mm_content = request request_func_input = RequestFuncInput( model=model_id, prompt=prompt, @@ -395,6 +464,7 @@ async def benchmark( logprobs=logprobs, best_of=best_of, use_beam_search=use_beam_search, + multi_modal_content=mm_content, ) tasks.append( asyncio.create_task( @@ -575,6 +645,16 @@ def main(args: argparse.Namespace): for prompt, prompt_formatted, prompt_len, output_len in input_requests] + elif args.dataset_name == "hf": + input_requests = sample_hf_requests( + dataset_path=args.dataset_path, + dataset_subset=args.hf_subset, + dataset_split=args.hf_split, + num_requests=args.num_prompts, + tokenizer=tokenizer, + fixed_output_len=args.hf_output_len, + ) + elif args.dataset_name == "random": input_requests = sample_random_requests( prefix_len=args.random_prefix_len, @@ -685,13 +765,14 @@ def main(args: argparse.Namespace): "--dataset-name", type=str, default="sharegpt", - choices=["sharegpt", "sonnet", "random"], + choices=["sharegpt", "sonnet", "random", "hf"], help="Name of the dataset to benchmark on.", ) parser.add_argument("--dataset-path", type=str, default=None, - help="Path to the dataset.") + help="Path to the sharegpt/sonnet dataset. " + "Or the huggingface dataset ID if using HF dataset.") parser.add_argument( "--model", type=str, @@ -718,26 +799,6 @@ def main(args: argparse.Namespace): default=1000, help="Number of prompts to process.", ) - parser.add_argument( - "--sharegpt-output-len", - type=int, - default=None, - help="Output length for each request. Overrides the output length " - "from the ShareGPT dataset.") - parser.add_argument( - "--sonnet-input-len", - type=int, - default=550, - help= - "Number of input tokens per request, used only for sonnet dataset.", - ) - parser.add_argument( - "--sonnet-output-len", - type=int, - default=150, - help= - "Number of output tokens per request, used only for sonnet dataset.", - ) parser.add_argument( "--logprobs", type=int, @@ -748,42 +809,6 @@ def main(args: argparse.Namespace): "logprob is returned for each token; or (2) if beam search " "is enabled 1 logprob per token is computed"), ) - parser.add_argument( - "--sonnet-prefix-len", - type=int, - default=200, - help= - "Number of prefix tokens per request, used only for sonnet dataset.", - ) - parser.add_argument( - "--random-input-len", - type=int, - default=1024, - help= - "Number of input tokens per request, used only for random sampling.", - ) - parser.add_argument( - "--random-output-len", - type=int, - default=128, - help= - "Number of output tokens per request, used only for random sampling.", - ) - parser.add_argument( - "--random-range-ratio", - type=float, - default=1.0, - help="Range of sampled ratio of input/output length, " - "used only for random sampling.", - ) - parser.add_argument( - "--random-prefix-len", - type=int, - default=0, - help="Number of fixed prefix tokens before random " - " context. The length range of context in a random " - " request is [random-prefix-len, " - " random-prefix-len + random-prefix-len * random-range-ratio).") parser.add_argument( "--request-rate", type=float, @@ -857,5 +882,85 @@ def main(args: argparse.Namespace): "Use \"--percentile-metrics\" to select metrics.", ) + # group for dataset specific arguments + sonnet_group = parser.add_argument_group("sonnet dataset options") + sonnet_group.add_argument( + "--sonnet-input-len", + type=int, + default=550, + help= + "Number of input tokens per request, used only for sonnet dataset.", + ) + sonnet_group.add_argument( + "--sonnet-output-len", + type=int, + default=150, + help= + "Number of output tokens per request, used only for sonnet dataset.", + ) + sonnet_group.add_argument( + "--sonnet-prefix-len", + type=int, + default=200, + help= + "Number of prefix tokens per request, used only for sonnet dataset.", + ) + + sharegpt_group = parser.add_argument_group("sharegpt dataset options") + sharegpt_group.add_argument( + "--sharegpt-output-len", + type=int, + default=None, + help="Output length for each request. Overrides the output length " + "from the ShareGPT dataset.") + + random_group = parser.add_argument_group("random dataset options") + random_group.add_argument( + "--random-input-len", + type=int, + default=1024, + help= + "Number of input tokens per request, used only for random sampling.", + ) + random_group.add_argument( + "--random-output-len", + type=int, + default=128, + help= + "Number of output tokens per request, used only for random sampling.", + ) + random_group.add_argument( + "--random-range-ratio", + type=float, + default=1.0, + help="Range of sampled ratio of input/output length, " + "used only for random sampling.", + ) + random_group.add_argument( + "--random-prefix-len", + type=int, + default=0, + help="Number of fixed prefix tokens before random " + " context. The length range of context in a random " + " request is [random-prefix-len, " + " random-prefix-len + random-prefix-len * random-range-ratio).") + + hf_group = parser.add_argument_group("hf dataset options") + hf_group.add_argument("--hf-subset", + type=str, + default=None, + help="Subset of the HF dataset.") + hf_group.add_argument("--hf-split", + type=str, + default=None, + help="Split of the HF dataset.") + hf_group.add_argument( + "--hf-output-len", + type=int, + default=None, + help="Output length for each request. Overrides the output lengths " + "from the sampled HF dataset.", + ) + args = parser.parse_args() main(args) diff --git a/benchmarks/kernels/benchmark_layernorm.py b/benchmarks/kernels/benchmark_layernorm.py index 4947fda02e1cc..92f6053cc6d7e 100644 --- a/benchmarks/kernels/benchmark_layernorm.py +++ b/benchmarks/kernels/benchmark_layernorm.py @@ -1,10 +1,10 @@ -import random import time import torch from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser +from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser, + seed_everything) @torch.inference_mode() @@ -16,10 +16,7 @@ def main(num_tokens: int, do_profile: bool = False, num_warmup_iters: int = 5, num_iters: int = 100) -> None: - random.seed(seed) - torch.random.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) + seed_everything(seed) torch.set_default_device("cuda") layer = RMSNorm(hidden_size).to(dtype=dtype) diff --git a/benchmarks/kernels/benchmark_moe.py b/benchmarks/kernels/benchmark_moe.py index fd233c71b10a6..c2ad98b7e2656 100644 --- a/benchmarks/kernels/benchmark_moe.py +++ b/benchmarks/kernels/benchmark_moe.py @@ -10,7 +10,7 @@ from transformers import AutoConfig from vllm.model_executor.layers.fused_moe.fused_moe import * -from vllm.utils import FlexibleArgumentParser +from vllm.utils import FlexibleArgumentParser, seed_everything class BenchmarkConfig(TypedDict): @@ -166,7 +166,7 @@ class BenchmarkWorker: def __init__(self, seed: int) -> None: torch.set_default_device("cuda") - torch.cuda.manual_seed_all(seed) + seed_everything(seed) self.seed = seed def benchmark( @@ -180,7 +180,7 @@ def benchmark( use_fp8_w8a8: bool, use_int8_w8a16: bool, ) -> Tuple[Dict[str, int], float]: - torch.cuda.manual_seed_all(self.seed) + seed_everything(self.seed) dtype_str = get_config_dtype_str(dtype, use_int8_w8a16=use_int8_w8a16, use_fp8_w8a8=use_fp8_w8a8) diff --git a/benchmarks/kernels/benchmark_paged_attention.py b/benchmarks/kernels/benchmark_paged_attention.py index a04433142da42..87864d038d593 100644 --- a/benchmarks/kernels/benchmark_paged_attention.py +++ b/benchmarks/kernels/benchmark_paged_attention.py @@ -6,7 +6,7 @@ from vllm import _custom_ops as ops from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser, - create_kv_caches_with_random) + create_kv_caches_with_random, seed_everything) NUM_BLOCKS = 1024 PARTITION_SIZE = 512 @@ -28,10 +28,7 @@ def main( device: str = "cuda", kv_cache_dtype: Optional[str] = None, ) -> None: - random.seed(seed) - torch.random.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) + seed_everything(seed) scale = float(1.0 / (head_size**0.5)) query = torch.empty(num_seqs, diff --git a/benchmarks/kernels/benchmark_quant.py b/benchmarks/kernels/benchmark_quant.py index 4c1a7b26213a5..743a5744e8614 100644 --- a/benchmarks/kernels/benchmark_quant.py +++ b/benchmarks/kernels/benchmark_quant.py @@ -1,10 +1,10 @@ -import random import time import torch from vllm import _custom_ops as ops -from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser +from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser, + seed_everything) @torch.inference_mode() @@ -17,10 +17,7 @@ def main(num_tokens: int, do_profile: bool = False, num_warmup_iters: int = 5, num_iters: int = 100) -> None: - random.seed(seed) - torch.random.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) + seed_everything(seed) torch.set_default_device("cuda") x = torch.randn(num_tokens, hidden_size, dtype=dtype) diff --git a/benchmarks/kernels/benchmark_rope.py b/benchmarks/kernels/benchmark_rope.py index f542684a9a2a9..73fc9e9dbf461 100644 --- a/benchmarks/kernels/benchmark_rope.py +++ b/benchmarks/kernels/benchmark_rope.py @@ -6,7 +6,7 @@ from vllm.model_executor.layers.rotary_embedding import (RotaryEmbedding, get_rope) -from vllm.utils import FlexibleArgumentParser +from vllm.utils import FlexibleArgumentParser, seed_everything def benchmark_rope_kernels_multi_lora( @@ -22,9 +22,7 @@ def benchmark_rope_kernels_multi_lora( max_position: int = 8192, base: int = 10000, ) -> None: - torch.random.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) + seed_everything(seed) torch.set_default_device(device) if rotary_dim is None: rotary_dim = head_size diff --git a/benchmarks/kernels/graph_machete_bench.py b/benchmarks/kernels/graph_machete_bench.py index 1d076ed6d5c18..de608fd05af70 100644 --- a/benchmarks/kernels/graph_machete_bench.py +++ b/benchmarks/kernels/graph_machete_bench.py @@ -45,8 +45,7 @@ rows = int(math.ceil(len(results) / 2)) fig, axs = plt.subplots(rows, 2, figsize=(12, 5 * rows)) axs = axs.flatten() - axs_idx = 0 - for shape, data in results.items(): + for axs_idx, (shape, data) in enumerate(results.items()): plt.sca(axs[axs_idx]) df = pd.DataFrame(data) sns.lineplot(data=df, @@ -59,6 +58,5 @@ palette="Dark2") plt.title(f"Shape: {shape}") plt.ylabel("time (median, s)") - axs_idx += 1 plt.tight_layout() plt.savefig("graph_machete_bench.pdf") diff --git a/csrc/cpu/quant.cpp b/csrc/cpu/quant.cpp index 0cfc19097fded..2d7abe6145fee 100644 --- a/csrc/cpu/quant.cpp +++ b/csrc/cpu/quant.cpp @@ -257,11 +257,13 @@ void int8_scaled_mm(torch::Tensor& c, // [M, OC], row-major // static-per-tensor quantization. void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size] const torch::Tensor& input, // [..., hidden_size] - const torch::Tensor& scale) { + const torch::Tensor& scale, + c10::optional const& azp) { CPU_KERNEL_GUARD_IN(static_scaled_int8_quant) TORCH_CHECK(input.is_contiguous()); TORCH_CHECK(out.is_contiguous()); TORCH_CHECK(scale.numel() == 1); + TORCH_CHECK(!azp.has_value(), "Zero point is not supported on CPU."); const int hidden_size = input.size(-1); const int num_tokens = input.numel() / hidden_size; @@ -277,11 +279,12 @@ void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size] void dynamic_scaled_int8_quant( torch::Tensor& out, // [..., hidden_size] const torch::Tensor& input, // [..., hidden_size] - torch::Tensor& scale // [..., 1] -) { + torch::Tensor& scale, // [..., 1] + c10::optional const& azp) { CPU_KERNEL_GUARD_IN(dynamic_scaled_int8_quant) TORCH_CHECK(input.is_contiguous()); TORCH_CHECK(out.is_contiguous()); + TORCH_CHECK(!azp.has_value(), "Zero point is not supported on CPU."); int const hidden_size = input.size(-1); int const num_tokens = input.numel() / hidden_size; diff --git a/csrc/cpu/torch_bindings.cpp b/csrc/cpu/torch_bindings.cpp index b45da1b386b5b..ab697e3e6aef7 100644 --- a/csrc/cpu/torch_bindings.cpp +++ b/csrc/cpu/torch_bindings.cpp @@ -94,13 +94,14 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { #ifdef __AVX512F__ // Compute int8 quantized tensor for given scaling factor. ops.def( - "static_scaled_int8_quant(Tensor! out, Tensor input, Tensor scale) -> " - "()"); + "static_scaled_int8_quant(Tensor! out, Tensor input, Tensor scale," + "Tensor? azp) -> ()"); ops.impl("static_scaled_int8_quant", torch::kCPU, &static_scaled_int8_quant); + // Compute int8 quantized tensor and scaling factor ops.def( - "dynamic_scaled_int8_quant(Tensor! out, Tensor input, Tensor! scale) -> " - "()"); + "dynamic_scaled_int8_quant(Tensor! out, Tensor input, Tensor! scale, " + "Tensor!? azp) -> ()"); ops.impl("dynamic_scaled_int8_quant", torch::kCPU, &dynamic_scaled_int8_quant); // W8A8 GEMM, supporting symmetric per-tensor or per-row/column diff --git a/csrc/custom_all_reduce.cu b/csrc/custom_all_reduce.cu index 82a3563979f16..9b82bec44c3c6 100644 --- a/csrc/custom_all_reduce.cu +++ b/csrc/custom_all_reduce.cu @@ -55,18 +55,6 @@ bool _is_weak_contiguous(torch::Tensor& t) { t.numel() * t.element_size()); } -bool should_custom_ar(torch::Tensor& inp, int64_t max_size, int64_t world_size, - bool full_nvlink) { - auto inp_size = inp.numel() * inp.element_size(); - // custom allreduce requires input byte size to be multiples of 16 - if (inp_size % 16 != 0) return false; - if (!_is_weak_contiguous(inp)) return false; - if (world_size == 2 || full_nvlink) return inp_size <= max_size; - // for 4 or more non NVLink-capable GPUs, custom allreduce provides little - // performance improvement over NCCL. - return false; -} - void _all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out, cudaStream_t stream) { auto fa = reinterpret_cast(_fa); diff --git a/csrc/mamba/causal_conv1d/causal_conv1d.cu b/csrc/mamba/causal_conv1d/causal_conv1d.cu index 88a64a8ece585..32261ec17d897 100644 --- a/csrc/mamba/causal_conv1d/causal_conv1d.cu +++ b/csrc/mamba/causal_conv1d/causal_conv1d.cu @@ -198,7 +198,8 @@ causal_conv1d_update(const at::Tensor &x, const at::Tensor &conv_state, const at::Tensor &weight, const c10::optional &bias_, - bool silu_activation) { + bool silu_activation, + const c10::optional &conv_state_indices_) { auto input_type = x.scalar_type(); auto weight_type = weight.scalar_type(); TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16); @@ -216,7 +217,6 @@ causal_conv1d_update(const at::Tensor &x, const int width = weight.size(-1); CHECK_SHAPE(x, batch_size, dim); - CHECK_SHAPE(conv_state, batch_size, dim, width); CHECK_SHAPE(weight, dim, width); TORCH_CHECK(width >= 2 && width <= 4, "causal_conv1d only supports width between 2 and 4"); @@ -241,6 +241,22 @@ causal_conv1d_update(const at::Tensor &x, params.conv_state_c_stride = conv_state.stride(1); params.conv_state_l_stride = conv_state.stride(2); + if (conv_state_indices_.has_value()) { + auto conv_state_indices = conv_state_indices_.value(); + TORCH_CHECK(conv_state_indices.scalar_type() == torch::kInt32) + TORCH_CHECK(conv_state_indices.is_cuda()); + TORCH_CHECK(conv_state_indices.stride(0) == 1) + CHECK_SHAPE(conv_state_indices, batch_size); + + int conv_state_entries = conv_state.size(0); + CHECK_SHAPE(conv_state, conv_state_entries, dim, width); + + params.conv_state_indices_ptr = conv_state_indices.data_ptr(); + } else { + CHECK_SHAPE(conv_state, batch_size, dim, width); + params.conv_state_indices_ptr = nullptr; + } + // Otherwise the kernel will be launched from cuda:0 device // Cast to char to avoid compiler warning about narrowing at::cuda::CUDAGuard device_guard{(char)x.get_device()}; @@ -646,8 +662,16 @@ void causal_conv1d_update_kernel(ConvParamsBase params) { const int channel_id = blockIdx.y * kNThreads + tidx; input_t *x = reinterpret_cast(params.x_ptr) + batch_id * params.x_batch_stride + channel_id * params.x_c_stride; - input_t *conv_state = reinterpret_cast(params.conv_state_ptr) + batch_id * params.conv_state_batch_stride + + // If params.conv_state_batch_indices is set, then the conv state is gathered from the conv state tensor + // along the batch axis. Otherwise, the conv state coordinate is the same as the batch id. + const int conv_state_batch_coord = params.conv_state_indices_ptr == nullptr + ? batch_id + : params.conv_state_indices_ptr[batch_id]; + input_t *conv_state = reinterpret_cast(params.conv_state_ptr) + + conv_state_batch_coord * params.conv_state_batch_stride + channel_id * params.conv_state_c_stride; + weight_t *weight = reinterpret_cast(params.weight_ptr) + channel_id * params.weight_c_stride; input_t *out = reinterpret_cast(params.out_ptr) + batch_id * params.out_batch_stride + channel_id * params.out_c_stride; diff --git a/csrc/mamba/causal_conv1d/causal_conv1d.h b/csrc/mamba/causal_conv1d/causal_conv1d.h index bb25314c8bbbd..32a7d83c09b8d 100644 --- a/csrc/mamba/causal_conv1d/causal_conv1d.h +++ b/csrc/mamba/causal_conv1d/causal_conv1d.h @@ -36,6 +36,10 @@ struct ConvParamsBase { void *__restrict__ conv_state_ptr; + // For the continuous batching case. Makes it so that the mamba state for + // the current batch doesn't need to be a contiguous tensor. + int32_t *__restrict__ conv_state_indices_ptr; + void *__restrict__ seq_idx_ptr; // No __restrict__ since initial_states could be the same as final_states. diff --git a/csrc/moe/marlin_moe_ops.cu b/csrc/moe/marlin_moe_ops.cu index 92184f43c9eb0..666d87eb92595 100644 --- a/csrc/moe/marlin_moe_ops.cu +++ b/csrc/moe/marlin_moe_ops.cu @@ -25,6 +25,8 @@ #include +#include "core/scalar_type.hpp" + template inline std::string str(T x) { return std::to_string(x); @@ -131,11 +133,26 @@ __device__ inline int lop3(int a, int b, int c) { return res; } -// Efficiently dequantize an int32 value into a full B-fragment of 4 fp16 -// values. We mostly follow the strategy in the link below, with some small -// changes: -// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h -__device__ inline FragB dequant(int q) { +// Constructs destination register by taking bytes from 2 sources (based on +// mask) +template +__device__ inline uint32_t prmt(uint32_t a) { + uint32_t res; + asm volatile("prmt.b32 %0, %1, %2, %3;\n" + : "=r"(res) + : "r"(a), "n"(start_byte), "n"(mask)); + return res; +} + +template +__device__ inline FragB dequant(int q); + +// Efficiently dequantize 4bit values packed in an int32 value into a full +// B-fragment of 4 fp16 values. We mostly follow the strategy in the link below, +// with some small changes: +// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L215-L287 +template <> +__device__ inline FragB dequant(int q) { const int LO = 0x000f000f; const int HI = 0x00f000f0; const int EX = 0x64006400; @@ -156,6 +173,28 @@ __device__ inline FragB dequant(int q) { return frag_b; } +// Fast Int8ToFp16: Efficiently dequantize 8bit int values to fp16 +// Reference: +// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L53-L85 +template <> +__device__ inline FragB dequant(int q) { + static constexpr uint32_t mask_for_elt_01 = 0x5250; + static constexpr uint32_t mask_for_elt_23 = 0x5351; + static constexpr uint32_t start_byte_for_fp16 = 0x64646464; + + uint32_t lo = prmt(q); + uint32_t hi = prmt(q); + + static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480; + + FragB frag_b; + frag_b[0] = __hsub2(*reinterpret_cast(&lo), + *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); + frag_b[1] = __hsub2(*reinterpret_cast(&hi), + *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); + return frag_b; +} + // Multiply dequantized values by the corresponding quantization scale; used // only for grouped quantization. __device__ inline void scale(FragB& frag_b, FragS& frag_s, int i) { @@ -296,7 +335,8 @@ __global__ void compute_expert_offsets(int const* __restrict__ topk_ids, __syncthreads(); } -template ( - &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd]); + + #pragma unroll + for (int i = 0; i < b_thread_vecs; i++) { + frag_b_quant[k % 2][i] = *reinterpret_cast( + &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd + i]); + } }; bool is_same_group[stages]; int same_group_id[stages]; auto init_same_group = [&](int pipe) { + if constexpr (!has_act_order) { + is_same_group[pipe] = false; + same_group_id[pipe] = 0; + return; + } + int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; int* sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage); @@ -840,10 +902,19 @@ __device__ inline void MarlinMoESingle( // dequantization and matmul operations. #pragma unroll for (int j = 0; j < 4; j++) { - int b_quant = frag_b_quant[k % 2][j]; - int b_quant_shift = b_quant >> 8; + int b_quant_0, b_quant_1; + if constexpr (w_type.size_bits() == 4) { + b_quant_0 = frag_b_quant[k % 2][0][j]; + b_quant_1 = b_quant_0 >> 8; + } else { + static_assert(w_type.size_bits() == 8); + int* frag_b_quant_ptr = reinterpret_cast(frag_b_quant[k % 2]); + b_quant_0 = frag_b_quant_ptr[j * 2 + 0]; + b_quant_1 = frag_b_quant_ptr[j * 2 + 1]; + } - FragB frag_b0 = dequant(b_quant); + FragB frag_b0 = dequant(b_quant_0); + FragB frag_b1 = dequant(b_quant_1); // Apply scale to frag_b0 if constexpr (has_act_order) { @@ -855,8 +926,6 @@ __device__ inline void MarlinMoESingle( } } - FragB frag_b1 = dequant(b_quant_shift); - // Apply scale to frag_b1 if constexpr (has_act_order) { scale4(frag_b1, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j], @@ -881,13 +950,13 @@ __device__ inline void MarlinMoESingle( // multiple warps that accumulate their partial sums of the same output // location; which we have to reduce over in the end. We do in shared memory. auto thread_block_reduce = [&]() { - constexpr int red_off = threads / b_sh_stride / 2; + constexpr int red_off = threads / b_sh_stride_threads / 2; if (red_off >= 1) { - int red_idx = threadIdx.x / b_sh_stride; - constexpr int red_sh_stride = b_sh_stride * 4 * 2; - constexpr int red_sh_delta = b_sh_stride; - int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride) + - (threadIdx.x % b_sh_stride); + int red_idx = threadIdx.x / b_sh_stride_threads; + constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2; + constexpr int red_sh_delta = b_sh_stride_threads; + int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) + + (threadIdx.x % b_sh_stride_threads); // Parallel logarithmic shared memory reduction. We make sure to avoid any // unnecessary read or write iterations, e.g., for two warps we write only @@ -1035,8 +1104,10 @@ __device__ inline void MarlinMoESingle( auto write = [&](int idx, float c0, float c1, FragS& s) { half2 res = __halves2half2(__float2half(c0), __float2half(c1)); - // For per-column quantization we finally apply the scale here - if constexpr (!has_act_order && group_blocks == -1) { + // For per-column quantization we finally apply the scale here (only for + // 4-bit) + if constexpr (!has_act_order && group_blocks == -1 && + w_type.size_bits() == 4) { res = __hmul2(res, s[0]); } @@ -1088,9 +1159,9 @@ __device__ inline void MarlinMoESingle( // Start global fetch and register load pipelines. auto start_pipes = [&]() { - // TODO re-enable after fixing this function - // fetch_sorted_ids_to_shared(); - __syncthreads(); + // TODO re-enable after fixing this function + // fetch_sorted_ids_to_shared(); + // __syncthreads(); #pragma unroll for (int i = 0; i < stages - 1; i++) { @@ -1166,28 +1237,70 @@ __device__ inline void MarlinMoESingle( if (slice_iters == 0) { cp_async_wait<0>(); bool last = slice_idx == slice_count - 1; - // For per-column scales, we only fetch them here in the final step before - // write-out if constexpr (!has_act_order && group_blocks == -1) { - if (last) { + if constexpr (w_type.size_bits() == 8) { if (s_sh_wr_pred) { cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); } cp_async_fence(); + } else { + // For 4-bit per-column scales, we only fetch them here in the + // final step before write-out + if (last) { + if (s_sh_wr_pred) { + cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); + } + cp_async_fence(); + } } } thread_block_reduce(); if constexpr (!has_act_order && group_blocks == -1) { - if (last) { + if constexpr (w_type.size_bits() == 8) { cp_async_wait<0>(); __syncthreads(); if (threadIdx.x / 32 < thread_n_blocks / 4) { reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; } + + } else { + if (last) { + cp_async_wait<0>(); + __syncthreads(); + if (threadIdx.x / 32 < thread_n_blocks / 4) { + reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; + reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; + } + } } } + + // For 8-bit channelwise, we apply the scale before the global reduction + // that converts the fp32 results to fp16 (so that we avoid possible + // overflow in fp16) + if constexpr (!has_act_order && group_blocks == -1 && + w_type.size_bits() == 8) { + if (threadIdx.x / 32 < thread_n_blocks / 4) { + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + #pragma unroll + for (int j = 0; j < 4; j++) { + scale_float(reinterpret_cast(&frag_c[i][j][0][0]), + frag_s[j / 2][2 * (j % 2) + 0]); + scale_float(reinterpret_cast(&frag_c[i][j][0][2]), + frag_s[j / 2][2 * (j % 2) + 0]); + + scale_float(reinterpret_cast(&frag_c[i][j][1][0]), + frag_s[j / 2][2 * (j % 2) + 1]); + scale_float(reinterpret_cast(&frag_c[i][j][1][2]), + frag_s[j / 2][2 * (j % 2) + 1]); + } + } + } + } + if (slice_count > 1) { // only globally reduce if there is more than one // block in a slice barrier_acquire(&locks[slice_col], slice_idx); @@ -1227,7 +1340,8 @@ __device__ inline void MarlinMoESingle( } } -template 4) { + if (max_block > cfg_max_m_blocks) { // Note that parallel > 1 currently only works for inputs without any // padding - par = (16 * max_block - pad) / 64; - par = min((16 * max_block - pad) / 64, max_par); - prob_m = 64 * par; - m_block_ctr += 4 * (par - 1); - max_block = 4; + par = (16 * max_block - pad) / (16 * cfg_max_m_blocks); + if (par > max_par) par = max_par; + prob_m = (16 * cfg_max_m_blocks) * par; + m_block_ctr += cfg_max_m_blocks * (par - 1); + max_block = cfg_max_m_blocks; } if (max_block == 1) { - MarlinMoESingle( + MarlinMoESingle( A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx, expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, current_m_block); } else if (max_block == 2) { - MarlinMoESingle( + MarlinMoESingle( A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx, expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, current_m_block); } else if (max_block == 3) { - MarlinMoESingle( + MarlinMoESingle( A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx, expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, current_m_block); } else { - MarlinMoESingle( + MarlinMoESingle( A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx, expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, @@ -1342,7 +1457,8 @@ __global__ void compute_expert_offsets(int const* __restrict__ topk_ids, return; } -template , \ cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ - MarlinMoE \ + MarlinMoE \ <<>>( \ A_ptr, B_ptr, C_ptr, sorted_ids_ptr, topk_weights_ptr, s_ptr, \ g_idx_ptr, expert_offsets_ptr, num_groups, expert_idx, \ num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, \ - replicate_input, apply_weights, m_block, max_par); \ + replicate_input, apply_weights, m_block, max_par, \ + exec_cfg.max_m_blocks); \ } typedef struct { @@ -1423,6 +1543,11 @@ typedef struct { int num_threads; } thread_config_t; +typedef struct { + int max_m_blocks; + thread_config_t tb_cfg; +} exec_config_t; + thread_config_t small_batch_thread_configs[] = { // Ordered by priority @@ -1443,8 +1568,77 @@ thread_config_t large_batch_thread_configs[] = { {128, 64, 128}, // Reduce N 4X, increase K 2X }; -bool is_valid_config(thread_config_t const& th_config, int prob_m, int prob_n, - int prob_k) { +int get_scales_cache_size(thread_config_t const& th_config, int prob_m, + int prob_n, int prob_k, int num_bits, int group_size, + bool has_act_order, bool is_k_full) { + bool cache_scales_chunk = has_act_order && !is_k_full; + + int tb_n = th_config.thread_n; + int tb_k = th_config.thread_k; + + // Get max scale groups per thread-block + int tb_groups; + if (group_size == -1) { + tb_groups = 1; + } else if (group_size == 0) { + tb_groups = ceildiv(tb_k, 32); // Worst case is 32 group size + } else { + tb_groups = ceildiv(tb_k, group_size); + } + + if (cache_scales_chunk) { + int load_groups = + tb_groups * STAGES * 2; // Chunk size is 2x pipeline over dim K + load_groups = max(load_groups, 32); // We load at least 32 scale groups + return load_groups * tb_n * 2; + + } else { + int tb_scales = tb_groups * tb_n * 2; + + return tb_scales * STAGES; + } +} + +bool is_valid_cache_size(thread_config_t const& th_config, int max_m_blocks, + int prob_m, int prob_n, int prob_k, int num_bits, + int scales_cache_size, int max_shared_mem) { + int pack_factor = 32 / num_bits; + + // Get B size + int tb_k = th_config.thread_k; + int tb_n = th_config.thread_n; + + int b_size = (tb_k * tb_n / pack_factor) * 4; + + // Get A size + int m_blocks = ceildiv(prob_m, 16); + int tb_max_m = 16; + + while (true) { + if (m_blocks >= max_m_blocks) { + tb_max_m *= max_m_blocks; + break; + } + + max_m_blocks--; + if (max_m_blocks == 0) { + TORCH_CHECK(false, "Unexpected m_blocks = ", m_blocks); + } + } + + int a_size = (tb_max_m * tb_k) * 2; + + float pipe_size = (a_size + b_size) * STAGES; + + TORCH_CHECK(max_shared_mem / 2 > scales_cache_size); // Sanity + + return pipe_size < 0.95f * (max_shared_mem - scales_cache_size); +} + +bool is_valid_config(thread_config_t const& th_config, int max_m_blocks, + int prob_m, int prob_n, int prob_k, int num_bits, + int group_size, bool has_act_order, bool is_k_full, + int max_shared_mem) { // Sanity if (th_config.thread_k == -1 || th_config.thread_n == -1 || th_config.num_threads == -1) { @@ -1472,64 +1666,88 @@ bool is_valid_config(thread_config_t const& th_config, int prob_m, int prob_n, return false; } + // Determine cache for scales + int scales_cache_size = + get_scales_cache_size(th_config, prob_m, prob_n, prob_k, num_bits, + group_size, has_act_order, is_k_full); + + // Check that pipeline fits into cache + if (!is_valid_cache_size(th_config, max_m_blocks, prob_m, prob_n, prob_k, + num_bits, scales_cache_size, max_shared_mem)) { + return false; + } + return true; } -thread_config_t determine_thread_config(int prob_m, int prob_n, int prob_k) { - if (prob_m <= 16) { - for (auto th_config : small_batch_thread_configs) { - if (is_valid_config(th_config, prob_m, prob_n, prob_k)) { - return th_config; +exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k, + int num_bits, int group_size, + bool has_act_order, bool is_k_full, + int max_shared_mem) { + int max_m_blocks = 4; + while (max_m_blocks > 0) { + if (prob_m <= 16) { + for (auto th_config : small_batch_thread_configs) { + if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n, prob_k, + num_bits, group_size, has_act_order, is_k_full, + max_shared_mem)) { + return exec_config_t{max_m_blocks, th_config}; + } } - } - - } else { - for (auto th_config : large_batch_thread_configs) { - if (is_valid_config(th_config, prob_m, prob_n, prob_k)) { - return th_config; + } else { + for (auto th_config : large_batch_thread_configs) { + if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n, prob_k, + num_bits, group_size, has_act_order, is_k_full, + max_shared_mem)) { + return exec_config_t{max_m_blocks, th_config}; + } } } + + max_m_blocks--; // Process less M blocks per invocation to reduce cache + // usage } - return thread_config_t{-1, -1, -1}; + return exec_config_t{0, {-1, -1, -1}}; } -#define CALL_IF_MOE(N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - __CALL_IF_MOE(1, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ - __CALL_IF_MOE(2, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ - __CALL_IF_MOE(3, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ - __CALL_IF_MOE(4, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ - \ - __CALL_IF_MOE(1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ - __CALL_IF_MOE(1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ - __CALL_IF_MOE(1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ - __CALL_IF_MOE(1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ - \ - __CALL_IF_MOE(2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ - __CALL_IF_MOE(2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ - __CALL_IF_MOE(2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ - __CALL_IF_MOE(2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ - \ - __CALL_IF_MOE(3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ - __CALL_IF_MOE(3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ - __CALL_IF_MOE(3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ - __CALL_IF_MOE(3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ - \ - __CALL_IF_MOE(4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ - __CALL_IF_MOE(4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ - __CALL_IF_MOE(4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ - __CALL_IF_MOE(4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) +#define CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, 2, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, 3, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, 4, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ + \ + __CALL_IF_MOE(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ + \ + __CALL_IF_MOE(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ + \ + __CALL_IF_MOE(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ + \ + __CALL_IF_MOE(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) void marlin_mm_moe_f16i4(const void* A, const void* B, void* C, const void* sorted_ids, const void* topk_weights, const void* topk_ids, const void* s, const void* g_idx, const void* perm, void* a_tmp, void* expert_offsets, int prob_m, int prob_n, int prob_k, void* workspace, - bool has_act_order, bool is_k_full, int num_groups, - int group_size, int num_experts, int topk, - int moe_block_size, int dev, cudaStream_t stream, - int thread_k, int thread_n, int sms, int max_par, - bool replicate_input, bool apply_weights) { + vllm::ScalarType const& q_type, bool has_act_order, + bool is_k_full, int num_groups, int group_size, + int num_experts, int topk, int moe_block_size, int dev, + cudaStream_t stream, int thread_k, int thread_n, + int sms, int max_par, bool replicate_input, + bool apply_weights) { TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m, ", ", prob_n, ", ", prob_k, "]"); @@ -1537,26 +1755,42 @@ void marlin_mm_moe_f16i4(const void* A, const void* B, void* C, cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev); } + int max_shared_mem = 0; + cudaDeviceGetAttribute(&max_shared_mem, + cudaDevAttrMaxSharedMemoryPerBlockOptin, dev); + TORCH_CHECK(max_shared_mem > 0); + + int num_bits = q_type.size_bits(); + // Set thread config - thread_config_t th_config; + exec_config_t exec_cfg; if (thread_k != -1 && thread_n != -1) { // User-defined config - th_config = thread_config_t{thread_k, thread_n, USER_THREADS}; + exec_cfg = + exec_config_t{4, thread_config_t{thread_k, thread_n, USER_THREADS}}; } else { // Auto config - th_config = determine_thread_config(prob_m, prob_n, prob_k); + exec_cfg = + determine_thread_config(prob_m, prob_n, prob_k, num_bits, group_size, + has_act_order, is_k_full, max_shared_mem); } - TORCH_CHECK(is_valid_config(th_config, prob_m, prob_n, prob_k), - "Invalid thread config: thread_k = " + str(th_config.thread_k) + - ", thread_n = " + str(th_config.thread_n) + - ", num_threads = " + str(th_config.num_threads) + - " for MKN = [" + str(prob_m) + ", " + str(prob_k) + ", " + - str(prob_n) + "]"); - - int num_threads = th_config.num_threads; - thread_k = th_config.thread_k; - thread_n = th_config.thread_n; + TORCH_CHECK(exec_cfg.max_m_blocks > 0 && + is_valid_config(exec_cfg.tb_cfg, exec_cfg.max_m_blocks, + prob_m, prob_n, prob_k, num_bits, group_size, + has_act_order, is_k_full, max_shared_mem), + "Invalid thread config: max_m_blocks = ", exec_cfg.max_m_blocks, + ", thread_k = ", exec_cfg.tb_cfg.thread_k, + ", thread_n = ", exec_cfg.tb_cfg.thread_n, + ", num_threads = ", exec_cfg.tb_cfg.num_threads, " for MKN = [", + prob_m, ", ", prob_k, ", ", prob_n, "] and num_bits = ", num_bits, + ", group_size = ", group_size, + ", has_act_order = ", has_act_order, ", is_k_full = ", is_k_full, + ", max_shared_mem = ", max_shared_mem); + + int num_threads = exec_cfg.tb_cfg.num_threads; + thread_k = exec_cfg.tb_cfg.thread_k; + thread_n = exec_cfg.tb_cfg.thread_n; int thread_k_blocks = thread_k / 16; int thread_n_blocks = thread_n / 16; @@ -1590,11 +1824,6 @@ void marlin_mm_moe_f16i4(const void* A, const void* B, void* C, } } - int max_shared_mem = 0; - cudaDeviceGetAttribute(&max_shared_mem, - cudaDevAttrMaxSharedMemoryPerBlockOptin, dev); - TORCH_CHECK(max_shared_mem > 0); - int tot_m = prob_m; const int* topk_ids_ptr = (const int*)topk_ids; @@ -1611,10 +1840,13 @@ void marlin_mm_moe_f16i4(const void* A, const void* B, void* C, has_act_order = false; } + int pack_factor = 32 / q_type.size_bits(); + for (int expert_idx = 0; expert_idx < num_experts; ++expert_idx) { const int4* A_ptr = (const int4*)A; int4* a_tmp_ptr = (int4*)a_tmp; - const int4* B_ptr = (const int4*)B + (prob_n * prob_k / 32) * expert_idx; + const int4* B_ptr = + (const int4*)B + (prob_n * prob_k / (pack_factor * 4)) * expert_idx; int4* C_ptr = (int4*)C; const float* topk_weights_ptr = (const float*)topk_weights; const int* sorted_ids_ptr = (const int*)sorted_ids; @@ -1636,19 +1868,22 @@ void marlin_mm_moe_f16i4(const void* A, const void* B, void* C, A_ptr = a_tmp_ptr; } - int max_m_blocks = ceildiv(tot_m, 16); - for (int m_block = 0; m_block < max_m_blocks; m_block += 16) { - // Define kernel configurations - + int tot_m_blocks = ceildiv(tot_m, 16); + for (int m_block = 0; m_block < tot_m_blocks; + m_block += 4 * exec_cfg.max_m_blocks) { // make it max possible value - int thread_m_blocks = 4; + int thread_m_blocks = exec_cfg.max_m_blocks; if (false) { } - CALL_IF_MOE(16, 4, 256) - CALL_IF_MOE(8, 8, 256) - CALL_IF_MOE(8, 4, 128) - CALL_IF_MOE(4, 8, 128) + CALL_IF_MOE(vllm::kU4B8, 16, 4, 256) + CALL_IF_MOE(vllm::kU4B8, 8, 8, 256) + CALL_IF_MOE(vllm::kU4B8, 8, 4, 128) + CALL_IF_MOE(vllm::kU4B8, 4, 8, 128) + CALL_IF_MOE(vllm::kU8B128, 16, 4, 256) + CALL_IF_MOE(vllm::kU8B128, 8, 8, 256) + CALL_IF_MOE(vllm::kU8B128, 8, 4, 128) + CALL_IF_MOE(vllm::kU8B128, 4, 8, 128) else { TORCH_CHECK(false, "Unsupported shapes: MNK = [" + str(prob_m) + ", " + str(prob_n) + ", " + str(prob_k) + "]" + @@ -1670,9 +1905,15 @@ torch::Tensor marlin_gemm_moe( const torch::Tensor& sorted_ids, const torch::Tensor& topk_weights, const torch::Tensor& topk_ids, const torch::Tensor& b_scales, const torch::Tensor& g_idx, const torch::Tensor& perm, - torch::Tensor& workspace, int64_t size_m, int64_t size_n, int64_t size_k, - bool is_k_full, int64_t num_experts, int64_t topk, int64_t moe_block_size, + torch::Tensor& workspace, vllm::ScalarTypeTorchPtr const& b_q_type, + int64_t size_m, int64_t size_n, int64_t size_k, bool is_k_full, + int64_t num_experts, int64_t topk, int64_t moe_block_size, bool replicate_input, bool apply_weights) { + TORCH_CHECK(*b_q_type == vllm::kU4B8 || *b_q_type == vllm::kU8B128, + "b_q_type must be uint4b8 or uint8b128. Got = ", b_q_type->str()); + + int pack_factor = 32 / b_q_type->size_bits(); + int max_par = 4; int dev = a.get_device(); @@ -1733,8 +1974,8 @@ torch::Tensor marlin_gemm_moe( topk_weights.data_ptr(), topk_ids.data_ptr(), b_scales.data_ptr(), g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(), expert_offsets.data_ptr(), size_m, size_n, size_k, workspace.data_ptr(), - has_act_order, is_k_full, num_groups, group_size, num_experts, topk, - moe_block_size, dev, at::cuda::getCurrentCUDAStream(dev), thread_k, + *b_q_type, has_act_order, is_k_full, num_groups, group_size, num_experts, + topk, moe_block_size, dev, at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, max_par, replicate_input, apply_weights); return c; } diff --git a/csrc/moe/marlin_moe_ops.h b/csrc/moe/marlin_moe_ops.h index 43d264e0770d6..adee8399a4d6f 100644 --- a/csrc/moe/marlin_moe_ops.h +++ b/csrc/moe/marlin_moe_ops.h @@ -2,11 +2,14 @@ #include +#include "core/scalar_type.hpp" + torch::Tensor marlin_gemm_moe( const torch::Tensor& a, const torch::Tensor& b_q_weights, const torch::Tensor& sorted_ids, const torch::Tensor& topk_weights, const torch::Tensor& topk_ids, const torch::Tensor& b_scales, const torch::Tensor& g_idx, const torch::Tensor& perm, - torch::Tensor& workspace, int64_t size_m, int64_t size_n, int64_t size_k, - bool is_k_full, int64_t num_experts, int64_t topk, int64_t moe_block_size, + torch::Tensor& workspace, vllm::ScalarTypeTorchPtr const& b_q_type, + int64_t size_m, int64_t size_n, int64_t size_k, bool is_k_full, + int64_t num_experts, int64_t topk, int64_t moe_block_size, bool replicate_input, bool apply_weights); diff --git a/csrc/moe/torch_bindings.cpp b/csrc/moe/torch_bindings.cpp index 8a0e625b43fa1..cd65a8ee92b94 100644 --- a/csrc/moe/torch_bindings.cpp +++ b/csrc/moe/torch_bindings.cpp @@ -13,9 +13,11 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { m.def( "marlin_gemm_moe(Tensor! a, Tensor! b_q_weights, Tensor! sorted_ids, " "Tensor! topk_weights, Tensor! topk_ids, Tensor! b_scales, Tensor! " - "g_idx, Tensor! perm, Tensor! workspace, int size_m, int size_n, int " - "size_k, bool is_k_full, int num_experts, int topk, int moe_block_size, " - "bool replicate_input, bool apply_weights) -> Tensor"); + "g_idx, Tensor! perm, Tensor! workspace, " + "__torch__.torch.classes._core_C.ScalarType b_q_type, int size_m, " + "int size_n, int size_k, bool is_k_full, int num_experts, int topk, " + "int moe_block_size, bool replicate_input, bool apply_weights)" + " -> Tensor"); m.impl("marlin_gemm_moe", torch::kCUDA, &marlin_gemm_moe); #endif } diff --git a/csrc/ops.h b/csrc/ops.h index 5333b22c536d6..15e9ebe87408a 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -184,10 +184,12 @@ torch::Tensor marlin_qqq_gemm(torch::Tensor const& a, #endif void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input, - torch::Tensor const& scale); + torch::Tensor const& scale, + c10::optional const& azp); void dynamic_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input, - torch::Tensor& scales); + torch::Tensor& scales, + c10::optional const& azp); torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight, torch::Tensor b_gptq_qzeros, @@ -220,11 +222,10 @@ std::vector selective_scan_fwd( const c10::optional& index_, const c10::optional& x); -at::Tensor causal_conv1d_update(const at::Tensor& x, - const at::Tensor& conv_state, - const at::Tensor& weight, - const c10::optional& bias_, - bool silu_activation); +at::Tensor causal_conv1d_update( + const at::Tensor& x, const at::Tensor& conv_state, const at::Tensor& weight, + const c10::optional& bias, bool silu_activation, + const c10::optional& conv_state_indices); at::Tensor causal_conv1d_fwd(const at::Tensor& x, const at::Tensor& weight, const c10::optional& bias_, @@ -239,8 +240,6 @@ fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data, const std::vector& handles, const std::vector& offsets, int64_t rank, bool full_nvlink); -bool should_custom_ar(torch::Tensor& inp, int64_t max_size, int64_t world_size, - bool full_nvlink); void all_reduce_reg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out); void all_reduce_unreg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& reg_buffer, torch::Tensor& out); diff --git a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu index 616fc149760e5..aec9fa002f96e 100644 --- a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu +++ b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu @@ -14,12 +14,17 @@ static inline __device__ int8_t float_to_int8_rn(float x) { #ifdef USE_ROCM - static const float i8_min = + static constexpr auto i8_min = static_cast(std::numeric_limits::min()); - static const float i8_max = + static constexpr auto i8_max = static_cast(std::numeric_limits::max()); - // round + + // To match the rounding mode of CUDA, we use nearbyint. + // It uses the current rounding mode, which is always FE_TONEAREST on HIP. + // If that changes in the future, we may need to set the rounding mode + // explicitly, either at runtime or compile time. float dst = std::nearbyint(x); + // saturate dst = std::clamp(dst, i8_min, i8_max); return static_cast(dst); @@ -31,6 +36,59 @@ static inline __device__ int8_t float_to_int8_rn(float x) { #endif } +static inline __device__ int32_t float_to_int32_rn(float x) { +#ifdef USE_ROCM + // int32_max is not exactly representable as float. + // Therefore, we need to be careful and manually return int32_max on overflow. + // For symmetry, we also do the same for int32_min, even though it is exactly + // representable as float and the conversion should be exact. + static constexpr auto i32_min = std::numeric_limits::min(); + static constexpr auto i32_min_f = static_cast(i32_min); + static constexpr auto i32_max = std::numeric_limits::max(); + static constexpr auto i32_max_f = static_cast(i32_max); + + // To match the rounding mode of CUDA, we use nearbyint. + // It uses the current rounding mode, which is always FE_TONEAREST on HIP. + // If that changes in the future, we may need to set the rounding mode + // explicitly, either at runtime or compile time. + float dst = std::nearbyint(x); + + // saturate on the higher end. + if (dst >= i32_max_f) { + return i32_max; + } + // saturate on the lower end. + if (dst <= i32_min_f) { + return i32_min; + } + + return static_cast(dst); +#else + // CUDA path + uint32_t dst; + asm volatile("cvt.rni.sat.s32.f32 %0, %1;" : "=r"(dst) : "f"(x)); + return reinterpret_cast(dst); +#endif +} + +static inline __device__ int8_t int32_to_int8(int32_t x) { +#ifdef USE_ROCM + static constexpr auto i8_min = + static_cast(std::numeric_limits::min()); + static constexpr auto i8_max = + static_cast(std::numeric_limits::max()); + + // saturate + int32_t dst = std::clamp(x, i8_min, i8_max); + return static_cast(dst); +#else + // CUDA path + uint32_t dst; + asm volatile("cvt.sat.s8.s32 %0, %1;" : "=r"(dst) : "r"(x)); + return reinterpret_cast(dst); +#endif +} + namespace vllm { template @@ -47,6 +105,23 @@ __global__ void static_scaled_int8_quant_kernel( } } +template +__global__ void static_scaled_int8_azp_quant_kernel( + scalar_t const* __restrict__ input, int8_t* __restrict__ out, + scale_type const* scale_ptr, azp_type const* azp_ptr, + const int hidden_size) { + int const tid = threadIdx.x; + int const token_idx = blockIdx.x; + scale_type const scale = *scale_ptr; + azp_type const azp = *azp_ptr; + + for (int i = tid; i < hidden_size; i += blockDim.x) { + auto const val = static_cast(input[token_idx * hidden_size + i]); + auto const quant_val = int32_to_int8(float_to_int32_rn(val / scale) + azp); + out[token_idx * hidden_size + i] = quant_val; + } +} + template __global__ void dynamic_scaled_int8_quant_kernel( scalar_t const* __restrict__ input, int8_t* __restrict__ out, @@ -80,14 +155,68 @@ __global__ void dynamic_scaled_int8_quant_kernel( } } +template +__global__ void dynamic_scaled_int8_azp_quant_kernel( + scalar_t const* __restrict__ input, int8_t* __restrict__ out, + scale_type* scale, azp_type* azp, const int hidden_size) { + int const token_idx = blockIdx.x; + + // Scan for the min and max value for this token + float max_val = std::numeric_limits::min(); + float min_val = std::numeric_limits::max(); + for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) { + auto val = static_cast(input[token_idx * hidden_size + i]); + max_val = std::max(max_val, val); + min_val = std::min(min_val, val); + } + + // Reduce the max and min values across the block + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage reduceStorage; + max_val = BlockReduce(reduceStorage).Reduce(max_val, cub::Max{}, blockDim.x); + __syncthreads(); // Make sure min doesn't mess with max shared memory + min_val = BlockReduce(reduceStorage).Reduce(min_val, cub::Min{}, blockDim.x); + + __shared__ scale_type scale_sh; + __shared__ azp_type azp_sh; + + // Compute the scale and zero point and store them, only on the first thread + if (threadIdx.x == 0) { + float const scale_val = (max_val - min_val) / 255.0f; + // Use rounding to even (same as torch.round) + auto const azp_float = std::nearbyint(-128.0f - min_val / scale_val); + auto const azp_val = static_cast(azp_float); + + // Store the scale and azp into shared and global + scale[token_idx] = scale_sh = scale_val; + azp[token_idx] = azp_sh = azp_val; + } + + // Wait for the scale and azp to be computed + __syncthreads(); + + float const scale_val = scale_sh; + azp_type const azp_val = azp_sh; + + // Quantize the values + for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) { + auto const val = static_cast(input[token_idx * hidden_size + i]); + auto const quant_val = + int32_to_int8(float_to_int32_rn(val / scale_val) + azp_val); + out[token_idx * hidden_size + i] = quant_val; + } +} + } // namespace vllm void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size] torch::Tensor const& input, // [..., hidden_size] - torch::Tensor const& scale) { + torch::Tensor const& scale, + c10::optional const& azp) { TORCH_CHECK(input.is_contiguous()); TORCH_CHECK(out.is_contiguous()); TORCH_CHECK(scale.numel() == 1); + TORCH_CHECK(!azp || azp->numel() == 1); int const hidden_size = input.size(-1); int const num_tokens = input.numel() / hidden_size; @@ -96,19 +225,29 @@ void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size] const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); VLLM_DISPATCH_FLOATING_TYPES( input.scalar_type(), "static_scaled_int8_quant_kernel", [&] { - vllm::static_scaled_int8_quant_kernel - <<>>(input.data_ptr(), - out.data_ptr(), - scale.data_ptr(), hidden_size); + if (!azp) { + vllm::static_scaled_int8_quant_kernel + <<>>( + input.data_ptr(), out.data_ptr(), + scale.data_ptr(), hidden_size); + } else { + vllm::static_scaled_int8_azp_quant_kernel + <<>>( + input.data_ptr(), out.data_ptr(), + scale.data_ptr(), azp->data_ptr(), + hidden_size); + } }); } void dynamic_scaled_int8_quant( torch::Tensor& out, // [..., hidden_size] torch::Tensor const& input, // [..., hidden_size] - torch::Tensor& scales) { + torch::Tensor& scales, c10::optional const& azp) { TORCH_CHECK(input.is_contiguous()); TORCH_CHECK(out.is_contiguous()); + TORCH_CHECK(scales.is_contiguous()); + TORCH_CHECK(!azp || azp->is_contiguous()); int const hidden_size = input.size(-1); int const num_tokens = input.numel() / hidden_size; @@ -117,9 +256,17 @@ void dynamic_scaled_int8_quant( const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); VLLM_DISPATCH_FLOATING_TYPES( input.scalar_type(), "dynamic_scaled_int8_quant_kernel", [&] { - vllm::dynamic_scaled_int8_quant_kernel - <<>>(input.data_ptr(), - out.data_ptr(), - scales.data_ptr(), hidden_size); + if (!azp) { + vllm::dynamic_scaled_int8_quant_kernel + <<>>( + input.data_ptr(), out.data_ptr(), + scales.data_ptr(), hidden_size); + } else { + vllm::dynamic_scaled_int8_azp_quant_kernel + <<>>( + input.data_ptr(), out.data_ptr(), + scales.data_ptr(), azp->data_ptr(), + hidden_size); + } }); } diff --git a/csrc/quantization/gguf/dequantize.cuh b/csrc/quantization/gguf/dequantize.cuh index 2069fba759ea0..c012262e49015 100644 --- a/csrc/quantization/gguf/dequantize.cuh +++ b/csrc/quantization/gguf/dequantize.cuh @@ -353,18 +353,47 @@ static __global__ void dequantize_block_iq3_s(const void * __restrict__ vx, dst_ template static __global__ void dequantize_block_iq1_s(const void * __restrict__ vx, dst_t * __restrict__ yy) { - const int i = blockIdx.x; + const int64_t i = blockIdx.x; const block_iq1_s * x = (const block_iq1_s *) vx; - const int tid = threadIdx.x; - const int il = tid/8; // 0...3 - const int ib = tid%8; // 0...7 + const int64_t tid = threadIdx.x; + const int64_t il = tid/8; // 0...3 + const int64_t ib = tid%8; // 0...7 + dst_t * y = yy + i*QK_K + 32*ib + 8*il; + const float delta = x[i].qh[ib] & 0x8000 ? -1 - IQ1S_DELTA : -1 + IQ1S_DELTA; + const float d = __half2float(x[i].d) * (2*((x[i].qh[ib] >> 12) & 7) + 1); + uint32_t grid32[2]; const int8_t * q = (const int8_t *)grid32; + grid32[0] = iq1s_grid_gpu[x[i].qs[4*ib+il] | (((x[i].qh[ib] >> 3*il) & 7) << 8)]; + grid32[1] = (grid32[0] >> 4) & 0x0f0f0f0f; + grid32[0] &= 0x0f0f0f0f; + for (int j = 0; j < 8; ++j) { + y[j] = __float2half(d * (q[j] + delta)); + } +} + +template +static __global__ void dequantize_block_iq1_m(const void * __restrict__ vx, dst_t * __restrict__ yy) { + + const int64_t i = blockIdx.x; + const block_iq1_m * x = (const block_iq1_m *) vx; + + const int64_t tid = threadIdx.x; + const int64_t il = tid/8; // 0...3 + const int64_t ib = tid%8; // 0...7 dst_t * y = yy + i*QK_K + 32*ib + 8*il; - const int i8 = 4*ib+il; - uint8_t h = x[i].scales[i8/2] >> 4*(i8%2); - const int8_t * grid = (const int8_t *)(iq1s_grid + (x[i].qs[i8] | ((h & 8) << 5))); - const float d = __half2float(x[i].d) * (2*(h & 7) + 1); - for (int j = 0; j < 8; ++j) y[j] = __float2half(d * grid[j]); + const uint16_t * sc = (const uint16_t *)x[i].scales; + iq1m_scale_t scale; + scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000); + const int64_t ib16 = 2*ib + il/2; // sc[ib16/4] >> 3*(ib16%4) -> sc[ib/2] >> 3*((2*ib+il/2)%4); + const float d = __half2float(scale.f16) * (2*((sc[ib16/4] >> 3*(ib16%4)) & 0x7) + 1); + const float delta = x[i].qh[2*ib+il/2] & (0x08 << 4*(il%2)) ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA; + uint32_t grid32[2]; const int8_t * q = (const int8_t *)grid32; + grid32[0] = iq1s_grid_gpu[x[i].qs[4*ib+il] | (((x[i].qh[2*ib+il/2] >> 4*(il%2)) & 7) << 8)]; + grid32[1] = (grid32[0] >> 4) & 0x0f0f0f0f; + grid32[0] &= 0x0f0f0f0f; + for (int j = 0; j < 8; ++j) { + y[j] = __float2half(d * (q[j] + delta)); + } } template @@ -475,6 +504,12 @@ static void dequantize_row_iq1_s_cuda(const void * vx, dst_t * y, const int k, c dequantize_block_iq1_s<<>>(vx, y); } +template +static void dequantize_row_iq1_m_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { + const int nb = k / QK_K; + dequantize_block_iq1_m<<>>(vx, y); +} + template static void dequantize_row_iq4_nl_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { const int nb = (k + QK_K - 1) / QK_K; @@ -525,6 +560,8 @@ static to_fp16_cuda_t ggml_get_to_fp16_cuda(int64_t type) { return dequantize_row_iq2_s_cuda; case 23: return dequantize_row_iq4_xs_cuda; + case 29: + return dequantize_row_iq1_m_cuda; default: return nullptr; } diff --git a/csrc/quantization/gguf/ggml-common.h b/csrc/quantization/gguf/ggml-common.h index d7989d84bf68e..fba94fd1d157b 100644 --- a/csrc/quantization/gguf/ggml-common.h +++ b/csrc/quantization/gguf/ggml-common.h @@ -149,14 +149,30 @@ typedef struct { uint8_t scales[IQ3S_N_SCALE]; } block_iq3_s; +// 1.5625 bpw #define QR1_S 8 #define QI1_S (QK_K / (4*QR1_S)) typedef struct { half d; - uint8_t qs[QK_K/8]; - uint8_t scales[QK_K/16]; + uint8_t qs[QK_K/8]; + uint16_t qh[QK_K/32]; } block_iq1_s; +// 1.75 bpw +#define QR1_M 8 +#define QI1_M (QK_K / (4*QR1_M)) +typedef struct { + uint8_t qs[QK_K/8]; // grid index, low 8 bits + uint8_t qh[QK_K/16]; // grid index, high 3 bits + grid shift bit (for two groups of 8) + uint8_t scales[QK_K/32]; // 3-bit block scales (4-bit if QK_K == 64) +} block_iq1_m; + +// Used by IQ1_M quants +typedef union { + half f16; + uint16_t u16; +} iq1m_scale_t; + #define QK4_NL 32 #define QR4_NL 2 #define QI4_NL (QK4_NL / (4*QR4_NL)) @@ -733,135 +749,265 @@ static const __device__ uint32_t iq3xs_grid[512] = { 0x3e240c1c, 0x3e241404, 0x3e242c04, 0x3e2c1414, 0x3e2c2414, 0x3e340414, 0x3e341c0c, 0x3e3e0404, }; -static const __device__ uint64_t iq1s_grid[512] = { - 0xffffffffffff0101, 0xffffffffff01ff00, 0xffffffffff010100, 0xffffffff00000000, - 0xffffffff01ff00ff, 0xffffffff01ff0001, 0xffffffff0101ffff, 0xffffffff0101ff01, - 0xffffff00ff000000, 0xffffff000000ff00, 0xffffff00000000ff, 0xffffff0000000100, - 0xffffff0000010000, 0xffffff0001000000, 0xffffff01ffff00ff, 0xffffff01ff01ff00, - 0xffffff01ff010100, 0xffffff0100000001, 0xffffff0101ffff00, 0xffffff0101ff0101, - 0xffffff0101010100, 0xffff00ffff00ff01, 0xffff00ffff0000ff, 0xffff00ff00ff0100, - 0xffff00ff0100ff00, 0xffff00ff010001ff, 0xffff0000ff0101ff, 0xffff000000ffff00, - 0xffff000000000000, 0xffff00000001ff01, 0xffff000001000101, 0xffff0000010100ff, - 0xffff0001ffff0100, 0xffff00010000ff00, 0xffff000100010101, 0xffff000101000000, - 0xffff01ffffff0000, 0xffff01ffff01ffff, 0xffff01ffff010100, 0xffff01ff00000000, - 0xffff01ff01ffffff, 0xffff01ff01ff0001, 0xffff01ff0101ffff, 0xffff01ff01010001, - 0xffff0100ffffff01, 0xffff01000000ffff, 0xffff010000000100, 0xffff010001ff01ff, - 0xffff010001000000, 0xffff0101ff000000, 0xffff0101000101ff, 0xffff010101ffff01, - 0xffff01010101ff00, 0xff00ffffff000000, 0xff00ffff00ffff00, 0xff00ffff00000001, - 0xff00ffff000001ff, 0xff00ffff01010000, 0xff00ff00ffff0000, 0xff00ff00ff00ff00, - 0xff00ff00ff0000ff, 0xff00ff00ff000100, 0xff00ff00ff010001, 0xff00ff0000ff0001, - 0xff00ff000000ffff, 0xff00ff0000000000, 0xff00ff000001ff00, 0xff00ff0000010100, - 0xff00ff0001ff0000, 0xff00ff000100ff00, 0xff00ff0001000100, 0xff00ff01ff000000, - 0xff00ff0100ff0000, 0xff00ff01000001ff, 0xff00ff0101010001, 0xff0000ff00000000, - 0xff0000ff0001ff00, 0xff0000ff00010100, 0xff000000ffff0101, 0xff000000ff000000, - 0xff000000ff01ff00, 0xff00000000ff0000, 0xff0000000000ff00, 0xff000000000000ff, - 0xff00000000000000, 0xff00000000000001, 0xff00000000000100, 0xff0000000001ffff, - 0xff00000000010000, 0xff00000001000000, 0xff00000001010100, 0xff000001ff00ff01, - 0xff000001ff0100ff, 0xff00000100000000, 0xff0000010001ff00, 0xff00000101ff0100, - 0xff0000010100ff00, 0xff0001ff00ff00ff, 0xff0001ff00000101, 0xff0001ff000100ff, - 0xff0001ff01000000, 0xff000100ff0001ff, 0xff0001000000ff01, 0xff00010000000000, - 0xff00010000010001, 0xff00010000010100, 0xff00010001ffff00, 0xff00010001ff0101, - 0xff00010001010000, 0xff000101ffffffff, 0xff000101ff000101, 0xff00010101ff00ff, - 0xff00010101000001, 0xff000101010100ff, 0xff01ffffff000101, 0xff01ffffff01ffff, - 0xff01ffffff01ff01, 0xff01ffffff0101ff, 0xff01ffff00000000, 0xff01ffff01ff0001, - 0xff01ffff0101ff01, 0xff01ff00ff000000, 0xff01ff0000ff0100, 0xff01ff000000ff01, - 0xff01ff0000010000, 0xff01ff00010000ff, 0xff01ff01ff01ff00, 0xff01ff0100000101, - 0xff0100ffffff0000, 0xff0100ffff010000, 0xff0100ff01ff00ff, 0xff0100ff01000100, - 0xff0100ff010100ff, 0xff010000ffffff01, 0xff01000000000000, 0xff0100000101ff00, - 0xff010001ffff00ff, 0xff010001ff000100, 0xff01000100ffff00, 0xff01000100010001, - 0xff01000101ff0001, 0xff010001010001ff, 0xff0101ffffffffff, 0xff0101ffff01ffff, - 0xff0101ffff010101, 0xff0101ff0000ff00, 0xff0101ff01010001, 0xff010100ff000000, - 0xff010100ff01ff01, 0xff01010000ff0001, 0xff01010000000100, 0xff01010001000000, - 0xff0101010100ffff, 0x00ffffff0000ff01, 0x00ffffff000000ff, 0x00ffffff00000100, - 0x00ffffff00010000, 0x00ffff00ffff0001, 0x00ffff00ff0000ff, 0x00ffff00ff000100, - 0x00ffff0000000000, 0x00ffff0001000100, 0x00ffff0001010001, 0x00ffff01ff00ff01, - 0x00ffff0100ff0100, 0x00ffff010000ff00, 0x00ffff01000100ff, 0x00ffff0101ff00ff, - 0x00ffff010101ff00, 0x00ff00ffffffffff, 0x00ff00ffffff01ff, 0x00ff00ffff000101, - 0x00ff00ff00000000, 0x00ff00ff000101ff, 0x00ff00ff01010101, 0x00ff0000ff000000, - 0x00ff0000ff01ffff, 0x00ff000000ff0000, 0x00ff00000000ff00, 0x00ff0000000000ff, - 0x00ff000000000000, 0x00ff000000000001, 0x00ff000000000100, 0x00ff000000010000, - 0x00ff000001ffff01, 0x00ff000001000000, 0x00ff0001ff000101, 0x00ff000100ffffff, - 0x00ff000100000000, 0x00ff0001010001ff, 0x00ff01ffff000000, 0x00ff01ff0001ff00, - 0x00ff01ff01ff0100, 0x00ff0100ff01ff01, 0x00ff010000ff00ff, 0x00ff010000ff0101, - 0x00ff010000000000, 0x00ff010000010101, 0x00ff01000100ff00, 0x00ff010001010000, - 0x00ff0101ffffff00, 0x00ff01010000ff01, 0x00ff010100000100, 0x00ff010101ff0000, - 0x0000ffffffff0100, 0x0000ffffff00ff00, 0x0000ffffff0000ff, 0x0000ffffff010000, - 0x0000ffff00000000, 0x0000ffff00010101, 0x0000ffff01ffff01, 0x0000ffff01000100, - 0x0000ff00ff000000, 0x0000ff00ff01ff00, 0x0000ff00ff0101ff, 0x0000ff0000ff0000, - 0x0000ff000000ff00, 0x0000ff00000000ff, 0x0000ff0000000000, 0x0000ff0000000001, - 0x0000ff0000000100, 0x0000ff0000010000, 0x0000ff0001ffffff, 0x0000ff0001ff01ff, - 0x0000ff0001000000, 0x0000ff000101ffff, 0x0000ff01ffff0101, 0x0000ff01ff010000, - 0x0000ff0100000000, 0x0000ff0101000101, 0x000000ffffff0001, 0x000000ffff000000, - 0x000000ff00ff0000, 0x000000ff0000ff00, 0x000000ff000000ff, 0x000000ff00000000, - 0x000000ff00000001, 0x000000ff00000100, 0x000000ff00010000, 0x000000ff01000000, - 0x000000ff0101ff00, 0x00000000ffff0000, 0x00000000ff00ff00, 0x00000000ff0000ff, - 0x00000000ff000000, 0x00000000ff000001, 0x00000000ff000100, 0x00000000ff010000, - 0x0000000000ffff00, 0x0000000000ff00ff, 0x0000000000ff0000, 0x0000000000ff0001, - 0x0000000000ff0100, 0x000000000000ffff, 0x000000000000ff00, 0x000000000000ff01, - 0x00000000000000ff, 0x0000000000000001, 0x00000000000001ff, 0x0000000000000100, - 0x0000000000000101, 0x000000000001ff00, 0x00000000000100ff, 0x0000000000010000, - 0x0000000000010001, 0x0000000000010100, 0x0000000001ff0000, 0x000000000100ff00, - 0x00000000010000ff, 0x0000000001000000, 0x0000000001000001, 0x0000000001000100, - 0x0000000001010000, 0x00000001ffff01ff, 0x00000001ff000000, 0x0000000100ff0000, - 0x000000010000ff00, 0x00000001000000ff, 0x0000000100000000, 0x0000000100000001, - 0x0000000100000100, 0x0000000100010000, 0x0000000101000000, 0x000001ffff00ff00, - 0x000001ffff010001, 0x000001ffff0101ff, 0x000001ff00ffff01, 0x000001ff0000ffff, - 0x000001ff00000000, 0x000001ff010000ff, 0x000001ff01010100, 0x00000100ffff0100, - 0x00000100ff000000, 0x0000010000ff0000, 0x000001000000ff00, 0x00000100000000ff, - 0x0000010000000000, 0x0000010000000001, 0x0000010000000100, 0x0000010000010000, - 0x0000010001000000, 0x000001000101ff01, 0x00000101ffff0001, 0x00000101ff01ffff, - 0x0000010100000000, 0x0000010101010100, 0x0001ffffff000000, 0x0001ffff00ffffff, - 0x0001ffff00000100, 0x0001ffff0001ff00, 0x0001ffff01000000, 0x0001ff00ffffff00, - 0x0001ff00ffff01ff, 0x0001ff00ff010000, 0x0001ff0000000000, 0x0001ff0000010001, - 0x0001ff0001ff0000, 0x0001ff0001010100, 0x0001ff01ff0000ff, 0x0001ff01ff000001, - 0x0001ff0100ffffff, 0x0001ff010001ffff, 0x0001ff01000101ff, 0x0001ff010100ff01, - 0x000100ffff00ffff, 0x000100ffff00ff01, 0x000100ffff000100, 0x000100ff00000000, - 0x000100ff000101ff, 0x000100ff01ff0101, 0x000100ff0100ffff, 0x000100ff01010101, - 0x00010000ff000000, 0x00010000ff010100, 0x0001000000ff0000, 0x000100000000ff00, - 0x00010000000000ff, 0x0001000000000000, 0x0001000000000001, 0x0001000000000100, - 0x0001000000010000, 0x0001000001ffff01, 0x0001000001000000, 0x0001000100ff0101, - 0x0001000100000000, 0x00010001010100ff, 0x000101ffffff01ff, 0x000101ffffff0101, - 0x000101ff00010000, 0x000101ff01ff0000, 0x000101ff0100ff01, 0x00010100ffff0000, - 0x0001010000000000, 0x000101000001ffff, 0x0001010000010101, 0x00010100010001ff, - 0x00010101ff00ff00, 0x00010101ff010001, 0x0001010100ffffff, 0x0001010100ff01ff, - 0x00010101000101ff, 0x0001010101ff0000, 0x000101010100ff01, 0x0001010101000101, - 0x01ffffffffff0101, 0x01ffffffff01ffff, 0x01ffffffff01ff01, 0x01ffffffff0101ff, - 0x01ffffffff010101, 0x01ffffff00000000, 0x01ffffff01ff01ff, 0x01ffffff01000101, - 0x01ffffff0101ff01, 0x01ffffff010100ff, 0x01ffff000000ff00, 0x01ffff0000000001, - 0x01ffff00000001ff, 0x01ffff0000010000, 0x01ffff0001ff0000, 0x01ffff01ffffffff, - 0x01ffff01ffff01ff, 0x01ffff01ff000000, 0x01ffff01ff01ffff, 0x01ffff01ff0101ff, - 0x01ffff010100ffff, 0x01ff00ffffff0000, 0x01ff00ffff010000, 0x01ff00ff00ffff01, - 0x01ff0000ff0000ff, 0x01ff000000000000, 0x01ff00000001ff01, 0x01ff000001ffffff, - 0x01ff000001010100, 0x01ff0001ffffff01, 0x01ff0001ff010001, 0x01ff000101ff0100, - 0x01ff000101000001, 0x01ff0001010100ff, 0x01ff01ffff00ffff, 0x01ff01ff00010001, - 0x01ff01ff01000000, 0x01ff01ff010101ff, 0x01ff0100ff000001, 0x01ff010000ffff00, - 0x01ff010000000100, 0x01ff010001ff01ff, 0x01ff01000101ffff, 0x01ff0101ffff00ff, - 0x01ff0101ffff0101, 0x01ff0101ff0101ff, 0x01ff010100010000, 0x0100ffff00ff00ff, - 0x0100ffff00ff0001, 0x0100ffff00000100, 0x0100ffff0100ff00, 0x0100ff00ffff0000, - 0x0100ff00ff00ffff, 0x0100ff00ff00ff01, 0x0100ff00ff000100, 0x0100ff00ff010000, - 0x0100ff0000000000, 0x0100ff00000100ff, 0x0100ff0001ff0101, 0x0100ff0001010101, - 0x0100ff0100ff00ff, 0x0100ff0100ff0001, 0x0100ff0100000100, 0x0100ff0100010001, - 0x0100ff0101000000, 0x010000ffff00ff00, 0x010000ff0000ffff, 0x010000ff00000000, - 0x010000ff010001ff, 0x010000ff01010001, 0x01000000ffffff00, 0x01000000ffff0101, - 0x01000000ff000000, 0x01000000ff0100ff, 0x01000000ff010101, 0x0100000000ff0000, - 0x010000000000ff00, 0x01000000000000ff, 0x0100000000000000, 0x0100000000000001, - 0x0100000000000100, 0x0100000000010000, 0x0100000001000000, 0x0100000100000000, - 0x01000001000101ff, 0x0100000101ffff01, 0x010001ffff000101, 0x010001ff00ff0100, - 0x010001ff0000ff00, 0x010001ff000100ff, 0x010001ff01ffffff, 0x01000100ffff0000, - 0x01000100ff0001ff, 0x0100010000000000, 0x010001000001ff00, 0x0100010001ff0000, - 0x01000100010000ff, 0x0100010001000101, 0x01000101ff00ff01, 0x0100010100ff0100, - 0x010001010000ffff, 0x0100010101010001, 0x0101ffffffff0101, 0x0101ffffff0001ff, - 0x0101ffffff01ffff, 0x0101ffffff010101, 0x0101ffff00000000, 0x0101ffff0101ffff, - 0x0101ffff010101ff, 0x0101ff00ff000000, 0x0101ff0000ff0100, 0x0101ff000000ff00, - 0x0101ff0000010000, 0x0101ff00010000ff, 0x0101ff0001000001, 0x0101ff01ff010101, - 0x0101ff0100000000, 0x0101ff010101ff00, 0x010100ffffff0000, 0x010100ffff010000, - 0x010100ff00ff01ff, 0x010100ff000000ff, 0x010100ff00000101, 0x010100ff01ffff00, - 0x01010000ffffff01, 0x01010000ff000100, 0x01010000ff01ff01, 0x0101000000000000, - 0x01010000000100ff, 0x010100000101ff01, 0x01010001ffff0000, 0x01010001ff00ffff, - 0x01010001ff010000, 0x0101000101ffffff, 0x0101000101ff01ff, 0x0101000101010101, - 0x010101ffff01ffff, 0x010101ff00000000, 0x010101ff0001ff01, 0x010101ff0101ffff, - 0x010101ff010101ff, 0x01010100ffffffff, 0x01010100ff000001, 0x010101000000ff00, - 0x0101010001010000, 0x0101010100ff0001, 0x010101010001ff01, 0x010101010101ffff, +#define IQ1S_DELTA 0.125f +#define IQ1M_DELTA 0.125f +static const __device__ uint64_t iq1s_grid_gpu[2048] = { + 0x00000000, 0x00000002, 0x00000101, 0x00000200, 0x00000202, 0x00010001, 0x00010101, 0x00020000, + 0x00020002, 0x00020200, 0x00020202, 0x01000101, 0x01010001, 0x01010100, 0x01010102, 0x01020101, + 0x02000000, 0x02000002, 0x02000200, 0x02000202, 0x02010101, 0x02020000, 0x02020002, 0x02020200, + 0x02020202, 0x00000110, 0x00000111, 0x00010011, 0x00010110, 0x00010112, 0x00010211, 0x00010212, + 0x00020111, 0x01000011, 0x01000112, 0x01000211, 0x01010012, 0x01010111, 0x01010212, 0x01020011, + 0x01020110, 0x01020112, 0x01020210, 0x02000111, 0x02010011, 0x02010110, 0x02010112, 0x02020111, + 0x00000020, 0x00000022, 0x00000220, 0x00000222, 0x00010121, 0x00020020, 0x00020022, 0x00020220, + 0x00020222, 0x01000121, 0x01010021, 0x01010221, 0x01020120, 0x01020221, 0x02000020, 0x02000022, + 0x02000220, 0x02000222, 0x02010021, 0x02010121, 0x02010221, 0x02020020, 0x02020022, 0x02020220, + 0x02020222, 0x00011001, 0x00011100, 0x00011102, 0x00021101, 0x01001001, 0x01001201, 0x01011101, + 0x01011202, 0x01021100, 0x01021101, 0x02011001, 0x02011201, 0x02021101, 0x00001011, 0x00001110, + 0x00001111, 0x00001112, 0x00011111, 0x00011210, 0x00011212, 0x00021211, 0x01001010, 0x01001111, + 0x01001212, 0x01011010, 0x01011011, 0x01011110, 0x01011111, 0x01011112, 0x01011211, 0x01021010, + 0x01021012, 0x01021111, 0x01021210, 0x01021212, 0x02001011, 0x02011011, 0x02011111, 0x02011210, + 0x02011212, 0x02021011, 0x02021110, 0x02021111, 0x02021112, 0x02021211, 0x00011120, 0x00011221, + 0x01001021, 0x01001120, 0x01011020, 0x01011022, 0x01011121, 0x01011220, 0x01021020, 0x01021021, + 0x01021122, 0x01021221, 0x02001121, 0x02011021, 0x02011120, 0x02011221, 0x00002000, 0x00002002, + 0x00002200, 0x00002202, 0x00012101, 0x00022000, 0x00022002, 0x00022200, 0x00022202, 0x01002101, + 0x01012001, 0x01012102, 0x01022101, 0x02002000, 0x02002002, 0x02002200, 0x02002202, 0x02012101, + 0x02022000, 0x02022002, 0x02022200, 0x02022202, 0x00002111, 0x00012011, 0x00012110, 0x00012211, + 0x00022110, 0x00022111, 0x01002011, 0x01012010, 0x01012011, 0x01012111, 0x01022011, 0x01022110, + 0x01022211, 0x02012011, 0x02012110, 0x02012112, 0x02012211, 0x02022111, 0x00002020, 0x00002022, + 0x00002220, 0x00002222, 0x00012121, 0x00022020, 0x00022022, 0x00022220, 0x00022222, 0x01002121, + 0x01012021, 0x01012221, 0x01022021, 0x01022121, 0x02002020, 0x02002022, 0x02002121, 0x02002220, + 0x02002222, 0x02012121, 0x02022020, 0x02022022, 0x02022220, 0x02022222, 0x00110000, 0x00110001, + 0x00110100, 0x00110201, 0x00120100, 0x00120101, 0x01100001, 0x01100100, 0x01110000, 0x01110101, + 0x01110200, 0x01120001, 0x01120100, 0x01120101, 0x01120201, 0x02110001, 0x02110100, 0x02110102, + 0x02120001, 0x02120101, 0x00100011, 0x00100110, 0x00100112, 0x00100211, 0x00110010, 0x00110012, + 0x00110111, 0x00110210, 0x00120011, 0x00120110, 0x00120211, 0x01100111, 0x01100212, 0x01110010, + 0x01110011, 0x01110012, 0x01110110, 0x01110111, 0x01110112, 0x01110211, 0x01120010, 0x01120111, + 0x02100110, 0x02110012, 0x02110111, 0x02120011, 0x02120110, 0x00110021, 0x00110120, 0x00110122, + 0x00120121, 0x01100020, 0x01100122, 0x01100221, 0x01110022, 0x01110121, 0x01110220, 0x01110222, + 0x01120120, 0x01120122, 0x02100121, 0x02110021, 0x02110120, 0x02110122, 0x02120121, 0x00101001, + 0x00101102, 0x00101201, 0x00111100, 0x00111101, 0x00111200, 0x00111201, 0x00121001, 0x00121102, + 0x01101001, 0x01101101, 0x01101102, 0x01101200, 0x01101202, 0x01111001, 0x01111100, 0x01111101, + 0x01111102, 0x01111201, 0x01121002, 0x01121101, 0x01121200, 0x02101100, 0x02101201, 0x02111000, + 0x02111100, 0x02111101, 0x02111200, 0x02111201, 0x02111202, 0x02121001, 0x02121100, 0x02121101, + 0x02121201, 0x00101012, 0x00101111, 0x00101212, 0x00111011, 0x00111110, 0x00111111, 0x00111112, + 0x00111211, 0x00121010, 0x00121012, 0x00121111, 0x00121210, 0x00121212, 0x01101011, 0x01101110, + 0x01101111, 0x01101112, 0x01111011, 0x01111012, 0x01111110, 0x01111111, 0x01111112, 0x01111211, + 0x01111212, 0x01121011, 0x01121110, 0x01121111, 0x01121112, 0x01121211, 0x02101010, 0x02101012, + 0x02101110, 0x02101111, 0x02101210, 0x02101212, 0x02111010, 0x02111011, 0x02111110, 0x02111111, + 0x02111112, 0x02111211, 0x02111212, 0x02121010, 0x02121012, 0x02121111, 0x00101021, 0x00101120, + 0x00101121, 0x00101122, 0x00111121, 0x00111122, 0x00111220, 0x00111222, 0x00121021, 0x00121122, + 0x01101020, 0x01101022, 0x01101120, 0x01101121, 0x01101220, 0x01101222, 0x01111021, 0x01111121, + 0x01111122, 0x01111220, 0x01111221, 0x01121021, 0x01121120, 0x01121121, 0x01121220, 0x01121221, + 0x01121222, 0x02101122, 0x02101222, 0x02111022, 0x02111121, 0x02121120, 0x02121221, 0x00112001, + 0x00112102, 0x00122101, 0x01102001, 0x01102100, 0x01102102, 0x01102201, 0x01112000, 0x01112101, + 0x01112200, 0x01112202, 0x01122000, 0x01122001, 0x01122100, 0x01122102, 0x01122201, 0x02102101, + 0x02112001, 0x02112100, 0x02122101, 0x00112010, 0x00112012, 0x00112111, 0x00112212, 0x00122011, + 0x00122111, 0x01102012, 0x01102110, 0x01102111, 0x01102210, 0x01112011, 0x01112110, 0x01112111, + 0x01112112, 0x01112211, 0x01112212, 0x01122010, 0x01122111, 0x01122212, 0x02102211, 0x02112011, + 0x02112012, 0x02112111, 0x02112210, 0x02122011, 0x02122112, 0x02122211, 0x00102221, 0x00112122, + 0x00122120, 0x00122122, 0x01102120, 0x01102122, 0x01102221, 0x01112020, 0x01112022, 0x01112121, + 0x01112220, 0x01122021, 0x01122122, 0x01122221, 0x02102121, 0x02112021, 0x02112122, 0x02112222, + 0x00200000, 0x00200002, 0x00200200, 0x00200202, 0x00210101, 0x00220000, 0x00220002, 0x00220101, + 0x00220200, 0x00220202, 0x01200101, 0x01210001, 0x01210201, 0x01220001, 0x01220101, 0x02200000, + 0x02200002, 0x02200200, 0x02200202, 0x02210101, 0x02220000, 0x02220002, 0x02220101, 0x02220200, + 0x02220202, 0x00200111, 0x00210011, 0x00210110, 0x00210211, 0x00220111, 0x01200012, 0x01200110, + 0x01200211, 0x01210111, 0x01210210, 0x01210212, 0x01220011, 0x01220110, 0x01220111, 0x01220112, + 0x02200111, 0x02210010, 0x02210112, 0x02210211, 0x02220111, 0x00200021, 0x00200220, 0x00200222, + 0x00210021, 0x00210121, 0x00220020, 0x00220022, 0x00220220, 0x00220222, 0x01200121, 0x01210021, + 0x01210122, 0x01210221, 0x01220121, 0x02200021, 0x02200220, 0x02200222, 0x02210021, 0x02210121, + 0x02220020, 0x02220022, 0x02220220, 0x02220222, 0x00201101, 0x00211100, 0x00211102, 0x00211201, + 0x00221101, 0x01201100, 0x01201101, 0x01201102, 0x01201201, 0x01211002, 0x01211101, 0x01211200, + 0x01211202, 0x01221102, 0x02201101, 0x02211001, 0x02211100, 0x02211201, 0x02221001, 0x02221101, + 0x00201211, 0x00211111, 0x00221011, 0x00221211, 0x01201010, 0x01201111, 0x01201210, 0x01211011, + 0x01211110, 0x01211111, 0x01211211, 0x01221012, 0x01221111, 0x01221210, 0x02201211, 0x02211010, + 0x02211110, 0x02211111, 0x02211210, 0x02211212, 0x02221011, 0x02221110, 0x02221112, 0x02221211, + 0x00201121, 0x00211020, 0x00211022, 0x00211221, 0x00221121, 0x01201021, 0x01201221, 0x01211121, + 0x01221020, 0x01221021, 0x01221221, 0x02201120, 0x02201122, 0x02211020, 0x02211222, 0x00202000, + 0x00202002, 0x00202200, 0x00202202, 0x00212101, 0x00222000, 0x00222002, 0x00222200, 0x00222202, + 0x01202101, 0x01212001, 0x01212100, 0x01222101, 0x02202000, 0x02202002, 0x02202200, 0x02202202, + 0x02222000, 0x02222002, 0x02222200, 0x02222202, 0x00202211, 0x00212011, 0x00212110, 0x00212211, + 0x00222111, 0x01202112, 0x01202211, 0x01212012, 0x01212111, 0x01222011, 0x01222110, 0x01222112, + 0x01222211, 0x02202111, 0x02212010, 0x02212112, 0x02212211, 0x02222110, 0x02222111, 0x00202020, + 0x00202022, 0x00202220, 0x00202222, 0x00222020, 0x00222022, 0x00222220, 0x00222222, 0x01202121, + 0x01212021, 0x01212122, 0x01212221, 0x01222121, 0x02202020, 0x02202022, 0x02202220, 0x02202222, + 0x02212121, 0x02222020, 0x02222022, 0x02222220, 0x02222222, 0x10000101, 0x10010001, 0x10010102, + 0x10020101, 0x11000201, 0x11010002, 0x11010101, 0x11010200, 0x11010202, 0x11020001, 0x11020100, + 0x11020102, 0x12010100, 0x12010201, 0x12020001, 0x12020102, 0x10000010, 0x10000011, 0x10000110, + 0x10000112, 0x10000211, 0x10010012, 0x10010111, 0x10010112, 0x10010210, 0x10010212, 0x10020011, + 0x10020112, 0x10020211, 0x11000111, 0x11000210, 0x11000212, 0x11010011, 0x11010110, 0x11010111, + 0x11010112, 0x11010211, 0x11010212, 0x11020111, 0x11020210, 0x11020212, 0x12000011, 0x12000110, + 0x12000112, 0x12010010, 0x12010012, 0x12010111, 0x12020010, 0x12020011, 0x12020012, 0x10000121, + 0x10010021, 0x10010120, 0x10010122, 0x10020121, 0x11000021, 0x11010022, 0x11010121, 0x11010222, + 0x11020120, 0x11020221, 0x12000221, 0x12010120, 0x12020121, 0x10001001, 0x10011101, 0x10011201, + 0x10021201, 0x11001101, 0x11001200, 0x11001202, 0x11011001, 0x11011100, 0x11011101, 0x11011102, + 0x11021001, 0x11021002, 0x11021101, 0x11021200, 0x11021202, 0x12001001, 0x12001102, 0x12001201, + 0x12011000, 0x12011002, 0x12011101, 0x12021000, 0x12021001, 0x12021201, 0x10001011, 0x10001012, + 0x10001111, 0x10001212, 0x10011011, 0x10011110, 0x10011111, 0x10011112, 0x10011211, 0x10021010, + 0x10021111, 0x10021212, 0x11001011, 0x11001110, 0x11001111, 0x11001112, 0x11001211, 0x11011010, + 0x11011011, 0x11011110, 0x11011111, 0x11011112, 0x11011210, 0x11011211, 0x11021011, 0x11021110, + 0x11021111, 0x11021112, 0x11021211, 0x12001012, 0x12001110, 0x12001111, 0x12001210, 0x12011011, + 0x12011110, 0x12011111, 0x12011112, 0x12011211, 0x12011212, 0x12021111, 0x12021210, 0x12021212, + 0x10001021, 0x10001121, 0x10001221, 0x10011120, 0x10011121, 0x10011220, 0x10011222, 0x10021021, + 0x10021120, 0x10021221, 0x11001020, 0x11001022, 0x11001121, 0x11001220, 0x11011020, 0x11011021, + 0x11011022, 0x11011121, 0x11011122, 0x11011221, 0x11021022, 0x11021121, 0x11021220, 0x12001021, + 0x12001121, 0x12001222, 0x12011120, 0x12011121, 0x12021021, 0x12021120, 0x12021122, 0x10002101, + 0x10012001, 0x10012101, 0x10012202, 0x10022101, 0x11002002, 0x11002201, 0x11012000, 0x11012101, + 0x11012200, 0x11022001, 0x11022100, 0x11022102, 0x11022201, 0x12002101, 0x12012001, 0x12012100, + 0x12012102, 0x12012201, 0x12022101, 0x10002011, 0x10002111, 0x10002112, 0x10002212, 0x10012010, + 0x10012110, 0x10012111, 0x10012210, 0x10022011, 0x10022110, 0x10022112, 0x11002010, 0x11002111, + 0x11002212, 0x11012011, 0x11012012, 0x11012110, 0x11012111, 0x11012112, 0x11012211, 0x11022010, + 0x11022012, 0x11022111, 0x11022112, 0x11022212, 0x12002112, 0x12002211, 0x12012012, 0x12012111, + 0x12012112, 0x12012210, 0x12022011, 0x12022110, 0x12022112, 0x12022211, 0x10012122, 0x11002120, + 0x11002122, 0x11002221, 0x11012121, 0x11012220, 0x11012222, 0x11022120, 0x11022221, 0x12012120, + 0x12022121, 0x10100001, 0x10100100, 0x10100101, 0x10100102, 0x10100201, 0x10110002, 0x10110101, + 0x10110202, 0x10120001, 0x10120100, 0x10120201, 0x11100000, 0x11100101, 0x11100200, 0x11110001, + 0x11110100, 0x11110101, 0x11110102, 0x11110201, 0x11120101, 0x11120200, 0x12100102, 0x12100201, + 0x12110101, 0x12110200, 0x12120000, 0x12120001, 0x12120102, 0x12120201, 0x10100111, 0x10100210, + 0x10100211, 0x10100212, 0x10110011, 0x10110110, 0x10110111, 0x10110112, 0x10110210, 0x10110211, + 0x10120010, 0x10120111, 0x10120112, 0x10120210, 0x10120212, 0x11100011, 0x11100110, 0x11100111, + 0x11100112, 0x11100211, 0x11110010, 0x11110011, 0x11110012, 0x11110110, 0x11110111, 0x11110112, + 0x11110210, 0x11110211, 0x11110212, 0x11120011, 0x11120110, 0x11120111, 0x11120112, 0x11120211, + 0x12100012, 0x12100111, 0x12110011, 0x12110110, 0x12110111, 0x12110112, 0x12110211, 0x12120010, + 0x12120111, 0x12120212, 0x10100021, 0x10100122, 0x10110022, 0x10110121, 0x10110222, 0x10120021, + 0x10120120, 0x11100022, 0x11100121, 0x11100222, 0x11110021, 0x11110120, 0x11110121, 0x11110122, + 0x11110221, 0x11120022, 0x11120121, 0x12100121, 0x12110020, 0x12110022, 0x12110121, 0x12110221, + 0x12110222, 0x12120120, 0x10101100, 0x10101101, 0x10111001, 0x10111100, 0x10111101, 0x10111102, + 0x10111200, 0x10111201, 0x10121001, 0x10121101, 0x10121200, 0x10121202, 0x11101001, 0x11101100, + 0x11101101, 0x11101102, 0x11101201, 0x11101202, 0x11111000, 0x11111001, 0x11111100, 0x11111101, + 0x11111102, 0x11111200, 0x11111201, 0x11111202, 0x11121001, 0x11121002, 0x11121100, 0x11121101, + 0x11121102, 0x11121201, 0x12101000, 0x12101200, 0x12101202, 0x12111001, 0x12111100, 0x12111101, + 0x12111102, 0x12111201, 0x12121001, 0x12121100, 0x12121101, 0x12121202, 0x10101011, 0x10101012, + 0x10101110, 0x10101111, 0x10101112, 0x10101211, 0x10111010, 0x10111011, 0x10111012, 0x10111110, + 0x10111111, 0x10111112, 0x10111211, 0x10111212, 0x10121011, 0x10121110, 0x10121111, 0x10121112, + 0x10121211, 0x11101010, 0x11101011, 0x11101012, 0x11101110, 0x11101111, 0x11101112, 0x11101210, + 0x11101211, 0x11111010, 0x11111011, 0x11111012, 0x11111110, 0x11111111, 0x11111112, 0x11111210, + 0x11111211, 0x11111212, 0x11121010, 0x11121011, 0x11121110, 0x11121111, 0x11121112, 0x11121210, + 0x11121211, 0x11121212, 0x12101011, 0x12101110, 0x12101111, 0x12101211, 0x12101212, 0x12111010, + 0x12111011, 0x12111110, 0x12111111, 0x12111112, 0x12111210, 0x12111211, 0x12121011, 0x12121110, + 0x12121111, 0x12121112, 0x12121211, 0x10101020, 0x10101021, 0x10101022, 0x10101120, 0x10101122, + 0x10101220, 0x10101221, 0x10111021, 0x10111120, 0x10111121, 0x10111220, 0x10111221, 0x10121020, + 0x10121021, 0x10121022, 0x10121120, 0x10121121, 0x10121122, 0x10121220, 0x10121221, 0x11101021, + 0x11101121, 0x11101122, 0x11101220, 0x11101221, 0x11101222, 0x11111020, 0x11111021, 0x11111022, + 0x11111120, 0x11111121, 0x11111122, 0x11111220, 0x11111221, 0x11111222, 0x11121021, 0x11121120, + 0x11121121, 0x11121221, 0x12101022, 0x12101121, 0x12101122, 0x12101220, 0x12101221, 0x12101222, + 0x12111021, 0x12111121, 0x12111222, 0x12121022, 0x12121121, 0x12121122, 0x12121220, 0x12121221, + 0x10102100, 0x10102101, 0x10102102, 0x10102201, 0x10112000, 0x10112101, 0x10112200, 0x10122001, + 0x10122202, 0x11102101, 0x11102200, 0x11102202, 0x11112001, 0x11112100, 0x11112101, 0x11112102, + 0x11112200, 0x11112201, 0x11122000, 0x11122002, 0x11122100, 0x11122101, 0x12102002, 0x12102201, + 0x12112000, 0x12112002, 0x12112101, 0x12112200, 0x12122001, 0x12122201, 0x10102011, 0x10102012, + 0x10102111, 0x10102212, 0x10112011, 0x10112110, 0x10112111, 0x10112112, 0x10112211, 0x10122111, + 0x11102011, 0x11102110, 0x11102111, 0x11102112, 0x11102211, 0x11112010, 0x11112011, 0x11112012, + 0x11112110, 0x11112111, 0x11112112, 0x11112210, 0x11112211, 0x11112212, 0x11122011, 0x11122110, + 0x11122111, 0x11122112, 0x11122211, 0x12102011, 0x12102111, 0x12102211, 0x12112011, 0x12112110, + 0x12112111, 0x12112112, 0x12112210, 0x12112211, 0x12122111, 0x10102120, 0x10102220, 0x10112121, + 0x10112222, 0x10122020, 0x10122121, 0x10122122, 0x10122221, 0x11102121, 0x11102220, 0x11102221, + 0x11112021, 0x11112121, 0x11112122, 0x11112220, 0x11112221, 0x11122022, 0x11122121, 0x11122220, + 0x11122222, 0x12102021, 0x12102222, 0x12112022, 0x12112121, 0x12112122, 0x12112220, 0x12112222, + 0x12122021, 0x10200101, 0x10210100, 0x10210102, 0x10210201, 0x10220101, 0x11200100, 0x11210000, + 0x11210101, 0x11210102, 0x11210200, 0x11210202, 0x11220001, 0x11220100, 0x11220102, 0x11220201, + 0x12200001, 0x12210102, 0x12220101, 0x10200011, 0x10200110, 0x10200112, 0x10200211, 0x10210012, + 0x10210111, 0x10220011, 0x10220012, 0x10220112, 0x10220211, 0x11200111, 0x11200211, 0x11210011, + 0x11210111, 0x11210112, 0x11210211, 0x11220111, 0x11220112, 0x11220212, 0x12200110, 0x12200212, + 0x12210012, 0x12210111, 0x12220011, 0x12220112, 0x12220211, 0x10210021, 0x10210122, 0x10210221, + 0x11200020, 0x11200021, 0x11200122, 0x11210121, 0x11210122, 0x11210220, 0x11220020, 0x12200121, + 0x12210021, 0x12210122, 0x12220121, 0x10211001, 0x10211002, 0x10211101, 0x10211102, 0x10211202, + 0x10221001, 0x10221102, 0x10221201, 0x11201000, 0x11201002, 0x11201101, 0x11201200, 0x11201202, + 0x11211001, 0x11211100, 0x11211101, 0x11211102, 0x11211201, 0x11211202, 0x11221000, 0x11221002, + 0x11221101, 0x12201100, 0x12201101, 0x12201201, 0x12211000, 0x12211002, 0x12211100, 0x12211101, + 0x12211102, 0x12211200, 0x12211202, 0x12221001, 0x12221100, 0x12221201, 0x10201111, 0x10201210, + 0x10201212, 0x10211011, 0x10211111, 0x10211112, 0x10211211, 0x11201110, 0x11201111, 0x11201112, + 0x11201211, 0x11211010, 0x11211011, 0x11211110, 0x11211111, 0x11211112, 0x11211211, 0x11221011, + 0x11221110, 0x11221111, 0x11221112, 0x11221211, 0x12201112, 0x12201211, 0x12201212, 0x12211011, + 0x12211111, 0x12211112, 0x12211211, 0x12211212, 0x12221012, 0x12221111, 0x12221112, 0x12221210, + 0x10201022, 0x10201221, 0x10211121, 0x10221020, 0x10221122, 0x10221220, 0x10221221, 0x11201020, + 0x11201121, 0x11201220, 0x11201222, 0x11211021, 0x11211120, 0x11211121, 0x11211122, 0x11211220, + 0x11211222, 0x11221020, 0x11221121, 0x11221220, 0x12201020, 0x12201022, 0x12201121, 0x12201222, + 0x12211120, 0x12211122, 0x12211220, 0x12211221, 0x12221020, 0x12221120, 0x12221122, 0x12221222, + 0x10212102, 0x10212201, 0x10222101, 0x11202001, 0x11212002, 0x11212101, 0x11212202, 0x11222001, + 0x11222201, 0x12202101, 0x12212001, 0x12212200, 0x12222102, 0x10202011, 0x10202110, 0x10212010, + 0x10212111, 0x10222011, 0x10222110, 0x10222112, 0x10222211, 0x11202010, 0x11202011, 0x11202111, + 0x11202112, 0x11202210, 0x11212011, 0x11212110, 0x11212111, 0x11212112, 0x11212211, 0x11222010, + 0x11222111, 0x11222212, 0x12202012, 0x12202110, 0x12202212, 0x12212111, 0x12222011, 0x12222110, + 0x12222111, 0x12222211, 0x10212021, 0x10212122, 0x10212220, 0x11202021, 0x11202120, 0x11202221, + 0x11212020, 0x11212121, 0x11212220, 0x11212222, 0x11222120, 0x11222121, 0x11222221, 0x12202122, + 0x12212120, 0x12212220, 0x12212222, 0x12222122, 0x20000000, 0x20000002, 0x20000200, 0x20000202, + 0x20020000, 0x20020002, 0x20020200, 0x20020202, 0x21000101, 0x21010000, 0x21010001, 0x21010100, + 0x21010102, 0x21010201, 0x21020101, 0x22000000, 0x22000002, 0x22000200, 0x22000202, 0x22010101, + 0x22020000, 0x22020002, 0x22020200, 0x22020202, 0x20000111, 0x20010011, 0x20010110, 0x20010112, + 0x20010211, 0x20020111, 0x21000011, 0x21000110, 0x21000211, 0x21010010, 0x21010012, 0x21010111, + 0x21010112, 0x21010210, 0x21010211, 0x21020110, 0x21020112, 0x21020211, 0x22000111, 0x22000211, + 0x22010110, 0x22010112, 0x22010211, 0x22020111, 0x20000020, 0x20000022, 0x20000220, 0x20000222, + 0x20010121, 0x20020020, 0x20020022, 0x20020220, 0x20020222, 0x21010021, 0x21010120, 0x21010221, + 0x21020121, 0x22000020, 0x22000022, 0x22000220, 0x22000222, 0x22010121, 0x22020020, 0x22020022, + 0x22020220, 0x22020222, 0x20011100, 0x20011201, 0x21001001, 0x21001100, 0x21011001, 0x21011101, + 0x21011202, 0x21021001, 0x21021100, 0x21021201, 0x22011100, 0x22011201, 0x20001011, 0x20001211, + 0x20011012, 0x20011111, 0x20011212, 0x20021112, 0x20021211, 0x21001010, 0x21001011, 0x21001111, + 0x21001210, 0x21011011, 0x21011110, 0x21011111, 0x21011112, 0x21011211, 0x21011212, 0x21021111, + 0x21021112, 0x21021210, 0x21021212, 0x22001011, 0x22001110, 0x22001112, 0x22001211, 0x22011010, + 0x22011012, 0x22011111, 0x22011210, 0x22021112, 0x20011021, 0x20011122, 0x20011221, 0x20021121, + 0x21001021, 0x21001120, 0x21001221, 0x21001222, 0x21011020, 0x21011121, 0x21011221, 0x21011222, + 0x21021021, 0x21021122, 0x21021222, 0x22001121, 0x22011021, 0x22011222, 0x22021120, 0x20002000, + 0x20002002, 0x20002200, 0x20002202, 0x20012101, 0x20022000, 0x20022002, 0x20022200, 0x20022202, + 0x21002001, 0x21002101, 0x21012001, 0x21012100, 0x21012201, 0x21022101, 0x21022201, 0x22002000, + 0x22002002, 0x22002200, 0x22002202, 0x22012101, 0x22022000, 0x22022002, 0x22022200, 0x22022202, + 0x20002111, 0x20002112, 0x20012011, 0x20012110, 0x20012112, 0x20022111, 0x21002011, 0x21002110, + 0x21002112, 0x21002211, 0x21012010, 0x21012012, 0x21012111, 0x21012212, 0x21022011, 0x21022110, + 0x22002111, 0x22012112, 0x22012211, 0x22022111, 0x20002020, 0x20002022, 0x20002220, 0x20002222, + 0x20012121, 0x20022020, 0x20022022, 0x20022220, 0x20022222, 0x21002121, 0x21012021, 0x21012120, + 0x21012122, 0x22002020, 0x22002022, 0x22002220, 0x22002222, 0x22012121, 0x22022020, 0x22022022, + 0x22022220, 0x22022222, 0x20100101, 0x20110001, 0x20110102, 0x20110200, 0x20110201, 0x20120101, + 0x21100001, 0x21100102, 0x21100201, 0x21110101, 0x21110200, 0x21110202, 0x21120201, 0x21120202, + 0x22100101, 0x22110001, 0x22110100, 0x22110102, 0x22110201, 0x22120101, 0x20100011, 0x20100110, + 0x20100112, 0x20100211, 0x20110010, 0x20110111, 0x20110210, 0x20110212, 0x20120011, 0x20120110, + 0x20120112, 0x20120211, 0x21100010, 0x21100111, 0x21110010, 0x21110011, 0x21110110, 0x21110111, + 0x21110112, 0x21110211, 0x21120012, 0x21120111, 0x22100110, 0x22100112, 0x22110012, 0x22110111, + 0x22110210, 0x22120011, 0x22120110, 0x22120112, 0x22120211, 0x20100121, 0x20110021, 0x20110120, + 0x20110221, 0x20120121, 0x21100120, 0x21100122, 0x21100221, 0x21110020, 0x21110022, 0x21110121, + 0x21110220, 0x21120122, 0x21120221, 0x22100121, 0x22110120, 0x22110122, 0x22120221, 0x20101001, + 0x20101100, 0x20101102, 0x20111000, 0x20111101, 0x20111200, 0x20121102, 0x21101000, 0x21101202, + 0x21111001, 0x21111100, 0x21111101, 0x21111102, 0x21111200, 0x21111201, 0x21121000, 0x21121001, + 0x21121002, 0x21121101, 0x22101100, 0x22101102, 0x22111002, 0x22111100, 0x22111101, 0x22111200, + 0x22121001, 0x22121201, 0x20101010, 0x20101111, 0x20101210, 0x20101212, 0x20111010, 0x20111011, + 0x20111110, 0x20111111, 0x20111112, 0x20111211, 0x20121011, 0x20121111, 0x20121211, 0x20121212, + 0x21101011, 0x21101110, 0x21101111, 0x21101112, 0x21101211, 0x21111010, 0x21111011, 0x21111012, + 0x21111110, 0x21111111, 0x21111112, 0x21111210, 0x21111211, 0x21111212, 0x21121011, 0x21121110, + 0x21121111, 0x21121112, 0x21121211, 0x22101011, 0x22101111, 0x22101210, 0x22111011, 0x22111012, + 0x22111110, 0x22111111, 0x22111112, 0x22111211, 0x22111212, 0x22121010, 0x22121012, 0x22121111, + 0x22121210, 0x22121212, 0x20101021, 0x20101120, 0x20111020, 0x20111121, 0x20111221, 0x20121020, + 0x20121122, 0x20121221, 0x21101121, 0x21101220, 0x21101221, 0x21111021, 0x21111022, 0x21111121, + 0x21111122, 0x21111221, 0x21121121, 0x21121220, 0x22101022, 0x22101120, 0x22101221, 0x22101222, + 0x22111022, 0x22111120, 0x22111121, 0x22121120, 0x22121122, 0x22121221, 0x20102101, 0x20112102, + 0x20112201, 0x20122101, 0x21102001, 0x21102102, 0x21112000, 0x21112002, 0x21112101, 0x21112102, + 0x21112202, 0x21122100, 0x21122101, 0x22102101, 0x22112001, 0x22112102, 0x22112201, 0x22122101, + 0x20102110, 0x20102112, 0x20102211, 0x20112010, 0x20112012, 0x20112111, 0x20112210, 0x20112212, + 0x20122010, 0x20122011, 0x20122110, 0x20122112, 0x21102010, 0x21102012, 0x21102111, 0x21102210, + 0x21102212, 0x21112011, 0x21112110, 0x21112111, 0x21112112, 0x21112211, 0x21122012, 0x21122111, + 0x21122112, 0x21122212, 0x22102011, 0x22102110, 0x22112010, 0x22112012, 0x22112111, 0x22112212, + 0x22122011, 0x22122112, 0x20102121, 0x20112121, 0x20122121, 0x21102120, 0x21102122, 0x21102221, + 0x21112020, 0x21112121, 0x21112220, 0x21122021, 0x22102121, 0x22112021, 0x22112120, 0x22112121, + 0x22112122, 0x20200000, 0x20200002, 0x20200200, 0x20200202, 0x20210101, 0x20220000, 0x20220002, + 0x20220200, 0x20220202, 0x21200101, 0x21210001, 0x21210100, 0x21210102, 0x21210201, 0x22200000, + 0x22200002, 0x22200200, 0x22200202, 0x22210101, 0x22220000, 0x22220002, 0x22220200, 0x22220202, + 0x20200111, 0x20200211, 0x20210011, 0x20210110, 0x20210112, 0x20210211, 0x20210212, 0x21200112, + 0x21200211, 0x21210011, 0x21210111, 0x21210210, 0x21210212, 0x21220011, 0x21220110, 0x22200111, + 0x22210010, 0x22210012, 0x22210112, 0x22210211, 0x20200022, 0x20200220, 0x20200222, 0x20210020, + 0x20210221, 0x20220022, 0x20220220, 0x20220222, 0x21200121, 0x21210021, 0x21210122, 0x21210221, + 0x21220121, 0x22200020, 0x22200022, 0x22200220, 0x22200222, 0x22210121, 0x22220020, 0x22220022, + 0x22220220, 0x22220222, 0x20211201, 0x20221101, 0x21201001, 0x21201100, 0x21211000, 0x21211100, + 0x21211101, 0x21211200, 0x21211202, 0x21221001, 0x21221101, 0x21221102, 0x21221200, 0x21221201, + 0x22201101, 0x20201112, 0x20201211, 0x20211010, 0x20211012, 0x20211111, 0x20211210, 0x20221112, + 0x20221211, 0x21201012, 0x21201111, 0x21211011, 0x21211110, 0x21211111, 0x21211112, 0x21211211, + 0x21221111, 0x21221212, 0x22201011, 0x22201110, 0x22201111, 0x22201112, 0x22201211, 0x22211012, + 0x22211111, 0x22211210, 0x20201121, 0x20211021, 0x20211122, 0x20211222, 0x20221021, 0x20221121, + 0x21201120, 0x21201122, 0x21201222, 0x21211022, 0x21211121, 0x21211122, 0x21211220, 0x21221020, + 0x21221022, 0x22201122, 0x22211020, 0x22211121, 0x22211122, 0x22211221, 0x22221021, 0x22221120, + 0x22221122, 0x20202000, 0x20202002, 0x20202200, 0x20202202, 0x20222000, 0x20222002, 0x20222200, + 0x20222202, 0x21212001, 0x21212100, 0x21212102, 0x21212201, 0x22202000, 0x22202002, 0x22202200, + 0x22202202, 0x22212101, 0x22222000, 0x22222002, 0x22222200, 0x22222202, 0x20202111, 0x20212110, + 0x20212211, 0x20222011, 0x20222111, 0x21202011, 0x21212010, 0x21212111, 0x21212212, 0x21222011, + 0x21222112, 0x21222211, 0x22212010, 0x22212112, 0x20202020, 0x20202022, 0x20202220, 0x20202222, + 0x20222020, 0x20222022, 0x20222220, 0x20222222, 0x21212021, 0x21212120, 0x21212122, 0x22202020, + 0x22202022, 0x22202220, 0x22202222, 0x22212121, 0x22222020, 0x22222022, 0x22222220, 0x22222222, }; static const __device__ uint8_t ksigns_iq2xs[128] = { diff --git a/csrc/quantization/gguf/gguf_kernel.cu b/csrc/quantization/gguf/gguf_kernel.cu index 966d9992b25fd..37e4de4e14dd3 100644 --- a/csrc/quantization/gguf/gguf_kernel.cu +++ b/csrc/quantization/gguf/gguf_kernel.cu @@ -166,6 +166,11 @@ torch::Tensor ggml_mul_mat_vec_a8(torch::Tensor W, // quant weight (void*)quant_X.data_ptr(), (half*)Y.data_ptr(), col, row, stream); break; + case 29: + mul_mat_vec_iq1_m_q8_1_cuda((void*)W.data_ptr(), + (void*)quant_X.data_ptr(), + (half*)Y.data_ptr(), col, row, stream); + break; } return Y; } diff --git a/csrc/quantization/gguf/mmvq.cuh b/csrc/quantization/gguf/mmvq.cuh index ef2ea072392d2..b221ae7896138 100644 --- a/csrc/quantization/gguf/mmvq.cuh +++ b/csrc/quantization/gguf/mmvq.cuh @@ -157,6 +157,14 @@ static void mul_mat_vec_iq1_s_q8_1_cuda(const void * vx, const void * vy, half * <<>>(vx, vy, dst, ncols, nrows); } +static void mul_mat_vec_iq1_m_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) { + const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; + const dim3 block_nums(block_num_y, 1, 1); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + mul_mat_vec_q + <<>>(vx, vy, dst, ncols, nrows); +} + static void mul_mat_vec_iq4_nl_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) { const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const dim3 block_nums(block_num_y, 1, 1); diff --git a/csrc/quantization/gguf/vecdotq.cuh b/csrc/quantization/gguf/vecdotq.cuh index 78c749d3f3bc1..d5af345a6b26f 100644 --- a/csrc/quantization/gguf/vecdotq.cuh +++ b/csrc/quantization/gguf/vecdotq.cuh @@ -1,5 +1,18 @@ // copied and adapted from https://github.com/ggerganov/llama.cpp/blob/b2899/ggml-cuda/vecdotq.cuh // and https://github.com/ggerganov/llama.cpp/blob/b2899/ggml-cuda/mmq.cu +static __device__ __forceinline__ int get_int_b2(const void * x, const int & i32) { + const uint16_t * x16 = (const uint16_t *) x; // assume at least 2 byte alignment + + int x32 = x16[2*i32 + 0] << 0; + x32 |= x16[2*i32 + 1] << 16; + + return x32; +} + +static __device__ __forceinline__ int get_int_b4(const void * x, const int & i32) { + return ((const int *) x)[i32]; // assume at least 4 byte alignment +} + static __device__ __forceinline__ int get_int_from_int8(const int8_t * x8, const int & i32) { const uint16_t * x16 = (const uint16_t *) (x8 + sizeof(int) * i32); // assume at least 2 byte alignment int x32 = 0; @@ -1661,24 +1674,76 @@ static __device__ __forceinline__ float vec_dot_iq1_s_q8_1( #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 const block_iq1_s * bq1 = (const block_iq1_s *) vbq; - const int ib32 = iqs; - int sumi1 = 0, sumi2 = 0, sumi3 = 0, sumi4 = 0; - const uint8_t h1 = bq1->scales[2*ib32+0]; - const uint8_t h2 = bq1->scales[2*ib32+1]; - const int * q8 = (const int *)bq8_1[ib32].qs; - const int * grid1 = (const int *)(iq1s_grid + (bq1->qs[4*ib32+0] | ((h1 & 0x08) << 5))); - const int * grid2 = (const int *)(iq1s_grid + (bq1->qs[4*ib32+1] | ((h1 & 0x80) << 1))); - const int * grid3 = (const int *)(iq1s_grid + (bq1->qs[4*ib32+2] | ((h2 & 0x08) << 5))); - const int * grid4 = (const int *)(iq1s_grid + (bq1->qs[4*ib32+3] | ((h2 & 0x80) << 1))); - for (int j = 0; j < 2; ++j) { - sumi1 = __dp4a(q8[j+0], grid1[j], sumi1); - sumi2 = __dp4a(q8[j+2], grid2[j], sumi2); - sumi3 = __dp4a(q8[j+4], grid3[j], sumi3); - sumi4 = __dp4a(q8[j+6], grid4[j], sumi4); - } - const float d = __half2float(bq1->d) * __low2float(bq8_1[ib32].ds); - return d * (sumi1 * (2*(h1 & 7) + 1) + sumi2 * (2*((h1 >> 4) & 7) + 1) + - sumi3 * (2*(h2 & 7) + 1) + sumi4 * (2*((h2 >> 4) & 7) + 1)); + const int qs_packed = get_int_b2(bq1->qs, iqs); + const uint8_t * qs = (const uint8_t *) &qs_packed; + + const int qh = bq1->qh[iqs]; + + int sumi = 0; +#pragma unroll + for (int l0 = 0; l0 < 8; l0 += 2) { + const int grid = iq1s_grid_gpu[qs[l0/2] | (((qh >> 3*(l0/2)) & 0x07) << 8)]; + + const int grid0 = (grid >> 0) & 0x0F0F0F0F; + const int grid1 = (grid >> 4) & 0x0F0F0F0F; + + const int u0 = get_int_b4(bq8_1[iqs].qs, l0 + 0); + const int u1 = get_int_b4(bq8_1[iqs].qs, l0 + 1); + + sumi = __dp4a(grid0, u0, sumi); + sumi = __dp4a(grid1, u1, sumi); + } + + const float d1q = __half2float(bq1->d) * (((qh >> 11) & 0x0E) + 1); + const float delta = -1.0f + IQ1S_DELTA - (qh & 0x8000) * (2.0f*IQ1S_DELTA/0x8000); + const float2 ds = __half22float2(bq8_1[iqs].ds); + return d1q * (ds.x*sumi + ds.y*delta); +#endif +} + +static __device__ __forceinline__ float vec_dot_iq1_m_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 + + const block_iq1_m * bq1 = (const block_iq1_m *) vbq; + + const int qs_packed = get_int_b4(bq1->qs, iqs); + const uint8_t * qs = (const uint8_t *) &qs_packed; + + int sumi[2] = {0}; + float sumf[2] = {0.0f}; +#pragma unroll + for (int l0 = 0; l0 < 8; l0 += 2) { + const int qhl = bq1->qh[2*iqs + l0/4] >> (4 * ((l0/2) % 2)); + + const int grid = iq1s_grid_gpu[qs[l0/2] | ((qhl & 0x07) << 8)]; + + const int grid0 = (grid >> 0) & 0x0F0F0F0F; + const int grid1 = (grid >> 4) & 0x0F0F0F0F; + + const int u0 = get_int_b4(bq8_1[iqs].qs, l0 + 0); + const int u1 = get_int_b4(bq8_1[iqs].qs, l0 + 1); + + sumi[l0/4] = __dp4a(grid0, u0, sumi[l0/4]); + sumi[l0/4] = __dp4a(grid1, u1, sumi[l0/4]); + + const float delta = -1.0f + IQ1M_DELTA - (qhl & 0x08) * (2.0f*IQ1M_DELTA/0x08); + int sumy = 0; + sumy = __dp4a(u0, 0x01010101, sumy); + sumy = __dp4a(u1, 0x01010101, sumy); + sumf[l0/4] += delta*sumy; + } + + const uint16_t * sc = (const uint16_t *) bq1->scales; + + iq1m_scale_t scale; + scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00F0) | ((sc[2] >> 4) & 0x0F00) | (sc[3] & 0xF000); + const float d = __half2float(scale.f16) * __low2float(bq8_1[iqs].ds); + + const int tmp = sc[iqs/2] >> (6*(iqs%2)); + const int sc0 = 2*((tmp >> 0) & 0x07) + 1; + const int sc1 = 2*((tmp >> 3) & 0x07) + 1; + return d * ((sumi[0] + sumf[0]) * sc0 + (sumi[1] + sumf[1]) * sc1); #endif } diff --git a/csrc/rocm/attention.cu b/csrc/rocm/attention.cu new file mode 100644 index 0000000000000..8fa7c862fbfa8 --- /dev/null +++ b/csrc/rocm/attention.cu @@ -0,0 +1,1038 @@ +/* + * Copyright (c) 2024, The vLLM team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include + +#include + +#if defined(__HIPCC__) && (defined(__gfx90a__) || defined(__gfx940__) || \ + defined(__gfx941__) || defined(__gfx942__)) + #define __HIP__MI300_MI250__ +#endif + +#if defined(NDEBUG) + #undef NDEBUG + #include + #define UNREACHABLE_CODE assert(false); + #define NDEBUG +#else + #define UNREACHABLE_CODE assert(false); +#endif + +#define MAX(a, b) ((a) > (b) ? (a) : (b)) +#define MIN(a, b) ((a) < (b) ? (a) : (b)) +#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b)) +#define WARP_SIZE 64 + +#if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support + + #define GCN_MFMA_INSTR1 __builtin_amdgcn_mfma_f32_16x16x4f32 + #define GCN_MFMA_INSTR __builtin_amdgcn_mfma_f32_4x4x4f16 + +using floatx4 = __attribute__((__vector_size__(4 * sizeof(float)))) float; +using float16x4 = + __attribute__((__vector_size__(4 * sizeof(_Float16)))) _Float16; +typedef float16x4 _Half4; +typedef struct _Half8 { + _Half4 xy[2]; +} _Half8; + +using bit16_t = uint16_t; +using bit16x4 = __attribute__((__vector_size__(4 * sizeof(uint16_t)))) uint16_t; +typedef bit16x4 _B16x4; +typedef struct _B16x8 { + _B16x4 xy[2]; +} _B16x8; + +////// Non temporal load stores /////// + +template +__device__ __forceinline__ T load(T* addr) { + return addr[0]; +} + +template +__device__ __forceinline__ void store(T value, T* addr) { + addr[0] = value; +} + +template +__device__ __forceinline__ floatx4 gcn_mfma_instr(const _B16x4& inpA, + const _B16x4& inpB, + const floatx4& inpC) { + if constexpr (std::is_same::value) { + return __builtin_amdgcn_mfma_f32_4x4x4f16(inpA, inpB, inpC, absz, cbid, + blgp); + } else if constexpr (std::is_same::value) { + return __builtin_amdgcn_mfma_f32_4x4x4bf16_1k(inpA, inpB, inpC, absz, cbid, + blgp); + } else { + static_assert(false, "unsupported 16b dtype"); + } +} + +template +__device__ __forceinline__ float to_float(const T& inp) { + if constexpr (std::is_same::value) { + return (float)inp; + } else if constexpr (std::is_same::value) { + return __bfloat162float(inp); + } else { + static_assert(false, "unsupported 16b dtype"); + } +} + +template +__device__ __forceinline__ T from_float(const float& inp) { + if constexpr (std::is_same::value) { + return (_Float16)inp; + } else if constexpr (std::is_same::value) { + return __float2bfloat16(inp); + } else { + static_assert(false, "unsupported 16b dtype"); + } +} + +template +__device__ __forceinline__ _B16x4 from_floatx4(const floatx4& inp) { + union tmpcvt { + uint16_t u; + _Float16 f; + __hip_bfloat16 b; + } t16; + _B16x4 ret; + if constexpr (std::is_same::value) { + #pragma unroll + for (int i = 0; i < 4; i++) { + t16.f = (_Float16)inp[i]; + ret[i] = t16.u; + } + return ret; + } else if constexpr (std::is_same::value) { + #pragma unroll + for (int i = 0; i < 4; i++) { + t16.b = __float2bfloat16(inp[i]); + ret[i] = t16.u; + } + return ret; + } else { + static_assert(false, "unsupported 16b dtype"); + } +} + +template +__device__ __forceinline__ _B16x4 addx4(const _B16x4& inp1, + const _B16x4& inp2) { + union tmpcvt { + uint16_t u; + _Float16 f; + __hip_bfloat16 b; + } t1, t2, res; + _B16x4 ret; + if constexpr (std::is_same::value) { + #pragma unroll + for (int i = 0; i < 4; i++) { + t1.u = inp1[i]; + t2.u = inp2[i]; + res.f = t1.f + t2.f; + ret[i] = res.u; + } + return ret; + } else if constexpr (std::is_same::value) { + #pragma unroll + for (int i = 0; i < 4; i++) { + t1.u = inp1[i]; + t2.u = inp2[i]; + res.b = t1.b + t2.b; + ret[i] = res.u; + } + return ret; + } else { + static_assert(false, "unsupported 16b dtype"); + } +} + +/////////////////////////////////////// + +// grid (num_seqs, num_partitions,num_heads/gqa_ratio) +// block (partition size) +template +__global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( + const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] + const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, + // head_size/x, block_size, x] + const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, + // head_size, block_size] + const int num_kv_heads, const float scale, + const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] + const int* __restrict__ context_lens, // [num_seqs] + const int max_num_blocks_per_seq, + const float* __restrict__ alibi_slopes, // [num_heads] + const int q_stride, const int kv_block_stride, const int kv_head_stride, + float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] + float* __restrict__ max_logits, // [num_seqs, num_heads, + // max_num_partitions] + scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, + // head_size] + scalar_t* __restrict__ final_out, // [num_seqs, num_heads, head_size] + #if 0 + scalar_t* __restrict__ qk_out, // [num_heads, num_seqs, max_ctx_blocks,block_size] + #endif + int max_ctx_blocks) { + constexpr int NWARPS = NUM_THREADS / WARP_SIZE; + const int warpid = threadIdx.x / WARP_SIZE; + const int laneid = threadIdx.x % WARP_SIZE; + const int lane4id = laneid % 4; + + const int seq_idx = blockIdx.x; + const int partition_idx = blockIdx.y; + const int partition_size = blockDim.x; + const int max_num_partitions = gridDim.y; + + const int context_len = context_lens[seq_idx]; + const int partition_start_token_idx = partition_idx * partition_size; + // exit if partition is out of context for seq + if (partition_start_token_idx >= context_len) { + return; + } + constexpr int QHLOOP = + DIVIDE_ROUND_UP(GQA_RATIO, 4); // each 4 lanes fetch 4 different qheads, + // total qheads =8, so qhloop is 2 + constexpr int GQA_RATIO4 = 4 * QHLOOP; + __shared__ float shared_qk_max[NWARPS][GQA_RATIO4 + 1]; + __shared__ float shared_exp_sum[NWARPS][GQA_RATIO4 + 1]; + _B16x8 Qlocal[QHLOOP]; + constexpr int x = 16 / sizeof(scalar_t); + constexpr int KHELOOP = HEAD_SIZE / x; + _B16x8 Klocal[KHELOOP]; + constexpr int VHELOOP = + HEAD_SIZE / + WARP_SIZE; // v head_size dimension is distributed across lanes + constexpr int VTLOOP = 8; // 16 separate 4xtokens across warp -> 16/2 + // 8xtokens + _B16x8 Vlocal[VHELOOP][VTLOOP]; + floatx4 dout[QHLOOP]; + float qk_max[QHLOOP]; + #pragma unroll + for (int h = 0; h < QHLOOP; h++) { + dout[h] = {0}; + qk_max[h] = -FLT_MAX; + } + + const int wg_start_head_idx = blockIdx.z * GQA_RATIO; + const int wg_start_kv_head_idx = blockIdx.z; + + const int warp_start_token_idx = + partition_start_token_idx + warpid * WARP_SIZE; + + if (warp_start_token_idx >= context_len) { // warp out of context + #pragma unroll + for (int h = 0; h < GQA_RATIO4; h++) { + shared_qk_max[warpid][h] = -FLT_MAX; + shared_exp_sum[warpid][h] = 0.0f; + } + } else { // warp within context + + const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE); + const int last_ctx_block = num_context_blocks - 1; + + const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq; + + const int local_token_idx = threadIdx.x; + const int global_token_idx = partition_start_token_idx + local_token_idx; + + const int block_idx = (global_token_idx < context_len) + ? global_token_idx / BLOCK_SIZE + : last_ctx_block; + // fetch block number for q and k + // int32 physical_block_number leads to overflow when multiplied with + // kv_block_stride + const int64_t physical_block_number = + static_cast(block_table[block_idx]); + + // fetch vphysical block numbers up front + constexpr int VBLOCKS = 8 * VTLOOP / BLOCK_SIZE; + int vphysical_blocks[VBLOCKS]; + + const int warp_start_block_idx = warp_start_token_idx / BLOCK_SIZE; + #pragma unroll + for (int b = 0; b < VBLOCKS; b++) { + const int vblock_idx = warp_start_block_idx + b; + const int vblock_idx_ctx = + (vblock_idx <= last_ctx_block) ? vblock_idx : last_ctx_block; + vphysical_blocks[b] = block_table[vblock_idx_ctx]; + } + // each 4 lanes fetch 8 helems, so warp fetches 8*16 = 128 helems + const scalar_t* q_ptr = + q + seq_idx * q_stride + wg_start_head_idx * HEAD_SIZE; + const _B16x8* q_ptrh8 = reinterpret_cast(q_ptr); + const int qhead_elemh8 = laneid / 4; + #pragma unroll + for (int h = 0; h < QHLOOP - 1; h++) { + const int qhead_idx = h * 4 + lane4id; + Qlocal[h] = q_ptrh8[qhead_idx * HEAD_SIZE / 8 + qhead_elemh8]; + } + const int final_qhead_idx = 4 * (QHLOOP - 1) + lane4id; + if (final_qhead_idx < GQA_RATIO) { + Qlocal[QHLOOP - 1] = + q_ptrh8[final_qhead_idx * HEAD_SIZE / 8 + qhead_elemh8]; + } else { + Qlocal[QHLOOP - 1].xy[0] = {0}; + Qlocal[QHLOOP - 1].xy[1] = {0}; + } + + const scalar_t* k_ptr = k_cache + physical_block_number * kv_block_stride + + wg_start_kv_head_idx * kv_head_stride; + + const int physical_block_offset = + local_token_idx % BLOCK_SIZE; // since x=half8, physical_block_offset + // is already cast as _H8 + + const _B16x8* k_ptrh8 = reinterpret_cast(k_ptr); + #pragma unroll + for (int d = 0; d < KHELOOP; d++) { + Klocal[d] = k_ptrh8[d * BLOCK_SIZE + physical_block_offset]; + } + + float alibi_slope[QHLOOP]; + if (alibi_slopes != nullptr) { + #pragma unroll + for (int h = 0; h < QHLOOP; h++) { + const int qhead_idx = h * 4 + lane4id; + alibi_slope[h] = (qhead_idx < GQA_RATIO) + ? alibi_slopes[wg_start_head_idx + qhead_idx] + : 0.f; + } + } + + const scalar_t* v_ptr = v_cache + wg_start_kv_head_idx * kv_head_stride; + const _B16x8* v_ptrh8 = reinterpret_cast(v_ptr); + // iterate over each v block + #pragma unroll + for (int b = 0; b < VBLOCKS; b++) { + // int32 physical_block_number leads to overflow when multiplied with + // kv_block_stride + const int64_t vphysical_block_number = + static_cast(vphysical_blocks[b]); + const _B16x8* v_ptrh8b = + v_ptrh8 + (vphysical_block_number * kv_block_stride) / 8; + // iterate over each head elem (within head_size) + #pragma unroll + for (int h = 0; h < VHELOOP; h++) { + const int head_size_elem = h * WARP_SIZE + laneid; + const _B16x8* v_ptrh8be = v_ptrh8b + head_size_elem * BLOCK_SIZE / 8; + // iterate over all velems within block + #pragma unroll + for (int d = 0; d < BLOCK_SIZE / 8; d++) { + Vlocal[h][b * BLOCK_SIZE / 8 + d] = v_ptrh8be[d]; + } + } + } + + #pragma unroll + for (int h = 0; h < QHLOOP; h++) { + dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], + Klocal[0].xy[0], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], + Klocal[0].xy[1], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], + Klocal[1].xy[0], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], + Klocal[1].xy[1], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], + Klocal[2].xy[0], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], + Klocal[2].xy[1], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], + Klocal[3].xy[0], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], + Klocal[3].xy[1], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], + Klocal[4].xy[0], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], + Klocal[4].xy[1], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], + Klocal[5].xy[0], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], + Klocal[5].xy[1], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], + Klocal[6].xy[0], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], + Klocal[6].xy[1], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], + Klocal[7].xy[0], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], + Klocal[7].xy[1], dout[h]); + if constexpr (KHELOOP > 8) { + dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], + Klocal[8].xy[0], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], + Klocal[8].xy[1], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], + Klocal[9].xy[0], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], + Klocal[9].xy[1], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], + Klocal[10].xy[0], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], + Klocal[10].xy[1], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], + Klocal[11].xy[0], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], + Klocal[11].xy[1], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], + Klocal[12].xy[0], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], + Klocal[12].xy[1], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], + Klocal[13].xy[0], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], + Klocal[13].xy[1], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], + Klocal[14].xy[0], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], + Klocal[14].xy[1], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], + Klocal[15].xy[0], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], + Klocal[15].xy[1], dout[h]); + } // KHELOOP>8 + dout[h] *= scale; + } + // transpose dout so that 4 token ids are in each lane, and 4 heads are across + // 4 lanes + #pragma unroll + for (int h = 0; h < QHLOOP; h++) { + floatx4 tmp = {0}; + #pragma unroll + for (int i = 0; i < 4; i++) { + const float B = (lane4id == i) ? 1.0f : 0.0f; + // const float A = (global_token_idx < context_len) ? dout[h][i] : 0.0f; + tmp = __builtin_amdgcn_mfma_f32_4x4x1f32(dout[h][i], B, tmp, 0, 0, 0); + // tmp = __builtin_amdgcn_mfma_f32_4x4x1f32(A, B, tmp, 0, 0, 0); + } + dout[h] = tmp; + } + + const int lane4_token_idx = 4 * (global_token_idx >> 2); + const int alibi_offset = lane4_token_idx - context_len + 1; + if (alibi_slopes != nullptr) { + #pragma unroll + for (int h = 0; h < QHLOOP; h++) { + #pragma unroll + for (int i = 0; i < 4; i++) { + dout[h][i] += alibi_slope[h] * (alibi_offset + i); + } + } + } + + #pragma unroll + for (int h = 0; h < QHLOOP; h++) { + qk_max[h] = -FLT_MAX; + #pragma unroll + for (int i = 0; i < 4; i++) { + qk_max[h] = (lane4_token_idx + i < context_len) + ? fmaxf(qk_max[h], dout[h][i]) + : qk_max[h]; + } + #pragma unroll + for (int mask = WARP_SIZE / 2; mask >= 4; mask /= 2) { + qk_max[h] = fmaxf(qk_max[h], __shfl_xor(qk_max[h], mask)); + } + } + + float exp_sum[QHLOOP]; + #pragma unroll + for (int h = 0; h < QHLOOP; h++) { + exp_sum[h] = 0.0f; + #pragma unroll + for (int i = 0; i < 4; i++) { + dout[h][i] = (lane4_token_idx + i < context_len) + ? __expf(dout[h][i] - qk_max[h]) + : 0.0f; + exp_sum[h] += dout[h][i]; + } + #pragma unroll + for (int mask = WARP_SIZE / 2; mask >= 4; mask /= 2) { + exp_sum[h] += __shfl_xor(exp_sum[h], mask); + } + } + + #pragma unroll + for (int h = 0; h < QHLOOP; h++) { + const int head_idx = 4 * h + lane4id; + shared_qk_max[warpid][head_idx] = qk_max[h]; + shared_exp_sum[warpid][head_idx] = exp_sum[h]; + } + } // warp within context + + __syncthreads(); + + const int num_heads = gridDim.z * GQA_RATIO; + float* max_logits_ptr = + max_logits + seq_idx * num_heads * max_num_partitions + partition_idx; + float* exp_sums_ptr = + exp_sums + seq_idx * num_heads * max_num_partitions + partition_idx; + #pragma unroll + for (int h = 0; h < QHLOOP; h++) { + float global_qk_max = -FLT_MAX; + float warp_qk_max[NWARPS]; + const int head_idx = 4 * h + lane4id; + #pragma unroll + for (int w = 0; w < NWARPS; w++) { + warp_qk_max[w] = shared_qk_max[w][head_idx]; + global_qk_max = fmaxf(global_qk_max, warp_qk_max[w]); + } + float global_exp_sum = 0.0f; + #pragma unroll + for (int w = 0; w < NWARPS; w++) { + global_exp_sum += + shared_exp_sum[w][head_idx] * __expf(warp_qk_max[w] - global_qk_max); + } + if (head_idx < GQA_RATIO) { + max_logits_ptr[(wg_start_head_idx + head_idx) * max_num_partitions] = + global_qk_max; + exp_sums_ptr[(wg_start_head_idx + head_idx) * max_num_partitions] = + global_exp_sum; + } + const float global_inv_sum_scale = __fdividef(1.f, global_exp_sum + 1e-6f) * + __expf(qk_max[h] - global_qk_max); + dout[h] *= global_inv_sum_scale; + } + // logits[h] -> every 4 lanes hold 4 heads, each lane holds 4 tokens, there + // are 4x16 tokens across warp + _B16x4 logits[QHLOOP]; + #pragma unroll + for (int h = 0; h < QHLOOP; h++) { + logits[h] = from_floatx4(dout[h]); + } + + __shared__ _B16x4 vout_shared[QHLOOP][VHELOOP][WARP_SIZE][NWARPS + 1]; + + if (warp_start_token_idx >= context_len) { // warp out of context + #pragma unroll + for (int qh = 0; qh < QHLOOP; qh++) { + #pragma unroll + for (int vh = 0; vh < VHELOOP; vh++) { + vout_shared[qh][vh][laneid][warpid] = {0}; + } + } + } else { // warp in context + // iterate across heads + #pragma unroll + for (int qh = 0; qh < QHLOOP; qh++) { + // iterate over each v head elem (within head_size) + #pragma unroll + for (int vh = 0; vh < VHELOOP; vh++) { + floatx4 acc = {0}; + // iterate over tokens + acc = gcn_mfma_instr(logits[qh], Vlocal[vh][0].xy[0], + acc); + acc = gcn_mfma_instr(logits[qh], Vlocal[vh][0].xy[1], + acc); + acc = gcn_mfma_instr(logits[qh], Vlocal[vh][1].xy[0], + acc); + acc = gcn_mfma_instr(logits[qh], Vlocal[vh][1].xy[1], + acc); + acc = gcn_mfma_instr(logits[qh], Vlocal[vh][2].xy[0], + acc); + acc = gcn_mfma_instr(logits[qh], Vlocal[vh][2].xy[1], + acc); + acc = gcn_mfma_instr(logits[qh], Vlocal[vh][3].xy[0], + acc); + acc = gcn_mfma_instr(logits[qh], Vlocal[vh][3].xy[1], + acc); + acc = gcn_mfma_instr(logits[qh], Vlocal[vh][4].xy[0], + acc); + acc = gcn_mfma_instr(logits[qh], Vlocal[vh][4].xy[1], + acc); + acc = gcn_mfma_instr(logits[qh], + Vlocal[vh][5].xy[0], acc); + acc = gcn_mfma_instr(logits[qh], + Vlocal[vh][5].xy[1], acc); + acc = gcn_mfma_instr(logits[qh], + Vlocal[vh][6].xy[0], acc); + acc = gcn_mfma_instr(logits[qh], + Vlocal[vh][6].xy[1], acc); + acc = gcn_mfma_instr(logits[qh], + Vlocal[vh][7].xy[0], acc); + acc = gcn_mfma_instr(logits[qh], + Vlocal[vh][7].xy[1], acc); + vout_shared[qh][vh][laneid][warpid] = from_floatx4(acc); + } + } + } // warp in context + + __syncthreads(); + + if (warpid == 0) { + _B16x4 vout[QHLOOP][VHELOOP]; + // iterate across heads + scalar_t* out_ptr; + int out_num_partitions; + if (context_len > partition_size) { + out_num_partitions = max_num_partitions; + out_ptr = out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + + partition_idx * HEAD_SIZE; + } else { + out_num_partitions = 1; + out_ptr = final_out + seq_idx * num_heads * HEAD_SIZE; + } + #pragma unroll + for (int qh = 0; qh < QHLOOP; qh++) { + // iterate over each v head elem (within head_size) + #pragma unroll + for (int vh = 0; vh < VHELOOP; vh++) { + vout[qh][vh] = {0}; + #pragma unroll + for (int w = 0; w < NWARPS; w++) { + vout[qh][vh] = + addx4(vout[qh][vh], vout_shared[qh][vh][laneid][w]); + } + const int head_size_elem = vh * WARP_SIZE + laneid; + bit16_t* out_ptr_b16 = reinterpret_cast(out_ptr); + #pragma unroll + for (int i = 0; i < 4; i++) { + const int head_idx = 4 * qh + i; + if (head_idx < GQA_RATIO) { + out_ptr_b16[(wg_start_head_idx + head_idx) * out_num_partitions * + HEAD_SIZE + + head_size_elem] = vout[qh][vh][i]; + } + } + } + } + } +} + +// Grid: (num_heads, num_seqs). +template +__global__ +__launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( + scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] + const float* __restrict__ exp_sums, // [num_seqs, num_heads, + // max_num_partitions] + const float* __restrict__ max_logits, // [num_seqs, num_heads, + // max_num_partitions] + const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, + // max_num_partitions, head_size] + const int* __restrict__ context_lens, // [num_seqs] + const int max_num_partitions) { + const int num_heads = gridDim.x; + const int head_idx = blockIdx.x; + const int seq_idx = blockIdx.y; + const int context_len = context_lens[seq_idx]; + const int num_partitions = DIVIDE_ROUND_UP(context_len, PARTITION_SIZE); + if (num_partitions == 1) { + // if num_partitions==1, main kernel will write to out directly, no work in + // reduction kernel + return; + } + + constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; + const int warpid = threadIdx.x / WARP_SIZE; + const int laneid = threadIdx.x % WARP_SIZE; + + __shared__ float shared_global_exp_sum; + __shared__ float shared_exp_sums[2 * WARP_SIZE]; + + if (warpid == 0) { + const float* max_logits_ptr = max_logits + + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions; + + // valid partition is the last valid partition in case threadid > num + // partitions + const int valid_partition = + (threadIdx.x < num_partitions) ? threadIdx.x : num_partitions - 1; + const int valid_partition2 = (WARP_SIZE + threadIdx.x < num_partitions) + ? WARP_SIZE + threadIdx.x + : num_partitions - 1; + float reg_max_logit = max_logits_ptr[valid_partition]; + float reg_max_logit2 = max_logits_ptr[valid_partition2]; + float max_logit = fmaxf(reg_max_logit, reg_max_logit2); + + #pragma unroll + for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { + max_logit = fmaxf(max_logit, __shfl_xor(max_logit, mask)); + } + + const float* exp_sums_ptr = exp_sums + + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions; + + float global_exp_sum = 0.0f; + float rescaled_exp_sum = exp_sums_ptr[valid_partition]; + float rescaled_exp_sum2 = exp_sums_ptr[valid_partition2]; + rescaled_exp_sum *= + (threadIdx.x < num_partitions) ? expf(reg_max_logit - max_logit) : 0.0f; + rescaled_exp_sum2 *= (threadIdx.x + WARP_SIZE < num_partitions) + ? expf(reg_max_logit2 - max_logit) + : 0.0f; + global_exp_sum += rescaled_exp_sum + rescaled_exp_sum2; + shared_exp_sums[threadIdx.x] = rescaled_exp_sum; + shared_exp_sums[threadIdx.x + WARP_SIZE] = rescaled_exp_sum2; + + #pragma unroll + for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { + global_exp_sum += __shfl_xor(global_exp_sum, mask); + } + if (threadIdx.x == 0) { + shared_global_exp_sum = global_exp_sum; + } + } // warpid == 0 + const scalar_t* tmp_out_ptr = + tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + + head_idx * max_num_partitions * HEAD_SIZE + threadIdx.x; + constexpr int MAX_NPAR = 64; + scalar_t tmps[MAX_NPAR]; + const float dzero = 0.0f; + #pragma unroll + for (int j = 0; j < MAX_NPAR; j++) { + tmps[j] = from_float(dzero); + } + const int last_partition_offset = (num_partitions - 1) * HEAD_SIZE; + const int num_partition_offset = (num_partitions)*HEAD_SIZE; + int idx = 0; + + constexpr int JCHUNK = 16; + + #pragma unroll + for (int j = 0; j < JCHUNK * HEAD_SIZE; j += HEAD_SIZE) { + // lastj is last valid partition + const int lastj_offset = + (j < num_partition_offset) ? j : last_partition_offset; + tmps[idx] = tmp_out_ptr[lastj_offset]; + idx++; + } + __syncthreads(); + + if (num_partitions > JCHUNK) { + #pragma unroll + for (int j = JCHUNK * HEAD_SIZE; j < 2 * JCHUNK * HEAD_SIZE; + j += HEAD_SIZE) { + const int lastj_offset = + (j < num_partition_offset) ? j : last_partition_offset; + tmps[idx] = tmp_out_ptr[lastj_offset]; + idx++; + } + + if (num_partitions > 2 * JCHUNK) { + #pragma unroll + for (int j = 2 * JCHUNK * HEAD_SIZE; j < MAX_NPAR * HEAD_SIZE; + j += HEAD_SIZE) { + const int lastj_offset = + (j < num_partition_offset) ? j : last_partition_offset; + tmps[idx] = tmp_out_ptr[lastj_offset]; + idx++; + } + } + } // num_partitions > JCHUNK + + // Aggregate tmp_out to out. + float acc = 0.0f; + #pragma unroll + for (int j = 0; j < JCHUNK; j++) { + acc += to_float(tmps[j]) * shared_exp_sums[j]; + } + if (num_partitions > JCHUNK) { + #pragma unroll + for (int j = JCHUNK; j < 2 * JCHUNK; j++) { + acc += to_float(tmps[j]) * shared_exp_sums[j]; + } + if (num_partitions > 2 * JCHUNK) { + #pragma unroll + for (int j = 2 * JCHUNK; j < MAX_NPAR; j++) { + acc += to_float(tmps[j]) * shared_exp_sums[j]; + } + } + } + + if (num_partitions > MAX_NPAR) { + idx = 0; + #pragma unroll + for (int j = MAX_NPAR * HEAD_SIZE; j < 2 * MAX_NPAR * HEAD_SIZE; + j += HEAD_SIZE) { + // lastj is last valid partition + const int lastj_offset = + (j < num_partition_offset) ? j : last_partition_offset; + tmps[idx] = tmp_out_ptr[lastj_offset]; + idx++; + } + + #pragma unroll + for (int j = 0; j < MAX_NPAR; j++) { + acc += to_float(tmps[j]) * shared_exp_sums[j + MAX_NPAR]; + } + } + + const float inv_global_exp_sum = + __fdividef(1.0f, shared_global_exp_sum + 1e-6f); + acc *= inv_global_exp_sum; + scalar_t* out_ptr = + out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; + out_ptr[threadIdx.x] = from_float(acc); +} + +#else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support + +template +__global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( + const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] + const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, + // head_size/x, block_size, x] + const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, + // head_size, block_size] + const int num_kv_heads, const float scale, + const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] + const int* __restrict__ context_lens, // [num_seqs] + const int max_num_blocks_per_seq, + const float* __restrict__ alibi_slopes, // [num_heads] + const int q_stride, const int kv_block_stride, const int kv_head_stride, + float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] + float* __restrict__ max_logits, // [num_seqs, num_heads, + // max_num_partitions] + scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, + // head_size] + scalar_t* __restrict__ final_out, // [num_seqs, num_heads, head_size] + #if 0 + scalar_t* __restrict__ qk_out, // [num_heads, num_seqs, max_ctx_blocks,block_size] + #endif + int max_ctx_blocks) { + UNREACHABLE_CODE +} + +// Grid: (num_heads, num_seqs). +template +__global__ +__launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( + scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] + const float* __restrict__ exp_sums, // [num_seqs, num_heads, + // max_num_partitions] + const float* __restrict__ max_logits, // [num_seqs, num_heads, + // max_num_partitions] + const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, + // max_num_partitions, head_size] + const int* __restrict__ context_lens, // [num_seqs] + const int max_num_partitions){UNREACHABLE_CODE} + +#endif // defined(__HIP__MI300_MI250__) TODO: Add NAVI support + +#define LAUNCH_CUSTOM_ATTENTION(GQA_RATIO) \ + paged_attention_ll4mi_QKV_kernel \ + <<>>( \ + query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \ + block_tables_ptr, context_lens_ptr, max_num_blocks_per_seq, \ + alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \ + exp_sums_ptr, max_logits_ptr, tmp_out_ptr, out_ptr, max_ctx_blocks); + +template +void paged_attention_custom_launcher( + torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, + torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, + torch::Tensor& value_cache, const int num_kv_heads, float scale, + torch::Tensor& block_tables, torch::Tensor& context_lens, + int max_context_len, +#if 0 + torch::Tensor& qk_out, + torch::Tensor& softmax_out, +#endif + const c10::optional& alibi_slopes) { + + int num_seqs = query.size(0); + int num_heads = query.size(1); + int head_size = query.size(2); + int max_num_blocks_per_seq = block_tables.size(1); + int q_stride = query.stride(0); + int kv_block_stride = key_cache.stride(0); + int kv_head_stride = key_cache.stride(1); + + // NOTE: alibi_slopes is optional. + const float* alibi_slopes_ptr = + alibi_slopes + ? reinterpret_cast(alibi_slopes.value().data_ptr()) + : nullptr; + + T* out_ptr = reinterpret_cast(out.data_ptr()); + float* exp_sums_ptr = reinterpret_cast(exp_sums.data_ptr()); + float* max_logits_ptr = reinterpret_cast(max_logits.data_ptr()); + T* tmp_out_ptr = reinterpret_cast(tmp_out.data_ptr()); + T* query_ptr = reinterpret_cast(query.data_ptr()); + T* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); + T* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); + int* block_tables_ptr = block_tables.data_ptr(); + int* context_lens_ptr = context_lens.data_ptr(); +#if 0 + T* qk_out_ptr = reinterpret_cast(qk_out.data_ptr()); + T* softmax_out_ptr = reinterpret_cast(softmax_out.data_ptr()); +#endif + + const int max_ctx_blocks = DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE); + const int max_num_partitions = + DIVIDE_ROUND_UP(max_context_len, PARTITION_SIZE); + const int gqa_ratio = num_heads / num_kv_heads; + assert(num_heads % num_kv_heads == 0); + assert(head_size == HEAD_SIZE); + assert(max_num_partitions <= 128); + + constexpr int NTHR = PARTITION_SIZE; + dim3 grid(num_seqs, max_num_partitions, num_kv_heads); + dim3 block(NTHR); + const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + switch (gqa_ratio) { + case 1: + LAUNCH_CUSTOM_ATTENTION(1); + break; + case 2: + LAUNCH_CUSTOM_ATTENTION(2); + break; + case 3: + LAUNCH_CUSTOM_ATTENTION(3); + break; + case 4: + LAUNCH_CUSTOM_ATTENTION(4); + break; + case 5: + LAUNCH_CUSTOM_ATTENTION(5); + break; + case 6: + LAUNCH_CUSTOM_ATTENTION(6); + break; + case 7: + LAUNCH_CUSTOM_ATTENTION(7); + break; + case 8: + LAUNCH_CUSTOM_ATTENTION(8); + break; + case 9: + LAUNCH_CUSTOM_ATTENTION(9); + break; + case 10: + LAUNCH_CUSTOM_ATTENTION(10); + break; + case 11: + LAUNCH_CUSTOM_ATTENTION(11); + break; + case 12: + LAUNCH_CUSTOM_ATTENTION(12); + break; + case 13: + LAUNCH_CUSTOM_ATTENTION(13); + break; + case 14: + LAUNCH_CUSTOM_ATTENTION(14); + break; + case 15: + LAUNCH_CUSTOM_ATTENTION(15); + break; + case 16: + LAUNCH_CUSTOM_ATTENTION(16); + break; + default: + TORCH_CHECK(false, "Unsupported gqa ratio: ", gqa_ratio); + break; + } + // dim3 grid2(num_heads,num_seqs,head_size/HEAD_ELEMS_PER_WG); + // dim3 block2(1024); + // LAUNCH_CUSTOM_ATTENTION2; + + // reduction kernel is only required if max_context_len > partition size, + // otherwise main kernel writes directly to final output + // note there are cases with graphing where max_context_len is the max + // supported by graphing, not the actual max among all the sequences: in that + // case reduction kernel will still run but return immediately + if (max_context_len > PARTITION_SIZE) { + dim3 reduce_grid(num_heads, num_seqs); + dim3 reduce_block(head_size); + paged_attention_ll4mi_reduce_kernel + <<>>( + out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, + context_lens_ptr, max_num_partitions); + } +} + +#define CALL_CUSTOM_LAUNCHER(T, BLK_SIZE, HEAD_SIZE) \ + paged_attention_custom_launcher( \ + out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \ + num_kv_heads, scale, block_tables, context_lens, max_context_len, \ + alibi_slopes); + +#define CALL_CUSTOM_LAUNCHER_BLK(T, HEAD_SIZE) \ + switch (block_size) { \ + case 16: \ + CALL_CUSTOM_LAUNCHER(T, 16, HEAD_SIZE); \ + break; \ + case 32: \ + CALL_CUSTOM_LAUNCHER(T, 32, HEAD_SIZE); \ + break; \ + default: \ + TORCH_CHECK(false, "Unsupported block size: ", block_size); \ + break; \ + } + +#define CALL_CUSTOM_LAUNCHER_BLK_HEAD(T) \ + switch (head_size) { \ + case 64: \ + CALL_CUSTOM_LAUNCHER_BLK(T, 64); \ + break; \ + case 128: \ + CALL_CUSTOM_LAUNCHER_BLK(T, 128); \ + break; \ + default: \ + TORCH_CHECK(false, "Unsupported head size: ", head_size); \ + break; \ + } + +void paged_attention( + torch::Tensor& out, // [num_seqs, num_heads, head_size] + torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions] + torch::Tensor& max_logits, // [num_seqs, num_heads, max_num_partitions] + torch::Tensor& + tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size] + torch::Tensor& query, // [num_seqs, num_heads, head_size] + torch::Tensor& + key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] + torch::Tensor& + value_cache, // [num_blocks, num_heads, head_size, block_size] + int64_t num_kv_heads, double scale, + torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] + torch::Tensor& context_lens, // [num_seqs] + int64_t block_size, int64_t max_context_len, + const c10::optional& alibi_slopes, + const std::string& kv_cache_dtype) { + assert(kv_cache_dtype == "auto"); + const int head_size = query.size(2); + if (query.dtype() == at::ScalarType::Half) { + CALL_CUSTOM_LAUNCHER_BLK_HEAD(_Float16); + } else if (query.dtype() == at::ScalarType::BFloat16) { + CALL_CUSTOM_LAUNCHER_BLK_HEAD(__hip_bfloat16); + } else { + TORCH_CHECK(false, "Unsupported data type: ", query.dtype()); + } +} + +#undef WARP_SIZE +#undef MAX +#undef MIN +#undef DIVIDE_ROUND_UP diff --git a/csrc/rocm/ops.h b/csrc/rocm/ops.h new file mode 100644 index 0000000000000..4a07a3f1775bd --- /dev/null +++ b/csrc/rocm/ops.h @@ -0,0 +1,13 @@ +#pragma once + +#include + +void paged_attention(torch::Tensor& out, torch::Tensor& exp_sums, + torch::Tensor& max_logits, torch::Tensor& tmp_out, + torch::Tensor& query, torch::Tensor& key_cache, + torch::Tensor& value_cache, int64_t num_kv_heads, + double scale, torch::Tensor& block_tables, + torch::Tensor& context_lens, int64_t block_size, + int64_t max_context_len, + const c10::optional& alibi_slopes, + const std::string& kv_cache_dtype); diff --git a/csrc/rocm/torch_bindings.cpp b/csrc/rocm/torch_bindings.cpp new file mode 100644 index 0000000000000..082e314587908 --- /dev/null +++ b/csrc/rocm/torch_bindings.cpp @@ -0,0 +1,33 @@ +#include "core/registration.h" +#include "rocm/ops.h" + +// Note on op signatures: +// The X_meta signatures are for the meta functions corresponding to op X. +// They must be kept in sync with the signature for X. Generally, only +// functions that return Tensors require a meta function. +// +// See the following links for detailed docs on op registration and function +// schemas. +// https://docs.google.com/document/d/1_W62p8WJOQQUzPsJYa7s701JXt0qf2OfLub2sbkHOaU/edit#heading=h.ptttacy8y1u9 +// https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/README.md#annotations + +TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, rocm_ops) { + // vLLM custom ops for rocm + + // Custom attention op + // Compute the attention between an input query and the cached + // keys/values using PagedAttention. + rocm_ops.def( + "paged_attention(Tensor! out, Tensor exp_sums," + " Tensor max_logits, Tensor tmp_out," + " Tensor query, Tensor key_cache," + " Tensor value_cache, int num_kv_heads," + " float scale, Tensor block_tables," + " Tensor context_lens, int block_size," + " int max_context_len," + " Tensor? alibi_slopes," + " str kv_cache_dtype) -> ()"); + rocm_ops.impl("paged_attention", torch::kCUDA, &paged_attention); +} + +REGISTER_EXTENSION(TORCH_EXTENSION_NAME) diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 51afeacfdc0ad..045203c3de8a8 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -279,8 +279,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "causal_conv1d_update(Tensor! x," "Tensor! conv_state," "Tensor! weight," - "Tensor? bias_," - "bool silu_activation) -> Tensor"); + "Tensor? bias," + "bool silu_activation," + "Tensor? conv_state_indices) -> Tensor"); ops.impl("causal_conv1d_update", torch::kCUDA, &causal_conv1d_update); ops.def( @@ -336,14 +337,14 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // Compute int8 quantized tensor for given scaling factor. ops.def( - "static_scaled_int8_quant(Tensor! out, Tensor input, Tensor scale) -> " - "()"); + "static_scaled_int8_quant(Tensor! out, Tensor input, Tensor scale," + "Tensor? azp) -> ()"); ops.impl("static_scaled_int8_quant", torch::kCUDA, &static_scaled_int8_quant); // Compute int8 quantized tensor and scaling factor ops.def( - "dynamic_scaled_int8_quant(Tensor! out, Tensor input, Tensor! scale) -> " - "()"); + "dynamic_scaled_int8_quant(Tensor! out, Tensor input, Tensor! scale, " + "Tensor!? azp) -> ()"); ops.impl("dynamic_scaled_int8_quant", torch::kCUDA, &dynamic_scaled_int8_quant); } @@ -411,11 +412,6 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _custom_ar), custom_ar) { "bool full_nvlink) -> int"); custom_ar.impl("init_custom_ar", torch::kCUDA, &init_custom_ar); - custom_ar.def( - "should_custom_ar(Tensor inp, int max_size, int world_size, " - "bool full_nvlink) -> bool"); - custom_ar.impl("should_custom_ar", torch::kCUDA, &should_custom_ar); - custom_ar.def("all_reduce_reg(int fa, Tensor inp, Tensor! out) -> ()"); custom_ar.impl("all_reduce_reg", torch::kCUDA, &all_reduce_reg); diff --git a/docs/source/dev/profiling/profiling_index.rst b/docs/source/dev/profiling/profiling_index.rst index e22d547293445..9e8b2f1817567 100644 --- a/docs/source/dev/profiling/profiling_index.rst +++ b/docs/source/dev/profiling/profiling_index.rst @@ -21,8 +21,8 @@ Traces can be visualized using https://ui.perfetto.dev/. .. tip:: To stop the profiler - it flushes out all the profile trace files to the directory. This takes time, for example for about 100 requests worth of data for a llama 70b, it takes about 10 minutes to flush out on a H100. - Set the env variable VLLM_RPC_GET_DATA_TIMEOUT_MS to a big number before you start the server. Say something like 30 minutes. - ``export VLLM_RPC_GET_DATA_TIMEOUT_MS=1800000`` + Set the env variable VLLM_RPC_TIMEOUT to a big number before you start the server. Say something like 30 minutes. + ``export VLLM_RPC_TIMEOUT=1800000`` Example commands and usage: =========================== diff --git a/docs/source/getting_started/cpu-installation.rst b/docs/source/getting_started/cpu-installation.rst index 7fc469e06844f..816e0a29ef28b 100644 --- a/docs/source/getting_started/cpu-installation.rst +++ b/docs/source/getting_started/cpu-installation.rst @@ -59,6 +59,20 @@ Build from source $ pip install wheel packaging ninja "setuptools>=49.4.0" numpy $ pip install -v -r requirements-cpu.txt --extra-index-url https://download.pytorch.org/whl/cpu +- Third, build and install oneDNN library from source: + +.. code-block:: console + + $ git clone -b rls-v3.5 https://github.com/oneapi-src/oneDNN.git + $ cmake -B ./oneDNN/build -S ./oneDNN -G Ninja -DONEDNN_LIBRARY_TYPE=STATIC \ + -DONEDNN_BUILD_DOC=OFF \ + -DONEDNN_BUILD_EXAMPLES=OFF \ + -DONEDNN_BUILD_TESTS=OFF \ + -DONEDNN_BUILD_GRAPH=OFF \ + -DONEDNN_ENABLE_WORKLOAD=INFERENCE \ + -DONEDNN_ENABLE_PRIMITIVE=MATMUL + $ cmake --build ./oneDNN/build --target install --config Release + - Finally, build and install vLLM CPU backend: .. code-block:: console diff --git a/docs/source/getting_started/debugging.rst b/docs/source/getting_started/debugging.rst index 31ecca1332e5d..81287762d3c0a 100644 --- a/docs/source/getting_started/debugging.rst +++ b/docs/source/getting_started/debugging.rst @@ -98,6 +98,13 @@ Here are some common issues that can cause hangs: If the script runs successfully, you should see the message ``sanity check is successful!``. + Note that multi-node environment is more complicated than single-node. If you see errors such as ``torch.distributed.DistNetworkError``, it is likely that the network/DNS setup is incorrect. In that case, you can manually assign node rank and specify the IP via command line arguments: + + - In the first node, run ``NCCL_DEBUG=TRACE torchrun --nnodes 2 --nproc-per-node=2 --node-rank 0 --master_addr $MASTER_ADDR test.py``. + - In the second node, run ``NCCL_DEBUG=TRACE torchrun --nnodes 2 --nproc-per-node=2 --node-rank 1 --master_addr $MASTER_ADDR test.py``. + + Adjust ``--nproc-per-node``, ``--nnodes``, and ``--node-rank`` according to your setup. The difference is that you need to execute different commands (with different ``--node-rank``) on different nodes. + If the problem persists, feel free to `open an issue on GitHub `_, with a detailed description of the issue, your environment, and the logs. Some known issues: diff --git a/docs/source/getting_started/installation.rst b/docs/source/getting_started/installation.rst index f0e54c29fcad7..0322503a89a56 100644 --- a/docs/source/getting_started/installation.rst +++ b/docs/source/getting_started/installation.rst @@ -26,6 +26,10 @@ You can install vLLM using pip: $ # Install vLLM with CUDA 12.1. $ pip install vllm +.. note:: + + Although we recommend using ``conda`` to create and manage Python environments, it is highly recommended to use ``pip`` to install vLLM. This is because ``pip`` can install ``torch`` with separate library packages like ``NCCL``, while ``conda`` installs ``torch`` with statically linked ``NCCL``. This can cause issues when vLLM tries to use ``NCCL``. See `this issue `_ for more details. + .. note:: As of now, vLLM's binaries are compiled with CUDA 12.1 and public PyTorch release versions by default. @@ -34,7 +38,7 @@ You can install vLLM using pip: .. code-block:: console $ # Install vLLM with CUDA 11.8. - $ export VLLM_VERSION=0.4.0 + $ export VLLM_VERSION=0.6.1.post1 $ export PYTHON_VERSION=310 $ pip install https://github.com/vllm-project/vllm/releases/download/v${VLLM_VERSION}/vllm-${VLLM_VERSION}+cu118-cp${PYTHON_VERSION}-cp${PYTHON_VERSION}-manylinux1_x86_64.whl --extra-index-url https://download.pytorch.org/whl/cu118 @@ -48,7 +52,7 @@ You can install vLLM using pip: .. code-block:: console - $ export VLLM_VERSION=0.5.4 # vLLM's main branch version is currently set to latest released tag + $ export VLLM_VERSION=0.6.1.post1 # vLLM's main branch version is currently set to latest released tag $ pip install https://vllm-wheels.s3.us-west-2.amazonaws.com/nightly/vllm-${VLLM_VERSION}-cp38-abi3-manylinux1_x86_64.whl $ # You can also access a specific commit $ # export VLLM_COMMIT=... @@ -80,17 +84,19 @@ You can also build and install vLLM from source: .. tip:: - Building from source requires quite a lot compilation. If you are building from source for multiple times, it is beneficial to cache the compilation results. For example, you can install `ccache `_ via either `conda install ccache` or `apt install ccache` . As long as `which ccache` command can find the `ccache` binary, it will be used automatically by the build system. After the first build, the subsequent builds will be much faster. + Building from source requires quite a lot compilation. If you are building from source for multiple times, it is beneficial to cache the compilation results. For example, you can install `ccache `_ via either ``conda install ccache`` or ``apt install ccache`` . As long as ``which ccache`` command can find the ``ccache`` binary, it will be used automatically by the build system. After the first build, the subsequent builds will be much faster. .. tip:: To avoid your system being overloaded, you can limit the number of compilation jobs - to be run simultaneously, via the environment variable `MAX_JOBS`. For example: + to be run simultaneously, via the environment variable ``MAX_JOBS``. For example: .. code-block:: console $ export MAX_JOBS=6 $ pip install -e . + This is especially useful when you are building on less powerful machines. For example, when you use WSL, it only `gives you half of the memory by default `_, and you'd better use ``export MAX_JOBS=1`` to avoid compiling multiple files simultaneously and running out of memory. The side effect is that the build process will be much slower. If you only touch the Python code, slow compilation is okay, as you are building in an editable mode: you can just change the code and run the Python script without any re-compilation or re-installation. + .. tip:: If you have trouble building vLLM, we recommend using the NVIDIA PyTorch Docker image. @@ -99,7 +105,7 @@ You can also build and install vLLM from source: $ # Use `--ipc=host` to make sure the shared memory is large enough. $ docker run --gpus all -it --rm --ipc=host nvcr.io/nvidia/pytorch:23.10-py3 - If you don't want to use docker, it is recommended to have a full installation of CUDA Toolkit. You can download and install it from `the official website `_. After installation, set the environment variable `CUDA_HOME` to the installation path of CUDA Toolkit, and make sure that the `nvcc` compiler is in your `PATH`, e.g.: + If you don't want to use docker, it is recommended to have a full installation of CUDA Toolkit. You can download and install it from `the official website `_. After installation, set the environment variable ``CUDA_HOME`` to the installation path of CUDA Toolkit, and make sure that the ``nvcc`` compiler is in your ``PATH``, e.g.: .. code-block:: console diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index faac2b97722b7..745b4b8e2e0eb 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -107,6 +107,10 @@ Decoder-only Language Models - MiniCPM - :code:`openbmb/MiniCPM-2B-sft-bf16`, :code:`openbmb/MiniCPM-2B-dpo-bf16`, etc. - + * - :code:`MiniCPM3ForCausalLM` + - MiniCPM3 + - :code:`openbmb/MiniCPM3-4B`, etc. + - * - :code:`MistralForCausalLM` - Mistral, Mistral-Instruct - :code:`mistralai/Mistral-7B-v0.1`, :code:`mistralai/Mistral-7B-Instruct-v0.1`, etc. @@ -175,6 +179,10 @@ Decoder-only Language Models - Starcoder2 - :code:`bigcode/starcoder2-3b`, :code:`bigcode/starcoder2-7b`, :code:`bigcode/starcoder2-15b`, etc. - + * - :code:`SolarForCausalLM` + - EXAONE-3 + - :code:`upstage/solar-pro-preview-instruct`, etc. + - * - :code:`XverseForCausalLM` - Xverse - :code:`xverse/XVERSE-7B-Chat`, :code:`xverse/XVERSE-13B-Chat`, :code:`xverse/XVERSE-65B-Chat`, etc. @@ -342,7 +350,7 @@ Note that, as an inference engine, vLLM does not introduce new models. Therefore We have the following levels of testing for models: -1. **Strict Consistency**: We compare the output of the model with the output of the model in the HuggingFace Transformers library under greedy decoding. This is the most stringent test. Please refer to `test_models.py `_ and `test_big_models.py `_ for the models that have passed this test. +1. **Strict Consistency**: We compare the output of the model with the output of the model in the HuggingFace Transformers library under greedy decoding. This is the most stringent test. Please refer to `models tests `_ for the models that have passed this test. 2. **Output Sensibility**: We check if the output of the model is sensible and coherent, by measuring the perplexity of the output and checking for any obvious errors. This is a less stringent test. 3. **Runtime Functionality**: We check if the model can be loaded and run without errors. This is the least stringent test. Please refer to `functionality tests `_ and `examples `_ for the models that have passed this test. 4. **Community Feedback**: We rely on the community to provide feedback on the models. If a model is broken or not working as expected, we encourage users to raise issues to report it or open pull requests to fix it. The rest of the models fall under this category. diff --git a/examples/offline_chat_with_tools.py b/examples/offline_chat_with_tools.py new file mode 100644 index 0000000000000..e69a6c067e4da --- /dev/null +++ b/examples/offline_chat_with_tools.py @@ -0,0 +1,138 @@ +# ruff: noqa +import json +import random +import string + +from vllm import LLM +from vllm.sampling_params import SamplingParams + +# This script is an offline demo for function calling +# +# If you want to run a server/client setup, please follow this code: +# +# - Server: +# +# ```bash +# vllm serve mistralai/Mistral-7B-Instruct-v0.3 --tokenizer-mode mistral --load-format mistral --config-format mistral +# ``` +# +# - Client: +# +# ```bash +# curl --location 'http://:8000/v1/chat/completions' \ +# --header 'Content-Type: application/json' \ +# --header 'Authorization: Bearer token' \ +# --data '{ +# "model": "mistralai/Mistral-7B-Instruct-v0.3" +# "messages": [ +# { +# "role": "user", +# "content": [ +# {"type" : "text", "text": "Describe this image in detail please."}, +# {"type": "image_url", "image_url": {"url": "https://s3.amazonaws.com/cms.ipressroom.com/338/files/201808/5b894ee1a138352221103195_A680%7Ejogging-edit/A680%7Ejogging-edit_hero.jpg"}}, +# {"type" : "text", "text": "and this one as well. Answer in French."}, +# {"type": "image_url", "image_url": {"url": "https://www.wolframcloud.com/obj/resourcesystem/images/a0e/a0ee3983-46c6-4c92-b85d-059044639928/6af8cfb971db031b.png"}} +# ] +# } +# ] +# }' +# ``` +# +# Usage: +# python demo.py simple +# python demo.py advanced + +model_name = "mistralai/Mistral-7B-Instruct-v0.3" +# or switch to "mistralai/Mistral-Nemo-Instruct-2407" +# or "mistralai/Mistral-Large-Instruct-2407" +# or any other mistral model with function calling ability + +sampling_params = SamplingParams(max_tokens=8192, temperature=0.0) +llm = LLM(model=model_name, + tokenizer_mode="mistral", + config_format="mistral", + load_format="mistral") + + +def generate_random_id(length=9): + characters = string.ascii_letters + string.digits + random_id = ''.join(random.choice(characters) for _ in range(length)) + return random_id + + +# simulate an API that can be called +def get_current_weather(city: str, state: str, unit: 'str'): + return (f"The weather in {city}, {state} is 85 degrees {unit}. It is " + "partly cloudly, with highs in the 90's.") + + +tool_funtions = {"get_current_weather": get_current_weather} + +tools = [{ + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "city": { + "type": + "string", + "description": + "The city to find the weather for, e.g. 'San Francisco'" + }, + "state": { + "type": + "string", + "description": + "the two-letter abbreviation for the state that the city is" + " in, e.g. 'CA' which would mean 'California'" + }, + "unit": { + "type": "string", + "description": "The unit to fetch the temperature in", + "enum": ["celsius", "fahrenheit"] + } + }, + "required": ["city", "state", "unit"] + } + } +}] + +messages = [{ + "role": + "user", + "content": + "Can you tell me what the temperate will be in Dallas, in fahrenheit?" +}] + +outputs = llm.chat(messages, sampling_params=sampling_params, tools=tools) +output = outputs[0].outputs[0].text.strip() + +# append the assistant message +messages.append({ + "role": "assistant", + "content": output, +}) + +# let's now actually parse and execute the model's output simulating an API call by using the +# above defined function +tool_calls = json.loads(output) +tool_answers = [ + tool_funtions[call['name']](**call['arguments']) for call in tool_calls +] + +# append the answer as a tool message and let the LLM give you an answer +messages.append({ + "role": "tool", + "content": "\n\n".join(tool_answers), + "tool_call_id": generate_random_id(), +}) + +outputs = llm.chat(messages, sampling_params, tools=tools) + +print(outputs[0].outputs[0].text.strip()) +# yields +# 'The weather in Dallas, TX is 85 degrees fahrenheit. ' +# 'It is partly cloudly, with highs in the 90's.' diff --git a/examples/offline_inference_pixtral.py b/examples/offline_inference_pixtral.py index 738d890607e37..c12ff7021cf51 100644 --- a/examples/offline_inference_pixtral.py +++ b/examples/offline_inference_pixtral.py @@ -11,7 +11,7 @@ # - Server: # # ```bash -# vllm serve mistralai/Pixtral-12B-2409 --tokenizer_mode mistral --limit_mm_per_prompt 'image=4' --max_num_batched_tokens 16384 +# vllm serve mistralai/Pixtral-12B-2409 --tokenizer-mode mistral --limit-mm-per-prompt 'image=4' --max-model-len 16384 # ``` # # - Client: @@ -45,6 +45,7 @@ def run_simple_demo(): model_name = "mistralai/Pixtral-12B-2409" sampling_params = SamplingParams(max_tokens=8192) + # Lower max_num_seqs or max_model_len on low-VRAM GPUs. llm = LLM(model=model_name, tokenizer_mode="mistral") prompt = "Describe this image in one sentence." @@ -83,7 +84,7 @@ def run_advanced_demo(): model=model_name, tokenizer_mode="mistral", limit_mm_per_prompt={"image": max_img_per_msg}, - max_num_batched_tokens=max_img_per_msg * max_tokens_per_img, + max_model_len=max_img_per_msg * max_tokens_per_img, ) prompt = "Describe the following image." diff --git a/examples/offline_inference_with_profiler.py b/examples/offline_inference_with_profiler.py index 906c9502800d8..1f00d26808771 100644 --- a/examples/offline_inference_with_profiler.py +++ b/examples/offline_inference_with_profiler.py @@ -16,7 +16,7 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95) # Create an LLM. -llm = LLM(model="facebook/opt-125m") +llm = LLM(model="facebook/opt-125m", tensor_parallel_size=1) llm.start_profile() diff --git a/format.sh b/format.sh index 2204b3ba59498..6563d89b192ea 100755 --- a/format.sh +++ b/format.sh @@ -159,7 +159,7 @@ echo 'vLLM codespell: Done' # Lint specified files lint() { - ruff "$@" + ruff check "$@" } # Lint files that differ from main branch. Ignores dirs that are not slated @@ -175,7 +175,7 @@ lint_changed() { if ! git diff --diff-filter=ACM --quiet --exit-code "$MERGEBASE" -- '*.py' '*.pyi' &>/dev/null; then git diff --name-only --diff-filter=ACM "$MERGEBASE" -- '*.py' '*.pyi' | xargs \ - ruff + ruff check fi } diff --git a/pyproject.toml b/pyproject.toml index 22a25d9cf32e6..14f0934499c46 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,6 +42,8 @@ ignore = [ "E731", # Loop control variable not used within loop body "B007", + # f-string format + "UP032", ] [tool.mypy] @@ -76,7 +78,7 @@ exclude = [ [tool.codespell] ignore-words-list = "dout, te, indicies, subtile" -skip = "./tests/prompts,./benchmarks/sonnet.txt,./tests/lora/data,./build" +skip = "./tests/models/fixtures,./tests/prompts,./benchmarks/sonnet.txt,./tests/lora/data,./build" [tool.isort] use_parentheses = true @@ -85,5 +87,6 @@ skip_gitignore = true [tool.pytest.ini_options] markers = [ "skip_global_cleanup", - "vlm: run tests for vision language models only", + "core_model: run this model test in each PR instead of just daily", + "distributed_2_gpus: run this test only in distributed tests for 2 GPUs", ] diff --git a/requirements-common.txt b/requirements-common.txt index 3a9ae4aa77421..ad53395307ec5 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -7,11 +7,12 @@ py-cpuinfo transformers >= 4.43.2 # Required for Chameleon and Llama 3.1 hotfox. tokenizers >= 0.19.1 # Required for Llama 3. protobuf # Required by LlamaTokenizer. -fastapi +fastapi < 0.113.0; python_version < '3.9' +fastapi >= 0.114.1; python_version >= '3.9' aiohttp openai >= 1.40.0 # Ensure modern openai package (ensure types module present) uvicorn[standard] -pydantic >= 2.8 # Required for OpenAI server. +pydantic >= 2.9 # Required for fastapi >= 0.113.0 pillow # Required for image processing prometheus_client >= 0.18.0 prometheus-fastapi-instrumentator >= 7.0.0 @@ -23,9 +24,10 @@ filelock >= 3.10.4 # filelock starts to support `mode` argument from 3.10.4 partial-json-parser # used for parsing partial JSON outputs pyzmq msgspec -gguf == 0.9.1 +gguf == 0.10.0 importlib_metadata mistral_common >= 1.4.0 pyyaml six>=1.16.0; python_version > '3.11' # transitive dependency of pandas that needs to be the latest version for python 3.12 +setuptools>=74.1.1; python_version > '3.11' # Setuptools is used by triton, we need to ensure a modern version is installed for 3.12+ so that it does not try to import distutils, which was removed in 3.12 einops # Required for Qwen2-VL. diff --git a/requirements-lint.txt b/requirements-lint.txt index d0b2fef6deaef..07f738873e1a8 100644 --- a/requirements-lint.txt +++ b/requirements-lint.txt @@ -2,7 +2,7 @@ yapf==0.32.0 toml==0.10.2 tomli==2.0.1 -ruff==0.1.5 +ruff==0.6.5 codespell==2.3.0 isort==5.13.2 clang-format==18.1.5 diff --git a/requirements-test.txt b/requirements-test.txt index ca3bfa7aff629..10d463de27be5 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -14,13 +14,14 @@ librosa # required for audio test opencv-python # required for video test peft requests -ray[adag]>=2.35 +ray[adag]==2.35 sentence-transformers # required for embedding soundfile # required for audio test compressed-tensors==0.4.0 # required for compressed-tensors timm # required for internvl test transformers_stream_generator # required for qwen-vl test matplotlib # required for qwen-vl test +datamodel_code_generator # required for minicpm3 test # TODO: Add this after fully implementing llava(mantis) # git+https://github.com/TIGER-AI-Lab/Mantis.git # required for llava(mantis) test diff --git a/requirements-xpu.txt b/requirements-xpu.txt index 48d899ec70eda..f07211b48b68d 100644 --- a/requirements-xpu.txt +++ b/requirements-xpu.txt @@ -3,9 +3,10 @@ setuptools < 70.0.0 # IPEX's torch have some dependency. to be removed. -torch @ https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_dev/xpu/torch-2.1.0.post1%2Bcxx11.abi-cp310-cp310-linux_x86_64.whl -intel_extension_for_pytorch @ https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_dev/xpu/intel_extension_for_pytorch-2.1.30a0-cp310-cp310-linux_x86_64.whl -oneccl_bind_pt @ https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_stable/xpu/oneccl_bind_pt-2.1.200%2Bxpu-cp310-cp310-linux_x86_64.whl +torch == 2.3.1+cxx11.abi +intel-extension-for-pytorch == 2.3.110+xpu +oneccl_bind_pt == 2.3.100+xpu -triton @ https://github.com/intel/intel-xpu-backend-for-triton/releases/download/v2.1.0/triton-2.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl +triton-xpu == 3.0.0b2 +--extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ diff --git a/setup.py b/setup.py index 10770b8c9aa4a..7da9115440433 100644 --- a/setup.py +++ b/setup.py @@ -371,7 +371,9 @@ def get_vllm_version() -> str: cuda_version = str(get_nvcc_cuda_version()) if cuda_version != MAIN_CUDA_VERSION: cuda_version_str = cuda_version.replace(".", "")[:3] - version += f"+cu{cuda_version_str}" + # skip this for source tarball, required for pypi + if "sdist" not in sys.argv: + version += f"+cu{cuda_version_str}" elif _is_hip(): # Get the HIP version hipcc_version = get_hipcc_rocm_version() @@ -462,6 +464,9 @@ def _read_requirements(filename: str) -> List[str]: if _is_cuda() or _is_hip(): ext_modules.append(CMakeExtension(name="vllm._moe_C")) +if _is_hip(): + ext_modules.append(CMakeExtension(name="vllm._rocm_C")) + if _build_custom_ops(): ext_modules.append(CMakeExtension(name="vllm._C")) diff --git a/tests/async_engine/test_async_llm_engine.py b/tests/async_engine/test_async_llm_engine.py index 3bf11fbcfb3b8..6cae76f74603d 100644 --- a/tests/async_engine/test_async_llm_engine.py +++ b/tests/async_engine/test_async_llm_engine.py @@ -1,7 +1,10 @@ import asyncio +import os +import uuid from asyncio import CancelledError +from copy import copy from dataclasses import dataclass -from typing import Optional +from typing import List, Optional import pytest import pytest_asyncio @@ -11,6 +14,7 @@ from vllm.config import ParallelConfig from vllm.engine.async_llm_engine import AsyncEngineArgs, AsyncLLMEngine from vllm.outputs import RequestOutput as RealRequestOutput +from vllm.sampling_params import RequestOutputKind from ..conftest import cleanup from ..utils import wait_for_gpu_memory_to_clear @@ -22,6 +26,11 @@ class RequestOutput: finished: bool = False +@dataclass +class MockModelConfig: + use_async_output_proc = True + + class MockEngine: def __init__(self): @@ -31,6 +40,7 @@ def __init__(self): self.request_id = None # Ugly, remove dependency when possible self.parallel_config = ParallelConfig(1, 1, False) + self.model_config = MockModelConfig() async def step_async(self, virtual_engine): # PP size is 1, ignore virtual engine @@ -76,7 +86,7 @@ class MockAsyncLLMEngine(AsyncLLMEngine): @pytest.mark.asyncio async def test_new_requests_event(): - engine = MockAsyncLLMEngine(worker_use_ray=False) + engine = MockAsyncLLMEngine() engine.start_background_loop() await asyncio.sleep(0.01) assert engine.engine.step_calls == 0 @@ -109,7 +119,7 @@ async def test_new_requests_event(): assert engine.engine.add_request_calls == 3 assert engine.engine.step_calls == old_step_calls + 1 - engine = MockAsyncLLMEngine(worker_use_ray=True) + engine = MockAsyncLLMEngine() assert engine.get_model_config() is not None assert engine.get_tokenizer() is not None assert engine.get_decoding_config() is not None @@ -122,8 +132,17 @@ def start_engine(): timeout_s=60, ) + num_scheduler_steps = int(os.getenv("NUM_SCHEDULER_STEPS", "1")) + print(f"Starting engine with num_scheduler_steps={num_scheduler_steps}") + return AsyncLLMEngine.from_engine_args( - AsyncEngineArgs(model="facebook/opt-125m", enforce_eager=True)) + AsyncEngineArgs(model="facebook/opt-125m", + enforce_eager=True, + num_scheduler_steps=num_scheduler_steps)) + + +def uid() -> str: + return str(uuid.uuid4()) @pytest_asyncio.fixture(scope="module") @@ -146,59 +165,195 @@ def should_do_global_cleanup_after_test(request) -> bool: @pytest.mark.asyncio(scope="module") -async def test_asyncio_run(async_engine): +@pytest.mark.parametrize("stop", [None, ["a stop string"]]) +async def test_asyncio_run(async_engine, stop): + + scheduler_config = await async_engine.get_scheduler_config() + num_scheduler_steps = scheduler_config.num_scheduler_steps async def run(prompt: str): sampling_params = SamplingParams( temperature=0, max_tokens=32, + min_tokens=32, + stop=stop, ) + output_count = 0 + final_output = None async for output in async_engine.generate(prompt, sampling_params, - request_id=prompt): + request_id=uid()): + output_count += 1 final_output = output - return final_output + return final_output, output_count results = await asyncio.gather( run("test0"), - run("test1"), + run("test0"), ) assert len(results) == 2 + first, second = results + + # remove nondeterministic fields for comparison + first[0].metrics = None + second[0].metrics = None + first[0].request_id = None + second[0].request_id = None + + assert str(first) == str(second) + + output_count = results[0][1] + if num_scheduler_steps == 1: + assert output_count == 32 + else: + assert 1 < output_count < 32 @pytest.mark.asyncio(scope="module") -async def test_cancellation(async_engine): +@pytest.mark.parametrize("stop", [None, ["a stop string"]]) +async def test_output_kinds(async_engine, stop): + """Test that output_kind works as expected and that + results are equivalent across different kinds.""" + + scheduler_config = await async_engine.get_scheduler_config() + num_scheduler_steps = scheduler_config.num_scheduler_steps + sampling_params = SamplingParams( temperature=0, - min_tokens=10, - max_tokens=10, + max_tokens=32, + min_tokens=32, + stop=stop, + ) + + async def run(prompt: str, kind: RequestOutputKind): + params = copy(sampling_params) + params.output_kind = kind + + output_count = 0 + final_output = None + async for output in async_engine.generate(prompt, + params, + request_id=uid()): + output_count += 1 + final_output = output + + assert final_output is not None + assert final_output.finished + + return (final_output.prompt_token_ids, + final_output.outputs[0].token_ids, + final_output.outputs[0].text, output_count) + + async def run_deltas(prompt: str): + params = copy(sampling_params) + params.output_kind = RequestOutputKind.DELTA + + prompt_tokens = None + output_tokens: List[int] = [] + output_text = "" + output_count = 0 + final_output = None + async for output in async_engine.generate(prompt, + params, + request_id=uid()): + token_ids = output.outputs[0].token_ids + text = output.outputs[0].text + final_output = output + + # Ensure we get prompt ids iff we haven't yet received output tokens + if output_tokens: + assert 1 <= len(token_ids) <= num_scheduler_steps + assert stop or text + assert not output.prompt_token_ids + else: + assert output.prompt_token_ids + prompt_tokens = output.prompt_token_ids + + output_tokens.extend(token_ids) + output_text += text + + output_count += 1 + + assert final_output is not None + assert final_output.finished + + return prompt_tokens, output_tokens, output_text, output_count + + results = await asyncio.gather( + run("common input prompt", RequestOutputKind.CUMULATIVE), + run("common input prompt", RequestOutputKind.FINAL_ONLY), + run_deltas("common input prompt")) + + # Make sure outputs are the same + prompt_set = set(tuple(prompt_ids) for prompt_ids, _, _, _ in results) + assert len(prompt_set) == 1 + + text_set = set(text for _, _, text, _ in results) + assert len(text_set) == 1 + + tokens_set = set(tuple(ids) for _, ids, _, _ in results) + assert len(tokens_set) == 1 + + cumulative, final, deltas = results + + # output message counts + assert cumulative[3] == deltas[3] + + if num_scheduler_steps == 1: + assert cumulative[3] == 32 + else: + assert 1 < cumulative[3] < 32 + + assert final[3] == 1 + + +@pytest.mark.asyncio(scope="module") +@pytest.mark.parametrize("stop", [None, ["a stop string"]]) +async def test_cancellation(async_engine, stop): + scheduler_config = await async_engine.get_scheduler_config() + num_scheduler_steps = scheduler_config.num_scheduler_steps + + sampling_params = SamplingParams( + temperature=0, + min_tokens=13, + max_tokens=13, + stop=stop, ) + stop_at = 5 if num_scheduler_steps == 1 else 1 + + request_id = uid() + i = 0 with pytest.raises(CancelledError): async for output in async_engine.generate("test2", sampling_params, - request_id="test2"): + request_id=request_id): assert not output.finished i += 1 - if i == 5: - await async_engine.abort("test2") + if i == stop_at: + await async_engine.abort(request_id) - assert i == 5 + assert i == stop_at @pytest.mark.asyncio(scope="module") -async def test_delayed_generator(async_engine): +@pytest.mark.parametrize("stop", [None, ["a stop string"]]) +async def test_delayed_generator(async_engine, stop): + scheduler_config = await async_engine.get_scheduler_config() + + if scheduler_config.num_scheduler_steps != 1: + pytest.skip("no need to test this one with multistep") + sampling_params = SamplingParams( temperature=0, min_tokens=10, max_tokens=10, + stop=stop, ) - stream = async_engine.generate("test3", - sampling_params, - request_id="test3") + stream = async_engine.generate("test3", sampling_params, request_id=uid()) i = 0 final_output: Optional[RealRequestOutput] = None async for output in stream: diff --git a/tests/async_engine/test_openapi_server.py b/tests/async_engine/test_openapi_server.py deleted file mode 100644 index 9e5c7c04287eb..0000000000000 --- a/tests/async_engine/test_openapi_server.py +++ /dev/null @@ -1,106 +0,0 @@ -import openai # use the official client for correctness check -import pytest -import pytest_asyncio - -from ..utils import VLLM_PATH, RemoteOpenAIServer - -# any model with a chat template should work here -MODEL_NAME = "facebook/opt-125m" -chatml_jinja_path = VLLM_PATH / "examples/template_chatml.jinja" -assert chatml_jinja_path.exists() - - -@pytest.fixture(scope="module") -def server(): - args = [ - # use half precision for speed and memory savings in CI environment - "--dtype", - "float16", - "--max-model-len", - "2048", - "--enforce-eager", - "--chat-template", - str(chatml_jinja_path), - ] - - with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: - yield remote_server - - -@pytest_asyncio.fixture -async def client(server): - async with server.get_async_client() as async_client: - yield async_client - - -@pytest.mark.asyncio -async def test_check_models(client: openai.AsyncOpenAI): - models = await client.models.list() - models = models.data - served_model = models[0] - assert served_model.id == MODEL_NAME - assert all(model.root == MODEL_NAME for model in models) - - -@pytest.mark.asyncio -async def test_single_completion(client: openai.AsyncOpenAI): - completion = await client.completions.create(model=MODEL_NAME, - prompt="Hello, my name is", - max_tokens=5, - temperature=0.0) - - assert completion.id is not None - assert len(completion.choices) == 1 - assert len(completion.choices[0].text) >= 5 - assert completion.choices[0].finish_reason == "length" - assert completion.usage == openai.types.CompletionUsage( - completion_tokens=5, prompt_tokens=6, total_tokens=11) - - # test using token IDs - completion = await client.completions.create( - model=MODEL_NAME, - prompt=[0, 0, 0, 0, 0], - max_tokens=5, - temperature=0.0, - ) - assert len(completion.choices[0].text) >= 5 - - -@pytest.mark.asyncio -async def test_single_chat_session(client: openai.AsyncOpenAI): - messages = [{ - "role": "system", - "content": "you are a helpful assistant" - }, { - "role": "user", - "content": "what is 1+1?" - }] - - # test single completion - chat_completion = await client.chat.completions.create(model=MODEL_NAME, - messages=messages, - max_tokens=10, - logprobs=True, - top_logprobs=5) - assert chat_completion.id is not None - assert len(chat_completion.choices) == 1 - - choice = chat_completion.choices[0] - assert choice.finish_reason == "length" - assert chat_completion.usage == openai.types.CompletionUsage( - completion_tokens=10, prompt_tokens=55, total_tokens=65) - - message = choice.message - assert message.content is not None and len(message.content) >= 10 - assert message.role == "assistant" - messages.append({"role": "assistant", "content": message.content}) - - # test multi-turn dialogue - messages.append({"role": "user", "content": "express your result in json"}) - chat_completion = await client.chat.completions.create( - model=MODEL_NAME, - messages=messages, - max_tokens=10, - ) - message = chat_completion.choices[0].message - assert message.content is not None and len(message.content) >= 0 diff --git a/tests/basic_correctness/test_basic_correctness.py b/tests/basic_correctness/test_basic_correctness.py index b970cd48f9170..0fe88e792520a 100644 --- a/tests/basic_correctness/test_basic_correctness.py +++ b/tests/basic_correctness/test_basic_correctness.py @@ -15,12 +15,15 @@ from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata from ..models.utils import check_outputs_equal +from ..utils import multi_gpu_test MODELS = [ "facebook/opt-125m", "meta-llama/Llama-2-7b-hf", ] +TARGET_TEST_SUITE = os.environ.get("TARGET_TEST_SUITE", "L4") + def test_vllm_gc_ed(): """Verify vllm instance is GC'ed when it is deleted""" @@ -70,6 +73,65 @@ def test_models( ) +@multi_gpu_test(num_gpus=2) +@pytest.mark.parametrize( + "model, distributed_executor_backend, attention_backend, " + "test_suite", [ + ("facebook/opt-125m", "ray", "", "L4"), + ("facebook/opt-125m", "mp", "", "L4"), + ("meta-llama/Llama-2-7b-hf", "ray", "", "L4"), + ("meta-llama/Llama-2-7b-hf", "mp", "", "L4"), + ("facebook/opt-125m", "ray", "", "A100"), + ("facebook/opt-125m", "mp", "", "A100"), + ("facebook/opt-125m", "mp", "FLASHINFER", "A100"), + ("meta-llama/Meta-Llama-3-8B", "ray", "FLASHINFER", "A100"), + ]) +def test_models_distributed( + hf_runner, + vllm_runner, + example_prompts, + model: str, + distributed_executor_backend: str, + attention_backend: str, + test_suite: str, +) -> None: + + if test_suite != TARGET_TEST_SUITE: + pytest.skip(f"Skip test for {test_suite}") + + if model == "meta-llama/Llama-2-7b-hf" and distributed_executor_backend == "ray" and attention_backend == "" and test_suite == "L4": # noqa + # test ray adag + os.environ['VLLM_USE_RAY_SPMD_WORKER'] = "1" + os.environ['VLLM_USE_RAY_COMPILED_DAG'] = "1" + + if attention_backend: + os.environ["VLLM_ATTENTION_BACKEND"] = attention_backend + + dtype = "half" + max_tokens = 5 + + # NOTE: take care of the order. run vLLM first, and then run HF. + # vLLM needs a fresh new process without cuda initialization. + # if we run HF first, the cuda initialization will be done and it + # will hurt multiprocessing backend with fork method (the default method). + with vllm_runner(model, + dtype=dtype, + tensor_parallel_size=2, + distributed_executor_backend=distributed_executor_backend + ) as vllm_model: + vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) + + with hf_runner(model, dtype=dtype) as hf_model: + hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) + + check_outputs_equal( + outputs_0_lst=hf_outputs, + outputs_1_lst=vllm_outputs, + name_0="hf", + name_1="vllm", + ) + + def test_model_with_failure(vllm_runner) -> None: try: with patch("vllm.model_executor.models.opt.OPTForCausalLM.forward", diff --git a/tests/basic_correctness/test_chunked_prefill.py b/tests/basic_correctness/test_chunked_prefill.py index 9c34b2a13fd53..14c5447680729 100644 --- a/tests/basic_correctness/test_chunked_prefill.py +++ b/tests/basic_correctness/test_chunked_prefill.py @@ -6,11 +6,13 @@ Run `pytest tests/models/test_chunked_prefill.py`. """ +import os from contextlib import nullcontext import pytest from ..models.utils import check_logprobs_close, check_outputs_equal +from ..utils import multi_gpu_test MODELS = [ "facebook/opt-125m", @@ -66,6 +68,59 @@ def test_models( ) +@multi_gpu_test(num_gpus=2) +@pytest.mark.parametrize("distributed_executor_backend", ["ray", "mp"]) +@pytest.mark.parametrize("model", MODELS) +def test_models_distributed( + hf_runner, + vllm_runner, + example_prompts, + model: str, + distributed_executor_backend: str, +) -> None: + if (model == "meta-llama/Llama-2-7b-hf" + and distributed_executor_backend == "ray"): + # test ray adag + os.environ['VLLM_USE_RAY_SPMD_WORKER'] = "1" + os.environ['VLLM_USE_RAY_COMPILED_DAG'] = "1" + + dtype = "half" + max_tokens = 5 + chunked_prefill_token_size = 16 + + # Add a chunked prefill config. + max_num_seqs = min(chunked_prefill_token_size, 256) + assert chunked_prefill_token_size != -1 + enable_chunked_prefill = True + max_num_batched_tokens = chunked_prefill_token_size + + # NOTE: take care of the order. run vLLM first, and then run HF. + # vLLM needs a fresh new process without cuda initialization. + # if we run HF first, the cuda initialization will be done and it + # will hurt multiprocessing backend with fork method (the default method). + + with vllm_runner( + model, + dtype=dtype, + tensor_parallel_size=2, + max_num_seqs=max_num_seqs, + enable_chunked_prefill=enable_chunked_prefill, + max_num_batched_tokens=max_num_batched_tokens, + distributed_executor_backend=distributed_executor_backend, + ) as vllm_model: + vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) + + with hf_runner(model, dtype=dtype) as hf_model: + hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) + + check_outputs_equal( + outputs_0_lst=hf_outputs, + outputs_1_lst=vllm_outputs, + name_0="hf", + name_1="vllm", + ) + + @pytest.mark.parametrize( "kv_cache_dtype,model", [("fp8_e4m3", diff --git a/tests/basic_correctness/test_preemption.py b/tests/basic_correctness/test_preemption.py index 7e77037da07d3..00806c3e129b1 100644 --- a/tests/basic_correctness/test_preemption.py +++ b/tests/basic_correctness/test_preemption.py @@ -19,10 +19,13 @@ "facebook/opt-125m", ] -assert ENABLE_ARTIFICIAL_PREEMPT is True, ( - "Use an env var VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1. " - "`VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 pytest " - "tests/basic_correctness/test_preemption.py`") + +@pytest.fixture(scope="module", autouse=True) +def check_settings(): + assert ENABLE_ARTIFICIAL_PREEMPT is True, ( + "Use an env var VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1. " + "`VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 pytest " + "tests/basic_correctness/test_preemption.py`") @pytest.fixture @@ -64,6 +67,7 @@ def test_chunked_prefill_recompute( enable_chunked_prefill=enable_chunked_prefill, max_num_seqs=max_num_seqs, worker_use_ray=worker_use_ray, + disable_log_stats=False, ) as vllm_model: vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) assert (vllm_model.model.llm_engine.scheduler[0].artificial_preempt_cnt diff --git a/tests/entrypoints/openai/rpc/__init__.py b/tests/compile/__init__.py similarity index 100% rename from tests/entrypoints/openai/rpc/__init__.py rename to tests/compile/__init__.py diff --git a/tests/compile/test_full_graph.py b/tests/compile/test_full_graph.py index 0a6e781e18834..2e309aaa58d48 100644 --- a/tests/compile/test_full_graph.py +++ b/tests/compile/test_full_graph.py @@ -2,11 +2,23 @@ import pytest +from vllm.utils import cuda_device_count_stateless + +from ..utils import fork_new_process_for_each_test + @pytest.mark.parametrize("model", ["meta-llama/Meta-Llama-3-8B"]) -def test_full_graph(model): +@pytest.mark.parametrize("tp_size", [1, 2]) +@fork_new_process_for_each_test +def test_full_graph(model, tp_size): + + # Skip the test if there are not enough CUDA devices. + if cuda_device_count_stateless() < tp_size: + pytest.skip("Not enough CUDA devices for the test.") + # make sure these models can be captured in full graph mode - os.environ["VLLM_TEST_DYNAMO_GRAPH_CAPTURE"] = "1" + if "VLLM_TEST_DYNAMO_GRAPH_CAPTURE" not in os.environ: + os.environ["VLLM_TEST_DYNAMO_GRAPH_CAPTURE"] = "1" from vllm import LLM, SamplingParams prompts = [ @@ -16,7 +28,15 @@ def test_full_graph(model): "The future of AI is", ] sampling_params = SamplingParams(temperature=0) - llm = LLM(model="meta-llama/Meta-Llama-3-8B", + llm = LLM(model=model, enforce_eager=True, - load_format="dummy") - llm.generate(prompts, sampling_params) + tensor_parallel_size=tp_size, + disable_custom_all_reduce=True) + + outputs = llm.generate(prompts, sampling_params) + + # Print the outputs. + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") diff --git a/tests/conftest.py b/tests/conftest.py index c850e60a9ca6c..c2616bcf7091c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,8 +6,8 @@ import tempfile from collections import UserList from enum import Enum -from typing import (Any, Callable, Dict, List, Optional, Tuple, TypedDict, - TypeVar, Union) +from typing import (Any, Callable, Dict, List, Optional, Tuple, Type, + TypedDict, TypeVar, Union) import numpy as np import pytest @@ -18,7 +18,10 @@ from PIL import Image from transformers import (AutoModelForCausalLM, AutoTokenizer, BatchEncoding, BatchFeature) +from transformers.models.auto.auto_factory import _BaseAutoModelClass +from tests.models.utils import (TokensTextLogprobs, + TokensTextLogprobsPromptLogprobs) from vllm import LLM, SamplingParams from vllm.assets.image import ImageAsset from vllm.assets.video import VideoAsset @@ -32,7 +35,6 @@ to_enc_dec_tuple_list, zip_enc_dec_prompts) from vllm.logger import init_logger from vllm.outputs import RequestOutput -from vllm.sequence import SampleLogprobs from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, cuda_device_count_stateless, identity, is_cpu) @@ -157,10 +159,7 @@ def should_do_global_cleanup_after_test(request) -> bool: to initialize torch. """ - if request.node.get_closest_marker("skip_global_cleanup"): - return False - - return True + return not request.node.get_closest_marker("skip_global_cleanup") @pytest.fixture(autouse=True) @@ -260,7 +259,7 @@ def __init__( *, model_kwargs: Optional[Dict[str, Any]] = None, is_embedding_model: bool = False, - auto_cls=AutoModelForCausalLM, + auto_cls: Type[_BaseAutoModelClass] = AutoModelForCausalLM, postprocess_inputs: Callable[[BatchEncoding], BatchEncoding] = identity, ) -> None: @@ -292,20 +291,14 @@ def __init__( trust_remote_code=True, ) - try: - # don't put this import at the top level - # it will call torch.cuda.device_count() - from transformers import AutoProcessor # noqa: F401 - self.processor = AutoProcessor.from_pretrained( - model_name, - torch_dtype=torch_dtype, - trust_remote_code=True, - ) - except Exception as exc: - logger.warning( - "Unable to auto-load HuggingFace processor for model (%s). " - "Using tokenizer instead. Reason: %s", model_name, exc) - self.processor = self.tokenizer + # don't put this import at the top level + # it will call torch.cuda.device_count() + from transformers import AutoProcessor # noqa: F401 + self.processor = AutoProcessor.from_pretrained( + model_name, + torch_dtype=torch_dtype, + trust_remote_code=True, + ) self.postprocess_inputs = postprocess_inputs @@ -477,7 +470,7 @@ def generate_greedy_logprobs_limit( audios: Optional[PromptAudioInput] = None, videos: Optional[List[np.ndarray]] = None, **kwargs: Any, - ) -> List[Tuple[List[int], str, List[Dict[int, float]]]]: + ) -> List[TokensTextLogprobs]: all_logprobs: List[List[Dict[int, float]]] = [] all_output_ids: List[List[int]] = [] all_output_strs: List[str] = [] @@ -533,7 +526,7 @@ def generate_encoder_decoder_greedy_logprobs_limit( max_tokens: int, num_logprobs: int, **kwargs: Any, - ) -> List[Tuple[List[int], str, List[Dict[int, float]]]]: + ) -> List[TokensTextLogprobs]: ''' Greedy logprobs generation for vLLM encoder/decoder models ''' @@ -658,17 +651,19 @@ def generate( outputs.append((req_sample_output_ids, req_sample_output_strs)) return outputs + @staticmethod def _final_steps_generate_w_logprobs( - self, req_outputs: List[RequestOutput], - ) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]: - outputs: List[Tuple[List[int], str, Optional[SampleLogprobs]]] = [] + ) -> List[TokensTextLogprobsPromptLogprobs]: + outputs: List[TokensTextLogprobsPromptLogprobs] = [] for req_output in req_outputs: + assert len(req_output.outputs) > 0 for sample in req_output.outputs: output_str = sample.text output_ids = list(sample.token_ids) output_logprobs = sample.logprobs - outputs.append((output_ids, output_str, output_logprobs)) + outputs.append((output_ids, output_str, output_logprobs, + req_output.prompt_logprobs)) return outputs def generate_w_logprobs( @@ -678,7 +673,8 @@ def generate_w_logprobs( images: Optional[PromptImageInput] = None, audios: Optional[PromptAudioInput] = None, videos: Optional[PromptVideoInput] = None, - ) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]: + ) -> Union[List[TokensTextLogprobs], + List[TokensTextLogprobsPromptLogprobs]]: assert sampling_params.logprobs is not None if images is not None: @@ -703,13 +699,20 @@ def generate_w_logprobs( req_outputs = self.model.generate(inputs, sampling_params=sampling_params) - return self._final_steps_generate_w_logprobs(req_outputs) + + toks_str_logsprobs_prompt_logprobs = ( + self._final_steps_generate_w_logprobs(req_outputs)) + # Omit prompt logprobs if not required by sampling params + return ([x[0:-1] for x in toks_str_logsprobs_prompt_logprobs] + if sampling_params.prompt_logprobs is None else + toks_str_logsprobs_prompt_logprobs) def generate_encoder_decoder_w_logprobs( self, encoder_decoder_prompts: List[ExplicitEncoderDecoderPrompt[str, str]], sampling_params: SamplingParams, - ) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]: + ) -> Union[List[TokensTextLogprobs], + List[TokensTextLogprobsPromptLogprobs]]: ''' Logprobs generation for vLLM encoder/decoder models ''' @@ -717,7 +720,12 @@ def generate_encoder_decoder_w_logprobs( assert sampling_params.logprobs is not None req_outputs = self.model.generate(encoder_decoder_prompts, sampling_params=sampling_params) - return self._final_steps_generate_w_logprobs(req_outputs) + toks_str_logsprobs_prompt_logprobs = ( + self._final_steps_generate_w_logprobs(req_outputs)) + # Omit prompt logprobs if not required by sampling params + return ([x[0:-1] for x in toks_str_logsprobs_prompt_logprobs] + if sampling_params.prompt_logprobs is None else + toks_str_logsprobs_prompt_logprobs) def generate_greedy( self, @@ -735,44 +743,48 @@ def generate_greedy_logprobs( prompts: List[str], max_tokens: int, num_logprobs: int, + num_prompt_logprobs: Optional[int] = None, images: Optional[PromptImageInput] = None, audios: Optional[PromptAudioInput] = None, videos: Optional[PromptVideoInput] = None, stop_token_ids: Optional[List[int]] = None, - ) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]: - greedy_logprobs_params = SamplingParams(temperature=0.0, - max_tokens=max_tokens, - logprobs=num_logprobs, - stop_token_ids=stop_token_ids) - outputs = self.generate_w_logprobs(prompts, - greedy_logprobs_params, - images=images, - audios=audios, - videos=videos) - - return [(output_ids, output_str, output_logprobs) - for output_ids, output_str, output_logprobs in outputs] + ) -> Union[List[TokensTextLogprobs], + List[TokensTextLogprobsPromptLogprobs]]: + greedy_logprobs_params = SamplingParams( + temperature=0.0, + max_tokens=max_tokens, + logprobs=num_logprobs, + prompt_logprobs=(num_prompt_logprobs), + stop_token_ids=stop_token_ids) + + return self.generate_w_logprobs(prompts, + greedy_logprobs_params, + images=images, + audios=audios, + videos=videos) def generate_encoder_decoder_greedy_logprobs( self, encoder_decoder_prompts: List[ExplicitEncoderDecoderPrompt[str, str]], max_tokens: int, num_logprobs: int, - ) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]: - greedy_logprobs_params = SamplingParams(temperature=0.0, - use_beam_search=False, - max_tokens=max_tokens, - logprobs=num_logprobs) + num_prompt_logprobs: Optional[int] = None, + ) -> Union[List[TokensTextLogprobs], + List[TokensTextLogprobsPromptLogprobs]]: + greedy_logprobs_params = SamplingParams( + temperature=0.0, + use_beam_search=False, + max_tokens=max_tokens, + logprobs=num_logprobs, + prompt_logprobs=(num_prompt_logprobs), + ) ''' Greedy logprobs generation for vLLM encoder/decoder models ''' - outputs = self.generate_encoder_decoder_w_logprobs( + return self.generate_encoder_decoder_w_logprobs( encoder_decoder_prompts, greedy_logprobs_params) - return [(output_ids, output_str, output_logprobs) - for output_ids, output_str, output_logprobs in outputs] - def generate_beam_search( self, prompts: List[str], diff --git a/tests/distributed/test_basic_distributed_correctness.py b/tests/distributed/test_basic_distributed_correctness.py deleted file mode 100644 index e254686f269b1..0000000000000 --- a/tests/distributed/test_basic_distributed_correctness.py +++ /dev/null @@ -1,80 +0,0 @@ -"""Compare the outputs of HF and distributed vLLM when using greedy sampling. - -Run: -```sh -cd $VLLM_PATH/tests - -pytest distributed/test_basic_distributed_correctness.py -``` -""" -import os - -import pytest - -from vllm.utils import cuda_device_count_stateless - -from ..models.utils import check_outputs_equal -from ..utils import fork_new_process_for_each_test - -TARGET_TEST_SUITE = os.environ.get("TARGET_TEST_SUITE", "L4") - - -@pytest.mark.skipif(cuda_device_count_stateless() < 2, - reason="Need at least 2 GPUs to run the test.") -@pytest.mark.parametrize( - "model, distributed_executor_backend, attention_backend, " - "test_suite", [ - ("facebook/opt-125m", "ray", "", "L4"), - ("facebook/opt-125m", "mp", "", "L4"), - ("meta-llama/Llama-2-7b-hf", "ray", "", "L4"), - ("meta-llama/Llama-2-7b-hf", "mp", "", "L4"), - ("facebook/opt-125m", "ray", "", "A100"), - ("facebook/opt-125m", "mp", "", "A100"), - ("facebook/opt-125m", "mp", "FLASHINFER", "A100"), - ("meta-llama/Meta-Llama-3-8B", "ray", "FLASHINFER", "A100"), - ]) -@fork_new_process_for_each_test -def test_models( - hf_runner, - vllm_runner, - example_prompts, - model: str, - distributed_executor_backend: str, - attention_backend: str, - test_suite: str, -) -> None: - - if test_suite != TARGET_TEST_SUITE: - pytest.skip(f"Skip test for {test_suite}") - - if model == "meta-llama/Llama-2-7b-hf" and distributed_executor_backend == "ray" and attention_backend == "" and test_suite == "L4": # noqa - # test ray adag - os.environ['VLLM_USE_RAY_SPMD_WORKER'] = "1" - os.environ['VLLM_USE_RAY_COMPILED_DAG'] = "1" - - if attention_backend: - os.environ["VLLM_ATTENTION_BACKEND"] = attention_backend - - dtype = "half" - max_tokens = 5 - - # NOTE: take care of the order. run vLLM first, and then run HF. - # vLLM needs a fresh new process without cuda initialization. - # if we run HF first, the cuda initialization will be done and it - # will hurt multiprocessing backend with fork method (the default method). - with vllm_runner(model, - dtype=dtype, - tensor_parallel_size=2, - distributed_executor_backend=distributed_executor_backend - ) as vllm_model: - vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) - - with hf_runner(model, dtype=dtype) as hf_model: - hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) - - check_outputs_equal( - outputs_0_lst=hf_outputs, - outputs_1_lst=vllm_outputs, - name_0="hf", - name_1="vllm", - ) diff --git a/tests/distributed/test_basic_distributed_correctness_enc_dec.py b/tests/distributed/test_basic_distributed_correctness_enc_dec.py deleted file mode 100644 index f00d5ef584a2a..0000000000000 --- a/tests/distributed/test_basic_distributed_correctness_enc_dec.py +++ /dev/null @@ -1,102 +0,0 @@ -"""For encoder/decoder models only: -Compare the outputs of HF and distributed vLLM when using greedy sampling. - -Run: -```sh -cd $VLLM_PATH/tests - -pytest distributed/test_basic_distributed_correctness_enc_dec.py -``` -""" - -import pytest -from transformers import AutoModelForSeq2SeqLM - -from vllm.utils import cuda_device_count_stateless - -from ..conftest import DecoderPromptType -from ..models.utils import check_logprobs_close -from ..utils import fork_new_process_for_each_test - - -@pytest.mark.skipif(cuda_device_count_stateless() < 2, - reason="Need at least 2 GPUs to run the test.") -@pytest.mark.parametrize("model, distributed_executor_backend", [ - ("facebook/bart-large-cnn", "ray"), - ("facebook/bart-large-cnn", "mp"), -]) -@fork_new_process_for_each_test -def test_models( - model: str, - distributed_executor_backend: str, - hf_runner, - vllm_runner, - example_encoder_decoder_prompts, -) -> None: - ''' - Test vLLM BART inference on more than one GPU, comparing - outputs against HF as a baseline. - - Fork a new process for each test, to prevent CUDA from - being re-initialized by successive tests within the same - process. - - Arguments: - - * model: the HF ID of the specific BART variant under test - * distributed_executor_backend - * hf_runner: HuggingFace (HF) test model runner - * vllm_runner: vLLM test model runner - * example_encoder_decoder_prompts: test fixture which provides a - dictionary of dummy prompts - ''' - - dtype = "float" - max_tokens = 64 - num_logprobs = 5 - - # Example inputs with non-trivial (i.e. not None/empty) encoder & - # decoder prompts. - test_prompts = example_encoder_decoder_prompts[DecoderPromptType.CUSTOM] - - # NOTE: take care of the order. run vLLM first, and then run HF. - # vLLM needs a fresh new process without cuda initialization. - # if we run HF first, the cuda initialization will be done and it - # will hurt multiprocessing backend with fork method (the default method). - with vllm_runner( - model, - dtype=dtype, - tensor_parallel_size=2, - distributed_executor_backend=distributed_executor_backend, - enforce_eager=True, - ) as vllm_model: - vllm_outputs = vllm_model.generate_encoder_decoder_greedy_logprobs( - test_prompts, max_tokens, num_logprobs) - - # Configuration settings for HF baseline - hf_kwargs = { - "top_k": None, - "num_beams": 1, - "repetition_penalty": 1.0, - "top_p": 1.0, - "length_penalty": 1.0, - "early_stopping": False, - "no_repeat_ngram_size": None, - "min_length": 0 - } - - with hf_runner(model, dtype=dtype, - auto_cls=AutoModelForSeq2SeqLM) as hf_model: - hf_outputs = (hf_model.generate_encoder_decoder_greedy_logprobs_limit( - test_prompts, - max_tokens, - num_logprobs, - **hf_kwargs, - )) - - check_logprobs_close( - outputs_0_lst=hf_outputs, - outputs_1_lst=vllm_outputs, - name_0="hf", - name_1="vllm", - ) diff --git a/tests/distributed/test_chunked_prefill_distributed.py b/tests/distributed/test_chunked_prefill_distributed.py deleted file mode 100644 index 262845f19822f..0000000000000 --- a/tests/distributed/test_chunked_prefill_distributed.py +++ /dev/null @@ -1,75 +0,0 @@ -"""Compare the outputs of HF and distributed vLLM when using greedy sampling. - -Run: -```sh -pytest test_chunked_prefill_distributed.py -``` -""" - -import os - -import pytest - -from vllm.utils import cuda_device_count_stateless - -from ..models.utils import check_outputs_equal -from ..utils import fork_new_process_for_each_test - - -@pytest.mark.skipif(cuda_device_count_stateless() < 2, - reason="Need at least 2 GPUs to run the test.") -@pytest.mark.parametrize("model, distributed_executor_backend", [ - ("facebook/opt-125m", "ray"), - ("meta-llama/Llama-2-7b-hf", "ray"), - ("facebook/opt-125m", "mp"), - ("meta-llama/Llama-2-7b-hf", "mp"), -]) -@fork_new_process_for_each_test -def test_models( - hf_runner, - vllm_runner, - example_prompts, - model: str, - distributed_executor_backend: str, -) -> None: - if model == "meta-llama/Llama-2-7b-hf" and distributed_executor_backend == "ray": # noqa - assert distributed_executor_backend == "ray" - # test ray adag - os.environ['VLLM_USE_RAY_SPMD_WORKER'] = "1" - os.environ['VLLM_USE_RAY_COMPILED_DAG'] = "1" - - dtype = "half" - max_tokens = 5 - chunked_prefill_token_size = 16 - - # Add a chunked prefill config. - max_num_seqs = min(chunked_prefill_token_size, 256) - assert chunked_prefill_token_size != -1 - enable_chunked_prefill = True - max_num_batched_tokens = chunked_prefill_token_size - - # NOTE: take care of the order. run vLLM first, and then run HF. - # vLLM needs a fresh new process without cuda initialization. - # if we run HF first, the cuda initialization will be done and it - # will hurt multiprocessing backend with fork method (the default method). - - with vllm_runner( - model, - dtype=dtype, - tensor_parallel_size=2, - max_num_seqs=max_num_seqs, - enable_chunked_prefill=enable_chunked_prefill, - max_num_batched_tokens=max_num_batched_tokens, - distributed_executor_backend=distributed_executor_backend, - ) as vllm_model: - vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) - - with hf_runner(model, dtype=dtype) as hf_model: - hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) - - check_outputs_equal( - outputs_0_lst=hf_outputs, - outputs_1_lst=vllm_outputs, - name_0="hf", - name_1="vllm", - ) diff --git a/tests/distributed/test_multimodal_broadcast.py b/tests/distributed/test_multimodal_broadcast.py deleted file mode 100644 index 73ef863c2f193..0000000000000 --- a/tests/distributed/test_multimodal_broadcast.py +++ /dev/null @@ -1,58 +0,0 @@ -"""Compare the outputs of HF and distributed vLLM when using greedy sampling. - -Run: -```sh -pytest -s -v test_multimodal_broadcast.py -``` -""" - -import pytest - -from vllm.utils import cuda_device_count_stateless - -from ..utils import fork_new_process_for_each_test - - -@pytest.mark.skipif(cuda_device_count_stateless() < 2, - reason="Need at least 2 GPUs to run the test.") -@pytest.mark.parametrize("model, distributed_executor_backend", [ - ("llava-hf/llava-1.5-7b-hf", "ray"), - ("llava-hf/llava-v1.6-mistral-7b-hf", "ray"), - ("facebook/chameleon-7b", "ray"), - ("llava-hf/llava-1.5-7b-hf", "mp"), - ("llava-hf/llava-v1.6-mistral-7b-hf", "mp"), - ("facebook/chameleon-7b", "mp"), -]) -@fork_new_process_for_each_test -def test_models(hf_runner, vllm_runner, image_assets, model: str, - distributed_executor_backend: str) -> None: - - dtype = "half" - max_tokens = 5 - num_logprobs = 5 - tensor_parallel_size = 2 - - if model.startswith("llava-hf/llava-1.5"): - from ..models.test_llava import models, run_test - elif model.startswith("llava-hf/llava-v1.6"): - from ..models.test_llava_next import run_test # type: ignore[no-redef] - from ..models.test_llava_next import models - elif model.startswith("facebook/chameleon"): - from ..models.test_chameleon import run_test # type: ignore[no-redef] - from ..models.test_chameleon import models - else: - raise NotImplementedError(f"Unsupported model: {model}") - - run_test( - hf_runner, - vllm_runner, - image_assets, - model=models[0], - # So that LLaVA-NeXT processor may return nested list - size_factors=[0.25, 0.5, 1.0], - dtype=dtype, - max_tokens=max_tokens, - num_logprobs=num_logprobs, - tensor_parallel_size=tensor_parallel_size, - distributed_executor_backend=distributed_executor_backend, - ) diff --git a/tests/distributed/test_pipeline_parallel.py b/tests/distributed/test_pipeline_parallel.py index d2219eed988e1..02288dc9dac90 100644 --- a/tests/distributed/test_pipeline_parallel.py +++ b/tests/distributed/test_pipeline_parallel.py @@ -32,9 +32,11 @@ (1, 4, 1, 0, 0, "meta-llama/Meta-Llama-3-8B", "ray"), (2, 2, 1, 0, 0, "meta-llama/Meta-Llama-3-8B", "ray"), (2, 2, 0, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray"), - (1, 2, 1, 1, 1, "OpenGVLab/InternVL2-1B", "ray"), - (1, 2, 1, 1, 1, "OpenGVLab/InternVL2-2B", "ray"), - (1, 2, 1, 0, 1, "OpenGVLab/InternVL2-4B", "ray"), + # NOTE: InternVL2 multi-node tests are flaky, + # use mp backend to skip the multi-node tests + (1, 2, 1, 1, 1, "OpenGVLab/InternVL2-1B", "mp"), + (1, 2, 1, 1, 1, "OpenGVLab/InternVL2-2B", "mp"), + (1, 2, 1, 0, 1, "OpenGVLab/InternVL2-4B", "mp"), ], ) @fork_new_process_for_each_test diff --git a/tests/distributed/test_same_node.py b/tests/distributed/test_same_node.py index 07e84d0ad54cd..defc4e23c8ce2 100644 --- a/tests/distributed/test_same_node.py +++ b/tests/distributed/test_same_node.py @@ -1,13 +1,13 @@ import os -import torch +import torch.distributed as dist from vllm.distributed.parallel_state import in_the_same_node_as -torch.distributed.init_process_group(backend="gloo") -test_result = all( - in_the_same_node_as(torch.distributed.group.WORLD, source_rank=0)) +if __name__ == "__main__": + dist.init_process_group(backend="gloo") + test_result = all(in_the_same_node_as(dist.group.WORLD, source_rank=0)) -expected = os.environ.get("VLLM_TEST_SAME_HOST", "1") == "1" -assert test_result == expected, f"Expected {expected}, got {test_result}" -print("Same node test passed!") + expected = os.environ.get("VLLM_TEST_SAME_HOST", "1") == "1" + assert test_result == expected, f"Expected {expected}, got {test_result}" + print("Same node test passed!") diff --git a/vllm/model_executor/layers/ops/__init__.py b/tests/encoder_decoder/__init__.py similarity index 100% rename from vllm/model_executor/layers/ops/__init__.py rename to tests/encoder_decoder/__init__.py diff --git a/tests/encoder_decoder/test_e2e_correctness.py b/tests/encoder_decoder/test_e2e_correctness.py new file mode 100644 index 0000000000000..9324a737a779c --- /dev/null +++ b/tests/encoder_decoder/test_e2e_correctness.py @@ -0,0 +1,98 @@ +"""E2E tests to verify the correctness of the encoder-decoder framework + +Run `pytest tests/encoder_decoder/test_e2e_correctness.py`. +""" +from typing import List, Optional, Tuple + +import pytest +from transformers import AutoModelForSeq2SeqLM + +from vllm.sequence import SampleLogprobs +from vllm.utils import is_cpu + +from ..conftest import DecoderPromptType +from ..models.utils import check_logprobs_close + + +def vllm_to_hf_output( + vllm_output: Tuple[List[int], str, Optional[SampleLogprobs]], + decoder_prompt_type: DecoderPromptType, +): + """Sanitize vllm output to be comparable with hf output.""" + output_ids, output_str, out_logprobs = vllm_output + + hf_output_str = output_str + "" + if decoder_prompt_type == DecoderPromptType.NONE: + hf_output_str = "" + hf_output_str + + return output_ids, hf_output_str, out_logprobs + + +@pytest.mark.parametrize("model", ["facebook/bart-large-cnn"]) +@pytest.mark.parametrize("dtype", ["bfloat16"]) +@pytest.mark.parametrize("max_tokens", [128]) +@pytest.mark.parametrize("num_logprobs", [5]) +@pytest.mark.parametrize("decoder_prompt_type", list(DecoderPromptType)) +@pytest.mark.parametrize("enforce_eager", [True, False]) +@pytest.mark.skipif( + is_cpu(), + reason="CPU backend is not currently supported with encoder/decoder models" +) +def test_encoder_decoder_e2e( + hf_runner, + vllm_runner, + example_encoder_decoder_prompts, + model: str, + dtype: str, + max_tokens: int, + num_logprobs: int, + decoder_prompt_type: DecoderPromptType, + enforce_eager: bool, +) -> None: + ''' + End-to-End (E2E) test for the encoder-decoder framework. + This test evaluates the encoder-decoder functionality using the BART + model. We compare the outputs of the Hugging Face and vLLM + implementations to ensure that both implementations produce consistent + and correct results. + ''' + test_case_prompts = example_encoder_decoder_prompts[decoder_prompt_type] + + # Configuration settings for HF baseline + hf_kwargs = { + "top_k": None, + "num_beams": 1, + "repetition_penalty": 1.0, + "top_p": 1.0, + "length_penalty": 1.0, + "early_stopping": False, + "no_repeat_ngram_size": None, + "min_length": 0 + } + + with hf_runner(model, dtype=dtype, + auto_cls=AutoModelForSeq2SeqLM) as hf_model: + hf_outputs = (hf_model.generate_encoder_decoder_greedy_logprobs_limit( + test_case_prompts, + max_tokens, + num_logprobs, + **hf_kwargs, + )) + with vllm_runner(model, dtype=dtype, + enforce_eager=enforce_eager) as vllm_model: + vllm_outputs = vllm_model.generate_encoder_decoder_greedy_logprobs( + test_case_prompts, max_tokens, num_logprobs) + + hf_skip_tokens = (1 + if decoder_prompt_type == DecoderPromptType.NONE else 0) + + check_logprobs_close( + outputs_0_lst=hf_outputs, + outputs_1_lst=[ + vllm_to_hf_output(vllm_output, decoder_prompt_type) + for vllm_output in vllm_outputs + ], + name_0="hf", + name_1="vllm", + num_outputs_0_skip_tokens=hf_skip_tokens, + ) diff --git a/tests/engine/test_arg_utils.py b/tests/engine/test_arg_utils.py index 3208d6bb48bdc..8dd200b35d0f3 100644 --- a/tests/engine/test_arg_utils.py +++ b/tests/engine/test_arg_utils.py @@ -1,6 +1,8 @@ +from argparse import ArgumentTypeError + import pytest -from vllm.engine.arg_utils import EngineArgs +from vllm.engine.arg_utils import EngineArgs, nullable_kvs from vllm.utils import FlexibleArgumentParser @@ -13,6 +15,10 @@ "image": 16, "video": 2 }), + ("Image=16, Video=2", { + "image": 16, + "video": 2 + }), ]) def test_limit_mm_per_prompt_parser(arg, expected): parser = EngineArgs.add_cli_args(FlexibleArgumentParser()) @@ -22,3 +28,15 @@ def test_limit_mm_per_prompt_parser(arg, expected): args = parser.parse_args(["--limit-mm-per-prompt", arg]) assert args.limit_mm_per_prompt == expected + + +@pytest.mark.parametrize( + ("arg"), + [ + "image", # Missing = + "image=4,image=5", # Conflicting values + "image=video=4" # Too many = in tokenized arg + ]) +def test_bad_nullable_kvs(arg): + with pytest.raises(ArgumentTypeError): + nullable_kvs(arg) diff --git a/tests/engine/test_skip_tokenizer_init.py b/tests/engine/test_skip_tokenizer_init.py index 338b208723ba9..b8818af5614cf 100644 --- a/tests/engine/test_skip_tokenizer_init.py +++ b/tests/engine/test_skip_tokenizer_init.py @@ -11,9 +11,10 @@ def test_skip_tokenizer_initialization(model: str): # token ids. llm = LLM(model=model, skip_tokenizer_init=True) sampling_params = SamplingParams(prompt_logprobs=True, detokenize=True) - with pytest.raises(ValueError) as err: + + with pytest.raises(ValueError, match="cannot pass text prompts when"): llm.generate("abc", sampling_params) - assert "prompts must be None if" in str(err.value) + outputs = llm.generate({"prompt_token_ids": [1, 2, 3]}, sampling_params=sampling_params) assert len(outputs) > 0 diff --git a/tests/entrypoints/openai/rpc/test_zmq_client.py b/tests/entrypoints/openai/rpc/test_zmq_client.py deleted file mode 100644 index cafd125c5a598..0000000000000 --- a/tests/entrypoints/openai/rpc/test_zmq_client.py +++ /dev/null @@ -1,120 +0,0 @@ -import asyncio -import tempfile -import unittest -import unittest.mock -import uuid - -import pytest -import pytest_asyncio - -from vllm.engine.async_llm_engine import AsyncLLMEngine -from vllm.entrypoints.openai.rpc.client import (AsyncEngineRPCClient, - RPCClientClosedError) -from vllm.entrypoints.openai.rpc.server import AsyncEngineRPCServer - - -@pytest.fixture(scope="function") -def tmp_socket(): - with tempfile.TemporaryDirectory() as td: - yield f"ipc://{td}/{uuid.uuid4()}" - - -@pytest_asyncio.fixture(scope="function") -async def dummy_server(tmp_socket, monkeypatch): - dummy_engine = unittest.mock.AsyncMock() - - def dummy_engine_builder(*args, **kwargs): - return dummy_engine - - with monkeypatch.context() as m: - m.setattr(AsyncLLMEngine, "from_engine_args", dummy_engine_builder) - server = AsyncEngineRPCServer(None, None, rpc_path=tmp_socket) - - loop = asyncio.get_running_loop() - server_task = loop.create_task(server.run_server_loop()) - - try: - yield server - finally: - server_task.cancel() - server.cleanup() - - -@pytest_asyncio.fixture(scope="function") -async def client(tmp_socket): - client = AsyncEngineRPCClient(rpc_path=tmp_socket) - # Sanity check: the server is connected - await client._wait_for_server_rpc() - - try: - yield client - finally: - client.close() - - -@pytest.mark.asyncio -async def test_client_data_methods_use_timeouts(monkeypatch, dummy_server, - client: AsyncEngineRPCClient): - with monkeypatch.context() as m: - # Make the server _not_ reply with a model config - m.setattr(dummy_server, "get_config", lambda x: None) - m.setattr(client, "_data_timeout", 10) - - # And ensure the task completes anyway - # (client.setup() invokes server.get_config()) - client_task = asyncio.get_running_loop().create_task(client.setup()) - with pytest.raises(TimeoutError, match="Server didn't reply within"): - await asyncio.wait_for(client_task, timeout=0.05) - - -@pytest.mark.asyncio -async def test_client_aborts_use_timeouts(monkeypatch, dummy_server, - client: AsyncEngineRPCClient): - with monkeypatch.context() as m: - # Hang all abort requests - m.setattr(dummy_server, "abort", lambda x: None) - m.setattr(client, "_data_timeout", 10) - - # The client should suppress timeouts on `abort`s - # and return normally, assuming the server will eventually - # abort the request. - client_task = asyncio.get_running_loop().create_task( - client.abort("test request id")) - await asyncio.wait_for(client_task, timeout=0.05) - - -@pytest.mark.asyncio -async def test_client_data_methods_reraise_exceptions( - monkeypatch, dummy_server, client: AsyncEngineRPCClient): - with monkeypatch.context() as m: - # Make the server raise some random exception - exception = RuntimeError("Client test exception") - - def raiser(): - raise exception - - m.setattr(dummy_server.engine, "get_model_config", raiser) - m.setattr(client, "_data_timeout", 10) - - client_task = asyncio.get_running_loop().create_task(client.setup()) - # And ensure the task completes, raising the exception - with pytest.raises(RuntimeError, match=str(exception)): - await asyncio.wait_for(client_task, timeout=0.05) - - -@pytest.mark.asyncio -async def test_client_errors_after_closing(monkeypatch, dummy_server, - client: AsyncEngineRPCClient): - - client.close() - - # Healthchecks and generate requests will fail with explicit errors - with pytest.raises(RPCClientClosedError): - await client.check_health() - with pytest.raises(RPCClientClosedError): - async for _ in client.generate(None, None, None): - pass - - # But no-ops like aborting will pass - await client.abort("test-request-id") - await client.do_log_stats() diff --git a/tests/entrypoints/openai/test_accuracy.py b/tests/entrypoints/openai/test_accuracy.py index b442a903c33ae..2ad8460023c25 100644 --- a/tests/entrypoints/openai/test_accuracy.py +++ b/tests/entrypoints/openai/test_accuracy.py @@ -18,38 +18,32 @@ FILTER = "exact_match,strict-match" RTOL = 0.03 EXPECTED_VALUE = 0.58 +DEFAULT_ARGS = ["--max-model-len", "4096", "--disable-log-requests"] +MORE_ARGS_LIST = [["--enable-chunked-prefill"], ["--num-scheduler-steps", "8"]] -@pytest.fixture(scope="module") -def server(): - args = [ - "--max-model-len", "4096", "--enable-chunked-prefill", - "--disable-log-requests", "--enforce-eager" - ] - - with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: - yield remote_server - - -@pytest.fixture(scope="module") -def server_data(server): - return { - "url": f"{server.url_for('v1')}/completions", - } +@pytest.mark.parametrize("more_args", MORE_ARGS_LIST) +def test_lm_eval_accuracy(more_args): + args = list(DEFAULT_ARGS) + args.extend(more_args) + print(f"Running with: {args}") -def test_lm_eval_accuracy(server_data): - model_args = (f"model={MODEL_NAME}," - f"base_url={server_data['url']}," - f"num_concurrent={NUM_CONCURRENT},tokenized_requests=False") - - results = lm_eval.simple_evaluate( - model="local-completions", - model_args=model_args, - tasks=TASK, - ) - - measured_value = results["results"][TASK][FILTER] - assert (measured_value - RTOL < EXPECTED_VALUE - and measured_value + RTOL > EXPECTED_VALUE - ), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}" + with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: + url = f"{remote_server.url_for('v1')}/completions" + + model_args = ( + f"model={MODEL_NAME}," + f"base_url={url}," + f"num_concurrent={NUM_CONCURRENT},tokenized_requests=False") + + results = lm_eval.simple_evaluate( + model="local-completions", + model_args=model_args, + tasks=TASK, + ) + + measured_value = results["results"][TASK][FILTER] + assert (measured_value - RTOL < EXPECTED_VALUE + and measured_value + RTOL > EXPECTED_VALUE + ), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}" diff --git a/tests/async_engine/test_chat_template.py b/tests/entrypoints/openai/test_chat_template.py similarity index 99% rename from tests/async_engine/test_chat_template.py rename to tests/entrypoints/openai/test_chat_template.py index 61a6d77cd8756..b98ab2e30d78d 100644 --- a/tests/async_engine/test_chat_template.py +++ b/tests/entrypoints/openai/test_chat_template.py @@ -5,7 +5,7 @@ from vllm.entrypoints.openai.protocol import ChatCompletionRequest from vllm.transformers_utils.tokenizer import get_tokenizer -from ..utils import VLLM_PATH +from ...utils import VLLM_PATH chatml_jinja_path = VLLM_PATH / "examples/template_chatml.jinja" assert chatml_jinja_path.exists() diff --git a/tests/entrypoints/openai/test_mp_api_server.py b/tests/entrypoints/openai/test_mp_api_server.py deleted file mode 100644 index fbfe0db19dd03..0000000000000 --- a/tests/entrypoints/openai/test_mp_api_server.py +++ /dev/null @@ -1,40 +0,0 @@ -import time - -import pytest - -from vllm.entrypoints.openai.api_server import build_async_engine_client -from vllm.entrypoints.openai.cli_args import make_arg_parser -from vllm.utils import FlexibleArgumentParser - - -@pytest.mark.asyncio -async def test_mp_crash_detection(): - - parser = FlexibleArgumentParser(description="vLLM's remote OpenAI server.") - parser = make_arg_parser(parser) - args = parser.parse_args([]) - # use an invalid tensor_parallel_size to trigger the - # error in the server - args.tensor_parallel_size = 65536 - - start = time.perf_counter() - async with build_async_engine_client(args): - pass - end = time.perf_counter() - - assert end - start < 60, ("Expected vLLM to gracefully shutdown in <60s " - "if there is an error in the startup.") - - -@pytest.mark.asyncio -async def test_mp_cuda_init(): - # it should not crash, when cuda is initialized - # in the API server process - import torch - torch.cuda.init() - parser = FlexibleArgumentParser(description="vLLM's remote OpenAI server.") - parser = make_arg_parser(parser) - args = parser.parse_args([]) - - async with build_async_engine_client(args): - pass diff --git a/tests/entrypoints/openai/test_serving_chat.py b/tests/entrypoints/openai/test_serving_chat.py index c3a6c65be1d90..de2a932199a01 100644 --- a/tests/entrypoints/openai/test_serving_chat.py +++ b/tests/entrypoints/openai/test_serving_chat.py @@ -4,7 +4,7 @@ from unittest.mock import MagicMock from vllm.config import MultiModalConfig -from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm.engine.multiprocessing.client import MQLLMEngineClient from vllm.entrypoints.openai.protocol import ChatCompletionRequest from vllm.entrypoints.openai.serving_chat import OpenAIServingChat from vllm.transformers_utils.tokenizer import get_tokenizer @@ -52,8 +52,9 @@ def test_async_serving_chat_init(): def test_serving_chat_should_set_correct_max_tokens(): - mock_engine = MagicMock(spec=AsyncLLMEngine) + mock_engine = MagicMock(spec=MQLLMEngineClient) mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME) + mock_engine.errored = False serving_chat = OpenAIServingChat(mock_engine, MockModelConfig(), diff --git a/tests/entrypoints/openai/test_serving_engine.py b/tests/entrypoints/openai/test_serving_engine.py index 325bc03434287..6d9e620b4af7d 100644 --- a/tests/entrypoints/openai/test_serving_engine.py +++ b/tests/entrypoints/openai/test_serving_engine.py @@ -4,7 +4,7 @@ import pytest from vllm.config import ModelConfig -from vllm.engine.protocol import AsyncEngineClient +from vllm.engine.protocol import EngineClient from vllm.entrypoints.openai.protocol import (ErrorResponse, LoadLoraAdapterRequest, UnloadLoraAdapterRequest) @@ -18,7 +18,7 @@ async def _async_serving_engine_init(): - mock_engine_client = MagicMock(spec=AsyncEngineClient) + mock_engine_client = MagicMock(spec=EngineClient) mock_model_config = MagicMock(spec=ModelConfig) # Set the max_model_len attribute to avoid missing attribute mock_model_config.max_model_len = 2048 diff --git a/tests/entrypoints/openai/test_shutdown.py b/tests/entrypoints/openai/test_shutdown.py index 73ecb74007272..25ab91ef69333 100644 --- a/tests/entrypoints/openai/test_shutdown.py +++ b/tests/entrypoints/openai/test_shutdown.py @@ -44,5 +44,5 @@ async def test_shutdown_on_engine_failure(tmp_path): prompt="Hello, my name is") # Now the server should shut down - return_code = remote_server.proc.wait(timeout=3) + return_code = remote_server.proc.wait(timeout=8) assert return_code is not None diff --git a/tests/kernels/test_activation.py b/tests/kernels/test_activation.py index ed050ce851535..9b476585fa19e 100644 --- a/tests/kernels/test_activation.py +++ b/tests/kernels/test_activation.py @@ -7,6 +7,7 @@ from vllm.model_executor.layers.activation import (FastGELU, GeluAndMul, NewGELU, QuickGELU, SiluAndMul) +from vllm.utils import seed_everything from .allclose_default import get_default_atol, get_default_rtol @@ -34,9 +35,7 @@ def test_act_and_mul( seed: int, device: str, ) -> None: - torch.random.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) + seed_everything(seed) torch.set_default_device(device) x = torch.randn(num_tokens, 2 * d, dtype=dtype) if activation == "silu": @@ -77,9 +76,7 @@ def test_activation( seed: int, device: str, ) -> None: - torch.random.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) + seed_everything(seed) torch.set_default_device(device) x = torch.randn(num_tokens, d, dtype=dtype) layer = activation[0]() diff --git a/tests/kernels/test_attention.py b/tests/kernels/test_attention.py index 7995f11f19e98..4bd6f7863a658 100644 --- a/tests/kernels/test_attention.py +++ b/tests/kernels/test_attention.py @@ -3,15 +3,17 @@ import pytest import torch -from xformers import ops as xops -from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask from tests.kernels.utils import opcheck from vllm import _custom_ops as ops -from vllm.utils import get_max_shared_memory_bytes, is_hip +from vllm.utils import get_max_shared_memory_bytes, is_hip, seed_everything from .allclose_default import get_default_atol, get_default_rtol +if not is_hip(): + from xformers import ops as xops + from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask + FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 # This will change depending on the compute capability. # - 512 as a buffer @@ -137,10 +139,8 @@ def test_paged_attention( ) -> None: if kv_cache_dtype == "fp8" and head_size % 16: pytest.skip() - random.seed(seed) - torch.random.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) + + seed_everything(seed) torch.set_default_device(device) scale = float(1.0 / (head_size**0.5)) num_query_heads, num_kv_heads = num_heads @@ -328,6 +328,162 @@ def ref_multi_query_kv_attention( return torch.cat(ref_outputs, dim=0) +@pytest.mark.parametrize("version", ["rocm"]) +@pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS) +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("head_size", [64, 128]) # only test 64 128 +@pytest.mark.parametrize("use_alibi", USE_ALIBI) +@pytest.mark.parametrize("block_size", BLOCK_SIZES) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("kv_cache_dtype", ["auto"]) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.skipif(not is_hip(), reason="only for rocm") +def test_paged_attention_rocm( + kv_cache_factory, + version: str, + num_seqs: int, + num_heads: Tuple[int, int], + head_size: int, + use_alibi: bool, + block_size: int, + dtype: torch.dtype, + kv_cache_dtype: str, + seed: int, + device: str, +) -> None: + seed_everything(seed) + torch.set_default_device(device) + scale = float(1.0 / (head_size**0.5)) + num_query_heads, num_kv_heads = num_heads + query = torch.empty(num_seqs, num_query_heads, head_size, dtype=dtype) + query.uniform_(-scale, scale) + + assert num_query_heads % num_kv_heads == 0 + num_queries_per_kv = num_query_heads // num_kv_heads + alibi_slopes = None + if use_alibi: + alibi_slopes = torch.randn(num_query_heads, dtype=torch.float) + + context_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)] + context_lens[-1] = MAX_SEQ_LEN + #context_lens = [8192 for _ in range(num_seqs)] + max_context_len = max(context_lens) + context_lens = torch.tensor(context_lens, dtype=torch.int) + #print('>>> ctx lens', context_lens) + + # Create the block tables. + max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size + block_tables = [] + for _ in range(num_seqs): + block_table = [ + random.randint(0, NUM_BLOCKS - 1) + for _ in range(max_num_blocks_per_seq) + ] + block_tables.append(block_table) + block_tables = torch.tensor(block_tables, dtype=torch.int) + + # Create the KV caches. + key_caches, value_caches = kv_cache_factory(NUM_BLOCKS, block_size, 1, + num_kv_heads, head_size, + kv_cache_dtype, dtype, seed, + device) + key_cache, value_cache = key_caches[0], value_caches[0] + + # TODO(charlifu) enable fp8 kv cache + # Using default kv_scale + # kv_scale = 1.0 + + # Call the paged attention kernel. + output = torch.empty_like(query) + PARTITION_SIZE_ROCM = 256 + num_partitions = ((max_context_len + PARTITION_SIZE_ROCM - 1) // + PARTITION_SIZE_ROCM) + assert PARTITION_SIZE_ROCM % block_size == 0 + num_seqs, num_heads, head_size = output.shape + tmp_output = torch.empty( + size=(num_seqs, num_heads, num_partitions, head_size), + dtype=output.dtype, + ) + exp_sums = torch.empty( + size=(num_seqs, num_heads, num_partitions), + dtype=torch.float32, + ) + max_logits = torch.empty_like(exp_sums) + if version == "rocm": + ops.paged_attention_rocm( + output, + exp_sums, + max_logits, + tmp_output, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + context_lens, + block_size, + max_context_len, + alibi_slopes, + kv_cache_dtype, + ) + else: + raise AssertionError(f"Unknown version: {version}") + + # Run the reference implementation. + if kv_cache_dtype == "fp8": + # Convert cache data back to dtype. + x = 16 // torch.tensor([], dtype=dtype).element_size() + key_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size // x, + block_size, x) + dequantized_key_cache = torch.empty(size=key_cache_shape, + dtype=dtype, + device=device) + ops.convert_fp8(key_cache, dequantized_key_cache) + key_cache = dequantized_key_cache + + value_cache_shape = value_cache.shape + dequantized_value_cache = torch.empty(size=value_cache_shape, + dtype=dtype, + device=device) + ops.convert_fp8(value_cache, dequantized_value_cache) + value_cache = dequantized_value_cache + + ref_output = torch.empty_like(query) + ref_single_query_cached_kv_attention( + ref_output, + query, + num_queries_per_kv, + key_cache, + value_cache, + block_tables, + context_lens, + scale, + alibi_slopes, + ) + + # NOTE(woosuk): Due to the kernel-level differences in the two + # implementations, there is a small numerical difference in the two + # outputs. Thus, we use a relaxed tolerance for the test. + atol = get_default_atol(output) if is_hip() else 1e-3 + rtol = get_default_rtol(output) if is_hip() else 1e-5 + + # NOTE(zhaoyang): FP8 KV Cache will introduce quantization error, + # so we use a relaxed tolerance for the test. + atol, rtol = 1e-4, 1e-5 + if dtype == torch.bfloat16: + atol, rtol = 2e-4, 1e-5 + if use_alibi: + if dtype == torch.half: + atol, rtol = 5e-4, 1e-5 + if dtype == torch.bfloat16: + atol, rtol = 1e-3, 1e-5 + if kv_cache_dtype == "fp8": + atol, rtol = 1e-2, 1e-5 + assert torch.allclose(output, ref_output, atol=atol, rtol=rtol) + + # TODO(woosuk): Add tests for USE_ALIBI=True. @pytest.mark.parametrize("num_seqs", NUM_PREFILL_SEQS) @pytest.mark.parametrize("num_heads", NUM_HEADS) @@ -335,6 +491,7 @@ def ref_multi_query_kv_attention( @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.skipif(is_hip(), reason="skip for rocm") @torch.inference_mode() def test_multi_query_kv_attention( num_seqs: int, @@ -344,10 +501,7 @@ def test_multi_query_kv_attention( seed: int, device: str, ) -> None: - random.seed(seed) - torch.random.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) + seed_everything(seed) torch.set_default_device(device) # MAX_SEQ_LEN sometimes causes OOM in the reference implementation. # As the xformers library is already tested with its own tests, we can use diff --git a/tests/kernels/test_attention_selector.py b/tests/kernels/test_attention_selector.py index a20a741c27f74..c1fb45955a0e5 100644 --- a/tests/kernels/test_attention_selector.py +++ b/tests/kernels/test_attention_selector.py @@ -45,7 +45,7 @@ def test_flash_attn(monkeypatch): override_backend_env_variable(monkeypatch, STR_FLASH_ATTN_VAL) # Unsupported CUDA arch - with patch("torch.cuda.get_device_capability", return_value=[7, 5]): + with patch("torch.cuda.get_device_capability", return_value=(7, 5)): backend = which_attn_to_use(8, 16, 8, None, torch.float16, None, 16) assert backend.name != STR_FLASH_ATTN_VAL diff --git a/tests/kernels/test_awq_triton.py b/tests/kernels/test_awq_triton.py index 198d40a155ccb..e95e5bd948212 100644 --- a/tests/kernels/test_awq_triton.py +++ b/tests/kernels/test_awq_triton.py @@ -7,6 +7,7 @@ from vllm.model_executor.layers.quantization.awq_triton import ( AWQ_TRITON_SUPPORTED_GROUP_SIZES, awq_dequantize_triton, awq_gemm_triton) +from vllm.utils import seed_everything device = "cuda" @@ -79,7 +80,7 @@ def test_dequantize(qweight_rows, qweight_cols, group_size): zeros_cols = qweight_cols zeros_dtype = torch.int32 - torch.manual_seed(0) + seed_everything(0) qweight = torch.randint(0, torch.iinfo(torch.int32).max, @@ -133,7 +134,7 @@ def test_gemm(N, K, M, splitK, group_size): qzeros_rows = scales_rows qzeros_cols = qweight_cols - torch.manual_seed(0) + seed_everything(0) input = torch.rand((input_rows, input_cols), dtype=input_dtype, diff --git a/tests/kernels/test_blocksparse_attention.py b/tests/kernels/test_blocksparse_attention.py index 7357508751ae1..f3bd8f0524264 100644 --- a/tests/kernels/test_blocksparse_attention.py +++ b/tests/kernels/test_blocksparse_attention.py @@ -7,7 +7,7 @@ from vllm import _custom_ops as ops from vllm.attention.ops.blocksparse_attention.interface import ( LocalStridedBlockSparseAttn) -from vllm.utils import get_max_shared_memory_bytes, is_hip +from vllm.utils import get_max_shared_memory_bytes, is_hip, seed_everything from .allclose_default import get_default_atol, get_default_rtol @@ -172,10 +172,7 @@ def test_paged_attention( blocksparse_block_size: int, blocksparse_head_sliding_step: int, ) -> None: - random.seed(seed) - torch.random.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) + seed_everything(seed) torch.set_default_device(device) scale = float(1.0 / (head_size**0.5)) num_query_heads, num_kv_heads = num_heads @@ -386,10 +383,7 @@ def test_varlen_blocksparse_attention_prefill( seed: int, device: str, ) -> None: - random.seed(seed) - torch.random.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) + seed_everything(seed) torch.set_default_device(device) # MAX_SEQ_LEN sometimes causes OOM in the reference implementation. # As the xformers library is already tested with its own tests, we can use diff --git a/tests/kernels/test_cache.py b/tests/kernels/test_cache.py index 19402a337b8d6..b0e7097fdfbd4 100644 --- a/tests/kernels/test_cache.py +++ b/tests/kernels/test_cache.py @@ -6,6 +6,7 @@ from tests.kernels.utils import DEFAULT_OPCHECK_TEST_UTILS, opcheck from vllm import _custom_ops as ops +from vllm.utils import seed_everything COPYING_DIRECTION = [('cuda', 'cpu'), ('cuda', 'cuda'), ('cpu', 'cuda')] DTYPES = [torch.half, torch.bfloat16, torch.float] @@ -55,10 +56,7 @@ def test_copy_blocks( ) -> None: if kv_cache_dtype == "fp8" and head_size % 16: pytest.skip() - random.seed(seed) - torch.random.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) + seed_everything(seed) torch.set_default_device(device) # Generate random block mappings where each source block is mapped to two # destination blocks. @@ -134,10 +132,7 @@ def test_reshape_and_cache( ) -> None: if kv_cache_dtype == "fp8" and head_size % 16: pytest.skip() - random.seed(seed) - torch.random.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) + seed_everything(seed) torch.set_default_device(device) # Create a random slot mapping. num_slots = block_size * num_blocks @@ -229,9 +224,7 @@ def test_reshape_and_cache_flash( device: str, kv_cache_dtype: str, ) -> None: - random.seed(seed) - torch.random.manual_seed(seed) - torch.cuda.manual_seed(seed) + seed_everything(seed) torch.set_default_device(device) # Create a random slot mapping. @@ -345,10 +338,8 @@ def test_swap_blocks( pytest.skip() if kv_cache_dtype == "fp8" and head_size % 16: pytest.skip() - random.seed(seed) - torch.random.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) + + seed_everything(seed) src_device = device if direction[0] == "cuda" else 'cpu' dst_device = device if direction[1] == "cuda" else 'cpu' @@ -417,9 +408,7 @@ def test_fp8_e4m3_conversion( seed: int, device: str, ) -> None: - random.seed(seed) - torch.random.manual_seed(seed) - torch.cuda.manual_seed(seed) + seed_everything(seed) low = -224.0 high = 224.0 diff --git a/tests/kernels/test_causal_conv1d.py b/tests/kernels/test_causal_conv1d.py index 7bf338b36953a..043c4923bd660 100644 --- a/tests/kernels/test_causal_conv1d.py +++ b/tests/kernels/test_causal_conv1d.py @@ -7,6 +7,7 @@ from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( causal_conv1d_fn, causal_conv1d_update) +from vllm.utils import seed_everything def causal_conv1d_ref( @@ -104,7 +105,7 @@ def test_causal_conv1d(batch, dim, seqlen, width, has_bias, silu_activation, if itype == torch.bfloat16: rtol, atol = 1e-2, 5e-2 # set seed - torch.random.manual_seed(0) + seed_everything(0) if not channel_last: x = torch.randn(batch, 4096 + dim + 64, @@ -175,7 +176,7 @@ def test_causal_conv1d_update(batch, dim, width, has_bias, silu_activation, if itype == torch.bfloat16: rtol, atol = 1e-2, 5e-2 # set seed - torch.random.manual_seed(0) + seed_everything(0) batch = 2 x = torch.randn(batch, dim, device=device, dtype=itype) conv_state = torch.randn(batch, dim, width, device=device, dtype=itype) @@ -203,3 +204,61 @@ def test_causal_conv1d_update(batch, dim, width, has_bias, silu_activation, assert torch.equal(conv_state, conv_state_ref) assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) + + +@pytest.mark.parametrize("itype", + [torch.float32, torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("silu_activation", [False, True]) +@pytest.mark.parametrize("has_bias", [False, True]) +@pytest.mark.parametrize("seqlen", [1, 4, 5]) +@pytest.mark.parametrize("width", [2, 3, 4]) +@pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096]) +def test_causal_conv1d_update_with_batch_gather(dim, width, seqlen, has_bias, + silu_activation, itype): + device = "cuda" + rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3) + if itype == torch.bfloat16: + rtol, atol = 1e-2, 5e-2 + + # set seed + torch.random.manual_seed(0) + batch = 64 + + x = torch.randn(batch, dim, device=device, dtype=itype) + + total_entries = 10 * batch + conv_state = torch.randn(total_entries, + dim, + width, + device=device, + dtype=itype) + conv_state_indices = torch.randperm(total_entries)[:batch].to( + dtype=torch.int32, device=device) + + weight = torch.randn(dim, + width, + device=device, + dtype=itype, + requires_grad=True) + if has_bias: + bias = torch.randn(dim, device=device, dtype=itype, requires_grad=True) + else: + bias = None + conv_state_ref = conv_state[conv_state_indices, :].detach().clone() + activation = None if not silu_activation else "silu" + out = causal_conv1d_update(x, + conv_state, + weight, + bias, + activation=activation, + conv_state_indices=conv_state_indices) + out_ref = causal_conv1d_update_ref(x, + conv_state_ref, + weight, + bias, + activation=activation) + + print(f"Output max diff: {(out - out_ref).abs().max().item()}") + print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") + assert torch.equal(conv_state[conv_state_indices, :], conv_state_ref) + assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) diff --git a/tests/kernels/test_cutlass.py b/tests/kernels/test_cutlass.py index d1f0524f83c4c..cc4ca2e91e76f 100644 --- a/tests/kernels/test_cutlass.py +++ b/tests/kernels/test_cutlass.py @@ -15,9 +15,6 @@ f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) ] -capability = current_platform.get_device_capability() -capability = capability[0] * 10 + capability[1] - def to_fp8(tensor: torch.Tensor): finfo = torch.finfo(torch.float8_e4m3fn) @@ -119,7 +116,7 @@ def cutlass_int8_gemm_helper(m: int, @pytest.mark.parametrize("per_act_token", [True, False]) @pytest.mark.parametrize("per_out_ch", [True, False]) @pytest.mark.parametrize("use_bias", [True, False]) -@pytest.mark.skipif(capability < 89, +@pytest.mark.skipif(not current_platform.has_device_capability(89), reason="FP8 is not supported on this GPU type.") def test_cutlass_fp8_gemm(m: int, n: int, k: int, per_act_token: bool, per_out_ch: bool, use_bias: bool): @@ -157,7 +154,7 @@ def test_cutlass_int8_gemm_output_dtype(per_act_token: bool, per_out_ch: bool, @pytest.mark.parametrize("per_out_ch", [True, False]) @pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16]) @pytest.mark.parametrize("use_bias", [True, False]) -@pytest.mark.skipif(capability < 89, +@pytest.mark.skipif(not current_platform.has_device_capability(89), reason="FP8 is not supported on this GPU type.") def test_cutlass_fp8_gemm_output_dtype(per_act_token: bool, per_out_ch: bool, out_dtype: Type[torch.dtype], @@ -175,7 +172,7 @@ def test_cutlass_fp8_gemm_output_dtype(per_act_token: bool, per_out_ch: bool, @pytest.mark.parametrize("per_out_ch", [True, False]) @pytest.mark.parametrize("use_bias", [True, False]) @pytest.mark.parametrize("device", CUDA_DEVICES) -@pytest.mark.skipif(capability < 89, +@pytest.mark.skipif(not current_platform.has_device_capability(89), reason="FP8 is not supported on this GPU type.") def test_cutlass_fp8_gemm_devices(per_act_token: bool, per_out_ch: bool, use_bias: bool, device: str): @@ -207,7 +204,7 @@ def test_cutlass_int8_gemm_devices(per_act_token: bool, per_out_ch: bool, @pytest.mark.parametrize("per_act_token", [True, False]) @pytest.mark.parametrize("per_out_ch", [True, False]) @pytest.mark.parametrize("use_bias", [True, False]) -@pytest.mark.skipif(capability < 89, +@pytest.mark.skipif(not current_platform.has_device_capability(89), reason="FP8 is not supported on this GPU type.") def test_cutlass_fp8_gemm_m_sweep(per_act_token: bool, per_out_ch: bool, use_bias: bool): diff --git a/tests/kernels/test_flash_attn.py b/tests/kernels/test_flash_attn.py index 870a8bf65eb92..8e960d098c408 100644 --- a/tests/kernels/test_flash_attn.py +++ b/tests/kernels/test_flash_attn.py @@ -4,6 +4,7 @@ import torch import vllm.attention.backends.flash_attn # noqa: F401 +from vllm.utils import seed_everything NUM_HEADS = [(4, 4), (8, 2), (16, 2)] HEAD_SIZES = [128, 256] @@ -87,7 +88,7 @@ def test_flash_attn_with_paged_kv( num_blocks: int, ) -> None: torch.set_default_device("cuda") - torch.cuda.manual_seed_all(0) + seed_everything(0) num_seqs = len(kv_lens) num_query_heads = num_heads[0] num_kv_heads = num_heads[1] @@ -174,7 +175,7 @@ def test_varlen_with_paged_kv( num_blocks: int, ) -> None: torch.set_default_device("cuda") - torch.cuda.manual_seed_all(0) + seed_everything(0) num_seqs = len(seq_lens) query_lens = [x[0] for x in seq_lens] kv_lens = [x[1] for x in seq_lens] diff --git a/tests/kernels/test_flashinfer.py b/tests/kernels/test_flashinfer.py index 696cc0c6cdf10..80a388db6530e 100644 --- a/tests/kernels/test_flashinfer.py +++ b/tests/kernels/test_flashinfer.py @@ -4,6 +4,8 @@ import pytest import torch +from vllm.utils import seed_everything + NUM_HEADS = [(16, 16), (32, 8), (64, 8), (6, 1)] HEAD_SIZES = [128, 256] BLOCK_SIZES = [16, 32] @@ -82,7 +84,7 @@ def test_flashinfer_decode_with_paged_kv( soft_cap: Optional[float], ) -> None: torch.set_default_device("cuda") - torch.cuda.manual_seed_all(0) + seed_everything(0) num_seqs = len(kv_lens) num_query_heads = num_heads[0] num_kv_heads = num_heads[1] @@ -168,7 +170,7 @@ def test_flashinfer_prefill_with_paged_kv(seq_lens: List[Tuple[int, int]], block_size: int, soft_cap: Optional[float]) -> None: torch.set_default_device("cuda") - torch.cuda.manual_seed_all(0) + seed_everything(0) num_seqs = len(seq_lens) query_lens = [x[0] for x in seq_lens] kv_lens = [x[1] for x in seq_lens] @@ -266,7 +268,7 @@ def test_flashinfer_prefill_with_paged_fp8_kv( head_size: int, dtype: torch.dtype, block_size: int, soft_cap: Optional[float]) -> None: torch.set_default_device("cuda") - torch.cuda.manual_seed_all(0) + seed_everything(0) num_seqs = len(seq_lens) query_lens = [x[0] for x in seq_lens] kv_lens = [x[1] for x in seq_lens] @@ -379,7 +381,7 @@ def test_flashinfer_decode_with_paged_fp8_kv( ) -> None: # test doesn't work for num_heads = (16,16) torch.set_default_device("cuda") - torch.cuda.manual_seed_all(0) + seed_everything(0) num_seqs = len(kv_lens) num_query_heads = num_heads[0] num_kv_heads = num_heads[1] diff --git a/tests/kernels/test_fp8_quant.py b/tests/kernels/test_fp8_quant.py index bae9b39203ff9..49f5ce53aab54 100644 --- a/tests/kernels/test_fp8_quant.py +++ b/tests/kernels/test_fp8_quant.py @@ -5,6 +5,7 @@ from tests.kernels.quant_utils import (FP8_DTYPE, ref_dynamic_per_tensor_fp8_quant, ref_dynamic_per_token_quant) +from vllm.utils import seed_everything DTYPES = [torch.half, torch.bfloat16, torch.float] HIDDEN_SIZES = [1, 2, 3, 4, 16, 67, 768, 2048, 5120, 5137, 8192, @@ -24,8 +25,7 @@ def test_dynamic_per_token_fp8_quant(num_tokens: int, hidden_size: int, dtype: torch.dtype, scale_ub: bool, seed: int) -> None: - torch.random.manual_seed(seed) - torch.cuda.manual_seed(seed) + seed_everything(seed) x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") + 1e-6 # avoid nans @@ -49,8 +49,7 @@ def test_dynamic_per_token_fp8_quant(num_tokens: int, hidden_size: int, @torch.inference_mode() def test_dynamic_per_tensor_fp8_quant(num_tokens: int, hidden_size: int, dtype: torch.dtype, seed: int) -> None: - torch.random.manual_seed(seed) - torch.cuda.manual_seed(seed) + seed_everything(seed) x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") @@ -67,8 +66,7 @@ def test_dynamic_per_tensor_fp8_quant(num_tokens: int, hidden_size: int, @torch.inference_mode() @pytest.mark.parametrize("seed", SEEDS) def test_fp8_quant_large(seed: int) -> None: - torch.random.manual_seed(seed) - torch.cuda.manual_seed(seed) + seed_everything(seed) num_tokens = 1024000 # Mistral-Nemo's max_position_embeddings hidden_size = 1152 # Smallest hidden_size to reproduce the error diff --git a/tests/kernels/test_gguf.py b/tests/kernels/test_gguf.py new file mode 100644 index 0000000000000..1513fc196153c --- /dev/null +++ b/tests/kernels/test_gguf.py @@ -0,0 +1,127 @@ +from pathlib import Path +from typing import List + +import pytest +import torch +from gguf import GGMLQuantizationType, GGUFReader, ReaderTensor, dequantize +from huggingface_hub import snapshot_download + +import vllm._custom_ops as ops +from vllm.utils import seed_everything + +GGUF_SAMPLE = snapshot_download("Isotr0py/test-gguf-sample") + + +def get_gguf_sample_tensors( + hidden_size: int, + quant_type: GGMLQuantizationType) -> List[ReaderTensor]: + sample_dir = GGUF_SAMPLE + filename = f"Quant_{quant_type.name}_{hidden_size}.gguf" + sample_file = Path(sample_dir) / filename + return GGUFReader(sample_file).tensors + + +DTYPES = [torch.half] +# Hidden_size for testing, must match the sample file in HF repo, +# we have `hidden_size = 256, 1024` for test in HF repo currently. +HIDDEN_SIZES = [256, 1024] +NUM_TOKENS = [7, 83, 128, 2048] # Arbitrary values for testing +SEEDS = [0] +QUANT_TYPES = [ + # i-matrix + GGMLQuantizationType.IQ1_M, + GGMLQuantizationType.IQ1_S, + GGMLQuantizationType.IQ2_S, + GGMLQuantizationType.IQ2_XS, + GGMLQuantizationType.IQ3_S, + GGMLQuantizationType.IQ3_XXS, + GGMLQuantizationType.IQ4_NL, + GGMLQuantizationType.IQ4_XS, + # k-quants + GGMLQuantizationType.Q2_K, + GGMLQuantizationType.Q3_K, + GGMLQuantizationType.Q4_K, + GGMLQuantizationType.Q5_K, + GGMLQuantizationType.Q6_K, + # standard quantization + GGMLQuantizationType.Q4_0, + GGMLQuantizationType.Q5_0, + GGMLQuantizationType.Q8_0, +] + + +@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("quant_type", QUANT_TYPES) +@torch.inference_mode() +def test_dequantize(hidden_size: int, dtype: torch.dtype, + quant_type: GGMLQuantizationType): + tensors = get_gguf_sample_tensors(hidden_size, quant_type) + for tensor in tensors: + shape_str = tensor.name.split("_")[-1] + shape = map(int, shape_str.split("x")) + + ref_output = torch.tensor(dequantize(tensor.data, quant_type), + device="cuda").to(dtype) + output = ops.ggml_dequantize(torch.tensor(tensor.data, device="cuda"), + quant_type, *list(shape)).to(dtype) + + torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=4e-2) + + +@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("quant_type", QUANT_TYPES) +@torch.inference_mode() +def test_mmvq(hidden_size: int, dtype: torch.dtype, + quant_type: GGMLQuantizationType): + seed_everything(0) + + tensors = get_gguf_sample_tensors(hidden_size, quant_type) + x = torch.rand((1, hidden_size), dtype=dtype, device="cuda") + for tensor in tensors: + weight = torch.tensor(dequantize(tensor.data, quant_type), + device="cuda").to(dtype) + ref_output = x @ weight.T + + qweight = torch.tensor(tensor.data, device="cuda") + output = ops.ggml_mul_mat_vec_a8(qweight, x, quant_type, + qweight.shape[0]).to(dtype) + + torch.testing.assert_close(output, ref_output, atol=1, rtol=1e-1) + + +@pytest.mark.parametrize("num_tokens", NUM_TOKENS) +@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize( + "quant_type", + [ + # k-quants + GGMLQuantizationType.Q2_K, + GGMLQuantizationType.Q3_K, + GGMLQuantizationType.Q4_K, + GGMLQuantizationType.Q5_K, + GGMLQuantizationType.Q6_K, + # standard quants + GGMLQuantizationType.Q4_0, + GGMLQuantizationType.Q5_0, + GGMLQuantizationType.Q8_0, + ]) +@torch.inference_mode() +def test_mmq(num_tokens: int, hidden_size: int, dtype: torch.dtype, + quant_type: GGMLQuantizationType): + seed_everything(0) + + tensors = get_gguf_sample_tensors(hidden_size, quant_type) + x = torch.rand((num_tokens, hidden_size), dtype=dtype, device="cuda") + for tensor in tensors: + weight = torch.tensor(dequantize(tensor.data, quant_type), + device="cuda").to(dtype) + ref_output = x @ weight.T + + qweight = torch.tensor(tensor.data, device="cuda") + output = ops.ggml_mul_mat_a8(qweight, x, quant_type, + qweight.shape[0]).to(dtype) + + torch.testing.assert_close(output, ref_output, atol=1, rtol=1e-1) diff --git a/tests/kernels/test_int8_quant.py b/tests/kernels/test_int8_quant.py index a82ecb026482e..41e103e1d09f9 100644 --- a/tests/kernels/test_int8_quant.py +++ b/tests/kernels/test_int8_quant.py @@ -4,6 +4,7 @@ from tests.kernels.quant_utils import ref_dynamic_per_token_quant from tests.kernels.utils import opcheck from vllm._custom_ops import scaled_int8_quant +from vllm.utils import seed_everything DTYPES = [torch.half, torch.bfloat16, torch.float] HIDDEN_SIZES = [16, 67, 768, 2048, 5120, 5137, 8192, @@ -13,14 +14,28 @@ SCALE = [0.1, 0.5, 0.8, 1.2, 2.1] -def opcheck_int8_quant(output, input, scale=None): - if scale is not None: - opcheck(torch.ops._C.static_scaled_int8_quant, (output, input, scale)) +def opcheck_int8_quant_static(output, input, scale, azp=None): + if azp is None: + opcheck(torch.ops._C.static_scaled_int8_quant, + (output, input, scale, None)) else: - scale = torch.empty((input.numel() // input.shape[-1], 1), - device=input.device, - dtype=torch.float32) - opcheck(torch.ops._C.dynamic_scaled_int8_quant, (output, input, scale)) + opcheck(torch.ops._C.static_scaled_int8_quant, + (output, input, scale, azp)) + + +def opcheck_int8_quant_dynamic(output, input, symmetric=True): + scale = torch.empty((input.numel() // input.shape[-1], 1), + device=input.device, + dtype=torch.float32) + if symmetric: + opcheck(torch.ops._C.dynamic_scaled_int8_quant, + (output, input, scale, None)) + else: + azp = torch.empty((input.numel() // input.shape[-1], 1), + device=input.device, + dtype=torch.int32) + opcheck(torch.ops._C.dynamic_scaled_int8_quant, + (output, input, scale, azp)) @pytest.mark.parametrize("num_tokens", NUM_TOKENS) @@ -30,22 +45,62 @@ def opcheck_int8_quant(output, input, scale=None): @torch.inference_mode() def test_dynamic_scaled_int8_quant(num_tokens: int, hidden_size: int, dtype: torch.dtype, seed: int) -> None: - torch.random.manual_seed(seed) - torch.cuda.manual_seed(seed) + seed_everything(seed) x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") * 1000 # reference ref_out, ref_scales = ref_dynamic_per_token_quant(x, torch.int8) # kernel - ops_out, ops_scales = scaled_int8_quant(x) + ops_out, ops_scales, _ = scaled_int8_quant(x) torch.testing.assert_close(ops_scales, ref_scales) - torch.testing.assert_close( - ops_out, ref_out, atol=1, - rtol=0.0) # big atol to account for rounding errors + # big atol to account for rounding errors + torch.testing.assert_close(ops_out, ref_out, atol=1, rtol=0.0) - opcheck_int8_quant(ops_out, x) + opcheck_int8_quant_dynamic(ops_out, x) + + +@pytest.mark.parametrize("num_tokens", NUM_TOKENS) +@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) +@torch.inference_mode() +def test_dynamic_scaled_int8_azp_quant(num_tokens: int, hidden_size: int, + dtype: torch.dtype, seed: int) -> None: + seed_everything(seed) + int8_traits = torch.iinfo(torch.int8) + + x = torch.rand(num_tokens, hidden_size, dtype=dtype, + device="cuda") * 1000 - 300 + + x_token_max, _ = x.to(dtype=torch.float32).max(dim=1, keepdim=True) + x_token_min, _ = x.to(dtype=torch.float32).min(dim=1, keepdim=True) + + # calculate scale and azp, and adjust the range + scales = (x_token_max - x_token_min) / torch.tensor(255.0) + azps = torch.round(torch.tensor(-128.0) - x_token_min / scales).to( + torch.int32) + + torch_out = ((x / scales).round() + azps).clamp( + int8_traits.min, int8_traits.max).to(torch.int8) + assert torch_out.min() >= int8_traits.min and torch_out.max( + ) <= int8_traits.max + + ops_out = torch.empty_like(x, dtype=torch.int8) + scales_out = torch.empty_like(scales, dtype=torch.float32) + azp_out = torch.empty_like(azps, dtype=torch.int32) + torch.ops._C.dynamic_scaled_int8_quant(ops_out, x, scales_out, azp_out) + + if (not torch.allclose(scales_out, scales)): + print(torch.argmax(torch.abs(scales_out - scales))) + torch.testing.assert_close(scales_out, scales) + # big atol to account for rounding errors + torch.testing.assert_close(azp_out, azps, atol=1, rtol=0.0) + # if AZP is off by 1, after rounding-to-even, the output may be off by 2 + torch.testing.assert_close(ops_out, torch_out, atol=2, rtol=0.0) + + opcheck_int8_quant_dynamic(ops_out, x, False) @pytest.mark.parametrize("num_tokens", NUM_TOKENS) @@ -57,19 +112,79 @@ def test_dynamic_scaled_int8_quant(num_tokens: int, hidden_size: int, def test_static_scaled_int8_quant(num_tokens: int, hidden_size: int, dtype: torch.dtype, seed: int, scale: float) -> None: - torch.random.manual_seed(seed) - torch.cuda.manual_seed(seed) + seed_everything(seed) int8_traits = torch.iinfo(torch.int8) x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") * 1000 - scale = torch.tensor([scale], dtype=torch.float32, device="cuda") + scale_arg = torch.tensor([scale], dtype=torch.float32, device="cuda") + + out1 = (x / scale_arg).round().clamp(int8_traits.min, + int8_traits.max).to(torch.int8) + out2, _, _ = scaled_int8_quant(x, scale_arg) + + # big atol to account for rounding errors + torch.testing.assert_close(out1, out2, atol=1, rtol=0.0) + + opcheck_int8_quant_static(out2, x, scale_arg) - out1 = (x / scale).round().clamp(int8_traits.min, - int8_traits.max).to(torch.int8) - out2, _ = scaled_int8_quant(x, scale) - torch.testing.assert_close( - out1, out2, atol=1, - rtol=0.0) # big atol to account for rounding errors +@pytest.mark.parametrize("num_tokens", NUM_TOKENS) +@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("scale", SCALE[2:]) # Reduce test time +@pytest.mark.parametrize("azp", [-255, 54]) +@torch.inference_mode() +def test_static_scaled_int8_azp_quant(num_tokens: int, hidden_size: int, + dtype: torch.dtype, seed: int, + scale: float, azp: int) -> None: + seed_everything(seed) + int8_traits = torch.iinfo(torch.int8) + + x = torch.rand(num_tokens, hidden_size, dtype=dtype, + device="cuda") * 1000 - 300 + + out1 = ((x / scale).round() + azp).clamp(int8_traits.min, + int8_traits.max).to(torch.int8) + out2 = torch.empty_like(x, dtype=torch.int8) + scale_arg = torch.tensor([scale], dtype=torch.float32, device="cuda") + azp_arg = torch.tensor([azp], dtype=torch.int32, device="cuda") + + torch.ops._C.static_scaled_int8_quant(out2, x, scale_arg, azp_arg) + + # big atol to account for rounding errors + torch.testing.assert_close(out1, out2, atol=1, rtol=0.0) + + opcheck_int8_quant_static(out2, x, scale_arg, azp_arg) + + +@pytest.mark.parametrize("is_max", [True, False]) +@torch.inference_mode() +def test_static_scaled_int8_azp_quant_saturating_cast(is_max: bool) -> None: + # Test that the saturating cast works correctly for values near i32 max/min + + from numpy import inf, nextafter + + int32_traits = torch.iinfo(torch.int32) + val = float(int32_traits.max if is_max else int32_traits.min) + + x_vals = [[ + nextafter(val, inf), val + 1, val, val - 1, + nextafter(val, -inf) + ]] + x = torch.tensor(x_vals, dtype=torch.float32, device="cuda") + + # The calculation in the kernel is: cast(cast(x / scale) + azp) + # where cast is a saturating cast to type T. + # Scale is set to 1.0 so that the input values are the ones that are cast. + # AZP is set to 0 to make sure the int8 saturating cast is tested as well. + scale = torch.scalar_tensor(1.0, dtype=torch.float32, device="cuda") + azp = torch.scalar_tensor(0, dtype=torch.int32, device="cuda") + + int8_traits = torch.iinfo(torch.int8) + val_i8 = int8_traits.max if is_max else int8_traits.min + expected = torch.full((1, 5), val_i8, dtype=torch.int8, device="cuda") - opcheck_int8_quant(out2, x, scale) + out = torch.empty_like(expected) + torch.ops._C.static_scaled_int8_quant(out, x, scale, azp) + torch.testing.assert_close(expected, out, atol=0, rtol=0) diff --git a/tests/kernels/test_layernorm.py b/tests/kernels/test_layernorm.py index 6eaf67ec75f41..382079d472ee9 100644 --- a/tests/kernels/test_layernorm.py +++ b/tests/kernels/test_layernorm.py @@ -3,6 +3,7 @@ from tests.kernels.utils import opcheck from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.utils import seed_everything DTYPES = [torch.half, torch.bfloat16, torch.float] NUM_TOKENS = [7, 83, 4096] # Arbitrary values for testing @@ -30,9 +31,7 @@ def test_rms_norm( seed: int, device: str, ) -> None: - torch.random.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) + seed_everything(seed) torch.set_default_device(device) layer = RMSNorm(hidden_size).to(dtype=dtype) layer.weight.data.normal_(mean=1.0, std=0.1) diff --git a/tests/kernels/test_machete_gemm.py b/tests/kernels/test_machete_gemm.py index ce65aaef60ac6..0a90882223077 100644 --- a/tests/kernels/test_machete_gemm.py +++ b/tests/kernels/test_machete_gemm.py @@ -48,7 +48,7 @@ # `is_quant_method_supported` conflates kernels with quantization methods # an assumption which is breaking down as quantizations methods can have # have kernels and some kernels support multiple quantization methods. -IS_SUPPORTED_BY_GPU = current_platform.get_device_capability()[0] >= 9 +IS_SUPPORTED_BY_GPU = current_platform.has_device_capability(90) def rand_data(shape, dtype=torch.float16): diff --git a/tests/kernels/test_mamba_ssm.py b/tests/kernels/test_mamba_ssm.py index d3cb0a8656a02..f582445692344 100644 --- a/tests/kernels/test_mamba_ssm.py +++ b/tests/kernels/test_mamba_ssm.py @@ -5,6 +5,7 @@ from vllm.model_executor.layers.mamba.ops.mamba_ssm import ( selective_scan_fn, selective_state_update) +from vllm.utils import seed_everything def selective_state_update_ref(state, @@ -186,7 +187,7 @@ def test_selective_scan(is_variable_B, is_variable_C, varBC_groups, has_D, rtolw = max(rtolw, rtol) atolw = max(atolw, atol) # set seed - torch.random.manual_seed(0) + seed_everything(0) batch_size = 2 dim = 4 dstate = 8 @@ -287,7 +288,7 @@ def test_selective_state_update(dim, dstate, has_z, itype): if torch.version.hip: atol *= 2 # set seed - torch.random.manual_seed(0) + seed_everything(0) batch_size = 1 state = torch.randn(batch_size, dim, dstate, dtype=itype, device=device) x = torch.randn(batch_size, dim, device=device, dtype=itype) diff --git a/tests/kernels/test_moe.py b/tests/kernels/test_moe.py index 2250cf1598b8b..b1f0516dfa0b3 100644 --- a/tests/kernels/test_moe.py +++ b/tests/kernels/test_moe.py @@ -18,6 +18,7 @@ marlin_quantize) from vllm.model_executor.models.mixtral import MixtralMoE from vllm.scalar_type import scalar_types +from vllm.utils import seed_everything def torch_moe(a, w1, w2, score, topk): @@ -140,6 +141,7 @@ def compute_max_diff(output, output_ref): @pytest.mark.parametrize("topk", [2, 6]) @pytest.mark.parametrize("group_size", [-1, 32, 64, 128]) @pytest.mark.parametrize("act_order", [True, False]) +@pytest.mark.parametrize("num_bits", [4, 8]) def test_fused_marlin_moe( m: int, n: int, @@ -148,8 +150,9 @@ def test_fused_marlin_moe( topk: int, group_size: int, act_order: bool, + num_bits: int, ): - torch.manual_seed(7) + seed_everything(7) if topk > e: return @@ -161,13 +164,12 @@ def test_fused_marlin_moe( if group_size in (k, n): return - quant_type = scalar_types.uint4b8 + quant_type = (scalar_types.uint4b8 + if num_bits == 4 else scalar_types.uint8b128) dtype = torch.float16 a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 - for i in range(w2.shape[0]): - w2[0] = torch.eye(k, n, device="cuda", dtype=dtype) w_ref1_l = [] qweight1_l = [] @@ -240,6 +242,7 @@ def test_fused_marlin_moe( topk_ids, w1_scale=scales1, w2_scale=scales2, + num_bits=num_bits, ) assert compute_max_diff(marlin_output, triton_output) < 4e-2 @@ -254,7 +257,8 @@ def test_fused_marlin_moe( @pytest.mark.parametrize("topk", [2, 6]) @pytest.mark.parametrize("group_size", [-1, 32, 64, 128]) @pytest.mark.parametrize("act_order", [True, False]) -def test_marlin_moe_mmm( +@pytest.mark.parametrize("num_bits", [4, 8]) +def test_single_marlin_moe_multiply( m: int, n: int, k: int, @@ -262,6 +266,7 @@ def test_marlin_moe_mmm( topk: int, group_size: int, act_order: bool, + num_bits: int, ): if topk > e: return @@ -273,7 +278,8 @@ def test_marlin_moe_mmm( if group_size == k: return - quant_type = scalar_types.uint4b8 + quant_type = (scalar_types.uint4b8 + if num_bits == 4 else scalar_types.uint8b128) dtype = torch.float16 a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 w = torch.randn((e, n, k), device="cuda", dtype=dtype) / 10 @@ -308,7 +314,8 @@ def test_marlin_moe_mmm( g_idx, sort_indices, topk, - renormalize=False) + renormalize=False, + num_bits=num_bits) torch_output = torch_moe_single(a, w_ref.transpose(1, 2), score, topk) assert compute_max_diff(marlin_output, torch_output) < 1e-2 diff --git a/tests/kernels/test_pos_encoding.py b/tests/kernels/test_pos_encoding.py index 65242e275650c..ba9d2d4389b21 100644 --- a/tests/kernels/test_pos_encoding.py +++ b/tests/kernels/test_pos_encoding.py @@ -5,6 +5,7 @@ import torch from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.utils import seed_everything from .allclose_default import get_default_atol, get_default_rtol @@ -46,9 +47,8 @@ def test_rotary_embedding( ) -> None: if rotary_dim is None: rotary_dim = head_size - torch.random.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) + + seed_everything(seed) torch.set_default_device(device) if rotary_dim is None: rotary_dim = head_size @@ -100,9 +100,7 @@ def test_batched_rotary_embedding( max_position: int = 8192, base: int = 10000, ) -> None: - torch.random.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) + seed_everything(seed) torch.set_default_device(device) if rotary_dim is None: rotary_dim = head_size @@ -162,9 +160,7 @@ def test_batched_rotary_embedding_multi_lora( max_position: int = 8192, base: int = 10000, ) -> None: - torch.random.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) + seed_everything(seed) torch.set_default_device(device) if rotary_dim is None: rotary_dim = head_size diff --git a/tests/kernels/test_prefix_prefill.py b/tests/kernels/test_prefix_prefill.py index 60f9a4dc9f90f..3181d92562399 100644 --- a/tests/kernels/test_prefix_prefill.py +++ b/tests/kernels/test_prefix_prefill.py @@ -9,7 +9,7 @@ from vllm.attention.backends.xformers import _make_alibi_bias from vllm.attention.ops.prefix_prefill import context_attention_fwd -from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE +from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, seed_everything NUM_HEADS = [64] NUM_QUERIES_PER_KV = [1, 8, 64] @@ -39,10 +39,7 @@ def test_contexted_kv_attention( kv_cache_dtype: str, device: str, ) -> None: - random.seed(0) - torch.manual_seed(0) - if torch.cuda.is_available(): - torch.cuda.manual_seed(0) + seed_everything(0) torch.set_default_device(device) # Need this, otherwise when we capture the graph the process @@ -237,10 +234,7 @@ def test_contexted_kv_attention_alibi( kv_cache_dtype: str, device: str, ) -> None: - random.seed(0) - torch.manual_seed(0) - if torch.cuda.is_available(): - torch.cuda.manual_seed(0) + seed_everything(0) torch.set_default_device(device) # Need this, otherwise when we capture the graph the process diff --git a/tests/kernels/test_rand.py b/tests/kernels/test_rand.py deleted file mode 100644 index a4242d22eb489..0000000000000 --- a/tests/kernels/test_rand.py +++ /dev/null @@ -1,52 +0,0 @@ -import random - -import pytest -import torch - -from vllm.model_executor.layers.ops.rand import seeded_uniform -from vllm.model_executor.utils import set_random_seed - - -@pytest.mark.parametrize("dtype", - [torch.float32, torch.float16, torch.bfloat16]) -@pytest.mark.parametrize("use_3d", [True, False]) -def test_seeded_uniform(dtype: torch.dtype, use_3d: bool): - device = "cuda" - for seed in range(512): - set_random_seed(seed) - rows = random.randint(1, 512) - cols = random.randint(1, 64000) - if use_3d: - third_dim = random.randint(2, 10) - dims = [rows, third_dim, cols] - else: - dims = [rows, cols] - seeds = torch.randint(torch.iinfo(torch.long).min, - torch.iinfo(torch.long).max, (rows, ), - device=device) - - # Test that the same seed produces the same output - out = seeded_uniform(*dims, seeds=seeds, dtype=dtype, device=device) - out2 = seeded_uniform(*dims, seeds=seeds, dtype=dtype, device=device) - torch.testing.assert_close(out, out2) - # del to save memory - del out2 - - out3 = seeded_uniform(*dims, seeds=seeds, dtype=dtype, device=device) - torch.testing.assert_close(out, out3) - # del to save memory - del out3 - - # Initialize out tensor with garbage to ensure that it is overwritten - out_with_tensor = seeded_uniform( - *dims, - out=torch.full( - (*dims, ), - -1, - dtype=dtype, - device=device, - ), - seeds=seeds, - dtype=dtype, - ) - torch.testing.assert_close(out, out_with_tensor) diff --git a/tests/kernels/test_sampler.py b/tests/kernels/test_sampler.py deleted file mode 100644 index 03844aba20f8a..0000000000000 --- a/tests/kernels/test_sampler.py +++ /dev/null @@ -1,209 +0,0 @@ -import gc -from unittest.mock import patch - -import pytest -import torch -import triton -import triton.language as tl - -from vllm.model_executor.layers.ops.sample import (_sample_triton, - _uniform_to_exponential, - sample) -from vllm.model_executor.sampling_metadata import SamplingTensors -from vllm.model_executor.utils import set_random_seed -from vllm.triton_utils.libentry import LibEntry -from vllm.triton_utils.sample import (MAX_TRITON_N_COLS, - get_num_triton_sampler_splits) - -SINGLE_SPLIT_VOCAB_SIZE = 32000 # llama/mistral/mixtral vocab size -MULTI_SPLIT_VOCAB_SIZE = MAX_TRITON_N_COLS + 100 - - -@pytest.fixture(autouse=True) -def _cleanup(): - yield - gc.collect() - torch.cuda.empty_cache() - - -@triton.jit -def _uniform_to_exponential_kernel(input, output, n: tl.constexpr): - idx = tl.arange(0, n) - x = tl.load(input + idx) - y = _uniform_to_exponential(x) - tl.store(output + idx, y) - - -def test_uniform_to_exponential(): - """Test that we can convert uniform to exponential without div by 0.""" - input = torch.tensor([0.0, 1.0 - torch.finfo(torch.float32).eps], - dtype=torch.float32, - device="cuda") - output = torch.zeros(input.shape, dtype=torch.float32, device="cuda") - _uniform_to_exponential_kernel[(1, )](input, output, 2) - assert torch.all(torch.isfinite(output)) - assert torch.all(output > 0) - assert torch.all(torch.isfinite(torch.full_like(output, 1.0) / output)) - - -@pytest.mark.parametrize("random_sampling", [True, False, "mixed"]) -@pytest.mark.parametrize("max_best_of", [1, 2, 3, 4, 5]) -@pytest.mark.parametrize("modify_greedy_probs", [True, False]) -@pytest.mark.parametrize("seed", [1337]) -@pytest.mark.parametrize("vocab_size", - [SINGLE_SPLIT_VOCAB_SIZE, MULTI_SPLIT_VOCAB_SIZE]) -@pytest.mark.parametrize("save_logprobs", [True, False]) -def test_sample_decoding_only(random_sampling, max_best_of, - modify_greedy_probs, seed, vocab_size, - save_logprobs): - set_random_seed(seed) - bs = 8 - probs = torch.zeros((bs, vocab_size), dtype=torch.float32, device="cuda") - for i in range(bs): - probs[i, i * (vocab_size // bs)] = 1.0 - logprobs = torch.rand_like(probs) - sample_indices = torch.arange(bs, dtype=torch.long, device="cuda") - n_splits = get_num_triton_sampler_splits(probs.shape[1]) - if random_sampling == "mixed": - random_sampling_mask = (torch.rand( - (1, bs), device="cuda") < 0.5).expand(n_splits, bs) - elif random_sampling: - random_sampling_mask = torch.ones((n_splits, bs), - dtype=torch.bool, - device="cuda") - else: - random_sampling_mask = torch.zeros((n_splits, bs), - dtype=torch.bool, - device="cuda") - - seeds = torch.randint(1, - torch.iinfo(torch.long).max, (n_splits, bs), - device="cuda").mul_(random_sampling_mask) - #The current _sample_triton does not utilize the - # libentry decoration. The purpose of adding this patch is to test - # the correctness of libentry. - with patch("vllm.model_executor.layers.ops.sample._sample_triton", - LibEntry(_sample_triton)): - sampled_tokens, sampled_logprobs, sampled_modified_probs = sample( - probs=probs, - logprobs=logprobs, - sample_indices=sample_indices, - seeds=seeds, - max_best_of=max_best_of, - modify_greedy_probs=modify_greedy_probs, - save_logprobs=save_logprobs, - _save_modified_probs=True) - assert sampled_tokens.shape == (bs, max_best_of) - for i in range(bs): - assert torch.all(sampled_tokens[i] == i * (vocab_size // bs)) - request_uses_random_sampling = random_sampling_mask[0, i] - if modify_greedy_probs and not request_uses_random_sampling: - # If we are modifying greedy probs and the request is greedy, - # we want to make sure the probs tensor is modified in place - torch.testing.assert_close( - probs[i][sampled_tokens[i]], - torch.full_like(probs[i][sampled_tokens[i]], 1.0)) - assert torch.sum(probs[i]) == 1.0 - torch.testing.assert_close( - sampled_modified_probs[i][0], - torch.full_like(sampled_modified_probs[i][0], 1.0)) - elif request_uses_random_sampling: - # If the request is random, we want to make sure - # sampled_modified_probs tensor has noise added - # (and thus is different from probs tensor) - assert not torch.allclose(sampled_modified_probs[i][0], - probs[i][sampled_tokens[i]]) - elif not request_uses_random_sampling: - # If the request is greedy and we are not modifying greedy probs, - # we want to make sure sampled_modified_probs tensor is the same as - # the probs tensor. - torch.testing.assert_close(sampled_modified_probs[i], - probs[i][sampled_tokens[i]]) - - if save_logprobs: - assert sampled_logprobs.shape == (bs, max_best_of) - for i in range(bs): - for best_of in range(max_best_of): - assert torch.all(sampled_logprobs[i] == logprobs[i][ - sampled_tokens[i, best_of]]) - else: - assert sampled_logprobs is None - - -@pytest.mark.parametrize("random_sampling", [True, False, "mixed"]) -@pytest.mark.parametrize("max_best_of", [1, 2, 3, 4, 5]) -@pytest.mark.parametrize("modify_greedy_probs", [True, False]) -@pytest.mark.parametrize("seed", [1337]) -@pytest.mark.parametrize("vocab_size", - [SINGLE_SPLIT_VOCAB_SIZE, MULTI_SPLIT_VOCAB_SIZE]) -def test_sample_prompt_logprobs(random_sampling, max_best_of, - modify_greedy_probs, seed, vocab_size): - - set_random_seed(seed) - prompt_sizes = [16, 32, 64, 128] * 2 - samples = 8 - bs = samples + sum(prompt_sizes) - probs = torch.zeros((bs, vocab_size), dtype=torch.float32, device="cuda") - for i in range(bs): - probs[i, i * (vocab_size // bs)] = 1.0 - logprobs = torch.rand_like(probs) - sample_indices = torch.tensor(prompt_sizes, - dtype=torch.long, - device="cuda").cumsum_(0) - n_splits = get_num_triton_sampler_splits(probs.shape[1]) - if random_sampling == "mixed": - random_sampling_mask = torch.rand( - (n_splits, samples), device="cuda") < 0.5 - elif random_sampling: - random_sampling_mask = torch.ones((n_splits, samples), - dtype=torch.bool, - device="cuda") - else: - random_sampling_mask = torch.zeros((n_splits, samples), - dtype=torch.bool, - device="cuda") - - seeds = torch.randint(1, - torch.iinfo(torch.long).max, (n_splits, samples), - device="cuda").mul_(random_sampling_mask) - #ditto - with patch("vllm.model_executor.layers.ops.sample._sample_triton", - LibEntry(_sample_triton)): - sampled_tokens, sampled_logprobs, _ = sample( - probs=probs, - logprobs=logprobs, - sample_indices=sample_indices, - seeds=seeds, - max_best_of=max_best_of, - modify_greedy_probs=modify_greedy_probs, - save_logprobs=True) - assert sampled_tokens.shape == (samples, max_best_of) - assert sampled_logprobs.shape == (samples, max_best_of) - for i, t in enumerate(sample_indices): - assert torch.all(sampled_tokens[i] == t * (vocab_size // bs)) - for best_of in range(max_best_of): - assert torch.all(sampled_logprobs[i] == logprobs[sample_indices[i]] - [sampled_tokens[i, best_of]]) - - -@pytest.mark.parametrize("seed", list(range(16))) -def test_get_sequence_seeds(seed): - """Ensure that we get a different child seed from base - seed + extra entropy""" - starting_seed = seed - seq_seed = None - extra_entropy = 1 - for i in range(512): - new_seq_seed = SamplingTensors._get_sequence_seeds(starting_seed, - i, - seeds_to_generate=1, - is_greedy=False)[0] - new_seq_seed_extra_entropy = SamplingTensors._get_sequence_seeds( - starting_seed, - i, - extra_entropy, - seeds_to_generate=1, - is_greedy=False)[0] - assert new_seq_seed_extra_entropy != new_seq_seed - assert seq_seed != new_seq_seed - seq_seed = new_seq_seed diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index dbddd69c07dbc..5746932c30a45 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -10,7 +10,6 @@ import torch from vllm.attention import AttentionBackend, AttentionMetadata, AttentionType -from vllm.attention.backends.xformers import XFormersBackend from vllm.utils import (STR_BACKEND_ENV_VAR, STR_XFORMERS_ATTN_VAL, make_tensor_with_pad) @@ -521,6 +520,9 @@ def make_backend(backend_name: str) -> AttentionBackend: * Backend instance ''' if backend_name == STR_XFORMERS_ATTN_VAL: + # NOTE: xFormers backend cannot be imported for CPU and AMD GPUs. + from vllm.attention.backends.xformers import XFormersBackend + return XFormersBackend() raise AssertionError( f"Unrecognized backend_name {backend_name} for unit test") diff --git a/tests/lora/conftest.py b/tests/lora/conftest.py index 0bcae5b0c96dc..4834a9d35a3ee 100644 --- a/tests/lora/conftest.py +++ b/tests/lora/conftest.py @@ -65,10 +65,7 @@ def should_do_global_cleanup_after_test(request) -> bool: to initialize torch. """ - if request.node.get_closest_marker("skip_global_cleanup"): - return False - - return True + return not request.node.get_closest_marker("skip_global_cleanup") @pytest.fixture(autouse=True) diff --git a/tests/lora/test_layers.py b/tests/lora/test_layers.py index effcffc5c174e..e3233c6b60696 100644 --- a/tests/lora/test_layers.py +++ b/tests/lora/test_layers.py @@ -39,6 +39,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, get_masked_input_and_mask) from vllm.model_executor.utils import set_random_seed +from vllm.utils import seed_everything from .utils import DummyLoRAManager @@ -922,9 +923,7 @@ def test_rotary_embedding_long_context(dist_init, num_loras, device, seq_len) -> None: dtype = torch.float16 seed = 0 - torch.random.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) + seed_everything(seed) torch.set_default_device(device) punica_wrapper = PunicaWrapper(8192, 256, device) max_loras = 8 diff --git a/tests/lora/test_punica_sizes.py b/tests/lora/test_punica_sizes.py index c36fb3afb0cc3..314d6215cbd9c 100644 --- a/tests/lora/test_punica_sizes.py +++ b/tests/lora/test_punica_sizes.py @@ -4,7 +4,6 @@ whether the corresponding Triton kernel can run normally when tensor parallelism is set to [1, 2, 4, 8, 16, 32, 64]. """ -import random from unittest.mock import patch import pytest @@ -17,6 +16,7 @@ from vllm.lora.ops.sgmv_expand_slice import sgmv_expand_slice from vllm.lora.ops.sgmv_shrink import sgmv_shrink from vllm.triton_utils.libentry import LibEntry +from vllm.utils import seed_everything from .utils import (generate_data, generate_data_for_expand_nslices, ref_torch_groupgemm) @@ -145,11 +145,8 @@ def test_punica_sgmv( seed: int, device: str, ): - random.seed(seed) torch.set_default_device(device) - torch.random.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) + seed_everything(seed) seq_length = 128 ( @@ -238,11 +235,8 @@ def test_punica_bgmv( from vllm.lora.ops.bgmv_expand import _bgmv_expand_kernel from vllm.lora.ops.bgmv_shrink import _bgmv_shrink_kernel - random.seed(seed) torch.set_default_device(device) - torch.random.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) + seed_everything(seed) seq_length = 1 ( @@ -329,11 +323,9 @@ def test_punica_expand_nslices( ): from vllm.lora.ops.bgmv_expand_slice import _bgmv_expand_slice_kernel - random.seed(seed) torch.set_default_device(device) - torch.random.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) + seed_everything(seed) + seq_length = 128 if op_type == "sgmv" else 1 ( inputs_tensor, diff --git a/tests/lora/test_punica_variation.py b/tests/lora/test_punica_variation.py index d026e34878e04..28a395af19e6d 100644 --- a/tests/lora/test_punica_variation.py +++ b/tests/lora/test_punica_variation.py @@ -3,7 +3,6 @@ under different conditions, including various batches, numbers of LoRA , and maximum ranks. """ -import random from unittest.mock import patch import pytest @@ -16,6 +15,7 @@ from vllm.lora.ops.sgmv_expand_slice import sgmv_expand_slice from vllm.lora.ops.sgmv_shrink import sgmv_shrink from vllm.triton_utils.libentry import LibEntry +from vllm.utils import seed_everything from .utils import (generate_data, generate_data_for_expand_nslices, ref_torch_groupgemm) @@ -60,11 +60,8 @@ def test_punica_sgmv( seed: int, device: str, ): - random.seed(seed) torch.set_default_device(device) - torch.random.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) + seed_everything(seed) seq_length = 128 ( @@ -153,11 +150,8 @@ def test_punica_bgmv( from vllm.lora.ops.bgmv_expand import _bgmv_expand_kernel from vllm.lora.ops.bgmv_shrink import _bgmv_shrink_kernel - random.seed(seed) torch.set_default_device(device) - torch.random.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) + seed_everything(seed) seq_length = 1 ( @@ -244,11 +238,9 @@ def test_punica_expand_nslices( ): from vllm.lora.ops.bgmv_expand_slice import _bgmv_expand_slice_kernel - random.seed(seed) torch.set_default_device(device) - torch.random.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) + seed_everything(seed) + seq_length = 128 if op_type == "sgmv" else 1 ( inputs_tensor, diff --git a/tests/models/decoder_only/__init__.py b/tests/models/decoder_only/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/models/decoder_only/audio_language/__init__.py b/tests/models/decoder_only/audio_language/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/models/test_ultravox.py b/tests/models/decoder_only/audio_language/test_ultravox.py similarity index 98% rename from tests/models/test_ultravox.py rename to tests/models/decoder_only/audio_language/test_ultravox.py index e98db9b65f484..bfffd34d1142c 100644 --- a/tests/models/test_ultravox.py +++ b/tests/models/decoder_only/audio_language/test_ultravox.py @@ -7,10 +7,8 @@ from vllm.sequence import SampleLogprobs from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE -from ..conftest import HfRunner, VllmRunner -from .utils import check_logprobs_close - -pytestmark = pytest.mark.vlm +from ....conftest import HfRunner, VllmRunner +from ...utils import check_logprobs_close MODEL_NAME = "fixie-ai/ultravox-v0_3" diff --git a/tests/models/decoder_only/language/__init__.py b/tests/models/decoder_only/language/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/models/test_aqlm.py b/tests/models/decoder_only/language/test_aqlm.py similarity index 100% rename from tests/models/test_aqlm.py rename to tests/models/decoder_only/language/test_aqlm.py diff --git a/tests/models/test_big_models.py b/tests/models/decoder_only/language/test_big_models.py similarity index 77% rename from tests/models/test_big_models.py rename to tests/models/decoder_only/language/test_big_models.py index c3e48b56ee58f..fcc158639748d 100644 --- a/tests/models/test_big_models.py +++ b/tests/models/decoder_only/language/test_big_models.py @@ -5,9 +5,10 @@ Run `pytest tests/models/test_big_models.py`. """ import pytest -import torch -from .utils import check_outputs_equal +from vllm.platforms import current_platform + +from ...utils import check_outputs_equal MODELS = [ "meta-llama/Llama-2-7b-hf", @@ -19,10 +20,12 @@ # "Qwen/Qwen1.5-0.5B" # Broken, ] +if not current_platform.is_cpu(): + # MiniCPM requires fused_moe which is not supported by CPU + MODELS.append("openbmb/MiniCPM3-4B") + #TODO: remove this after CPU float16 support ready -target_dtype = "float" -if torch.cuda.is_available(): - target_dtype = "half" +target_dtype = "float" if current_platform.is_cpu() else "half" @pytest.mark.parametrize("model", MODELS) @@ -39,7 +42,7 @@ def test_models( with hf_runner(model, dtype=dtype) as hf_model: hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) - with vllm_runner(model, dtype=dtype) as vllm_model: + with vllm_runner(model, dtype=dtype, enforce_eager=True) as vllm_model: vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) check_outputs_equal( @@ -57,7 +60,7 @@ def test_model_print( model: str, dtype: str, ) -> None: - with vllm_runner(model, dtype=dtype) as vllm_model: + with vllm_runner(model, dtype=dtype, enforce_eager=True) as vllm_model: # This test is for verifying whether the model's extra_repr # can be printed correctly. print(vllm_model.model.llm_engine.model_executor.driver_worker. diff --git a/tests/models/test_danube3_4b.py b/tests/models/decoder_only/language/test_danube3_4b.py similarity index 97% rename from tests/models/test_danube3_4b.py rename to tests/models/decoder_only/language/test_danube3_4b.py index bfaa275f73c19..bdd498edc293d 100644 --- a/tests/models/test_danube3_4b.py +++ b/tests/models/decoder_only/language/test_danube3_4b.py @@ -6,7 +6,7 @@ """ import pytest -from .utils import check_outputs_equal +from ...utils import check_outputs_equal MODELS = ["h2oai/h2o-danube3-4b-base"] diff --git a/tests/models/test_fp8.py b/tests/models/decoder_only/language/test_fp8.py similarity index 98% rename from tests/models/test_fp8.py rename to tests/models/decoder_only/language/test_fp8.py index 17acdb52322fd..5a947ce62c785 100644 --- a/tests/models/test_fp8.py +++ b/tests/models/decoder_only/language/test_fp8.py @@ -10,7 +10,7 @@ from tests.kernels.utils import override_backend_env_variable from tests.quantization.utils import is_quant_method_supported -from ..models.utils import check_logprobs_close +from ...utils import check_logprobs_close os.environ["TOKENIZERS_PARALLELISM"] = "true" diff --git a/tests/models/test_gguf.py b/tests/models/decoder_only/language/test_gguf.py similarity index 98% rename from tests/models/test_gguf.py rename to tests/models/decoder_only/language/test_gguf.py index 196cd88e039a1..8fc64a10c84af 100644 --- a/tests/models/test_gguf.py +++ b/tests/models/decoder_only/language/test_gguf.py @@ -11,7 +11,7 @@ from tests.quantization.utils import is_quant_method_supported -from .utils import check_logprobs_close +from ...utils import check_logprobs_close os.environ["TOKENIZERS_PARALLELISM"] = "true" diff --git a/tests/models/test_gptq_marlin.py b/tests/models/decoder_only/language/test_gptq_marlin.py similarity index 98% rename from tests/models/test_gptq_marlin.py rename to tests/models/decoder_only/language/test_gptq_marlin.py index 4abbc41c9c287..2155e83dbe915 100644 --- a/tests/models/test_gptq_marlin.py +++ b/tests/models/decoder_only/language/test_gptq_marlin.py @@ -15,7 +15,7 @@ from tests.quantization.utils import is_quant_method_supported from vllm.model_executor.layers.rotary_embedding import _ROPE_DICT -from .utils import check_logprobs_close +from ...utils import check_logprobs_close os.environ["TOKENIZERS_PARALLELISM"] = "true" diff --git a/tests/models/test_gptq_marlin_24.py b/tests/models/decoder_only/language/test_gptq_marlin_24.py similarity index 97% rename from tests/models/test_gptq_marlin_24.py rename to tests/models/decoder_only/language/test_gptq_marlin_24.py index 60d9ae2f1c629..d65be05f141b4 100644 --- a/tests/models/test_gptq_marlin_24.py +++ b/tests/models/decoder_only/language/test_gptq_marlin_24.py @@ -10,9 +10,10 @@ import pytest -from tests.models.utils import check_logprobs_close from tests.quantization.utils import is_quant_method_supported +from ...utils import check_logprobs_close + @dataclass class ModelPair: diff --git a/tests/models/test_granite.py b/tests/models/decoder_only/language/test_granite.py similarity index 84% rename from tests/models/test_granite.py rename to tests/models/decoder_only/language/test_granite.py index 2435b5dc3ff88..e5c5ce4a8f745 100644 --- a/tests/models/test_granite.py +++ b/tests/models/decoder_only/language/test_granite.py @@ -2,15 +2,10 @@ Run `pytest tests/models/test_granite.py`. """ -import importlib.metadata - import pytest +import transformers -from .utils import check_logprobs_close - -TRANSFORMERS_VERSION = tuple( - map(int, - importlib.metadata.version("transformers").split("."))) +from ...utils import check_logprobs_close MODELS = [ "ibm/PowerLM-3b", @@ -18,7 +13,7 @@ # GraniteForCausalLM will be in transformers >= 4.45 -@pytest.mark.skipif(TRANSFORMERS_VERSION < (4, 45), +@pytest.mark.skipif(transformers.__version__ < "4.45", reason="granite model test requires transformers >= 4.45") @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["bfloat16"]) diff --git a/tests/models/test_jamba.py b/tests/models/decoder_only/language/test_jamba.py similarity index 99% rename from tests/models/test_jamba.py rename to tests/models/decoder_only/language/test_jamba.py index efb7b1c607721..36fa67a22b0f6 100644 --- a/tests/models/test_jamba.py +++ b/tests/models/decoder_only/language/test_jamba.py @@ -1,8 +1,9 @@ import pytest -from tests.models.utils import check_outputs_equal from vllm.worker.model_runner import _get_graph_batch_size +from ...utils import check_outputs_equal + MODELS = ["ai21labs/Jamba-tiny-random"] diff --git a/tests/models/test_marlin.py b/tests/models/decoder_only/language/test_marlin.py similarity index 98% rename from tests/models/test_marlin.py rename to tests/models/decoder_only/language/test_marlin.py index e86f6e29d1567..c802346dee8af 100644 --- a/tests/models/test_marlin.py +++ b/tests/models/decoder_only/language/test_marlin.py @@ -16,7 +16,7 @@ from tests.quantization.utils import is_quant_method_supported -from .utils import check_logprobs_close +from ...utils import check_logprobs_close @dataclass diff --git a/tests/models/test_mistral.py b/tests/models/decoder_only/language/test_mistral.py similarity index 50% rename from tests/models/test_mistral.py rename to tests/models/decoder_only/language/test_mistral.py index 0741174497e32..26f90456849f1 100644 --- a/tests/models/test_mistral.py +++ b/tests/models/decoder_only/language/test_mistral.py @@ -4,13 +4,61 @@ """ import pytest -from .utils import check_logprobs_close +from vllm import SamplingParams + +from ...utils import check_logprobs_close MODELS = [ "mistralai/Mistral-7B-Instruct-v0.1", "mistralai/Mistral-7B-Instruct-v0.3", + # Mistral-Nemo is to big for CI, but passes locally + # "mistralai/Mistral-Nemo-Instruct-2407" ] +SAMPLING_PARAMS = SamplingParams(max_tokens=512, temperature=0.0, logprobs=5) + +# for function calling +TOOLS = [{ + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "city": { + "type": + "string", + "description": + "The city to find the weather for, e.g. 'San Francisco'" + }, + "state": { + "type": + "string", + "description": + "the two-letter abbreviation for the state that the city is" + " in, e.g. 'CA' which would mean 'California'" + }, + "unit": { + "type": "string", + "description": "The unit to fetch the temperature in", + "enum": ["celsius", "fahrenheit"] + } + }, + "required": ["city", "state", "unit"] + } + } +}] +MSGS = [{ + "role": + "user", + "content": ("Can you tell me what the temperate" + " will be in Dallas, in fahrenheit?") +}] +EXPECTED_FUNC_CALL = ( + '[{"name": "get_current_weather", "arguments": ' + '{"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]') + @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["bfloat16"]) @@ -81,3 +129,22 @@ def test_mistral_format( name_0="hf", name_1="mistral", ) + + +@pytest.mark.parametrize("dtype", ["bfloat16"]) +@pytest.mark.parametrize("model", MODELS[1:]) # v1 can't do func calling +def test_mistral_function_calling( + vllm_runner, + model: str, + dtype: str, +) -> None: + with vllm_runner(model, + dtype=dtype, + tokenizer_mode="mistral", + config_format="mistral", + load_format="mistral") as vllm_model: + outputs = vllm_model.model.chat(MSGS, + tools=TOOLS, + sampling_params=SAMPLING_PARAMS) + + assert outputs[0].outputs[0].text.strip() == EXPECTED_FUNC_CALL diff --git a/tests/models/test_modelopt.py b/tests/models/decoder_only/language/test_modelopt.py similarity index 100% rename from tests/models/test_modelopt.py rename to tests/models/decoder_only/language/test_modelopt.py diff --git a/tests/models/test_models.py b/tests/models/decoder_only/language/test_models.py similarity index 97% rename from tests/models/test_models.py rename to tests/models/decoder_only/language/test_models.py index 4cd2cb665c8f0..68055cbe29095 100644 --- a/tests/models/test_models.py +++ b/tests/models/decoder_only/language/test_models.py @@ -7,7 +7,7 @@ """ import pytest -from .utils import check_outputs_equal +from ...utils import check_outputs_equal MODELS = [ "facebook/opt-125m", diff --git a/tests/models/test_phimoe.py b/tests/models/decoder_only/language/test_phimoe.py similarity index 98% rename from tests/models/test_phimoe.py rename to tests/models/decoder_only/language/test_phimoe.py index 2fb2eecc94672..dbdf5a1b934a6 100644 --- a/tests/models/test_phimoe.py +++ b/tests/models/decoder_only/language/test_phimoe.py @@ -7,7 +7,7 @@ from vllm.utils import is_cpu -from .utils import check_logprobs_close +from ...utils import check_logprobs_close MODELS = [ "microsoft/Phi-3.5-MoE-instruct", diff --git a/tests/models/decoder_only/vision_language/__init__.py b/tests/models/decoder_only/vision_language/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/models/test_blip2.py b/tests/models/decoder_only/vision_language/test_blip2.py similarity index 95% rename from tests/models/test_blip2.py rename to tests/models/decoder_only/vision_language/test_blip2.py index 5d48bad0d7b35..e1e32b96d89ac 100644 --- a/tests/models/test_blip2.py +++ b/tests/models/decoder_only/vision_language/test_blip2.py @@ -6,10 +6,8 @@ from vllm.multimodal.utils import rescale_image_size from vllm.sequence import SampleLogprobs -from ..conftest import IMAGE_ASSETS -from .utils import check_logprobs_close - -pytestmark = pytest.mark.vlm +from ....conftest import IMAGE_ASSETS +from ...utils import check_logprobs_close HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({ "stop_sign": @@ -56,7 +54,7 @@ def test_models(hf_runner, vllm_runner, image_assets, model, size_factors, dtype: str, max_tokens: int, num_logprobs: int) -> None: """Inference result should be the same between hf and vllm. - All the image fixtures for the test is under tests/images. + All the image fixtures for the test are from IMAGE_ASSETS. For huggingface runner, we provide the PIL images as input. For vllm runner, we provide MultiModalData objects and corresponding MultiModalConfig as input. diff --git a/tests/models/decoder_only/vision_language/test_broadcast.py b/tests/models/decoder_only/vision_language/test_broadcast.py new file mode 100644 index 0000000000000..d01490d74bd4d --- /dev/null +++ b/tests/models/decoder_only/vision_language/test_broadcast.py @@ -0,0 +1,42 @@ +import pytest + +from ....utils import multi_gpu_test + + +@multi_gpu_test(num_gpus=2) +@pytest.mark.parametrize("distributed_executor_backend", ["ray", "mp"]) +@pytest.mark.parametrize("model", [ + "llava-hf/llava-1.5-7b-hf", + "llava-hf/llava-v1.6-mistral-7b-hf", + "facebook/chameleon-7b", +]) +def test_models(hf_runner, vllm_runner, image_assets, + distributed_executor_backend, model) -> None: + + dtype = "half" + max_tokens = 5 + num_logprobs = 5 + tensor_parallel_size = 2 + + if model.startswith("llava-hf/llava-1.5"): + from .test_llava import models, run_test + elif model.startswith("llava-hf/llava-v1.6"): + from .test_llava_next import models, run_test # type: ignore[no-redef] + elif model.startswith("facebook/chameleon"): + from .test_chameleon import models, run_test # type: ignore[no-redef] + else: + raise NotImplementedError(f"Unsupported model: {model}") + + run_test( + hf_runner, + vllm_runner, + image_assets, + model=models[0], + # So that LLaVA-NeXT processor may return nested list + size_factors=[0.25, 0.5, 1.0], + dtype=dtype, + max_tokens=max_tokens, + num_logprobs=num_logprobs, + tensor_parallel_size=tensor_parallel_size, + distributed_executor_backend=distributed_executor_backend, + ) diff --git a/tests/models/test_chameleon.py b/tests/models/decoder_only/vision_language/test_chameleon.py similarity index 95% rename from tests/models/test_chameleon.py rename to tests/models/decoder_only/vision_language/test_chameleon.py index e02b4b1ed72bd..8334451970a4f 100644 --- a/tests/models/test_chameleon.py +++ b/tests/models/decoder_only/vision_language/test_chameleon.py @@ -6,10 +6,8 @@ from vllm.multimodal.utils import rescale_image_size from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE -from ..conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets -from .utils import check_outputs_equal - -pytestmark = pytest.mark.vlm +from ....conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets +from ...utils import check_outputs_equal HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({ "stop_sign": @@ -36,7 +34,7 @@ def run_test( ): """Inference result should be the same between hf and vllm. - All the image fixtures for the test is under tests/images. + All the image fixtures for the test are from IMAGE_ASSETS. For huggingface runner, we provide the PIL images as input. For vllm runner, we provide MultiModalDataDict objects and corresponding vision language config as input. diff --git a/tests/models/test_fuyu.py b/tests/models/decoder_only/vision_language/test_fuyu.py similarity index 95% rename from tests/models/test_fuyu.py rename to tests/models/decoder_only/vision_language/test_fuyu.py index 0d666d8f71a92..94b8431424db5 100644 --- a/tests/models/test_fuyu.py +++ b/tests/models/decoder_only/vision_language/test_fuyu.py @@ -6,10 +6,8 @@ from vllm.sequence import SampleLogprobs from vllm.utils import is_cpu -from ..conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets -from .utils import check_logprobs_close - -pytestmark = pytest.mark.vlm +from ....conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets +from ...utils import check_logprobs_close HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({ "stop_sign": @@ -46,7 +44,7 @@ def run_test( ): """Inference result should be the same between hf and vllm. - All the image fixtures for the test is under tests/images. + All the image fixtures for the test are from IMAGE_ASSETS. For huggingface runner, we provide the PIL images as input. For vllm runner, we provide MultiModalDataDict objects and corresponding MultiModalConfig as input. diff --git a/tests/models/test_intern_vit.py b/tests/models/decoder_only/vision_language/test_intern_vit.py similarity index 97% rename from tests/models/test_intern_vit.py rename to tests/models/decoder_only/vision_language/test_intern_vit.py index 816f846f69bae..3c3b95b38baac 100644 --- a/tests/models/test_intern_vit.py +++ b/tests/models/decoder_only/vision_language/test_intern_vit.py @@ -6,9 +6,7 @@ from huggingface_hub import snapshot_download from transformers import AutoConfig, AutoModel, CLIPImageProcessor -from ..conftest import _ImageAssets, cleanup - -pytestmark = pytest.mark.vlm +from ....conftest import _ImageAssets, cleanup # we use snapshot_download to prevent conflicts between # dynamic_module and trust_remote_code for hf_runner diff --git a/tests/models/test_internvl.py b/tests/models/decoder_only/vision_language/test_internvl.py similarity index 98% rename from tests/models/test_internvl.py rename to tests/models/decoder_only/vision_language/test_internvl.py index 881068b3afe41..a756f8214edee 100644 --- a/tests/models/test_internvl.py +++ b/tests/models/decoder_only/vision_language/test_internvl.py @@ -9,11 +9,9 @@ from vllm.multimodal.utils import rescale_image_size from vllm.utils import is_cpu -from ..conftest import (IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner, - _ImageAssets) -from .utils import check_logprobs_close - -pytestmark = pytest.mark.vlm +from ....conftest import (IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner, + _ImageAssets) +from ...utils import check_logprobs_close HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({ "stop_sign": @@ -78,7 +76,7 @@ def run_test( ): """Inference result should be the same between hf and vllm. - All the image fixtures for the test is under tests/images. + All the image fixtures for the test are from IMAGE_ASSETS. For huggingface runner, we provide the PIL images as input. For vllm runner, we provide MultiModalDataDict objects and corresponding MultiModalConfig as input. diff --git a/tests/models/test_llava.py b/tests/models/decoder_only/vision_language/test_llava.py similarity index 96% rename from tests/models/test_llava.py rename to tests/models/decoder_only/vision_language/test_llava.py index 84ca23f6222a9..fd28a9367b4b2 100644 --- a/tests/models/test_llava.py +++ b/tests/models/decoder_only/vision_language/test_llava.py @@ -8,11 +8,9 @@ from vllm.sequence import SampleLogprobs from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE -from ..conftest import (IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner, - _ImageAssets) -from .utils import check_logprobs_close - -pytestmark = pytest.mark.vlm +from ....conftest import (IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner, + _ImageAssets) +from ...utils import check_logprobs_close _LIMIT_IMAGE_PER_PROMPT = 4 @@ -143,7 +141,7 @@ def _run_test( ): """Inference result should be the same between hf and vllm. - All the image fixtures for the test is under tests/images. + All the image fixtures for the test are from IMAGE_ASSETS. For huggingface runner, we provide the PIL images as input. For vllm runner, we provide MultiModalDataDict objects and corresponding MultiModalConfig as input. @@ -239,7 +237,7 @@ def process(hf_inputs: BatchEncoding): @pytest.mark.parametrize("max_tokens", [128]) @pytest.mark.parametrize("num_logprobs", [5]) def test_models(hf_runner, vllm_runner, image_assets, model, size_factors, - dtype: str, max_tokens: int, num_logprobs: int) -> None: + dtype, max_tokens, num_logprobs) -> None: run_test( hf_runner, vllm_runner, diff --git a/tests/models/test_llava_image_embeds.py b/tests/models/decoder_only/vision_language/test_llava_image_embeds.py similarity index 96% rename from tests/models/test_llava_image_embeds.py rename to tests/models/decoder_only/vision_language/test_llava_image_embeds.py index cc444fe32e79b..66414032509ed 100644 --- a/tests/models/test_llava_image_embeds.py +++ b/tests/models/decoder_only/vision_language/test_llava_image_embeds.py @@ -5,10 +5,8 @@ from vllm.sequence import SampleLogprobs -from ..conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets -from .utils import check_logprobs_close - -pytestmark = pytest.mark.vlm +from ....conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets +from ...utils import check_logprobs_close HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({ "stop_sign": @@ -62,7 +60,7 @@ def run_test( ): """Inference result should be the same between hf and vllm. - All the image fixtures for the test is under tests/images. + All the image fixtures for the test are from IMAGE_ASSETS. For huggingface runner, we provide the PIL images as input. For vllm runner, we provide MultiModalDataDict objects and corresponding vision language config as input. diff --git a/tests/models/test_llava_next.py b/tests/models/decoder_only/vision_language/test_llava_next.py similarity index 97% rename from tests/models/test_llava_next.py rename to tests/models/decoder_only/vision_language/test_llava_next.py index d5fe0cbe32880..f833fe0c8bbb4 100644 --- a/tests/models/test_llava_next.py +++ b/tests/models/decoder_only/vision_language/test_llava_next.py @@ -6,11 +6,9 @@ from vllm.multimodal.utils import rescale_image_size from vllm.sequence import SampleLogprobs -from ..conftest import (IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner, - _ImageAssets) -from .utils import check_logprobs_close - -pytestmark = pytest.mark.vlm +from ....conftest import (IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner, + _ImageAssets) +from ...utils import check_logprobs_close _LIMIT_IMAGE_PER_PROMPT = 4 @@ -197,7 +195,7 @@ def test_models(hf_runner, vllm_runner, image_assets, model, size_factors, dtype, max_tokens, num_logprobs) -> None: """Inference result should be the same between hf and vllm. - All the image fixtures for the test is under tests/images. + All the image fixtures for the test are from IMAGE_ASSETS. For huggingface runner, we provide the PIL images as input. For vllm runner, we provide MultiModalDataDict objects and corresponding MultiModalConfig as input. diff --git a/tests/models/test_llava_next_video.py b/tests/models/decoder_only/vision_language/test_llava_next_video.py similarity index 98% rename from tests/models/test_llava_next_video.py rename to tests/models/decoder_only/vision_language/test_llava_next_video.py index 6856b15f22ec3..373c8964054cd 100644 --- a/tests/models/test_llava_next_video.py +++ b/tests/models/decoder_only/vision_language/test_llava_next_video.py @@ -8,10 +8,8 @@ sample_frames_from_video) from vllm.sequence import SampleLogprobs -from ..conftest import VIDEO_ASSETS, HfRunner, VllmRunner, _VideoAssets -from .utils import check_logprobs_close - -pytestmark = pytest.mark.vlm +from ....conftest import VIDEO_ASSETS, HfRunner, VllmRunner, _VideoAssets +from ...utils import check_logprobs_close _PREFACE = ( "A chat between a curious human and an artificial intelligence assistant. " diff --git a/tests/models/test_minicpmv.py b/tests/models/decoder_only/vision_language/test_minicpmv.py similarity index 97% rename from tests/models/test_minicpmv.py rename to tests/models/decoder_only/vision_language/test_minicpmv.py index 99e49c14f1f26..7bf5d75f400f9 100644 --- a/tests/models/test_minicpmv.py +++ b/tests/models/decoder_only/vision_language/test_minicpmv.py @@ -9,10 +9,8 @@ from vllm.multimodal.utils import rescale_image_size from vllm.sequence import SampleLogprobs -from ..conftest import IMAGE_ASSETS, HfRunner, VllmRunner -from .utils import check_logprobs_close - -pytestmark = pytest.mark.vlm +from ....conftest import IMAGE_ASSETS, HfRunner, VllmRunner +from ...utils import check_logprobs_close # The image token is placed before "user" on purpose so that the test can pass HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({ @@ -65,7 +63,7 @@ def run_test( ): """Inference result should be the same between hf and vllm. - All the image fixtures for the test is under tests/images. + All the image fixtures for the test are from IMAGE_ASSETS. For huggingface runner, we provide the PIL images as input. For vllm runner, we provide MultiModalDataDict objects and corresponding MultiModalConfig as input. diff --git a/tests/models/test_paligemma.py b/tests/models/decoder_only/vision_language/test_paligemma.py similarity index 96% rename from tests/models/test_paligemma.py rename to tests/models/decoder_only/vision_language/test_paligemma.py index beddaaf608a18..d7e29ea76ba4e 100644 --- a/tests/models/test_paligemma.py +++ b/tests/models/decoder_only/vision_language/test_paligemma.py @@ -8,10 +8,8 @@ from vllm.sequence import SampleLogprobs from vllm.utils import is_hip -from ..conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets -from .utils import check_logprobs_close - -pytestmark = pytest.mark.vlm +from ....conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets +from ...utils import check_logprobs_close HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({ "stop_sign": @@ -69,7 +67,7 @@ def run_test( ): """Inference result should be the same between hf and vllm. - All the image fixtures for the test is under tests/images. + All the image fixtures for the test are from IMAGE_ASSETS. For huggingface runner, we provide the PIL images as input. For vllm runner, we provide MultiModalDataDict objects and corresponding MultiModalConfig as input. diff --git a/tests/models/test_phi3v.py b/tests/models/decoder_only/vision_language/test_phi3v.py similarity index 97% rename from tests/models/test_phi3v.py rename to tests/models/decoder_only/vision_language/test_phi3v.py index 6ecbf07a08b7c..e248151c40a60 100644 --- a/tests/models/test_phi3v.py +++ b/tests/models/decoder_only/vision_language/test_phi3v.py @@ -9,10 +9,8 @@ from vllm.sequence import SampleLogprobs from vllm.utils import is_cpu, is_hip -from ..conftest import IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner -from .utils import check_logprobs_close - -pytestmark = pytest.mark.vlm +from ....conftest import IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner +from ...utils import check_logprobs_close HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({ "stop_sign": @@ -71,7 +69,7 @@ def run_test( ): """Inference result should be the same between hf and vllm. - All the image fixtures for the test is under tests/images. + All the image fixtures for the test are from IMAGE_ASSETS. For huggingface runner, we provide the PIL images as input. For vllm runner, we provide MultiModalDataDict objects and corresponding MultiModalConfig as input. diff --git a/tests/models/decoder_only/vision_language/test_pixtral.py b/tests/models/decoder_only/vision_language/test_pixtral.py new file mode 100644 index 0000000000000..072bedfc01a1f --- /dev/null +++ b/tests/models/decoder_only/vision_language/test_pixtral.py @@ -0,0 +1,199 @@ +"""Compare the outputs of HF and vLLM for Mistral models using greedy sampling. + +Run `pytest tests/models/test_mistral.py`. +""" +import json +import uuid +from dataclasses import asdict +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple + +import pytest +from mistral_common.protocol.instruct.messages import ImageURLChunk +from mistral_common.protocol.instruct.request import ChatCompletionRequest +from mistral_common.tokens.tokenizers.mistral import MistralTokenizer +from mistral_common.tokens.tokenizers.multimodal import image_from_chunk + +from vllm import EngineArgs, LLMEngine, SamplingParams, TokensPrompt +from vllm.multimodal import MultiModalDataBuiltins +from vllm.sequence import Logprob, SampleLogprobs + +from ....utils import VLLM_PATH +from ...utils import check_logprobs_close + +if TYPE_CHECKING: + from _typeshed import StrPath + +MODELS = ["mistralai/Pixtral-12B-2409"] +IMG_URLS = [ + "https://picsum.photos/id/237/400/300", + "https://picsum.photos/id/231/200/300", + "https://picsum.photos/id/27/500/500", + "https://picsum.photos/id/17/150/600", +] +PROMPT = "Describe each image in one short sentence." + + +def _create_msg_format(urls: List[str]) -> List[Dict[str, Any]]: + return [{ + "role": + "user", + "content": [{ + "type": "text", + "text": PROMPT, + }] + [{ + "type": "image_url", + "image_url": { + "url": url + } + } for url in urls], + }] + + +def _create_engine_inputs(urls: List[str]) -> TokensPrompt: + msg = _create_msg_format(urls) + + tokenizer = MistralTokenizer.from_model("pixtral") + + request = ChatCompletionRequest(messages=msg) # type: ignore[type-var] + tokenized = tokenizer.encode_chat_completion(request) + + engine_inputs = TokensPrompt(prompt_token_ids=tokenized.tokens) + + images = [] + for chunk in request.messages[0].content: + if isinstance(chunk, ImageURLChunk): + images.append(image_from_chunk(chunk)) + + mm_data = MultiModalDataBuiltins(image=images) + engine_inputs["multi_modal_data"] = mm_data + + return engine_inputs + + +MSGS = [ + _create_msg_format(IMG_URLS[:1]), + _create_msg_format(IMG_URLS[:2]), + _create_msg_format(IMG_URLS), +] +ENGINE_INPUTS = [ + _create_engine_inputs(IMG_URLS[:1]), + _create_engine_inputs(IMG_URLS[:2]), + _create_engine_inputs(IMG_URLS), +] + +SAMPLING_PARAMS = SamplingParams(max_tokens=512, temperature=0.0, logprobs=5) +LIMIT_MM_PER_PROMPT = dict(image=4) + +MAX_MODEL_LEN = [8192, 65536] + +FIXTURES_PATH = VLLM_PATH / "tests/models/fixtures" +assert FIXTURES_PATH.exists() + +FIXTURE_LOGPROBS_CHAT = FIXTURES_PATH / "pixtral_chat.json" +FIXTURE_LOGPROBS_ENGINE = FIXTURES_PATH / "pixtral_chat_engine.json" + +OutputsLogprobs = List[Tuple[List[int], str, Optional[SampleLogprobs]]] + + +# For the test author to store golden output in JSON +def _dump_outputs_w_logprobs( + outputs: OutputsLogprobs, + filename: "StrPath", +) -> None: + json_data = [(tokens, text, + [{k: asdict(v) + for k, v in token_logprobs.items()} + for token_logprobs in (logprobs or [])]) + for tokens, text, logprobs in outputs] + + with open(filename, "w") as f: + json.dump(json_data, f) + + +def load_outputs_w_logprobs(filename: "StrPath") -> OutputsLogprobs: + with open(filename, "rb") as f: + json_data = json.load(f) + + return [(tokens, text, + [{int(k): Logprob(**v) + for k, v in token_logprobs.items()} + for token_logprobs in logprobs]) + for tokens, text, logprobs in json_data] + + +@pytest.mark.skip( + reason= + "Model is too big, test passed on A100 locally but will OOM on CI machine." +) +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("max_model_len", MAX_MODEL_LEN) +@pytest.mark.parametrize("dtype", ["bfloat16"]) +def test_chat( + vllm_runner, + max_model_len: int, + model: str, + dtype: str, +) -> None: + EXPECTED_CHAT_LOGPROBS = load_outputs_w_logprobs(FIXTURE_LOGPROBS_CHAT) + with vllm_runner( + model, + dtype=dtype, + tokenizer_mode="mistral", + enable_chunked_prefill=False, + max_model_len=max_model_len, + limit_mm_per_prompt=LIMIT_MM_PER_PROMPT, + ) as vllm_model: + outputs = [] + for msg in MSGS: + output = vllm_model.model.chat(msg, + sampling_params=SAMPLING_PARAMS) + + outputs.extend(output) + + logprobs = vllm_runner._final_steps_generate_w_logprobs(outputs) + check_logprobs_close(outputs_0_lst=EXPECTED_CHAT_LOGPROBS, + outputs_1_lst=logprobs, + name_0="h100_ref", + name_1="output") + + +@pytest.mark.skip( + reason= + "Model is too big, test passed on A100 locally but will OOM on CI machine." +) +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["bfloat16"]) +def test_model_engine(vllm_runner, model: str, dtype: str) -> None: + EXPECTED_ENGINE_LOGPROBS = load_outputs_w_logprobs(FIXTURE_LOGPROBS_ENGINE) + args = EngineArgs( + model=model, + tokenizer_mode="mistral", + enable_chunked_prefill=False, + limit_mm_per_prompt=LIMIT_MM_PER_PROMPT, + dtype=dtype, + ) + engine = LLMEngine.from_engine_args(args) + + engine.add_request(uuid.uuid4().hex, ENGINE_INPUTS[0], SAMPLING_PARAMS) + engine.add_request(uuid.uuid4().hex, ENGINE_INPUTS[1], SAMPLING_PARAMS) + + outputs = [] + count = 0 + while True: + out = engine.step() + count += 1 + for request_output in out: + if request_output.finished: + outputs.append(request_output) + + if count == 2: + engine.add_request(uuid.uuid4().hex, ENGINE_INPUTS[2], + SAMPLING_PARAMS) + if not engine.has_unfinished_requests(): + break + + logprobs = vllm_runner._final_steps_generate_w_logprobs(outputs) + check_logprobs_close(outputs_0_lst=EXPECTED_ENGINE_LOGPROBS, + outputs_1_lst=logprobs, + name_0="h100_ref", + name_1="output") diff --git a/tests/models/test_qwen.py b/tests/models/decoder_only/vision_language/test_qwen.py similarity index 98% rename from tests/models/test_qwen.py rename to tests/models/decoder_only/vision_language/test_qwen.py index 5e7f1de99d6c3..e4f79092b7606 100644 --- a/tests/models/test_qwen.py +++ b/tests/models/decoder_only/vision_language/test_qwen.py @@ -10,11 +10,9 @@ from vllm.multimodal.base import MultiModalInputs from vllm.multimodal.utils import cached_get_tokenizer, rescale_image_size -from ..conftest import (IMAGE_ASSETS, HfRunner, ImageAsset, PromptImageInput, - VllmRunner, _ImageAssets) -from .utils import check_logprobs_close - -pytestmark = pytest.mark.vlm +from ....conftest import (IMAGE_ASSETS, HfRunner, ImageAsset, PromptImageInput, + VllmRunner, _ImageAssets) +from ...utils import check_logprobs_close text_only_models = [ "Qwen/Qwen-7B-Chat" # Has no visual component diff --git a/tests/models/embedding/__init__.py b/tests/models/embedding/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/models/embedding/language/__init__.py b/tests/models/embedding/language/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/models/test_embedding.py b/tests/models/embedding/language/test_embedding.py similarity index 100% rename from tests/models/test_embedding.py rename to tests/models/embedding/language/test_embedding.py diff --git a/tests/models/encoder_decoder/__init__.py b/tests/models/encoder_decoder/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/models/encoder_decoder/language/__init__.py b/tests/models/encoder_decoder/language/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/models/test_bart.py b/tests/models/encoder_decoder/language/test_bart.py similarity index 69% rename from tests/models/test_bart.py rename to tests/models/encoder_decoder/language/test_bart.py index 660b61d1a7ade..758a9b743b397 100644 --- a/tests/models/test_bart.py +++ b/tests/models/encoder_decoder/language/test_bart.py @@ -1,8 +1,8 @@ """Compare the outputs of HF and vLLM for BART models using greedy sampling. -Run `pytest tests/models/test_bart.py`. +Run `pytest tests/models/encoder_decoder/language/test_bart.py`. """ -from typing import List, Optional, Tuple +from typing import List, Optional, Tuple, Type from vllm.utils import is_cpu @@ -16,8 +16,10 @@ from vllm.sequence import SampleLogprobs - from ..conftest import DecoderPromptType - from .utils import check_logprobs_close + from ....conftest import (DecoderPromptType, ExplicitEncoderDecoderPrompt, + HfRunner, VllmRunner) + from ....utils import multi_gpu_test + from ...utils import check_logprobs_close MODELS = ["facebook/bart-base", "facebook/bart-large-cnn"] @@ -34,20 +36,18 @@ def vllm_to_hf_output( return output_ids, hf_output_str, out_logprobs - @pytest.mark.parametrize("model", MODELS) - @pytest.mark.parametrize("dtype", ["float", "bfloat16"]) - @pytest.mark.parametrize("max_tokens", [64]) - @pytest.mark.parametrize("num_logprobs", [5]) - @pytest.mark.parametrize("decoder_prompt_type", list(DecoderPromptType)) - def test_models( - hf_runner, - vllm_runner, - example_encoder_decoder_prompts, + def run_test( + hf_runner: Type[HfRunner], + vllm_runner: Type[VllmRunner], + prompts: List[ExplicitEncoderDecoderPrompt[str, str]], + decoder_prompt_type: DecoderPromptType, model: str, + *, dtype: str, max_tokens: int, num_logprobs: int, - decoder_prompt_type: DecoderPromptType, + tensor_parallel_size: int, + distributed_executor_backend: Optional[str] = None, ) -> None: ''' Test the vLLM BART model for a variety of encoder/decoder input prompts, @@ -116,8 +116,29 @@ def test_models( token during the process of validating the vLLM decoded output. ''' - test_case_prompts = example_encoder_decoder_prompts[ - decoder_prompt_type] + # NOTE: take care of the order. run vLLM first, and then run HF. + # vLLM needs a fresh new process without cuda initialization. + # if we run HF first, the cuda initialization will be done and it + # will hurt multiprocessing backend with fork method (the default). + + # Note: currently encoder/decoder models are only compatible with + # enforce_eager=True. Normally this is not a problem because + # for encoder/decoder models vLLM will + # default to enforce_eager=True if enforce_eager + # is left unspecified. However, the + # VllmRunner test fixture (which wraps around the LLM class) defaults to + # enforce_eager=False (a behavior which a number of already-exisitng + # decoder-only unit tests expect), so when testing an encoder/decoder + # model we must explicitly specify enforce_eager=True in the VllmRunner + # constructor. + with vllm_runner( + model, + dtype=dtype, + tensor_parallel_size=tensor_parallel_size, + distributed_executor_backend=distributed_executor_backend, + enforce_eager=True) as vllm_model: + vllm_outputs = vllm_model.generate_encoder_decoder_greedy_logprobs( + prompts, max_tokens, num_logprobs) # Configuration settings for HF baseline hf_kwargs = { @@ -135,26 +156,12 @@ def test_models( auto_cls=AutoModelForSeq2SeqLM) as hf_model: hf_outputs = ( hf_model.generate_encoder_decoder_greedy_logprobs_limit( - test_case_prompts, + prompts, max_tokens, num_logprobs, **hf_kwargs, )) - # Note: currently encoder/decoder models are only compatible with - # enforce_eager=True. Normally this is not a problem because - # for encoder/decoder models vLLM will - # default to enforce_eager=True if enforce_eager - # is left unspecified. However, the - # VllmRunner test fixture (which wraps around the LLM class) defaults to - # enforce_eager=False (a behavior which a number of already-exisitng - # decoder-only unit tests expect), so when testing an encoder/decoder - # model we must explicitly specify enforce_eager=True in the VllmRunner - # constructor. - with vllm_runner(model, dtype=dtype, enforce_eager=True) as vllm_model: - vllm_outputs = vllm_model.generate_encoder_decoder_greedy_logprobs( - test_case_prompts, max_tokens, num_logprobs) - hf_skip_tokens = (1 if decoder_prompt_type == DecoderPromptType.NONE else 0) @@ -168,3 +175,49 @@ def test_models( name_1="vllm", num_outputs_0_skip_tokens=hf_skip_tokens, ) + + @pytest.mark.parametrize("model", MODELS) + @pytest.mark.parametrize("dtype", ["float", "bfloat16"]) + @pytest.mark.parametrize("max_tokens", [64]) + @pytest.mark.parametrize("num_logprobs", [5]) + @pytest.mark.parametrize("decoder_prompt_type", list(DecoderPromptType)) + def test_models(hf_runner, vllm_runner, example_encoder_decoder_prompts, + model, dtype, max_tokens, num_logprobs, + decoder_prompt_type) -> None: + + run_test( + hf_runner, + vllm_runner, + example_encoder_decoder_prompts[decoder_prompt_type], + decoder_prompt_type, + model, + dtype=dtype, + max_tokens=max_tokens, + num_logprobs=num_logprobs, + tensor_parallel_size=1, + ) + + @multi_gpu_test(num_gpus=2) + @pytest.mark.parametrize("distributed_executor_backend", ["ray", "mp"]) + @pytest.mark.parametrize("model", ["facebook/bart-large-cnn"]) + @pytest.mark.parametrize("dtype", ["float"]) + @pytest.mark.parametrize("max_tokens", [64]) + @pytest.mark.parametrize("num_logprobs", [5]) + @pytest.mark.parametrize("decoder_prompt_type", [DecoderPromptType.CUSTOM]) + def test_models_distributed(hf_runner, vllm_runner, + example_encoder_decoder_prompts, + distributed_executor_backend, model, dtype, + max_tokens, num_logprobs, + decoder_prompt_type) -> None: + run_test( + hf_runner, + vllm_runner, + example_encoder_decoder_prompts[decoder_prompt_type], + decoder_prompt_type, + model, + dtype=dtype, + max_tokens=max_tokens, + num_logprobs=num_logprobs, + tensor_parallel_size=2, + distributed_executor_backend=distributed_executor_backend, + ) diff --git a/tests/models/fixtures/pixtral_chat.json b/tests/models/fixtures/pixtral_chat.json new file mode 100644 index 0000000000000..643afb83d29b8 --- /dev/null +++ b/tests/models/fixtures/pixtral_chat.json @@ -0,0 +1 @@ +[[[1784, 3937, 6122, 1261, 7244, 10575, 18970, 1408, 1261, 32656, 4691, 1046, 2], "The image shows a black dog sitting on a wooden surface.", [{"1784": {"logprob": -0.11687260121107101, "rank": 1, "decoded_token": "The"}, "4380": {"logprob": -2.366872549057007, "rank": 2, "decoded_token": "This"}, "1049": {"logprob": -4.741872787475586, "rank": 3, "decoded_token": "1"}, "117991": {"logprob": -5.991872787475586, "rank": 4, "decoded_token": "Certain"}, "1785": {"logprob": -5.991872787475586, "rank": 5, "decoded_token": "In"}}, {"3937": {"logprob": -0.28887900710105896, "rank": 1, "decoded_token": " image"}, "2158": {"logprob": -1.4138790369033813, "rank": 2, "decoded_token": " first"}, "3977": {"logprob": -5.788878917694092, "rank": 3, "decoded_token": " top"}, "7244": {"logprob": -6.163878917694092, "rank": 4, "decoded_token": " black"}, "8061": {"logprob": -6.788878917694092, "rank": 5, "decoded_token": " images"}}, {"6122": {"logprob": -0.9653709530830383, "rank": 1, "decoded_token": " shows"}, "51948": {"logprob": -1.4653708934783936, "rank": 2, "decoded_token": " depicts"}, "6971": {"logprob": -1.4653708934783936, "rank": 3, "decoded_token": " features"}, "25981": {"logprob": -2.8403708934783936, "rank": 4, "decoded_token": " displays"}, "8688": {"logprob": -2.8403708934783936, "rank": 5, "decoded_token": " contains"}}, {"1261": {"logprob": -0.003059827256947756, "rank": 1, "decoded_token": " a"}, "1420": {"logprob": -6.2530598640441895, "rank": 2, "decoded_token": " an"}, "2295": {"logprob": -7.8780598640441895, "rank": 3, "decoded_token": " two"}, "2342": {"logprob": -7.8780598640441895, "rank": 4, "decoded_token": " only"}, "1278": {"logprob": -8.628059387207031, "rank": 5, "decoded_token": " the"}}, {"7244": {"logprob": -0.17616479098796844, "rank": 1, "decoded_token": " black"}, "6231": {"logprob": -2.3011648654937744, "rank": 2, "decoded_token": " close"}, "4249": {"logprob": -3.4261648654937744, "rank": 3, "decoded_token": " single"}, "4329": {"logprob": -5.113664627075195, "rank": 4, "decoded_token": " large"}, "10575": {"logprob": -5.176164627075195, "rank": 5, "decoded_token": " dog"}}, {"10575": {"logprob": -0.10940006375312805, "rank": 1, "decoded_token": " dog"}, "116572": {"logprob": -2.4844000339508057, "rank": 2, "decoded_token": " puppy"}, "119075": {"logprob": -4.109400272369385, "rank": 3, "decoded_token": " Labrador"}, "15812": {"logprob": -7.296900272369385, "rank": 4, "decoded_token": " Lab"}, "7990": {"logprob": -7.421900272369385, "rank": 5, "decoded_token": " cat"}}, {"18970": {"logprob": -0.8322296738624573, "rank": 1, "decoded_token": " sitting"}, "1454": {"logprob": -1.5822296142578125, "rank": 2, "decoded_token": " with"}, "28528": {"logprob": -1.9572296142578125, "rank": 3, "decoded_token": " lying"}, "7283": {"logprob": -2.2072296142578125, "rank": 4, "decoded_token": " looking"}, "15866": {"logprob": -3.0197296142578125, "rank": 5, "decoded_token": " standing"}}, {"1408": {"logprob": -0.08769982308149338, "rank": 1, "decoded_token": " on"}, "1321": {"logprob": -3.7126998901367188, "rank": 2, "decoded_token": " and"}, "3675": {"logprob": -3.9626998901367188, "rank": 3, "decoded_token": " against"}, "41132": {"logprob": -4.587699890136719, "rank": 4, "decoded_token": " attent"}, "1454": {"logprob": -5.087699890136719, "rank": 5, "decoded_token": " with"}}, {"1261": {"logprob": -0.5400654673576355, "rank": 1, "decoded_token": " a"}, "32656": {"logprob": -0.9150654673576355, "rank": 2, "decoded_token": " wooden"}, "3977": {"logprob": -5.415065288543701, "rank": 3, "decoded_token": " top"}, "12603": {"logprob": -5.540065288543701, "rank": 4, "decoded_token": " wood"}, "44130": {"logprob": -6.290065288543701, "rank": 5, "decoded_token": " rust"}}, {"32656": {"logprob": -0.02516966126859188, "rank": 1, "decoded_token": " wooden"}, "44130": {"logprob": -4.400169849395752, "rank": 2, "decoded_token": " rust"}, "12603": {"logprob": -5.275169849395752, "rank": 3, "decoded_token": " wood"}, "3403": {"logprob": -5.525169849395752, "rank": 4, "decoded_token": " text"}, "17253": {"logprob": -6.962669849395752, "rank": 5, "decoded_token": " weather"}}, {"4691": {"logprob": -0.7264319658279419, "rank": 1, "decoded_token": " surface"}, "11237": {"logprob": -0.8514319658279419, "rank": 2, "decoded_token": " floor"}, "7042": {"logprob": -2.6014318466186523, "rank": 3, "decoded_token": " background"}, "28984": {"logprob": -5.226431846618652, "rank": 4, "decoded_token": " deck"}, "1615": {"logprob": -5.726431846618652, "rank": 5, "decoded_token": " pl"}}, {"1046": {"logprob": -0.4668232202529907, "rank": 1, "decoded_token": "."}, "1044": {"logprob": -1.9668232202529907, "rank": 2, "decoded_token": ","}, "1321": {"logprob": -2.466823101043701, "rank": 3, "decoded_token": " and"}, "7283": {"logprob": -2.716823101043701, "rank": 4, "decoded_token": " looking"}, "1454": {"logprob": -2.716823101043701, "rank": 5, "decoded_token": " with"}}, {"2": {"logprob": -0.002247072057798505, "rank": 1, "decoded_token": ""}, "1531": {"logprob": -6.627246856689453, "rank": 2, "decoded_token": " The"}, "1032": {"logprob": -7.127246856689453, "rank": 3, "decoded_token": " "}, "3730": {"logprob": -9.877246856689453, "rank": 4, "decoded_token": " There"}, "1256": {"logprob": -11.127246856689453, "rank": 5, "decoded_token": " "}}]], [[1049, 1046, 1349, 7244, 10575, 1454, 2327, 94766, 32961, 53048, 41132, 3923, 1408, 1261, 32656, 4691, 1626, 1050, 1046, 1349, 15375, 24361, 4521, 1454, 122203, 27469, 94973, 2425, 1261, 16152, 1121, 21283, 1046, 2], "1. A black dog with floppy ears sits attentively on a wooden surface.\n2. A vast mountain range with rugged peaks stretches under a cloudy sky.", [{"1049": {"logprob": -0.42824622988700867, "rank": 1, "decoded_token": "1"}, "1045": {"logprob": -1.553246259689331, "rank": 2, "decoded_token": "-"}, "1065": {"logprob": -2.428246259689331, "rank": 3, "decoded_token": "A"}, "1784": {"logprob": -4.053246021270752, "rank": 4, "decoded_token": "The"}, "69957": {"logprob": -4.428246021270752, "rank": 5, "decoded_token": "Sure"}}, {"1046": {"logprob": -1.9788545614574105e-05, "rank": 1, "decoded_token": "."}, "1058": {"logprob": -11.750020027160645, "rank": 2, "decoded_token": ":"}, "3590": {"logprob": -12.125020027160645, "rank": 3, "decoded_token": ".A"}, "1065": {"logprob": -13.062520027160645, "rank": 4, "decoded_token": "A"}, "1041": {"logprob": -13.750020027160645, "rank": 5, "decoded_token": ")"}}, {"1349": {"logprob": -0.14020134508609772, "rank": 1, "decoded_token": " A"}, "1429": {"logprob": -2.3902013301849365, "rank": 2, "decoded_token": " \""}, "1603": {"logprob": -3.7652013301849365, "rank": 3, "decoded_token": " **"}, "11967": {"logprob": -4.890201568603516, "rank": 4, "decoded_token": " Image"}, "1531": {"logprob": -5.015201568603516, "rank": 5, "decoded_token": " The"}}, {"7244": {"logprob": -0.2003599852323532, "rank": 1, "decoded_token": " black"}, "38462": {"logprob": -3.075360059738159, "rank": 2, "decoded_token": " curious"}, "68076": {"logprob": -3.575360059738159, "rank": 3, "decoded_token": " cute"}, "4329": {"logprob": -3.887860059738159, "rank": 4, "decoded_token": " large"}, "6231": {"logprob": -4.32535982131958, "rank": 5, "decoded_token": " close"}}, {"10575": {"logprob": -0.18818901479244232, "rank": 1, "decoded_token": " dog"}, "116572": {"logprob": -2.0631890296936035, "rank": 2, "decoded_token": " puppy"}, "119075": {"logprob": -3.1881890296936035, "rank": 3, "decoded_token": " Labrador"}, "15812": {"logprob": -6.9381890296936035, "rank": 4, "decoded_token": " Lab"}, "8636": {"logprob": -7.3131890296936035, "rank": 5, "decoded_token": " lab"}}, {"1454": {"logprob": -0.5699259042739868, "rank": 1, "decoded_token": " with"}, "53048": {"logprob": -1.2574259042739868, "rank": 2, "decoded_token": " sits"}, "1395": {"logprob": -3.0699257850646973, "rank": 3, "decoded_token": " is"}, "22524": {"logprob": -3.6324257850646973, "rank": 4, "decoded_token": " lies"}, "18970": {"logprob": -3.7574257850646973, "rank": 5, "decoded_token": " sitting"}}, {"2327": {"logprob": -1.2377738952636719, "rank": 1, "decoded_token": " fl"}, "1261": {"logprob": -1.3627738952636719, "rank": 2, "decoded_token": " a"}, "17300": {"logprob": -1.9252738952636719, "rank": 3, "decoded_token": " soul"}, "100089": {"logprob": -2.675273895263672, "rank": 4, "decoded_token": " expressive"}, "6444": {"logprob": -3.237773895263672, "rank": 5, "decoded_token": " soft"}}, {"94766": {"logprob": -0.0025601964443922043, "rank": 1, "decoded_token": "oppy"}, "124603": {"logprob": -6.315060138702393, "rank": 2, "decoded_token": "uffy"}, "1484": {"logprob": -7.877560138702393, "rank": 3, "decoded_token": "op"}, "24897": {"logprob": -8.81506061553955, "rank": 4, "decoded_token": "appy"}, "102477": {"logprob": -9.69006061553955, "rank": 5, "decoded_token": "opping"}}, {"32961": {"logprob": -5.113947918289341e-05, "rank": 1, "decoded_token": " ears"}, "16962": {"logprob": -11.250051498413086, "rank": 2, "decoded_token": " ear"}, "5731": {"logprob": -11.812551498413086, "rank": 3, "decoded_token": " eyes"}, "3351": {"logprob": -12.000051498413086, "rank": 4, "decoded_token": " years"}, "42071": {"logprob": -13.062551498413086, "rank": 5, "decoded_token": " cheeks"}}, {"53048": {"logprob": -0.6179640889167786, "rank": 1, "decoded_token": " sits"}, "10637": {"logprob": -1.9929640293121338, "rank": 2, "decoded_token": " looks"}, "1321": {"logprob": -2.430464029312134, "rank": 3, "decoded_token": " and"}, "1395": {"logprob": -2.617964029312134, "rank": 4, "decoded_token": " is"}, "18970": {"logprob": -3.055464029312134, "rank": 5, "decoded_token": " sitting"}}, {"41132": {"logprob": -0.3746516704559326, "rank": 1, "decoded_token": " attent"}, "1408": {"logprob": -2.3121516704559326, "rank": 2, "decoded_token": " on"}, "106534": {"logprob": -2.3746516704559326, "rank": 3, "decoded_token": " calmly"}, "12276": {"logprob": -2.6246516704559326, "rank": 4, "decoded_token": " alert"}, "6482": {"logprob": -5.124651908874512, "rank": 5, "decoded_token": " patient"}}, {"3923": {"logprob": -8.463501580990851e-05, "rank": 1, "decoded_token": "ively"}, "1556": {"logprob": -9.50008487701416, "rank": 2, "decoded_token": "ive"}, "6655": {"logprob": -11.87508487701416, "rank": 3, "decoded_token": "atively"}, "3929": {"logprob": -14.00008487701416, "rank": 4, "decoded_token": "ently"}, "47885": {"logprob": -14.62508487701416, "rank": 5, "decoded_token": "edly"}}, {"1408": {"logprob": -0.06439964473247528, "rank": 1, "decoded_token": " on"}, "3675": {"logprob": -3.0643997192382812, "rank": 2, "decoded_token": " against"}, "1294": {"logprob": -4.939399719238281, "rank": 3, "decoded_token": " in"}, "7283": {"logprob": -5.689399719238281, "rank": 4, "decoded_token": " looking"}, "1044": {"logprob": -5.814399719238281, "rank": 5, "decoded_token": ","}}, {"1261": {"logprob": -0.2108541578054428, "rank": 1, "decoded_token": " a"}, "32656": {"logprob": -1.710854172706604, "rank": 2, "decoded_token": " wooden"}, "17253": {"logprob": -5.5858540534973145, "rank": 3, "decoded_token": " weather"}, "44130": {"logprob": -6.0858540534973145, "rank": 4, "decoded_token": " rust"}, "12603": {"logprob": -6.9608540534973145, "rank": 5, "decoded_token": " wood"}}, {"32656": {"logprob": -0.08556432276964188, "rank": 1, "decoded_token": " wooden"}, "44130": {"logprob": -2.710564374923706, "rank": 2, "decoded_token": " rust"}, "17253": {"logprob": -4.710564136505127, "rank": 3, "decoded_token": " weather"}, "12603": {"logprob": -5.960564136505127, "rank": 4, "decoded_token": " wood"}, "3403": {"logprob": -5.960564136505127, "rank": 5, "decoded_token": " text"}}, {"4691": {"logprob": -0.7751782536506653, "rank": 1, "decoded_token": " surface"}, "11237": {"logprob": -0.7751782536506653, "rank": 2, "decoded_token": " floor"}, "7042": {"logprob": -2.9001781940460205, "rank": 3, "decoded_token": " background"}, "28984": {"logprob": -4.1501784324646, "rank": 4, "decoded_token": " deck"}, "92504": {"logprob": -6.1501784324646, "rank": 5, "decoded_token": " backdrop"}}, {"1626": {"logprob": -0.12918435037136078, "rank": 1, "decoded_token": ".\n"}, "1044": {"logprob": -2.3791842460632324, "rank": 2, "decoded_token": ","}, "1046": {"logprob": -4.129184246063232, "rank": 3, "decoded_token": "."}, "1338": {"logprob": -5.129184246063232, "rank": 4, "decoded_token": ".\n\n"}, "7283": {"logprob": -5.629184246063232, "rank": 5, "decoded_token": " looking"}}, {"1050": {"logprob": -0.00017474555352237076, "rank": 1, "decoded_token": "2"}, "1256": {"logprob": -9.000174522399902, "rank": 2, "decoded_token": " "}, "1032": {"logprob": -10.875174522399902, "rank": 3, "decoded_token": " "}, "1293": {"logprob": -11.625174522399902, "rank": 4, "decoded_token": " "}, "1051": {"logprob": -12.125174522399902, "rank": 5, "decoded_token": "3"}}, {"1046": {"logprob": -7.629365427419543e-06, "rank": 1, "decoded_token": "."}, "3590": {"logprob": -12.875007629394531, "rank": 2, "decoded_token": ".A"}, "1626": {"logprob": -13.062507629394531, "rank": 3, "decoded_token": ".\n"}, "1338": {"logprob": -14.562507629394531, "rank": 4, "decoded_token": ".\n\n"}, "1058": {"logprob": -14.812507629394531, "rank": 5, "decoded_token": ":"}}, {"1349": {"logprob": -0.558266282081604, "rank": 1, "decoded_token": " A"}, "11826": {"logprob": -1.495766282081604, "rank": 2, "decoded_token": " Maj"}, "37159": {"logprob": -2.2457661628723145, "rank": 3, "decoded_token": " Snow"}, "113465": {"logprob": -3.9957661628723145, "rank": 4, "decoded_token": " Rug"}, "1531": {"logprob": -3.9957661628723145, "rank": 5, "decoded_token": " The"}}, {"15375": {"logprob": -0.6446555852890015, "rank": 1, "decoded_token": " vast"}, "37849": {"logprob": -2.019655704498291, "rank": 2, "decoded_token": " breat"}, "61082": {"logprob": -2.394655704498291, "rank": 3, "decoded_token": " panor"}, "10726": {"logprob": -3.082155704498291, "rank": 4, "decoded_token": " scen"}, "2169": {"logprob": -3.207155704498291, "rank": 5, "decoded_token": " ser"}}, {"24361": {"logprob": -0.7034653425216675, "rank": 1, "decoded_token": " mountain"}, "127945": {"logprob": -1.9534653425216675, "rank": 2, "decoded_token": " mountainous"}, "1044": {"logprob": -2.078465461730957, "rank": 3, "decoded_token": ","}, "4521": {"logprob": -2.328465461730957, "rank": 4, "decoded_token": " range"}, "28035": {"logprob": -2.453465461730957, "rank": 5, "decoded_token": " landscape"}}, {"4521": {"logprob": -0.07058106362819672, "rank": 1, "decoded_token": " range"}, "28035": {"logprob": -2.6955809593200684, "rank": 2, "decoded_token": " landscape"}, "37691": {"logprob": -8.320581436157227, "rank": 3, "decoded_token": " valley"}, "12248": {"logprob": -9.445581436157227, "rank": 4, "decoded_token": " peak"}, "13327": {"logprob": -9.695581436157227, "rank": 5, "decoded_token": " scene"}}, {"1454": {"logprob": -1.1448894739151, "rank": 1, "decoded_token": " with"}, "94973": {"logprob": -1.1448894739151, "rank": 2, "decoded_token": " stretches"}, "2425": {"logprob": -1.8948894739151, "rank": 3, "decoded_token": " under"}, "1395": {"logprob": -2.5198893547058105, "rank": 4, "decoded_token": " is"}, "13875": {"logprob": -3.0198893547058105, "rank": 5, "decoded_token": " covered"}}, {"122203": {"logprob": -1.0288245677947998, "rank": 1, "decoded_token": " rugged"}, "58127": {"logprob": -1.6538245677947998, "rank": 2, "decoded_token": " jag"}, "27469": {"logprob": -2.1538245677948, "rank": 3, "decoded_token": " peaks"}, "23745": {"logprob": -2.6538245677948, "rank": 4, "decoded_token": " snow"}, "95746": {"logprob": -2.8413245677948, "rank": 5, "decoded_token": " rocky"}}, {"27469": {"logprob": -0.20564845204353333, "rank": 1, "decoded_token": " peaks"}, "24765": {"logprob": -2.580648422241211, "rank": 2, "decoded_token": " terrain"}, "130655": {"logprob": -2.955648422241211, "rank": 3, "decoded_token": ""}, "1044": {"logprob": -3.580648422241211, "rank": 4, "decoded_token": ","}, "61263": {"logprob": -4.455648422241211, "rank": 5, "decoded_token": " slopes"}}, {"94973": {"logprob": -1.0839273929595947, "rank": 1, "decoded_token": " stretches"}, "1321": {"logprob": -1.1464273929595947, "rank": 2, "decoded_token": " and"}, "2425": {"logprob": -1.7714273929595947, "rank": 3, "decoded_token": " under"}, "13875": {"logprob": -3.0839273929595947, "rank": 4, "decoded_token": " covered"}, "1395": {"logprob": -3.2714273929595947, "rank": 5, "decoded_token": " is"}}, {"2425": {"logprob": -0.9016233682632446, "rank": 1, "decoded_token": " under"}, "5669": {"logprob": -1.0266233682632446, "rank": 2, "decoded_token": " across"}, "1848": {"logprob": -1.9016233682632446, "rank": 3, "decoded_token": " out"}, "2203": {"logprob": -3.151623249053955, "rank": 4, "decoded_token": " into"}, "8994": {"logprob": -4.026623249053955, "rank": 5, "decoded_token": " towards"}}, {"1261": {"logprob": -0.00555459875613451, "rank": 1, "decoded_token": " a"}, "1420": {"logprob": -5.380554676055908, "rank": 2, "decoded_token": " an"}, "1278": {"logprob": -7.630554676055908, "rank": 3, "decoded_token": " the"}, "2136": {"logprob": -9.31805419921875, "rank": 4, "decoded_token": " over"}, "16152": {"logprob": -9.38055419921875, "rank": 5, "decoded_token": " cloud"}}, {"16152": {"logprob": -0.6862213015556335, "rank": 1, "decoded_token": " cloud"}, "6133": {"logprob": -1.4362213611602783, "rank": 2, "decoded_token": " clear"}, "18416": {"logprob": -2.6862213611602783, "rank": 3, "decoded_token": " haz"}, "27254": {"logprob": -3.0612213611602783, "rank": 4, "decoded_token": " partly"}, "4391": {"logprob": -3.1862213611602783, "rank": 5, "decoded_token": " light"}}, {"1121": {"logprob": -0.10446903109550476, "rank": 1, "decoded_token": "y"}, "4527": {"logprob": -2.854469060897827, "rank": 2, "decoded_token": "less"}, "1286": {"logprob": -3.479469060897827, "rank": 3, "decoded_token": "ed"}, "114525": {"logprob": -5.479468822479248, "rank": 4, "decoded_token": "-covered"}, "77187": {"logprob": -5.479468822479248, "rank": 5, "decoded_token": "-filled"}}, {"21283": {"logprob": -0.003459066851064563, "rank": 1, "decoded_token": " sky"}, "10991": {"logprob": -6.3784589767456055, "rank": 2, "decoded_token": " blue"}, "1044": {"logprob": -6.8784589767456055, "rank": 3, "decoded_token": ","}, "26549": {"logprob": -7.8784589767456055, "rank": 4, "decoded_token": " gray"}, "34052": {"logprob": -8.503458976745605, "rank": 5, "decoded_token": " grey"}}, {"1046": {"logprob": -0.01103890035301447, "rank": 1, "decoded_token": "."}, "1044": {"logprob": -4.636038780212402, "rank": 2, "decoded_token": ","}, "1338": {"logprob": -7.261038780212402, "rank": 3, "decoded_token": ".\n\n"}, "1294": {"logprob": -8.136038780212402, "rank": 4, "decoded_token": " in"}, "1454": {"logprob": -8.761038780212402, "rank": 5, "decoded_token": " with"}}, {"2": {"logprob": -9.059865078597795e-06, "rank": 1, "decoded_token": ""}, "1032": {"logprob": -11.625008583068848, "rank": 2, "decoded_token": " "}, "1256": {"logprob": -16.125009536743164, "rank": 3, "decoded_token": " "}, "1319": {"logprob": -17.375009536743164, "rank": 4, "decoded_token": " ("}, "1766": {"logprob": -18.750009536743164, "rank": 5, "decoded_token": " ["}}]], [[1049, 1046, 1349, 7244, 10575, 53048, 41132, 3923, 1408, 1261, 32656, 11237, 1626, 1050, 1046, 1349, 15375, 24361, 4521, 94973, 5669, 1278, 48932, 2425, 1261, 16152, 1121, 21283, 1626, 1051, 1046, 8342, 71284, 7377, 1394, 22140, 1294, 1278, 27208, 1513, 97558, 1626, 1052, 1046, 1349, 53301, 59396, 3549, 13335, 2645, 1261, 1295, 3506, 11223, 12097, 1046, 2], "1. A black dog sits attentively on a wooden floor.\n2. A vast mountain range stretches across the horizon under a cloudy sky.\n3. Surfers wait for waves in the ocean at sunset.\n4. A winding gravel path leads through a lush green park.", [{"1049": {"logprob": -0.05001257359981537, "rank": 1, "decoded_token": "1"}, "1045": {"logprob": -3.1750125885009766, "rank": 2, "decoded_token": "-"}, "69957": {"logprob": -5.925012588500977, "rank": 3, "decoded_token": "Sure"}, "11745": {"logprob": -6.425012588500977, "rank": 4, "decoded_token": "Here"}, "1065": {"logprob": -6.425012588500977, "rank": 5, "decoded_token": "A"}}, {"1046": {"logprob": -9.536697689327411e-06, "rank": 1, "decoded_token": "."}, "1058": {"logprob": -11.875009536743164, "rank": 2, "decoded_token": ":"}, "3590": {"logprob": -13.375009536743164, "rank": 3, "decoded_token": ".A"}, "1041": {"logprob": -14.750009536743164, "rank": 4, "decoded_token": ")"}, "1065": {"logprob": -15.687509536743164, "rank": 5, "decoded_token": "A"}}, {"1349": {"logprob": -0.12580634653568268, "rank": 1, "decoded_token": " A"}, "1429": {"logprob": -2.3758063316345215, "rank": 2, "decoded_token": " \""}, "1531": {"logprob": -4.6258063316345215, "rank": 3, "decoded_token": " The"}, "11967": {"logprob": -4.6258063316345215, "rank": 4, "decoded_token": " Image"}, "1603": {"logprob": -5.6258063316345215, "rank": 5, "decoded_token": " **"}}, {"7244": {"logprob": -0.15412142872810364, "rank": 1, "decoded_token": " black"}, "68076": {"logprob": -3.3416213989257812, "rank": 2, "decoded_token": " cute"}, "6231": {"logprob": -3.9666213989257812, "rank": 3, "decoded_token": " close"}, "38462": {"logprob": -4.216621398925781, "rank": 4, "decoded_token": " curious"}, "4329": {"logprob": -4.404121398925781, "rank": 5, "decoded_token": " large"}}, {"10575": {"logprob": -0.12086891382932663, "rank": 1, "decoded_token": " dog"}, "116572": {"logprob": -2.3708689212799072, "rank": 2, "decoded_token": " puppy"}, "119075": {"logprob": -3.9958689212799072, "rank": 3, "decoded_token": " Labrador"}, "15812": {"logprob": -7.683368682861328, "rank": 4, "decoded_token": " Lab"}, "8636": {"logprob": -7.808368682861328, "rank": 5, "decoded_token": " lab"}}, {"53048": {"logprob": -0.8729249238967896, "rank": 1, "decoded_token": " sits"}, "1454": {"logprob": -1.1229249238967896, "rank": 2, "decoded_token": " with"}, "1395": {"logprob": -2.4354248046875, "rank": 3, "decoded_token": " is"}, "18970": {"logprob": -2.6854248046875, "rank": 4, "decoded_token": " sitting"}, "22524": {"logprob": -3.6854248046875, "rank": 5, "decoded_token": " lies"}}, {"41132": {"logprob": -0.5888903737068176, "rank": 1, "decoded_token": " attent"}, "106534": {"logprob": -1.2763903141021729, "rank": 2, "decoded_token": " calmly"}, "12276": {"logprob": -2.838890314102173, "rank": 3, "decoded_token": " alert"}, "1408": {"logprob": -2.901390314102173, "rank": 4, "decoded_token": " on"}, "6482": {"logprob": -5.026390552520752, "rank": 5, "decoded_token": " patient"}}, {"3923": {"logprob": -9.16677454370074e-05, "rank": 1, "decoded_token": "ively"}, "1556": {"logprob": -9.625091552734375, "rank": 2, "decoded_token": "ive"}, "6655": {"logprob": -10.875091552734375, "rank": 3, "decoded_token": "atively"}, "3929": {"logprob": -13.125091552734375, "rank": 4, "decoded_token": "ently"}, "47885": {"logprob": -13.750091552734375, "rank": 5, "decoded_token": "edly"}}, {"1408": {"logprob": -0.052677519619464874, "rank": 1, "decoded_token": " on"}, "3675": {"logprob": -3.802677631378174, "rank": 2, "decoded_token": " against"}, "1454": {"logprob": -4.302677631378174, "rank": 3, "decoded_token": " with"}, "1294": {"logprob": -5.177677631378174, "rank": 4, "decoded_token": " in"}, "7283": {"logprob": -5.427677631378174, "rank": 5, "decoded_token": " looking"}}, {"1261": {"logprob": -0.36706605553627014, "rank": 1, "decoded_token": " a"}, "32656": {"logprob": -1.2420660257339478, "rank": 2, "decoded_token": " wooden"}, "17253": {"logprob": -4.617065906524658, "rank": 3, "decoded_token": " weather"}, "44130": {"logprob": -5.742065906524658, "rank": 4, "decoded_token": " rust"}, "12603": {"logprob": -6.617065906524658, "rank": 5, "decoded_token": " wood"}}, {"32656": {"logprob": -0.07824385166168213, "rank": 1, "decoded_token": " wooden"}, "44130": {"logprob": -2.8282437324523926, "rank": 2, "decoded_token": " rust"}, "17253": {"logprob": -4.703243732452393, "rank": 3, "decoded_token": " weather"}, "12603": {"logprob": -5.828243732452393, "rank": 4, "decoded_token": " wood"}, "3403": {"logprob": -5.953243732452393, "rank": 5, "decoded_token": " text"}}, {"11237": {"logprob": -0.5853750705718994, "rank": 1, "decoded_token": " floor"}, "4691": {"logprob": -1.0853750705718994, "rank": 2, "decoded_token": " surface"}, "7042": {"logprob": -2.7103750705718994, "rank": 3, "decoded_token": " background"}, "28984": {"logprob": -3.5853750705718994, "rank": 4, "decoded_token": " deck"}, "92504": {"logprob": -6.08537483215332, "rank": 5, "decoded_token": " backdrop"}}, {"1626": {"logprob": -0.7340722680091858, "rank": 1, "decoded_token": ".\n"}, "1044": {"logprob": -0.8590722680091858, "rank": 2, "decoded_token": ","}, "1454": {"logprob": -3.359072208404541, "rank": 3, "decoded_token": " with"}, "7283": {"logprob": -3.609072208404541, "rank": 4, "decoded_token": " looking"}, "1321": {"logprob": -4.109072208404541, "rank": 5, "decoded_token": " and"}}, {"1050": {"logprob": -1.1324817933200393e-05, "rank": 1, "decoded_token": "2"}, "1051": {"logprob": -11.625011444091797, "rank": 2, "decoded_token": "3"}, "1256": {"logprob": -14.000011444091797, "rank": 3, "decoded_token": " "}, "1049": {"logprob": -14.625011444091797, "rank": 4, "decoded_token": "1"}, "1032": {"logprob": -14.625011444091797, "rank": 5, "decoded_token": " "}}, {"1046": {"logprob": -2.50339189733495e-06, "rank": 1, "decoded_token": "."}, "3590": {"logprob": -13.56250286102295, "rank": 2, "decoded_token": ".A"}, "1626": {"logprob": -15.43750286102295, "rank": 3, "decoded_token": ".\n"}, "4700": {"logprob": -15.50000286102295, "rank": 4, "decoded_token": ".M"}, "3051": {"logprob": -16.000001907348633, "rank": 5, "decoded_token": ".S"}}, {"1349": {"logprob": -0.6769706010818481, "rank": 1, "decoded_token": " A"}, "11826": {"logprob": -1.9269706010818481, "rank": 2, "decoded_token": " Maj"}, "37159": {"logprob": -2.1144704818725586, "rank": 3, "decoded_token": " Snow"}, "27260": {"logprob": -2.6144704818725586, "rank": 4, "decoded_token": " Mountain"}, "113465": {"logprob": -2.8644704818725586, "rank": 5, "decoded_token": " Rug"}}, {"15375": {"logprob": -0.9251430034637451, "rank": 1, "decoded_token": " vast"}, "10726": {"logprob": -2.300143003463745, "rank": 2, "decoded_token": " scen"}, "4521": {"logprob": -2.362643003463745, "rank": 3, "decoded_token": " range"}, "122203": {"logprob": -2.425143003463745, "rank": 4, "decoded_token": " rugged"}, "61082": {"logprob": -2.800143003463745, "rank": 5, "decoded_token": " panor"}}, {"24361": {"logprob": -0.5277582406997681, "rank": 1, "decoded_token": " mountain"}, "127945": {"logprob": -1.902758240699768, "rank": 2, "decoded_token": " mountainous"}, "28035": {"logprob": -2.5277581214904785, "rank": 3, "decoded_token": " landscape"}, "4521": {"logprob": -2.5277581214904785, "rank": 4, "decoded_token": " range"}, "1044": {"logprob": -2.7777581214904785, "rank": 5, "decoded_token": ","}}, {"4521": {"logprob": -0.055658817291259766, "rank": 1, "decoded_token": " range"}, "28035": {"logprob": -2.9306588172912598, "rank": 2, "decoded_token": " landscape"}, "37691": {"logprob": -8.430658340454102, "rank": 3, "decoded_token": " valley"}, "13327": {"logprob": -9.055658340454102, "rank": 4, "decoded_token": " scene"}, "3719": {"logprob": -9.805658340454102, "rank": 5, "decoded_token": " view"}}, {"94973": {"logprob": -0.6880245208740234, "rank": 1, "decoded_token": " stretches"}, "2425": {"logprob": -1.7505245208740234, "rank": 2, "decoded_token": " under"}, "1395": {"logprob": -2.3130245208740234, "rank": 3, "decoded_token": " is"}, "1454": {"logprob": -2.6880245208740234, "rank": 4, "decoded_token": " with"}, "7038": {"logprob": -3.2505245208740234, "rank": 5, "decoded_token": " extends"}}, {"5669": {"logprob": -0.4545598328113556, "rank": 1, "decoded_token": " across"}, "2425": {"logprob": -1.4545598030090332, "rank": 2, "decoded_token": " under"}, "1848": {"logprob": -2.454559803009033, "rank": 3, "decoded_token": " out"}, "2203": {"logprob": -4.204559803009033, "rank": 4, "decoded_token": " into"}, "25136": {"logprob": -4.642059803009033, "rank": 5, "decoded_token": " beneath"}}, {"1278": {"logprob": -0.23015151917934418, "rank": 1, "decoded_token": " the"}, "1261": {"logprob": -1.6051515340805054, "rank": 2, "decoded_token": " a"}, "1420": {"logprob": -5.605151653289795, "rank": 3, "decoded_token": " an"}, "2425": {"logprob": -7.167651653289795, "rank": 4, "decoded_token": " under"}, "1454": {"logprob": -10.167651176452637, "rank": 5, "decoded_token": " with"}}, {"48932": {"logprob": -0.2797861397266388, "rank": 1, "decoded_token": " horizon"}, "21283": {"logprob": -2.0297861099243164, "rank": 2, "decoded_token": " sky"}, "3937": {"logprob": -3.2797861099243164, "rank": 3, "decoded_token": " image"}, "28035": {"logprob": -3.6547861099243164, "rank": 4, "decoded_token": " landscape"}, "3044": {"logprob": -3.7797861099243164, "rank": 5, "decoded_token": " sk"}}, {"2425": {"logprob": -0.28862035274505615, "rank": 1, "decoded_token": " under"}, "1044": {"logprob": -2.4136204719543457, "rank": 2, "decoded_token": ","}, "1454": {"logprob": -2.5386204719543457, "rank": 3, "decoded_token": " with"}, "1626": {"logprob": -3.7886204719543457, "rank": 4, "decoded_token": ".\n"}, "1408": {"logprob": -3.9136204719543457, "rank": 5, "decoded_token": " on"}}, {"1261": {"logprob": -0.04524127021431923, "rank": 1, "decoded_token": " a"}, "16152": {"logprob": -4.045241355895996, "rank": 2, "decoded_token": " cloud"}, "1420": {"logprob": -4.045241355895996, "rank": 3, "decoded_token": " an"}, "2136": {"logprob": -6.107741355895996, "rank": 4, "decoded_token": " over"}, "6133": {"logprob": -6.357741355895996, "rank": 5, "decoded_token": " clear"}}, {"16152": {"logprob": -0.19613930583000183, "rank": 1, "decoded_token": " cloud"}, "6133": {"logprob": -2.883639335632324, "rank": 2, "decoded_token": " clear"}, "27254": {"logprob": -3.508639335632324, "rank": 3, "decoded_token": " partly"}, "18416": {"logprob": -3.883639335632324, "rank": 4, "decoded_token": " haz"}, "4391": {"logprob": -4.321139335632324, "rank": 5, "decoded_token": " light"}}, {"1121": {"logprob": -0.05146069824695587, "rank": 1, "decoded_token": "y"}, "1286": {"logprob": -3.8014607429504395, "rank": 2, "decoded_token": "ed"}, "77187": {"logprob": -4.5514607429504395, "rank": 3, "decoded_token": "-filled"}, "114525": {"logprob": -4.9264607429504395, "rank": 4, "decoded_token": "-covered"}, "4527": {"logprob": -4.9264607429504395, "rank": 5, "decoded_token": "less"}}, {"21283": {"logprob": -0.00033122775494121015, "rank": 1, "decoded_token": " sky"}, "10991": {"logprob": -8.875330924987793, "rank": 2, "decoded_token": " blue"}, "1044": {"logprob": -9.500330924987793, "rank": 3, "decoded_token": ","}, "26549": {"logprob": -10.500330924987793, "rank": 4, "decoded_token": " gray"}, "34052": {"logprob": -11.375330924987793, "rank": 5, "decoded_token": " grey"}}, {"1626": {"logprob": -0.00012683063687290996, "rank": 1, "decoded_token": ".\n"}, "1044": {"logprob": -9.500126838684082, "rank": 2, "decoded_token": ","}, "1046": {"logprob": -10.500126838684082, "rank": 3, "decoded_token": "."}, "1454": {"logprob": -10.875126838684082, "rank": 4, "decoded_token": " with"}, "1294": {"logprob": -13.375126838684082, "rank": 5, "decoded_token": " in"}}, {"1051": {"logprob": -3.2186455882765586e-06, "rank": 1, "decoded_token": "3"}, "1052": {"logprob": -12.75000286102295, "rank": 2, "decoded_token": "4"}, "1050": {"logprob": -15.00000286102295, "rank": 3, "decoded_token": "2"}, "1049": {"logprob": -17.000003814697266, "rank": 4, "decoded_token": "1"}, "1032": {"logprob": -17.937503814697266, "rank": 5, "decoded_token": " "}}, {"1046": {"logprob": -1.9073468138230965e-06, "rank": 1, "decoded_token": "."}, "3590": {"logprob": -14.625001907348633, "rank": 2, "decoded_token": ".A"}, "5226": {"logprob": -15.625001907348633, "rank": 3, "decoded_token": ".D"}, "6847": {"logprob": -15.750001907348633, "rank": 4, "decoded_token": ".T"}, "4700": {"logprob": -16.750001907348633, "rank": 5, "decoded_token": ".M"}}, {"8342": {"logprob": -0.5928499102592468, "rank": 1, "decoded_token": " Sur"}, "1349": {"logprob": -1.6553499698638916, "rank": 2, "decoded_token": " A"}, "22468": {"logprob": -2.5303499698638916, "rank": 3, "decoded_token": " Several"}, "1488": {"logprob": -2.7178499698638916, "rank": 4, "decoded_token": " W"}, "15035": {"logprob": -3.2178499698638916, "rank": 5, "decoded_token": " People"}}, {"71284": {"logprob": -0.003268140833824873, "rank": 1, "decoded_token": "fers"}, "1102": {"logprob": -5.878268241882324, "rank": 2, "decoded_token": "f"}, "1726": {"logprob": -7.753268241882324, "rank": 3, "decoded_token": "fer"}, "61888": {"logprob": -12.315768241882324, "rank": 4, "decoded_token": "fline"}, "2119": {"logprob": -13.065768241882324, "rank": 5, "decoded_token": "fter"}}, {"7377": {"logprob": -1.4883846044540405, "rank": 1, "decoded_token": " wait"}, "1584": {"logprob": -1.7383846044540405, "rank": 2, "decoded_token": " are"}, "88014": {"logprob": -1.9258846044540405, "rank": 3, "decoded_token": " paddle"}, "1294": {"logprob": -1.9258846044540405, "rank": 4, "decoded_token": " in"}, "24434": {"logprob": -2.23838472366333, "rank": 5, "decoded_token": " ride"}}, {"1394": {"logprob": -0.6120346188545227, "rank": 1, "decoded_token": " for"}, "1294": {"logprob": -0.9870346188545227, "rank": 2, "decoded_token": " in"}, "1408": {"logprob": -2.737034559249878, "rank": 3, "decoded_token": " on"}, "6482": {"logprob": -4.487034797668457, "rank": 4, "decoded_token": " patient"}, "1321": {"logprob": -5.612034797668457, "rank": 5, "decoded_token": " and"}}, {"22140": {"logprob": -0.008224429562687874, "rank": 1, "decoded_token": " waves"}, "1278": {"logprob": -5.5082244873046875, "rank": 2, "decoded_token": " the"}, "1261": {"logprob": -5.6332244873046875, "rank": 3, "decoded_token": " a"}, "39460": {"logprob": -8.133224487304688, "rank": 4, "decoded_token": " incoming"}, "1321": {"logprob": -9.758224487304688, "rank": 5, "decoded_token": " and"}}, {"1294": {"logprob": -0.3204176723957062, "rank": 1, "decoded_token": " in"}, "1408": {"logprob": -2.195417642593384, "rank": 2, "decoded_token": " on"}, "1513": {"logprob": -2.320417642593384, "rank": 3, "decoded_token": " at"}, "3016": {"logprob": -3.695417642593384, "rank": 4, "decoded_token": " while"}, "1435": {"logprob": -3.820417642593384, "rank": 5, "decoded_token": " as"}}, {"1278": {"logprob": -0.004615250043570995, "rank": 1, "decoded_token": " the"}, "1261": {"logprob": -6.192115306854248, "rank": 2, "decoded_token": " a"}, "1420": {"logprob": -6.942115306854248, "rank": 3, "decoded_token": " an"}, "40466": {"logprob": -7.317115306854248, "rank": 4, "decoded_token": " shallow"}, "26517": {"logprob": -7.879615306854248, "rank": 5, "decoded_token": " calm"}}, {"27208": {"logprob": -0.06491076946258545, "rank": 1, "decoded_token": " ocean"}, "7786": {"logprob": -3.439910888671875, "rank": 2, "decoded_token": " distance"}, "5124": {"logprob": -5.314910888671875, "rank": 3, "decoded_token": " early"}, "26517": {"logprob": -5.377410888671875, "rank": 4, "decoded_token": " calm"}, "11196": {"logprob": -5.377410888671875, "rank": 5, "decoded_token": " sea"}}, {"1513": {"logprob": -1.144903540611267, "rank": 1, "decoded_token": " at"}, "1435": {"logprob": -1.269903540611267, "rank": 2, "decoded_token": " as"}, "3184": {"logprob": -1.394903540611267, "rank": 3, "decoded_token": " during"}, "3016": {"logprob": -3.0199036598205566, "rank": 4, "decoded_token": " while"}, "6117": {"logprob": -3.1449036598205566, "rank": 5, "decoded_token": " near"}}, {"97558": {"logprob": -0.12556149065494537, "rank": 1, "decoded_token": " sunset"}, "11729": {"logprob": -2.875561475753784, "rank": 2, "decoded_token": " sun"}, "1266": {"logprob": -3.375561475753784, "rank": 3, "decoded_token": " d"}, "54507": {"logprob": -4.000561714172363, "rank": 4, "decoded_token": " dawn"}, "1261": {"logprob": -5.125561714172363, "rank": 5, "decoded_token": " a"}}, {"1626": {"logprob": -0.26737067103385925, "rank": 1, "decoded_token": ".\n"}, "1044": {"logprob": -2.2673707008361816, "rank": 2, "decoded_token": ","}, "3016": {"logprob": -2.7673707008361816, "rank": 3, "decoded_token": " while"}, "1454": {"logprob": -3.5173707008361816, "rank": 4, "decoded_token": " with"}, "6117": {"logprob": -4.142370700836182, "rank": 5, "decoded_token": " near"}}, {"1052": {"logprob": -2.9802276912960224e-06, "rank": 1, "decoded_token": "4"}, "1051": {"logprob": -13.37500286102295, "rank": 2, "decoded_token": "3"}, "1049": {"logprob": -14.00000286102295, "rank": 3, "decoded_token": "1"}, "1053": {"logprob": -14.56250286102295, "rank": 4, "decoded_token": "5"}, "1032": {"logprob": -16.750003814697266, "rank": 5, "decoded_token": " "}}, {"1046": {"logprob": -1.6689286894688848e-06, "rank": 1, "decoded_token": "."}, "3590": {"logprob": -13.500001907348633, "rank": 2, "decoded_token": ".A"}, "6847": {"logprob": -16.562501907348633, "rank": 3, "decoded_token": ".T"}, "1044": {"logprob": -17.312501907348633, "rank": 4, "decoded_token": ","}, "1349": {"logprob": -17.500001907348633, "rank": 5, "decoded_token": " A"}}, {"1349": {"logprob": -0.004883386194705963, "rank": 1, "decoded_token": " A"}, "2048": {"logprob": -5.504883289337158, "rank": 2, "decoded_token": " An"}, "10638": {"logprob": -7.754883289337158, "rank": 3, "decoded_token": " Two"}, "111463": {"logprob": -9.754883766174316, "rank": 4, "decoded_token": " Trees"}, "1531": {"logprob": -10.692383766174316, "rank": 5, "decoded_token": " The"}}, {"53301": {"logprob": -1.5612412691116333, "rank": 1, "decoded_token": " winding"}, "15192": {"logprob": -1.7487412691116333, "rank": 2, "decoded_token": " narrow"}, "47945": {"logprob": -2.1237411499023438, "rank": 3, "decoded_token": " dirt"}, "2169": {"logprob": -2.5612411499023438, "rank": 4, "decoded_token": " ser"}, "59396": {"logprob": -2.6862411499023438, "rank": 5, "decoded_token": " gravel"}}, {"59396": {"logprob": -0.9024254083633423, "rank": 1, "decoded_token": " gravel"}, "3549": {"logprob": -1.1524254083633423, "rank": 2, "decoded_token": " path"}, "47945": {"logprob": -1.6524254083633423, "rank": 3, "decoded_token": " dirt"}, "14801": {"logprob": -3.1524252891540527, "rank": 4, "decoded_token": " pathway"}, "15551": {"logprob": -4.277425289154053, "rank": 5, "decoded_token": " stone"}}, {"3549": {"logprob": -0.021290099248290062, "rank": 1, "decoded_token": " path"}, "14801": {"logprob": -3.8962900638580322, "rank": 2, "decoded_token": " pathway"}, "33659": {"logprob": -7.896290302276611, "rank": 3, "decoded_token": " trail"}, "9480": {"logprob": -9.521289825439453, "rank": 4, "decoded_token": " road"}, "7368": {"logprob": -9.646289825439453, "rank": 5, "decoded_token": "path"}}, {"13335": {"logprob": -0.16593234241008759, "rank": 1, "decoded_token": " leads"}, "39985": {"logprob": -2.8534324169158936, "rank": 2, "decoded_token": " cuts"}, "1639": {"logprob": -3.9784324169158936, "rank": 3, "decoded_token": " me"}, "11500": {"logprob": -4.1034321784973145, "rank": 4, "decoded_token": " runs"}, "2645": {"logprob": -4.2909321784973145, "rank": 5, "decoded_token": " through"}}, {"2645": {"logprob": -0.05767015367746353, "rank": 1, "decoded_token": " through"}, "8994": {"logprob": -4.0576701164245605, "rank": 2, "decoded_token": " towards"}, "2396": {"logprob": -4.1826701164245605, "rank": 3, "decoded_token": " between"}, "2203": {"logprob": -4.5576701164245605, "rank": 4, "decoded_token": " into"}, "1317": {"logprob": -5.5576701164245605, "rank": 5, "decoded_token": " to"}}, {"1261": {"logprob": -0.017209367826581, "rank": 1, "decoded_token": " a"}, "11223": {"logprob": -4.892209529876709, "rank": 2, "decoded_token": " green"}, "1295": {"logprob": -5.017209529876709, "rank": 3, "decoded_token": " l"}, "23170": {"logprob": -6.767209529876709, "rank": 4, "decoded_token": " grass"}, "1420": {"logprob": -7.267209529876709, "rank": 5, "decoded_token": " an"}}, {"1295": {"logprob": -0.9430665969848633, "rank": 1, "decoded_token": " l"}, "11223": {"logprob": -1.3180665969848633, "rank": 2, "decoded_token": " green"}, "23170": {"logprob": -1.9430665969848633, "rank": 3, "decoded_token": " grass"}, "12097": {"logprob": -2.4430665969848633, "rank": 4, "decoded_token": " park"}, "26428": {"logprob": -3.3180665969848633, "rank": 5, "decoded_token": " garden"}}, {"3506": {"logprob": -6.556489552167477e-06, "rank": 1, "decoded_token": "ush"}, "1374": {"logprob": -12.000006675720215, "rank": 2, "decoded_token": "us"}, "90716": {"logprob": -15.625006675720215, "rank": 3, "decoded_token": "USH"}, "16938": {"logprob": -15.875006675720215, "rank": 4, "decoded_token": "usher"}, "13326": {"logprob": -17.1875057220459, "rank": 5, "decoded_token": "inden"}}, {"11223": {"logprob": -0.36697858572006226, "rank": 1, "decoded_token": " green"}, "1044": {"logprob": -1.366978645324707, "rank": 2, "decoded_token": ","}, "26428": {"logprob": -3.491978645324707, "rank": 3, "decoded_token": " garden"}, "12097": {"logprob": -4.116978645324707, "rank": 4, "decoded_token": " park"}, "23170": {"logprob": -5.866978645324707, "rank": 5, "decoded_token": " grass"}}, {"12097": {"logprob": -0.5570574402809143, "rank": 1, "decoded_token": " park"}, "3727": {"logprob": -1.9320573806762695, "rank": 2, "decoded_token": " field"}, "28035": {"logprob": -2.1820573806762695, "rank": 3, "decoded_token": " landscape"}, "26428": {"logprob": -2.4320573806762695, "rank": 4, "decoded_token": " garden"}, "4457": {"logprob": -2.8070573806762695, "rank": 5, "decoded_token": " area"}}, {"1046": {"logprob": -0.7940837144851685, "rank": 1, "decoded_token": "."}, "1454": {"logprob": -1.2940837144851685, "rank": 2, "decoded_token": " with"}, "8994": {"logprob": -2.794083595275879, "rank": 3, "decoded_token": " towards"}, "54410": {"logprob": -3.544083595275879, "rank": 4, "decoded_token": " lined"}, "2425": {"logprob": -3.544083595275879, "rank": 5, "decoded_token": " under"}}, {"2": {"logprob": -2.145764938177308e-06, "rank": 1, "decoded_token": ""}, "1032": {"logprob": -13.125001907348633, "rank": 2, "decoded_token": " "}, "1256": {"logprob": -16.000001907348633, "rank": 3, "decoded_token": " "}, "1293": {"logprob": -18.750001907348633, "rank": 4, "decoded_token": " "}, "1319": {"logprob": -19.687501907348633, "rank": 5, "decoded_token": " ("}}]]] \ No newline at end of file diff --git a/tests/models/fixtures/pixtral_chat_engine.json b/tests/models/fixtures/pixtral_chat_engine.json new file mode 100644 index 0000000000000..60e4ae6cebf59 --- /dev/null +++ b/tests/models/fixtures/pixtral_chat_engine.json @@ -0,0 +1 @@ +[[[1784, 3937, 6122, 1261, 7244, 10575, 18970, 1408, 1261, 32656, 4691, 1046, 2], "The image shows a black dog sitting on a wooden surface.", [{"1784": {"logprob": -0.11685245484113693, "rank": 1, "decoded_token": "The"}, "4380": {"logprob": -2.3668525218963623, "rank": 2, "decoded_token": "This"}, "1049": {"logprob": -4.741852283477783, "rank": 3, "decoded_token": "1"}, "117991": {"logprob": -5.991852283477783, "rank": 4, "decoded_token": "Certain"}, "1785": {"logprob": -5.991852283477783, "rank": 5, "decoded_token": "In"}}, {"3937": {"logprob": -0.2591013014316559, "rank": 1, "decoded_token": " image"}, "2158": {"logprob": -1.5091012716293335, "rank": 2, "decoded_token": " first"}, "3977": {"logprob": -5.884101390838623, "rank": 3, "decoded_token": " top"}, "7244": {"logprob": -6.259101390838623, "rank": 4, "decoded_token": " black"}, "8061": {"logprob": -6.759101390838623, "rank": 5, "decoded_token": " images"}}, {"6122": {"logprob": -0.9660423994064331, "rank": 1, "decoded_token": " shows"}, "51948": {"logprob": -1.466042399406433, "rank": 2, "decoded_token": " depicts"}, "6971": {"logprob": -1.466042399406433, "rank": 3, "decoded_token": " features"}, "25981": {"logprob": -2.8410425186157227, "rank": 4, "decoded_token": " displays"}, "8688": {"logprob": -2.8410425186157227, "rank": 5, "decoded_token": " contains"}}, {"1261": {"logprob": -0.0030613720882683992, "rank": 1, "decoded_token": " a"}, "1420": {"logprob": -6.253061294555664, "rank": 2, "decoded_token": " an"}, "2295": {"logprob": -7.878061294555664, "rank": 3, "decoded_token": " two"}, "2342": {"logprob": -7.878061294555664, "rank": 4, "decoded_token": " only"}, "1278": {"logprob": -8.628061294555664, "rank": 5, "decoded_token": " the"}}, {"7244": {"logprob": -0.17649099230766296, "rank": 1, "decoded_token": " black"}, "6231": {"logprob": -2.3014910221099854, "rank": 2, "decoded_token": " close"}, "4249": {"logprob": -3.4264910221099854, "rank": 3, "decoded_token": " single"}, "4329": {"logprob": -5.113990783691406, "rank": 4, "decoded_token": " large"}, "10575": {"logprob": -5.176490783691406, "rank": 5, "decoded_token": " dog"}}, {"10575": {"logprob": -0.10929587483406067, "rank": 1, "decoded_token": " dog"}, "116572": {"logprob": -2.4842958450317383, "rank": 2, "decoded_token": " puppy"}, "119075": {"logprob": -4.109295845031738, "rank": 3, "decoded_token": " Labrador"}, "15812": {"logprob": -7.296795845031738, "rank": 4, "decoded_token": " Lab"}, "7990": {"logprob": -7.484295845031738, "rank": 5, "decoded_token": " cat"}}, {"18970": {"logprob": -0.830376148223877, "rank": 1, "decoded_token": " sitting"}, "1454": {"logprob": -1.580376148223877, "rank": 2, "decoded_token": " with"}, "28528": {"logprob": -1.955376148223877, "rank": 3, "decoded_token": " lying"}, "7283": {"logprob": -2.205376148223877, "rank": 4, "decoded_token": " looking"}, "15866": {"logprob": -3.017876148223877, "rank": 5, "decoded_token": " standing"}}, {"1408": {"logprob": -0.08554735779762268, "rank": 1, "decoded_token": " on"}, "1321": {"logprob": -3.71054744720459, "rank": 2, "decoded_token": " and"}, "3675": {"logprob": -3.96054744720459, "rank": 3, "decoded_token": " against"}, "41132": {"logprob": -4.71054744720459, "rank": 4, "decoded_token": " attent"}, "1454": {"logprob": -5.08554744720459, "rank": 5, "decoded_token": " with"}}, {"1261": {"logprob": -0.540847897529602, "rank": 1, "decoded_token": " a"}, "32656": {"logprob": -0.915847897529602, "rank": 2, "decoded_token": " wooden"}, "12603": {"logprob": -5.4158477783203125, "rank": 3, "decoded_token": " wood"}, "3977": {"logprob": -5.4158477783203125, "rank": 4, "decoded_token": " top"}, "17253": {"logprob": -6.2908477783203125, "rank": 5, "decoded_token": " weather"}}, {"32656": {"logprob": -0.025753861293196678, "rank": 1, "decoded_token": " wooden"}, "44130": {"logprob": -4.400753974914551, "rank": 2, "decoded_token": " rust"}, "12603": {"logprob": -5.275753974914551, "rank": 3, "decoded_token": " wood"}, "3403": {"logprob": -5.400753974914551, "rank": 4, "decoded_token": " text"}, "17253": {"logprob": -6.963253974914551, "rank": 5, "decoded_token": " weather"}}, {"4691": {"logprob": -0.7265751957893372, "rank": 1, "decoded_token": " surface"}, "11237": {"logprob": -0.8515751957893372, "rank": 2, "decoded_token": " floor"}, "7042": {"logprob": -2.6015751361846924, "rank": 3, "decoded_token": " background"}, "28984": {"logprob": -5.2265753746032715, "rank": 4, "decoded_token": " deck"}, "1615": {"logprob": -5.7265753746032715, "rank": 5, "decoded_token": " pl"}}, {"1046": {"logprob": -0.4868825674057007, "rank": 1, "decoded_token": "."}, "1044": {"logprob": -1.9868825674057007, "rank": 2, "decoded_token": ","}, "1321": {"logprob": -2.3618826866149902, "rank": 3, "decoded_token": " and"}, "1454": {"logprob": -2.6118826866149902, "rank": 4, "decoded_token": " with"}, "7283": {"logprob": -2.7368826866149902, "rank": 5, "decoded_token": " looking"}}, {"2": {"logprob": -0.0026643513701856136, "rank": 1, "decoded_token": ""}, "1531": {"logprob": -6.502664566040039, "rank": 2, "decoded_token": " The"}, "1032": {"logprob": -6.877664566040039, "rank": 3, "decoded_token": " "}, "3730": {"logprob": -9.752664566040039, "rank": 4, "decoded_token": " There"}, "1256": {"logprob": -11.002664566040039, "rank": 5, "decoded_token": " "}}]], [[1049, 1046, 1349, 7244, 10575, 1454, 2327, 94766, 32961, 53048, 41132, 3923, 1408, 1261, 32656, 4691, 1626, 1050, 1046, 1349, 15375, 24361, 4521, 94973, 5669, 1278, 48932, 2425, 1261, 16152, 1121, 21283, 1046, 2], "1. A black dog with floppy ears sits attentively on a wooden surface.\n2. A vast mountain range stretches across the horizon under a cloudy sky.", [{"1049": {"logprob": -0.42824622988700867, "rank": 1, "decoded_token": "1"}, "1045": {"logprob": -1.553246259689331, "rank": 2, "decoded_token": "-"}, "1065": {"logprob": -2.428246259689331, "rank": 3, "decoded_token": "A"}, "1784": {"logprob": -4.053246021270752, "rank": 4, "decoded_token": "The"}, "69957": {"logprob": -4.428246021270752, "rank": 5, "decoded_token": "Sure"}}, {"1046": {"logprob": -1.811964830267243e-05, "rank": 1, "decoded_token": "."}, "1058": {"logprob": -11.875018119812012, "rank": 2, "decoded_token": ":"}, "3590": {"logprob": -12.250018119812012, "rank": 3, "decoded_token": ".A"}, "1065": {"logprob": -13.062518119812012, "rank": 4, "decoded_token": "A"}, "1041": {"logprob": -13.750018119812012, "rank": 5, "decoded_token": ")"}}, {"1349": {"logprob": -0.13647246360778809, "rank": 1, "decoded_token": " A"}, "1429": {"logprob": -2.386472463607788, "rank": 2, "decoded_token": " \""}, "1603": {"logprob": -3.886472463607788, "rank": 3, "decoded_token": " **"}, "11967": {"logprob": -5.011472702026367, "rank": 4, "decoded_token": " Image"}, "1531": {"logprob": -5.011472702026367, "rank": 5, "decoded_token": " The"}}, {"7244": {"logprob": -0.18561004102230072, "rank": 1, "decoded_token": " black"}, "38462": {"logprob": -3.185610055923462, "rank": 2, "decoded_token": " curious"}, "68076": {"logprob": -3.623110055923462, "rank": 3, "decoded_token": " cute"}, "4329": {"logprob": -3.935610055923462, "rank": 4, "decoded_token": " large"}, "74168": {"logprob": -4.373109817504883, "rank": 5, "decoded_token": " gloss"}}, {"10575": {"logprob": -0.17297746241092682, "rank": 1, "decoded_token": " dog"}, "116572": {"logprob": -2.1729774475097656, "rank": 2, "decoded_token": " puppy"}, "119075": {"logprob": -3.1729774475097656, "rank": 3, "decoded_token": " Labrador"}, "15812": {"logprob": -6.985477447509766, "rank": 4, "decoded_token": " Lab"}, "8636": {"logprob": -7.360477447509766, "rank": 5, "decoded_token": " lab"}}, {"1454": {"logprob": -0.5785807967185974, "rank": 1, "decoded_token": " with"}, "53048": {"logprob": -1.2660808563232422, "rank": 2, "decoded_token": " sits"}, "1395": {"logprob": -3.016080856323242, "rank": 3, "decoded_token": " is"}, "22524": {"logprob": -3.578580856323242, "rank": 4, "decoded_token": " lies"}, "18970": {"logprob": -3.703580856323242, "rank": 5, "decoded_token": " sitting"}}, {"2327": {"logprob": -1.2709298133850098, "rank": 1, "decoded_token": " fl"}, "1261": {"logprob": -1.3959298133850098, "rank": 2, "decoded_token": " a"}, "17300": {"logprob": -1.8959298133850098, "rank": 3, "decoded_token": " soul"}, "100089": {"logprob": -2.6459298133850098, "rank": 4, "decoded_token": " expressive"}, "6444": {"logprob": -3.1459298133850098, "rank": 5, "decoded_token": " soft"}}, {"94766": {"logprob": -0.002432247158139944, "rank": 1, "decoded_token": "oppy"}, "124603": {"logprob": -6.377432346343994, "rank": 2, "decoded_token": "uffy"}, "1484": {"logprob": -7.877432346343994, "rank": 3, "decoded_token": "op"}, "24897": {"logprob": -8.877431869506836, "rank": 4, "decoded_token": "appy"}, "102477": {"logprob": -9.752431869506836, "rank": 5, "decoded_token": "opping"}}, {"32961": {"logprob": -5.113947918289341e-05, "rank": 1, "decoded_token": " ears"}, "16962": {"logprob": -11.312551498413086, "rank": 2, "decoded_token": " ear"}, "5731": {"logprob": -11.750051498413086, "rank": 3, "decoded_token": " eyes"}, "3351": {"logprob": -12.000051498413086, "rank": 4, "decoded_token": " years"}, "42071": {"logprob": -13.000051498413086, "rank": 5, "decoded_token": " cheeks"}}, {"53048": {"logprob": -0.6131591200828552, "rank": 1, "decoded_token": " sits"}, "10637": {"logprob": -1.9881591796875, "rank": 2, "decoded_token": " looks"}, "1321": {"logprob": -2.4256591796875, "rank": 3, "decoded_token": " and"}, "1395": {"logprob": -2.6756591796875, "rank": 4, "decoded_token": " is"}, "18970": {"logprob": -3.0506591796875, "rank": 5, "decoded_token": " sitting"}}, {"41132": {"logprob": -0.36187249422073364, "rank": 1, "decoded_token": " attent"}, "1408": {"logprob": -2.361872434616089, "rank": 2, "decoded_token": " on"}, "106534": {"logprob": -2.424372434616089, "rank": 3, "decoded_token": " calmly"}, "12276": {"logprob": -2.611872434616089, "rank": 4, "decoded_token": " alert"}, "6482": {"logprob": -5.174372673034668, "rank": 5, "decoded_token": " patient"}}, {"3923": {"logprob": -8.451581379631534e-05, "rank": 1, "decoded_token": "ively"}, "1556": {"logprob": -9.50008487701416, "rank": 2, "decoded_token": "ive"}, "6655": {"logprob": -11.87508487701416, "rank": 3, "decoded_token": "atively"}, "3929": {"logprob": -14.00008487701416, "rank": 4, "decoded_token": "ently"}, "47885": {"logprob": -14.75008487701416, "rank": 5, "decoded_token": "edly"}}, {"1408": {"logprob": -0.058125678449869156, "rank": 1, "decoded_token": " on"}, "3675": {"logprob": -3.1831257343292236, "rank": 2, "decoded_token": " against"}, "1294": {"logprob": -4.9331254959106445, "rank": 3, "decoded_token": " in"}, "7283": {"logprob": -5.8081254959106445, "rank": 4, "decoded_token": " looking"}, "1044": {"logprob": -5.9331254959106445, "rank": 5, "decoded_token": ","}}, {"1261": {"logprob": -0.21029606461524963, "rank": 1, "decoded_token": " a"}, "32656": {"logprob": -1.7102960348129272, "rank": 2, "decoded_token": " wooden"}, "17253": {"logprob": -5.710296154022217, "rank": 3, "decoded_token": " weather"}, "44130": {"logprob": -6.085296154022217, "rank": 4, "decoded_token": " rust"}, "12603": {"logprob": -6.960296154022217, "rank": 5, "decoded_token": " wood"}}, {"32656": {"logprob": -0.08548421412706375, "rank": 1, "decoded_token": " wooden"}, "44130": {"logprob": -2.710484266281128, "rank": 2, "decoded_token": " rust"}, "17253": {"logprob": -4.710484027862549, "rank": 3, "decoded_token": " weather"}, "12603": {"logprob": -5.960484027862549, "rank": 4, "decoded_token": " wood"}, "3403": {"logprob": -5.960484027862549, "rank": 5, "decoded_token": " text"}}, {"4691": {"logprob": -0.7172377109527588, "rank": 1, "decoded_token": " surface"}, "11237": {"logprob": -0.8422377109527588, "rank": 2, "decoded_token": " floor"}, "7042": {"logprob": -2.842237710952759, "rank": 3, "decoded_token": " background"}, "28984": {"logprob": -4.21723747253418, "rank": 4, "decoded_token": " deck"}, "92504": {"logprob": -6.21723747253418, "rank": 5, "decoded_token": " backdrop"}}, {"1626": {"logprob": -0.12971943616867065, "rank": 1, "decoded_token": ".\n"}, "1044": {"logprob": -2.3797194957733154, "rank": 2, "decoded_token": ","}, "1046": {"logprob": -4.129719257354736, "rank": 3, "decoded_token": "."}, "1338": {"logprob": -5.129719257354736, "rank": 4, "decoded_token": ".\n\n"}, "7283": {"logprob": -5.504719257354736, "rank": 5, "decoded_token": " looking"}}, {"1050": {"logprob": -0.00015698630886618048, "rank": 1, "decoded_token": "2"}, "1256": {"logprob": -9.125157356262207, "rank": 2, "decoded_token": " "}, "1032": {"logprob": -10.875157356262207, "rank": 3, "decoded_token": " "}, "1293": {"logprob": -11.750157356262207, "rank": 4, "decoded_token": " "}, "1051": {"logprob": -12.125157356262207, "rank": 5, "decoded_token": "3"}}, {"1046": {"logprob": -6.6756979322235566e-06, "rank": 1, "decoded_token": "."}, "3590": {"logprob": -13.062506675720215, "rank": 2, "decoded_token": ".A"}, "1626": {"logprob": -13.187506675720215, "rank": 3, "decoded_token": ".\n"}, "1338": {"logprob": -14.750006675720215, "rank": 4, "decoded_token": ".\n\n"}, "1058": {"logprob": -14.937506675720215, "rank": 5, "decoded_token": ":"}}, {"1349": {"logprob": -0.5863217115402222, "rank": 1, "decoded_token": " A"}, "11826": {"logprob": -1.4613217115402222, "rank": 2, "decoded_token": " Maj"}, "37159": {"logprob": -2.2113218307495117, "rank": 3, "decoded_token": " Snow"}, "113465": {"logprob": -3.8988218307495117, "rank": 4, "decoded_token": " Rug"}, "1531": {"logprob": -3.9613218307495117, "rank": 5, "decoded_token": " The"}}, {"15375": {"logprob": -0.639299213886261, "rank": 1, "decoded_token": " vast"}, "37849": {"logprob": -2.014299154281616, "rank": 2, "decoded_token": " breat"}, "61082": {"logprob": -2.389299154281616, "rank": 3, "decoded_token": " panor"}, "10726": {"logprob": -3.139299154281616, "rank": 4, "decoded_token": " scen"}, "2169": {"logprob": -3.201799154281616, "rank": 5, "decoded_token": " ser"}}, {"24361": {"logprob": -0.702845573425293, "rank": 1, "decoded_token": " mountain"}, "127945": {"logprob": -1.952845573425293, "rank": 2, "decoded_token": " mountainous"}, "1044": {"logprob": -2.077845573425293, "rank": 3, "decoded_token": ","}, "4521": {"logprob": -2.327845573425293, "rank": 4, "decoded_token": " range"}, "28035": {"logprob": -2.452845573425293, "rank": 5, "decoded_token": " landscape"}}, {"4521": {"logprob": -0.07058162242174149, "rank": 1, "decoded_token": " range"}, "28035": {"logprob": -2.6955816745758057, "rank": 2, "decoded_token": " landscape"}, "37691": {"logprob": -8.320581436157227, "rank": 3, "decoded_token": " valley"}, "12248": {"logprob": -9.445581436157227, "rank": 4, "decoded_token": " peak"}, "13327": {"logprob": -9.695581436157227, "rank": 5, "decoded_token": " scene"}}, {"94973": {"logprob": -1.1164050102233887, "rank": 1, "decoded_token": " stretches"}, "1454": {"logprob": -1.1789050102233887, "rank": 2, "decoded_token": " with"}, "2425": {"logprob": -1.8664050102233887, "rank": 3, "decoded_token": " under"}, "1395": {"logprob": -2.5539050102233887, "rank": 4, "decoded_token": " is"}, "13875": {"logprob": -2.9914050102233887, "rank": 5, "decoded_token": " covered"}}, {"5669": {"logprob": -0.3286789357662201, "rank": 1, "decoded_token": " across"}, "1848": {"logprob": -2.078678846359253, "rank": 2, "decoded_token": " out"}, "2425": {"logprob": -2.328678846359253, "rank": 3, "decoded_token": " under"}, "2203": {"logprob": -3.328678846359253, "rank": 4, "decoded_token": " into"}, "8994": {"logprob": -4.766179084777832, "rank": 5, "decoded_token": " towards"}}, {"1278": {"logprob": -0.039004355669021606, "rank": 1, "decoded_token": " the"}, "1261": {"logprob": -3.289004325866699, "rank": 2, "decoded_token": " a"}, "1420": {"logprob": -7.414004325866699, "rank": 3, "decoded_token": " an"}, "2425": {"logprob": -9.0390043258667, "rank": 4, "decoded_token": " under"}, "1454": {"logprob": -9.2265043258667, "rank": 5, "decoded_token": " with"}}, {"48932": {"logprob": -0.2659883201122284, "rank": 1, "decoded_token": " horizon"}, "21283": {"logprob": -2.140988349914551, "rank": 2, "decoded_token": " sky"}, "3937": {"logprob": -3.015988349914551, "rank": 3, "decoded_token": " image"}, "28035": {"logprob": -3.515988349914551, "rank": 4, "decoded_token": " landscape"}, "3044": {"logprob": -4.265988349914551, "rank": 5, "decoded_token": " sk"}}, {"2425": {"logprob": -0.5356141328811646, "rank": 1, "decoded_token": " under"}, "1044": {"logprob": -1.5356141328811646, "rank": 2, "decoded_token": ","}, "1454": {"logprob": -1.7856141328811646, "rank": 3, "decoded_token": " with"}, "25136": {"logprob": -3.785614013671875, "rank": 4, "decoded_token": " beneath"}, "1408": {"logprob": -5.785614013671875, "rank": 5, "decoded_token": " on"}}, {"1261": {"logprob": -0.006081883795559406, "rank": 1, "decoded_token": " a"}, "1420": {"logprob": -5.506082057952881, "rank": 2, "decoded_token": " an"}, "16152": {"logprob": -7.631082057952881, "rank": 3, "decoded_token": " cloud"}, "6133": {"logprob": -7.881082057952881, "rank": 4, "decoded_token": " clear"}, "2136": {"logprob": -8.006081581115723, "rank": 5, "decoded_token": " over"}}, {"16152": {"logprob": -0.6749536991119385, "rank": 1, "decoded_token": " cloud"}, "6133": {"logprob": -1.4249536991119385, "rank": 2, "decoded_token": " clear"}, "18416": {"logprob": -2.8624536991119385, "rank": 3, "decoded_token": " haz"}, "27254": {"logprob": -2.9874536991119385, "rank": 4, "decoded_token": " partly"}, "4391": {"logprob": -3.2374536991119385, "rank": 5, "decoded_token": " light"}}, {"1121": {"logprob": -0.10860869288444519, "rank": 1, "decoded_token": "y"}, "4527": {"logprob": -2.9836087226867676, "rank": 2, "decoded_token": "less"}, "1286": {"logprob": -3.4836087226867676, "rank": 3, "decoded_token": "ed"}, "77187": {"logprob": -4.608608722686768, "rank": 4, "decoded_token": "-filled"}, "114525": {"logprob": -4.858608722686768, "rank": 5, "decoded_token": "-covered"}}, {"21283": {"logprob": -0.002785732736811042, "rank": 1, "decoded_token": " sky"}, "10991": {"logprob": -6.252785682678223, "rank": 2, "decoded_token": " blue"}, "1044": {"logprob": -7.627785682678223, "rank": 3, "decoded_token": ","}, "26549": {"logprob": -8.627785682678223, "rank": 4, "decoded_token": " gray"}, "34052": {"logprob": -9.377785682678223, "rank": 5, "decoded_token": " grey"}}, {"1046": {"logprob": -0.047878943383693695, "rank": 1, "decoded_token": "."}, "1044": {"logprob": -3.1728789806365967, "rank": 2, "decoded_token": ","}, "1454": {"logprob": -5.547878742218018, "rank": 3, "decoded_token": " with"}, "1338": {"logprob": -7.172878742218018, "rank": 4, "decoded_token": ".\n\n"}, "1294": {"logprob": -9.172879219055176, "rank": 5, "decoded_token": " in"}}, {"2": {"logprob": -1.3351351299206726e-05, "rank": 1, "decoded_token": ""}, "1032": {"logprob": -11.25001335144043, "rank": 2, "decoded_token": " "}, "1256": {"logprob": -16.00001335144043, "rank": 3, "decoded_token": " "}, "1319": {"logprob": -17.25001335144043, "rank": 4, "decoded_token": " ("}, "1766": {"logprob": -18.50001335144043, "rank": 5, "decoded_token": " ["}}]], [[1049, 1046, 1349, 7244, 10575, 53048, 41132, 3923, 1408, 1261, 32656, 11237, 1626, 1050, 1046, 1349, 15375, 24361, 4521, 94973, 5669, 1278, 48932, 2425, 1261, 16152, 1121, 21283, 1626, 1051, 1046, 8342, 71284, 7377, 1394, 22140, 1294, 1278, 27208, 1513, 97558, 1626, 1052, 1046, 1349, 53301, 59396, 3549, 13335, 2645, 1261, 1295, 3506, 11223, 12097, 1046, 2], "1. A black dog sits attentively on a wooden floor.\n2. A vast mountain range stretches across the horizon under a cloudy sky.\n3. Surfers wait for waves in the ocean at sunset.\n4. A winding gravel path leads through a lush green park.", [{"1049": {"logprob": -0.05001257359981537, "rank": 1, "decoded_token": "1"}, "1045": {"logprob": -3.1750125885009766, "rank": 2, "decoded_token": "-"}, "69957": {"logprob": -5.925012588500977, "rank": 3, "decoded_token": "Sure"}, "11745": {"logprob": -6.425012588500977, "rank": 4, "decoded_token": "Here"}, "1065": {"logprob": -6.425012588500977, "rank": 5, "decoded_token": "A"}}, {"1046": {"logprob": -8.702239938429557e-06, "rank": 1, "decoded_token": "."}, "1058": {"logprob": -12.000008583068848, "rank": 2, "decoded_token": ":"}, "3590": {"logprob": -13.375008583068848, "rank": 3, "decoded_token": ".A"}, "1041": {"logprob": -14.750008583068848, "rank": 4, "decoded_token": ")"}, "1065": {"logprob": -15.687508583068848, "rank": 5, "decoded_token": "A"}}, {"1349": {"logprob": -0.14196155965328217, "rank": 1, "decoded_token": " A"}, "1429": {"logprob": -2.2669615745544434, "rank": 2, "decoded_token": " \""}, "1531": {"logprob": -4.516961574554443, "rank": 3, "decoded_token": " The"}, "11967": {"logprob": -4.516961574554443, "rank": 4, "decoded_token": " Image"}, "1603": {"logprob": -5.391961574554443, "rank": 5, "decoded_token": " **"}}, {"7244": {"logprob": -0.14889711141586304, "rank": 1, "decoded_token": " black"}, "68076": {"logprob": -3.398897171020508, "rank": 2, "decoded_token": " cute"}, "6231": {"logprob": -3.961397171020508, "rank": 3, "decoded_token": " close"}, "38462": {"logprob": -4.273897171020508, "rank": 4, "decoded_token": " curious"}, "4329": {"logprob": -4.398897171020508, "rank": 5, "decoded_token": " large"}}, {"10575": {"logprob": -0.12091328203678131, "rank": 1, "decoded_token": " dog"}, "116572": {"logprob": -2.37091326713562, "rank": 2, "decoded_token": " puppy"}, "119075": {"logprob": -3.99591326713562, "rank": 3, "decoded_token": " Labrador"}, "15812": {"logprob": -7.683413505554199, "rank": 4, "decoded_token": " Lab"}, "8636": {"logprob": -7.808413505554199, "rank": 5, "decoded_token": " lab"}}, {"53048": {"logprob": -0.8691943287849426, "rank": 1, "decoded_token": " sits"}, "1454": {"logprob": -1.1191942691802979, "rank": 2, "decoded_token": " with"}, "1395": {"logprob": -2.431694269180298, "rank": 3, "decoded_token": " is"}, "18970": {"logprob": -2.744194269180298, "rank": 4, "decoded_token": " sitting"}, "22524": {"logprob": -3.681694269180298, "rank": 5, "decoded_token": " lies"}}, {"41132": {"logprob": -0.5939557552337646, "rank": 1, "decoded_token": " attent"}, "106534": {"logprob": -1.2814557552337646, "rank": 2, "decoded_token": " calmly"}, "12276": {"logprob": -2.8439557552337646, "rank": 3, "decoded_token": " alert"}, "1408": {"logprob": -2.8439557552337646, "rank": 4, "decoded_token": " on"}, "6482": {"logprob": -4.968955993652344, "rank": 5, "decoded_token": " patient"}}, {"3923": {"logprob": -0.00010084597306558862, "rank": 1, "decoded_token": "ively"}, "1556": {"logprob": -9.500101089477539, "rank": 2, "decoded_token": "ive"}, "6655": {"logprob": -10.875101089477539, "rank": 3, "decoded_token": "atively"}, "3929": {"logprob": -13.000101089477539, "rank": 4, "decoded_token": "ently"}, "47885": {"logprob": -13.750101089477539, "rank": 5, "decoded_token": "edly"}}, {"1408": {"logprob": -0.056158196181058884, "rank": 1, "decoded_token": " on"}, "3675": {"logprob": -3.6811583042144775, "rank": 2, "decoded_token": " against"}, "1454": {"logprob": -4.306158065795898, "rank": 3, "decoded_token": " with"}, "1294": {"logprob": -5.181158065795898, "rank": 4, "decoded_token": " in"}, "7283": {"logprob": -5.431158065795898, "rank": 5, "decoded_token": " looking"}}, {"1261": {"logprob": -0.33056098222732544, "rank": 1, "decoded_token": " a"}, "32656": {"logprob": -1.3305609226226807, "rank": 2, "decoded_token": " wooden"}, "17253": {"logprob": -4.70556116104126, "rank": 3, "decoded_token": " weather"}, "44130": {"logprob": -5.83056116104126, "rank": 4, "decoded_token": " rust"}, "12603": {"logprob": -6.58056116104126, "rank": 5, "decoded_token": " wood"}}, {"32656": {"logprob": -0.07081110030412674, "rank": 1, "decoded_token": " wooden"}, "44130": {"logprob": -2.9458110332489014, "rank": 2, "decoded_token": " rust"}, "17253": {"logprob": -4.6958112716674805, "rank": 3, "decoded_token": " weather"}, "12603": {"logprob": -5.8208112716674805, "rank": 4, "decoded_token": " wood"}, "3403": {"logprob": -6.0708112716674805, "rank": 5, "decoded_token": " text"}}, {"11237": {"logprob": -0.6428436636924744, "rank": 1, "decoded_token": " floor"}, "4691": {"logprob": -1.0178437232971191, "rank": 2, "decoded_token": " surface"}, "7042": {"logprob": -2.642843723297119, "rank": 3, "decoded_token": " background"}, "28984": {"logprob": -3.517843723297119, "rank": 4, "decoded_token": " deck"}, "92504": {"logprob": -6.017843723297119, "rank": 5, "decoded_token": " backdrop"}}, {"1626": {"logprob": -0.7337945103645325, "rank": 1, "decoded_token": ".\n"}, "1044": {"logprob": -0.8587945103645325, "rank": 2, "decoded_token": ","}, "1454": {"logprob": -3.3587944507598877, "rank": 3, "decoded_token": " with"}, "7283": {"logprob": -3.6087944507598877, "rank": 4, "decoded_token": " looking"}, "1321": {"logprob": -4.108794689178467, "rank": 5, "decoded_token": " and"}}, {"1050": {"logprob": -1.0132738680113107e-05, "rank": 1, "decoded_token": "2"}, "1051": {"logprob": -11.75001049041748, "rank": 2, "decoded_token": "3"}, "1256": {"logprob": -14.00001049041748, "rank": 3, "decoded_token": " "}, "1049": {"logprob": -14.62501049041748, "rank": 4, "decoded_token": "1"}, "1032": {"logprob": -14.62501049041748, "rank": 5, "decoded_token": " "}}, {"1046": {"logprob": -2.861018856492592e-06, "rank": 1, "decoded_token": "."}, "3590": {"logprob": -13.43750286102295, "rank": 2, "decoded_token": ".A"}, "4700": {"logprob": -15.37500286102295, "rank": 3, "decoded_token": ".M"}, "1626": {"logprob": -15.37500286102295, "rank": 4, "decoded_token": ".\n"}, "3051": {"logprob": -15.87500286102295, "rank": 5, "decoded_token": ".S"}}, {"1349": {"logprob": -0.6794427633285522, "rank": 1, "decoded_token": " A"}, "11826": {"logprob": -1.9294427633285522, "rank": 2, "decoded_token": " Maj"}, "37159": {"logprob": -2.116942882537842, "rank": 3, "decoded_token": " Snow"}, "27260": {"logprob": -2.616942882537842, "rank": 4, "decoded_token": " Mountain"}, "113465": {"logprob": -2.866942882537842, "rank": 5, "decoded_token": " Rug"}}, {"15375": {"logprob": -0.9194075465202332, "rank": 1, "decoded_token": " vast"}, "10726": {"logprob": -2.294407606124878, "rank": 2, "decoded_token": " scen"}, "4521": {"logprob": -2.356907606124878, "rank": 3, "decoded_token": " range"}, "122203": {"logprob": -2.419407606124878, "rank": 4, "decoded_token": " rugged"}, "61082": {"logprob": -2.856907606124878, "rank": 5, "decoded_token": " panor"}}, {"24361": {"logprob": -0.5804797410964966, "rank": 1, "decoded_token": " mountain"}, "127945": {"logprob": -1.8304797410964966, "rank": 2, "decoded_token": " mountainous"}, "28035": {"logprob": -2.455479621887207, "rank": 3, "decoded_token": " landscape"}, "4521": {"logprob": -2.455479621887207, "rank": 4, "decoded_token": " range"}, "1044": {"logprob": -2.705479621887207, "rank": 5, "decoded_token": ","}}, {"4521": {"logprob": -0.0493546724319458, "rank": 1, "decoded_token": " range"}, "28035": {"logprob": -3.0493545532226562, "rank": 2, "decoded_token": " landscape"}, "37691": {"logprob": -8.424354553222656, "rank": 3, "decoded_token": " valley"}, "13327": {"logprob": -9.049354553222656, "rank": 4, "decoded_token": " scene"}, "3719": {"logprob": -9.799354553222656, "rank": 5, "decoded_token": " view"}}, {"94973": {"logprob": -0.6676871180534363, "rank": 1, "decoded_token": " stretches"}, "2425": {"logprob": -1.792687177658081, "rank": 2, "decoded_token": " under"}, "1395": {"logprob": -2.292687177658081, "rank": 3, "decoded_token": " is"}, "1454": {"logprob": -2.730187177658081, "rank": 4, "decoded_token": " with"}, "7038": {"logprob": -3.292687177658081, "rank": 5, "decoded_token": " extends"}}, {"5669": {"logprob": -0.4542117118835449, "rank": 1, "decoded_token": " across"}, "2425": {"logprob": -1.454211711883545, "rank": 2, "decoded_token": " under"}, "1848": {"logprob": -2.454211711883545, "rank": 3, "decoded_token": " out"}, "2203": {"logprob": -4.204211711883545, "rank": 4, "decoded_token": " into"}, "25136": {"logprob": -4.641711711883545, "rank": 5, "decoded_token": " beneath"}}, {"1278": {"logprob": -0.23009441792964935, "rank": 1, "decoded_token": " the"}, "1261": {"logprob": -1.6050944328308105, "rank": 2, "decoded_token": " a"}, "1420": {"logprob": -5.6050944328308105, "rank": 3, "decoded_token": " an"}, "2425": {"logprob": -7.2300944328308105, "rank": 4, "decoded_token": " under"}, "1454": {"logprob": -10.167593955993652, "rank": 5, "decoded_token": " with"}}, {"48932": {"logprob": -0.3072167932987213, "rank": 1, "decoded_token": " horizon"}, "21283": {"logprob": -1.932216763496399, "rank": 2, "decoded_token": " sky"}, "3937": {"logprob": -3.1822168827056885, "rank": 3, "decoded_token": " image"}, "28035": {"logprob": -3.6822168827056885, "rank": 4, "decoded_token": " landscape"}, "3044": {"logprob": -3.6822168827056885, "rank": 5, "decoded_token": " sk"}}, {"2425": {"logprob": -0.2914469838142395, "rank": 1, "decoded_token": " under"}, "1044": {"logprob": -2.4164469242095947, "rank": 2, "decoded_token": ","}, "1454": {"logprob": -2.5414469242095947, "rank": 3, "decoded_token": " with"}, "1626": {"logprob": -3.7914469242095947, "rank": 4, "decoded_token": ".\n"}, "1408": {"logprob": -3.7914469242095947, "rank": 5, "decoded_token": " on"}}, {"1261": {"logprob": -0.0460360012948513, "rank": 1, "decoded_token": " a"}, "1420": {"logprob": -3.9210360050201416, "rank": 2, "decoded_token": " an"}, "16152": {"logprob": -4.1085357666015625, "rank": 3, "decoded_token": " cloud"}, "2136": {"logprob": -6.1710357666015625, "rank": 4, "decoded_token": " over"}, "6133": {"logprob": -6.4210357666015625, "rank": 5, "decoded_token": " clear"}}, {"16152": {"logprob": -0.20367540419101715, "rank": 1, "decoded_token": " cloud"}, "6133": {"logprob": -2.8286755084991455, "rank": 2, "decoded_token": " clear"}, "27254": {"logprob": -3.5161755084991455, "rank": 3, "decoded_token": " partly"}, "18416": {"logprob": -3.8286755084991455, "rank": 4, "decoded_token": " haz"}, "4391": {"logprob": -4.328675270080566, "rank": 5, "decoded_token": " light"}}, {"1121": {"logprob": -0.05241352692246437, "rank": 1, "decoded_token": "y"}, "1286": {"logprob": -3.8024134635925293, "rank": 2, "decoded_token": "ed"}, "77187": {"logprob": -4.552413463592529, "rank": 3, "decoded_token": "-filled"}, "4527": {"logprob": -4.802413463592529, "rank": 4, "decoded_token": "less"}, "114525": {"logprob": -4.927413463592529, "rank": 5, "decoded_token": "-covered"}}, {"21283": {"logprob": -0.0003716255014296621, "rank": 1, "decoded_token": " sky"}, "10991": {"logprob": -8.750371932983398, "rank": 2, "decoded_token": " blue"}, "1044": {"logprob": -9.375371932983398, "rank": 3, "decoded_token": ","}, "26549": {"logprob": -10.375371932983398, "rank": 4, "decoded_token": " gray"}, "34052": {"logprob": -11.250371932983398, "rank": 5, "decoded_token": " grey"}}, {"1626": {"logprob": -0.00012730741582345217, "rank": 1, "decoded_token": ".\n"}, "1044": {"logprob": -9.500126838684082, "rank": 2, "decoded_token": ","}, "1046": {"logprob": -10.500126838684082, "rank": 3, "decoded_token": "."}, "1454": {"logprob": -10.875126838684082, "rank": 4, "decoded_token": " with"}, "1294": {"logprob": -13.250126838684082, "rank": 5, "decoded_token": " in"}}, {"1051": {"logprob": -3.2186455882765586e-06, "rank": 1, "decoded_token": "3"}, "1052": {"logprob": -12.75000286102295, "rank": 2, "decoded_token": "4"}, "1050": {"logprob": -15.00000286102295, "rank": 3, "decoded_token": "2"}, "1049": {"logprob": -16.937503814697266, "rank": 4, "decoded_token": "1"}, "1032": {"logprob": -17.875003814697266, "rank": 5, "decoded_token": " "}}, {"1046": {"logprob": -1.6689286894688848e-06, "rank": 1, "decoded_token": "."}, "3590": {"logprob": -14.687501907348633, "rank": 2, "decoded_token": ".A"}, "5226": {"logprob": -15.687501907348633, "rank": 3, "decoded_token": ".D"}, "6847": {"logprob": -15.812501907348633, "rank": 4, "decoded_token": ".T"}, "48426": {"logprob": -16.812501907348633, "rank": 5, "decoded_token": ".The"}}, {"8342": {"logprob": -0.5730464458465576, "rank": 1, "decoded_token": " Sur"}, "1349": {"logprob": -1.6980464458465576, "rank": 2, "decoded_token": " A"}, "22468": {"logprob": -2.5730464458465576, "rank": 3, "decoded_token": " Several"}, "1488": {"logprob": -2.6980464458465576, "rank": 4, "decoded_token": " W"}, "15035": {"logprob": -3.1980464458465576, "rank": 5, "decoded_token": " People"}}, {"71284": {"logprob": -0.0033258858602494, "rank": 1, "decoded_token": "fers"}, "1102": {"logprob": -5.878325939178467, "rank": 2, "decoded_token": "f"}, "1726": {"logprob": -7.628325939178467, "rank": 3, "decoded_token": "fer"}, "61888": {"logprob": -12.253325462341309, "rank": 4, "decoded_token": "fline"}, "2119": {"logprob": -13.003325462341309, "rank": 5, "decoded_token": "fter"}}, {"7377": {"logprob": -1.4996429681777954, "rank": 1, "decoded_token": " wait"}, "1584": {"logprob": -1.7496429681777954, "rank": 2, "decoded_token": " are"}, "88014": {"logprob": -1.9371429681777954, "rank": 3, "decoded_token": " paddle"}, "1294": {"logprob": -1.9371429681777954, "rank": 4, "decoded_token": " in"}, "24434": {"logprob": -2.187142848968506, "rank": 5, "decoded_token": " ride"}}, {"1394": {"logprob": -0.6126739382743835, "rank": 1, "decoded_token": " for"}, "1294": {"logprob": -0.9876739382743835, "rank": 2, "decoded_token": " in"}, "1408": {"logprob": -2.7376739978790283, "rank": 3, "decoded_token": " on"}, "6482": {"logprob": -4.425173759460449, "rank": 4, "decoded_token": " patient"}, "1321": {"logprob": -5.612673759460449, "rank": 5, "decoded_token": " and"}}, {"22140": {"logprob": -0.00729279313236475, "rank": 1, "decoded_token": " waves"}, "1278": {"logprob": -5.632292747497559, "rank": 2, "decoded_token": " the"}, "1261": {"logprob": -5.757292747497559, "rank": 3, "decoded_token": " a"}, "39460": {"logprob": -8.257292747497559, "rank": 4, "decoded_token": " incoming"}, "1321": {"logprob": -9.757292747497559, "rank": 5, "decoded_token": " and"}}, {"1294": {"logprob": -0.3071398138999939, "rank": 1, "decoded_token": " in"}, "1408": {"logprob": -2.1821398735046387, "rank": 2, "decoded_token": " on"}, "1513": {"logprob": -2.4321398735046387, "rank": 3, "decoded_token": " at"}, "3016": {"logprob": -3.6821398735046387, "rank": 4, "decoded_token": " while"}, "1435": {"logprob": -3.8071398735046387, "rank": 5, "decoded_token": " as"}}, {"1278": {"logprob": -0.004646694287657738, "rank": 1, "decoded_token": " the"}, "1261": {"logprob": -6.1921467781066895, "rank": 2, "decoded_token": " a"}, "1420": {"logprob": -6.9421467781066895, "rank": 3, "decoded_token": " an"}, "40466": {"logprob": -7.2546467781066895, "rank": 4, "decoded_token": " shallow"}, "26517": {"logprob": -7.8796467781066895, "rank": 5, "decoded_token": " calm"}}, {"27208": {"logprob": -0.0658877044916153, "rank": 1, "decoded_token": " ocean"}, "7786": {"logprob": -3.440887689590454, "rank": 2, "decoded_token": " distance"}, "5124": {"logprob": -5.253387928009033, "rank": 3, "decoded_token": " early"}, "26517": {"logprob": -5.315887928009033, "rank": 4, "decoded_token": " calm"}, "11196": {"logprob": -5.378387928009033, "rank": 5, "decoded_token": " sea"}}, {"1513": {"logprob": -1.1504861116409302, "rank": 1, "decoded_token": " at"}, "1435": {"logprob": -1.2754861116409302, "rank": 2, "decoded_token": " as"}, "3184": {"logprob": -1.4004861116409302, "rank": 3, "decoded_token": " during"}, "3016": {"logprob": -2.9004859924316406, "rank": 4, "decoded_token": " while"}, "6117": {"logprob": -3.1504859924316406, "rank": 5, "decoded_token": " near"}}, {"97558": {"logprob": -0.12151996046304703, "rank": 1, "decoded_token": " sunset"}, "11729": {"logprob": -2.8715200424194336, "rank": 2, "decoded_token": " sun"}, "1266": {"logprob": -3.4965200424194336, "rank": 3, "decoded_token": " d"}, "54507": {"logprob": -3.9965200424194336, "rank": 4, "decoded_token": " dawn"}, "1261": {"logprob": -5.121520042419434, "rank": 5, "decoded_token": " a"}}, {"1626": {"logprob": -0.3073118329048157, "rank": 1, "decoded_token": ".\n"}, "1044": {"logprob": -2.182311773300171, "rank": 2, "decoded_token": ","}, "3016": {"logprob": -2.557311773300171, "rank": 3, "decoded_token": " while"}, "1454": {"logprob": -3.432311773300171, "rank": 4, "decoded_token": " with"}, "6117": {"logprob": -4.05731201171875, "rank": 5, "decoded_token": " near"}}, {"1052": {"logprob": -3.3378546504536644e-06, "rank": 1, "decoded_token": "4"}, "1051": {"logprob": -13.25000286102295, "rank": 2, "decoded_token": "3"}, "1049": {"logprob": -13.93750286102295, "rank": 3, "decoded_token": "1"}, "1053": {"logprob": -14.43750286102295, "rank": 4, "decoded_token": "5"}, "1032": {"logprob": -16.687503814697266, "rank": 5, "decoded_token": " "}}, {"1046": {"logprob": -1.6689286894688848e-06, "rank": 1, "decoded_token": "."}, "3590": {"logprob": -13.500001907348633, "rank": 2, "decoded_token": ".A"}, "6847": {"logprob": -16.437501907348633, "rank": 3, "decoded_token": ".T"}, "1044": {"logprob": -17.312501907348633, "rank": 4, "decoded_token": ","}, "1349": {"logprob": -17.375001907348633, "rank": 5, "decoded_token": " A"}}, {"1349": {"logprob": -0.004292916506528854, "rank": 1, "decoded_token": " A"}, "2048": {"logprob": -5.629292964935303, "rank": 2, "decoded_token": " An"}, "10638": {"logprob": -7.879292964935303, "rank": 3, "decoded_token": " Two"}, "111463": {"logprob": -10.004292488098145, "rank": 4, "decoded_token": " Trees"}, "1531": {"logprob": -10.879292488098145, "rank": 5, "decoded_token": " The"}}, {"53301": {"logprob": -1.5473321676254272, "rank": 1, "decoded_token": " winding"}, "15192": {"logprob": -1.7348321676254272, "rank": 2, "decoded_token": " narrow"}, "47945": {"logprob": -2.109832286834717, "rank": 3, "decoded_token": " dirt"}, "2169": {"logprob": -2.609832286834717, "rank": 4, "decoded_token": " ser"}, "59396": {"logprob": -2.672332286834717, "rank": 5, "decoded_token": " gravel"}}, {"59396": {"logprob": -0.8954829573631287, "rank": 1, "decoded_token": " gravel"}, "3549": {"logprob": -1.1454830169677734, "rank": 2, "decoded_token": " path"}, "47945": {"logprob": -1.6454830169677734, "rank": 3, "decoded_token": " dirt"}, "14801": {"logprob": -3.2704830169677734, "rank": 4, "decoded_token": " pathway"}, "15551": {"logprob": -4.270483016967773, "rank": 5, "decoded_token": " stone"}}, {"3549": {"logprob": -0.02117946185171604, "rank": 1, "decoded_token": " path"}, "14801": {"logprob": -3.896179437637329, "rank": 2, "decoded_token": " pathway"}, "33659": {"logprob": -8.14617919921875, "rank": 3, "decoded_token": " trail"}, "9480": {"logprob": -9.64617919921875, "rank": 4, "decoded_token": " road"}, "7368": {"logprob": -9.64617919921875, "rank": 5, "decoded_token": "path"}}, {"13335": {"logprob": -0.18962937593460083, "rank": 1, "decoded_token": " leads"}, "39985": {"logprob": -2.752129316329956, "rank": 2, "decoded_token": " cuts"}, "1639": {"logprob": -3.877129316329956, "rank": 3, "decoded_token": " me"}, "11500": {"logprob": -3.939629316329956, "rank": 4, "decoded_token": " runs"}, "2645": {"logprob": -4.189629554748535, "rank": 5, "decoded_token": " through"}}, {"2645": {"logprob": -0.05349981039762497, "rank": 1, "decoded_token": " through"}, "8994": {"logprob": -4.053499698638916, "rank": 2, "decoded_token": " towards"}, "2396": {"logprob": -4.303499698638916, "rank": 3, "decoded_token": " between"}, "2203": {"logprob": -4.678499698638916, "rank": 4, "decoded_token": " into"}, "1317": {"logprob": -5.678499698638916, "rank": 5, "decoded_token": " to"}}, {"1261": {"logprob": -0.017386287450790405, "rank": 1, "decoded_token": " a"}, "11223": {"logprob": -4.892386436462402, "rank": 2, "decoded_token": " green"}, "1295": {"logprob": -5.017386436462402, "rank": 3, "decoded_token": " l"}, "23170": {"logprob": -6.642386436462402, "rank": 4, "decoded_token": " grass"}, "1420": {"logprob": -7.267386436462402, "rank": 5, "decoded_token": " an"}}, {"1295": {"logprob": -0.9453322887420654, "rank": 1, "decoded_token": " l"}, "11223": {"logprob": -1.3203322887420654, "rank": 2, "decoded_token": " green"}, "23170": {"logprob": -1.9453322887420654, "rank": 3, "decoded_token": " grass"}, "12097": {"logprob": -2.4453322887420654, "rank": 4, "decoded_token": " park"}, "26428": {"logprob": -3.3203322887420654, "rank": 5, "decoded_token": " garden"}}, {"3506": {"logprob": -6.556489552167477e-06, "rank": 1, "decoded_token": "ush"}, "1374": {"logprob": -12.000006675720215, "rank": 2, "decoded_token": "us"}, "90716": {"logprob": -15.625006675720215, "rank": 3, "decoded_token": "USH"}, "16938": {"logprob": -15.875006675720215, "rank": 4, "decoded_token": "usher"}, "13326": {"logprob": -17.1875057220459, "rank": 5, "decoded_token": "inden"}}, {"11223": {"logprob": -0.3668670654296875, "rank": 1, "decoded_token": " green"}, "1044": {"logprob": -1.3668670654296875, "rank": 2, "decoded_token": ","}, "26428": {"logprob": -3.4918670654296875, "rank": 3, "decoded_token": " garden"}, "12097": {"logprob": -4.1168670654296875, "rank": 4, "decoded_token": " park"}, "23170": {"logprob": -5.8668670654296875, "rank": 5, "decoded_token": " grass"}}, {"12097": {"logprob": -0.5530153512954712, "rank": 1, "decoded_token": " park"}, "3727": {"logprob": -2.0530152320861816, "rank": 2, "decoded_token": " field"}, "28035": {"logprob": -2.1780152320861816, "rank": 3, "decoded_token": " landscape"}, "26428": {"logprob": -2.3030152320861816, "rank": 4, "decoded_token": " garden"}, "4457": {"logprob": -2.8030152320861816, "rank": 5, "decoded_token": " area"}}, {"1046": {"logprob": -0.7924000024795532, "rank": 1, "decoded_token": "."}, "1454": {"logprob": -1.2924000024795532, "rank": 2, "decoded_token": " with"}, "8994": {"logprob": -2.7923998832702637, "rank": 3, "decoded_token": " towards"}, "54410": {"logprob": -3.5423998832702637, "rank": 4, "decoded_token": " lined"}, "2425": {"logprob": -3.5423998832702637, "rank": 5, "decoded_token": " under"}}, {"2": {"logprob": -1.9073468138230965e-06, "rank": 1, "decoded_token": ""}, "1032": {"logprob": -13.250001907348633, "rank": 2, "decoded_token": " "}, "1256": {"logprob": -16.250001907348633, "rank": 3, "decoded_token": " "}, "1293": {"logprob": -19.000001907348633, "rank": 4, "decoded_token": " "}, "1319": {"logprob": -20.000001907348633, "rank": 5, "decoded_token": " ("}}]]] \ No newline at end of file diff --git a/tests/models/test_pixtral.py b/tests/models/test_pixtral.py deleted file mode 100644 index dc60cf7eae8b1..0000000000000 --- a/tests/models/test_pixtral.py +++ /dev/null @@ -1,64 +0,0 @@ -"""Compare the outputs of HF and vLLM for Mistral models using greedy sampling. - -Run `pytest tests/models/test_mistral.py`. -""" -import pytest - -from vllm.sampling_params import SamplingParams - -pytestmark = pytest.mark.vlm - -MODELS = ["mistralai/Pixtral-12B-2409"] - - -@pytest.mark.skip( - reason= - "Model is too big, test passed on A100 locally but will OOM on CI machine." -) -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["bfloat16"]) -@pytest.mark.parametrize("max_tokens", [64]) -@pytest.mark.parametrize("num_logprobs", [5]) -def test_models( - vllm_runner, - example_prompts, - model: str, - dtype: str, - max_tokens: int, - num_logprobs: int, -) -> None: - image_urls = [ - "https://picsum.photos/id/237/200/300", - "https://picsum.photos/seed/picsum/200/300" - ] - expected = [ - "The image depicts a black dog lying on a wooden surface, looking directly at the camera with a calm expression.", # noqa - "The image depicts a serene landscape with a snow-covered mountain under a pastel-colored sky during sunset." # noqa - ] - prompt = "Describe the image in one short sentence." - - sampling_params = SamplingParams(max_tokens=512, temperature=0.0) - - with vllm_runner(model, dtype=dtype, - tokenizer_mode="mistral") as vllm_model: - - for i, image_url in enumerate(image_urls): - messages = [ - { - "role": - "user", - "content": [{ - "type": "text", - "text": prompt - }, { - "type": "image_url", - "image_url": { - "url": image_url - } - }] - }, - ] - - outputs = vllm_model.model.chat(messages, - sampling_params=sampling_params) - assert outputs[0].outputs[0].text == expected[i] diff --git a/tests/models/utils.py b/tests/models/utils.py index 93ec03995094b..8e31a1d6eefed 100644 --- a/tests/models/utils.py +++ b/tests/models/utils.py @@ -1,7 +1,7 @@ import warnings from typing import Dict, List, Optional, Sequence, Tuple, Union -from vllm.sequence import Logprob, SampleLogprobs +from vllm.sequence import Logprob, PromptLogprobs, SampleLogprobs TokensText = Tuple[List[int], str] @@ -34,20 +34,47 @@ def check_outputs_equal( assert output_ids_0 == output_ids_1, fail_msg +# Representation of generated sequence as a tuple of +# * Token ID list +# * String +# * List of top sample logprobs for each sampled token +# +# Assumes prompt logprobs were not requested. TokensTextLogprobs = Tuple[List[int], str, Optional[Union[List[Dict[int, float]], SampleLogprobs]]] -# Allow for tokens to be represented as str's rather than IDs +# Allow for tokens to be represented as str's rather than IDs; +# tuple of +# * Token string representations list +# * String +# * Optional list of top sample logprobs for each sampled token +# +# Assumes prompt logprobs were not requested. TextTextLogprobs = Tuple[List[str], str, Optional[Union[List[Dict[str, float]], List[Dict[str, Logprob]]]]] +# Representation of generated sequence as a tuple of +# * Token ID list +# * String +# * Optional list of top sample logprobs for each sampled token +# * Optional list of top prompt logprobs for each prompt token +# +# Allows prompt logprobs to be requested. +TokensTextLogprobsPromptLogprobs = Tuple[ + List[int], str, Optional[Union[List[Dict[int, float]], SampleLogprobs]], + Optional[Union[List[Optional[Dict[int, float]]], PromptLogprobs]]] + def check_logprobs_close( *, - outputs_0_lst: Sequence[Union[TokensTextLogprobs, TextTextLogprobs]], - outputs_1_lst: Sequence[Union[TokensTextLogprobs, TextTextLogprobs]], + outputs_0_lst: Sequence[Union[TokensTextLogprobs, + TokensTextLogprobsPromptLogprobs, + TextTextLogprobs]], + outputs_1_lst: Sequence[Union[TokensTextLogprobs, + TokensTextLogprobsPromptLogprobs, + TextTextLogprobs]], name_0: str, name_1: str, num_outputs_0_skip_tokens: int = 0, @@ -57,6 +84,18 @@ def check_logprobs_close( """Compare the logprobs of two sequences generated by different models, which should be similar but not necessarily equal. + How sample logprobs are compared: + * `always_check_logprobs == True`: set of highest-logprob token ids + must match between seq0 and seq1 at all sampled token offsets + * `always_check_logprobs == False`: highest-logprob token ids are + only compared at sampled token offsets for which generated token + ids don't match + + Prompt logprobs must be provided either for both input sequences, or + for neither. If prompt logprobs are provided, then highest-logprob + prompt token ids must match between seq0 and seq1 at all prompt token + offsets. + Args: outputs_0_lst: First sequence to compare outputs_0_lst: Second sequence to compare @@ -78,8 +117,65 @@ def check_logprobs_close( for prompt_idx, (outputs_0, outputs_1) in enumerate(zip(outputs_0_lst, outputs_1_lst)): - output_ids_0, output_str_0, logprobs_0 = outputs_0 - output_ids_1, output_str_1, logprobs_1 = outputs_1 + assert len(outputs_0) == len(outputs_1) + if len(outputs_0) == 3: + assert len(outputs_1) == 3 + # Break out tokens, text & sample logprobs + # (prompt logprobs were not provided) + output_ids_0, output_str_0, logprobs_0 = outputs_0 + output_ids_1, output_str_1, logprobs_1 = outputs_1 + elif len(outputs_0) == 4: + assert len(outputs_1) == 4 + # Break out tokens, text, sample logprobs & prompt logprobs + ( + output_ids_0, + output_str_0, + logprobs_0, + prompt_logprobs_0, + ) = outputs_0 + ( + output_ids_1, + output_str_1, + logprobs_1, + prompt_logprobs_1, + ) = outputs_1 + + # Test prompt logprobs closeness + if (prompt_logprobs_0 is not None + and prompt_logprobs_1 is not None): + # Both sequences' prompt logprobs lists are not `None`` + # (although individual list elements may be `None`); + # for each token's logprobs: + for idx, (logprobs_elem_0, logprobs_elem_1) in enumerate( + zip(prompt_logprobs_0, prompt_logprobs_1)): + fail_msg = ( + f"Prompt logprobs test:" + f"\n{name_0}:\tPrompt index {idx}\t{logprobs_elem_0}" + f"\n{name_1}:\tPrompt index {idx}\t{logprobs_elem_1}") + + if logprobs_elem_0 is None: + # If the seq 0 token's logprobs are `None`, + # the seq 1 token's logprobs must be `None` + assert logprobs_elem_1 is None, fail_msg + else: + # If the seq 0 token's logprobs are not `None`, + # the seq 1 token's logprobs must not be `None` + assert logprobs_elem_1 is not None, fail_msg + # Logprobs check: top-k token choices must be the same + assert (set(logprobs_elem_0.keys()) == set( + logprobs_elem_1.keys())), fail_msg + else: + # Both sequence logprobs lists must be `None` + fail_msg = (f"Prompt logprobs test:" + f"\n{name_0}:\tlogprobs\t{prompt_logprobs_0}" + f"\n{name_1}:\tlogprobs\t{prompt_logprobs_1}") + + assert (prompt_logprobs_0 is None + and prompt_logprobs_1 is None), fail_msg + else: + raise ValueError(f"Outputs tuple must have 3 or 4 elements but " + f"{len(outputs_0)} elements were provided: " + f"{outputs_0}") if logprobs_0 is None: logprobs_0 = [None] * len(output_ids_0) diff --git a/tests/mq_llm_engine/__init__.py b/tests/mq_llm_engine/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/mq_llm_engine/test_abort.py b/tests/mq_llm_engine/test_abort.py new file mode 100644 index 0000000000000..782b508a57149 --- /dev/null +++ b/tests/mq_llm_engine/test_abort.py @@ -0,0 +1,67 @@ +"""Test that aborting is handled properly.""" + +import asyncio +import tempfile +import uuid + +import pytest + +from tests.mq_llm_engine.utils import RemoteMQLLMEngine, generate +from vllm.engine.arg_utils import AsyncEngineArgs + +MODEL = "google/gemma-1.1-2b-it" +ENGINE_ARGS = AsyncEngineArgs(model=MODEL) +RAISED_ERROR = KeyError +RAISED_VALUE = "foo" +EXPECTED_TOKENS = 250 + + +@pytest.fixture(scope="function") +def tmp_socket(): + with tempfile.TemporaryDirectory() as td: + yield f"ipc://{td}/{uuid.uuid4()}" + + +@pytest.mark.asyncio +async def test_abort(tmp_socket): + with RemoteMQLLMEngine(engine_args=ENGINE_ARGS, + ipc_path=tmp_socket) as engine: + + client = await engine.make_client() + + request_id_to_be_aborted = "request-aborted" + request_ids_a = [f"request-a-{idx}" for idx in range(10)] + request_ids_b = [f"request-b-{idx}" for idx in range(10)] + + # Requests started before one to be aborted. + tasks = [] + for request_id in request_ids_a: + tasks.append( + asyncio.create_task( + generate(client, request_id, EXPECTED_TOKENS))) + + # Aborted. + task_aborted = asyncio.create_task( + generate(client, request_id_to_be_aborted, EXPECTED_TOKENS)) + + # Requests started after one to be aborted. + for request_id in request_ids_b: + tasks.append( + asyncio.create_task( + generate(client, request_id, EXPECTED_TOKENS))) + + # Actually abort. + await asyncio.sleep(0.5) + await client.abort(request_id_to_be_aborted) + + # Confirm that we got all the EXPECTED tokens from the requests. + for task in tasks: + count, request_id = await task + assert count == EXPECTED_TOKENS, ( + f"{request_id} generated only {count} tokens") + + # Cancel task (this will hang indefinitely if not). + task_aborted.cancel() + + # Shutdown. + client.close() diff --git a/tests/mq_llm_engine/test_error_handling.py b/tests/mq_llm_engine/test_error_handling.py new file mode 100644 index 0000000000000..49cfc5aa04c36 --- /dev/null +++ b/tests/mq_llm_engine/test_error_handling.py @@ -0,0 +1,244 @@ +"""Test that various errors are handled properly.""" + +import asyncio +import tempfile +import time +import uuid +from unittest.mock import Mock + +import pytest + +from tests.mq_llm_engine.utils import RemoteMQLLMEngine +from vllm import SamplingParams +from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.engine.llm_engine import LLMEngine +from vllm.engine.multiprocessing import MQEngineDeadError +from vllm.engine.multiprocessing.engine import MQLLMEngine +from vllm.entrypoints.openai.api_server import build_async_engine_client +from vllm.entrypoints.openai.cli_args import make_arg_parser +from vllm.lora.request import LoRARequest +from vllm.usage.usage_lib import UsageContext +from vllm.utils import FlexibleArgumentParser + +MODEL = "google/gemma-1.1-2b-it" +ENGINE_ARGS = AsyncEngineArgs(model=MODEL) +RAISED_ERROR = KeyError +RAISED_VALUE = "foo" + + +@pytest.fixture(scope="function") +def tmp_socket(): + with tempfile.TemporaryDirectory() as td: + yield f"ipc://{td}/{uuid.uuid4()}" + + +def run_with_evil_forward(engine_args: AsyncEngineArgs, ipc_path: str): + # Make engine. + engine = MQLLMEngine.from_engine_args( + engine_args=engine_args, + usage_context=UsageContext.UNKNOWN_CONTEXT, + ipc_path=ipc_path) + + # Raise error during first forward pass. + engine.engine.model_executor.execute_model = Mock( + side_effect=RAISED_ERROR(RAISED_VALUE)) + + # Run engine. + engine.start() + + +@pytest.mark.asyncio +async def test_evil_forward(tmp_socket): + with RemoteMQLLMEngine(engine_args=ENGINE_ARGS, + ipc_path=tmp_socket, + run_fn=run_with_evil_forward) as engine: + + client = await engine.make_client() + + # Server should be healthy after initial probe. + await asyncio.sleep(2.0) + await client.check_health() + + # Throws an error in first forward pass. + with pytest.raises(RAISED_ERROR): + async for _ in client.generate(inputs="Hello my name is", + sampling_params=SamplingParams(), + request_id=uuid.uuid4()): + pass + assert client.errored + + # Engine is errored, should get ENGINE_DEAD_ERROR. + with pytest.raises(MQEngineDeadError): + async for _ in client.generate(inputs="Hello my name is", + sampling_params=SamplingParams(), + request_id=uuid.uuid4()): + pass + assert client.errored + + await asyncio.sleep(1.0) + with pytest.raises(RAISED_ERROR): + await client.check_health() + assert client.errored + + # Shutdown. + client.close() + + +def run_with_evil_model_executor_health(engine_args: AsyncEngineArgs, + ipc_path: str): + # Make engine. + engine = MQLLMEngine.from_engine_args( + engine_args=engine_args, + usage_context=UsageContext.UNKNOWN_CONTEXT, + ipc_path=ipc_path) + + # Raise error during first forward pass. + engine.engine.model_executor.check_health = Mock(side_effect=RAISED_ERROR) + + # Run engine. + engine.start() + + +@pytest.mark.asyncio +async def test_failed_health_check(tmp_socket): + with RemoteMQLLMEngine( + engine_args=ENGINE_ARGS, + ipc_path=tmp_socket, + run_fn=run_with_evil_model_executor_health) as engine: + + client = await engine.make_client() + assert client.is_running + + # Health probe should throw RAISED_ERROR. + await asyncio.sleep(15.) + + with pytest.raises(RAISED_ERROR): + await client.check_health() + assert client.errored + + # Generate call should throw ENGINE_DEAD_ERROR + with pytest.raises(MQEngineDeadError): + async for _ in client.generate(inputs="Hello my name is", + sampling_params=SamplingParams(), + request_id=uuid.uuid4()): + pass + + client.close() + + +def run_with_evil_abort(engine_args: AsyncEngineArgs, ipc_path: str): + # Make engine. + engine = MQLLMEngine.from_engine_args( + engine_args=engine_args, + usage_context=UsageContext.UNKNOWN_CONTEXT, + ipc_path=ipc_path) + + # Raise error during abort call. + engine.engine.abort_request = Mock(side_effect=RAISED_ERROR) + + # Run engine. + engine.start() + + +@pytest.mark.asyncio +async def test_failed_abort(tmp_socket): + with RemoteMQLLMEngine(engine_args=ENGINE_ARGS, + ipc_path=tmp_socket, + run_fn=run_with_evil_abort) as engine: + + client = await engine.make_client() + assert client.is_running + + # Firsh check health should work. + await client.check_health() + + # Trigger an abort on the client side. + async def bad_abort_after_2s(): + await asyncio.sleep(2.0) + await client.abort(request_id="foo") + + # Trigger an abort in 2s from now. + abort_task = asyncio.create_task(bad_abort_after_2s()) + + # Exception in abort() will happen during this generation. + # This will kill the engine and should return ENGINE_DEAD_ERROR + # with reference to the original KeyError("foo") + with pytest.raises(MQEngineDeadError) as execinfo: + async for _ in client.generate( + inputs="Hello my name is", + sampling_params=SamplingParams(max_tokens=2000), + request_id=uuid.uuid4()): + pass + assert "KeyError" in repr(execinfo.value) + assert client.errored + + await abort_task + + # This should raise the original error. + with pytest.raises(RAISED_ERROR): + await client.check_health() + + client.close() + + +@pytest.mark.asyncio +async def test_bad_request(tmp_socket): + with RemoteMQLLMEngine(engine_args=ENGINE_ARGS, + ipc_path=tmp_socket) as engine: + + client = await engine.make_client() + + # Invalid request should fail, but not crash the server. + with pytest.raises(ValueError): + async for _ in client.generate(inputs="Hello my name is", + sampling_params=SamplingParams(), + request_id="abcd-1", + lora_request=LoRARequest( + "invalid-lora", 1, + "invalid-path")): + pass + + # This request should be okay. + async for _ in client.generate(inputs="Hello my name is", + sampling_params=SamplingParams(), + request_id="abcd-2"): + pass + + # Shutdown. + client.close() + + +@pytest.mark.asyncio +async def test_mp_crash_detection(monkeypatch): + + parser = FlexibleArgumentParser(description="vLLM's remote OpenAI server.") + parser = make_arg_parser(parser) + args = parser.parse_args([]) + + # When LLMEngine is loaded, it will crash. + def mock_init(): + raise ValueError + + monkeypatch.setattr(LLMEngine, "__init__", mock_init) + + start = time.perf_counter() + async with build_async_engine_client(args): + pass + end = time.perf_counter() + + assert end - start < 60, ("Expected vLLM to gracefully shutdown in <60s " + "if there is an error in the startup.") + + +@pytest.mark.asyncio +async def test_mp_cuda_init(): + # it should not crash, when cuda is initialized + # in the API server process + import torch + torch.cuda.init() + parser = FlexibleArgumentParser(description="vLLM's remote OpenAI server.") + parser = make_arg_parser(parser) + args = parser.parse_args([]) + + async with build_async_engine_client(args): + pass diff --git a/tests/mq_llm_engine/test_load.py b/tests/mq_llm_engine/test_load.py new file mode 100644 index 0000000000000..630c112d0f0c9 --- /dev/null +++ b/tests/mq_llm_engine/test_load.py @@ -0,0 +1,57 @@ +"""Test that the MQLLMEngine is able to handle 10k concurrent requests.""" + +import asyncio +import tempfile +import uuid + +import pytest + +from tests.mq_llm_engine.utils import RemoteMQLLMEngine, generate +from vllm.engine.arg_utils import AsyncEngineArgs + +MODEL = "google/gemma-1.1-2b-it" +NUM_EXPECTED_TOKENS = 10 +NUM_REQUESTS = 10000 + +# Scenarios to test for num generated token. +ENGINE_ARGS = AsyncEngineArgs(model=MODEL, disable_log_requests=True) + + +@pytest.fixture(scope="function") +def tmp_socket(): + with tempfile.TemporaryDirectory() as td: + yield f"ipc://{td}/{uuid.uuid4()}" + + +@pytest.mark.asyncio +async def test_load(tmp_socket): + with RemoteMQLLMEngine(engine_args=ENGINE_ARGS, + ipc_path=tmp_socket) as engine: + + client = await engine.make_client() + + request_ids = [f"request-{i}" for i in range(NUM_REQUESTS)] + + # Create concurrent requests. + tasks = [] + for request_id in request_ids: + tasks.append( + asyncio.create_task( + generate(client, request_id, NUM_EXPECTED_TOKENS))) + + # Confirm that we got all the EXPECTED tokens from the requests. + failed_request_id = None + tokens = None + for task in tasks: + num_generated_tokens, request_id = await task + if (num_generated_tokens != NUM_EXPECTED_TOKENS + and failed_request_id is None): + failed_request_id = request_id + tokens = num_generated_tokens + + assert failed_request_id is None, ( + f"{failed_request_id} generated {tokens} but " + f"expected {NUM_EXPECTED_TOKENS}") + + # Shutdown. + client.close() diff --git a/tests/mq_llm_engine/utils.py b/tests/mq_llm_engine/utils.py new file mode 100644 index 0000000000000..e27fd77923412 --- /dev/null +++ b/tests/mq_llm_engine/utils.py @@ -0,0 +1,78 @@ +import asyncio +import multiprocessing +from typing import Callable, Tuple, Union + +from vllm import SamplingParams +from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.engine.multiprocessing.client import MQLLMEngineClient +from vllm.engine.multiprocessing.engine import MQLLMEngine +from vllm.outputs import RequestOutput +from vllm.usage.usage_lib import UsageContext + + +async def generate( + client: MQLLMEngineClient, + request_id: str, + num_tokens: int, + return_output: bool = False) -> Union[RequestOutput, Tuple[int, str]]: + + final_output = None + count = 0 + async for out in client.generate( + request_id=request_id, + inputs="Hello my name is Robert and", + sampling_params=SamplingParams(max_tokens=num_tokens, + temperature=0)): + + count += 1 + final_output = out + await asyncio.sleep(0.) + + if return_output: + return final_output + + # Confirm we generated all the tokens we expected. + return count, request_id + + +def run_normal(engine_args: AsyncEngineArgs, ipc_path: str): + # Make engine. + engine = MQLLMEngine.from_engine_args( + engine_args=engine_args, + usage_context=UsageContext.UNKNOWN_CONTEXT, + ipc_path=ipc_path) + + # Run engine. + engine.start() + + +class RemoteMQLLMEngine: + + def __init__(self, + engine_args: AsyncEngineArgs, + ipc_path: str, + run_fn: Callable = run_normal) -> None: + + self.engine_args = engine_args + self.ipc_path = ipc_path + context = multiprocessing.get_context("spawn") + self.proc = context.Process(target=run_fn, + args=(engine_args, ipc_path)) + self.proc.start() + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.proc.kill() + + async def make_client(self) -> MQLLMEngineClient: + engine_config = self.engine_args.create_engine_config() + client = MQLLMEngineClient(self.ipc_path, engine_config) + while True: + try: + await client.setup() + break + except TimeoutError: + assert self.proc.is_alive() + return client diff --git a/tests/multi_step/test_correctness_llm.py b/tests/multi_step/test_correctness_llm.py index 24ebb60a9cbfd..c5dc81cc25622 100644 --- a/tests/multi_step/test_correctness_llm.py +++ b/tests/multi_step/test_correctness_llm.py @@ -100,3 +100,95 @@ def test_multi_step_llm( name_0="hf", name_1="vllm", ) + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("tp_size", [1]) +@pytest.mark.parametrize("max_tokens", [5]) +@pytest.mark.parametrize("enforce_eager", [True]) +@pytest.mark.parametrize("num_scheduler_steps", NUM_SCHEDULER_STEPS) +@pytest.mark.parametrize("num_prompts", NUM_PROMPTS) +@pytest.mark.parametrize("num_logprobs,num_prompt_logprobs", [(5, 5)]) +def test_multi_step_llm_w_prompt_logprobs( + vllm_runner, + example_prompts, + model: str, + dtype: str, + tp_size: int, + max_tokens: int, + enforce_eager: int, + num_scheduler_steps: int, + num_prompts: int, + num_logprobs: Optional[int], + num_prompt_logprobs: Optional[int], +) -> None: + """Test prompt logprobs with multi-step scheduling via sync LLM Engine. + + Set up a vLLM engine instance w/ single-step scheduling as a ground-truth + reference. + + Prompt them with the same example prompts. + + Validate: + * All generated logprobs are all very close + + Args: + hf_runner: HF transformers model runner fixture + vllm_runner: vLLM model runner fixture + example_prompts: test fixture providing example prompts + model: model under test (same for single- and multi-step engines) + dtype: tensor datatype for engine to utilize + tp_size: degree of tensor-parallelism + max_tokens: the maximum number of tokens to generate + enforce_eager + num_scheduler_steps: for multi-step scheduling, GPU-side steps per + GPU -> CPU output transfer + num_prompts: number of example prompts under test + num_logprobs: corresponds to the `logprobs` argument to the OpenAI + completions endpoint; `None` -> no logprobs + num_prompt_logprobs: number of logprobs to return for each prompt token; + note that this argument is not supported by the + OpenAI completions endpoint. + """ + + prompts = example_prompts + if len(prompts) < num_prompts: + prompts = prompts * ((num_prompts // len(prompts)) + 1) + prompts = prompts[:num_prompts] + assert len(prompts) == num_prompts + + with vllm_runner( + model, + dtype=dtype, + enforce_eager=enforce_eager, + gpu_memory_utilization=0.7, + tensor_parallel_size=tp_size, + use_v2_block_manager=True, + num_scheduler_steps=num_scheduler_steps, + ) as vllm_model: + vllm_outputs = vllm_model.generate_greedy_logprobs( + prompts, + max_tokens, + num_logprobs, + num_prompt_logprobs=num_prompt_logprobs) + + with vllm_runner( + model, + dtype=dtype, + enforce_eager=enforce_eager, + gpu_memory_utilization=0.7, + tensor_parallel_size=tp_size, + ) as vllm_model: + single_step_vllm_outputs = vllm_model.generate_greedy_logprobs( + prompts, + max_tokens, + num_logprobs, + num_prompt_logprobs=num_prompt_logprobs) + + check_logprobs_close( + outputs_0_lst=single_step_vllm_outputs, + outputs_1_lst=vllm_outputs, + name_0="hf", + name_1="vllm", + ) diff --git a/tests/multimodal/test_base.py b/tests/multimodal/test_base.py index e9562d2048f06..68d05de904ba8 100644 --- a/tests/multimodal/test_base.py +++ b/tests/multimodal/test_base.py @@ -5,7 +5,7 @@ def assert_nested_tensors_equal(expected: NestedTensors, actual: NestedTensors): - assert type(expected) == type(actual) + assert type(expected) == type(actual) # noqa: E721 if isinstance(expected, torch.Tensor): assert torch.equal(expected, actual) else: diff --git a/tests/quantization/test_bitsandbytes.py b/tests/quantization/test_bitsandbytes.py index 3f0c6cbc051a7..36167cf95f589 100644 --- a/tests/quantization/test_bitsandbytes.py +++ b/tests/quantization/test_bitsandbytes.py @@ -10,6 +10,8 @@ from tests.quantization.utils import is_quant_method_supported +from ..utils import fork_new_process_for_each_test + models_4bit_to_test = [ ('huggyllama/llama-7b', 'quantize model inflight'), ] @@ -29,6 +31,7 @@ @pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"), reason='bitsandbytes is not supported on this GPU type.') @pytest.mark.parametrize("model_name, description", models_4bit_to_test) +@fork_new_process_for_each_test def test_load_4bit_bnb_model(hf_runner, vllm_runner, example_prompts, model_name, description) -> None: @@ -41,6 +44,7 @@ def test_load_4bit_bnb_model(hf_runner, vllm_runner, example_prompts, reason='bitsandbytes is not supported on this GPU type.') @pytest.mark.parametrize("model_name, description", models_pre_qaunt_4bit_to_test) +@fork_new_process_for_each_test def test_load_pre_quant_4bit_bnb_model(hf_runner, vllm_runner, example_prompts, model_name, description) -> None: @@ -52,6 +56,7 @@ def test_load_pre_quant_4bit_bnb_model(hf_runner, vllm_runner, example_prompts, reason='bitsandbytes is not supported on this GPU type.') @pytest.mark.parametrize("model_name, description", models_pre_quant_8bit_to_test) +@fork_new_process_for_each_test def test_load_8bit_bnb_model(hf_runner, vllm_runner, example_prompts, model_name, description) -> None: @@ -59,6 +64,24 @@ def test_load_8bit_bnb_model(hf_runner, vllm_runner, example_prompts, model_name) +@pytest.mark.skipif(torch.cuda.device_count() < 2, + reason='Test requires at least 2 GPUs.') +@pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"), + reason='bitsandbytes is not supported on this GPU type.') +@pytest.mark.parametrize("model_name, description", models_4bit_to_test) +@fork_new_process_for_each_test +def test_load_tp_4bit_bnb_model(hf_runner, vllm_runner, example_prompts, + model_name, description) -> None: + + hf_model_kwargs = {"load_in_4bit": True} + validate_generated_texts(hf_runner, + vllm_runner, + example_prompts[:1], + model_name, + hf_model_kwargs, + vllm_tp_size=2) + + def log_generated_texts(prompts, outputs, runner_name): logged_texts = [] for i, (_, generated_text) in enumerate(outputs): @@ -75,32 +98,33 @@ def validate_generated_texts(hf_runner, vllm_runner, prompts, model_name, - hf_model_kwargs=None): - - if hf_model_kwargs is None: - hf_model_kwargs = {} - - # Run with HF runner - with hf_runner(model_name, model_kwargs=hf_model_kwargs) as llm: - hf_outputs = llm.generate_greedy(prompts, 8) - hf_logs = log_generated_texts(prompts, hf_outputs, "HfRunner") - - # Clean up the GPU memory for the next test - torch.cuda.synchronize() - gc.collect() - torch.cuda.empty_cache() + hf_model_kwargs=None, + vllm_tp_size=1): - #Run with vLLM runner + # NOTE: run vLLM first, as it requires a clean process + # when using distributed inference with vllm_runner(model_name, quantization='bitsandbytes', load_format='bitsandbytes', + tensor_parallel_size=vllm_tp_size, enforce_eager=True, gpu_memory_utilization=0.8) as llm: vllm_outputs = llm.generate_greedy(prompts, 8) vllm_logs = log_generated_texts(prompts, vllm_outputs, "VllmRunner") # Clean up the GPU memory for the next test - torch.cuda.synchronize() + gc.collect() + torch.cuda.empty_cache() + + if hf_model_kwargs is None: + hf_model_kwargs = {} + + # Run with HF runner + with hf_runner(model_name, model_kwargs=hf_model_kwargs) as llm: + hf_outputs = llm.generate_greedy(prompts, 8) + hf_logs = log_generated_texts(prompts, hf_outputs, "HfRunner") + + # Clean up the GPU memory for the next test gc.collect() torch.cuda.empty_cache() diff --git a/tests/quantization/test_fp8.py b/tests/quantization/test_fp8.py index 58864e83173f9..a0c1d7e24c503 100644 --- a/tests/quantization/test_fp8.py +++ b/tests/quantization/test_fp8.py @@ -86,9 +86,7 @@ def test_load_fp16_model(vllm_runner, kv_cache_dtype: str, force_marlin: bool, assert attn._k_scale == 1.0 assert attn._v_scale == 1.0 - capability = current_platform.get_device_capability() - capability = capability[0] * 10 + capability[1] - if capability >= 89 and not force_marlin: + if current_platform.has_device_capability(89) and not force_marlin: # For GPUs with hardware support, we keep weights in fp8 assert fc1.weight.dtype == torch.float8_e4m3fn else: diff --git a/tests/quantization/utils.py b/tests/quantization/utils.py index 65bb80ed70c6a..061a077592e80 100644 --- a/tests/quantization/utils.py +++ b/tests/quantization/utils.py @@ -1,15 +1,15 @@ -import torch - from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS from vllm.platforms import current_platform def is_quant_method_supported(quant_method: str) -> bool: # Currently, all quantization methods require Nvidia or AMD GPUs - if not torch.cuda.is_available(): + if not (current_platform.is_cuda() or current_platform.is_rocm()): return False capability = current_platform.get_device_capability() - capability = capability[0] * 10 + capability[1] - return (capability >= - QUANTIZATION_METHODS[quant_method].get_min_capability()) + assert capability is not None + + min_capability = QUANTIZATION_METHODS[quant_method].get_min_capability() + + return capability.to_int() >= min_capability diff --git a/tests/test_cache_block_hashing.py b/tests/test_cache_block_hashing.py index fe413d1228021..3576a4834ebc3 100644 --- a/tests/test_cache_block_hashing.py +++ b/tests/test_cache_block_hashing.py @@ -66,8 +66,7 @@ def test_auto_prefix_caching(model: str, block_size: int, max_num_seqs: int, hashes.append([]) prompts = [prefix + prompt for prompt in sample_prompts] - seq_id = 0 - for prompt in prompts: + for seq_id, prompt in enumerate(prompts): hashes[-1].append([]) prompt_token_ids = tokenizer.encode(prompt) seq = Sequence(seq_id, @@ -83,8 +82,6 @@ def test_auto_prefix_caching(model: str, block_size: int, max_num_seqs: int, for idx in range(num_blocks): hashes[-1][-1].append(seq.hash_of_block(idx)) - seq_id += 1 - # Check that hashes made with two prefixes with different first blocks are # different everywhere. for hash0, hash1 in zip(flatten_2d(hashes[0]), flatten_2d(hashes[1])): diff --git a/tests/test_logger.py b/tests/test_logger.py index 8f3d218416870..fadf66f2b61d4 100644 --- a/tests/test_logger.py +++ b/tests/test_logger.py @@ -111,7 +111,7 @@ def test_an_error_is_raised_when_custom_logging_config_file_does_not_exist(): configuration occurs.""" with pytest.raises(RuntimeError) as ex_info: _configure_vllm_root_logger() - assert ex_info.type == RuntimeError + assert ex_info.type == RuntimeError # noqa: E721 assert "File does not exist" in str(ex_info) @@ -152,7 +152,7 @@ def test_an_error_is_raised_when_custom_logging_config_is_unexpected_json( logging_config_file.name): with pytest.raises(ValueError) as ex_info: _configure_vllm_root_logger() - assert ex_info.type == ValueError + assert ex_info.type == ValueError # noqa: E721 assert "Invalid logging config. Expected Dict, got" in str(ex_info) diff --git a/tests/tpu/test_custom_dispatcher.py b/tests/tpu/test_custom_dispatcher.py index 7f3fb595321ad..69ab67abdd12b 100644 --- a/tests/tpu/test_custom_dispatcher.py +++ b/tests/tpu/test_custom_dispatcher.py @@ -1,5 +1,12 @@ +import os + from ..utils import compare_two_settings +# --enforce-eager on TPU causes graph compilation +# this times out default Health Check in the MQLLMEngine, +# so we set the timeout here to 30s +os.environ["VLLM_RPC_TIMEOUT"] = "30000" + def test_custom_dispatcher(): compare_two_settings("google/gemma-2b", diff --git a/tests/utils.py b/tests/utils.py index 6e5bc05b3901a..43825e8138362 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -10,6 +10,7 @@ from typing import Any, Callable, Dict, List, Optional import openai +import pytest import requests from openai.types.completion import Completion from transformers import AutoTokenizer @@ -22,7 +23,8 @@ from vllm.entrypoints.openai.cli_args import make_arg_parser from vllm.model_executor.model_loader.loader import get_model_loader from vllm.platforms import current_platform -from vllm.utils import FlexibleArgumentParser, get_open_port, is_hip +from vllm.utils import (FlexibleArgumentParser, cuda_device_count_stateless, + get_open_port, is_hip) if current_platform.is_rocm(): from amdsmi import (amdsmi_get_gpu_vram_usage, @@ -117,7 +119,7 @@ def __enter__(self): def __exit__(self, exc_type, exc_value, traceback): self.proc.terminate() try: - self.proc.wait(3) + self.proc.wait(8) except subprocess.TimeoutExpired: # force kill if needed self.proc.kill() @@ -356,12 +358,23 @@ def error_on_warning(): yield +def get_physical_device_indices(devices): + visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES") + if visible_devices is None: + return devices + + visible_indices = [int(x) for x in visible_devices.split(",")] + index_mapping = {i: physical for i, physical in enumerate(visible_indices)} + return [index_mapping[i] for i in devices if i in index_mapping] + + @_nvml() def wait_for_gpu_memory_to_clear(devices: List[int], threshold_bytes: int, timeout_s: float = 120) -> None: # Use nvml instead of pytorch to reduce measurement error from torch cuda # context. + devices = get_physical_device_indices(devices) start_time = time.time() while True: output: Dict[int, str] = {} @@ -441,6 +454,22 @@ def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> None: return wrapper +def multi_gpu_test(*, num_gpus: int): + """ + Decorate a test to be run only when multiple GPUs are available. + """ + test_selector = getattr(pytest.mark, f"distributed_{num_gpus}_gpus") + test_skipif = pytest.mark.skipif( + cuda_device_count_stateless() < num_gpus, + reason=f"Need at least {num_gpus} GPUs to run the test.", + ) + + def wrapper(f: Callable[_P, None]) -> Callable[_P, None]: + return test_selector(test_skipif(fork_new_process_for_each_test(f))) + + return wrapper + + async def completions_with_server_args( prompts: List[str], model_name: str, @@ -464,6 +493,7 @@ async def completions_with_server_args( ''' outputs = None + max_wait_seconds = 240 * 3 # 240 is default with RemoteOpenAIServer(model_name, server_cli_args, max_wait_seconds=max_wait_seconds) as server: @@ -474,7 +504,7 @@ async def completions_with_server_args( stream=False, max_tokens=5, logprobs=num_logprobs) - assert outputs is not None + assert outputs is not None, "Completion API call failed." return outputs diff --git a/tests/weight_loading/models-large.txt b/tests/weight_loading/models-large.txt index fe76705746766..2f5c6c5a117f3 100644 --- a/tests/weight_loading/models-large.txt +++ b/tests/weight_loading/models-large.txt @@ -1,3 +1,4 @@ compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-quantized, main compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-channel-quantized, main -gptq_marlin, TheBloke/Mixtral-8x7B-v0.1-GPTQ, main \ No newline at end of file +compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W8A16-quantized, main +gptq_marlin, TheBloke/Mixtral-8x7B-v0.1-GPTQ, main diff --git a/tests/weight_loading/run_model_weight_loading_test.sh b/tests/weight_loading/run_model_weight_loading_test.sh old mode 100644 new mode 100755 diff --git a/tests/worker/test_encoder_decoder_model_runner.py b/tests/worker/test_encoder_decoder_model_runner.py index 32bff22f66a8b..c0654712b71b5 100644 --- a/tests/worker/test_encoder_decoder_model_runner.py +++ b/tests/worker/test_encoder_decoder_model_runner.py @@ -1,3 +1,4 @@ +import itertools from array import array from typing import List @@ -7,13 +8,9 @@ from vllm.engine.arg_utils import EngineArgs from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, SamplingParams, SequenceData, SequenceGroupMetadata) -from vllm.utils import is_cpu +from vllm.utils import is_cpu, make_tensor_with_pad from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner - -# CUDA graph scenarios to test -# -# Currently CUDA graph is not supported -ENFORCE_EAGER = [True] +from vllm.worker.model_runner import _get_graph_batch_size BATCH_SIZES = [1, 4, 16, 64, 256] @@ -40,8 +37,7 @@ def _create_model_runner(model: str, *args, reason="CPU backend is currently " "unsupported for encoder/ " "decoder models") -@pytest.mark.parametrize("enforce_eager", ENFORCE_EAGER) -def test_empty_seq_group(enforce_eager, ): +def test_empty_seq_group(): """Verify prepare prompt and decode returns empty output for empty seq group list""" @@ -52,7 +48,7 @@ def test_empty_seq_group(enforce_eager, ): max_num_batched_tokens=100000, max_num_seqs=100000, enable_chunked_prefill=False, - enforce_eager=enforce_eager, + enforce_eager=True, ) seq_group_metadata_list: List[SequenceGroupMetadata] = [] model_input = model_runner._prepare_model_input_tensors( @@ -85,11 +81,7 @@ def test_empty_seq_group(enforce_eager, ): "unsupported for encoder/ " "decoder models") @pytest.mark.parametrize("batch_size", BATCH_SIZES) -@pytest.mark.parametrize("enforce_eager", ENFORCE_EAGER) -def test_prepare_prompt( - batch_size, - enforce_eager, -): +def test_prepare_prompt(batch_size): ''' Test the ability of the encoder/decoder model runner subclass to produce prefill-phase model inputs & attention metadata. @@ -115,7 +107,7 @@ def test_prepare_prompt( max_num_batched_tokens=100000, max_num_seqs=100000, enable_chunked_prefill=False, - enforce_eager=enforce_eager, + enforce_eager=True, ) seq_lens: List[int] = [] @@ -281,11 +273,7 @@ def test_prepare_prompt( "unsupported for encoder/ " "decoder models") @pytest.mark.parametrize("batch_size", BATCH_SIZES) -@pytest.mark.parametrize("enforce_eager", ENFORCE_EAGER) -def test_prepare_decode( - batch_size, - enforce_eager, -): +def test_prepare_decode(batch_size): ''' Test the ability of the encoder/decoder model runner subclass to produce decode-phase model inputs & attention metadata. @@ -311,7 +299,7 @@ def test_prepare_decode( max_num_batched_tokens=100000, max_num_seqs=100000, enable_chunked_prefill=False, - enforce_eager=enforce_eager, + enforce_eager=True, ) seq_lens: List[int] = [] @@ -428,7 +416,8 @@ def test_prepare_decode( expected, ) - # Cuda graph should is currently not supported for encoder/decoer. + # Model runner's CUDAGraph setting should be propagated to attention + # metadata. assert attn_metadata.use_cuda_graph is False # Verify the lengths of input tokens & positions @@ -464,8 +453,7 @@ def test_prepare_decode( # each sequence) in the decode phase expected_selected_token_indices = [] - selected_token_start_idx = 0 - for seq_len in seq_lens: + for selected_token_start_idx, seq_len in enumerate(seq_lens): # Compute the index offset of the final token in each # sequence's decoded outputs; since a single token is # decoded per iteration per sequence, then the length @@ -474,7 +462,6 @@ def test_prepare_decode( # generated tokens is 0 (i.e. the expected sampling index # for a given sequence is just `selected_token_start_idx`) expected_selected_token_indices.append(selected_token_start_idx) - selected_token_start_idx += 1 sampling_metadata = model_input.sampling_metadata actual = sampling_metadata.selected_token_indices @@ -484,3 +471,152 @@ def test_prepare_decode( dtype=actual.dtype, ) assert torch.equal(actual, expected) + + +@pytest.mark.parametrize("batch_size", list(range(1, 257))) +def test_prepare_decode_cuda_graph(batch_size): + """ + Tests that for encoder-decoder models with CUDA Graph capture and replay + enabled, the tensors used during the decode phase are correctly padded + for varying input batch sizes. + """ + model_runner = _create_model_runner( + "facebook/bart-base", + seed=0, + dtype="float16", + max_num_batched_tokens=100000, + max_num_seqs=100000, + enable_chunked_prefill=False, + enforce_eager=False, + ) + + seq_lens: List[int] = [] + encoder_seq_lens: List[int] = [] + seq_group_metadata_list: List[SequenceGroupMetadata] = [] + block_tables = {0: [1]} + cross_block_table = [2] + for i in range(batch_size): + # make sure all tokens fit into one block + seq_len = i % (model_runner.block_size - 1) + 1 + seq_lens.append(seq_len) + seq_data = SequenceData( + array(VLLM_TOKEN_ID_ARRAY_TYPE, (range(seq_len)))) + encoder_seq_len = (i + 1) % (model_runner.block_size - 1) + 1 + encoder_seq_lens.append(encoder_seq_len) + encoder_seq_data = SequenceData( + array(VLLM_TOKEN_ID_ARRAY_TYPE, (range(encoder_seq_len)))) + seq_group_metadata = SequenceGroupMetadata( + request_id=f"test_{i}", + is_prompt=False, + seq_data={0: seq_data}, + sampling_params=SamplingParams(temperature=0), + block_tables=block_tables, + encoder_seq_data=encoder_seq_data, + cross_block_table=cross_block_table, + ) + assert seq_group_metadata.token_chunk_size == 1 + seq_group_metadata_list.append(seq_group_metadata) + + model_input = model_runner.prepare_model_input(seq_group_metadata_list) + input_tokens = model_input.input_tokens + input_positions = model_input.input_positions + attn_metadata = model_input.attn_metadata + return_seq_lens = model_input.seq_lens + slot_mapping = attn_metadata.slot_mapping + encoder_input_tokens = model_input.encoder_input_tokens + encoder_input_positions = model_input.encoder_input_positions + cross_slot_mapping = attn_metadata.cross_slot_mapping + + # With CUDA Graph capture and replay enabled, the decoder and encoder + # input sequences will be padded. Create the expected padded tensors + # accordingly. + graph_batch_size = _get_graph_batch_size(batch_size) + cuda_graph_pad_size = graph_batch_size - batch_size + padded_seq_lens = seq_lens + list(itertools.repeat(1, cuda_graph_pad_size)) + padded_encoder_seq_lens = encoder_seq_lens + list( + itertools.repeat(1, cuda_graph_pad_size)) + + assert return_seq_lens == padded_seq_lens + assert len(slot_mapping) == len(input_tokens) + assert len(cross_slot_mapping) == len(encoder_input_tokens) + + # Verify attention metadata + device = model_runner.device + assert attn_metadata.num_prefills == 0 + assert attn_metadata.num_decode_tokens > 0 + assert torch.equal( + attn_metadata.seq_lens_tensor, + torch.tensor(padded_seq_lens, device=device, dtype=torch.int)) + assert attn_metadata.seq_lens == padded_seq_lens + assert attn_metadata.max_prefill_seq_len == 0 + assert attn_metadata.max_decode_seq_len == max(seq_lens) + # - Encoder attention metadata + assert attn_metadata.encoder_seq_lens == padded_encoder_seq_lens + assert torch.equal( + attn_metadata.encoder_seq_lens_tensor, + torch.tensor(padded_encoder_seq_lens, device=device, dtype=torch.int)) + assert attn_metadata.max_encoder_seq_len == max(padded_encoder_seq_lens) + assert attn_metadata.num_encoder_tokens == sum(padded_encoder_seq_lens) + + # Verify block tables are correct for prompts + # - Decoder self-attention. Pad the block tables as expected. + expected = [block_tables[0] for _ in range(batch_size)] + expected.extend([[] for _ in range(cuda_graph_pad_size)]) + expected = make_tensor_with_pad( + expected, + max_len=64, + pad=0, + dtype=torch.int32, + device=model_runner.device, + ) + assert torch.equal( + attn_metadata.block_tables, + expected, + ) + # - Encoder/decoder cross-attention. Pad the cross-attention block tables + # as expected. + expected = [cross_block_table for _ in range(len(seq_group_metadata_list))] + expected.extend([[] for _ in range(cuda_graph_pad_size)]) + expected = make_tensor_with_pad( + expected, + max_len=64, + pad=0, + dtype=torch.int32, + device=model_runner.device, + ) + assert torch.equal( + attn_metadata.cross_block_tables, + expected, + ) + + # Model runner's CUDAGraph setting should be propagated to attention + # metadata. + assert attn_metadata.use_cuda_graph is True + + # Verify the lengths of input tokens & positions + # - Decoder + assert len(input_tokens) == len(padded_seq_lens) + assert len(input_positions) == len(padded_seq_lens) + # -- An indirect check that model_input.input_tokens + # and model_input.input_positions are correct - + # by design of the test, the input tokens are + # equal to the input position values, so if + # the model_input data structure has the correct + # values then these two should be equal + assert torch.equal( + input_tokens, + input_positions, + ) + # - Encoder + assert len(encoder_input_tokens) == 0 + assert len(encoder_input_tokens) == 0 + # -- An indirect check that model_input.encoder_input_tokens + # and model_input.encoder_input_positions are correct - + # by design of the test, the input tokens are + # equal to the input position values, so if + # the model_input data structure has the correct + # values then these two should be equal + assert torch.equal( + encoder_input_tokens, + encoder_input_positions, + ) diff --git a/tests/worker/test_model_runner.py b/tests/worker/test_model_runner.py index a20aa37bcc1e2..42b2337f46914 100644 --- a/tests/worker/test_model_runner.py +++ b/tests/worker/test_model_runner.py @@ -241,10 +241,8 @@ def test_prepare_decode_cuda_graph(batch_size): # Verify Sampling expected_selected_token_indices = [] - selected_token_start_idx = 0 - for _ in context_lens: + for selected_token_start_idx, _ in enumerate(context_lens): expected_selected_token_indices.append(selected_token_start_idx) - selected_token_start_idx += 1 sampling_metadata = SamplingMetadata.prepare( seq_group_metadata_list, seq_lens, diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index efa02d36c4acd..ff5aa8bee3c27 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -17,6 +17,9 @@ except ImportError as e: logger.warning("Failed to import from vllm._C with %r", e) +if current_platform.is_rocm(): + import vllm._rocm_C # noqa: F401 + with contextlib.suppress(ImportError): import vllm._moe_C # noqa: F401 @@ -127,6 +130,30 @@ def paged_attention_v2( blocksparse_block_size, blocksparse_head_sliding_step) +def paged_attention_rocm( + out: torch.Tensor, + exp_sum: torch.Tensor, + max_logits: torch.Tensor, + tmp_out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + block_size: int, + max_seq_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, +) -> None: + torch.ops._rocm_C.paged_attention(out, exp_sum, max_logits, tmp_out, query, + key_cache, value_cache, num_kv_heads, + scale, block_tables, seq_lens, + block_size, max_seq_len, alibi_slopes, + kv_cache_dtype) + + # pos encoding ops def rotary_embedding( positions: torch.Tensor, @@ -532,7 +559,7 @@ def gptq_marlin_moe_repack(b_q_weight: torch.Tensor, perm: torch.Tensor, num_bits: int) -> torch.Tensor: num_experts = b_q_weight.shape[0] assert size_k % 16 == 0 - output = torch.empty((num_experts, size_k // 16, size_n * 2), + output = torch.empty((num_experts, size_k // 16, size_n * (num_bits // 2)), device=b_q_weight.device, dtype=b_q_weight.dtype) for e in range(num_experts): @@ -657,32 +684,43 @@ def scaled_fp8_quant( # int8 def scaled_int8_quant( - input: torch.Tensor, - scale: Optional[torch.Tensor] = None -) -> Tuple[torch.Tensor, torch.Tensor]: + input: torch.Tensor, + scale: Optional[torch.Tensor] = None, + azp: Optional[torch.Tensor] = None, + symmetric: bool = True +) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: """ - Quantize the input tensor to int8 and return the quantized tensor and scale. + Quantize the input tensor to int8 and return the quantized tensor and scale, and maybe azp. Args: input: The input tensor to be quantized to int8. scale: Optional scaling factor for the int8 quantization. When not provided, we invoke dynamic-per-token quantization. + azp: Optional zero-point for the int8 quantization. + Must be provided for asymmetric quantization if `scale` is provided. + symmetric: Whether to use symmetric quantization (scale only, azp ignored). Returns: - Tuple[Torch.Tensor, Torch.Tensor] : Output int8 tensor and scales. + Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp. """ output = torch.empty_like(input, dtype=torch.int8) if scale is not None: # static-per-tensor quantization. - torch.ops._C.static_scaled_int8_quant(output, input, scale) - return output, scale + assert symmetric == ( + azp is + None), "azp must only be provided for asymmetric quantization." + torch.ops._C.static_scaled_int8_quant(output, input, scale, azp) + return output, scale, None # dynamic-per-token quantization. input_scales = torch.empty((input.numel() // input.shape[-1], 1), device=input.device, dtype=torch.float32) - torch.ops._C.dynamic_scaled_int8_quant(output, input, input_scales) - return output, input_scales + input_azp = None if symmetric else torch.empty_like(input_scales, + dtype=torch.int32) + torch.ops._C.dynamic_scaled_int8_quant(output, input, input_scales, + input_azp) + return output, input_scales, input_azp # qqq ops @@ -730,11 +768,17 @@ def causal_conv1d_fwd(x: torch.Tensor, weight: torch.Tensor, silu_activation) -def causal_conv1d_update(x: torch.Tensor, conv_state: torch.Tensor, - weight: torch.Tensor, bias_: Optional[torch.Tensor], - silu_activation: bool) -> torch.Tensor: +def causal_conv1d_update( + x: torch.Tensor, + conv_state: torch.Tensor, + weight: torch.Tensor, + bias_: Optional[torch.Tensor], + silu_activation: bool, + conv_state_indices: Optional[torch.Tensor], +) -> torch.Tensor: return torch.ops._C.causal_conv1d_update(x, conv_state, weight, bias_, - silu_activation) + silu_activation, + conv_state_indices) def selective_scan_fwd(u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor, @@ -832,12 +876,6 @@ def init_custom_ar(meta: torch.Tensor, rank_data: torch.Tensor, offsets, rank, full_nvlink) -def should_custom_ar(inp: torch.Tensor, max_size: int, world_size: int, - full_nvlink: bool) -> bool: - return torch.ops._C_custom_ar.should_custom_ar(inp, max_size, world_size, - full_nvlink) - - def all_reduce_reg(fa: int, inp: torch.Tensor, out: torch.Tensor) -> None: torch.ops._C_custom_ar.all_reduce_reg(fa, inp, out) diff --git a/vllm/_ipex_ops.py b/vllm/_ipex_ops.py index 2156f6b18adb6..31fcc4c3256a8 100644 --- a/vllm/_ipex_ops.py +++ b/vllm/_ipex_ops.py @@ -27,29 +27,27 @@ def _reshape_activation_tensor( @staticmethod def silu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None: - x1, x2 = ipex_ops._reshape_activation_tensor(x) - ipex.llm.functional.silu_mul(x1, x2, out) + ipex.llm.functional.silu_and_mul(x, out) @staticmethod def gelu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None: - x1, x2 = ipex_ops._reshape_activation_tensor(x) - ipex.llm.functional.gelu_mul(x1, x2, out, "none") + ipex.llm.functional.gelu_and_mul(x, out) @staticmethod def gelu_tanh_and_mul(out: torch.Tensor, x: torch.Tensor) -> None: - x1, x2 = ipex_ops._reshape_activation_tensor(x) - ipex.llm.functional.gelu_mul(x1, x2, out, "tanh") + ipex.llm.functional.gelu_and_mul(x, out) @staticmethod - def gelu_fast(out: torch.Tensor, x: torch.Tensor) -> None: - out.copy_(torch.nn.functional.gelu(x)) + def gelu_fast(x: torch.Tensor) -> torch.Tensor: + return torch.nn.functional.gelu(x) @staticmethod - def gelu_new(out: torch.Tensor, x: torch.Tensor) -> None: - out.copy_(torch.nn.functional.gelu(x)) + def gelu_new(x: torch.Tensor) -> torch.Tensor: + return torch.nn.functional.gelu(x) - # TODO add implementation of gelu_quick here - # def gelu_quick(out: torch.Tensor, x: torch.Tensor) -> None: + @staticmethod + def gelu_quick(out: torch.Tensor, x: torch.Tensor) -> None: + ipex.llm.functional.gelu_quick(x, out) @staticmethod def paged_attention_v1( @@ -160,29 +158,10 @@ def rotary_embedding( cos_sin_cache: torch.Tensor, # [cos_sin_dim, rot_dim] is_neox: bool, ) -> None: - if positions.dim() == 1: - positions = positions.unsqueeze(0) - query = query.unsqueeze(0) - key = key.unsqueeze(0) - - rotary_dim = cos_sin_cache.size(1) - query = query.view(*query.shape[:-1], -1, head_size) - key = key.view(*key.shape[:-1], -1, head_size) - - query_rot = query[..., :rotary_dim] - key_rot = key[..., :rotary_dim] - - cos_sin = cos_sin_cache[positions.long()] - cos, sin = cos_sin.chunk(2, dim=-1) - - if is_neox: - cos = cos.repeat(1, 1, 2).unsqueeze(-2) - sin = sin.repeat(1, 1, 2).unsqueeze(-2) - else: - cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2) - sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2) - ipex.llm.functional.rotary_embedding(query_rot, key_rot, sin, cos, - rotary_dim, is_neox, positions) + rot_dim = cos_sin_cache.size(1) + ipex.llm.functional.rotary_embedding_batched(positions, query, key, + head_size, cos_sin_cache, + is_neox, rot_dim) @staticmethod def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor, @@ -190,37 +169,15 @@ def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor, cos_sin_cache: torch.Tensor, is_neox: bool, rot_dim: int, cos_sin_cache_offsets: torch.Tensor) -> None: - if positions.dim() == 1: - positions = positions.unsqueeze(0) - query = query.unsqueeze(0) - key = key.unsqueeze(0) - cos_sin_cache_offsets = cos_sin_cache_offsets.view_as(positions) - rotary_dim = cos_sin_cache.size(1) - query = query.view(*query.shape[:-1], -1, head_size) - key = key.view(*key.shape[:-1], -1, head_size) - - query_rot = query[..., :rotary_dim] - key_rot = key[..., :rotary_dim] - - cos_sin = cos_sin_cache[torch.add(positions, - cos_sin_cache_offsets).long()] - cos, sin = cos_sin.chunk(2, dim=-1) - - if is_neox: - cos = cos.repeat(1, 1, 2).unsqueeze(-2) - sin = sin.repeat(1, 1, 2).unsqueeze(-2) - else: - cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2) - sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2) - - ipex.llm.functional.rotary_embedding(query_rot, key_rot, sin, cos, - rotary_dim, is_neox, positions) + ipex.llm.functional.rotary_embedding_batched(positions, query, key, + head_size, cos_sin_cache, + is_neox, rot_dim, + cos_sin_cache_offsets) @staticmethod - def rms_norm(out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, - epsilon: float) -> None: - tmp = ipex.llm.functional.rms_norm(input, weight, epsilon) - out.copy_(tmp) + def rms_norm(input: torch.Tensor, weight: torch.Tensor, + epsilon: float) -> torch.Tensor: + return ipex.llm.functional.rms_norm(input, weight, epsilon) @staticmethod def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor, @@ -246,11 +203,14 @@ def varlen_attention( return_softmax: bool, gen_: torch.Generator, ) -> None: - ipex.llm.functional.varlen_attention(query, key, value, out, seqlen_q, - seqlen_k, max_seqlen_q, - max_seqlen_k, pdropout, - softmax_scale, zero_tensors, - is_causal, return_softmax, gen_) + ipex.llm.functional.varlen_attention(query.contiguous(), + key.contiguous(), + value.contiguous(), out, + seqlen_q.int(), seqlen_k.int(), + max_seqlen_q, max_seqlen_k, + pdropout, softmax_scale, + zero_tensors, is_causal, + return_softmax, gen_) @staticmethod def reshape_and_cache( diff --git a/vllm/adapter_commons/utils.py b/vllm/adapter_commons/utils.py index 6c5411f7d3d5c..1e9adca50093b 100644 --- a/vllm/adapter_commons/utils.py +++ b/vllm/adapter_commons/utils.py @@ -42,7 +42,7 @@ def list_adapters(registered_adapters: Dict[int, Any]) -> Dict[int, Any]: def get_adapter(adapter_id: int, registered_adapters: Dict[int, Any]) -> Optional[Any]: - return registered_adapters.get(adapter_id, None) + return registered_adapters.get(adapter_id) ## worker functions diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index adc8390e6f9ec..2bc36ff18a96b 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -156,18 +156,27 @@ def graph_clone(self, batch_size: int) -> "AttentionState[T]": ... @abstractmethod - def graph_capture_get_metadata_for_batch(self, batch_size: int) -> T: + def graph_capture_get_metadata_for_batch( + self, + batch_size: int, + is_encoder_decoder_model: bool = False) -> T: """Get attention metadata for CUDA graph capture of batch_size.""" ... @abstractmethod - def get_graph_input_buffers(self, attn_metadata: T) -> Dict[str, Any]: + def get_graph_input_buffers( + self, + attn_metadata: T, + is_encoder_decoder_model: bool = False) -> Dict[str, Any]: """Get attention-specific input buffers for CUDA graph capture.""" ... @abstractmethod - def prepare_graph_input_buffers(self, input_buffers: Dict[str, Any], - attn_metadata: T) -> None: + def prepare_graph_input_buffers( + self, + input_buffers: Dict[str, Any], + attn_metadata: T, + is_encoder_decoder_model: bool = False) -> None: """In-place modify input buffers dict for CUDA graph replay.""" ... diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 58d62e02e8733..3a602fbfbbc04 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -172,7 +172,8 @@ def graph_clone(self, batch_size: int): state._prefill_wrapper = self._get_prefill_wrapper() return state - def graph_capture_get_metadata_for_batch(self, batch_size: int): + def graph_capture_get_metadata_for_batch( + self, batch_size: int, is_encoder_decoder_model: bool = False): assert self._is_graph_capturing _indptr_buffer = self._graph_indptr_buffer[:batch_size + 1] _last_page_len_buffer = self._graph_last_page_len_buffer[:batch_size] @@ -232,12 +233,17 @@ def graph_capture_get_metadata_for_batch(self, batch_size: int): attn_metadata.begin_forward() return attn_metadata - def get_graph_input_buffers(self, attn_metadata): + def get_graph_input_buffers(self, + attn_metadata, + is_encoder_decoder_model: bool = False): return { "slot_mapping": attn_metadata.slot_mapping, } - def prepare_graph_input_buffers(self, input_buffers, attn_metadata): + def prepare_graph_input_buffers(self, + input_buffers, + attn_metadata, + is_encoder_decoder_model: bool = False): return def begin_forward(self, model_input): @@ -597,9 +603,19 @@ def build(self, seq_lens: List[int], query_lens: List[int], # The shape of graph_block_tables is # [max batch size, max context len // block size]. input_block_tables = self.runner.graph_block_tables[:batch_size] + max_blocks = input_block_tables.shape[1] for i, block_table in enumerate(self.block_tables): if block_table: - input_block_tables[i, :len(block_table)] = block_table + num_blocks = len(block_table) + if num_blocks <= max_blocks: + input_block_tables[i, :num_blocks] = block_table + else: + # It may be possible to have more blocks allocated due + # to lookahead slots of multi-step, however, they are + # not used anyway, so can be safely ignored. + input_block_tables[ + i, :max_blocks] = block_table[:max_blocks] + block_tables = torch.from_numpy(input_block_tables).to( device, non_blocking=True) diff --git a/vllm/attention/backends/ipex_attn.py b/vllm/attention/backends/ipex_attn.py index 64d60e4e47e48..113a2788eacd3 100644 --- a/vllm/attention/backends/ipex_attn.py +++ b/vllm/attention/backends/ipex_attn.py @@ -49,14 +49,18 @@ def swap_blocks( dst_kv_cache: torch.Tensor, src_to_dst: torch.Tensor, ) -> None: - PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) + from vllm._ipex_ops import ipex_ops as ops + ops.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) @staticmethod def copy_blocks( kv_caches: List[torch.Tensor], src_to_dists: torch.Tensor, ) -> None: - PagedAttention.copy_blocks(kv_caches, src_to_dists) + from vllm._ipex_ops import ipex_ops as ops + key_caches = [kv_cache[0] for kv_cache in kv_caches] + value_caches = [kv_cache[1] for kv_cache in kv_caches] + ops.copy_blocks(key_caches, value_caches, src_to_dists) @dataclass diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index b0f4d0530b7f0..6bd276ade1d41 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -5,6 +5,7 @@ import torch import vllm.envs as envs +from vllm import _custom_ops as ops from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, AttentionType) from vllm.attention.backends.utils import (CommonAttentionState, @@ -12,9 +13,13 @@ from vllm.attention.ops.paged_attn import (PagedAttention, PagedAttentionMetadata) from vllm.logger import init_logger +from vllm.platforms import current_platform logger = init_logger(__name__) +_PARTITION_SIZE = 256 +ON_NAVI = "gfx1" in torch.cuda.get_device_properties("cuda").gcnArchName + class ROCmFlashAttentionBackend(AttentionBackend): @@ -295,7 +300,7 @@ def __init__( else: # if not using triton, navi3x/navi21/navi10 do not use flash-attn # either - if torch.cuda.get_device_capability()[0] != 9: + if not current_platform.has_device_capability(90): self.use_naive_attn = True else: try: @@ -480,20 +485,61 @@ def forward( if decode_meta := attn_metadata.decode_metadata: # Decoding run. - output[num_prefill_tokens:] = PagedAttention.forward_decode( - decode_query, - key_cache, - value_cache, - decode_meta.block_tables, - decode_meta.seq_lens_tensor, - decode_meta.max_decode_seq_len, - self.kv_cache_dtype, - self.num_kv_heads, - self.scale, - self.alibi_slopes, - k_scale, - v_scale, - ) + # Whether to use rocm custom paged attention or not + num_seqs, num_heads, head_size = decode_query.shape + block_size = value_cache.shape[3] + gqa_ratio = num_heads // self.num_kv_heads + use_custom = use_rocm_custom_paged_attention( + decode_query.dtype, head_size, block_size, self.kv_cache_dtype, + gqa_ratio, decode_meta.max_decode_seq_len) + if use_custom: + max_seq_len = decode_meta.max_decode_seq_len + max_num_partitions = ((max_seq_len + _PARTITION_SIZE - 1) // + _PARTITION_SIZE) + assert _PARTITION_SIZE % block_size == 0 + tmp_output = torch.empty( + size=(num_seqs, num_heads, max_num_partitions, head_size), + dtype=output.dtype, + device=output.device, + ) + exp_sums = torch.empty( + size=(num_seqs, num_heads, max_num_partitions), + dtype=torch.float32, + device=output.device, + ) + max_logits = torch.empty_like(exp_sums) + ops.paged_attention_rocm( + output[num_prefill_tokens:], + exp_sums, + max_logits, + tmp_output, + decode_query, + key_cache, + value_cache, + self.num_kv_heads, + self.scale, + decode_meta.block_tables, + decode_meta.seq_lens_tensor, + block_size, + max_seq_len, + self.alibi_slopes, + self.kv_cache_dtype, + ) + else: + output[num_prefill_tokens:] = PagedAttention.forward_decode( + decode_query, + key_cache, + value_cache, + decode_meta.block_tables, + decode_meta.seq_lens_tensor, + decode_meta.max_decode_seq_len, + self.kv_cache_dtype, + self.num_kv_heads, + self.scale, + self.alibi_slopes, + k_scale, + v_scale, + ) # Reshape the output tensor. return output.view(num_tokens, hidden_size) @@ -532,3 +578,14 @@ def _sdpa_attention( start = end return output + + +def use_rocm_custom_paged_attention(qtype: torch.dtype, head_size: int, + block_size: int, kv_cache_dtype: str, + gqa_ratio: int, max_seq_len: int) -> bool: + # rocm custom page attention not support on navi (gfx1*) + return (not ON_NAVI and (qtype == torch.half or qtype == torch.bfloat16) + and (head_size == 64 or head_size == 128) + and (block_size == 16 or block_size == 32) + and kv_cache_dtype == "auto" + and (gqa_ratio >= 1 and gqa_ratio <= 16) and max_seq_len <= 32768) diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py index 0375d3488eb15..49fbb25f4547b 100644 --- a/vllm/attention/backends/utils.py +++ b/vllm/attention/backends/utils.py @@ -33,10 +33,8 @@ def is_block_tables_empty(block_tables: Union[None, Dict]): """ if block_tables is None: return True - if isinstance(block_tables, dict) and all( - value is None for value in block_tables.values()): - return True - return False + return (isinstance(block_tables, dict) + and all(value is None for value in block_tables.values())) def compute_slot_mapping_start_idx(is_prompt: bool, query_len: int, @@ -304,7 +302,8 @@ def graph_clone(self, batch_size: int) -> "CommonAttentionState": assert self._is_graph_capturing return self.__class__(self.runner) - def graph_capture_get_metadata_for_batch(self, batch_size: int): + def graph_capture_get_metadata_for_batch( + self, batch_size: int, is_encoder_decoder_model: bool = False): assert self._is_graph_capturing attn_metadata = self.runner.attn_backend.make_metadata( num_prefills=0, @@ -322,21 +321,121 @@ def graph_capture_get_metadata_for_batch(self, batch_size: int): block_tables=self._graph_block_tables[:batch_size], use_cuda_graph=True, ) + if is_encoder_decoder_model: + # The encoder decoder model works only with XFormers backend. + # Assert the same. + assert self.runner.attn_backend.get_name() == "xformers", \ + f"Expected attn_backend name to be 'xformers', but "\ + f" got '{self.runner.attn_backend.get_name()}'" + self._update_captured_metadata_for_enc_dec_model( + batch_size=batch_size, attn_metadata=attn_metadata) + return attn_metadata - def get_graph_input_buffers(self, attn_metadata) -> Dict[str, Any]: - return { + def get_graph_input_buffers( + self, + attn_metadata, + is_encoder_decoder_model: bool = False) -> Dict[str, Any]: + input_buffers = { "slot_mapping": attn_metadata.slot_mapping, "seq_lens_tensor": attn_metadata.decode_metadata.seq_lens_tensor, "block_tables": attn_metadata.decode_metadata.block_tables, } - - def prepare_graph_input_buffers(self, input_buffers, - attn_metadata) -> None: + if is_encoder_decoder_model: + # The encoder decoder model works only with XFormers backend. + # Assert the same. + assert self.runner.attn_backend.get_name() == "xformers", \ + f"Expected attn_backend name to be 'xformers', but "\ + f" got '{self.runner.attn_backend.get_name()}'" + self._add_additonal_input_buffers_for_enc_dec_model( + attn_metadata=attn_metadata, input_buffers=input_buffers) + return input_buffers + + def prepare_graph_input_buffers( + self, + input_buffers, + attn_metadata, + is_encoder_decoder_model: bool = False) -> None: input_buffers["seq_lens_tensor"].copy_( attn_metadata.decode_metadata.seq_lens_tensor, non_blocking=True) input_buffers["block_tables"].copy_( attn_metadata.decode_metadata.block_tables, non_blocking=True) + if is_encoder_decoder_model: + # The encoder decoder model works only with XFormers backend. + # Assert the same. + assert self.runner.attn_backend.get_name() == "xformers", \ + f"Expected attn_backend name to be 'xformers', but "\ + f" got '{self.runner.attn_backend.get_name()}'" + self._prepare_input_buffers_for_enc_dec_model( + attn_metadata, input_buffers) def begin_forward(self, model_input) -> None: return + + def _update_captured_metadata_for_enc_dec_model(self, batch_size: int, + attn_metadata): + """ + Updates the attention metadata parameters for CUDA graph capture in an + encoder-decoder model. + + This method modifies attention-related tensors and metadata required + for CUDA graph capture in encoder-decoder models. Specifically, it + updates the cross-attention and encoder sequence tensors in the + AttentionMetadata object. + """ + # During decode phase the cross_slot_mapping will be empty. Hence set + # an empty tensor for CUDA Graph capture. + attn_metadata.cross_slot_mapping = torch.tensor( + [], dtype=torch.int).cuda() + attn_metadata.cross_block_tables = torch.full( + (batch_size, self.runner.get_max_block_per_batch()), + 1, + dtype=torch.int).cuda() + attn_metadata.encoder_seq_lens = torch.full((batch_size, ), + 1, + dtype=torch.int).cuda() + attn_metadata.encoder_seq_lens_tensor = torch.full( + (batch_size, ), 1, dtype=torch.int).cuda() + attn_metadata.max_encoder_seq_len = self.runner.max_seq_len_to_capture + + def _add_additonal_input_buffers_for_enc_dec_model( + self, attn_metadata, input_buffers: Dict[str, Any]): + """ + Saves additional input buffers specific to the encoder-decoder model + from the attention metadata. + + This method extracts and stores encoder-decoder related input buffers + from the `attn_metadata` into the `input_buffers` dictionary. The + buffers include encoder sequence lengths, cross-slot mappings, and + cross-block tables, which are essential for the encoder-decoder model + during CUDA graph replay. + """ + input_buffers["encoder_seq_lens_tensor"] = ( + attn_metadata.decode_metadata.encoder_seq_lens_tensor) + input_buffers["cross_slot_mapping"] = ( + attn_metadata.decode_metadata.cross_slot_mapping) + input_buffers["cross_block_tables"] = ( + attn_metadata.decode_metadata.cross_block_tables) + + def _prepare_input_buffers_for_enc_dec_model(self, attn_metadata, + input_buffers: Dict[str, + Any]): + """ + Populates input buffers with data from the encoder-decoder model's + attention metadata. + + This method fills the input buffers with encoder-decoder specific + tensors. It copies data from the `attn_metadata` and keyword arguments + (`kwargs`) into corresponding buffers in the `input_buffers` dictionary. + The copied data includes attention-related metadata as well as input + IDs and positional information for the encoder. + """ + input_buffers["encoder_seq_lens_tensor"].copy_( + attn_metadata.decode_metadata.encoder_seq_lens_tensor, + non_blocking=True) + input_buffers["cross_slot_mapping"].copy_( + attn_metadata.decode_metadata.cross_slot_mapping, + non_blocking=True) + input_buffers["cross_block_tables"].copy_( + attn_metadata.decode_metadata.cross_block_tables, + non_blocking=True) diff --git a/vllm/attention/ops/blocksparse_attention/interface.py b/vllm/attention/ops/blocksparse_attention/interface.py index e870a8e614d12..1ead541f391b5 100644 --- a/vllm/attention/ops/blocksparse_attention/interface.py +++ b/vllm/attention/ops/blocksparse_attention/interface.py @@ -8,8 +8,7 @@ from .utils import (dense_to_crow_col, get_head_sliding_step, get_sparse_attn_mask) -IS_COMPUTE_8_OR_ABOVE = (torch.cuda.is_available() - and current_platform.get_device_capability()[0] >= 8) +IS_COMPUTE_8_OR_ABOVE = current_platform.has_device_capability(80) if IS_COMPUTE_8_OR_ABOVE: from .blocksparse_attention_kernel import blocksparse_flash_attn_varlen_fwd @@ -36,7 +35,7 @@ def __init__( use_spda = is_hip() or is_cpu() or not \ IS_COMPUTE_8_OR_ABOVE device = device or (torch.cuda.current_device() - if torch.cuda.is_available() else "cpu") + if current_platform.is_cuda_alike() else "cpu") device = torch.device(device) # NOTE: vllm CPU backend support BF16 instead of FP16. dtype = dtype or (torch.bfloat16 if IS_COMPUTE_8_OR_ABOVE diff --git a/vllm/attention/ops/prefix_prefill.py b/vllm/attention/ops/prefix_prefill.py index 558b2f3eeac7e..a2a649c8ebcfd 100644 --- a/vllm/attention/ops/prefix_prefill.py +++ b/vllm/attention/ops/prefix_prefill.py @@ -709,8 +709,7 @@ def context_attention_fwd(q, alibi_slopes=None, sliding_window=None): - cap = current_platform.get_device_capability() - BLOCK = 128 if cap[0] >= 8 else 64 + BLOCK = 128 if current_platform.has_device_capability(80) else 64 NUM_WARPS = 8 # need to reduce num. blocks when using fp32 diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 855586d4e5961..fbda263ba8e08 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -203,7 +203,7 @@ def which_attn_to_use( selected_backend = (_Backend.ROCM_FLASH if selected_backend == _Backend.FLASH_ATTN else selected_backend) if selected_backend == _Backend.ROCM_FLASH: - if current_platform.get_device_capability()[0] != 9: + if not current_platform.has_device_capability(90): # not Instinct series GPUs. logger.info("flash_attn is not supported on NAVI GPUs.") else: @@ -212,7 +212,7 @@ def which_attn_to_use( # FlashAttn in NVIDIA GPUs. if selected_backend == _Backend.FLASH_ATTN: - if current_platform.get_device_capability()[0] < 8: + if not current_platform.has_device_capability(80): # Volta and Turing NVIDIA GPUs. logger.info( "Cannot use FlashAttention-2 backend for Volta and Turing " diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py new file mode 100644 index 0000000000000..de0b1d8a75757 --- /dev/null +++ b/vllm/compilation/backends.py @@ -0,0 +1,156 @@ +import operator + +import torch +import torch.fx as fx + + +def fix_functionalization(graph: fx.Graph): + """ + Rewrite the graph module to replace the pattern involving + torch._higher_order_ops.auto_functionalize.auto_functionalized + with a direct call to the inplace custom op. + + # TODO: check if PyTorch nightly has fixed this issue + """ + + # debug code, if we want to see the graph before the transformation + # with open("before.py", "w") as f: + # print(graph.python_code(root_module="self", verbose=True).src, file=f) + + nodes_to_remove = [] + + for node in graph.nodes: + # Identify the auto_functionalized node + if node.op == 'call_function' and node.target == torch._higher_order_ops.auto_functionalize.auto_functionalized: # noqa + if node.args[0] == torch.ops._C.rotary_embedding.default: + # manual replace for rotary_embedding + + # Now, collect the arguments + kwargs = node.kwargs + + query = kwargs['query'] + mm_node = query.args[0].args[0] + + # Create a new call to torch.ops._C.rotary_embedding.default + with graph.inserting_before(node): + # just insert the call to the custom op + # NOTE: don't run dead code elimination, + # otherwise this op will be removed + graph.call_function(torch.ops._C.rotary_embedding.default, + kwargs=kwargs) + + # Remove the auto_functionalized node + # Since the node may have outputs, we need to handle its users + # Replace uses of the outputs (getitem nodes) with mm_node + for user in list(node.users): + if user.op == 'call_function' and user.target == operator.getitem: # noqa + # Remove the getitem node + for getitem_user in list(user.users): + if (getitem_user.op == 'call_function' + and getitem_user.target + == torch.ops.aten.slice_scatter.default): + # Replace the uses of slice_scatter node + # with mm_node + getitem_user.replace_all_uses_with(mm_node) + nodes_to_remove.append(getitem_user) + nodes_to_remove.append(user) + nodes_to_remove.append(node) + + elif node.args[0] == torch.ops._C.fused_add_rms_norm.default: + # manual replace for fused_add_rms_norm + # this is the most effective optimization for llama + # failing to do this will result in many unnecessary copies + + kwargs = node.kwargs + + input = kwargs['input'] + residual = kwargs['residual'] + + # Create a new call to torch.ops._C.rotary_embedding.default + with graph.inserting_before(node): + # just insert the call to the custom op + # NOTE: don't run dead code elimination, + # otherwise this op will be removed + graph.call_function( + torch.ops._C.fused_add_rms_norm.default, kwargs=kwargs) + + for user in list(node.users): + if user.op == 'call_function' and user.target == operator.getitem: # noqa + # Remove the getitem node + if user.args[1] == 1: + replace_node = input + elif user.args[1] == 2: + replace_node = residual + user.replace_all_uses_with(replace_node) + nodes_to_remove.append(user) + nodes_to_remove.append(node) + + elif node.args[0] == torch.ops._C.rms_norm.default: + # manual replace for rms_norm + + kwargs = node.kwargs + + input = kwargs['input'] + out = kwargs['out'] + weight = kwargs['weight'] + epsilon = kwargs['epsilon'] + # Create a new call to torch.ops._C.rotary_embedding.default + # cannot use kwargs, because we have an `out`, see https://github.com/pytorch/pytorch/blob/a00faf440888ffb724bad413f329a49e2b6388e7/torch/_inductor/lowering.py#L351 # noqa + with graph.inserting_before(node): + # just insert the call to the custom op + # NOTE: don't run dead code elimination, + # otherwise this op will be removed + graph.call_function( + torch.ops._C.rms_norm.default, + args=(out, input, weight, epsilon), + ) + + replace_node = out + + for user in list(node.users): + if user.op == 'call_function' and user.target == operator.getitem: # noqa + user.replace_all_uses_with(replace_node) + nodes_to_remove.append(user) + nodes_to_remove.append(node) + + elif node.args[0] == torch.ops._C.silu_and_mul.default: + # manual replace for silu_and_mul + + kwargs = node.kwargs + + input = kwargs['input'] + out = kwargs['out'] + + # Create a new call to torch.ops._C.rotary_embedding.default + # cannot use kwargs, because we have an `out`, see https://github.com/pytorch/pytorch/blob/a00faf440888ffb724bad413f329a49e2b6388e7/torch/_inductor/lowering.py#L351 # noqa + with graph.inserting_before(node): + # just insert the call to the custom op + # NOTE: don't run dead code elimination, + # otherwise this op will be removed + graph.call_function( + torch.ops._C.silu_and_mul.default, + args=(out, input), + ) + replace_node = out + + for user in list(node.users): + if user.op == 'call_function' and user.target == operator.getitem: # noqa + user.replace_all_uses_with(replace_node) + nodes_to_remove.append(user) + nodes_to_remove.append(node) + + # Remove the nodes all at once + for node in nodes_to_remove: + graph.erase_node(node) + + # debug code, if we want to see the graph after the transformation + # with open("after.py", "w") as f: + # print(graph.python_code(root_module="self", verbose=True).src, file=f) + + +def vllm_backend(graph, example_inputs): + from torch._inductor import config + current_config = config.shallow_copy_dict() + from torch._inductor.compile_fx import compile_fx + current_config['post_grad_custom_post_pass'] = fix_functionalization + return compile_fx(graph, example_inputs, config_patches=current_config) diff --git a/vllm/config.py b/vllm/config.py index 9684cea813134..7a15606836dcc 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -16,8 +16,7 @@ from vllm.transformers_utils.config import (ConfigFormat, get_config, get_hf_image_processor_config, get_hf_text_config) -from vllm.utils import (STR_NOT_IMPL_ENC_DEC_CUDAGRAPH, GiB_bytes, - cuda_device_count_stateless, get_cpu_memory, is_cpu, +from vllm.utils import (GiB_bytes, cuda_device_count_stateless, get_cpu_memory, is_hip, is_neuron, is_openvino, is_xpu, print_warning_once) @@ -96,15 +95,15 @@ class ModelConfig: enforce_eager: Whether to enforce eager execution. If True, we will disable CUDA graph and always execute the model in eager mode. If False, we will use CUDA graph and eager execution in hybrid. - If None, the user did not specify, so default to False - - except for encoder/decoder models, which currently require - eager mode. + If None, the user did not specify, so default to False. max_context_len_to_capture: Maximum context len covered by CUDA graphs. When a sequence has context length larger than this, we fall back to eager mode (DEPRECATED. Use max_seq_len_to_capture instead). max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs. When a sequence has context length larger than this, we fall back - to eager mode + to eager mode. Additionally for encoder-decoder models, if the + sequence length of the encoder input is larger than this, we fall + back to the eager mode. disable_sliding_window: Whether to disable sliding window. If True, we will disable the sliding window functionality of the model. If the model does not support sliding window, this argument is @@ -186,32 +185,8 @@ def __init__(self, self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype) self.use_async_output_proc = use_async_output_proc - # Choose a default enforce_eager value if the user did not specify - # a value (enforce_eager is None) - if getattr(self.hf_config, 'is_encoder_decoder', False): - if self.enforce_eager is None: - # *Only for encoder/decoder models* and - # *only if enforce_eager is unset*, override - # to enforce_eager=True - # - # Add a logger message since it is *somewhat* non-intuitive that - # enforce_eager is True when the user has not specified its - # value. - logger.info("Forcing enforce_eager == True because " - "enforce_eager setting was unspecified and " - "CUDAGraph is not supported with encoder/ " - "decoder models.") - self.enforce_eager = True - - if not self.enforce_eager: - # Eager mode explicitly disabled by user for an encoder/ - # decoder model; however CUDAGRAPH + encoder/decoder is - # not currently supported - raise ValueError(STR_NOT_IMPL_ENC_DEC_CUDAGRAPH) - elif self.enforce_eager is None: - # *Only for decoder-only models*, enforce_eager - # defaults to False if unset. This is intuitive - # so no logging message needed. + # Set enforce_eager to False if the value is unset. + if self.enforce_eager is None: self.enforce_eager = False if (not self.disable_sliding_window @@ -280,7 +255,10 @@ def _parse_quant_hf_config(self): def _verify_quantization(self) -> None: supported_quantization = [*QUANTIZATION_METHODS] - rocm_supported_quantization = ["awq", "gptq", "fp8"] + rocm_supported_quantization = [ + "awq", "gptq", "fp8", "compressed_tensors", "compressed-tensors", + "fbgemm_fp8" + ] optimized_quantization_methods = [ "fp8", "marlin", "modelopt", "gptq_marlin_24", "gptq_marlin", "awq_marlin", "fbgemm_fp8", "compressed_tensors", @@ -379,7 +357,7 @@ def verify_async_output_proc(self, parallel_config, speculative_config, self.use_async_output_proc = False return - if self.enforce_eager: + if device_config.device_type == "cuda" and self.enforce_eager: logger.warning( "To see benefits of async output processing, enable CUDA " "graph. Since, enforce-eager is enabled, async output " @@ -418,12 +396,6 @@ def verify_with_parallel_config( "Pipeline parallelism is only supported for the following " f" architectures: {_PP_SUPPORTED_MODELS}.") - if self.quantization == "bitsandbytes" and ( - parallel_config.tensor_parallel_size > 1 - or parallel_config.pipeline_parallel_size > 1): - raise ValueError( - "BitAndBytes quantization with TP or PP is not supported yet.") - # Remove the constraint after the bitsandbytes issue is fixed: # https://github.com/bitsandbytes-foundation/bitsandbytes/issues/1308 if self.quantization == "bitsandbytes" and self.enforce_eager is False: @@ -1066,20 +1038,20 @@ class DeviceConfig: def __init__(self, device: str = "auto") -> None: if device == "auto": # Automated device type detection - if is_neuron(): + if current_platform.is_cuda_alike(): + self.device_type = "cuda" + elif is_neuron(): self.device_type = "neuron" elif is_openvino(): self.device_type = "openvino" elif current_platform.is_tpu(): self.device_type = "tpu" - elif is_cpu(): + elif current_platform.is_cpu(): self.device_type = "cpu" elif is_xpu(): self.device_type = "xpu" else: - # We don't call torch.cuda.is_available() here to - # avoid initializing CUDA before workers are forked - self.device_type = "cuda" + raise RuntimeError("Failed to infer device type") else: # Device type is assigned explicitly self.device_type = device diff --git a/vllm/core/block/prefix_caching_block.py b/vllm/core/block/prefix_caching_block.py index a87e814cfb041..db67c95c32429 100644 --- a/vllm/core/block/prefix_caching_block.py +++ b/vllm/core/block/prefix_caching_block.py @@ -417,9 +417,7 @@ def get_prefix_cache_hit_rate(self) -> float: def is_block_cached(self, block: Block) -> bool: assert block.content_hash is not None - if block.content_hash in self._cached_blocks: - return True - return False + return block.content_hash in self._cached_blocks def promote_to_immutable_block(self, block: Block) -> BlockId: """Once a mutable block is full, it can be promoted to an immutable diff --git a/vllm/core/block_manager_v2.py b/vllm/core/block_manager_v2.py index b06385b062e83..54818c7e3e9a6 100644 --- a/vllm/core/block_manager_v2.py +++ b/vllm/core/block_manager_v2.py @@ -399,9 +399,7 @@ def can_swap_out(self, seq_group: SequenceGroup) -> bool: """ alloc_status = self._can_swap(seq_group, Device.CPU, SequenceStatus.RUNNING) - if alloc_status == AllocStatus.OK: - return True - return False + return alloc_status == AllocStatus.OK def swap_out(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]: """Returns the block id mapping (from GPU to CPU) generated by diff --git a/vllm/distributed/device_communicators/custom_all_reduce.py b/vllm/distributed/device_communicators/custom_all_reduce.py index 6229f1d6ec788..d239d645edc14 100644 --- a/vllm/distributed/device_communicators/custom_all_reduce.py +++ b/vllm/distributed/device_communicators/custom_all_reduce.py @@ -33,6 +33,12 @@ def _can_p2p(rank: int, world_size: int) -> bool: return True +def is_weak_contiguous(inp: torch.Tensor): + return inp.is_contiguous() or (inp.storage().nbytes() - + inp.storage_offset() * inp.element_size() + == inp.numel() * inp.element_size()) + + class CustomAllreduce: _SUPPORTED_WORLD_SIZES = [2, 4, 6, 8] @@ -224,8 +230,19 @@ def register_graph_buffers(self): ops.register_graph_buffers(self._ptr, handles, offsets) def should_custom_ar(self, inp: torch.Tensor): - return ops.should_custom_ar(inp, self.max_size, self.world_size, - self.full_nvlink) + if self.disabled: + return False + inp_size = inp.numel() * inp.element_size() + # custom allreduce requires input byte size to be multiples of 16 + if inp_size % 16 != 0: + return False + if not is_weak_contiguous(inp): + return False + # for 4 or more non NVLink-capable GPUs, custom allreduce provides + # little performance improvement over NCCL. + if self.world_size == 2 or self.full_nvlink: + return inp_size < self.max_size + return False # all reduce, assuming inp tensor is IPC registered with register_buffer, # or, in the context of cuda graphs, register_graph_buffers diff --git a/vllm/distributed/device_communicators/shm_broadcast.py b/vllm/distributed/device_communicators/shm_broadcast.py index d4847542688c0..b507cd2e1cddb 100644 --- a/vllm/distributed/device_communicators/shm_broadcast.py +++ b/vllm/distributed/device_communicators/shm_broadcast.py @@ -196,7 +196,9 @@ def __init__( # see http://api.zeromq.org/3-3:zmq-setsockopt for more details self.local_socket.setsockopt(XPUB_VERBOSE, True) local_subscribe_port = get_open_port() - self.local_socket.bind(f"tcp://*:{local_subscribe_port}") + socket_addr = f"tcp://127.0.0.1:{local_subscribe_port}" + logger.debug("Binding to %s", socket_addr) + self.local_socket.bind(socket_addr) self.current_idx = 0 @@ -212,7 +214,8 @@ def __init__( self.remote_socket = context.socket(XPUB) self.remote_socket.setsockopt(XPUB_VERBOSE, True) remote_subscribe_port = get_open_port() - self.remote_socket.bind(f"tcp://*:{remote_subscribe_port}") + socket_addr = f"tcp://*:{remote_subscribe_port}" + self.remote_socket.bind(socket_addr) else: remote_subscribe_port = None @@ -255,8 +258,9 @@ def create_from_handle(handle: Handle, rank) -> "MessageQueue": self.local_socket = context.socket(SUB) self.local_socket.setsockopt_string(SUBSCRIBE, "") - self.local_socket.connect( - f"tcp://{handle.connect_ip}:{handle.local_subscribe_port}") + socket_addr = f"tcp://127.0.0.1:{handle.local_subscribe_port}" + logger.debug("Connecting to %s", socket_addr) + self.local_socket.connect(socket_addr) self.remote_socket = None else: @@ -270,8 +274,9 @@ def create_from_handle(handle: Handle, rank) -> "MessageQueue": self.remote_socket = context.socket(SUB) self.remote_socket.setsockopt_string(SUBSCRIBE, "") - self.remote_socket.connect( - f"tcp://{handle.connect_ip}:{handle.remote_subscribe_port}") + socket_addr = f"tcp://{handle.connect_ip}:{handle.remote_subscribe_port}" + logger.debug("Connecting to %s", socket_addr) + self.remote_socket.connect(socket_addr) return self diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 6755b20eec9bb..df07842edfa56 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -21,11 +21,12 @@ """ import contextlib import pickle +import weakref from collections import namedtuple from contextlib import contextmanager, nullcontext from dataclasses import dataclass from multiprocessing import shared_memory -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union from unittest.mock import patch import torch @@ -34,6 +35,7 @@ import vllm.envs as envs from vllm.logger import init_logger +from vllm.platforms import current_platform @dataclass @@ -69,6 +71,58 @@ def _split_tensor_dict( return metadata_list, tensor_list +_group_name_counter: Dict[str, int] = {} + + +def _get_unique_name(name: str) -> str: + """Get a unique name for the group. + Example: + _get_unique_name("tp") -> "tp:0" + _get_unique_name("tp") -> "tp:1" + """ + if name not in _group_name_counter: + _group_name_counter[name] = 0 + newname = f"{name}:{_group_name_counter[name]}" + _group_name_counter[name] += 1 + return newname + + +_groups: Dict[str, Callable[[], "GroupCoordinator"]] = {} + + +def _register_group(group: "GroupCoordinator") -> None: + # looks like Python 3.8 does not understand `ReferenceType` + _groups[group.unique_name] = weakref.ref(group) # type: ignore + + +@torch.library.custom_op("vllm::inplace_all_reduce", mutates_args=["tensor"]) +def inplace_all_reduce(tensor: torch.Tensor, group_name: str) -> None: + assert group_name in _groups, f"Group {group_name} is not found." + group = _groups[group_name]() + if group is None: + raise ValueError(f"Group {group_name} is destroyed.") + group._all_reduce(tensor) + + +@inplace_all_reduce.register_fake +def _(tensor: torch.Tensor, group_name: str) -> None: + return + + +@torch.library.custom_op("vllm::outplace_all_reduce", mutates_args=[]) +def outplace_all_reduce(tensor: torch.Tensor, group_name: str) -> torch.Tensor: + assert group_name in _groups, f"Group {group_name} is not found." + group = _groups[group_name]() + if group is None: + raise ValueError(f"Group {group_name} is destroyed.") + return group._all_reduce(tensor) + + +@outplace_all_reduce.register_fake +def _(tensor: torch.Tensor, group_name: str) -> torch.Tensor: + return torch.empty_like(tensor) + + class GroupCoordinator: """ PyTorch ProcessGroup wrapper for a group of processes. @@ -111,7 +165,11 @@ def __init__( use_custom_allreduce: bool, use_tpu_communicator: bool, use_message_queue_broadcaster: bool = False, + group_name: Optional[str] = None, ): + group_name = group_name or "anonymous" + self.unique_name = _get_unique_name(group_name) + _register_group(self) self.rank = torch.distributed.get_rank() self.local_rank = local_rank @@ -134,7 +192,7 @@ def __init__( assert self.cpu_group is not None assert self.device_group is not None - if torch.cuda.is_available(): + if current_platform.is_cuda_alike(): self.device = torch.device(f"cuda:{local_rank}") else: self.device = torch.device("cpu") @@ -149,28 +207,24 @@ def __init__( from vllm.distributed.device_communicators.pynccl import ( PyNcclCommunicator) - self.pynccl_comm: Optional[PyNcclCommunicator] + self.pynccl_comm: Optional[PyNcclCommunicator] = None if use_pynccl and self.world_size > 1: self.pynccl_comm = PyNcclCommunicator( group=self.cpu_group, device=self.device, ) - else: - self.pynccl_comm = None - self.ca_comm: Optional[CustomAllreduce] + self.ca_comm: Optional[CustomAllreduce] = None if use_custom_allreduce and self.world_size > 1: # Initialize a custom fast all-reduce implementation. self.ca_comm = CustomAllreduce( group=self.cpu_group, device=self.device, ) - else: - self.ca_comm = None from vllm.distributed.device_communicators.tpu_communicator import ( TpuCommunicator) - self.tpu_communicator: Optional[TpuCommunicator] + self.tpu_communicator: Optional[TpuCommunicator] = None if use_tpu_communicator and self.world_size > 1: self.tpu_communicator = TpuCommunicator(group=self.cpu_group) @@ -264,16 +318,46 @@ def graph_capture( def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: """ + User-facing all-reduce function before we actually call the + all-reduce operation. + + We need this because Dynamo does not support passing an arbitrary + object (`self` in this case) to a custom op. We need to pass the + group name as a string, and then look up the group coordinator from + the group name, dispatch the all-reduce operation to the group + coordinator. + + In addition, PyTorch custom ops do not support mutation or returning + a new tensor in the same op. So we need to figure out if the op is + in-place or out-of-place ahead of time. + """ + # Bypass the function if we are using only 1 GPU. + if self.world_size == 1: + return input_ + + if self.tpu_communicator is not None and \ + not self.tpu_communicator.disabled: + # TPU handles Dynamo with its own logic. + return self._all_reduce(input_) + + if self.ca_comm is not None and self.ca_comm.should_custom_ar(input_): + return torch.ops.vllm.outplace_all_reduce( + input_, group_name=self.unique_name) + else: + torch.ops.vllm.inplace_all_reduce(input_, + group_name=self.unique_name) + return input_ + + def _all_reduce(self, input_: torch.Tensor) -> torch.Tensor: + """ + The actual all-reduce implementation. + NOTE: This operation will be applied in-place or out-of-place. Always assume this function modifies its input, but use the return value as the output. """ ca_comm = self.ca_comm - # Bypass the function if we are using only 1 GPU. - if self.world_size == 1: - return input_ - # For TPUs, use TPU communicator. tpu_comm = self.tpu_communicator if tpu_comm is not None and not tpu_comm.disabled: @@ -758,6 +842,7 @@ def init_world_group(ranks: List[int], local_rank: int, use_pynccl=False, use_custom_allreduce=False, use_tpu_communicator=False, + group_name="world", ) @@ -767,6 +852,7 @@ def init_model_parallel_group( backend: str, use_custom_allreduce: Optional[bool] = None, use_message_queue_broadcaster: bool = False, + group_name: Optional[str] = None, ) -> GroupCoordinator: if use_custom_allreduce is None: use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE @@ -778,6 +864,7 @@ def init_model_parallel_group( use_custom_allreduce=use_custom_allreduce, use_tpu_communicator=True, use_message_queue_broadcaster=use_message_queue_broadcaster, + group_name=group_name, ) @@ -931,7 +1018,8 @@ def initialize_model_parallel( _TP = init_model_parallel_group(group_ranks, get_world_group().local_rank, backend, - use_message_queue_broadcaster=True) + use_message_queue_broadcaster=True, + group_name="tp") # Build the pipeline model-parallel groups. num_pipeline_model_parallel_groups: int = (world_size // @@ -947,7 +1035,8 @@ def initialize_model_parallel( _PP = init_model_parallel_group(group_ranks, get_world_group().local_rank, backend, - use_custom_allreduce=False) + use_custom_allreduce=False, + group_name="pp") def ensure_model_parallel_initialized( diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 6f58c39162087..4139eca9c1832 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -44,22 +44,36 @@ def nullable_str(val: str): def nullable_kvs(val: str) -> Optional[Mapping[str, int]]: + """Parses a string containing comma separate key [str] to value [int] + pairs into a dictionary. + + Args: + val: String value to be parsed. + + Returns: + Dictionary with parsed values. + """ if len(val) == 0: return None out_dict: Dict[str, int] = {} for item in val.split(","): - try: - key, value = item.split("=") - except TypeError as exc: - msg = "Each item should be in the form KEY=VALUE" - raise ValueError(msg) from exc + kv_parts = [part.lower().strip() for part in item.split("=")] + if len(kv_parts) != 2: + raise argparse.ArgumentTypeError( + "Each item should be in the form KEY=VALUE") + key, value = kv_parts try: - out_dict[key] = int(value) + parsed_value = int(value) except ValueError as exc: msg = f"Failed to parse value of item {key}={value}" - raise ValueError(msg) from exc + raise argparse.ArgumentTypeError(msg) from exc + + if key in out_dict and out_dict[key] != parsed_value: + raise argparse.ArgumentTypeError( + f"Conflicting values specified for key: {key}") + out_dict[key] = parsed_value return out_dict @@ -458,7 +472,10 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: default=EngineArgs.max_seq_len_to_capture, help='Maximum sequence length covered by CUDA ' 'graphs. When a sequence has context length ' - 'larger than this, we fall back to eager mode.') + 'larger than this, we fall back to eager mode. ' + 'Additionally for encoder-decoder models, if the ' + 'sequence length of the encoder input is larger ' + 'than this, we fall back to the eager mode.') parser.add_argument('--disable-custom-all-reduce', action='store_true', default=EngineArgs.disable_custom_all_reduce, @@ -843,6 +860,13 @@ def create_engine_config(self) -> EngineConfig: device_config = DeviceConfig(device=self.device) model_config = self.create_model_config() + if model_config.is_multimodal_model: + if self.enable_prefix_caching: + logger.warning( + "--enable-prefix-caching is currently not " + "supported for multimodal models and has been disabled.") + self.enable_prefix_caching = False + cache_config = CacheConfig( block_size=self.block_size if self.device != "neuron" else self.max_model_len, # neuron needs block_size = max_model_len @@ -874,7 +898,10 @@ def create_engine_config(self) -> EngineConfig: # If not explicitly set, enable chunked prefill by default for # long context (> 32K) models. This is to avoid OOM errors in the # initial memory profiling phase. - if use_long_context: + + # Chunked prefill is currently disabled for multimodal models by + # default. + if use_long_context and not model_config.is_multimodal_model: is_gpu = device_config.device_type == "cuda" use_sliding_window = (model_config.get_sliding_window() is not None) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 362b0f3a44b02..34e7e05341f02 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -1,10 +1,10 @@ import asyncio import time +import weakref from functools import partial from typing import (Any, AsyncGenerator, Callable, Dict, Iterable, List, Mapping, Optional, Set, Tuple, Type, Union) - -from typing_extensions import assert_never +from weakref import ReferenceType import vllm.envs as envs from vllm.config import (DecodingConfig, EngineConfig, LoRAConfig, ModelConfig, @@ -12,14 +12,12 @@ from vllm.core.scheduler import SchedulerOutputs from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_timeout import asyncio_timeout -from vllm.engine.llm_engine import (DecoderPromptComponents, LLMEngine, - PromptComponents, SchedulerOutputState) +from vllm.engine.llm_engine import LLMEngine, SchedulerOutputState from vllm.engine.metrics_types import StatLoggerBase from vllm.executor.executor_base import ExecutorAsyncBase +from vllm.executor.gpu_executor import GPUExecutorAsync from vllm.executor.ray_utils import initialize_ray_cluster -from vllm.inputs import (EncoderDecoderLLMInputs, LLMInputs, PromptInputs, - SingletonPromptInputs) -from vllm.inputs.parse import is_explicit_encoder_decoder_prompt +from vllm.inputs import PromptInputs from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.model_executor.layers.sampler import SamplerOutput @@ -30,6 +28,7 @@ from vllm.sequence import ExecuteModelRequest from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.usage.usage_lib import UsageContext +from vllm.utils import weak_bind logger = init_logger(__name__) ENGINE_ITERATION_TIMEOUT_S = envs.VLLM_ENGINE_ITERATION_TIMEOUT_S @@ -403,139 +402,6 @@ async def stop_remote_worker_execution_loop_async(self) -> None: """Stop the remote worker execution loop.""" await self.model_executor.stop_remote_worker_execution_loop_async() - async def _tokenize_prompt_async( - self, - prompt: str, - request_id: str, - lora_request: Optional[LoRARequest], - ) -> List[int]: - """Async version of :meth:`_tokenize_prompt`.""" - tokenizer = self.get_tokenizer_group( - missing_msg="prompts must be None if skip_tokenizer_init is True") - - return await tokenizer.encode_async(request_id=request_id, - prompt=prompt, - lora_request=lora_request) - - async def _extract_prompt_components_async( - self, - inputs: SingletonPromptInputs, - request_id: str, - lora_request: Optional[LoRARequest] = None, - ) -> PromptComponents: - """Async version of :meth:`_extract_prompt_components`.""" - if isinstance(inputs, str): - prompt = inputs - prompt_token_ids = await self._tokenize_prompt_async( - prompt, - request_id=request_id, - lora_request=lora_request, - ) - multi_modal_data = None - elif isinstance(inputs, dict): - if "prompt_token_ids" in inputs: - prompt = None - prompt_token_ids = inputs["prompt_token_ids"] - else: - # NOTE: This extra assignment is required to pass mypy - prompt = parsed_prompt = inputs["prompt"] - prompt_token_ids = await self._tokenize_prompt_async( - parsed_prompt, - request_id=request_id, - lora_request=lora_request, - ) - - multi_modal_data = inputs.get("multi_modal_data") - else: - assert_never(inputs) - - return prompt, prompt_token_ids, multi_modal_data - - async def _process_encoder_decoder_prompt_async( - self, - inputs: PromptInputs, - request_id: str, - ) -> EncoderDecoderLLMInputs: - """Async version of :meth:`_process_encoder_decoder_prompt`.""" - encoder_comps: PromptComponents - decoder_comps: DecoderPromptComponents - - if is_explicit_encoder_decoder_prompt(inputs): - encoder_task = self._extract_prompt_components_async( - inputs["encoder_prompt"], - request_id=request_id, - ) - - if (decoder_input := inputs["decoder_prompt"]) is None: - encoder_comps = await encoder_task - decoder_comps = None, None, None - else: - decoder_task = self._extract_prompt_components_async( - decoder_input, - request_id=request_id, - ) - - encoder_comps, decoder_comps = await asyncio.gather( - encoder_task, decoder_task) - else: - encoder_comps = await self._extract_prompt_components_async( - inputs, - request_id=request_id, - ) - - decoder_comps = None, None, None - - return self._build_enc_dec_llm_inputs(encoder_comps, decoder_comps) - - async def _process_decoder_only_prompt_async( - self, - inputs: SingletonPromptInputs, - request_id: str, - lora_request: Optional[LoRARequest] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None, - ) -> LLMInputs: - """Async version of :meth:`_process_decoder_only_prompt`.""" - prompt_comps = await self._extract_prompt_components_async( - inputs, - request_id=request_id, - lora_request=lora_request, - ) - - return self._build_decoder_only_llm_inputs( - prompt_comps, - prompt_adapter_request=prompt_adapter_request, - ) - - async def process_model_inputs_async( - self, - inputs: PromptInputs, - request_id: str, - lora_request: Optional[LoRARequest] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None, - ) -> Union[LLMInputs, EncoderDecoderLLMInputs]: - """Async version of :meth:`process_model_inputs`.""" - if self.is_encoder_decoder_model(): - # Encoder-decoder model requires special mapping of - # input prompts to encoder & decoder - model_inputs = await self._process_encoder_decoder_prompt_async( - inputs, - request_id=request_id, - ) - else: - if is_explicit_encoder_decoder_prompt(inputs): - raise ValueError("Cannot pass encoder-decoder prompt " - "to decoder-only models") - - # Decoder-only operation - model_inputs = await self._process_decoder_only_prompt_async( - inputs, - request_id=request_id, - lora_request=lora_request, - prompt_adapter_request=prompt_adapter_request, - ) - - return self.input_processor(model_inputs) - async def add_request_async( self, request_id: str, @@ -553,12 +419,13 @@ async def add_request_async( if arrival_time is None: arrival_time = time.time() - processed_inputs = await self.process_model_inputs_async( + preprocessed_inputs = await self.input_preprocessor.preprocess_async( inputs, request_id=request_id, lora_request=lora_request, prompt_adapter_request=prompt_adapter_request, ) + processed_inputs = self.input_processor(preprocessed_inputs) self._add_processed_request( request_id=request_id, @@ -586,9 +453,6 @@ class AsyncLLMEngine: method yields the outputs from the :class:`LLMEngine` to the caller. Args: - worker_use_ray: Whether to use Ray for model workers. Required for - distributed execution. Should be the same as - `parallel_config.worker_use_ray`. log_requests: Whether to log the requests. start_engine_loop: If True, the background task to run the engine will be automatically started in the generate call. @@ -599,23 +463,22 @@ class AsyncLLMEngine: _engine_class: Type[_AsyncLLMEngine] = _AsyncLLMEngine def __init__(self, - worker_use_ray: bool, *args, log_requests: bool = True, start_engine_loop: bool = True, **kwargs) -> None: - self.worker_use_ray = worker_use_ray self.log_requests = log_requests self.engine = self._engine_class(*args, **kwargs) # This ensures quick processing of request outputs # so the append to asyncio queues is not delayed, # especially for multi-step. - # - self.use_process_request_outputs_callback = True + self.use_process_request_outputs_callback = ( + self.engine.model_config.use_async_output_proc) + if self.use_process_request_outputs_callback: self.engine.process_request_outputs_callback = \ - self.process_request_outputs + weak_bind(self.process_request_outputs) self.background_loop: Optional[asyncio.Future] = None # We need to keep a reference to unshielded @@ -628,6 +491,11 @@ def __init__(self, # Lazy initialized fields self._request_tracker: RequestTracker + def __del__(self): + if rt := getattr(self, "request_tracker", None): + # Wake up engine loop so that it will exit cleanly + rt.new_requests_event.set() + @classmethod def _get_executor_cls( cls, engine_config: EngineConfig) -> Type[ExecutorAsyncBase]: @@ -638,15 +506,12 @@ def _get_executor_cls( raise TypeError( "distributed_executor_backend must be a subclass of " f"ExecutorAsyncBase. Got {distributed_executor_backend}.") - if distributed_executor_backend.uses_ray: # type: ignore - initialize_ray_cluster(engine_config.parallel_config) executor_class = distributed_executor_backend elif engine_config.device_config.device_type == "neuron": from vllm.executor.neuron_executor import NeuronExecutorAsync executor_class = NeuronExecutorAsync elif engine_config.device_config.device_type == "tpu": if distributed_executor_backend == "ray": - initialize_ray_cluster(engine_config.parallel_config) from vllm.executor.ray_tpu_executor import RayTPUExecutorAsync executor_class = RayTPUExecutorAsync else: @@ -667,11 +532,9 @@ def _get_executor_cls( from vllm.executor.xpu_executor import XPUExecutorAsync executor_class = XPUExecutorAsync elif distributed_executor_backend == "ray": - initialize_ray_cluster(engine_config.parallel_config) from vllm.executor.ray_xpu_executor import RayXPUExecutorAsync executor_class = RayXPUExecutorAsync elif distributed_executor_backend == "mp": - initialize_ray_cluster(engine_config.parallel_config) from vllm.executor.multiproc_xpu_executor import ( MultiprocessingXPUExecutorAsync) executor_class = MultiprocessingXPUExecutorAsync @@ -679,7 +542,6 @@ def _get_executor_cls( raise RuntimeError( "Not supported distributed execution model on XPU device.") elif distributed_executor_backend == "ray": - initialize_ray_cluster(engine_config.parallel_config) from vllm.executor.ray_gpu_executor import RayGPUExecutorAsync executor_class = RayGPUExecutorAsync elif distributed_executor_backend == "mp": @@ -695,19 +557,23 @@ def _get_executor_cls( def from_engine_args( cls, engine_args: AsyncEngineArgs, + engine_config: Optional[EngineConfig] = None, start_engine_loop: bool = True, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, stat_loggers: Optional[Dict[str, StatLoggerBase]] = None, ) -> "AsyncLLMEngine": """Creates an async LLM engine from the engine arguments.""" # Create the engine configs. - engine_config = engine_args.create_engine_config() + if engine_config is None: + engine_config = engine_args.create_engine_config() executor_class = cls._get_executor_cls(engine_config) + if executor_class.uses_ray: + initialize_ray_cluster(engine_config.parallel_config) + # Create the async LLM engine. engine = cls( - executor_class.uses_ray, **engine_config.to_dict(), executor_class=executor_class, log_requests=not engine_args.disable_log_requests, @@ -735,9 +601,12 @@ def errored(self) -> bool: return self._errored_with is not None @property - def limit_concurrency(self) -> Optional[int]: - """Maximum number of concurrently running requests.""" - return None + def dead_error(self) -> BaseException: + return AsyncEngineDeadError( + "Background loop is not running. If it was running, " + "inspect the output to find the stacktrace of the " + "error that caused the background loop to stop " + "(AsyncEngineDeadError).") def set_errored(self, exc: Exception) -> None: self._errored_with = exc @@ -764,7 +633,7 @@ def start_background_loop(self) -> None: self._request_tracker = RequestTracker() self._background_loop_unshielded = asyncio.get_event_loop( - ).create_task(self.run_engine_loop()) + ).create_task(self.run_engine_loop(weakref.ref(self))) self._background_loop_unshielded.add_done_callback( partial(_log_task_completion, error_callback=self._error_callback)) self.background_loop = asyncio.shield(self._background_loop_unshielded) @@ -834,9 +703,16 @@ def process_request_outputs(self, request_outputs) -> bool: async def _engine_abort(self, request_ids: Iterable[str]): self.engine.abort_request(request_ids) - async def run_engine_loop(self): + @staticmethod + async def run_engine_loop(engine_ref: ReferenceType): + """We use a weakref to the engine so that the running loop + doesn't prevent the engine being garbage collected.""" + engine: Optional["AsyncLLMEngine"] = engine_ref() + if not engine: + return + pipeline_parallel_size = \ - self.engine.parallel_config.pipeline_parallel_size + engine.engine.parallel_config.pipeline_parallel_size has_requests_in_progress = [False] * pipeline_parallel_size while True: if not any(has_requests_in_progress): @@ -847,11 +723,21 @@ async def run_engine_loop(self): # timeout, and unblocks the RPC thread in the workers so that # they can process any other queued control plane messages, # such as add/remove lora adapters. - await self.engine.stop_remote_worker_execution_loop_async() - await self._request_tracker.wait_for_new_requests() + await engine.engine.stop_remote_worker_execution_loop_async() + request_tracker = engine._request_tracker + # Allow engine to be garbage collected while + # waiting for new requests + del engine + await asyncio.sleep(0) + if engine_ref() is None: + return + await request_tracker.wait_for_new_requests() + engine = engine_ref() + if not engine: + return logger.debug("Got new requests!") requests_in_progress = [ - asyncio.create_task(self.engine_step(ve)) + asyncio.create_task(engine.engine_step(ve)) for ve in range(pipeline_parallel_size) ] has_requests_in_progress = [True] * pipeline_parallel_size @@ -869,19 +755,20 @@ async def run_engine_loop(self): result = task.result() virtual_engine = requests_in_progress.index(task) has_unfinished_requests = ( - self.engine.has_unfinished_requests_for_virtual_engine( + engine.engine. + has_unfinished_requests_for_virtual_engine( virtual_engine)) if result or has_unfinished_requests: requests_in_progress[virtual_engine] = ( asyncio.create_task( - self.engine_step(virtual_engine))) + engine.engine_step(virtual_engine))) has_requests_in_progress[virtual_engine] = True else: has_requests_in_progress[virtual_engine] = False except asyncio.TimeoutError as exc: logger.error( "Engine iteration timed out. This should never happen!") - self.set_errored(exc) + engine.set_errored(exc) raise await asyncio.sleep(0) @@ -942,7 +829,7 @@ async def generate( request_id: The unique id of the request. lora_request: LoRA request to use for generation, if any. trace_headers: OpenTelemetry trace headers. - prompt_adapter_request: Prompt Adapter request to use + prompt_adapter_request: Prompt Adapter request to use for generation, if any. Yields: @@ -1156,7 +1043,17 @@ def remove_logger(self, logger_name: str) -> None: self.engine.remove_logger(logger_name=logger_name) async def start_profile(self) -> None: - self.engine.model_executor._run_workers("start_profile") + # using type instead of isinstance to check to avoid capturing + # inherited classes + if type(self.engine.model_executor) == GPUExecutorAsync: # noqa: E721 + self.engine.model_executor.start_profile() + else: + self.engine.model_executor._run_workers("start_profile") async def stop_profile(self) -> None: - self.engine.model_executor._run_workers("stop_profile") + # using type instead of isinstance to check to avoid capturing + # inherited classes + if type(self.engine.model_executor) == GPUExecutorAsync: # noqa: E721 + self.engine.model_executor.stop_profile() + else: + self.engine.model_executor._run_workers("stop_profile") diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 92e46c7af5162..2743d5c7d2282 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -1,15 +1,15 @@ -import functools import time from collections import deque from contextlib import contextmanager from dataclasses import dataclass +from functools import partial from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Deque, Dict, Iterable, List, Mapping, NamedTuple, Optional) from typing import Sequence as GenericSequence -from typing import Set, Tuple, Type, Union +from typing import Set, Type, Union import torch -from typing_extensions import TypeVar, assert_never +from typing_extensions import TypeVar import vllm.envs as envs from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, @@ -26,20 +26,19 @@ from vllm.engine.output_processor.stop_checker import StopChecker from vllm.engine.output_processor.util import create_output_by_sequence_group from vllm.executor.executor_base import ExecutorBase +from vllm.executor.gpu_executor import GPUExecutor from vllm.executor.ray_utils import initialize_ray_cluster from vllm.inputs import (INPUT_REGISTRY, EncoderDecoderLLMInputs, - InputRegistry, LLMInputs, PromptInputs, - SingletonPromptInputs) -from vllm.inputs.parse import is_explicit_encoder_decoder_prompt + InputRegistry, LLMInputs, PromptInputs) +from vllm.inputs.preprocess import InputPreprocessor from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.multimodal import MultiModalDataDict from vllm.outputs import (EmbeddingRequestOutput, RequestOutput, RequestOutputFactory) from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest -from vllm.sampling_params import SamplingParams +from vllm.sampling_params import RequestOutputKind, SamplingParams from vllm.sequence import (EmbeddingSequenceGroupOutput, ExecuteModelRequest, Sequence, SequenceGroup, SequenceGroupMetadata, SequenceStatus) @@ -52,7 +51,7 @@ BaseTokenizerGroup, init_tokenizer_from_configs) from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled, usage_message) -from vllm.utils import Counter, Device +from vllm.utils import Counter, Device, weak_bind from vllm.version import __version__ as VLLM_VERSION logger = init_logger(__name__) @@ -75,11 +74,6 @@ def _load_generation_config_dict(model_config: ModelConfig) -> Dict[str, Any]: _G = TypeVar("_G", bound=BaseTokenizerGroup, default=BaseTokenizerGroup) _O = TypeVar("_O", RequestOutput, EmbeddingRequestOutput) -PromptComponents = Tuple[Optional[str], List[int], - Optional[MultiModalDataDict]] -DecoderPromptComponents = Tuple[Optional[str], Optional[List[int]], - Optional[MultiModalDataDict]] - @dataclass class SchedulerOutputState: @@ -150,7 +144,7 @@ class LLMEngine: decoding. executor_class: The model executor class for managing distributed execution. - prompt_adapter_config (Optional): The configuration related to serving + prompt_adapter_config (Optional): The configuration related to serving prompt adapters. log_stats: Whether to log statistics. usage_context: Specified entry point, used for usage info collection. @@ -225,9 +219,6 @@ def __init__( usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, stat_loggers: Optional[Dict[str, StatLoggerBase]] = None, input_registry: InputRegistry = INPUT_REGISTRY, - # To improve performance, only final requests outputs may be required. - # If this set to true, then no intermediate outputs will be returned. - step_return_finished_only: bool = False, ) -> None: logger.info( "Initializing an LLM engine (v%s) with config: " @@ -295,7 +286,6 @@ def __init__( self.observability_config = observability_config or ObservabilityConfig( ) self.log_stats = log_stats - self.step_return_finished_only = step_return_finished_only if not self.model_config.skip_tokenizer_init: self.tokenizer = self._init_tokenizer() @@ -317,6 +307,9 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer: self.generation_config_fields = _load_generation_config_dict( model_config) + self.input_preprocessor = InputPreprocessor(model_config, + self.tokenizer) + self.input_registry = input_registry self.input_processor = input_registry.create_input_processor( model_config) @@ -389,11 +382,16 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer: for _ in range(self.parallel_config.pipeline_parallel_size) ] - self.async_callbacks = [ - functools.partial(self._process_model_outputs, - ctx=self.scheduler_contexts[v_id]) - for v_id in range(self.parallel_config.pipeline_parallel_size) - ] + if model_config.use_async_output_proc: + process_model_outputs = weak_bind(self._process_model_outputs) + + self.async_callbacks = [ + partial(process_model_outputs, + ctx=self.scheduler_contexts[v_id]) + for v_id in range(self.parallel_config.pipeline_parallel_size) + ] + else: + self.async_callbacks = [] # Currently used by AsyncLLMEngine to ensure quick append # of request outputs to asyncio queues @@ -575,19 +573,15 @@ def __del__(self): if model_executor := getattr(self, "model_executor", None): model_executor.shutdown() - MISSING_TOKENIZER_GROUP_MSG = ("Unable to get tokenizer because " - "skip_tokenizer_init is True") - def get_tokenizer_group( self, group_type: Type[_G] = BaseTokenizerGroup, - *, - missing_msg: str = MISSING_TOKENIZER_GROUP_MSG, ) -> _G: tokenizer_group = self.tokenizer if tokenizer_group is None: - raise ValueError(missing_msg) + raise ValueError("Unable to get tokenizer because " + "skip_tokenizer_init is True") if not isinstance(tokenizer_group, group_type): raise TypeError("Invalid type of tokenizer group. " f"Expected type: {group_type}, but " @@ -619,52 +613,6 @@ def _verify_args(self) -> None: self.prompt_adapter_config.verify_with_model_config( self.model_config) - def _get_bos_token_id(self, - lora_request: Optional[LoRARequest] = None - ) -> Optional[int]: - if self.tokenizer is None: - logger.warning("Using None for BOS token id because tokenizer " - "is not initialized") - return None - - return self.tokenizer.get_lora_tokenizer(lora_request).bos_token_id - - def _get_eos_token_id(self, - lora_request: Optional[LoRARequest] = None - ) -> Optional[int]: - if self.tokenizer is None: - logger.warning("Using None for EOS token id because tokenizer " - "is not initialized") - return None - - return self.tokenizer.get_lora_tokenizer(lora_request).eos_token_id - - def _get_decoder_start_token_id(self) -> Optional[int]: - ''' - Obtain the decoder start token id employed by an encoder/decoder - model. Returns None for non-encoder/decoder models or if the - model config is unavailable. - ''' - - if not self.is_encoder_decoder_model(): - logger.warning("Using None for decoder start token id because " - "this is not an encoder/decoder model.") - return None - - if (self.model_config is None or self.model_config.hf_config is None): - logger.warning("Using None for decoder start token id because " - "model config is not available.") - return None - - dec_start_token_id = getattr(self.model_config.hf_config, - 'decoder_start_token_id', None) - if dec_start_token_id is None: - logger.warning("Falling back on for decoder start token id " - "because decoder start token id is not available.") - dec_start_token_id = self._get_bos_token_id() - - return dec_start_token_id - def _add_processed_request( self, request_id: str, @@ -679,7 +627,7 @@ def _add_processed_request( # Create the sequences. block_size = self.cache_config.block_size seq_id = next(self.seq_counter) - eos_token_id = self._get_eos_token_id(lora_request) + eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request) seq = Sequence(seq_id, processed_inputs, block_size, eos_token_id, lora_request, prompt_adapter_request) @@ -729,334 +677,6 @@ def _add_processed_request( def stop_remote_worker_execution_loop(self) -> None: self.model_executor.stop_remote_worker_execution_loop() - _LLMInputComponentsType = Tuple[str, List[int]] - - def _prepare_decoder_input_ids_for_generation( - self, - decoder_input_ids: Optional[List[int]], - ) -> List[int]: - """ - Prepares `decoder_input_ids` for generation with encoder-decoder models. - - Based on - - https://github.com/huggingface/transformers/blob/ - 4037a2b5b1278736e566aec12e169100275545ea/ - src/transformers/generation/utils.py - - specifically GenerationMixin._prepare_decoder_input_ids_for_generation() - - Arguments: - - * decoder_input_ids: input token ids to preprocess - - Returns: - - * Processed token list - """ - - decoder_start_token_id = self._get_decoder_start_token_id() - assert decoder_start_token_id is not None - - if decoder_input_ids is None: - # no decoder prompt input -> - # use decoder_start_token_id as decoder_input_ids - decoder_input_ids = self._get_default_enc_dec_decoder_prompt() - - if (len(decoder_input_ids) == 0 - or decoder_input_ids[0] != decoder_start_token_id): - decoder_input_ids = [decoder_start_token_id] + decoder_input_ids - - return decoder_input_ids - - def _tokenize_prompt( - self, - prompt: str, - request_id: str, - lora_request: Optional[LoRARequest], - ) -> List[int]: - ''' - Wrapper around application of the model's tokenizer. - - Arguments: - - * prompt - * request_id - * lora_request - - Returns: - - * prompt token ids - ''' - - tokenizer = self.get_tokenizer_group( - missing_msg="prompts must be None if skip_tokenizer_init is True") - - return tokenizer.encode(request_id=request_id, - prompt=prompt, - lora_request=lora_request) - - def _extract_prompt_components( - self, - inputs: SingletonPromptInputs, - request_id: str, - lora_request: Optional[LoRARequest] = None, - ) -> PromptComponents: - ''' - Extract the components of any single encoder or decoder input prompt. - - Arguments: - - * request_id - * inputs: single encoder or decoder input prompt - * lora_request: this is only valid for decoder prompts - - Returns: - - * prompt - * prompt_token_ids - * multi_modal_data - ''' - - if isinstance(inputs, str): - prompt = inputs - prompt_token_ids = self._tokenize_prompt( - prompt, - request_id=request_id, - lora_request=lora_request, - ) - multi_modal_data = None - elif isinstance(inputs, dict): - if "prompt_token_ids" in inputs: - prompt = None - prompt_token_ids = inputs["prompt_token_ids"] - else: - # NOTE: This extra assignment is required to pass mypy - prompt = parsed_prompt = inputs["prompt"] - prompt_token_ids = self._tokenize_prompt( - parsed_prompt, - request_id=request_id, - lora_request=lora_request, - ) - - multi_modal_data = inputs.get("multi_modal_data") - else: - assert_never(inputs) - - return prompt, prompt_token_ids, multi_modal_data - - def _apply_prompt_adapter( - self, - prompt_token_ids: List[int], - prompt_adapter_request: Optional[PromptAdapterRequest], - ) -> List[int]: - if prompt_adapter_request: - prompt_token_ids = ( - [0] * prompt_adapter_request.prompt_adapter_num_virtual_tokens - + prompt_token_ids) - - return prompt_token_ids - - def _get_default_enc_dec_decoder_prompt(self) -> List[int]: - ''' - Specifically for encoder/decoder models: - generate a default decoder prompt for when - the user specifies only the encoder prompt. - - Encoder/decoder models utilize the decoder - prompt in different ways; as new models are - added, it is intended that this function - will be extended to produce differing - default decoder prompts, depending on the - model variety. - - Absent a special case, the default behavior - of this method is to mirror the behavior of - the HuggingFace (HF) GenerationMixin for a None - decoder prompt, which is to employ a logit processor - setting to force the first decoded token to be . - Here, this behavior is approximated by having the - "default" decoder prompt be . - - However, it is possible that in the future - other models may have different or more - complex logic for the default decoder prompt. - This motivates having a special helper method - for default decoder prompts. - - Returns: - - * prompt_token_ids - ''' - - bos_token_id = self._get_bos_token_id() - assert bos_token_id is not None - return [bos_token_id] - - def _build_enc_dec_llm_inputs( - self, - encoder_comps: PromptComponents, - decoder_comps: DecoderPromptComponents, - ) -> EncoderDecoderLLMInputs: - encoder_prompt, encoder_prompt_ids, encoder_mm_data = encoder_comps - decoder_prompt, decoder_prompt_ids, decoder_mm_data = decoder_comps - - if encoder_mm_data is not None or decoder_mm_data is not None: - raise ValueError("Multi-modal encoder-decoder models are " - "not supported yet") - - decoder_prompt_ids = ( - self._prepare_decoder_input_ids_for_generation(decoder_prompt_ids)) - - return EncoderDecoderLLMInputs( - prompt_token_ids=decoder_prompt_ids, - prompt=decoder_prompt, - encoder_prompt_token_ids=encoder_prompt_ids, - encoder_prompt=encoder_prompt, - ) - - def _process_encoder_decoder_prompt( - self, - inputs: PromptInputs, - request_id: str, - ) -> EncoderDecoderLLMInputs: - ''' - For encoder/decoder models only: - Process an input prompt into an - :class:`EncoderDecoderLLMInputs` instance. - - There are two types of input prompts: - singleton prompts which carry only the - encoder prompt, and explicit encoder/decoder - prompts which carry both the encoder and the - decoder prompts as member variables. - - This function handles the following scenarios: - * Singleton encoder prompt: extract encoder prompt - token ids & infer default decoder prompt token ids - * Explicit encoder/decoder prompt: extract encoder - and decoder prompt token ids - - Note that for Explicit encoder/decoder prompts, - each sub-prompt (encoder or decoder prompt) can - have any possible singleton type; thus this - method relies on helper functions to obtain - token ids for the sub-prompts. - - Arguments: - - * inputs: an input prompt - * request_id - - Returns: - - * :class:`EncoderDecoderLLMInputs` instance - ''' - - encoder_comps: PromptComponents - decoder_comps: DecoderPromptComponents - - if is_explicit_encoder_decoder_prompt(inputs): - encoder_comps = self._extract_prompt_components( - inputs["encoder_prompt"], - request_id=request_id, - ) - - if (decoder_input := inputs["decoder_prompt"]) is None: - decoder_comps = None, None, None - else: - decoder_comps = self._extract_prompt_components( - decoder_input, - request_id=request_id, - ) - else: - encoder_comps = self._extract_prompt_components( - inputs, - request_id=request_id, - ) - - decoder_comps = None, None, None - - return self._build_enc_dec_llm_inputs(encoder_comps, decoder_comps) - - def _build_decoder_only_llm_inputs( - self, - prompt_comps: PromptComponents, - prompt_adapter_request: Optional[PromptAdapterRequest], - ) -> LLMInputs: - prompt, prompt_token_ids, multi_modal_data = prompt_comps - - prompt_token_ids = self._apply_prompt_adapter( - prompt_token_ids, prompt_adapter_request=prompt_adapter_request) - - return LLMInputs(prompt_token_ids=prompt_token_ids, - prompt=prompt, - multi_modal_data=multi_modal_data) - - def _process_decoder_only_prompt( - self, - inputs: SingletonPromptInputs, - request_id: str, - lora_request: Optional[LoRARequest] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None, - ) -> LLMInputs: - ''' - For decoder-only models: - Process an input prompt into an :class:`LLMInputs` instance. - - Arguments: - - * inputs: input prompt - * request_id - * lora_request - * prompt_adapter_request - - Returns: - - * :class:`LLMInputs` instance - ''' - - prompt_comps = self._extract_prompt_components( - inputs, - request_id=request_id, - lora_request=lora_request, - ) - - return self._build_decoder_only_llm_inputs( - prompt_comps, - prompt_adapter_request=prompt_adapter_request, - ) - - def process_model_inputs( - self, - inputs: PromptInputs, - request_id: str, - lora_request: Optional[LoRARequest] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None, - ) -> Union[LLMInputs, EncoderDecoderLLMInputs]: - - if self.is_encoder_decoder_model(): - # Encoder-decoder model requires special mapping of - # input prompts to encoder & decoder - model_inputs = self._process_encoder_decoder_prompt( - inputs, - request_id=request_id, - ) - else: - if is_explicit_encoder_decoder_prompt(inputs): - raise ValueError("Cannot pass encoder-decoder prompt " - "to decoder-only models") - - # Decoder-only operation - model_inputs = self._process_decoder_only_prompt( - inputs, - request_id=request_id, - lora_request=lora_request, - prompt_adapter_request=prompt_adapter_request, - ) - - return self.input_processor(model_inputs) - def add_request( self, request_id: str, @@ -1115,12 +735,13 @@ def add_request( if arrival_time is None: arrival_time = time.time() - processed_inputs = self.process_model_inputs( + preprocessed_inputs = self.input_preprocessor.preprocess( inputs, request_id=request_id, lora_request=lora_request, prompt_adapter_request=prompt_adapter_request, ) + processed_inputs = self.input_processor(preprocessed_inputs) self._add_processed_request( request_id=request_id, @@ -1253,8 +874,8 @@ def has_unfinished_requests_for_virtual_engine( """ return self.scheduler[virtual_engine].has_unfinished_seqs() + @staticmethod def _process_sequence_group_outputs( - self, seq_group: SequenceGroup, outputs: List[EmbeddingSequenceGroupOutput], ) -> None: @@ -1273,7 +894,7 @@ def _process_model_outputs(self, ctx: The virtual engine context to work on request_id: If provided, then only this request is going to be processed - + """ now = time.time() @@ -1378,7 +999,8 @@ def _process_model_outputs(self, seq_group = scheduled_seq_group.seq_group seq_group.maybe_set_first_token_time(now) request_output = RequestOutputFactory.create(seq_group) - ctx.request_outputs.append(request_output) + if request_output: + ctx.request_outputs.append(request_output) # When we process a single request, we skip it for the next time, # and invoke the request output callback (if there was final output) @@ -1415,14 +1037,19 @@ def _process_model_outputs(self, seq_group = scheduled_seq_group.seq_group seq_group.maybe_set_first_token_time(now) - if (seq_group.is_finished() - if self.step_return_finished_only else True): - request_output = RequestOutputFactory.create(seq_group) + request_output = RequestOutputFactory.create(seq_group) + if request_output: ctx.request_outputs.append(request_output) for seq_group in scheduler_outputs.ignored_seq_groups: + params = seq_group.sampling_params + if params is not None and params.output_kind == ( + RequestOutputKind.DELTA) and not seq_group.is_finished(): + continue + request_output = RequestOutputFactory.create(seq_group) - ctx.request_outputs.append(request_output) + if request_output: + ctx.request_outputs.append(request_output) # Immediately process request outputs here (if callback is given) if (ctx.request_outputs @@ -1435,7 +1062,8 @@ def _process_model_outputs(self, # LLMEngine/AsyncLLMEngine directly if is_async: # Log stats. - self.do_log_stats(scheduler_outputs, outputs, finished_before) + self.do_log_stats(scheduler_outputs, outputs, finished_before, + skip) # Tracing self.do_tracing(scheduler_outputs) @@ -1661,6 +1289,7 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: # torch.distributed ops which may otherwise timeout, and unblocks # the RPC thread in the workers so that they can process any other # queued control plane messages, such as add/remove lora adapters. + logger.debug("Stopping remote worker execution loop.") self.model_executor.stop_remote_worker_execution_loop() return ctx.request_outputs @@ -1742,18 +1371,20 @@ def remove_logger(self, logger_name: str) -> None: def do_log_stats(self, scheduler_outputs: Optional[SchedulerOutputs] = None, model_output: Optional[List[SamplerOutput]] = None, - finished_before: Optional[List[int]] = None) -> None: + finished_before: Optional[List[int]] = None, + skip: Optional[List[int]] = None) -> None: """Forced log when no requests active.""" if self.log_stats: stats = self._get_stats(scheduler_outputs, model_output, - finished_before) + finished_before, skip) for logger in self.stat_loggers.values(): logger.log(stats) def _get_stats(self, scheduler_outputs: Optional[SchedulerOutputs], model_output: Optional[List[SamplerOutput]] = None, - finished_before: Optional[List[int]] = None) -> Stats: + finished_before: Optional[List[int]] = None, + skip: Optional[List[int]] = None) -> Stats: """Get Stats to be Logged to Prometheus. Args: @@ -1761,6 +1392,10 @@ def _get_stats(self, the scheduled batch, model_output: Optional, used to emit speculative decoding metrics which are created by the workers. + finished_before: Optional, indices of sequences that were finished + before. These sequences will be ignored. + skip: Optional, indices of sequences that were preempted. These + sequences will be ignored. """ now = time.time() @@ -1835,6 +1470,11 @@ def _get_stats(self, actual_num_batched_tokens -= 1 continue + # Currently, skip == preempted sequences, so we need to skip + # their log stats + if skip and idx in skip: + continue + group_was_prefill = idx < scheduler_outputs.num_prefill_groups seq_group = scheduled_seq_group.seq_group @@ -1964,10 +1604,20 @@ def check_health(self) -> None: self.model_executor.check_health() def start_profile(self) -> None: - self.model_executor.start_profile() + # using type instead of isinstance to check to avoid capturing + # inherited classes (MultiprocessingGPUExecutor) + if type(self.model_executor) == GPUExecutor: # noqa: E721 + self.model_executor.start_profile() + else: + self.model_executor._run_workers("start_profile") def stop_profile(self) -> None: - self.model_executor.stop_profile() + # using type instead of isinstance to check to avoid capturing + # inherited classes (MultiprocessingGPUExecutor) + if type(self.model_executor) == GPUExecutor: # noqa: E721 + self.model_executor.stop_profile() + else: + self.model_executor._run_workers("stop_profile") def is_tracing_enabled(self) -> bool: return self.tracer is not None @@ -2041,7 +1691,7 @@ def create_trace_span(self, seq_group: SequenceGroup) -> None: metrics.model_execute_time) def is_encoder_decoder_model(self): - return self.model_config.is_encoder_decoder_model + return self.input_preprocessor.is_encoder_decoder_model() def is_embedding_model(self): return self.model_config.is_embedding_model diff --git a/vllm/engine/multiprocessing/__init__.py b/vllm/engine/multiprocessing/__init__.py new file mode 100644 index 0000000000000..ba5c6e15fc821 --- /dev/null +++ b/vllm/engine/multiprocessing/__init__.py @@ -0,0 +1,73 @@ +from dataclasses import dataclass +from enum import Enum +from typing import List, Mapping, Optional, Union + +from vllm.inputs import PromptInputs +from vllm.lora.request import LoRARequest +from vllm.outputs import RequestOutput +from vllm.prompt_adapter.request import PromptAdapterRequest +from vllm.sampling_params import SamplingParams + +VLLM_RPC_SUCCESS_STR = "SUCCESS" + +IPC_INPUT_EXT = "_input_socket" +IPC_OUTPUT_EXT = "_output_socket" +IPC_HEALTH_EXT = "_health_socket" +IPC_DATA_EXT = "_data_socket" + + +class MQEngineDeadError(RuntimeError): + pass + + +@dataclass +class RPCGenerateRequest: + inputs: PromptInputs + sampling_params: SamplingParams + request_id: str + lora_request: Optional[LoRARequest] = None + trace_headers: Optional[Mapping[str, str]] = None + prompt_adapter_request: Optional[PromptAdapterRequest] = None + + +@dataclass +class RPCError: + request_id: Optional[str] + is_engine_errored: bool + exception: BaseException + + +@dataclass +class RPCAbortRequest: + request_id: str + + +class RPCHealthRequest: + pass + + +class RPCStartupRequest(Enum): + IS_SERVER_READY = 1 + + +@dataclass +class RPCStartupResponse: + tracing_enabled: bool + + +RPC_REQUEST_T = Union[RPCGenerateRequest, RPCAbortRequest, RPCHealthRequest, + RPCStartupRequest] + +REQUEST_OUTPUTS_T = Union[List[RequestOutput], RPCError] + + +def ENGINE_DEAD_ERROR( + error: Optional[BaseException] = None) -> MQEngineDeadError: + if error is None: + return MQEngineDeadError( + "Engine loop is not running. Inspect the stacktrace to " + "find the original error") + + return MQEngineDeadError( + "Engine loop is not running. Inspect the stacktrace to " + f"find the original error: {repr(error)}.") diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py new file mode 100644 index 0000000000000..18b620c74ddf9 --- /dev/null +++ b/vllm/engine/multiprocessing/client.py @@ -0,0 +1,452 @@ +import asyncio +import copy +import pickle +from contextlib import contextmanager, suppress +from typing import (Any, AsyncGenerator, Dict, Iterator, Mapping, Optional, + Union) + +import cloudpickle +import zmq +import zmq.asyncio +from zmq import Frame # type: ignore[attr-defined] +from zmq.asyncio import Socket + +from vllm.config import DecodingConfig, EngineConfig, ModelConfig +from vllm.engine.arg_utils import AsyncEngineArgs +# yapf conflicts with isort for this block +# yapf: disable +from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT, + IPC_HEALTH_EXT, IPC_INPUT_EXT, + IPC_OUTPUT_EXT, RPC_REQUEST_T, + VLLM_RPC_SUCCESS_STR, RPCAbortRequest, + RPCError, RPCGenerateRequest, + RPCHealthRequest, RPCStartupRequest, + RPCStartupResponse) +# yapf: enable +from vllm.envs import VLLM_RPC_TIMEOUT +from vllm.inputs import PromptInputs +from vllm.logger import init_logger +from vllm.lora.request import LoRARequest +from vllm.outputs import EmbeddingRequestOutput, RequestOutput +from vllm.prompt_adapter.request import PromptAdapterRequest +from vllm.sampling_params import SamplingParams +from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs + +logger = init_logger(__name__) + + +class MQClientClosedError(Exception): + """Exception class raised when the client is used post-close. + + The client can be closed, which closes the ZMQ context. This normally + happens on server shutdown. In some cases, methods like abort and + do_log_stats will still be called and then try to open a socket, which + causes a ZMQError and creates a huge stack trace. + So, we throw this error such that we can suppress it. + """ + + +class MQLLMEngineClient: + """A client wrapper for MQLLMEngine that conforms to the + EngineClient protocol. + + MQLLMEngine and MQLLMEngineClient are intended to run in separate + processes communicating via zeromq ipc sockets. + + The entrypoint to MQLLMEngineClient is through the generate() + method. On generate() MQLLMEngine does three things: + - Creates an asyncio output queue + - Sends a RPCGenerateRequest to the MQLLMEngine via zmq + - Pulls RequestOutputs from its queue and yields them + + MQLLMEngine runs two background loops: + - output_loop: the output loop pulls List[RequestOutput] + from the MQLLMEngine via zmq (each list is the output + of one engine_step in the LLMEngine). It then parses + the list and pushes individual request_outputs into + the corresponding output_queue such that they can be + consumed by the .generate() method. + - health_loop: the health loop queries the health socket + every N seconds, confirming the engine is healthy + """ + + def __init__(self, ipc_path: str, engine_config: EngineConfig): + self.context = zmq.asyncio.Context() + self._errored_with: Optional[BaseException] = None + + # Get the configs. + self.model_config = engine_config.model_config + self.decoding_config = engine_config.decoding_config + + # Create the tokenizer group. + self.tokenizer = init_tokenizer_from_configs( + model_config=self.model_config, + scheduler_config=engine_config.scheduler_config, + parallel_config=engine_config.parallel_config, + enable_lora=bool(engine_config.lora_config), + ) + + # Send RPCGenerateRequest to the MQLLMEngine. + self.input_socket: Socket = self.context.socket(zmq.constants.PUSH) + self.input_socket.connect(f"{ipc_path}{IPC_INPUT_EXT}") + + # Receive streams of RequestOutput from the MQLLMEngine. + self.output_socket: Socket = self.context.socket(zmq.constants.PULL) + self.output_socket.connect(f"{ipc_path}{IPC_OUTPUT_EXT}") + + # IPC path for ack of check_health requests. + self.health_socket: Socket = self.context.socket(zmq.constants.PULL) + self.health_socket.connect(f"{ipc_path}{IPC_HEALTH_EXT}") + + # IPC path for the data socket. + self.data_ipc_path = f"{ipc_path}{IPC_DATA_EXT}" + + # Stream for each individual request. + self.output_queues: Dict[str, asyncio.Queue] = {} + self.output_loop = asyncio.create_task(self.run_output_handler_loop()) + + # Loop to check health of the LLMEngine periodically. + # Started after the MQLLMEngine is ready. + self.health_loop: Optional[asyncio.Task] = None + + @staticmethod + def is_unsupported_config(engine_args: AsyncEngineArgs): + if engine_args.pipeline_parallel_size > 1: + return True + + is_embedding = ModelConfig( + model=engine_args.model, + revision=engine_args.revision, + tokenizer=engine_args.model, + tokenizer_mode="auto", + trust_remote_code=engine_args.trust_remote_code, + quantization=engine_args.quantization, + seed=0, + dtype="auto").embedding_mode + + return is_embedding + + @contextmanager + def get_data_socket(self) -> Iterator[Socket]: + socket = self.context.socket(zmq.constants.DEALER) + try: + socket.connect(self.data_ipc_path) + yield socket + finally: + socket.close(linger=0) + + async def run_check_health_loop(self, timeout: int): + """Background loop that continually probes the RPCServer for health. + + The loop sends CHECK_HEALTH requests to the INPUT_SOCKET, which + the MQLLMEngine server is blocking on. + + The Server replies on the HEALTH_SOCKET (rather than on the + OUTPUT_SOCKET such that the messages are not intermingled with + output streaming). + """ + + try: + while True: + if await self.health_socket.poll(timeout=timeout) == 0: + # Wakeup every N seconds and do a health probe. + await self._send_one_way_rpc_request( + RPCHealthRequest(), self.input_socket) + + # Wait for ack from the health socket. + await self._await_ack(error_message="Health check failed.", + socket=self.health_socket) + else: + # Server sent a health status message unprompted. + await self._check_success( + error_message="Health check failed.", + socket=self.health_socket) + + logger.debug("Health probe successful.") + + except asyncio.CancelledError: + logger.debug("Shutting down MQLLMEngineClient check health loop.") + + except Exception as e: + self._set_errored(e) + + async def run_output_handler_loop(self): + """Get RequestOutputs from Engine and stream to Request Queues""" + + try: + while True: + # Poll, checking for ENGINE_DEAD + while await self.output_socket.poll(timeout=VLLM_RPC_TIMEOUT + ) == 0: + logger.debug("Waiting for output from MQLLMEngine.") + + # If errored, alert all running requests. + if self.errored: + for queue_j in tuple(self.output_queues.values()): + queue_j.put_nowait( + ENGINE_DEAD_ERROR(self._errored_with)) + return + + message: Frame = await self.output_socket.recv(copy=False) + request_outputs = pickle.loads(message.buffer) + + is_error = isinstance(request_outputs, + (BaseException, RPCError)) + if is_error: + if isinstance(request_outputs, RPCError): + rpc_error: RPCError = request_outputs + request_id = rpc_error.request_id + exception = rpc_error.exception + is_engine_errored = rpc_error.is_engine_errored + else: + # MPLLMEngine should always return an RPCError to + # the output_socket when an issue arises. + # If we are here, we are in a bad state and + # should shut down the server. + error: BaseException = request_outputs + logger.error( + "Received Exception %s rather than RPCError from " + "MPLLMEngine. This should never happen.", error) + request_id = None + exception = error + is_engine_errored = True + + # Set to error state only on engine critical error + # (and record only the first one) + if is_engine_errored and not self._errored_with: + self._errored_with = exception + + if request_id is None: + for queue_i in tuple(self.output_queues.values()): + queue_i.put_nowait(exception) + else: + queue = self.output_queues.get(request_id) + if queue is not None: + queue.put_nowait(exception) + else: + # Put each output into the appropriate steam. + for request_output in request_outputs: + queue = self.output_queues.get( + request_output.request_id) + if queue is not None: + queue.put_nowait(request_output) + + except asyncio.CancelledError: + logger.debug("Shutting down MQLLMEngineClient output handler.") + + async def setup(self): + """Setup the client before it starts sending server requests.""" + + with self.get_data_socket() as socket: + # Wait until server is ready. + response = await self._wait_for_server_rpc(socket) + + self.tracing_flag = response.tracing_enabled + + # Start health_loop. + self.health_loop = asyncio.create_task( + self.run_check_health_loop(timeout=VLLM_RPC_TIMEOUT)) + + def close(self): + """Destroy the ZeroMQ Context.""" + # Close all sockets and terminate the context. + self.context.destroy(linger=0) + + # Cancel background tasks. + if self.health_loop is not None: + self.health_loop.cancel() + self.output_loop.cancel() + + def _set_errored(self, e: BaseException): + logger.exception(repr(e)) + if self._errored_with is None: + self._errored_with = e + + @staticmethod + async def _send_get_data_rpc_request(request: RPCStartupRequest, + expected_type: Any, + error_message: str, + socket: Socket) -> Any: + """Send an RPC request that is expecting data back.""" + + # Ping RPCServer with a request. + await socket.send_multipart((pickle.dumps(request), ), copy=False) + + # Make sure the server responds in time. + if await socket.poll(timeout=VLLM_RPC_TIMEOUT) == 0: + raise TimeoutError("RPCServer didn't reply within " + f"{VLLM_RPC_TIMEOUT} ms") + + # Await the data from the Server. + frame = await socket.recv(copy=False) + data = pickle.loads(frame.buffer) + + if isinstance(data, BaseException): + raise data + elif not isinstance(data, expected_type): + raise ValueError(error_message) + + return data + + @staticmethod + async def _send_one_way_rpc_request(request: RPC_REQUEST_T, + socket: Socket): + """Send one-way RPC request to trigger an action.""" + + if socket.closed: + raise MQClientClosedError() + + await socket.send_multipart((pickle.dumps(request), )) + + async def _await_ack(self, error_message: str, socket: Socket): + """Await acknowledgement that a request succeeded.""" + + if socket.closed: + raise MQClientClosedError() + + if await socket.poll(timeout=VLLM_RPC_TIMEOUT) == 0: + raise TimeoutError("MQLLMEngine didn't reply within " + f"{VLLM_RPC_TIMEOUT}ms") + + await self._check_success(error_message, socket) + + @staticmethod + async def _check_success(error_message: str, socket: Socket): + """Confirm that socket has a VLLM_RPC_SUCCESS_STR message""" + + if socket.closed: + raise MQClientClosedError() + + frame = await socket.recv(copy=False) + response = pickle.loads(frame.buffer) + + # Raise error if unsuccessful + if isinstance(response, BaseException): + raise response + elif (not isinstance(response, str) + or response != VLLM_RPC_SUCCESS_STR): + raise ValueError(error_message) + + async def get_tokenizer(self, lora_request: LoRARequest): + return await self.tokenizer.get_lora_tokenizer_async(lora_request) + + async def get_decoding_config(self) -> DecodingConfig: + return self.decoding_config + + async def get_model_config(self) -> ModelConfig: + return self.model_config + + async def is_tracing_enabled(self) -> bool: + return self.tracing_flag + + async def _wait_for_server_rpc(self, socket: Socket) -> RPCStartupResponse: + """Wait for the RPCServer to start up.""" + + return await self._send_get_data_rpc_request( + request=RPCStartupRequest.IS_SERVER_READY, + expected_type=RPCStartupResponse, + error_message="Unable to start RPC Server", + socket=socket) + + async def abort(self, request_id: str): + """Send an ABORT_REQUEST signal to the RPC Server""" + + with suppress(MQClientClosedError): + await self._send_one_way_rpc_request( + request=RPCAbortRequest(request_id), socket=self.input_socket) + + async def do_log_stats(self): + """Ignore do_log_stats (handled on MQLLMEngine polling)""" + pass + + async def check_health(self): + """ + The check health loop probes the health status of the + Engine's health every N seconds and sets _errored_with + if the engine is unhealthy. + """ + if self._errored_with is not None: + raise self._errored_with + + @property + def is_running(self) -> bool: + return not self.errored + + @property + def is_stopped(self) -> bool: + return self.errored + + @property + def errored(self) -> bool: + return self._errored_with is not None + + async def generate( + self, + inputs: PromptInputs, + sampling_params: SamplingParams, + request_id: str, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None + ) -> AsyncGenerator[RequestOutput, None]: + """Send an RPCGenerateRequest to the RPCServer and stream responses.""" + + # If already dead, error out. + if self._errored_with is not None: + raise ENGINE_DEAD_ERROR(self._errored_with) + + # 1) Create output queue for this requests. + queue: asyncio.Queue[Union[RequestOutput, + BaseException]] = asyncio.Queue() + self.output_queues[request_id] = queue + + try: + # 2) Detach logits processors so that they can be pickled + # separately (may require cloudpickle which is slower) + if sampling_params.logits_processors: + # Defensive shallow copy + sampling_params = copy.copy(sampling_params) + logits_processors = sampling_params.logits_processors + sampling_params.logits_processors = None + lp_bytes = cloudpickle.dumps(logits_processors) + else: + lp_bytes = None + + request_bytes = pickle.dumps( + RPCGenerateRequest( + inputs=inputs, + sampling_params=sampling_params, + request_id=request_id, + lora_request=lora_request, + trace_headers=trace_headers, + prompt_adapter_request=prompt_adapter_request)) + + # 3) Send the RPCGenerateRequest to the MQLLMEngine. + parts = (request_bytes, + lp_bytes) if lp_bytes else (request_bytes, ) + await self.input_socket.send_multipart(parts, copy=False) + + # 4) Stream the RequestOutputs from the output queue. Note + # that the output_loop pushes RequestOutput objects to this + # queue after pulling them from the zmq socket. + finished = False + try: + while not finished: + request_output = await queue.get() + + if isinstance(request_output, BaseException): + raise request_output + + finished = request_output.finished + yield request_output + finally: + # Request was canceled by the client. + if not finished and not self.errored: + await self.abort(request_id) + finally: + self.output_queues.pop(request_id) + + async def encode(self, *args, + **kwargs) -> AsyncGenerator[EmbeddingRequestOutput, None]: + raise NotImplementedError( + "Embeddings not supported with multiprocessing backend") diff --git a/vllm/engine/multiprocessing/engine.py b/vllm/engine/multiprocessing/engine.py new file mode 100644 index 0000000000000..70cd6e5cb6000 --- /dev/null +++ b/vllm/engine/multiprocessing/engine.py @@ -0,0 +1,321 @@ +import pickle +import signal +from contextlib import contextmanager +from typing import Iterator, List, Optional, Union + +import cloudpickle +import zmq + +from vllm import AsyncEngineArgs, LLMEngine +from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig, + ParallelConfig, SchedulerConfig) +# yapf conflicts with isort for this block +# yapf: disable +from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT, + IPC_HEALTH_EXT, IPC_INPUT_EXT, + IPC_OUTPUT_EXT, REQUEST_OUTPUTS_T, + VLLM_RPC_SUCCESS_STR, RPCAbortRequest, + RPCError, RPCGenerateRequest, + RPCHealthRequest, RPCStartupRequest, + RPCStartupResponse) +# yapf: enable +from vllm.logger import init_logger +from vllm.outputs import RequestOutput +from vllm.usage.usage_lib import UsageContext + +CONFIG_TYPE = Union[ModelConfig, DecodingConfig, ParallelConfig, + SchedulerConfig, LoRAConfig] + +logger = init_logger(__name__) + +POLLING_TIMEOUT_MS = 10000 +HEALTHY_RESPONSE = (pickle.dumps(VLLM_RPC_SUCCESS_STR), ) + + +class MQLLMEngine: + """A multiprocessing wrapper for :class:`LLMEngine`. + + This class is used to wrap the :class:`LLMEngine` class to enable use + in concurrnet manner. It runs a background loop and uses zeromq to + receive new requests and stream outputs incrementally via ipc. + + The :class:`LLMEngine.generate` is kicked off when a new + RPCGenerateRequest is received by the input_socket. + + The self.engine_loop checks the input_socket for new requests, + adds them to the LLMEngine if there are any, calls the internal + :class:`LLMEngine.step()`, and sends the RequestOutputs back over + the output_socket. + + If use_async_sockets is set, the logic associated with reading new + requests from the socket and sending data to the socket is passed + as a callback to the llm_engine, which calls the logic asynchronously + such that the IPC can be overlapped with the GPU. + + Args: + ipc_path: Base path for zeromq interprocess messaging + use_async_sockets: Whether to make send/recv async with GPU + log_requests: Whether to log the requests. + *args: Arguments for :class:`LLMEngine`. + **kwargs: Arguments for :class:`LLMEngine`. + """ + + def __init__(self, + ipc_path: str, + use_async_sockets: bool, + *args, + log_requests: bool = True, + **kwargs) -> None: + self.engine = LLMEngine(*args, **kwargs) + self.log_requests = log_requests + + self.use_async_sockets = use_async_sockets + if self.use_async_sockets: + self.engine.process_request_outputs_callback = \ + self._async_socket_engine_callback + + self.ctx = zmq.Context() # type: ignore[attr-defined] + + # Receive input from the client. + self.input_socket = self.ctx.socket(zmq.constants.PULL) + self.input_socket.bind(f"{ipc_path}{IPC_INPUT_EXT}") + + # Send output stream back to client. + self.output_socket = self.ctx.socket(zmq.constants.PUSH) + self.output_socket.bind(f"{ipc_path}{IPC_OUTPUT_EXT}") + + # Send health status back to client. + self.health_socket = self.ctx.socket(zmq.constants.PUSH) + self.health_socket.bind(f"{ipc_path}{IPC_HEALTH_EXT}") + + # IPC path for the data socket. + self.data_ipc_path = f"{ipc_path}{IPC_DATA_EXT}" + + # Error state. + self._errored_with: Optional[BaseException] = None + + @property + def dead_error(self) -> BaseException: + if self._errored_with is not None: + return ENGINE_DEAD_ERROR(self._errored_with) + else: + return ENGINE_DEAD_ERROR() + + @classmethod + def from_engine_args(cls, engine_args: AsyncEngineArgs, + usage_context: UsageContext, ipc_path: str): + """Creates an MQLLMEngine from the engine arguments.""" + + engine_config = engine_args.create_engine_config() + + executor_class = LLMEngine._get_executor_cls(engine_config) + + return cls( + ipc_path=ipc_path, + use_async_sockets=engine_config.model_config.use_async_output_proc, + **engine_config.to_dict(), + executor_class=executor_class, + log_requests=not engine_args.disable_log_requests, + log_stats=not engine_args.disable_log_stats, + usage_context=usage_context) + + def start(self): + try: + try: + logger.debug("Starting Startup Loop.") + self.run_startup_loop() + logger.debug("Starting Engine Loop.") + self.run_engine_loop() + except Exception as e: + logger.exception(repr(e)) + except KeyboardInterrupt: + logger.debug("Shutting down MQLLMEngine.") + finally: + logger.debug("MQLLMEngine is shut down.") + self.cleanup() + + def cleanup(self): + """Cleanup zeromq state on shutdown.""" + # Closes all sockets and destroys context. + self.ctx.destroy(linger=0) + del self.engine + + @contextmanager + def make_data_socket( + self) -> Iterator[zmq.Socket]: # type: ignore[name-defined] + socket = self.ctx.socket(zmq.constants.ROUTER) + try: + socket.bind(self.data_ipc_path) + yield socket + finally: + socket.close(linger=0) + + def run_startup_loop(self) -> None: + """Startup loop for sending data from Engine -> Client.""" + + with self.make_data_socket() as socket: + response: Union[RPCStartupResponse, BaseException] + try: + identity, message = socket.recv_multipart(copy=False) + request: RPCStartupRequest = pickle.loads(message.buffer) + + # Handle the query from the Client. + if request == RPCStartupRequest.IS_SERVER_READY: + tracing_enabled = self.engine.is_tracing_enabled() + response = RPCStartupResponse( + tracing_enabled=tracing_enabled) + + except Exception as e: + response = e + + socket.send_multipart((identity, pickle.dumps(response)), + copy=False) + + def run_engine_loop(self): + """Core busy loop of the LLMEngine.""" + + while True: + if not self.engine.has_unfinished_requests(): + # Poll until there is work to do. + while self.input_socket.poll(timeout=POLLING_TIMEOUT_MS) == 0: + self.engine.do_log_stats() + logger.debug("Waiting for new requests in engine loop.") + + # Handle any input from the client. + self.handle_new_input() + + # Engine step. + request_outputs = self.engine_step() + + # Send request outputs (if async, done in engine_step callback). + if not self.use_async_sockets: + self._send_outputs(request_outputs) + + def engine_step(self) -> List[RequestOutput]: + """Engine step wrapper with error handling.""" + + try: + return self.engine.step() + except SystemExit: + raise + except BaseException as e: + self._set_errored(e) + rpc_err = RPCError(request_id=None, + is_engine_errored=True, + exception=e) + self._send_outputs(rpc_err) + raise e + + def handle_new_input(self): + """Handle new input from the socket""" + try: + while self.input_socket.poll(timeout=0) != 0: + frames = self.input_socket.recv_multipart(copy=False) + request = pickle.loads(frames[0].buffer) + + if isinstance(request, RPCGenerateRequest): + if len(frames) > 1: + # Use cloudpickle for logits processors + lprocs = cloudpickle.loads(frames[1].buffer) + request.sampling_params.logits_processors = lprocs + self._handle_generate_request(request) + elif isinstance(request, RPCAbortRequest): + self._handle_abort_request(request) + elif isinstance(request, RPCHealthRequest): + self._handle_health_request() + else: + raise ValueError("Unknown RPCRequest Type: {request}") + + except Exception as e: + self._set_errored(e) + self._send_unhealthy(e) + raise e + + def _handle_generate_request(self, request: RPCGenerateRequest): + """Handle RPCGenerateRequest by adding it to the LLMEngine.""" + request_id = request.request_id + + if self._errored_with is not None: + rpc_err = RPCError(request_id=request_id, + is_engine_errored=True, + exception=ENGINE_DEAD_ERROR(self._errored_with)) + self._send_outputs(rpc_err) + + try: + self.engine.add_request( + request_id=request_id, + inputs=request.inputs, + params=request.sampling_params, + lora_request=request.lora_request, + trace_headers=request.trace_headers, + prompt_adapter_request=request.prompt_adapter_request) + + if self.log_requests: + logger.info("Added request %s.", request.request_id) + + except Exception as e: + # We do not set self._errored = True here, since the error + # is due to an issue adding this request to the engine, + # rather than an issue with the engine itself. + is_errored = self._errored_with is not None + rpc_err = RPCError(request_id=request_id, + is_engine_errored=is_errored, + exception=e) + self._send_outputs(rpc_err) + + # Remove request from the engine. + self.engine.abort_request(request_id) + + def _handle_abort_request(self, request: RPCAbortRequest): + self.engine.abort_request(request.request_id) + if self.log_requests: + logger.info("Aborted request %s.", request.request_id) + + def _handle_health_request(self): + if self._errored_with is not None: + self._send_unhealthy(self._errored_with) + + # Raises error if unhealthy. + self.engine.check_health() + self._send_healthy() + + def _send_outputs(self, outputs: REQUEST_OUTPUTS_T): + """Send List of RequestOutput to RPCClient.""" + if outputs: + output_bytes = pickle.dumps(outputs) + self.output_socket.send_multipart((output_bytes, ), copy=False) + + def _send_healthy(self): + """Send HEALTHY message to RPCClient.""" + self.health_socket.send_multipart(HEALTHY_RESPONSE, copy=False) + + def _send_unhealthy(self, error: BaseException): + """Send UNHEALTHY message to RPCClient.""" + error_bytes = pickle.dumps(error) + self.health_socket.send_multipart((error_bytes, ), copy=False) + + def _async_socket_engine_callback(self, + request_outputs: REQUEST_OUTPUTS_T): + """Callback used by engine to make socket handling async with GPU.""" + self._send_outputs(request_outputs) + self.handle_new_input() + + def _set_errored(self, e: BaseException): + """Log and set errored status if this is the first issue.""" + if self._errored_with is None: + self._errored_with = e + + +def run_mp_engine(engine_args: AsyncEngineArgs, usage_context: UsageContext, + ipc_path: str): + + def signal_handler(*_) -> None: + # Interrupt server on sigterm + raise KeyboardInterrupt("MQLLMEngine terminated") + + signal.signal(signal.SIGTERM, signal_handler) + + engine = MQLLMEngine.from_engine_args(engine_args=engine_args, + usage_context=usage_context, + ipc_path=ipc_path) + engine.start() diff --git a/vllm/engine/protocol.py b/vllm/engine/protocol.py index 34ae79f5fa8df..70444faa670a2 100644 --- a/vllm/engine/protocol.py +++ b/vllm/engine/protocol.py @@ -14,8 +14,8 @@ @runtime_checkable -class AsyncEngineClient(Protocol): - """Protocol class for Clients to AsyncLLMEngine""" +class EngineClient(Protocol): + """Protocol class for Clients to Engine""" @property def is_running(self) -> bool: @@ -30,8 +30,8 @@ def errored(self) -> bool: ... @property - def limit_concurrency(self) -> Optional[int]: - """Maximum number of concurrently running requests.""" + def dead_error(self) -> BaseException: + ... def generate( self, diff --git a/vllm/entrypoints/launcher.py b/vllm/entrypoints/launcher.py index 3598872b65bb0..5dcf50bd1b0a1 100644 --- a/vllm/entrypoints/launcher.py +++ b/vllm/entrypoints/launcher.py @@ -4,19 +4,18 @@ from typing import Any import uvicorn -from fastapi import FastAPI, Response +from fastapi import FastAPI, Request, Response from vllm import envs from vllm.engine.async_llm_engine import AsyncEngineDeadError -from vllm.engine.protocol import AsyncEngineClient +from vllm.engine.multiprocessing import MQEngineDeadError from vllm.logger import init_logger from vllm.utils import find_process_using_port logger = init_logger(__name__) -async def serve_http(app: FastAPI, engine: AsyncEngineClient, - **uvicorn_kwargs: Any): +async def serve_http(app: FastAPI, **uvicorn_kwargs: Any): logger.info("Available routes are:") for route in app.routes: methods = getattr(route, "methods", None) @@ -27,18 +26,9 @@ async def serve_http(app: FastAPI, engine: AsyncEngineClient, logger.info("Route: %s, Methods: %s", path, ', '.join(methods)) - # Set concurrency limits in uvicorn if running in multiprocessing mode - # since zmq has maximum socket limit of zmq.constants.SOCKET_LIMIT (65536). - if engine.limit_concurrency is not None: - logger.info( - "Launching Uvicorn with --limit_concurrency %s. To avoid this " - "limit at the expense of performance run with " - "--disable-frontend-multiprocessing", engine.limit_concurrency) - uvicorn_kwargs["limit_concurrency"] = engine.limit_concurrency - config = uvicorn.Config(app, **uvicorn_kwargs) server = uvicorn.Server(config) - _add_shutdown_handlers(app, server, engine) + _add_shutdown_handlers(app, server) loop = asyncio.get_running_loop() @@ -64,19 +54,19 @@ async def dummy_shutdown() -> None: logger.debug( "port %s is used by process %s launched with command:\n%s", port, process, " ".join(process.cmdline())) - logger.info("Gracefully stopping http server") + logger.info("Shutting down FastAPI HTTP server.") return server.shutdown() -def _add_shutdown_handlers(app: FastAPI, server: uvicorn.Server, - engine: AsyncEngineClient) -> None: +def _add_shutdown_handlers(app: FastAPI, server: uvicorn.Server) -> None: """Adds handlers for fatal errors that should crash the server""" @app.exception_handler(RuntimeError) - async def runtime_error_handler(_, __): + async def runtime_error_handler(request: Request, __): """On generic runtime error, check to see if the engine has died. It probably has, in which case the server will no longer be able to handle requests. Trigger a graceful shutdown with a SIGTERM.""" + engine = request.app.state.engine_client if (not envs.VLLM_KEEP_ALIVE_ON_ENGINE_DEATH and engine.errored and not engine.is_running): logger.fatal("AsyncLLMEngine has failed, terminating server " @@ -91,7 +81,7 @@ async def runtime_error_handler(_, __): return Response(status_code=HTTPStatus.INTERNAL_SERVER_ERROR) @app.exception_handler(AsyncEngineDeadError) - async def engine_dead_handler(_, __): + async def async_engine_dead_handler(_, __): """Kill the server if the async engine is already dead. It will not handle any further requests.""" if not envs.VLLM_KEEP_ALIVE_ON_ENGINE_DEATH: @@ -100,3 +90,14 @@ async def engine_dead_handler(_, __): server.should_exit = True return Response(status_code=HTTPStatus.INTERNAL_SERVER_ERROR) + + @app.exception_handler(MQEngineDeadError) + async def mq_engine_dead_handler(_, __): + """Kill the server if the mq engine is already dead. It will + not handle any further requests.""" + if not envs.VLLM_KEEP_ALIVE_ON_ENGINE_DEATH: + logger.fatal("MQLLMEngine is already dead, terminating server " + "process") + server.should_exit = True + + return Response(status_code=HTTPStatus.INTERNAL_SERVER_ERROR) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index b1d9f386b6c3e..248b070611cd2 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -1,5 +1,6 @@ from contextlib import contextmanager -from typing import ClassVar, List, Optional, Sequence, Union, cast, overload +from typing import (Any, ClassVar, Dict, List, Optional, Sequence, Union, cast, + overload) from tqdm import tqdm @@ -19,7 +20,7 @@ from vllm.outputs import EmbeddingRequestOutput, RequestOutput from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest -from vllm.sampling_params import SamplingParams +from vllm.sampling_params import RequestOutputKind, SamplingParams from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer, get_cached_tokenizer) from vllm.transformers_utils.tokenizer_group import TokenizerGroup @@ -88,7 +89,9 @@ class LLM: to eager mode (DEPRECATED. Use `max_seq_len_to_capture` instead). max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs. When a sequence has context length larger than this, we fall back - to eager mode. + to eager mode. Additionally for encoder-decoder models, if the + sequence length of the encoder input is larger than this, we fall + back to the eager mode. disable_custom_all_reduce: See ParallelConfig **kwargs: Arguments for :class:`~vllm.EngineArgs`. (See :ref:`engine_args`) @@ -137,9 +140,7 @@ def __init__( LLM constructor. Note: if enforce_eager is unset (enforce_eager is None) - it defaults to False for decoder-only models and True - for encoder/decoder models, since encoder/decoder models - do not currently support CUDAGraph. + it defaults to False. ''' if "disable_log_stats" not in kwargs: @@ -357,6 +358,7 @@ def chat( lora_request: Optional[LoRARequest] = None, chat_template: Optional[str] = None, add_generation_prompt: bool = True, + tools: Optional[List[Dict[str, Any]]] = None, ) -> List[RequestOutput]: """ Generate responses for a chat conversation. @@ -401,6 +403,7 @@ def chat( messages=messages, chat_template=chat_template, add_generation_prompt=add_generation_prompt, + tools=tools, ) else: prompt = apply_hf_chat_template( @@ -408,6 +411,7 @@ def chat( conversation=conversation, chat_template=chat_template, add_generation_prompt=add_generation_prompt, + tools=tools, ) inputs: PromptInputs @@ -642,14 +646,12 @@ def _validate_and_add_requests( raise ValueError("The lengths of prompts and lora_request " "must be the same.") - if isinstance(params, list): - params = [ - self._add_guided_processor(param, guided_options) - if isinstance(param, SamplingParams) else param - for param in params - ] - elif isinstance(params, SamplingParams): - params = self._add_guided_processor(params, guided_options) + for sp in params if isinstance(params, list) else (params, ): + if isinstance(sp, SamplingParams): + self._add_guided_processor(sp, guided_options) + + # We only care about the final output + sp.output_kind = RequestOutputKind.FINAL_ONLY # Add requests to the engine. for i, request_inputs in enumerate(inputs): @@ -709,9 +711,6 @@ def _run_engine( f"output: {0:.2f} toks/s"), ) - # In the loop below, only finished outputs are used - self.llm_engine.step_return_finished_only = True - # Run the engine. outputs: List[Union[RequestOutput, EmbeddingRequestOutput]] = [] total_in_toks = 0 @@ -724,6 +723,7 @@ def _run_engine( if use_tqdm: if isinstance(output, RequestOutput): # Calculate tokens only for RequestOutput + assert output.prompt_token_ids is not None total_in_toks += len(output.prompt_token_ids) in_spd = total_in_toks / pbar.format_dict["elapsed"] total_out_toks += sum( @@ -735,9 +735,6 @@ def _run_engine( f"output: {out_spd:.2f} toks/s") pbar.update(1) - # Restore original behavior - self.llm_engine.step_return_finished_only = False - if use_tqdm: pbar.close() # Sort the outputs by request ID. diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index d8704d5e24964..1b9eb30252417 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -4,16 +4,21 @@ import multiprocessing import os import re +import signal +import socket import tempfile from argparse import Namespace from contextlib import asynccontextmanager +from functools import partial from http import HTTPStatus from typing import AsyncIterator, Optional, Set +import uvloop from fastapi import APIRouter, FastAPI, Request from fastapi.exceptions import RequestValidationError from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, Response, StreamingResponse +from starlette.datastructures import State from starlette.routing import Mount from typing_extensions import assert_never @@ -21,7 +26,9 @@ from vllm.config import ModelConfig from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_llm_engine import AsyncLLMEngine -from vllm.engine.protocol import AsyncEngineClient +from vllm.engine.multiprocessing.client import MQLLMEngineClient +from vllm.engine.multiprocessing.engine import run_mp_engine +from vllm.engine.protocol import EngineClient from vllm.entrypoints.launcher import serve_http from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.openai.cli_args import make_arg_parser @@ -39,8 +46,6 @@ TokenizeRequest, TokenizeResponse, UnloadLoraAdapterRequest) -from vllm.entrypoints.openai.rpc.client import AsyncEngineRPCClient -from vllm.entrypoints.openai.rpc.server import run_rpc_server # yapf: enable from vllm.entrypoints.openai.serving_chat import OpenAIServingChat from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion @@ -54,12 +59,6 @@ TIMEOUT_KEEP_ALIVE = 5 # seconds -async_engine_client: AsyncEngineClient -engine_args: AsyncEngineArgs -openai_serving_chat: OpenAIServingChat -openai_serving_completion: OpenAIServingCompletion -openai_serving_embedding: OpenAIServingEmbedding -openai_serving_tokenization: OpenAIServingTokenization prometheus_multiproc_dir: tempfile.TemporaryDirectory # Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765) @@ -68,49 +67,42 @@ _running_tasks: Set[asyncio.Task] = set() -def model_is_embedding(model_name: str, trust_remote_code: bool, - quantization: Optional[str]) -> bool: - return ModelConfig(model=model_name, - tokenizer=model_name, - tokenizer_mode="auto", - trust_remote_code=trust_remote_code, - quantization=quantization, - seed=0, - dtype="auto").embedding_mode - - @asynccontextmanager async def lifespan(app: FastAPI): - - async def _force_log(): - while True: - await asyncio.sleep(10) - await async_engine_client.do_log_stats() - - if not engine_args.disable_log_stats: - task = asyncio.create_task(_force_log()) - _running_tasks.add(task) - task.add_done_callback(_running_tasks.remove) - - yield + try: + if app.state.log_stats: + engine_client: EngineClient = app.state.engine_client + + async def _force_log(): + while True: + await asyncio.sleep(10.) + await engine_client.do_log_stats() + + task = asyncio.create_task(_force_log()) + _running_tasks.add(task) + task.add_done_callback(_running_tasks.remove) + else: + task = None + try: + yield + finally: + if task is not None: + task.cancel() + finally: + # Ensure app state including engine ref is gc'd + del app.state @asynccontextmanager async def build_async_engine_client( - args: Namespace) -> AsyncIterator[Optional[AsyncEngineClient]]: + args: Namespace) -> AsyncIterator[Optional[EngineClient]]: - # Context manager to handle async_engine_client lifecycle + # Context manager to handle engine_client lifecycle # Ensures everything is shutdown and cleaned up on error/exit - global engine_args engine_args = AsyncEngineArgs.from_cli_args(args) - # Backend itself still global for the silly lil' health handler - global async_engine_client - async with build_async_engine_client_from_engine_args( engine_args, args.disable_frontend_multiprocessing) as engine: - - async_engine_client = engine # type: ignore[assignment] yield engine @@ -118,26 +110,35 @@ async def build_async_engine_client( async def build_async_engine_client_from_engine_args( engine_args: AsyncEngineArgs, disable_frontend_multiprocessing: bool = False, -) -> AsyncIterator[Optional[AsyncEngineClient]]: +) -> AsyncIterator[Optional[EngineClient]]: """ - Create AsyncEngineClient, either: + Create EngineClient, either: - in-process using the AsyncLLMEngine Directly - multiprocess using AsyncLLMEngine RPC Returns the Client or None if the creation failed. """ - # If manually triggered or embedding model, use AsyncLLMEngine in process. - # TODO: support embedding model via RPC. - if (model_is_embedding(engine_args.model, engine_args.trust_remote_code, - engine_args.quantization) + # Fall back + # TODO: fill out feature matrix. + if (MQLLMEngineClient.is_unsupported_config(engine_args) or disable_frontend_multiprocessing): - engine_client = AsyncLLMEngine.from_engine_args( - engine_args, usage_context=UsageContext.OPENAI_API_SERVER) - try: - yield engine_client - finally: - engine_client.shutdown_background_loop() + engine_config = engine_args.create_engine_config() + uses_ray = getattr(AsyncLLMEngine._get_executor_cls(engine_config), + "uses_ray", False) + + build_engine = partial(AsyncLLMEngine.from_engine_args, + engine_args=engine_args, + engine_config=engine_config, + usage_context=UsageContext.OPENAI_API_SERVER) + if uses_ray: + # Must run in main thread with ray for its signal handlers to work + engine_client = build_engine() + else: + engine_client = await asyncio.get_running_loop().run_in_executor( + None, build_engine) + + yield engine_client return # Otherwise, use the multiprocessing AsyncLLMEngine. @@ -158,56 +159,60 @@ async def build_async_engine_client_from_engine_args( "and vLLM will properly handle cleanup.") # Select random path for IPC. - rpc_path = get_open_zmq_ipc_path() - logger.info("Multiprocessing frontend to use %s for RPC Path.", - rpc_path) - - # Build RPCClient, which conforms to AsyncEngineClient Protocol. - # NOTE: Actually, this is not true yet. We still need to support - # embedding models via RPC (see TODO above) - rpc_client = AsyncEngineRPCClient(rpc_path) + ipc_path = get_open_zmq_ipc_path() + logger.info("Multiprocessing frontend to use %s for IPC Path.", + ipc_path) - # Start RPCServer in separate process (holds the AsyncLLMEngine). - context = multiprocessing.get_context("spawn") + # Start RPCServer in separate process (holds the LLMEngine). # the current process might have CUDA context, # so we need to spawn a new process - rpc_server_process = context.Process( - target=run_rpc_server, - args=(engine_args, UsageContext.OPENAI_API_SERVER, rpc_path)) - rpc_server_process.start() - logger.info("Started engine process with PID %d", - rpc_server_process.pid) + context = multiprocessing.get_context("spawn") + + engine_process = context.Process(target=run_mp_engine, + args=(engine_args, + UsageContext.OPENAI_API_SERVER, + ipc_path)) + engine_process.start() + logger.info("Started engine process with PID %d", engine_process.pid) + + # Build RPCClient, which conforms to EngineClient Protocol. + # NOTE: Actually, this is not true yet. We still need to support + # embedding models via RPC (see TODO above) + engine_config = engine_args.create_engine_config() + mp_engine_client = MQLLMEngineClient(ipc_path, engine_config) try: while True: try: - await rpc_client.setup() + await mp_engine_client.setup() break except TimeoutError: - if not rpc_server_process.is_alive(): - logger.error( - "RPCServer process died before responding " - "to readiness probe") + if not engine_process.is_alive(): + logger.error("Engine process died before responding " + "to readiness probe") yield None return - yield rpc_client # type: ignore[misc] + yield mp_engine_client # type: ignore[misc] finally: # Ensure rpc server process was terminated - rpc_server_process.terminate() + engine_process.terminate() # Close all open connections to the backend - rpc_client.close() + mp_engine_client.close() - # Wait for server process to join - rpc_server_process.join() + # Wait for engine process to join + engine_process.join(4) + if engine_process.exitcode is None: + # Kill if taking longer than 5 seconds to stop + engine_process.kill() # Lazy import for prometheus multiprocessing. # We need to set PROMETHEUS_MULTIPROC_DIR environment variable # before prometheus_client is imported. # See https://prometheus.github.io/client_python/multiprocess/ from prometheus_client import multiprocess - multiprocess.mark_process_dead(rpc_server_process.pid) + multiprocess.mark_process_dead(engine_process.pid) router = APIRouter() @@ -239,16 +244,36 @@ def mount_metrics(app: FastAPI): app.routes.append(metrics_route) +def chat(request: Request) -> OpenAIServingChat: + return request.app.state.openai_serving_chat + + +def completion(request: Request) -> OpenAIServingCompletion: + return request.app.state.openai_serving_completion + + +def tokenization(request: Request) -> OpenAIServingTokenization: + return request.app.state.openai_serving_tokenization + + +def embedding(request: Request) -> OpenAIServingEmbedding: + return request.app.state.openai_serving_embedding + + +def engine_client(request: Request) -> EngineClient: + return request.app.state.engine_client + + @router.get("/health") -async def health() -> Response: +async def health(raw_request: Request) -> Response: """Health check.""" - await async_engine_client.check_health() + await engine_client(raw_request).check_health() return Response(status_code=200) @router.post("/tokenize") -async def tokenize(request: TokenizeRequest): - generator = await openai_serving_tokenization.create_tokenize(request) +async def tokenize(request: TokenizeRequest, raw_request: Request): + generator = await tokenization(raw_request).create_tokenize(request) if isinstance(generator, ErrorResponse): return JSONResponse(content=generator.model_dump(), status_code=generator.code) @@ -259,8 +284,8 @@ async def tokenize(request: TokenizeRequest): @router.post("/detokenize") -async def detokenize(request: DetokenizeRequest): - generator = await openai_serving_tokenization.create_detokenize(request) +async def detokenize(request: DetokenizeRequest, raw_request: Request): + generator = await tokenization(raw_request).create_detokenize(request) if isinstance(generator, ErrorResponse): return JSONResponse(content=generator.model_dump(), status_code=generator.code) @@ -271,8 +296,8 @@ async def detokenize(request: DetokenizeRequest): @router.get("/v1/models") -async def show_available_models(): - models = await openai_serving_completion.show_available_models() +async def show_available_models(raw_request: Request): + models = await completion(raw_request).show_available_models() return JSONResponse(content=models.model_dump()) @@ -286,7 +311,7 @@ async def show_version(): async def create_chat_completion(request: ChatCompletionRequest, raw_request: Request): - generator = await openai_serving_chat.create_chat_completion( + generator = await chat(raw_request).create_chat_completion( request, raw_request) if isinstance(generator, ErrorResponse): @@ -301,7 +326,7 @@ async def create_chat_completion(request: ChatCompletionRequest, @router.post("/v1/completions") async def create_completion(request: CompletionRequest, raw_request: Request): - generator = await openai_serving_completion.create_completion( + generator = await completion(raw_request).create_completion( request, raw_request) if isinstance(generator, ErrorResponse): return JSONResponse(content=generator.model_dump(), @@ -314,7 +339,7 @@ async def create_completion(request: CompletionRequest, raw_request: Request): @router.post("/v1/embeddings") async def create_embedding(request: EmbeddingRequest, raw_request: Request): - generator = await openai_serving_embedding.create_embedding( + generator = await embedding(raw_request).create_embedding( request, raw_request) if isinstance(generator, ErrorResponse): return JSONResponse(content=generator.model_dump(), @@ -331,16 +356,16 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request): "used for local development!") @router.post("/start_profile") - async def start_profile(): + async def start_profile(raw_request: Request): logger.info("Starting profiler...") - await async_engine_client.start_profile() + await engine_client(raw_request).start_profile() logger.info("Profiler started.") return Response(status_code=200) @router.post("/stop_profile") - async def stop_profile(): + async def stop_profile(raw_request: Request): logger.info("Stopping profiler...") - await async_engine_client.stop_profile() + await engine_client(raw_request).stop_profile() logger.info("Profiler stopped.") return Response(status_code=200) @@ -351,13 +376,14 @@ async def stop_profile(): "This should ONLY be used for local development!") @router.post("/v1/load_lora_adapter") - async def load_lora_adapter(request: LoadLoraAdapterRequest): - response = await openai_serving_chat.load_lora_adapter(request) + async def load_lora_adapter(request: LoadLoraAdapterRequest, + raw_request: Request): + response = await chat(raw_request).load_lora_adapter(request) if isinstance(response, ErrorResponse): return JSONResponse(content=response.model_dump(), status_code=response.code) - response = await openai_serving_completion.load_lora_adapter(request) + response = await completion(raw_request).load_lora_adapter(request) if isinstance(response, ErrorResponse): return JSONResponse(content=response.model_dump(), status_code=response.code) @@ -365,13 +391,14 @@ async def load_lora_adapter(request: LoadLoraAdapterRequest): return Response(status_code=200, content=response) @router.post("/v1/unload_lora_adapter") - async def unload_lora_adapter(request: UnloadLoraAdapterRequest): - response = await openai_serving_chat.unload_lora_adapter(request) + async def unload_lora_adapter(request: UnloadLoraAdapterRequest, + raw_request: Request): + response = await chat(raw_request).unload_lora_adapter(request) if isinstance(response, ErrorResponse): return JSONResponse(content=response.model_dump(), status_code=response.code) - response = await openai_serving_completion.unload_lora_adapter(request) + response = await completion(raw_request).unload_lora_adapter(request) if isinstance(response, ErrorResponse): return JSONResponse(content=response.model_dump(), status_code=response.code) @@ -380,7 +407,13 @@ async def unload_lora_adapter(request: UnloadLoraAdapterRequest): def build_app(args: Namespace) -> FastAPI: - app = FastAPI(lifespan=lifespan) + if args.disable_fastapi_docs: + app = FastAPI(openapi_url=None, + docs_url=None, + redoc_url=None, + lifespan=lifespan) + else: + app = FastAPI(lifespan=lifespan) app.include_router(router) app.root_path = args.root_path @@ -396,7 +429,8 @@ def build_app(args: Namespace) -> FastAPI: @app.exception_handler(RequestValidationError) async def validation_exception_handler(_, exc): - err = openai_serving_chat.create_error_response(message=str(exc)) + chat = app.state.openai_serving_chat + err = chat.create_error_response(message=str(exc)) return JSONResponse(err.model_dump(), status_code=HTTPStatus.BAD_REQUEST) @@ -428,31 +462,27 @@ async def authentication(request: Request, call_next): return app -async def init_app( - async_engine_client: AsyncEngineClient, +def init_app_state( + engine_client: EngineClient, + model_config: ModelConfig, + state: State, args: Namespace, -) -> FastAPI: - app = build_app(args) - +) -> None: if args.served_model_name is not None: served_model_names = args.served_model_name else: served_model_names = [args.model] - model_config = await async_engine_client.get_model_config() - if args.disable_log_requests: request_logger = None else: request_logger = RequestLogger(max_log_len=args.max_log_len) - global openai_serving_chat - global openai_serving_completion - global openai_serving_embedding - global openai_serving_tokenization + state.engine_client = engine_client + state.log_stats = not args.disable_log_stats - openai_serving_chat = OpenAIServingChat( - async_engine_client, + state.openai_serving_chat = OpenAIServingChat( + engine_client, model_config, served_model_names, args.response_role, @@ -463,8 +493,8 @@ async def init_app( return_tokens_as_token_ids=args.return_tokens_as_token_ids, enable_auto_tools=args.enable_auto_tool_choice, tool_parser=args.tool_call_parser) - openai_serving_completion = OpenAIServingCompletion( - async_engine_client, + state.openai_serving_completion = OpenAIServingCompletion( + engine_client, model_config, served_model_names, lora_modules=args.lora_modules, @@ -472,39 +502,49 @@ async def init_app( request_logger=request_logger, return_tokens_as_token_ids=args.return_tokens_as_token_ids, ) - openai_serving_embedding = OpenAIServingEmbedding( - async_engine_client, + state.openai_serving_embedding = OpenAIServingEmbedding( + engine_client, model_config, served_model_names, request_logger=request_logger, ) - openai_serving_tokenization = OpenAIServingTokenization( - async_engine_client, + state.openai_serving_tokenization = OpenAIServingTokenization( + engine_client, model_config, served_model_names, lora_modules=args.lora_modules, request_logger=request_logger, chat_template=args.chat_template, ) - app.root_path = args.root_path - - return app async def run_server(args, **uvicorn_kwargs) -> None: logger.info("vLLM API server version %s", VLLM_VERSION) logger.info("args: %s", args) - async with build_async_engine_client(args) as async_engine_client: + temp_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + temp_socket.bind(("", args.port)) + + def signal_handler(*_) -> None: + # Interrupt server on sigterm while initializing + raise KeyboardInterrupt("terminated") + + signal.signal(signal.SIGTERM, signal_handler) + + async with build_async_engine_client(args) as engine_client: # If None, creation of the client failed and we exit. - if async_engine_client is None: + if engine_client is None: return - app = await init_app(async_engine_client, args) + app = build_app(args) + + model_config = await engine_client.get_model_config() + init_app_state(engine_client, model_config, app.state, args) + + temp_socket.close() shutdown_task = await serve_http( app, - engine=async_engine_client, host=args.host, port=args.port, log_level=args.uvicorn_log_level, @@ -528,4 +568,4 @@ async def run_server(args, **uvicorn_kwargs) -> None: parser = make_arg_parser(parser) args = parser.parse_args() - asyncio.run(run_server(args)) + uvloop.run(run_server(args)) diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py index 7ccee0b6b55b7..bbb0823de9a51 100644 --- a/vllm/entrypoints/openai/cli_args.py +++ b/vllm/entrypoints/openai/cli_args.py @@ -190,6 +190,13 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: 'ID numbers being printed in log.' '\n\nDefault: Unlimited') + parser.add_argument( + "--disable-fastapi-docs", + action='store_true', + default=False, + help="Disable FastAPI's OpenAPI schema, Swagger UI, and ReDoc endpoint" + ) + return parser diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 374196044b7e8..7e9f53b1816d1 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -12,7 +12,8 @@ from vllm.entrypoints.chat_utils import ChatCompletionMessageParam from vllm.entrypoints.openai.logits_processors import get_logits_processors from vllm.pooling_params import PoolingParams -from vllm.sampling_params import LogitsProcessor, SamplingParams +from vllm.sampling_params import (LogitsProcessor, RequestOutputKind, + SamplingParams) from vllm.sequence import Logprob from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.utils import random_uuid @@ -316,6 +317,8 @@ def to_sampling_params( length_penalty=self.length_penalty, logits_processors=logits_processors, truncate_prompt_tokens=self.truncate_prompt_tokens, + output_kind=RequestOutputKind.DELTA if self.stream \ + else RequestOutputKind.FINAL_ONLY, ) @model_validator(mode="before") @@ -559,6 +562,8 @@ def to_sampling_params( length_penalty=self.length_penalty, logits_processors=logits_processors, truncate_prompt_tokens=self.truncate_prompt_tokens, + output_kind=RequestOutputKind.DELTA if self.stream \ + else RequestOutputKind.FINAL_ONLY, ) @model_validator(mode="before") diff --git a/vllm/entrypoints/openai/rpc/__init__.py b/vllm/entrypoints/openai/rpc/__init__.py deleted file mode 100644 index efc7e43afdcc9..0000000000000 --- a/vllm/entrypoints/openai/rpc/__init__.py +++ /dev/null @@ -1,50 +0,0 @@ -from dataclasses import dataclass -from enum import Enum -from typing import Mapping, Optional, Union - -from vllm.inputs import PromptInputs -from vllm.lora.request import LoRARequest -from vllm.prompt_adapter.request import PromptAdapterRequest -from vllm.sampling_params import SamplingParams - -# Success string used for RPC instructions. -VLLM_RPC_SUCCESS_STR = "SUCCESS" - -# Minimum value of ZMQ.SOCKET_LIMIT to run mp. -VLLM_RPC_SOCKET_LIMIT_CUTOFF = 2000 - -# HWM is set to Infinity. -VLLM_RPC_ZMQ_HWM = 0 - - -@dataclass -class RPCGenerateRequest: - inputs: PromptInputs - sampling_params: SamplingParams - request_id: str - lora_request: Optional[LoRARequest] = None - trace_headers: Optional[Mapping[str, str]] = None - prompt_adapter_request: Optional[PromptAdapterRequest] = None - - -@dataclass -class RPCAbortRequest: - request_id: str - - -class RPCUtilityRequest(Enum): - IS_SERVER_READY = 1 - GET_MODEL_CONFIG = 2 - GET_DECODING_CONFIG = 3 - GET_PARALLEL_CONFIG = 4 - GET_SCHEDULER_CONFIG = 5 - GET_LORA_CONFIG = 6 - DO_LOG_STATS = 7 - IS_SERVER_HEALTHY = 8 - IS_TRACING_ENABLED = 9 - START_PROFILE = 10 - STOP_PROFILE = 11 - - -RPC_REQUEST_TYPE = Union[RPCGenerateRequest, RPCAbortRequest, - RPCUtilityRequest] diff --git a/vllm/entrypoints/openai/rpc/client.py b/vllm/entrypoints/openai/rpc/client.py deleted file mode 100644 index 9b88db746be5c..0000000000000 --- a/vllm/entrypoints/openai/rpc/client.py +++ /dev/null @@ -1,451 +0,0 @@ -import asyncio -import pickle -from contextlib import contextmanager, suppress -from typing import Any, AsyncGenerator, Iterator, Mapping, Optional -from uuid import uuid4 - -import cloudpickle -import zmq -import zmq.asyncio -from zmq import Frame # type: ignore[attr-defined] -from zmq.asyncio import Socket - -from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig, - ParallelConfig, SchedulerConfig) -# yapf: disable -from vllm.entrypoints.openai.rpc import (RPC_REQUEST_TYPE, - VLLM_RPC_SOCKET_LIMIT_CUTOFF, - VLLM_RPC_SUCCESS_STR, - VLLM_RPC_ZMQ_HWM, RPCAbortRequest, - RPCGenerateRequest, RPCUtilityRequest) -# yapf: enable -from vllm.envs import VLLM_RPC_GET_DATA_TIMEOUT_MS -from vllm.inputs import PromptInputs -from vllm.logger import init_logger -from vllm.lora.request import LoRARequest -from vllm.outputs import EmbeddingRequestOutput, RequestOutput -from vllm.prompt_adapter.request import PromptAdapterRequest -from vllm.sampling_params import SamplingParams -from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs - -logger = init_logger(__name__) - -# Path used for inprocess proxy. -INPROC_PROXY_PATH = f"inproc://{uuid4()}" - - -class RPCClientClosedError(Exception): - """Exception class raised when the client is used post-close. - - The client can be closed, which closes the ZMQ context. This normally - happens on server shutdown. In some cases, methods like abort and - do_log_stats will still be called and then try to open a socket, which - causes a ZMQError and creates a huge stack trace. - So, we throw this error such that we can suppress it. - """ - - -class AsyncEngineRPCClient: - """ - RPCClient that connects to the RPCServer wrapping AsyncLLMEngine. - - The overall design mirrors the Asynchronous Client Server Pattern - https://zguide.zeromq.org/docs/chapter3/#The-Asynchronous-Client-Server-Pattern - - On startup, the RPCClient: - - makes DEALER socket (to_rpc_server) that connects to the RPCServer - via ipc, which uses unix sockets under the hood - (https://libzmq.readthedocs.io/en/zeromq4-1/zmq_ipc.html) - - makes ROUTER socket (from_api_server) that binds to a random - inproc address, which uses memory under the hood - (https://libzmq.readthedocs.io/en/zeromq3-x/zmq_inproc.html) - - runs a proxy in a background asyncio task between - from_api_server (ROUTER, inproc) and to_rpc_server (DEALER ipc, ) - - Each request handled by the asyncio api_server calls generate(): - - make a DEALER socket that connects to from_api_server via inproc - - send a RCPGenerateRequest to the inproc socket - - background proxy forwards the request from inproc -> ipc - - RPCServer responds to the request one token at a time over ipc - - background proxy forwards the response from ipc -> inproc - - The connection looks like this: - DEALER <- inproc -> [ ROUTER | DEALER ] <- ipc -> DEALER - - Message routing is performed via identities that are managed by the - ROUTER socket. ROUTER sockets track every connection it has and - tells the caller about these. The way it tells the caller is to stick - the connection identity in front of each message received. When we - send the message via a ROUTER, we first send an identity frame. - See https://zguide.zeromq.org/docs/chapter3/#The-Extended-Reply-Envelope - for more details on connection identities. - - This proxy design enables us to use a single unix socket, which - improves performance by avoiding syscalls (~5%) and avoids resource limits - such as ulimit, which defaults to 1024 on ubuntu. - - Note: we run set_hwm(0) on each socket, which sets the HWM to inf, - which is required to avoid dropping messages under high load. - This is generally not advisable. However, since we are in control - of both sides of the connection + failure on either side is - catastrophic to the overall system health and memory profiling - suggests limited memory overhead relative to asyncio, we will - proceed for now. - - See https://zguide.zeromq.org/docs/chapter2/#High-Water-Marks - for more details on high water marks. - """ - - def __init__(self, rpc_path: str): - self.context = zmq.asyncio.Context() - self._data_timeout = VLLM_RPC_GET_DATA_TIMEOUT_MS - self._errored = False - - # Maximum number of sockets that can be opened (typically 65536). - # ZMQ_SOCKET_LIMIT (http://api.zeromq.org/4-2:zmq-ctx-get) - socket_limit = self.context.get(zmq.constants.SOCKET_LIMIT) - assert isinstance(socket_limit, int) - if socket_limit < VLLM_RPC_SOCKET_LIMIT_CUTOFF: - raise ValueError( - f"Found zmq.constants.SOCKET_LIMIT={socket_limit}, which caps " - "the number of concurrent requests vLLM can process. Launch " - "vLLM with --disable-frontend-multiprocessing and open a " - "GitHub issue so we can investigate.") - - # We only have 1 ipc connection that uses unix sockets, so - # safe to set MAX_SOCKETS to the zmq SOCKET_LIMIT (i.e. will - # not run into ulimit issues) - self.context.set(zmq.constants.MAX_SOCKETS, socket_limit) - - # IPC connection to RPC Server (uses unix sockets). - self.to_rpc_server: Socket = self.context.socket(zmq.constants.DEALER) - self.to_rpc_server.set_hwm(VLLM_RPC_ZMQ_HWM) - self.to_rpc_server.bind(rpc_path) - - # In process proxy to RPC Server (uses memory-based messaging). - self.from_api_server: Socket = self.context.socket( - zmq.constants.ROUTER) - self.from_api_server.set_hwm(VLLM_RPC_ZMQ_HWM) - self.from_api_server.bind(INPROC_PROXY_PATH) - - # Asyncio background task for the proxy. - self.proxy_in_task = asyncio.create_task( - self.run_proxy(self.from_api_server, self.to_rpc_server)) - self.proxy_out_task = asyncio.create_task( - self.run_proxy(self.to_rpc_server, self.from_api_server)) - - # Since we open 1 inproc socket per request, we have a hard cap on - # the number of requests that can run in vLLM w. frontend - # mulitprocessing. This value is used uvicorn to launch - # with --limit-concurrency to return 503 when server is overloaded. - # We need 2 sockets per request - 2: - # 1 for generate(), 1 for abort(), do_log_stats(), check_health() - self.limit_concurrency = socket_limit // 2 - 2 - - async def run_proxy(self, socket_from: Socket, socket_to: Socket): - """Background task that runs a proxy""" - while True: - frames = await socket_from.recv_multipart(copy=False) - await socket_to.send_multipart(frames, copy=False) - - async def setup(self): - """Setup the client before it starts sending server requests.""" - - # Wait until server is ready. - await self._wait_for_server_rpc() - - # Get the configs. - self.model_config = await self._get_model_config_rpc() - self.decoding_config = await self._get_decoding_config_rpc() - self.tracing_flag = await self._is_tracing_enabled_rpc() - - # Create the tokenizer group. - # TODO: refactor OAI server to avoid needing this info. - self.tokenizer = init_tokenizer_from_configs( - model_config=self.model_config, - scheduler_config=(await self._get_scheduler_config_rpc()), - parallel_config=(await self._get_parallel_config_rpc()), - enable_lora=bool(await self._get_lora_config_rpc()), - ) - - def close(self): - """Destroy the ZeroMQ Context.""" - # Close all sockets associated with this context and - # then terminate the context. - self.from_api_server.close() - self.to_rpc_server.close() - self.context.destroy() - - @contextmanager - def to_proxy_socket(self) -> Iterator[Socket]: - # Connect to the RPCServer via the proxy. - - # Raise a sensible error if the client was already closed. - # This can happen if a server shutdown is triggered but some coroutines - # are still running requests. - # There should not be a race condition with this check because we don't - # yield to the event loop between here and opening the socket. - if self.context.closed: - raise RPCClientClosedError("The ZMQ client has already shut down") - - # Note that we use DEALER to enable asynchronous communication - # to enable streaming. - socket = self.context.socket(zmq.constants.DEALER) - socket.set_hwm(VLLM_RPC_ZMQ_HWM) - try: - socket.connect(INPROC_PROXY_PATH) - yield socket - finally: - socket.close(linger=0) - - async def _send_get_data_rpc_request(self, request: RPCUtilityRequest, - expected_type: Any, - error_message: str) -> Any: - """Send an RPC request that is expecting data back.""" - - with self.to_proxy_socket() as socket: - # Ping RPCServer with a request. - await socket.send_multipart((cloudpickle.dumps(request), ), - copy=False) - - # Make sure the server responds - if await socket.poll(timeout=self._data_timeout) == 0: - raise TimeoutError("Server didn't reply within " - f"{self._data_timeout} ms") - - # Await the data from the Server. - frame = await socket.recv(copy=False) - assert isinstance(frame, Frame) - data = pickle.loads(frame.buffer) - - if isinstance(data, Exception): - # Re-raise exceptions returned by the server - raise data - - if not isinstance(data, expected_type): - # LoRAConfig can be None. - if expected_type == LoRAConfig and data is None: - pass - elif isinstance(data, Exception): - logger.error(error_message) - raise data - else: - raise ValueError(error_message) - - return data - - async def _send_one_way_rpc_request(self, - request: RPC_REQUEST_TYPE, - error_message: str, - socket: Optional[Socket] = None): - """Send one-way RPC request to trigger an action.""" - - async def do_rpc_call(socket: Socket, request: RPC_REQUEST_TYPE): - - await socket.send_multipart((cloudpickle.dumps(request), )) - - if await socket.poll(timeout=self._data_timeout) == 0: - raise TimeoutError("Server didn't reply within " - f"{self._data_timeout} ms") - - frame = await socket.recv(copy=False) - assert isinstance(frame, Frame) - return pickle.loads(frame.buffer) - - # Make a new socket connection. - if socket is None: - with self.to_proxy_socket() as socket: - response = await do_rpc_call(socket, request) - - # Use existing socket connection. - else: - response = await do_rpc_call(socket, request) - - if not isinstance(response, str) or response != VLLM_RPC_SUCCESS_STR: - if isinstance(response, Exception): - logger.error(error_message) - raise response - raise ValueError(error_message) - - async def get_tokenizer(self, lora_request: LoRARequest): - return await self.tokenizer.get_lora_tokenizer_async(lora_request) - - async def get_decoding_config(self) -> DecodingConfig: - return self.decoding_config - - async def get_model_config(self) -> ModelConfig: - return self.model_config - - async def is_tracing_enabled(self) -> bool: - return self.tracing_flag - - async def _wait_for_server_rpc(self): - """Wait for the RPCServer to start up.""" - - await self._send_one_way_rpc_request( - request=RPCUtilityRequest.IS_SERVER_READY, - error_message="Unable to start RPC Server") - - async def _get_model_config_rpc(self) -> ModelConfig: - """Get the ModelConfig object from the RPC Server""" - - return await self._send_get_data_rpc_request( - RPCUtilityRequest.GET_MODEL_CONFIG, - expected_type=ModelConfig, - error_message="Could not get ModelConfig from RPC Server") - - async def _get_decoding_config_rpc(self) -> DecodingConfig: - """Get DecodingConfig from the RPCServer""" - - return await self._send_get_data_rpc_request( - RPCUtilityRequest.GET_DECODING_CONFIG, - expected_type=DecodingConfig, - error_message="Could not get DecodingConfig from RPC Server") - - async def _get_parallel_config_rpc(self) -> ParallelConfig: - """Get ParallelConfig from the RPCServer""" - - return await self._send_get_data_rpc_request( - RPCUtilityRequest.GET_PARALLEL_CONFIG, - expected_type=ParallelConfig, - error_message="Could not get ParallelConfig from RPC Server") - - async def _get_scheduler_config_rpc(self) -> SchedulerConfig: - """Get SchedulerConfig from the RPCServer""" - - return await self._send_get_data_rpc_request( - RPCUtilityRequest.GET_SCHEDULER_CONFIG, - expected_type=SchedulerConfig, - error_message="Could not get SchedulerConfig from RPC Server") - - async def _get_lora_config_rpc(self) -> LoRAConfig: - """Get LoRAConfig from the RPCServer""" - - return await self._send_get_data_rpc_request( - RPCUtilityRequest.GET_LORA_CONFIG, - expected_type=LoRAConfig, - error_message="Could not get LoRAConfig from RPC Server") - - async def _is_tracing_enabled_rpc(self) -> bool: - """Get is_tracing_enabled flag from the RPCServer""" - - return await self._send_get_data_rpc_request( - RPCUtilityRequest.IS_TRACING_ENABLED, - expected_type=bool, - error_message="Could not get is_tracing_enabled from RPC Server") - - async def abort(self, request_id: str): - """Send an ABORT_REQUEST signal to the RPC Server""" - - # Suppress timeouts as well. - # In cases where the server is busy processing requests and a very - # large volume of abort requests arrive, it is likely that the server - # will not be able to ack all of them in time. We have seen this when - # we abort 20k requests at once while another 2k are processing- many - # of them time out, but we see the server successfully abort all of the - # requests. - # In this case we assume that the server has received or will receive - # these abort requests, and ignore the timeout. This prevents a massive - # wall of `TimeoutError` stack traces. - with suppress(RPCClientClosedError, TimeoutError): - await self._send_one_way_rpc_request( - request=RPCAbortRequest(request_id), - error_message=f"RPCAbortRequest {request_id} failed") - - async def do_log_stats(self): - """Send a DO_LOG_STATS signal to the RPC Server""" - with suppress(RPCClientClosedError): - await self._send_one_way_rpc_request( - request=RPCUtilityRequest.DO_LOG_STATS, - error_message="RPCRequest DO_LOG_STATS failed.") - - @property - def is_running(self) -> bool: - return not self._errored - - @property - def is_stopped(self) -> bool: - return self._errored - - @property - def errored(self) -> bool: - return self._errored - - async def generate( - self, - inputs: PromptInputs, - sampling_params: SamplingParams, - request_id: str, - lora_request: Optional[LoRARequest] = None, - trace_headers: Optional[Mapping[str, str]] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None - ) -> AsyncGenerator[RequestOutput, None]: - """Send an RPCGenerateRequest to the RPCServer and stream responses.""" - - finished = False - try: - with self.to_proxy_socket() as socket: - # Send RPCGenerateRequest to the RPCServer. - await socket.send_multipart((cloudpickle.dumps( - RPCGenerateRequest( - inputs=inputs, - sampling_params=sampling_params, - request_id=request_id, - lora_request=lora_request, - trace_headers=trace_headers, - prompt_adapter_request=prompt_adapter_request)), )) - - # Stream back the results from the RPC Server. - while not finished: - message = await socket.recv(copy=False) - assert isinstance(message, Frame) - request_output = pickle.loads(message.buffer) - - if isinstance(request_output, Exception): - # On exception, check if the server is still healthy - # possibly setting the `errored` property. - if not self._errored: - try: - await self.check_health(socket=socket) - except Exception as e: - self._errored = True - logger.exception(repr(e)) - - # NB: do before raising here so that the flag is set - # by the time the caller receives this exception - raise request_output - - finished = request_output.finished - yield request_output - - finally: - # Request was canceled by the client. - if not finished and not self._errored: - await self.abort(request_id) - - async def check_health(self, socket: Optional[Socket] = None) -> None: - """Raise if unhealthy""" - - await self._send_one_way_rpc_request( - request=RPCUtilityRequest.IS_SERVER_HEALTHY, - error_message="Got Unhealthy response from RPC Server", - socket=socket) - - async def encode(self, *args, - **kwargs) -> AsyncGenerator[EmbeddingRequestOutput, None]: - raise NotImplementedError( - "Embeddings not supported with multiprocessing backend") - - async def start_profile(self) -> None: - """Start profiling the engine""" - - await self._send_one_way_rpc_request( - request=RPCUtilityRequest.START_PROFILE, - error_message="RPCRequest START_PROFILE failed.") - - async def stop_profile(self) -> None: - """Stop profiling the engine""" - - await self._send_one_way_rpc_request( - request=RPCUtilityRequest.STOP_PROFILE, - error_message="RPCRequest STOP_PROFILE failed.") diff --git a/vllm/entrypoints/openai/rpc/server.py b/vllm/entrypoints/openai/rpc/server.py deleted file mode 100644 index bebc2faedb680..0000000000000 --- a/vllm/entrypoints/openai/rpc/server.py +++ /dev/null @@ -1,237 +0,0 @@ -import asyncio -import pickle -import signal -from typing import Any, Coroutine, Union - -import cloudpickle -import uvloop -import zmq -import zmq.asyncio -from typing_extensions import Never -from zmq import Frame # type: ignore[attr-defined] -from zmq.asyncio import Socket - -from vllm import AsyncEngineArgs, AsyncLLMEngine -from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig, - ParallelConfig, SchedulerConfig) -from vllm.entrypoints.openai.rpc import (VLLM_RPC_SUCCESS_STR, - VLLM_RPC_ZMQ_HWM, RPCAbortRequest, - RPCGenerateRequest, RPCUtilityRequest) -from vllm.logger import init_logger -from vllm.usage.usage_lib import UsageContext - -logger = init_logger(__name__) - -CONFIG_TYPE = Union[ModelConfig, DecodingConfig, ParallelConfig, - SchedulerConfig, LoRAConfig] - - -class AsyncEngineRPCServer: - - def __init__(self, async_engine_args: AsyncEngineArgs, - usage_context: UsageContext, rpc_path: str): - # Initialize engine first. - self.engine = AsyncLLMEngine.from_engine_args( - async_engine_args, usage_context=usage_context) - - # Initialize context. - self.context = zmq.asyncio.Context() - - # Init socket. - self.socket: Socket = self.context.socket(zmq.constants.DEALER) - self.socket.set_hwm(VLLM_RPC_ZMQ_HWM) - self.socket.connect(rpc_path) - - def cleanup(self): - """Cleanup all resources.""" - self.socket.close() - self.context.destroy() - self.engine.shutdown_background_loop() - # Clear the engine reference so that it can be GC'ed. - del self.engine - - async def get_config(self, identity, request): - try: - config: CONFIG_TYPE - if request == RPCUtilityRequest.GET_MODEL_CONFIG: - config = await self.engine.get_model_config() - elif request == RPCUtilityRequest.GET_DECODING_CONFIG: - config = await self.engine.get_decoding_config() - elif request == RPCUtilityRequest.GET_LORA_CONFIG: - config = await self.engine.get_lora_config() - elif request == RPCUtilityRequest.GET_SCHEDULER_CONFIG: - config = await self.engine.get_scheduler_config() - elif request == RPCUtilityRequest.GET_PARALLEL_CONFIG: - config = await self.engine.get_parallel_config() - else: - raise ValueError("Unknown Config Request: %s", request) - - await self.socket.send_multipart((identity, pickle.dumps(config)), - copy=False) - - except Exception as e: - await self.socket.send_multipart((identity, pickle.dumps(e)), - copy=False) - - async def is_tracing_enabled(self, identity): - """Send the is_tracing_enabled flag""" - tracing_flag = await self.engine.is_tracing_enabled() - - await self.socket.send_multipart( - (identity, pickle.dumps(tracing_flag))) - - async def do_log_stats(self, identity): - """Log stats and confirm success.""" - await self.engine.do_log_stats() - - await self.socket.send_multipart( - (identity, pickle.dumps(VLLM_RPC_SUCCESS_STR))) - - async def is_server_ready(self, identity): - """Notify the client that we are ready.""" - await self.socket.send_multipart( - (identity, pickle.dumps(VLLM_RPC_SUCCESS_STR))) - - async def abort(self, identity, request: RPCAbortRequest): - """Abort request and notify the client of success.""" - try: - # Abort the request in the llm engine. - await self.engine.abort(request.request_id) - result: Union[str, Exception] = VLLM_RPC_SUCCESS_STR - except Exception as e: - result = e - await self.socket.send_multipart((identity, pickle.dumps(result))) - - async def generate(self, identity, generate_request: RPCGenerateRequest): - try: - results_generator = self.engine.generate( - generate_request.inputs, - sampling_params=generate_request.sampling_params, - request_id=generate_request.request_id, - lora_request=generate_request.lora_request, - trace_headers=generate_request.trace_headers, - prompt_adapter_request=generate_request.prompt_adapter_request) - - async for request_output in results_generator: - await self.socket.send_multipart( - (identity, pickle.dumps(request_output)), copy=False) - - except Exception as e: - await self.socket.send_multipart((identity, pickle.dumps(e)), - copy=False) - - async def check_health(self, identity): - try: - await self.engine.check_health() - await self.socket.send_multipart( - (identity, pickle.dumps(VLLM_RPC_SUCCESS_STR))) - - except Exception as e: - await self.socket.send_multipart((identity, pickle.dumps(e)), - copy=False) - - async def start_profile(self, identity): - logger.info("Starting profiler...") - await self.engine.start_profile() - logger.info("Profiler started.") - - await self.socket.send_multipart(( - identity, - pickle.dumps(VLLM_RPC_SUCCESS_STR), - )) - - async def stop_profile(self, identity): - logger.info("Stopping profiler...") - await self.engine.stop_profile() - logger.info("Profiler stopped.") - - await self.socket.send_multipart(( - identity, - pickle.dumps(VLLM_RPC_SUCCESS_STR), - )) - - def _make_handler_coro(self, identity, - message: Frame) -> Coroutine[Any, Any, Never]: - """Route the zmq message to the handler coroutine.""" - - request = cloudpickle.loads(message.buffer) - - if isinstance(request, RPCGenerateRequest): - return self.generate(identity, request) - - elif isinstance(request, RPCAbortRequest): - return self.abort(identity, request) - - elif isinstance(request, RPCUtilityRequest): - if request in [ - RPCUtilityRequest.GET_MODEL_CONFIG, - RPCUtilityRequest.GET_PARALLEL_CONFIG, - RPCUtilityRequest.GET_DECODING_CONFIG, - RPCUtilityRequest.GET_SCHEDULER_CONFIG, - RPCUtilityRequest.GET_LORA_CONFIG - ]: - return self.get_config(identity, request) - elif request == RPCUtilityRequest.DO_LOG_STATS: - return self.do_log_stats(identity) - elif request == RPCUtilityRequest.IS_SERVER_READY: - return self.is_server_ready(identity) - elif request == RPCUtilityRequest.IS_SERVER_HEALTHY: - return self.check_health(identity) - elif request == RPCUtilityRequest.IS_TRACING_ENABLED: - return self.is_tracing_enabled(identity) - elif request == RPCUtilityRequest.START_PROFILE: - return self.start_profile(identity) - elif request == RPCUtilityRequest.STOP_PROFILE: - return self.stop_profile(identity) - else: - raise ValueError(f"Unknown RPCUtilityRequest type: {request}") - - else: - raise ValueError(f"Unknown RPCRequest type: {request}") - - async def run_server_loop(self): - """Inner RPC Server Loop""" - - running_tasks = set() - while True: - # Wait for a request. - identity, message = await self.socket.recv_multipart(copy=False) - - # Process the request async. - task = asyncio.create_task( - self._make_handler_coro(identity, message)) - - # We need to keep around a strong reference to the task, - # to avoid the task disappearing mid-execution as running tasks - # can be GC'ed. Below is a common "fire-and-forget" tasks - # https://docs.python.org/3/library/asyncio-task.html#asyncio.create_task - running_tasks.add(task) - task.add_done_callback(running_tasks.discard) - - -async def run_server(server: AsyncEngineRPCServer): - # Put the server task into the asyncio loop. - loop = asyncio.get_running_loop() - server_task = loop.create_task(server.run_server_loop()) - - # Interruption handling. - def signal_handler() -> None: - # Kill the server on interrupt / terminate - server_task.cancel() - - loop.add_signal_handler(signal.SIGINT, signal_handler) - loop.add_signal_handler(signal.SIGTERM, signal_handler) - - try: - await server_task - except asyncio.CancelledError: - logger.info("vLLM ZMQ RPC Server was interrupted.") - finally: - # Clean up all resources. - server.cleanup() - - -def run_rpc_server(async_engine_args: AsyncEngineArgs, - usage_context: UsageContext, rpc_path: str): - server = AsyncEngineRPCServer(async_engine_args, usage_context, rpc_path) - uvloop.run(run_server(server)) diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 8ac4caffb37f0..b84898dc39b0f 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -9,7 +9,7 @@ from fastapi import Request from vllm.config import ModelConfig -from vllm.engine.protocol import AsyncEngineClient +from vllm.engine.protocol import EngineClient from vllm.entrypoints.chat_utils import (ConversationMessage, apply_hf_chat_template, apply_mistral_chat_template, @@ -45,7 +45,7 @@ class OpenAIServingChat(OpenAIServing): def __init__(self, - async_engine_client: AsyncEngineClient, + engine_client: EngineClient, model_config: ModelConfig, served_model_names: List[str], response_role: str, @@ -57,7 +57,7 @@ def __init__(self, return_tokens_as_token_ids: bool = False, enable_auto_tools: bool = False, tool_parser: Optional[str] = None): - super().__init__(async_engine_client=async_engine_client, + super().__init__(engine_client=engine_client, model_config=model_config, served_model_names=served_model_names, lora_modules=lora_modules, @@ -105,6 +105,12 @@ async def create_chat_completion( logger.error("Error with model %s", error_check_ret) return error_check_ret + # If the engine is dead, raise the engine's DEAD_ERROR. + # This is required for the streaming case, where we return a + # success status before we actually start generating text :). + if self.engine_client.errored: + raise self.engine_client.dead_error + try: ( lora_request, @@ -112,8 +118,7 @@ async def create_chat_completion( ) = self._maybe_get_adapters(request) model_config = self.model_config - tokenizer = await self.async_engine_client.get_tokenizer( - lora_request) + tokenizer = await self.engine_client.get_tokenizer(lora_request) conversation, mm_data_future = parse_chat_messages_futures( request.messages, model_config, tokenizer) @@ -123,7 +128,8 @@ async def create_chat_completion( ] prompt: Union[str, List[int]] - if isinstance(tokenizer, MistralTokenizer): + is_mistral_tokenizer = isinstance(tokenizer, MistralTokenizer) + if is_mistral_tokenizer: prompt = apply_mistral_chat_template( tokenizer, messages=request.messages, @@ -159,10 +165,10 @@ async def create_chat_completion( return self.create_error_response( "tool_choice = \"required\" is not supported!") - # "auto" tools requires --enable-auto-tool-choice - # and --tool-call-parser - if request.tool_choice == "auto" and not ( + if not is_mistral_tokenizer and request.tool_choice == "auto" and not ( self.enable_auto_tools and self.tool_parser is not None): + # for hf tokenizers, "auto" tools requires + # --enable-auto-tool-choice and --tool-call-parser return self.create_error_response( "\"auto\" tool choice requires " "--enable-auto-tool-choice and --tool-call-parser to be set") @@ -206,8 +212,8 @@ async def create_chat_completion( if mm_data is not None: engine_inputs["multi_modal_data"] = mm_data - is_tracing_enabled = ( - await self.async_engine_client.is_tracing_enabled()) + is_tracing_enabled = (await + self.engine_client.is_tracing_enabled()) trace_headers = None if is_tracing_enabled and raw_request: trace_headers = extract_trace_headers(raw_request.headers) @@ -215,7 +221,7 @@ async def create_chat_completion( and contains_trace_headers(raw_request.headers)): log_tracing_disabled_warning() - result_generator = self.async_engine_client.generate( + result_generator = self.engine_client.generate( engine_inputs, sampling_params, request_id, @@ -246,8 +252,7 @@ async def create_chat_completion( def get_chat_request_role(self, request: ChatCompletionRequest) -> str: if request.add_generation_prompt: return self.response_role - else: - return request.messages[-1]["role"] + return request.messages[-1]["role"] async def chat_completion_stream_generator( self, @@ -264,15 +269,37 @@ async def chat_completion_stream_generator( # Send response for each token for each request.n (index) num_choices = 1 if request.n is None else request.n - previous_texts = [""] * num_choices previous_num_tokens = [0] * num_choices finish_reason_sent = [False] * num_choices + num_prompt_tokens = 0 + tool_parser: Optional[ToolParser] = self.tool_parser( tokenizer) if self.tool_parser else None + if isinstance(request.tool_choice, ChatCompletionNamedToolChoiceParam): + tool_choice_function_name = request.tool_choice.function.name + else: + tool_choice_function_name = None + + # Determine whether tools are in use with "auto" tool choice + tool_choice_auto = ( + not tool_choice_function_name + and self._should_stream_with_auto_tool_parsing(request)) + + all_previous_token_ids: Optional[List[List[int]]] + if tool_choice_auto: + # These are only required in "auto" tool choice case + previous_texts = [""] * num_choices + all_previous_token_ids = [[]] * num_choices + else: + previous_texts, all_previous_token_ids = None, None + try: async for res in result_generator: + if res.prompt_token_ids is not None: + num_prompt_tokens = len(res.prompt_token_ids) + # We need to do it here, because if there are exceptions in # the result_generator, it needs to be sent as the FIRST # response (by the try...catch). @@ -305,10 +332,10 @@ async def chat_completion_stream_generator( and request.stream_options.include_usage): # if continuous usage stats are requested, add it if request.stream_options.continuous_usage_stats: - prompt_tokens = len(res.prompt_token_ids) - usage = UsageInfo(prompt_tokens=prompt_tokens, - completion_tokens=0, - total_tokens=prompt_tokens) + usage = UsageInfo( + prompt_tokens=num_prompt_tokens, + completion_tokens=0, + total_tokens=num_prompt_tokens) chunk.usage = usage # otherwise don't else: @@ -344,12 +371,10 @@ async def chat_completion_stream_generator( request.stream_options.include_usage): if (request.stream_options. continuous_usage_stats): - prompt_tokens = len( - res.prompt_token_ids) usage = UsageInfo( - prompt_tokens=prompt_tokens, + prompt_tokens=num_prompt_tokens, completion_tokens=0, - total_tokens=prompt_tokens) + total_tokens=num_prompt_tokens) chunk.usage = usage else: chunk.usage = None @@ -360,65 +385,66 @@ async def chat_completion_stream_generator( first_iteration = False for output in res.outputs: - i = output.index if finish_reason_sent[i]: continue - delta_token_ids = output.token_ids[previous_num_tokens[i]:] - out_logprobs = output.logprobs[ - previous_num_tokens[i]:] if output.logprobs else None - if request.logprobs and request.top_logprobs is not None: - assert out_logprobs is not None, ( + assert output.logprobs is not None, ( "Did not output logprobs") logprobs = self._create_chat_logprobs( - token_ids=delta_token_ids, - top_logprobs=out_logprobs, + token_ids=output.token_ids, + top_logprobs=output.logprobs, tokenizer=tokenizer, num_output_top_logprobs=request.top_logprobs, ) else: logprobs = None - delta_text = output.text[len(previous_texts[i]):] - delta_message: Optional[DeltaMessage] = None + delta_text = output.text + delta_message: Optional[DeltaMessage] # handle streaming deltas for tools with named tool_choice - if (request.tool_choice and type(request.tool_choice) is - ChatCompletionNamedToolChoiceParam): + if tool_choice_function_name: delta_message = DeltaMessage(tool_calls=[ DeltaToolCall(function=DeltaFunctionCall( - name=request.tool_choice.function.name, + name=tool_choice_function_name, arguments=delta_text), index=i) ]) # handle streaming deltas for tools with "auto" tool choice - elif (self._should_stream_with_auto_tool_parsing(request) - and tool_parser): + elif tool_choice_auto: + assert previous_texts is not None + assert all_previous_token_ids is not None + assert tool_parser is not None + #TODO optimize manipulation of these lists + previous_text = previous_texts[i] + previous_token_ids = all_previous_token_ids[i] + current_text = previous_text + delta_text + current_token_ids = previous_token_ids + list( + output.token_ids) + delta_message = ( tool_parser.extract_tool_calls_streaming( - previous_text=previous_texts[i], - current_text=output.text, + previous_text=previous_text, + current_text=current_text, delta_text=delta_text, - previous_token_ids= \ - output.token_ids[ - :-1 * len(delta_token_ids) - ], - current_token_ids=output.token_ids, - delta_token_ids=delta_token_ids - ) - ) + previous_token_ids=previous_token_ids, + current_token_ids=current_token_ids, + delta_token_ids=output.token_ids)) + + # update the previous values for the next iteration + previous_texts[i] = current_text + all_previous_token_ids[i] = current_token_ids # handle streaming just a content delta else: delta_message = DeltaMessage(content=delta_text) # set the previous values for the next iteration - previous_texts[i] = output.text - previous_num_tokens[i] = len(output.token_ids) + previous_num_tokens[i] += len(output.token_ids) # if the message delta is None (e.g. because it was a # "control token" for tool calls or the parser otherwise @@ -445,13 +471,12 @@ async def chat_completion_stream_generator( # handle usage stats if requested & if continuous if (request.stream_options and request.stream_options.include_usage): - if (request.stream_options.continuous_usage_stats): - prompt_tokens = len(res.prompt_token_ids) + if request.stream_options.continuous_usage_stats: completion_tokens = len(output.token_ids) usage = UsageInfo( - prompt_tokens=prompt_tokens, + prompt_tokens=num_prompt_tokens, completion_tokens=completion_tokens, - total_tokens=prompt_tokens + + total_tokens=num_prompt_tokens + completion_tokens, ) chunk.usage = usage @@ -482,7 +507,7 @@ async def chat_completion_stream_generator( tool_parser.prev_tool_call_arr[index].get( "arguments", {})) - # get what we've streamed so for for arguments + # get what we've streamed so far for arguments # for the current tool actual_call = tool_parser.streamed_args_for_tool[ index] @@ -500,7 +525,6 @@ async def chat_completion_stream_generator( ]) # Send the finish response for each request.n only once - prompt_tokens = len(res.prompt_token_ids) choice_data = ChatCompletionResponseStreamChoice( index=i, delta=delta_message, @@ -518,13 +542,12 @@ async def chat_completion_stream_generator( model=model_name) if (request.stream_options and request.stream_options.include_usage): - if (request.stream_options.continuous_usage_stats): - prompt_tokens = len(res.prompt_token_ids) + if request.stream_options.continuous_usage_stats: completion_tokens = len(output.token_ids) usage = UsageInfo( - prompt_tokens=prompt_tokens, + prompt_tokens=num_prompt_tokens, completion_tokens=completion_tokens, - total_tokens=prompt_tokens + + total_tokens=num_prompt_tokens + completion_tokens, ) chunk.usage = usage @@ -538,10 +561,11 @@ async def chat_completion_stream_generator( # is sent, send the usage if (request.stream_options and request.stream_options.include_usage): + completion_tokens = previous_num_tokens[i] final_usage = UsageInfo( - prompt_tokens=prompt_tokens, - completion_tokens=previous_num_tokens[i], - total_tokens=prompt_tokens + previous_num_tokens[i], + prompt_tokens=num_prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=num_prompt_tokens + completion_tokens, ) final_usage_chunk = ChatCompletionStreamResponse( @@ -680,6 +704,7 @@ async def chat_completion_full_generator( or "") choice.message.content = full_message + assert final_res.prompt_token_ids is not None num_prompt_tokens = len(final_res.prompt_token_ids) num_generated_tokens = sum( len(output.token_ids) for output in final_res.outputs) @@ -789,9 +814,9 @@ def _should_check_for_unstreamed_tool_arg_tokens( return bool( # if there is a delta message that includes tool calls which # include a function that has arguments - self.enable_auto_tools and self.tool_parser and delta_message + output.finish_reason is not None + and self.enable_auto_tools and self.tool_parser and delta_message and delta_message.tool_calls and delta_message.tool_calls[0] and delta_message.tool_calls[0].function and delta_message.tool_calls[0].function.arguments is not None - and output.finish_reason is not None ) diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 34f1200753f8d..14fa60243c584 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -8,7 +8,7 @@ from fastapi import Request from vllm.config import ModelConfig -from vllm.engine.protocol import AsyncEngineClient +from vllm.engine.protocol import EngineClient from vllm.entrypoints.logger import RequestLogger # yapf conflicts with isort for this block # yapf: disable @@ -43,7 +43,7 @@ class OpenAIServingCompletion(OpenAIServing): def __init__( self, - async_engine_client: AsyncEngineClient, + engine_client: EngineClient, model_config: ModelConfig, served_model_names: List[str], *, @@ -52,7 +52,7 @@ def __init__( request_logger: Optional[RequestLogger], return_tokens_as_token_ids: bool = False, ): - super().__init__(async_engine_client=async_engine_client, + super().__init__(engine_client=engine_client, model_config=model_config, served_model_names=served_model_names, lora_modules=lora_modules, @@ -78,6 +78,12 @@ async def create_completion( if error_check_ret is not None: return error_check_ret + # If the engine is dead, raise the engine's DEAD_ERROR. + # This is required for the streaming case, where we return a + # success status before we actually start generating text :). + if self.engine_client.errored: + raise self.engine_client.dead_error + # Return error for unsupported features. if request.suffix is not None: return self.create_error_response( @@ -95,8 +101,7 @@ async def create_completion( prompt_adapter_request, ) = self._maybe_get_adapters(request) - tokenizer = await self.async_engine_client.get_tokenizer( - lora_request) + tokenizer = await self.engine_client.get_tokenizer(lora_request) guided_decode_logits_processor = ( await self._guided_decode_logits_processor(request, tokenizer)) @@ -124,8 +129,8 @@ async def create_completion( lora_request=lora_request, prompt_adapter_request=prompt_adapter_request) - is_tracing_enabled = ( - await self.async_engine_client.is_tracing_enabled()) + is_tracing_enabled = (await + self.engine_client.is_tracing_enabled()) trace_headers = None if is_tracing_enabled: trace_headers = extract_trace_headers(raw_request.headers) @@ -133,7 +138,7 @@ async def create_completion( raw_request.headers): log_tracing_disabled_warning() - generator = self.async_engine_client.generate( + generator = self.engine_client.generate( {"prompt_token_ids": prompt_inputs["prompt_token_ids"]}, sampling_params, request_id_item, @@ -223,9 +228,10 @@ async def completion_stream_generator( tokenizer: AnyTokenizer, ) -> AsyncGenerator[str, None]: num_choices = 1 if request.n is None else request.n - previous_texts = [""] * num_choices * num_prompts + previous_text_lens = [0] * num_choices * num_prompts previous_num_tokens = [0] * num_choices * num_prompts has_echoed = [False] * num_choices * num_prompts + num_prompt_tokens = [0] * num_prompts try: async for prompt_idx, res in result_generator: @@ -233,6 +239,10 @@ async def completion_stream_generator( prompt_logprobs = res.prompt_logprobs prompt_text = res.prompt + # Prompt details are excluded from later streamed outputs + if res.prompt_token_ids is not None: + num_prompt_tokens[prompt_idx] = len(res.prompt_token_ids) + delta_token_ids: GenericSequence[int] out_logprobs: Optional[GenericSequence[Optional[Dict[ int, Logprob]]]] @@ -244,6 +254,7 @@ async def completion_stream_generator( assert request.max_tokens is not None if request.echo and request.max_tokens == 0: + assert prompt_token_ids is not None assert prompt_text is not None # only return the prompt delta_text = prompt_text @@ -252,6 +263,7 @@ async def completion_stream_generator( has_echoed[i] = True elif (request.echo and request.max_tokens > 0 and not has_echoed[i]): + assert prompt_token_ids is not None assert prompt_text is not None assert prompt_logprobs is not None # echo the prompt and first token @@ -266,11 +278,9 @@ async def completion_stream_generator( has_echoed[i] = True else: # return just the delta - delta_text = output.text[len(previous_texts[i]):] - delta_token_ids = output.token_ids[ - previous_num_tokens[i]:] - out_logprobs = output.logprobs[previous_num_tokens[ - i]:] if output.logprobs else None + delta_text = output.text + delta_token_ids = output.token_ids + out_logprobs = output.logprobs if request.logprobs is not None: assert out_logprobs is not None, ( @@ -280,13 +290,13 @@ async def completion_stream_generator( top_logprobs=out_logprobs, num_output_top_logprobs=request.logprobs, tokenizer=tokenizer, - initial_text_offset=len(previous_texts[i]), + initial_text_offset=previous_text_lens[i], ) else: logprobs = None - previous_texts[i] = output.text - previous_num_tokens[i] = len(output.token_ids) + previous_text_lens[i] += len(output.text) + previous_num_tokens[i] += len(output.token_ids) finish_reason = output.finish_reason stop_reason = output.stop_reason @@ -307,8 +317,8 @@ async def completion_stream_generator( and request.stream_options.include_usage): if (request.stream_options.continuous_usage_stats or output.finish_reason is not None): - prompt_tokens = len(prompt_token_ids) - completion_tokens = len(output.token_ids) + prompt_tokens = num_prompt_tokens[prompt_idx] + completion_tokens = previous_num_tokens[i] usage = UsageInfo( prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, @@ -356,6 +366,7 @@ def request_output_to_completion_response( for final_res in final_res_batch: prompt_token_ids = final_res.prompt_token_ids + assert prompt_token_ids is not None prompt_logprobs = final_res.prompt_logprobs prompt_text = final_res.prompt @@ -411,9 +422,9 @@ def request_output_to_completion_response( ) choices.append(choice_data) + num_generated_tokens += len(output.token_ids) + num_prompt_tokens += len(prompt_token_ids) - num_generated_tokens += sum( - len(output.token_ids) for output in final_res.outputs) usage = UsageInfo( prompt_tokens=num_prompt_tokens, diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py index 12ec6be03cd62..f111a3a8277b5 100644 --- a/vllm/entrypoints/openai/serving_embedding.py +++ b/vllm/entrypoints/openai/serving_embedding.py @@ -8,7 +8,7 @@ from typing_extensions import assert_never from vllm.config import ModelConfig -from vllm.engine.protocol import AsyncEngineClient +from vllm.engine.protocol import EngineClient from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.openai.protocol import (EmbeddingRequest, EmbeddingResponse, @@ -71,13 +71,13 @@ class OpenAIServingEmbedding(OpenAIServing): def __init__( self, - async_engine_client: AsyncEngineClient, + engine_client: EngineClient, model_config: ModelConfig, served_model_names: List[str], *, request_logger: Optional[RequestLogger], ): - super().__init__(async_engine_client=async_engine_client, + super().__init__(engine_client=engine_client, model_config=model_config, served_model_names=served_model_names, lora_modules=None, @@ -118,8 +118,7 @@ async def create_embedding( prompt_adapter_request, ) = self._maybe_get_adapters(request) - tokenizer = await self.async_engine_client.get_tokenizer( - lora_request) + tokenizer = await self.engine_client.get_tokenizer(lora_request) pooling_params = request.to_pooling_params() @@ -144,7 +143,7 @@ async def create_embedding( "Prompt adapter is not supported " "for embedding models") - generator = self.async_engine_client.encode( + generator = self.engine_client.encode( {"prompt_token_ids": prompt_inputs["prompt_token_ids"]}, pooling_params, request_id_item, diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index ac74527441cd9..72f9381abc7db 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -8,7 +8,7 @@ from typing_extensions import Annotated from vllm.config import ModelConfig -from vllm.engine.protocol import AsyncEngineClient +from vllm.engine.protocol import EngineClient from vllm.entrypoints.logger import RequestLogger # yapf conflicts with isort for this block # yapf: disable @@ -64,7 +64,7 @@ class OpenAIServing: def __init__( self, - async_engine_client: AsyncEngineClient, + engine_client: EngineClient, model_config: ModelConfig, served_model_names: List[str], *, @@ -75,7 +75,7 @@ def __init__( ): super().__init__() - self.async_engine_client = async_engine_client + self.engine_client = engine_client self.model_config = model_config self.max_model_len = model_config.max_model_len @@ -159,7 +159,7 @@ def create_streaming_error_response( async def _guided_decode_logits_processor( self, request: Union[ChatCompletionRequest, CompletionRequest], tokenizer: AnyTokenizer) -> Optional[LogitsProcessor]: - decoding_config = await self.async_engine_client.get_decoding_config() + decoding_config = await self.engine_client.get_decoding_config() guided_decoding_backend = request.guided_decoding_backend \ or decoding_config.guided_decoding_backend return await get_guided_decoding_logits_processor( diff --git a/vllm/entrypoints/openai/serving_tokenization.py b/vllm/entrypoints/openai/serving_tokenization.py index 6e802b71ae2b4..8f8862897fc4e 100644 --- a/vllm/entrypoints/openai/serving_tokenization.py +++ b/vllm/entrypoints/openai/serving_tokenization.py @@ -1,7 +1,7 @@ from typing import List, Optional, Union from vllm.config import ModelConfig -from vllm.engine.protocol import AsyncEngineClient +from vllm.engine.protocol import EngineClient from vllm.entrypoints.chat_utils import (apply_hf_chat_template, apply_mistral_chat_template, load_chat_template, @@ -29,7 +29,7 @@ class OpenAIServingTokenization(OpenAIServing): def __init__( self, - async_engine_client: AsyncEngineClient, + engine_client: EngineClient, model_config: ModelConfig, served_model_names: List[str], *, @@ -37,7 +37,7 @@ def __init__( request_logger: Optional[RequestLogger], chat_template: Optional[str], ): - super().__init__(async_engine_client=async_engine_client, + super().__init__(engine_client=engine_client, model_config=model_config, served_model_names=served_model_names, lora_modules=lora_modules, @@ -66,7 +66,7 @@ async def create_tokenize( prompt_adapter_request, ) = self._maybe_get_adapters(request) - tokenizer = await self.async_engine_client.get_tokenizer(lora_request) + tokenizer = await self.engine_client.get_tokenizer(lora_request) prompt: Union[str, List[int]] if isinstance(request, TokenizeChatRequest): @@ -132,7 +132,7 @@ async def create_detokenize( prompt_adapter_request, ) = self._maybe_get_adapters(request) - tokenizer = await self.async_engine_client.get_tokenizer(lora_request) + tokenizer = await self.engine_client.get_tokenizer(lora_request) self._log_inputs(request_id, request.tokens, diff --git a/vllm/envs.py b/vllm/envs.py index b3678399fe207..43c7aa8af85b2 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -57,9 +57,10 @@ VERBOSE: bool = False VLLM_ALLOW_LONG_MAX_MODEL_LEN: bool = False VLLM_TEST_FORCE_FP8_MARLIN: bool = False - VLLM_RPC_GET_DATA_TIMEOUT_MS: int = 5000 + VLLM_RPC_TIMEOUT: int = 10000 # ms VLLM_PLUGINS: Optional[List[str]] = None VLLM_TORCH_PROFILER_DIR: Optional[str] = None + VLLM_USE_TRITON_AWQ: bool = False VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False @@ -202,6 +203,11 @@ def get_default_config_root(): (os.environ.get("VLLM_DYNAMO_USE_CUSTOM_DISPATCHER", "True").lower() in ("true", "1")), + # Internal flag to control whether we use custom op, + # or use the native pytorch implementation + "VLLM_TEST_COMPILE_NO_CUSTOM_OPS": + lambda: int(os.environ.get("VLLM_TEST_COMPILE_NO_CUSTOM_OPS", "0")), + # Internal flag to enable Dynamo fullgraph capture "VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE": lambda: bool( @@ -387,8 +393,8 @@ def get_default_config_root(): # Time in ms for the zmq client to wait for a response from the backend # server for simple data operations - "VLLM_RPC_GET_DATA_TIMEOUT_MS": - lambda: int(os.getenv("VLLM_RPC_GET_DATA_TIMEOUT_MS", "5000")), + "VLLM_RPC_TIMEOUT": + lambda: int(os.getenv("VLLM_RPC_TIMEOUT", "10000")), # a list of plugin names to load, separated by commas. # if this is not set, it means all plugins will be loaded diff --git a/vllm/executor/cpu_executor.py b/vllm/executor/cpu_executor.py index 7380b73ad6548..9ad240ef60820 100644 --- a/vllm/executor/cpu_executor.py +++ b/vllm/executor/cpu_executor.py @@ -106,6 +106,7 @@ def _init_executor(self) -> None: )) for rank in range(1, world_size) ] + self.worker_monitor = None if world_size != 1 or is_async: if is_async: async_worker_list = self.workers + [self.driver_worker] diff --git a/vllm/executor/multiproc_gpu_executor.py b/vllm/executor/multiproc_gpu_executor.py index 9c6d4051eb3f8..cc535e99a06ef 100644 --- a/vllm/executor/multiproc_gpu_executor.py +++ b/vllm/executor/multiproc_gpu_executor.py @@ -1,8 +1,5 @@ import asyncio import os -import signal -import threading -import weakref from functools import partial from typing import Any, List, Optional @@ -108,17 +105,6 @@ def _init_executor(self) -> None: # Set up signal handlers to shutdown the executor cleanly # sometimes gc does not work well - # Use weakref to avoid holding a reference to self - ref = weakref.ref(self) - - def shutdown(signum, frame): - if executor := ref(): - executor.shutdown() - - if threading.current_thread() is threading.main_thread(): - signal.signal(signal.SIGINT, shutdown) - signal.signal(signal.SIGTERM, shutdown) - self.driver_worker = self._create_worker( distributed_init_method=distributed_init_method) self._run_workers("init_device") diff --git a/vllm/executor/multiproc_worker_utils.py b/vllm/executor/multiproc_worker_utils.py index 28c8e8699f083..5bef76b90d332 100644 --- a/vllm/executor/multiproc_worker_utils.py +++ b/vllm/executor/multiproc_worker_utils.py @@ -120,7 +120,8 @@ def run(self) -> None: logger.error("Worker %s pid %s died, exit code: %s", process.name, process.pid, process.exitcode) # Cleanup any remaining workers - logger.info("Killing local vLLM worker processes") + if logger: + logger.info("Killing local vLLM worker processes") for worker in self.workers: worker.kill_worker() # Must be done after worker task queues are all closed @@ -167,6 +168,8 @@ def _enqueue_task(self, future: Union[ResultFuture, asyncio.Future], self.tasks[task_id] = future try: self._task_queue.put((task_id, method, args, kwargs)) + except SystemExit: + raise except BaseException as e: del self.tasks[task_id] raise ChildProcessError("worker died") from e @@ -221,6 +224,10 @@ def _run_worker_process( try: executor = getattr(worker, method) output = executor(*args, **kwargs) + except SystemExit: + raise + except KeyboardInterrupt: + break except BaseException as e: tb = traceback.format_exc() logger.error( diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py index b124fe2e08ea6..9433dce842b09 100644 --- a/vllm/executor/ray_gpu_executor.py +++ b/vllm/executor/ray_gpu_executor.py @@ -437,8 +437,10 @@ def _check_ray_adag_installation(self): required_version = version.parse("2.35") current_version = version.parse( pkg_resources.get_distribution("ray").version) - if current_version < required_version: - raise ValueError(f"Ray version {required_version} or greater is " + # TODO: update the constraint once we adapt to the backward + # incompatible API change from ray 2.36 + if current_version != required_version: + raise ValueError(f"Ray version {required_version} is " f"required, but found {current_version}") import importlib.util diff --git a/vllm/executor/ray_tpu_executor.py b/vllm/executor/ray_tpu_executor.py index 8c8b5f741488b..d02fecb46f007 100644 --- a/vllm/executor/ray_tpu_executor.py +++ b/vllm/executor/ray_tpu_executor.py @@ -26,6 +26,8 @@ class RayTPUExecutor(TPUExecutor): + uses_ray: bool = True + def __init__(self, *args, **kwargs): # This is non-None when the execute model loop is running # in the parallel workers. It's a coroutine in the AsyncLLMEngine case. @@ -68,8 +70,12 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", ) assert self.speculative_config is None - worker_module_name = "vllm.worker.tpu_worker" - worker_class_name = "TPUWorker" + if self.scheduler_config.is_multi_step: + worker_module_name = "vllm.worker.multi_step_tpu_worker" + worker_class_name = "MultiStepTPUWorker" + else: + worker_module_name = "vllm.worker.tpu_worker" + worker_class_name = "TPUWorker" # GKE does not fetch environment information from metadata server # and instead sets these from within the Ray process. Therefore we diff --git a/vllm/executor/tpu_executor.py b/vllm/executor/tpu_executor.py index 0af8ba41e24d5..972649dedf33e 100644 --- a/vllm/executor/tpu_executor.py +++ b/vllm/executor/tpu_executor.py @@ -62,11 +62,17 @@ def _create_worker( rank: int = 0, distributed_init_method: Optional[str] = None, ): - from vllm.worker.tpu_worker import TPUWorker - - worker = TPUWorker(**self._get_worker_kwargs(local_rank, rank, - distributed_init_method)) - return worker + if self.scheduler_config.is_multi_step: + from vllm.worker.multi_step_tpu_worker import MultiStepTPUWorker + worker = MultiStepTPUWorker(**self._get_worker_kwargs( + local_rank, rank, distributed_init_method)) + return worker + else: + from vllm.worker.tpu_worker import TPUWorker + + worker = TPUWorker(**self._get_worker_kwargs( + local_rank, rank, distributed_init_method)) + return worker def initialize_cache( self, diff --git a/vllm/inputs/parse.py b/vllm/inputs/parse.py index b5e8ef7860598..ac9d355c64c80 100644 --- a/vllm/inputs/parse.py +++ b/vllm/inputs/parse.py @@ -5,7 +5,8 @@ from vllm.utils import is_list_of from .data import (EncoderDecoderLLMInputs, ExplicitEncoderDecoderPrompt, - LLMInputs, PromptInputs) + LLMInputs, PromptInputs, SingletonPromptInputs, TextPrompt, + TokensPrompt) class ParsedText(TypedDict): @@ -60,8 +61,38 @@ def parse_and_batch_prompt( for elem in prompt ] - raise ValueError("prompt must be a string, array of strings, " - "array of tokens, or array of token arrays") + raise TypeError("prompt must be a string, array of strings, " + "array of tokens, or array of token arrays") + + +class ParsedStrPrompt(TypedDict): + type: Literal["str"] + content: str + + +class ParsedTextPrompt(TypedDict): + type: Literal["text"] + content: TextPrompt + + +class ParsedTokensPrompt(TypedDict): + type: Literal["tokens"] + content: TokensPrompt + + +def parse_singleton_prompt( + inputs: SingletonPromptInputs, +) -> Union[ParsedStrPrompt, ParsedTextPrompt, ParsedTokensPrompt]: + if isinstance(inputs, str): + return ParsedStrPrompt(type="str", content=inputs) + elif isinstance(inputs, dict): + if "prompt_token_ids" in inputs: + return ParsedTokensPrompt(type="tokens", + content=inputs) # type: ignore + elif "prompt" in inputs: + return ParsedTextPrompt(type="text", content=inputs) + + raise TypeError("inputs must be a string, TextPrompt, or TokensPrompt") def is_explicit_encoder_decoder_prompt( diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py new file mode 100644 index 0000000000000..be2aa5f8cb7d0 --- /dev/null +++ b/vllm/inputs/preprocess.py @@ -0,0 +1,536 @@ +import asyncio +from typing import TYPE_CHECKING, List, Optional, Tuple, Union + +from typing_extensions import assert_never + +from vllm.config import ModelConfig +from vllm.logger import init_logger +from vllm.lora.request import LoRARequest +from vllm.prompt_adapter.request import PromptAdapterRequest +from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup + +from .data import (EncoderDecoderLLMInputs, LLMInputs, PromptInputs, + SingletonPromptInputs) +from .parse import is_explicit_encoder_decoder_prompt, parse_singleton_prompt + +if TYPE_CHECKING: + from vllm.multimodal import MultiModalDataDict + +logger = init_logger(__name__) + +PromptComponents = Tuple[Optional[str], List[int], + Optional["MultiModalDataDict"]] +DecoderPromptComponents = Tuple[Optional[str], Optional[List[int]], + Optional["MultiModalDataDict"]] + + +class InputPreprocessor: + + def __init__( + self, + model_config: ModelConfig, + tokenizer: Optional[BaseTokenizerGroup], + ) -> None: + super().__init__() + + self.model_config = model_config + self.tokenizer = tokenizer + + def get_tokenizer_group(self) -> BaseTokenizerGroup: + if self.tokenizer is None: + raise ValueError("You cannot pass text prompts when " + "`skip_tokenizer_init` is True") + + return self.tokenizer + + def get_bos_token_id(self, + lora_request: Optional[LoRARequest] = None + ) -> Optional[int]: + if self.tokenizer is None: + logger.warning("Using None for BOS token id because tokenizer " + "is not initialized") + return None + + return self.tokenizer.get_lora_tokenizer(lora_request).bos_token_id + + def get_eos_token_id(self, + lora_request: Optional[LoRARequest] = None + ) -> Optional[int]: + if self.tokenizer is None: + logger.warning("Using None for EOS token id because tokenizer " + "is not initialized") + return None + + return self.tokenizer.get_lora_tokenizer(lora_request).eos_token_id + + def get_decoder_start_token_id(self) -> Optional[int]: + ''' + Obtain the decoder start token id employed by an encoder/decoder + model. Returns None for non-encoder/decoder models or if the + model config is unavailable. + ''' + + if not self.is_encoder_decoder_model(): + logger.warning("Using None for decoder start token id because " + "this is not an encoder/decoder model.") + return None + + if (self.model_config is None or self.model_config.hf_config is None): + logger.warning("Using None for decoder start token id because " + "model config is not available.") + return None + + dec_start_token_id = getattr(self.model_config.hf_config, + 'decoder_start_token_id', None) + if dec_start_token_id is None: + logger.warning("Falling back on for decoder start token id " + "because decoder start token id is not available.") + dec_start_token_id = self.get_bos_token_id() + + return dec_start_token_id + + def _get_default_enc_dec_decoder_prompt(self) -> List[int]: + ''' + Specifically for encoder/decoder models: + generate a default decoder prompt for when + the user specifies only the encoder prompt. + + Encoder/decoder models utilize the decoder + prompt in different ways; as new models are + added, it is intended that this function + will be extended to produce differing + default decoder prompts, depending on the + model variety. + + Absent a special case, the default behavior + of this method is to mirror the behavior of + the HuggingFace (HF) GenerationMixin for a None + decoder prompt, which is to employ a logit processor + setting to force the first decoded token to be . + Here, this behavior is approximated by having the + "default" decoder prompt be . + + However, it is possible that in the future + other models may have different or more + complex logic for the default decoder prompt. + This motivates having a special helper method + for default decoder prompts. + + Returns: + + * prompt_token_ids + ''' + + bos_token_id = self.get_bos_token_id() + assert bos_token_id is not None + return [bos_token_id] + + def _prepare_decoder_input_ids_for_generation( + self, + decoder_input_ids: Optional[List[int]], + ) -> List[int]: + """ + Prepares `decoder_input_ids` for generation with encoder-decoder models. + + Based on + + https://github.com/huggingface/transformers/blob/ + 4037a2b5b1278736e566aec12e169100275545ea/ + src/transformers/generation/utils.py + + specifically GenerationMixin._prepare_decoder_input_ids_for_generation() + + Arguments: + + * decoder_input_ids: input token ids to preprocess + + Returns: + + * Processed token list + """ + + decoder_start_token_id = self.get_decoder_start_token_id() + assert decoder_start_token_id is not None + + if decoder_input_ids is None: + # no decoder prompt input -> + # use decoder_start_token_id as decoder_input_ids + decoder_input_ids = self._get_default_enc_dec_decoder_prompt() + + if (len(decoder_input_ids) == 0 + or decoder_input_ids[0] != decoder_start_token_id): + decoder_input_ids = [decoder_start_token_id] + decoder_input_ids + + return decoder_input_ids + + def _apply_prompt_adapter( + self, + prompt_token_ids: List[int], + prompt_adapter_request: Optional[PromptAdapterRequest], + ) -> List[int]: + if prompt_adapter_request: + prompt_token_ids = ( + [0] * prompt_adapter_request.prompt_adapter_num_virtual_tokens + + prompt_token_ids) + + return prompt_token_ids + + def _tokenize_prompt( + self, + prompt: str, + request_id: str, + lora_request: Optional[LoRARequest], + ) -> List[int]: + """ + Apply the model's tokenizer to a text prompt, returning the + corresponding token IDs. + """ + tokenizer = self.get_tokenizer_group() + + return tokenizer.encode(request_id=request_id, + prompt=prompt, + lora_request=lora_request) + + async def _tokenize_prompt_async( + self, + prompt: str, + request_id: str, + lora_request: Optional[LoRARequest], + ) -> List[int]: + """Async version of :meth:`_tokenize_prompt`.""" + tokenizer = self.get_tokenizer_group() + + return await tokenizer.encode_async(request_id=request_id, + prompt=prompt, + lora_request=lora_request) + + def _extract_prompt_components( + self, + inputs: SingletonPromptInputs, + request_id: str, + lora_request: Optional[LoRARequest] = None, + ) -> PromptComponents: + ''' + Extract the components of any single encoder or decoder input prompt. + + Arguments: + + * request_id + * inputs: single encoder or decoder input prompt + * lora_request: this is only valid for decoder prompts + + Returns: + + * prompt + * prompt_token_ids + * multi_modal_data + ''' + + parsed = parse_singleton_prompt(inputs) + + if parsed["type"] == "str": + prompt = parsed["content"] + prompt_token_ids = self._tokenize_prompt( + prompt, + request_id=request_id, + lora_request=lora_request, + ) + multi_modal_data = None + elif parsed["type"] == "tokens": + prompt = None + prompt_token_ids = parsed["content"]["prompt_token_ids"] + multi_modal_data = parsed["content"].get("multi_modal_data") + elif parsed["type"] == "text": + prompt = parsed["content"]["prompt"] + prompt_token_ids = self._tokenize_prompt( + prompt, + request_id=request_id, + lora_request=lora_request, + ) + multi_modal_data = parsed["content"].get("multi_modal_data") + else: + assert_never(parsed) + + return prompt, prompt_token_ids, multi_modal_data + + async def _extract_prompt_components_async( + self, + inputs: SingletonPromptInputs, + request_id: str, + lora_request: Optional[LoRARequest] = None, + ) -> PromptComponents: + """Async version of :meth:`_extract_prompt_components`.""" + parsed = parse_singleton_prompt(inputs) + + if parsed["type"] == "str": + prompt = parsed["content"] + prompt_token_ids = await self._tokenize_prompt_async( + prompt, + request_id=request_id, + lora_request=lora_request, + ) + multi_modal_data = None + elif parsed["type"] == "tokens": + prompt = None + prompt_token_ids = parsed["content"]["prompt_token_ids"] + multi_modal_data = parsed["content"].get("multi_modal_data") + elif parsed["type"] == "text": + prompt = parsed["content"]["prompt"] + prompt_token_ids = await self._tokenize_prompt_async( + prompt, + request_id=request_id, + lora_request=lora_request, + ) + multi_modal_data = parsed["content"].get("multi_modal_data") + else: + assert_never(parsed) + + return prompt, prompt_token_ids, multi_modal_data + + def _build_enc_dec_llm_inputs( + self, + encoder_comps: PromptComponents, + decoder_comps: DecoderPromptComponents, + ) -> EncoderDecoderLLMInputs: + encoder_prompt, encoder_prompt_ids, encoder_mm_data = encoder_comps + decoder_prompt, decoder_prompt_ids, decoder_mm_data = decoder_comps + + if encoder_mm_data is not None or decoder_mm_data is not None: + raise ValueError("Multi-modal encoder-decoder models are " + "not supported yet") + + decoder_prompt_ids = ( + self._prepare_decoder_input_ids_for_generation(decoder_prompt_ids)) + + return EncoderDecoderLLMInputs( + prompt_token_ids=decoder_prompt_ids, + prompt=decoder_prompt, + encoder_prompt_token_ids=encoder_prompt_ids, + encoder_prompt=encoder_prompt, + ) + + def _process_encoder_decoder_prompt( + self, + inputs: PromptInputs, + request_id: str, + ) -> EncoderDecoderLLMInputs: + ''' + For encoder/decoder models only: + Process an input prompt into an + :class:`EncoderDecoderLLMInputs` instance. + + There are two types of input prompts: + singleton prompts which carry only the + encoder prompt, and explicit encoder/decoder + prompts which carry both the encoder and the + decoder prompts as member variables. + + This function handles the following scenarios: + * Singleton encoder prompt: extract encoder prompt + token ids & infer default decoder prompt token ids + * Explicit encoder/decoder prompt: extract encoder + and decoder prompt token ids + + Note that for Explicit encoder/decoder prompts, + each sub-prompt (encoder or decoder prompt) can + have any possible singleton type; thus this + method relies on helper functions to obtain + token ids for the sub-prompts. + + Arguments: + + * inputs: an input prompt + * request_id + + Returns: + + * :class:`EncoderDecoderLLMInputs` instance + ''' + + encoder_comps: PromptComponents + decoder_comps: DecoderPromptComponents + + if is_explicit_encoder_decoder_prompt(inputs): + encoder_comps = self._extract_prompt_components( + inputs["encoder_prompt"], + request_id=request_id, + ) + + if (decoder_input := inputs["decoder_prompt"]) is None: + decoder_comps = None, None, None + else: + decoder_comps = self._extract_prompt_components( + decoder_input, + request_id=request_id, + ) + else: + encoder_comps = self._extract_prompt_components( + inputs, + request_id=request_id, + ) + + decoder_comps = None, None, None + + return self._build_enc_dec_llm_inputs(encoder_comps, decoder_comps) + + async def _process_encoder_decoder_prompt_async( + self, + inputs: PromptInputs, + request_id: str, + ) -> EncoderDecoderLLMInputs: + """Async version of :meth:`_process_encoder_decoder_prompt`.""" + encoder_comps: PromptComponents + decoder_comps: DecoderPromptComponents + + if is_explicit_encoder_decoder_prompt(inputs): + encoder_task = self._extract_prompt_components_async( + inputs["encoder_prompt"], + request_id=request_id, + ) + + if (decoder_input := inputs["decoder_prompt"]) is None: + encoder_comps = await encoder_task + decoder_comps = None, None, None + else: + decoder_task = self._extract_prompt_components_async( + decoder_input, + request_id=request_id, + ) + + encoder_comps, decoder_comps = await asyncio.gather( + encoder_task, decoder_task) + else: + encoder_comps = await self._extract_prompt_components_async( + inputs, + request_id=request_id, + ) + + decoder_comps = None, None, None + + return self._build_enc_dec_llm_inputs(encoder_comps, decoder_comps) + + def _build_decoder_only_llm_inputs( + self, + prompt_comps: PromptComponents, + prompt_adapter_request: Optional[PromptAdapterRequest], + ) -> LLMInputs: + prompt, prompt_token_ids, multi_modal_data = prompt_comps + + prompt_token_ids = self._apply_prompt_adapter( + prompt_token_ids, prompt_adapter_request=prompt_adapter_request) + + return LLMInputs(prompt_token_ids=prompt_token_ids, + prompt=prompt, + multi_modal_data=multi_modal_data) + + def _process_decoder_only_prompt( + self, + inputs: SingletonPromptInputs, + request_id: str, + lora_request: Optional[LoRARequest] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + ) -> LLMInputs: + ''' + For decoder-only models: + Process an input prompt into an :class:`LLMInputs` instance. + + Arguments: + + * inputs: input prompt + * request_id + * lora_request + * prompt_adapter_request + + Returns: + + * :class:`LLMInputs` instance + ''' + + prompt_comps = self._extract_prompt_components( + inputs, + request_id=request_id, + lora_request=lora_request, + ) + + return self._build_decoder_only_llm_inputs( + prompt_comps, + prompt_adapter_request=prompt_adapter_request, + ) + + async def _process_decoder_only_prompt_async( + self, + inputs: SingletonPromptInputs, + request_id: str, + lora_request: Optional[LoRARequest] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + ) -> LLMInputs: + """Async version of :meth:`_process_decoder_only_prompt`.""" + prompt_comps = await self._extract_prompt_components_async( + inputs, + request_id=request_id, + lora_request=lora_request, + ) + + return self._build_decoder_only_llm_inputs( + prompt_comps, + prompt_adapter_request=prompt_adapter_request, + ) + + def preprocess( + self, + inputs: PromptInputs, + request_id: str, + lora_request: Optional[LoRARequest] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + ) -> Union[LLMInputs, EncoderDecoderLLMInputs]: + """Preprocess the input prompt.""" + if self.is_encoder_decoder_model(): + # Encoder-decoder model requires special mapping of + # input prompts to encoder & decoder + return self._process_encoder_decoder_prompt( + inputs, + request_id=request_id, + ) + + if is_explicit_encoder_decoder_prompt(inputs): + raise ValueError("Cannot pass encoder-decoder prompt " + "to decoder-only models") + + # Decoder-only operation + return self._process_decoder_only_prompt( + inputs, + request_id=request_id, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request, + ) + + async def preprocess_async( + self, + inputs: PromptInputs, + request_id: str, + lora_request: Optional[LoRARequest] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + ) -> Union[LLMInputs, EncoderDecoderLLMInputs]: + """Async version of :meth:`preprocess`.""" + if self.is_encoder_decoder_model(): + # Encoder-decoder model requires special mapping of + # input prompts to encoder & decoder + return await self._process_encoder_decoder_prompt_async( + inputs, + request_id=request_id, + ) + + if is_explicit_encoder_decoder_prompt(inputs): + raise ValueError("Cannot pass encoder-decoder prompt " + "to decoder-only models") + + # Decoder-only operation + return await self._process_decoder_only_prompt_async( + inputs, + request_id=request_id, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request, + ) + + def is_encoder_decoder_model(self): + return self.model_config.is_encoder_decoder_model diff --git a/vllm/model_executor/custom_op.py b/vllm/model_executor/custom_op.py index 49247cd5de42a..9102b5e19ebec 100644 --- a/vllm/model_executor/custom_op.py +++ b/vllm/model_executor/custom_op.py @@ -1,5 +1,6 @@ import torch.nn as nn +import vllm.envs as envs from vllm.platforms import current_platform from vllm.utils import is_cpu, is_hip, is_xpu @@ -53,6 +54,10 @@ def forward_gaudi(self, *args, **kwargs): def dispatch_forward(self): # NOTE(woosuk): Here we assume that vLLM was built for only one # specific backend. Currently, we do not support dynamic dispatching. + + if envs.VLLM_TEST_COMPILE_NO_CUSTOM_OPS: + return self.forward_native + if is_hip(): return self.forward_hip elif is_cpu(): diff --git a/vllm/model_executor/guided_decoding/__init__.py b/vllm/model_executor/guided_decoding/__init__.py index 7161e83952a3d..f4fe8a7307c04 100644 --- a/vllm/model_executor/guided_decoding/__init__.py +++ b/vllm/model_executor/guided_decoding/__init__.py @@ -6,6 +6,7 @@ from vllm.model_executor.guided_decoding.guided_fields import ( GuidedDecodingRequest) from vllm.sampling_params import LogitsProcessor +from vllm.transformers_utils.tokenizer import MistralTokenizer async def get_guided_decoding_logits_processor( @@ -15,12 +16,23 @@ async def get_guided_decoding_logits_processor( request = _adapt_request_for_tool_use(request) if guided_decoding_backend == 'outlines': + if isinstance(tokenizer, MistralTokenizer): + raise NotImplementedError( + "Guided decoding with 'outlines' is currently not supported " + "for Mistral tokenizer. Please consider contributing to the " + "'outlines' project if you are interested in this feature.") # NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193 from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa get_outlines_guided_decoding_logits_processor) return await get_outlines_guided_decoding_logits_processor( request, tokenizer) if guided_decoding_backend == 'lm-format-enforcer': + if isinstance(tokenizer, MistralTokenizer): + raise NotImplementedError( + "Guided decoding with 'lm-format-enforcer' is currently not " + "supported for Mistral tokenizer. Please consider contributing " + "to the 'lm-format-enforcer' project if you are interested " + "in this feature.") from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import ( # noqa get_lm_format_enforcer_guided_decoding_logits_processor) return await get_lm_format_enforcer_guided_decoding_logits_processor( @@ -37,12 +49,23 @@ def get_local_guided_decoding_logits_processor( # request = _adapt_request_for_tool_use(request) if guided_decoding_backend == 'outlines': + if isinstance(tokenizer, MistralTokenizer): + raise NotImplementedError( + "Guided decoding with 'outlines' is currently not supported " + "for Mistral tokenizer. Please consider contributing to the " + "'outlines' project if you are interested in this feature.") # NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193 from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa get_local_outlines_guided_decoding_logits_processor) return get_local_outlines_guided_decoding_logits_processor( guided_options, tokenizer) if guided_decoding_backend == 'lm-format-enforcer': + if isinstance(tokenizer, MistralTokenizer): + raise NotImplementedError( + "Guided decoding with 'lm-format-enforcer' is currently not " + "supported for Mistral tokenizer. Please consider contributing " + "to the 'lm-format-enforcer' project if you are interested " + "in this feature.") from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import ( # noqa get_local_lm_format_enforcer_guided_decoding_logits_processor) return get_local_lm_format_enforcer_guided_decoding_logits_processor( diff --git a/vllm/model_executor/guided_decoding/outlines_logits_processors.py b/vllm/model_executor/guided_decoding/outlines_logits_processors.py index 554dcc0ed43ed..c28bd71c9f682 100644 --- a/vllm/model_executor/guided_decoding/outlines_logits_processors.py +++ b/vllm/model_executor/guided_decoding/outlines_logits_processors.py @@ -67,9 +67,9 @@ def __call__(self, input_ids: List[int], instruction = self._guide.get_next_instruction( state=self._fsm_state[seq_id]) - if type(instruction) == Generate: + if type(instruction) == Generate: # noqa: E721 allowed_tokens = instruction.tokens - elif type(instruction) == Write: + elif type(instruction) == Write: # noqa: E721 # TODO: support fast forward tokens allowed_tokens = [instruction.tokens[0]] else: diff --git a/vllm/model_executor/layers/activation.py b/vllm/model_executor/layers/activation.py index 4c14fe476ee4a..43056786d35c9 100644 --- a/vllm/model_executor/layers/activation.py +++ b/vllm/model_executor/layers/activation.py @@ -114,9 +114,7 @@ def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: def forward_xpu(self, x: torch.Tensor) -> torch.Tensor: from vllm._ipex_ops import ipex_ops as ops - out = torch.empty_like(x) - ops.gelu_new(out, x) - return out + return ops.gelu_new(x) class FastGELU(CustomOp): @@ -136,9 +134,7 @@ def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: def forward_xpu(self, x: torch.Tensor) -> torch.Tensor: from vllm._ipex_ops import ipex_ops as ops - out = torch.empty_like(x) - ops.gelu_fast(out, x) - return out + return ops.gelu_fast(x) class QuickGELU(CustomOp): @@ -155,6 +151,13 @@ def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: ops.gelu_quick(out, x) return out + def forward_xpu(self, x: torch.Tensor) -> torch.Tensor: + from vllm._ipex_ops import ipex_ops as ops + + out = torch.empty_like(x) + ops.gelu_quick(out, x) + return out + # TODO implement forward_xpu for QuickGELU # def forward_xpu(self, x: torch.Tensor) -> torch.Tensor: diff --git a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py index 200a6148978aa..866b18d725a8c 100644 --- a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py @@ -7,18 +7,21 @@ from vllm import _custom_ops as ops from vllm.model_executor.layers.fused_moe.fused_moe import ( fused_topk, moe_align_block_size, try_get_optimal_moe_config) +from vllm.scalar_type import scalar_types def single_marlin_moe( - hidden_states: torch.Tensor, - w: torch.Tensor, - scales: torch.Tensor, - gating_output: torch.Tensor, - g_idx: torch.Tensor, - perm: torch.Tensor, - topk: int, - renormalize: bool, - override_config: Optional[Dict[str, Any]] = None) -> torch.Tensor: + hidden_states: torch.Tensor, + w: torch.Tensor, + scales: torch.Tensor, + gating_output: torch.Tensor, + g_idx: torch.Tensor, + perm: torch.Tensor, + topk: int, + renormalize: bool, + override_config: Optional[Dict[str, Any]] = None, + num_bits: int = 8, +) -> torch.Tensor: """ This function computes the multiplication of hidden_states with expert weights used in Marlin MoE, using weights w and top-k gating mechanism. @@ -36,6 +39,7 @@ def single_marlin_moe( - renormalize (bool): If True, renormalize the top-k weights to sum to 1. - override_config (Optional[Dict[str, Any]]): Optional override for the kernel configuration. + - num_bits (bool): The number of bits in expert weights quantization. Returns: - torch.Tensor: The output tensor after applying the MoE layer. @@ -48,10 +52,11 @@ def single_marlin_moe( assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" assert w.is_contiguous(), "Expert weights must be contiguous" assert hidden_states.dtype == torch.float16 + assert num_bits in [4, 8] M, K = hidden_states.shape E = w.shape[0] - N = w.shape[2] // 2 + N = w.shape[2] // (num_bits // 2) topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk, renormalize) @@ -76,10 +81,13 @@ def single_marlin_moe( device="cuda", requires_grad=False) + scalar_type = (scalar_types.uint4b8 + if num_bits == 4 else scalar_types.uint8b128) + intermediate_cache = torch.ops._moe_C.marlin_gemm_moe( hidden_states, w, sorted_token_ids, topk_weights, topk_ids, scales, - g_idx, perm, workspace, M, N, K, True, E, topk, block_size_m, True, - False) + g_idx, perm, workspace, scalar_type, M, N, K, True, E, topk, + block_size_m, True, False) return torch.sum(intermediate_cache.view(*intermediate_cache.shape), dim=1) @@ -98,6 +106,7 @@ def fused_marlin_moe( override_config: Optional[Dict[str, Any]] = None, w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, + num_bits: int = 8, ) -> torch.Tensor: """ This function computes a Mixture of Experts (MoE) layer using two sets of @@ -122,6 +131,7 @@ def fused_marlin_moe( w1. - w2_scale (Optional[torch.Tensor]): Optional scale to be used for w2. + - num_bits (bool): The number of bits in expert weights quantization. Returns: - torch.Tensor: The output tensor after applying the MoE layer. @@ -131,13 +141,14 @@ def fused_marlin_moe( 0], "Number of tokens mismatch" assert hidden_states.shape[ 1] == w1.shape[1] * 16, "Hidden size mismatch w1" - assert hidden_states.shape[ - 1] == w2.shape[2] // 2, "Hidden size mismatch w2" + assert hidden_states.shape[1] == w2.shape[2] // ( + num_bits // 2), "Hidden size mismatch w2" assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch" assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" assert w1.is_contiguous(), "Expert weights1 must be contiguous" assert w2.is_contiguous(), "Expert weights2 must be contiguous" assert hidden_states.dtype == torch.float16 + assert num_bits in [4, 8] M, K = hidden_states.shape E = w1.shape[0] @@ -165,6 +176,9 @@ def fused_marlin_moe( device="cuda", requires_grad=False) + scalar_type = (scalar_types.uint4b8 + if num_bits == 4 else scalar_types.uint8b128) + intermediate_cache2 = torch.empty( (M * topk_ids.shape[1], N), device=hidden_states.device, @@ -181,6 +195,7 @@ def fused_marlin_moe( g_idx1, perm1, workspace, + scalar_type, M, 2 * N, K, @@ -204,6 +219,7 @@ def fused_marlin_moe( g_idx2, perm2, workspace, + scalar_type, M, K, N, diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index bd13d8fecbb96..3e01112eaa14d 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -410,6 +410,7 @@ def fused_topk( if renormalize: topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) + return topk_weights, topk_ids @@ -443,7 +444,8 @@ def grouped_topk(hidden_states: torch.Tensor, if renormalize: topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) - return topk_weights, topk_ids + + return topk_weights.to(torch.float32), topk_ids.to(torch.int32) def get_config_dtype_str(dtype: torch.dtype, diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py index e3d588efd9b6d..14f60e9172f29 100644 --- a/vllm/model_executor/layers/layernorm.py +++ b/vllm/model_executor/layers/layernorm.py @@ -82,14 +82,11 @@ def forward_xpu( self.variance_epsilon, ) return x, residual - out = torch.empty_like(x) - ops.rms_norm( - out, + return ops.rms_norm( x, self.weight.data, self.variance_epsilon, ) - return out def extra_repr(self) -> str: s = f"hidden_size={self.weight.data.size(0)}" diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index cea768469aeb8..568892778abe2 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -530,8 +530,11 @@ def weight_loader(self, param_data = param_data.narrow(output_dim, shard_offset, shard_size) start_idx = tp_rank * shard_size - loaded_weight = loaded_weight.narrow(output_dim, start_idx, - shard_size) + # bitsandbytes loads the weights of the specific portion + # no need to narrow here + if not use_bitsandbytes_4bit: + loaded_weight = loaded_weight.narrow(output_dim, start_idx, + shard_size) # Special case for AQLM codebooks. elif is_metadata: # metadata indicates fixed size concatenated along dim 0 @@ -899,8 +902,13 @@ def weight_loader(self, else: shard_id = tp_rank // self.num_kv_head_replicas start_idx = shard_id * shard_size - loaded_weight = loaded_weight.narrow(output_dim, start_idx, - shard_size) + + # bitsandbytes loads the weights of the specific portion + # no need to narrow here + if not use_bitsandbytes_4bit: + loaded_weight = loaded_weight.narrow(output_dim, start_idx, + shard_size) + # Special case for for AQLM codebooks. elif is_metadata: # metadata indicates fixed size concatenated along dim 0 @@ -1000,6 +1008,7 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): tp_rank = get_tensor_model_parallel_rank() tp_size = get_tensor_model_parallel_world_size() input_dim = getattr(param, "input_dim", None) + use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False) # Special case for GGUF is_gguf_weight = getattr(param, "is_gguf_weight", False) @@ -1015,7 +1024,9 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): param.materialize(tuple(weight_shape), dtype=loaded_weight.dtype) param_data = param.data - if input_dim is not None: + # bitsandbytes loads the weights of the specific portion + # no need to narrow here + if input_dim is not None and not use_bitsandbytes_4bit: shard_size = param_data.shape[input_dim] start_idx = tp_rank * shard_size loaded_weight = loaded_weight.narrow(input_dim, start_idx, diff --git a/vllm/model_executor/layers/mamba/ops/causal_conv1d.py b/vllm/model_executor/layers/mamba/ops/causal_conv1d.py index 413c8bc227ae8..196d81267f32f 100644 --- a/vllm/model_executor/layers/mamba/ops/causal_conv1d.py +++ b/vllm/model_executor/layers/mamba/ops/causal_conv1d.py @@ -1,4 +1,5 @@ # Copyright (c) 2024, Tri Dao. +# Adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/causal_conv1d/causal_conv1d_interface.py from typing import Optional @@ -70,12 +71,17 @@ def causal_conv1d_update(x: torch.Tensor, conv_state: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None, - activation: Optional[str] = None): + activation: Optional[str] = None, + conv_state_indices: Optional[torch.Tensor] = None): """ x: (batch, dim) conv_state: (batch, dim, width) weight: (dim, width) bias: (dim,) + conv_state_indices: (batch,), dtype int32 + If not None, the conv_state is a larger tensor along the batch dim, + and we are selecting the batch coords specified by conv_state_indices. + Useful for a continuous batching scenario. out: (batch, dim) """ @@ -83,4 +89,4 @@ def causal_conv1d_update(x: torch.Tensor, raise NotImplementedError("activation must be None, silu, or swish") activation_bool = activation in ["silu", "swish"] return ops.causal_conv1d_update(x, conv_state, weight, bias, - activation_bool) + activation_bool, conv_state_indices) diff --git a/vllm/model_executor/layers/ops/rand.py b/vllm/model_executor/layers/ops/rand.py deleted file mode 100644 index 4a429e329567d..0000000000000 --- a/vllm/model_executor/layers/ops/rand.py +++ /dev/null @@ -1,157 +0,0 @@ -from typing import Optional, Union - -import torch -import triton -import triton.language as tl - - -def seeded_uniform( - *size, - seeds: torch.Tensor, - out: Optional[torch.Tensor] = None, - dtype: Optional[torch.dtype] = None, - device: Optional[Union[torch.device, str]] = None, - pin_memory: Optional[bool] = False, -) -> torch.Tensor: - """Similar to torch.rand, but allows for seeds to be set per row. - - seeds must be a 1d tensor. The output tensor may be 1d, 2d, or 3d. - If it is 3d, the additional seeds needed will be derived automatically - in a deterministic fashion: - [ - row 0: [columns_with_seed_0], [columns_with_seed0^1], ... - ] - """ - n_dims = len(size) - - if n_dims > 3: - raise ValueError("seeded_uniform only supports up to 3D tensors") - - if out is None: - out = torch.empty(*size, - dtype=dtype, - device=device, - pin_memory=pin_memory) - elif out.shape != size: - raise ValueError("shape of out and size must be the same") - - if n_dims == 3: - n_rows, n_3d, n_cols = out.shape - stride_row = out.stride(0) - stride_3d = out.stride(1) - elif n_dims == 2: - n_rows, n_cols = out.shape - n_3d = 1 - stride_row = out.stride(0) - stride_3d = 1 - else: - n_cols = out.shape[0] - n_rows = 1 - n_3d = 1 - stride_row = 1 - stride_3d = 1 - - if seeds.ndim != 1: - raise ValueError("seeds must be a 1D tensor") - - if seeds.numel() != n_rows: - raise ValueError( - "seeds must have the same number of elements as out has rows") - - # The philox PRNG Triton uses generates 4 random numbers at once. - # Therefore, the most efficient use of it is to divide the - # block size by 4, and then save the generated random numbers to - # each of the 4 slices of the tensor. - full_block_size = triton.next_power_of_2(n_cols) - philox_block_size = max(full_block_size // 4, 1) - n_slices = full_block_size // philox_block_size - num_warps = 4 - # Manual tuning. This seems to give best performance on A100 for - # simple kernels like this. - if philox_block_size >= 8192: - num_warps = 32 - elif philox_block_size >= 4096: - num_warps = 16 - elif philox_block_size >= 2048: - num_warps = 8 - - _seeded_uniform_triton[(n_rows, n_3d)]( - out, - seeds, - stride_row, - stride_3d, - seeds.stride(0), - n_rows, - n_3d, - n_cols, - n_slices=n_slices, - num_warps=num_warps, - block_size=philox_block_size, - ) - return out - - -@triton.jit -def _seeded_uniform_triton( - out_ptr: torch.Tensor, - seed_ptr: torch.Tensor, - out_row_stride: int, - out_3d_stride: int, - seed_row_stride: int, - n_rows: int, - n_3d: int, - n_cols: int, - n_slices: tl.constexpr, - block_size: tl.constexpr, -): - """ - Generate a random float32 number in [0, 1) for each element in the output - tensor. The random numbers in a row generated using the seed for that row. - - Args: - out_ptr: The output tensor. - seed_ptr: The per-row seeds to use for random number generation. - out_row_stride: The stride between rows of the output tensor. - out_3d_stride: The stride between 3D slices of the output tensor. - seed_row_stride: The stride between rows of the seed tensor. - n_rows: The number of rows in the output tensor. - n_3d: The size of second dimension of the output tensor, - if output tensor is 3D. - n_cols: The number of columns in the output tensor. - n_slices: The number of philox outputs to use. - """ - tl.static_assert(n_slices > 0 and n_slices <= 4, "0 < n_slices <= 4") - - # Get the row index. - row_idx = tl.program_id(axis=0) - three_d_idx = tl.program_id(axis=1) - - philox_offsets = tl.arange(0, block_size) - # Get the seed for the current element. - seed = tl.load(seed_ptr + row_idx * seed_row_stride) - if three_d_idx > 0: - seed ^= three_d_idx - # Generate random numbers in [0, 1). - out1, out2, out3, out4 = tl.rand4x(seed, philox_offsets) - - output_row_start_ptr = (out_ptr + row_idx * out_row_stride + - three_d_idx * out_3d_stride) - out1_offsets = philox_offsets - tl.store(output_row_start_ptr + out1_offsets, - out1, - mask=out1_offsets < n_cols) - if n_slices > 1: - out2_offsets = tl.arange(block_size, block_size * 2) - tl.store(output_row_start_ptr + out2_offsets, - out2, - mask=out2_offsets < n_cols) - if n_slices > 2: - out3_offsets = tl.arange(block_size * 2, block_size * 3) - tl.store(output_row_start_ptr + out3_offsets, - out3, - mask=out3_offsets < n_cols) - if n_slices > 3: - out4_offsets = tl.arange(block_size * 3, block_size * 4) - tl.store(output_row_start_ptr + out4_offsets, - out4, - mask=out4_offsets < n_cols) diff --git a/vllm/model_executor/layers/ops/sample.py b/vllm/model_executor/layers/ops/sample.py deleted file mode 100644 index fb88a05daf482..0000000000000 --- a/vllm/model_executor/layers/ops/sample.py +++ /dev/null @@ -1,394 +0,0 @@ -from typing import Optional, Tuple - -import torch -import triton -import triton.language as tl - -from vllm.model_executor.layers.ops.rand import seeded_uniform -from vllm.triton_utils.sample import get_num_triton_sampler_splits - -_EPS: tl.constexpr = 1e-6 - - -def _multi_split_sample( - probs: torch.Tensor, - seeds: torch.Tensor, - n_splits: int, - sampled_tokens_size: Tuple[int, int], - sampled_logprobs_size: Tuple[int, int], - sample_indices: torch.Tensor, - logprobs: torch.Tensor, - *, - modify_greedy_probs: bool = False, - save_logprobs: bool = False, -): - """Sample tokens where vocab size is split into multiple parts - (too large for Triton otherwise).""" - assert seeds.ndim == 2 and seeds.shape[0] == n_splits - split_probs = probs.tensor_split(n_splits, 1) - split_logprobs = logprobs.tensor_split(n_splits, 1) - sampled_tokens_tmp = [ - torch.empty(sampled_tokens_size, dtype=torch.long, device=probs.device) - for _ in range(n_splits) - ] - sampled_logprobs_tmp = [ - torch.empty(sampled_logprobs_size, - dtype=probs.dtype, - device=probs.device) for _ in range(n_splits) - ] - # We are purposefuly using sampled_tokens_size as we need to always - # save modified probs in this case. - sampled_modified_probs_tmp = [ - torch.empty(sampled_tokens_size, - dtype=probs.dtype, - device=probs.device) for _ in range(n_splits) - ] - for i in range(n_splits): - n_samples = sample_indices.shape[0] - n_cols = split_probs[i].shape[1] - n_best = sampled_tokens_tmp[i].shape[1] - uniform_noise = seeded_uniform(n_samples, - n_best, - n_cols, - seeds=seeds[i].flatten(), - device=split_probs[i].device, - dtype=split_probs[i].dtype) - # TODO(yard1): See if we can remove the contiguous() calls. - # Will need kernel support. - _sample( - split_probs[i].contiguous(), - split_logprobs[i].contiguous(), - sample_indices, - sampled_tokens_tmp[i], - sampled_logprobs_tmp[i], - sampled_modified_probs_tmp[i], - seeds[i], - uniform_noise, - modify_greedy_probs=False, - save_logprobs=save_logprobs, - save_modified_probs=True, - ) - if i > 0: - # Add offset to sampled tokens - sampled_tokens_tmp[i].add_(i * split_probs[i - 1].shape[1]) - sampled_tokens = torch.stack(sampled_tokens_tmp) - sampled_modified_probs = torch.stack(sampled_modified_probs_tmp) - # Reduce the results from the splits. - sampled_modified_probs, indices = torch.max(sampled_modified_probs, - dim=0, - keepdim=True) - sampled_tokens = sampled_tokens.gather(0, indices).squeeze(0) - if save_logprobs: - sampled_logprobs = torch.stack(sampled_logprobs_tmp) - sampled_logprobs = sampled_logprobs.gather(0, indices).squeeze(0) - else: - sampled_logprobs = None - sampled_modified_probs = sampled_modified_probs.squeeze(0) - - if modify_greedy_probs: - # We need to modify the greedy probs for the sampled tokens. - # We can't do this in the kernel as we need to know the - # sampled tokens. - probs.fill_(0.0) - probs.scatter_(1, sampled_tokens, 1.0) - - return (sampled_tokens, sampled_logprobs, sampled_modified_probs) - - -def sample( - probs: torch.Tensor, - seeds: torch.Tensor, - *, - max_best_of: int = 1, - sample_indices: Optional[torch.Tensor] = None, - logprobs: Optional[torch.Tensor] = None, - modify_greedy_probs: bool = False, - save_logprobs: bool = False, - _save_modified_probs: bool = False, # pylint: disable=invalid-name -) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: - """Sample tokens from probs. with per-sequence seeds. - - Can sample from a subset of sequences through sample_indices. - - Args: - probs: Probabilities to sample from. - shape = [batch_size, vocab_size] - seeds: Per-sequence seed values. - shape = [n, math.ceil(vocab_size / MAX_TRITON_N_COLS)] - max_best_of: Number of samples to generate per sequence. - Sequence seed will be incremented by 1 each time. - sample_indices: Indices of sequences to sample from. - If not provided, will sample from all sequences. - shape = [n] - logprobs: Log-probabilities of the sampled tokens. - Only used for saving the logprobs if save_logprobs is True. - shape = [batch_size, vocab_size] - modify_greedy_probs: Whether to modify the greedy probabilities - for speculative sampling (sampled token = 1.0, - everything else = 0.0). - save_logprobs: Whether to save the log-probabilities of the - sampled tokens to a tensor. - _save_modified_probs: Whether to save the modified probabilities - (including gumbel noise) of the sampled tokens to a tensor. - DOES NOT include the modification done by modify_greedy_probs - (because we want to use the unmodified probs to pick the best - split in case of multi-split sampling). - This is exposed only for testing. - - Returns: - sampled_tokens: shape = [n, max_best_of] - sampled_logprobs: shape = [n, max_best_of] if save_logprobs else None - sampled_modified_probs: shape = [n, max_best_of] - if save_modified_probs else None - """ - if sample_indices is None: - sample_indices = torch.arange(0, probs.shape[0], device=probs.device) - - sampled_tokens_size = (sample_indices.size(0), max_best_of) - if save_logprobs: - if logprobs is None: - raise ValueError( - "logprobs tensor must be provided if save_logprobs is True") - sampled_logprobs_size = sampled_tokens_size - else: - # Empty tensors to invoke the kernel - sampled_logprobs_size = (0, 0) - logprobs = probs - - assert logprobs is not None - if _save_modified_probs: - sampled_modified_probs_size = sampled_tokens_size - else: - # Empty tensors to invoke the kernel - sampled_modified_probs_size = (0, 0) - - # If the number of columns in probs is too large for Triton to handle, - # we split the tensor and sample from each split separately, and then - # do an argmax+gather to combine the results. - n_splits = get_num_triton_sampler_splits(probs.shape[1]) - if n_splits > 1: - (sampled_tokens, sampled_logprobs, - sampled_modified_probs) = _multi_split_sample( - probs, - seeds, - n_splits, - sampled_tokens_size, - sampled_logprobs_size, - sample_indices, - logprobs=logprobs, - modify_greedy_probs=modify_greedy_probs, - save_logprobs=save_logprobs) - else: - sampled_tokens = torch.empty(sampled_tokens_size, - dtype=torch.long, - device=probs.device) - sampled_logprobs = torch.empty(sampled_logprobs_size, - dtype=probs.dtype, - device=probs.device) - sampled_modified_probs = torch.empty(sampled_modified_probs_size, - dtype=probs.dtype, - device=probs.device) - n_samples = sample_indices.shape[0] - n_cols = probs.shape[1] - uniform_noise = seeded_uniform(n_samples, - max_best_of, - n_cols, - seeds=seeds.flatten(), - device=probs.device, - dtype=probs.dtype) - - _sample( - probs, - logprobs, - sample_indices, - sampled_tokens, - sampled_logprobs, - sampled_modified_probs, - seeds, - uniform_noise, - modify_greedy_probs=modify_greedy_probs, - save_logprobs=save_logprobs, - save_modified_probs=_save_modified_probs, - ) - return (sampled_tokens, sampled_logprobs if save_logprobs else None, - sampled_modified_probs if _save_modified_probs else None) - - -def _sample(probs: torch.Tensor, - logprobs: torch.Tensor, - sample_indices: torch.Tensor, - output_samples: torch.Tensor, - output_logprobs: torch.Tensor, - output_modified_probs: torch.Tensor, - seeds: torch.Tensor, - uniform_noise: torch.Tensor, - *, - modify_greedy_probs: bool = False, - save_logprobs: bool = True, - save_modified_probs: bool = False) -> torch.Tensor: - """Sample tokens from probs. - - Args: - probs [batch_size, vocab_size]: probs to sample from. - logprobs [batch_size, vocab_size]: logprobs (used when - save_logprobsis True). - sample_indices [n]: Indices of the samples to use for each row of probs. - output_samples [n, n_best]: Output tensor to store samples in. - output_logprobs [n, n_best]: Output tensor to store logprobs in. - output_modified_probs [n, n_best]: Output tensor to store - probs of chosen tokens in (modified with noise). - seeds [n]: Seeds to use for sampling. If the seed is 0, we use - greedy sampling. Note this is ONLY used for determining - whether to use random sampling or not. The actual random - noise should be passed as uniform_noise. - uniform_noise [batch_size, n_best, vocab_size]: Uniform - noise to use for random sampling (will be converted - to exponential gumbel noise by the kernel). - modify_greedy_probs: If True, we modify the probs tensor in-place - to encode the sampling method used for each row. This is used - in speculative decoding. Only applies in greedy decoding. - save_logprobs: If True, we save the logprobs of the sampled tokens - in the output_logprobs tensor. - save_modified_probs: If True, we save the modified probs (with noise) - of the sampled tokens in the output_modified_probs tensor. - DOES NOT include the modification done by modify_greedy_probs - (because we want to use the unmodified probs to pick the best - split in case of multi-split sampling). - """ - n_samples = sample_indices.shape[0] - n_cols = probs.shape[1] - n_best = output_samples.shape[1] if len(output_samples.shape) > 1 else 1 - - # The block size is the smallest power of two greater than the number of - # columns in probs - block_size = triton.next_power_of_2(n_cols) - num_warps = 4 - # Manual tuning. This seems to give best performance on A100 for - # simple kernels like this. - if block_size >= 8192: - num_warps = 32 - elif block_size >= 4096: - num_warps = 16 - elif block_size >= 2048: - num_warps = 8 - - # Enqueue kernel. The 1D launch grid is simple: we have one kernel - # instance per row of the probs matrix - _sample_triton[(n_samples, n_best)]( - sample_indices, - output_samples, - output_logprobs, - output_modified_probs, - probs, - logprobs, - seeds, - uniform_noise, - output_samples.stride(0), - probs.stride(0), - uniform_noise.stride(0), - uniform_noise.stride(1) if n_best > 1 else 1, - n_samples, - n_cols, - n_best, - num_warps=num_warps, - block_size=block_size, - modify_greedy_probs=modify_greedy_probs, - save_logprobs=save_logprobs, - save_modified_probs=save_modified_probs, - ) - return output_samples, output_logprobs, output_modified_probs - - -@triton.jit -def _uniform_to_exponential(uniform_noise): - """Convert uniform samples to exponential samples.""" - # tl.rand returns values in [0, 1), so we clamp lower bound - # to _EPS to avoid log(0) and thus division by 0 later - lb = tl.full(uniform_noise.shape, _EPS, uniform_noise.dtype) - uniform_noise = tl.maximum(uniform_noise, lb) - # Use the inversion method to turn uniform samples - # into exponential samples - exponential_noise = -tl.log(uniform_noise) - return exponential_noise - - -@triton.jit -def _sample_triton( - sample_indices_ptr: torch.Tensor, output_ptr: torch.Tensor, - output_logprobs_ptr: torch.Tensor, - output_modified_probs_ptr: torch.Tensor, probs_ptr: torch.Tensor, - logprobs_ptr: torch.Tensor, seeds_ptr: torch.Tensor, - uniform_noise_ptr: torch.Tensor, output_row_stride: int, - probs_row_stride: int, uniform_noise_row_stride: int, - uniform_noise_best_stride: int, n_samples: int, n_cols: int, - n_best: int, block_size: tl.constexpr, - modify_greedy_probs: tl.constexpr, save_logprobs: tl.constexpr, - save_modified_probs: tl.constexpr): - # The rows are independent, so we parallelize across those - sample_idx = tl.program_id(0) - best_idx = tl.program_id(1) - - # Load the row index from DRAM - row_idx = tl.load(sample_indices_ptr + sample_idx) - seed = tl.load(seeds_ptr + sample_idx) - uses_random_sampling = seed != 0 - - # The stride represents how much we need to increase the - # pointer to advance 1 row - row_start_ptr = probs_ptr + row_idx * probs_row_stride - - # The block size is the next power of two greater than n_cols, - # so we can fit each row in a single block - col_offsets = tl.arange(0, block_size) - - # Load the row into SRAM, using a mask since block_size may be > than n_cols - row = tl.load(row_start_ptr + col_offsets, - mask=col_offsets < n_cols, - other=float("-inf")) - - if uses_random_sampling: - uniform_noise_start_ptr = (uniform_noise_ptr + - sample_idx * uniform_noise_row_stride + - best_idx * uniform_noise_best_stride) - uniform_noise = tl.load(uniform_noise_start_ptr + col_offsets, - mask=col_offsets < n_cols, - other=0.5) - exponential_noise = _uniform_to_exponential(uniform_noise) - row /= exponential_noise - - sampled_value, sampled_token = tl.max(row, axis=0, return_indices=True) - # clamp sampled token to n_cols - 1 - # this should not be necessary, but we do it - # just in case - if sampled_token >= n_cols: - sampled_token = n_cols - 1 - # Write back output to DRAM - output_row_start_ptr = (output_ptr + sample_idx * output_row_stride + - best_idx) - tl.store(output_row_start_ptr, sampled_token) - - if modify_greedy_probs: # noqa - if not uses_random_sampling: - # Set the probability of the sampled token to 1, all other - # tokens to zero. This is used in speculative decoding where - # the sampling method must be encoded within the sampled - # probability distributions. - row = tl.where(col_offsets == sampled_token, 1.0, 0.0) - tl.store(row_start_ptr + col_offsets, - row, - mask=col_offsets < n_cols) - - if save_modified_probs: - output_row_start_ptr = (output_modified_probs_ptr + - sample_idx * output_row_stride + best_idx) - tl.store(output_row_start_ptr, sampled_value) - - if save_logprobs: - # Load the row into SRAM, using a mask since block_size - # may be > than n_cols - sampled_logprob = tl.load(logprobs_ptr + row_idx * probs_row_stride + - sampled_token) - # Write back output to DRAM - output_row_start_ptr = (output_logprobs_ptr + - sample_idx * output_row_stride + best_idx) - tl.store(output_row_start_ptr, sampled_logprob) diff --git a/vllm/model_executor/layers/quantization/awq_marlin.py b/vllm/model_executor/layers/quantization/awq_marlin.py index eee6a8f7cff49..eed01953fb4af 100644 --- a/vllm/model_executor/layers/quantization/awq_marlin.py +++ b/vllm/model_executor/layers/quantization/awq_marlin.py @@ -110,9 +110,9 @@ def get_scaled_act_names(self) -> List[str]: def is_awq_marlin_compatible(cls, quant_config: Dict[str, Any]): # Extract data from quant config. quant_method = quant_config.get("quant_method", "").lower() - num_bits = quant_config.get("bits", None) - group_size = quant_config.get("group_size", None) - has_zp = quant_config.get("zero_point", None) + num_bits = quant_config.get("bits") + group_size = quant_config.get("group_size") + has_zp = quant_config.get("zero_point") if quant_method != "awq": return False diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index b5b2570966600..e536fae45c845 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, cast import torch from pydantic import BaseModel @@ -79,8 +79,8 @@ def get_quant_method( @classmethod def from_config(cls, config: Dict[str, Any]) -> "CompressedTensorsConfig": target_scheme_map: Dict[str, Any] = dict() - ignore: List[str] = config.get("ignore", None) - quant_format: str = config.get("format", None) + ignore = cast(List[str], config.get("ignore")) + quant_format = cast(str, config.get("format")) # The quant_config has multiple config_groups, each containing # an input_activations key with details about how the activations are @@ -116,10 +116,10 @@ def get_config_filenames(cls) -> List[str]: def _check_scheme_supported(self, min_capability: int, error: bool = True) -> bool: - capability = current_platform.get_device_capability() # type: ignore + capability_tuple = current_platform.get_device_capability() - if capability is not None: - capability = capability[0] * 10 + capability[1] + if capability_tuple is not None: + capability = capability_tuple.to_int() supported = capability >= min_capability if error and not supported: raise RuntimeError( @@ -200,7 +200,7 @@ def _is_fp8_w8a16(self, weight_quant: BaseModel, is_per_tensor_or_channel_weight = (weight_quant.strategy in [ QuantizationStrategy.TENSOR, QuantizationStrategy.CHANNEL ]) - if not (is_symmetric_weight and is_static_weight + if not (is_symmetric_weight and is_static_weight # noqa: SIM103 and is_per_tensor_or_channel_weight): return False @@ -333,7 +333,7 @@ def create_weights(self, layer: torch.nn.Module, output_size: int, params_dtype: torch.dtype, **extra_weight_attrs): """ - Use the CompressedTensorsScheme associated with each layer to create + Use the CompressedTensorsScheme associated with each layer to create the necessary parameters for the layer. See LinearMethodBase for param details """ @@ -352,8 +352,8 @@ def apply(self, x: torch.Tensor, bias: Optional[torch.Tensor] = None): """ - Use the output of create_weights and the CompressedTensorsScheme - associated with the layer to apply the forward pass with the + Use the output of create_weights and the CompressedTensorsScheme + associated with the layer to apply the forward pass with the layer input. See LinearMethodBase for param details """ diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 49c29c2775cb6..7dee2fca81153 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -6,6 +6,8 @@ from vllm import _custom_ops as ops from vllm.model_executor.layers.fused_moe import FusedMoE, FusedMoEMethodBase +from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( + WNA16_SUPPORTED_BITS) from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( CompressionFormat) from vllm.model_executor.utils import set_weight_attrs @@ -38,10 +40,11 @@ def __init__( if not (self.quant_config.quant_format == CompressionFormat.pack_quantized.value - and self.num_bits == 4): + and self.num_bits in WNA16_SUPPORTED_BITS): raise ValueError("For Fused MoE layers, only ", f"{CompressionFormat.pack_quantized.value} ", - "is supported for 4 bits") + "is supported for the following bits: ", + f"{WNA16_SUPPORTED_BITS}") def create_weights(self, layer: torch.nn.Module, num_experts: int, hidden_size: int, intermediate_size: int, @@ -292,4 +295,5 @@ def apply( topk_ids, w1_scale=layer.w13_weight_scale, w2_scale=layer.w2_weight_scale, + num_bits=self.num_bits, ) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py index 8a3d24e2fd258..5931ec36c97d5 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py @@ -8,10 +8,12 @@ from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( QuantizationStrategy) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - apply_fp8_linear, cutlass_fp8_supported, requantize_with_max_scale) + apply_fp8_linear, cutlass_fp8_supported, normalize_e4m3fn_to_e4m3fnuz, + requantize_with_max_scale) from vllm.model_executor.parameter import (ChannelQuantScaleParameter, ModelWeightParameter, PerTensorScaleParameter) +from vllm.utils import is_hip __all__ = ["CompressedTensorsW8A8Fp8"] @@ -39,16 +41,37 @@ def process_weights_after_loading(self, layer) -> None: logical_widths=layer.logical_widths, ) + if is_hip(): + weight, max_w_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz( + weight=weight, + weight_scale=max_w_scale, + input_scale=layer.input_scale) + if input_scale is not None: + layer.input_scale = Parameter(input_scale, + requires_grad=False) + layer.weight = Parameter(weight.t(), requires_grad=False) layer.weight_scale = Parameter(max_w_scale, requires_grad=False) # If channelwise, scales are already lined up, so just transpose. elif self.strategy == QuantizationStrategy.CHANNEL: weight = layer.weight + + if is_hip(): + weight, weight_scale, input_scale = \ + normalize_e4m3fn_to_e4m3fnuz( + weight=weight, + weight_scale=layer.weight_scale, + input_scale=layer.input_scale) + if input_scale is not None: + layer.input_scale = Parameter(input_scale, + requires_grad=False) + else: + weight_scale = layer.weight_scale.data + layer.weight = Parameter(weight.t(), requires_grad=False) # required by torch.compile to be torch.nn.Parameter - layer.weight_scale = Parameter(layer.weight_scale.data, - requires_grad=False) + layer.weight_scale = Parameter(weight_scale, requires_grad=False) else: raise ValueError(f"Unknown quantization strategy {self.strategy}") diff --git a/vllm/model_executor/layers/quantization/fbgemm_fp8.py b/vllm/model_executor/layers/quantization/fbgemm_fp8.py index 3ccf1af9eb898..f26907176ad1a 100644 --- a/vllm/model_executor/layers/quantization/fbgemm_fp8.py +++ b/vllm/model_executor/layers/quantization/fbgemm_fp8.py @@ -15,10 +15,11 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( is_layer_skipped) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - apply_fp8_linear) + apply_fp8_linear, normalize_e4m3fn_to_e4m3fnuz) from vllm.model_executor.parameter import (ChannelQuantScaleParameter, ModelWeightParameter) from vllm.platforms import current_platform +from vllm.utils import is_hip logger = init_logger(__name__) @@ -32,9 +33,7 @@ def __init__(self, ignore_list: List[str], input_scale_ub: float): # For GPUs that lack FP8 hardware support, we can leverage the Marlin # kernel for fast weight-only FP8 quantization - capability = current_platform.get_device_capability() - capability = capability[0] * 10 + capability[1] - self.use_marlin = capability < 89 + self.use_marlin = not current_platform.has_device_capability(89) @classmethod def get_name(cls) -> str: @@ -127,8 +126,18 @@ def process_weights_after_loading(self, layer: Module) -> None: layer.weight = Parameter(layer.weight.data, requires_grad=False) weight = layer.weight - layer.weight = Parameter(weight.t(), requires_grad=False) + if is_hip(): + weight, weight_scale, input_scale = \ + normalize_e4m3fn_to_e4m3fnuz( + weight=weight, + weight_scale=layer.weight_scale, + input_scale=None) + if input_scale is not None: + layer.input_scale = Parameter(input_scale, requires_grad=False) + layer.weight_scale = Parameter(weight_scale, requires_grad=False) + + layer.weight = Parameter(weight.t(), requires_grad=False) if self.quant_config.use_marlin: prepare_fp8_layer_for_marlin(layer) # Activations not quantized for marlin. diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 32affe06b89b7..b5feb55db0e74 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -120,9 +120,8 @@ def __init__(self, quant_config: Fp8Config): # For GPUs that lack FP8 hardware support, we can leverage the Marlin # kernel for fast weight-only FP8 quantization - capability = current_platform.get_device_capability() - capability = capability[0] * 10 + capability[1] - self.use_marlin = capability < 89 or envs.VLLM_TEST_FORCE_FP8_MARLIN + self.use_marlin = (not current_platform.has_device_capability(89) + or envs.VLLM_TEST_FORCE_FP8_MARLIN) # Disable marlin for rocm if is_hip(): self.use_marlin = False diff --git a/vllm/model_executor/layers/quantization/gguf.py b/vllm/model_executor/layers/quantization/gguf.py index a6a1ed5b0dee5..dc83017bcc7f9 100644 --- a/vllm/model_executor/layers/quantization/gguf.py +++ b/vllm/model_executor/layers/quantization/gguf.py @@ -55,7 +55,10 @@ def get_scaled_act_names(self) -> List[str]: def _fuse_mul_mat(x: torch.Tensor, qweight: torch.Tensor, qweight_type: int) -> torch.Tensor: # use dequantize mulmat for IQmatrix, mmq for k-quants - if qweight_type >= 16: + if x.shape[0] == 1: + # enable mmvq in contiguous batching + y = ops.ggml_mul_mat_vec_a8(qweight, x, qweight_type, qweight.shape[0]) + elif qweight_type >= 16: block_size, type_size = gguf.GGML_QUANT_SIZES[qweight_type] shape = (qweight.shape[0], qweight.shape[1] // type_size * block_size) weight = ops.ggml_dequantize(qweight, qweight_type, *shape) diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index 3617a32f80fc1..5a1b2d701ab0d 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -132,10 +132,10 @@ def get_scaled_act_names(self) -> List[str]: def is_gptq_marlin_compatible(cls, quant_config: Dict[str, Any]): # Extract data from quant config. quant_method = quant_config.get("quant_method", "").lower() - num_bits = quant_config.get("bits", None) - group_size = quant_config.get("group_size", None) - sym = quant_config.get("sym", None) - desc_act = quant_config.get("desc_act", None) + num_bits = quant_config.get("bits") + group_size = quant_config.get("group_size") + sym = quant_config.get("sym") + desc_act = quant_config.get("desc_act") if quant_method != "gptq": return False @@ -611,4 +611,5 @@ def apply( topk_ids, w1_scale=layer.w13_scales, w2_scale=layer.w2_scales, + num_bits=self.quant_config.quant_type.size_bits, ).to(orig_dtype) diff --git a/vllm/model_executor/layers/quantization/qqq.py b/vllm/model_executor/layers/quantization/qqq.py index c3434214a1cde..5bc3737520865 100644 --- a/vllm/model_executor/layers/quantization/qqq.py +++ b/vllm/model_executor/layers/quantization/qqq.py @@ -260,7 +260,7 @@ def apply( size_k = x_2d.shape[1] size_n = s_ch.shape[1] - x_int8, s_tok = ops.scaled_int8_quant(x_2d) + x_int8, s_tok, _ = ops.scaled_int8_quant(x_2d) output_2d = ops.marlin_qqq_gemm(x_int8, qweight, s_tok, s_ch, s_group, workspace, size_m, size_n, size_k) diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils.py b/vllm/model_executor/layers/quantization/utils/marlin_utils.py index 699d5f1844146..fea94cf7322ad 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils.py @@ -29,8 +29,9 @@ def query_marlin_supported_quant_types(has_zp: bool, device_capability: Optional[int] = None ): if device_capability is None: - major, minor = current_platform.get_device_capability() - device_capability = major * 10 + minor + capability_tuple = current_platform.get_device_capability() + device_capability = (-1 if capability_tuple is None else + capability_tuple.to_int()) if device_capability < 80: return [] @@ -52,8 +53,9 @@ def _check_marlin_supported( device_capability: Optional[int] = None) -> Tuple[bool, Optional[str]]: if device_capability is None: - major, minor = current_platform.get_device_capability() - device_capability = major * 10 + minor + capability_tuple = current_platform.get_device_capability() + device_capability = (-1 if capability_tuple is None else + capability_tuple.to_int()) supported_types = query_marlin_supported_quant_types( has_zp, device_capability) diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py index 5f9d8658a342f..8b3dfaae971c3 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py @@ -10,8 +10,7 @@ def is_fp8_marlin_supported(): - capability = current_platform.get_device_capability() - return capability[0] >= 8 + return current_platform.has_device_capability(80) def apply_fp8_marlin_linear( diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index a54e3cae73b14..fb263d121fe55 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -6,19 +6,18 @@ from vllm.platforms import current_platform from vllm.utils import is_hip -# scaled_mm in pytorch on rocm has a bug that requires always -# providing scaling factor for result. This value is created -# as global value to avoid multiple tensor allocations, and -# can be removed once pytorch fixes the bug. -TORCH_SCALED_MM_SCALE_RESULT = torch.ones(1).cuda() if is_hip() else None +# Input scaling factors are no longer optional in _scaled_mm starting +# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale +TORCH_DEVICE_IDENTITY = torch.ones(1).cuda() if is_hip() else None def cutlass_fp8_supported() -> bool: # cutlass is not supported on Rocm if is_hip(): return False - capability = current_platform.get_device_capability() - capability = capability[0] * 10 + capability[1] + + capability_tuple = current_platform.get_device_capability() + capability = -1 if capability_tuple is None else capability_tuple.to_int() return ops.cutlass_scaled_mm_supports_fp8(capability) @@ -130,19 +129,17 @@ def apply_fp8_linear( if per_tensor_weights and per_tensor_activations: # Fused GEMM_DQ - output = torch._scaled_mm( - qinput, - weight, - out_dtype=input.dtype, - scale_a=x_scale, - scale_b=weight_scale, - scale_result=TORCH_SCALED_MM_SCALE_RESULT, - bias=bias) - # Since in torch 2.5, scaled_mm only returns single value - # This should be removed when vllm-nvidia also moves to 2.5 - if is_hip(): - return torch.narrow(output, 0, 0, input.shape[0]) - return torch.narrow(output[0], 0, 0, input.shape[0]) + output = torch._scaled_mm(qinput, + weight, + out_dtype=input.dtype, + scale_a=x_scale, + scale_b=weight_scale, + bias=bias) + # A fix for discrepancy in scaled_mm which returns tuple + # for torch < 2.5 and a single value in torch >= 2.5 + if type(output) is tuple and len(output) == 2: + return torch.narrow(output[0], 0, 0, input.shape[0]) + return torch.narrow(output, 0, 0, input.shape[0]) else: # Fallback for channelwise case, where we use unfused DQ @@ -160,12 +157,23 @@ def apply_fp8_linear( # For the scaled_mm fallback case, we break this down, since it # does not support s_w being a vector. + # Making sure the dummy tensor is on the same device as the weight + global TORCH_DEVICE_IDENTITY + if TORCH_DEVICE_IDENTITY.device != weight.device: + TORCH_DEVICE_IDENTITY = TORCH_DEVICE_IDENTITY.to(weight.device) + # GEMM # This computes C = (X * W). # Output in fp32 to allow subsequent ops to happen in-place - output, _ = torch._scaled_mm(qinput, - weight, - out_dtype=torch.float32) + output = torch._scaled_mm(qinput, + weight, + scale_a=TORCH_DEVICE_IDENTITY, + scale_b=TORCH_DEVICE_IDENTITY, + out_dtype=torch.float32) + # A fix for discrepancy in scaled_mm which returns tuple + # for torch < 2.5 and a single value in torch >= 2.5 + if type(output) is tuple and len(output) == 2: + output = output[0] # Unpad (undo num_token_padding) output = torch.narrow(output, 0, 0, input.shape[0]) x_scale = torch.narrow(x_scale, 0, 0, input.shape[0]) @@ -188,7 +196,7 @@ def apply_int8_linear( # ops.scaled_int8_quant supports both dynamic and static quant. # * dynamic, layer.input_scale is None and x_scale computed from x. # * static, layer.input_scale is scalar and x_scale is input_scale. - x_q, x_scale = ops.scaled_int8_quant(input, input_scale) + x_q, x_scale, _ = ops.scaled_int8_quant(input, input_scale) return ops.cutlass_scaled_mm(x_q, weight, diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index c00da106734ae..487f5a3d2a441 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -10,12 +10,6 @@ import torch import torch.nn as nn -from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics -from vllm.triton_utils import HAS_TRITON - -if HAS_TRITON: - from vllm.model_executor.layers.ops.sample import sample as sample_triton - import vllm.envs as envs from vllm.model_executor.sampling_metadata import (SamplingMetadata, SamplingTensors, @@ -23,6 +17,7 @@ from vllm.sampling_params import SamplingType from vllm.sequence import (CompletionSequenceGroupOutput, Logprob, PromptLogprobs, SampleLogprobs, SequenceOutput) +from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics if envs.VLLM_USE_FLASHINFER_SAMPLER and find_spec("flashinfer"): import flashinfer.sampling @@ -740,7 +735,7 @@ def _sample_with_torch( ) -> SampleReturnType: '''Torch-oriented _sample() implementation. - Single-step scheduling: + Single-step scheduling: * Perform GPU-side sampling computation * Immediately Pythonize sampling result @@ -777,7 +772,7 @@ def _sample_with_torch( # Counterintiutively, having two loops here is actually faster. # The first loop can run without waiting on GPU<->CPU sync. for sampling_type in SamplingType: - sample_indices = categorized_sample_indices[sampling_type][:, 0] + sample_indices = categorized_sample_indices[sampling_type] num_tokens = len(sample_indices) if num_tokens == 0: continue @@ -863,88 +858,6 @@ def _sample_with_torch( ) -def _sample_with_triton_kernel( - probs: torch.Tensor, - logprobs: torch.Tensor, - sampling_metadata: SamplingMetadata, - sampling_tensors: SamplingTensors, -) -> SampleResultType: - categorized_seq_group_ids: Dict[SamplingType, - List[int]] = {t: [] - for t in SamplingType} - categorized_sample_indices = sampling_metadata.categorized_sample_indices - for i, seq_group in enumerate(sampling_metadata.seq_groups): - sampling_params = seq_group.sampling_params - sampling_type = sampling_params.sampling_type - categorized_seq_group_ids[sampling_type].append(i) - - sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {} - sample_metadata: Dict[SamplingType, - Tuple[List[int], List[SequenceGroupToSample], - torch.Tensor, torch.Tensor]] = {} - max_best_of_in_batch = 1 - - # Counterintiutively, having two loops here is actually faster. - # The first loop can run without waiting on GPU<->CPU sync. - for sampling_type in SamplingType: - sample_indices = categorized_sample_indices[sampling_type][:, 0] - sampled_token_indices = categorized_sample_indices[sampling_type][:, 1] - num_tokens = len(sample_indices) - if num_tokens == 0: - continue - seq_group_id = categorized_seq_group_ids[sampling_type] - seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_id] - sample_metadata[sampling_type] = (seq_group_id, seq_groups, - sample_indices, - sampled_token_indices) - if sampling_type in (SamplingType.GREEDY, SamplingType.RANDOM, - SamplingType.RANDOM_SEED): - for seq_group in seq_groups: - if seq_group.is_prompt: - sampling_params = seq_group.sampling_params - max_best_of_in_batch = max(max_best_of_in_batch, - sampling_params.best_of) - elif sampling_type == SamplingType.BEAM: - beam_search_logprobs = logprobs[sample_indices] - else: - raise ValueError(f"Unsupported sampling type: {sampling_type}") - - sampled_tokens, _, _ = sample_triton( - probs=probs, - seeds=sampling_tensors.sampling_seeds, - max_best_of=max_best_of_in_batch, - sample_indices=sampling_tensors.sample_indices, - logprobs=logprobs, - # don't save logprobs because we have logic for that below - # TODO: use this instead of the CPU-based logic below - save_logprobs=False, - ) - - # GPU<->CPU sync happens in the loop below. - - for sampling_type in SamplingType: - if sampling_type not in sample_metadata: - continue - (seq_group_id, seq_groups, sample_indices, - sampled_token_indices) = sample_metadata[sampling_type] - if sampling_type == SamplingType.GREEDY: - sample_results = _greedy_sample( - seq_groups, sampled_tokens[sampled_token_indices][:, 0]) - elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED): - sample_results = _random_sample( - seq_groups, sampled_tokens[sampled_token_indices]) - elif sampling_type == SamplingType.BEAM: - sample_results = _beam_search_sample(seq_groups, - beam_search_logprobs) - sample_results_dict.update(zip(seq_group_id, sample_results)) - - sample_results = [ - sample_results_dict.get(i, ([], [])) - for i in range(len(sampling_metadata.seq_groups)) - ] - return sample_results - - def _sample( probs: torch.Tensor, logprobs: torch.Tensor, @@ -974,10 +887,6 @@ def _sample( modify_greedy_probs=modify_greedy_probs, ) - # TODO: Enable once Triton kernel & associated code is faster. - # return _sample_with_triton_kernel(probs, logprobs, sampling_metadata, - # sampling_tensors) - def _get_ranks(x: torch.Tensor, indices: torch.Tensor) -> torch.Tensor: """ diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index ac869e56ce198..f0d2a9e7f06be 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -22,6 +22,8 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoadFormat, LoRAConfig, ModelConfig, MultiModalConfig, ParallelConfig, SchedulerConfig) +from vllm.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) from vllm.envs import VLLM_USE_MODELSCOPE from vllm.logger import init_logger from vllm.model_executor.layers.quantization.base_config import ( @@ -95,10 +97,10 @@ def _get_quantization_config( """Get the quantization config.""" if model_config.quantization is not None: quant_config = get_quant_config(model_config, load_config) - capability = current_platform.get_device_capability() # type: ignore + capability_tuple = current_platform.get_device_capability() - if capability is not None: - capability = capability[0] * 10 + capability[1] + if capability_tuple is not None: + capability = capability_tuple.to_int() if capability < quant_config.get_min_capability(): raise ValueError( f"The quantization method {model_config.quantization} " @@ -689,6 +691,8 @@ def save_model( class BitsAndBytesModelLoader(BaseModelLoader): """Model loader to load model weights with BitAndBytes quantization.""" + # TODO: these module names are for Llama only, + # change so that it works with other models as well default_target_modules = [ "gate_proj", "down_proj", "up_proj", "q_proj", "k_proj", "v_proj", "o_proj" @@ -911,13 +915,44 @@ def _parse_quant_state(param_name: str, def _unquantized_generator(self, hf_weights_files, use_safetensors, quant_state_dict) -> Generator: from bitsandbytes.functional import quantize_4bit + tp_size = get_tensor_model_parallel_world_size() + tp_rank = get_tensor_model_parallel_rank() + for weight_name, weight_tensor in self._hf_weight_iter( hf_weights_files, use_safetensors): if any(target_module in weight_name for target_module in self.target_modules): weight_name = weight_name.replace(".weight", ".qweight") + + # weight partitions of different modules occur at + # different dimensions + # TODO: these module names are for Llama only, + # change so that it works with other models as well + if 'down_proj' in weight_name or 'o_proj' in weight_name: + total_size = weight_tensor.size(-1) + start_index = total_size // tp_size * tp_rank + end_index = total_size // tp_size * (tp_rank + 1) + weight_sub_tensor = weight_tensor[..., + start_index:end_index] + + else: + total_size = weight_tensor.size(0) + start_index = total_size // tp_size * tp_rank + end_index = total_size // tp_size * (tp_rank + 1) + weight_sub_tensor = weight_tensor[start_index:end_index, + ...] + # bitsandbytes requires data in GPU - loaded_weight = weight_tensor.cuda().data + if weight_sub_tensor.is_cuda: + loaded_weight = weight_sub_tensor + else: + loaded_weight = weight_sub_tensor.cuda() + + # remove the following after the issue is fixed: + # https://github.com/bitsandbytes-foundation/bitsandbytes/issues/1342 + if loaded_weight.is_contiguous() is False: + loaded_weight = loaded_weight.contiguous() + with set_default_torch_dtype(torch.float32): processed_weight, quant_state = quantize_4bit( loaded_weight, @@ -958,6 +993,13 @@ def _load_weights(self, model_config: ModelConfig, f"BitsAndBytes loader does not support {quant_method} " "quantization") + # The quant_states in pre_quantized models cannot work with a split + # weight tensor. So TP does not work with pre_quantized bnb models. + if pre_quant and get_tensor_model_parallel_world_size() > 1: + raise ValueError( + "Prequant BitsAndBytes models with TP is not supported." + "Please try with PP.") + load_8bit = False if pre_quant: load_8bit = quant_config.get('load_in_8bit', False) diff --git a/vllm/model_executor/model_loader/tensorizer.py b/vllm/model_executor/model_loader/tensorizer.py index 3aac5cd2b43a5..36f33d6d139ee 100644 --- a/vllm/model_executor/model_loader/tensorizer.py +++ b/vllm/model_executor/model_loader/tensorizer.py @@ -408,9 +408,7 @@ def is_vllm_tensorized(tensorizer_config: "TensorizerConfig") -> bool: "inferred as vLLM models, so setting vllm_tensorized=True is " "only necessary for models serialized prior to this change.") return True - if (".vllm_tensorized_marker" in deserializer): - return True - return False + return ".vllm_tensorized_marker" in deserializer def serialize_vllm_model( diff --git a/vllm/model_executor/model_loader/utils.py b/vllm/model_executor/model_loader/utils.py index 0052489d99dc4..2bfe6ea09bd62 100644 --- a/vllm/model_executor/model_loader/utils.py +++ b/vllm/model_executor/model_loader/utils.py @@ -23,13 +23,7 @@ def get_model_architecture( architectures = getattr(model_config.hf_config, "architectures", []) # Special handling for quantized Mixtral. # FIXME(woosuk): This is a temporary hack. - mixtral_supported = ["fp8", "compressed-tensors"] - # for gptq_marlin, only run fused MoE for int4 - if model_config.quantization == "gptq_marlin": - hf_quant_config = getattr(model_config.hf_config, - "quantization_config", None) - if hf_quant_config and hf_quant_config.get("bits") == 4: - mixtral_supported.append("gptq_marlin") + mixtral_supported = ["fp8", "compressed-tensors", "gptq_marlin"] if (model_config.quantization is not None and model_config.quantization not in mixtral_supported diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index 2c01eb380c375..591007e787f47 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -43,6 +43,7 @@ "MptForCausalLM": ("mpt", "MPTForCausalLM"), "MPTForCausalLM": ("mpt", "MPTForCausalLM"), "MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"), + "MiniCPM3ForCausalLM": ("minicpm3", "MiniCPM3ForCausalLM"), "NemotronForCausalLM": ("nemotron", "NemotronForCausalLM"), "OlmoForCausalLM": ("olmo", "OlmoForCausalLM"), "OPTForCausalLM": ("opt", "OPTForCausalLM"), @@ -59,6 +60,7 @@ "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"), "StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"), "Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"), + "SolarForCausalLM": ("solar", "SolarForCausalLM"), "ArcticForCausalLM": ("arctic", "ArcticForCausalLM"), "XverseForCausalLM": ("xverse", "XverseForCausalLM"), "Phi3SmallForCausalLM": ("phi3_small", "Phi3SmallForCausalLM"), @@ -90,12 +92,12 @@ "PaliGemmaForConditionalGeneration": ("paligemma", "PaliGemmaForConditionalGeneration"), "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"), - "UltravoxModel": ("ultravox", "UltravoxModel"), - "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"), "PixtralForConditionalGeneration": ("pixtral", "PixtralForConditionalGeneration"), + "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"), "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"), + "UltravoxModel": ("ultravox", "UltravoxModel"), } _CONDITIONAL_GENERATION_MODELS = { "BartModel": ("bart", "BartForConditionalGeneration"), diff --git a/vllm/model_executor/models/bart.py b/vllm/model_executor/models/bart.py index 9b4c4be7fcb09..cbdacf779b089 100644 --- a/vllm/model_executor/models/bart.py +++ b/vllm/model_executor/models/bart.py @@ -848,11 +848,13 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - encoder_input_ids: torch.Tensor, - encoder_positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + *, + encoder_input_ids: torch.Tensor, + encoder_positions: torch.Tensor, + **kwargs, ) -> torch.Tensor: r""" Args: diff --git a/vllm/model_executor/models/granite.py b/vllm/model_executor/models/granite.py index b0325e8b616c8..5f365bbc30670 100644 --- a/vllm/model_executor/models/granite.py +++ b/vllm/model_executor/models/granite.py @@ -428,7 +428,8 @@ def compute_logits( sampling_metadata: SamplingMetadata) -> Optional[torch.Tensor]: logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) - logits /= self.config.logits_scaling + if logits is not None: + logits /= self.config.logits_scaling return logits def sample( diff --git a/vllm/model_executor/models/minicpm.py b/vllm/model_executor/models/minicpm.py index a135118bc748e..963ad7553fe1d 100644 --- a/vllm/model_executor/models/minicpm.py +++ b/vllm/model_executor/models/minicpm.py @@ -270,38 +270,47 @@ def __init__( ) -> None: super().__init__() self.config = config + self.cache_config = cache_config + self.quant_config = quant_config self.hidden_size = config.hidden_size - rope_theta = getattr(config, "rope_theta", 10000) - rope_scaling = getattr(config, "rope_scaling", None) - max_position_embeddings = getattr(config, "max_position_embeddings", - 8192) + self.rope_theta = getattr(config, "rope_theta", 10000) + self.rope_scaling = getattr(config, "rope_scaling", None) + self.max_position_embeddings = getattr(config, + "max_position_embeddings", 8192) + self._init_attn_block() + self._init_ffn_block() + + def _init_attn_block(self): + self.input_layernorm = RMSNorm(self.config.hidden_size, + eps=self.config.rms_norm_eps) self.self_attn = MiniCPMAttention( hidden_size=self.hidden_size, - num_heads=config.num_attention_heads, - num_kv_heads=config.num_key_value_heads, - rope_theta=rope_theta, - rope_scaling=rope_scaling, - max_position_embeddings=max_position_embeddings, - cache_config=cache_config, - quant_config=quant_config, + num_heads=self.config.num_attention_heads, + num_kv_heads=self.config.num_key_value_heads, + rope_theta=self.rope_theta, + rope_scaling=self.rope_scaling, + max_position_embeddings=self.max_position_embeddings, + cache_config=self.cache_config, + quant_config=self.quant_config, ) + + def _init_ffn_block(self): + self.post_attention_layernorm = RMSNorm(self.config.hidden_size, + eps=self.config.rms_norm_eps) self.num_experts = getattr(self.config, "num_experts", 0) if self.num_experts == 0: self.mlp = MiniCPMMLP( hidden_size=self.hidden_size, - intermediate_size=config.intermediate_size, - hidden_act=config.hidden_act, - quant_config=quant_config, + intermediate_size=self.config.intermediate_size, + hidden_act=self.config.hidden_act, + quant_config=self.quant_config, ) else: - self.mlp = MiniCPMMoE(num_experts=config.num_experts, - top_k=config.num_experts_per_tok, - hidden_size=config.hidden_size, - intermediate_size=config.intermediate_size) - self.input_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.post_attention_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + self.mlp = MiniCPMMoE( + num_experts=self.config.num_experts, + top_k=self.config.num_experts_per_tok, + hidden_size=self.config.hidden_size, + intermediate_size=self.config.intermediate_size) def forward( self, @@ -344,6 +353,8 @@ def __init__( ) -> None: super().__init__() self.config = config + self.cache_config = cache_config + self.quant_config = quant_config self.padding_idx = config.pad_token_id lora_vocab = (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) if lora_config else 0 @@ -354,11 +365,15 @@ def __init__( config.hidden_size, org_num_embeddings=config.vocab_size, ) + self._init_layers() + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def _init_layers(self): self.layers = nn.ModuleList([ - MiniCPMDecoderLayer(config, cache_config, quant_config) - for _ in range(config.num_hidden_layers) + MiniCPMDecoderLayer(self.config, self.cache_config, + self.quant_config) + for _ in range(self.config.num_hidden_layers) ]) - self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: embedding = self.embed_tokens(input_ids) @@ -431,13 +446,11 @@ def __init__( self.config = config self.lora_config = lora_config + self.cache_config = cache_config + self.quant_config = quant_config self.num_experts = getattr(self.config, "num_experts", 0) - self.quant_config = quant_config - self.model = MiniCPMModel(config, - cache_config, - quant_config, - lora_config=lora_config) + self._init_model() unpadded_vocab_size = config.vocab_size if lora_config: unpadded_vocab_size += lora_config.lora_extra_vocab_size @@ -458,6 +471,12 @@ def __init__( config.vocab_size) self.sampler = Sampler() + def _init_model(self): + self.model = MiniCPMModel(config=self.config, + cache_config=self.cache_config, + quant_config=self.quant_config, + lora_config=self.lora_config) + def forward( self, input_ids: torch.Tensor, diff --git a/vllm/model_executor/models/minicpm3.py b/vllm/model_executor/models/minicpm3.py new file mode 100644 index 0000000000000..a048a3dba0415 --- /dev/null +++ b/vllm/model_executor/models/minicpm3.py @@ -0,0 +1,216 @@ +# coding=utf-8 +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py +# Copyright 2024 The ModelBest team. +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only MiniCPM3 model compatible with HuggingFace weights.""" +from typing import Any, Dict, Optional + +import torch +from torch import nn + +from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig +from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + ReplicatedLinear, + RowParallelLinear) +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.models.minicpm import (MiniCPMDecoderLayer, + MiniCPMForCausalLM, + MiniCPMModel) + + +class MiniCPM3Attention(nn.Module): + + def __init__( + self, + config, + hidden_size: int, + num_heads: int, + qk_nope_head_dim: int, + qk_rope_head_dim: int, + v_head_dim: int, + q_lora_rank: int, + kv_lora_rank: int, + rope_theta: float = 10000, + rope_scaling: Optional[Dict[str, Any]] = None, + max_position_embeddings: int = 8192, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.hidden_size = hidden_size + self.qk_nope_head_dim = qk_nope_head_dim + self.qk_rope_head_dim = qk_rope_head_dim + self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim + self.v_head_dim = v_head_dim + self.q_lora_rank = q_lora_rank + self.kv_lora_rank = kv_lora_rank + self.num_heads = num_heads + + tp_size = get_tensor_model_parallel_world_size() + assert self.num_heads % tp_size == 0 + self.num_local_heads = num_heads // tp_size + + self.scaling = self.qk_head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + self.q_a_proj = ReplicatedLinear(self.hidden_size, + self.q_lora_rank, + bias=False, + quant_config=quant_config) + self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps) + self.q_b_proj = ColumnParallelLinear(q_lora_rank, + self.num_heads * self.qk_head_dim, + bias=False, + quant_config=quant_config) + + self.kv_a_proj_with_mqa = ReplicatedLinear(self.hidden_size, + self.kv_lora_rank + + self.qk_rope_head_dim, + bias=False, + quant_config=quant_config) + self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, + eps=config.rms_norm_eps) + self.kv_b_proj = ColumnParallelLinear( + self.kv_lora_rank, + self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), + bias=False, + quant_config=quant_config) + # O projection. + self.o_proj = RowParallelLinear(self.num_heads * self.v_head_dim, + self.hidden_size, + bias=False, + quant_config=quant_config) + + self.rotary_emb = get_rope( + self.qk_rope_head_dim, + rotary_dim=self.qk_rope_head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + ) + self.attn = Attention(self.num_local_heads, + self.qk_head_dim, + self.scaling, + num_kv_heads=self.num_local_heads, + cache_config=cache_config, + quant_config=quant_config) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + q, _ = self.q_a_proj(hidden_states) + q = self.q_a_layernorm(q) + q, _ = self.q_b_proj(q) + q = q.view(-1, self.num_local_heads, self.qk_head_dim) + _, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], + dim=-1) + latent_cache, _ = self.kv_a_proj_with_mqa(hidden_states) + kv_a, _ = latent_cache.split( + [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + latent_cache = latent_cache.unsqueeze(1) + kv_a = self.kv_a_layernorm(kv_a.contiguous()) + kv, _ = self.kv_b_proj(kv_a) + kv = kv.view(-1, self.num_local_heads, + self.qk_nope_head_dim + self.v_head_dim) + k_nope, v = kv.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) + + k_pe = latent_cache[:, :, self.kv_lora_rank:] + + q_pe, k_pe = self.rotary_emb( + positions, + q_pe.reshape(-1, self.num_local_heads * self.qk_rope_head_dim), + k_pe.reshape(-1, self.qk_rope_head_dim)) + q_pe = q_pe.view(-1, self.num_local_heads, self.qk_rope_head_dim) + k_pe = k_pe.view(-1, 1, self.qk_rope_head_dim) + + q[..., self.qk_nope_head_dim:] = q_pe + + k = torch.empty_like(q) + + k[..., :self.qk_nope_head_dim] = k_nope + k[..., self.qk_nope_head_dim:] = k_pe + + q = q.reshape(-1, self.num_local_heads * self.qk_head_dim) + k = k.view(-1, self.num_local_heads * self.qk_head_dim) + v = torch.nn.functional.pad( + v, [0, self.qk_head_dim - self.v_head_dim], + value=0).view(-1, self.num_local_heads * self.qk_head_dim) + + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + attn_output = attn_output.view( + -1, self.num_local_heads, + self.qk_head_dim)[..., :self.v_head_dim].reshape( + -1, self.num_local_heads * self.v_head_dim) + + output, _ = self.o_proj(attn_output) + return output + + +class MiniCPM3DecoderLayer(MiniCPMDecoderLayer): + + def _init_attn_block(self): + self.input_layernorm = RMSNorm(self.config.hidden_size, + eps=self.config.rms_norm_eps) + self.self_attn = MiniCPM3Attention( + config=self.config, + hidden_size=self.hidden_size, + num_heads=self.config.num_attention_heads, + qk_nope_head_dim=self.config.qk_nope_head_dim, + qk_rope_head_dim=self.config.qk_rope_head_dim, + v_head_dim=self.config.v_head_dim, + q_lora_rank=self.config.q_lora_rank, + kv_lora_rank=self.config.kv_lora_rank, + rope_theta=self.rope_theta, + rope_scaling=self.rope_scaling, + max_position_embeddings=self.max_position_embeddings, + cache_config=self.cache_config, + quant_config=self.quant_config, + ) + + +class MiniCPM3Model(MiniCPMModel): + + def _init_layers(self): + self.layers = nn.ModuleList([ + MiniCPM3DecoderLayer(self.config, self.cache_config, + self.quant_config) + for _ in range(self.config.num_hidden_layers) + ]) + + +class MiniCPM3ForCausalLM(MiniCPMForCausalLM): + + def _init_model(self): + self.model = MiniCPM3Model(config=self.config, + cache_config=self.cache_config, + quant_config=self.quant_config, + lora_config=self.lora_config) diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index f8be9490ee55d..f0fc950defed7 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -884,7 +884,7 @@ def __new__( version = str(config.version).split(".") version = tuple([int(x) for x in version]) # Dispatch class based on version - instance_class = _SUPPORT_VERSION.get(version, None) + instance_class = _SUPPORT_VERSION.get(version) if instance_class is None: raise ValueError( "Currently, MiniCPMV only supports versions 2.0, 2.5, and 2.6") diff --git a/vllm/model_executor/models/phimoe.py b/vllm/model_executor/models/phimoe.py index 25bc0590c745c..5036f55803c20 100644 --- a/vllm/model_executor/models/phimoe.py +++ b/vllm/model_executor/models/phimoe.py @@ -600,7 +600,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader( param, loaded_weight, - weight_name, + name, shard_id=shard_id, expert_id=expert_id, ) diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index 010cf85f45e07..682b78bbed093 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -1,4 +1,3 @@ -import math from array import array from dataclasses import dataclass, fields from itertools import tee @@ -15,11 +14,12 @@ from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, MultiModalConfig -from vllm.inputs import INPUT_REGISTRY, InputContext +from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.utils import merge_multimodal_embeddings from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.base import MultiModalInputs @@ -48,23 +48,29 @@ def dummy_data_for_pixtral(ctx: InputContext, seq_len: int, tokenizer = cached_get_tokenizer( ctx.model_config.tokenizer, tokenizer_mode=ctx.model_config.tokenizer_mode) - mm_encoder = tokenizer.instruct.mm_encoder - mm_config = ctx.model_config.multimodal_config - max_num_images_per_request = mm_config.limit_per_prompt.get("image", 1) + mm_encoder = tokenizer.mistral.instruct_tokenizer.mm_encoder + patch_size = mm_encoder.mm_config.image_patch_size + image_token_id = mm_encoder.special_ids.img - # approximate image size - size = int(math.sqrt(seq_len) * mm_encoder.mm_config.image_patch_size) + mm_config = ctx.model_config.multimodal_config + num_images = mm_config.limit_per_prompt.get("image", 1) + # dummy size + size = 256 image = Image.new("RGB", (size, size), color=0) - img_chunk = ImageChunk(image=image) - tokens = mm_encoder(img_chunk).tokens - token_ids = max_num_images_per_request * array(VLLM_TOKEN_ID_ARRAY_TYPE, - tokens) + image_feature_size = (size**2) // (patch_size**2) + + num_image_tokens = image_feature_size * num_images + + token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, + [image_token_id]) * num_image_tokens + token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE, + [0]) * (seq_len - num_image_tokens) seq_data = SequenceData(token_ids) - mm_data = {"image": max_num_images_per_request * [image]} + mm_data = {"image": num_images * [image]} return seq_data, mm_data @@ -99,32 +105,31 @@ def input_mapper_for_pixtral(ctx: InputContext, return MultiModalInputs({"images": images}) -def merge_multimodal_embeddings(input_ids: torch.Tensor, - inputs_embeds: torch.Tensor, - image_features: Optional[List[torch.Tensor]], - image_id: int) -> torch.Tensor: - text_locations = input_ids != image_id - image_locations = input_ids == image_id - - seq_len = input_ids.shape[0] +def input_processor_for_pixtral(ctx: InputContext, llm_inputs: LLMInputs): + multi_modal_data = llm_inputs.get("multi_modal_data") + if multi_modal_data is not None and "image" in multi_modal_data: + tokenizer = cached_get_tokenizer( + ctx.model_config.tokenizer, + tokenizer_mode=ctx.model_config.tokenizer_mode) - N_txt = text_locations.sum().item() - _, D_txt = inputs_embeds.shape - N_img, D_img = image_features.shape + mm_encoder = tokenizer.mistral.instruct_tokenizer.mm_encoder + image_token_id = mm_encoder.special_ids.img - assert (D_txt == D_img), (f"Text features dim {D_txt} should be equal " - "to image features dim {D_img}") - assert (seq_len == N_txt + - N_img), (f"seq_len {seq_len} should be equal to N_txt + N_img " - f"{(N_txt, N_img, image_locations.sum().item())}") + if image_token_id not in llm_inputs['prompt_token_ids']: + raise ValueError( + (f"You've passed {llm_inputs=} without {image_token_id=}" + " Make sure to process your input via mistral_common's" + " tokenizer or pass a chat completion request. For more" + " For more info, see: " + "https://github.com/vllm-project/vllm/issues/8411.")) - inputs_embeds[image_locations, :] = image_features - return inputs_embeds + return llm_inputs @MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_pixtral) @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_pixtral_image_tokens) @INPUT_REGISTRY.register_dummy_data(dummy_data_for_pixtral) +@INPUT_REGISTRY.register_input_processor(input_processor_for_pixtral) class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal): def __init__(self, @@ -201,11 +206,21 @@ def _parse_and_validate_image_input( return None if isinstance(images, torch.Tensor): - # always take last images - images = [images[-1][i] for i in range(images.size(1))] + # if passed as batch take all images + N, B, C, W, H = images.shape + images = images.reshape(N * B, C, W, H) + images = [images[i] for i in range(images.size(0))] elif isinstance(images, list): - # always take last images - images = [images[-1][i] for i in range(len(images[0]))] + # if passed as list flatten lists of tensors + flatten_images = [] + for imgs_per_req in images: + imgs_per_req = [ + imgs_per_req[i] for i in range(imgs_per_req.size(0)) + ] if isinstance(imgs_per_req, torch.Tensor) else imgs_per_req + + flatten_images.extend(imgs_per_req) + + images = flatten_images return images @@ -439,7 +454,7 @@ def forward( return x -def position_meshgrid(patch_embeds_list: list[torch.Tensor], ) -> torch.Tensor: +def position_meshgrid(patch_embeds_list: List[torch.Tensor], ) -> torch.Tensor: positions = torch.cat([ torch.stack( torch.meshgrid( diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 3f8c590a39b00..a9a0329e99f08 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -207,7 +207,7 @@ def __init__( selected_backend = backend_name_to_enum(backend_by_env_var) if selected_backend is None: # For Volta and Turing GPUs, use xformers instead. - device_available = current_platform.get_device_capability()[0] >= 8 + device_available = current_platform.has_device_capability(80) if device_available: from transformers.utils import is_flash_attn_2_available @@ -1055,6 +1055,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): if weight_name not in name: continue name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) @@ -1078,6 +1081,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): loaded_weight = loaded_weight.transpose(0, 1) loaded_weight = loaded_weight.reshape(-1) try: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue param = params_dict[name] except KeyError: print(params_dict.keys()) diff --git a/vllm/model_executor/models/solar.py b/vllm/model_executor/models/solar.py new file mode 100644 index 0000000000000..16e576d0ac29c --- /dev/null +++ b/vllm/model_executor/models/solar.py @@ -0,0 +1,580 @@ +# coding=utf-8 +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only Solar model compatible with HuggingFace weights.""" + +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union + +import torch +from torch import nn + +from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig, LoRAConfig +from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( + get_compressed_tensors_cache_scale) +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name) +from vllm.model_executor.models.interfaces import SupportsLoRA +from vllm.model_executor.models.utils import (PPMissingLayer, + is_pp_missing_parameter, + make_layers) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors +from vllm.utils import is_hip + + +class SolarMLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config: Optional[QuantizationConfig] = None, + bias: bool = False, + prefix: str = "", + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + input_size=hidden_size, + output_sizes=[intermediate_size] * 2, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj", + ) + self.down_proj = RowParallelLinear( + input_size=intermediate_size, + output_size=hidden_size, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.down_proj", + ) + if hidden_act != "silu": + raise ValueError(f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now.") + self.act_fn = SiluAndMul() + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class SolarAttention(nn.Module): + + def __init__( + self, + config, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + rope_theta: float = 10000, + rope_scaling: Optional[Dict[str, Any]] = None, + max_position_embeddings: int = 8192, + quant_config: Optional[QuantizationConfig] = None, + bias: bool = False, + cache_config: Optional[CacheConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + # MistralConfig has an optional head_dim introduced by Mistral-Nemo + self.head_dim = getattr(config, "head_dim", + self.hidden_size // self.total_num_heads) + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + self.qkv_proj = QKVParallelLinear( + hidden_size=hidden_size, + head_size=self.head_dim, + total_num_heads=self.total_num_heads, + total_num_kv_heads=self.total_num_kv_heads, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + self.o_proj = RowParallelLinear( + input_size=self.total_num_heads * self.head_dim, + output_size=hidden_size, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + ) + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + output, _ = self.o_proj(attn_output) + return output + + +class SolarDecoderLayer(nn.Module): + + def __init__( + self, + config, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + + if rope_scaling is not None and getattr( + config, "original_max_position_embeddings", None): + rope_scaling["original_max_position_embeddings"] \ + = config.original_max_position_embeddings + max_position_embeddings = getattr(config, "max_position_embeddings", + 8192) + # Support abacusai/Smaug-72B-v0.1 with attention_bias + # Support internlm/internlm-7b with bias + attention_bias = getattr(config, "attention_bias", False) or getattr( + config, "bias", False) + self.self_attn = SolarAttention( + config=config, + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=getattr(config, "num_key_value_heads", + config.num_attention_heads), + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + quant_config=quant_config, + bias=attention_bias, + cache_config=cache_config, + prefix=f"{prefix}.self_attn", + ) + self.mlp = SolarMLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + bias=getattr(config, "mlp_bias", False), + prefix=f"{prefix}.mlp", + ) + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + residual: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + ) + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + + +class SolarModel(nn.Module): + + def __init__( + self, + config, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.config = config + self.padding_idx = config.pad_token_id + lora_vocab = ((lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0) + self.vocab_size = config.vocab_size + lora_vocab + self.org_vocab_size = config.vocab_size + if get_pp_group().is_first_rank or (config.tie_word_embeddings + and get_pp_group().is_last_rank): + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + ) + else: + self.embed_tokens = PPMissingLayer() + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: SolarDecoderLayer( + config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix, + ), + prefix=f"{prefix}.layers", + ) + if get_pp_group().is_last_rank: + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + else: + self.norm = PPMissingLayer() + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + bskcn_h_1 = None + bskcn_h_2 = None + bskcn_r_1 = None + bskcn_r_2 = None + bskcn_tv = (self.config.bskcn_tv[0] + if self.training else self.config.bskcn_tv[1]) + + for i in range(self.start_layer, self.end_layer): + if i in self.config.bskcn_1: + bskcn_h_1 = hidden_states.clone() + bskcn_r_1 = residual.clone() + if i in self.config.bskcn_2: + bskcn_h_2 = hidden_states.clone() + bskcn_r_2 = residual.clone() + if i in self.config.bskcn_3: + hidden_states = bskcn_h_1 * bskcn_tv + hidden_states * ( + 1 - bskcn_tv) + residual = bskcn_r_1 * bskcn_tv + residual * (1 - bskcn_tv) + if i in self.config.bskcn_4: + hidden_states = bskcn_h_2 * bskcn_tv + hidden_states * ( + 1 - bskcn_tv) + residual = bskcn_r_2 * bskcn_tv + residual * (1 - bskcn_tv) + layer = self.layers[i] + hidden_states, residual = layer( + positions, + hidden_states, + kv_caches[i - self.start_layer], + attn_metadata, + residual, + ) + + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + +class SolarForCausalLM(nn.Module, SupportsLoRA): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + # LoRA specific attributes + supported_lora_modules = [ + "qkv_proj", + "o_proj", + "gate_up_proj", + "down_proj", + "embed_tokens", + "lm_head", + ] + embedding_modules = { + "embed_tokens": "input_embeddings", + "lm_head": "output_embeddings", + } + embedding_padding_modules = ["lm_head"] + bitsandbytes_stacked_params_mapping = { + # shard_name, weight_name, index + "q_proj": ("qkv_proj", 0), + "k_proj": ("qkv_proj", 1), + "v_proj": ("qkv_proj", 2), + "gate_proj": ("gate_up_proj", 0), + "up_proj": ("gate_up_proj", 1), + } + + def __init__( + self, + config, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, + ) -> None: + super().__init__() + + self.config = config + self.lora_config = lora_config + + self.model = SolarModel( + config, + cache_config, + quant_config, + lora_config=lora_config, + prefix="model", + ) + if get_pp_group().is_last_rank: + self.unpadded_vocab_size = config.vocab_size + if lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE + # We need bigger padding if using lora for kernel + # compatibility + if not lora_config else lora_config.lora_vocab_padding_size, + quant_config=quant_config, + ) + if config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight + + logit_scale = getattr(config, "logit_scale", 1.0) + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size, + logit_scale) + self.sampler = Sampler() + else: + self.lm_head = PPMissingLayer() + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + model_output = self.model(input_ids, positions, kv_caches, + attn_metadata, intermediate_tensors) + return model_output + + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def make_empty_intermediate_tensors( + self, batch_size: int, dtype: torch.dtype, + device: torch.device) -> IntermediateTensors: + return IntermediateTensors({ + "hidden_states": + torch.zeros( + (batch_size, self.config.hidden_size), + dtype=dtype, + device=device, + ), + "residual": + torch.zeros( + (batch_size, self.config.hidden_size), + dtype=dtype, + device=device, + ), + }) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + (".gate_up_proj", ".gate_proj", 0), + (".gate_up_proj", ".up_proj", 1), + ] + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + if ("rotary_emb.cos_cached" in name + or "rotary_emb.sin_cached" in name): + # Models trained using ColossalAI may include these tensors in + # the checkpoint. Skip them. + continue + if scale_name := get_compressed_tensors_cache_scale(name): + # Loading kv cache scales for compressed-tensors quantization + param = params_dict[scale_name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + loaded_weight = loaded_weight[0] + weight_loader(param, loaded_weight) + continue + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + + # If this function is called, it should always initialize KV cache scale + # factors (or else raise an exception). Thus, handled exceptions should + # make sure to leave KV cache scale factors in a known good (dummy) state + def load_kv_cache_scales(self, quantization_param_path: str) -> None: + tp_size = get_tensor_model_parallel_world_size() + tp_rank = get_tensor_model_parallel_rank() + for layer_idx, scaling_factor in kv_cache_scales_loader( + quantization_param_path, + tp_rank, + tp_size, + self.config.num_hidden_layers, + self.config.__class__.model_type, + ): + if not isinstance(self.model.layers[layer_idx], nn.Identity): + layer_self_attn = self.model.layers[layer_idx].self_attn + + if is_hip(): + # The scaling factor convention we are assuming is + # quantized_value * scaling_factor ~= true_value + # which is consistent with the practice of setting + # scaling_factor = tensor_amax / FPtype_max + scaling_factor *= 2 + if hasattr(layer_self_attn, "kv_scale"): + layer_self_attn.attn._kv_scale = scaling_factor + else: + raise RuntimeError("Self attention has no KV cache scaling " + "factor attribute!") diff --git a/vllm/model_executor/sampling_metadata.py b/vllm/model_executor/sampling_metadata.py index a085779bc61a7..97d36d31f2b11 100644 --- a/vllm/model_executor/sampling_metadata.py +++ b/vllm/model_executor/sampling_metadata.py @@ -1,4 +1,3 @@ -import random from array import array from dataclasses import dataclass from typing import Dict, List, Optional, Tuple @@ -8,15 +7,10 @@ from vllm.sampling_params import SamplingParams, SamplingType from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData, SequenceGroupMetadata) -from vllm.triton_utils.sample import get_num_triton_sampler_splits from vllm.utils import (PyObjectCache, async_tensor_h2d, - is_pin_memory_available, make_tensor_with_pad, - maybe_expand_dim) + is_pin_memory_available, make_tensor_with_pad) _SAMPLING_EPS = 1e-5 -_SEED_0_REPLACEMENT = 3403598558 -# Some triton sampler related code is guarded before it is ready. -_USE_TRITON_SAMPLER = False @dataclass @@ -74,12 +68,12 @@ def gen_seq_group_to_sample_builder(num_seqs: int): generator=None, is_prompt=True, prompt_logprob_indices=[], - sample_indices=[]) + sample_indices=[], + ) class SamplingMetadataCache: - """Used to cache SamplingMetadata objects between scheduler iterations - """ + """Used to cache SamplingMetadata objects between scheduler iterations""" def __init__(self): self._seq_group_to_sample_cache: Dict[int, PyObjectCache] = {} @@ -124,12 +118,12 @@ def sample(logits): The first tuple is [1, 2] (sampled index within original logit), and the second tuple is [0, 1] (sampled index within pruned logit). num_prompts: Number of prompt sequence groups in seq_groups. - skip_sampler_cpu_output: Indicates if we want to skip the GPU=>CPU + skip_sampler_cpu_output: Indicates if we want to skip the GPU=>CPU serialization of token outputs. - reuse_sampling_tensors: Indicates if we want to reuse sampling + reuse_sampling_tensors: Indicates if we want to reuse sampling tensors that are part of the sampler forward pass. Currently, it is mainly used for multi-step decode. - + """ def __init__( @@ -165,16 +159,19 @@ def prepare( num_prompts, ) = _prepare_seq_groups(seq_group_metadata_list, seq_lens, query_lens, device, generators, cache) - selected_token_indices = async_tensor_h2d(selected_token_indices, - dtype=torch.long, - target_device=device, - pin_memory=pin_memory) + selected_token_indices = async_tensor_h2d( + selected_token_indices, + dtype=torch.long, + target_device=device, + pin_memory=pin_memory, + ) categorized_sample_indices = { - t: maybe_expand_dim( - async_tensor_h2d(seq_ids, - dtype=torch.int, - target_device=device, - pin_memory=pin_memory), 2, 2) + t: async_tensor_h2d( + seq_ids, + dtype=torch.int, + target_device=device, + pin_memory=pin_memory, + ) for t, seq_ids in categorized_sample_indices.items() } @@ -201,8 +198,8 @@ def _prepare_seq_groups( device: str, generators: Optional[Dict[str, torch.Generator]] = None, cache: Optional[SamplingMetadataCache] = None, -) -> Tuple[List[SequenceGroupToSample], List[int], Dict[ - SamplingType, List[Tuple[int, int]]], int]: +) -> Tuple[List[SequenceGroupToSample], List[int], Dict[SamplingType, + List[int]], int, ]: """Prepare sequence groups and indices for sampling. Args: @@ -233,16 +230,13 @@ def _prepare_seq_groups( # Sampling type -> ( # indices to sample/prompt logprob within pruned output logits, # indices to sample within pruned logits) - categorized_sample_indices: Dict[SamplingType, List[Tuple[int, int]]] = { + categorized_sample_indices: Dict[SamplingType, List[int]] = { t: [] for t in SamplingType } # Index of logits to compute logprob. Logits include both prompt logprob # and sample logprob indices. logit_idx = 0 - # Index to sample from a sample tensor. It is used by triton sample kernel. - # See `_sample_with_triton_kernel` for more details. - sample_idx = 0 # Total number of prompts from given sequence groups. num_prompts = 0 @@ -264,10 +258,10 @@ def _prepare_seq_groups( # If the current seq group is in decode stage, it is None. seq_len: Optional[int] = None query_len: Optional[int] = None - prompt_logprob_indices: List[int] = \ - sample_obj.prompt_logprob_indices if cache is not None else [] - sample_indices: List[int] = \ - sample_obj.sample_indices if cache is not None else [] + prompt_logprob_indices: List[int] = (sample_obj.prompt_logprob_indices + if cache is not None else []) + sample_indices: List[int] = (sample_obj.sample_indices + if cache is not None else []) do_sample = seq_group_metadata.do_sample if seq_group_metadata.is_prompt: @@ -333,11 +327,8 @@ def sample(logits): if do_sample: sample_indices.extend(range(logit_idx, logit_idx + sample_len)) categorized_sample_indices[sampling_params.sampling_type].extend( - list( - zip(range(logit_idx, logit_idx + sample_len), - range(sample_idx, sample_idx + sample_len)))) + list(range(logit_idx, logit_idx + sample_len))) logit_idx += sample_len - sample_idx += sample_len if cache is not None: sample_obj.sampling_params = sampling_params @@ -356,7 +347,8 @@ def sample(logits): generator=generator, is_prompt=is_prompt, prompt_logprob_indices=list(prompt_logprob_indices), - sample_indices=list(sample_indices)) + sample_indices=list(sample_indices), + ) seq_groups.append(sample_obj) @@ -378,9 +370,6 @@ class SamplingTensors: presence_penalties: torch.Tensor frequency_penalties: torch.Tensor repetition_penalties: torch.Tensor - sampling_seeds: torch.Tensor - sample_indices: torch.Tensor - extra_seeds: Optional[torch.Tensor] prompt_tokens: torch.Tensor output_tokens: torch.Tensor @@ -391,15 +380,7 @@ def from_sampling_metadata( vocab_size: int, device: torch.device, dtype: torch.dtype, - *, - extra_seeds_to_generate: int = 0, - extra_entropy: Optional[Tuple[int, ...]] = None ) -> Tuple["SamplingTensors", bool, bool, bool]: - """ - extra_seeds_to_generate: extra seeds to generate using the - user-defined seed for each sequence. - extra_entropy: extra entropy to use when generating seeds. - """ prompt_tokens: List[array] = [] output_tokens: List[array] = [] top_ks: List[int] = [] @@ -409,19 +390,10 @@ def from_sampling_metadata( presence_penalties: List[float] = [] frequency_penalties: List[float] = [] repetition_penalties: List[float] = [] - sampling_seeds: List[int] = [] - sample_indices: List[int] = [] do_penalties = False do_top_p_top_k = False do_min_p = False - if _USE_TRITON_SAMPLER: - prompt_best_of: List[int] = [] - - # We need one base seed per Triton slice. - seeds_to_generate = (extra_seeds_to_generate + - get_num_triton_sampler_splits(vocab_size)) - assert sampling_metadata.seq_groups is not None for seq_group in sampling_metadata.seq_groups: seq_ids = seq_group.seq_ids @@ -452,7 +424,7 @@ def from_sampling_metadata( do_penalties = True is_prompt = seq_group.is_prompt - if (is_prompt and sampling_params.prompt_logprobs is not None): + if is_prompt and sampling_params.prompt_logprobs is not None: # For tokens in the prompt that we only need to get # their logprobs query_len = seq_group.query_len @@ -477,28 +449,6 @@ def from_sampling_metadata( frequency_penalties += [f] * len(seq_ids) repetition_penalties += [r] * len(seq_ids) - if _USE_TRITON_SAMPLER: - if is_prompt: - prompt_best_of.append(sampling_params.best_of) - query_len = seq_group.query_len - assert query_len is not None - - seed = sampling_params.seed - is_greedy = sampling_params.sampling_type == SamplingType.GREEDY - - for seq_id in seq_ids: - seq_data = seq_group.seq_data[seq_id] - extra_entropy = extra_entropy or () - seq_seeds = cls._get_sequence_seeds( - seed, - seq_data.get_len(), - *extra_entropy, - seq_id, - seeds_to_generate=seeds_to_generate, - is_greedy=is_greedy) - sampling_seeds.append(seq_seeds) - sample_indices.extend(seq_group.sample_indices) - if do_penalties: for seq_group in sampling_metadata.seq_groups: seq_ids = seq_group.seq_ids @@ -518,23 +468,37 @@ def from_sampling_metadata( output_tokens.append(seq_data.output_token_ids_array) sampling_tensors = SamplingTensors.from_lists( - temperatures, top_ps, top_ks, min_ps, presence_penalties, - frequency_penalties, repetition_penalties, sampling_seeds, - sample_indices, prompt_tokens, output_tokens, vocab_size, - extra_seeds_to_generate, device, dtype) + temperatures, + top_ps, + top_ks, + min_ps, + presence_penalties, + frequency_penalties, + repetition_penalties, + prompt_tokens, + output_tokens, + vocab_size, + device, + dtype, + ) return (sampling_tensors, do_penalties, do_top_p_top_k, do_min_p) @classmethod - def from_lists(cls, temperatures: List[float], top_ps: List[float], - top_ks: List[int], min_ps: List[float], - presence_penalties: List[float], - frequency_penalties: List[float], - repetition_penalties: List[float], - sampling_seeds: List[int], sample_indices: List[int], - prompt_tokens: List[array], output_tokens: List[array], - vocab_size: int, extra_seeds_to_generate: int, - device: torch.device, - dtype: torch.dtype) -> "SamplingTensors": + def from_lists( + cls, + temperatures: List[float], + top_ps: List[float], + top_ks: List[int], + min_ps: List[float], + presence_penalties: List[float], + frequency_penalties: List[float], + repetition_penalties: List[float], + prompt_tokens: List[array], + output_tokens: List[array], + vocab_size: int, + device: torch.device, + dtype: torch.dtype, + ) -> "SamplingTensors": # Note that the performance will be very bad without # pinned memory. pin_memory = is_pin_memory_available() @@ -603,34 +567,9 @@ def from_lists(cls, temperatures: List[float], top_ps: List[float], dtype=torch.int, pin_memory=pin_memory, ) - sample_indices_t = torch.tensor( - sample_indices, - device="cpu", - dtype=torch.long, - pin_memory=pin_memory, - ) - # need to transpose and make contiguous to - # copy the tensor correctly. - # [batch_size, n_seeds] -> [n_seeds, batch_size] - sampling_seeds_t = torch.tensor( - sampling_seeds, - device="cpu", - dtype=torch.long, - pin_memory=pin_memory, - ).t().contiguous() - # Because the memory is pinned, we can do non-blocking # transfer to device. - # How many seeds the sample operation itself will need. - num_base_seeds = sampling_seeds_t.shape[0] - extra_seeds_to_generate - sampling_seeds_gpu = sampling_seeds_t.to(device=device, - non_blocking=True) - extra_seeds_gpu = sampling_seeds_gpu[num_base_seeds:] - if not extra_seeds_gpu.numel(): - extra_seeds_gpu = None - sampling_seeds_gpu = sampling_seeds_gpu[:num_base_seeds] - return cls( temperatures=temperatures_t.to(device=device, non_blocking=True), top_ps=top_ps_t.to(device=device, non_blocking=True), @@ -644,38 +583,4 @@ def from_lists(cls, temperatures: List[float], top_ps: List[float], non_blocking=True), prompt_tokens=prompt_t.to(device=device, non_blocking=True), output_tokens=output_t.to(device=device, non_blocking=True), - sampling_seeds=sampling_seeds_gpu, - sample_indices=sample_indices_t.to(device=device, - non_blocking=True), - extra_seeds=extra_seeds_gpu, ) - - @staticmethod - def _get_sequence_seeds( - seed: int, - *extra_entropy: int, - seeds_to_generate: int, - is_greedy: bool, - ): - """Get `seeds_to_generate` child seeds from `seed` and extra entropy.""" - if not is_greedy: - if seed is None: - randint_fn = random.randint - else: - generator = random.Random(str((seed, ) + extra_entropy)) - randint_fn = generator.randint - lo, hi = torch.iinfo(torch.long).min, torch.iinfo(torch.long).max - # If the user/random sets seed = 0 but request should - # have sampling, we need to change it to something - # else. We use a constant in that case. - # This way we don't need to create and load a bool - # matrix in the sampling kernel, which reduces CPU - # overhead and latency. - seq_seeds = [ - randint_fn(lo, hi) or _SEED_0_REPLACEMENT - for _ in range(seeds_to_generate) - ] - else: - # For the kernel, seed == 0 means greedy decoding. - seq_seeds = [0] * seeds_to_generate - return seq_seeds diff --git a/vllm/model_executor/utils.py b/vllm/model_executor/utils.py index 336bc1cd005cf..d7eec818cbba4 100644 --- a/vllm/model_executor/utils.py +++ b/vllm/model_executor/utils.py @@ -1,17 +1,13 @@ """Utils for model executor.""" -import random from typing import Any, Dict, Optional -import numpy as np import torch +from vllm.utils import seed_everything + def set_random_seed(seed: int) -> None: - random.seed(seed) - np.random.seed(seed) - torch.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed_all(seed) + seed_everything(seed) def set_weight_attrs( diff --git a/vllm/outputs.py b/vllm/outputs.py index e091b576f5972..85ea9196b25df 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -5,6 +5,7 @@ from typing import Union from vllm.lora.request import LoRARequest +from vllm.sampling_params import RequestOutputKind from vllm.sequence import (PromptLogprobs, RequestMetrics, SampleLogprobs, SequenceGroup, SequenceStatus) @@ -92,7 +93,7 @@ def __init__( self, request_id: str, prompt: Optional[str], - prompt_token_ids: List[int], + prompt_token_ids: Optional[List[int]], prompt_logprobs: Optional[PromptLogprobs], outputs: List[CompletionOutput], finished: bool, @@ -113,19 +114,26 @@ def __init__( self.encoder_prompt_token_ids = encoder_prompt_token_ids @classmethod - def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput": - if seq_group.sampling_params is None: + def from_seq_group(cls, + seq_group: SequenceGroup) -> Optional["RequestOutput"]: + sampling_params = seq_group.sampling_params + if sampling_params is None: raise ValueError( "Sampling parameters are missing for a CompletionRequest.") + finished = seq_group.is_finished() + if sampling_params.output_kind == RequestOutputKind.FINAL_ONLY and ( + not finished): + return None + seqs = seq_group.get_seqs() if len(seqs) == 1: top_n_seqs = seqs else: # Get the top-n sequences. - n = seq_group.sampling_params.n - if seq_group.sampling_params.use_beam_search: + n = sampling_params.n + if sampling_params.use_beam_search: sorting_key = lambda seq: seq.get_beam_search_score( - seq_group.sampling_params.length_penalty) + sampling_params.length_penalty) else: sorting_key = lambda seq: seq.get_cumulative_logprob() sorted_seqs = sorted(seqs, key=sorting_key, reverse=True) @@ -135,26 +143,49 @@ def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput": # NOTE: We need omit logprobs here explicitly because the sequence # always has the logprobs of the sampled tokens even if the # logprobs are not requested. - include_logprobs = seq_group.sampling_params.logprobs is not None - text_buffer_length = seq_group.sampling_params.output_text_buffer_length - outputs = [ - CompletionOutput( - seqs.index(seq), - seq.get_output_text_to_return(text_buffer_length), - seq.data._output_token_ids, - seq.get_cumulative_logprob() if include_logprobs else None, - seq.output_logprobs if include_logprobs else None, - SequenceStatus.get_finished_reason(seq.status), - seq.stop_reason) for seq in top_n_seqs - ] + include_logprobs = sampling_params.logprobs is not None + text_buffer_length = sampling_params.output_text_buffer_length + delta = sampling_params.output_kind == RequestOutputKind.DELTA + + outputs = [] + include_prompt = True + for seq in top_n_seqs: + output_text = seq.get_output_text_to_return( + text_buffer_length, delta) + output_token_ids = seq.get_output_token_ids_to_return(delta) + output_logprobs = seq.output_logprobs if include_logprobs else None + + if delta: + # Slice logprobs delta if applicable + if output_logprobs: + output_logprobs = output_logprobs[-len(output_token_ids):] + # Don't include prompt if this is after the first output + # containing decode token ids + if include_prompt and seq.get_output_len() > len( + output_token_ids): + include_prompt = False + + outputs.append( + CompletionOutput( + seqs.index(seq), output_text, output_token_ids, + seq.get_cumulative_logprob() if include_logprobs else None, + output_logprobs, + SequenceStatus.get_finished_reason(seq.status), + seq.stop_reason)) # Every sequence in the sequence group should have the same prompt. - prompt = seq_group.prompt - prompt_token_ids = seq_group.prompt_token_ids - encoder_prompt = seq_group.encoder_prompt - encoder_prompt_token_ids = seq_group.encoder_prompt_token_ids - prompt_logprobs = seq_group.prompt_logprobs - finished = seq_group.is_finished() + if include_prompt: + prompt = seq_group.prompt + prompt_token_ids = seq_group.prompt_token_ids + encoder_prompt = seq_group.encoder_prompt + encoder_prompt_token_ids = seq_group.encoder_prompt_token_ids + prompt_logprobs = seq_group.prompt_logprobs + else: + prompt = None + prompt_token_ids = None + encoder_prompt = None + encoder_prompt_token_ids = None + prompt_logprobs = None finished_time = time.time() if finished else None seq_group.set_finished_time(finished_time) return cls(seq_group.request_id, diff --git a/vllm/platforms/cpu.py b/vllm/platforms/cpu.py index 4736e898b6a52..9b348f3e17a5f 100644 --- a/vllm/platforms/cpu.py +++ b/vllm/platforms/cpu.py @@ -6,10 +6,10 @@ class CpuPlatform(Platform): _enum = PlatformEnum.CPU - @staticmethod - def get_device_name(device_id: int = 0) -> str: + @classmethod + def get_device_name(cls, device_id: int = 0) -> str: return "cpu" - @staticmethod - def inference_mode(): + @classmethod + def inference_mode(cls): return torch.no_grad() diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 8d18527e7c973..a9978d5d84d7c 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -11,7 +11,7 @@ from vllm.logger import init_logger -from .interface import Platform, PlatformEnum +from .interface import DeviceCapability, Platform, PlatformEnum logger = init_logger(__name__) @@ -96,19 +96,20 @@ def device_id_to_physical_device_id(device_id: int) -> int: class CudaPlatform(Platform): _enum = PlatformEnum.CUDA - @staticmethod - def get_device_capability(device_id: int = 0) -> Tuple[int, int]: + @classmethod + def get_device_capability(cls, device_id: int = 0) -> DeviceCapability: physical_device_id = device_id_to_physical_device_id(device_id) - return get_physical_device_capability(physical_device_id) + major, minor = get_physical_device_capability(physical_device_id) + return DeviceCapability(major=major, minor=minor) - @staticmethod - def get_device_name(device_id: int = 0) -> str: + @classmethod + def get_device_name(cls, device_id: int = 0) -> str: physical_device_id = device_id_to_physical_device_id(device_id) return get_physical_device_name(physical_device_id) - @staticmethod + @classmethod @with_nvml_context - def is_full_nvlink(physical_device_ids: List[int]) -> bool: + def is_full_nvlink(cls, physical_device_ids: List[int]) -> bool: """ query if the set of gpus are fully connected by nvlink (1 hop) """ diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 676f4c9fccf5a..360590d7d5eb6 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -1,5 +1,5 @@ import enum -from typing import Optional, Tuple +from typing import NamedTuple, Optional, Tuple, Union import torch @@ -12,6 +12,23 @@ class PlatformEnum(enum.Enum): UNSPECIFIED = enum.auto() +class DeviceCapability(NamedTuple): + major: int + minor: int + + def as_version_str(self) -> str: + return f"{self.major}.{self.minor}" + + def to_int(self) -> int: + """ + Express device capability as an integer ````. + + It is assumed that the minor version is always a single digit. + """ + assert 0 <= self.minor < 10 + return self.major * 10 + self.minor + + class Platform: _enum: PlatformEnum @@ -27,16 +44,47 @@ def is_tpu(self) -> bool: def is_cpu(self) -> bool: return self._enum == PlatformEnum.CPU - @staticmethod - def get_device_capability(device_id: int = 0) -> Optional[Tuple[int, int]]: + def is_cuda_alike(self) -> bool: + """Stateless version of :func:`torch.cuda.is_available`.""" + return self._enum in (PlatformEnum.CUDA, PlatformEnum.ROCM) + + @classmethod + def get_device_capability( + cls, + device_id: int = 0, + ) -> Optional[DeviceCapability]: + """Stateless version of :func:`torch.cuda.get_device_capability`.""" return None - @staticmethod - def get_device_name(device_id: int = 0) -> str: + @classmethod + def has_device_capability( + cls, + capability: Union[Tuple[int, int], int], + device_id: int = 0, + ) -> bool: + """ + Test whether this platform is compatible with a device capability. + + The ``capability`` argument can either be: + + - A tuple ``(major, minor)``. + - An integer ````. (See :meth:`DeviceCapability.to_int`) + """ + current_capability = cls.get_device_capability(device_id=device_id) + if current_capability is None: + return False + + if isinstance(capability, tuple): + return current_capability >= capability + + return current_capability.to_int() >= capability + + @classmethod + def get_device_name(cls, device_id: int = 0) -> str: raise NotImplementedError - @staticmethod - def inference_mode(): + @classmethod + def inference_mode(cls): """A device-specific wrapper of `torch.inference_mode`. This wrapper is recommended because some hardware backends such as TPU diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 28525e8ff8811..b6a19eca01745 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -1,12 +1,11 @@ import os from functools import lru_cache -from typing import Tuple import torch from vllm.logger import init_logger -from .interface import Platform, PlatformEnum +from .interface import DeviceCapability, Platform, PlatformEnum logger = init_logger(__name__) @@ -20,12 +19,13 @@ class RocmPlatform(Platform): _enum = PlatformEnum.ROCM - @staticmethod + @classmethod @lru_cache(maxsize=8) - def get_device_capability(device_id: int = 0) -> Tuple[int, int]: - return torch.cuda.get_device_capability(device_id) + def get_device_capability(cls, device_id: int = 0) -> DeviceCapability: + major, minor = torch.cuda.get_device_capability(device_id) + return DeviceCapability(major=major, minor=minor) - @staticmethod + @classmethod @lru_cache(maxsize=8) - def get_device_name(device_id: int = 0) -> str: + def get_device_name(cls, device_id: int = 0) -> str: return torch.cuda.get_device_name(device_id) diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index 393fc230da0b9..b30bccb103af3 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -6,6 +6,10 @@ class TpuPlatform(Platform): _enum = PlatformEnum.TPU - @staticmethod - def inference_mode(): + @classmethod + def get_device_name(cls, device_id: int = 0) -> str: + raise NotImplementedError + + @classmethod + def inference_mode(cls): return torch.no_grad() diff --git a/vllm/plugins/__init__.py b/vllm/plugins/__init__.py index 765f74fe7356f..7939688ef0da3 100644 --- a/vllm/plugins/__init__.py +++ b/vllm/plugins/__init__.py @@ -1,4 +1,5 @@ import logging +from typing import Callable, Optional, Union import vllm.envs as envs @@ -29,3 +30,15 @@ def load_general_plugins(): except Exception: logger.exception("Failed to load general plugin: %s", plugin.name) + + +_torch_compile_backend: Optional[Union[Callable, str]] = None + + +def set_torch_compile_backend(backend: Union[Callable, str]): + global _torch_compile_backend + _torch_compile_backend = backend + + +def get_torch_compile_backend() -> Optional[Union[Callable, str]]: + return _torch_compile_backend diff --git a/vllm/prompt_adapter/utils.py b/vllm/prompt_adapter/utils.py index 989cc5a0f87c8..4cde2a0254b90 100644 --- a/vllm/prompt_adapter/utils.py +++ b/vllm/prompt_adapter/utils.py @@ -8,13 +8,15 @@ from huggingface_hub.utils import EntryNotFoundError from safetensors.torch import load_file as safe_load_file +from vllm.platforms import current_platform + WEIGHTS_NAME = "adapter_model.bin" SAFETENSORS_WEIGHTS_NAME = "adapter_model.safetensors" # Get current device name based on available devices def infer_device() -> str: - if torch.cuda.is_available(): + if current_platform.is_cuda_alike(): return "cuda" return "cpu" diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index c83ed5cca6791..5edbc8e424e81 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -1,6 +1,6 @@ """Sampling parameters for text generation.""" import copy -from enum import IntEnum +from enum import Enum, IntEnum from functools import cached_property from typing import Any, Callable, Dict, List, Optional, Set, Union @@ -33,6 +33,15 @@ class SamplingType(IntEnum): to sample from.""" +class RequestOutputKind(Enum): + # Return entire output so far in every RequestOutput + CUMULATIVE = 0 + # Return only deltas in each RequestOutput + DELTA = 1 + # Do not return intermediate RequestOuputs + FINAL_ONLY = 2 + + class SamplingParams( msgspec.Struct, omit_defaults=True, # type: ignore[call-arg] @@ -147,6 +156,7 @@ class SamplingParams( logits_processors: Optional[Any] = None include_stop_str_in_output: bool = False truncate_prompt_tokens: Optional[Annotated[int, msgspec.Meta(ge=1)]] = None + output_kind: RequestOutputKind = RequestOutputKind.CUMULATIVE # The below fields are not supposed to be used as an input. # They are set in post_init. @@ -182,6 +192,7 @@ def from_optional( logits_processors: Optional[List[LogitsProcessor]] = None, truncate_prompt_tokens: Optional[Annotated[int, msgspec.Meta(ge=1)]] = None, + output_kind: RequestOutputKind = RequestOutputKind.CUMULATIVE, ) -> "SamplingParams": return SamplingParams( n=1 if n is None else n, @@ -213,6 +224,7 @@ def from_optional( spaces_between_special_tokens=spaces_between_special_tokens, logits_processors=logits_processors, truncate_prompt_tokens=truncate_prompt_tokens, + output_kind=output_kind, ) def __post_init__(self) -> None: @@ -317,6 +329,9 @@ def _verify_args(self) -> None: raise ValueError( "stop strings are only supported when detokenize is True. " "Set detokenize=True to use stop.") + if self.best_of != self.n and self.output_kind == ( + RequestOutputKind.DELTA): + raise ValueError("best_of must equal n to use output_kind=DELTA") def _verify_beam_search(self) -> None: if self.best_of == 1: diff --git a/vllm/scripts.py b/vllm/scripts.py index e557961a335bf..231a18e99f3d7 100644 --- a/vllm/scripts.py +++ b/vllm/scripts.py @@ -1,11 +1,11 @@ # The CLI entrypoint to vLLM. import argparse -import asyncio import os import signal import sys from typing import List, Optional +import uvloop from openai import OpenAI from openai.types.chat import ChatCompletionMessageParam @@ -34,7 +34,7 @@ def serve(args: argparse.Namespace) -> None: # EngineArgs expects the model name to be passed as --model. args.model = args.model_tag - asyncio.run(run_server(args)) + uvloop.run(run_server(args)) def interactive_cli(args: argparse.Namespace) -> None: diff --git a/vllm/sequence.py b/vllm/sequence.py index 135586831e680..07ceccf123541 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -5,8 +5,9 @@ from array import array from collections import defaultdict from dataclasses import dataclass -from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Mapping, - Optional, Set, Tuple, Union, cast) +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Mapping, Optional +from typing import Sequence as GenericSequence +from typing import Set, Tuple, Union, cast import msgspec import torch @@ -407,6 +408,10 @@ def __init__( self.status = SequenceStatus.WAITING self.stop_reason: Union[int, str, None] = None + # These are used to keep track of delta outputs + self._last_token_ids_offset: int = 0 + self._last_output_text_offset: int = 0 + # Used for incremental detokenization self.prefix_offset = 0 self.read_offset = 0 @@ -462,11 +467,37 @@ def prompt_adapter_id(self) -> int: return self.prompt_adapter_request.prompt_adapter_id \ if self.prompt_adapter_request else 0 - def get_output_text_to_return(self, buffer_length: int): + def get_output_text_to_return(self, buffer_length: int, + delta: bool) -> str: + """If delta is True, only new text since the last call to + this method is returned""" + # We return the full output text if the sequence is finished. truncate = buffer_length and not self.is_finished() - return self.output_text[:-buffer_length] if truncate else ( - self.output_text) + if not delta: + return self.output_text[:-buffer_length] if truncate else ( + self.output_text) + length = len(self.output_text) + if truncate: + length -= buffer_length + last_offset = self._last_output_text_offset + if last_offset < length: + self._last_output_text_offset = length + return self.output_text[last_offset:length] + return "" + + def get_output_token_ids_to_return(self, + delta: bool) -> GenericSequence[int]: + """If delta is True, only new tokens since the last call to + this method are returned""" + if not delta: + return self.get_output_token_ids() + length = self.get_output_len() + last_offset = self._last_token_ids_offset + if last_offset < length: + self._last_token_ids_offset = length + return self.data._output_token_ids[last_offset:] + return () def hash_of_block(self, logical_idx: int) -> int: # TODO This can produce incorrect hash when block size > prompt size diff --git a/vllm/spec_decode/draft_model_runner.py b/vllm/spec_decode/draft_model_runner.py index 1e403637d2388..cf64af72a14a5 100644 --- a/vllm/spec_decode/draft_model_runner.py +++ b/vllm/spec_decode/draft_model_runner.py @@ -183,10 +183,7 @@ def supports_gpu_multi_step(self, execute_model_req: ExecuteModelRequest): return False # TODO: Add soft-tuning prompt adapter support - if self.prompt_adapter_config: - return False - - return True + return not self.prompt_adapter_config @torch.inference_mode() def execute_model( diff --git a/vllm/spec_decode/metrics.py b/vllm/spec_decode/metrics.py index ad4e2dc879d7b..89ccaba70e93c 100644 --- a/vllm/spec_decode/metrics.py +++ b/vllm/spec_decode/metrics.py @@ -104,13 +104,10 @@ def _should_collect_rejsample_metrics(self, now: float) -> bool: if self._rank != 0: return False - if (now - self._last_metrics_collect_time < - self._rejsample_metrics_collect_interval_s): - return False - return True + return now - self._last_metrics_collect_time >= self._rejsample_metrics_collect_interval_s # noqa: E501 def _copy_rejsample_metrics_async(self) -> torch.cuda.Event: - """Copy rejection/typical-acceptance sampling metrics + """Copy rejection/typical-acceptance sampling metrics (number of accepted tokens, etc) to CPU asynchronously. Returns a CUDA event recording when the copy is complete. diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 3c269bc10cdf8..1744935d624fb 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -24,7 +24,7 @@ JAISConfig, MedusaConfig, MLPSpeculatorConfig, MPTConfig, NemotronConfig, RWConfig, - UltravoxConfig) + SolarConfig, UltravoxConfig) # yapf: enable from vllm.transformers_utils.utils import check_gguf_file @@ -50,6 +50,7 @@ "exaone": ExaoneConfig, "internvl_chat": InternVLChatConfig, "nemotron": NemotronConfig, + "solar": SolarConfig, "ultravox": UltravoxConfig, # Granite can be removed from here once we have upgraded to # transformers 4.45+ diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py index 8381c5227584e..ea4fc8ad21f35 100644 --- a/vllm/transformers_utils/configs/__init__.py +++ b/vllm/transformers_utils/configs/__init__.py @@ -13,6 +13,7 @@ from vllm.transformers_utils.configs.mlp_speculator import MLPSpeculatorConfig from vllm.transformers_utils.configs.mpt import MPTConfig from vllm.transformers_utils.configs.nemotron import NemotronConfig +from vllm.transformers_utils.configs.solar import SolarConfig from vllm.transformers_utils.configs.ultravox import UltravoxConfig __all__ = [ @@ -27,6 +28,7 @@ "ExaoneConfig", "MLPSpeculatorConfig", "NemotronConfig", + "SolarConfig", "UltravoxConfig", # Granite can be removed from here once we have upgraded to # transformers 4.45+ diff --git a/vllm/transformers_utils/configs/solar.py b/vllm/transformers_utils/configs/solar.py new file mode 100644 index 0000000000000..d5113bf01695a --- /dev/null +++ b/vllm/transformers_utils/configs/solar.py @@ -0,0 +1,245 @@ +# coding=utf-8 +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Solar model configuration""" + +from transformers import PretrainedConfig +from transformers.utils import logging + +logger = logging.get_logger(__name__) + + +class SolarConfig(PretrainedConfig): + r""" + This is the configuration class to store + the configuration of a [`SolarModel`]. + It is used to instantiate an LLaMA model + according to the specified arguments, + defining the model architecture. + Instantiating a configuration with the + defaults will yield a similar + configuration to that of the LLaMA-7B. + Configuration objects inherit from [`PretrainedConfig`] + and can be used to control the model outputs. + Read the documentation from [`PretrainedConfig`] for more information. + Args: + vocab_size (`int`, *optional*, defaults to 32000): + Vocabulary size of the LLaMA model. + Defines the number of different tokens + that can be represented by the `inputs_ids` + passed when calling [`SolarModel`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 11008): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer decoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer + in the Transformer decoder. + num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that + should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, + the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model + will use Multi Query Attention (MQA) + otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, + each group key and value head should be constructed + by meanpooling all the original heads within that group. + For more details checkout [this paper] + (https://arxiv.org/pdf/2305.13245.pdf). + If it is not specified, will default to + `num_attention_heads`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) + in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 2048): + The maximum sequence length that this model might ever be used with. + Solar 1 supports up to 2048 tokens, + Solar 2 up to 4096, CodeSolar up to 16384. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of + the truncated_normal_initializer for initializing + all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return + the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*): + Padding token id. + bos_token_id (`int`, *optional*, defaults to 1): + Beginning of stream token id. + eos_token_id (`int`, *optional*, defaults to 2): + End of stream token id. + pretraining_tp (`int`, *optional*, defaults to 1): + Experimental feature. Tensor parallelism rank + used during pretraining. + Please refer to [this + document](https://huggingface.co/docs/ + transformers/main/ + perf_train_gpu_many#tensor-parallelism) + to understand more about it. This value is + necessary to ensure exact reproducibility + of the pretraining results. + Please refer to [this + issue](https://github.com/pytorch/pytorch/issues/76232). + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for + the RoPE embeddings. + Currently supports two scaling + strategies: linear and dynamic. + Their scaling factor must be a float greater than 1. + The expected format is + `{"type": strategy name, "factor": scaling factor}`. + When using this flag, don't update + `max_position_embeddings` to the expected new maximum. + See the following thread for more information on how + these scaling strategies behave: + https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/ + dynamically_scaled_rope_further_increases/. This is an + experimental feature, subject to breaking + API changes in future versions. + attention_bias (`bool`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value + and output projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + mlp_bias (`bool`, *optional*, defaults to `False`): + Whether to use a bias in up_proj, down_proj and gate_proj + layers in the MLP layers. + sliding_window (`int`, *optional*, defaults to 2047): + Sliding window attention window size. If not specified, + will default to `2047`. + ```python + >>> from transformers import SolarModel, SolarConfig + >>> # Initializing a Solar-pro style configuration + >>> configuration = SolarConfig() + >>> # Initializing a model from the Solar-pro style configuration + >>> model = SolarModel(configuration) + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "solar" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=32000, + hidden_size=4096, + intermediate_size=11008, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=None, + hidden_act="silu", + max_position_embeddings=2048, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=None, + bos_token_id=1, + eos_token_id=2, + pretraining_tp=1, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + attention_bias=False, + attention_dropout=0.0, + mlp_bias=False, + sliding_window=2047, + bskcn_1=None, + bskcn_2=None, + bskcn_3=None, + bskcn_4=None, + bskcn_tv=None, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.pretraining_tp = pretraining_tp + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self._rope_scaling_validation() + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + self.mlp_bias = mlp_bias + self.sliding_window = sliding_window + self.bskcn_1 = bskcn_1 if bskcn_1 is not None else [12, 20, 32, 44] + self.bskcn_2 = bskcn_2 if bskcn_2 is not None else [20, 32] + self.bskcn_3 = bskcn_3 if bskcn_3 is not None else [16, 24, 36, 48] + self.bskcn_4 = bskcn_4 if bskcn_4 is not None else [28, 40] + self.bskcn_tv = bskcn_tv if bskcn_tv is not None else [0.9, 0.8] + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + def _rope_scaling_validation(self): + """ + Validate the `rope_scaling` configuration. + """ + if self.rope_scaling is None: + return + + if (not isinstance(self.rope_scaling, dict) + or len(self.rope_scaling) != 2): + raise ValueError( + "`rope_scaling` must be a dictionary with two fields," + " `type` and `factor`, " + f"got {self.rope_scaling}") + rope_scaling_type = self.rope_scaling.get("type", None) + rope_scaling_factor = self.rope_scaling.get("factor", None) + if rope_scaling_type is None or rope_scaling_type not in [ + "linear", + "dynamic", + ]: + raise ValueError(f"`rope_scaling`'s type field must be one of " + f"['linear', 'dynamic'], got {rope_scaling_type}") + if (rope_scaling_factor is None + or not isinstance(rope_scaling_factor, float) + or rope_scaling_factor <= 1.0): + raise ValueError( + f"`rope_scaling`'s factor field must be a float > 1," + f" got {rope_scaling_factor}") diff --git a/vllm/transformers_utils/tokenizers/mistral.py b/vllm/transformers_utils/tokenizers/mistral.py index ea1910ed20ec3..7a228a3efa6e8 100644 --- a/vllm/transformers_utils/tokenizers/mistral.py +++ b/vllm/transformers_utils/tokenizers/mistral.py @@ -165,10 +165,9 @@ def apply_chat_template(self, messages: List["ChatCompletionMessageParam"], tools: Optional[Dict[str, Any]] = None, **kwargs) -> List[int]: - assert tools is None, "`tools` are not yet supported." - request = ChatCompletionRequest( - messages=messages) # type: ignore[type-var] + request = ChatCompletionRequest(messages=messages, + tools=tools) # type: ignore[type-var] encoded = self.mistral.encode_chat_completion(request) # encode-decode to get clean prompt @@ -176,7 +175,8 @@ def apply_chat_template(self, def convert_tokens_to_string(self, tokens: List[str]) -> str: if isinstance(self.tokenizer, Tekkenizer): - return "".join(tokens) + return "".join(t for t in tokens + if t not in self.tokenizer._all_special_tokens) else: return self.tokenizer.decode(tokens) # type: ignore[arg-type] diff --git a/vllm/triton_utils/libentry.py b/vllm/triton_utils/libentry.py index ae00af44a048a..4335c7adfc13b 100644 --- a/vllm/triton_utils/libentry.py +++ b/vllm/triton_utils/libentry.py @@ -35,8 +35,8 @@ def key(self, spec_args, dns_args, const_args): dns_key = [ arg.dtype if hasattr( arg, "data_ptr") else type(arg) if not isinstance(arg, int) - else "i32" if -(2**31) <= arg and arg <= 2**31 - - 1 else "u64" if 2**63 <= arg and arg <= 2**64 - 1 else "i64" + else "i32" if arg >= -(2**31) and arg <= 2**31 - + 1 else "u64" if arg >= 2**63 and arg <= 2**64 - 1 else "i64" for arg in dns_args ] # const args passed by position diff --git a/vllm/triton_utils/sample.py b/vllm/triton_utils/sample.py deleted file mode 100644 index 401e4d28a3c99..0000000000000 --- a/vllm/triton_utils/sample.py +++ /dev/null @@ -1,13 +0,0 @@ -import math - -# This is a hardcoded limit in Triton (max block size). -MAX_TRITON_N_COLS = 131072 - - -def get_num_triton_sampler_splits(n_cols: int) -> int: - """Get the number of splits to use for Triton sampling. - - Triton has a limit on the number of columns it can handle, so we need to - split the tensor and call the kernel multiple times if it's too large. - """ - return math.ceil(n_cols / MAX_TRITON_N_COLS) diff --git a/vllm/usage/usage_lib.py b/vllm/usage/usage_lib.py index 515e0a4d8abe7..7fadfd5dfffb4 100644 --- a/vllm/usage/usage_lib.py +++ b/vllm/usage/usage_lib.py @@ -17,6 +17,7 @@ import vllm.envs as envs from vllm.connections import global_http_connection +from vllm.platforms import current_platform from vllm.version import __version__ as VLLM_VERSION _config_home = envs.VLLM_CONFIG_ROOT @@ -151,7 +152,7 @@ def _report_usage_once(self, model_architecture: str, usage_context: UsageContext, extra_kvs: Dict[str, Any]) -> None: # Platform information - if torch.cuda.is_available(): + if current_platform.is_cuda_alike(): device_property = torch.cuda.get_device_properties(0) self.gpu_count = torch.cuda.device_count() self.gpu_type = device_property.name diff --git a/vllm/utils.py b/vllm/utils.py index aba243071b69a..060b387ec7834 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -5,6 +5,7 @@ import enum import gc import os +import random import socket import subprocess import sys @@ -12,6 +13,7 @@ import threading import uuid import warnings +import weakref from asyncio import FIRST_COMPLETED, ensure_future from functools import lru_cache, partial, wraps from platform import uname @@ -31,6 +33,7 @@ import vllm.envs as envs from vllm.logger import enable_trace_function_call, init_logger +from vllm.platforms import current_platform logger = init_logger(__name__) @@ -70,10 +73,6 @@ "currently supported with encoder/" "decoder models.") -STR_NOT_IMPL_ENC_DEC_CUDAGRAPH = ("CUDAGraph is not " - "currently supported with encoder/" - "decoder models.") - STR_NOT_IMPL_ENC_DEC_BACKEND = ("XFormers is the only backend " "currently supported with encoder/" "decoder models.") @@ -97,7 +96,6 @@ "STR_NOT_IMPL_ENC_DEC_PP": STR_NOT_IMPL_ENC_DEC_PP, "STR_NOT_IMPL_ENC_DEC_MM": STR_NOT_IMPL_ENC_DEC_MM, "STR_NOT_IMPL_ENC_DEC_SPEC_DEC": STR_NOT_IMPL_ENC_DEC_SPEC_DEC, - "STR_NOT_IMPL_ENC_DEC_CUDA_GRAPH": STR_NOT_IMPL_ENC_DEC_CUDAGRAPH, "STR_NOT_IMPL_ENC_DEC_BACKEND": STR_NOT_IMPL_ENC_DEC_BACKEND, "STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER": STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER, "STR_NOT_IMPL_ENC_DEC_CPU": STR_NOT_IMPL_ENC_DEC_CPU @@ -269,7 +267,7 @@ def clear(self): class PyObjectCache: - """Used to cache python objects to avoid object allocations + """Used to cache python objects to avoid object allocations across scheduler iterations. """ @@ -288,7 +286,7 @@ def _grow_cache(self): self._obj_cache.append(self._obj_builder()) def get_object(self): - """Returns a pre-allocated cached object. If there is not enough + """Returns a pre-allocated cached object. If there is not enough objects, then the cache size will double. """ if self._index >= len(self._obj_cache): @@ -377,6 +375,22 @@ def get_cpu_memory() -> int: return psutil.virtual_memory().total +def seed_everything(seed: int) -> None: + """ + Set the seed of each random module. + + Loosely based on: https://github.com/Lightning-AI/pytorch-lightning/blob/2.4.0/src/lightning/fabric/utilities/seed.py#L20 + """ + random.seed(seed) + np.random.seed(seed) + + if current_platform.is_cuda_alike(): + torch.cuda.manual_seed_all(seed) + + if is_xpu(): + torch.xpu.manual_seed_all(seed) + + def random_uuid() -> str: return str(uuid.uuid4().hex) @@ -638,9 +652,7 @@ def create_kv_caches_with_random_flash( seed: int = 0, device: Optional[str] = "cuda", ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: - torch.random.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) + seed_everything(seed) torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype) key_value_cache_shape = (num_blocks, 2, block_size, num_heads, head_size) @@ -682,9 +694,7 @@ def create_kv_caches_with_random( f"Does not support key cache of type fp8 with head_size {head_size}" ) - torch.random.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) + seed_everything(seed) torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype) @@ -754,7 +764,7 @@ def __init__(self, device: Optional[torch.types.Device] = None): def current_memory_usage(self) -> float: # Return the memory usage in bytes. - if torch.cuda.is_available(): + if current_platform.is_cuda_alike(): torch.cuda.reset_peak_memory_stats(self.device) mem = torch.cuda.max_memory_allocated(self.device) elif is_xpu(): @@ -836,15 +846,6 @@ def async_tensor_h2d( return t.to(device=target_device, non_blocking=True) -def maybe_expand_dim(tensor: torch.Tensor, - target_dims: int, - size: int = 1) -> torch.Tensor: - """Expand the tensor to the target_dims.""" - if tensor.ndim < target_dims: - tensor = tensor.view(-1, *([size] * (target_dims - tensor.ndim))) - return tensor - - def get_dtype_size(dtype: torch.dtype) -> int: """Get the size of the data type in bytes.""" return torch.tensor([], dtype=dtype).element_size() @@ -1069,7 +1070,7 @@ def _cuda_device_count_stateless( def cuda_device_count_stateless() -> int: """Get number of CUDA devices, caching based on the value of CUDA_VISIBLE_DEVICES at the time of call. - + This should be used instead of torch.cuda.device_count() unless CUDA_VISIBLE_DEVICES has already been set to the desired value.""" @@ -1079,6 +1080,20 @@ def cuda_device_count_stateless() -> int: return _cuda_device_count_stateless(envs.CUDA_VISIBLE_DEVICES) +def weak_bind(bound_method: Callable[..., Any], ) -> Callable[..., None]: + """Make an instance method that weakly references + its associated instance and no-ops once that + instance is collected.""" + ref = weakref.ref(bound_method.__self__) # type: ignore[attr-defined] + unbound = bound_method.__func__ # type: ignore[attr-defined] + + def weak_bound(*args, **kwargs) -> None: + if inst := ref(): + unbound(inst, *args, **kwargs) + + return weak_bound + + #From: https://stackoverflow.com/a/4104188/2749989 def run_once(f: Callable[P, None]) -> Callable[P, None]: @@ -1121,10 +1136,10 @@ def parse_args(self, args=None, namespace=None): def _pull_args_from_config(args: List[str]) -> List[str]: """Method to pull arguments specified in the config file into the command-line args variable. - - The arguments in config file will be inserted between + + The arguments in config file will be inserted between the argument list. - + example: ```yaml port: 12323 @@ -1135,21 +1150,21 @@ def _pull_args_from_config(args: List[str]) -> List[str]: --config config.yaml -tp 2 $: args = [ "serve,chat,complete", - "facebook/opt-12B", - '--config', 'config.yaml', + "facebook/opt-12B", + '--config', 'config.yaml', '-tp', '2' ] $: args = [ "serve,chat,complete", - "facebook/opt-12B", - '--port', '12323', - '--tensor-parallel-size', '4', + "facebook/opt-12B", + '--port', '12323', + '--tensor-parallel-size', '4', '-tp', '2' ] ``` Please note how the config args are inserted after the sub command. - this way the order of priorities is maintained when these are args + this way the order of priorities is maintained when these are args parsed by super(). """ assert args.count( @@ -1175,7 +1190,7 @@ def _pull_args_from_config(args: List[str]) -> List[str]: @staticmethod def _load_config_file(file_path: str) -> List[str]: - """Loads a yaml file and returns the key value pairs as a + """Loads a yaml file and returns the key value pairs as a flattened list with argparse like pattern ```yaml port: 12323 @@ -1186,7 +1201,7 @@ def _load_config_file(file_path: str) -> List[str]: '--port': '12323', '--tensor-parallel-size': '4' ] - + """ extension: str = file_path.split('.')[-1] diff --git a/vllm/version.py b/vllm/version.py index 1f492a24bf078..0ddc7fb99ad45 100644 --- a/vllm/version.py +++ b/vllm/version.py @@ -2,6 +2,7 @@ try: import vllm.commit_id + __commit__ = vllm.commit_id.__commit__ except Exception as e: warnings.warn(f"Failed to read commit hash:\n{e}", @@ -9,4 +10,4 @@ stacklevel=2) __commit__ = "COMMIT_HASH_PLACEHOLDER" -__version__ = "0.6.1" +__version__ = "0.6.1.post2" diff --git a/vllm/worker/enc_dec_model_runner.py b/vllm/worker/enc_dec_model_runner.py index d6189d82d51d9..09dab0135f390 100644 --- a/vllm/worker/enc_dec_model_runner.py +++ b/vllm/worker/enc_dec_model_runner.py @@ -1,4 +1,5 @@ import dataclasses +import itertools from typing import Any, Dict, List, Optional, Tuple, Type, cast import torch @@ -24,7 +25,8 @@ from vllm.utils import STR_NOT_IMPL_ENC_DEC_BACKEND, make_tensor_with_pad from vllm.worker.model_runner import (GPUModelRunnerBase, ModelInputForGPUBuilder, - ModelInputForGPUWithSamplingMetadata) + ModelInputForGPUWithSamplingMetadata, + _get_graph_batch_size) from vllm.worker.model_runner_base import ( _add_attn_metadata_broadcastable_dict, _add_sampling_metadata_broadcastable_dict) @@ -178,7 +180,15 @@ def execute_model( raise ValueError("num_steps > 1 is not supported in " "EncoderDecoderModelRunner") - model_executable = self.model + if (model_input.attn_metadata is not None + and model_input.attn_metadata.prefill_metadata is None + and model_input.attn_metadata.decode_metadata.use_cuda_graph): + assert model_input.input_tokens is not None + graph_batch_size = model_input.input_tokens.shape[0] + model_executable = self.graph_runners[ + model_input.virtual_engine][graph_batch_size] + else: + model_executable = self.model seqlen_agnostic_kwargs = { "finished_requests_ids": model_input.finished_requests_ids, @@ -200,6 +210,9 @@ def execute_model( if not self.is_driver_worker: return [] + if model_input.async_callback is not None: + model_input.async_callback() + # Sample the next token. output: SamplerOutput = self.model.sample( logits=logits, @@ -231,14 +244,12 @@ def prepare_model_input( """ model_input = self._prepare_model_input_tensors( seq_group_metadata_list, finished_requests_ids) - ( attn_metadata, encoder_input_tokens_tensor, encoder_input_positions_tensor, ) = (self._prepare_encoder_model_input_tensors(seq_group_metadata_list, model_input)) - # Inject attn_metadata encoder/cross-attention fields & # encoder input tokens/positions into model_input. # Frozen dataclass fields cannot be modified, so use @@ -437,11 +448,29 @@ def _prepare_encoder_model_input_tensors( cross_block_tables.append([] if ( cross_block_table is None) else cross_block_table) - # Convert cross-attention block tables to encoder input tensor + if (model_input.attn_metadata is not None + and model_input.attn_metadata.use_cuda_graph): + # We will be using CUDA graph replay for this decode. + max_len_of_block_table = self.get_max_block_per_batch() + batch_size = len(encoder_seq_lens) + graph_batch_size = _get_graph_batch_size(batch_size) + assert graph_batch_size >= batch_size + cuda_graph_pad_size = graph_batch_size - batch_size + # extend the cross_block_tables and encoder_seq_lens to match + # the graph_batch_size. + cross_block_tables.extend([[] + for _ in range(cuda_graph_pad_size) + ]) + encoder_seq_lens.extend( + itertools.repeat(1, cuda_graph_pad_size)) + + else: + max_len_of_block_table = max( + len(block_table) for block_table in cross_block_tables) + cross_block_tables = make_tensor_with_pad( cross_block_tables, - max_len=max( - len(block_table) for block_table in cross_block_tables), + max_len=max_len_of_block_table, pad=0, dtype=torch.int32, device=self.device, diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index acb7bafefc204..e8c472df8b5fc 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -243,6 +243,7 @@ def __init__( prefix_cache_hit: bool = False, reinit: bool = False, reinit_use_defaults: bool = False, + encoder_seq_len: int = 0, ): if reinit: assert len(self.seq_ids) == len(seq_ids) # type: ignore @@ -256,6 +257,7 @@ def __init__( self.block_tables = block_tables self.computed_block_nums = computed_block_nums self.n_seqs = n_seqs + self.encoder_seq_len = encoder_seq_len if reinit: if len(self.seq_ids) == 1 and reinit_use_defaults: @@ -702,6 +704,11 @@ def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata): assert n_seqs == 1 self.decode_only = False + encoder_seq_len = 0 + + if self.runner.model_config.is_encoder_decoder_model: + encoder_seq_len = seq_group_metadata.encoder_seq_data.get_len() + inter_data = self.init_cached_inter_data( request_id=seq_group_metadata.request_id, seq_ids=seq_ids, @@ -709,7 +716,8 @@ def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata): block_tables=seq_group_metadata.block_tables, computed_block_nums=seq_group_metadata.computed_block_nums, reinit=True, - reinit_use_defaults=True) + reinit_use_defaults=True, + encoder_seq_len=encoder_seq_len) self.inter_data_list.append(inter_data) @@ -719,11 +727,15 @@ def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata): for per_seq_group_fn in self.per_seq_group_compute_fns: per_seq_group_fn(inter_data, seq_group_metadata) - def _use_captured_graph(self, batch_size: int, - max_decode_seq_len: int) -> bool: + def _use_captured_graph(self, + batch_size: int, + max_decode_seq_len: int, + max_encoder_seq_len: int = 0) -> bool: return (self.decode_only and not self.runner.model_config.enforce_eager - and batch_size <= self.runner.max_batchsize_to_capture - and max_decode_seq_len <= self.runner.max_seq_len_to_capture) + and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1] + and max_decode_seq_len <= self.runner.max_seq_len_to_capture + and max_encoder_seq_len <= self.runner.max_seq_len_to_capture + and batch_size <= self.runner.max_batchsize_to_capture) def build(self) -> ModelInputForGPU: """Finalize the builder intermediate data and @@ -763,15 +775,18 @@ def build(self) -> ModelInputForGPU: input_positions.extend(cur_input_positions) seq_lens = [] + query_lens = [] max_decode_seq_len = 0 + max_encoder_seq_len = 0 for inter_data in self.inter_data_list: seq_lens.extend(inter_data.seq_lens) + query_lens.extend(inter_data.query_lens) if not inter_data.is_prompt: max_decode_seq_len = max(max_decode_seq_len, max(inter_data.seq_lens)) - query_lens = [] - for inter_data in self.inter_data_list: - query_lens.extend(inter_data.query_lens) + if self.runner.model_config.is_encoder_decoder_model: + max_encoder_seq_len = max(max_encoder_seq_len, + inter_data.encoder_seq_len) # Mapping from request IDs to sequence IDs. Used for Jamba models # that manages the cache by itself. @@ -781,8 +796,10 @@ def build(self) -> ModelInputForGPU: } batch_size = len(input_tokens) - use_captured_graph = self._use_captured_graph(batch_size, - max_decode_seq_len) + use_captured_graph = self._use_captured_graph( + batch_size, + max_decode_seq_len, + max_encoder_seq_len=max_encoder_seq_len) # If cuda graph can be used, pad tensors accordingly. # See `capture_model` API for more details. @@ -1064,10 +1081,13 @@ def load_model(self) -> None: "This may lead to less accurate results!") if envs.VLLM_TEST_DYNAMO_GRAPH_CAPTURE and supports_dynamo(): + from vllm.compilation.backends import vllm_backend + from vllm.plugins import get_torch_compile_backend + backend = get_torch_compile_backend() or vllm_backend self.model = torch.compile( self.model, fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE, - backend="eager") + backend=backend) def save_sharded_state( self, @@ -1361,7 +1381,9 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: for batch_size in reversed(batch_size_capture_list): attn_metadata = ( self.attn_state.graph_capture_get_metadata_for_batch( - batch_size)) + batch_size, + is_encoder_decoder_model=self.model_config. + is_encoder_decoder_model)) if self.lora_config: lora_mapping = LoRAMapping( @@ -1377,10 +1399,10 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: ) self.set_active_prompt_adapters( set(), prompt_adapter_mapping) - graph_runner = CUDAGraphRunner( self.model, self.attn_backend.get_name(), - self.attn_state.graph_clone(batch_size)) + self.attn_state.graph_clone(batch_size), + self.model_config.is_encoder_decoder_model) capture_inputs = { "input_ids": @@ -1417,6 +1439,12 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: self.model.get_seqlen_agnostic_capture_inputs( batch_size) }) + if self.model_config.is_encoder_decoder_model: + # add the additional inputs to capture for + # encoder-decoder models. + self._update_inputs_to_capture_for_enc_dec_model( + capture_inputs) + graph_runner.capture(**capture_inputs) self.graph_memory_pool = graph_runner.graph.pool() self.graph_runners[virtual_engine][batch_size] = ( @@ -1427,6 +1455,24 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: # This usually takes < 10 seconds. logger.info("Graph capturing finished in %.0f secs.", elapsed_time) + def _update_inputs_to_capture_for_enc_dec_model(self, + capture_inputs: Dict[str, + Any]): + """ + Updates the set of input tensors needed for CUDA graph capture in an + encoder-decoder model. + + This method modifies the provided `capture_inputs` dictionary by + adding tensors specific to encoder-decoder specific models that + need to be captured for CUDA Graph replay. + """ + # During the decode phase encoder_input_ids and encoder_positions are + # unset. Do the same thing for graph capture. + capture_inputs["encoder_input_ids"] = torch.tensor( + [], dtype=torch.long).cuda() + capture_inputs["encoder_positions"] = torch.tensor( + [], dtype=torch.long).cuda() + @property def vocab_size(self) -> int: return self.model_config.get_vocab_size() @@ -1626,7 +1672,7 @@ def execute_model( class CUDAGraphRunner: def __init__(self, model: nn.Module, backend_name: str, - attn_state: AttentionState): + attn_state: AttentionState, is_encoder_decoder_model: bool): self.model = model self.backend_name = backend_name self.attn_state = attn_state @@ -1635,6 +1681,7 @@ def __init__(self, model: nn.Module, backend_name: str, self.output_buffers: Dict[str, torch.Tensor] = {} self._graph: Optional[torch.cuda.CUDAGraph] = None + self._is_encoder_decoder_model = is_encoder_decoder_model @property def graph(self): @@ -1668,8 +1715,9 @@ def capture( intermediate_tensors=intermediate_inputs, **kwargs, ) + # Wait for the warm up operations to finish before proceeding with + # Graph Capture. torch.cuda.synchronize() - # Capture the graph. self._graph = torch.cuda.CUDAGraph() with torch.cuda.graph(self._graph, pool=memory_pool, stream=stream): @@ -1701,10 +1749,14 @@ def capture( # Save the input and output buffers. self.input_buffers = { - "input_ids": input_ids, - "positions": positions, - "kv_caches": kv_caches, - **self.attn_state.get_graph_input_buffers(attn_metadata), + "input_ids": + input_ids, + "positions": + positions, + "kv_caches": + kv_caches, + **self.attn_state.get_graph_input_buffers( + attn_metadata, self._is_encoder_decoder_model), **kwargs, } if intermediate_inputs is not None: @@ -1734,8 +1786,8 @@ def forward( self.input_buffers["positions"].copy_(positions, non_blocking=True) self.input_buffers["slot_mapping"].copy_(attn_metadata.slot_mapping, non_blocking=True) - self.attn_state.prepare_graph_input_buffers(self.input_buffers, - attn_metadata) + self.attn_state.prepare_graph_input_buffers( + self.input_buffers, attn_metadata, self._is_encoder_decoder_model) if "seqlen_agnostic_capture_inputs" in self.input_buffers: self.model.copy_inputs_before_cuda_graphs(self.input_buffers, **kwargs) @@ -1749,6 +1801,12 @@ def forward( if key != "model_execute_time" and key != "model_forward_time": self.input_buffers[key].copy_(intermediate_tensors[key], non_blocking=True) + if self._is_encoder_decoder_model: + self.input_buffers["encoder_input_ids"].copy_( + kwargs['encoder_input_ids'], non_blocking=True) + self.input_buffers["encoder_positions"].copy_( + kwargs['encoder_positions'], non_blocking=True) + # Run the graph. self.graph.replay() # Return the output tensor. diff --git a/vllm/worker/model_runner_base.py b/vllm/worker/model_runner_base.py index 94d2507968382..975b88c0e79a2 100644 --- a/vllm/worker/model_runner_base.py +++ b/vllm/worker/model_runner_base.py @@ -3,11 +3,13 @@ from abc import ABC, abstractmethod from datetime import datetime from functools import wraps -from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Type, - TypeVar) +from typing import (TYPE_CHECKING, Any, Dict, Generic, Iterable, List, + Optional, Type, TypeVar) import torch +from torch import is_tensor +from vllm.logger import init_logger from vllm.model_executor.layers.sampler import SamplerOutput from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors, SequenceGroupMetadata @@ -17,6 +19,8 @@ from vllm.attention.backends.abstract import AttentionBackend from vllm.model_executor import SamplingMetadata +logger = init_logger(__name__) + T = TypeVar('T', bound="BroadcastableModelInput") @@ -113,6 +117,8 @@ def _wrapper(*args, **kwargs): except Exception as err: timestamp = datetime.now().strftime("%Y%m%d-%H%M%S") filename = f"/tmp/err_{func.__name__}_input_{timestamp}.pkl" + logger.info("Writing input of failed execution to %s...", + filename) with open(filename, "wb") as filep: dumped_inputs = { k: v @@ -122,7 +128,19 @@ def _wrapper(*args, **kwargs): for i, arg in enumerate(args): if i not in (exclude_args or []): dumped_inputs[f"arg_{i}"] = arg + + # Only persist dtype and shape for kvcache tensors + # (can be way to big otherwise) + if (kv_caches := dumped_inputs.get("kv_caches")) \ + and isinstance(kv_caches, Iterable): + dumped_inputs["kv_caches"] = [(t.dtype, t.shape) + for t in kv_caches + if is_tensor(t)] + pickle.dump(dumped_inputs, filep) + logger.info( + "Completed writing input of failed execution to %s.", + filename) raise type(err)( f"Error in model execution (input dumped to {filename}): " f"{str(err)}") from err diff --git a/vllm/worker/multi_step_model_runner.py b/vllm/worker/multi_step_model_runner.py index b900eb5a610ff..ebcafbbab119a 100644 --- a/vllm/worker/multi_step_model_runner.py +++ b/vllm/worker/multi_step_model_runner.py @@ -614,34 +614,66 @@ def _pythonize_sampler_output( frozen_model_input = model_input.frozen_model_input assert frozen_model_input.sampling_metadata is not None + sampling_metadata = frozen_model_input.sampling_metadata # samples generation should have been skipped assert not output.outputs pinned_buffer = pinned_sampled_token_buffer[:model_input.num_queries] - # CPU GPU sync - pinned_buffer = pinned_buffer.copy_(sampled_token_ids, non_blocking=False) + # We guarantee output tensors are ready, so it is safe to + # pythonize the sampler output & obtain CPU-side logprobs. + # + # However we should check whether logprobs pythonization may + # be skipped entirely, i.e. because no logprobs were requested + # or pythonization was not deferred. To that end, + # + # * `prompt_logprobs_are_requested_for_prefill` signals that + # there are *any* prefill-phase requests which specify that + # prompt logprobs should be returned. + # + # * `any_logprobs_are_requested` signals that there are any + # requests which (1) specify that sample logprobs should be + # returned, or (2) are in the prefill phase AND specify that + # prompt logprobs should be returned. + # + # Later on, these flags cause adjustments to the pythonization + # process to accommodate logprobs. + + seq_groups = sampling_metadata.seq_groups + prompt_logprobs_are_requested_for_prefill = any([ + sg.sampling_params.prompt_logprobs is not None and sg.is_prompt + for sg in seq_groups + ]) + any_logprobs_are_requested = ( + prompt_logprobs_are_requested_for_prefill + or any([sg.sampling_params.logprobs is not None for sg in seq_groups])) + + if prompt_logprobs_are_requested_for_prefill: + # CPU GPU sync, after gathering *only* sampled tokens (since + # requesting prompt logprobs leads `sampled_token_ids` to + # include prompt token ids in addition to sampled token ids.) + sample_idx_tensor = torch.tensor( + [sdx for sg in seq_groups for sdx in sg.sample_indices]) + pinned_buffer = pinned_buffer.copy_( + sampled_token_ids[sample_idx_tensor, :], non_blocking=False) + else: + # CPU GPU sync + pinned_buffer = pinned_buffer.copy_(sampled_token_ids, + non_blocking=False) # this will not block as the tensors are already on CPU samples_list = pinned_buffer.tolist() - sampling_metadata = frozen_model_input.sampling_metadata - skip_sampler_cpu_output = ( frozen_model_input.sampling_metadata.skip_sampler_cpu_output) - # We are guaranteed output tensors are ready, so it is safe to - # pythonize the sampler output & obtain CPU-side logprobs. - # - # However this computation may be skipped entirely - # if no pythonization was deferred. - seq_groups = sampling_metadata.seq_groups - logprobs_are_requested = any([ - sg.sampling_params.logprobs is not None - or sg.sampling_params.prompt_logprobs is not None for sg in seq_groups - ]) + # *Don't* skip logprobs pythonization *if*: + # * Any requests require logprobs to be returned in this + # iteration AND + # * These requests are being scheduled in a fashion which + # defers pythonization (i.e. multi-step scheduling.) do_pythonize_logprobs = (skip_sampler_cpu_output - and logprobs_are_requested) + and any_logprobs_are_requested) ( prompt_logprobs, sample_logprobs, @@ -666,7 +698,7 @@ def _pythonize_sampler_output( prompt_logprobs[sgdx], sample_logprobs[sgdx], ) - elif logprobs_are_requested: + elif any_logprobs_are_requested: ( group_prompt_logprobs, group_sample_logprobs, @@ -696,7 +728,7 @@ def _pythonize_sampler_output( seq_output.parent_seq_id = seq_ids[parent_id] seq_output.output_token = next_token_id - if logprobs_are_requested: + if any_logprobs_are_requested: seq_output.logprobs = group_sample_logprobs[tdx] else: logprobs = next(iter(seq_output.logprobs.values())) @@ -714,7 +746,7 @@ def _pythonize_sampler_output( seq_outputs.append( SequenceOutput(seq_ids[parent_id], next_token_id, (group_sample_logprobs[tdx] - if logprobs_are_requested else { + if any_logprobs_are_requested else { next_token_id: Logprob(logprob=float('inf'), rank=None, @@ -722,12 +754,12 @@ def _pythonize_sampler_output( }))) if cache is not None: completion_seq_group_output.prompt_logprobs = \ - group_prompt_logprobs if logprobs_are_requested else None + group_prompt_logprobs if any_logprobs_are_requested else None output.outputs.append(completion_seq_group_output) else: output.outputs.append( CompletionSequenceGroupOutput( seq_outputs, (group_prompt_logprobs - if logprobs_are_requested else None))) + if any_logprobs_are_requested else None))) assert len(output.outputs) > 0 diff --git a/vllm/worker/multi_step_tpu_worker.py b/vllm/worker/multi_step_tpu_worker.py new file mode 100644 index 0000000000000..e654f7172b266 --- /dev/null +++ b/vllm/worker/multi_step_tpu_worker.py @@ -0,0 +1,105 @@ +import dataclasses +from typing import Dict, Optional, Tuple + +import torch + +from vllm.distributed import broadcast_tensor_dict +from vllm.sequence import ExecuteModelRequest +from vllm.worker.tpu_model_runner import ModelInputForTPU +from vllm.worker.tpu_worker import TPUWorker +from vllm.worker.worker_base import WorkerInput + + +class MultiStepTPUWorker(TPUWorker): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.cached_model_input: Optional[ModelInputForTPU] = None + + def _get_driver_input_and_broadcast( + self, execute_model_req: ExecuteModelRequest + ) -> Tuple[ModelInputForTPU, WorkerInput, Dict[str, torch.Tensor]]: + assert self.is_driver_worker + assert execute_model_req.virtual_engine == 0 + + is_first_multi_step = execute_model_req.is_first_multi_step + is_last_step = execute_model_req.is_last_step + if is_first_multi_step: + worker_input: WorkerInput = self.prepare_worker_input( + execute_model_req=execute_model_req) + worker_input = dataclasses.replace( + worker_input, + num_steps=execute_model_req.num_lookahead_slots + 1) + model_input: ModelInputForTPU = ( + self.model_runner.prepare_model_input( + execute_model_req.seq_group_metadata_list, + execute_model_req.virtual_engine, + execute_model_req.finished_requests_ids)) + + if execute_model_req.async_callback: + model_input = dataclasses.replace( + model_input, + async_callback=execute_model_req.async_callback) + else: + assert self.cached_model_input is not None + model_input = self.cached_model_input + worker_input = WorkerInput() + model_input = dataclasses.replace( + model_input, + is_first_multi_step=is_first_multi_step, + is_last_step=is_last_step) + + if self.do_metadata_broadcast: + if is_first_multi_step: + broadcast_data = worker_input.as_broadcastable_tensor_dict() + broadcast_data.update( + model_input.as_broadcastable_tensor_dict()) + broadcast_tensor_dict(broadcast_data, src=0) + else: + broadcast_data = { + "is_first_multi_step": is_first_multi_step, + "is_last_step": is_last_step, + } + broadcast_tensor_dict(broadcast_data, src=0) + + # Retuning empty dict here to keep this compatible with + # `LocalOrDistributedWorkerBase._get_driver_input_and_broadcast` + return model_input, worker_input, {} + + def prepare_input( + self, + execute_model_req: Optional[ExecuteModelRequest] = None, + ) -> Optional[Tuple[ModelInputForTPU, WorkerInput, Dict[str, + torch.Tensor]]]: + if self.is_driver_worker: + if execute_model_req is None: + if self.do_metadata_broadcast: + broadcast_tensor_dict({}, src=0) + return None + + model_input, worker_input, _ = self._get_driver_input_and_broadcast( + execute_model_req) + if model_input.is_first_multi_step: + self.cached_model_input = model_input + return model_input, worker_input, {} + else: + broadcast_data = broadcast_tensor_dict(src=0) + if not broadcast_data: + return None + + if len(broadcast_data) == 2: + assert self.cached_model_input is not None + self.cached_model_input = dataclasses.replace( + self.cached_model_input, + is_first_multi_step=broadcast_data["is_first_multi_step"], + is_last_step=broadcast_data["is_last_step"]) + empty_worker_input = WorkerInput() + return self.cached_model_input, empty_worker_input, {} + + worker_input = WorkerInput.from_broadcasted_tensor_dict( + broadcast_data) + model_input = ( + self.model_runner. + make_model_input_from_broadcasted_tensor_dict(broadcast_data)) + self.cached_model_input = model_input + return model_input, worker_input, {} diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index db306bc743d3a..575769ca1aa4a 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -51,6 +51,8 @@ class ModelInputForTPU(ModelRunnerInputBase): num_samples: int best_of: List[int] seq_groups: List[List[int]] + is_first_multi_step: bool = True + is_last_step: bool = True virtual_engine: int = 0 async_callback: Optional[Callable] = None @@ -65,6 +67,8 @@ def as_broadcastable_tensor_dict( "num_samples": self.num_samples, "best_of": self.best_of, "seq_groups": self.seq_groups, + "is_first_multi_step": self.is_first_multi_step, + "is_last_step": self.is_last_step, "virtual_engine": self.virtual_engine, } _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) @@ -118,6 +122,7 @@ def __init__( self.block_size, False, ) + self.cached_step_outputs: List[torch.Tensor] = [] def load_model(self) -> None: self.device = self.device_config.device @@ -518,97 +523,159 @@ def execute_model( num_steps: int = 1, ) -> List[SamplerOutput]: assert intermediate_tensors is None - if num_steps > 1: - raise ValueError( - "TPUModelRunner does not support multi-step execution.") - - def _execute_model(*args): - """Move input args from CPU to device and execute the model.""" - - new_args = [] - for arg in args: - if isinstance(arg, torch.Tensor): - arg = arg.to(self.device) - elif isinstance(arg, AttentionMetadata): - arg.slot_mapping = arg.slot_mapping.to(self.device) - if getattr(arg, "block_tables", None) is not None: - arg.block_tables = arg.block_tables.to(self.device) - if getattr(arg, "context_lens", None) is not None: - arg.context_lens = arg.context_lens.to(self.device) - new_args.append(arg) - return self.model(*new_args, is_prompt=is_prompt) - - num_prefills = model_input.attn_metadata.num_prefills - is_prompt = num_prefills > 0 + if not model_input.is_first_multi_step: + if not model_input.is_last_step: + return [] + + use_async_out_proc = model_input.async_callback is not None + sampler_outputs = [] + num_outputs = len(self.cached_step_outputs) + for i in range(num_outputs): + next_token_ids = self.cached_step_outputs.pop(0) + next_token_ids = next_token_ids.cpu().tolist() + sampler_output = _make_decode_output(next_token_ids, + model_input.seq_groups) + sampler_outputs.append(sampler_output) + + if i < num_outputs - 1 and use_async_out_proc: + assert model_input.async_callback is not None + ctx = model_input.async_callback.keywords[ # type: ignore + "ctx"] + ctx.append_output( + outputs=[sampler_output], + seq_group_metadata_list=ctx.seq_group_metadata_list, + scheduler_outputs=ctx.scheduler_outputs, + is_async=False, + is_last_step=False) + model_input.async_callback() + if use_async_out_proc: + return [sampler_outputs[-1]] + else: + return sampler_outputs + + is_prompt = model_input.attn_metadata.num_prefills > 0 if is_prompt: + assert num_steps == 1 # NOTE(woosuk): Since the FlashAttention kernel does not support # ragged inputs, we split the prompts into different batches and # process them separately. This is a temporary hack that should be # optimized by using SplashAttention. - next_token_ids = [] orig_slot_mapping = model_input.attn_metadata.slot_mapping batch_size = model_input.input_lens.shape[0] start_idx = 0 + next_token_ids = [] for i in range(batch_size): # Get the actual prefill_len. prefill_len = model_input.input_lens[i:i + 1].item() prefill_len = _get_padded_prefill_len(prefill_len) end_idx = start_idx + prefill_len - model_input.attn_metadata.slot_mapping = orig_slot_mapping[ - None, start_idx:end_idx] - model_input.attn_metadata.num_prefills = 1 - output_token_ids = _execute_model( - model_input.token_ids[None, start_idx:end_idx], - model_input.position_ids[None, start_idx:end_idx], - model_input.attn_metadata, model_input.input_lens[i:i + 1], - model_input.t[i:i + 1], model_input.p[i:i + 1], - model_input.num_samples, kv_caches) - if i == 0 and model_input.async_callback is not None: - model_input.async_callback() - # Retrieve the outputs to CPU. - next_token_ids += output_token_ids.cpu().tolist() + token_ids = model_input.token_ids[None, start_idx:end_idx].to( + self.device) + position_ids = model_input.position_ids[None, + start_idx:end_idx].to( + self.device) + attn_metadata = model_input.attn_metadata + attn_metadata.num_prefills = 1 + attn_metadata.slot_mapping = orig_slot_mapping[ + None, start_idx:end_idx].to(self.device) + input_lens = model_input.input_lens[i:i + 1].to(self.device) + t = model_input.t[i:i + 1].to(self.device) + p = model_input.p[i:i + 1].to(self.device) + output_token_ids = self.model(token_ids, + position_ids, + attn_metadata, + input_lens, + t, + p, + model_input.num_samples, + kv_caches, + is_prompt=True) + next_token_ids.append(output_token_ids[0]) start_idx = end_idx - else: - # Execute the model. - output_token_ids = _execute_model( - model_input.token_ids, model_input.position_ids, - model_input.attn_metadata, model_input.input_lens, - model_input.t, model_input.p, model_input.num_samples, - kv_caches) + if model_input.async_callback is not None: model_input.async_callback() # Retrieve the outputs to CPU. - next_token_ids = output_token_ids.cpu().tolist() - - # NOTE(woosuk): Minimal code to construct the sampler outputs. - # The TPU backend does not reuse the sampler, since the TPU backend - # does not support the advanced sampling parameters such as logprobs. - zero_logprob = Logprob(0.0) - batch_idx = 0 - sampler_outputs = [] - for seq_group in model_input.seq_groups: - seq_ids = seq_group - seq_outputs = [] - if is_prompt: + next_token_ids = [ + output_token_ids.cpu().tolist() + for output_token_ids in next_token_ids + ] + + # NOTE(woosuk): Minimal code to construct the sampler outputs. + # The TPU backend does not reuse the sampler, since the TPU backend + # does not support advanced sampling parameters such as logprobs. + zero_logprob = Logprob(0.0) + sampler_outputs = [] + for i, seq_group in enumerate(model_input.seq_groups): + seq_ids = seq_group assert len(seq_ids) == 1 seq_id = seq_ids[0] - for i in range(model_input.best_of[batch_idx]): - next_token_id = next_token_ids[batch_idx][i] + seq_outputs = [] + for j in range(model_input.best_of[i]): + next_token_id = next_token_ids[i][j] seq_outputs.append( SequenceOutput(seq_id, next_token_id, {next_token_id: zero_logprob})) - batch_idx += 1 - else: - for seq_id in seq_ids: - next_token_id = next_token_ids[batch_idx] - seq_outputs.append( - SequenceOutput(seq_id, next_token_id, - {next_token_id: zero_logprob})) - batch_idx += 1 - sampler_outputs.append( - CompletionSequenceGroupOutput(seq_outputs, None)) - return [SamplerOutput(sampler_outputs)] + sampler_outputs.append( + CompletionSequenceGroupOutput(seq_outputs, None)) + return [SamplerOutput(sampler_outputs)] + else: + token_ids = model_input.token_ids.to(self.device) + position_ids = model_input.position_ids.to(self.device) + attn_metadata = model_input.attn_metadata + attn_metadata.slot_mapping = attn_metadata.slot_mapping.to( + self.device) + attn_metadata.block_tables = attn_metadata.block_tables.to( + self.device) + attn_metadata.context_lens = attn_metadata.context_lens.to( + self.device) + t = model_input.t.to(self.device) + p = model_input.p.to(self.device) + input_lens = model_input.input_lens.to(self.device) + for i in range(num_steps): + slot_mapping = attn_metadata.slot_mapping + output_token_ids = self.model(token_ids, + position_ids, + attn_metadata, + input_lens, + t, + p, + model_input.num_samples, + kv_caches, + is_prompt=False) + self.cached_step_outputs.append(output_token_ids) + + if i < num_steps - 1: + # Prepare the inputs for the next step. + token_ids = output_token_ids.unsqueeze(dim=1).int() + position_ids = position_ids + 1 + attn_metadata.context_lens = attn_metadata.context_lens + 1 + + block_tables = attn_metadata.block_tables + block_number = block_tables.gather( + 1, + position_ids.long() // self.block_size) + block_offset = position_ids % self.block_size + + is_padding = slot_mapping == _PAD_SLOT_ID + slot_mapping = block_number * self.block_size + block_offset + slot_mapping = slot_mapping.long() + slot_mapping = torch.where(is_padding, _PAD_SLOT_ID, + slot_mapping) + attn_metadata.slot_mapping = slot_mapping + + if model_input.async_callback is not None: + model_input.async_callback() + + if num_steps > 1: + return [] + # Retrieve the outputs to CPU. + next_token_ids = self.cached_step_outputs.pop(0) + next_token_ids = next_token_ids.cpu().tolist() + sampler_output = _make_decode_output(next_token_ids, + model_input.seq_groups) + return [sampler_output] class ModelWrapper(TorchCompileWrapperWithCustomDispatcher): @@ -756,3 +823,24 @@ def _apply_top_p(logits: torch.Tensor, p: torch.Tensor) -> torch.Tensor: cutoff_logit = torch.gather(logits_sorted, -1, cutoff_index) logits = logits.masked_fill_(logits < cutoff_logit, -float("inf")) return logits + + +def _make_decode_output( + next_token_ids: List[int], + seq_groups: List[List[int]], +) -> SamplerOutput: + zero_logprob = Logprob(0.0) + sampler_outputs = [] + batch_idx = 0 + for seq_group in seq_groups: + seq_ids = seq_group + seq_outputs = [] + for seq_id in seq_ids: + next_token_id = next_token_ids[batch_idx] + seq_outputs.append( + SequenceOutput(seq_id, next_token_id, + {next_token_id: zero_logprob})) + batch_idx += 1 + sampler_outputs.append(CompletionSequenceGroupOutput( + seq_outputs, None)) + return SamplerOutput(sampler_outputs) diff --git a/vllm/worker/utils.py b/vllm/worker/utils.py index d73023e8e1724..a58b80e4f2adb 100644 --- a/vllm/worker/utils.py +++ b/vllm/worker/utils.py @@ -47,10 +47,6 @@ def assert_enc_dec_mr_supported_scenario( raise NotImplementedError( STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_SPEC_DEC']) - if not enc_dec_mr.model_config.enforce_eager: - raise NotImplementedError( - STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_CUDA_GRAPH']) - if enc_dec_mr.prompt_adapter_config is not None: raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_ERR_STRS[ 'STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER']) diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 52092dc2dc291..3851843afc960 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -454,14 +454,20 @@ def init_worker_distributed_environment( def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype): # Check if the GPU supports the dtype. - if torch_dtype == torch.bfloat16: - compute_capability = current_platform.get_device_capability() - if compute_capability[0] < 8: + if torch_dtype == torch.bfloat16: # noqa: SIM102 + if not current_platform.has_device_capability(80): + capability = current_platform.get_device_capability() gpu_name = current_platform.get_device_name() + + if capability is None: + compute_str = "does not have a compute capability" + else: + version_str = capability.as_version_str() + compute_str = f"has compute capability {version_str}" + raise ValueError( "Bfloat16 is only supported on GPUs with compute capability " - f"of at least 8.0. Your {gpu_name} GPU has compute capability " - f"{compute_capability[0]}.{compute_capability[1]}. " + f"of at least 8.0. Your {gpu_name} GPU {compute_str}. " "You can use float16 instead by explicitly setting the" "`dtype` flag in CLI, for example: --dtype=half.")