diff --git a/.buildkite/nightly-benchmarks/benchmark-pipeline.yaml b/.buildkite/nightly-benchmarks/benchmark-pipeline.yaml index eec2a51e2f8fd..3db77d5f16022 100644 --- a/.buildkite/nightly-benchmarks/benchmark-pipeline.yaml +++ b/.buildkite/nightly-benchmarks/benchmark-pipeline.yaml @@ -9,8 +9,11 @@ steps: - image: badouralix/curl-jq command: - sh .buildkite/nightly-benchmarks/scripts/wait-for-image.sh + - wait + - label: "A100" + # skip: "use this flag to conditionally skip the benchmark step, useful for PR testing" agents: queue: A100 plugins: @@ -41,20 +44,43 @@ steps: - name: devshm emptyDir: medium: Memory - # - label: "H100" - # agents: - # queue: H100 - # plugins: - # - docker#v5.11.0: - # image: public.ecr.aws/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT - # command: - # - bash - # - .buildkite/nightly-benchmarks/run-benchmarks-suite.sh - # mount-buildkite-agent: true - # propagate-environment: true - # ipc: host - # gpus: all - # environment: - # - VLLM_USAGE_SOURCE - # - HF_TOKEN + - label: "H200" + # skip: "use this flag to conditionally skip the benchmark step, useful for PR testing" + agents: + queue: H200 + plugins: + - docker#v5.12.0: + image: public.ecr.aws/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT + command: + - bash + - .buildkite/nightly-benchmarks/scripts/run-performance-benchmarks.sh + mount-buildkite-agent: true + propagate-environment: true + ipc: host + gpus: 4,5,6,7 + volumes: + - /data/benchmark-hf-cache:/root/.cache/huggingface + environment: + - VLLM_USAGE_SOURCE + - HF_TOKEN + + - label: "H100" + # skip: "use this flag to conditionally skip the benchmark step, useful for PR testing" + agents: + queue: H100 + plugins: + - docker#v5.12.0: + image: public.ecr.aws/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT + command: + - bash + - .buildkite/nightly-benchmarks/scripts/run-performance-benchmarks.sh + mount-buildkite-agent: true + propagate-environment: true + ipc: host + gpus: all # see CUDA_VISIBLE_DEVICES for actual GPUs used + volumes: + - /data/benchmark-hf-cache:/root/.cache/huggingface + environment: + - VLLM_USAGE_SOURCE + - HF_TOKEN diff --git a/.buildkite/nightly-benchmarks/scripts/convert-results-json-to-markdown.py b/.buildkite/nightly-benchmarks/scripts/convert-results-json-to-markdown.py index 7cf05610b9953..9d3646e2f6a15 100644 --- a/.buildkite/nightly-benchmarks/scripts/convert-results-json-to-markdown.py +++ b/.buildkite/nightly-benchmarks/scripts/convert-results-json-to-markdown.py @@ -157,6 +157,18 @@ def results_to_json(latency, throughput, serving): throughput_results, serving_results) + for df in [latency_results, serving_results, throughput_results]: + if df.empty: + continue + + # Sort all dataframes by their respective "Test name" columns + df.sort_values(by="Test name", inplace=True) + + # The GPUs sometimes come in format of "GPUTYPE\nGPUTYPE\n...", + # we want to turn it into "8xGPUTYPE" + df["GPU"] = df["GPU"].apply( + lambda x: f"{len(x.split('\n'))}x{x.split('\n')[0]}") + # get markdown tables latency_md_table = tabulate(latency_results, headers='keys', diff --git a/.buildkite/nightly-benchmarks/scripts/run-performance-benchmarks.sh b/.buildkite/nightly-benchmarks/scripts/run-performance-benchmarks.sh index d397b05cdff23..0d16a83781ab2 100644 --- a/.buildkite/nightly-benchmarks/scripts/run-performance-benchmarks.sh +++ b/.buildkite/nightly-benchmarks/scripts/run-performance-benchmarks.sh @@ -6,6 +6,7 @@ # Do not set -e, as the mixtral 8x22B model tends to crash occasionally # and we still want to see other benchmarking results even when mixtral crashes. +set -x set -o pipefail check_gpus() { @@ -85,11 +86,7 @@ kill_gpu_processes() { ps -aux lsof -t -i:8000 | xargs -r kill -9 - pkill -f pt_main_thread - # this line doesn't work now - # ps aux | grep python | grep openai | awk '{print $2}' | xargs -r kill -9 - pkill -f python3 - pkill -f /usr/bin/python3 + pgrep python3 | xargs -r kill -9 # wait until GPU memory usage smaller than 1GB @@ -289,7 +286,7 @@ run_serving_tests() { # run the server echo "Running test case $test_name" echo "Server command: $server_command" - eval "$server_command" & + bash -c "$server_command" & server_pid=$! # wait until the server is alive @@ -322,7 +319,7 @@ run_serving_tests() { echo "Running test case $test_name with qps $qps" echo "Client command: $client_command" - eval "$client_command" + bash -c "$client_command" # record the benchmarking commands jq_output=$(jq -n \ diff --git a/.buildkite/run-amd-test.sh b/.buildkite/run-amd-test.sh index 46e368ee0de1d..5e79984c9f7b9 100755 --- a/.buildkite/run-amd-test.sh +++ b/.buildkite/run-amd-test.sh @@ -85,7 +85,6 @@ if [[ $commands == *" kernels "* ]]; then --ignore=kernels/test_encoder_decoder_attn.py \ --ignore=kernels/test_flash_attn.py \ --ignore=kernels/test_flashinfer.py \ - --ignore=kernels/test_gguf.py \ --ignore=kernels/test_int8_quant.py \ --ignore=kernels/test_machete_gemm.py \ --ignore=kernels/test_mamba_ssm.py \ diff --git a/.buildkite/run-cpu-test-ppc64le.sh b/.buildkite/run-cpu-test-ppc64le.sh index 5d7a0bff90963..bc06838d804ff 100755 --- a/.buildkite/run-cpu-test-ppc64le.sh +++ b/.buildkite/run-cpu-test-ppc64le.sh @@ -4,49 +4,11 @@ # It serves a sanity check for compilation and basic model usage. set -ex -# Try building the docker image -docker build -t cpu-test -f Dockerfile.ppc64le . - # Setup cleanup -remove_docker_container() { docker rm -f cpu-test || true; } +remove_docker_container() { docker rm -f cpu-test || true; docker system prune -f; } trap remove_docker_container EXIT remove_docker_container -# Run the image, setting --shm-size=4g for tensor parallel. -source /etc/environment -#docker run -itd --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --privileged=true --network host -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --shm-size=4g --name cpu-test cpu-test -docker run -itd --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --privileged=true --network host -e HF_TOKEN="$HF_TOKEN" --name cpu-test cpu-test - -function cpu_tests() { - set -e - - # Run basic model test - docker exec cpu-test bash -c " - set -e - pip install pytest pytest-asyncio \ - decord einops librosa peft Pillow sentence-transformers soundfile \ - transformers_stream_generator matplotlib datamodel_code_generator - pip install torchvision --index-url https://download.pytorch.org/whl/cpu - pytest -v -s tests/models/decoder_only/language -m cpu_model - pytest -v -s tests/models/embedding/language -m cpu_model - pytest -v -s tests/models/encoder_decoder/language -m cpu_model - pytest -v -s tests/models/decoder_only/audio_language -m cpu_model - pytest -v -s tests/models/decoder_only/vision_language -m cpu_model" - - # online inference - docker exec cpu-test bash -c " - set -e - python3 -m vllm.entrypoints.openai.api_server --model facebook/opt-125m & - timeout 600 bash -c 'until curl localhost:8000/v1/models; do sleep 1; done' || exit 1 - python3 benchmarks/benchmark_serving.py \ - --backend vllm \ - --dataset-name random \ - --model facebook/opt-125m \ - --num-prompts 20 \ - --endpoint /v1/completions \ - --tokenizer facebook/opt-125m" -} +# Try building the docker image +docker build -t cpu-test -f Dockerfile.ppc64le . -# All of CPU tests are expected to be finished less than 25 mins. -export -f cpu_tests -timeout 25m bash -c "cpu_tests" diff --git a/.buildkite/run-cpu-test.sh b/.buildkite/run-cpu-test.sh index 14756b5964aaf..4f1729d46dae2 100644 --- a/.buildkite/run-cpu-test.sh +++ b/.buildkite/run-cpu-test.sh @@ -13,26 +13,27 @@ numactl -C "$CORE_RANGE" -N "$NUMA_NODE" docker build -t cpu-test -f Dockerfile. numactl -C "$CORE_RANGE" -N "$NUMA_NODE" docker build --build-arg VLLM_CPU_DISABLE_AVX512="true" -t cpu-test-avx2 -f Dockerfile.cpu . # Setup cleanup -remove_docker_container() { docker rm -f cpu-test cpu-test-avx2 || true; } +remove_docker_container() { docker rm -f cpu-test-"$NUMA_NODE" cpu-test-avx2-"$NUMA_NODE" || true; } trap remove_docker_container EXIT remove_docker_container # Run the image, setting --shm-size=4g for tensor parallel. docker run -itd --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --cpuset-cpus="$CORE_RANGE" \ - --cpuset-mems="$NUMA_NODE" --privileged=true --network host -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --shm-size=4g --name cpu-test cpu-test + --cpuset-mems="$NUMA_NODE" --privileged=true --network host -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --shm-size=4g --name cpu-test-"$NUMA_NODE" cpu-test docker run -itd --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --cpuset-cpus="$CORE_RANGE" \ - --cpuset-mems="$NUMA_NODE" --privileged=true --network host -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --shm-size=4g --name cpu-test-avx2 cpu-test-avx2 + --cpuset-mems="$NUMA_NODE" --privileged=true --network host -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --shm-size=4g --name cpu-test-avx2-"$NUMA_NODE" cpu-test-avx2 function cpu_tests() { set -e + export NUMA_NODE=$2 # offline inference - docker exec cpu-test-avx2 bash -c " + docker exec cpu-test-avx2-"$NUMA_NODE" bash -c " set -e python3 examples/offline_inference.py" # Run basic model test - docker exec cpu-test bash -c " + docker exec cpu-test-"$NUMA_NODE" bash -c " set -e pip install pytest pytest-asyncio \ decord einops librosa peft Pillow sentence-transformers soundfile \ @@ -45,20 +46,26 @@ function cpu_tests() { pytest -v -s tests/models/decoder_only/vision_language -m cpu_model" # Run compressed-tensor test - docker exec cpu-test bash -c " + docker exec cpu-test-"$NUMA_NODE" bash -c " set -e pytest -s -v \ tests/quantization/test_compressed_tensors.py::test_compressed_tensors_w8a8_static_setup \ tests/quantization/test_compressed_tensors.py::test_compressed_tensors_w8a8_dynamic_per_token" # Run AWQ test - docker exec cpu-test bash -c " + docker exec cpu-test-"$NUMA_NODE" bash -c " set -e pytest -s -v \ tests/quantization/test_ipex_quant.py" + # Run chunked-prefill and prefix-cache test + docker exec cpu-test-"$NUMA_NODE" bash -c " + set -e + pytest -s -v -k cpu_model \ + tests/basic_correctness/test_chunked_prefill.py" + # online inference - docker exec cpu-test bash -c " + docker exec cpu-test-"$NUMA_NODE" bash -c " set -e export VLLM_CPU_KVCACHE_SPACE=10 export VLLM_CPU_OMP_THREADS_BIND=$1 @@ -75,4 +82,4 @@ function cpu_tests() { # All of CPU tests are expected to be finished less than 25 mins. export -f cpu_tests -timeout 25m bash -c "cpu_tests $CORE_RANGE" +timeout 30m bash -c "cpu_tests $CORE_RANGE $NUMA_NODE" diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 24bf223fb12c0..f5591f1098534 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -9,8 +9,7 @@ # label(str): the name of the test. emoji allowed. # fast_check(bool): whether to run this on each commit on fastcheck pipeline. # fast_check_only(bool): run this test on fastcheck pipeline only -# nightly(bool): run this test in nightly pipeline only -# optional(bool): never run this test by default (i.e. need to unblock manually) +# optional(bool): never run this test by default (i.e. need to unblock manually) unless it's scheduled nightly run. # command(str): the single command to run for tests. incompatible with commands. # commands(list): the list of commands to run for test. incompatbile with command. # mirror_hardwares(list): the list of hardwares to run the test on as well. currently only supports [amd] @@ -51,7 +50,9 @@ steps: - tests/multimodal - tests/test_utils - tests/worker + - tests/test_lazy_torch_compile.py commands: + - python3 test_lazy_torch_compile.py - pytest -v -s mq_llm_engine # MQLLMEngine - pytest -v -s async_engine # AsyncLLMEngine - NUM_SCHEDULER_STEPS=4 pytest -v -s async_engine/test_async_llm_engine.py @@ -171,7 +172,7 @@ steps: - vllm/ - tests/v1 commands: - - pytest -v -s v1 + - VLLM_USE_V1=1 pytest -v -s v1 - label: Examples Test # 15min working_dir: "/vllm-workspace/examples" @@ -229,7 +230,7 @@ steps: source_file_dependencies: - vllm/lora - tests/lora - command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT --ignore=lora/test_long_context.py + command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT --ignore lora/test_long_context.py lora/test_chatglm3_tp.py lora/test_llama_tp.py parallelism: 4 - label: "PyTorch Fullgraph Smoke Test" # 9min @@ -333,10 +334,9 @@ steps: commands: - pytest -v -s models/decoder_only/language -m 'core_model or quant_model' - pytest -v -s models/embedding/language -m core_model - - pytest -v -s models/embedding/vision_language -m core_model - label: Language Models Test (Extended) # 50min - nightly: true + optional: true source_file_dependencies: - vllm/ - tests/models/decoder_only/language @@ -345,7 +345,6 @@ steps: commands: - pytest -v -s models/decoder_only/language -m 'not core_model and not quant_model' - pytest -v -s models/embedding/language -m 'not core_model' - - pytest -v -s models/embedding/vision_language -m 'not core_model' - label: Multi-Modal Models Test (Standard) # 26min #mirror_hardwares: [amd] @@ -358,11 +357,12 @@ steps: commands: - pytest -v -s models/decoder_only/audio_language -m 'core_model or quant_model' - pytest -v -s --ignore models/decoder_only/vision_language/test_phi3v.py models/decoder_only/vision_language -m 'core_model or quant_model' + - pytest -v -s models/embedding/vision_language -m core_model - pytest -v -s models/encoder_decoder/language -m core_model - pytest -v -s models/encoder_decoder/vision_language -m core_model - label: Multi-Modal Models Test (Extended) # 1h15m - nightly: true + optional: true source_file_dependencies: - vllm/ - tests/models/decoder_only/audio_language @@ -375,6 +375,7 @@ steps: # https://github.com/huggingface/transformers/issues/34307 - pytest -v -s models/decoder_only/vision_language/test_phi3v.py - pytest -v -s --ignore models/decoder_only/vision_language/test_phi3v.py models/decoder_only/vision_language -m 'not core_model and not quant_model' + - pytest -v -s models/embedding/vision_language -m 'not core_model' - pytest -v -s models/encoder_decoder/language -m 'not core_model' - pytest -v -s models/encoder_decoder/vision_language -m 'not core_model' @@ -429,6 +430,9 @@ steps: - vllm/model_executor/models/ - tests/distributed/ - vllm/compilation + - vllm/worker/worker_base.py + - vllm/worker/worker.py + - vllm/worker/model_runner.py commands: - pytest -v -s ./compile/test_basic_correctness.py - pytest -v -s ./compile/test_wrapper.py @@ -442,6 +446,7 @@ steps: - pip install -e ./plugins/vllm_add_dummy_model - pytest -v -s distributed/test_distributed_oot.py - CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py + - CUDA_VISIBLE_DEVICES=0,1 pytest -v -s kv_transfer/disagg_test.py - label: Multi-step Tests (4 GPUs) # 36min working_dir: "/vllm-workspace/tests" @@ -474,18 +479,23 @@ steps: - pytest -v -s distributed/test_pp_cudagraph.py - pytest -v -s distributed/test_pipeline_parallel.py -- label: LoRA Long Context (Distributed) # 11min - # This test runs llama 13B, so it is required to run on 4 GPUs. +- label: LoRA TP Test (Distributed) num_gpus: 4 soft_fail: true source_file_dependencies: - vllm/lora - - tests/lora/test_long_context + - tests/lora commands: # FIXIT: find out which code initialize cuda before running the test # before the fix, we need to use spawn to test it - export VLLM_WORKER_MULTIPROC_METHOD=spawn + # This test runs llama 13B, so it is required to run on 4 GPUs. - pytest -v -s -x lora/test_long_context.py + # There is some Tensor Parallelism related processing logic in LoRA that + # requires multi-GPU testing for validation. + - pytest -v -s -x lora/test_chatglm3_tp.py + - pytest -v -s -x lora/test_llama_tp.py + - label: Weight Loading Multiple GPU Test # 33min working_dir: "/vllm-workspace/tests" @@ -513,6 +523,7 @@ steps: - label: Distributed Tests (A100) # optional gpu: a100 + optional: true num_gpus: 4 source_file_dependencies: - vllm/ @@ -526,6 +537,7 @@ steps: - label: LM Eval Large Models # optional gpu: a100 + optional: true num_gpus: 4 working_dir: "/vllm-workspace/.buildkite/lm-eval-harness" source_file_dependencies: diff --git a/.buildkite/upload-wheels.sh b/.buildkite/upload-wheels.sh index 541b395eddbe7..7345dd4e66b29 100644 --- a/.buildkite/upload-wheels.sh +++ b/.buildkite/upload-wheels.sh @@ -25,7 +25,12 @@ echo "Version: $version" # If the version contains "dev", rename it to v1.0.0.dev for consistency if [[ $version == *dev* ]]; then - new_version="1.0.0.dev" + suffix="${version##*.}" + if [[ $suffix == cu* ]]; then + new_version="1.0.0.dev+${suffix}" + else + new_version="1.0.0.dev" + fi new_wheel="${wheel/$version/$new_version}" mv -- "$wheel" "$new_wheel" wheel="$new_wheel" diff --git a/.github/FUNDING.yml b/.github/FUNDING.yml index 71f4e520135d4..d1f6105a47166 100644 --- a/.github/FUNDING.yml +++ b/.github/FUNDING.yml @@ -1,2 +1,2 @@ github: [vllm-project] -open_collective: [vllm] +open_collective: vllm diff --git a/.github/dependabot.yml b/.github/dependabot.yml index 4f54eea564ecb..683b70cd89989 100644 --- a/.github/dependabot.yml +++ b/.github/dependabot.yml @@ -15,6 +15,8 @@ updates: allow: - dependency-type: "all" ignore: + - dependency-name: "*" + update-types: ["version-update:semver-patch"] - dependency-name: "torch" - dependency-name: "torchvision" - dependency-name: "xformers" @@ -24,9 +26,6 @@ updates: - dependency-name: "ray[adag]" - dependency-name: "lm-eval" groups: - patch-update: - applies-to: version-updates - update-types: ["patch"] minor-update: applies-to: version-updates update-types: ["minor"] diff --git a/.github/workflows/clang-format.yml b/.github/workflows/clang-format.yml index 68149d2dc019f..67454c40d4b53 100644 --- a/.github/workflows/clang-format.yml +++ b/.github/workflows/clang-format.yml @@ -6,6 +6,7 @@ on: push: branches: - main + - develop paths: - '**/*.h' - '**/*.cpp' @@ -15,6 +16,7 @@ on: pull_request: branches: - main + - develop paths: - '**/*.h' - '**/*.cpp' diff --git a/.github/workflows/png-lint.yml b/.github/workflows/png-lint.yml new file mode 100644 index 0000000000000..4932af943a07b --- /dev/null +++ b/.github/workflows/png-lint.yml @@ -0,0 +1,37 @@ +name: Lint PNG exports from excalidraw +on: + push: + branches: + - "main" + paths: + - '*.excalidraw.png' + - '.github/workflows/png-lint.yml' + pull_request: + branches: + - "main" + paths: + - '*.excalidraw.png' + - '.github/workflows/png-lint.yml' + +env: + LC_ALL: en_US.UTF-8 + +defaults: + run: + shell: bash + +permissions: + contents: read + +jobs: + actionlint: + runs-on: ubuntu-latest + steps: + - name: "Checkout" + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + fetch-depth: 0 + + - name: "Run png-lint.sh to check excalidraw exported images" + run: | + tools/png-lint.sh diff --git a/.github/workflows/sphinx-lint.yml b/.github/workflows/sphinx-lint.yml new file mode 100644 index 0000000000000..e0bb24276a653 --- /dev/null +++ b/.github/workflows/sphinx-lint.yml @@ -0,0 +1,32 @@ +name: Lint documentation + +on: + push: + branches: + - main + paths: + - "docs/**" + pull_request: + branches: + - main + paths: + - "docs/**" + +jobs: + sphinx-lint: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.12"] + steps: + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -r requirements-lint.txt + - name: Linting docs + run: tools/sphinx-lint.sh diff --git a/CMakeLists.txt b/CMakeLists.txt index 440588241d6a0..38dcf9a591bff 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -34,7 +34,7 @@ install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY TRUE)" ALL_COMPONENTS) set(PYTHON_SUPPORTED_VERSIONS "3.9" "3.10" "3.11" "3.12") # Supported NVIDIA architectures. -set(CUDA_SUPPORTED_ARCHS "7.0;7.5;8.0;8.6;8.9;9.0") +set(CUDA_SUPPORTED_ARCHS "7.0;7.2;7.5;8.0;8.6;8.7;8.9;9.0") # Supported AMD GPU architectures. set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100;gfx1101;gfx1200;gfx1201") @@ -230,6 +230,7 @@ set(VLLM_EXT_SRC "csrc/quantization/gptq/q_gemm.cu" "csrc/quantization/compressed_tensors/int8_quant_kernels.cu" "csrc/quantization/fp8/common.cu" + "csrc/quantization/gguf/gguf_kernel.cu" "csrc/cuda_utils_kernels.cu" "csrc/prepare_inputs/advance_step.cu" "csrc/torch_bindings.cpp") @@ -240,7 +241,19 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") # Set CUTLASS_REVISION manually -- its revision detection doesn't work in this case. set(CUTLASS_REVISION "v3.5.1" CACHE STRING "CUTLASS revision to use") - FetchContent_Declare( + # Use the specified CUTLASS source directory for compilation if VLLM_CUTLASS_SRC_DIR is provided + if (DEFINED ENV{VLLM_CUTLASS_SRC_DIR}) + set(VLLM_CUTLASS_SRC_DIR $ENV{VLLM_CUTLASS_SRC_DIR}) + endif() + + if(VLLM_CUTLASS_SRC_DIR) + if(NOT IS_ABSOLUTE VLLM_CUTLASS_SRC_DIR) + get_filename_component(VLLM_CUTLASS_SRC_DIR "${VLLM_CUTLASS_SRC_DIR}" ABSOLUTE) + endif() + message(STATUS "The VLLM_CUTLASS_SRC_DIR is set, using ${VLLM_CUTLASS_SRC_DIR} for compilation") + FetchContent_Declare(cutlass SOURCE_DIR ${VLLM_CUTLASS_SRC_DIR}) + else() + FetchContent_Declare( cutlass GIT_REPOSITORY https://github.com/nvidia/cutlass.git GIT_TAG v3.5.1 @@ -250,7 +263,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") # Important: If GIT_SHALLOW is enabled then GIT_TAG works only with branch names and tags. # So if the GIT_TAG above is updated to a commit hash, GIT_SHALLOW must be set to FALSE GIT_SHALLOW TRUE - ) + ) + endif() FetchContent_MakeAvailable(cutlass) list(APPEND VLLM_EXT_SRC @@ -258,7 +272,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") "csrc/mamba/causal_conv1d/causal_conv1d.cu" "csrc/quantization/aqlm/gemm_kernels.cu" "csrc/quantization/awq/gemm_kernels.cu" - "csrc/quantization/gguf/gguf_kernel.cu" "csrc/custom_all_reduce.cu" "csrc/permute_cols.cu" "csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu") @@ -270,7 +283,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") # Only build Marlin kernels if we are building for at least some compatible archs. # Keep building Marlin for 9.0 as there are some group sizes and shapes that # are not supported by Machete yet. - cuda_archs_loose_intersection(MARLIN_ARCHS "8.0;8.6;8.9;9.0" ${CUDA_ARCHS}) + cuda_archs_loose_intersection(MARLIN_ARCHS "8.0;8.6;8.7;8.9;9.0" ${CUDA_ARCHS}) if (MARLIN_ARCHS) set(MARLIN_SRCS "csrc/quantization/fp8/fp8_marlin.cu" @@ -321,8 +334,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") # # For the cutlass_scaled_mm kernels we want to build the c2x (CUTLASS 2.x) # kernels for the remaining archs that are not already built for 3x. - cuda_archs_loose_intersection(SCALED_MM_2X_ARCHS - "7.5;8.0;8.6;8.9;9.0" "${CUDA_ARCHS}") + cuda_archs_loose_intersection(SCALED_MM_2X_ARCHS + "7.5;8.0;8.6;8.7;8.9;9.0" "${CUDA_ARCHS}") # subtract out the archs that are already built for 3x list(REMOVE_ITEM SCALED_MM_2X_ARCHS ${SCALED_MM_3X_ARCHS}) if (SCALED_MM_2X_ARCHS) @@ -453,7 +466,7 @@ set_gencode_flags_for_srcs( CUDA_ARCHS "${CUDA_ARCHS}") if(VLLM_GPU_LANG STREQUAL "CUDA") - cuda_archs_loose_intersection(MARLIN_MOE_ARCHS "8.0;8.6;8.9;9.0" "${CUDA_ARCHS}") + cuda_archs_loose_intersection(MARLIN_MOE_ARCHS "8.0;8.6;8.7;8.9;9.0" "${CUDA_ARCHS}") if (MARLIN_MOE_ARCHS) set(MARLIN_MOE_SRC "csrc/moe/marlin_kernels/marlin_moe_kernel.h" @@ -569,7 +582,7 @@ else() FetchContent_Declare( vllm-flash-attn GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git - GIT_TAG 5259c586c403a4e4d8bf69973c159b40cc346fb9 + GIT_TAG 04325b6798bcc326c86fb35af62d05a9c8c8eceb GIT_PROGRESS TRUE # Don't share the vllm-flash-attn build between build types BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn diff --git a/Dockerfile b/Dockerfile index 220dbe26712ec..682f046d4b6ec 100644 --- a/Dockerfile +++ b/Dockerfile @@ -191,6 +191,10 @@ ADD . /vllm-workspace/ RUN --mount=type=cache,target=/root/.cache/pip \ python3 -m pip install -r requirements-dev.txt +# install development dependencies (for testing) +RUN --mount=type=cache,target=/root/.cache/pip \ + python3 -m pip install -e tests/vllm_test_utils + # enable fast downloads from hf (for testing) RUN --mount=type=cache,target=/root/.cache/pip \ python3 -m pip install hf_transfer diff --git a/Dockerfile.arm b/Dockerfile.arm new file mode 100644 index 0000000000000..093ee2209222f --- /dev/null +++ b/Dockerfile.arm @@ -0,0 +1,62 @@ +# This vLLM Dockerfile is used to construct an image that can build and run vLLM on ARM CPU platform. + +FROM ubuntu:22.04 AS cpu-test-arm + +ENV CCACHE_DIR=/root/.cache/ccache + +ENV CMAKE_CXX_COMPILER_LAUNCHER=ccache + +RUN --mount=type=cache,target=/var/cache/apt \ + apt-get update -y \ + && apt-get install -y curl ccache git wget vim numactl gcc-12 g++-12 python3 python3-pip libtcmalloc-minimal4 libnuma-dev \ + && apt-get install -y ffmpeg libsm6 libxext6 libgl1 \ + && update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-12 10 --slave /usr/bin/g++ g++ /usr/bin/g++-12 + +# tcmalloc provides better memory allocation efficiency, e.g., holding memory in caches to speed up access of commonly-used objects. +RUN --mount=type=cache,target=/root/.cache/pip \ + pip install py-cpuinfo # Use this to gather CPU info and optimize based on ARM Neoverse cores + +# Set LD_PRELOAD for tcmalloc on ARM +ENV LD_PRELOAD="/usr/lib/aarch64-linux-gnu/libtcmalloc_minimal.so.4" + +RUN echo 'ulimit -c 0' >> ~/.bashrc + +WORKDIR /workspace + +ARG PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cpu" +ENV PIP_EXTRA_INDEX_URL=${PIP_EXTRA_INDEX_URL} +RUN --mount=type=cache,target=/root/.cache/pip \ + --mount=type=bind,src=requirements-build.txt,target=requirements-build.txt \ + pip install --upgrade pip && \ + pip install -r requirements-build.txt + +FROM cpu-test-arm AS build + +WORKDIR /workspace/vllm + +RUN --mount=type=cache,target=/root/.cache/pip \ + --mount=type=bind,src=requirements-common.txt,target=requirements-common.txt \ + --mount=type=bind,src=requirements-cpu.txt,target=requirements-cpu.txt \ + pip install -v -r requirements-cpu.txt + +COPY . . +ARG GIT_REPO_CHECK=0 +RUN --mount=type=bind,source=.git,target=.git \ + if [ "$GIT_REPO_CHECK" != 0 ]; then bash tools/check_repo.sh ; fi + +# Disabling AVX512 specific optimizations for ARM +ARG VLLM_CPU_DISABLE_AVX512="true" +ENV VLLM_CPU_DISABLE_AVX512=${VLLM_CPU_DISABLE_AVX512} + +RUN --mount=type=cache,target=/root/.cache/pip \ + --mount=type=cache,target=/root/.cache/ccache \ + --mount=type=bind,source=.git,target=.git \ + VLLM_TARGET_DEVICE=cpu python3 setup.py bdist_wheel && \ + pip install dist/*.whl && \ + rm -rf dist + +WORKDIR /workspace/ + +RUN ln -s /workspace/vllm/tests && ln -s /workspace/vllm/examples && ln -s /workspace/vllm/benchmarks + +ENTRYPOINT ["python3", "-m", "vllm.entrypoints.openai.api_server"] \ No newline at end of file diff --git a/Dockerfile.cpu b/Dockerfile.cpu index 287b4958da4e5..ebe226cf6d148 100644 --- a/Dockerfile.cpu +++ b/Dockerfile.cpu @@ -16,7 +16,7 @@ RUN --mount=type=cache,target=/var/cache/apt \ # intel-openmp provides additional performance improvement vs. openmp # tcmalloc provides better memory allocation efficiency, e.g, holding memory in caches to speed up access of commonly-used objects. RUN --mount=type=cache,target=/root/.cache/pip \ - pip install intel-openmp + pip install intel-openmp==2025.0.1 ENV LD_PRELOAD="/usr/lib/x86_64-linux-gnu/libtcmalloc_minimal.so.4:/usr/local/lib/libiomp5.so" @@ -62,4 +62,8 @@ WORKDIR /workspace/ RUN ln -s /workspace/vllm/tests && ln -s /workspace/vllm/examples && ln -s /workspace/vllm/benchmarks +# install development dependencies (for testing) +RUN --mount=type=cache,target=/root/.cache/pip \ + pip install -e tests/vllm_test_utils + ENTRYPOINT ["python3", "-m", "vllm.entrypoints.openai.api_server"] diff --git a/Dockerfile.hpu b/Dockerfile.hpu index d18fc016387bf..87e0c1a6a934e 100644 --- a/Dockerfile.hpu +++ b/Dockerfile.hpu @@ -11,6 +11,9 @@ ENV PT_HPU_ENABLE_LAZY_COLLECTIVES=true RUN VLLM_TARGET_DEVICE=hpu python3 setup.py install +# install development dependencies (for testing) +RUN python3 -m pip install -e tests/vllm_test_utils + WORKDIR /workspace/ RUN ln -s /workspace/vllm/tests && ln -s /workspace/vllm/examples && ln -s /workspace/vllm/benchmarks diff --git a/Dockerfile.neuron b/Dockerfile.neuron index 2143315d2a078..76dbd4c04d3f3 100644 --- a/Dockerfile.neuron +++ b/Dockerfile.neuron @@ -38,4 +38,7 @@ ENV VLLM_TARGET_DEVICE neuron RUN --mount=type=bind,source=.git,target=.git \ pip install --no-build-isolation -v -e . +# install development dependencies (for testing) +RUN python3 -m pip install -e tests/vllm_test_utils + CMD ["/bin/bash"] diff --git a/Dockerfile.openvino b/Dockerfile.openvino index a05ff452cd36e..8bd188ffde408 100644 --- a/Dockerfile.openvino +++ b/Dockerfile.openvino @@ -22,4 +22,7 @@ RUN PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cpu" VLLM_TARGET_DEVIC COPY examples/ /workspace/examples COPY benchmarks/ /workspace/benchmarks +# install development dependencies (for testing) +RUN python3 -m pip install -e tests/vllm_test_utils + CMD ["/bin/bash"] diff --git a/Dockerfile.ppc64le b/Dockerfile.ppc64le index b19c6ddec7948..971248577983f 100644 --- a/Dockerfile.ppc64le +++ b/Dockerfile.ppc64le @@ -29,6 +29,9 @@ RUN --mount=type=cache,target=/root/.cache/pip \ RUN --mount=type=bind,source=.git,target=.git \ VLLM_TARGET_DEVICE=cpu python3 setup.py install +# install development dependencies (for testing) +RUN python3 -m pip install -e tests/vllm_test_utils + WORKDIR /workspace/ RUN ln -s /workspace/vllm/tests && ln -s /workspace/vllm/examples && ln -s /workspace/vllm/benchmarks diff --git a/Dockerfile.rocm b/Dockerfile.rocm index a9594833c3c99..d46b5f3beaa5a 100644 --- a/Dockerfile.rocm +++ b/Dockerfile.rocm @@ -101,5 +101,8 @@ ENV TOKENIZERS_PARALLELISM=false # Performance environment variable. ENV HIP_FORCE_DEV_KERNARG=1 +# install development dependencies (for testing) +RUN python3 -m pip install -e tests/vllm_test_utils + CMD ["/bin/bash"] diff --git a/Dockerfile.tpu b/Dockerfile.tpu index 0a507b6ecdf60..b617932a85b47 100644 --- a/Dockerfile.tpu +++ b/Dockerfile.tpu @@ -22,4 +22,7 @@ RUN --mount=type=cache,target=/root/.cache/pip \ -r requirements-tpu.txt RUN python3 setup.py develop +# install development dependencies (for testing) +RUN python3 -m pip install -e tests/vllm_test_utils + CMD ["/bin/bash"] diff --git a/Dockerfile.xpu b/Dockerfile.xpu index 63bc682770422..a374f20d7d949 100644 --- a/Dockerfile.xpu +++ b/Dockerfile.xpu @@ -64,5 +64,6 @@ RUN --mount=type=cache,target=/root/.cache/pip \ ENV VLLM_USAGE_SOURCE production-docker-image \ TRITON_XPU_PROFILE 1 - +# install development dependencies (for testing) +RUN python3 -m pip install -e tests/vllm_test_utils ENTRYPOINT ["python3", "-m", "vllm.entrypoints.openai.api_server"] diff --git a/README.md b/README.md index 0ef073210d070..cfeb24cbb5823 100644 --- a/README.md +++ b/README.md @@ -16,9 +16,9 @@ Easy, fast, and cheap LLM serving for everyone --- *Latest News* 🔥 -- [2024/11] We hosted [the seventh vLLM meetup](https://lu.ma/h0qvrajz) with Snowflake! Please find the meetup slides [here](https://docs.google.com/presentation/d/1e3CxQBV3JsfGp30SwyvS3eM_tW-ghOhJ9PAJGK6KR54/edit?usp=sharing). +- [2024/11] We hosted [the seventh vLLM meetup](https://lu.ma/h0qvrajz) with Snowflake! Please find the meetup slides from vLLM team [here](https://docs.google.com/presentation/d/1e3CxQBV3JsfGp30SwyvS3eM_tW-ghOhJ9PAJGK6KR54/edit?usp=sharing), and Snowflake team [here](https://docs.google.com/presentation/d/1qF3RkDAbOULwz9WK5TOltt2fE9t6uIc_hVNLFAaQX6A/edit?usp=sharing). - [2024/10] We have just created a developer slack ([slack.vllm.ai](https://slack.vllm.ai)) focusing on coordinating contributions and discussing features. Please feel free to join us there! -- [2024/10] Ray Summit 2024 held a special track for vLLM! Please find the opening talk slides from the vLLM team [here](https://docs.google.com/presentation/d/1B_KQxpHBTRa_mDF-tR6i8rWdOU5QoTZNcEg2MKZxEHM/edit?usp=sharing). Learn more from the [talks](https://raysummit.anyscale.com/flow/anyscale/raysummit2024/landing/page/sessioncatalog?tab.day=20241001&search.sessiontracks=1719251906298001uzJ2) from other vLLM contributors and users! +- [2024/10] Ray Summit 2024 held a special track for vLLM! Please find the opening talk slides from the vLLM team [here](https://docs.google.com/presentation/d/1B_KQxpHBTRa_mDF-tR6i8rWdOU5QoTZNcEg2MKZxEHM/edit?usp=sharing). Learn more from the [talks](https://www.youtube.com/playlist?list=PLzTswPQNepXl6AQwifuwUImLPFRVpksjR) from other vLLM contributors and users! - [2024/09] We hosted [the sixth vLLM meetup](https://lu.ma/87q3nvnh) with NVIDIA! Please find the meetup slides [here](https://docs.google.com/presentation/d/1wrLGwytQfaOTd5wCGSPNhoaW3nq0E-9wqyP7ny93xRs/edit?usp=sharing). - [2024/07] We hosted [the fifth vLLM meetup](https://lu.ma/lp0gyjqr) with AWS! Please find the meetup slides [here](https://docs.google.com/presentation/d/1RgUD8aCfcHocghoP3zmXzck9vX3RCI9yfUAB2Bbcl4Y/edit?usp=sharing). - [2024/07] In partnership with Meta, vLLM officially supports Llama 3.1 with FP8 quantization and pipeline parallelism! Please check out our blog post [here](https://blog.vllm.ai/2024/07/23/llama31.html). diff --git a/benchmarks/backend_request_func.py b/benchmarks/backend_request_func.py index 25c8b1bbf3e22..c3fed56e8a956 100644 --- a/benchmarks/backend_request_func.py +++ b/benchmarks/backend_request_func.py @@ -54,6 +54,7 @@ async def async_request_tgi( "do_sample": True, "temperature": 0.01, # TGI does not accept 0.0 temperature. "top_p": 0.99, # TGI does not accept 1.0 top_p. + "truncate": request_func_input.prompt_len, # TGI does not accept ignore_eos flag. } payload = { diff --git a/benchmarks/benchmark_prefix_caching.py b/benchmarks/benchmark_prefix_caching.py index 6d33096ca1d11..5e9381f712e10 100644 --- a/benchmarks/benchmark_prefix_caching.py +++ b/benchmarks/benchmark_prefix_caching.py @@ -54,13 +54,30 @@ def test_prefix(llm=None, sampling_params=None, prompts=None): print(f"cost time {end_time - start_time}") -def sample_requests( +@dataclasses.dataclass +class Request: + prompt: str + prompt_len: int + output_len: int + + +def sample_tokens(tokenizer: PreTrainedTokenizerBase, length: int) -> str: + vocab = tokenizer.get_vocab() + # Remove the special tokens. + vocab = { + k: v + for k, v in vocab.items() if k not in tokenizer.all_special_ids + } + return random.choices(list(vocab.values()), k=length) + + +def sample_requests_from_dataset( dataset_path: str, num_requests: int, tokenizer: PreTrainedTokenizerBase, input_length_range: Tuple[int, int], fixed_output_len: Optional[int], -) -> List[Tuple[str, int, int]]: +) -> List[Request]: if fixed_output_len is not None and fixed_output_len < 4: raise ValueError("output_len too small") @@ -77,31 +94,55 @@ def sample_requests( random.shuffle(dataset) min_len, max_len = input_length_range + assert min_len >= 0 and max_len >= min_len, "input_length_range too small" # Filter out sequences that are too long or too short - filtered_dataset: List[Tuple[str, int, int]] = [] + filtered_requests: List[Request] = [] + for i in range(len(dataset)): - if len(filtered_dataset) == num_requests: + if len(filtered_requests) == num_requests: break # Tokenize the prompts and completions. - prompt = dataset[i][0] - prompt_token_ids = tokenizer(prompt).input_ids + prompt_token_ids = tokenizer(dataset[i][0]).input_ids + prompt = tokenizer.decode(prompt_token_ids) completion = dataset[i][1] completion_token_ids = tokenizer(completion).input_ids prompt_len = len(prompt_token_ids) - output_len = len(completion_token_ids - ) if fixed_output_len is None else fixed_output_len - if prompt_len < 4 or output_len < 4: - # Prune too short sequences. - continue + output_len = (len(completion_token_ids) + if fixed_output_len is None else fixed_output_len) if min_len <= prompt_len <= max_len: - filtered_dataset.append((prompt, prompt_len, output_len)) + filtered_requests.append(Request(prompt, prompt_len, output_len)) + + return filtered_requests + + +def sample_requests_from_random( + num_requests: int, + tokenizer: PreTrainedTokenizerBase, + input_length_range: Tuple[int, int], + fixed_output_len: Optional[int], + prefix_len: int, +) -> List[Request]: - return filtered_dataset + requests = [] + prefix_token_ids = sample_tokens(tokenizer, prefix_len) + min_len, max_len = input_length_range + + for i in range(num_requests): + unique_part_token_ids = sample_tokens( + tokenizer, + random.randint(min_len - prefix_len, max_len - prefix_len)) + prompt_token_ids = prefix_token_ids + unique_part_token_ids + prompt = tokenizer.decode(prompt_token_ids) + prompt_len = len(prompt_token_ids) + assert (min_len <= prompt_len <= max_len + ), f"prompt_len {prompt_len} out of range {min_len}:{max_len}" + requests.append(Request(prompt, prompt_len, fixed_output_len)) + return requests -def repeat_and_sort_requests(requests: List[Tuple[str, int, int]], +def repeat_and_sort_requests(requests: List[Request], repeat_count: int, sort: bool = False) -> List[str]: repeated_requests = requests * repeat_count @@ -109,7 +150,7 @@ def repeat_and_sort_requests(requests: List[Tuple[str, int, int]], repeated_requests.sort(key=lambda x: x[1]) else: random.shuffle(repeated_requests) - return [req[0] for req in repeated_requests] + return [req.prompt for req in repeated_requests] def main(args): @@ -117,9 +158,12 @@ def main(args): input_length_range = tuple(map(int, args.input_length_range.split(':'))) random.seed(args.seed) if args.dataset_path is not None: - print(f"Start to sample {args.num_prompts} prompts" + if args.prefix_len > 0: + raise ValueError("prefix-len is not supported when " + "dataset-path is provided.") + print(f"Start to sample {args.num_prompts} prompts " f"from {args.dataset_path}") - filtered_datasets = sample_requests( + filtered_requests = sample_requests_from_dataset( dataset_path=args.dataset_path, num_requests=args.num_prompts, tokenizer=tokenizer, @@ -127,9 +171,22 @@ def main(args): fixed_output_len=args.output_len, ) else: - prompt_len = len(tokenizer(PROMPT).input_ids) - filtered_datasets = [(PROMPT, prompt_len, args.output_len) - ] * args.num_prompts + print(f"Start to sample {args.num_prompts} prompts from random") + filtered_requests = sample_requests_from_random( + num_requests=args.num_prompts, + tokenizer=tokenizer, + input_length_range=input_length_range, + fixed_output_len=args.output_len, + prefix_len=args.prefix_len, + ) + + # Print some helpful stats of the requests. + print(f"Sampled {len(filtered_requests)} requests.") + prompt_lens = [req.prompt_len for req in filtered_requests] + print(f"Average input length: {sum(prompt_lens) / len(prompt_lens)}") + print(f"P50 input length: {sorted(prompt_lens)[len(prompt_lens) // 2]}") + print(f"Min Prompt Length: {min(prompt_lens)}") + print(f"Max Prompt Length: {max(prompt_lens)}") engine_args = EngineArgs.from_cli_args(args) @@ -137,8 +194,8 @@ def main(args): sampling_params = SamplingParams(temperature=0, max_tokens=args.output_len) - print("Testing filtered datasets") - prompts = repeat_and_sort_requests(filtered_datasets, + print("Testing filtered requests") + prompts = repeat_and_sort_requests(filtered_requests, repeat_count=args.repeat_count, sort=args.sort) @@ -161,20 +218,29 @@ def main(args): parser.add_argument('--output-len', type=int, default=10) parser.add_argument('--num-prompts', type=int, - default=1, + required=True, help="Number of the prompts sampled from dataset") parser.add_argument('--repeat-count', type=int, - default=100, + default=1, help='Number of times to repeat each prompt') parser.add_argument('--sort', action='store_true', help='Sort prompts by input length') parser.add_argument('--input-length-range', type=str, - default='128:256', + required=True, help='Range of input lengths for sampling prompts,' 'specified as "min:max" (e.g., "128:256").') + parser.add_argument( + "--prefix-len", + type=int, + default=0, + help="Specifies the length of a common prefix to be " + "added to the input prompt. The input-length-range will " + "subtract this length when filtering prompts. Only used " + "when dataset-path is not provided.", + ) parser = EngineArgs.add_cli_args(parser) args = parser.parse_args() diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py index e9fc037a46965..3256692142c5e 100644 --- a/benchmarks/benchmark_serving.py +++ b/benchmarks/benchmark_serving.py @@ -199,6 +199,56 @@ def sample_sonnet_requests( return sampled_requests +def sample_mmmu_pro_vision_requests( + dataset, + num_requests: int, + tokenizer: PreTrainedTokenizerBase, + fixed_output_len: Optional[int] = None, +) -> List[Tuple[str, str, int, Optional[Dict[str, Collection[str]]]]]: + sampled_requests: List[Tuple[str, int, int, Dict[str, + Collection[str]]]] = [] + for data in dataset: + if len(sampled_requests) == num_requests: + break + + # MMMU-Pro vision direct prompt + # Ref: https://github.com/MMMU-Benchmark/MMMU/blob/6ce42f4d8f70c1841c67867152648974415b5cac/mmmu-pro/prompts.yaml#L5 + prompt = ( + "Answer with the option letter from the given choices directly. " + "The last line of your response should be of the following " + "format: 'Answer: $LETTER' (without quotes) where LETTER is one of " + "options.") + + prompt_token_ids = tokenizer(prompt).input_ids + if fixed_output_len is None: + # Default max output len is set to 128 + print("--hf-output-len is not provided. Using default value 128.") + fixed_output_len = 128 + + prompt_len = len(prompt_token_ids) + output_len = fixed_output_len + + assert isinstance( + data["image"], + Image), ("Input image format must be `PIL.Image.Image`, " + f"given {type(data['image'])}.") + image: Image = data["image"] + image = image.convert("RGB") + image_data = io.BytesIO() + image.save(image_data, format='JPEG') + image_base64 = base64.b64encode(image_data.getvalue()).decode("utf-8") + mm_content = { + "type": "image_url", + "image_url": { + "url": f"data:image/jpeg;base64,{image_base64}" + }, + } + + sampled_requests.append((prompt, prompt_len, output_len, mm_content)) + + return sampled_requests + + def sample_hf_requests( dataset_path: str, dataset_subset: str, @@ -208,6 +258,21 @@ def sample_hf_requests( random_seed: int, fixed_output_len: Optional[int] = None, ) -> List[Tuple[str, str, int, Optional[Dict[str, Collection[str]]]]]: + + # Special case for MMMU-Pro vision dataset + if dataset_path == 'MMMU/MMMU_Pro' and dataset_subset == 'vision': + assert dataset_split == "test" + dataset = load_dataset(dataset_path, + name=dataset_subset, + split=dataset_split, + streaming=True) + assert "image" in dataset.features, ( + "MMMU/MMMU_Pro vision dataset must have 'image' column.") + filter_func = lambda x: isinstance(x["image"], Image) + dataset = dataset.shuffle(seed=random_seed).filter(filter_func) + return sample_mmmu_pro_vision_requests(dataset, num_requests, + tokenizer, fixed_output_len) + dataset = load_dataset(dataset_path, name=dataset_subset, split=dataset_split, diff --git a/benchmarks/disagg_benchmarks/disagg_overhead_benchmark.sh b/benchmarks/disagg_benchmarks/disagg_overhead_benchmark.sh new file mode 100644 index 0000000000000..2924ea4a49f54 --- /dev/null +++ b/benchmarks/disagg_benchmarks/disagg_overhead_benchmark.sh @@ -0,0 +1,144 @@ +#!/bin/bash + +# benchmark the overhead of disaggregated prefill. +# methodology: +# - send all request to prefill vLLM instance. It will buffer KV cache. +# - then send all request to decode instance. +# - The TTFT of decode instance is the overhead. + +set -ex + +kill_gpu_processes() { + # kill all processes on GPU. + pkill -f pt_main_thread + sleep 10 + + # remove vllm config file + rm -rf ~/.config/vllm + + # Print the GPU memory usage + # so that we know if all GPU processes are killed. + gpu_memory_usage=$(nvidia-smi --query-gpu=memory.used --format=csv,noheader,nounits -i 0) + # The memory usage should be 0 MB. + echo "GPU 0 Memory Usage: $gpu_memory_usage MB" +} + +wait_for_server() { + # wait for vllm server to start + # return 1 if vllm server crashes + local port=$1 + timeout 1200 bash -c " + until curl -s localhost:${port}/v1/completions > /dev/null; do + sleep 1 + done" && return 0 || return 1 +} + + +benchmark() { + + export VLLM_LOGGING_LEVEL=DEBUG + export VLLM_HOST_IP=$(hostname -I | awk '{print $1}') + + # compare chunked prefill with disaggregated prefill + + results_folder="./results" + model="meta-llama/Meta-Llama-3.1-8B-Instruct" + dataset_name="sonnet" + dataset_path="../sonnet_4x.txt" + num_prompts=10 + qps=$1 + prefix_len=50 + input_len=2048 + output_len=$2 + + + CUDA_VISIBLE_DEVICES=0 python3 \ + -m vllm.entrypoints.openai.api_server \ + --model meta-llama/Meta-Llama-3.1-8B-Instruct \ + --port 8100 \ + --max-model-len 10000 \ + --gpu-memory-utilization 0.6 \ + --kv-transfer-config \ + '{"kv_connector":"PyNcclConnector","kv_role":"kv_producer","kv_rank":0,"kv_parallel_size":2,"kv_buffer_size":5e9}' & + + + CUDA_VISIBLE_DEVICES=1 python3 \ + -m vllm.entrypoints.openai.api_server \ + --model meta-llama/Meta-Llama-3.1-8B-Instruct \ + --port 8200 \ + --max-model-len 10000 \ + --gpu-memory-utilization 0.6 \ + --kv-transfer-config \ + '{"kv_connector":"PyNcclConnector","kv_role":"kv_consumer","kv_rank":1,"kv_parallel_size":2,"kv_buffer_size":5e9}' & + + wait_for_server 8100 + wait_for_server 8200 + + # let the prefill instance finish prefill + python3 ../benchmark_serving.py \ + --backend vllm \ + --model $model \ + --dataset-name $dataset_name \ + --dataset-path $dataset_path \ + --sonnet-input-len $input_len \ + --sonnet-output-len "$output_len" \ + --sonnet-prefix-len $prefix_len \ + --num-prompts $num_prompts \ + --port 8100 \ + --save-result \ + --result-dir $results_folder \ + --result-filename disagg_prefill_2xtp4.json \ + --request-rate "inf" + + + # send the request to decode. + # The TTFT of this command will be the overhead of disagg prefill impl. + python3 ../benchmark_serving.py \ + --backend vllm \ + --model $model \ + --dataset-name $dataset_name \ + --dataset-path $dataset_path \ + --sonnet-input-len $input_len \ + --sonnet-output-len "$output_len" \ + --sonnet-prefix-len $prefix_len \ + --num-prompts $num_prompts \ + --port 8200 \ + --save-result \ + --result-dir $results_folder \ + --result-filename disagg_prefill_2xtp4.json \ + --request-rate "$qps" + kill_gpu_processes + +} + + +main() { + + (which wget && which curl) || (apt-get update && apt-get install -y wget curl) + (which jq) || (apt-get -y install jq) + (which socat) || (apt-get -y install socat) + + pip install quart httpx + + cd "$(dirname "$0")" + + cd .. + # create sonnet-4x.txt + echo "" > sonnet_4x.txt + for _ in {1..4} + do + cat sonnet.txt >> sonnet_4x.txt + done + cd disagg_benchmarks + + rm -rf results + mkdir results + + default_qps=1 + default_output_len=1 + benchmark $default_qps $default_output_len + +} + + +main "$@" diff --git a/benchmarks/disagg_benchmarks/disagg_performance_benchmark.sh b/benchmarks/disagg_benchmarks/disagg_performance_benchmark.sh new file mode 100644 index 0000000000000..d8d9e976dce76 --- /dev/null +++ b/benchmarks/disagg_benchmarks/disagg_performance_benchmark.sh @@ -0,0 +1,164 @@ +#!/bin/bash + +# Requirement: 8x H100 GPUs. + + +# Model: neuralmagic/Meta-Llama-3-70B-Instruct-FP8-KV +# Query: 2048 input tokens, 11 output tokens, QPS 4, 500 requests +# Resource: 8x H100 +# Approaches: +# 1. Chunked prefill: 1 vllm instance with tp=8 +# 2. Chunked prefill: 2 vllm instance with tp=4, equivalent to 1 tp=4 instance with QPS 4 +# 3. Disaggregated prefill: 1 prefilling instance and 1 decoding instance +# Prefilling instance: max_output_token=1 +# Decoding instance: force the input tokens be the same across requests to bypass prefilling + +set -ex + +kill_gpu_processes() { + # kill all processes on GPU. + pgrep pt_main_thread | xargs -r kill -9 + pgrep python3 | xargs -r kill -9 + for port in 8000 8100 8200; do lsof -t -i:$port | xargs -r kill -9; done + sleep 1 +} + +wait_for_server() { + # wait for vllm server to start + # return 1 if vllm server crashes + local port=$1 + timeout 1200 bash -c " + until curl -s localhost:${port}/v1/completions > /dev/null; do + sleep 1 + done" && return 0 || return 1 +} + + +launch_chunked_prefill() { + model="meta-llama/Meta-Llama-3.1-8B-Instruct" + # disagg prefill + CUDA_VISIBLE_DEVICES=0 python3 \ + -m vllm.entrypoints.openai.api_server \ + --model $model \ + --port 8100 \ + --max-model-len 10000 \ + --enable-chunked-prefill \ + --gpu-memory-utilization 0.6 & + CUDA_VISIBLE_DEVICES=1 python3 \ + -m vllm.entrypoints.openai.api_server \ + --model $model \ + --port 8200 \ + --max-model-len 10000 \ + --enable-chunked-prefill \ + --gpu-memory-utilization 0.6 & + wait_for_server 8100 + wait_for_server 8200 + python3 round_robin_proxy.py & + sleep 1 +} + + +launch_disagg_prefill() { + model="meta-llama/Meta-Llama-3.1-8B-Instruct" + # disagg prefill + CUDA_VISIBLE_DEVICES=0 python3 \ + -m vllm.entrypoints.openai.api_server \ + --model $model \ + --port 8100 \ + --max-model-len 10000 \ + --gpu-memory-utilization 0.6 \ + --kv-transfer-config \ + '{"kv_connector":"PyNcclConnector","kv_role":"kv_producer","kv_rank":0,"kv_parallel_size":2,"kv_buffer_size":5e9}' & + + CUDA_VISIBLE_DEVICES=1 python3 \ + -m vllm.entrypoints.openai.api_server \ + --model $model \ + --port 8200 \ + --max-model-len 10000 \ + --gpu-memory-utilization 0.6 \ + --kv-transfer-config \ + '{"kv_connector":"PyNcclConnector","kv_role":"kv_consumer","kv_rank":1,"kv_parallel_size":2,"kv_buffer_size":5e9}' & + + wait_for_server 8100 + wait_for_server 8200 + python3 disagg_prefill_proxy_server.py & + sleep 1 +} + + +benchmark() { + results_folder="./results" + model="meta-llama/Meta-Llama-3.1-8B-Instruct" + dataset_name="sonnet" + dataset_path="../sonnet_4x.txt" + num_prompts=100 + qps=$1 + prefix_len=50 + input_len=1024 + output_len=$2 + tag=$3 + + python3 ../benchmark_serving.py \ + --backend vllm \ + --model $model \ + --dataset-name $dataset_name \ + --dataset-path $dataset_path \ + --sonnet-input-len $input_len \ + --sonnet-output-len "$output_len" \ + --sonnet-prefix-len $prefix_len \ + --num-prompts $num_prompts \ + --port 8000 \ + --save-result \ + --result-dir $results_folder \ + --result-filename "$tag"-qps-"$qps".json \ + --request-rate "$qps" + + sleep 2 + +} + + +main() { + + (which wget && which curl) || (apt-get update && apt-get install -y wget curl) + (which jq) || (apt-get -y install jq) + (which socat) || (apt-get -y install socat) + + pip install quart httpx matplotlib aiohttp + + cd "$(dirname "$0")" + + cd .. + # create sonnet-4x.txt so that we can sample 2048 tokens for input + echo "" > sonnet_4x.txt + for _ in {1..4} + do + cat sonnet.txt >> sonnet_4x.txt + done + cd disagg_benchmarks + + rm -rf results + mkdir results + + default_output_len=6 + + export VLLM_HOST_IP=$(hostname -I | awk '{print $1}') + + launch_chunked_prefill + for qps in 2 4 6 8; do + benchmark $qps $default_output_len chunked_prefill + done + kill_gpu_processes + + launch_disagg_prefill + for qps in 2 4 6 8; do + benchmark $qps $default_output_len disagg_prefill + done + kill_gpu_processes + + python3 visualize_benchmark_results.py + +} + + +main "$@" diff --git a/benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py b/benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py new file mode 100644 index 0000000000000..4058b1c0a3b79 --- /dev/null +++ b/benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py @@ -0,0 +1,61 @@ +import os + +import aiohttp +from quart import Quart, make_response, request + +AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60) + +app = Quart(__name__) + + +async def forward_request(url, data): + async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: + headers = { + "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}" + } + async with session.post(url=url, json=data, + headers=headers) as response: + if response.status == 200: + # if response.headers.get('Transfer-Encoding') == 'chunked': + if True: + async for chunk_bytes in response.content.iter_chunked( + 1024): + yield chunk_bytes + else: + content = await response.read() + yield content + + +@app.route('/v1/completions', methods=['POST']) +async def handle_request(): + try: + original_request_data = await request.get_json() + + prefill_request = original_request_data.copy() + # change max_tokens = 1 to let it only do prefill + prefill_request['max_tokens'] = 1 + + # finish prefill + async for _ in forward_request('http://localhost:8100/v1/completions', + prefill_request): + continue + + # return decode + generator = forward_request('http://localhost:8200/v1/completions', + original_request_data) + response = await make_response(generator) + response.timeout = None + + return response + + except Exception as e: + import sys + import traceback + exc_info = sys.exc_info() + print("Error occurred in disagg prefill proxy server") + print(e) + print("".join(traceback.format_exception(*exc_info))) + + +if __name__ == '__main__': + app.run(port=8000) diff --git a/benchmarks/disagg_benchmarks/round_robin_proxy.py b/benchmarks/disagg_benchmarks/round_robin_proxy.py new file mode 100644 index 0000000000000..6eb5f63980070 --- /dev/null +++ b/benchmarks/disagg_benchmarks/round_robin_proxy.py @@ -0,0 +1,60 @@ +import asyncio +import itertools + +import aiohttp +from aiohttp import web + + +class RoundRobinProxy: + + def __init__(self, target_ports): + self.target_ports = target_ports + self.port_cycle = itertools.cycle(self.target_ports) + + async def handle_request(self, request): + target_port = next(self.port_cycle) + target_url = f"http://localhost:{target_port}{request.path_qs}" + + async with aiohttp.ClientSession() as session: + try: + # Forward the request + async with session.request( + method=request.method, + url=target_url, + headers=request.headers, + data=request.content, + ) as response: + # Start sending the response + resp = web.StreamResponse(status=response.status, + headers=response.headers) + await resp.prepare(request) + + # Stream the response content + async for chunk in response.content.iter_any(): + await resp.write(chunk) + + await resp.write_eof() + return resp + + except Exception as e: + return web.Response(text=f"Error: {str(e)}", status=500) + + +async def main(): + proxy = RoundRobinProxy([8100, 8200]) + app = web.Application() + app.router.add_route('*', '/{path:.*}', proxy.handle_request) + + runner = web.AppRunner(app) + await runner.setup() + site = web.TCPSite(runner, 'localhost', 8000) + await site.start() + + print("Proxy server started on http://localhost:8000") + + # Keep the server running + await asyncio.Event().wait() + + +if __name__ == '__main__': + asyncio.run(main()) diff --git a/benchmarks/disagg_benchmarks/visualize_benchmark_results.py b/benchmarks/disagg_benchmarks/visualize_benchmark_results.py new file mode 100644 index 0000000000000..e59d8bb0e6c8c --- /dev/null +++ b/benchmarks/disagg_benchmarks/visualize_benchmark_results.py @@ -0,0 +1,46 @@ +import json + +import matplotlib.pyplot as plt +import pandas as pd + +if __name__ == "__main__": + + data = [] + for name in ['disagg_prefill', 'chunked_prefill']: + for qps in [2, 4, 6, 8]: + with open(f"results/{name}-qps-{qps}.json") as f: + x = json.load(f) + x['name'] = name + x['qps'] = qps + data.append(x) + + df = pd.DataFrame.from_dict(data) + dis_df = df[df['name'] == 'disagg_prefill'] + chu_df = df[df['name'] == 'chunked_prefill'] + + plt.style.use('bmh') + plt.rcParams['font.size'] = 20 + + for key in [ + 'mean_ttft_ms', 'median_ttft_ms', 'p99_ttft_ms', 'mean_itl_ms', + 'median_itl_ms', 'p99_itl_ms' + ]: + + fig, ax = plt.subplots(figsize=(11, 7)) + plt.plot(dis_df['qps'], + dis_df[key], + label='disagg_prefill', + marker='o', + linewidth=4) + plt.plot(chu_df['qps'], + chu_df[key], + label='chunked_prefill', + marker='o', + linewidth=4) + ax.legend() + + ax.set_xlabel('QPS') + ax.set_ylabel(key) + ax.set_ylim(bottom=0) + fig.savefig(f'results/{key}.png') + plt.close(fig) diff --git a/benchmarks/kernels/benchmark_machete.py b/benchmarks/kernels/benchmark_machete.py index 665b50bf18cf0..46bab74ae8adf 100644 --- a/benchmarks/kernels/benchmark_machete.py +++ b/benchmarks/kernels/benchmark_machete.py @@ -2,8 +2,10 @@ import copy import itertools import math +import os import pickle as pkl import time +from dataclasses import dataclass from itertools import product from typing import Callable, Iterable, List, Optional, Tuple @@ -15,11 +17,12 @@ from vllm import _custom_ops as ops from vllm.model_executor.layers.quantization.utils.marlin_utils import ( - GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N, marlin_permute_scales) + GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N, marlin_permute_scales, + marlin_zero_points) from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( MarlinWorkspace) from vllm.model_executor.layers.quantization.utils.quant_utils import ( - gptq_pack, pack_rows, quantize_weights) + pack_rows, quantize_weights) from vllm.scalar_type import ScalarType, scalar_types from vllm.utils import FlexibleArgumentParser @@ -27,149 +30,350 @@ DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512, 1024] DEFAULT_TP_SIZES = [1] +NVTX_PROFILE = os.environ.get("NVTX_PROFILE", False) + +if NVTX_PROFILE: + import nvtx + + +def terse_type_name(dt): + return { + torch.bfloat16: "bf16", + torch.float16: "fp16", + torch.int8: "int8", + torch.float8_e4m3fn: "fp8", + torch.bfloat16: "bf16", + torch.float: "float", + torch.int: "int", + }[dt] + + +@dataclass +class BenchmarkTensors: + w_ref: torch.Tensor + a: torch.Tensor + + w_q: torch.Tensor + group_size: Optional[int] + wtype: ScalarType + w_g_s: torch.Tensor + w_g_zp: Optional[torch.Tensor] + w_ch_s: Optional[torch.Tensor] + w_tok_s: Optional[torch.Tensor] + + +@dataclass +class TypeConfig: + act_type: torch.dtype + weight_type: ScalarType + output_type: Optional[torch.dtype] + group_scale_type: Optional[torch.dtype] + group_zero_type: Optional[torch.dtype] + channel_scale_type: Optional[torch.dtype] + token_scale_type: Optional[torch.dtype] + + +def rand_data(shape, dtype=torch.float16, scale=1): + if dtype.is_floating_point: + return (scale * torch.rand(shape, device="cuda") - 0.3).to(dtype) + else: + return torch.randint(-15, 15, shape, dtype=dtype, device="cuda") + + +def quantize_and_pack(atype: torch.dtype, + w: torch.Tensor, + wtype: ScalarType, + stype: Optional[torch.dtype], + group_size: Optional[int], + zero_points: bool = False): + assert wtype.is_integer(), "TODO: support floating point weights" + + w_ref, w_q, w_s, w_zp = quantize_weights( + w, + wtype, + group_size=group_size, + zero_points=zero_points, + # to match how the kernel applies zps + ref_zero_points_after_scales=True) -def machete_pack_weights(w_q: torch.tensor, wtype: ScalarType) -> torch.tensor: w_q = pack_rows(w_q, wtype.size_bits, *w_q.shape) - w_q = w_q.t().contiguous().t() # make col major - return ops.machete_prepack_B(w_q, wtype) + return w_ref, w_q, w_s, w_zp -def make_bench_tensors( - atype: torch.dtype, wtype: ScalarType, group_size: int, m: int, n: int, - k: int -) -> Tuple[torch.tensor, List[Tuple[torch.tensor, torch.tensor, torch.tensor, - torch.tensor]]]: - assert wtype.is_integer(), "TODO: support floating point weights" +def create_bench_tensors(shape: Tuple[int, int, int], types: TypeConfig, + group_size: Optional[int]) -> List[BenchmarkTensors]: + m, n, k = shape # we want to make sure that weights don't fit into L2 cache between runs so # we construct enough weights to exceed L2 cache, which is 50mb on a H100 # so we target total weight size > 2*50mb - num_weights = math.ceil(2 * 50 * 1024**2 * 8 / (k * n * wtype.size_bits)) - - a = torch.randn((m, k), device="cuda", dtype=atype) * 5 - weights = [ - torch.randn((k, n), device="cuda", dtype=atype) - for _ in range(num_weights) - ] - quanitized_weights = [ - quantize_weights(w, wtype, group_size) for w in weights - ] - - return a, quanitized_weights + num_weights = math.ceil(2 * 50 * 1024**2 * 8 / + (k * n * types.weight_type.size_bits)) + + a = rand_data((m, k), types.act_type, scale=5) + + benchmark_tensors: List[BenchmarkTensors] = [] + for _ in range(num_weights): + w = rand_data((k, n), types.act_type, scale=5) + + if types.group_scale_type is not None: + w = w.to(types.group_scale_type) + if w.dtype.itemsize == 1: + w = w.to(torch.float16) + + w_ref, w_q_packed, w_s, w_zp = quantize_and_pack( + a.dtype, w, types.weight_type, types.group_scale_type, group_size, + types.group_zero_type is not None) + + if not a.dtype.is_floating_point: + aiinfo = torch.iinfo(a.dtype) + w_ref = w_ref.round().clamp(aiinfo.min, aiinfo.max) + + w_ref = w_ref.to(torch.float32) + + w_ch_s = None if types.channel_scale_type is None else\ + rand_data((n,), types.channel_scale_type) + w_tok_s = None if types.token_scale_type is None else\ + rand_data((m,), types.token_scale_type) + + benchmark_tensors.append( + BenchmarkTensors(w_ref=w_ref, + a=a, + w_q=w_q_packed, + wtype=types.weight_type, + w_g_s=w_s, + w_g_zp=w_zp, + group_size=group_size, + w_ch_s=w_ch_s, + w_tok_s=w_tok_s)) + + return benchmark_tensors + + +def torch_matmul_f16_create_bench_fn(bt: BenchmarkTensors) -> Callable: + a = bt.a + w = bt.w_ref.to(bt.a.dtype) # use float reference tensor + if a.dtype not in [torch.float16, torch.bfloat16]: + a = a.to(torch.float16) + w = w.to(torch.float16) + return lambda: torch.matmul(a, w) + + +def cutlass_scaled_mm_create_bench_fn(bt: BenchmarkTensors) -> Callable: + if bt.w_ch_s is not None and bt.w_tok_s is not None: + scale_a = bt.w_tok_s.to(torch.float32) + scale_b = bt.w_ch_s.to(torch.float32) + else: + scale_a = torch.tensor(1.0, dtype=torch.float32, device=bt.a.device) + scale_b = torch.tensor(1.0, dtype=torch.float32, device=bt.a.device) + w_col_major = bt.w_ref.to(bt.a.dtype).t().contiguous().t() + return lambda: ops.cutlass_scaled_mm( + bt.a, w_col_major, scale_a, scale_b, out_dtype=torch.float16) + + +def marlin_create_bench_fn(bt: BenchmarkTensors) -> Callable: + device = bt.a.device + + workspace = MarlinWorkspace(bt.w_ref.shape[1], GPTQ_MARLIN_MIN_THREAD_N, + GPTQ_MARLIN_MAX_PARALLEL) + + if bt.w_g_zp is None: + w_zp = torch.empty(0, dtype=torch.int, device=device) + else: + w_zp = marlin_zero_points(bt.w_g_zp, bt.w_ref.shape[0], + bt.w_ref.shape[1], bt.wtype.size_bits) + + if bt.group_size is None: + w_s = torch.tensor([], device="cuda", dtype=torch.half) + else: + w_s = marlin_permute_scales(bt.w_g_s, bt.w_ref.shape[0], + bt.w_ref.shape[1], bt.group_size) + + sort_indices = torch.empty(0, dtype=torch.int, device=device) + g_idx = torch.empty(0, dtype=torch.int, device=device) + w_q = ops.gptq_marlin_repack(bt.w_q, sort_indices, bt.w_ref.shape[0], + bt.w_ref.shape[1], bt.wtype.size_bits) + + if bt.a.dtype.is_floating_point: + assert bt.w_ch_s is None + assert bt.w_tok_s is None + assert bt.group_size is not None + + fn = lambda: ops.gptq_marlin_gemm(a=bt.a, + b_q_weight=w_q, + b_scales=w_s, + b_zeros=w_zp, + g_idx=g_idx, + perm=sort_indices, + workspace=workspace.scratch, + b_q_type=bt.wtype, + size_m=bt.a.shape[0], + size_n=bt.w_ref.shape[1], + size_k=bt.w_ref.shape[0], + is_k_full=True, + is_zp_float=False) + else: + assert bt.a.dtype == torch.int8 + assert bt.wtype == scalar_types.uint4b8 + + if bt.w_ch_s is not None: + s_ch = bt.w_ch_s.to(torch.float32) + else: + s_ch = torch.ones(bt.w_ref.shape[1], + dtype=torch.float32, + device=device) + + if bt.w_tok_s is not None: + s_tok = bt.w_tok_s.to(torch.float32) + else: + s_tok = torch.ones(bt.a.shape[0], + dtype=torch.float32, + device=device) + + fn = lambda: ops.marlin_qqq_gemm(a=bt.a, + b_q_weight=w_q, + s_group=w_s, + s_tok=s_tok, + s_ch=s_ch, + workspace=workspace.scratch, + size_m=bt.a.shape[0], + size_n=bt.w_ref.shape[1], + size_k=bt.w_ref.shape[0]) + + return fn + + +def machete_create_bench_fn(bt: BenchmarkTensors, + out_type=torch.dtype, + schedule=None) -> Callable: + w_q = bt.w_q.t().contiguous().t() # make col major + w_q = ops.machete_prepack_B(w_q, bt.a.dtype, bt.wtype, + None if bt.w_g_s is None else bt.w_g_s.dtype) + + w_g_zp = bt.w_g_zp + if w_g_zp is not None: + w_g_zp = -1 * bt.w_g_s * (w_g_zp.to(bt.w_g_s.dtype)) + + return lambda: ops.machete_mm( + a=bt.a, + b_q=bt.w_q, + b_type=bt.wtype, + b_group_scales=bt.w_g_s, + b_group_zeros=w_g_zp, + b_group_size=bt.group_size, + b_channel_scales=bt.w_ch_s, + a_token_scales=bt.w_tok_s, + out_type=out_type, + schedule=schedule, + ) # impl - # bench -def bench_fn(label: str, sub_label: str, description: str, - fn: Callable) -> TMeasurement: - min_run_time = 1 - return TBenchmark.Timer( - stmt="fn()", + +def bench_fns(label: str, sub_label: str, description: str, + fns: List[Callable]): + + min_run_time = 1 if not NVTX_PROFILE else 0.1 + res = TBenchmark.Timer( + stmt=""" + for fn in fns: + fn() + """, globals={ - "fn": fn + "fns": fns }, label=label, sub_label=sub_label, description=description, ).blocked_autorange(min_run_time=min_run_time) + if NVTX_PROFILE: + with nvtx.annotate("mm-bench"), nvtx.annotate( + f"{label}|{sub_label}|{description}"): + fns[0]() -def loop_over_weights( - a: torch.tensor, weights: List[Tuple[torch.tensor, torch.tensor, - torch.tensor, torch.tensor]], - fn: Callable[[torch.tensor, torch.tensor, torch.tensor, torch.tensor], - None]): - for w_ref, w_q, w_s, _ in weights: - fn(a, w_ref, w_q, w_s) + return res _SWEEP_SCHEDULES_RESULTS: Optional[pd.DataFrame] = None _SWEEP_SCHEDULES_RESULTS_CSV: Optional[str] = None -def bench(atype: torch.dtype, - wtype: ScalarType, +def bench(types: TypeConfig, group_size: int, m: int, k: int, n: int, label: str, sub_label: str, - benchmark_marlinv1: bool = True, - sweep_schedules: bool = True) -> Iterable[TMeasurement]: - global _SWEEP_SCHEDULES_RESULTS - - a, weights = make_bench_tensors(atype, wtype, group_size, m, n, k) - sub_label += f", L={len(weights)}" - - weights_machete = [(w_ref, machete_pack_weights(w_q, wtype), w_s, w_zp) - for w_ref, w_q, w_s, w_zp in weights] + sweep_schedules: bool = True) -> List[TMeasurement]: + benchmark_tensors = create_bench_tensors((m, n, k), types, group_size) + sub_label += f", L={len(benchmark_tensors)}" + + name_type_string = f"W{types.weight_type}"+\ + f"-A{terse_type_name(types.act_type)}" + if types.group_scale_type is not None: + name_type_string += f"-GS{terse_type_name(types.group_scale_type)}" + if types.group_zero_type is not None: + name_type_string += f"-GZ{terse_type_name(types.group_zero_type)}" + if group_size is not None: + name_type_string += f"-G{group_size}" + if types.channel_scale_type is not None: + name_type_string += f"-CS{terse_type_name(types.channel_scale_type)}" + if types.token_scale_type is not None: + name_type_string += f"-TS{terse_type_name(types.token_scale_type)}" timers = [] # pytorch impl timers.append( - bench_fn( - label, sub_label, "torch.matmul", lambda: loop_over_weights( - a, - weights, - lambda a, w_ref, w_q, w_s: torch.matmul(a, w_ref), - ))) + bench_fns( + label, sub_label, "torch.matmul (fp16)", + [torch_matmul_f16_create_bench_fn(bt) + for bt in benchmark_tensors])) - if benchmark_marlinv1: - w_ref = weights[0][0] - - w_zp_empty = torch.empty(0, dtype=torch.int, device=w_ref.device) - sort_indices = torch.empty(0, dtype=torch.int, device=w_ref.device) - g_idx = torch.empty(0, dtype=torch.int, device=w_ref.device) - - def marlinv1_pack_weights(w_q: torch.tensor) -> torch.tensor: - w_q_gptq = gptq_pack(w_q, wtype.size_bits, *w_ref.shape) - return ops.gptq_marlin_repack(w_q_gptq, sort_indices, *w_ref.shape, - wtype.size_bits) - - def marlinv1_permute_scales(w_s: torch.tensor) -> torch.tensor: - return marlin_permute_scales(w_s, *w_ref.shape, group_size) - - weights_marlinv1 = [(w_ref, marlinv1_pack_weights(w_q), - marlinv1_permute_scales(w_s), w_zp) - for w_ref, w_q, w_s, w_zp in weights] - - workspace = MarlinWorkspace(w_ref.shape[1], GPTQ_MARLIN_MIN_THREAD_N, - GPTQ_MARLIN_MAX_PARALLEL) - - # marlinv1 + if types.act_type == torch.int8 or types.act_type == torch.float8_e4m3fn: + timers.append( + bench_fns( + label, sub_label, + f"cutlass_scaled_mm ({terse_type_name(types.act_type)})", [ + cutlass_scaled_mm_create_bench_fn(bt) + for bt in benchmark_tensors + ])) + + if types.act_type != torch.float8_e4m3fn: timers.append( - bench_fn( - label, sub_label, "marlin_orig", lambda: loop_over_weights( - a, weights_marlinv1, lambda a, w_ref, w_q, w_s: ops. - gptq_marlin_gemm(a, - w_q, - w_s, - w_zp_empty, - g_idx, - sort_indices, - workspace.scratch, - wtype, - size_m=a.shape[0], - size_n=w_ref.shape[1], - size_k=w_ref.shape[0], - is_k_full=True)))) + bench_fns(label, sub_label, f"marlin ({name_type_string})", + [marlin_create_bench_fn(bt) + for bt in benchmark_tensors])) # machete timers.append( - bench_fn( - label, sub_label, "machete_heuristic", lambda: loop_over_weights( - a, weights_machete, lambda a, _, w_q, w_s: ops.machete_gemm( - a, w_q, wtype, b_scales=w_s, b_group_size=group_size)))) + bench_fns(label, sub_label, f"machete ({name_type_string})", [ + machete_create_bench_fn(bt, out_type=types.output_type) + for bt in benchmark_tensors + ])) if sweep_schedules: + global _SWEEP_SCHEDULES_RESULTS + print("Finding best schedule for machete") best = None best_schedule = None - schedules = ops.machete_supported_schedules(wtype) + schedules = ops.machete_supported_schedules( + a_type=types.act_type, + b_type=types.weight_type, + group_scales_type=types.group_scale_type, + group_zeros_type=types.group_zero_type, + token_scales_type=types.token_scale_type, + channel_scales_type=types.channel_scale_type, + out_type=types.output_type) + + if schedules is None or len(schedules) == 0: + raise ValueError("No schedules found to sweep") + for schedule in reversed(schedules): schedule_M = int(schedule.split("_")[0].split("x")[1]) @@ -177,16 +381,11 @@ def marlinv1_permute_scales(w_s: torch.tensor) -> torch.tensor: if schedule_M >= 2 * max(m, 16) or schedule_M < m // 4: continue - def run(a, _, w_q, w_s, schedule=schedule): - ops.machete_gemm(a, - w_q, - wtype, - w_s, - b_group_size=group_size, - schedule=schedule) - - res = bench_fn(label, sub_label, "machete_best", - lambda: loop_over_weights(a, weights_machete, run)) + res = bench_fns(label, sub_label, "machete_best", [ + machete_create_bench_fn( + bt, out_type=types.output_type, schedule=schedule) + for bt in benchmark_tensors + ]) results_row = { "M": m, @@ -213,25 +412,33 @@ def run(a, _, w_q, w_s, schedule=schedule): # runner -def print_timers(timers: Iterable[TMeasurement]): +def print_timers(timers: List[TMeasurement]): compare = TBenchmark.Compare(timers) compare.print() -def run(dtype: torch.dtype, sweep_schedules: bool, - MKNs: Iterable[Tuple[int, int, int]]) -> Iterable[TMeasurement]: +def run(args, MKNs: Iterable[Tuple[int, int, int]]) -> Iterable[TMeasurement]: + types = TypeConfig( + act_type=args.act_type, + weight_type=scalar_types.uint4b8 if args.group_zero_type is None \ + else scalar_types.uint4, + output_type=args.out_type, + group_scale_type=args.group_scale_type, + group_zero_type=args.group_zero_type, + channel_scale_type=args.channel_scale_type, + token_scale_type=args.token_scale_type, + ) - results = [] + results: List[TMeasurement] = [] for m, k, n in MKNs: - timers = bench(dtype, - scalar_types.uint4b8, - 128, + timers = bench(types, + args.group_size, m, k, n, - f"{dtype}-gemm", + f"{args.act_type}-gemm", f"MKN=({m}x{k}x{n})", - sweep_schedules=sweep_schedules) + sweep_schedules=args.sweep_schedules) print_timers(timers) results.extend(timers) @@ -240,7 +447,7 @@ def run(dtype: torch.dtype, sweep_schedules: bool, # output makers def make_output( - data: Iterable[TMeasurement], + data: List[TMeasurement], MKNs: Iterable[Tuple[int, int, int]], base_description: str, timestamp=None, @@ -262,7 +469,6 @@ def run_square_bench(args): dim_sizes = list( range(args.dim_start, args.dim_end + 1, args.dim_increment)) MKNs = list(zip(dim_sizes, dim_sizes, dim_sizes)) - data = run(args.dtype, args.sweep_schedules, MKNs) make_output(data, MKNs, f"square_bench-{args.dtype}") @@ -306,33 +512,49 @@ def model_shapes(model_name: str, tp_size: int) -> List[Tuple[int, int]]: for k, n in KNs: MKNs.append((m, k, n)) - data = run(args.dtype, args.sweep_schedules, MKNs) + data = run(args, MKNs) model_bench_data.append(data) + type_string = f"{args.act_type}" + # Print all results for data, model_tp in zip(model_bench_data, models_tps): model, tp_size = model_tp - print(f"== Results {args.dtype} {model}-TP{tp_size} ====") + print(f"== Results {type_string} {model}-TP{tp_size} ====") print_timers(data) - timestamp = int(time.time()) + timestr = time.strftime("%Y%m%d-%H%M%S") - all_data = [] + all_results = [] for d in model_bench_data: - all_data.extend(d) + all_results.extend(d) + # pickle all data - with open(f"model_bench-{args.dtype}-{timestamp}.pkl", "wb") as f: - pkl.dump(all_data, f) + with open(f"model_bench-{type_string}-{timestr}.pkl", "wb") as f: + args_dict = vars(args) + args_dict.pop("func") + pkl.dump({ + "args": args_dict, + "results": all_results, + }, f) if __name__ == "__main__": def to_torch_dtype(dt): - if dt == "bfloat16": - return torch.bfloat16 - if dt == "float16": - return torch.float16 - raise ValueError("unsupported dtype") + return { + "bfloat16": torch.bfloat16, + "float16": torch.float16, + "int8": torch.int8, + "float8_e4m3fn": torch.float8_e4m3fn, + "int": torch.int, + "float": torch.float, + }[dt] + + class ToTorchDtype(argparse.Action): + + def __call__(self, parser, namespace, values, option_string=None): + setattr(namespace, self.dest, to_torch_dtype(values)) parser = FlexibleArgumentParser( description=""" @@ -352,12 +574,42 @@ def to_torch_dtype(dt): """, # noqa: E501 formatter_class=argparse.RawTextHelpFormatter, ) - parser.add_argument( - "--dtype", - type=to_torch_dtype, + "--act-type", + action=ToTorchDtype, required=True, - help="Available options are ['bfloat16', 'float16']", + choices=['bfloat16', 'float16', 'int8', 'float8_e4m3fn'], + ) + parser.add_argument( + "--group-scale-type", + action=ToTorchDtype, + choices=['bfloat16', 'float16'], + ) + parser.add_argument( + "--group-zero-type", + type=to_torch_dtype, + choices=['bfloat16', 'float16'], + ) + parser.add_argument( + "--channel-scale-type", + action=ToTorchDtype, + choices=['float'], + ) + parser.add_argument( + "--token-scale-type", + action=ToTorchDtype, + choices=['float'], + ) + parser.add_argument( + "--out-type", + action=ToTorchDtype, + choices=['bfloat16', 'float16'], + ) + parser.add_argument( + "--group-size", + type=int, + help="Available options are ['None', '-1', '128'], default=128", + default=128, ) parser.add_argument( "--sweep-schedules", diff --git a/benchmarks/kernels/benchmark_marlin.py b/benchmarks/kernels/benchmark_marlin.py index 536c133bb3341..8fb44e3a3dbd8 100644 --- a/benchmarks/kernels/benchmark_marlin.py +++ b/benchmarks/kernels/benchmark_marlin.py @@ -131,7 +131,7 @@ def bench_run(results: List[benchmark.Measurement], model: str, results.append( benchmark.Timer( stmt= - "output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, False)", # noqa: E501 + "output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, False, False)", # noqa: E501 globals=globals, label=label, sub_label=sub_label, @@ -141,7 +141,7 @@ def bench_run(results: List[benchmark.Measurement], model: str, results.append( benchmark.Timer( stmt= - "output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, True)", # noqa: E501 + "output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, True, False)", # noqa: E501 globals=globals, label=label, sub_label=sub_label, diff --git a/benchmarks/kernels/graph_machete_bench.py b/benchmarks/kernels/graph_machete_bench.py index de608fd05af70..7d0bd84150a27 100644 --- a/benchmarks/kernels/graph_machete_bench.py +++ b/benchmarks/kernels/graph_machete_bench.py @@ -20,10 +20,11 @@ args = parser.parse_args() with open(args.filename, 'rb') as f: - data: List[TMeasurement] = pickle.load(f) + data = pickle.load(f) + raw_results: List[TMeasurement] = data["results"] results = defaultdict(lambda: list()) - for v in data: + for v in raw_results: result = re.search(r"MKN=\(\d+x(\d+x\d+)\)", v.task_spec.sub_label) if result is not None: KN = result.group(1) diff --git a/benchmarks/kernels/weight_shapes.py b/benchmarks/kernels/weight_shapes.py index 25ec9d6028627..51f24f3ba1774 100644 --- a/benchmarks/kernels/weight_shapes.py +++ b/benchmarks/kernels/weight_shapes.py @@ -40,4 +40,10 @@ ([8192, 57344], 1), ([28672, 8192], 0), ], + "meta-llama/Llama-3.1-405b-hf": [ + ([16384, 18432], 1), + ([16384, 16384], 0), + ([16384, 106496], 1), + ([53248, 16384], 0), + ], } diff --git a/cmake/cpu_extension.cmake b/cmake/cpu_extension.cmake index 5912c5c02ede7..68f7ca1af05ad 100644 --- a/cmake/cpu_extension.cmake +++ b/cmake/cpu_extension.cmake @@ -16,9 +16,14 @@ include_directories("${CMAKE_SOURCE_DIR}/csrc") # # Check the compile flags # + +if (CMAKE_SYSTEM_PROCESSOR MATCHES "x86_64") + list(APPEND CXX_COMPILE_FLAGS + "-mf16c" + ) +endif() list(APPEND CXX_COMPILE_FLAGS "-fopenmp" - "-mf16c" "-DVLLM_CPU_EXTENSION") execute_process(COMMAND cat /proc/cpuinfo @@ -53,6 +58,8 @@ find_isa(${CPUINFO} "avx2" AVX2_FOUND) find_isa(${CPUINFO} "avx512f" AVX512_FOUND) find_isa(${CPUINFO} "POWER10" POWER10_FOUND) find_isa(${CPUINFO} "POWER9" POWER9_FOUND) +find_isa(${CPUINFO} "asimd" ASIMD_FOUND) # Check for ARM NEON support +find_isa(${CPUINFO} "bf16" ARM_BF16_FOUND) # Check for ARM BF16 support if (AVX512_FOUND AND NOT AVX512_DISABLED) list(APPEND CXX_COMPILE_FLAGS @@ -72,9 +79,11 @@ if (AVX512_FOUND AND NOT AVX512_DISABLED) else() message(WARNING "Disable AVX512-BF16 ISA support, no avx512_bf16 found in local CPU flags." " If cross-compilation is required, please set env VLLM_CPU_AVX512BF16=1.") endif() + elseif (AVX2_FOUND) list(APPEND CXX_COMPILE_FLAGS "-mavx2") message(WARNING "vLLM CPU backend using AVX2 ISA") + elseif (POWER9_FOUND OR POWER10_FOUND) message(STATUS "PowerPC detected") # Check for PowerPC VSX support @@ -82,8 +91,20 @@ elseif (POWER9_FOUND OR POWER10_FOUND) "-mvsx" "-mcpu=native" "-mtune=native") + +elseif (ASIMD_FOUND) + message(STATUS "ARMv8 or later architecture detected") + if(ARM_BF16_FOUND) + message(STATUS "BF16 extension detected") + set(MARCH_FLAGS "-march=armv8.2-a+bf16+dotprod+fp16") + add_compile_definitions(ARM_BF16_SUPPORT) + else() + message(WARNING "BF16 functionality is not available") + set(MARCH_FLAGS "-march=armv8.2-a+dotprod+fp16") + endif() + list(APPEND CXX_COMPILE_FLAGS ${MARCH_FLAGS}) else() - message(FATAL_ERROR "vLLM CPU backend requires AVX512 or AVX2 or Power9+ ISA support.") + message(FATAL_ERROR "vLLM CPU backend requires AVX512, AVX2, Power9+ ISA or ARMv8 support.") endif() # @@ -153,4 +174,4 @@ define_gpu_extension_target( WITH_SOABI ) -message(STATUS "Enabling C extension.") +message(STATUS "Enabling C extension.") \ No newline at end of file diff --git a/csrc/cpu/attention.cpp b/csrc/cpu/attention.cpp index e73eca1b345fd..e21832ba7582f 100644 --- a/csrc/cpu/attention.cpp +++ b/csrc/cpu/attention.cpp @@ -24,12 +24,20 @@ struct KernelVecType { template <> struct KernelVecType { +#ifdef __powerpc64__ + // Power architecture-specific vector types + using q_load_vec_type = vec_op::FP32Vec8; + using k_load_vec_type = vec_op::FP32Vec16; + using v_load_vec_type = vec_op::FP32Vec16; +#else + // Fallback for other architectures, including x86 using q_load_vec_type = vec_op::FP16Vec8; - using q_vec_type = vec_op::FP32Vec16; using k_load_vec_type = vec_op::FP16Vec16; + using v_load_vec_type = vec_op::FP16Vec16; +#endif + using q_vec_type = vec_op::FP32Vec16; using k_vec_type = vec_op::FP32Vec16; using qk_acc_vec_type = vec_op::FP32Vec16; - using v_load_vec_type = vec_op::FP16Vec16; }; #ifdef __AVX512BF16__ @@ -43,6 +51,21 @@ struct KernelVecType { using v_load_vec_type = vec_op::BF16Vec16; }; #else + #ifdef __aarch64__ + #ifndef ARM_BF16_SUPPORT + // pass + #else +template <> +struct KernelVecType { + using q_load_vec_type = vec_op::BF16Vec8; + using q_vec_type = vec_op::FP32Vec16; + using k_load_vec_type = vec_op::BF16Vec16; + using k_vec_type = vec_op::FP32Vec16; + using qk_acc_vec_type = vec_op::FP32Vec16; + using v_load_vec_type = vec_op::BF16Vec16; +}; + #endif + #else template <> struct KernelVecType { using q_load_vec_type = vec_op::BF16Vec8; @@ -52,6 +75,7 @@ struct KernelVecType { using qk_acc_vec_type = vec_op::FP32Vec16; using v_load_vec_type = vec_op::BF16Vec16; }; + #endif #endif template @@ -771,4 +795,4 @@ void paged_attention_v2( CALL_V2_KERNEL_LAUNCHER_BLOCK_SIZE(scalar_t); CPU_KERNEL_GUARD_OUT(paged_attention_v2_impl) }); -} +} \ No newline at end of file diff --git a/csrc/cpu/cpu_types.hpp b/csrc/cpu/cpu_types.hpp index 0213be09105ed..28db0479748bf 100644 --- a/csrc/cpu/cpu_types.hpp +++ b/csrc/cpu/cpu_types.hpp @@ -1,4 +1,3 @@ - #ifndef CPU_TYPES_HPP #define CPU_TYPES_HPP @@ -8,8 +7,11 @@ #elif defined(__POWER9_VECTOR__) //ppc implementation #include "cpu_types_vsx.hpp" +#elif defined(__aarch64__) + //arm implementation + #include "cpu_types_arm.hpp" #else #warning "unsupported vLLM cpu implementation" #endif -#endif +#endif \ No newline at end of file diff --git a/csrc/cpu/cpu_types_arm.hpp b/csrc/cpu/cpu_types_arm.hpp new file mode 100644 index 0000000000000..73e0f8cb2e0fb --- /dev/null +++ b/csrc/cpu/cpu_types_arm.hpp @@ -0,0 +1,515 @@ +#include +#include +#include + +namespace vec_op { + +#ifdef ARM_BF16_SUPPORT + #define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \ + AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) +#else + #define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \ + AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) +#endif + +#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)) + +#ifndef CPU_OP_GUARD +#define CPU_KERNEL_GUARD_IN(NAME) +#define CPU_KERNEL_GUARD_OUT(NAME) +#else +#define CPU_KERNEL_GUARD_IN(NAME) \ + std::cout << #NAME << " invoked." << std::endl; +#define CPU_KERNEL_GUARD_OUT(NAME) std::cout << #NAME << " exit." << std::endl; +#endif + +#define FORCE_INLINE __attribute__((always_inline)) inline + +namespace { + template + constexpr void unroll_loop_item(std::integer_sequence, F &&f) { + (f(std::integral_constant{}), ...); + }; +}; + +template >> +constexpr void unroll_loop(F &&f) { + unroll_loop_item(std::make_integer_sequence{}, std::forward(f)); +} + +template struct Vec { + constexpr static int get_elem_num() { return T::VEC_ELEM_NUM; }; +}; + +struct FP32Vec8; +struct FP32Vec16; + +struct FP16Vec8 : public Vec { + constexpr static int VEC_ELEM_NUM = 8; + + float16x8_t reg; + + explicit FP16Vec8(const void *ptr) + : reg(vld1q_f16(static_cast(ptr))) {}; + + explicit FP16Vec8(const FP32Vec8 &); + + void save(void *ptr) const { + vst1q_f16(static_cast<__fp16 *>(ptr), reg); + } +}; + +struct FP16Vec16 : public Vec { + constexpr static int VEC_ELEM_NUM = 16; + + float16x8x2_t reg; + + explicit FP16Vec16(const void *ptr) { + reg.val[0] = vld1q_f16(reinterpret_cast(ptr)); + reg.val[1] = vld1q_f16(reinterpret_cast(ptr) + 8); + } + + explicit FP16Vec16(const FP32Vec16& vec); + + void save(void *ptr) const { + vst1q_f16(reinterpret_cast<__fp16*>(ptr), reg.val[0]); + vst1q_f16(reinterpret_cast<__fp16*>(ptr) + 8, reg.val[1]); + } + + void save(void *ptr, const int elem_num) const { + int full_blocks = elem_num / 8; + int remainder = elem_num % 8; + + if (full_blocks > 0) { + vst1q_f16(reinterpret_cast<__fp16*>(ptr), reg.val[0]); + if (full_blocks > 1) { + vst1q_f16(reinterpret_cast<__fp16*>(ptr) + 8, reg.val[1]); + } + } + + if (remainder > 0) { + float16x8_t temp = reg.val[full_blocks]; + for (int i = 0; i < remainder; ++i) { + reinterpret_cast<__fp16*>(ptr)[full_blocks * 8 + i] = vgetq_lane_f16(temp, i); + } + } + } +}; + + +#ifdef ARM_BF16_SUPPORT +struct BF16Vec8 : public Vec { + constexpr static int VEC_ELEM_NUM = 8; + + bfloat16x8_t reg; + + explicit BF16Vec8(const void *ptr) + : reg(*reinterpret_cast(ptr)) {}; + + explicit BF16Vec8(bfloat16x8_t data) : reg(data) {}; + + explicit BF16Vec8(const FP32Vec8 &); + + explicit BF16Vec8(float32x4x2_t v) : reg(vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(v.val[0]), v.val[1])) {}; + + void save(void *ptr) const { *reinterpret_cast(ptr) = reg; } +}; + +struct BF16Vec16 : public Vec { + constexpr static int VEC_ELEM_NUM = 16; + + bfloat16x8x2_t reg; + + explicit BF16Vec16(const void *ptr) + : reg(*reinterpret_cast(ptr)) {}; + + explicit BF16Vec16(bfloat16x8x2_t data) : reg(data) {}; + + explicit BF16Vec16(const FP32Vec16 &); + + explicit BF16Vec16(float32x4x4_t v) : reg({ + vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(v.val[0]), v.val[1]), + vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(v.val[2]), v.val[3]) + }){}; + + void save(void *ptr) const { *reinterpret_cast(ptr) = reg; }; +}; + +struct BF16Vec32 : public Vec { + constexpr static int VEC_ELEM_NUM = 32; + + bfloat16x8x4_t reg; + + explicit BF16Vec32(const void *ptr) + : reg(*reinterpret_cast(ptr)) {}; + + explicit BF16Vec32(bfloat16x8x4_t data) : reg(data) {}; + + explicit BF16Vec32(const BF16Vec8 &vec8_data) : reg({ + vec8_data.reg, + vec8_data.reg, + vec8_data.reg, + vec8_data.reg + }) {}; + + void save(void *ptr) const { *reinterpret_cast(ptr) = reg; }; +}; +#endif + +struct FP32Vec4 : public Vec { + constexpr static int VEC_ELEM_NUM = 4; + + union AliasReg { + float32x4_t reg; + float values[VEC_ELEM_NUM]; + }; + + float32x4_t reg; + + explicit FP32Vec4(float v) : reg(vdupq_n_f32(v)) {}; + + explicit FP32Vec4() : reg(vdupq_n_f32(0.0f)) {}; + + explicit FP32Vec4(const float *ptr) : reg(vld1q_f32(ptr)) {}; + + explicit FP32Vec4(float32x4_t data) : reg(data) {}; + + explicit FP32Vec4(const FP32Vec4 &data) : reg(data.reg) {}; +}; + +struct FP32Vec8 : public Vec { + constexpr static int VEC_ELEM_NUM = 8; + union AliasReg { + float32x4x2_t reg; + float values[VEC_ELEM_NUM]; + }; + + float32x4x2_t reg; + + explicit FP32Vec8(float v) : reg({vmovq_n_f32(v), vmovq_n_f32(v)}) {}; + + explicit FP32Vec8() : reg({vmovq_n_f32(0.0), vmovq_n_f32(0.0)}) {}; + + explicit FP32Vec8(const float *ptr) : reg({vld1q_f32(ptr), vld1q_f32(ptr + 4)}) {}; + + explicit FP32Vec8(float32x4x2_t data) : reg(data) {}; + + explicit FP32Vec8(const FP32Vec8 &data) : reg(data.reg) {}; + + explicit FP32Vec8(const FP16Vec8 &v) { + reg.val[0] = vcvt_f32_f16(vget_low_f16(v.reg)); + reg.val[1] = vcvt_f32_f16(vget_high_f16(v.reg)); + }; + + explicit FP32Vec8(float16x8_t v) : reg({vcvt_f32_f16(vget_low_f16(v)), vcvt_f32_f16(vget_high_f16(v))}) {}; + + #ifdef ARM_BF16_SUPPORT + + explicit FP32Vec8(bfloat16x8_t v) : reg({vcvtq_low_f32_bf16(v), vcvtq_high_f32_bf16(v)}) {}; + + explicit FP32Vec8(const BF16Vec8 &v) : reg({vcvtq_low_f32_bf16(v.reg), vcvtq_high_f32_bf16(v.reg)}) {}; + + #endif + + float reduce_sum() const { + AliasReg ar; + ar.reg = reg; + float answer = 0; + unroll_loop([&answer, &ar](int i) { answer += ar.values[i]; }); + + return answer; + } + + FP32Vec8 exp() const { + AliasReg ar; + ar.reg = reg; + + float32x2_t exp_vec0 = {expf(ar.values[0]), expf(ar.values[1])}; + float32x2_t exp_vec1 = {expf(ar.values[2]), expf(ar.values[3])}; + float32x2_t exp_vec2 = {expf(ar.values[4]), expf(ar.values[5])}; + float32x2_t exp_vec3 = {expf(ar.values[6]), expf(ar.values[7])}; + + float32x4_t result0 = vcombine_f32(exp_vec0, exp_vec1); + float32x4_t result1 = vcombine_f32(exp_vec2, exp_vec3); + + float32x4x2_t result; + result.val[0] = result0; + result.val[1] = result1; + + return FP32Vec8(result); + } + + FP32Vec8 tanh() const { + AliasReg ar; + ar.reg = reg; + + float32x2_t tanh_vec0 = {tanhf(ar.values[0]), tanhf(ar.values[1])}; + float32x2_t tanh_vec1 = {tanhf(ar.values[2]), tanhf(ar.values[3])}; + float32x2_t tanh_vec2 = {tanhf(ar.values[4]), tanhf(ar.values[5])}; + float32x2_t tanh_vec3 = {tanhf(ar.values[6]), tanhf(ar.values[7])}; + + float32x4_t result0 = vcombine_f32(tanh_vec0, tanh_vec1); + float32x4_t result1 = vcombine_f32(tanh_vec2, tanh_vec3); + + float32x4x2_t result; + result.val[0] = result0; + result.val[1] = result1; + + return FP32Vec8(result); + } + + FP32Vec8 er() const { + AliasReg ar; + ar.reg = reg; + + float32x2_t er_vec0 = {static_cast(erf(ar.values[0])), static_cast(erf(ar.values[1]))}; + float32x2_t er_vec1 = {static_cast(erf(ar.values[2])), static_cast(erf(ar.values[3]))}; + float32x2_t er_vec2 = {static_cast(erf(ar.values[4])), static_cast(erf(ar.values[5]))}; + float32x2_t er_vec3 = {static_cast(erf(ar.values[6])), static_cast(erf(ar.values[7]))}; + + float32x4_t result0 = vcombine_f32(er_vec0, er_vec1); + float32x4_t result1 = vcombine_f32(er_vec2, er_vec3); + + float32x4x2_t result; + result.val[0] = result0; + result.val[1] = result1; + + return FP32Vec8(result); + } + + FP32Vec8 operator*(const FP32Vec8 &b) const { + return FP32Vec8(float32x4x2_t({vmulq_f32(reg.val[0], b.reg.val[0]), vmulq_f32(reg.val[1], b.reg.val[1])})); + } + + FP32Vec8 operator+(const FP32Vec8 &b) const { + return FP32Vec8(float32x4x2_t({vaddq_f32(reg.val[0], b.reg.val[0]), vaddq_f32(reg.val[1], b.reg.val[1])})); + } + + FP32Vec8 operator-(const FP32Vec8 &b) const { + return FP32Vec8(float32x4x2_t({vsubq_f32(reg.val[0], b.reg.val[0]), vsubq_f32(reg.val[1], b.reg.val[1])})); + } + + FP32Vec8 operator/(const FP32Vec8 &b) const { + return FP32Vec8(float32x4x2_t({vdivq_f32(reg.val[0], b.reg.val[0]), vdivq_f32(reg.val[1], b.reg.val[1])})); + } + + void save(float *ptr) const { + vst1q_f32(ptr, reg.val[0]); + vst1q_f32(ptr + 4, reg.val[1]); + } +}; + +struct FP32Vec16 : public Vec { + constexpr static int VEC_ELEM_NUM = 16; + union AliasReg { + float32x4x4_t reg; + float values[VEC_ELEM_NUM]; + }; + + float32x4x4_t reg; + + explicit FP32Vec16(float v) : reg({vmovq_n_f32(v), vmovq_n_f32(v), vmovq_n_f32(v), vmovq_n_f32(v)}) {} + + explicit FP32Vec16() : reg({vmovq_n_f32(0.0), vmovq_n_f32(0.0), vmovq_n_f32(0.0), vmovq_n_f32(0.0)}) {} + + explicit FP32Vec16(const float *ptr) : reg({vld1q_f32(ptr), vld1q_f32(ptr + 4), vld1q_f32(ptr + 8), vld1q_f32(ptr + 12)}) {} + + explicit FP32Vec16(float32x4x4_t data) : reg(data) {} + + explicit FP32Vec16(const FP32Vec8 &data) { + reg.val[0] = data.reg.val[0]; + reg.val[1] = data.reg.val[1]; + reg.val[2] = data.reg.val[0]; + reg.val[3] = data.reg.val[1]; + } + + explicit FP32Vec16(const FP32Vec16 &data) : reg(data.reg) {} + + explicit FP32Vec16(const FP16Vec8 &v) : FP32Vec16(FP32Vec8(v.reg)) {} + + #ifdef ARM_BF16_SUPPORT + explicit FP32Vec16(bfloat16x8x2_t v) : reg({ + vcvtq_low_f32_bf16(v.val[0]), + vcvtq_high_f32_bf16(v.val[0]), + vcvtq_low_f32_bf16(v.val[1]), + vcvtq_high_f32_bf16(v.val[1]) + }) {}; + #endif + + explicit FP32Vec16(const FP32Vec4 &data) { + reg.val[0] = data.reg; + reg.val[1] = data.reg; + reg.val[2] = data.reg; + reg.val[3] = data.reg; + }; + + #ifdef ARM_BF16_SUPPORT + explicit FP32Vec16(const BF16Vec16 &v) : reg({ + vcvtq_low_f32_bf16(v.reg.val[0]), + vcvtq_high_f32_bf16(v.reg.val[0]), + vcvtq_low_f32_bf16(v.reg.val[1]), + vcvtq_high_f32_bf16(v.reg.val[1]) + }) {}; + + explicit FP32Vec16(const BF16Vec8 &v) : FP32Vec16(FP32Vec8(v)) {}; + #endif + + explicit FP32Vec16(const FP16Vec16 &v) { + reg.val[0] = vcvt_f32_f16(vget_low_f16(v.reg.val[0])); + reg.val[1] = vcvt_f32_f16(vget_high_f16(v.reg.val[0])); + reg.val[2] = vcvt_f32_f16(vget_low_f16(v.reg.val[1])); + reg.val[3] = vcvt_f32_f16(vget_high_f16(v.reg.val[1])); + }; + + FP32Vec16 operator+(const FP32Vec16 &b) const { + return FP32Vec16(float32x4x4_t({ + vaddq_f32(reg.val[0], b.reg.val[0]), + vaddq_f32(reg.val[1], b.reg.val[1]), + vaddq_f32(reg.val[2], b.reg.val[2]), + vaddq_f32(reg.val[3], b.reg.val[3])})); + }; + + FP32Vec16 operator*(const FP32Vec16 &b) const { + return FP32Vec16(float32x4x4_t({ + vmulq_f32(reg.val[0], b.reg.val[0]), + vmulq_f32(reg.val[1], b.reg.val[1]), + vmulq_f32(reg.val[2], b.reg.val[2]), + vmulq_f32(reg.val[3], b.reg.val[3])})); + }; + + FP32Vec16 operator-(const FP32Vec16 &b) const { + return FP32Vec16(float32x4x4_t({ + vsubq_f32(reg.val[0], b.reg.val[0]), + vsubq_f32(reg.val[1], b.reg.val[1]), + vsubq_f32(reg.val[2], b.reg.val[2]), + vsubq_f32(reg.val[3], b.reg.val[3]) + })); + }; + + FP32Vec16 operator/(const FP32Vec16 &b) const { + return FP32Vec16(float32x4x4_t({ + vdivq_f32(reg.val[0], b.reg.val[0]), + vdivq_f32(reg.val[1], b.reg.val[1]), + vdivq_f32(reg.val[2], b.reg.val[2]), + vdivq_f32(reg.val[3], b.reg.val[3]) + })); + }; + + float reduce_sum() const { + AliasReg ar; + ar.reg = reg; + float answer = 0; + unroll_loop([&answer, &ar](int i) { answer += ar.values[i]; }); + + return answer; + }; + + template float reduce_sub_sum(int idx) { + static_assert(VEC_ELEM_NUM % group_size == 0); + + AliasReg ar; + ar.reg = reg; + float answer = 0; + const int start = idx * group_size; + unroll_loop( + [&answer, &start, ar](int i) { answer += ar.values[start + i]; }); + + return answer; + }; + + void save(float *ptr) const { + vst1q_f32(ptr, reg.val[0]); + vst1q_f32(ptr + 4, reg.val[1]); + vst1q_f32(ptr + 8, reg.val[2]); + vst1q_f32(ptr + 12, reg.val[3]); + }; +}; + +template struct VecType { using vec_type = void; }; + +template using vec_t = typename VecType::vec_type; + +template <> struct VecType { using vec_type = FP32Vec8; }; + +template <> struct VecType { using vec_type = FP16Vec8; }; + +#ifdef ARM_BF16_SUPPORT +template <> struct VecType { using vec_type = BF16Vec8; }; +#endif + +template void storeFP32(float v, T *ptr) { *ptr = v; } + +template <> inline void storeFP32(float v, c10::Half *ptr) { + *reinterpret_cast<__fp16 *>(ptr) = v; +} + +inline FP16Vec16::FP16Vec16(const FP32Vec16 &v) { + float16x4_t low_0 = vcvt_f16_f32(v.reg.val[0]); + float16x4_t high_0 = vcvt_f16_f32(v.reg.val[1]); + float16x4_t low_1 = vcvt_f16_f32(v.reg.val[2]); + float16x4_t high_1 = vcvt_f16_f32(v.reg.val[3]); + + reg.val[0] = vcombine_f16(low_0, high_0); + reg.val[1] = vcombine_f16(low_1, high_1); +}; + +inline FP16Vec8 :: FP16Vec8(const FP32Vec8 &v) { + float16x4_t lower_half = vcvt_f16_f32(v.reg.val[0]); + float16x4_t upper_half = vcvt_f16_f32(v.reg.val[1]); + + reg = vcombine_f16(lower_half, upper_half); +}; + +inline void fma(FP32Vec16 &acc, FP32Vec16 &a, FP32Vec16 &b) { + + acc.reg.val[0] = vfmaq_f32(acc.reg.val[0], a.reg.val[0], b.reg.val[0]); + acc.reg.val[1] = vfmaq_f32(acc.reg.val[1], a.reg.val[1], b.reg.val[1]); + acc.reg.val[2] = vfmaq_f32(acc.reg.val[2], a.reg.val[2], b.reg.val[2]); + acc.reg.val[3] = vfmaq_f32(acc.reg.val[3], a.reg.val[3], b.reg.val[3]); +}; + +#ifdef ARM_BF16_SUPPORT +inline void fma(FP32Vec16 &acc, BF16Vec32 &a, BF16Vec32 &b) { + + float32x4_t a0_low = vcvt_f32_bf16(vget_low_bf16(a.reg.val[0])); + float32x4_t a0_high = vcvt_f32_bf16(vget_high_bf16(a.reg.val[0])); + float32x4_t a1_low = vcvt_f32_bf16(vget_low_bf16(a.reg.val[1])); + float32x4_t a1_high = vcvt_f32_bf16(vget_high_bf16(a.reg.val[1])); + + float32x4_t b0_low = vcvt_f32_bf16(vget_low_bf16(b.reg.val[0])); + float32x4_t b0_high = vcvt_f32_bf16(vget_high_bf16(b.reg.val[0])); + float32x4_t b1_low = vcvt_f32_bf16(vget_low_bf16(b.reg.val[1])); + float32x4_t b1_high = vcvt_f32_bf16(vget_high_bf16(b.reg.val[1])); + + acc.reg.val[0] = vfmaq_f32(acc.reg.val[0], a0_low, b0_low); + acc.reg.val[1] = vfmaq_f32(acc.reg.val[1], a0_high, b0_high); + acc.reg.val[2] = vfmaq_f32(acc.reg.val[2], a1_low, b1_low); + acc.reg.val[3] = vfmaq_f32(acc.reg.val[3], a1_high, b1_high); +}; +#endif + +#ifdef ARM_BF16_SUPPORT +inline BF16Vec8::BF16Vec8(const FP32Vec8 &v) : reg(vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(v.reg.val[0]), v.reg.val[1])) {}; + +inline BF16Vec16::BF16Vec16(const FP32Vec16 &v) : reg({ + vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(v.reg.val[0]), v.reg.val[1]), + vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(v.reg.val[2]), v.reg.val[3]) + }){}; +#endif + +inline void prefetch(const void *addr) { + __builtin_prefetch(addr, 0, 1); +}; + +#ifdef ARM_BF16_SUPPORT +template <> +inline void storeFP32(float v, c10::BFloat16 *ptr) { + *reinterpret_cast<__bf16 *>(ptr) = vcvth_bf16_f32(v); +}; +#endif +}; \ No newline at end of file diff --git a/csrc/cpu/quant.cpp b/csrc/cpu/quant.cpp index f42fa2361a2db..d9aed657a3113 100644 --- a/csrc/cpu/quant.cpp +++ b/csrc/cpu/quant.cpp @@ -25,7 +25,13 @@ struct KernelVecType { template <> struct KernelVecType { +#ifdef __powerpc64__ + // Power architecture-specific vector type + using load_vec_type = vec_op::FP32Vec16; +#else + // Fallback for other architectures using load_vec_type = vec_op::FP16Vec16; +#endif using azp_adj_load_vec_type = vec_op::INT32Vec16; using cvt_vec_type = vec_op::FP32Vec16; }; diff --git a/csrc/cutlass_extensions/cute_utils.cuh b/csrc/cutlass_extensions/cute_utils.cuh index 1842fab8b2cac..f61fe3ceb978a 100644 --- a/csrc/cutlass_extensions/cute_utils.cuh +++ b/csrc/cutlass_extensions/cute_utils.cuh @@ -20,9 +20,9 @@ CUTE_HOST_DEVICE static constexpr auto permute_layout(Layout l) { // is the layout f(x) = x template CUTE_HOST_DEVICE static constexpr bool is_identity_layout() { - if constexpr (std::is_same_v) + if constexpr (std::is_same_v) { return true; - else { + } else { constexpr auto coalesced_layout = coalesce(Layout{}); if constexpr (rank(coalesced_layout) == 1 && stride<0>(coalesced_layout) == 1) { diff --git a/csrc/quantization/cutlass_w8a8/broadcast_load_epilogue_c2x.hpp b/csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_c2x.hpp similarity index 99% rename from csrc/quantization/cutlass_w8a8/broadcast_load_epilogue_c2x.hpp rename to csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_c2x.hpp index d407d66ab2aa6..7aa87feb4cce2 100644 --- a/csrc/quantization/cutlass_w8a8/broadcast_load_epilogue_c2x.hpp +++ b/csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_c2x.hpp @@ -52,6 +52,7 @@ // clang-format off #include "cutlass/epilogue/threadblock/fusion/visitor_2x.hpp" +#include "cutlass/epilogue/threadblock/fusion/visitors.hpp" #include "cute/tensor.hpp" namespace cutlass::epilogue::threadblock { diff --git a/csrc/quantization/cutlass_w8a8/broadcast_load_epilogue_c3x.hpp b/csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp similarity index 100% rename from csrc/quantization/cutlass_w8a8/broadcast_load_epilogue_c3x.hpp rename to csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp diff --git a/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp b/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp new file mode 100644 index 0000000000000..c69e87999ae71 --- /dev/null +++ b/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp @@ -0,0 +1,317 @@ +#include "cutlass_extensions/epilogue/broadcast_load_epilogue_c2x.hpp" + +/* + This file defines custom epilogues for fusing channel scales, token scales, + bias, and activation zero-points onto a GEMM operation using the + CUTLASS 2.x API, for sm80 (Ampere) NVIDIA GPUs. + + Epilogues must contain a public type named EVTCompute of type Sm80EVT, + as well as a static prepare_args function that constructs an + EVTCompute::Arguments struct. +*/ + +namespace vllm::c2x { + +using namespace cute; + +/* + * This class provides the common load descriptors for the + * ScaledEpilogue[...] classes + */ +template +struct ScaledEpilogueBase { + protected: + using Accum = cutlass::epilogue::threadblock::VisitorAccFetch; + + template + using ColOrScalarLoad = + cutlass::epilogue::threadblock::VisitorColOrScalarBroadcast< + OutputTileThreadMap, T, Stride, Int<0>, Int<0>>>; + + template + using RowOrScalarLoad = + cutlass::epilogue::threadblock::VisitorRowOrScalarBroadcast< + OutputTileThreadMap, T, Stride, Int<1>, Int<0>>>; + + template + using ColLoad = cutlass::epilogue::threadblock::VisitorColBroadcast< + OutputTileThreadMap, T, Stride, Int<0>, Int<0>>>; + + template + using RowLoad = cutlass::epilogue::threadblock::VisitorRowBroadcast< + OutputTileThreadMap, T, Stride, Int<1>, Int<0>>>; + + template + using RowOrZeroLoad = + cutlass::epilogue::threadblock::VisitorRowOrZeroBroadcast< + OutputTileThreadMap, T, Stride, Int<1>, Int<0>>>; + + // This utility function constructs the arguments for the load descriptors + // from a tensor. It can handle both row and column, as well as row/column or + // scalar cases. + template + static auto args_from_tensor(torch::Tensor const& tensor) { + using Arguments = typename Descriptor::Arguments; + auto* data_ptr = static_cast(tensor.data_ptr()); + if constexpr (std::is_same_v> || + std::is_same_v>) { + return Arguments{data_ptr, tensor.numel() != 1}; + } else { + // it would technically work but no use case as data_ptr is never nullptr + static_assert(!std::is_same_v>); + return Arguments{data_ptr}; + } + } + + // This overload handles the case where there might not be a tensor, in which + // case a nullptr is passed and a constant (0) is used. + template + static auto args_from_tensor(c10::optional const& tensor) { + static_assert(std::is_same_v>); + using Arguments = typename Descriptor::Arguments; + auto* data_ptr = tensor ? static_cast(tensor->data_ptr()) : nullptr; + return Arguments{data_ptr}; + } +}; + +/* + This epilogue function defines a quantized GEMM operation similar to + torch._scaled_mm. + + A and B may be both either int8 or fp8_e4m3. A can be quantized per-tensor or + per-row. B can be quantized per-tensor or per-column. + Any combination of per-tensor and per-row or column is supported. + A and B must have symmetric quantization (zero point == 0). + + So the GEMM operation is D = (a_scales * A) (b_scales * B), where the + scales are applied elementwise with numpy-style broadcasting. + + ScaleA and ScaleB define the epilogue functions that apply the scales for + the A and B operands respectively. These scales may be either per-tensor or + per row or column. +*/ +template +struct ScaledEpilogue + : private ScaledEpilogueBase { + private: + using SUPER = ScaledEpilogueBase; + using Accum = typename SUPER::Accum; + using ScaleA = typename SUPER::template ColOrScalarLoad; + using ScaleB = typename SUPER::template RowOrScalarLoad; + + using Compute0 = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::multiplies, float, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTCompute0 = + cutlass::epilogue::threadblock::Sm80EVT; + + using Compute1 = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::multiplies, ElementD, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + public: + using EVTCompute = + cutlass::epilogue::threadblock::Sm80EVT; + using ArgumentType = typename EVTCompute::Arguments; + + static ArgumentType prepare_args(torch::Tensor const& a_scales, + torch::Tensor const& b_scales) { + auto a_args = SUPER::template args_from_tensor(a_scales); + auto b_args = SUPER::template args_from_tensor(b_scales); + + typename EVTCompute0::Arguments evt0_args{b_args}; + return ArgumentType{a_args, evt0_args}; + } +}; + +/* + * This epilogue performs the same operation as ScaledEpilogue, but adds a bias. + * This bias can also be used in the per-tensor azp case, where the activation + * zero point (azp) is used to compute an azp correction term, + * which is folded into the bias. + * + * The bias tensor must be per-output channel. + * ScaleA and ScaleB can be per-tensor or per-token/per-channel. + */ +template +struct ScaledEpilogueBias + : protected ScaledEpilogueBase { + protected: + using SUPER = ScaledEpilogueBase; + using Accum = typename SUPER::Accum; + using ScaleA = typename SUPER::template ColOrScalarLoad; + using ScaleB = typename SUPER::template RowOrScalarLoad; + using Bias = typename SUPER::template RowLoad; + using Compute0 = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::multiplies, float, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTCompute0 = + cutlass::epilogue::threadblock::Sm80EVT; + + using Compute1 = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::multiply_add, ElementD, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + public: + using EVTCompute = cutlass::epilogue::threadblock::Sm80EVT; + using ArgumentType = typename EVTCompute::Arguments; + static ArgumentType prepare_args(torch::Tensor const& a_scales, + torch::Tensor const& b_scales, + torch::Tensor const& bias) { + auto a_args = SUPER::template args_from_tensor(a_scales); + auto b_args = SUPER::template args_from_tensor(b_scales); + auto bias_args = SUPER::template args_from_tensor(bias); + + typename EVTCompute0::Arguments evt0_args{b_args}; + return ArgumentType{a_args, evt0_args, bias_args}; + } +}; + +/* + * This epilogue directly supports per-tensor azp in int32 form. + * As opposed to the per-token epilogue below, this epilogue only has an azp_adj + * term, which should already be multiplied with the scalar azp. + * The azp_adj term is a 1D tensor of shape (1,n), computed as azp * J @ B. + * + * This epilogue also supports bias, which remains per-channel. + */ +template +struct ScaledEpilogueBiasAzp + : protected ScaledEpilogueBase { + private: + using SUPER = ScaledEpilogueBase; + using Accum = typename SUPER::Accum; + using ScaleA = typename SUPER::template ColOrScalarLoad; + using ScaleB = typename SUPER::template RowOrScalarLoad; + using Bias = typename SUPER::template RowOrZeroLoad; + + // This is the full AZP term, azp * J @ B, shape (1,n) + using AzpWithAdj = typename SUPER::template RowLoad; + + // Compute float(accum - azp_adj), both operands are int32_t + using ComputeAzp = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::minus, float, int32_t, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTComputeAzp = + cutlass::epilogue::threadblock::Sm80EVT; + + using ComputeScaleB = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::multiplies, float, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTComputeScaleB = + cutlass::epilogue::threadblock::Sm80EVT; + + using ComputeScaleBiasA = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::multiply_add, ElementD, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + public: + using EVTCompute = + cutlass::epilogue::threadblock::Sm80EVT; + + using ArgumentType = typename EVTCompute::Arguments; + + static ArgumentType prepare_args(torch::Tensor const& a_scales, + torch::Tensor const& b_scales, + torch::Tensor const& azp_adj, + c10::optional const& bias) { + auto a_args = SUPER::template args_from_tensor(a_scales); + auto b_args = SUPER::template args_from_tensor(b_scales); + auto bias_args = SUPER::template args_from_tensor(bias); + auto azp_adj_args = + SUPER::template args_from_tensor(azp_adj); + + typename EVTComputeAzp::Arguments evt_azp_args{{}, azp_adj_args}; + typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_azp_args}; + return ArgumentType{a_args, evt_scale_b_args, bias_args}; + } +}; + +/* + * This epilogue supports per-token azp by computing and applying + * the correction term using a rank-1 update. If the term were materialized, + * it would require O(m*n) space, and this way it only requires O(m+n) space. + * The azp term is a 1D tensor of shape (m,1), and represents the unscaled zero + * point for each row of A. + * The azp_adj term is a 1D tensor of shape (1,n), computed as J @ B. + * + * This epilogue also supports bias, which remains per-channel. + */ +template +struct ScaledEpilogueBiasAzpToken + : protected ScaledEpilogueBase { + private: + using SUPER = ScaledEpilogueBase; + using Accum = typename SUPER::Accum; + using ScaleA = typename SUPER::template ColOrScalarLoad; + using ScaleB = typename SUPER::template RowOrScalarLoad; + using Bias = typename SUPER::template RowOrZeroLoad; + + // Per-token azp term, shape (m,1) + using Azp = typename SUPER::template ColLoad; + + // This is the AZP adjustment term, J @ B, shape (1,n) + using AzpAdj = typename SUPER::template RowLoad; + + // Compute azp * azp_adj + using ComputeAzp = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::multiplies, int32_t, int32_t, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTComputeAzp = + cutlass::epilogue::threadblock::Sm80EVT; + + // Compute float(accum - azp*azp_adj), all operands are int32_t + using ComputeAcc = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::minus, float, int32_t, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTComputeAcc = + cutlass::epilogue::threadblock::Sm80EVT; + + using ComputeScaleB = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::multiplies, float, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTComputeScaleB = + cutlass::epilogue::threadblock::Sm80EVT; + + using ComputeScaleBiasA = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::multiply_add, ElementD, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + public: + using EVTCompute = + cutlass::epilogue::threadblock::Sm80EVT; + + using ArgumentType = typename EVTCompute::Arguments; + + static ArgumentType prepare_args(torch::Tensor const& a_scales, + torch::Tensor const& b_scales, + torch::Tensor const& azp_adj, + torch::Tensor const& azp, + c10::optional const& bias) { + auto a_args = SUPER::template args_from_tensor(a_scales); + auto b_args = SUPER::template args_from_tensor(b_scales); + auto bias_args = SUPER::template args_from_tensor(bias); + auto azp_args = SUPER::template args_from_tensor(azp); + auto azp_adj_args = + SUPER::template args_from_tensor(azp_adj); + + typename EVTComputeAzp::Arguments evt_azp_args{azp_args, azp_adj_args}; + typename EVTComputeAcc::Arguments evt_acc_args{{}, evt_azp_args}; + typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_acc_args}; + return ArgumentType{a_args, evt_scale_b_args, bias_args}; + } +}; + +}; // namespace vllm::c2x \ No newline at end of file diff --git a/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp b/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp new file mode 100644 index 0000000000000..95764ecddc79f --- /dev/null +++ b/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp @@ -0,0 +1,315 @@ +#include "cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp" + +/* + This file defines custom epilogues for fusing channel scales, token scales, + bias, and activation zero-points onto a GEMM operation using the + CUTLASS 3.x API, for NVIDIA GPUs with sm90a (Hopper) or later. + + Epilogues must contain a public type named EVTCompute of type Sm90EVT, + as well as a static prepare_args function that constructs an + EVTCompute::Arguments struct. +*/ + +namespace vllm::c3x { + +using namespace cute; + +/* + * This class provides the common load descriptors for the + * ScaledEpilogue[...] classes + */ +template +struct ScaledEpilogueBase { + protected: + using Accum = cutlass::epilogue::fusion::Sm90AccFetch; + + template + using ColOrScalarLoad = cutlass::epilogue::fusion::Sm90ColOrScalarBroadcast< + 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, + Stride, Int<0>, Int<0>>>; + + template + using RowOrScalarLoad = cutlass::epilogue::fusion::Sm90RowOrScalarBroadcast< + 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, + Stride, Int<1>, Int<0>>>; + + // Don't want to support nullptr by default + template + using ColLoad = cutlass::epilogue::fusion::Sm90ColBroadcast< + 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, + Stride, Int<0>, Int<0>>, 128 / sizeof_bits_v, EnableNullPtr>; + + // Don't want to support nullptr by default + template + using RowLoad = cutlass::epilogue::fusion::Sm90RowBroadcast< + 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, + Stride, Int<1>, Int<0>>, 128 / sizeof_bits_v, EnableNullPtr>; + + // This utility function constructs the arguments for the load descriptors + // from a tensor. It can handle both row and column, as well as row/column or + // scalar cases. + template + static auto args_from_tensor(torch::Tensor const& tensor) { + using Arguments = typename Descriptor::Arguments; + auto* data_ptr = static_cast(tensor.data_ptr()); + if constexpr (std::is_same_v> || + std::is_same_v>) { + return Arguments{data_ptr, tensor.numel() != 1}; + } else { + static_assert(!std::is_same_v> && + !std::is_same_v>); + return Arguments{data_ptr}; + } + } + + // This overload handles the case where there might not be a tensor, in which + // case a nullptr is passed and a constant (0) is used. + template + static auto args_from_tensor(c10::optional const& tensor) { + using Arguments = typename Descriptor::Arguments; + auto* data_ptr = tensor ? static_cast(tensor->data_ptr()) : nullptr; + static_assert(std::is_same_v> || + std::is_same_v>); + return Arguments{data_ptr}; + } +}; + +/* + This epilogue function defines a quantized GEMM operation similar to + torch.scaled_mm_. + + A and B may be both either int8 or fp8_e4m3. A can be + quantized per-tensor or per-row. B can be quantized per-tensor or per-column. + Any combination of per-tensor and per-row or column is supported. + A and B must have symmetric quantization (zero point == 0). + + So the GEMM operation is D = (a_scales * A) (b_scales * B), where the + scales are applied elementwise with numpy-style broadcasting. + + ScaleA and ScaleB define the epilogue functions that apply the scales for + the A and B operands respectively. These scales may be either per-tensor or + per row or column. +*/ +template +struct ScaledEpilogue + : private ScaledEpilogueBase { + private: + using SUPER = ScaledEpilogueBase; + using Accum = typename SUPER::Accum; + using ScaleA = typename SUPER::template ColOrScalarLoad; + using ScaleB = typename SUPER::template RowOrScalarLoad; + + using Compute0 = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiplies, float, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTCompute0 = + cutlass::epilogue::fusion::Sm90EVT; + + using Compute1 = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiplies, ElementD, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + public: + using EVTCompute = + cutlass::epilogue::fusion::Sm90EVT; + using ArgumentType = typename EVTCompute::Arguments; + + static ArgumentType prepare_args(torch::Tensor const& a_scales, + torch::Tensor const& b_scales) { + auto a_args = SUPER::template args_from_tensor(a_scales); + auto b_args = SUPER::template args_from_tensor(b_scales); + + typename EVTCompute0::Arguments evt0_args{b_args}; + return ArgumentType{a_args, evt0_args}; + } +}; + +/* + * This epilogue performs the same operation as ScaledEpilogue, but adds a bias. + * This bias can also be used in the per-tensor azp case, where the activation + * zero point (azp) is used to compute an azp correction term, + * which is folded into the bias. + * + * The bias tensor must be per-output channel. + * ScaleA and ScaleB can be per-tensor or per-token/per-channel. + */ +template +struct ScaledEpilogueBias + : private ScaledEpilogueBase { + private: + using SUPER = ScaledEpilogueBase; + using Accum = typename SUPER::Accum; + using ScaleA = typename SUPER::template ColOrScalarLoad; + using ScaleB = typename SUPER::template RowOrScalarLoad; + using Bias = typename SUPER::template RowLoad; + + using Compute0 = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiplies, float, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTCompute0 = + cutlass::epilogue::fusion::Sm90EVT; + + using Compute1 = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiply_add, ElementD, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + public: + using EVTCompute = + cutlass::epilogue::fusion::Sm90EVT; + + using ArgumentType = typename EVTCompute::Arguments; + static ArgumentType prepare_args(torch::Tensor const& a_scales, + torch::Tensor const& b_scales, + torch::Tensor const& bias) { + auto a_args = SUPER::template args_from_tensor(a_scales); + auto b_args = SUPER::template args_from_tensor(b_scales); + auto bias_args = SUPER::template args_from_tensor(bias); + + typename EVTCompute0::Arguments evt0_args{b_args}; + return ArgumentType{a_args, evt0_args, bias_args}; + } +}; + +/* + * This epilogue directly supports per-tensor azp in int32 form. + * As opposed to the per-token epilogue below, this epilogue only has an azp_adj + * term, which should already be multiplied with the scalar azp. + * The azp_adj term is a 1D tensor of shape (1,n), computed as azp * J @ B. + * + * This epilogue also supports bias, which remains per-channel. + */ +template +struct ScaledEpilogueBiasAzp + : private ScaledEpilogueBase { + private: + using SUPER = ScaledEpilogueBase; + using Accum = typename SUPER::Accum; + using ScaleA = typename SUPER::template ColOrScalarLoad; + using ScaleB = typename SUPER::template RowOrScalarLoad; + using Bias = typename SUPER::template RowLoad; + + // This is the full AZP term, azp * J @ B, shape (1,n) + using AzpWithAdj = typename SUPER::template RowLoad; + + // Compute float(accum - azp_adj), both operands are int32_t + using ComputeAzp = cutlass::epilogue::fusion::Sm90Compute< + cutlass::minus, float, int32_t, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTComputeAzp = + cutlass::epilogue::fusion::Sm90EVT; + + using ComputeScaleB = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiplies, float, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTComputeScaleB = + cutlass::epilogue::fusion::Sm90EVT; + + using ComputeScaleBiasA = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiply_add, ElementD, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + public: + using EVTCompute = + cutlass::epilogue::fusion::Sm90EVT; + using ArgumentType = typename EVTCompute::Arguments; + + static ArgumentType prepare_args(torch::Tensor const& a_scales, + torch::Tensor const& b_scales, + torch::Tensor const& azp_adj, + c10::optional const& bias) { + auto a_args = SUPER::template args_from_tensor(a_scales); + auto b_args = SUPER::template args_from_tensor(b_scales); + auto bias_args = SUPER::template args_from_tensor(bias); + auto azp_adj_args = + SUPER::template args_from_tensor(azp_adj); + + typename EVTComputeAzp::Arguments evt_azp_args{{}, azp_adj_args}; + typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_azp_args}; + return ArgumentType{a_args, evt_scale_b_args, bias_args}; + } +}; + +/* + * This epilogue supports per-token azp by computing and applying + * the correction term using a rank-1 update. If the term were materialized, + * it would require O(m*n) space, and this way it only requires O(m+n) space. + * The azp term is a 1D tensor of shape (m,1), and represents the unscaled zero + * point for each row of A. + * The azp_adj term is a 1D tensor of shape (1,n), computed as J @ B. + * + * This epilogue also supports bias, which remains per-channel. + */ +template +struct ScaledEpilogueBiasAzpToken + : private ScaledEpilogueBase { + private: + using SUPER = ScaledEpilogueBase; + using Accum = typename SUPER::Accum; + using ScaleA = typename SUPER::template ColOrScalarLoad; + using ScaleB = typename SUPER::template RowOrScalarLoad; + using Bias = typename SUPER::template RowLoad; + + // Per-token azp term, shape (m,1) + using Azp = typename SUPER::template ColLoad; + + // This is the AZP adjustment term, J @ B, shape (1,n) + using AzpAdj = typename SUPER::template RowLoad; + + // Compute azp * azp_adj + using ComputeAzp = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiplies, int32_t, int32_t, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTComputeAzp = + cutlass::epilogue::fusion::Sm90EVT; + + // Compute float(accum - azp*azp_adj), all operands are int32_t + using ComputeAcc = cutlass::epilogue::fusion::Sm90Compute< + cutlass::minus, float, int32_t, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTComputeAcc = + cutlass::epilogue::fusion::Sm90EVT; + + using ComputeScaleB = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiplies, float, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTComputeScaleB = + cutlass::epilogue::fusion::Sm90EVT; + + using ComputeScaleBiasA = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiply_add, ElementD, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + public: + using EVTCompute = + cutlass::epilogue::fusion::Sm90EVT; + using ArgumentType = typename EVTCompute::Arguments; + + static ArgumentType prepare_args(torch::Tensor const& a_scales, + torch::Tensor const& b_scales, + torch::Tensor const& azp_adj, + torch::Tensor const& azp, + c10::optional const& bias) { + auto a_args = SUPER::template args_from_tensor(a_scales); + auto b_args = SUPER::template args_from_tensor(b_scales); + auto bias_args = SUPER::template args_from_tensor(bias); + auto azp_args = SUPER::template args_from_tensor(azp); + auto azp_adj_args = + SUPER::template args_from_tensor(azp_adj); + + typename EVTComputeAzp::Arguments evt_azp_args{azp_args, azp_adj_args}; + typename EVTComputeAcc::Arguments evt_acc_args{{}, evt_azp_args}; + typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_acc_args}; + return ArgumentType{a_args, evt_scale_b_args, bias_args}; + } +}; + +}; // namespace vllm::c3x \ No newline at end of file diff --git a/csrc/cutlass_extensions/vllm_cutlass_library_extension.py b/csrc/cutlass_extensions/vllm_cutlass_library_extension.py index 4fcfcd311aa91..a5beea1a35e49 100644 --- a/csrc/cutlass_extensions/vllm_cutlass_library_extension.py +++ b/csrc/cutlass_extensions/vllm_cutlass_library_extension.py @@ -35,6 +35,35 @@ class MixedInputKernelScheduleType(enum.Enum): } } +VLLMDataTypeSize: Dict[Union[VLLMDataType, DataType], int] = { + **DataTypeSize, # type: ignore + **{ + VLLMDataType.u4b8: 4, + VLLMDataType.u8b128: 8, + } +} + +VLLMDataTypeVLLMScalarTypeTag: Dict[Union[VLLMDataType, DataType], str] = { + VLLMDataType.u4b8: "vllm::kU4B8", + VLLMDataType.u8b128: "vllm::kU8B128", + DataType.u4: "vllm::kU4", + DataType.u8: "vllm::kU8", + DataType.s4: "vllm::kS4", + DataType.s8: "vllm::kS8", + DataType.f16: "vllm::kFloat16", + DataType.bf16: "vllm::kBfloat16", +} + +VLLMDataTypeTorchDataTypeTag: Dict[Union[VLLMDataType, DataType], str] = { + DataType.u8: "at::ScalarType::Byte", + DataType.s8: "at::ScalarType::Char", + DataType.e4m3: "at::ScalarType::Float8_e4m3fn", + DataType.s32: "at::ScalarType::Int", + DataType.f16: "at::ScalarType::Half", + DataType.bf16: "at::ScalarType::BFloat16", + DataType.f32: "at::ScalarType::Float", +} + VLLMKernelScheduleTag: Dict[Union[ MixedInputKernelScheduleType, KernelScheduleType], str] = { **KernelScheduleTag, # type: ignore diff --git a/csrc/cutlass_extensions/vllm_numeric_conversion.cuh b/csrc/cutlass_extensions/vllm_numeric_conversion.cuh index 2ad914f8e9868..90f226cf64c0a 100644 --- a/csrc/cutlass_extensions/vllm_numeric_conversion.cuh +++ b/csrc/cutlass_extensions/vllm_numeric_conversion.cuh @@ -3,6 +3,7 @@ #include "cutlass/numeric_conversion.h" #include "cutlass_extensions/vllm_custom_types.cuh" #include "cutlass_extensions/cute_utils.cuh" +#include "cutlass_extensions/vllm_type_utils.cuh" // this file extends: // https://github.com/NVIDIA/cutlass/blob/cutlass-3.5.0/include/cutlass/numeric_conversion.h @@ -28,8 +29,19 @@ struct InterleavedNumericArrayConverter { CUTLASS_DEVICE static result_type convert(source_type const& source) { - CUTE_INVALID_CONTROL_PATH( - "InterleavedNumericArrayConverter not implemented\n"); + if (cute::elect_one_sync()) { + if constexpr (std::is_same_v) { + printf( + "Convert %s <= %s (N = %d, IlvBlkLayout = void), not implemented\n", + nameof_v, nameof_v, N); + } else { + printf( + "Convert %s <= %s (N = %d, size(IlvBlkLayout{}) = %d), not " + "implemented\n", + nameof_v, nameof_v, N, size(IlvBlkLayout{})); + } + __brkpt(); + } return {}; } @@ -56,11 +68,6 @@ struct InterleavedNumericArrayConverter< result_type operator()(source_type const& s) const { return convert(s); } }; -// TODO (LucasWilkinson): Implement -// for Array <= Array - -// .... - template struct ArrayConverterPacked32Bit { using result_type = Array; @@ -86,14 +93,16 @@ struct ArrayConverterPacked32Bit { using ScalarConverter = NumericConverter; template - CUTLASS_DEVICE static uint32_t to_reg(PackedSrc const& source) { + CUTLASS_DEVICE static auto to_regs(PackedSrc const& src) { if constexpr (sizeof(PackedSrc) == 1) { - return static_cast(reinterpret_cast(source)); + return Array{reinterpret_cast(src)}; } else if constexpr (sizeof(PackedSrc) == 2) { - return static_cast(reinterpret_cast(source)); + return Array{reinterpret_cast(src)}; + } else if constexpr (sizeof(PackedSrc) == 4) { + return Array{reinterpret_cast(src)}; } else { - static_assert(sizeof(PackedSrc) == 4); - return reinterpret_cast(source); + static_assert(sizeof(PackedSrc) == 8); + return reinterpret_cast const&>(src); } } @@ -110,7 +119,7 @@ struct ArrayConverterPacked32Bit { static_assert(std::is_same_v); static_assert(std::is_same_v); - return RegConvert32bit::template convert(to_reg(source)); + return RegConvert32bit::template convert(to_regs(source)); } friend class detail::VectorizedConverter; @@ -140,6 +149,131 @@ struct ArrayConverterPacked32Bit { } }; +// Convert 8 4bit values packed into a 32bit register to 8 8bit values packed +// into 2 32bit register. +template +CUTLASS_DEVICE cutlass::AlignedArray lut_4bit_to_8bit_convert( + uint32_t src) { + cutlass::AlignedArray r; + // Determines if the value is in the top half of the LUT if set or + // (i.e. LUT[8:15]) in the bottom half (i.e. LUT[0:7]) if not set. Then move + // into bit position 0x4 of each nibble so when or'd with final_prmt_base it + // selects the correct candidate. When elements in final_prmt_base + // are >= 0x4, the high candidate is selected (i.e. LUT[8:15]), when elements + // are < 0x4, the low candidate is selected (i.e. LUT[0:7]) + uint32_t high_bit = (src & 0x88888888) >> 1; + + // `high_bit` is OR'd with 0x31203120 to find the correct value in the LUT + // (selects correct high or low candidate) + const uint32_t final_prmt_base = 0x32103210; + + // Ignore the high bit when indexing into LUT, for each 4bit value + // we index into both the high and low candidates then use + // high_bit | final_prmt_base to select the correct candidate + uint32_t lut_idx = (src & 0x77777777); + + auto pack = [](uint8_t a, uint8_t b, uint8_t c, uint8_t d) { + return uint32_t(a) | (uint32_t(b) << 8) | (uint32_t(c) << 16) | + (uint32_t(d) << 24); + }; + + static constexpr uint32_t LOW_0 = pack(LUT0, LUT1, LUT2, LUT3); + static constexpr uint32_t LOW_1 = pack(LUT4, LUT5, LUT6, LUT7); + static constexpr uint32_t HIGH_0 = pack(LUT8, LUT9, LUT10, LUT11); + static constexpr uint32_t HIGH_1 = pack(LUT12, LUT13, LUT14, LUT15); + + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < 2; ++ii, lut_idx >>= 16, high_bit >>= 16) { + uint32_t final_prmt_idx = final_prmt_base | high_bit; + + // This uses a look up table to convert packed int4s to packed int8s, + // using the int4 value as the index to prmt. It first select both the + // high and low candidates, then uses the high bit (i.e. `high_bit`) to + // select the correct candidate. + asm volatile( + "{\n" + " .reg .b32 low, high;\n" + " prmt.b32 low, %1, %2, %5;\n" + " prmt.b32 high, %3, %4, %5;\n" + " prmt.b32 %0, low, high, %6;\n" + "}\n" + : "=r"(r[ii]) + : "n"(LOW_0), "n"(LOW_1), "n"(HIGH_0), "n"(HIGH_1), "r"(lut_idx), + "r"(final_prmt_idx)); + } + + return r; +}; + +// for Array <= Array +template +struct NumericArrayConverter { + using result_type = Array; + using source_type = Array; + + static FloatRoundStyle const round_style = Round; + + private: + struct RegConvert { + template + CUTLASS_DEVICE static PackedResultType convert(Array src_) { + // [-8, -7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7] as int8s + auto r = lut_4bit_to_8bit_convert<0xF8, 0xF9, 0xFA, 0xFB, // + 0xFC, 0xFD, 0xFE, 0xFF, // + 0x00, 0x01, 0x02, 0x03, // + 0x04, 0x05, 0x06, 0x07>(src_[0]); + return reinterpret_cast(r); + }; + }; + + public: + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + return ArrayConverterPacked32Bit::convert(source); + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) const { return convert(s); } +}; + +// for Array <= Array +template +struct NumericArrayConverter { + using result_type = Array; + using source_type = Array; + + static FloatRoundStyle const round_style = Round; + + private: + struct RegConvert { + template + CUTLASS_DEVICE static PackedResultType convert(Array src_) { + // [-8, -7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7] as fp8s + auto r = lut_4bit_to_8bit_convert<0xD0, 0xCE, 0xCC, 0xCA, // + 0xC8, 0xC4, 0xC0, 0xB8, // + 0x00, 0x38, 0x40, 0x44, // + 0x48, 0x4A, 0x4C, 0x4E>(src_[0]); + return reinterpret_cast(r); + }; + }; + + public: + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + return ArrayConverterPacked32Bit::convert(source); + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) const { return convert(s); } +}; + // for Array <= Array template struct NumericArrayConverter { @@ -148,7 +282,8 @@ struct NumericArrayConverter { struct RegConvert { template - CUTLASS_DEVICE static PackedResultType convert(uint32_t src) { + CUTLASS_DEVICE static PackedResultType convert(Array src_) { + uint32_t src = src_[0]; using RegArray = cutlass::AlignedArray; @@ -249,7 +384,8 @@ struct InterleavedNumericArrayConverter, Stride<_4, _1>>, private: struct RegConvert { template - CUTLASS_DEVICE static PackedResultType convert(uint32_t src) { + CUTLASS_DEVICE static PackedResultType convert(Array src_) { + uint32_t src = src_[0]; using RegArray = cutlass::AlignedArray; @@ -338,7 +474,8 @@ struct InterleavedNumericArrayConverter, Stride<_4, _1>>, private: struct RegConvert { template - CUTLASS_DEVICE static PackedResultType convert(uint32_t src) { + CUTLASS_DEVICE static PackedResultType convert(Array src_) { + uint32_t src = src_[0]; using RegArray = cutlass::AlignedArray; @@ -417,7 +554,8 @@ struct NumericArrayConverter { struct RegConvert { template - CUTLASS_DEVICE static PackedResultType convert(uint32_t src) { + CUTLASS_DEVICE static PackedResultType convert(Array src_) { + uint32_t src = src_[0]; // Hold output FP16s in reg. We need 1 reg for every 2 elements using RegArray = cutlass::AlignedArray { private: struct RegConvert { template - CUTLASS_DEVICE static PackedResultType convert(uint32_t src) { + CUTLASS_DEVICE static PackedResultType convert(Array src_) { + uint32_t src = src_[0]; PackedResultType r; // __byte_perm simulates the add.u32 0x4B000000 to every u8 element of @@ -513,7 +652,8 @@ struct NumericArrayConverter { private: struct RegConvert { template - CUTLASS_DEVICE static PackedResultType convert(uint32_t src_reg) { + CUTLASS_DEVICE static PackedResultType convert(Array src_) { + uint32_t src_reg = src_[0]; // Hold output BF16s in reg. We need 1 reg for every 2 elements using RegArray = cutlass::AlignedArray, Stride<_4, _1>>, private: struct RegConvert { template - CUTLASS_DEVICE static PackedResultType convert(uint32_t src) { + CUTLASS_DEVICE static PackedResultType convert(Array src_) { + uint32_t src = src_[0]; using RegArray = cutlass::AlignedArray; @@ -671,7 +812,8 @@ struct InterleavedNumericArrayConverter, Stride<_4, _1>>, private: struct RegConvert { template - CUTLASS_DEVICE static PackedResultType convert(uint32_t src) { + CUTLASS_DEVICE static PackedResultType convert(Array src_) { + uint32_t src = src_[0]; using RegArray = cutlass::AlignedArray; @@ -788,6 +930,61 @@ struct NumericArrayConverter { #endif +// for Array <= Array +// FastFP16toINT8 from https://arxiv.org/pdf/2406.09904 +template +struct NumericArrayConverter { + using result_type = Array; + using source_type = Array; + + struct RegConvert { + // FastFP16toINT8 from https://arxiv.org/pdf/2406.09904 + template + CUTLASS_DEVICE static PackedResultType convert( + Array src) { + // Hold output int8s in reg. We need 1 reg for every 4 elements + using RegArray = cutlass::AlignedArray< + uint32_t, std::max(PackedResultType::kElements / 4, size_t(1))>; + RegArray r; + + static constexpr uint32_t MAGIC_BIAS_ = 0x64806480; + auto MAGIC_BIAS = *reinterpret_cast(&MAGIC_BIAS_); + + *reinterpret_cast(&src[0]) = + __hadd2(*reinterpret_cast(&src[0]), MAGIC_BIAS); + + if constexpr (src_regs > 1) { + *reinterpret_cast(&src[1]) = + __hadd2(*reinterpret_cast(&src[1]), MAGIC_BIAS); + } + + static_assert(PackedResultType::kElements <= 4); + uint32_t uint8s; + static constexpr uint32_t MASK_0246 = 0x6420; + static constexpr uint32_t UINT8s_TO_INT8s_MASK = 0x80808080; + asm volatile("prmt.b32 %0,%1,%2,%3;\n" + : "=r"(uint8s) + : "r"(src[0]), "r"((src_regs > 1) ? src[1] : src[0]), + "n"(MASK_0246)); + + uint32_t int8s = (uint8s ^ UINT8s_TO_INT8s_MASK); + + return reinterpret_cast(int8s); + }; + }; + + public: + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + return ArrayConverterPacked32Bit::convert(source); + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) const { return convert(s); } +}; + ///////////////////////////////////////////////////////////////////////////////////////////////// } // namespace cutlass diff --git a/csrc/cutlass_extensions/vllm_type_utils.cuh b/csrc/cutlass_extensions/vllm_type_utils.cuh new file mode 100644 index 0000000000000..500ed508c8303 --- /dev/null +++ b/csrc/cutlass_extensions/vllm_type_utils.cuh @@ -0,0 +1,42 @@ +#include "cutlass/bfloat16.h" +#include "cutlass/half.h" +#include "cuda_bf16.h" + +#include "cutlass_extensions/vllm_custom_types.cuh" + +namespace cutlass { + +template +struct nameof { + static constexpr char const* value = "unknown"; +}; + +template +inline constexpr auto nameof_v = nameof::value; + +#define NAMEOF_TYPE(T) \ + template <> \ + struct nameof { \ + static constexpr char const* value = #T; \ + }; + +NAMEOF_TYPE(float_e4m3_t) +NAMEOF_TYPE(float_e5m2_t) +NAMEOF_TYPE(half_t) +NAMEOF_TYPE(nv_bfloat16) +NAMEOF_TYPE(bfloat16_t) +NAMEOF_TYPE(float) + +NAMEOF_TYPE(int4b_t) +NAMEOF_TYPE(int8_t) +NAMEOF_TYPE(int32_t) +NAMEOF_TYPE(int64_t) + +NAMEOF_TYPE(vllm_uint4b8_t) +NAMEOF_TYPE(uint4b_t) +NAMEOF_TYPE(uint8_t) +NAMEOF_TYPE(vllm_uint8b128_t) +NAMEOF_TYPE(uint32_t) +NAMEOF_TYPE(uint64_t) + +}; // namespace cutlass \ No newline at end of file diff --git a/csrc/layernorm_kernels.cu b/csrc/layernorm_kernels.cu index 405ba213628f6..e14ad972a0a45 100644 --- a/csrc/layernorm_kernels.cu +++ b/csrc/layernorm_kernels.cu @@ -16,21 +16,17 @@ #include "quantization/fp8/nvidia/quant_utils.cuh" #endif -#if defined(__HIPCC__) && (defined(__gfx90a__) || defined(__gfx940__) || \ - defined(__gfx941__) || defined(__gfx942__)) - #define __HIP__MI300_MI250__ -#endif - namespace vllm { -// TODO(woosuk): Further optimize this kernel. +// This kernel uses the _f16Vec to represent vectorized data. +// A conversion to/from float should exist template __global__ std::enable_if_t<(width > 0) && _typeConvert::exists> rms_norm_kernel(scalar_t* __restrict__ out, // [..., hidden_size] const scalar_t* __restrict__ input, // [..., hidden_size] const scalar_t* __restrict__ weight, // [hidden_size] const float epsilon, const int num_tokens, - const int hidden_size, const int vec_hidden_size) { + const size_t hidden_size, const size_t vec_hidden_size) { __shared__ float s_variance; float v8_variance_sum = 0.0f; @@ -46,7 +42,7 @@ rms_norm_kernel(scalar_t* __restrict__ out, // [..., hidden_size] reinterpret_cast*>(weight); // Compute variance. Be careful, hidden_size should multiple of 4. - for (int idx = tx; idx < vec_hidden_size; idx += num_threads) { + for (size_t idx = tx; idx < vec_hidden_size; idx += num_threads) { _f16Vec temp = input_v[idx]; v8_variance_sum += temp.sum_squares(); } @@ -64,7 +60,7 @@ rms_norm_kernel(scalar_t* __restrict__ out, // [..., hidden_size] variance = s_variance; - for (int idx = tx; idx < vec_hidden_size; idx += num_threads) { + for (size_t idx = tx; idx < vec_hidden_size; idx += num_threads) { _f16Vec temp = input_v[idx]; temp *= variance; temp *= weight_v[idx]; @@ -72,19 +68,19 @@ rms_norm_kernel(scalar_t* __restrict__ out, // [..., hidden_size] } } +// Non vectorized kernel for unusual shapes/types without conversion template __global__ std::enable_if_t<(width == 0) || !_typeConvert::exists> rms_norm_kernel(scalar_t* __restrict__ out, // [..., hidden_size] const scalar_t* __restrict__ input, // [..., hidden_size] const scalar_t* __restrict__ weight, // [hidden_size] const float epsilon, const int num_tokens, - const int hidden_size, const int vec_hidden_size) { + const size_t hidden_size, const size_t) { __shared__ float s_variance; float variance = 0.0f; - for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { - const float x = - (float)input[blockIdx.x * static_cast(hidden_size) + idx]; + for (size_t idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { + const float x = (float)input[blockIdx.x * hidden_size + idx]; variance += x * x; } @@ -97,10 +93,9 @@ rms_norm_kernel(scalar_t* __restrict__ out, // [..., hidden_size] } __syncthreads(); - for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { - float x = - (float)input[blockIdx.x * static_cast(hidden_size) + idx]; - out[blockIdx.x * static_cast(hidden_size) + idx] = + for (size_t idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { + float x = (float)input[blockIdx.x * hidden_size + idx]; + out[blockIdx.x * hidden_size + idx] = ((scalar_t)(x * s_variance)) * weight[idx]; } } @@ -234,21 +229,18 @@ void rms_norm(torch::Tensor& out, // [..., hidden_size] int num_tokens = input.numel() / hidden_size; int vec_size = 16 / input.element_size(); int vec_hidden_size = hidden_size / vec_size; + bool can_run_vectorize = (hidden_size % vec_size) == 0; dim3 grid(num_tokens); - dim3 block(std::min(vec_hidden_size, 1024)); const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - -#ifdef __HIP__MI300_MI250__ - if (vec_size % 8 == 0) { + if (vec_size % 8 == 0 && can_run_vectorize) { + dim3 block(std::min(vec_hidden_size, 1024)); LAUNCH_RMS_NORM(8); } else { + dim3 block(std::min(hidden_size, 1024)); LAUNCH_RMS_NORM(0); } -#else - LAUNCH_RMS_NORM(0); -#endif } #define LAUNCH_FUSED_ADD_RMS_NORM(width) \ diff --git a/csrc/ops.h b/csrc/ops.h index f21446a9563c7..665c859ac2950 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -131,6 +131,7 @@ torch::Tensor awq_dequantize(torch::Tensor _kernel, int64_t thx, int64_t thy); torch::Tensor permute_cols(torch::Tensor const& A, torch::Tensor const& perm); +#endif torch::Tensor ggml_dequantize(torch::Tensor W, int64_t type, int64_t m, int64_t n); @@ -141,6 +142,7 @@ torch::Tensor ggml_mul_mat_vec_a8(torch::Tensor W, torch::Tensor X, torch::Tensor ggml_mul_mat_a8(torch::Tensor W, torch::Tensor X, int64_t type, int64_t row); +#ifndef USE_ROCM bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability); void cutlass_scaled_mm(torch::Tensor& out, torch::Tensor const& a, diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu index ee801e16573d4..dbb72e8bbd3f5 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu @@ -8,6 +8,10 @@ #include "scaled_mm_c2x_sm89_fp8_dispatch.cuh" #include "scaled_mm_c2x_sm89_int8_dispatch.cuh" +#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp" + +using namespace vllm; + /* This file defines quantized GEMM operations using the CUTLASS 2.x API, for NVIDIA GPUs with SM versions prior to sm90 (Hopper). @@ -22,12 +26,11 @@ void cutlass_scaled_mm_sm75_epilogue(torch::Tensor& out, torch::Tensor const& a, TORCH_CHECK(b.dtype() == torch::kInt8); if (out.dtype() == torch::kBFloat16) { - return vllm::cutlass_gemm_sm75_dispatch( + return cutlass_gemm_sm75_dispatch( out, a, b, std::forward(epilogue_args)...); } else { TORCH_CHECK(out.dtype() == torch::kFloat16); - return vllm::cutlass_gemm_sm75_dispatch( + return cutlass_gemm_sm75_dispatch( out, a, b, std::forward(epilogue_args)...); } } @@ -42,10 +45,10 @@ void cutlass_scaled_mm_sm75(torch::Tensor& out, torch::Tensor const& a, if (bias) { TORCH_CHECK(bias->dtype() == out.dtype(), "currently bias dtype must match output dtype ", out.dtype()); - return cutlass_scaled_mm_sm75_epilogue( + return cutlass_scaled_mm_sm75_epilogue( out, a, b, a_scales, b_scales, *bias); } else { - return cutlass_scaled_mm_sm75_epilogue( + return cutlass_scaled_mm_sm75_epilogue( out, a, b, a_scales, b_scales); } } @@ -61,10 +64,10 @@ void cutlass_scaled_mm_azp_sm75(torch::Tensor& out, torch::Tensor const& a, TORCH_CHECK(b_scales.dtype() == torch::kFloat32); if (azp) { - return cutlass_scaled_mm_sm75_epilogue( + return cutlass_scaled_mm_sm75_epilogue( out, a, b, a_scales, b_scales, azp_adj, *azp, bias); } else { - return cutlass_scaled_mm_sm75_epilogue( + return cutlass_scaled_mm_sm75_epilogue( out, a, b, a_scales, b_scales, azp_adj, bias); } } @@ -78,12 +81,11 @@ void cutlass_scaled_mm_sm80_epilogue(torch::Tensor& out, torch::Tensor const& a, TORCH_CHECK(b.dtype() == torch::kInt8); if (out.dtype() == torch::kBFloat16) { - return vllm::cutlass_gemm_sm80_dispatch( + return cutlass_gemm_sm80_dispatch( out, a, b, std::forward(epilogue_args)...); } else { TORCH_CHECK(out.dtype() == torch::kFloat16); - return vllm::cutlass_gemm_sm80_dispatch( + return cutlass_gemm_sm80_dispatch( out, a, b, std::forward(epilogue_args)...); } } @@ -98,10 +100,10 @@ void cutlass_scaled_mm_sm80(torch::Tensor& out, torch::Tensor const& a, if (bias) { TORCH_CHECK(bias->dtype() == out.dtype(), "currently bias dtype must match output dtype ", out.dtype()); - return cutlass_scaled_mm_sm80_epilogue( + return cutlass_scaled_mm_sm80_epilogue( out, a, b, a_scales, b_scales, *bias); } else { - return cutlass_scaled_mm_sm80_epilogue( + return cutlass_scaled_mm_sm80_epilogue( out, a, b, a_scales, b_scales); } } @@ -117,10 +119,10 @@ void cutlass_scaled_mm_azp_sm80(torch::Tensor& out, torch::Tensor const& a, TORCH_CHECK(b_scales.dtype() == torch::kFloat32); if (azp) { - return cutlass_scaled_mm_sm80_epilogue( + return cutlass_scaled_mm_sm80_epilogue( out, a, b, a_scales, b_scales, azp_adj, *azp, bias); } else { - return cutlass_scaled_mm_sm80_epilogue( + return cutlass_scaled_mm_sm80_epilogue( out, a, b, a_scales, b_scales, azp_adj, bias); } } @@ -134,13 +136,12 @@ void cutlass_scaled_mm_sm89_epilogue(torch::Tensor& out, torch::Tensor const& a, TORCH_CHECK(b.dtype() == torch::kInt8); if (out.dtype() == torch::kBFloat16) { - return vllm::cutlass_gemm_sm89_int8_dispatch( + return cutlass_gemm_sm89_int8_dispatch( out, a, b, std::forward(epilogue_args)...); } else { assert(out.dtype() == torch::kFloat16); - return vllm::cutlass_gemm_sm89_int8_dispatch( + return cutlass_gemm_sm89_int8_dispatch( out, a, b, std::forward(epilogue_args)...); } } else { @@ -148,13 +149,13 @@ void cutlass_scaled_mm_sm89_epilogue(torch::Tensor& out, torch::Tensor const& a, TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn); if (out.dtype() == torch::kBFloat16) { - return vllm::cutlass_gemm_sm89_fp8_dispatch< - cutlass::float_e4m3_t, cutlass::bfloat16_t, Epilogue>( + return cutlass_gemm_sm89_fp8_dispatch( out, a, b, std::forward(epilogue_args)...); } else { TORCH_CHECK(out.dtype() == torch::kFloat16); - return vllm::cutlass_gemm_sm89_fp8_dispatch( + return cutlass_gemm_sm89_fp8_dispatch( out, a, b, std::forward(epilogue_args)...); } } @@ -170,10 +171,10 @@ void cutlass_scaled_mm_sm89(torch::Tensor& out, torch::Tensor const& a, if (bias) { TORCH_CHECK(bias->dtype() == out.dtype(), "currently bias dtype must match output dtype ", out.dtype()); - return cutlass_scaled_mm_sm89_epilogue( + return cutlass_scaled_mm_sm89_epilogue( out, a, b, a_scales, b_scales, *bias); } else { - return cutlass_scaled_mm_sm89_epilogue( + return cutlass_scaled_mm_sm89_epilogue( out, a, b, a_scales, b_scales); } } @@ -189,10 +190,10 @@ void cutlass_scaled_mm_azp_sm89(torch::Tensor& out, torch::Tensor const& a, TORCH_CHECK(b_scales.dtype() == torch::kFloat32); if (azp) { - return cutlass_scaled_mm_sm89_epilogue( + return cutlass_scaled_mm_sm89_epilogue( out, a, b, a_scales, b_scales, azp_adj, *azp, bias); } else { - return cutlass_scaled_mm_sm89_epilogue( + return cutlass_scaled_mm_sm89_epilogue( out, a, b, a_scales, b_scales, azp_adj, bias); } } diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh b/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh index 6329ff63623e2..d03242f44ab1d 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh @@ -21,7 +21,6 @@ #include "cutlass/epilogue/threadblock/fusion/visitors.hpp" #include "cutlass/gemm/kernel/default_gemm_universal_with_visitor.h" -#include "broadcast_load_epilogue_c2x.hpp" #include "common.hpp" // clang-format on @@ -71,307 +70,6 @@ struct enable_sm89_to_sm90 : Kernel { #endif } }; - -/* - * This class provides the common load descriptors for the - * ScaledEpilogue[...] classes - */ -template -struct ScaledEpilogueBase { - protected: - using Accum = cutlass::epilogue::threadblock::VisitorAccFetch; - - template - using ColOrScalarLoad = - cutlass::epilogue::threadblock::VisitorColOrScalarBroadcast< - OutputTileThreadMap, T, Stride, Int<0>, Int<0>>>; - - template - using RowOrScalarLoad = - cutlass::epilogue::threadblock::VisitorRowOrScalarBroadcast< - OutputTileThreadMap, T, Stride, Int<1>, Int<0>>>; - - template - using ColLoad = cutlass::epilogue::threadblock::VisitorColBroadcast< - OutputTileThreadMap, T, Stride, Int<0>, Int<0>>>; - - template - using RowLoad = cutlass::epilogue::threadblock::VisitorRowBroadcast< - OutputTileThreadMap, T, Stride, Int<1>, Int<0>>>; - - template - using RowOrZeroLoad = - cutlass::epilogue::threadblock::VisitorRowOrZeroBroadcast< - OutputTileThreadMap, T, Stride, Int<1>, Int<0>>>; - - // This utility function constructs the arguments for the load descriptors - // from a tensor. It can handle both row and column, as well as row/column or - // scalar cases. - template - static auto args_from_tensor(torch::Tensor const& tensor) { - using Arguments = typename Descriptor::Arguments; - auto* data_ptr = static_cast(tensor.data_ptr()); - if constexpr (std::is_same_v> || - std::is_same_v>) { - return Arguments{data_ptr, tensor.numel() != 1}; - } else { - // it would technically work but no use case as data_ptr is never nullptr - static_assert(!std::is_same_v>); - return Arguments{data_ptr}; - } - } - - // This overload handles the case where there might not be a tensor, in which - // case a nullptr is passed and a constant (0) is used. - template - static auto args_from_tensor(c10::optional const& tensor) { - static_assert(std::is_same_v>); - using Arguments = typename Descriptor::Arguments; - auto* data_ptr = tensor ? static_cast(tensor->data_ptr()) : nullptr; - return Arguments{data_ptr}; - } -}; - -/* - This epilogue function defines a quantized GEMM operation similar to - torch._scaled_mm. - - A and B may be both either int8 or fp8_e4m3. A can be quantized per-tensor or - per-row. B can be quantized per-tensor or per-column. - Any combination of per-tensor and per-row or column is supported. - A and B must have symmetric quantization (zero point == 0). - - So the GEMM operation is D = (a_scales * A) (b_scales * B), where the - scales are applied elementwise with numpy-style broadcasting. - - ScaleA and ScaleB define the epilogue functions that apply the scales for - the A and B operands respectively. These scales may be either per-tensor or - per row or column. -*/ -template -struct ScaledEpilogue - : private ScaledEpilogueBase { - private: - using SUPER = ScaledEpilogueBase; - using Accum = typename SUPER::Accum; - using ScaleA = typename SUPER::template ColOrScalarLoad; - using ScaleB = typename SUPER::template RowOrScalarLoad; - - using Compute0 = cutlass::epilogue::threadblock::VisitorCompute< - cutlass::multiplies, float, float, - cutlass::FloatRoundStyle::round_to_nearest>; - - using EVTCompute0 = - cutlass::epilogue::threadblock::Sm80EVT; - - using Compute1 = cutlass::epilogue::threadblock::VisitorCompute< - cutlass::multiplies, ElementD, float, - cutlass::FloatRoundStyle::round_to_nearest>; - - public: - using EVTCompute = - cutlass::epilogue::threadblock::Sm80EVT; - using ArgumentType = typename EVTCompute::Arguments; - - static ArgumentType prepare_args(torch::Tensor const& a_scales, - torch::Tensor const& b_scales) { - auto a_args = SUPER::template args_from_tensor(a_scales); - auto b_args = SUPER::template args_from_tensor(b_scales); - - typename EVTCompute0::Arguments evt0_args{b_args}; - return ArgumentType{a_args, evt0_args}; - } -}; - -/* - * This epilogue performs the same operation as ScaledEpilogue, but adds a bias. - * This bias can also be used in the per-tensor azp case, where the activation - * zero point (azp) is used to compute an azp correction term, - * which is folded into the bias. - * - * The bias tensor must be per-output channel. - * ScaleA and ScaleB can be per-tensor or per-token/per-channel. - */ -template -struct ScaledEpilogueBias - : protected ScaledEpilogueBase { - protected: - using SUPER = ScaledEpilogueBase; - using Accum = typename SUPER::Accum; - using ScaleA = typename SUPER::template ColOrScalarLoad; - using ScaleB = typename SUPER::template RowOrScalarLoad; - using Bias = typename SUPER::template RowLoad; - using Compute0 = cutlass::epilogue::threadblock::VisitorCompute< - cutlass::multiplies, float, float, - cutlass::FloatRoundStyle::round_to_nearest>; - - using EVTCompute0 = - cutlass::epilogue::threadblock::Sm80EVT; - - using Compute1 = cutlass::epilogue::threadblock::VisitorCompute< - cutlass::multiply_add, ElementD, float, - cutlass::FloatRoundStyle::round_to_nearest>; - - public: - using EVTCompute = cutlass::epilogue::threadblock::Sm80EVT; - using ArgumentType = typename EVTCompute::Arguments; - static ArgumentType prepare_args(torch::Tensor const& a_scales, - torch::Tensor const& b_scales, - torch::Tensor const& bias) { - auto a_args = SUPER::template args_from_tensor(a_scales); - auto b_args = SUPER::template args_from_tensor(b_scales); - auto bias_args = SUPER::template args_from_tensor(bias); - - typename EVTCompute0::Arguments evt0_args{b_args}; - return ArgumentType{a_args, evt0_args, bias_args}; - } -}; - -/* - * This epilogue directly supports per-tensor azp in int32 form. - * As opposed to the per-token epilogue below, this epilogue only has an azp_adj - * term, which should already be multiplied with the scalar azp. - * The azp_adj term is a 1D tensor of shape (1,n), computed as azp * J @ B. - * - * This epilogue also supports bias, which remains per-channel. - */ -template -struct ScaledEpilogueBiasAzp - : protected ScaledEpilogueBase { - private: - using SUPER = ScaledEpilogueBase; - using Accum = typename SUPER::Accum; - using ScaleA = typename SUPER::template ColOrScalarLoad; - using ScaleB = typename SUPER::template RowOrScalarLoad; - using Bias = typename SUPER::template RowOrZeroLoad; - - // This is the full AZP term, azp * J @ B, shape (1,n) - using AzpWithAdj = typename SUPER::template RowLoad; - - // Compute float(accum - azp_adj), both operands are int32_t - using ComputeAzp = cutlass::epilogue::threadblock::VisitorCompute< - cutlass::minus, float, int32_t, - cutlass::FloatRoundStyle::round_to_nearest>; - - using EVTComputeAzp = - cutlass::epilogue::threadblock::Sm80EVT; - - using ComputeScaleB = cutlass::epilogue::threadblock::VisitorCompute< - cutlass::multiplies, float, float, - cutlass::FloatRoundStyle::round_to_nearest>; - - using EVTComputeScaleB = - cutlass::epilogue::threadblock::Sm80EVT; - - using ComputeScaleBiasA = cutlass::epilogue::threadblock::VisitorCompute< - cutlass::multiply_add, ElementD, float, - cutlass::FloatRoundStyle::round_to_nearest>; - - public: - using EVTCompute = - cutlass::epilogue::threadblock::Sm80EVT; - - using ArgumentType = typename EVTCompute::Arguments; - - static ArgumentType prepare_args(torch::Tensor const& a_scales, - torch::Tensor const& b_scales, - torch::Tensor const& azp_adj, - c10::optional const& bias) { - auto a_args = SUPER::template args_from_tensor(a_scales); - auto b_args = SUPER::template args_from_tensor(b_scales); - auto bias_args = SUPER::template args_from_tensor(bias); - auto azp_adj_args = - SUPER::template args_from_tensor(azp_adj); - - typename EVTComputeAzp::Arguments evt_azp_args{{}, azp_adj_args}; - typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_azp_args}; - return ArgumentType{a_args, evt_scale_b_args, bias_args}; - } -}; - -/* - * This epilogue supports per-token azp by computing and applying - * the correction term using a rank-1 update. If the term were materialized, - * it would require O(m*n) space, and this way it only requires O(m+n) space. - * The azp term is a 1D tensor of shape (m,1), and represents the unscaled zero - * point for each row of A. - * The azp_adj term is a 1D tensor of shape (1,n), computed as J @ B. - * - * This epilogue also supports bias, which remains per-channel. - */ -template -struct ScaledEpilogueBiasAzpToken - : protected ScaledEpilogueBase { - private: - using SUPER = ScaledEpilogueBase; - using Accum = typename SUPER::Accum; - using ScaleA = typename SUPER::template ColOrScalarLoad; - using ScaleB = typename SUPER::template RowOrScalarLoad; - using Bias = typename SUPER::template RowOrZeroLoad; - - // Per-token azp term, shape (m,1) - using Azp = typename SUPER::template ColLoad; - - // This is the AZP adjustment term, J @ B, shape (1,n) - using AzpAdj = typename SUPER::template RowLoad; - - // Compute azp * azp_adj - using ComputeAzp = cutlass::epilogue::threadblock::VisitorCompute< - cutlass::multiplies, int32_t, int32_t, - cutlass::FloatRoundStyle::round_to_nearest>; - - using EVTComputeAzp = - cutlass::epilogue::threadblock::Sm80EVT; - - // Compute float(accum - azp*azp_adj), all operands are int32_t - using ComputeAcc = cutlass::epilogue::threadblock::VisitorCompute< - cutlass::minus, float, int32_t, - cutlass::FloatRoundStyle::round_to_nearest>; - - using EVTComputeAcc = - cutlass::epilogue::threadblock::Sm80EVT; - - using ComputeScaleB = cutlass::epilogue::threadblock::VisitorCompute< - cutlass::multiplies, float, float, - cutlass::FloatRoundStyle::round_to_nearest>; - - using EVTComputeScaleB = - cutlass::epilogue::threadblock::Sm80EVT; - - using ComputeScaleBiasA = cutlass::epilogue::threadblock::VisitorCompute< - cutlass::multiply_add, ElementD, float, - cutlass::FloatRoundStyle::round_to_nearest>; - - public: - using EVTCompute = - cutlass::epilogue::threadblock::Sm80EVT; - - using ArgumentType = typename EVTCompute::Arguments; - - static ArgumentType prepare_args(torch::Tensor const& a_scales, - torch::Tensor const& b_scales, - torch::Tensor const& azp_adj, - torch::Tensor const& azp, - c10::optional const& bias) { - auto a_args = SUPER::template args_from_tensor(a_scales); - auto b_args = SUPER::template args_from_tensor(b_scales); - auto bias_args = SUPER::template args_from_tensor(bias); - auto azp_args = SUPER::template args_from_tensor(azp); - auto azp_adj_args = - SUPER::template args_from_tensor(azp_adj); - - typename EVTComputeAzp::Arguments evt_azp_args{azp_args, azp_adj_args}; - typename EVTComputeAcc::Arguments evt_acc_args{{}, evt_azp_args}; - typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_acc_args}; - return ArgumentType{a_args, evt_scale_b_args, bias_args}; - } -}; - template typename ArchGuard, typename ElementAB_, typename ElementD_, template typename Epilogue_, typename TileShape, diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu index 292c9e4b34e1c..33581a63d4c3d 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu @@ -23,11 +23,12 @@ #include "cutlass/epilogue/collective/collective_builder.hpp" #include "cutlass/gemm/collective/collective_builder.hpp" -#include "broadcast_load_epilogue_c3x.hpp" +#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp" #include "common.hpp" // clang-format on using namespace cute; +using namespace vllm; /* This file defines quantized GEMM operations using the CUTLASS 3.x API, for @@ -56,305 +57,6 @@ struct enable_sm90_or_later : Kernel { #endif } }; - -/* - * This class provides the common load descriptors for the - * ScaledEpilogue[...] classes - */ -template -struct ScaledEpilogueBase { - protected: - using Accum = cutlass::epilogue::fusion::Sm90AccFetch; - - template - using ColOrScalarLoad = cutlass::epilogue::fusion::Sm90ColOrScalarBroadcast< - 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, - Stride, Int<0>, Int<0>>>; - - template - using RowOrScalarLoad = cutlass::epilogue::fusion::Sm90RowOrScalarBroadcast< - 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, - Stride, Int<1>, Int<0>>>; - - // Don't want to support nullptr by default - template - using ColLoad = cutlass::epilogue::fusion::Sm90ColBroadcast< - 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, - Stride, Int<0>, Int<0>>, 128 / sizeof_bits_v, EnableNullPtr>; - - // Don't want to support nullptr by default - template - using RowLoad = cutlass::epilogue::fusion::Sm90RowBroadcast< - 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, - Stride, Int<1>, Int<0>>, 128 / sizeof_bits_v, EnableNullPtr>; - - // This utility function constructs the arguments for the load descriptors - // from a tensor. It can handle both row and column, as well as row/column or - // scalar cases. - template - static auto args_from_tensor(torch::Tensor const& tensor) { - using Arguments = typename Descriptor::Arguments; - auto* data_ptr = static_cast(tensor.data_ptr()); - if constexpr (std::is_same_v> || - std::is_same_v>) { - return Arguments{data_ptr, tensor.numel() != 1}; - } else { - static_assert(!std::is_same_v> && - !std::is_same_v>); - return Arguments{data_ptr}; - } - } - - // This overload handles the case where there might not be a tensor, in which - // case a nullptr is passed and a constant (0) is used. - template - static auto args_from_tensor(c10::optional const& tensor) { - using Arguments = typename Descriptor::Arguments; - auto* data_ptr = tensor ? static_cast(tensor->data_ptr()) : nullptr; - static_assert(std::is_same_v> || - std::is_same_v>); - return Arguments{data_ptr}; - } -}; - -/* - This epilogue function defines a quantized GEMM operation similar to - torch.scaled_mm_. - - A and B may be both either int8 or fp8_e4m3. A can be - quantized per-tensor or per-row. B can be quantized per-tensor or per-column. - Any combination of per-tensor and per-row or column is supported. - A and B must have symmetric quantization (zero point == 0). - - So the GEMM operation is D = (a_scales * A) (b_scales * B), where the - scales are applied elementwise with numpy-style broadcasting. - - ScaleA and ScaleB define the epilogue functions that apply the scales for - the A and B operands respectively. These scales may be either per-tensor or - per row or column. -*/ -template -struct ScaledEpilogue - : private ScaledEpilogueBase { - private: - using SUPER = ScaledEpilogueBase; - using Accum = typename SUPER::Accum; - using ScaleA = typename SUPER::template ColOrScalarLoad; - using ScaleB = typename SUPER::template RowOrScalarLoad; - - using Compute0 = cutlass::epilogue::fusion::Sm90Compute< - cutlass::multiplies, float, float, - cutlass::FloatRoundStyle::round_to_nearest>; - - using EVTCompute0 = - cutlass::epilogue::fusion::Sm90EVT; - - using Compute1 = cutlass::epilogue::fusion::Sm90Compute< - cutlass::multiplies, ElementD, float, - cutlass::FloatRoundStyle::round_to_nearest>; - - public: - using EVTCompute = - cutlass::epilogue::fusion::Sm90EVT; - using ArgumentType = typename EVTCompute::Arguments; - - static ArgumentType prepare_args(torch::Tensor const& a_scales, - torch::Tensor const& b_scales) { - auto a_args = SUPER::template args_from_tensor(a_scales); - auto b_args = SUPER::template args_from_tensor(b_scales); - - typename EVTCompute0::Arguments evt0_args{b_args}; - return ArgumentType{a_args, evt0_args}; - } -}; - -/* - * This epilogue performs the same operation as ScaledEpilogue, but adds a bias. - * This bias can also be used in the per-tensor azp case, where the activation - * zero point (azp) is used to compute an azp correction term, - * which is folded into the bias. - * - * The bias tensor must be per-output channel. - * ScaleA and ScaleB can be per-tensor or per-token/per-channel. - */ -template -struct ScaledEpilogueBias - : private ScaledEpilogueBase { - private: - using SUPER = ScaledEpilogueBase; - using Accum = typename SUPER::Accum; - using ScaleA = typename SUPER::template ColOrScalarLoad; - using ScaleB = typename SUPER::template RowOrScalarLoad; - using Bias = typename SUPER::template RowLoad; - - using Compute0 = cutlass::epilogue::fusion::Sm90Compute< - cutlass::multiplies, float, float, - cutlass::FloatRoundStyle::round_to_nearest>; - - using EVTCompute0 = - cutlass::epilogue::fusion::Sm90EVT; - - using Compute1 = cutlass::epilogue::fusion::Sm90Compute< - cutlass::multiply_add, ElementD, float, - cutlass::FloatRoundStyle::round_to_nearest>; - - public: - using EVTCompute = - cutlass::epilogue::fusion::Sm90EVT; - - using ArgumentType = typename EVTCompute::Arguments; - static ArgumentType prepare_args(torch::Tensor const& a_scales, - torch::Tensor const& b_scales, - torch::Tensor const& bias) { - auto a_args = SUPER::template args_from_tensor(a_scales); - auto b_args = SUPER::template args_from_tensor(b_scales); - auto bias_args = SUPER::template args_from_tensor(bias); - - typename EVTCompute0::Arguments evt0_args{b_args}; - return ArgumentType{a_args, evt0_args, bias_args}; - } -}; - -/* - * This epilogue directly supports per-tensor azp in int32 form. - * As opposed to the per-token epilogue below, this epilogue only has an azp_adj - * term, which should already be multiplied with the scalar azp. - * The azp_adj term is a 1D tensor of shape (1,n), computed as azp * J @ B. - * - * This epilogue also supports bias, which remains per-channel. - */ -template -struct ScaledEpilogueBiasAzp - : private ScaledEpilogueBase { - private: - using SUPER = ScaledEpilogueBase; - using Accum = typename SUPER::Accum; - using ScaleA = typename SUPER::template ColOrScalarLoad; - using ScaleB = typename SUPER::template RowOrScalarLoad; - using Bias = typename SUPER::template RowLoad; - - // This is the full AZP term, azp * J @ B, shape (1,n) - using AzpWithAdj = typename SUPER::template RowLoad; - - // Compute float(accum - azp_adj), both operands are int32_t - using ComputeAzp = cutlass::epilogue::fusion::Sm90Compute< - cutlass::minus, float, int32_t, - cutlass::FloatRoundStyle::round_to_nearest>; - - using EVTComputeAzp = - cutlass::epilogue::fusion::Sm90EVT; - - using ComputeScaleB = cutlass::epilogue::fusion::Sm90Compute< - cutlass::multiplies, float, float, - cutlass::FloatRoundStyle::round_to_nearest>; - - using EVTComputeScaleB = - cutlass::epilogue::fusion::Sm90EVT; - - using ComputeScaleBiasA = cutlass::epilogue::fusion::Sm90Compute< - cutlass::multiply_add, ElementD, float, - cutlass::FloatRoundStyle::round_to_nearest>; - - public: - using EVTCompute = - cutlass::epilogue::fusion::Sm90EVT; - using ArgumentType = typename EVTCompute::Arguments; - - static ArgumentType prepare_args(torch::Tensor const& a_scales, - torch::Tensor const& b_scales, - torch::Tensor const& azp_adj, - c10::optional const& bias) { - auto a_args = SUPER::template args_from_tensor(a_scales); - auto b_args = SUPER::template args_from_tensor(b_scales); - auto bias_args = SUPER::template args_from_tensor(bias); - auto azp_adj_args = - SUPER::template args_from_tensor(azp_adj); - - typename EVTComputeAzp::Arguments evt_azp_args{{}, azp_adj_args}; - typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_azp_args}; - return ArgumentType{a_args, evt_scale_b_args, bias_args}; - } -}; - -/* - * This epilogue supports per-token azp by computing and applying - * the correction term using a rank-1 update. If the term were materialized, - * it would require O(m*n) space, and this way it only requires O(m+n) space. - * The azp term is a 1D tensor of shape (m,1), and represents the unscaled zero - * point for each row of A. - * The azp_adj term is a 1D tensor of shape (1,n), computed as J @ B. - * - * This epilogue also supports bias, which remains per-channel. - */ -template -struct ScaledEpilogueBiasAzpToken - : private ScaledEpilogueBase { - private: - using SUPER = ScaledEpilogueBase; - using Accum = typename SUPER::Accum; - using ScaleA = typename SUPER::template ColOrScalarLoad; - using ScaleB = typename SUPER::template RowOrScalarLoad; - using Bias = typename SUPER::template RowLoad; - - // Per-token azp term, shape (m,1) - using Azp = typename SUPER::template ColLoad; - - // This is the AZP adjustment term, J @ B, shape (1,n) - using AzpAdj = typename SUPER::template RowLoad; - - // Compute azp * azp_adj - using ComputeAzp = cutlass::epilogue::fusion::Sm90Compute< - cutlass::multiplies, int32_t, int32_t, - cutlass::FloatRoundStyle::round_to_nearest>; - - using EVTComputeAzp = - cutlass::epilogue::fusion::Sm90EVT; - - // Compute float(accum - azp*azp_adj), all operands are int32_t - using ComputeAcc = cutlass::epilogue::fusion::Sm90Compute< - cutlass::minus, float, int32_t, - cutlass::FloatRoundStyle::round_to_nearest>; - - using EVTComputeAcc = - cutlass::epilogue::fusion::Sm90EVT; - - using ComputeScaleB = cutlass::epilogue::fusion::Sm90Compute< - cutlass::multiplies, float, float, - cutlass::FloatRoundStyle::round_to_nearest>; - - using EVTComputeScaleB = - cutlass::epilogue::fusion::Sm90EVT; - - using ComputeScaleBiasA = cutlass::epilogue::fusion::Sm90Compute< - cutlass::multiply_add, ElementD, float, - cutlass::FloatRoundStyle::round_to_nearest>; - - public: - using EVTCompute = - cutlass::epilogue::fusion::Sm90EVT; - using ArgumentType = typename EVTCompute::Arguments; - - static ArgumentType prepare_args(torch::Tensor const& a_scales, - torch::Tensor const& b_scales, - torch::Tensor const& azp_adj, - torch::Tensor const& azp, - c10::optional const& bias) { - auto a_args = SUPER::template args_from_tensor(a_scales); - auto b_args = SUPER::template args_from_tensor(b_scales); - auto bias_args = SUPER::template args_from_tensor(bias); - auto azp_args = SUPER::template args_from_tensor(azp); - auto azp_adj_args = - SUPER::template args_from_tensor(azp_adj); - - typename EVTComputeAzp::Arguments evt_azp_args{azp_args, azp_adj_args}; - typename EVTComputeAcc::Arguments evt_acc_args{{}, evt_azp_args}; - typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_acc_args}; - return ArgumentType{a_args, evt_scale_b_args, bias_args}; - } -}; - template typename Epilogue_, typename TileShape, typename ClusterShape, typename KernelSchedule, @@ -721,11 +423,11 @@ void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a, if (bias) { TORCH_CHECK(bias->dtype() == c.dtype(), "currently bias dtype must match output dtype ", c.dtype()); - return cutlass_scaled_mm_sm90_epilogue( + return cutlass_scaled_mm_sm90_epilogue( c, a, b, a_scales, b_scales, *bias); } else { - return cutlass_scaled_mm_sm90_epilogue(c, a, b, a_scales, - b_scales); + return cutlass_scaled_mm_sm90_epilogue( + c, a, b, a_scales, b_scales); } } @@ -740,10 +442,10 @@ void cutlass_scaled_mm_azp_sm90(torch::Tensor& out, torch::Tensor const& a, TORCH_CHECK(b_scales.dtype() == torch::kFloat32); if (azp) { - return cutlass_scaled_mm_sm90_epilogue( + return cutlass_scaled_mm_sm90_epilogue( out, a, b, a_scales, b_scales, azp_adj, *azp, bias); } else { - return cutlass_scaled_mm_sm90_epilogue( + return cutlass_scaled_mm_sm90_epilogue( out, a, b, a_scales, b_scales, azp_adj, bias); } } diff --git a/csrc/quantization/gguf/ggml-common.h b/csrc/quantization/gguf/ggml-common.h index fba94fd1d157b..d42205a6571db 100644 --- a/csrc/quantization/gguf/ggml-common.h +++ b/csrc/quantization/gguf/ggml-common.h @@ -1,7 +1,7 @@ // copied from https://github.com/ggerganov/llama.cpp/blob/b2899/ggml-common.h #define QK_K 256 #define K_QUANTS_PER_ITERATION 2 -#define WARP_SIZE 32 +#define WARP_SIZE_GGUF 32 #define K_SCALE_SIZE 12 #define CUDA_DEQUANTIZE_BLOCK_SIZE 256 #define CUDA_QUANTIZE_BLOCK_SIZE 256 @@ -1112,4 +1112,19 @@ static __device__ __forceinline__ int __dp4a(const int a, const int b, int c) { #endif return c; } + +static __device__ __forceinline__ uint32_t __vcmpeq4(const uint32_t a, const uint32_t b) { + uint32_t neq = a^b; + return !(neq & 0xff000000) * 0xff000000 | + !(neq & 0x00ff0000) * 0x00ff0000 | + !(neq & 0x0000ff00) * 0x0000ff00 | + !(neq & 0x000000ff) * 0x000000ff; +} + +static __device__ __forceinline__ uint32_t __vsub4(const uint32_t a, const uint32_t b) { + return (static_cast(((a & 0xff000000) >> 24) - ((b & 0xff000000) >> 24)) << 24) + + (static_cast(((a & 0x00ff0000) >> 16) - ((b & 0x00ff0000) >> 16)) << 16) + + (static_cast(((a & 0x0000ff00) >> 8) - ((b & 0x0000ff00) >> 8)) << 8) + + (static_cast(((a & 0x000000ff) >> 0) - ((b & 0x000000ff) >> 0)) << 0); +} #endif // defined(USE_ROCM) diff --git a/csrc/quantization/gguf/gguf_kernel.cu b/csrc/quantization/gguf/gguf_kernel.cu index 37e4de4e14dd3..5f0eaf5a973fb 100644 --- a/csrc/quantization/gguf/gguf_kernel.cu +++ b/csrc/quantization/gguf/gguf_kernel.cu @@ -4,6 +4,8 @@ #include #include +#include "cuda_compat.h" + #include "ggml-common.h" #include "vecdotq.cuh" #include "dequantize.cuh" @@ -32,8 +34,8 @@ static __global__ void quantize_q8_1(const half* __restrict__ x, #pragma unroll for (int mask = 16; mask > 0; mask >>= 1) { - amax = fmaxf(amax, __shfl_xor_sync(0xffffffff, amax, mask, 32)); - sum += __shfl_xor_sync(0xffffffff, sum, mask, 32); + amax = fmaxf(amax, VLLM_SHFL_XOR_SYNC_WIDTH(amax, mask, 32)); + sum += VLLM_SHFL_XOR_SYNC_WIDTH(sum, mask, 32); } const float d = amax / 127; diff --git a/csrc/quantization/gguf/mmq.cuh b/csrc/quantization/gguf/mmq.cuh index d13efd5965313..c935faa07df0c 100644 --- a/csrc/quantization/gguf/mmq.cuh +++ b/csrc/quantization/gguf/mmq.cuh @@ -10,7 +10,7 @@ static __device__ __forceinline__ void mul_mat_q( const int blocks_per_row_x = ncols_x / qk; const int blocks_per_col_y = nrows_y / QK8_1; - const int blocks_per_warp = WARP_SIZE / qi; + const int blocks_per_warp = WARP_SIZE_GGUF / qi; const int & ncols_dst = ncols_y; @@ -27,10 +27,10 @@ static __device__ __forceinline__ void mul_mat_q( allocate_tiles(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc); - __shared__ int tile_y_qs[mmq_x * WARP_SIZE]; - __shared__ half2 tile_y_ds[mmq_x * WARP_SIZE/QI8_1]; + __shared__ int tile_y_qs[mmq_x * WARP_SIZE_GGUF]; + __shared__ half2 tile_y_ds[mmq_x * WARP_SIZE_GGUF/QI8_1]; - float sum[mmq_y/WARP_SIZE][mmq_x/nwarps] = {{0.0f}}; + float sum[mmq_y/WARP_SIZE_GGUF][mmq_x/nwarps] = {{0.0f}}; for (int ib0 = 0; ib0 < blocks_per_row_x; ib0 += blocks_per_warp) { @@ -39,26 +39,26 @@ static __device__ __forceinline__ void mul_mat_q( #pragma unroll for (int ir = 0; ir < qr; ++ir) { - const int kqs = ir*WARP_SIZE + threadIdx.x; + const int kqs = ir*WARP_SIZE_GGUF + threadIdx.x; const int kbxd = kqs / QI8_1; #pragma unroll for (int i = 0; i < mmq_x; i += nwarps) { const int col_y_eff = min(col_y_0 + threadIdx.y + i, ncols_y-1); // to prevent out-of-bounds memory accesses const block_q8_1 * by0 = &y[col_y_eff*blocks_per_col_y + ib0 * (qk/QK8_1) + kbxd]; - const int index_y = (threadIdx.y + i) * WARP_SIZE + kqs % WARP_SIZE; + const int index_y = (threadIdx.y + i) * WARP_SIZE_GGUF + kqs % WARP_SIZE_GGUF; tile_y_qs[index_y] = get_int_from_int8_aligned(by0->qs, threadIdx.x % QI8_1); } #pragma unroll for (int ids0 = 0; ids0 < mmq_x; ids0 += nwarps * QI8_1) { - const int ids = (ids0 + threadIdx.y * QI8_1 + threadIdx.x / (WARP_SIZE/QI8_1)) % mmq_x; - const int kby = threadIdx.x % (WARP_SIZE/QI8_1); + const int ids = (ids0 + threadIdx.y * QI8_1 + threadIdx.x / (WARP_SIZE_GGUF/QI8_1)) % mmq_x; + const int kby = threadIdx.x % (WARP_SIZE_GGUF/QI8_1); const int col_y_eff = min(col_y_0 + ids, ncols_y-1); // if the sum is not needed it's faster to transform the scale to f32 ahead of time - const half2 * dsi_src = &y[col_y_eff*blocks_per_col_y + ib0 * (qk/QK8_1) + ir*(WARP_SIZE/QI8_1) + kby].ds; - half2 * dsi_dst = &tile_y_ds[ids * (WARP_SIZE/QI8_1) + kby]; + const half2 * dsi_src = &y[col_y_eff*blocks_per_col_y + ib0 * (qk/QK8_1) + ir*(WARP_SIZE_GGUF/QI8_1) + kby].ds; + half2 * dsi_dst = &tile_y_ds[ids * (WARP_SIZE_GGUF/QI8_1) + kby]; if (need_sum) { *dsi_dst = *dsi_src; } else { @@ -70,12 +70,12 @@ static __device__ __forceinline__ void mul_mat_q( __syncthreads(); // #pragma unroll // unrolling this loop causes too much register pressure - for (int k = ir*WARP_SIZE/qr; k < (ir+1)*WARP_SIZE/qr; k += vdr) { + for (int k = ir*WARP_SIZE_GGUF/qr; k < (ir+1)*WARP_SIZE_GGUF/qr; k += vdr) { #pragma unroll for (int j = 0; j < mmq_x; j += nwarps) { #pragma unroll - for (int i = 0; i < mmq_y; i += WARP_SIZE) { - sum[i/WARP_SIZE][j/nwarps] += vec_dot( + for (int i = 0; i < mmq_y; i += WARP_SIZE_GGUF) { + sum[i/WARP_SIZE_GGUF][j/nwarps] += vec_dot( tile_x_ql, tile_x_dm, tile_x_qh, tile_x_sc, tile_y_qs, tile_y_ds, threadIdx.x + i, threadIdx.y + j, k); } @@ -93,12 +93,12 @@ static __device__ __forceinline__ void mul_mat_q( } #pragma unroll - for (int i = 0; i < mmq_y; i += WARP_SIZE) { + for (int i = 0; i < mmq_y; i += WARP_SIZE_GGUF) { const int row_dst = row_dst_0 + threadIdx.x + i; if (row_dst >= nrows_dst) { continue; } - dst[col_dst*nrows_dst + row_dst] = __float2half(sum[i/WARP_SIZE][j/nwarps]); + dst[col_dst*nrows_dst + row_dst] = __float2half(sum[i/WARP_SIZE_GGUF][j/nwarps]); } } } @@ -115,7 +115,7 @@ static __device__ __forceinline__ void mul_mat_q( template static __global__ void #if defined(USE_ROCM) -__launch_bounds__(WARP_SIZE*NWARPS_Q4_0, 2) +__launch_bounds__(WARP_SIZE_GGUF*NWARPS_Q4_0, 2) #endif mul_mat_q4_0( const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst, @@ -140,7 +140,7 @@ static void ggml_mul_mat_q4_0_q8_1_cuda( const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x; const dim3 block_nums(block_num_x, block_num_y, 1); - const dim3 block_dims(WARP_SIZE, nwarps, 1); + const dim3 block_dims(WARP_SIZE_GGUF, nwarps, 1); if (nrows_x % mmq_y == 0) { const bool need_check = false; @@ -165,7 +165,7 @@ static void ggml_mul_mat_q4_0_q8_1_cuda( template static __global__ void #if defined(USE_ROCM) -__launch_bounds__(WARP_SIZE*NWARPS_Q4_1, 2) +__launch_bounds__(WARP_SIZE_GGUF*NWARPS_Q4_1, 2) #endif mul_mat_q4_1( const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst, @@ -190,7 +190,7 @@ static void ggml_mul_mat_q4_1_q8_1_cuda( const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x; const dim3 block_nums(block_num_x, block_num_y, 1); - const dim3 block_dims(WARP_SIZE, nwarps, 1); + const dim3 block_dims(WARP_SIZE_GGUF, nwarps, 1); if (nrows_x % mmq_y == 0) { const bool need_check = false; @@ -215,7 +215,7 @@ static void ggml_mul_mat_q4_1_q8_1_cuda( template static __global__ void #if defined(USE_ROCM) -__launch_bounds__(WARP_SIZE*NWARPS_Q5_0, 2) +__launch_bounds__(WARP_SIZE_GGUF*NWARPS_Q5_0, 2) #endif mul_mat_q5_0( const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst, @@ -240,7 +240,7 @@ static void ggml_mul_mat_q5_0_q8_1_cuda( const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x; const dim3 block_nums(block_num_x, block_num_y, 1); - const dim3 block_dims(WARP_SIZE, nwarps, 1); + const dim3 block_dims(WARP_SIZE_GGUF, nwarps, 1); if (nrows_x % mmq_y == 0) { const bool need_check = false; @@ -265,7 +265,7 @@ static void ggml_mul_mat_q5_0_q8_1_cuda( template static __global__ void #if defined(USE_ROCM) -__launch_bounds__(WARP_SIZE*NWARPS_Q5_1, 2) +__launch_bounds__(WARP_SIZE_GGUF*NWARPS_Q5_1, 2) #endif mul_mat_q5_1( const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst, @@ -289,7 +289,7 @@ static void ggml_mul_mat_q5_1_q8_1_cuda( const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x; const dim3 block_nums(block_num_x, block_num_y, 1); - const dim3 block_dims(WARP_SIZE, nwarps, 1); + const dim3 block_dims(WARP_SIZE_GGUF, nwarps, 1); if (nrows_x % mmq_y == 0) { const bool need_check = false; @@ -314,7 +314,7 @@ static void ggml_mul_mat_q5_1_q8_1_cuda( template static __global__ void #if defined(USE_ROCM) -__launch_bounds__(WARP_SIZE*NWARPS_Q8_0, 2) +__launch_bounds__(WARP_SIZE_GGUF*NWARPS_Q8_0, 2) #endif mul_mat_q8_0( const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst, @@ -338,7 +338,7 @@ static void ggml_mul_mat_q8_0_q8_1_cuda( const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x; const dim3 block_nums(block_num_x, block_num_y, 1); - const dim3 block_dims(WARP_SIZE, nwarps, 1); + const dim3 block_dims(WARP_SIZE_GGUF, nwarps, 1); if (nrows_x % mmq_y == 0) { const bool need_check = false; @@ -363,7 +363,7 @@ static void ggml_mul_mat_q8_0_q8_1_cuda( template static __global__ void #if defined(USE_ROCM) -__launch_bounds__(WARP_SIZE*NWARPS_Q2_K, 2) +__launch_bounds__(WARP_SIZE_GGUF*NWARPS_Q2_K, 2) #endif mul_mat_q2_K( const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst, @@ -387,7 +387,7 @@ static void ggml_mul_mat_q2_K_q8_1_cuda( const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x; const dim3 block_nums(block_num_x, block_num_y, 1); - const dim3 block_dims(WARP_SIZE, nwarps, 1); + const dim3 block_dims(WARP_SIZE_GGUF, nwarps, 1); if (nrows_x % mmq_y == 0) { const bool need_check = false; @@ -412,7 +412,7 @@ static void ggml_mul_mat_q2_K_q8_1_cuda( template static __global__ void #if defined(USE_ROCM) -__launch_bounds__(WARP_SIZE*NWARPS_Q3_K, 2) +__launch_bounds__(WARP_SIZE_GGUF*NWARPS_Q3_K, 2) #endif mul_mat_q3_K( const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst, @@ -438,7 +438,7 @@ static void ggml_mul_mat_q3_K_q8_1_cuda( const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x; const dim3 block_nums(block_num_x, block_num_y, 1); - const dim3 block_dims(WARP_SIZE, nwarps, 1); + const dim3 block_dims(WARP_SIZE_GGUF, nwarps, 1); if (nrows_x % mmq_y == 0) { const bool need_check = false; @@ -463,7 +463,7 @@ static void ggml_mul_mat_q3_K_q8_1_cuda( template static __global__ void #if defined(USE_ROCM) -__launch_bounds__(WARP_SIZE*NWARPS_Q4_K, 2) +__launch_bounds__(WARP_SIZE_GGUF*NWARPS_Q4_K, 2) #endif mul_mat_q4_K( const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst, @@ -487,7 +487,7 @@ static void ggml_mul_mat_q4_K_q8_1_cuda( const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x; const dim3 block_nums(block_num_x, block_num_y, 1); - const dim3 block_dims(WARP_SIZE, nwarps, 1); + const dim3 block_dims(WARP_SIZE_GGUF, nwarps, 1); if (nrows_x % mmq_y == 0) { const bool need_check = false; @@ -512,7 +512,7 @@ static void ggml_mul_mat_q4_K_q8_1_cuda( template static __global__ void #if defined(USE_ROCM) -__launch_bounds__(WARP_SIZE*NWARPS_Q5_K, 2) +__launch_bounds__(WARP_SIZE_GGUF*NWARPS_Q5_K, 2) #endif mul_mat_q5_K( const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst, @@ -537,7 +537,7 @@ static void ggml_mul_mat_q5_K_q8_1_cuda( const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x; const dim3 block_nums(block_num_x, block_num_y, 1); - const dim3 block_dims(WARP_SIZE, nwarps, 1); + const dim3 block_dims(WARP_SIZE_GGUF, nwarps, 1); if (nrows_x % mmq_y == 0) { const bool need_check = false; @@ -562,7 +562,7 @@ static void ggml_mul_mat_q5_K_q8_1_cuda( template static __global__ void #if defined(USE_ROCM) -__launch_bounds__(WARP_SIZE*NWARPS_Q6_K, 2) +__launch_bounds__(WARP_SIZE_GGUF*NWARPS_Q6_K, 2) #endif mul_mat_q6_K( const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst, @@ -586,7 +586,7 @@ static void ggml_mul_mat_q6_K_q8_1_cuda( const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x; const dim3 block_nums(block_num_x, block_num_y, 1); - const dim3 block_dims(WARP_SIZE, nwarps, 1); + const dim3 block_dims(WARP_SIZE_GGUF, nwarps, 1); if (nrows_x % mmq_y == 0) { const bool need_check = false; diff --git a/csrc/quantization/gguf/mmvq.cuh b/csrc/quantization/gguf/mmvq.cuh index b221ae7896138..b01e939808a3f 100644 --- a/csrc/quantization/gguf/mmvq.cuh +++ b/csrc/quantization/gguf/mmvq.cuh @@ -28,8 +28,8 @@ static __global__ void mul_mat_vec_q(const void * __restrict__ vx, const void * // sum up partial sums and write back result #pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) { - tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); + for (int mask = WARP_SIZE/2; mask > 0; mask >>= 1) { + tmp += VLLM_SHFL_XOR_SYNC(tmp, mask); } if (threadIdx.x == 0) { diff --git a/csrc/quantization/gguf/vecdotq.cuh b/csrc/quantization/gguf/vecdotq.cuh index d5af345a6b26f..e00422637c65b 100644 --- a/csrc/quantization/gguf/vecdotq.cuh +++ b/csrc/quantization/gguf/vecdotq.cuh @@ -43,7 +43,7 @@ static __device__ __forceinline__ int get_int_from_uint8_aligned(const uint8_t * template static __device__ __forceinline__ float vec_dot_q4_0_q8_1_impl( const int * v, const int * u, const float & d4, const half2 & ds8) { -#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM int sumi = 0; #pragma unroll @@ -68,7 +68,7 @@ template static __device__ __forceinline__ float vec_dot_q4_0_q8_1_imp template static __device__ __forceinline__ float vec_dot_q4_1_q8_1_impl( const int * v, const int * u, const half2 & dm4, const half2 & ds8) { -#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM int sumi = 0; #pragma unroll @@ -95,7 +95,7 @@ template static __device__ __forceinline__ float vec_dot_q4_1_q8_1_imp template static __device__ __forceinline__ float vec_dot_q5_0_q8_1_impl( const int * vl, const int * vh, const int * u, const float & d5, const half2 & ds8) { -#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM int sumi = 0; #pragma unroll @@ -128,7 +128,7 @@ template static __device__ __forceinline__ float vec_dot_q5_0_q8_1_imp template static __device__ __forceinline__ float vec_dot_q5_1_q8_1_impl( const int * vl, const int * vh, const int * u, const half2 & dm5, const half2 & ds8) { -#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM int sumi = 0; #pragma unroll @@ -162,7 +162,7 @@ template static __device__ __forceinline__ float vec_dot_q5_1_q8_1_imp template static __device__ __forceinline__ float vec_dot_q8_0_q8_1_impl( const int * v, const int * u, const float & d8_0, const float & d8_1) { -#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM int sumi = 0; #pragma unroll @@ -176,7 +176,7 @@ template static __device__ __forceinline__ float vec_dot_q8_0_q8_1_imp template static __device__ __forceinline__ float vec_dot_q8_1_q8_1_impl( const int * v, const int * u, const half2 & dm8, const half2 & ds8) { -#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM int sumi = 0; @@ -202,7 +202,7 @@ template static __device__ __forceinline__ float vec_dot_q8_1_q8_1_imp static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmvq( const int & v, const int * __restrict__ u, const uint8_t * __restrict__ scales, const half2 & dm2, const float * __restrict__ d8) { -#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM float sumf_d = 0.0f; float sumf_m = 0.0f; @@ -230,7 +230,7 @@ static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmvq( static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmq( const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ scales, const half2 & dm2, const float & d8) { -#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM int sumi_d = 0; int sumi_m = 0; @@ -267,7 +267,7 @@ static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmq( static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmvq( const int & vl, const int & vh, const int * __restrict__ u, const uint8_t * __restrict__ scales, const int & scale_offset, const float & d3, const float * __restrict__ d8) { -#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM float sumf = 0.0f; @@ -301,7 +301,7 @@ static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmvq( static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmq( const int * __restrict__ v, const int * __restrict__ u, const int8_t * __restrict__ scales, const float & d3, const float & d8) { -#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM int sumi = 0; #pragma unroll @@ -326,7 +326,7 @@ static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmq( static __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl_vmmq( const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ sc, const uint8_t * __restrict__ m, const half2 & dm4, const float * __restrict__ d8) { -#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM float sumf_d = 0.0f; float sumf_m = 0.0f; @@ -351,7 +351,7 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl_vmmq( static __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl_mmq( const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ sc, const uint8_t * __restrict__ m, const half2 & dm4, const half2 * __restrict__ ds8) { -#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM float sumf_d = 0.0f; float sumf_m = 0.0f; @@ -382,7 +382,7 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl_mmq( static __device__ __forceinline__ float vec_dot_q5_K_q8_1_impl_vmmq( const int * __restrict__ vl, const int * __restrict__ vh, const int * __restrict__ u, const uint8_t * __restrict__ sc, const uint8_t * __restrict__ m, const half2 & dm5, const float * __restrict__ d8) { -#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM float sumf_d = 0.0f; float sumf_m = 0.0f; @@ -413,7 +413,7 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1_impl_vmmq( static __device__ __forceinline__ float vec_dot_q5_K_q8_1_impl_mmq( const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ sc, const uint8_t * __restrict__ m, const half2 & dm4, const half2 * __restrict__ ds8) { -#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM float sumf_d = 0.0f; float sumf_m = 0.0f; @@ -445,7 +445,7 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1_impl_mmq( static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmvq( const int & vl, const int & vh, const int * __restrict__ u, const int8_t * __restrict__ scales, const float & d, const float * __restrict__ d8) { -#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM float sumf = 0.0f; #pragma unroll @@ -465,7 +465,7 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmvq( static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmq( const int * __restrict__ v, const int * __restrict__ u, const int8_t * __restrict__ sc, const float & d6, const float * __restrict__ d8) { -#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM float sumf_d = 0.0f; #pragma unroll @@ -507,8 +507,8 @@ static __device__ __forceinline__ float vec_dot_q4_0_q8_1( } template static __device__ __forceinline__ void allocate_tiles_q4_0(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { - __shared__ int tile_x_qs[mmq_y * (WARP_SIZE) + mmq_y]; - __shared__ float tile_x_d[mmq_y * (WARP_SIZE/QI4_0) + mmq_y/QI4_0]; + __shared__ int tile_x_qs[mmq_y * (WARP_SIZE_GGUF) + mmq_y]; + __shared__ float tile_x_d[mmq_y * (WARP_SIZE_GGUF/QI4_0) + mmq_y/QI4_0]; *x_ql = tile_x_qs; *x_dm = (half2 *) tile_x_d; } @@ -529,11 +529,11 @@ template static __device__ __forceinlin i = min(i, i_max); } const block_q4_0 * bxi = bx0 + i*blocks_per_row + kbx; - x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8(bxi->qs, kqsx); - // x_dmf[i * (WARP_SIZE/QI4_0) + i / QI4_0 + kbx] = bxi->d; + x_ql[i * (WARP_SIZE_GGUF + 1) + k] = get_int_from_uint8(bxi->qs, kqsx); + // x_dmf[i * (WARP_SIZE_GGUF/QI4_0) + i / QI4_0 + kbx] = bxi->d; } - const int blocks_per_tile_x_row = WARP_SIZE / QI4_0; + const int blocks_per_tile_x_row = WARP_SIZE_GGUF / QI4_0; const int kbxd = k % blocks_per_tile_x_row; #pragma unroll @@ -543,7 +543,7 @@ template static __device__ __forceinlin i = min(i, i_max); } const block_q4_0 * bxi = bx0 + i*blocks_per_row + kbxd; - x_dmf[i * (WARP_SIZE/QI4_0) + i / QI4_0 + kbxd] = __half2float(bxi->d); + x_dmf[i * (WARP_SIZE_GGUF/QI4_0) + i / QI4_0 + kbxd] = __half2float(bxi->d); } } @@ -559,13 +559,13 @@ static __device__ __forceinline__ float vec_dot_q4_0_q8_1_mul_mat( #pragma unroll for (int l = 0; l < VDR_Q4_0_Q8_1_MMQ; ++l) { - u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l) % WARP_SIZE]; - u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI4_0) % WARP_SIZE]; + u[2*l+0] = y_qs[j * WARP_SIZE_GGUF + (kyqs + l) % WARP_SIZE_GGUF]; + u[2*l+1] = y_qs[j * WARP_SIZE_GGUF + (kyqs + l + QI4_0) % WARP_SIZE_GGUF]; } return vec_dot_q4_0_q8_1_impl - (&x_ql[i * (WARP_SIZE + 1) + k], u, x_dmf[i * (WARP_SIZE/QI4_0) + i/QI4_0 + k/QI4_0], - y_ds[j * (WARP_SIZE/QI8_1) + (2*k/QI8_1) % (WARP_SIZE/QI8_1)]); + (&x_ql[i * (WARP_SIZE_GGUF + 1) + k], u, x_dmf[i * (WARP_SIZE_GGUF/QI4_0) + i/QI4_0 + k/QI4_0], + y_ds[j * (WARP_SIZE_GGUF/QI8_1) + (2*k/QI8_1) % (WARP_SIZE_GGUF/QI8_1)]); } static __device__ __forceinline__ float vec_dot_q4_1_q8_1( @@ -587,8 +587,8 @@ static __device__ __forceinline__ float vec_dot_q4_1_q8_1( } template static __device__ __forceinline__ void allocate_tiles_q4_1(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { - __shared__ int tile_x_qs[mmq_y * (WARP_SIZE) + + mmq_y]; - __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI4_1) + mmq_y/QI4_1]; + __shared__ int tile_x_qs[mmq_y * (WARP_SIZE_GGUF) + + mmq_y]; + __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE_GGUF/QI4_1) + mmq_y/QI4_1]; *x_ql = tile_x_qs; *x_dm = tile_x_dm; } @@ -608,10 +608,10 @@ template static __device__ __forceinlin i = min(i, i_max); } const block_q4_1 * bxi = bx0 + i*blocks_per_row + kbx; - x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8_aligned(bxi->qs, kqsx); + x_ql[i * (WARP_SIZE_GGUF + 1) + k] = get_int_from_uint8_aligned(bxi->qs, kqsx); } - const int blocks_per_tile_x_row = WARP_SIZE / QI4_1; + const int blocks_per_tile_x_row = WARP_SIZE_GGUF / QI4_1; const int kbxd = k % blocks_per_tile_x_row; #pragma unroll @@ -621,7 +621,7 @@ template static __device__ __forceinlin i = min(i, i_max); } const block_q4_1 * bxi = bx0 + i*blocks_per_row + kbxd; - x_dm[i * (WARP_SIZE/QI4_1) + i / QI4_1 + kbxd] = bxi->dm; + x_dm[i * (WARP_SIZE_GGUF/QI4_1) + i / QI4_1 + kbxd] = bxi->dm; } } @@ -634,13 +634,13 @@ static __device__ __forceinline__ float vec_dot_q4_1_q8_1_mul_mat( #pragma unroll for (int l = 0; l < VDR_Q4_1_Q8_1_MMQ; ++l) { - u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l) % WARP_SIZE]; - u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI4_1) % WARP_SIZE]; + u[2*l+0] = y_qs[j * WARP_SIZE_GGUF + (kyqs + l) % WARP_SIZE_GGUF]; + u[2*l+1] = y_qs[j * WARP_SIZE_GGUF + (kyqs + l + QI4_1) % WARP_SIZE_GGUF]; } return vec_dot_q4_1_q8_1_impl - (&x_ql[i * (WARP_SIZE + 1) + k], u, x_dm[i * (WARP_SIZE/QI4_1) + i/QI4_1 + k/QI4_1], - y_ds[j * (WARP_SIZE/QI8_1) + (2*k/QI8_1) % (WARP_SIZE/QI8_1)]); + (&x_ql[i * (WARP_SIZE_GGUF + 1) + k], u, x_dm[i * (WARP_SIZE_GGUF/QI4_1) + i/QI4_1 + k/QI4_1], + y_ds[j * (WARP_SIZE_GGUF/QI8_1) + (2*k/QI8_1) % (WARP_SIZE_GGUF/QI8_1)]); } static __device__ __forceinline__ float vec_dot_q5_0_q8_1( @@ -664,8 +664,8 @@ static __device__ __forceinline__ float vec_dot_q5_0_q8_1( } template static __device__ __forceinline__ void allocate_tiles_q5_0(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { - __shared__ int tile_x_ql[mmq_y * (2*WARP_SIZE) + mmq_y]; - __shared__ float tile_x_d[mmq_y * (WARP_SIZE/QI5_0) + mmq_y/QI5_0]; + __shared__ int tile_x_ql[mmq_y * (2*WARP_SIZE_GGUF) + mmq_y]; + __shared__ float tile_x_d[mmq_y * (WARP_SIZE_GGUF/QI5_0) + mmq_y/QI5_0]; *x_ql = tile_x_ql; *x_dm = (half2 *) tile_x_d; @@ -697,7 +697,7 @@ template static __device__ __forceinlin qs0 |= (qh << 25) & 0x10000000; // 3 -> 28 qs0 = __vsubss4(qs0, 0x10101010); // subtract 16 - x_ql[i * (2*WARP_SIZE + 1) + 2*k+0] = qs0; + x_ql[i * (2*WARP_SIZE_GGUF + 1) + 2*k+0] = qs0; int qs1 = (ql >> 4) & 0x0F0F0F0F; qs1 |= (qh >> 12) & 0x00000010; // 16 -> 4 @@ -706,10 +706,10 @@ template static __device__ __forceinlin qs1 |= (qh << 9) & 0x10000000; // 19 -> 28 qs1 = __vsubss4(qs1, 0x10101010); // subtract 16 - x_ql[i * (2*WARP_SIZE + 1) + 2*k+1] = qs1; + x_ql[i * (2*WARP_SIZE_GGUF + 1) + 2*k+1] = qs1; } - const int blocks_per_tile_x_row = WARP_SIZE / QI5_0; + const int blocks_per_tile_x_row = WARP_SIZE_GGUF / QI5_0; const int kbxd = k % blocks_per_tile_x_row; float * x_dmf = (float *) x_dm; @@ -722,7 +722,7 @@ template static __device__ __forceinlin } const block_q5_0 * bxi = bx0 + i*blocks_per_row + kbxd; - x_dmf[i * (WARP_SIZE/QI5_0) + i / QI5_0 + kbxd] = __half2float(bxi->d); + x_dmf[i * (WARP_SIZE_GGUF/QI5_0) + i / QI5_0 + kbxd] = __half2float(bxi->d); } } @@ -730,7 +730,7 @@ static __device__ __forceinline__ float vec_dot_q5_0_q8_1_mul_mat( const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) { const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2)); - const int index_bx = i * (WARP_SIZE/QI5_0) + i/QI5_0 + k/QI5_0; + const int index_bx = i * (WARP_SIZE_GGUF/QI5_0) + i/QI5_0 + k/QI5_0; const float * x_dmf = (const float *) x_dm; const float * y_df = (const float *) y_ds; @@ -738,12 +738,12 @@ static __device__ __forceinline__ float vec_dot_q5_0_q8_1_mul_mat( #pragma unroll for (int l = 0; l < VDR_Q5_0_Q8_1_MMQ; ++l) { - u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l) % WARP_SIZE]; - u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI5_0) % WARP_SIZE]; + u[2*l+0] = y_qs[j * WARP_SIZE_GGUF + (kyqs + l) % WARP_SIZE_GGUF]; + u[2*l+1] = y_qs[j * WARP_SIZE_GGUF + (kyqs + l + QI5_0) % WARP_SIZE_GGUF]; } return vec_dot_q8_0_q8_1_impl - (&x_ql[i * (2*WARP_SIZE + 1) + 2 * k], u, x_dmf[index_bx], y_df[j * (WARP_SIZE/QI8_1) + (2*k/QI8_1) % (WARP_SIZE/QI8_1)]); + (&x_ql[i * (2*WARP_SIZE_GGUF + 1) + 2 * k], u, x_dmf[index_bx], y_df[j * (WARP_SIZE_GGUF/QI8_1) + (2*k/QI8_1) % (WARP_SIZE_GGUF/QI8_1)]); } static __device__ __forceinline__ float vec_dot_q5_1_q8_1( @@ -767,8 +767,8 @@ static __device__ __forceinline__ float vec_dot_q5_1_q8_1( } template static __device__ __forceinline__ void allocate_tiles_q5_1(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { - __shared__ int tile_x_ql[mmq_y * (2*WARP_SIZE) + mmq_y]; - __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI5_1) + mmq_y/QI5_1]; + __shared__ int tile_x_ql[mmq_y * (2*WARP_SIZE_GGUF) + mmq_y]; + __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE_GGUF/QI5_1) + mmq_y/QI5_1]; *x_ql = tile_x_ql; *x_dm = tile_x_dm; @@ -801,7 +801,7 @@ template static __device__ __forceinlin qs0 |= (qh << 18) & 0x00100000; // 2 -> 20 qs0 |= (qh << 25) & 0x10000000; // 3 -> 28 - x_ql[i * (2*WARP_SIZE + 1) + 2*k+0] = qs0; + x_ql[i * (2*WARP_SIZE_GGUF + 1) + 2*k+0] = qs0; int qs1 = (ql >> 4) & 0x0F0F0F0F; qs1 |= (qh >> 12) & 0x00000010; // 16 -> 4 @@ -809,10 +809,10 @@ template static __device__ __forceinlin qs1 |= (qh << 2) & 0x00100000; // 18 -> 20 qs1 |= (qh << 9) & 0x10000000; // 19 -> 28 - x_ql[i * (2*WARP_SIZE + 1) + 2*k+1] = qs1; + x_ql[i * (2*WARP_SIZE_GGUF + 1) + 2*k+1] = qs1; } - const int blocks_per_tile_x_row = WARP_SIZE / QI5_1; + const int blocks_per_tile_x_row = WARP_SIZE_GGUF / QI5_1; const int kbxd = k % blocks_per_tile_x_row; #pragma unroll @@ -825,7 +825,7 @@ template static __device__ __forceinlin const block_q5_1 * bxi = bx0 + i*blocks_per_row + kbxd; - x_dm[i * (WARP_SIZE/QI5_1) + i / QI5_1 + kbxd] = bxi->dm; + x_dm[i * (WARP_SIZE_GGUF/QI5_1) + i / QI5_1 + kbxd] = bxi->dm; } } @@ -833,18 +833,18 @@ static __device__ __forceinline__ float vec_dot_q5_1_q8_1_mul_mat( const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) { const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2)); - const int index_bx = i * (WARP_SIZE/QI5_1) + + i/QI5_1 + k/QI5_1; + const int index_bx = i * (WARP_SIZE_GGUF/QI5_1) + + i/QI5_1 + k/QI5_1; int u[2*VDR_Q5_1_Q8_1_MMQ]; #pragma unroll for (int l = 0; l < VDR_Q5_1_Q8_1_MMQ; ++l) { - u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l) % WARP_SIZE]; - u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI5_1) % WARP_SIZE]; + u[2*l+0] = y_qs[j * WARP_SIZE_GGUF + (kyqs + l) % WARP_SIZE_GGUF]; + u[2*l+1] = y_qs[j * WARP_SIZE_GGUF + (kyqs + l + QI5_1) % WARP_SIZE_GGUF]; } return vec_dot_q8_1_q8_1_impl - (&x_ql[i * (2*WARP_SIZE + 1) + 2 * k], u, x_dm[index_bx], y_ds[j * (WARP_SIZE/QI8_1) + (2*k/QI8_1) % (WARP_SIZE/QI8_1)]); + (&x_ql[i * (2*WARP_SIZE_GGUF + 1) + 2 * k], u, x_dm[index_bx], y_ds[j * (WARP_SIZE_GGUF/QI8_1) + (2*k/QI8_1) % (WARP_SIZE_GGUF/QI8_1)]); } static __device__ __forceinline__ float vec_dot_q8_0_q8_1( @@ -865,8 +865,8 @@ static __device__ __forceinline__ float vec_dot_q8_0_q8_1( } template static __device__ __forceinline__ void allocate_tiles_q8_0(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { - __shared__ int tile_x_qs[mmq_y * (WARP_SIZE) + mmq_y]; - __shared__ float tile_x_d[mmq_y * (WARP_SIZE/QI8_0) + mmq_y/QI8_0]; + __shared__ int tile_x_qs[mmq_y * (WARP_SIZE_GGUF) + mmq_y]; + __shared__ float tile_x_d[mmq_y * (WARP_SIZE_GGUF/QI8_0) + mmq_y/QI8_0]; *x_ql = tile_x_qs; *x_dm = (half2 *) tile_x_d; @@ -889,10 +889,10 @@ template static __device__ __forceinlin i = min(i, i_max); } const block_q8_0 * bxi = bx0 + i*blocks_per_row + kbx; - x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_int8(bxi->qs, kqsx); + x_ql[i * (WARP_SIZE_GGUF + 1) + k] = get_int_from_int8(bxi->qs, kqsx); } - const int blocks_per_tile_x_row = WARP_SIZE / QI8_0; + const int blocks_per_tile_x_row = WARP_SIZE_GGUF / QI8_0; const int kbxd = k % blocks_per_tile_x_row; #pragma unroll @@ -903,7 +903,7 @@ template static __device__ __forceinlin i = min(i, i_max); } const block_q8_0 * bxi = bx0 + i*blocks_per_row + kbxd; - x_dmf[i * (WARP_SIZE/QI8_0) + i / QI8_0 + kbxd] = __half2float(bxi->d); + x_dmf[i * (WARP_SIZE_GGUF/QI8_0) + i / QI8_0 + kbxd] = __half2float(bxi->d); } } @@ -914,8 +914,8 @@ static __device__ __forceinline__ float vec_dot_q8_0_q8_1_mul_mat( const float * y_df = (const float *) y_ds; return vec_dot_q8_0_q8_1_impl - (&x_ql[i * (WARP_SIZE + 1) + k], &y_qs[j * WARP_SIZE + k], x_dmf[i * (WARP_SIZE/QI8_0) + i/QI8_0 + k/QI8_0], - y_df[j * (WARP_SIZE/QI8_1) + k/QI8_1]); + (&x_ql[i * (WARP_SIZE_GGUF + 1) + k], &y_qs[j * WARP_SIZE_GGUF + k], x_dmf[i * (WARP_SIZE_GGUF/QI8_0) + i/QI8_0 + k/QI8_0], + y_df[j * (WARP_SIZE_GGUF/QI8_1) + k/QI8_1]); } static __device__ __forceinline__ float vec_dot_q2_K_q8_1( @@ -942,9 +942,9 @@ static __device__ __forceinline__ float vec_dot_q2_K_q8_1( } template static __device__ __forceinline__ void allocate_tiles_q2_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { - __shared__ int tile_x_ql[mmq_y * (WARP_SIZE) + mmq_y]; - __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI2_K) + mmq_y/QI2_K]; - __shared__ int tile_x_sc[mmq_y * (WARP_SIZE/4) + mmq_y/4]; + __shared__ int tile_x_ql[mmq_y * (WARP_SIZE_GGUF) + mmq_y]; + __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE_GGUF/QI2_K) + mmq_y/QI2_K]; + __shared__ int tile_x_sc[mmq_y * (WARP_SIZE_GGUF/4) + mmq_y/4]; *x_ql = tile_x_ql; *x_dm = tile_x_dm; @@ -967,10 +967,10 @@ template static __device__ __forceinlin i = min(i, i_max); } const block_q2_K * bxi = bx0 + i*blocks_per_row + kbx; - x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8_aligned(bxi->qs, kqsx); + x_ql[i * (WARP_SIZE_GGUF + 1) + k] = get_int_from_uint8_aligned(bxi->qs, kqsx); } - const int blocks_per_tile_x_row = WARP_SIZE / QI2_K; + const int blocks_per_tile_x_row = WARP_SIZE_GGUF / QI2_K; const int kbxd = k % blocks_per_tile_x_row; #pragma unroll @@ -981,18 +981,18 @@ template static __device__ __forceinlin i = min(i, i_max); } const block_q2_K * bxi = bx0 + i*blocks_per_row + kbxd; - x_dm[i * (WARP_SIZE/QI2_K) + i / QI2_K + kbxd] = bxi->dm; + x_dm[i * (WARP_SIZE_GGUF/QI2_K) + i / QI2_K + kbxd] = bxi->dm; } #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) { - int i = i0 + i_offset * 4 + k / (WARP_SIZE/4); + int i = i0 + i_offset * 4 + k / (WARP_SIZE_GGUF/4); if (need_check) { i = min(i, i_max); } - const block_q2_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/4)) / (QI2_K/4); - x_sc[i * (WARP_SIZE/4) + i / 4 + k % (WARP_SIZE/4)] = get_int_from_uint8_aligned(bxi->scales, k % (QI2_K/4)); + const block_q2_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE_GGUF/4)) / (QI2_K/4); + x_sc[i * (WARP_SIZE_GGUF/4) + i / 4 + k % (WARP_SIZE_GGUF/4)] = get_int_from_uint8_aligned(bxi->scales, k % (QI2_K/4)); } } @@ -1005,7 +1005,7 @@ static __device__ __forceinline__ float vec_dot_q2_K_q8_1_mul_mat( int v[QR2_K*VDR_Q2_K_Q8_1_MMQ]; - const int kqsx = i * (WARP_SIZE + 1) + kbx*QI2_K + (QI2_K/2) * (ky/(2*QI2_K)) + ky % (QI2_K/2); + const int kqsx = i * (WARP_SIZE_GGUF + 1) + kbx*QI2_K + (QI2_K/2) * (ky/(2*QI2_K)) + ky % (QI2_K/2); const int shift = 2 * ((ky % (2*QI2_K)) / (QI2_K/2)); #pragma unroll @@ -1013,10 +1013,10 @@ static __device__ __forceinline__ float vec_dot_q2_K_q8_1_mul_mat( v[l] = (x_ql[kqsx + l] >> shift) & 0x03030303; } - const uint8_t * scales = ((const uint8_t *) &x_sc[i * (WARP_SIZE/4) + i/4 + kbx*4]) + ky/4; + const uint8_t * scales = ((const uint8_t *) &x_sc[i * (WARP_SIZE_GGUF/4) + i/4 + kbx*4]) + ky/4; - const int index_y = j * WARP_SIZE + (QR2_K*k) % WARP_SIZE; - return vec_dot_q2_K_q8_1_impl_mmq(v, &y_qs[index_y], scales, x_dm[i * (WARP_SIZE/QI2_K) + i/QI2_K + kbx], y_df[index_y/QI8_1]); + const int index_y = j * WARP_SIZE_GGUF + (QR2_K*k) % WARP_SIZE_GGUF; + return vec_dot_q2_K_q8_1_impl_mmq(v, &y_qs[index_y], scales, x_dm[i * (WARP_SIZE_GGUF/QI2_K) + i/QI2_K + kbx], y_df[index_y/QI8_1]); } static __device__ __forceinline__ float vec_dot_q3_K_q8_1( @@ -1047,10 +1047,10 @@ static __device__ __forceinline__ float vec_dot_q3_K_q8_1( } template static __device__ __forceinline__ void allocate_tiles_q3_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { - __shared__ int tile_x_ql[mmq_y * (WARP_SIZE) + mmq_y]; - __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI3_K) + mmq_y/QI3_K]; - __shared__ int tile_x_qh[mmq_y * (WARP_SIZE/2) + mmq_y/2]; - __shared__ int tile_x_sc[mmq_y * (WARP_SIZE/4) + mmq_y/4]; + __shared__ int tile_x_ql[mmq_y * (WARP_SIZE_GGUF) + mmq_y]; + __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE_GGUF/QI3_K) + mmq_y/QI3_K]; + __shared__ int tile_x_qh[mmq_y * (WARP_SIZE_GGUF/2) + mmq_y/2]; + __shared__ int tile_x_sc[mmq_y * (WARP_SIZE_GGUF/4) + mmq_y/4]; *x_ql = tile_x_ql; *x_dm = tile_x_dm; @@ -1073,10 +1073,10 @@ template static __device__ __forceinlin i = min(i, i_max); } const block_q3_K * bxi = bx0 + i*blocks_per_row + kbx; - x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8(bxi->qs, kqsx); + x_ql[i * (WARP_SIZE_GGUF + 1) + k] = get_int_from_uint8(bxi->qs, kqsx); } - const int blocks_per_tile_x_row = WARP_SIZE / QI3_K; + const int blocks_per_tile_x_row = WARP_SIZE_GGUF / QI3_K; const int kbxd = k % blocks_per_tile_x_row; float * x_dmf = (float *) x_dm; @@ -1087,27 +1087,27 @@ template static __device__ __forceinlin i = min(i, i_max); } const block_q3_K * bxi = bx0 + i*blocks_per_row + kbxd; - x_dmf[i * (WARP_SIZE/QI3_K) + i / QI3_K + kbxd] = __half2float(bxi->d); + x_dmf[i * (WARP_SIZE_GGUF/QI3_K) + i / QI3_K + kbxd] = __half2float(bxi->d); } #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 2) { - int i = i0 + i_offset * 2 + k / (WARP_SIZE/2); + int i = i0 + i_offset * 2 + k / (WARP_SIZE_GGUF/2); if (need_check) { i = min(i, i_max); } - const block_q3_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/2)) / (QI3_K/2); + const block_q3_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE_GGUF/2)) / (QI3_K/2); // invert the mask with ~ so that a 0/1 results in 4/0 being subtracted - x_qh[i * (WARP_SIZE/2) + i / 2 + k % (WARP_SIZE/2)] = ~get_int_from_uint8(bxi->hmask, k % (QI3_K/2)); + x_qh[i * (WARP_SIZE_GGUF/2) + i / 2 + k % (WARP_SIZE_GGUF/2)] = ~get_int_from_uint8(bxi->hmask, k % (QI3_K/2)); } #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) { - int i = i0 + i_offset * 4 + k / (WARP_SIZE/4); + int i = i0 + i_offset * 4 + k / (WARP_SIZE_GGUF/4); if (need_check) { i = min(i, i_max); } - const block_q3_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/4)) / (QI3_K/4); + const block_q3_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE_GGUF/4)) / (QI3_K/4); const int ksc = k % (QI3_K/4); @@ -1121,7 +1121,7 @@ template static __device__ __forceinlin const int sc = __vsubss4(sc_low | sc_high, 0x20202020); - x_sc[i * (WARP_SIZE/4) + i / 4 + k % (WARP_SIZE/4)] = sc; + x_sc[i * (WARP_SIZE_GGUF/4) + i / 4 + k % (WARP_SIZE_GGUF/4)] = sc; } } @@ -1134,24 +1134,24 @@ static __device__ __forceinline__ float vec_dot_q3_K_q8_1_mul_mat( const float * x_dmf = (const float *) x_dm; const float * y_df = (const float *) y_ds; - const int8_t * scales = ((const int8_t *) (x_sc + i * (WARP_SIZE/4) + i/4 + kbx*4)) + ky/4; + const int8_t * scales = ((const int8_t *) (x_sc + i * (WARP_SIZE_GGUF/4) + i/4 + kbx*4)) + ky/4; int v[QR3_K*VDR_Q3_K_Q8_1_MMQ]; #pragma unroll for (int l = 0; l < QR3_K*VDR_Q3_K_Q8_1_MMQ; ++l) { - const int kqsx = i * (WARP_SIZE + 1) + kbx*QI3_K + (QI3_K/2) * (ky/(2*QI3_K)) + ky % (QI3_K/2); + const int kqsx = i * (WARP_SIZE_GGUF + 1) + kbx*QI3_K + (QI3_K/2) * (ky/(2*QI3_K)) + ky % (QI3_K/2); const int shift = 2 * ((ky % 32) / 8); const int vll = (x_ql[kqsx + l] >> shift) & 0x03030303; - const int vh = x_qh[i * (WARP_SIZE/2) + i/2 + kbx * (QI3_K/2) + (ky+l)%8] >> ((ky+l) / 8); + const int vh = x_qh[i * (WARP_SIZE_GGUF/2) + i/2 + kbx * (QI3_K/2) + (ky+l)%8] >> ((ky+l) / 8); const int vlh = (vh << 2) & 0x04040404; v[l] = __vsubss4(vll, vlh); } - const int index_y = j * WARP_SIZE + (k*QR3_K) % WARP_SIZE; - return vec_dot_q3_K_q8_1_impl_mmq(v, &y_qs[index_y], scales, x_dmf[i * (WARP_SIZE/QI3_K) + i/QI3_K + kbx], y_df[index_y/QI8_1]); + const int index_y = j * WARP_SIZE_GGUF + (k*QR3_K) % WARP_SIZE_GGUF; + return vec_dot_q3_K_q8_1_impl_mmq(v, &y_qs[index_y], scales, x_dmf[i * (WARP_SIZE_GGUF/QI3_K) + i/QI3_K + kbx], y_df[index_y/QI8_1]); } static __device__ __forceinline__ float vec_dot_q4_K_q8_1( @@ -1200,9 +1200,9 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1( } template static __device__ __forceinline__ void allocate_tiles_q4_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { - __shared__ int tile_x_ql[mmq_y * (WARP_SIZE) + mmq_y]; - __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI4_K) + mmq_y/QI4_K]; - __shared__ int tile_x_sc[mmq_y * (WARP_SIZE/8) + mmq_y/8]; + __shared__ int tile_x_ql[mmq_y * (WARP_SIZE_GGUF) + mmq_y]; + __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE_GGUF/QI4_K) + mmq_y/QI4_K]; + __shared__ int tile_x_sc[mmq_y * (WARP_SIZE_GGUF/8) + mmq_y/8]; *x_ql = tile_x_ql; *x_dm = tile_x_dm; @@ -1225,10 +1225,10 @@ template static __device__ __forceinlin i = min(i, i_max); } const block_q4_K * bxi = bx0 + i*blocks_per_row + kbx; - x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8_aligned(bxi->qs, kqsx); + x_ql[i * (WARP_SIZE_GGUF + 1) + k] = get_int_from_uint8_aligned(bxi->qs, kqsx); } - const int blocks_per_tile_x_row = WARP_SIZE / QI4_K; // == 1 if QK_K == 256 + const int blocks_per_tile_x_row = WARP_SIZE_GGUF / QI4_K; // == 1 if QK_K == 256 const int kbxd = k % blocks_per_tile_x_row; // == 0 if QK_K == 256 #pragma unroll @@ -1238,27 +1238,27 @@ template static __device__ __forceinlin i = min(i, i_max); } const block_q4_K * bxi = bx0 + i*blocks_per_row + kbxd; - x_dm[i * (WARP_SIZE/QI4_K) + i / QI4_K + kbxd] = bxi->dm; + x_dm[i * (WARP_SIZE_GGUF/QI4_K) + i / QI4_K + kbxd] = bxi->dm; } #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) { - int i = (i0 + i_offset * 8 + k / (WARP_SIZE/8)) % mmq_y; + int i = (i0 + i_offset * 8 + k / (WARP_SIZE_GGUF/8)) % mmq_y; if (need_check) { i = min(i, i_max); } - const block_q4_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/8)) / (QI4_K/8); + const block_q4_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE_GGUF/8)) / (QI4_K/8); const int * scales = (const int *) bxi->scales; - const int ksc = k % (WARP_SIZE/8); + const int ksc = k % (WARP_SIZE_GGUF/8); // scale arrangement after the following two lines: sc0,...,sc3, sc4,...,sc7, m0,...,m3, m4,...,m8 int scales8 = (scales[(ksc%2) + (ksc!=0)] >> (4 * (ksc & (ksc/2)))) & 0x0F0F0F0F; // lower 4 bits scales8 |= (scales[ksc/2] >> (2 * (ksc % 2))) & 0x30303030; // upper 2 bits - x_sc[i * (WARP_SIZE/8) + i / 8 + ksc] = scales8; + x_sc[i * (WARP_SIZE_GGUF/8) + i / 8 + ksc] = scales8; } } @@ -1267,11 +1267,11 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1_mul_mat( const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) { (void)x_qh; - const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k/16]) + 2*((k % 16) / 8); + const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE_GGUF/8) + i/8 + k/16]) + 2*((k % 16) / 8); - const int index_y = j * WARP_SIZE + (QR4_K*k) % WARP_SIZE; - return vec_dot_q4_K_q8_1_impl_mmq(&x_ql[i * (WARP_SIZE + 1) + k], &y_qs[index_y], sc, sc+8, - x_dm[i * (WARP_SIZE/QI4_K) + i/QI4_K], &y_ds[index_y/QI8_1]); + const int index_y = j * WARP_SIZE_GGUF + (QR4_K*k) % WARP_SIZE_GGUF; + return vec_dot_q4_K_q8_1_impl_mmq(&x_ql[i * (WARP_SIZE_GGUF + 1) + k], &y_qs[index_y], sc, sc+8, + x_dm[i * (WARP_SIZE_GGUF/QI4_K) + i/QI4_K], &y_ds[index_y/QI8_1]); } static __device__ __forceinline__ float vec_dot_q5_K_q8_1( @@ -1321,9 +1321,9 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1( } template static __device__ __forceinline__ void allocate_tiles_q5_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { - __shared__ int tile_x_ql[mmq_y * (2*WARP_SIZE) + mmq_y]; - __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI5_K) + mmq_y/QI5_K]; - __shared__ int tile_x_sc[mmq_y * (WARP_SIZE/8) + mmq_y/8]; + __shared__ int tile_x_ql[mmq_y * (2*WARP_SIZE_GGUF) + mmq_y]; + __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE_GGUF/QI5_K) + mmq_y/QI5_K]; + __shared__ int tile_x_sc[mmq_y * (WARP_SIZE_GGUF/8) + mmq_y/8]; *x_ql = tile_x_ql; *x_dm = tile_x_dm; @@ -1360,11 +1360,11 @@ template static __device__ __forceinlin const int kq0 = ky - ky % (QI5_K/2) + k % (QI5_K/4) + 0; const int kq1 = ky - ky % (QI5_K/2) + k % (QI5_K/4) + (QI5_K/4); - x_ql[i * (2*WARP_SIZE + 1) + kq0] = ql0 | qh0; - x_ql[i * (2*WARP_SIZE + 1) + kq1] = ql1 | qh1; + x_ql[i * (2*WARP_SIZE_GGUF + 1) + kq0] = ql0 | qh0; + x_ql[i * (2*WARP_SIZE_GGUF + 1) + kq1] = ql1 | qh1; } - const int blocks_per_tile_x_row = WARP_SIZE / QI5_K; // == 1 if QK_K == 256 + const int blocks_per_tile_x_row = WARP_SIZE_GGUF / QI5_K; // == 1 if QK_K == 256 const int kbxd = k % blocks_per_tile_x_row; // == 0 if QK_K == 256 #pragma unroll @@ -1376,40 +1376,40 @@ template static __device__ __forceinlin } const block_q5_K * bxi = bx0 + i*blocks_per_row + kbxd; - x_dm[i * (WARP_SIZE/QI5_K) + i / QI5_K + kbxd] = bxi->dm; + x_dm[i * (WARP_SIZE_GGUF/QI5_K) + i / QI5_K + kbxd] = bxi->dm; } #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) { - int i = (i0 + i_offset * 8 + k / (WARP_SIZE/8)) % mmq_y; + int i = (i0 + i_offset * 8 + k / (WARP_SIZE_GGUF/8)) % mmq_y; if (need_check) { i = min(i, i_max); } - const block_q5_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/8)) / (QI5_K/8); + const block_q5_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE_GGUF/8)) / (QI5_K/8); const int * scales = (const int *) bxi->scales; - const int ksc = k % (WARP_SIZE/8); + const int ksc = k % (WARP_SIZE_GGUF/8); // scale arrangement after the following two lines: sc0,...,sc3, sc4,...,sc7, m0,...,m3, m4,...,m8 int scales8 = (scales[(ksc%2) + (ksc!=0)] >> (4 * (ksc & (ksc/2)))) & 0x0F0F0F0F; // lower 4 bits scales8 |= (scales[ksc/2] >> (2 * (ksc % 2))) & 0x30303030; // upper 2 bits - x_sc[i * (WARP_SIZE/8) + i / 8 + ksc] = scales8; + x_sc[i * (WARP_SIZE_GGUF/8) + i / 8 + ksc] = scales8; } } static __device__ __forceinline__ float vec_dot_q5_K_q8_1_mul_mat( const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) { - const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k/16]) + 2 * ((k % 16) / 8); + const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE_GGUF/8) + i/8 + k/16]) + 2 * ((k % 16) / 8); - const int index_x = i * (QR5_K*WARP_SIZE + 1) + QR5_K*k; - const int index_y = j * WARP_SIZE + (QR5_K*k) % WARP_SIZE; + const int index_x = i * (QR5_K*WARP_SIZE_GGUF + 1) + QR5_K*k; + const int index_y = j * WARP_SIZE_GGUF + (QR5_K*k) % WARP_SIZE_GGUF; return vec_dot_q5_K_q8_1_impl_mmq(&x_ql[index_x], &y_qs[index_y], sc, sc+8, - x_dm[i * (WARP_SIZE/QI5_K) + i/QI5_K], &y_ds[index_y/QI8_1]); + x_dm[i * (WARP_SIZE_GGUF/QI5_K) + i/QI5_K], &y_ds[index_y/QI8_1]); } static __device__ __forceinline__ float vec_dot_q6_K_q8_1( @@ -1439,9 +1439,9 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1( } template static __device__ __forceinline__ void allocate_tiles_q6_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { - __shared__ int tile_x_ql[mmq_y * (2*WARP_SIZE) + mmq_y]; - __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI6_K) + mmq_y/QI6_K]; - __shared__ int tile_x_sc[mmq_y * (WARP_SIZE/8) + mmq_y/8]; + __shared__ int tile_x_ql[mmq_y * (2*WARP_SIZE_GGUF) + mmq_y]; + __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE_GGUF/QI6_K) + mmq_y/QI6_K]; + __shared__ int tile_x_sc[mmq_y * (WARP_SIZE_GGUF/8) + mmq_y/8]; *x_ql = tile_x_ql; *x_dm = tile_x_dm; @@ -1478,11 +1478,11 @@ template static __device__ __forceinlin const int kq0 = ky - ky % QI6_K + k % (QI6_K/2) + 0; const int kq1 = ky - ky % QI6_K + k % (QI6_K/2) + (QI6_K/2); - x_ql[i * (2*WARP_SIZE + 1) + kq0] = __vsubss4(ql0 | qh0, 0x20202020); - x_ql[i * (2*WARP_SIZE + 1) + kq1] = __vsubss4(ql1 | qh1, 0x20202020); + x_ql[i * (2*WARP_SIZE_GGUF + 1) + kq0] = __vsubss4(ql0 | qh0, 0x20202020); + x_ql[i * (2*WARP_SIZE_GGUF + 1) + kq1] = __vsubss4(ql1 | qh1, 0x20202020); } - const int blocks_per_tile_x_row = WARP_SIZE / QI6_K; // == 1 if QK_K == 256 + const int blocks_per_tile_x_row = WARP_SIZE_GGUF / QI6_K; // == 1 if QK_K == 256 const int kbxd = k % blocks_per_tile_x_row; // == 0 if QK_K == 256 float * x_dmf = (float *) x_dm; @@ -1496,20 +1496,20 @@ template static __device__ __forceinlin const block_q6_K * bxi = bx0 + i*blocks_per_row + kbxd; - x_dmf[i * (WARP_SIZE/QI6_K) + i / QI6_K + kbxd] = __half2float(bxi->d); + x_dmf[i * (WARP_SIZE_GGUF/QI6_K) + i / QI6_K + kbxd] = __half2float(bxi->d); } #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) { - int i = (i0 + i_offset * 8 + k / (WARP_SIZE/8)) % mmq_y; + int i = (i0 + i_offset * 8 + k / (WARP_SIZE_GGUF/8)) % mmq_y; if (need_check) { i = min(i, i_max); } - const block_q6_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/8)) / 4; + const block_q6_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE_GGUF/8)) / 4; - x_sc[i * (WARP_SIZE/8) + i / 8 + k % (WARP_SIZE/8)] = get_int_from_int8(bxi->scales, k % (QI6_K/8)); + x_sc[i * (WARP_SIZE_GGUF/8) + i / 8 + k % (WARP_SIZE_GGUF/8)] = get_int_from_int8(bxi->scales, k % (QI6_K/8)); } } @@ -1519,11 +1519,11 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1_mul_mat( const float * x_dmf = (const float *) x_dm; const float * y_df = (const float *) y_ds; - const int8_t * sc = ((const int8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k/8]); + const int8_t * sc = ((const int8_t *) &x_sc[i * (WARP_SIZE_GGUF/8) + i/8 + k/8]); - const int index_x = i * (QR6_K*WARP_SIZE + 1) + QR6_K*k; - const int index_y = j * WARP_SIZE + (QR6_K*k) % WARP_SIZE; - return vec_dot_q6_K_q8_1_impl_mmq(&x_ql[index_x], &y_qs[index_y], sc, x_dmf[i * (WARP_SIZE/QI6_K) + i/QI6_K], &y_df[index_y/QI8_1]); + const int index_x = i * (QR6_K*WARP_SIZE_GGUF + 1) + QR6_K*k; + const int index_y = j * WARP_SIZE_GGUF + (QR6_K*k) % WARP_SIZE_GGUF; + return vec_dot_q6_K_q8_1_impl_mmq(&x_ql[index_x], &y_qs[index_y], sc, x_dmf[i * (WARP_SIZE_GGUF/QI6_K) + i/QI6_K], &y_df[index_y/QI8_1]); } static __device__ __forceinline__ float vec_dot_iq2_xxs_q8_1( @@ -1582,7 +1582,7 @@ static __device__ __forceinline__ float vec_dot_iq2_xs_q8_1( static __device__ __forceinline__ float vec_dot_iq2_s_q8_1( const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { -#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM const block_iq2_s * bq2 = (const block_iq2_s *) vbq; const int ib32 = iqs; @@ -1619,7 +1619,7 @@ static __device__ __forceinline__ float vec_dot_iq2_s_q8_1( static __device__ __forceinline__ float vec_dot_iq3_xxs_q8_1( const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { -#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM const block_iq3_xxs * bq2 = (const block_iq3_xxs *) vbq; const int ib32 = iqs; @@ -1646,7 +1646,7 @@ static __device__ __forceinline__ float vec_dot_iq3_xxs_q8_1( static __device__ __forceinline__ float vec_dot_iq3_s_q8_1( const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { -#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM const block_iq3_s * bq2 = (const block_iq3_s *) vbq; const int ib32 = iqs; @@ -1671,7 +1671,7 @@ static __device__ __forceinline__ float vec_dot_iq3_s_q8_1( static __device__ __forceinline__ float vec_dot_iq1_s_q8_1( const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { -#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM const block_iq1_s * bq1 = (const block_iq1_s *) vbq; const int qs_packed = get_int_b2(bq1->qs, iqs); @@ -1703,7 +1703,7 @@ static __device__ __forceinline__ float vec_dot_iq1_s_q8_1( static __device__ __forceinline__ float vec_dot_iq1_m_q8_1( const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { -#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM const block_iq1_m * bq1 = (const block_iq1_m *) vbq; @@ -1763,7 +1763,7 @@ static __device__ __forceinline__ void get_int_from_table_16(const uint32_t & q4 static __device__ __forceinline__ float vec_dot_iq4_nl_q8_1( const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { -#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM const block_iq4_nl * bq = (const block_iq4_nl *) vbq; @@ -1788,7 +1788,7 @@ static __device__ __forceinline__ float vec_dot_iq4_nl_q8_1( static __device__ __forceinline__ float vec_dot_iq4_xs_q8_1( const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { -#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM const block_iq4_xs * bq4 = (const block_iq4_xs *) vbq; const uint8_t * values = (const uint8_t *)kvalues_iq4nl; diff --git a/csrc/quantization/gptq_marlin/gptq_marlin.cu b/csrc/quantization/gptq_marlin/gptq_marlin.cu index 6dbf9594e8492..0c698ced7713d 100644 --- a/csrc/quantization/gptq_marlin/gptq_marlin.cu +++ b/csrc/quantization/gptq_marlin/gptq_marlin.cu @@ -54,9 +54,10 @@ template shared // fetch pipeline - const bool has_act_order, // whether act_order is enabled - const int group_blocks = -1 // number of consecutive 16x16 blocks - // with a separate quantization scale + const bool has_act_order, // whether act_order is enabled + const int group_blocks = -1, // number of consecutive 16x16 blocks + // with a separate quantization scale + const bool is_zp_float // is zero point of float16 type? > __global__ void Marlin( const int4* __restrict__ A, // fp16 input matrix of shape mxk @@ -82,7 +83,7 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, torch::Tensor& workspace, vllm::ScalarTypeId const b_q_type_id, int64_t size_m, int64_t size_n, int64_t size_k, - bool is_k_full, bool has_zp) { + bool is_k_full, bool has_zp, bool is_zp_float) { TORCH_CHECK_NOT_IMPLEMENTED(false, "marlin_gemm(..) requires CUDA_ARCH >= 8.0"); return torch::empty({1, 1}); @@ -516,10 +517,11 @@ template shared // fetch pipeline - const bool has_act_order, // whether act_order is enabled - const bool has_zp, // whether zero-points are enabled - const int group_blocks = -1 // number of consecutive 16x16 blocks - // with a separate quantization scale + const bool has_act_order, // whether act_order is enabled + const bool has_zp, // whether zero-points are enabled + const int group_blocks = -1, // number of consecutive 16x16 blocks + // with a separate quantization scale + const bool is_zp_float // is zero point of float16 type? > __global__ void Marlin( const int4* __restrict__ A, // fp16 input matrix of shape mxk @@ -692,8 +694,10 @@ __global__ void Marlin( int act_s_col_tb_stride = act_s_col_warp_stride * tb_n_warps; // Zero-points sizes/strides - int zp_gl_stride = (prob_n / pack_factor) / 4; - constexpr int zp_sh_stride = ((16 * thread_n_blocks) / pack_factor) / 4; + int zp_gl_stride = is_zp_float ? prob_n / 8 : (prob_n / pack_factor) / 4; + constexpr int zp_sh_stride = is_zp_float + ? 16 * thread_n_blocks / 8 + : ((16 * thread_n_blocks) / pack_factor) / 4; constexpr int zp_tb_groups = s_tb_groups; constexpr int zp_sh_stage = has_zp ? zp_tb_groups * zp_sh_stride : 0; int zp_gl_rd_delta = zp_gl_stride; @@ -768,9 +772,16 @@ __global__ void Marlin( constexpr int num_ints_per_thread = 8 / pack_factor; int zp_sh_rd; if constexpr (has_zp) { - zp_sh_rd = num_ints_per_thread * num_col_threads * - ((threadIdx.x / 32) % (thread_n_blocks / 4)) + - num_ints_per_thread * ((threadIdx.x % 32) / num_row_threads); + if constexpr (is_zp_float) { + if constexpr (group_blocks != -1) { + zp_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + (threadIdx.x % 32) / 4; + } + } else { + zp_sh_rd = num_ints_per_thread * num_col_threads * + ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + num_ints_per_thread * ((threadIdx.x % 32) / num_row_threads); + } } // Precompute which thread should not read memory in which iterations; this is @@ -832,6 +843,7 @@ __global__ void Marlin( FragS act_frag_s[2][4][4]; // For act-order int frag_qzp[2][num_ints_per_thread]; // Zero-points FragZP frag_zp; // Zero-points in fp16 + FragZP frag_zpf[2]; // Zero-points in fp16 in HQQ // Zero accumulators. auto zero_accums = [&]() { @@ -1126,7 +1138,7 @@ __global__ void Marlin( // has_zp implies AWQ, which doesn't have act_order, static_assert(!has_zp || group_blocks != 0); - if constexpr (has_zp) { + if constexpr (has_zp && !is_zp_float) { int pipe = full_pipe % stages; if constexpr (group_blocks == -1) { @@ -1170,11 +1182,44 @@ __global__ void Marlin( } } } + + else if constexpr (has_zp && is_zp_float) { + int pipe = full_pipe % stages; + + if constexpr (group_blocks != -1) { + if constexpr (group_blocks >= thread_k_blocks) { + int4* sh_zp_stage = + sh_zp + zp_sh_stage * ((group_blocks / thread_k_blocks) * + (pipe / (group_blocks / thread_k_blocks))); + reinterpret_cast(&frag_zpf[k % 2])[0] = sh_zp_stage[zp_sh_rd]; + } else { + int warp_id = threadIdx.x / 32; + int n_warps = thread_n_blocks / 4; + + int warp_row = warp_id / n_warps; + + int cur_k = warp_row * 16; + cur_k += k_iter_size * (k % b_sh_wr_iters); + + int k_blocks = cur_k / 16; + // Suppress bogus and persistent divide-by-zero warning + #pragma nv_diagnostic push + #pragma nv_diag_suppress divide_by_zero + int cur_group_id = k_blocks / group_blocks; + #pragma nv_diagnostic pop + + int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe; + + reinterpret_cast(&frag_zpf[k % 2])[0] = + sh_zp_stage[zp_sh_rd + cur_group_id * zp_sh_stride]; + } + } + } }; // Execute the actual tensor core matmul of a sub-tile. auto matmul = [&](int k) { - if constexpr (has_zp) { + if constexpr (has_zp && !is_zp_float) { FragB frag_zp_0; FragB frag_zp_1; int zp_quant_0, zp_quant_1; @@ -1219,10 +1264,14 @@ __global__ void Marlin( frag_b1 = dequant(b_quant_1); // Apply zero-point to frag_b0 - if constexpr (has_zp) { + if constexpr (has_zp && !is_zp_float) { sub_zp(frag_b0, frag_zp[j], 0); } + else if constexpr (has_zp && is_zp_float && group_blocks != -1) { + sub_zp(frag_b0, frag_zpf[k % 2][j], 0); + } + // Apply scale to frag_b0 if constexpr (has_act_order) { scale4(frag_b0, act_frag_s[k % 2][0][j], @@ -1235,10 +1284,14 @@ __global__ void Marlin( } // Apply zero-point to frag_b1 - if constexpr (has_zp) { + if constexpr (has_zp && !is_zp_float) { sub_zp(frag_b1, frag_zp[j], 1); } + else if constexpr (has_zp && is_zp_float && group_blocks != -1) { + sub_zp(frag_b1, frag_zpf[k % 2][j], 1); + } + // Apply scale to frag_b1 if constexpr (has_act_order) { scale4(frag_b1, act_frag_s[k % 2][0][j], @@ -1510,7 +1563,7 @@ __global__ void Marlin( fetch_scales_to_shared(true, g_idx[slice_k_start], g_idx[last_g_idx]); } - if constexpr (has_zp && group_blocks == -1) { + if constexpr (has_zp && !is_zp_float && group_blocks == -1) { if (i == 0) { fetch_zp_to_shared(); } @@ -1697,23 +1750,27 @@ __global__ void Marlin( } #define __CALL_IF(W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \ - HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS, NUM_THREADS) \ + HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS, NUM_THREADS, \ + IS_ZP_FLOAT) \ else if (q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && \ thread_n_blocks == THREAD_N_BLOCKS && \ thread_k_blocks == THREAD_K_BLOCKS && \ has_act_order == HAS_ACT_ORDER && has_zp == HAS_ZP && \ - group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS) { \ - cudaFuncSetAttribute( \ - Marlin, \ - cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ - Marlin \ - <<>>( \ - A_ptr, B_ptr, C_ptr, C_tmp_ptr, s_ptr, zp_ptr, g_idx_ptr, \ - num_groups, prob_m, prob_n, prob_k, locks, use_fp32_reduce); \ + group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS && \ + is_zp_float == IS_ZP_FLOAT) { \ + if constexpr (!IS_ZP_FLOAT || std::is_same::value) { \ + cudaFuncSetAttribute( \ + Marlin, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ + Marlin \ + <<>>( \ + A_ptr, B_ptr, C_ptr, C_tmp_ptr, s_ptr, zp_ptr, g_idx_ptr, \ + num_groups, prob_m, prob_n, prob_k, locks, use_fp32_reduce); \ + } \ } typedef struct { @@ -1905,51 +1962,96 @@ exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k, } #define GPTQ_CALL_IF(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - __CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \ - __CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \ - __CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \ - __CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \ + __CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS, \ + false) \ + __CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS, \ + false) \ + __CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS, \ + false) \ + __CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS, \ + false) \ \ - __CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \ - __CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \ - __CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \ - __CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) \ + __CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS, \ + false) \ + __CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS, \ + false) \ + __CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS, \ + false) \ + __CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS, \ + false) \ \ - __CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \ - __CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \ - __CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \ - __CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) \ + __CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS, \ + false) \ + __CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS, \ + false) \ + __CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS, \ + false) \ + __CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS, \ + false) \ \ - __CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \ - __CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \ - __CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \ - __CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) \ + __CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS, \ + false) \ + __CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS, \ + false) \ + __CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS, \ + false) \ + __CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS, \ + false) \ \ - __CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \ - __CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \ - __CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \ - __CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) + __CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS, \ + false) \ + __CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS, \ + false) \ + __CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS, \ + false) \ + __CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS, \ + false) #define AWQ_CALL_IF(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - __CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \ - __CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \ - __CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \ - __CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \ + __CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS, \ + false) \ + __CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS, \ + false) \ + __CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS, \ + false) \ + __CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS, \ + false) \ \ - __CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \ - __CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \ - __CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \ - __CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \ + __CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS, \ + false) \ + __CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS, \ + false) \ + __CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS, \ + false) \ + __CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS, \ + false) \ \ - __CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \ - __CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \ - __CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \ - __CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \ + __CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS, \ + false) \ + __CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS, \ + false) \ + __CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS, \ + false) \ + __CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS, \ + false) \ \ - __CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \ - __CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \ - __CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \ - __CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) + __CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS, \ + false) \ + __CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS, \ + false) \ + __CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS, \ + false) \ + __CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS, false) + + // We currently have 4-bit models only with group_blocks == 4 + #define HQQ_CALL_IF(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + __CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS, \ + true) \ + __CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS, \ + true) \ + __CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS, \ + true) \ + __CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS, true) template void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s, @@ -1958,7 +2060,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s, vllm::ScalarType const& q_type, bool has_act_order, bool is_k_full, bool has_zp, int num_groups, int group_size, int dev, cudaStream_t stream, int thread_k, int thread_n, - int sms, int max_par, bool use_fp32_reduce) { + int sms, int max_par, bool use_fp32_reduce, bool is_zp_float) { if (has_zp) { TORCH_CHECK( q_type == vllm::kU4 || q_type == vllm::kU8, @@ -2111,6 +2213,11 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s, AWQ_CALL_IF(vllm::kU8, 8, 8, 256) AWQ_CALL_IF(vllm::kU8, 8, 4, 128) AWQ_CALL_IF(vllm::kU8, 4, 8, 128) + + HQQ_CALL_IF(vllm::kU4, 16, 4, 256) + HQQ_CALL_IF(vllm::kU4, 8, 8, 256) + HQQ_CALL_IF(vllm::kU4, 8, 4, 128) + HQQ_CALL_IF(vllm::kU4, 4, 8, 128) else { TORCH_CHECK(false, "Unsupported shapes: MNK = [", prob_m, ", ", prob_n, ", ", prob_k, "]", ", has_act_order = ", has_act_order, @@ -2135,7 +2242,7 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, vllm::ScalarTypeId const& b_q_type_id, int64_t size_m, int64_t size_n, int64_t size_k, bool is_k_full, bool has_zp, - bool use_fp32_reduce) { + bool use_fp32_reduce, bool is_zp_float) { vllm::ScalarType const b_q_type = vllm::ScalarType::from_id(b_q_type_id); if (has_zp) { TORCH_CHECK( @@ -2148,6 +2255,12 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, b_q_type.str()); } + if (has_zp && is_zp_float) { + TORCH_CHECK(a.scalar_type() == at::ScalarType::Half, + "Computation type must be float16 (half) when using float zero " + "points."); + } + int pack_factor = 32 / b_q_type.size_bits(); // Verify A @@ -2257,12 +2370,22 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, if (has_zp) { int rank = b_zeros.sizes().size(); TORCH_CHECK(rank == 2, "b_zeros rank = ", rank, " is not 2"); - TORCH_CHECK(b_zeros.size(0) == num_groups, - "b_zeros dim 0 = ", b_zeros.size(0), - " is not num_groups = ", num_groups); - TORCH_CHECK(b_zeros.size(1) == size_n / pack_factor, - "b_zeros dim 1 = ", b_zeros.size(1), - " is not size_n / pack_factor = ", size_n / pack_factor); + if (is_zp_float) { + TORCH_CHECK(b_zeros.size(1) == size_n, + "b_zeros dim 1 = ", b_zeros.size(1), + " is not size_n = ", size_n); + TORCH_CHECK(num_groups == b_zeros.size(0), + "b_zeros dim 0 = ", b_zeros.size(0), + " is not num_groups = ", num_groups); + TORCH_CHECK(num_groups != -1, "num_groups must be != -1"); + } else { + TORCH_CHECK(b_zeros.size(0) == num_groups, + "b_zeros dim 0 = ", b_zeros.size(0), + " is not num_groups = ", num_groups); + TORCH_CHECK(b_zeros.size(1) == size_n / pack_factor, + "b_zeros dim 1 = ", b_zeros.size(1), + " is not size_n / pack_factor = ", size_n / pack_factor); + } } // Verify workspace size @@ -2282,7 +2405,7 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, a_tmp.data_ptr(), size_m, size_n, size_k, workspace.data_ptr(), b_q_type, has_act_order, is_k_full, has_zp, num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev), - thread_k, thread_n, sms, marlin::max_par, use_fp32_reduce); + thread_k, thread_n, sms, marlin::max_par, use_fp32_reduce, is_zp_float); } else if (a.scalar_type() == at::ScalarType::BFloat16) { marlin::marlin_mm( a.data_ptr(), b_q_weight.data_ptr(), @@ -2291,7 +2414,7 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, perm.data_ptr(), a_tmp.data_ptr(), size_m, size_n, size_k, workspace.data_ptr(), b_q_type, has_act_order, is_k_full, has_zp, num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev), - thread_k, thread_n, sms, marlin::max_par, use_fp32_reduce); + thread_k, thread_n, sms, marlin::max_par, use_fp32_reduce, is_zp_float); } else { TORCH_CHECK(false, "gpt_marlin_gemm only supports bfloat16 and float16"); } diff --git a/csrc/quantization/machete/generate.py b/csrc/quantization/machete/generate.py index d126af1849024..ac63afe79a255 100644 --- a/csrc/quantization/machete/generate.py +++ b/csrc/quantization/machete/generate.py @@ -3,8 +3,10 @@ import os import shutil from collections.abc import Iterable -from dataclasses import dataclass -from typing import List, Optional, Tuple, Union +from copy import deepcopy +from dataclasses import dataclass, fields +from functools import reduce +from typing import Dict, List, Optional, Tuple, Union import jinja2 # yapf conflicts with isort for this block @@ -14,7 +16,10 @@ MixedInputKernelScheduleType, TileSchedulerTag, TileSchedulerType, VLLMDataType, - VLLMDataTypeNames, VLLMDataTypeTag, + VLLMDataTypeNames, + VLLMDataTypeSize, VLLMDataTypeTag, + VLLMDataTypeTorchDataTypeTag, + VLLMDataTypeVLLMScalarTypeTag, VLLMKernelScheduleTag) # yapf: enable @@ -27,49 +32,125 @@ #include "../machete_mm_launcher.cuh" namespace machete { -using GemmDispatcher_ = GemmDispatcher< - {{DataTypeTag[type_config.element_a]}}, // ElementA - {{DataTypeTag[type_config.element_b]}}, // ElementB - {{DataTypeTag[type_config.element_d]}}, // ElementD - {{DataTypeTag[type_config.accumulator]}}, // Accumulator - {{DataTypeTag[type_config.element_b_scale]}}, // Scales - {{DataTypeTag[type_config.element_b_zeropoint]}}>; // Zeropoints - -{% for s in schedules %}extern torch::Tensor -impl_{{type_name}}_sch_{{ gen_sch_name(s) }}(PyTorchArguments args); -{% endfor %} -template <> -torch::Tensor GemmDispatcher_::dispatch(PyTorchArguments args) { + +{% for impl_config in impl_configs %} +{% set type_sig = gen_type_sig(impl_config.types) -%} +{% for s in impl_config.schedules %} +extern torch::Tensor impl_{{type_sig}}_sch_{{gen_sch_sig(s)}}(MMArgs); +{%- endfor %} + +torch::Tensor mm_dispatch_{{type_sig}}(MMArgs args) { [[maybe_unused]] auto M = args.A.size(0); [[maybe_unused]] auto N = args.B.size(1); [[maybe_unused]] auto K = args.A.size(1); - if (!args.schedule) { - {%- for cond, s in heuristic %} + if (!args.maybe_schedule) { + {%- for cond, s in impl_config.heuristic %} {%if cond is not none%}if ({{cond}}) {%- else %}else {%- endif %} - return impl_{{ type_name }}_sch_{{ gen_sch_name(s) }}(args);{% endfor %} + return impl_{{type_sig}}_sch_{{ gen_sch_sig(s) }}(args);{% endfor %} } - {% for s in schedules %} - if (*args.schedule == "{{ gen_sch_name(s) }}") { - return impl_{{ type_name }}_sch_{{ gen_sch_name(s) }}(args); - } - {% endfor %} + {%- for s in impl_config.schedules %} + if (*args.maybe_schedule == "{{ gen_sch_sig(s) }}") + return impl_{{type_sig}}_sch_{{ gen_sch_sig(s) }}(args); + {%- endfor %} TORCH_CHECK_NOT_IMPLEMENTED(false, "machete_gemm(..) is not implemented for " - "schedule = ", *args.schedule); + "schedule = ", *args.maybe_schedule); } +{%- endfor %} + -template <> -std::vector GemmDispatcher_::supported_schedules() { - return { - {% for s in schedules -%} - "{{ gen_sch_name(s) }}"{{ ", - " if not loop.last }}{%- endfor %} - }; +static inline std::optional maybe_scalartype( + c10::optional const& t) { + if (!t) { + return std::nullopt; + } else { + return t->scalar_type(); + }; +} + +torch::Tensor mm_dispatch(MMArgs args) { + auto out_type = args.maybe_out_type.value_or(args.A.scalar_type()); + auto a_type = args.A.scalar_type(); + auto maybe_g_scales_type = maybe_scalartype(args.maybe_group_scales); + auto maybe_g_zeros_type = maybe_scalartype(args.maybe_group_zeros); + auto maybe_ch_scales_type = maybe_scalartype(args.maybe_channel_scales); + auto maybe_tok_scales_type = maybe_scalartype(args.maybe_token_scales); + + {% for impl_config in impl_configs %} + {% set t = impl_config.types -%} + {% set type_sig = gen_type_sig(t) -%} + if (args.b_type == {{VLLMScalarTypeTag[t.b]}} + && a_type == {{TorchTypeTag[t.a]}} + && out_type == {{TorchTypeTag[t.out]}} + && {%if t.b_group_scale != void -%} + maybe_g_scales_type == {{TorchTypeTag[t.b_group_scale]}} + {%- else %}!maybe_g_scales_type{%endif%} + && {%if t.b_group_zeropoint != void -%} + maybe_g_zeros_type == {{TorchTypeTag[t.b_group_zeropoint]}} + {%- else %}!maybe_g_zeros_type{%endif%} + && {%if t.b_channel_scale != void -%} + maybe_ch_scales_type == {{TorchTypeTag[t.b_channel_scale]}} + {%- else %}!maybe_ch_scales_type{%endif%} + && {%if t.a_token_scale != void -%} + maybe_tok_scales_type == {{TorchTypeTag[t.a_token_scale]}} + {%- else %}!maybe_tok_scales_type{%endif%} + ) { + return mm_dispatch_{{type_sig}}(args); + } + {%- endfor %} + + TORCH_CHECK_NOT_IMPLEMENTED( + false, "machete_mm(..) is not implemented for " + "a_type=", args.A.scalar_type(), + ", b_type=", args.b_type.str(), + ", out_type=", out_type, + ", with_group_scale_type=", maybe_g_scales_type + ? toString(*maybe_g_scales_type) : "None", + ", with_group_zeropoint_type=", maybe_g_zeros_type + ? toString(*maybe_g_zeros_type) : "None", + ", with_channel_scale_type=", maybe_ch_scales_type + ? toString(*maybe_ch_scales_type) : "None", + ", with_token_scale_type=", maybe_tok_scales_type + ? toString(*maybe_tok_scales_type) : "None", + "; implemented types are: \\n", + {%- for impl_config in impl_configs %} + {% set t = impl_config.types -%} + "\\t{{gen_type_option_name(t)}}\\n", + {%- endfor %} + ""); } +std::vector supported_schedules_dispatch( + SupportedSchedulesArgs args) { + auto out_type = args.maybe_out_type.value_or(args.a_type); + + {% for impl_config in impl_configs %} + {% set t = impl_config.types -%} + {% set schs = impl_config.schedules -%} + if (args.b_type == {{VLLMScalarTypeTag[t.b]}} + && args.a_type == {{TorchTypeTag[t.a]}} + && out_type == {{TorchTypeTag[t.out]}} + && {%if t.b_group_scale != void -%} + args.maybe_group_scales_type == {{TorchTypeTag[t.b_group_scale]}} + {%- else %}!args.maybe_group_scales_type{%endif%} + && {%if t.b_group_zeropoint != void-%} + args.maybe_group_zeros_type == {{TorchTypeTag[t.b_group_zeropoint]}} + {%- else %}!args.maybe_group_zeros_type{%endif%} + ) { + return { + {%- for s in impl_config.schedules %} + "{{gen_sch_sig(s)}}"{% if not loop.last %},{% endif %} + {%- endfor %} + }; + } + {%- endfor %} + + return {}; +}; + }; // namespace machete """ @@ -77,20 +158,10 @@ #include "../machete_mm_launcher.cuh" namespace machete { -template -using Kernel = MacheteKernelTemplate< - {{DataTypeTag[type_config.element_a]}}, // ElementA - {{DataTypeTag[type_config.element_b]}}, // ElementB - {{DataTypeTag[type_config.element_d]}}, // ElementD - {{DataTypeTag[type_config.accumulator]}}, // Accumulator - {{DataTypeTag[type_config.element_b_scale]}}, // Scales - {{DataTypeTag[type_config.element_b_zeropoint]}}, // Zeropoints - cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput, - Config, with_C, with_scales, with_zeropoints>; - -{% for sch in schedules %} -{% set schedule_name = gen_sch_name(sch) -%} -struct sch_{{schedule_name}} { + +{% for sch in unique_schedules(impl_configs) %} +{% set sch_sig = gen_sch_sig(sch) -%} +struct sch_{{sch_sig}} { using TileShapeNM = Shape<{{ to_cute_constant(sch.tile_shape_mn)|join(', ')}}>; using ClusterShape = Shape<{{ @@ -101,27 +172,34 @@ using TileScheduler = {{TileSchedulerTag[sch.tile_scheduler]}}; using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto; }; - +{% endfor %} + +{% for impl_config in impl_configs %} +{% set t = impl_config.types -%} +{% set schs = impl_config.schedules -%} +{% set type_sig = gen_type_sig(t) -%} + +template +using Kernel_{{type_sig}} = MacheteKernelTemplate< + {{DataTypeTag[t.a]}}, // ElementA + {{DataTypeTag[t.b]}}, // ElementB + {{DataTypeTag[t.out]}}, // ElementD + {{DataTypeTag[t.accumulator]}}, // Accumulator + {{DataTypeTag[t.b_group_scale]}}, // GroupScaleT + {{DataTypeTag[t.b_group_zeropoint]}}, // GroupZeroT + {{DataTypeTag[t.b_channel_scale]}}, // ChannelScaleT + {{DataTypeTag[t.a_token_scale]}}, // TokenScaleT + cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput, + Sch>; + +{% for sch in schs %} +{% set sch_sig = gen_sch_sig(sch) -%} torch::Tensor -impl_{{type_name}}_sch_{{schedule_name}}(PyTorchArguments args) { - bool with_C = args.C.has_value(), with_scales = args.scales.has_value(), - with_zeropoints = args.zeros.has_value(); - - {% for s in specializations %} - if (with_C == {{s.with_C|lower}} - && with_zeropoints == {{s.with_zeropoints|lower}} - && with_scales == {{s.with_scales|lower}}) { - return run_impl>(args); - }{% endfor %} - - TORCH_CHECK_NOT_IMPLEMENTED( - false, "for the sake of compile times and binary size machete_mm(..) is " - " not implemented for with_C=", with_C, ", with_scales=", with_scales, - ", with_zeropoints=", with_zeropoints, - " (for {{type_name}}_sch_{{schedule_name}})"); +impl_{{type_sig}}_sch_{{sch_sig}}(MMArgs args) { + return run_impl>(args); } -{% endfor %} +{%- endfor %} +{%- endfor %} }; // namespace machete """ @@ -130,26 +208,34 @@ #include "../machete_prepack_launcher.cuh" namespace machete { -using PrepackBDispatcher_ = PrepackBDispatcher< - {{DataTypeTag[type_config.element_a]}}, // ElementA - {{DataTypeTag[type_config.element_b]}}, // ElementB - {{DataTypeTag[type_config.element_d]}}, // ElementD - {{DataTypeTag[type_config.accumulator]}}, // Accumulator - {{DataTypeTag[type_config.element_b_scale]}}, // Scales - {{DataTypeTag[type_config.element_b_zeropoint]}}>; // Zeropoints - -using PrepackedLayoutB = PrepackedLayoutBTemplate< - {{DataTypeTag[type_config.element_a]}}, // ElementA - {{DataTypeTag[type_config.element_b]}}, // ElementB - {{DataTypeTag[type_config.element_d]}}, // ElementD - {{DataTypeTag[type_config.accumulator]}}, // Accumulator - cutlass::layout::ColumnMajor, - cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput>; - -template <> -torch::Tensor PrepackBDispatcher_::dispatch(torch::Tensor B) { - return prepack_impl(B); + +torch::Tensor prepack_B_dispatch(PrepackBArgs args) { + auto convert_type = args.maybe_group_scales_type.value_or(args.a_type); + {%- for t in types %} + {% set b_type = unsigned_type_with_bitwidth(t.b_num_bits) %} + if (args.a_type == {{TorchTypeTag[t.a]}} + && args.b_type.size_bits() == {{t.b_num_bits}} + && convert_type == {{TorchTypeTag[t.convert]}}) { + return prepack_impl< + PrepackedLayoutBTemplate< + {{DataTypeTag[t.a]}}, // ElementA + {{DataTypeTag[b_type]}}, // ElementB + {{DataTypeTag[t.convert]}}, // ElementConvert + {{DataTypeTag[t.accumulator]}}, // Accumulator + cutlass::layout::ColumnMajor, + cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput> + >(args.B); + } + {%- endfor %} + + TORCH_CHECK_NOT_IMPLEMENTED(false, + "prepack_B_dispatch(..) is not implemented for " + "atype = ", args.a_type, + ", b_type = ", args.b_type.str(), + ", with_group_scales_type= ", args.maybe_group_scales_type ? + toString(*args.maybe_group_scales_type) : "None"); } + }; // namespace machete """ @@ -166,32 +252,34 @@ class ScheduleConfig: tile_scheduler: TileSchedulerType -@dataclass +@dataclass(frozen=True) class TypeConfig: - element_a: DataType - element_b: Union[DataType, VLLMDataType] - element_b_scale: DataType - element_b_zeropoint: DataType - element_d: DataType + a: DataType + b: Union[DataType, VLLMDataType] + b_group_scale: DataType + b_group_zeropoint: DataType + b_channel_scale: DataType + a_token_scale: DataType + out: DataType accumulator: DataType -@dataclass -class Specialization: - with_C: bool - with_zeropoints: bool - with_scales: bool +@dataclass(frozen=True) +class PrepackTypeConfig: + a: DataType + b_num_bits: int + convert: DataType + accumulator: DataType @dataclass class ImplConfig: - type_config: TypeConfig - schedule_configs: List[ScheduleConfig] - specializations: List[Specialization] + types: TypeConfig + schedules: List[ScheduleConfig] heuristic: List[Tuple[Optional[str], ScheduleConfig]] -def generate_schedule_name(schedule_config: ScheduleConfig) -> str: +def generate_sch_sig(schedule_config: ScheduleConfig) -> str: tile_shape = ( f"{schedule_config.tile_shape_mn[0]}x{schedule_config.tile_shape_mn[1]}" ) @@ -209,40 +297,34 @@ def generate_schedule_name(schedule_config: ScheduleConfig) -> str: f"_{epilogue_schedule}_{tile_scheduler}") -# mostly unique shorter schedule_name -def generate_terse_schedule_name(schedule_config: ScheduleConfig) -> str: +# mostly unique shorter sch_sig +def generate_terse_sch_sig(schedule_config: ScheduleConfig) -> str: kernel_terse_names_replace = { "KernelTmaWarpSpecializedCooperativeMixedInput_": "TmaMI_", "TmaWarpSpecializedCooperative_": "TmaCoop_", "StreamKScheduler": "streamK", } - schedule_name = generate_schedule_name(schedule_config) + sch_sig = generate_sch_sig(schedule_config) for orig, terse in kernel_terse_names_replace.items(): - schedule_name = schedule_name.replace(orig, terse) - return schedule_name + sch_sig = sch_sig.replace(orig, terse) + return sch_sig # unique type_name -def generate_type_signature(kernel_type_config: TypeConfig): - element_a = VLLMDataTypeNames[kernel_type_config.element_a] - element_b = VLLMDataTypeNames[kernel_type_config.element_b] - element_d = VLLMDataTypeNames[kernel_type_config.element_d] - accumulator = VLLMDataTypeNames[kernel_type_config.accumulator] - element_scale = VLLMDataTypeNames[kernel_type_config.element_b_scale] - element_zeropoint = VLLMDataTypeNames[ - kernel_type_config.element_b_zeropoint] - - return (f"{element_a}{element_b}{element_d}" - f"{accumulator}{element_scale}{element_zeropoint}") - +def generate_type_signature(kernel_types: TypeConfig): + return str("".join([ + VLLMDataTypeNames[getattr(kernel_types, field.name)] + for field in fields(TypeConfig) + ])) -# non-unique shorter type_name -def generate_terse_type_signature(kernel_type_config: TypeConfig): - element_a = VLLMDataTypeNames[kernel_type_config.element_a] - element_b = VLLMDataTypeNames[kernel_type_config.element_b] - return f"{element_a}{element_b}" +def generate_type_option_name(kernel_types: TypeConfig): + return ", ".join([ + f"{field.name.replace('b_', 'with_')+'_type'}=" + + VLLMDataTypeNames[getattr(kernel_types, field.name)] + for field in fields(TypeConfig) + ]) def is_power_of_two(n): @@ -263,13 +345,36 @@ def _to_cute_constant(value: int): return _to_cute_constant(value) +def unique_schedules(impl_configs: List[ImplConfig]): + return list( + set(sch for impl_config in impl_configs + for sch in impl_config.schedules)) + + +def unsigned_type_with_bitwidth(num_bits): + return { + 4: DataType.u4, + 8: DataType.u8, + 16: DataType.u16, + 32: DataType.u32, + 64: DataType.u64, + }[num_bits] + + template_globals = { + "void": DataType.void, "DataTypeTag": VLLMDataTypeTag, + "VLLMScalarTypeTag": VLLMDataTypeVLLMScalarTypeTag, + "TorchTypeTag": VLLMDataTypeTorchDataTypeTag, "KernelScheduleTag": VLLMKernelScheduleTag, "EpilogueScheduleTag": EpilogueScheduleTag, "TileSchedulerTag": TileSchedulerTag, "to_cute_constant": to_cute_constant, - "gen_sch_name": generate_terse_schedule_name, + "gen_sch_sig": generate_terse_sch_sig, + "gen_type_sig": generate_type_signature, + "unique_schedules": unique_schedules, + "unsigned_type_with_bitwidth": unsigned_type_with_bitwidth, + "gen_type_option_name": generate_type_option_name } @@ -284,42 +389,82 @@ def create_template(template_str): prepack_dispatch_template = create_template(PREPACK_TEMPLATE) -def create_sources(impl_config: ImplConfig, num_impl_files=1): +def create_sources(impl_configs: List[ImplConfig], num_impl_files=8): sources = [] - type_name = generate_type_signature(impl_config.type_config) - terse_type_name = generate_terse_type_signature(impl_config.type_config) - sources.append(( - f"machete_mm_{terse_type_name}", - mm_dispatch_template.render(type_name=type_name, - type_config=impl_config.type_config, - schedules=impl_config.schedule_configs, - heuristic=impl_config.heuristic), + "machete_mm_dispatch", + mm_dispatch_template.render(impl_configs=impl_configs), )) + prepack_types = [] + for impl_config in impl_configs: + convert_type = impl_config.types.a \ + if impl_config.types.b_group_scale == DataType.void \ + else impl_config.types.b_group_scale + prepack_types.append( + PrepackTypeConfig( + a=impl_config.types.a, + b_num_bits=VLLMDataTypeSize[impl_config.types.b], + convert=convert_type, + accumulator=impl_config.types.accumulator, + )) + + def prepacked_type_key(prepack_type: PrepackTypeConfig): + # For now we we can just use the first accumulator type seen since + # the tensor core shapes/layouts don't vary based on accumulator + # type so we can generate less code this way + return (prepack_type.a, prepack_type.b_num_bits, prepack_type.convert) + + unique_prepack_types = [] + prepack_types_seen = set() + for prepack_type in prepack_types: + key = prepacked_type_key(prepack_type) + if key not in prepack_types_seen: + unique_prepack_types.append(prepack_type) + prepack_types_seen.add(key) + sources.append(( - f"machete_prepack_{terse_type_name}", - prepack_dispatch_template.render( - type_name=type_name, - type_config=impl_config.type_config, - ), + "machete_prepack", + prepack_dispatch_template.render(types=unique_prepack_types, ), )) - num_schedules = len(impl_config.schedule_configs) - schedules_per_file = math.ceil(num_schedules / num_impl_files) - for part, i in enumerate(range(0, num_schedules, schedules_per_file)): - file_schedules = impl_config.schedule_configs[i:i + schedules_per_file] + # Split up impls across files + num_impls = reduce(lambda x, y: x + len(y.schedules), impl_configs, 0) + num_impls_per_file = math.ceil(num_impls / num_impl_files) + + files_impls: List[List[ImplConfig]] = [[]] + + curr_num_impls_assigned = 0 + curr_impl_in_file = 0 + curr_impl_configs = deepcopy(list(reversed(impl_configs))) + + while curr_num_impls_assigned < num_impls: + room_left_in_file = num_impls_per_file - curr_impl_in_file + if room_left_in_file == 0: + files_impls.append([]) + room_left_in_file = num_impls_per_file + curr_impl_in_file = 0 + + curr_ic = curr_impl_configs[-1] + if len(curr_ic.schedules) >= room_left_in_file: + # Break apart the current impl config + tmp_ic = deepcopy(curr_ic) + tmp_ic.schedules = curr_ic.schedules[:room_left_in_file] + curr_ic.schedules = curr_ic.schedules[room_left_in_file:] + files_impls[-1].append(tmp_ic) + else: + files_impls[-1].append(curr_ic) + curr_impl_configs.pop() + curr_num_impls_assigned += len(files_impls[-1][-1].schedules) + curr_impl_in_file += len(files_impls[-1][-1].schedules) + for part, file_impls in enumerate(files_impls): sources.append(( - f"machete_mm_{terse_type_name}_impl_part{part}", - mm_impl_template.render( - type_name=type_name, - type_config=impl_config.type_config, - schedules=file_schedules, - specializations=impl_config.specializations, - ), + f"machete_mm_impl_part{part+1}", + mm_impl_template.render(impl_configs=file_impls), )) + return sources @@ -328,187 +473,169 @@ def generate(): # about how this works SCRIPT_DIR = os.path.dirname(__file__) - schedule_common_params = dict( + sch_common_params = dict( kernel_schedule=TmaMI, epilogue_schedule=TmaCoop, tile_scheduler=TileSchedulerType.StreamK, ) - # For now we use the same heuristic for all types - # Heuristic is currently tuned for H100s - default_heuristic = [ + # Stored as "condition": ((tile_shape_mn), (cluster_shape_mnk)) + default_tile_heuristic_config = { #### M = 257+ - ( - "M > 256 && K <= 16384 && N <= 4096", - ScheduleConfig( - tile_shape_mn=(128, 128), - cluster_shape_mnk=(2, 1, 1), - **schedule_common_params # type: ignore - )), - ( - "M > 256", - ScheduleConfig( - tile_shape_mn=(128, 256), - cluster_shape_mnk=(2, 1, 1), - **schedule_common_params # type: ignore - )), + "M > 256 && K <= 16384 && N <= 4096": ((128, 128), (2, 1, 1)), + "M > 256": ((128, 256), (2, 1, 1)), #### M = 129-256 - ( - "M > 128 && K <= 4096 && N <= 4096", - ScheduleConfig( - tile_shape_mn=(128, 64), - cluster_shape_mnk=(2, 1, 1), - **schedule_common_params # type: ignore - )), - ( - "M > 128 && K <= 8192 && N <= 8192", - ScheduleConfig( - tile_shape_mn=(128, 128), - cluster_shape_mnk=(2, 1, 1), - **schedule_common_params # type: ignore - )), - ( - "M > 128", - ScheduleConfig( - tile_shape_mn=(128, 256), - cluster_shape_mnk=(2, 1, 1), - **schedule_common_params # type: ignore - )), + "M > 128 && K <= 4096 && N <= 4096": ((128, 64), (2, 1, 1)), + "M > 128 && K <= 8192 && N <= 8192": ((128, 128), (2, 1, 1)), + "M > 128": ((128, 256), (2, 1, 1)), #### M = 65-128 - ( - "M > 64 && K <= 4069 && N <= 4069", - ScheduleConfig( - tile_shape_mn=(128, 32), - cluster_shape_mnk=(2, 1, 1), - **schedule_common_params # type: ignore - )), - ( - "M > 64 && K <= 4069 && N <= 8192", - ScheduleConfig( - tile_shape_mn=(128, 64), - cluster_shape_mnk=(2, 1, 1), - **schedule_common_params # type: ignore - )), - ( - "M > 64 && K >= 8192 && N >= 12288", - ScheduleConfig( - tile_shape_mn=(256, 128), - cluster_shape_mnk=(2, 1, 1), - **schedule_common_params # type: ignore - )), - ( - "M > 64", - ScheduleConfig( - tile_shape_mn=(128, 128), - cluster_shape_mnk=(2, 1, 1), - **schedule_common_params # type: ignore - )), + "M > 64 && K <= 4069 && N <= 4069": ((128, 32), (2, 1, 1)), + "M > 64 && K <= 4069 && N <= 8192": ((128, 64), (2, 1, 1)), + "M > 64 && K >= 8192 && N >= 12288": ((256, 128), (2, 1, 1)), + "M > 64": ((128, 128), (2, 1, 1)), #### M = 33-64 - ( - "M > 32 && K <= 6144 && N <= 6144", - ScheduleConfig( - tile_shape_mn=(128, 16), - cluster_shape_mnk=(1, 1, 1), - **schedule_common_params # type: ignore - )), - ( - "M > 32 && K >= 16384 && N >= 12288", - ScheduleConfig( - tile_shape_mn=(256, 64), - cluster_shape_mnk=(2, 1, 1), - **schedule_common_params # type: ignore - )), - ( - "M > 32", - ScheduleConfig( - tile_shape_mn=(128, 64), - cluster_shape_mnk=(2, 1, 1), - **schedule_common_params # type: ignore - )), + "M > 32 && K <= 6144 && N <= 6144": ((128, 16), (1, 1, 1)), + "M > 32 && K >= 16384 && N >= 12288": ((256, 64), (2, 1, 1)), + "M > 32": ((128, 64), (2, 1, 1)), #### M = 17-32 - ( - "M > 16 && K <= 12288 && N <= 8192", - ScheduleConfig( - tile_shape_mn=(128, 32), - cluster_shape_mnk=(2, 1, 1), - **schedule_common_params # type: ignore - )), - ( - "M > 16", - ScheduleConfig( - tile_shape_mn=(256, 32), - cluster_shape_mnk=(2, 1, 1), - **schedule_common_params # type: ignore - )), + "M > 16 && K <= 12288 && N <= 8192": ((128, 32), (2, 1, 1)), + "M > 16": ((256, 32), (2, 1, 1)), #### M = 1-16 - ( - "N >= 26624", - ScheduleConfig( - tile_shape_mn=(256, 16), - cluster_shape_mnk=(1, 1, 1), - **schedule_common_params # type: ignore - )), - ( - None, - ScheduleConfig( - tile_shape_mn=(128, 16), - cluster_shape_mnk=(1, 1, 1), - **schedule_common_params # type: ignore - )), + "N >= 26624": ((256, 16), (1, 1, 1)), + None: ((128, 16), (1, 1, 1)), + } + + # For now we use the same heuristic for all types + # Heuristic is currently tuned for H100s + default_heuristic = [ + (cond, ScheduleConfig(*tile_config, + **sch_common_params)) # type: ignore + for cond, tile_config in default_tile_heuristic_config.items() ] - # Do not use schedules = list(set(...)) because we need to make sure - # the output list is deterministic; otherwise the generated kernel file - # will be non-deterministic and causes ccache miss. - schedules = [] - for _, schedule_config in default_heuristic: - if schedule_config not in schedules: - schedules.append(schedule_config) + def get_unique_schedules(heuristic: Dict[str, ScheduleConfig]): + # Do not use schedules = list(set(...)) because we need to make sure + # the output list is deterministic; otherwise the generated kernel file + # will be non-deterministic and causes ccache miss. + schedules = [] + for _, schedule_config in heuristic: + if schedule_config not in schedules: + schedules.append(schedule_config) + return schedules impl_configs = [] GPTQ_kernel_type_configs = list( TypeConfig( - element_a=element_a, - element_b=element_b, - element_b_scale=element_a, - element_b_zeropoint=element_a, - element_d=element_a, + a=a, + b=b, + b_group_scale=a, + b_group_zeropoint=DataType.void, + b_channel_scale=DataType.void, + a_token_scale=DataType.void, + out=a, accumulator=DataType.f32, - ) for element_b in (VLLMDataType.u4b8, VLLMDataType.u8b128) - for element_a in (DataType.f16, DataType.bf16)) - - GPTQ_kernel_specializations = [ - Specialization(with_C=False, with_zeropoints=False, with_scales=True) - ] + ) for b in (VLLMDataType.u4b8, VLLMDataType.u8b128) + for a in (DataType.f16, DataType.bf16)) impl_configs += [ - ImplConfig(x[0], x[1], x[2], x[3]) - for x in zip(GPTQ_kernel_type_configs, itertools.repeat(schedules), - itertools.repeat(GPTQ_kernel_specializations), + ImplConfig(x[0], x[1], x[2]) + for x in zip(GPTQ_kernel_type_configs, + itertools.repeat(get_unique_schedules(default_heuristic)), itertools.repeat(default_heuristic)) ] AWQ_kernel_type_configs = list( TypeConfig( - element_a=element_a, - element_b=element_b, - element_b_scale=element_a, - element_b_zeropoint=element_a, - element_d=element_a, + a=a, + b=b, + b_group_scale=a, + b_group_zeropoint=a, + b_channel_scale=DataType.void, + a_token_scale=DataType.void, + out=a, accumulator=DataType.f32, - ) for element_b in (DataType.u4, DataType.u8) - for element_a in (DataType.f16, DataType.bf16)) + ) for b in (DataType.u4, DataType.u8) + for a in (DataType.f16, DataType.bf16)) + + impl_configs += [ + ImplConfig(x[0], x[1], x[2]) + for x in zip(AWQ_kernel_type_configs, + itertools.repeat(get_unique_schedules(default_heuristic)), + itertools.repeat(default_heuristic)) + ] - AWQ_kernel_specializations = [ - Specialization(with_C=False, with_zeropoints=True, with_scales=True) + # Stored as "condition": ((tile_shape_mn), (cluster_shape_mnk)) + # TODO (LucasWilkinson): Further tuning required + qqq_tile_heuristic_config = { + #### M = 257+ + # ((128, 256), (2, 1, 1)) Broken for QQQ types + # TODO (LucasWilkinson): Investigate further + # "M > 256 && K <= 16384 && N <= 4096": ((128, 128), (2, 1, 1)), + # "M > 256": ((128, 256), (2, 1, 1)), + "M > 256": ((128, 128), (2, 1, 1)), + #### M = 129-256 + "M > 128 && K <= 4096 && N <= 4096": ((128, 64), (2, 1, 1)), + "M > 128 && K <= 8192 && N <= 8192": ((128, 128), (2, 1, 1)), + # ((128, 256), (2, 1, 1)) Broken for QQQ types + # TODO (LucasWilkinson): Investigate further + # "M > 128": ((128, 256), (2, 1, 1)), + "M > 128": ((128, 128), (2, 1, 1)), + #### M = 65-128 + "M > 64 && K <= 4069 && N <= 4069": ((128, 32), (2, 1, 1)), + "M > 64 && K <= 4069 && N <= 8192": ((128, 64), (2, 1, 1)), + "M > 64 && K >= 8192 && N >= 12288": ((256, 128), (2, 1, 1)), + "M > 64": ((128, 128), (2, 1, 1)), + #### M = 33-64 + "M > 32 && K <= 6144 && N <= 6144": ((128, 16), (1, 1, 1)), + # Broken for QQQ types + # TODO (LucasWilkinson): Investigate further + #"M > 32 && K >= 16384 && N >= 12288": ((256, 64), (2, 1, 1)), + "M > 32": ((128, 64), (2, 1, 1)), + #### M = 17-32 + "M > 16 && K <= 12288 && N <= 8192": ((128, 32), (2, 1, 1)), + "M > 16": ((256, 32), (2, 1, 1)), + #### M = 1-16 + "N >= 26624": ((256, 16), (1, 1, 1)), + None: ((128, 16), (1, 1, 1)), + } + + # For now we use the same heuristic for all types + # Heuristic is currently tuned for H100s + qqq_heuristic = [ + (cond, ScheduleConfig(*tile_config, + **sch_common_params)) # type: ignore + for cond, tile_config in qqq_tile_heuristic_config.items() + ] + + QQQ_kernel_types = [ + *(TypeConfig( + a=DataType.s8, + b=VLLMDataType.u4b8, + b_group_scale=b_group_scale, + b_group_zeropoint=DataType.void, + b_channel_scale=DataType.f32, + a_token_scale=DataType.f32, + out=DataType.f16, + accumulator=DataType.s32, + ) for b_group_scale in (DataType.f16, DataType.void)), + *(TypeConfig( + a=DataType.e4m3, + b=VLLMDataType.u4b8, + b_group_scale=b_group_scale, + b_group_zeropoint=DataType.void, + b_channel_scale=DataType.f32, + a_token_scale=DataType.f32, + out=DataType.f16, + accumulator=DataType.f32, + ) for b_group_scale in (DataType.f16, DataType.void)), ] impl_configs += [ - ImplConfig(x[0], x[1], x[2], x[3]) - for x in zip(AWQ_kernel_type_configs, itertools.repeat(schedules), - itertools.repeat(AWQ_kernel_specializations), - itertools.repeat(default_heuristic)) + ImplConfig(x[0], x[1], x[2]) + for x in zip(QQQ_kernel_types, + itertools.repeat(get_unique_schedules(qqq_heuristic)), + itertools.repeat(qqq_heuristic)) ] output_dir = os.path.join(SCRIPT_DIR, "generated") @@ -521,12 +648,11 @@ def generate(): os.makedirs(output_dir) # Render each group of configurations into separate files - for impl_config in impl_configs: - for filename, code in create_sources(impl_config): - filepath = os.path.join(output_dir, f"{filename}.cu") - with open(filepath, "w") as output_file: - output_file.write(code) - print(f"Rendered template to {filepath}") + for filename, code in create_sources(impl_configs): + filepath = os.path.join(output_dir, f"{filename}.cu") + with open(filepath, "w") as output_file: + output_file.write(code) + print(f"Rendered template to {filepath}") if __name__ == "__main__": diff --git a/csrc/quantization/machete/machete_mainloop.cuh b/csrc/quantization/machete/machete_mainloop.cuh index e8e7b14de0da1..816f33a1078e5 100644 --- a/csrc/quantization/machete/machete_mainloop.cuh +++ b/csrc/quantization/machete/machete_mainloop.cuh @@ -171,6 +171,10 @@ struct MacheteCollectiveMma { make_shape(size<0>(TileShape_MNK{}), size<2>(TileShape_MNK{}), Int{}))); + using SmemLayoutACopy = decltype(GmemLayoutA::TVbNbKL_to_offset_copy( + make_shape(size<0>(TileShape_MNK{}), size<2>(TileShape_MNK{}), + Int{}))); + using SmemLayoutAtomARowMajor = decltype(rs_smem_selector(TileShape_MNK{})), @@ -288,14 +292,7 @@ struct MacheteCollectiveMma { static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomScale{})) == 0, "SmemLayoutAtomScale must evenly divide tile k shape."); - // Tile along modes in a way that maximizes the TMA box size. - using SmemLayoutACopy = decltype(tile_to_shape( - SmemLayoutAtomARowMajor{}, - make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), - Int{}), - conditional_t<::cutlass::gemm::detail::is_major<0, StrideA>(), - Step<_2, _1, _3>, Step<_1, _2, _3>>{})); - + // Tile along modes in a way that maximizes the TMA box size using SmemLayoutB = decltype(tile_to_shape( SmemLayoutAtomB{}, make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), @@ -428,12 +425,12 @@ struct MacheteCollectiveMma { // clang-format on // ((athrid, val), (BlocksM, BlockK), L) -> (storage_idx) - using PrepackedStrideA = decltype(stride(GmemLayoutA::TVbNbKL_to_offset( + using PrepackedStrideA = decltype(stride(GmemLayoutA::TVbNbKL_to_offset_copy( make_shape(int32_t(0), int32_t(0), int32_t(0))))); using ATensor = decltype(make_tensor( get_logical_ptr(static_cast(nullptr)), - shape(GmemLayoutA::TVbNbKL_to_offset( + shape(GmemLayoutA::TVbNbKL_to_offset_copy( make_shape(int32_t(0), int32_t(0), int32_t(0)))), PrepackedStrideA{})); @@ -450,8 +447,8 @@ struct MacheteCollectiveMma { static constexpr auto make_tma_copy_A(ATensor tensor_a = ATensor{}) { return make_tma_copy( - GmemTiledCopyA{}, tensor_a, SmemLayoutA{}(_, _, cute::Int<0>{}), - shape(SmemLayoutA{}(_, _, cute::Int<0>{})), + GmemTiledCopyA{}, tensor_a, SmemLayoutACopy{}(_, _, cute::Int<0>{}), + shape(SmemLayoutACopy{}(_, _, cute::Int<0>{})), size<1>(ClusterShape{})); // mcast along N mode for this M load, if any } @@ -584,7 +581,7 @@ struct MacheteCollectiveMma { typename Params::TMA_Scale tma_load_scale; typename Params::TMA_Zero tma_load_zero; - auto layout = GmemLayoutA::TVbNbKL_to_offset(make_shape(M, K, L)); + auto layout = GmemLayoutA::TVbNbKL_to_offset_copy(make_shape(M, K, L)); tma_load_a = make_tma_copy_A( make_logical_tensor(ptr_A, shape(layout), stride(layout))); @@ -722,7 +719,7 @@ struct MacheteCollectiveMma { // (TILE_V,TILE_B,m,k,l) auto make_gA_mkl = [&]() { // ((athrid, val), (BlocksM, BlockK), L) -> (storage_idx) - auto layout = GmemLayoutA::TVbNbKL_to_offset(make_shape(M, K, L)); + auto layout = GmemLayoutA::TVbNbKL_to_offset_copy(make_shape(M, K, L)); Tensor mA_mkl = mainloop_params.tma_load_a.get_tma_tensor(shape(layout)); return local_tile(mA_mkl, make_shape(size<0>(layout), PPBlocksPerTile_MK{}), diff --git a/csrc/quantization/machete/machete_mm_kernel.cuh b/csrc/quantization/machete/machete_mm_kernel.cuh index 4d41b8d291484..d4d19ae5deec7 100644 --- a/csrc/quantization/machete/machete_mm_kernel.cuh +++ b/csrc/quantization/machete/machete_mm_kernel.cuh @@ -21,6 +21,8 @@ #include "cutlass_extensions/cute_utils.cuh" #include "cutlass_extensions/vllm_numeric_conversion.cuh" +#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp" +#include "cutlass_extensions/torch_utils.hpp" #include "machete_collective_builder.cuh" #include "machete_prepacked_layout.cuh" #include "machete_interleaving_utils.cuh" @@ -37,27 +39,42 @@ using namespace cute; // W is quantized, in this situation or right-hand operand is quantized so // we compute the transpose to move it to the left-hand side. template + typename AccumulatorT, typename GroupScaleT, typename GroupZeroT, + typename ChannelScaleT, typename TokenScaleT, class KernelSchedule, + typename ScheduleConfig> struct MacheteKernelTemplate { + static constexpr bool with_C = false; // not ever used + static constexpr bool with_group_scales = !std::is_same_v; + static constexpr bool with_group_zeropoints = + !std::is_same_v; + static constexpr bool with_channel_scales = + !std::is_same_v; + static constexpr bool with_token_scales = !std::is_same_v; + using MmaType = ElementA_; using ElementA = ElementA_; using ElementB = ElementB_; using ElementD = ElementD_; using ElementC = cute::conditional_t; - using ElementZ = ZeroT; - using ElementS = ScaleT; - - using ElementAccumulator = - AccumulatorT; // Element type for internal accumulation + using ElementAccumulator = AccumulatorT; using ElementCompute = AccumulatorT; // For Epilogue + // Use dummy values when we don't have scales or zeropoints + using ElementZGroup = + cute::conditional_t; + using ElementSGroup = + cute::conditional_t; + using ElementConvertGroup = + cute::conditional_t; + using ElementSChannel = + cute::conditional_t; + using ElementSToken = + cute::conditional_t; using BTypeTuple = cute::conditional_t< - with_scales, - cute::conditional_t, - cute::tuple>, + with_group_scales, + cute::conditional_t, + cute::tuple>, ElementB>; using LayoutA = cutlass::layout::RowMajor; @@ -71,8 +88,8 @@ struct MacheteKernelTemplate { using StrideA = cutlass::detail::TagToStrideA_t; using StrideC = cutlass::detail::TagToStrideA_t; using StrideD = cutlass::detail::TagToStrideA_t; - using StrideS = cutlass::detail::TagToStrideA_t; - using StrideZ = StrideS; + using StrideSGroup = cutlass::detail::TagToStrideA_t; + using StrideZGroup = StrideSGroup; using LayoutA_Transpose = typename cutlass::layout::LayoutTranspose::type; @@ -85,8 +102,8 @@ struct MacheteKernelTemplate { using OperatorClass = cutlass::arch::OpClassTensorOp; using PrepackedLayoutB = - PrepackedLayoutBTemplate; + PrepackedLayoutBTemplate; static int constexpr TileShapeK = 128 * 8 / cutlass::sizeof_bits::value; @@ -103,12 +120,42 @@ struct MacheteKernelTemplate { using EpilogueTileType = typename ScheduleConfig::EpilogueTileType; using TileScheduler = typename ScheduleConfig::TileScheduler; + static_assert( + (!with_channel_scales && !with_token_scales) || + ((with_channel_scales && with_token_scales) && + std::is_same_v), + "Currently token and channel scales (if present) must be the same type"); + + using EpilogueDescriptor = + cutlass::epilogue::collective::detail::EpilogueDescriptor< + TileShape, cutlass::epilogue::collective::EpilogueTileAuto, ElementD, + ElementD, EpilogueSchedule>; + + // Currently only supports float scales + using ChTokScalesEpilogue = + typename vllm::c3x::ScaledEpilogue; + static_assert((with_channel_scales || with_token_scales) || + (std::is_same_v && + std::is_same_v), + "Currently token and channel scales (if present) must be float " + "(and if one is present the other must be too)"); + + using StoreEpilogueCompute = typename cutlass::epilogue::fusion::Sm90EVT< + cutlass::epilogue::fusion::Sm90AccFetch>; + + using EVTCompute = + std::conditional_t; + + // EVTCompute using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< ArchTag, OperatorClass, TileShape, ClusterShape, EpilogueTileType, - ElementAccumulator, ElementAccumulator, ElementC, LayoutC_Transpose, - AlignmentC, ElementD, LayoutD_Transpose, AlignmentD, - EpilogueSchedule>::CollectiveOp; + ElementAccumulator, ElementSChannel, ElementC, LayoutC_Transpose, + AlignmentC, ElementD, LayoutD_Transpose, AlignmentD, EpilogueSchedule, + EVTCompute>::CollectiveOp; using CollectiveMainloop = typename cutlass::gemm::collective::VLLMCollectiveBuilder< @@ -131,26 +178,44 @@ struct MacheteKernelTemplate { using MainloopArguments = typename GemmKernel::MainloopArguments; using EpilogueArguments = typename GemmKernel::EpilogueArguments; - template static Arguments create_arguments( cudaStream_t stream, - ElementA const* A_ptr, // A is an MxK matrix - Layout const& layout_A, - ElementB const* B_ptr, // B is an KxN prepacked matrix - ElementD* D_ptr, // D is an MxN matrix - Layout const& layout_D, - ElementC const* C_ptr, // C is an MxN matrix - std::optional> const& layout_C, - ElementS const* S_ptr, // S is an scale_KxN matrix - std::optional> const& layout_S, - ElementZ const* Z_ptr, // Z is an scale_KxN matrix - std::optional> const& layout_Z, - ElementCompute alpha, ElementCompute beta, - std::optional maybe_group_size) { - static_assert(!with_zeropoints || with_scales); - - int M = size<0>(layout_A), N = size<1>(layout_D), K = size<1>(layout_A); + torch::Tensor const& A, // MxK matrix + torch::Tensor const& B, // KxN prepacked matrix + torch::Tensor& D, // MxN matrix + c10::optional const& maybe_g_scales, // scale_KxN matrix + c10::optional const& maybe_g_zeros, // scale_KxN matrix + c10::optional maybe_group_size, + c10::optional const& maybe_ch_scales, // len N vector + c10::optional const& maybe_tok_scales) // len M vector + { + static_assert(!with_group_zeropoints || with_group_scales); + + int M = A.size(0), N = B.size(1), K = A.size(1); + TORCH_CHECK(D.size(0) == M && D.size(1) == N); + + auto layout_A = make_cute_layout(A, "A"); + auto layout_D = make_cute_layout(D, "D"); + auto layout_S_group = + maybe_make_cute_layout(maybe_g_scales, "group_scales"); + auto layout_Z_group = + maybe_make_cute_layout(maybe_g_zeros, "group_zeros"); + int64_t numel_S_channel = maybe_ch_scales ? maybe_ch_scales->numel() : 0; + int64_t numel_S_token = maybe_tok_scales ? maybe_tok_scales->numel() : 0; + + auto unwrap = [](auto const& t) { + return t ? t->const_data_ptr() : nullptr; + }; + auto A_ptr = static_cast(A.const_data_ptr()); + auto B_ptr = static_cast(B.const_data_ptr()); + auto D_ptr = static_cast(D.mutable_data_ptr()); + auto S_group_ptr = + static_cast(unwrap(maybe_g_scales)); + auto Z_group_ptr = static_cast(unwrap(maybe_g_zeros)); + auto S_channel_ptr = + static_cast(unwrap(maybe_ch_scales)); + auto S_token_ptr = + static_cast(unwrap(maybe_tok_scales)); int const group_size = maybe_group_size == -1 ? K : maybe_group_size.value_or(K); @@ -159,26 +224,28 @@ struct MacheteKernelTemplate { TORCH_CHECK(size<0>(layout_A) == M && size<1>(layout_A) == K); TORCH_CHECK(size<0>(layout_D) == M && size<1>(layout_D) == N); - if constexpr (with_C) { - TORCH_CHECK(C_ptr && layout_C); + if constexpr (with_group_scales) { + TORCH_CHECK(S_group_ptr && layout_S_group); + TORCH_CHECK((size<0>(*layout_S_group) == scale_k && + size<1>(*layout_S_group) == N)); } else { - TORCH_CHECK(!C_ptr, "C not supported"); + TORCH_CHECK(!S_group_ptr, "Scales not supported"); } - if constexpr (with_scales) { - TORCH_CHECK(S_ptr && layout_S); - TORCH_CHECK((size<0>(*layout_S) == scale_k && size<1>(*layout_S) == N)); + if constexpr (with_group_zeropoints) { + TORCH_CHECK(Z_group_ptr && layout_Z_group); + TORCH_CHECK((size<0>(*layout_Z_group) == scale_k && + size<1>(*layout_Z_group) == N)); + TORCH_CHECK(layout_S_group && *layout_Z_group == *layout_S_group, + "Scales and zeros must have the same layout"); } else { - TORCH_CHECK(!S_ptr, "Scales not supported"); + TORCH_CHECK(!Z_group_ptr, "Zeropoints not supported"); } - if constexpr (with_zeropoints) { - TORCH_CHECK(Z_ptr && layout_Z); - TORCH_CHECK((size<0>(*layout_Z) == scale_k && size<1>(*layout_Z) == N)); - TORCH_CHECK(layout_S && *layout_Z == *layout_S, - "Scales and zeros must have the same layout"); - } else { - TORCH_CHECK(!Z_ptr, "Zeropoints not supported"); + if constexpr (with_channel_scales || with_token_scales) { + TORCH_CHECK( + (maybe_ch_scales->numel() == N || maybe_ch_scales->numel() == 1) && + (maybe_tok_scales->numel() == M || maybe_tok_scales->numel() == 1)); } // Transpose A and D @@ -186,24 +253,33 @@ struct MacheteKernelTemplate { // for B (which is At) auto stride_At = layout_A.stride(); auto stride_Dt = permute_layout<1, 0, 2>(layout_D).stride(); - auto stride_Ct = stride_Dt; - if (layout_C) { - stride_Ct = permute_layout<1, 0, 2>(*layout_C).stride(); - } MainloopArguments mainloop_arguments{}; - EpilogueArguments epilogue_arguments{ - {alpha, beta}, C_ptr, stride_Ct, D_ptr, stride_Dt}; + // {Accum, C, C_layout, D, D} + EpilogueArguments epilogue_arguments{}; + + if constexpr (with_channel_scales || with_token_scales) { + epilogue_arguments = + EpilogueArguments{ChTokScalesEpilogue::prepare_args( + *maybe_ch_scales, *maybe_tok_scales), + nullptr, + {}, + D_ptr, + stride_Dt}; + } else { + epilogue_arguments = EpilogueArguments{{}, nullptr, {}, D_ptr, stride_Dt}; + } - if constexpr (with_scales && with_zeropoints) { - auto stride_S = permute_layout<1, 0, 2>(*layout_S).stride(); - mainloop_arguments = - MainloopArguments{B_ptr, _StrideB{}, A_ptr, stride_At, - S_ptr, stride_S, group_size, Z_ptr}; - } else if constexpr (with_scales) { - auto stride_S = permute_layout<1, 0, 2>(*layout_S).stride(); + if constexpr (with_group_scales && with_group_zeropoints) { + auto stride_S_group = permute_layout<1, 0, 2>(*layout_S_group).stride(); mainloop_arguments = MainloopArguments{ - B_ptr, _StrideB{}, A_ptr, stride_At, S_ptr, stride_S, group_size}; + B_ptr, _StrideB{}, A_ptr, stride_At, + S_group_ptr, stride_S_group, group_size, Z_group_ptr}; + } else if constexpr (with_group_scales) { + auto stride_S_group = permute_layout<1, 0, 2>(*layout_S_group).stride(); + mainloop_arguments = + MainloopArguments{B_ptr, _StrideB{}, A_ptr, stride_At, + S_group_ptr, stride_S_group, group_size}; } else { mainloop_arguments = MainloopArguments{B_ptr, _StrideB{}, A_ptr, stride_At}; diff --git a/csrc/quantization/machete/machete_mm_launcher.cuh b/csrc/quantization/machete/machete_mm_launcher.cuh index 60a4ed60535b7..4b0da5b303e0c 100644 --- a/csrc/quantization/machete/machete_mm_launcher.cuh +++ b/csrc/quantization/machete/machete_mm_launcher.cuh @@ -5,73 +5,61 @@ #include "machete_mm_kernel.cuh" #include "cutlass_extensions/torch_utils.hpp" +#include "core/scalar_type.hpp" namespace machete { -struct PyTorchArguments { +struct MMArgs { torch::Tensor const& A; torch::Tensor const& B; - c10::optional const& scales; - c10::optional const& zeros; - c10::optional group_size; - c10::optional const& C; - c10::optional alpha; - c10::optional beta; - c10::optional schedule; + vllm::ScalarType const& b_type; + c10::optional const& maybe_out_type; + c10::optional const& maybe_group_scales; + c10::optional const& maybe_group_zeros; + c10::optional maybe_group_size; + c10::optional const& maybe_channel_scales; + c10::optional const& maybe_token_scales; + c10::optional maybe_schedule; }; +struct SupportedSchedulesArgs { + at::ScalarType a_type; + vllm::ScalarType b_type; + c10::optional maybe_group_scales_type; + c10::optional maybe_group_zeros_type; + c10::optional maybe_channel_scales_type; + c10::optional maybe_token_scales_type; + c10::optional maybe_out_type; +}; + +torch::Tensor mm_dispatch(MMArgs args); + +std::vector supported_schedules_dispatch( + SupportedSchedulesArgs args); + template -torch::Tensor run_impl(PyTorchArguments args) { +torch::Tensor run_impl(MMArgs args) { const at::cuda::OptionalCUDAGuard device_guard(device_of(args.A)); auto device = args.A.device(); auto stream = at::cuda::getCurrentCUDAStream(device.index()); - using EleA = typename MacheteKernel::ElementA; - using EleB = typename MacheteKernel::ElementB; - using EleC = typename MacheteKernel::ElementC; - using EleD = typename MacheteKernel::ElementD; - using EleScale = typename MacheteKernel::ElementS; - using EleZero = typename MacheteKernel::ElementZ; - - using StrideA = typename MacheteKernel::StrideA; - using StrideC = typename MacheteKernel::StrideC; - using StrideD = typename MacheteKernel::StrideD; - using StrideS = typename MacheteKernel::StrideS; - using StrideZ = typename MacheteKernel::StrideZ; - int M = args.A.size(0); int N = args.B.size(1); int K = args.A.size(1); // Allocate output - torch::Tensor D = - torch::empty({M, N}, torch::TensorOptions() - .dtype(equivalent_scalar_type_v) - .device(device)); - - auto const &A = args.A, &B = args.B; - auto const &C = args.C, &scales = args.scales, &zeros = args.zeros; - - auto layout_A = make_cute_layout(A, "A"); - auto layout_D = make_cute_layout(D, "D"); - auto layout_C = maybe_make_cute_layout(C, "C"); - auto layout_S = maybe_make_cute_layout(scales, "scales"); - auto layout_Z = maybe_make_cute_layout(zeros, "zeros"); - - auto A_ptr = static_cast(A.const_data_ptr()); - auto B_ptr = static_cast(B.const_data_ptr()); - auto D_ptr = static_cast(D.mutable_data_ptr()); - auto C_ptr = static_cast(C ? C->const_data_ptr() : nullptr); - auto S_ptr = - static_cast(scales ? scales->const_data_ptr() : nullptr); - auto Z_ptr = - static_cast(zeros ? zeros->const_data_ptr() : nullptr); + torch::Tensor D = torch::empty( + {M, N}, + torch::TensorOptions() + .dtype(equivalent_scalar_type_v) + .device(device)); auto arguments = MacheteKernel::create_arguments( - stream, A_ptr, layout_A, B_ptr, D_ptr, layout_D, C_ptr, layout_C, S_ptr, - layout_S, Z_ptr, layout_Z, args.alpha.value_or(1), args.beta.value_or(0), - args.group_size); + stream, // + args.A, args.B, D, args.maybe_group_scales, args.maybe_group_zeros, + args.maybe_group_size, args.maybe_channel_scales, + args.maybe_token_scales); TORCH_CHECK(MacheteKernel::can_implement(arguments), "Machete kernel cannot be run with these arguments"); @@ -84,12 +72,4 @@ torch::Tensor run_impl(PyTorchArguments args) { return D; }; -template -struct GemmDispatcher { - static torch::Tensor dispatch(PyTorchArguments args); - static std::vector supported_schedules(); -}; - }; // namespace machete \ No newline at end of file diff --git a/csrc/quantization/machete/machete_prepack_kernel.cuh b/csrc/quantization/machete/machete_prepack_kernel.cuh index f23483f928b47..d002355ca49d6 100644 --- a/csrc/quantization/machete/machete_prepack_kernel.cuh +++ b/csrc/quantization/machete/machete_prepack_kernel.cuh @@ -6,31 +6,49 @@ namespace machete { -template -static __global__ void prepack_B_kernel(BInTensor B_in, - BTiledOutTensor B_tiled_out) { - auto tB_in = local_tile(B_in, TileShapeNKL{}, - make_coord(blockIdx.x, blockIdx.y, blockIdx.z)); - auto tB_out = B_tiled_out(make_coord(_, _), - make_coord(blockIdx.x, blockIdx.y), blockIdx.z); +template +static __global__ void prepack_B_kernel(BInTensor B_in, ElementB* B_out_ptr) { + auto constexpr block_size = + Int{}; + auto constexpr eles_per_thread = Int{}; + static_assert(block_size % threads == 0, + "block_size must be divisible by the number of threads"); - auto tiled_copy = make_tiled_copy(Copy_Atom{}, - Layout, Stride<_32, _1>>{}, - Layout>{}); + // Which pre-packed are we responsible for + auto blk_coord = make_coord(blockIdx.x, blockIdx.y, blockIdx.z); + auto tB_in = local_tile( + B_in, append(typename PrepackedLayoutB::PPBlockShape_NK{}, _1{}), + blk_coord); - auto thr_copy = tiled_copy.get_thread_slice(threadIdx.x); + // Find the start offset in the output for this pre-packed block + auto bNbKL_to_offset = PrepackedLayoutB::bNbKL_to_offset(shape(B_in)); - Tensor thr_tile_S = thr_copy.partition_S(tB_in); - Tensor thr_tile_D = thr_copy.partition_D(tB_out); + // Tensor representing a 1:1 mapping to the output space in 1D + auto tB_out_linear = + make_tensor(get_logical_ptr(B_out_ptr) + bNbKL_to_offset(blk_coord), + make_layout(make_shape(block_size))); + // Mapping from output space (1D) to input space + auto tB_in_linear = make_tensor( + tB_in.data(), + tB_in.layout() + .compose(right_inverse(PrepackedLayoutB::ppblock_ilvd_NK_to_offset())) + .with_shape(make_shape(block_size))); + + // Tile for this specific thread (could have used a TiledCopy but these work + // best with 2d layouts, this is a simple 1d layout so local_tile is enough, + // we are also not that concerned with performance for this kernel) + auto thr_tB_in_linear = + local_tile(tB_in_linear, make_shape(eles_per_thread), threadIdx.x); + auto thr_tB_out_linear = + local_tile(tB_out_linear, make_shape(eles_per_thread), threadIdx.x); // Construct a register-backed Tensor with the same shape as each thread's // partition - auto fragment = make_tensor(shape(thr_tile_D)); + auto fragment = make_tensor(shape(thr_tB_in_linear)); - // Copy from GMEM to RMEM and from RMEM to GMEM - copy(tiled_copy, thr_tile_S, fragment); - copy(Copy_Atom{}, fragment, thr_tile_D); + copy(thr_tB_in_linear, fragment); + copy(Copy_Atom{}, fragment, thr_tB_out_linear); } template @@ -44,18 +62,15 @@ static void prepack_B_template( TORCH_CHECK(size<0>(B_layout) % size<0>(TileShapeNKL{}) == 0); TORCH_CHECK(size<1>(B_layout) % size<1>(TileShapeNKL{}) == 0); - TORCH_CHECK(size<2>(B_layout) % size<2>(TileShapeNKL{}) == 0); auto N_tiles = size<0>(B_layout) / size<0>(TileShapeNKL{}); auto K_tiles = size<1>(B_layout) / size<1>(TileShapeNKL{}); - auto L_tiles = size<2>(B_layout) / size<2>(TileShapeNKL{}); + auto L_tiles = size<2>(B_layout); auto B_in = make_tensor(get_logical_ptr(B_in_ptr), B_layout); - auto B_tiled_out = - make_tensor(get_logical_ptr(B_out_ptr), ilvd_NKbNbKL_to_offset); - prepack_B_kernel - <<>>(B_in, B_tiled_out); + prepack_B_kernel<128, PrepackedLayoutB> + <<>>(B_in, B_out_ptr); } }; // namespace machete \ No newline at end of file diff --git a/csrc/quantization/machete/machete_prepack_launcher.cuh b/csrc/quantization/machete/machete_prepack_launcher.cuh index a33d8f9484cfe..3486d28be2126 100644 --- a/csrc/quantization/machete/machete_prepack_launcher.cuh +++ b/csrc/quantization/machete/machete_prepack_launcher.cuh @@ -2,9 +2,17 @@ #include "machete_prepack_kernel.cuh" #include "cutlass_extensions/torch_utils.hpp" +#include "core/scalar_type.hpp" namespace machete { +struct PrepackBArgs { + torch::Tensor const& B; + at::ScalarType a_type; + vllm::ScalarType b_type; + c10::optional maybe_group_scales_type; +}; + template torch::Tensor prepack_impl(torch::Tensor const B) { const at::cuda::OptionalCUDAGuard device_guard(device_of(B)); @@ -61,11 +69,6 @@ torch::Tensor prepack_impl(torch::Tensor const B) { return D; }; -template -struct PrepackBDispatcher { - static torch::Tensor dispatch(torch::Tensor B); -}; +torch::Tensor prepack_B_dispatch(PrepackBArgs args); }; // namespace machete \ No newline at end of file diff --git a/csrc/quantization/machete/machete_prepacked_layout.cuh b/csrc/quantization/machete/machete_prepacked_layout.cuh index 78e2cc5eec7d8..680a858a893c1 100644 --- a/csrc/quantization/machete/machete_prepacked_layout.cuh +++ b/csrc/quantization/machete/machete_prepacked_layout.cuh @@ -41,7 +41,7 @@ struct IlvBlkLayoutAuto {}; // The contract here is that the `TiledMma` determined below matches the one // ultimately used in the kernel. (this is also why the other element types are // required along with the kernel schedule) -template // clang-format on @@ -49,20 +49,27 @@ struct PrepackedLayoutBTemplate { using MmaType = ElementA_; using ElementA = ElementA_; using ElementB = ElementB_; - using ElementD = ElementD_; - using ElementAccumulator = - AccumulatorT; // Element type for internal accumulation + using ElementAccumulator = AccumulatorT; using ElementMma = MmaType; - // Only use interleaved layouts for subbyte weights, prmt instructions makes - // non-interleaved layouts for 8bit+ weights efficient enough we don't need - // iterleaved layouts + // Interleave for 4bit bit types when we are not upconverting to fp8 or int8, + // in those cases case we use a LUT using prmt instructions to upconvert and + // is more efficient if the data is not interleaved For 8bit+ prmt + // instructions makes non-interleaved layouts efficient enough we don't need + // iterleaved layouts (and can reuse more of the existing cutlass converts) + static constexpr bool should_interleave = + sizeof_bits_v <= 4 && + !std::is_same_v && + !std::is_same_v; + + // Only use interleaved layouts for subbyte weights, using IlvdBlkLayout = std::conditional_t< std::is_same_v, - std::conditional_t <= 4, - decltype(get_interleaved_blk_layout< - ElementB, sizeof_bits_v, 32>()), - void>, + std::conditional_t< + should_interleave, + decltype(get_interleaved_blk_layout< + ElementB, sizeof_bits_v, 32>()), + void>, IlvBlkLayout_>; // TODO (LucasWilkinson): compare the performance for other sizes @@ -135,7 +142,8 @@ struct PrepackedLayoutBTemplate { // then ((IlvBlk), FrgB) is {A, C, B, D, C, G, D, H} auto frgV = get<1, 0>(layout_no_interleave); auto ilvdBlk = IlvdBlkLayout{}; - static_assert(size(frgV) % 4 == 0, "FrgV must be divisible by 4"); + static_assert(size(frgV) % size(ilvdBlk) == 0, + "FrgV must be divisible by size(ilvdBlk)"); auto ilvd_FrgV = make_layout( make_shape(shape(ilvdBlk), Int{}), make_stride(stride(ilvdBlk), size(ilvdBlk))); @@ -175,6 +183,15 @@ struct PrepackedLayoutBTemplate { return group<1, 3>(result(_, repeat(result)>(_))); } + // ((athrid_val), (BlocksN, BlocksK, L)) -> (N, K, L) + template + CUTE_HOST_DEVICE static constexpr auto TVbNbKL_to_offset_copy( + Shape_NKL shape_mkl) { + auto layout = TVbNbKL_to_offset(shape_mkl); + return make_layout(coalesce(get<0>(layout)), get<1>(layout), + get<2>(layout)); + } + // ((BlockN, BlockK), (BlocksN, BlocksK), L) -> (storage_idx) template CUTE_HOST_DEVICE static constexpr auto ilvd_NKbNbKL_to_offset( @@ -197,6 +214,19 @@ struct PrepackedLayoutBTemplate { return group<1, 3>(result(_, repeat(result)>(_))); } + // (BlocksN, BlocksK, L) -> (storage_idx) + template + CUTE_HOST_DEVICE static constexpr auto bNbKL_to_offset(Shape_NKL shape_mkl) { + // (BlocksN, BlocksK, L) + auto blocks_shape = + cute::transform(shape_mkl, append(PPBlockShape_NK{}, _1{}), + [](auto x, auto y) { return x / y; }); + auto stride = size(PPBlockShape_NK{}); + + // (BlocksN, BlocksK, L) -> (storage_idx) + return make_layout(blocks_shape, compact_col_major(blocks_shape, stride)); + } + // ((athrid, val), (BlocksN, BlocksK, L)) -> (N, K, L) template CUTE_HOST_DEVICE static auto TVbNbK_to_NKL(Shape_NKL shape_mkl) { diff --git a/csrc/quantization/machete/machete_pytorch.cu b/csrc/quantization/machete/machete_pytorch.cu index 9f9073ded6191..da2c2fb0d3e77 100644 --- a/csrc/quantization/machete/machete_pytorch.cu +++ b/csrc/quantization/machete/machete_pytorch.cu @@ -8,89 +8,61 @@ namespace machete { using namespace vllm; -// -// Utils (type dispatching) -// - -template -static auto scalar_type_dispatch(ScalarType const& type, Fn fn) { - if (type == vllm::kU4) { - return fn(cutlass::uint4b_t{}); - } else if (type == vllm::kU8) { - return fn(cutlass::uint8_t{}); - } else if (type == vllm::kU4B8) { - return fn(cutlass::vllm_uint4b8_t{}); - } else if (type == vllm::kU8B128) { - return fn(cutlass::vllm_uint8b128_t{}); - } else { - TORCH_CHECK(false, "Unsupported type ", type.str()); - } -} - -#define AT_DISPATCH_CASE_SUPPORTED_COMPUTE_TYPES(...) \ - AT_DISPATCH_CASE_REDUCED_FLOATING_TYPES(__VA_ARGS__) - -#define AT_DISPATCH_SUPPORTED_COMPUTE_TYPES(TYPE, NAME, ...) \ - AT_DISPATCH_SWITCH(TYPE, NAME, \ - AT_DISPATCH_CASE_SUPPORTED_COMPUTE_TYPES(__VA_ARGS__)) - -// -// Interface -// - -std::vector supported_schedules(ScalarTypeId const btype_id) { -#if defined(__CUDACC_VER_MAJOR__) && __CUDACC_VER_MAJOR__ >= 12 - vllm::ScalarType b_type = ScalarType::from_id(btype_id); - return scalar_type_dispatch(b_type, [&](auto BType) { - return GemmDispatcher::supported_schedules(); +std::vector supported_schedules( + at::ScalarType a_type, int64_t b_type_id, + c10::optional maybe_group_scales_type, + c10::optional maybe_group_zeros_type, + c10::optional maybe_channel_scales_type, + c10::optional maybe_token_scales_type, + c10::optional maybe_out_type) { + ScalarType const b_type = ScalarType::from_id(b_type_id); + return supported_schedules_dispatch({ + .a_type = a_type, + .b_type = b_type, + .maybe_group_scales_type = maybe_group_scales_type, + .maybe_group_zeros_type = maybe_group_zeros_type, + .maybe_channel_scales_type = maybe_channel_scales_type, + .maybe_token_scales_type = maybe_token_scales_type, + .maybe_out_type = maybe_out_type, }); -#else - TORCH_CHECK(false, "Machete requires CUDA 12.0 or later"); -#endif } -torch::Tensor gemm(torch::Tensor const& A, torch::Tensor const& B, - ScalarTypeId const btype_id, - c10::optional const& scales, - c10::optional const& zeros, - c10::optional group_size, - c10::optional const& C, - c10::optional alpha, c10::optional beta, - c10::optional schedule) { -#if defined(__CUDACC_VER_MAJOR__) && __CUDACC_VER_MAJOR__ >= 12 - ScalarType const btype = ScalarType::from_id(btype_id); - auto args = PyTorchArguments{.A = A, - .B = B, - .scales = scales, - .zeros = zeros, - .group_size = group_size, - .C = C, - .alpha = alpha, - .beta = beta, - .schedule = schedule}; - - return scalar_type_dispatch(btype, [&](auto BType) { - return AT_DISPATCH_SUPPORTED_COMPUTE_TYPES( - A.scalar_type(), "machete_gemm", [&] { - using ComputeType = equivalent_cutlass_type_t; - return GemmDispatcher::dispatch(args); - }); - }); -#else - TORCH_CHECK(false, "Machete requires CUDA 12.0 or later"); -#endif +torch::Tensor mm(torch::Tensor const& A, torch::Tensor const& B, + int64_t b_type_id, + c10::optional const& maybe_out_type, + c10::optional const& maybe_group_scales, + c10::optional const& maybe_group_zeros, + c10::optional maybe_group_size, + c10::optional const& maybe_channel_scales, + c10::optional const& maybe_token_scales, + c10::optional maybe_schedule) { + ScalarType const b_type = ScalarType::from_id(b_type_id); + return mm_dispatch({.A = A, + .B = B, + .b_type = b_type, + .maybe_out_type = maybe_out_type, + .maybe_group_scales = maybe_group_scales, + .maybe_group_zeros = maybe_group_zeros, + .maybe_group_size = maybe_group_size, + .maybe_channel_scales = maybe_channel_scales, + .maybe_token_scales = maybe_token_scales, + .maybe_schedule = maybe_schedule}); } -torch::Tensor prepack_B(torch::Tensor const& B, ScalarTypeId const btype_id) { - ScalarType const btype = ScalarType::from_id(btype_id); - return scalar_type_dispatch(btype, [&](auto BType) { - return PrepackBDispatcher::dispatch(B); - }); +torch::Tensor prepack_B( + torch::Tensor const& B, at::ScalarType const& a_type, int64_t b_type_id, + c10::optional const& maybe_group_scales_type) { + ScalarType const b_type = ScalarType::from_id(b_type_id); + return prepack_B_dispatch( + {.B = B, + .a_type = a_type, + .b_type = b_type, + .maybe_group_scales_type = maybe_group_scales_type}); } TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) { m.impl("machete_prepack_B", &prepack_B); - m.impl("machete_gemm", &gemm); + m.impl("machete_mm", &mm); } // use CatchAll since supported_schedules has no tensor arguments diff --git a/csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu b/csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu index a33e2660d760e..17837351324be 100644 --- a/csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu +++ b/csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu @@ -296,13 +296,9 @@ __global__ void Marlin_24( // We use a different scale layout for grouped and column-wise quantization as // we scale a `half2` tile in column-major layout in the former and in // row-major in the latter case. - if (group_blocks != -1) { - s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + - (threadIdx.x % 32) / 4; - } else { - s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + - (threadIdx.x % 32) / 4; - } + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + (threadIdx.x % 32) / 4; // Note that in the original Marlin kernel + // this is (threadIdx.x % 32) / 4 // Precompute which thread should not read memory in which iterations; this is // needed if there are more threads than required for a certain tilesize or @@ -910,13 +906,16 @@ void marlin_cuda_2_4(const void* A, const void* B, const void* meta, void* C, // than better compute utilization thread_k = 128; thread_m = 128; - } else if (prob_n <= 256) { + } else { thread_k = 64; thread_m = 256; - } else { - thread_k = 32; - thread_m = 512; } + // Also had + // if prob_n > 256 + // thread_k = 32; + // thread_m = 512; + // but this is broken, + // TODO(Lucas, Alex M): figure out why } int thread_k_blocks = thread_k / 32; // 2:4 version with m16n8k32 instruction @@ -1079,6 +1078,8 @@ torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, // Verify A device and strides TORCH_CHECK(a.device().is_cuda(), "A is not on GPU"); TORCH_CHECK(a.is_contiguous(), "A is not contiguous"); + TORCH_CHECK(a.dtype() == torch::kFloat16, + "A is not float16, currently only float16 is supported"); // Verify B device and strides TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU"); @@ -1091,6 +1092,8 @@ torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, // Verify scales device and strides TORCH_CHECK(b_scales.device().is_cuda(), "b_scales is not on GPU"); TORCH_CHECK(b_scales.is_contiguous(), "b_scales is not contiguous"); + TORCH_CHECK(b_scales.dtype() == torch::kFloat16, + "A is not float16, currently only float16 is supported"); // Alloc C matrix const at::cuda::OptionalCUDAGuard device_guard(device_of(a)); diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 3aec9350f017a..d6dab3c916e96 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -209,13 +209,36 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // conditionally compiled so impl in source file // Machete (Dense) Optimized Mixed Precision GEMM for Hopper. - ops.def("machete_supported_schedules(int btype) -> str[]"); ops.def( - "machete_gemm(Tensor A, Tensor B, int btype, " - " Tensor? scales, Tensor? zeros, int? group_size, " - " Tensor? C, float? alpha, float? beta, str? schedule)" - "-> Tensor"); - ops.def("machete_prepack_B(Tensor B, int btype) -> Tensor"); + "machete_supported_schedules(" + " ScalarType a_type," + " int b_type," + " ScalarType? maybe_group_scales_type," + " ScalarType? maybe_group_zeros_type," + " ScalarType? maybe_channel_scales_type," + " ScalarType? maybe_token_scales_type," + " ScalarType? maybe_out_type" + ") -> str[]"); + ops.def( + "machete_mm(" + " Tensor A," + " Tensor B," + " int b_type," + " ScalarType? out_type," + " Tensor? group_scales," + " Tensor? group_zeros," + " int? group_size," + " Tensor? channel_scales," + " Tensor? token_scales," + " str? schedule" + ") -> Tensor"); + ops.def( + "machete_prepack_B(" + " Tensor B," + " ScalarType a_type," + " int b_type," + " ScalarType? group_scales_type" + ") -> Tensor"); // conditionally compiled so impl registration is in source file ops.def("permute_cols(Tensor A, Tensor perm) -> Tensor"); @@ -227,7 +250,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "Tensor b_zeros, Tensor g_idx, Tensor perm, Tensor workspace, " "int b_q_type, " "SymInt size_m, SymInt size_n, SymInt size_k, bool is_k_full, " - "bool has_zp, bool use_fp32_reduce) -> Tensor"); + "bool has_zp, bool use_fp32_reduce, bool is_zp_float) -> Tensor"); // conditionally compiled so impl registration is in source file // gptq_marlin repack from GPTQ. @@ -241,6 +264,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "awq_marlin_repack(Tensor b_q_weight, SymInt size_k, " "SymInt size_n, int num_bits) -> Tensor"); // conditionally compiled so impl registrations are in source file +#endif // Dequantization for GGML. ops.def("ggml_dequantize(Tensor W, int type, SymInt m, SymInt n) -> Tensor"); @@ -257,6 +281,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "ggml_mul_mat_a8(Tensor W, Tensor X, int type, SymInt row) -> Tensor"); ops.impl("ggml_mul_mat_a8", torch::kCUDA, &ggml_mul_mat_a8); +#ifndef USE_ROCM // fp8_marlin Optimized Quantized GEMM for FP8 weight-only. ops.def( "fp8_marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, " diff --git a/docs/dev-docker/README.md b/docs/dev-docker/README.md new file mode 100644 index 0000000000000..ad08078936c8d --- /dev/null +++ b/docs/dev-docker/README.md @@ -0,0 +1,428 @@ +# vllm FP8 Latency and Throughput benchmarks on AMD MI300x + +Documentation for vLLM Inferencing on AMD Instinct platforms. + +## Overview + +vLLM is a toolkit and library for large language model (LLM) inference and serving. It deploys the PagedAttention algorithm, which reduces memory consumption and increases throughput by leveraging dynamic key and value allocation in GPU memory. vLLM also incorporates many recent LLM acceleration and quantization algorithms, such as fp8 GeMM, fp8 KV cache, continuous batching, flash attention, hip graph, tensor parallel, GPTQ, AWQ, and token speculation. In addition, AMD implements high-performance custom kernels and modules in vLLM to enhance performance further. + +This documentation shows some reference performance numbers and the steps to reproduce it for the popular Llama 3.1 series models from Meta with a pre-built AMD vLLM docker optimized for an AMD Instinct™ MI300X accelerator. + +It includes: + + - ROCm™ 6.2.2 + + - vLLM 0.6.3 + + - PyTorch 2.5dev (nightly) + +## System configuration + +The performance data below was measured on a server with MI300X accelerators with the following system configuration. The performance might vary with different system configurations. + +| System | MI300X with 8 GPUs | +|---|---| +| BKC | 24.11 | +| ROCm | version ROCm 6.2.2 | +| amdgpu | build 2009461 | +| OS | Ubuntu 22.04 | +| Linux Kernel | 5.15.0-117-generic | +| BMCVersion | C2789.BC.0809.00 | +| BiosVersion | C2789.5.BS.1C11.AG.1 | +| CpldVersion | 02.02.00 | +| DCSCMCpldVersion | 02.02.00 | +| CX7 | FW 28.40.1000 | +| RAM | 1 TB | +| Host CPU | Intel(R) Xeon(R) Platinum 8480C | +| Cores | 224 | +| VRAM | 192 GB | +| Power cap | 750 W | +| SCLK/MCLK | 2100 Mhz / 1300 Mhz | + +## Pull latest + +You can pull the image with `docker pull rocm/vllm-dev:20241114-tuned` + +### What is New + + - MoE optimizations for Mixtral 8x22B, FP16 + - Llama 3.2 stability improvements + + +Gemms are tuned using PyTorch's Tunable Ops feature (https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/cuda/tunable/README.md) +The gemms are automatically enabled in the docker image, and all stored gemm configs are kept in /app/_gemm_csv in the same image + +### Reproducing benchmark results + +### Use pre-quantized models + +To make it easier to run fp8 Llama 3.1 models on MI300X, the quantized checkpoints are available on AMD Huggingface space as follows + +- https://huggingface.co/amd/Meta-Llama-3.1-8B-Instruct-FP8-KV +- https://huggingface.co/amd/Meta-Llama-3.1-70B-Instruct-FP8-KV +- https://huggingface.co/amd/Meta-Llama-3.1-405B-Instruct-FP8-KV +- https://huggingface.co/amd/grok-1-FP8-KV + +Currently these models are private. Please join https://huggingface.co/amd to access. + +Download the model you want to run. + +These FP8 quantized checkpoints were generated with AMD’s Quark Quantizer. For more information about Quark, please refer to https://quark.docs.amd.com/latest/quark_example_torch_llm_gen.html + +### Quantize your own models +This step is optional for you to use quantized models on your own. Take Llama 3.1 405B as an example. + +Download the Model View the Meta-Llama-3.1-405B model at https://huggingface.co/meta-llama/Meta-Llama-3.1-405B. Ensure that you have been granted access, and apply for it if you do not have access. + +If you do not already have a HuggingFace token, open your user profile (https://huggingface.co/settings/profile), select "Access Tokens", press "+ Create New Token", and create a new Read token. + +Install the `huggingface-cli` (if not already available on your system) and log in with the token you created earlier and download the model. The instructions in this document assume that the model will be stored under `/data/llama-3.1`. You can store the model in a different location, but then you'll need to update other commands accordingly. The model is quite large and will take some time to download; it is recommended to use tmux or screen to keep your session running without getting disconnected. + + sudo pip install -U "huggingface_hub[cli]" + + huggingface-cli login + +Enter the token you created earlier; you do NOT need to save it as a git credential + +Create the directory for Llama 3.1 models (if it doesn't already exist) + + sudo mkdir -p /data/llama-3.1 + + sudo chmod -R a+w /data/llama-3.1 + +Download the model + + huggingface-cli download meta-llama/Meta-Llama-3.1-405B-Instruct --exclude "original/*" --local-dir /data/llama-3.1/Meta-Llama-3.1-405B-Instruct + +Similarly, you can download Meta-Llama-3.1-70B and Meta-Llama-3.1-8B. + +[Download and install Quark](https://quark.docs.amd.com/latest/install.html) + +Run the quantization script in the example folder using the following command line: +export MODEL_DIR = [local model checkpoint folder] or meta-llama/Meta-Llama-3.1-405B-Instruct +#### single GPU + python3 quantize_quark.py \ + --model_dir $MODEL_DIR \ + --output_dir Meta-Llama-3.1-405B-Instruct-FP8-KV \ + --quant_scheme w_fp8_a_fp8 \ + --kv_cache_dtype fp8 \ + --num_calib_data 128 \ + --model_export quark_safetensors \ + --no_weight_matrix_merge + +#### If model size is too large for single GPU, please use multi GPU instead. + python3 quantize_quark.py \ + --model_dir $MODEL_DIR \ + --output_dir Meta-Llama-3.1-405B-Instruct-FP8-KV \ + --quant_scheme w_fp8_a_fp8 \ + --kv_cache_dtype fp8 \ + --num_calib_data 128 \ + --model_export quark_safetensors \ + --no_weight_matrix_merge \ + --multi_gpu + + +### Launch AMD vLLM Docker + +Download and launch the docker, + + docker run -it --rm --ipc=host --network=host --group-add render \ + --privileged --security-opt seccomp=unconfined \ + --cap-add=CAP_SYS_ADMIN --cap-add=SYS_PTRACE \ + --device=/dev/kfd --device=/dev/dri --device=/dev/mem \ + -v /data/llama-3.1:/data/llm \ + docker pull rocm/vllm-dev:20241114-tuned + +### Benchmark with AMD vLLM Docker + +There are some system settings to be configured for optimum performance on MI300X. + +#### NUMA balancing setting + +To optimize performance, disable automatic NUMA balancing. Otherwise, the GPU might hang until the periodic balancing is finalized. For further details, refer to the AMD Instinct MI300X system optimization guide. + +Disable automatic NUMA balancing + + sh -c 'echo 0 > /proc/sys/kernel/numa_balancing' + +Check if NUMA balancing is disabled (returns 0 if disabled) + + cat /proc/sys/kernel/numa_balancing + 0 + +#### LLM performance settings + +Some environment variables enhance the performance of the vLLM kernels and PyTorch's tunableOp on the MI300X accelerator. The settings below are already preconfigured in the Docker image. See the AMD Instinct MI300X workload optimization guide for more information. + +##### vLLM performance environment variables + + export VLLM_USE_TRITON_FLASH_ATTN=0 + export NCCL_MIN_NCHANNELS=112 + export VLLM_FP8_PADDING=1 + +You can set both PYTORCH_TUNABLEOP_ENABLED and PYTORCH_TUNABLEOP_TUNING to 1 to performance GEMM tuning for the 1st benchmark run. +It will take some time to complete the tuning during the benchmark. After tuning, it will generate several csv files as the performance lookup database. For the subsequent benchmark runs, you can keep + +PYTORCH_TUNABLEOP_ENABLED as 1 and set +PYTORCH_TUNABLEOP_TUNING to 0 to use the selected kernels. + +##### vLLM engine performance settings +vLLM provides a number of engine options which can be changed to improve performance. +Refer https://docs.vllm.ai/en/stable/models/engine_args.html for the complete list of vLLM engine options. +Below is a list of options which are useful: +- **--max-model-len** : Maximum context length supported by the model instance. Can be set to a lower value than model configuration value to improve performance and gpu memory utilization. +- **--max-num-batched-tokens** : The maximum prefill size, i.e., how many prompt tokens can be packed together in a single prefill. Set to a higher value to improve prefill performance at the cost of higher gpu memory utilization. 65536 works well for LLama models. +- **--max-num-seqs** : The maximum decode batch size. Set to a value higher than the default(256) to improve decode throughput. Higher values will also utilize more KV cache memory. Too high values can cause KV cache space to run out which will lead to decode preemption. 512/1024 works well for LLama models. +- **--max-seq-len-to-capture** : Maximum sequence length for which Hip-graphs are captured and utilized. It's recommended to use Hip-graphs for the best decode performance. The default value of this parameter is 8K, which is lower than the large context lengths supported by recent models such as LLama. Set this parameter to max-model-len or maximum context length supported by the model for best performance. +- **--gpu-memory-utilization** : The ratio of GPU memory reserved by a vLLM instance. Default value is 0.9. It's recommended to set this to 0.99 to increase KV cache space. + +Note: vLLM's server creation command line (vllm serve) supports the above parameters as command line arguments. However, vLLM's benchmark_latency and benchmark_throughput command lines may not include all of these flags as command line arguments. In that case, it might be necessary to add these parameters to the LLMEngine instance constructor inside the benchmark script. + +##### Online Gemm Tuning +Online Gemm tuning for small decode batch sizes can improve performance in some cases. e.g. Llama 70B upto Batch size 8 + +If you want to do limited online tuning use --enforce-eager and tune for particular batch sizes. See example below. + + export PYTORCH_TUNABLEOP_TUNING=1 + export PYTORCH_TUNABLEOP_ENABLED=1 + export PYTORCH_TUNABLEOP_MAX_TUNING_DURATION_MS=100 + export PYTORCH_TUNABLEOP_MAX_WARMUP_DURATION_MS=10 + export PYTORCH_TUNABLEOP_ROTATING_BUFFER_SIZE=1024 + export PYTORCH_TUNABLEOP_FILENAME=/app/tuned_gemm_csv/bench_latency_tune_device_%d_full.csv + + Run the following command for BS=1/2/4/8: + + python /app/vllm/benchmarks/benchmark_latency.py \ + --model \ + --quantization fp8 \ + --kv-cache-dtype fp8 \ + --dtype float16 \ + --max-model-len 8192 \ + --num-iters-warmup 5 \ + --num-iters 5 \ + --tensor-parallel-size 8 \ + --input-len 4096 \ + --output-len 512 \ + --batch-size \ + --num-scheduler-steps 10 \ + --enforce-eager + +The tuned file will be generated for device 0 only at /app/tuned_gemm_csv/bench_latency_tune_device_0_full.csv. Copy this file to /app/tuned_gemm_csv/bench_latency_tune_device__full.csv for D=1 through 7. + +After the above steps, retain the environment variables set earlier, but set export PYTORCH_TUNABLEOP_TUNING=0 to disable online tuning, and use the tuned solutions. + +##### Latency Benchmark + +Benchmark Meta-Llama-3.1-405B FP8 with input 128 tokens, output 128 tokens, batch size 32 and tensor parallelism 8 as an example, + + python /app/vllm/benchmarks/benchmark_latency.py \ + --model /data/llm/Meta-Llama-3.1-405B-Instruct-FP8-KV \ + --quantization fp8 \ + --kv-cache-dtype fp8 \ + --dtype half \ + --gpu-memory-utilization 0.99 \ + --distributed-executor-backend mp \ + --tensor-parallel-size 8 \ + --batch size 32 \ + --input-len 128 \ + --output-len 128 + +If you want to run Meta-Llama-3.1-405B FP16, please run + + python /app/vllm/benchmarks/benchmark_latency.py \ + --model /data/llm/Meta-Llama-3.1-405B-Instruct \ + --dtype float16 \ + --gpu-memory-utilization 0.99 \ + --distributed-executor-backend mp \ + --tensor-parallel-size 8 \ + --batch size 32 \ + --input-len 128 \ + --output-len 128 + +You can change various input-len, output-len, batch size and run the benchmark as well. When output-len is 1, it measures prefill latency (TTFT). +Decoding latency (TPOT) can be calculated based on the measured latency. + +For more information about the parameters, please run + + /app/vllm/benchmarks/benchmark_latency.py -h + +##### Throughput Benchmark + +Benchmark Meta-Llama-3.1-405B FP8 with input 128 tokens, output 128 tokens and tensor parallelism 8 as an example, + + python /app/vllm/benchmarks/benchmark_throughput.py \ + --model /data/llm/Meta-Llama-3.1-405B-Instruct-FP8-KV \ + --quantization fp8 \ + --kv-cache-dtype fp8 \ + --dtype half \ + --gpu-memory-utilization 0.99 \ + --num-prompts 2000 \ + --distributed-executor-backend mp \ + --num-scheduler-steps 10 \ + --tensor-parallel-size 8 \ + --input-len 128 \ + --output-len 128 + +If you want to run Meta-Llama-3.1-405B FP16, please run + + python /app/vllm/benchmarks/benchmark_throughput.py \ + --model /data/llm/Meta-Llama-3.1-405B-Instruct \ + --dtype float16 \ + --gpu-memory-utilization 0.99 \ + --num-prompts 2000 \ + --distributed-executor-backend mp \ + --num-scheduler-steps 10 \ + --tensor-parallel-size 8 \ + --input-len 128 \ + --output-len 128 \ + --swapspace 16 \ + --max-model-length 8192 \ + --max-num-batched-tokens 65536 \ + --gpu-memory-utilization 0.99 + +For fp8 quantized Llama3.18B/70B models: + + Recommend TP:1 for Llama3.1-8B, 8 for Llama3.1-70B + Recommend NSCHED: 10 for Llama3.1-8B, 8 for Llama3.1-70B + +You can change various input-len, output-len, num-prompts and run the benchmark as well. +Please note num-scheduler-step is a new feature added in vLLM 0.6.0. It can improve the decoding latency and throughput, however, it may increase the prefill latency. + +For more information about the parameters, please run + + /app/vllm/benchmarks/benchmark_throughput.py -h + +Tensor parallelism (TP) parameters depends on the model size. For Llama 3.1 70B and 8B model, TP 1 can be used as well for MI300X. In general, TP 8 and 1 is recommended to achieve the optimum performance. + +##### Online Server Benchmark + +Make the following changes if required + +/app/vllm/benchmarks/backend_request_func.py + +line 242 + "ignore_eos": True, + +/app/vllm/benchmarks/benchmark_serving.py +line 245 - interval = np.random.exponential(1.0 / request_rate) +line 245 + ## interval = np.random.exponential(1.0 / request_rate) +line 246 + interval = 1.0 / request_rate + +Benchmark Meta-Llama-3.1-70B with input 4096 tokens, output 512 tokens and tensor parallelism 8 as an example, + + vllm serve /data/llm/Meta-Llama-3.1-70B-Instruct-FP8-KV \ + --swap-space 16 \ + --disable-log-requests \ + --quantization fp8 \ + --kv-cache-dtype fp8 \ + --dtype float16 \ + --max-model-len 8192 \ + --tensor-parallel-size 8 \ + --max-num-batched-tokens 65536 \ + --gpu-memory-utilization 0.99 \ + --num_scheduler-steps 10 + +Change port (for example --port 8005) if port=8000 is currently being used by other processes. + +run client in a separate terminal. Use port_id from previous step else port-id=8000. + + python /app/vllm/benchmarks/benchmark_serving.py \ + --port 8000 \ + --model /data/llm/Meta-Llama-3.1-70B-Instruct-FP8-KV \ + --dataset-name random \ + --random-input-len 4096 \ + --random-output-len 512 \ + --request-rate 1 \ + --num-prompts 500 \ + --percentile-metrics ttft,tpot,itl,e2el + +Once all prompts are processed, terminate the server gracefully (ctrl+c). + +##### CPX mode + +Currently only CPX-NPS1 mode is supported. So ONLY tp=1 is supported in CPX mode. +But multiple instances can be started simultaneously (if needed) in CPX-NPS1 mode. + +Set GPUs in CPX mode + + rocm-smi --setcomputepartition cpx + +Example of running Llama3.1-8B on 1 CPX-NPS1 GPU with input 4096 and output 512. As mentioned above, tp=1. + + HIP_VISIBLE_DEVICES=0 \ + python3 /app/vllm/benchmarks/benchmark_throughput.py \ + --max-model-len 4608 \ + --num-scheduler-steps 10 \ + --num-prompts 100 \ + --model /data/llm/Meta-Llama-3.1-70B-Instruct-FP8-KV \ + --input-len 4096 \ + --output-len 512 \ + --dtype float16 \ + --tensor-parallel-size 1 \ + --output-json \ + --quantization fp8 \ + --gpu-memory-utilization 0.99 + +Set GPU to SPX mode. + + rocm-smi --setcomputepartition spx + +### Speculative Decoding + +Speculative decoding is one of the key features in vLLM. It has been supported on MI300. Here below is an example of the performance benchmark w/wo speculative decoding for Llama 3.1 405B with Llama 3.1 8B as the draft model. + +Without Speculative Decoding - + + python benchmark_latency.py --model /models/models--amd--Meta-Llama-3.1-405B-Instruct-FP8-KV/ --max-model-len 26720 -tp 8 --batch-size 1 --use-v2-block-manager --input-len 1024 --output-len 128 + +With Speculative Decoding - + + python benchmark_latency.py --model /models/models--amd--Meta-Llama-3.1-405B-Instruct-FP8-KV/ --max-model-len 26720 -tp 8 --batch-size 1 --use-v2-block-manager --input-len 1024 --output-len 128 --speculative-model /models/models--amd--Meta-Llama-3.1-8B-Instruct-FP8-KV/ --num-speculative-tokens 5 + +You should see some performance improvement about the e2e latency. + +### MMLU_PRO_Biology Accuracy Eval + +### fp16 +vllm (pretrained=models--meta-llama--Meta-Llama-3.1-405B-Instruct/snapshots/069992c75aed59df00ec06c17177e76c63296a26,dtype=float16,tensor_parallel_size=8), gen_kwargs: (None), limit: None, num_fewshot: None, batch_size: 64 + +| Tasks |Version| Filter |n-shot| Metric | |Value | |Stderr| +|-------|------:|--------------|-----:|-----------|---|-----:|---|-----:| +|biology| 0|custom-extract| 5|exact_match|↑ |0.8466|± |0.0135| + +### fp8 +vllm (pretrained=models--meta-llama--Meta-Llama-3.1-405B-Instruct/snapshots/069992c75aed59df00ec06c17177e76c63296a26,dtype=float16,quantization=fp8,quantized_weights_path=/llama.safetensors,tensor_parallel_size=8), gen_kwargs: (None), limit: None, num_fewshot: None, batch_size: 32 + +| Tasks |Version| Filter |n-shot| Metric | |Value| |Stderr| +|-------|------:|--------------|-----:|-----------|---|----:|---|-----:| +|biology| 0|custom-extract| 5|exact_match|↑ |0.848|± |0.0134| + + +## Performance + +### LLaMA2/3 *MLPerf* 70B + +Please refer to the MLPerf instructions for recreating the MLPerf numbers. + +## Version + +### Release Notes +20240906a: Legacy quantization formats required `--quantization fp8_rocm` as a flag instead of `--quantization fp8` + +Updated: + +vLLM: https://github.com/ROCm/vllm/commit/5362727ec366c1542b2be7a520e7c44e5cc3ce30 +### Docker Manifest + +To reproduce the release docker: + +``` +git clone https://github.com/ROCm/vllm.git +cd vllm +git checkout 5362727ec366c1542b2be7a520e7c44e5cc3ce30 +docker build -f Dockerfile.rocm -t --build-arg BUILD_HIPBLASLT=1 --build-arg USE_CYTHON=1 . +``` + +For details on all the dependencies, please refer to: https://github.com/ROCm/vllm/blob/5362727ec366c1542b2be7a520e7c44e5cc3ce30/Dockerfile.rocm + + + diff --git a/docs/requirements-docs.txt b/docs/requirements-docs.txt index e3e35844405ac..8ea240f59c38f 100644 --- a/docs/requirements-docs.txt +++ b/docs/requirements-docs.txt @@ -12,7 +12,7 @@ pydantic >= 2.8 torch py-cpuinfo transformers -mistral_common >= 1.3.4 +mistral_common >= 1.5.0 aiohttp starlette openai # Required by docs/source/serving/openai_compatible_server.md's vllm.entrypoints.openai.cli_args diff --git a/docs/source/assets/design/arch_overview/entrypoints.excalidraw.png b/docs/source/assets/design/arch_overview/entrypoints.excalidraw.png new file mode 100644 index 0000000000000..bbf46286cfe5d Binary files /dev/null and b/docs/source/assets/design/arch_overview/entrypoints.excalidraw.png differ diff --git a/docs/source/assets/design/arch_overview/llm_engine.excalidraw.png b/docs/source/assets/design/arch_overview/llm_engine.excalidraw.png new file mode 100644 index 0000000000000..ade1d602a9187 Binary files /dev/null and b/docs/source/assets/design/arch_overview/llm_engine.excalidraw.png differ diff --git a/docs/source/automatic_prefix_caching/details.md b/docs/source/automatic_prefix_caching/details.md index 2d3214e28ed93..17f806217aa65 100644 --- a/docs/source/automatic_prefix_caching/details.md +++ b/docs/source/automatic_prefix_caching/details.md @@ -25,7 +25,7 @@ With this mapping, we can add another indirection in vLLM’s KV cache managemen This design achieves automatic prefix caching without the need of maintaining a tree structure among the KV blocks. More specifically, all of the blocks are independent of each other and can be allocated and freed by itself, which enables us to manages the KV cache as ordinary caches in operating system. -# Generalized Caching Policy +## Generalized Caching Policy Keeping all the KV blocks in a hash table enables vLLM to cache KV blocks from earlier requests to save memory and accelerate the computation of future requests. For example, if a new request shares the system prompt with the previous request, the KV cache of the shared prompt can directly be used for the new request without recomputation. However, the total KV cache space is limited and we have to decide which KV blocks to keep or evict when the cache is full. diff --git a/docs/source/design/arch_overview.rst b/docs/source/design/arch_overview.rst new file mode 100644 index 0000000000000..bc3f509f0a66e --- /dev/null +++ b/docs/source/design/arch_overview.rst @@ -0,0 +1,274 @@ +.. _arch_overview: + +Architecture Overview +====================== + +This document provides an overview of the vLLM architecture. + +.. contents:: Table of Contents + :local: + :depth: 2 + +Entrypoints +----------- + +vLLM provides a number of entrypoints for interacting with the system. The +following diagram shows the relationship between them. + +.. image:: /assets/design/arch_overview/entrypoints.excalidraw.png + :alt: Entrypoints Diagram + +LLM Class +^^^^^^^^^ + +The LLM class provides the primary Python interface for doing offline inference, +which is interacting with a model without using a separate model inference +server. + +Here is a sample of `LLM` class usage: + +.. code-block:: python + + from vllm import LLM, SamplingParams + + # Define a list of input prompts + prompts = [ + "Hello, my name is", + "The capital of France is", + "The largest ocean is", + ] + + # Define sampling parameters + sampling_params = SamplingParams(temperature=0.8, top_p=0.95) + + # Initialize the LLM engine with the OPT-125M model + llm = LLM(model="facebook/opt-125m") + + # Generate outputs for the input prompts + outputs = llm.generate(prompts, sampling_params) + + # Print the generated outputs + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + +More API details can be found in the :doc:`Offline Inference +` section of the API docs. + +The code for the `LLM` class can be found in `vllm/entrypoints/llm.py +`_. + +OpenAI-compatible API server +^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +The second primary interface to vLLM is via its OpenAI-compatible API server. +This server can be started using the `vllm serve` command. + +.. code-block:: bash + + vllm serve + +The code for the `vllm` CLI can be found in `vllm/scripts.py +`_. + +Sometimes you may see the API server entrypoint used directly instead of via the +`vllm` CLI command. For example: + +.. code-block:: bash + + python -m vllm.entrypoints.openai.api_server --model + +That code can be found in `vllm/entrypoints/openai/api_server.py +`_. + +More details on the API server can be found in the :doc:`OpenAI Compatible +Server ` document. + +LLM Engine +---------- + +The `LLMEngine` and `AsyncLLMEngine` classes are central to the functioning of +the vLLM system, handling model inference and asynchronous request processing. + +.. image:: /assets/design/arch_overview/llm_engine.excalidraw.png + :alt: LLMEngine Diagram + +LLMEngine +^^^^^^^^^ + +The `LLMEngine` class is the core component of the vLLM engine. It is +responsible for receiving requests from clients and generating outputs from the +model. The `LLMEngine` includes input processing, model execution (possibly +distributed across multiple hosts and/or GPUs), scheduling, and output +processing. + +- **Input Processing**: Handles tokenization of input text using the specified + tokenizer. + +- **Scheduling**: Chooses which requests are processed in each step. + +- **Model Execution**: Manages the execution of the language model, including + distributed execution across multiple GPUs. + +- **Output Processing**: Processes the outputs generated by the model, decoding the + token IDs from a language model into human-readable text. + +The code for `LLMEngine` can be found in `vllm/engine/llm_engine.py`_. + +.. _vllm/engine/llm_engine.py: https://github.com/vllm-project/vllm/tree/main/vllm/engine/llm_engine.py + +AsyncLLMEngine +^^^^^^^^^^^^^^ + +The `AsyncLLMEngine` class is an asynchronous wrapper for the `LLMEngine` class. +It uses `asyncio` to create a background loop that continuously processes +incoming requests. The `AsyncLLMEngine` is designed for online serving, where it +can handle multiple concurrent requests and stream outputs to clients. + +The OpenAI-compatible API server uses the `AsyncLLMEngine`. There is also a demo +API server that serves as a simpler example in +`vllm/entrypoints/api_server.py`_. + +.. _vllm/entrypoints/api_server.py: https://github.com/vllm-project/vllm/tree/main/vllm/entrypoints/api_server.py + +The code for `AsyncLLMEngine` can be found in `vllm/engine/async_llm_engine.py`_. + +.. _vllm/engine/async_llm_engine.py: https://github.com/vllm-project/vllm/tree/main/vllm/engine/async_llm_engine.py + +Worker +------ + +A worker is a process that runs the model inference. vLLM follows the common +practice of using one process to control one accelerator device, such as GPUs. +For example, if we use tensor parallelism of size 2 and pipeline parallelism of +size 2, we will have 4 workers in total. Workers are identified by their +``rank`` and ``local_rank``. ``rank`` is used for global orchestration, while +``local_rank`` is mainly used for assigning the accelerator device and accessing +local resources such as the file system and shared memory. + +Model Runner +------------ + +Every worker has one model runner object, responsible for loading and running +the model. Much of the model execution logic resides here, such as preparing +input tensors and capturing cudagraphs. + +Model +----- + +Every model runner object has one model object, which is the actual +``torch.nn.Module`` instance. See :ref:`huggingface_integration` for how various +configurations affect the class we ultimately get. + +Class Hierarchy +--------------- + +The following figure shows the class hierarchy of vLLM: + + .. figure:: /assets/design/hierarchy.png + :alt: query + :width: 100% + :align: center + +There are several important design choices behind this class hierarchy: + +1. **Extensibility**: All classes in the hierarchy accept a configuration object +containing all the necessary information. The `VllmConfig +`__ +class is the main configuration object that is passed around. The class +hierarchy is quite deep, and every class needs to read the configuration it is +interested in. By encapsulating all configurations in one object, we can easily +pass the configuration object around and access the configuration we need. +Suppose we want to add a new feature (this is often the case given how fast the +field of LLM inference is evolving) that only touches the model runner. We will +have to add a new configuration option in the `VllmConfig` class. Since we pass +the whole config object around, we only need to add the configuration option to +the `VllmConfig` class, and the model runner can access it directly. We don't +need to change the constructor of the engine, worker, or model class to pass the +new configuration option. + +2. **Uniformity**: The model runner needs a unified interface to create and +initialize the model. vLLM supports more than 50 types of popular open-source +models. Each model has its own initialization logic. If the constructor +signature varies with models, the model runner does not know how to call the +constructor accordingly, without complicated and error-prone inspection logic. +By making the constructor of the model class uniform, the model runner can +easily create and initialize the model without knowing the specific model type. +This is also useful for composing models. Vision-language models often consist +of a vision model and a language model. By making the constructor uniform, we +can easily create a vision model and a language model and compose them into a +vision-language model. + +.. note:: + + To support this change, all vLLM models' signatures have been updated to: + + .. code-block:: python + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + + To avoid accidentally passing incorrect arguments, the constructor is now keyword-only. This ensures that the constructor will raise an error if old configurations are passed. vLLM developers have already made this change for all models within vLLM. For out-of-tree registered models, developers need to update their models, for example by adding shim code to adapt the old constructor signature to the new one: + + .. code-block:: python + + class MyOldModel(nn.Module): + def __init__( + self, + config, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, + prefix: str = "", + ) -> None: + ... + + from vllm.config import VllmConfig + class MyNewModel(MyOldModel): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + super().__init__(config, cache_config, quant_config, lora_config, prefix) + + if __version__ >= "0.6.4": + MyModel = MyNewModel + else: + MyModel = MyOldModel + + This way, the model can work with both old and new versions of vLLM. + +3. **Sharding and Quantization at Initialization**: Certain features require +changing the model weights. For example, tensor parallelism needs to shard the +model weights, and quantization needs to quantize the model weights. There are +two possible ways to implement this feature. One way is to change the model +weights after the model is initialized. The other way is to change the model +weights during the model initialization. vLLM chooses the latter. The first +approach is not scalable to large models. Suppose we want to run a 405B model +(with roughly 810GB weights) with 16 H100 80GB GPUs. Ideally, every GPU should +only load 50GB weights. If we change the model weights after the model is +initialized, we need to load the full 810GB weights to every GPU and then shard +the weights, leading to a huge memory overhead. Instead, if we shard the weights +during the model initialization, every layer will only create a shard of the +weights it needs, leading to a much smaller memory overhead. The same idea +applies to quantization. Note that we also add an additional argument ``prefix`` +to the model's constructor so that the model can initialize itself differently +based on the prefix. This is useful for non-uniform quantization, where +different parts of the model are quantized differently. The ``prefix`` is +usually an empty string for the top-level model and a string like ``"vision"`` +or ``"language"`` for the sub-models. In general, it matches the name of the +module's state dict in the checkpoint file. + +One disadvantage of this design is that it is hard to write unit tests for +individual components in vLLM because every component needs to be initialized by +a complete config object. We solve this problem by providing a default +initialization function that creates a default config object with all fields set +to ``None``. If the component we want to test only cares about a few fields in +the config object, we can create a default config object and set the fields we +care about. This way, we can test the component in isolation. Note that many +tests in vLLM are end-to-end tests that test the whole system, so this is not a +big problem. + +In summary, the complete config object ``VllmConfig`` can be treated as an +engine-level global state that is shared among all vLLM classes. diff --git a/docs/source/design/class_hierarchy.rst b/docs/source/design/class_hierarchy.rst deleted file mode 100644 index 58a888b17ba53..0000000000000 --- a/docs/source/design/class_hierarchy.rst +++ /dev/null @@ -1,74 +0,0 @@ -.. _class_hierarchy: - -vLLM's Class Hierarchy -======================= - -This document describes the class hierarchy of vLLM. We will explain the relationships between the core classes, their responsibilities, and the design choices behind them to make vLLM more modular and extensible. - -1. **Entrypoints**: vLLM has two entrypoints: `command line usage `__ with ``vllm serve`` for launching an OpenAI-API compatible server, and `library-style usage `__ with the ``vllm.LLM`` class for running inference in a Python script. These are user-facing entrypoints that end-users interact with. Under the hood, both create an engine object to handle model inference. - -2. **Engine**: Each vLLM instance contains one engine object, orchestrating and serving as the control plane for model inference. Depending on the configuration, the engine can create multiple workers to handle the inference workload. - -3. **Worker**: A worker is a process that runs the model inference. vLLM follows the common practice of using one process to control one accelerator device, such as GPUs. For example, if we use tensor parallelism of size 2 and pipeline parallelism of size 2, we will have 4 workers in total. Workers are identified by their ``rank`` and ``local_rank``. ``rank`` is used for global orchestration, while ``local_rank`` is mainly used for assigning the accelerator device and accessing local resources such as the file system and shared memory. - -4. **Model Runner**: Every worker has one model runner object, responsible for loading and running the model. Much of the model execution logic resides here, such as preparing input tensors and capturing cudagraphs. - -5. **Model**: Every model runner object has one model object, which is the actual ``torch.nn.Module`` instance. See :ref:`huggingface_integration` for how various configurations affect the class we ultimately get. - -The following figure shows the class hierarchy of vLLM: - - .. figure:: ../assets/design/hierarchy.png - :alt: query - :width: 100% - :align: center - -There are several important design choices behind this class hierarchy: - -1. **Extensibility**: All classes in the hierarchy accept a configuration object containing all the necessary information. The `VllmConfig `__ class is the main configuration object that is passed around. The class hierarchy is quite deep, and every class needs to read the configuration it is interested in. By encapsulating all configurations in one object, we can easily pass the configuration object around and access the configuration we need. Suppose we want to add a new feature (this is often the case given how fast the field of LLM inference is evolving) that only touches the model runner. We will have to add a new configuration option in the `VllmConfig` class. Since we pass the whole config object around, we only need to add the configuration option to the `VllmConfig` class, and the model runner can access it directly. We don't need to change the constructor of the engine, worker, or model class to pass the new configuration option. - -2. **Uniformity**: The model runner needs a unified interface to create and initialize the model. vLLM supports more than 50 types of popular open-source models. Each model has its own initialization logic. If the constructor signature varies with models, the model runner does not know how to call the constructor accordingly, without complicated and error-prone inspection logic. By making the constructor of the model class uniform, the model runner can easily create and initialize the model without knowing the specific model type. This is also useful for composing models. Vision-language models often consist of a vision model and a language model. By making the constructor uniform, we can easily create a vision model and a language model and compose them into a vision-language model. - -.. note:: - - To support this change, all vLLM models' signatures have been updated to: - - .. code-block:: python - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - - To avoid accidentally passing incorrect arguments, the constructor is now keyword-only. This ensures that the constructor will raise an error if old configurations are passed. vLLM developers have already made this change for all models within vLLM. For out-of-tree registered models, developers need to update their models, for example by adding shim code to adapt the old constructor signature to the new one: - - .. code-block:: python - - class MyOldModel(nn.Module): - def __init__( - self, - config, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - lora_config: Optional[LoRAConfig] = None, - prefix: str = "", - ) -> None: - ... - - from vllm.config import VllmConfig - class MyNewModel(MyOldModel): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - config = vllm_config.model_config.hf_config - cache_config = vllm_config.cache_config - quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config - super().__init__(config, cache_config, quant_config, lora_config, prefix) - - if __version__ >= "0.6.4": - MyModel = MyNewModel - else: - MyModel = MyOldModel - - This way, the model can work with both old and new versions of vLLM. - -3. **Sharding and Quantization at Initialization**: Certain features require changing the model weights. For example, tensor parallelism needs to shard the model weights, and quantization needs to quantize the model weights. There are two possible ways to implement this feature. One way is to change the model weights after the model is initialized. The other way is to change the model weights during the model initialization. vLLM chooses the latter. The first approach is not scalable to large models. Suppose we want to run a 405B model (with roughly 810GB weights) with 16 H100 80GB GPUs. Ideally, every GPU should only load 50GB weights. If we change the model weights after the model is initialized, we need to load the full 810GB weights to every GPU and then shard the weights, leading to a huge memory overhead. Instead, if we shard the weights during the model initialization, every layer will only create a shard of the weights it needs, leading to a much smaller memory overhead. The same idea applies to quantization. Note that we also add an additional argument ``prefix`` to the model's constructor so that the model can initialize itself differently based on the prefix. This is useful for non-uniform quantization, where different parts of the model are quantized differently. The ``prefix`` is usually an empty string for the top-level model and a string like ``"vision"`` or ``"language"`` for the sub-models. In general, it matches the name of the module's state dict in the checkpoint file. - -One disadvantage of this design is that it is hard to write unit tests for individual components in vLLM because every component needs to be initialized by a complete config object. We solve this problem by providing a default initialization function that creates a default config object with all fields set to ``None``. If the component we want to test only cares about a few fields in the config object, we can create a default config object and set the fields we care about. This way, we can test the component in isolation. Note that many tests in vLLM are end-to-end tests that test the whole system, so this is not a big problem. - -In summary, the complete config object ``VllmConfig`` can be treated as an engine-level global state that is shared among all vLLM classes. diff --git a/docs/source/design/plugin_system.rst b/docs/source/design/plugin_system.rst index bfca702b9267a..5a96cc8b3a464 100644 --- a/docs/source/design/plugin_system.rst +++ b/docs/source/design/plugin_system.rst @@ -8,7 +8,7 @@ The community frequently requests the ability to extend vLLM with custom feature How Plugins Work in vLLM ------------------------ -Plugins are user-registered code that vLLM executes. Given vLLM's architecture (see :ref:`class_hierarchy`), multiple processes may be involved, especially when using distributed inference with various parallelism techniques. To enable plugins successfully, every process created by vLLM needs to load the plugin. This is done by the `load_general_plugins `__ function in the ``vllm.plugins`` module. This function is called for every process created by vLLM before it starts any work. +Plugins are user-registered code that vLLM executes. Given vLLM's architecture (see :ref:`arch_overview`), multiple processes may be involved, especially when using distributed inference with various parallelism techniques. To enable plugins successfully, every process created by vLLM needs to load the plugin. This is done by the `load_general_plugins `__ function in the ``vllm.plugins`` module. This function is called for every process created by vLLM before it starts any work. How vLLM Discovers Plugins -------------------------- @@ -59,4 +59,4 @@ Guidelines for Writing Plugins Compatibility Guarantee ----------------------- -vLLM guarantees the interface of documented plugins, such as ``ModelRegistry.register_model``, will always be available for plugins to register models. However, it is the responsibility of plugin developers to ensure their plugins are compatible with the version of vLLM they are targeting. For example, ``"vllm_add_dummy_model.my_llava:MyLlava"`` should be compatible with the version of vLLM that the plugin targets. The interface for the model may change during vLLM's development. \ No newline at end of file +vLLM guarantees the interface of documented plugins, such as ``ModelRegistry.register_model``, will always be available for plugins to register models. However, it is the responsibility of plugin developers to ensure their plugins are compatible with the version of vLLM they are targeting. For example, ``"vllm_add_dummy_model.my_llava:MyLlava"`` should be compatible with the version of vLLM that the plugin targets. The interface for the model may change during vLLM's development. diff --git a/docs/source/getting_started/arm-installation.rst b/docs/source/getting_started/arm-installation.rst new file mode 100644 index 0000000000000..7b457df92c11d --- /dev/null +++ b/docs/source/getting_started/arm-installation.rst @@ -0,0 +1,50 @@ +.. _installation_arm: + +Installation for ARM CPUs +========================= + +vLLM has been adapted to work on ARM64 CPUs with NEON support, leveraging the CPU backend initially developed for the x86 platform. This guide provides installation instructions specific to ARM. For additional details on supported features, refer to the x86 platform documentation covering: + +* CPU backend inference capabilities +* Relevant runtime environment variables +* Performance optimization tips + +ARM CPU backend currently supports Float32, FP16 and BFloat16 datatypes. +Contents: + +1. :ref:`Requirements ` +2. :ref:`Quick Start with Dockerfile ` +3. :ref:`Building from Source ` + +.. _arm_backend_requirements: + +Requirements +------------ + +* **Operating System**: Linux or macOS +* **Compiler**: gcc/g++ >= 12.3.0 (optional, but recommended) +* **Instruction Set Architecture (ISA)**: NEON support is required + +.. _arm_backend_quick_start_dockerfile: + +Quick Start with Dockerfile +--------------------------- + +You can quickly set up vLLM on ARM using Docker: + +.. code-block:: console + + $ docker build -f Dockerfile.arm -t vllm-cpu-env --shm-size=4g . + $ docker run -it \ + --rm \ + --network=host \ + --cpuset-cpus= \ + --cpuset-mems= \ + vllm-cpu-env + +.. _build_arm_backend_from_source: + +Building from Source +-------------------- + +To build vLLM from source on Ubuntu 22.04 or other Linux distributions, follow a similar process as with x86. Testing has been conducted on AWS Graviton3 instances for compatibility. diff --git a/docs/source/getting_started/cpu-installation.rst b/docs/source/getting_started/cpu-installation.rst index 69530fd778c55..649de1cd9b53c 100644 --- a/docs/source/getting_started/cpu-installation.rst +++ b/docs/source/getting_started/cpu-installation.rst @@ -5,11 +5,11 @@ Installation with CPU vLLM initially supports basic model inferencing and serving on x86 CPU platform, with data types FP32, FP16 and BF16. vLLM CPU backend supports the following vLLM features: -- Tensor Parallel (``-tp = N``) -- Quantization (``INT8 W8A8, AWQ``) - -.. note:: - More advanced features on `chunked-prefill`, `prefix-caching` and `FP8 KV cache` are under development and will be available soon. +- Tensor Parallel +- Model Quantization (``INT8 W8A8, AWQ``) +- Chunked-prefill +- Prefix-caching +- FP8-E5M2 KV-Caching (TODO) Table of contents: diff --git a/docs/source/getting_started/debugging.rst b/docs/source/getting_started/debugging.rst index 77bf550601346..0c1afcbd7c0b9 100644 --- a/docs/source/getting_started/debugging.rst +++ b/docs/source/getting_started/debugging.rst @@ -86,7 +86,6 @@ If GPU/CPU communication cannot be established, you can use the following Python from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator pynccl = PyNcclCommunicator(group=gloo_group, device=local_rank) - pynccl.disabled = False s = torch.cuda.Stream() with torch.cuda.stream(s): diff --git a/docs/source/getting_started/gaudi-installation.rst b/docs/source/getting_started/gaudi-installation.rst index 68c1a56660fa4..249e08278ff8f 100644 --- a/docs/source/getting_started/gaudi-installation.rst +++ b/docs/source/getting_started/gaudi-installation.rst @@ -4,7 +4,7 @@ Installation with Intel® Gaudi® AI Accelerators This README provides instructions on running vLLM with Intel Gaudi devices. Requirements and Installation -============================= +----------------------------- Please follow the instructions provided in the `Gaudi Installation Guide `__ @@ -13,7 +13,7 @@ please follow the methods outlined in the `Optimizing Training Platform Guide `__. Requirements ------------- +~~~~~~~~~~~~ - OS: Ubuntu 22.04 LTS - Python: 3.10 @@ -22,7 +22,7 @@ Requirements Quick start using Dockerfile ----------------------------- +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. code:: console $ docker build -f Dockerfile.hpu -t vllm-hpu-env . @@ -34,10 +34,10 @@ Quick start using Dockerfile Build from source ------------------ +~~~~~~~~~~~~~~~~~ Environment verification -~~~~~~~~~~~~~~~~~~~~~~~~ +^^^^^^^^^^^^^^^^^^^^^^^^ To verify that the Intel Gaudi software was correctly installed, run: @@ -53,7 +53,7 @@ Verification `__ @@ -107,7 +107,7 @@ Supported Features - Attention with Linear Biases (ALiBi) Unsupported Features -==================== +-------------------- - Beam search - LoRA adapters @@ -115,7 +115,7 @@ Unsupported Features - Prefill chunking (mixed-batch inferencing) Supported Configurations -======================== +------------------------ The following configurations have been validated to be function with Gaudi2 devices. Configurations that are not listed may or may not work. @@ -152,10 +152,10 @@ Gaudi2 devices. Configurations that are not listed may or may not work. with tensor parallelism on 8x HPU, BF16 datatype with random or greedy sampling Performance Tuning -================== +------------------ Execution modes ---------------- +~~~~~~~~~~~~~~~ Currently in vLLM for HPU we support four execution modes, depending on selected HPU PyTorch Bridge backend (via ``PT_HPU_LAZY_MODE`` environment variable), and ``--enforce-eager`` flag. @@ -184,7 +184,7 @@ Currently in vLLM for HPU we support four execution modes, depending on selected Bucketing mechanism -------------------- +~~~~~~~~~~~~~~~~~~~ Intel Gaudi accelerators work best when operating on models with fixed tensor shapes. `Intel Gaudi Graph Compiler `__ is responsible for generating optimized binary code that implements the given model topology on Gaudi. In its default configuration, the produced binary code may be heavily dependent on input and output tensor shapes, and can require graph recompilation when encountering differently shaped tensors within the same topology. While the resulting binaries utilize Gaudi efficiently, the compilation itself may introduce a noticeable overhead in end-to-end execution. In a dynamic inference serving scenario, there is a need to minimize the number of graph compilations and reduce the risk of graph compilation occurring during server runtime. Currently it is achieved by "bucketing" model's forward pass across two dimensions - ``batch_size`` and ``sequence_length``. @@ -233,7 +233,7 @@ As an example, if a request of 3 sequences, with max sequence length of 412 come Bucketing is transparent to a client - padding in sequence length dimension is never returned to the client, and padding in batch dimension does not create new requests. Warmup ------- +~~~~~~ Warmup is an optional, but highly recommended step occurring before vLLM server starts listening. It executes a forward pass for each bucket with dummy data. The goal is to pre-compile all graphs and not incur any graph compilation overheads within bucket boundaries during server runtime. Each warmup step is logged during vLLM startup: @@ -257,7 +257,7 @@ This example uses the same buckets as in *Bucketing mechanism* section. Each out Compiling all the buckets might take some time and can be turned off with ``VLLM_SKIP_WARMUP=true`` environment variable. Keep in mind that if you do that, you may face graph compilations once executing a given bucket for the first time. It is fine to disable warmup for development, but it's highly recommended to enable it in deployment. HPU Graph capture ------------------ +~~~~~~~~~~~~~~~~~ `HPU Graphs `__ are currently the most performant execution method of vLLM on Intel Gaudi. When HPU Graphs are enabled, execution graphs will be traced (recorded) ahead of time (after performing warmup), to be later replayed during inference, significantly reducing host overheads. Recording can take large amounts of memory, which needs to be taken into account when allocating KV cache. Enabling HPU Graphs will impact the number of available KV cache blocks, but vLLM provides user-configurable variables to control memory management. @@ -321,7 +321,7 @@ Each described step is logged by vLLM server, as follows (negative values corres Recommended vLLM Parameters ---------------------------- +~~~~~~~~~~~~~~~~~~~~~~~~~~~ - We recommend running inference on Gaudi 2 with ``block_size`` of 128 for BF16 data type. Using default values (16, 32) might lead to @@ -333,7 +333,7 @@ Recommended vLLM Parameters If you encounter out-of-memory issues, see troubleshooting section. Environment variables ---------------------- +~~~~~~~~~~~~~~~~~~~~~ **Diagnostic and profiling knobs:** @@ -380,7 +380,7 @@ Additionally, there are HPU PyTorch Bridge environment variables impacting vLLM - ``PT_HPU_ENABLE_LAZY_COLLECTIVES``: required to be ``true`` for tensor parallel inference with HPU Graphs Troubleshooting: Tweaking HPU Graphs -==================================== +------------------------------------ If you experience device out-of-memory issues or want to attempt inference at higher batch sizes, try tweaking HPU Graphs by following diff --git a/docs/source/getting_started/installation.rst b/docs/source/getting_started/installation.rst index f02626bda4c64..e3dbbc9affe66 100644 --- a/docs/source/getting_started/installation.rst +++ b/docs/source/getting_started/installation.rst @@ -170,6 +170,18 @@ To build vLLM using an existing PyTorch installation: $ pip install -e . --no-build-isolation +Use the local cutlass for compilation +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +Currently, before starting the build process, vLLM fetches cutlass code from GitHub. However, there may be scenarios where you want to use a local version of cutlass instead. +To achieve this, you can set the environment variable VLLM_CUTLASS_SRC_DIR to point to your local cutlass directory. + +.. code-block:: console + + $ git clone https://github.com/vllm-project/vllm.git + $ cd vllm + $ VLLM_CUTLASS_SRC_DIR=/path/to/cutlass pip install -e . + + Troubleshooting ~~~~~~~~~~~~~~~ diff --git a/docs/source/index.rst b/docs/source/index.rst index 3b2698a8845ed..0692e949f1c77 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -67,6 +67,7 @@ Documentation getting_started/openvino-installation getting_started/cpu-installation getting_started/gaudi-installation + getting_started/arm-installation getting_started/neuron-installation getting_started/tpu-installation getting_started/xpu-installation @@ -101,6 +102,7 @@ Documentation models/engine_args models/lora models/vlm + models/structured_outputs models/spec_decode models/performance @@ -156,7 +158,7 @@ Documentation :maxdepth: 2 :caption: Design - design/class_hierarchy + design/arch_overview design/huggingface_integration design/plugin_system design/input_processing/model_inputs_index diff --git a/docs/source/models/adding_model.rst b/docs/source/models/adding_model.rst index a70ebf99c746f..df06d736ca86b 100644 --- a/docs/source/models/adding_model.rst +++ b/docs/source/models/adding_model.rst @@ -38,41 +38,70 @@ For instance, vLLM's `OPT model Union[Tuple, CausalLMOutputWithPast]: - + positions: torch.Tensor, - + kv_caches: List[torch.Tensor], - + attn_metadata: AttentionMetadata, - + ) -> Optional[SamplerOutput]: - -1. Update the code by considering that :code:`input_ids` and :code:`positions` are now flattened tensors. -2. Replace the attention operation with either :code:`PagedAttention`, :code:`PagedAttentionWithRoPE`, or :code:`PagedAttentionWithALiBi` depending on the model's architecture. +To ensure compatibility with vLLM, your model must meet the following requirements: + +Initialization Code +^^^^^^^^^^^^^^^^^^^ + +All vLLM modules within the model must include a ``prefix`` argument in their constructor. This ``prefix`` is typically the full name of the module in the model's state dictionary and is crucial for: + +* Runtime support: vLLM's attention operators are registered in a model's state by their full names. Each attention operator must have a unique prefix as its layer name to avoid conflicts. +* Non-uniform quantization support: A quantized checkpoint can selectively quantize certain layers while keeping others in full precision. By providing the ``prefix`` during initialization, vLLM can match the current layer's ``prefix`` with the quantization configuration to determine if the layer should be initialized in quantized mode. + +The initialization code should look like this: + +.. code-block:: python + + from torch import nn + from vllm.config import VllmConfig + from vllm.attention import Attention + + class MyAttention(nn.Module): + def __init__(self, vllm_config: VllmConfig, prefix: str): + super().__init__() + self.attn = Attention(prefix=f"{prefix}.attn") + + class MyDecoderLayer(nn.Module): + def __init__(self, vllm_config: VllmConfig, prefix: str): + super().__init__() + self.self_attn = MyAttention(prefix=f"{prefix}.self_attn") + + class MyModel(nn.Module): + def __init__(self, vllm_config: VllmConfig, prefix: str): + super().__init__() + self.layers = nn.ModuleList( + [MyDecoderLayer(vllm_config, prefix=f"{prefix}.layers.{i}") for i in range(vllm_config.model_config.hf_config.num_hidden_layers)] + ) + + class MyModelForCausalLM(nn.Module): + def __init__(self, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + self.model = MyModel(vllm_config, prefix=f"{prefix}.model") + +Computation Code +^^^^^^^^^^^^^^^^ + +Rewrite the :meth:`~torch.nn.Module.forward` method of your model to remove any unnecessary code, such as training-specific code. Modify the input parameters to treat ``input_ids`` and ``positions`` as flattened tensors with a single batch size dimension, without a max-sequence length dimension. + +.. code-block:: python + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + ... .. note:: Currently, vLLM supports the basic multi-head attention mechanism and its variant with rotary positional embeddings. If your model employs a different attention mechanism, you will need to implement a new attention layer in vLLM. +For reference, check out the `LLAMA model `__. vLLM already supports a large number of models. It is recommended to find a model similar to yours and adapt it to your model's architecture. Check out the `vLLM models `__ directory for more examples. 3. (Optional) Implement tensor parallelism and quantization support ------------------------------------------------------------------- diff --git a/docs/source/models/structured_outputs.rst b/docs/source/models/structured_outputs.rst new file mode 100644 index 0000000000000..484e1f17d191e --- /dev/null +++ b/docs/source/models/structured_outputs.rst @@ -0,0 +1,267 @@ +.. _structured_outputs: + +Structured Outputs +================== + +vLLM supports the generation of structured outputs using `outlines `_ or `lm-format-enforcer `_ as backends for the guided decoding. +This document shows you some examples of the different options that are available to generate structured outputs. + + +Online Inference (OpenAI API) +----------------------------- + +You can generate structured outputs using the OpenAI's `Completions `_ and `Chat `_ API. + +The following parameters are supported, which must be added as extra parameters: + +- ``guided_choice``: the output will be exactly one of the choices. +- ``guided_regex``: the output will follow the regex pattern. +- ``guided_json``: the output will follow the JSON schema. +- ``guided_grammar``: the output will follow the context free grammar. +- ``guided_whitespace_pattern``: used to override the default whitespace pattern for guided json decoding. +- ``guided_decoding_backend``: used to select the guided decoding backend to use. + +You can see the complete list of supported parameters on the `OpenAI Compatible Server `_ page. + +Now let´s see an example for each of the cases, starting with the ``guided_choice``, as it´s the easiest one: + +.. code-block:: python + + from openai import OpenAI + client = OpenAI( + base_url="http://localhost:8000/v1", + api_key="-", + ) + + completion = client.chat.completions.create( + model="Qwen/Qwen2.5-3B-Instruct", + messages=[ + {"role": "user", "content": "Classify this sentiment: vLLM is wonderful!"} + ], + extra_body={"guided_choice": ["positive", "negative"]}, + ) + print(completion.choices[0].message.content) + + +The next example shows how to use the ``guided_regex``. The idea is to generate an email address, given a simple regex template: + +.. code-block:: python + + completion = client.chat.completions.create( + model="Qwen/Qwen2.5-3B-Instruct", + messages=[ + { + "role": "user", + "content": "Generate an example email address for Alan Turing, who works in Enigma. End in .com and new line. Example result: alan.turing@enigma.com\n", + } + ], + extra_body={"guided_regex": "\w+@\w+\.com\n", "stop": ["\n"]}, + ) + print(completion.choices[0].message.content) + +One of the most relevant features in structured text generation is the option to generate a valid JSON with pre-defined fields and formats. +For this we can use the ``guided_json`` parameter in two different ways: + +- Using directly a `JSON Schema `_ +- Defining a `Pydantic model `_ and then extracting the JSON Schema from it (which is normally an easier option). + +The next example shows how to use the ``guided_json`` parameter with a Pydantic model: + +.. code-block:: python + + from pydantic import BaseModel + from enum import Enum + + class CarType(str, Enum): + sedan = "sedan" + suv = "SUV" + truck = "Truck" + coupe = "Coupe" + + + class CarDescription(BaseModel): + brand: str + model: str + car_type: CarType + + + json_schema = CarDescription.model_json_schema() + + completion = client.chat.completions.create( + model="Qwen/Qwen2.5-3B-Instruct", + messages=[ + { + "role": "user", + "content": "Generate a JSON with the brand, model and car_type of the most iconic car from the 90's", + } + ], + extra_body={"guided_json": json_schema}, + ) + print(completion.choices[0].message.content) + +.. tip:: + While not strictly necessary, normally it´s better to indicate in the prompt that a JSON needs to be generated and which fields and how should the LLM fill them. + This can improve the results notably in most cases. + + +Finally we have the ``guided_grammar``, which probably is the most difficult one to use but it´s really powerful, as it allows us to define complete languages like SQL queries. +It works by using a context free EBNF grammar, which for example we can use to define a specific format of simplified SQL queries, like in the example below: + +.. code-block:: python + + simplified_sql_grammar = """ + ?start: select_statement + + ?select_statement: "SELECT " column_list " FROM " table_name + + ?column_list: column_name ("," column_name)* + + ?table_name: identifier + + ?column_name: identifier + + ?identifier: /[a-zA-Z_][a-zA-Z0-9_]*/ + """ + + completion = client.chat.completions.create( + model="Qwen/Qwen2.5-3B-Instruct", + messages=[ + { + "role": "user", + "content": "Generate an SQL query to show the 'username' and 'email' from the 'users' table.", + } + ], + extra_body={"guided_grammar": simplified_sql_grammar}, + ) + print(completion.choices[0].message.content) + +The complete code of the examples can be found on `examples/openai_chat_completion_structured_outputs.py `_. + +Experimental Automatic Parsing (OpenAI API) +-------------------------------------------- + +This section covers the OpenAI beta wrapper over the ``client.chat.completions.create()`` method that provides richer integrations with Python specific types. + +At the time of writing (``openai==1.54.4``), this is a "beta" feature in the OpenAI client library. Code reference can be found `here `_. + +For the following examples, vLLM was setup using ``vllm serve meta-llama/Llama-3.1-8B-Instruct`` + +Here is a simple example demonstrating how to get structured output using Pydantic models: + +.. code-block:: python + + from pydantic import BaseModel + from openai import OpenAI + + + class Info(BaseModel): + name: str + age: int + + + client = OpenAI(base_url="http://0.0.0.0:8000/v1", api_key="dummy") + completion = client.beta.chat.completions.parse( + model="meta-llama/Llama-3.1-8B-Instruct", + messages=[ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "My name is Cameron, I'm 28. What's my name and age?"}, + ], + response_format=Info, + extra_body=dict(guided_decoding_backend="outlines"), + ) + + message = completion.choices[0].message + print(message) + assert message.parsed + print("Name:", message.parsed.name) + print("Age:", message.parsed.age) + +Output: + +.. code-block:: console + + ParsedChatCompletionMessage[Testing](content='{"name": "Cameron", "age": 28}', refusal=None, role='assistant', audio=None, function_call=None, tool_calls=[], parsed=Testing(name='Cameron', age=28)) + Name: Cameron + Age: 28 + + +Here is a more complex example using nested Pydantic models to handle a step-by-step math solution: + +.. code-block:: python + + from typing import List + from pydantic import BaseModel + from openai import OpenAI + + + class Step(BaseModel): + explanation: str + output: str + + + class MathResponse(BaseModel): + steps: List[Step] + final_answer: str + + + client = OpenAI(base_url="http://0.0.0.0:8000/v1", api_key="dummy") + completion = client.beta.chat.completions.parse( + model="meta-llama/Llama-3.1-8B-Instruct", + messages=[ + {"role": "system", "content": "You are a helpful expert math tutor."}, + {"role": "user", "content": "Solve 8x + 31 = 2."}, + ], + response_format=MathResponse, + extra_body=dict(guided_decoding_backend="outlines"), + ) + + message = completion.choices[0].message + print(message) + assert message.parsed + for i, step in enumerate(message.parsed.steps): + print(f"Step #{i}:", step) + print("Answer:", message.parsed.final_answer) + +Output: + +.. code-block:: console + + ParsedChatCompletionMessage[MathResponse](content='{ "steps": [{ "explanation": "First, let\'s isolate the term with the variable \'x\'. To do this, we\'ll subtract 31 from both sides of the equation.", "output": "8x + 31 - 31 = 2 - 31"}, { "explanation": "By subtracting 31 from both sides, we simplify the equation to 8x = -29.", "output": "8x = -29"}, { "explanation": "Next, let\'s isolate \'x\' by dividing both sides of the equation by 8.", "output": "8x / 8 = -29 / 8"}], "final_answer": "x = -29/8" }', refusal=None, role='assistant', audio=None, function_call=None, tool_calls=[], parsed=MathResponse(steps=[Step(explanation="First, let's isolate the term with the variable 'x'. To do this, we'll subtract 31 from both sides of the equation.", output='8x + 31 - 31 = 2 - 31'), Step(explanation='By subtracting 31 from both sides, we simplify the equation to 8x = -29.', output='8x = -29'), Step(explanation="Next, let's isolate 'x' by dividing both sides of the equation by 8.", output='8x / 8 = -29 / 8')], final_answer='x = -29/8')) + Step #0: explanation="First, let's isolate the term with the variable 'x'. To do this, we'll subtract 31 from both sides of the equation." output='8x + 31 - 31 = 2 - 31' + Step #1: explanation='By subtracting 31 from both sides, we simplify the equation to 8x = -29.' output='8x = -29' + Step #2: explanation="Next, let's isolate 'x' by dividing both sides of the equation by 8." output='8x / 8 = -29 / 8' + Answer: x = -29/8 + +Offline Inference +----------------- + +Offline inference allows for the same types of guided decoding. +To use it, we´ll need to configure the guided decoding using the class ``GuidedDecodingParams`` inside ``SamplingParams``. +The main available options inside ``GuidedDecodingParams`` are: + +- ``json`` +- ``regex`` +- ``choice`` +- ``grammar`` +- ``backend`` +- ``whitespace_pattern`` + +These parameters can be used in the same way as the parameters from the Online Inference examples above. +One example for the usage of the ``choices`` parameter is shown below: + +.. code-block:: python + + from vllm import LLM, SamplingParams + from vllm.sampling_params import GuidedDecodingParams + + llm = LLM(model="HuggingFaceTB/SmolLM2-1.7B-Instruct") + + guided_decoding_params = GuidedDecodingParams(choice=["Positive", "Negative"]) + sampling_params = SamplingParams(guided_decoding=guided_decoding_params) + outputs = llm.generate( + prompts="Classify this sentiment: vLLM is wonderful!", + sampling_params=sampling_params, + ) + print(outputs[0].outputs[0].text) + +A complete example with all options can be found in `examples/offline_inference_structured_outputs.py `_. diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index e902d393f2f70..9f3b6f59068e2 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -139,6 +139,11 @@ Text Generation - :code:`google/gemma-2-9b`, :code:`google/gemma-2-27b`, etc. - ✅︎ - ✅︎ + * - :code:`GlmForCausalLM` + - GLM-4 + - :code:`THUDM/glm-4-9b-chat-hf`, etc. + - ✅︎ + - ✅︎ * - :code:`GPT2LMHeadModel` - GPT-2 - :code:`gpt2`, :code:`gpt2-xl`, etc. @@ -177,7 +182,7 @@ Text Generation * - :code:`InternLM2ForCausalLM` - InternLM2 - :code:`internlm/internlm2-7b`, :code:`internlm/internlm2-chat-7b`, etc. - - + - ✅︎ - ✅︎ * - :code:`JAISLMHeadModel` - Jais @@ -234,6 +239,11 @@ Text Generation - :code:`allenai/OLMo-1B-hf`, :code:`allenai/OLMo-7B-hf`, etc. - - ✅︎ + * - :code:`OLMo2ForCausalLM` + - OLMo2 + - :code:`allenai/OLMo2-7B-1124`, etc. + - + - ✅︎ * - :code:`OLMoEForCausalLM` - OLMoE - :code:`allenai/OLMoE-1B-7B-0924`, :code:`allenai/OLMoE-1B-7B-0924-Instruct`, etc. @@ -304,6 +314,11 @@ Text Generation - :code:`upstage/solar-pro-preview-instruct`, etc. - ✅︎ - ✅︎ + * - :code:`TeleChat2ForCausalLM` + - TeleChat2 + - :code:`TeleAI/TeleChat2-3B`, :code:`TeleAI/TeleChat2-7B`, :code:`TeleAI/TeleChat2-35B`, etc. + - ✅︎ + - ✅︎ * - :code:`XverseForCausalLM` - XVERSE - :code:`xverse/XVERSE-7B-Chat`, :code:`xverse/XVERSE-13B-Chat`, :code:`xverse/XVERSE-65B-Chat`, etc. @@ -325,6 +340,11 @@ Text Embedding - Example HF Models - :ref:`LoRA ` - :ref:`PP ` + * - :code:`BertModel` + - BERT-based + - :code:`BAAI/bge-base-en-v1.5`, etc. + - + - * - :code:`Gemma2Model` - Gemma2-based - :code:`BAAI/bge-multilingual-gemma2`, etc. @@ -337,9 +357,19 @@ Text Embedding - ✅︎ * - :code:`Qwen2Model`, :code:`Qwen2ForCausalLM` - Qwen2-based - - :code:`ssmits/Qwen2-7B-Instruct-embed-base`, :code:`Alibaba-NLP/gte-Qwen2-1.5B-instruct`, etc. + - :code:`ssmits/Qwen2-7B-Instruct-embed-base` (see note), :code:`Alibaba-NLP/gte-Qwen2-7B-instruct` (see note), etc. - ✅︎ - ✅︎ + * - :code:`RobertaModel`, :code:`RobertaForMaskedLM` + - RoBERTa-based + - :code:`sentence-transformers/all-roberta-large-v1`, :code:`sentence-transformers/all-roberta-large-v1`, etc. + - + - + * - :code:`XLMRobertaModel` + - XLM-RoBERTa-based + - :code:`intfloat/multilingual-e5-large`, etc. + - + - .. important:: Some model architectures support both generation and embedding tasks. @@ -348,6 +378,17 @@ Text Embedding .. tip:: You can override the model's pooling method by passing :code:`--override-pooler-config`. +.. note:: + :code:`ssmits/Qwen2-7B-Instruct-embed-base` has an improperly defined Sentence Transformers config. + You should manually set mean pooling by passing :code:`--override-pooler-config '{"pooling_type": "MEAN"}'`. + +.. note:: + Unlike base Qwen2, :code:`Alibaba-NLP/gte-Qwen2-7B-instruct` uses bi-directional attention. + You can set :code:`--hf-overrides '{"is_causal": false}'` to change the attention mask accordingly. + + On the other hand, its 1.5B variant (:code:`Alibaba-NLP/gte-Qwen2-1.5B-instruct`) uses causal attention + despite being described otherwise on its model card. + Reward Modeling --------------- @@ -360,12 +401,21 @@ Reward Modeling - Example HF Models - :ref:`LoRA ` - :ref:`PP ` + * - :code:`LlamaForCausalLM` + - Llama-based + - :code:`peiyi9979/math-shepherd-mistral-7b-prm`, etc. + - ✅︎ + - ✅︎ * - :code:`Qwen2ForRewardModel` - Qwen2-based - :code:`Qwen/Qwen2.5-Math-RM-72B`, etc. - ✅︎ - ✅︎ +.. important:: + For process-supervised reward models such as :code:`peiyi9979/math-shepherd-mistral-7b-prm`, the pooling config should be set explicitly, + e.g.: :code:`--override-pooler-config '{"pooling_type": "STEP", "step_tag_id": 123, "returned_token_ids": [456, 789]}'`. + .. note:: As an interim measure, these models are supported in both offline and online inference via Embeddings API. @@ -390,6 +440,36 @@ Classification .. note:: As an interim measure, these models are supported in both offline and online inference via Embeddings API. +Sentence Pair Scoring +--------------------- + +.. list-table:: + :widths: 25 25 50 5 5 + :header-rows: 1 + + * - Architecture + - Models + - Example HF Models + - :ref:`LoRA ` + - :ref:`PP ` + * - :code:`BertForSequenceClassification` + - BERT-based + - :code:`cross-encoder/ms-marco-MiniLM-L-6-v2`, etc. + - + - + * - :code:`RobertaForSequenceClassification` + - RoBERTa-based + - :code:`cross-encoder/quora-roberta-base`, etc. + - + - + * - :code:`XLMRobertaForSequenceClassification` + - XLM-RoBERTa-based + - :code:`BAAI/bge-reranker-v2-m3`, etc. + - + - + +.. note:: + These models are supported in both offline and online inference via Score API. Multimodal Language Models ^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -424,6 +504,12 @@ Text Generation - Example HF Models - :ref:`LoRA ` - :ref:`PP ` + * - :code:`AriaForConditionalGeneration` + - Aria + - T + I + - :code:`rhymes-ai/Aria` + - + - ✅︎ * - :code:`Blip2ForConditionalGeneration` - BLIP-2 - T + I\ :sup:`E` @@ -561,10 +647,10 @@ Text Generation | :sup:`+` Multiple items can be inputted per text prompt for this modality. .. note:: - vLLM currently only supports adding LoRA to the language backbone of multimodal models. + vLLM currently only supports adding LoRA to the language backbone of multimodal models. .. note:: - For :code:`openbmb/MiniCPM-V-2`, the official repo doesn't work yet, so we need to use a fork (:code:`HwwwH/MiniCPM-V-2`) for now. + The official :code:`openbmb/MiniCPM-V-2` doesn't work yet, so we need to use a fork (:code:`HwwwH/MiniCPM-V-2`) for now. For more details, please see: https://github.com/vllm-project/vllm/pull/4087#issuecomment-2250397630 Multimodal Embedding @@ -615,6 +701,9 @@ At vLLM, we are committed to facilitating the integration and support of third-p 2. **Best-Effort Consistency**: While we aim to maintain a level of consistency between the models implemented in vLLM and other frameworks like transformers, complete alignment is not always feasible. Factors like acceleration techniques and the use of low-precision computations can introduce discrepancies. Our commitment is to ensure that the implemented models are functional and produce sensible results. +.. tip:: + When comparing the output of :code:`model.generate` from HuggingFace Transformers with the output of :code:`llm.generate` from vLLM, note that the former reads the model's generation config file (i.e., `generation_config.json `__) and applies the default parameters for generation, while the latter only uses the parameters passed to the function. Ensure all sampling parameters are identical when comparing outputs. + 3. **Issue Resolution and Model Updates**: Users are encouraged to report any bugs or issues they encounter with third-party models. Proposed fixes should be submitted via PRs, with a clear explanation of the problem and the rationale behind the proposed solution. If a fix for one model impacts another, we rely on the community to highlight and address these cross-model dependencies. Note: for bugfix PRs, it is good etiquette to inform the original author to seek their feedback. 4. **Monitoring and Updates**: Users interested in specific models should monitor the commit history for those models (e.g., by tracking changes in the main/vllm/model_executor/models directory). This proactive approach helps users stay informed about updates and changes that may affect the models they use. diff --git a/docs/source/quantization/fp8_e5m2_kvcache.rst b/docs/source/quantization/fp8_e5m2_kvcache.rst index 9ae07bcd3b991..b2d824427f786 100644 --- a/docs/source/quantization/fp8_e5m2_kvcache.rst +++ b/docs/source/quantization/fp8_e5m2_kvcache.rst @@ -4,7 +4,7 @@ FP8 E5M2 KV Cache ================== The int8/int4 quantization scheme requires additional scale GPU memory storage, which reduces the expected GPU memory benefits. -The FP8 data format retains 2~3 mantissa bits and can convert float/fp16/bflaot16 and fp8 to each other. +The FP8 data format retains 2~3 mantissa bits and can convert float/fp16/bfloat16 and fp8 to each other. Here is an example of how to enable this feature: diff --git a/docs/source/quantization/supported_hardware.rst b/docs/source/quantization/supported_hardware.rst index 9bf0cdb80376d..09f8e7112cf0c 100644 --- a/docs/source/quantization/supported_hardware.rst +++ b/docs/source/quantization/supported_hardware.rst @@ -27,7 +27,7 @@ The table below shows the compatibility of various quantization implementations - ✅︎ - ✅︎ - ✗ - - ✗ + - ✅︎ - ✅︎ - ✗ - ✗ @@ -38,8 +38,8 @@ The table below shows the compatibility of various quantization implementations - ✅︎ - ✅︎ - ✗ - - ✗ - - ✗ + - ✅︎ + - ✅︎ - ✗ - ✗ * - Marlin (GPTQ/AWQ/FP8) @@ -129,4 +129,4 @@ Notes: Please note that this compatibility chart may be subject to change as vLLM continues to evolve and expand its support for different hardware platforms and quantization methods. -For the most up-to-date information on hardware support and quantization methods, please check the `quantization directory `_ or consult with the vLLM development team. \ No newline at end of file +For the most up-to-date information on hardware support and quantization methods, please check the `quantization directory `_ or consult with the vLLM development team. diff --git a/docs/source/serving/compatibility_matrix.rst b/docs/source/serving/compatibility_matrix.rst index f629b3ca78318..a93632ff36fb8 100644 --- a/docs/source/serving/compatibility_matrix.rst +++ b/docs/source/serving/compatibility_matrix.rst @@ -39,12 +39,13 @@ Feature x Feature - :abbr:`prmpt adptr (Prompt Adapter)` - :ref:`SD ` - CUDA graph + - :abbr:`emd (Embedding Models)` - :abbr:`enc-dec (Encoder-Decoder Models)` - :abbr:`logP (Logprobs)` - :abbr:`prmpt logP (Prompt Logprobs)` - :abbr:`async output (Async Output Processing)` - multi-step - - :abbr:`MM (Multimodal)` + - :abbr:`mm (Multimodal)` - best-of - beam-search - :abbr:`guided dec (Guided Decoding)` @@ -64,6 +65,7 @@ Feature x Feature - - - + - * - :ref:`APC ` - ✅ - @@ -80,6 +82,7 @@ Feature x Feature - - - + - * - :ref:`LoRA ` - `✗ `__ - ✅ @@ -96,6 +99,7 @@ Feature x Feature - - - + - * - :abbr:`prmpt adptr (Prompt Adapter)` - ✅ - ✅ @@ -112,8 +116,9 @@ Feature x Feature - - - + - * - :ref:`SD ` - - ✗ + - ✅ - ✅ - ✗ - ✅ @@ -128,6 +133,7 @@ Feature x Feature - - - + - * - CUDA graph - ✅ - ✅ @@ -144,6 +150,24 @@ Feature x Feature - - - + - + * - :abbr:`emd (Embedding Models)` + - ✗ + - ✗ + - ✗ + - ✗ + - ✗ + - ✗ + - + - + - + - + - + - + - + - + - + - * - :abbr:`enc-dec (Encoder-Decoder Models)` - ✗ - `✗ `__ @@ -151,6 +175,7 @@ Feature x Feature - ✗ - `✗ `__ - ✅ + - ✅ - - - @@ -166,7 +191,8 @@ Feature x Feature - ✅ - ✅ - ✅ - - ✅ + - ✅ + - ✗ - ✅ - - @@ -183,7 +209,8 @@ Feature x Feature - ✅ - `✗ `__ - ✅ - - ✅ + - ✗ + - ✅ - ✅ - - @@ -199,6 +226,7 @@ Feature x Feature - ✅ - ✗ - ✅ + - ✗ - ✗ - ✅ - ✅ @@ -215,6 +243,7 @@ Feature x Feature - ✅ - ✗ - ✅ + - ✗ - ✗ - ✅ - `✗ `__ @@ -224,14 +253,15 @@ Feature x Feature - - - - * - :abbr:`MM (Multimodal)` - - `✗ `__ + * - :abbr:`mm (Multimodal)` + - ✅ - `✗ `__ - `✗ `__ - ? - ? - ✅ - - ✗ + - ✅ + - ✅ - ✅ - ✅ - ✅ @@ -247,6 +277,7 @@ Feature x Feature - ✅ - `✗ `__ - ✅ + - ✗ - ✅ - ✅ - ✅ @@ -263,6 +294,7 @@ Feature x Feature - ✅ - `✗ `__ - ✅ + - ✗ - ✅ - ✅ - ✅ @@ -279,6 +311,7 @@ Feature x Feature - ? - ✅ - ✅ + - ✗ - ? - ✅ - ✅ @@ -311,7 +344,7 @@ Feature x Hardware - ✅ - ✅ - ✅ - - ✗ + - ✅ - ✅ * - :ref:`APC ` - `✗ `__ @@ -319,7 +352,7 @@ Feature x Hardware - ✅ - ✅ - ✅ - - ✗ + - ✅ - ✅ * - :ref:`LoRA ` - ✅ @@ -353,6 +386,14 @@ Feature x Hardware - ✅ - ✗ - ✅ + * - :abbr:`emd (Embedding Models)` + - ✅ + - ✅ + - ✅ + - ✅ + - ✅ + - ✅ + - ? * - :abbr:`enc-dec (Encoder-Decoder Models)` - ✅ - ✅ @@ -361,7 +402,7 @@ Feature x Hardware - ✅ - ✅ - ✗ - * - :abbr:`logP (Logprobs)` + * - :abbr:`mm (Multimodal)` - ✅ - ✅ - ✅ @@ -369,7 +410,7 @@ Feature x Hardware - ✅ - ✅ - ✅ - * - :abbr:`prmpt logP (Prompt Logprobs)` + * - :abbr:`logP (Logprobs)` - ✅ - ✅ - ✅ @@ -377,29 +418,29 @@ Feature x Hardware - ✅ - ✅ - ✅ - * - :abbr:`async output (Async Output Processing)` + * - :abbr:`prmpt logP (Prompt Logprobs)` - ✅ - ✅ - ✅ - ✅ - ✅ - - ✗ - - ✗ - * - multi-step - ✅ - ✅ + * - :abbr:`async output (Async Output Processing)` - ✅ - ✅ - ✅ - - `✗ `__ - ✅ - * - :abbr:`MM (Multimodal)` - ✅ + - ✗ + - ✗ + * - multi-step - ✅ - ✅ - ✅ - ✅ - ✅ + - `✗ `__ - ✅ * - best-of - ✅ diff --git a/docs/source/serving/metrics.rst b/docs/source/serving/metrics.rst index 15e57bd3fec65..231111cd7b738 100644 --- a/docs/source/serving/metrics.rst +++ b/docs/source/serving/metrics.rst @@ -2,9 +2,34 @@ Production Metrics ================== vLLM exposes a number of metrics that can be used to monitor the health of the -system. These metrics are exposed via the `/metrics` endpoint on the vLLM +system. These metrics are exposed via the ``/metrics`` endpoint on the vLLM OpenAI compatible API server. +You can start the server using Python, or using [Docker](deploying_with_docker.rst): + +.. code-block:: console + + $ vllm serve unsloth/Llama-3.2-1B-Instruct + +Then query the endpoint to get the latest metrics from the server: + +.. code-block:: console + + $ curl http://0.0.0.0:8000/metrics + + # HELP vllm:iteration_tokens_total Histogram of number of tokens per engine_step. + # TYPE vllm:iteration_tokens_total histogram + vllm:iteration_tokens_total_sum{model_name="unsloth/Llama-3.2-1B-Instruct"} 0.0 + vllm:iteration_tokens_total_bucket{le="1.0",model_name="unsloth/Llama-3.2-1B-Instruct"} 3.0 + vllm:iteration_tokens_total_bucket{le="8.0",model_name="unsloth/Llama-3.2-1B-Instruct"} 3.0 + vllm:iteration_tokens_total_bucket{le="16.0",model_name="unsloth/Llama-3.2-1B-Instruct"} 3.0 + vllm:iteration_tokens_total_bucket{le="32.0",model_name="unsloth/Llama-3.2-1B-Instruct"} 3.0 + vllm:iteration_tokens_total_bucket{le="64.0",model_name="unsloth/Llama-3.2-1B-Instruct"} 3.0 + vllm:iteration_tokens_total_bucket{le="128.0",model_name="unsloth/Llama-3.2-1B-Instruct"} 3.0 + vllm:iteration_tokens_total_bucket{le="256.0",model_name="unsloth/Llama-3.2-1B-Instruct"} 3.0 + vllm:iteration_tokens_total_bucket{le="512.0",model_name="unsloth/Llama-3.2-1B-Instruct"} 3.0 + ... + The following metrics are exposed: .. literalinclude:: ../../../vllm/engine/metrics.py diff --git a/docs/source/serving/openai_compatible_server.md b/docs/source/serving/openai_compatible_server.md index 79d032bf8b211..c39cef85897ed 100644 --- a/docs/source/serving/openai_compatible_server.md +++ b/docs/source/serving/openai_compatible_server.md @@ -44,6 +44,148 @@ We currently support the following OpenAI APIs: - This enables multi-modal inputs to be passed to embedding models, see [Using VLMs](../models/vlm.rst). - *Note: You should run `vllm serve` with `--task embedding` to ensure that the model is being run in embedding mode.* +## Score API for Cross Encoder Models + +vLLM supports *cross encoders models* at the **/v1/score** endpoint, which is not an OpenAI API standard endpoint. You can find the documentation for these kind of models at [sbert.net](https://www.sbert.net/docs/package_reference/cross_encoder/cross_encoder.html). + +A ***Cross Encoder*** takes exactly two sentences / texts as input and either predicts a score or label for this sentence pair. It can for example predict the similarity of the sentence pair on a scale of 0 … 1. + +### Example of usage for a pair of a string and a list of texts + +In this case, the model will compare the first given text to each of the texts containing the list. + +```bash +curl -X 'POST' \ + 'http://127.0.0.1:8000/v1/score' \ + -H 'accept: application/json' \ + -H 'Content-Type: application/json' \ + -d '{ + "model": "BAAI/bge-reranker-v2-m3", + "text_1": "What is the capital of France?", + "text_2": [ + "The capital of Brazil is Brasilia.", + "The capital of France is Paris." + ] +}' +``` + +Response: + +```bash +{ + "id": "score-request-id", + "object": "list", + "created": 693570, + "model": "BAAI/bge-reranker-v2-m3", + "data": [ + { + "index": 0, + "object": "score", + "score": [ + 0.001094818115234375 + ] + }, + { + "index": 1, + "object": "score", + "score": [ + 1 + ] + } + ], + "usage": {} +} +``` + +### Example of usage for a pair of two lists of texts + +In this case, the model will compare the one by one, making pairs by same index correspondent in each list. + +```bash +curl -X 'POST' \ + 'http://127.0.0.1:8000/v1/score' \ + -H 'accept: application/json' \ + -H 'Content-Type: application/json' \ + -d '{ + "model": "BAAI/bge-reranker-v2-m3", + "encoding_format": "float", + "text_1": [ + "What is the capital of Brazil?", + "What is the capital of France?" + ], + "text_2": [ + "The capital of Brazil is Brasilia.", + "The capital of France is Paris." + ] +}' +``` + +Response: + +```bash +{ + "id": "score-request-id", + "object": "list", + "created": 693447, + "model": "BAAI/bge-reranker-v2-m3", + "data": [ + { + "index": 0, + "object": "score", + "score": [ + 1 + ] + }, + { + "index": 1, + "object": "score", + "score": [ + 1 + ] + } + ], + "usage": {} +} +``` + +### Example of usage for a pair of two strings + +In this case, the model will compare the strings of texts. + +```bash +curl -X 'POST' \ + 'http://127.0.0.1:8000/v1/score' \ + -H 'accept: application/json' \ + -H 'Content-Type: application/json' \ + -d '{ + "model": "BAAI/bge-reranker-v2-m3", + "encoding_format": "float", + "text_1": "What is the capital of France?", + "text_2": "The capital of France is Paris." +}' +``` + +Response: + +```bash +{ + "id": "score-request-id", + "object": "list", + "created": 693447, + "model": "BAAI/bge-reranker-v2-m3", + "data": [ + { + "index": 0, + "object": "score", + "score": [ + 1 + ] + } + ], + "usage": {} +} +``` + ## Extra Parameters vLLM supports a set of parameters that are not part of the OpenAI API. diff --git a/examples/disaggregated_prefill.sh b/examples/disaggregated_prefill.sh new file mode 100644 index 0000000000000..87155273a81d1 --- /dev/null +++ b/examples/disaggregated_prefill.sh @@ -0,0 +1,109 @@ +#!/bin/bash +# This file demonstrates the example usage of disaggregated prefilling +# We will launch 2 vllm instances (1 for prefill and 1 for decode), +# and then transfer the KV cache between them. + +echo "🚧🚧 Warning: The usage of disaggregated prefill is experimental and subject to change 🚧🚧" +sleep 1 + +# Trap the SIGINT signal (triggered by Ctrl+C) +trap 'cleanup' INT + +# Cleanup function +cleanup() { + echo "Caught Ctrl+C, cleaning up..." + # Cleanup commands + pgrep python | xargs kill -9 + pkill -f python + echo "Cleanup complete. Exiting." + exit 0 +} + +export VLLM_HOST_IP=$(hostname -I | awk '{print $1}') + +# install quart first -- required for disagg prefill proxy serve +if python3 -c "import quart" &> /dev/null; then + echo "Quart is already installed." +else + echo "Quart is not installed. Installing..." + python3 -m pip install quart +fi + +# a function that waits vLLM server to start +wait_for_server() { + local port=$1 + timeout 1200 bash -c " + until curl -s localhost:${port}/v1/completions > /dev/null; do + sleep 1 + done" && return 0 || return 1 +} + + +# You can also adjust --kv-ip and --kv-port for distributed inference. + +# prefilling instance, which is the KV producer +CUDA_VISIBLE_DEVICES=0 vllm serve meta-llama/Meta-Llama-3.1-8B-Instruct \ + --port 8100 \ + --max-model-len 100 \ + --gpu-memory-utilization 0.8 \ + --kv-transfer-config \ + '{"kv_connector":"PyNcclConnector","kv_role":"kv_producer","kv_rank":0,"kv_parallel_size":2}' & + +# decoding instance, which is the KV consumer +CUDA_VISIBLE_DEVICES=1 vllm serve meta-llama/Meta-Llama-3.1-8B-Instruct \ + --port 8200 \ + --max-model-len 100 \ + --gpu-memory-utilization 0.8 \ + --kv-transfer-config \ + '{"kv_connector":"PyNcclConnector","kv_role":"kv_consumer","kv_rank":1,"kv_parallel_size":2}' & + +# wait until prefill and decode instances are ready +wait_for_server 8100 +wait_for_server 8200 + +# launch a proxy server that opens the service at port 8000 +# the workflow of this proxy: +# - send the request to prefill vLLM instance (port 8100), change max_tokens +# to 1 +# - after the prefill vLLM finishes prefill, send the request to decode vLLM +# instance +# NOTE: the usage of this API is subject to change --- in the future we will +# introduce "vllm connect" to connect between prefill and decode instances +python3 ../benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py & +sleep 1 + +# serve two example requests +output1=$(curl -X POST -s http://localhost:8000/v1/completions \ +-H "Content-Type: application/json" \ +-d '{ +"model": "meta-llama/Meta-Llama-3.1-8B-Instruct", +"prompt": "San Francisco is a", +"max_tokens": 10, +"temperature": 0 +}') + +output2=$(curl -X POST -s http://localhost:8000/v1/completions \ +-H "Content-Type: application/json" \ +-d '{ +"model": "meta-llama/Meta-Llama-3.1-8B-Instruct", +"prompt": "Santa Clara is a", +"max_tokens": 10, +"temperature": 0 +}') + + +# Cleanup commands +pgrep python | xargs kill -9 +pkill -f python + +echo "" + +sleep 1 + +# Print the outputs of the curl requests +echo "" +echo "Output of first request: $output1" +echo "Output of second request: $output2" + +echo "🎉🎉 Successfully finished 2 test requests! 🎉🎉" +echo "" diff --git a/examples/logging_configuration.md b/examples/logging_configuration.md index 0d278b0392403..9ac8b13cd5eaf 100644 --- a/examples/logging_configuration.md +++ b/examples/logging_configuration.md @@ -118,7 +118,7 @@ configuration for the root vLLM logger and for the logger you wish to silence: { "formatters": { "vllm": { - "class": "vllm.logging.NewLineFormatter", + "class": "vllm.logging_utils.NewLineFormatter", "datefmt": "%m-%d %H:%M:%S", "format": "%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s" } diff --git a/examples/offline_inference.py b/examples/offline_inference.py index 391ac6b9b6b03..23cc6e8539431 100644 --- a/examples/offline_inference.py +++ b/examples/offline_inference.py @@ -1,80 +1,22 @@ -from dataclasses import asdict - from vllm import LLM, SamplingParams -from vllm.engine.arg_utils import EngineArgs -from vllm.utils import FlexibleArgumentParser - - -def get_prompts(num_prompts: int): - # The default sample prompts. - prompts = [ - "Hello, my name is", - "The president of the United States is", - "The capital of France is", - "The future of AI is", - ] - - if num_prompts != len(prompts): - prompts = (prompts * ((num_prompts // len(prompts)) + 1))[:num_prompts] - - return prompts - - -def main(args): - # Create prompts - prompts = get_prompts(args.num_prompts) - - # Create a sampling params object. - sampling_params = SamplingParams(n=args.n, - temperature=args.temperature, - top_p=args.top_p, - top_k=args.top_k, - max_tokens=args.max_tokens) - - # Create an LLM. - # The default model is 'facebook/opt-125m' - engine_args = EngineArgs.from_cli_args(args) - llm = LLM(**asdict(engine_args)) - - # Generate texts from the prompts. - # The output is a list of RequestOutput objects - # that contain the prompt, generated text, and other information. - outputs = llm.generate(prompts, sampling_params) - # Print the outputs. - for output in outputs: - prompt = output.prompt - generated_text = output.outputs[0].text - print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") - - -if __name__ == '__main__': - parser = FlexibleArgumentParser() - parser = EngineArgs.add_cli_args(parser) - group = parser.add_argument_group("SamplingParams options") - group.add_argument("--num-prompts", - type=int, - default=4, - help="Number of prompts used for inference") - group.add_argument("--max-tokens", - type=int, - default=16, - help="Generated output length for sampling") - group.add_argument('--n', - type=int, - default=1, - help='Number of generated sequences per prompt') - group.add_argument('--temperature', - type=float, - default=0.8, - help='Temperature for text generation') - group.add_argument('--top-p', - type=float, - default=0.95, - help='top_p for text generation') - group.add_argument('--top-k', - type=int, - default=-1, - help='top_k for text generation') - args = parser.parse_args() - main(args) +# Sample prompts. +prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", +] +# Create a sampling params object. +sampling_params = SamplingParams(temperature=0.8, top_p=0.95) + +# Create an LLM. +llm = LLM(model="facebook/opt-125m") +# Generate texts from the prompts. The output is a list of RequestOutput objects +# that contain the prompt, generated text, and other information. +outputs = llm.generate(prompts, sampling_params) +# Print the outputs. +for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") \ No newline at end of file diff --git a/examples/offline_inference_cli.py b/examples/offline_inference_cli.py new file mode 100644 index 0000000000000..391ac6b9b6b03 --- /dev/null +++ b/examples/offline_inference_cli.py @@ -0,0 +1,80 @@ +from dataclasses import asdict + +from vllm import LLM, SamplingParams +from vllm.engine.arg_utils import EngineArgs +from vllm.utils import FlexibleArgumentParser + + +def get_prompts(num_prompts: int): + # The default sample prompts. + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + + if num_prompts != len(prompts): + prompts = (prompts * ((num_prompts // len(prompts)) + 1))[:num_prompts] + + return prompts + + +def main(args): + # Create prompts + prompts = get_prompts(args.num_prompts) + + # Create a sampling params object. + sampling_params = SamplingParams(n=args.n, + temperature=args.temperature, + top_p=args.top_p, + top_k=args.top_k, + max_tokens=args.max_tokens) + + # Create an LLM. + # The default model is 'facebook/opt-125m' + engine_args = EngineArgs.from_cli_args(args) + llm = LLM(**asdict(engine_args)) + + # Generate texts from the prompts. + # The output is a list of RequestOutput objects + # that contain the prompt, generated text, and other information. + outputs = llm.generate(prompts, sampling_params) + # Print the outputs. + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + + +if __name__ == '__main__': + parser = FlexibleArgumentParser() + parser = EngineArgs.add_cli_args(parser) + group = parser.add_argument_group("SamplingParams options") + group.add_argument("--num-prompts", + type=int, + default=4, + help="Number of prompts used for inference") + group.add_argument("--max-tokens", + type=int, + default=16, + help="Generated output length for sampling") + group.add_argument('--n', + type=int, + default=1, + help='Number of generated sequences per prompt') + group.add_argument('--temperature', + type=float, + default=0.8, + help='Temperature for text generation') + group.add_argument('--top-p', + type=float, + default=0.95, + help='top_p for text generation') + group.add_argument('--top-k', + type=int, + default=-1, + help='top_k for text generation') + + args = parser.parse_args() + main(args) diff --git a/examples/offline_inference_embedding.py b/examples/offline_inference_embedding.py index 7d5ef128bc8e0..ae158eef2ca4c 100644 --- a/examples/offline_inference_embedding.py +++ b/examples/offline_inference_embedding.py @@ -10,7 +10,7 @@ # Create an LLM. model = LLM(model="intfloat/e5-mistral-7b-instruct", enforce_eager=True) -# Generate embedding. The output is a list of EmbeddingRequestOutputs. +# Generate embedding. The output is a list of PoolingRequestOutputs. outputs = model.encode(prompts) # Print the outputs. for output in outputs: diff --git a/examples/offline_inference_structured_outputs.py b/examples/offline_inference_structured_outputs.py new file mode 100644 index 0000000000000..00d864606eeff --- /dev/null +++ b/examples/offline_inference_structured_outputs.py @@ -0,0 +1,78 @@ +from enum import Enum + +from pydantic import BaseModel + +from vllm import LLM, SamplingParams +from vllm.sampling_params import GuidedDecodingParams + +llm = LLM(model="Qwen/Qwen2.5-3B-Instruct", max_model_len=100) + +# Guided decoding by Choice (list of possible options) +guided_decoding_params = GuidedDecodingParams(choice=["Positive", "Negative"]) +sampling_params = SamplingParams(guided_decoding=guided_decoding_params) +outputs = llm.generate( + prompts="Classify this sentiment: vLLM is wonderful!", + sampling_params=sampling_params, +) +print(outputs[0].outputs[0].text) + +# Guided decoding by Regex +guided_decoding_params = GuidedDecodingParams(regex="\w+@\w+\.com\n") +sampling_params = SamplingParams(guided_decoding=guided_decoding_params, + stop=["\n"]) +prompt = ("Generate an email address for Alan Turing, who works in Enigma." + "End in .com and new line. Example result:" + "alan.turing@enigma.com\n") +outputs = llm.generate(prompts=prompt, sampling_params=sampling_params) +print(outputs[0].outputs[0].text) + + +# Guided decoding by JSON using Pydantic schema +class CarType(str, Enum): + sedan = "sedan" + suv = "SUV" + truck = "Truck" + coupe = "Coupe" + + +class CarDescription(BaseModel): + brand: str + model: str + car_type: CarType + + +json_schema = CarDescription.model_json_schema() + +guided_decoding_params = GuidedDecodingParams(json=json_schema) +sampling_params = SamplingParams(guided_decoding=guided_decoding_params) +prompt = ("Generate a JSON with the brand, model and car_type of" + "the most iconic car from the 90's") +outputs = llm.generate( + prompts=prompt, + sampling_params=sampling_params, +) +print(outputs[0].outputs[0].text) + +# Guided decoding by Grammar +simplified_sql_grammar = """ + ?start: select_statement + + ?select_statement: "SELECT " column_list " FROM " table_name + + ?column_list: column_name ("," column_name)* + + ?table_name: identifier + + ?column_name: identifier + + ?identifier: /[a-zA-Z_][a-zA-Z0-9_]*/ +""" +guided_decoding_params = GuidedDecodingParams(grammar=simplified_sql_grammar) +sampling_params = SamplingParams(guided_decoding=guided_decoding_params) +prompt = ("Generate an SQL query to show the 'username' and 'email'" + "from the 'users' table.") +outputs = llm.generate( + prompts=prompt, + sampling_params=sampling_params, +) +print(outputs[0].outputs[0].text) diff --git a/examples/offline_inference_vision_language.py b/examples/offline_inference_vision_language.py index 11af6880e1b5a..f08f22eec164a 100644 --- a/examples/offline_inference_vision_language.py +++ b/examples/offline_inference_vision_language.py @@ -402,6 +402,23 @@ def run_idefics3(question: str, modality: str): return llm, prompt, stop_token_ids +# Aria +def run_aria(question: str, modality: str): + assert modality == "image" + model_name = "rhymes-ai/Aria" + + llm = LLM(model=model_name, + tokenizer_mode="slow", + trust_remote_code=True, + dtype="bfloat16") + + prompt = (f"<|im_start|>user\n<|img|>\n{question}" + "<|im_end|>\n<|im_start|>assistant\n") + + stop_token_ids = [93532, 93653, 944, 93421, 1019, 93653, 93519] + return llm, prompt, stop_token_ids + + model_example_map = { "llava": run_llava, "llava-next": run_llava_next, @@ -423,6 +440,7 @@ def run_idefics3(question: str, modality: str): "molmo": run_molmo, "glm4v": run_glm4v, "idefics3": run_idefics3, + "aria": run_aria, } diff --git a/examples/offline_inference_vision_language_multi_image.py b/examples/offline_inference_vision_language_multi_image.py index dc12df8d78211..788b604cfd4a0 100644 --- a/examples/offline_inference_vision_language_multi_image.py +++ b/examples/offline_inference_vision_language_multi_image.py @@ -321,6 +321,25 @@ def load_idefics3(question, image_urls: List[str]) -> ModelRequestData: ) +def load_aria(question, image_urls: List[str]) -> ModelRequestData: + model_name = "rhymes-ai/Aria" + llm = LLM(model=model_name, + tokenizer_mode="slow", + trust_remote_code=True, + dtype="bfloat16", + limit_mm_per_prompt={"image": len(image_urls)}) + placeholders = "<|img|>\n" * len(image_urls) + prompt = (f"<|im_start|>user\n{placeholders}{question}<|im_end|>\n" + "<|im_start|>assistant\n") + stop_token_ids = [93532, 93653, 944, 93421, 1019, 93653, 93519] + return ModelRequestData( + llm=llm, + prompt=prompt, + stop_token_ids=stop_token_ids, + image_data=[fetch_image(url) for url in image_urls], + chat_template=None) + + model_example_map = { "phi3_v": load_phi3v, "h2ovl_chat": load_h2onvl, @@ -330,6 +349,7 @@ def load_idefics3(question, image_urls: List[str]) -> ModelRequestData: "qwen_vl_chat": load_qwenvl_chat, "mllama": load_mllama, "idefics3": load_idefics3, + "aria": load_aria, } diff --git a/examples/openai_chat_completion_structured_outputs.py b/examples/openai_chat_completion_structured_outputs.py new file mode 100644 index 0000000000000..8c059c7ca07ce --- /dev/null +++ b/examples/openai_chat_completion_structured_outputs.py @@ -0,0 +1,94 @@ +from enum import Enum + +from openai import OpenAI +from pydantic import BaseModel + +client = OpenAI( + base_url="http://localhost:8000/v1", + api_key="-", +) + +# Guided decoding by Choice (list of possible options) +completion = client.chat.completions.create( + model="Qwen/Qwen2.5-3B-Instruct", + messages=[{ + "role": "user", + "content": "Classify this sentiment: vLLM is wonderful!" + }], + extra_body={"guided_choice": ["positive", "negative"]}, +) +print(completion.choices[0].message.content) + +# Guided decoding by Regex +prompt = ("Generate an email address for Alan Turing, who works in Enigma." + "End in .com and new line. Example result:" + "alan.turing@enigma.com\n") + +completion = client.chat.completions.create( + model="Qwen/Qwen2.5-3B-Instruct", + messages=[{ + "role": "user", + "content": prompt, + }], + extra_body={ + "guided_regex": "\w+@\w+\.com\n", + "stop": ["\n"] + }, +) +print(completion.choices[0].message.content) + + +# Guided decoding by JSON using Pydantic schema +class CarType(str, Enum): + sedan = "sedan" + suv = "SUV" + truck = "Truck" + coupe = "Coupe" + + +class CarDescription(BaseModel): + brand: str + model: str + car_type: CarType + + +json_schema = CarDescription.model_json_schema() + +prompt = ("Generate a JSON with the brand, model and car_type of" + "the most iconic car from the 90's") +completion = client.chat.completions.create( + model="Qwen/Qwen2.5-3B-Instruct", + messages=[{ + "role": "user", + "content": prompt, + }], + extra_body={"guided_json": json_schema}, +) +print(completion.choices[0].message.content) + +# Guided decoding by Grammar +simplified_sql_grammar = """ + ?start: select_statement + + ?select_statement: "SELECT " column_list " FROM " table_name + + ?column_list: column_name ("," column_name)* + + ?table_name: identifier + + ?column_name: identifier + + ?identifier: /[a-zA-Z_][a-zA-Z0-9_]*/ +""" + +prompt = ("Generate an SQL query to show the 'username' and 'email'" + "from the 'users' table.") +completion = client.chat.completions.create( + model="Qwen/Qwen2.5-3B-Instruct", + messages=[{ + "role": "user", + "content": prompt, + }], + extra_body={"guided_grammar": simplified_sql_grammar}, +) +print(completion.choices[0].message.content) diff --git a/examples/openai_cross_encoder_score.py b/examples/openai_cross_encoder_score.py new file mode 100644 index 0000000000000..8c32eea5dd252 --- /dev/null +++ b/examples/openai_cross_encoder_score.py @@ -0,0 +1,58 @@ +"""Examples Python client Score for Cross Encoder Models +""" + +import argparse +import json +import pprint + +import requests + + +def post_http_request(prompt: json, api_url: str) -> requests.Response: + headers = {"User-Agent": "Test Client"} + response = requests.post(api_url, headers=headers, json=prompt) + return response + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--host", type=str, default="localhost") + parser.add_argument("--port", type=int, default=8000) + parser.add_argument("--model", type=str, default="BAAI/bge-reranker-v2-m3") + args = parser.parse_args() + api_url = f"http://{args.host}:{args.port}/v1/score" + + model_name = args.model + + text_1 = "What is the capital of France?" + text_2 = [ + "The capital of Brazil is Brasilia.", "The capital of France is Paris." + ] + prompt = {"model": model_name, "text_1": text_1, "text_2": text_2} + score_response = post_http_request(prompt=prompt, api_url=api_url) + print("Prompt for text_1 is string and text_2 is a list:") + pprint.pprint(prompt) + print("Score Response:") + pprint.pprint(score_response.data) + + text_1 = [ + "What is the capital of Brazil?", "What is the capital of France?" + ] + text_2 = [ + "The capital of Brazil is Brasilia.", "The capital of France is Paris." + ] + prompt = {"model": model_name, "text_1": text_1, "text_2": text_2} + score_response = post_http_request(prompt=prompt, api_url=api_url) + print("Prompt for text_1 and text_2 are lists:") + pprint.pprint(prompt) + print("Score Response:") + pprint.pprint(score_response.data) + + text_1 = "What is the capital of Brazil?" + text_2 = "The capital of Brazil is Brasilia." + prompt = {"model": model_name, "text_1": text_1, "text_2": text_2} + score_response = post_http_request(prompt=prompt, api_url=api_url) + print("Prompt for text_1 and text_2 are strings:") + pprint.pprint(prompt) + print("Score Response:") + pprint.pprint(score_response.data) \ No newline at end of file diff --git a/examples/tool_chat_template_granite.jinja b/examples/tool_chat_template_granite.jinja index 2cc19e77188dc..467dcb2d10237 100644 --- a/examples/tool_chat_template_granite.jinja +++ b/examples/tool_chat_template_granite.jinja @@ -21,11 +21,7 @@ {{- '<|start_of_role|>user<|end_of_role|>' + message['content'] + '<|end_of_text|> ' }} {%- elif message['role'] == 'assistant_tool_call' or (message['role'] == 'assistant' and message.tool_calls is defined) %} - {{- '<|start_of_role|>assistant<|end_of_role|>' }} - {% for tc in message.tool_calls %} - {{- '<|tool_call|> ' + {'name': tc.function.name, 'arguments': tc.function.arguments}|tojson }} - {% endfor %} - {{- '<|end_of_text|> + {{- '<|start_of_role|>assistant<|end_of_role|><|tool_call|>' + message.tool_calls|map(attribute='function')|list|tojson(indent=4) + '<|end_of_text|> ' }} {%- elif message['role'] == 'assistant' %} {{- '<|start_of_role|>assistant<|end_of_role|>' + message['content'] + '<|end_of_text|> diff --git a/examples/tool_chat_template_llama3.1_json.jinja b/examples/tool_chat_template_llama3.1_json.jinja index c24a7e51335ef..033830936a56b 100644 --- a/examples/tool_chat_template_llama3.1_json.jinja +++ b/examples/tool_chat_template_llama3.1_json.jinja @@ -19,10 +19,18 @@ {#- This block extracts the system message, so we can slot it into the right place. #} {%- if messages[0]['role'] == 'system' %} - {%- set system_message = messages[0]['content']|trim %} + {%- if messages[0]['content'] is string %} + {%- set system_message = messages[0]['content']|trim %} + {%- else %} + {%- set system_message = messages[0]['content'][0]['text']|trim %} + {%- endif %} {%- set messages = messages[1:] %} {%- else %} - {%- set system_message = "You are a helpful assistant with tool calling capabilities. Only reply with a tool call if the function exists in the library provided by the user. If it doesn't exist, just reply directly in natural language. When you receive a tool call response, use the output to format an answer to the original user question." %} + {%- if tools is not none %} + {%- set system_message = "You are a helpful assistant with tool calling capabilities. Only reply with a tool call if the function exists in the library provided by the user. If it doesn't exist, just reply directly in natural language. When you receive a tool call response, use the output to format an answer to the original user question." %} + {%- else %} + {%- set system_message = "" %} + {%- endif %} {%- endif %} {#- System message #} @@ -33,8 +41,8 @@ {{- "Cutting Knowledge Date: December 2023\n" }} {{- "Today Date: " + date_string + "\n\n" }} {%- if tools is not none and not tools_in_user_message %} - {{- "You have access to the following functions. To call a function, please respond with JSON for a function call." }} - {{- 'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}.' }} + {{- "You have access to the following functions. To call a function, please respond with JSON for a function call. " }} + {{- 'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}. ' }} {{- "Do not use variables.\n\n" }} {%- for t in tools %} {{- t | tojson(indent=4) }} @@ -48,7 +56,11 @@ {%- if tools_in_user_message and not tools is none %} {#- Extract the first user message so we can plug it in here #} {%- if messages | length != 0 %} - {%- set first_user_message = messages[0]['content']|trim %} + {%- if messages[0]['content'] is string %} + {%- set first_user_message = messages[0]['content']|trim %} + {%- else %} + {%- set first_user_message = messages[0]['content'] | selectattr('type', 'equalto', 'text') | map(attribute='text') | map('trim') | join('\n') %} + {%- endif %} {%- set messages = messages[1:] %} {%- else %} {{- raise_exception("Cannot put tools in the first user message when there's no first user message!") }} @@ -56,7 +68,7 @@ {{- '<|start_header_id|>user<|end_header_id|>\n\n' -}} {{- "Given the following functions, please respond with a JSON for a function call " }} {{- "with its proper arguments that best answers the given prompt.\n\n" }} - {{- 'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}.' }} + {{- 'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}. ' }} {{- "Do not use variables.\n\n" }} {%- for t in tools %} {{- t | tojson(indent=4) }} @@ -67,7 +79,17 @@ {%- for message in messages %} {%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %} - {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' }} + {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' }} + {%- if message['content'] is string %} + {{- message['content'] | trim}} + {%- else %} + {%- for content in message['content'] %} + {%- if content['type'] == 'text' %} + {{- content['text'] | trim }} + {%- endif %} + {%- endfor %} + {%- endif %} + {{- '<|eot_id|>' }} {%- elif 'tool_calls' in message %} {%- if not message.tool_calls|length == 1 %} {{- raise_exception("This model only supports single tool-calls at once!") }} @@ -81,10 +103,14 @@ {{- "<|eot_id|>" }} {%- elif message.role == "tool" or message.role == "ipython" %} {{- "<|start_header_id|>ipython<|end_header_id|>\n\n" }} - {%- if message.content is mapping %} - {{- message.content | tojson }} - {%- else %} + {%- if message.content is string %} {{- { "output": message.content } | tojson }} + {%- else %} + {%- for content in message['content'] %} + {%- if content['type'] == 'text' %} + {{- { "output": content['text'] } | tojson }} + {%- endif %} + {%- endfor %} {%- endif %} {{- "<|eot_id|>" }} {%- endif %} diff --git a/examples/tool_chat_template_llama3.2_json.jinja b/examples/tool_chat_template_llama3.2_json.jinja index 7e24777726a35..39f902c1c3c40 100644 --- a/examples/tool_chat_template_llama3.2_json.jinja +++ b/examples/tool_chat_template_llama3.2_json.jinja @@ -16,38 +16,70 @@ {%- set tools = none %} {%- endif %} +{#- Find out if there are any images #} +{% set image_ns = namespace(has_images=false) %} +{%- for message in messages %} + {%- for content in message['content'] %} + {%- if content['type'] == 'image' %} + {%- set image_ns.has_images = true %} + {%- endif %} + {%- endfor %} +{%- endfor %} + + {#- This block extracts the system message, so we can slot it into the right place. #} {%- if messages[0]['role'] == 'system' %} - {%- set system_message = messages[0]['content']|trim %} + {%- if messages[0]['content'] is string %} + {%- set system_message = messages[0]['content']|trim %} + {%- else %} + {#- Support vLLM's transforming of a content string to JSON. #} + {%- set system_message = messages[0]['content'][0]['text']|trim %} + {%- endif %} {%- set messages = messages[1:] %} {%- else %} - {%- set system_message = "You are a helpful assistant with tool calling capabilities. Only reply with a tool call if the function exists in the library provided by the user. If it doesn't exist, just reply directly in natural language. When you receive a tool call response, use the output to format an answer to the original user question." %} + {%- if tools is not none %} + {%- set system_message = "You are a helpful assistant with tool calling capabilities. Only reply with a tool call if the function exists in the library provided by the user. If it doesn't exist, just reply directly in natural language. When you receive a tool call response, use the output to format an answer to the original user question." %} + {%- else %} + {%- set system_message = "" %} + {%- endif %} {%- endif %} -{#- System message #} -{{- "<|start_header_id|>system<|end_header_id|>\n\n" }} -{%- if tools is not none %} - {{- "Environment: ipython\n" }} +{#- Including an image is not compatible with a system message #} +{%- if image_ns.has_images and not system_message == "" %} + {{- raise_exception("Prompting with images is incompatible with system messages and tool use.") }} {%- endif %} -{{- "Cutting Knowledge Date: December 2023\n" }} -{{- "Today Date: " + date_string + "\n\n" }} -{%- if tools is not none and not tools_in_user_message %} - {{- "You have access to the following functions. To call a function, please respond with JSON for a function call." }} - {{- 'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}.' }} - {{- "Do not use variables.\n\n" }} - {%- for t in tools %} - {{- t | tojson(indent=4) }} - {{- "\n\n" }} - {%- endfor %} + + +{#- System message, if there are no images #} +{%- if not image_ns.has_images %} + {{- "<|start_header_id|>system<|end_header_id|>\n\n" }} + {%- if tools is not none %} + {{- "Environment: ipython\n" }} + {%- endif %} + {{- "Cutting Knowledge Date: December 2023\n" }} + {{- "Today Date: " + date_string + "\n\n" }} + {%- if tools is not none and not tools_in_user_message %} + {{- "You have access to the following functions. To call a function, please respond with JSON for a function call. " }} + {{- 'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}. ' }} + {{- "Do not use variables.\n\n" }} + {%- for t in tools %} + {{- t | tojson(indent=4) }} + {{- "\n\n" }} + {%- endfor %} + {%- endif %} + {{- system_message }} + {{- "<|eot_id|>" }} {%- endif %} -{{- system_message }} -{{- "<|eot_id|>" }} {#- Custom tools are passed in a user message with some extra guidance #} {%- if tools_in_user_message and not tools is none %} {#- Extract the first user message so we can plug it in here #} {%- if messages | length != 0 %} - {%- set first_user_message = messages[0]['content']|trim %} + {%- if messages[0]['content'] is string %} + {%- set first_user_message = messages[0]['content']|trim %} + {%- else %} + {%- set first_user_message = messages[0]['content'] | selectattr('type', 'equalto', 'text') | map(attribute='text') | map('trim') | join('\n') %} + {%- endif %} {%- set messages = messages[1:] %} {%- else %} {{- raise_exception("Cannot put tools in the first user message when there's no first user message!") }} @@ -55,7 +87,7 @@ {{- '<|start_header_id|>user<|end_header_id|>\n\n' -}} {{- "Given the following functions, please respond with a JSON for a function call " }} {{- "with its proper arguments that best answers the given prompt.\n\n" }} - {{- 'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}.' }} + {{- 'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}. ' }} {{- "Do not use variables.\n\n" }} {%- for t in tools %} {{- t | tojson(indent=4) }} @@ -66,7 +98,19 @@ {%- for message in messages %} {%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %} - {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' }} + {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' }} + {%- if message['content'] is string %} + {{- message['content'] | trim}} + {%- else %} + {%- for content in message['content'] %} + {%- if content['type'] == 'image' %} + {{- '<|image|>' }} + {%- elif content['type'] == 'text' %} + {{- content['text'] | trim }} + {%- endif %} + {%- endfor %} + {%- endif %} + {{- '<|eot_id|>' }} {%- elif 'tool_calls' in message %} {%- if not message.tool_calls|length == 1 %} {{- raise_exception("This model only supports single tool-calls at once!") }} @@ -80,10 +124,14 @@ {{- "<|eot_id|>" }} {%- elif message.role == "tool" or message.role == "ipython" %} {{- "<|start_header_id|>ipython<|end_header_id|>\n\n" }} - {%- if message.content is mapping %} - {{- message.content | tojson }} - {%- else %} + {%- if message.content is string %} {{- { "output": message.content } | tojson }} + {%- else %} + {%- for content in message['content'] %} + {%- if content['type'] == 'text' %} + {{- { "output": content['text'] } | tojson }} + {%- endif %} + {%- endfor %} {%- endif %} {{- "<|eot_id|>" }} {%- endif %} diff --git a/format.sh b/format.sh index a57882d2ac3f9..0b196de9d0773 100755 --- a/format.sh +++ b/format.sh @@ -41,6 +41,7 @@ MYPY_VERSION=$(mypy --version | awk '{print $2}') CODESPELL_VERSION=$(codespell --version) ISORT_VERSION=$(isort --vn) CLANGFORMAT_VERSION=$(clang-format --version | awk '{print $3}') +SPHINX_LINT_VERSION=$(sphinx-lint --version | awk '{print $2}') # # params: tool name, tool version, required version tool_version_check() { @@ -57,6 +58,7 @@ tool_version_check "mypy" "$MYPY_VERSION" tool_version_check "isort" "$ISORT_VERSION" tool_version_check "codespell" "$CODESPELL_VERSION" tool_version_check "clang-format" "$CLANGFORMAT_VERSION" +tool_version_check "sphinx-lint" "$SPHINX_LINT_VERSION" YAPF_FLAGS=( '--recursive' @@ -299,6 +301,10 @@ echo 'vLLM shellcheck:' tools/shellcheck.sh echo 'vLLM shellcheck: Done' +echo 'excalidraw png check:' +tools/png-lint.sh +echo 'excalidraw png check: Done' + if ! git diff --quiet &>/dev/null; then echo echo "🔍🔍There are files changed by the format checker or by you that are not added and committed:" @@ -309,3 +315,7 @@ if ! git diff --quiet &>/dev/null; then else echo "✨🎉 Format check passed! Congratulations! 🎉✨" fi + +echo 'vLLM sphinx-lint:' +tools/sphinx-lint.sh +echo 'vLLM sphinx-lint: Done' diff --git a/pyproject.toml b/pyproject.toml index 80d221513e0ed..89fdfc26df097 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -98,4 +98,5 @@ markers = [ "quant_model: run this model test under Quantized category", "distributed_2_gpus: run this test only in distributed tests for 2 GPUs", "skip_v1: do not run this test with v1", + "optional: optional tests that are automatically skipped, include --optional to run them", ] diff --git a/requirements-common.txt b/requirements-common.txt index f62ad66a1ecc4..02e3d65fb774c 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -20,7 +20,7 @@ tiktoken >= 0.6.0 # Required for DBRX tokenizer lm-format-enforcer >= 0.10.9, < 0.11 outlines >= 0.0.43, < 0.1 typing_extensions >= 4.10 -filelock >= 3.10.4 # filelock starts to support `mode` argument from 3.10.4 +filelock >= 3.16.1 # need to contain https://github.com/tox-dev/filelock/pull/317 partial-json-parser # used for parsing partial JSON outputs pyzmq msgspec diff --git a/requirements-cpu.txt b/requirements-cpu.txt index 749b03a0603d8..db8ad9d3a015d 100644 --- a/requirements-cpu.txt +++ b/requirements-cpu.txt @@ -1,6 +1,7 @@ # Common dependencies -r requirements-common.txt -# Dependencies for x86_64 CPUs -torch == 2.5.1+cpu; platform_machine != "ppc64le" -torchvision; platform_machine != "ppc64le" # required for the image processor of phi3v, this must be updated alongside torch +# Dependencies for CPUs +torch==2.5.1+cpu; platform_machine != "ppc64le" and platform_machine != "aarch64" +torch==2.5.1; platform_machine == "aarch64" +torchvision; platform_machine != "ppc64le" # required for the image processor of phi3v, this must be updated alongside torch \ No newline at end of file diff --git a/requirements-lint.txt b/requirements-lint.txt index f9132bbf96437..711bb50a0e936 100644 --- a/requirements-lint.txt +++ b/requirements-lint.txt @@ -6,6 +6,7 @@ ruff==0.6.5 codespell==2.3.0 isort==5.13.2 clang-format==18.1.5 +sphinx-lint==1.0.0 # type checking mypy==1.11.1 diff --git a/requirements-test.in b/requirements-test.in index 76f6de2f77c34..44972866ddc4b 100644 --- a/requirements-test.in +++ b/requirements-test.in @@ -20,7 +20,7 @@ timm # required for internvl test torch==2.5.1 transformers_stream_generator # required for qwen-vl test matplotlib # required for qwen-vl test -mistral_common[opencv] >= 1.4.4 # required for pixtral test +mistral_common[opencv] >= 1.5.0 # required for pixtral test datamodel_code_generator # required for minicpm3 test lm-eval[api]==0.4.4 # required for model evaluation test diff --git a/requirements-test.txt b/requirements-test.txt index 65695111e4dc5..a59b85023948b 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -217,7 +217,7 @@ mbstrdecoder==1.1.3 # dataproperty # pytablewriter # typepy -mistral-common[opencv]==1.4.4 +mistral-common[opencv]==1.5.1 # via # -r requirements-test.in # mistral-common diff --git a/requirements-tpu.txt b/requirements-tpu.txt index f9a0770804e55..b8f0b15469e77 100644 --- a/requirements-tpu.txt +++ b/requirements-tpu.txt @@ -16,8 +16,8 @@ ray[default] --find-links https://storage.googleapis.com/libtpu-releases/index.html --find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html --find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html -torch==2.6.0.dev20241028+cpu -torchvision==0.20.0.dev20241028+cpu -torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev20241028-cp310-cp310-linux_x86_64.whl -jaxlib==0.4.32.dev20240829 -jax==0.4.32.dev20240829 +torch==2.6.0.dev20241126+cpu +torchvision==0.20.0.dev20241126+cpu +torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev20241126-cp310-cp310-linux_x86_64.whl +jaxlib==0.4.36.dev20241122 +jax==0.4.36.dev20241122 diff --git a/tests/basic_correctness/test_basic_correctness.py b/tests/basic_correctness/test_basic_correctness.py index 7f16baa65a644..fcba253d159f3 100644 --- a/tests/basic_correctness/test_basic_correctness.py +++ b/tests/basic_correctness/test_basic_correctness.py @@ -14,11 +14,12 @@ from vllm.platforms import current_platform from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata +from ..conftest import VllmRunner from ..models.utils import check_outputs_equal from ..utils import multi_gpu_test MODELS = [ - "facebook/opt-125m", + "google/gemma-2-2b-it", "meta-llama/Llama-3.2-1B", ] @@ -42,8 +43,6 @@ def test_vllm_gc_ed(): @pytest.mark.parametrize("enforce_eager", [False, True]) def test_models( hf_runner, - vllm_runner, - example_prompts, model: str, backend: str, dtype: str, @@ -54,15 +53,27 @@ def test_models( if backend == "FLASHINFER" and current_platform.is_rocm(): pytest.skip("Flashinfer does not support ROCm/HIP.") + if backend == "XFORMERS" and model == "google/gemma-2-2b-it": + pytest.skip( + "XFORMERS does not support gemma2 with full context length.") + os.environ["VLLM_ATTENTION_BACKEND"] = backend + # 5042 tokens for gemma2 + # gemma2 has alternating sliding window size of 4096 + # we need a prompt with more than 4096 tokens to test the sliding window + prompt = "The following numbers of the sequence " + ", ".join( + str(i) for i in range(1024)) + " are:" + example_prompts = [prompt] + with hf_runner(model, dtype=dtype) as hf_model: hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) - with vllm_runner(model, - dtype=dtype, - enforce_eager=enforce_eager, - gpu_memory_utilization=0.7) as vllm_model: + with VllmRunner(model, + max_model_len=8192, + dtype=dtype, + enforce_eager=enforce_eager, + gpu_memory_utilization=0.7) as vllm_model: vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) check_outputs_equal( diff --git a/tests/basic_correctness/test_chunked_prefill.py b/tests/basic_correctness/test_chunked_prefill.py index cc5bc2aca27c9..469d18a4dd7af 100644 --- a/tests/basic_correctness/test_chunked_prefill.py +++ b/tests/basic_correctness/test_chunked_prefill.py @@ -12,6 +12,7 @@ import pytest from tests.kernels.utils import override_backend_env_variable +from vllm.platforms import current_platform from ..models.utils import check_logprobs_close, check_outputs_equal from ..utils import multi_gpu_test @@ -206,12 +207,14 @@ def test_models_with_fp8_kv_cache( # NOTE: Increasing this in this suite will fail CI because we currently cannot # reset distributed env properly. Use a value > 1 just when you test. @pytest.mark.parametrize("tensor_parallel_size", [1]) +@pytest.mark.parametrize("dtype", ["half"]) def test_with_prefix_caching( vllm_runner, max_tokens: int, enforce_eager: bool, chunk_size: int, tensor_parallel_size: int, + dtype: str, ) -> None: """ Checks exact match decode with and without prefix caching @@ -233,7 +236,7 @@ def test_with_prefix_caching( for enable in (True, False): with vllm_runner( model, - dtype="half", + dtype=dtype, max_num_batched_tokens=max_num_batched_tokens, enable_chunked_prefill=True, enable_prefix_caching=enable, @@ -260,3 +263,61 @@ def test_with_prefix_caching( name_0="w/o prefix caching", name_1="with prefix caching", ) + + +@pytest.mark.parametrize("model", ["facebook/opt-125m"]) +@pytest.mark.parametrize("dtype", ["bfloat16"]) +@pytest.mark.parametrize("max_tokens", [32]) +@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16]) +@pytest.mark.parametrize("enforce_eager", [False]) +@pytest.mark.parametrize("attention_backend", ["TORCH_SDPA"]) +@pytest.mark.cpu_model +@pytest.mark.skipif(not current_platform.is_cpu(), reason="CPU only") +def test_models_cpu( + hf_runner, + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_tokens: int, + chunked_prefill_token_size: int, + enforce_eager: bool, + attention_backend: str, + monkeypatch, +) -> None: + test_models( + hf_runner, + vllm_runner, + example_prompts, + model, + dtype, + max_tokens, + chunked_prefill_token_size, + enforce_eager, + 1, + attention_backend, + monkeypatch, + ) + + +@pytest.mark.parametrize("max_tokens", [16]) +@pytest.mark.parametrize("enforce_eager", [False]) +@pytest.mark.parametrize("chunk_size", [30, 32]) +@pytest.mark.parametrize("dtype", ["bfloat16"]) +@pytest.mark.cpu_model +@pytest.mark.skipif(not current_platform.is_cpu(), reason="CPU only") +def test_with_prefix_caching_cpu( + vllm_runner, + max_tokens: int, + enforce_eager: bool, + chunk_size: int, + dtype: str, +) -> None: + test_with_prefix_caching( + vllm_runner, + max_tokens, + enforce_eager, + chunk_size, + 1, + dtype, + ) diff --git a/tests/compile/backend.py b/tests/compile/backend.py index 9d5c68274374e..8fa10e5bd1b37 100644 --- a/tests/compile/backend.py +++ b/tests/compile/backend.py @@ -1,7 +1,9 @@ from copy import deepcopy -from typing import Callable +from typing import Callable, Union -import torch +from torch import fx + +from vllm.compilation.inductor_pass import InductorPass class TestBackend: @@ -11,19 +13,21 @@ class TestBackend: It also saves the graph before and after the custom passes for inspection. """ - def __init__(self, *args: Callable[[torch.fx.Graph], None]): - self.custom_passes = args + def __init__(self, *passes: Union[InductorPass, Callable[[fx.Graph], + None]]): + self.custom_passes = list(passes) from torch._inductor import config self.current_config = config.shallow_copy_dict() + self.current_config['force_disable_caches'] = True self.current_config['post_grad_custom_post_pass'] = self.post_pass - def __call__(self, graph: torch.fx.GraphModule, example_inputs): + def __call__(self, graph: fx.GraphModule, example_inputs): from torch._inductor.compile_fx import compile_fx return compile_fx(graph, example_inputs, config_patches=self.current_config) - def post_pass(self, graph: torch.fx.Graph): + def post_pass(self, graph: fx.Graph): self.graph_pre_pass = deepcopy(graph) for pass_ in self.custom_passes: pass_(graph) diff --git a/tests/compile/piecewise/piecewise_compilation_config.json b/tests/compile/piecewise/piecewise_compilation_config.json deleted file mode 100644 index 798a34e8dd92d..0000000000000 --- a/tests/compile/piecewise/piecewise_compilation_config.json +++ /dev/null @@ -1,5 +0,0 @@ -{ - "use_cudagraph": true, - "non_cudagraph_ops": ["silly.attention"], - "cudagraph_copy_inputs": true -} \ No newline at end of file diff --git a/tests/compile/piecewise/test_simple.py b/tests/compile/piecewise/test_simple.py index 45f56cbbd4b16..7ef502abee345 100644 --- a/tests/compile/piecewise/test_simple.py +++ b/tests/compile/piecewise/test_simple.py @@ -2,7 +2,6 @@ Test the piecewise compilation with a simple model so that we can exactly calculate the expected output and side effects. """ -import os import torch from torch import nn @@ -11,8 +10,8 @@ from vllm.compilation.compile_context import set_compile_context from vllm.compilation.counter import compilation_counter from vllm.compilation.decorators import support_torch_compile -from vllm.config import CompilationLevel, VllmConfig -from vllm.plugins import set_current_vllm_config +from vllm.config import (CompilationConfig, CompilationLevel, VllmConfig, + set_current_vllm_config) from vllm.utils import direct_register_custom_op global_counter = 0 @@ -77,12 +76,12 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: def test_simple_piecewise_compile(): - directory = os.path.dirname(__file__) - config = os.path.join(directory, "piecewise_compilation_config.json") - os.environ["VLLM_TORCH_COMPILE_CONFIG"] = config - os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(CompilationLevel.PIECEWISE) - - vllm_config = VllmConfig() + vllm_config = VllmConfig(compilation_config=CompilationConfig( + level=CompilationLevel.PIECEWISE, + use_cudagraph=True, + splitting_ops=["silly.attention"], + cudagraph_copy_inputs=True, + )) with set_current_vllm_config(vllm_config): model = SillyModel(vllm_config=vllm_config, prefix='') @@ -109,6 +108,3 @@ def test_simple_piecewise_compile(): output = model(input) assert global_counter == 2 assert torch.allclose(output.cpu(), torch.tensor([3., 1.])) - - # clean up to avoid side effects for other tests - del os.environ["VLLM_TORCH_COMPILE_CONFIG"] diff --git a/tests/compile/piecewise/test_toy_llama.py b/tests/compile/piecewise/test_toy_llama.py index 8032304e95806..dbd5a3bbffeab 100644 --- a/tests/compile/piecewise/test_toy_llama.py +++ b/tests/compile/piecewise/test_toy_llama.py @@ -6,7 +6,6 @@ if the config `tractable_init` is set to True. Otherwise, the weights are initialized randomly with a fixed seed. """ -import os from dataclasses import dataclass from typing import Optional, Tuple @@ -17,8 +16,8 @@ from vllm.compilation.compile_context import set_compile_context from vllm.compilation.counter import compilation_counter from vllm.compilation.decorators import support_torch_compile -from vllm.config import CompilationConfig, CompilationLevel, VllmConfig -from vllm.plugins import set_compilation_config, set_current_vllm_config +from vllm.config import (CompilationConfig, CompilationLevel, VllmConfig, + set_current_vllm_config) from vllm.utils import direct_register_custom_op # create a library to hold the custom op @@ -254,23 +253,17 @@ def run_model(llama_config, split_attn: bool = False) -> torch.Tensor: if use_compile: - os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str( - CompilationLevel.PIECEWISE) - + compilation_config = CompilationConfig( + level=CompilationLevel.PIECEWISE, + use_cudagraph=True, + ) if split_attn: - set_compilation_config( - CompilationConfig( - use_cudagraph=True, - non_cudagraph_ops=["silly.attention"], - )) - else: - set_compilation_config(CompilationConfig(use_cudagraph=True, )) + compilation_config.splitting_ops = ["silly.attention"] else: - os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str( - CompilationLevel.NO_COMPILATION) - set_compilation_config(None) + compilation_config = CompilationConfig( + level=CompilationLevel.NO_COMPILATION, ) - vllm_config = VllmConfig() + vllm_config = VllmConfig(compilation_config=compilation_config) with set_current_vllm_config(vllm_config): model = LlamaModel(config=llama_config, vllm_config=vllm_config, @@ -288,10 +281,6 @@ def run_model(llama_config, input_ids[:2].zero_() output = model(input_ids[:2], positions[:2]) - # manual cleanup - del os.environ["VLLM_TORCH_COMPILE_LEVEL"] - set_compilation_config(None) - output = output.cpu() if llama_config.tractable_init: @@ -361,7 +350,6 @@ def test_toy_llama(): @torch.inference_mode def benchmark(): - os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(CompilationLevel.PIECEWISE) from triton.testing import do_bench # similar to llama 3.1-8B @@ -387,15 +375,16 @@ def benchmark(): for piecewise in [False, True]: if piecewise: - set_compilation_config( - CompilationConfig( - use_cudagraph=True, - non_cudagraph_ops=["silly.attention"], - )) + compilation_config = CompilationConfig( + level=CompilationLevel.PIECEWISE, + use_cudagraph=True, + splitting_ops=["silly.attention"], + ) else: - set_compilation_config(None) + compilation_config = CompilationConfig( + level=CompilationLevel.PIECEWISE, ) - vllm_config = VllmConfig() + vllm_config = VllmConfig(compilation_config=compilation_config) with set_current_vllm_config(vllm_config): model = LlamaModel(config=llama_config, vllm_config=vllm_config, diff --git a/tests/compile/test_basic_correctness.py b/tests/compile/test_basic_correctness.py index 08747ebc58b75..99781c55b672e 100644 --- a/tests/compile/test_basic_correctness.py +++ b/tests/compile/test_basic_correctness.py @@ -62,6 +62,16 @@ class TestSetting: method="encode", fullgraph=True, ), + # encoder-based embedding model (BERT) + TestSetting( + model="BAAI/bge-base-en-v1.5", + model_args=["--task", "embedding"], + pp_size=1, + tp_size=1, + attn_backend="XFORMERS", + method="encode", + fullgraph=True, + ), # vision language model TestSetting( model="microsoft/Phi-3.5-vision-instruct", @@ -96,31 +106,36 @@ def test_compile_correctness(test_setting: TestSetting): final_args = ["--enforce-eager"] + model_args + ["-pp", str(pp_size)] + \ ["-tp", str(tp_size)] + all_args: List[List[str]] = [] all_envs: List[Optional[Dict[str, str]]] = [] for level in [ CompilationLevel.NO_COMPILATION, CompilationLevel.PIECEWISE, ]: - all_envs.append({"VLLM_TORCH_COMPILE_LEVEL": str(level)}) + all_args.append(final_args + [f"-O{level}"]) + all_envs.append({}) # inductor will change the output, so we only compare if the output # is close, not exactly the same. compare_all_settings( - model, [final_args] * 2, + model, + all_args, all_envs, method=method if method != "generate" else "generate_close") all_envs.clear() + all_args.clear() for level in [ CompilationLevel.NO_COMPILATION, CompilationLevel.DYNAMO_AS_IS, CompilationLevel.DYNAMO_ONCE, ]: - all_envs.append({"VLLM_TORCH_COMPILE_LEVEL": str(level)}) + all_args.append(final_args + [f"-O{level}"]) + all_envs.append({}) if level != CompilationLevel.DYNAMO_ONCE and not fullgraph: # "DYNAMO_ONCE" will always use fullgraph all_envs[-1][ "VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE"] = "0" # type: ignore - compare_all_settings(model, [final_args] * 3, all_envs, method=method) + compare_all_settings(model, all_args * 3, all_envs, method=method) diff --git a/tests/compile/test_functionalization.py b/tests/compile/test_functionalization.py new file mode 100644 index 0000000000000..5036189077be2 --- /dev/null +++ b/tests/compile/test_functionalization.py @@ -0,0 +1,95 @@ +import pytest +import torch + +import vllm.envs as envs +from vllm import LLM, SamplingParams +from vllm.compilation.fix_functionalization import FixFunctionalizationPass +from vllm.compilation.fusion import (FusionPass, find_auto_fn, + find_auto_fn_maybe) +from vllm.compilation.reshapes import RedundantReshapesPass +from vllm.compilation.vllm_inductor_pass import is_func +from vllm.config import CompilationConfig + +from .backend import TestBackend + +OPS_IN_MODEL = [ + torch.ops._C.rotary_embedding.default, + torch.ops._C.fused_add_rms_norm.default, + torch.ops._C.silu_and_mul.default, +] + +RMS_OP = torch.ops._C.rms_norm.default + +RMS_QUANT_OPS = { + "static_fp8": [ + torch.ops._C.rms_norm_static_fp8_quant.default, + torch.ops._C.fused_add_rms_norm_static_fp8_quant.default + ], +} + +prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", +] + + +@pytest.mark.parametrize("model", + ["nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8-e2e"]) +@pytest.mark.parametrize("do_fusion", [True, False]) +@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE != "cuda", + reason="Only test on CUDA") +def test_fix_functionalization(model: str, do_fusion: bool): + torch.set_default_device("cuda") + + config = CompilationConfig.PassConfig(enable_fusion=do_fusion, + enable_reshape=True) + reshape_pass = RedundantReshapesPass(config) + fusion_pass = FusionPass.instance(config) + + passes = [reshape_pass, fusion_pass] if do_fusion else [reshape_pass] + func_pass = FixFunctionalizationPass(config) + backend_func = TestBackend(*passes, func_pass) + backend_no_func = TestBackend(*passes) + + # instantiate a full engine and manually compile the model 2x + # (with and without FixFunctionalizationPass) + llm = LLM(model=model, enforce_eager=True) + model_runner = llm.llm_engine.model_executor.driver_worker.model_runner + orig_model = model_runner.model + # TODO mark inputs dynamic? (currently torch.compile is triggered 4x) + # Can only do that by using the decorator but then we'd have to instantiate + # 2 LLM instances. + + sampling_params = SamplingParams(temperature=0.0, top_p=1.0) + model_runner.model = torch.compile(orig_model, + fullgraph=True, + backend=backend_func) + gen_func = llm.generate(prompts, sampling_params) + + model_runner.model = torch.compile(orig_model, + fullgraph=True, + backend=backend_no_func) + gen_no_func = llm.generate(prompts, sampling_params) + + for output_func, output_no_func in zip(gen_func, gen_no_func): + assert output_func.outputs[0].text == output_no_func.outputs[0].text + + # OPS_IN_MODEL always appear. RMS_OP is fused away if we run fusion, + # and replaced by fused quantized ops in RMS_QUANT_OPS. + ops = OPS_IN_MODEL + (RMS_QUANT_OPS["static_fp8"] + if do_fusion else [RMS_OP]) + + for op in ops: + find_auto_fn(backend_no_func.graph_post_pass.nodes, op) + assert find_auto_fn_maybe(backend_func.graph_post_pass.nodes, + op) is None # noqa: E501 + + # make sure the ops were all de-functionalized + found = dict() + for node in backend_func.graph_post_pass.nodes: + for op in ops: + if is_func(node, op): + found[op] = True + assert all(found[op] for op in ops) diff --git a/tests/compile/test_fusion.py b/tests/compile/test_fusion.py index 4db79b070fd8d..f92ec8d0de5f1 100644 --- a/tests/compile/test_fusion.py +++ b/tests/compile/test_fusion.py @@ -38,12 +38,6 @@ def forward(self, x): return y3 -# Init does pattern registration, which can only happen once -config = CompilationConfig(enable_fusion=True) -reshape_pass = RedundantReshapesPass(config) -fusion_pass = FusionPass.instance(config) - - @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("hidden_size", [64, 3392, 4096]) @pytest.mark.parametrize("num_tokens", [7, 256, 533, 2048, 2049]) @@ -58,6 +52,11 @@ def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps): pytest.skip("Only test eps=1e-5 for now") # Reshape pass is needed for the fusion pass to work + config = CompilationConfig.PassConfig(enable_fusion=True, + enable_reshape=True) + reshape_pass = RedundantReshapesPass(config) + fusion_pass = FusionPass.instance(config) + backend = TestBackend(reshape_pass, fusion_pass) model = TestModel(hidden_size, eps) diff --git a/tests/compile/test_pass_manager.py b/tests/compile/test_pass_manager.py new file mode 100644 index 0000000000000..03e7535093c5d --- /dev/null +++ b/tests/compile/test_pass_manager.py @@ -0,0 +1,35 @@ +import pickle + +import pytest +import torch +from torch._inductor.codecache import BypassFxGraphCache + +from vllm.compilation.config import CompilationConfig +from vllm.compilation.inductor_pass import (CallableInductorPass, + as_inductor_pass) +from vllm.compilation.pass_manager import PostGradPassManager + + +def simple_callable(graph: torch.fx.Graph): + pass + + +@as_inductor_pass(files=(__file__, )) +def callable_decorated(graph: torch.fx.Graph): + pass + + +@pytest.mark.parametrize( + "works, callable", + [(False, simple_callable), (True, callable_decorated), + (True, CallableInductorPass(simple_callable, "simple_callable"))]) +def test_pass_manager(works: bool, callable): + config = CompilationConfig().pass_config + pass_manager = PostGradPassManager([callable]) + pass_manager.configure(config) # Adds default passes + + if works: + pickle.dumps(pass_manager) + else: + with pytest.raises(BypassFxGraphCache): + pickle.dumps(pass_manager) diff --git a/tests/compile/utils.py b/tests/compile/utils.py index 729f10676888b..7c92d165d05f7 100644 --- a/tests/compile/utils.py +++ b/tests/compile/utils.py @@ -65,7 +65,6 @@ def check_full_graph_support(model, optimization_level, tp_size=1): # make sure these models can be captured in full graph mode - os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(optimization_level) os.environ["VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE"] = "1" # The base meta llama uses too much memory. @@ -86,6 +85,7 @@ def check_full_graph_support(model, enforce_eager=True, tensor_parallel_size=tp_size, disable_custom_all_reduce=True, + compilation_config=optimization_level, **model_kwargs) outputs = llm.generate(prompts, sampling_params) diff --git a/tests/conftest.py b/tests/conftest.py index 0dc1cc6e83c18..d6be8f5b00af8 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -263,8 +263,8 @@ def __init__( dtype: str = "half", *, model_kwargs: Optional[Dict[str, Any]] = None, - is_embedding_model: bool = False, is_sentence_transformer: bool = False, + is_cross_encoder: bool = False, skip_tokenizer_init: bool = False, auto_cls: Type[_BaseAutoModelClass] = AutoModelForCausalLM, postprocess_inputs: Callable[..., BatchEncoding] = identity, @@ -282,6 +282,14 @@ def __init__( device="cpu", trust_remote_code=True, ).to(dtype=torch_dtype)) + elif is_cross_encoder: + # Lazy init required for AMD CI + from sentence_transformers import CrossEncoder + self.model = CrossEncoder(model_name, + device="cpu", + trust_remote_code=True) + self.model.model = self.wrap_device(self.model.model)\ + .to(dtype=torch_dtype) else: model_kwargs = model_kwargs if model_kwargs is not None else {} self.model = self.wrap_device( @@ -625,6 +633,9 @@ def generate_encoder_decoder_greedy_logprobs_limit( def encode(self, prompts: List[str]) -> List[List[torch.Tensor]]: return self.model.encode(prompts) + def predict(self, prompts: List[List[str]]) -> torch.Tensor: + return self.model.predict(prompts, convert_to_tensor=True) + def __enter__(self): return self @@ -645,6 +656,7 @@ def __init__( model_name: str, task: TaskOption = "auto", tokenizer_name: Optional[str] = None, + tokenizer_mode: str = "auto", # Use smaller max model length, otherwise bigger model cannot run due # to kv cache size limit. max_model_len: int = 1024, @@ -661,6 +673,7 @@ def __init__( model=model_name, task=task, tokenizer=tokenizer_name, + tokenizer_mode=tokenizer_mode, trust_remote_code=True, dtype=dtype, swap_space=swap_space, @@ -831,6 +844,7 @@ def generate_greedy_logprobs( audios: Optional[PromptAudioInput] = None, videos: Optional[PromptVideoInput] = None, stop_token_ids: Optional[List[int]] = None, + stop: Optional[List[str]] = None, ) -> Union[List[TokensTextLogprobs], List[TokensTextLogprobsPromptLogprobs]]: greedy_logprobs_params = SamplingParams( @@ -838,7 +852,8 @@ def generate_greedy_logprobs( max_tokens=max_tokens, logprobs=num_logprobs, prompt_logprobs=num_prompt_logprobs, - stop_token_ids=stop_token_ids) + stop_token_ids=stop_token_ids, + stop=stop) return self.generate_w_logprobs(prompts, greedy_logprobs_params, @@ -898,6 +913,14 @@ def encode( req_outputs = self.model.encode(inputs) return [req_output.outputs.embedding for req_output in req_outputs] + def score( + self, + text_1: Union[str, List[str]], + text_2: Union[str, List[str]], + ) -> List[List[float]]: + req_outputs = self.model.score(text_1, text_2) + return [req_output.outputs.embedding for req_output in req_outputs] + def __enter__(self): return self @@ -1010,3 +1033,22 @@ def dummy_gemma2_embedding_path(): with open(json_path, "w") as f: json.dump(config, f) return _dummy_gemma2_embedding_path + + +# Add the flag `--optional` to allow run tests +# that are marked with @pytest.mark.optional +def pytest_addoption(parser): + parser.addoption("--optional", + action="store_true", + default=False, + help="run optional test") + + +def pytest_collection_modifyitems(config, items): + if config.getoption("--optional"): + # --optional given in cli: do not skip optional tests + return + skip_optional = pytest.mark.skip(reason="need --optional option to run") + for item in items: + if "optional" in item.keywords: + item.add_marker(skip_optional) diff --git a/tests/core/block/e2e/test_correctness_sliding_window.py b/tests/core/block/e2e/test_correctness_sliding_window.py index fdae1ae689164..78ef24f08db86 100644 --- a/tests/core/block/e2e/test_correctness_sliding_window.py +++ b/tests/core/block/e2e/test_correctness_sliding_window.py @@ -3,6 +3,7 @@ import pytest +from tests.kernels.utils import override_backend_env_variable from vllm import LLM, SamplingParams from ....test_utils import xfail_if_rocm62 @@ -29,8 +30,9 @@ @pytest.mark.parametrize("test_llm_kwargs", [{}]) @pytest.mark.parametrize("batch_size", [5]) @pytest.mark.parametrize("seed", [1]) +@pytest.mark.parametrize("backend", ["FLASH_ATTN", "FLASHINFER", "XFORMERS"]) def test_sliding_window_retrival(baseline_llm_generator, test_llm_generator, - batch_size, seed): + batch_size, seed, backend, monkeypatch): """ The test does a bunch of assignments "x1 = 10\nx2 = 33\n..." and then asks for value of one of them (which is outside the sliding window). @@ -39,6 +41,8 @@ def test_sliding_window_retrival(baseline_llm_generator, test_llm_generator, Additionally, we compare the results of the v1 and v2 managers. """ + override_backend_env_variable(monkeypatch, backend) + sampling_params = SamplingParams( max_tokens=1024, ignore_eos=True, @@ -86,7 +90,9 @@ def test_sliding_window_retrival(baseline_llm_generator, test_llm_generator, @pytest.mark.parametrize("test_llm_kwargs", [{"enable_chunked_prefill": True}]) @pytest.mark.parametrize("batch_size", [5]) @pytest.mark.parametrize("seed", [1]) -def test_sliding_window_chunked_prefill(test_llm_generator, batch_size, seed): +@pytest.mark.parametrize("backend", ["FLASH_ATTN", "FLASHINFER", "XFORMERS"]) +def test_sliding_window_chunked_prefill(test_llm_generator, batch_size, seed, + backend, monkeypatch): """ This is similar to test_sliding_window_retrival, however, it doesn't compare against the v1 block manager since v1 doesn't support @@ -95,6 +101,8 @@ def test_sliding_window_chunked_prefill(test_llm_generator, batch_size, seed): The results with and without chunked prefill are not the same due to numerical instabilities. """ + override_backend_env_variable(monkeypatch, backend) + sampling_params = SamplingParams( max_tokens=10, ignore_eos=True, diff --git a/tests/core/block/test_prefix_caching_block.py b/tests/core/block/test_prefix_caching_block.py index d325b9606843e..bbeb4b3a58f2a 100644 --- a/tests/core/block/test_prefix_caching_block.py +++ b/tests/core/block/test_prefix_caching_block.py @@ -5,9 +5,14 @@ import pytest +from tests.core.utils import create_dummy_sequence +from vllm.core.block.cpu_gpu_block_allocator import CpuGpuBlockAllocator from vllm.core.block.interfaces import Block, BlockAllocator -from vllm.core.block.prefix_caching_block import (PrefixCachingBlock, +from vllm.core.block.prefix_caching_block import (ComputedBlocksTracker, + PrefixCachingBlock, PrefixCachingBlockAllocator) +from vllm.sequence import Logprob +from vllm.utils import Device class TestPrefixCachingBlock: @@ -726,18 +731,71 @@ def test_touch_block(): token_ids=common_token_ids, allocator=allocator, ) - block_ids = [block.block_id for block in blocks] + block_hashes = [block.content_hash for block in blocks] # The allocated blocks should be marked as touched # but not computed. - computed_block_ids = allocator.get_computed_block_ids( - [], block_ids, skip_last_block_id=False) + computed_block_ids = allocator.find_cached_blocks_prefix( + block_hashes) assert len(computed_block_ids) == 0 allocator.mark_blocks_as_computed([]) - computed_block_ids = allocator.get_computed_block_ids( - [], block_ids, skip_last_block_id=False) + computed_block_ids = allocator.find_cached_blocks_prefix( + block_hashes=block_hashes) assert len(computed_block_ids) == common_blocks + @staticmethod + def test_find_cached_blocks_prefix(): + """ + This test verifies the behavior of find_cached_blocks_prefix. + """ + block_size = 4 + num_blocks = 8 + total_test_blocks = 12 + allocator = PrefixCachingBlockAllocator(num_blocks=num_blocks, + block_size=block_size) + + token_ids = list(range(total_test_blocks * block_size)) + block_tokens_seq1 = token_ids[:num_blocks * block_size] + blocks_seq1 = TestPrefixCachingBlockAllocator.create_immutable_chain( + block_size=block_size, + token_ids=block_tokens_seq1, + allocator=allocator, + ) + block_hashes_seq1 = [block.content_hash for block in blocks_seq1] + allocator.mark_blocks_as_computed([]) + + # All blocks should be cached. + cached_blocks_seq1 = allocator.find_cached_blocks_prefix( + block_hashes=block_hashes_seq1) + assert len(cached_blocks_seq1) == num_blocks + + # Free the first sequence. + for block in blocks_seq1: + allocator.free(block) + + # All blocks should be still be cached if not required to be allocated. + cached_blocks = allocator.find_cached_blocks_prefix( + block_hashes=block_hashes_seq1) + assert len(cached_blocks) == num_blocks + + block_tokens_seq2 = token_ids[num_blocks * block_size:] + blocks_seq2 = TestPrefixCachingBlockAllocator.create_immutable_chain( + block_size=block_size, + token_ids=block_tokens_seq2, + allocator=allocator, + ) + block_hashes_seq2 = [block.content_hash for block in blocks_seq2] + allocator.mark_blocks_as_computed([]) + cached_blocks = allocator.find_cached_blocks_prefix( + block_hashes=block_hashes_seq2) + assert len(cached_blocks) == len(blocks_seq2) + + # Half of the blocks from seq1 should still be cached. + num_evicted_blocks = len(blocks_seq2) + cached_blocks = allocator.find_cached_blocks_prefix( + block_hashes=block_hashes_seq1) + assert len(cached_blocks) == len(blocks_seq1) - num_evicted_blocks + @staticmethod def create_immutable_chain( block_size: int, @@ -762,3 +820,114 @@ def create_immutable_chain( blocks.append(prev_block) return blocks + + +class TestComputedBlocksTracker: + + @staticmethod + def _get_mock_allocator(): + return MagicMock(spec=PrefixCachingBlockAllocator) + + @staticmethod + def test_get_num_cached_tokens(): + """ + Test it correctly computes the number of cached tokens for a given + sequence: + + - The cache token count is derived from the number of cached blocks. + - The cache token count is updated when the allocator is updated. + - When a sequence is removed, the cache token count should be updated + accordingly. + + # TODO(rickyx): This behaviour for prefill sequence is a hack until + we fix the computed blocks tracking. + - The cache token count for prefill sequence doesn't change while + the sequence is in continuous prefill (chunked prefill). + """ + block_size = 4 + mock_allocator = TestComputedBlocksTracker._get_mock_allocator() + tracker = ComputedBlocksTracker( + allocator=mock_allocator, + block_size=block_size, + enable_caching=True, + ) + + # Not yet allocated. + tokens = [0, 1, 2, 3, 4, 5] + seq1 = create_dummy_sequence(request_id=0, + token_ids=tokens, + block_size=block_size) + mock_allocator.find_cached_blocks_prefix.return_value = [] + assert tracker.get_num_cached_tokens(seq1) == 0 + + mock_allocator.find_cached_blocks_prefix.return_value = [ + None + ] # 1 block cached. + # Result is cached for prefill sequence. + assert tracker.get_num_cached_tokens(seq1) == 0 + + # Mark the sequence as non-prefill. + seq1.data.update_num_computed_tokens(len(tokens)) # 6 tokens computed. + assert not seq1.is_prefill() + + # Recomputes for decoding sequence. + assert tracker.get_num_cached_tokens(seq1) == 4 + + # Append new tokens to the sequence. + num_new_tokens = 3 + for i in range(num_new_tokens): + seq1.append_token_id(i, {i: Logprob(logprob=0.0)}) + + assert tracker.get_num_cached_tokens(seq1) == 4 + + # Update the allocator. + mock_allocator.find_cached_blocks_prefix.return_value = [ + None + ] * 2 # 2 blocks cached. + assert tracker.get_num_cached_tokens(seq1) == 8 + + # Remove the sequence. + tracker.remove_seq(seq1.seq_id) + + # Re-create the sequence with the same request id to simulate recompute. + seq1 = create_dummy_sequence(request_id=0, + token_ids=tokens, + block_size=block_size) + mock_allocator.find_cached_blocks_prefix.return_value = [ + ] # no cached block + assert tracker.get_num_cached_tokens(seq1) == 0 + + @staticmethod + def test_correct_block_hash(): + """ + Test that the block hash is correctly computed for a sequence (should + match the underlying block allocator's block hash). So the number of + cached tokens is correctly retrieved. + """ + block_size = 4 + allocator = CpuGpuBlockAllocator.create( + allocator_type="prefix_caching", + num_gpu_blocks=16, + num_cpu_blocks=16, + block_size=block_size, + ) + gpu_allocator = allocator._allocators[Device.GPU] + + tracker = ComputedBlocksTracker( + allocator=allocator, + block_size=block_size, + enable_caching=True, + ) + + tokens = list(range(block_size * 4)) # 4 blocks. + seq = create_dummy_sequence(request_id=0, + token_ids=tokens, + block_size=block_size) + _ = TestPrefixCachingBlockAllocator.create_immutable_chain( + block_size=block_size, + token_ids=tokens, + allocator=gpu_allocator, + ) + allocator.mark_blocks_as_computed([]) + + assert tracker.get_num_cached_tokens(seq) == len(tokens) diff --git a/tests/core/test_chunked_prefill_scheduler.py b/tests/core/test_chunked_prefill_scheduler.py index acd82065ae457..eaaf004df38b2 100644 --- a/tests/core/test_chunked_prefill_scheduler.py +++ b/tests/core/test_chunked_prefill_scheduler.py @@ -413,6 +413,45 @@ def cannot_append_second_group2(seq_group, num_lookahead_slots): assert out.num_batched_tokens == max_num_batched_tokens +@pytest.mark.parametrize("num_scheduler_steps", [1, 5]) +def test_chunked_prefill_spec_prefill(num_scheduler_steps): + """Verify that the num_lookahead_slots is set appropriately for an all""" + """prefill batch depending on whether multi-step scheduling is enabled""" + """or not""" + block_size = 4 + max_seqs = 30 + max_model_len = 200 + max_num_batched_tokens = 30 + num_lookahead_slots = 4 + scheduler_config = SchedulerConfig( + "generate", + max_num_batched_tokens, + max_seqs, + max_model_len, + enable_chunked_prefill=True, + num_lookahead_slots=num_lookahead_slots, + num_scheduler_steps=num_scheduler_steps, + ) + cache_config = CacheConfig(block_size, 1.0, 1, "auto") + cache_config.num_cpu_blocks = 16 + cache_config.num_gpu_blocks = 16 + scheduler = Scheduler(scheduler_config, cache_config, None) + + _, seq_group = create_dummy_prompt("1", + prompt_length=30, + block_size=block_size) + scheduler.add_seq_group(seq_group) + _, out = schedule_and_update_computed_tokens(scheduler) + # The request is chunked. + # prefill scheduled now. + assert len(out.scheduled_seq_groups) == 1 + assert out.num_prefill_groups == 1 + assert out.num_batched_tokens == max_num_batched_tokens + print(out.num_lookahead_slots) + assert out.num_lookahead_slots == (0 if (num_scheduler_steps == 1) else + num_lookahead_slots) + + def test_chunked_prefill_max_seqs(): block_size = 4 max_seqs = 2 diff --git a/tests/core/test_scheduler.py b/tests/core/test_scheduler.py index 5ff32be611592..8f6de84e566e7 100644 --- a/tests/core/test_scheduler.py +++ b/tests/core/test_scheduler.py @@ -12,9 +12,9 @@ from vllm.lora.request import LoRARequest from vllm.sequence import SequenceGroup -from .utils import (append_new_token, append_new_token_seq_group, - create_dummy_prompt, get_sequence_groups, - schedule_and_update_computed_tokens) +from .utils import (append_new_token, append_new_token_seq, + append_new_token_seq_group, create_dummy_prompt, + get_sequence_groups, schedule_and_update_computed_tokens) def test_scheduler_add_seq_group(): @@ -305,6 +305,8 @@ def initialize_scheduler( block_size=4, num_cpu_blocks=8, num_gpu_blocks=8, + enable_prefix_caching=False, + enable_chunked_prefill=False, ): block_size = block_size scheduler_config = SchedulerConfig( @@ -312,8 +314,15 @@ def initialize_scheduler( max_num_batched_tokens=max_token_budget, max_num_seqs=max_num_seqs, max_model_len=max_model_len, + enable_chunked_prefill=enable_chunked_prefill, + ) + cache_config = CacheConfig( + block_size, + 1.0, + 1, + "auto", + enable_prefix_caching=enable_prefix_caching, ) - cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config.num_cpu_blocks = num_cpu_blocks cache_config.num_gpu_blocks = num_gpu_blocks scheduler = Scheduler(scheduler_config, cache_config, lora_config) @@ -800,3 +809,165 @@ def test_scheduling_budget(): assert budget.num_curr_seqs == 0 budget.subtract_num_seqs(seq_group.request_id, 2) assert budget.num_curr_seqs == 0 + + +@pytest.mark.parametrize("enable_prefix_caching", [True, False]) +def test_prefix_caching_aware_prefills(enable_prefix_caching): + """ + Test the below scenario: + + For 3 sequences, seqA, seqB, seqC, share the first block as prefix. + + The test verifies the below scenarios: + 1. SeqA is first scheduled. + 2. SeqB and SeqC can be prefilled together in a single schedule round + even though there are not enough token budgets to prefill both without + considering prefix caching. + """ + + block_size = 4 + max_num_batched_tokens = 12 + max_seq_group = 3 + scheduler = initialize_scheduler( + block_size=block_size, + num_cpu_blocks=16, + num_gpu_blocks=16, + max_token_budget=max_num_batched_tokens, + max_num_seqs=max_seq_group, + max_model_len=max_num_batched_tokens, + enable_prefix_caching=enable_prefix_caching, + ) + + seqA_tokens = list(range(8)) + num_shared_tokens = 4 + seqB_tokens = seqA_tokens[:num_shared_tokens] + list(range( + 12, 16)) # Shared prefix first 4. + seqC_tokens = seqA_tokens[:num_shared_tokens] + list(range( + 16, 20)) # Shared prefix first 4. + + seqA, seqA_group = create_dummy_prompt("0", + prompt_tokens=seqA_tokens, + block_size=block_size) + seqB, seqB_group = create_dummy_prompt("1", + prompt_tokens=seqB_tokens, + block_size=block_size) + seqC, seqC_group = create_dummy_prompt("2", + prompt_tokens=seqC_tokens, + block_size=block_size) + + # Schedule seqA prefill. + scheduler.add_seq_group(seqA_group) + metas, out, _ = scheduler.schedule() + assert (len(out.scheduled_seq_groups) == 1 + and out.scheduled_seq_groups[0].seq_group == seqA_group) + assert out.scheduled_seq_groups[0].token_chunk_size == len(seqA_tokens) + + # Schedule seqA decode. + append_new_token_seq_group(len(seqA_tokens), seqA_group, 999) + metas, out, _ = scheduler.schedule() + + assert len(out.scheduled_seq_groups) == 1 + assert out.scheduled_seq_groups[0].seq_group == seqA_group + assert out.scheduled_seq_groups[0].token_chunk_size == 1 + + # Schedule seqB and seqC prefills should work with prefix caching. + scheduler.add_seq_group(seqB_group) + scheduler.add_seq_group(seqC_group) + metas, out, _ = scheduler.schedule() + + if enable_prefix_caching: + assert len(out.scheduled_seq_groups) == 2 + assert set([ + out.scheduled_seq_groups[0].seq_group, + out.scheduled_seq_groups[1].seq_group, + ]) == set([seqB_group, seqC_group]) + assert len(metas) == 2 + for meta in metas: + assert meta.token_chunk_size == 8 + assert (len(meta.computed_block_nums) == num_shared_tokens // + block_size) # 1 Block for the 8 tokens. + else: + assert len(out.scheduled_seq_groups) == 1 + assert len(metas) == 1 + assert metas[0].token_chunk_size == 8 + assert len(metas[0].computed_block_nums) == 0 # No blocks computed. + + +def test_no_multiple_partial_prefills_with_chunked_prefill_and_prefix_caching( +): + """ + This test verifies that we don't schedule new prefills if there's already + a continuous prefill in progress even though the new prefills with shared + prefix can fit in the token budget: + + - SeqA is being chunked prefill. + - SeqB with the same prompt shouldn't be scheduled for prefill even though + there's enough token budget to prefill the cached tokens. + - Neither should seqC be scheduled. + + - When seqA is in decoding phase, seqB and seqC can be scheduled. + - Entire seqB should be prefilled since it's a full prefix cache hit. + - SeqC would be partially prefilled with the prefix shared, and the + remaining unique tokens would be prefilled (rounded down to be + block-size aligned). + """ + + block_size = 2 + max_num_batched_tokens = 4 + max_seq_group = 3 + scheduler = initialize_scheduler( + block_size=block_size, + num_cpu_blocks=16, + num_gpu_blocks=16, + max_token_budget=max_num_batched_tokens, + max_num_seqs=max_seq_group, + max_model_len=100, + enable_prefix_caching=True, + enable_chunked_prefill=True, + ) + + seqA_tokens = list(range(8)) + seqB_tokens = seqA_tokens + seqC_shared_prefix_len = 4 + seqC_tokens = seqA_tokens[:seqC_shared_prefix_len] + list(range(12, 20)) + + seqA, seqA_group = create_dummy_prompt("0", + prompt_tokens=seqA_tokens, + block_size=block_size) + seqB, seqB_group = create_dummy_prompt("1", + prompt_tokens=seqB_tokens, + block_size=block_size) + + # Chunked prefill seqA. + scheduler.add_seq_group(seqA_group) + metas, out = schedule_and_update_computed_tokens(scheduler) + assert len(out.scheduled_seq_groups) == 1 + assert out.scheduled_seq_groups[0].seq_group == seqA_group + assert out.scheduled_seq_groups[0].token_chunk_size == 4 + + # seqB should not be scheduled with ongoing prefills. + scheduler.add_seq_group(seqB_group) + metas, out = schedule_and_update_computed_tokens(scheduler) + assert len(out.scheduled_seq_groups) == 1 + assert out.scheduled_seq_groups[0].seq_group == seqA_group + assert out.scheduled_seq_groups[0].token_chunk_size == 4 + + # both seqB and seqC can now be scheduled with seqA is over. + # seqA is in decoding phase. + append_new_token_seq(seqA, 999) + seqC, seqC_group = create_dummy_prompt("2", + prompt_tokens=seqC_tokens, + block_size=block_size) + scheduler.add_seq_group(seqC_group) + metas, out = schedule_and_update_computed_tokens(scheduler) + assert len(out.scheduled_seq_groups) == 3 + + metas = {meta.request_id: meta for meta in metas} + assert metas[seqA_group.request_id].token_chunk_size == 1 # Decode + assert (metas[seqB_group.request_id].token_chunk_size == 8 + ) # Fully cached prefill + assert ( + metas[seqC_group.request_id].token_chunk_size == 6 + ), "A partial prefix of C (4 tokens) should be prefilled, with the " + "remaining tokens fit into 3 token budget (4-1 from the seqA). It will " + "then be rounded down to 2 tokens on block size, thus 6 tokens in total." diff --git a/tests/core/utils.py b/tests/core/utils.py index cd0caa4704e11..277368b57b938 100644 --- a/tests/core/utils.py +++ b/tests/core/utils.py @@ -1,17 +1,20 @@ import time -from typing import List, Optional +from collections import defaultdict +from typing import Any, Dict, List, Optional from typing import Sequence as GenericSequence from typing import Tuple from vllm import SamplingParams +from vllm.core.scheduler import Scheduler, SchedulerOutputs from vllm.inputs import EncoderDecoderInputs, token_inputs from vllm.lora.request import LoRARequest -from vllm.sequence import Logprob, Sequence, SequenceGroup +from vllm.sequence import (Logprob, Sequence, SequenceGroup, + SequenceGroupMetadata) def create_dummy_prompt( request_id: str, - prompt_length: int, + prompt_length: int = -1, block_size: Optional[int] = None, lora_request: Optional[LoRARequest] = None, best_of: int = 1, @@ -26,6 +29,7 @@ def create_dummy_prompt( # Create dummy prompt sequence with tokens 0...block_size-1 # and prompt "0 ... block_size". prompt_tokens = list(range(prompt_length)) + prompt_str = " ".join([str(t) for t in prompt_tokens]) prompt = Sequence(int(request_id), inputs=token_inputs(prompt_tokens, prompt=prompt_str), @@ -42,6 +46,15 @@ def create_dummy_prompt( return prompt, seq_group +def create_dummy_sequence(request_id: int, token_ids: List[int], + block_size: int) -> Sequence: + return Sequence( + seq_id=request_id, + inputs=token_inputs(token_ids), + block_size=block_size, + ) + + def create_dummy_prompt_encoder_decoder( request_id: str, decoder_prompt_length: int, @@ -194,12 +207,40 @@ def append_new_token(out, token_id: int): def schedule_and_update_computed_tokens(scheduler): metas, out, _ = scheduler.schedule() - for s, meta in zip(out.scheduled_seq_groups, metas): - s.seq_group.update_num_computed_tokens(meta.token_chunk_size) + for s in out.scheduled_seq_groups: + s.seq_group.update_num_computed_tokens(s.token_chunk_size) return metas, out +def append_new_token_seq(seq: Sequence, token_id: int): + seq.append_token_id(token_id, {token_id: Logprob(token_id)}) + + def append_new_token_seq_group(token_chunk_size, seq_group, token_id: int): seq_group.update_num_computed_tokens(token_chunk_size) for seq in seq_group.get_seqs(): seq.append_token_id(token_id, {token_id: Logprob(token_id)}) + + +class SchedulerProxy: + """ + A proxy class to forward calls to the scheduler. + """ + + def __init__(self, scheduler: Scheduler): + self.scheduler_ = scheduler + self.call_history: Dict[str, List[Any]] = defaultdict(list) + + def __getattr__(self, name: str) -> Any: + + def wrapper(*args, **kwargs): + result = getattr(self.scheduler_, name)(*args, **kwargs) + self.call_history[name].append((args, kwargs, result)) + return result + + return wrapper + + def last_schedule_ret( + self, ) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs, Any]: + _, _, ret = self.call_history["schedule"][-1] + return ret diff --git a/tests/distributed/test_pipeline_parallel.py b/tests/distributed/test_pipeline_parallel.py index c49ed9802cde8..386877e0e0a2c 100644 --- a/tests/distributed/test_pipeline_parallel.py +++ b/tests/distributed/test_pipeline_parallel.py @@ -167,6 +167,7 @@ def iter_params(self, model_name: str): "mosaicml/mpt-7b": PPTestSettings.fast(), "nvidia/Minitron-8B-Base": PPTestSettings.fast(), "allenai/OLMo-1B-hf": PPTestSettings.fast(), + "shanearora/OLMo-7B-1124-hf": PPTestSettings.fast(), "allenai/OLMoE-1B-7B-0924-Instruct": PPTestSettings.fast(), "facebook/opt-iml-max-1.3b": PPTestSettings.fast(), "OrionStarAI/Orion-14B-Chat": PPTestSettings.fast(trust_remote_code=True), diff --git a/tests/distributed/test_pynccl.py b/tests/distributed/test_pynccl.py index e0e424439e3a5..fb24d6bc2c100 100644 --- a/tests/distributed/test_pynccl.py +++ b/tests/distributed/test_pynccl.py @@ -60,7 +60,7 @@ def worker_fn(): tensor = torch.ones(16, 1024, 1024, dtype=torch.float32).cuda(pynccl_comm.rank) with pynccl_comm.change_state(enable=True): - pynccl_comm.all_reduce(tensor) + tensor = pynccl_comm.all_reduce(tensor) result = tensor.mean().cpu().item() assert result == pynccl_comm.world_size @@ -84,12 +84,12 @@ def multiple_allreduce_worker_fn(): with pynccl_comm.change_state(enable=True): # two groups can communicate independently if torch.distributed.get_rank() in [0, 1]: - pynccl_comm.all_reduce(tensor) - pynccl_comm.all_reduce(tensor) + tensor = pynccl_comm.all_reduce(tensor) + tensor = pynccl_comm.all_reduce(tensor) result = tensor.mean().cpu().item() assert result == 4 else: - pynccl_comm.all_reduce(tensor) + tensor = pynccl_comm.all_reduce(tensor) result = tensor.mean().cpu().item() assert result == 2 @@ -140,14 +140,80 @@ def worker_fn_with_cudagraph(): with torch.cuda.graph( graph, stream=pynccl_comm.stream), pynccl_comm.change_state( enable=True): - # operation during the graph capture is recorded but not executed - # see https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#creating-a-graph-using-stream-capture # noqa - pynccl_comm.all_reduce(a) + a_out = pynccl_comm.all_reduce(a) pynccl_comm.stream.synchronize() - assert a.mean().cpu().item() == pynccl_comm.world_size**0 graph.replay() pynccl_comm.stream.synchronize() - assert a.mean().cpu().item() == pynccl_comm.world_size**1 + assert a_out.mean().cpu().item() == pynccl_comm.world_size**1 + + +@worker_fn_wrapper +def all_gather_worker_fn(): + pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group, + device=get_world_group().device) + + rank = pynccl_comm.rank + world_size = pynccl_comm.world_size + device = f'cuda:{pynccl_comm.rank}' + + num_elems = 1000 + tensor = torch.arange(num_elems, dtype=torch.float32, + device=device) + rank * num_elems + result = torch.zeros(num_elems * world_size, + dtype=torch.float32, + device=device) + + expected = torch.cat([ + torch.arange(num_elems, dtype=torch.float32) + r * num_elems + for r in range(world_size) + ]).to(device) + + with pynccl_comm.change_state(enable=True): + pynccl_comm.all_gather(result, tensor) + torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8) + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, + reason="Need at least 2 GPUs to run the test.") +def test_pynccl_all_gather(): + distributed_run(all_gather_worker_fn, 2) + + +@worker_fn_wrapper +def reduce_scatter_worker_fn(): + pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group, + device=get_world_group().device) + + rank = pynccl_comm.rank + world_size = pynccl_comm.world_size + device = f'cuda:{pynccl_comm.rank}' + + num_elems = 1000 + tensor = torch.arange(num_elems, dtype=torch.float32, + device=device) + rank * num_elems + assert (num_elems % world_size == 0) + result = torch.zeros(num_elems // world_size, + dtype=torch.float32, + device=device) + + # Calculate expected result for this rank's chunk + scattered_size = num_elems // world_size + all_tensors = [ + torch.arange(num_elems, dtype=torch.float32) + r * num_elems + for r in range(world_size) + ] + expected = sum(tensor[rank * scattered_size:(rank + 1) * scattered_size] + for tensor in all_tensors).to(device) + + with pynccl_comm.change_state(enable=True): + pynccl_comm.reduce_scatter(result, tensor) + torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8) + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, + reason="Need at least 2 GPUs to run the test.") +def test_pynccl_reduce_scatter(): + distributed_run(reduce_scatter_worker_fn, 2) @pytest.mark.skipif(torch.cuda.device_count() < 2, diff --git a/tests/distributed/test_utils.py b/tests/distributed/test_utils.py index 686b697c98e03..5fb1ae7b29fd2 100644 --- a/tests/distributed/test_utils.py +++ b/tests/distributed/test_utils.py @@ -70,14 +70,12 @@ def gpu_worker(rank, WORLD_SIZE, port1, port2): rank=rank, world_size=WORLD_SIZE) pynccl1 = PyNcclCommunicator(pg1, device=rank) - pynccl1.disabled = False if rank <= 2: pg2 = StatelessProcessGroup.create(host="127.0.0.1", port=port2, rank=rank, world_size=3) pynccl2 = PyNcclCommunicator(pg2, device=rank) - pynccl2.disabled = False data = torch.tensor([rank]).cuda() pynccl1.all_reduce(data) pg1.barrier() diff --git a/tests/engine/test_arg_utils.py b/tests/engine/test_arg_utils.py index 7b1be5a9802fd..de78d41ad12eb 100644 --- a/tests/engine/test_arg_utils.py +++ b/tests/engine/test_arg_utils.py @@ -31,6 +31,53 @@ def test_limit_mm_per_prompt_parser(arg, expected): assert args.limit_mm_per_prompt == expected +def test_compilation_config(): + parser = EngineArgs.add_cli_args(FlexibleArgumentParser()) + + # default value + args = parser.parse_args([]) + assert args.compilation_config is None + + # set to O3 + args = parser.parse_args(["-O3"]) + assert args.compilation_config.level == 3 + + # set to O 3 (space) + args = parser.parse_args(["-O", "3"]) + assert args.compilation_config.level == 3 + + # set to O 3 (equals) + args = parser.parse_args(["-O=3"]) + assert args.compilation_config.level == 3 + + # set to json + args = parser.parse_args(["--compilation-config", '{"level": 3}']) + assert args.compilation_config.level == 3 + + # set to json + args = parser.parse_args(['--compilation-config={"level": 3}']) + assert args.compilation_config.level == 3 + + +def test_prefix_cache_default(): + parser = EngineArgs.add_cli_args(FlexibleArgumentParser()) + args = parser.parse_args([]) + + engine_args = EngineArgs.from_cli_args(args=args) + assert (not engine_args.enable_prefix_caching + ), "prefix caching defaults to off." + + # with flag to turn it on. + args = parser.parse_args(["--enable-prefix-caching"]) + engine_args = EngineArgs.from_cli_args(args=args) + assert engine_args.enable_prefix_caching + + # with disable flag to turn it off. + args = parser.parse_args(["--no-enable-prefix-caching"]) + engine_args = EngineArgs.from_cli_args(args=args) + assert not engine_args.enable_prefix_caching + + def test_valid_pooling_config(): parser = EngineArgs.add_cli_args(FlexibleArgumentParser()) args = parser.parse_args([ diff --git a/tests/entrypoints/llm/test_encode.py b/tests/entrypoints/llm/test_encode.py index 4c9f796e5ed71..41163809237e9 100644 --- a/tests/entrypoints/llm/test_encode.py +++ b/tests/entrypoints/llm/test_encode.py @@ -3,7 +3,7 @@ import pytest -from vllm import LLM, EmbeddingRequestOutput, PoolingParams +from vllm import LLM, PoolingParams, PoolingRequestOutput from vllm.distributed import cleanup_dist_env_and_memory MODEL_NAME = "intfloat/e5-mistral-7b-instruct" @@ -43,8 +43,8 @@ def llm(): cleanup_dist_env_and_memory() -def assert_outputs_equal(o1: List[EmbeddingRequestOutput], - o2: List[EmbeddingRequestOutput]): +def assert_outputs_equal(o1: List[PoolingRequestOutput], + o2: List[PoolingRequestOutput]): assert [o.outputs for o in o1] == [o.outputs for o in o2] diff --git a/tests/entrypoints/llm/test_lazy_outlines.py b/tests/entrypoints/llm/test_lazy_outlines.py index cbfb0cc32c1ce..2c53676c5f5dd 100644 --- a/tests/entrypoints/llm/test_lazy_outlines.py +++ b/tests/entrypoints/llm/test_lazy_outlines.py @@ -1,12 +1,13 @@ import sys +from contextlib import nullcontext + +from vllm_test_utils import BlameResult, blame from vllm import LLM, SamplingParams from vllm.distributed import cleanup_dist_env_and_memory -def test_lazy_outlines(sample_regex): - """If users don't use guided decoding, outlines should not be imported. - """ +def run_normal(): prompts = [ "Hello, my name is", "The president of the United States is", @@ -25,13 +26,12 @@ def test_lazy_outlines(sample_regex): generated_text = output.outputs[0].text print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") - # make sure outlines is not imported - assert 'outlines' not in sys.modules - # Destroy the LLM object and free up the GPU memory. del llm cleanup_dist_env_and_memory() + +def run_lmfe(sample_regex): # Create an LLM with guided decoding enabled. llm = LLM(model="facebook/opt-125m", enforce_eager=True, @@ -51,5 +51,26 @@ def test_lazy_outlines(sample_regex): generated_text = output.outputs[0].text print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + +def test_lazy_outlines(sample_regex): + """If users don't use guided decoding, outlines should not be imported. + """ # make sure outlines is not imported - assert 'outlines' not in sys.modules + module_name = "outlines" + # In CI, we only check finally if the module is imported. + # If it is indeed imported, we can rerun the test with `use_blame=True`, + # which will trace every function call to find the first import location, + # and help find the root cause. + # We don't run it in CI by default because it is slow. + use_blame = False + context = blame( + lambda: module_name in sys.modules) if use_blame else nullcontext() + with context as result: + run_normal() + run_lmfe(sample_regex) + if use_blame: + assert isinstance(result, BlameResult) + print(f"the first import location is:\n{result.trace_stack}") + assert module_name not in sys.modules, ( + f"Module {module_name} is imported. To see the first" + f" import location, run the test with `use_blame=True`.") diff --git a/tests/entrypoints/openai/test_async_tokenization.py b/tests/entrypoints/openai/test_async_tokenization.py new file mode 100644 index 0000000000000..fcce8b46c4344 --- /dev/null +++ b/tests/entrypoints/openai/test_async_tokenization.py @@ -0,0 +1,137 @@ +import asyncio +import contextlib +import random +import time +from typing import Callable + +import openai +import pytest +import pytest_asyncio +import requests + +from tests.utils import RemoteOpenAIServer + +MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct" + + +@pytest.fixture(scope="module") +def server(): # noqa: F811 + args = [ + # use half precision for speed and memory savings in CI environment + "--dtype", + "bfloat16", + "--max-model-len", + "8192", + "--enforce-eager", + "--max-num-seqs", + "128", + "--load-format", + "dummy", + ] + + with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: + yield remote_server + + +@pytest_asyncio.fixture +async def client(server): + async with server.get_async_client() as async_client: + yield async_client + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + ids=["completion", "chat"], + argnames=["create_func_gen", "content_body"], + argvalues=[ + (lambda x: x.completions.create, { + "prompt": " ".join(['A'] * 10_000) + }), + (lambda x: x.chat.completions.create, { + "messages": [{ + "role": "user", + "content": " ".join(['A'] * 10_000) + }] + }), + ], +) +async def test_with_and_without_truncate( + server: RemoteOpenAIServer, + client: openai.AsyncOpenAI, + create_func_gen: Callable, + content_body: dict, +): + create_func = create_func_gen(client) + body = {"model": MODEL_NAME, **content_body, "max_tokens": 10} + + num_requests = 10 + truncate_prompt_tokens = ([1000] * (num_requests // 2) + [None] * + (num_requests - num_requests // 2)) + random.shuffle(truncate_prompt_tokens) + + bodies = [{ + **body, "extra_body": { + 'truncate_prompt_tokens': t + } + } for t in truncate_prompt_tokens] + + async def get_status_code(**kwargs): + try: + await create_func(**kwargs) + return 200 + except openai.APIStatusError as e: + return e.status_code + + responses = await asyncio.gather(*[get_status_code(**b) for b in bodies]) + assert 500 not in responses + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + ids=["single completion", "multiple completions", "chat"], + argnames=["create_func_gen", "content_body"], + argvalues=[ + (lambda x: x.completions.create, { + "prompt": " ".join(['A'] * 300_000) + }), + (lambda x: x.completions.create, { + "prompt": [" ".join(['A'] * 300_000)] * 2 + }), + (lambda x: x.chat.completions.create, { + "messages": [{ + "role": "user", + "content": " ".join(['A'] * 300_000) + }] + }), + ], +) +async def test_healthcheck_response_time( + server: RemoteOpenAIServer, + client: openai.AsyncOpenAI, + create_func_gen: Callable, + content_body: dict, +): + num_requests = 50 + + create_func = create_func_gen(client) + body = {"model": MODEL_NAME, **content_body, "max_tokens": 10} + + def get_response_time(url): + start_time = time.monotonic() + res = requests.get(url) + end_time = time.monotonic() + assert res.status_code == 200 + return end_time - start_time + + no_load_response_time = get_response_time(server.url_for("health")) + tasks = [ + asyncio.create_task(create_func(**body)) for _ in range(num_requests) + ] + await asyncio.sleep(1) # give the tasks a chance to start running + load_response_time = get_response_time(server.url_for("health")) + + with contextlib.suppress(openai.APIStatusError): + await asyncio.gather(*tasks) + + assert load_response_time < 100 * no_load_response_time + assert load_response_time < 0.1 diff --git a/tests/entrypoints/openai/test_chat.py b/tests/entrypoints/openai/test_chat.py index 8d13f64dce01c..8d23a2be6f9bb 100644 --- a/tests/entrypoints/openai/test_chat.py +++ b/tests/entrypoints/openai/test_chat.py @@ -829,6 +829,20 @@ async def test_inconsistent_tool_choice_and_tools(client: openai.AsyncOpenAI, "name": "nondefined_function_name" } }) + with pytest.raises(openai.BadRequestError): + await client.chat.completions.create( + model=MODEL_NAME, + messages=messages, + max_completion_tokens=1000, + tools=[{ + "type": "function", + "function": { + "name": "dummy_function_name", + "description": "This is a dummy function", + "parameters": sample_json_schema + } + }], + tool_choice={}) @pytest.mark.asyncio @@ -899,19 +913,19 @@ async def test_response_format_json_schema(client: openai.AsyncOpenAI): @pytest.mark.asyncio -async def test_extra_fields(client: openai.AsyncOpenAI): - with pytest.raises(BadRequestError) as exc_info: - await client.chat.completions.create( - model=MODEL_NAME, - messages=[{ - "role": "system", - "content": "You are a helpful assistant.", - "extra_field": "0", - }], # type: ignore - temperature=0, - seed=0) - - assert "extra_forbidden" in exc_info.value.message +async def test_extra_fields_allowed(client: openai.AsyncOpenAI): + resp = await client.chat.completions.create( + model=MODEL_NAME, + messages=[{ + "role": "user", + "content": "what is 1+1?", + "extra_field": "0", + }], # type: ignore + temperature=0, + seed=0) + + content = resp.choices[0].message.content + assert content is not None @pytest.mark.asyncio diff --git a/tests/entrypoints/openai/test_chat_echo.py b/tests/entrypoints/openai/test_chat_echo.py new file mode 100644 index 0000000000000..223ac5b41aa83 --- /dev/null +++ b/tests/entrypoints/openai/test_chat_echo.py @@ -0,0 +1,79 @@ +from typing import NamedTuple + +import openai # use the official client for correctness check +import pytest +import pytest_asyncio + +from ...utils import RemoteOpenAIServer + +# # any model with a chat template should work here +MODEL_NAME = "Qwen/Qwen2-1.5B-Instruct" +DUMMY_CHAT_TEMPLATE = """{% for message in messages %}{{message['role'] + ': ' + message['content'] + '\\n'}}{% endfor %}""" # noqa: E501 + + +@pytest.fixture(scope="module") +def server(): + args = [ + # use half precision for speed and memory savings in CI environment + "--dtype", + "float16", + "--enforce-eager", + "--max-model-len", + "4080", + "--chat-template", + DUMMY_CHAT_TEMPLATE, + ] + + with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: + yield remote_server + + +@pytest_asyncio.fixture +async def client(server): + async with server.get_async_client() as async_client: + yield async_client + + +class TestCase(NamedTuple): + model_name: str + echo: bool + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "test_case", + [ + TestCase(model_name=MODEL_NAME, echo=True), + TestCase(model_name=MODEL_NAME, echo=False) + ], +) +async def test_chat_session_with_echo_and_continue_final_message( + client: openai.AsyncOpenAI, test_case: TestCase): + saying: str = "Here is a common saying about apple. An apple a day, keeps" + # test echo with continue_final_message parameter + chat_completion = await client.chat.completions.create( + model=test_case.model_name, + messages=[{ + "role": "user", + "content": "tell me a common saying" + }, { + "role": "assistant", + "content": saying + }], + extra_body={ + "echo": test_case.echo, + "continue_final_message": True, + "add_generation_prompt": False + }) + assert chat_completion.id is not None + assert len(chat_completion.choices) == 1 + + choice = chat_completion.choices[0] + assert choice.finish_reason == "stop" + + message = choice.message + if test_case.echo: + assert message.content is not None and saying in message.content + else: + assert message.content is not None and saying not in message.content + assert message.role == "assistant" diff --git a/tests/entrypoints/openai/test_root_path.py b/tests/entrypoints/openai/test_root_path.py new file mode 100644 index 0000000000000..20f7960619efb --- /dev/null +++ b/tests/entrypoints/openai/test_root_path.py @@ -0,0 +1,103 @@ +import contextlib +import os +from typing import Any, List, NamedTuple + +import openai # use the official client for correctness check +import pytest + +from ...utils import RemoteOpenAIServer + +# # any model with a chat template should work here +MODEL_NAME = "Qwen/Qwen2-1.5B-Instruct" +DUMMY_CHAT_TEMPLATE = """{% for message in messages %}{{message['role'] + ': ' + message['content'] + '\\n'}}{% endfor %}""" # noqa: E501 +API_KEY = "abc-123" +ERROR_API_KEY = "abc" +ROOT_PATH = "llm" + + +@pytest.fixture(scope="module") +def server(): + args = [ + # use half precision for speed and memory savings in CI environment + "--dtype", + "float16", + "--enforce-eager", + "--max-model-len", + "4080", + "--root-path", # use --root-path=/llm for testing + "/" + ROOT_PATH, + "--chat-template", + DUMMY_CHAT_TEMPLATE, + ] + envs = os.environ.copy() + + envs["VLLM_API_KEY"] = API_KEY + with RemoteOpenAIServer(MODEL_NAME, args, env_dict=envs) as remote_server: + yield remote_server + + +class TestCase(NamedTuple): + model_name: str + base_url: List[str] + api_key: str + expected_error: Any + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "test_case", + [ + TestCase( + model_name=MODEL_NAME, + base_url=["v1"], # http://localhost:8000/v1 + api_key=ERROR_API_KEY, + expected_error=openai.AuthenticationError), + TestCase( + model_name=MODEL_NAME, + base_url=[ROOT_PATH, "v1"], # http://localhost:8000/llm/v1 + api_key=ERROR_API_KEY, + expected_error=openai.AuthenticationError), + TestCase( + model_name=MODEL_NAME, + base_url=["v1"], # http://localhost:8000/v1 + api_key=API_KEY, + expected_error=None), + TestCase( + model_name=MODEL_NAME, + base_url=[ROOT_PATH, "v1"], # http://localhost:8000/llm/v1 + api_key=API_KEY, + expected_error=None), + ], +) +async def test_chat_session_root_path_with_api_key(server: RemoteOpenAIServer, + test_case: TestCase): + saying: str = "Here is a common saying about apple. An apple a day, keeps" + ctx = contextlib.nullcontext() + if test_case.expected_error is not None: + ctx = pytest.raises(test_case.expected_error) + with ctx: + client = openai.AsyncOpenAI( + api_key=test_case.api_key, + base_url=server.url_for(*test_case.base_url), + max_retries=0) + chat_completion = await client.chat.completions.create( + model=test_case.model_name, + messages=[{ + "role": "user", + "content": "tell me a common saying" + }, { + "role": "assistant", + "content": saying + }], + extra_body={ + "continue_final_message": True, + "add_generation_prompt": False + }) + + assert chat_completion.id is not None + assert len(chat_completion.choices) == 1 + choice = chat_completion.choices[0] + assert choice.finish_reason == "stop" + message = choice.message + assert len(message.content) > 0 + assert message.role == "assistant" diff --git a/tests/entrypoints/openai/test_score.py b/tests/entrypoints/openai/test_score.py new file mode 100644 index 0000000000000..7565ff7192f67 --- /dev/null +++ b/tests/entrypoints/openai/test_score.py @@ -0,0 +1,93 @@ +import pytest +import requests + +from vllm.entrypoints.openai.protocol import ScoreResponse + +from ...utils import RemoteOpenAIServer + +MODEL_NAME = "BAAI/bge-reranker-v2-m3" + + +@pytest.fixture(scope="module") +def server(): + args = [ + "--enforce-eager", + ] + + with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: + yield remote_server + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +async def test_text_1_str_text_2_list(server: RemoteOpenAIServer, + model_name: str): + text_1 = "What is the capital of France?" + text_2 = [ + "The capital of Brazil is Brasilia.", "The capital of France is Paris." + ] + + score_response = requests.post(server.url_for("v1/score"), + json={ + "model": model_name, + "text_1": text_1, + "text_2": text_2, + }) + score_response.raise_for_status() + score = ScoreResponse.model_validate(score_response.json()) + + assert score.id is not None + assert score.data is not None + assert len(score.data) == 2 + assert score.data[0].score[0] <= 0.01 + assert score.data[1].score[0] >= 0.9 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +async def test_text_1_list_text_2_list(server: RemoteOpenAIServer, + model_name: str): + text_1 = [ + "What is the capital of the United States?", + "What is the capital of France?" + ] + text_2 = [ + "The capital of Brazil is Brasilia.", "The capital of France is Paris." + ] + + score_response = requests.post(server.url_for("v1/score"), + json={ + "model": model_name, + "text_1": text_1, + "text_2": text_2, + }) + score_response.raise_for_status() + score = ScoreResponse.model_validate(score_response.json()) + + assert score.id is not None + assert score.data is not None + assert len(score.data) == 2 + assert score.data[0].score[0] <= 0.01 + assert score.data[1].score[0] >= 0.9 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +async def test_text_1_str_text_2_str(server: RemoteOpenAIServer, + model_name: str): + text_1 = "What is the capital of France?" + text_2 = "The capital of France is Paris." + + score_response = requests.post(server.url_for("v1/score"), + json={ + "model": model_name, + "text_1": text_1, + "text_2": text_2, + }) + score_response.raise_for_status() + score = ScoreResponse.model_validate(score_response.json()) + + assert score.id is not None + assert score.data is not None + assert len(score.data) == 1 + assert score.data[0].score[0] >= 0.9 diff --git a/tests/entrypoints/test_chat_utils.py b/tests/entrypoints/test_chat_utils.py index 72477e048eafa..996e60bfee592 100644 --- a/tests/entrypoints/test_chat_utils.py +++ b/tests/entrypoints/test_chat_utils.py @@ -766,8 +766,8 @@ def test_resolve_content_format_hf_defined(model, expected_format): ("tool_chat_template_granite_20b_fc.jinja", "string"), ("tool_chat_template_hermes.jinja", "string"), ("tool_chat_template_internlm2_tool.jinja", "string"), - ("tool_chat_template_llama3.1_json.jinja", "string"), - ("tool_chat_template_llama3.2_json.jinja", "string"), + ("tool_chat_template_llama3.1_json.jinja", "openai"), + ("tool_chat_template_llama3.2_json.jinja", "openai"), ("tool_chat_template_mistral_parallel.jinja", "string"), ("tool_chat_template_mistral.jinja", "string")], ) diff --git a/tests/kernels/test_attention_selector.py b/tests/kernels/test_attention_selector.py index 169ce040d370c..d37f95d48d5b2 100644 --- a/tests/kernels/test_attention_selector.py +++ b/tests/kernels/test_attention_selector.py @@ -5,6 +5,7 @@ from tests.kernels.utils import override_backend_env_variable from vllm.attention.selector import which_attn_to_use +from vllm.platforms import cpu, cuda, openvino, rocm from vllm.utils import STR_FLASH_ATTN_VAL, STR_INVALID_VAL @@ -19,26 +20,28 @@ def test_env(name: str, device: str, monkeypatch): override_backend_env_variable(monkeypatch, name) if device == "cpu": - with patch("vllm.attention.selector.current_platform.is_cpu", - return_value=True): + with patch("vllm.attention.selector.current_platform", + cpu.CpuPlatform()): backend = which_attn_to_use(16, torch.float16, torch.float16, 16, False) assert backend.name == "TORCH_SDPA" elif device == "hip": - with patch("vllm.attention.selector.current_platform.is_rocm", - return_value=True): + with patch("vllm.attention.selector.current_platform", + rocm.RocmPlatform()): backend = which_attn_to_use(16, torch.float16, torch.float16, 16, False) assert backend.name == "ROCM_FLASH" elif device == "openvino": - with patch("vllm.attention.selector.current_platform.is_openvino", - return_value=True): + with patch("vllm.attention.selector.current_platform", + openvino.OpenVinoPlatform()): backend = which_attn_to_use(16, torch.float16, torch.float16, 16, False) assert backend.name == "OPENVINO" else: - backend = which_attn_to_use(16, torch.float16, torch.float16, 16, - False) + with patch("vllm.attention.selector.current_platform", + cuda.CudaPlatform()): + backend = which_attn_to_use(16, torch.float16, torch.float16, 16, + False) assert backend.name == name diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index c990c87ecbb38..6bc2e59ac6dff 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -18,6 +18,7 @@ from vllm.attention.backends.utils import STR_NOT_IMPL_ENC_DEC_ROCM_HIP from vllm.attention.selector import (_Backend, _cached_get_attn_backend, global_force_attn_backend_context_manager) +from vllm.config import VllmConfig, set_current_vllm_config from vllm.forward_context import set_forward_context from vllm.platforms import current_platform @@ -597,6 +598,7 @@ def _run_encoder_attention_test( encoder_test_params: PhaseTestParameters, attn_metadata: AttentionMetadata, test_pt: TestPoint, + vllm_config: VllmConfig, ) -> torch.Tensor: ''' Run encoder attention. @@ -626,7 +628,7 @@ def _run_encoder_attention_test( attn_type = AttentionType.ENCODER packed_qkv = encoder_test_params.packed_qkvo.packed_qkv assert packed_qkv is not None - with set_forward_context(attn_metadata): + with set_forward_context(attn_metadata, vllm_config): # In the test setup the shape of the query is # [batch_size, seq_len, num_heads, head_size]. However # the attention backend expect the shape to be @@ -651,6 +653,7 @@ def _run_decoder_self_attention_test( decoder_test_params: PhaseTestParameters, attn_metadata: AttentionMetadata, test_pt: TestPoint, + vllm_config: VllmConfig, ) -> torch.Tensor: ''' Run decoder self-attention test. @@ -680,7 +683,7 @@ def _run_decoder_self_attention_test( kv_cache = test_rsrcs.kv_cache packed_qkv = decoder_test_params.packed_qkvo.packed_qkv assert packed_qkv is not None - with set_forward_context(attn_metadata): + with set_forward_context(attn_metadata, vllm_config): # In the test setup the shape of the query is # [batch_size, seq_len, num_heads, head_size]. However # the attention backend expect the shape to be @@ -704,6 +707,7 @@ def _run_encoder_decoder_cross_attention_test( cross_test_params: Optional[PhaseTestParameters], attn_metadata: AttentionMetadata, test_pt: TestPoint, + vllm_config: VllmConfig, ) -> torch.Tensor: ''' Run encoder/decoder cross-attention test. @@ -751,7 +755,7 @@ def _run_encoder_decoder_cross_attention_test( cross_pckd_qkv = cross_test_params.packed_qkvo.packed_qkv key = (None if cross_pckd_qkv is None else cross_pckd_qkv.key) value = (None if cross_pckd_qkv is None else cross_pckd_qkv.value) - with set_forward_context(attn_metadata): + with set_forward_context(attn_metadata, vllm_config): # In the test setup the shape of the query is # [batch_size, seq_len, num_heads, head_size]. However # the attention backend expect the shape to be @@ -842,7 +846,9 @@ def test_encoder_only( # Attention scale factor, attention backend instance, attention wrapper # instance, KV cache init - test_rsrcs = _make_test_resources(test_pt) + vllm_config = VllmConfig() + with set_current_vllm_config(vllm_config): + test_rsrcs = _make_test_resources(test_pt) # Construct encoder attention test params (only used # during prefill) @@ -866,7 +872,8 @@ def test_encoder_only( test_rsrcs.attn, enc_test_params, prephase_attn_metadata, - test_pt=test_pt)) + test_pt=test_pt, + vllm_config=vllm_config)) # - Is encoder attention result correct? assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out, @@ -961,7 +968,9 @@ def test_e2e_enc_dec_attn( # Attention scale factor, attention backend instance, attention wrapper # instance, KV cache init - test_rsrcs = _make_test_resources(test_pt) + vllm_config = VllmConfig() + with set_current_vllm_config(vllm_config): + test_rsrcs = _make_test_resources(test_pt) # Construct encoder attention test params (only used # during prefill) @@ -1012,7 +1021,8 @@ def test_e2e_enc_dec_attn( enc_pckd_act_out = _run_encoder_attention_test(test_rsrcs.attn, enc_test_params, prephase_attn_metadata, - test_pt=test_pt) + test_pt=test_pt, + vllm_config=vllm_config) # - Is encoder attention result correct? assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out, @@ -1024,7 +1034,8 @@ def test_e2e_enc_dec_attn( test_rsrcs, prephase_dec_test_params, prephase_attn_metadata, - test_pt=test_pt) + test_pt=test_pt, + vllm_config=vllm_config) # - Is prefill decoder self-attention correct? assert_actual_matches_ideal(prephase_dec_test_params, @@ -1038,7 +1049,8 @@ def test_e2e_enc_dec_attn( prephase_dec_test_params, prephase_cross_test_params, prephase_attn_metadata, - test_pt=test_pt) + test_pt=test_pt, + vllm_config=vllm_config) # - Is prefill encoder/decoder cross-attention correct? assert_actual_matches_ideal(prephase_cross_test_params, @@ -1062,7 +1074,8 @@ def test_e2e_enc_dec_attn( test_rsrcs, decphase_dec_test_params, decphase_attn_metadata, - test_pt=test_pt) + test_pt=test_pt, + vllm_config=vllm_config) # - Is decode-phase decoder self-attention correct? assert_actual_matches_ideal(decphase_dec_test_params, @@ -1076,7 +1089,8 @@ def test_e2e_enc_dec_attn( decphase_dec_test_params, None, decphase_attn_metadata, - test_pt=test_pt) + test_pt=test_pt, + vllm_config=vllm_config) # - Is decode-phase encoder/decoder cross-attention correct? assert_actual_matches_ideal(decphase_cross_test_params, diff --git a/tests/kernels/test_flash_attn.py b/tests/kernels/test_flash_attn.py index a20c73345218f..1ae78d7b46c5b 100644 --- a/tests/kernels/test_flash_attn.py +++ b/tests/kernels/test_flash_attn.py @@ -71,6 +71,7 @@ def ref_paged_attn( return torch.cat(outputs, dim=0) +@pytest.mark.parametrize("use_out", [True, False]) @pytest.mark.parametrize("kv_lens", [[1328, 18, 463], [1, 54, 293, 70]]) @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @@ -81,6 +82,7 @@ def ref_paged_attn( @pytest.mark.parametrize("sliding_window", [None, 256]) @torch.inference_mode() def test_flash_attn_with_paged_kv( + use_out: bool, kv_lens: List[int], num_heads: Tuple[int, int], head_size: int, @@ -116,17 +118,22 @@ def test_flash_attn_with_paged_kv( (num_seqs, max_num_blocks_per_seq), dtype=torch.int32) + q = query.unsqueeze(1) + out = torch.empty_like(q) if use_out else None output = flash_attn_with_kvcache( - q=query.unsqueeze(1), + q=q, k_cache=key_cache, v_cache=value_cache, + out=out, softmax_scale=scale, causal=True, block_table=block_tables, cache_seqlens=kv_lens_tensor, softcap=soft_cap if soft_cap is not None else 0, window_size=window_size, - ).squeeze(1) + ) + output = output if not use_out else out + output = output.squeeze(1) ref_output = ref_paged_attn(query=query, key_cache=key_cache, @@ -141,7 +148,10 @@ def test_flash_attn_with_paged_kv( f"{torch.max(torch.abs(output - ref_output))}" -@pytest.mark.parametrize("seq_lens", [[(1, 1328), (5, 18), (129, 463)]]) +@pytest.mark.parametrize("use_out", [True, False]) +@pytest.mark.parametrize("seq_lens", + [[(1, 1328), (5, 18), + (129, 463)], [(1, 523), (1, 37), (1, 2011)]]) @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("block_size", BLOCK_SIZES) @@ -151,6 +161,7 @@ def test_flash_attn_with_paged_kv( @pytest.mark.parametrize("num_blocks", NUM_BLOCKS) @torch.inference_mode() def test_varlen_with_paged_kv( + use_out: bool, seq_lens: List[Tuple[int, int]], num_heads: Tuple[int, int], head_size: int, @@ -197,10 +208,12 @@ def test_varlen_with_paged_kv( (num_seqs, max_num_blocks_per_seq), dtype=torch.int32) + out = torch.empty_like(query) if use_out else None output = flash_attn_varlen_func( q=query, k=key_cache, v=value_cache, + out=out, cu_seqlens_q=cu_query_lens, cu_seqlens_k=cu_kv_lens, max_seqlen_q=max_query_len, @@ -211,6 +224,7 @@ def test_varlen_with_paged_kv( block_table=block_tables, softcap=soft_cap if soft_cap is not None else 0, ) + output = output if not use_out else out ref_output = ref_paged_attn( query=query, diff --git a/tests/kernels/test_machete_gemm.py b/tests/kernels/test_machete_gemm.py deleted file mode 100644 index 59c0a24753c3b..0000000000000 --- a/tests/kernels/test_machete_gemm.py +++ /dev/null @@ -1,284 +0,0 @@ -"""Tests for the machete kernel. - -Run `pytest tests/kernels/test_machete_gemm.py`. -""" - -import math -from typing import Optional, Tuple - -import pytest -import torch - -from tests.kernels.utils import opcheck -from vllm import _custom_ops as ops -from vllm.model_executor.layers.quantization.utils.quant_utils import ( - pack_rows, quantize_weights) -from vllm.platforms import current_platform -from vllm.scalar_type import ScalarType, scalar_types - -CUDA_DEVICES = [ - f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) -] - -MNK_SHAPES = [ - (1, 128, 128), - (1, 512, 1024), - (1, 4096, 4096), - (1, 8192, 28672), - (13, 8192, 4096), - (26, 4096, 8192), - (64, 4096, 4096), - (64, 8192, 28672), - (257, 128, 4096), - (257, 4224, 4160), - (257, 4096, 4096), - (1024, 4096, 8192), - (1024, 8192, 4096), -] - -ACT_TYPES = [torch.float16, torch.bfloat16] -WTYPE_ZEROPOINTS = [ - # GPTQ style - (scalar_types.uint4b8, False), - (scalar_types.uint8b128, False), - # AWQ style - (scalar_types.uint4, True), - (scalar_types.uint8, True), -] - -# TODO: in future PR refactor this and `is_quant_method_supported` in the kernel -# unit tests to a common utility function. Currently the use of -# `is_quant_method_supported` conflates kernels with quantization methods -# an assumption which is breaking down as quantizations methods can have -# have kernels and some kernels support multiple quantization methods. -IS_SUPPORTED_BY_GPU = current_platform.has_device_capability(90) - - -def rand_data(shape, dtype=torch.float16): - return 10 * (torch.rand(shape, dtype=dtype, device="cuda") - 0.3) - - -def maybe_convert_zeropoints(zps: Optional[torch.Tensor], s: torch.Tensor): - return zps if zps is None else -1 * s * (zps.to(s.dtype)) - - -def machete_quantize_and_pack(w: torch.Tensor, - wtype: ScalarType, - group_size: int, - zero_points: bool = False): - assert wtype.is_integer(), "TODO: support floating point weights" - - w_ref, w_q, w_s, w_zp = quantize_weights( - w, - wtype, - group_size, - zero_points=zero_points, - # to match how the kernel applies zps - ref_zero_points_after_scales=True) - - w_q = pack_rows(w_q, wtype.size_bits, *w_q.shape) - w_q = w_q.t().contiguous().t() # convert to col major - w_q_machete = ops.machete_prepack_B(w_q, wtype) - - opcheck(torch.ops._C.machete_prepack_B, (w_q, wtype.id)) - - return w_ref, w_q_machete, w_s, w_zp - - -def machete_gemm_test_helper(a: torch.Tensor, b: torch.Tensor, - wtype: ScalarType, group_size: int, - zero_points: bool): - w_ref, w_q_packed, w_s, w_zp = machete_quantize_and_pack( - b, wtype, group_size, zero_points) - - output_ref = torch.matmul(a, w_ref) - - output = ops.machete_gemm( - a=a, - b_q=w_q_packed, - b_type=wtype, - b_scales=w_s, - b_zeros=maybe_convert_zeropoints(w_zp, w_s), - b_group_size=group_size, - ) - - # Relax atol as our reduction dim becomes larger (more rounding error) - # Relax atol when we have zeropoints since the way machete applies - # zeropoints (after scales) causes noise around 0 - atol = 1 if zero_points else min(5e-2 * math.sqrt(a.shape[1]), 1) - torch.testing.assert_close(output, output_ref, rtol=1e-1, atol=atol) - - -@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU, - reason="Machete is not supported on this GPU type.") -@pytest.mark.parametrize("shape", - MNK_SHAPES, - ids=lambda x: "x".join(str(v) for v in x)) -@pytest.mark.parametrize("atype", ACT_TYPES, ids=lambda x: str(x)) -@pytest.mark.parametrize("wtype_zeropoints", WTYPE_ZEROPOINTS) -@pytest.mark.parametrize("group_size", [128, None]) -def test_machete_all_schedules(shape, atype: torch.dtype, - wtype_zeropoints: Tuple[ScalarType, bool], - group_size: Optional[int]): - m, n, k = shape - wtype, zero_points = wtype_zeropoints - - if group_size is not None and k % group_size != 0: - return - - print(f"MNK = {m} {n} {k}") - - # Normalize group_size - if group_size is None: - group_size = k - assert group_size <= k - - a = rand_data((m, k), atype) - w = rand_data((k, n), atype) - - w_ref, w_q_machete, w_s, w_zp = machete_quantize_and_pack( - w, wtype, group_size, zero_points) - - output_ref = torch.matmul(a, w_ref) - - for schedule in ops.machete_supported_schedules(wtype): - print(f"Testing schedule {schedule}") - output = ops.machete_gemm( - a, - b_q=w_q_machete, - b_type=wtype, - b_scales=w_s, - b_zeros=maybe_convert_zeropoints(w_zp, w_s), - b_group_size=group_size, - schedule=schedule, - ) - - opcheck( - torch.ops._C.machete_gemm, - (a, w_q_machete, wtype.id, w_s, maybe_convert_zeropoints( - w_zp, w_s), group_size, None, None, None, schedule)) - - # Relax atol as our reduction dim becomes larger (more rounding error) - # Relax atol when we have zeropoints since the way machete applies - # zeropoints (after scales) causes noise around 0 - atol = 1 if zero_points else min(5e-2 * math.sqrt(k), 1) - torch.testing.assert_close(output, output_ref, rtol=1e-1, atol=atol),\ - f"Schedule failed {schedule}" - - -@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU, - reason="Machete is not supported on this GPU type.") -@pytest.mark.parametrize("shape", - MNK_SHAPES, - ids=lambda x: "x".join(str(v) for v in x)) -@pytest.mark.parametrize("atype", ACT_TYPES, ids=lambda x: str(x)) -@pytest.mark.parametrize("wtype_zeropoints", WTYPE_ZEROPOINTS) -@pytest.mark.parametrize("group_size", [128, None]) -def test_machete_heuristic(shape, atype: torch.dtype, - wtype_zeropoints: Tuple[ScalarType, bool], - group_size: Optional[int]): - m, n, k = shape - wtype, zero_points = wtype_zeropoints - - if group_size is not None and k % group_size != 0: - return - - # Normalize group_size - if group_size is None: - group_size = k - assert group_size <= k - - a = rand_data((m, k), atype) - b = rand_data((k, n), atype) - - machete_gemm_test_helper(a, b, wtype, group_size, zero_points) - - -# Test working on other devices -@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU, - reason="Machete is not supported on this GPU type.") -@pytest.mark.parametrize("device", CUDA_DEVICES) -def test_machete_devices(device: str): - m, n, k = 512, 4096, 4096 - wtype = scalar_types.uint4b8 - group_size = 128 - zero_points = False - - print(f"MNK = {m} {n} {k}, device = {device}") - - a = rand_data((m, k), torch.float16).to(device) - b = rand_data((k, n), torch.float16).to(device) - - machete_gemm_test_helper(a, b, wtype, group_size, zero_points) - - -# Test working with a subset of A and B -@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU, - reason="Machete is not supported on this GPU type.") -def test_machete_subset(): - big_m, big_n, big_k = 1024, 1024, 1024 - m, n, k = 512, 512, 512 - wtype = scalar_types.uint4b8 - group_size = 128 - zero_points = False - - whole_a = rand_data((big_m, big_k), torch.float16) - whole_b = rand_data((big_k, big_n), torch.float16) - - a = whole_a[0:m, 0:k] - b = whole_b[0:k, 0:n] - - machete_gemm_test_helper(a, b, wtype, group_size, zero_points) - - -# Test to make sure cuda graphs work -class MacheteLayer(torch.nn.Module): - - def __init__(self, **kwargs): - super().__init__() - self.kwargs = kwargs - - def forward(self, a): - return ops.machete_gemm(**self.kwargs) - - -@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU, - reason="Machete is not supported on this GPU type.") -def test_machete_cuda_graph(): - m, n, k = 512, 4096, 4096 - - a = rand_data((m, k), torch.float16) - b = rand_data((k, n), torch.float16) - wtype = scalar_types.uint4b8 - group_size = 128 - zero_points = False - - w_ref, w_q_packed, w_s, w_zp = machete_quantize_and_pack( - b, wtype, group_size, zero_points) - - # Construct a trivial model with a single layer that calls a machete kernel - model = MacheteLayer( - a=a, - b_q=w_q_packed, - b_type=wtype, - b_scales=w_s, - b_zeros=maybe_convert_zeropoints(w_zp, w_s), - b_group_size=group_size, - ) - - output_ref = torch.matmul(a, w_ref) - - # Run the model with a cuda graph - stream = torch.cuda.Stream() - with torch.cuda.stream(stream): - g = torch.cuda.CUDAGraph() - with torch.cuda.graph(g): - output = model(a) - output.zero_() - g.replay() - - # Relax atol as our reduction dim becomes larger (more rounding error) - # Relax atol when we have zeropoints since the way machete applies - # zeropoints (after scales) causes noise around 0 - atol = 1 if zero_points else min(5e-2 * math.sqrt(k), 1) - torch.testing.assert_close(output, output_ref, rtol=1e-1, atol=atol) diff --git a/tests/kernels/test_machete_mm.py b/tests/kernels/test_machete_mm.py new file mode 100644 index 0000000000000..1c6eb2dd9a228 --- /dev/null +++ b/tests/kernels/test_machete_mm.py @@ -0,0 +1,406 @@ +"""Tests for the machete kernel. + +Run `pytest tests/kernels/test_machete_mm.py`. +""" + +import math +from dataclasses import dataclass, fields +from typing import List, Optional, Tuple + +import pytest +import torch + +from tests.kernels.utils import opcheck +from vllm import _custom_ops as ops +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + pack_rows, quantize_weights) +from vllm.platforms import current_platform +from vllm.scalar_type import ScalarType, scalar_types + +CUDA_DEVICES = [ + f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) +] + +# TODO: in future PR refactor this and `is_quant_method_supported` in the kernel +# unit tests to a common utility function. Currently the use of +# `is_quant_method_supported` conflates kernels with quantization methods +# an assumption which is breaking down as quantizations methods can have +# have kernels and some kernels support multiple quantization methods. +IS_SUPPORTED_BY_GPU = current_platform.get_device_capability()[0] >= 9 + +MNK_SHAPES = [ + (1, 128, 128), + (1, 512, 1024), + (1, 4096, 4096), + (1, 8192, 28672), + (13, 8192, 4096), + (26, 4096, 8192), + (64, 4096, 4096), + (64, 8192, 28672), + (257, 128, 4096), + (257, 4224, 4160), + (257, 4096, 4096), + (1024, 4096, 8192), + (1024, 8192, 4096), +] + +GROUP_SIZES_TO_TEST: List[Optional[int]] = [128, -1] + + +@dataclass +class TypeConfig: + act_type: torch.dtype + weight_type: ScalarType + output_type: Optional[torch.dtype] + group_scale_type: Optional[torch.dtype] + group_zero_type: Optional[torch.dtype] + channel_scale_type: Optional[torch.dtype] + token_scale_type: Optional[torch.dtype] + + +@dataclass +class Tensors: + w_ref: torch.Tensor + a_ref: torch.Tensor + a: torch.Tensor + w_q: torch.Tensor + w_g_s: Optional[torch.Tensor] + w_g_zp: Optional[torch.Tensor] + w_ch_s: Optional[torch.Tensor] + w_tok_s: Optional[torch.Tensor] + + +# (Act Type, Weight Type, Output Type, Scale Type, ZeroPoints, +# Ch Scales Type, Tok Scales Type) +# NOTE: None "Scale Type" means the act type is floating point +# None "Output Type" means the output type is the same as the act type +TestTypeTuple = Tuple[List[torch.dtype], ScalarType, Optional[torch.dtype], + Optional[torch.dtype], bool] +TEST_TYPES = [ + # GPTQ style + *(TypeConfig(act_type=a_type, + weight_type=w_type, + output_type=None, + group_scale_type=a_type, + group_zero_type=None, + channel_scale_type=None, + token_scale_type=None) + for w_type in [scalar_types.uint4b8, scalar_types.uint8b128] + for a_type in [torch.float16, torch.bfloat16]), + # AWQ style + *(TypeConfig(act_type=a_type, + weight_type=w_type, + output_type=None, + group_scale_type=a_type, + group_zero_type=a_type, + channel_scale_type=None, + token_scale_type=None) + for w_type in [scalar_types.uint4, scalar_types.uint8] + for a_type in [torch.float16, torch.bfloat16]), + # QQQ style + *(TypeConfig(act_type=torch.int8, + weight_type=scalar_types.uint4b8, + output_type=torch.float16, + group_scale_type=group_scale_type, + group_zero_type=None, + channel_scale_type=torch.float, + token_scale_type=torch.float) + for group_scale_type in [None, torch.float16]), + *(TypeConfig(act_type=torch.float8_e4m3fn, + weight_type=scalar_types.uint4b8, + output_type=torch.float16, + group_scale_type=group_scale_type, + group_zero_type=None, + channel_scale_type=torch.float, + token_scale_type=torch.float) + for group_scale_type in [None, torch.float16]), +] + +# TODO: in future PR refactor this and `is_quant_method_supported` in the kernel +# unit tests to a common utility function. Currently the use of +# `is_quant_method_supported` conflates kernels with quantization methods +# an assumption which is breaking down as quantizations methods can have +# have kernels and some kernels support multiple quantization methods. +IS_SUPPORTED_BY_GPU = current_platform.has_device_capability(90) + + +def rand_data(shape, dtype=torch.float16, scale=1, offset=0): + if dtype.is_floating_point: + return (scale * torch.rand(shape, device="cuda") - offset).to(dtype) + else: + return torch.randint(-8, 7, shape, dtype=dtype, device="cuda") + + +def maybe_convert_zeropoints(zps: Optional[torch.Tensor], s: torch.Tensor): + return zps if zps is None else -1 * s * (zps.to(s.dtype)) + + +def group_size_valid(shape: Tuple[int, int, int], + group_size: Optional[int]) -> bool: + return group_size is None or group_size == -1 or group_size % shape[2] == 0 + + +def machete_quantize_and_pack(atype: torch.dtype, + w: torch.Tensor, + wtype: ScalarType, + stype: Optional[torch.dtype], + group_size: Optional[int], + zero_points: bool = False): + assert wtype.is_integer(), "TODO: support floating point weights" + + w_ref, w_q, w_s, w_zp = quantize_weights( + w, + wtype, + group_size=group_size, + zero_points=zero_points, + # to match how the kernel applies zps + ref_zero_points_after_scales=True) + + w_q = pack_rows(w_q, wtype.size_bits, *w_q.shape) + w_q = w_q.t().contiguous().t() # convert to col major + + w_q_machete = ops.machete_prepack_B(w_q, atype, wtype, stype) + opcheck(torch.ops._C.machete_prepack_B, (w_q, atype, wtype.id, stype)) + + return w_ref, w_q_machete, w_s, w_zp + + +def create_test_tensors(shape: Tuple[int, int, int], + types: TypeConfig, + group_size: Optional[int], + subset_stride_factor: Optional[int] = None) -> Tensors: + m, n, k = shape + factor = subset_stride_factor or 1 + + print("create_test_tensors, shape:", shape, "types:", types, "group_size:", + group_size) + + a = rand_data((m * factor, k * factor), types.act_type, scale=3, offset=2) + w = rand_data((k * factor, n * factor), types.act_type, scale=3, offset=1) + + if factor > 1: + a = a[0:m, 0:k] + w = w[0:k, 0:n] + + if types.group_scale_type is not None: + w = w.to(types.group_scale_type) + if w.dtype.itemsize == 1: + w = w.to(torch.float16) + + w_ref, w_q_packed, w_s, w_zp = machete_quantize_and_pack( + a.dtype, w, types.weight_type, types.group_scale_type, group_size, + types.group_zero_type is not None) + + if not a.dtype.is_floating_point: + aiinfo = torch.iinfo(a.dtype) + w_ref = w_ref.round().clamp(aiinfo.min, aiinfo.max) + + a_ref = a.to(torch.float32) + w_ref = w_ref.to(torch.float32) + + w_ch_s = None if types.channel_scale_type is None else\ + rand_data((n,), types.channel_scale_type) + w_tok_s = None if types.token_scale_type is None else\ + rand_data((m,), types.token_scale_type) + + return Tensors(w_ref=w_ref, + a_ref=a_ref, + a=a, + w_q=w_q_packed, + w_g_s=w_s, + w_g_zp=maybe_convert_zeropoints(w_zp, w_s), + w_ch_s=w_ch_s, + w_tok_s=w_tok_s) + + +# None stype means scales use the same dtype as a +def machete_mm_test_helper(types: TypeConfig, + tensors: Tensors, + group_size: Optional[int] = None, + schedule: Optional[str] = None): + output_ref = torch.matmul(tensors.a_ref, tensors.w_ref) + output_ref_type = output_ref.dtype + + if tensors.w_ch_s is not None: + output_ref = (output_ref.to(tensors.w_ch_s.dtype) * + tensors.w_ch_s.unsqueeze(0)).to(output_ref_type) + if tensors.w_tok_s is not None: + output_ref = (output_ref.to(tensors.w_tok_s.dtype) * + tensors.w_tok_s.unsqueeze(1)).to(output_ref_type) + + output = ops.machete_mm( + a=tensors.a, + b_q=tensors.w_q, + b_type=types.weight_type, + b_group_scales=tensors.w_g_s, + b_group_zeros=tensors.w_g_zp, + b_group_size=group_size, + b_channel_scales=tensors.w_ch_s, + a_token_scales=tensors.w_tok_s, + out_type=types.output_type, + schedule=schedule, + ) + + print(output) + print(output_ref) + + # Relax atol as our reduction dim becomes larger (more rounding error) + # Relax atol when we have zeropoints since the way machete applies + # zeropoints (after scales) causes noise around 0 + atol = 1 if tensors.w_g_zp is not None\ + else min(5e-2 * math.sqrt(tensors.a.shape[1]), 1) + rtol = 1e-1 if tensors.a.element_size() >= 2 else 2e-1 + torch.testing.assert_close(output, + output_ref.to(output.dtype), + rtol=rtol, + atol=atol) + + +@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU, + reason="Machete is not supported on this GPU type.") +@pytest.mark.parametrize("shape", + MNK_SHAPES, + ids=lambda x: "x".join(str(v) for v in x)) +@pytest.mark.parametrize("types", TEST_TYPES) +def test_machete_all_schedules(shape, types: TypeConfig): + + group_sizes: List[Optional[int]] = [] + if types.group_scale_type is None: + group_sizes = [None] + else: + group_sizes = GROUP_SIZES_TO_TEST + + for group_size in group_sizes: + if not group_size_valid(shape, group_size): + continue + + tensors = create_test_tensors(shape, types, group_size) + print(f"MNK = {shape}") + for schedule in ops.machete_supported_schedules( + types.act_type, + types.weight_type, + group_scales_type=types.group_scale_type, + group_zeros_type=types.group_scale_type, + out_type=types.output_type): + print(f"Testing schedule {schedule}") + machete_mm_test_helper(types, tensors, group_size, schedule) + + +@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU, + reason="Machete is not supported on this GPU type.") +@pytest.mark.parametrize("shape", + MNK_SHAPES, + ids=lambda x: "x".join(str(v) for v in x)) +@pytest.mark.parametrize("types", TEST_TYPES) +def test_machete_heuristic(shape, types: TypeConfig): + group_sizes: List[Optional[int]] = [] + if types.group_scale_type is None: + group_sizes = [None] + else: + group_sizes = GROUP_SIZES_TO_TEST + + for group_size in group_sizes: + if not group_size_valid(shape, group_size): + continue + + tensors = create_test_tensors(shape, types, group_size) + machete_mm_test_helper(types, tensors, group_size) + + +# Test working on other devices +@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU, + reason="Machete is not supported on this GPU type.") +@pytest.mark.parametrize("device", CUDA_DEVICES) +def test_machete_devices(device: str): + group_size = 128 + + type_config = TypeConfig(act_type=torch.float16, + weight_type=scalar_types.uint4b8, + output_type=None, + group_scale_type=torch.float16, + group_zero_type=None, + channel_scale_type=None, + token_scale_type=None) + + tensors = create_test_tensors((512, 4096, 4096), type_config, group_size) + + for field in fields(Tensors): + tensor = getattr(tensors, field.name) + if isinstance(tensor, torch.Tensor): + setattr(tensors, field.name, tensor.to(device)) + + machete_mm_test_helper(type_config, tensors, group_size) + + +# Test working with a subset of A and B +@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU, + reason="Machete is not supported on this GPU type.") +def test_machete_subset(): + group_size = 128 + + type_config = TypeConfig(act_type=torch.float16, + weight_type=scalar_types.uint4b8, + output_type=None, + group_scale_type=torch.float16, + group_zero_type=None, + channel_scale_type=None, + token_scale_type=None) + + tensors = create_test_tensors((512, 4096, 4096), + type_config, + group_size, + subset_stride_factor=2) + machete_mm_test_helper(type_config, tensors, group_size) + + +# Test to make sure cuda graphs work +class MacheteLayer(torch.nn.Module): + + def __init__(self, **kwargs): + super().__init__() + self.kwargs = kwargs + + def forward(self, a): + return ops.machete_mm(a=a, **self.kwargs) + + +@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU, + reason="Machete is not supported on this GPU type.") +def test_machete_cuda_graph(): + m, n, k = 512, 4096, 4096 + + a = rand_data((m, k), torch.float16) + b = rand_data((k, n), torch.float16) + wtype = scalar_types.uint4b8 + stype = torch.float16 + group_size = 128 + zero_points = False + + w_ref, w_q_packed, w_s, w_zp = machete_quantize_and_pack( + a.dtype, b, wtype, stype, group_size, zero_points) + + # Construct a trivial model with a single layer that calls a machete kernel + model = MacheteLayer( + b_q=w_q_packed, + b_type=wtype, + b_group_scales=w_s, + b_group_zeros=maybe_convert_zeropoints(w_zp, w_s), + b_group_size=group_size, + ) + + output_ref = torch.matmul(a, w_ref) + + # Run the model with a cuda graph + stream = torch.cuda.Stream() + with torch.cuda.stream(stream): + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + output = model(a) + output.zero_() + g.replay() + + # Relax atol as our reduction dim becomes larger (more rounding error) + # Relax atol when we have zeropoints since the way machete applies + # zeropoints (after scales) causes noise around 0 + atol = 1 if zero_points else min(5e-2 * math.sqrt(k), 1) + torch.testing.assert_close(output, output_ref, rtol=1e-1, atol=atol) diff --git a/tests/kernels/test_marlin_gemm.py b/tests/kernels/test_marlin_gemm.py index b6dd68cc51a9f..5e047f4b099f1 100644 --- a/tests/kernels/test_marlin_gemm.py +++ b/tests/kernels/test_marlin_gemm.py @@ -29,6 +29,7 @@ marlin_qqq_quantize) from vllm.model_executor.layers.quantization.utils.quant_utils import ( awq_pack, gptq_pack, gptq_quantize_weights, quantize_weights, sort_weights) +from vllm.scalar_type import scalar_types ACT_ORDER_OPTS = [False, True] K_FULL_OPTS = [False, True] @@ -40,6 +41,8 @@ MARLIN_24_K_CHUNKS = [128] MARLIN_24_N_CHUNKS = [512] +HQQ_SUPPORTED_GROUP_SIZES = [64] + MNK_FACTORS = [ (1, 1, 1), (1, 4, 8), @@ -47,6 +50,8 @@ (13, 17, 67), (26, 37, 13), (67, 13, 11), + (257, 13, 11), + (658, 13, 11), ] DTYPES = [torch.float16, torch.bfloat16] @@ -226,7 +231,7 @@ def test_gptq_marlin_gemm( torch.ops._C.gptq_marlin_gemm, (a_input, marlin_q_w, marlin_s, marlin_zp, g_idx, sort_indices, workspace.scratch, quant_type.id, a_input.shape[0], b_weight.shape[1], - a_input.shape[1], is_k_full, False, use_fp32_reduce), + a_input.shape[1], is_k_full, False, use_fp32_reduce, False), test_utils=DEFAULT_OPCHECK_TEST_UTILS) output = ops.gptq_marlin_gemm( @@ -244,6 +249,7 @@ def test_gptq_marlin_gemm( is_k_full=is_k_full, has_zp=False, use_fp32_reduce=use_fp32_reduce, + is_zp_float=False, ) output_ref = torch.matmul(a_input, w_ref) @@ -441,6 +447,7 @@ def test_awq_marlin_gemm( is_k_full=is_k_full, has_zp=has_zp, use_fp32_reduce=use_fp32_reduce, + is_zp_float=False, ) output_ref = torch.matmul(a_input, w_ref) @@ -451,6 +458,87 @@ def test_awq_marlin_gemm( assert max_diff < 0.04 +@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"), + reason="Marlin is not supported on this GPU type.") +@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS) +@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS) +@pytest.mark.parametrize("group_size", HQQ_SUPPORTED_GROUP_SIZES) +@pytest.mark.parametrize("mnk_factors", MNK_FACTORS) +@pytest.mark.parametrize("use_fp32_reduce", USE_FP32_REDUCE_OPTS) +def test_hqq_marlin_gemm( + k_chunk, + n_chunk, + group_size, + mnk_factors, + use_fp32_reduce, +): + m_factor, n_factor, k_factor = mnk_factors + + size_m = m_factor + size_k = k_chunk * k_factor + size_n = n_chunk * n_factor + + quant_type = scalar_types.uint4 + + a_input = rand_data((size_m, size_k)) + dev = a_input.device + + b_weight = torch.randint(0, + 10, (size_n, size_k), + dtype=torch.uint8, + device=dev) + scale = rand_data((size_n, size_k // group_size)) + zero = rand_data((size_n, size_k // group_size)) + + gptq_w_q = gptq_pack(b_weight.transpose(1, 0), 4, size_k, size_n) + + sort_indices = torch.empty(0, dtype=torch.int, device=dev) + marlin_w_q = ops.gptq_marlin_repack(gptq_w_q, sort_indices, size_k, size_n, + 4).to(dev) + marlin_s = marlin_permute_scales(scale.transpose(1, 0), size_k, size_n, + group_size).to(dev) + marlin_zp = marlin_permute_scales(zero.transpose(1, 0), size_k, size_n, + group_size).to(dev) + + g_idx = marlin_make_empty_g_idx(dev) + g_idx_sort_indices = marlin_make_empty_g_idx(dev) + + workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N, + GPTQ_MARLIN_MAX_PARALLEL) + + output = ops.gptq_marlin_gemm( + a_input, + marlin_w_q, + marlin_s, + marlin_zp, + g_idx, + g_idx_sort_indices, + workspace.scratch, + quant_type, + a_input.shape[0], + b_weight.shape[0], + a_input.shape[1], + is_k_full=True, + has_zp=True, + use_fp32_reduce=use_fp32_reduce, + is_zp_float=True, + ) + + b_flat = b_weight.reshape(-1, group_size) + zp_flat = zero.reshape(-1, 1) + s_flat = scale.reshape(-1, 1) + dequant = (b_flat - zp_flat) * s_flat + + output_ref = torch.matmul(a_input, + dequant.reshape(b_weight.shape).transpose(1, 0)) + + torch.cuda.synchronize() + + max_diff = compute_max_diff(output, output_ref) + + assert max_diff < 0.04 + + @pytest.mark.skipif(not is_quant_method_supported("qqq"), reason="Marlin is not supported on this GPU type.") @pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS) diff --git a/tests/kernels/test_prefix_prefill.py b/tests/kernels/test_prefix_prefill.py index a8a187ebaede4..3fdb7996ba4e0 100644 --- a/tests/kernels/test_prefix_prefill.py +++ b/tests/kernels/test_prefix_prefill.py @@ -40,6 +40,13 @@ def test_contexted_kv_attention( kv_cache_dtype: str, device: str, ) -> None: + + if 'fp8' in kv_cache_dtype and not current_platform.has_device_capability( + 89): + pytest.skip( + 'Triton limitation: fp8e4nv data type is not supported on CUDA' + ' arch < 89') + current_platform.seed_everything(0) torch.set_default_device(device) @@ -235,6 +242,13 @@ def test_contexted_kv_attention_alibi( kv_cache_dtype: str, device: str, ) -> None: + + if 'fp8' in kv_cache_dtype and not current_platform.has_device_capability( + 89): + pytest.skip( + 'Triton limitation: fp8e4nv data type is not supported on CUDA' + ' arch < 89') + current_platform.seed_everything(0) torch.set_default_device(device) @@ -462,3 +476,52 @@ def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: print(f"xformers Time: {(end_time - start_time)*1000:.2f} ms") atol = 1e-3 if "fp8" in kv_cache_dtype else 1e-6 torch.testing.assert_close(output, output_ref, atol=atol, rtol=0) + + +# These tests are optional to only run when explicitly invoked +# +# pytest -v -s --optional \ +# tests/kernels/test_prefix_prefill.py::test_contexted_kv_attention_f32 +# +# These tests are useful to test model dtype float32 on Turing devices. +# We skip them to not increase the time when running tests on CI +@pytest.mark.optional +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("num_queries_per_kv", NUM_QUERIES_PER_KV) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("dtype", [torch.float32]) +@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPES) +@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("sliding_window", SLIDING_WINDOW) +@torch.inference_mode() +def test_contexted_kv_attention_f32( + num_heads: int, + num_queries_per_kv: int, + head_size: int, + sliding_window: int, + dtype: torch.dtype, + kv_cache_dtype: str, + device: str, +) -> None: + test_contexted_kv_attention(num_heads, num_queries_per_kv, head_size, + sliding_window, dtype, kv_cache_dtype, device) + + +@pytest.mark.optional +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("num_queries_per_kv", NUM_QUERIES_PER_KV) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("dtype", [torch.float32]) +@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPES) +@pytest.mark.parametrize("device", CUDA_DEVICES) +@torch.inference_mode() +def test_contexted_kv_attention_alibi_f32( + num_heads: int, + num_queries_per_kv: int, + head_size: int, + dtype: torch.dtype, + kv_cache_dtype: str, + device: str, +) -> None: + test_contexted_kv_attention_alibi(num_heads, num_queries_per_kv, head_size, + dtype, kv_cache_dtype, device) diff --git a/tests/kv_transfer/disagg_test.py b/tests/kv_transfer/disagg_test.py new file mode 100644 index 0000000000000..adc6150edece6 --- /dev/null +++ b/tests/kv_transfer/disagg_test.py @@ -0,0 +1,119 @@ +import os +import subprocess +import sys +import time +from subprocess import Popen + +import pytest +import requests +import torch + + +# Fixture to set up environment variables and teardown servers after tests +@pytest.fixture(scope="module", autouse=True) +def setup_servers(): + if torch.cuda.device_count() < 4: + pytest.skip("Skipping test: fewer than 4 GPUs available") + + # Set up environment variables + VLLM_HOST_IP = subprocess.check_output("hostname -I | awk '{print $1}'", + shell=True).decode().strip() + os.environ["VLLM_HOST_IP"] = VLLM_HOST_IP + + # Start prefill instance + prefill_cmd = [ + sys.executable, + "-m", + "vllm.entrypoints.openai.api_server", + "--model", + "meta-llama/Meta-Llama-3.1-8B-Instruct", + "--port", + "8100", + "--gpu-memory-utilization", + "0.5", + "--max-model-len", + "1000", + "--kv-transfer-config", + '{"kv_connector":"PyNcclConnector","kv_role":"kv_producer",'\ + '"kv_rank":0,"kv_parallel_size":2}', + ] + prefill_env = os.environ.copy() + prefill_env["CUDA_VISIBLE_DEVICES"] = "0" + prefill_proc = Popen(prefill_cmd, env=prefill_env) + + # Start decode instance + decode_cmd = [ + sys.executable, + "-m", + "vllm.entrypoints.openai.api_server", + "--model", + "meta-llama/Meta-Llama-3.1-8B-Instruct", + "--port", + "8200", + "--gpu-memory-utilization", + "0.5", + "--max-model-len", + "1000", + "--kv-transfer-config", + '{"kv_connector":"PyNcclConnector","kv_role":"kv_consumer",'\ + '"kv_rank":1,"kv_parallel_size":2}', + ] + decode_env = os.environ.copy() + decode_env["CUDA_VISIBLE_DEVICES"] = "1" + decode_proc = Popen(decode_cmd, env=decode_env) + + # Wait for servers to be ready + assert wait_for_server(8100), "Prefill server did not start in time" + assert wait_for_server(8200), "Decode server did not start in time" + + # Yield to the test function and handle teardown after tests + yield + + # Cleanup: kill the processes + prefill_proc.terminate() + decode_proc.terminate() + + # Additional cleanup if needed + prefill_proc.wait() + decode_proc.wait() + + +# Helper function to wait for server +def wait_for_server(port, timeout=240): + start_time = time.time() + while time.time() - start_time < timeout: + try: + response = requests.get(f"http://localhost:{port}/v1/completions") + if response.status_code in [200, 405]: + return True + except requests.ConnectionError: + time.sleep(1) + return False + + +# Test function to send curl requests and validate responses +@pytest.mark.parametrize("prompt", ["San Francisco is a", "Santa Clara is a"]) +def test_disaggregated_prefilling(prompt): + # Send to prefill + response = requests.post("http://localhost:8100/v1/completions", + headers={"Content-Type": "application/json"}, + json={ + "model": + "meta-llama/Meta-Llama-3.1-8B-Instruct", + "prompt": prompt, + "max_tokens": 1, + "temperature": 0 + }) + assert response.status_code == 200 + + # Send to decode + response = requests.post("http://localhost:8200/v1/completions", + headers={"Content-Type": "application/json"}, + json={ + "model": + "meta-llama/Meta-Llama-3.1-8B-Instruct", + "prompt": prompt, + "max_tokens": 10, + "temperature": 0 + }) + assert response.status_code == 200 diff --git a/tests/kv_transfer/module_test.py b/tests/kv_transfer/module_test.py new file mode 100644 index 0000000000000..355461919cd7c --- /dev/null +++ b/tests/kv_transfer/module_test.py @@ -0,0 +1,64 @@ +import subprocess +import sys + +import pytest +import torch + + +def run_python_script(script_name, timeout): + script_name = f'kv_transfer/{script_name}' + try: + # Start both processes asynchronously using Popen + process0 = subprocess.Popen( + [sys.executable, script_name], + env={"RANK": + "0"}, # Set the RANK environment variable for process 0 + stdout=sys.stdout, # Pipe stdout to current stdout + stderr=sys.stderr, # Pipe stderr to current stderr + ) + + process1 = subprocess.Popen( + [sys.executable, script_name], + env={"RANK": + "1"}, # Set the RANK environment variable for process 1 + stdout=sys.stdout, # Pipe stdout to current stdout + stderr=sys.stderr, # Pipe stderr to current stderr + ) + + # Wait for both processes to complete, with a timeout + process0.wait(timeout=timeout) + process1.wait(timeout=timeout) + + # Check the return status of both processes + if process0.returncode != 0: + pytest.fail( + f"Test {script_name} failed for RANK=0, {process0.returncode}") + if process1.returncode != 0: + pytest.fail( + f"Test {script_name} failed for RANK=1, {process1.returncode}") + + except subprocess.TimeoutExpired: + # If either process times out, terminate both and fail the test + process0.terminate() + process1.terminate() + pytest.fail(f"Test {script_name} timed out") + except Exception as e: + pytest.fail(f"Test {script_name} failed with error: {str(e)}") + + +# Define the test cases using pytest's parametrize +@pytest.mark.parametrize( + "script_name,timeout", + [ + ("test_lookup_buffer.py", + 60), # Second test case with a 60-second timeout + ("test_send_recv.py", 120) # First test case with a 120-second timeout + ]) +def test_run_python_script(script_name, timeout): + # Check the number of GPUs + if torch.cuda.device_count() < 2: + pytest.skip( + f"Skipping test {script_name} because <2 GPUs are available") + + # Run the test if there are at least 2 GPUs + run_python_script(script_name, timeout) diff --git a/tests/kv_transfer/test_lookup_buffer.py b/tests/kv_transfer/test_lookup_buffer.py new file mode 100644 index 0000000000000..96b0e58713332 --- /dev/null +++ b/tests/kv_transfer/test_lookup_buffer.py @@ -0,0 +1,160 @@ +import os +import random + +import torch +from tqdm import tqdm + +from vllm.config import KVTransferConfig +from vllm.distributed.kv_transfer.kv_lookup_buffer.simple_buffer import ( + SimpleBuffer) +from vllm.distributed.kv_transfer.kv_pipe.pynccl_pipe import PyNcclPipe + +# TODO: the test depends on a lot of fields in the current implementation. +# We should have standard interface instead direct field access + + +def test_run(my_rank, buffer, device): + + # buffer should be empty in the beginning + if my_rank == 0: + assert buffer.buffer_size == 0 + assert len(buffer.buffer) == 0 + + print("My rank: %d, device: %s" % (my_rank, device)) + + # insert + tokens = torch.tensor([1, 2, 3]).to(device) + roi = (tokens > 0) + if my_rank == 0: + key = 2.0 * torch.ones([5, 6]).to(device) + value = 3.0 * torch.ones([5, 6]).to(device) + + placeholder = torch.tensor([1]).to(device) + + buffer.insert(tokens, roi, key, value, placeholder) + + torch.distributed.barrier() + + # drop_select + if my_rank == 1: + tok, roi_, key, value, hidden = buffer.drop_select(tokens, roi) + assert torch.allclose(tokens, tok) + assert torch.allclose(roi, roi_) + assert torch.allclose(key, 2.0 * torch.ones([5, 6], device=device)) + assert torch.allclose(value, 3.0 * torch.ones([5, 6], device=device)) + torch.distributed.barrier() + + if my_rank == 0: + assert buffer.buffer_size == 0 + assert len(buffer.buffer) == 0 + + print("Test run passed!") + + +def stress_test(my_rank, buf, device): + + torch.distributed.barrier() + torch.manual_seed(100) + + reqs = [ + ( + torch.rand(100).to(device), # tokens + torch.ones(100).bool().to(device), # roi + torch.rand(100).to(device), # key + torch.rand(100).to(device), # value + torch.rand(100).to(device), # hidden + ) for i in tqdm(range(200)) + ] + + random.seed(my_rank) + random.shuffle(reqs) + + torch.distributed.barrier() + + n = 0 + + # the buffer size can only store 100 reqs + # so the sender will occasionally block to wait for the receiver. + for req in tqdm(reqs): + if my_rank == 0: + buf.insert(*req) + else: + tok, roi, k, v, h = req + tok_, roi_, k_, v_, h_ = buf.drop_select(tok, roi) + + if tok_ is None: + assert roi_ is None + assert k_ is None + assert v_ is None + assert h_ is None + n += 1 + else: + assert torch.allclose(tok, tok_) + assert torch.allclose(roi, roi_) + assert torch.allclose(k, k_) + assert torch.allclose(v, v_) + assert torch.allclose(h, h_) + print('Rank %d done' % my_rank) + torch.distributed.barrier() + + if my_rank == 0: + x = torch.tensor([0]) + torch.distributed.recv(x, 1) + # the # of None received is the kv that are not selected + assert x.item() == len(buf.buffer) + # and the size of the buffer should be 2000 * buffer len + print(buf.buffer_size) + assert buf.buffer_size == 1700 * len(buf.buffer) + else: + torch.distributed.send(torch.tensor([n]), 0) + + print("Passed stress test!") + + +if __name__ == "__main__": + + my_rank = int(os.environ['RANK']) + + torch.distributed.init_process_group( + backend='gloo', + init_method='tcp://localhost:12398', + world_size=2, + rank=my_rank, + ) + + print("initialized! My rank is %d" % my_rank) + + config = KVTransferConfig( + kv_connector='PyNcclConnector', + kv_buffer_device='cuda', + kv_buffer_size=1e9, + kv_rank=my_rank, + kv_role="kv_both", # this arg doesn't matter in this test + kv_parallel_size=2, + kv_ip="127.0.0.1", + kv_port=12345, + ) + + data_pipe = PyNcclPipe( + local_rank=my_rank, + config=config, + device="cuda", + port_offset=0, + ) + cpu_pipe = PyNcclPipe( + local_rank=my_rank, + config=config, + device="cpu", + port_offset=1, + ) + + buffer = SimpleBuffer(cpu_pipe, data_pipe, 170000) + + test_run(my_rank, buffer, data_pipe.device) + + stress_test(my_rank, buffer, data_pipe.device) + + buffer.close() + data_pipe.close() + cpu_pipe.close() + print('Done') diff --git a/tests/kv_transfer/test_lookup_buffer.sh b/tests/kv_transfer/test_lookup_buffer.sh new file mode 100644 index 0000000000000..09d7ee018c3f4 --- /dev/null +++ b/tests/kv_transfer/test_lookup_buffer.sh @@ -0,0 +1,3 @@ +#!/bin/bash +RANK=0 python test_lookup_buffer.py & +RANK=1 python test_lookup_buffer.py & \ No newline at end of file diff --git a/tests/kv_transfer/test_send_recv.py b/tests/kv_transfer/test_send_recv.py new file mode 100644 index 0000000000000..65973bf10a4d7 --- /dev/null +++ b/tests/kv_transfer/test_send_recv.py @@ -0,0 +1,155 @@ +import os +import time +from typing import List + +import torch +from tqdm import tqdm + +from vllm.config import KVTransferConfig +from vllm.distributed.kv_transfer.kv_pipe.pynccl_pipe import PyNcclPipe + + +def test_run(my_rank, pipe): + # test run + x = torch.tensor([1]).to(pipe.device) + y = torch.tensor([[2., 3., 4., 8.]]).to(pipe.device) + if my_rank == 0: + pipe.send_tensor(x) + print("sent tensor x") + pipe.send_tensor(y) + print("sent tensor y") + x2 = pipe.recv_tensor() + print("received x2 = ", x2) + y2 = pipe.recv_tensor() + print("received y2 = ", x2) + + else: + x2 = pipe.recv_tensor() + print("received x2 = ", x2) + y2 = pipe.recv_tensor() + print("received y2 = ", x2) + pipe.send_tensor(x) + print("sent tensor x") + pipe.send_tensor(y) + print("sent tensor y") + + assert torch.allclose(x, x2) + assert torch.allclose(y, y2) + + +def stress_test(my_rank, pipe): + + torch.distributed.barrier() + + tensors: List[torch.Tensor] = [] + + torch.manual_seed(0) + + for i in tqdm(range(500)): + mean = torch.rand(1).item() * 100 + std = torch.rand(1).item() * 100 + size = torch.randint(900, 1000, (2, )) + x = torch.normal(mean * 1.0, std * 1.0, + size=size.tolist()).to(pipe.device) + + # 5% probability of sending a None + if torch.rand(1).item() < 0.05: + tensors.append(None) + tensors.append(None) + tensors.append(None) + else: + tensors.append(x) + tensors.append(x.mean().unsqueeze(0)) + tensors.append(x.std().unsqueeze(0)) + + torch.distributed.barrier() + + for i in tqdm(range(500)): + if my_rank == int((i % 10) > 3): + pipe.send_tensor(tensors[3 * i]) + pipe.send_tensor(tensors[3 * i + 1]) + pipe.send_tensor(tensors[3 * i + 2]) + else: + x = pipe.recv_tensor() + mean = pipe.recv_tensor() + std = pipe.recv_tensor() + + if x is None: + assert mean is None + assert std is None + else: + assert torch.allclose(x, tensors[3 * i]) + assert x.mean() == mean[0] + assert x.std() == std[0] + + torch.distributed.barrier() + + +def latency_test(my_rank, pipe, nelement, ntensor): + + latencies = [] + + torch.distributed.barrier() + + for i in tqdm(range(500)): + + tensors = [] + + if my_rank == 0: + # create tensor + tensors = [ + torch.rand(nelement).to(pipe.device) for _ in range(ntensor) + ] + + torch.distributed.barrier() + + if my_rank == 0: + t = torch.tensor([time.time()], + dtype=torch.float64).to(pipe.device) + for tensor in tensors: + pipe.send_tensor(tensor) + pipe.send_tensor(t) + else: + for _ in range(ntensor): + pipe.recv_tensor() + t = pipe.recv_tensor() + latencies.append(time.time() - t.item()) + + torch.distributed.barrier() + + print('Latency test passed.') + print('Latency:', torch.tensor(latencies).mean().item() * 1000, 'ms') + + +if __name__ == "__main__": + + my_rank = int(os.environ['RANK']) + + torch.distributed.init_process_group( + backend='gloo', + init_method='tcp://localhost:12398', + world_size=2, + rank=my_rank, + ) + + config = KVTransferConfig( + kv_connector='PyNcclConnector', + kv_buffer_device='cuda', + kv_buffer_size=1e9, + kv_rank=my_rank, + kv_role="kv_both", # this arg doesn't matter in this test + kv_parallel_size=2, + kv_ip="127.0.0.1", + kv_port=12345, + ) + + pipe = PyNcclPipe( + local_rank=my_rank, + config=config, + ) + + test_run(my_rank, pipe) + stress_test(my_rank, pipe) + + # Use this function if you want to test the latency of pipe impl. + # latency_test(my_rank, pipe, 1024 * 8 * 128, 80) diff --git a/tests/kv_transfer/test_send_recv.sh b/tests/kv_transfer/test_send_recv.sh new file mode 100644 index 0000000000000..1e89e246b4992 --- /dev/null +++ b/tests/kv_transfer/test_send_recv.sh @@ -0,0 +1,3 @@ +#!/bin/bash +RANK=0 python3 test_send_recv.py & +RANK=1 python3 test_send_recv.py & \ No newline at end of file diff --git a/tests/lora/test_chatglm3.py b/tests/lora/test_chatglm3_tp.py similarity index 56% rename from tests/lora/test_chatglm3.py rename to tests/lora/test_chatglm3_tp.py index de4cbea80924e..f17464573459f 100644 --- a/tests/lora/test_chatglm3.py +++ b/tests/lora/test_chatglm3_tp.py @@ -1,12 +1,21 @@ from typing import List import vllm +from tests.utils import fork_new_process_for_each_test from vllm.lora.request import LoRARequest +from ..utils import multi_gpu_test + MODEL_PATH = "THUDM/chatglm3-6b" PROMPT_TEMPLATE = """I want you to act as a SQL terminal in front of an example database, you need only to return the sql command to me.Below is an instruction that describes a task, Write a response that appropriately completes the request.\n"\n##Instruction:\nconcert_singer contains tables such as stadium, singer, concert, singer_in_concert. Table stadium has columns such as Stadium_ID, Location, Name, Capacity, Highest, Lowest, Average. Stadium_ID is the primary key.\nTable singer has columns such as Singer_ID, Name, Country, Song_Name, Song_release_year, Age, Is_male. Singer_ID is the primary key.\nTable concert has columns such as concert_ID, concert_Name, Theme, Stadium_ID, Year. concert_ID is the primary key.\nTable singer_in_concert has columns such as concert_ID, Singer_ID. concert_ID is the primary key.\nThe Stadium_ID of concert is the foreign key of Stadium_ID of stadium.\nThe Singer_ID of singer_in_concert is the foreign key of Singer_ID of singer.\nThe concert_ID of singer_in_concert is the foreign key of concert_ID of concert.\n\n###Input:\n{query}\n\n###Response:""" # noqa: E501 +EXPECTED_LORA_OUTPUT = [ + "SELECT count(*) FROM singer", + "SELECT avg(age) , min(age) , max(age) FROM singer WHERE country = 'France'", # noqa: E501 + "SELECT name , country , age FROM singer ORDER BY age", +] + def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]: prompts = [ @@ -20,7 +29,6 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]: "Show name, country, age for all singers ordered by age from the oldest to the youngest." # noqa: E501 ), ] - print(prompts) sampling_params = vllm.SamplingParams(temperature=0, max_tokens=32) outputs = llm.generate( prompts, @@ -37,23 +45,58 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]: return generated_texts +@fork_new_process_for_each_test def test_chatglm3_lora(chatglm3_lora_files): llm = vllm.LLM(MODEL_PATH, max_model_len=1024, enable_lora=True, max_loras=4, max_lora_rank=64, + tensor_parallel_size=1, trust_remote_code=True) - expected_lora_output = [ - "SELECT count(*) FROM singer", - "SELECT avg(age) , min(age) , max(age) FROM singer WHERE country = 'France'", # noqa: E501 - "SELECT name , country , age FROM singer ORDER BY age", - ] + output1 = do_sample(llm, chatglm3_lora_files, lora_id=1) + for i in range(len(EXPECTED_LORA_OUTPUT)): + assert output1[i] == EXPECTED_LORA_OUTPUT[i] + output2 = do_sample(llm, chatglm3_lora_files, lora_id=2) + for i in range(len(EXPECTED_LORA_OUTPUT)): + assert output2[i] == EXPECTED_LORA_OUTPUT[i] + +@multi_gpu_test(num_gpus=4) +@fork_new_process_for_each_test +def test_chatglm3_lora_tp4(chatglm3_lora_files): + llm = vllm.LLM(MODEL_PATH, + max_model_len=1024, + enable_lora=True, + max_loras=4, + max_lora_rank=64, + tensor_parallel_size=4, + trust_remote_code=True, + fully_sharded_loras=False) + + output1 = do_sample(llm, chatglm3_lora_files, lora_id=1) + for i in range(len(EXPECTED_LORA_OUTPUT)): + assert output1[i] == EXPECTED_LORA_OUTPUT[i] + output2 = do_sample(llm, chatglm3_lora_files, lora_id=2) + for i in range(len(EXPECTED_LORA_OUTPUT)): + assert output2[i] == EXPECTED_LORA_OUTPUT[i] + + +@multi_gpu_test(num_gpus=4) +@fork_new_process_for_each_test +def test_chatglm3_lora_tp4_fully_sharded_loras(chatglm3_lora_files): + llm = vllm.LLM(MODEL_PATH, + max_model_len=1024, + enable_lora=True, + max_loras=4, + max_lora_rank=64, + tensor_parallel_size=4, + trust_remote_code=True, + fully_sharded_loras=True) output1 = do_sample(llm, chatglm3_lora_files, lora_id=1) - for i in range(len(expected_lora_output)): - assert output1[i] == expected_lora_output[i] + for i in range(len(EXPECTED_LORA_OUTPUT)): + assert output1[i] == EXPECTED_LORA_OUTPUT[i] output2 = do_sample(llm, chatglm3_lora_files, lora_id=2) - for i in range(len(expected_lora_output)): - assert output2[i] == expected_lora_output[i] + for i in range(len(EXPECTED_LORA_OUTPUT)): + assert output2[i] == EXPECTED_LORA_OUTPUT[i] diff --git a/tests/lora/test_llama.py b/tests/lora/test_llama.py deleted file mode 100644 index e2a4f1ed0496a..0000000000000 --- a/tests/lora/test_llama.py +++ /dev/null @@ -1,146 +0,0 @@ -from typing import List - -import pytest -import ray - -import vllm -from vllm.distributed import cleanup_dist_env_and_memory -from vllm.lora.request import LoRARequest - -MODEL_PATH = "meta-llama/Llama-2-7b-hf" - - -def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]: - prompts = [ - "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]", # noqa: E501 - "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? [/user] [assistant]", # noqa: E501 - "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_95 (one_mora VARCHAR, gloss VARCHAR, accented_mora VARCHAR)\n\n question: What is the one mora for a low tone mora with a gloss of /˩okiru/ [òkìɽɯ́]? [/user] [assistant]", # noqa: E501 - "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE candidate (people_id VARCHAR, unsure_rate INTEGER); CREATE TABLE people (sex VARCHAR, people_id VARCHAR)\n\n question: which gender got the highest average uncertain ratio. [/user] [assistant]", # noqa: E501 - "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_60 (pick INTEGER, former_wnba_team VARCHAR)\n\n question: What pick was a player that previously played for the Minnesota Lynx? [/user] [assistant]", # noqa: E501 - "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the women's doubles for werner schlager [/user] [assistant]" # noqa: E501 - ] - sampling_params = vllm.SamplingParams(temperature=0, - max_tokens=256, - stop=["[/assistant]"]) - outputs = llm.generate( - prompts, - sampling_params, - lora_request=LoRARequest(str(lora_id), lora_id, lora_path) - if lora_id else None) - # Print the outputs. - generated_texts: List[str] = [] - for output in outputs: - prompt = output.prompt - generated_text = output.outputs[0].text - generated_texts.append(generated_text) - print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") - return generated_texts - - -@pytest.mark.parametrize("tp_size", [1, 2, 4]) -def test_llama_lora(sql_lora_files, tp_size, num_gpus_available): - if num_gpus_available < tp_size: - pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}") - - llm = vllm.LLM(MODEL_PATH, - enable_lora=True, - max_num_seqs=16, - max_loras=4, - tensor_parallel_size=tp_size) - - expected_no_lora_output = [ - "\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_75 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_76 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_77 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_78 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user]", # noqa: E501 - " Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? ", # noqa: E501 - "\n\n answer: 1\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_96 (one_mora VARCHAR, gloss VARCHAR, accented_mora VARCHAR)\n\n question: What is the one mora for a high tone mora with a gloss of /˧kot/ [kòt]? [/user] [assistant]\n\n answer: 2\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_97 (one_mora VARCHAR, gloss VARCHAR, accented_mora VARCHAR)\n\n question: What is the one mora for a high tone mora with a gloss of /˧kot/ [kòt]? [/user] [assistant]\n\n answer: 2\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_98 (one_mora VARCHAR, gloss VARCHAR, accented_mora VARCHAR)\n\n question: What is the one m", # noqa: E501 - " Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE candidate (people_id VARCHAR, unsure_rate INTEGER); CREATE TABLE people (sex VARCHAR, people_id VARCHAR)\n\n question: which gender got the highest average uncertain ratio. ", # noqa: E501 - " Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_60 (pick INTEGER, former_wnba_team VARCHAR)\n\n question: What pick was a player that previously played for the Minnesota Lynx? ", # noqa: E501 - "\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the women's doubles for werner schlager [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the women's doubles for werner schlager [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the women's doubles for werner schlager [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE", # noqa: E501 - ] - expected_lora_output = [ - " SELECT icao FROM table_name_74 WHERE airport = 'lilongwe international airport' ", # noqa: E501 - " SELECT nationality FROM table_name_11 WHERE elector = 'anchero pantaleone' ", # noqa: E501 - " SELECT one_mora FROM table_name_95 WHERE gloss = 'low tone mora with a gloss of /˩okiru/' [òkìɽɯ́] AND accented_mora = 'low tone mora with a gloss of /˩okiru/' [òkìɽɯ́] ", # noqa: E501 - " SELECT sex FROM people WHERE people_id IN (SELECT people_id FROM candidate GROUP BY sex ORDER BY COUNT(people_id) DESC LIMIT 1) ", # noqa: E501 - " SELECT pick FROM table_name_60 WHERE former_wnba_team = 'Minnesota Lynx' ", # noqa: E501 - " SELECT womens_doubles FROM table_28138035_4 WHERE mens_singles = 'Werner Schlager' " # noqa: E501 - ] - - print("lora adapter created") - assert do_sample(llm, sql_lora_files, lora_id=0) == expected_no_lora_output - - print("lora 1") - assert do_sample(llm, sql_lora_files, lora_id=1) == expected_lora_output - - print("no lora") - assert do_sample(llm, sql_lora_files, lora_id=0) == expected_no_lora_output - - print("lora 2") - assert do_sample(llm, sql_lora_files, lora_id=2) == expected_lora_output - - print("removing lora") - - -def test_llama_tensor_parallel_equality(sql_lora_files, num_gpus_available): - if num_gpus_available < 4: - pytest.skip("Not enough GPUs for tensor parallelism 4") - - llm_tp1 = vllm.LLM(MODEL_PATH, - enable_lora=True, - max_num_seqs=16, - max_loras=4, - tensor_parallel_size=1) - output_tp1 = do_sample(llm_tp1, sql_lora_files, lora_id=1) - - del llm_tp1 - cleanup_dist_env_and_memory() - - llm_tp2 = vllm.LLM(MODEL_PATH, - enable_lora=True, - max_num_seqs=16, - max_loras=4, - tensor_parallel_size=2) - output_tp2 = do_sample(llm_tp2, sql_lora_files, lora_id=1) - - del llm_tp2 - cleanup_dist_env_and_memory() - - assert output_tp1 == output_tp2 - - llm_tp4 = vllm.LLM(MODEL_PATH, - enable_lora=True, - max_num_seqs=16, - max_loras=4, - tensor_parallel_size=4) - output_tp4 = do_sample(llm_tp4, sql_lora_files, lora_id=1) - - del llm_tp4 - cleanup_dist_env_and_memory() - - assert output_tp1 == output_tp4 - - -def test_llama_lora_warmup(sql_lora_files): - """Test that the LLM initialization works with a warmup LORA path and - is more conservative""" - - @ray.remote(num_gpus=1) - def get_num_gpu_blocks_lora(): - llm = vllm.LLM(MODEL_PATH, enable_lora=True, max_num_seqs=16) - num_gpu_blocks_lora_warmup = llm.llm_engine.cache_config.num_gpu_blocks - return num_gpu_blocks_lora_warmup - - @ray.remote(num_gpus=1) - def get_num_gpu_blocks_no_lora(): - llm = vllm.LLM(MODEL_PATH, max_num_seqs=16) - num_gpu_blocks_no_lora_warmup = ( - llm.llm_engine.cache_config.num_gpu_blocks) - return num_gpu_blocks_no_lora_warmup - - num_gpu_blocks_lora_warmup = ray.get(get_num_gpu_blocks_lora.remote()) - num_gpu_blocks_no_lora_warmup = ray.get( - get_num_gpu_blocks_no_lora.remote()) - assert num_gpu_blocks_lora_warmup < num_gpu_blocks_no_lora_warmup, ( - "The warmup with lora should be more " - "conservative than without lora, therefore the number of " - "memory blocks for the KV cache should be " - "less when using lora than when not using lora") diff --git a/tests/lora/test_llama_tp.py b/tests/lora/test_llama_tp.py new file mode 100644 index 0000000000000..aae6310a2a213 --- /dev/null +++ b/tests/lora/test_llama_tp.py @@ -0,0 +1,161 @@ +from typing import List + +import ray + +import vllm +from tests.utils import fork_new_process_for_each_test +from vllm.lora.request import LoRARequest + +from ..utils import multi_gpu_test + +MODEL_PATH = "meta-llama/Llama-2-7b-hf" + +EXPECTED_NO_LORA_OUTPUT = [ + "\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_75 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_76 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_77 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_78 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user]", # noqa: E501 + " Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? ", # noqa: E501 + "\n\n answer: 1\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_96 (one_mora VARCHAR, gloss VARCHAR, accented_mora VARCHAR)\n\n question: What is the one mora for a high tone mora with a gloss of /˧kot/ [kòt]? [/user] [assistant]\n\n answer: 2\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_97 (one_mora VARCHAR, gloss VARCHAR, accented_mora VARCHAR)\n\n question: What is the one mora for a high tone mora with a gloss of /˧kot/ [kòt]? [/user] [assistant]\n\n answer: 2\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_98 (one_mora VARCHAR, gloss VARCHAR, accented_mora VARCHAR)\n\n question: What is the one m", # noqa: E501 + " Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE candidate (people_id VARCHAR, unsure_rate INTEGER); CREATE TABLE people (sex VARCHAR, people_id VARCHAR)\n\n question: which gender got the highest average uncertain ratio. ", # noqa: E501 + " Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_60 (pick INTEGER, former_wnba_team VARCHAR)\n\n question: What pick was a player that previously played for the Minnesota Lynx? ", # noqa: E501 + "\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the women's doubles for werner schlager [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the women's doubles for werner schlager [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the women's doubles for werner schlager [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE", # noqa: E501 +] +EXPECTED_LORA_OUTPUT = [ + " SELECT icao FROM table_name_74 WHERE airport = 'lilongwe international airport' ", # noqa: E501 + " SELECT nationality FROM table_name_11 WHERE elector = 'anchero pantaleone' ", # noqa: E501 + " SELECT one_mora FROM table_name_95 WHERE gloss = 'low tone mora with a gloss of /˩okiru/' [òkìɽɯ́] AND accented_mora = 'low tone mora with a gloss of /˩okiru/' [òkìɽɯ́] ", # noqa: E501 + " SELECT sex FROM people WHERE people_id IN (SELECT people_id FROM candidate GROUP BY sex ORDER BY COUNT(people_id) DESC LIMIT 1) ", # noqa: E501 + " SELECT pick FROM table_name_60 WHERE former_wnba_team = 'Minnesota Lynx' ", # noqa: E501 + " SELECT womens_doubles FROM table_28138035_4 WHERE mens_singles = 'Werner Schlager' " # noqa: E501 +] + + +def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]: + prompts = [ + "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]", # noqa: E501 + "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? [/user] [assistant]", # noqa: E501 + "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_95 (one_mora VARCHAR, gloss VARCHAR, accented_mora VARCHAR)\n\n question: What is the one mora for a low tone mora with a gloss of /˩okiru/ [òkìɽɯ́]? [/user] [assistant]", # noqa: E501 + "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE candidate (people_id VARCHAR, unsure_rate INTEGER); CREATE TABLE people (sex VARCHAR, people_id VARCHAR)\n\n question: which gender got the highest average uncertain ratio. [/user] [assistant]", # noqa: E501 + "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_60 (pick INTEGER, former_wnba_team VARCHAR)\n\n question: What pick was a player that previously played for the Minnesota Lynx? [/user] [assistant]", # noqa: E501 + "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the women's doubles for werner schlager [/user] [assistant]" # noqa: E501 + ] + sampling_params = vllm.SamplingParams(temperature=0, + max_tokens=256, + stop=["[/assistant]"]) + outputs = llm.generate( + prompts, + sampling_params, + lora_request=LoRARequest(str(lora_id), lora_id, lora_path) + if lora_id else None) + # Print the outputs. + generated_texts: List[str] = [] + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + generated_texts.append(generated_text) + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + return generated_texts + + +@fork_new_process_for_each_test +def test_llama_lora(sql_lora_files): + + llm = vllm.LLM(MODEL_PATH, + enable_lora=True, + max_num_seqs=16, + max_loras=4, + tensor_parallel_size=1) + + print("lora adapter created") + assert do_sample(llm, sql_lora_files, lora_id=0) == EXPECTED_NO_LORA_OUTPUT + + print("lora 1") + assert do_sample(llm, sql_lora_files, lora_id=1) == EXPECTED_LORA_OUTPUT + + print("no lora") + assert do_sample(llm, sql_lora_files, lora_id=0) == EXPECTED_NO_LORA_OUTPUT + + print("lora 2") + assert do_sample(llm, sql_lora_files, lora_id=2) == EXPECTED_LORA_OUTPUT + + print("removing lora") + + +@fork_new_process_for_each_test +def test_llama_lora_warmup(sql_lora_files): + """Test that the LLM initialization works with a warmup LORA path and + is more conservative""" + + @ray.remote(num_gpus=1) + def get_num_gpu_blocks_lora(): + llm = vllm.LLM(MODEL_PATH, enable_lora=True, max_num_seqs=16) + num_gpu_blocks_lora_warmup = llm.llm_engine.cache_config.num_gpu_blocks + return num_gpu_blocks_lora_warmup + + @ray.remote(num_gpus=1) + def get_num_gpu_blocks_no_lora(): + llm = vllm.LLM(MODEL_PATH, max_num_seqs=16) + num_gpu_blocks_no_lora_warmup = ( + llm.llm_engine.cache_config.num_gpu_blocks) + return num_gpu_blocks_no_lora_warmup + + num_gpu_blocks_lora_warmup = ray.get(get_num_gpu_blocks_lora.remote()) + num_gpu_blocks_no_lora_warmup = ray.get( + get_num_gpu_blocks_no_lora.remote()) + assert num_gpu_blocks_lora_warmup < num_gpu_blocks_no_lora_warmup, ( + "The warmup with lora should be more " + "conservative than without lora, therefore the number of " + "memory blocks for the KV cache should be " + "less when using lora than when not using lora") + + +@multi_gpu_test(num_gpus=4) +@fork_new_process_for_each_test +def test_llama_lora_tp4(sql_lora_files): + + llm = vllm.LLM( + MODEL_PATH, + enable_lora=True, + max_num_seqs=16, + max_loras=4, + tensor_parallel_size=4, + ) + + print("lora adapter created") + assert do_sample(llm, sql_lora_files, lora_id=0) == EXPECTED_NO_LORA_OUTPUT + + print("lora 1") + assert do_sample(llm, sql_lora_files, lora_id=1) == EXPECTED_LORA_OUTPUT + + print("no lora") + assert do_sample(llm, sql_lora_files, lora_id=0) == EXPECTED_NO_LORA_OUTPUT + + print("lora 2") + assert do_sample(llm, sql_lora_files, lora_id=2) == EXPECTED_LORA_OUTPUT + + print("removing lora") + + +@multi_gpu_test(num_gpus=4) +@fork_new_process_for_each_test +def test_llama_lora_tp4_fully_sharded_loras(sql_lora_files): + + llm = vllm.LLM( + MODEL_PATH, + enable_lora=True, + max_num_seqs=16, + max_loras=4, + tensor_parallel_size=4, + fully_sharded_loras=True, + ) + print("lora adapter created") + assert do_sample(llm, sql_lora_files, lora_id=0) == EXPECTED_NO_LORA_OUTPUT + + print("lora 1") + assert do_sample(llm, sql_lora_files, lora_id=1) == EXPECTED_LORA_OUTPUT + + print("no lora") + assert do_sample(llm, sql_lora_files, lora_id=0) == EXPECTED_NO_LORA_OUTPUT + + print("lora 2") + assert do_sample(llm, sql_lora_files, lora_id=2) == EXPECTED_LORA_OUTPUT + + print("removing lora") diff --git a/tests/lora/test_punica_variation.py b/tests/lora/test_punica_variation.py index 52b82f25d23e1..3b20033271d26 100644 --- a/tests/lora/test_punica_variation.py +++ b/tests/lora/test_punica_variation.py @@ -6,12 +6,13 @@ import pytest import torch -from vllm.lora.ops.bgmv_expand import bgmv_expand -from vllm.lora.ops.bgmv_expand_slice import bgmv_expand_slice -from vllm.lora.ops.bgmv_shrink import bgmv_shrink -from vllm.lora.ops.sgmv_expand import sgmv_expand -from vllm.lora.ops.sgmv_expand_slice import sgmv_expand_slice -from vllm.lora.ops.sgmv_shrink import sgmv_shrink +# Enable custom op register +import vllm.lora.ops.bgmv_expand +import vllm.lora.ops.bgmv_expand_slice +import vllm.lora.ops.bgmv_shrink +import vllm.lora.ops.sgmv_expand +import vllm.lora.ops.sgmv_expand_slice +import vllm.lora.ops.sgmv_shrink # noqa: F401 from vllm.platforms import current_platform from .utils import (generate_data, generate_data_for_expand_nslices, @@ -37,6 +38,16 @@ def assert_close(a, b): torch.testing.assert_close(a, b, rtol=rtol, atol=atol) +# Unlike test_punica_sizes.py, we directly utilize custom op for +# testing, which verifies the correct registration of these ops. +bgmv_expand = torch.ops.vllm.bgmv_expand +bgmv_expand_slice = torch.ops.vllm.bgmv_expand_slice +bgmv_shrink = torch.ops.vllm.bgmv_shrink +sgmv_expand = torch.ops.vllm.sgmv_expand +sgmv_expand_slice = torch.ops.vllm.sgmv_expand_slice +sgmv_shrink = torch.ops.vllm.sgmv_shrink + + @pytest.mark.parametrize("batches", BATCHES) @pytest.mark.parametrize("num_loras", NUM_LORA) @pytest.mark.parametrize("rank", MAX_RANKS) diff --git a/tests/model_executor/test_enabled_custom_ops.py b/tests/model_executor/test_enabled_custom_ops.py index c3219bc50646b..0a3aba255fd76 100644 --- a/tests/model_executor/test_enabled_custom_ops.py +++ b/tests/model_executor/test_enabled_custom_ops.py @@ -1,15 +1,13 @@ -import os from typing import List import pytest -from vllm.config import CompilationConfig, VllmConfig +from vllm.config import CompilationConfig, VllmConfig, set_current_vllm_config from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.layers.activation import (GeluAndMul, ReLUSquaredActivation, SiluAndMul) from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.plugins import set_current_vllm_config # Registered subclass for test @@ -53,9 +51,8 @@ class Relu3(ReLUSquaredActivation): ]) def test_enabled_ops(env: str, torch_level: int, ops_enabled: List[int], default_on: bool): - os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(torch_level) vllm_config = VllmConfig(compilation_config=CompilationConfig( - custom_ops=env.split(","))) + level=torch_level, custom_ops=env.split(","))) with set_current_vllm_config(vllm_config): assert CustomOp.default_on() == default_on diff --git a/tests/models/decoder_only/language/test_jamba.py b/tests/models/decoder_only/language/test_jamba.py index 6542689c3f277..87a05b3011393 100644 --- a/tests/models/decoder_only/language/test_jamba.py +++ b/tests/models/decoder_only/language/test_jamba.py @@ -275,6 +275,44 @@ def test_state_cleanup( "could be related to finished_requests_ids") +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["float"]) +def test_multistep( + vllm_runner, + model: str, + dtype: str, + example_prompts, +) -> None: + # This test is verifying that multistep works correctly + #on mamba-like models + with vllm_runner(model, num_scheduler_steps=8, + max_num_seqs=2) as vllm_model: + vllm_model.generate_greedy([example_prompts[0]] * 10, 1) + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["float"]) +@pytest.mark.parametrize("max_tokens", [64]) +def test_multistep_correctness(vllm_runner, model: str, dtype: str, + max_tokens: int, example_prompts) -> None: + with vllm_runner(model, num_scheduler_steps=8, + max_num_seqs=2) as vllm_model: + vllm_outputs_multistep = vllm_model.generate_greedy( + example_prompts, max_tokens) + + with vllm_runner(model, num_scheduler_steps=1, + max_num_seqs=2) as vllm_model: + vllm_outputs_single_step = vllm_model.generate_greedy( + example_prompts, max_tokens) + + check_outputs_equal( + outputs_0_lst=vllm_outputs_multistep, + outputs_1_lst=vllm_outputs_single_step, + name_0="vllm_outputs_multistep", + name_1="vllm_outputs_single_step", + ) + + @multi_gpu_test(num_gpus=2) @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["float"]) diff --git a/tests/models/decoder_only/language/test_mamba.py b/tests/models/decoder_only/language/test_mamba.py index 78eab8d5354fd..01e208347bff4 100644 --- a/tests/models/decoder_only/language/test_mamba.py +++ b/tests/models/decoder_only/language/test_mamba.py @@ -283,3 +283,39 @@ def test_state_cleanup( except ValueError: pytest.fail("Mamba inner state wasn't cleaned up between states, " "could be related to finished_requests_ids") + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["float"]) +def test_multistep( + vllm_runner, + model: str, + dtype: str, + example_prompts, +) -> None: + with vllm_runner(model, num_scheduler_steps=8, + max_num_seqs=2) as vllm_model: + vllm_model.generate_greedy([example_prompts[0]] * 10, 1) + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["float"]) +@pytest.mark.parametrize("max_tokens", [64]) +def test_multistep_correctness(vllm_runner, model: str, dtype: str, + max_tokens: int, example_prompts) -> None: + with vllm_runner(model, num_scheduler_steps=8, + max_num_seqs=2) as vllm_model: + vllm_outputs_multistep = vllm_model.generate_greedy( + example_prompts, max_tokens) + + with vllm_runner(model, num_scheduler_steps=1, + max_num_seqs=2) as vllm_model: + vllm_outputs_single_step = vllm_model.generate_greedy( + example_prompts, max_tokens) + + check_outputs_equal( + outputs_0_lst=vllm_outputs_multistep, + outputs_1_lst=vllm_outputs_single_step, + name_0="vllm_outputs_multistep", + name_1="vllm_outputs_single_step", + ) diff --git a/tests/models/decoder_only/vision_language/mm_processor_kwargs/test_internvl.py b/tests/models/decoder_only/vision_language/mm_processor_kwargs/test_internvl.py new file mode 100644 index 0000000000000..af0c2aa211998 --- /dev/null +++ b/tests/models/decoder_only/vision_language/mm_processor_kwargs/test_internvl.py @@ -0,0 +1,206 @@ +"""Tests for InternVL's multimodal preprocessing kwargs.""" +from typing import Callable, Optional + +import pytest +from transformers import AutoTokenizer + +from vllm.inputs import InputContext, token_inputs +from vllm.multimodal import MultiModalRegistry + +from .....conftest import _ImageAssets +from ....utils import build_model_context + +models = ["OpenGVLab/InternVL2-2B"] + + +# Wrap lazy imports to avoid initializing CUDA during test collection +@pytest.fixture() +def input_processor_for_internvl(): + from vllm.model_executor.models.internvl import InternVLInputPipeline + + pipeline = InternVLInputPipeline('', '', '') + return pipeline.input_processor + + +@pytest.fixture() +def dummy_data_for_internvl(): + from vllm.model_executor.models.internvl import InternVLInputPipeline + + pipeline = InternVLInputPipeline('', '', '') + return pipeline.dummy_data + + +@pytest.fixture() +def get_max_internvl_image_tokens(): + from vllm.model_executor.models.internvl import ( + get_max_internvl_image_tokens) + return get_max_internvl_image_tokens + + +@pytest.mark.parametrize("model", models) +@pytest.mark.parametrize("max_dynamic_patch", [1, 4]) +@pytest.mark.parametrize("dynamic_image_size", [True, False, None]) +def test_input_mapper_override( + model: str, + image_assets: _ImageAssets, + max_dynamic_patch: int, + dynamic_image_size: Optional[bool], +): + mm_processor_kwargs = { + "max_dynamic_patch": max_dynamic_patch, + } + if dynamic_image_size is not None: + mm_processor_kwargs["dynamic_image_size"] = dynamic_image_size + + expected_num_patches = max_dynamic_patch + 1 if max_dynamic_patch > 1 else 1 + if dynamic_image_size is False: + expected_num_patches = 1 + + ctx = build_model_context( + model_name=model, + tokenizer_name=model, + trust_remote_code=True, + mm_processor_kwargs=mm_processor_kwargs, + ) + + mm_registry = MultiModalRegistry() + mm_registry.init_mm_limits_per_prompt(ctx.model_config) + + image = image_assets[0].pil_image.resize((448 * 2, 448 * 2)) + vllm_result = mm_registry.map_input( + ctx.model_config, + {"image": image}, + ) + assert vllm_result["pixel_values"].size(1) == expected_num_patches + + +@pytest.mark.parametrize("model", models) +@pytest.mark.parametrize("max_dynamic_patch", [1, 4, None]) +@pytest.mark.parametrize("dynamic_image_size", [True, False, None]) +def test_max_tokens_override( + get_max_internvl_image_tokens: Callable, + model: str, + max_dynamic_patch: Optional[int], + dynamic_image_size: Optional[bool], +): + """Ensure get_max_internvl_image_tokens handles mm_processor_kwargs.""" + ctx = build_model_context( + model_name=model, + tokenizer_name=model, + trust_remote_code=True, + mm_processor_kwargs=None, + ) + + if max_dynamic_patch is None: + max_dynamic_patch = ctx.get_hf_config().max_dynamic_patch + expected_num_patches = max_dynamic_patch + 1 if max_dynamic_patch > 1 else 1 + if dynamic_image_size is False: + expected_num_patches = 1 + expected_max_tokens = 256 * expected_num_patches + + actual_max_tokens = get_max_internvl_image_tokens( + ctx=InputContext(ctx.model_config), + max_dynamic_patch=max_dynamic_patch, + dynamic_image_size=dynamic_image_size, + ) + assert expected_max_tokens == actual_max_tokens + + +@pytest.mark.parametrize("model", models) +@pytest.mark.parametrize("num_imgs", [1, 2]) +@pytest.mark.parametrize("max_dynamic_patch", [1, 4, None]) +@pytest.mark.parametrize("dynamic_image_size", [True, False, None]) +def test_dummy_data_override( + dummy_data_for_internvl: Callable, + model: str, + num_imgs: int, + max_dynamic_patch: Optional[int], + dynamic_image_size: Optional[bool], +): + """Ensure dummy_data_for_internvl handles kwargs properly.""" + # Same as the previous test - don't initialize mm_processor_kwargs + # in this test and assume that the kwargs will be correctly expanded by + # the partial when calling the dummy data func. + ctx = build_model_context( + model_name=model, + tokenizer_name=model, + trust_remote_code=True, + mm_processor_kwargs=None, + ) + + if max_dynamic_patch is None: + max_dynamic_patch = ctx.get_hf_config().max_dynamic_patch + expected_num_patches = max_dynamic_patch + 1 if max_dynamic_patch > 1 else 1 + if dynamic_image_size is False: + expected_num_patches = 1 + expected_max_tokens = 256 * expected_num_patches + + dummy_data = dummy_data_for_internvl( + ctx=ctx, + seq_len=8192, # Should be bigger than num_imgs * toks_per_img + mm_counts={"image": num_imgs}, + max_dynamic_patch=max_dynamic_patch, + dynamic_image_size=dynamic_image_size, + ) + sequence_data = dummy_data.seq_data + + tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True) + image_token_id = tokenizer.encode('', + add_special_tokens=False)[0] + + # Ensure we have the right number of placeholders per size + img_tok_count = sequence_data.get_token_ids().count(image_token_id) + assert img_tok_count == expected_max_tokens * num_imgs + + +@pytest.mark.parametrize("model", models) +@pytest.mark.parametrize("max_dynamic_patch", [1, 4]) +@pytest.mark.parametrize("dynamic_image_size", [True, False, None]) +@pytest.mark.parametrize("num_imgs", [1, 2]) +def test_input_processor_override( + input_processor_for_internvl: Callable, + image_assets: _ImageAssets, + model: str, + num_imgs: int, + max_dynamic_patch: int, + dynamic_image_size: Optional[bool], +): + """Ensure input_processor_for_internvl handles kwargs properly.""" + # Same as the previous test - don't initialize mm_processor_kwargs + # in this test and assume that the kwargs will be correctly expanded by + # the partial when calling the custom input processor. + expected_num_patches = max_dynamic_patch + 1 if max_dynamic_patch > 1 else 1 + if dynamic_image_size is False: + expected_num_patches = 1 + + ctx = build_model_context( + model_name=model, + tokenizer_name=model, + trust_remote_code=True, + mm_processor_kwargs=None, + ) + expected_toks_per_img = 256 * expected_num_patches + + # Build the image str / prompt based on the number of images we pass + tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True) + placeholders = "" if num_imgs == 1 else "\n".join( + f"Image-{i}: \n" for i in range(1, num_imgs + 1)) + prompt = placeholders + images = [image_assets[0].pil_image.resize((448 * 2, 448 * 2))] * num_imgs + + inputs = token_inputs(prompt_token_ids=tokenizer.encode(prompt), + prompt=prompt, + multi_modal_data={"image": images}) + + processed_inputs = input_processor_for_internvl( + ctx, + inputs, + max_dynamic_patch=max_dynamic_patch, + dynamic_image_size=dynamic_image_size, + ) + + # Ensure we have the right number of placeholders per num_crops size + image_token_id = tokenizer.encode('', + add_special_tokens=False)[0] + img_tok_count = processed_inputs["prompt_token_ids"].count(image_token_id) + assert img_tok_count == expected_toks_per_img * num_imgs diff --git a/tests/models/decoder_only/vision_language/test_models.py b/tests/models/decoder_only/vision_language/test_models.py index 3f6d8ef42cd5f..dbb0b4d350d10 100644 --- a/tests/models/decoder_only/vision_language/test_models.py +++ b/tests/models/decoder_only/vision_language/test_models.py @@ -8,6 +8,7 @@ import pytest import transformers from transformers import AutoModelForVision2Seq +from transformers.utils import is_flash_attn_2_available from vllm.platforms import current_platform from vllm.utils import cuda_device_count_stateless, identity @@ -134,6 +135,35 @@ marks=[pytest.mark.core_model, pytest.mark.cpu_model], ), #### Extended model tests + "aria": VLMTestInfo( + models=["rhymes-ai/Aria"], + tokenizer_mode="slow", + test_type=( + VLMTestType.IMAGE, + VLMTestType.MULTI_IMAGE, + ), + dtype="bfloat16", + prompt_formatter=lambda img_prompt: f"<|im_start|>user\n{img_prompt}<|im_end|>\n<|im_start|>assistant\n ", # noqa: E501 + img_idx_to_prompt=lambda idx: "<|img|>\n", + max_model_len=4096, + max_num_seqs=2, + single_image_prompts=IMAGE_ASSETS.prompts({ + "stop_sign": "Please describe the image shortly.", + "cherry_blossom": "Please infer the season with reason.", + }), + multi_image_prompt="Describe the two images shortly.", # noqa: E501 + postprocess_inputs=model_utils.get_key_type_post_processor("pixel_values"), + stop_str=["<|im_end|>"], + image_size_factors=[(0.10, 0.15)], + max_tokens=64, + marks=[ + pytest.mark.skipif( + not is_flash_attn_2_available(), + reason="Model needs flash-attn for numeric convergence.", + ), + large_gpu_mark(min_gb=64), + ], + ), "blip2": VLMTestInfo( models=["Salesforce/blip2-opt-2.7b"], test_type=VLMTestType.IMAGE, @@ -295,16 +325,29 @@ ) ], ), - "minicpmv": VLMTestInfo( + "minicpmv_25": VLMTestInfo( models=["openbmb/MiniCPM-Llama3-V-2_5"], - test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), + test_type=VLMTestType.IMAGE, prompt_formatter=lambda img_prompt: f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{img_prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", # noqa: E501 img_idx_to_prompt=lambda idx: "(./)\n", max_model_len=4096, max_num_seqs=2, get_stop_token_ids=lambda tok: [tok.eos_id, tok.eot_id], postprocess_inputs=model_utils.wrap_inputs_post_processor, - hf_output_post_proc=model_utils.minicmpv_trunc_hf_output, + hf_output_post_proc=model_utils.minicpmv_trunc_hf_output, + ), + "minicpmv_26": VLMTestInfo( + models=["openbmb/MiniCPM-V-2_6"], + test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), + prompt_formatter=lambda img_prompt: f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{img_prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", # noqa: E501 + img_idx_to_prompt=lambda idx: "(./)\n", + max_model_len=4096, + max_num_seqs=2, + get_stop_token_ids=lambda tok: tok.convert_tokens_to_ids(['<|im_end|>', '<|endoftext|>']), # noqa: E501 + postprocess_inputs=model_utils.ignore_inputs_post_processor( + "image_sizes" + ), + hf_output_post_proc=model_utils.minicpmv_trunc_hf_output, ), # Tests for phi3v currently live in another file because of a bug in # transformers. Once this issue is fixed, we can enable them here instead. diff --git a/tests/models/decoder_only/vision_language/vlm_utils/core.py b/tests/models/decoder_only/vision_language/vlm_utils/core.py index 7e8c6dabb15af..88349ef9a3a69 100644 --- a/tests/models/decoder_only/vision_language/vlm_utils/core.py +++ b/tests/models/decoder_only/vision_language/vlm_utils/core.py @@ -29,6 +29,8 @@ def run_test( postprocess_inputs: Callable[[BatchEncoding], BatchEncoding], comparator: Callable[..., None], get_stop_token_ids: Optional[Callable[[AutoTokenizer], List[int]]], + stop_str: Optional[List[str]], + tokenizer_mode: str, limit_mm_per_prompt: Dict[str, int], model_kwargs: Optional[Dict[str, Any]], patch_hf_runner: Optional[Callable[[HfRunner], HfRunner]], @@ -50,11 +52,14 @@ def run_test( # vLLM needs a fresh new process without cuda initialization. # if we run HF first, the cuda initialization will be done and it # will hurt multiprocessing backend with fork method (the default method). - vllm_kwargs = {} + vllm_kwargs: Dict[str, Any] = {} if get_stop_token_ids is not None: vllm_kwargs["stop_token_ids"] = get_stop_token_ids(tokenizer) + if stop_str: + vllm_kwargs["stop"] = stop_str with vllm_runner(model, + tokenizer_mode=tokenizer_mode, max_model_len=max_model_len, max_num_seqs=max_num_seqs, dtype=dtype, @@ -85,6 +90,8 @@ def run_test( hf_kwargs = {} if use_tokenizer_eos: hf_kwargs["eos_token_id"] = tokenizer.eos_token_id + if stop_str: + hf_kwargs["stop_strings"] = stop_str with hf_model, torch.no_grad(): for prompts, media in inputs: @@ -138,4 +145,4 @@ def process_runner_outputs( def process_outputs(output_processor, model, outputs_per_image): """Applies a model specific post-processor function to a runner's output""" return [[output_processor(res, model) for res in outputs] - for outputs in outputs_per_image] + for outputs in outputs_per_image] \ No newline at end of file diff --git a/tests/models/decoder_only/vision_language/vlm_utils/model_utils.py b/tests/models/decoder_only/vision_language/vlm_utils/model_utils.py index 849857b4232e7..15f15dd7d8030 100644 --- a/tests/models/decoder_only/vision_language/vlm_utils/model_utils.py +++ b/tests/models/decoder_only/vision_language/vlm_utils/model_utils.py @@ -170,7 +170,7 @@ def paligemma_vllm_to_hf_output(vllm_output: RunnerOutput, ####### Post-processors for HF outputs -def minicmpv_trunc_hf_output(hf_output: RunnerOutput, +def minicpmv_trunc_hf_output(hf_output: RunnerOutput, model: str) -> RunnerOutput: output_ids, output_str, out_logprobs = hf_output if output_str.endswith("<|eot_id|>"): @@ -197,6 +197,17 @@ def process(hf_inputs: BatchEncoding, dtype: str): return process +def ignore_inputs_post_processor( + hf_inp_key: str) -> Callable[[BatchEncoding, str], BatchEncoding]: + """Gets a handle to a post processor which ignores a given key.""" + + def process(hf_inputs: BatchEncoding, dtype: str): + del hf_inputs[hf_inp_key] + return hf_inputs + + return process + + def wrap_inputs_post_processor(hf_inputs: BatchEncoding, dtype: str): return {"model_inputs": hf_inputs} diff --git a/tests/models/decoder_only/vision_language/vlm_utils/types.py b/tests/models/decoder_only/vision_language/vlm_utils/types.py index 8459476dc2d07..d410fa8c653ce 100644 --- a/tests/models/decoder_only/vision_language/vlm_utils/types.py +++ b/tests/models/decoder_only/vision_language/vlm_utils/types.py @@ -97,6 +97,9 @@ class VLMTestInfo(NamedTuple): # Optional callable which gets a list of token IDs from the model tokenizer get_stop_token_ids: Optional[Callable[[AutoTokenizer], List[int]]] = None + # Optional list of strings to stop generation, useful when stop tokens are + # not special tokens in the tokenizer + stop_str: Optional[List[str]] = None # Exposed options for HF runner model_kwargs: Optional[Dict[str, Any]] = None @@ -148,6 +151,8 @@ class VLMTestInfo(NamedTuple): marks: Optional[List[MarkDecorator]] = None + tokenizer_mode: str = "auto" + def get_non_parametrized_runner_kwargs(self): """Returns a dictionary of expandable kwargs for items that are used in all test types, which are NOT used when creating the parametrized @@ -166,8 +171,10 @@ def get_non_parametrized_runner_kwargs(self): "postprocess_inputs": self.postprocess_inputs, "comparator": self.comparator, "get_stop_token_ids": self.get_stop_token_ids, + "stop_str": self.stop_str, "model_kwargs": self.model_kwargs, "patch_hf_runner": self.patch_hf_runner, + "tokenizer_mode": self.tokenizer_mode } diff --git a/tests/models/embedding/language/test_embedding.py b/tests/models/embedding/language/test_embedding.py index c3f351ef707be..5ef8540265d14 100644 --- a/tests/models/embedding/language/test_embedding.py +++ b/tests/models/embedding/language/test_embedding.py @@ -4,6 +4,8 @@ """ import pytest +from vllm.config import PoolerConfig + from ..utils import check_embeddings_close @@ -21,6 +23,7 @@ marks=[pytest.mark.core_model]), pytest.param("ssmits/Qwen2-7B-Instruct-embed-base"), pytest.param("Alibaba-NLP/gte-Qwen2-1.5B-instruct"), + pytest.param("Alibaba-NLP/gte-Qwen2-7B-instruct"), ], ) @pytest.mark.parametrize("dtype", ["half"]) @@ -31,6 +34,13 @@ def test_models( model, dtype: str, ) -> None: + vllm_extra_kwargs = {} + if model == "ssmits/Qwen2-7B-Instruct-embed-base": + vllm_extra_kwargs["override_pooler_config"] = \ + PoolerConfig(pooling_type="MEAN") + if model == "Alibaba-NLP/gte-Qwen2-7B-instruct": + vllm_extra_kwargs["hf_overrides"] = {"is_causal": False} + # The example_prompts has ending "\n", for example: # "Write a short story about a robot that dreams for the first time.\n" # sentence_transformers will strip the input texts, see: @@ -43,8 +53,11 @@ def test_models( is_sentence_transformer=True) as hf_model: hf_outputs = hf_model.encode(example_prompts) - with vllm_runner(model, task="embedding", dtype=dtype, - max_model_len=None) as vllm_model: + with vllm_runner(model, + task="embedding", + dtype=dtype, + max_model_len=None, + **vllm_extra_kwargs) as vllm_model: vllm_outputs = vllm_model.encode(example_prompts) # This test is for verifying whether the model's extra_repr # can be printed correctly. diff --git a/tests/models/embedding/language/test_scoring.py b/tests/models/embedding/language/test_scoring.py new file mode 100644 index 0000000000000..30fa5ea7b36c0 --- /dev/null +++ b/tests/models/embedding/language/test_scoring.py @@ -0,0 +1,95 @@ +"""Compare the embedding outputs of HF and vLLM models. + +Run `pytest tests/models/embedding/language/test_embedding.py`. +""" +import math + +import pytest + +MODELS = [ + "cross-encoder/ms-marco-MiniLM-L-6-v2", # Bert + "BAAI/bge-reranker-v2-m3", # Roberta +] + +TEXTS_1 = [ + "What is the capital of France?", + "What is the capital of Germany?", +] + +TEXTS_2 = [ + "The capital of France is Paris.", + "The capital of Germany is Berlin.", +] + + +@pytest.fixture(scope="module", params=MODELS) +def model_name(request): + yield request.param + + +@pytest.mark.parametrize("dtype", ["half"]) +def test_llm_1_to_1(vllm_runner, hf_runner, model_name, dtype: str): + + text_pair = [TEXTS_1[0], TEXTS_2[0]] + + with hf_runner(model_name, dtype=dtype, is_cross_encoder=True) as hf_model: + hf_outputs = hf_model.predict([text_pair]).tolist() + + with vllm_runner(model_name, + task="embedding", + dtype=dtype, + max_model_len=None) as vllm_model: + vllm_outputs = vllm_model.score(text_pair[0], text_pair[1]) + + assert len(vllm_outputs) == 1 + assert len(hf_outputs) == 1 + + assert math.isclose(hf_outputs[0], vllm_outputs[0][0], rel_tol=0.01) + + +@pytest.mark.parametrize("dtype", ["half"]) +def test_llm_1_to_N(vllm_runner, hf_runner, model_name, dtype: str): + + text_pairs = [ + [TEXTS_1[0], TEXTS_2[0]], + [TEXTS_1[0], TEXTS_2[1]], + ] + + with hf_runner(model_name, dtype=dtype, is_cross_encoder=True) as hf_model: + hf_outputs = hf_model.predict(text_pairs).tolist() + + with vllm_runner(model_name, + task="embedding", + dtype=dtype, + max_model_len=None) as vllm_model: + vllm_outputs = vllm_model.score(TEXTS_1[0], TEXTS_2) + + assert len(vllm_outputs) == 2 + assert len(hf_outputs) == 2 + + assert math.isclose(hf_outputs[0], vllm_outputs[0][0], rel_tol=0.01) + assert math.isclose(hf_outputs[1], vllm_outputs[1][0], rel_tol=0.01) + + +@pytest.mark.parametrize("dtype", ["half"]) +def test_llm_N_to_N(vllm_runner, hf_runner, model_name, dtype: str): + + text_pairs = [ + [TEXTS_1[0], TEXTS_2[0]], + [TEXTS_1[1], TEXTS_2[1]], + ] + + with hf_runner(model_name, dtype=dtype, is_cross_encoder=True) as hf_model: + hf_outputs = hf_model.predict(text_pairs).tolist() + + with vllm_runner(model_name, + task="embedding", + dtype=dtype, + max_model_len=None) as vllm_model: + vllm_outputs = vllm_model.score(TEXTS_1, TEXTS_2) + + assert len(vllm_outputs) == 2 + assert len(hf_outputs) == 2 + + assert math.isclose(hf_outputs[0], vllm_outputs[0][0], rel_tol=0.01) + assert math.isclose(hf_outputs[1], vllm_outputs[1][0], rel_tol=0.01) diff --git a/tests/models/embedding/utils.py b/tests/models/embedding/utils.py index fd1c44d9c117e..f96c7d2b176db 100644 --- a/tests/models/embedding/utils.py +++ b/tests/models/embedding/utils.py @@ -24,7 +24,7 @@ def check_embeddings_close( dim=0) fail_msg = (f"Test{prompt_idx}:" - f"\n{name_0}:\t{embeddings_0!r}" - f"\n{name_1}:\t{embeddings_1!r}") + f"\n{name_0}:\t{embeddings_0[:16]!r}" + f"\n{name_1}:\t{embeddings_1[:16]!r}") assert sim >= 1 - tol, fail_msg diff --git a/tests/models/registry.py b/tests/models/registry.py index 3848367b6126c..461f453d8b1c3 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -43,6 +43,8 @@ class _HfExamplesInfo: trust_remote_code=True), "ArcticForCausalLM": _HfExamplesInfo("Snowflake/snowflake-arctic-instruct", trust_remote_code=True), + "AriaForConditionalGeneration": _HfExamplesInfo("rhymes-ai/Aria", + trust_remote_code=True), "BaiChuanForCausalLM": _HfExamplesInfo("baichuan-inc/Baichuan-7B", trust_remote_code=True), "BaichuanForCausalLM": _HfExamplesInfo("baichuan-inc/Baichuan2-7B-chat", @@ -61,6 +63,7 @@ class _HfExamplesInfo: "FalconForCausalLM": _HfExamplesInfo("tiiuae/falcon-7b"), "GemmaForCausalLM": _HfExamplesInfo("google/gemma-2b"), "Gemma2ForCausalLM": _HfExamplesInfo("google/gemma-2-9b"), + "GlmForCausalLM": _HfExamplesInfo("THUDM/glm-4-9b-chat-hf"), "GPT2LMHeadModel": _HfExamplesInfo("gpt2"), "GPTBigCodeForCausalLM": _HfExamplesInfo("bigcode/starcoder"), "GPTJForCausalLM": _HfExamplesInfo("EleutherAI/gpt-j-6b"), @@ -91,6 +94,7 @@ class _HfExamplesInfo: "MPTForCausalLM": _HfExamplesInfo("mosaicml/mpt-7b"), "NemotronForCausalLM": _HfExamplesInfo("nvidia/Minitron-8B-Base"), "OlmoForCausalLM": _HfExamplesInfo("allenai/OLMo-1B-hf"), + "Olmo2ForCausalLM": _HfExamplesInfo("shanearora/OLMo-7B-1124-hf"), "OlmoeForCausalLM": _HfExamplesInfo("allenai/OLMoE-1B-7B-0924-Instruct"), "OPTForCausalLM": _HfExamplesInfo("facebook/opt-iml-max-1.3b"), "OrionForCausalLM": _HfExamplesInfo("OrionStarAI/Orion-14B-Chat", @@ -112,6 +116,8 @@ class _HfExamplesInfo: "StableLmForCausalLM": _HfExamplesInfo("stabilityai/stablelm-3b-4e1t"), "Starcoder2ForCausalLM": _HfExamplesInfo("bigcode/starcoder2-3b"), "SolarForCausalLM": _HfExamplesInfo("upstage/solar-pro-preview-instruct"), + "TeleChat2ForCausalLM": _HfExamplesInfo("Tele-AI/TeleChat2-3B", + trust_remote_code=True), "XverseForCausalLM": _HfExamplesInfo("xverse/XVERSE-7B-Chat", is_available_online=False, trust_remote_code=True), @@ -135,6 +141,7 @@ class _HfExamplesInfo: "Qwen2ForRewardModel": _HfExamplesInfo("Qwen/Qwen2.5-Math-RM-72B"), "Qwen2ForSequenceClassification": _HfExamplesInfo("jason9693/Qwen2.5-1.5B-apeach"), # noqa: E501 "RobertaModel": _HfExamplesInfo("sentence-transformers/stsb-roberta-base-v2"), # noqa: E501 + "RobertaForMaskedLM": _HfExamplesInfo("sentence-transformers/all-roberta-large-v1"), # noqa: E501 "XLMRobertaModel": _HfExamplesInfo("intfloat/multilingual-e5-large"), # [Multimodal] "LlavaNextForConditionalGeneration": _HfExamplesInfo("royokong/e5-v"), @@ -143,6 +150,13 @@ class _HfExamplesInfo: "Qwen2VLForConditionalGeneration": _HfExamplesInfo("MrLight/dse-qwen2-2b-mrl-v1"), # noqa: E501 } +_CROSS_ENCODER_EXAMPLE_MODELS = { + # [Text-only] + "BertForSequenceClassification": _HfExamplesInfo("cross-encoder/ms-marco-MiniLM-L-6-v2"), # noqa: E501 + "RobertaForSequenceClassification": _HfExamplesInfo("cross-encoder/quora-roberta-base"), # noqa: E501 + "XLMRobertaForSequenceClassification": _HfExamplesInfo("BAAI/bge-reranker-v2-m3"), # noqa: E501 +} + _MULTIMODAL_EXAMPLE_MODELS = { # [Decoder-only] "Blip2ForConditionalGeneration": _HfExamplesInfo("Salesforce/blip2-opt-2.7b"), # noqa: E501 @@ -195,6 +209,7 @@ class _HfExamplesInfo: _EXAMPLE_MODELS = { **_TEXT_GENERATION_EXAMPLE_MODELS, **_EMBEDDING_EXAMPLE_MODELS, + **_CROSS_ENCODER_EXAMPLE_MODELS, **_MULTIMODAL_EXAMPLE_MODELS, **_SPECULATIVE_DECODING_EXAMPLE_MODELS, } diff --git a/tests/models/test_initialization.py b/tests/models/test_initialization.py index b8312c2d9b7cc..2a072737db043 100644 --- a/tests/models/test_initialization.py +++ b/tests/models/test_initialization.py @@ -11,7 +11,7 @@ @pytest.mark.parametrize("model_arch", HF_EXAMPLE_MODELS.get_supported_archs()) def test_can_initialize(model_arch): - if (model_arch == "Idefics3ForConditionalGeneration" + if (model_arch in {"Idefics3ForConditionalGeneration", "GlmForCausalLM"} and transformers.__version__ < "4.46.0"): pytest.skip(reason="Model introduced in HF >= 4.46.0") diff --git a/tests/models/test_registry.py b/tests/models/test_registry.py index e462dae3dc688..b5368aab3ecf1 100644 --- a/tests/models/test_registry.py +++ b/tests/models/test_registry.py @@ -3,11 +3,11 @@ import pytest import torch.cuda -from vllm.model_executor.models import (is_embedding_model, +from vllm.model_executor.models import (is_pooling_model, is_text_generation_model, supports_multimodal) -from vllm.model_executor.models.registry import (_EMBEDDING_MODELS, - _MULTIMODAL_MODELS, +from vllm.model_executor.models.adapters import as_embedding_model +from vllm.model_executor.models.registry import (_MULTIMODAL_MODELS, _SPECULATIVE_DECODING_MODELS, _TEXT_GENERATION_MODELS, ModelRegistry) @@ -23,28 +23,34 @@ def test_registry_imports(model_arch): model_cls, _ = ModelRegistry.resolve_model_cls(model_arch) if model_arch in _SPECULATIVE_DECODING_MODELS: - pass # Ignore these models which do not have a unified format - else: - assert is_text_generation_model(model_cls) is ( - model_arch in _TEXT_GENERATION_MODELS - or model_arch in _MULTIMODAL_MODELS) + return # Ignore these models which do not have a unified format - assert is_embedding_model(model_cls) is (model_arch - in _EMBEDDING_MODELS) + if (model_arch in _TEXT_GENERATION_MODELS + or model_arch in _MULTIMODAL_MODELS): + assert is_text_generation_model(model_cls) - assert supports_multimodal(model_cls) is (model_arch - in _MULTIMODAL_MODELS) + # All vLLM models should be convertible to an embedding model + embed_model = as_embedding_model(model_cls) + assert is_pooling_model(embed_model) + + if model_arch in _MULTIMODAL_MODELS: + assert supports_multimodal(model_cls) @fork_new_process_for_each_test -@pytest.mark.parametrize("model_arch,is_mm,init_cuda", [ - ("LlamaForCausalLM", False, False), - ("MllamaForConditionalGeneration", True, False), - ("LlavaForConditionalGeneration", True, True), +@pytest.mark.parametrize("model_arch,is_mm,init_cuda,is_ce", [ + ("LlamaForCausalLM", False, False, False), + ("MllamaForConditionalGeneration", True, False, False), + ("LlavaForConditionalGeneration", True, True, False), + ("BertForSequenceClassification", False, False, True), + ("RobertaForSequenceClassification", False, False, True), + ("XLMRobertaForSequenceClassification", False, False, True), ]) -def test_registry_is_multimodal(model_arch, is_mm, init_cuda): +def test_registry_model_property(model_arch, is_mm, init_cuda, is_ce): assert ModelRegistry.is_multimodal_model(model_arch) is is_mm + assert ModelRegistry.is_cross_encoder_model(model_arch) is is_ce + if init_cuda and current_platform.is_cuda_alike(): assert not torch.cuda.is_initialized() diff --git a/tests/multimodal/test_processing.py b/tests/multimodal/test_processing.py new file mode 100644 index 0000000000000..b2367060c6c1b --- /dev/null +++ b/tests/multimodal/test_processing.py @@ -0,0 +1,370 @@ +from typing import cast + +import pytest +from transformers import BatchFeature + +from vllm.multimodal.processing import (PromptReplacement, find_text_matches, + find_token_matches, iter_token_matches, + iter_token_runs, replace_text_matches) +from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.utils import full_groupby + + +# yapf: disable +@pytest.mark.parametrize( + ("token_ids", "expected"), + [ + ([], []), + ( + [32000, 32000, 32000], + [{ "token_id": 32000, "start_idx": 0, "length": 3 }], + ), + ( + [9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918], + [ + { "token_id": 9833, "start_idx": 0, "length": 1 }, + { "token_id": 28747, "start_idx": 1, "length": 1 }, + { "token_id": 32000, "start_idx": 2, "length": 3 }, + { "token_id": 9833, "start_idx": 5, "length": 1 }, + { "token_id": 28747, "start_idx": 6, "length": 1 }, + { "token_id": 32000, "start_idx": 7, "length": 2 }, + { "token_id": 918, "start_idx": 9, "length": 1 }, + ], + ), + ], +) +# yapf: enable +def test_iter_token_runs(token_ids, expected): + result = list(iter_token_runs(token_ids)) + + # Only displayed on error + print("result:", result) + + # Manually constructed results + assert [item._asdict() for item in result] == expected + + # Invariants + assert sum(run_info.length for run_info in result) == len(token_ids) + + +# yapf: disable +@pytest.mark.parametrize( + ("token_ids", "match_ids", "expected"), + [ + ([], [], [{ "start_idx": 0, "end_idx": 0 }]), + ([], [32000], []), + ( + [32000, 32000, 32000], + [32000], + [ + { "start_idx": 0, "end_idx": 1 }, + { "start_idx": 1, "end_idx": 2 }, + { "start_idx": 2, "end_idx": 3 }, + ], + ), + ( + [32000, 32000, 32000], + [32000, 32000], + [{ "start_idx": 0, "end_idx": 2 }], + ), + ( + [32000, 32000, 32000], + [32000, 32000, 32000], + [{ "start_idx": 0, "end_idx": 3 }], + ), + ( + [9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918], + [28747, 32000], + [ + { "start_idx": 1, "end_idx": 3 }, + { "start_idx": 6, "end_idx": 8 }, + ], + ), + ( + [9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918], + [28747, 32000, 32000, 32000], + [ + { "start_idx": 1, "end_idx": 5 }, + ], + ), + ( + [9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918], + [28747, 0, 32000], + [], + ), + ], +) +# yapf: enable +def test_iter_token_matches(token_ids, match_ids, expected): + result = list(iter_token_matches(token_ids, match_ids)) + + # Manually constructed results + assert [item._asdict() for item in result] == expected + + # Invariants + match_lens = [end - start for start, end in result] + print("match_lens:", match_lens) # Only displayed on error + assert all(match_len == len(match_ids) for match_len in match_lens) + + +# yapf: disable +@pytest.mark.parametrize( + ("prompt", "target_by_key", "expected_by_key"), + [ + ( + [], + { + "pattern_1": [], + "pattern_2": [32000], + }, + { + "pattern_1": [{ "start_idx": 0, "end_idx": 0 }], + "pattern_2": [], + } + ), + ( + [32000, 32000, 32000, 32000], + { + "pattern_1": [32000], + "pattern_2": [32000, 32000], + "pattern_3": [32000, 32000, 32000], + }, + { + "pattern_1": [ + { "start_idx": 0, "end_idx": 1 }, + { "start_idx": 1, "end_idx": 2 }, + { "start_idx": 2, "end_idx": 3 }, + { "start_idx": 3, "end_idx": 4 }, + ], + "pattern_2": [ + { "start_idx": 0, "end_idx": 2 }, + { "start_idx": 2, "end_idx": 4 }, + ], + "pattern_3": [ + { "start_idx": 0, "end_idx": 3 }, + ], + }, + ), + ( + [9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918], + { + "pattern_1": [28747, 32000], + "pattern_2": [28747, 32000, 32000, 32000], + "pattern_3": [28747, 0, 32000], + }, + { + "pattern_1": [ + { "start_idx": 1, "end_idx": 3 }, + { "start_idx": 6, "end_idx": 8 }, + ], + "pattern_2": [ + { "start_idx": 1, "end_idx": 5 }, + ], + "pattern_3": [], + }, + ), + ], +) +# yapf: enable +def test_find_token_matches(prompt, target_by_key, expected_by_key): + # Should not be used since there is nothing to convert to token IDs + mock_tokenizer = cast(AnyTokenizer, object()) + + result = find_token_matches( + prompt, + [ + PromptReplacement(target, [], 0).bind(key, mock_tokenizer) + for key, target in target_by_key.items() + ], + ) + + # Only displayed on error + print("result:", result) + + # Manually constructed results + result_groups = dict(full_groupby(result, key=lambda x: x.modality)) + assert { + key: [ + dict(start_idx=item.start_idx, end_idx=item.end_idx) + for item in result_groups.get(key, []) + ] + for key in expected_by_key + } == expected_by_key + + +# yapf: disable +@pytest.mark.parametrize( + ("prompt", "target_by_key", "expected_by_key"), + [ + # Detokenized test cases of `test_find_token_matches` + # using the vocab of llava-hf/llava-v1.6-mistral-7b-hf + ( + "", + { + "pattern_1": "", + "pattern_2": "", + }, + { + "pattern_1": [{ "start_idx": 0, "end_idx": 0 }], + "pattern_2": [], + } + ), + ( + "", + { + "pattern_1": "", + "pattern_2": "", + "pattern_3": "", + }, + { + "pattern_1": [ + { "start_idx": 0, "end_idx": 7 }, + { "start_idx": 7, "end_idx": 14 }, + { "start_idx": 14, "end_idx": 21 }, + { "start_idx": 21, "end_idx": 28 }, + ], + "pattern_2": [ + { "start_idx": 0, "end_idx": 14 }, + { "start_idx": 14, "end_idx": 28 }, + ], + "pattern_3": [ + { "start_idx": 0, "end_idx": 21 }, + ], + }, + ), + ( + "Image:Image:!", + { + "pattern_1": "Image:", + "pattern_2": "Image:", + "pattern_3": "Image:", + }, + { + "pattern_1": [ + { "start_idx": 0, "end_idx": 13 }, + { "start_idx": 27, "end_idx": 40 }, + ], + "pattern_2": [ + { "start_idx": 0, "end_idx": 27 }, + ], + "pattern_3": [], + }, + ), + # Test regex escape + ( + "<|image|><|image|>", + { + "pattern_1": "<|image|>", + "pattern_2": "<|image|>", + "pattern_3": "<|image|><|image|>", + }, + { + "pattern_1": [ + { "start_idx": 0, "end_idx": 9 }, + { "start_idx": 16, "end_idx": 25 }, + ], + "pattern_2": [ + { "start_idx": 0, "end_idx": 16 }, + { "start_idx": 16, "end_idx": 32 }, + ], + "pattern_3": [ + { "start_idx": 0, "end_idx": 25 }, + ], + }, + ), + ], +) +# yapf: enable +def test_find_text_matches(prompt, target_by_key, expected_by_key): + # Should not be used since there is nothing to convert to text + mock_tokenizer = cast(AnyTokenizer, object()) + + result = find_text_matches( + prompt, + [ + PromptReplacement(target, [], 0).bind(key, mock_tokenizer) + for key, target in target_by_key.items() + ], + ) + + # Only displayed on error + print("result:", result) + + # Manually constructed results + result_groups = dict(full_groupby(result, key=lambda x: x.modality)) + assert { + key: [ + dict(start_idx=item.start_idx, end_idx=item.end_idx) + for item in result_groups.get(key, []) + ] + for key in expected_by_key + } == expected_by_key + + +# yapf: disable +@pytest.mark.parametrize( + ("prompt", "target_by_key", "repl_by_key", "expected_by_mm_count"), + [ + ( + "Image:Image:!", + { + # We use `` before `Image:` to test matches that + # occur out of order + "pattern_1": "", + "pattern_2": "Image:", + "pattern_3": "!", + }, + { + # Test whether target is confused with repl_unit + "pattern_1": ("", 1), + # Test empty repl_unit + "pattern_2": ("", 1), + # Test multiple repl_count + "pattern_3": ("?", 2), + }, + { + # Test no replacement + 0: "Image:Image:!", + # Test single replacement + 1: "Image:??", + # Test repeated replacement + 2: "??", + }, + ), + ] +) +# yapf: enable +def test_find_replace_text( + prompt, + target_by_key, + repl_by_key, + expected_by_mm_count, +): + # Should not be used since there is nothing to convert to text + mock_tokenizer = cast(AnyTokenizer, object()) + + matches = find_text_matches( + prompt, + [ + PromptReplacement(target, *repl_by_key[key]) \ + .bind(key, mock_tokenizer) + for key, target in target_by_key.items() + ], + ) + result_by_mm_count = { + mm_count: replace_text_matches( + prompt, + matches, + {key: list(range(mm_count)) + for key in repl_by_key}, + BatchFeature(), + ) + for mm_count in expected_by_mm_count + } + + # Only displayed on error + print("matches:", matches) + print("result_by_mm_count:", result_by_mm_count) + + # Manually constructed results + assert result_by_mm_count == expected_by_mm_count diff --git a/tests/multimodal/test_utils.py b/tests/multimodal/test_utils.py index 9869c8123f001..fd82fb0c55fd7 100644 --- a/tests/multimodal/test_utils.py +++ b/tests/multimodal/test_utils.py @@ -139,7 +139,8 @@ def test_repeat_and_pad_placeholder_tokens(model): 2, "", [32000, 32000, 32000], - [{ "offset": 0, "length": 2 }]), + [{ "offset": 0, "length": 2 }], + ), ( "", [3, 2], diff --git a/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_gemma_embedding.py b/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_gemma_embedding.py index 21958b1640204..d676eacffb056 100644 --- a/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_gemma_embedding.py +++ b/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_gemma_embedding.py @@ -1,13 +1,34 @@ -from typing import List, Optional, Union +from typing import Iterable, List, Optional, Tuple, Union import torch +import torch.nn as nn from vllm.attention import AttentionMetadata -from vllm.model_executor.models.gemma2 import Gemma2EmbeddingModel -from vllm.sequence import IntermediateTensors +from vllm.config import VllmConfig +from vllm.model_executor.layers.pooler import Pooler, PoolingType +from vllm.model_executor.models.gemma2 import Gemma2Model +from vllm.model_executor.models.utils import WeightsMapper, maybe_prefix +from vllm.model_executor.pooling_metadata import PoolingMetadata +from vllm.sequence import IntermediateTensors, PoolerOutput -class MyGemma2Embedding(Gemma2EmbeddingModel): +class MyGemma2Embedding(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + self.model = Gemma2Model(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + + self._pooler = Pooler.from_config_with_defaults( + vllm_config.model_config.pooler_config, + pooling_type=PoolingType.LAST, + normalize=True, + softmax=False, + ) + + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) def forward( self, @@ -18,7 +39,7 @@ def forward( intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = super().forward( + hidden_states = self.model( input_ids, positions, kv_caches, @@ -32,3 +53,17 @@ def forward( # Return all-zero embeddings return torch.zeros_like(hidden_states) + + def pooler( + self, + hidden_states: torch.Tensor, + pooling_metadata: PoolingMetadata, + ) -> Optional[PoolerOutput]: + return self._pooler(hidden_states, pooling_metadata) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""}) + weights = hf_to_vllm_mapper.apply(weights) + weights = ((name, data) for name, data in weights + if not name.startswith("lm_head.")) + return self.model.load_weights(weights) diff --git a/tests/prefix_caching/test_prefix_caching.py b/tests/prefix_caching/test_prefix_caching.py index 50723dbb610ac..8d16710f14585 100644 --- a/tests/prefix_caching/test_prefix_caching.py +++ b/tests/prefix_caching/test_prefix_caching.py @@ -2,10 +2,15 @@ Run `pytest tests/prefix_caching/test_prefix_caching.py`. """ + import pytest +from tests.conftest import VllmRunner +from tests.core.utils import SchedulerProxy, create_dummy_prompt from tests.kernels.utils import override_backend_env_variable from vllm import SamplingParams, TokensPrompt +from vllm.core.scheduler import Scheduler +from vllm.engine.llm_engine import LLMEngine from ..models.utils import check_outputs_equal @@ -27,6 +32,7 @@ @pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("max_tokens", [5]) @pytest.mark.parametrize("cached_position", [0, 1]) +@pytest.mark.parametrize("enable_chunked_prefill", [True, False]) @pytest.mark.parametrize("block_size", [16]) def test_mixed_requests( hf_runner, @@ -37,6 +43,7 @@ def test_mixed_requests( dtype: str, max_tokens: int, cached_position: int, + enable_chunked_prefill: bool, block_size: int, monkeypatch, ) -> None: @@ -55,6 +62,7 @@ def test_mixed_requests( model, dtype=dtype, enable_prefix_caching=True, + enable_chunked_prefill=enable_chunked_prefill, block_size=block_size, ) as vllm_model: # Run the first prompt so the cache is populated @@ -72,13 +80,13 @@ def test_mixed_requests( block_size) * block_size else: expected_num_cached_tokens = 0 - assert req_outputs[ - i].num_cached_tokens == expected_num_cached_tokens + assert ( + req_outputs[i].num_cached_tokens == expected_num_cached_tokens) - vllm_outputs = [ - (output.prompt_token_ids + list(output.outputs[0].token_ids), - output.prompt + output.outputs[0].text) for output in req_outputs - ] + vllm_outputs = [( + output.prompt_token_ids + list(output.outputs[0].token_ids), + output.prompt + output.outputs[0].text, + ) for output in req_outputs] check_outputs_equal( outputs_0_lst=hf_outputs, @@ -105,3 +113,89 @@ def test_unstable_prompt_sequence( for prompt in UNSTABLE_PROMPT_SEQUENCE: vllm_model.generate(TokensPrompt(prompt_token_ids=prompt), SamplingParams(max_tokens=1)) + + +@pytest.mark.parametrize("model", MODELS) +def test_fully_cached_prefill_needs_uncached_token(model): + block_size = 16 + max_num_batched_tokens = 16 + num_output_tokens = 5 + # Make a vllm engine + runner = VllmRunner( + model_name=model, + gpu_memory_utilization=0.7, + enable_chunked_prefill=True, + enforce_eager=True, + enable_prefix_caching=True, + block_size=block_size, + max_num_batched_tokens=max_num_batched_tokens, + max_num_seqs=max_num_batched_tokens, + ) + engine: LLMEngine = runner.model.llm_engine + + scheduler: Scheduler = SchedulerProxy(engine.scheduler[0]) # type: ignore + engine.scheduler[0] = scheduler + + # SeqA + seqA_tokens = list(range(2 * block_size)) + seqA, seq_groupA = create_dummy_prompt( + request_id="0", + prompt_tokens=seqA_tokens, + max_tokens=num_output_tokens, + block_size=block_size, + ) + + scheduler.add_seq_group(seq_groupA) + + assert seqA.data.get_num_computed_tokens() == 0 + + # Prefill seqA + while not seqA.is_finished(): + engine.step() + + # seqB + seqB_tokens = [t + 1 for t in seqA_tokens] # shift by 1 + seqB, seq_groupB = create_dummy_prompt( + request_id="1", + prompt_tokens=seqB_tokens, + max_tokens=num_output_tokens, + block_size=block_size, + ) + + # seqC is the same as seqA + seqC, seq_groupC = create_dummy_prompt( + request_id="2", + prompt_tokens=seqA_tokens, + max_tokens=num_output_tokens, + block_size=block_size, + ) + + scheduler.add_seq_group(seq_groupB) + scheduler.add_seq_group(seq_groupC) + + # Even seqC is fully cached, it should not be prefilled since we + # require at least 1 uncached token. + engine.step() + + sched_metas, sched_out, _ = scheduler.last_schedule_ret() + assert len(sched_out.scheduled_seq_groups) == 1 + assert (sched_out.scheduled_seq_groups[0].seq_group.request_id == + seq_groupB.request_id) + assert (sched_out.scheduled_seq_groups[0].token_chunk_size == + max_num_batched_tokens) + + # When seqB is finished, seqC could be prefilled. + while not seqB.is_finished(): + engine.step() + sched_metas, sched_out, _ = scheduler.last_schedule_ret() + assert len(sched_out.scheduled_seq_groups) == 1 + assert (sched_out.scheduled_seq_groups[0].seq_group.request_id == + seq_groupB.request_id) + + engine.step() + sched_metas, sched_out, _ = scheduler.last_schedule_ret() + assert len(sched_out.scheduled_seq_groups) == 1 + assert (sched_out.scheduled_seq_groups[0].seq_group.request_id == + seq_groupC.request_id) + assert sched_out.scheduled_seq_groups[0].token_chunk_size == len( + seqA_tokens) diff --git a/tests/quantization/test_ipex_quant.py b/tests/quantization/test_ipex_quant.py index d541efcefcac3..68a73f0f8ab48 100644 --- a/tests/quantization/test_ipex_quant.py +++ b/tests/quantization/test_ipex_quant.py @@ -1,5 +1,5 @@ """Test model set-up and inference for quantized HF models supported - on the CPU backend using IPEX (including AWQ). + on the CPU/GPU backend using IPEX (including AWQ/GPTQ). Validating the configuration and printing results for manual checking. @@ -11,13 +11,15 @@ from vllm.platforms import current_platform MODELS = [ - "casperhansen/llama-3-8b-instruct-awq", + "AMead10/Llama-3.2-1B-Instruct-AWQ", + "shuyuej/Llama-3.2-1B-Instruct-GPTQ", # with g_idx ] DTYPE = ["bfloat16"] -@pytest.mark.skipif(not current_platform.is_cpu(), - reason="only supports the CPU backend.") +@pytest.mark.skipif(not current_platform.is_cpu() + and not current_platform.is_xpu(), + reason="only supports Intel CPU/XPU backend.") @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", DTYPE) def test_ipex_quant(vllm_runner, model, dtype): diff --git a/tests/quantization/utils.py b/tests/quantization/utils.py index 061a077592e80..8ebd8dd2be0d5 100644 --- a/tests/quantization/utils.py +++ b/tests/quantization/utils.py @@ -1,4 +1,4 @@ -from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS +from vllm.model_executor.layers.quantization import get_quantization_config from vllm.platforms import current_platform @@ -10,6 +10,6 @@ def is_quant_method_supported(quant_method: str) -> bool: capability = current_platform.get_device_capability() assert capability is not None - min_capability = QUANTIZATION_METHODS[quant_method].get_min_capability() + min_capability = get_quantization_config(quant_method).get_min_capability() return capability.to_int() >= min_capability diff --git a/tests/spec_decode/e2e/test_compatibility.py b/tests/spec_decode/e2e/test_compatibility.py index a3f0464e79675..af8397c235f48 100644 --- a/tests/spec_decode/e2e/test_compatibility.py +++ b/tests/spec_decode/e2e/test_compatibility.py @@ -50,49 +50,3 @@ def test_spec_decode_xfail_spec_max_model_len(test_llm_generator): with pytest.raises(ValueError, match="cannot be larger than"): get_output_from_llm_generator(test_llm_generator, prompts, sampling_params) - - -@pytest.mark.parametrize("common_llm_kwargs", - [{ - "model": "meta-llama/Llama-2-7b-chat-hf", - "speculative_model": "JackFram/llama-68m", - "num_speculative_tokens": 5, - "enable_chunked_prefill": "True", - }]) -@pytest.mark.parametrize("per_test_common_llm_kwargs", [ - { - "tensor_parallel_size": 2, - "speculative_draft_tensor_parallel_size": 2, - }, - { - "tensor_parallel_size": 4, - "speculative_draft_tensor_parallel_size": 4, - }, - { - "tensor_parallel_size": 8, - "speculative_draft_tensor_parallel_size": 8, - }, -]) -@pytest.mark.parametrize("test_llm_kwargs", [{}]) -@pytest.mark.parametrize("seed", [1]) -def test_spec_decode_xfail_chunked_prefill_draft_model_tp_not_one( - test_llm_generator): - """Verify that speculative decoding fails if chunked prefill is enabled for - draft model with tensor parallelism of more than 1. - """ - output_len = 128 - temperature = 0.0 - - prompts = [ - "Hello, my name is", - ] - - sampling_params = SamplingParams( - max_tokens=output_len, - ignore_eos=True, - temperature=temperature, - ) - - with pytest.raises(ValueError, match="with tensor parallel size 1"): - get_output_from_llm_generator(test_llm_generator, prompts, - sampling_params) diff --git a/tests/spec_decode/e2e/test_integration_dist_tp2.py b/tests/spec_decode/e2e/test_integration_dist_tp2.py index 25562ca85adf4..02cba92795142 100644 --- a/tests/spec_decode/e2e/test_integration_dist_tp2.py +++ b/tests/spec_decode/e2e/test_integration_dist_tp2.py @@ -115,3 +115,60 @@ def test_draft_model_tp_lt_target_model_tp2(model, common_llm_kwargs, max_output_len=32, seed=seed, temperature=0.0) + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, + reason="Need at least 2 GPUs to run the test.") +@pytest.mark.parametrize( + "common_llm_kwargs", + [[ + # Skip cuda graph recording for fast test. + "--enforce-eager", + "--tensor_parallel_size", + "2", + + # precision + "--dtype", + "bfloat16", + ]]) +@pytest.mark.parametrize( + "per_test_common_llm_kwargs", + [["--enable-chunked-prefill", "False"], + [ + "--enable-chunked-prefill", "True", "--max-num-batched-tokens", "4", + "--max-num-seqs", "4" + ]]) +@pytest.mark.parametrize("baseline_llm_kwargs", [[]]) +@pytest.mark.parametrize("model, test_llm_kwargs", + [("JackFram/llama-68m", [ + "--speculative-model", + "JackFram/llama-68m", + "--num_speculative-tokens", + "3", + ]), + ("JackFram/llama-68m", [ + "--speculative-model", + "JackFram/llama-68m", + "--num_speculative-tokens", + "3", + "--speculative-draft-tensor-parallel-size", + "1", + ])]) +@pytest.mark.parametrize("batch_size", [2]) +@pytest.mark.parametrize("seed", [1]) +def test_spec_decode_chunked_prefill_tp2(model, common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, test_llm_kwargs, + batch_size: int, seed: int): + """Verify spec decode works well with same and different TP size for + the draft model with chunked prefill. + """ + run_equality_correctness_test_tp(model, + common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, + test_llm_kwargs, + batch_size, + max_output_len=32, + seed=seed, + temperature=0.0) diff --git a/tests/spec_decode/e2e/test_mlp_correctness.py b/tests/spec_decode/e2e/test_mlp_correctness.py index 5ecc0d4e95719..183ff2f5db274 100644 --- a/tests/spec_decode/e2e/test_mlp_correctness.py +++ b/tests/spec_decode/e2e/test_mlp_correctness.py @@ -203,7 +203,7 @@ def test_mlp_e2e_acceptance_rate(vllm_runner, common_llm_kwargs, @pytest.mark.parametrize("test_llm_kwargs", [{"seed": 5}]) @pytest.mark.parametrize("output_len", [64]) @pytest.mark.parametrize("batch_size", [1, 32]) -@pytest.mark.parametrize("temperature", [0.1, 1.0]) +@pytest.mark.parametrize("temperature", [1.0]) @pytest.mark.parametrize("seed", [1]) def test_mlp_e2e_seeded_correctness(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs, diff --git a/tests/spec_decode/test_batch_expansion.py b/tests/spec_decode/test_batch_expansion.py index 0d6aaa449d856..3504fcf43e361 100644 --- a/tests/spec_decode/test_batch_expansion.py +++ b/tests/spec_decode/test_batch_expansion.py @@ -90,6 +90,14 @@ def test_create_single_target_seq_group_metadata(k: int): ) assert output.request_id == input_seq_group_metadata.request_id + assert output.sampling_params.repetition_penalty == \ + input_seq_group_metadata.sampling_params.repetition_penalty + assert output.sampling_params.temperature == \ + input_seq_group_metadata.sampling_params.temperature + assert output.sampling_params.top_p == \ + input_seq_group_metadata.sampling_params.top_p + assert output.sampling_params.top_k == \ + input_seq_group_metadata.sampling_params.top_k assert len(output.seq_data) == 1 assert output.seq_data[target_seq_id].get_prompt_token_ids() == tuple( prompt_tokens) diff --git a/tests/spec_decode/test_spec_decode_worker.py b/tests/spec_decode/test_spec_decode_worker.py index 8df143104c279..caf7a7e625b46 100644 --- a/tests/spec_decode/test_spec_decode_worker.py +++ b/tests/spec_decode/test_spec_decode_worker.py @@ -595,8 +595,8 @@ def test_init_device(acceptance_sampler_method: str): target_worker.init_device.assert_called_once() - metrics_collector.init_gpu_tensors.assert_called_once() - spec_decode_sampler.init_gpu_tensors.assert_called_once() + metrics_collector.init_tensors.assert_called_once() + spec_decode_sampler.init_tensors.assert_called_once() @pytest.mark.parametrize("acceptance_sampler_method", @@ -867,7 +867,8 @@ def test_chunked_prefill_flow(k: int, batch_size: int, batch_composition: str): target_group_metadata_list = prefill + decodes execute_model_req = ExecuteModelRequest( seq_group_metadata_list=target_group_metadata_list, - num_lookahead_slots=k) + # For prefill only batches we expect num_lookahead_slots = 0. + num_lookahead_slots=k if n_decodes > 0 else 0) target_token_ids = torch.randint(low=0, high=vocab_size, diff --git a/tests/test_config.py b/tests/test_config.py index 3cf90297ce177..45b0b938af215 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -26,8 +26,7 @@ def test_auto_task(model_id, expected_task): @pytest.mark.parametrize(("model_id", "bad_task"), [ - ("facebook/opt-125m", "embedding"), - ("intfloat/e5-mistral-7b-instruct", "generate"), + ("Qwen/Qwen2.5-Math-RM-72B", "generate"), ]) def test_incorrect_task(model_id, bad_task): with pytest.raises(ValueError, match=r"does not support the .* task"): diff --git a/tests/test_lazy_torch_compile.py b/tests/test_lazy_torch_compile.py new file mode 100644 index 0000000000000..b950877a4337b --- /dev/null +++ b/tests/test_lazy_torch_compile.py @@ -0,0 +1,28 @@ +# Description: Test the lazy import module +# The utility function cannot be placed in `vllm.utils` +# this needs to be a standalone script +import sys +from contextlib import nullcontext + +from vllm_test_utils import BlameResult, blame + +module_name = "torch._inductor.async_compile" + +# In CI, we only check finally if the module is imported. +# If it is indeed imported, we can rerun the test with `use_blame=True`, +# which will trace every function call to find the first import location, +# and help find the root cause. +# We don't run it in CI by default because it is slow. +use_blame = False +context = blame( + lambda: module_name in sys.modules) if use_blame else nullcontext() +with context as result: + import vllm # noqa + +if use_blame: + assert isinstance(result, BlameResult) + print(f"the first import location is:\n{result.trace_stack}") + +assert module_name not in sys.modules, ( + f"Module {module_name} is imported. To see the first" + f" import location, run the test with `use_blame=True`.") diff --git a/tests/tpu/test_compilation.py b/tests/tpu/test_compilation.py index 941abe17a3378..b7124ebc1b0f3 100644 --- a/tests/tpu/test_compilation.py +++ b/tests/tpu/test_compilation.py @@ -1,24 +1,46 @@ import glob import os -import runpy import tempfile import depyf from vllm.config import CompilationLevel -# disable custom dispatcher, let Dynamo takes over -# all the control -os.environ['VLLM_TORCH_COMPILE_LEVEL'] = str(CompilationLevel.DYNAMO_AS_IS) - temp_dir = tempfile.mkdtemp() with depyf.prepare_debug(temp_dir): - cur_dir = os.path.dirname(__file__) - parent_dir = os.path.dirname(cur_dir) - root_dir = os.path.dirname(parent_dir) - example_file = os.path.join(root_dir, "examples", - "offline_inference_tpu.py") - runpy.run_path(example_file) + from vllm import LLM, SamplingParams + + prompts = [ + "A robot may not injure a human being", + "It is only with the heart that one can see rightly;", + "The greatest glory in living lies not in never falling,", + ] + answers = [ + " or, through inaction, allow a human being to come to harm.", + " what is essential is invisible to the eye.", + " but in rising every time we fall.", + ] + N = 1 + # Currently, top-p sampling is disabled. `top_p` should be 1.0. + sampling_params = SamplingParams(temperature=0.7, + top_p=1.0, + n=N, + max_tokens=16) + + # Set `enforce_eager=True` to avoid ahead-of-time compilation. + # In real workloads, `enforace_eager` should be `False`. + + # disable custom dispatcher, let Dynamo takes over + # all the control + llm = LLM(model="google/gemma-2b", + enforce_eager=True, + compilation_config={"level": CompilationLevel.DYNAMO_AS_IS}) + outputs = llm.generate(prompts, sampling_params) + for output, answer in zip(outputs, answers): + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + assert generated_text.startswith(answer) compiled_code = sorted( glob.glob(os.path.join(temp_dir, "__transformed_code*.py"))) diff --git a/tests/tpu/test_custom_dispatcher.py b/tests/tpu/test_custom_dispatcher.py index 53b10c06135a1..bb1379deba3fc 100644 --- a/tests/tpu/test_custom_dispatcher.py +++ b/tests/tpu/test_custom_dispatcher.py @@ -13,7 +13,10 @@ def test_custom_dispatcher(): compare_two_settings( "google/gemma-2b", - arg1=["--enforce-eager"], - arg2=["--enforce-eager"], - env1={"VLLM_TORCH_COMPILE_LEVEL": str(CompilationLevel.DYNAMO_ONCE)}, - env2={"VLLM_TORCH_COMPILE_LEVEL": str(CompilationLevel.DYNAMO_AS_IS)}) + arg1=[ + "--enforce-eager", + f"-O{CompilationLevel.DYNAMO_ONCE}", + ], + arg2=["--enforce-eager", f"-O{CompilationLevel.DYNAMO_AS_IS}"], + env1={}, + env2={}) diff --git a/tests/v1/core/test_prefix_caching.py b/tests/v1/core/test_prefix_caching.py index d614d3e67460f..b44d3e5cb0678 100644 --- a/tests/v1/core/test_prefix_caching.py +++ b/tests/v1/core/test_prefix_caching.py @@ -1,8 +1,11 @@ """Compare the with and without prefix caching.""" +import pytest + from vllm.inputs import token_inputs from vllm.sampling_params import SamplingParams +from vllm.utils import cdiv from vllm.v1.core.kv_cache_manager import KVCacheManager, Request -from vllm.v1.core.kv_cache_utils import hash_block_tokens +from vllm.v1.core.kv_cache_utils import KVCacheBlock, hash_block_tokens def make_request(request_id, prompt_token_ids): @@ -20,7 +23,8 @@ def test_prefill(): manager = KVCacheManager( block_size=16, num_gpu_blocks=10, - sliding_window=False, + max_model_len=8192, + sliding_window=None, enable_caching=True, num_preallocate_tokens=16, ) @@ -31,7 +35,8 @@ def test_prefill(): # Fully cache miss # Incomplete 1 block (7 tokens) unique_token_ids = [3] * 7 - req0 = make_request("0", common_token_ids + unique_token_ids) + all_token_ids = common_token_ids + unique_token_ids + req0 = make_request("0", all_token_ids) computed_blocks = manager.get_computed_blocks(req0) assert not computed_blocks blocks = manager.allocate_slots(req0, 55, computed_blocks) @@ -40,24 +45,16 @@ def test_prefill(): # Check full block metadata parent_block_hash = None for block_id in (0, 1, 2): - block_hash = hash_block_tokens(parent_block_hash, - manager.block_pool[block_id].token_ids) + block_tokens = tuple(all_token_ids[block_id * 16:(block_id + 1) * 16]) + block_hash = hash_block_tokens(parent_block_hash, block_tokens) assert manager.block_pool[block_id].block_hash == block_hash assert manager.block_pool[block_id].ref_cnt == 1 - assert manager.block_pool[block_id].num_hashed_tokens == 16 * ( - block_id + 1) - assert manager.block_pool[block_id].token_ids == tuple([block_id] * 16) parent_block_hash = block_hash # Check partial/preallocated block metadata for block_id in (3, 4): assert manager.block_pool[block_id].block_hash is None assert manager.block_pool[block_id].ref_cnt == 1 - assert manager.block_pool[block_id].num_hashed_tokens == 0 - if block_id == 3: - assert manager.block_pool[block_id].token_ids == [3] * 7 - else: - assert not manager.block_pool[block_id].token_ids # Cache hit in the common prefix when the original block is still in use. # Incomplete 1 block (5 tokens) @@ -113,7 +110,7 @@ def test_prefill(): req3 = make_request("3", [99] * (16 * 9)) computed_blocks = manager.get_computed_blocks(req3) assert not computed_blocks - blocks = manager.allocate_slots(req2, 16 * 9, computed_blocks) + blocks = manager.allocate_slots(req3, 16 * 9, computed_blocks) # This block ID order also checks the eviction order. assert [b.block_id for b in blocks] == [9, 4, 3, 6, 5, 8, 7, 2, 1, 0] assert manager.free_block_queue.num_free_blocks == 0 @@ -125,7 +122,8 @@ def test_decode(): manager = KVCacheManager( block_size=16, num_gpu_blocks=10, - sliding_window=False, + max_model_len=8192, + sliding_window=None, enable_caching=True, num_preallocate_tokens=16, ) @@ -148,7 +146,7 @@ def test_decode(): req0.append_output_token_ids(8) new_blocks = manager.append_slots(req0, 4) assert new_blocks is not None and len(new_blocks) == 0 - assert len(manager.block_pool[3].token_ids) == 11 + assert manager.req_to_blocks[req0.request_id][-2].block_hash is None # Append slots without allocating a new block, but start using the # preallocated block. @@ -159,8 +157,7 @@ def test_decode(): req0.append_output_token_ids(7) new_blocks = manager.append_slots(req0, 15) assert new_blocks is not None and len(new_blocks) == 0 - assert len(manager.block_pool[3].token_ids) == 16 - assert len(manager.block_pool[4].token_ids) == 10 + assert manager.req_to_blocks[req0.request_id][-2].block_hash is not None # Append slots with allocating a new block. req0.num_computed_tokens = 74 @@ -171,16 +168,14 @@ def test_decode(): new_blocks = manager.append_slots(req0, 17) # Plus one preallocated block. assert new_blocks is not None and len(new_blocks) == 2 - assert len(manager.block_pool[4].token_ids) == 16 - assert len(manager.block_pool[5].token_ids) == 11 - assert len(manager.block_pool[6].token_ids) == 0 def test_evict(): manager = KVCacheManager( block_size=16, num_gpu_blocks=10, - sliding_window=False, + max_model_len=8192, + sliding_window=None, enable_caching=True, num_preallocate_tokens=16, ) @@ -217,3 +212,203 @@ def test_evict(): blocks = manager.allocate_slots(req2, 3, computed_blocks) assert [b.block_id for b in blocks] == [6, 5] assert manager.free_block_queue.num_free_blocks == 6 + + +def test_hash_block_correct_reuse(): + """ + This tests when a previously cached block is reused as a new block, + its hash metadata should be correctly reset. + """ + block_size = 16 + manager = KVCacheManager( + block_size=block_size, + num_gpu_blocks=1, + max_model_len=8192, + sliding_window=None, + enable_caching=True, + num_preallocate_tokens=0, + ) + + # Allocate 1 block and cache it. + num_tokens = block_size * 1 + req = make_request("0", list(range(num_tokens))) + computed_blocks = manager.get_computed_blocks(req) + assert not computed_blocks + blocks = manager.allocate_slots(req, num_tokens, computed_blocks) + assert len(blocks) == 1 + + # Deallocate the block. + manager.free(req) + + # Allocate a new block that's not full, make sure hash info on the + # block is cleared. + req = make_request("1", list(range(num_tokens - 1))) + computed_blocks = manager.get_computed_blocks(req) + assert not computed_blocks + blocks = manager.allocate_slots(req, num_tokens - 1, computed_blocks) + assert len(blocks) == 1 + + assert manager.block_pool[blocks[0].block_id].block_hash is None + + +def test_computed_blocks_not_evicted(): + """ + Test that the computed blocks are not evicted when getting new blocks + for a request if there are any other free blocks. + """ + block_size = 16 + manager = KVCacheManager( + block_size=block_size, + num_gpu_blocks=2, + max_model_len=8192, + sliding_window=None, + enable_caching=True, + num_preallocate_tokens=0, + ) + + # Allocate a block and cache it. + num_tokens = block_size * 1 + req0 = make_request("0", list(range(num_tokens))) + computed_blocks = manager.get_computed_blocks(req0) + assert not computed_blocks + blocks = manager.allocate_slots(req0, num_tokens, computed_blocks) + assert len(blocks) == 1 + assert blocks[0].block_id == 0 + + # Allocate another block. + req1 = make_request("1", list(range(num_tokens, num_tokens * 2))) + computed_blocks = manager.get_computed_blocks(req1) + assert not computed_blocks + blocks = manager.allocate_slots(req1, num_tokens, computed_blocks) + assert len(blocks) == 1 + assert blocks[0].block_id == 1 + + # Free the blocks. + manager.free(req0) + manager.free(req1) + + # Now if we have a cache hit on the first block, we should evict the second + # cached block rather than the first one. + req2 = make_request("2", list(range(num_tokens * 2))) + computed_blocks = manager.get_computed_blocks(req2) + assert len(computed_blocks) == 1 + assert computed_blocks[0].block_id == 0 + + blocks = manager.allocate_slots(req2, num_tokens * 2 - num_tokens, + computed_blocks) + assert len(blocks) == 1 + assert blocks[0].block_id == 1 + + +def test_basic_prefix_caching_disabled(): + """ + This tests that the prefix caching is disabled. + """ + block_size = 4 + manager = KVCacheManager( + block_size=block_size, + num_gpu_blocks=4, + max_model_len=8192, + sliding_window=None, + enable_caching=False, + num_preallocate_tokens=0, + ) + + req1 = make_request("1", list(range(10))) # 2 blocks and some more + + computed_blocks = manager.get_computed_blocks(req1) + assert not computed_blocks + blocks = manager.allocate_slots(req1, 10, computed_blocks) + assert len(blocks) == 3 + + # Free the blocks. + manager.free(req1) + + # No caching. + req2 = make_request("2", list(range(16))) # shared prefix + computed_blocks = manager.get_computed_blocks(req2) + assert not computed_blocks + blocks = manager.allocate_slots(req2, 16, computed_blocks) + assert len(blocks) == 4 + + # New requests should not have any blocks. + req3 = make_request("3", list(range(4))) + computed_blocks = manager.get_computed_blocks(req3) + assert not computed_blocks + blocks = manager.allocate_slots(req3, 4, computed_blocks) + assert not blocks + + +@pytest.mark.parametrize("num_preallocate_tokens", list(range(0, 8))) +@pytest.mark.parametrize("block_size", [4]) +def test_preallocate_blocks(num_preallocate_tokens: int, block_size: int): + """ + This tests that the preallocated blocks are correctly added. + """ + manager = KVCacheManager( + block_size=block_size, + num_gpu_blocks=10, + max_model_len=8192, + sliding_window=None, + enable_caching=True, + num_preallocate_tokens=num_preallocate_tokens, + ) + num_preallocated_blocks = cdiv(num_preallocate_tokens, block_size) + + req = make_request("0", list(range(block_size * 30))) + computed_blocks = manager.get_computed_blocks(req) + assert not computed_blocks + # Just ask for 1 block. + blocks = manager.allocate_slots(req, block_size, computed_blocks) + assert len(blocks) == 1 + num_preallocated_blocks + + # Append slots to the block. + req.num_computed_tokens = block_size * len(blocks) # Assume all used. + blocks = manager.append_slots(req, block_size) # Append 1 block. + assert len(blocks) == 1 + num_preallocated_blocks + + +def test_cache_blocks(): + """ + This is a unit test that tests the correctness of the _cache_full_blocks + function of KVCacheManager. + """ + block_size = 4 + manager = KVCacheManager( + block_size=block_size, + num_gpu_blocks=5, + max_model_len=8192, + sliding_window=None, + enable_caching=True, + num_preallocate_tokens=0, + ) + # Req: + # Block 0: [0, 1, 2, 3] + # Block 1: [4, 5, 6, 7] + # Block 2: [8, 9, 10, 11] + # Block 3: [12, 13] + req = make_request("0", list(range(14))) + + # Test that blocks are cached correctly for 2 full blocks from the start. + blocks = [KVCacheBlock(block_id=i) for i in range(2)] + + manager._cache_full_blocks( + request=req, + blk_start_idx=0, + full_blocks=blocks, + prev_block=None, + ) + + assert len(manager.cached_block_hash_to_block) == 2 + assert all([block.block_hash is not None for block in blocks]) + + # Test that blocks that don't start from the beginning are cached correctly. + blocks = [KVCacheBlock(block_id=2)] + manager._cache_full_blocks( + request=req, + blk_start_idx=2, + full_blocks=blocks, + prev_block=None, + ) + assert len(manager.cached_block_hash_to_block) == 3 + assert blocks[0].block_hash is not None diff --git a/tests/v1/engine/test_async_llm.py b/tests/v1/engine/test_async_llm.py index 1f26fe0fc892f..fffb5b8100ec7 100644 --- a/tests/v1/engine/test_async_llm.py +++ b/tests/v1/engine/test_async_llm.py @@ -32,6 +32,9 @@ async def generate(engine: AsyncLLM, request_id: str, @pytest.mark.asyncio async def test_load(monkeypatch): + # TODO(rickyx): Remove monkeypatch once we have a better way to test V1 + # so that in the future when we switch, we don't have to change all the + # tests. with monkeypatch.context() as m: m.setenv("VLLM_USE_V1", "1") diff --git a/tests/v1/engine/test_engine_args.py b/tests/v1/engine/test_engine_args.py new file mode 100644 index 0000000000000..ac5e7dde525a7 --- /dev/null +++ b/tests/v1/engine/test_engine_args.py @@ -0,0 +1,61 @@ +import pytest + +from vllm import envs +from vllm.config import VllmConfig +from vllm.engine.arg_utils import EngineArgs +from vllm.usage.usage_lib import UsageContext +from vllm.utils import FlexibleArgumentParser + +if not envs.VLLM_USE_V1: + pytest.skip( + "Skipping V1 tests. Rerun with `VLLM_USE_V1=1` to test.", + allow_module_level=True, + ) + + +def test_prefix_caching_from_cli(): + parser = EngineArgs.add_cli_args(FlexibleArgumentParser()) + args = parser.parse_args([]) + engine_args = EngineArgs.from_cli_args(args=args) + assert (engine_args.enable_prefix_caching + ), "V1 turns on prefix caching by default." + + # Turn it off possible with flag. + args = parser.parse_args(["--no-enable-prefix-caching"]) + engine_args = EngineArgs.from_cli_args(args=args) + assert not engine_args.enable_prefix_caching + + # Turn it on with flag. + args = parser.parse_args(["--enable-prefix-caching"]) + engine_args = EngineArgs.from_cli_args(args=args) + assert engine_args.enable_prefix_caching + + +def test_defaults(): + engine_args = EngineArgs(model="facebook/opt-125m") + + # Assert V1 defaults + assert (engine_args.enable_prefix_caching + ), "V1 turns on prefix caching by default" + + +def test_defaults_with_usage_context(): + engine_args = EngineArgs(model="facebook/opt-125m") + vllm_config: VllmConfig = engine_args.create_engine_config( + UsageContext.LLM_CLASS) + + assert vllm_config.scheduler_config.max_num_seqs == 1024 + assert vllm_config.scheduler_config.max_num_batched_tokens == 8192 + + engine_args = EngineArgs(model="facebook/opt-125m") + vllm_config = engine_args.create_engine_config( + UsageContext.OPENAI_API_SERVER) + assert vllm_config.scheduler_config.max_num_seqs == 1024 + assert vllm_config.scheduler_config.max_num_batched_tokens == 2048 + + +def test_prefix_cache_disabled_with_multimodel(): + engine_args = EngineArgs(model="llava-hf/llava-1.5-7b-hf") + + vllm_config = engine_args.create_engine_config(UsageContext.LLM_CLASS) + assert not vllm_config.cache_config.enable_prefix_caching diff --git a/tests/v1/engine/test_engine_core.py b/tests/v1/engine/test_engine_core.py index b3692b594326a..bd11ff1877064 100644 --- a/tests/v1/engine/test_engine_core.py +++ b/tests/v1/engine/test_engine_core.py @@ -43,7 +43,8 @@ def test_engine_core(monkeypatch): m.setenv("VLLM_USE_V1", "1") """Setup the EngineCore.""" engine_args = EngineArgs(model=MODEL_NAME) - vllm_config = engine_args.create_engine_config() + vllm_config = engine_args.create_engine_config( + usage_context=UsageContext.UNKNOWN_CONTEXT) executor_class = AsyncLLM._get_executor_cls(vllm_config) engine_core = EngineCore(vllm_config=vllm_config, diff --git a/tests/v1/engine/test_engine_core_client.py b/tests/v1/engine/test_engine_core_client.py index 7b241bf836a0e..582192196aaf9 100644 --- a/tests/v1/engine/test_engine_core_client.py +++ b/tests/v1/engine/test_engine_core_client.py @@ -81,8 +81,9 @@ def test_engine_core_client(monkeypatch, multiprocessing_mode: bool): with monkeypatch.context() as m: m.setenv("VLLM_USE_V1", "1") - engine_args = EngineArgs(model=MODEL_NAME) - vllm_config = engine_args.create_engine_config() + engine_args = EngineArgs(model=MODEL_NAME, compilation_config=3) + vllm_config = engine_args.create_engine_config( + UsageContext.UNKNOWN_CONTEXT) executor_class = AsyncLLM._get_executor_cls(vllm_config) client = EngineCoreClient.make_client( vllm_config, @@ -153,7 +154,8 @@ async def test_engine_core_client_asyncio(monkeypatch): m.setenv("VLLM_USE_V1", "1") engine_args = EngineArgs(model=MODEL_NAME) - vllm_config = engine_args.create_engine_config() + vllm_config = engine_args.create_engine_config( + usage_context=UsageContext.UNKNOWN_CONTEXT) executor_class = AsyncLLM._get_executor_cls(vllm_config) client = EngineCoreClient.make_client( vllm_config, diff --git a/tests/vllm_test_utils/setup.py b/tests/vllm_test_utils/setup.py new file mode 100644 index 0000000000000..790e891ec837d --- /dev/null +++ b/tests/vllm_test_utils/setup.py @@ -0,0 +1,7 @@ +from setuptools import setup + +setup( + name='vllm_test_utils', + version='0.1', + packages=['vllm_test_utils'], +) diff --git a/tests/vllm_test_utils/vllm_test_utils/__init__.py b/tests/vllm_test_utils/vllm_test_utils/__init__.py new file mode 100644 index 0000000000000..bf0b62a5b75e3 --- /dev/null +++ b/tests/vllm_test_utils/vllm_test_utils/__init__.py @@ -0,0 +1,8 @@ +""" +vllm_utils is a package for vLLM testing utilities. +It does not import any vLLM modules. +""" + +from .blame import BlameResult, blame + +__all__ = ["blame", "BlameResult"] diff --git a/tests/vllm_test_utils/vllm_test_utils/blame.py b/tests/vllm_test_utils/vllm_test_utils/blame.py new file mode 100644 index 0000000000000..1ddd3471d357b --- /dev/null +++ b/tests/vllm_test_utils/vllm_test_utils/blame.py @@ -0,0 +1,53 @@ +import contextlib +import dataclasses +import sys +import traceback +from typing import Callable, Generator + + +@dataclasses.dataclass +class BlameResult: + found: bool = False + trace_stack: str = "" + + +@contextlib.contextmanager +def blame(func: Callable) -> Generator[BlameResult, None, None]: + """ + Trace the function calls to find the first function that satisfies the + condition. The trace stack will be stored in the result. + + Usage: + + ```python + with blame(lambda: some_condition()) as result: + # do something + + if result.found: + print(result.trace_stack) + """ + result = BlameResult() + + def _trace_calls(frame, event, arg=None): + nonlocal result + if event in ['call', 'return']: + # for every function call or return + try: + # Temporarily disable the trace function + sys.settrace(None) + # check condition here + if not result.found and func(): + result.found = True + result.trace_stack = "".join(traceback.format_stack()) + # Re-enable the trace function + sys.settrace(_trace_calls) + except NameError: + # modules are deleted during shutdown + pass + return _trace_calls + + try: + sys.settrace(_trace_calls) + yield result + finally: + sys.settrace(None) diff --git a/tests/weight_loading/models.txt b/tests/weight_loading/models.txt index a4ee9538d646b..2afffb5b9d1c8 100644 --- a/tests/weight_loading/models.txt +++ b/tests/weight_loading/models.txt @@ -27,4 +27,5 @@ fp8, neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV, main marlin, nm-testing/zephyr-beta-7b-marlin-g128, main marlin, robertgshaw2/zephyr-7b-beta-channelwise-marlin, main qqq, HandH1998/QQQ-Llama-3-8b-g128, main -qqq, HandH1998/QQQ-Llama-3-8b, main \ No newline at end of file +qqq, HandH1998/QQQ-Llama-3-8b, main +hqq, nm-testing/Llama-3.2-1B-Instruct-HQQ, main \ No newline at end of file diff --git a/tests/worker/test_model_input.py b/tests/worker/test_model_input.py index b36e8bfe73ff3..309854e6babf3 100644 --- a/tests/worker/test_model_input.py +++ b/tests/worker/test_model_input.py @@ -8,10 +8,10 @@ from vllm.attention.backends.utils import CommonAttentionState from vllm.model_executor import SamplingMetadata from vllm.model_executor.pooling_metadata import PoolingMetadata -from vllm.worker.embedding_model_runner import ( - ModelInputForGPUWithPoolingMetadata) from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata from vllm.worker.multi_step_model_runner import StatefulModelInput +from vllm.worker.pooling_model_runner import ( + ModelInputForGPUWithPoolingMetadata) class MockAttentionBackend(AttentionBackend): diff --git a/tools/png-lint.sh b/tools/png-lint.sh new file mode 100755 index 0000000000000..a80fe9837342f --- /dev/null +++ b/tools/png-lint.sh @@ -0,0 +1,15 @@ +#!/bin/bash + +# Ensure that *.excalidraw.png files have the excalidraw metadata +# embedded in them. This ensures they can be loaded back into +# the tool and edited in the future. + +find . -iname '*.excalidraw.png' | while read -r file; do + if git check-ignore -q "$file"; then + continue + fi + if ! grep -q "excalidraw+json" "$file"; then + echo "$file was not exported from excalidraw with 'Embed Scene' enabled." + exit 1 + fi +done diff --git a/tools/sphinx-lint.sh b/tools/sphinx-lint.sh new file mode 100755 index 0000000000000..04f8075c5527f --- /dev/null +++ b/tools/sphinx-lint.sh @@ -0,0 +1,3 @@ +#!/bin/bash + +sphinx-lint --disable trailing-whitespace,missing-final-newline docs diff --git a/vllm/__init__.py b/vllm/__init__.py index 3c8ad70848fa1..493d5c5649733 100644 --- a/vllm/__init__.py +++ b/vllm/__init__.py @@ -8,8 +8,8 @@ from vllm.executor.ray_utils import initialize_ray_cluster from vllm.inputs import PromptType, TextPrompt, TokensPrompt from vllm.model_executor.models import ModelRegistry -from vllm.outputs import (CompletionOutput, EmbeddingOutput, - EmbeddingRequestOutput, RequestOutput) +from vllm.outputs import (CompletionOutput, PoolingOutput, + PoolingRequestOutput, RequestOutput) from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams @@ -27,8 +27,8 @@ "SamplingParams", "RequestOutput", "CompletionOutput", - "EmbeddingOutput", - "EmbeddingRequestOutput", + "PoolingOutput", + "PoolingRequestOutput", "LLMEngine", "EngineArgs", "AsyncLLMEngine", @@ -36,3 +36,26 @@ "initialize_ray_cluster", "PoolingParams", ] + + +def __getattr__(name: str): + import warnings + + if name == "EmbeddingOutput": + msg = ("EmbeddingOutput has been renamed to PoolingOutput. " + "The original name will be removed in an upcoming version.") + + warnings.warn(DeprecationWarning(msg), stacklevel=2) + + return PoolingOutput + + if name == "EmbeddingRequestOutput": + msg = ("EmbeddingRequestOutput has been renamed to " + "PoolingRequestOutput. " + "The original name will be removed in an upcoming version.") + + warnings.warn(DeprecationWarning(msg), stacklevel=2) + + return PoolingRequestOutput + + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 5751e5caba1a7..6bb660e7ca350 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -20,9 +20,6 @@ except ImportError as e: logger.warning("Failed to import from vllm._C with %r", e) -if current_platform.is_rocm(): - import vllm._rocm_C # noqa: F401 - supports_moe_ops = False with contextlib.suppress(ImportError): import vllm._moe_C # noqa: F401 @@ -369,34 +366,10 @@ def _gptq_marlin_gemm_fake(a: torch.Tensor, size_k: torch.SymInt, is_k_full: bool, has_zp: bool = False, - use_fp32_reduce: bool = False) -> torch.Tensor: + use_fp32_reduce: bool = False, + is_zp_float: bool = False) -> torch.Tensor: return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype) - @register_fake("_C::ggml_dequantize") - def _ggml_dequantize_fake(W: torch.Tensor, quant_type: int, - m: torch.SymInt, - n: torch.SymInt) -> torch.Tensor: - return torch.empty((m, n), dtype=torch.float16, device=W.device) - - @register_fake("_C::ggml_mul_mat_vec_a8") - def _ggml_mul_mat_vec_a8_fake( - W: torch.Tensor, - X: torch.Tensor, - quant_type: int, - row: torch.SymInt, - ) -> torch.Tensor: - return torch.empty((1, row), dtype=torch.float16, device=W.device) - - @register_fake("_C::ggml_mul_mat_a8") - def _ggml_mul_mat_a8_fake( - W: torch.Tensor, - X: torch.Tensor, - quant_type: int, - row: torch.SymInt, - ) -> torch.Tensor: - batch = X.size(0) - return torch.empty((batch, row), dtype=torch.float16, device=W.device) - @register_fake("_C::marlin_qqq_gemm") def _marlin_qqq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, s_tok: torch.Tensor, s_ch: torch.Tensor, @@ -470,18 +443,18 @@ def _fp8_marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, size_k: torch.SymInt) -> torch.Tensor: return torch.empty((size_m, size_n), dtype=a.dtype, device=a.device) - @register_fake("_C::machete_gemm") - def machete_gemm_fake( + @register_fake("_C::machete_mm") + def machete_mm_fake( a: torch.Tensor, - # Should be the tensor returned by machete_prepack_B + # b_q Should be the tensor returned by machete_prepack_B b_q: torch.Tensor, b_type: ScalarType, - b_scales: Optional[torch.Tensor] = None, - b_zeros: Optional[torch.Tensor] = None, + out_type: Optional[torch.dtype] = None, + b_group_scales: Optional[torch.Tensor] = None, + b_group_zeros: Optional[torch.Tensor] = None, b_group_size: Optional[int] = None, - c: Optional[torch.Tensor] = None, - alpha: Optional[float] = None, - beta: Optional[float] = None, + b_channel_scales: Optional[torch.Tensor] = None, + a_token_scales: Optional[torch.Tensor] = None, schedule: Optional[str] = None, ) -> torch.Tensor: m = a.size(0) @@ -489,12 +462,41 @@ def machete_gemm_fake( return torch.empty((m, n), device=a.device, dtype=a.dtype) @register_fake("_C::machete_prepack_B") - def machete_prepack_B_fake(b_q_weight: torch.Tensor, - b_type: ScalarType) -> torch.Tensor: + def machete_prepack_B_fake( + b_q_weight: torch.Tensor, a_type: torch.dtype, b_type: ScalarType, + group_scales_type: Optional[torch.dtype]) -> torch.Tensor: return torch.empty_like(b_q_weight, memory_format=torch.contiguous_format) +if hasattr(torch.ops._C, "ggml_dequantize"): + + @register_fake("_C::ggml_dequantize") + def _ggml_dequantize_fake(W: torch.Tensor, quant_type: int, + m: torch.SymInt, + n: torch.SymInt) -> torch.Tensor: + return torch.empty((m, n), dtype=torch.float16, device=W.device) + + @register_fake("_C::ggml_mul_mat_vec_a8") + def _ggml_mul_mat_vec_a8_fake( + W: torch.Tensor, + X: torch.Tensor, + quant_type: int, + row: torch.SymInt, + ) -> torch.Tensor: + return torch.empty((1, row), dtype=torch.float16, device=W.device) + + @register_fake("_C::ggml_mul_mat_a8") + def _ggml_mul_mat_a8_fake( + W: torch.Tensor, + X: torch.Tensor, + quant_type: int, + row: torch.SymInt, + ) -> torch.Tensor: + batch = X.size(0) + return torch.empty((batch, row), dtype=torch.float16, device=W.device) + + # cutlass def cutlass_scaled_mm_supports_fp8(cuda_device_capability: int) -> bool: return torch.ops._C.cutlass_scaled_mm_supports_fp8(cuda_device_capability) @@ -626,11 +628,12 @@ def gptq_marlin_gemm(a: torch.Tensor, size_k: int, is_k_full: bool, has_zp: bool = False, - use_fp32_reduce: bool = False) -> torch.Tensor: + use_fp32_reduce: bool = False, + is_zp_float: bool = False) -> torch.Tensor: return torch.ops._C.gptq_marlin_gemm(a, b_q_weight, b_scales, b_zeros, g_idx, perm, workspace, b_q_type.id, size_m, size_n, size_k, is_k_full, - has_zp, use_fp32_reduce) + has_zp, use_fp32_reduce, is_zp_float) # fp8 marlin @@ -643,29 +646,41 @@ def fp8_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, # machete -def machete_supported_schedules(b_type: ScalarType) -> List[str]: - return torch.ops._C.machete_supported_schedules(b_type.id) - - -def machete_gemm( - a: torch.Tensor, - b_q: torch.Tensor, # Should be the tensor returned by machete_prepack_B - b_type: ScalarType, - b_scales: Optional[torch.Tensor] = None, - b_zeros: Optional[torch.Tensor] = None, - b_group_size: Optional[int] = None, - c: Optional[torch.Tensor] = None, - alpha: Optional[float] = None, - beta: Optional[float] = None, - schedule: Optional[str] = None, -) -> torch.Tensor: - return torch.ops._C.machete_gemm(a, b_q, b_type.id, b_scales, b_zeros, - b_group_size, c, alpha, beta, schedule) +def machete_supported_schedules( + a_type: torch.dtype, + b_type: ScalarType, + group_scales_type: Optional[torch.dtype], + group_zeros_type: Optional[torch.dtype] = None, + channel_scales_type: Optional[torch.dtype] = None, + token_scales_type: Optional[torch.dtype] = None, + out_type: Optional[torch.dtype] = None) -> List[str]: + return torch.ops._C.machete_supported_schedules( + a_type, b_type.id, group_scales_type, group_zeros_type, + channel_scales_type, token_scales_type, out_type) -def machete_prepack_B(b_q_weight: torch.Tensor, - b_type: ScalarType) -> torch.Tensor: - return torch.ops._C.machete_prepack_B(b_q_weight, b_type.id) +def machete_mm( + a: torch.Tensor, + # b_q Should be the tensor returned by machete_prepack_B + b_q: torch.Tensor, + b_type: ScalarType, + out_type: Optional[torch.dtype] = None, + b_group_scales: Optional[torch.Tensor] = None, + b_group_zeros: Optional[torch.Tensor] = None, + b_group_size: Optional[int] = None, + b_channel_scales: Optional[torch.Tensor] = None, + a_token_scales: Optional[torch.Tensor] = None, + schedule: Optional[str] = None) -> torch.Tensor: + return torch.ops._C.machete_mm(a, b_q, b_type.id, out_type, b_group_scales, + b_group_zeros, b_group_size, + b_channel_scales, a_token_scales, schedule) + + +def machete_prepack_B( + b_q_weight: torch.Tensor, a_type: torch.dtype, b_type: ScalarType, + group_scales_type: Optional[torch.dtype]) -> torch.Tensor: + return torch.ops._C.machete_prepack_B(b_q_weight, a_type, b_type.id, + group_scales_type) if hasattr(torch.ops._C, "permute_cols"): diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index a5fcf5ce3d4d0..11e9380e11cee 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -1,7 +1,6 @@ from abc import ABC, abstractmethod from contextlib import contextmanager from dataclasses import dataclass, fields -from enum import Enum, auto from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Set, Tuple, Type, TypeVar) @@ -15,13 +14,19 @@ ModelRunnerInputBuilderBase) -class AttentionType(Enum): - DECODER = auto() # Decoder attention between previous layer Q/K/V - ENCODER = auto( - ) # Encoder attention between previous layer Q/K/V for encoder-decoder - ENCODER_ONLY = auto() # Encoder attention between previous layer Q/K/V - ENCODER_DECODER = auto( - ) # Attention between dec. Q and enc. K/V for encoder-decoder +class AttentionType: + """ + Attention type. + Use string to be compatible with `torch.compile`. + """ + # Decoder attention between previous layer Q/K/V + DECODER = "decoder" + # Encoder attention between previous layer Q/K/V for encoder-decoder + ENCODER = "encoder" + # Encoder attention between previous layer Q/K/V + ENCODER_ONLY = "encoder_only" + # Attention between dec. Q and enc. K/V for encoder-decoder + ENCODER_DECODER = "encoder_decoder" class AttentionBackend(ABC): @@ -241,7 +246,8 @@ def forward( attn_metadata: T, k_scale: float = 1.0, v_scale: float = 1.0, - attn_type: AttentionType = AttentionType.DECODER, + attn_type: str = AttentionType.DECODER, + output: Optional[torch.Tensor] = None, fp8_out_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: raise NotImplementedError diff --git a/vllm/attention/backends/blocksparse_attn.py b/vllm/attention/backends/blocksparse_attn.py index b977cc9033879..49a9d70ca2ccc 100644 --- a/vllm/attention/backends/blocksparse_attn.py +++ b/vllm/attention/backends/blocksparse_attn.py @@ -87,6 +87,11 @@ def __post_init__(self): class BlocksparseFlashAttentionBackend(AttentionBackend): + @staticmethod + def get_name() -> str: + # For attention layer compatibility + return "FLASH_ATTN" + @staticmethod def get_impl_cls() -> Type["BlocksparseFlashAttentionImpl"]: return BlocksparseFlashAttentionImpl @@ -354,7 +359,8 @@ def forward( attn_metadata: BlocksparseFlashAttentionMetadata, k_scale: float = 1.0, v_scale: float = 1.0, - attn_type: AttentionType = AttentionType.DECODER, + attn_type: str = AttentionType.DECODER, + output: Optional[torch.Tensor] = None, fp8_out_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with FlashAttention and PagedAttention. @@ -444,5 +450,6 @@ def forward( blocksparse_head_sliding_step=self.head_sliding_step, ) + assert output is not None # Reshape the output tensor. return output.view(num_tokens, hidden_size) diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 041318fd419f1..9c89088795207 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -16,10 +16,8 @@ compute_slot_mapping_start_idx, get_num_prefill_decode_query_kv_tokens, get_seq_len_block_table_args, is_all_cross_attn_metadata_set, is_all_encoder_attn_metadata_set, is_block_tables_empty) -from vllm.forward_context import get_forward_context from vllm.multimodal import MultiModalPlaceholderMap -from vllm.utils import (async_tensor_h2d, direct_register_custom_op, - make_tensor_with_pad) +from vllm.utils import async_tensor_h2d, make_tensor_with_pad if TYPE_CHECKING: from vllm.worker.model_runner import (ModelInputForGPUBuilder, @@ -639,26 +637,29 @@ def forward( attn_metadata: FlashAttentionMetadata, k_scale: float = 1.0, v_scale: float = 1.0, - attn_type: AttentionType = AttentionType.DECODER, + attn_type: str = AttentionType.DECODER, + output: Optional[torch.Tensor] = None, fp8_out_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with FlashAttention. Args: - query: shape = [num_tokens, num_heads * head_size] - key: shape = [num_tokens, num_kv_heads * head_size] - value: shape = [num_tokens, num_kv_heads * head_size] + query: shape = [num_tokens, num_heads, head_size] + key: shape = [num_tokens, num_kv_heads, head_size] + value: shape = [num_tokens, num_kv_heads, head_size] + output: shape = [num_tokens, num_heads, head_size] kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size] NOTE: kv_cache will be an empty tensor with shape [0] for profiling run. attn_metadata: Metadata for attention. - Returns: - shape = [num_tokens, num_heads * head_size] + NOTE: It in-place updates the output tensor. """ # NOTE(woosuk): FlashAttention does not support FP8 KV cache. assert k_scale == 1.0 and v_scale == 1.0, ( "key/v_scale is not supported in FlashAttention.") + assert output is not None, "Output tensor must be provided." + if (attn_type == AttentionType.ENCODER and (not attn_metadata.is_all_encoder_attn_metadata_set)): raise AttributeError("Encoder attention requires setting " @@ -669,31 +670,162 @@ def forward( "requires setting cross-attention " "metadata attributes.") - output = torch.ops.vllm.unified_flash_attention( - query, - key, - value, - self.num_heads, - self.head_size, - self.num_kv_heads, - kv_cache, - self.kv_cache_dtype, - k_scale, - v_scale, - self.scale, - attn_type.value, - self.sliding_window, - self.alibi_slopes, - self.logits_soft_cap, - ) - + kv_cache_dtype: str = self.kv_cache_dtype + softmax_scale: float = self.scale + window_size = self.sliding_window + alibi_slopes: Optional[torch.Tensor] = self.alibi_slopes + logits_soft_cap: Optional[float] = self.logits_soft_cap + + if kv_cache.numel() > 0: + key_cache = kv_cache[0] + value_cache = kv_cache[1] + # We skip updating the KV cache under two conditions: + # a. When the Attention Type is ENCODER. In this phase, we compute + # only the encoder attention without updating the cache. + # b. When both Key and Value are None. This occurs during + # cross-attention computation in the decoding phase, where the + # KV cache is already populated with the cross-attention + # tensor. Thus, we skip cache updates during this time. + if (attn_type != AttentionType.ENCODER) and (key is not None) and ( + value is not None): + if attn_type == AttentionType.ENCODER_DECODER: + # Update cross-attention KV cache (prefill-only) + updated_slot_mapping = attn_metadata.cross_slot_mapping + else: + # Update self-attention KV cache (prefill/decode) + updated_slot_mapping = attn_metadata.slot_mapping + + # Reshape the input keys and values and store them in the cache. + # If kv_cache is not provided, the new key and value tensors are + # not cached. This happens during the initial memory + # profiling run. + torch.ops._C_cache_ops.reshape_and_cache_flash( + key, + value, + kv_cache[0], + kv_cache[1], + updated_slot_mapping.flatten(), # type: ignore[union-attr] + kv_cache_dtype, + k_scale, + v_scale, + ) + + (num_prefill_query_tokens, num_prefill_kv_tokens, + num_decode_query_tokens) = \ + get_num_prefill_decode_query_kv_tokens(attn_metadata, attn_type) + decode_query = query[num_prefill_query_tokens:] + decode_output = output[num_prefill_query_tokens:] + # QKV for prefill. + query = query[:num_prefill_query_tokens] + prefill_output = output[:num_prefill_query_tokens] + assert query.shape[0] == num_prefill_query_tokens + assert decode_query.shape[0] == num_decode_query_tokens + + if prefill_meta := attn_metadata.prefill_metadata: + # Prompt run. + if (kv_cache.numel() == 0 or prefill_meta.block_tables is None + or prefill_meta.block_tables.numel() == 0): + # normal attention + # When block_tables are not filled, it means q and k are the + # prompt, and they have the same length. + q_seq_start_loc, q_seq_len, k_seq_start_loc, k_seq_len = \ + _get_query_key_seq_metadata(prefill_meta, True, attn_type) + + key = key[:num_prefill_kv_tokens] + value = value[:num_prefill_kv_tokens] + + flash_attn_varlen_func( + q=query, + k=key, + v=value, + cu_seqlens_q=q_seq_start_loc, + cu_seqlens_k=k_seq_start_loc, + max_seqlen_q=q_seq_len, + max_seqlen_k=k_seq_len, + softmax_scale=softmax_scale, + causal=_get_causal_option(attn_type), + window_size=window_size, + alibi_slopes=alibi_slopes, + softcap=logits_soft_cap, + out=prefill_output, + ) + else: + # prefix-enabled attention + assert attn_type == AttentionType.DECODER, ( + "Only decoder-only models support prefix caching") + assert prefill_meta.seq_lens is not None + max_seq_len = max(prefill_meta.seq_lens) + flash_attn_varlen_func( # noqa + q=query, + k=key_cache, + v=value_cache, + cu_seqlens_q=prefill_meta.query_start_loc, + max_seqlen_q=prefill_meta.max_query_len, + cu_seqlens_k=prefill_meta.seq_start_loc, + max_seqlen_k=max_seq_len, + softmax_scale=softmax_scale, + causal=True, + window_size=window_size, + alibi_slopes=alibi_slopes, + block_table=prefill_meta.block_tables, + softcap=logits_soft_cap, + out=prefill_output, + ) + + if decode_meta := attn_metadata.decode_metadata: + # Decoding run. + # Use flash_attn_varlen_func kernel for speculative decoding + # because different queries might have different lengths. + + assert decode_meta.max_decode_query_len is not None + # use only for actual varlen decoding + if decode_meta.max_decode_query_len > 1: + assert attn_type == AttentionType.DECODER, ( + "Only decoder-only models support max_decode_query_len > 1" + ) + flash_attn_varlen_func( + q=decode_query, + k=key_cache, + v=value_cache, + cu_seqlens_q=decode_meta.query_start_loc, + max_seqlen_q=decode_meta.max_decode_query_len, + cu_seqlens_k=decode_meta.seq_start_loc, + max_seqlen_k=decode_meta.max_decode_seq_len, + softmax_scale=softmax_scale, + causal=True, + window_size=window_size, + alibi_slopes=alibi_slopes, + softcap=logits_soft_cap, + block_table=decode_meta.block_tables, + out=decode_output, + ) + else: + # Use flash_attn_with_kvcache for normal decoding. + ( + seq_lens_arg, + _, + block_tables_arg, + ) = get_seq_len_block_table_args(decode_meta, False, attn_type) + flash_attn_with_kvcache( + q=decode_query.unsqueeze(1), + k_cache=key_cache, + v_cache=value_cache, + block_table=block_tables_arg, + cache_seqlens=seq_lens_arg, + softmax_scale=softmax_scale, + causal=True, + window_size=window_size, + alibi_slopes=alibi_slopes, + softcap=logits_soft_cap, + out=decode_output.unsqueeze(1), + ) return output def _get_query_key_seq_metadata( attn_metadata, is_prompt: bool, - attn_type: AttentionType, + attn_type: str, ) -> tuple: """ Returns sequence metadata for key and query based on the specified @@ -755,7 +887,7 @@ def _get_query_key_seq_metadata( raise AttributeError(f"Invalid attention type {str(attn_type)}") -def _get_causal_option(attn_type: AttentionType) -> bool: +def _get_causal_option(attn_type: str) -> bool: """ Determine whether the given attention type is suitable for causal attention mechanisms. @@ -771,220 +903,3 @@ def _get_causal_option(attn_type: AttentionType) -> bool: return not (attn_type == AttentionType.ENCODER or attn_type == AttentionType.ENCODER_ONLY or attn_type == AttentionType.ENCODER_DECODER) - - -def unified_flash_attention( - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - num_heads: int, - head_size: int, - num_kv_heads: int, - kv_cache: torch.Tensor, - kv_cache_dtype: str, - k_scale: float, - v_scale: float, - softmax_scale: float, - attn_type_int_val: int, - window_size: Optional[List[int]] = None, - alibi_slopes: Optional[torch.Tensor] = None, - logits_soft_cap: Optional[float] = None, -) -> torch.Tensor: - - # Convert integer attn_type to enum - try: - attn_type = AttentionType(attn_type_int_val) - except ValueError as err: - raise AttributeError( - f"Invalid attention type {str(attn_type_int_val)}") from err - - current_metadata = get_forward_context() - assert current_metadata is not None - assert isinstance(current_metadata, FlashAttentionMetadata) - attn_metadata: FlashAttentionMetadata = current_metadata - - num_tokens, hidden_size = query.shape - - # Reshape the query, key, and value tensors. - query = query.view(-1, num_heads, head_size) - if (key is not None) and (value is not None): - key = key.view(-1, num_kv_heads, head_size) - value = value.view(-1, num_kv_heads, head_size) - - if kv_cache.numel() > 0: - key_cache = kv_cache[0] - value_cache = kv_cache[1] - # We skip updating the KV cache under two conditions: - # a. When the Attention Type is ENCODER. In this phase, we compute - # only the encoder attention without updating the cache. - # b. When both Key and Value are None. This occurs during - # cross-attention computation in the decoding phase, where the KV - # cache is already populated with the cross-attention tensor. - # Thus, we skip cache updates during this time. - if (attn_type != AttentionType.ENCODER) and (key is not None) and ( - value is not None): - if attn_type == AttentionType.ENCODER_DECODER: - # Update cross-attention KV cache (prefill-only) - updated_slot_mapping = attn_metadata.cross_slot_mapping - else: - # Update self-attention KV cache (prefill/decode) - updated_slot_mapping = attn_metadata.slot_mapping - - # Reshape the input keys and values and store them in the cache. - # If kv_cache is not provided, the new key and value tensors are - # not cached. This happens during the initial memory profiling run. - torch.ops._C_cache_ops.reshape_and_cache_flash( - key, - value, - kv_cache[0], - kv_cache[1], - updated_slot_mapping.flatten(), # type: ignore[union-attr] - kv_cache_dtype, - k_scale, - v_scale, - ) - - (num_prefill_query_tokens, num_prefill_kv_tokens, - num_decode_query_tokens) = \ - get_num_prefill_decode_query_kv_tokens(attn_metadata, attn_type) - decode_query = query[num_prefill_query_tokens:] - # QKV for prefill. - query = query[:num_prefill_query_tokens] - assert query.shape[0] == num_prefill_query_tokens - assert decode_query.shape[0] == num_decode_query_tokens - - prefill_output: Optional[torch.Tensor] = None - decode_output: Optional[torch.Tensor] = None - if prefill_meta := attn_metadata.prefill_metadata: - # Prompt run. - if (kv_cache.numel() == 0 or prefill_meta.block_tables is None - or prefill_meta.block_tables.numel() == 0): - # normal attention - # When block_tables are not filled, it means q and k are the - # prompt, and they have the same length. - q_seq_start_loc, q_seq_len, k_seq_start_loc, k_seq_len = \ - _get_query_key_seq_metadata(prefill_meta, True, attn_type) - - key = key[:num_prefill_kv_tokens] - value = value[:num_prefill_kv_tokens] - - prefill_output = flash_attn_varlen_func( - q=query, - k=key, - v=value, - cu_seqlens_q=q_seq_start_loc, - cu_seqlens_k=k_seq_start_loc, - max_seqlen_q=q_seq_len, - max_seqlen_k=k_seq_len, - softmax_scale=softmax_scale, - causal=_get_causal_option(attn_type), - window_size=window_size, - alibi_slopes=alibi_slopes, - softcap=logits_soft_cap, - ) - else: - # prefix-enabled attention - assert attn_type == AttentionType.DECODER, ( - "Only decoder-only models support prefix caching") - assert prefill_meta.seq_lens is not None - max_seq_len = max(prefill_meta.seq_lens) - prefill_output = flash_attn_varlen_func( # noqa - q=query, - k=key_cache, - v=value_cache, - cu_seqlens_q=prefill_meta.query_start_loc, - max_seqlen_q=prefill_meta.max_query_len, - cu_seqlens_k=prefill_meta.seq_start_loc, - max_seqlen_k=max_seq_len, - softmax_scale=softmax_scale, - causal=True, - window_size=window_size, - alibi_slopes=alibi_slopes, - block_table=prefill_meta.block_tables, - softcap=logits_soft_cap, - ) - - if decode_meta := attn_metadata.decode_metadata: - # Decoding run. - # Use flash_attn_varlen_func kernel for speculative decoding - # because different queries might have different lengths. - - assert decode_meta.max_decode_query_len is not None - # use only for actual varlen decoding - if decode_meta.max_decode_query_len > 1: - assert attn_type == AttentionType.DECODER, ( - "Only decoder-only models support max_decode_query_len > 1") - decode_output = flash_attn_varlen_func( - q=decode_query, - k=key_cache, - v=value_cache, - cu_seqlens_q=decode_meta.query_start_loc, - max_seqlen_q=decode_meta.max_decode_query_len, - cu_seqlens_k=decode_meta.seq_start_loc, - max_seqlen_k=decode_meta.max_decode_seq_len, - softmax_scale=softmax_scale, - causal=True, - window_size=window_size, - alibi_slopes=alibi_slopes, - softcap=logits_soft_cap, - block_table=decode_meta.block_tables, - ) - else: - # Use flash_attn_with_kvcache for normal decoding. - ( - seq_lens_arg, - _, - block_tables_arg, - ) = get_seq_len_block_table_args(decode_meta, False, attn_type) - decode_output = flash_attn_with_kvcache( - q=decode_query.unsqueeze(1), - k_cache=key_cache, - v_cache=value_cache, - block_table=block_tables_arg, - cache_seqlens=seq_lens_arg, - softmax_scale=softmax_scale, - causal=True, - window_size=window_size, - alibi_slopes=alibi_slopes, - softcap=logits_soft_cap, - ).squeeze(1) - - if prefill_output is None: - assert decode_output is not None - return decode_output.view(num_decode_query_tokens, hidden_size) - if decode_output is None: - assert prefill_output is not None - return prefill_output.view(num_prefill_query_tokens, hidden_size) - - assert decode_meta is not None - decode_output = decode_output.squeeze(1) - output = torch.cat([prefill_output, decode_output], dim=0) - return output.view(num_tokens, hidden_size) - - -def unified_flash_attention_fake( - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - num_heads: int, - head_size: int, - num_kv_heads: int, - kv_cache: torch.Tensor, - kv_cache_dtype: str, - k_scale: float, - v_scale: float, - softmax_scale: float, - attn_type_int_val: int, - window_size: Optional[List[int]] = None, - alibi_slopes: Optional[torch.Tensor] = None, - logits_soft_cap: Optional[float] = None, -) -> torch.Tensor: - return torch.empty_like(query) - - -direct_register_custom_op( - op_name="unified_flash_attention", - op_func=unified_flash_attention, - mutates_args=["kv_cache"], - fake_impl=unified_flash_attention_fake, -) diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 9694d2b53f215..91b2e6dfc070b 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -30,9 +30,8 @@ compute_slot_mapping_start_idx, is_block_tables_empty) from vllm.attention.ops.paged_attn import PagedAttention -from vllm.forward_context import get_forward_context -from vllm.utils import (async_tensor_h2d, direct_register_custom_op, - get_kv_cache_torch_dtype, make_tensor_with_pad) +from vllm.utils import (async_tensor_h2d, get_kv_cache_torch_dtype, + make_tensor_with_pad) if TYPE_CHECKING: from vllm.worker.model_runner import (ModelInputForGPUBuilder, @@ -757,9 +756,8 @@ def __init__( if alibi_slopes is not None: alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) self.alibi_slopes = alibi_slopes - if sliding_window is not None: - raise ValueError("Sliding window is not supported in FlashInfer.") - self.sliding_window = (-1, -1) + self.sliding_window = ((sliding_window - 1, + 0) if sliding_window is not None else (-1, -1)) self.kv_cache_dtype = kv_cache_dtype self.logits_soft_cap = logits_soft_cap @@ -775,179 +773,130 @@ def forward( attn_metadata: FlashInferMetadata, k_scale: float = 1.0, v_scale: float = 1.0, - attn_type: AttentionType = AttentionType.DECODER, + attn_type: str = AttentionType.DECODER, + output: Optional[torch.Tensor] = None, fp8_out_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: + + # TODO: directly write to output tensor + if attn_type != AttentionType.DECODER: raise NotImplementedError("Encoder self-attention and " "encoder/decoder cross-attention " "are not implemented for " "FlashInferImpl") - return torch.ops.vllm.unified_flash_infer( - query, - key, - value, - self.num_heads, - self.head_size, - self.num_kv_heads, - kv_cache, - self.kv_cache_dtype, - k_scale, - v_scale, - self.scale, - self.sliding_window, - self.alibi_slopes, - self.logits_soft_cap, - ) - - -def unified_flash_infer( - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - num_heads: int, - head_size: int, - num_kv_heads: int, - kv_cache: torch.Tensor, - kv_cache_dtype: str, - k_scale: float, - v_scale: float, - softmax_scale: float, - window_size: Optional[List[int]] = None, - alibi_slopes: Optional[torch.Tensor] = None, - logits_soft_cap: Optional[float] = None, -) -> torch.Tensor: - - current_metadata = get_forward_context() - assert current_metadata is not None - assert isinstance(current_metadata, FlashInferMetadata) - attn_metadata: FlashInferMetadata = current_metadata - - num_tokens, hidden_size = query.shape - query = query.view(-1, num_heads, head_size) - key = key.view(-1, num_kv_heads, head_size) - value = value.view(-1, num_kv_heads, head_size) - - if kv_cache.numel() > 0: - # Use the same reshape and cache kernel as flash attention. - ops.reshape_and_cache_flash( - key, - value, - kv_cache[:, 0], - kv_cache[:, 1], - attn_metadata.slot_mapping.flatten(), - kv_cache_dtype, - k_scale, - v_scale, - ) - # The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2 - # to process the cache when the kv_cache_dtype is fp8 - if kv_cache_dtype.startswith("fp8"): - torch_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer( - kv_cache_dtype) - kv_cache = kv_cache.view(torch_dtype) - - num_prefill_tokens = attn_metadata.num_prefill_tokens - num_decode_tokens = attn_metadata.num_decode_tokens - assert key.shape[0] == num_prefill_tokens + num_decode_tokens, \ - f"key : {key.shape} : #prefill tokens {num_prefill_tokens} : #decode tokens {num_decode_tokens}" # noqa - assert value.shape[0] == num_prefill_tokens + num_decode_tokens, \ - f"value : {value.shape} : #prefill toks {num_prefill_tokens} : #decode toks {num_decode_tokens}" # noqa - query = query.contiguous() # Flashinfer requires query to be contiguous - # Query for decode. KV is not needed because it is already cached. - # QKV for prefill. - decode_query = query[num_prefill_tokens:] - query = query[:num_prefill_tokens] - - key = key[:num_prefill_tokens] - value = value[:num_prefill_tokens] - - assert query.shape[0] == num_prefill_tokens - assert decode_query.shape[0] == num_decode_tokens - - prefill_output: Optional[torch.Tensor] = None - decode_output: Optional[torch.Tensor] = None - if prefill_meta := attn_metadata.prefill_metadata: - # We will use flash attention for prefill - # when kv_cache is not provided. - # This happens when vllm runs the profiling to - # determine the number of blocks. - if kv_cache.numel() == 0: - prefill_output = flash_attn_varlen_func( - q=query, - k=key, - v=value, - cu_seqlens_q=prefill_meta.seq_start_loc, - cu_seqlens_k=prefill_meta.seq_start_loc, - max_seqlen_q=prefill_meta.max_prefill_seq_len, - max_seqlen_k=prefill_meta.max_prefill_seq_len, - softmax_scale=softmax_scale, - causal=True, - window_size=window_size, - alibi_slopes=alibi_slopes, + num_heads: int = self.num_heads + head_size: int = self.head_size + num_kv_heads: int = self.num_kv_heads + kv_cache_dtype: str = self.kv_cache_dtype + softmax_scale: float = self.scale + window_size = self.sliding_window + alibi_slopes = self.alibi_slopes + logits_soft_cap = self.logits_soft_cap + + num_tokens, hidden_size = query.shape + query = query.view(-1, num_heads, head_size) + key = key.view(-1, num_kv_heads, head_size) + value = value.view(-1, num_kv_heads, head_size) + + if kv_cache.numel() > 0: + # Use the same reshape and cache kernel as flash attention. + ops.reshape_and_cache_flash( + key, + value, + kv_cache[:, 0], + kv_cache[:, 1], + attn_metadata.slot_mapping.flatten(), + kv_cache_dtype, + k_scale, + v_scale, ) - else: - assert prefill_meta is not None - assert prefill_meta.prefill_wrapper is not None - prefill_output = prefill_meta.prefill_wrapper.forward( - query, + # The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2 + # to process the cache when the kv_cache_dtype is fp8 + if kv_cache_dtype.startswith("fp8"): + torch_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer( + kv_cache_dtype) + kv_cache = kv_cache.view(torch_dtype) + + num_prefill_tokens = attn_metadata.num_prefill_tokens + num_decode_tokens = attn_metadata.num_decode_tokens + assert key.shape[0] == num_prefill_tokens + num_decode_tokens, \ + f"key : {key.shape} : #prefill tokens {num_prefill_tokens} : #decode tokens {num_decode_tokens}" # noqa + assert value.shape[0] == num_prefill_tokens + num_decode_tokens, \ + f"value : {value.shape} : #prefill toks {num_prefill_tokens} : #decode toks {num_decode_tokens}" # noqa + query = query.contiguous( + ) # Flashinfer requires query to be contiguous + # Query for decode. KV is not needed because it is already cached. + # QKV for prefill. + decode_query = query[num_prefill_tokens:] + query = query[:num_prefill_tokens] + + key = key[:num_prefill_tokens] + value = value[:num_prefill_tokens] + + assert query.shape[0] == num_prefill_tokens + assert decode_query.shape[0] == num_decode_tokens + + window_left = window_size[0] if window_size is not None else -1 + + prefill_output: Optional[torch.Tensor] = None + decode_output: Optional[torch.Tensor] = None + if prefill_meta := attn_metadata.prefill_metadata: + # We will use flash attention for prefill + # when kv_cache is not provided. + # This happens when vllm runs the profiling to + # determine the number of blocks. + if kv_cache.numel() == 0: + prefill_output = flash_attn_varlen_func( + q=query, + k=key, + v=value, + cu_seqlens_q=prefill_meta.seq_start_loc, + cu_seqlens_k=prefill_meta.seq_start_loc, + max_seqlen_q=prefill_meta.max_prefill_seq_len, + max_seqlen_k=prefill_meta.max_prefill_seq_len, + softmax_scale=softmax_scale, + causal=True, + window_size=window_size, + alibi_slopes=alibi_slopes, + ) + else: + assert prefill_meta is not None + assert prefill_meta.prefill_wrapper is not None + prefill_output = prefill_meta.prefill_wrapper.forward( + query, + kv_cache, + logits_soft_cap=logits_soft_cap, + causal=True, + k_scale=k_scale, + v_scale=v_scale, + window_left=window_left) + if decode_meta := attn_metadata.decode_metadata: + assert decode_meta is not None + assert decode_meta.decode_wrapper is not None + decode_output = decode_meta.decode_wrapper.forward( + decode_query, kv_cache, + sm_scale=softmax_scale, logits_soft_cap=logits_soft_cap, - causal=True, k_scale=k_scale, - v_scale=v_scale) - if decode_meta := attn_metadata.decode_metadata: - assert attn_metadata.decode_metadata is not None - assert attn_metadata.decode_metadata.decode_wrapper is not None - decode_output = attn_metadata.decode_metadata.decode_wrapper.forward( - decode_query, - kv_cache, - sm_scale=softmax_scale, - logits_soft_cap=logits_soft_cap, - k_scale=k_scale, - v_scale=v_scale) - - if prefill_output is None and decode_output is not None: - # Decode only batch. - output, num_tokens = decode_output, num_decode_tokens - elif decode_output is None and prefill_output is not None: - # Prefill only batch. - output, num_tokens = prefill_output, num_prefill_tokens - else: - # Chunked prefill batch does not work with speculative decoding in - # FlashInfer backend, so the query length for decode should be 1. - assert prefill_output is not None - assert decode_output is not None - assert decode_meta is not None - assert decode_meta.decode_query_len == 1 - decode_output = decode_output.squeeze(1) - output = torch.cat([prefill_output, decode_output], dim=0) - return output.view(num_tokens, hidden_size) - - -def unified_flash_infer_fake( - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - num_heads: int, - head_size: int, - num_kv_heads: int, - kv_cache: torch.Tensor, - kv_cache_dtype: str, - k_scale: float, - v_scale: float, - softmax_scale: float, - window_size: Optional[List[int]] = None, - alibi_slopes: Optional[torch.Tensor] = None, - logits_soft_cap: Optional[float] = None, -) -> torch.Tensor: - return torch.empty_like(query).contiguous() - - -direct_register_custom_op( - op_name="unified_flash_infer", - op_func=unified_flash_infer, - mutates_args=["kv_cache"], - fake_impl=unified_flash_infer_fake, -) + v_scale=v_scale, + window_left=window_left) + + if prefill_output is None and decode_output is not None: + # Decode only batch. + output, num_tokens = decode_output, num_decode_tokens + elif decode_output is None and prefill_output is not None: + # Prefill only batch. + output, num_tokens = prefill_output, num_prefill_tokens + else: + # Chunked prefill batch does not work with speculative decoding in + # FlashInfer backend, so the query length for decode should be 1. + assert prefill_output is not None + assert decode_output is not None + assert decode_meta is not None + assert decode_meta.decode_query_len == 1 + decode_output = decode_output.squeeze(1) + output = torch.cat([prefill_output, decode_output], dim=0) + return output.view(num_tokens, hidden_size) diff --git a/vllm/attention/backends/hpu_attn.py b/vllm/attention/backends/hpu_attn.py index 7d7967a1c0329..5471eec881d85 100644 --- a/vllm/attention/backends/hpu_attn.py +++ b/vllm/attention/backends/hpu_attn.py @@ -22,6 +22,10 @@ class HPUAttentionBackend(AttentionBackend): + @staticmethod + def get_name() -> str: + return "HPU_ATTN" + @staticmethod def get_impl_cls() -> Type["HPUAttentionImpl"]: return HPUAttentionImpl @@ -140,7 +144,8 @@ def forward( attn_metadata: HPUAttentionMetadata, k_scale: float = 1.0, v_scale: float = 1.0, - attn_type: AttentionType = AttentionType.DECODER, + attn_type: str = AttentionType.DECODER, + output: Optional[torch.Tensor] = None, fp8_out_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with xFormers and PagedAttention. diff --git a/vllm/attention/backends/ipex_attn.py b/vllm/attention/backends/ipex_attn.py index 0077438adbc03..d02fbcb6ca0ae 100644 --- a/vllm/attention/backends/ipex_attn.py +++ b/vllm/attention/backends/ipex_attn.py @@ -172,7 +172,8 @@ def forward( attn_metadata: IpexAttnMetadata, # type: ignore k_scale: float = 1.0, v_scale: float = 1.0, - attn_type: AttentionType = AttentionType.DECODER, + attn_type: str = AttentionType.DECODER, + output: Optional[torch.Tensor] = None, fp8_out_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with IPEX varlen_attention and PagedAttention. diff --git a/vllm/attention/backends/pallas.py b/vllm/attention/backends/pallas.py index 7943af23b7b41..a2612c97ca23b 100644 --- a/vllm/attention/backends/pallas.py +++ b/vllm/attention/backends/pallas.py @@ -65,6 +65,7 @@ class PallasMetadata(AttentionMetadata): # or all decoding. block_tables: Optional[torch.Tensor] = None context_lens: Optional[torch.Tensor] = None + effective_query_lens: Optional[torch.Tensor] = None @property def prefill_metadata(self) -> Optional["PallasMetadata"]: @@ -72,8 +73,6 @@ def prefill_metadata(self) -> Optional["PallasMetadata"]: return None assert self.num_decode_tokens == 0 - assert self.block_tables is None - assert self.context_lens is None return self @property @@ -151,7 +150,8 @@ def forward( attn_metadata: PallasMetadata, k_scale: float = 1.0, v_scale: float = 1.0, - attn_type: AttentionType = AttentionType.DECODER, + attn_type: str = AttentionType.DECODER, + output: Optional[torch.Tensor] = None, fp8_out_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with Pallas attention. @@ -187,29 +187,50 @@ def forward( query = query * self.scale if attn_metadata.num_prefills > 0: - assert seq_len % 16 == 0, ( - "Pallas FlashAttention kernel requires seq_len to be a " - f"multiple of 16 but got {seq_len}") - - # Handle GQA/MQA. - if self.num_kv_heads != self.num_heads: - key = key.repeat_interleave(self.num_queries_per_kv, dim=-2) - key = key.view(batch_size, seq_len, self.num_heads, - self.head_size) - value = value.repeat_interleave(self.num_queries_per_kv, + if attn_metadata.block_tables is None: + # Prefill without paged KV cache. + assert seq_len % 16 == 0, ( + "Pallas FlashAttention kernel requires seq_len to be a " + f"multiple of 16 but got {seq_len}") + + # Handle GQA/MQA. + if self.num_kv_heads != self.num_heads: + key = key.repeat_interleave(self.num_queries_per_kv, dim=-2) - value = value.view(batch_size, seq_len, self.num_heads, + key = key.view(batch_size, seq_len, self.num_heads, self.head_size) - # FlashAttention requires [batch_size, num_heads, seq_len, d_model] - # while the input is [batch_size, seq_len, num_heads, d_model]. - # Permute the input to match the required format. - output = torch.ops.xla.flash_attention( - query.permute(0, 2, 1, 3), - key.permute(0, 2, 1, 3), - value.permute(0, 2, 1, 3), - True, - ) - output = output.permute(0, 2, 1, 3) + value = value.repeat_interleave(self.num_queries_per_kv, + dim=-2) + value = value.view(batch_size, seq_len, self.num_heads, + self.head_size) + # FlashAttention kernel requires the input shape to be + # [batch_size, num_heads, seq_len, d_model] + # while the input is [batch_size, seq_len, num_heads, d_model]. + # Permute the input to match the required format. + output = torch.ops.xla.flash_attention( + query.permute(0, 2, 1, 3), + key.permute(0, 2, 1, 3), + value.permute(0, 2, 1, 3), + True, + ) + output = output.permute(0, 2, 1, 3) + else: + # Prefill with paged KV cache. + # TODO(woosuk): Tune the below knobs. + num_kv_pages_per_compute_block = 16 + num_queries_per_compute_block = 16 + assert seq_len % num_queries_per_compute_block == 0 + output = torch.ops.xla.multi_queries_paged_attention( + query, + key_cache, + value_cache, + attn_metadata.context_lens, + attn_metadata.block_tables, + attn_metadata.effective_query_lens, + num_kv_pages_per_compute_block, + num_queries_per_compute_block, + use_kernel=True, + ) else: # Decoding run. assert kv_cache[0].numel() > 0 diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 77eadb2997689..abec46b2e599c 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -323,7 +323,7 @@ def _make_alibi_bias(alibi_slopes: torch.Tensor, def _get_seq_len_block_table_args( attn_metadata: ROCmFlashAttentionMetadata, - attn_type: AttentionType, + attn_type: str, ) -> tuple: ''' The particular choice of sequence-length @@ -540,7 +540,8 @@ def forward( attn_metadata: ROCmFlashAttentionMetadata, k_scale: float = 1.0, v_scale: float = 1.0, - attn_type: AttentionType = AttentionType.DECODER, + attn_type: str = AttentionType.DECODER, + output: Optional[torch.Tensor] = None, fp8_out_scale: torch.Tensor = None, ) -> torch.Tensor: """Forward pass with FlashAttention and PagedAttention. diff --git a/vllm/attention/backends/torch_sdpa.py b/vllm/attention/backends/torch_sdpa.py index c9c791dad47fb..e5e020cedd6ce 100644 --- a/vllm/attention/backends/torch_sdpa.py +++ b/vllm/attention/backends/torch_sdpa.py @@ -7,18 +7,14 @@ from torch.nn.functional import scaled_dot_product_attention from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionMetadata, AttentionType) + AttentionMetadata, + AttentionMetadataBuilder, + AttentionType) from vllm.attention.backends.utils import CommonAttentionState +from vllm.attention.ops.ipex_attn import PagedAttention from vllm.attention.ops.paged_attn import PagedAttentionMetadata -from vllm.platforms import current_platform - -if current_platform.is_cpu(): - try: - from vllm.attention.ops.ipex_attn import PagedAttention - except ImportError: - from vllm.attention.ops.paged_attn import PagedAttention -else: - from vllm.attention.ops.paged_attn import PagedAttention +from vllm.utils import make_tensor_with_pad +from vllm.worker.cpu_model_runner import ModelInputForCPUBuilder class TorchSDPABackend(AttentionBackend): @@ -39,6 +35,10 @@ def get_metadata_cls() -> Type["AttentionMetadata"]: def get_state_cls() -> Type["CommonAttentionState"]: return CommonAttentionState + @staticmethod + def get_builder_cls() -> Type["TorchSDPAMetadataBuilder"]: + return TorchSDPAMetadataBuilder + @staticmethod def get_kv_cache_shape( num_blocks: int, @@ -71,9 +71,15 @@ class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata): """ # Currently, input sequences can only contain all prompts # or all decoding. True if all sequences are prompts. - is_prompt: bool - slot_mapping: torch.Tensor - seq_lens: Optional[List[int]] + chunked_prefill: bool + seq_lens: Optional[List[int]] = None # For non-chunked prefill + + # For chunked prefill only + max_query_len: Optional[int] = None + max_kv_len: Optional[int] = None + query_start_loc: Optional[torch.Tensor] = None + kv_start_loc: Optional[torch.Tensor] = None + prefill_block_tables: Optional[torch.Tensor] = None # Begin encoder attn & enc/dec cross-attn fields... # Encoder sequence lengths representation @@ -123,25 +129,19 @@ def is_all_cross_attn_metadata_set(self): @property def prefill_metadata(self) -> Optional["TorchSDPAMetadata"]: - # Currently chunked prefill is not supported - if self.num_decode_tokens == 0: - assert self.num_prefills > 0 - return self - - return None + if self.num_prefill_tokens == 0: + return None + return self @property def decode_metadata(self) -> Optional["TorchSDPAMetadata"]: - # Currently chunked prefill is not supported - if self.num_prefills > 0: - assert self.num_decode_tokens == 0 + if self.num_decode_tokens == 0: return None - return self def get_seq_lens( self, - attn_type: AttentionType, + attn_type: str, ): ''' Extract appropriate sequence lengths from attention metadata @@ -174,7 +174,7 @@ def get_seq_lens( def get_attn_bias( self, - attn_type: AttentionType, + attn_type: str, ) -> Optional[List[torch.Tensor]]: ''' Extract appropriate attention bias from attention metadata @@ -203,7 +203,7 @@ def get_attn_bias( def set_attn_bias( self, attn_bias: List[torch.Tensor], - attn_type: AttentionType, + attn_type: str, ) -> None: ''' Update appropriate attention bias field of attention metadata, @@ -229,7 +229,7 @@ def set_attn_bias( def get_seq_len_block_table_args( self, - attn_type: AttentionType, + attn_type: str, ) -> tuple: ''' The particular choice of sequence-length- and block-table-related @@ -274,6 +274,109 @@ def get_seq_len_block_table_args( raise AttributeError(f"Invalid attention type {str(attn_type)}") +class TorchSDPAMetadataBuilder(AttentionMetadataBuilder[TorchSDPAMetadata]): + + def __init__(self, input_builder: ModelInputForCPUBuilder) -> None: + self.chunked_prefill = input_builder.chunked_prefill + self.input_data = input_builder.input_data + + def build(self, seq_lens: List[int], query_lens: List[int], + cuda_graph_pad_size: int, batch_size: int) -> TorchSDPAMetadata: + input_data = self.input_data + prefill_seq_lens = seq_lens[0:input_data.num_prefills] + prefill_query_lens = query_lens[0:input_data.num_prefills] + slot_mapping = torch.tensor(input_data.slot_mapping, + dtype=torch.long, + device="cpu") + + # For chunked-prefill + if self.chunked_prefill and input_data.num_prefill_tokens != 0: + prefill_block_tables = make_tensor_with_pad( + self.input_data.prefill_block_tables, + pad=0, + dtype=torch.int32, + device="cpu", + ) + query_lens_tensor = torch.tensor(prefill_query_lens, + dtype=torch.int32, + device="cpu") + kv_lens_tensor = torch.tensor(prefill_seq_lens, + dtype=torch.int32, + device="cpu") + query_start_loc = torch.zeros(input_data.num_prefills + 1, + dtype=torch.int32, + device="cpu") + kv_start_loc = torch.zeros(input_data.num_prefills + 1, + dtype=torch.int32, + device="cpu") + torch.cumsum(query_lens_tensor, + dim=0, + dtype=torch.int32, + out=query_start_loc[1:]) + torch.cumsum(kv_lens_tensor, + dim=0, + dtype=torch.int32, + out=kv_start_loc[1:]) + max_query_len = max(prefill_query_lens) + max_kv_len = max(prefill_seq_lens) + else: + prefill_block_tables = None + query_start_loc = None + kv_start_loc = None + max_query_len = None + max_kv_len = None + + # For paged attention + if input_data.num_decode_tokens != 0: + seq_lens_tensor = torch.tensor( + input_data.seq_lens[input_data.num_prefills:], + dtype=torch.int32, + device="cpu", + ) + block_tables = make_tensor_with_pad( + self.input_data.decode_block_tables, + pad=0, + dtype=torch.int32, + device="cpu", + ) + else: + block_tables = torch.tensor([]) + seq_lens_tensor = torch.tensor( + input_data.seq_lens[:input_data.num_prefills], + dtype=torch.int32, + device="cpu", + ) + + # For multi-modal models + placeholder_index_maps = None + if len(input_data.multi_modal_inputs_list) != 0: + placeholder_index_maps = { + modality: placeholder_map.index_map() + for modality, placeholder_map in + input_data.multi_modal_placeholder_maps.items() + } + + attn_metadata = TorchSDPAMetadata( + chunked_prefill=self.chunked_prefill, + seq_lens=prefill_seq_lens, + seq_lens_tensor=seq_lens_tensor, + max_query_len=max_query_len, + max_kv_len=max_kv_len, + query_start_loc=query_start_loc, + kv_start_loc=kv_start_loc, + max_decode_seq_len=input_data.max_decode_seq_len, + num_prefills=input_data.num_prefills, + num_prefill_tokens=input_data.num_prefill_tokens, + num_decode_tokens=input_data.num_decode_tokens, + block_tables=block_tables, + prefill_block_tables=prefill_block_tables, + slot_mapping=slot_mapping, + multi_modal_placeholder_index_maps=placeholder_index_maps, + ) + + return attn_metadata + + class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]): def __init__( @@ -327,7 +430,8 @@ def forward( attn_metadata: TorchSDPAMetadata, # type: ignore k_scale: float = 1.0, v_scale: float = 1.0, - attn_type: AttentionType = AttentionType.DECODER, + attn_type: str = AttentionType.DECODER, + output: Optional[torch.Tensor] = None, fp8_out_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with torch SDPA and PagedAttention. @@ -410,19 +514,35 @@ def forward( assert key.shape[0] == num_prefill_tokens + num_decode_tokens assert value.shape[0] == num_prefill_tokens + num_decode_tokens + output = torch.empty_like(query) if prefill_meta := attn_metadata.prefill_metadata: assert attn_metadata.seq_lens is not None - if (kv_cache.numel() == 0 - or prefill_meta.block_tables.numel() == 0): - output = self._run_sdpa_forward(query, - key, - value, - prefill_meta, - attn_type=attn_type) + if not prefill_meta.prefill_metadata.chunked_prefill: # type: ignore + self._run_sdpa_forward(output, + query, + key, + value, + prefill_meta, + attn_type=attn_type) else: # prefix-enabled attention - raise RuntimeError( - "Torch SDPA backend doesn't support prefix decoding.") + assert not self.need_mask + import intel_extension_for_pytorch.llm.modules as ipex_modules + output = torch.empty_like(query) + ipex_modules.PagedAttention.flash_attn_varlen_func( + output[:prefill_meta.num_prefill_tokens, :, :], + query[:prefill_meta.num_prefill_tokens, :, :], + key_cache, + value_cache, + prefill_meta.query_start_loc, + prefill_meta.kv_start_loc, + prefill_meta.max_query_len, + prefill_meta.max_kv_len, + self.scale, + True, + prefill_meta.prefill_block_tables, + self.alibi_slopes, + ) if decode_meta := attn_metadata.decode_metadata: assert attn_type != AttentionType.ENCODER_ONLY, ( @@ -434,8 +554,9 @@ def forward( block_tables_arg, ) = decode_meta.get_seq_len_block_table_args(attn_type) - output = PagedAttention.forward_decode( - query, + PagedAttention.forward_decode( + output[attn_metadata.num_prefill_tokens:, :, :], + query[attn_metadata.num_prefill_tokens:, :, :], key_cache, value_cache, block_tables_arg, @@ -454,12 +575,13 @@ def forward( def _run_sdpa_forward( self, + output: torch.Tensor, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attn_metadata: TorchSDPAMetadata, - attn_type: AttentionType = AttentionType.DECODER, - ): + attn_type: str = AttentionType.DECODER, + ) -> None: if self.num_kv_heads != self.num_heads: key = key.repeat_interleave(self.num_queries_per_kv, dim=1) value = value.repeat_interleave(self.num_queries_per_kv, dim=1) @@ -480,7 +602,6 @@ def _run_sdpa_forward( attn_masks = [None] * len(seq_lens) attn_metadata.set_attn_bias(attn_masks, attn_type) - output = torch.empty_like(query) query = query.movedim(0, query.dim() - 2) key = key.movedim(0, key.dim() - 2) value = value.movedim(0, value.dim() - 2) @@ -503,7 +624,6 @@ def _run_sdpa_forward( scale=self.scale).squeeze(0).movedim(query.dim() - 2, 0) output[start_q:end_q, :, :] = sub_out start_q, start_kv = end_q, end_kv - return output def _make_alibi_bias( diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py index 4a0eeda980cb2..03d5d84141d35 100644 --- a/vllm/attention/backends/utils.py +++ b/vllm/attention/backends/utils.py @@ -510,7 +510,7 @@ def is_all_cross_attn_metadata_set(attn_metadata): def get_seq_len_block_table_args( attn_metadata, is_prompt: bool, - attn_type: AttentionType, + attn_type: str, ) -> tuple: ''' The particular choice of sequence-length- and block-table-related @@ -561,7 +561,7 @@ def get_seq_len_block_table_args( def get_num_prefill_decode_query_kv_tokens( attn_metadata, - attn_type: AttentionType, + attn_type: str, ) -> Tuple[int, int, int]: """ Calculate the number of prefill and decode tokens for query, key/value diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index cc343705944a9..5192c07d3e495 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -284,7 +284,7 @@ def decode_metadata(self) -> Optional["XFormersMetadata"]: def _get_attn_bias( attn_metadata: XFormersMetadata, - attn_type: AttentionType, + attn_type: str, ) -> Optional[AttentionBias]: ''' Extract appropriate attention bias from attention metadata @@ -314,7 +314,7 @@ def _get_attn_bias( def _set_attn_bias( attn_metadata: XFormersMetadata, attn_bias: List[Optional[AttentionBias]], - attn_type: AttentionType, + attn_type: str, ) -> None: ''' Update appropriate attention bias field of attention metadata, @@ -416,7 +416,8 @@ def forward( attn_metadata: "XFormersMetadata", k_scale: float = 1.0, v_scale: float = 1.0, - attn_type: AttentionType = AttentionType.DECODER, + attn_type: str = AttentionType.DECODER, + output: Optional[torch.Tensor] = None, fp8_out_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with xFormers and PagedAttention. @@ -618,7 +619,7 @@ def _run_memory_efficient_xformers_forward( key: torch.Tensor, value: torch.Tensor, attn_metadata: XFormersMetadata, - attn_type: AttentionType = AttentionType.DECODER, + attn_type: str = AttentionType.DECODER, ) -> torch.Tensor: """Attention for 1D query of multiple prompts. Multiple prompt tokens are flattened in to `query` input. diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index b14ad01f1d670..29a38285b6750 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -5,11 +5,14 @@ import torch.nn as nn from vllm.attention import AttentionMetadata, AttentionType -from vllm.attention.selector import get_attn_backend -from vllm.config import CacheConfig +from vllm.attention.selector import backend_name_to_enum, get_attn_backend +from vllm.config import CacheConfig, get_current_vllm_config +from vllm.forward_context import ForwardContext, get_forward_context from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod +from vllm.platforms import _Backend, current_platform +from vllm.utils import direct_register_custom_op class Attention(nn.Module): @@ -35,18 +38,26 @@ def __init__( quant_config: Optional[QuantizationConfig] = None, blocksparse_params: Optional[Dict[str, Any]] = None, logits_soft_cap: Optional[float] = None, + per_layer_sliding_window: Optional[int] = None, prefix: str = "", ) -> None: super().__init__() + if per_layer_sliding_window is not None: + # per-layer sliding window + sliding_window = per_layer_sliding_window + elif cache_config is not None: + # model-level sliding window + sliding_window = cache_config.sliding_window + else: + sliding_window = None + if cache_config is not None: kv_cache_dtype = cache_config.cache_dtype block_size = cache_config.block_size - sliding_window = cache_config.sliding_window is_attention_free = cache_config.is_attention_free else: kv_cache_dtype = "auto" block_size = 16 - sliding_window = None is_attention_free = False if num_kv_heads is None: num_kv_heads = num_heads @@ -85,6 +96,28 @@ def __init__( self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads, alibi_slopes, sliding_window, kv_cache_dtype, blocksparse_params, logits_soft_cap) + self.num_heads = num_heads + self.head_size = head_size + self.num_kv_heads = num_kv_heads + self.backend = backend_name_to_enum(attn_backend.get_name()) + + # For cuda-alike (CUDA and ROCM) and cpu platforms, we control how + # torch.compile works by registering the attention as one giant + # opaque custom op. For other platforms, we directly call them + # and let torch.compile handle them. + self.use_direct_call = not current_platform.is_cuda_alike( + ) and not current_platform.is_cpu() + + # For some attention backends, we allocate an output tensor before + # calling the custom op. When piecewise cudagraph is enabled, this + # makes sure the output tensor is allocated inside the cudagraph. + self.use_output = self.backend == _Backend.FLASH_ATTN or \ + self.backend == _Backend.FLASH_ATTN_VLLM_V1 + compilation_config = get_current_vllm_config().compilation_config + if prefix in compilation_config.static_forward_context: + raise ValueError(f"Duplicate layer name: {prefix}") + compilation_config.static_forward_context[prefix] = self + self.layer_name = prefix def forward( self, @@ -93,18 +126,39 @@ def forward( value: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: AttentionMetadata, - attn_type: AttentionType = AttentionType.DECODER, + attn_type: str = AttentionType.DECODER, fp8_out_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: - return self.impl.forward(query, - key, - value, - kv_cache, - attn_metadata, - self._k_scale, - self._v_scale, - attn_type=attn_type, - fp8_out_scale=fp8_out_scale) + if self.use_direct_call: + return self.impl.forward(query, + key, + value, + kv_cache, + attn_metadata, + self._k_scale, + self._v_scale, + attn_type=attn_type, + fp8_out_scale=fp8_out_scale) + elif self.use_output: + output = torch.empty_like(query) + hidden_size = query.size(-1) + # Reshape the query, key, and value tensors. + # NOTE(woosuk): We do this outside the custom op to minimize the + # CPU overheads from the non-CUDA-graph regions. + query = query.view(-1, self.num_heads, self.head_size) + output = output.view(-1, self.num_heads, self.head_size) + if key is not None: + key = key.view(-1, self.num_kv_heads, self.head_size) + if value is not None: + value = value.view(-1, self.num_kv_heads, self.head_size) + torch.ops.vllm.unified_attention_with_output( + query, key, value, output, kv_cache, attn_type, + self.layer_name) + return output.view(-1, hidden_size) + else: + return torch.ops.vllm.unified_attention(query, key, value, + kv_cache, attn_type, + self.layer_name) def extra_repr(self) -> str: s = f"head_size={self.impl.head_size}" # type: ignore @@ -113,3 +167,88 @@ def extra_repr(self) -> str: s += f", scale={self.impl.scale}" # type: ignore s += f", backend={self.impl.__class__.__name__}" return s + + +def unified_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_type: str, + layer_name: str, +) -> torch.Tensor: + forward_context: ForwardContext = get_forward_context() + attn_metadata = forward_context.dynamic_forward_context + self = forward_context.static_forward_context[layer_name] + return self.impl.forward(query, + key, + value, + kv_cache, + attn_metadata, + self._k_scale, + self._v_scale, + attn_type=attn_type) + + +def unified_attention_fake( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_type: str, + layer_name: str, +) -> torch.Tensor: + return torch.empty_like(query).contiguous() + + +direct_register_custom_op( + op_name="unified_attention", + op_func=unified_attention, + mutates_args=["kv_cache"], + fake_impl=unified_attention_fake, + dispatch_key=current_platform.dispatch_key, +) + + +def unified_attention_with_output( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + output: torch.Tensor, + kv_cache: torch.Tensor, + attn_type: str, + layer_name: str, +) -> None: + forward_context: ForwardContext = get_forward_context() + attn_metadata = forward_context.dynamic_forward_context + self = forward_context.static_forward_context[layer_name] + self.impl.forward(query, + key, + value, + kv_cache, + attn_metadata, + self._k_scale, + self._v_scale, + attn_type=attn_type, + output=output) + + +def unified_attention_with_output_fake( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + output: torch.Tensor, + kv_cache: torch.Tensor, + attn_type: str, + layer_name: str, +) -> None: + return + + +direct_register_custom_op( + op_name="unified_attention_with_output", + op_func=unified_attention_with_output, + mutates_args=["kv_cache", "output"], + fake_impl=unified_attention_with_output_fake, + dispatch_key=current_platform.dispatch_key, +) diff --git a/vllm/attention/ops/blocksparse_attention/blocksparse_attention_kernel.py b/vllm/attention/ops/blocksparse_attention/blocksparse_attention_kernel.py index ec1c37c5bcb0e..727a470ba6d0e 100644 --- a/vllm/attention/ops/blocksparse_attention/blocksparse_attention_kernel.py +++ b/vllm/attention/ops/blocksparse_attention/blocksparse_attention_kernel.py @@ -157,19 +157,22 @@ def _fwd_kernel_inner( k = tl.load( k_ptrs + start_n * stride_kt, mask=offs_n[None, :] + start_n < k_seqlen, + other=0.0, ) else: k = tl.load( k_ptrs + start_n * stride_kt, mask=(offs_n[None, :] + start_n < k_seqlen) & (offs_d[:, None] < D_HEAD), + other=0.0, ) else: if EVEN_D: k = tl.load(k_ptrs + start_n * stride_kt) else: k = tl.load(k_ptrs + start_n * stride_kt, - mask=offs_d[:, None] < D_HEAD) + mask=offs_d[:, None] < D_HEAD, + other=0.0) qk = tl.zeros([BLOCK_M_LOADING, BLOCK_N], dtype=tl.float32) qk += tl.dot(q, k) @@ -200,19 +203,22 @@ def _fwd_kernel_inner( v = tl.load( v_ptrs + start_n * stride_vt, mask=offs_n[:, None] + start_n < k_seqlen, + other=0.0, ) else: v = tl.load( v_ptrs + start_n * stride_vt, mask=(offs_n[:, None] + start_n < k_seqlen) & (offs_d[None, :] < D_HEAD), + other=0.0, ) else: if EVEN_D: v = tl.load(v_ptrs + start_n * stride_vt) else: v = tl.load(v_ptrs + start_n * stride_vt, - mask=offs_d[None, :] < D_HEAD) + mask=offs_d[None, :] < D_HEAD, + other=0.0) acc += tl.dot(p, v) @@ -318,12 +324,13 @@ def _fwd_kernel_batch_inference( q = tl.load( Q + offs_m[:, None] * stride_qt + offs_d[None, :] * stride_qd, mask=offs_m[:, None] < q_seqlen, + other=0.0, ) else: q = tl.load( Q + offs_m[:, None] * stride_qt + offs_d[None, :] * stride_qd, mask=(offs_m[:, None] < q_seqlen) & (offs_d[None, :] < D_HEAD), - other=0, + other=0.0, ) sparse_crow_ptr = (layout_crow_ptr + off_h * layout_crow_stride_h + diff --git a/vllm/attention/ops/ipex_attn.py b/vllm/attention/ops/ipex_attn.py index 8df6d4ced9dc6..cbc6c74acf09a 100644 --- a/vllm/attention/ops/ipex_attn.py +++ b/vllm/attention/ops/ipex_attn.py @@ -1,12 +1,17 @@ from typing import Dict, List, Optional, Tuple -import intel_extension_for_pytorch.llm.modules as ipex_modules +try: + import intel_extension_for_pytorch.llm.modules as ipex_modules + _use_ipex = True +except ImportError: + _use_ipex = False + import torch from vllm import _custom_ops as ops -class PagedAttention: +class _PagedAttention: @staticmethod def get_supported_head_sizes() -> List[int]: @@ -22,6 +27,105 @@ def get_kv_cache_shape( ) -> Tuple[int, ...]: return (2, num_blocks, block_size * num_kv_heads * head_size) + @staticmethod + def split_kv_cache( + kv_cache: torch.Tensor, + num_kv_heads: int, + head_size: int, + *args, + ) -> Tuple[torch.Tensor, torch.Tensor]: + x = 16 // kv_cache.element_size() + num_blocks = kv_cache.shape[1] + + key_cache = kv_cache[0] + key_cache = key_cache.view(num_blocks, num_kv_heads, head_size // x, + -1, x) + value_cache = kv_cache[1] + value_cache = value_cache.view(num_blocks, num_kv_heads, head_size, -1) + return key_cache, value_cache + + @staticmethod + def write_to_paged_cache( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: float, + v_scale: float, + *args, + ) -> None: + ops.reshape_and_cache( + key, + value, + key_cache, + value_cache, + slot_mapping.flatten(), + kv_cache_dtype, + k_scale, + v_scale, + ) + + @staticmethod + def forward_decode( + output: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + block_tables: torch.Tensor, + context_lens: torch.Tensor, + max_context_len: int, + kv_cache_dtype: str, + num_kv_heads: int, + scale: float, + alibi_slopes: Optional[torch.Tensor], + k_scale: float, + v_scale: float, + *args, + ) -> None: + tp_rank: int = 0 + blocksparse_local_blocks: int = 0 + blocksparse_vert_stride: int = 0 + blocksparse_block_size: int = 64 + blocksparse_head_sliding_step: int = 0 + block_size = value_cache.shape[3] + + ops.paged_attention_v1( + output, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + context_lens, + block_size, + max_context_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step, + ) + + @staticmethod + def copy_blocks( + kv_caches: List[torch.Tensor], + src_to_dists: Dict[int, List[int]], + *args, + ) -> None: + key_caches = [kv_cache[0] for kv_cache in kv_caches] + value_caches = [kv_cache[1] for kv_cache in kv_caches] + ops.copy_blocks(key_caches, value_caches, src_to_dists) + + +class _IPEXPagedAttention(_PagedAttention): + @staticmethod def split_kv_cache( kv_cache: torch.Tensor, @@ -55,6 +159,7 @@ def write_to_paged_cache( @staticmethod def forward_decode( + output: torch.Tensor, query: torch.Tensor, key_cache: torch.Tensor, value_cache: torch.Tensor, @@ -68,8 +173,7 @@ def forward_decode( k_scale: float, v_scale: float, *args, - ) -> torch.Tensor: - output = torch.empty_like(query) + ) -> None: block_size = value_cache.shape[2] head_mapping = torch.arange( 0, @@ -83,41 +187,5 @@ def forward_decode( scale, block_tables, context_lens, block_size, max_context_len, alibi_slopes) - return output - - @staticmethod - def forward_prefix( - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - kv_cache_dtype: str, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - block_tables: torch.Tensor, - subquery_start_loc: torch.Tensor, - prompt_lens_tensor: torch.Tensor, - context_lens: torch.Tensor, - max_subquery_len: int, - alibi_slopes: Optional[torch.Tensor], - *args, - ) -> torch.Tensor: - raise NotImplementedError - - @staticmethod - def swap_blocks( - src_kv_cache: torch.Tensor, - dst_kv_cache: torch.Tensor, - src_to_dst: Dict[int, int], - *args, - ) -> None: - raise NotImplementedError - @staticmethod - def copy_blocks( - kv_caches: List[torch.Tensor], - src_to_dists: Dict[int, List[int]], - *args, - ) -> None: - key_caches = [kv_cache[0] for kv_cache in kv_caches] - value_caches = [kv_cache[1] for kv_cache in kv_caches] - ops.copy_blocks(key_caches, value_caches, src_to_dists) +PagedAttention = _IPEXPagedAttention if _use_ipex else _PagedAttention diff --git a/vllm/attention/ops/prefix_prefill.py b/vllm/attention/ops/prefix_prefill.py index f14981ebd8268..339631872b43a 100644 --- a/vllm/attention/ops/prefix_prefill.py +++ b/vllm/attention/ops/prefix_prefill.py @@ -7,6 +7,13 @@ from vllm.platforms import current_platform +# Static kernels parameters +BASE_BLOCK = 128 if current_platform.has_device_capability(80) else 64 +NUM_WARPS = 8 + +# To check compatibility +IS_TURING = current_platform.get_device_capability() == (7, 5) + if triton.__version__ >= "2.1.0": @triton.jit @@ -50,6 +57,7 @@ def _fwd_kernel( stride_v_cache_d, stride_v_cache_bl, num_queries_per_kv: int, + IN_PRECISION: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, # head size BLOCK_DMODEL_PADDED: tl.constexpr, # head size padded to a power of 2 @@ -130,7 +138,7 @@ def _fwd_kernel( k = k_load qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) # [M,N] - qk += tl.dot(q, k) + qk = tl.dot(q, k, acc=qk, input_precision=IN_PRECISION) qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk, float("-inf")) qk *= sm_scale @@ -178,7 +186,7 @@ def _fwd_kernel( v = v_load p = p.to(v.dtype) - acc += tl.dot(p, v) + acc = tl.dot(p, v, acc=acc, input_precision=IN_PRECISION) # # update m_i and l_i l_i = l_i_new m_i = m_i_new @@ -204,7 +212,7 @@ def _fwd_kernel( other=0.0) qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk += tl.dot(q, k) + qk = tl.dot(q, k, acc=qk, input_precision=IN_PRECISION) qk *= sm_scale # apply causal mask qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, @@ -238,7 +246,7 @@ def _fwd_kernel( other=0.0) p = p.to(v.dtype) - acc += tl.dot(p, v) + acc = tl.dot(p, v, acc=acc, input_precision=IN_PRECISION) # update m_i and l_i l_i = l_i_new m_i = m_i_new @@ -485,6 +493,7 @@ def _fwd_kernel_alibi( stride_v_cache_d, stride_v_cache_bl, num_queries_per_kv: int, + IN_PRECISION: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, # head size BLOCK_DMODEL_PADDED: tl.constexpr, # head size padded to a power of 2 @@ -560,7 +569,7 @@ def _fwd_kernel_alibi( k = k_load qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk += tl.dot(q, k) + qk = tl.dot(q, k, acc=qk, input_precision=IN_PRECISION) qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk, float("-inf")) qk *= sm_scale @@ -600,7 +609,7 @@ def _fwd_kernel_alibi( v = v_load p = p.to(v.dtype) - acc += tl.dot(p, v, allow_tf32=False) + acc = tl.dot(p, v, acc=acc, input_precision='ieee') # update m_i and l_i l_i = l_i_new m_i = m_i_new @@ -635,7 +644,7 @@ def _fwd_kernel_alibi( other=0.0) qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk += tl.dot(q, k, allow_tf32=False) + qk = tl.dot(q, k, acc=qk, input_precision='ieee') qk *= sm_scale qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf")) @@ -673,7 +682,7 @@ def _fwd_kernel_alibi( other=0.0) p = p.to(v.dtype) - acc += tl.dot(p, v, allow_tf32=False) + acc = tl.dot(p, v, acc=acc, input_precision='ieee') # update m_i and l_i l_i = l_i_new m_i = m_i_new @@ -709,13 +718,17 @@ def context_attention_fwd(q, alibi_slopes=None, sliding_window=None): - BLOCK = 128 if current_platform.has_device_capability(80) else 64 - NUM_WARPS = 8 - + q_dtype_is_f32 = q.dtype is torch.float32 # need to reduce num. blocks when using fp32 # due to increased use of GPU shared memory - if q.dtype is torch.float32: - BLOCK = BLOCK // 2 + # if q.dtype is torch.float32: + BLOCK = BASE_BLOCK // 2 if q_dtype_is_f32 else BASE_BLOCK + + # Turing does have tensor core for float32 multiplication + # use ieee as fallback for triton kernels work. There is also + # warning on vllm/config.py to inform users this fallback + # implementation + IN_PRECISION = 'ieee' if IS_TURING and q_dtype_is_f32 else None # Conversion of FP8 Tensor from uint8 storage to # appropriate torch.dtype for interpretation by Triton @@ -801,6 +814,7 @@ def context_attention_fwd(q, v_cache.stride( 3), #[num_blocks, num_kv_heads, head_size, block_size] num_queries_per_kv=num_queries_per_kv, + IN_PRECISION=IN_PRECISION, BLOCK_M=BLOCK, BLOCK_DMODEL=Lk, BLOCK_DMODEL_PADDED=Lk_padded, @@ -852,6 +866,7 @@ def context_attention_fwd(q, v_cache.stride( 3), #[num_blocks, num_kv_heads, head_size, block_size] num_queries_per_kv=num_queries_per_kv, + IN_PRECISION=IN_PRECISION, BLOCK_M=BLOCK, BLOCK_DMODEL=Lk, BLOCK_DMODEL_PADDED=Lk_padded, diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 664707e9dc65d..d263839705690 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -1,4 +1,3 @@ -import enum import os from contextlib import contextmanager from functools import lru_cache @@ -9,26 +8,12 @@ import vllm.envs as envs from vllm.attention.backends.abstract import AttentionBackend from vllm.logger import init_logger -from vllm.platforms import current_platform +from vllm.platforms import _Backend, current_platform from vllm.utils import STR_BACKEND_ENV_VAR logger = init_logger(__name__) -class _Backend(enum.Enum): - FLASH_ATTN = enum.auto() - FLASH_ATTN_VLLM_V1 = enum.auto() - XFORMERS = enum.auto() - ROCM_FLASH = enum.auto() - TORCH_SDPA = enum.auto() - OPENVINO = enum.auto() - FLASHINFER = enum.auto() - HPU_ATTN = enum.auto() - PALLAS = enum.auto() - IPEX = enum.auto() - NO_ATTENTION = enum.auto() - - def backend_name_to_enum(backend_name: str) -> _Backend: assert backend_name is not None @@ -216,40 +201,11 @@ def which_attn_to_use(head_size: int, if backend_by_env_var is not None: selected_backend = backend_name_to_enum(backend_by_env_var) - if current_platform.is_cpu(): - if selected_backend != _Backend.TORCH_SDPA: - logger.info("Cannot use %s backend on CPU.", selected_backend) - return _Backend.TORCH_SDPA - - if current_platform.is_openvino(): - if selected_backend != _Backend.OPENVINO: - logger.info("Cannot use %s backend on OpenVINO.", selected_backend) - return _Backend.OPENVINO - - if current_platform.is_xpu(): - if selected_backend != _Backend.IPEX: - logger.info("Cannot use %s backend on XPU.", selected_backend) - return _Backend.IPEX - - if current_platform.is_tpu(): - if selected_backend != _Backend.PALLAS: - logger.info("Cannot use %s backend on TPU.", selected_backend) - return _Backend.PALLAS - - if current_platform.is_rocm(): - # AMD GPUs. - selected_backend = (_Backend.ROCM_FLASH if selected_backend - == _Backend.FLASH_ATTN else selected_backend) - if selected_backend == _Backend.ROCM_FLASH: - if not current_platform.has_device_capability(90): - # not Instinct series GPUs. - logger.info("flash_attn is not supported on NAVI GPUs.") - else: - logger.info("%s is not supported in AMD GPUs.", selected_backend) - return _Backend.ROCM_FLASH - - if current_platform.is_hpu(): - return _Backend.HPU_ATTN + # get device-specific default attn_backend + default_backend = current_platform.get_default_attn_backend( + selected_backend) + if default_backend is not None: + return default_backend if use_v1: return _Backend.FLASH_ATTN_VLLM_V1 diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 0cf1e3a95fcba..464bc2af8fd6d 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -1,6 +1,5 @@ import copy import dataclasses -import operator from contextlib import ExitStack from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple from unittest.mock import patch @@ -11,205 +10,15 @@ import vllm.envs as envs from vllm.config import CompilationConfig from vllm.logger import init_logger -from vllm.utils import combine_fx_passes, weak_ref_tensors +from vllm.utils import weak_ref_tensors from .counter import compilation_counter -from .fusion import FusionPass -from .reshapes import RedundantReshapesPass +from .inductor_pass import InductorPass +from .pass_manager import PostGradPassManager logger = init_logger(__name__) -def fix_functionalization(graph: fx.Graph): - """ - Rewrite the graph module to replace the pattern involving - torch._higher_order_ops.auto_functionalize.auto_functionalized - with a direct call to the inplace custom op. - - # TODO: check if PyTorch nightly has fixed this issue - """ - - # debug code, if we want to see the graph before the transformation - # with open("before.py", "w") as f: - # print(graph.python_code(root_module="self", verbose=True).src, file=f) - - nodes_to_remove = [] - - for node in graph.nodes: - # Identify the auto_functionalized node - if node.op == 'call_function' and node.target == torch._higher_order_ops.auto_functionalize.auto_functionalized: # noqa - if node.args[0] == torch.ops._C.rotary_embedding.default: - # manual replace for rotary_embedding - - # Now, collect the arguments - kwargs = node.kwargs - - query = kwargs['query'] - mm_node = query.args[0].args[0] - - # Create a new call to torch.ops._C.rotary_embedding.default - with graph.inserting_before(node): - # just insert the call to the custom op - # NOTE: don't run dead code elimination, - # otherwise this op will be removed - graph.call_function(torch.ops._C.rotary_embedding.default, - kwargs=kwargs) - - # Remove the auto_functionalized node - # Since the node may have outputs, we need to handle its users - # Replace uses of the outputs (getitem nodes) with mm_node - for user in list(node.users): - if user.op == 'call_function' and user.target == operator.getitem: # noqa - # Remove the getitem node - for getitem_user in list(user.users): - if (getitem_user.op == 'call_function' - and getitem_user.target - == torch.ops.aten.slice_scatter.default): - # Replace the uses of slice_scatter node - # with mm_node - getitem_user.replace_all_uses_with(mm_node) - nodes_to_remove.append(getitem_user) - nodes_to_remove.append(user) - nodes_to_remove.append(node) - - elif node.args[0] == torch.ops._C.fused_add_rms_norm.default: - # manual replace for fused_add_rms_norm - # this is the most effective optimization for llama - # failing to do this will result in many unnecessary copies - - kwargs = node.kwargs - - input = kwargs['input'] - residual = kwargs['residual'] - - # Create a new call to torch.ops._C.rotary_embedding.default - with graph.inserting_before(node): - # just insert the call to the custom op - # NOTE: don't run dead code elimination, - # otherwise this op will be removed - graph.call_function( - torch.ops._C.fused_add_rms_norm.default, kwargs=kwargs) - - for user in list(node.users): - if user.op == 'call_function' and user.target == operator.getitem: # noqa - # Remove the getitem node - if user.args[1] == 1: - replace_node = input - elif user.args[1] == 2: - replace_node = residual - user.replace_all_uses_with(replace_node) - nodes_to_remove.append(user) - nodes_to_remove.append(node) - elif (node.args[0] == - torch.ops._C.fused_add_rms_norm_static_fp8_quant.default): - # manual replace for fused_add_rms_norm_static_fp8_quant - # this is the most effective optimization for llama - # failing to do this will result in many unnecessary copies - - kwargs = node.kwargs - - result = kwargs['result'] - residual = kwargs['residual'] - - # Create a new call to - # torch.ops._C.fused_add_rms_norm_static_fp8_quant.default - with graph.inserting_before(node): - # just insert the call to the custom op - # NOTE: don't run dead code elimination, - # otherwise this op will be removed - graph.call_function( - torch.ops._C.fused_add_rms_norm_static_fp8_quant. - default, - kwargs=kwargs) - - for user in list(node.users): - if user.op == 'call_function' and user.target == operator.getitem: # noqa - # Remove the getitem node - if user.args[1] == 1: - replace_node = result - elif user.args[1] == 2: - replace_node = residual - user.replace_all_uses_with(replace_node) - nodes_to_remove.append(user) - nodes_to_remove.append(node) - - elif node.args[0] == torch.ops._C.rms_norm.default: - # manual replace for rms_norm - - kwargs = node.kwargs - - replace_node = kwargs['result'] - # Create a new call to torch.ops._C.rms_norm.default - with graph.inserting_before(node): - # just insert the call to the custom op - # NOTE: don't run dead code elimination, - # otherwise this op will be removed - graph.call_function(torch.ops._C.rms_norm.default, - kwargs=kwargs) - - for user in list(node.users): - if user.op == 'call_function' and user.target == operator.getitem: # noqa - user.replace_all_uses_with(replace_node) - nodes_to_remove.append(user) - nodes_to_remove.append(node) - - elif node.args[ - 0] == torch.ops._C.rms_norm_static_fp8_quant.default: # noqa - # manual replace for rms_norm_static_fp8_quant - - kwargs = node.kwargs - - replace_node = kwargs['result'] - # Create a new call to torch.ops._C.rms_norm_static_fp8_quant.default # noqa - with graph.inserting_before(node): - # just insert the call to the custom op - # NOTE: don't run dead code elimination, - # otherwise this op will be removed - graph.call_function( - torch.ops._C.rms_norm_static_fp8_quant.default, - kwargs=kwargs) - - for user in list(node.users): - if user.op == 'call_function' and user.target == operator.getitem: # noqa - user.replace_all_uses_with(replace_node) - nodes_to_remove.append(user) - nodes_to_remove.append(node) - - elif node.args[0] == torch.ops._C.silu_and_mul.default: - # manual replace for silu_and_mul - - kwargs = node.kwargs - - input = kwargs['input'] - out = kwargs['out'] - - # Create a new call to torch.ops._C.silu_and_mul.default - # cannot use kwargs, because we have an `out`, see https://github.com/pytorch/pytorch/blob/a00faf440888ffb724bad413f329a49e2b6388e7/torch/_inductor/lowering.py#L351 # noqa - with graph.inserting_before(node): - # just insert the call to the custom op - # NOTE: don't run dead code elimination, - # otherwise this op will be removed - graph.call_function( - torch.ops._C.silu_and_mul.default, - args=(out, input), - ) - replace_node = out - - for user in list(node.users): - if user.op == 'call_function' and user.target == operator.getitem: # noqa - user.replace_all_uses_with(replace_node) - nodes_to_remove.append(user) - nodes_to_remove.append(node) - - # Remove the nodes all at once - for node in nodes_to_remove: - graph.erase_node(node) - - # debug code, if we want to see the graph after the transformation - # with open("after.py", "w") as f: - # print(graph.python_code(root_module="self", verbose=True).src, file=f) - - def wrap_inductor(graph, example_inputs, additional_inductor_config, @@ -368,12 +177,8 @@ class VllmBackend: The major work of this backend is to split the graph into piecewise graphs, and pass them to the piecewise backend. - This backend also handles custom passes and adds them to Inductor config. - The order of the post-grad post-passes is: - 1. post_grad_passes (constructor parameter) - 2. config["post_grad_custom_post_pass"] - 3. fix_functionalization - This way, all passes operate on a functionalized graph. + This backend also adds the PostGradPassManager to Inductor config, + which handles the post-grad passes. """ compilation_configs: CompilationConfig @@ -402,7 +207,9 @@ def __init__( # streams, it might not be safe to share a global pool. # only investigate this when we use multiple streams self.graph_pool = global_graph_pool - self.post_grad_passes = [] + + # Passes to run on the graph post-grad. + self.post_grad_pass_manager = PostGradPassManager() self.sym_tensor_indices = [] self.input_buffers = [] @@ -412,24 +219,19 @@ def __init__( # `torch.compile` is JIT compiled, so we don't need to # do anything here - def add_passes_to_config(self): + def configure_post_pass(self): config = self.compilation_configs - passes = list(self.post_grad_passes) - - passes = passes + [RedundantReshapesPass(config)] - - if config.enable_fusion: - passes = passes + [FusionPass.instance(config)] + self.post_grad_pass_manager.configure(config.pass_config) + # Post-grad custom passes are run using the post_grad_custom_post_pass + # hook. If a pass for that hook exists, add it to the pass manager. inductor_config = config.inductor_compile_config - if "post_grad_custom_post_pass" in inductor_config: - passes = passes + [inductor_config["post_grad_custom_post_pass"]] - - # add the fix_functionalization pass last, so that all other - # passes operate on a functionalized graph - passes = passes + [fix_functionalization] - combined_pass = combine_fx_passes(passes) - inductor_config["post_grad_custom_post_pass"] = combined_pass + PASS_KEY = "post_grad_custom_post_pass" + if PASS_KEY in inductor_config: + # Config should automatically wrap all inductor passes + assert isinstance(inductor_config[PASS_KEY], InductorPass) + self.post_grad_pass_manager.add(inductor_config[PASS_KEY]) + inductor_config[PASS_KEY] = self.post_grad_pass_manager def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: @@ -444,10 +246,10 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: # we get the sizes to capture for cudagraph # from compilation context self.compilation_configs.init_during_runtime() - self.add_passes_to_config() + self.configure_post_pass() self.split_gm, self.piecewise_graphs = split_graph( - graph, self.compilation_configs.non_cudagraph_ops) + graph, self.compilation_configs.splitting_ops) from torch._dynamo.utils import lazy_format_graph_code logger.debug("%s", lazy_format_graph_code("before split", self.graph)) diff --git a/vllm/compilation/counter.py b/vllm/compilation/counter.py index 100a49aba74ac..6385f1c5dbf81 100644 --- a/vllm/compilation/counter.py +++ b/vllm/compilation/counter.py @@ -5,6 +5,7 @@ @dataclasses.dataclass class CompilationCounter: + num_models_seen: int = 0 num_graphs_seen: int = 0 # including the splitting ops num_piecewise_graphs_seen: int = 0 diff --git a/vllm/compilation/decorators.py b/vllm/compilation/decorators.py index 4b78491bc5a48..8700243c9d904 100644 --- a/vllm/compilation/decorators.py +++ b/vllm/compilation/decorators.py @@ -1,8 +1,10 @@ import inspect -from typing import Dict, List, Optional, Union +from typing import Callable, Dict, List, Optional, TypeVar, Union, overload import torch +import torch.nn as nn +from vllm.compilation.counter import compilation_counter from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher from vllm.config import CompilationLevel, VllmConfig from vllm.logger import init_logger @@ -11,10 +13,27 @@ logger = init_logger(__name__) +_T = TypeVar("_T", bound=type[nn.Module]) + + +@overload +def support_torch_compile( + *, + dynamic_arg_dims: Optional[Dict[str, Union[int, List[int]]]], +) -> Callable[[_T], _T]: + ... + + +@overload +def support_torch_compile(cls: _T) -> _T: + ... + def support_torch_compile( - cls: Optional[type] = None, - dynamic_arg_dims: Optional[Dict[str, Union[int, List[int]]]] = None): + cls: Optional[_T] = None, + *, + dynamic_arg_dims: Optional[Dict[str, Union[int, List[int]]]] = None, +) -> Union[Callable[[_T], _T], _T]: """ A decorator to add support for compiling the forward method of a class. @@ -65,7 +84,7 @@ def forward(self, x: torch.Tensor, y: Optional[torch.Tensor]): computation graph. """ - def cls_decorator_helper(cls: type): + def cls_decorator_helper(cls: _T) -> _T: # helper to pass `dynamic_arg_dims`` to `_support_torch_compile`` # to avoid too much indentation for `_support_torch_compile`` if not hasattr(cls, 'forward'): @@ -104,8 +123,10 @@ def cls_decorator_helper(cls: type): return cls_decorator_helper -def _support_torch_compile(cls: type, - dynamic_arg_dims: Dict[str, Union[int, List[int]]]): +def _support_torch_compile( + cls: _T, + dynamic_arg_dims: Dict[str, Union[int, List[int]]], +) -> _T: """ A decorator to add support for compiling the forward method of a class. """ @@ -118,7 +139,7 @@ def _support_torch_compile(cls: type, # other than TorchCompileWrapperWithCustomDispatcher cls.__bases__ = cls.__bases__ + (TorchCompileWrapperWithCustomDispatcher, ) - old_init = cls.__init__ # type: ignore + old_init = cls.__init__ def __init__(self, *, vllm_config: VllmConfig, prefix: str = '', **kwargs): old_init(self, vllm_config=vllm_config, prefix=prefix, **kwargs) @@ -130,10 +151,11 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = '', **kwargs): ] or not supports_dynamo() if self.do_not_compile: return + compilation_counter.num_models_seen += 1 TorchCompileWrapperWithCustomDispatcher.__init__( self, compilation_level=vllm_config.compilation_config.level) - cls.__init__ = __init__ # type: ignore + cls.__init__ = __init__ def __call__(self, *args, **kwargs): # torch.compiler.is_compiling() means we are inside the compilation @@ -178,5 +200,5 @@ def __call__(self, *args, **kwargs): model_output = self.forward(*args, **kwargs) return model_output - cls.__call__ = __call__ # type: ignore + cls.__call__ = __call__ return cls diff --git a/vllm/compilation/fix_functionalization.py b/vllm/compilation/fix_functionalization.py new file mode 100644 index 0000000000000..3584cc3608caf --- /dev/null +++ b/vllm/compilation/fix_functionalization.py @@ -0,0 +1,177 @@ +import operator +from typing import Dict, Iterable, List, Optional, Tuple, Union + +import torch +from torch._higher_order_ops.auto_functionalize import auto_functionalized + +from vllm.logger import init_logger + +from .vllm_inductor_pass import VllmInductorPass, is_func + +logger = init_logger(__name__) + + +class FixFunctionalizationPass(VllmInductorPass): + """ + This pass defunctionalizes certain nodes to avoid redundant tensor copies. + After this pass, DCE (dead-code elimination) should never be run, + as de-functionalized nodes may appear as dead code. + + To add new nodes to defunctionalize, add to the if-elif chain in __call__. + """ + + def __call__(self, graph: torch.fx.Graph): + self.begin() + self.dump_graph(graph, "before_fix_functionalization") + + self.nodes_to_remove: List[torch.fx.Node] = [] + count = 0 + for node in graph.nodes: + if not is_func(node, auto_functionalized): + continue # Avoid deep if-elif nesting + + kwargs = node.kwargs + at_target = node.args[0] + + if at_target == torch.ops._C.rotary_embedding.default: + query = kwargs['query'] + mm_node = query.args[0].args[0] + + # rotary_embedding is a special case: the two mutating inputs + # are query and key, which are slices of mm_node. + # While functionalized, results at[1] and at[2] are scattered + # back into mm_node. After de-functionalization, we can just + # use mm_node directly. + for idx, user in self.getitem_users(node).items(): + for user_of_getitem in user.users: + if is_func(user_of_getitem, + torch.ops.aten.slice_scatter.default): + user_of_getitem.replace_all_uses_with(mm_node) + self._remove(user_of_getitem) + self._remove(user) + + self.insert_defunctionalized(graph, node) + self._remove(node) + + # These 2 replacements avoid the most copies for LLaMa. + elif at_target == torch.ops._C.fused_add_rms_norm.default: + mutated_args = {1: 'input', 2: 'residual'} + self.defunctionalize(graph, node, mutated_args) + elif at_target == torch.ops._C.fused_add_rms_norm_static_fp8_quant.default: # noqa: E501 + mutated_args = {1: 'result', 2: 'residual'} + self.defunctionalize(graph, node, mutated_args) + + elif at_target in [ + torch.ops._C.rms_norm.default, + torch.ops._C.rms_norm_static_fp8_quant.default + ]: + mutated_args = {1: 'result'} + self.defunctionalize(graph, node, mutated_args) + + elif at_target == torch.ops._C.silu_and_mul.default: + mutated_args = {1: 'out'} + # Because we have an 'out', need to specify args directly + self.defunctionalize(graph, + node, + mutated_args, + args=('out', 'input')) + else: + continue # skip the count + + count += 1 + + self.dump_graph(graph, "before_fix_functionalization_cleanup") + + # Remove the nodes all at once + count_removed = len(self.nodes_to_remove) + for node in self.nodes_to_remove: + graph.erase_node(node) + + logger.debug("De-functionalized %s nodes, removed %s nodes", count, + count_removed) + self.dump_graph(graph, "after_fix_functionalization") + self.end_and_log() + + def _remove(self, node_or_nodes: Union[torch.fx.Node, + Iterable[torch.fx.Node]]): + """ + Stage a node (or nodes) for removal at the end of the pass. + """ + if isinstance(node_or_nodes, torch.fx.Node): + self.nodes_to_remove.append(node_or_nodes) + else: + self.nodes_to_remove.extend(node_or_nodes) + + def defunctionalize(self, + graph: torch.fx.Graph, + node: torch.fx.Node, + mutated_args: Dict[int, Union[torch.fx.Node, str]], + args: Optional[Tuple[Union[torch.fx.Node, str], + ...]] = None): + """ + De-functionalize a node by replacing it with a call to the original. + It also replaces the getitem users with the mutated arguments. + See replace_users_with_mutated_args and insert_defunctionalized. + """ + self.replace_users_with_mutated_args(node, mutated_args) + self.insert_defunctionalized(graph, node, args=args) + self._remove(node) + + def replace_users_with_mutated_args(self, node: torch.fx.Node, + mutated_args: Dict[int, + Union[torch.fx.Node, + str]]): + """ + Replace all getitem users of the auto-functionalized node with the + mutated arguments. + :param node: The auto-functionalized node + :param mutated_args: The mutated arguments, indexed by getitem index. + If the value of an arg is a string, `node.kwargs[arg]` is used. + """ + for idx, user in self.getitem_users(node).items(): + arg = mutated_args[idx] + arg = node.kwargs[arg] if isinstance(arg, str) else arg + user.replace_all_uses_with(arg) + self._remove(user) + + def getitem_users(self, node: torch.fx.Node) -> Dict[int, torch.fx.Node]: + """ + Returns the operator.getitem users of the auto-functionalized node, + indexed by the index they are getting. + """ + users = {} + for user in node.users: + if is_func(user, operator.getitem): + idx = user.args[1] + users[idx] = user + return users + + def insert_defunctionalized(self, + graph: torch.fx.Graph, + node: torch.fx.Node, + args: Optional[Tuple[Union[torch.fx.Node, str], + ...]] = None): + """ + Insert a new defunctionalized node into the graph before node. + If one of the kwargs is 'out', provide args directly, + as node.kwargs cannot be used. + See https://github.com/pytorch/pytorch/blob/a00faf440888ffb724bad413f329a49e2b6388e7/torch/_inductor/lowering.py#L351 + + :param graph: Graph to insert the defunctionalized node into + :param node: The auto-functionalized node to defunctionalize + :param args: If we cannot use kwargs, specify args directly. + If an arg is a string, `node.kwargs[arg]` is used. + """ # noqa: E501 + assert is_func(node, auto_functionalized), \ + f"node must be auto-functionalized, is {node} instead" + + # Create a new call to the original function + with graph.inserting_before(node): + function = node.args[0] + if args is None: + graph.call_function(function, kwargs=node.kwargs) + else: + # Args passed as strings refer to items in node.kwargs + args = tuple(node.kwargs[arg] if isinstance(arg, str) else arg + for arg in args) + graph.call_function(function, args=args) diff --git a/vllm/compilation/fusion.py b/vllm/compilation/fusion.py index e6a3afef85e1b..5efa410fab6a0 100644 --- a/vllm/compilation/fusion.py +++ b/vllm/compilation/fusion.py @@ -6,10 +6,11 @@ from torch._inductor.pattern_matcher import (Match, PatternMatcherPass, fwd_only, register_replacement) -from vllm.compilation.inductor_pass import InductorPass from vllm.config import CompilationConfig from vllm.logger import init_logger +from .vllm_inductor_pass import VllmInductorPass, is_func + logger = init_logger(__name__) @@ -90,8 +91,6 @@ def empty_fp32(*args, **kwargs): # Utilities for post-processing multi-output matches -def is_func(node: torch.fx.Node, target) -> bool: - return node.op == "call_function" and node.target == target # Returns the first auto_functionalized node with the given op (if it exists) @@ -127,7 +126,7 @@ def find_getitem(node: torch.fx.Node, idx: int) -> torch.fx.Node: return ret -class FusionPass(InductorPass): +class FusionPass(VllmInductorPass): """ This pass fuses a pre-defined set of custom ops into fused ops. It uses the torch pattern matcher to find the patterns and replace them. @@ -142,7 +141,7 @@ class FusionPass(InductorPass): _instance: 'Optional[FusionPass]' = None @classmethod - def instance(cls, config: CompilationConfig): + def instance(cls, config: CompilationConfig.PassConfig): """ Get the singleton instance of the FusionPass. If the instance exists, the config is updated but @@ -154,7 +153,7 @@ def instance(cls, config: CompilationConfig): cls._instance.config = config return cls._instance - def __init__(self, config: CompilationConfig): + def __init__(self, config: CompilationConfig.PassConfig): assert self.__class__._instance is None, \ "FusionPass singleton instance already exists" super().__init__(config) @@ -278,6 +277,7 @@ def process_matches(self, graph: torch.fx.Graph): for node in match.nodes) def __call__(self, graph: torch.fx.Graph): + self.begin() self.dump_graph(graph, "before_fusion") count = self.patterns.apply(graph) @@ -289,3 +289,4 @@ def __call__(self, graph: torch.fx.Graph): logger.debug("Post-processed %s matches", len(self.matches)) self.dump_graph(graph, "after_fusion") self.matches.clear() + self.end_and_log() diff --git a/vllm/compilation/inductor_pass.py b/vllm/compilation/inductor_pass.py index 8082a08b40019..f6846c08ac841 100644 --- a/vllm/compilation/inductor_pass.py +++ b/vllm/compilation/inductor_pass.py @@ -1,38 +1,84 @@ +import hashlib +import inspect +import types from abc import ABC, abstractmethod +from typing import Any, Callable, Optional, Union import torch - -from vllm.config import CompilationConfig -# yapf: disable -from vllm.distributed import get_tensor_model_parallel_rank as get_tp_rank -from vllm.distributed import ( - get_tensor_model_parallel_world_size as get_tp_world_size) -from vllm.distributed import model_parallel_is_initialized as p_is_init -# yapf: enable -from vllm.logger import init_logger - -logger = init_logger(__name__) +from torch import fx class InductorPass(ABC): + """ + General custom inductor pass interface. + TODO(torch==2.6) use torch._inductor.custom_graph_pass.CustomGraphPass + """ @abstractmethod def __call__(self, graph: torch.fx.Graph): + """ + Execute the pass on the given graph. + """ raise NotImplementedError - def __init__(self, config: CompilationConfig): - self.config = config - - def dump_graph(self, graph: torch.fx.Graph, stage: str): - if stage in self.config.dump_graph_stages: - # Make sure filename includes rank in the distributed setting - parallel = p_is_init() and get_tp_world_size() > 1 - rank = f"-{get_tp_rank()}" if parallel else "" - filepath = self.config.dump_graph_dir / f"{stage}{rank}.py" - - logger.info("Printing graph to %s", filepath) - with open(filepath, "w") as f: - src = graph.python_code(root_module="self", verbose=True).src - # Add imports so it's not full of errors - print("import torch; from torch import device", file=f) - print(src, file=f) + def uuid(self) -> Any: + """ + Provide a unique identifier for the pass, used in Inductor code cache. + This should depend on the pass implementation, so that changes to the + pass result in recompilation. + By default, the object source is hashed. + """ + return InductorPass.hash_source(self) + + @staticmethod + def hash_source(*srcs: Union[str, Any]): + """ + Utility method to hash the sources of functions or objects. + :param srcs: strings or objects to add to the hash. + Objects and functions have their source inspected. + :return: + """ + hasher = hashlib.sha256() + for src in srcs: + if isinstance(src, str): + src_str = src + elif isinstance(src, types.FunctionType): + src_str = inspect.getsource(src) + else: + src_str = inspect.getsource(src.__class__) + hasher.update(src_str.encode("utf-8")) + return hasher.digest() + + +class CallableInductorPass(InductorPass): + """ + This class is a wrapper for a callable that automatically provides an + implementation of the UUID. + """ + + def __init__(self, + callable: Callable[[fx.Graph], None], + uuid: Optional[Any] = None): + self.callable = callable + if uuid is None: + uuid = InductorPass.hash_source(callable) + self._uuid = uuid + + def __call__(self, graph: torch.fx.Graph): + self.callable(graph) + + def uuid(self) -> Any: + return self._uuid + + def __getstate__(self): + """ + Pickling occurs in the Inductor code cache if a pass is not given to + the pass manager but is instead directly added to config as a pass. + See PostGradPassManager for more. + + TODO(torch==2.6), use the `uuid` method in CustomGraphPass instead. + """ + return self._uuid + + def __setstate__(self, state): + raise ValueError("Cannot unpickle CallableInductorPass") diff --git a/vllm/compilation/pass_manager.py b/vllm/compilation/pass_manager.py new file mode 100644 index 0000000000000..fb522ae053e97 --- /dev/null +++ b/vllm/compilation/pass_manager.py @@ -0,0 +1,77 @@ +from typing import List + +from torch import fx as fx + +from vllm.config import CompilationConfig +from vllm.logger import init_logger + +from .fix_functionalization import FixFunctionalizationPass +from .fusion import FusionPass +from .inductor_pass import InductorPass +from .reshapes import RedundantReshapesPass + +logger = init_logger(__name__) + + +class PostGradPassManager: + """ + The pass manager for post-grad passes. + It handles configuration, adding custom passes, and running passes. + It also supports pickling, which is used by the Inductor code cache. + TODO(torch==2.6), use CustomGraphPass + (torch._inductor.custom_graph_pass.CustomGraphPass) + + The order of the post-grad post-passes is: + 1. passes (constructor parameter) + 2. default passes (RedundantReshapesPass, FusionPass) + 3. config["post_grad_custom_post_pass"] (if it exists) + 4. fix_functionalization + This way, all passes operate on a functionalized graph. + """ + + def __init__(self): + self.passes: List[InductorPass] = [] + + def __call__(self, graph: fx.Graph): + for pass_ in self.passes: + pass_(graph) + + # always run fix_functionalization last + self.fix_functionalization(graph) + + def configure(self, pass_config: CompilationConfig.PassConfig): + self.pass_config = pass_config + if pass_config.enable_reshape: + self.passes += [RedundantReshapesPass(pass_config)] + + if pass_config.enable_fusion: + self.passes += [FusionPass.instance(pass_config)] + + self.fix_functionalization = FixFunctionalizationPass(pass_config) + + def add(self, pass_: InductorPass): + assert isinstance(pass_, InductorPass) + self.passes.append(pass_) + + def __getstate__(self): + """ + Custom pickling for the pass manager, as some passes cannot be pickled. + Pickling occurs because the pass manager is set as the value of + `config["post_grad_custom_post_pass"]` in the Inductor config. + The config is pickled to act as a key in the Inductor code cache. + Any other passes in the config are pickled as well. + + TODO(torch==2.6), use the `uuid` method in CustomGraphPass instead. + """ + state = {"pass_config": self.pass_config.uuid(), "passes": []} + for pass_ in self.passes: + state["passes"].append(pass_.uuid()) + state["passes"].append(self.fix_functionalization.uuid()) + return state + + def __setstate__(self, state): + """ + Do not allow unpickling of the pass manager. + If this is needed in the future, it should properly pickle the passes. + """ + raise ValueError("Cannot unpickle PostGradPassManager") diff --git a/vllm/compilation/reshapes.py b/vllm/compilation/reshapes.py index 36597e119d2e1..63a369fe8d966 100644 --- a/vllm/compilation/reshapes.py +++ b/vllm/compilation/reshapes.py @@ -3,14 +3,14 @@ import torch.fx from torch import SymInt -from vllm.compilation.fusion import is_func -from vllm.compilation.inductor_pass import InductorPass from vllm.logger import init_logger +from .vllm_inductor_pass import VllmInductorPass, is_func + logger = init_logger(__name__) -class RedundantReshapesPass(InductorPass): +class RedundantReshapesPass(VllmInductorPass): """ This is an inductor pass that removes redundant reshape operations. It is required for RMSNorm-quant fusion to work properly. @@ -31,6 +31,7 @@ class RedundantReshapesPass(InductorPass): """ def __call__(self, graph: torch.fx.Graph): + self.begin() self.dump_graph(graph, "before_reshapes") count = 0 # Remove no-op reshapes/views: @@ -56,6 +57,7 @@ def __call__(self, graph: torch.fx.Graph): logger.debug("Removed %s no-op reshapes", count) self.dump_graph(graph, "after_reshapes") + self.end_and_log() def dims_equivalent(self, dim: Union[int, torch.fx.Node], i_dim: Union[int, SymInt]) -> bool: diff --git a/vllm/compilation/vllm_inductor_pass.py b/vllm/compilation/vllm_inductor_pass.py new file mode 100644 index 0000000000000..dbf6b8f7789e1 --- /dev/null +++ b/vllm/compilation/vllm_inductor_pass.py @@ -0,0 +1,53 @@ +import time + +import torch + +from vllm.config import CompilationConfig +# yapf: disable +from vllm.distributed import get_tensor_model_parallel_rank as get_tp_rank +from vllm.distributed import ( + get_tensor_model_parallel_world_size as get_tp_world_size) +from vllm.distributed import model_parallel_is_initialized as p_is_init +# yapf: enable +from vllm.logger import init_logger + +from .inductor_pass import InductorPass + +logger = init_logger(__name__) + + +def is_func(node: torch.fx.Node, target) -> bool: + return node.op == "call_function" and node.target == target + + +class VllmInductorPass(InductorPass): + """ + An inductor pass with access to vLLM PassConfig. + It provides timing, logging, and dumping utilities. + """ + + def __init__(self, config: CompilationConfig.PassConfig): + self.config = config + self.pass_name = self.__class__.__name__ + + def dump_graph(self, graph: torch.fx.Graph, stage: str): + if stage in self.config.dump_graph_stages: + # Make sure filename includes rank in the distributed setting + parallel = p_is_init() and get_tp_world_size() > 1 + rank = f"-{get_tp_rank()}" if parallel else "" + filepath = self.config.dump_graph_dir / f"{stage}{rank}.py" + + logger.info("%s printing graph to %s", self.pass_name, filepath) + with open(filepath, "w") as f: + src = graph.python_code(root_module="self", verbose=True).src + # Add imports so it's not full of errors + print("import torch; from torch import device", file=f) + print(src, file=f) + + def begin(self): + self._start_time = time.perf_counter_ns() + + def end_and_log(self): + self._end_time = time.perf_counter_ns() + duration_ms = float(self._end_time - self._start_time) / 1.0e6 + logger.debug("%s completed in %.1f ms", self.pass_name, duration_ms) diff --git a/vllm/compilation/wrapper.py b/vllm/compilation/wrapper.py index 0143d0301ca1a..bc4d292fef402 100644 --- a/vllm/compilation/wrapper.py +++ b/vllm/compilation/wrapper.py @@ -8,7 +8,7 @@ import torch import vllm.envs as envs -from vllm.config import CompilationLevel +from vllm.config import CompilationLevel, get_current_vllm_config class TorchCompileWrapperWithCustomDispatcher: @@ -32,7 +32,6 @@ def __init__(self, # default compilation settings # compiling the forward method - from vllm.plugins import get_current_vllm_config backend = get_current_vllm_config( ).compilation_config.init_backend() diff --git a/vllm/config.py b/vllm/config.py index 6c003bdf2cfc0..fe25ea9f76038 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1,19 +1,24 @@ import copy import enum +import hashlib import json import warnings +from contextlib import contextmanager from dataclasses import dataclass, field, replace from pathlib import Path -from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Dict, Final, List, - Literal, Mapping, Optional, Set, Tuple, Type, Union) +from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Counter, Dict, + Final, List, Literal, Mapping, Optional, Set, Tuple, Type, + Union) import torch from pydantic import BaseModel, Field, PrivateAttr from transformers import PretrainedConfig import vllm.envs as envs +from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass from vllm.logger import init_logger -from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS +from vllm.model_executor.layers.quantization import (QUANTIZATION_METHODS, + get_quantization_config) from vllm.model_executor.models import ModelRegistry from vllm.platforms import current_platform from vllm.tracing import is_otel_available, otel_import_error_traceback @@ -22,7 +27,7 @@ get_hf_text_config, get_pooling_config, get_sentence_transformer_tokenizer_config, is_encoder_decoder, uses_mrope) from vllm.utils import (GiB_bytes, cuda_device_count_stateless, get_cpu_memory, - identity, is_mi250, is_navi, print_warning_once, + is_mi250, is_navi, print_warning_once, resolve_obj_by_qualname) if TYPE_CHECKING: @@ -87,6 +92,8 @@ class ModelConfig: the default version. max_model_len: Maximum length of a sequence (including prompt and output). If None, will be derived from the model. + spec_target_max_model_len: Specify the the maximum length for spec + decoding draft models. quantization: Quantization method that was used to quantize the model weights. If None, we assume the model weights are not quantized. quantization_param_path: Path to JSON file containing scaling factors. @@ -103,6 +110,7 @@ class ModelConfig: to eager mode. Additionally for encoder-decoder models, if the sequence length of the encoder input is larger than this, we fall back to the eager mode. + max_logprobs: Maximum number of log probabilities. Defaults to 20. disable_sliding_window: Whether to disable sliding window. If True, we will disable the sliding window functionality of the model. If the model does not support sliding window, this argument is @@ -115,6 +123,8 @@ class ModelConfig: the model name will be the same as `model`. limit_mm_per_prompt: Maximum number of data items per modality per prompt. Only applicable for multimodal models. + use_async_output_proc: Whether to use async output processor. + Defaults to True. config_format: The config format which shall be loaded. Defaults to 'auto' which defaults to 'hf'. hf_overrides: If a dictionary, contains arguments to be forwarded to the @@ -126,7 +136,7 @@ class ModelConfig: override default neuron config that are specific to Neuron devices, this argument will be used to configure the neuron config that can not be gathered from the vllm arguments. - override_pooling_config: Initialize non default pooling config or + override_pooler_config: Initialize non default pooling config or override default pooling config for the embedding model. """ @@ -179,7 +189,7 @@ def __init__( hf_overrides_fn = hf_overrides else: hf_overrides_kw = hf_overrides - hf_overrides_fn = identity + hf_overrides_fn = None if rope_scaling is not None: hf_override: Dict[str, Any] = {"rope_scaling": rope_scaling} @@ -208,8 +218,15 @@ def __init__( self.skip_tokenizer_init = skip_tokenizer_init hf_config = get_config(self.model, trust_remote_code, revision, - code_revision, config_format, **hf_overrides_kw) - hf_config = hf_overrides_fn(hf_config) + code_revision, config_format) + + if hf_overrides_kw: + logger.info("Overriding HF config with %s", hf_overrides_kw) + hf_config.update(hf_overrides_kw) + if hf_overrides_fn: + logger.info("Overriding HF config with %s", hf_overrides_fn) + hf_config = hf_overrides_fn(hf_config) + self.hf_config = hf_config self.hf_text_config = get_hf_text_config(self.hf_config) @@ -230,15 +247,26 @@ def __init__( (self.hf_text_config.model_type in ["gemma2"])) if (not self.disable_sliding_window and has_interleaved_attention): - sliding_window_len_min = get_min_sliding_window( - self.hf_text_config.sliding_window) - - print_warning_once( - f"{self.hf_text_config.model_type} has interleaved attention, " - "which is currently not supported by vLLM. Disabling sliding " - "window and capping the max length to the sliding window size " - f"({sliding_window_len_min}).") - self.disable_sliding_window = True + if envs.VLLM_ATTENTION_BACKEND == "XFORMERS": + sliding_window_len_min = get_min_sliding_window( + self.hf_text_config.sliding_window) + + print_warning_once( + f"{self.hf_text_config.model_type} has interleaved " + "attention, which is currently not supported by the " + "XFORMERS backend. Disabling sliding window and capping " + "the max length to the sliding window size " + f"({sliding_window_len_min}).") + self.disable_sliding_window = True + else: + # for a model with interleaved attention, + # the scheduler and the model treat it as full attention + # (i.e., not dropping any tokens outside the window). + # only the attention layer itself is aware of the sliding + # window, and use the window size to compute the attention. + self.hf_text_config.interleaved_sliding_window = sliding_window + delattr(self.hf_text_config, "sliding_window") + sliding_window = None self.max_model_len = _get_and_verify_max_len( hf_config=self.hf_text_config, @@ -337,7 +365,7 @@ def _resolve_task( # NOTE: Listed from highest to lowest priority, # in case the model supports multiple of them "generate": ModelRegistry.is_text_generation_model(architectures), - "embedding": ModelRegistry.is_embedding_model(architectures), + "embedding": ModelRegistry.is_pooling_model(architectures), } supported_tasks_lst: List[_Task] = [ task for task, is_supported in task_support.items() if is_supported @@ -348,6 +376,31 @@ def _resolve_task( selected_task = next(iter(supported_tasks_lst)) if len(supported_tasks) > 1: + suffix_to_preferred_task: List[Tuple[str, _Task]] = [ + # Hardcode the models that are exceptions + ("AquilaModel", "generate"), + ("ChatGLMModel", "generate"), + # Other models follow this pattern + ("ForCausalLM", "generate"), + ("ForConditionalGeneration", "generate"), + ("ChatModel", "generate"), + ("LMHeadModel", "generate"), + ("EmbeddingModel", "embedding"), + ("RewardModel", "embedding"), + ("ForSequenceClassification", "embedding"), + ] + info, arch = ModelRegistry.inspect_model_cls(architectures) + + for suffix, pref_task in suffix_to_preferred_task: + if arch.endswith(suffix) and pref_task in supported_tasks: + selected_task = pref_task + break + else: + if (arch.endswith("Model") + and info.architecture.endswith("ForCausalLM") + and "embedding" in supported_tasks): + selected_task = "embedding" + logger.info( "This model supports multiple tasks: %s. " "Defaulting to '%s'.", supported_tasks, selected_task) @@ -370,18 +423,12 @@ def _parse_quant_hf_config(self): return quant_cfg def _verify_quantization(self) -> None: - supported_quantization = [*QUANTIZATION_METHODS] - rocm_supported_quantization = [ - "awq", "gptq", "fp8", "compressed_tensors", "compressed-tensors", - "fbgemm_fp8" - ] + supported_quantization = QUANTIZATION_METHODS optimized_quantization_methods = [ "fp8", "marlin", "modelopt", "gptq_marlin_24", "gptq_marlin", "awq_marlin", "fbgemm_fp8", "compressed_tensors", "compressed-tensors", "experts_int8" ] - tpu_supported_quantization = ["tpu_int8"] - neuron_supported_quantization = ["neuron_quant"] if self.quantization is not None: self.quantization = self.quantization.lower() @@ -392,19 +439,15 @@ def _verify_quantization(self) -> None: quant_method = quant_cfg.get("quant_method", "").lower() # Detect which checkpoint is it - for _, method in QUANTIZATION_METHODS.items(): + for name in QUANTIZATION_METHODS: + method = get_quantization_config(name) quantization_override = method.override_quantization_method( quant_cfg, self.quantization) - if quantization_override: - if current_platform.is_rocm(): - if quantization_override in rocm_supported_quantization: - quant_method = quantization_override - self.quantization = quantization_override - break - else: - quant_method = quantization_override - self.quantization = quantization_override - break + if (quantization_override and quantization_override + in current_platform.supported_quantization): + quant_method = quantization_override + self.quantization = quantization_override + break # Verify quantization configurations. if self.quantization is None: @@ -421,32 +464,12 @@ def _verify_quantization(self) -> None: raise ValueError( f"Unknown quantization method: {self.quantization}. Must " f"be one of {supported_quantization}.") - if current_platform.is_rocm( - ) and self.quantization not in rocm_supported_quantization: - raise ValueError( - f"{self.quantization} quantization is currently not " - f"supported in ROCm.") - if current_platform.is_tpu( - ) and self.quantization not in tpu_supported_quantization: - raise ValueError( - f"{self.quantization} quantization is currently not " - f"supported in TPU Backend.") + current_platform.verify_quantization(self.quantization) if self.quantization not in optimized_quantization_methods: logger.warning( "%s quantization is not fully " "optimized yet. The speed can be slower than " "non-quantized models.", self.quantization) - if (self.quantization == "awq" and current_platform.is_rocm() - and not envs.VLLM_USE_TRITON_AWQ): - logger.warning( - "Using AWQ quantization with ROCm, but VLLM_USE_TRITON_AWQ" - " is not set, enabling VLLM_USE_TRITON_AWQ.") - envs.VLLM_USE_TRITON_AWQ = True - if current_platform.is_neuron( - ) and self.quantization not in neuron_supported_quantization: - raise ValueError( - f"{self.quantization} quantization is currently not " - f"supported in Neuron Backend.") def _verify_cuda_graph(self) -> None: if self.max_seq_len_to_capture is None: @@ -703,6 +726,11 @@ def uses_mrope(self) -> bool: def is_multimodal_model(self) -> bool: return self.multimodal_config is not None + @property + def is_cross_encoder(self) -> bool: + architectures = getattr(self.hf_config, "architectures", []) + return ModelRegistry.is_cross_encoder_model(architectures) + class CacheConfig: """Configuration for the KV cache. @@ -713,8 +741,13 @@ class CacheConfig: vLLM execution. swap_space: Size of the CPU swap space per GPU (in GiB). cache_dtype: Data type for kv cache storage. + is_attention_free: Whether the model is attention-free. num_gpu_blocks_override: Number of GPU blocks to use. This overrides the profiled num_gpu_blocks if specified. Does nothing if None. + sliding_window: Sliding window size for the KV cache. Can not work with + prefix caching enabled. + enable_prefix_caching: Whether to enable prefix caching. + cpu_offload_gb: Size of the CPU offload buffer in GiB. """ def __init__( @@ -883,6 +916,7 @@ class LoadConfig: "tensorizer" will use CoreWeave's tensorizer library for fast weight loading. "bitsandbytes" will load nf4 type weights. + model_loader_extra_config: The extra config for the model loader. ignore_patterns: The list of patterns to ignore when loading the model. Default to "original/**/*" to avoid repeated loading of llama's checkpoints. @@ -928,76 +962,72 @@ def _verify_load_format(self) -> None: f"{rocm_supported_load_format}") +@dataclass class ParallelConfig: - """Configuration for the distributed execution. + """Configuration for the distributed execution.""" - Args: - pipeline_parallel_size: Number of pipeline parallel groups. - tensor_parallel_size: Number of tensor parallel groups. - worker_use_ray: Deprecated, use distributed_executor_backend instead. - max_parallel_loading_workers: Maximum number of multiple batches - when load model sequentially. To avoid RAM OOM when using tensor - parallel and large models. - disable_custom_all_reduce: Disable the custom all-reduce kernel and - fall back to NCCL. - tokenizer_pool_config: Config for the tokenizer pool. - If None, will use synchronous tokenization. - ray_workers_use_nsight: Whether to profile Ray workers with nsight, see - https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler. - placement_group: ray distributed model workers placement group. - distributed_executor_backend: Backend to use for distributed model - workers, either "ray" or "mp" (multiprocessing). If the product - of pipeline_parallel_size and tensor_parallel_size is less than - or equal to the number of GPUs available, "mp" will be used to - keep processing on a single host. Otherwise, this will default - to "ray" if Ray is installed and fail otherwise. Note that tpu - and hpu only support Ray for distributed inference. - """ + pipeline_parallel_size: int = 1 # Number of pipeline parallel groups. + tensor_parallel_size: int = 1 # Number of tensor parallel groups. - def __init__( - self, - pipeline_parallel_size: int, - tensor_parallel_size: int, - worker_use_ray: Optional[bool] = None, - max_parallel_loading_workers: Optional[int] = None, - disable_custom_all_reduce: bool = False, - tokenizer_pool_config: Optional[TokenizerPoolConfig] = None, - ray_workers_use_nsight: bool = False, - placement_group: Optional["PlacementGroup"] = None, - distributed_executor_backend: Optional[Union[ - str, Type["ExecutorBase"]]] = None, - ) -> None: - self.pipeline_parallel_size = pipeline_parallel_size - self.tensor_parallel_size = tensor_parallel_size - self.distributed_executor_backend = distributed_executor_backend - self.max_parallel_loading_workers = max_parallel_loading_workers - self.disable_custom_all_reduce = disable_custom_all_reduce - self.tokenizer_pool_config = tokenizer_pool_config - self.ray_workers_use_nsight = ray_workers_use_nsight - self.placement_group = placement_group - self.world_size = pipeline_parallel_size * self.tensor_parallel_size - - if worker_use_ray: + # Deprecated, use distributed_executor_backend instead. + worker_use_ray: Optional[bool] = None + + # Maximum number of multiple batches + # when load model sequentially. To avoid RAM OOM when using tensor + # parallel and large models. + max_parallel_loading_workers: Optional[int] = None + + # Disable the custom all-reduce kernel and fall back to NCCL. + disable_custom_all_reduce: bool = False + + # Config for the tokenizer pool. If None, will use synchronous tokenization. + tokenizer_pool_config: Optional[TokenizerPoolConfig] = None + + # Whether to profile Ray workers with nsight, see https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler. + ray_workers_use_nsight: bool = False + + # ray distributed model workers placement group. + placement_group: Optional["PlacementGroup"] = None + + # Backend to use for distributed model + # workers, either "ray" or "mp" (multiprocessing). If the product + # of pipeline_parallel_size and tensor_parallel_size is less than + # or equal to the number of GPUs available, "mp" will be used to + # keep processing on a single host. Otherwise, this will default + # to "ray" if Ray is installed and fail otherwise. Note that tpu + # and hpu only support Ray for distributed inference. + distributed_executor_backend: Optional[Union[str, + Type["ExecutorBase"]]] = None + + # the full name of the worker class to use. If "auto", the worker class + # will be determined based on the platform. + worker_cls: str = "auto" + sd_worker_cls: str = "auto" + + world_size: int = field(init=False) + + rank: int = 0 + + def __post_init__(self) -> None: + self.world_size = self.pipeline_parallel_size * \ + self.tensor_parallel_size + + if self.worker_use_ray: if self.distributed_executor_backend is None: self.distributed_executor_backend = "ray" elif not self.use_ray: raise ValueError(f"worker-use-ray can't be used with " f"distributed executor backend " f"'{self.distributed_executor_backend}'.") - - if current_platform.is_tpu() and self.world_size > 1: - if self.distributed_executor_backend is None: - self.distributed_executor_backend = "ray" - if self.distributed_executor_backend != "ray": - raise ValueError( - "TPU backend only supports Ray for distributed inference.") - - if current_platform.is_hpu() and self.world_size > 1: + ray_only_devices = ["tpu", "hpu"] + if (current_platform.device_type in ray_only_devices + and self.world_size > 1): if self.distributed_executor_backend is None: self.distributed_executor_backend = "ray" if self.distributed_executor_backend != "ray": raise ValueError( - "HPU backend only supports Ray for distributed inference.") + f"{current_platform.device_type.upper()} backend only " + "supports Ray for distributed inference.") if self.distributed_executor_backend is None and self.world_size > 1: # We use multiprocessing by default if world_size fits on the @@ -1028,7 +1058,18 @@ def __init__( backend) self._verify_args() - self.rank: int = 0 + + if is_mi250() and self.tensor_parallel_size > 1: + self.disable_custom_all_reduce = True + logger.info( + "Disabled the custom all-reduce kernel because it is not " + "working correctly on multi AMD MI250.") + + if is_navi() and self.tensor_parallel_size <= 2: + self.disable_custom_all_reduce = True + logger.info( + "Disabled the custom all-reduce kernel because it is not " + "working correctly when using two AMD Navi GPUs.") if is_mi250() and self.tensor_parallel_size > 1: self.disable_custom_all_reduce = True @@ -1074,100 +1115,97 @@ def _verify_args(self) -> None: "run with Ray.") +@dataclass class SchedulerConfig: - """Scheduler configuration. + """Scheduler configuration.""" - Args: - task: The task to use the model for. - max_num_batched_tokens: Maximum number of tokens to be processed in - a single iteration. - max_num_seqs: Maximum number of sequences to be processed in a single - iteration. - max_model_len: Maximum length of a sequence (including prompt - and generated text). - num_lookahead_slots: The number of slots to allocate per sequence per - step, beyond the known token ids. This is used in speculative - decoding to store KV activations of tokens which may or may not be - accepted. - delay_factor: Apply a delay (of delay factor multiplied by previous - prompt latency) before scheduling next prompt. - enable_chunked_prefill: If True, prefill requests can be chunked based - on the remaining max_num_batched_tokens. - preemption_mode: Whether to perform preemption by swapping or - recomputation. If not specified, we determine the mode as follows: - We use recomputation by default since it incurs lower overhead than - swapping. However, when the sequence group has multiple sequences - (e.g., beam search), recomputation is not currently supported. In - such a case, we use swapping instead. - send_delta_data: Private API. If used, scheduler sends delta data to - workers instead of an entire data. It should be enabled only - when SPMD worker architecture is enabled. I.e., - VLLM_USE_RAY_SPMD_WORKER=1 - policy: The scheduling policy to use. "fcfs" (default) or "priority". - """ + task: str = "generate" # The task to use the model for. + + # Maximum number of tokens to be processed in a single iteration. + max_num_batched_tokens: int = field(default=None) # type: ignore + + # Maximum number of sequences to be processed in a single iteration. + max_num_seqs: int = 128 + + # Maximum length of a sequence (including prompt and generated text). + max_model_len: int = 8192 + + # The number of slots to allocate per sequence per + # step, beyond the known token ids. This is used in speculative + # decoding to store KV activations of tokens which may or may not be + # accepted. + num_lookahead_slots: int = 0 + + # Apply a delay (of delay factor multiplied by previous + # prompt latency) before scheduling next prompt. + delay_factor: float = 0.0 + + # If True, prefill requests can be chunked based + # on the remaining max_num_batched_tokens. + enable_chunked_prefill: bool = False + + is_multimodal_model: bool = False + + # Whether to perform preemption by swapping or + # recomputation. If not specified, we determine the mode as follows: + # We use recomputation by default since it incurs lower overhead than + # swapping. However, when the sequence group has multiple sequences + # (e.g., beam search), recomputation is not currently supported. In + # such a case, we use swapping instead. + preemption_mode: Optional[str] = None - def __init__(self, - task: _Task, - max_num_batched_tokens: Optional[int], - max_num_seqs: int, - max_model_len: int, - num_lookahead_slots: int = 0, - delay_factor: float = 0.0, - enable_chunked_prefill: bool = False, - is_multimodal_model: bool = False, - preemption_mode: Optional[str] = None, - num_scheduler_steps: int = 1, - multi_step_stream_outputs: bool = False, - send_delta_data: bool = False, - policy: str = "fcfs") -> None: - if max_num_batched_tokens is None: - if enable_chunked_prefill: - if num_scheduler_steps > 1: + num_scheduler_steps: int = 1 + + multi_step_stream_outputs: bool = False + + # Private API. If used, scheduler sends delta data to + # workers instead of an entire data. It should be enabled only + # when SPMD worker architecture is enabled. I.e., + # VLLM_USE_RAY_SPMD_WORKER=1 + send_delta_data: bool = False + + # The scheduling policy to use. "fcfs" (default) or "priority". + policy: str = "fcfs" + + chunked_prefill_enabled: bool = field(init=False) + + def __post_init__(self) -> None: + if self.max_num_batched_tokens is None: + if self.enable_chunked_prefill: + if self.num_scheduler_steps > 1: # Multi-step Chunked-Prefill doesn't allow prompt-chunking # for now. Have max_num_batched_tokens set to max_model_len # so we don't reject sequences on account of a short # max_num_batched_tokens. - max_num_batched_tokens = max(max_model_len, 2048) + self.max_num_batched_tokens = max(self.max_model_len, 2048) else: - # It is the values that have the best balance between ITL - # and TTFT on A100. Note it is not optimized for throughput. - max_num_batched_tokens = 512 + # This value is chosen to have a balance between ITL + # and TTFT. Note it is not optimized for throughput. + self.max_num_batched_tokens = 2048 else: # If max_model_len is too short, use 2048 as the default value # for higher throughput. - max_num_batched_tokens = max(max_model_len, 2048) + self.max_num_batched_tokens = max(self.max_model_len, 2048) - if task == "embedding": + if self.task == "embedding": # For embedding, choose specific value for higher throughput - max_num_batched_tokens = max( - max_num_batched_tokens, + self.max_num_batched_tokens = max( + self.max_num_batched_tokens, _EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS, ) - if is_multimodal_model: + if self.is_multimodal_model: # The value needs to be at least the number of multimodal tokens - max_num_batched_tokens = max( - max_num_batched_tokens, + self.max_num_batched_tokens = max( + self.max_num_batched_tokens, _MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS, ) - self.max_num_batched_tokens = max_num_batched_tokens - - if enable_chunked_prefill: + if self.enable_chunked_prefill: logger.info( "Chunked prefill is enabled with max_num_batched_tokens=%d.", self.max_num_batched_tokens) - self.task: Final = task - self.max_num_seqs = max_num_seqs - self.max_model_len = max_model_len - self.num_lookahead_slots = num_lookahead_slots - self.delay_factor = delay_factor - self.chunked_prefill_enabled = enable_chunked_prefill - self.preemption_mode = preemption_mode - self.num_scheduler_steps = num_scheduler_steps - self.multi_step_stream_outputs = multi_step_stream_outputs - self.send_delta_data = send_delta_data - self.policy = policy + self.chunked_prefill_enabled = self.enable_chunked_prefill self._verify_args() def _verify_args(self) -> None: @@ -1206,25 +1244,13 @@ def is_multi_step(self) -> bool: class DeviceConfig: device: Optional[torch.device] + device_type: str def __init__(self, device: str = "auto") -> None: if device == "auto": # Automated device type detection - if current_platform.is_cuda_alike(): - self.device_type = "cuda" - elif current_platform.is_neuron(): - self.device_type = "neuron" - elif current_platform.is_hpu(): - self.device_type = "hpu" - elif current_platform.is_openvino(): - self.device_type = "openvino" - elif current_platform.is_tpu(): - self.device_type = "tpu" - elif current_platform.is_cpu(): - self.device_type = "cpu" - elif current_platform.is_xpu(): - self.device_type = "xpu" - else: + self.device_type = current_platform.device_type + if not self.device_type: raise RuntimeError("Failed to infer device type") else: # Device type is assigned explicitly @@ -1421,16 +1447,6 @@ def maybe_create_spec_config( draft_hf_config ) - if (enable_chunked_prefill and \ - speculative_draft_tensor_parallel_size != 1): - # TODO - Investigate why the error reported in - # https://github.com/vllm-project/vllm/pull/9291#issuecomment-2463266258 - # is happening and re-enable it. - raise ValueError( - "Chunked prefill and speculative decoding can be enabled " - "simultaneously only for draft models with tensor " - "parallel size 1.") - draft_model_config.max_model_len = ( SpeculativeConfig._maybe_override_draft_max_model_len( speculative_max_model_len, @@ -2074,6 +2090,88 @@ def __post_init__(self): f"installed. Original error:\n{otel_import_error_traceback}") +class KVTransferConfig(BaseModel): + """Configuration for distributed KV cache transfer.""" + + # The KV connector for vLLM to transmit KV caches between vLLM instances. + kv_connector: Optional[str] = None + + # The device used by kv connector to buffer the KV cache. + # Currently only support 'cuda'. + kv_buffer_device: Optional[str] = "cuda" + + # The buffer size for TorchDistributedConnector. Measured in number of + # bytes. Recommended value: 1e9 (about 1GB). + kv_buffer_size: float = 1e9 + + # Whether this vLLM instance produces, consumes KV cache, or both. Choices + # are 'kv_producer', 'kv_consumer', and 'both'. + kv_role: Optional[str] = None + + # The rank of this vLLM instance in the KV cache transfer. Typical value: + # 0 for prefill instance, 1 for decode instance. + # Currently only 1P1D is supported. + kv_rank: Optional[int] = None + + # The number of parallel instances for KV cache transfer. For + # PyNcclConnector, this should be 2. + kv_parallel_size: int = 1 + + # The KV connector ip, used to build distributed connection + kv_ip: str = "127.0.0.1" + + # The KV connector port, used to build distributed connection + kv_port: int = 14579 + + @classmethod + def from_cli(cls, cli_value: str) -> "KVTransferConfig": + """Parse the CLI value for the compilation config.""" + return KVTransferConfig.model_validate_json(cli_value) + + def model_post_init(self, __context: Any) -> None: + if all([ + self.kv_connector is not None, + self.kv_connector != "PyNcclConnector" + ]): + raise ValueError(f"Unsupported kv_connector: {self.kv_connector}. " + f"Supported connectors are " + f"`PyNcclConnector`.") + + if self.kv_role is not None and self.kv_role not in [ + "kv_producer", "kv_consumer", "kv_both" + ]: + raise ValueError( + f"Unsupported kv_role: {self.kv_role}. " + f"Supported roles are `kv_producer`, `kv_consumer`, " + f"and `kv_both`") + + if self.kv_connector is not None and self.kv_role is None: + raise ValueError("Please specify kv_disagg_role when kv_connector " + "is set, supported roles are `kv_producer`, " + "`kv_consumer`, and `kv_both`") + + @property + def is_kv_transfer_instance(self) -> bool: + return self.kv_connector is not None and \ + self.kv_role in ["kv_producer", "kv_consumer", "kv_both"] + + @property + def need_kv_parallel_group(self) -> bool: + # for those database-based connector, vLLM does not need to create + # parallel group, and in that case the kv parallel size will be 1. + return self.kv_connector is not None and self.kv_parallel_size > 1 + + @property + def is_kv_producer(self) -> bool: + return self.kv_connector is not None and \ + self.kv_role in ["kv_producer", "kv_both"] + + @property + def is_kv_consumer(self) -> bool: + return self.kv_connector is not None and \ + self.kv_role in ["kv_consumer", "kv_both"] + + class CompilationLevel: # constants for the levels of the compilation process NO_COMPILATION = 0 @@ -2108,13 +2206,15 @@ class CompilationConfig(BaseModel): - 'none,+op1,+op2' to enable only op1 and op2 By default, all custom ops are enabled when running without Inductor and disabled when running with Inductor (compile_level >= Inductor). + - splitting_ops: a list of ops to split the full graph into subgraphs, used in piecewise compilation. - CudaGraph capture: - use_cudagraph: whether to use cudagraph inside compilation. - False: cudagraph inside compilation is not used. - True: cudagraph inside compilation is used. It requires - that all input buffers have fixed addresses. - Note that this is orthogonal to the cudagraph capture out - side of compilation. + that all input buffers have fixed addresses, and all + splitting ops write their outputs to input buffers. + Note that this is orthogonal to the cudagraph capture logic + outside of compilation. TODO: move outside cudagraph logic into compilation. torch.compile will handle cudagraph capture logic in the future. - cudagraph_capture_sizes: sizes to capture cudagraph. @@ -2148,12 +2248,7 @@ class CompilationConfig(BaseModel): name because the config uses json format. If we pass the config from Python, functions can also be passed directly via Python object constructor, e.g. `CompilationConfig(inductor_passes={"a": func})` - - custom inductor passes: - - dump_graph_stages: list of stages for which we want to dump the graph. - Each pass defines its own stages (before, after, maybe in-between). - - dump_graph_dir: directory to dump the graph. Default is . - - enable_fusion: whether to enable the custom fusion pass. - TODO better pass enabling system. + - custom inductor passes: see PassConfig for more details Why we have different sizes for cudagraph and inductor: - cudagraph: a cudagraph captured for a specific size can only be used @@ -2168,29 +2263,81 @@ class CompilationConfig(BaseModel): level: int = 0 backend: str = "" custom_ops: List[str] = Field(default_factory=list) + splitting_ops: List[str] = Field(default_factory=lambda: [ + "vllm.unified_attention", + "vllm.unified_attention_with_output", + ]) use_inductor: bool = True inductor_specialize_for_cudagraph_no_more_than: Optional[int] = None - inductor_compile_sizes: Optional[List[int]] = Field(default_factory=dict) + inductor_compile_sizes: Optional[List[int]] = Field(default=None) inductor_compile_config: Dict = Field(default_factory=dict) inductor_passes: Dict[str, str] = Field(default_factory=dict) use_cudagraph: bool = False - non_cudagraph_ops: List[str] = Field(default_factory=list) cudagraph_num_of_warmups: int = 0 cudagraph_capture_sizes: Optional[List[int]] = None cudagraph_copy_inputs: bool = False - dump_graph_stages: List[str] = Field(default_factory=list) - dump_graph_dir: Path = Field(default=Path(".")) - enable_fusion: bool = True + class PassConfig(BaseModel): + """ + Configuration for custom Inductor passes. + This is separate from general CompilationConfig so that inductor passes + don't all have access to full configuration - that would create a cycle + as the PassManager is set as a property of config. + - dump_graph_stages: list of stages for which we want to dump the graph. + Each pass defines its own stages (before, after, maybe in-between). + - dump_graph_dir: directory to dump the graphs. Default is . + - enable_fusion: whether to enable the custom fusion pass. + - enable_reshape: whether to enable the custom reshape elimination pass. + TODO better pass enabling system. + """ + dump_graph_stages: List[str] = Field(default_factory=list) + dump_graph_dir: Path = Field(default=Path(".")) + enable_fusion: bool = True + enable_reshape: bool = True + + def uuid(self): + """ + Produces a hash unique to the pass configuration. + Any new fields that affect compilation should be added to the hash. + Do not include dump_graph_* in the hash - they don't affect + compilation. + """ + dict_ = self.model_dump( + include={"enable_fusion", "enable_reshape"}) + encoded = json.dumps(dict_, sort_keys=True).encode("utf-8") + return hashlib.sha256(encoded).digest() + + def model_post_init(self, __context: Any) -> None: + if not self.enable_reshape and self.enable_fusion: + print_warning_once( + "Fusion enabled but reshape elimination disabled." + "RMSNorm + quant (fp8) fusion might not work") + + pass_config: PassConfig = Field(default_factory=PassConfig) # not configurable, computed after init compile_sizes: List[int] = PrivateAttr capture_sizes: List[int] = PrivateAttr + # keep track of enabled and disabled custom ops + enabled_custom_ops: Counter[str] = PrivateAttr + disabled_custom_ops: Counter[str] = PrivateAttr + + # Per-model forward context + # Mainly used to store attention cls + # Map from layer name to the attention cls + static_forward_context: Dict[str, Any] = PrivateAttr + + @classmethod + def from_cli(cls, cli_value: str) -> "CompilationConfig": + """Parse the CLI value for the compilation config.""" + if cli_value in ["0", "1", "2", "3"]: + return cls(level=int(cli_value)) + return CompilationConfig.model_validate_json(cli_value) + def model_post_init(self, __context: Any) -> None: - self.level = envs.VLLM_TORCH_COMPILE_LEVEL count_none = self.custom_ops.count("none") count_all = self.custom_ops.count("all") @@ -2199,8 +2346,9 @@ def model_post_init(self, __context: Any) -> None: for k, v in self.inductor_passes.items(): if not isinstance(v, str): assert callable(v), ( - f"pass {k} should be a function or a qualified name") - self.inductor_compile_config[k] = v + f"pass {k} should be callable or a qualified name") + self.inductor_compile_config[k] = v if isinstance( + v, InductorPass) else CallableInductorPass(v) continue # resolve function from qualified name @@ -2208,7 +2356,12 @@ def model_post_init(self, __context: Any) -> None: module = ".".join(names[:-1]) func_name = names[-1] func = __import__(module).__dict__[func_name] - self.inductor_compile_config[k] = func + self.inductor_compile_config[k] = func if isinstance( + func, InductorPass) else CallableInductorPass(func) + + self.enabled_custom_ops = Counter() + self.disabled_custom_ops = Counter() + self.static_forward_context = {} def init_backend(self) -> Union[str, Callable]: if self.level == CompilationLevel.NO_COMPILATION: @@ -2256,31 +2409,10 @@ def init_during_runtime(self): if x <= self.inductor_specialize_for_cudagraph_no_more_than ] else: - assert self.inductor_compile_sizes is not None, ( - "inductor_compile_sizes should not be None when " - "inductor_specialize_for_cudagraph_no_more_than is None") + if self.inductor_compile_sizes is None: + self.inductor_compile_sizes = [] self.compile_sizes = self.inductor_compile_sizes - @staticmethod - def select_and_init_config() -> "CompilationConfig": - """The order of selecting config is: - 1. Use the config specified in environment variable. - 2. Use the config specified in plugins. - 3. Use the default config. - """ - config_path = envs.VLLM_TORCH_COMPILE_CONFIG - if config_path is not None: - with open(config_path) as json_file: - config = CompilationConfig.model_validate_json( - json_file.read()) - else: - from vllm.plugins import get_compilation_config - predefined_config = get_compilation_config() - config = predefined_config if predefined_config is not None else ( - CompilationConfig()) - - return config - @dataclass class VllmConfig: @@ -2290,10 +2422,10 @@ class VllmConfig: model_config: ModelConfig = field(default=None, init=True) # type: ignore cache_config: CacheConfig = field(default=None, init=True) # type: ignore - parallel_config: ParallelConfig = field(default=None, - init=True) # type: ignore - scheduler_config: SchedulerConfig = field(default=None, - init=True) # type: ignore + parallel_config: ParallelConfig = field(default_factory=ParallelConfig, + init=True) + scheduler_config: SchedulerConfig = field(default_factory=SchedulerConfig, + init=True) device_config: DeviceConfig = field(default=None, init=True) # type: ignore load_config: LoadConfig = field(default=None, init=True) # type: ignore @@ -2305,6 +2437,8 @@ class VllmConfig: quant_config: Optional[QuantizationConfig] = None compilation_config: CompilationConfig = field(default=None, init=True) # type: ignore + kv_transfer_config: KVTransferConfig = field(default=None, + init=True) # type: ignore @staticmethod def _get_quantization_config( @@ -2365,9 +2499,43 @@ def __post_init__(self): self.quant_config = VllmConfig._get_quantization_config( self.model_config, self.load_config) + if self.scheduler_config is not None and \ + self.model_config is not None and \ + self.scheduler_config.chunked_prefill_enabled and \ + self.model_config.dtype == torch.float32 and \ + current_platform.get_device_capability() == (7, 5): + print_warning_once( + "Turing devices tensor cores do not support float32 matmul. " + "To workaround this limitation, vLLM will set 'ieee' input " + "precision for chunked prefill triton kernels.") + if self.compilation_config is None: - self.compilation_config = CompilationConfig.select_and_init_config( - ) + self.compilation_config = CompilationConfig() + if envs.VLLM_USE_V1 and not self.model_config.enforce_eager: + # NOTE(woosuk): Currently, we use inductor because the piecewise + # CUDA graphs do not work properly with the custom CUDA kernels. + # FIXME(woosuk): Disable inductor to reduce the compilation time + # and avoid any potential issues with the inductor. + self.compilation_config.custom_ops = ["none"] + self.compilation_config.use_cudagraph = True + self.compilation_config.use_inductor = True + self.compilation_config.pass_config.enable_fusion = False + self.compilation_config.pass_config.enable_reshape = False + self.compilation_config.level = CompilationLevel.PIECEWISE + + if self.cache_config is not None and \ + self.cache_config.cpu_offload_gb > 0 and \ + self.compilation_config.level != CompilationLevel.NO_COMPILATION: + logger.warning( + "CPU offload is not supported with `torch.compile` yet." + " Disabling `torch.compile`.") + self.compilation_config.level = CompilationLevel.NO_COMPILATION + + if self.lora_config is not None and self.compilation_config.level !=\ + CompilationLevel.NO_COMPILATION: + logger.warning("LoRA is not supported with `torch.compile` yet. " + "Disabling `torch.compile`.") + self.compilation_config.level = CompilationLevel.NO_COMPILATION current_platform.check_and_update_config(self) @@ -2411,3 +2579,53 @@ def __str__(self): self.cache_config.enable_prefix_caching, self.model_config.use_async_output_proc, self.model_config.mm_processor_kwargs) + + +_current_vllm_config: Optional[VllmConfig] = None + + +@contextmanager +def set_current_vllm_config(vllm_config: VllmConfig): + """ + Temporarily set the current VLLM config. + Used during model initialization. + We save the current VLLM config in a global variable, + so that all modules can access it, e.g. custom ops + can access the VLLM config to determine how to dispatch. + """ + global _current_vllm_config + old_vllm_config = _current_vllm_config + from vllm.compilation.counter import compilation_counter + num_models_seen = compilation_counter.num_models_seen + try: + _current_vllm_config = vllm_config + yield + finally: + logger.debug("enabled custom ops: %s", + vllm_config.compilation_config.enabled_custom_ops) + logger.debug("disabled custom ops: %s", + vllm_config.compilation_config.disabled_custom_ops) + if vllm_config.compilation_config.level == CompilationLevel.PIECEWISE \ + and compilation_counter.num_models_seen == num_models_seen: + # If the model supports compilation, + # compilation_counter.num_models_seen should be increased + # by at least 1. + # If it is not increased, it means the model does not support + # compilation (does not have @support_torch_compile decorator). + logger.warning( + "`torch.compile` is turned on, but the model %s" + " does not support it. Please open an issue on GitHub" + "if you want it to be supported.", + vllm_config.model_config.model) + _current_vllm_config = old_vllm_config + + +def get_current_vllm_config() -> VllmConfig: + if _current_vllm_config is None: + # in ci, usually when we test custom ops/modules directly, + # we don't set the vllm config. In that case, we set a default + # config. + logger.warning("Current VLLM config is not set.") + from vllm.config import VllmConfig + return VllmConfig() + return _current_vllm_config diff --git a/vllm/core/block/cpu_gpu_block_allocator.py b/vllm/core/block/cpu_gpu_block_allocator.py index 9727f6e19b84e..3197af3c2b7a4 100644 --- a/vllm/core/block/cpu_gpu_block_allocator.py +++ b/vllm/core/block/cpu_gpu_block_allocator.py @@ -306,14 +306,6 @@ def mark_blocks_as_computed(self, block_ids: List[int]) -> None: device = Device.GPU return self._allocators[device].mark_blocks_as_computed(block_ids) - def get_computed_block_ids(self, prev_computed_block_ids: List[int], - block_ids: List[int], - skip_last_block_id: bool) -> List[int]: - # Prefix caching only supported on GPU. - device = Device.GPU - return self._allocators[device].get_computed_block_ids( - prev_computed_block_ids, block_ids, skip_last_block_id) - def get_common_computed_block_ids( self, computed_seq_block_ids: List[List[int]]) -> List[int]: # Prefix caching only supported on GPU. @@ -342,6 +334,13 @@ def get_and_reset_swaps(self) -> List[Tuple[int, int]]: self._swap_mapping.clear() return list(mapping.items()) + def find_cached_blocks_prefix( + self, + block_hashes: List[int], + device: Device = Device.GPU, + ) -> List[int]: + return self._allocators[device].find_cached_blocks_prefix(block_hashes) + class NullBlock(Block): """ diff --git a/vllm/core/block/interfaces.py b/vllm/core/block/interfaces.py index 72bbab1dcea5d..06f4851af3466 100644 --- a/vllm/core/block/interfaces.py +++ b/vllm/core/block/interfaces.py @@ -159,12 +159,6 @@ def mark_blocks_as_accessed(self, block_ids: List[int], def mark_blocks_as_computed(self, block_ids: List[int]) -> None: pass - @abstractmethod - def get_computed_block_ids(self, prev_computed_block_ids: List[int], - block_ids: List[int], - skip_last_block_id: bool) -> List[int]: - pass - @abstractmethod def get_common_computed_block_ids( self, computed_seq_block_ids: List[List[int]]) -> List[int]: @@ -192,6 +186,13 @@ def get_prefix_cache_hit_rate(self) -> float: class NoFreeBlocksError(ValueError): pass + @abstractmethod + def find_cached_blocks_prefix( + self, + block_hashes: List[int], + ) -> List[int]: + pass + class DeviceAwareBlockAllocator(ABC): @@ -207,9 +208,12 @@ def allocate_immutable_block(self, prev_block: Optional[Block], pass @abstractmethod - def allocate_immutable_blocks(self, prev_block: Optional[Block], - block_token_ids: List[List[int]], - device: Device) -> List[Block]: + def allocate_immutable_blocks( + self, + prev_block: Optional[Block], + block_token_ids: List[List[int]], + device: Device, + ) -> List[Block]: pass @abstractmethod @@ -246,12 +250,6 @@ def mark_blocks_as_accessed(self, block_ids: List[int], def mark_blocks_as_computed(self, block_ids: List[int]) -> None: pass - @abstractmethod - def get_computed_block_ids(self, prev_computed_block_ids: List[int], - block_ids: List[int], - skip_last_block_id: bool) -> List[int]: - pass - @abstractmethod def get_common_computed_block_ids( self, computed_seq_block_ids: List[List[int]]) -> List[int]: @@ -284,3 +282,11 @@ def allocate_or_get_null_block(self) -> Block: def get_prefix_cache_hit_rate(self, device: Device) -> float: """Prefix cache hit rate. -1 means not supported or disabled.""" pass + + @abstractmethod + def find_cached_blocks_prefix( + self, + block_hashes: List[int], + device: Device = Device.GPU, + ) -> List[int]: + pass diff --git a/vllm/core/block/naive_block.py b/vllm/core/block/naive_block.py index 9341a518d11c6..a2af5ad6362c1 100644 --- a/vllm/core/block/naive_block.py +++ b/vllm/core/block/naive_block.py @@ -262,13 +262,6 @@ def mark_blocks_as_computed(self, block_ids: List[int]) -> None: """ pass - def get_computed_block_ids(self, prev_computed_block_ids: List[int], - block_ids: List[int], - skip_last_block_id: bool) -> List[int]: - """No prefix caching here => return empty list - """ - return [] - def get_common_computed_block_ids( self, computed_seq_block_ids: List[List[int]]) -> List[int]: """Determine blocks that can be skipped in prefill. @@ -329,6 +322,10 @@ def swap_in(self, blocks: List[Block]) -> None: def get_prefix_cache_hit_rate(self) -> float: return -1 + def find_cached_blocks_prefix(self, block_hashes: List[int]) -> List[int]: + # Not applicable for naive block allocator. + return [] + class NaiveBlock(Block): """An implementation of the Block class that does not support prefix diff --git a/vllm/core/block/prefix_caching_block.py b/vllm/core/block/prefix_caching_block.py index 57527e39b9bdd..b736167f6ceb4 100644 --- a/vllm/core/block/prefix_caching_block.py +++ b/vllm/core/block/prefix_caching_block.py @@ -1,13 +1,18 @@ """Token blocks.""" +import sys +from bisect import bisect_left from os.path import commonprefix -from typing import Dict, FrozenSet, Iterable, List, Optional, Set, Tuple +from typing import (Callable, Dict, FrozenSet, Iterable, List, Optional, Set, + Tuple) from vllm.core.block.common import (CacheMetricData, CopyOnWriteTracker, get_all_blocks_recursively) -from vllm.core.block.interfaces import Block, BlockAllocator, BlockId, Device +from vllm.core.block.interfaces import (Block, BlockAllocator, BlockId, Device, + DeviceAwareBlockAllocator) from vllm.core.block.naive_block import (BlockPool, NaiveBlock, NaiveBlockAllocator) from vllm.core.evictor import EvictionPolicy, Evictor, make_evictor +from vllm.sequence import Sequence PrefixHash = int @@ -534,26 +539,6 @@ def block_is_computed(self, block_id: int) -> bool: else: return block_id in self.evictor - def get_computed_block_ids(self, - prev_computed_block_ids: List[int], - block_ids: List[int], - skip_last_block_id: bool = True) -> List[int]: - prev_prefix_size = len(prev_computed_block_ids) - cur_size = len(block_ids) - if skip_last_block_id: - cur_size -= 1 - - # Sanity checks - assert cur_size >= 0 - assert prev_prefix_size <= cur_size - - ret = prev_computed_block_ids - for i in range(prev_prefix_size, cur_size): - block_id = block_ids[i] - if self.block_is_computed(block_id): - ret.append(block_id) - return ret - def get_common_computed_block_ids( self, computed_seq_block_ids: List[List[int]]) -> List[int]: """Return the block ids that are common for a given sequence group. @@ -634,6 +619,47 @@ def swap_in(self, blocks: List[Block]) -> None: block.block_id = block_id # Assign block_id + def find_cached_blocks_prefix(self, block_hashes: List[int]) -> List[int]: + """ + Given a list of block hashes, return the prefix of the block hashes that + are all cached. + + Since a block's block hash includes the hashes of all previous blocks, + and we only allocate/deallocate blocks in the entire sequence, so if a + block is cached, then all previous blocks are also cached. With this + property, we can use binary search to find the prefix of cached blocks. + + Args: + block_hashes (List[int]): The list of block hashes. + + Returns: + List[int]: The prefix of the `block_hashes` that are cached. + """ + + def _block_is_cached(block_hash: PrefixHash) -> bool: + if block_hash not in self._cached_blocks: + return False + + cached_block_id = self._cached_blocks[block_hash] + # We only consider the blocks that are marked as computed. + return self.block_is_computed(cached_block_id) + + def _bisect_left(a, x, key: Callable[[PrefixHash], bool]) -> int: + + # python <= 3.10 don't have the key argument + if sys.version_info < (3, 10): + a = [key(e) for e in a] + return bisect_left(a, x) + else: + return bisect_left(a, x, key=key) + + # Look for the first block that's not cached, and returns the prefix + # i.e. blocks that are cached. + idx = _bisect_left(block_hashes, + True, + key=lambda x: not _block_is_cached(x)) + return block_hashes[:idx] + class PrefixCachingBlock(Block): """A block implementation that supports prefix caching. @@ -843,86 +869,126 @@ def hash_block_tokens(is_first_block: bool, prev_block_hash: Optional[int], class ComputedBlocksTracker: - """Handles caching of per-sequence computed block ids. - When a sequence appears for the first time, it traverses all of the - blocks and detects the prefix of blocks that is computed. On the - subsequent times, it only traverses the new blocks that were added - and updates the already recorded prefix of blocks with the newly - computed blocks. - - To avoid redundant traversals, the algorithm also detects when there - is a "gap" in the computed prefix. For example, if we have blocks = - [1,2,3,4,5], and we have detected [1,2,3] as the computed prefix, then - we won't try to add more computed blocks to [1,2,3] in this sequence - iteration, and will add more computed blocks only after the sequence is - freed and reused again. - - Note that currently, for a given sequence, we also skip the last - block id for caching purposes, to avoid caching of a full sequence """ + Tracks the computed blocks for each sequence. - def __init__(self, allocator): - self._allocator = allocator - self._cached_computed_seq_blocks: Dict[int, Tuple[List[int], - bool]] = {} + Internally, it maintains a map from sequence id to the list of block hashes + for the sequence. We cache the hashes of the full blocks for each sequence, + and make sure the hash is calculated in the same way as the allocator. + When a sequence is being decoded, we also update the sequence's hash + accordingly and incrementally. - def add_seq(self, seq_id: int) -> None: - """Start tracking seq_id - """ - assert seq_id not in self._cached_computed_seq_blocks - self._cached_computed_seq_blocks[seq_id] = ([], False) - - def remove_seq(self, seq_id: int) -> None: - """Stop tracking seq_id - """ - assert seq_id in self._cached_computed_seq_blocks - del self._cached_computed_seq_blocks[seq_id] - - def get_cached_computed_blocks_and_update( - self, seq_id: int, block_ids: List[int]) -> List[int]: - """ Look at the class documentation for details - """ - # Ensure seq_id is already tracked - assert seq_id in self._cached_computed_seq_blocks - - # Get cached data (may be empty on the first time) - prev_computed_block_ids, has_gap = self._cached_computed_seq_blocks[ - seq_id] - - if has_gap: - # When gap is detected, we do not add more computed blocks at this - # sequence iteration - return prev_computed_block_ids - - # We do not consider the last block id for caching purposes. - num_cur_blocks = len(block_ids) - 1 - assert num_cur_blocks >= 0 - - if len(prev_computed_block_ids) >= num_cur_blocks: - # Cache HIT - assert len(prev_computed_block_ids) == num_cur_blocks - return prev_computed_block_ids - - # If here, then we may possibly add more computed blocks. As a result, - # traverse the additional blocks after prev_computed_block_ids to - # detect more computed blocks and add them. - - # Incremental init for seq_id => Look only at the new blocks - computed_block_ids = self._allocator.get_computed_block_ids( # noqa: E501 - prev_computed_block_ids, - block_ids, - skip_last_block_id= - True, # We skip last block id to avoid caching of full seq - ) + From the sequence hash, with prefix caching enabled, we could also calculate + the number of cached tokens for the sequence by looking up the number of + cached block hashes in the allocator. + """ - # Detect if there is a "gap" - has_gap = len(computed_block_ids) < num_cur_blocks + def __init__( + self, + allocator: DeviceAwareBlockAllocator, + block_size: int, + enable_caching: bool, + ): + self._allocator = allocator + self._block_size = block_size + self._enable_caching = enable_caching + + # A map from seq_id to the list of block hashes for the + # sequence. This is so that we don't have to recompute the block hashes + # for the sequence when we need to check if the sequence is cached. + # Note a block that's not full will not have its hash calculated and + # recorded. + self._seq_id_to_blocks_hashes: Dict[int, List[int]] = {} + + # A map from seq_id to the number of tokens that are cached for the + # sequence. + # We need this so that a sequence in continuous prefill doesn't + # accidentally see its cached token count change. See comments in + # `get_num_cached_tokens` for more details. + self._seq_id_to_num_tokens_computed: Dict[int, int] = {} + + def _update_seq_hashes(self, seq: Sequence) -> None: + """Incrementally update the sequence's block hashes and record them.""" + assert self._enable_caching + + block_hashes_recorded = self._seq_id_to_blocks_hashes.get( + seq.seq_id, []) + cur_num_blocks_recorded = len(block_hashes_recorded) + token_ids = seq.get_token_ids() + assert len(token_ids) >= cur_num_blocks_recorded * self._block_size, ( + f"The sequence has {len(token_ids)} tokens, but" + f" already recorded {cur_num_blocks_recorded} blocks. " + "This should not happen since we assume blocks are " + "only appended other than recomputation. When the sequence is " + "recomputed, we should have removed the info of the old blocks.") + # Update the computed block hashes for the sequence. Since only full + # blocks are considered as "computed", we take floor here. + num_computed_blocks = len(token_ids) // self._block_size + + # We need to know the hash of the previous block to compute the hash of + # the current block so that blocks could be uniquely identified across + # sequences of prefixes. + prev_block_hash = (None if cur_num_blocks_recorded == 0 else + block_hashes_recorded[-1]) + # Only update the computed block hashes for the new blocks + for i in range(cur_num_blocks_recorded, num_computed_blocks): + assert len(token_ids) >= (i + 1) * self._block_size + block_token_ids = token_ids[i * self._block_size:(i + 1) * + self._block_size] + # This has to be kept in sync with the allocator's hash + # calculation. + block_hash = PrefixCachingBlock.hash_block_tokens( + is_first_block=prev_block_hash is None, + prev_block_hash=prev_block_hash, + cur_block_token_ids=block_token_ids, + ) + block_hashes_recorded.append(block_hash) + prev_block_hash = block_hash + + self._seq_id_to_blocks_hashes[seq.seq_id] = block_hashes_recorded + + def get_num_cached_tokens(self, seq: Sequence) -> int: + if not self._enable_caching: + return 0 + + # We always try to update the sequence hashes on the fly. + # This is to ensure that we don't miss any cached tokens for the + # sequence during decode. + # This routine should only update hash for any new blocks too. + self._update_seq_hashes(seq) + + num_computed_tokens_prev = self._seq_id_to_num_tokens_computed.get( + seq.seq_id, None) + + # TODO(rickyx): This hack could be removed once we mark blocks as + # computed correctly with chunked prefills. + if num_computed_tokens_prev is not None and seq.is_prefill(): + # For a sequence that is still in prefill, we don't + # recompute the number of cached tokens. + # This also handles correctly chunked prefill since currently + # we mark blocks as computed even if the sequence is still partially + # prefilled. So a continuously prefilled sequence should not + # see its cached token count change while running. + return num_computed_tokens_prev + + block_hashes = self._seq_id_to_blocks_hashes[seq.seq_id] + + # This is O(logN), where N is the number of blocks. + num_cached_blocks = len( + self._allocator.find_cached_blocks_prefix(block_hashes)) + num_cached_tokens = num_cached_blocks * self._block_size + self._seq_id_to_num_tokens_computed[seq.seq_id] = num_cached_tokens + return num_cached_tokens - # Record - self._cached_computed_seq_blocks[seq_id] = (computed_block_ids, - has_gap) + def remove_seq(self, seq_id: int) -> None: + """Stop tracking the sequence.""" + if not self._enable_caching: + return + assert seq_id in self._seq_id_to_blocks_hashes + del self._seq_id_to_blocks_hashes[seq_id] - return computed_block_ids + assert seq_id in self._seq_id_to_num_tokens_computed + del self._seq_id_to_num_tokens_computed[seq_id] class LastAccessBlocksTracker: diff --git a/vllm/core/block_manager.py b/vllm/core/block_manager.py index 21f4c63b6572d..209487c6b4f9e 100644 --- a/vllm/core/block_manager.py +++ b/vllm/core/block_manager.py @@ -101,7 +101,7 @@ def __init__( self.cross_block_tables: Dict[EncoderSeqId, BlockTable] = {} self._computed_blocks_tracker = ComputedBlocksTracker( - self.block_allocator) + self.block_allocator, self.block_size, self.enable_caching) self._last_access_blocks_tracker = LastAccessBlocksTracker( self.block_allocator) @@ -170,7 +170,6 @@ def allocate(self, seq_group: SequenceGroup) -> None: self.block_tables[seq.seq_id] = block_table # Track seq - self._computed_blocks_tracker.add_seq(seq.seq_id) self._last_access_blocks_tracker.add_seq(seq.seq_id) # Assign the block table for each sequence. @@ -178,7 +177,6 @@ def allocate(self, seq_group: SequenceGroup) -> None: self.block_tables[seq.seq_id] = block_table.fork() # Track seq - self._computed_blocks_tracker.add_seq(seq.seq_id) self._last_access_blocks_tracker.add_seq(seq.seq_id) # Allocate cross-attention block table for encoder sequence @@ -314,11 +312,13 @@ def get_common_computed_block_ids( """ computed_seq_block_ids = [] for seq in seqs: - computed_seq_block_ids.append( - self._computed_blocks_tracker. - get_cached_computed_blocks_and_update( - seq.seq_id, - self.block_tables[seq.seq_id].physical_block_ids)) + all_blocks = self.block_tables[seq.seq_id].physical_block_ids + num_cached_tokens = ( + self._computed_blocks_tracker.get_num_cached_tokens(seq)) + assert num_cached_tokens % self.block_size == 0 + num_cached_blocks = num_cached_tokens // self.block_size + computed_block_ids = all_blocks[:num_cached_blocks] + computed_seq_block_ids.append(computed_block_ids) # NOTE(sang): This assumes seq_block_ids doesn't contain any None. return self.block_allocator.get_common_computed_block_ids( @@ -332,7 +332,6 @@ def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None: self.block_tables[child_seq.seq_id] = src_block_table.fork() # Track child seq - self._computed_blocks_tracker.add_seq(child_seq.seq_id) self._last_access_blocks_tracker.add_seq(child_seq.seq_id) def can_swap_in(self, seq_group: SequenceGroup, @@ -503,3 +502,9 @@ def _can_swap(self, return AllocStatus.OK else: return AllocStatus.LATER + + def get_num_cached_tokens(self, seq: Sequence) -> int: + """Get the number of tokens in blocks that are already computed and + cached in the block manager for the sequence. + """ + return self._computed_blocks_tracker.get_num_cached_tokens(seq) diff --git a/vllm/core/interfaces.py b/vllm/core/interfaces.py index 9501a516bf020..b10b8d3f4a5bf 100644 --- a/vllm/core/interfaces.py +++ b/vllm/core/interfaces.py @@ -121,3 +121,7 @@ def mark_blocks_as_computed(self, seq_group: SequenceGroup, def get_prefix_cache_hit_rate(self, device: Device) -> float: """Prefix cache hit rate. -1 means not supported or disabled.""" pass + + @abstractmethod + def get_num_cached_tokens(self, seq: Sequence) -> int: + pass diff --git a/vllm/core/placeholder_block_space_manager.py b/vllm/core/placeholder_block_space_manager.py index a337392bbed53..26d42b7f1790e 100644 --- a/vllm/core/placeholder_block_space_manager.py +++ b/vllm/core/placeholder_block_space_manager.py @@ -89,3 +89,6 @@ def mark_blocks_as_computed(self, seq_group: SequenceGroup, def get_prefix_cache_hit_rate(self, device: Device) -> float: return -1 + + def get_num_cached_tokens(self, seq: Sequence) -> int: + return 0 diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index af4671ec29be9..d23009dae01ee 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -56,11 +56,16 @@ class SchedulingBudget: max_num_seqs: int _request_ids_num_batched_tokens: Set[str] = field(default_factory=set) _request_ids_num_curr_seqs: Set[str] = field(default_factory=set) + # Number of cached tokens in the batch. + _num_cached_tokens: int = 0 + # Number of actual non-cached tokens in the batch. _num_batched_tokens: int = 0 _num_curr_seqs: int = 0 def can_schedule(self, *, num_new_tokens: int, num_new_seqs: int): - assert num_new_tokens != 0 + # We allow num_new_tokens to be 0 when the entire sequence has + # been cached. + assert num_new_tokens >= 0 assert num_new_seqs != 0 return (self.num_batched_tokens + num_new_tokens <= self.token_budget and self.num_curr_seqs + num_new_seqs <= self.max_num_seqs) @@ -68,12 +73,18 @@ def can_schedule(self, *, num_new_tokens: int, num_new_seqs: int): def remaining_token_budget(self): return self.token_budget - self.num_batched_tokens - def add_num_batched_tokens(self, req_id: str, num_batched_tokens: int): + def add_num_batched_tokens(self, + req_id: str, + num_batched_tokens: int, + num_cached_tokens: int = 0): if req_id in self._request_ids_num_batched_tokens: return + assert num_cached_tokens >= 0 + assert num_batched_tokens >= 0 self._request_ids_num_batched_tokens.add(req_id) self._num_batched_tokens += num_batched_tokens + self._num_cached_tokens += num_cached_tokens def subtract_num_batched_tokens(self, req_id: str, num_batched_tokens: int): @@ -101,6 +112,10 @@ def num_batched_tokens(self): def num_curr_seqs(self): return self._num_curr_seqs + @property + def num_cached_tokens(self): + return self._num_cached_tokens + @dataclass class ScheduledSequenceGroup: @@ -541,9 +556,19 @@ def _schedule_running( assert len(self._async_stopped) == 0 while running_queue: seq_group = running_queue[0] - num_running_tokens = self._get_num_new_tokens( - seq_group, SequenceStatus.RUNNING, enable_chunking, budget) - + # We discard the cached tokens info here because we don't need it + # for running sequence: + # 1. If a sequence is running with chunked prefill, the cached + # tokens info was already used for the first prefill. + # 2. If a sequence is running with non-chunked prefill, then + # there it's a decoding sequence, and the cached tokens info is + # irrelevant. + num_uncached_new_tokens, _ = ( + self._get_num_new_uncached_and_cached_tokens( + seq_group, SequenceStatus.RUNNING, enable_chunking, + budget)) + + num_running_tokens = num_uncached_new_tokens if num_running_tokens == 0: # No budget => Stop break @@ -715,13 +740,15 @@ def _schedule_swapped( # The total number of sequences in the RUNNING state should not # exceed the maximum number of sequences. num_new_seqs = seq_group.get_max_num_running_seqs() - num_new_tokens = self._get_num_new_tokens(seq_group, - SequenceStatus.SWAPPED, - enable_chunking, budget) - - if (num_new_tokens == 0 - or not budget.can_schedule(num_new_tokens=num_new_tokens, - num_new_seqs=num_new_seqs)): + num_new_tokens_uncached, num_new_tokens_cached = ( + self._get_num_new_uncached_and_cached_tokens( + seq_group, SequenceStatus.SWAPPED, enable_chunking, + budget)) + + if num_new_tokens_uncached == 0 or not budget.can_schedule( + num_new_tokens=num_new_tokens_uncached, + num_new_seqs=num_new_seqs, + ): break if lora_int_id > 0 and curr_loras is not None: @@ -732,12 +759,19 @@ def _schedule_swapped( is_prefill = seq_group.is_prefill() if is_prefill: prefill_seq_groups.append( - ScheduledSequenceGroup(seq_group, - token_chunk_size=num_new_tokens)) + ScheduledSequenceGroup( + seq_group, + token_chunk_size=num_new_tokens_uncached + + num_new_tokens_cached, + )) else: decode_seq_groups.append( ScheduledSequenceGroup(seq_group, token_chunk_size=1)) - budget.add_num_batched_tokens(seq_group.request_id, num_new_tokens) + budget.add_num_batched_tokens( + seq_group.request_id, + num_batched_tokens=num_new_tokens_uncached, + num_cached_tokens=num_new_tokens_cached, + ) budget.add_num_seqs(seq_group.request_id, num_new_seqs) swapped_queue.extendleft(leftover_swapped) @@ -803,26 +837,30 @@ def _schedule_priority_preemption( if waiting_queue: seq_group = waiting_queue.popleft() num_new_seqs = seq_group.get_max_num_running_seqs() - num_new_tokens = self._get_num_new_tokens(seq_group, - SequenceStatus.WAITING, - False, budget) + num_new_tokens_uncached, _ = ( + self._get_num_new_uncached_and_cached_tokens( + seq_group, SequenceStatus.WAITING, False, budget)) #Only preempt if priority inversion exists while running_queue and self._get_priority( running_queue[-1]) > self._get_priority(seq_group): #Only preempt if waiting sequence cannot be allocated can_allocate = self.block_manager.can_allocate(seq_group) - if (num_new_tokens and can_allocate == AllocStatus.OK - and budget.can_schedule(num_new_tokens=num_new_tokens, - num_new_seqs=num_new_seqs)): + if (num_new_tokens_uncached > 0 + and can_allocate == AllocStatus.OK + and budget.can_schedule( + num_new_tokens=num_new_tokens_uncached, + num_new_seqs=num_new_seqs, + )): break #Adjust budget to remove the victim sequence group vseq_group = running_queue.pop() - num_running_tokens = self._get_num_new_tokens( - vseq_group, SequenceStatus.RUNNING, False, budget) - budget.subtract_num_batched_tokens(vseq_group.request_id, - num_running_tokens) + num_running_tokens_uncached, _ = ( + self._get_num_new_uncached_and_cached_tokens( + vseq_group, SequenceStatus.RUNNING, False, budget)) + budget.subtract_num_batched_tokens( + vseq_group.request_id, num_running_tokens_uncached) num_running_seqs = vseq_group.get_max_num_running_seqs() budget.subtract_num_seqs(vseq_group.request_id, num_running_seqs) @@ -882,9 +920,12 @@ def _schedule_prefills( assert len(waiting_seqs) == 1, ( "Waiting sequence group should have only one prompt " "sequence.") - num_new_tokens = self._get_num_new_tokens(seq_group, - SequenceStatus.WAITING, - enable_chunking, budget) + num_new_tokens_uncached, num_new_tokens_cached = ( + self._get_num_new_uncached_and_cached_tokens( + seq_group, SequenceStatus.WAITING, enable_chunking, + budget)) + num_new_tokens = num_new_tokens_uncached + num_new_tokens_cached + if not enable_chunking: num_prompt_tokens = waiting_seqs[0].get_len() assert num_new_tokens == num_prompt_tokens @@ -935,10 +976,18 @@ def _schedule_prefills( waiting_queue.popleft() continue + if (budget.num_batched_tokens >= + self.scheduler_config.max_num_batched_tokens): + # We've reached the budget limit - since there might be + # continuous prefills in the running queue, we should break + # to avoid scheduling any new prefills. + break + num_new_seqs = seq_group.get_max_num_running_seqs() - if (num_new_tokens == 0 - or not budget.can_schedule(num_new_tokens=num_new_tokens, - num_new_seqs=num_new_seqs)): + if num_new_tokens_uncached == 0 or not budget.can_schedule( + num_new_tokens=num_new_tokens_uncached, + num_new_seqs=num_new_seqs, + ): break # Can schedule this request. @@ -967,7 +1016,11 @@ def _schedule_prefills( seq_groups.append( ScheduledSequenceGroup(seq_group=seq_group, token_chunk_size=num_new_tokens)) - budget.add_num_batched_tokens(seq_group.request_id, num_new_tokens) + budget.add_num_batched_tokens( + seq_group.request_id, + num_batched_tokens=num_new_tokens_uncached, + num_cached_tokens=num_new_tokens_cached, + ) budget.add_num_seqs(seq_group.request_id, num_new_seqs) # Queue requests that couldn't be scheduled. @@ -1075,7 +1128,8 @@ def _schedule_default(self) -> SchedulerOutputs: return SchedulerOutputs( scheduled_seq_groups=scheduled_seq_groups, num_prefill_groups=num_prefill_groups, - num_batched_tokens=budget.num_batched_tokens, + num_batched_tokens=budget.num_batched_tokens + + budget.num_cached_tokens, blocks_to_swap_in=swapped_in.blocks_to_swap_in, blocks_to_swap_out=running_scheduled.blocks_to_swap_out, blocks_to_copy=blocks_to_copy, @@ -1119,7 +1173,6 @@ def _schedule_chunked_prefill(self) -> SchedulerOutputs: running_scheduled.swapped_out) == 0: swapped_in = self._schedule_swapped(budget, curr_loras) - # Schedule new prefills. prefills = self._schedule_prefills(budget, curr_loras, enable_chunking=True) @@ -1148,23 +1201,34 @@ def _schedule_chunked_prefill(self) -> SchedulerOutputs: # Update swapped requests. self.swapped.extend(running_scheduled.swapped_out) # Put prefills first due to Attention backend ordering assumption. + scheduled_seq_groups = (prefills.seq_groups + + running_scheduled.prefill_seq_groups + + swapped_in.prefill_seq_groups + + running_scheduled.decode_seq_groups + + swapped_in.decode_seq_groups) + num_prefill_groups = (len(prefills.seq_groups) + + len(swapped_in.prefill_seq_groups) + + len(running_scheduled.prefill_seq_groups)) + # If all prompts, then we set num_lookahead_slots to 0 + # this allows us to go through the `no_spec` path in + # `spec_decode_worker.py` + all_prefills = (len(scheduled_seq_groups) == num_prefill_groups) + num_lookahead_slots = (0 if + (all_prefills + and not self.scheduler_config.is_multi_step) + else running_scheduled.num_lookahead_slots) return SchedulerOutputs( - scheduled_seq_groups=(prefills.seq_groups + - running_scheduled.prefill_seq_groups + - swapped_in.prefill_seq_groups + - running_scheduled.decode_seq_groups + - swapped_in.decode_seq_groups), - num_prefill_groups=(len(prefills.seq_groups) + - len(swapped_in.prefill_seq_groups) + - len(running_scheduled.prefill_seq_groups)), - num_batched_tokens=budget.num_batched_tokens, + scheduled_seq_groups=scheduled_seq_groups, + num_prefill_groups=num_prefill_groups, + num_batched_tokens=budget.num_batched_tokens + + budget.num_cached_tokens, blocks_to_swap_in=swapped_in.blocks_to_swap_in, blocks_to_swap_out=running_scheduled.blocks_to_swap_out, blocks_to_copy=running_scheduled.blocks_to_copy + swapped_in.blocks_to_copy, ignored_seq_groups=prefills.ignored_seq_groups + swapped_in.infeasible_seq_groups, - num_lookahead_slots=running_scheduled.num_lookahead_slots, + num_lookahead_slots=num_lookahead_slots, running_queue_size=len(self.running), preempted=(len(running_scheduled.preempted) + len(running_scheduled.swapped_out)), @@ -1303,6 +1367,7 @@ def schedule( encoder_seq_data=encoder_seq_data, cross_block_table=cross_block_table, state=seq_group.state, + token_type_ids=seq_group.token_type_ids, # `multi_modal_data` will only be present for the 1st comm # between engine and worker. # the subsequent comms can still use delta, but @@ -1584,64 +1649,178 @@ def _get_num_lookahead_slots(self, is_prefill: bool, return self.scheduler_config.num_lookahead_slots - def _get_num_new_tokens(self, seq_group: SequenceGroup, - status: SequenceStatus, enable_chunking: bool, - budget: SchedulingBudget) -> int: - """Get the next new tokens to compute for a given sequence group - that's in a given `status`. + def _get_num_new_uncached_and_cached_tokens( + self, + seq_group: SequenceGroup, + status: SequenceStatus, + enable_chunking: bool, + budget: SchedulingBudget, + ) -> Tuple[int, int]: + """ + Returns the number of new uncached and cached tokens to schedule for a + given sequence group that's in a given `status`. The API could chunk the number of tokens to compute based on `budget` if `enable_chunking` is True. If a sequence group has multiple sequences (e.g., running beam search), it means it is in decoding phase, so chunking doesn't happen. - Returns 0 if the new token cannot be computed due to token budget. + Returns (0, 0) if the new token cannot be computed due to token budget. + + The cached tokens's blocks are already computed, and the attention + backend will reuse the cached blocks rather than recomputing them. So + the scheduler could schedule these cached tokens "for free". + + Args: + seq_group: The sequence group to get the number of new tokens to + schedule. + status: The status of the sequences to get the number of new tokens + to schedule. + enable_chunking: Whether to chunk the number of tokens to compute. + budget: The budget to chunk the number of tokens to compute. + + + Returns: + A tuple of two ints. The first int is the number of new uncached + tokens to schedule. The second int is the number of cached tokens. + If no more new tokens can be scheduled, returns (0, 0). """ - num_new_tokens = 0 + num_cached_new_tokens = 0 + num_uncached_new_tokens = 0 + seqs = seq_group.get_seqs(status=status) + # Compute the number of new uncached and cached tokens for + # each sequence. for seq in seqs: - num_new_tokens += seq.get_num_new_tokens() - assert num_new_tokens > 0 - # Chunk if a running request cannot fit in the given budget. - # If number of seq > 1, it means it is doing beam search - # in a decode phase. Do not chunk. + if not seq.is_prefill(): + # Decode sequences should always just have 1 uncached token + # TODO(rickyx): Actually is this still correct for multi-step? + num_uncached_new_tokens += 1 + continue + + num_computed_tokens_seq = seq.get_num_computed_tokens() + all_num_new_tokens_seq = seq.get_len() - num_computed_tokens_seq + if not self.cache_config.enable_prefix_caching: + # If prefix caching is not enabled, all new tokens are uncached. + num_uncached_new_tokens += all_num_new_tokens_seq + continue + + # NOTE: the cache token might be currently in a block that's in an + # evictor meaning that it's not yet allocated. However, we don't + # exclude such tokens in the cache count because it will be + # guaranteed to be allocated later if the sequence can be allocated. + num_cached_tokens_seq = self.block_manager.get_num_cached_tokens( + seq) + + # Sanity check. + if num_cached_tokens_seq < num_computed_tokens_seq: + # This should only happen with chunked prefill, and + # the seq is still in prefill. The `num_cached_tokens_seq` + # is the value we calculated on scheduling the first prefill. + # For subsequent continuous prefill steps, we cached the + # number of cache tokens for the sequence so the cached token + # count could be less than the number of computed tokens. + # See comments on `ComputedBlocksTracker` for more details. + assert ( + seq.is_prefill() and seq.status == SequenceStatus.RUNNING + and self.scheduler_config.chunked_prefill_enabled + ), ("Number of cached tokens should not be less than the " + "number of computed tokens for a sequence that's still " + f"in prefill. But there are {num_cached_tokens_seq} cached " + f"tokens and {num_computed_tokens_seq} computed tokens " + f"for sequence {seq.seq_id}.") + + num_cached_new_tokens_seq = max( + 0, num_cached_tokens_seq - num_computed_tokens_seq) + num_uncached_new_tokens_seq = (all_num_new_tokens_seq - + num_cached_new_tokens_seq) + + num_uncached_new_tokens += num_uncached_new_tokens_seq + num_cached_new_tokens += num_cached_new_tokens_seq + + if num_uncached_new_tokens == 0 and num_cached_new_tokens > 0: + # For a fully cached hit sequence, we actually need to recompute the + # last token. So we need at least 1 uncached token to schedule. + # See ModelRunner._compute_for_prefix_cache_hit for more details. + num_uncached_new_tokens = 1 + num_cached_new_tokens -= 1 + if enable_chunking and len(seqs) == 1: - remaining_token_budget = budget.remaining_token_budget() - if self.scheduler_config.is_multi_step: - # The current multi-step + chunked prefill capability does - # not actually support chunking prompts. - # - # Therefore, `num_new_tokens` is computed in the same fashion - # for both multi-step+chunked-prefill & - # multi-step+chunked-prefill+APC - # - # Prompts with more tokens than the current remaining budget - # are postponed to future scheduler steps - if num_new_tokens > self._get_prompt_limit(seq_group): - # If the seq_group is in prompt-stage, pass the - # num_new_tokens as-is so the caller can ignore - # the sequence. - pass - else: - num_new_tokens = 0 \ - if num_new_tokens > remaining_token_budget \ - else num_new_tokens - elif self.cache_config.enable_prefix_caching: - # When prefix caching is enabled, we always allocate - # the number of new tokens that is dividable by the block - # size to avoid partial block matching. - block_size = self.cache_config.block_size - remainder = budget.token_budget % block_size - if remainder != 0: - raise ValueError("When enabling chunked prefill and " - "prefix caching, max_num_batched_tokens " - "(chunk size) must be dividable by " - "block size, but got chunk_size " - f"({budget.token_budget}) % block_size " - f"({block_size}) = {remainder}") - if remaining_token_budget < num_new_tokens: - num_new_tokens = (remaining_token_budget // - block_size) * block_size - else: - num_new_tokens = min(num_new_tokens, remaining_token_budget) + # Chunk if a running request cannot fit in the given budget. + # If number of seq > 1, it means it is doing beam search + # in a decode phase. Do not chunk. + num_uncached_new_tokens = self._chunk_new_tokens_to_schedule( + self.scheduler_config, + self.cache_config, + budget, + self._get_prompt_limit(seq_group), + num_uncached_new_tokens, + ) + + return num_uncached_new_tokens, num_cached_new_tokens + + @staticmethod + def _chunk_new_tokens_to_schedule( + scheduler_config: SchedulerConfig, + cache_config: CacheConfig, + budget: SchedulingBudget, + prompt_limit: int, + num_new_tokens: int, + ) -> int: + """ + Chunks the number of new tokens to schedule based on the budget when + chunked prefill is enabled. + + Args: + scheduler_config: The scheduler config. + cache_config: The cache config. + budget: The budget to chunk the number of tokens to compute. + prompt_limit: The maximum number of tokens allowed in a prompt. + num_new_tokens: The number of new tokens to schedule. + + Returns: + The number of new tokens to schedule after chunking. + """ + remaining_token_budget = budget.remaining_token_budget() + if scheduler_config.is_multi_step: + # The current multi-step + chunked prefill capability does + # not actually support chunking prompts. + # + # Therefore, `num_new_tokens` is computed in the same fashion + # for both multi-step+chunked-prefill & + # multi-step+chunked-prefill+APC + # + # Prompts with more tokens than the current remaining budget + # are postponed to future scheduler steps + if num_new_tokens > prompt_limit: + # If the seq_group is in prompt-stage, pass the + # num_new_tokens as-is so the caller can ignore + # the sequence. + return num_new_tokens + + return (0 if num_new_tokens > remaining_token_budget else + num_new_tokens) + + if cache_config.enable_prefix_caching: + # Adjust the remaining token budget to be divisible by the block + # size when prefix caching is enabled. + + # When prefix caching is enabled, we always allocate + # the number of new tokens that is dividable by the block + # size to avoid partial block matching. + block_size = cache_config.block_size + remainder = budget.token_budget % block_size + if remainder != 0: + raise ValueError("When enabling chunked prefill and " + "prefix caching, max_num_batched_tokens " + "(chunk size) must be dividable by " + "block size, but got chunk_size " + f"({budget.token_budget}) % block_size " + f"({block_size}) = {remainder}") + # Round down to block size. + remaining_token_budget = (remaining_token_budget // block_size * + block_size) + + num_new_tokens = min(num_new_tokens, remaining_token_budget) + return num_new_tokens diff --git a/vllm/distributed/device_communicators/pynccl.py b/vllm/distributed/device_communicators/pynccl.py index 7c6f48e88637b..d4e3f81747038 100644 --- a/vllm/distributed/device_communicators/pynccl.py +++ b/vllm/distributed/device_communicators/pynccl.py @@ -106,30 +106,72 @@ def __init__( self.stream.synchronize() del data - # by default it is disabled, e.g. in profiling models and prefill phase. - # to use it, use under `with obj.change_state(enable=True)`, usually - # when we are using CUDA graph. - self.disabled = True - def all_reduce(self, - tensor: torch.Tensor, + in_tensor: torch.Tensor, op: ReduceOp = ReduceOp.SUM, - stream=None): + stream=None) -> torch.Tensor: if self.disabled: - return + return None # nccl communicator created on a specific device # will only work on tensors on the same device # otherwise it will cause "illegal memory access" - assert tensor.device == self.device, ( + assert in_tensor.device == self.device, ( f"this nccl communicator is created to work on {self.device}, " - f"but the input tensor is on {tensor.device}") + f"but the input tensor is on {in_tensor.device}") + + out_tensor = torch.empty_like(in_tensor) + if stream is None: stream = self.stream - self.nccl.ncclAllReduce(buffer_type(tensor.data_ptr()), - buffer_type(tensor.data_ptr()), tensor.numel(), - ncclDataTypeEnum.from_torch(tensor.dtype), + self.nccl.ncclAllReduce(buffer_type(in_tensor.data_ptr()), + buffer_type(out_tensor.data_ptr()), + in_tensor.numel(), + ncclDataTypeEnum.from_torch(in_tensor.dtype), ncclRedOpTypeEnum.from_torch(op), self.comm, cudaStream_t(stream.cuda_stream)) + return out_tensor + + def all_gather(self, + output_tensor: torch.Tensor, + input_tensor: torch.Tensor, + stream=None): + if self.disabled: + return + # nccl communicator created on a specific device + # will only work on tensors on the same device + # otherwise it will cause "illegal memory access" + assert input_tensor.device == self.device, ( + f"this nccl communicator is created to work on {self.device}, " + f"but the input tensor is on {input_tensor.device}") + if stream is None: + stream = self.stream + self.nccl.ncclAllGather( + buffer_type(input_tensor.data_ptr()), + buffer_type(output_tensor.data_ptr()), input_tensor.numel(), + ncclDataTypeEnum.from_torch(input_tensor.dtype), self.comm, + cudaStream_t(stream.cuda_stream)) + + def reduce_scatter(self, + output_tensor: torch.Tensor, + input_tensor: torch.Tensor, + op: ReduceOp = ReduceOp.SUM, + stream=None): + if self.disabled: + return + # nccl communicator created on a specific device + # will only work on tensors on the same device + # otherwise it will cause "illegal memory access" + assert input_tensor.device == self.device, ( + f"this nccl communicator is created to work on {self.device}, " + f"but the input tensor is on {input_tensor.device}") + if stream is None: + stream = self.stream + self.nccl.ncclReduceScatter( + buffer_type(input_tensor.data_ptr()), + buffer_type(output_tensor.data_ptr()), output_tensor.numel(), + ncclDataTypeEnum.from_torch(input_tensor.dtype), + ncclRedOpTypeEnum.from_torch(op), self.comm, + cudaStream_t(stream.cuda_stream)) def send(self, tensor: torch.Tensor, dst: int, stream=None): if self.disabled: diff --git a/vllm/distributed/device_communicators/pynccl_wrapper.py b/vllm/distributed/device_communicators/pynccl_wrapper.py index 7619c98f22148..ff88f72470b27 100644 --- a/vllm/distributed/device_communicators/pynccl_wrapper.py +++ b/vllm/distributed/device_communicators/pynccl_wrapper.py @@ -151,6 +151,28 @@ class NCCLLibrary: ncclRedOp_t, ncclComm_t, cudaStream_t ]), + # ncclResult_t ncclAllGather( + # const void* sendbuff, void* recvbuff, size_t count, + # ncclDataType_t datatype, ncclComm_t comm, + # cudaStream_t stream); + # note that cudaStream_t is a pointer type, so the last argument + # is a pointer + Function("ncclAllGather", ncclResult_t, [ + buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t, + ncclComm_t, cudaStream_t + ]), + + # ncclResult_t ncclReduceScatter( + # const void* sendbuff, void* recvbuff, size_t count, + # ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm, + # cudaStream_t stream); + # note that cudaStream_t is a pointer type, so the last argument + # is a pointer + Function("ncclReduceScatter", ncclResult_t, [ + buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t, + ncclRedOp_t, ncclComm_t, cudaStream_t + ]), + # ncclResult_t ncclSend( # const void* sendbuff, size_t count, ncclDataType_t datatype, # int dest, ncclComm_t comm, cudaStream_t stream); @@ -258,6 +280,28 @@ def ncclAllReduce(self, sendbuff: buffer_type, recvbuff: buffer_type, datatype, op, comm, stream)) + def ncclReduceScatter(self, sendbuff: buffer_type, recvbuff: buffer_type, + count: int, datatype: int, op: int, comm: ncclComm_t, + stream: cudaStream_t) -> None: + # `datatype` actually should be `ncclDataType_t` + # and `op` should be `ncclRedOp_t` + # both are aliases of `ctypes.c_int` + # when we pass int to a function, it will be converted to `ctypes.c_int` + # by ctypes automatically + self.NCCL_CHECK(self._funcs["ncclReduceScatter"](sendbuff, recvbuff, + count, datatype, op, + comm, stream)) + + def ncclAllGather(self, sendbuff: buffer_type, recvbuff: buffer_type, + count: int, datatype: int, comm: ncclComm_t, + stream: cudaStream_t) -> None: + # `datatype` actually should be `ncclDataType_t` + # which is an aliases of `ctypes.c_int` + # when we pass int to a function, it will be converted to `ctypes.c_int` + # by ctypes automatically + self.NCCL_CHECK(self._funcs["ncclAllGather"](sendbuff, recvbuff, count, + datatype, comm, stream)) + def ncclSend(self, sendbuff: buffer_type, count: int, datatype: int, dest: int, comm: ncclComm_t, stream: cudaStream_t) -> None: self.NCCL_CHECK(self._funcs["ncclSend"](sendbuff, count, datatype, diff --git a/vllm/distributed/kv_transfer/README.md b/vllm/distributed/kv_transfer/README.md new file mode 100644 index 0000000000000..dab2d10c4c9d0 --- /dev/null +++ b/vllm/distributed/kv_transfer/README.md @@ -0,0 +1,30 @@ + +# Distributed KV cache transfer + +This folder implements distributed KV cache transfer across vLLM instances. +Currently the main usecase is for disaggregated prefilling. + +## Abstractions + +The KV cache transfer contains three layer of abstractions: + +- KV pipe: a FIFO pipe for torch.tensor transmission. Key APIs: `send_tensor` and `recv_tensor`. +- KV lookup buffer: a lookup buffer for KV caches. Key: the tokens, value: the KV caches (and/or hidden states). Key APIs: `insert` and `drop_select` (similar to SQL semantics). +- KV connector: a connector that connects the KV pipe and KV lookup buffer to vLLM. Key APIs: `send_kv_caches_and_hidden_states` and `recv_kv_caches_and_hidden_states`. + +Why we need KV lookup buffer: FIFO pipe itself is not enough as prefill vLLM worker may process requests in a different order compared to decode vLLM worker. Say the QPS is really high, prefill worker may handle requests in order A -> B -> C, but the decode worker may process request C first. This is not the case that can be naturally handled by FIFO pipe, so we provide KV lookup buffer to help translate a FIFO pipe to a lookup buffer. + +NOTE: KV pipe layer is bypassible: you can skip this layer if your distributed +communication service already supports key-value-based lookup (like redis or +RDMA database). + +NOTE: If you want to not only transfer KV caches, but adjust the model execution flow of vLLM as well (for example, allow vLLM to receive KV caches on some tokens and do prefill on the remaining tokens), you can bypass both KV pipe layer and KV lookup buffer layer, and directly implement on KV connector layer. Bear in mind that as vLLM's model input is constantly changing, this implementation will likely be broken when vLLM has new updates. + +## Disaggregated prefilling + +The example usage is in [this file](../../../examples/disaggregated_prefill.sh). + +Here is the diagram of how we run disaggretgated prefilling. + +![Disaggregated prefill workflow](./disagg_prefill_workflow.jpg) + diff --git a/vllm/distributed/kv_transfer/__init__.py b/vllm/distributed/kv_transfer/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/distributed/kv_transfer/disagg_prefill_workflow.jpg b/vllm/distributed/kv_transfer/disagg_prefill_workflow.jpg new file mode 100644 index 0000000000000..a25ec5ef52491 Binary files /dev/null and b/vllm/distributed/kv_transfer/disagg_prefill_workflow.jpg differ diff --git a/vllm/distributed/kv_transfer/kv_connector/__init__.py b/vllm/distributed/kv_transfer/kv_connector/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/distributed/kv_transfer/kv_connector/base.py b/vllm/distributed/kv_transfer/kv_connector/base.py new file mode 100644 index 0000000000000..6089e3babac3e --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_connector/base.py @@ -0,0 +1,122 @@ +""" +KVConnectorBase Class for Distributed KV Cache & Hidden State communication + +The class provides two primary abstract methods: +1. send_kv_caches_and_hidden_states(): Send KV caches and hidden states +2. recv_kv_caches_and_hidden_states(): Recv KV caches and hidden states +""" + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, List, Tuple, Union + +import torch + +from vllm.sequence import IntermediateTensors + +if TYPE_CHECKING: + from vllm.config import VllmConfig + from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata + + +class KVConnectorBase(ABC): + """ + Abstract base class for a KV connector. + + The class provides two primary abstract methods: + 1. send_kv_caches_and_hidden_states(): Send KV caches and hidden states + 2. recv_kv_caches_and_hidden_states(): Recv KV caches and hidden states + """ + + @abstractmethod + def __init__( + self, + rank: int, + local_rank: int, + config: "VllmConfig", + ): + raise NotImplementedError + + @abstractmethod + def close(self) -> None: + """Close the buffer and release resources. + + This method is responsible for cleaning up resources related to the + connector when it is no longer needed. + + Raises: + NotImplementedError: This method must be implemented in subclasses. + """ + raise NotImplementedError + + @abstractmethod + def send_kv_caches_and_hidden_states( + self, + model_executable: torch.nn.Module, + model_input: "ModelInputForGPUWithSamplingMetadata", + kv_caches: List[torch.Tensor], + hidden_or_intermediate_states: Union[torch.Tensor, + IntermediateTensors], + ) -> None: + """ + Send KV caches and hidden states to the connector. + + This method processes the input tokens, KV caches, and + hidden/intermediate states for a given model and sends the data to the + decode instance. + + Args: + model_executable (torch.nn.Module): The model executable containing + start and end layer information. + model_input (ModelInputForGPUWithSamplingMetadata): The input + metadata from vLLM. + kv_caches (List[torch.Tensor]): List of KV caches (keys and values) + for each layer. + hidden_or_intermediate_states (Union[torch.Tensor, + IntermediateTensors]): + The hidden or intermediate states associated with the tokens. + + Returns: + None + + """ + + raise NotImplementedError + + @abstractmethod + def recv_kv_caches_and_hidden_states( + self, model_executable: torch.nn.Module, + model_input: "ModelInputForGPUWithSamplingMetadata", + kv_caches: List[torch.Tensor] + ) -> Tuple[Union[torch.Tensor, IntermediateTensors], bool, + "ModelInputForGPUWithSamplingMetadata"]: + """ + Receive KV caches and hidden states from the connector. + + This method attempts to retrieve KV caches and hidden states for input + tokens. If all required KV caches and hidden states are received, it + will bypass model input, else it will fall back to normal vLLM model + forwarding. + + Args: + model_executable (torch.nn.Module): + The model executable from vLLM modelrunner. + model_input (ModelInputForGPUWithSamplingMetadata): + The model input from vLLM modelrunner. + kv_caches (List[torch.Tensor]): + List of KV caches for each layer. + + Returns: + - hidden_or_intermediate_states (torch.Tensor or + IntermediateTensors): + Concatenated hidden states if all required data is retrieved, + otherwise `None`. + - bypass_model_exec (bool): + Indicates whether the model execution can be skipped (True) or + needs to be redone (False). + - model_input (ModelInputForGPUWithSamplingMetadata): + Optionally adjusted input metadata for re-execution when + `bypass_model_exec=False`. + + """ + + raise NotImplementedError diff --git a/vllm/distributed/kv_transfer/kv_connector/factory.py b/vllm/distributed/kv_transfer/kv_connector/factory.py new file mode 100644 index 0000000000000..015f892cec933 --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_connector/factory.py @@ -0,0 +1,19 @@ +from typing import TYPE_CHECKING + +from .base import KVConnectorBase + +if TYPE_CHECKING: + from vllm.config import VllmConfig + + +class KVConnectorFactory: + + @staticmethod + def create_connector(rank: int, local_rank: int, + config: "VllmConfig") -> KVConnectorBase: + if config.kv_transfer_config.kv_connector == 'PyNcclConnector': + from .simple_connector import SimpleConnector + return SimpleConnector(rank, local_rank, config) + else: + raise ValueError(f"Unsupported connector type: " + f"{config.kv_connector}") diff --git a/vllm/distributed/kv_transfer/kv_connector/simple_connector.py b/vllm/distributed/kv_transfer/kv_connector/simple_connector.py new file mode 100644 index 0000000000000..5870070a54c75 --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_connector/simple_connector.py @@ -0,0 +1,261 @@ +""" +Simple KV Cache Connector for Distributed Machine Learning Inference + +The SimpleConnector transfers KV caches between prefill vLLM worker (KV cache +producer) and decode vLLM worker (KV cache consumer) using PyNcclPipe. + +But the logic can be extended to support other pipe and lookup buffer. +""" +from typing import TYPE_CHECKING, List, Optional, Tuple, Union + +import torch + +from vllm import _custom_ops as ops +from vllm.config import VllmConfig +from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase +from vllm.distributed.kv_transfer.kv_lookup_buffer.simple_buffer import ( + SimpleBuffer) +from vllm.distributed.kv_transfer.kv_pipe.pynccl_pipe import PyNcclPipe +from vllm.logger import init_logger +from vllm.sequence import IntermediateTensors + +if TYPE_CHECKING: + from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata + +logger = init_logger(__name__) + + +class SimpleConnector(KVConnectorBase): + + def __init__( + self, + rank: int, + local_rank: int, + config: VllmConfig, + ): + + self.config = config.kv_transfer_config + + logger.info("Initializing PyNcclConfig under kv_transfer_config %s", + self.config) + + self.lookup_buffer_size = self.config.kv_buffer_size + + self.producer_buffer: Optional[SimpleBuffer] = None + self.consumer_buffer: Optional[SimpleBuffer] = None + + # 2 pipes for every rank in the world + port_offset_base = 2 * rank + + # In disaggregated prefill, the prefill vLLM only uses send pipe + # and the decode vLLM only uses recv pipe + if self.config.is_kv_producer: + + self.producer_data_pipe = PyNcclPipe( + local_rank=local_rank, + config=self.config, + port_offset=port_offset_base, + ) + self.producer_signal_pipe = PyNcclPipe( + local_rank=local_rank, + config=self.config, + port_offset=port_offset_base + 1, + device="cpu", + ) + self.producer_buffer = SimpleBuffer(self.producer_signal_pipe, + self.producer_data_pipe, + self.config.kv_buffer_size) + + else: + + # the current vLLM instance is KV consumer, so it needs to connect + # its recv pipe to the send pipe of KV producder + self.consumer_data_pipe = PyNcclPipe( + local_rank=local_rank, + config=self.config, + port_offset=port_offset_base, + ) + self.consumer_signal_pipe = PyNcclPipe( + local_rank=local_rank, + config=self.config, + port_offset=port_offset_base + 1, + device="cpu", + ) + self.consumer_buffer = SimpleBuffer( + self.consumer_signal_pipe, + self.consumer_data_pipe, + self.config.kv_buffer_size, + ) + + def select(self, input_tokens: Optional[torch.Tensor], + roi: Optional[torch.Tensor]) -> List[Optional[torch.Tensor]]: + + assert self.consumer_buffer is not None, "Please initialize the "\ + "consumer buffer before calling select." + return self.consumer_buffer.drop_select(input_tokens, roi) + + def insert(self, input_tokens: torch.Tensor, roi: torch.Tensor, + key: torch.Tensor, value: torch.Tensor, + hidden: torch.Tensor) -> None: + + assert self.producer_buffer is not None, "Please initialize the "\ + "producer buffer before calling insert." + + self.producer_buffer.insert(input_tokens, roi, key, value, hidden) + + def send_kv_caches_and_hidden_states( + self, + model_executable: torch.nn.Module, + model_input: "ModelInputForGPUWithSamplingMetadata", + kv_caches: List[torch.Tensor], + hidden_or_intermediate_states: Union[torch.Tensor, + IntermediateTensors], + ) -> None: + + input_tokens_tensor = model_input.input_tokens + seq_lens = model_input.attn_metadata.seq_lens + slot_mapping_flat = model_input.attn_metadata.slot_mapping.flatten() + start_layer = model_executable.model.start_layer + end_layer = model_executable.model.end_layer + + # query_lens contains new KV caches that are added to vLLM. + # so we will send them to decode instance + # FIXME(Kuntai): This assume that all requests are prefill. + for idx, slen in enumerate(seq_lens): + start_pos = sum(seq_lens[:idx]) + end_pos = start_pos + slen + current_tokens = input_tokens_tensor[start_pos:end_pos] + + keys, values = [], [] + + for layer_id in range(start_layer, end_layer): + kv_cache = kv_caches[layer_id - start_layer] + + _, _, num_heads, head_size = kv_cache[0].shape + + key_cache = kv_cache[0].reshape(-1, num_heads, head_size) + value_cache = kv_cache[1].reshape(-1, num_heads, head_size) + + current_slot_mapping = slot_mapping_flat[start_pos:end_pos] + + keys.append(key_cache[current_slot_mapping].unsqueeze(0)) + values.append(value_cache[current_slot_mapping].unsqueeze(0)) + + keys = torch.cat(keys, dim=0) + values = torch.cat(values, dim=0) + + self.insert(current_tokens, + torch.ones_like(current_tokens, + dtype=bool), keys, values, + hidden_or_intermediate_states[start_pos:end_pos]) + + logger.debug("[rank%d]: KV send DONE.", torch.distributed.get_rank()) + + def recv_kv_caches_and_hidden_states( + self, model_executable: torch.nn.Module, + model_input: "ModelInputForGPUWithSamplingMetadata", + kv_caches: List[torch.Tensor] + ) -> Tuple[Union[torch.Tensor, IntermediateTensors], bool, + "ModelInputForGPUWithSamplingMetadata"]: + + # When bypass_model_exec is set to False, it means that at least for one + # request its corresponding KV cache or hidden state is missing. + # In this case we need to do prefilling to recompute missing KV cache + # and hidden states. + bypass_model_exec = True + + input_tokens_tensor = model_input.input_tokens + seq_lens = model_input.attn_metadata.seq_lens + slot_mapping = model_input.attn_metadata.slot_mapping.flatten() + + hidden_or_intermediate_states_for_one_req = [] + + input_tokens_list = [] + num_computed_tokens_list = [] + start_pos_list = [] + + # enumerate different requests + # FIXME(Kuntai): This impl assumes that all requests are prefill. + for idx, slen in enumerate(seq_lens): + + start_pos = sum(seq_lens[:idx]) + end_pos = start_pos + slen + current_tokens = input_tokens_tensor[start_pos:end_pos] + num_tokens = slen + + # collecting data for rebuilding the input + input_tokens_list.append(current_tokens) + start_pos_list.append(start_pos) + + ret = self.select(current_tokens, + torch.ones_like(current_tokens, dtype=bool)) + if ret[0] is None: + # didn't find any match. + bypass_model_exec = False + num_computed_tokens_list.append(0) + continue + + roi: torch.Tensor = ret[1] + keys: torch.Tensor = ret[2] + values: torch.Tensor = ret[3] + hidden: torch.Tensor = ret[4] + + num_computed_tokens = roi.shape[0] + num_computed_tokens_list.append(num_computed_tokens) + + # check if both KV cache and the hidden states are received + # If not, need to redo the forwarding to compute missing states + if not all([(num_computed_tokens == num_tokens), hidden is not None + ]): + bypass_model_exec = False + + # update the end position based on how many tokens are cached. + end_pos = start_pos + num_computed_tokens + + # put received KV caches into paged memory + for i in range(model_executable.model.start_layer, + model_executable.model.end_layer): + + kv_cache = kv_caches[i - model_executable.model.start_layer] + layer = model_executable.model.layers[i] + + key_cache, value_cache = kv_cache[0], kv_cache[1] + ops.reshape_and_cache_flash( + keys[i - model_executable.model.start_layer].to( + key_cache.device), + values[i - model_executable.model.start_layer].to( + value_cache.device), + key_cache, + value_cache, + slot_mapping[start_pos:end_pos], + layer.self_attn.attn.kv_cache_dtype, + layer.self_attn.attn._k_scale, + layer.self_attn.attn._v_scale, + ) + + hidden_or_intermediate_states_for_one_req.append(hidden) + + if not bypass_model_exec: + # Some of the KV cache is not retrieved + # Here we will fall back to normal model forwarding + # But optionally you can adjust model_input so that you only do + # prefilling on those tokens that are missing KV caches. + logger.debug( + "[rank%d]: Failed to receive all KVs and hidden " + "states, redo model forwarding.", torch.distributed.get_rank()) + hidden_or_intermediate_states = None + + else: + logger.debug( + "[rank%d]: Successfully received all KVs and hidden " + "states, skip model forwarding.", torch.distributed.get_rank()) + hidden_or_intermediate_states = torch.cat( + hidden_or_intermediate_states_for_one_req, dim=0) + + return hidden_or_intermediate_states, bypass_model_exec, model_input + + def close(self): + self.producer_data_pipe.close() + self.producer_signal_pipe.close() + self.consumer_data_pipe.close() + self.consumer_signal_pipe.close() diff --git a/vllm/distributed/kv_transfer/kv_lookup_buffer/__init__.py b/vllm/distributed/kv_transfer/kv_lookup_buffer/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/distributed/kv_transfer/kv_lookup_buffer/base.py b/vllm/distributed/kv_transfer/kv_lookup_buffer/base.py new file mode 100644 index 0000000000000..bad119a1aa929 --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_lookup_buffer/base.py @@ -0,0 +1,108 @@ +""" +This file contains a new class `KVLookupBufferBase` that allows developers to +think of KV cache operations as inserting new KV cache entries (`insert`) +into the lookup buffer and querying existing KV caches (`drop_select`) +from the lookup buffer. + +All distributed communications are abstracted behind this class. +""" + +from abc import ABC, abstractmethod +from typing import List, Optional + +import torch + + +class KVLookupBufferBase(ABC): + """ + Abstract base class for a lookup buffer. + + This class provides an abstraction for a key-value (KV) cache lookup buffer. + + The key of the lookup buffer: + - input_tokens: token IDs of the request + - roi: a binary mask on top of input_tokens. + - Purpose of roi: Since KV cache may only be available for a subset of + tokens in the input (for example, when vLLM is connected to an external + KV cache service), roi specifies the subset of tokens that the KV cache + is associated with. + - NOTE: roi can be further extended to describe which part of KV the + current process is holding (each process may only hold a part of KV + due to TP and PP). This is not implemented for now. + + The value of the lookup buffer: + - key: the key tensor in the KV cache + - value: the value tensor in the KV cache + - hidden: the final hidden state generated by model forwarding. This allows + vLLM to bypass further model forwarding by transmitting the hidden state. + """ + + @abstractmethod + def insert(self, input_tokens: torch.Tensor, roi: torch.Tensor, + key: torch.Tensor, value: torch.Tensor, + hidden: torch.Tensor) -> None: + """Insert into the lookup buffer. + + The functionality is similar to the following python statement + ``` + buffer[input_tokens, roi] = [key, value, hidden] + ``` + + FIXME: in the future, we should only have two arguments, key and value, + where key is a tensor dict and value is a tensor dict. + + FIXME: we should transmit both sampler outputs and the hidden states. + + Args: + input_tokens (torch.Tensor): token IDs. + roi (torch.Tensor): A binary mask on top of the input tokens + key (torch.Tensor): The key tensor in the KV cache. + value (torch.Tensor): The value tensor in the KV cache. + hidden (torch.Tensor): The final hidden state tensor generated + during model forwarding to bypass model + forwarding. + + Raises: + NotImplementedError: This method must be implemented in subclasses. + """ + raise NotImplementedError + + @abstractmethod + def drop_select( + self, input_tokens: Optional[torch.Tensor], + roi: Optional[torch.Tensor]) -> List[Optional[torch.Tensor]]: + """Select and *drop* KV cache entries from the lookup buffer. + + The functionality is similar to the following python statements + ``` + ret = buffer.pop(input_tokens, roi) + return ret + ``` + + If `input_tokens` and `roi` is `None`, it means selecting any of the + KV caches in the buffer, return, and remove it from the buffer, useful + when offloading KV cache to KV cache storage service. + + Args: + input_tokens (torch.Tensor): token IDs. + roi (torch.Tensor): A binary mask on top of the input tokens + + Returns: + List[Optional[torch.Tensor]]: A list of tensors. Can be None. + + Raises: + NotImplementedError: This method must be implemented in subclasses. + """ + raise NotImplementedError + + @abstractmethod + def close(self) -> None: + """Close the buffer and release resources. + + This method is responsible for cleaning up resources related to the + lookup buffer when it is no longer needed. + + Raises: + NotImplementedError: This method must be implemented in subclasses. + """ + raise NotImplementedError diff --git a/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py b/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py new file mode 100644 index 0000000000000..fe8d8d7375f36 --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py @@ -0,0 +1,242 @@ +""" + Implements a distributed key-value (KV) cache transfer mechanism. + + Key Features: + - Distributed KV cache transmission using PyNccl pipes. + - Non-blocking `insert`, blocking `drop_select`. + - Use CPU signal pipe to avoid racing condition + - Handles buffer size constraints and provide backpressure mechanism to + stop the prefill instance when the decode instance is slow. +""" +import threading +import time +from collections import deque +from typing import Deque, List, Optional, Union + +import torch + +from vllm.distributed.kv_transfer.kv_lookup_buffer.base import ( + KVLookupBufferBase) +from vllm.distributed.kv_transfer.kv_pipe.base import KVPipeBase +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +class SimpleBuffer(KVLookupBufferBase): + + def __init__(self, signal_pipe: KVPipeBase, data_pipe: KVPipeBase, + buffer_size_thresh: float): + """ + signal_pipe: on CPU + + NOTE: on-device recv will block all threads in the process, making the + KV cache producer unable to listen to new request while transmitting + KV cache. Luckily CPU recv only blocks the current thread so we use + CPU recv to listen to new request. + + data_pipe: on device (e.g. GPU) + """ + + self.buffer: Deque[List[torch.Tensor]] = deque() + + self.buffer_size = 0 + self.buffer_size_threshold = buffer_size_thresh + self.buffer_lock = threading.Lock() + self.signal_pipe = signal_pipe + self.data_pipe = data_pipe + self.request_handling_thread: Optional[threading.Thread] = None + + self.normal_signal = torch.tensor([0], device="cpu") + self.end_signal = None + + def _matches(self, tokens_roi_sender: List[torch.Tensor], + tokens_roi_recver: List[torch.Tensor]): + + # tokens_roi_sender: tokens and roi of the producer (in the buffer) + # tokens_roi_recver: tokens and roi of the consumer (query) + + tokens_sender = tokens_roi_sender[0] + tokens_recver = tokens_roi_recver[0] + roi_sender = tokens_roi_sender[1] + roi_recver = tokens_roi_recver[1] + + if tokens_recver is None: + # consumer sends an empty request + # semantics: DROP SELECT * LIMIT 1 + # so any of the data in the buffer can be drop-selected + return True + + # Assuming that roi is a binary mask on tokens + tokens_sender = tokens_sender[roi_sender] + tokens_recver = tokens_recver[roi_recver] + + # simple common prefix matching + min_length = min(len(tokens_sender), len(tokens_recver)) + if torch.allclose(tokens_sender[:min_length], + tokens_recver[:min_length]): + return min_length + + return 0 + + def _send_tensor_and_dec_size(self, + tensor: Optional[torch.Tensor]) -> None: + + assert tensor is not None, "Use self.data_pipe.send(None) instead" + self.buffer_size -= tensor.element_size() * tensor.numel() + if tensor.dtype == torch.bool: + tensor = tensor.float() + self.data_pipe.send_tensor(tensor) + + def _get_element_size(self, data: Optional[Union[List, torch.Tensor]]): + + if isinstance(data, torch.Tensor): + return data.element_size() * data.numel() + if not data: + # cannot perform `not data` on a tensor + # so this check needs to go after the check above + return 0 + + raise AssertionError(f"Unknown data type {type(data)}") + + def _add_to_buffer(self, input_tokens: torch.Tensor, roi: torch.Tensor, + key: torch.Tensor, value: torch.Tensor, + hidden: torch.Tensor): + + if isinstance(input_tokens, torch.Tensor): + input_tokens = input_tokens.clone() + if isinstance(roi, torch.Tensor): + roi = roi.clone() + if isinstance(key, torch.Tensor): + key = key.clone() + if isinstance(value, torch.Tensor): + value = value.clone() + if isinstance(hidden, torch.Tensor): + hidden = hidden.clone() + + buffer_item = [input_tokens, roi, key, value, hidden] + + with self.buffer_lock: + for data in buffer_item: + self.buffer_size += self._get_element_size(data) + self.buffer.append(buffer_item) + + def _is_end_signal(self, signal): + return signal is None + + def drop_select_handler(self): + + try: + + while True: + signal = self.signal_pipe.recv_tensor() + if self._is_end_signal(signal): + logger.info("Received end signal!") + break + + input_tokens = self.data_pipe.recv_tensor() + + roi = self.data_pipe.recv_tensor() + assert roi is not None, "Please provide the roi when sending "\ + "drop-select request" + roi = (roi > 0.5) + tokens_roi_recver = [input_tokens, roi] + + matched_length = 0 + + # perform input tokens and roi matching + # FIXME: this matching is O(n), ideally it should be O(1) + # but this buffer size won't (and shouldn't) be too large so + # the fix is not urgent. + with self.buffer_lock: + + for _ in range(len(self.buffer)): + + temp_length = self._matches(self.buffer[0], + tokens_roi_recver) + if temp_length > 0: + matched_length = temp_length + break + # rotate the element we just accessed to the end + self.buffer.rotate(-1) + + if matched_length > 0: + # need to clone the tensor + # in case the tensor is freed before sending finishes + matched_item = self.buffer.popleft() + for tensor in matched_item: + self._send_tensor_and_dec_size(tensor) + + else: + # no match, just send None + for _ in range(5): + self.data_pipe.send_tensor(None) + + except RuntimeError as e: + if 'Connection closed by peer' not in str(e): + raise e + + logger.debug("Closing drop_select_handler") + + def drop_select( + self, input_tokens: Optional[torch.Tensor], + roi: Optional[torch.Tensor]) -> List[Optional[torch.Tensor]]: + + assert self.request_handling_thread is None, \ + "drop_select should be called by the KV cache consumer "\ + "(e.g. the decode vLLM instance)" + + if isinstance(input_tokens, torch.Tensor): + input_tokens = input_tokens.clone() + if isinstance(roi, torch.Tensor): + roi = roi.clone().float() + + self.signal_pipe.send_tensor(self.normal_signal) + self.data_pipe.send_tensor(input_tokens) + self.data_pipe.send_tensor(roi) + + input_tokens = self.data_pipe.recv_tensor() + roi = self.data_pipe.recv_tensor() + if roi is not None: + # convert from float tensor to bool tensor + # as PyNccl does not support sending bool tensor + roi = (roi > 0.5) + key = self.data_pipe.recv_tensor() + value = self.data_pipe.recv_tensor() + hidden = self.data_pipe.recv_tensor() + + return [input_tokens, roi, key, value, hidden] + + def full_handler(self): + time.sleep(0.001) + + def insert(self, input_tokens: torch.Tensor, roi: torch.Tensor, + key: torch.Tensor, value: torch.Tensor, + hidden: torch.Tensor) -> None: + + if self.buffer_size > self.buffer_size_threshold: + # log outside the while loop to avoid this message being logged + # repeatedly. + logger.debug("KV transfer buffer is full. Handling...") + while self.buffer_size > self.buffer_size_threshold: + self.full_handler() + + self._add_to_buffer(input_tokens, roi, key, value, hidden) + + # when calling the insert, the current process is a sender + # need to launch the request handler and start listening to request. + if self.request_handling_thread is None: + self.request_handling_thread = threading.Thread( + target=self.drop_select_handler) + self.request_handling_thread.start() + + def close(self): + + if hasattr(self, "request_handling_thread" + ) and self.request_handling_thread is not None: + self.request_handling_thread.join() + + else: + # TODO: have a explicit close signal and have a explicit way to + # check if it's requester + self.signal_pipe.send_tensor(self.end_signal) diff --git a/vllm/distributed/kv_transfer/kv_pipe/__init__.py b/vllm/distributed/kv_transfer/kv_pipe/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/distributed/kv_transfer/kv_pipe/base.py b/vllm/distributed/kv_transfer/kv_pipe/base.py new file mode 100644 index 0000000000000..4b0cb44cc5b81 --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_pipe/base.py @@ -0,0 +1,65 @@ +""" +This file defines an interface `KVPipeBase` +that provides an abstraction for sending and receiving tensors, or None, via +distributed communications. + +All classes instantiated from this interface are assumed to be a FIFO pipe. + +If your distributed communication platform already supports key-value lookup, +you can bypass this interface and directly start from `kv_lookup_buffer`. +""" + +from abc import ABC, abstractmethod +from typing import Optional + +import torch + + +class KVPipeBase(ABC): + """ + This class provides an interface for sending and receiving tensors, or + None, by distributed communications. + """ + + @abstractmethod + def send_tensor(self, tensor: Optional[torch.Tensor]) -> None: + """Send a tensor, or None, via the pipe. + + Need to support sending None -- important for error handling. + + TODO: add a `key` argument so that we can use traditional + key-value database as the distributed communication mechanism behind + the pipe. + + Args: + tensor (Optional[torch.Tensor]): The tensor to be sent. Can be None. + + Raises: + NotImplementedError: This method must be implemented in subclasses. + """ + raise NotImplementedError + + @abstractmethod + def recv_tensor(self) -> Optional[torch.Tensor]: + """Receive a tensor (can be None) from the pipeline. + + Returns: + Optional[torch.Tensor]: The tensor received from the pipeline. Can + be None. + + Raises: + NotImplementedError: This method must be implemented in subclasses. + """ + raise NotImplementedError + + @abstractmethod + def close(self) -> None: + """Close the pipeline and release resources. + + This method is responsible for closing the communication pipeline + and releasing any resources associated with it. + + Raises: + NotImplementedError: This method must be implemented in subclasses. + """ + raise NotImplementedError diff --git a/vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py b/vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py new file mode 100644 index 0000000000000..98222fa67e492 --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py @@ -0,0 +1,276 @@ +""" + This module implements a PyNccl pipe for sending and receiving + Optional[torch.Tensor] between distributed ranks with advanced + communication features. + + Key Features: + - Supports sending and receiving tensors with metadata + - Handles both CUDA and CPU device communications + - Implements a non-blocking tensor transfer mechanism + - Manages buffer size and provides backpressure control + - Supports distributed process groups with configurable parameters +""" + +import threading +import time +from concurrent.futures import ThreadPoolExecutor +from typing import Callable, Dict, Optional, Tuple + +import torch + +from vllm.config import KVTransferConfig +from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator +from vllm.distributed.kv_transfer.kv_pipe.base import KVPipeBase +from vllm.distributed.utils import StatelessProcessGroup +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +class BrokenPipeException(Exception): + + def __init__(self, message): + self.message = message + super().__init__(self.message) + + +Metadata = Dict[str, Optional[torch.Tensor]] + + +class PyNcclPipe(KVPipeBase): + + METADATA_LENGTH = 16 + MAX_TENSOR_DIMENSIONS = 14 + METADATA_DTYPE = torch.int64 + + def __init__(self, + local_rank: int, + config: KVTransferConfig, + device: Optional[str] = None, + port_offset: int = 0): + self.config = config + self.local_rank = local_rank + self.kv_rank = self.config.kv_rank + self.kv_parallel_size = self.config.kv_parallel_size + if device is None: + self.device = self._select_device(self.config.kv_buffer_device) + else: + self.device = self._select_device(device) + + # build distributed connection and send/recv implementation + self.group = StatelessProcessGroup.create( + host=self.config.kv_ip, + port=self.config.kv_port + port_offset, + rank=self.kv_rank, + world_size=self.kv_parallel_size, + ) + # add a barrier to make sure the connection is initiated properly + self.group.barrier() + impl = self._get_device_send_recv_impl(self.group) + self.device_send_func, self.device_recv_func = impl + # set target rank + self.target_rank_for_send = (self.kv_rank + 1) % self.kv_parallel_size + self.target_rank_for_recv = (self.kv_rank - 1) % self.kv_parallel_size + + # transportation-related variables + self.transport_thread: Optional[ThreadPoolExecutor] = None + self.buffer_size = 0 + self.buffer_size_lock = threading.Lock() + self.buffer_size_thresh = self.config.kv_buffer_size + + def _get_device_send_recv_impl( + self, group: StatelessProcessGroup + ) -> Tuple[Callable[[torch.Tensor, int], None], Callable[ + [torch.Tensor, int], None]]: + + send: Callable[[torch.Tensor, int], None] + recv: Callable[[torch.Tensor, int], None] + if self.device.type == "cuda": + # use PyNCCL for send / recv + comm = PyNcclCommunicator(group, device=self.local_rank) + comm.disabled = False + send, recv = comm.send, comm.recv # type: ignore + else: + # This send / recv implementation here is NOT intended to transfer + # KV caches (and should NOT be repurposed to transfer KV caches). + # Currently it is only used to transmit control-plane messages + # for PyNcclBuffer. + send = group.send_obj + + def my_recv(x, src): + x[...] = group.recv_obj(src) + + recv = my_recv + + return send, recv + + def _select_device(self, device: str): + logger.info("Selecting device: %s", device) + if device == "cuda": + return torch.device(f"cuda:{self.local_rank}") + else: + return torch.device("cpu") + + def _make_metadata(self, tensor: Optional[torch.Tensor]) -> Metadata: + """ + Create the metadata as a dictionary based on the input tensor. + + Parameters: + - tensor: The input tensor or None if no tensor is provided. + + Returns: + - metadata: A dictionary with the following keys: + - "dtype": The data type of the tensor or None. + - "shape": The shape of the tensor or None. + """ + if tensor is None: + return {"dtype": None, "shape": None} + else: + return {"dtype": tensor.dtype, "shape": tensor.shape} + + def _prepare_recv_buffer(self, metadata: Metadata) -> torch.Tensor: + """ + Create a buffer to receive the tensor based on the provided metadata. + + Parameters: + - metadata: A dictionary with keys "dtype" and "shape", describing + the tensor's data type and shape. + + Returns: + - buffer: A tensor of the specified type and shape, allocated on + self.device. + """ + return torch.empty(metadata["shape"], + dtype=metadata["dtype"], + device=self.device) + + def _send_metadata(self, metadata: Metadata): + """ + Send the metadata dictionary to the target rank. + + Parameters: + - metadata: A dictionary with keys "dtype" and "shape". + """ + self.group.send_obj(metadata, self.target_rank_for_send) + + def _recv_metadata(self) -> Metadata: + """ + Receive the metadata dictionary from the target rank. + + Returns: + - metadata: A dictionary with keys "dtype" and "shape" describing + the tensor. + """ + return self.group.recv_obj(self.target_rank_for_recv) + + def _send_impl(self, tensor: Optional[torch.Tensor]) -> None: + """ + The actual implementation of sending the tensor and its metadata to the + target rank. + + Parameters: + - tensor: The input tensor to be sent, or None if no tensor is + being sent. + """ + metadata = self._make_metadata(tensor) + self._send_metadata(metadata) + if tensor is not None: + self.device_send_func(tensor.to(self.device), + self.target_rank_for_send) + + def _recv_impl(self) -> Optional[torch.Tensor]: + """ + The actual implementation of receiving a tensor and its metadata from + the target rank. + + Returns: + - buffer: The received tensor, or None if no tensor is received. + """ + metadata = self._recv_metadata() + if metadata["dtype"] is None: + return None + buffer = self._prepare_recv_buffer(metadata) + self.device_recv_func(buffer, self.target_rank_for_recv) + + return buffer + + def send_tensor_wrapper(self, tensor: Optional[torch.Tensor], + tensor_size: int) -> None: + """ + Wrapper for _send_impl to handle exceptions and update buffer size. + """ + try: + self._send_impl(tensor) + + with self.buffer_size_lock: + self.buffer_size -= tensor_size + except Exception as e: + logger.error("[rank%d]: Exception when trying to send %s, msg: %s", + torch.distributed.get_rank(), str(tensor), str(e)) + import traceback + traceback.print_exc() + + def block_if_full(self): + """ + Block the current thread if the buffer size is larger than the + threshold. + """ + while self.buffer_size > self.buffer_size_thresh: + logger.debug("KV cache transfer pipe is full. Waiting...") + time.sleep(0.05) + + def send_tensor(self, tensor: Optional[torch.Tensor]) -> None: + """ + Sends a tensor and its metadata to the destination rank in a + non-blocking way. + + Parameters: + - tensor: The tensor to send, or None if no tensor is being sent. + """ + if self.transport_thread is None: + self.transport_thread = ThreadPoolExecutor(max_workers=1) + + if tensor is not None: + tensor_size = tensor.element_size() * tensor.numel() + else: + tensor_size = 0 + + self.block_if_full() + + with self.buffer_size_lock: + self.buffer_size += tensor_size + + self.transport_thread.submit(self.send_tensor_wrapper, tensor, + tensor_size) + + def recv_tensor(self) -> Optional[torch.Tensor]: + """ + Receives a tensor and its metadata from the source rank. Blocking call. + + Returns: + - tensor: The received tensor, or None if no tensor is received. + """ + if self.transport_thread is None: + self.transport_thread = ThreadPoolExecutor(max_workers=1) + + future = self.transport_thread.submit(self._recv_impl) + + try: + tensor = future.result() + except Exception as e: + logger.error("Encountering exception in KV receiving thread") + logger.error("%s", e) + logger.error("My device: %s", self.device) + import traceback + traceback.print_exc() + raise e + + return tensor + + def close(self): + """ + Close the pipe and release associated resources. + """ + if hasattr(self, + "transport_thread") and self.transport_thread is not None: + self.transport_thread.shutdown() diff --git a/vllm/distributed/kv_transfer/kv_transfer_agent.py b/vllm/distributed/kv_transfer/kv_transfer_agent.py new file mode 100644 index 0000000000000..9ce97851dc849 --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_transfer_agent.py @@ -0,0 +1,75 @@ +"""A centralized entrypoint to perform distributed KV cache transfer. + +This implementation is a shim wrapper on two APIs exposed by `kv_connector`: +1. `send_kv_caches_and_hidden_states` +2. `recv_kv_caches_and_hidden_states +""" +from typing import TYPE_CHECKING, List, Tuple, Union + +if TYPE_CHECKING: + from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata + from vllm.config import VllmConfig + +import torch + +from vllm.distributed.kv_transfer.kv_connector.factory import ( + KVConnectorFactory) +from vllm.logger import init_logger +from vllm.sequence import IntermediateTensors + +logger = init_logger(__name__) + + +class KVTransferAgent: + """ + A class designated for distributed KV transfer + + Target use cases: + 1. Disaggregated prefill + 2. Remote KV cache storage + """ + + def __init__( + self, + rank: int, + local_rank: int, + config: "VllmConfig", + ): + + self.config = config + + if config.kv_transfer_config is None: + raise ValueError("KVTransferConfig is not set in the VllmConfig," + " cannot initialize KVConnector.") + + assert self.config.kv_transfer_config.is_kv_transfer_instance, "KV"\ + "TransferAgent should only be used when kv_connector is set." + + self.connector = KVConnectorFactory.create_connector( + rank, local_rank, config) + + def send_kv_caches_and_hidden_states( + self, + model_executable: torch.nn.Module, + model_input: "ModelInputForGPUWithSamplingMetadata", + kv_caches: List[torch.Tensor], + hidden_or_intermediate_states: Union[torch.Tensor, + IntermediateTensors], + ) -> None: + + self.connector.send_kv_caches_and_hidden_states( + model_executable, model_input, kv_caches, + hidden_or_intermediate_states) + + def close(self) -> None: + self.connector.close() + + def recv_kv_caches_and_hidden_states( + self, model_executable: torch.nn.Module, + model_input: "ModelInputForGPUWithSamplingMetadata", + kv_caches: List[torch.Tensor] + ) -> Tuple[Union[torch.Tensor, IntermediateTensors], bool, + "ModelInputForGPUWithSamplingMetadata"]: + + return self.connector.recv_kv_caches_and_hidden_states( + model_executable, model_input, kv_caches) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 87ade377266a2..34815d7f0aa78 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -27,18 +27,23 @@ from contextlib import contextmanager, nullcontext from dataclasses import dataclass from multiprocessing import shared_memory -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, + Union) from unittest.mock import patch import torch import torch.distributed from torch.distributed import Backend, ProcessGroup +import vllm.distributed.kv_transfer.kv_transfer_agent as kv_transfer import vllm.envs as envs from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.utils import direct_register_custom_op, supports_custom_op +if TYPE_CHECKING: + from vllm.config import VllmConfig + @dataclass class GraphCaptureContext: @@ -96,42 +101,24 @@ def _register_group(group: "GroupCoordinator") -> None: _groups[group.unique_name] = weakref.ref(group) -if supports_custom_op(): - - def inplace_all_reduce(tensor: torch.Tensor, group_name: str) -> None: - assert group_name in _groups, f"Group {group_name} is not found." - group = _groups[group_name]() - if group is None: - raise ValueError(f"Group {group_name} is destroyed.") - group._all_reduce_in_place(tensor) +def all_reduce(tensor: torch.Tensor, group_name: str) -> torch.Tensor: + assert group_name in _groups, f"Group {group_name} is not found." + group = _groups[group_name]() + if group is None: + raise ValueError(f"Group {group_name} is destroyed.") + return group._all_reduce_out_place(tensor) - def inplace_all_reduce_fake(tensor: torch.Tensor, group_name: str) -> None: - return - - direct_register_custom_op( - op_name="inplace_all_reduce", - op_func=inplace_all_reduce, - mutates_args=["tensor"], - fake_impl=inplace_all_reduce_fake, - ) - def outplace_all_reduce(tensor: torch.Tensor, - group_name: str) -> torch.Tensor: - assert group_name in _groups, f"Group {group_name} is not found." - group = _groups[group_name]() - if group is None: - raise ValueError(f"Group {group_name} is destroyed.") - return group._all_reduce_out_place(tensor) +def all_reduce_fake(tensor: torch.Tensor, group_name: str) -> torch.Tensor: + return torch.empty_like(tensor) - def outplace_all_reduce_fake(tensor: torch.Tensor, - group_name: str) -> torch.Tensor: - return torch.empty_like(tensor) +if supports_custom_op(): direct_register_custom_op( - op_name="outplace_all_reduce", - op_func=outplace_all_reduce, + op_name="all_reduce", + op_func=all_reduce, mutates_args=[], - fake_impl=outplace_all_reduce_fake, + fake_impl=all_reduce_fake, ) @@ -317,30 +304,13 @@ def graph_capture( stream.wait_stream(curr_stream) with torch.cuda.stream(stream), maybe_ca_context: - # In graph mode, we have to be very careful about the collective - # operations. The current status is: - # allreduce \ Mode | Eager | Graph | - # -------------------------------------------- - # custom allreduce | enabled | enabled | - # PyNccl | disabled| enabled | - # torch.distributed | enabled | disabled| - # - # Note that custom allreduce will have a runtime check, if the - # tensor size is too large, it will fallback to the next - # available option. - # In summary: When using CUDA graph, we use - # either custom all-reduce kernel or pynccl. When not using - # CUDA graph, we use either custom all-reduce kernel or - # PyTorch NCCL. We always prioritize using custom all-reduce - # kernel but fall back to PyTorch or pynccl if it is - # disabled or not supported. pynccl_comm = self.pynccl_comm maybe_pynccl_context: Any if not pynccl_comm: maybe_pynccl_context = nullcontext() else: maybe_pynccl_context = pynccl_comm.change_state( - enable=True, stream=torch.cuda.current_stream()) + stream=torch.cuda.current_stream()) with maybe_pynccl_context: yield graph_capture_context @@ -356,8 +326,8 @@ def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: coordinator. In addition, PyTorch custom ops do not support mutation or returning - a new tensor in the same op. So we need to figure out if the op is - in-place or out-of-place ahead of time. + a new tensor in the same op. So we always make the all-reduce operation + out-of-place. """ # Bypass the function if we are using only 1 GPU. if self.world_size == 1: @@ -368,10 +338,6 @@ def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: ipex.distributed.all_reduce(input_, group=self.device_group) return input_ - if not supports_custom_op(): - self._all_reduce_in_place(input_) - return input_ - if self.tpu_communicator is not None and \ not self.tpu_communicator.disabled: # TPU handles Dynamo with its own logic. @@ -385,30 +351,31 @@ def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: not self.xpu_communicator.disabled: return self.xpu_communicator.all_reduce(input_) - if self.ca_comm is not None and \ - not self.ca_comm.disabled and \ - self.ca_comm.should_custom_ar(input_): - return torch.ops.vllm.outplace_all_reduce( - input_, group_name=self.unique_name) - else: - torch.ops.vllm.inplace_all_reduce(input_, - group_name=self.unique_name) - return input_ + return torch.ops.vllm.all_reduce(input_, group_name=self.unique_name) def _all_reduce_out_place(self, input_: torch.Tensor) -> torch.Tensor: + # always try custom allreduce first, + # and then pynccl. ca_comm = self.ca_comm - assert ca_comm is not None - assert not ca_comm.disabled - out = ca_comm.custom_all_reduce(input_) - assert out is not None - return out - - def _all_reduce_in_place(self, input_: torch.Tensor) -> None: + if ca_comm is not None and not ca_comm.disabled and \ + ca_comm.should_custom_ar(input_): + out = ca_comm.custom_all_reduce(input_) + assert out is not None + return out pynccl_comm = self.pynccl_comm - if (pynccl_comm is not None and not pynccl_comm.disabled): - pynccl_comm.all_reduce(input_) - else: - torch.distributed.all_reduce(input_, group=self.device_group) + assert pynccl_comm is not None + # TODO: pynccl should not use `stream=` + # it can just always use the current stream. + out = pynccl_comm.all_reduce(input_, + stream=torch.cuda.current_stream()) + if out is None: + # fall back to the default all-reduce using PyTorch. + # this usually happens during testing. + # when we run the model, allreduce only happens for the TP + # group, where we always have either custom allreduce or pynccl. + out = input_.clone() + torch.distributed.all_reduce(out, group=self.device_group) + return out def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: world_size = self.world_size @@ -942,6 +909,14 @@ def get_pp_group() -> GroupCoordinator: # kept for backward compatibility get_pipeline_model_parallel_group = get_pp_group +_KV_TRANSFER: Optional[kv_transfer.KVTransferAgent] = None + + +def get_kv_transfer_group() -> kv_transfer.KVTransferAgent: + assert _KV_TRANSFER is not None, ( + "disaggregated KV cache transfer parallel group is not initialized") + return _KV_TRANSFER + @contextmanager def graph_capture(): @@ -1090,6 +1065,26 @@ def initialize_model_parallel( group_name="pp") +def ensure_kv_transfer_initialized(vllm_config: "VllmConfig") -> None: + """ + Initialize KV cache transfer parallel group. + """ + + global _KV_TRANSFER + + if vllm_config.kv_transfer_config is None: + return + + if all([ + vllm_config.kv_transfer_config.need_kv_parallel_group, + _KV_TRANSFER is None + ]): + _KV_TRANSFER = kv_transfer.KVTransferAgent( + rank=get_world_group().rank, + local_rank=get_world_group().local_rank, + config=vllm_config) + + def ensure_model_parallel_initialized( tensor_model_parallel_size: int, pipeline_model_parallel_size: int, diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 92fa87c7fa45b..4aa0eebd976c9 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -8,17 +8,19 @@ import torch import vllm.envs as envs -from vllm.config import (CacheConfig, ConfigFormat, DecodingConfig, - DeviceConfig, HfOverrides, LoadConfig, LoadFormat, - LoRAConfig, ModelConfig, ObservabilityConfig, - ParallelConfig, PoolerConfig, PromptAdapterConfig, - SchedulerConfig, SpeculativeConfig, TaskOption, - TokenizerPoolConfig, VllmConfig) +from vllm.config import (CacheConfig, CompilationConfig, ConfigFormat, + DecodingConfig, DeviceConfig, HfOverrides, + KVTransferConfig, LoadConfig, LoadFormat, LoRAConfig, + ModelConfig, ObservabilityConfig, ParallelConfig, + PoolerConfig, PromptAdapterConfig, SchedulerConfig, + SpeculativeConfig, TaskOption, TokenizerPoolConfig, + VllmConfig) from vllm.executor.executor_base import ExecutorBase from vllm.logger import init_logger from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS from vllm.platforms import current_platform from vllm.transformers_utils.utils import check_gguf_file +from vllm.usage.usage_lib import UsageContext from vllm.utils import FlexibleArgumentParser, StoreBoolean if TYPE_CHECKING: @@ -106,13 +108,14 @@ class EngineArgs: # notice. distributed_executor_backend: Optional[Union[str, Type[ExecutorBase]]] = None + # number of P/D disaggregation (or other disaggregation) workers pipeline_parallel_size: int = 1 tensor_parallel_size: int = 1 max_parallel_loading_workers: Optional[int] = None # NOTE(kzawora): default block size for Gaudi should be 128 # smaller sizes still work, but very inefficiently block_size: int = 16 if not current_platform.is_hpu() else 128 - enable_prefix_caching: bool = False + enable_prefix_caching: Optional[bool] = None disable_sliding_window: bool = False use_v2_block_manager: bool = True swap_space: float = 4 # GiB @@ -189,11 +192,30 @@ class EngineArgs: override_neuron_config: Optional[Dict[str, Any]] = None override_pooler_config: Optional[PoolerConfig] = None + compilation_config: Optional[CompilationConfig] = None + worker_cls: str = "auto" + + kv_transfer_config: Optional[KVTransferConfig] = None def __post_init__(self): if not self.tokenizer: self.tokenizer = self.model + # Override the default value of enable_prefix_caching if it's not set + # by user. + if self.enable_prefix_caching is None: + self.enable_prefix_caching = bool(envs.VLLM_USE_V1) + + # support `EngineArgs(compilation_config={...})` + # without having to manually construct a + # CompilationConfig object + if isinstance(self.compilation_config, (int)): + self.compilation_config = CompilationConfig.from_cli( + str(self.compilation_config)) + elif isinstance(self.compilation_config, (dict)): + self.compilation_config = CompilationConfig.from_cli( + json.dumps(self.compilation_config)) + # Setup plugins from vllm.plugins import load_general_plugins load_general_plugins() @@ -397,9 +419,13 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: 'tokens. This is ignored on neuron devices and ' 'set to max-model-len') - parser.add_argument('--enable-prefix-caching', - action='store_true', - help='Enables automatic prefix caching.') + parser.add_argument( + "--enable-prefix-caching", + action=argparse.BooleanOptionalAction, + default=EngineArgs.enable_prefix_caching, + help="Enables automatic prefix caching. " + "Use --no-enable-prefix-caching to disable explicitly.", + ) parser.add_argument('--disable-sliding-window', action='store_true', help='Disables sliding window, ' @@ -793,7 +819,7 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: type=str, default=[], help="The pattern(s) to ignore when loading the model." - "Default to 'original/**/*' to avoid repeated loading of llama's " + "Default to `original/**/*` to avoid repeated loading of llama's " "checkpoints.") parser.add_argument( '--preemption-mode', @@ -868,6 +894,35 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: help="Override or set the pooling method in the embedding model. " "e.g. {\"pooling_type\": \"mean\", \"normalize\": false}.'") + parser.add_argument('--compilation-config', + '-O', + type=CompilationConfig.from_cli, + default=None, + help='torch.compile configuration for the model.' + 'When it is a number (0, 1, 2, 3), it will be ' + 'interpreted as the optimization level.\n' + 'NOTE: level 0 is the default level without ' + 'any optimization. level 1 and 2 are for internal ' + 'testing only. level 3 is the recommended level ' + 'for production.\n' + 'To specify the full compilation config, ' + 'use a JSON string.\n' + 'Following the convention of traditional ' + 'compilers, using -O without space is also ' + 'supported. -O3 is equivalent to -O 3.') + + parser.add_argument('--kv-transfer-config', + type=KVTransferConfig.from_cli, + default=None, + help='The configurations for distributed KV cache ' + 'transfer. Should be a JSON string.') + + parser.add_argument( + '--worker-cls', + type=str, + default="auto", + help='The worker class to use for distributed execution.') + return parser @classmethod @@ -920,7 +975,12 @@ def create_load_config(self) -> LoadConfig: ignore_patterns=self.ignore_patterns, ) - def create_engine_config(self) -> VllmConfig: + def create_engine_config(self, + usage_context: Optional[UsageContext] = None + ) -> VllmConfig: + if envs.VLLM_USE_V1: + self._override_v1_engine_args(usage_context) + # gguf file needs a specific model loader and doesn't use hf_repo if check_gguf_file(self.model): self.quantization = self.load_format = "gguf" @@ -980,7 +1040,9 @@ def create_engine_config(self) -> VllmConfig: self.tokenizer_pool_extra_config, ), ray_workers_use_nsight=self.ray_workers_use_nsight, - distributed_executor_backend=self.distributed_executor_backend) + distributed_executor_backend=self.distributed_executor_backend, + worker_cls=self.worker_cls, + ) max_model_len = model_config.max_model_len use_long_context = max_model_len > 32768 @@ -998,7 +1060,8 @@ def create_engine_config(self) -> VllmConfig: use_spec_decode = self.speculative_model is not None if (is_gpu and not use_sliding_window and not use_spec_decode and not self.enable_lora - and not self.enable_prompt_adapter): + and not self.enable_prompt_adapter + and model_config.task != "embedding"): self.enable_chunked_prefill = True logger.warning( "Chunked prefill is enabled by default for models with " @@ -1015,6 +1078,10 @@ def create_engine_config(self) -> VllmConfig: "errors during the initial memory profiling phase, or result " "in low performance due to small KV cache space. Consider " "setting --max-model-len to a smaller value.", max_model_len) + elif self.enable_chunked_prefill and model_config.task == "embedding": + msg = "Chunked prefill is not supported for embedding models" + raise ValueError(msg) + speculative_config = SpeculativeConfig.maybe_create_spec_config( target_model_config=model_config, @@ -1130,7 +1197,7 @@ def create_engine_config(self) -> VllmConfig: or "all" in detailed_trace_modules, ) - return VllmConfig( + config = VllmConfig( model_config=model_config, cache_config=cache_config, parallel_config=parallel_config, @@ -1142,8 +1209,46 @@ def create_engine_config(self) -> VllmConfig: decoding_config=decoding_config, observability_config=observability_config, prompt_adapter_config=prompt_adapter_config, + compilation_config=self.compilation_config, + kv_transfer_config=self.kv_transfer_config, ) + if envs.VLLM_USE_V1: + self._override_v1_engine_config(config) + return config + + def _override_v1_engine_args(self, usage_context: UsageContext) -> None: + """ + Override the EngineArgs's args based on the usage context for V1. + """ + assert envs.VLLM_USE_V1, "V1 is not enabled" + + if self.max_num_batched_tokens is None: + # When no user override, set the default values based on the + # usage context. + if usage_context == UsageContext.LLM_CLASS: + logger.warning("Setting max_num_batched_tokens to 8192 " + "for LLM_CLASS usage context.") + self.max_num_seqs = 1024 + self.max_num_batched_tokens = 8192 + elif usage_context == UsageContext.OPENAI_API_SERVER: + logger.warning("Setting max_num_batched_tokens to 2048 " + "for OPENAI_API_SERVER usage context.") + self.max_num_seqs = 1024 + self.max_num_batched_tokens = 2048 + + def _override_v1_engine_config(self, engine_config: VllmConfig) -> None: + """ + Override the EngineConfig's configs based on the usage context for V1. + """ + assert envs.VLLM_USE_V1, "V1 is not enabled" + # TODO (ywang96): Enable APC by default when VLM supports it. + if engine_config.model_config.is_multimodal_model: + logger.warning( + "Prefix caching is currently not supported for multimodal " + "models and has been disabled.") + engine_config.cache_config.enable_prefix_caching = False + @dataclass class AsyncEngineArgs(EngineArgs): diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 5a5388708b1c6..7b1bb7b05708d 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -25,7 +25,7 @@ from vllm.model_executor.guided_decoding import ( get_guided_decoding_logits_processor) from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.outputs import EmbeddingRequestOutput, RequestOutput +from vllm.outputs import PoolingRequestOutput, RequestOutput from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams @@ -74,7 +74,7 @@ def _log_task_completion(task: asyncio.Task, class AsyncStream: - """A stream of RequestOutputs or EmbeddingRequestOutputs for a request + """A stream of RequestOutputs or PoolingRequestOutputs for a request that can be iterated over asynchronously via an async generator.""" def __init__(self, request_id: str, cancel: Callable[[str], None]) -> None: @@ -83,7 +83,7 @@ def __init__(self, request_id: str, cancel: Callable[[str], None]) -> None: self._queue: asyncio.Queue = asyncio.Queue() self._finished = False - def put(self, item: Union[RequestOutput, EmbeddingRequestOutput, + def put(self, item: Union[RequestOutput, PoolingRequestOutput, Exception]) -> None: if not self._finished: self._queue.put_nowait(item) @@ -103,7 +103,7 @@ def finished(self) -> bool: async def generator( self - ) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]: + ) -> AsyncGenerator[Union[RequestOutput, PoolingRequestOutput], None]: try: while True: result = await self._queue.get() @@ -154,7 +154,7 @@ def propagate_exception(self, def process_request_output(self, request_output: Union[RequestOutput, - EmbeddingRequestOutput], + PoolingRequestOutput], *, verbose: bool = False) -> None: """Process a request output from the engine.""" @@ -265,7 +265,7 @@ def __init__(self, *args, **kwargs): async def step_async( self, virtual_engine: int - ) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: + ) -> List[Union[RequestOutput, PoolingRequestOutput]]: """Performs one decoding iteration and returns newly generated results. The workers are ran asynchronously if possible. @@ -300,6 +300,9 @@ async def step_async( ctx.seq_group_metadata_list = seq_group_metadata_list ctx.scheduler_outputs = scheduler_outputs + finished_requests_ids = self.scheduler[ + virtual_engine].get_and_reset_finished_requests_ids() + # Maybe switch from async mode to sync mode if not allow_async_output_proc and len(ctx.output_queue) > 0: self._process_model_outputs(ctx=ctx) @@ -311,13 +314,13 @@ async def step_async( self._cache_scheduler_outputs_for_multi_step( virtual_engine, seq_group_metadata_list, scheduler_outputs, allow_async_output_proc) + else: + finished_requests_ids = list() assert seq_group_metadata_list is not None assert scheduler_outputs is not None if not scheduler_outputs.is_empty(): - finished_requests_ids = self.scheduler[ - virtual_engine].get_and_reset_finished_requests_ids() # Check if we have a cached last_output from the previous iteration. # For supporting PP this is probably the best way to pass the @@ -680,7 +683,7 @@ def from_engine_args( """Creates an async LLM engine from the engine arguments.""" # Create the engine configs. if engine_config is None: - engine_config = engine_args.create_engine_config() + engine_config = engine_args.create_engine_config(usage_context) executor_class = cls._get_executor_cls(engine_config) @@ -904,7 +907,7 @@ def add_request( prompt_adapter_request: Optional[PromptAdapterRequest] = None, priority: int = 0, ) -> Coroutine[None, None, AsyncGenerator[Union[ - RequestOutput, EmbeddingRequestOutput], None]]: + RequestOutput, PoolingRequestOutput], None]]: ... @overload @@ -919,7 +922,7 @@ def add_request( prompt_adapter_request: Optional[PromptAdapterRequest] = None, priority: int = 0, ) -> Coroutine[None, None, AsyncGenerator[Union[ - RequestOutput, EmbeddingRequestOutput], None]]: + RequestOutput, PoolingRequestOutput], None]]: ... @deprecate_kwargs( @@ -938,7 +941,7 @@ async def add_request( priority: int = 0, *, inputs: Optional[PromptType] = None, # DEPRECATED - ) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]: + ) -> AsyncGenerator[Union[RequestOutput, PoolingRequestOutput], None]: if inputs is not None: prompt = inputs assert prompt is not None and params is not None @@ -1067,7 +1070,7 @@ async def encode( lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Mapping[str, str]] = None, priority: int = 0, - ) -> AsyncGenerator[EmbeddingRequestOutput, None]: + ) -> AsyncGenerator[PoolingRequestOutput, None]: """Generate outputs for a request from an embedding model. Generate outputs for a request. This method is a coroutine. It adds the @@ -1085,7 +1088,7 @@ async def encode( Only applicable with priority scheduling. Yields: - The output `EmbeddingRequestOutput` objects from the LLMEngine + The output `PoolingRequestOutput` objects from the LLMEngine for the request. Details: @@ -1138,7 +1141,7 @@ async def encode( trace_headers=trace_headers, priority=priority, ): - yield LLMEngine.validate_output(output, EmbeddingRequestOutput) + yield LLMEngine.validate_output(output, PoolingRequestOutput) async def abort(self, request_id: str) -> None: """Abort a request. diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index e41e9da66350c..d2d50b3238daa 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -40,7 +40,7 @@ get_local_guided_decoding_logits_processor) from vllm.model_executor.layers.sampler import SamplerOutput from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry -from vllm.outputs import (EmbeddingRequestOutput, RequestOutput, +from vllm.outputs import (PoolingRequestOutput, RequestOutput, RequestOutputFactory) from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest @@ -80,7 +80,7 @@ def _load_generation_config_dict(model_config: ModelConfig) -> Dict[str, Any]: _G = TypeVar("_G", bound=BaseTokenizerGroup, default=BaseTokenizerGroup) -_O = TypeVar("_O", RequestOutput, EmbeddingRequestOutput) +_O = TypeVar("_O", RequestOutput, PoolingRequestOutput) @dataclass @@ -112,7 +112,7 @@ class SchedulerContext: def __init__(self, multi_step_stream_outputs: bool = False): self.output_queue: Deque[OutputData] = deque() self.request_outputs: List[Union[RequestOutput, - EmbeddingRequestOutput]] = [] + PoolingRequestOutput]] = [] self.seq_group_metadata_list: Optional[ List[SequenceGroupMetadata]] = None self.scheduler_outputs: Optional[SchedulerOutputs] = None @@ -231,19 +231,18 @@ def __init__( use_cached_outputs: bool = False, ) -> None: - # TODO: remove the local variables and use self.* throughout the class. - model_config = self.model_config = vllm_config.model_config - cache_config = self.cache_config = vllm_config.cache_config - lora_config = self.lora_config = vllm_config.lora_config - parallel_config = self.parallel_config = vllm_config.parallel_config - scheduler_config = self.scheduler_config = vllm_config.scheduler_config - device_config = self.device_config = vllm_config.device_config - speculative_config = self.speculative_config = vllm_config.speculative_config # noqa - load_config = self.load_config = vllm_config.load_config - decoding_config = self.decoding_config = vllm_config.decoding_config or DecodingConfig( # noqa + self.model_config = vllm_config.model_config + self.cache_config = vllm_config.cache_config + self.lora_config = vllm_config.lora_config + self.parallel_config = vllm_config.parallel_config + self.scheduler_config = vllm_config.scheduler_config + self.device_config = vllm_config.device_config + self.speculative_config = vllm_config.speculative_config # noqa + self.load_config = vllm_config.load_config + self.decoding_config = vllm_config.decoding_config or DecodingConfig( # noqa ) - prompt_adapter_config = self.prompt_adapter_config = vllm_config.prompt_adapter_config # noqa - observability_config = self.observability_config = vllm_config.observability_config or ObservabilityConfig( # noqa + self.prompt_adapter_config = vllm_config.prompt_adapter_config # noqa + self.observability_config = vllm_config.observability_config or ObservabilityConfig( # noqa ) logger.info( @@ -262,55 +261,46 @@ def __init__( "num_scheduler_steps=%d, chunked_prefill_enabled=%s " "multi_step_stream_outputs=%s, enable_prefix_caching=%s, " "use_async_output_proc=%s, use_cached_outputs=%s, " - "mm_processor_kwargs=%s, pooler_config=%r)", + "mm_processor_kwargs=%s, pooler_config=%r," + "compilation_config=%r", VLLM_VERSION, - model_config.model, - speculative_config, - model_config.tokenizer, - model_config.skip_tokenizer_init, - model_config.tokenizer_mode, - model_config.revision, - model_config.override_neuron_config, - model_config.tokenizer_revision, - model_config.trust_remote_code, - model_config.dtype, - model_config.max_model_len, - load_config.download_dir, - load_config.load_format, - parallel_config.tensor_parallel_size, - parallel_config.pipeline_parallel_size, - parallel_config.disable_custom_all_reduce, - model_config.quantization, - model_config.enforce_eager, - cache_config.cache_dtype, - model_config.quantization_param_path, - device_config.device, - decoding_config, - observability_config, - model_config.seed, - model_config.served_model_name, - scheduler_config.num_scheduler_steps, - scheduler_config.chunked_prefill_enabled, - scheduler_config.multi_step_stream_outputs, - cache_config.enable_prefix_caching, - model_config.use_async_output_proc, + self.model_config.model, + self.speculative_config, + self.model_config.tokenizer, + self.model_config.skip_tokenizer_init, + self.model_config.tokenizer_mode, + self.model_config.revision, + self.model_config.override_neuron_config, + self.model_config.tokenizer_revision, + self.model_config.trust_remote_code, + self.model_config.dtype, + self.model_config.max_model_len, + self.load_config.download_dir, + self.load_config.load_format, + self.parallel_config.tensor_parallel_size, + self.parallel_config.pipeline_parallel_size, + self.parallel_config.disable_custom_all_reduce, + self.model_config.quantization, + self.model_config.enforce_eager, + self.cache_config.cache_dtype, + self.model_config.quantization_param_path, + self.device_config.device, + self.decoding_config, + self.observability_config, + self.model_config.seed, + self.model_config.served_model_name, + self.scheduler_config.num_scheduler_steps, + self.scheduler_config.chunked_prefill_enabled, + self.scheduler_config.multi_step_stream_outputs, + self.cache_config.enable_prefix_caching, + self.model_config.use_async_output_proc, use_cached_outputs, - model_config.mm_processor_kwargs, - model_config.pooler_config, + self.model_config.mm_processor_kwargs, + self.model_config.pooler_config, + vllm_config.compilation_config, ) # TODO(woosuk): Print more configs in debug mode. - self.model_config = model_config - self.cache_config = cache_config - self.lora_config = lora_config - self.parallel_config = parallel_config - self.scheduler_config = scheduler_config - self.device_config = device_config - self.speculative_config = speculative_config - self.load_config = load_config - self.decoding_config = decoding_config or DecodingConfig() - self.prompt_adapter_config = prompt_adapter_config - self.observability_config = observability_config or ObservabilityConfig( - ) + self.log_stats = log_stats self.use_cached_outputs = use_cached_outputs @@ -332,15 +322,15 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer: self.seq_counter = Counter() self.generation_config_fields = _load_generation_config_dict( - model_config) + self.model_config) - self.input_preprocessor = InputPreprocessor(model_config, + self.input_preprocessor = InputPreprocessor(self.model_config, self.tokenizer, mm_registry) self.input_registry = input_registry self.input_processor = input_registry.create_input_processor( - model_config) + self.model_config) self.model_executor = executor_class(vllm_config=vllm_config, ) @@ -352,36 +342,36 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer: from vllm.model_executor.model_loader import ( get_architecture_class_name) usage_message.report_usage( - get_architecture_class_name(model_config), + get_architecture_class_name(self.model_config), usage_context, extra_kvs={ # Common configuration "dtype": - str(model_config.dtype), + str(self.model_config.dtype), "tensor_parallel_size": - parallel_config.tensor_parallel_size, + self.parallel_config.tensor_parallel_size, "block_size": - cache_config.block_size, + self.cache_config.block_size, "gpu_memory_utilization": - cache_config.gpu_memory_utilization, + self.cache_config.gpu_memory_utilization, # Quantization "quantization": - model_config.quantization, + self.model_config.quantization, "kv_cache_dtype": - str(cache_config.cache_dtype), + str(self.cache_config.cache_dtype), # Feature flags "enable_lora": - bool(lora_config), + bool(self.lora_config), "enable_prompt_adapter": - bool(prompt_adapter_config), + bool(self.prompt_adapter_config), "enable_prefix_caching": - cache_config.enable_prefix_caching, + self.cache_config.enable_prefix_caching, "enforce_eager": - model_config.enforce_eager, + self.model_config.enforce_eager, "disable_custom_all_reduce": - parallel_config.disable_custom_all_reduce, + self.parallel_config.disable_custom_all_reduce, }) if self.tokenizer: @@ -400,7 +390,7 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer: for _ in range(self.parallel_config.pipeline_parallel_size) ] - if model_config.use_async_output_proc: + if self.model_config.use_async_output_proc: process_model_outputs = weak_bind(self._process_model_outputs) self.async_callbacks = [ @@ -420,11 +410,11 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer: # GPU and CPU blocks, which are profiled in the distributed executor. self.scheduler = [ Scheduler( - scheduler_config, cache_config, lora_config, - parallel_config.pipeline_parallel_size, + self.scheduler_config, self.cache_config, self.lora_config, + self.parallel_config.pipeline_parallel_size, self.async_callbacks[v_id] - if model_config.use_async_output_proc else None) - for v_id in range(parallel_config.pipeline_parallel_size) + if self.model_config.use_async_output_proc else None) + for v_id in range(self.parallel_config.pipeline_parallel_size) ] # Metric Logging. @@ -446,7 +436,8 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer: "prometheus": PrometheusStatLogger( local_interval=_LOCAL_LOGGING_INTERVAL_SEC, - labels=dict(model_name=model_config.served_model_name), + labels=dict( + model_name=self.model_config.served_model_name), max_model_len=self.model_config.max_model_len), } self.stat_loggers["prometheus"].info("cache_config", @@ -577,7 +568,7 @@ def from_engine_args( ) -> "LLMEngine": """Creates an LLM engine from the engine arguments.""" # Create the engine configs. - engine_config = engine_args.create_engine_config() + engine_config = engine_args.create_engine_config(usage_context) executor_class = cls._get_executor_cls(engine_config) # Create the LLM engine. engine = cls( @@ -1323,7 +1314,7 @@ def _advance_to_next_step( else: seq.append_token_id(sample.output_token, sample.logprobs) - def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: + def step(self) -> List[Union[RequestOutput, PoolingRequestOutput]]: """Performs one decoding iteration and returns newly generated results. .. figure:: https://i.imgur.com/sv2HssD.png @@ -1407,6 +1398,9 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: ctx.seq_group_metadata_list = seq_group_metadata_list ctx.scheduler_outputs = scheduler_outputs + finished_requests_ids = self.scheduler[ + virtual_engine].get_and_reset_finished_requests_ids() + # Maybe switch from async mode to sync mode if not allow_async_output_proc and len(ctx.output_queue) > 0: self._process_model_outputs(ctx=ctx) @@ -1418,13 +1412,13 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: self._cache_scheduler_outputs_for_multi_step( virtual_engine, seq_group_metadata_list, scheduler_outputs, allow_async_output_proc) + else: + finished_requests_ids = list() assert seq_group_metadata_list is not None assert scheduler_outputs is not None if not scheduler_outputs.is_empty(): - finished_requests_ids = self.scheduler[ - virtual_engine].get_and_reset_finished_requests_ids() # Check if we have a cached last_output from the previous iteration. # For supporting PP this is probably the best way to pass the @@ -1716,7 +1710,7 @@ def _get_stats(self, # not counted (to avoid double counting) actual_num_batched_tokens = scheduler_outputs.num_batched_tokens # type: ignore - num_generation_tokens_from_prefill_groups = 0. + num_generation_tokens_from_prefill_groups = 0 # NOTE: if scheduler_outputs.num_prefill_groups > 0 and # the len of scheduler_outputs.scheduled_seq_groups is != # scheduler_outputs.num_prefill_groups, this means that diff --git a/vllm/engine/metrics.py b/vllm/engine/metrics.py index e896bcdded2d1..4869557ba9b44 100644 --- a/vllm/engine/metrics.py +++ b/vllm/engine/metrics.py @@ -421,6 +421,11 @@ def get_throughput(tracked_stats: List[int], now: float, class LoggingStatLogger(StatLoggerBase): """LoggingStatLogger is used in LLMEngine to log to Stdout.""" + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.last_prompt_throughput: Optional[float] = None + self.last_generation_throughput: Optional[float] = None + def log(self, stats: Stats) -> None: """Called by LLMEngine. Logs to Stdout every self.local_interval seconds.""" @@ -445,8 +450,14 @@ def log(self, stats: Stats) -> None: now=stats.now, last_log=self.last_local_log) - # Log to stdout. - logger.info( + log_fn = logger.info + if not any((prompt_throughput, generation_throughput, + self.last_prompt_throughput, + self.last_generation_throughput)): + # Avoid log noise on an idle production system + log_fn = logger.debug + + log_fn( "Avg prompt throughput: %.1f tokens/s, " "Avg generation throughput: %.1f tokens/s, " "Running: %d reqs, Swapped: %d reqs, " @@ -462,21 +473,26 @@ def log(self, stats: Stats) -> None: ) if (stats.cpu_prefix_cache_hit_rate >= 0 or stats.gpu_prefix_cache_hit_rate >= 0): - logger.info( + log_fn( "Prefix cache hit rate: GPU: %.2f%%, CPU: %.2f%%", stats.gpu_prefix_cache_hit_rate * 100, stats.cpu_prefix_cache_hit_rate * 100, ) if self.spec_decode_metrics is not None: - logger.info( + log_fn( self._format_spec_decode_metrics_str( self.spec_decode_metrics)) - # Reset tracked stats for next interval. - self.num_prompt_tokens = [] - self.num_generation_tokens = [] - self.last_local_log = stats.now - self.spec_decode_metrics = None + self._reset(stats, prompt_throughput, generation_throughput) + + def _reset(self, stats, prompt_throughput, generation_throughput) -> None: + # Reset tracked stats for next interval. + self.num_prompt_tokens = [] + self.num_generation_tokens = [] + self.last_local_log = stats.now + self.spec_decode_metrics = None + self.last_prompt_throughput = prompt_throughput + self.last_generation_throughput = generation_throughput def _format_spec_decode_metrics_str( self, metrics: "SpecDecodeWorkerMetrics") -> str: @@ -512,6 +528,11 @@ def _log_gauge(self, gauge, data: Union[int, float]) -> None: def _log_counter(self, counter, data: Union[int, float]) -> None: # Convenience function for logging to counter. + # Prevent ValueError from negative increment + if data < 0: + logger.warning("Skipping negative increment of %g to %s", data, + counter) + return counter.labels(**self.labels).inc(data) def _log_counter_labels(self, counter, data: CollectionsCounter, diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py index fe21c58c775fe..d26728e8c6e67 100644 --- a/vllm/engine/multiprocessing/client.py +++ b/vllm/engine/multiprocessing/client.py @@ -35,7 +35,7 @@ from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.outputs import EmbeddingRequestOutput, RequestOutput +from vllm.outputs import PoolingRequestOutput, RequestOutput from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs @@ -495,7 +495,7 @@ def encode( lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Mapping[str, str]] = None, priority: int = 0, - ) -> AsyncGenerator[EmbeddingRequestOutput, None]: + ) -> AsyncGenerator[PoolingRequestOutput, None]: ... @overload @@ -507,7 +507,7 @@ def encode( lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Mapping[str, str]] = None, priority: int = 0, - ) -> AsyncGenerator[EmbeddingRequestOutput, None]: + ) -> AsyncGenerator[PoolingRequestOutput, None]: ... @deprecate_kwargs( @@ -524,7 +524,7 @@ def encode( priority: int = 0, *, inputs: Optional[PromptType] = None # DEPRECATED - ) -> AsyncGenerator[EmbeddingRequestOutput, None]: + ) -> AsyncGenerator[PoolingRequestOutput, None]: """Generate outputs for a request from an embedding model. Generate outputs for a request. This method is a coroutine. It adds the @@ -540,7 +540,7 @@ def encode( trace_headers: OpenTelemetry trace headers. Yields: - The output `EmbeddingRequestOutput` objects from the LLMEngine + The output `PoolingRequestOutput` objects from the LLMEngine for the request. """ if inputs is not None: @@ -549,7 +549,7 @@ def encode( and request_id is not None) return cast( - AsyncGenerator[EmbeddingRequestOutput, None], + AsyncGenerator[PoolingRequestOutput, None], self._process_request(prompt, pooling_params, request_id, @@ -567,7 +567,7 @@ async def _process_request( prompt_adapter_request: Optional[PromptAdapterRequest] = None, priority: int = 0, ) -> Union[AsyncGenerator[RequestOutput, None], AsyncGenerator[ - EmbeddingRequestOutput, None]]: + PoolingRequestOutput, None]]: """Send an RPCGenerateRequest to the RPCServer and stream responses.""" # If already dead, error out. diff --git a/vllm/engine/multiprocessing/engine.py b/vllm/engine/multiprocessing/engine.py index 7de23643a2e1c..49a90b321dac4 100644 --- a/vllm/engine/multiprocessing/engine.py +++ b/vllm/engine/multiprocessing/engine.py @@ -111,7 +111,7 @@ def from_engine_args(cls, engine_args: AsyncEngineArgs, from vllm.plugins import load_general_plugins load_general_plugins() - engine_config = engine_args.create_engine_config() + engine_config = engine_args.create_engine_config(usage_context) executor_class = LLMEngine._get_executor_cls(engine_config) use_async_sockets = engine_config.model_config.use_async_output_proc diff --git a/vllm/engine/protocol.py b/vllm/engine/protocol.py index e15395d75c91f..4079de7d36793 100644 --- a/vllm/engine/protocol.py +++ b/vllm/engine/protocol.py @@ -11,8 +11,7 @@ from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.outputs import (CompletionOutput, EmbeddingRequestOutput, - RequestOutput) +from vllm.outputs import CompletionOutput, PoolingRequestOutput, RequestOutput from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import BeamSearchParams, SamplingParams @@ -209,7 +208,7 @@ def encode( lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Mapping[str, str]] = None, priority: int = 0, - ) -> AsyncGenerator[EmbeddingRequestOutput, None]: + ) -> AsyncGenerator[PoolingRequestOutput, None]: """Generate outputs for a request from an embedding model.""" ... diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index abee5ac46391c..c2054dcbfce0e 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -412,6 +412,8 @@ def _placeholder_str(self, modality: ModalityStr, return "" if model_type == "idefics3": return "" + if model_type == "aria": + return "<|fim_prefix|><|img|><|fim_suffix|>" raise TypeError(f"Unknown {modality} model type: {model_type}") elif modality == "audio": diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 86b0b6893f1d9..a25c401b4ea10 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -1,4 +1,5 @@ import itertools +import json import warnings from contextlib import contextmanager from typing import (Any, ClassVar, Dict, List, Optional, Sequence, Tuple, Type, @@ -9,6 +10,7 @@ from vllm import envs from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput, BeamSearchSequence, get_beam_search_score) +from vllm.config import CompilationConfig from vllm.engine.arg_utils import (EngineArgs, HfOverrides, PoolerConfig, TaskOption) from vllm.engine.llm_engine import LLMEngine @@ -18,13 +20,13 @@ apply_mistral_chat_template, parse_chat_messages, resolve_chat_template_content_format) -from vllm.inputs import PromptType, TextPrompt, TokensPrompt +from vllm.inputs import PromptType, SingletonPrompt, TextPrompt, TokensPrompt from vllm.inputs.parse import parse_and_batch_prompt from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.model_executor.guided_decoding.guided_fields import ( GuidedDecodingRequest, LLMGuidedOptions) -from vllm.outputs import EmbeddingRequestOutput, RequestOutput +from vllm.outputs import PoolingRequestOutput, RequestOutput from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams, @@ -107,6 +109,9 @@ class LLM: hf_overrides: If a dictionary, contains arguments to be forwarded to the HuggingFace config. If a callable, it is called to update the HuggingFace config. + compilation_config: Either an integer or a dictionary. If it is an + integer, it is used as the level of compilation optimization. If it + is a dictionary, it can specify the full compilation configuration. **kwargs: Arguments for :class:`~vllm.EngineArgs`. (See :ref:`engine_args`) @@ -166,6 +171,7 @@ def __init__( # After positional args are removed, move this right below `model` task: TaskOption = "auto", override_pooler_config: Optional[PoolerConfig] = None, + compilation_config: Optional[Union[int, Dict[str, Any]]] = None, **kwargs, ) -> None: ''' @@ -178,6 +184,18 @@ def __init__( if "disable_log_stats" not in kwargs: kwargs["disable_log_stats"] = True + if compilation_config is not None: + if isinstance(compilation_config, (int)): + compilation_config_instance = CompilationConfig.from_cli( + str(compilation_config)) + elif isinstance(compilation_config, (dict)): + compilation_config_instance = CompilationConfig.from_cli( + json.dumps(compilation_config)) + else: + compilation_config_instance = compilation_config + else: + compilation_config_instance = None + engine_args = EngineArgs( model=model, task=task, @@ -202,6 +220,7 @@ def __init__( hf_overrides=hf_overrides, mm_processor_kwargs=mm_processor_kwargs, override_pooler_config=override_pooler_config, + compilation_config=compilation_config_instance, **kwargs, ) # Logic to switch between engines is done at runtime instead of import @@ -660,7 +679,7 @@ def encode( prompt_token_ids: Optional[List[int]] = None, use_tqdm: bool = True, lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, - ) -> List[EmbeddingRequestOutput]: + ) -> List[PoolingRequestOutput]: ... @overload # LEGACY: multi (prompt + optional token ids) @@ -672,7 +691,7 @@ def encode( prompt_token_ids: Optional[List[List[int]]] = None, use_tqdm: bool = True, lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, - ) -> List[EmbeddingRequestOutput]: + ) -> List[PoolingRequestOutput]: ... @overload # LEGACY: single (token ids + optional prompt) @@ -685,7 +704,7 @@ def encode( prompt_token_ids: List[int], use_tqdm: bool = True, lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, - ) -> List[EmbeddingRequestOutput]: + ) -> List[PoolingRequestOutput]: ... @overload # LEGACY: multi (token ids + optional prompt) @@ -698,7 +717,7 @@ def encode( prompt_token_ids: List[List[int]], use_tqdm: bool = True, lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, - ) -> List[EmbeddingRequestOutput]: + ) -> List[PoolingRequestOutput]: ... @overload # LEGACY: single or multi token ids [pos-only] @@ -709,7 +728,7 @@ def encode( prompt_token_ids: Union[List[int], List[List[int]]], use_tqdm: bool = True, lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, - ) -> List[EmbeddingRequestOutput]: + ) -> List[PoolingRequestOutput]: ... @overload @@ -722,7 +741,7 @@ def encode( Sequence[PoolingParams]]] = None, use_tqdm: bool = True, lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, - ) -> List[EmbeddingRequestOutput]: + ) -> List[PoolingRequestOutput]: ... @deprecate_kwargs( @@ -740,7 +759,7 @@ def encode( use_tqdm: bool = True, lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, - ) -> List[EmbeddingRequestOutput]: + ) -> List[PoolingRequestOutput]: """Generates the completions for the input prompts. This class automatically batches the given prompts, considering @@ -759,7 +778,7 @@ def encode( generation, if any. Returns: - A list of ``EmbeddingRequestOutput`` objects containing the + A list of ``PoolingRequestOutput`` objects containing the generated embeddings in the same order as the input prompts. Note: @@ -802,7 +821,129 @@ def encode( outputs = self._run_engine(use_tqdm=use_tqdm) return self.engine_class.validate_outputs(outputs, - EmbeddingRequestOutput) + PoolingRequestOutput) + + def score( + self, + text_1: Union[SingletonPrompt, Sequence[SingletonPrompt]], + text_2: Union[SingletonPrompt, Sequence[SingletonPrompt]], + /, + truncate_prompt_tokens: Optional[int] = None, + use_tqdm: bool = True, + lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + ) -> List[PoolingRequestOutput]: + """Generates similarity scores for all pairs . + + The inputs can be 1 -> 1, 1 -> N or N -> N. In the 1 - N case + the text_1 sentence will be replicated N times to pair with the text_2 + sentences. The input pairs are used to build a list of prompts for the + cross encoder model. This class automatically batches the prompts, + considering the memory constraint. For the best performance, put all + of your texts into a single list and pass it to this method. + + Args: + text_1: can be a single prompt or a list of prompts, in which + case it has to have the same length as the text_2 list + text_2: The texts to pair with the query to form the input + to the LLM. See :class:`~vllm.inputs.PromptType` for + more details about the format of each prompts. + use_tqdm: Whether to use tqdm to display the progress bar. + lora_request: LoRA request to use for generation, if any. + prompt_adapter_request: Prompt Adapter request to use for + generation, if any. + + Returns: + A list of ``PoolingRequestOutput`` objects containing the + generated scores in the same order as the input prompts. + """ + task = self.llm_engine.model_config.task + if task != "embedding": + messages = ["LLM.score() is only supported for embedding models."] + + supported_tasks = self.llm_engine.model_config.supported_tasks + if "embedding" in supported_tasks: + messages.append( + "Your model supports the 'embedding' task, but is " + f"currently initialized for the '{task}' task. Please " + "initialize the model using `--task embedding`.") + + raise ValueError(" ".join(messages)) + + if not self.llm_engine.model_config.is_cross_encoder: + raise ValueError("Your model does not support the cross encoding") + + tokenizer = self.llm_engine.get_tokenizer() + + if isinstance(tokenizer, MistralTokenizer): + raise ValueError( + "MistralTokenizer not supported for cross-encoding") + + # the tokenizer for models such as + # "cross-encoder/ms-marco-MiniLM-L-6-v2" doesn't support passing + # lists of tokens to the `text` and `text_pair` kwargs + def ensure_str(prompt: SingletonPrompt): + if isinstance(prompt, dict): + if "multi_modal_data" in prompt: + raise ValueError("Multi-modal prompt is not " + "supported for cross encoding") + elif "prompt_token_ids" in prompt: + prompt = tokenizer.decode( + cast(TokensPrompt, prompt)["prompt_token_ids"]) + elif "prompt" in prompt: + prompt = cast(TextPrompt, prompt)["prompt"] + assert type(prompt) is str + return prompt + + if isinstance(text_1, (str, dict)): + # Convert a single prompt to a list. + text_1 = [text_1] + text_1 = [ensure_str(t) for t in text_1] + + if isinstance(text_2, (str, dict)): + # Convert a single prompt to a list. + text_2 = [text_2] + text_2 = [ensure_str(t) for t in text_2] + + if len(text_1) > 1 and len(text_1) != len(text_2): + raise ValueError("Input lengths must be either 1:1, 1:N or N:N") + if len(text_1) == 0: + raise ValueError("At least one text element must be given") + if len(text_2) == 0: + raise ValueError("At least one text_pair element must be given") + + if len(text_1) == 1: + text_1 = text_1 * len(text_2) + + input_pairs = [(t1, t2) for t1, t2 in zip(text_1, text_2)] + pooling_params = PoolingParams() + + tokenization_kwargs: Dict[str, Any] = {} + if truncate_prompt_tokens is not None: + tokenization_kwargs["truncation"] = True + tokenization_kwargs["max_length"] = truncate_prompt_tokens + + parsed_prompts = [] + + for q, t in input_pairs: + prompt_inputs = tokenizer(text=q, + text_pair=t, + **tokenization_kwargs) + engine_prompt = TokensPrompt( + prompt_token_ids=prompt_inputs["input_ids"], + token_type_ids=prompt_inputs.get("token_type_ids")) + parsed_prompts.append(engine_prompt) + + self._validate_and_add_requests( + prompts=parsed_prompts, + params=pooling_params, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request, + ) + + outputs = self._run_engine(use_tqdm=use_tqdm) + return self.engine_class.validate_outputs(outputs, + PoolingRequestOutput) def start_profile(self) -> None: self.llm_engine.start_profile() @@ -944,7 +1085,7 @@ def _add_guided_params( def _run_engine( self, *, use_tqdm: bool - ) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: + ) -> List[Union[RequestOutput, PoolingRequestOutput]]: # Initialize tqdm. if use_tqdm: num_requests = self.llm_engine.get_num_unfinished_requests() @@ -957,7 +1098,7 @@ def _run_engine( ) # Run the engine. - outputs: List[Union[RequestOutput, EmbeddingRequestOutput]] = [] + outputs: List[Union[RequestOutput, PoolingRequestOutput]] = [] total_in_toks = 0 total_out_toks = 0 while self.llm_engine.has_unfinished_requests(): diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index b0fe061f5db4a..6bc31ef83ded4 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -45,6 +45,7 @@ EmbeddingRequest, EmbeddingResponse, ErrorResponse, LoadLoraAdapterRequest, + ScoreRequest, ScoreResponse, TokenizeRequest, TokenizeResponse, UnloadLoraAdapterRequest) @@ -53,6 +54,7 @@ from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding from vllm.entrypoints.openai.serving_engine import BaseModelPath, OpenAIServing +from vllm.entrypoints.openai.serving_score import OpenAIServingScores from vllm.entrypoints.openai.serving_tokenization import ( OpenAIServingTokenization) from vllm.entrypoints.openai.tool_parsers import ToolParserManager @@ -133,8 +135,8 @@ async def build_async_engine_client_from_engine_args( # TODO: fill out feature matrix. if (MQLLMEngineClient.is_unsupported_config(engine_args) or envs.VLLM_USE_V1 or disable_frontend_multiprocessing): - - engine_config = engine_args.create_engine_config() + engine_config = engine_args.create_engine_config( + UsageContext.OPENAI_API_SERVER) uses_ray = getattr(AsyncLLMEngine._get_executor_cls(engine_config), "uses_ray", False) @@ -280,6 +282,10 @@ def embedding(request: Request) -> Optional[OpenAIServingEmbedding]: return request.app.state.openai_serving_embedding +def score(request: Request) -> Optional[OpenAIServingScores]: + return request.app.state.openai_serving_scores + + def tokenization(request: Request) -> OpenAIServingTokenization: return request.app.state.openai_serving_tokenization @@ -391,6 +397,23 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request): assert_never(generator) +@router.post("/v1/score") +async def create_score(request: ScoreRequest, raw_request: Request): + handler = score(raw_request) + if handler is None: + return base(raw_request).create_error_response( + message="The model does not support Score API") + + generator = await handler.create_score(request, raw_request) + if isinstance(generator, ErrorResponse): + return JSONResponse(content=generator.model_dump(), + status_code=generator.code) + elif isinstance(generator, ScoreResponse): + return JSONResponse(content=generator.model_dump()) + + assert_never(generator) + + if envs.VLLM_TORCH_PROFILER_DIR: logger.warning( "Torch Profiler is enabled in the API server. This should ONLY be " @@ -466,8 +489,9 @@ def build_app(args: Namespace) -> FastAPI: @app.exception_handler(RequestValidationError) async def validation_exception_handler(_, exc): - chat = app.state.openai_serving_chat - err = chat.create_error_response(message=str(exc)) + err = ErrorResponse(message=str(exc), + type="BadRequestError", + code=HTTPStatus.BAD_REQUEST) return JSONResponse(err.model_dump(), status_code=HTTPStatus.BAD_REQUEST) @@ -475,10 +499,12 @@ async def validation_exception_handler(_, exc): @app.middleware("http") async def authentication(request: Request, call_next): - root_path = "" if args.root_path is None else args.root_path if request.method == "OPTIONS": return await call_next(request) - if not request.url.path.startswith(f"{root_path}/v1"): + url_path = request.url.path + if app.root_path and url_path.startswith(app.root_path): + url_path = url_path[len(app.root_path):] + if not url_path.startswith("/v1"): return await call_next(request) if request.headers.get("Authorization") != "Bearer " + token: return JSONResponse(content={"error": "Unauthorized"}, @@ -565,6 +591,13 @@ def init_app_state( chat_template=resolved_chat_template, chat_template_content_format=args.chat_template_content_format, ) if model_config.task == "embedding" else None + state.openai_serving_scores = OpenAIServingScores( + engine_client, + model_config, + base_model_paths, + request_logger=request_logger + ) if (model_config.task == "embedding" \ + and model_config.is_cross_encoder) else None state.openai_serving_tokenization = OpenAIServingTokenization( engine_client, model_config, diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index b7b064ae01f05..ee94a9413f098 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -9,12 +9,15 @@ from typing_extensions import Annotated from vllm.entrypoints.chat_utils import ChatCompletionMessageParam +from vllm.logger import init_logger from vllm.pooling_params import PoolingParams from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams, RequestOutputKind, SamplingParams) from vllm.sequence import Logprob from vllm.utils import random_uuid +logger = init_logger(__name__) + # torch is mocked during docs generation, # so we have to provide the values as literals _MOCK_LONG_INFO = Namespace(min=-9223372036854775808, max=9223372036854775807) @@ -35,8 +38,19 @@ class OpenAIBaseModel(BaseModel): - # OpenAI API does not allow extra fields - model_config = ConfigDict(extra="forbid") + # OpenAI API does allow extra fields + model_config = ConfigDict(extra="allow") + + @model_validator(mode="before") + @classmethod + def __log_extra_fields__(cls, data): + if isinstance(data, dict): + extra_fields = data.keys() - cls.model_fields.keys() + if extra_fields: + logger.warning( + "The following fields were present in the request " + "but ignored: %s", extra_fields) + return data class ErrorResponse(OpenAIBaseModel): @@ -464,17 +478,17 @@ def check_tool_usage(cls, data): # it matches a valid tool if isinstance(data["tool_choice"], dict): valid_tool = False - specified_function = data["tool_choice"]["function"] + specified_function = data["tool_choice"].get("function") if not specified_function: raise ValueError( - "Incorrectly formatted `tool_choice`. Should be like " - "`{\"type\": \"function\"," + "Expected field `function` in `tool_choice`." + " Correct usage: `{\"type\": \"function\"," " \"function\": {\"name\": \"my_function\"}}`") - specified_function_name = specified_function["name"] + specified_function_name = specified_function.get("name") if not specified_function_name: raise ValueError( - "Incorrectly formatted `tool_choice`. Should be like " - "`{\"type\": \"function\", " + "Expected field `name` in `function` in `tool_choice`." + "Correct usage: `{\"type\": \"function\", " "\"function\": {\"name\": \"my_function\"}}`") for tool in data["tools"]: if tool["function"]["name"] == specified_function_name: @@ -746,22 +760,6 @@ class EmbeddingChatRequest(OpenAIBaseModel): # doc: end-chat-embedding-pooling-params # doc: begin-chat-embedding-extra-params - add_generation_prompt: bool = Field( - default=True, - description= - ("If true, the generation prompt will be added to the chat template. " - "This is a parameter used by chat template in tokenizer config of the " - "model."), - ) - continue_final_message: bool = Field( - default=False, - description= - ("If this is set, the chat will be formatted so that the final " - "message in the chat is open-ended, without any EOS tokens. The " - "model will continue this message rather than starting a new one. " - "This allows you to \"prefill\" part of the model's response for it. " - "Cannot be used at the same time as `add_generation_prompt`."), - ) add_special_tokens: bool = Field( default=False, description=( @@ -808,6 +806,27 @@ def to_pooling_params(self): EmbeddingRequest = Union[EmbeddingCompletionRequest, EmbeddingChatRequest] +class ScoreRequest(OpenAIBaseModel): + model: str + text_1: Union[List[str], str] + text_2: Union[List[str], str] + truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None + + # doc: begin-chat-embedding-pooling-params + additional_data: Optional[Any] = None + # doc: end-chat-embedding-pooling-params + + priority: int = Field( + default=0, + description=( + "The priority of the request (lower means earlier handling; " + "default: 0). Any priority other than 0 will raise an error " + "if the served model does not use priority scheduling.")) + + def to_pooling_params(self): + return PoolingParams(additional_data=self.additional_data) + + class CompletionLogProbs(OpenAIBaseModel): text_offset: List[int] = Field(default_factory=list) token_logprobs: List[Optional[float]] = Field(default_factory=list) @@ -878,6 +897,21 @@ class EmbeddingResponse(OpenAIBaseModel): usage: UsageInfo +class ScoreResponseData(OpenAIBaseModel): + index: int + object: str = "score" + score: Union[List[float], str] + + +class ScoreResponse(OpenAIBaseModel): + id: str = Field(default_factory=lambda: f"embd-{random_uuid()}") + object: str = "list" + created: int = Field(default_factory=lambda: int(time.time())) + model: str + data: List[ScoreResponseData] + usage: UsageInfo + + class FunctionCall(OpenAIBaseModel): name: str arguments: str diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 2eef909eb9319..54ca0463bcab1 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -361,7 +361,7 @@ async def chat_completion_stream_generator( # Send response to echo the input portion of the # last message - if request.echo or request.continue_final_message: + if request.echo: last_msg_content: Union[str, List[Dict[str, str]]] = "" if conversation and "content" in conversation[ -1] and conversation[-1].get("role") == role: @@ -706,7 +706,7 @@ async def chat_completion_full_generator( stop_reason=output.stop_reason) choices.append(choice_data) - if request.echo or request.continue_final_message: + if request.echo: last_msg_content: Union[str, List[Dict[str, str]]] = "" if conversation and "content" in conversation[-1] and conversation[ -1].get("role") == role: diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 936aae8f1c267..fc1c4908d6650 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -101,7 +101,7 @@ async def create_completion( tokenizer = await self.engine_client.get_tokenizer(lora_request) - request_prompts, engine_prompts = self._preprocess_completion( + request_prompts, engine_prompts = await self._preprocess_completion( request, tokenizer, request.prompt, diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py index 74ad7389784fc..2cbb252610e39 100644 --- a/vllm/entrypoints/openai/serving_embedding.py +++ b/vllm/entrypoints/openai/serving_embedding.py @@ -18,14 +18,14 @@ ErrorResponse, UsageInfo) from vllm.entrypoints.openai.serving_engine import BaseModelPath, OpenAIServing from vllm.logger import init_logger -from vllm.outputs import EmbeddingOutput, EmbeddingRequestOutput +from vllm.outputs import PoolingOutput, PoolingRequestOutput from vllm.utils import merge_async_iterators, random_uuid logger = init_logger(__name__) def _get_embedding( - output: EmbeddingOutput, + output: PoolingOutput, encoding_format: Literal["float", "base64"], ) -> Union[List[float], str]: if encoding_format == "float": @@ -40,7 +40,7 @@ def _get_embedding( def request_output_to_embedding_response( - final_res_batch: List[EmbeddingRequestOutput], request_id: str, + final_res_batch: List[PoolingRequestOutput], request_id: str, created_time: int, model_name: str, encoding_format: Literal["float", "base64"]) -> EmbeddingResponse: data: List[EmbeddingResponseData] = [] @@ -148,25 +148,28 @@ async def create_embedding( chat_template=request.chat_template or self.chat_template, chat_template_content_format=self. chat_template_content_format, - add_generation_prompt=request.add_generation_prompt, - continue_final_message=request.continue_final_message, + # In embedding requests, we are not generating tokens, + # so there is no need to append extra tokens to the input + add_generation_prompt=False, + continue_final_message=False, truncate_prompt_tokens=truncate_prompt_tokens, add_special_tokens=request.add_special_tokens, ) else: - request_prompts, engine_prompts = self._preprocess_completion( - request, - tokenizer, - request.input, - truncate_prompt_tokens=truncate_prompt_tokens, - add_special_tokens=request.add_special_tokens, - ) + (request_prompts, + engine_prompts) = await self._preprocess_completion( + request, + tokenizer, + request.input, + truncate_prompt_tokens=truncate_prompt_tokens, + add_special_tokens=request.add_special_tokens, + ) except ValueError as e: logger.exception("Error in preprocessing prompt inputs") return self.create_error_response(str(e)) # Schedule the request and get the result generator. - generators: List[AsyncGenerator[EmbeddingRequestOutput, None]] = [] + generators: List[AsyncGenerator[PoolingRequestOutput, None]] = [] try: pooling_params = request.to_pooling_params() @@ -204,7 +207,7 @@ async def create_embedding( num_prompts = len(engine_prompts) # Non-streaming response - final_res_batch: List[Optional[EmbeddingRequestOutput]] + final_res_batch: List[Optional[PoolingRequestOutput]] final_res_batch = [None] * num_prompts try: async for i, res in result_generator: @@ -212,7 +215,7 @@ async def create_embedding( assert all(final_res is not None for final_res in final_res_batch) - final_res_batch_checked = cast(List[EmbeddingRequestOutput], + final_res_batch_checked = cast(List[PoolingRequestOutput], final_res_batch) response = request_output_to_embedding_response( diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index cae2877ea7e99..8232c6116c1bd 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -1,5 +1,6 @@ import json import pathlib +from concurrent.futures.thread import ThreadPoolExecutor from dataclasses import dataclass from http import HTTPStatus from typing import (Any, Callable, Dict, Iterable, Iterator, List, Mapping, @@ -46,7 +47,7 @@ from vllm.tracing import (contains_trace_headers, extract_trace_headers, log_tracing_disabled_warning) from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer -from vllm.utils import AtomicCounter, is_list_of +from vllm.utils import AtomicCounter, is_list_of, make_async logger = init_logger(__name__) @@ -140,6 +141,14 @@ def __init__( self.request_logger = request_logger self.return_tokens_as_token_ids = return_tokens_as_token_ids + self._tokenizer_executor = ThreadPoolExecutor(max_workers=1) + + self._tokenize_prompt_input_async = make_async( + self._tokenize_prompt_input, executor=self._tokenizer_executor) + self._tokenize_prompt_input_or_inputs_async = make_async( + self._tokenize_prompt_input_or_inputs, + executor=self._tokenizer_executor) + async def show_available_models(self) -> ModelList: """Show available models. Right now we only have one model.""" model_cards = [ @@ -368,7 +377,7 @@ def _tokenize_prompt_input_or_inputs( input_or_inputs: Union[str, List[str], List[int], List[List[int]]], truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None, add_special_tokens: bool = True, - ) -> Iterator[TextTokensPrompt]: + ) -> List[TextTokensPrompt]: """ Tokenize/detokenize depending on the input format. @@ -376,45 +385,41 @@ def _tokenize_prompt_input_or_inputs( , each input can be a string or array of tokens. Note that each request can pass one or more inputs. """ - for prompt_input in parse_and_batch_prompt(input_or_inputs): - # Although our type checking is based on mypy, - # VSCode Pyright extension should still work properly - # "is True" is required for Pyright to perform type narrowing - # See: https://github.com/microsoft/pyright/issues/7672 - if prompt_input["is_tokens"] is False: - yield self._normalize_prompt_text_to_input( - request, - tokenizer, - prompt=prompt_input["content"], - truncate_prompt_tokens=truncate_prompt_tokens, - add_special_tokens=add_special_tokens, - ) - else: - yield self._normalize_prompt_tokens_to_input( - request, - tokenizer, - prompt_ids=prompt_input["content"], - truncate_prompt_tokens=truncate_prompt_tokens, - ) + # Although our type checking is based on mypy, + # VSCode Pyright extension should still work properly + # "is True" is required for Pyright to perform type narrowing + # See: https://github.com/microsoft/pyright/issues/7672 + return [ + self._normalize_prompt_text_to_input( + request, + tokenizer, + prompt=prompt_input["content"], + truncate_prompt_tokens=truncate_prompt_tokens, + add_special_tokens=add_special_tokens) + if prompt_input["is_tokens"] is False else + self._normalize_prompt_tokens_to_input( + request, + tokenizer, + prompt_ids=prompt_input["content"], + truncate_prompt_tokens=truncate_prompt_tokens) + for prompt_input in parse_and_batch_prompt(input_or_inputs) + ] - def _preprocess_completion( + async def _preprocess_completion( self, request: CompletionLikeRequest, tokenizer: AnyTokenizer, input_or_inputs: Union[str, List[str], List[int], List[List[int]]], truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None, add_special_tokens: bool = True, - ) -> Tuple[Sequence[TextTokensPrompt], List[TokensPrompt]]: - request_prompts = [ - request_prompt - for request_prompt in self._tokenize_prompt_input_or_inputs( - request, - tokenizer, - input_or_inputs, - truncate_prompt_tokens=truncate_prompt_tokens, - add_special_tokens=add_special_tokens, - ) - ] + ) -> Tuple[List[TextTokensPrompt], List[TokensPrompt]]: + request_prompts = await self._tokenize_prompt_input_or_inputs_async( + request, + tokenizer, + input_or_inputs, + truncate_prompt_tokens=truncate_prompt_tokens, + add_special_tokens=add_special_tokens, + ) engine_prompts = [ TokensPrompt(prompt_token_ids=request_prompt["prompt_token_ids"]) @@ -493,7 +498,7 @@ async def _preprocess_chat( request=request) if isinstance(request_prompt, str): - prompt_inputs = self._tokenize_prompt_input( + prompt_inputs = await self._tokenize_prompt_input_async( request, tokenizer, request_prompt, diff --git a/vllm/entrypoints/openai/serving_score.py b/vllm/entrypoints/openai/serving_score.py new file mode 100644 index 0000000000000..a1f14449ba9c3 --- /dev/null +++ b/vllm/entrypoints/openai/serving_score.py @@ -0,0 +1,217 @@ +import asyncio +import time +from typing import Any, AsyncGenerator, Dict, List, Optional, Union, cast + +from fastapi import Request + +from vllm.config import ModelConfig +from vllm.engine.protocol import EngineClient +from vllm.entrypoints.logger import RequestLogger +from vllm.entrypoints.openai.protocol import (ErrorResponse, ScoreRequest, + ScoreResponse, ScoreResponseData, + UsageInfo) +from vllm.entrypoints.openai.serving_engine import BaseModelPath, OpenAIServing +from vllm.inputs.data import TokensPrompt +from vllm.logger import init_logger +from vllm.outputs import PoolingRequestOutput +from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer +from vllm.utils import make_async, merge_async_iterators, random_uuid + +logger = init_logger(__name__) + + +def request_output_to_score_response( + final_res_batch: List[PoolingRequestOutput], request_id: str, + created_time: int, model_name: str) -> ScoreResponse: + data: List[ScoreResponseData] = [] + score = None + num_prompt_tokens = 0 + for idx, final_res in enumerate(final_res_batch): + if final_res is not None: + score = final_res.outputs.embedding + score_data = ScoreResponseData(index=idx, score=score) + data.append(score_data) + + usage = UsageInfo( + prompt_tokens=num_prompt_tokens, + total_tokens=num_prompt_tokens, + ) + + return ScoreResponse( + id=request_id, + created=created_time, + model=model_name, + data=data, + usage=usage, + ) + + +def make_pairs(text_1: Union[List[str], str], text_2: Union[List[str], + str]) -> List: + if isinstance(text_1, (str, dict)): + # Convert a single prompt to a list. + text_1 = [text_1] + text_1 = [t for t in text_1] + + if isinstance(text_2, (str, dict)): + # Convert a single prompt to a list. + text_2 = [text_2] + text_2 = [t for t in text_2] + if len(text_1) > 1 and len(text_1) != len(text_2): + raise ValueError("Input lengths must be either 1:1, 1:N or N:N") + if len(text_1) == 0: + raise ValueError("At least one text element must be given") + if len(text_2) == 0: + raise ValueError("At least one text_pair element must be given") + + if len(text_1) == 1: + text_1 = text_1 * len(text_2) + + return [(t1, t2) for t1, t2 in zip(text_1, text_2)] + + +class OpenAIServingScores(OpenAIServing): + + def __init__( + self, + engine_client: EngineClient, + model_config: ModelConfig, + base_model_paths: List[BaseModelPath], + *, + request_logger: Optional[RequestLogger], + ) -> None: + super().__init__(engine_client=engine_client, + model_config=model_config, + base_model_paths=base_model_paths, + lora_modules=None, + prompt_adapters=None, + request_logger=request_logger) + + async def create_score( + self, + request: ScoreRequest, + raw_request: Optional[Request] = None, + ) -> Union[ScoreResponse, ErrorResponse]: + """ + Score API similar to Sentence Transformers cross encoder + + See https://sbert.net/docs/package_reference/cross_encoder + """ + error_check_ret = await self._check_model(request) + if error_check_ret is not None: + return error_check_ret + + model_name = request.model + request_id = f"score-{random_uuid()}" + created_time = int(time.monotonic()) + truncate_prompt_tokens = request.truncate_prompt_tokens + + request_prompts = [] + engine_prompts = [] + + try: + ( + lora_request, + prompt_adapter_request, + ) = self._maybe_get_adapters(request) + + tokenizer = await self.engine_client.get_tokenizer(lora_request) + + if prompt_adapter_request is not None: + raise NotImplementedError("Prompt adapter is not supported " + "for embedding models") + + if isinstance(tokenizer, MistralTokenizer): + raise ValueError( + "MistralTokenizer not supported for cross-encoding") + + if not self.model_config.is_cross_encoder: + raise ValueError("Model is not cross encoder.") + + except ValueError as e: + logger.exception("Error in preprocessing prompt inputs") + return self.create_error_response(str(e)) + + # Schedule the request and get the result generator. + generators: List[AsyncGenerator[PoolingRequestOutput, None]] = [] + + input_pairs = make_pairs(request.text_1, request.text_2) + + for q, t in input_pairs: + request_prompt = f"{q}{tokenizer.sep_token}{t}" + + tokenization_kwargs: Dict[str, Any] = {} + if truncate_prompt_tokens is not None: + tokenization_kwargs["truncation"] = True + tokenization_kwargs["max_length"] = truncate_prompt_tokens + + tokenize_async = make_async(tokenizer.__call__, + executor=self._tokenizer_executor) + prompt_inputs = await tokenize_async(text=q, + text_pair=t, + **tokenization_kwargs) + engine_prompt = TokensPrompt( + prompt_token_ids=prompt_inputs["input_ids"], + token_type_ids=prompt_inputs.get("token_type_ids")) + + request_prompts.append(request_prompt) + engine_prompts.append(engine_prompt) + + try: + pooling_params = request.to_pooling_params() + + for i, engine_prompt in enumerate(engine_prompts): + request_id_item = f"{request_id}-{i}" + + self._log_inputs(request_id_item, + request_prompts[i], + params=pooling_params, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request) + + trace_headers = (None if raw_request is None else await + self._get_trace_headers(raw_request.headers)) + + generator = self.engine_client.encode( + engine_prompt, + pooling_params, + request_id_item, + lora_request=lora_request, + trace_headers=trace_headers, + priority=request.priority, + ) + + generators.append(generator) + except ValueError as e: + # TODO: Use a vllm-specific Validation Error + return self.create_error_response(str(e)) + + result_generator = merge_async_iterators( + *generators, + is_cancelled=raw_request.is_disconnected if raw_request else None, + ) + + num_prompts = len(engine_prompts) + + # Non-streaming response + final_res_batch: List[Optional[PoolingRequestOutput]] + final_res_batch = [None] * num_prompts + + try: + async for i, res in result_generator: + final_res_batch[i] = res + + assert all(final_res is not None for final_res in final_res_batch) + + final_res_batch_checked = cast(List[PoolingRequestOutput], + final_res_batch) + + response = request_output_to_score_response( + final_res_batch_checked, request_id, created_time, model_name) + except asyncio.CancelledError: + return self.create_error_response("Client disconnected") + except ValueError as e: + # TODO: Use a vllm-specific Validation Error + return self.create_error_response(str(e)) + + return response diff --git a/vllm/entrypoints/openai/serving_tokenization.py b/vllm/entrypoints/openai/serving_tokenization.py index 59b3b1311f881..9c3dc2c98b2dd 100644 --- a/vllm/entrypoints/openai/serving_tokenization.py +++ b/vllm/entrypoints/openai/serving_tokenization.py @@ -81,12 +81,13 @@ async def create_tokenize( add_special_tokens=request.add_special_tokens, ) else: - request_prompts, engine_prompts = self._preprocess_completion( - request, - tokenizer, - request.prompt, - add_special_tokens=request.add_special_tokens, - ) + (request_prompts, + engine_prompts) = await self._preprocess_completion( + request, + tokenizer, + request.prompt, + add_special_tokens=request.add_special_tokens, + ) except ValueError as e: logger.exception("Error in preprocessing prompt inputs") return self.create_error_response(str(e)) @@ -134,7 +135,7 @@ async def create_detokenize( # Silently ignore prompt adapter since it does not affect tokenization # (Unlike in Embeddings API where an error is raised) - prompt_input = self._tokenize_prompt_input( + prompt_input = await self._tokenize_prompt_input_async( request, tokenizer, request.tokens, diff --git a/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py index faa6f653b835c..18816cd665b3e 100644 --- a/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py @@ -12,8 +12,6 @@ FunctionCall, ToolCall) from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( ToolParser, ToolParserManager) -from vllm.entrypoints.openai.tool_parsers.utils import ( - extract_intermediate_diff) from vllm.logger import init_logger from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer from vllm.utils import random_uuid @@ -190,8 +188,11 @@ def extract_tool_calls_streaming( diff = self.prev_tool_call_arr[self.current_tool_id].get( "arguments") if diff: - diff = json.dumps(diff).replace( - self.streamed_args_for_tool[self.current_tool_id], "") + diff = diff.encode('utf-8').decode( + 'unicode_escape') if diff is str else diff + diff = json.dumps( + diff, ensure_ascii=False + )[len(self.streamed_args_for_tool[self.current_tool_id]):] logger.debug( "Finishing tool and found diff that had not " "been streamed yet: %s", diff) @@ -307,22 +308,20 @@ def extract_tool_calls_streaming( # last case -- we have an update to existing arguments. elif cur_arguments and prev_arguments: + if isinstance(delta_text, str) and len(delta_text.rstrip( + )) >= 1 and delta_text.rstrip()[-1] == '}': + delta_text = delta_text.rstrip()[:-1] + + logger.debug("got diff %s", delta_text) - cur_args_json = json.dumps(cur_arguments) - prev_args_json = json.dumps(prev_arguments) - logger.debug("Searching for diff between\n%s", cur_args_json) - logger.debug("and\n%s", prev_args_json) - argument_diff = extract_intermediate_diff( - cur_args_json, prev_args_json) - logger.debug("got argument diff %s", argument_diff) delta = DeltaMessage(tool_calls=[ DeltaToolCall(index=self.current_tool_id, function=DeltaFunctionCall( - arguments=argument_diff).model_dump( + arguments=delta_text).model_dump( exclude_none=True)) ]) self.streamed_args_for_tool[self.current_tool_id] \ - += argument_diff + += delta_text # handle saving the state for the current tool into # the "prev" list for use in diffing for the next iteration diff --git a/vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py index a5f44d69e5fd2..1856308b88cfa 100644 --- a/vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py @@ -29,7 +29,8 @@ class Llama3JsonToolParser(ToolParser): Tool call parser for Llama 3.1 models intended for use with the examples/tool_chat_template_llama.jinja template. - Used when --enable-auto-tool-choice --tool-call-parser mistral are all set + Used when --enable-auto-tool-choice --tool-call-parser llama3_json + are all set """ def __init__(self, tokenizer: PreTrainedTokenizerBase): diff --git a/vllm/envs.py b/vllm/envs.py index 0b5ac4f8a4e24..3e458f9b6f25f 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -56,7 +56,7 @@ VLLM_WORKER_MULTIPROC_METHOD: str = "fork" VLLM_ASSETS_CACHE: str = os.path.join(VLLM_CACHE_ROOT, "assets") VLLM_IMAGE_FETCH_TIMEOUT: int = 5 - VLLM_VIDEO_FETCH_TIMEOUT: int = 15 + VLLM_VIDEO_FETCH_TIMEOUT: int = 30 VLLM_AUDIO_FETCH_TIMEOUT: int = 10 VLLM_TARGET_DEVICE: str = "cuda" MAX_JOBS: Optional[str] = None @@ -75,8 +75,6 @@ VLLM_USE_TRITON_AWQ: bool = False VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False VLLM_SKIP_P2P_CHECK: bool = False - VLLM_TORCH_COMPILE_LEVEL: int = 0 - VLLM_TORCH_COMPILE_CONFIG: Optional[str] = None VLLM_DISABLED_KERNELS: List[str] = [] VLLM_USE_V1: bool = False VLLM_SYNC_SERVER_ACCUM_REQUESTS: int = 1 @@ -167,7 +165,7 @@ def get_default_config_root(): # If you are using multi-node inference, you should set this differently # on each node. 'VLLM_HOST_IP': - lambda: os.getenv('VLLM_HOST_IP', "") or os.getenv("HOST_IP", ""), + lambda: os.getenv('VLLM_HOST_IP', ""), # used in distributed environment to manually set the communication port # Note: if VLLM_PORT is set, and some code asks for multiple ports, the @@ -236,12 +234,6 @@ def get_default_config_root(): "VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE": lambda: bool( os.environ.get("VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE", "1") != "0"), - "VLLM_TORCH_COMPILE_LEVEL": - lambda: int(os.environ.get("VLLM_TORCH_COMPILE_LEVEL", "0")), - - # Path to the config file for torch compile - "VLLM_TORCH_COMPILE_CONFIG": - lambda: os.environ.get("VLLM_TORCH_COMPILE_CONFIG", None), # small gemms custom implementation for MI3* cards "VLLM_USE_ROCM_SKINNY_GEMM": diff --git a/vllm/executor/cpu_executor.py b/vllm/executor/cpu_executor.py index 1542a2ae367eb..336f9bc8efb20 100644 --- a/vllm/executor/cpu_executor.py +++ b/vllm/executor/cpu_executor.py @@ -115,13 +115,8 @@ def _create_worker( local_rank: int = 0, rank: int = 0, ): - worker_module_name = "vllm.worker.cpu_worker" - worker_class_name = "CPUWorker" - wrapper = WorkerWrapperBase( - worker_module_name=worker_module_name, - worker_class_name=worker_class_name, - ) + wrapper = WorkerWrapperBase(vllm_config=self.vllm_config) assert self.distributed_init_method is not None diff --git a/vllm/executor/gpu_executor.py b/vllm/executor/gpu_executor.py index c65d0836e5ff7..7fa34456028dd 100644 --- a/vllm/executor/gpu_executor.py +++ b/vllm/executor/gpu_executor.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union +from typing import Any, Dict, List, Optional, Set, Tuple, Union from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase from vllm.logger import init_logger @@ -8,19 +8,14 @@ from vllm.sequence import ExecuteModelRequest, PoolerOutput from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, make_async) -from vllm.worker.worker_base import WorkerBase, WorkerWrapperBase +from vllm.worker.worker_base import WorkerWrapperBase logger = init_logger(__name__) -def create_worker(worker_module_name: str, worker_class_name: str, - worker_class_fn: Optional[Callable[[], Type[WorkerBase]]], - **kwargs): - wrapper = WorkerWrapperBase( - worker_module_name=worker_module_name, - worker_class_name=worker_class_name, - worker_class_fn=worker_class_fn, - ) +def create_worker(**kwargs): + vllm_config = kwargs.get("vllm_config") + wrapper = WorkerWrapperBase(vllm_config=vllm_config) wrapper.init_worker(**kwargs) return wrapper.worker @@ -57,43 +52,11 @@ def _get_worker_kwargs( or (rank % self.parallel_config.tensor_parallel_size == 0), ) - def _get_worker_module_and_class( - self) -> Tuple[str, str, Optional[Callable[[], Type[WorkerBase]]]]: - worker_class_fn = None - if self.scheduler_config.is_multi_step: - worker_module_name = "vllm.worker.multi_step_worker" - worker_class_name = "MultiStepWorker" - elif self.speculative_config: - worker_module_name = "vllm.spec_decode.spec_decode_worker" - worker_class_name = "create_spec_worker" - else: - worker_module_name = "vllm.worker.worker" - worker_class_name = "Worker" - return (worker_module_name, worker_class_name, worker_class_fn) - - def _get_create_worker_kwargs( - self, - local_rank: int = 0, - rank: int = 0, - distributed_init_method: Optional[str] = None) -> Dict: - worker_kwargs = self._get_worker_kwargs(local_rank, rank, - distributed_init_method) - - (worker_module_name, worker_class_name, - worker_class_fn) = self._get_worker_module_and_class() - worker_kwargs.update( - worker_module_name=worker_module_name, - worker_class_name=worker_class_name, - worker_class_fn=worker_class_fn, - ) - - return worker_kwargs - def _create_worker(self, local_rank: int = 0, rank: int = 0, distributed_init_method: Optional[str] = None): - return create_worker(**self._get_create_worker_kwargs( + return create_worker(**self._get_worker_kwargs( local_rank=local_rank, rank=rank, distributed_init_method=distributed_init_method)) diff --git a/vllm/executor/hpu_executor.py b/vllm/executor/hpu_executor.py index 220e9eee87bb3..c9b7bfa71edfa 100644 --- a/vllm/executor/hpu_executor.py +++ b/vllm/executor/hpu_executor.py @@ -48,10 +48,7 @@ def _create_worker(self, local_rank: int = 0, rank: int = 0, distributed_init_method: Optional[str] = None): - wrapper = WorkerWrapperBase( - worker_module_name="vllm.worker.hpu_worker", - worker_class_name="HPUWorker", - ) + wrapper = WorkerWrapperBase(vllm_config=self.vllm_config) wrapper.init_worker(**self._get_worker_kwargs(local_rank, rank, distributed_init_method)) return wrapper.worker diff --git a/vllm/executor/multiproc_gpu_executor.py b/vllm/executor/multiproc_gpu_executor.py index 3eb14fb931925..a6c05a71d2b6f 100644 --- a/vllm/executor/multiproc_gpu_executor.py +++ b/vllm/executor/multiproc_gpu_executor.py @@ -90,7 +90,7 @@ def _init_executor(self) -> None: result_handler, partial( create_worker, - **self._get_create_worker_kwargs( + **self._get_worker_kwargs( rank=rank, local_rank=rank, distributed_init_method=distributed_init_method, diff --git a/vllm/executor/neuron_executor.py b/vllm/executor/neuron_executor.py index 02d37cd7fbf23..a9efc4f9a801c 100644 --- a/vllm/executor/neuron_executor.py +++ b/vllm/executor/neuron_executor.py @@ -7,6 +7,7 @@ from vllm.sequence import ExecuteModelRequest from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, make_async) +from vllm.worker.worker_base import WorkerWrapperBase logger = init_logger(__name__) @@ -25,14 +26,16 @@ def _init_executor(self) -> None: self._init_worker() def _init_worker(self): - from vllm.worker.neuron_worker import NeuronWorker + wrapper = WorkerWrapperBase(vllm_config=self.vllm_config) distributed_init_method = get_distributed_init_method( get_ip(), get_open_port()) - self.driver_worker = NeuronWorker( + wrapper.init_worker( vllm_config=self.vllm_config, local_rank=0, rank=0, - distributed_init_method=distributed_init_method) + distributed_init_method=distributed_init_method, + ) + self.driver_worker = wrapper.worker self.driver_worker.init_device() self.driver_worker.load_model() diff --git a/vllm/executor/openvino_executor.py b/vllm/executor/openvino_executor.py index d06b0ccb7906e..057a32364e512 100644 --- a/vllm/executor/openvino_executor.py +++ b/vllm/executor/openvino_executor.py @@ -1,19 +1,17 @@ from typing import List, Set, Tuple import openvino as ov -import openvino.properties.hint as hints -import torch import vllm.envs as envs -from vllm.config import CacheConfig, ModelConfig from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.model_executor.layers.sampler import SamplerOutput from vllm.platforms import current_platform from vllm.sequence import ExecuteModelRequest -from vllm.utils import (GiB_bytes, get_distributed_init_method, get_ip, - get_open_port, make_async) +from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, + make_async) +from vllm.worker.worker_base import WorkerWrapperBase logger = init_logger(__name__) @@ -29,25 +27,17 @@ def _init_executor(self) -> None: current_platform.is_openvino_gpu(), \ "OpenVINO backend supports only CPU and GPU devices" - self.ov_core = ov.Core() - self.model_config = _verify_and_get_model_config(self.model_config) - self.cache_config = _verify_and_get_cache_config( - self.ov_core, self.cache_config) - # Instantiate the worker and load the model to CPU. self._init_worker() def _init_worker(self): - from vllm.worker.openvino_worker import OpenVINOWorker - assert ( - self.parallel_config.world_size == 1 - ), "OpenVINOExecutor only supports single CPU socket currently." + wrapper = WorkerWrapperBase(vllm_config=self.vllm_config) distributed_init_method = get_distributed_init_method( get_ip(), get_open_port()) - self.driver_worker = OpenVINOWorker( - ov_core=self.ov_core, + wrapper.init_worker( + ov_core=ov.Core(), vllm_config=self.vllm_config, local_rank=0, rank=0, @@ -55,6 +45,7 @@ def _init_worker(self): kv_cache_dtype=self.cache_config.cache_dtype, is_driver_worker=True, ) + self.driver_worker = wrapper.worker self.driver_worker.init_device() self.driver_worker.load_model() @@ -132,70 +123,3 @@ async def check_health_async(self) -> None: # OpenVINOExecutor will always be healthy as long as # it's running. return - - -def _verify_and_get_model_config(config: ModelConfig) -> ModelConfig: - if config.dtype != torch.float32: - logger.warning( - f"Only float32 dtype is supported on OpenVINO, casting from {config.dtype}." # noqa: G004, E501 - ) - config.dtype = torch.float32 - if not config.enforce_eager: - logger.warning( - "CUDA graph is not supported on OpenVINO backend, fallback to the " - "eager mode.") - config.enforce_eager = True - return config - - -def _verify_and_get_cache_config(ov_core: ov.Core, - config: CacheConfig) -> CacheConfig: - if envs.VLLM_OPENVINO_CPU_KV_CACHE_PRECISION == "u8": - if not current_platform.is_openvino_cpu(): - logger.info("VLLM_OPENVINO_CPU_KV_CACHE_PRECISION is" - "ignored for GPU, f16 data type will be used.") - config.cache_dtype = ov.Type.f16 - else: - logger.info("KV cache type is overridden to u8 via " - "VLLM_OPENVINO_CPU_KV_CACHE_PRECISION env var.") - config.cache_dtype = ov.Type.u8 - else: - if current_platform.is_openvino_cpu(): - ov_device = envs.VLLM_OPENVINO_DEVICE - inference_precision = ov_core.get_property( - ov_device, hints.inference_precision) - if inference_precision == ov.Type.bf16: - config.cache_dtype = ov.Type.bf16 - else: - config.cache_dtype = ov.Type.f16 - else: - config.cache_dtype = ov.Type.f16 - - if current_platform.is_openvino_cpu(): - if config.block_size != 32: - logger.info( - f"OpenVINO CPU optimal block size is 32, overriding currently set {config.block_size}" # noqa: G004, E501 - ) - config.block_size = 32 - else: - if config.block_size != 16: - logger.info( - f"OpenVINO GPU optimal block size is 16, overriding currently set {config.block_size}" # noqa: G004, E501 - ) - config.block_size = 16 - - kv_cache_space = envs.VLLM_OPENVINO_KVCACHE_SPACE - if kv_cache_space >= 0: - if kv_cache_space == 0 and current_platform.is_openvino_cpu(): - config.openvino_kvcache_space_bytes = 4 * GiB_bytes # type: ignore - logger.warning( - "Environment variable VLLM_OPENVINO_KVCACHE_SPACE (GB) " - "for OpenVINO backend is not set, using 4 by default.") - else: - config.openvino_kvcache_space_bytes = kv_cache_space * GiB_bytes # type: ignore - else: - raise RuntimeError( - "Invalid environment variable VLLM_OPENVINO_KVCACHE_SPACE" - f" {kv_cache_space}, expect a positive integer value.") - - return config diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py index 66bab2c686c67..6542b18ae70b1 100644 --- a/vllm/executor/ray_gpu_executor.py +++ b/vllm/executor/ray_gpu_executor.py @@ -91,17 +91,6 @@ def _configure_ray_workers_use_nsight(self, return ray_remote_kwargs - def _get_worker_wrapper_args(self) -> Dict[str, Any]: - (worker_module_name, worker_class_name, - worker_class_fn) = self._get_worker_module_and_class() - - return dict( - worker_module_name=worker_module_name, - worker_class_name=worker_class_name, - worker_class_fn=worker_class_fn, - trust_remote_code=self.model_config.trust_remote_code, - ) - # child class could overwrite this to return actual env vars. def _get_env_vars_to_be_updated(self): return self._env_vars_for_all_workers @@ -135,7 +124,6 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", # Create the workers. driver_ip = get_ip() - worker_wrapper_kwargs = self._get_worker_wrapper_args() for bundle_id, bundle in enumerate(placement_group.bundle_specs): if not bundle.get("GPU", 0): continue @@ -150,7 +138,7 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", num_gpus=num_gpus, scheduling_strategy=scheduling_strategy, **ray_remote_kwargs, - )(RayWorkerWrapper).remote(**worker_wrapper_kwargs) + )(RayWorkerWrapper).remote(vllm_config=self.vllm_config) if self.use_ray_spmd_worker: self.workers.append(worker) @@ -161,7 +149,7 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", # as the resource holder for the driver process. self.driver_dummy_worker = worker self.driver_worker = RayWorkerWrapper( - **worker_wrapper_kwargs) + vllm_config=self.vllm_config) else: # Else, added to the list of workers. self.workers.append(worker) @@ -228,8 +216,8 @@ def sort_by_driver_then_worker_ip(worker): f"Every node should have a unique IP address. Got {n_nodes}" f" nodes with node ids {list(node_workers.keys())} and " f"{n_ips} unique IP addresses {all_ips}. Please check your" - " network configuration. If you set `VLLM_HOST_IP` or " - "`HOST_IP` environment variable, make sure it is unique for" + " network configuration. If you set `VLLM_HOST_IP`" + " environment variable, make sure it is unique for" " each node.") VLLM_INSTANCE_ID = get_vllm_instance_id() diff --git a/vllm/executor/ray_hpu_executor.py b/vllm/executor/ray_hpu_executor.py index a24bab6df370e..a74328e5aa272 100644 --- a/vllm/executor/ray_hpu_executor.py +++ b/vllm/executor/ray_hpu_executor.py @@ -2,8 +2,7 @@ import os from collections import defaultdict from itertools import islice, repeat -from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, - Type) +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple import msgspec @@ -18,7 +17,6 @@ from vllm.utils import (_run_task_with_lock, get_distributed_init_method, get_ip, get_open_port, get_vllm_instance_id, make_async) -from vllm.worker.worker_base import WorkerBase if ray is not None: from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy @@ -81,33 +79,6 @@ def shutdown(self) -> None: def finish_measurements(self): self._run_workers("finish_measurements") - def _get_worker_module_and_class( - self - ) -> Tuple[str, str, Optional[Callable[[], - Type[WorkerBase]]]]: # noqa: F821 - worker_class_fn = None - if self.scheduler_config.is_multi_step: - raise NotImplementedError( - "Multi-step execution is not implemented for HPU") - elif self.speculative_config: - raise NotImplementedError( - "Speculative decoding is not implemented for HPU") - else: - worker_module_name = "vllm.worker.hpu_worker" - worker_class_name = "HPUWorker" - return (worker_module_name, worker_class_name, worker_class_fn) - - def _get_worker_wrapper_args(self) -> Dict[str, Any]: - (worker_module_name, worker_class_name, - worker_class_fn) = self._get_worker_module_and_class() - - return dict( - worker_module_name=worker_module_name, - worker_class_name=worker_class_name, - worker_class_fn=worker_class_fn, - trust_remote_code=self.model_config.trust_remote_code, - ) - def _init_workers_ray(self, placement_group: "PlacementGroup", **ray_remote_kwargs): # Otherwise, the ray workers are allocated with a full GPU. @@ -128,7 +99,6 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", # Create the workers. driver_ip = get_ip() - worker_wrapper_kwargs = self._get_worker_wrapper_args() for bundle_id, bundle in enumerate(placement_group.bundle_specs): if not bundle.get("HPU", 0): continue @@ -144,7 +114,7 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", resources={'HPU': num_gpus}, scheduling_strategy=scheduling_strategy, **ray_remote_kwargs, - )(RayWorkerWrapper).remote(**worker_wrapper_kwargs) + )(RayWorkerWrapper).remote(vllm_config=self.vllm_config) if self.use_ray_spmd_worker: self.workers.append(worker) @@ -155,7 +125,7 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", # as the resource holder for the driver process. self.driver_dummy_worker = worker self.driver_worker = RayWorkerWrapper( - **worker_wrapper_kwargs) + vllm_config=self.vllm_config) else: # Else, added to the list of workers. self.workers.append(worker) @@ -222,8 +192,8 @@ def sort_by_driver_then_worker_ip(worker): f"Every node should have a unique IP address. Got {n_nodes}" f" nodes with node ids {list(node_workers.keys())} and " f"{n_ips} unique IP addresses {all_ips}. Please check your" - " network configuration. If you set `VLLM_HOST_IP` or " - "`HOST_IP` environment variable, make sure it is unique for" + " network configuration. If you set `VLLM_HOST_IP` " + "environment variable, make sure it is unique for" " each node.") VLLM_INSTANCE_ID = get_vllm_instance_id() diff --git a/vllm/executor/ray_tpu_executor.py b/vllm/executor/ray_tpu_executor.py index d02fecb46f007..c227b5e283c68 100644 --- a/vllm/executor/ray_tpu_executor.py +++ b/vllm/executor/ray_tpu_executor.py @@ -69,14 +69,6 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", placement_group_bundle_index=bundle_id, ) - assert self.speculative_config is None - if self.scheduler_config.is_multi_step: - worker_module_name = "vllm.worker.multi_step_tpu_worker" - worker_class_name = "MultiStepTPUWorker" - else: - worker_module_name = "vllm.worker.tpu_worker" - worker_class_name = "TPUWorker" - # GKE does not fetch environment information from metadata server # and instead sets these from within the Ray process. Therefore we # need to override the Ray environment variables manually. @@ -95,11 +87,7 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", resources={"TPU": 1}, scheduling_strategy=scheduling_strategy, **ray_remote_kwargs, - )(RayWorkerWrapper).remote( - worker_module_name=worker_module_name, - worker_class_name=worker_class_name, - trust_remote_code=self.model_config.trust_remote_code, - ) + )(RayWorkerWrapper).remote(vllm_config=self.vllm_config) if override_env: worker.override_env_vars.remote(override_env) @@ -109,10 +97,7 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", # as the resource holder for the driver process. self.driver_dummy_worker = worker self.driver_worker = RayWorkerWrapper( - worker_module_name=worker_module_name, - worker_class_name=worker_class_name, - trust_remote_code=self.model_config.trust_remote_code, - ) + vllm_config=self.vllm_config) else: # Else, added to the list of workers. self.workers.append(worker) diff --git a/vllm/executor/xpu_executor.py b/vllm/executor/xpu_executor.py index 36b7e2265efab..722b86a95ff8a 100644 --- a/vllm/executor/xpu_executor.py +++ b/vllm/executor/xpu_executor.py @@ -1,15 +1,11 @@ -from typing import Callable, List, Optional, Tuple, Type, Union +from typing import List, Optional, Union -import torch - -from vllm.config import ModelConfig, ParallelConfig from vllm.executor.executor_base import ExecutorAsyncBase from vllm.executor.gpu_executor import GPUExecutor from vllm.logger import init_logger from vllm.model_executor.layers.sampler import SamplerOutput from vllm.sequence import ExecuteModelRequest, PoolerOutput from vllm.utils import make_async -from vllm.worker.worker_base import WorkerBase logger = init_logger(__name__) @@ -23,20 +19,8 @@ def _init_executor(self) -> None: assert self.speculative_config is None, ( "Speculative decoding not yet supported for XPU backend") - self.model_config = _verify_and_get_model_config(self.model_config) GPUExecutor._init_executor(self) - def _get_worker_module_and_class( - self) -> Tuple[str, str, Optional[Callable[[], Type[WorkerBase]]]]: - worker_class_fn = None - if self.speculative_config is not None: - raise NotImplementedError( - "XPU does not support speculative decoding") - else: - worker_module_name = "vllm.worker.xpu_worker" - worker_class_name = "XPUWorker" - return (worker_module_name, worker_class_name, worker_class_fn) - def execute_model( self, execute_model_req: ExecuteModelRequest ) -> Optional[List[Union[SamplerOutput, PoolerOutput]]]: @@ -53,26 +37,3 @@ async def execute_model_async( output = await make_async(self.driver_worker.execute_model )(execute_model_req=execute_model_req) return output - - -def _verify_and_get_model_config(config: ModelConfig) -> ModelConfig: - if config.dtype == torch.bfloat16: - logger.warning( - "bfloat16 is not fully supported on XPU, casting to float16.") - config.dtype = torch.float16 - if not config.enforce_eager: - logger.warning( - "CUDA graph is not supported on XPU, fallback to the eager " - "mode.") - config.enforce_eager = True - return config - - -def _verify_and_get_parallel_config(config: ParallelConfig) -> ParallelConfig: - if (config.distributed_executor_backend is not None - and config.distributed_executor_backend != "ray"): - logger.warning( - "%s is not supported on XPU, fallback to ray distributed executor " - "backend.", config.distributed_executor_backend) - config.distributed_executor_backend = "ray" - return config diff --git a/vllm/forward_context.py b/vllm/forward_context.py index 777747505e14a..aaa3e4bb3a1e8 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -1,21 +1,38 @@ from contextlib import contextmanager -from typing import Any +from dataclasses import dataclass +from typing import Any, Dict, Optional -_forward_context: Any = None +from vllm.config import VllmConfig -def get_forward_context() -> Any: +@dataclass +class ForwardContext: + static_forward_context: Dict[str, Any] + # TODO: extend to support per-layer dynamic forward context + dynamic_forward_context: Any + + +_forward_context: Optional[ForwardContext] = None + + +def get_forward_context() -> ForwardContext: """Get the current forward context.""" + assert _forward_context is not None, ( + "Forward context is not set. " + "Please use `set_forward_context` to set the forward context.") return _forward_context @contextmanager -def set_forward_context(context: Any): +def set_forward_context(context: Any, vllm_config: VllmConfig): """A context manager that stores the current forward context, can be attention metadata, etc.""" global _forward_context prev_context = _forward_context - _forward_context = context + _forward_context = ForwardContext( + static_forward_context=vllm_config.compilation_config. + static_forward_context, + dynamic_forward_context=context) try: yield finally: diff --git a/vllm/inputs/data.py b/vllm/inputs/data.py index 07ff9faa50f13..fb7dbbebd7b90 100644 --- a/vllm/inputs/data.py +++ b/vllm/inputs/data.py @@ -38,6 +38,9 @@ class TokensPrompt(TypedDict): prompt_token_ids: List[int] """A list of token IDs to pass to the model.""" + token_type_ids: NotRequired[List[int]] + """A list of token type IDs to pass to the cross encoder model.""" + multi_modal_data: NotRequired["MultiModalDataDict"] """ DEPRECATED: Optional multi-modal data to pass to the model, @@ -133,6 +136,9 @@ class TokenInputs(TypedDict): prompt_token_ids: List[int] """The token IDs of the prompt.""" + token_type_ids: NotRequired[List[int]] + """The token type IDs of the prompt.""" + prompt: NotRequired[str] """ The original prompt text corresponding to the token IDs, if available. @@ -160,6 +166,7 @@ class TokenInputs(TypedDict): def token_inputs( prompt_token_ids: List[int], + token_type_ids: Optional[List[int]] = None, prompt: Optional[str] = None, multi_modal_data: Optional["MultiModalDataDict"] = None, multi_modal_placeholders: Optional["MultiModalPlaceholderDict"] = None, @@ -170,6 +177,8 @@ def token_inputs( if prompt is not None: inputs["prompt"] = prompt + if token_type_ids is not None: + inputs["token_type_ids"] = token_type_ids if multi_modal_data is not None: inputs["multi_modal_data"] = multi_modal_data if multi_modal_placeholders is not None: @@ -234,6 +243,15 @@ def prompt_token_ids(self) -> List[int]: assert_never(inputs) + @cached_property + def token_type_ids(self) -> List[int]: + inputs = self.inputs + + if inputs["type"] == "token" or inputs["type"] == "multimodal": + return inputs.get("token_type_ids", []) + + assert_never(inputs) + @cached_property def prompt_embeds(self) -> Optional[torch.Tensor]: inputs = self.inputs diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py index aacff87df6d79..3d606817e90aa 100644 --- a/vllm/inputs/preprocess.py +++ b/vllm/inputs/preprocess.py @@ -10,7 +10,7 @@ from vllm.multimodal.processing import MultiModalDataDict, MultiModalInputsV2 from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup -from vllm.utils import print_warning_once +from vllm.utils import print_info_once, print_warning_once from .data import (DecoderOnlyInputs, EncoderDecoderInputs, ProcessorInputs, PromptType, SingletonInputs, SingletonPrompt, token_inputs) @@ -212,7 +212,7 @@ def _can_process_multimodal(self) -> bool: # updated to use the new multi-modal processor can_process_multimodal = self.mm_registry.has_processor(model_config) if not can_process_multimodal: - logger.info( + print_info_once( "Your model uses the legacy input pipeline instead of the new " "multi-modal processor. Please note that the legacy pipeline " "will be removed in a future release. For more details, see: " @@ -305,6 +305,7 @@ def _prompt_to_llm_inputs( tokens_content = parsed["content"] prompt_token_ids = tokens_content["prompt_token_ids"] + token_type_ids = tokens_content.get("token_type_ids") multi_modal_data = tokens_content.get("multi_modal_data") mm_processor_kwargs = tokens_content.get("mm_processor_kwargs") @@ -318,6 +319,7 @@ def _prompt_to_llm_inputs( return token_inputs( prompt_token_ids=prompt_token_ids, + token_type_ids=token_type_ids, multi_modal_data=multi_modal_data, mm_processor_kwargs=mm_processor_kwargs, ) diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index 68b4756331e6d..85ab4355cc2e4 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -11,8 +11,8 @@ from vllm.logger import init_logger from vllm.transformers_utils.processor import cached_get_processor from vllm.transformers_utils.tokenizer import AnyTokenizer -from vllm.utils import (get_allowed_kwarg_only_overrides, print_warning_once, - resolve_mm_processor_kwargs) +from vllm.utils import (ClassRegistry, get_allowed_kwarg_only_overrides, + print_warning_once, resolve_mm_processor_kwargs) from .data import ProcessorInputs, SingletonInputs from .parse import is_encoder_decoder_inputs @@ -136,12 +136,12 @@ class InputRegistry: """ def __init__(self) -> None: - self._dummy_factories_by_model_type: Dict[Type[nn.Module], - DummyDataFactory] = {} - self._dummy_encoder_factories_by_model_type: Dict[ - Type[nn.Module], DummyDataFactory] = {} - self._input_processors_by_model_type: Dict[Type[nn.Module], - InputProcessor] = {} + self._dummy_factories_by_model_type = \ + ClassRegistry[nn.Module, DummyDataFactory]() + self._dummy_encoder_factories_by_model_type = \ + ClassRegistry[nn.Module, DummyDataFactory]() + self._input_processors_by_model_type = \ + ClassRegistry[nn.Module, InputProcessor]() def _default_dummy_data_factory( self, diff --git a/vllm/logger.py b/vllm/logger.py index 9e16e591315ba..538db0dcf19aa 100644 --- a/vllm/logger.py +++ b/vllm/logger.py @@ -50,7 +50,7 @@ def _configure_vllm_root_logger() -> None: - logging_config: Optional[Dict] = None + logging_config: Dict = {} if not VLLM_CONFIGURE_LOGGING and VLLM_LOGGING_CONFIG_PATH: raise RuntimeError( @@ -75,6 +75,11 @@ def _configure_vllm_root_logger() -> None: type(custom_config).__name__) logging_config = custom_config + for formatter in logging_config.get("formatters", {}).values(): + # This provides backwards compatibility after #10134. + if formatter.get("class") == "vllm.logging.NewLineFormatter": + formatter["class"] = "vllm.logging_utils.NewLineFormatter" + if logging_config: dictConfig(logging_config) diff --git a/vllm/lora/fully_sharded_layers.py b/vllm/lora/fully_sharded_layers.py index 3443c3feb4d2a..f5c2eced9d2bb 100644 --- a/vllm/lora/fully_sharded_layers.py +++ b/vllm/lora/fully_sharded_layers.py @@ -44,6 +44,11 @@ class ColumnParallelLinearWithShardedLoRA(ColumnParallelLinearWithLoRA): Based on S-LoRA, slicing happens along the rank dim. """ + # For all LoRA layers where the `base_layer` is `ColumnParallelLinear`, + # their `lora_a` and `lora_b` have different sharding patterns. After + # completing the `lora_a` GEMM , a gather operation is performed. + # Therefore, the sharding of `lora_a` only needs to correspond with the + # gather operation. def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor: tp_rank = get_tensor_model_parallel_rank() shard_size = self.lora_a_stacked.shape[2] diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index 6afe80219fe07..3701988ff692f 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -451,6 +451,12 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA): def __init__(self, base_layer: ColumnParallelLinear) -> None: super().__init__() + # The base_layer type is ColumnParallelLinear or + # MergedColumnParallelLinear, their weight sharding logic is + # inconsistent when TP is greater than 1. + self.is_merged_col_linear = type( + base_layer) is MergedColumnParallelLinear + self.base_layer = base_layer self.tp_size = get_tensor_model_parallel_world_size() self.input_size = self.base_layer.input_size @@ -508,14 +514,30 @@ def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor: return lora_a def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor: - tensor_model_parallel_rank = get_tensor_model_parallel_rank() - shard_size = self.output_dim - start_idx = tensor_model_parallel_rank * shard_size - end_idx = (tensor_model_parallel_rank + 1) * shard_size - lora_b = lora_b[:, start_idx:end_idx] + # Applicable to cases where the base_layer is + # MergedColumnParallelLinear. + if self.is_merged_col_linear: + tp_rank = get_tensor_model_parallel_rank() + shard_size = self.output_size // 2 + offset = lora_b.shape[-1] // 2 + + left_weight = lora_b[:, tp_rank * shard_size:(tp_rank + 1) * + shard_size] + right_weight = lora_b[:, offset + tp_rank * shard_size:offset + + (tp_rank + 1) * shard_size] + lora_b = torch.cat([left_weight, right_weight], dim=1) + # Applicable to cases where the base_layer is + # ColumnParallelLinear. + else: + tensor_model_parallel_rank = get_tensor_model_parallel_rank() + shard_size = self.output_dim + start_idx = tensor_model_parallel_rank * shard_size + end_idx = (tensor_model_parallel_rank + 1) * shard_size + lora_b = lora_b[:, start_idx:end_idx] return lora_b def slice_bias(self, bias: torch.Tensor) -> torch.Tensor: + # TODO: Fix the slicing logic of bias. if bias is None: return bias tensor_model_parallel_rank = get_tensor_model_parallel_rank() @@ -779,7 +801,7 @@ def can_replace_layer( class QKVParallelLinearWithLora(ColumnParallelLinearWithLoRA): """ ColumnParallelLinear layer that is specifically designed for - qkv_proj. Certain models, such as chtglm3 and baichuan-7b, + qkv_proj. Certain models, such as chatglm3 and baichuan-7b, only contains a single LoRA within their qkv_proj layer. During inference with Tensor Parallel, the weights of lora_b diff --git a/vllm/lora/ops/bgmv_expand.py b/vllm/lora/ops/bgmv_expand.py index 6a32387a6f36c..42adb191b8ead 100644 --- a/vllm/lora/ops/bgmv_expand.py +++ b/vllm/lora/ops/bgmv_expand.py @@ -9,6 +9,8 @@ import triton import triton.language as tl +from vllm.utils import direct_register_custom_op + from .utils import get_lora_op_configs @@ -75,7 +77,9 @@ def _bgmv_expand_kernel( other=0.0, ) # [BLOCK_N,BLOCK_K] if ADD_INPUTS: - tiled_out = tl.load(c_ptr + current_n * cn_stride, mask=c_mask) + tiled_out = tl.load(c_ptr + current_n * cn_stride, + mask=c_mask, + other=0.0) accumulator = tl.sum(tiled_a * tiled_b, 1) + tiled_out else: accumulator = tl.sum(tiled_a * tiled_b, 1) @@ -160,9 +164,24 @@ def _bgmv_expand( return +def bgmv_expand_fake( + inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + add_inputs: bool = True, +) -> None: + return + + try: - bgmv_expand = torch.library.custom_op("lora::bgmv_expand", - _bgmv_expand, - mutates_args=["output_tensor"]) + direct_register_custom_op( + op_name="bgmv_expand", + op_func=_bgmv_expand, + mutates_args=["output_tensor"], + fake_impl=bgmv_expand_fake, + ) + bgmv_expand = torch.ops.vllm.bgmv_expand + except AttributeError: bgmv_expand = _bgmv_expand diff --git a/vllm/lora/ops/bgmv_expand_slice.py b/vllm/lora/ops/bgmv_expand_slice.py index 73628fd20d327..f397d752a3ea9 100644 --- a/vllm/lora/ops/bgmv_expand_slice.py +++ b/vllm/lora/ops/bgmv_expand_slice.py @@ -9,6 +9,8 @@ import triton import triton.language as tl +from vllm.utils import direct_register_custom_op + from .utils import get_lora_op_configs @@ -78,7 +80,13 @@ def _bgmv_expand_slice_kernel( ) # [BLOCK_N,BLOCK_K] if ADD_INPUTS: - tiled_out = tl.load(c_ptr + current_n * cn_stride, mask=c_mask) + # explicitly pass in other=None to tell triton that masked values + # can be uninitialized. This is OK because the later tl.store + # operation uses the same mask, eliminating the risk of garbage + # values propagating + tiled_out = tl.load(c_ptr + current_n * cn_stride, + mask=c_mask, + other=None) accumulator = tl.sum(tiled_a * tiled_b, 1) + tiled_out else: accumulator = tl.sum(tiled_a * tiled_b, 1) @@ -173,9 +181,26 @@ def _bgmv_expand_slice( return +def bgmv_expand_slice_fake( + inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + slice_offset: int, + slice_size: int, + add_inputs: bool = True, +) -> None: + return + + try: - bgmv_expand_slice = torch.library.custom_op("lora::bgmv_expand_slice", - _bgmv_expand_slice, - mutates_args=["output_tensor"]) + direct_register_custom_op( + op_name="bgmv_expand_slice", + op_func=_bgmv_expand_slice, + mutates_args=["output_tensor"], + fake_impl=bgmv_expand_slice_fake, + ) + bgmv_expand_slice = torch.ops.vllm.bgmv_expand_slice + except AttributeError: bgmv_expand_slice = _bgmv_expand_slice diff --git a/vllm/lora/ops/bgmv_shrink.py b/vllm/lora/ops/bgmv_shrink.py index 0846ff36b1692..f3ef01d39e776 100644 --- a/vllm/lora/ops/bgmv_shrink.py +++ b/vllm/lora/ops/bgmv_shrink.py @@ -9,6 +9,8 @@ import triton import triton.language as tl +from vllm.utils import direct_register_custom_op + from .utils import get_lora_op_configs @@ -142,9 +144,24 @@ def _bgmv_shrink( return +def bgmv_shrink_fake( + inputs: torch.Tensor, + lora_a_weights: torch.Tensor, + output_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + scaling: float = 1.0, +) -> None: + return + + try: - bgmv_shrink = torch.library.custom_op("lora::bgmv_shrink", - _bgmv_shrink, - mutates_args=["output_tensor"]) + direct_register_custom_op( + op_name="bgmv_shrink", + op_func=_bgmv_shrink, + mutates_args=["output_tensor"], + fake_impl=bgmv_shrink_fake, + ) + bgmv_shrink = torch.ops.vllm.bgmv_shrink + except AttributeError: bgmv_shrink = _bgmv_shrink diff --git a/vllm/lora/ops/sgmv_expand.py b/vllm/lora/ops/sgmv_expand.py index 4910cb4061298..77c5178493c44 100644 --- a/vllm/lora/ops/sgmv_expand.py +++ b/vllm/lora/ops/sgmv_expand.py @@ -9,6 +9,8 @@ import triton import triton.language as tl +from vllm.utils import direct_register_custom_op + @triton.jit def _sgmv_expand_kernel( @@ -88,7 +90,10 @@ def _sgmv_expand_kernel( c_mask = (offset_cm[:, None] < (cur_seq_start + M)) & (offset_cn[None, :] < N) if ADD_INPUTS: - tiled_out = tl.load(c_ptr, mask=c_mask) + # explicitly pass in other=None to tell triton that masked values + # can be uninitialized. This is OK because the later tl.store operation + # uses the same mask, eliminating the risk of garbage values propagating + tiled_out = tl.load(c_ptr, mask=c_mask, other=None) tiled_c += tiled_out tl.store(c_ptr, tiled_c, mask=c_mask) @@ -193,9 +198,30 @@ def _sgmv_expand( return +def sgmv_expand_fake( + inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + b_seq_start_loc: torch.Tensor, + seq_len_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + batches: int, + max_seq_length: int, + token_nums: int, + add_inputs: bool = False, +) -> None: + return + + try: - sgmv_expand = torch.library.custom_op("lora::sgmv_expand", - _sgmv_expand, - mutates_args=["output_tensor"]) + + direct_register_custom_op( + op_name="sgmv_expand", + op_func=_sgmv_expand, + mutates_args=["output_tensor"], + fake_impl=sgmv_expand_fake, + ) + sgmv_expand = torch.ops.vllm.sgmv_expand + except AttributeError: sgmv_expand = _sgmv_expand diff --git a/vllm/lora/ops/sgmv_expand_slice.py b/vllm/lora/ops/sgmv_expand_slice.py index 844f5cec39e93..55c4fb68ed128 100644 --- a/vllm/lora/ops/sgmv_expand_slice.py +++ b/vllm/lora/ops/sgmv_expand_slice.py @@ -9,6 +9,8 @@ import triton import triton.language as tl +from vllm.utils import direct_register_custom_op + @triton.jit def _sgmv_expand_slice_kernel( @@ -94,7 +96,10 @@ def _sgmv_expand_slice_kernel( c_mask = (offset_cm[:, None] < (cur_seq_start + M)) & (offset_cn[None, :] < (slice_offset + N)) if ADD_INPUTS: - tiled_out = tl.load(c_ptr, mask=c_mask) + # explicitly pass in other=None to tell triton that masked values + # can be uninitialized. This is OK because the later tl.store operation + # uses the same mask, eliminating the risk of garbage values propagating + tiled_out = tl.load(c_ptr, mask=c_mask, other=None) tiled_c += tiled_out tl.store(c_ptr, tiled_c, mask=c_mask) @@ -206,9 +211,31 @@ def _sgmv_expand_slice( return +def sgmv_expand_slice_fake( + inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + b_seq_start_loc: torch.Tensor, + seq_len_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + batches: int, + max_seq_length: int, + token_nums: int, + slice_offset: int, + slice_size: int, + add_inputs: bool = False, +) -> None: + return + + try: - sgmv_expand_slice = torch.library.custom_op("lora::sgmv_expand_slice", - _sgmv_expand_slice, - mutates_args=["output_tensor"]) + direct_register_custom_op( + op_name="sgmv_expand_slice", + op_func=_sgmv_expand_slice, + mutates_args=["output_tensor"], + fake_impl=sgmv_expand_slice_fake, + ) + sgmv_expand_slice = torch.ops.vllm.sgmv_expand_slice + except AttributeError: sgmv_expand_slice = _sgmv_expand_slice diff --git a/vllm/lora/ops/sgmv_shrink.py b/vllm/lora/ops/sgmv_shrink.py index b4d893047b06b..37d1dc84eebca 100644 --- a/vllm/lora/ops/sgmv_shrink.py +++ b/vllm/lora/ops/sgmv_shrink.py @@ -9,6 +9,8 @@ import triton import triton.language as tl +from vllm.utils import direct_register_custom_op + @triton.jit def _sgmv_shrink_kernel( @@ -190,9 +192,29 @@ def _sgmv_shrink( return +def sgmv_shrink_fake( + inputs: torch.Tensor, + lora_a_weights: torch.Tensor, + output_tensor: torch.Tensor, + b_seq_start_loc: torch.Tensor, + seq_len_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + batches: int, + max_seq_length: int, + token_nums: int, + scaling: float, +) -> None: + return + + try: - sgmv_shrink = torch.library.custom_op("lora::sgmv_shrink", - _sgmv_shrink, - mutates_args=["output_tensor"]) + direct_register_custom_op( + op_name="sgmv_shrink", + op_func=_sgmv_shrink, + mutates_args=["output_tensor"], + fake_impl=sgmv_shrink_fake, + ) + sgmv_shrink = torch.ops.vllm.sgmv_shrink + except AttributeError: sgmv_shrink = _sgmv_shrink diff --git a/vllm/model_executor/custom_op.py b/vllm/model_executor/custom_op.py index 6ae7d7cf6964f..fddc8bad09ef5 100644 --- a/vllm/model_executor/custom_op.py +++ b/vllm/model_executor/custom_op.py @@ -2,9 +2,9 @@ import torch.nn as nn +from vllm.config import get_current_vllm_config from vllm.logger import init_logger from vllm.platforms import current_platform -from vllm.plugins import get_current_vllm_config from vllm.utils import print_warning_once logger = init_logger(__name__) @@ -61,10 +61,13 @@ def forward_hpu(self, *args, **kwargs): def dispatch_forward(self): # NOTE(woosuk): Here we assume that vLLM was built for only one # specific backend. Currently, we do not support dynamic dispatching. - + compilation_config = get_current_vllm_config().compilation_config enabled = self.enabled() - logger.debug("custom op %s %s", self.__class__.name, - "enabled" if enabled else "disabled") + if enabled: + compilation_config.enabled_custom_ops.update([self.__class__.name]) + else: + compilation_config.disabled_custom_ops.update( + [self.__class__.name]) if not enabled: return self.forward_native diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 954e51f216325..91aefbecee5f5 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -257,7 +257,7 @@ def _load_per_tensor_weight_scale(self, shard_id: str, def _load_model_weight_or_group_weight_scale(self, shard_dim: int, expert_data: torch.Tensor, shard_id: str, - loaded_weight: torch.tensor, + loaded_weight: torch.Tensor, tp_rank: int): # Load grouped weight scales for group quantization # or model weights @@ -276,7 +276,7 @@ def _load_model_weight_or_group_weight_scale(self, shard_dim: int, def _load_per_channel_weight_scale(self, expert_data: torch.Tensor, shard_dim: int, shard_id: str, - loaded_weight: torch.tensor, + loaded_weight: torch.Tensor, tp_rank: int): # for per channel weight quantization if shard_id == "w2": @@ -289,7 +289,7 @@ def _load_per_channel_weight_scale(self, expert_data: torch.Tensor, tp_rank=tp_rank) def _load_w13(self, expert_data: torch.Tensor, shard_dim: int, - shard_id: str, loaded_weight: torch.tensor, tp_rank: int): + shard_id: str, loaded_weight: torch.Tensor, tp_rank: int): # Index the loaded weight for tp sharding. # gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim @@ -307,7 +307,7 @@ def _load_w13(self, expert_data: torch.Tensor, shard_dim: int, expert_data.copy_(loaded_weight) def _load_w2(self, expert_data: torch.Tensor, shard_dim: int, - shard_id: str, loaded_weight: torch.tensor, tp_rank: int): + shard_id: str, loaded_weight: torch.Tensor, tp_rank: int): # Index the loaded weight for tp sharding. # down_proj: "RowParallel" so tp sharding on input_dim @@ -326,7 +326,7 @@ def _load_single_value(self, param: torch.nn.Parameter, param_data[expert_id] = loaded_weight def _load_g_idx(self, shard_id: str, expert_data: torch.Tensor, - shard_dim: int, loaded_weight: torch.tensor, tp_rank: int): + shard_dim: int, loaded_weight: torch.Tensor, tp_rank: int): if shard_id == "w2": self._load_w2(shard_id=shard_id, diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index d481dc4fb413f..8c5db5adaa77c 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -1,3 +1,4 @@ +import itertools from abc import abstractmethod from typing import Dict, List, Optional, Tuple @@ -27,7 +28,8 @@ "AWQLinearMethod", "GPTQMarlinLinearMethod", "Fp8LinearMethod", "MarlinLinearMethod", "QQQLinearMethod", "GPTQMarlin24LinearMethod", "TPUInt8LinearMethod", "GPTQLinearMethod", "FBGEMMFp8LinearMethod", - "ModelOptFp8LinearMethod", "IPEXAWQLinearMethod" + "ModelOptFp8LinearMethod", "IPEXAWQLinearMethod", "IPEXGPTQLinearMethod", + "HQQMarlinMethod" ] @@ -40,12 +42,12 @@ def adjust_marlin_shard(param, shard_size, shard_offset): def adjust_bitsandbytes_4bit_shard(param: Parameter, - qkv_offsets: Dict[str, Tuple[int, int]], + shard_offsets: Dict[str, Tuple[int, int]], loaded_shard_id: str) -> Tuple[int, int]: """Adjust the quantization offsets and sizes for BitsAndBytes sharding.""" - total, _ = qkv_offsets["total"] - orig_offset, orig_size = qkv_offsets[loaded_shard_id] + total, _ = shard_offsets["total"] + orig_offset, orig_size = shard_offsets[loaded_shard_id] quantized_total = param.data.shape[0] quantized_offset = orig_offset * quantized_total // total @@ -469,7 +471,8 @@ def weight_loader(self, needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False) if loaded_shard_id is None: - # Loaded weight is already fused on disk (qkv/mlp). + # Loaded weight is already fused on disk (mlp). + # (e.g., Phi-3's gate_up_proj). if output_dim is None: if needs_scalar_to_array: param_data, loaded_weight = adjust_scalar_to_fused_array( @@ -479,6 +482,8 @@ def weight_loader(self, param_data.copy_(loaded_weight) return current_shard_offset = 0 + use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", + False) shard_offsets: List[Tuple[int, int, int]] = [] for i, output_size in enumerate(self.output_sizes): shard_offsets.append((i, current_shard_offset, output_size)) @@ -495,6 +500,16 @@ def weight_loader(self, shard_size, shard_offset = adjust_marlin_shard( param, shard_size, shard_offset) + if use_bitsandbytes_4bit: + index = list(itertools.accumulate([0] + self.output_sizes)) + orig_offsets = { + str(i): (index[i], size) + for i, size in enumerate(self.output_sizes) + } + orig_offsets["total"] = (self.output_size, 0) + shard_size, shard_offset = adjust_bitsandbytes_4bit_shard( + param, orig_offsets, str(shard_id)) + loaded_weight_shard = loaded_weight.narrow( output_dim, shard_offset, shard_size) self.weight_loader(param, loaded_weight_shard, shard_id) @@ -807,7 +822,8 @@ def weight_loader(self, needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False) if loaded_shard_id is None: - # Loaded weight is already fused on disk (qkv/mlp). + # Loaded weight is already fused on disk (qkv). + # (e.g., Phi-3's qkv_proj). if output_dim is None: if needs_scalar_to_array: param_data, loaded_weight = adjust_scalar_to_fused_array( diff --git a/vllm/model_executor/layers/pooler.py b/vllm/model_executor/layers/pooler.py index bfe2d7d0f382e..e0d42e30ebef3 100644 --- a/vllm/model_executor/layers/pooler.py +++ b/vllm/model_executor/layers/pooler.py @@ -3,11 +3,14 @@ import torch import torch.nn as nn +from transformers import PretrainedConfig from vllm.config import PoolerConfig from vllm.model_executor.pooling_metadata import (PoolingMetadata, PoolingTensors) from vllm.sequence import EmbeddingSequenceGroupOutput, PoolerOutput +from vllm.transformers_utils.config import ( + get_cross_encoder_activation_function) class PoolingType(IntEnum): @@ -57,9 +60,7 @@ def from_config_with_defaults( softmax: bool, step_tag_id: Optional[int] = None, returned_token_ids: Optional[List[int]] = None, - ) -> Optional["Pooler"]: - if pooler_config is None: - return None + ) -> "Pooler": return cls( pooling_type=PoolingType[pooler_config.pooling_type] if pooler_config.pooling_type is not None else pooling_type, @@ -94,14 +95,10 @@ def forward( pooled_data = hidden_states[last_token_flat_indices] elif self.pooling_type == PoolingType.ALL: offset = 0 - pooled_data_lst = [] + pooled_data = [] for prompt_len in prompt_lens: - pooled_data_i = hidden_states[offset:offset + prompt_len] - - pooled_data_lst.append(pooled_data_i) + pooled_data.append(hidden_states[offset:offset + prompt_len]) offset += prompt_len - - pooled_data = torch.stack(pooled_data_lst) elif self.pooling_type == PoolingType.MEAN: # Calculate mean pooling cumsum = torch.cumsum(hidden_states, dim=0) @@ -121,7 +118,7 @@ def forward( step_tag_id = self.step_tag_id offset = 0 - pooled_data_lst = [] + pooled_data = [] for prompt_len, seq_data_i in zip( prompt_lens, pooling_metadata.seq_data.values()): pooled_data_i = hidden_states[offset:offset + prompt_len] @@ -130,20 +127,90 @@ def forward( pooled_data_i = pooled_data_i[token_ids == step_tag_id] offset += prompt_len - pooled_data_lst.append(pooled_data_i) - - pooled_data = torch.stack(pooled_data_lst) + pooled_data.append(pooled_data_i) else: raise ValueError(f"Invalid pooling type: {self.pooling_type}") if self.normalize: - pooled_data = nn.functional.normalize(pooled_data, p=2, dim=1) + if isinstance(pooled_data, list): + pooled_data = [ + nn.functional.normalize(data, p=2, dim=1) + for data in pooled_data + ] + else: + pooled_data = nn.functional.normalize(pooled_data, p=2, dim=1) if self.softmax: - pooled_data = nn.functional.softmax(pooled_data, dim=-1) + if isinstance(pooled_data, list): + pooled_data = [ + nn.functional.softmax(data, dim=-1) for data in pooled_data + ] + else: + pooled_data = nn.functional.softmax(pooled_data, dim=-1) pooled_outputs = [ EmbeddingSequenceGroupOutput(data.tolist()) for data in pooled_data ] return PoolerOutput(outputs=pooled_outputs) + + +class CrossEncodingPooler(nn.Module): + """A layer that pools specific information from hidden states. + + This layer does the following: + 1. Extracts specific tokens or aggregates data based on pooling method. + 2. Normalizes output if specified. + 3. Returns structured results as `PoolerOutput`. + + Attributes: + pooling_type: The type of pooling to use. + normalize: Whether to normalize the pooled data. + """ + + def __init__( + self, + config: PretrainedConfig, + classifier: nn.Module, + pooler: Optional[nn.Module] = None, + ): + super().__init__() + self.classifier = classifier + self.pooler = pooler + self.default_activation_function = \ + get_cross_encoder_activation_function(config) + + def forward( + self, + hidden_states: torch.Tensor, + pooling_metadata: PoolingMetadata, + ) -> PoolerOutput: + """Pools sentence pair scores from the hidden_states.""" + + prompt_lens = PoolingTensors.from_pooling_metadata( + pooling_metadata, hidden_states.device).prompt_lens + + offset = 0 + pooled_data_lst = [] + for prompt_len in prompt_lens: + pooled_data_i = hidden_states[offset:offset + prompt_len] + + if self.pooler is not None: + final_shape_tensor = self.pooler(pooled_data_i) + else: + final_shape_tensor = self.classifier(pooled_data_i) + + pooled_data_lst.append(final_shape_tensor) + offset += prompt_len + + pooled_output = torch.stack(pooled_data_lst) + + if self.pooler is not None: + # apply classifier once on the full batch if possible + pooled_output = self.classifier(pooled_output) + logits = self.default_activation_function(pooled_output) + + pooled_outputs = [ + EmbeddingSequenceGroupOutput(data.tolist()) for data in logits + ] + return PoolerOutput(outputs=pooled_outputs) diff --git a/vllm/model_executor/layers/quantization/__init__.py b/vllm/model_executor/layers/quantization/__init__.py index da841d052d728..dd10c434f0752 100644 --- a/vllm/model_executor/layers/quantization/__init__.py +++ b/vllm/model_executor/layers/quantization/__init__.py @@ -1,63 +1,87 @@ -from typing import Dict, Type +from typing import Dict, List, Type -from vllm.model_executor.layers.quantization.aqlm import AQLMConfig -from vllm.model_executor.layers.quantization.awq import AWQConfig -from vllm.model_executor.layers.quantization.awq_marlin import AWQMarlinConfig from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) -from vllm.model_executor.layers.quantization.bitsandbytes import ( - BitsAndBytesConfig) -from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501 - CompressedTensorsConfig) -from vllm.model_executor.layers.quantization.deepspeedfp import ( - DeepSpeedFPConfig) -from vllm.model_executor.layers.quantization.experts_int8 import ( - ExpertsInt8Config) -from vllm.model_executor.layers.quantization.fbgemm_fp8 import FBGEMMFp8Config -from vllm.model_executor.layers.quantization.fp8 import Fp8Config -from vllm.model_executor.layers.quantization.gguf import GGUFConfig -from vllm.model_executor.layers.quantization.gptq import GPTQConfig -from vllm.model_executor.layers.quantization.gptq_marlin import ( - GPTQMarlinConfig) -from vllm.model_executor.layers.quantization.gptq_marlin_24 import ( - GPTQMarlin24Config) -from vllm.model_executor.layers.quantization.ipex_quant import IPEXConfig -from vllm.model_executor.layers.quantization.marlin import MarlinConfig -from vllm.model_executor.layers.quantization.modelopt import ModelOptFp8Config -from vllm.model_executor.layers.quantization.neuron_quant import ( - NeuronQuantConfig) -from vllm.model_executor.layers.quantization.qqq import QQQConfig -from vllm.model_executor.layers.quantization.tpu_int8 import Int8TpuConfig -QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = { - "aqlm": AQLMConfig, - "awq": AWQConfig, - "deepspeedfp": DeepSpeedFPConfig, - "tpu_int8": Int8TpuConfig, - "fp8": Fp8Config, - "fbgemm_fp8": FBGEMMFp8Config, - "modelopt": ModelOptFp8Config, +QUANTIZATION_METHODS: List[str] = [ + "aqlm", + "awq", + "deepspeedfp", + "tpu_int8", + "fp8", + "fbgemm_fp8", + "modelopt", # The order of gptq methods is important for config.py iteration over # override_quantization_method(..) - "marlin": MarlinConfig, - "gguf": GGUFConfig, - "gptq_marlin_24": GPTQMarlin24Config, - "gptq_marlin": GPTQMarlinConfig, - "awq_marlin": AWQMarlinConfig, - "gptq": GPTQConfig, - "compressed-tensors": CompressedTensorsConfig, - "bitsandbytes": BitsAndBytesConfig, - "qqq": QQQConfig, - "experts_int8": ExpertsInt8Config, - "neuron_quant": NeuronQuantConfig, - "ipex": IPEXConfig, -} + "marlin", + "gguf", + "gptq_marlin_24", + "gptq_marlin", + "awq_marlin", + "gptq", + "compressed-tensors", + "bitsandbytes", + "qqq", + "hqq", + "experts_int8", + "neuron_quant", + "ipex", +] def get_quantization_config(quantization: str) -> Type[QuantizationConfig]: if quantization not in QUANTIZATION_METHODS: raise ValueError(f"Invalid quantization method: {quantization}") - return QUANTIZATION_METHODS[quantization] + + # lazy import to avoid triggering `torch.compile` too early + from .aqlm import AQLMConfig + from .awq import AWQConfig + from .awq_marlin import AWQMarlinConfig + from .bitsandbytes import BitsAndBytesConfig + from .compressed_tensors.compressed_tensors import ( # noqa: E501 + CompressedTensorsConfig) + from .deepspeedfp import DeepSpeedFPConfig + from .experts_int8 import ExpertsInt8Config + from .fbgemm_fp8 import FBGEMMFp8Config + from .fp8 import Fp8Config + from .gguf import GGUFConfig + from .gptq import GPTQConfig + from .gptq_marlin import GPTQMarlinConfig + from .gptq_marlin_24 import GPTQMarlin24Config + from .hqq_marlin import HQQMarlinConfig + from .ipex_quant import IPEXConfig + from .marlin import MarlinConfig + from .modelopt import ModelOptFp8Config + from .neuron_quant import NeuronQuantConfig + from .qqq import QQQConfig + from .tpu_int8 import Int8TpuConfig + + method_to_config: Dict[str, Type[QuantizationConfig]] = { + "aqlm": AQLMConfig, + "awq": AWQConfig, + "deepspeedfp": DeepSpeedFPConfig, + "tpu_int8": Int8TpuConfig, + "fp8": Fp8Config, + "fbgemm_fp8": FBGEMMFp8Config, + "modelopt": ModelOptFp8Config, + # The order of gptq methods is important for config.py iteration over + # override_quantization_method(..) + "marlin": MarlinConfig, + "gguf": GGUFConfig, + "gptq_marlin_24": GPTQMarlin24Config, + "gptq_marlin": GPTQMarlinConfig, + "awq_marlin": AWQMarlinConfig, + "gptq": GPTQConfig, + "compressed-tensors": CompressedTensorsConfig, + "bitsandbytes": BitsAndBytesConfig, + "qqq": QQQConfig, + "hqq": HQQMarlinConfig, + "experts_int8": ExpertsInt8Config, + "neuron_quant": NeuronQuantConfig, + "ipex": IPEXConfig, + } + + return method_to_config[quantization] __all__ = [ diff --git a/vllm/model_executor/layers/quantization/awq_triton.py b/vllm/model_executor/layers/quantization/awq_triton.py index bbb7fc8ad5087..ace8f4a348812 100644 --- a/vllm/model_executor/layers/quantization/awq_triton.py +++ b/vllm/model_executor/layers/quantization/awq_triton.py @@ -42,7 +42,7 @@ def awq_dequantize_kernel( result_masks = result_masks_y[:, None] & result_masks_x[None, :] # Load the weights. - iweights = tl.load(qweight_ptr + offsets, masks) + iweights = tl.load(qweight_ptr + offsets, masks, 0.0) iweights = tl.interleave(iweights, iweights) iweights = tl.interleave(iweights, iweights) iweights = tl.interleave(iweights, iweights) @@ -71,7 +71,7 @@ def awq_dequantize_kernel( zero_masks = zero_masks_y[:, None] & zero_masks_x[None, :] # Load the zeros. - zeros = tl.load(zeros_ptr + zero_offsets, zero_masks) + zeros = tl.load(zeros_ptr + zero_offsets, zero_masks, 0.0) zeros = tl.interleave(zeros, zeros) zeros = tl.interleave(zeros, zeros) zeros = tl.interleave(zeros, zeros) @@ -91,7 +91,7 @@ def awq_dequantize_kernel( scale_masks = scale_masks_y[:, None] & scale_masks_x[None, :] # Load the scales. - scales = tl.load(scales_ptr + scale_offsets, scale_masks) + scales = tl.load(scales_ptr + scale_offsets, scale_masks, 0.0) scales = tl.broadcast_to(scales, (BLOCK_SIZE_Y, BLOCK_SIZE_X * 8)) # Dequantize. @@ -165,10 +165,10 @@ def awq_gemm_kernel(a_ptr, b_ptr, c_ptr, zeros_ptr, scales_ptr, M, N, K, for k in range(0, tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K)): masks_k = offsets_k < K masks_a = masks_am[:, None] & masks_k[None, :] - a = tl.load(a_ptrs, mask=masks_a) + a = tl.load(a_ptrs, mask=masks_a, other=0.0) masks_b = masks_k[:, None] & masks_bn[None, :] - b = tl.load(b_ptrs, mask=masks_b) + b = tl.load(b_ptrs, mask=masks_b, other=0.0) b = tl.interleave(b, b) b = tl.interleave(b, b) b = tl.interleave(b, b) @@ -181,7 +181,7 @@ def awq_gemm_kernel(a_ptr, b_ptr, c_ptr, zeros_ptr, scales_ptr, M, N, K, masks_zk = offsets_szk < K // group_size masks_z = masks_zk[:, None] & masks_zn[None, :] zeros_ptrs = zeros_ptr + offsets_z - zeros = tl.load(zeros_ptrs, mask=masks_z) + zeros = tl.load(zeros_ptrs, mask=masks_z, other=0.0) zeros = tl.interleave(zeros, zeros) zeros = tl.interleave(zeros, zeros) zeros = tl.interleave(zeros, zeros) @@ -191,7 +191,7 @@ def awq_gemm_kernel(a_ptr, b_ptr, c_ptr, zeros_ptr, scales_ptr, M, N, K, masks_sk = offsets_szk < K // group_size masks_s = masks_sk[:, None] & masks_sn[None, :] scales_ptrs = scales_ptr + offsets_s - scales = tl.load(scales_ptrs, mask=masks_s) + scales = tl.load(scales_ptrs, mask=masks_s, other=0.0) scales = tl.broadcast_to(scales, (BLOCK_SIZE_K, BLOCK_SIZE_N)) b = (b >> shifts) & 0xF diff --git a/vllm/model_executor/layers/quantization/bitsandbytes.py b/vllm/model_executor/layers/quantization/bitsandbytes.py index 39965ac9115c2..e01c713dd14db 100644 --- a/vllm/model_executor/layers/quantization/bitsandbytes.py +++ b/vllm/model_executor/layers/quantization/bitsandbytes.py @@ -20,17 +20,19 @@ def __init__( load_in_8bit: bool = False, load_in_4bit: bool = True, bnb_4bit_compute_dtype: str = "float32", + bnb_4bit_quant_storage: str = "uint8", bnb_4bit_quant_type: str = "fp4", bnb_4bit_use_double_quant: bool = False, llm_int8_enable_fp32_cpu_offload: bool = False, llm_int8_has_fp16_weight: bool = False, llm_int8_skip_modules: Optional[List[str]] = None, - llm_int8_threshold: float = 0.0, + llm_int8_threshold: float = 6.0, ) -> None: self.load_in_8bit = load_in_8bit self.load_in_4bit = load_in_4bit self.bnb_4bit_compute_dtype = bnb_4bit_compute_dtype + self.bnb_4bit_quant_storage = bnb_4bit_quant_storage self.bnb_4bit_quant_type = bnb_4bit_quant_type self.bnb_4bit_use_double_quant = bnb_4bit_use_double_quant self.llm_int8_enable_fp32_cpu_offload = llm_int8_enable_fp32_cpu_offload @@ -38,10 +40,15 @@ def __init__( self.llm_int8_skip_modules = llm_int8_skip_modules or [] self.llm_int8_threshold = llm_int8_threshold + if self.bnb_4bit_quant_storage not in ["uint8"]: + raise ValueError("Unsupported bnb_4bit_quant_storage: " + f"{self.bnb_4bit_quant_storage}") + def __repr__(self) -> str: return (f"BitsAndBytesConfig(load_in_8bit={self.load_in_8bit}, " f"load_in_4bit={self.load_in_4bit}, " f"bnb_4bit_compute_dtype={self.bnb_4bit_compute_dtype}, " + f"bnb_4bit_quant_storage={self.bnb_4bit_quant_storage}, " f"bnb_4bit_quant_type={self.bnb_4bit_quant_type}, " f"llm_int8_skip_modules={self.llm_int8_skip_modules})") @@ -80,6 +87,9 @@ def get_safe_value(config, keys, default_value=None): bnb_4bit_compute_dtype = get_safe_value(config, ["bnb_4bit_compute_dtype"], default_value="float32") + bnb_4bit_quant_storage = get_safe_value(config, + ["bnb_4bit_quant_storage"], + default_value="uint8") bnb_4bit_quant_type = get_safe_value(config, ["bnb_4bit_quant_type"], default_value="fp4") bnb_4bit_use_double_quant = get_safe_value( @@ -93,12 +103,13 @@ def get_safe_value(config, keys, default_value=None): ["llm_int8_skip_modules"], default_value=[]) llm_int8_threshold = get_safe_value(config, ["llm_int8_threshold"], - default_value=0.0) + default_value=6.0) return cls( load_in_8bit=load_in_8bit, load_in_4bit=load_in_4bit, bnb_4bit_compute_dtype=bnb_4bit_compute_dtype, + bnb_4bit_quant_storage=bnb_4bit_quant_storage, bnb_4bit_quant_type=bnb_4bit_quant_type, bnb_4bit_use_double_quant=bnb_4bit_use_double_quant, llm_int8_enable_fp32_cpu_offload=llm_int8_enable_fp32_cpu_offload, diff --git a/vllm/model_executor/layers/quantization/gguf.py b/vllm/model_executor/layers/quantization/gguf.py index 24138662eb25c..f0943efa0039d 100644 --- a/vllm/model_executor/layers/quantization/gguf.py +++ b/vllm/model_executor/layers/quantization/gguf.py @@ -2,6 +2,7 @@ import gguf import torch +from gguf import GGMLQuantizationType as WeightType from torch.nn.parameter import Parameter, UninitializedParameter from vllm import _custom_ops as ops @@ -49,19 +50,65 @@ def get_quant_method(self, layer: torch.nn.Module, return None +UNQUANTIZED_TYPES = {WeightType.F32, WeightType.F16, WeightType.BF16} +STANDARD_QUANT_TYPES = { + WeightType.Q4_0, + WeightType.Q4_1, + WeightType.Q5_0, + WeightType.Q5_1, + WeightType.Q8_0, + WeightType.Q8_1, +} +KQUANT_TYPES = { + WeightType.Q2_K, + WeightType.Q3_K, + WeightType.Q4_K, + WeightType.Q5_K, + WeightType.Q6_K, +} +IMATRIX_QUANT_TYPES = { + WeightType.IQ1_M, + WeightType.IQ1_S, + WeightType.IQ2_XXS, + WeightType.IQ2_XS, + WeightType.IQ2_S, + WeightType.IQ3_XXS, + WeightType.IQ3_S, + WeightType.IQ4_XS, + WeightType.IQ4_NL, +} +# TODO(Isotr0py): Currently, we don't have MMQ kernel for I-Matrix quantization. +# Consolidate DEQUANT_TYPES, MMVQ_QUANT_TYPES and MMQ_QUANT_TYPES after we add +# MMQ kernel for I-Matrix quantization. +DEQUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES | IMATRIX_QUANT_TYPES +MMVQ_QUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES | IMATRIX_QUANT_TYPES +MMQ_QUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES + + def _fuse_mul_mat(x: torch.Tensor, qweight: torch.Tensor, qweight_type: int) -> torch.Tensor: - # use dequantize mulmat for IQmatrix, mmq for k-quants - if x.shape[0] == 1: - # enable mmvq in contiguous batching + # there is no need to call any kernel for fp16/bf16 + if qweight_type in UNQUANTIZED_TYPES: + return x @ qweight.T + # enable MMVQ in contiguous batching with batch_size=1 + if x.shape[0] == 1 and qweight_type in MMVQ_QUANT_TYPES: y = ops.ggml_mul_mat_vec_a8(qweight, x, qweight_type, qweight.shape[0]) - elif qweight_type >= 16: + # Use MMQ Kernel if it's available (standard + k-quants) + elif qweight_type in MMQ_QUANT_TYPES: + y = ops.ggml_mul_mat_a8(qweight, x, qweight_type, qweight.shape[0]) + # If there is no available MMQ kernel, fallback to dequantize + elif qweight_type in DEQUANT_TYPES: block_size, type_size = gguf.GGML_QUANT_SIZES[qweight_type] shape = (qweight.shape[0], qweight.shape[1] // type_size * block_size) weight = ops.ggml_dequantize(qweight, qweight_type, *shape) y = x @ weight.T else: - y = ops.ggml_mul_mat_a8(qweight, x, qweight_type, qweight.shape[0]) + # Raise an error if the quantization type is not supported. + # Might be useful if llama.cpp adds a new quantization type. + # Wrap to GGMLQuantizationType IntEnum to make sure it's a valid type. + qweight_type = WeightType(qweight_type) + raise NotImplementedError( + f"Unsupported GGUF quantization type: {qweight_type}") return y @@ -121,9 +168,9 @@ def apply(self, shard_id = ["q", "k", "v"] if "q" in shard_id else shard_id qweight = layer.qweight.unbind(0) result = [] - for id in shard_id: - q_idx = layer.qweight.shard_id_map[id] - qweight_type = layer.qweight_type.shard_weight_type[id] + for idx in shard_id: + q_idx = layer.qweight.shard_id_map[idx] + qweight_type = layer.qweight_type.shard_weight_type[idx] result.append(_fuse_mul_mat(x, qweight[q_idx], qweight_type)) out = torch.cat(result, axis=1) else: @@ -163,9 +210,13 @@ class GGUFUninitializedParameter(UninitializedParameter): data_container: List[torch.Tensor] def materialize_nested(self) -> Parameter: + dtype = {data.dtype for data in self.data_container} + assert len(dtype) == 1, ValueError( + f"Data container has mixed dtypes: {dtype}") + dtype = next(iter(dtype)) nested_data = torch.nested.nested_tensor(self.data_container, device=self.device, - dtype=torch.uint8) + dtype=dtype) self.data_container.clear() param = torch.Tensor._make_subclass(self.cls_to_become, nested_data, diff --git a/vllm/model_executor/layers/quantization/gptq.py b/vllm/model_executor/layers/quantization/gptq.py index 0aa605e62454e..abafad0f1047e 100644 --- a/vllm/model_executor/layers/quantization/gptq.py +++ b/vllm/model_executor/layers/quantization/gptq.py @@ -210,7 +210,6 @@ def create_weights( def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # for torch.compile - layer.qweight = Parameter(layer.qweight.data, requires_grad=False) layer.qzeros = Parameter(layer.qzeros.data, requires_grad=False) layer.qweight = Parameter(layer.qweight.data, requires_grad=False) layer.g_idx = Parameter(layer.g_idx.data, requires_grad=False) diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index 1f72e3afbbce5..a3e58bf1b2a4c 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -23,6 +23,7 @@ PackedColumnParameter, PackedvLLMParameter, RowvLLMParameter) +from vllm.platforms import current_platform from vllm.scalar_type import scalar_types logger = init_logger(__name__) @@ -134,6 +135,9 @@ def is_gptq_marlin_compatible(cls, quant_config: Dict[str, Any]): sym = quant_config.get("sym") desc_act = quant_config.get("desc_act") + if not current_platform.is_cuda(): + return False + if quant_method != "gptq": return False diff --git a/vllm/model_executor/layers/quantization/hqq_marlin.py b/vllm/model_executor/layers/quantization/hqq_marlin.py new file mode 100644 index 0000000000000..28538d2993355 --- /dev/null +++ b/vllm/model_executor/layers/quantization/hqq_marlin.py @@ -0,0 +1,325 @@ +from typing import Any, Dict, List, Optional + +import torch + +from vllm import _custom_ops as ops +from vllm.logger import init_logger +from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, + UnquantizedLinearMethod) +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig, QuantizeMethodBase) +from vllm.model_executor.layers.quantization.utils.marlin_utils import ( + GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N, + marlin_make_empty_g_idx, marlin_permute_scales) +from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( + MarlinWorkspace) +from vllm.model_executor.layers.quantization.utils.quant_utils import gptq_pack +from vllm.model_executor.parameter import (BasevLLMParameter, + GroupQuantScaleParameter, + PackedvLLMParameter) +from vllm.scalar_type import scalar_types + +logger = init_logger(__name__) + + +class HQQMarlinConfig(QuantizationConfig): + """Config class for HQQ Marlin""" + + def __init__( + self, + weight_bits: int, + group_size: int, + skip_modules: Optional[List[str]] = None, + ) -> None: + assert group_size == 64, ("The only supported HQQ group size is " + "currently 64.") + assert weight_bits == 4, ("The only supported HQQ quantization " + "bitsize is currently 4.") + + self.weight_bits = weight_bits + self.group_size = group_size + self.pack_factor = 32 // weight_bits # packed into int32 in GPTQ format + self.quant_type = scalar_types.uint4 + self.skip_modules = skip_modules + + def __repr__(self) -> str: + return (f"HQQMarlinConfig(quant_type={self.quant_type}, " + f"group_size={self.group_size})") + + @classmethod + def get_name(cls) -> str: + return "hqq" + + @classmethod + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.half, torch.bfloat16] + + @classmethod + def get_min_capability(cls) -> int: + return 80 + + @classmethod + def get_config_filenames(cls) -> List[str]: + return ["quantize_config.json"] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "HQQMarlinConfig": + wq_params = (config["quant_config"]["weight_quant_params"]) + weight_bits = cls.get_from_keys(wq_params, ["nbits"]) + group_size = cls.get_from_keys(wq_params, ["group_size"]) + skip_modules = config["skip_modules"] + return cls(weight_bits, group_size, skip_modules) + + def is_layer_skipped(self, prefix: str) -> bool: + # Split the prefix into its dot-separated components + components = prefix.split('.') + + # Check if any of the skip modules exactly matches any component + return self.skip_modules is not None and any( + module_name in components for module_name in self.skip_modules) + + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional["QuantizeMethodBase"]: + if isinstance(layer, LinearBase): + if self.is_layer_skipped(prefix): + return UnquantizedLinearMethod() + return HQQMarlinMethod(self) + return None + + +# Empty HQQ parameter, will be ignored during loading +class HQQEmptyParameter(BasevLLMParameter): + + def load_merged_column_weight(self, loaded_weight: torch.Tensor, **kwargs): + pass + + def load_row_parallel_weight(self, loaded_weight: torch.Tensor): + pass + + def load_qkv_weight(self, loaded_weight: torch.Tensor, **kwargs): + pass + + +def error_loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None: + raise ValueError("No loader provided for HQQ parameter!") + + +# HQQ packing creates issues with sharding - therefore, prior to loading, we +# repack to GPTQ. We also reshape the weights to their proper GPTQ shape. +class HQQweightParameter(PackedvLLMParameter): + + # unpack function from https://github.com/mobiusml/hqq + def unpack_4bit_u8(self, + W_q: torch.Tensor) -> torch.Tensor: # uint8/2 > uint8 + assert self.weight_bits == 4, "Unsupported quant bitsize (must be 4)" + + dtype = torch.uint8 + step = W_q.shape[0] + tmp = torch.empty([2 * step, W_q.shape[1]], + dtype=dtype, + device=W_q.device) + tmp[:step] = (W_q & 0b11110000) >> 4 + tmp[step:] = W_q & 0b00001111 + return tmp + + def __init__(self, packed_factor: int, packed_dim: int, weight_bits: int, + **kwargs): + super().__init__(packed_factor, packed_dim, None, **kwargs) + self.weight_bits = weight_bits + self.input_shape = self.shape[self.input_dim] * self.packed_factor + self.output_shape = self.shape[self.output_dim] + + def load_merged_column_weight(self, loaded_weight: torch.Tensor, **kwargs): + loaded_weight = self.unpack_4bit_u8(loaded_weight) + loaded_weight = loaded_weight.reshape(-1, self.input_shape).transpose( + 1, 0) + loaded_weight = gptq_pack(loaded_weight, self.weight_bits, + loaded_weight.shape[0], + loaded_weight.shape[1]) + super().load_merged_column_weight(loaded_weight, **kwargs) + + def load_row_parallel_weight(self, loaded_weight: torch.Tensor): + loaded_weight = self.unpack_4bit_u8(loaded_weight) + loaded_weight = loaded_weight.reshape(self.output_shape, + -1).transpose(1, 0) + loaded_weight = gptq_pack(loaded_weight, self.weight_bits, + loaded_weight.shape[0], + loaded_weight.shape[1]) + super().load_row_parallel_weight(loaded_weight) + + def load_qkv_weight(self, loaded_weight: torch.Tensor, **kwargs): + loaded_weight = self.unpack_4bit_u8(loaded_weight) + loaded_weight = loaded_weight.reshape(-1, self.input_shape).transpose( + 1, 0) + loaded_weight = gptq_pack(loaded_weight, self.weight_bits, + loaded_weight.shape[0], + loaded_weight.shape[1]) + super().load_qkv_weight(loaded_weight, **kwargs) + + +# Zero points and scales in HQQ must also be reshaped to correspond to W_q's +# GPTQ shape (transposed - we transpose them too when processing weights). +class HQQZeroScaleParameter(GroupQuantScaleParameter): + + def load_merged_column_weight(self, loaded_weight: torch.Tensor, **kwargs): + loaded_weight = loaded_weight.reshape(-1, self.shape[1]) + super().load_merged_column_weight(loaded_weight, **kwargs) + + def load_row_parallel_weight(self, loaded_weight: torch.Tensor): + loaded_weight = loaded_weight.reshape(self.shape[0], -1) + super().load_row_parallel_weight(loaded_weight) + + def load_qkv_weight(self, loaded_weight: torch.Tensor, **kwargs): + loaded_weight = loaded_weight.reshape(-1, self.shape[1]) + super().load_qkv_weight(loaded_weight, **kwargs) + + +class HQQMarlinMethod(LinearMethodBase): + """Linear method for HQQ Marlin. + """ + + def __init__( + self, + quant_config: HQQMarlinConfig, + ): + self.quant_config = quant_config + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ) -> None: + self.output_size_per_partition = sum(output_partition_sizes) + self.input_size_per_partition = input_size_per_partition + + weight_loader = extra_weight_attrs.get("weight_loader", error_loader) + + self.scales_and_zp_size = (input_size_per_partition // + self.quant_config.group_size) + + qweight = HQQweightParameter( + data=torch.empty( + self.input_size_per_partition // self.quant_config.pack_factor, + self.output_size_per_partition, + dtype=torch.int32, + ), + input_dim=0, + output_dim=1, + packed_dim=0, + packed_factor=self.quant_config.pack_factor, + weight_bits=self.quant_config.weight_bits, + weight_loader=weight_loader) + + zeros = HQQZeroScaleParameter(data=torch.empty( + self.output_size_per_partition, + self.scales_and_zp_size, + dtype=params_dtype, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader) + + scales = HQQZeroScaleParameter(data=torch.empty( + self.output_size_per_partition, + self.scales_and_zp_size, + dtype=params_dtype, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader) + + layer.register_parameter("W_q", qweight) + layer.register_parameter("zero", zeros) + layer.register_parameter("scale", scales) + + # Ignore extra parameters in the HQQ model. + # To be added as needed. + ignore_parameters = ("axis", "channel_wise", "compute_dtype", + "encoded_state_dict", "group_size", "nbits", + "offload_meta", "optimize", "packing", + "quant_scale", "quant_zero", "round_zero", + "shape", "stores_quant_config", + "unpack_view_dtype", "view_as_float") + for name in ignore_parameters: + layer.register_parameter( + name, + HQQEmptyParameter(data=torch.empty(0), + weight_loader=weight_loader)) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + dev = layer.W_q.device + + # Repack to Marlin + sort_indices = torch.empty(0, dtype=torch.int, device=dev) + marlin_w_q = ops.gptq_marlin_repack( + layer.W_q, + sort_indices, + self.input_size_per_partition, + self.output_size_per_partition, + self.quant_config.weight_bits, + ).to(dev) + marlin_s = marlin_permute_scales(layer.scale.transpose(1, 0), + self.input_size_per_partition, + self.output_size_per_partition, + self.quant_config.group_size).to(dev) + marlin_zp = marlin_permute_scales(layer.zero.transpose(1, 0), + self.input_size_per_partition, + self.output_size_per_partition, + self.quant_config.group_size).to(dev) + + layer.g_idx = marlin_make_empty_g_idx(dev) + layer.g_idx_sort_indices = marlin_make_empty_g_idx(dev) + + layer.marlin_qweight = marlin_w_q + layer.marlin_zeros = marlin_zp + layer.marlin_scales = marlin_s + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + workspace = MarlinWorkspace(self.output_size_per_partition, + GPTQ_MARLIN_MIN_THREAD_N, + GPTQ_MARLIN_MAX_PARALLEL) + + scales = layer.marlin_scales + zeros = layer.marlin_zeros + orig_type = x.dtype + + if orig_type != torch.float16: + x = x.to(torch.float16) + scales = scales.to(torch.float16) + zeros = zeros.to(torch.float16) + + marlin_out = ops.gptq_marlin_gemm( + x, + layer.marlin_qweight, + scales, + zeros, + layer.g_idx, + layer.g_idx_sort_indices, + workspace.scratch, + scalar_types.uint4, + x.shape[0], + self.output_size_per_partition, + self.input_size_per_partition, + True, # is_k_full + True, # has_zp + True, # use 32-bit reduce + True, # use float zp + ) + + if orig_type != torch.float16: + marlin_out = marlin_out.to(orig_type) + + if bias is not None: + marlin_out.add_(bias) + + return marlin_out diff --git a/vllm/model_executor/layers/quantization/ipex_quant.py b/vllm/model_executor/layers/quantization/ipex_quant.py index 330c2ad195d78..c16a962134d06 100644 --- a/vllm/model_executor/layers/quantization/ipex_quant.py +++ b/vllm/model_executor/layers/quantization/ipex_quant.py @@ -2,21 +2,26 @@ import torch -from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase -from vllm.model_executor.layers.quantization.awq import AWQLinearMethod +from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, + UnquantizedLinearMethod) +from vllm.model_executor.layers.quantization.awq import (AWQLinearMethod, + is_layer_skipped_awq) from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) +from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod from vllm.platforms import current_platform +MIN_IPEX_VERSION = "2.5.0" + class IPEXConfig(QuantizationConfig): - """INT8 quantization config class using IPEX for the CPU backend, - including AWQ. + """INT8 quantization config class using IPEX for the CPU/XPU backend, + including AWQ, GPTQ. """ IPEX_QUANT_METHOD_MAP = { "awq": 1, - "gptq": 2, + "gptq": 0, } def __init__( @@ -24,29 +29,30 @@ def __init__( method: str, weight_bits: int, group_size: int, + modules_to_not_convert: Optional[List[str]] = None, + desc_act: Optional[bool] = None, + lm_head_quantized: Optional[bool] = None, ) -> None: self.method = method self.weight_bits = weight_bits self.group_size = group_size + self.modules_to_not_convert = modules_to_not_convert or [] + self.desc_act = desc_act + self.lm_head_quantized = lm_head_quantized self.pack_factor = 32 // self.weight_bits if self.weight_bits not in [4]: raise ValueError(f"IPEX quantization supports weight bits [4], " f"but got {self.weight_bits}.") - if self.method == "awq": - self.quant_method = IPEXAWQLinearMethod - else: - raise ValueError(f"IPEX quantization supports [awq], " + if self.method not in ["awq", "gptq"]: + raise ValueError(f"IPEX quantization supports [awq, gptq], " f"but got {self.method}.") def __repr__(self) -> str: - return (f"IPEXConfig(method={self.method}" + return (f"IPEXConfig(method={self.method}," f"weight_bits={self.weight_bits}, " - f"group_size={self.group_size}") - - def get_ipex_quant_method_id(self) -> int: - return IPEXConfig.IPEX_QUANT_METHOD_MAP[self.method] + f"group_size={self.group_size})") @classmethod def get_name(cls) -> str: @@ -70,19 +76,32 @@ def get_config_filenames() -> List[str]: @classmethod def from_config(cls, config: Dict[str, Any]) -> "IPEXConfig": method = cls.get_from_keys(config, ["quant_method"]).lower() - weight_bits = cls.get_from_keys(config, ["w_bit", "bits"]) - group_size = cls.get_from_keys(config, ["q_group_size", "group_size"]) - return cls(method, weight_bits, group_size) + if method == "awq": + weight_bits = cls.get_from_keys(config, ["w_bit", "bits"]) + group_size = cls.get_from_keys(config, + ["q_group_size", "group_size"]) + modules_to_not_convert = cls.get_from_keys_or( + config, ["modules_to_not_convert"], None) + return cls(method, weight_bits, group_size, modules_to_not_convert, + False, False) + # otherwise for gptq + weight_bits = cls.get_from_keys(config, ["bits"]) + group_size = cls.get_from_keys(config, ["group_size"]) + lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], + default=False) + desc_act = cls.get_from_keys_or(config, ["desc_act"], default=False) + return cls(method, weight_bits, group_size, [], desc_act, + lm_head_quantized) @classmethod def override_quantization_method(cls, hf_quant_cfg, user_quant) -> Optional[str]: - if not current_platform.is_cpu(): + if not current_platform.is_cpu() and not current_platform.is_xpu(): return None quant_method = hf_quant_cfg.get("quant_method", "").lower() - if quant_method in ["awq"]: + if quant_method in ["awq", "gptq"]: return cls.get_name() return None @@ -90,12 +109,81 @@ def override_quantization_method(cls, hf_quant_cfg, def get_quant_method(self, layer: torch.nn.Module, prefix: str) -> Optional["LinearMethodBase"]: if isinstance(layer, LinearBase): - return self.quant_method(self) + if self.method == "awq": + if is_layer_skipped_awq(prefix, self.modules_to_not_convert): + return UnquantizedLinearMethod() + return IPEXAWQLinearMethod(self) + if self.method == "gptq": + return IPEXGPTQLinearMethod(self) return None +class IPEXGPTQLinearMethod(GPTQLinearMethod): + """GPTQ linear method using IPEX for the CPU/XPU backend. + """ + + def __init__(self, quant_config: IPEXConfig): + self.quant_config = quant_config # type: ignore + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + bias = layer.bias if not layer.skip_bias_add else None + + try: + import intel_extension_for_pytorch as ipex + if ipex.__version__ < MIN_IPEX_VERSION: + raise ImportError( + "intel_extension_for_pytorch version is " + "wrong. Please install " + f"intel_extension_for_pytorch>={MIN_IPEX_VERSION}.") + except ImportError as err: + raise ImportError( + "Please install " + f"intel_extension_for_pytorch>={MIN_IPEX_VERSION} via " + f"`pip install intel_extension_for_pytorch>={MIN_IPEX_VERSION}`" + " to use IPEX-AWQ linear method.") from err + # Using the compute dtype (lowp_mode) as INT8 to leverage instructions + # with better performance. + lowp_mode = ipex.quantization.WoqLowpMode.INT8 + # The weight will be de-packed from INT4 to INT8. + weight_dtype = ipex.quantization.WoqWeightDtype.INT4 + # The float activation will be quantized (dynamic, per-token) to INT8. + act_quant_mode = ipex.quantization.WoqActQuantMode.PER_BATCH_IC_BLOCK + + qconfig = ipex.quantization.get_weight_only_quant_qconfig_mapping( + weight_dtype=weight_dtype, + lowp_mode=lowp_mode, + act_quant_mode=act_quant_mode, + group_size=self.quant_config.group_size, + ) + layer.ipex_output_size = layer.qweight.shape[-1] + g_idx = layer.g_idx if self.quant_config.desc_act else None + layer.ipex_qlinear = ipex.llm.quantization.woq_linear. \ + IPEXWeightOnlyQuantizedLinear.from_weight( + layer.qweight, + layer.scales, + layer.qzeros, + layer.qweight.size(0), + layer.ipex_output_size, + qconfig=qconfig, + g_idx=g_idx, + bias=bias, + group_size=self.quant_config.group_size, + quant_method=IPEXConfig.IPEX_QUANT_METHOD_MAP["gptq"] + ) + + def apply(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + reshaped_x = x.reshape(-1, x.shape[-1]) + out = layer.ipex_qlinear(reshaped_x) + if bias is not None: + out.add_(bias) + return out.reshape(x.shape[:-1] + (layer.ipex_output_size, )) + + class IPEXAWQLinearMethod(AWQLinearMethod): - """AWQ linear method using IPEX for the CPU backend. + """AWQ linear method using IPEX for the CPU/XPU backend. """ def __init__(self, quant_config: IPEXConfig): @@ -108,15 +196,16 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: try: import intel_extension_for_pytorch as ipex - if ipex.__version__ < "2.4.0": - raise ImportError("intel_extension_for_pytorch version is " - "wrong. Please install " - "intel_extension_for_pytorch>=2.4.0.") + if ipex.__version__ < MIN_IPEX_VERSION: + raise ImportError( + "intel_extension_for_pytorch version is " + "wrong. Please install " + f"intel_extension_for_pytorch>={MIN_IPEX_VERSION}.") except ImportError as err: raise ImportError( "Please install " - "intel_extension_for_pytorch>=2.4.0 via " - "`pip install intel_extension_for_pytorch>=2.4.0`" + f"intel_extension_for_pytorch>={MIN_IPEX_VERSION} via " + f"`pip install intel_extension_for_pytorch>={MIN_IPEX_VERSION}`" " to use IPEX-AWQ linear method.") from err # Using the compute dtype (lowp_mode) as INT8 to leverage instructions @@ -136,19 +225,18 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer.ipex_output_size = layer.qweight.size( 1) * self.quant_config.pack_factor - layer.ipex_qlinear = ipex.nn.modules.weight_only_quantization.\ - WeightOnlyQuantizedLinear.from_weight( - layer.qweight, - layer.scales, - layer.qzeros, - layer.qweight.size(0), - layer.ipex_output_size, - qconfig=qconfig, - bias=bias, - group_size=self.quant_config.group_size, - quant_method= - self.quant_config.get_ipex_quant_method_id() # type: ignore - ) + layer.ipex_qlinear = ipex.llm.quantization.woq_linear. \ + IPEXWeightOnlyQuantizedLinear.from_weight( + layer.qweight, + layer.scales, + layer.qzeros, + layer.qweight.size(0), + layer.ipex_output_size, + qconfig=qconfig, + bias=bias, + group_size=self.quant_config.group_size, + quant_method=IPEXConfig.IPEX_QUANT_METHOD_MAP["awq"] # type: ignore + ) def apply(self, layer: torch.nn.Module, @@ -156,5 +244,4 @@ def apply(self, bias: Optional[torch.Tensor] = None) -> torch.Tensor: reshaped_x = x.reshape(-1, x.shape[-1]) out = layer.ipex_qlinear(reshaped_x) - return out.reshape(x.shape[:-1] + (layer.ipex_output_size, )) diff --git a/vllm/model_executor/layers/quantization/kernels/machete.py b/vllm/model_executor/layers/quantization/kernels/machete.py index e5696d08f30f5..15df0200f30b5 100644 --- a/vllm/model_executor/layers/quantization/kernels/machete.py +++ b/vllm/model_executor/layers/quantization/kernels/machete.py @@ -79,7 +79,9 @@ def transform_w_q(x): c.weight_type, packed_dim=0) x.data = ops.machete_prepack_B(x.data.t().contiguous().t(), - self.config.weight_type) + a_type=c.act_type, + b_type=c.weight_type, + group_scales_type=c.act_type) return x def transform_w_s(x): @@ -105,12 +107,12 @@ def apply_weights(self, if c.has_g_idx: x_2d = self.act_perm(x_2d) - output = ops.machete_gemm(a=x_2d, - b_q=w_q, - b_type=c.weight_type, - b_zeros=None, - b_scales=w_s, - b_group_size=c.group_size) + output = ops.machete_mm(a=x_2d, + b_q=w_q, + b_type=c.weight_type, + b_group_zeros=None, + b_group_scales=w_s, + b_group_size=c.group_size) if bias is not None: output.add_(bias) # In-place add diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils.py b/vllm/model_executor/layers/quantization/utils/marlin_utils.py index 9a1defa409714..c9366ca97d149 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils.py @@ -303,7 +303,8 @@ def apply_gptq_marlin_linear( size_k=input_size_per_partition, is_k_full=is_k_full, has_zp=False, - use_fp32_reduce=use_fp32_reduce) + use_fp32_reduce=use_fp32_reduce, + is_zp_float=False) if bias is not None: output.add_(bias) # In-place add @@ -340,7 +341,8 @@ def apply_awq_marlin_linear( size_k=input_size_per_partition, is_k_full=True, has_zp=True, - use_fp32_reduce=use_fp32_reduce) + use_fp32_reduce=use_fp32_reduce, + is_zp_float=False) if bias is not None: output.add_(bias) # In-place add diff --git a/vllm/model_executor/layers/quantization/utils/quant_utils.py b/vllm/model_executor/layers/quantization/utils/quant_utils.py index c217f5ca620a1..83055d6000d83 100644 --- a/vllm/model_executor/layers/quantization/utils/quant_utils.py +++ b/vllm/model_executor/layers/quantization/utils/quant_utils.py @@ -126,11 +126,14 @@ def permute_rows(q_w: torch.Tensor, def quantize_weights(w: torch.Tensor, quant_type: ScalarType, - group_size: int, + group_size: Optional[int], zero_points: bool = False, ref_zero_points_after_scales: bool = False): assert quant_type.is_integer(), \ "Floating point quantization may work but has not been tested" + assert not zero_points or group_size is not None, \ + "to have group zero points, group_size must be provided "\ + "(-1 group_size is channelwise)" orig_device = w.device orig_type = w.dtype @@ -140,10 +143,9 @@ def quantize_weights(w: torch.Tensor, if group_size == -1: group_size = size_k - assert group_size <= size_k # Reshape to [groupsize, -1] - if group_size < size_k: + if group_size is not None and group_size < size_k: w = w.reshape((-1, group_size, size_n)) w = w.permute(1, 0, 2) w = w.reshape((group_size, -1)) @@ -155,18 +157,20 @@ def quantize_weights(w: torch.Tensor, max_q_val = quant_type.max() min_q_val = quant_type.min() - if zero_points: - assert not quant_type.is_signed() and quant_type.max() > 0 - w_s = (max_val - min_val).clamp(min=1e-5) / quant_type.max() - maybe_w_zp = torch.round(torch.abs(min_val / w_s)) \ - .clamp(min_q_val, max_q_val).int() - else: - # If the bias is such that there are no possible negative/positive - # values, set the max value to inf to avoid divide by 0 - w_s = torch.max( - abs(max_val / (max_q_val if max_q_val != 0 else torch.inf)), - abs(min_val / (min_q_val if min_q_val != 0 else torch.inf))) - maybe_w_zp = None + w_s = torch.Tensor([1.0]).to(w.device) # unscaled case + maybe_w_zp = None + if group_size is not None: + if zero_points: + assert not quant_type.is_signed() and quant_type.max() > 0 + w_s = (max_val - min_val).clamp(min=1e-5) / quant_type.max() + maybe_w_zp = torch.round(torch.abs(min_val / w_s)) \ + .clamp(min_q_val, max_q_val).int() + else: + # If the bias is such that there are no possible negative/positive + # values, set the max value to inf to avoid divide by 0 + w_s = torch.max( + abs(max_val / (max_q_val if max_q_val != 0 else torch.inf)), + abs(min_val / (min_q_val if min_q_val != 0 else torch.inf))) # Quantize w_q = torch.round(w / w_s).int() + (maybe_w_zp if zero_points else 0) @@ -176,7 +180,7 @@ def quantize_weights(w: torch.Tensor, # For some kernels (namely Machete) the zero-points are applied after the # scales are applied, for this case computing the reference in similar way # allows us to use tighter error tolerances in our unit tests. - if ref_zero_points_after_scales and zero_points: + if ref_zero_points_after_scales and maybe_w_zp is not None: w_ref = w_q.to(orig_type) * w_s - maybe_w_zp.to(orig_type) * w_s else: w_ref = (w_q - (maybe_w_zp if zero_points else 0)).to(orig_type) * w_s @@ -185,7 +189,7 @@ def quantize_weights(w: torch.Tensor, w_q += quant_type.bias # Restore original shapes - if group_size < size_k: + if group_size is not None and group_size < size_k: def reshape_w(w): w = w.reshape((group_size, -1, size_n)) @@ -195,17 +199,16 @@ def reshape_w(w): w_q = reshape_w(w_q) w_ref = reshape_w(w_ref) + w_s = w_s.reshape((-1, size_n)).contiguous() - w_s = w_s.reshape((-1, size_n)).contiguous() - - if zero_points: + if maybe_w_zp is not None: maybe_w_zp = maybe_w_zp.reshape((-1, size_n)).contiguous() maybe_w_zp = maybe_w_zp.to(device=orig_device) return ( w_ref.to(device=orig_device), w_q.to(device=orig_device), - w_s.to(device=orig_device), + w_s if group_size is not None else None, maybe_w_zp, ) diff --git a/vllm/model_executor/layers/spec_decode_base_sampler.py b/vllm/model_executor/layers/spec_decode_base_sampler.py index 7e750a744e25f..6aa4b8bd34cde 100644 --- a/vllm/model_executor/layers/spec_decode_base_sampler.py +++ b/vllm/model_executor/layers/spec_decode_base_sampler.py @@ -43,6 +43,21 @@ def init_gpu_tensors(self, device: Union[int, str]) -> None: dtype=torch.long, device=device) + def init_tensors(self, + device: Union[int, str], + device_type: Union[torch.device, str] = 'cuda') -> None: + assert self.num_accepted_tokens is None + if isinstance(device_type, torch.device): + device_type = device_type.type + if isinstance(device, int): + device = f"{device_type}:{device}" + self.num_accepted_tokens = torch.tensor(0, + dtype=torch.long, + device=device) + self.num_emitted_tokens = torch.tensor(0, + dtype=torch.long, + device=device) + @property def probs_dtype(self): return torch.float32 @@ -77,7 +92,7 @@ def _create_output( tensor is [batch_size, k + num_bonus_tokens] """ batch_size, k = substitute_token_ids.shape - bonus_token_ids = bonus_token_ids.squeeze() + bonus_token_ids = bonus_token_ids.squeeze(-1) # Determine the index of the first False value for each row. limits = (accepted == 0).max(1).indices limits[~(accepted == 0).any(1)] = k diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index d9ce85949e4ee..0e12bc5691538 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -5,9 +5,11 @@ import fnmatch import glob import inspect +import itertools import json import math import os +import warnings from abc import ABC, abstractmethod from contextlib import contextmanager from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple, cast @@ -22,13 +24,18 @@ from transformers.utils import SAFE_WEIGHTS_INDEX_NAME from vllm.config import (LoadConfig, LoadFormat, ModelConfig, ParallelConfig, - VllmConfig) + VllmConfig, set_current_vllm_config) from vllm.distributed import (get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.envs import VLLM_USE_MODELSCOPE from vllm.logger import init_logger -from vllm.model_executor.layers.linear import (ReplicatedLinear, +from vllm.model_executor.layers.linear import (LinearBase, + MergedColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, RowParallelLinear) +from vllm.model_executor.layers.quantization.base_config import ( + QuantizeMethodBase) from vllm.model_executor.model_loader.tensorizer import ( TensorizerConfig, is_vllm_tensorized, load_with_tensorizer, serialize_vllm_model, tensorizer_weights_iterator) @@ -42,7 +49,6 @@ safetensors_weights_iterator) from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform -from vllm.plugins import set_current_vllm_config from vllm.utils import is_pin_memory_available @@ -74,12 +80,14 @@ def device_loading_context(module: torch.nn.Module, original_device: torch.device = original_device_states[name] if original_device.type == "cpu": # `torch.empty_like` does not support `pin_memory` argument - cpu_data = torch.empty_strided(size=p.data.size(), - stride=p.data.stride(), - dtype=p.data.dtype, - layout=p.data.layout, - device="cpu", - pin_memory=pin_memory) + cpu_data = torch.empty_strided( + size=p.data.size(), + stride=p.data.stride(), + dtype=p.data.dtype, + layout=p.data.layout, + device="cpu", + pin_memory=pin_memory, + ) cpu_data.copy_(p.data) p.data = cpu_data else: @@ -90,25 +98,35 @@ def device_loading_context(module: torch.nn.Module, logger = init_logger(__name__) -def _initialize_model(vllm_config: VllmConfig, prefix: str = "") -> nn.Module: +def _initialize_model( + vllm_config: VllmConfig, + *, + prefix: str = "", + architectures: Optional[list[str]] = None, +) -> nn.Module: """Initialize a model with the given configurations.""" model_config = vllm_config.model_config - model_class, _ = get_model_architecture(model_config) + model_class, _ = get_model_architecture(model_config, + architectures=architectures) + signatures = inspect.signature(model_class.__init__) all_params = [param.name for param in signatures.parameters.values()] if "vllm_config" in all_params and "prefix" in all_params: # new-style model class with set_current_vllm_config(vllm_config): return model_class(vllm_config=vllm_config, prefix=prefix) + msg = ("vLLM model class should accept `vllm_config` and `prefix` as " "input arguments. Possibly you have an old-style model class" " registered from out of tree and it is used for new vLLM version. " - "Check https://docs.vllm.ai/en/latest/design/class_hierarchy.html " + "Check https://docs.vllm.ai/en/latest/design/arch_overview.html " "for the design and update the model class accordingly.") - logger.warning(msg) + warnings.warn(msg, DeprecationWarning, stacklevel=2) + logger.warning( "Trying to guess the arguments for old-style model class %s", - model_class) + model_class, + ) # try to be compatible with old-style model class kwargs = {} if "prefix" in all_params: @@ -194,14 +212,17 @@ def _maybe_download_from_modelscope( return model_path return None - def _prepare_weights(self, model_name_or_path: str, - revision: Optional[str], - fall_back_to_pt: bool) -> Tuple[str, List[str], bool]: + def _prepare_weights( + self, + model_name_or_path: str, + revision: Optional[str], + fall_back_to_pt: bool, + ) -> Tuple[str, List[str], bool]: """Prepare weights for the model. If the model is not local, it will be downloaded.""" - model_name_or_path = self._maybe_download_from_modelscope( - model_name_or_path, revision) or model_name_or_path + model_name_or_path = (self._maybe_download_from_modelscope( + model_name_or_path, revision) or model_name_or_path) is_local = os.path.isdir(model_name_or_path) load_format = self.load_config.load_format @@ -254,8 +275,11 @@ def _prepare_weights(self, model_name_or_path: str, # any files not found in the index. if not is_local: download_safetensors_index_file_from_hf( - model_name_or_path, index_file, - self.load_config.download_dir, revision) + model_name_or_path, + index_file, + self.load_config.download_dir, + revision, + ) hf_weights_files = filter_duplicate_safetensors_files( hf_weights_files, hf_folder, index_file) else: @@ -278,8 +302,11 @@ def _get_weights_iterator( # Currently np_cache only support *.bin checkpoints assert use_safetensors is False weights_iterator = np_cache_weights_iterator( - source.model_or_path, self.load_config.download_dir, hf_folder, - hf_weights_files) + source.model_or_path, + self.load_config.download_dir, + hf_folder, + hf_weights_files, + ) elif use_safetensors: weights_iterator = safetensors_weights_iterator(hf_weights_files) else: @@ -306,17 +333,19 @@ def _get_all_weights( model_config: ModelConfig, model: nn.Module, ) -> Generator[Tuple[str, torch.Tensor], None, None]: - primary_weights = DefaultModelLoader.Source( model_config.model, model_config.revision, prefix="", fall_back_to_pt=getattr(model, "fall_back_to_pt_during_load", - True)) + True), + ) yield from self._get_weights_iterator(primary_weights) - secondary_weights = cast(Iterable[DefaultModelLoader.Source], - getattr(model, "secondary_weights", ())) + secondary_weights = cast( + Iterable[DefaultModelLoader.Source], + getattr(model, "secondary_weights", ()), + ) for source in secondary_weights: yield from self._get_weights_iterator(source) @@ -337,7 +366,7 @@ def load_model(self, vllm_config: VllmConfig) -> nn.Module: weights_to_load = {name for name, _ in model.named_parameters()} loaded_weights = model.load_weights( self._get_all_weights(model_config, model)) - # We only enable strict check for non-quantiized models + # We only enable strict check for non-quantized models # that have loaded weights tracking currently. if model_config.quantization is None and loaded_weights is not None: weights_not_loaded = weights_to_load - loaded_weights @@ -348,7 +377,7 @@ def load_model(self, vllm_config: VllmConfig) -> nn.Module: for _, module in model.named_modules(): quant_method = getattr(module, "quant_method", None) - if quant_method is not None: + if isinstance(quant_method, QuantizeMethodBase): # When quant methods need to process weights after loading # (for repacking, quantizing, etc), they expect parameters # to be on the global target device. This scope is for the @@ -412,7 +441,7 @@ def _verify_config(self, model_config: ModelConfig, self.tensorizer_config.verify_with_parallel_config(parallel_config) def _get_weights_iterator( - self) -> Generator[Tuple[str, torch.Tensor], None, None]: + self, ) -> Generator[Tuple[str, torch.Tensor], None, None]: tensorizer_args = self.tensorizer_config._construct_tensorizer_args() return tensorizer_weights_iterator(tensorizer_args) @@ -475,9 +504,10 @@ def load_model(self, vllm_config: VllmConfig) -> nn.Module: if parallel_config.tensor_parallel_size > 1: from vllm.distributed import get_tensor_model_parallel_rank - self.tensorizer_config.tensorizer_uri = \ - self.tensorizer_config.tensorizer_uri \ - % get_tensor_model_parallel_rank() + + self.tensorizer_config.tensorizer_uri = ( + self.tensorizer_config.tensorizer_uri % + get_tensor_model_parallel_rank()) if is_vllm_tensorized(self.tensorizer_config): return self._load_model_serialized(vllm_config=vllm_config) @@ -516,13 +546,13 @@ def __init__(self, load_config: LoadConfig): @staticmethod def _filter_subtensors( - tensors: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + tensors: Dict[str, torch.Tensor], ) -> Dict[str, torch.Tensor]: """ Filter out all tensors that share the same memory or a subset of the memory of another tensor. """ - same_storage_groups: Dict[Any, List[Tuple[ - str, torch.Tensor]]] = collections.defaultdict(list) + same_storage_groups: Dict[Any, List[Tuple[str, torch.Tensor]]] = ( + collections.defaultdict(list)) for key, tensor in tensors.items(): if tensor.numel(): ptr = tensor.untyped_storage().data_ptr() @@ -611,8 +641,11 @@ def load_model(self, vllm_config: VllmConfig) -> nn.Module: if tensor.shape != param_shape: logger.warning( "loading tensor of shape %s into " - "parameter '%s' of shape %s", tensor.shape, - key, param_shape) + "parameter '%s' of shape %s", + tensor.shape, + key, + param_shape, + ) param_data.copy_(tensor) state_dict.pop(key) if state_dict: @@ -630,6 +663,7 @@ def save_model( from safetensors.torch import save_file from vllm.distributed import get_tensor_model_parallel_rank + if pattern is None: pattern = ShardedStateLoader.DEFAULT_PATTERN rank = get_tensor_model_parallel_rank() @@ -663,24 +697,6 @@ class BitsAndBytesModelLoader(BaseModelLoader): possible_config_file_names = ["adapter_config.json"] - default_target_modules = [ - ".gate_proj.", - ".down_proj.", - ".up_proj.", - ".q_proj.", - ".k_proj.", - ".v_proj.", - ".o_proj.", - '.fc1.', - '.fc2.', - '.dense.', - '.query_key_value.', - '.qkv_proj.', - '.dense_h_to_4h.', - '.dense_4h_to_h.', - '.out_proj.', - ] - def __init__(self, load_config: LoadConfig): super().__init__(load_config) @@ -705,6 +721,11 @@ def __init__(self, load_config: LoadConfig): with open(config_file_path) as f: config = json.load(f) self.target_modules = config["target_modules"] + # TODO: target_modules could be either a list or a regex string. + # We need to handle both cases. + assert isinstance(self.target_modules, + list), "Unsupported target_modules: " + f"{self.target_modules}" def _get_config_file(self, qlora_adapter: str) -> str: is_local = os.path.isdir(qlora_adapter) @@ -730,12 +751,13 @@ def _get_config_file(self, qlora_adapter: str) -> str: return config_file_path def _get_weight_files( - self, - model_name_or_path: str, - allowed_patterns: List[str], - revision: Optional[str] = None) -> Tuple[List[str], str]: - """Retrieve weight files. Download the files if necessary. - + self, + model_name_or_path: str, + allowed_patterns: List[str], + revision: Optional[str] = None, + ) -> Tuple[List[str], str]: + """Retrieve weight files. Download the files if necessary. + Return the weight files and the file pattern.""" is_local = os.path.isdir(model_name_or_path) @@ -802,6 +824,7 @@ def _get_quantized_weights_iterator( # only load the bitsandbytes module when needed try: import bitsandbytes + if bitsandbytes.__version__ < "0.44.0": raise ImportError("bitsandbytes version is wrong. Please " "install bitsandbytes>=0.44.0.") @@ -835,8 +858,11 @@ def _is_8bit_weight_name(self, weight_name: str): def _is_4bit_weight_name(self, weight_name: str): quantized_suffix = { - "absmax", "quant_map", "nested_absmax", "nested_quant_map", - "bitsandbytes" + "absmax", + "quant_map", + "nested_absmax", + "nested_quant_map", + "bitsandbytes", } suffix = weight_name.split(".")[-1] return any(q_suffix in suffix for q_suffix in quantized_suffix) @@ -853,7 +879,6 @@ def _quantized_8bit_generator(self, hf_weights_files, use_safetensors, for weight_name, weight_tensor in self._hf_weight_iter( hf_weights_files, use_safetensors): - if self._is_8bit_weight_name(weight_name): continue @@ -895,14 +920,13 @@ def _parse_quant_state(param_name: str, # pre quantized weights would have a quant_state for weight_name, weight_tensor in self._hf_weight_iter( hf_weights_files, use_safetensors): - if self._is_4bit_weight_name(weight_name): continue - if (f"{weight_name}.quant_state.bitsandbytes__nf4" \ - in temp_state_dict) or \ - (f"{weight_name}.quant_state.bitsandbytes__fp4" \ - in temp_state_dict): + if (f"{weight_name}.quant_state.bitsandbytes__nf4" + in temp_state_dict) or ( + f"{weight_name}.quant_state.bitsandbytes__fp4" + in temp_state_dict): quant_state = _parse_quant_state(weight_name, temp_state_dict) quant_state_dict[weight_name] = quant_state yield weight_name, weight_tensor @@ -912,12 +936,12 @@ def _parse_quant_state(param_name: str, def _unquantized_generator(self, hf_weights_files, use_safetensors, quant_state_dict) -> Generator: from bitsandbytes.functional import quantize_4bit + tp_size = get_tensor_model_parallel_world_size() tp_rank = get_tensor_model_parallel_rank() for weight_name, weight_tensor in self._hf_weight_iter( hf_weights_files, use_safetensors): - if any(target_module in weight_name for target_module in self.target_modules) and weight_name.endswith(".weight"): # Without sharding @@ -934,6 +958,33 @@ def _unquantized_generator(self, hf_weights_files, use_safetensors, end_index = total_size // tp_size * (tp_rank + 1) weight_sub_tensor = weight_tensor[..., start_index:end_index] + # Weights have fused on disk. In this case, we assume that the + # weight and module use same name. + elif any( + weight_name.startswith(module) + for module in self.maybe_fused_weights_modules): + # special case for fused weights + # get the size of each shard weight tensor + total_shard_sizes = next( + (sizes for module, sizes in + self.maybe_fused_weights_modules.items() + if weight_name.startswith(module))) + total_size = weight_tensor.size(0) + assert total_size == sum(total_shard_sizes) + # get the start/end index of each shard weight tensor + total_start_index = list( + itertools.accumulate([0] + total_shard_sizes))[:-1] + shard_weights_index = [( + idx + size // tp_size * tp_rank, + idx + size // tp_size * (tp_rank + 1), + ) for idx, size in zip(total_start_index, + total_shard_sizes)] + # slice and reorder the weight tensor + weight_tensor = [ + weight_tensor[start_index:end_index, ...] + for start_index, end_index in shard_weights_index + ] + weight_sub_tensor = torch.cat(weight_tensor, dim=0) # Shard by row else: total_size = weight_tensor.size(0) @@ -957,7 +1008,8 @@ def _unquantized_generator(self, hf_weights_files, use_safetensors, processed_weight, quant_state = quantize_4bit( loaded_weight, compress_statistics=True, - quant_type="nf4") + quant_type="nf4", + ) quant_state_dict[weight_name] = quant_state else: @@ -965,30 +1017,70 @@ def _unquantized_generator(self, hf_weights_files, use_safetensors, yield weight_name, processed_weight + def _get_bnb_target_modules(self, model: nn.Module) -> None: + + # TODO: Maybe we can replace bitsandbytes_stacked_params_mapping with + # packed_modules_mapping. + inverse_stacked_mapping: Dict[str, List[str]] = {} + for orig, ( + packed, + idx, + ) in model.bitsandbytes_stacked_params_mapping.items(): + if packed not in inverse_stacked_mapping: + inverse_stacked_mapping[packed] = [] + inverse_stacked_mapping[packed].insert(idx, orig) + + linear_module_lst = [] + for name, module in model.named_modules(): + if isinstance(module, (LinearBase, )): + last_name = name.split(".")[-1] + if sub_modules := inverse_stacked_mapping.get(last_name, []): + # Map vllm's names to transformers' names. + for sub_name in sub_modules: + linear_module_lst.append( + name.replace(last_name, sub_name)) + else: + linear_module_lst.append(name) + if self.target_modules: + # Update self.target_modules + self.target_modules = [ + qual_name for qual_name in linear_module_lst + if any(t in qual_name for t in self.target_modules) + ] + else: + self.target_modules = linear_module_lst + assert (self.target_modules + ), "vllm currently does not support BNB quantization for" + f" {type(model).__name__}" + def _load_weights(self, model_config: ModelConfig, model: nn.Module) -> None: - if not hasattr(model, 'load_weights'): + if not hasattr(model, "load_weights"): raise AttributeError( "The required method 'load_weights' is not defined in class" f" {type(model).__name__}.") - if not hasattr(model, 'bitsandbytes_stacked_params_mapping'): + if not hasattr(model, "bitsandbytes_stacked_params_mapping"): raise AttributeError( f"Model {type(model).__name__} does not support BitsAndBytes " "quantization yet.") - if len(self.target_modules) == 0: - if hasattr(model, 'default_bitsandbytes_target_modules'): - self.target_modules = model.default_bitsandbytes_target_modules - else: - self.target_modules = self.default_target_modules - + # Modules whose weights might have fused on disk + # we need their output_sizes to make shard in flight correctly with TP + self.maybe_fused_weights_modules: Dict[str, List[int]] = {} + self._get_bnb_target_modules(model) for name, module in model.named_modules(): # Some modules like `ReplicatedLinear` should not have their weights # sharded. The reason for implementing it this way is to avoid new # static variable in the model implementation. if isinstance(module, (ReplicatedLinear, )): self.unsharded_weights_modules.append(name) + # `QKVParallelLinear` and `MergedColumnParallelLinear` might have + # fused weights on disk. We need to use the output sizes of these + # modules to shard the weights correctly. + elif isinstance(module, + (QKVParallelLinear, MergedColumnParallelLinear)): + self.maybe_fused_weights_modules[name] = module.output_sizes # In TP, these weights are partitioned along the column # dimension (dim=-1) elif isinstance(module, (RowParallelLinear, )): @@ -1004,7 +1096,7 @@ def _load_weights(self, model_config: ModelConfig, pre_quant = False if quant_config is not None: - quant_method = quant_config.get('quant_method') + quant_method = quant_config.get("quant_method") if quant_method == "bitsandbytes": pre_quant = True else: @@ -1021,11 +1113,12 @@ def _load_weights(self, model_config: ModelConfig, load_8bit = False if pre_quant: - load_8bit = quant_config.get('load_in_8bit', False) + load_8bit = quant_config.get("load_in_8bit", False) - qweight_iterator, quant_state_dict = \ - self._get_quantized_weights_iterator( - model_config.model, model_config.revision, pre_quant, load_8bit) + qweight_iterator, quant_state_dict = ( + self._get_quantized_weights_iterator(model_config.model, + model_config.revision, + pre_quant, load_8bit)) model.load_weights(qweight_iterator) @@ -1036,6 +1129,7 @@ def _load_weights(self, model_config: ModelConfig, # TODO: Change this lazy import to normal import # after the checks are updated to run on a new version from vllm.model_executor.models.utils import is_pp_missing_parameter + for quant_param_name in quant_state_dict: if is_pp_missing_parameter(quant_param_name, model): continue @@ -1044,9 +1138,9 @@ def _load_weights(self, model_config: ModelConfig, shard_index = 0 for shard_name, ( - weight_name, index + weight_name, + index, ) in model.bitsandbytes_stacked_params_mapping.items(): - shard_pos = quant_param_name.find(shard_name) # Some models, such as MiniCPM V2.5/2.6, contain both # module names 'kv_proj' and 'qkv_proj'. To prevent 'kv_proj' @@ -1081,8 +1175,8 @@ def _load_weights(self, model_config: ModelConfig, num_elements = [0] * len(quant_states) for seq, quant_state in quant_states.items(): - num_elements[seq] = math.prod( - quant_state.shape) // pack_ratio + num_elements[seq] = (math.prod(quant_state.shape) // + pack_ratio) offsets = np.concatenate(([0], np.cumsum(num_elements))) set_weight_attrs(param, {"bnb_shard_offsets": offsets}) diff --git a/vllm/model_executor/model_loader/tensorizer.py b/vllm/model_executor/model_loader/tensorizer.py index c48b287ed181a..87f3fcb5cae00 100644 --- a/vllm/model_executor/model_loader/tensorizer.py +++ b/vllm/model_executor/model_loader/tensorizer.py @@ -13,7 +13,7 @@ from transformers import PretrainedConfig import vllm.envs as envs -from vllm.config import ModelConfig, ParallelConfig +from vllm.config import ModelConfig, ParallelConfig, set_current_vllm_config from vllm.engine.arg_utils import EngineArgs from vllm.engine.llm_engine import LLMEngine from vllm.logger import init_logger @@ -284,7 +284,8 @@ def _init_model(self): model_args = self.tensorizer_config.hf_config model_args.torch_dtype = self.tensorizer_config.dtype assert self.tensorizer_config.model_class is not None - with no_init_or_tensor(): + # TODO: Do we need to consider old-style model class? + with no_init_or_tensor(), set_current_vllm_config(self.vllm_config): return self.tensorizer_config.model_class( vllm_config=self.vllm_config, ) diff --git a/vllm/model_executor/model_loader/utils.py b/vllm/model_executor/model_loader/utils.py index b95c0b7cd0612..864dd04e79921 100644 --- a/vllm/model_executor/model_loader/utils.py +++ b/vllm/model_executor/model_loader/utils.py @@ -1,12 +1,13 @@ """Utilities for selecting and loading models.""" import contextlib -from typing import Tuple, Type +from typing import Optional, Tuple, Type import torch from torch import nn from vllm.config import ModelConfig from vllm.model_executor.models import ModelRegistry +from vllm.model_executor.models.adapters import as_embedding_model @contextlib.contextmanager @@ -19,8 +20,13 @@ def set_default_torch_dtype(dtype: torch.dtype): def get_model_architecture( - model_config: ModelConfig) -> Tuple[Type[nn.Module], str]: - architectures = getattr(model_config.hf_config, "architectures", []) + model_config: ModelConfig, + *, + architectures: Optional[list[str]] = None, +) -> Tuple[Type[nn.Module], str]: + if architectures is None: + architectures = getattr(model_config.hf_config, "architectures", []) + # Special handling for quantized Mixtral. # FIXME(woosuk): This is a temporary hack. mixtral_supported = [ @@ -32,7 +38,11 @@ def get_model_architecture( and "MixtralForCausalLM" in architectures): architectures = ["QuantMixtralForCausalLM"] - return ModelRegistry.resolve_model_cls(architectures) + model_cls, arch = ModelRegistry.resolve_model_cls(architectures) + if model_config.task == "embedding": + model_cls = as_embedding_model(model_cls) + + return model_cls, arch def get_architecture_class_name(model_config: ModelConfig) -> str: diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index d66373512b95e..a3ef9adad16d9 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -1,15 +1,14 @@ from .interfaces import (HasInnerState, SupportsLoRA, SupportsMultiModal, SupportsPP, has_inner_state, supports_lora, supports_multimodal, supports_pp) -from .interfaces_base import (VllmModelForEmbedding, - VllmModelForTextGeneration, is_embedding_model, - is_text_generation_model) +from .interfaces_base import (VllmModelForPooling, VllmModelForTextGeneration, + is_pooling_model, is_text_generation_model) from .registry import ModelRegistry __all__ = [ "ModelRegistry", - "VllmModelForEmbedding", - "is_embedding_model", + "VllmModelForPooling", + "is_pooling_model", "VllmModelForTextGeneration", "is_text_generation_model", "HasInnerState", @@ -20,4 +19,4 @@ "supports_multimodal", "SupportsPP", "supports_pp", -] \ No newline at end of file +] diff --git a/vllm/model_executor/models/adapters.py b/vllm/model_executor/models/adapters.py new file mode 100644 index 0000000000000..9cc43ae9181b9 --- /dev/null +++ b/vllm/model_executor/models/adapters.py @@ -0,0 +1,98 @@ +from collections.abc import Iterable +from typing import Any, TypeVar + +import torch +import torch.nn as nn + +from .interfaces_base import VllmModelForPooling, is_pooling_model + +_T = TypeVar("_T", bound=type[nn.Module]) + + +def as_embedding_model(cls: _T) -> _T: + """Subclass an existing vLLM model to support embeddings.""" + # Avoid modifying existing embedding models + if is_pooling_model(cls): + return cls + + # Lazy import + from vllm.config import VllmConfig + from vllm.model_executor.layers.pooler import (Pooler, PoolerOutput, + PoolingType) + from vllm.model_executor.pooling_metadata import PoolingMetadata + + from .utils import AutoWeightsLoader, WeightsMapper + + class ModelForEmbedding(cls, VllmModelForPooling): + + def __init__( + self, + *, + vllm_config: "VllmConfig", + prefix: str = "", + **kwargs: Any, + ) -> None: + super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs) + + # These are not used in embedding models + for attr in ("lm_head", "logits_processor"): + if hasattr(self, attr): + delattr(self, attr) + + pooler_config = vllm_config.model_config.pooler_config + assert pooler_config is not None + + # If the model already defines a pooler instance, don't overwrite it + if not getattr(self, "_pooler", None): + self._pooler = Pooler.from_config_with_defaults( + pooler_config, + pooling_type=PoolingType.LAST, + normalize=True, + softmax=False, + ) + + def pooler( + self, + hidden_states: torch.Tensor, + pooling_metadata: PoolingMetadata, + ) -> PoolerOutput: + return self._pooler(hidden_states, pooling_metadata) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): + # TODO: Support uninitialized params tracking + + # We have deleted this attribute, so don't load it + weights = ((name, data) for name, data in weights + if not name.startswith("lm_head.")) + + # If `*ForCausalLM` defines `load_weights` on the inner model + # and there are no other inner modules with parameters, + # we support loading from both `*Model` and `*ForCausalLM` + if hasattr(self, "model") and hasattr(self.model, "load_weights"): + # Whether only `self.model` contains parameters + model_is_only_param = all( + name == "model" or next(child.parameters(), None) is None + for name, child in self.named_children()) + + if model_is_only_param: + mapper = WeightsMapper(orig_to_new_prefix={"model.": ""}) + weights = mapper.apply(weights) + + self.model.load_weights(weights) + return + + # For most other models + if hasattr(cls, "load_weights"): + cls.load_weights(self, weights) # type: ignore + # Fallback + else: + loader = AutoWeightsLoader(self) + loader.load_weights(weights) + + ModelForEmbedding.__name__ = cls.__name__ \ + .removesuffix("ForCausalLM") \ + .removesuffix("ForConditionalGeneration") \ + .removesuffix("ChatModel") \ + .removesuffix("LMHeadModel") + "ForEmbedding" + + return ModelForEmbedding # type: ignore diff --git a/vllm/model_executor/models/arctic.py b/vllm/model_executor/models/arctic.py index e58ad19cab54c..fd6b5659df5d1 100644 --- a/vllm/model_executor/models/arctic.py +++ b/vllm/model_executor/models/arctic.py @@ -33,7 +33,7 @@ from vllm.transformers_utils.configs.arctic import ArcticConfig from .interfaces import SupportsPP -from .utils import (is_pp_missing_parameter, +from .utils import (extract_layer_index, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) @@ -44,15 +44,14 @@ class ArcticMLP(nn.Module): def __init__(self, config: ArcticConfig, - layer_id: int, expert_id: int = -1, is_residual_mlp: bool = False, quant_config: Optional[QuantizationConfig] = None, - reduce_results: bool = True): + reduce_results: bool = True, + prefix: str = ""): super().__init__() self.hidden_size = config.hidden_size self.expert_id = expert_id - self.layer_id = layer_id self.ffn_dim = config.intermediate_size if not is_residual_mlp \ else self.hidden_size @@ -85,13 +84,14 @@ class ArcticMoE(nn.Module): def __init__(self, config: ArcticConfig, - layer_id: int, tp_size: Optional[int] = None, params_dtype: Optional[torch.dtype] = None, quant_config: Optional[QuantizationConfig] = None, - reduce_results: bool = True): + reduce_results: bool = True, + prefix: str = ""): super().__init__() + layer_id = extract_layer_index(prefix) self.tp_size = tp_size or get_tensor_model_parallel_world_size() self.hidden_size = config.hidden_size self.num_experts = config.num_local_experts @@ -109,15 +109,16 @@ def __init__(self, if not self.is_moe_layer: self.mlp = ArcticMLP(config, - layer_id=layer_id, quant_config=quant_config, - reduce_results=reduce_results) + reduce_results=reduce_results, + prefix=f"{prefix}.mlp") else: self.gate = ReplicatedLinear(self.hidden_size, self.num_experts, bias=False, params_dtype=self.params_dtype, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.gate") if self.is_quant: self.ws = DeepSpeedFPParameter( torch.Size((self.num_experts, 2 * self.intermediate_size, @@ -220,13 +221,12 @@ class ArcticAttention(nn.Module): def __init__( self, config: ArcticConfig, - layer_idx: Optional[int] = None, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() self.config = config - self.layer_idx = layer_idx self.hidden_size = config.hidden_size tp_size = get_tensor_model_parallel_world_size() @@ -274,7 +274,8 @@ def __init__( self.scaling, num_kv_heads=self.num_kv_heads, cache_config=cache_config, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.attn") def forward( self, @@ -296,24 +297,25 @@ class ArcticDecoderLayer(nn.Module): def __init__( self, config: ArcticConfig, - layer_idx: int, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() - self.layer_idx = layer_idx self.hidden_size = config.hidden_size + layer_idx = extract_layer_index(prefix) is_moe_layer = (layer_idx + 1) % config.moe_layer_frequency == 0 self.use_residual = config.use_residual and is_moe_layer self.self_attn = ArcticAttention(config, - layer_idx, cache_config, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.self_attn") self.block_sparse_moe = ArcticMoE( config, - layer_id=layer_idx, quant_config=quant_config, - reduce_results=(not self.use_residual)) + reduce_results=(not self.use_residual), + prefix=f"{prefix}.block_sparse_moe", + ) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -324,9 +326,9 @@ def __init__( self.residual_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.residual_mlp = ArcticMLP(config, - layer_id=layer_idx, is_residual_mlp=True, - reduce_results=False) + reduce_results=False, + prefix=f"{prefix}.residual_mlp") def forward( self, @@ -380,8 +382,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): org_num_embeddings=self.vocab_size) self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, - lambda prefix: ArcticDecoderLayer(config, int( - prefix.split(".")[-1]), cache_config, quant_config), + lambda prefix: ArcticDecoderLayer( + config, cache_config, quant_config, prefix=prefix), prefix=f"{prefix}.layers") self._attn_implementation = config._attn_implementation self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) diff --git a/vllm/model_executor/models/aria.py b/vllm/model_executor/models/aria.py new file mode 100644 index 0000000000000..fa6b95f5481ad --- /dev/null +++ b/vllm/model_executor/models/aria.py @@ -0,0 +1,677 @@ +import math +from typing import Iterable, List, Optional, Set, Tuple, TypedDict, Union + +import torch +import torch.nn as nn +from torch.nn.init import trunc_normal_ +from transformers import LlamaConfig + +from vllm.attention import AttentionMetadata +from vllm.config import CacheConfig, QuantizationConfig, VllmConfig +from vllm.distributed import get_tensor_model_parallel_rank +from vllm.inputs import INPUT_REGISTRY, token_inputs +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( + get_compressed_tensors_cache_scale) +from vllm.model_executor.layers.sampler import (Sampler, SamplerOutput, + SamplingMetadata) +from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, maybe_remap_kv_scale_name) +from vllm.model_executor.models.idefics2_vision_model import ( + Idefics2VisionTransformer) +from vllm.model_executor.models.interfaces import SupportsMultiModal +from vllm.model_executor.models.llama import (LlamaDecoderLayer, LlamaMLP, + LlamaModel) +from vllm.model_executor.models.utils import (AutoWeightsLoader, WeightsMapper, + is_pp_missing_parameter, + maybe_prefix, + merge_multimodal_embeddings) +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.base import MultiModalInputs +from vllm.multimodal.image import cached_get_image_processor +from vllm.multimodal.inputs import NestedTensors +from vllm.multimodal.utils import (cached_get_tokenizer, + repeat_and_pad_placeholder_tokens) +from vllm.sequence import IntermediateTensors +from vllm.transformers_utils.configs.aria import (AriaMoELMConfig, + AriaVisionConfig) + +from .utils import flatten_bn + + +class AriaImagePixelInputs(TypedDict): + pixel_values: torch.Tensor + pixel_mask: Optional[torch.Tensor] + """ + Shape: + pixel_values: `(batch_size * num_images, num_channels, height, width)` + pixel_mask: `(batch_size * num_images, height, width)` + """ + + +class AriaVisionTransformer(Idefics2VisionTransformer): + """ + AriaVisionTransformer is a modified version of Idefics2VisionTransformer + that replaces the post-layernorm with an identity layer. + """ + + def __init__( + self, + config: AriaVisionConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__(config, quant_config, prefix) + self.post_layernorm = nn.Identity() + + +class AriaVisionModel(nn.Module): + config_class = AriaVisionConfig + + def __init__( + self, + config: AriaVisionConfig, + quant_config: Optional[QuantizationConfig] = None, + *, + prefix: str = "", + ) -> None: + super().__init__() + + self.vision_model = AriaVisionTransformer( + config, + quant_config, + prefix=f"{prefix}.vision_model", + ) + + def forward( + self, + pixel_values: torch.Tensor, + pixel_mask: Optional[torch.BoolTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.BoolTensor]]: + patch_attention_mask = self._create_patch_attention_mask(pixel_mask) + + vit_oup = self.vision_model( + pixel_values=pixel_values, + patch_attention_mask=patch_attention_mask, + ) + + image_atts = self._create_image_attention_mask(patch_attention_mask) + + return vit_oup, image_atts + + def _create_patch_attention_mask(self, pixel_mask): + if pixel_mask is None: + return None + + patches_subgrid = pixel_mask.unfold( + dimension=1, + size=self.vision_model.config.patch_size, + step=self.vision_model.config.patch_size, + ).unfold( + dimension=2, + size=self.vision_model.config.patch_size, + step=self.vision_model.config.patch_size, + ) + return (patches_subgrid.sum(dim=(-1, -2)) > 0).bool() + + def _create_image_attention_mask(self, patch_attention_mask): + if patch_attention_mask is None: + return None + + flattened_mask = patch_attention_mask.flatten(1) + return torch.logical_not(flattened_mask) + + +class FFN(nn.Module): + + def __init__(self, embed_dim, ff_dim, output_dim): + super().__init__() + self.linear_in = ColumnParallelLinear(embed_dim, ff_dim, bias=False) + self.linear_out = RowParallelLinear(ff_dim, output_dim, bias=False) + self.act = get_act_fn("gelu_new") + + def forward(self, hidden_states): + hidden_states, _ = self.linear_in(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states, _ = self.linear_out(hidden_states) + return hidden_states + + +class CrossAttention(nn.Module): + + def __init__(self, kv_dim, embed_dim, num_heads, drop_out_rate=0): + super().__init__() + self.num_heads = num_heads + self.q_proj = nn.Linear(embed_dim, embed_dim, bias=False) + self.k_proj = nn.Linear(kv_dim, embed_dim, bias=False) + self.v_proj = nn.Linear(kv_dim, embed_dim, bias=False) + + self.multihead_attn = nn.MultiheadAttention(embed_dim, num_heads) + self.linear = nn.Linear(embed_dim, embed_dim) + self.dropout = nn.Dropout(drop_out_rate) + + self.layer_norm = nn.LayerNorm(embed_dim) + self.ln_kv = nn.LayerNorm(kv_dim) + + def forward(self, x, hidden_states, attn_mask=None, add_residual=False): + normed_hidden_states = self.layer_norm(hidden_states) + query = self.q_proj(normed_hidden_states).permute(1, 0, 2) + + x = self.ln_kv(x) + key = self.k_proj(x).permute(1, 0, 2) + value = self.v_proj(x).permute(1, 0, 2) + + attn_output, _ = self.multihead_attn(query, + key, + value, + attn_mask=attn_mask) + + attn_output = attn_output.permute(1, 0, 2) + + if add_residual: + attn_output = hidden_states + self.dropout( + self.linear(attn_output)) + else: + attn_output = self.dropout(self.linear(attn_output)) + + return attn_output + + +class AriaProjector(nn.Module): + """ + A projection module with one cross attention layer and one FFN layer, which + projects ViT's outputs into MoE's inputs. + + Args: + patch_to_query_dict (dict): Maps patch numbers to their corresponding + query numbers, + e.g., {1225: 128, 4900: 256}. This allows for different query sizes + based on image resolution. + embed_dim (int): Embedding dimension. + num_heads (int): Number of attention heads. + kv_dim (int): Dimension of key and value. + ff_dim (int): Hidden dimension of the feed-forward network. + output_dim (int): Output dimension. + norm_layer (nn.Module): Normalization layer. Default is nn.LayerNorm. + + Outputs: + A tensor with the shape of (batch_size, query_number, output_dim) + """ + + def __init__( + self, + patch_to_query_dict, + embed_dim, + num_heads, + kv_dim, + ff_dim, + output_dim, + norm_layer=nn.LayerNorm, + ): + super().__init__() + self.patch_to_query_dict = patch_to_query_dict + self.embed_dim = embed_dim + self.num_heads = num_heads + + self.query = nn.Parameter( + torch.zeros(max(patch_to_query_dict.values()), self.embed_dim)) + + trunc_normal_(self.query, std=0.02) + + self.cross_attn = CrossAttention(kv_dim, embed_dim, num_heads) + + self.ln_ffn = norm_layer(embed_dim) + self.ffn = FFN(embed_dim, ff_dim, output_dim) + + def forward(self, x, attn_mask=None): + bs = x.shape[0] + queries = self.query.unsqueeze(0).repeat(bs, 1, 1) + + query_num = self.patch_to_query_dict.get(x.shape[1], None) + assert (query_num is not None + ), f"Query number for {x.shape[1]} patches is not provided" + + queries = queries[:, :query_num, :] + + if attn_mask is not None: + attn_mask = attn_mask.repeat_interleave(self.num_heads, 0) + attn_mask = attn_mask.unsqueeze(1).expand(-1, queries.size(1), -1) + + attention_out = self.cross_attn(x, queries, attn_mask=attn_mask) + + out = self.ffn(self.ln_ffn(attention_out)) + + return out + + +class AriaFusedMoE(FusedMoE): + + def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, + shard_id: str) -> Set[str]: + # Override the weight_loader to handle the expert weights in the Aria + # model, which are already packed with experts, and merge the gate and + # up weights for each expert. + # Note: Loading expert weights with quantization is not supported + tp_rank = get_tensor_model_parallel_rank() + if shard_id == 'w13': + # the shape of loaded_weight is + # (num_experts, hidden_size, 2 * moe_intermediate_size) + if self.tp_size > 1: + up, gate = loaded_weight.chunk(2, dim=-1) + up_current_rank = up.chunk(self.tp_size, dim=-1)[tp_rank] + gate_current_rank = gate.chunk(self.tp_size, dim=-1)[tp_rank] + up_and_gate = torch.cat([up_current_rank, gate_current_rank], + dim=-1).transpose(1, 2) + param.data.copy_(up_and_gate) + else: + param.data.copy_(loaded_weight.transpose(1, 2)) + elif shard_id == 'w2': + # the shape of loaded_weight is + # (num_experts, moe_intermediate_size, hidden_size) + if self.tp_size > 1: + down_current_rank = loaded_weight.chunk(self.tp_size, + dim=1)[tp_rank] + param.data.copy_(down_current_rank.transpose(1, 2)) + else: + param.data.copy_(loaded_weight.transpose(1, 2)) + + +class MoELayer(nn.Module): + """ + Mixture of Experts (MoE) Layer for the AriaMoE model. + + This layer implements the MoE mechanism, which routes input tokens to + different experts based on a routing algorithm, processes them through the + experts, and then combines the outputs. + """ + + def __init__( + self, + config: AriaMoELMConfig, + quant_config: Optional[QuantizationConfig], + ) -> None: + super().__init__() + self.config = config + + self.router_weight = nn.Parameter( + torch.empty( + (self.config.moe_num_experts, self.config.hidden_size))) + + self.experts = AriaFusedMoE( + num_experts=config.moe_num_experts, + top_k=config.moe_topk, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + quant_config=quant_config, + reduce_results=True, + ) + self.shared_experts = LlamaMLP( + config.hidden_size, + config.moe_intermediate_size * config.moe_num_shared_experts, + "silu", + quant_config=quant_config, + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + """ + Forward pass of the MoE Layer. + + Args: + hidden_states (torch.Tensor): Input tensor of shape (batch_size, + sequence_length, hidden_size). + + Returns: + torch.Tensor: Output tensor after passing through the MoE layer. + """ + + router_output = torch.nn.functional.linear(hidden_states, + self.router_weight) + + shared_expert_output = self.shared_experts(hidden_states) + sparse_expert_output = self.experts(hidden_states, router_output) + + return sparse_expert_output + shared_expert_output + + +class MoEDecoderLayer(LlamaDecoderLayer): + """ + Custom Decoder Layer for the AriaMoE model which modifies the standard + `LlamaDecoderLayer` by replacing the traditional MLP with a Mixture of + Experts (MoE) Layer. + """ + + def __init__( + self, + config: LlamaConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__(config, cache_config, quant_config, prefix) + self.mlp = MoELayer(config, quant_config=quant_config) + + +class AriaMoELMModel(LlamaModel): + """ + Custom LlamaModel for the AriaMoE model which modifies the standard + LlamaModel by replacing the `LlamaDecoderLayer` with `MoEDecoderLayer`. + """ + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__(vllm_config=vllm_config, + prefix=prefix, + layer_type=MoEDecoderLayer) + + # Adapted from LlamaModel.load_weights with the modification of adding + # the expert weights mapping to `stacked_params_mapping` + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + (".gate_up_proj", ".gate_proj", 0), + (".gate_up_proj", ".up_proj", 1), + ("experts.w13_weight", "experts.fc1.weight", 'w13'), + ("experts.w2_weight", "experts.fc2.weight", 'w2'), + ] + params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + if ("rotary_emb.cos_cached" in name + or "rotary_emb.sin_cached" in name): + # Models trained using ColossalAI may include these tensors in + # the checkpoint. Skip them. + continue + if scale_name := get_compressed_tensors_cache_scale(name): + # Loading kv cache scales for compressed-tensors quantization + param = params_dict[scale_name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + loaded_weight = loaded_weight[0] + weight_loader(param, loaded_weight) + loaded_params.add(scale_name) + continue + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +def build_mm_projector(config): + return AriaProjector( + patch_to_query_dict=config.projector_patch_to_query_dict, + embed_dim=config.vision_config.hidden_size, + num_heads=config.vision_config.num_attention_heads, + kv_dim=config.vision_config.hidden_size, + ff_dim=config.text_config.hidden_size, + output_dim=config.text_config.hidden_size, + ) + + +def get_max_multimodal_tokens(ctx): + return max(ctx.model_config.hf_config.image_size2tokens.values()) + + +def input_mapper_for_aria(ctx, data): + return MultiModalInputs(data) + + +def input_processor(ctx, llm_inputs): + multi_modal_data = llm_inputs.get("multi_modal_data") + # if it is pure text input, use it as is + if multi_modal_data is None or "image" not in multi_modal_data: + return llm_inputs + + model_config = ctx.model_config + + tokenizer = cached_get_tokenizer(model_config.tokenizer) + image_processor = cached_get_image_processor( + model_config.model, trust_remote_code=model_config.trust_remote_code) + hf_config = model_config.hf_config + + # prepare image tokens, the max_image_size is used to determine the number + # of patch_size for every image + max_image_size = multi_modal_data.pop("max_image_size", 980) + _split_image = multi_modal_data.pop("split_image", False) + + assert isinstance(max_image_size, + (int, float)), "max_image_size should be float or int" + images = (multi_modal_data["image"] if isinstance( + multi_modal_data["image"], list) else [multi_modal_data["image"]]) + + image_inputs = image_processor.preprocess(images, + max_image_size=max_image_size, + split_image=_split_image, + return_tensors="pt").data + image_inputs['pixel_values'] = image_inputs['pixel_values'].to( + ctx.model_config.dtype) + num_crops = image_inputs.pop("num_crops") + + prompt_token_ids = llm_inputs["prompt_token_ids"] + if num_crops.sum().item() > 0: + _, prompt_token_ids, _ = repeat_and_pad_placeholder_tokens( + tokenizer, + None, + prompt_token_ids, + placeholder_token_id=hf_config.image_token_index, + repeat_count=num_crops, + ) + + repeat_count = [hf_config.image_size2tokens[max_image_size] + ] * sum(num_crops).item() + new_prompt, new_token_ids, _ = repeat_and_pad_placeholder_tokens( + tokenizer, + None, + prompt_token_ids, + placeholder_token_id=hf_config.image_token_index, + repeat_count=repeat_count, + ) + + return token_inputs( + prompt_token_ids=new_token_ids, + prompt=new_prompt, + multi_modal_data={"image": image_inputs}, + ) + + +@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_multimodal_tokens) +@MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_aria) +@INPUT_REGISTRY.register_input_processor(input_processor) +class AriaForConditionalGeneration(nn.Module, SupportsMultiModal): + """ + Aria model for conditional generation tasks. + + This model combines a vision tower, a multi-modal projector, and a language + model to perform tasks that involve both image and text inputs. + """ + + def __init__( + self, + vllm_config: VllmConfig, + prefix: str = "", + ): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + + # prepare the image_size to tokens mapping for the image preprocess, see + # input_processor + config.image_size2tokens = { + int(math.sqrt(k) * config.vision_config.patch_size): v + for k, v in config.projector_patch_to_query_dict.items() + } + self.config = config + self.vision_tower = AriaVisionModel(config.vision_config) + self.multi_modal_projector = build_mm_projector(config) + self.vocab_size = config.text_config.vocab_size + self.language_model = AriaMoELMModel( + vllm_config=vllm_config.with_hf_config(config.text_config), + prefix=maybe_prefix(prefix, "language_model.model"), + ) + self.pad_token_id = (self.config.pad_token_id + if self.config.pad_token_id is not None else -1) + self.unpadded_vocab_size = config.text_config.vocab_size + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.text_config.hidden_size, + org_num_embeddings=self.language_model.org_vocab_size, + quant_config=quant_config, + ) + logit_scale = getattr(config, "logit_scale", 1.0) + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + self.vocab_size, logit_scale) + self.sampler = Sampler() + + def _validate_image_sizes( + self, images: List[torch.Tensor]) -> List[torch.Tensor]: + if not all(img.shape == images[0].shape for img in images): + raise ValueError("All images must be the same size") + return images + + def _parse_and_validate_image_input( + self, **kwargs: object) -> Optional[AriaImagePixelInputs]: + pixel_values = kwargs.pop("pixel_values", None) + pixel_mask = kwargs.pop("pixel_mask", None) + + if pixel_values is None: + return None + + if not isinstance(pixel_values, (torch.Tensor, list)): + raise ValueError("Incorrect type of pixel values. " + f"Got type: {type(pixel_values)}") + + pixel_values = self._validate_image_sizes(pixel_values) + pixel_values = flatten_bn(pixel_values, concat=True) + if pixel_mask is not None: + pixel_mask = flatten_bn(pixel_mask, concat=True) + + return AriaImagePixelInputs( + pixel_values=pixel_values, + pixel_mask=pixel_mask, + ) + + def _process_image_input( + self, image_input: AriaImagePixelInputs + ) -> Tuple[torch.Tensor, torch.Tensor]: + assert self.vision_tower is not None + + pixel_values = image_input['pixel_values'] + pixel_mask = image_input['pixel_mask'] + + image_feature, image_attn_mask = self.vision_tower( + pixel_values, pixel_mask=pixel_mask) + return self.multi_modal_projector(image_feature, image_attn_mask) + + def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]: + image_input = self._parse_and_validate_image_input(**kwargs) + if image_input is None: + return None + multimodal_embeddings = self._process_image_input(image_input) + return multimodal_embeddings + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[NestedTensors] = None, + ) -> torch.Tensor: + inputs_embeds = self.language_model.get_input_embeddings(input_ids) + if multimodal_embeddings is not None: + inputs_embeds = merge_multimodal_embeddings( + input_ids, inputs_embeds, multimodal_embeddings, + self.config.image_token_index) + return inputs_embeds + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: object, + ) -> Union[torch.Tensor, IntermediateTensors]: + if inputs_embeds is None: + multimodal_embeddings = self.get_multimodal_embeddings(**kwargs) + # always pass the input via `inputs_embeds` + # to make sure the computation graph is consistent + inputs_embeds = self.get_input_embeddings(input_ids, + multimodal_embeddings) + input_ids = None + + hidden_states = self.language_model( + input_ids, + positions, + kv_caches, + attn_metadata, + intermediate_tensors, + inputs_embeds=inputs_embeds, + ) + + return hidden_states + + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix={ + "language_model.model": "language_model", + "language_model.lm_head": "lm_head", + }, + orig_to_new_suffix={ + "router.weight": "router_weight", + }, + ) + + loader = AutoWeightsLoader(self) + loader.load_weights(weights, mapper=hf_to_vllm_mapper) diff --git a/vllm/model_executor/models/baichuan.py b/vllm/model_executor/models/baichuan.py index 3749a16a38994..5e68b7f165bf4 100644 --- a/vllm/model_executor/models/baichuan.py +++ b/vllm/model_executor/models/baichuan.py @@ -116,6 +116,7 @@ def __init__( max_position_embeddings: int = 8192, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() self.hidden_size = hidden_size @@ -158,7 +159,8 @@ def __init__( self.head_dim, scaling, alibi_slopes=alibi_slopes, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.attn") else: self.rotary_emb = get_rope( self.head_dim, @@ -171,7 +173,8 @@ def __init__( self.head_dim, self.scaling, cache_config=cache_config, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.attn") def forward( self, @@ -195,7 +198,8 @@ def __init__(self, config: PretrainedConfig, position_embedding: str, cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None): + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): super().__init__() self.hidden_size = config.hidden_size rope_theta = getattr(config, "rope_theta", 10000) @@ -209,6 +213,7 @@ def __init__(self, max_position_embeddings=max_position_embeddings, cache_config=cache_config, quant_config=quant_config, + prefix=f"{prefix}.self_attn", ) self.mlp = BaiChuanMLP( hidden_size=self.hidden_size, @@ -275,8 +280,11 @@ def __init__( ) self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, - lambda prefix: BaiChuanDecoderLayer(config, position_embedding, - cache_config, quant_config), + lambda prefix: BaiChuanDecoderLayer(config, + position_embedding, + cache_config, + quant_config, + prefix=prefix), prefix=f"{prefix}.layers", ) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -342,6 +350,13 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP): embedding_modules = {} embedding_padding_modules = [] + # BitandBytes specific attributes + bitsandbytes_stacked_params_mapping = { + # shard_name, weight_name, index + "gate_proj": ("gate_up_proj", 0), + "up_proj": ("gate_up_proj", 1), + } + def __init__( self, *, diff --git a/vllm/model_executor/models/bart.py b/vllm/model_executor/models/bart.py index a50a5a5b018e1..3776490cb3465 100644 --- a/vllm/model_executor/models/bart.py +++ b/vllm/model_executor/models/bart.py @@ -126,6 +126,7 @@ def __init__( config: Optional[BartConfig] = None, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() self.d_model = config.d_model @@ -178,7 +179,8 @@ def __init__( self.scaling, num_kv_heads=self.num_kv_heads, cache_config=cache_config, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.attn") def forward(self, hidden_states: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: AttentionMetadata) -> torch.Tensor: @@ -208,6 +210,7 @@ def __init__( config: Optional[BartConfig] = None, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() self.d_model = config.d_model @@ -260,7 +263,8 @@ def __init__( self.scaling, num_kv_heads=self.num_kv_heads, cache_config=cache_config, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.attn") def forward(self, hidden_states: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: AttentionMetadata) -> torch.Tensor: @@ -290,6 +294,7 @@ def __init__( config: Optional[BartConfig] = None, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() self.d_model = config.d_model @@ -342,7 +347,8 @@ def __init__( self.scaling, num_kv_heads=self.num_kv_heads, cache_config=cache_config, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.attn") def forward( self, @@ -384,6 +390,7 @@ def __init__( config: BartConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() self.embed_dim = config.d_model @@ -393,7 +400,9 @@ def __init__( num_heads=config.encoder_attention_heads, config=config, cache_config=cache_config, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + ) self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.activation_fn = get_act_fn(config.activation_function) @@ -464,6 +473,7 @@ def __init__( config: BartConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() self.embed_dim = config.d_model @@ -473,7 +483,9 @@ def __init__( num_heads=config.decoder_attention_heads, config=config, cache_config=cache_config, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + ) self.activation_fn = get_act_fn(config.activation_function) self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) @@ -486,6 +498,7 @@ def __init__( self.embed_dim, config.decoder_attention_heads, config=config, + prefix=f"{prefix}.encoder_attn", ) self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) @@ -578,7 +591,8 @@ def __init__(self, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, - embed_tokens: Optional[nn.Embedding] = None): + embed_tokens: Optional[nn.Embedding] = None, + prefix: str = ""): super().__init__() self.cache_config = cache_config @@ -599,9 +613,13 @@ def __init__(self, config.max_position_embeddings, embed_dim, ) - self.layers = nn.ModuleList( - [BartEncoderLayer(config,cache_config,quant_config) \ - for _ in range(config.encoder_layers)]) + self.layers = nn.ModuleList([ + BartEncoderLayer(config, + cache_config, + quant_config, + prefix=f"{prefix}.layers.{layer_idx}") + for layer_idx in range(config.encoder_layers) + ]) self.layernorm_embedding = nn.LayerNorm(embed_dim) @@ -661,6 +679,7 @@ def __init__( quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, embed_tokens: Optional[nn.Embedding] = None, + prefix: str = "", ): super().__init__() self.cache_config = cache_config @@ -683,8 +702,9 @@ def __init__( ) self.layers = nn.ModuleList( - [BartDecoderLayer(config,cache_config,quant_config) \ - for _ in range(config.decoder_layers)]) + [BartDecoderLayer(config,cache_config,quant_config, + prefix=f"{prefix}.layers.{layer_idx}") \ + for layer_idx in range(config.decoder_layers)]) self.layernorm_embedding = nn.LayerNorm(config.d_model) @@ -759,10 +779,12 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.encoder = BartEncoder(config, cache_config, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.encoder") self.decoder = BartDecoder(config, cache_config, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.decoder") def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, encoder_input_ids: torch.Tensor, diff --git a/vllm/model_executor/models/bert.py b/vllm/model_executor/models/bert.py index d8301a36acb01..053d838432885 100644 --- a/vllm/model_executor/models/bert.py +++ b/vllm/model_executor/models/bert.py @@ -5,22 +5,26 @@ from transformers import BertConfig from vllm.attention import Attention, AttentionMetadata, AttentionType +from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, PoolerConfig, VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, QKVParallelLinear, RowParallelLinear) -from vllm.model_executor.layers.pooler import Pooler, PoolingType -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.pooler import (CrossEncodingPooler, Pooler, + PoolingType) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.sequence import IntermediateTensors, PoolerOutput +from vllm.transformers_utils.config import ( + get_cross_encoder_activation_function) -from .utils import maybe_prefix +from .interfaces import SupportsCrossEncoding +from .utils import WeightsMapper, maybe_prefix class BertEmbedding(nn.Module): @@ -48,7 +52,9 @@ def __init__(self, config: BertConfig): def forward( self, input_ids: torch.Tensor, - position_ids: Optional[torch.Tensor] = None, + seq_lens: torch.Tensor, + position_ids: torch.Tensor, + token_type_ids: Optional[torch.Tensor] = None, ) -> torch.Tensor: input_shape = input_ids.size() @@ -58,25 +64,42 @@ def forward( # Position embeddings. position_embeddings = self.position_embeddings(position_ids) - # Token type embeddings. (TODO: move off hotpath?) - token_type_embeddings = self.token_type_embeddings( - torch.zeros(input_shape, - dtype=torch.long, - device=inputs_embeds.device)) + if token_type_ids is None: + token_type_ids = torch.zeros(input_shape, + dtype=torch.long, + device=inputs_embeds.device) + + token_type_embeddings = self.token_type_embeddings(token_type_ids) embeddings = inputs_embeds + token_type_embeddings + position_embeddings embeddings = self.LayerNorm(embeddings) return embeddings +class BertPooler(nn.Module): + + def __init__(self, config: BertConfig): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[0, :] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +@support_torch_compile class BertEncoder(nn.Module): - def __init__(self, - config: BertConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = ""): + def __init__(self, vllm_config: VllmConfig, prefix: str = ""): super().__init__() + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config self.layer = nn.ModuleList([ BertLayer(config=config, cache_config=cache_config, @@ -309,16 +332,14 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "", - embedding_class: type = BertEmbedding): + embedding_class: type = BertEmbedding, + add_pooling_layer: bool = False): super().__init__() config = vllm_config.model_config.hf_config - cache_config = vllm_config.cache_config - quant_config = vllm_config.quant_config self.embeddings = embedding_class(config) - self.encoder = BertEncoder(config, - cache_config, - quant_config, + self.encoder = BertEncoder(vllm_config=vllm_config, prefix=f"{prefix}.encoder") + self.pooler = BertPooler(config) if add_pooling_layer else None def forward( self, @@ -328,13 +349,17 @@ def forward( attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, ) -> torch.Tensor: if inputs_embeds is not None: hidden_states = inputs_embeds else: - hidden_states = self.embeddings(input_ids=input_ids, - position_ids=position_ids) - + assert hasattr(attn_metadata, "seq_lens_tensor") + hidden_states = self.embeddings( + input_ids=input_ids, + seq_lens=attn_metadata.seq_lens_tensor, + position_ids=position_ids, + token_type_ids=token_type_ids) return self.encoder(hidden_states, kv_caches, attn_metadata) def load_weights(self, weights: Iterable[Tuple[str, @@ -349,7 +374,7 @@ def load_weights(self, weights: Iterable[Tuple[str, params_dict = dict(self.named_parameters()) loaded_params: Set[str] = set() for name, loaded_weight in weights: - if "pooler" in name: + if self.pooler is None and "pooler" in name: continue for (param_name, weight_name, shard_id) in stacked_params_mapping: if weight_name not in name: @@ -416,6 +441,10 @@ def pooler( return self._pooler(hidden_states, pooling_metadata) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""}) + weights = hf_to_vllm_mapper.apply(weights) + weights = ((name, data) for name, data in weights + if not name.startswith("lm_head.")) self.model.load_weights(weights) def _build_model(self, @@ -430,3 +459,78 @@ def _build_pooler(self, pooler_config: PoolerConfig) -> Pooler: pooling_type=PoolingType.CLS, normalize=True, softmax=False) + + +class BertForSequenceClassification(nn.Module, SupportsCrossEncoding): + """A model that uses Bert to provide embedding functionalities. + + This class encapsulates the BertModel and provides an interface for + embedding operations and customized pooling functions. + + Attributes: + model: An instance of BertModel used for forward operations. + _pooler: An instance of Pooler used for pooling operations. + """ + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + + self.default_activation_function = \ + get_cross_encoder_activation_function(config) + + self.num_labels = config.num_labels + self.bert = BertModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "bert"), + embedding_class=BertEmbedding, + add_pooling_layer=True) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + self._pooler = CrossEncodingPooler(config, self.classifier, + self.bert.pooler) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + + self_weights = [] + + def weight_filter(): + for name, weight in weights: + if name.startswith("bert."): + yield (name[len("bert."):], weight) + else: + self_weights.append((name, weight)) + + self.bert.load_weights(weight_filter()) + + params_dict = dict(self.named_parameters()) + + for name, loaded_weight in self_weights: + if name.startswith("classifier"): + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + + def pooler( + self, + hidden_states: torch.Tensor, + pooling_metadata: PoolingMetadata, + ) -> Optional[PoolerOutput]: + return self._pooler(hidden_states, pooling_metadata) + + def forward( + self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + return self.bert(input_ids=input_ids, + position_ids=positions, + kv_caches=kv_caches, + inputs_embeds=inputs_embeds, + intermediate_tensors=intermediate_tensors, + attn_metadata=attn_metadata, + token_type_ids=token_type_ids) diff --git a/vllm/model_executor/models/blip2.py b/vllm/model_executor/models/blip2.py index 7d7639b4a92ce..76b8505ee1c2a 100644 --- a/vllm/model_executor/models/blip2.py +++ b/vllm/model_executor/models/blip2.py @@ -16,6 +16,7 @@ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import NestedTensors from vllm.multimodal.utils import consecutive_placeholder_ranges from vllm.sequence import IntermediateTensors, SequenceData @@ -511,9 +512,10 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): ) self.language_model = init_vllm_registered_model( - config.text_config, vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "language_model")) + hf_config=config.text_config, + prefix=maybe_prefix(prefix, "language_model"), + ) self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors) @@ -609,6 +611,25 @@ def _process_image_input(self, return self.language_projection(query_output) + def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]: + image_input = self._parse_and_validate_image_input(**kwargs) + if image_input is None: + return None + vision_embeddings = self._process_image_input(image_input) + return vision_embeddings + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[NestedTensors] = None, + ) -> torch.Tensor: + inputs_embeds = self.language_model.get_input_embeddings(input_ids) + if multimodal_embeddings is not None: + inputs_embeds = merge_multimodal_embeddings( + input_ids, inputs_embeds, multimodal_embeddings, + BLIP2_IMAGE_TOKEN_ID) + return inputs_embeds + def forward( self, input_ids: torch.Tensor, @@ -616,6 +637,7 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, **kwargs: object, ) -> Union[SamplerOutput, IntermediateTensors]: """Run forward pass for BLIP-2. @@ -648,32 +670,24 @@ def forward( See also: :class:`Blip2ImageInputs` """ + if intermediate_tensors is not None: - input_ids = None inputs_embeds = None - else: - image_input = self._parse_and_validate_image_input(**kwargs) - - if image_input is not None: - vision_embeddings = self._process_image_input(image_input) - inputs_embeds = self.language_model.model.get_input_embeddings( - input_ids) - - inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, vision_embeddings, - BLIP2_IMAGE_TOKEN_ID) - - input_ids = None - else: - inputs_embeds = None - - hidden_states = self.language_model.model( - input_ids, - positions, - kv_caches, - attn_metadata, - intermediate_tensors=intermediate_tensors, - inputs_embeds=inputs_embeds) + + # NOTE: In v1, inputs_embeds is always generated at model runner, this + # condition is for v0 compatibility. + elif inputs_embeds is None: + vision_embeddings = self.get_multimodal_embeddings(**kwargs) + inputs_embeds = self.get_input_embeddings(input_ids, + vision_embeddings) + input_ids = None + + hidden_states = self.language_model.model(input_ids, + positions, + kv_caches, + attn_metadata, + intermediate_tensors, + inputs_embeds=inputs_embeds) return hidden_states diff --git a/vllm/model_executor/models/bloom.py b/vllm/model_executor/models/bloom.py index 1060d418474ef..fee74f491acc1 100644 --- a/vllm/model_executor/models/bloom.py +++ b/vllm/model_executor/models/bloom.py @@ -78,6 +78,7 @@ def __init__( config: BloomConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() self.hidden_size = config.hidden_size @@ -116,7 +117,8 @@ def __init__( scaling, alibi_slopes=alibi_slopes, cache_config=cache_config, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.attn") def forward( self, @@ -168,14 +170,17 @@ def __init__( config: BloomConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() hidden_size = config.hidden_size self.input_layernorm = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) - self.self_attention = BloomAttention(config, cache_config, - quant_config) + self.self_attention = BloomAttention(config, + cache_config, + quant_config, + prefix=f"{prefix}.self_attention") self.post_attention_layernorm = nn.LayerNorm( hidden_size, eps=config.layer_norm_epsilon) self.mlp = BloomMLP(config, quant_config) @@ -242,7 +247,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): # Transformer blocks self.start_layer, self.end_layer, self.h = make_layers( config.num_hidden_layers, - lambda prefix: BloomBlock(config, cache_config, quant_config), + lambda prefix: BloomBlock( + config, cache_config, quant_config, prefix=prefix), prefix=f"{prefix}.h") # Final Layer Norm diff --git a/vllm/model_executor/models/chameleon.py b/vllm/model_executor/models/chameleon.py index 8f91abffaea90..a40c321ce0a58 100644 --- a/vllm/model_executor/models/chameleon.py +++ b/vllm/model_executor/models/chameleon.py @@ -29,6 +29,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_weight_attrs from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import NestedTensors from vllm.multimodal.utils import (cached_get_tokenizer, consecutive_placeholder_ranges, repeat_and_pad_placeholder_tokens) @@ -38,7 +39,7 @@ from .interfaces import SupportsMultiModal, SupportsPP from .utils import (is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) + maybe_prefix, merge_multimodal_embeddings) # These configs are not part of the model config but the preprocessor # and processor files, so we hardcode them in the model file for now. @@ -223,6 +224,7 @@ def __init__( quant_config: Optional[QuantizationConfig] = None, bias: bool = False, cache_config: Optional[CacheConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.hidden_size = hidden_size @@ -276,7 +278,8 @@ def __init__( self.scaling, num_kv_heads=self.num_kv_heads, cache_config=cache_config, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.attn") def _apply_qk_norm(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: @@ -313,6 +316,7 @@ def __init__( config: ChameleonConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.hidden_size = config.hidden_size @@ -336,6 +340,7 @@ def __init__( quant_config=quant_config, bias=False, cache_config=cache_config, + prefix=f"{prefix}.self_attn", ) self.mlp = ChameleonMLP( hidden_size=self.hidden_size, @@ -386,6 +391,7 @@ def __init__( config: ChameleonConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.hidden_size = config.hidden_size @@ -409,6 +415,7 @@ def __init__( quant_config=quant_config, bias=False, cache_config=cache_config, + prefix=f"{prefix}.self_attn", ) self.mlp = ChameleonMLP( hidden_size=self.hidden_size, @@ -855,7 +862,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config.num_hidden_layers, lambda prefix: decoder_layer(config=config, cache_config=cache_config, - quant_config=quant_config), + quant_config=quant_config, + prefix=prefix), prefix=f"{prefix}.layers", ) @@ -980,6 +988,29 @@ def _parse_and_validate_image_input( data=self._validate_pixel_values(pixel_values), ) + def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]: + image_input = self._parse_and_validate_image_input(**kwargs) + if image_input is None: + return None + assert self.model.vqmodel is not None + image_tokens = self.model.get_image_tokens(image_input["data"].to( + self.config.torch_dtype)) + vision_embeddings = self.model.get_input_embeddings(image_tokens) + return vision_embeddings + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[NestedTensors] = None, + ) -> torch.Tensor: + + inputs_embeds = self.model.get_input_embeddings(input_ids) + if multimodal_embeddings is not None: + inputs_embeds = merge_multimodal_embeddings( + input_ids, inputs_embeds, multimodal_embeddings, + self.model.vocabulary_mapping.image_token_id) + return inputs_embeds + def forward( self, input_ids: torch.Tensor, @@ -987,27 +1018,27 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, **kwargs, ) -> Union[torch.Tensor, IntermediateTensors]: if intermediate_tensors is not None: + inputs_embeds = None + + # NOTE: In v1, inputs_embeds is always generated at model runner, this + # condition is for v0 compatibility. + elif inputs_embeds is None: + vision_embeddings = self.get_multimodal_embeddings(**kwargs) + inputs_embeds = self.get_input_embeddings(input_ids, + vision_embeddings) input_ids = None - else: - image_input = self._parse_and_validate_image_input(**kwargs) - - if image_input is not None: - assert self.model.vqmodel is not None - image_tokens = self.model.get_image_tokens( - image_input["data"].to(self.config.torch_dtype)) - image_token_id = self.model.vocabulary_mapping.image_token_id - special_image_mask = input_ids == image_token_id - image_tokens = image_tokens.to(input_ids.device, - input_ids.dtype) - input_ids = input_ids.masked_scatter(special_image_mask, - image_tokens) - - hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + + hidden_states = self.model(input_ids, + positions, + kv_caches, + attn_metadata, + intermediate_tensors, + inputs_embeds=inputs_embeds) return hidden_states def compute_logits( diff --git a/vllm/model_executor/models/chatglm.py b/vllm/model_executor/models/chatglm.py index 625e31bb0d368..6c50882d83c3b 100644 --- a/vllm/model_executor/models/chatglm.py +++ b/vllm/model_executor/models/chatglm.py @@ -33,7 +33,8 @@ from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import MultiModalData, MultiModalKwargs +from vllm.multimodal.inputs import (MultiModalData, MultiModalKwargs, + NestedTensors) from vllm.multimodal.utils import cached_get_tokenizer from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors, SequenceData) @@ -230,6 +231,7 @@ def __init__( config: ChatGLMConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() self.hidden_size = config.hidden_size @@ -285,7 +287,8 @@ def __init__( self.scaling, num_kv_heads=self.num_kv_heads, cache_config=cache_config, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.attn") def forward( self, @@ -364,6 +367,7 @@ def __init__( config: ChatGLMConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() self.apply_residual_connection_post_layernorm = ( @@ -377,7 +381,10 @@ def __init__( eps=config.layernorm_epsilon) # Self attention. - self.self_attention = GLMAttention(config, cache_config, quant_config) + self.self_attention = GLMAttention(config, + cache_config, + quant_config, + prefix=f"{prefix}.self_attention") self.hidden_dropout = config.hidden_dropout # Layernorm on the attention output @@ -446,7 +453,8 @@ def __init__( # Transformer layers. self.start_layer, self.end_layer, self.layers = make_layers( self.num_layers, - lambda prefix: GLMBlock(config, cache_config, quant_config), + lambda prefix: GLMBlock( + config, cache_config, quant_config, prefix=prefix), prefix=f"{prefix}.layers", ) @@ -500,16 +508,22 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.num_layers = config.num_layers self.multi_query_group_num = config.multi_query_group_num self.kv_channels = config.kv_channels - self.encoder = GLMTransformer(config, cache_config, quant_config) + self.encoder = GLMTransformer(config, + cache_config, + quant_config, + prefix=f"{prefix}.encoder") self.output_layer = ParallelLMHead(config.padded_vocab_size, config.hidden_size, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.output_layer") vision_config_flag = getattr(config, 'vision_config', None) if vision_config_flag is not None: self.vision_config = Namespace(**config.vision_config) - self.vision = EVA2CLIPModel(self.config, quant_config) + self.vision = EVA2CLIPModel(self.config, + quant_config, + prefix=f"{prefix}.vision") else: self.vision = None @@ -532,6 +546,30 @@ def _parse_and_validate_image_input( """) return GLMImagePixelInputs(pixel_values=pixel_values) + def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]: + image_input = self._parse_and_validate_image_input(**kwargs) + if image_input["pixel_values"] is None: + return None + pixel_values = image_input["pixel_values"].to( + dtype=self.config.torch_dtype) + vision_embeddings = self.vision(pixel_values) + return vision_embeddings + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[NestedTensors] = None, + ) -> torch.Tensor: + inputs_embeds = self.embedding(input_ids) + if multimodal_embeddings is not None: + inputs_embeds = merge_glm_vision_embeddings( + input_ids=input_ids, + inputs_embeds=inputs_embeds, + vision_embeddings=multimodal_embeddings, + boi_token_id=self.config.boi_token_id, + eoi_token_id=self.config.eoi_token_id) + return inputs_embeds + def forward( self, input_ids: torch.Tensor, @@ -539,26 +577,17 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, **kwargs: object, ) -> torch.Tensor: - if intermediate_tensors is None: - inputs_embeds = self.embedding(input_ids) - image_input = self._parse_and_validate_image_input(**kwargs) - - if image_input["pixel_values"] is not None: - pixel_values = image_input["pixel_values"].to( - dtype=inputs_embeds.dtype) - image_embeds = self.vision(pixel_values) - - boi_token_id = self.config.boi_token_id - eoi_token_id = self.config.eoi_token_id - - inputs_embeds = merge_glm_vision_embeddings( - input_ids=input_ids, - inputs_embeds=inputs_embeds, - vision_embeddings=image_embeds, - boi_token_id=boi_token_id, - eoi_token_id=eoi_token_id) + + # NOTE: In v1, inputs_embeds is always generated at model runner, this + # condition is for v0 compatibility. + if intermediate_tensors is None and inputs_embeds is None: + vision_embeddings = self.get_multimodal_embeddings(**kwargs) + inputs_embeds = self.get_input_embeddings(input_ids, + vision_embeddings) + input_ids = None else: inputs_embeds = intermediate_tensors["hidden_states"] @@ -575,8 +604,7 @@ def forward( return hidden_states -class ChatGLMBaseModel(nn.Module, SupportsLoRA, SupportsPP, - SupportsMultiModal): +class ChatGLMBaseModel(nn.Module, SupportsLoRA, SupportsPP): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -695,7 +723,7 @@ class ChatGLM(ChatGLMBaseModel): embedding_padding_modules = [] -class ChatGLMV(ChatGLMBaseModel): +class ChatGLMV(ChatGLMBaseModel, SupportsMultiModal): packed_modules_mapping = { "query_key_value": ["query_key_value"], "dense_h_to_4h": ["dense_h_to_4h"], @@ -748,7 +776,7 @@ def __new__( config = vllm_config.model_config.hf_config # Initialize VL if hasattr(config, "visual"): - return ChatGLM(vllm_config=vllm_config, prefix=prefix) + return ChatGLMV(vllm_config=vllm_config, prefix=prefix) # Initialize LLM else: - return ChatGLMV(vllm_config=vllm_config, prefix=prefix) + return ChatGLM(vllm_config=vllm_config, prefix=prefix) diff --git a/vllm/model_executor/models/clip.py b/vllm/model_executor/models/clip.py index 7f638506f9fb2..cd89519e95986 100644 --- a/vllm/model_executor/models/clip.py +++ b/vllm/model_executor/models/clip.py @@ -21,7 +21,8 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.multimodal.utils import (cached_get_tokenizer, consecutive_placeholder_ranges, - repeat_and_pad_placeholder_tokens) + repeat_and_pad_placeholder_tokens, + resolve_visual_encoder_outputs) from vllm.sequence import SequenceData from .utils import get_vit_attn_backend @@ -389,12 +390,20 @@ def __init__( for layer_idx in range(num_hidden_layers) ]) - def forward(self, inputs_embeds: torch.Tensor): - + def forward( + self, inputs_embeds: torch.Tensor, return_all_hidden_states: bool + ) -> Union[torch.Tensor, list[torch.Tensor]]: + hidden_states_pool = [] hidden_states = inputs_embeds + for encoder_layer in self.layers: hidden_states = encoder_layer(hidden_states) - + if return_all_hidden_states: + hidden_states_pool.append(hidden_states) + # If we have multiple feature sample layers, we return all hidden + # states in order and grab the ones we need by index. + if return_all_hidden_states: + return hidden_states_pool return hidden_states @@ -419,6 +428,7 @@ def __init__( # NOTE: This typo of "layrnorm" is not fixed on purpose to match # the original transformers code and name of the model weights. self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + self.encoder = CLIPEncoder( config=config, quant_config=quant_config, @@ -446,16 +456,26 @@ def __init__( def forward( self, pixel_values: torch.Tensor, + feature_sample_layers: Optional[list[int]] = None, ) -> torch.Tensor: hidden_states = self.embeddings(pixel_values) hidden_states = self.pre_layrnorm(hidden_states) - hidden_states = self.encoder(inputs_embeds=hidden_states) - if self.post_layernorm is None: - return hidden_states + return_all_hidden_states = feature_sample_layers is not None + + # Produces either the last layer output or all of the hidden states, + # depending on if we have feature_sample_layers or not + encoder_outputs = self.encoder( + inputs_embeds=hidden_states, + return_all_hidden_states=return_all_hidden_states) + + # Handle post-norm (if applicable) and stacks feature layers if needed + encoder_outputs = resolve_visual_encoder_outputs( + encoder_outputs, feature_sample_layers, self.post_layernorm, + self.config.num_hidden_layers) - return self.post_layernorm(hidden_states) + return encoder_outputs class CLIPVisionModel(nn.Module): @@ -478,11 +498,14 @@ def __init__( quant_config=quant_config, num_hidden_layers_override=num_hidden_layers_override, require_post_norm=require_post_norm, - prefix=f"{prefix}.vision_model", - ) + prefix=f"{prefix}.vision_model") - def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: - return self.vision_model(pixel_values) + def forward( + self, + pixel_values: torch.Tensor, + feature_sample_layers: Optional[list[int]] = None, + ) -> torch.Tensor: + return self.vision_model(pixel_values, feature_sample_layers) @property def device(self): diff --git a/vllm/model_executor/models/commandr.py b/vllm/model_executor/models/commandr.py index f8ba7e5232418..375e51f5bd346 100644 --- a/vllm/model_executor/models/commandr.py +++ b/vllm/model_executor/models/commandr.py @@ -119,6 +119,7 @@ def __init__( config: CohereConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() tp_size = get_tensor_model_parallel_world_size() @@ -174,7 +175,8 @@ def __init__( self.scaling, num_kv_heads=self.num_kv_heads, cache_config=cache_config, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.attn") if self.use_qk_norm: self.q_norm = LayerNorm(param_shape=(self.num_heads, self.head_dim), @@ -214,13 +216,15 @@ class CohereDecoderLayer(nn.Module): def __init__(self, config: CohereConfig, cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None): + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): super().__init__() self.hidden_size = config.hidden_size self.self_attn = CohereAttention(config, cache_config, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.self_attn") self.mlp = CohereMLP(config, quant_config=quant_config) self.input_layernorm = LayerNorm(param_shape=(config.hidden_size), @@ -270,8 +274,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config.hidden_size) self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, - lambda prefix: CohereDecoderLayer(config, cache_config, - quant_config), + lambda prefix: CohereDecoderLayer( + config, cache_config, quant_config, prefix=prefix), prefix=f"{prefix}.layers") self.norm = LayerNorm(param_shape=(config.hidden_size), eps=config.layer_norm_eps) diff --git a/vllm/model_executor/models/dbrx.py b/vllm/model_executor/models/dbrx.py index 5d952ddd2f84f..2c64ba67d877b 100644 --- a/vllm/model_executor/models/dbrx.py +++ b/vllm/model_executor/models/dbrx.py @@ -166,6 +166,7 @@ def __init__( config: DbrxConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() self.d_model = config.d_model @@ -220,7 +221,8 @@ def __init__( self.scaling, num_kv_heads=self.num_kv_heads, cache_config=cache_config, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.attn") def forward( self, @@ -246,10 +248,14 @@ def __init__( config: DbrxConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() self.d_model = config.d_model - self.attn = DbrxAttention(config, cache_config, quant_config) + self.attn = DbrxAttention(config, + cache_config, + quant_config, + prefix=f"{prefix}.attn") self.norm_1 = nn.LayerNorm(self.d_model) self.norm_2 = nn.LayerNorm(self.d_model) @@ -281,10 +287,14 @@ def __init__( config: DbrxConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() - self.norm_attn_norm = DbrxFusedNormAttention(config, cache_config, - quant_config) + self.norm_attn_norm = DbrxFusedNormAttention( + config, + cache_config, + quant_config, + prefix=f"{prefix}.norm_attn_norm") self.ffn = DbrxMoE(config, quant_config) def forward( @@ -320,7 +330,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): ) self.start_layer, self.end_layer, self.blocks = make_layers( config.n_layers, - lambda prefix: DbrxBlock(config, cache_config, quant_config), + lambda prefix: DbrxBlock( + config, cache_config, quant_config, prefix=prefix), prefix=f"{prefix}.blocks", ) self.norm_f = nn.LayerNorm(config.d_model, eps=1e-5) diff --git a/vllm/model_executor/models/deepseek.py b/vllm/model_executor/models/deepseek.py index 8c5ad9904e925..74b6bfdf21909 100644 --- a/vllm/model_executor/models/deepseek.py +++ b/vllm/model_executor/models/deepseek.py @@ -49,7 +49,7 @@ from vllm.sequence import IntermediateTensors from .interfaces import SupportsPP -from .utils import (is_pp_missing_parameter, +from .utils import (extract_layer_index, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) @@ -63,6 +63,7 @@ def __init__( hidden_act: str, quant_config: Optional[QuantizationConfig] = None, reduce_results: bool = True, + prefix: str = "", ) -> None: super().__init__() self.gate_up_proj = MergedColumnParallelLinear( @@ -92,6 +93,7 @@ def __init__( self, config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() self.config = config @@ -184,6 +186,7 @@ def __init__( max_position_embeddings: int = 8192, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.hidden_size = hidden_size @@ -236,7 +239,8 @@ def __init__( self.scaling, num_kv_heads=self.num_kv_heads, cache_config=cache_config, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.attn") def forward( self, @@ -258,11 +262,12 @@ class DeepseekDecoderLayer(nn.Module): def __init__( self, config: PretrainedConfig, - layer_idx: int, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() + layer_idx = extract_layer_index(prefix) self.hidden_size = config.hidden_size rope_theta = getattr(config, "rope_theta", 10000) rope_scaling = getattr(config, "rope_scaling", None) @@ -277,17 +282,21 @@ def __init__( max_position_embeddings=max_position_embeddings, cache_config=cache_config, quant_config=quant_config, + prefix=f"{prefix}.self_attn", ) if (config.n_routed_experts is not None and layer_idx >= config.first_k_dense_replace and layer_idx % config.moe_layer_freq == 0): - self.mlp = DeepseekMoE(config=config, quant_config=quant_config) + self.mlp = DeepseekMoE(config=config, + quant_config=quant_config, + prefix=f"{prefix}.mlp") else: self.mlp = DeepseekMLP( hidden_size=config.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, quant_config=quant_config, + prefix=f"{prefix}.mlp", ) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -343,10 +352,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): ) self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, - lambda prefix: DeepseekDecoderLayer(config, - int(prefix.split(".")[-1]), - cache_config, - quant_config=quant_config), + lambda prefix: DeepseekDecoderLayer( + config, cache_config, quant_config=quant_config, prefix=prefix + ), prefix=f"{prefix}.layers") self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.make_empty_intermediate_tensors = ( diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index d2c4ca0bf85e9..4cf4e6c358bf2 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -268,7 +268,8 @@ def __init__( self.scaling, num_kv_heads=self.num_local_heads, cache_config=cache_config, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.attn") def forward( self, diff --git a/vllm/model_executor/models/exaone.py b/vllm/model_executor/models/exaone.py index 9d739d0479548..5ca26d53a17e7 100644 --- a/vllm/model_executor/models/exaone.py +++ b/vllm/model_executor/models/exaone.py @@ -174,6 +174,7 @@ def __init__( num_kv_heads=self.num_kv_heads, cache_config=cache_config, quant_config=quant_config, + prefix=f"{prefix}.attn", ) def forward( @@ -219,7 +220,7 @@ def __init__( quant_config=quant_config, bias=bias, cache_config=cache_config, - prefix=prefix, + prefix=f"{prefix}.attention", ) def forward( diff --git a/vllm/model_executor/models/falcon.py b/vllm/model_executor/models/falcon.py index 2aa4b67d99894..8660cf79b9cdb 100644 --- a/vllm/model_executor/models/falcon.py +++ b/vllm/model_executor/models/falcon.py @@ -84,6 +84,7 @@ def __init__( config: FalconConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() @@ -158,7 +159,8 @@ def __init__( self.head_dim, self.inv_norm_factor, num_kv_heads=self.num_kv_heads, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.attn") elif self.use_alibi: tp_rank = get_tensor_model_parallel_rank() head_start = tp_rank * self.num_heads @@ -171,14 +173,16 @@ def __init__( self.inv_norm_factor, num_kv_heads=self.num_kv_heads, alibi_slopes=alibi_slopes, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.attn") else: self.attn = Attention(self.num_heads, self.head_dim, scale=self.inv_norm_factor, num_kv_heads=self.num_kv_heads, cache_config=cache_config, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.attn") def forward( self, @@ -241,12 +245,16 @@ def __init__( config: FalconConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() hidden_size = config.hidden_size self.num_heads = config.num_attention_heads - self.self_attention = FalconAttention(config, cache_config, - quant_config) + self.self_attention = FalconAttention( + config, + cache_config, + quant_config, + prefix=f"{prefix}.self_attention") self.mlp = FalconMLP(config, quant_config) self.config = config @@ -357,8 +365,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): # Transformer blocks self.start_layer, self.end_layer, self.h = make_layers( config.num_hidden_layers, - lambda prefix: FalconDecoderLayer(config, cache_config, - quant_config), + lambda prefix: FalconDecoderLayer( + config, cache_config, quant_config, prefix=prefix), prefix=f"{prefix}.h") # Final Layer Norm @@ -404,12 +412,6 @@ class FalconForCausalLM(nn.Module, SupportsPP): # BitandBytes specific attributes bitsandbytes_stacked_params_mapping = {} - default_bitsandbytes_target_modules = [ - ".query_key_value.", - ".dense.", - ".dense_h_to_4h.", - ".dense_4h_to_h.", - ] def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() diff --git a/vllm/model_executor/models/florence2.py b/vllm/model_executor/models/florence2.py index d3a9ff6915b84..3a5fe8e1f4144 100644 --- a/vllm/model_executor/models/florence2.py +++ b/vllm/model_executor/models/florence2.py @@ -35,10 +35,12 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.shared = BartScaledWordEmbedding(self.vocab_size, config.d_model) self.encoder = BartEncoder(config, cache_config=cache_config, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.encoder") self.decoder = BartDecoder(config, cache_config=cache_config, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.decoder") if self.config.tie_word_embeddings: self.encoder.embed_tokens.weight = self.shared.weight @@ -99,7 +101,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.config = config self.model = Florence2LanguageModel(vllm_config=vllm_config, - prefix=prefix) + prefix=f"{prefix}.model") embed_scale = math.sqrt( config.d_model) if config.scale_embedding else 1.0 @@ -198,7 +200,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): # TODO(Isotr0py): Add vision backbone self.language_model = Florence2LanguageForConditionalGeneration( vllm_config=vllm_config.with_hf_config(config.text_config), - prefix=prefix, + prefix=f"{prefix}.language_model", ) @property diff --git a/vllm/model_executor/models/fuyu.py b/vllm/model_executor/models/fuyu.py index 7b46907ac83ab..6e86900326c4b 100644 --- a/vllm/model_executor/models/fuyu.py +++ b/vllm/model_executor/models/fuyu.py @@ -35,6 +35,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs from vllm.multimodal.image import cached_get_image_processor +from vllm.multimodal.inputs import NestedTensors from vllm.multimodal.utils import (cached_get_tokenizer, consecutive_placeholder_ranges) from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors, @@ -302,6 +303,25 @@ def _process_image_input( vision_embeddings, _ = self.vision_embed_tokens(image_input["data"]) return vision_embeddings + def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]: + image_input = self._parse_and_validate_image_input(**kwargs) + if image_input is None: + return None + vision_embeddings = self._process_image_input(image_input) + return vision_embeddings + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[NestedTensors] = None, + ) -> torch.Tensor: + inputs_embeds = self.language_model.get_input_embeddings(input_ids) + if multimodal_embeddings is not None: + inputs_embeds = merge_multimodal_embeddings( + input_ids, inputs_embeds, multimodal_embeddings, + _IMAGE_TOKEN_ID) + return inputs_embeds + def forward( self, input_ids: torch.Tensor, @@ -309,24 +329,19 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, **kwargs: object, ): if intermediate_tensors is not None: - input_ids = None inputs_embeds = None - else: - image_input = self._parse_and_validate_image_input(**kwargs) - - if image_input is not None: - vision_embeddings = self._process_image_input(image_input) - inputs_embeds = self.language_model.model.embed_tokens( - input_ids) - inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, vision_embeddings, - self.image_token_id) - - else: - inputs_embeds = None + + # NOTE: In v1, inputs_embeds is always generated at model runner, this + # condition is for v0 compatibility. + elif inputs_embeds is None: + vision_embeddings = self.get_multimodal_embeddings(**kwargs) + inputs_embeds = self.get_input_embeddings(input_ids, + vision_embeddings) + input_ids = None hidden_states = self.language_model( input_ids=input_ids, diff --git a/vllm/model_executor/models/gemma.py b/vllm/model_executor/models/gemma.py index 64e03b30bf2f1..b28715c48adfb 100644 --- a/vllm/model_executor/models/gemma.py +++ b/vllm/model_executor/models/gemma.py @@ -174,7 +174,8 @@ def __init__( self.scaling, num_kv_heads=self.num_kv_heads, cache_config=cache_config, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.attn") def forward( self, @@ -349,15 +350,6 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA, SupportsPP): "down_proj", ] # BitandBytes specific attributes - default_bitsandbytes_target_modules = [ - ".gate_proj.", - ".down_proj.", - ".up_proj.", - ".q_proj.", - ".k_proj.", - ".v_proj.", - ".o_proj.", - ] bitsandbytes_stacked_params_mapping = { # shard_name, weight_name, index "q_proj": ("qkv_proj", 0), diff --git a/vllm/model_executor/models/gemma2.py b/vllm/model_executor/models/gemma2.py index 4ba39223cc07f..4664aa53ea092 100644 --- a/vllm/model_executor/models/gemma2.py +++ b/vllm/model_executor/models/gemma2.py @@ -30,19 +30,18 @@ QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.pooler import Pooler, PoolingType from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors, PoolerOutput +from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP -from .utils import (AutoWeightsLoader, is_pp_missing_parameter, +from .utils import (AutoWeightsLoader, extract_layer_index, + is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) @@ -85,7 +84,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class Gemma2Attention(nn.Module): def __init__(self, - layer_idx: int, config: Gemma2Config, hidden_size: int, num_heads: int, @@ -95,9 +93,9 @@ def __init__(self, rope_theta: float, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, - attn_logits_soft_cap: Optional[float] = None) -> None: + attn_logits_soft_cap: Optional[float] = None, + prefix: str = "") -> None: super().__init__() - self.layer_idx = layer_idx self.config = config self.hidden_size = hidden_size tp_size = get_tensor_model_parallel_world_size() @@ -142,19 +140,22 @@ def __init__(self, is_neox_style=True, ) - # FIXME(woosuk): While Gemma 2 uses sliding window attention for every - # odd layer, vLLM currently ignores it and uses global attention for - # all layers. - use_sliding_window = (layer_idx % 2 == 1 - and config.sliding_window is not None) - del use_sliding_window # Unused. + # reference: + # https://github.com/huggingface/transformers/blob/54be2d7ae87e873482b984cc956e165ca4dc0ba3/src/transformers/models/gemma2/modeling_gemma2.py#L312 # noqa + layer_idx = extract_layer_index(prefix) + use_sliding_window = (layer_idx % 2 == 0 and + config.interleaved_sliding_window is not None) + sliding_window = config.interleaved_sliding_window if \ + use_sliding_window else None self.attn = Attention(self.num_heads, self.head_dim, self.scaling, num_kv_heads=self.num_kv_heads, cache_config=cache_config, quant_config=quant_config, - logits_soft_cap=attn_logits_soft_cap) + logits_soft_cap=attn_logits_soft_cap, + per_layer_sliding_window=sliding_window, + prefix=f"{prefix}.attn") def forward( self, @@ -175,15 +176,14 @@ class Gemma2DecoderLayer(nn.Module): def __init__( self, - layer_idx: int, config: Gemma2Config, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.hidden_size = config.hidden_size self.self_attn = Gemma2Attention( - layer_idx=layer_idx, config=config, hidden_size=self.hidden_size, num_heads=config.num_attention_heads, @@ -194,6 +194,7 @@ def __init__( cache_config=cache_config, quant_config=quant_config, attn_logits_soft_cap=config.attn_logit_softcapping, + prefix=f"{prefix}.self_attn", ) self.hidden_size = config.hidden_size self.mlp = Gemma2MLP( @@ -257,8 +258,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): ) self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, - lambda prefix: Gemma2DecoderLayer(int(prefix.split(".")[ - -1]), config, cache_config, quant_config), + lambda prefix: Gemma2DecoderLayer( + config, cache_config, quant_config, prefix=prefix), prefix=f"{prefix}.layers") self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -383,15 +384,6 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): embedding_padding_modules = [] # BitandBytes specific attributes - default_bitsandbytes_target_modules = [ - ".gate_proj.", - ".down_proj.", - ".up_proj.", - ".q_proj.", - ".k_proj.", - ".v_proj.", - ".o_proj.", - ] bitsandbytes_stacked_params_mapping = { # shard_name, weight_name, index "q_proj": ("qkv_proj", 0), @@ -461,51 +453,3 @@ def load_weights(self, weights: Iterable[Tuple[str, if self.config.tie_word_embeddings else None), ) return loader.load_weights(weights) - - -class Gemma2EmbeddingModel(nn.Module, SupportsPP): - """ - A model that uses Gemma2 with additional embedding functionalities. - - This class encapsulates the Gemma2Model and provides an interface for - embedding operations and customized pooling functions. - - Attributes: - model: An instance of Gemma2Model used for forward operations. - _pooler: An instance of Pooler used for pooling operations. - """ - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__() - - self.model = Gemma2Model(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) - self._pooler = Pooler.from_config_with_defaults( - vllm_config.model_config.pooler_config, - pooling_type=PoolingType.LAST, - normalize=True, - softmax=False) - self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) - - def forward( - self, - input_ids: Optional[torch.Tensor], - positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: - return self.model(input_ids, positions, kv_caches, attn_metadata, - intermediate_tensors, inputs_embeds) - - def pooler( - self, - hidden_states: torch.Tensor, - pooling_metadata: PoolingMetadata, - ) -> Optional[PoolerOutput]: - return self._pooler(hidden_states, pooling_metadata) - - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): - self.model.load_weights(weights) diff --git a/vllm/model_executor/models/glm.py b/vllm/model_executor/models/glm.py new file mode 100644 index 0000000000000..942d1e14baed1 --- /dev/null +++ b/vllm/model_executor/models/glm.py @@ -0,0 +1,21 @@ +"""Inference-only HF format GLM-4 model compatible with THUDM weights.""" +from vllm.config import VllmConfig +from vllm.model_executor.models.llama import LlamaForCausalLM + +from .utils import PPMissingLayer + + +class GlmForCausalLM(LlamaForCausalLM): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__(vllm_config=vllm_config, prefix=prefix) + # Hack Llama model to fit HF format GLM implementation + # Attention difference between GLM and Llama: + # 1. Half partial rotary_dim and no Neox style. + # 2. There is no bias for o_proj in attention + for layer in self.model.layers: + if not isinstance(layer, PPMissingLayer): + layer.self_attn.rotary_emb.rotary_dim //= 2 + layer.self_attn.rotary_emb.is_neox_style = False + layer.self_attn.o_proj.bias = None + layer.self_attn.o_proj.skip_bias_add = True diff --git a/vllm/model_executor/models/glm4_vision_encoder.py b/vllm/model_executor/models/glm4_vision_encoder.py index 64bd7b2f0aeb8..b2557f67edbda 100644 --- a/vllm/model_executor/models/glm4_vision_encoder.py +++ b/vllm/model_executor/models/glm4_vision_encoder.py @@ -57,6 +57,7 @@ def __init__( self, config, quant_config: Optional[QuantizationConfig] = None, + prefix: str = '', ): super().__init__() self.hidden_size = config.hidden_size @@ -160,11 +161,14 @@ def __init__( self, config, quant_config: Optional[QuantizationConfig] = None, + prefix: str = '', ): super().__init__() self.input_layernorm = LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - self.attention = Attention(config, quant_config=quant_config) + self.attention = Attention(config, + quant_config=quant_config, + prefix=f"{prefix}.attention") self.mlp = MLP(config, quant_config=quant_config) self.post_attention_layernorm = LayerNorm(config.hidden_size, eps=config.layer_norm_eps) @@ -186,11 +190,14 @@ def __init__( self, config, quant_config: Optional[QuantizationConfig] = None, + prefix: str = '', ): super().__init__() self.layers = nn.ModuleList([ - TransformerLayer(config, quant_config=quant_config) - for _ in range(config.num_hidden_layers) + TransformerLayer(config, + quant_config=quant_config, + prefix=f"{prefix}.layer.{layer_idx}") + for layer_idx in range(config.num_hidden_layers) ]) def forward(self, hidden_states): @@ -277,12 +284,14 @@ def __init__( self, config, quant_config: Optional[QuantizationConfig] = None, + prefix: str = '', ): super().__init__() vision_config = Namespace(**config.vision_config) self.patch_embedding = PatchEmbedding(vision_config) self.transformer = Transformer(vision_config, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.transformer") self.linear_proj = GLU(config, in_features=config.hidden_size, quant_config=quant_config) diff --git a/vllm/model_executor/models/gpt2.py b/vllm/model_executor/models/gpt2.py index 1c61408ae1dd9..fd926ff0254d4 100644 --- a/vllm/model_executor/models/gpt2.py +++ b/vllm/model_executor/models/gpt2.py @@ -84,7 +84,8 @@ def __init__( self.head_dim, scale=self.scale, cache_config=cache_config, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.attn") def forward( self, diff --git a/vllm/model_executor/models/gpt_bigcode.py b/vllm/model_executor/models/gpt_bigcode.py index 50a143cb1b600..c64bc70688806 100644 --- a/vllm/model_executor/models/gpt_bigcode.py +++ b/vllm/model_executor/models/gpt_bigcode.py @@ -52,6 +52,7 @@ def __init__( config: GPTBigCodeConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() self.hidden_size = config.hidden_size @@ -92,7 +93,8 @@ def __init__( scale=self.scale, num_kv_heads=self.num_kv_heads, cache_config=cache_config, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.attn") def forward( self, @@ -151,6 +153,7 @@ def __init__( config: GPTBigCodeConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() hidden_size = config.hidden_size @@ -158,7 +161,10 @@ def __init__( hidden_size) self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) - self.attn = GPTBigCodeAttention(config, cache_config, quant_config) + self.attn = GPTBigCodeAttention(config, + cache_config, + quant_config, + prefix=f"{prefix}.attn") self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) self.mlp = GPTBigMLP(inner_dim, config, quant_config) @@ -210,7 +216,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) self.start_layer, self.end_layer, self.h = make_layers( config.num_hidden_layers, - lambda prefix: GPTBigCodeBlock(config, cache_config, quant_config), + lambda prefix: GPTBigCodeBlock( + config, cache_config, quant_config, prefix=prefix), prefix=f"{prefix}.h", ) self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) diff --git a/vllm/model_executor/models/gpt_j.py b/vllm/model_executor/models/gpt_j.py index d5defc60764e6..4829578a56959 100644 --- a/vllm/model_executor/models/gpt_j.py +++ b/vllm/model_executor/models/gpt_j.py @@ -53,6 +53,7 @@ def __init__( config: GPTJConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() self.total_num_heads = config.num_attention_heads @@ -94,7 +95,8 @@ def __init__( self.head_size, scaling, cache_config=cache_config, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.attn") def forward( self, @@ -147,12 +149,16 @@ def __init__( config: GPTJConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() inner_dim = (4 * config.n_embd if config.n_inner is None else config.n_inner) self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) - self.attn = GPTJAttention(config, cache_config, quant_config) + self.attn = GPTJAttention(config, + cache_config, + quant_config, + prefix=f"{prefix}.attn") self.mlp = GPTJMLP(inner_dim, config, quant_config) def forward( @@ -193,7 +199,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): ) self.start_layer, self.end_layer, self.h = make_layers( config.n_layer, - lambda prefix: GPTJBlock(config, cache_config, quant_config), + lambda prefix: GPTJBlock( + config, cache_config, quant_config, prefix=prefix), prefix=f"{prefix}.h", ) self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) diff --git a/vllm/model_executor/models/gpt_neox.py b/vllm/model_executor/models/gpt_neox.py index 0bb5e2f9b95f9..731642772011c 100644 --- a/vllm/model_executor/models/gpt_neox.py +++ b/vllm/model_executor/models/gpt_neox.py @@ -52,6 +52,7 @@ def __init__( config: GPTNeoXConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() self.total_num_heads = config.num_attention_heads @@ -94,7 +95,8 @@ def __init__( self.head_size, scaling, cache_config=cache_config, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.attn") def forward( self, @@ -145,6 +147,7 @@ def __init__( config: GPTNeoXConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() self.use_parallel_residual = config.use_parallel_residual @@ -152,7 +155,10 @@ def __init__( eps=config.layer_norm_eps) self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - self.attention = GPTNeoXAttention(config, cache_config, quant_config) + self.attention = GPTNeoXAttention(config, + cache_config, + quant_config, + prefix=f"{prefix}.attention") self.mlp = GPTNeoXMLP(config, quant_config) def forward( @@ -205,7 +211,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): ) self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, - lambda prefix: GPTNeoXLayer(config, cache_config, quant_config), + lambda prefix: GPTNeoXLayer( + config, cache_config, quant_config, prefix=prefix), prefix=f"{prefix}.layers", ) self.final_layer_norm = nn.LayerNorm(config.hidden_size, diff --git a/vllm/model_executor/models/granite.py b/vllm/model_executor/models/granite.py index c1e2e87f08ec3..bd2394e71c973 100644 --- a/vllm/model_executor/models/granite.py +++ b/vllm/model_executor/models/granite.py @@ -161,7 +161,8 @@ def __init__( self.scaling, num_kv_heads=self.num_kv_heads, cache_config=cache_config, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.attn") def forward( self, diff --git a/vllm/model_executor/models/granitemoe.py b/vllm/model_executor/models/granitemoe.py index a91a18816995f..51296ef0cc08e 100644 --- a/vllm/model_executor/models/granitemoe.py +++ b/vllm/model_executor/models/granitemoe.py @@ -164,7 +164,8 @@ def __init__( self.scaling, num_kv_heads=self.num_kv_heads, cache_config=cache_config, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.attn") def forward( self, diff --git a/vllm/model_executor/models/idefics3.py b/vllm/model_executor/models/idefics3.py index 5d176b2a4e416..e5d2edbd81eb1 100644 --- a/vllm/model_executor/models/idefics3.py +++ b/vllm/model_executor/models/idefics3.py @@ -39,6 +39,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs from vllm.multimodal.image import cached_get_image_processor +from vllm.multimodal.inputs import NestedTensors from vllm.sequence import IntermediateTensors, SequenceData from vllm.transformers_utils.processor import cached_get_processor from vllm.utils import is_list_of @@ -266,54 +267,56 @@ def input_processor_for_idefics3(ctx: InputContext, n_images_in_text = [] text = inputs.get("prompt") - if text is not None: - if isinstance(text, str): - text = [text] - elif not isinstance(text, list) and not isinstance(text[0], str): - raise ValueError("Invalid input text. Please provide a string, " - "or a list of strings") - - fake_image_token = processor.fake_image_token.content - image_token = processor.image_token.content - global_img_token = processor.global_image_tag - - prompt_strings = [] - for sample, sample_rows, sample_cols in zip(text, image_rows, - image_cols): - n_images_in_text.append(sample.count(image_token)) - - # Replace the image token with fake tokens around the expanded - # image token sequence of length `image_seq_len` - image_prompt_strings = [] - for n_rows, n_cols in zip(sample_rows, sample_cols): - image_prompt_string = _get_image_prompt_string( - n_rows, - n_cols, - processor.image_seq_len, - image_token=image_token, - fake_token_around_image=fake_image_token, - global_img_token=global_img_token, - ) - image_prompt_strings.append(image_prompt_string) - - split_sample = sample.split(image_token) - if len(split_sample) == 0: - raise ValueError( - "The image token should be present in the text.") + if text is None: + prompt_token_ids = inputs.get("prompt_token_ids", []) + assert prompt_token_ids + text = tokenizer.decode(prompt_token_ids) + + if isinstance(text, str): + text = [text] + elif not isinstance(text, list) and not isinstance(text[0], str): + raise ValueError("Invalid input text. Please provide a string, " + "or a list of strings") + + fake_image_token = processor.fake_image_token.content + image_token = processor.image_token.content + global_img_token = processor.global_image_tag + + prompt_strings = [] + for sample, sample_rows, sample_cols in zip(text, image_rows, image_cols): + n_images_in_text.append(sample.count(image_token)) + + # Replace the image token with fake tokens around the expanded + # image token sequence of length `image_seq_len` + image_prompt_strings = [] + for n_rows, n_cols in zip(sample_rows, sample_cols): + image_prompt_string = _get_image_prompt_string( + n_rows, + n_cols, + processor.image_seq_len, + image_token=image_token, + fake_token_around_image=fake_image_token, + global_img_token=global_img_token, + ) + image_prompt_strings.append(image_prompt_string) - # Place in the image prompt strings where the image tokens are - sample = split_sample[0] - for i, image_prompt_string in enumerate(image_prompt_strings): - sample += image_prompt_string + split_sample[i + 1] - prompt_strings.append(sample) + split_sample = sample.split(image_token) + if len(split_sample) == 0: + raise ValueError("The image token should be present in the text.") - prompt_token_ids = tokenizer(text=prompt_strings[0]).input_ids + # Place in the image prompt strings where the image tokens are + sample = split_sample[0] + for i, image_prompt_string in enumerate(image_prompt_strings): + sample += image_prompt_string + split_sample[i + 1] + prompt_strings.append(sample) - return token_inputs( - prompt_token_ids=prompt_token_ids, - prompt=prompt_strings[0], - multi_modal_data=multi_modal_data, - ) + prompt_token_ids = tokenizer(text=prompt_strings[0]).input_ids + + return token_inputs( + prompt_token_ids=prompt_token_ids, + prompt=prompt_strings[0], + multi_modal_data=multi_modal_data, + ) def _get_max_num_image_patch(image_processor: Idefics3ImageProcessor) -> int: @@ -597,6 +600,12 @@ def _process_image_input(self, image_input: ImageInputs) -> torch.Tensor: image_features = self._process_image_pixels(image_input) return self.connector(image_features) + def get_input_embeddings( + self, + input_ids: torch.Tensor, + ) -> torch.Tensor: + return self.text_model.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -604,26 +613,8 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, - **kwargs: object, + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - if intermediate_tensors is not None: - input_ids = None - inputs_embeds = None - else: - # always pass the input via `inputs_embeds` - # to make sure the computation graph is consistent - image_input = self._parse_and_validate_image_input(**kwargs) - - if image_input is not None: - vision_embeddings = self._process_image_input(image_input) - inputs_embeds = self.text_model.get_input_embeddings(input_ids) - - inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, vision_embeddings, - self.image_token_id) - else: - inputs_embeds = self.text_model.get_input_embeddings(input_ids) - input_ids = None hidden_states = self.text_model( input_ids, @@ -667,21 +658,6 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal, ] # BitandBytes specific attributes - default_bitsandbytes_target_modules = [ - ".gate_proj.", - ".down_proj.", - ".up_proj.", - ".q_proj.", - ".k_proj.", - ".v_proj.", - ".o_proj.", - # vision_model - ".fc1.", - ".fc2.", - ".out_proj.", - # connector - ".proj.", - ] bitsandbytes_stacked_params_mapping = { # shard_name, weight_name, index "q_proj": ("qkv_proj", 0), @@ -718,6 +694,25 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.logits_processor = LogitsProcessor(config.text_config.vocab_size) self.sampler = Sampler() + def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]: + image_input = self.model._parse_and_validate_image_input(**kwargs) + if image_input is None: + return None + vision_embeddings = self.model._process_image_input(image_input) + return vision_embeddings + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[NestedTensors] = None, + ) -> torch.Tensor: + inputs_embeds = self.model.get_input_embeddings(input_ids) + if multimodal_embeddings is not None: + inputs_embeds = merge_multimodal_embeddings( + input_ids, inputs_embeds, multimodal_embeddings, + self.config.image_token_id) + return inputs_embeds + def forward( self, input_ids: torch.Tensor, @@ -725,16 +720,27 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, **kwargs: object, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.model( - input_ids, - positions, - kv_caches, - attn_metadata, - intermediate_tensors, - **kwargs, - ) + if intermediate_tensors is not None: + inputs_embeds = None + + # NOTE: In v1, inputs_embeds is always generated at model runner, this + # condition is for v0 compatibility. + elif inputs_embeds is None: + vision_embeddings = self.get_multimodal_embeddings(**kwargs) + inputs_embeds = self.get_input_embeddings(input_ids, + vision_embeddings) + input_ids = None + + hidden_states = self.model.text_model(input_ids, + positions, + kv_caches, + attn_metadata, + intermediate_tensors, + inputs_embeds=inputs_embeds) + return hidden_states def compute_logits(self, hidden_states: torch.Tensor, diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index dcead65115132..01a381381ccec 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -2,17 +2,22 @@ Protocol, Type, Union, overload, runtime_checkable) import torch -from typing_extensions import TypeIs +from typing_extensions import TypeIs, TypeVar from vllm.logger import init_logger from vllm.utils import supports_kw +from .interfaces_base import is_pooling_model + if TYPE_CHECKING: - from vllm.config import LoRAConfig, MultiModalConfig, SchedulerConfig + from vllm.attention import AttentionMetadata + from vllm.multimodal.inputs import NestedTensors # noqa: F401 from vllm.sequence import IntermediateTensors logger = init_logger(__name__) +T = TypeVar("T", default="NestedTensors") + @runtime_checkable class SupportsMultiModal(Protocol): @@ -27,7 +32,34 @@ class SupportsMultiModal(Protocol): MRO of your model class. """ - def __init__(self, *, multimodal_config: "MultiModalConfig") -> None: + def get_multimodal_embeddings(self, **kwargs) -> Optional[T]: + """ + Returns multimodal embeddings generated from multimodal kwargs + to be merged with text embeddings. + """ + ... + + # Only for models that support v0 chunked prefill + # TODO(ywang96): Remove this overload once v0 is deprecated + @overload + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[T] = None, + attn_metadata: Optional["AttentionMetadata"] = None, + ) -> torch.Tensor: + ... + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[T] = None, + ) -> torch.Tensor: + """ + Returns the input embeddings merged from the text embeddings from + input_ids and the multimodal embeddings generated from multimodal + kwargs. + """ ... @@ -37,9 +69,6 @@ def __init__(self, *, multimodal_config: "MultiModalConfig") -> None: class _SupportsMultiModalType(Protocol): supports_multimodal: Literal[True] - def __call__(self, *, multimodal_config: "MultiModalConfig") -> None: - ... - @overload def supports_multimodal( @@ -79,10 +108,6 @@ class SupportsLoRA(Protocol): embedding_modules: ClassVar[Dict[str, str]] embedding_padding_modules: ClassVar[List[str]] - # lora_config is None when LoRA is not enabled - def __init__(self, *, lora_config: Optional["LoRAConfig"] = None) -> None: - ... - # We can't use runtime_checkable with ClassVar for issubclass checks # so we need to treat the class as an instance and use isinstance instead @@ -95,9 +120,6 @@ class _SupportsLoRAType(Protocol): embedding_modules: Dict[str, str] embedding_padding_modules: List[str] - def __call__(self, *, lora_config: Optional["LoRAConfig"] = None) -> None: - ... - @overload def supports_lora(model: Type[object]) -> TypeIs[Type[SupportsLoRA]]: @@ -274,21 +296,11 @@ class HasInnerState(Protocol): for max_num_seqs, etc. True for e.g. both Mamba and Jamba. """ - def __init__(self, - *, - scheduler_config: Optional["SchedulerConfig"] = None) -> None: - ... - @runtime_checkable class _HasInnerStateType(Protocol): has_inner_state: ClassVar[Literal[True]] - def __init__(self, - *, - scheduler_config: Optional["SchedulerConfig"] = None) -> None: - ... - @overload def has_inner_state(model: object) -> TypeIs[HasInnerState]: @@ -321,17 +333,11 @@ class IsAttentionFree(Protocol): True for Mamba but not Jamba. """ - def __init__(self) -> None: - ... - @runtime_checkable class _IsAttentionFreeType(Protocol): is_attention_free: ClassVar[Literal[True]] - def __init__(self) -> None: - ... - @overload def is_attention_free(model: object) -> TypeIs[IsAttentionFree]: @@ -350,3 +356,37 @@ def is_attention_free( return isinstance(model, _IsAttentionFreeType) return isinstance(model, IsAttentionFree) + + +@runtime_checkable +class SupportsCrossEncoding(Protocol): + """The interface required for all models that support cross encoding.""" + + supports_cross_encoding: ClassVar[Literal[True]] = True + + +@overload +def supports_cross_encoding( + model: Type[object]) -> TypeIs[Type[SupportsCrossEncoding]]: + ... + + +@overload +def supports_cross_encoding(model: object) -> TypeIs[SupportsCrossEncoding]: + ... + + +def _supports_cross_encoding( + model: Union[Type[object], object], +) -> Union[TypeIs[Type[SupportsCrossEncoding]], TypeIs[SupportsCrossEncoding]]: + + if isinstance(model, type): + return isinstance(model, SupportsCrossEncoding) + + return isinstance(model, SupportsCrossEncoding) + + +def supports_cross_encoding( + model: Union[Type[object], object], +) -> Union[TypeIs[Type[SupportsCrossEncoding]], TypeIs[SupportsCrossEncoding]]: + return is_pooling_model(model) and _supports_cross_encoding(model) diff --git a/vllm/model_executor/models/interfaces_base.py b/vllm/model_executor/models/interfaces_base.py index 7bb43beff255c..de733b6d49a53 100644 --- a/vllm/model_executor/models/interfaces_base.py +++ b/vllm/model_executor/models/interfaces_base.py @@ -71,7 +71,7 @@ def _check_vllm_model_forward(model: Union[Type[object], object]) -> bool: and issubclass(model, nn.Module)): logger.warning( "The model (%s) is missing " - "vLLM-specific keywords from its initializer: %s", + "vLLM-specific keywords from its `forward` method: %s", model, missing_kws, ) @@ -141,7 +141,7 @@ def is_text_generation_model( @runtime_checkable -class VllmModelForEmbedding(VllmModel[C_co, T], Protocol[C_co, T]): +class VllmModelForPooling(VllmModel[C_co, T], Protocol[C_co, T]): def pooler( self, @@ -153,23 +153,22 @@ def pooler( @overload -def is_embedding_model( - model: Type[object]) -> TypeIs[Type[VllmModelForEmbedding]]: +def is_pooling_model(model: Type[object]) -> TypeIs[Type[VllmModelForPooling]]: ... @overload -def is_embedding_model(model: object) -> TypeIs[VllmModelForEmbedding]: +def is_pooling_model(model: object) -> TypeIs[VllmModelForPooling]: ... -def is_embedding_model( +def is_pooling_model( model: Union[Type[object], object], -) -> Union[TypeIs[Type[VllmModelForEmbedding]], TypeIs[VllmModelForEmbedding]]: +) -> Union[TypeIs[Type[VllmModelForPooling]], TypeIs[VllmModelForPooling]]: if not is_vllm_model(model): return False if isinstance(model, type): - return isinstance(model, VllmModelForEmbedding) + return isinstance(model, VllmModelForPooling) - return isinstance(model, VllmModelForEmbedding) + return isinstance(model, VllmModelForPooling) diff --git a/vllm/model_executor/models/internlm2.py b/vllm/model_executor/models/internlm2.py index 94b819b5d9366..41b9f110d771f 100644 --- a/vllm/model_executor/models/internlm2.py +++ b/vllm/model_executor/models/internlm2.py @@ -1,5 +1,5 @@ from functools import partial -from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union +from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Type, Union import torch from torch import nn @@ -27,7 +27,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors -from .interfaces import SupportsPP +from .interfaces import SupportsLoRA, SupportsPP from .utils import (is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) @@ -250,7 +250,12 @@ def forward( @support_torch_compile class InternLM2Model(nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + def __init__( + self, + *, + vllm_config: VllmConfig, + prefix: str = "", + layer_type: Type[InternLMDecoderLayer] = InternLMDecoderLayer): super().__init__() config = vllm_config.model_config.hf_config @@ -266,7 +271,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): ) self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, - lambda prefix: InternLMDecoderLayer( + lambda prefix: layer_type( config, cache_config, quant_config, prefix=prefix), prefix=f"{prefix}.layers") self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -314,16 +319,38 @@ def forward( return hidden_states -class InternLM2ForCausalLM(nn.Module, SupportsPP): - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): +class InternLM2ForCausalLM(nn.Module, SupportsPP, SupportsLoRA): + packed_modules_mapping = { + "wqkv": ["wqkv"], + "gate_up_proj": ["w1", "w3"], + } + + # LoRA specific attributes + supported_lora_modules = [ + "wqkv", + "wo", + "gate_up_proj", + "w2", + ] + embedding_modules = {} + embedding_padding_modules = [] + + def __init__(self, + *, + vllm_config: VllmConfig, + prefix: str = "", + model_type: Type[InternLM2Model] = InternLM2Model): super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + self.config = config self.quant_config = quant_config - self.model = InternLM2Model(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) + self.lora_config = lora_config + + self.model = model_type(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) self.output = ParallelLMHead(config.vocab_size, config.hidden_size, quant_config=quant_config, diff --git a/vllm/model_executor/models/internlm2_ve.py b/vllm/model_executor/models/internlm2_ve.py index f1b7c896cadfe..93ac2dcf8d587 100644 --- a/vllm/model_executor/models/internlm2_ve.py +++ b/vllm/model_executor/models/internlm2_ve.py @@ -14,8 +14,6 @@ InternLM2MLP, InternLM2Model) from vllm.sequence import IntermediateTensors -from .utils import make_layers, maybe_prefix - class InternLM2VEDecoderLayer(nn.Module): @@ -105,17 +103,9 @@ def forward( class InternLM2VEModel(InternLM2Model): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__(vllm_config=vllm_config, prefix=prefix) - - config = vllm_config.model_config.hf_config - cache_config = vllm_config.cache_config - quant_config = vllm_config.quant_config - - self.start_layer, self.end_layer, self.layers = make_layers( - config.num_hidden_layers, - lambda prefix: InternLM2VEDecoderLayer( - config, cache_config, quant_config, prefix=prefix), - prefix=f"{prefix}.layers") + super().__init__(vllm_config=vllm_config, + prefix=prefix, + layer_type=InternLM2VEDecoderLayer) def forward( self, @@ -159,7 +149,6 @@ def forward( class InternLM2VEForCausalLM(InternLM2ForCausalLM): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__(vllm_config=vllm_config, prefix=prefix) - - self.model = InternLM2VEModel(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) + super().__init__(vllm_config=vllm_config, + prefix=prefix, + model_type=InternLM2VEModel) diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py index 7ea2f9be2191d..86aab38032450 100644 --- a/vllm/model_executor/models/internvl.py +++ b/vllm/model_executor/models/internvl.py @@ -19,13 +19,14 @@ from vllm.config import VllmConfig from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, InputContext, token_inputs) -from vllm.model_executor.layers.quantization import (AWQConfig, - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.quantization.awq import AWQConfig from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.models.intern_vit import (InternVisionModel, InternVisionPatchModel) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs +from vllm.multimodal.inputs import NestedTensors from vllm.multimodal.utils import cached_get_tokenizer from vllm.sequence import IntermediateTensors from vllm.utils import is_list_of @@ -123,8 +124,15 @@ def calculate_num_blocks(orig_width: int, orig_height: int, min_num: int, return blocks, target_width, target_height -def calculate_num_blocks_wrapper(hf_config: PretrainedConfig, - max_dynamic_patch: Optional[int] = None): +def calculate_num_blocks_wrapper( + hf_config: PretrainedConfig, + max_dynamic_patch: Optional[int] = None, + dynamic_image_size: Optional[bool] = None, +): + if dynamic_image_size is None: + dynamic_image_size = hf_config.dynamic_image_size + + max_dynamic_patch = max_dynamic_patch if dynamic_image_size else 1 if max_dynamic_patch is None: max_dynamic_patch = hf_config.max_dynamic_patch min_num = hf_config.min_dynamic_patch @@ -183,10 +191,17 @@ def image_to_pixel_values(image: Image.Image, input_size: int, min_num: int, return pixel_values -def image_to_pixel_values_wrapper(hf_config: PretrainedConfig, - max_dynamic_patch: Optional[int] = None): +def image_to_pixel_values_wrapper( + hf_config: PretrainedConfig, + max_dynamic_patch: Optional[int] = None, + dynamic_image_size: Optional[bool] = None, +): image_size = hf_config.vision_config.image_size min_num = hf_config.min_dynamic_patch + if dynamic_image_size is None: + dynamic_image_size = hf_config.dynamic_image_size + + max_dynamic_patch = max_dynamic_patch if dynamic_image_size else 1 if max_dynamic_patch is None: max_dynamic_patch = hf_config.max_dynamic_patch use_thumbnail = hf_config.use_thumbnail @@ -207,11 +222,17 @@ def get_internvl_num_patches(hf_config: PretrainedConfig): (downsample_ratio**2)) -def get_max_internvl_image_tokens(ctx: InputContext, - *, - max_dynamic_patch: Optional[int] = None): +def get_max_internvl_image_tokens( + ctx: InputContext, + *, + max_dynamic_patch: Optional[int] = None, + dynamic_image_size: Optional[bool] = None, +): hf_config = ctx.get_hf_config() + if dynamic_image_size is None: + dynamic_image_size = hf_config.dynamic_image_size + max_dynamic_patch = max_dynamic_patch if dynamic_image_size else 1 if max_dynamic_patch is None: max_dynamic_patch = hf_config.max_dynamic_patch use_thumbnail = hf_config.use_thumbnail @@ -222,12 +243,18 @@ def get_max_internvl_image_tokens(ctx: InputContext, return num_patches * max_dynamic_patch -def get_max_internvl_image_size(ctx: InputContext, - *, - max_dynamic_patch: Optional[int] = None): +def get_max_internvl_image_size( + ctx: InputContext, + *, + max_dynamic_patch: Optional[int] = None, + dynamic_image_size: Optional[bool] = None, +): hf_config = ctx.get_hf_config() image_size = hf_config.vision_config.image_size + if dynamic_image_size is None: + dynamic_image_size = hf_config.dynamic_image_size + max_dynamic_patch = max_dynamic_patch if dynamic_image_size else 1 if max_dynamic_patch is None: max_dynamic_patch = hf_config.max_dynamic_patch use_thumbnail = hf_config.use_thumbnail @@ -281,6 +308,7 @@ def input_processor( inputs: DecoderOnlyInputs, *, max_dynamic_patch: Optional[int] = None, + dynamic_image_size: Optional[bool] = None, ) -> DecoderOnlyInputs: multi_modal_data = inputs.get("multi_modal_data") if multi_modal_data is None or "image" not in multi_modal_data: @@ -292,7 +320,7 @@ def input_processor( image_data = multi_modal_data["image"] num_patches = get_internvl_num_patches(hf_config) num_blocks_calculator = calculate_num_blocks_wrapper( - hf_config, max_dynamic_patch) + hf_config, max_dynamic_patch, dynamic_image_size) if isinstance(image_data, Image.Image): width, height = image_data.size num_blocks, _, _ = num_blocks_calculator(width, height) @@ -332,11 +360,12 @@ def input_mapper( data: object, *, max_dynamic_patch: Optional[int] = None, + dynamic_image_size: Optional[bool] = None, ): hf_config = ctx.get_hf_config() image_pixel_values_mapper = image_to_pixel_values_wrapper( - hf_config, max_dynamic_patch) + hf_config, max_dynamic_patch, dynamic_image_size) if isinstance(data, Image.Image): data = image_pixel_values_mapper(data) # Add an N dimension for number of images per prompt (currently 1). @@ -366,13 +395,17 @@ def dummy_data( mm_counts: Mapping[str, int], *, max_dynamic_patch: Optional[int] = None, + dynamic_image_size: Optional[bool] = None, ): num_images = mm_counts["image"] hf_config = ctx.get_hf_config() image_feature_size = get_max_internvl_image_tokens( - ctx, max_dynamic_patch=max_dynamic_patch) + ctx, + max_dynamic_patch=max_dynamic_patch, + dynamic_image_size=dynamic_image_size, + ) model_config = ctx.model_config tokenizer = cached_get_tokenizer( model_config.tokenizer, @@ -388,7 +421,10 @@ def dummy_data( ) max_image_width, max_image_height = get_max_internvl_image_size( - ctx, max_dynamic_patch=max_dynamic_patch) + ctx, + max_dynamic_patch=max_dynamic_patch, + dynamic_image_size=dynamic_image_size, + ) mm_data = dummy_image_for_clip( hf_config.vision_config, @@ -438,9 +474,10 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: ) self.language_model = init_vllm_registered_model( - config.text_config, vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "language_model")) + hf_config=config.text_config, + prefix=maybe_prefix(prefix, "language_model"), + ) self.mlp1 = self._init_mlp1(config) @@ -606,6 +643,26 @@ def _get_visual_token_mask(self, input_ids: torch.Tensor) -> torch.Tensor: visual_token_mask = None return visual_token_mask + def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]: + image_input = self._parse_and_validate_image_input(**kwargs) + if image_input is None: + return None + vision_embeddings = self._process_image_input(image_input) + return vision_embeddings + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[NestedTensors] = None, + ) -> torch.Tensor: + inputs_embeds = self.language_model.get_input_embeddings(input_ids) + if multimodal_embeddings is not None: + assert self.img_context_token_id is not None + inputs_embeds = merge_multimodal_embeddings( + input_ids, inputs_embeds, multimodal_embeddings, + self.img_context_token_id) + return inputs_embeds + def forward( self, input_ids: torch.Tensor, @@ -613,26 +670,22 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, **kwargs: object, ) -> Union[SamplerOutput, IntermediateTensors]: + + visual_token_mask = None if intermediate_tensors is not None: input_ids = None inputs_embeds = None - visual_token_mask = None - else: - image_input = self._parse_and_validate_image_input(**kwargs) - if image_input is not None: - inputs_embeds = self.language_model.model.get_input_embeddings( - input_ids) - vision_embeddings = self._process_image_input(image_input) - inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, vision_embeddings, - self.img_context_token_id) - visual_token_mask = self._get_visual_token_mask(input_ids) - input_ids = None - else: - inputs_embeds = None - visual_token_mask = None + + # NOTE: In v1, inputs_embeds is always generated at model runner, this + # condition is for v0 compatibility. + elif inputs_embeds is None: + vision_embeddings = self.get_multimodal_embeddings(**kwargs) + inputs_embeds = self.get_input_embeddings(input_ids, + vision_embeddings) + input_ids = None forward_kwargs = { "input_ids": input_ids, @@ -642,6 +695,13 @@ def forward( "intermediate_tensors": intermediate_tensors, "inputs_embeds": inputs_embeds, } + if self.img_context_token_id is not None: + visual_token_mask = self._get_visual_token_mask(input_ids) + + # We always overwrite it back to None after computing visual token + # mask so that this doesn't need to depend on encoder output + self.img_context_token_id = None + if self.is_mono: forward_kwargs.update({"visual_token_mask": visual_token_mask}) diff --git a/vllm/model_executor/models/jais.py b/vllm/model_executor/models/jais.py index 41db85b678456..8c81dff6b5768 100644 --- a/vllm/model_executor/models/jais.py +++ b/vllm/model_executor/models/jais.py @@ -76,6 +76,7 @@ def __init__( config: JAISConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() self.hidden_size = config.hidden_size @@ -114,7 +115,8 @@ def __init__( scale=self.scale, alibi_slopes=alibi_slopes, cache_config=cache_config, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.attn") def forward( self, @@ -178,6 +180,7 @@ def __init__( config: JAISConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() hidden_size = config.hidden_size @@ -185,7 +188,10 @@ def __init__( hidden_size) self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) - self.attn = JAISAttention(config, cache_config, quant_config) + self.attn = JAISAttention(config, + cache_config, + quant_config, + prefix=f"{prefix}.attn") self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) self.mlp = JAISMLP(inner_dim, config, quant_config) @@ -241,7 +247,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config.num_hidden_layers, lambda prefix: JAISBlock(config=config, cache_config=cache_config, - quant_config=quant_config), + quant_config=quant_config, + prefix=prefix), prefix=f"{prefix}.h", ) diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index f83f0fce7275f..099ca7e12b288 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -102,7 +102,8 @@ def __init__(self, config: JambaConfig, layer_idx: int, cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None) -> None: + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "") -> None: super().__init__() self.config = config self.mamba = MambaMixer(hidden_size= config.hidden_size, @@ -157,6 +158,7 @@ def __init__( layer_idx: int, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.hidden_size = config.hidden_size @@ -198,6 +200,7 @@ def __init__( self.scaling, num_kv_heads=self.num_kv_heads, cache_config=cache_config, + prefix=f"{prefix}.attn", ) num_experts = config.layers_num_experts[layer_idx] @@ -287,7 +290,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): layer_class(config, layer_idx=i, cache_config=cache_config, - quant_config=quant_config)) + quant_config=quant_config, + prefix=f"{prefix}.layers.{i}")) self.layers = nn.ModuleList(decoder_layers) self.final_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index f84bc93085750..bc24a7d306afa 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -20,7 +20,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only LLaMA model compatible with HuggingFace weights.""" -from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union +from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Type, Union import torch from torch import nn @@ -39,7 +39,6 @@ QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.pooler import Pooler, PoolingType from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( get_compressed_tensors_cache_scale) @@ -50,14 +49,14 @@ DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name) -from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.platforms import current_platform -from vllm.sequence import IntermediateTensors, PoolerOutput +from vllm.sequence import IntermediateTensors from vllm.utils import is_navi from .interfaces import SupportsLoRA, SupportsPP -from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter, +from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index, + is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) @@ -130,6 +129,7 @@ def __init__( prefix: str = "", ) -> None: super().__init__() + layer_idx = extract_layer_index(prefix) self.hidden_size = hidden_size tp_size = get_tensor_model_parallel_world_size() self.total_num_heads = num_heads @@ -184,6 +184,18 @@ def __init__( rope_scaling=rope_scaling, is_neox_style=is_neox_style, ) + + if hasattr(config, "interleaved_sliding_window"): + if isinstance(config.interleaved_sliding_window, int): + sliding_window = config.interleaved_sliding_window + elif isinstance(config.interleaved_sliding_window, list): + sw_idx = layer_idx % len(config.interleaved_sliding_window) + sliding_window = config.interleaved_sliding_window[sw_idx] + else: + raise ValueError(f"{type(sliding_window)} is not supported.") + else: + sliding_window = None + self.attn = Attention( self.num_heads, self.head_dim, @@ -191,6 +203,8 @@ def __init__( num_kv_heads=self.num_kv_heads, cache_config=cache_config, quant_config=quant_config, + per_layer_sliding_window=sliding_window, + prefix=f"{prefix}.attn", ) # For CUDA devices and Navi4x, attn_fp8_out will be set to false. self.attn_fp8_out = envs.VLLM_USE_ROCM_CUSTOM_PAGED_ATTN_FP8_OUT \ @@ -305,7 +319,11 @@ def forward( @support_torch_compile class LlamaModel(nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + def __init__(self, + *, + vllm_config: VllmConfig, + prefix: str = "", + layer_type: Type[LlamaDecoderLayer] = LlamaDecoderLayer): super().__init__() config = vllm_config.model_config.hf_config @@ -331,10 +349,10 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.embed_tokens = PPMissingLayer() self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, - lambda prefix: LlamaDecoderLayer(config=config, - cache_config=cache_config, - quant_config=quant_config, - prefix=prefix), + lambda prefix: layer_type(config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix), prefix=f"{prefix}.layers", ) if get_pp_group().is_last_rank: @@ -492,15 +510,6 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP): embedding_padding_modules = ["lm_head"] # BitandBytes specific attributes - default_bitsandbytes_target_modules = [ - ".gate_proj.", - ".down_proj.", - ".up_proj.", - ".q_proj.", - ".k_proj.", - ".v_proj.", - ".o_proj.", - ] bitsandbytes_stacked_params_mapping = { # shard_name, weight_name, index "q_proj": ("qkv_proj", 0), @@ -535,12 +544,12 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config lora_config = vllm_config.lora_config - pooler_config = vllm_config.model_config.pooler_config self.config = config self.lora_config = lora_config - self.model = LlamaModel(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) + self.model = self._init_model(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + if get_pp_group().is_last_rank: self.unpadded_vocab_size = config.vocab_size if lora_config: @@ -569,13 +578,12 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.sampler = get_sampler() else: self.lm_head = PPMissingLayer() + self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) - self._pooler = Pooler.from_config_with_defaults( - pooler_config, - pooling_type=PoolingType.STEP, - normalize=False, - softmax=False) + + def _init_model(self, vllm_config: VllmConfig, prefix: str = ""): + return LlamaModel(vllm_config=vllm_config, prefix=prefix) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) @@ -603,14 +611,6 @@ def compute_logits( sampling_metadata) return logits - def pooler( - self, - hidden_states: torch.Tensor, - pooling_metadata: PoolingMetadata, - ) -> Optional[PoolerOutput]: - logits = self.compute_logits(hidden_states, None) - return self._pooler(logits, pooling_metadata) - def sample(self, logits: torch.Tensor, sampling_metadata: SamplingMetadata) -> Optional[SamplerOutput]: next_tokens = self.sampler(logits, sampling_metadata) @@ -661,74 +661,3 @@ def permute(w: torch.Tensor, n_heads: int): name = name.replace(item, mapping[item]) return name, loaded_weight - - -class LlamaEmbeddingModel(nn.Module, SupportsLoRA, SupportsPP): - """ - A model that uses Llama with additional embedding functionalities. - - This class encapsulates the LlamaModel and provides an interface for - embedding operations and customized pooling functions. - - Attributes: - model: An instance of LlamaModel used for forward operations. - _pooler: An instance of Pooler used for pooling operations. - """ - packed_modules_mapping = { - "qkv_proj": ["q_proj", "k_proj", "v_proj"], - "gate_up_proj": ["gate_proj", "up_proj"] - } - - # LoRA specific attributes - supported_lora_modules = [ - "qkv_proj", "o_proj", "gate_up_proj", "down_proj", "embed_tokens" - ] - embedding_modules = { - "embed_tokens": "input_embeddings", - } - embedding_padding_modules = [] - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__() - - pooler_config = vllm_config.model_config.pooler_config - - self.model = LlamaModel(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) - self._pooler = Pooler.from_config_with_defaults( - pooler_config, - pooling_type=PoolingType.LAST, - normalize=True, - softmax=False) - self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) - - def forward( - self, - input_ids: Optional[torch.Tensor], - positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: - return self.model(input_ids, positions, kv_caches, attn_metadata, - intermediate_tensors, inputs_embeds) - - def pooler( - self, - hidden_states: torch.Tensor, - pooling_metadata: PoolingMetadata, - ) -> Optional[PoolerOutput]: - return self._pooler(hidden_states, pooling_metadata) - - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): - self.model.load_weights(weights) - - def load_kv_cache_scales(self, quantization_param_path: str) -> None: - self.model.load_kv_cache_scales(quantization_param_path) - - # LRUCacheWorkerLoRAManager instantiation requires model config. - @property - def config(self): - return self.model.config diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index e7d3161a7cb2d..db7fa82ceb9b7 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -204,7 +204,41 @@ def input_processor_for_llava(ctx: InputContext, inputs: DecoderOnlyInputs): class LlavaLikeConfig(Protocol): vision_config: PretrainedConfig - vision_feature_layer: int + vision_feature_layer: Union[int, List[int]] + + +def _get_num_hidden_layers(hf_config: LlavaLikeConfig) -> int: + """Determine the number of hidden layers to initialize up to in the + visual encoder. + + Args: + hf_config: Model config with vision feature layer(s). + """ + feature_layers = hf_config.vision_feature_layer + num_hidden_layers = hf_config.vision_config.num_hidden_layers + # If we have one feature layer, initialize up to that layer + if isinstance(feature_layers, int): + return _get_layer_index(feature_layers, num_hidden_layers) + # If we have multiple feature layers, initialize up to the deepest one + elif isinstance(feature_layers, (list, tuple)): + return max( + _get_layer_index(idx, num_hidden_layers) for idx in feature_layers) + raise TypeError(f"vision_layer_feature type: {type(feature_layers)}" + " is not supported") + + +def _get_layer_index(feature_layer_index: int, num_hidden_layers: int) -> int: + """Given an signed vision feature layer, get the number of hidden layers + needed to leverage it. + + Args: + feature_layer_index: Index of a required layer in the visual encoder. + num_hidden_layers: The total number of hidden layers in the visual + encoder. + """ + if feature_layer_index < 0: + return num_hidden_layers + feature_layer_index + 1 + return feature_layer_index + 1 def init_vision_tower_for_llava( @@ -216,13 +250,8 @@ def init_vision_tower_for_llava( ): vision_config = hf_config.vision_config - # Initialize the vision tower only up to the required feature layer - vision_feature_layer = hf_config.vision_feature_layer - if vision_feature_layer < 0: - num_hidden_layers = hf_config.vision_config.num_hidden_layers \ - + vision_feature_layer + 1 - else: - num_hidden_layers = vision_feature_layer + 1 + # Initialize the vision tower only up to the deepest required feature layer + num_hidden_layers = _get_num_hidden_layers(hf_config) if isinstance(vision_config, CLIPVisionConfig): return CLIPVisionModel( @@ -258,6 +287,15 @@ def init_vision_tower_for_llava( @INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava) @INPUT_REGISTRY.register_input_processor(input_processor_for_llava) class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): + # BitandBytes specific attributes + bitsandbytes_stacked_params_mapping = { + # shard_name, weight_name, index + "q_proj": ("qkv_proj", 0), + "k_proj": ("qkv_proj", 1), + "v_proj": ("qkv_proj", 2), + "gate_proj": ("gate_up_proj", 0), + "up_proj": ("gate_up_proj", 1), + } def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: super().__init__() @@ -290,9 +328,10 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: projector_hidden_act=config.projector_hidden_act) self.language_model = init_vllm_registered_model( - config.text_config, vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "language_model")) + hf_config=config.text_config, + prefix=maybe_prefix(prefix, "language_model"), + ) self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors) @@ -449,7 +488,7 @@ def _process_image_input(self, image_features = self._process_image_pixels(image_input) return self.multi_modal_projector(image_features) - def process_mm_inputs(self, **kwargs): + def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]: image_input = self._parse_and_validate_image_input(**kwargs) if image_input is None: return None @@ -459,12 +498,12 @@ def process_mm_inputs(self, **kwargs): def get_input_embeddings( self, input_ids: torch.Tensor, - vision_embeddings: Optional[NestedTensors] = None, + multimodal_embeddings: Optional[NestedTensors] = None, ) -> torch.Tensor: inputs_embeds = self.language_model.get_input_embeddings(input_ids) - if vision_embeddings is not None: + if multimodal_embeddings is not None: inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, vision_embeddings, + input_ids, inputs_embeds, multimodal_embeddings, self.config.image_token_index) return inputs_embeds @@ -515,10 +554,11 @@ def forward( """ if intermediate_tensors is not None: inputs_embeds = None + + # NOTE: In v1, inputs_embeds is always generated at model runner, this + # condition is for v0 compatibility. elif inputs_embeds is None: - vision_embeddings = self.process_mm_inputs(**kwargs) - # always pass the input via `inputs_embeds` - # to make sure the computation graph is consistent + vision_embeddings = self.get_multimodal_embeddings(**kwargs) inputs_embeds = self.get_input_embeddings(input_ids, vision_embeddings) input_ids = None diff --git a/vllm/model_executor/models/llava_next.py b/vllm/model_executor/models/llava_next.py index 37e2227a52dcd..a39f2f4124d05 100644 --- a/vllm/model_executor/models/llava_next.py +++ b/vllm/model_executor/models/llava_next.py @@ -14,12 +14,11 @@ from vllm.config import VllmConfig from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, InputContext) -from vllm.model_executor.layers.pooler import Pooler, PoolingType from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler -from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.sequence import IntermediateTensors, PoolerOutput +from vllm.multimodal.inputs import NestedTensors +from vllm.sequence import IntermediateTensors from vllm.utils import is_list_of from .clip import (CLIPVisionModel, dummy_image_for_clip, @@ -285,9 +284,23 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config - pooler_config = vllm_config.model_config.pooler_config multimodal_config = vllm_config.model_config.multimodal_config + vision_feature_layer = config.vision_feature_layer + # Determine the layer up to which we will initialize the vision tower + if isinstance(vision_feature_layer, int): + vision_hidden_size = config.vision_config.hidden_size + self.feature_sample_layers = None + # Used for multimodal granite models to control encoder outputs + elif isinstance(vision_feature_layer, (list, tuple)): + vision_hidden_size = config.vision_config.hidden_size * len( + vision_feature_layer) + self.feature_sample_layers = vision_feature_layer + else: + raise TypeError( + f"vision_layer_feature type: {type(vision_feature_layer)}" + " is not supported") + self.config = config self.multimodal_config = multimodal_config @@ -300,22 +313,16 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: self.image_newline = nn.Parameter( torch.empty(config.text_config.hidden_size)) self.multi_modal_projector = LlavaMultiModalProjector( - vision_hidden_size=config.vision_config.hidden_size, + vision_hidden_size=vision_hidden_size, text_hidden_size=config.text_config.hidden_size, projector_hidden_act=config.projector_hidden_act) self.language_model = init_vllm_registered_model( - config.text_config, vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "language_model")) - - # The same model class supports both language generation and embedding - # because the architecture name is the same - self._pooler = Pooler.from_config_with_defaults( - pooler_config, - pooling_type=PoolingType.LAST, - normalize=True, - softmax=False) + hf_config=config.text_config, + prefix=maybe_prefix(prefix, "language_model"), + ) + self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors) @@ -419,7 +426,8 @@ def _image_pixels_to_features( # NOTE: we skip the step to select the vision feature layer since # this is already done inside the vision tower - image_features = vision_tower(pixel_values) + image_features = vision_tower( + pixel_values, feature_sample_layers=self.feature_sample_layers) return self._select_image_features( image_features, @@ -549,6 +557,30 @@ def _process_image_input( for i, patch_features_batch in enumerate(patch_embeddings) ] + def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]: + image_input = self._parse_and_validate_image_input(**kwargs) + if image_input is None: + return None + vision_embeddings = self._process_image_input(image_input) + return vision_embeddings + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[NestedTensors] = None, + ) -> torch.Tensor: + + if multimodal_embeddings is None: + return self.language_model.get_input_embeddings(input_ids) + + inputs_embeds = embed_multimodal( + input_ids, + self.config.image_token_index, + self.language_model.model.get_input_embeddings, + multimodal_embeddings, + ) + return inputs_embeds + def forward( self, input_ids: torch.Tensor, @@ -556,6 +588,7 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, **kwargs: object, ) -> Union[torch.Tensor, IntermediateTensors]: """Run forward pass for LlaVA-NeXT. @@ -604,24 +637,14 @@ def forward( """ if intermediate_tensors is not None: inputs_embeds = None - else: - image_input = self._parse_and_validate_image_input(**kwargs) - - if image_input is not None: - inputs_embeds = embed_multimodal( - input_ids, - self.config.image_token_index, - self.language_model.model.get_input_embeddings, - lambda _: self._process_image_input(image_input), - ) - else: - inputs_embeds = self.language_model.model.get_input_embeddings( - input_ids) - # always pass the input via `inputs_embeds` - # to make sure the computation graph is consistent - # for `torch.compile` integration - input_ids = None + # NOTE: In v1, inputs_embeds is always generated at model runner, this + # condition is for v0 compatibility. + elif inputs_embeds is None: + vision_embeddings = self.get_multimodal_embeddings(**kwargs) + inputs_embeds = self.get_input_embeddings(input_ids, + vision_embeddings) + input_ids = None hidden_states = self.language_model.model(input_ids, positions, @@ -629,7 +652,6 @@ def forward( attn_metadata, intermediate_tensors, inputs_embeds=inputs_embeds) - return hidden_states def compute_logits( @@ -647,13 +669,6 @@ def sample( ) -> Optional[SamplerOutput]: return self.language_model.sample(logits, sampling_metadata) - def pooler( - self, - hidden_states: torch.Tensor, - pooling_metadata: PoolingMetadata, - ) -> Optional[PoolerOutput]: - return self._pooler(hidden_states, pooling_metadata) - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: loader = AutoWeightsLoader(self) diff --git a/vllm/model_executor/models/llava_next_video.py b/vllm/model_executor/models/llava_next_video.py index e2880c76cf43d..0de9d8c5ea572 100644 --- a/vllm/model_executor/models/llava_next_video.py +++ b/vllm/model_executor/models/llava_next_video.py @@ -18,6 +18,7 @@ from vllm.model_executor.models.clip import CLIPVisionModel from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import NestedTensors from vllm.multimodal.utils import (cached_get_tokenizer, repeat_and_pad_placeholder_tokens) from vllm.sequence import IntermediateTensors @@ -274,9 +275,10 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: text_hidden_size=config.text_config.hidden_size, projector_hidden_act=config.projector_hidden_act) self.language_model = init_vllm_registered_model( - config.text_config, vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "language_model")) + hf_config=config.text_config, + prefix=maybe_prefix(prefix, "language_model"), + ) self.make_empty_intermediate_tensors = ( self.language_model.model.make_empty_intermediate_tensors) @@ -388,6 +390,25 @@ def _process_video_pixels(self, inputs: LlavaNextVideoPixelInputs): raise ValueError( f"Unsupported type of video input {type(video_pixels)}") + def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]: + video_input = self._parse_and_validate_video_input(**kwargs) + if video_input is None: + return None + vision_embeddings = self._process_video_pixels(video_input) + return vision_embeddings + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[NestedTensors] = None, + ) -> torch.Tensor: + inputs_embeds = self.language_model.get_input_embeddings(input_ids) + if multimodal_embeddings is not None: + inputs_embeds = merge_multimodal_embeddings( + input_ids, inputs_embeds, multimodal_embeddings, + self.config.video_token_index) + return inputs_embeds + def forward( self, input_ids: torch.Tensor, @@ -395,6 +416,7 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, **kwargs: object, ) -> Union[torch.Tensor, IntermediateTensors]: """Run forward pass for LlaVA-NeXT-Video. @@ -404,22 +426,15 @@ def forward( pixel_values_videos: Pixels in each frames for each input videos. """ if intermediate_tensors is not None: - input_ids = None inputs_embeds = None - else: - video_input = self._parse_and_validate_video_input(**kwargs) - if video_input is not None: - video_embeddings = self._process_video_pixels(video_input) - inputs_embeds = self.language_model \ - .model.get_input_embeddings(input_ids) - - inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, video_embeddings, - self.config.video_token_index) - - input_ids = None - else: - inputs_embeds = None + + # NOTE: In v1, inputs_embeds is always generated at model runner, this + # condition is for v0 compatibility. + elif inputs_embeds is None: + vision_embeddings = self.get_multimodal_embeddings(**kwargs) + inputs_embeds = self.get_input_embeddings(input_ids, + vision_embeddings) + input_ids = None hidden_states = self.language_model.model(input_ids, positions, diff --git a/vllm/model_executor/models/llava_onevision.py b/vllm/model_executor/models/llava_onevision.py index 705ca1e4ab6e6..0bebc1c745e2b 100644 --- a/vllm/model_executor/models/llava_onevision.py +++ b/vllm/model_executor/models/llava_onevision.py @@ -21,6 +21,7 @@ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import NestedTensors from vllm.multimodal.utils import (cached_get_tokenizer, repeat_and_pad_placeholder_tokens) from vllm.sequence import IntermediateTensors @@ -421,9 +422,10 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: prefix=maybe_prefix(prefix, "vision_tower")) self.multi_modal_projector = LlavaOnevisionMultiModalProjector(config) self.language_model = init_vllm_registered_model( - config.text_config, vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "language_model")) + hf_config=config.text_config, + prefix=maybe_prefix(prefix, "language_model"), + ) self.image_newline = nn.Parameter( torch.empty(config.text_config.hidden_size)) @@ -824,6 +826,49 @@ def apply_pooling(self, image_features, stride=2): image_feature = image_feature.view(batch_frames, -1, dim) return image_feature + def get_multimodal_embeddings( + self, **kwargs) -> Optional[List[Tuple[NestedTensors, str]]]: + modalities = self._parse_and_validate_multimodal_inputs(**kwargs) + if not modalities: + return None + + # We make a tuple of each embedding with its modality string. This is a + # temporary workaround for models to handle mixed modalities when + # get_multimodal_embeddings and get_input_embeddings are called + # separately. + # TODO(ywang96): Add support for mixed-modality inference for v1. + multimodal_embeddings: List[Tuple[NestedTensors, str]] = [] + + if "images" in modalities: + image_input = modalities["images"] + vision_embeddings = self._process_image_input(image_input) + multimodal_embeddings.append((vision_embeddings, "image")) + if "videos" in modalities: + video_input = modalities["videos"] + video_embeddings = self._process_video_pixels(video_input) + multimodal_embeddings.append((video_embeddings, "video")) + + return multimodal_embeddings + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[List[Tuple[NestedTensors, + str]]] = None, + ) -> torch.Tensor: + inputs_embeds = self.language_model.get_input_embeddings(input_ids) + if multimodal_embeddings is not None: + for embeddings, modality in multimodal_embeddings: + if modality == "image": + inputs_embeds = merge_multimodal_embeddings( + input_ids, inputs_embeds, embeddings, + self.config.image_token_index) + if modality == "video": + inputs_embeds = merge_multimodal_embeddings( + input_ids, inputs_embeds, embeddings, + self.config.video_token_index) + return inputs_embeds + def forward( self, input_ids: torch.Tensor, @@ -831,6 +876,7 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, **kwargs: object, ) -> Union[torch.Tensor, IntermediateTensors]: """Run forward pass for LlaVA-Onevision. @@ -840,28 +886,15 @@ def forward( pixel_values_videos: Pixels in each frames for each input videos. """ if intermediate_tensors is not None: - input_ids = None inputs_embeds = None - else: - modalities = self._parse_and_validate_multimodal_inputs(**kwargs) - if modalities: - inputs_embeds = self.language_model.model.get_input_embeddings( - input_ids) - if "images" in modalities: - image_input = modalities["images"] - vision_embeddings = self._process_image_input(image_input) - inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, vision_embeddings, - self.config.image_token_index) - if "videos" in modalities: - video_input = modalities["videos"] - video_embeddings = self._process_video_pixels(video_input) - inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, video_embeddings, - self.config.video_token_index) - input_ids = None - else: - inputs_embeds = None + + # NOTE: In v1, inputs_embeds is always generated at model runner, this + # condition is for v0 compatibility. + elif inputs_embeds is None: + multimodal_embeddings = self.get_multimodal_embeddings(**kwargs) + inputs_embeds = self.get_input_embeddings(input_ids, + multimodal_embeddings) + input_ids = None hidden_states = self.language_model.model(input_ids, positions, diff --git a/vllm/model_executor/models/mamba.py b/vllm/model_executor/models/mamba.py index 405b8f7787ba8..ac0d265a961f0 100644 --- a/vllm/model_executor/models/mamba.py +++ b/vllm/model_executor/models/mamba.py @@ -1,5 +1,5 @@ """PyTorch MAMBA model.""" -from typing import Iterable, List, Optional, Set, Tuple +from typing import Iterable, List, Optional, Tuple import torch from torch import nn @@ -243,10 +243,8 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): params_dict = dict(self.named_parameters()) - loaded_params: Set[str] = set() for name, loaded_weight in weights: if "A_log" in name: name = name.replace("A_log", "A") @@ -258,5 +256,3 @@ def load_weights(self, weights: Iterable[Tuple[str, weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) - loaded_params.add(name) - return loaded_params diff --git a/vllm/model_executor/models/medusa.py b/vllm/model_executor/models/medusa.py index b4ed6538bddac..66bdcb89a0213 100644 --- a/vllm/model_executor/models/medusa.py +++ b/vllm/model_executor/models/medusa.py @@ -61,14 +61,25 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: self.truncated_vocab_size = config.truncated_vocab_size self.unpadded_vocab_size = self.truncated_vocab_size - self.lm_heads = nn.ModuleList([ - ParallelLMHead( + if getattr(config, "original_lm_head", False): + self.lm_head = ParallelLMHead( self.unpadded_vocab_size, config.hidden_size, org_num_embeddings=self.truncated_vocab_size, padding_size=DEFAULT_VOCAB_PADDING_SIZE, - ) for _ in range(self.config.num_heads) - ]) + ) + self.lm_heads = [ + self.lm_head for _ in range(self.config.num_heads) + ] + else: + self.lm_heads = nn.ModuleList([ + ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=self.truncated_vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE, + ) for _ in range(self.config.num_heads) + ]) logit_scale = getattr(config, "logit_scale", 1.0) self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, @@ -172,6 +183,9 @@ def load_weights(self, weights: Iterable[Tuple[str, requires_grad=False) elif name in params_dict: weights_map[name] = loaded_weight + elif (getattr(self.config, "original_lm_head", False) + and name == "lm_heads.0.weight"): + weights_map["lm_head.weight"] = loaded_weight for name, loaded_weight in weights_map.items(): if "lm_head" in name and self.token_map is not None and\ diff --git a/vllm/model_executor/models/minicpm.py b/vllm/model_executor/models/minicpm.py index b92bff4d7c28c..6254d26c7060d 100644 --- a/vllm/model_executor/models/minicpm.py +++ b/vllm/model_executor/models/minicpm.py @@ -52,7 +52,7 @@ from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP -from .utils import (is_pp_missing_parameter, +from .utils import (AutoWeightsLoader, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) @@ -192,6 +192,7 @@ def __init__( max_position_embeddings: int = 8192, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.hidden_size = hidden_size @@ -246,7 +247,8 @@ def __init__( self.scaling, num_kv_heads=self.num_kv_heads, cache_config=cache_config, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.attn") def forward( self, @@ -273,6 +275,7 @@ def __init__( config: PretrainedConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.config = config @@ -283,6 +286,7 @@ def __init__( self.rope_scaling = getattr(config, "rope_scaling", None) self.max_position_embeddings = getattr(config, "max_position_embeddings", 8192) + self.prefix = prefix self._init_attn_block() self._init_ffn_block() @@ -298,6 +302,7 @@ def _init_attn_block(self): max_position_embeddings=self.max_position_embeddings, cache_config=self.cache_config, quant_config=self.quant_config, + prefix=f"{self.prefix}.self_attn", ) def _init_ffn_block(self): @@ -373,6 +378,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config.hidden_size, org_num_embeddings=config.vocab_size, ) + self.num_experts = getattr(self.config, "num_experts", 0) self._init_layers(prefix, config, cache_config, quant_config) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.make_empty_intermediate_tensors = ( @@ -388,8 +394,8 @@ def _init_layers( ): self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, - lambda prefix: MiniCPMDecoderLayer(config, cache_config, - quant_config), + lambda prefix: MiniCPMDecoderLayer( + config, cache_config, quant_config, prefix=prefix), prefix=f"{prefix}.layers") def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: @@ -432,6 +438,73 @@ def forward( hidden_states = self.norm(hidden_states) return hidden_states + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + expert_params_mapping = [ + # (param_name, weight_name, expert_id) + ("ws" if weight_name in ["w1", "w3"] else "w2s", + f"experts.{expert_id}.{weight_name}.weight", expert_id) + for expert_id in range(self.num_experts) + for weight_name in ["w1", "w2", "w3"] + ] + params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + if ("rotary_emb.cos_cached" in name + or "rotary_emb.sin_cached" in name): + # Models trained using ColossalAI may include these tensors in + # the checkpoint. Skip them. + continue + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + for param_name, weight_name, expert_id in expert_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, + loaded_weight, + weight_name, + expert_id=expert_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP): packed_modules_mapping = { @@ -475,8 +548,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.cache_config = cache_config self.quant_config = quant_config - self.num_experts = getattr(self.config, "num_experts", 0) - self._init_model(vllm_config=vllm_config, prefix=prefix) + self.model = self._init_model(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + unpadded_vocab_size = config.vocab_size if lora_config: unpadded_vocab_size += lora_config.lora_extra_vocab_size @@ -501,8 +575,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.model.make_empty_intermediate_tensors) def _init_model(self, *, vllm_config: VllmConfig, prefix: str = ""): - self.model = MiniCPMModel(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) + return MiniCPMModel(vllm_config=vllm_config, prefix=prefix) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) @@ -541,72 +614,9 @@ def sample( def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - ("qkv_proj", "q_proj", "q"), - ("qkv_proj", "k_proj", "k"), - ("qkv_proj", "v_proj", "v"), - ("gate_up_proj", "gate_proj", 0), - ("gate_up_proj", "up_proj", 1), - ] - expert_params_mapping = [ - # (param_name, weight_name, expert_id) - ("ws" if weight_name in ["w1", "w3"] else "w2s", - f"experts.{expert_id}.{weight_name}.weight", expert_id) - for expert_id in range(self.num_experts) - for weight_name in ["w1", "w2", "w3"] - ] - params_dict = dict(self.named_parameters()) - loaded_params: Set[str] = set() - for name, loaded_weight in weights: - if "rotary_emb.inv_freq" in name: - continue - if ("rotary_emb.cos_cached" in name - or "rotary_emb.sin_cached" in name): - # Models trained using ColossalAI may include these tensors in - # the checkpoint. Skip them. - continue - # With tie_word_embeddings, we can skip lm_head.weight - # The weight might appear unnecessarily in the files if the model is - # processed with quantization, LoRA, fine-tuning, etc. - if self.config.tie_word_embeddings and "lm_head.weight" in name: - continue - for (param_name, weight_name, shard_id) in stacked_params_mapping: - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - if is_pp_missing_parameter(name, self): - continue - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - break - else: - for param_name, weight_name, expert_id in expert_params_mapping: - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - if is_pp_missing_parameter(name, self): - continue - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, - loaded_weight, - weight_name, - expert_id=expert_id) - break - else: - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - if is_pp_missing_parameter(name, self): - continue - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) - loaded_params.add(name) - return loaded_params + loader = AutoWeightsLoader( + self, + skip_prefixes=(["lm_head."] + if self.config.tie_word_embeddings else None), + ) + return loader.load_weights(weights) diff --git a/vllm/model_executor/models/minicpm3.py b/vllm/model_executor/models/minicpm3.py index 278c4bbe6e563..e9d7eada1d16c 100644 --- a/vllm/model_executor/models/minicpm3.py +++ b/vllm/model_executor/models/minicpm3.py @@ -40,7 +40,7 @@ MiniCPMForCausalLM, MiniCPMModel) -from .utils import make_layers, maybe_prefix +from .utils import make_layers class MiniCPM3Attention(nn.Module): @@ -60,6 +60,7 @@ def __init__( max_position_embeddings: int = 8192, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.hidden_size = hidden_size @@ -119,7 +120,8 @@ def __init__( self.scaling, num_kv_heads=self.num_local_heads, cache_config=cache_config, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.attn") def forward( self, @@ -195,6 +197,7 @@ def _init_attn_block(self): max_position_embeddings=self.max_position_embeddings, cache_config=self.cache_config, quant_config=self.quant_config, + prefix=f"{self.prefix}.self_attn", ) @@ -209,8 +212,8 @@ def _init_layers( ): self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, - lambda prefix: MiniCPM3DecoderLayer(config, cache_config, - quant_config), + lambda prefix: MiniCPM3DecoderLayer( + config, cache_config, quant_config, prefix=prefix), prefix=f"{prefix}.layers") @@ -238,6 +241,11 @@ class MiniCPM3ForCausalLM(MiniCPMForCausalLM): # `embedding_modules` and `embedding_padding_modules` # are inherited from MiniCPMForCausalLM + bitsandbytes_stacked_params_mapping = { + # shard_name, weight_name, index + "gate_proj": ("gate_up_proj", 0), + "up_proj": ("gate_up_proj", 1), + } + def _init_model(self, *, vllm_config: VllmConfig, prefix: str = ""): - self.model = MiniCPM3Model(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) + return MiniCPM3Model(vllm_config=vllm_config, prefix=prefix) diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index 99bf1d42d0355..1e8f9bd4cf418 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -22,7 +22,7 @@ """Inference-only MiniCPM-V model compatible with HuggingFace weights.""" import math import re -from functools import partial +from functools import cached_property, partial from typing import (Any, Callable, Iterable, List, Literal, Mapping, Optional, Set, Tuple, TypedDict, Union) @@ -37,19 +37,15 @@ from vllm.config import VllmConfig from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, InputContext, token_inputs) -from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.resampler import (BaseResampler, Resampler2, get_2d_sincos_pos_embed) from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler -from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.model_loader.utils import set_default_torch_dtype -from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.models.llama import LlamaModel -from vllm.model_executor.models.minicpm import MiniCPMModel +from vllm.model_executor.models.llama import LlamaForCausalLM +from vllm.model_executor.models.minicpm import MiniCPMForCausalLM from vllm.model_executor.models.module_mapping import MultiModelKeys -from vllm.model_executor.models.qwen2 import Qwen2Model -from vllm.model_executor.models.utils import LLMWrapper +from vllm.model_executor.models.qwen2 import Qwen2ForCausalLM from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs from vllm.multimodal.image import cached_get_image_processor @@ -58,11 +54,7 @@ from .idefics2_vision_model import Idefics2VisionTransformer from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP -from .utils import is_pp_missing_parameter, maybe_prefix - -_KEYS_TO_MODIFY_MAPPING = { - "llm.lm_head": "lm_head", -} +from .utils import AutoWeightsLoader, maybe_prefix RawImageType = Union[Image.Image, torch.Tensor] @@ -297,10 +289,9 @@ def input_processor_for_minicpmv(ctx: InputContext, inputs: DecoderOnlyInputs): def get_placeholder(image_size: Tuple[int, int], num_image: int): if version == (2, 0) or version == (2, 5): - return image_processor. \ - get_slice_image_placeholder(image_size) - return image_processor. \ - get_slice_image_placeholder(image_size, num_image) + return image_processor.get_slice_image_placeholder(image_size) + return image_processor.get_slice_image_placeholder( + image_size, num_image) prompt = inputs.get("prompt") token_ids = inputs.get("prompt_token_ids") @@ -400,37 +391,32 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.vpm = self.init_vision_module(config, quant_config, prefix=maybe_prefix(prefix, "vpm")) - param_dtype = torch.get_default_dtype() - self.vpm.to(dtype=param_dtype) self.vision_dim = (self.vpm.embed_dim if self.version == (2, 0) else self.vpm.embeddings.embed_dim) self.embed_dim = self.config.hidden_size + self.resampler = self.init_resampler(self.embed_dim, self.vision_dim, quant_config=quant_config, prefix=maybe_prefix( prefix, "resampler")) - self.resampler.to(device="cuda", dtype=param_dtype) - # TODO: why is there _KEYS_TO_MODIFY_MAPPING? lm_head should be in llm - self.lm_head = ParallelLMHead(config.vocab_size, - config.hidden_size, - quant_config=quant_config, - prefix=maybe_prefix( - prefix, "llm.lm_head")) - self.logits_processor = LogitsProcessor(config.vocab_size) - self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.llm.make_empty_intermediate_tensors) + @cached_property + def sampler(self): + if hasattr(self.llm, "sampler"): + return self.llm.sampler + + return get_sampler() + def get_embedding( self, input_ids: torch.Tensor, image_inputs: Optional[MiniCPMVImageInputs], ) -> Tuple[torch.Tensor, torch.Tensor]: - vlm_embedding: torch.Tensor = self.llm.embed_tokens(input_ids) - if hasattr(self.config, "scale_emb"): - vlm_embedding *= self.config.scale_emb + vlm_embedding: torch.Tensor = self.llm.get_input_embeddings(input_ids) if image_inputs is None: # No image vision_hidden_states = torch.tensor([], device=input_ids.device) @@ -575,7 +561,7 @@ def forward( # for `torch.compile` integration input_ids = None - output = self.llm( + output = self.llm.model( input_ids=input_ids, positions=positions, kv_caches=kv_caches, @@ -590,9 +576,7 @@ def compute_logits( hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) - return logits + return self.llm.compute_logits(hidden_states, sampling_metadata) def sample( self, @@ -604,52 +588,8 @@ def sample( def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - ("qkv_proj", "q_proj", "q"), - ("qkv_proj", "k_proj", "k"), - ("qkv_proj", "v_proj", "v"), - ("gate_up_proj", "gate_proj", 0), - ("gate_up_proj", "up_proj", 1), - ] - params_dict = dict(self.named_parameters()) - loaded_params: Set[str] = set() - for name, loaded_weight in weights: - for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items(): - if key_to_modify in name: - name = name.replace(key_to_modify, new_key) - if "rotary_emb.inv_freq" in name: - continue - if ("rotary_emb.cos_cached" in name - or "rotary_emb.sin_cached" in name): - # Models trained using ColossalAI may include these tensors in - # the checkpoint. Skip them. - continue - use_default_weight_loading = False - if self.is_default_weight_loading(name): - use_default_weight_loading = True - else: - for param_name, weight_name, shard_id in stacked_params_mapping: - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - if is_pp_missing_parameter(name, self): - continue - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - break - else: - use_default_weight_loading = True - if use_default_weight_loading: - if is_pp_missing_parameter(name, self): - continue - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) - loaded_params.add(name) - return loaded_params + loader = AutoWeightsLoader(self) + return loader.load_weights(weights) def get_mm_mapping(self) -> MultiModelKeys: """ @@ -693,9 +633,6 @@ def get_vision_hidden_states(self, data: MiniCPMVImageInputs) -> torch.Tensor: raise NotImplementedError - def is_default_weight_loading(self, name: str) -> bool: - raise NotImplementedError - class MiniCPMV2_0(MiniCPMVBaseModel): @@ -708,8 +645,7 @@ def init_llm( vllm_config: VllmConfig, prefix: str = "", ) -> nn.Module: - return LLMWrapper(MiniCPMModel(vllm_config=vllm_config, prefix=prefix), - name="model") + return MiniCPMForCausalLM(vllm_config=vllm_config, prefix=prefix) def init_vision_module( self, @@ -717,11 +653,12 @@ def init_vision_module( quant_config: Optional[QuantizationConfig], prefix: str = "", ) -> nn.Module: - # TODO :refactor this vision model + # TODO: refactor this vision model try: import timm except ImportError: raise ImportError("Please install timm==0.9.10") from ImportError + with set_default_torch_dtype(torch.float16): model = timm.create_model( "vit_so400m_patch14_siglip_384.webli", @@ -731,6 +668,8 @@ def init_vision_module( dynamic_img_pad=True, ) + model = model.to(dtype=torch.get_default_dtype()) + if (isinstance(model, timm.models.VisionTransformer) and model.attn_pool is not None): model.attn_pool = torch.nn.Identity() @@ -759,7 +698,7 @@ def init_resampler(self, quant_config=quant_config, prefix=prefix) - return resampler + return resampler.to(device="cuda", dtype=torch.get_default_dtype()) def get_vision_embedding( self, @@ -790,9 +729,6 @@ def get_vision_hidden_states(self, return self.get_vision_embedding(pixel_values) - def is_default_weight_loading(self, name: str) -> bool: - return "resampler" in name or "vpm" in name - class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA): packed_modules_mapping = { @@ -822,25 +758,6 @@ class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA): ] # BitandBytes specific attributes - default_bitsandbytes_target_modules = [ - ".gate_proj.", - ".down_proj.", - ".up_proj.", - ".q_proj.", - ".k_proj.", - ".v_proj.", - ".o_proj.", - # vision encoder - ".fc1.", - ".fc2.", - # Currently, vllm does not support BNB quantization for the `out_proj` - # of the resampler, so it's necessary to distinguish between the - # vision encoder and the resampler's out_proj. The same applies to - # MiniCPMV2_6. - ".self_attn.out_proj.", # vision encoder out_proj - # resampler - ".kv_proj.", - ] bitsandbytes_stacked_params_mapping = { # shard_name, weight_name, index "q_proj": ("qkv_proj", 0), @@ -862,8 +779,7 @@ def init_llm( vllm_config: VllmConfig, prefix: str = "", ) -> nn.Module: - return LLMWrapper(LlamaModel(vllm_config=vllm_config, prefix=prefix), - name="model") + return LlamaForCausalLM(vllm_config=vllm_config, prefix=prefix) def init_vision_module( self, @@ -890,7 +806,8 @@ def init_resampler(self, kv_dim=vision_dim, quant_config=quant_config, prefix=prefix) - return resampler + + return resampler.to(device="cuda", dtype=torch.get_default_dtype()) def get_vision_embedding( self, @@ -932,9 +849,6 @@ def get_vision_hidden_states(self, return self.get_vision_embedding(all_pixel_values.type(dtype), patch_attn_mask, tgt_sizes) - def is_default_weight_loading(self, name: str) -> bool: - return "resampler" in name - class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA): packed_modules_mapping = { @@ -964,21 +878,6 @@ class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA): ] # BitandBytes specific attributes - default_bitsandbytes_target_modules = [ - ".gate_proj.", - ".down_proj.", - ".up_proj.", - ".q_proj.", - ".k_proj.", - ".v_proj.", - ".o_proj.", - # vision encoder - ".fc1.", - ".fc2.", - ".self_attn.out_proj.", - # resampler - ".kv_proj.", - ] bitsandbytes_stacked_params_mapping = { # shard_name, weight_name, index "q_proj": ("qkv_proj", 0), @@ -1000,8 +899,7 @@ def init_llm( vllm_config: VllmConfig, prefix: str = "", ) -> nn.Module: - return LLMWrapper(Qwen2Model(vllm_config=vllm_config, prefix=prefix), - name="model") + return Qwen2ForCausalLM(vllm_config=vllm_config, prefix=prefix) def init_vision_module( self, @@ -1029,7 +927,8 @@ def init_resampler(self, kv_dim=vision_dim, quant_config=quant_config, prefix=prefix) - return resampler + + return resampler.to(device="cuda", dtype=torch.get_default_dtype()) def get_vision_embedding( self, @@ -1077,9 +976,6 @@ def get_vision_hidden_states(self, return self.resampler(vision_embedding, tgt_sizes) - def is_default_weight_loading(self, name: str) -> bool: - return "resampler" in name - _SUPPORT_VERSION = { (2, 0): MiniCPMV2_0, diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 0faffb4f1b00c..a5b364fe5ec85 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -166,7 +166,8 @@ def __init__( self.scaling, num_kv_heads=self.num_kv_heads, cache_config=cache_config, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.attn") def forward( self, diff --git a/vllm/model_executor/models/mixtral_quant.py b/vllm/model_executor/models/mixtral_quant.py index ddd6afcf6a1b6..7a9b8cd88cfd0 100644 --- a/vllm/model_executor/models/mixtral_quant.py +++ b/vllm/model_executor/models/mixtral_quant.py @@ -170,6 +170,7 @@ def __init__( rope_theta: float = 10000, quant_config: Optional[QuantizationConfig] = None, cache_config: Optional[CacheConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.hidden_size = hidden_size @@ -219,7 +220,8 @@ def __init__( self.scaling, num_kv_heads=self.num_kv_heads, cache_config=cache_config, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.attn") def forward( self, @@ -243,6 +245,7 @@ def __init__( config: MixtralConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.hidden_size = config.hidden_size @@ -255,7 +258,9 @@ def __init__( num_kv_heads=config.num_key_value_heads, rope_theta=rope_theta, cache_config=cache_config, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + ) self.block_sparse_moe = MixtralMoE(config=config, quant_config=quant_config) self.input_layernorm = RMSNorm(config.hidden_size, @@ -311,7 +316,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, lambda prefix: MixtralDecoderLayer( - config, cache_config, quant_config=quant_config), + config, cache_config, quant_config=quant_config, prefix=prefix + ), prefix=f"{prefix}.layers") self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.make_empty_intermediate_tensors = ( diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py index bd9bd03c8c2db..7f9b5c3cdb5ee 100644 --- a/vllm/model_executor/models/mllama.py +++ b/vllm/model_executor/models/mllama.py @@ -33,6 +33,7 @@ import vllm.distributed.parallel_state as ps from vllm.attention import Attention, AttentionMetadata, AttentionType from vllm.attention.ops.paged_attn import PagedAttention +from vllm.attention.selector import _Backend from vllm.config import VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.inputs import (INPUT_REGISTRY, DummyData, EncoderDecoderInputs, @@ -837,10 +838,8 @@ def _attention_with_mask( cached_k, cached_v, key_cache, value_cache, attn_metadata.cross_slot_mapping, "auto", 1.0, 1.0) else: - from vllm.attention.backends.flash_attn import ( - FlashAttentionMetadata) - from vllm.attention.backends.xformers import XFormersMetadata - if isinstance(attn_metadata, FlashAttentionMetadata): + if self.attn.backend in (_Backend.FLASH_ATTN, + _Backend.FLASH_ATTN_VLLM_V1): cached_k = torch.cat( [k[s:e] for s, e in kv_range_for_decode]) cached_v = torch.cat( @@ -856,7 +855,8 @@ def _attention_with_mask( 1.0, 1.0, ) - elif isinstance(attn_metadata, XFormersMetadata): + elif self.attn.backend in (_Backend.XFORMERS, + _Backend.TORCH_SDPA): key_cache, value_cache = PagedAttention.split_kv_cache( kv_cache, self.num_local_key_value_heads, self.head_dim) @@ -869,10 +869,10 @@ def _attention_with_mask( attn_metadata.cross_slot_mapping, "auto", 1.0, 1.0) else: raise ValueError( - f"Unsupported AttentionMetadata {type(attn_metadata)} " - f"class found. Expected the AttentionMetadata to " - f"be either XFormersMetadata or FlashAttentionMetadata." - ) + f"Unsupported Attention backend {self.attn.backend} " + "enum found. Expected the Attention backend to be " + "FLASH_ATTN, FLASH_ATTN_VLLM_V1, XFORMERS " + "or TORCH_SDPA.") # We have to call torch.sdpa for prefill when using a # custom cross-attention mask. Because the mask is not a @@ -1122,20 +1122,6 @@ def forward( @INPUT_REGISTRY.register_input_processor(input_processor_for_mllama) class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal): # BitandBytes specific attributes - default_bitsandbytes_target_modules = [ - ".gate_proj.", - ".down_proj.", - ".up_proj.", - ".q_proj.", - ".k_proj.", - ".v_proj.", - ".o_proj.", - ".fc1.", - ".fc2.", - # The `multi_modal_projector` is at the top level of the model, - # so we can't add a dot in front of it. - "multi_modal_projector." - ] bitsandbytes_stacked_params_mapping = { # shard_name, weight_name, index "q_proj": ("qkv_proj", 0), diff --git a/vllm/model_executor/models/mlp_speculator.py b/vllm/model_executor/models/mlp_speculator.py index f2aa2653c4f5c..d49da5f29aa14 100644 --- a/vllm/model_executor/models/mlp_speculator.py +++ b/vllm/model_executor/models/mlp_speculator.py @@ -193,7 +193,8 @@ def load_weights(self, weights: Iterable[Tuple[str, params_dict = dict(self.named_parameters()) loaded_params: Set[str] = set() for name, loaded_weight in weights: - param = params_dict.get(name.replace("speculator.", "")) + name = name.replace("speculator.", "") + param = params_dict.get(name) if param is not None: weight_loader = getattr(param, "weight_loader", default_weight_loader) diff --git a/vllm/model_executor/models/molmo.py b/vllm/model_executor/models/molmo.py index a7c90a3f5031b..98caa6857e211 100644 --- a/vllm/model_executor/models/molmo.py +++ b/vllm/model_executor/models/molmo.py @@ -3,7 +3,7 @@ from array import array from dataclasses import dataclass from functools import lru_cache, partial -from typing import Iterable, List, Mapping, Optional, Tuple, TypedDict, Union +from typing import Iterable, List, Mapping, Optional, Set, Tuple, TypedDict import torch from einops import rearrange @@ -13,7 +13,6 @@ from transformers import PretrainedConfig from vllm.attention import Attention, AttentionMetadata -from vllm.attention.selector import _Backend from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, @@ -37,13 +36,16 @@ ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs +from vllm.multimodal.inputs import NestedTensors from vllm.multimodal.utils import cached_get_tokenizer +from vllm.platforms import _Backend from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors, SequenceData) from vllm.transformers_utils.processor import get_processor from .interfaces import SupportsMultiModal, SupportsPP -from .utils import (get_vit_attn_backend, +from .utils import (AutoWeightsLoader, WeightsMapper, get_vit_attn_backend, + is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) @@ -370,6 +372,7 @@ def __init__( config: PretrainedConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.hidden_size = config.hidden_size @@ -427,7 +430,8 @@ def __init__( self.scaling, num_kv_heads=self.num_kv_heads, cache_config=cache_config, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.attn") # Attention output projection. self.o_proj = RowParallelLinear( @@ -517,10 +521,14 @@ def __init__( config: PretrainedConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() # Attention block. - self.self_attn = MolmoAttention(config, cache_config, quant_config) + self.self_attn = MolmoAttention(config, + cache_config, + quant_config, + prefix=f"{prefix}.self_attn") # MLP block. self.mlp = MolmoMLP(config, quant_config=quant_config) @@ -713,6 +721,42 @@ def forward( # image_features: (batch_size, num_image, num_patch, d_model) return image_features + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() + + for name, loaded_weight in weights: + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + @support_torch_compile class MolmoModel(nn.Module): @@ -738,7 +782,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): else MolmoDecoderLayer self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, - lambda prefix: decoder_layer(config, cache_config, quant_config), + lambda prefix: decoder_layer( + config, cache_config, quant_config, prefix=prefix), prefix=f"{prefix}.layers", ) @@ -749,6 +794,12 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): make_empty_intermediate_tensors_factory( ["hidden_states", "residual"], config.hidden_size)) + def get_input_embeddings( + self, + input_ids: torch.Tensor, + ) -> torch.Tensor: + return self.embed_tokens(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -790,6 +841,28 @@ def forward( hidden_states = self.norm(hidden_states) return hidden_states + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() + + for name, loaded_weight in weights: + if "gate_up_proj" in name: + up_proj, gate_proj = loaded_weight.chunk(2, dim=0) + loaded_weight = torch.cat([gate_proj, up_proj], dim=0) + + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + cached_get_processor = lru_cache(get_processor) @@ -1091,19 +1164,16 @@ def _process_image_input( return image_features - def _merge_multimodal_embeddings( - self, - inputs_embeds: torch.Tensor, - image_features: torch.Tensor, - image_input_idx: torch.Tensor, - seq_len: Union[torch.Tensor, List[torch.Tensor]], - ) -> torch.Tensor: + def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]: + image_input = self._parse_and_validate_image_input(**kwargs) + if image_input is None: + return None + image_features = self._process_image_input(image_input) + image_input_idx = image_input["image_input_idx"] + seq_len = image_input["seq_len"] batch_size, num_image, num_patch = image_features.shape[:3] assert image_input_idx.shape == (batch_size, num_image, num_patch) - image_features = image_features.to(inputs_embeds.device) - seq_len = seq_len.to(inputs_embeds.device) - # insert the image feature into the embedding. image_features = image_features.view(batch_size, num_image * num_patch, -1) @@ -1123,12 +1193,24 @@ def _merge_multimodal_embeddings( image_input_idx = image_input_idx + offset.to(image_input_idx.dtype) image_input_idx = image_input_idx.flatten()[:, None] mat = image_input_idx == torch.arange( - seq_len.sum().item(), device=inputs_embeds.device)[None, :] + seq_len.sum().item(), device=image_features.device)[None, :] mat = mat.to(image_features.dtype) - inputs_embeds = inputs_embeds + torch.einsum('nd,nm->md', - image_features, mat) + # Note: In this original implementation from AI2, the final + # vision_embeddings will be always be the same length + # of input embedddings, which is not very efficient. + # TODO(ywang96): see if this can be optimized. + vision_embeddings = torch.einsum('nd,nm->md', image_features, mat) + return vision_embeddings + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[NestedTensors] = None, + ) -> torch.Tensor: + inputs_embeds = self.model.get_input_embeddings(input_ids) + if multimodal_embeddings is not None: + inputs_embeds = inputs_embeds + multimodal_embeddings return inputs_embeds def forward( @@ -1138,39 +1220,27 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, **kwargs: object, ) -> SamplerOutput: + if intermediate_tensors is not None: inputs_embeds = None - else: - image_input = self._parse_and_validate_image_input(**kwargs) - - if image_input is not None: - inputs_embeds = self.model.embed_tokens(input_ids) - image_features = self._process_image_input(image_input) - - inputs_embeds = self._merge_multimodal_embeddings( - inputs_embeds, - image_features, - image_input["image_input_idx"], - image_input["seq_len"], - ) - else: - inputs_embeds = self.model.embed_tokens(input_ids) - # always pass the input via `inputs_embeds` - # to make sure the computation graph is consistent - # for `torch.compile` integration - input_ids = None - - hidden_states = self.model( - input_ids=input_ids, - positions=positions, - kv_caches=kv_caches, - attn_metadata=attn_metadata, - intermediate_tensors=intermediate_tensors, - inputs_embeds=inputs_embeds, - ) + # NOTE: In v1, inputs_embeds is always generated at model runner, this + # condition is for v0 compatibility. + elif inputs_embeds is None: + vision_embeddings = self.get_multimodal_embeddings(**kwargs) + inputs_embeds = self.get_input_embeddings(input_ids, + vision_embeddings) + input_ids = None + + hidden_states = self.model(input_ids, + positions, + kv_caches, + attn_metadata, + intermediate_tensors, + inputs_embeds=inputs_embeds) return hidden_states @@ -1189,103 +1259,53 @@ def sample( return next_tokens def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): - - params_mapping = [ - ("model.transformer.ln_f.weight", "model.norm.weight"), - ("attn_out", "self_attn.o_proj"), - ("att_proj", "self_attn.qkv_proj"), - ("q_norm", "self_attn.q_norm"), - ("k_norm", "self_attn.k_norm"), - ("attn_norm", "input_layernorm"), - ("ff_norm", "post_attention_layernorm"), - ] - - params_dict = dict(self.named_parameters(remove_duplicate=False)) - - embedding_weight = dict() - projector_weight = dict() - for name, loaded_weight in weights: - if "rotary_emb.inv_freq" in name: - continue - if self.config.tie_word_embeddings and "lm_head.weight" in name: - continue - - if "wte.embedding" in name: - embedding_weight["embedding"] = loaded_weight - continue - - if "wte.new_embedding" in name: - embedding_weight["new_embedding"] = loaded_weight - continue - - if "vision_backbone" in name: - if name.startswith("model"): - name = name[len("model."):] - if 'image_projector' in name: - if 'w1' in name: - projector_weight['gate_proj'] = loaded_weight - elif 'w3' in name: - projector_weight['up_proj'] = loaded_weight - elif 'w2' in name: - projector_weight['down_proj'] = loaded_weight - else: - raise ValueError( - f"Unexpected projector weight: {name}") - continue - else: - if "transformer.blocks" in name: - name = name.replace("transformer.blocks", "layers") - - if "ff_proj" in name: - name = name.replace("ff_proj", "mlp.gate_up_proj") - assert 'weight' in name - up_weight, gate_weight = loaded_weight.chunk(2, dim=0) - loaded_weight = torch.cat([gate_weight, up_weight], dim=0) - - elif "ff_out" in name: - if "layers" in name: - name = name.replace("ff_out", "mlp.down_proj") - else: - # lm head - name = name.replace("model.transformer.ff_out", - "lm_head") - - else: - for (param_name, weight_name) in params_mapping: - if param_name in name: - name = name.replace(param_name, weight_name) - break - - try: - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - param = params_dict[name] - except KeyError: - raise ValueError(f"Unexpected weight: {name}") from None - - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) - - gate_up_proj_weight = torch.cat( - [projector_weight["gate_proj"], projector_weight["up_proj"]], - dim=0) - name = "vision_backbone.image_projector.gate_up_proj.weight" - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, gate_up_proj_weight) - - down_proj_weight = projector_weight["down_proj"] - name = "vision_backbone.image_projector.down_proj.weight" - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, down_proj_weight) - - embedding_weight = torch.cat( - [embedding_weight["embedding"], embedding_weight["new_embedding"]], - dim=0) - name = "model.embed_tokens.weight" - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, embedding_weight) + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_substr={ + # vision backbone mapping + "image_projector.w1.": "image_projector.gate_proj.", + "image_projector.w3.": "image_projector.up_proj.", + "image_projector.w2.": "image_projector.down_proj.", + # language backbone mapping + "att_proj": "self_attn.qkv_proj", + "attn_out": "self_attn.o_proj", + "q_norm": "self_attn.q_norm", + "k_norm": "self_attn.k_norm", + "ff_proj": "mlp.gate_up_proj", + "ff_out": "mlp.down_proj", + "attn_norm": "input_layernorm", + "ff_norm": "post_attention_layernorm", + }, + orig_to_new_prefix={ + # vision backbone mapping + "model.vision_backbone.": "vision_backbone.", + # language backbone mapping + "model.transformer.blocks.": "model.layers.", + "model.transformer.ln_f.": "model.norm.", + # lm_head is renamed to model.transformer.mlp.down_proj firstly, + # we need to run a second renaming for it + "model.transformer.mlp.down_proj.": "lm_head.", + }, + ) + loader = AutoWeightsLoader(self) + weights = _get_weights_with_merged_embedding(weights) + return loader.load_weights(weights, mapper=hf_to_vllm_mapper) + + +def _get_weights_with_merged_embedding( + weights: Iterable[Tuple[str, torch.Tensor]] +) -> Iterable[Tuple[str, torch.Tensor]]: + embedding_weights = {} + for name, weight in weights: + if "wte.embedding" in name: + embedding_weights["embedding"] = weight + elif "wte.new_embedding" in name: + embedding_weights["new_embedding"] = weight + else: + yield (name, weight) + # this is compatible with most of quantization, + # because they won't quantize embed_tokens + embedding_weights = torch.cat( + [embedding_weights["embedding"], embedding_weights["new_embedding"]], + dim=0, + ) + yield ("model.embed_tokens.weight", embedding_weights) diff --git a/vllm/model_executor/models/mpt.py b/vllm/model_executor/models/mpt.py index 8716e92b0f1c2..1235816413a44 100644 --- a/vllm/model_executor/models/mpt.py +++ b/vllm/model_executor/models/mpt.py @@ -50,6 +50,7 @@ def __init__( config: MPTConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() self.d_model = config.d_model @@ -115,7 +116,8 @@ def __init__( alibi_slopes=alibi_slopes, num_kv_heads=self.num_kv_heads, cache_config=cache_config, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.attn") def forward( self, @@ -176,11 +178,15 @@ def __init__( config: MPTConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() hidden_size = config.d_model self.norm_1 = nn.LayerNorm(hidden_size) - self.attn = MPTAttention(config, cache_config, quant_config) + self.attn = MPTAttention(config, + cache_config, + quant_config, + prefix=f"{prefix}.attn") self.norm_2 = nn.LayerNorm(hidden_size) self.ffn = MPTMLP(config, quant_config) @@ -224,7 +230,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): ) self.start_layer, self.end_layer, self.blocks = make_layers( config.n_layers, - lambda prefix: MPTBlock(config, cache_config, quant_config), + lambda prefix: MPTBlock( + config, cache_config, quant_config, prefix=prefix), prefix=f"{prefix}.blocks") self.norm_f = nn.LayerNorm(config.d_model) if config.no_bias: diff --git a/vllm/model_executor/models/nemotron.py b/vllm/model_executor/models/nemotron.py index ceab299a7950a..c7b4c22b6896b 100644 --- a/vllm/model_executor/models/nemotron.py +++ b/vllm/model_executor/models/nemotron.py @@ -195,7 +195,8 @@ def __init__( self.scaling, num_kv_heads=self.num_kv_heads, cache_config=cache_config, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.attn") def forward( self, diff --git a/vllm/model_executor/models/olmo.py b/vllm/model_executor/models/olmo.py index dc138e2e636ad..538e31ec91699 100644 --- a/vllm/model_executor/models/olmo.py +++ b/vllm/model_executor/models/olmo.py @@ -62,6 +62,7 @@ def __init__( config: OlmoConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() self.config = config @@ -101,7 +102,8 @@ def __init__( self.head_dim, scale=self.scaling, cache_config=cache_config, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.attn") # Attention output projection. self.o_proj = RowParallelLinear( @@ -184,10 +186,14 @@ class OlmoDecoderLayer(nn.Module): def __init__(self, config: OlmoConfig, cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None): + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): super().__init__() # Attention block. - self.self_attn = OlmoAttention(config, cache_config, quant_config) + self.self_attn = OlmoAttention(config, + cache_config, + quant_config, + prefix=f"{prefix}.self_attn") # MLP block. self.mlp = OlmoMLP(config, quant_config) @@ -238,8 +244,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config.hidden_size) self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, - lambda prefix: OlmoDecoderLayer(config, cache_config, quant_config - ), + lambda prefix: OlmoDecoderLayer( + config, cache_config, quant_config, prefix=prefix), prefix=f"{prefix}.layers") self.norm = nn.LayerNorm(config.hidden_size, elementwise_affine=False, diff --git a/vllm/model_executor/models/olmo2.py b/vllm/model_executor/models/olmo2.py new file mode 100644 index 0000000000000..a35c911f90d96 --- /dev/null +++ b/vllm/model_executor/models/olmo2.py @@ -0,0 +1,432 @@ +# Adapted from +# https://github.com/huggingface/transformers/blob/main/src/transformers/models/olmo2/modeling_olmo2.py +# Copyright 2024 The vLLM team. +# Copyright 2024 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only OLMo2 model compatible with HuggingFace weights.""" + +from functools import partial +from typing import Iterable, List, Optional, Tuple, Union + +import torch +from torch import nn + +from vllm.attention import Attention, AttentionMetadata +from vllm.config import VllmConfig +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.distributed.communication_op import tensor_model_parallel_all_gather +from vllm.distributed.parallel_state import get_tensor_model_parallel_rank +from vllm.distributed.utils import split_tensor_along_last_dim +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.interfaces import SupportsPP +from vllm.model_executor.models.utils import ( + is_pp_missing_parameter, make_empty_intermediate_tensors_factory, + make_layers, maybe_prefix) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors +from vllm.transformers_utils.configs.olmo2 import Olmo2Config + + +class Olmo2Attention(nn.Module): + """ + This is the attention block where the output is computed as + ``Attention(LN(x))`` in ``MLP(LN(x + Attention(LN(x))))`` + (plus another skip connection). + """ + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + self.config = vllm_config.model_config.hf_config + assert isinstance(self.config, Olmo2Config) + + hidden_size = self.config.hidden_size + self.tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = self.config.num_attention_heads + + assert hidden_size % self.total_num_heads == 0 + assert self.total_num_heads % self.tp_size == 0 + + self.num_heads = self.total_num_heads // self.tp_size + self.total_num_kv_heads = (self.config.num_key_value_heads + or self.total_num_heads) + if self.total_num_kv_heads >= self.tp_size: + assert self.total_num_kv_heads % self.tp_size == 0 + else: + assert self.tp_size % self.total_num_kv_heads == 0 + + self.num_kv_heads = max(1, self.total_num_kv_heads // self.tp_size) + self.head_dim = hidden_size // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.max_position_embeddings = self.config.max_position_embeddings + self.rope_theta = self.config.rope_theta + + # Attention input projection. Projects x -> (q, k, v) + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=False, + quant_config=vllm_config.quant_config, + prefix=f"{prefix}.qkv_proj", + ) + + self.tp_rank = get_tensor_model_parallel_rank() + self.k_norm = RMSNorm( + self.total_num_kv_heads * self.head_dim, + eps=self.config.rms_norm_eps, + ) + self.q_norm = RMSNorm(self.config.hidden_size, + eps=self.config.rms_norm_eps) + + # Rotary embeddings. + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=self.max_position_embeddings, + base=self.rope_theta, # type: ignore + ) + self.scaling = self.head_dim**-0.5 + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=vllm_config.cache_config, + quant_config=vllm_config.quant_config, + prefix=prefix, + ) + + # Attention output projection. + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + quant_config=vllm_config.quant_config, + prefix=f"{prefix}.o_proj", + ) + + def _apply_qk_norm(self, q: torch.Tensor, + k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + if self.tp_size > 1: + q = tensor_model_parallel_all_gather(q.contiguous()) + k = tensor_model_parallel_all_gather(k.contiguous()) + q = self.q_norm.forward_native(q) + k = self.k_norm.forward_native(k) + if self.tp_size > 1: + splitter = partial(split_tensor_along_last_dim, + num_partitions=self.tp_size) + q = splitter(q)[self.tp_rank] + k = splitter(k)[self.tp_rank] + return q, k + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.chunk(chunks=3, dim=-1) + q, k = self._apply_qk_norm(q, k) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + output, _ = self.o_proj(attn_output) + return output + + +class Olmo2MLP(nn.Module): + """ + This is the MLP block where the output is computed as + ``MLP(x)`` in ``LN(MLP(x + LN(Attention(x))))`` + (plus another skip connection). + """ + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + assert isinstance(config, Olmo2Config) + hidden_size = config.hidden_size + intermediate_size = config.intermediate_size + + # Feed-forward input projection. + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, + [intermediate_size] * 2, + bias=False, + quant_config=vllm_config.quant_config, + prefix=f"{prefix}.gate_up_proj", + ) + + # Activation function. + self.act_fn = SiluAndMul() + + # Feed-forward output projection. + self.down_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=False, + quant_config=vllm_config.quant_config, + prefix=f"{prefix}.down_proj", + ) + + def forward( + self, + x: torch.Tensor, + ) -> torch.Tensor: + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class Olmo2DecoderLayer(nn.Module): + """ + This is a typical transformer block where the output is + computed as ``MLP(LN(x + Attention(LN(x))))`` + (plus another skip connection). + """ + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + assert isinstance(config, Olmo2Config) + # Attention block. + self.self_attn = Olmo2Attention(vllm_config=vllm_config, + prefix=f"{prefix}.self_attn") + + # MLP block. + self.mlp = Olmo2MLP(vllm_config=vllm_config, prefix=f"{prefix}.mlp") + + # LayerNorm + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + self.post_feedforward_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + # Attention block. + residual = hidden_states + hidden_states = self.self_attn(positions, hidden_states, kv_cache, + attn_metadata) + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = hidden_states + residual + + # MLP block. + residual = hidden_states + hidden_states = self.mlp(hidden_states) + hidden_states = self.post_feedforward_layernorm(hidden_states) + hidden_states = residual + hidden_states + return hidden_states + + +class Olmo2Model(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + self.config = vllm_config.model_config.hf_config + assert isinstance(self.config, Olmo2Config) + + self.embed_tokens = VocabParallelEmbedding( + self.config.vocab_size, + self.config.hidden_size, + prefix=f"{prefix}.embed_tokens", + ) + self.start_layer, self.end_layer, self.layers = make_layers( + self.config.num_hidden_layers, + lambda prefix: Olmo2DecoderLayer(vllm_config=vllm_config, + prefix=prefix), + prefix=f"{prefix}.layers", + ) + self.norm = RMSNorm( + self.config.hidden_size, + eps=self.config.rms_norm_eps, + ) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory(["hidden_states"], + self.config.hidden_size)) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors], + ) -> Union[torch.Tensor, IntermediateTensors]: + """ + :param input_ids: A tensor of shape `(batch_size, seq_len)`. + """ + if get_pp_group().is_first_rank: + # Get embeddings of input. + # shape: (batch_size, seq_len, d_model) + inputs_embeds = self.embed_tokens(input_ids) + + # embed positions + hidden_states = inputs_embeds + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + assert isinstance(hidden_states, torch.Tensor) + + # Apply blocks one-by-one. + for i in range(self.start_layer, self.end_layer): + # shape: (batch_size, seq_len, d_model) + hidden_states = self.layers[i]( + positions, + hidden_states, + kv_caches[i - self.start_layer], + attn_metadata, + ) + + if not get_pp_group().is_last_rank: + return IntermediateTensors({"hidden_states": hidden_states}) + + # Apply final layer norm. + # shape: (batch_size, seq_len or 1, d_model) + hidden_states = self.norm(hidden_states) + return hidden_states + + +class Olmo2ForCausalLM(nn.Module, SupportsPP): + """ + Extremely barebones HF model wrapper. + """ + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + assert isinstance(config, Olmo2Config) + self.config = config + self.model = Olmo2Model(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + if config.tie_word_embeddings: + self.lm_head = self.model.embed_tokens + else: + self.unpadded_vocab_size = config.vocab_size + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + quant_config=vllm_config.quant_config, + prefix=maybe_prefix(prefix, "lm_head"), + ) + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = Sampler() + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.model( + input_ids=input_ids, + positions=positions, + kv_caches=kv_caches, + attn_metadata=attn_metadata, + intermediate_tensors=intermediate_tensors, + ) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + params_dict = dict(self.named_parameters(remove_duplicate=False)) + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + if ("rotary_emb.cos_cached" in name + or "rotary_emb.sin_cached" in name): + # Models trained using ColossalAI may include these tensors in + # the checkpoint. Skip them. + continue + if is_pp_missing_parameter(name, self): + continue + # With tie_word_embeddings, we can skip lm_head.weight + # The weight might appear unnecessarily in the files if the model is + # processed with quantization, LoRA, fine-tuning, etc. + if self.config.tie_word_embeddings and "lm_head.weight" in name: + continue + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader # type: ignore + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/models/olmoe.py b/vllm/model_executor/models/olmoe.py index ab87695d8e650..5d9091cfb9311 100644 --- a/vllm/model_executor/models/olmoe.py +++ b/vllm/model_executor/models/olmoe.py @@ -102,6 +102,7 @@ def __init__( max_position_embeddings: int = 4096, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.hidden_size = hidden_size @@ -156,7 +157,8 @@ def __init__( self.scaling, num_kv_heads=self.num_kv_heads, cache_config=cache_config, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.attn") def forward( self, @@ -179,9 +181,9 @@ class OlmoeDecoderLayer(nn.Module): def __init__( self, config: PretrainedConfig, - layer_idx: int, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.hidden_size = config.hidden_size @@ -199,6 +201,7 @@ def __init__( max_position_embeddings=max_position_embeddings, cache_config=cache_config, quant_config=quant_config, + prefix=f"{prefix}.self_attn", ) self.mlp = OlmoeMoE( @@ -260,8 +263,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): ) self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, - lambda prefix: OlmoeDecoderLayer(config, int( - prefix.split(".")[-1]), cache_config, quant_config), + lambda prefix: OlmoeDecoderLayer( + config, cache_config, quant_config, prefix=prefix), prefix=f"{prefix}.layers") self.norm = RMSNorm(config.hidden_size, eps=1e-5) diff --git a/vllm/model_executor/models/opt.py b/vllm/model_executor/models/opt.py index db85a494980a7..7edafcd20b5db 100644 --- a/vllm/model_executor/models/opt.py +++ b/vllm/model_executor/models/opt.py @@ -337,9 +337,6 @@ class OPTForCausalLM(nn.Module, SupportsPP): "k_proj": ("qkv_proj", 1), "v_proj": ("qkv_proj", 2), } - default_bitsandbytes_target_modules = [ - ".q_proj.", ".k_proj.", ".v_proj.", ".out_proj.", ".fc1.", ".fc2." - ] def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() diff --git a/vllm/model_executor/models/orion.py b/vllm/model_executor/models/orion.py index b01734af8ddd8..a3757b5c8808e 100644 --- a/vllm/model_executor/models/orion.py +++ b/vllm/model_executor/models/orion.py @@ -75,6 +75,7 @@ def __init__( max_position_embeddings: int = 8192, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.hidden_size = hidden_size @@ -126,7 +127,8 @@ def __init__( self.scaling, num_kv_heads=self.num_kv_heads, cache_config=cache_config, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.attn") def forward( self, @@ -150,6 +152,7 @@ def __init__( config: PretrainedConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.hidden_size = config.hidden_size @@ -166,6 +169,7 @@ def __init__( max_position_embeddings=max_position_embeddings, cache_config=cache_config, quant_config=quant_config, + prefix=f"{prefix}.self_attn", ) self.mlp = OrionMLP( hidden_size=self.hidden_size, @@ -226,10 +230,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, lambda prefix: OrionDecoderLayer( - config, - cache_config, - quant_config, - ), + config, cache_config, quant_config, prefix=prefix), prefix=f"{prefix}.layers") self.norm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps) self.make_empty_intermediate_tensors = ( diff --git a/vllm/model_executor/models/paligemma.py b/vllm/model_executor/models/paligemma.py index dd5256eb87ab3..253e689e50a3b 100644 --- a/vllm/model_executor/models/paligemma.py +++ b/vllm/model_executor/models/paligemma.py @@ -13,6 +13,7 @@ from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import NestedTensors from vllm.multimodal.utils import cached_get_tokenizer from vllm.sequence import IntermediateTensors @@ -150,9 +151,10 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.quant_config = quant_config config.text_config.architectures = ["GemmaForCausalLM"] self.language_model = init_vllm_registered_model( - config.text_config, vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "language_model")) + hf_config=config.text_config, + prefix=maybe_prefix(prefix, "language_model"), + ) logit_scale = getattr(config, "logit_scale", 1.0) self.language_model.logits_processor.scale *= logit_scale @@ -240,36 +242,45 @@ def _process_image_input( return self.multi_modal_projector(image_features) + def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]: + image_input = self._parse_and_validate_image_input(**kwargs) + if image_input is None: + return None + vision_embeddings = self._process_image_input(image_input) + # https://github.com/huggingface/transformers/blob/main/src/transformers/models/paligemma/modeling_paligemma.py#L294 # noqa + vision_embeddings = vision_embeddings * (self.config.hidden_size**-0.5) + return vision_embeddings + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[NestedTensors] = None, + ) -> torch.Tensor: + inputs_embeds = self.language_model.get_input_embeddings(input_ids) + if multimodal_embeddings is not None: + inputs_embeds = merge_multimodal_embeddings( + input_ids, inputs_embeds, multimodal_embeddings, + self.config.image_token_index) + return inputs_embeds + def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, **kwargs: object) -> Union[SamplerOutput, IntermediateTensors]: if intermediate_tensors is not None: - input_ids = None inputs_embeds = None - else: - parsed_image_input = self._parse_and_validate_image_input(**kwargs) - - if parsed_image_input is not None: - vision_embeddings = self._process_image_input( - parsed_image_input) - # https://github.com/huggingface/transformers/blob/main/src/transformers/models/paligemma/modeling_paligemma.py#L294 # noqa - vision_embeddings = vision_embeddings * ( - self.config.hidden_size**-0.5) - - inputs_embeds = self.language_model.model.get_input_embeddings( - input_ids) - - inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, vision_embeddings, - self.config.image_token_index) - - input_ids = None - else: - inputs_embeds = None + + # NOTE: In v1, inputs_embeds is always generated at model runner, this + # condition is for v0 compatibility. + elif inputs_embeds is None: + vision_embeddings = self.get_multimodal_embeddings(**kwargs) + inputs_embeds = self.get_input_embeddings(input_ids, + vision_embeddings) + input_ids = None hidden_states = self.language_model.model(input_ids, positions, diff --git a/vllm/model_executor/models/persimmon.py b/vllm/model_executor/models/persimmon.py index 3b8199f4f1661..14dd4b5b1b4da 100644 --- a/vllm/model_executor/models/persimmon.py +++ b/vllm/model_executor/models/persimmon.py @@ -75,7 +75,8 @@ class PersimmonAttention(nn.Module): def __init__(self, config: PersimmonConfig, cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None): + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): super().__init__() self.config = config tensor_parallel_world_size = get_tensor_model_parallel_world_size() @@ -122,7 +123,8 @@ def __init__(self, self.head_dim, scale=self.scaling, cache_config=cache_config, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.attn") def _split_heads(self, x: torch.Tensor) -> torch.Tensor: # [seq_length, hidden_size] -> [seq_length, num_heads, head_dim] @@ -167,12 +169,14 @@ class PersimmonDecoderLayer(nn.Module): def __init__(self, config: PersimmonConfig, cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None): + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): super().__init__() self.hidden_size = config.hidden_size self.self_attn = PersimmonAttention(config=config, cache_config=cache_config, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.self_attn") self.mlp = PersimmonMLP(config, quant_config=quant_config) self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) @@ -226,8 +230,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config.hidden_size) self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, - lambda prefix: PersimmonDecoderLayer(config, cache_config, - quant_config), + lambda prefix: PersimmonDecoderLayer( + config, cache_config, quant_config, prefix=prefix), prefix=f"{prefix}.layers") self.final_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) diff --git a/vllm/model_executor/models/phi.py b/vllm/model_executor/models/phi.py index 0a117bf16c9b3..f9e972688ddd1 100644 --- a/vllm/model_executor/models/phi.py +++ b/vllm/model_executor/models/phi.py @@ -69,7 +69,8 @@ class PhiAttention(nn.Module): def __init__(self, config: PhiConfig, cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None): + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): super().__init__() self.total_num_heads = config.num_attention_heads self.hidden_size = config.hidden_size @@ -116,7 +117,8 @@ def __init__(self, self.head_size, scaling, cache_config=cache_config, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.attn") def forward( self, @@ -167,11 +169,15 @@ class PhiLayer(nn.Module): def __init__(self, config: PhiConfig, cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None): + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): super().__init__() self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - self.self_attn = PhiAttention(config, cache_config, quant_config) + self.self_attn = PhiAttention(config, + cache_config, + quant_config, + prefix=f"{prefix}.self_attn") self.mlp = PhiMLP(config, quant_config) def forward( @@ -210,7 +216,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config.hidden_size) self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, - lambda prefix: PhiLayer(config, cache_config, quant_config), + lambda prefix: PhiLayer( + config, cache_config, quant_config, prefix=prefix), prefix=f"{prefix}.layers") self.final_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) @@ -279,9 +286,6 @@ class PhiForCausalLM(nn.Module, SupportsLoRA, SupportsPP): "k_proj": ("qkv_proj", 1), "v_proj": ("qkv_proj", 2), } - default_bitsandbytes_target_modules = [ - ".q_proj.", ".k_proj.", ".v_proj.", ".fc1.", ".fc2.", ".dense." - ] embedding_modules = {} embedding_padding_modules = [] diff --git a/vllm/model_executor/models/phi3.py b/vllm/model_executor/models/phi3.py index 34141511ea791..937858ee3b8c2 100644 --- a/vllm/model_executor/models/phi3.py +++ b/vllm/model_executor/models/phi3.py @@ -14,3 +14,7 @@ class Phi3ForCausalLM(LlamaForCausalLM): "gate_up_proj", ], } + + # BitandBytes specific attributes + # Initialize an empty dict when there is no stacked parameter mapping. + bitsandbytes_stacked_params_mapping = {} diff --git a/vllm/model_executor/models/phi3_small.py b/vllm/model_executor/models/phi3_small.py index f71cbd1264c45..da7e4cdbc6940 100644 --- a/vllm/model_executor/models/phi3_small.py +++ b/vllm/model_executor/models/phi3_small.py @@ -117,6 +117,7 @@ def __init__( layer_idx: int, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.layer_idx = layer_idx @@ -214,15 +215,14 @@ def __init__( "homo_head": self.homo_heads } - self.attn = Attention( - self.num_heads_per_partition, - self.head_dim, - self.scale, - num_kv_heads=self.num_kv_heads_per_partion, - cache_config=cache_config, - quant_config=quant_config, - blocksparse_params=bs_params, - ) + self.attn = Attention(self.num_heads_per_partition, + self.head_dim, + self.scale, + num_kv_heads=self.num_kv_heads_per_partion, + cache_config=cache_config, + quant_config=quant_config, + blocksparse_params=bs_params, + prefix=f"{prefix}.attn") def forward( self, @@ -259,13 +259,15 @@ def __init__( layer_idx: int, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() self.hidden_size = config.hidden_size self.self_attn = Phi3SmallSelfAttention(config, layer_idx, cache_config=cache_config, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.self_attn") self.mlp = Phi3SmallMLP(config, quant_config) self.input_layernorm = nn.LayerNorm(config.hidden_size, @@ -315,7 +317,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config.num_hidden_layers, lambda prefix: Phi3SmallDecoderLayer(config, int(prefix.split('.')[-1]), - cache_config, quant_config), + cache_config, + quant_config, + prefix=prefix), prefix=f"{prefix}.layers") self.final_layernorm = nn.LayerNorm(config.hidden_size, diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index 2e583bb08e87a..eef23029a2aca 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -29,24 +29,22 @@ from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, InputContext, token_inputs) from vllm.logger import init_logger -from vllm.model_executor.layers.pooler import Pooler, PoolingType from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) from vllm.model_executor.models.clip import CLIPVisionModel -from vllm.model_executor.models.llama import LlamaForCausalLM -from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import NestedTensors, PlaceholderRange from vllm.multimodal.utils import cached_get_tokenizer, repeat_and_pad_token -from vllm.sequence import IntermediateTensors, PoolerOutput +from vllm.sequence import IntermediateTensors from vllm.utils import is_list_of from .clip import dummy_image_for_clip, dummy_seq_data_for_clip from .interfaces import SupportsMultiModal, SupportsPP -from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, maybe_prefix, +from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, + init_vllm_registered_model, maybe_prefix, merge_multimodal_embeddings) logger = init_logger(__name__) @@ -536,7 +534,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config - pooler_config = vllm_config.model_config.pooler_config multimodal_config = vllm_config.model_config.multimodal_config self.config = config self.multimodal_config = multimodal_config @@ -556,18 +553,17 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): quant_config, prefix=maybe_prefix(prefix, "model.vision_embed_tokens")) - # The prefix is empty intentionally because default prefix of - # LlamaForCausalLM is "model" - self.language_model = LlamaForCausalLM(vllm_config=vllm_config, - prefix="") - - # The same model class supports both language generation and embedding - # because the architecture name is the same - self._pooler = Pooler.from_config_with_defaults( - pooler_config, - pooling_type=PoolingType.LAST, - normalize=True, - softmax=False) + self.language_model = init_vllm_registered_model( + vllm_config=vllm_config, + # The prefix is empty intentionally because default prefix of + # LlamaForCausalLM is "model" + prefix="", + # We don't directly initialize vLLM's LlamaForCausalLM so we + # can automatically apply embedding wrapper if this model is + # initialized as an embedding model + architectures=["LlamaForCausalLM"], + ) + self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors) @@ -676,7 +672,7 @@ def _process_image_input( return image_embeds - def process_mm_inputs(self, **kwargs): + def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]: image_input = self._parse_and_validate_image_input(**kwargs) if image_input is None: return None @@ -686,12 +682,12 @@ def process_mm_inputs(self, **kwargs): def get_input_embeddings( self, input_ids: torch.Tensor, - vision_embeddings: Optional[NestedTensors] = None, + multimodal_embeddings: Optional[NestedTensors] = None, ) -> torch.Tensor: inputs_embeds = self.embed_tokens(input_ids) - if vision_embeddings is not None: + if multimodal_embeddings is not None: inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, vision_embeddings, + input_ids, inputs_embeds, multimodal_embeddings, self.image_token_id) return inputs_embeds @@ -703,12 +699,14 @@ def forward(self, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, **kwargs: object): + if intermediate_tensors is not None: inputs_embeds = None + + # NOTE: In v1, inputs_embeds is always generated at model runner, this + # condition is for v0 compatibility elif inputs_embeds is None: - vision_embeddings = self.process_mm_inputs(**kwargs) - # always pass the input via `inputs_embeds` - # to make sure the computation graph is consistent + vision_embeddings = self.get_multimodal_embeddings(**kwargs) inputs_embeds = self.get_input_embeddings(input_ids, vision_embeddings) input_ids = None @@ -737,13 +735,6 @@ def sample( ) -> Optional[SamplerOutput]: return self.language_model.sample(logits, sampling_metadata) - def pooler( - self, - hidden_states: torch.Tensor, - pooling_metadata: PoolingMetadata, - ) -> Optional[PoolerOutput]: - return self._pooler(hidden_states, pooling_metadata) - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: hf_to_vllm_mapper = WeightsMapper( diff --git a/vllm/model_executor/models/phimoe.py b/vllm/model_executor/models/phimoe.py index e475d286bd7ea..1febd62f2f705 100644 --- a/vllm/model_executor/models/phimoe.py +++ b/vllm/model_executor/models/phimoe.py @@ -294,6 +294,7 @@ def __init__( cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, rope_scaling: Optional[dict] = None, + prefix: str = "", ) -> None: super().__init__() self.hidden_size = hidden_size @@ -347,6 +348,7 @@ def __init__( num_kv_heads=self.num_kv_heads, cache_config=cache_config, quant_config=quant_config, + prefix=f"{prefix}.attn", ) def forward( @@ -371,6 +373,7 @@ def __init__( config: PhiMoEConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.hidden_size = config.hidden_size @@ -385,6 +388,7 @@ def __init__( cache_config=cache_config, quant_config=quant_config, rope_scaling=config.rope_scaling, + prefix=f"{prefix}.self_attn", ) self.block_sparse_moe = PhiMoE( num_experts=config.num_local_experts, @@ -454,8 +458,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): ) self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, - lambda prefix: PhiMoEDecoderLayer(config, cache_config, - quant_config), + lambda prefix: PhiMoEDecoderLayer( + config, cache_config, quant_config, prefix=prefix), prefix=f"{prefix}.layers") self.norm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps, diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index d44a538d56b8c..215727cadd954 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -17,6 +17,7 @@ from vllm.attention import AttentionMetadata from vllm.config import ModelConfig, VllmConfig +from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, InputContext, token_inputs) from vllm.model_executor.layers.activation import get_act_and_mul_fn @@ -30,9 +31,10 @@ from vllm.model_executor.models.utils import merge_multimodal_embeddings from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs -from vllm.multimodal.inputs import PlaceholderRange +from vllm.multimodal.inputs import NestedTensors, PlaceholderRange from vllm.multimodal.utils import (cached_get_tokenizer, - consecutive_placeholder_ranges) + consecutive_placeholder_ranges, + resolve_visual_encoder_outputs) from vllm.sequence import IntermediateTensors, SequenceData from vllm.transformers_utils.processor import cached_get_processor from vllm.utils import is_list_of @@ -170,9 +172,10 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): # init MistralForCausalLM self.language_model = init_vllm_registered_model( - config.text_config, vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "language_model")) + hf_config=config.text_config, + prefix=maybe_prefix(prefix, "language_model"), + ) self.vision_encoder = VisionTransformer(self.vision_args) self.vision_language_adapter = VisionLanguageAdapter( @@ -188,6 +191,25 @@ def sampler(self): return get_sampler() + def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]: + image_input = self._parse_and_validate_image_input(**kwargs) + if image_input is None: + return None + vision_embeddings = self._process_image_input(image_input) + return vision_embeddings + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[NestedTensors] = None, + ) -> torch.Tensor: + inputs_embeds = self.language_model.get_input_embeddings(input_ids) + if multimodal_embeddings is not None: + inputs_embeds = merge_multimodal_embeddings( + input_ids, inputs_embeds, multimodal_embeddings, + self.vision_args.image_token_id) + return inputs_embeds + def forward( self, input_ids: torch.Tensor, @@ -195,31 +217,21 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, **kwargs: object, ) -> Union[torch.Tensor, IntermediateTensors]: """Run forward pass for pixtral. - - TODO - """ if intermediate_tensors is not None: - input_ids = None inputs_embeds = None - else: - image_input = self._parse_and_validate_image_input(**kwargs) - - if image_input is not None: - vision_embeddings = self._process_image_input(image_input) - inputs_embeds = self.language_model.model.get_input_embeddings( - input_ids) - inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, vision_embeddings, - self.vision_args.image_token_id) - - input_ids = None - else: - inputs_embeds = None + # NOTE: In v1, inputs_embeds is always generated at model runner, this + # condition is for v0 compatibility. + elif inputs_embeds is None: + vision_embeddings = self.get_multimodal_embeddings(**kwargs) + inputs_embeds = self.get_input_embeddings(input_ids, + vision_embeddings) + input_ids = None hidden_states = self.language_model.model(input_ids, positions, @@ -330,6 +342,7 @@ class VisionEncoderArgs: num_attention_heads: int rope_theta: float # for rope-2D image_token_id: int + adapter_bias: bool = True def _reshape_for_broadcast(freqs_cis: torch.Tensor, @@ -594,10 +607,10 @@ def __init__(self, args: VisionEncoderArgs, dim: int): self.w_in = nn.Linear( args.hidden_size, dim, - bias=True, + bias=args.adapter_bias, ) self.gelu = nn.GELU() - self.w_out = nn.Linear(dim, dim, bias=True) + self.w_out = nn.Linear(dim, dim, bias=args.adapter_bias) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.w_out(self.gelu(self.w_in(x))) @@ -843,17 +856,20 @@ def __init__( self.config = config assert not config.hidden_size % config.num_attention_heads - self.n_heads = config.num_attention_heads + self.total_num_heads = config.num_attention_heads + tp_size = get_tensor_model_parallel_world_size() + self.n_heads = divide(config.num_attention_heads, tp_size) self.head_dim = config.hidden_size // config.num_attention_heads self.qkv_proj = QKVParallelLinear( hidden_size=config.hidden_size, head_size=self.head_dim, - total_num_heads=self.n_heads, + total_num_heads=self.total_num_heads, bias=False, quant_config=quant_config, prefix=f"{prefix}.qkv_proj", ) + assert self.total_num_heads * self.head_dim == config.hidden_size self.o_proj = RowParallelLinear( input_size=config.hidden_size, output_size=config.hidden_size, @@ -965,9 +981,18 @@ def forward( x: torch.Tensor, attention_mask: torch.Tensor, position_embeddings: torch.Tensor, + return_all_hidden_states: bool, ) -> torch.Tensor: + hidden_states_pool = [] + for layer in self.layers: x = layer(x, attention_mask, position_embeddings) + if return_all_hidden_states: + hidden_states_pool.append(x) + # If we have multiple feature sample layers, we return all hidden + # states in order and grab the ones we need by index. + if return_all_hidden_states: + return hidden_states_pool return x @@ -985,6 +1010,7 @@ def __init__( super().__init__() self.config = config + self.patch_conv = nn.Conv2d( in_channels=config.num_channels, out_channels=config.hidden_size, @@ -1019,6 +1045,7 @@ def __init__( def forward( self, pixel_values: List[torch.Tensor], + feature_sample_layers: Optional[list[int]] = None, ) -> torch.Tensor: """ Args: @@ -1026,6 +1053,9 @@ def forward( in pixel_values. This means it will be a list of tensors because multiple requests batched can have multiple images, each with their own shape potentially + feature_sample_layers: Layer indices whose features should be + concatenated and used as the visual encoder output. If none + are provided, the last layer is used. Returns: image_features: tensor of token features for @@ -1060,8 +1090,15 @@ def forward( [p.shape[-2] * p.shape[-1] for p in patch_embeds_list], patch_embeds) - out = self.transformer(patch_embeds, attention_mask, - position_embedding) + return_all_hidden_states = feature_sample_layers is not None + out = self.transformer( + patch_embeds, + attention_mask, + position_embedding, + return_all_hidden_states=return_all_hidden_states) + + out = resolve_visual_encoder_outputs(out, feature_sample_layers, None, + self.config.num_hidden_layers) return out diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index 3978c176a2144..63d1374ab4092 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -442,6 +442,7 @@ def __init__( rope_scaling: Optional[Dict[str, Any]] = None, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() self.hidden_size = hidden_size @@ -478,7 +479,8 @@ def __init__( self.head_dim, self.scaling, cache_config=cache_config, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.attn") def forward( self, @@ -502,6 +504,7 @@ def __init__( config: PretrainedConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() self.ln_1 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) @@ -514,7 +517,8 @@ def __init__( rope_theta=rope_theta, rope_scaling=rope_scaling, cache_config=cache_config, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.attn") self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) @@ -568,7 +572,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): ) self.start_layer, self.end_layer, self.h = make_layers( config.num_hidden_layers, - lambda prefix: QWenBlock(config, cache_config, quant_config), + lambda prefix: QWenBlock( + config, cache_config, quant_config, prefix=prefix), prefix=f"{prefix}.h") self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) self.make_empty_intermediate_tensors = ( @@ -870,7 +875,7 @@ def dummy_data_for_qwen( return DummyData(seq_data, mm_data) -class QWenBaseModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA): +class QWenBaseModel(nn.Module, SupportsPP, SupportsLoRA): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -1023,8 +1028,15 @@ class QWenLLM(QWenBaseModel): embedding_modules = {} embedding_padding_modules = [] + # BitandBytes specific attributes + bitsandbytes_stacked_params_mapping = { + # shard_name, weight_name, index + "w2": ("gate_up_proj", 0), + "w1": ("gate_up_proj", 1), + } + -class QWenVL(QWenBaseModel): +class QWenVL(QWenBaseModel, SupportsMultiModal): packed_modules_mapping = { "c_attn": ["c_attn"], "gate_up_proj": [ @@ -1062,7 +1074,7 @@ def get_mm_mapping(self) -> MultiModelKeys: @MULTIMODAL_REGISTRY.register_max_image_tokens(MAX_QWEN_IMG_TOKENS) @INPUT_REGISTRY.register_dummy_data(dummy_data_for_qwen) @INPUT_REGISTRY.register_input_processor(input_processor_for_qwen) -class QWenLMHeadModel(QWenBaseModel, SupportsLoRA): +class QWenLMHeadModel(QWenBaseModel, SupportsMultiModal, SupportsLoRA): """ QWenLMHeadModel is not only applicable to LLM but also to VL, which is not conducive to the current integration logic of LoRA in vLLM. Therefore, it @@ -1083,7 +1095,7 @@ def __new__( config = vllm_config.model_config.hf_config # Initialize VL if hasattr(config, "visual"): - return QWenVL(vllm_config=vllm_config) + return QWenVL(vllm_config=vllm_config, prefix=prefix) # Initialize LLM else: - return QWenLLM(vllm_config=vllm_config) + return QWenLLM(vllm_config=vllm_config, prefix=prefix) diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index 370cff5fa153f..7d4cc4b69e614 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -27,10 +27,11 @@ from torch import nn from transformers import Qwen2Config -from vllm.attention import Attention, AttentionMetadata +from vllm.attention import Attention, AttentionMetadata, AttentionType from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.logger import init_logger from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, @@ -50,10 +51,13 @@ from vllm.sequence import IntermediateTensors, PoolerOutput from .interfaces import SupportsLoRA, SupportsPP -from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter, +from .utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper, + is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) +logger = init_logger(__name__) + class Qwen2MLP(nn.Module): @@ -164,11 +168,17 @@ def forward( hidden_states: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: AttentionMetadata, + attn_type: str = AttentionType.DECODER, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self.rotary_emb(positions, q, k) - attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + attn_output = self.attn(q, + k, + v, + kv_cache, + attn_metadata, + attn_type=attn_type) output, _ = self.o_proj(attn_output) return output @@ -210,6 +220,15 @@ def __init__( self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + # By default, Qwen2 uses causal attention as it is a decoder-only model. + # You can override the HF config with `is_causal=False` to enable + # bidirectional attention, which is used in some embedding models + # (e.g. Alibaba-NLP/gte-Qwen2-7B-instruct) + if getattr(config, "is_causal", True): + self._attn_type = AttentionType.DECODER + else: + self._attn_type = AttentionType.ENCODER_ONLY + def forward( self, positions: torch.Tensor, @@ -230,6 +249,7 @@ def forward( hidden_states=hidden_states, kv_cache=kv_cache, attn_metadata=attn_metadata, + attn_type=self._attn_type, ) # Fully Connected @@ -402,15 +422,6 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): embedding_padding_modules = [] # BitandBytes specific attributes - default_bitsandbytes_target_modules = [ - ".gate_proj.", - ".down_proj.", - ".up_proj.", - ".q_proj.", - ".k_proj.", - ".v_proj.", - ".o_proj.", - ] bitsandbytes_stacked_params_mapping = { # shard_name, weight_name, index "q_proj": ("qkv_proj", 0), @@ -425,7 +436,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config lora_config = vllm_config.lora_config - pooler_config = vllm_config.model_config.pooler_config self.config = config self.lora_config = lora_config @@ -446,14 +456,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = get_sampler() - # The same model class supports both language generation and embedding - # because the architecture name is the same - self._pooler = Pooler.from_config_with_defaults( - pooler_config, - pooling_type=PoolingType.LAST, - normalize=True, - softmax=False) - self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) @@ -491,13 +493,6 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def pooler( - self, - hidden_states: torch.Tensor, - pooling_metadata: PoolingMetadata, - ) -> Optional[PoolerOutput]: - return self._pooler(hidden_states, pooling_metadata) - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: loader = AutoWeightsLoader( @@ -545,6 +540,15 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.model = Qwen2Model(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")) + # TODO: Replace this model class with for_embedding(Qwen2ForCausalLM), + # after changing the default pooling method + if pooler_config.pooling_type is None: + logger.warning( + "This embedding model will default to last-token pooling in " + "an upcoming version. To avoid breaking changes, you should " + "pass `--override-pooler-config '{\"pooling_type\": \"MEAN\"}'`" + " explicitly.") + self._pooler = Pooler.from_config_with_defaults( pooler_config, pooling_type=PoolingType.MEAN, @@ -569,8 +573,9 @@ def pooler( ) -> Optional[PoolerOutput]: return self._pooler(hidden_states, pooling_metadata) - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: - loader = AutoWeightsLoader(self, - ignore_unexpected_prefixes=["lm_head."]) - return loader.load_weights(weights) + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""}) + weights = hf_to_vllm_mapper.apply(weights) + weights = ((name, data) for name, data in weights + if not name.startswith("lm_head.")) + self.model.load_weights(weights) diff --git a/vllm/model_executor/models/qwen2_audio.py b/vllm/model_executor/models/qwen2_audio.py index a4965f34b1ca8..a0605fee82aca 100644 --- a/vllm/model_executor/models/qwen2_audio.py +++ b/vllm/model_executor/models/qwen2_audio.py @@ -42,10 +42,12 @@ from vllm.model_executor.models.qwen2 import Qwen2Model from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs +from vllm.multimodal.inputs import NestedTensors from vllm.multimodal.utils import consecutive_placeholder_ranges from vllm.sequence import IntermediateTensors, SequenceData from .interfaces import SupportsMultiModal, SupportsPP +from .utils import merge_multimodal_embeddings logger = init_logger(__name__) @@ -212,7 +214,7 @@ def input_processor_for_qwen2_audio( return token_inputs( prompt_token_ids=new_input_ids, - prompt=inputs['prompt'], + prompt=inputs.get("prompt"), multi_modal_data=multi_modal_data, ) @@ -371,6 +373,25 @@ def _process_audio_input(self, return masked_audio_features + def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]: + audio_input = self._parse_and_validate_audio_input(**kwargs) + if audio_input is None: + return None + masked_audio_features = self._process_audio_input(audio_input) + return masked_audio_features + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[NestedTensors] = None, + ) -> torch.Tensor: + inputs_embeds = self.language_model.get_input_embeddings(input_ids) + if multimodal_embeddings is not None: + inputs_embeds = merge_multimodal_embeddings( + input_ids, inputs_embeds, multimodal_embeddings, + self.config.audio_token_index) + return inputs_embeds + def forward( self, input_ids: torch.Tensor, @@ -378,33 +399,27 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, **kwargs: object, ) -> Union[torch.Tensor, IntermediateTensors]: + if intermediate_tensors is not None: - input_ids = None inputs_embeds = None - else: - audio_input = self._parse_and_validate_audio_input(**kwargs) - if audio_input is None: - inputs_embeds = None - else: - inputs_embeds = self.language_model.embed_tokens(input_ids) - masked_audio_features = self._process_audio_input(audio_input) - # merge llm embeddings and audio features - mask = (input_ids == self.config.audio_token_index) - inputs_embeds[mask, :] = masked_audio_features - - input_ids = None - - hidden_states = self.language_model( - input_ids=input_ids, - positions=positions, - kv_caches=kv_caches, - attn_metadata=attn_metadata, - intermediate_tensors=intermediate_tensors, - inputs_embeds=inputs_embeds, - ) + # NOTE: In v1, inputs_embeds is always generated at model runner, this + # condition is for v0 compatibility. + elif inputs_embeds is None: + multimodal_embeddings = self.get_multimodal_embeddings(**kwargs) + inputs_embeds = self.get_input_embeddings(input_ids, + multimodal_embeddings) + input_ids = None + + hidden_states = self.language_model(input_ids, + positions, + kv_caches, + attn_metadata, + intermediate_tensors, + inputs_embeds=inputs_embeds) return hidden_states def compute_logits(self, hidden_states: torch.Tensor, diff --git a/vllm/model_executor/models/qwen2_moe.py b/vllm/model_executor/models/qwen2_moe.py index 96a9bc451f4df..ba70243c6533d 100644 --- a/vllm/model_executor/models/qwen2_moe.py +++ b/vllm/model_executor/models/qwen2_moe.py @@ -53,7 +53,7 @@ from vllm.utils import print_warning_once from .interfaces import SupportsPP -from .utils import (is_pp_missing_parameter, +from .utils import (extract_layer_index, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) @@ -168,6 +168,7 @@ def __init__( max_position_embeddings: int = 8192, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.hidden_size = hidden_size @@ -220,7 +221,8 @@ def __init__( self.scaling, num_kv_heads=self.num_kv_heads, cache_config=cache_config, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.attn") def forward( self, @@ -242,9 +244,9 @@ class Qwen2MoeDecoderLayer(nn.Module): def __init__( self, config: PretrainedConfig, - layer_idx: int, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.hidden_size = config.hidden_size @@ -261,10 +263,12 @@ def __init__( max_position_embeddings=max_position_embeddings, cache_config=cache_config, quant_config=quant_config, + prefix=f"{prefix}.self_attn", ) # Note: Qwen/Qwen2-57B-A14B-Instruct does not have # `mlp_only_layers` in the config. + layer_idx = extract_layer_index(prefix) mlp_only_layers = ([] if not hasattr(config, "mlp_only_layers") else config.mlp_only_layers) if (layer_idx not in mlp_only_layers) and ( @@ -333,10 +337,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, lambda prefix: Qwen2MoeDecoderLayer(config=config, - layer_idx=int( - prefix.split(".")[-1]), cache_config=cache_config, - quant_config=quant_config), + quant_config=quant_config, + prefix=prefix), prefix=f"{prefix}.layers", ) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index a929b9323b245..27175dbae7483 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -39,7 +39,6 @@ make_batched_images, make_batched_videos, smart_resize) from vllm.attention import AttentionMetadata -from vllm.attention.selector import _Backend from vllm.config import VllmConfig from vllm.distributed import get_pp_group, parallel_state from vllm.distributed import utils as dist_utils @@ -51,21 +50,21 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.pooler import Pooler, PoolingType -from vllm.model_executor.layers.quantization import (GPTQConfig, - GPTQMarlinConfig, - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.quantization.gptq import GPTQConfig +from vllm.model_executor.layers.quantization.gptq_marlin import ( + GPTQMarlinConfig) from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.qwen2 import Qwen2Model -from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.image import cached_get_image_processor from vllm.multimodal.inputs import (MultiModalData, MultiModalDataDict, - MultiModalKwargs) + MultiModalKwargs, NestedTensors) from vllm.multimodal.utils import cached_get_tokenizer -from vllm.sequence import IntermediateTensors, PoolerOutput, SequenceData +from vllm.platforms import _Backend +from vllm.sequence import IntermediateTensors, SequenceData from vllm.transformers_utils.config import uses_mrope from vllm.transformers_utils.processor import cached_get_processor @@ -1069,7 +1068,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config - pooler_config = vllm_config.model_config.pooler_config multimodal_config = vllm_config.model_config.multimodal_config assert not cache_config.enable_prefix_caching, \ "Qwen2-VL currently does not support prefix caching" @@ -1101,11 +1099,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = get_sampler() - self._pooler = Pooler.from_config_with_defaults( - pooler_config, - pooling_type=PoolingType.LAST, - normalize=True, - softmax=False) + self.make_empty_intermediate_tensors = ( make_empty_intermediate_tensors_factory( ["hidden_states", "residual"], config.hidden_size)) @@ -1237,6 +1231,55 @@ def _merge_multimodal_embeddings( inputs_embeds[mask, :] = multimodal_embeddings return inputs_embeds + def get_multimodal_embeddings( + self, **kwargs) -> Optional[List[Tuple[NestedTensors, str]]]: + + image_input = self._parse_and_validate_image_input(**kwargs) + video_input = self._parse_and_validate_video_input(**kwargs) + if image_input is None and video_input is None: + return None + + # We make a tuple of each embedding with its modality string. This is a + # temporary workaround for models to handle mixed modalities when + # get_multimodal_embeddings and get_input_embeddings are called + # separately. + # TODO(ywang96): Add support for mixed-modality inference for v1. + multimodal_embeddings: List[Tuple[NestedTensors, str]] = [] + + if image_input is not None: + image_embeds = self._process_image_input(image_input) + multimodal_embeddings.append((image_embeds, "image")) + if video_input is not None: + video_embeds = self._process_video_input(video_input) + multimodal_embeddings.append((video_embeds, "video")) + + return multimodal_embeddings + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[List[Tuple[NestedTensors, + str]]] = None, + ) -> torch.Tensor: + inputs_embeds = self.model.get_input_embeddings(input_ids) + if multimodal_embeddings is not None: + for embeddings, modality in multimodal_embeddings: + if modality == "image": + inputs_embeds = self._merge_multimodal_embeddings( + input_ids, + inputs_embeds, + embeddings, + placeholder_token_id=self.config.image_token_id, + ) + if modality == "video": + inputs_embeds = self._merge_multimodal_embeddings( + input_ids, + inputs_embeds, + embeddings, + placeholder_token_id=self.config.video_token_id, + ) + return inputs_embeds + def forward( self, input_ids: torch.Tensor, @@ -1244,6 +1287,7 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, **kwargs: object, ) -> Union[torch.Tensor, IntermediateTensors]: """Run forward pass for Qwen2-VL. @@ -1265,42 +1309,26 @@ def forward( video_grid_thw: Tensor `(n_videos, 3)` of video 3D grid in LLM. `None` if no videos are passed. """ + if intermediate_tensors is not None: - input_ids = None inputs_embeds = None - else: - image_input = self._parse_and_validate_image_input(**kwargs) - video_input = self._parse_and_validate_video_input(**kwargs) - - if image_input is None and video_input is None: - inputs_embeds = None - else: - if uses_mrope(self.config): - assert positions.ndim == 2 and positions.size(0) == 3, ( - "multimodal section rotary embedding requires " - f"(3, seq_len) positions, but got {positions.size()}") - - inputs_embeds = self.model.embed_tokens(input_ids) - if image_input is not None: - image_embeds = self._process_image_input(image_input) - inputs_embeds = self._merge_multimodal_embeddings( - input_ids, - inputs_embeds, - image_embeds, - placeholder_token_id=self.config.image_token_id, - ) - - if video_input is not None: - video_embeds = self._process_video_input(video_input) - inputs_embeds = self._merge_multimodal_embeddings( - input_ids, - inputs_embeds, - video_embeds, - placeholder_token_id=self.config.video_token_id, - ) - - input_ids = None + # NOTE: In v1, inputs_embeds is always generated at model runner, this + # condition is for v0 compatibility. + elif inputs_embeds is None: + multimodal_embeddings = self.get_multimodal_embeddings(**kwargs) + + # We need to check for usage of mrope here in case there is + # multimodal data. + # TODO (ywang96): move this to model runner in V1. + if multimodal_embeddings is not None and uses_mrope(self.config): + assert positions.ndim == 2 and positions.size(0) == 3, ( + "multimodal section rotary embedding requires " + f"(3, seq_len) positions, but got {positions.size()}") + + inputs_embeds = self.get_input_embeddings(input_ids, + multimodal_embeddings) + input_ids = None hidden_states = self.model( input_ids=input_ids, @@ -1326,13 +1354,6 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def pooler( - self, - hidden_states: torch.Tensor, - pooling_metadata: PoolingMetadata, - ) -> Optional[PoolerOutput]: - return self._pooler(hidden_states, pooling_metadata) - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: stacked_params_mapping = [ diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 45fb240f0f0c2..d8bd6d3603e84 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -20,9 +20,11 @@ from vllm.logger import init_logger from vllm.platforms import current_platform +from .adapters import as_embedding_model from .interfaces import (has_inner_state, is_attention_free, - supports_multimodal, supports_pp) -from .interfaces_base import is_embedding_model, is_text_generation_model + supports_cross_encoding, supports_multimodal, + supports_pp) +from .interfaces_base import is_pooling_model, is_text_generation_model logger = init_logger(__name__) @@ -47,6 +49,7 @@ "FalconForCausalLM": ("falcon", "FalconForCausalLM"), "GemmaForCausalLM": ("gemma", "GemmaForCausalLM"), "Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"), + "GlmForCausalLM": ("glm", "GlmForCausalLM"), "GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"), "GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"), "GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"), @@ -74,6 +77,7 @@ "MPTForCausalLM": ("mpt", "MPTForCausalLM"), "NemotronForCausalLM": ("nemotron", "NemotronForCausalLM"), "OlmoForCausalLM": ("olmo", "OlmoForCausalLM"), + "Olmo2ForCausalLM": ("olmo2", "Olmo2ForCausalLM"), "OlmoeForCausalLM": ("olmoe", "OlmoeForCausalLM"), "OPTForCausalLM": ("opt", "OPTForCausalLM"), "OrionForCausalLM": ("orion", "OrionForCausalLM"), @@ -90,7 +94,8 @@ "StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"), "Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"), "SolarForCausalLM": ("solar", "SolarForCausalLM"), - "XverseForCausalLM": ("xverse", "XverseForCausalLM"), + "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"), + "XverseForCausalLM": ("llama", "LlamaForCausalLM"), # [Encoder-decoder] "BartModel": ("bart", "BartForConditionalGeneration"), "BartForConditionalGeneration": ("bart", "BartForConditionalGeneration"), @@ -101,29 +106,41 @@ # [Text-only] "BertModel": ("bert", "BertEmbeddingModel"), "RobertaModel": ("roberta", "RobertaEmbeddingModel"), + "RobertaForMaskedLM": ("roberta", "RobertaEmbeddingModel"), "XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"), "DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"), - "Gemma2Model": ("gemma2", "Gemma2EmbeddingModel"), - "LlamaModel": ("llama", "LlamaEmbeddingModel"), + "Gemma2Model": ("gemma2", "Gemma2ForCausalLM"), + "GlmForCausalLM": ("glm", "GlmForCausalLM"), + "LlamaModel": ("llama", "LlamaForCausalLM"), **{ # Multiple models share the same architecture, so we include them all k: (mod, arch) for k, (mod, arch) in _TEXT_GENERATION_MODELS.items() if arch == "LlamaForCausalLM" }, - "MistralModel": ("llama", "LlamaEmbeddingModel"), + "MistralModel": ("llama", "LlamaForCausalLM"), "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"), "Qwen2Model": ("qwen2", "Qwen2EmbeddingModel"), "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"), "Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"), "Qwen2ForSequenceClassification": ("qwen2_cls", "Qwen2ForSequenceClassification"), # noqa: E501 + "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"), # [Multimodal] "LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"), # noqa: E501 "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"), - "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration") # noqa: E501, + "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"), # noqa: E501 +} + +_CROSS_ENCODER_MODELS = { + "BertForSequenceClassification": ("bert", "BertForSequenceClassification"), + "RobertaForSequenceClassification": ("roberta", + "RobertaForSequenceClassification"), + "XLMRobertaForSequenceClassification": ("roberta", + "RobertaForSequenceClassification"), } _MULTIMODAL_MODELS = { # [Decoder-only] + "AriaForConditionalGeneration": ("aria", "AriaForConditionalGeneration"), "Blip2ForConditionalGeneration": ("blip2", "Blip2ForConditionalGeneration"), "ChameleonForConditionalGeneration": ("chameleon", "ChameleonForConditionalGeneration"), # noqa: E501 "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"), @@ -160,6 +177,7 @@ _VLLM_MODELS = { **_TEXT_GENERATION_MODELS, **_EMBEDDING_MODELS, + **_CROSS_ENCODER_MODELS, **_MULTIMODAL_MODELS, **_SPECULATIVE_DECODING_MODELS, } @@ -192,8 +210,10 @@ @dataclass(frozen=True) class _ModelInfo: + architecture: str is_text_generation_model: bool - is_embedding_model: bool + is_pooling_model: bool + supports_cross_encoding: bool supports_multimodal: bool supports_pp: bool has_inner_state: bool @@ -201,9 +221,20 @@ class _ModelInfo: @staticmethod def from_model_cls(model: Type[nn.Module]) -> "_ModelInfo": + is_pooling_model_ = is_pooling_model(model) + if not is_pooling_model_: + try: + as_embedding_model(model) + except Exception: + pass + else: + is_pooling_model_ = True + return _ModelInfo( + architecture=model.__name__, is_text_generation_model=is_text_generation_model(model), - is_embedding_model=is_embedding_model(model), + is_pooling_model=is_pooling_model_, + supports_cross_encoding=supports_cross_encoding(model), supports_multimodal=supports_multimodal(model), supports_pp=supports_pp(model), has_inner_state=has_inner_state(model), @@ -381,13 +412,13 @@ def _normalize_archs( def inspect_model_cls( self, architectures: Union[str, List[str]], - ) -> _ModelInfo: + ) -> Tuple[_ModelInfo, str]: architectures = self._normalize_archs(architectures) for arch in architectures: model_info = self._try_inspect_model_cls(arch) if model_info is not None: - return model_info + return (model_info, arch) return self._raise_for_unsupported(architectures) @@ -408,33 +439,50 @@ def is_text_generation_model( self, architectures: Union[str, List[str]], ) -> bool: - return self.inspect_model_cls(architectures).is_text_generation_model + model_cls, _ = self.inspect_model_cls(architectures) + return model_cls.is_text_generation_model + + def is_pooling_model( + self, + architectures: Union[str, List[str]], + ) -> bool: + model_cls, _ = self.inspect_model_cls(architectures) + return model_cls.is_pooling_model - def is_embedding_model( + def is_cross_encoder_model( self, architectures: Union[str, List[str]], ) -> bool: - return self.inspect_model_cls(architectures).is_embedding_model + model_cls, _ = self.inspect_model_cls(architectures) + return model_cls.supports_cross_encoding def is_multimodal_model( self, architectures: Union[str, List[str]], ) -> bool: - return self.inspect_model_cls(architectures).supports_multimodal + model_cls, _ = self.inspect_model_cls(architectures) + return model_cls.supports_multimodal def is_pp_supported_model( self, architectures: Union[str, List[str]], ) -> bool: - return self.inspect_model_cls(architectures).supports_pp + model_cls, _ = self.inspect_model_cls(architectures) + return model_cls.supports_pp - def model_has_inner_state(self, architectures: Union[str, - List[str]]) -> bool: - return self.inspect_model_cls(architectures).has_inner_state + def model_has_inner_state( + self, + architectures: Union[str, List[str]], + ) -> bool: + model_cls, _ = self.inspect_model_cls(architectures) + return model_cls.has_inner_state - def is_attention_free_model(self, architectures: Union[str, - List[str]]) -> bool: - return self.inspect_model_cls(architectures).is_attention_free + def is_attention_free_model( + self, + architectures: Union[str, List[str]], + ) -> bool: + model_cls, _ = self.inspect_model_cls(architectures) + return model_cls.is_attention_free ModelRegistry = _ModelRegistry({ @@ -490,4 +538,4 @@ def _run() -> None: if __name__ == "__main__": - _run() \ No newline at end of file + _run() diff --git a/vllm/model_executor/models/roberta.py b/vllm/model_executor/models/roberta.py index c1dcdd36ec3de..ba1a78ac640fd 100644 --- a/vllm/model_executor/models/roberta.py +++ b/vllm/model_executor/models/roberta.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import Iterable, List, Optional, Tuple import torch from torch import nn @@ -6,10 +6,18 @@ from vllm.attention import AttentionMetadata from vllm.config import VllmConfig +from vllm.model_executor.layers.pooler import CrossEncodingPooler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.bert import BertEmbeddingModel, BertModel -from vllm.sequence import IntermediateTensors +from vllm.model_executor.models.utils import maybe_prefix +from vllm.model_executor.pooling_metadata import PoolingMetadata +from vllm.sequence import IntermediateTensors, PoolerOutput +from vllm.transformers_utils.config import ( + get_cross_encoder_activation_function) + +from .interfaces import SupportsCrossEncoding class RobertaEmbedding(nn.Module): @@ -39,34 +47,93 @@ def __init__(self, config: RobertaConfig): def forward( self, input_ids: torch.Tensor, - position_ids: Optional[torch.Tensor] = None, + seq_lens: torch.Tensor, + position_ids: torch.Tensor, + token_type_ids: Optional[torch.Tensor] = None, ) -> torch.Tensor: input_shape = input_ids.size() - - # Input embeddings. inputs_embeds = self.word_embeddings(input_ids) - # TODO: figure out if there is a better way - # to make to make position ids start at padding_idx + 1 + # Replace position ids because in RoBERTa models + # they have to start at padding_idx + 1 and ignore + # existing padding tokens # References: # - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L133 # - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L1669 - position_ids += self.padding_idx + 1 + pos_list = [] + token_list = [] + offset = 0 + for seq_len in seq_lens: + pos_list.append(position_ids[offset:offset + seq_len]) + token_list.append(input_ids[offset:offset + seq_len]) + offset += seq_len + + new_pos_list = [] + for positions, tokens in zip(pos_list, token_list): + # Verify assumption that incoming position are + # always a sequence from 0 to N. + expected_pos = torch.arange(positions.size()[0], + dtype=torch.long, + device=inputs_embeds.device) + assert torch.equal(positions, expected_pos) + new_pos_list.append( + create_position_ids_from_input_ids(tokens, self.padding_idx)) + position_ids = torch.cat(new_pos_list) # Position embeddings. position_embeddings = self.position_embeddings(position_ids) + if token_type_ids is None: + token_type_ids = torch.zeros(input_shape, + dtype=torch.long, + device=inputs_embeds.device) - # Token type embeddings. (TODO: move off hotpath?) - token_type_embeddings = self.token_type_embeddings( - torch.zeros(input_shape, - dtype=torch.long, - device=inputs_embeds.device)) - + token_type_embeddings = self.token_type_embeddings(token_type_ids) embeddings = inputs_embeds + token_type_embeddings + position_embeddings embeddings = self.LayerNorm(embeddings) return embeddings +# Adapted from transformers +def create_position_ids_from_input_ids(input_ids, + padding_idx, + past_key_values_length=0): + """ + Replace non-padding symbols with their position numbers. + Position numbers begin at padding_idx+1. Padding symbols + are ignored. This is modified from fairseq's `utils.make_positions`. + + Args: + x: torch.Tensor x: + + Returns: torch.Tensor + """ + # The series of casts and type-conversions here are carefully + # balanced to both work with ONNX export and XLA. + mask = input_ids.ne(padding_idx).int() + + incremental_indices = (torch.cumsum(mask, dim=0).type_as(mask) + + past_key_values_length) * mask + + return incremental_indices.long() + padding_idx + + +# Adapted from transformers +class RobertaClassificationHead(nn.Module): + """Head for sentence-level classification tasks.""" + + def __init__(self, config: RobertaConfig): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.out_proj = nn.Linear(config.hidden_size, config.num_labels) + + def forward(self, features, **kwargs): + x = features[0, :] # take token (equiv. to [CLS]) + x = self.dense(x) + x = torch.tanh(x) + x = self.out_proj(x) + return x + + class RobertaEmbeddingModel(BertEmbeddingModel): """A model that uses Roberta to provide embedding functionalities. @@ -85,6 +152,62 @@ def _build_model(self, prefix=prefix, embedding_class=RobertaEmbedding) + +class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding): + """A model that uses Roberta to provide embedding functionalities. + + This class encapsulates the BertModel and provides an interface for + embedding operations and customized pooling functions. + + Attributes: + roberta: An instance of BertModel used for forward operations. + _pooler: An instance of Pooler used for pooling operations. + """ + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + + self.default_activation_function = \ + get_cross_encoder_activation_function(config) + + self.num_labels = config.num_labels + self.roberta = BertModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "bert"), + embedding_class=RobertaEmbedding, + add_pooling_layer=False) + self.classifier = RobertaClassificationHead(config) + self._pooler = CrossEncodingPooler(config, self.classifier) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + + self_weights = [] + + def weight_filter(): + for name, weight in weights: + if name.startswith("roberta."): + yield (name[len("roberta."):], weight) + else: + self_weights.append((name, weight)) + + self.roberta.load_weights(weight_filter()) + + params_dict = dict(self.named_parameters()) + + for name, loaded_weight in self_weights: + if name.startswith("classifier"): + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + + def pooler( + self, + hidden_states: torch.Tensor, + pooling_metadata: PoolingMetadata, + ) -> Optional[PoolerOutput]: + return self._pooler(hidden_states, pooling_metadata) + def forward( self, input_ids: Optional[torch.Tensor], @@ -93,25 +216,12 @@ def forward( attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, ) -> torch.Tensor: - - # Verify assumption that position are always a sequence from - # 0 to N. (Actually here we just check 0 and N to simplify). - # This is important to fix the position which are assumed to - # start from padding_idx + 1 instead of 0 in the Roberta models. - assert hasattr(attn_metadata, "seq_lens_tensor") - cumulative = attn_metadata.seq_lens_tensor.cumsum(dim=0) - start_pos = torch.cat( - (torch.tensor([0], device=attn_metadata.seq_lens_tensor.device), - cumulative[:-1])) - assert len(torch.nonzero(positions[start_pos])) == 0 - end_pos = cumulative - 1 - last_tokens = attn_metadata.seq_lens_tensor - 1 - assert len(torch.nonzero(positions[end_pos] - last_tokens)) == 0 - - return super().forward(input_ids=input_ids, - positions=positions, - kv_caches=kv_caches, - attn_metadata=attn_metadata, - intermediate_tensors=intermediate_tensors, - inputs_embeds=inputs_embeds) + return self.roberta(input_ids=input_ids, + position_ids=positions, + kv_caches=kv_caches, + inputs_embeds=inputs_embeds, + intermediate_tensors=intermediate_tensors, + attn_metadata=attn_metadata, + token_type_ids=token_type_ids) diff --git a/vllm/model_executor/models/siglip.py b/vllm/model_executor/models/siglip.py index c58ad99692900..deaed0ba7e4ce 100644 --- a/vllm/model_executor/models/siglip.py +++ b/vllm/model_executor/models/siglip.py @@ -25,7 +25,8 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.multimodal.utils import (cached_get_tokenizer, consecutive_placeholder_ranges, - repeat_and_pad_placeholder_tokens) + repeat_and_pad_placeholder_tokens, + resolve_visual_encoder_outputs) from vllm.sequence import SequenceData from .utils import get_vit_attn_backend @@ -450,11 +451,19 @@ def __init__( def forward( self, inputs_embeds: torch.Tensor, - ) -> torch.Tensor: + return_all_hidden_states: bool, + ) -> Union[torch.Tensor, list[torch.Tensor]]: + hidden_states_pool = [] hidden_states = inputs_embeds + for encoder_layer in self.layers: hidden_states, _ = encoder_layer(hidden_states) - + if return_all_hidden_states: + hidden_states_pool.append(hidden_states) + # If we have multiple feature sample layers, we return all hidden + # states in order and grab the ones we need by index. + if return_all_hidden_states: + return hidden_states_pool return hidden_states @@ -509,6 +518,7 @@ def __init__( embed_dim = config.hidden_size self.embeddings = SiglipVisionEmbeddings(config) + self.encoder = SiglipEncoder( config, quant_config=quant_config, @@ -546,23 +556,33 @@ def forward( self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = True, + feature_sample_layers: Optional[list[int]] = None, ) -> torch.Tensor: + hidden_states = self.embeddings( pixel_values, interpolate_pos_encoding=interpolate_pos_encoding, ) - encoder_outputs = self.encoder(inputs_embeds=hidden_states) + return_all_hidden_states = feature_sample_layers is not None + + # Produces either the last layer output or all of the hidden states, + # depending on if we have feature_sample_layers or not + encoder_outputs = self.encoder( + inputs_embeds=hidden_states, + return_all_hidden_states=return_all_hidden_states, + ) - if self.post_layernorm is None: - return encoder_outputs + # Handle post-norm (if applicable) and stacks feature layers if needed + encoder_outputs = resolve_visual_encoder_outputs( + encoder_outputs, feature_sample_layers, self.post_layernorm, + self.config.num_hidden_layers) - last_hidden_state = self.post_layernorm(encoder_outputs) - # TODO: add this back when pooled_output is used in inference + # TODO: add this back when pooled_output is used in inference. # if self.use_head: - # pooled_output = self.head(last_hidden_state) + # pooled_output = self.head(encoder_outputs) - return last_hidden_state + return encoder_outputs class SiglipVisionModel(nn.Module): @@ -595,10 +615,12 @@ def forward( self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = False, + feature_sample_layers: Optional[list[int]] = None, ) -> torch.Tensor: return self.vision_model( pixel_values=pixel_values, interpolate_pos_encoding=interpolate_pos_encoding, + feature_sample_layers=feature_sample_layers, ) def load_weights(self, weights: Iterable[Tuple[str, diff --git a/vllm/model_executor/models/solar.py b/vllm/model_executor/models/solar.py index 6d6fafc5ab0eb..f58710d215056 100644 --- a/vllm/model_executor/models/solar.py +++ b/vllm/model_executor/models/solar.py @@ -167,6 +167,7 @@ def __init__( num_kv_heads=self.num_kv_heads, cache_config=cache_config, quant_config=quant_config, + prefix=f"{prefix}.attn", ) def forward( diff --git a/vllm/model_executor/models/stablelm.py b/vllm/model_executor/models/stablelm.py index e11d2e916730a..6b2107bef0a66 100644 --- a/vllm/model_executor/models/stablelm.py +++ b/vllm/model_executor/models/stablelm.py @@ -77,7 +77,8 @@ class StablelmAttention(nn.Module): def __init__(self, config: PretrainedConfig, cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None) -> None: + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "") -> None: super().__init__() self.config = config self.hidden_size = config.hidden_size @@ -131,7 +132,8 @@ def __init__(self, self.scaling, num_kv_heads=self.num_key_value_heads, cache_config=cache_config, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.attn") def forward( self, @@ -155,9 +157,13 @@ def __init__( config: PretrainedConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() - self.self_attn = StablelmAttention(config, cache_config, quant_config) + self.self_attn = StablelmAttention(config, + cache_config, + quant_config, + prefix=f"{prefix}.self_attn") self.mlp = StablelmMLP(config, quant_config) norm_eps = getattr(config, "norm_eps", getattr(config, "layer_norm_eps", 1e-05)) @@ -207,8 +213,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): ) self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, - lambda prefix: StablelmDecoderLayer(config, cache_config, - quant_config), + lambda prefix: StablelmDecoderLayer( + config, cache_config, quant_config, prefix=prefix), prefix=f"{prefix}.layers", ) norm_eps = getattr(config, "norm_eps", diff --git a/vllm/model_executor/models/starcoder2.py b/vllm/model_executor/models/starcoder2.py index 74c66042226de..15e8f2af52cda 100644 --- a/vllm/model_executor/models/starcoder2.py +++ b/vllm/model_executor/models/starcoder2.py @@ -52,7 +52,8 @@ class Starcoder2Attention(nn.Module): def __init__(self, config: Starcoder2Config, cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None): + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): super().__init__() self.config = config @@ -105,7 +106,8 @@ def __init__(self, self.scaling, num_kv_heads=self.num_kv_heads, cache_config=cache_config, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.attn") def forward( self, @@ -154,12 +156,14 @@ class Starcoder2DecoderLayer(nn.Module): def __init__(self, config: Starcoder2Config, cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None): + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): super().__init__() self.hidden_size = config.hidden_size self.self_attn = Starcoder2Attention(config, cache_config, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.self_attn") self.mlp = Starcoder2MLP(config, quant_config=quant_config) self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon) @@ -213,7 +217,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, lambda prefix: Starcoder2DecoderLayer( - config, cache_config, quant_config=quant_config), + config, cache_config, quant_config=quant_config, prefix=prefix + ), prefix=f"{prefix}.layers", ) self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon) diff --git a/vllm/model_executor/models/telechat2.py b/vllm/model_executor/models/telechat2.py new file mode 100644 index 0000000000000..39c9103527f01 --- /dev/null +++ b/vllm/model_executor/models/telechat2.py @@ -0,0 +1,131 @@ +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Iterable, Set, Tuple + +import torch + +from vllm.config import VllmConfig +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.llama import LlamaForCausalLM, LlamaModel + +from .utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper, + is_pp_missing_parameter) + + +class TeleChat2Model(LlamaModel): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + # 1. Initialize the LlamaModel with bias + vllm_config.model_config.hf_config.bias = True + vllm_config.model_config.hf_config.mlp_bias = True + super().__init__(vllm_config=vllm_config, prefix=prefix) + # 2. Remove the bias from the qkv_proj and gate_up_proj based on config + # Telechat2's gate_up_proj and qkv_proj don't have bias + # see: https://github.com/vllm-project/vllm/pull/10311#issuecomment-2490297566 + for layer in self.layers: + if not isinstance(layer, PPMissingLayer): + layer.self_attn.qkv_proj.bias = None + layer.self_attn.qkv_proj.skip_bias_add = True + layer.mlp.gate_up_proj.bias = None + layer.mlp.gate_up_proj.skip_bias_add = True + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + stacked_params_mapping = [ + ('gate_up_proj', 'gate_proj', 0), + ('gate_up_proj', 'up_proj', 1), + ] + params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() + total_num_heads = self.config.n_head + head_dim = self.config.hidden_size // total_num_heads + for name, loaded_weight in weights: + if "self_attn.key_value" in name: + k_weight = [] + v_weight = [] + for i in range(total_num_heads): + start = i * head_dim * 2 + k_weight.append(loaded_weight[start:start + head_dim, :]) + v_weight.append(loaded_weight[start + head_dim:start + + 2 * head_dim:]) + k_weight = torch.cat(k_weight, dim=0) + v_weight = torch.cat(v_weight, dim=0) + name = name.replace("key_value", "qkv_proj") + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, k_weight, "k") + weight_loader(param, v_weight, "v") + elif "query" in name: + name = name.replace("query", "qkv_proj") + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, "q") + else: + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +class TeleChat2ForCausalLM(LlamaForCausalLM): + + def _init_model(self, vllm_config: VllmConfig, prefix: str = ""): + return TeleChat2Model(vllm_config=vllm_config, prefix=prefix) + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix={ + "transformer.": "model.", + }, + orig_to_new_substr={ + ".h.": ".layers.", + ".self_attention.": ".self_attn.", + ".word_embeddings.": ".embed_tokens.", + ".dense.": ".o_proj.", + ".ln_f.": ".norm.", + }, + ) + loader = AutoWeightsLoader( + self, + skip_prefixes=(["lm_head."] + if self.config.tie_word_embeddings else None), + ) + return loader.load_weights(weights, mapper=hf_to_vllm_mapper) diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index 512adbc7db35e..ea1e5401d42c0 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -360,9 +360,10 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): )) self.multi_modal_projector = UltravoxProjector(config) self.language_model = init_vllm_registered_model( - config.text_config, vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "language_model")) + hf_config=config.text_config, + prefix=maybe_prefix(prefix, "language_model"), + ) if config.text_model_id is not None: # this prefix is not for initialization, but for loading weights # note the trailing dot @@ -449,10 +450,36 @@ def _process_audio_input( return result - def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, + def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]: + audio_input = self._parse_and_validate_audio_input(**kwargs) + if audio_input is None: + return None + audio_embeddings = self._process_audio_input(audio_input) + return audio_embeddings + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[NestedTensors] = None, + attn_metadata: Optional[AttentionMetadata] = None, + ) -> torch.Tensor: + inputs_embeds = self.language_model.get_input_embeddings(input_ids) + if multimodal_embeddings is not None: + + # TODO(ywang96): use merge_multimodal_embeddings after + # v0 is deprecated + merge_multimodal_embeddings_from_map( + inputs_embeds, multimodal_embeddings, + attn_metadata.multi_modal_placeholder_index_maps["audio"]) + return inputs_embeds + + def forward(self, + input_ids: torch.Tensor, + positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, - intermediate_tensors: Optional[torch.Tensor], + intermediate_tensors: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, **kwargs) -> Union[torch.Tensor, IntermediateTensors]: """Run forward pass for Ultravox @@ -466,30 +493,28 @@ def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, Args: audio_features: A batch of audio inputs [B, N, 80, M]. """ + if intermediate_tensors is not None: - input_ids = None inputs_embeds = None - else: - audio_input = self._parse_and_validate_audio_input(**kwargs) - if audio_input is not None: - audio_embeddings = self._process_audio_input(audio_input) - inputs_embeds = self.language_model.model.get_input_embeddings( - input_ids) - - merge_multimodal_embeddings_from_map( - inputs_embeds, audio_embeddings, - attn_metadata.multi_modal_placeholder_index_maps["audio"]) - input_ids = None - else: - inputs_embeds = None - - hidden_states = self.language_model.model( - input_ids=input_ids, - positions=positions, - kv_caches=kv_caches, - attn_metadata=attn_metadata, - intermediate_tensors=intermediate_tensors, - inputs_embeds=inputs_embeds) + + # NOTE: In v1, inputs_embeds is always generated at model runner, this + # condition is for v0 compatibility. + elif inputs_embeds is None: + multimodal_embeddings = self.get_multimodal_embeddings(**kwargs) + + # TODO(ywang96): remove attn_metadata from get_input_embeddings + # after v0 is deprecated + inputs_embeds = self.get_input_embeddings(input_ids, + multimodal_embeddings, + attn_metadata) + input_ids = None + + hidden_states = self.language_model.model(input_ids, + positions, + kv_caches, + attn_metadata, + intermediate_tensors, + inputs_embeds=inputs_embeds) return hidden_states def compute_logits(self, hidden_states: torch.Tensor, diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index 03226f42ee053..7a1e1f9bf2be4 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -1,7 +1,7 @@ import itertools from dataclasses import dataclass, field -from typing import (Any, Callable, Dict, Iterable, List, Literal, Mapping, - Optional, Protocol, Set, Tuple, Union, overload) +from typing import (Callable, Dict, Iterable, List, Literal, Mapping, Optional, + Protocol, Set, Tuple, Union, overload) import torch import torch.nn as nn @@ -9,13 +9,13 @@ from transformers import PretrainedConfig import vllm.envs as envs -from vllm.attention.selector import (_Backend, backend_name_to_enum, +from vllm.attention.selector import (backend_name_to_enum, get_global_forced_attn_backend) from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.multimodal import MultiModalPlaceholderMap, NestedTensors -from vllm.platforms import current_platform +from vllm.platforms import _Backend, current_platform from vllm.sequence import IntermediateTensors from vllm.utils import is_pin_memory_available @@ -173,8 +173,15 @@ def _load_module( module_load_weights = getattr(module, "load_weights", None) if callable(module_load_weights): loaded_params = module_load_weights(weights) - yield from map(lambda x: self._get_qualname(base_prefix, x), - loaded_params) + if loaded_params is None: + logger.warning( + "Unable to collect loaded parameters " + "for module %s", module) + else: + yield from map( + lambda x: self._get_qualname(base_prefix, x), + loaded_params, + ) child_modules = dict(module.named_children()) child_params = dict(module.named_parameters(recurse=False)) @@ -232,17 +239,24 @@ def load_weights( def init_vllm_registered_model( - hf_config: PretrainedConfig, vllm_config: VllmConfig, + *, prefix: str = "", + hf_config: Optional[PretrainedConfig] = None, + architectures: Optional[list[str]] = None, ) -> nn.Module: """ Helper function to initialize an inner model registered to vLLM, based on the arguments passed to the outer vLLM model. """ from vllm.model_executor.model_loader.loader import _initialize_model - vllm_config = vllm_config.with_hf_config(hf_config) - return _initialize_model(vllm_config, prefix) + + if hf_config is not None: + vllm_config = vllm_config.with_hf_config(hf_config) + + return _initialize_model(vllm_config=vllm_config, + prefix=prefix, + architectures=architectures) @overload @@ -356,8 +370,7 @@ def embed_multimodal( input_ids: torch.Tensor, multimodal_token_id: int, get_text_embeds: Callable[[torch.Tensor], torch.Tensor], - get_multimodal_embeds: Callable[[torch.Tensor], Union[torch.Tensor, - List[torch.Tensor]]], + multimodal_embeds: Union[torch.Tensor, List[torch.Tensor]], ) -> torch.Tensor: """ Embed token IDs and multimodal inputs and combine their embeddings. @@ -374,8 +387,6 @@ def embed_multimodal( is_text = ~is_multimodal text_embeds = get_text_embeds(input_ids[is_text]) - multimodal_embeds = get_multimodal_embeds(input_ids[is_multimodal]) - merged_embeds = torch.empty( (input_ids.shape[0], text_embeds.shape[1]), dtype=text_embeds.dtype, @@ -563,30 +574,6 @@ def make_empty_intermediate_tensors( return make_empty_intermediate_tensors -class LLMWrapper(nn.Module): - """ - To align with the key names of LoRA trained with PEFT, we need to add an - additional layer to the llm's implementation. - """ - - def __init__(self, llm: nn.Module, name: str) -> None: - super().__init__() - self.model_name = name - setattr(self, name, llm) - - def __getattr__(self, key: str): - llm = super().__getattr__(self.model_name) - if key == self.model_name: - return llm - - return getattr(llm, key) - - # We need to explicitly override this - def __call__(self, *args: Any, **kwargs: Any) -> Any: - llm = super().__getattr__(self.model_name) - return llm(*args, **kwargs) - - def get_vit_attn_backend(support_fa: bool = False) -> _Backend: """ Get the available attention backend for Vision Transformer. @@ -629,3 +616,24 @@ def maybe_prefix(prefix: str, name: str) -> str: The string "prefix.name" if prefix was non-empty, otherwise just "name". """ return name if not prefix else f"{prefix}.{name}" + + +def extract_layer_index(layer_name: str) -> int: + """ + Extract the layer index from the module name. + Examples: + - "encoder.layers.0" -> 0 + - "encoder.layers.1.self_attn" -> 1 + - "2.self_attn" -> 2 + - "model.encoder.layers.0.sub.1" -> ValueError + """ + subnames = layer_name.split(".") + int_vals: List[int] = [] + for subname in subnames: + try: + int_vals.append(int(subname)) + except ValueError: + continue + assert len(int_vals) == 1, (f"layer name {layer_name} should" + " only contain one integer") + return int_vals[0] diff --git a/vllm/model_executor/models/xverse.py b/vllm/model_executor/models/xverse.py deleted file mode 100644 index bc37a997eabb5..0000000000000 --- a/vllm/model_executor/models/xverse.py +++ /dev/null @@ -1,419 +0,0 @@ -# Adapted from -# https://huggingface.co/xverse/XVERSE-7B/blob/main/modeling_xverse.py -# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. -# -# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX -# and OPT implementations in this library. It has been modified from its -# original forms to accommodate minor architectural differences compared -# to GPT-NeoX and OPT used by the Meta AI team that trained the model. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Inference-only Xverse model compatible with HuggingFace weights.""" -from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union - -import torch -from torch import nn -from transformers import PretrainedConfig - -from vllm.attention import Attention, AttentionMetadata -from vllm.compilation.decorators import support_torch_compile -from vllm.config import CacheConfig, VllmConfig -from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size -from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) -from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler -from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors - -from .interfaces import SupportsLoRA, SupportsPP -from .utils import (is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) - - -class XverseMLP(nn.Module): - - def __init__( - self, - hidden_size: int, - intermediate_size: int, - hidden_act: str, - quant_config: Optional[QuantizationConfig] = None, - ) -> None: - super().__init__() - self.gate_up_proj = MergedColumnParallelLinear( - hidden_size, [intermediate_size] * 2, - bias=False, - quant_config=quant_config) - self.down_proj = RowParallelLinear(intermediate_size, - hidden_size, - bias=False, - quant_config=quant_config) - if hidden_act != "silu": - raise ValueError(f"Unsupported activation: {hidden_act}. " - "Only silu is supported for now.") - self.act_fn = SiluAndMul() - - def forward(self, x): - gate, _ = self.gate_up_proj(x) - x = self.act_fn(gate) - x, _ = self.down_proj(x) - return x - - -class XverseAttention(nn.Module): - - def __init__( - self, - hidden_size: int, - num_heads: int, - num_kv_heads: int, - rope_theta: float = 10000, - rope_scaling: Optional[Dict[str, Any]] = None, - max_position_embeddings: int = 8192, - quant_config: Optional[QuantizationConfig] = None, - bias: bool = False, - cache_config: Optional[CacheConfig] = None, - ) -> None: - super().__init__() - self.hidden_size = hidden_size - tp_size = get_tensor_model_parallel_world_size() - self.total_num_heads = num_heads - assert self.total_num_heads % tp_size == 0 - self.num_heads = self.total_num_heads // tp_size - self.total_num_kv_heads = num_kv_heads - # partition the KV heads across multiple tensor parallel GPUs. - assert self.total_num_kv_heads % tp_size == 0 - self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) - self.head_dim = hidden_size // self.total_num_heads - self.q_size = self.num_heads * self.head_dim - self.kv_size = self.num_kv_heads * self.head_dim - self.scaling = self.head_dim**-0.5 - self.rope_theta = rope_theta - self.max_position_embeddings = max_position_embeddings - - self.qkv_proj = QKVParallelLinear( - hidden_size, - self.head_dim, - self.total_num_heads, - self.total_num_kv_heads, - bias=bias, - quant_config=quant_config, - ) - self.o_proj = RowParallelLinear( - self.total_num_heads * self.head_dim, - hidden_size, - bias=bias, - quant_config=quant_config, - ) - - self.rotary_emb = get_rope( - self.head_dim, - rotary_dim=self.head_dim, - max_position=max_position_embeddings, - base=rope_theta, - rope_scaling=rope_scaling, - ) - self.attn = Attention(self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads, - cache_config=cache_config, - quant_config=quant_config) - - def forward( - self, - positions: torch.Tensor, - hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, - ) -> torch.Tensor: - qkv, _ = self.qkv_proj(hidden_states) - q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - q, k = self.rotary_emb(positions, q, k) - attn_output = self.attn(q, k, v, kv_cache, attn_metadata) - output, _ = self.o_proj(attn_output) - return output - - -class XverseDecoderLayer(nn.Module): - - def __init__( - self, - config: PretrainedConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - ) -> None: - super().__init__() - self.hidden_size = config.hidden_size - rope_theta = getattr(config, "rope_theta", 10000) - rope_scaling = getattr(config, "rope_scaling", None) - max_position_embeddings = getattr(config, "max_position_embeddings", - 8192) - self.self_attn = XverseAttention( - hidden_size=self.hidden_size, - num_heads=config.num_attention_heads, - num_kv_heads=getattr(config, "num_key_value_heads", - config.num_attention_heads), - rope_theta=rope_theta, - rope_scaling=rope_scaling, - max_position_embeddings=max_position_embeddings, - quant_config=quant_config, - bias=getattr(config, "bias", False), - cache_config=cache_config, - ) - self.mlp = XverseMLP( - hidden_size=self.hidden_size, - intermediate_size=config.intermediate_size, - hidden_act=config.hidden_act, - quant_config=quant_config, - ) - self.input_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.post_attention_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - - def forward( - self, - positions: torch.Tensor, - hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, - residual: Optional[torch.Tensor], - ) -> Tuple[torch.Tensor, torch.Tensor]: - # Self Attention - if residual is None: - residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) - else: - hidden_states, residual = self.input_layernorm( - hidden_states, residual) - hidden_states = self.self_attn( - positions=positions, - hidden_states=hidden_states, - kv_cache=kv_cache, - attn_metadata=attn_metadata, - ) - - # Fully Connected - hidden_states, residual = self.post_attention_layernorm( - hidden_states, residual) - hidden_states = self.mlp(hidden_states) - return hidden_states, residual - - -@support_torch_compile -class XverseModel(nn.Module): - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__() - config = vllm_config.model_config.hf_config - cache_config = vllm_config.cache_config - quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config - self.config = config - self.padding_idx = config.pad_token_id - lora_vocab = (lora_config.lora_extra_vocab_size * - (lora_config.max_loras or 1)) if lora_config else 0 - self.vocab_size = config.vocab_size + lora_vocab - self.org_vocab_size = config.vocab_size - self.embed_tokens = VocabParallelEmbedding( - self.vocab_size, - config.hidden_size, - org_num_embeddings=config.vocab_size, - ) - self.start_layer, self.end_layer, self.layers = make_layers( - config.num_hidden_layers, - lambda prefix: XverseDecoderLayer(config, cache_config, - quant_config), - prefix=f"{prefix}.layers", - ) - self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory( - ["hidden_states", "residual"], config.hidden_size)) - - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.embed_tokens(input_ids) - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, - intermediate_tensors: Optional[IntermediateTensors], - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: - if get_pp_group().is_first_rank: - if inputs_embeds is not None: - hidden_states = inputs_embeds - else: - hidden_states = self.get_input_embeddings(input_ids) - residual = None - else: - hidden_states = intermediate_tensors["hidden_states"] - residual = intermediate_tensors["residual"] - for i in range(self.start_layer, self.end_layer): - layer = self.layers[i] - hidden_states, residual = layer( - positions, - hidden_states, - kv_caches[i - self.start_layer], - attn_metadata, - residual, - ) - if not get_pp_group().is_last_rank: - return IntermediateTensors({ - "hidden_states": hidden_states, - "residual": residual - }) - hidden_states, _ = self.norm(hidden_states, residual) - return hidden_states - - -class XverseForCausalLM(nn.Module, SupportsLoRA, SupportsPP): - packed_modules_mapping = { - "qkv_proj": [ - "q_proj", - "k_proj", - "v_proj", - ], - "gate_up_proj": [ - "gate_proj", - "up_proj", - ], - } - - # LoRA specific attributes - supported_lora_modules = [ - "qkv_proj", - "o_proj", - "gate_up_proj", - "down_proj", - "embed_tokens", - "lm_head", - ] - embedding_modules = { - "embed_tokens": "input_embeddings", - "lm_head": "output_embeddings", - } - embedding_padding_modules = ["lm_head"] - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__() - - config = vllm_config.model_config.hf_config - quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config - - self.config = config - self.lora_config = lora_config - - self.quant_config = quant_config - self.model = XverseModel(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) - self.lm_head = ParallelLMHead(config.vocab_size, - config.hidden_size, - quant_config=quant_config) - if self.config.tie_word_embeddings: - self.lm_head.weight = self.model.embed_tokens.weight - self.logits_processor = LogitsProcessor(config.vocab_size) - self.sampler = get_sampler() - self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) - - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.model.get_input_embeddings(input_ids) - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors, - inputs_embeds) - return hidden_states - - def compute_logits( - self, - hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) - return logits - - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: - stacked_params_mapping = [ - ("qkv_proj", "q_proj", "q"), - ("qkv_proj", "k_proj", "k"), - ("qkv_proj", "v_proj", "v"), - ("gate_up_proj", "gate_proj", 0), - ("gate_up_proj", "up_proj", 1), - ] - params_dict = dict(self.named_parameters()) - loaded_params: Set[str] = set() - for name, loaded_weight in weights: - if ("rotary_emb.inv_freq" in name - or "rotary_emb.cos_cached" in name - or "rotary_emb.sin_cached" in name): - continue - for (param_name, weight_name, shard_id) in stacked_params_mapping: - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - if is_pp_missing_parameter(name, self): - continue - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - break - else: - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - if is_pp_missing_parameter(name, self): - continue - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) - loaded_params.add(name) - return loaded_params diff --git a/vllm/model_executor/sampling_metadata.py b/vllm/model_executor/sampling_metadata.py index 84f35f75a0c32..1df8f84ed4093 100644 --- a/vllm/model_executor/sampling_metadata.py +++ b/vllm/model_executor/sampling_metadata.py @@ -454,6 +454,7 @@ def from_sampling_metadata( if do_penalties: for seq_group in sampling_metadata.seq_groups: seq_ids = seq_group.seq_ids + sampling_params = seq_group.sampling_params if (seq_group.is_prompt and sampling_params.prompt_logprobs is not None): prefill_len = len(seq_group.prompt_logprob_indices) diff --git a/vllm/multimodal/base.py b/vllm/multimodal/base.py index 6eec660e42ac4..bbb8fb4bc1cd1 100644 --- a/vllm/multimodal/base.py +++ b/vllm/multimodal/base.py @@ -7,7 +7,7 @@ from vllm.inputs import InputContext from vllm.logger import init_logger -from vllm.utils import (get_allowed_kwarg_only_overrides, +from vllm.utils import (ClassRegistry, get_allowed_kwarg_only_overrides, resolve_mm_processor_kwargs) if TYPE_CHECKING: @@ -54,8 +54,8 @@ class MultiModalPlugin(ABC): """ def __init__(self) -> None: - self._input_mappers: Dict[Type[nn.Module], MultiModalInputMapper] = {} - self._max_mm_tokens: Dict[Type[nn.Module], MultiModalTokensCalc] = {} + self._input_mappers = ClassRegistry[nn.Module, MultiModalInputMapper]() + self._max_mm_tokens = ClassRegistry[nn.Module, MultiModalTokensCalc]() @abstractmethod def get_data_key(self) -> str: diff --git a/vllm/multimodal/inputs.py b/vllm/multimodal/inputs.py index 64a4c58d5509c..640c7c04b8817 100644 --- a/vllm/multimodal/inputs.py +++ b/vllm/multimodal/inputs.py @@ -6,7 +6,7 @@ import torch import torch.types from PIL.Image import Image -from typing_extensions import TypeAlias +from typing_extensions import NotRequired, TypeAlias from vllm.utils import JSONTree, is_list_of, json_map_leaves @@ -203,18 +203,14 @@ class MultiModalInputsV2(TypedDict): """The type of inputs.""" prompt: str - """ - The original, unprocessed prompt text. - - Note: - Since prompt text is not required by vLLM internals, we leave this - unprocessed to save CPU computation. You can still call - :code:`tokenizer.decode(prompt_token_ids)` to get the processed text. - """ + """The processed prompt text.""" prompt_token_ids: List[int] """The processed token IDs which includes placeholder tokens.""" + token_type_ids: NotRequired[List[int]] + """The token type IDs of the prompt.""" + mm_kwargs: MultiModalKwargs """Keyword arguments to be directly passed to the model after batching.""" diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index 88a924da174a6..28c8dda581982 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -1,34 +1,91 @@ +import re +from abc import ABC, abstractmethod +from collections.abc import Callable, ItemsView, Iterable, Mapping, Sequence from dataclasses import dataclass -from functools import lru_cache, partial -from typing import (Any, Callable, Collection, Generic, List, Mapping, - Optional, TypedDict, TypeVar, final) +from functools import lru_cache +from itertools import groupby +from typing import Any, Generic, NamedTuple, Optional, Protocol, TypeVar, Union +import numpy as np from transformers import BatchFeature -from typing_extensions import TypeAlias +from typing_extensions import TypeAlias, TypedDict from vllm.inputs import InputProcessingContext from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer -from vllm.utils import is_list_of +from vllm.utils import flatten_2d_lists, full_groupby, is_list_of from .inputs import (AudioItem, ImageItem, MultiModalDataDict, MultiModalInputsV2, MultiModalKwargs, PlaceholderRange, VideoItem) + +def bind_prompt_sequence( + seq: Union[str, list[int]], + tokenizer: AnyTokenizer, +) -> "_BoundPromptSequence": + """ + Bind a text or token sequence to a tokenizer so that it can be + lazily converted into the other format on demand. + """ + return _BoundPromptSequence( + tokenizer=tokenizer, + _text=seq if isinstance(seq, str) else None, + _token_ids=seq if isinstance(seq, list) else None, + ) + + _T = TypeVar("_T") +_S = TypeVar("_S", str, list[int]) -ReplacementFunc: TypeAlias = Callable[[_T, BatchFeature, int], List[int]] -""" -Given the original data item, HF-processed data, and index of the processed -item, output the replacement token IDs to be allocated in vLLM. -""" + +@dataclass +class PromptReplacement(Generic[_S, _T]): + target: _S + """The text or token sequence to find and replace.""" + + repl_unit: _S + """ + The unit making up the replacement text or token sequence. + + See :code:`repl_count` for more details. + """ + + repl_count: Union[Callable[[list[_T], BatchFeature, int], int], int] + """ + Given the original multi-modal items for this modality, HF-processed data, + and index of the processed item, output the number of repetitions of + :code:`repl_unit` to build up the replacement text or token sequence. + + For convenience, you can pass in an integer if the number of repetitions is + a constant. + """ + + def __repr__(self) -> str: + return (f"{type(self).__name__}(target={self.target!r}, " + f"repl_unit={self.repl_unit!r})") + + def bind( + self, + modality: str, + tokenizer: AnyTokenizer, + ) -> "_BoundPromptReplacement[_T]": + return _BoundPromptReplacement( + modality=modality, + target=bind_prompt_sequence(self.target, tokenizer), + repl_unit=bind_prompt_sequence(self.repl_unit, tokenizer), + repl_count=self.repl_count, + ) @dataclass class ModalityProcessingMetadata(Generic[_T]): - placeholder_replacements: Mapping[str, ReplacementFunc] + prompt_repls: Sequence[Union[PromptReplacement[str, _T], + PromptReplacement[list[int], _T]]] """ - A dictionary where each item represents the original placeholder in the - prompt text and the corresponding replacement. + Defines each text or token sequence to replace in the HF-processed prompt. + + This is skipped if the HF-processed prompt is found to already contain + the replacement prompts. """ @@ -52,46 +109,138 @@ class MultiModalProcessingMetadataBuiltins(TypedDict, total=False): Read more on that :ref:`here `. """ -MultiModalMultiData: TypeAlias = List[_T] -""" -A list of data items, where the number of data items allowed -per modality is restricted by :code:`--limit-mm-per-prompt`. -""" +def _encode( + tokenizer: AnyTokenizer, + text: str, + *, + add_special_tokens: bool = False, +) -> list[int]: + """ + Backend-agnostic equivalent of HF's + :code:`tokenizer.encode(text, add_special_tokens=...)`. + """ + if isinstance(tokenizer, MistralTokenizer): + return tokenizer.tokenizer.encode(text, + bos=add_special_tokens, + eos=add_special_tokens) -@final -class MultiModalMultiDataBuiltins(TypedDict, total=False): - """Type annotations for modality types predefined by vLLM.""" + return tokenizer.encode(text, add_special_tokens=add_special_tokens) - image: MultiModalMultiData[ImageItem] - """The input images.""" - video: MultiModalMultiData[VideoItem] - """The input videos.""" +@lru_cache(maxsize=2048) +def _cached_encode( + tokenizer: AnyTokenizer, + text: str, + *, + add_special_tokens: bool = False, +) -> list[int]: + return _encode(tokenizer, text, add_special_tokens=add_special_tokens) - audio: MultiModalMultiData[AudioItem] - """The input audios.""" +def _decode( + tokenizer: AnyTokenizer, + token_ids: list[int], + *, + skip_special_tokens: bool = False, +) -> str: + """ + Backend-agnostic equivalent of HF's + :code:`tokenizer.decode(token_ids, skip_special_tokens=...)`. + """ + return tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens) -MultiModalMultiDataDict: TypeAlias = Mapping[str, MultiModalMultiData[Any]] -""" -A dictionary containing an entry for each modality type to input. -Note: - This dictionary also accepts modality keys defined outside - :class:`MultiModalMultiDataBuiltins` as long as a customized plugin - is registered through the :class:`~vllm.multimodal.MULTIMODAL_REGISTRY`. - Read more on that :ref:`here `. -""" +@lru_cache(maxsize=2048) +def _cached_decode( + tokenizer: AnyTokenizer, + token_ids: tuple[int, ...], + *, + skip_special_tokens: bool = False, +) -> str: + return _decode(tokenizer, + list(token_ids), + skip_special_tokens=skip_special_tokens) + + +class _HasModalityAttr(Protocol): + modality: str + +class _HasModalityProp(Protocol): -def to_multi_format(data: MultiModalDataDict) -> MultiModalMultiDataDict: + @property + def modality(self) -> str: + ... + + +_M = TypeVar("_M", bound=Union[_HasModalityAttr, _HasModalityProp]) + + +def full_groupby_modality(values: Iterable[_M]) -> ItemsView[str, list[_M]]: + """Convenience function to apply :func:`full_groupby` based on modality.""" + return full_groupby(values, key=lambda x: x.modality) + + +@dataclass +class _BoundPromptSequence: + tokenizer: AnyTokenizer + _text: Optional[str] + _token_ids: Optional[list[int]] + + def __post_init__(self) -> None: + if self._text is None and self._token_ids is None: + raise ValueError("At least one of 'text' and 'token_ids' must be " + "specified") + + @property + def text(self) -> str: + if self._text is None: + assert self._token_ids is not None + self._text = _cached_decode(self.tokenizer, tuple(self._token_ids)) + + return self._text + + @property + def token_ids(self) -> list[int]: + if self._token_ids is None: + assert self._text is not None + self._token_ids = _cached_encode(self.tokenizer, self._text) + + return self._token_ids + + def __repr__(self) -> str: + return (f"{type(self).__name__}(_text={self._text!r}, " + f"_token_ids={self._token_ids!r})") + + +@dataclass +class _BoundPromptReplacement(Generic[_T]): + modality: str + target: _BoundPromptSequence + repl_unit: _BoundPromptSequence + repl_count: Union[Callable[[list[_T], BatchFeature, int], int], int] + + def get_count( + self, + mm_items: list[_T], + hf_inputs: BatchFeature, + item_idx: int, + ) -> int: + repl_count = self.repl_count + if isinstance(repl_count, int): + return repl_count + + return repl_count(mm_items, hf_inputs, item_idx) + + +def to_multi_format(data: MultiModalDataDict) -> dict[str, list[Any]]: """ Convert a :class:`MultiModalDataDict` containing single data items to a :class:`MultiModalMultiDataDict` containing multiple data items per entry. """ - multi_data: Mapping[str, MultiModalMultiData[Any]] = {} + multi_data = dict[str, list[Any]]() for k, v in data.items(): # yapf: disable @@ -107,86 +256,279 @@ def to_multi_format(data: MultiModalDataDict) -> MultiModalMultiDataDict: return multi_data -def encode_no_special_tokens( - tokenizer: AnyTokenizer, - text: str, -) -> List[int]: +class _TokenRun(NamedTuple): + token_id: int + + start_idx: int + length: int + + +def iter_token_runs(token_ids: list[int]) -> Iterable[_TokenRun]: """ - Backend-agnostic equivalent of HF's - :code:`tokenizer.encode(text, add_special_tokens=False)`. + Yield the starting index and length of each run of tokens that are the same. """ - if isinstance(tokenizer, MistralTokenizer): - return tokenizer.tokenizer.encode(text, bos=False, eos=False) + start_idx = 0 + + for token_id, it in groupby(token_ids): + length = sum(1 for _ in it) + yield _TokenRun(token_id=token_id, start_idx=start_idx, length=length) + + start_idx += length + + +class _PlaceholderInfo(NamedTuple): + modality: str + offset: int + length: int + + def to_range(self) -> PlaceholderRange: + return PlaceholderRange(offset=self.offset, length=self.length) + + +def iter_placeholders( + prompt_repls: Sequence[_BoundPromptReplacement[Any]], + token_ids: list[int], + *, + min_placeholder_count: int, +) -> Iterable[_PlaceholderInfo]: + """Yield each set of placeholder tokens found in :code:`token_ids`.""" + placeholder_ids_by_modality = { + modality: { + token_id + for prompt_repl in repls + for token_id in prompt_repl.repl_unit.token_ids + } + for modality, repls in full_groupby_modality(prompt_repls) + } - return tokenizer.encode(text, add_special_tokens=False) + for run_info in iter_token_runs(token_ids): + if run_info.length > min_placeholder_count: + for (modality, + placeholder_ids) in placeholder_ids_by_modality.items(): + if run_info.token_id in placeholder_ids: + yield _PlaceholderInfo( + modality=modality, + offset=run_info.start_idx, + length=run_info.length, + ) -@lru_cache -def candidate_placeholders( - tokenizer: AnyTokenizer, - placeholder_text: str, -) -> Collection[List[int]]: - """Generate token ID sequences that may represent a placeholder text.""" - # When the placeholder text is not mapped to a special token ID, - # it may be tokenized differently based on whether it is at the start/end - # of the string. So, we go through each combination of whether the text - # is at the start and end boundaries of the string - - # Matches the placeholder when it is in the middle of the string - start_id, = encode_no_special_tokens(tokenizer, "a") - end_id, = encode_no_special_tokens(tokenizer, "b") - - candidate_basic = encode_no_special_tokens(tokenizer, placeholder_text) - - start_id_, *candidate_a = encode_no_special_tokens( - tokenizer, - f"a{placeholder_text}", - ) - assert start_id == start_id_ +class _TokenMatch(NamedTuple): + start_idx: int + end_idx: int - start_id_, *candidate_ab, end_id_ = encode_no_special_tokens( - tokenizer, - f"a{placeholder_text}b", - ) - assert start_id == start_id_ and end_id == end_id_ - *candidate_b, end_id_ = encode_no_special_tokens( - tokenizer, - f"{placeholder_text}b", - ) - assert end_id == end_id_ +def iter_token_matches( + token_ids: list[int], + match_ids: list[int], +) -> Iterable[_TokenMatch]: + """Yield each occurrence of :code:`match_ids` in :code:`token_ids`.""" + match_len = len(match_ids) - # Remove duplicates (need to convert to tuple to be hashable) - unique_candidates = { - tuple(c) - for c in [candidate_basic, candidate_a, candidate_ab, candidate_b] - } + last_end_idx = 0 + for start_idx in range(len(token_ids) - match_len + 1): + if start_idx < last_end_idx: + continue # Exclude overlapping matches - # Convert back to list - return [list(c) for c in unique_candidates] + end_idx = start_idx + match_len + if token_ids[start_idx:end_idx] == match_ids: + yield _TokenMatch(start_idx=start_idx, end_idx=end_idx) + last_end_idx = end_idx -def apply_placeholders( - token_ids: List[int], - placeholder_ids: List[int], - get_replacement_ids: Callable[[], List[int]], -) -> Optional[PlaceholderRange]: - """ - Find the first occurrence of :code:`placeholder_ids`, - and replace it with the output of :code:`get_replacement_ids`. +class _PromptReplacementMatch(ABC, Generic[_T, _S]): + prompt_repl: _BoundPromptReplacement[_T] + + @property + def modality(self) -> str: + return self.prompt_repl.modality + + @property + @abstractmethod + def start_idx(self) -> int: + raise NotImplementedError + + @property + @abstractmethod + def end_idx(self) -> int: + raise NotImplementedError + + @abstractmethod + def get_repl( + self, + mm_items: list[_T], + hf_inputs: BatchFeature, + item_idx: int, + ) -> _S: + raise NotImplementedError + + def __repr__(self) -> str: + return (f"{type(self).__name__}(modality={self.modality!r}, " + f"start_idx={self.start_idx!r}, end_idx={self.end_idx!r})") + + +@dataclass(repr=False) +class _PromptReplacementTokenMatch(_PromptReplacementMatch[_T, list[int]]): + prompt_repl: _BoundPromptReplacement[_T] + match: _TokenMatch + + @property + def start_idx(self) -> int: + return self.match.start_idx + + @property + def end_idx(self) -> int: + return self.match.end_idx + + def get_repl( + self, + mm_items: list[_T], + hf_inputs: BatchFeature, + item_idx: int, + ) -> list[int]: + prompt_repl = self.prompt_repl + count = prompt_repl.get_count(mm_items, hf_inputs, item_idx) + return prompt_repl.repl_unit.token_ids * count - This function updates :code:`token_ids` in place. + +@dataclass(repr=False) +class _PromptReplacementTextMatch(_PromptReplacementMatch[_T, str]): + prompt_repl: _BoundPromptReplacement[_T] + match: re.Match[str] + + @property + def start_idx(self) -> int: + return self.match.start() + + @property + def end_idx(self) -> int: + return self.match.end() + + def get_repl( + self, + mm_items: list[_T], + hf_inputs: BatchFeature, + item_idx: int, + ) -> str: + prompt_repl = self.prompt_repl + count = prompt_repl.get_count(mm_items, hf_inputs, item_idx) + return prompt_repl.repl_unit.text * count + + +def find_token_matches( + prompt: list[int], + prompt_repls: Sequence[_BoundPromptReplacement[_T]], +) -> list[_PromptReplacementTokenMatch[_T]]: + """Return each target of :code:`prompt_repls` found in :code:`prompt`.""" + return [ + _PromptReplacementTokenMatch(prompt_repl, match) + for prompt_repl in prompt_repls + for match in iter_token_matches(prompt, prompt_repl.target.token_ids) + ] + + +def find_text_matches( + prompt: str, + prompt_repls: Sequence[_BoundPromptReplacement[_T]], +) -> list[_PromptReplacementTextMatch[_T]]: + """Return each target of :code:`prompt_repls` found in :code:`prompt`.""" + return [ + _PromptReplacementTextMatch(prompt_repl, match) + for prompt_repl in prompt_repls + for match in re.finditer(re.escape(prompt_repl.target.text), prompt) + ] + + +def _resolve_matches( + prompt: _S, + matches: Sequence[_PromptReplacementMatch[_T, _S]], +) -> list[_PromptReplacementMatch[_T, _S]]: + """ + Resolve :code:`matches` to ensure that there are no overlapping matches, + and sort them such that earlier matches take priority over later ones. """ - placeholder_length = len(placeholder_ids) + num_matches_by_idx = np.zeros(len(prompt), dtype=int) + for match in matches: + num_matches_by_idx[match.start_idx:match.end_idx] += 1 + + duplicate_matches_idxs, = np.nonzero(num_matches_by_idx > 1) + if len(duplicate_matches_idxs) > 0: + raise ValueError("Unable to find a unique replacement " + f"at indices={duplicate_matches_idxs} " + f"of prompt={prompt}") + + return sorted(matches, key=lambda x: x.start_idx) + + +def _replace_matches( + prompt: _S, + matches: Sequence[_PromptReplacementMatch[_T, _S]], + mm_items_by_modality: Mapping[str, list[_T]], + hf_inputs: BatchFeature, +) -> list[_S]: + out_seqs = list[_S]() + prev_end_idx = 0 + next_idx_by_modality = {modality: 0 for modality in mm_items_by_modality} + + for match in _resolve_matches(prompt, matches): + modality = match.modality + mm_items = mm_items_by_modality[modality] + + item_idx = next_idx_by_modality[modality] + if item_idx >= len(mm_items): + continue + + start_idx = match.start_idx + end_idx = match.end_idx + repl_ids = match.get_repl(mm_items, hf_inputs, item_idx) + + out_seqs.append(prompt[prev_end_idx:start_idx] + repl_ids) + prev_end_idx = end_idx + next_idx_by_modality[modality] += 1 + + out_seqs.append(prompt[prev_end_idx:]) + + return out_seqs + + +def replace_token_matches( + prompt: list[int], + matches: Sequence[_PromptReplacementMatch[_T, list[int]]], + mm_items_by_modality: Mapping[str, list[_T]], + hf_inputs: BatchFeature, +) -> list[int]: + """Apply :code:`prompt_repls` to :code:`prompt`.""" + if not matches: + return prompt + + token_id_seqs = _replace_matches( + prompt, + matches, + mm_items_by_modality, + hf_inputs, + ) + + return flatten_2d_lists(token_id_seqs) - for start_idx in range(len(token_ids) - placeholder_length + 1): - if token_ids[start_idx:placeholder_length] == placeholder_ids: - token_ids[start_idx:placeholder_length] = get_replacement_ids() - return PlaceholderRange(offset=start_idx, - length=placeholder_length) +def replace_text_matches( + prompt: str, + matches: Sequence[_PromptReplacementMatch[_T, str]], + mm_items_by_modality: Mapping[str, list[_T]], + hf_inputs: BatchFeature, +) -> str: + """Apply :code:`prompt_repls` to :code:`prompt`.""" + if not matches: + return prompt - return None + texts = _replace_matches( + prompt, + matches, + mm_items_by_modality, + hf_inputs, + ) + + return "".join(texts) class MultiModalProcessor: @@ -212,62 +554,166 @@ def __call__( ) -> MultiModalInputsV2: return self.apply(prompt, mm_data, mm_processor_kwargs) - def apply( + def _find_placeholders( + self, + all_prompt_repls: Sequence[_BoundPromptReplacement[Any]], + new_token_ids: list[int], + *, + # To avoid false positives from multi-input when detecting + # whether placeholder tokens have been inserted, in case + # the target sequence is a subset of the replacement tokens + min_placeholder_count: int = 16, + ) -> list[_PlaceholderInfo]: + return list( + iter_placeholders( + all_prompt_repls, + new_token_ids, + min_placeholder_count=min_placeholder_count, + )) + + def _apply_hf_processor( self, prompt: str, mm_data: MultiModalDataDict, mm_processor_kwargs: Mapping[str, object], - ) -> MultiModalInputsV2: - tokenizer = self.ctx.tokenizer + ) -> BatchFeature: hf_processor = self.ctx.get_hf_processor() - processed_inputs = hf_processor( + return hf_processor( text=prompt, # type: ignore **mm_data, **mm_processor_kwargs, ) - new_token_ids, = processed_inputs.pop("input_ids").tolist() - mm_kwargs = MultiModalKwargs(processed_inputs) - mm_placeholders: Mapping[str, List[PlaceholderRange]] = {} + def _bind_prompt_replacements( + self, + mm_data: MultiModalDataDict, + ) -> list[_BoundPromptReplacement[Any]]: + tokenizer = self.ctx.tokenizer - for modality, orig_inputs in to_multi_format(mm_data).items(): - assert isinstance(orig_inputs, list) + return [ + prompt_repl.bind(modality, tokenizer) + for modality, metadata in self.metadata.items() + if modality in mm_data for prompt_repl in metadata.prompt_repls + ] - metadata = self.metadata[modality] - placeholder_replacements = metadata.placeholder_replacements + def _apply_prompt_replacements( + self, + mm_data: MultiModalDataDict, + hf_inputs: BatchFeature, + token_ids: list[int], + prompt_repls: Sequence[_BoundPromptReplacement[Any]], + ) -> tuple[list[int], str, list[_PlaceholderInfo]]: + tokenizer = self.ctx.tokenizer - modality_placeholders: List[PlaceholderRange] = [] + mm_items = to_multi_format(mm_data) + token_matches = find_token_matches(token_ids, prompt_repls) + + # If the search text does not represent a special token, + # it may have different token IDs in the prompt, because + # the tokens may go across the boundaries of the search text. + # ---- + # e.g. when searching for "foo" in "food", if "food" itself makes + # up a token, then the token ID of "foo" will not appear at all + # ---- + # Since it is inefficient to search for all possible tokenizations + # of the search text in the prompt, we instead perform string + # replacement on the decoded token IDs, then encode them back. + if all( + len(matches) >= len(mm_data[modality]) + for modality, matches in full_groupby_modality(token_matches) + ): # yapf: disable + token_ids = replace_token_matches( + token_ids, + token_matches, + mm_items, + hf_inputs, + ) + + text = _decode(tokenizer, token_ids) + matched_repls = [match.prompt_repl for match in token_matches] + else: + text = _decode(tokenizer, token_ids) + + text_matches = find_text_matches(text, prompt_repls) + text = replace_text_matches( + text, + text_matches, + mm_items, + hf_inputs, + ) + + token_ids = _encode(tokenizer, text) + matched_repls = [match.prompt_repl for match in text_matches] + + placeholders = self._find_placeholders(matched_repls, token_ids) + + # Sanity check + assert len(placeholders) == len(matched_repls), dict( + # Log this information for easier debugging + text=text, + token_ids=token_ids, + placeholders=placeholders, + matched_repls=matched_repls, + ) - for item_idx, orig_item in enumerate(orig_inputs): - for match_text, replace_fn in placeholder_replacements.items(): - candidates = candidate_placeholders(tokenizer, match_text) - get_replacement_ids = partial( - replace_fn, - orig_item, - processed_inputs, - item_idx, - ) + return token_ids, text, placeholders - for match_ids in candidates: - # TODO(youkaichao): Don't update new_token_ids - placeholders = apply_placeholders( - new_token_ids, - match_ids, - get_replacement_ids, - ) + def apply( + self, + prompt_text: str, + mm_data: MultiModalDataDict, + mm_processor_kwargs: Mapping[str, object], + ) -> MultiModalInputsV2: + """ + Process multi-modal inputs to be used in vLLM. + + The main steps are: + + 1. Apply HF Processor on prompt text and multi-modal data together, + outputting token IDs and processed tensors. + 2. Find and replace sequences in the token IDs with placeholder tokens. + The number of placeholder tokens equals the feature size of the + multi-modal data outputted by the multi-modal encoder. + 3. Extract information about the placeholder tokens from the + processed token IDs. + """ + tokenizer = self.ctx.tokenizer + + hf_inputs = self._apply_hf_processor(prompt_text, mm_data, + mm_processor_kwargs) + prompt_ids, = hf_inputs.pop("input_ids").tolist() + mm_kwargs = MultiModalKwargs(hf_inputs) - if placeholders is not None: - modality_placeholders.append(placeholders) + all_prompt_repls = self._bind_prompt_replacements(mm_data) - # yapf: disable - mm_placeholders[modality] = modality_placeholders # type: ignore[index] - # yapf: enable + # If HF processor already inserts placeholder tokens, + # there is no need for us to insert them + all_placeholders = self._find_placeholders(all_prompt_repls, + prompt_ids) + if all_placeholders: + prompt_text = _decode(tokenizer, prompt_ids) + else: + ( + prompt_ids, + prompt_text, + all_placeholders, + ) = self._apply_prompt_replacements( + mm_data, + hf_inputs, + prompt_ids, + all_prompt_repls, + ) + + mm_placeholders = { + modality: [item.to_range() for item in items] + for modality, items in full_groupby_modality(all_placeholders) + } return MultiModalInputsV2( type="multimodal", - prompt=prompt, - prompt_token_ids=new_token_ids, + prompt=prompt_text, + prompt_token_ids=prompt_ids, mm_kwargs=mm_kwargs, mm_placeholders=mm_placeholders, ) diff --git a/vllm/multimodal/registry.py b/vllm/multimodal/registry.py index b992442d3b314..b73daee98bd80 100644 --- a/vllm/multimodal/registry.py +++ b/vllm/multimodal/registry.py @@ -9,6 +9,7 @@ from vllm.inputs import InputProcessingContext from vllm.logger import init_logger from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.utils import ClassRegistry from .audio import AudioPlugin from .base import MultiModalInputMapper, MultiModalPlugin, MultiModalTokensCalc @@ -62,8 +63,8 @@ def __init__( plugins: Sequence[MultiModalPlugin] = DEFAULT_PLUGINS) -> None: self._plugins = {p.get_data_key(): p for p in plugins} - self._processor_factories: Dict[Type[nn.Module], - MultiModalProcessorFactory] = {} + self._processor_factories = ClassRegistry[nn.Module, + MultiModalProcessorFactory]() # This is used for non-multimodal models self._disabled_limits_per_plugin = {k: 0 for k in self._plugins} diff --git a/vllm/multimodal/utils.py b/vllm/multimodal/utils.py index 40194716bbf94..d4333b7519b47 100644 --- a/vllm/multimodal/utils.py +++ b/vllm/multimodal/utils.py @@ -6,6 +6,7 @@ import numpy as np import numpy.typing as npt +import torch from PIL import Image import vllm.envs as envs @@ -392,6 +393,49 @@ def encode_video_base64(frames: npt.NDArray): return ",".join(base64_frames) +def resolve_visual_encoder_outputs( + encoder_outputs: Union[torch.Tensor, list[torch.Tensor]], + feature_sample_layers: Optional[list[int]], + post_layer_norm: Optional[torch.nn.LayerNorm], + max_possible_layers: int, +) -> torch.Tensor: + """Given the outputs a visual encoder module that may correspond to the + output of the last layer, or a list of hidden states to be stacked, + handle post normalization and resolve it into a single output tensor. + + Args: + encoder_outputs: Output of encoder's last layer or all hidden states. + feature_sample_layers: Optional layer indices to grab from the encoder + outputs; if provided, encoder outputs must be a list. + post_layer_norm: Post norm to apply to the output of the encoder. + max_possible_layers: Total layers in the fully loaded visual encoder. + + """ + if feature_sample_layers is None: + if post_layer_norm is not None: + return post_layer_norm(encoder_outputs) + return encoder_outputs + + # Get the hidden states corresponding to the layer indices. + # Negative values are relative to the full visual encoder, + # so offset them depending on how many layers were loaded. + # NOTE: this assumes that encoder_outputs contains a list + # of hidden states in the same order as the encoder layers + # that produced them. + offset = max_possible_layers - len(encoder_outputs) + hs_pool = [ + encoder_outputs[layer_idx] + if layer_idx >= 0 else encoder_outputs[layer_idx + offset] + for layer_idx in feature_sample_layers + ] + + # Apply post-norm on the final hidden state if we are using it + uses_last_layer = feature_sample_layers[-1] in (len(hs_pool) - 1, -1) + if post_layer_norm is not None and uses_last_layer: + hs_pool[-1] = post_layer_norm(encoder_outputs) + return torch.cat(hs_pool, dim=-1) + + # Utilities for input processors _T = TypeVar("_T", str, int) diff --git a/vllm/outputs.py b/vllm/outputs.py index 4ae9b377ae693..86264f604f6bc 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -53,18 +53,17 @@ def __repr__(self) -> str: @dataclass -class EmbeddingOutput: - """The output data of one completion output of a request. +class PoolingOutput: + """The output data of one pooling output of a request. Args: embedding: The embedding vector, which is a list of floats. The length of vector depends on the model as listed in the embedding guide. """ - embedding: List[float] def __repr__(self) -> str: - return (f"EmbeddingOutput(" + return (f"PoolingOutput(" f"embedding={len(self.embedding)})") @@ -317,18 +316,18 @@ def __repr__(self) -> str: f"multi_modal_placeholders={self.multi_modal_placeholders})") -class EmbeddingRequestOutput: +class PoolingRequestOutput: """ - The output data of an embedding request to the LLM. + The output data of a pooling request to the LLM. Args: - request_id (str): A unique identifier for the embedding request. - outputs (EmbeddingOutput): The embedding results for the given input. + request_id (str): A unique identifier for the pooling request. + outputs (PoolingOutput): The pooling results for the given input. prompt_token_ids (List[int]): A list of token IDs used in the prompt. - finished (bool): A flag indicating whether the embedding is completed. + finished (bool): A flag indicating whether the pooling is completed. """ - def __init__(self, request_id: str, outputs: "EmbeddingOutput", + def __init__(self, request_id: str, outputs: "PoolingOutput", prompt_token_ids: List[int], finished: bool): self.request_id = request_id self.prompt_token_ids = prompt_token_ids @@ -337,11 +336,11 @@ def __init__(self, request_id: str, outputs: "EmbeddingOutput", @classmethod def from_seq_group(cls, - seq_group: 'SequenceGroup') -> "EmbeddingRequestOutput": + seq_group: 'SequenceGroup') -> "PoolingRequestOutput": if seq_group.embeddings is None: raise ValueError( "Embeddings are missing in seq_group for EmbeddingRequest.") - output = EmbeddingOutput(seq_group.embeddings) + output = PoolingOutput(seq_group.embeddings) prompt_token_ids = seq_group.prompt_token_ids finished = seq_group.is_finished() @@ -349,20 +348,64 @@ def from_seq_group(cls, def __repr__(self): """ - Returns a string representation of an EmbeddingRequestOutput instance. + Returns a string representation of an PoolingRequestOutput instance. The representation includes the request_id and the number of outputs, - providing a quick overview of the embedding request's results. + providing a quick overview of the pooling request's results. Returns: - str: A string representation of the EmbeddingRequestOutput instance. + str: A string representation of the PoolingRequestOutput instance. """ - return (f"EmbeddingRequestOutput(request_id='{self.request_id}', " + return (f"PoolingRequestOutput(request_id='{self.request_id}', " f"outputs={repr(self.outputs)}, " f"prompt_token_ids={self.prompt_token_ids}, " f"finished={self.finished})") +@dataclass +class ScoreOutput: + """The output data of one completion output of a request. + + Args: + score: The score, which is a list of floats. + index: The correspondent text index of the score. + """ + index: int + score: List[float] + + def __repr__(self) -> str: + return (f"ScoreOutput(" + f"score={self.score}), " + f"index={self.index})") + + +class ScoreRequestOutput: + """ + The output data of an score request to the LLM. + + Args: + request_id (str): A unique identifier for the score request. + outputs (score): The embedding results for the given input. + """ + + def __init__(self, request_id: str, outputs: "ScoreOutput"): + self.request_id = request_id + self.outputs = outputs + + def __repr__(self): + """ + Returns a string representation of an ScoreRequestOutput instance. + + The representation includes the request_id and the number of outputs, + providing a quick overview of the embedding request's results. + + Returns: + str: A string representation of the ScoreRequestOutput instance. + """ + return (f"ScoreRequestOutput(request_id='{self.request_id}', " + f"outputs={repr(self.outputs)}") + + class RequestOutputFactory: @staticmethod @@ -372,7 +415,30 @@ def create(seq_group: SequenceGroup, # Determine the type based on a condition, for example: if hasattr(seq_group, 'embeddings') and seq_group.embeddings is not None: - return EmbeddingRequestOutput.from_seq_group(seq_group) + return PoolingRequestOutput.from_seq_group(seq_group) else: return RequestOutput.from_seq_group(seq_group, use_cache, seq_id_to_seq_group) + + +def __getattr__(name: str): + import warnings + + if name == "EmbeddingOutput": + msg = ("EmbeddingOutput has been renamed to PoolingOutput. " + "The original name will be removed in an upcoming version.") + + warnings.warn(DeprecationWarning(msg), stacklevel=2) + + return PoolingOutput + + if name == "EmbeddingRequestOutput": + msg = ("EmbeddingRequestOutput has been renamed to " + "PoolingRequestOutput. " + "The original name will be removed in an upcoming version.") + + warnings.warn(DeprecationWarning(msg), stacklevel=2) + + return PoolingRequestOutput + + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/vllm/platforms/__init__.py b/vllm/platforms/__init__.py index 9e740837381f8..7cb8ac4b0a1e0 100644 --- a/vllm/platforms/__init__.py +++ b/vllm/platforms/__init__.py @@ -1,3 +1,4 @@ +from .interface import _Backend # noqa: F401 from .interface import Platform, PlatformEnum, UnspecifiedPlatform current_platform: Platform @@ -27,7 +28,15 @@ finally: pynvml.nvmlShutdown() except Exception: - pass + # CUDA is supported on Jetson, but NVML may not be. + import os + + def cuda_is_jetson() -> bool: + return os.path.isfile("/etc/nv_tegra_release") \ + or os.path.exists("/sys/class/tegra-firmware") + + if cuda_is_jetson(): + is_cuda = True is_rocm = False diff --git a/vllm/platforms/cpu.py b/vllm/platforms/cpu.py index 42bee31dfb0e9..b5333fbd6f502 100644 --- a/vllm/platforms/cpu.py +++ b/vllm/platforms/cpu.py @@ -5,7 +5,9 @@ from vllm.logger import init_logger -from .interface import Platform, PlatformEnum +from .interface import Platform, PlatformEnum, _Backend + +logger = init_logger(__name__) if TYPE_CHECKING: from vllm.config import VllmConfig @@ -17,11 +19,20 @@ class CpuPlatform(Platform): _enum = PlatformEnum.CPU + device_name: str = "cpu" + device_type: str = "cpu" + dispatch_key: str = "CPU" @classmethod def get_device_name(cls, device_id: int = 0) -> str: return "cpu" + @classmethod + def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend: + if selected_backend != _Backend.TORCH_SDPA: + logger.info("Cannot use %s backend on CPU.", selected_backend) + return _Backend.TORCH_SDPA + @classmethod def get_device_total_memory(cls, device_id: int = 0) -> int: return psutil.virtual_memory().total @@ -45,11 +56,6 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: cache_config = vllm_config.cache_config - if cache_config.enable_prefix_caching: - logger.warning( - "Prefix caching is not supported on CPU, disable it.") - cache_config.enable_prefix_caching = False - kv_cache_space = envs.VLLM_CPU_KVCACHE_SPACE if kv_cache_space >= 0: @@ -66,10 +72,12 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: f" {kv_cache_space}, expect a positive integer value.") scheduler_config = vllm_config.scheduler_config - if scheduler_config.chunked_prefill_enabled: - logger.warning( - "Chunked prefill is not supported on CPU, disable it.") - scheduler_config.chunked_prefill_enabled = False + if ((scheduler_config.chunked_prefill_enabled + or cache_config.enable_prefix_caching) + and model_config.dtype == torch.half): + logger.warning("Chunked-prefill on the CPU backend only does not" + " support fp16 for now, cast to bf16.") + model_config.dtype = torch.bfloat16 parallel_config = vllm_config.parallel_config if (parallel_config.distributed_executor_backend is not None @@ -78,3 +86,11 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: "distributed executor backend."), parallel_config.distributed_executor_backend) parallel_config.distributed_executor_backend = "mp" + if parallel_config.worker_cls == "auto": + if vllm_config.speculative_config: + parallel_config.worker_cls = \ + "vllm.spec_decode.spec_decode_worker.create_spec_worker" + parallel_config.sd_worker_cls = \ + "vllm.worker.cpu_worker.CPUWorker" + else: + parallel_config.worker_cls = "vllm.worker.cpu_worker.CPUWorker" diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 9c5212ace1346..846a1869da228 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -4,16 +4,23 @@ import os from functools import lru_cache, wraps -from typing import Callable, List, Tuple, TypeVar +from typing import TYPE_CHECKING, Callable, List, TypeVar import pynvml import torch from typing_extensions import ParamSpec +# import custom ops, trigger op registration +import vllm._C # noqa from vllm.logger import init_logger from .interface import DeviceCapability, Platform, PlatformEnum +if TYPE_CHECKING: + from vllm.config import VllmConfig +else: + VllmConfig = None + logger = init_logger(__name__) _P = ParamSpec("_P") @@ -31,10 +38,23 @@ # see https://github.com/huggingface/diffusers/issues/9704 for details torch.backends.cuda.enable_cudnn_sdp(False) -# NVML utils -# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`, -# all the related functions work on real physical device ids. -# the major benefit of using NVML is that it will not initialize CUDA + +def device_id_to_physical_device_id(device_id: int) -> int: + if "CUDA_VISIBLE_DEVICES" in os.environ: + device_ids = os.environ["CUDA_VISIBLE_DEVICES"].split(",") + if device_ids == [""]: + msg = ( + "CUDA_VISIBLE_DEVICES is set to empty string, which means" + " GPU support is disabled. If you are using ray, please unset" + " the environment variable `CUDA_VISIBLE_DEVICES` inside the" + " worker/actor. " + "Check https://github.com/vllm-project/vllm/issues/8402 for" + " more information.") + raise RuntimeError(msg) + physical_device_id = device_ids[device_id] + return int(physical_device_id) + else: + return device_id def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]: @@ -50,79 +70,78 @@ def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R: return wrapper -@lru_cache(maxsize=8) -@with_nvml_context -def get_physical_device_capability(device_id: int = 0) -> Tuple[int, int]: - handle = pynvml.nvmlDeviceGetHandleByIndex(device_id) - return pynvml.nvmlDeviceGetCudaComputeCapability(handle) - - -@lru_cache(maxsize=8) -@with_nvml_context -def get_physical_device_name(device_id: int = 0) -> str: - handle = pynvml.nvmlDeviceGetHandleByIndex(device_id) - return pynvml.nvmlDeviceGetName(handle) - - -@lru_cache(maxsize=8) -@with_nvml_context -def get_physical_device_total_memory(device_id: int = 0) -> int: - handle = pynvml.nvmlDeviceGetHandleByIndex(device_id) - return int(pynvml.nvmlDeviceGetMemoryInfo(handle).total) - +class CudaPlatformBase(Platform): + _enum = PlatformEnum.CUDA + device_name: str = "cuda" + device_type: str = "cuda" + dispatch_key: str = "CUDA" -@with_nvml_context -def warn_if_different_devices(): - device_ids: int = pynvml.nvmlDeviceGetCount() - if device_ids > 1: - device_names = [get_physical_device_name(i) for i in range(device_ids)] - if len(set(device_names)) > 1 and os.environ.get( - "CUDA_DEVICE_ORDER") != "PCI_BUS_ID": - logger.warning( - "Detected different devices in the system: \n%s\nPlease" - " make sure to set `CUDA_DEVICE_ORDER=PCI_BUS_ID` to " - "avoid unexpected behavior.", "\n".join(device_names)) + @classmethod + def get_device_capability(cls, device_id: int = 0) -> DeviceCapability: + raise NotImplementedError + @classmethod + def get_device_name(cls, device_id: int = 0) -> str: + raise NotImplementedError -try: - from sphinx.ext.autodoc.mock import _MockModule + @classmethod + def get_device_total_memory(cls, device_id: int = 0) -> int: + raise NotImplementedError - if not isinstance(pynvml, _MockModule): - warn_if_different_devices() -except ModuleNotFoundError: - warn_if_different_devices() + @classmethod + def is_full_nvlink(cls, device_ids: List[int]) -> bool: + raise NotImplementedError + @classmethod + def log_warnings(cls): + pass -def device_id_to_physical_device_id(device_id: int) -> int: - if "CUDA_VISIBLE_DEVICES" in os.environ: - device_ids = os.environ["CUDA_VISIBLE_DEVICES"].split(",") - if device_ids == [""]: - raise RuntimeError("CUDA_VISIBLE_DEVICES is set to empty string," - " which means GPU support is disabled.") - physical_device_id = device_ids[device_id] - return int(physical_device_id) - else: - return device_id + @classmethod + def check_and_update_config(cls, vllm_config: VllmConfig) -> None: + parallel_config = vllm_config.parallel_config + scheduler_config = vllm_config.scheduler_config + if parallel_config.worker_cls == "auto": + if scheduler_config.is_multi_step: + parallel_config.worker_cls = \ + "vllm.worker.multi_step_worker.MultiStepWorker" + elif vllm_config.speculative_config: + parallel_config.worker_cls = \ + "vllm.spec_decode.spec_decode_worker.create_spec_worker" + parallel_config.sd_worker_cls = \ + "vllm.worker.worker.Worker" + else: + parallel_config.worker_cls = "vllm.worker.worker.Worker" -class CudaPlatform(Platform): - _enum = PlatformEnum.CUDA +# NVML utils +# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`, +# all the related functions work on real physical device ids. +# the major benefit of using NVML is that it will not initialize CUDA +class NvmlCudaPlatform(CudaPlatformBase): @classmethod + @lru_cache(maxsize=8) + @with_nvml_context def get_device_capability(cls, device_id: int = 0) -> DeviceCapability: physical_device_id = device_id_to_physical_device_id(device_id) - major, minor = get_physical_device_capability(physical_device_id) + handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id) + major, minor = pynvml.nvmlDeviceGetCudaComputeCapability(handle) return DeviceCapability(major=major, minor=minor) @classmethod + @lru_cache(maxsize=8) + @with_nvml_context def get_device_name(cls, device_id: int = 0) -> str: physical_device_id = device_id_to_physical_device_id(device_id) - return get_physical_device_name(physical_device_id) + return cls._get_physical_device_name(physical_device_id) @classmethod + @lru_cache(maxsize=8) + @with_nvml_context def get_device_total_memory(cls, device_id: int = 0) -> int: physical_device_id = device_id_to_physical_device_id(device_id) - return get_physical_device_total_memory(physical_device_id) + handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id) + return int(pynvml.nvmlDeviceGetMemoryInfo(handle).total) @classmethod @with_nvml_context @@ -138,13 +157,86 @@ def is_full_nvlink(cls, physical_device_ids: List[int]) -> bool: if i < j: try: p2p_status = pynvml.nvmlDeviceGetP2PStatus( - handle, peer_handle, - pynvml.NVML_P2P_CAPS_INDEX_NVLINK) + handle, + peer_handle, + pynvml.NVML_P2P_CAPS_INDEX_NVLINK, + ) if p2p_status != pynvml.NVML_P2P_STATUS_OK: return False except pynvml.NVMLError: logger.exception( - "NVLink detection failed. This is normal if your" - " machine has no NVLink equipped.") + "NVLink detection failed. This is normal if" + " your machine has no NVLink equipped.") return False return True + + @classmethod + def _get_physical_device_name(cls, device_id: int = 0) -> str: + handle = pynvml.nvmlDeviceGetHandleByIndex(device_id) + return pynvml.nvmlDeviceGetName(handle) + + @classmethod + @with_nvml_context + def log_warnings(cls): + device_ids: int = pynvml.nvmlDeviceGetCount() + if device_ids > 1: + device_names = [ + cls._get_physical_device_name(i) for i in range(device_ids) + ] + if (len(set(device_names)) > 1 + and os.environ.get("CUDA_DEVICE_ORDER") != "PCI_BUS_ID"): + logger.warning( + "Detected different devices in the system: \n%s\nPlease" + " make sure to set `CUDA_DEVICE_ORDER=PCI_BUS_ID` to " + "avoid unexpected behavior.", + "\n".join(device_names), + ) + + +class NonNvmlCudaPlatform(CudaPlatformBase): + + @classmethod + def get_device_capability(cls, device_id: int = 0) -> DeviceCapability: + major, minor = torch.cuda.get_device_capability(device_id) + return DeviceCapability(major=major, minor=minor) + + @classmethod + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(device_id) + + @classmethod + def get_device_total_memory(cls, device_id: int = 0) -> int: + device_props = torch.cuda.get_device_properties(device_id) + return device_props.total_memory + + @classmethod + def is_full_nvlink(cls, physical_device_ids: List[int]) -> bool: + logger.exception( + "NVLink detection not possible, as context support was" + " not found. Assuming no NVLink available.") + return False + + +# Autodetect either NVML-enabled or non-NVML platform +# based on whether NVML is available. +nvml_available = False +try: + try: + pynvml.nvmlInit() + nvml_available = True + except Exception: + # On Jetson, NVML is not supported. + nvml_available = False +finally: + if nvml_available: + pynvml.nvmlShutdown() + +CudaPlatform = NvmlCudaPlatform if nvml_available else NonNvmlCudaPlatform + +try: + from sphinx.ext.autodoc.mock import _MockModule + + if not isinstance(pynvml, _MockModule): + CudaPlatform.log_warnings() +except ModuleNotFoundError: + CudaPlatform.log_warnings() \ No newline at end of file diff --git a/vllm/platforms/hpu.py b/vllm/platforms/hpu.py index 170cfff94f90d..10aaa6d54962c 100644 --- a/vllm/platforms/hpu.py +++ b/vllm/platforms/hpu.py @@ -1,11 +1,41 @@ +from typing import TYPE_CHECKING + import torch -from .interface import Platform, PlatformEnum +from .interface import Platform, PlatformEnum, _Backend + +if TYPE_CHECKING: + from vllm.config import VllmConfig +else: + VllmConfig = None class HpuPlatform(Platform): _enum = PlatformEnum.HPU + device_name: str = "hpu" + device_type: str = "hpu" + dispatch_key: str = "HPU" + + @classmethod + def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend: + return _Backend.HPU_ATTN @staticmethod def inference_mode(): return torch.no_grad() + + @classmethod + def check_and_update_config(cls, vllm_config: VllmConfig) -> None: + + scheduler_config = vllm_config.scheduler_config + if scheduler_config.is_multi_step: + raise NotImplementedError( + "Multi-step execution is not implemented for HPU") + + if vllm_config.speculative_config is not None: + raise NotImplementedError( + "Speculative decoding is not implemented for HPU") + + parallel_config = vllm_config.parallel_config + if parallel_config.worker_cls == "auto": + parallel_config.worker_cls = "vllm.worker.hpu_worker.HPUWorker" diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 970c0d1be617e..eac2b413f9271 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -11,6 +11,20 @@ VllmConfig = None +class _Backend(enum.Enum): + FLASH_ATTN = enum.auto() + FLASH_ATTN_VLLM_V1 = enum.auto() + XFORMERS = enum.auto() + ROCM_FLASH = enum.auto() + TORCH_SDPA = enum.auto() + OPENVINO = enum.auto() + FLASHINFER = enum.auto() + HPU_ATTN = enum.auto() + PALLAS = enum.auto() + IPEX = enum.auto() + NO_ATTENTION = enum.auto() + + class PlatformEnum(enum.Enum): CUDA = enum.auto() ROCM = enum.auto() @@ -42,6 +56,13 @@ def to_int(self) -> int: class Platform: _enum: PlatformEnum + device_name: str + device_type: str + # available dispatch keys: + # check https://github.com/pytorch/pytorch/blob/313dac6c1ca0fa0cde32477509cce32089f8532a/torchgen/model.py#L134 # noqa + # use "CPU" as a fallback for platforms not registered in PyTorch + dispatch_key: str = "CPU" + supported_quantization: list[str] = [] def is_cuda(self) -> bool: return self._enum == PlatformEnum.CUDA @@ -71,6 +92,11 @@ def is_cuda_alike(self) -> bool: """Stateless version of :func:`torch.cuda.is_available`.""" return self._enum in (PlatformEnum.CUDA, PlatformEnum.ROCM) + @classmethod + def get_default_attn_backend(cls, selected_backend: _Backend): + """Get the default attention backend of a device.""" + return None + @classmethod def get_device_capability( cls, @@ -147,6 +173,18 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: """ pass + @classmethod + def verify_quantization(cls, quant: str) -> None: + """ + Verify whether the quantization is supported by the current platform. + """ + if cls.supported_quantization and \ + quant not in cls.supported_quantization: + raise ValueError( + f"{quant} quantization is currently not supported in " + f"{cls.device_name}.") + class UnspecifiedPlatform(Platform): _enum = PlatformEnum.UNSPECIFIED + device_type = "" diff --git a/vllm/platforms/neuron.py b/vllm/platforms/neuron.py index 07d8398eda525..87655ea198303 100644 --- a/vllm/platforms/neuron.py +++ b/vllm/platforms/neuron.py @@ -1,9 +1,26 @@ +from typing import TYPE_CHECKING + from .interface import Platform, PlatformEnum +if TYPE_CHECKING: + from vllm.config import VllmConfig +else: + VllmConfig = None + class NeuronPlatform(Platform): _enum = PlatformEnum.NEURON + device_name: str = "neuron" + device_type: str = "neuron" + supported_quantization: list[str] = ["neuron_quant"] @classmethod def get_device_name(cls, device_id: int = 0) -> str: return "neuron" + + @classmethod + def check_and_update_config(cls, vllm_config: VllmConfig) -> None: + parallel_config = vllm_config.parallel_config + if parallel_config.worker_cls == "auto": + parallel_config.worker_cls = \ + "vllm.worker.neuron_worker.NeuronWorker" diff --git a/vllm/platforms/openvino.py b/vllm/platforms/openvino.py index 31fe3f1fcbfe4..29b61e955d9ab 100644 --- a/vllm/platforms/openvino.py +++ b/vllm/platforms/openvino.py @@ -1,15 +1,37 @@ +from typing import TYPE_CHECKING + import torch import vllm.envs as envs from vllm.logger import init_logger -from .interface import Platform, PlatformEnum +from .interface import Platform, PlatformEnum, _Backend + +if TYPE_CHECKING: + from vllm.config import VllmConfig +else: + VllmConfig = None logger = init_logger(__name__) +try: + import openvino as ov + import openvino.properties.hint as hints +except ImportError as e: + logger.warning("Failed to import OpenVINO with %r", e) + class OpenVinoPlatform(Platform): _enum = PlatformEnum.OPENVINO + device_name: str = "openvino" + device_type: str = "openvino" + dispatch_key: str = "CPU" + + @classmethod + def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend: + if selected_backend != _Backend.OPENVINO: + logger.info("Cannot use %s backend on OpenVINO.", selected_backend) + return _Backend.OPENVINO @classmethod def get_device_name(self, device_id: int = 0) -> str: @@ -31,3 +53,81 @@ def is_openvino_gpu(self) -> bool: def is_pin_memory_available(self) -> bool: logger.warning("Pin memory is not supported on OpenViNO.") return False + + @classmethod + def check_and_update_config(cls, vllm_config: VllmConfig) -> None: + from vllm.utils import GiB_bytes + + parallel_config = vllm_config.parallel_config + assert ( + parallel_config.world_size == 1 + ), "OpenVINOExecutor only supports single CPU socket currently." + + if parallel_config.worker_cls == "auto": + parallel_config.worker_cls = \ + "vllm.worker.openvino_worker.OpenVINOWorker" + + # check and update model config + model_config = vllm_config.model_config + if model_config.dtype != torch.float32: + logger.warning( + f"Only float32 dtype is supported on OpenVINO, casting from {model_config.dtype}." # noqa: G004, E501 + ) + model_config.dtype = torch.float32 + if not model_config.enforce_eager: + logger.warning( + "CUDA graph is not supported on OpenVINO backend, fallback to " + "the eager mode.") + model_config.enforce_eager = True + + # check and update cache config + ov_core = ov.Core() + cache_config = vllm_config.cache_config + if envs.VLLM_OPENVINO_CPU_KV_CACHE_PRECISION == "u8": + if not OpenVinoPlatform.is_openvino_cpu(): + logger.info("VLLM_OPENVINO_CPU_KV_CACHE_PRECISION is" + "ignored for GPU, f16 data type will be used.") + cache_config.cache_dtype = ov.Type.f16 + else: + logger.info("KV cache type is overridden to u8 via " + "VLLM_OPENVINO_CPU_KV_CACHE_PRECISION env var.") + cache_config.cache_dtype = ov.Type.u8 + else: + if OpenVinoPlatform.is_openvino_cpu(): + ov_device = envs.VLLM_OPENVINO_DEVICE + inference_precision = ov_core.get_property( + ov_device, hints.inference_precision) + if inference_precision == ov.Type.bf16: + cache_config.cache_dtype = ov.Type.bf16 + else: + cache_config.cache_dtype = ov.Type.f16 + else: + cache_config.cache_dtype = ov.Type.f16 + + if OpenVinoPlatform.is_openvino_cpu(): + if cache_config.block_size != 32: + logger.info( + f"OpenVINO CPU optimal block size is 32, overriding currently set {cache_config.block_size}" # noqa: G004, E501 + ) + cache_config.block_size = 32 + else: + if cache_config.block_size != 16: + logger.info( + f"OpenVINO GPU optimal block size is 16, overriding currently set {cache_config.block_size}" # noqa: G004, E501 + ) + cache_config.block_size = 16 + + kv_cache_space = envs.VLLM_OPENVINO_KVCACHE_SPACE + if kv_cache_space >= 0: + if kv_cache_space == 0 and OpenVinoPlatform.is_openvino_cpu(): + cache_config.openvino_kvcache_space_bytes = 4 * GiB_bytes # type: ignore + logger.warning( + "Environment variable VLLM_OPENVINO_KVCACHE_SPACE (GB) " + "for OpenVINO backend is not set, using 4 by default.") + else: + cache_config.openvino_kvcache_space_bytes = ( # type: ignore + kv_cache_space * GiB_bytes) + else: + raise RuntimeError( + "Invalid environment variable VLLM_OPENVINO_KVCACHE_SPACE" + f" {kv_cache_space}, expect a positive integer value.") diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index dc46b03886e1d..00efa056f7ef0 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -1,18 +1,35 @@ import os from functools import lru_cache, wraps -from typing import List +from typing import TYPE_CHECKING, List import torch from amdsmi import (AmdSmiException, amdsmi_get_gpu_board_info, amdsmi_get_processor_handles, amdsmi_init, amdsmi_shut_down, amdsmi_topo_get_link_type) +import vllm.envs as envs from vllm.logger import init_logger -from .interface import DeviceCapability, Platform, PlatformEnum +from .interface import DeviceCapability, Platform, PlatformEnum, _Backend + +if TYPE_CHECKING: + from vllm.config import VllmConfig +else: + VllmConfig = None logger = init_logger(__name__) +try: + import vllm._C # noqa: F401 +except ImportError as e: + logger.warning("Failed to import from vllm._C with %r", e) + +# import custom ops, trigger op registration +try: + import vllm._rocm_C # noqa: F401 +except ImportError as e: + logger.warning("Failed to import from vllm._rocm_C with %r", e) + if os.environ.get("VLLM_WORKER_MULTIPROC_METHOD", None) in ["fork", None]: logger.warning("`fork` method is not supported by ROCm. " "VLLM_WORKER_MULTIPROC_METHOD is overridden to" @@ -57,6 +74,25 @@ def device_id_to_physical_device_id(device_id: int) -> int: class RocmPlatform(Platform): _enum = PlatformEnum.ROCM + device_name: str = "rocm" + device_type: str = "cuda" + dispatch_key: str = "CUDA" + supported_quantization: list[str] = [ + "awq", "gptq", "fp8", "compressed_tensors", "compressed-tensors", + "fbgemm_fp8", "gguf" + ] + + @classmethod + def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend: + selected_backend = (_Backend.ROCM_FLASH if selected_backend + == _Backend.FLASH_ATTN else selected_backend) + if selected_backend == _Backend.ROCM_FLASH: + if not cls.has_device_capability(90): + # not Instinct series GPUs. + logger.info("flash_attn is not supported on NAVI GPUs.") + else: + logger.info("%s is not supported in AMD GPUs.", selected_backend) + return _Backend.ROCM_FLASH @classmethod @lru_cache(maxsize=8) @@ -102,3 +138,26 @@ def get_device_name(cls, device_id: int = 0) -> str: def get_device_total_memory(cls, device_id: int = 0) -> int: device_props = torch.cuda.get_device_properties(device_id) return device_props.total_memory + + @classmethod + def check_and_update_config(cls, vllm_config: VllmConfig) -> None: + parallel_config = vllm_config.parallel_config + scheduler_config = vllm_config.scheduler_config + if parallel_config.worker_cls == "auto": + if scheduler_config.is_multi_step: + parallel_config.worker_cls = \ + "vllm.worker.multi_step_worker.MultiStepWorker" + elif vllm_config.speculative_config: + parallel_config.worker_cls = \ + "vllm.spec_decode.spec_decode_worker.create_spec_worker" + else: + parallel_config.worker_cls = "vllm.worker.worker.Worker" + + @classmethod + def verify_quantization(cls, quant: str) -> None: + super().verify_quantization(quant) + if quant == "awq" and not envs.VLLM_USE_TRITON_AWQ: + logger.warning( + "Using AWQ quantization with ROCm, but VLLM_USE_TRITON_AWQ" + " is not set, enabling VLLM_USE_TRITON_AWQ.") + envs.VLLM_USE_TRITON_AWQ = True diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index 643db835c85ff..b138f7e1c54c5 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -1,18 +1,31 @@ -import os from typing import TYPE_CHECKING import torch -from .interface import Platform, PlatformEnum +from vllm.logger import init_logger + +from .interface import Platform, PlatformEnum, _Backend if TYPE_CHECKING: from vllm.config import VllmConfig else: VllmConfig = None +logger = init_logger(__name__) + class TpuPlatform(Platform): _enum = PlatformEnum.TPU + device_name: str = "tpu" + device_type: str = "tpu" + dispatch_key: str = "XLA" + supported_quantization: list[str] = ["tpu_int8"] + + @classmethod + def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend: + if selected_backend != _Backend.PALLAS: + logger.info("Cannot use %s backend on TPU.", selected_backend) + return _Backend.PALLAS @classmethod def get_device_name(cls, device_id: int = 0) -> str: @@ -30,10 +43,23 @@ def inference_mode(cls): def check_and_update_config(cls, vllm_config: VllmConfig) -> None: from vllm.config import CompilationLevel compilation_config = vllm_config.compilation_config - if "VLLM_TORCH_COMPILE_LEVEL" not in os.environ: + if compilation_config.level == CompilationLevel.NO_COMPILATION: + # TPU does not support NO_COMPILATION compilation_config.level = CompilationLevel.DYNAMO_ONCE assert compilation_config.level < CompilationLevel.PIECEWISE,\ "TPU does not support Inductor." if compilation_config.backend == "": compilation_config.backend = "openxla" + + assert vllm_config.speculative_config is None, \ + "TPU does not support speculative decoding" + + parallel_config = vllm_config.parallel_config + scheduler_config = vllm_config.scheduler_config + if parallel_config.worker_cls == "auto": + if scheduler_config.is_multi_step: + parallel_config.worker_cls = \ + "vllm.worker.multi_step_tpu_worker.MultiStepTPUWorker" + else: + parallel_config.worker_cls = "vllm.worker.tpu_worker.TPUWorker" diff --git a/vllm/platforms/xpu.py b/vllm/platforms/xpu.py index 106e8eddf458f..9665786f4c499 100644 --- a/vllm/platforms/xpu.py +++ b/vllm/platforms/xpu.py @@ -1,10 +1,30 @@ +from typing import TYPE_CHECKING + import torch -from .interface import DeviceCapability, Platform, PlatformEnum +from vllm.logger import init_logger + +from .interface import DeviceCapability, Platform, PlatformEnum, _Backend + +if TYPE_CHECKING: + from vllm.config import VllmConfig +else: + VllmConfig = None + +logger = init_logger(__name__) class XPUPlatform(Platform): _enum = PlatformEnum.XPU + device_name: str = "xpu" + device_type: str = "xpu" + dispatch_key: str = "XPU" + + @classmethod + def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend: + if selected_backend != _Backend.IPEX: + logger.info("Cannot use %s backend on XPU.", selected_backend) + return _Backend.IPEX @staticmethod def get_device_capability(device_id: int = 0) -> DeviceCapability: @@ -24,3 +44,33 @@ def get_device_total_memory(cls, device_id: int = 0) -> int: @staticmethod def inference_mode(): return torch.no_grad() + + @classmethod + def check_and_update_config(cls, vllm_config: VllmConfig) -> None: + # check and update model config + model_config = vllm_config.model_config + if model_config.dtype == torch.bfloat16: + logger.warning( + "bfloat16 is not fully supported on XPU, casting to float16.") + model_config.dtype = torch.float16 + if not model_config.enforce_eager: + logger.warning( + "CUDA graph is not supported on XPU, fallback to the eager " + "mode.") + model_config.enforce_eager = True + + if vllm_config.speculative_config is not None: + raise NotImplementedError( + "XPU does not support speculative decoding") + + # check and update parallel config + parallel_config = vllm_config.parallel_config + if (parallel_config.distributed_executor_backend is not None + and parallel_config.distributed_executor_backend != "ray"): + logger.warning( + "%s is not supported on XPU, fallback to ray distributed" + " executor backend.", + parallel_config.distributed_executor_backend) + parallel_config.distributed_executor_backend = "ray" + if parallel_config.worker_cls == "auto": + parallel_config.worker_cls = "vllm.worker.xpu_worker.XPUWorker" diff --git a/vllm/plugins/__init__.py b/vllm/plugins/__init__.py index a0c73a752b5e8..3c64726ca3344 100644 --- a/vllm/plugins/__init__.py +++ b/vllm/plugins/__init__.py @@ -1,23 +1,34 @@ import logging -from contextlib import contextmanager -from typing import TYPE_CHECKING, Optional +import os -import vllm.envs as envs +import torch -if TYPE_CHECKING: - from vllm.config import CompilationConfig, VllmConfig -else: - CompilationConfig = None - VllmConfig = None +import vllm.envs as envs logger = logging.getLogger(__name__) +# make sure one process only loads plugins once +plugins_loaded = False + def load_general_plugins(): """WARNING: plugins can be loaded for multiple times in different processes. They should be designed in a way that they can be loaded multiple times without causing issues. """ + + # all processes created by vllm will load plugins, + # and here we can inject some common environment variables + # for all processes. + + # see https://github.com/vllm-project/vllm/issues/10480 + os.environ['TORCHINDUCTOR_COMPILE_THREADS'] = '1' + # see https://github.com/vllm-project/vllm/issues/10619 + torch._inductor.config.compile_threads = 1 + global plugins_loaded + if plugins_loaded: + return + plugins_loaded = True import sys if sys.version_info < (3, 10): from importlib_metadata import entry_points @@ -48,41 +59,3 @@ def load_general_plugins(): logger.info("plugin %s loaded.", plugin.name) except Exception: logger.exception("Failed to load plugin %s", plugin.name) - - -_compilation_config: Optional[CompilationConfig] = None - - -def set_compilation_config(config: Optional[CompilationConfig]): - global _compilation_config - _compilation_config = config - - -def get_compilation_config() -> Optional[CompilationConfig]: - return _compilation_config - - -_current_vllm_config: Optional[VllmConfig] = None - - -@contextmanager -def set_current_vllm_config(vllm_config: VllmConfig): - """ - Temporarily set the current VLLM config. - Used during model initialization. - We save the current VLLM config in a global variable, - so that all modules can access it, e.g. custom ops - can access the VLLM config to determine how to dispatch. - """ - global _current_vllm_config - old_vllm_config = _current_vllm_config - try: - _current_vllm_config = vllm_config - yield - finally: - _current_vllm_config = old_vllm_config - - -def get_current_vllm_config() -> VllmConfig: - assert _current_vllm_config is not None, "Current VLLM config is not set." - return _current_vllm_config diff --git a/vllm/sequence.py b/vllm/sequence.py index 3b41d25a2fe42..669124319c4f4 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -449,6 +449,10 @@ def prompt_token_ids(self) -> List[int]: def prompt_embeds(self) -> Optional[torch.Tensor]: return self.inputs.prompt_embeds + @property + def token_type_ids(self) -> List[int]: + return self.inputs.token_type_ids + @property def multi_modal_data(self) -> "MultiModalDataDict": return self.inputs.multi_modal_data @@ -579,6 +583,9 @@ def get_num_new_tokens(self) -> int: return 1 return self.data.get_num_uncomputed_tokens() + def get_num_computed_tokens(self) -> int: + return self.data.get_num_computed_tokens() + def is_prefill(self) -> bool: return self.data.stage == SequenceStage.PREFILL @@ -684,6 +691,10 @@ def encoder_prompt_token_ids(self) -> Optional[List[int]]: return (self.encoder_seq.prompt_token_ids if self.encoder_seq is not None else None) + @property + def token_type_ids(self) -> Optional[List[int]]: + return self.first_seq.token_type_ids + @property def multi_modal_data(self) -> MultiModalDataDict: return self.first_seq.multi_modal_data @@ -906,6 +917,7 @@ class SequenceGroupMetadata( default_factory=lambda: SequenceGroupState()) # "MultiModalDataDict" types. We have to use Any due to msgspec # doesn't allow to have union of 2 different dicts. + token_type_ids: Optional[List[int]] = None multi_modal_data: Optional[Any] = None multi_modal_placeholders: Optional[MultiModalPlaceholderDict] = None mm_processor_kwargs: Optional[Dict[str, Any]] = None diff --git a/vllm/spec_decode/batch_expansion.py b/vllm/spec_decode/batch_expansion.py index 25ef27b8378f0..01b9cdad963da 100644 --- a/vllm/spec_decode/batch_expansion.py +++ b/vllm/spec_decode/batch_expansion.py @@ -307,28 +307,16 @@ def _create_target_seq_group_metadata( token_ids_to_score = self._get_token_ids_to_score( proposal_token_ids[batch_index]) - # Use simpler sampling parameters apart from for final token - # (in particular don't do seeded sampling) since those sampled tokens - # aren't used. - # We don't replace the sampling_params in the greedy case because - # this also controls whether the probs get modified in the sampler - # (see use of _modify_greedy_probs_inplace there). sampling_params = input_seq_group_metadata.sampling_params - non_bonus_sampling_params = DEFAULT_SIMPLE_SAMPLING_PARAMS \ - if sampling_params.temperature else sampling_params - target_seq_group_metadata_list: List[SequenceGroupMetadata] = [] - last_index = len(token_ids_to_score) - 1 for i, token_ids in enumerate(token_ids_to_score): - target_sampling_params = sampling_params if i == last_index \ - else non_bonus_sampling_params target_seq_group_metadata_list.append( self._create_single_target_seq_group_metadata( input_seq_group_metadata, input_seq_id, next(target_seq_ids_iter), token_ids, - sampling_params=target_sampling_params, + sampling_params=sampling_params, )) return target_seq_group_metadata_list diff --git a/vllm/spec_decode/draft_model_runner.py b/vllm/spec_decode/draft_model_runner.py index cd4d7eb0e6e4e..fe5fd39f42ac9 100644 --- a/vllm/spec_decode/draft_model_runner.py +++ b/vllm/spec_decode/draft_model_runner.py @@ -20,8 +20,9 @@ from vllm.logger import init_logger from vllm.multimodal import MultiModalKwargs from vllm.sequence import ExecuteModelRequest, IntermediateTensors -from vllm.worker.model_runner import (ModelInputForGPUWithSamplingMetadata, - ModelRunner) +from vllm.worker.model_runner_base import (ModelRunnerBase, + ModelRunnerInputBase, + ModelRunnerWrapperBase) logger = init_logger(__name__) @@ -33,7 +34,7 @@ allow_gpu_advance_step = True -class TP1DraftModelRunner(ModelRunner): +class TP1DraftModelRunner(ModelRunnerWrapperBase): """Specialized model runner for speculative decoding draft model. Since the draft model always execute k forward passes consecutively to generate k speculative tokens in a single speculative decoding step, @@ -46,13 +47,14 @@ class TP1DraftModelRunner(ModelRunner): any broadcasting inside execute_model). """ - def __init__(self, *args, **kwargs): - if kwargs.get("return_hidden_states"): + def __init__(self, model_runner: ModelRunnerBase): + if hasattr( + model_runner, + "return_hidden_states") and model_runner.return_hidden_states: raise ValueError( "return_hidden_states is not supported for TP1DraftModelRunner." ) - - super().__init__(*args, **kwargs) + super().__init__(model_runner) self.indices_of_seq_with_bonus_tokens = None @@ -73,10 +75,8 @@ def _update_sampling_metadata(self, sampling_metadata, num_seqs, assert seq_group.prompt_logprob_indices == [] # No prompt assert seq_group.sample_indices == [i] # Simple - def _gpu_advance_step( - self, model_input: ModelInputForGPUWithSamplingMetadata, - last_output: SamplerOutput - ) -> ModelInputForGPUWithSamplingMetadata: + def _gpu_advance_step(self, model_input: ModelRunnerInputBase, + last_output: SamplerOutput) -> ModelRunnerInputBase: # Currently, we expect "decode mode" only assert not model_input.is_prompt @@ -168,7 +168,7 @@ def set_indices_of_seq_with_bonus_tokens(self, @torch.inference_mode() def execute_model( self, - model_input: ModelInputForGPUWithSamplingMetadata, + model_input: ModelRunnerInputBase, kv_caches: List[torch.Tensor], previous_hidden_states: Optional[torch.Tensor] = None, intermediate_tensors: Optional[IntermediateTensors] = None, @@ -273,7 +273,8 @@ def execute_model( if previous_hidden_states is not None else {} # Run model - with set_forward_context(model_input.attn_metadata): + with set_forward_context(model_input.attn_metadata, + self.vllm_config): hidden_states = model_executable( input_ids=model_input.input_tokens, positions=model_input.input_positions, diff --git a/vllm/spec_decode/interfaces.py b/vllm/spec_decode/interfaces.py index 029f56460f5c1..a4fe0f13c8db1 100644 --- a/vllm/spec_decode/interfaces.py +++ b/vllm/spec_decode/interfaces.py @@ -1,6 +1,6 @@ from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Optional, Set +from typing import Optional, Set, Union import torch @@ -75,9 +75,11 @@ def get_spec_proposals( class SpeculativeScorer(ABC): - def __init__(self, scorer_worker: WorkerBase, device: str, - vocab_size: int): + def __init__(self, scorer_worker: WorkerBase, + device: Union[torch.device, str], vocab_size: int): self._scorer_worker = scorer_worker + if isinstance(device, torch.device): + device = device.type self._device = device self._vocab_size = vocab_size diff --git a/vllm/spec_decode/medusa_worker.py b/vllm/spec_decode/medusa_worker.py index 0d233f393cb8c..1ab691a7ef047 100644 --- a/vllm/spec_decode/medusa_worker.py +++ b/vllm/spec_decode/medusa_worker.py @@ -9,21 +9,22 @@ from vllm.spec_decode.interfaces import SpeculativeProposals from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase from vllm.spec_decode.top1_proposer import Top1Proposer -from vllm.worker.worker import Worker +from vllm.worker.worker_base import WorkerWrapperBase -class MedusaWorker(NonLLMProposerWorkerBase, Worker): +class MedusaWorker(NonLLMProposerWorkerBase, WorkerWrapperBase): """Worker for Medusa. """ def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) + super().__init__(kwargs.get("vllm_config")) + self.init_worker(*args, **kwargs) # Lazy initialization list. self._proposer: Top1Proposer def init_device(self): - super().init_device() + self.worker.init_device() self._proposer = Top1Proposer( weakref.proxy(self), # type: ignore[arg-type] diff --git a/vllm/spec_decode/metrics.py b/vllm/spec_decode/metrics.py index 89ccaba70e93c..03dc46600d8a9 100644 --- a/vllm/spec_decode/metrics.py +++ b/vllm/spec_decode/metrics.py @@ -1,11 +1,12 @@ import time -from typing import Callable, Optional +from typing import Callable, Optional, Union import msgspec import torch from vllm.model_executor.layers.spec_decode_base_sampler import ( SpecDecodeBaseSampler) +from vllm.platforms import current_platform from vllm.utils import is_pin_memory_available @@ -81,8 +82,20 @@ def init_gpu_tensors(self, rank: int) -> None: self._rank = rank self._copy_stream = torch.cuda.Stream() + def init_tensors(self, + rank: int, + device_type: Union[torch.device, str] = 'cuda') -> None: + self._rank = rank + if isinstance(device_type, torch.device): + device_type = device_type.type + if device_type == 'cuda': + self._copy_stream = torch.cuda.Stream() + def maybe_collect_rejsample_metrics( self, k: int) -> Optional[SpecDecodeWorkerMetrics]: + # currently using cuda.Event, skip for any non_cuda_alike platform + if not current_platform.is_cuda_alike(): + return None # If a copy was initiated in the previous call, collect and return. if self._in_flight_copy is not None: diff --git a/vllm/spec_decode/multi_step_worker.py b/vllm/spec_decode/multi_step_worker.py index f49b98f5c9528..d249b37c780e4 100644 --- a/vllm/spec_decode/multi_step_worker.py +++ b/vllm/spec_decode/multi_step_worker.py @@ -5,17 +5,21 @@ import torch from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.platforms import current_platform from vllm.sequence import (ExecuteModelRequest, HiddenStates, SequenceData, SequenceGroupMetadata) -from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner + +if current_platform.is_cuda_alike(): + from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner + from vllm.spec_decode.interfaces import (SpeculativeProposals, SpeculativeProposer) from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase from vllm.spec_decode.top1_proposer import Top1Proposer -from vllm.worker.worker import Worker +from vllm.worker.worker_base import WorkerWrapperBase -class MultiStepWorker(Worker, ProposerWorkerBase): +class MultiStepWorker(ProposerWorkerBase, WorkerWrapperBase): """The MultiStepWorker is equivalent to a Worker except that it allows multiple forward passes in a single call, assuming the scheduler has allocated enough space to store the additional KV. This reduces overhead @@ -28,13 +32,14 @@ class MultiStepWorker(Worker, ProposerWorkerBase): """ def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) + super().__init__(kwargs.get("vllm_config")) + self.init_worker(*args, **kwargs) # Lazy initialization list. self._proposer: SpeculativeProposer def init_device(self) -> None: - super().init_device() + self.worker.init_device() self._proposer = Top1Proposer( weakref.proxy(self), # type: ignore[arg-type] @@ -51,6 +56,18 @@ def set_should_modify_greedy_probs_inplace(self) -> None: self.model_runner.model.sampler.should_modify_greedy_probs_inplace = ( True) + def determine_num_available_blocks(self) -> Tuple[int, int]: + return self.worker.determine_num_available_blocks() + + def get_cache_block_size_bytes(self) -> int: + return self.worker.get_cache_block_size_bytes() + + def initialize_cache(self, *args, **kwargs) -> None: + self.worker.initialize_cache(*args, **kwargs) + + def execute_model(self, *args, **kwargs) -> List[SamplerOutput]: + return self.worker.execute_model(*args, **kwargs) + @torch.inference_mode() def sampler_output( self, @@ -75,7 +92,7 @@ def sampler_output( # Run model sample_len times. model_outputs: List[SamplerOutput] = [] - if isinstance( + if current_platform.is_cuda_alike() and isinstance( self.model_runner, TP1DraftModelRunner ) and self.model_runner.supports_gpu_multi_step(expanded_request): # Here we run the draft_model_runner with multi-step prepare @@ -92,7 +109,7 @@ def sampler_output( # and other restrictions that are part of DraftModelRunner's # supports_gpu_multi_step(..) for _ in range(sample_len): - model_output: List[SamplerOutput] = super().execute_model( + model_output: List[SamplerOutput] = self.worker.execute_model( execute_model_req=expanded_request) assert (len(model_output) == 1 ), "composing multistep workers not supported" diff --git a/vllm/spec_decode/ngram_worker.py b/vllm/spec_decode/ngram_worker.py index debb3b2d5ec30..bb6b99135580e 100644 --- a/vllm/spec_decode/ngram_worker.py +++ b/vllm/spec_decode/ngram_worker.py @@ -22,6 +22,7 @@ def __init__(self, *args, **kwargs): # Get local_rank/vocab_size from kwargs attribute self.local_rank = kwargs["local_rank"] self.vocab_size = kwargs["vllm_config"].model_config.get_vocab_size() + self.device_type = kwargs.get("device_type", "cuda") # Lazy initialization list. self._proposer: Top1Proposer @@ -34,7 +35,7 @@ def set_ngram_window_size(self, ngram_prompt_lookup_min: int, self.ngram_prompt_lookup_min = ngram_prompt_lookup_min def init_device(self): - self.device = torch.device(f"cuda:{self.local_rank}") + self.device = torch.device(f"{self.device_type}:{self.local_rank}") self.load_model = lambda *args, **kwargs: None # Current NGramWorker only supports Top1Proposer diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index b57742c2ebfdd..53634f7b0b366 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -14,12 +14,16 @@ SpecDecodeBaseSampler, SpecDecodeStochasticBaseSampler) from vllm.model_executor.layers.typical_acceptance_sampler import ( TypicalAcceptanceSampler) +from vllm.platforms import current_platform from vllm.sequence import (VLLM_INVALID_TOKEN_ID, CompletionSequenceGroupOutput, ExecuteModelRequest, HiddenStates, SequenceGroupMetadata, get_all_seq_ids_and_request_ids) from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer -from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner + +if current_platform.is_cuda_alike(): + from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner + from vllm.spec_decode.interfaces import (SpeculativeProposals, SpeculativeScorer, SpeculativeScores) from vllm.spec_decode.medusa_worker import MedusaWorker @@ -36,8 +40,8 @@ get_all_num_logprobs, get_sampled_token_logprobs, nvtx_range, split_batch_by_proposal_len) -from vllm.worker.worker import Worker -from vllm.worker.worker_base import LoraNotSupportedWorkerBase, WorkerBase +from vllm.worker.worker_base import (LoraNotSupportedWorkerBase, WorkerBase, + WorkerWrapperBase) logger = init_logger(__name__) @@ -53,7 +57,11 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker": draft_worker_kwargs = kwargs.copy() kwargs["model_runner_cls"] = TargetModelRunner - target_worker = Worker(*args, **kwargs) + target_worker_config = copy.deepcopy(vllm_config) + target_worker_config.parallel_config.worker_cls =\ + target_worker_config.parallel_config.sd_worker_cls + target_worker = WorkerWrapperBase(vllm_config=target_worker_config) + target_worker.init_worker(*args, **kwargs) # Set the disable_logprobs variable in the TargetModelRunner instance # as per its value specified in the SpeculativeConfig. target_worker.model_runner.disable_logprobs =\ @@ -65,6 +73,8 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker": draft_worker_config.model_config, vllm_config.load_config, ) + speculative_config.draft_parallel_config.worker_cls =\ + draft_worker_config.parallel_config.sd_worker_cls draft_worker_config.parallel_config = speculative_config.draft_parallel_config # noqa # TODO allow draft-model specific load config. @@ -125,7 +135,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): @classmethod def create_worker( cls, - scorer_worker: Worker, + scorer_worker: WorkerBase, draft_worker_kwargs: Dict[str, Any], disable_mqa_scorer: bool, disable_by_batch_size: Optional[int], @@ -145,6 +155,8 @@ def create_worker( draft_parallel_config: ParallelConfig = draft_worker_kwargs[ 'vllm_config'].parallel_config if ngram_prompt_lookup_max > 0: + draft_worker_kwargs[ + "device_type"] = scorer_worker.device_config.device.type proposer_worker = NGramWorker(**draft_worker_kwargs) proposer_worker.set_ngram_window_size(ngram_prompt_lookup_min, ngram_prompt_lookup_max) @@ -158,8 +170,9 @@ def create_worker( proposer_worker = MedusaWorker(**draft_worker_kwargs) else: if draft_tp == 1: - draft_worker_kwargs[ - "model_runner_cls"] = TP1DraftModelRunner + if current_platform.is_cuda_alike(): + draft_worker_kwargs[ + "model_runner_cls"] = TP1DraftModelRunner else: if draft_model_config.hf_config.model_type == "eagle": raise NotImplementedError( @@ -306,8 +319,9 @@ def init_device(self) -> None: self.scorer_worker.load_model() self.proposer_worker.load_model() - self._metrics.init_gpu_tensors(self.rank) - self.spec_decode_sampler.init_gpu_tensors(self.rank) + self._metrics.init_tensors(self.rank, device_type=self.device) + self.spec_decode_sampler.init_tensors(self.rank, + device_type=self.device) scorer_cls: Type[SpeculativeScorer] if self.disable_mqa_scorer: @@ -408,7 +422,20 @@ def execute_model( disable_all_speculation = self._should_disable_all_speculation( execute_model_req) num_lookahead_slots = execute_model_req.num_lookahead_slots - + all_prompt = True + atleast_one_prompt = False + all_zero_spec_tokens = True + for sgm in execute_model_req.seq_group_metadata_list: + all_prompt = all_prompt and sgm.is_prompt + atleast_one_prompt = atleast_one_prompt or sgm.is_prompt + all_zero_spec_tokens = all_zero_spec_tokens and ( + sgm.num_speculative_tokens == 0) + + if all_prompt and execute_model_req.seq_group_metadata_list: + assert num_lookahead_slots == 0, ( + "Prompt only runs should have num_lookahead_slots equal to 0. " + "This should never happen, please file a bug at " + "https://github.com/vllm-project/vllm/issues") # Speculative decoding is disabled in the following cases: # 1. Prefill phase: Speculative decoding is not # used during the prefill phase. @@ -419,11 +446,8 @@ def execute_model( # In any of these cases, the proposer and scorer workers # are called normally. # We expect `num_speculative_tokens` to be None for prefills. - no_spec = all( - sgm.is_prompt for sgm in execute_model_req.seq_group_metadata_list - ) or num_lookahead_slots == 0 or disable_all_speculation or all( - sgm.num_speculative_tokens == 0 - for sgm in execute_model_req.seq_group_metadata_list) + no_spec = (num_lookahead_slots == 0 or disable_all_speculation + or all_zero_spec_tokens) # Broadcast how many lookahead slots are scheduled for this step, and # whether all speculation is disabled, to all non-driver workers. @@ -442,6 +466,15 @@ def execute_model( num_lookahead_slots=num_lookahead_slots, no_spec=no_spec, disable_all_speculation=disable_all_speculation, + # When both chunked prefill and speculative decoding are enabled + # it is possible that the same batch contains both prefill + # and decodes. If that happens in the scorer we run the batch + # as one single forward pass. However, in the proposer we + # run them as 2 different batches - one for prefill and + # the other for decodes. The variable indicates to the non-driver + # worker that there are prefills as part of the speculative batch + # and hence it needs to run an extra prefill forward pass. + run_spec_proposer_for_prefill=atleast_one_prompt, ) broadcast_tensor_dict(broadcast_dict, src=self._driver_rank) @@ -653,6 +686,8 @@ def _run_non_driver_rank(self) -> bool: if not data["no_spec"]: self.scorer_worker.execute_model() + if data["run_spec_proposer_for_prefill"]: + self.proposer_worker.execute_model() return True @@ -1090,11 +1125,11 @@ def get_cache_block_size_bytes(self): raise NotImplementedError def start_profile(self): - if isinstance(self.scorer_worker, Worker): + if isinstance(self.scorer_worker, WorkerBase): self.scorer_worker.start_profile() def stop_profile(self): - if isinstance(self.scorer_worker, Worker): + if isinstance(self.scorer_worker, WorkerBase): self.scorer_worker.stop_profile() diff --git a/vllm/spec_decode/target_model_runner.py b/vllm/spec_decode/target_model_runner.py index e61cde5b17f20..56540744b73a9 100644 --- a/vllm/spec_decode/target_model_runner.py +++ b/vllm/spec_decode/target_model_runner.py @@ -1,12 +1,12 @@ from typing import List, Optional -from vllm.config import VllmConfig from vllm.sequence import SequenceGroupMetadata -from vllm.worker.model_runner import (ModelInputForGPUWithSamplingMetadata, - ModelRunner) +from vllm.worker.model_runner_base import (ModelRunnerBase, + ModelRunnerInputBase, + ModelRunnerWrapperBase) -class TargetModelRunner(ModelRunner): +class TargetModelRunner(ModelRunnerWrapperBase): """Specialized model runner for speculative decoding target model. In speculative decoding, the log probabilities selected finally may not be the same ones as selected by the target model sampling. This means @@ -18,32 +18,21 @@ class TargetModelRunner(ModelRunner): requested or not. """ - def __init__( - self, - vllm_config: VllmConfig, - kv_cache_dtype: Optional[str] = "auto", - is_driver_worker: bool = False, - return_hidden_states: bool = False, - ): + def __init__(self, model_runner: ModelRunnerBase): # An internal boolean member variable to indicate if token log # probabilities are needed or not. + super().__init__(model_runner) self.disable_logprobs = True - super().__init__( - vllm_config=vllm_config, - kv_cache_dtype=kv_cache_dtype, - is_driver_worker=is_driver_worker, - return_hidden_states=return_hidden_states, - ) def prepare_model_input( self, seq_group_metadata_list: List[SequenceGroupMetadata], virtual_engine: int = 0, - finished_requests_ids: Optional[List[str]] = None - ) -> ModelInputForGPUWithSamplingMetadata: - model_input: ModelInputForGPUWithSamplingMetadata = super( - ).prepare_model_input(seq_group_metadata_list, virtual_engine, - finished_requests_ids) + finished_requests_ids: Optional[List[str]] = None, + ) -> ModelRunnerInputBase: + model_input: ModelRunnerInputBase =\ + self.model_runner.prepare_model_input( + seq_group_metadata_list, virtual_engine, finished_requests_ids) # If token log probabilities is disabled then skip generating sampler # CPU output. We directly serialize the GPU sampled_token_id tensors # as needed. If log probabilities is enabled then synchronize all the diff --git a/vllm/spec_decode/util.py b/vllm/spec_decode/util.py index 193ef870dfceb..da8706658d09a 100644 --- a/vllm/spec_decode/util.py +++ b/vllm/spec_decode/util.py @@ -5,6 +5,7 @@ import torch from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.platforms import current_platform from vllm.sequence import (CompletionSequenceGroupOutput, Logprob, PromptLogprobs, SequenceGroupMetadata, SequenceOutput) @@ -247,11 +248,14 @@ def nvtx_range(msg, *args, **kwargs): Arguments: msg (string): message to associate with the range """ - torch.cuda.nvtx.range_push(msg.format(*args, **kwargs)) - try: + if current_platform.is_cuda_alike(): + torch.cuda.nvtx.range_push(msg.format(*args, **kwargs)) + try: + yield + finally: + torch.cuda.nvtx.range_pop() + else: yield - finally: - torch.cuda.nvtx.range_pop() class Timer: diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 85a754f1d63c7..57572cf05f73e 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -9,6 +9,7 @@ from huggingface_hub.utils import (EntryNotFoundError, LocalEntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError) +from torch import nn from transformers import GenerationConfig, PretrainedConfig from transformers.models.auto.image_processing_auto import ( get_image_processor_config) @@ -27,10 +28,12 @@ MedusaConfig, MllamaConfig, MLPSpeculatorConfig, MPTConfig, NemotronConfig, NVLM_D_Config, - RWConfig, SolarConfig, + Olmo2Config, RWConfig, + SolarConfig, Telechat2Config, UltravoxConfig) # yapf: enable from vllm.transformers_utils.utils import check_gguf_file +from vllm.utils import resolve_obj_by_qualname if VLLM_USE_MODELSCOPE: from modelscope import AutoConfig @@ -60,7 +63,9 @@ "internvl_chat": InternVLChatConfig, "nemotron": NemotronConfig, "NVLM_D": NVLM_D_Config, + "olmo2": Olmo2Config, "solar": SolarConfig, + "telechat": Telechat2Config, "ultravox": UltravoxConfig, "grok-1": Grok1Config, **_CONFIG_REGISTRY_OVERRIDE_HF @@ -108,6 +113,15 @@ def patch_rope_scaling(config: PretrainedConfig) -> None: def patch_rope_scaling_dict(rope_scaling: Dict[str, Any]) -> None: + if "rope_type" in rope_scaling and "type" in rope_scaling: + rope_type = rope_scaling["rope_type"] + rope_type_legacy = rope_scaling["type"] + if rope_type != rope_type_legacy: + raise ValueError( + f"Found conflicts between 'rope_type={rope_type}' (modern " + f"field) and 'type={rope_type_legacy}' (legacy field). " + "You should only specify one of them.") + if "rope_type" not in rope_scaling and "type" in rope_scaling: rope_scaling["rope_type"] = rope_scaling["type"] logger.info("Replacing legacy 'type' key with 'rope_type'") @@ -573,3 +587,16 @@ def try_get_generation_config( return GenerationConfig.from_model_config(config) except OSError: # Not found return None + + +def get_cross_encoder_activation_function(config: PretrainedConfig): + if (hasattr(config, "sbert_ce_default_activation_function") + and config.sbert_ce_default_activation_function is not None): + + function_name = config.sbert_ce_default_activation_function + assert function_name.startswith("torch.nn.modules."), \ + "Loading of activation functions is restricted to " \ + "torch.nn.modules for security reasons" + return resolve_obj_by_qualname(function_name)() + else: + return nn.Sigmoid() if config.num_labels == 1 else nn.Identity() diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py index 6ef3b13198c91..5abfe17a2a937 100644 --- a/vllm/transformers_utils/configs/__init__.py +++ b/vllm/transformers_utils/configs/__init__.py @@ -16,7 +16,9 @@ from vllm.transformers_utils.configs.mpt import MPTConfig from vllm.transformers_utils.configs.nemotron import NemotronConfig from vllm.transformers_utils.configs.nvlm_d import NVLM_D_Config +from vllm.transformers_utils.configs.olmo2 import Olmo2Config from vllm.transformers_utils.configs.solar import SolarConfig +from vllm.transformers_utils.configs.telechat2 import Telechat2Config from vllm.transformers_utils.configs.ultravox import UltravoxConfig __all__ = [ @@ -34,7 +36,9 @@ "MLPSpeculatorConfig", "NemotronConfig", "NVLM_D_Config", + "Olmo2Config", "SolarConfig", + "Telechat2Config", "UltravoxConfig", "Grok1Config", ] \ No newline at end of file diff --git a/vllm/transformers_utils/configs/aria.py b/vllm/transformers_utils/configs/aria.py new file mode 100644 index 0000000000000..d253da0d96a34 --- /dev/null +++ b/vllm/transformers_utils/configs/aria.py @@ -0,0 +1,47 @@ +from transformers.models.idefics2.configuration_idefics2 import ( + Idefics2VisionConfig) +from transformers.models.llama.configuration_llama import LlamaConfig + + +class AriaVisionConfig(Idefics2VisionConfig): + model_type = "aria_vision_model" + + +class AriaMoELMConfig(LlamaConfig): + """ + Configuration class for AriaMoE language model. + + This class extends the LlamaConfig to include additional parameters specific + to the Mixture of Experts (MoE) architecture. + """ + + model_type = "aria_moe_lm" + + def __init__( + self, + moe_intermediate_size: int = 4096, + moe_num_experts: int = 8, + moe_topk: int = 2, + moe_num_shared_experts: int = 2, + **kwargs, + ): + """ + Initialize the AriaMoELMConfig. + + Args: + moe_intermediate_size (int): The intermediate size for MoE layers. + Default is 4096. + moe_num_experts (int): The number of experts in the MoE layer. + Default is 8. + moe_topk (int): The number of top experts to route to for each + token. Default is 2. + moe_num_shared_experts (int): The number of shared experts. Default + is 2. + **kwargs: Additional keyword arguments to be passed to the parent + LlamaConfig. + """ + super().__init__(**kwargs) + self.moe_intermediate_size = moe_intermediate_size + self.moe_num_experts = moe_num_experts + self.moe_topk = moe_topk + self.moe_num_shared_experts = moe_num_shared_experts diff --git a/vllm/transformers_utils/configs/olmo2.py b/vllm/transformers_utils/configs/olmo2.py new file mode 100644 index 0000000000000..0e6d8e4879b06 --- /dev/null +++ b/vllm/transformers_utils/configs/olmo2.py @@ -0,0 +1,166 @@ +# yapf: disable +# ruff: noqa: E501 +# coding=utf-8 +# Copied from +# https://github.com/huggingface/transformers/blob/main/src/transformers/models/olmo2/configuration_olmo2.py +"""OLMo 2 configuration.""" + +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + +logger = logging.get_logger(__name__) + + +class Olmo2Config(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Olmo2Model`]. It is used to instantiate an OLMo2 + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the [allenai/Olmo2-7B-1124-hf](https://huggingface.co/allenai/Olmo2-7B-1124-hf). + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 50304): + Vocabulary size of the Olmo2 model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`Olmo2Model`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 11008): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer decoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer decoder. + num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to + `num_attention_heads`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 2048): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*, defaults to 1): + Padding token id. + bos_token_id (`int`, *optional*): + Beginning of stream token id. + eos_token_id (`int`, *optional*, defaults to 50279): + End of stream token id. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling + strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is + `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update + `max_position_embeddings` to the expected new maximum. See the following thread for more information on how + these scaling strategies behave: + https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an + experimental feature, subject to breaking API changes in future versions. + attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + rms_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the rms normalization layers. + + ```python + >>> from transformers import Olmo2Model, Olmo2Config + + >>> # Initializing a Olmo2 7B style configuration + >>> configuration = Olmo2Config() + + >>> # Initializing a model from the Olmo2 7B style configuration + >>> model = Olmo2Model(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ``` + """ + + model_type = "olmo2" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=50304, + hidden_size=4096, + intermediate_size=11008, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=None, + hidden_act="silu", + max_position_embeddings=2048, + initializer_range=0.02, + use_cache=True, + pad_token_id=1, + bos_token_id=None, + eos_token_id=50279, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + attention_bias=False, + attention_dropout=0.0, + rms_norm_eps=1e-5, + **kwargs, + ): + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self._rope_scaling_validation() + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + + self.rms_norm_eps = rms_norm_eps + + def _rope_scaling_validation(self): + """ + Validate the `rope_scaling` configuration. + """ + if self.rope_scaling is None: + return + + if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2: + raise ValueError( + "`rope_scaling` must be a dictionary with two fields, `type` and `factor`, " f"got {self.rope_scaling}" + ) + rope_scaling_type = self.rope_scaling.get("type", None) + rope_scaling_factor = self.rope_scaling.get("factor", None) + if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]: + raise ValueError( + f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}" + ) + if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0: + raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}") diff --git a/vllm/transformers_utils/configs/telechat2.py b/vllm/transformers_utils/configs/telechat2.py new file mode 100644 index 0000000000000..eb6f5a059169f --- /dev/null +++ b/vllm/transformers_utils/configs/telechat2.py @@ -0,0 +1,61 @@ +# adapted from https://www.modelscope.cn/models/TeleAI/TeleChat2-3B/resolve/master/configuration_telechat2.py +""" Telechat configuration compatible with LlamaConfig. """ + +from transformers.configuration_utils import PretrainedConfig + + +class Telechat2Config(PretrainedConfig): + + model_type = "telechat" + keys_to_ignore_at_inference = ["past_key_values"] + attribute_map = { + "num_hidden_layers": "n_layer", + "num_attention_heads": "n_head", + "intermediate_size": "ffn_hidden_size", + "rms_norm_eps": "layer_norm_epsilon" + } + + def __init__( + self, + vocab_size=160256, + hidden_size=4096, + n_layer=30, + n_head=32, + layer_norm_epsilon=1e-5, + initializer_range=0.02, + use_cache=True, + bos_token_id=1, + eos_token_id=2, + apply_residual_connection_post_layernorm=False, + hidden_dropout=0.0, + attention_dropout=0.0, + ffn_hidden_size=12288, + training_seqlen=8192, + logn=True, + embed_layernorm=False, + hidden_act="silu", + **kwargs, + ): + self.vocab_size = vocab_size + n_embed = kwargs.pop("n_embed", None) + self.hidden_size = hidden_size if n_embed is None else n_embed + self.n_layer = n_layer + self.n_head = n_head + self.layer_norm_epsilon = layer_norm_epsilon + self.initializer_range = initializer_range + self.use_cache = use_cache + self.apply_residual_connection_post_layernorm = ( + apply_residual_connection_post_layernorm) + self.hidden_dropout = hidden_dropout + self.attention_dropout = attention_dropout + self.bos_token_id = bos_token_id + self.eos_token_id = eos_token_id + self.logn = logn + self.training_seqlen = training_seqlen + self.embed_layernorm = embed_layernorm + self.num_key_value_heads = kwargs.pop("num_key_value_heads", None) + self.ffn_hidden_size = ffn_hidden_size + self.hidden_act = hidden_act + super().__init__(bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + **kwargs) diff --git a/vllm/utils.py b/vllm/utils.py index af9ed4fa549a4..65b26ebb9a44f 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -1,5 +1,6 @@ import argparse import asyncio +import concurrent import contextlib import datetime import enum @@ -19,7 +20,8 @@ import warnings import weakref from asyncio import FIRST_COMPLETED, AbstractEventLoop, Future, Task -from collections.abc import Mapping +from collections import UserDict, defaultdict +from collections.abc import Iterable, Mapping from functools import lru_cache, partial, wraps from platform import uname from typing import (Any, AsyncGenerator, Awaitable, Callable, Dict, Generic, @@ -467,7 +469,10 @@ def in_wsl() -> bool: return "microsoft" in " ".join(uname()).lower() -def make_async(func: Callable[P, T]) -> Callable[P, Awaitable[T]]: +def make_async( + func: Callable[P, T], + executor: Optional[concurrent.futures.Executor] = None +) -> Callable[P, Awaitable[T]]: """Take a blocking function, and run it on in an executor thread. This function prevents the blocking function from blocking the @@ -478,7 +483,7 @@ def make_async(func: Callable[P, T]) -> Callable[P, Awaitable[T]]: def _async_wrapper(*args: P.args, **kwargs: P.kwargs) -> asyncio.Future: loop = asyncio.get_event_loop() p_func = partial(func, *args, **kwargs) - return loop.run_in_executor(executor=None, func=p_func) + return loop.run_in_executor(executor=executor, func=p_func) return _async_wrapper @@ -583,6 +588,13 @@ async def collect_from_async_generator( def get_ip() -> str: host_ip = envs.VLLM_HOST_IP + if "HOST_IP" in os.environ and "VLLM_HOST_IP" not in os.environ: + logger.warning( + "The environment variable HOST_IP is deprecated and ignored, as" + " it is often used by Docker and other software to" + "interact with the container's network stack. Please" + "use VLLM_HOST_IP instead to set the IP address for vLLM processes" + " to communicate with each other.") if host_ip: return host_ip @@ -822,6 +834,12 @@ def create_kv_caches_with_random( return key_caches, value_caches +@lru_cache +def print_info_once(msg: str) -> None: + # Set the stacklevel to 2 to print the caller's line info + logger.info(msg, stacklevel=2) + + @lru_cache def print_warning_once(msg: str) -> None: # Set the stacklevel to 2 to print the caller's line info @@ -1016,6 +1034,23 @@ def flatten_2d_lists(lists: List[List[T]]) -> List[T]: return [item for sublist in lists for item in sublist] +_K = TypeVar("_K", bound=Hashable) +_V = TypeVar("_V") + + +def full_groupby(values: Iterable[_V], *, key: Callable[[_V], _K]): + """ + Unlike :class:`itertools.groupby`, groups are not broken by + non-contiguous data. + """ + groups = defaultdict[_K, list[_V]](list) + + for value in values: + groups[key(value)].append(value) + + return groups.items() + + # TODO: This function can be removed if transformer_modules classes are # serialized by value when communicating between processes def init_cached_hf_modules() -> None: @@ -1303,6 +1338,10 @@ def parse_args(self, args=None, namespace=None): else: processed_args.append('--' + arg[len('--'):].replace('_', '-')) + elif arg.startswith('-O') and arg != '-O' and len(arg) == 2: + # allow -O flag to be used without space, e.g. -O3 + processed_args.append('-O') + processed_args.append(arg[2:]) else: processed_args.append(arg) @@ -1595,19 +1634,22 @@ def value(self): # Adapted from: https://stackoverflow.com/a/47212782/5082708 -class LazyDict(Mapping, Generic[T]): +class LazyDict(Mapping[str, T], Generic[T]): def __init__(self, factory: Dict[str, Callable[[], T]]): self._factory = factory self._dict: Dict[str, T] = {} - def __getitem__(self, key) -> T: + def __getitem__(self, key: str) -> T: if key not in self._dict: if key not in self._factory: raise KeyError(key) self._dict[key] = self._factory[key]() return self._dict[key] + def __setitem__(self, key: str, value: Callable[[], T]): + self._factory[key] = value + def __iter__(self): return iter(self._factory) @@ -1615,13 +1657,20 @@ def __len__(self): return len(self._factory) -def combine_fx_passes(passes: List[Callable]) -> Callable: +class ClassRegistry(UserDict[type[T], _V]): + + def __getitem__(self, key: type[T]) -> _V: + for cls in key.mro(): + if cls in self.data: + return self.data[cls] - def combined_fx(graph) -> None: - for fx in passes: - fx(graph) + raise KeyError(key) + + def __contains__(self, key: object) -> bool: + if not isinstance(key, type): + return False - return combined_fx + return any(cls in self.data for cls in key.mro()) def weak_ref_tensor(tensor: torch.Tensor) -> torch.Tensor: @@ -1706,6 +1755,7 @@ def direct_register_custom_op( mutates_args: List[str], fake_impl: Optional[Callable] = None, target_lib: Optional[Library] = None, + dispatch_key: str = "CUDA", ): """ `torch.library.custom_op` can have significant overhead because it @@ -1734,7 +1784,7 @@ def direct_register_custom_op( schema_str = torch._custom_op.impl.infer_schema(op_func, mutates_args) my_lib = target_lib or vllm_lib my_lib.define(op_name + schema_str) - my_lib.impl(op_name, op_func, "CUDA") + my_lib.impl(op_name, op_func, dispatch_key=dispatch_key) if fake_impl is not None: my_lib._register_fake(op_name, fake_impl) diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index e73a1e60b2730..d37989055c2e5 100644 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -6,8 +6,6 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, AttentionType) -from vllm.forward_context import get_forward_context -from vllm.utils import direct_register_custom_op from vllm.vllm_flash_attn import flash_attn_varlen_func @@ -19,7 +17,7 @@ def get_supported_head_sizes() -> List[int]: @staticmethod def get_name() -> str: - return "flash-attn-vllm-v1" + return "FLASH_ATTN_VLLM_V1" @staticmethod def get_impl_cls() -> Type["FlashAttentionImpl"]: @@ -113,13 +111,14 @@ def forward( k_scale: float = 1.0, v_scale: float = 1.0, attn_type: AttentionType = AttentionType.DECODER, + output: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with FlashAttention. Args: - query: shape = [num_tokens, num_heads * head_size] - key: shape = [num_tokens, num_kv_heads * head_size] - value: shape = [num_tokens, num_kv_heads * head_size] + query: shape = [num_tokens, num_heads, head_size] + key: shape = [num_tokens, num_kv_heads, head_size] + value: shape = [num_tokens, num_kv_heads, head_size] kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size] attn_metadata: Metadata for attention. Returns: @@ -135,116 +134,42 @@ def forward( assert k_scale == 1.0 and v_scale == 1.0, ( "key/v_scale is not supported in FlashAttention.") - output = torch.empty_like(query) - torch.ops.vllm.unified_v1_flash_attention( - output, - query, - key, - value, - self.num_heads, - self.head_size, - self.num_kv_heads, - kv_cache, + if attn_metadata is None: + # Profiling run. + return output + + num_actual_tokens = attn_metadata.num_actual_tokens + + # Reshape the input keys and values and store them in the cache. + key_cache = kv_cache[0] + value_cache = kv_cache[1] + torch.ops._C_cache_ops.reshape_and_cache_flash( + key[:num_actual_tokens], + value[:num_actual_tokens], + key_cache, + value_cache, + attn_metadata.slot_mapping, self.kv_cache_dtype, k_scale, v_scale, - self.scale, - self.sliding_window, - self.alibi_slopes, - self.logits_soft_cap, ) - return output + # Compute attention and update output up to `num_actual_tokens`. + flash_attn_varlen_func( + q=query[:num_actual_tokens], + k=key_cache, + v=value_cache, + out=output[:num_actual_tokens], + cu_seqlens_q=attn_metadata.query_start_loc, + max_seqlen_q=attn_metadata.max_query_len, + cu_seqlens_k=attn_metadata.seq_start_loc, + max_seqlen_k=attn_metadata.max_seq_len, + softmax_scale=self.scale, + causal=True, + alibi_slopes=self.alibi_slopes, + window_size=self.sliding_window, + block_table=attn_metadata.block_table, + softcap=self.logits_soft_cap, + ) -def unified_v1_flash_attention( - output: torch.Tensor, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - num_heads: int, - head_size: int, - num_kv_heads: int, - kv_cache: torch.Tensor, - kv_cache_dtype: str, - k_scale: float, - v_scale: float, - softmax_scale: float, - window_size: Optional[List[int]] = None, - alibi_slopes: Optional[torch.Tensor] = None, - logits_soft_cap: Optional[float] = None, -) -> None: - current_metadata = get_forward_context() - if current_metadata is None: - # Profiling run. - return - - assert current_metadata is not None - assert isinstance(current_metadata, FlashAttentionMetadata) - attn_metadata: FlashAttentionMetadata = current_metadata - num_actual_tokens = attn_metadata.num_actual_tokens - - # Reshape the query, key, and value tensors. - query = query.view(-1, num_heads, head_size) - key = key.view(-1, num_kv_heads, head_size) - value = value.view(-1, num_kv_heads, head_size) - - # Reshape the input keys and values and store them in the cache. - key_cache = kv_cache[0] - value_cache = kv_cache[1] - torch.ops._C_cache_ops.reshape_and_cache_flash( - key[:num_actual_tokens], - value[:num_actual_tokens], - key_cache, - value_cache, - attn_metadata.slot_mapping, - kv_cache_dtype, - k_scale, - v_scale, - ) - - attn_output = flash_attn_varlen_func( - q=query[:num_actual_tokens], - k=key_cache, - v=value_cache, - cu_seqlens_q=attn_metadata.query_start_loc, - max_seqlen_q=attn_metadata.max_query_len, - cu_seqlens_k=attn_metadata.seq_start_loc, - max_seqlen_k=attn_metadata.max_seq_len, - softmax_scale=softmax_scale, - causal=True, - alibi_slopes=alibi_slopes, - window_size=window_size, - block_table=attn_metadata.block_table, - softcap=logits_soft_cap, - ) - attn_output = attn_output.view(num_actual_tokens, -1) - # TODO(woosuk): Optimize this. - output[:num_actual_tokens].copy_(attn_output) - - -def unified_v1_flash_attention_fake( - output: torch.Tensor, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - num_heads: int, - head_size: int, - num_kv_heads: int, - kv_cache: torch.Tensor, - kv_cache_dtype: str, - k_scale: float, - v_scale: float, - softmax_scale: float, - window_size: Optional[List[int]] = None, - alibi_slopes: Optional[torch.Tensor] = None, - logits_soft_cap: Optional[float] = None, -) -> None: - return - - -direct_register_custom_op( - op_name="unified_v1_flash_attention", - op_func=unified_v1_flash_attention, - mutates_args=["kv_cache", "output"], - fake_impl=unified_v1_flash_attention_fake, -) + return output diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 38f1c03a4d3ac..b492a755e6dd5 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -17,12 +17,15 @@ def __init__( self, block_size: int, num_gpu_blocks: int, + max_model_len: int, sliding_window: Optional[int] = None, enable_caching: bool = True, num_preallocate_tokens: int = 64, ) -> None: self.block_size = block_size self.num_gpu_blocks = num_gpu_blocks + self.max_model_len = max_model_len + self.max_num_blocks_per_req = cdiv(max_model_len, block_size) self.sliding_window = sliding_window self.enable_caching = enable_caching # NOTE(woosuk): To avoid frequent block allocation, we preallocate some @@ -79,6 +82,9 @@ def get_computed_blocks(self, request: Request) -> List[KVCacheBlock]: return [] computed_blocks = [] + + # TODO(rickyx): potentially we could cache this so we don't have to + # recompute it every time. block_hashes = hash_request_tokens(self.block_size, request.all_token_ids) @@ -120,47 +126,52 @@ def append_slots( # slots, but we cannot allocate new blocks due to the limit. return None - # When caching is enabled, assign token IDs to already allocated blocks. - new_token_ids = None - parent_block = None - if self.enable_caching: - # Figure out the token IDs to add to the blocks. - new_token_ids = request.all_token_ids[ - request.num_computed_tokens:request.num_computed_tokens + - num_tokens] - - # Find the last full block index. - # TODO: This may be optimized by calculating the computed tokens. - last_full_block_idx = len(req_blocks) - 1 - while (last_full_block_idx >= 0 - and req_blocks[last_full_block_idx].block_hash is None): - last_full_block_idx -= 1 - - parent_block = (req_blocks[last_full_block_idx] - if last_full_block_idx >= 0 else None) - token_id_idx = self._add_token_ids_to_blocks( - blocks=req_blocks[last_full_block_idx + 1:], - token_ids=new_token_ids, - parent_block=parent_block) - - new_token_ids = new_token_ids[token_id_idx:] - parent_block = req_blocks[-1] - - # No new block is needed. When caching is enabled, we make sure - # token_id_idx is equal to len(new_token_ids), meaning that all tokens - # are added to allocated blocks. - if num_required_blocks <= len(req_blocks): - assert not self.enable_caching or token_id_idx == num_tokens, \ - f"{token_id_idx=} != {num_tokens=}" - return [] + if num_new_blocks <= 0: + # No new block is needed. + new_blocks = [] + else: + # Get new blocks from the free block pool considering + # preallocated blocks. + num_new_blocks = min( + num_new_blocks + self.num_preallocate_blocks, + self.free_block_queue.num_free_blocks, + # Should not exceed the maximum number of blocks per request. + # This is especially because the block table has the shape + # [..., max_num_blocks_per_req]. + # TODO(woosuk): Check and reject requests if + # num_prompt_tokens + max_tokens > max_model_len. + self.max_num_blocks_per_req - len(req_blocks), + ) + assert num_new_blocks > 0 + + new_blocks = self._get_new_blocks(num_new_blocks) + req_blocks.extend(new_blocks) + + if not self.enable_caching: + return new_blocks + + num_computed_full_blocks = (request.num_computed_tokens // + self.block_size) + + # NOTE(rickyx): We are assuming the `num_tokens` are actual + # tokens rather than lookahead slots (e.g. for speculative decoding). + # TODO(rickyx): When supporting speculative decoding, we will need to + # differentiate between them so that we can know how many blocks are + # full after appending the actual tokens. + num_full_blocks_after_append = (request.num_computed_tokens + + num_tokens) // self.block_size + assert num_full_blocks_after_append <= len(req_blocks) + + new_full_blocks = req_blocks[ + num_computed_full_blocks:num_full_blocks_after_append] + self._cache_full_blocks( + request=request, + blk_start_idx=num_computed_full_blocks, + full_blocks=new_full_blocks, + prev_block=req_blocks[num_computed_full_blocks - 1] + if num_computed_full_blocks >= 1 else None, + ) - # Allocate new blocks considering preallocated blocks, and - # add token IDs to them if caching is enabled. - num_new_blocks = min(num_new_blocks + self.num_preallocate_blocks, - self.free_block_queue.num_free_blocks) - new_blocks = self._get_new_blocks(num_new_blocks, new_token_ids, - parent_block) - req_blocks.extend(new_blocks) return new_blocks def allocate_slots( @@ -184,11 +195,20 @@ def allocate_slots( raise ValueError( f"num_tokens must be greater than 0, got {num_tokens}") - # If a computed block of a request is an eviction candidate (in the - # free queue and ref_cnt == 0), it cannot be counted as a free block - # when allocating this request. - num_evictable_computed_blocks = len( - [blk for blk in computed_blocks if blk.ref_cnt == 0]) + # Touch the computed blocks to make sure they won't be evicted. + num_evictable_computed_blocks = 0 + if self.enable_caching: + self._touch(computed_blocks) + + # If a computed block of a request is an eviction candidate (in the + # free queue and ref_cnt == 0), it cannot be counted as a free block + # when allocating this request. + num_evictable_computed_blocks = len( + [blk for blk in computed_blocks if blk.ref_cnt == 0]) + else: + assert not computed_blocks, ( + "Computed blocks should be empty when " + "prefix caching is disabled") num_required_blocks = cdiv(num_tokens, self.block_size) if (num_required_blocks > self.free_block_queue.num_free_blocks - @@ -201,35 +221,35 @@ def allocate_slots( num_new_blocks = min( num_required_blocks + self.num_preallocate_blocks, self.free_block_queue.num_free_blocks - - num_evictable_computed_blocks) + num_evictable_computed_blocks, + # Should not exceed the maximum number of blocks per request. + # This is especially because the block table has the shape + # [..., max_num_blocks_per_req]. + # TODO(woosuk): Check and reject requests if + # num_prompt_tokens + max_tokens > max_model_len. + self.max_num_blocks_per_req - len(computed_blocks), + ) + assert num_new_blocks > 0 - num_computed_tokens = len(computed_blocks) * self.block_size - - # When caching is enabled, get the new token IDs and the parent block - # ID to generate cache keys. - new_token_ids = None - parent_block = None - if self.enable_caching: - # Touch the computed blocks to make sure they won't be evicted. - self._touch(computed_blocks) + # Concatenate the computed block IDs and the new block IDs. + new_blocks = self._get_new_blocks(num_new_blocks) + self.req_to_blocks[request.request_id] = computed_blocks + new_blocks - # Get the token IDs for the blocks being allocated for hashing. - new_token_ids = request.all_token_ids[ - num_computed_tokens:num_computed_tokens + num_tokens] - if not new_token_ids: - raise RuntimeError( - "Failed to infer the token IDs for allocation. " - f"#all_tokens={len(request.all_token_ids)} < " - f"#computed_tokens={num_computed_tokens}") + if not self.enable_caching: + return new_blocks - # Get the parent block ID to construct the block chain. - parent_block = computed_blocks[-1] if computed_blocks else None + num_computed_tokens = len(computed_blocks) * self.block_size + num_full_blocks = (num_computed_tokens + num_tokens) // self.block_size - new_blocks = self._get_new_blocks(num_new_blocks, new_token_ids, - parent_block) + self._cache_full_blocks( + request=request, + blk_start_idx=len(computed_blocks), + # The new full blocks are the full blocks that are not computed. + full_blocks=self.req_to_blocks[request.request_id] + [len(computed_blocks):num_full_blocks], + prev_block=computed_blocks[-1] if computed_blocks else None, + ) - # Concatenate the computed block IDs and the new block IDs. - self.req_to_blocks[request.request_id] = computed_blocks + new_blocks return new_blocks def free(self, request: Request) -> None: @@ -248,24 +268,17 @@ def free(self, request: Request) -> None: blocks = reversed(blocks) for block in blocks: - block.ref_cnt -= 1 + block.decr_ref() if block.ref_cnt == 0: self.free_block_queue.append(block) - def _get_new_blocks( - self, - num_blocks: int, - token_ids: Optional[List[int]] = None, - parent_block: Optional[int] = None) -> List[KVCacheBlock]: - """Get new blocks from the free block pool, and add token IDs to - allocated blocks if caching is enabled. + def _get_new_blocks(self, num_blocks: int) -> List[KVCacheBlock]: + """Get new blocks from the free block pool. + Note that we do not check block cache in this function. Args: num_blocks: The number of blocks to allocate. - token_ids: The token IDs in the blocks. None if caching is disabled. - parent_block: The parent block. Used to include block chain - in the block hash. Returns: A list of new block. @@ -274,56 +287,38 @@ def _get_new_blocks( raise ValueError( f"Cannot get {num_blocks} free blocks from the pool") - # First allocate blocks. ret: List[KVCacheBlock] = [] idx = 0 while idx < num_blocks: + # First allocate blocks. curr_block = self.free_block_queue.popleft() assert curr_block.ref_cnt == 0 - # Evict blocks from the cache. + # If the block is cached, evict it. if self.enable_caching: - block_hash = curr_block.block_hash - if (block_hash is not None - and block_hash in self.cached_block_hash_to_block): - if len(self.cached_block_hash_to_block[block_hash]) == 1: - del self.cached_block_hash_to_block[block_hash] - else: - del self.cached_block_hash_to_block[block_hash][ - curr_block.block_id] - curr_block.reset() - - curr_block.ref_cnt = 1 + self._evict_cached_block(curr_block) + + curr_block.incr_ref() ret.append(curr_block) idx += 1 - # Then assign token IDs to the allocated blocks. - if self.enable_caching: - assert token_ids is not None - token_id_idx = self._add_token_ids_to_blocks( - blocks=ret, token_ids=token_ids, parent_block=parent_block) - assert token_id_idx == len(token_ids) - return ret - def _cache_full_block(self, - block: KVCacheBlock, - parent_block: Optional[KVCacheBlock] = None) -> None: - """Cache a full block for prefix caching. + def _evict_cached_block(self, block: KVCacheBlock) -> None: + """ + If a block is cached in `cached_block_hash_to_block`, we reset its hash + metadata and evict it from the cache. Args: - block: The block to cache. - parent_block: The parent block. None if this is the first block. + block: The block to evict. """ - parent_block_hash = (parent_block.block_hash - if parent_block is not None else None) - assert len(block.token_ids) == self.block_size - block.token_ids = tuple(block.token_ids) - block_hash = hash_block_tokens(parent_block_hash, block.token_ids) - block.block_hash = block_hash - block.num_hashed_tokens = self.block_size + ( - parent_block.num_hashed_tokens if parent_block is not None else 0) - self.cached_block_hash_to_block[block_hash][block.block_id] = block + block_hash = block.block_hash + if block_hash and block_hash in self.cached_block_hash_to_block: + block.reset_hash() + del self.cached_block_hash_to_block[block_hash][block.block_id] + + if len(self.cached_block_hash_to_block[block_hash]) == 0: + del self.cached_block_hash_to_block[block_hash] def _get_cached_block(self, block_hash: BlockHashType) -> Optional[KVCacheBlock]: @@ -355,43 +350,50 @@ def _touch(self, blocks: List[KVCacheBlock]) -> None: # candidate), so remove it. if block.ref_cnt == 0: self.free_block_queue.remove(block) - block.ref_cnt += 1 - - def _add_token_ids_to_blocks( - self, - blocks: List[KVCacheBlock], - token_ids: List[int], - parent_block: Optional[KVCacheBlock] = None) -> int: - """Add token IDs to a list of allocated blocks. - If a block becomes full after adding token IDs, cache it. - Return the token ID index that has not been added to the blocks - if the blocks are not enough to hold all the token IDs. + block.incr_ref() - Args: - blocks: A list of blocks to add token IDs. - token_ids: A list of token IDs to add. - parent_block: The parent block. None if this is the - first block. + def _cache_full_blocks( + self, + request: Request, + blk_start_idx: int, + full_blocks: List[KVCacheBlock], + prev_block: Optional[KVCacheBlock], + ) -> None: + """Cache a list of full blocks for prefix caching. - Returns: - The starting token ID index that has not been added to the blocks - due to insufficient given blocks. + This function takes a list of blocks that will have their block hash + metadata to be updated and cached. Given a request, it computes the + block hashes for the blocks starting from `blk_start_idx` to the end + of the request's full blocks, updating the metadata for each block + and caching them in the `cached_block_hash_to_block`. + + Args: + request: The request to cache the blocks. + blk_start_idx: The index of the first block in the request's blocks + to cache. + full_blocks: The list of blocks to update hash metadata. + prev_block: The previous block in the chain. """ - token_id_start = 0 - for curr_block in blocks: - # If all token IDs are added, then the rest of the blocks are - # preallocated blocks, so we only need to update the - # parent_block_id. FIXME - if token_id_start == len(token_ids): - continue - - # Add token IDs to the empty slots in the block. - empty_slots = self.block_size - len(curr_block.token_ids) - token_id_end = min(token_id_start + empty_slots, len(token_ids)) - curr_block.token_ids.extend(token_ids[token_id_start:token_id_end]) - # Cache the block if it becomes full. - if len(curr_block.token_ids) == self.block_size: - self._cache_full_block(curr_block, parent_block) - parent_block = curr_block - token_id_start = token_id_end - return token_id_start + # Update the new blocks with the block hashes through the chain. + prev_block_hash = (prev_block.block_hash + if prev_block is not None else None) + for i, blk in enumerate(full_blocks): + blk_idx = blk_start_idx + i + + block_tokens = request.all_token_ids[blk_idx * + self.block_size:(blk_idx + + 1) * + self.block_size] + assert len(block_tokens) == self.block_size, ( + f"Expected {self.block_size} tokens, got {len(block_tokens)} " + f"at {blk_idx}th block for request " + f"{request.request_id}({request})") + + # Compute the hash of the current block. + block_hash = hash_block_tokens(prev_block_hash, + tuple(block_tokens)) + + # Update and added the full block to the cache. + blk.block_hash = block_hash + self.cached_block_hash_to_block[block_hash][blk.block_id] = blk + prev_block_hash = block_hash diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 33dbfb7377bfd..fb666c364bfb2 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -1,6 +1,6 @@ """KV-Cache Utilities.""" -from dataclasses import dataclass, field -from typing import List, Optional, Tuple, Union +from dataclasses import dataclass +from typing import List, Optional, Tuple from vllm.logger import init_logger @@ -16,27 +16,34 @@ class KVCacheBlock: block_id: int # Reference count. ref_cnt: int = 0 - # Token IDs in the block. When the block is full, the type of token_ids - # should be Tuple[int] for fast matching. - token_ids: Union[List[int], Tuple[int]] = field(default_factory=list) # The hash of the block composed of (block hash, tuple of token IDs). # It is only available when the block is full. - block_hash: Optional[BlockHashType] = None - # The number of hashed tokens. More hashed tokens means the block - # is closer to the end of a prompt and more likely to be evicted. - num_hashed_tokens: int = 0 + _block_hash: Optional[BlockHashType] = None # Used to construct a doubly linked list for free blocks. # These two attributes should only be manipulated by FreeKVCacheBlockQueue. prev_free_block: Optional["KVCacheBlock"] = None next_free_block: Optional["KVCacheBlock"] = None - def reset(self): - """Reset the block metadata.""" - self.ref_cnt = 0 - self.token_ids = [] - self.block_hash = None - self.num_hashed_tokens = 0 + def incr_ref(self): + self.ref_cnt += 1 + + def decr_ref(self): + self.ref_cnt -= 1 + + @property + def block_hash(self) -> Optional[BlockHashType]: + return self._block_hash + + @block_hash.setter + def block_hash(self, block_hash: BlockHashType): + assert self.block_hash is None, ( + "The block already has a hash. This should not happen.") + self._block_hash = block_hash + + def reset_hash(self): + """Reset the block hash when the block is evicted.""" + self._block_hash = None class FreeKVCacheBlockQueue: diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index ba50a9786d805..f1f26f4e8d443 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -33,22 +33,23 @@ def __init__( # TODO: Support LoRA. assert lora_config is None, "V1 does not support LoRA yet." + # Scheduling constraints. + self.max_num_running_reqs = self.scheduler_config.max_num_seqs + self.max_num_scheduled_tokens = \ + self.scheduler_config.max_num_batched_tokens + self.max_model_len = self.scheduler_config.max_model_len + num_gpu_blocks = cache_config.num_gpu_blocks assert isinstance(num_gpu_blocks, int) and num_gpu_blocks > 0 - # Create the block space manager. + # Create the KV cache manager. self.kv_cache_manager = KVCacheManager( block_size=self.cache_config.block_size, num_gpu_blocks=num_gpu_blocks, + max_model_len=self.max_model_len, sliding_window=self.cache_config.sliding_window, enable_caching=self.cache_config.enable_prefix_caching) self.block_size = self.cache_config.block_size - # Scheduling constraints. - self.max_num_running_reqs = self.scheduler_config.max_num_seqs - self.max_num_scheduled_tokens = \ - self.scheduler_config.max_num_batched_tokens - self.max_model_len = self.scheduler_config.max_model_len - # req_id -> Request self.requests: Dict[str, Request] = {} # Priority queues for requests. diff --git a/vllm/v1/engine/__init__.py b/vllm/v1/engine/__init__.py index edfb8bd7c2fc1..967124fd850ea 100644 --- a/vllm/v1/engine/__init__.py +++ b/vllm/v1/engine/__init__.py @@ -68,6 +68,11 @@ class EngineCoreOutputs(msgspec.Struct, outputs: List[EngineCoreOutput] +@dataclass +class EngineCoreProfile: + is_start: bool + + class EngineCoreRequestType(enum.Enum): """ Request types defined as hex byte strings, so it can be sent over sockets @@ -75,3 +80,4 @@ class EngineCoreRequestType(enum.Enum): """ ADD = b'\x00' ABORT = b'\x01' + PROFILE = b'\x02' diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 09bff9655a882..7335c637f0f79 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -9,7 +9,7 @@ from vllm.inputs.preprocess import InputPreprocessor from vllm.logger import init_logger from vllm.lora.request import LoRARequest -from vllm.outputs import EmbeddingRequestOutput, RequestOutput +from vllm.outputs import PoolingRequestOutput, RequestOutput from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams @@ -94,7 +94,7 @@ def from_engine_args( # Create the engine configs. if engine_config is None: - vllm_config = engine_args.create_engine_config() + vllm_config = engine_args.create_engine_config(usage_context) else: vllm_config = engine_config @@ -133,7 +133,7 @@ async def add_request( trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, priority: int = 0, - ) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]: + ) -> AsyncGenerator[Union[RequestOutput, PoolingRequestOutput], None]: """Add new request to the AsyncLLM.""" if self.detokenizer.is_request_active(request_id): @@ -346,10 +346,10 @@ async def check_health(self) -> None: logger.debug("Called check_health.") async def start_profile(self) -> None: - raise ValueError("Not supported on V1 yet.") + await self.engine_core.profile(True) async def stop_profile(self) -> None: - raise ValueError("Not supported on V1 yet.") + await self.engine_core.profile(False) @property def is_running(self) -> bool: diff --git a/vllm/v1/engine/async_stream.py b/vllm/v1/engine/async_stream.py index 3e6c759ad5ebd..35449238c3259 100644 --- a/vllm/v1/engine/async_stream.py +++ b/vllm/v1/engine/async_stream.py @@ -1,11 +1,11 @@ import asyncio from typing import Any, AsyncGenerator, Callable, Optional, Type, Union -from vllm.outputs import EmbeddingRequestOutput, RequestOutput +from vllm.outputs import PoolingRequestOutput, RequestOutput class AsyncStream: - """A stream of RequestOutputs or EmbeddingRequestOutputs for a request + """A stream of RequestOutputs or PoolingRequestOutputs for a request that can be iterated over asynchronously via an async generator.""" STOP_ITERATION = Exception() # Sentinel @@ -16,7 +16,7 @@ def __init__(self, request_id: str, cancel: Callable[[str], None]) -> None: self._queue: asyncio.Queue = asyncio.Queue() self._finished = False - def put(self, item: Union[RequestOutput, EmbeddingRequestOutput, + def put(self, item: Union[RequestOutput, PoolingRequestOutput, Exception]) -> None: if not self._finished: self._queue.put_nowait(item) @@ -32,7 +32,7 @@ def finish( async def generator( self - ) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]: + ) -> AsyncGenerator[Union[RequestOutput, PoolingRequestOutput], None]: finished = False try: while True: diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 35ed131d50de9..34f99dd30ef2e 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -1,4 +1,5 @@ import multiprocessing +import pickle import queue import threading import time @@ -16,7 +17,8 @@ from vllm.usage.usage_lib import UsageContext from vllm.v1.core.scheduler import Scheduler from vllm.v1.engine import (EngineCoreOutput, EngineCoreOutputs, - EngineCoreRequest, EngineCoreRequestType) + EngineCoreProfile, EngineCoreRequest, + EngineCoreRequestType) from vllm.v1.engine.mm_input_mapper import MMInputMapper from vllm.v1.executor.gpu_executor import GPUExecutor from vllm.v1.request import Request, RequestStatus @@ -39,19 +41,6 @@ def __init__( executor_class: Type[GPUExecutor], usage_context: UsageContext, ): - # Override the configs for V1. - # FIXME - if usage_context == UsageContext.LLM_CLASS: - vllm_config.scheduler_config.max_num_seqs = 1024 - vllm_config.scheduler_config.max_num_batched_tokens = 8192 - elif usage_context == UsageContext.OPENAI_API_SERVER: - vllm_config.scheduler_config.max_num_seqs = 1024 - vllm_config.scheduler_config.max_num_batched_tokens = 2048 - - # TODO (ywang96): Enable APC by default when VLM supports it. - if not vllm_config.model_config.is_multimodal_model: - vllm_config.cache_config.enable_prefix_caching = True - assert vllm_config.model_config.task != "embedding" logger.info("Initializing an LLM engine (v%s) with config: %s", @@ -126,6 +115,9 @@ def step(self) -> List[EngineCoreOutput]: scheduler_output, output) return engine_core_outputs + def profile(self, is_start=True): + self.model_executor.worker.profile(is_start) + class EngineCoreProc(EngineCore): """ZMQ-wrapper for running EngineCore in background process.""" @@ -312,11 +304,14 @@ def _log_stats(self): self._last_logging_time = now def _handle_client_request( - self, request: Union[EngineCoreRequest, List[str]]) -> None: + self, request: Union[EngineCoreRequest, EngineCoreProfile, + List[str]]) -> None: """Handle EngineCoreRequest or EngineCoreABORT from Client.""" if isinstance(request, EngineCoreRequest): self.add_request(request) + elif isinstance(request, EngineCoreProfile): + self.model_executor.worker.profile(request.is_start) else: # TODO: make an EngineCoreAbort wrapper assert isinstance(request, list) @@ -341,6 +336,8 @@ def process_input_socket(self, input_path: str): request = decoder_add_req.decode(request_data) elif request_type == EngineCoreRequestType.ABORT.value: request = decoder_abort_req.decode(request_data) + elif request_type == EngineCoreRequestType.PROFILE.value: + request = pickle.loads(request_data) else: raise ValueError(f"Unknown RequestType: {request_type}") diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index 09801e20e16ca..835963f7ee86c 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -9,7 +9,8 @@ from vllm.logger import init_logger from vllm.utils import get_open_zmq_ipc_path from vllm.v1.engine import (EngineCoreOutput, EngineCoreOutputs, - EngineCoreRequest, EngineCoreRequestType) + EngineCoreProfile, EngineCoreRequest, + EngineCoreRequestType) from vllm.v1.engine.core import EngineCore, EngineCoreProc from vllm.v1.serial_utils import PickleEncoder @@ -58,6 +59,9 @@ def get_output(self) -> List[EngineCoreOutput]: def add_request(self, request: EngineCoreRequest) -> None: raise NotImplementedError + async def profile(self, is_start=True) -> None: + raise NotImplementedError + def abort_requests(self, request_ids: List[str]) -> None: raise NotImplementedError @@ -95,6 +99,9 @@ def add_request(self, request: EngineCoreRequest) -> None: def abort_requests(self, request_ids: List[str]) -> None: self.engine_core.abort_requests(request_ids) + async def profile(self, is_start=True) -> None: + self.engine_core.profile(is_start) + class MPClient(EngineCoreClient): """ @@ -177,8 +184,10 @@ def get_output(self) -> List[EngineCoreOutput]: engine_core_outputs = self.decoder.decode(frame.buffer).outputs return engine_core_outputs - def _send_input(self, request_type: EngineCoreRequestType, - request: Union[EngineCoreRequest, List[str]]) -> None: + def _send_input( + self, request_type: EngineCoreRequestType, + request: Union[EngineCoreRequest, EngineCoreProfile, + List[str]]) -> None: # (RequestType, SerializedRequest) msg = (request_type.value, self.encoder.encode(request)) @@ -190,6 +199,10 @@ def add_request(self, request: EngineCoreRequest) -> None: def abort_requests(self, request_ids: List[str]) -> None: self._send_input(EngineCoreRequestType.ABORT, request_ids) + async def profile(self, is_start=True) -> None: + self._send_input(EngineCoreRequestType.PROFILE, + EngineCoreProfile(is_start)) + class AsyncMPClient(MPClient): """Asyncio-compatible client for multi-proc EngineCore.""" @@ -205,8 +218,9 @@ async def get_output_async(self) -> List[EngineCoreOutput]: return engine_core_outputs async def _send_input( - self, request_type: EngineCoreRequestType, - request: Union[EngineCoreRequest, List[str]]) -> None: + self, request_type: EngineCoreRequestType, + request: Union[EngineCoreRequest, EngineCoreProfile, + List[str]]) -> None: msg = (request_type.value, self.encoder.encode(request)) await self.input_socket.send_multipart(msg, copy=False) @@ -217,3 +231,7 @@ async def add_request_async(self, request: EngineCoreRequest) -> None: async def abort_requests_async(self, request_ids: List[str]) -> None: if len(request_ids) > 0: await self._send_input(EngineCoreRequestType.ABORT, request_ids) + + async def profile(self, is_start=True) -> None: + await self._send_input(EngineCoreRequestType.PROFILE, + EngineCoreProfile(is_start)) diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index 75a77be750acd..bd19d998a4adb 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -82,7 +82,7 @@ def from_engine_args( """Creates an LLM engine from the engine arguments.""" # Create the engine configs. - vllm_config = engine_args.create_engine_config() + vllm_config = engine_args.create_engine_config(usage_context) executor_class = cls._get_executor_cls(vllm_config) if VLLM_ENABLE_V1_MULTIPROCESSING: @@ -161,13 +161,13 @@ def step(self) -> List[RequestOutput]: # TODO(rob): Can we get rid of these? def get_model_config(self): - pass + return self.model_config def start_profile(self): - pass + self.engine_core.profile(True) def stop_profile(self): - pass + self.engine_core.profile(False) def get_tokenizer_group(self, group_type): pass diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index d60f93a44f6dd..1fa47f553dfd6 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1,3 +1,4 @@ +import gc import time from dataclasses import dataclass from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple @@ -8,13 +9,13 @@ import torch.nn as nn from vllm.compilation.compile_context import set_compile_context -from vllm.config import CompilationConfig, CompilationLevel, VllmConfig +from vllm.config import CompilationLevel, VllmConfig +from vllm.distributed.parallel_state import graph_capture from vllm.forward_context import set_forward_context from vllm.inputs import INPUT_REGISTRY, InputRegistry from vllm.logger import init_logger from vllm.model_executor.model_loader import get_model from vllm.multimodal import MultiModalKwargs -from vllm.plugins import set_compilation_config from vllm.sampling_params import SamplingParams, SamplingType from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, cdiv, is_pin_memory_available) @@ -362,7 +363,8 @@ def _execute_encoder(self, scheduler_output: "SchedulerOutput"): # 2. A list (length: num_images) of tensors, each of shape # [feature_size, hidden_size] in case when the feature size is # dynamic depending on input images. - encoder_outputs = self.model.process_mm_inputs(**batched_mm_inputs) + encoder_outputs = self.model.get_multimodal_embeddings( + **batched_mm_inputs) # Cache the encoder outputs. for (req_id, input_id), output in zip(req_input_ids, encoder_outputs): @@ -447,7 +449,7 @@ def execute_model( # Run the decoder. # Use persistent buffers for CUDA graphs. - with set_forward_context(attn_metadata): + with set_forward_context(attn_metadata, self.vllm_config): hidden_states = self.model( input_ids=None, positions=self.positions[:num_input_tokens], @@ -508,20 +510,6 @@ def execute_model( return model_runner_output def load_model(self) -> None: - if self.use_cuda_graph: - # NOTE(woosuk): Currently, we use inductor because the piecewise - # CUDA graphs do not work properly with the custom CUDA kernels. - # FIXME(woosuk): Disable inductor to reduce the compilation time - # and avoid any potential issues with the inductor. - set_compilation_config( - CompilationConfig( - custom_ops=["none"], - use_cudagraph=True, - non_cudagraph_ops=["vllm.unified_v1_flash_attention"], - use_inductor=True, - enable_fusion=False, - )) - logger.info("Starting to load model %s...", self.model_config.model) with DeviceMemoryProfiler() as m: # noqa: SIM117 self.model = get_model(vllm_config=self.vllm_config) @@ -530,7 +518,25 @@ def load_model(self) -> None: logger.info("Loading model weights took %.4f GB", self.model_memory_usage / float(2**30)) - def _dummy_run(self, model: nn.Module, num_tokens: int) -> None: + @torch.inference_mode() + def _dummy_run( + self, + model: nn.Module, + num_tokens: int, + kv_caches: List[torch.Tensor], + ) -> torch.Tensor: + with set_forward_context(None, self.vllm_config): + hidden_states = model( + input_ids=None, + positions=self.positions[:num_tokens], + kv_caches=kv_caches, + attn_metadata=None, + inputs_embeds=self.inputs_embeds[:num_tokens]) + return hidden_states + + def profile_run(self) -> None: + # TODO(woosuk): Profile the max memory usage of the encoder and + # the encoder cache. # use an empty tensor instead of `None`` to force Dynamo to pass # it by reference, rather by specializing on the value `None`. # the `dtype` argument does not matter, and we use `float32` as @@ -542,46 +548,33 @@ def _dummy_run(self, model: nn.Module, num_tokens: int) -> None: torch.tensor([], dtype=torch.float32, device=self.device) for _ in range(self.num_attn_layers) ] - with set_forward_context(None): # noqa: SIM117 - with set_compile_context(self.cudagraph_batch_sizes): - # Trigger compilation for general shape. - model(input_ids=None, - positions=self.positions, - kv_caches=dummy_kv_caches, - attn_metadata=None, - inputs_embeds=self.inputs_embeds) - - @torch.inference_mode() - def profile_run(self) -> None: - # TODO(woosuk): Profile the max memory usage of the encoder and - # the encoder cache. - self._dummy_run(self.model, self.max_num_tokens) + with set_compile_context(self.cudagraph_batch_sizes): + # Trigger compilation for general shape. + hidden_states = self._dummy_run(self.model, self.max_num_tokens, + dummy_kv_caches) + logits = self.model.compute_logits(hidden_states, None) + logits = logits[:self.max_num_tokens] + # TODO(woosuk): Consider the memory usage of the sampler. torch.cuda.synchronize() + del hidden_states, logits + gc.collect() - @torch.inference_mode() def capture_model(self) -> None: if not self.use_cuda_graph: logger.warning( - "Skipping CUDA graph capture. Please set " - "VLLM_TORCH_COMPILE_LEVEL=%d to use CUDA graphs.", - CompilationLevel.PIECEWISE) + "Skipping CUDA graph capture. Please add " + "-O %s to use CUDA graphs.", CompilationLevel.PIECEWISE) return start_time = time.perf_counter() start_free_gpu_memory = torch.cuda.mem_get_info()[0] - with set_forward_context(None): - # Trigger CUDA graph capture for specific shapes. - # Capture the large shapes first so that the smaller shapes - # can reuse the memory pool allocated for the large shapes. + # Trigger CUDA graph capture for specific shapes. + # Capture the large shapes first so that the smaller shapes + # can reuse the memory pool allocated for the large shapes. + with graph_capture(): for num_tokens in reversed(self.cudagraph_batch_sizes): - self.model( - input_ids=None, - positions=self.positions[:num_tokens], - kv_caches=self.kv_caches, - attn_metadata=None, - inputs_embeds=self.inputs_embeds[:num_tokens], - ) + self._dummy_run(self.model, num_tokens, self.kv_caches) end_time = time.perf_counter() end_free_gpu_memory = torch.cuda.mem_get_info()[0] diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index c8192b7f86eb0..d33b55a8a9f9a 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -6,6 +6,7 @@ import torch import torch.distributed +import vllm.envs as envs from vllm.config import CacheConfig, ModelConfig, ParallelConfig, VllmConfig from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment, @@ -56,6 +57,22 @@ def __init__( init_cached_hf_modules() self.model_runner = GPUModelRunner(vllm_config) + # Torch profiler. Enabled and configured through env vars: + # VLLM_TORCH_PROFILER_DIR=/path/to/save/trace + if envs.VLLM_TORCH_PROFILER_DIR: + torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR + logger.info("Profiling enabled. Traces will be saved to: %s", + torch_profiler_trace_dir) + self.profiler = torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + with_stack=True, + on_trace_ready=torch.profiler.tensorboard_trace_handler( + torch_profiler_trace_dir, use_gzip=True)) + else: + self.profiler = None def initialize(self): if self.device_config.device.type == "cuda": @@ -105,35 +122,48 @@ def determine_num_available_blocks(self) -> Tuple[int, int]: # Profile the memory usage of the model and get the maximum number of # cache blocks that can be allocated with the remaining free memory. torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + _, total_gpu_memory = torch.cuda.mem_get_info() # Execute a forward pass with dummy inputs to profile the memory usage # of the model. self.model_runner.profile_run() - - # Calculate the number of blocks that can be allocated with the - # profiled peak memory. torch.cuda.synchronize() - free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info() + + free_gpu_memory, _ = torch.cuda.mem_get_info() # NOTE(woosuk): Here we assume that the other processes using the same # GPU did not change their memory usage during the profiling. - peak_memory = self.init_gpu_memory - free_gpu_memory - assert peak_memory > 0, ( + assert self.init_gpu_memory > free_gpu_memory, ( "Error in memory profiling. " f"Initial free memory {self.init_gpu_memory}, current free memory" f" {free_gpu_memory}. This happens when the GPU memory was " "not properly cleaned up before initializing the vLLM instance.") + # Get the peak memory allocation recorded by torch + peak_memory = torch.cuda.memory_stats()["allocated_bytes.all.peak"] + + # Check for any memory left around that may have been allocated on the + # gpu outside of `torch`. NCCL operations, for example, can use a few + # GB during a forward pass + torch.cuda.empty_cache() + torch_allocated_bytes = torch.cuda.memory_stats( + )["allocated_bytes.all.current"] + total_allocated_bytes = torch.cuda.mem_get_info( + )[1] - torch.cuda.mem_get_info()[0] + non_torch_allocations = total_allocated_bytes - torch_allocated_bytes + if non_torch_allocations > 0: + peak_memory += non_torch_allocations + available_kv_cache_memory = ( + total_gpu_memory * self.cache_config.gpu_memory_utilization - + peak_memory) + + # Calculate the number of blocks that can be allocated with the + # profiled peak memory. cache_block_size = _get_cache_block_size(self.cache_config, self.model_config, self.parallel_config) - num_gpu_blocks = int( - (total_gpu_memory * self.cache_config.gpu_memory_utilization - - peak_memory) // cache_block_size) + num_gpu_blocks = int(available_kv_cache_memory // cache_block_size) num_gpu_blocks = max(num_gpu_blocks, 0) - # if self.model_runner.lora_manager: - # self.model_runner.remove_all_loras() - gc.collect() - torch.cuda.empty_cache() return num_gpu_blocks, 0 def initialize_cache(self, num_gpu_blocks: int) -> None: @@ -171,6 +201,14 @@ def execute_model( # TODO(woosuk): Send the output to the engine process. return output + def profile(self, is_start=True): + if self.profiler is None: + raise RuntimeError("Profiler is not enabled.") + if is_start: + self.profiler.start() + else: + self.profiler.stop() + def init_worker_distributed_environment( parallel_config: ParallelConfig, diff --git a/vllm/worker/cpu_enc_dec_model_runner.py b/vllm/worker/cpu_enc_dec_model_runner.py index d040831870bd8..cc24cfe04d2ba 100644 --- a/vllm/worker/cpu_enc_dec_model_runner.py +++ b/vllm/worker/cpu_enc_dec_model_runner.py @@ -4,6 +4,7 @@ import torch from vllm.attention import AttentionMetadata +from vllm.forward_context import set_forward_context from vllm.model_executor import SamplingMetadata from vllm.model_executor.layers.sampler import SamplerOutput from vllm.multimodal import MultiModalKwargs @@ -34,6 +35,7 @@ def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: "input_positions": self.input_positions, "encoder_input_tokens": self.encoder_input_tokens, "encoder_input_positions": self.encoder_input_positions, + "multi_modal_kwargs": self.multi_modal_kwargs, } _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) _add_sampling_metadata_broadcastable_dict(tensor_dict, @@ -303,7 +305,8 @@ def execute_model( intermediate_tensors, } - hidden_states = model_executable(**execute_model_kwargs) + with set_forward_context(model_input.attn_metadata, self.vllm_config): + hidden_states = model_executable(**execute_model_kwargs) # Compute the logits. logits = self.model.compute_logits(hidden_states, diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index d3e1202c15e61..420aaf8a1b4cd 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -2,14 +2,15 @@ import weakref from collections import defaultdict from dataclasses import dataclass -from typing import (TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, - TypeVar, Union) +from typing import (TYPE_CHECKING, Any, Dict, List, Optional, Type, TypeVar, + Union) import torch from torch import nn from vllm.attention import AttentionMetadata, get_attn_backend from vllm.config import VllmConfig +from vllm.forward_context import set_forward_context from vllm.logger import init_logger from vllm.model_executor import SamplingMetadata from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding @@ -19,7 +20,6 @@ MultiModalKwargs, MultiModalPlaceholderMap) from vllm.sequence import (IntermediateTensors, SequenceData, SequenceGroupMetadata) -from vllm.utils import make_tensor_with_pad from vllm.worker.model_runner_base import ( ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase, _add_attn_metadata_broadcastable_dict, @@ -43,6 +43,7 @@ class ModelInputForCPU(ModelRunnerInputBase): """ input_tokens: Optional[torch.Tensor] = None input_positions: Optional[torch.Tensor] = None + token_type_ids: Optional[torch.Tensor] = None attn_metadata: Optional["AttentionMetadata"] = None multi_modal_kwargs: Optional[BatchedTensorInputs] = None virtual_engine: Optional[int] = None @@ -54,6 +55,7 @@ def as_broadcastable_tensor_dict( tensor_dict = { "input_tokens": self.input_tokens, "input_positions": self.input_positions, + "token_type_ids": self.token_type_ids, "multi_modal_kwargs": self.multi_modal_kwargs, } _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) @@ -78,11 +80,14 @@ class ModelInputForCPUWithSamplingMetadata(ModelInputForCPU): Used by the ModelRunner. """ sampling_metadata: Optional["SamplingMetadata"] = None + is_prompt: Optional[bool] = None def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: tensor_dict = { "input_tokens": self.input_tokens, "input_positions": self.input_positions, + "token_type_ids": self.token_type_ids, + "multi_modal_kwargs": self.multi_modal_kwargs, } _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) _add_sampling_metadata_broadcastable_dict(tensor_dict, @@ -104,65 +109,233 @@ def from_broadcasted_tensor_dict( class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]): + class ModelInputData: + + def __init__(self, use_mrope: bool): + self.use_mrope = use_mrope + self.input_tokens: List[int] = [] + self.input_positions: Optional[ + List[int]] = [] if not self.use_mrope else None + self.token_type_ids: Optional[List[int]] = [] + self.seq_lens: List[int] = [] + self.query_lens: List[int] = [] + self.prefill_block_tables: List[List[int]] = [] + self.decode_block_tables: List[List[int]] = [] + self.max_decode_seq_len: int = 0 + self.num_prefills: int = 0 + self.num_prefill_tokens: int = 0 + self.num_decode_tokens: int = 0 + self.slot_mapping: List[int] = [] + self.multi_modal_inputs_list: List[MultiModalKwargs] = [] + self.multi_modal_placeholder_maps: Dict[ + str, MultiModalPlaceholderMap] = defaultdict( + MultiModalPlaceholderMap) + self.input_mrope_positions: Optional[List[List[int]]] = [ + [] for _ in range(3) + ] if self.use_mrope else None + def __init__(self, runner: "CPUModelRunner", finished_requests_ids: Optional[List[str]] = None) -> None: super().__init__() self.seq_group_metadata_list: List[SequenceGroupMetadata] = [] self.runner = runner + + self.chunked_prefill = (runner.scheduler_config.chunked_prefill_enabled + or runner.cache_config.enable_prefix_caching) self.model_input_cls = self.runner._model_input_cls self.attn_backend = self.runner.attn_backend - self.sliding_window = self.runner.sliding_window - self.block_size = self.runner.block_size - self.device = self.runner.device self.multi_modal_input_mapper = self.runner.multi_modal_input_mapper + self.input_data = ModelInputForCPUBuilder.ModelInputData( + self.runner.model_config.uses_mrope) + self.att_metadata_builder = self.runner.attn_backend.get_builder_cls()( + self) def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata): self.seq_group_metadata_list.append(seq_group_metadata) + def set_seq_group_list( + self, seq_group_metadata_list: List[SequenceGroupMetadata]): + self.seq_group_metadata_list = seq_group_metadata_list + def build(self) -> ModelInputForCPU: + self._build_input_data() + + input_data = self.input_data + input_tokens = torch.tensor(input_data.input_tokens, + dtype=torch.long, + device="cpu") + input_positions = torch.tensor( + input_data.input_positions + if not input_data.use_mrope else input_data.input_mrope_positions, + dtype=torch.long, + device="cpu") + token_type_ids = torch.tensor(input_data.token_type_ids, + dtype=torch.long, + device="cpu") \ + if input_data.token_type_ids else None + + # For multi-modal models multi_modal_kwargs = None - # NOTE: We assume that all sequences in the group are all prompts or - # all decodes. - is_prompt = self.seq_group_metadata_list[0].is_prompt - # Prepare input tensors. - if is_prompt: - (input_tokens, input_positions, attn_metadata, seq_lens, - multi_modal_kwargs) = self._prepare_prompt( - self.seq_group_metadata_list) - else: - (input_tokens, input_positions, - attn_metadata) = self._prepare_decode( - self.seq_group_metadata_list) - seq_lens = None + if len(input_data.multi_modal_inputs_list) != 0: + multi_modal_kwargs = MultiModalKwargs.batch( + input_data.multi_modal_inputs_list) + + attn_metadata = self.att_metadata_builder.build( + input_data.seq_lens, input_data.query_lens, -1, -1) return self.model_input_cls( input_tokens=input_tokens, input_positions=input_positions, + token_type_ids=token_type_ids, + seq_lens=input_data.seq_lens, + query_lens=input_data.query_lens, attn_metadata=attn_metadata, multi_modal_kwargs=multi_modal_kwargs, - # query_lens is not needed if chunked prefill is not - # supported. Since CPU worker doesn't support chunked prefill - # just use seq_lens instead. - seq_lens=seq_lens, - query_lens=seq_lens, ) - def _compute_multi_modal_input( - self, - seq_data: SequenceData, - computed_len: int, - seq_group_metadata: SequenceGroupMetadata, - ): + def _build_input_data(self): + for seq_group_metadata in self.seq_group_metadata_list: + for seq_id, seq_data in seq_group_metadata.seq_data.items(): + if seq_group_metadata.is_prompt: + self._compute_prompt_input_tokens(self.input_data, + seq_group_metadata, + seq_data, seq_id) + if seq_group_metadata.multi_modal_data: + self._compute_multi_modal_input( + seq_group_metadata, seq_data) + else: + self._compute_decode_input_tokens(self.input_data, + seq_group_metadata, + seq_data, seq_id) + + def _compute_decode_input_tokens(self, data: ModelInputData, + seq_group_metadata: SequenceGroupMetadata, + seq_data: SequenceData, seq_id: int): + """ + Compute decode input tokens, positions, block table and slot mapping. + """ + block_size = self.runner.block_size + + block_table = seq_group_metadata.block_tables[seq_id] + seq_len = seq_data.get_len() + context_len = seq_data.get_num_computed_tokens() + + tokens = seq_data.get_last_token_id() + token_positions = seq_len - 1 + block_number = block_table[token_positions // block_size] + block_offset = token_positions % block_size + slot = block_number * block_size + block_offset + + # For paged_attention kernel + if self.runner.sliding_window: + start_idx = max(0, seq_len - self.runner.sliding_window) + start_block = start_idx // block_size + start_idx = start_block * block_size + seq_len = seq_len - start_idx + block_table = block_table[start_block:] + + # For MRotaryEmbedding + if data.input_positions is None: + next_pos = MRotaryEmbedding.get_next_input_positions( + seq_data.mrope_position_delta, + context_len, + seq_len, + ) + for idx in range(3): + data.input_mrope_positions[idx].extend( # type: ignore + next_pos[idx]) + else: + data.input_positions.append(token_positions) # type: ignore + + # Update fields + data.input_tokens.append(tokens) + data.max_decode_seq_len = max(data.max_decode_seq_len, seq_len) + data.num_decode_tokens += 1 + data.slot_mapping.append(slot) + data.decode_block_tables.append(block_table) + data.query_lens.append(1) + data.seq_lens.append(seq_len) + + def _compute_prompt_input_tokens(self, data: ModelInputData, + seq_group_metadata: SequenceGroupMetadata, + seq_data: SequenceData, seq_id: int): + """ + Compute prompt input tokens, positions, block table and slot mapping. + """ + token_chunk_size = seq_group_metadata.token_chunk_size + block_size = self.runner.block_size + + block_table = seq_group_metadata.block_tables[seq_id] + seq_len = seq_data.get_len() + context_len = seq_data.get_num_computed_tokens() + seq_len = min(seq_len, context_len + token_chunk_size) + + # For prefix caching + prefix_cache_block_num = len(seq_group_metadata.computed_block_nums) + if prefix_cache_block_num > 0: + prefix_cache_len = (prefix_cache_block_num * + self.runner.block_size) + if prefix_cache_len <= context_len: + # We already passed the cache hit region, + # so do normal computation. + pass + elif context_len < prefix_cache_len < seq_len: + # Partial hit. Compute the missing part. + context_len = prefix_cache_len + token_chunk_size = seq_len - context_len + elif seq_len <= prefix_cache_len: + # Full hit. Only compute the last token to avoid + # erroneous behavior. FIXME: Ideally we should directly + # mark all tokens as computed in the scheduler and do not + # schedule this sequence, so this case should not happen. + context_len = seq_len - 1 + token_chunk_size = 1 + + tokens = seq_data.get_token_ids() + tokens = tokens[context_len:seq_len] + token_positions = range(context_len, seq_len) + token_types = seq_group_metadata.token_type_ids + + # For encoder-only models, the block_table is None, + # and there is no need to initialize the slot_mapping. + if block_table is not None: + slot_mapping = [_PAD_SLOT_ID] * len(token_positions) + for i, pos in enumerate(token_positions): + block_number = block_table[pos // block_size] + block_offset = pos % block_size + slot = block_number * block_size + block_offset + slot_mapping[i] = slot + data.slot_mapping.extend(slot_mapping) + + # The MROPE positions are prepared in _compute_multi_modal_input + if data.input_positions is not None: + data.input_positions.extend(token_positions) + + if data.token_type_ids is not None: + data.token_type_ids.extend(token_types if token_types else []) + + # Update fields + data.input_tokens.extend(tokens) + data.num_prefills += 1 + data.num_prefill_tokens += len(tokens) + data.query_lens.append(len(tokens)) + data.prefill_block_tables.append(block_table) + data.seq_lens.append(seq_len) + + def _compute_multi_modal_input(self, + seq_group_metadata: SequenceGroupMetadata, + seq_data: SequenceData): + computed_len = seq_data.get_num_computed_tokens() + seq_len = self.input_data.seq_lens[-1] + # NOTE: mm_data only includes the subset of multi-modal items that # intersect with the current prefill positions. mm_data, placeholder_maps = MultiModalPlaceholderMap.from_seq_group( - seq_group_metadata, - range(computed_len, len(seq_data.get_token_ids())), - ) + seq_group_metadata, range(computed_len, seq_len)) if not mm_data: - return None, None, None + return if self.runner.mm_registry.has_processor(self.runner.model_config): mm_kwargs = mm_data @@ -173,8 +346,10 @@ def _compute_multi_modal_input( ) # special processing for mrope position deltas. - mrope_positions = None if self.runner.model_config.uses_mrope: + assert not self.chunked_prefill, \ + "MROPE on CPU does not support chunked-prefill." + image_grid_thw = mm_kwargs.get("image_grid_thw", None) video_grid_thw = mm_kwargs.get("video_grid_thw", None) assert image_grid_thw is not None or video_grid_thw is not None, ( @@ -198,226 +373,15 @@ def _compute_multi_modal_input( context_len=computed_len, ) seq_data.mrope_position_delta = mrope_position_delta - return mm_kwargs, placeholder_maps, mrope_positions - def _prepare_prompt( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int], - BatchedTensorInputs]: - assert len(seq_group_metadata_list) > 0 - input_tokens: List[int] = [] - input_positions: List[int] = [] - input_mrope_positions: List[List[int]] = [[] for _ in range(3)] - - slot_mapping: List[int] = [] - seq_lens: List[int] = [] - multi_modal_kwargs_list: List[MultiModalKwargs] = [] - multi_modal_placeholder_maps: Dict[ - str, - MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap) - - for seq_group_metadata in seq_group_metadata_list: - assert seq_group_metadata.is_prompt - seq_ids = list(seq_group_metadata.seq_data.keys()) - assert len(seq_ids) == 1 - seq_id = seq_ids[0] - - seq_data = seq_group_metadata.seq_data[seq_id] - prompt_tokens = seq_data.get_token_ids() - computed_len = seq_data.get_num_computed_tokens() - seq_len = len(prompt_tokens) - - seq_lens.append(seq_len) # Prompt token num - input_tokens.extend(prompt_tokens) # Token ids - - mrope_positions = None - if seq_group_metadata.multi_modal_data: - ( - mm_kwargs, - placeholder_maps, - mrope_positions, - ) = self._compute_multi_modal_input(seq_data, computed_len, - seq_group_metadata) - - multi_modal_kwargs_list.append(mm_kwargs) - for modality, placeholder_map in placeholder_maps.items(): - multi_modal_placeholder_maps[modality].extend( - placeholder_map) - - # Token position ids - # NOTE(woosuk): Here we assume that the first token in the prompt - # is always the first token in the sequence. - if mrope_positions: - for idx in range(3): - input_mrope_positions[idx].extend(mrope_positions[idx]) - else: - input_positions.extend(list(range(computed_len, seq_len))) - - # Compute the slot mapping. - block_table = seq_group_metadata.block_tables[seq_id] - # Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID, - # where start_idx is max(0, seq_len - sliding_window). - # For example, if the prompt len is 10, sliding window is 8, and - # block size is 4, the first two tokens are masked and the slot - # mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1]. - start_idx = 0 - if self.sliding_window is not None: - start_idx = max(0, seq_len - self.sliding_window) - - for i in range(computed_len, seq_len): - if i < start_idx: - slot_mapping.append(_PAD_SLOT_ID) - continue - - # For encoder-only models, the block_table is None, - # and there is no need to initialize the slot_mapping. - if block_table is not None: - block_number = block_table[i // - self.block_size] # type: ignore - block_offset = i % self.block_size # type: ignore - slot = block_number * self.block_size + block_offset - slot_mapping.append(slot) - - if any(input_mrope_positions): - input_positions = None # type: ignore - else: - input_mrope_positions = None # type: ignore + for i in range(3): + self.input_data.input_mrope_positions[ # type: ignore + i].extend(mrope_positions[i]) - num_prompt_tokens = len(input_tokens) - - input_tokens = torch.tensor(input_tokens, - dtype=torch.long, - device=self.device) # type: ignore - input_positions = torch.tensor(input_positions - or input_mrope_positions, - dtype=torch.long, - device=self.device) # type: ignore - slot_mapping = torch.tensor(slot_mapping, - dtype=torch.long, - device=self.device) # type: ignore - placeholder_index_maps = { - modality: placeholder_map.index_map() - for modality, placeholder_map in - multi_modal_placeholder_maps.items() - } - - attn_metadata = self.attn_backend.make_metadata( - is_prompt=True, - seq_lens=seq_lens, - seq_lens_tensor=torch.tensor([]), - max_decode_seq_len=0, - num_prefills=len(seq_lens), - num_prefill_tokens=num_prompt_tokens, - num_decode_tokens=0, - block_tables=torch.tensor([]), - slot_mapping=slot_mapping, - multi_modal_placeholder_index_maps=placeholder_index_maps, - ) - - multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list) - - return (input_tokens, input_positions, attn_metadata, seq_lens, - multi_modal_kwargs) - - def _prepare_decode( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata]: - assert len(seq_group_metadata_list) > 0 - input_tokens: List[int] = [] - input_positions: List[int] = [] - input_mrope_positions: List[List[int]] = [[] for _ in range(3)] - slot_mapping: List[int] = [] - seq_lens: List[int] = [] - block_tables: List[List[int]] = [] - - for seq_group_metadata in seq_group_metadata_list: - assert not seq_group_metadata.is_prompt - assert seq_group_metadata.token_chunk_size == 1 - - seq_ids = list(seq_group_metadata.seq_data.keys()) - - for seq_id in seq_ids: - seq_data = seq_group_metadata.seq_data[seq_id] - generation_token = seq_data.get_last_token_id() - input_tokens.append(generation_token) - - seq_len = seq_data.get_len() - position = seq_len - 1 - if seq_data.mrope_position_delta is not None: - context_len = seq_data.get_num_computed_tokens() - next_pos = MRotaryEmbedding.get_next_input_positions( - seq_data.mrope_position_delta, - context_len, - seq_len, - ) - for idx in range(3): - input_mrope_positions[idx].extend(next_pos[idx]) - else: - input_positions.append(position) - - seq_len = seq_len if self.sliding_window is None else min( - seq_len, self.sliding_window) - seq_lens.append(seq_len) - - block_table = seq_group_metadata.block_tables[seq_id] - block_number = block_table[position // self.block_size] - block_offset = position % self.block_size - slot = block_number * self.block_size + block_offset - slot_mapping.append(slot) - - if self.sliding_window is not None: - sliding_window_blocks = (self.sliding_window // - self.block_size) - block_table = block_table[-sliding_window_blocks:] - block_tables.append(block_table) - - if any(input_mrope_positions): - input_positions = None # type: ignore - else: - input_mrope_positions = None # type: ignore - - max_decode_seq_len = max(seq_lens) - - input_tokens = torch.tensor(input_tokens, - dtype=torch.long, - device=self.device) - input_positions = torch.tensor(input_positions - or input_mrope_positions, - dtype=torch.long, - device=self.device) - slot_mapping = torch.tensor(slot_mapping, - dtype=torch.long, - device=self.device) - seq_lens_tensor = torch.tensor(seq_lens, - dtype=torch.int, - device=self.device) - - block_tables = make_tensor_with_pad( - block_tables, - pad=0, - dtype=torch.int, - device=self.device, - ) - - attn_metadata = self.attn_backend.make_metadata( - is_prompt=False, - slot_mapping=slot_mapping, - multi_modal_placeholder_index_maps=None, - seq_lens=seq_lens, - seq_lens_tensor=seq_lens_tensor, - max_decode_seq_len=max_decode_seq_len, - num_prefill_tokens=0, - num_decode_tokens=len(input_tokens), - num_prefills=0, - block_tables=block_tables, - ) - return ( - input_tokens, - input_positions, - attn_metadata, - ) + self.input_data.multi_modal_inputs_list.append(mm_kwargs) + for modality, placeholder_map in placeholder_maps.items(): + self.input_data.multi_modal_placeholder_maps[modality].extend( + placeholder_map) class CPUModelRunnerBase(ModelRunnerBase[TModelInputForCPU]): @@ -432,29 +396,34 @@ def __init__( vllm_config: VllmConfig, kv_cache_dtype: Optional[str] = "auto", is_driver_worker: bool = False, + return_hidden_states: bool = False, *args, **kwargs, ): ModelRunnerBase.__init__(self, vllm_config) - # Currently, CPU worker doesn't support chunked prefill. - assert self.scheduler_config.chunked_prefill_enabled is False model_config = self.model_config cache_config = self.cache_config self.is_driver_worker = is_driver_worker + self.return_hidden_states = return_hidden_states self.device = self.device_config.device + self.pin_memory = False self.kv_cache_dtype = kv_cache_dtype self.sliding_window = model_config.get_sliding_window() self.block_size = cache_config.block_size + num_attn_heads = self.model_config.get_num_attention_heads( + self.parallel_config) + needs_attn_backend = (num_attn_heads != 0 + or self.model_config.is_attention_free) self.attn_backend = get_attn_backend( self.model_config.get_head_size(), self.model_config.dtype, self.kv_cache_dtype, self.block_size, self.model_config.is_attention_free, - ) + ) if needs_attn_backend else None # Multi-modal data support self.mm_registry = MULTIMODAL_REGISTRY @@ -479,11 +448,19 @@ def _prepare_model_input_tensors( """ builder = self._builder_cls(weakref.proxy(self), finished_requests_ids) - for seq_group_metadata in seq_group_metadata_list: - builder.add_seq_group(seq_group_metadata) + builder.set_seq_group_list(seq_group_metadata_list) return builder.build() # type: ignore + # sampler property will be used by spec_decode_worker + @property + def sampler(self): + return self.model.sampler + + @property + def vocab_size(self) -> int: + return self.model_config.get_vocab_size() + class CPUModelRunner(CPUModelRunnerBase[ModelInputForCPUWithSamplingMetadata]): _model_input_cls: Type[ModelInputForCPUWithSamplingMetadata] = ( @@ -520,9 +497,12 @@ def prepare_model_input( pin_memory=False, generators=generators) + is_prompt = (seq_group_metadata_list[0].is_prompt + if seq_group_metadata_list else None) return dataclasses.replace(model_input, sampling_metadata=sampling_metadata, - virtual_engine=virtual_engine) + virtual_engine=virtual_engine, + is_prompt=is_prompt) @torch.no_grad() def execute_model( @@ -531,28 +511,33 @@ def execute_model( kv_caches: List[torch.Tensor], intermediate_tensors: Optional[IntermediateTensors] = None, num_steps: int = 1, + previous_hidden_states: Optional[torch.Tensor] = None, ) -> Optional[List[SamplerOutput]]: if num_steps > 1: raise ValueError( "CPU worker does not support multi-step execution.") model_executable = self.model - execute_model_kwargs = { - "input_ids": - model_input.input_tokens, - "positions": - model_input.input_positions, - "kv_caches": - kv_caches, - "attn_metadata": - model_input.attn_metadata, - **MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs or {}, - device=self.device), - "intermediate_tensors": - intermediate_tensors, - } - hidden_states = model_executable(**execute_model_kwargs) + multimodal_kwargs = {} + if model_input.multi_modal_kwargs is not None: + multimodal_kwargs = MultiModalKwargs.as_kwargs( + model_input.multi_modal_kwargs, device=self.device) + execute_model_kwargs = {} + if previous_hidden_states is not None: + execute_model_kwargs.update( + {"previous_hidden_states": previous_hidden_states}) + + with set_forward_context(model_input.attn_metadata, self.vllm_config): + hidden_states = model_executable( + input_ids=model_input.input_tokens, + positions=model_input.input_positions, + kv_caches=kv_caches, + attn_metadata=model_input.attn_metadata, + intermediate_tensors=intermediate_tensors, + **execute_model_kwargs, + **multimodal_kwargs, + ) # Compute the logits. logits = self.model.compute_logits(hidden_states, @@ -567,4 +552,12 @@ def execute_model( logits=logits, sampling_metadata=model_input.sampling_metadata, ) + if self.return_hidden_states: + # we only need to pass hidden states of most recent token + if model_input.is_prompt: + output.prefill_hidden_states = hidden_states + output.hidden_states = hidden_states return [output] + + def generate_proposals(self, *args, **kwargs): + return self.model.generate_proposals(*args, **kwargs) diff --git a/vllm/worker/cpu_embedding_model_runner.py b/vllm/worker/cpu_pooling_model_runner.py similarity index 91% rename from vllm/worker/cpu_embedding_model_runner.py rename to vllm/worker/cpu_pooling_model_runner.py index d0b8fec48d74f..17b2fd2564a04 100644 --- a/vllm/worker/cpu_embedding_model_runner.py +++ b/vllm/worker/cpu_pooling_model_runner.py @@ -3,6 +3,7 @@ import torch +from vllm.forward_context import set_forward_context from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.multimodal import MultiModalKwargs from vllm.pooling_params import PoolingParams @@ -15,12 +16,12 @@ @dataclasses.dataclass(frozen=True) class ModelInputForCPUWithPoolingMetadata(ModelInputForCPU): """ - Used by the CPUEmbeddingModelRunner. + Used by the CPUPoolingModelRunner. """ pooling_metadata: Optional["PoolingMetadata"] = None -class CPUEmbeddingModelRunner( +class CPUPoolingModelRunner( CPUModelRunnerBase[ModelInputForCPUWithPoolingMetadata]): _model_input_cls: Type[ModelInputForCPUWithPoolingMetadata] = ( ModelInputForCPUWithPoolingMetadata) @@ -49,6 +50,9 @@ def execute_model( ] model_executable = self.model + cross_enc_kwargs = {} + if model_input.token_type_ids is not None: + cross_enc_kwargs["token_type_ids"] = model_input.token_type_ids execute_model_kwargs = { "input_ids": model_input.input_tokens, @@ -60,11 +64,13 @@ def execute_model( model_input.attn_metadata, **MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs or {}, device=self.device), + **cross_enc_kwargs, "intermediate_tensors": intermediate_tensors, } - hidden_states = model_executable(**execute_model_kwargs) + with set_forward_context(model_input.attn_metadata, self.vllm_config): + hidden_states = model_executable(**execute_model_kwargs) # Only perform pooling in the driver worker. if not self.is_driver_worker: diff --git a/vllm/worker/cpu_worker.py b/vllm/worker/cpu_worker.py index bc9164bd9d5df..4fad1a3f4caeb 100644 --- a/vllm/worker/cpu_worker.py +++ b/vllm/worker/cpu_worker.py @@ -14,9 +14,9 @@ from vllm.model_executor import set_random_seed from vllm.sequence import ExecuteModelRequest from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE -from vllm.worker.cpu_embedding_model_runner import CPUEmbeddingModelRunner from vllm.worker.cpu_enc_dec_model_runner import CPUEncoderDecoderModelRunner from vllm.worker.cpu_model_runner import CPUModelRunner, CPUModelRunnerBase +from vllm.worker.cpu_pooling_model_runner import CPUPoolingModelRunner from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, LoraNotSupportedWorkerBase, WorkerBase, WorkerInput) @@ -128,6 +128,7 @@ def __init__( distributed_init_method: str, kv_cache_dtype: Optional[str] = "auto", is_driver_worker: bool = False, + model_runner_cls: Optional[Type[CPUModelRunner]] = None, ) -> None: WorkerBase.__init__(self, vllm_config=vllm_config) @@ -151,15 +152,29 @@ def __init__( else: self.local_omp_cpuid = omp_cpuids.split("|")[rank] + # Return hidden states from target model if the draft model is an + # mlp_speculator + speculative_config = self.speculative_config + model_config = self.model_config + speculative_args = {} if speculative_config is None \ + or (speculative_config.draft_model_config.model == + model_config.model) \ + or (speculative_config.draft_model_config.hf_config.model_type + not in ["medusa", "mlp_speculator", "eagle"]) \ + else {"return_hidden_states": True} ModelRunnerClass: Type[CPUModelRunnerBase] = CPUModelRunner if self.model_config.task == "embedding": - ModelRunnerClass = CPUEmbeddingModelRunner + ModelRunnerClass = CPUPoolingModelRunner elif self.model_config.is_encoder_decoder: ModelRunnerClass = CPUEncoderDecoderModelRunner self.model_runner: CPUModelRunnerBase = ModelRunnerClass( vllm_config=vllm_config, kv_cache_dtype=kv_cache_dtype, - is_driver_worker=is_driver_worker) + is_driver_worker=is_driver_worker, + **speculative_args, + ) + if model_runner_cls is not None: + self.model_runner = model_runner_cls(self.model_runner) # Uninitialized cache engine. Will be initialized by # initialize_cache. self.cache_engine: List[CPUCacheEngine] @@ -197,7 +212,7 @@ def init_device(self) -> None: ret = torch.ops._C_utils.init_cpu_threads_env(self.local_omp_cpuid) if ret: logger.info(ret) - + self.device = torch.device("cpu") self.init_distributed_environment() # Set random seed. set_random_seed(self.model_config.seed) @@ -297,6 +312,14 @@ def do_metadata_broadcast(self) -> bool: def kv_cache(self) -> Optional[List[List[torch.Tensor]]]: return self.cpu_cache + @property + def vocab_size(self) -> int: + return self.model_runner.vocab_size + + @property + def max_model_len(self) -> int: + return self.model_config.max_model_len + def execute_worker( self, worker_input: WorkerInput, diff --git a/vllm/worker/enc_dec_model_runner.py b/vllm/worker/enc_dec_model_runner.py index 79b515546d849..c2ebf853f41d8 100644 --- a/vllm/worker/enc_dec_model_runner.py +++ b/vllm/worker/enc_dec_model_runner.py @@ -8,7 +8,7 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionMetadata) from vllm.attention.backends.utils import PAD_SLOT_ID -from vllm.attention.selector import (_Backend, get_env_variable_attn_backend, +from vllm.attention.selector import (get_env_variable_attn_backend, get_global_forced_attn_backend) from vllm.config import VllmConfig from vllm.forward_context import set_forward_context @@ -18,6 +18,7 @@ from vllm.model_executor.layers.sampler import SamplerOutput from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalKwargs, MultiModalRegistry) +from vllm.platforms import _Backend from vllm.sampling_params import SamplingParams from vllm.sequence import (IntermediateTensors, PoolerOutput, SequenceGroupMetadata) @@ -175,7 +176,7 @@ def execute_model( } if self.has_inner_state else {} multi_modal_kwargs = model_input.multi_modal_kwargs or {} - with set_forward_context(model_input.attn_metadata): + with set_forward_context(model_input.attn_metadata, self.vllm_config): hidden_or_intermediate_states = model_executable( input_ids=model_input.input_tokens, positions=model_input.input_positions, diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index f917198740e07..741bdf6346c49 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -21,7 +21,7 @@ from vllm.compilation.compile_context import set_compile_context from vllm.config import CompilationLevel, VllmConfig from vllm.core.scheduler import SchedulerOutputs -from vllm.distributed import get_pp_group +from vllm.distributed import get_kv_transfer_group, get_pp_group from vllm.distributed.parallel_state import graph_capture from vllm.forward_context import set_forward_context from vllm.inputs import INPUT_REGISTRY, InputRegistry @@ -92,6 +92,7 @@ class ModelInputForGPU(ModelRunnerInputBase): """ input_tokens: Optional[torch.Tensor] = None input_positions: Optional[torch.Tensor] = None + token_types: Optional[torch.Tensor] = None seq_lens: Optional[List[int]] = None query_lens: Optional[List[int]] = None lora_mapping: Optional["LoRAMapping"] = None @@ -200,6 +201,7 @@ class InterDataForSeqGroup: def simple_reinit(self): self.input_tokens[0].clear() # type: ignore self.input_positions[0].clear() # type: ignore + self.token_types[0].clear() # type: ignore self.mrope_input_positions = None # type: ignore self.seq_lens[0] = 0 # type: ignore self.orig_seq_lens[0] = 0 # type: ignore @@ -226,6 +228,7 @@ def __init__( # Input tokens and positions. input_tokens: Optional[List[List[int]]] = None, input_positions: Optional[List[List[int]]] = None, + token_types: Optional[List[List[int]]] = None, mrope_input_positions: Optional[List[List[List[int]]]] = None, # The sequence length (may be capped to the sliding window). @@ -291,6 +294,12 @@ def __init__( for seq_id in range(len(self.seq_ids)): self.input_positions[seq_id].clear() + if token_types: + self.token_types = token_types + else: + for seq_id in range(len(self.seq_ids)): + self.token_types[seq_id].clear() + self.mrope_input_positions = None if seq_lens: @@ -354,6 +363,7 @@ def __init__( else: self.input_tokens = input_tokens or [] self.input_positions = input_positions or [] + self.token_types = token_types or [] self.mrope_input_positions = mrope_input_positions or None self.seq_lens = seq_lens or [] self.orig_seq_lens = orig_seq_lens or [] @@ -386,6 +396,7 @@ def __post_init__(self): self.input_tokens = [[] for _ in range(self.n_seqs)] self.input_positions = [[] for _ in range(self.n_seqs)] + self.token_types = [[] for _ in range(self.n_seqs)] self.mrope_input_positions = None self.seq_lens = [0] * self.n_seqs self.orig_seq_lens = [0] * self.n_seqs @@ -498,12 +509,15 @@ def _compute_lens(self, inter_data: InterDataForSeqGroup, seq_idx: int, # Compute tokens. tokens = seq_data.get_token_ids()[context_len:seq_len] + token_types = seq_group_metadata.token_type_ids inter_data.seq_lens[seq_idx] = seq_len inter_data.orig_seq_lens[seq_idx] = seq_len inter_data.context_lens[seq_idx] = context_len inter_data.input_tokens[seq_idx].extend(tokens) inter_data.input_positions[seq_idx].extend(range(context_len, seq_len)) + inter_data.token_types[seq_idx].extend( + token_types if token_types else []) inter_data.query_lens[seq_idx] = seq_len - context_len if seq_data.mrope_position_delta is not None: @@ -561,6 +575,8 @@ def _compute_for_prefix_cache_hit( seq_idx][uncomputed_start:] inter_data.input_positions[seq_idx] = inter_data.input_positions[ seq_idx][uncomputed_start:] + inter_data.token_types[seq_idx] = inter_data.token_types[seq_idx][ + uncomputed_start:] context_len = prefix_cache_len inter_data.context_lens[seq_idx] = context_len @@ -575,6 +591,8 @@ def _compute_for_prefix_cache_hit( seq_idx][-1:] inter_data.input_positions[seq_idx] = inter_data.input_positions[ seq_idx][-1:] + inter_data.token_types[seq_idx] = inter_data.token_types[seq_idx][ + -1:] inter_data.query_lens[seq_idx] = 1 inter_data.context_lens[seq_idx] = inter_data.seq_lens[seq_idx] - 1 @@ -803,9 +821,12 @@ def build(self) -> ModelInputForGPU: """ # Combine and flatten intermediate data. input_tokens = [] + token_types = [] for inter_data in self.inter_data_list: for cur_input_tokens in inter_data.input_tokens: input_tokens.extend(cur_input_tokens) + for cur_token_types in inter_data.token_types: + token_types.extend(cur_token_types) if not input_tokens: # This may happen when all prefill requests hit @@ -874,6 +895,12 @@ def build(self) -> ModelInputForGPU: input_tokens_tensor = async_tensor_h2d(input_tokens, torch.long, self.runner.device, self.runner.pin_memory) + + token_types_tensor = async_tensor_h2d(token_types, torch.long, + self.runner.device, + self.runner.pin_memory) \ + if token_types else None + if mrope_input_positions is not None: for idx in range(3): mrope_input_positions[idx].extend( @@ -952,6 +979,7 @@ def build(self) -> ModelInputForGPU: return self.model_input_cls( input_tokens=input_tokens_tensor, input_positions=input_positions_tensor, + token_types=token_types_tensor, attn_metadata=attn_metadata, seq_lens=seq_lens, query_lens=query_lens, @@ -1503,7 +1531,7 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: self._update_inputs_to_capture_for_enc_dec_model( capture_inputs) - with set_forward_context(attn_metadata): + with set_forward_context(attn_metadata, self.vllm_config): graph_runner.capture(**capture_inputs) self.graph_memory_pool = graph_runner.graph.pool() self.graph_runners[virtual_engine][batch_size] = ( @@ -1639,6 +1667,24 @@ def execute_model( else: model_executable = self.model + # Receive KV cache in distributed KV cache transfer setting + # In disagg prefill setting, it will also recv hidden states and bypass + # model forwarding + # In KV cache database setting, it will change the model input so that + # we can skip prefilling on tokens that successfully received KV caches + # NOTE: The receive operation is blocking + bypass_model_exec = False + if self.need_recv_kv(model_input, kv_caches): + hidden_or_intermediate_states, bypass_model_exec, model_input = \ + get_kv_transfer_group().recv_kv_caches_and_hidden_states( + # model is used to know which layer the current worker + # is working on, so that we can receive KV for only those + # layers. + model_executable, + model_input, + kv_caches=kv_caches + ) + multi_modal_kwargs = model_input.multi_modal_kwargs or {} seqlen_agnostic_kwargs = { "finished_requests_ids": model_input.finished_requests_ids, @@ -1650,21 +1696,36 @@ def execute_model( model_forward_end = torch.cuda.Event(enable_timing=True) model_forward_start.record() - with set_forward_context(model_input.attn_metadata): - hidden_or_intermediate_states = model_executable( - input_ids=model_input.input_tokens, - positions=model_input.input_positions, - kv_caches=kv_caches, - attn_metadata=model_input.attn_metadata, - intermediate_tensors=intermediate_tensors, - **MultiModalKwargs.as_kwargs(multi_modal_kwargs, - device=self.device), - **seqlen_agnostic_kwargs) + if not bypass_model_exec: + with set_forward_context(model_input.attn_metadata, + self.vllm_config): + hidden_or_intermediate_states = model_executable( + input_ids=model_input.input_tokens, + positions=model_input.input_positions, + kv_caches=kv_caches, + attn_metadata=model_input.attn_metadata, + intermediate_tensors=intermediate_tensors, + **MultiModalKwargs.as_kwargs(multi_modal_kwargs, + device=self.device), + **seqlen_agnostic_kwargs) if (self.observability_config is not None and self.observability_config.collect_model_forward_time): model_forward_end.record() + # Sending KV cache in distributed KV cache transfer setting + # NOTE: the send operation is non-blocking + if self.need_send_kv(model_input, kv_caches): + get_kv_transfer_group().send_kv_caches_and_hidden_states( + # model_executable is used to know which layer the current + # worker is working on, so that we can send KV for only those + # layers. + model_executable, + model_input, + kv_caches, + hidden_or_intermediate_states, + ) + # Compute the logits in the last pipeline stage. if not get_pp_group().is_last_rank: if (self.is_driver_worker @@ -1732,6 +1793,56 @@ def execute_model( return [output] + def need_recv_kv(self, model_input, kv_caches) -> bool: + """Check if we need to receive kv-cache from the other worker. + We need to receive KV when + 1. current vLLM instance is KV cache consumer/decode vLLM instance + 2. this batch is not a profiling run + 3. this batch is a prefill run + + Args: + model_input: input to the model executable + kv_caches: vLLM's paged memory + """ + + prefill_meta = model_input.attn_metadata.prefill_metadata + + # check if the current run is profiling + is_profile_run = (kv_caches[0].numel() == 0) + # check if the current run is prefill + is_prefill_run = prefill_meta is not None + + if self.vllm_config.kv_transfer_config is None: + return False + + return self.vllm_config.kv_transfer_config.is_kv_consumer and ( + not is_profile_run) and is_prefill_run + + def need_send_kv(self, model_input, kv_caches) -> bool: + """Check if we need to send kv-cache to the other worker. + We need to send KV when + 1. current vLLM instance is KV cache producer/prefill vLLM instance + 2. this batch is not a profiling run + 3. this batch is a prefill run + + Args: + model_input: input to the model executable + kv_caches: vLLM's paged memory + """ + + prefill_meta = model_input.attn_metadata.prefill_metadata + + # check if the current run is profiling + is_profile_run = (kv_caches[0].numel() == 0) + # check if the current run is prefill + is_prefill_run = prefill_meta is not None + + if self.vllm_config.kv_transfer_config is None: + return False + + return self.vllm_config.kv_transfer_config.is_kv_producer and ( + not is_profile_run) and is_prefill_run + # NOTE: this is nn.Module so the profiler can properly capture/group # kernels calls made within the graph diff --git a/vllm/worker/model_runner_base.py b/vllm/worker/model_runner_base.py index 9e529f86b46bb..cd4770202a186 100644 --- a/vllm/worker/model_runner_base.py +++ b/vllm/worker/model_runner_base.py @@ -289,3 +289,18 @@ def get_generators(self, finished_request_ids: Optional[List[str]] = None): self.generators.pop(request_id, None) return self.generators + + +class ModelRunnerWrapperBase: + """ + The whole point of this class is to lazily initialize the model_runner. + """ + + def __init__( + self, + moderl_runner: ModelRunnerBase, + ) -> None: + self.model_runner: ModelRunnerBase = moderl_runner + + def __getattr__(self, attr): + return getattr(self.model_runner, attr) diff --git a/vllm/worker/embedding_model_runner.py b/vllm/worker/pooling_model_runner.py similarity index 95% rename from vllm/worker/embedding_model_runner.py rename to vllm/worker/pooling_model_runner.py index 37cfcbf13d7a3..1beae1e3884c5 100644 --- a/vllm/worker/embedding_model_runner.py +++ b/vllm/worker/pooling_model_runner.py @@ -21,12 +21,12 @@ @dataclasses.dataclass(frozen=True) class ModelInputForGPUWithPoolingMetadata(ModelInputForGPU): """ - Used by the EmbeddingModelRunner. + Used by the PoolingModelRunner. """ pooling_metadata: Optional["PoolingMetadata"] = None -class EmbeddingModelRunner( +class PoolingModelRunner( GPUModelRunnerBase[ModelInputForGPUWithPoolingMetadata]): _model_input_cls: Type[ModelInputForGPUWithPoolingMetadata] = ( ModelInputForGPUWithPoolingMetadata) @@ -52,7 +52,7 @@ def execute_model( ) -> Optional[Union[List[PoolerOutput], IntermediateTensors]]: if num_steps > 1: raise ValueError( - "EmbeddingModelRunner does not support multi-step execution.") + "PoolingModelRunner does not support multi-step execution.") if self.lora_config: assert model_input.lora_requests is not None @@ -97,7 +97,11 @@ def execute_model( model_forward_end = torch.cuda.Event(enable_timing=True) model_forward_start.record() - with set_forward_context(model_input.attn_metadata): + cross_enc_kwargs = {} + if model_input.token_types is not None: + cross_enc_kwargs["token_type_ids"] = model_input.token_types + + with set_forward_context(model_input.attn_metadata, self.vllm_config): hidden_or_intermediate_states = model_executable( input_ids=model_input.input_tokens, positions=model_input.input_positions, @@ -105,7 +109,8 @@ def execute_model( attn_metadata=model_input.attn_metadata, intermediate_tensors=intermediate_tensors, **MultiModalKwargs.as_kwargs(multi_modal_kwargs, - device=self.device)) + device=self.device), + **cross_enc_kwargs) if (self.observability_config is not None and self.observability_config.collect_model_forward_time): diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index d7a641857a613..9a054eb8a4cf7 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -1,3 +1,4 @@ +import enum import time from dataclasses import dataclass from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, @@ -11,7 +12,6 @@ import torch_xla.runtime as xr from vllm.attention import AttentionMetadata, get_attn_backend -from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.model_executor.layers.sampler import SamplerOutput @@ -39,6 +39,15 @@ _MAX_NUM_SAMPLES = 128 +class ExecutionMode(enum.Enum): + PREFILL = enum.auto() + DECODE = enum.auto() + PREFIX_PREFILL = enum.auto() + + def is_prefill(self) -> bool: + return self in (ExecutionMode.PREFILL, ExecutionMode.PREFIX_PREFILL) + + @dataclass(frozen=True) class ModelInputForTPU(ModelRunnerInputBase): token_ids: torch.Tensor @@ -140,16 +149,21 @@ def load_model(self) -> None: model = get_model(vllm_config=self.vllm_config) model = model.eval() xm.wait_device_ops() - self.model = ModelWrapper(model, self.vllm_config) + model = ModelWrapper(model) + self.model = torch.compile(model, + backend="openxla", + fullgraph=True, + dynamic=False) def _dummy_run( self, batch_size: int, seq_len: int, kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], - is_prompt: bool, + exec_mode: ExecutionMode, ) -> None: - if is_prompt: + exec_mode = ExecutionMode(exec_mode) + if exec_mode.is_prefill(): seq_len = (seq_len + 15) // 16 * 16 token_ids = torch.zeros((batch_size, seq_len), dtype=torch.int32, @@ -160,18 +174,38 @@ def _dummy_run( slot_mapping = torch.zeros((batch_size, seq_len), dtype=torch.int64, device=self.device) - attn_metadata = self.attn_backend.make_metadata( - num_prefills=batch_size, - num_prefill_tokens=batch_size * seq_len, - num_decode_tokens=0, - slot_mapping=slot_mapping, - multi_modal_placeholder_index_maps=None, - block_tables=None, - context_lens=None, - ) input_lens = torch.ones((batch_size, ), dtype=torch.int32, device=self.device) + if exec_mode == ExecutionMode.PREFILL: + attn_metadata = self.attn_backend.make_metadata( + num_prefills=batch_size, + num_prefill_tokens=batch_size * seq_len, + num_decode_tokens=0, + slot_mapping=slot_mapping, + multi_modal_placeholder_index_maps=None, + block_tables=None, + context_lens=None, + effective_query_lens=None, + ) + else: + context_lens = torch.ones((batch_size, ), + dtype=torch.int32, + device=self.device) + block_tables = torch.tensor(self.block_tables[:batch_size], + dtype=torch.int32, + device=self.device) + effective_query_lens = torch.ones_like(context_lens) + attn_metadata = self.attn_backend.make_metadata( + num_prefills=batch_size, + num_prefill_tokens=batch_size * seq_len, + num_decode_tokens=0, + slot_mapping=slot_mapping, + multi_modal_placeholder_index_maps=None, + block_tables=block_tables, + context_lens=context_lens, + effective_query_lens=effective_query_lens, + ) else: assert seq_len == 1 token_ids = torch.zeros((batch_size, seq_len), @@ -204,7 +238,7 @@ def _dummy_run( ) t = torch.ones((batch_size, ), dtype=torch.float32, device=self.device) p = torch.ones((batch_size, ), dtype=torch.float32, device=self.device) - num_samples = _MAX_NUM_SAMPLES if is_prompt else 1 + num_samples = _MAX_NUM_SAMPLES if exec_mode.is_prefill() else 1 # NOTE(woosuk): There are two stages of compilation: torch.compile and # XLA compilation. Using `mark_dynamic` can reduce the torch.compile @@ -213,7 +247,7 @@ def _dummy_run( # be re-compiled for every different shapes. This overhead is inevitable # in the first run, but can be skipped afterwards as we cache the XLA # graphs in the disk (VLLM_XLA_CACHE_PATH). - if is_prompt: + if exec_mode.is_prefill(): # Prefll torch._dynamo.mark_dynamic(token_ids, 1) torch._dynamo.mark_dynamic(position_ids, 1) @@ -229,15 +263,8 @@ def _dummy_run( torch._dynamo.mark_dynamic(t, 0) torch._dynamo.mark_dynamic(p, 0) # Dummy run. - self.model(token_ids, - position_ids, - attn_metadata, - input_lens, - t, - p, - num_samples, - kv_caches, - is_prompt=is_prompt) + self.model(token_ids, position_ids, attn_metadata, input_lens, t, p, + num_samples, kv_caches) def warmup_model( self, @@ -248,13 +275,13 @@ def warmup_model( start = time.time() for batch_size in [1]: seq_len = 16 - while True: - self._dummy_run(batch_size, seq_len, kv_caches, is_prompt=True) + while seq_len <= self.model_config.max_model_len: + self._dummy_run(batch_size, + seq_len, + kv_caches, + exec_mode=ExecutionMode.PREFILL) xm.wait_device_ops() logger.info("batch_size: %d, seq_len: %d", batch_size, seq_len) - - if seq_len >= self.model_config.max_model_len: - break num_tokens = batch_size * seq_len if num_tokens >= self.scheduler_config.max_num_batched_tokens: break @@ -263,12 +290,39 @@ def warmup_model( end = time.time() logger.info("Compilation for prefill done in %.2f s.", end - start) + # Prefix prefill + if self.cache_config.enable_prefix_caching: + logger.info("Compiling the model with different input shapes for " + "prefix prefill...") + start = time.time() + for batch_size in [1]: + seq_len = 16 + while seq_len <= self.model_config.max_model_len: + self._dummy_run(batch_size, + seq_len, + kv_caches, + exec_mode=ExecutionMode.PREFIX_PREFILL) + xm.wait_device_ops() + logger.info("batch_size: %d, seq_len: %d", batch_size, + seq_len) + num_tokens = batch_size * seq_len + if (num_tokens >= + self.scheduler_config.max_num_batched_tokens): + break + seq_len = seq_len * 2 + end = time.time() + logger.info("Compilation for prefix prefill done in %.2f s.", + end - start) + # Decode start = time.time() seq_len = 1 batch_size = 8 # Must be in sync with _get_padded_batch_size() while True: - self._dummy_run(batch_size, seq_len, kv_caches, is_prompt=False) + self._dummy_run(batch_size, + seq_len, + kv_caches, + exec_mode=ExecutionMode.DECODE) xm.wait_device_ops() logger.info("batch_size: %d, seq_len: %d", batch_size, seq_len) @@ -287,9 +341,11 @@ def _prepare_prompt( input_tokens: List[int] = [] input_positions: List[int] = [] prompt_lens: List[int] = [] + context_lens: List[int] = [] slot_mapping: List[int] = [] - for seq_group_metadata in seq_group_metadata_list: + for batch_idx, seq_group_metadata in enumerate( + seq_group_metadata_list): assert seq_group_metadata.is_prompt seq_ids = list(seq_group_metadata.seq_data.keys()) assert len(seq_ids) == 1 @@ -298,19 +354,31 @@ def _prepare_prompt( seq_data = seq_group_metadata.seq_data[seq_id] # Could include output tokens when a request is preempted. prompt_tokens = seq_data.get_token_ids() + seq_len = len(prompt_tokens) + + num_computed_blocks = len(seq_group_metadata.computed_block_nums) + num_computed_tokens = num_computed_blocks * self.block_size + if num_computed_tokens > 0: + prompt_tokens = prompt_tokens[num_computed_tokens:] + context_lens.append(seq_len) + else: + context_lens.append(0) + prompt_len = len(prompt_tokens) prompt_lens.append(prompt_len) input_tokens.extend(prompt_tokens) - input_positions.extend(list(range(prompt_len))) + input_positions.extend(range(num_computed_tokens, seq_len)) assert seq_group_metadata.block_tables is not None block_table = seq_group_metadata.block_tables[seq_id] - for i in range(prompt_len): + for i in range(num_computed_tokens, seq_len): block_number = block_table[i // self.block_size] block_offset = i % self.block_size slot = block_number * self.block_size + block_offset slot_mapping.append(slot) + if num_computed_tokens > 0: + self.block_tables[batch_idx, :len(block_table)] = block_table # Add paddings to EACH prompt to the smallest power of 2 that is # greater than or equal to the prompt length. @@ -338,14 +406,21 @@ def _prepare_prompt( prompt_lens = torch.tensor(prompt_lens, dtype=torch.int32, device="cpu") + context_lens = torch.tensor(context_lens, + dtype=torch.int32, + device="cpu") + block_tables = torch.tensor(self.block_tables[:num_prefills], + dtype=torch.int32, + device="cpu") attn_metadata = self.attn_backend.make_metadata( num_prefills=num_prefills, num_prefill_tokens=0, # NOTE: This is not used. num_decode_tokens=0, slot_mapping=slot_mapping, multi_modal_placeholder_index_maps=None, - block_tables=None, - context_lens=None, + block_tables=block_tables, + context_lens=context_lens, + effective_query_lens=prompt_lens, ) return input_tokens, input_positions, attn_metadata, prompt_lens @@ -550,6 +625,10 @@ def execute_model( # process them separately. This is a temporary hack that should be # optimized by using SplashAttention. orig_slot_mapping = model_input.attn_metadata.slot_mapping + orig_block_tables = model_input.attn_metadata.block_tables + orig_context_lens = model_input.attn_metadata.context_lens + orig_effective_query_lens = \ + model_input.attn_metadata.effective_query_lens batch_size = model_input.input_lens.shape[0] start_idx = 0 next_token_ids = [] @@ -568,18 +647,24 @@ def execute_model( attn_metadata.num_prefills = 1 attn_metadata.slot_mapping = orig_slot_mapping[ None, start_idx:end_idx].to(self.device) + if orig_context_lens[i].item() > 0: + attn_metadata.context_lens = orig_context_lens[i:i + 1].to( + self.device) + attn_metadata.block_tables = orig_block_tables[ + i].unsqueeze(0).to(self.device) + attn_metadata.effective_query_lens = \ + orig_effective_query_lens[i:i + 1].to(self.device) + else: + attn_metadata.context_lens = None + attn_metadata.block_tables = None + attn_metadata.effective_query_lens = None input_lens = model_input.input_lens[i:i + 1].to(self.device) t = model_input.t[i:i + 1].to(self.device) p = model_input.p[i:i + 1].to(self.device) - output_token_ids = self.model(token_ids, - position_ids, - attn_metadata, - input_lens, - t, - p, + output_token_ids = self.model(token_ids, position_ids, + attn_metadata, input_lens, t, p, model_input.num_samples, - kv_caches, - is_prompt=True) + kv_caches) next_token_ids.append(output_token_ids[0]) start_idx = end_idx @@ -624,15 +709,10 @@ def execute_model( input_lens = model_input.input_lens.to(self.device) for i in range(num_steps): slot_mapping = attn_metadata.slot_mapping - output_token_ids = self.model(token_ids, - position_ids, - attn_metadata, - input_lens, - t, - p, + output_token_ids = self.model(token_ids, position_ids, + attn_metadata, input_lens, t, p, model_input.num_samples, - kv_caches, - is_prompt=False) + kv_caches) self.cached_step_outputs.append(output_token_ids) if i < num_steps - 1: @@ -667,34 +747,11 @@ def execute_model( return [sampler_output] -class ModelWrapper(TorchCompileWrapperWithCustomDispatcher): +class ModelWrapper(nn.Module): - def __init__(self, model: nn.Module, vllm_config: VllmConfig): + def __init__(self, model: nn.Module): + super().__init__() self.model = model - compiled_callable = torch.compile(self.forward, - backend="openxla", - fullgraph=True, - dynamic=False) - super().__init__( - compiled_callable, - compilation_level=vllm_config.compilation_config.level) - - def __call__(self, *args, is_prompt: bool, **kwargs): - if len(self.compiled_codes) < 3 or not self.use_custom_dispatcher: - # not fully compiled yet, or not using the custom dispatcher, - # let PyTorch handle it - return self.compiled_callable(*args, **kwargs) - # the 3 compiled codes are: - # 0: for profiling - # 1: for prompt - # 2: for decode - # dispatch to the compiled code directly, skip PyTorch - if is_prompt: - with self.dispatch_to_code(1): - return self.forward(*args, **kwargs) - else: - with self.dispatch_to_code(2): - return self.forward(*args, **kwargs) def forward( self, diff --git a/vllm/worker/tpu_worker.py b/vllm/worker/tpu_worker.py index 096cb23416909..8754f7538f251 100644 --- a/vllm/worker/tpu_worker.py +++ b/vllm/worker/tpu_worker.py @@ -13,7 +13,7 @@ from vllm.model_executor import set_random_seed from vllm.sequence import ExecuteModelRequest from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, get_dtype_size -from vllm.worker.tpu_model_runner import TPUModelRunner +from vllm.worker.tpu_model_runner import ExecutionMode, TPUModelRunner from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, LoraNotSupportedWorkerBase, WorkerBase, WorkerInput) @@ -112,7 +112,7 @@ def determine_num_available_blocks(self) -> Tuple[int, int]: batch_size=1, seq_len=self.scheduler_config.max_num_batched_tokens, kv_caches=kv_caches, - is_prompt=True, + exec_mode=ExecutionMode.PREFILL, ) # Synchronize before measuring the memory usage. xm.wait_device_ops() diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index f10cffe0863d2..0c4d73752e319 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -1,6 +1,7 @@ """A GPU worker class.""" import gc import os +import time from pathlib import Path from typing import Dict, List, Optional, Set, Tuple, Type, Union @@ -8,8 +9,9 @@ import torch.distributed import vllm.envs as envs -from vllm.config import ParallelConfig, VllmConfig -from vllm.distributed import (ensure_model_parallel_initialized, +from vllm.config import VllmConfig +from vllm.distributed import (ensure_kv_transfer_initialized, + ensure_model_parallel_initialized, init_distributed_environment, set_custom_all_reduce) from vllm.logger import init_logger @@ -22,9 +24,9 @@ from vllm.sequence import (ExecuteModelRequest, IntermediateTensors, SequenceGroupMetadata, SequenceGroupMetadataDelta) from vllm.worker.cache_engine import CacheEngine -from vllm.worker.embedding_model_runner import EmbeddingModelRunner from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner from vllm.worker.model_runner import GPUModelRunnerBase, ModelRunner +from vllm.worker.pooling_model_runner import PoolingModelRunner from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, WorkerBase, WorkerInput) @@ -74,10 +76,8 @@ def __init__( else {"return_hidden_states": True} ModelRunnerClass: Type[GPUModelRunnerBase] = ModelRunner - if model_runner_cls is not None: - ModelRunnerClass = model_runner_cls - elif model_config.task == "embedding": - ModelRunnerClass = EmbeddingModelRunner + if model_config.task == "embedding": + ModelRunnerClass = PoolingModelRunner elif self.model_config.is_encoder_decoder: ModelRunnerClass = EncoderDecoderModelRunner self.model_runner: GPUModelRunnerBase = ModelRunnerClass( @@ -86,6 +86,9 @@ def __init__( is_driver_worker=is_driver_worker, **speculative_args, ) + if model_runner_cls is not None: + self.model_runner = model_runner_cls(self.model_runner) + # Uninitialized cache engine. Will be initialized by # initialize_cache. self.cache_engine: List[CacheEngine] @@ -170,7 +173,7 @@ def init_device(self) -> None: raise RuntimeError( f"Not support device type: {self.device_config.device}") # Initialize the distributed environment. - init_worker_distributed_environment(self.parallel_config, self.rank, + init_worker_distributed_environment(self.vllm_config, self.rank, self.distributed_init_method, self.local_rank) # Set random seed. @@ -217,6 +220,7 @@ def determine_num_available_blocks(self) -> Tuple[int, int]: torch.cuda.reset_peak_memory_stats() free_memory_pre_profile, total_gpu_memory = torch.cuda.mem_get_info() + start_time = time.time() # Execute a forward pass with dummy inputs to profile the memory usage # of the model. @@ -257,12 +261,18 @@ def determine_num_available_blocks(self) -> Tuple[int, int]: num_gpu_blocks = max(num_gpu_blocks, 0) num_cpu_blocks = max(num_cpu_blocks, 0) + end_time = time.time() logger.info( - "Memory profiling results: total_gpu_memory=%.2fGiB" - " initial_memory_usage=%.2fGiB peak_torch_memory=%.2fGiB" - " memory_usage_post_profile=%.2fGiB" - " non_torch_memory=%.2fGiB kv_cache_size=%.2fGiB" - " gpu_memory_utilization=%.2f", total_gpu_memory / (1024**3), + "Memory profiling results: " + "duration=%.2f seconds, " + "total_gpu_memory=%.2fGiB, " + "initial_memory_usage=%.2fGiB, " + "peak_torch_memory=%.2fGiB, " + "memory_usage_post_profile=%.2fGiB, " + "non_torch_memory=%.2fGiB, " + "kv_cache_size=%.2fGiB, " + "gpu_memory_utilization=%.2f.", end_time - start_time, + total_gpu_memory / (1024**3), (total_gpu_memory - free_memory_pre_profile) / (1024**3), (peak_memory - non_torch_allocations) / (1024**3), total_allocated_bytes / (1024**3), @@ -476,20 +486,22 @@ def get_cache_block_size_bytes(self) -> int: def init_worker_distributed_environment( - parallel_config: ParallelConfig, + vllm_config: VllmConfig, rank: int, distributed_init_method: Optional[str] = None, local_rank: int = -1, ) -> None: """Initialize the distributed environment.""" + parallel_config = vllm_config.parallel_config set_custom_all_reduce(not parallel_config.disable_custom_all_reduce) init_distributed_environment(parallel_config.world_size, rank, distributed_init_method, local_rank) - ensure_model_parallel_initialized(parallel_config.tensor_parallel_size, parallel_config.pipeline_parallel_size) + ensure_kv_transfer_initialized(vllm_config) + def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype): # Check if the GPU supports the dtype. diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index cf8a4946a71c4..7c0bc5a678956 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -1,9 +1,8 @@ import dataclasses -import importlib import os import time from abc import ABC, abstractmethod -from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union +from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union import torch @@ -15,7 +14,7 @@ from vllm.platforms import current_platform from vllm.sequence import ExecuteModelRequest, IntermediateTensors from vllm.utils import (enable_trace_function_call_for_thread, - update_environment_variables) + resolve_obj_by_qualname, update_environment_variables) from vllm.worker.model_runner_base import (BroadcastableModelInput, ModelRunnerBase, ModelRunnerInputBase) @@ -44,6 +43,7 @@ def __init__( self.speculative_config = vllm_config.speculative_config self.prompt_adapter_config = vllm_config.prompt_adapter_config self.observability_config = vllm_config.observability_config + self.kv_transfer_config = vllm_config.kv_transfer_config @abstractmethod def init_device(self) -> None: @@ -411,23 +411,14 @@ class WorkerWrapperBase: We first instantiate the WorkerWrapper, which remembers the worker module and class name. Then, when we call `update_environment_variables`, and the real initialization happens in `init_worker`. - - If worker_class_fn is specified, it will be executed to get the worker - class. - Otherwise, the worker class will be obtained by dynamically importing it - using worker_module_name and worker_class_name. """ def __init__( self, - worker_module_name: str, - worker_class_name: str, - trust_remote_code: bool = False, - worker_class_fn: Optional[Callable[[], - Type[WorkerBase]]] = None) -> None: - self.worker_module_name = worker_module_name - self.worker_class_name = worker_class_name - self.worker_class_fn = worker_class_fn + vllm_config: VllmConfig, + ) -> None: + self.vllm_config = vllm_config + trust_remote_code = vllm_config.model_config.trust_remote_code self.worker: Optional[WorkerBase] = None if trust_remote_code: # note: lazy import to avoid importing torch before initializing @@ -456,12 +447,8 @@ def init_worker(self, *args, **kwargs): from vllm.plugins import load_general_plugins load_general_plugins() - if self.worker_class_fn: - worker_class = self.worker_class_fn() - else: - mod = importlib.import_module(self.worker_module_name) - worker_class = getattr(mod, self.worker_class_name) - + worker_class = resolve_obj_by_qualname( + self.vllm_config.parallel_config.worker_cls) self.worker = worker_class(*args, **kwargs) assert self.worker is not None @@ -480,6 +467,9 @@ def execute_method(self, method, *args, **kwargs): logger.exception(msg) raise e + def __getattr__(self, attr): + return getattr(self.worker, attr) + def extract_previous_hidden_states( data: Union[ExecuteModelRequest, Dict[str, torch.Tensor]]) -> \