From b482b9a5b13ba7d126adabbedb3ba66f48d4d83b Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Fri, 2 Aug 2024 16:51:22 -0400 Subject: [PATCH 01/36] [CI/Build] Add support for Python 3.12 (#7035) --- .github/workflows/mypy.yaml | 2 +- .github/workflows/publish.yml | 2 +- .github/workflows/ruff.yml | 2 +- .github/workflows/yapf.yml | 2 +- CMakeLists.txt | 2 +- docs/source/getting_started/installation.rst | 2 +- setup.py | 1 + 7 files changed, 7 insertions(+), 6 deletions(-) diff --git a/.github/workflows/mypy.yaml b/.github/workflows/mypy.yaml index 68e3a3fefdc5d..8d423657630c2 100644 --- a/.github/workflows/mypy.yaml +++ b/.github/workflows/mypy.yaml @@ -15,7 +15,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.8", "3.9", "3.10", "3.11"] + python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"] steps: - uses: actions/checkout@v2 - name: Set up Python ${{ matrix.python-version }} diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index 607fda754bf23..aeeaf6efab043 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -48,7 +48,7 @@ jobs: fail-fast: false matrix: os: ['ubuntu-20.04'] - python-version: ['3.8', '3.9', '3.10', '3.11'] + python-version: ['3.8', '3.9', '3.10', '3.11', '3.12'] pytorch-version: ['2.4.0'] # Must be the most recent version that meets requirements-cuda.txt. cuda-version: ['11.8', '12.1'] diff --git a/.github/workflows/ruff.yml b/.github/workflows/ruff.yml index 773def58fd966..1a794af572fef 100644 --- a/.github/workflows/ruff.yml +++ b/.github/workflows/ruff.yml @@ -15,7 +15,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.8", "3.9", "3.10", "3.11"] + python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"] steps: - uses: actions/checkout@v2 - name: Set up Python ${{ matrix.python-version }} diff --git a/.github/workflows/yapf.yml b/.github/workflows/yapf.yml index 04f307bcf8b0e..c89f82dfaaaf6 100644 --- a/.github/workflows/yapf.yml +++ b/.github/workflows/yapf.yml @@ -14,7 +14,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.8", "3.9", "3.10", "3.11"] + python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"] steps: - uses: actions/checkout@v2 - name: Set up Python ${{ matrix.python-version }} diff --git a/CMakeLists.txt b/CMakeLists.txt index 77a8af549b027..dbe688186f17f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -14,7 +14,7 @@ include(${CMAKE_CURRENT_LIST_DIR}/cmake/utils.cmake) # Supported python versions. These versions will be searched in order, the # first match will be selected. These should be kept in sync with setup.py. # -set(PYTHON_SUPPORTED_VERSIONS "3.8" "3.9" "3.10" "3.11") +set(PYTHON_SUPPORTED_VERSIONS "3.8" "3.9" "3.10" "3.11" "3.12") # Supported NVIDIA architectures. set(CUDA_SUPPORTED_ARCHS "7.0;7.5;8.0;8.6;8.9;9.0") diff --git a/docs/source/getting_started/installation.rst b/docs/source/getting_started/installation.rst index 0253717da3cda..57ad8bacedfcc 100644 --- a/docs/source/getting_started/installation.rst +++ b/docs/source/getting_started/installation.rst @@ -9,7 +9,7 @@ Requirements ------------ * OS: Linux -* Python: 3.8 -- 3.11 +* Python: 3.8 -- 3.12 * GPU: compute capability 7.0 or higher (e.g., V100, T4, RTX20xx, A100, L4, H100, etc.) Install with pip diff --git a/setup.py b/setup.py index 63c1f466d2910..91307e8a94062 100644 --- a/setup.py +++ b/setup.py @@ -465,6 +465,7 @@ def _read_requirements(filename: str) -> List[str]: "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", "License :: OSI Approved :: Apache Software License", "Topic :: Scientific/Engineering :: Artificial Intelligence", ], From a8d604ca2a2912b3a5352821c53c080383580df1 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Fri, 2 Aug 2024 16:51:58 -0400 Subject: [PATCH 02/36] [Misc] Disambiguate quantized types via a new ScalarType (#6396) --- CMakeLists.txt | 52 ++- Dockerfile.openvino | 3 + benchmarks/kernels/benchmark_marlin.py | 50 +-- cmake/cpu_extension.cmake | 1 - csrc/{ => core}/registration.h | 0 csrc/core/scalar_type.hpp | 382 ++++++++++++++++++ csrc/core/torch_bindings.cpp | 16 + csrc/cpu/torch_bindings.cpp | 2 +- csrc/moe/torch_bindings.cpp | 2 +- csrc/ops.h | 8 +- csrc/quantization/gptq_marlin/gptq_marlin.cu | 66 ++- .../marlin/sparse/marlin_24_cuda_kernel.cu | 17 +- csrc/torch_bindings.cpp | 2 +- setup.py | 9 +- tests/kernels/test_int8_quant.py | 2 - tests/kernels/test_marlin_gemm.py | 75 ++-- tests/test_scalartype.py | 36 ++ vllm/_core_ext.py | 177 ++++++++ vllm/_custom_ops.py | 29 +- .../layers/quantization/awq_marlin.py | 49 ++- .../schemes/compressed_tensors_w4a16_24.py | 18 +- .../schemes/compressed_tensors_wNa16.py | 29 +- .../layers/quantization/gptq_marlin.py | 43 +- .../layers/quantization/gptq_marlin_24.py | 29 +- .../layers/quantization/utils/marlin_utils.py | 120 +++--- .../quantization/utils/marlin_utils_test.py | 29 +- .../utils/marlin_utils_test_24.py | 30 +- .../layers/quantization/utils/quant_utils.py | 148 +++---- vllm/scalar_type.py | 35 ++ 29 files changed, 1107 insertions(+), 352 deletions(-) rename csrc/{ => core}/registration.h (100%) create mode 100644 csrc/core/scalar_type.hpp create mode 100644 csrc/core/torch_bindings.cpp create mode 100644 tests/test_scalartype.py create mode 100644 vllm/_core_ext.py create mode 100644 vllm/scalar_type.py diff --git a/CMakeLists.txt b/CMakeLists.txt index dbe688186f17f..922613ec5ddaa 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -66,6 +66,39 @@ endif() # find_package(Torch REQUIRED) +# +# Add the `default` target which detects which extensions should be +# built based on platform/architecture. This is the same logic that +# setup.py uses to select which extensions should be built and should +# be kept in sync. +# +# The `default` target makes direct use of cmake easier since knowledge +# of which extensions are supported has been factored in, e.g. +# +# mkdir build && cd build +# cmake -G Ninja -DVLLM_PYTHON_EXECUTABLE=`which python3` -DCMAKE_LIBRARY_OUTPUT_DIRECTORY=../vllm .. +# cmake --build . --target default +# +add_custom_target(default) +message(STATUS "Enabling core extension.") + +# Define _core_C extension +# built for (almost) every target platform, (excludes TPU and Neuron) + +set(VLLM_EXT_SRC + "csrc/core/torch_bindings.cpp") + +define_gpu_extension_target( + _core_C + DESTINATION vllm + LANGUAGE CXX + SOURCES ${VLLM_EXT_SRC} + COMPILE_FLAGS ${CXX_COMPILE_FLAGS} + USE_SABI 3 + WITH_SOABI) + +add_dependencies(default _core_C) + # # Forward the non-CUDA device extensions to external CMake scripts. # @@ -74,7 +107,7 @@ if (NOT VLLM_TARGET_DEVICE STREQUAL "cuda" AND if (VLLM_TARGET_DEVICE STREQUAL "cpu") include(${CMAKE_CURRENT_LIST_DIR}/cmake/cpu_extension.cmake) else() - message(FATAL_ERROR "Unsupported vLLM target device: ${VLLM_TARGET_DEVICE}") + return() endif() return() endif() @@ -132,7 +165,7 @@ if(NVCC_THREADS AND VLLM_GPU_LANG STREQUAL "CUDA") endif() # -# Define extension targets +# Define other extension targets # # @@ -228,21 +261,6 @@ define_gpu_extension_target( -# -# Add the `default` target which detects which extensions should be -# built based on platform/architecture. This is the same logic that -# setup.py uses to select which extensions should be built and should -# be kept in sync. -# -# The `default` target makes direct use of cmake easier since knowledge -# of which extensions are supported has been factored in, e.g. -# -# mkdir build && cd build -# cmake -G Ninja -DVLLM_PYTHON_EXECUTABLE=`which python3` -DCMAKE_LIBRARY_OUTPUT_DIRECTORY=../vllm .. -# cmake --build . --target default -# -add_custom_target(default) - if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP") message(STATUS "Enabling C extension.") add_dependencies(default _C) diff --git a/Dockerfile.openvino b/Dockerfile.openvino index 7c62dd845aa99..c84dea419e58a 100644 --- a/Dockerfile.openvino +++ b/Dockerfile.openvino @@ -13,6 +13,9 @@ COPY requirements-common.txt /workspace/vllm/ COPY requirements-openvino.txt /workspace/vllm/ COPY vllm/ /workspace/vllm/vllm +COPY csrc/core /workspace/vllm/csrc/core +COPY cmake/utils.cmake /workspace/vllm/cmake/ +COPY CMakeLists.txt /workspace/vllm/ COPY setup.py /workspace/vllm/ # install build requirements diff --git a/benchmarks/kernels/benchmark_marlin.py b/benchmarks/kernels/benchmark_marlin.py index 684985b81f690..536c133bb3341 100644 --- a/benchmarks/kernels/benchmark_marlin.py +++ b/benchmarks/kernels/benchmark_marlin.py @@ -7,16 +7,17 @@ from vllm import _custom_ops as ops from vllm.model_executor.layers.quantization.gptq_marlin_24 import ( GPTQ_MARLIN_24_MAX_PARALLEL, GPTQ_MARLIN_24_MIN_THREAD_N, - GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_24_SUPPORTED_NUM_BITS) + GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES) from vllm.model_executor.layers.quantization.utils.marlin_utils import ( GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N, - MARLIN_SUPPORTED_GROUP_SIZES, MARLIN_SUPPORTED_NUM_BITS) + MARLIN_SUPPORTED_GROUP_SIZES, query_marlin_supported_quant_types) from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( MarlinWorkspace, marlin_quantize) from vllm.model_executor.layers.quantization.utils.marlin_utils_test_24 import ( marlin_24_quantize) from vllm.model_executor.layers.quantization.utils.quant_utils import ( - gptq_pack, quantize_weights, sort_weights) + gptq_pack, gptq_quantize_weights, sort_weights) +from vllm.scalar_type import ScalarType from vllm.utils import FlexibleArgumentParser DEFAULT_MODELS = ["meta-llama/Llama-2-7b-hf/TP1"] @@ -27,13 +28,14 @@ def bench_run(results: List[benchmark.Measurement], model: str, - act_order: bool, is_k_full: bool, num_bits: int, group_size: int, - size_m: int, size_k: int, size_n: int): + act_order: bool, is_k_full: bool, quant_type: ScalarType, + group_size: int, size_m: int, size_k: int, size_n: int): label = "Quant Matmul" - sub_label = ("{}, act={} k_full={}, b={}, g={}, " - "MKN=({}x{}x{})".format(model, act_order, is_k_full, num_bits, - group_size, size_m, size_k, size_n)) + sub_label = ("{}, act={} k_full={}, q={}, g={}, " + "MKN=({}x{}x{})".format(model, act_order, is_k_full, + str(quant_type), group_size, size_m, + size_k, size_n)) print(f"Testing: {sub_label}") @@ -50,18 +52,18 @@ def bench_run(results: List[benchmark.Measurement], model: str, marlin_g_idx, marlin_sort_indices, marlin_rand_perm, - ) = marlin_quantize(b, num_bits, group_size, act_order) + ) = marlin_quantize(b, quant_type, group_size, act_order) # Marlin_24 quant (marlin_24_w_ref, marlin_24_q_w_comp, marlin_24_meta, - marlin_24_s) = marlin_24_quantize(b, num_bits, group_size) + marlin_24_s) = marlin_24_quantize(b, quant_type, group_size) marlin_zp = torch.empty(0, dtype=torch.int, device=b.device) # GPTQ quant (w_ref, q_w, s, g_idx, - rand_perm) = quantize_weights(b, num_bits, group_size, act_order) - q_w_gptq = gptq_pack(q_w, num_bits, size_k, size_n) + rand_perm) = gptq_quantize_weights(b, quant_type, group_size, act_order) + q_w_gptq = gptq_pack(q_w, quant_type.size_bits, size_k, size_n) # For act_order, sort the "weights" and "g_idx" # so that group ids are increasing @@ -75,10 +77,11 @@ def bench_run(results: List[benchmark.Measurement], model: str, marlin_24_workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_24_MIN_THREAD_N, GPTQ_MARLIN_24_MAX_PARALLEL) + marlin_zp = torch.zeros_like(marlin_s, dtype=torch.int) globals = { # Gen params - "num_bits": num_bits, + "quant_type": quant_type, "group_size": group_size, "size_m": size_m, "size_n": size_n, @@ -128,7 +131,7 @@ def bench_run(results: List[benchmark.Measurement], model: str, results.append( benchmark.Timer( stmt= - "output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, num_bits, size_m, size_n, size_k, is_k_full, False, False)", # noqa: E501 + "output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, False)", # noqa: E501 globals=globals, label=label, sub_label=sub_label, @@ -138,19 +141,19 @@ def bench_run(results: List[benchmark.Measurement], model: str, results.append( benchmark.Timer( stmt= - "output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, num_bits, size_m, size_n, size_k, is_k_full, False, True)", # noqa: E501 + "output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, True)", # noqa: E501 globals=globals, label=label, sub_label=sub_label, description="gptq_marlin_gemm_fp32", ).blocked_autorange(min_run_time=min_run_time)) - if (num_bits in GPTQ_MARLIN_24_SUPPORTED_NUM_BITS + if (quant_type in GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES and group_size in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES): results.append( benchmark.Timer( stmt= - "output = gptq_marlin_24_gemm(a, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s, marlin_24_workspace.scratch, num_bits, size_m, size_n, size_k)", # noqa: E501 + "output = gptq_marlin_24_gemm(a, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s, marlin_24_workspace.scratch, quant_type, size_m, size_n, size_k)", # noqa: E501 globals=globals, label=label, sub_label=sub_label, @@ -160,7 +163,7 @@ def bench_run(results: List[benchmark.Measurement], model: str, results.append( benchmark.Timer( stmt= - "q_res = gptq_marlin_repack(q_w_gptq, repack_sort_indices, size_k, size_n, num_bits)", # noqa: E501 + "q_res = gptq_marlin_repack(q_w_gptq, repack_sort_indices, size_k, size_n, quant_type.size_bits)", # noqa: E501 globals=globals, label=label, sub_label=sub_label, @@ -196,9 +199,10 @@ def main(args): ) > 0 and is_k_full not in args.limit_k_full: continue - for num_bits in MARLIN_SUPPORTED_NUM_BITS: - if len(args.limit_num_bits - ) > 0 and num_bits not in args.limit_num_bits: + for quant_type in query_marlin_supported_quant_types( + False): + if len(args.limit_num_bits) > 0 and \ + quant_type.size_bits not in args.limit_num_bits: continue for group_size in MARLIN_SUPPORTED_GROUP_SIZES: @@ -215,8 +219,8 @@ def main(args): for size_m in args.batch_sizes: bench_run(results, model, act_order, is_k_full, - num_bits, group_size, size_m, size_k, - size_n) + quant_type, group_size, size_m, + size_k, size_n) compare = benchmark.Compare(results) compare.print() diff --git a/cmake/cpu_extension.cmake b/cmake/cpu_extension.cmake index 118f9b28e0ae3..3ba3a2b6a93cd 100644 --- a/cmake/cpu_extension.cmake +++ b/cmake/cpu_extension.cmake @@ -113,6 +113,5 @@ define_gpu_extension_target( WITH_SOABI ) -add_custom_target(default) message(STATUS "Enabling C extension.") add_dependencies(default _C) diff --git a/csrc/registration.h b/csrc/core/registration.h similarity index 100% rename from csrc/registration.h rename to csrc/core/registration.h diff --git a/csrc/core/scalar_type.hpp b/csrc/core/scalar_type.hpp new file mode 100644 index 0000000000000..9f78402eee2a7 --- /dev/null +++ b/csrc/core/scalar_type.hpp @@ -0,0 +1,382 @@ +#pragma once + +#include + +namespace vllm { + +// +// ScalarType can represent a wide range of floating point and integer types, +// in particular it can be used to represent sub-byte data types (something +// that torch.dtype currently does not support). +// +// ScalarTypeTorch is a subclass of ScalarType that is compatible with +// TORCH_LIBRARY, making it accessible from Python as well meaning this class +// can be used as a argument for custom operators, helping to simplify these +// interfaces. +// +// The type definitions on the Python side can be found in: vllm/_core_ext.pyi +// these type definitions should be kept up to date with any Python API changes +// here. +// +class ScalarType { + public: + enum NanRepr : int64_t { + NAN_NONE = 0, // nans are not supported + NAN_IEEE_754 = 1, // nans are: exp all 1s, mantissa not all 0s + NAN_EXTD_RANGE_MAX_MIN = 2, // nans are: exp all 1s, mantissa all 1s + + NAN_REPR_ID_MAX + }; + + constexpr ScalarType(bool signed_, int64_t exponent, int64_t mantissa, + int64_t bias, bool finite_values_only = false, + NanRepr nan_repr = NAN_IEEE_754) + : exponent(exponent), + mantissa(mantissa), + bias(bias), + signed_(signed_), + finite_values_only(finite_values_only), + nan_repr(nan_repr){}; + + static constexpr ScalarType int_(int64_t size_bits, int64_t bias = 0) { + return ScalarType(true, 0, size_bits - 1, bias); + } + + static constexpr ScalarType uint(int64_t size_bits, int64_t bias = 0) { + return ScalarType(false, 0, size_bits, bias); + } + + // IEEE 754 compliant floating point type + static constexpr ScalarType float_IEEE754(int64_t exponent, + int64_t mantissa) { + TORCH_CHECK(mantissa > 0 && exponent > 0); + return ScalarType(true, exponent, mantissa, 0, false, NAN_IEEE_754); + } + + // IEEE 754 non-compliant floating point type + static constexpr ScalarType float_(int64_t exponent, int64_t mantissa, + bool finite_values_only, + NanRepr nan_repr) { + TORCH_CHECK(nan_repr < NAN_REPR_ID_MAX, "Invalid NanRepr"); + TORCH_CHECK(mantissa > 0 && exponent > 0); + TORCH_CHECK(nan_repr != NAN_IEEE_754, + "use `float_IEEE754` constructor for floating point types that " + "follow IEEE 754 conventions"); + return ScalarType(true, exponent, mantissa, 0, finite_values_only, + nan_repr); + } + + int64_t const exponent; // size of the exponent field (0 for integer types) + int64_t const mantissa; // size of the mantissa field (size of the integer + // excluding the sign bit for integer types) + int64_t const bias; // stored values equal value + bias, + // used for quantized type + bool const signed_; // flag if the type supports negative numbers (i.e. has a + // sign bit) + + // Extra Floating point info + bool const finite_values_only; // i.e. no +/-inf if true + NanRepr const nan_repr; // how NaNs are represented + // (not applicable for integer types) + + int64_t size_bits() const { return mantissa + exponent + is_signed(); } + bool is_signed() const { return signed_; } + bool is_integer() const { return exponent == 0; } + bool is_floating_point() const { return exponent > 0; } + bool is_ieee_754() const { + return is_floating_point() && finite_values_only == false && + nan_repr == NAN_IEEE_754; + } + bool has_nans() const { return is_floating_point() && nan_repr != NAN_NONE; } + bool has_infs() const { + return is_floating_point() && finite_values_only == false; + } + bool has_bias() const { return bias != 0; } + + private: + double _floating_point_max() const { + TORCH_CHECK(mantissa <= 52 && exponent <= 11, + "Cannot represent max/min as a double for type ", str()); + + uint64_t max_mantissa = (uint64_t(1) << mantissa) - 1; + if (nan_repr == NAN_EXTD_RANGE_MAX_MIN) { + max_mantissa -= 1; + } + + uint64_t max_exponent = (uint64_t(1) << exponent) - 2; + if (nan_repr == NAN_EXTD_RANGE_MAX_MIN || nan_repr == NAN_NONE) { + TORCH_CHECK(exponent < 11, + "Cannot represent max/min as a double for type ", str()); + max_exponent += 1; + } + + // adjust the exponent to match that of a double + // for now we assume the exponent bias is the standard 2^(e-1) -1, (where e + // is the exponent bits), there is some precedent for non-standard biases, + // example `float8_e4m3b11fnuz` here: https://github.com/jax-ml/ml_dtypes + // but to avoid premature over complication we are just assuming the + // standard exponent bias until there is a need to support non-standard + // biases + uint64_t exponent_bias = (uint64_t(1) << (exponent - 1)) - 1; + uint64_t exponent_bias_double = (uint64_t(1) << 10) - 1; // double e = 11 + + uint64_t max_exponent_double = + max_exponent - exponent_bias + exponent_bias_double; + + // shift the mantissa into the position for a double and + // the exponent + uint64_t double_raw = + (max_mantissa << (52 - mantissa)) | (max_exponent_double << 52); + + return *reinterpret_cast(&double_raw); + } + + std::variant _raw_max() const { + if (is_floating_point()) { + return {_floating_point_max()}; + } else { + TORCH_CHECK(size_bits() < 64 || size_bits() == 64 && is_signed(), + "Cannot represent max as a int64_t"); + return {(int64_t(1) << mantissa) - 1}; + } + } + + std::variant _raw_min() const { + if (is_floating_point()) { + TORCH_CHECK(is_signed(), + "We currently assume all floating point types are signed"); + constexpr uint64_t sign_bit_double = (uint64_t(1) << 63); + + double max = _floating_point_max(); + uint64_t max_raw = *reinterpret_cast(&max); + uint64_t min_raw = max_raw | sign_bit_double; + return {*reinterpret_cast(&min_raw)}; + } else { + TORCH_CHECK(!is_signed() || size_bits() <= 64, + "Cannot represent min as a int64_t"); + if (is_signed()) { + // set the top bit to 1 (i.e. INT64_MIN) and the rest to 0 + // then perform an arithmetic shift right to set all the bits above + // (size_bits() - 1) to 1 + return {INT64_MIN >> (64 - size_bits())}; + } else { + return {int64_t(0)}; + } + } + } + + public: + // Max representable value for this scalar type. + // (accounting for bias if there is one) + std::variant max() const { + return std::visit( + [this](auto x) -> std::variant { return {x - bias}; }, + _raw_max()); + } + + // Min representable value for this scalar type. + // (accounting for bias if there is one) + std::variant min() const { + return std::visit( + [this](auto x) -> std::variant { return {x - bias}; }, + _raw_min()); + } + + std::string str() const { + /* naming generally follows: https://github.com/jax-ml/ml_dtypes + * for floating point types (leading f) the scheme is: + * `float_em[flags]` + * flags: + * - no-flags: means it follows IEEE 754 conventions + * - f: means finite values only (no infinities) + * - n: means nans are supported (non-standard encoding) + * for integer types the scheme is: + * `[u]int[b]` + * - if bias is not present it means its zero + */ + if (is_floating_point()) { + auto ret = "float" + std::to_string(size_bits()) + "_e" + + std::to_string(exponent) + "m" + std::to_string(mantissa); + if (!is_ieee_754()) { + if (finite_values_only) { + ret += "f"; + } + if (nan_repr != NAN_NONE) { + ret += "n"; + } + } + return ret; + } else { + auto ret = ((is_signed()) ? "int" : "uint") + std::to_string(size_bits()); + if (has_bias()) { + ret += "b" + std::to_string(bias); + } + return ret; + } + } + + bool operator==(ScalarType const& other) const { + return mantissa == other.mantissa && exponent == other.exponent && + bias == other.bias && signed_ == other.signed_ && + finite_values_only == other.finite_values_only && + nan_repr == other.nan_repr; + } +}; + +// Create a TORCH_LIBRARY compatible version of ScalarType (i.e. inherit from +// torch::CustomClassHolder), we use multiple inheritance here since we cannot +// have ScalarType inherit from torch::CustomClassHolder and have a constexpr +// constructor at the same time (torch::CustomClassHolder does not have a +// constexpr destructor) +class ScalarTypeTorch : public torch::CustomClassHolder, public ScalarType { + public: + ScalarTypeTorch(int64_t exponent, int64_t mantissa, int64_t bias, + bool _signed) + : ScalarType(exponent, mantissa, bias, _signed){}; + + ScalarTypeTorch(ScalarType type) : ScalarType(type){}; + + using Base = ScalarType; + using Self = ScalarTypeTorch; + using SelfPtr = c10::intrusive_ptr; + + static SelfPtr int_(int64_t size_bits, c10::optional bias) { + return c10::make_intrusive( + ScalarType::int_(size_bits, bias.value_or(0))); + } + + static SelfPtr uint(int64_t size_bits, c10::optional bias) { + return c10::make_intrusive( + ScalarType::uint(size_bits, bias.value_or(0))); + } + + static SelfPtr float_IEEE754(int64_t exponent, int64_t mantissa) { + return c10::make_intrusive( + ScalarType::float_IEEE754(exponent, mantissa)); + } + + static SelfPtr float_(int64_t exponent, int64_t mantissa, + bool finite_values_only, int64_t nan_repr) { + return c10::make_intrusive(ScalarType::float_( + exponent, mantissa, finite_values_only, NanRepr(nan_repr))); + } + + template + static void bind_readonly_property(torch::class_& cls, + std::string const& name, T Base::*field) { + auto getter_func = [field = std::move(field)](SelfPtr const& self) { + if constexpr (std::is_member_function_pointer_v) { + return (self.get()->*field)(); + } else { + return self.get()->*field; + } + }; + + cls.def_property(name, getter_func); + } + + template + static void bind_function(torch::class_& cls, const std::string& name, + MemberFunc Cls::*member) { + cls.def(name, [member = std::move(member)](SelfPtr const& self) { + return (self.get()->*member)(); + }); + } + + template + static void bind_function(torch::class_& cls, const std::string& name, + Func func) { + cls.def(name, func); + } + + template + static void bind_static_function(torch::class_& cls, + const std::string& name, Func func) { + cls.def_static(name, func); + } + + static void bind_class(torch::Library& lib) { + auto cls = lib.class_("ScalarType") + .def(torch::init()); + + // Bind Properties + bind_readonly_property(cls, "mantissa", &Base::mantissa); + bind_readonly_property(cls, "exponent", &Base::exponent); + bind_readonly_property(cls, "bias", &Base::bias); + bind_readonly_property(cls, "signed", &Base::is_signed); + bind_readonly_property(cls, "size_bits", &Base::size_bits); + + // Bind member functions + bind_function(cls, "is_signed", &Base::is_signed); + bind_function(cls, "is_integer", &Base::is_integer); + bind_function(cls, "is_floating_point", &Base::is_floating_point); + bind_function(cls, "is_ieee_754", &Base::is_ieee_754); + bind_function(cls, "has_nans", &Base::has_nans); + bind_function(cls, "has_infs", &Base::has_infs); + bind_function(cls, "has_bias", &Base::has_bias); + + bind_function(cls, "max", [](SelfPtr const& self) { + return std::visit([](auto arg) { return c10::IValue(arg); }, + self.get()->max()); + }); + bind_function(cls, "min", [](SelfPtr const& self) { + return std::visit([](auto arg) { return c10::IValue(arg); }, + self.get()->min()); + }); + + bind_function(cls, "__str__", &Base::str); + bind_function(cls, "__eq__", [](SelfPtr const& self, SelfPtr const& other) { + return *self == *other; + }); + bind_function(cls, "__repr__", [](SelfPtr const& self) { + return "ScalarType." + self.get()->str(); + }); + + // Bind static functions (convenience constructors) + bind_static_function(cls, "int_", &ScalarTypeTorch::int_); + bind_static_function(cls, "uint", &ScalarTypeTorch::uint); + bind_static_function(cls, "float_IEEE754", &ScalarTypeTorch::float_IEEE754); + bind_static_function(cls, "float_", &ScalarTypeTorch::float_); + } +}; + +using ScalarTypeTorchPtr = c10::intrusive_ptr; + +// "rust style" names generally following: +// https://github.com/pytorch/pytorch/blob/6d9f74f0af54751311f0dd71f7e5c01a93260ab3/torch/csrc/api/include/torch/types.h#L60-L70 +static inline constexpr auto kS4 = ScalarType::int_(4); +static inline constexpr auto kU4 = ScalarType::uint(4); +static inline constexpr auto kU4B8 = ScalarType::uint(4, 8); +static inline constexpr auto kS8 = ScalarType::int_(8); +static inline constexpr auto kU8 = ScalarType::uint(8); +static inline constexpr auto kU8B128 = ScalarType::uint(8, 128); + +static inline constexpr auto kFE3M2f = + ScalarType::float_(3, 2, true, ScalarType::NAN_NONE); +static inline constexpr auto kFE4M3fn = + ScalarType::float_(4, 3, true, ScalarType::NAN_EXTD_RANGE_MAX_MIN); +static inline constexpr auto kFE5M2 = ScalarType::float_IEEE754(5, 2); +static inline constexpr auto kFE8M7 = ScalarType::float_IEEE754(8, 7); +static inline constexpr auto kFE5M10 = ScalarType::float_IEEE754(5, 10); + +// Fixed width style names, generally following: +// https://github.com/pytorch/pytorch/blob/6d9f74f0af54751311f0dd71f7e5c01a93260ab3/torch/csrc/api/include/torch/types.h#L47-L57 +static inline constexpr auto kInt4 = kS4; +static inline constexpr auto kUint4 = kU4; +static inline constexpr auto kUint4b8 = kU4B8; +static inline constexpr auto kInt8 = kS8; +static inline constexpr auto kUint8 = kU8; +static inline constexpr auto kUint8b128 = kU8B128; + +static inline constexpr auto kFloat6_e3m2f = kFE3M2f; +static inline constexpr auto kFloat8_e4m3fn = kFE4M3fn; +static inline constexpr auto kFloat8_e5m2 = kFE5M2; +static inline constexpr auto kFloat16_e8m7 = kFE8M7; +static inline constexpr auto kFloat16_e5m10 = kFE5M10; + +// colloquial names +static inline constexpr auto kHalf = kFE5M10; +static inline constexpr auto kFloat16 = kHalf; +static inline constexpr auto kBFloat16 = kFE8M7; + +}; // namespace vllm diff --git a/csrc/core/torch_bindings.cpp b/csrc/core/torch_bindings.cpp new file mode 100644 index 0000000000000..f60254189a2f7 --- /dev/null +++ b/csrc/core/torch_bindings.cpp @@ -0,0 +1,16 @@ +#include + +#include "scalar_type.hpp" +#include "registration.h" + +// Note the CORE exstension will be built for (almost) all hardware targets so +// new additions must account for this. (currently not built for TPU and Neuron) + +TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, lib) { + // ScalarType, a custom class for representing data types that supports + // quantized types, declared here so it can be used when creating interfaces + // for custom ops. + vllm::ScalarTypeTorch::bind_class(lib); +} + +REGISTER_EXTENSION(TORCH_EXTENSION_NAME) diff --git a/csrc/cpu/torch_bindings.cpp b/csrc/cpu/torch_bindings.cpp index 7d549e271a30d..cf7d977da7c1c 100644 --- a/csrc/cpu/torch_bindings.cpp +++ b/csrc/cpu/torch_bindings.cpp @@ -1,6 +1,6 @@ #include "cache.h" #include "ops.h" -#include "registration.h" +#include "core/registration.h" #include diff --git a/csrc/moe/torch_bindings.cpp b/csrc/moe/torch_bindings.cpp index 243752b9a9e8c..86e42af44df15 100644 --- a/csrc/moe/torch_bindings.cpp +++ b/csrc/moe/torch_bindings.cpp @@ -1,4 +1,4 @@ -#include "registration.h" +#include "core/registration.h" #include "moe_ops.h" TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { diff --git a/csrc/ops.h b/csrc/ops.h index f274a7e647b95..3bd4a9eda5ee3 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -3,6 +3,8 @@ #include #include +#include "core/scalar_type.hpp" + void paged_attention_v1( torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache, torch::Tensor& value_cache, int64_t num_kv_heads, double scale, @@ -84,14 +86,16 @@ torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, torch::Tensor& b_meta, torch::Tensor& b_scales, - torch::Tensor& workspace, int64_t num_bits, + torch::Tensor& workspace, + vllm::ScalarTypeTorchPtr const& b_q_type, int64_t size_m, int64_t size_n, int64_t size_k); torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, torch::Tensor& b_scales, torch::Tensor& b_zeros, torch::Tensor& g_idx, torch::Tensor& perm, - torch::Tensor& workspace, int64_t num_bits, + torch::Tensor& workspace, + vllm::ScalarTypeTorchPtr const& b_q_type, int64_t size_m, int64_t size_n, int64_t size_k, bool is_k_full, bool has_zp, bool use_fp32_reduce); diff --git a/csrc/quantization/gptq_marlin/gptq_marlin.cu b/csrc/quantization/gptq_marlin/gptq_marlin.cu index 26cc248e6ac5d..edf19365c8098 100644 --- a/csrc/quantization/gptq_marlin/gptq_marlin.cu +++ b/csrc/quantization/gptq_marlin/gptq_marlin.cu @@ -21,6 +21,7 @@ #include "marlin.cuh" #include "marlin_dtypes.cuh" +#include "core/scalar_type.hpp" #define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \ static_assert(std::is_same::value || \ @@ -71,14 +72,15 @@ __global__ void Marlin( bool use_fp32_reduce // whether to use fp32 global reduce ) {} -} // namespace gptq_marlin +} // namespace marlin torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, torch::Tensor& b_scales, torch::Tensor& b_zeros, torch::Tensor& g_idx, torch::Tensor& perm, - torch::Tensor& workspace, int64_t num_bits, + torch::Tensor& workspace, + vllm::ScalarTypeTorchPtr const& b_q_type, int64_t size_m, int64_t size_n, int64_t size_k, - bool is_k_full) { + bool is_k_full, bool has_zp) { TORCH_CHECK_NOT_IMPLEMENTED(false, "marlin_gemm(..) requires CUDA_ARCH >= 8.0"); return torch::empty({1, 1}); @@ -1963,18 +1965,29 @@ exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k, __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) template -void marlin_mm_f16i4(const void* A, const void* B, void* C, void* C_tmp, - void* s, void* zp, void* g_idx, void* perm, void* a_tmp, - int prob_m, int prob_n, int prob_k, void* workspace, - int num_bits, bool has_act_order, bool is_k_full, - bool has_zp, int num_groups, int group_size, int dev, - cudaStream_t stream, int thread_k, int thread_n, int sms, - int max_par, bool use_fp32_reduce) { - TORCH_CHECK(num_bits == 4 || num_bits == 8, - "num_bits must be 4 or 8. Got = ", num_bits); +void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s, + void* zp, void* g_idx, void* perm, void* a_tmp, int prob_m, + int prob_n, int prob_k, void* workspace, + vllm::ScalarType const& q_type, bool has_act_order, + bool is_k_full, bool has_zp, int num_groups, int group_size, + int dev, cudaStream_t stream, int thread_k, int thread_n, + int sms, int max_par, bool use_fp32_reduce) { + if (has_zp) { + TORCH_CHECK( + q_type == vllm::kU4 || q_type == vllm::kU8, + "q_type must be u4 or u8 when has_zp = True. Got = ", q_type.str()); + } else { + TORCH_CHECK( + q_type == vllm::kU4B8 || q_type == vllm::kU8B128, + "q_type must be uint4b8 or uint8b128 when has_zp = False. Got = ", + q_type.str()); + } + TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m, ", ", prob_n, ", ", prob_k, "]"); + // TODO: remove alias when we start supporting other 8bit types + int num_bits = q_type.size_bits(); int tot_m = prob_m; int tot_m_blocks = div_ceil(tot_m, 16); int pad = 16 * tot_m_blocks - tot_m; @@ -2126,19 +2139,28 @@ void marlin_mm_f16i4(const void* A, const void* B, void* C, void* C_tmp, } } -} // namespace gptq_marlin +} // namespace marlin torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, torch::Tensor& b_scales, torch::Tensor& b_zeros, torch::Tensor& g_idx, torch::Tensor& perm, - torch::Tensor& workspace, int64_t num_bits, + torch::Tensor& workspace, + vllm::ScalarTypeTorchPtr const& b_q_type, int64_t size_m, int64_t size_n, int64_t size_k, bool is_k_full, bool has_zp, bool use_fp32_reduce) { - // Verify num_bits - TORCH_CHECK(num_bits == 4 || num_bits == 8, - "num_bits must be 4 or 8. Got = ", num_bits); - int pack_factor = 32 / num_bits; + if (has_zp) { + TORCH_CHECK(*b_q_type == vllm::kU4 || *b_q_type == vllm::kU8, + "b_q_type must be u4 or u8 when has_zp = True. Got = ", + b_q_type->str()); + } else { + TORCH_CHECK( + *b_q_type == vllm::kU4B8 || *b_q_type == vllm::kU8B128, + "b_q_type must be uint4b8 or uint8b128 when has_zp = False. Got = ", + b_q_type->str()); + } + + int pack_factor = 32 / b_q_type->size_bits(); // Verify A TORCH_CHECK(a.size(0) == size_m, "Shape mismatch: a.size(0) = ", a.size(0), @@ -2265,21 +2287,21 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, int dev = a.get_device(); if (a.scalar_type() == at::ScalarType::Half) { - marlin::marlin_mm_f16i4( + marlin::marlin_mm( a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), c_tmp.data_ptr(), b_scales.data_ptr(), b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(), size_m, size_n, size_k, - workspace.data_ptr(), num_bits, has_act_order, is_k_full, has_zp, + workspace.data_ptr(), *b_q_type, has_act_order, is_k_full, has_zp, num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, marlin::max_par, use_fp32_reduce); } else if (a.scalar_type() == at::ScalarType::BFloat16) { - marlin::marlin_mm_f16i4( + marlin::marlin_mm( a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), c_tmp.data_ptr(), b_scales.data_ptr(), b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(), size_m, size_n, size_k, - workspace.data_ptr(), num_bits, has_act_order, is_k_full, has_zp, + workspace.data_ptr(), *b_q_type, has_act_order, is_k_full, has_zp, num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, marlin::max_par, use_fp32_reduce); } else { diff --git a/csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu b/csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu index 3c50f1786bc68..93445a386593b 100644 --- a/csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu +++ b/csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu @@ -27,6 +27,7 @@ #include #include "common/base.h" +#include "core/scalar_type.hpp" #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 @@ -86,7 +87,8 @@ __global__ void Marlin_24( torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, torch::Tensor& b_meta, torch::Tensor& b_scales, - torch::Tensor& workspace, int64_t num_bits, + torch::Tensor& workspace, + vllm::ScalarTypeTorchPtr const& b_q_type, int64_t size_m, int64_t size_n, int64_t size_k) { TORCH_CHECK_NOT_IMPLEMENTED( @@ -1025,13 +1027,14 @@ void marlin_cuda_2_4(const void* A, const void* B, const void* meta, void* C, torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, torch::Tensor& b_meta, torch::Tensor& b_scales, - torch::Tensor& workspace, int64_t num_bits, + torch::Tensor& workspace, + vllm::ScalarTypeTorchPtr const& b_q_type, int64_t size_m, int64_t size_n, int64_t size_k) { // Verify num_bits - TORCH_CHECK(num_bits == 4 || num_bits == 8, - "num_bits must be 4 or 8. Got = ", num_bits); - int pack_factor = 32 / num_bits; + TORCH_CHECK(*b_q_type == vllm::kU4B8 || *b_q_type == vllm::kU8B128, + "num_bits must be uint4b8 or uint8b128. Got = ", b_q_type->str()); + int pack_factor = 32 / b_q_type->size_bits(); // Verify M TORCH_CHECK(size_m == a.size(0), @@ -1126,8 +1129,8 @@ torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, marlin_24::marlin_cuda_2_4( a.data_ptr(), b_q_weight.data_ptr(), b_meta.data_ptr(), c.data_ptr(), b_scales.data_ptr(), size_n, size_m, size_k, workspace.data_ptr(), - num_bits, groupsize, dev, at::cuda::getCurrentCUDAStream(dev), thread_k, - thread_m, sms, max_par); + b_q_type->size_bits(), groupsize, dev, + at::cuda::getCurrentCUDAStream(dev), thread_k, thread_m, sms, max_par); return c; } diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index bf8cefa8d4713..7c0d617fc8b3b 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -1,7 +1,7 @@ #include "cache.h" #include "cuda_utils.h" #include "ops.h" -#include "registration.h" +#include "core/registration.h" #include diff --git a/setup.py b/setup.py index 91307e8a94062..b146299f8269d 100644 --- a/setup.py +++ b/setup.py @@ -271,6 +271,10 @@ def _build_custom_ops() -> bool: return _is_cuda() or _is_hip() or _is_cpu() +def _build_core_ext() -> bool: + return not _is_neuron() and not _is_tpu() + + def get_hipcc_rocm_version(): # Run the hipcc --version command result = subprocess.run(['hipcc', '--version'], @@ -433,6 +437,9 @@ def _read_requirements(filename: str) -> List[str]: ext_modules = [] +if _build_core_ext(): + ext_modules.append(CMakeExtension(name="vllm._core_C")) + if _is_cuda() or _is_hip(): ext_modules.append(CMakeExtension(name="vllm._moe_C")) @@ -477,7 +484,7 @@ def _read_requirements(filename: str) -> List[str]: extras_require={ "tensorizer": ["tensorizer>=2.9.0"], }, - cmdclass={"build_ext": cmake_build_ext} if _build_custom_ops() else {}, + cmdclass={"build_ext": cmake_build_ext} if len(ext_modules) > 0 else {}, package_data=package_data, entry_points={ "console_scripts": [ diff --git a/tests/kernels/test_int8_quant.py b/tests/kernels/test_int8_quant.py index 03acbf7968ff1..0b7ed26a39e1e 100644 --- a/tests/kernels/test_int8_quant.py +++ b/tests/kernels/test_int8_quant.py @@ -1,8 +1,6 @@ import pytest import torch -# ruff: noqa: F401 -import vllm._C from tests.kernels.quant_utils import ref_dynamic_per_token_quant from vllm._custom_ops import scaled_int8_quant diff --git a/tests/kernels/test_marlin_gemm.py b/tests/kernels/test_marlin_gemm.py index a9e34ac8a7aa8..2f58ffda21408 100644 --- a/tests/kernels/test_marlin_gemm.py +++ b/tests/kernels/test_marlin_gemm.py @@ -9,14 +9,14 @@ from vllm import _custom_ops as ops from vllm.model_executor.layers.quantization.gptq_marlin_24 import ( GPTQ_MARLIN_24_MAX_PARALLEL, GPTQ_MARLIN_24_MIN_THREAD_N, - GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_24_SUPPORTED_NUM_BITS) + GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES) from vllm.model_executor.layers.quantization.qqq import ( MARLIN_QQQ_MAX_PARALLEL, MARLIN_QQQ_MIN_THREAD_N, MARLIN_QQQ_SUPPORTED_GROUP_SIZES, MARLIN_QQQ_SUPPORTED_NUM_BITS) from vllm.model_executor.layers.quantization.utils.marlin_utils import ( GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N, - MARLIN_SUPPORTED_GROUP_SIZES, MARLIN_SUPPORTED_NUM_BITS, - marlin_make_empty_g_idx, marlin_permute_scales) + MARLIN_SUPPORTED_GROUP_SIZES, marlin_make_empty_g_idx, + marlin_permute_scales, query_marlin_supported_quant_types) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( pack_fp8_to_int32) from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( @@ -27,8 +27,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_test_qqq import ( # noqa: E501 marlin_qqq_quantize) from vllm.model_executor.layers.quantization.utils.quant_utils import ( - awq_pack, gptq_pack, quantize_weights, quantize_weights_with_zp, - sort_weights) + awq_pack, gptq_pack, gptq_quantize_weights, quantize_weights, sort_weights) ACT_ORDER_OPTS = [False, True] K_FULL_OPTS = [False, True] @@ -65,12 +64,13 @@ def rand_data(shape, dtype=torch.float16): reason="Marlin is not supported on this GPU type.") @pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS) @pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS) -@pytest.mark.parametrize("num_bits", MARLIN_SUPPORTED_NUM_BITS) +@pytest.mark.parametrize("quant_type", + query_marlin_supported_quant_types(False)) @pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES) @pytest.mark.parametrize("act_order", ACT_ORDER_OPTS) @pytest.mark.parametrize("mnk_factors", MNK_FACTORS) -def test_gptq_marlin_repack(k_chunk, n_chunk, num_bits, group_size, act_order, - mnk_factors): +def test_gptq_marlin_repack(k_chunk, n_chunk, quant_type, group_size, + act_order, mnk_factors): m_factor, n_factor, k_factor = mnk_factors size_m = m_factor @@ -95,11 +95,11 @@ def test_gptq_marlin_repack(k_chunk, n_chunk, num_bits, group_size, act_order, b_weight = rand_data((size_k, size_n)) # Quantize (and apply act_order if provided) - w_ref, q_w, s, g_idx, rand_perm = quantize_weights(b_weight, num_bits, - group_size, act_order) + w_ref, q_w, s, g_idx, rand_perm = gptq_quantize_weights( + b_weight, quant_type, group_size, act_order) # Pack to GPTQ format - q_w_gptq = gptq_pack(q_w, num_bits, size_k, size_n) + q_w_gptq = gptq_pack(q_w, quant_type.size_bits, size_k, size_n) # For act_order, sort the "weights" and "g_idx" so that group ids are # increasing @@ -108,8 +108,9 @@ def test_gptq_marlin_repack(k_chunk, n_chunk, num_bits, group_size, act_order, q_w, g_idx, sort_indices = sort_weights(q_w, g_idx) # Pack to Marlin format - weight_perm = get_weight_perm(num_bits) - marlin_q_w_1 = marlin_weights(q_w, size_k, size_n, num_bits, weight_perm) + weight_perm = get_weight_perm(quant_type.size_bits) + marlin_q_w_1 = marlin_weights(q_w, size_k, size_n, quant_type.size_bits, + weight_perm) # Run Marlin repack GPU kernel marlin_q_w_2 = ops.gptq_marlin_repack( @@ -117,7 +118,7 @@ def test_gptq_marlin_repack(k_chunk, n_chunk, num_bits, group_size, act_order, sort_indices, size_k, size_n, - num_bits, + quant_type.size_bits, ) torch.cuda.synchronize() @@ -128,10 +129,11 @@ def test_gptq_marlin_repack(k_chunk, n_chunk, num_bits, group_size, act_order, reason="Marlin is not supported on this GPU type.") @pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS) @pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS) -@pytest.mark.parametrize("num_bits", MARLIN_SUPPORTED_NUM_BITS) +@pytest.mark.parametrize("quant_type", + query_marlin_supported_quant_types(False)) @pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES) @pytest.mark.parametrize("mnk_factors", MNK_FACTORS) -def test_awq_marlin_repack(k_chunk, n_chunk, num_bits, group_size, +def test_awq_marlin_repack(k_chunk, n_chunk, quant_type, group_size, mnk_factors): m_factor, n_factor, k_factor = mnk_factors @@ -150,22 +152,25 @@ def test_awq_marlin_repack(k_chunk, n_chunk, num_bits, group_size, b_weight = rand_data((size_k, size_n)) # Quantize - w_ref, q_w, s, zp = quantize_weights_with_zp(b_weight, num_bits, - group_size) + w_ref, q_w, s, zp = quantize_weights(b_weight, + quant_type, + group_size, + zero_points=True) # Pack to AWQ format - q_w_awq = awq_pack(q_w, num_bits, size_k, size_n) + q_w_awq = awq_pack(q_w, quant_type.size_bits, size_k, size_n) # Pack to Marlin format - weight_perm = get_weight_perm(num_bits) - marlin_q_w_1 = marlin_weights(q_w, size_k, size_n, num_bits, weight_perm) + weight_perm = get_weight_perm(quant_type.size_bits) + marlin_q_w_1 = marlin_weights(q_w, size_k, size_n, quant_type.size_bits, + weight_perm) # Run Marlin repack GPU kernel marlin_q_w_2 = ops.awq_marlin_repack( q_w_awq, size_k, size_n, - num_bits, + quant_type.size_bits, ) torch.cuda.synchronize() @@ -176,7 +181,8 @@ def test_awq_marlin_repack(k_chunk, n_chunk, num_bits, group_size, reason="Marlin is not supported on this GPU type.") @pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS) @pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS) -@pytest.mark.parametrize("num_bits", MARLIN_SUPPORTED_NUM_BITS) +@pytest.mark.parametrize("quant_type", + query_marlin_supported_quant_types(False)) @pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES) @pytest.mark.parametrize("mnk_factors", MNK_FACTORS) @pytest.mark.parametrize("act_order", ACT_ORDER_OPTS) @@ -185,7 +191,7 @@ def test_awq_marlin_repack(k_chunk, n_chunk, num_bits, group_size, def test_gptq_marlin_gemm( k_chunk, n_chunk, - num_bits, + quant_type, group_size, mnk_factors, act_order, @@ -211,7 +217,7 @@ def test_gptq_marlin_gemm( b_weight = rand_data((size_k, size_n)) w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize( - b_weight, num_bits, group_size, act_order) + b_weight, quant_type, group_size, act_order) marlin_zp = marlin_make_empty_g_idx(marlin_s.device) @@ -226,7 +232,7 @@ def test_gptq_marlin_gemm( g_idx, sort_indices, workspace.scratch, - num_bits, + quant_type, a_input.shape[0], b_weight.shape[1], a_input.shape[1], @@ -248,10 +254,10 @@ def test_gptq_marlin_gemm( reason="Marlin is not supported on this GPU type.") @pytest.mark.parametrize("k_chunk", MARLIN_24_K_CHUNKS) @pytest.mark.parametrize("n_chunk", MARLIN_24_N_CHUNKS) -@pytest.mark.parametrize("num_bits", GPTQ_MARLIN_24_SUPPORTED_NUM_BITS) +@pytest.mark.parametrize("quant_type", GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES) @pytest.mark.parametrize("group_size", GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES) @pytest.mark.parametrize("mnk_factors", MNK_FACTORS) -def test_gptq_marlin_24_gemm(k_chunk, n_chunk, num_bits, group_size, +def test_gptq_marlin_24_gemm(k_chunk, n_chunk, quant_type, group_size, mnk_factors): m_factor, n_factor, k_factor = mnk_factors @@ -266,7 +272,7 @@ def test_gptq_marlin_24_gemm(k_chunk, n_chunk, num_bits, group_size, b_weight = rand_data((size_k, size_n)) (w_24_ref, marlin_24_q_w_comp, marlin_24_meta, - marlin_24_s) = marlin_24_quantize(b_weight, num_bits, group_size) + marlin_24_s) = marlin_24_quantize(b_weight, quant_type, group_size) workspace_24 = MarlinWorkspace(size_n, GPTQ_MARLIN_24_MIN_THREAD_N, GPTQ_MARLIN_24_MAX_PARALLEL) @@ -279,7 +285,7 @@ def test_gptq_marlin_24_gemm(k_chunk, n_chunk, num_bits, group_size, marlin_24_meta, marlin_24_s, workspace_24.scratch, - num_bits, + quant_type, a_input.shape[0], b_weight.shape[1], a_input.shape[1], @@ -371,14 +377,15 @@ def test_fp8_marlin_gemm( reason="Marlin is not supported on this GPU type.") @pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS) @pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS) -@pytest.mark.parametrize("num_bits", MARLIN_SUPPORTED_NUM_BITS) +@pytest.mark.parametrize("quant_type", + query_marlin_supported_quant_types(True)) @pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES) @pytest.mark.parametrize("mnk_factors", MNK_FACTORS) @pytest.mark.parametrize("use_fp32_reduce", USE_FP32_REDUCE_OPTS) def test_awq_marlin_gemm( k_chunk, n_chunk, - num_bits, + quant_type, group_size, mnk_factors, use_fp32_reduce, @@ -396,7 +403,7 @@ def test_awq_marlin_gemm( b_weight = rand_data((size_k, size_n)) w_ref, marlin_q_w, marlin_s, marlin_zp = awq_marlin_quantize( - b_weight, num_bits, group_size) + b_weight, quant_type, group_size) g_idx = torch.empty(0, dtype=torch.int, device=marlin_q_w.device) sort_indices = torch.empty(0, dtype=torch.int, device=marlin_q_w.device) @@ -414,7 +421,7 @@ def test_awq_marlin_gemm( g_idx, sort_indices, workspace.scratch, - num_bits, + quant_type, a_input.shape[0], b_weight.shape[1], a_input.shape[1], diff --git a/tests/test_scalartype.py b/tests/test_scalartype.py new file mode 100644 index 0000000000000..1201aaa92ea89 --- /dev/null +++ b/tests/test_scalartype.py @@ -0,0 +1,36 @@ +import pytest +import torch + +from vllm.scalar_type import scalar_types + + +@pytest.mark.parametrize("type_tuple", ( + (-8, 7, scalar_types.int4), + (0, 15, scalar_types.uint4), + (-8, 7, scalar_types.uint4b8), + (-128, 127, scalar_types.uint8b128), + (-28., 28., scalar_types.float6_e3m2f), + (torch.int8, scalar_types.int8), + (torch.uint8, scalar_types.uint8), + (torch.float8_e5m2, scalar_types.float8_e5m2), + (torch.float8_e4m3fn, scalar_types.float8_e4m3fn), + (torch.bfloat16, scalar_types.float16_e8m7), + (torch.float16, scalar_types.float16_e5m10), +), + ids=lambda x: str(x)) +def test_scalar_type_min_max(type_tuple): + print(type_tuple) + if len(type_tuple) == 3: + min, max, t = type_tuple + else: + torch_type, t = type_tuple + if torch_type.is_floating_point: + min = torch.finfo(torch_type).min + max = torch.finfo(torch_type).max + else: + min = torch.iinfo(torch_type).min + max = torch.iinfo(torch_type).max + + print(t, min, max, t.min(), t.max()) + assert min == t.min() + assert max == t.max() diff --git a/vllm/_core_ext.py b/vllm/_core_ext.py new file mode 100644 index 0000000000000..e3b9fbb938915 --- /dev/null +++ b/vllm/_core_ext.py @@ -0,0 +1,177 @@ +import importlib.util +from enum import Enum +from typing import TYPE_CHECKING, Optional, Union + +import torch + +from vllm.logger import init_logger + +logger = init_logger(__name__) +core_C_available = importlib.util.find_spec('._core_C', 'vllm') is not None + + +# Mirrors enum in `core/scalar_type.hpp` +class NanRepr(Enum): + NONE = 0 # nans are not supported + IEEE_754 = 1 # nans are: Exp all 1s, mantissa not all 0s + EXTD_RANGE_MAX_MIN = 2 # nans are: Exp all 1s, mantissa all 1s + + +if TYPE_CHECKING or not core_C_available: + # On platforms were we cannot use/build the C++ core extension (i.e. namely + # neuron and tpu), we define the mock ScalarType class here that partially + # mimics the C++ ScalarType class. + # + # We also use this provide type signatures to the Python LSP for the methods + # in the C++ ScalarType class. So these type signatures should be kept + # in sync with csrc/core/scalar_type.hpp + + from dataclasses import dataclass + + @dataclass(frozen=True) + class ScalarType: + """ + ScalarType can represent a wide range of floating point and integer + types, in particular it can be used to represent sub-byte data types + (something that torch.dtype currently does not support). It is also + capable of representing types with a bias, i.e.: + `stored_value = value + bias`, + this is useful for quantized types (e.g. standard GPTQ 4bit uses a bias + of 8). The implementation for this class can be found in + csrc/core/scalar_type.hpp, these type signatures should be kept in sync + with that file. + """ + + exponent: int + """ + Number of bits in the exponent if this is a floating point type + (zero if this an integer type) + """ + + mantissa: int + """ + Number of bits in the mantissa if this is a floating point type, + or the number bits representing an integer excluding the sign bit if + this an integer type. + """ + + bias: int + """ + bias used to encode the values in this scalar type + (value = stored_value - bias, default 0) for example if we store the + type as an unsigned integer with a bias of 128 then the value 0 will be + stored as 128 and -1 will be stored as 127 and 1 will be stored as 129. + """ + + signed: bool + "If the type is signed (i.e. has a sign bit)" + + _finite_values_only: bool = False + """ + Private: if NANs are supported, used `has_infs()` instead. + """ + + nan_repr: int = NanRepr.IEEE_754.value + """ + How NaNs are represent in this scalar type, returns NanRepr value. + (not applicable for integer types) + """ + + @property + def size_bits(self): + return self.exponent + self.mantissa + int(self.signed) + + def min(self) -> Union[int, float]: + """ + Min representable value for this scalar type. + (accounting for bias if there is one) + """ + raise NotImplementedError + + def max(self) -> Union[int, float]: + """ + Max representable value for this scalar type. + (accounting for bias if there is one) + """ + raise NotImplementedError + + def is_signed(self) -> bool: + """ + If the type is signed (i.e. has a sign bit), same as `signed` + added for consistency with: + https://pytorch.org/docs/stable/generated/torch.Tensor.is_signed.html + """ + ... + + def is_floating_point(self): + "If the type is a floating point type" + return self.exponent != 0 + + def is_integer(self): + "If the type is an integer type" + return self.exponent == 0 + + def has_bias(self): + "If the type has a non-zero bias" + return self.bias != 0 + + def has_infs(self): + "If the type is floating point and supports infinity" + return not self._finite_values_only + + def has_nans(self): + return self.nan_repr != NanRepr.NONE.value + + def is_ieee_754(self) -> bool: + """ + If the type is a floating point type that follows IEEE 754 + conventions + """ + return self.nan_repr == NanRepr.IEEE_754.value and \ + not self._finite_values_only + + def __str__(self) -> str: + raise NotImplementedError + + def __repr__(self) -> str: + raise NotImplementedError + + # + # Convenience Constructors + # + + @classmethod + def int_(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType': + "Create a signed integer scalar type (size_bits includes sign-bit)." + return cls(size_bits - 1, size_bits, bias if bias else 0, True) + + @classmethod + def uint(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType': + """Create a unsigned integer scalar type.""" + return cls(size_bits, size_bits, bias if bias else 0, False) + + @classmethod + def float_IEEE754(cls, exponent: int, mantissa: int) -> 'ScalarType': + """ + Create a standard floating point type + (i.e. follows IEEE 754 conventions). + """ + return cls(exponent, mantissa, 0, True) + + @classmethod + def float_(cls, exponent: int, mantissa: int, finite_values_only: bool, + nan_repr: int): + """ + Create a non-standard floating point type + (i.e. does not follow IEEE 754 conventions). + """ + return cls(exponent, mantissa, 0, True, finite_values_only, + nan_repr) + +elif core_C_available: + try: + import vllm._core_C # noqa: F401 + except ImportError as e: + logger.warning("Failed to import from vllm._core_C with %r", e) + + ScalarType = torch.classes._core_C.ScalarType diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 6cd77f75cae8d..ad7e5bd199339 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -4,6 +4,7 @@ import torch +from vllm._core_ext import ScalarType from vllm.logger import init_logger logger = init_logger(__name__) @@ -220,10 +221,10 @@ def marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, # marlin_24 def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, b_meta: torch.Tensor, b_scales: torch.Tensor, - workspace: torch.Tensor, num_bits: int, size_m: int, - size_n: int, size_k: int) -> torch.Tensor: + workspace: torch.Tensor, b_q_type: ScalarType, + size_m: int, size_n: int, size_k: int) -> torch.Tensor: return torch.ops._C.gptq_marlin_24_gemm(a, b_q_weight, b_meta, b_scales, - workspace, num_bits, size_m, + workspace, b_q_type, size_m, size_n, size_k) @@ -279,14 +280,22 @@ def awq_marlin_repack(b_q_weight: torch.Tensor, size_k: int, size_n: int, return torch.ops._C.awq_marlin_repack(b_q_weight, size_k, size_n, num_bits) -def gptq_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, - b_scales: torch.Tensor, b_zeros: torch.Tensor, - g_idx: torch.Tensor, perm: torch.Tensor, - workspace: torch.Tensor, num_bits: int, size_m: int, - size_n: int, size_k: int, is_k_full: bool, has_zp: bool, - use_fp32_reduce: bool) -> torch.Tensor: +def gptq_marlin_gemm(a: torch.Tensor, + b_q_weight: torch.Tensor, + b_scales: torch.Tensor, + b_zeros: torch.Tensor, + g_idx: torch.Tensor, + perm: torch.Tensor, + workspace: torch.Tensor, + b_q_type: ScalarType, + size_m: int, + size_n: int, + size_k: int, + is_k_full: bool, + has_zp: bool = False, + use_fp32_reduce: bool = False) -> torch.Tensor: return torch.ops._C.gptq_marlin_gemm(a, b_q_weight, b_scales, b_zeros, - g_idx, perm, workspace, num_bits, + g_idx, perm, workspace, b_q_type, size_m, size_n, size_k, is_k_full, has_zp, use_fp32_reduce) diff --git a/vllm/model_executor/layers/quantization/awq_marlin.py b/vllm/model_executor/layers/quantization/awq_marlin.py index 5ffbb8e854e87..2cc080608c7a9 100644 --- a/vllm/model_executor/layers/quantization/awq_marlin.py +++ b/vllm/model_executor/layers/quantization/awq_marlin.py @@ -10,11 +10,11 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.quantization.utils.marlin_utils import ( - apply_awq_marlin_linear, awq_to_marlin_zero_points, - check_awq_marlin_supported, marlin_make_empty_g_idx, marlin_make_workspace, - marlin_permute_scales, replace_tensor, verify_awq_marlin_supported, - verify_marlin_supports_shape) + apply_awq_marlin_linear, awq_to_marlin_zero_points, check_marlin_supported, + marlin_make_empty_g_idx, marlin_make_workspace, marlin_permute_scales, + replace_tensor, verify_marlin_supported, verify_marlin_supports_shape) from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead +from vllm.scalar_type import scalar_types logger = init_logger(__name__) @@ -22,20 +22,31 @@ class AWQMarlinConfig(QuantizationConfig): """Config class for AWQ Marlin""" + # num_bits -> type + TYPE_MAP = { + 4: scalar_types.uint4, + 8: scalar_types.uint8, + } + def __init__(self, weight_bits: int, group_size: int, has_zp: bool, lm_head_quantized: bool) -> None: - self.weight_bits = weight_bits - self.pack_factor = 32 // self.weight_bits # packed into 32bits + self.pack_factor = 32 // weight_bits # packed into int32 self.group_size = group_size self.has_zp = has_zp self.lm_head_quantized = lm_head_quantized - verify_awq_marlin_supported(num_bits=self.weight_bits, - group_size=self.group_size, - has_zp=self.has_zp) + if weight_bits not in self.TYPE_MAP: + raise ValueError(f"Unsupported num_bits = {weight_bits}. " + f"Supported num_bits = {self.TYPE_MAP.keys()}") + + self.quant_type = self.TYPE_MAP[weight_bits] + + verify_marlin_supported(self.quant_type, + group_size=self.group_size, + has_zp=self.has_zp) def __repr__(self) -> str: - return (f"AWQMarlinConfig(weight_bits={self.weight_bits}, " + return (f"AWQMarlinConfig(quant_type={self.quant_type}, " f"group_size={self.group_size}, " f"has_zp={self.has_zp}, " f"lm_head_quantized={self.lm_head_quantized})") @@ -110,11 +121,13 @@ def is_awq_marlin_compatible(cls, quant_config: Dict[str, Any]): if (num_bits is None or group_size is None or has_zp is None): return False - return check_awq_marlin_supported( - num_bits=num_bits, - group_size=group_size, - has_zp=has_zp, - min_capability=cls.get_min_capability()) + if num_bits not in cls.TYPE_MAP: + return False + + return check_marlin_supported(quant_type=cls.TYPE_MAP[num_bits], + group_size=group_size, + has_zp=has_zp, + min_capability=cls.get_min_capability()) class AWQMarlinLinearMethod(LinearMethodBase): @@ -226,7 +239,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer.qweight, size_k=layer.input_size_per_partition, size_n=layer.output_size_per_partition, - num_bits=self.quant_config.weight_bits) + num_bits=self.quant_config.quant_type.size_bits) replace_tensor(layer, "qweight", marlin_qweight) # Permute scales from AWQ format to marlin format. @@ -242,7 +255,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer.qzeros, size_k=layer.num_groups, size_n=layer.output_size_per_partition, - num_bits=self.quant_config.weight_bits) + num_bits=self.quant_config.quant_type.size_bits) replace_tensor(layer, "qzeros", marlin_zp) # Not-used @@ -263,7 +276,7 @@ def apply( g_idx=layer.g_idx, g_idx_sort_indices=layer.g_idx_sort_indices, workspace=layer.workspace, - num_bits=self.quant_config.weight_bits, + quant_type=self.quant_config.quant_type, output_size_per_partition=layer.output_size_per_partition, input_size_per_partition=layer.input_size_per_partition, bias=bias) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py index b8ffb22d7a89d..c1adfdb2980b6 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py @@ -9,9 +9,13 @@ from vllm.model_executor.layers.quantization.gptq_marlin_24 import ( GPTQ_MARLIN_24_MAX_PARALLEL, GPTQ_MARLIN_24_MIN_THREAD_N) from vllm.model_executor.utils import set_weight_attrs +from vllm.scalar_type import scalar_types __all__ = ["CompressedTensorsW4A16Sparse24"] -W4A16SPARSE24_SUPPORTED_BITS = [4] +W4A16SPARSE24_SUPPORTED_TYPES_MAP = { + 4: scalar_types.uint4b8, +} +W4A16SPARSE24_SUPPORTED_BITS = list(W4A16SPARSE24_SUPPORTED_TYPES_MAP.keys()) class CompressedTensorsW4A16Sparse24(CompressedTensorsScheme): @@ -22,9 +26,15 @@ def __init__(self, group_size: Optional[int] = None): self.strategy = strategy self.group_size = group_size - self.num_bits = num_bits self.tile_size = 16 + if num_bits not in W4A16SPARSE24_SUPPORTED_TYPES_MAP: + raise ValueError( + f"Unsupported num_bits = {num_bits}. " + f"Supported num_bits = {W4A16SPARSE24_SUPPORTED_BITS}") + + self.quant_type = W4A16SPARSE24_SUPPORTED_TYPES_MAP[num_bits] + if self.strategy == "group" and self.group_size is None: raise ValueError( "group_size must be given when using strategy group") @@ -43,7 +53,7 @@ def create_weights(self, layer: torch.nn.Module, input_size: int, params_dtype: torch.dtype, weight_loader: Callable, **kwargs): - pack_factor = 32 // self.num_bits + pack_factor = 32 // self.quant_type.size_bits output_size_per_partition = sum(output_partition_sizes) qweight = Parameter( @@ -138,7 +148,7 @@ def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor, size_n = scales.shape[1] output_2d = ops.gptq_marlin_24_gemm(x_2d, qweight, meta, scales, - workspace, self.num_bits, size_m, + workspace, self.quant_type, size_m, size_n, size_k) output = output_2d.view(x.shape[:-1] + (output_2d.shape[1], )) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py index a41962ccd66d8..b8880f7ac136f 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py @@ -8,12 +8,17 @@ CompressedTensorsScheme) from vllm.model_executor.layers.quantization.utils.marlin_utils import ( apply_gptq_marlin_linear, marlin_make_empty_g_idx, marlin_make_workspace, - marlin_permute_scales, replace_tensor, verify_gptq_marlin_supported, + marlin_permute_scales, replace_tensor, verify_marlin_supported, verify_marlin_supports_shape) from vllm.model_executor.utils import set_weight_attrs +from vllm.scalar_type import scalar_types __all__ = ["CompressedTensorsWNA16"] -WNA16_SUPPORTED_BITS = [4, 8] +WNA16_SUPPORTED_TYPES_MAP = { + 4: scalar_types.uint4b8, + 8: scalar_types.uint8b128, +} +WNA16_SUPPORTED_BITS = list(WNA16_SUPPORTED_TYPES_MAP.keys()) class CompressedTensorsWNA16(CompressedTensorsScheme): @@ -22,8 +27,8 @@ def __init__(self, strategy: str, num_bits: int, group_size: Optional[int] = None): - self.num_bits = num_bits - self.pack_factor = 32 // self.num_bits + + self.pack_factor = 32 // num_bits self.strategy = strategy self.group_size: int @@ -37,10 +42,16 @@ def __init__(self, else: self.group_size = group_size + if num_bits not in WNA16_SUPPORTED_TYPES_MAP: + raise ValueError( + f"Unsupported num_bits = {num_bits}. " + f"Supported num_bits = {WNA16_SUPPORTED_TYPES_MAP.keys()}") + + self.quant_type = WNA16_SUPPORTED_TYPES_MAP[num_bits] + # Verify supported on platform. - verify_gptq_marlin_supported(num_bits=self.num_bits, - group_size=self.group_size, - is_sym=True) + verify_marlin_supported(quant_type=self.quant_type, + group_size=self.group_size) @classmethod def get_min_capability(cls) -> int: @@ -150,7 +161,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: perm=layer.g_idx_sort_indices, size_k=layer.input_size_per_partition, size_n=layer.output_size_per_partition, - num_bits=self.num_bits) + num_bits=self.quant_type.size_bits) replace_tensor(layer, "weight_packed", marlin_qweight) # Permute scales from compressed-tensors format to marlin format. @@ -172,7 +183,7 @@ def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor, g_idx=layer.g_idx, g_idx_sort_indices=layer.g_idx_sort_indices, workspace=layer.workspace, - num_bits=self.num_bits, + wtype=self.quant_type, output_size_per_partition=layer.output_size_per_partition, input_size_per_partition=layer.input_size_per_partition, is_k_full=True, diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index bdcc9c3b4f0c5..4a11b14971076 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -10,11 +10,12 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.quantization.utils.marlin_utils import ( - apply_gptq_marlin_linear, check_gptq_marlin_supported, marlin_is_k_full, + apply_gptq_marlin_linear, check_marlin_supported, marlin_is_k_full, marlin_make_empty_g_idx, marlin_make_workspace, marlin_permute_scales, marlin_repeat_scales_on_all_ranks, marlin_sort_g_idx, replace_tensor, - verify_gptq_marlin_supported, verify_marlin_supports_shape) + verify_marlin_supported, verify_marlin_supports_shape) from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead +from vllm.scalar_type import scalar_types logger = init_logger(__name__) @@ -22,6 +23,12 @@ class GPTQMarlinConfig(QuantizationConfig): """Config class for GPTQ Marlin""" + # (num_bits, is_sym) -> quant_type + TYPE_MAP = { + (4, True): scalar_types.uint4b8, + (8, True): scalar_types.uint8b128, + } + def __init__(self, weight_bits: int, group_size: int, desc_act: bool, is_sym: bool, lm_head_quantized: bool) -> None: if desc_act and group_size == -1: @@ -29,20 +36,23 @@ def __init__(self, weight_bits: int, group_size: int, desc_act: bool, # (since we have only one group per output channel) desc_act = False - self.weight_bits = weight_bits - self.pack_factor = 32 // self.weight_bits # packed into int32 + self.pack_factor = 32 // weight_bits # packed into int32 self.group_size = group_size self.desc_act = desc_act - self.is_sym = is_sym self.lm_head_quantized = lm_head_quantized + if (weight_bits, is_sym) not in self.TYPE_MAP: + raise ValueError("Unsupported quantization config: " + f"bits={weight_bits}, sym={is_sym}") + + self.quant_type = self.TYPE_MAP[(weight_bits, is_sym)] + # Verify supported on platform. - verify_gptq_marlin_supported(num_bits=self.weight_bits, - group_size=self.group_size, - is_sym=self.is_sym) + verify_marlin_supported(quant_type=self.quant_type, + group_size=self.group_size) def __repr__(self) -> str: - return (f"GPTQMarlinConfig(weight_bits={self.weight_bits}, " + return (f"GPTQMarlinConfig(quant_type={self.quant_type}, " f"group_size={self.group_size}, " f"desc_act={self.desc_act}, " f"lm_head_quantized={self.lm_head_quantized})") @@ -122,11 +132,12 @@ def is_gptq_marlin_compatible(cls, quant_config: Dict[str, Any]): or desc_act is None): return False - return check_gptq_marlin_supported( - num_bits=num_bits, - group_size=group_size, - is_sym=sym, - min_capability=cls.get_min_capability()) + if (num_bits, sym) not in cls.TYPE_MAP: + return False + + return check_marlin_supported(quant_type=cls.TYPE_MAP[(num_bits, sym)], + group_size=group_size, + min_capability=cls.get_min_capability()) class GPTQMarlinLinearMethod(LinearMethodBase): @@ -293,7 +304,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: perm=layer.g_idx_sort_indices, size_k=layer.input_size_per_partition, size_n=layer.output_size_per_partition, - num_bits=self.quant_config.weight_bits) + num_bits=self.quant_config.quant_type.size_bits) replace_tensor(layer, "qweight", marlin_qweight) # Permute scales from autogptq format to marlin format. @@ -319,7 +330,7 @@ def apply( g_idx=layer.g_idx, g_idx_sort_indices=layer.g_idx_sort_indices, workspace=layer.workspace, - num_bits=self.quant_config.weight_bits, + wtype=self.quant_config.quant_type, output_size_per_partition=layer.output_size_per_partition, input_size_per_partition=layer.input_size_per_partition, is_k_full=layer.is_k_full, diff --git a/vllm/model_executor/layers/quantization/gptq_marlin_24.py b/vllm/model_executor/layers/quantization/gptq_marlin_24.py index e708c4da95af3..cafd100a2f40c 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin_24.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin_24.py @@ -9,6 +9,7 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.utils import set_weight_attrs +from vllm.scalar_type import scalar_types logger = init_logger(__name__) @@ -17,9 +18,10 @@ GPTQ_MARLIN_24_MIN_THREAD_K = 128 GPTQ_MARLIN_24_MAX_PARALLEL = 64 -GPTQ_MARLIN_24_SUPPORTED_NUM_BITS = [4, 8] +GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES = [ + scalar_types.uint4b8, scalar_types.uint8b128 +] GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES = [-1, 128] -GPTQ_MARLIN_24_SUPPORTED_SYM = [True] class GPTQMarlin24Config(QuantizationConfig): @@ -31,14 +33,19 @@ def __init__( weight_bits: int, group_size: int, ) -> None: - self.weight_bits = weight_bits + quant_type = { + 4: scalar_types.uint4b8, + 8: scalar_types.uint8b128, + }.get(weight_bits) + self.group_size = group_size # Verify - if self.weight_bits not in GPTQ_MARLIN_24_SUPPORTED_NUM_BITS: + if quant_type is None or \ + quant_type not in GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES: raise ValueError( - f"Marlin_24 does not support weight_bits = {self.weight_bits}. " - f"Only weight_bits = {GPTQ_MARLIN_24_SUPPORTED_NUM_BITS} " + f"Marlin_24 does not support quant_type = {quant_type}. " + f"Only weight_bits = {GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES} " "are supported.") if self.group_size not in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES: raise ValueError( @@ -46,8 +53,10 @@ def __init__( f"Only group_sizes = {GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES} " "are supported.") + self.quant_type = quant_type + # 4 Bits packed into 32 bit datatype. - self.pack_factor = 32 // self.weight_bits + self.pack_factor = 32 // self.quant_type.size_bits # Tile size used by marlin kernels. self.tile_size = 16 @@ -66,8 +75,8 @@ def __init__( self.perm_len = 1024 def __repr__(self) -> str: - return "Marlin24Config(weight_bits={}, group_size={})".format( - self.weight_bits, self.group_size) + return "Marlin24Config(quant_type={}, group_size={})".format( + self.quant_type, self.group_size) @classmethod def get_name(cls) -> str: @@ -279,7 +288,7 @@ def apply( output_2d = ops.gptq_marlin_24_gemm(x_2d, qweight, meta, scales, workspace, - self.quant_config.weight_bits, + self.quant_config.quant_type, size_m, size_n, size_k) output = output_2d.view(x.shape[:-1] + (output_2d.shape[1], )) diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils.py b/vllm/model_executor/layers/quantization/utils/marlin_utils.py index b789ca20cadb3..6e84d36219361 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils.py @@ -5,6 +5,7 @@ from vllm import _custom_ops as ops from vllm.platforms import current_platform +from vllm.scalar_type import ScalarType, scalar_types from .quant_utils import pack_cols, unpack_cols @@ -13,7 +14,6 @@ GPTQ_MARLIN_MIN_THREAD_K = 128 GPTQ_MARLIN_MAX_PARALLEL = 16 -MARLIN_SUPPORTED_NUM_BITS = [4, 8] MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128] # In case there is a performance issue with Marlin, the variable below can be @@ -22,76 +22,70 @@ USE_FP32_REDUCE_DEFAULT = True -def _check_marlin_supported(num_bits: int, group_size: int, is_sym: bool, - min_capability: Optional[int], - has_zp: bool) -> Tuple[bool, Optional[str]]: - if min_capability is not None: +# For binary size and compile time, we don't support the same types for with and +# without runtime zero-point. We support common cases, i.e. AWQ and GPTQ. +# TODO: we may want to move this into the C++ so its closer to the actual impl +def query_marlin_supported_quant_types(has_zp: bool, + min_capability: Optional[int] = None): + if min_capability is None: major, minor = current_platform.get_device_capability() - device_capability = major * 10 + minor - if device_capability < min_capability: - return (False, "Marlin does not support device_capability = {}" - ", the min_capability required is {}".format( - device_capability, min_capability)) - - if num_bits not in MARLIN_SUPPORTED_NUM_BITS: - return (False, "Marlin does not support weight_bits = {}. " - "Only weight_bits = {} are supported.".format( - num_bits, MARLIN_SUPPORTED_NUM_BITS)) - - if group_size not in MARLIN_SUPPORTED_GROUP_SIZES: - return (False, "Marlin does not support group_size = {}. Only " - "group_sizes = {} are supported.".format( - group_size, MARLIN_SUPPORTED_GROUP_SIZES)) - - if not has_zp and not is_sym: - return (False, - "Marlin without zero_points must have symmetric quantization") + min_capability = major * 10 + minor - return True, None + if min_capability < 80: + return [] + if has_zp: + # AWQ style, unsigned + runtime zero-point + return [scalar_types.uint4, scalar_types.uint8] + else: + # GPTQ style, unsigned + symmetric bias + # TODO: once fp8_marlin is merged into "gptq_marlin" we should be able + # to add `scalar_types.float8_e4m3fn` here + return [scalar_types.uint4b8, scalar_types.uint8b128] -def check_gptq_marlin_supported(num_bits: int, group_size: int, is_sym: bool, - min_capability: int) -> bool: - cond, _ = _check_marlin_supported(num_bits, - group_size, - is_sym, - min_capability, - has_zp=False) - return cond +def _check_marlin_supported( + quant_type: ScalarType, + group_size: Optional[int], + has_zp: bool, + min_capability: Optional[int] = None) -> Tuple[bool, Optional[str]]: -def check_awq_marlin_supported(num_bits: int, group_size: int, has_zp: bool, - min_capability: int) -> bool: - cond, _ = _check_marlin_supported(num_bits, - group_size, - False, - min_capability, - has_zp=has_zp) - return cond + if min_capability is None: + major, minor = current_platform.get_device_capability() + min_capability = major * 10 + minor + supported_types = query_marlin_supported_quant_types( + has_zp, min_capability) -def verify_gptq_marlin_supported(num_bits: int, group_size: int, - is_sym: bool) -> None: - cond, err_msg = _check_marlin_supported(num_bits, - group_size, - is_sym, - min_capability=None, - has_zp=False) - if not cond: - assert err_msg is not None - raise ValueError("GPTQ" + err_msg) + if quant_type not in supported_types: + return (False, f"Marlin does not support weight_bits = {quant_type}. " + f"Only types = {supported_types} " + f"are supported (for group_size = {group_size}, " + f"min_capability = {min_capability}, zp = {has_zp}).") + if (group_size is None or group_size not in MARLIN_SUPPORTED_GROUP_SIZES): + return (False, f"Marlin does not support group_size = {group_size}. " + f"Only group_sizes = {MARLIN_SUPPORTED_GROUP_SIZES} " + "are supported.") + + return True, None + + +def check_marlin_supported(quant_type: ScalarType, + group_size: int, + has_zp: bool = False, + min_capability: Optional[int] = None) -> bool: + cond, _ = _check_marlin_supported(quant_type, group_size, has_zp, + min_capability) + return cond -def verify_awq_marlin_supported(num_bits: int, group_size: int, - has_zp: bool) -> None: - cond, err_msg = _check_marlin_supported(num_bits, - group_size, - False, - min_capability=None, - has_zp=has_zp) +def verify_marlin_supported(quant_type: ScalarType, + group_size: int, + has_zp: bool = False) -> None: + cond, err_msg = _check_marlin_supported(quant_type, group_size, has_zp) if not cond: assert err_msg is not None - raise ValueError("AWQ" + err_msg) + raise ValueError(err_msg) def verify_marlin_supports_shape(output_size_per_partition: int, @@ -245,7 +239,7 @@ def apply_gptq_marlin_linear( g_idx: torch.Tensor, g_idx_sort_indices: torch.Tensor, workspace: torch.Tensor, - num_bits: int, + wtype: ScalarType, output_size_per_partition: int, input_size_per_partition: int, is_k_full: bool, @@ -261,7 +255,7 @@ def apply_gptq_marlin_linear( g_idx, g_idx_sort_indices, workspace, - num_bits, + wtype, size_m=reshaped_x.shape[0], size_n=output_size_per_partition, size_k=input_size_per_partition, @@ -283,7 +277,7 @@ def apply_awq_marlin_linear( g_idx: torch.Tensor, g_idx_sort_indices: torch.Tensor, workspace: torch.Tensor, - num_bits: int, + quant_type: ScalarType, output_size_per_partition: int, input_size_per_partition: int, bias: Optional[torch.Tensor] = None, @@ -298,7 +292,7 @@ def apply_awq_marlin_linear( g_idx, g_idx_sort_indices, workspace, - num_bits, + quant_type, size_m=reshaped_x.shape[0], size_n=output_size_per_partition, size_k=input_size_per_partition, diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils_test.py b/vllm/model_executor/layers/quantization/utils/marlin_utils_test.py index 541d148c761fc..7d08ac6f87469 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils_test.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils_test.py @@ -5,10 +5,12 @@ import numpy as np import torch +from vllm.scalar_type import ScalarType + from .marlin_utils import (GPTQ_MARLIN_TILE, marlin_permute_scales, marlin_zero_points) -from .quant_utils import (get_pack_factor, quantize_weights, - quantize_weights_with_zp, sort_weights) +from .quant_utils import (get_pack_factor, gptq_quantize_weights, + quantize_weights, sort_weights) class MarlinWorkspace: @@ -90,9 +92,10 @@ def get_weight_perm(num_bits: int): return perm -def marlin_quantize(w: torch.Tensor, num_bits: int, group_size: int, +def marlin_quantize(w: torch.Tensor, quant_type: ScalarType, group_size: int, act_order: bool): size_k, size_n = w.shape + num_bits = quant_type.size_bits # Normalize group_size if group_size == -1: @@ -100,8 +103,8 @@ def marlin_quantize(w: torch.Tensor, num_bits: int, group_size: int, assert group_size <= size_k # Quantize (and apply act_order if provided) - w_ref, q_w, s, g_idx, rand_perm = quantize_weights(w, num_bits, group_size, - act_order) + w_ref, q_w, s, g_idx, rand_perm = gptq_quantize_weights( + w, quant_type, group_size, act_order) # For act_order, sort the "weights" and "g_idx" so that group ids are # increasing @@ -122,7 +125,8 @@ def marlin_quantize(w: torch.Tensor, num_bits: int, group_size: int, return res_list -def awq_marlin_quantize(w: torch.Tensor, num_bits: int, group_size: int): +def awq_marlin_quantize(w: torch.Tensor, quant_type: ScalarType, + group_size: int): size_k, size_n = w.shape # Normalize group_size @@ -135,13 +139,18 @@ def awq_marlin_quantize(w: torch.Tensor, num_bits: int, group_size: int): num_groups = size_k // group_size # Quantize with zp - w_ref, q_w, s, zp = quantize_weights_with_zp(w, num_bits, group_size) + w_ref, q_w, s, zp = quantize_weights(w, + quant_type, + group_size, + zero_points=True) # Reformat to marlin - weight_perm = get_weight_perm(num_bits) - marlin_q_w = marlin_weights(q_w, size_k, size_n, num_bits, weight_perm) + weight_perm = get_weight_perm(quant_type.size_bits) + marlin_q_w = marlin_weights(q_w, size_k, size_n, quant_type.size_bits, + weight_perm) marlin_s = marlin_permute_scales(s, size_k, size_n, group_size) - marlin_zp = marlin_zero_points(zp, num_groups, size_n, num_bits) + marlin_zp = marlin_zero_points(zp, num_groups, size_n, + quant_type.size_bits) # Create result res_list = [w_ref, marlin_q_w, marlin_s, marlin_zp] diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils_test_24.py b/vllm/model_executor/layers/quantization/utils/marlin_utils_test_24.py index 648c32249a571..17d09055b1eac 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils_test_24.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils_test_24.py @@ -6,8 +6,10 @@ import numpy import torch +from vllm.scalar_type import ScalarType + from .marlin_utils_test import marlin_weights -from .quant_utils import quantize_weights +from .quant_utils import gptq_quantize_weights # This is PyTorch implementation of main part of reorder_meta() @@ -348,13 +350,11 @@ def check_24(w, num_rows_to_sample=50, _verbose=False): print(f"{non_24_segments} / {total_segments} do not have 2:4 structure.") -def compress_quantized_24_weight(q_24, size_k, size_n, num_bits): +def compress_quantized_24_weight(q_24, size_k, size_n, wtype: ScalarType): assert q_24.shape == (size_k, size_n) - # Remove zp to normalize over 0 - max_q_val = (1 << num_bits) - 1 - zp = (max_q_val + 1) // 2 - q_24_no_zp = q_24 - zp + # Remove bias to normalize over 0 + q_24_no_zp = q_24 - wtype.bias # Compress q_24_no_zp = q_24_no_zp.t().contiguous() @@ -362,8 +362,8 @@ def compress_quantized_24_weight(q_24, size_k, size_n, num_bits): q_24_no_zp) q_24_no_zp_comp = q_24_no_zp_comp.t().contiguous() - # Restore zp - q_24_comp = q_24_no_zp_comp + zp + # Restore bias + q_24_comp = q_24_no_zp_comp + wtype.bias # Resize meta to its actual shape (without moving any data) meta = meta.resize_(meta.shape[1] // 2, meta.shape[0] * 2) @@ -427,7 +427,7 @@ def marlin_permute_scales_24(s: torch.Tensor, size_k: int, size_n: int, def marlin_24_quantize( w: torch.Tensor, - num_bits: int, + quant_type: ScalarType, group_size: int, ): size_k, size_n = w.shape @@ -441,20 +441,18 @@ def marlin_24_quantize( w_24, mask_24 = inject_24(w, size_k, size_n) # Quantize - w_24_ref, q_w_24, s, g_idx, rand_perm = quantize_weights(w_24, - num_bits, - group_size, - act_order=False) + w_24_ref, q_w_24, s, g_idx, rand_perm = gptq_quantize_weights( + w_24, quant_type, group_size, act_order=False) # Compress quantized weight q_w_24_comp, meta = compress_quantized_24_weight(q_w_24, size_k, size_n, - num_bits) + quant_type) size_k_comp = size_k // 2 # Reformat to marlin - weight_perm = get_weight_perm_24(num_bits) + weight_perm = get_weight_perm_24(quant_type.size_bits) marlin_24_q_w_comp = marlin_weights(q_w_24_comp, size_k_comp, size_n, - num_bits, weight_perm) + quant_type.size_bits, weight_perm) marlin_24_s = marlin_permute_scales_24(s, size_k, size_n, group_size) # Create result diff --git a/vllm/model_executor/layers/quantization/utils/quant_utils.py b/vllm/model_executor/layers/quantization/utils/quant_utils.py index 7ade8bf664ccc..7f9081b257705 100644 --- a/vllm/model_executor/layers/quantization/utils/quant_utils.py +++ b/vllm/model_executor/layers/quantization/utils/quant_utils.py @@ -4,7 +4,11 @@ import numpy import torch -SUPPORTED_NUM_BITS = [4, 8] +from vllm.model_executor.layers.quantization.qqq import ( + MARLIN_QQQ_SUPPORTED_NUM_BITS) +from vllm.scalar_type import ScalarType, scalar_types + +SUPPORTED_GPTQ_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128] SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128] # Note: this is a hack. We should update each model to register the @@ -45,7 +49,7 @@ def is_layer_skipped(prefix: str, ignored_layers: List[str]) -> bool: def get_pack_factor(num_bits): - assert num_bits in SUPPORTED_NUM_BITS, f"Unsupported num_bits = {num_bits}" + assert 32 % num_bits == 0, f"Unsupported num_bits = {num_bits}" return 32 // num_bits @@ -74,24 +78,23 @@ def permute_rows(q_w: torch.Tensor, w_ref: torch.Tensor, group_size: int): ) -def quantize_weights(w: torch.Tensor, num_bits: int, group_size: int, - act_order: bool): +def quantize_weights(w: torch.Tensor, + quant_type: ScalarType, + group_size: int, + zero_points: bool = False): + assert quant_type.is_integer(), \ + "Floating point quantization may work but has not been tested" + orig_device = w.device + orig_type = w.dtype size_k, size_n = w.shape assert w.is_floating_point(), "w must be float" - assert num_bits in SUPPORTED_NUM_BITS, f"Unsupported num_bits = {num_bits}" - assert group_size in SUPPORTED_GROUP_SIZES + [ - size_k - ], f"Unsupported groupsize = {group_size}" if group_size == -1: group_size = size_k assert group_size <= size_k - max_q_val = 2**num_bits - 1 - half_q_val = (max_q_val + 1) // 2 - # Reshape to [groupsize, -1] if group_size < size_k: w = w.reshape((-1, group_size, size_n)) @@ -99,16 +102,34 @@ def quantize_weights(w: torch.Tensor, num_bits: int, group_size: int, w = w.reshape((group_size, -1)) # Compute scale for each group - s = torch.max(torch.abs(w), 0, keepdim=True)[0] - s *= 2 / max_q_val # 2 => symmetric + max_val = torch.max(w, 0, keepdim=True).values + min_val = torch.min(w, 0, keepdim=True).values + + max_q_val = quant_type.max() + min_q_val = quant_type.min() + + if zero_points: + assert not quant_type.is_signed() and quant_type.max() > 0 + w_s = (max_val - min_val).clamp(min=1e-5) / quant_type.max() + maybe_w_zp = torch.round(torch.abs(min_val / w_s)) \ + .clamp(min_q_val, max_q_val).int() + else: + # If the bias is such that there are no possible negative/positive + # values, set the max value to inf to avoid divide by 0 + w_s = torch.max( + abs(max_val / (max_q_val if max_q_val != 0 else torch.inf)), + abs(min_val / (min_q_val if min_q_val != 0 else torch.inf))) + maybe_w_zp = None # Quantize - q_w = torch.round(w / s).int() - q_w += half_q_val - q_w = torch.clamp(q_w, 0, max_q_val) + w_q = torch.round(w / w_s).int() + (maybe_w_zp if zero_points else 0) + w_q = torch.clamp(w_q, min_q_val, max_q_val) # Compute ref (dequantized) - w_ref = (q_w - half_q_val).half() * s + w_ref = (w_q - (maybe_w_zp if zero_points else 0)).to(orig_type) * w_s + + if quant_type.has_bias(): + w_q += quant_type.bias # Restore original shapes if group_size < size_k: @@ -119,90 +140,48 @@ def reshape_w(w): w = w.reshape((size_k, size_n)).contiguous() return w - q_w = reshape_w(q_w) + w_q = reshape_w(w_q) w_ref = reshape_w(w_ref) - s = s.reshape((-1, size_n)).contiguous() + w_s = w_s.reshape((-1, size_n)).contiguous() - # Apply act_order - g_idx = torch.empty(0, dtype=torch.int, device=w.device) - rand_perm = torch.empty(0, dtype=torch.int, device=w.device) - if act_order: - assert ( - group_size < size_k - ), "For act_order, groupsize = {} must be less than size_k = {}".format( - group_size, size_k) - - w_ref, q_w, g_idx, rand_perm = permute_rows(q_w, w_ref, group_size) + if zero_points: + maybe_w_zp = maybe_w_zp.reshape((-1, size_n)).contiguous() + maybe_w_zp = maybe_w_zp.to(device=orig_device) return ( w_ref.to(device=orig_device), - q_w.to(device=orig_device), - s.to(device=orig_device), - g_idx.to(device=orig_device), - rand_perm.to(device=orig_device), + w_q.to(device=orig_device), + w_s.to(device=orig_device), + maybe_w_zp, ) -def quantize_weights_with_zp(w: torch.Tensor, num_bits: int, group_size: int): - orig_device = w.device - size_k, size_n = w.shape +def gptq_quantize_weights(w: torch.Tensor, quant_type: ScalarType, + group_size: int, act_order: bool): + size_k, _ = w.shape assert w.is_floating_point(), "w must be float" - assert num_bits in SUPPORTED_NUM_BITS, f"Unsupported num_bits = {num_bits}" + assert quant_type in SUPPORTED_GPTQ_QUANT_TYPES, \ + f"Unsupported gptq type = {quant_type}" assert group_size in SUPPORTED_GROUP_SIZES + [ size_k ], f"Unsupported groupsize = {group_size}" - if group_size == -1: - group_size = size_k - assert group_size <= size_k - - max_q_val = 2**num_bits - 1 - min_q_val = 0 + w_ref, w_q, w_s, _ = quantize_weights(w, quant_type, group_size) - # Reshape to [groupsize, -1] - if group_size < size_k: - w = w.reshape((-1, group_size, size_n)) - w = w.permute(1, 0, 2) - w = w.reshape((group_size, -1)) - - # Compute scale for each group - max = torch.max(w, 0, keepdim=True)[0] - min = torch.min(w, 0, keepdim=True)[0] - s = (max - min).clamp(min=1e-5) / max_q_val - - # Compute zero-point for each group - zp = (-torch.round(min / s)).clamp(min_q_val, max_q_val).int() - - # Quantize - q_w = torch.round(w / s).int() + zp - q_w = torch.clamp(q_w, min_q_val, max_q_val) - - # Compute ref (dequantized) - w_ref = (q_w - zp).half() * s - - # Restore original shapes - if group_size < size_k: - - def reshape_w(w): - w = w.reshape((group_size, -1, size_n)) - w = w.permute(1, 0, 2) - w = w.reshape((size_k, size_n)).contiguous() - return w - - q_w = reshape_w(q_w) - w_ref = reshape_w(w_ref) + # Apply act_order + g_idx = torch.empty(0, dtype=torch.int, device=w.device) + rand_perm = torch.empty(0, dtype=torch.int, device=w.device) + if act_order: + assert ( + group_size < size_k + ), "For act_order, groupsize = {} must be less than size_k = {}".format( + group_size, size_k) - s = s.reshape((-1, size_n)).contiguous() - zp = zp.reshape((-1, size_n)).contiguous() + w_ref, w_q, g_idx, rand_perm = permute_rows(w_q, w_ref, group_size) - return ( - w_ref.to(device=orig_device), - q_w.to(device=orig_device), - s.to(device=orig_device), - zp.to(device=orig_device), - ) + return w_ref, w_q, w_s, g_idx, rand_perm # QQQ employs different quant schemes for per-group and @@ -212,7 +191,8 @@ def qqq_quantize_weights(w: torch.Tensor, num_bits: int, group_size: int): size_k, size_n = w.shape assert w.is_floating_point(), "w must be float" - assert num_bits in SUPPORTED_NUM_BITS, f"Unsupported num_bits = {num_bits}" + assert num_bits in MARLIN_QQQ_SUPPORTED_NUM_BITS, \ + f"Unsupported num_bits = {num_bits}" assert group_size in SUPPORTED_GROUP_SIZES + [ size_k ], f"Unsupported groupsize = {group_size}" diff --git a/vllm/scalar_type.py b/vllm/scalar_type.py new file mode 100644 index 0000000000000..eb491dd1554a8 --- /dev/null +++ b/vllm/scalar_type.py @@ -0,0 +1,35 @@ +from ._core_ext import NanRepr, ScalarType + +# naming generally follows: https://github.com/jax-ml/ml_dtypes +# for floating point types (leading f) the scheme is: +# `float_em[flags]` +# flags: +# - no-flags: means it follows IEEE 754 conventions +# - f: means finite values only (no infinities) +# - n: means nans are supported (non-standard encoding) +# for integer types the scheme is: +# `[u]int[b]` +# - if bias is not present it means its zero + + +class scalar_types: + int4 = ScalarType.int_(4, None) + uint4 = ScalarType.uint(4, None) + int8 = ScalarType.int_(8, None) + uint8 = ScalarType.uint(8, None) + float8_e4m3fn = ScalarType.float_(4, 3, True, + NanRepr.EXTD_RANGE_MAX_MIN.value) + float8_e5m2 = ScalarType.float_IEEE754(5, 2) + float16_e8m7 = ScalarType.float_IEEE754(8, 7) + float16_e5m10 = ScalarType.float_IEEE754(5, 10) + + # fp6, https://github.com/usyd-fsalab/fp6_llm/tree/main + float6_e3m2f = ScalarType.float_(3, 2, True, NanRepr.NONE.value) + + # "gptq" types + uint4b8 = ScalarType.uint(4, 8) + uint8b128 = ScalarType.uint(8, 128) + + # colloquial names + bfloat16 = float16_e8m7 + float16 = float16_e5m10 From 05308891e203329a733bcf29a3452b15b75b5eb4 Mon Sep 17 00:00:00 2001 From: Rui Qiao <161574667+ruisearch42@users.noreply.github.com> Date: Fri, 2 Aug 2024 13:55:40 -0700 Subject: [PATCH 03/36] [Core] Pipeline parallel with Ray ADAG (#6837) Support pipeline-parallelism with Ray accelerated DAG. Signed-off-by: Rui Qiao --- Dockerfile | 2 + MANIFEST.in | 1 + requirements-adag.txt | 3 + requirements-test.txt | 3 + tests/distributed/test_pipeline_parallel.py | 51 +++++--- tests/utils.py | 31 ++++- vllm/envs.py | 12 +- vllm/executor/ray_gpu_executor.py | 137 +++++++++++++------- vllm/executor/ray_utils.py | 30 ++++- vllm/worker/worker_base.py | 6 +- 10 files changed, 199 insertions(+), 77 deletions(-) create mode 100644 requirements-adag.txt diff --git a/Dockerfile b/Dockerfile index 7294707046abc..49aaea2949ac6 100644 --- a/Dockerfile +++ b/Dockerfile @@ -42,6 +42,7 @@ WORKDIR /workspace # install build and runtime dependencies COPY requirements-common.txt requirements-common.txt +COPY requirements-adag.txt requirements-adag.txt COPY requirements-cuda.txt requirements-cuda.txt RUN --mount=type=cache,target=/root/.cache/pip \ python3 -m pip install -r requirements-cuda.txt @@ -78,6 +79,7 @@ COPY setup.py setup.py COPY cmake cmake COPY CMakeLists.txt CMakeLists.txt COPY requirements-common.txt requirements-common.txt +COPY requirements-adag.txt requirements-adag.txt COPY requirements-cuda.txt requirements-cuda.txt COPY pyproject.toml pyproject.toml COPY vllm vllm diff --git a/MANIFEST.in b/MANIFEST.in index 82be639ef4d73..5a41e5e714184 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,4 +1,5 @@ include LICENSE +include requirements-adag.txt include requirements-common.txt include requirements-cuda.txt include requirements-rocm.txt diff --git a/requirements-adag.txt b/requirements-adag.txt new file mode 100644 index 0000000000000..e77f90fb8f85d --- /dev/null +++ b/requirements-adag.txt @@ -0,0 +1,3 @@ +# Dependencies for Ray accelerated DAG +cupy-cuda12x +ray >= 2.32 \ No newline at end of file diff --git a/requirements-test.txt b/requirements-test.txt index df247496be16c..5f3fd15c7ee56 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -1,3 +1,6 @@ +# Needed for Ray accelerated DAG tests +-r requirements-adag.txt + # testing pytest tensorizer>=2.9.0 diff --git a/tests/distributed/test_pipeline_parallel.py b/tests/distributed/test_pipeline_parallel.py index f632caba9017e..ab325e0966929 100644 --- a/tests/distributed/test_pipeline_parallel.py +++ b/tests/distributed/test_pipeline_parallel.py @@ -15,22 +15,31 @@ @pytest.mark.parametrize( - "TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME, DIST_BACKEND", - [ - (2, 2, 0, 1, "meta-llama/Meta-Llama-3-8B", "ray"), - (2, 2, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray"), - (1, 3, 0, 0, "meta-llama/Meta-Llama-3-8B", "ray"), - (1, 4, 0, 1, "meta-llama/Meta-Llama-3-8B", "ray"), - (1, 4, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray"), - (2, 2, 0, 1, "meta-llama/Meta-Llama-3-8B", "mp"), - (2, 2, 1, 0, "meta-llama/Meta-Llama-3-8B", "mp"), - (1, 3, 0, 0, "meta-llama/Meta-Llama-3-8B", "mp"), - (1, 4, 0, 1, "meta-llama/Meta-Llama-3-8B", "mp"), - (1, 4, 1, 0, "meta-llama/Meta-Llama-3-8B", "mp"), - ]) -@fork_new_process_for_each_test + ("TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, " + "MODEL_NAME, DIST_BACKEND, USE_RAY_ADAG, USE_RAY_ADAG_NCCL"), [ + (2, 2, 0, 1, "meta-llama/Meta-Llama-3-8B", "ray", False, False), + (2, 2, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray", False, False), + (1, 3, 0, 0, "meta-llama/Meta-Llama-3-8B", "ray", False, False), + (1, 4, 0, 1, "meta-llama/Meta-Llama-3-8B", "ray", False, False), + (1, 4, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray", False, False), + (2, 2, 0, 1, "meta-llama/Meta-Llama-3-8B", "ray", True, False), + (2, 2, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray", True, False), + (1, 3, 0, 0, "meta-llama/Meta-Llama-3-8B", "ray", True, False), + (1, 4, 0, 1, "meta-llama/Meta-Llama-3-8B", "ray", True, False), + (1, 4, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray", True, False), + (2, 2, 0, 1, "meta-llama/Meta-Llama-3-8B", "ray", True, True), + (2, 2, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray", True, True), + (1, 3, 0, 0, "meta-llama/Meta-Llama-3-8B", "ray", True, True), + (1, 4, 0, 1, "meta-llama/Meta-Llama-3-8B", "ray", True, True), + (1, 4, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray", True, True), + (2, 2, 0, 1, "meta-llama/Meta-Llama-3-8B", "mp", False, False), + (2, 2, 1, 0, "meta-llama/Meta-Llama-3-8B", "mp", False, False), + (1, 3, 0, 0, "meta-llama/Meta-Llama-3-8B", "mp", False, False), + (1, 4, 0, 1, "meta-llama/Meta-Llama-3-8B", "mp", False, False), + (1, 4, 1, 0, "meta-llama/Meta-Llama-3-8B", "mp", False, False), + ]) def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME, - DIST_BACKEND): + DIST_BACKEND, USE_RAY_ADAG, USE_RAY_ADAG_NCCL): if VLLM_MULTI_NODE and DIST_BACKEND == "mp": pytest.skip("Skipping multi-node pipeline parallel test for " "multiprocessing distributed backend") @@ -67,8 +76,18 @@ def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME, if EAGER_MODE: pp_args.append("--enforce-eager") tp_args.append("--enforce-eager") + pp_env = None + if USE_RAY_ADAG: + assert DIST_BACKEND == "ray", ( + "Ray ADAG is only supported with Ray distributed backend") + pp_env = { + "VLLM_USE_RAY_COMPILED_DAG": "1", + "VLLM_USE_RAY_SPMD_WORKER": "1", + "VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL": + str(int(USE_RAY_ADAG_NCCL)), + } - compare_two_settings(MODEL_NAME, pp_args, tp_args) + compare_two_settings(MODEL_NAME, pp_args, tp_args, pp_env) @pytest.mark.parametrize("PP_SIZE, MODEL_NAME", [ diff --git a/tests/utils.py b/tests/utils.py index f3ee801ee7742..dd8af8e3afe70 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -7,7 +7,7 @@ import warnings from contextlib import contextmanager from pathlib import Path -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional import openai import ray @@ -57,6 +57,7 @@ def __init__( model: str, cli_args: List[str], *, + env_dict: Optional[Dict[str, str]] = None, auto_port: bool = True, ) -> None: if auto_port: @@ -77,6 +78,8 @@ def __init__( # the current process might initialize cuda, # to be safe, we should use spawn method env['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn' + if env_dict is not None: + env.update(env_dict) self.proc = subprocess.Popen(["vllm", "serve"] + [model] + cli_args, env=env, stdout=sys.stdout, @@ -89,6 +92,11 @@ def __enter__(self): def __exit__(self, exc_type, exc_value, traceback): self.proc.terminate() + try: + self.proc.wait(3) + except subprocess.TimeoutExpired: + # force kill if needed + self.proc.kill() def _wait_for_server(self, *, url: str, timeout: float): # run health check @@ -127,10 +135,21 @@ def get_async_client(self): ) -def compare_two_settings(model: str, arg1: List[str], arg2: List[str]): +def compare_two_settings(model: str, + arg1: List[str], + arg2: List[str], + env1: Optional[Dict[str, str]] = None, + env2: Optional[Dict[str, str]] = None): """ - Launch API server with two different sets of arguments and compare the - results of the API calls. The arguments are after the model name. + Launch API server with two different sets of arguments/environments + and compare the results of the API calls. + + Args: + model: The model to test. + arg1: The first set of arguments to pass to the API server. + arg2: The second set of arguments to pass to the API server. + env1: The first set of environment variables to pass to the API server. + env2: The second set of environment variables to pass to the API server. """ tokenizer = AutoTokenizer.from_pretrained(model) @@ -138,8 +157,8 @@ def compare_two_settings(model: str, arg1: List[str], arg2: List[str]): prompt = "Hello, my name is" token_ids = tokenizer(prompt)["input_ids"] results = [] - for args in (arg1, arg2): - with RemoteOpenAIServer(model, args) as server: + for args, env in ((arg1, env1), (arg2, env2)): + with RemoteOpenAIServer(model, args, env_dict=env) as server: client = server.get_client() # test models list diff --git a/vllm/envs.py b/vllm/envs.py index 9bcb26f8e5a64..5b8a65bd6545c 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -38,6 +38,7 @@ VLLM_FUSED_MOE_CHUNK_SIZE: int = 64 * 1024 VLLM_USE_RAY_SPMD_WORKER: bool = False VLLM_USE_RAY_COMPILED_DAG: bool = False + VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL: bool = True VLLM_WORKER_MULTIPROC_METHOD: str = "fork" VLLM_ASSETS_CACHE: str = os.path.join(VLLM_CACHE_ROOT, "assets") VLLM_IMAGE_FETCH_TIMEOUT: int = 5 @@ -273,13 +274,20 @@ def get_default_config_root(): # execution on all workers. # Run vLLM with VLLM_USE_RAY_SPMD_WORKER=1 to enable it. "VLLM_USE_RAY_SPMD_WORKER": - lambda: bool(os.getenv("VLLM_USE_RAY_SPMD_WORKER", 0)), + lambda: bool(int(os.getenv("VLLM_USE_RAY_SPMD_WORKER", "0"))), # If the env var is set, it uses the Ray's compiled DAG API # which optimizes the control plane overhead. # Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it. "VLLM_USE_RAY_COMPILED_DAG": - lambda: bool(os.getenv("VLLM_USE_RAY_COMPILED_DAG", 0)), + lambda: bool(int(os.getenv("VLLM_USE_RAY_COMPILED_DAG", "0"))), + + # If the env var is set, it uses NCCL for communication in + # Ray's compiled DAG. This flag is ignored if + # VLLM_USE_RAY_COMPILED_DAG is not set. + "VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL": + lambda: bool(int(os.getenv("VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL", "1")) + ), # Use dedicated multiprocess context for workers. # Both spawn and fork work diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py index 14007e5518d4a..46d216910a08a 100644 --- a/vllm/executor/ray_gpu_executor.py +++ b/vllm/executor/ray_gpu_executor.py @@ -105,12 +105,19 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", # The remaining workers are the actual ray actors. self.workers: List[RayWorkerWrapper] = [] + # Used in ray compiled DAG: indexed first by PP rank, + # and then TP rank. In other words, the inner list is + # the TP group of workers for a PP rank. + self.pp_tp_workers: List[List[RayWorkerWrapper]] = [] + if self.parallel_config.ray_workers_use_nsight: ray_remote_kwargs = self._configure_ray_workers_use_nsight( ray_remote_kwargs) + logger.info("use_ray_spmd_worker: %s", self.use_ray_spmd_worker) # Create the workers. driver_ip = get_ip() + logger.info("driver_ip: %s", driver_ip) worker_wrapper_kwargs = self._get_worker_wrapper_args() for bundle_id, bundle in enumerate(placement_group.bundle_specs): if not bundle.get("GPU", 0): @@ -142,42 +149,49 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", # Else, added to the list of workers. self.workers.append(worker) + logger.debug("workers: %s", self.workers) + logger.debug("driver_dummy_worker: %s", self.driver_dummy_worker) if not self.use_ray_spmd_worker and self.driver_dummy_worker is None: raise ValueError( "Ray does not allocate any GPUs on the driver node. Consider " "adjusting the Ray placement group or running the driver on a " "GPU node.") + worker_ips = [ + ray.get(worker.get_node_ip.remote()) # type: ignore[attr-defined] + for worker in self.workers + ] + ip_counts: Dict[str, int] = {} + for ip in worker_ips: + ip_counts[ip] = ip_counts.get(ip, 0) + 1 + + def sort_by_driver_then_worker_ip(worker): + """ + Sort the workers based on 3 properties: + 1. If the worker is on the same node as the driver (vllm engine), + it should be placed first. + 2. Then, if the worker is on a node with fewer workers, it should + be placed first. + 3. Finally, if the work is on a node with smaller IP address, it + should be placed first. + """ + ip = ray.get(worker.get_node_ip.remote()) + return (ip != driver_ip, ip_counts[ip], ip) + + # After sorting, the workers on the same node will be + # close to each other, and the workers on the driver + # node will be placed first. + self.workers = sorted(self.workers, key=sort_by_driver_then_worker_ip) + # Get the set of GPU IDs used on each node. worker_node_and_gpu_ids = self._run_workers("get_node_and_gpu_ids", use_dummy_driver=True) - # the order in `worker_node_and_gpu_ids` does not necessarily match - # the machine boundaries. We need to make sure that workers in the - # same node are assigned consecutive ranks. - # examples: - # [('852a09a13c7503ef126d7c828454c741494b1be33a8627a5206604d9', [0]), ('dfaad7adfdae57a694cc74490db45bd112c9f31243523e43ddc2e7f0', [0]), ('dfaad7adfdae57a694cc74490db45bd112c9f31243523e43ddc2e7f0', [1]), ('dfaad7adfdae57a694cc74490db45bd112c9f31243523e43ddc2e7f0', [2]), ('dfaad7adfdae57a694cc74490db45bd112c9f31243523e43ddc2e7f0', [3]), ('852a09a13c7503ef126d7c828454c741494b1be33a8627a5206604d9', [1]), ('852a09a13c7503ef126d7c828454c741494b1be33a8627a5206604d9', [2]), ('852a09a13c7503ef126d7c828454c741494b1be33a8627a5206604d9', [3])] # noqa - - # initialize worker ranks with -1 (unassigned) - worker_ranks = [-1 for x in worker_node_and_gpu_ids] - current_rank = 0 - while -1 in worker_ranks: - # whenever we find an unassigned worker, find the node - index = worker_ranks.index(-1) - current_node_id = worker_node_and_gpu_ids[index][0] - # assign ranks to all workers in the same node - for i, (node_id, _) in enumerate(worker_node_and_gpu_ids): - if node_id == current_node_id: - worker_ranks[i] = current_rank - current_rank += 1 - # with the above example, worker_ranks will be [0, 4, 5, 6, 7, 1, 2, 3] - node_workers = defaultdict(list) # node id -> list of worker ranks node_gpus = defaultdict(list) # node id -> list of gpu ids - for worker_rank, (node_id, gpu_ids) in zip(worker_ranks, - worker_node_and_gpu_ids): - node_workers[node_id].append(worker_rank) + for i, (node_id, gpu_ids) in enumerate(worker_node_and_gpu_ids): + node_workers[node_id].append(i) # `gpu_ids` can be a list of strings or integers. # convert them to integers for consistency. # NOTE: gpu_ids can be larger than 9 (e.g. 16 GPUs), @@ -202,16 +216,6 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", self._run_workers("update_environment_variables", all_args=all_args_to_update_environment_variables) - if len(node_gpus) == 1: - # in single node case, we don't need to get the IP address. - # the loopback address is sufficient - # NOTE: a node may have several IP addresses, one for each - # network interface. `get_ip()` might return any of them, - # while they might not work for communication inside the node - # if the network setup is complicated. Using the loopback address - # solves this issue, as it always works for communication inside - # the node. - driver_ip = "127.0.0.1" distributed_init_method = get_distributed_init_method( driver_ip, get_open_port()) @@ -221,8 +225,7 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", local_rank=node_workers[node_id].index(rank), rank=rank, distributed_init_method=distributed_init_method, - ) for rank, (node_id, - _) in zip(worker_ranks, worker_node_and_gpu_ids) + ) for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids) ] self._run_workers("init_worker", all_kwargs=init_worker_all_kwargs) @@ -231,6 +234,19 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", max_concurrent_workers=self.parallel_config. max_parallel_loading_workers) + if self.use_ray_spmd_worker: + for pp_rank in range(self.parallel_config.pipeline_parallel_size): + self.pp_tp_workers.append([]) + for tp_rank in range( + self.parallel_config.tensor_parallel_size): + # PP=2, TP=4 + # pp_tp_workers = [[0, 1, 2, 3], [4, 5, 6, 7]] + rank = (pp_rank * self.parallel_config.tensor_parallel_size + ) + tp_rank + assert len(self.pp_tp_workers[pp_rank]) == tp_rank + assert pp_rank < len(self.pp_tp_workers) + self.pp_tp_workers[pp_rank].append(self.workers[rank]) + # This is the list of workers that are rank 0 of each TP group EXCEPT # global rank 0. These are the workers that will broadcast to the # rest of the workers. @@ -241,9 +257,9 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", self.non_driver_workers: List[RayWorkerWrapper] = [] # Enforce rank order for correct rank to return final output. - for rank, worker in sorted(zip(worker_ranks[1:], self.workers)): - # We need to skip the driver worker, which we - # do by skipping worker_ranks[0] which is always 0. + for index, worker in enumerate(self.workers): + # The driver worker is rank 0 and not in self.workers. + rank = index + 1 if rank % self.parallel_config.tensor_parallel_size == 0: self.tp_driver_workers.append(worker) else: @@ -376,16 +392,47 @@ def _compiled_ray_dag(self, enable_asyncio: bool): raise ValueError(f"Ray version {required_version} or greater is " f"required, but found {current_version}") - from ray.dag import InputNode, MultiOutputNode assert self.parallel_config.use_ray + from ray.dag import InputNode, MultiOutputNode + from ray.experimental.channel.torch_tensor_type import TorchTensorType - # Right now, compiled DAG requires at least 1 arg. We send - # a dummy value for now. It will be fixed soon. + logger.info("VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL = %s", + envs.VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL) with InputNode() as input_data: - forward_dag = MultiOutputNode([ - worker.execute_model_spmd.bind( # type: ignore[attr-defined] - input_data) for worker in self.workers - ]) + # Example DAG: PP=2, TP=4 + # (ExecuteModelReq, None) -> 0 -> (ExecuteModelReq, IntermediateOutput) -> 4 -> SamplerOutput # noqa: E501 + # -> 1 -> (ExecuteModelReq, IntermediateOutput) -> 5 -> SamplerOutput # noqa: E501 + # -> 2 -> (ExecuteModelReq, IntermediateOutput) -> 6 -> SamplerOutput # noqa: E501 + # -> 3 -> (ExecuteModelReq, IntermediateOutput) -> 7 -> SamplerOutput # noqa: E501 + + # All workers in the first TP group will take in the + # ExecuteModelRequest as input. + outputs = [input_data for _ in self.pp_tp_workers[0]] + for pp_rank, tp_group in enumerate(self.pp_tp_workers): + # Each PP worker takes in the output of the previous PP worker, + # and the TP group executes in SPMD fashion. + outputs = [ + worker.execute_model_spmd. + bind( # type: ignore[attr-defined] + outputs[i]) for i, worker in enumerate(tp_group) + ] + + last_pp_rank = len(self.pp_tp_workers) - 1 + if pp_rank < last_pp_rank: + # Specify how intermediate tensors should be passed + # between pp stages, no need to specify for the last + # pp stage. + transport = "nccl" \ + if envs.VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL \ + else "auto" + outputs = [ + output.with_type_hint( + TorchTensorType(transport=transport)) + for output in outputs + ] + + forward_dag = MultiOutputNode(outputs) + return forward_dag.experimental_compile(enable_asyncio=enable_asyncio) def __del__(self): diff --git a/vllm/executor/ray_utils.py b/vllm/executor/ray_utils.py index 58b864070f727..ac948331e81e0 100644 --- a/vllm/executor/ray_utils.py +++ b/vllm/executor/ray_utils.py @@ -1,8 +1,8 @@ -from typing import List, Optional, Tuple +from typing import List, Optional, Tuple, Union from vllm.config import ParallelConfig from vllm.logger import init_logger -from vllm.sequence import ExecuteModelRequest +from vllm.sequence import ExecuteModelRequest, IntermediateTensors from vllm.utils import get_ip, is_hip, is_tpu, is_xpu from vllm.worker.worker_base import WorkerWrapperBase @@ -31,9 +31,17 @@ def get_node_and_gpu_ids(self) -> Tuple[str, List[int]]: gpu_ids = ray.get_gpu_ids() return node_id, gpu_ids - def execute_model_spmd(self, execute_model_req: ExecuteModelRequest): - """Used only when SPMD worker and compiled DAG are both - enabled.""" + def execute_model_spmd( + self, req_or_tuple: Union[ExecuteModelRequest, + Tuple[ExecuteModelRequest, + IntermediateTensors]]): + """Execute model in SPMD fashion: used only when SPMD worker and + compiled DAG are both enabled. + + Args: + req_or_tuple: The request to execute the model, or a tuple + containing the request and intermediate tensors. + """ # TODO(swang): This is needed right now because Ray aDAG executes # on a background thread, so we need to reset torch's current # device. @@ -42,7 +50,17 @@ def execute_model_spmd(self, execute_model_req: ExecuteModelRequest): torch.cuda.set_device(self.worker.device) self.compiled_dag_cuda_device_set = True - return self.worker._execute_model_spmd(execute_model_req) + if isinstance(req_or_tuple, tuple): + execute_model_req, intermediate_tensors = req_or_tuple + else: + execute_model_req = req_or_tuple + intermediate_tensors = None + + output = self.worker._execute_model_spmd(execute_model_req, + intermediate_tensors) + if isinstance(output, IntermediateTensors): + return execute_model_req, output + return output ray_import_err = None diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index 8a4d1958c65a0..e56440693b895 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -285,7 +285,9 @@ def execute_model( return output def _execute_model_spmd( - self, execute_model_req: ExecuteModelRequest + self, + execute_model_req: ExecuteModelRequest, + intermediate_tensors: Optional[IntermediateTensors] = None ) -> Optional[List[SamplerOutput]]: """ Execute model in Single Program Multiple Data (SPMD) fashion. @@ -309,7 +311,7 @@ def _execute_model_spmd( return self.model_runner.execute_model( model_input, self.kv_cache[worker_input.virtual_engine] - if self.kv_cache is not None else None) + if self.kv_cache is not None else None, intermediate_tensors) class WorkerWrapperBase: From 22e718ff1a51930231d87c89d6c43676af59860b Mon Sep 17 00:00:00 2001 From: Rui Qiao <161574667+ruisearch42@users.noreply.github.com> Date: Fri, 2 Aug 2024 15:50:00 -0700 Subject: [PATCH 04/36] [Misc] Revive to use loopback address for driver IP (#7091) Signed-off-by: Rui Qiao --- vllm/executor/ray_gpu_executor.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py index 46d216910a08a..4a6825c01fcf8 100644 --- a/vllm/executor/ray_gpu_executor.py +++ b/vllm/executor/ray_gpu_executor.py @@ -216,6 +216,16 @@ def sort_by_driver_then_worker_ip(worker): self._run_workers("update_environment_variables", all_args=all_args_to_update_environment_variables) + if len(node_gpus) == 1: + # in single node case, we don't need to get the IP address. + # the loopback address is sufficient + # NOTE: a node may have several IP addresses, one for each + # network interface. `get_ip()` might return any of them, + # while they might not work for communication inside the node + # if the network setup is complicated. Using the loopback address + # solves this issue, as it always works for communication inside + # the node. + driver_ip = "127.0.0.1" distributed_init_method = get_distributed_init_method( driver_ip, get_open_port()) From 708989341ef6361a5981d890a0e2f1b794323458 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 2 Aug 2024 16:18:45 -0700 Subject: [PATCH 05/36] [misc] add a flag to enable compile (#7092) --- vllm/envs.py | 4 ++++ vllm/worker/model_runner.py | 6 ++++++ 2 files changed, 10 insertions(+) diff --git a/vllm/envs.py b/vllm/envs.py index 5b8a65bd6545c..595058bcbb027 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -174,6 +174,10 @@ def get_default_config_root(): lambda: (os.environ.get("VLLM_USE_TRITON_FLASH_ATTN", "True").lower() in ("true", "1")), + # Internal flag to enable Dynamo graph capture + "VLLM_TEST_DYNAMO_GRAPH_CAPTURE": + lambda: int(os.environ.get("VLLM_TEST_DYNAMO_GRAPH_CAPTURE", "0")), + # local rank of the process in the distributed setting, used to determine # the GPU device id "LOCAL_RANK": diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 7773442899585..f9c26e0c318b1 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -23,6 +23,7 @@ BatchPrefillWithPagedKVCacheWrapper = None FLASHINFER_WORKSPACE_BUFFER_SIZE = 0 +import vllm.envs as envs from vllm.attention import AttentionMetadata, get_attn_backend from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, MultiModalConfig, ParallelConfig, @@ -786,6 +787,11 @@ def load_model(self) -> None: "provided. Defaulting to scaling factors of 1.0. " "This may lead to less accurate results!") + if envs.VLLM_TEST_DYNAMO_GRAPH_CAPTURE: + self.model = torch.compile(self.model, + fullgraph=True, + backend="eager") + def save_sharded_state( self, path: str, From ed812a73fae77bb520b739cfeaad36dbd61e2b03 Mon Sep 17 00:00:00 2001 From: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com> Date: Fri, 2 Aug 2024 21:27:28 -0400 Subject: [PATCH 06/36] [ Frontend ] Multiprocessing for OpenAI Server with `zeromq` (#6883) Signed-off-by: Joe Runde Co-authored-by: Joe Runde Co-authored-by: Joe Runde Co-authored-by: Nick Hill Co-authored-by: Simon Mo --- tests/entrypoints/openai/test_disable_mp.py | 715 ++++++++++++++++++ vllm/engine/async_llm_engine.py | 27 +- vllm/engine/llm_engine.py | 36 +- vllm/engine/protocol.py | 84 ++ vllm/entrypoints/openai/api_server.py | 132 +++- vllm/entrypoints/openai/cli_args.py | 9 +- vllm/entrypoints/openai/logits_processors.py | 19 +- vllm/entrypoints/openai/rpc/__init__.py | 42 + vllm/entrypoints/openai/rpc/client.py | 248 ++++++ vllm/entrypoints/openai/rpc/server.py | 216 ++++++ vllm/entrypoints/openai/serving_chat.py | 16 +- vllm/entrypoints/openai/serving_completion.py | 19 +- vllm/entrypoints/openai/serving_embedding.py | 13 +- vllm/entrypoints/openai/serving_engine.py | 8 +- .../openai/serving_tokenization.py | 10 +- vllm/envs.py | 6 + .../outlines_logits_processors.py | 19 + vllm/tracing.py | 2 +- .../tokenizer_group/__init__.py | 19 +- vllm/utils.py | 28 +- 20 files changed, 1567 insertions(+), 101 deletions(-) create mode 100644 tests/entrypoints/openai/test_disable_mp.py create mode 100644 vllm/engine/protocol.py create mode 100644 vllm/entrypoints/openai/rpc/__init__.py create mode 100644 vllm/entrypoints/openai/rpc/client.py create mode 100644 vllm/entrypoints/openai/rpc/server.py diff --git a/tests/entrypoints/openai/test_disable_mp.py b/tests/entrypoints/openai/test_disable_mp.py new file mode 100644 index 0000000000000..12c805413311c --- /dev/null +++ b/tests/entrypoints/openai/test_disable_mp.py @@ -0,0 +1,715 @@ +""" +Repeat of tests in test_completion.py with the non-mp backend. +""" + +# imports for guided decoding tests +import json +import re +import shutil +from tempfile import TemporaryDirectory +from typing import List + +import jsonschema +import openai # use the official client for correctness check +import pytest +# downloading lora to test lora requests +from huggingface_hub import snapshot_download +from openai import BadRequestError +from transformers import AutoTokenizer + +from vllm.transformers_utils.tokenizer import get_tokenizer + +from ...utils import RemoteOpenAIServer + +# any model with a chat template should work here +MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" +# technically these adapters use a different base model, +# but we're not testing generation quality here +LORA_NAME = "typeof/zephyr-7b-beta-lora" +PA_NAME = "swapnilbp/llama_tweet_ptune" +# if PA_NAME changes, PA_NUM_VIRTUAL_TOKENS might also +# need to change to match the prompt adapter +PA_NUM_VIRTUAL_TOKENS = 8 + + +@pytest.fixture(scope="module") +def zephyr_lora_files(): + return snapshot_download(repo_id=LORA_NAME) + + +@pytest.fixture(scope="module") +def zephyr_lora_added_tokens_files(zephyr_lora_files): + tmp_dir = TemporaryDirectory() + tmp_model_dir = f"{tmp_dir.name}/zephyr" + shutil.copytree(zephyr_lora_files, tmp_model_dir) + tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) + # Copy tokenizer to adapter and add some unique tokens + # 32000, 32001, 32002 + added = tokenizer.add_tokens(["vllm1", "vllm2", "vllm3"], + special_tokens=True) + assert added == 3 + tokenizer.save_pretrained(tmp_model_dir) + yield tmp_model_dir + tmp_dir.cleanup() + + +@pytest.fixture(scope="module") +def zephyr_pa_files(): + return snapshot_download(repo_id=PA_NAME) + + +@pytest.fixture(scope="module") +def default_server_args(zephyr_lora_files, zephyr_lora_added_tokens_files, + zephyr_pa_files): + return [ + # use half precision for speed and memory savings in CI environment + "--dtype", + "bfloat16", + "--max-model-len", + "8192", + "--max-num-seqs", + "128", + "--enforce-eager", + # lora config + "--enable-lora", + "--lora-modules", + f"zephyr-lora={zephyr_lora_files}", + f"zephyr-lora2={zephyr_lora_added_tokens_files}", + "--max-lora-rank", + "64", + "--max-cpu-loras", + "2", + # pa config + "--enable-prompt-adapter", + "--prompt-adapters", + f"zephyr-pa={zephyr_pa_files}", + f"zephyr-pa2={zephyr_pa_files}", + "--max-prompt-adapters", + "2", + "--max-prompt-adapter-token", + "128", + "--disable-frontend-multiprocessing" + ] + + +@pytest.fixture(scope="module") +def server(default_server_args): + with RemoteOpenAIServer(MODEL_NAME, default_server_args) as remote_server: + yield remote_server + + +@pytest.fixture(scope="module") +def client(server): + return server.get_async_client() + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + # first test base model, then test loras, then test prompt adapters + "model_name,num_virtual_tokens", + [(MODEL_NAME, 0), ("zephyr-lora", 0), ("zephyr-lora2", 0), + ("zephyr-pa", PA_NUM_VIRTUAL_TOKENS), + ("zephyr-pa2", PA_NUM_VIRTUAL_TOKENS)], +) +async def test_single_completion(client: openai.AsyncOpenAI, model_name: str, + num_virtual_tokens: int): + completion = await client.completions.create(model=model_name, + prompt="Hello, my name is", + max_tokens=5, + temperature=0.0) + + assert completion.id is not None + assert completion.choices is not None and len(completion.choices) == 1 + + choice = completion.choices[0] + assert len(choice.text) >= 5 + assert choice.finish_reason == "length" + assert completion.usage == openai.types.CompletionUsage( + completion_tokens=5, + prompt_tokens=6 + num_virtual_tokens, + total_tokens=11 + num_virtual_tokens) + + # test using token IDs + completion = await client.completions.create( + model=model_name, + prompt=[0, 0, 0, 0, 0], + max_tokens=5, + temperature=0.0, + ) + assert len(completion.choices[0].text) >= 1 + + +@pytest.mark.asyncio +async def test_added_lora_tokens(client: openai.AsyncOpenAI): + # test using token IDs + completion = await client.completions.create( + model="zephyr-lora2", + prompt=[0, 0, 32000, 32001, 32002], + echo=True, + max_tokens=5, + temperature=0.0, + ) + # Added tokens should appear in tokenized prompt + assert completion.choices[0].text.startswith("vllm1vllm2vllm3") + + +@pytest.mark.asyncio +async def test_added_lora_tokens_base_model(client: openai.AsyncOpenAI): + # test using token IDs + completion = await client.completions.create( + model=MODEL_NAME, + prompt=[0, 0, 32000, 32001, 32002], + echo=True, + max_tokens=5, + temperature=0.0, + ) + # Added tokens should not appear in tokenized prompt + assert "vllm" not in completion.choices[0].text + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + # first test base model, then test loras, then test prompt adapters + "model_name", + [MODEL_NAME, "zephyr-lora", "zephyr-lora2", "zephyr-pa", "zephyr-pa2"], +) +async def test_no_logprobs(client: openai.AsyncOpenAI, model_name: str): + # test using token IDs + completion = await client.completions.create( + model=model_name, + prompt=[0, 0, 0, 0, 0], + max_tokens=5, + temperature=0.0, + logprobs=None, + ) + choice = completion.choices[0] + assert choice.logprobs is None + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + # just test 1 lora and 1 pa hereafter + "model_name", + [MODEL_NAME, "zephyr-lora", "zephyr-pa"], +) +async def test_zero_logprobs(client: openai.AsyncOpenAI, model_name: str): + # test using token IDs + completion = await client.completions.create( + model=model_name, + prompt=[0, 0, 0, 0, 0], + max_tokens=5, + temperature=0.0, + logprobs=0, + ) + choice = completion.choices[0] + assert choice.logprobs is not None + assert choice.logprobs.token_logprobs is not None + assert choice.logprobs.top_logprobs is not None + assert len(choice.logprobs.top_logprobs[0]) == 1 + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name", + [MODEL_NAME, "zephyr-lora", "zephyr-pa"], +) +async def test_some_logprobs(client: openai.AsyncOpenAI, model_name: str): + # test using token IDs + completion = await client.completions.create( + model=model_name, + prompt=[0, 0, 0, 0, 0], + max_tokens=5, + temperature=0.0, + logprobs=5, + ) + choice = completion.choices[0] + assert choice.logprobs is not None + assert choice.logprobs.token_logprobs is not None + assert choice.logprobs.top_logprobs is not None + assert 5 <= len(choice.logprobs.top_logprobs[0]) <= 6 + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name", + [MODEL_NAME, "zephyr-lora", "zephyr-pa"], +) +async def test_too_many_completion_logprobs(client: openai.AsyncOpenAI, + model_name: str): + + with pytest.raises( + (openai.BadRequestError, openai.APIError)): # test using token IDs + await client.completions.create( + model=model_name, + prompt=[0, 0, 0, 0, 0], + max_tokens=5, + temperature=0.0, + # vLLM has higher default max_logprobs (20 instead of 5) to support + # both Completion API and Chat Completion API + logprobs=21, + ) + ... + with pytest.raises( + (openai.BadRequestError, openai.APIError)): # test using token IDs + stream = await client.completions.create( + model=model_name, + prompt=[0, 0, 0, 0, 0], + max_tokens=5, + temperature=0.0, + # vLLM has higher default max_logprobs (20 instead of 5) to support + # both Completion API and Chat Completion API + logprobs=30, + stream=True, + ) + async for chunk in stream: + ... + + # the server should still work afterwards + completion = await client.completions.create( + model=model_name, + prompt=[0, 0, 0, 0, 0], + max_tokens=5, + temperature=0.0, + ) + assert len(completion.choices[0].text) >= 0 + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name", + [MODEL_NAME, "zephyr-lora", "zephyr-pa"], +) +async def test_completion_streaming(client: openai.AsyncOpenAI, + model_name: str): + prompt = "What is an LLM?" + + single_completion = await client.completions.create( + model=model_name, + prompt=prompt, + max_tokens=5, + temperature=0.0, + ) + single_output = single_completion.choices[0].text + stream = await client.completions.create(model=model_name, + prompt=prompt, + max_tokens=5, + temperature=0.0, + stream=True) + chunks: List[str] = [] + finish_reason_count = 0 + async for chunk in stream: + chunks.append(chunk.choices[0].text) + if chunk.choices[0].finish_reason is not None: + finish_reason_count += 1 + # finish reason should only return in last block + assert finish_reason_count == 1 + assert chunk.choices[0].finish_reason == "length" + assert chunk.choices[0].text + assert "".join(chunks) == single_output + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name", + [MODEL_NAME, "zephyr-lora", "zephyr-pa"], +) +async def test_completion_stream_options(client: openai.AsyncOpenAI, + model_name: str): + prompt = "What is the capital of France?" + + # Test stream=True, stream_options= + # {"include_usage": False, "continuous_usage_stats": False} + stream = await client.completions.create(model=model_name, + prompt=prompt, + max_tokens=5, + temperature=0.0, + stream=True, + stream_options={ + "include_usage": False, + "continuous_usage_stats": + False, + }) + + async for chunk in stream: + assert chunk.usage is None + + # Test stream=True, stream_options= + # {"include_usage": False, "continuous_usage_stats": True} + stream = await client.completions.create(model=model_name, + prompt=prompt, + max_tokens=5, + temperature=0.0, + stream=True, + stream_options={ + "include_usage": False, + "continuous_usage_stats": + True, + }) + async for chunk in stream: + assert chunk.usage is None + + # Test stream=True, stream_options= + # {"include_usage": True, "continuous_usage_stats": False} + stream = await client.completions.create(model=model_name, + prompt=prompt, + max_tokens=5, + temperature=0.0, + stream=True, + stream_options={ + "include_usage": True, + "continuous_usage_stats": + False, + }) + async for chunk in stream: + if chunk.choices[0].finish_reason is None: + assert chunk.usage is None + else: + assert chunk.usage is None + final_chunk = await stream.__anext__() + assert final_chunk.usage is not None + assert final_chunk.usage.prompt_tokens > 0 + assert final_chunk.usage.completion_tokens > 0 + assert final_chunk.usage.total_tokens == ( + final_chunk.usage.prompt_tokens + + final_chunk.usage.completion_tokens) + assert final_chunk.choices == [] + + # Test stream=True, stream_options= + # {"include_usage": True, "continuous_usage_stats": True} + stream = await client.completions.create(model=model_name, + prompt=prompt, + max_tokens=5, + temperature=0.0, + stream=True, + stream_options={ + "include_usage": True, + "continuous_usage_stats": + True, + }) + async for chunk in stream: + assert chunk.usage is not None + assert chunk.usage.prompt_tokens > 0 + assert chunk.usage.completion_tokens > 0 + assert chunk.usage.total_tokens == (chunk.usage.prompt_tokens + + chunk.usage.completion_tokens) + if chunk.choices[0].finish_reason is not None: + final_chunk = await stream.__anext__() + assert final_chunk.usage is not None + assert final_chunk.usage.prompt_tokens > 0 + assert final_chunk.usage.completion_tokens > 0 + assert final_chunk.usage.total_tokens == ( + final_chunk.usage.prompt_tokens + + final_chunk.usage.completion_tokens) + assert final_chunk.choices == [] + + # Test stream=False, stream_options= + # {"include_usage": None} + with pytest.raises(BadRequestError): + await client.completions.create(model=model_name, + prompt=prompt, + max_tokens=5, + temperature=0.0, + stream=False, + stream_options={"include_usage": None}) + + # Test stream=False, stream_options= + # {"include_usage": True} + with pytest.raises(BadRequestError): + await client.completions.create(model=model_name, + prompt=prompt, + max_tokens=5, + temperature=0.0, + stream=False, + stream_options={"include_usage": True}) + + # Test stream=False, stream_options= + # {"continuous_usage_stats": None} + with pytest.raises(BadRequestError): + await client.completions.create( + model=model_name, + prompt=prompt, + max_tokens=5, + temperature=0.0, + stream=False, + stream_options={"continuous_usage_stats": None}) + + # Test stream=False, stream_options= + # {"continuous_usage_stats": True} + with pytest.raises(BadRequestError): + await client.completions.create( + model=model_name, + prompt=prompt, + max_tokens=5, + temperature=0.0, + stream=False, + stream_options={"continuous_usage_stats": True}) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name", + [MODEL_NAME, "zephyr-lora", "zephyr-pa"], +) +async def test_batch_completions(client: openai.AsyncOpenAI, model_name: str): + # test both text and token IDs + for prompts in (["Hello, my name is"] * 2, [[0, 0, 0, 0, 0]] * 2): + # test simple list + batch = await client.completions.create( + model=model_name, + prompt=prompts, + max_tokens=5, + temperature=0.0, + ) + assert len(batch.choices) == 2 + assert batch.choices[0].text == batch.choices[1].text + + # test n = 2 + batch = await client.completions.create( + model=model_name, + prompt=prompts, + n=2, + max_tokens=5, + temperature=0.0, + extra_body=dict( + # NOTE: this has to be true for n > 1 in vLLM, but not necessary + # for official client. + use_beam_search=True), + ) + assert len(batch.choices) == 4 + assert batch.choices[0].text != batch.choices[ + 1].text, "beam search should be different" + assert batch.choices[0].text == batch.choices[ + 2].text, "two copies of the same prompt should be the same" + assert batch.choices[1].text == batch.choices[ + 3].text, "two copies of the same prompt should be the same" + + # test streaming + batch = await client.completions.create( + model=model_name, + prompt=prompts, + max_tokens=5, + temperature=0.0, + stream=True, + ) + texts = [""] * 2 + async for chunk in batch: + assert len(chunk.choices) == 1 + choice = chunk.choices[0] + texts[choice.index] += choice.text + assert texts[0] == texts[1] + + +@pytest.mark.asyncio +async def test_logits_bias(client: openai.AsyncOpenAI): + prompt = "Hello, my name is" + max_tokens = 5 + tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME) + + # Test exclusive selection + token_id = 1000 + completion = await client.completions.create( + model=MODEL_NAME, + prompt=prompt, + max_tokens=max_tokens, + temperature=0.0, + logit_bias={str(token_id): 100}, + seed=42, + ) + assert len(completion.choices[0].text) >= 5 + response_tokens = tokenizer(completion.choices[0].text, + add_special_tokens=False)["input_ids"] + expected_tokens = tokenizer(tokenizer.decode([token_id] * 5), + add_special_tokens=False)["input_ids"] + assert all([ + response == expected + for response, expected in zip(response_tokens, expected_tokens) + ]) + + # Test ban + completion = await client.completions.create( + model=MODEL_NAME, + prompt=prompt, + max_tokens=max_tokens, + temperature=0.0, + ) + response_tokens = tokenizer(completion.choices[0].text, + add_special_tokens=False)["input_ids"] + first_response = completion.choices[0].text + completion = await client.completions.create( + model=MODEL_NAME, + prompt=prompt, + max_tokens=max_tokens, + temperature=0.0, + logit_bias={str(token): -100 + for token in response_tokens}, + ) + assert first_response != completion.choices[0].text + + +@pytest.mark.asyncio +async def test_allowed_token_ids(client: openai.AsyncOpenAI): + prompt = "Hello, my name is" + max_tokens = 1 + tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME) + + # Test exclusive selection + allowed_ids = [21555, 21557, 21558] + completion = await client.completions.create( + model=MODEL_NAME, + prompt=prompt, + max_tokens=max_tokens, + temperature=0.0, + seed=42, + extra_body=dict(allowed_token_ids=allowed_ids), + logprobs=1, + ) + response_tokens = completion.choices[0].logprobs.tokens + assert len(response_tokens) == 1 + assert tokenizer.convert_tokens_to_ids(response_tokens)[0] in allowed_ids + + +@pytest.mark.asyncio +@pytest.mark.parametrize("guided_decoding_backend", + ["outlines", "lm-format-enforcer"]) +async def test_guided_json_completion(client: openai.AsyncOpenAI, + guided_decoding_backend: str, + sample_json_schema): + completion = await client.completions.create( + model=MODEL_NAME, + prompt=f"Give an example JSON for an employee profile " + f"that fits this schema: {sample_json_schema}", + n=3, + temperature=1.0, + max_tokens=500, + extra_body=dict(guided_json=sample_json_schema, + guided_decoding_backend=guided_decoding_backend)) + + assert completion.id is not None + assert len(completion.choices) == 3 + for i in range(3): + output_json = json.loads(completion.choices[i].text) + jsonschema.validate(instance=output_json, schema=sample_json_schema) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("guided_decoding_backend", + ["outlines", "lm-format-enforcer"]) +async def test_guided_regex_completion(client: openai.AsyncOpenAI, + guided_decoding_backend: str, + sample_regex): + completion = await client.completions.create( + model=MODEL_NAME, + prompt=f"Give an example IPv4 address with this regex: {sample_regex}", + n=3, + temperature=1.0, + max_tokens=20, + extra_body=dict(guided_regex=sample_regex, + guided_decoding_backend=guided_decoding_backend)) + + assert completion.id is not None + assert len(completion.choices) == 3 + for i in range(3): + assert re.fullmatch(sample_regex, + completion.choices[i].text) is not None + + +@pytest.mark.asyncio +@pytest.mark.parametrize("guided_decoding_backend", + ["outlines", "lm-format-enforcer"]) +async def test_guided_choice_completion(client: openai.AsyncOpenAI, + guided_decoding_backend: str, + sample_guided_choice): + completion = await client.completions.create( + model=MODEL_NAME, + prompt="The best language for type-safe systems programming is ", + n=2, + temperature=1.0, + max_tokens=10, + extra_body=dict(guided_choice=sample_guided_choice, + guided_decoding_backend=guided_decoding_backend)) + + assert completion.id is not None + assert len(completion.choices) == 2 + for i in range(2): + assert completion.choices[i].text in sample_guided_choice + + +@pytest.mark.asyncio +async def test_guided_grammar(client: openai.AsyncOpenAI, + sample_sql_statements): + + completion = await client.completions.create( + model=MODEL_NAME, + prompt=("Generate a sql state that select col_1 from " + "table_1 where it is equals to 1"), + temperature=1.0, + max_tokens=500, + extra_body=dict(guided_grammar=sample_sql_statements)) + + content = completion.choices[0].text + + # use Lark to parse the output, and make sure it's a valid parse tree + from lark import Lark + parser = Lark(sample_sql_statements) + parser.parse(content) + + # remove spaces for comparison b/c we removed them in the grammar + ground_truth = "SELECT col_1 from table_1 where col_1 = 1".replace(" ", "") + + assert content.strip() == ground_truth + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + # first test base model, then test loras + "model_name", + [MODEL_NAME, "zephyr-lora", "zephyr-lora2"], +) +@pytest.mark.parametrize("logprobs_arg", [1, 0]) +async def test_echo_logprob_completion(client: openai.AsyncOpenAI, + model_name: str, logprobs_arg: int): + tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME) + # test using text and token IDs + for prompt in ("Hello, my name is", [0, 0, 0, 0, 0]): + completion = await client.completions.create(model=model_name, + prompt=prompt, + max_tokens=5, + temperature=0.0, + echo=True, + logprobs=logprobs_arg) + + prompt_text = tokenizer.decode(prompt) if isinstance(prompt, + list) else prompt + assert re.search(r"^" + prompt_text, completion.choices[0].text) + logprobs = completion.choices[0].logprobs + assert logprobs is not None + assert len(logprobs.text_offset) > 5 + assert (len(logprobs.token_logprobs) > 5 + and logprobs.token_logprobs[0] is None) + assert (len(logprobs.top_logprobs) > 5 + and logprobs.top_logprobs[0] is None) + for top_logprobs in logprobs.top_logprobs[1:]: + assert max(logprobs_arg, + 1) <= len(top_logprobs) <= logprobs_arg + 1 + assert len(logprobs.tokens) > 5 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("guided_decoding_backend", + ["outlines", "lm-format-enforcer"]) +async def test_guided_decoding_type_error(client: openai.AsyncOpenAI, + guided_decoding_backend: str, + sample_json_schema, sample_regex): + with pytest.raises(openai.BadRequestError): + _ = await client.completions.create( + model=MODEL_NAME, + prompt="Give an example JSON that fits this schema: 42", + extra_body=dict(guided_json=42, + guided_decoding_backend=guided_decoding_backend)) + + with pytest.raises(openai.BadRequestError): + _ = await client.completions.create( + model=MODEL_NAME, + prompt="Give an example string that fits this regex", + extra_body=dict(guided_regex=sample_regex, + guided_json=sample_json_schema)) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index d3f9a0ab00f10..c39caca25cc7a 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -7,7 +7,8 @@ from transformers import PreTrainedTokenizer import vllm.envs as envs -from vllm.config import DecodingConfig, EngineConfig, ModelConfig +from vllm.config import (DecodingConfig, EngineConfig, LoRAConfig, ModelConfig, + ParallelConfig, SchedulerConfig) from vllm.core.scheduler import SchedulerOutputs from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_timeout import asyncio_timeout @@ -928,6 +929,14 @@ async def get_model_config(self) -> ModelConfig: else: return self.engine.get_model_config() + async def get_parallel_config(self) -> ParallelConfig: + """Get the parallel configuration of the vLLM engine.""" + if self.engine_use_ray: + return await self.engine.get_parallel_config.remote( # type: ignore + ) + else: + return self.engine.get_parallel_config() + async def get_decoding_config(self) -> DecodingConfig: """Get the decoding configuration of the vLLM engine.""" if self.engine_use_ray: @@ -936,6 +945,22 @@ async def get_decoding_config(self) -> DecodingConfig: else: return self.engine.get_decoding_config() + async def get_scheduler_config(self) -> SchedulerConfig: + """Get the scheduling configuration of the vLLM engine.""" + if self.engine_use_ray: + return await self.engine.get_scheduler_config.remote( # type: ignore + ) + else: + return self.engine.get_scheduler_config() + + async def get_lora_config(self) -> LoRAConfig: + """Get the lora configuration of the vLLM engine.""" + if self.engine_use_ray: + return await self.engine.get_lora_config.remote( # type: ignore + ) + else: + return self.engine.get_lora_config() + async def do_log_stats( self, scheduler_outputs: Optional[SchedulerOutputs] = None, diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 1efe2206abe81..3747f93b16cd1 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -38,9 +38,8 @@ init_tracer) from vllm.transformers_utils.config import try_get_generation_config from vllm.transformers_utils.detokenizer import Detokenizer -from vllm.transformers_utils.tokenizer_group import (AnyTokenizer, - BaseTokenizerGroup, - get_tokenizer_group) +from vllm.transformers_utils.tokenizer_group import ( + AnyTokenizer, BaseTokenizerGroup, init_tokenizer_from_configs) from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled, usage_message) from vllm.utils import Counter @@ -485,19 +484,12 @@ def get_tokenizer_for_seq(self, sequence: Sequence) -> AnyTokenizer: return self.get_tokenizer_group().get_lora_tokenizer( sequence.lora_request) - def _init_tokenizer(self, **tokenizer_init_kwargs) -> BaseTokenizerGroup: - init_kwargs = dict( - tokenizer_id=self.model_config.tokenizer, - enable_lora=bool(self.lora_config), - max_num_seqs=self.scheduler_config.max_num_seqs, - max_input_length=None, - tokenizer_mode=self.model_config.tokenizer_mode, - trust_remote_code=self.model_config.trust_remote_code, - revision=self.model_config.tokenizer_revision) - init_kwargs.update(tokenizer_init_kwargs) - - return get_tokenizer_group(self.parallel_config.tokenizer_pool_config, - **init_kwargs) + def _init_tokenizer(self) -> BaseTokenizerGroup: + return init_tokenizer_from_configs( + model_config=self.model_config, + scheduler_config=self.scheduler_config, + parallel_config=self.parallel_config, + enable_lora=bool(self.lora_config)) def _verify_args(self) -> None: self.model_config.verify_with_parallel_config(self.parallel_config) @@ -759,10 +751,22 @@ def get_model_config(self) -> ModelConfig: """Gets the model configuration.""" return self.model_config + def get_parallel_config(self) -> ParallelConfig: + """Gets the parallel configuration.""" + return self.parallel_config + def get_decoding_config(self) -> DecodingConfig: """Gets the decoding configuration.""" return self.decoding_config + def get_scheduler_config(self) -> SchedulerConfig: + """Gets the scheduler configuration.""" + return self.scheduler_config + + def get_lora_config(self) -> LoRAConfig: + """Gets the LoRA configuration.""" + return self.lora_config + def get_num_unfinished_requests(self) -> int: """Gets the number of unfinished requests.""" return sum(scheduler.get_num_unfinished_seq_groups() diff --git a/vllm/engine/protocol.py b/vllm/engine/protocol.py new file mode 100644 index 0000000000000..fc94ef6662e0a --- /dev/null +++ b/vllm/engine/protocol.py @@ -0,0 +1,84 @@ +from typing import (AsyncIterator, List, Mapping, Optional, Protocol, + runtime_checkable) + +from transformers import PreTrainedTokenizer + +from vllm.config import DecodingConfig, ModelConfig +from vllm.core.scheduler import SchedulerOutputs +from vllm.inputs.data import PromptInputs +from vllm.lora.request import LoRARequest +from vllm.outputs import EmbeddingRequestOutput, RequestOutput +from vllm.pooling_params import PoolingParams +from vllm.prompt_adapter.request import PromptAdapterRequest +from vllm.sampling_params import SamplingParams +from vllm.sequence import SamplerOutput + + +@runtime_checkable +class AsyncEngineClient(Protocol): + """Protocol class for Clients to AsyncLLMEngine""" + + @property + def is_running(self) -> bool: + ... + + @property + def is_stopped(self) -> bool: + ... + + @property + def errored(self) -> bool: + ... + + async def generate( + self, + inputs: PromptInputs, + sampling_params: SamplingParams, + request_id: str, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None + ) -> AsyncIterator[RequestOutput]: + """Generates outputs for a request""" + + async def encode( + self, + inputs: PromptInputs, + pooling_params: PoolingParams, + request_id: str, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + ) -> AsyncIterator[EmbeddingRequestOutput]: + """Generate outputs for a request from an embedding model.""" + + async def abort(self, request_id: str) -> None: + """Abort a request. + + Args: + request_id: The unique id of the request. + """ + + async def get_model_config(self) -> ModelConfig: + """Get the model configuration of the vLLM engine.""" + + async def get_decoding_config(self) -> DecodingConfig: + """Get the decoding configuration of the vLLM engine.""" + + async def get_tokenizer( + self, + lora_request: Optional[LoRARequest] = None, + ) -> PreTrainedTokenizer: + """Get the appropriate Tokenizer for the request""" + + async def is_tracing_enabled(self) -> bool: + pass + + async def do_log_stats( + self, + scheduler_outputs: Optional[SchedulerOutputs] = None, + model_output: Optional[List[SamplerOutput]] = None, + ) -> None: + pass + + async def check_health(self) -> None: + """Raise if unhealthy""" diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 0fe4dd245b5e6..e330ee81f7e44 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -5,7 +5,8 @@ import signal from contextlib import asynccontextmanager from http import HTTPStatus -from typing import Optional, Set +from multiprocessing import Process +from typing import AsyncIterator, Set import fastapi import uvicorn @@ -17,8 +18,10 @@ from starlette.routing import Mount import vllm.envs as envs +from vllm.config import ModelConfig from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm.engine.protocol import AsyncEngineClient from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.openai.cli_args import make_arg_parser # yapf conflicts with isort for this block @@ -31,6 +34,8 @@ EmbeddingRequest, ErrorResponse, TokenizeRequest, TokenizeResponse) +from vllm.entrypoints.openai.rpc.client import AsyncEngineRPCClient +from vllm.entrypoints.openai.rpc.server import run_rpc_server # yapf: enable from vllm.entrypoints.openai.serving_chat import OpenAIServingChat from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion @@ -39,12 +44,12 @@ OpenAIServingTokenization) from vllm.logger import init_logger from vllm.usage.usage_lib import UsageContext -from vllm.utils import FlexibleArgumentParser +from vllm.utils import FlexibleArgumentParser, get_open_port from vllm.version import __version__ as VLLM_VERSION TIMEOUT_KEEP_ALIVE = 5 # seconds -engine: AsyncLLMEngine +async_engine_client: AsyncEngineClient engine_args: AsyncEngineArgs openai_serving_chat: OpenAIServingChat openai_serving_completion: OpenAIServingCompletion @@ -56,13 +61,22 @@ _running_tasks: Set[asyncio.Task] = set() +def model_is_embedding(model_name: str) -> bool: + return ModelConfig(model=model_name, + tokenizer=model_name, + tokenizer_mode="auto", + trust_remote_code=False, + seed=0, + dtype="float16").embedding_mode + + @asynccontextmanager async def lifespan(app: fastapi.FastAPI): async def _force_log(): while True: await asyncio.sleep(10) - await engine.do_log_stats() + await async_engine_client.do_log_stats() if not engine_args.disable_log_stats: task = asyncio.create_task(_force_log()) @@ -72,6 +86,52 @@ async def _force_log(): yield +@asynccontextmanager +async def build_async_engine_client(args) -> AsyncIterator[AsyncEngineClient]: + # Context manager to handle async_engine_client lifecycle + # Ensures everything is shutdown and cleaned up on error/exit + global engine_args + engine_args = AsyncEngineArgs.from_cli_args(args) + + # Backend itself still global for the silly lil' health handler + global async_engine_client + + # If manually triggered or embedding model, use AsyncLLMEngine in process. + # TODO: support embedding model via RPC. + if (model_is_embedding(args.model) + or args.disable_frontend_multiprocessing): + async_engine_client = AsyncLLMEngine.from_engine_args( + engine_args, usage_context=UsageContext.OPENAI_API_SERVER) + yield async_engine_client + return + + # Otherwise, use the multiprocessing AsyncLLMEngine. + else: + # Start RPCServer in separate process (holds the AsyncLLMEngine). + port = get_open_port(envs.VLLM_RPC_PORT) + rpc_server_process = Process(target=run_rpc_server, + args=(engine_args, + UsageContext.OPENAI_API_SERVER, + port)) + rpc_server_process.start() + + # Build RPCClient, which conforms to AsyncEngineClient Protocol. + async_engine_client = AsyncEngineRPCClient(port) + await async_engine_client.setup() + + try: + yield async_engine_client + finally: + # Ensure rpc server process was terminated + rpc_server_process.terminate() + + # Close all open connections to the backend + async_engine_client.close() + + # Wait for server process to join + rpc_server_process.join() + + router = APIRouter() @@ -86,7 +146,7 @@ def mount_metrics(app: fastapi.FastAPI): @router.get("/health") async def health() -> Response: """Health check.""" - await openai_serving_chat.engine.check_health() + await async_engine_client.check_health() return Response(status_code=200) @@ -215,8 +275,8 @@ async def authentication(request: Request, call_next): async def build_server( + async_engine_client: AsyncEngineClient, args, - llm_engine: Optional[AsyncLLMEngine] = None, **uvicorn_kwargs, ) -> uvicorn.Server: app = build_app(args) @@ -226,14 +286,7 @@ async def build_server( else: served_model_names = [args.model] - global engine, engine_args - - engine_args = AsyncEngineArgs.from_cli_args(args) - engine = (llm_engine - if llm_engine is not None else AsyncLLMEngine.from_engine_args( - engine_args, usage_context=UsageContext.OPENAI_API_SERVER)) - - model_config = await engine.get_model_config() + model_config = await async_engine_client.get_model_config() if args.disable_log_requests: request_logger = None @@ -246,7 +299,7 @@ async def build_server( global openai_serving_tokenization openai_serving_chat = OpenAIServingChat( - engine, + async_engine_client, model_config, served_model_names, args.response_role, @@ -257,7 +310,7 @@ async def build_server( return_tokens_as_token_ids=args.return_tokens_as_token_ids, ) openai_serving_completion = OpenAIServingCompletion( - engine, + async_engine_client, model_config, served_model_names, lora_modules=args.lora_modules, @@ -266,13 +319,13 @@ async def build_server( return_tokens_as_token_ids=args.return_tokens_as_token_ids, ) openai_serving_embedding = OpenAIServingEmbedding( - engine, + async_engine_client, model_config, served_model_names, request_logger=request_logger, ) openai_serving_tokenization = OpenAIServingTokenization( - engine, + async_engine_client, model_config, served_model_names, lora_modules=args.lora_modules, @@ -304,32 +357,39 @@ async def build_server( return uvicorn.Server(config) -async def run_server(args, llm_engine=None, **uvicorn_kwargs) -> None: +async def run_server(args, **uvicorn_kwargs) -> None: logger.info("vLLM API server version %s", VLLM_VERSION) logger.info("args: %s", args) - server = await build_server( - args, - llm_engine, - **uvicorn_kwargs, - ) + shutdown_task = None + async with build_async_engine_client(args) as async_engine_client: + + server = await build_server( + async_engine_client, + args, + **uvicorn_kwargs, + ) + + loop = asyncio.get_running_loop() - loop = asyncio.get_running_loop() + server_task = loop.create_task(server.serve()) - server_task = loop.create_task(server.serve()) + def signal_handler() -> None: + # prevents the uvicorn signal handler to exit early + server_task.cancel() - def signal_handler() -> None: - # prevents the uvicorn signal handler to exit early - server_task.cancel() + loop.add_signal_handler(signal.SIGINT, signal_handler) + loop.add_signal_handler(signal.SIGTERM, signal_handler) - loop.add_signal_handler(signal.SIGINT, signal_handler) - loop.add_signal_handler(signal.SIGTERM, signal_handler) + try: + await server_task + except asyncio.CancelledError: + logger.info("Gracefully stopping http server") + shutdown_task = server.shutdown() - try: - await server_task - except asyncio.CancelledError: - print("Gracefully stopping http server") - await server.shutdown() + if shutdown_task: + # NB: Await server shutdown only after the backend context is exited + await shutdown_task if __name__ == "__main__": diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py index a4192937980f7..1facedac72ca8 100644 --- a/vllm/entrypoints/openai/cli_args.py +++ b/vllm/entrypoints/openai/cli_args.py @@ -131,9 +131,14 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: parser.add_argument( "--return-tokens-as-token-ids", action="store_true", - help="When --max-logprobs is specified, represents single tokens as" - "strings of the form 'token_id:{token_id}' so that tokens that" + help="When --max-logprobs is specified, represents single tokens as " + "strings of the form 'token_id:{token_id}' so that tokens that " "are not JSON-encodable can be identified.") + parser.add_argument( + "--disable-frontend-multiprocessing", + action="store_true", + help="If specified, will run the OpenAI frontend server in the same " + "process as the model serving engine.") parser = AsyncEngineArgs.add_cli_args(parser) diff --git a/vllm/entrypoints/openai/logits_processors.py b/vllm/entrypoints/openai/logits_processors.py index f8e04e7f18e0f..84871fc83ef5f 100644 --- a/vllm/entrypoints/openai/logits_processors.py +++ b/vllm/entrypoints/openai/logits_processors.py @@ -1,4 +1,4 @@ -from functools import lru_cache +from functools import lru_cache, partial from typing import Dict, FrozenSet, Iterable, List, Optional, Union import torch @@ -40,6 +40,14 @@ def _get_allowed_token_ids_logits_processor( return AllowedTokenIdsLogitsProcessor(allowed_token_ids) +def logit_bias_logits_processor(logit_bias: Dict[str, + float], token_ids: List[int], + logits: torch.Tensor) -> torch.Tensor: + for token_id, bias in logit_bias.items(): + logits[token_id] += bias + return logits + + def get_logits_processors( logit_bias: Optional[Union[Dict[int, float], Dict[str, float]]], allowed_token_ids: Optional[List[int]], @@ -64,13 +72,8 @@ def get_logits_processors( raise ValueError("token_id in logit_bias contains " "out-of-vocab token id") - def logit_bias_logits_processor(token_ids: List[int], - logits: torch.Tensor) -> torch.Tensor: - for token_id, bias in clamped_logit_bias.items(): - logits[token_id] += bias - return logits - - logits_processors.append(logit_bias_logits_processor) + logits_processors.append( + partial(logit_bias_logits_processor, clamped_logit_bias)) if allowed_token_ids is not None: logits_processors.append( diff --git a/vllm/entrypoints/openai/rpc/__init__.py b/vllm/entrypoints/openai/rpc/__init__.py new file mode 100644 index 0000000000000..8a7b12201cab7 --- /dev/null +++ b/vllm/entrypoints/openai/rpc/__init__.py @@ -0,0 +1,42 @@ +from dataclasses import dataclass +from enum import Enum +from typing import Mapping, Optional, Union + +from vllm.inputs import PromptInputs +from vllm.lora.request import LoRARequest +from vllm.prompt_adapter.request import PromptAdapterRequest +from vllm.sampling_params import SamplingParams + +VLLM_RPC_SUCCESS_STR = "SUCCESS" +VLLM_RPC_HEALTHY_STR = "HEALTHY" + + +@dataclass +class RPCGenerateRequest: + inputs: PromptInputs + sampling_params: SamplingParams + request_id: str + lora_request: Optional[LoRARequest] = None + trace_headers: Optional[Mapping[str, str]] = None + prompt_adapter_request: Optional[PromptAdapterRequest] = None + + +@dataclass +class RPCAbortRequest: + request_id: str + + +class RPCUtilityRequest(Enum): + IS_SERVER_READY = 1 + GET_MODEL_CONFIG = 2 + GET_DECODING_CONFIG = 3 + GET_PARALLEL_CONFIG = 4 + GET_SCHEDULER_CONFIG = 5 + GET_LORA_CONFIG = 6 + DO_LOG_STATS = 7 + CHECK_HEALTH = 8 + IS_TRACING_ENABLED = 9 + + +RPC_REQUEST_TYPE = Union[RPCGenerateRequest, RPCAbortRequest, + RPCUtilityRequest] diff --git a/vllm/entrypoints/openai/rpc/client.py b/vllm/entrypoints/openai/rpc/client.py new file mode 100644 index 0000000000000..45bf88b5bf574 --- /dev/null +++ b/vllm/entrypoints/openai/rpc/client.py @@ -0,0 +1,248 @@ +from contextlib import contextmanager +from typing import Any, AsyncIterator, Mapping, Optional + +import cloudpickle +import zmq +import zmq.asyncio + +from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig, + ParallelConfig, SchedulerConfig) +from vllm.entrypoints.openai.rpc import (RPC_REQUEST_TYPE, + VLLM_RPC_HEALTHY_STR, + VLLM_RPC_SUCCESS_STR, RPCAbortRequest, + RPCGenerateRequest, RPCUtilityRequest) +from vllm.inputs import PromptInputs +from vllm.lora.request import LoRARequest +from vllm.outputs import EmbeddingRequestOutput, RequestOutput +from vllm.prompt_adapter.request import PromptAdapterRequest +from vllm.sampling_params import SamplingParams +from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs + + +class AsyncEngineRPCClient: + + def __init__(self, port: int): + self.context = zmq.asyncio.Context() + self.path = f"tcp://localhost:{port}" + + async def setup(self): + """Setup the client before it starts sending server requests.""" + + # Wait until server is ready. + await self.wait_for_server() + + # Get the configs. + self.model_config = await self._get_model_config_rpc() + self.decoding_config = await self._get_decoding_config_rpc() + self.tracing_flag = await self._is_tracing_enabled_rpc() + + # Create the tokenizer group. + # TODO: refactor OAI server to avoid needing this info. + self.tokenizer = init_tokenizer_from_configs( + model_config=self.model_config, + scheduler_config=(await self._get_scheduler_config_rpc()), + parallel_config=(await self._get_parallel_config_rpc()), + enable_lora=bool(await self._get_lora_config_rpc()), + ) + + def close(self): + """Destroy the ZeroMQ Context.""" + self.context.destroy() + + @contextmanager + def socket(self): + # Ensure client sockets are always closed after use + + # Connect to RPC socket for Request-Reply pattern, + # Note that we use DEALER to enable asynchronous communication + # to enable streaming. + socket = self.context.socket(zmq.constants.DEALER) + try: + socket.connect(self.path) + yield socket + finally: + socket.close() + + async def _send_get_data_rpc_request(self, request: RPCUtilityRequest, + expected_type: Any, + error_message: str) -> Any: + """Send an RPC request that is expecting data back.""" + + with self.socket() as socket: + + # Ping RPCServer with a request. + await socket.send(cloudpickle.dumps(request)) + + # Await the data from the Server. + data = cloudpickle.loads(await socket.recv()) + + if not isinstance(data, expected_type): + # LoRAConfig can be None. + if expected_type == LoRAConfig and data is None: + pass + else: + raise ValueError(error_message) + + return data + + async def _send_one_way_rpc_request(self, request: RPC_REQUEST_TYPE, + error_message: str): + """Send one-way RPC request to trigger an action.""" + with self.socket() as socket: + # Ping RPC Server with request. + await socket.send(cloudpickle.dumps(request)) + + # Await acknowledgement from RPCServer. + response = cloudpickle.loads(await socket.recv()) + + if not isinstance(response, str) or response != VLLM_RPC_SUCCESS_STR: + raise ValueError(error_message) + + return response + + async def get_tokenizer(self, lora_request: LoRARequest): + return await self.tokenizer.get_lora_tokenizer_async(lora_request) + + async def get_decoding_config(self) -> DecodingConfig: + return self.decoding_config + + async def get_model_config(self) -> ModelConfig: + return self.model_config + + async def is_tracing_enabled(self) -> bool: + return self.tracing_flag + + async def wait_for_server(self): + """Wait for the RPCServer to start up.""" + + await self._send_one_way_rpc_request( + request=RPCUtilityRequest.IS_SERVER_READY, + error_message="Unable to start RPC Server.") + + async def _get_model_config_rpc(self) -> ModelConfig: + """Get the ModelConfig object from the RPC Server""" + + return await self._send_get_data_rpc_request( + RPCUtilityRequest.GET_MODEL_CONFIG, + expected_type=ModelConfig, + error_message="Could not get ModelConfig from RPC Server") + + async def _get_decoding_config_rpc(self) -> DecodingConfig: + """Get DecodingConfig from the RPCServer""" + + return await self._send_get_data_rpc_request( + RPCUtilityRequest.GET_DECODING_CONFIG, + expected_type=DecodingConfig, + error_message="Could not get DecodingConfig from RPC Server") + + async def _get_parallel_config_rpc(self) -> ParallelConfig: + """Get ParallelConfig from the RPCServer""" + + return await self._send_get_data_rpc_request( + RPCUtilityRequest.GET_PARALLEL_CONFIG, + expected_type=ParallelConfig, + error_message="Could not get ParallelConfig from RPC Server") + + async def _get_scheduler_config_rpc(self) -> SchedulerConfig: + """Get SchedulerConfig from the RPCServer""" + + return await self._send_get_data_rpc_request( + RPCUtilityRequest.GET_SCHEDULER_CONFIG, + expected_type=SchedulerConfig, + error_message="Could not get SchedulerConfig from RPC Server") + + async def _get_lora_config_rpc(self): + """Get LoRAConfig from the RPCServer""" + + return await self._send_get_data_rpc_request( + RPCUtilityRequest.GET_LORA_CONFIG, + expected_type=LoRAConfig, + error_message="Could not get LoRAConfig from RPC Server") + + async def _is_tracing_enabled_rpc(self) -> ParallelConfig: + """Get is_tracing_enabled flag from the RPCServer""" + + return await self._send_get_data_rpc_request( + RPCUtilityRequest.IS_TRACING_ENABLED, + expected_type=bool, + error_message="Could not get is_tracing_enabled flag from RPC " + "Server") + + async def abort(self, request_id: str): + """Send an ABORT_REQUEST signal to the RPC Server""" + + await self._send_one_way_rpc_request( + request=RPCAbortRequest(request_id), + error_message=f"RPCAbortRequest {request_id} failed") + + async def do_log_stats(self): + """Send a DO_LOG_STATS signal to the RPC Server""" + + await self._send_one_way_rpc_request( + request=RPCUtilityRequest.DO_LOG_STATS, + error_message="RPCRequest DO_LOG_STATS failed.") + + async def generate( + self, + inputs: PromptInputs, + sampling_params: SamplingParams, + request_id: str, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None + ) -> AsyncIterator[RequestOutput]: + """Send an RPCGenerateRequest to the RPCServer and stream responses.""" + + with self.socket() as socket: + + # Send RPCGenerateRequest to the RPCServer. + await socket.send_multipart([ + cloudpickle.dumps( + RPCGenerateRequest( + inputs=inputs, + sampling_params=sampling_params, + request_id=request_id, + lora_request=lora_request, + trace_headers=trace_headers, + prompt_adapter_request=prompt_adapter_request)) + ]) + + # Stream back the results from the RPC Server. + while True: + message = await socket.recv() + request_output = cloudpickle.loads(message) + + if isinstance(request_output, Exception): + raise request_output + + if request_output.finished: + break + yield request_output + + yield request_output + + async def check_health(self) -> None: + """Raise if unhealthy""" + + with self.socket() as socket: + + # Ping RPCServer with CHECK_HEALTH request. + await socket.send(cloudpickle.dumps(RPCUtilityRequest.CHECK_HEALTH) + ) + + # Await the reply from the server. + # TODO: do we need an internal timeout here? + # Or do we expect the external probe to timeout and let this chill? + health_message = cloudpickle.loads(await socket.recv()) + + if isinstance(health_message, Exception): + raise health_message + + if health_message != VLLM_RPC_HEALTHY_STR: + raise ValueError("Expected healthy response from backend but got " + "f{health_message}") + + async def encode(self, *args, + **kwargs) -> AsyncIterator[EmbeddingRequestOutput]: + raise NotImplementedError( + "Embeddings not supported with multiprocessing backend") diff --git a/vllm/entrypoints/openai/rpc/server.py b/vllm/entrypoints/openai/rpc/server.py new file mode 100644 index 0000000000000..7a72a6f732c99 --- /dev/null +++ b/vllm/entrypoints/openai/rpc/server.py @@ -0,0 +1,216 @@ +import asyncio +import signal +from typing import Any, Coroutine + +import cloudpickle +import zmq +import zmq.asyncio +from typing_extensions import Never + +from vllm import AsyncEngineArgs, AsyncLLMEngine +from vllm.entrypoints.openai.rpc import (VLLM_RPC_HEALTHY_STR, + VLLM_RPC_SUCCESS_STR, RPCAbortRequest, + RPCGenerateRequest, RPCUtilityRequest) +from vllm.logger import init_logger +from vllm.usage.usage_lib import UsageContext + +logger = init_logger(__name__) + + +class AsyncEngineRPCServer: + + def __init__(self, async_engine_args: AsyncEngineArgs, + usage_context: UsageContext, port: int): + # Initialize engine first. + self.engine = AsyncLLMEngine.from_engine_args(async_engine_args, + usage_context) + + # Initialize context. + self.context = zmq.asyncio.Context() + + # Init socket for readiness state. + self.socket = self.context.socket(zmq.constants.ROUTER) + self.socket.bind(f"tcp://localhost:{port}") + + def cleanup(self): + """Cleanup all resources.""" + self.socket.close() + self.context.destroy() + + async def get_model_config(self, identity): + """Send the ModelConfig""" + model_config = await self.engine.get_model_config() + + await self.socket.send_multipart( + [identity, cloudpickle.dumps(model_config)]) + + async def get_decoding_config(self, identity): + """Send the DecodingConfig""" + decoding_config = await self.engine.get_decoding_config() + + await self.socket.send_multipart( + [identity, cloudpickle.dumps(decoding_config)]) + + async def get_lora_config(self, identity): + lora_config = await self.engine.get_lora_config() + + await self.socket.send_multipart( + [identity, cloudpickle.dumps(lora_config)]) + + async def get_scheduler_config(self, identity): + """Send the SchedulerConfig""" + parallel_config = await self.engine.get_scheduler_config() + + await self.socket.send_multipart( + [identity, cloudpickle.dumps(parallel_config)]) + + async def get_parallel_config(self, identity): + """Send the ParallelConfig""" + parallel_config = await self.engine.get_parallel_config() + + await self.socket.send_multipart( + [identity, cloudpickle.dumps(parallel_config)]) + + async def is_tracing_enabled(self, identity): + """Send the is_tracing_enabled flag""" + tracing_flag = await self.engine.is_tracing_enabled() + + await self.socket.send_multipart( + [identity, cloudpickle.dumps(tracing_flag)]) + + async def do_log_stats(self, identity): + """Log stats and confirm success.""" + await self.engine.do_log_stats() + + await self.socket.send_multipart([ + identity, + cloudpickle.dumps(VLLM_RPC_SUCCESS_STR), + ]) + + async def is_server_ready(self, identity): + """Notify the client that we are ready.""" + await self.socket.send_multipart([ + identity, + cloudpickle.dumps(VLLM_RPC_SUCCESS_STR), + ]) + + async def abort(self, identity, request: RPCAbortRequest): + """Abort request and notify the client of success.""" + # Abort the request in the llm engine. + await self.engine.abort(request.request_id) + + # Send confirmation to the client. + await self.socket.send_multipart([ + identity, + cloudpickle.dumps(VLLM_RPC_SUCCESS_STR), + ]) + + async def generate(self, identity, generate_request: RPCGenerateRequest): + try: + results_generator = self.engine.generate( + generate_request.inputs, + sampling_params=generate_request.sampling_params, + request_id=generate_request.request_id, + lora_request=generate_request.lora_request, + trace_headers=generate_request.trace_headers, + prompt_adapter_request=generate_request.prompt_adapter_request) + + async for request_output in results_generator: + await self.socket.send_multipart( + [identity, cloudpickle.dumps(request_output)]) + + except Exception as e: + ### Notify client of all failures + await self.socket.send_multipart([identity, cloudpickle.dumps(e)]) + + async def check_health(self, identity): + try: + await self.engine.check_health() + await self.socket.send_multipart( + [identity, cloudpickle.dumps(VLLM_RPC_HEALTHY_STR)]) + except Exception as e: + await self.socket.send_multipart([identity, cloudpickle.dumps(e)]) + + def _make_handler_coro(self, identity, + message) -> Coroutine[Any, Any, Never]: + """Route the zmq message to the handler coroutine.""" + + request = cloudpickle.loads(message) + + if isinstance(request, RPCGenerateRequest): + return self.generate(identity, request) + + elif isinstance(request, RPCAbortRequest): + return self.abort(identity, request) + + elif isinstance(request, RPCUtilityRequest): + if request == RPCUtilityRequest.GET_MODEL_CONFIG: + return self.get_model_config(identity) + elif request == RPCUtilityRequest.GET_PARALLEL_CONFIG: + return self.get_parallel_config(identity) + elif request == RPCUtilityRequest.GET_DECODING_CONFIG: + return self.get_decoding_config(identity) + elif request == RPCUtilityRequest.GET_SCHEDULER_CONFIG: + return self.get_scheduler_config(identity) + elif request == RPCUtilityRequest.GET_LORA_CONFIG: + return self.get_lora_config(identity) + elif request == RPCUtilityRequest.DO_LOG_STATS: + return self.do_log_stats(identity) + elif request == RPCUtilityRequest.IS_SERVER_READY: + return self.is_server_ready(identity) + elif request == RPCUtilityRequest.CHECK_HEALTH: + return self.check_health(identity) + elif request == RPCUtilityRequest.IS_TRACING_ENABLED: + return self.is_tracing_enabled(identity) + else: + raise ValueError(f"Unknown RPCUtilityRequest type: {request}") + + else: + raise ValueError(f"Unknown RPCRequest type: {request}") + + async def run_server_loop(self): + """Inner RPC Server Loop""" + + running_tasks = set() + while True: + # Wait for a request. + identity, message = await self.socket.recv_multipart() + + # Process the request async. + task = asyncio.create_task( + self._make_handler_coro(identity, message)) + + # We need to keep around a strong reference to the task, + # to avoid the task disappearing mid-execution as running tasks + # can be GC'ed. Below is a common "fire-and-forget" tasks + # https://docs.python.org/3/library/asyncio-task.html#asyncio.create_task + running_tasks.add(task) + task.add_done_callback(running_tasks.discard) + + +async def run_server(server: AsyncEngineRPCServer): + # Put the server task into the asyncio loop. + loop = asyncio.get_running_loop() + server_task = loop.create_task(server.run_server_loop()) + + # Interruption handling. + def signal_handler() -> None: + # Kill the server on interrupt / terminate + server_task.cancel() + + loop.add_signal_handler(signal.SIGINT, signal_handler) + loop.add_signal_handler(signal.SIGTERM, signal_handler) + + try: + await server_task + except asyncio.CancelledError: + logger.info("vLLM ZMQ RPC Server was interrupted.") + finally: + # Clean up all resources. + server.cleanup() + + +def run_rpc_server(async_engine_args: AsyncEngineArgs, + usage_context: UsageContext, port: int): + server = AsyncEngineRPCServer(async_engine_args, usage_context, port) + asyncio.run(run_server(server)) diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index c832cf2a24b50..ebb1d57fbb9a6 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -8,7 +8,7 @@ from transformers import PreTrainedTokenizer from vllm.config import ModelConfig -from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm.engine.protocol import AsyncEngineClient from vllm.entrypoints.chat_utils import (ConversationMessage, load_chat_template, parse_chat_message_content) @@ -39,7 +39,7 @@ class OpenAIServingChat(OpenAIServing): def __init__( self, - engine: AsyncLLMEngine, + async_engine_client: AsyncEngineClient, model_config: ModelConfig, served_model_names: List[str], response_role: str, @@ -50,7 +50,7 @@ def __init__( chat_template: Optional[str], return_tokens_as_token_ids: bool = False, ): - super().__init__(engine=engine, + super().__init__(async_engine_client=async_engine_client, model_config=model_config, served_model_names=served_model_names, lora_modules=lora_modules, @@ -89,7 +89,8 @@ async def create_chat_completion( ) = self._maybe_get_adapters(request) model_config = self.model_config - tokenizer = await self.engine.get_tokenizer(lora_request) + tokenizer = await self.async_engine_client.get_tokenizer( + lora_request) conversation: List[ConversationMessage] = [] mm_futures: List[Awaitable[MultiModalDataDict]] = [] @@ -161,7 +162,8 @@ async def create_chat_completion( if mm_data is not None: engine_inputs["multi_modal_data"] = mm_data - is_tracing_enabled = await self.engine.is_tracing_enabled() + is_tracing_enabled = ( + await self.async_engine_client.is_tracing_enabled()) trace_headers = None if is_tracing_enabled and raw_request: trace_headers = extract_trace_headers(raw_request.headers) @@ -169,7 +171,7 @@ async def create_chat_completion( and contains_trace_headers(raw_request.headers)): log_tracing_disabled_warning() - result_generator = self.engine.generate( + result_generator = self.async_engine_client.generate( engine_inputs, sampling_params, request_id, @@ -441,7 +443,7 @@ async def chat_completion_full_generator( async for res in result_generator: if raw_request is not None and await raw_request.is_disconnected(): # Abort the request if the client disconnects. - await self.engine.abort(request_id) + await self.async_engine_client.abort(request_id) return self.create_error_response("Client disconnected") final_res = res assert final_res is not None diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 7765c5903f341..edc83d83fbba7 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -8,7 +8,7 @@ from transformers import PreTrainedTokenizer from vllm.config import ModelConfig -from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm.engine.protocol import AsyncEngineClient from vllm.entrypoints.logger import RequestLogger # yapf conflicts with isort for this block # yapf: disable @@ -42,7 +42,7 @@ class OpenAIServingCompletion(OpenAIServing): def __init__( self, - engine: AsyncLLMEngine, + async_engine_client: AsyncEngineClient, model_config: ModelConfig, served_model_names: List[str], *, @@ -51,7 +51,7 @@ def __init__( request_logger: Optional[RequestLogger], return_tokens_as_token_ids: bool = False, ): - super().__init__(engine=engine, + super().__init__(async_engine_client=async_engine_client, model_config=model_config, served_model_names=served_model_names, lora_modules=lora_modules, @@ -91,7 +91,8 @@ async def create_completion(self, request: CompletionRequest, prompt_adapter_request, ) = self._maybe_get_adapters(request) - tokenizer = await self.engine.get_tokenizer(lora_request) + tokenizer = await self.async_engine_client.get_tokenizer( + lora_request) guided_decode_logits_processor = ( await self._guided_decode_logits_processor(request, tokenizer)) @@ -119,7 +120,8 @@ async def create_completion(self, request: CompletionRequest, lora_request=lora_request, prompt_adapter_request=prompt_adapter_request) - is_tracing_enabled = await self.engine.is_tracing_enabled() + is_tracing_enabled = ( + await self.async_engine_client.is_tracing_enabled()) trace_headers = None if is_tracing_enabled: trace_headers = extract_trace_headers(raw_request.headers) @@ -127,7 +129,7 @@ async def create_completion(self, request: CompletionRequest, raw_request.headers): log_tracing_disabled_warning() - generator = self.engine.generate( + generator = self.async_engine_client.generate( {"prompt_token_ids": prompt_inputs["prompt_token_ids"]}, sampling_params, request_id_item, @@ -168,7 +170,7 @@ async def create_completion(self, request: CompletionRequest, async for i, res in result_generator: if await raw_request.is_disconnected(): # Abort the request if the client disconnects. - await self.engine.abort(f"{request_id}-{i}") + await self.async_engine_client.abort(f"{request_id}-{i}") return self.create_error_response("Client disconnected") final_res_batch[i] = res @@ -230,7 +232,8 @@ async def completion_stream_generator( # Abort the request if the client disconnects. if await raw_request.is_disconnected(): - await self.engine.abort(f"{request_id}-{prompt_idx}") + await self.async_engine_client.abort( + f"{request_id}-{prompt_idx}") raise StopAsyncIteration() for output in res.outputs: diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py index bccc90894e79f..e61c82f9a8a6c 100644 --- a/vllm/entrypoints/openai/serving_embedding.py +++ b/vllm/entrypoints/openai/serving_embedding.py @@ -6,7 +6,7 @@ from fastapi import Request from vllm.config import ModelConfig -from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm.engine.protocol import AsyncEngineClient from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.openai.protocol import (EmbeddingRequest, EmbeddingResponse, @@ -56,13 +56,13 @@ class OpenAIServingEmbedding(OpenAIServing): def __init__( self, - engine: AsyncLLMEngine, + async_engine_client: AsyncEngineClient, model_config: ModelConfig, served_model_names: List[str], *, request_logger: Optional[RequestLogger], ): - super().__init__(engine=engine, + super().__init__(async_engine_client=async_engine_client, model_config=model_config, served_model_names=served_model_names, lora_modules=None, @@ -99,7 +99,8 @@ async def create_embedding(self, request: EmbeddingRequest, prompt_adapter_request, ) = self._maybe_get_adapters(request) - tokenizer = await self.engine.get_tokenizer(lora_request) + tokenizer = await self.async_engine_client.get_tokenizer( + lora_request) pooling_params = request.to_pooling_params() @@ -124,7 +125,7 @@ async def create_embedding(self, request: EmbeddingRequest, "Prompt adapter is not supported " "for embedding models") - generator = self.engine.encode( + generator = self.async_engine_client.encode( {"prompt_token_ids": prompt_inputs["prompt_token_ids"]}, pooling_params, request_id_item, @@ -146,7 +147,7 @@ async def create_embedding(self, request: EmbeddingRequest, async for i, res in result_generator: if await raw_request.is_disconnected(): # Abort the request if the client disconnects. - await self.engine.abort(f"{request_id}-{i}") + await self.async_engine_client.abort(f"{request_id}-{i}") return self.create_error_response("Client disconnected") final_res_batch[i] = res diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 8c7929a12e9a0..df4932d8fe185 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -8,7 +8,7 @@ from typing_extensions import Annotated from vllm.config import ModelConfig -from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm.engine.protocol import AsyncEngineClient from vllm.entrypoints.logger import RequestLogger # yapf conflicts with isort for this block # yapf: disable @@ -61,7 +61,7 @@ class OpenAIServing: def __init__( self, - engine: AsyncLLMEngine, + async_engine_client: AsyncEngineClient, model_config: ModelConfig, served_model_names: List[str], *, @@ -72,7 +72,7 @@ def __init__( ): super().__init__() - self.engine = engine + self.async_engine_client = async_engine_client self.model_config = model_config self.max_model_len = model_config.max_model_len @@ -155,7 +155,7 @@ def create_streaming_error_response( async def _guided_decode_logits_processor( self, request: Union[ChatCompletionRequest, CompletionRequest], tokenizer: AnyTokenizer) -> Optional[LogitsProcessor]: - decoding_config = await self.engine.get_decoding_config() + decoding_config = await self.async_engine_client.get_decoding_config() guided_decoding_backend = request.guided_decoding_backend \ or decoding_config.guided_decoding_backend return await get_guided_decoding_logits_processor( diff --git a/vllm/entrypoints/openai/serving_tokenization.py b/vllm/entrypoints/openai/serving_tokenization.py index 94e1b03ed4036..c4350881a27a6 100644 --- a/vllm/entrypoints/openai/serving_tokenization.py +++ b/vllm/entrypoints/openai/serving_tokenization.py @@ -1,9 +1,9 @@ from typing import List, Optional, Union from vllm.config import ModelConfig -from vllm.engine.async_llm_engine import AsyncLLMEngine # yapf conflicts with isort for this block # yapf: disable +from vllm.engine.protocol import AsyncEngineClient from vllm.entrypoints.chat_utils import (ConversationMessage, load_chat_template, parse_chat_message_content) @@ -24,7 +24,7 @@ class OpenAIServingTokenization(OpenAIServing): def __init__( self, - engine: AsyncLLMEngine, + async_engine_client: AsyncEngineClient, model_config: ModelConfig, served_model_names: List[str], *, @@ -32,7 +32,7 @@ def __init__( request_logger: Optional[RequestLogger], chat_template: Optional[str], ): - super().__init__(engine=engine, + super().__init__(async_engine_client=async_engine_client, model_config=model_config, served_model_names=served_model_names, lora_modules=lora_modules, @@ -57,7 +57,7 @@ async def create_tokenize( prompt_adapter_request, ) = self._maybe_get_adapters(request) - tokenizer = await self.engine.get_tokenizer(lora_request) + tokenizer = await self.async_engine_client.get_tokenizer(lora_request) if isinstance(request, TokenizeChatRequest): model_config = self.model_config @@ -113,7 +113,7 @@ async def create_detokenize( prompt_adapter_request, ) = self._maybe_get_adapters(request) - tokenizer = await self.engine.get_tokenizer(lora_request) + tokenizer = await self.async_engine_client.get_tokenizer(lora_request) self._log_inputs(request_id, request.tokens, diff --git a/vllm/envs.py b/vllm/envs.py index 595058bcbb027..a78bad6a2b273 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -4,6 +4,7 @@ if TYPE_CHECKING: VLLM_HOST_IP: str = "" VLLM_PORT: Optional[int] = None + VLLM_RPC_PORT: int = 5570 VLLM_USE_MODELSCOPE: bool = False VLLM_RINGBUFFER_WARNING_INTERVAL: int = 60 VLLM_INSTANCE_ID: Optional[str] = None @@ -140,6 +141,11 @@ def get_default_config_root(): lambda: int(os.getenv('VLLM_PORT', '0')) if 'VLLM_PORT' in os.environ else None, + # used when the frontend api server is running in multi-processing mode, + # to communicate with the backend engine process over ZMQ. + 'VLLM_RPC_PORT': + lambda: int(os.getenv('VLLM_PORT', '5570')), + # If true, will load models from ModelScope instead of Hugging Face Hub. # note that the value is true or false, not numbers "VLLM_USE_MODELSCOPE": diff --git a/vllm/model_executor/guided_decoding/outlines_logits_processors.py b/vllm/model_executor/guided_decoding/outlines_logits_processors.py index 1c8f6cccb3e9a..554dcc0ed43ed 100644 --- a/vllm/model_executor/guided_decoding/outlines_logits_processors.py +++ b/vllm/model_executor/guided_decoding/outlines_logits_processors.py @@ -21,6 +21,8 @@ from typing import Callable, DefaultDict, Dict, List, Union import torch +from lark import Lark +from outlines import grammars from outlines.caching import cache from outlines.fsm.guide import CFGGuide, Generate, Guide, RegexGuide, Write from outlines.fsm.json_schema import build_regex_from_schema @@ -44,6 +46,23 @@ def __call__(self, input_ids: List[int], last_seq_id = hash(tuple(input_ids[:-1])) self._fsm_state[seq_id] = self._guide.get_next_state( state=self._fsm_state[last_seq_id], token_id=last_token) + else: + # Note: this is a hack. + # Lark pickling does not work properly (silent failure), + # which breaks the RPC (which uses python pickleing). + # We need to find a better solution. + # On the first time this is called, we simply re-create + # the Lark object. + if isinstance(self._guide, CFGGuide): + self._guide.parser = Lark( + self._guide.cfg_string, + parser="lalr", + lexer="contextual", + propagate_positions=False, + maybe_placeholders=False, + regex=True, + import_paths=[grammars.GRAMMAR_PATH], + ) instruction = self._guide.get_next_instruction( state=self._fsm_state[seq_id]) diff --git a/vllm/tracing.py b/vllm/tracing.py index dc8377f2396f2..7ac38e6a0f663 100644 --- a/vllm/tracing.py +++ b/vllm/tracing.py @@ -60,7 +60,7 @@ def get_span_exporter(endpoint): OTLPSpanExporter) elif protocol == "http/protobuf": from opentelemetry.exporter.otlp.proto.http.trace_exporter import ( - OTLPSpanExporter) + OTLPSpanExporter) # type: ignore else: raise ValueError( f"Unsupported OTLP protocol '{protocol}' is configured") diff --git a/vllm/transformers_utils/tokenizer_group/__init__.py b/vllm/transformers_utils/tokenizer_group/__init__.py index 7a0436dd1fb16..eeab19899b022 100644 --- a/vllm/transformers_utils/tokenizer_group/__init__.py +++ b/vllm/transformers_utils/tokenizer_group/__init__.py @@ -1,6 +1,7 @@ from typing import Optional, Type -from vllm.config import TokenizerPoolConfig +from vllm.config import (ModelConfig, ParallelConfig, SchedulerConfig, + TokenizerPoolConfig) from vllm.executor.ray_utils import ray from .base_tokenizer_group import AnyTokenizer, BaseTokenizerGroup @@ -13,6 +14,22 @@ RayTokenizerGroupPool = None # type: ignore +def init_tokenizer_from_configs(model_config: ModelConfig, + scheduler_config: SchedulerConfig, + parallel_config: ParallelConfig, + enable_lora: bool): + init_kwargs = dict(tokenizer_id=model_config.tokenizer, + enable_lora=enable_lora, + max_num_seqs=scheduler_config.max_num_seqs, + max_input_length=None, + tokenizer_mode=model_config.tokenizer_mode, + trust_remote_code=model_config.trust_remote_code, + revision=model_config.tokenizer_revision) + + return get_tokenizer_group(parallel_config.tokenizer_pool_config, + **init_kwargs) + + def get_tokenizer_group(tokenizer_pool_config: Optional[TokenizerPoolConfig], **init_kwargs) -> BaseTokenizerGroup: tokenizer_cls: Type[BaseTokenizerGroup] diff --git a/vllm/utils.py b/vllm/utils.py index c4c17bfbefc65..51bd72977a226 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -290,6 +290,10 @@ def _async_wrapper(*args: P.args, **kwargs: P.kwargs) -> asyncio.Future: return _async_wrapper +class ProducerFinished: + pass + + def merge_async_iterators( *iterators: AsyncIterator[T]) -> AsyncIterator[Tuple[int, T]]: """Merge multiple asynchronous iterators into a single iterator. @@ -298,9 +302,10 @@ def merge_async_iterators( When it yields, it yields a tuple (i, item) where i is the index of the iterator that yields the item. """ - queue: asyncio.Queue[Union[Tuple[int, T], Exception]] = asyncio.Queue() + queue: asyncio.Queue[Union[Tuple[int, T], ProducerFinished, + Exception]] = asyncio.Queue() - finished = [False] * len(iterators) + producers = len(iterators) async def producer(i: int, iterator: AsyncIterator[T]): try: @@ -308,7 +313,8 @@ async def producer(i: int, iterator: AsyncIterator[T]): await queue.put((i, item)) except Exception as e: await queue.put(e) - finished[i] = True + # Signal to the consumer that we've finished + await queue.put(ProducerFinished()) _tasks = [ asyncio.create_task(producer(i, iterator)) @@ -316,9 +322,17 @@ async def producer(i: int, iterator: AsyncIterator[T]): ] async def consumer(): + remaining = producers try: - while not all(finished) or not queue.empty(): + while remaining or not queue.empty(): + # we think there is a race condition here item = await queue.get() + + if isinstance(item, ProducerFinished): + # Signal that a producer finished- not a real item + remaining -= 1 + continue + if isinstance(item, Exception): raise item yield item @@ -374,8 +388,10 @@ def get_distributed_init_method(ip: str, port: int) -> str: return f"tcp://[{ip}]:{port}" if ":" in ip else f"tcp://{ip}:{port}" -def get_open_port() -> int: - port = envs.VLLM_PORT +def get_open_port(port: Optional[int] = None) -> int: + if port is None: + # Default behavior here is to return a port for multi-gpu communication + port = envs.VLLM_PORT if port is not None: while True: try: From 69ea15e5cc823b2bc040921ce516807fb7357dd1 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 2 Aug 2024 21:05:16 -0700 Subject: [PATCH 07/36] [ci][distributed] shorten wait time if server hangs (#7098) --- tests/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/utils.py b/tests/utils.py index dd8af8e3afe70..974fece49f4b4 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -50,7 +50,7 @@ def _nvml(): class RemoteOpenAIServer: DUMMY_API_KEY = "token-abc123" # vLLM's OpenAI server does not need API key - MAX_SERVER_START_WAIT_S = 600 # wait for server to start for 60 seconds + MAX_SERVER_START_WAIT_S = 120 # wait for server to start for 120 seconds def __init__( self, From 8c025fa7030350a81bfeb665c99ad622667bdac0 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Sat, 3 Aug 2024 12:31:27 +0800 Subject: [PATCH 08/36] [Frontend] Factor out chat message parsing (#7055) --- vllm/entrypoints/chat_utils.py | 28 +++++++++++++++---- vllm/entrypoints/openai/serving_chat.py | 17 ++++------- .../openai/serving_tokenization.py | 21 +++++++------- 3 files changed, 39 insertions(+), 27 deletions(-) diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index fbb7f70b55e16..072450a6146ee 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -1,7 +1,8 @@ import codecs -from dataclasses import dataclass, field +from dataclasses import dataclass from functools import lru_cache -from typing import Awaitable, Iterable, List, Optional, Union, cast, final +from typing import (Awaitable, Iterable, List, Optional, Tuple, Union, cast, + final) # yapf conflicts with isort for this block # yapf: disable @@ -65,8 +66,7 @@ class ConversationMessage(TypedDict): @dataclass(frozen=True) class ChatMessageParseResult: messages: List[ConversationMessage] - mm_futures: List[Awaitable[MultiModalDataDict]] = field( - default_factory=list) + mm_futures: List[Awaitable[MultiModalDataDict]] def load_chat_template(chat_template: Optional[str]) -> Optional[str]: @@ -174,7 +174,7 @@ def _parse_chat_message_content_parts( return ChatMessageParseResult(messages=messages, mm_futures=mm_futures) -def parse_chat_message_content( +def _parse_chat_message_content( message: ChatCompletionMessageParam, model_config: ModelConfig, tokenizer: PreTrainedTokenizer, @@ -190,3 +190,21 @@ def parse_chat_message_content( return _parse_chat_message_content_parts(role, content, model_config, tokenizer) + + +def parse_chat_messages( + messages: List[ChatCompletionMessageParam], + model_config: ModelConfig, + tokenizer: PreTrainedTokenizer, +) -> Tuple[List[ConversationMessage], List[Awaitable[MultiModalDataDict]]]: + conversation: List[ConversationMessage] = [] + mm_futures: List[Awaitable[MultiModalDataDict]] = [] + + for msg in messages: + parse_result = _parse_chat_message_content(msg, model_config, + tokenizer) + + conversation.extend(parse_result.messages) + mm_futures.extend(parse_result.mm_futures) + + return conversation, mm_futures diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index ebb1d57fbb9a6..d215754993e82 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -1,6 +1,5 @@ import time -from typing import (AsyncGenerator, AsyncIterator, Awaitable, Dict, List, - Optional) +from typing import AsyncGenerator, AsyncIterator, Dict, List, Optional from typing import Sequence as GenericSequence from typing import Union @@ -11,7 +10,7 @@ from vllm.engine.protocol import AsyncEngineClient from vllm.entrypoints.chat_utils import (ConversationMessage, load_chat_template, - parse_chat_message_content) + parse_chat_messages) from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.openai.protocol import ( ChatCompletionLogProb, ChatCompletionLogProbs, @@ -92,15 +91,8 @@ async def create_chat_completion( tokenizer = await self.async_engine_client.get_tokenizer( lora_request) - conversation: List[ConversationMessage] = [] - mm_futures: List[Awaitable[MultiModalDataDict]] = [] - - for msg in request.messages: - chat_parsed_result = parse_chat_message_content( - msg, model_config, tokenizer) - - conversation.extend(chat_parsed_result.messages) - mm_futures.extend(chat_parsed_result.mm_futures) + conversation, mm_futures = parse_chat_messages( + request.messages, model_config, tokenizer) tool_dicts = None if request.tools is None else [ tool.model_dump() for tool in request.tools @@ -115,6 +107,7 @@ async def create_chat_completion( chat_template=request.chat_template or self.chat_template, **(request.chat_template_kwargs or {}), ) + assert isinstance(prompt, str) except Exception as e: logger.error("Error in applying chat template from request: %s", e) return self.create_error_response(str(e)) diff --git a/vllm/entrypoints/openai/serving_tokenization.py b/vllm/entrypoints/openai/serving_tokenization.py index c4350881a27a6..5b6b979b9b9e7 100644 --- a/vllm/entrypoints/openai/serving_tokenization.py +++ b/vllm/entrypoints/openai/serving_tokenization.py @@ -1,13 +1,11 @@ from typing import List, Optional, Union from vllm.config import ModelConfig -# yapf conflicts with isort for this block -# yapf: disable from vllm.engine.protocol import AsyncEngineClient -from vllm.entrypoints.chat_utils import (ConversationMessage, - load_chat_template, - parse_chat_message_content) +from vllm.entrypoints.chat_utils import load_chat_template, parse_chat_messages from vllm.entrypoints.logger import RequestLogger +# yapf conflicts with isort for this block +# yapf: disable from vllm.entrypoints.openai.protocol import (DetokenizeRequest, DetokenizeResponse, ErrorResponse, @@ -17,8 +15,11 @@ # yapf: enable from vllm.entrypoints.openai.serving_engine import (LoRAModulePath, OpenAIServing) +from vllm.logger import init_logger from vllm.utils import random_uuid +logger = init_logger(__name__) + class OpenAIServingTokenization(OpenAIServing): @@ -62,12 +63,12 @@ async def create_tokenize( if isinstance(request, TokenizeChatRequest): model_config = self.model_config - conversation: List[ConversationMessage] = [] + conversation, mm_futures = parse_chat_messages( + request.messages, model_config, tokenizer) - for message in request.messages: - result = parse_chat_message_content(message, model_config, - tokenizer) - conversation.extend(result.messages) + if mm_futures: + logger.warning( + "Multi-modal inputs are ignored during tokenization") prompt = tokenizer.apply_chat_template( add_generation_prompt=request.add_generation_prompt, From 04e55834254bf11770d544bbeebdbdb7731d9bbd Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 2 Aug 2024 21:33:53 -0700 Subject: [PATCH 09/36] [ci][distributed] merge distributed test commands (#7097) Co-authored-by: Cyrus Leung --- .buildkite/test-pipeline.yaml | 27 ++------- .../test_basic_distributed_correctness.py | 50 ++++++++++------ .../test_chunked_prefill_distributed.py | 35 +++++------- .../distributed/test_multimodal_broadcast.py | 57 +++++++++---------- 4 files changed, 78 insertions(+), 91 deletions(-) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 573c3740f0bbb..93b3e3fe91663 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -82,20 +82,9 @@ steps: num_gpus: 2 commands: - VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py - - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py - - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py - - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray VLLM_USE_RAY_SPMD_WORKER=1 VLLM_USE_RAY_COMPILED_DAG=1 pytest -v -s distributed/test_basic_distributed_correctness.py - - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray VLLM_USE_RAY_SPMD_WORKER=1 VLLM_USE_RAY_COMPILED_DAG=1 pytest -v -s distributed/test_basic_distributed_correctness.py - - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_chunked_prefill_distributed.py - - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_chunked_prefill_distributed.py - - TEST_DIST_MODEL=llava-hf/llava-1.5-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_multimodal_broadcast.py - - TEST_DIST_MODEL=llava-hf/llava-v1.6-mistral-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_multimodal_broadcast.py - - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py - - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py - - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_chunked_prefill_distributed.py - - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_chunked_prefill_distributed.py - - TEST_DIST_MODEL=llava-hf/llava-1.5-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_multimodal_broadcast.py - - TEST_DIST_MODEL=llava-hf/llava-v1.6-mistral-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_multimodal_broadcast.py + - TARGET_TEST_SUITE=L4 pytest -v -s distributed/test_basic_distributed_correctness.py + - pytest -v -s distributed/test_chunked_prefill_distributed.py + - pytest -v -s distributed/test_multimodal_broadcast.py - pytest -v -s spec_decode/e2e/test_integration_dist_tp2.py - CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py - CUDA_VISIBLE_DEVICES=0,1 pytest -v -s distributed/test_utils.py @@ -107,11 +96,6 @@ steps: fast_check: true commands: - pytest -v -s distributed/test_pynccl.py - # We want to test that models which use 2 GPUs work with 4 GPUs, which is why we duplicate them here. - # See https://github.com/vllm-project/vllm/pull/5473#issuecomment-2166601837 for context. - - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py - - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray VLLM_USE_RAY_SPMD_WORKER=1 VLLM_USE_RAY_COMPILED_DAG=1 pytest -v -s distributed/test_basic_distributed_correctness.py - - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py - pytest -v -s spec_decode/e2e/test_integration_dist_tp4.py - label: Pipeline Parallelism Test @@ -279,9 +263,6 @@ steps: # NOTE: don't test llama model here, it seems hf implementation is buggy # see https://github.com/vllm-project/vllm/pull/5689 for details - pytest -v -s distributed/test_custom_all_reduce.py - - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py - - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py - pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.1.2/flashinfer-0.1.2+cu121torch2.4-cp310-cp310-linux_x86_64.whl - - VLLM_ATTENTION_BACKEND=FLASHINFER TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py - - VLLM_ATTENTION_BACKEND=FLASHINFER TEST_DIST_MODEL=meta-llama/Meta-Llama-3-8B DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py + - TARGET_TEST_SUITE=A100 pytest -v -s distributed/test_basic_distributed_correctness.py - pytest -v -s -x lora/test_mixtral.py diff --git a/tests/distributed/test_basic_distributed_correctness.py b/tests/distributed/test_basic_distributed_correctness.py index 7a0e5673b2cc4..1de2ebab22db4 100644 --- a/tests/distributed/test_basic_distributed_correctness.py +++ b/tests/distributed/test_basic_distributed_correctness.py @@ -1,15 +1,10 @@ """Compare the outputs of HF and distributed vLLM when using greedy sampling. -vLLM will allocate all the available memory, so we need to run the tests one -by one. The solution is to pass arguments (model name) by environment -variables. + Run: ```sh cd $VLLM_PATH/tests -TEST_DIST_MODEL=facebook/opt-125m pytest \ - distributed/test_basic_distributed_correctness.py -TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf \ - distributed/test_basic_distributed_correctness.py +pytest distributed/test_basic_distributed_correctness.py ``` """ import os @@ -19,27 +14,48 @@ from vllm.utils import cuda_device_count_stateless from ..models.utils import check_outputs_equal +from ..utils import fork_new_process_for_each_test -MODELS = [ - os.environ["TEST_DIST_MODEL"], -] -DISTRIBUTED_EXECUTOR_BACKEND = "DISTRIBUTED_EXECUTOR_BACKEND" +TARGET_TEST_SUITE = os.environ.get("TARGET_TEST_SUITE", "L4") @pytest.mark.skipif(cuda_device_count_stateless() < 2, reason="Need at least 2 GPUs to run the test.") -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["half"]) -@pytest.mark.parametrize("max_tokens", [5]) +@pytest.mark.parametrize( + "model, distributed_executor_backend, attention_backend, test_suite", [ + ("facebook/opt-125m", "ray", "", "L4"), + ("facebook/opt-125m", "mp", "", "L4"), + ("meta-llama/Llama-2-7b-hf", "ray", "", "L4"), + ("meta-llama/Llama-2-7b-hf", "mp", "", "L4"), + ("facebook/opt-125m", "ray", "", "A100"), + ("facebook/opt-125m", "mp", "", "A100"), + ("facebook/opt-125m", "mp", "FLASHINFER", "A100"), + ("meta-llama/Meta-Llama-3-8B", "ray", "FLASHINFER", "A100"), + ]) +@fork_new_process_for_each_test def test_models( hf_runner, vllm_runner, example_prompts, model: str, - dtype: str, - max_tokens: int, + distributed_executor_backend: str, + attention_backend: str, + test_suite: str, ) -> None: - distributed_executor_backend = os.getenv(DISTRIBUTED_EXECUTOR_BACKEND) + + if test_suite != TARGET_TEST_SUITE: + pytest.skip(f"Skip test for {test_suite}") + + if model == "meta-llama/Llama-2-7b-hf" and distributed_executor_backend == "ray" and attention_backend == "" and test_suite == "L4": # noqa + # test ray adag + os.environ['VLLM_USE_RAY_SPMD_WORKER'] = "1" + os.environ['VLLM_USE_RAY_COMPILED_DAG'] = "1" + + if attention_backend: + os.environ["VLLM_ATTENTION_BACKEND"] = attention_backend + + dtype = "half" + max_tokens = 5 # NOTE: take care of the order. run vLLM first, and then run HF. # vLLM needs a fresh new process without cuda initialization. diff --git a/tests/distributed/test_chunked_prefill_distributed.py b/tests/distributed/test_chunked_prefill_distributed.py index 1ef085b933793..10921a3852f81 100644 --- a/tests/distributed/test_chunked_prefill_distributed.py +++ b/tests/distributed/test_chunked_prefill_distributed.py @@ -1,46 +1,39 @@ """Compare the outputs of HF and distributed vLLM when using greedy sampling. -vLLM will allocate all the available memory, so we need to run the tests one -by one. The solution is to pass arguments (model name) by environment -variables. Run: ```sh -TEST_DIST_MODEL=facebook/opt-125m pytest \ - test_chunked_prefill_distributed.py -TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf \ - test_chunked_prefill_distributed.py +pytest test_chunked_prefill_distributed.py ``` """ -import os import pytest from vllm.utils import cuda_device_count_stateless from ..models.utils import check_outputs_equal - -MODELS = [ - os.environ["TEST_DIST_MODEL"], -] -DISTRIBUTED_EXECUTOR_BACKEND = "DISTRIBUTED_EXECUTOR_BACKEND" +from ..utils import fork_new_process_for_each_test @pytest.mark.skipif(cuda_device_count_stateless() < 2, reason="Need at least 2 GPUs to run the test.") -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["half"]) -@pytest.mark.parametrize("max_tokens", [5]) -@pytest.mark.parametrize("chunked_prefill_token_size", [16]) +@pytest.mark.parametrize("model, distributed_executor_backend", [ + ("facebook/opt-125m", "ray"), + ("meta-llama/Llama-2-7b-hf", "ray"), + ("facebook/opt-125m", "mp"), + ("meta-llama/Llama-2-7b-hf", "mp"), +]) +@fork_new_process_for_each_test def test_models( hf_runner, vllm_runner, example_prompts, model: str, - dtype: str, - max_tokens: int, - chunked_prefill_token_size: int, + distributed_executor_backend: str, ) -> None: - distributed_executor_backend = os.getenv(DISTRIBUTED_EXECUTOR_BACKEND) + + dtype = "half" + max_tokens = 5 + chunked_prefill_token_size = 16 # Add a chunked prefill config. max_num_seqs = min(chunked_prefill_token_size, 256) diff --git a/tests/distributed/test_multimodal_broadcast.py b/tests/distributed/test_multimodal_broadcast.py index a99917f586949..2c96358e2e6f2 100644 --- a/tests/distributed/test_multimodal_broadcast.py +++ b/tests/distributed/test_multimodal_broadcast.py @@ -1,44 +1,41 @@ """Compare the outputs of HF and distributed vLLM when using greedy sampling. -The second test will hang if more than one test is run per command, so we need -to run the tests one by one. The solution is to pass arguments (model name) by -environment variables. Run: ```sh -TEST_DIST_MODEL=llava-hf/llava-1.5-7b-hf \ - test_multimodal_broadcast.py -TEST_DIST_MODEL=microsoft/Phi-3-vision-128k-instruct \ - test_multimodal_broadcast.py +pytest -s -v test_multimodal_broadcast.py ``` """ -import os import pytest from vllm.utils import cuda_device_count_stateless -model = os.environ["TEST_DIST_MODEL"] - -if model.startswith("llava-hf/llava-1.5"): - from ..models.test_llava import models, run_test -elif model.startswith("llava-hf/llava-v1.6"): - from ..models.test_llava_next import models, run_test -else: - raise NotImplementedError(f"Unsupported model: {model}") - - -@pytest.mark.parametrize("tensor_parallel_size", [2]) -@pytest.mark.parametrize("dtype", ["half"]) -@pytest.mark.parametrize("max_tokens", [128]) -@pytest.mark.parametrize("num_logprobs", [5]) -def test_models(hf_runner, vllm_runner, image_assets, - tensor_parallel_size: int, dtype: str, max_tokens: int, - num_logprobs: int) -> None: - if cuda_device_count_stateless() < tensor_parallel_size: - pytest.skip( - f"Need at least {tensor_parallel_size} GPUs to run the test.") - - distributed_executor_backend = os.getenv("DISTRIBUTED_EXECUTOR_BACKEND") +from ..utils import fork_new_process_for_each_test + + +@pytest.mark.skipif(cuda_device_count_stateless() < 2, + reason="Need at least 2 GPUs to run the test.") +@pytest.mark.parametrize("model, distributed_executor_backend", [ + ("llava-hf/llava-1.5-7b-hf", "ray"), + ("llava-hf/llava-v1.6-mistral-7b-hf", "ray"), + ("llava-hf/llava-1.5-7b-hf", "mp"), + ("llava-hf/llava-v1.6-mistral-7b-hf", "mp"), +]) +@fork_new_process_for_each_test +def test_models(hf_runner, vllm_runner, image_assets, model: str, + distributed_executor_backend: str) -> None: + + dtype = "half" + max_tokens = 5 + num_logprobs = 5 + tensor_parallel_size = 2 + + if model.startswith("llava-hf/llava-1.5"): + from ..models.test_llava import models, run_test + elif model.startswith("llava-hf/llava-v1.6"): + from ..models.test_llava_next import models, run_test + else: + raise NotImplementedError(f"Unsupported model: {model}") run_test( hf_runner, From a0d164567cd2a82d827c81a49a21e3f2c75a522d Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 2 Aug 2024 22:32:04 -0700 Subject: [PATCH 10/36] [ci][distributed] disable ray dag tests (#7099) --- tests/distributed/test_pipeline_parallel.py | 43 +++++++++------------ 1 file changed, 18 insertions(+), 25 deletions(-) diff --git a/tests/distributed/test_pipeline_parallel.py b/tests/distributed/test_pipeline_parallel.py index ab325e0966929..8eb5ca9461c75 100644 --- a/tests/distributed/test_pipeline_parallel.py +++ b/tests/distributed/test_pipeline_parallel.py @@ -14,36 +14,29 @@ VLLM_MULTI_NODE = os.getenv("VLLM_MULTI_NODE", "0") == "1" -@pytest.mark.parametrize( - ("TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, " - "MODEL_NAME, DIST_BACKEND, USE_RAY_ADAG, USE_RAY_ADAG_NCCL"), [ - (2, 2, 0, 1, "meta-llama/Meta-Llama-3-8B", "ray", False, False), - (2, 2, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray", False, False), - (1, 3, 0, 0, "meta-llama/Meta-Llama-3-8B", "ray", False, False), - (1, 4, 0, 1, "meta-llama/Meta-Llama-3-8B", "ray", False, False), - (1, 4, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray", False, False), - (2, 2, 0, 1, "meta-llama/Meta-Llama-3-8B", "ray", True, False), - (2, 2, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray", True, False), - (1, 3, 0, 0, "meta-llama/Meta-Llama-3-8B", "ray", True, False), - (1, 4, 0, 1, "meta-llama/Meta-Llama-3-8B", "ray", True, False), - (1, 4, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray", True, False), - (2, 2, 0, 1, "meta-llama/Meta-Llama-3-8B", "ray", True, True), - (2, 2, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray", True, True), - (1, 3, 0, 0, "meta-llama/Meta-Llama-3-8B", "ray", True, True), - (1, 4, 0, 1, "meta-llama/Meta-Llama-3-8B", "ray", True, True), - (1, 4, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray", True, True), - (2, 2, 0, 1, "meta-llama/Meta-Llama-3-8B", "mp", False, False), - (2, 2, 1, 0, "meta-llama/Meta-Llama-3-8B", "mp", False, False), - (1, 3, 0, 0, "meta-llama/Meta-Llama-3-8B", "mp", False, False), - (1, 4, 0, 1, "meta-llama/Meta-Llama-3-8B", "mp", False, False), - (1, 4, 1, 0, "meta-llama/Meta-Llama-3-8B", "mp", False, False), - ]) +@pytest.mark.parametrize(("TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, " + "MODEL_NAME, DIST_BACKEND"), + [ + (2, 2, 0, 1, "meta-llama/Meta-Llama-3-8B", "ray"), + (2, 2, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray"), + (1, 3, 0, 0, "meta-llama/Meta-Llama-3-8B", "ray"), + (1, 4, 0, 1, "meta-llama/Meta-Llama-3-8B", "ray"), + (1, 4, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray"), + (2, 2, 0, 1, "meta-llama/Meta-Llama-3-8B", "mp"), + (2, 2, 1, 0, "meta-llama/Meta-Llama-3-8B", "mp"), + (1, 3, 0, 0, "meta-llama/Meta-Llama-3-8B", "mp"), + (1, 4, 0, 1, "meta-llama/Meta-Llama-3-8B", "mp"), + (1, 4, 1, 0, "meta-llama/Meta-Llama-3-8B", "mp"), + ]) def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME, - DIST_BACKEND, USE_RAY_ADAG, USE_RAY_ADAG_NCCL): + DIST_BACKEND): if VLLM_MULTI_NODE and DIST_BACKEND == "mp": pytest.skip("Skipping multi-node pipeline parallel test for " "multiprocessing distributed backend") + USE_RAY_ADAG_NCCL = 0 + USE_RAY_ADAG = 0 + pp_args = [ # use half precision for speed and memory savings in CI environment "--dtype", From 0c25435daa0a399460a676e7c9b604bd23ea2d22 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Sat, 3 Aug 2024 13:36:14 +0800 Subject: [PATCH 11/36] [Model] Refactor and decouple weight loading logic for InternVL2 model (#7067) --- vllm/model_executor/models/intern_vit.py | 11 +++- vllm/model_executor/models/internvl.py | 82 ++++++++---------------- 2 files changed, 38 insertions(+), 55 deletions(-) diff --git a/vllm/model_executor/models/intern_vit.py b/vllm/model_executor/models/intern_vit.py index c6c692deca2e1..54c933e3e4959 100644 --- a/vllm/model_executor/models/intern_vit.py +++ b/vllm/model_executor/models/intern_vit.py @@ -4,7 +4,7 @@ # Copyright (c) 2023 OpenGVLab # Licensed under The MIT License [see LICENSE for details] # -------------------------------------------------------- -from typing import Optional +from typing import Iterable, Optional, Tuple import torch import torch.nn as nn @@ -16,6 +16,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, RowParallelLinear) from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.model_loader.weight_utils import default_weight_loader NORM2FN = { 'rms_norm': RMSNorm, @@ -268,3 +269,11 @@ def forward( encoder_outputs = self.encoder(inputs_embeds=hidden_states) return encoder_outputs + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py index eabc283b1efdb..4749251271487 100644 --- a/vllm/model_executor/models/internvl.py +++ b/vllm/model_executor/models/internvl.py @@ -4,6 +4,7 @@ # Copyright (c) 2023 OpenGVLab # Licensed under The MIT License [see LICENSE for details] # -------------------------------------------------------- +import itertools from typing import Iterable, List, Literal, Optional, Tuple, TypedDict, Union import torch @@ -414,58 +415,31 @@ def sample( ) -> Optional[SamplerOutput]: return self.language_model.sample(logits, sampling_metadata) - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - (".qkv_proj", ".q_proj", "q"), - (".qkv_proj", ".k_proj", "k"), - (".qkv_proj", ".v_proj", "v"), - (".gate_up_proj", ".gate_proj", 0), - (".gate_up_proj", ".up_proj", 1), - (".gate_up_proj", ".w1", 0), - (".gate_up_proj", ".w3", 1), - ] - params_dict = dict(self.named_parameters()) + def _filter_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], + prefix: str): for name, loaded_weight in weights: - if "rotary_emb.inv_freq" in name: - continue - if self.config.text_config.tie_word_embeddings \ - and "lm_head.weight" in name: - continue - for (param_name, weight_name, shard_id) in stacked_params_mapping: - # We only do sharding for language model - # and not vision model for now. - if "vision_embed_tokens" in name and self.vision_embed_tokens: - continue - if weight_name not in name: - continue - param = params_dict[name.replace(weight_name, param_name)] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - break - else: - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - param = params_dict[name] - if "wqkv" in name: - config = self.config.text_config - kv_groups = (config.num_attention_heads // - config.num_key_value_heads) - head_dim = config.hidden_size // config.num_attention_heads - loaded_weight = loaded_weight.view(-1, 2 + kv_groups, - head_dim, - loaded_weight.shape[-1]) - wq, wk, wv = torch.split(loaded_weight, [kv_groups, 1, 1], - dim=1) - wq = wq.reshape(-1, wq.shape[-1]) - wk = wk.reshape(-1, wk.shape[-1]) - wv = wv.reshape(-1, wv.shape[-1]) - weight_loader = param.weight_loader - weight_loader(param, wq, 'q') - weight_loader(param, wk, 'k') - weight_loader(param, wv, 'v') - continue - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) + name = name.split(".") + if prefix == name.pop(0): + name = ".".join(name) + yield name, loaded_weight + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + # prepare weight iterators for components + vit_weights, mlp_weights, llm_weights = itertools.tee(weights, 3) + + # load vision encoder + vit_weights = self._filter_weights(vit_weights, "vision_model") + self.vision_model.load_weights(vit_weights) + + # load mlp projector + mlp_weights = self._filter_weights(mlp_weights, "mlp1") + mlp_params_dict = dict(self.mlp1.named_parameters()) + for name, loaded_weight in mlp_weights: + param = mlp_params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + + # load llm backbone + llm_weights = self._filter_weights(llm_weights, "language_model") + self.language_model.load_weights(llm_weights) From fb2c1c86c196aa1531435d0c445fbea4c9dd4aa5 Mon Sep 17 00:00:00 2001 From: Zach Zheng Date: Fri, 2 Aug 2024 22:38:15 -0700 Subject: [PATCH 12/36] [Bugfix] Fix block table for seqs that have prefix cache hits (#7018) --- tests/prefix_caching/test_prefix_caching.py | 56 +++++++++++++++++++++ vllm/attention/backends/flash_attn.py | 12 +++-- 2 files changed, 65 insertions(+), 3 deletions(-) diff --git a/tests/prefix_caching/test_prefix_caching.py b/tests/prefix_caching/test_prefix_caching.py index 7985001d34eb1..9821dbd066a59 100644 --- a/tests/prefix_caching/test_prefix_caching.py +++ b/tests/prefix_caching/test_prefix_caching.py @@ -6,10 +6,17 @@ import pytest +from tests.kernels.utils import override_backend_env_variable from vllm.block import PhysicalTokenBlock from vllm.core.block_manager_v1 import CachedBlockAllocator from vllm.utils import Device +from ..models.utils import check_outputs_equal + +MODELS = [ + "facebook/opt-125m", +] + @pytest.mark.parametrize("block_size", [16]) @pytest.mark.parametrize("num_blocks", [16]) @@ -76,3 +83,52 @@ def test_eviction(num_blocks: int, ): assert (realloc_block != new_block) assert (new_block.block_hash == new_block_hash) assert (new_block.block_number == 2) + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("backend", ["FLASH_ATTN", "FLASHINFER", "XFORMERS"]) +@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("max_tokens", [5]) +@pytest.mark.parametrize("cached_position", [0, 1]) +@pytest.mark.parametrize("use_v2_block_manager", [False, True]) +def test_mixed_requests( + hf_runner, + vllm_runner, + example_prompts, + model: str, + backend: str, + dtype: str, + max_tokens: int, + cached_position: int, + use_v2_block_manager: bool, + monkeypatch, +) -> None: + """ + Test the case when some sequences have the prefix cache hit + and the others don't. The cached position determines where + the sequence is at among the batch of prefills. + """ + override_backend_env_variable(monkeypatch, backend) + + with hf_runner(model, dtype=dtype) as hf_model: + hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) + + cached_prompt = example_prompts[cached_position] + with vllm_runner( + model, + dtype=dtype, + enable_prefix_caching=True, + use_v2_block_manager=use_v2_block_manager, + ) as vllm_model: + # Run the first prompt so the cache is populated + vllm_outputs = vllm_model.generate_greedy([cached_prompt], max_tokens) + + # Run all the promopts + vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) + + check_outputs_equal( + outputs_0_lst=hf_outputs, + outputs_1_lst=vllm_outputs, + name_0="hf", + name_1="vllm", + ) diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 00654dca2adfa..26b3159682b3e 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -209,6 +209,7 @@ def __init__(self, input_builder: "ModelInputForGPUBuilder"): self.num_prefills = 0 self.num_prefill_tokens = 0 self.num_decode_tokens = 0 + self.has_prefix_cache_hit = False self.input_builder = input_builder self.runner = input_builder.runner @@ -219,7 +220,7 @@ def __init__(self, input_builder: "ModelInputForGPUBuilder"): def _add_seq_group( self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup", - chunked_prefill_enabled: bool): + chunked_prefill_enabled: bool, prefix_cache_hit: bool): """Add a sequence group to the metadata. Specifically update/append 1. context length. 2. block table. @@ -252,7 +253,7 @@ def _add_seq_group( # only allowing multiple of block_size chunk size. # NOTE: This only works for oooooooxxx style attention. block_table = [] - if inter_data.prefix_cache_hit: + if prefix_cache_hit: # NOTE(woosuk): For flash-attn, the block table should # include the entries for the incoming prefill tokens. block_table = block_tables[seq_id] @@ -281,9 +282,14 @@ def build(self, seq_lens: List[int], query_lens: List[int], -1 if cuda graph is not used. batch_size: The maybe padded batch size. """ + prefix_cache_hit = any([ + inter_data.prefix_cache_hit + for inter_data in self.input_builder.inter_data_list + ]) for inter_data in self.input_builder.inter_data_list: self._add_seq_group(inter_data, - self.input_builder.chunked_prefill_enabled) + self.input_builder.chunked_prefill_enabled, + prefix_cache_hit) device = self.runner.device use_captured_graph = cuda_graph_pad_size != -1 From 99d7cabd7b8b789e837a0682982fd7ec94a843b1 Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Sat, 3 Aug 2024 13:40:19 +0800 Subject: [PATCH 13/36] [LoRA] ReplicatedLinear support LoRA (#7081) --- tests/lora/test_layers.py | 103 ++++++++++++++++++++++++++++++++++++++ vllm/lora/layers.py | 94 ++++++++++++++++++++++++++++++++++ vllm/lora/utils.py | 2 + 3 files changed, 199 insertions(+) diff --git a/tests/lora/test_layers.py b/tests/lora/test_layers.py index 6f33f56616fcd..d8cc68d5e9599 100644 --- a/tests/lora/test_layers.py +++ b/tests/lora/test_layers.py @@ -22,6 +22,7 @@ MergedColumnParallelLinearWithLoRA, MergedQKVParallelLinearWithLora, QKVParallelLinearWithLora, + ReplicatedLinearWithLoRA, RowParallelLinearWithLoRA, VocabParallelEmbeddingWithLoRA) # yapf: enable @@ -31,6 +32,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, MergedColumnParallelLinear, QKVParallelLinear, + ReplicatedLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.rotary_embedding import get_rope @@ -545,6 +547,107 @@ def _pretest(): atol=atol) +@torch.inference_mode() +@pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) +@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("stage", STAGES) +def test_linear_replicated(dist_init, num_loras, device, stage) -> None: + + torch.set_default_device(device) + punica_wrapper = PunicaWrapper(8192, 256, device) + max_loras = 8 + lora_config = LoRAConfig(max_loras=max_loras, + max_lora_rank=8, + lora_dtype=torch.float16) + + def create_random_linear_replicated_layer(): + + linear = ReplicatedLinear(4096, + 4096, + bias=False, + params_dtype=torch.float16) + linear.weight.data = torch.rand_like(linear.weight.data) + lora_linear = ReplicatedLinearWithLoRA(linear) + + lora_linear.create_lora_weights(max_loras, lora_config) + + return linear, lora_linear + + for i in range(10): + set_random_seed(i) + + id_to_index = get_random_id_to_index(num_loras, max_loras) + linear, lora_linear = create_random_linear_replicated_layer() + lora_linear.set_mapping(punica_wrapper) + lora_dict, _ = populate_loras( + id_to_index, + layer=lora_linear, + layer_weights=linear.weight, + ) + + inputs, index_mapping, prompt_mapping = create_random_inputs( + active_lora_ids=list(lora_dict.keys()), + num_inputs=32 * num_loras, + input_size=(1, 4096), + input_range=(0, 1), + input_type=torch.float16, + ) + lora_mapping = LoRAMapping(index_mapping, + prompt_mapping, + is_prefill=stage) + punica_wrapper.update_metadata( + lora_mapping, + id_to_index, + max_loras, + 512, + lora_config.lora_extra_vocab_size, + ) + + lora_result = lora_linear(torch.cat(inputs))[0] + + expected_results: List[torch.Tensor] = [] + for input_, lora_id in zip(inputs, prompt_mapping): + lora = lora_dict[lora_id] + result = linear(input_)[0] + result += input_ @ lora.lora_a @ lora.lora_b * lora.scaling + expected_results.append(result) + expected_result = torch.cat(expected_results) + + rtol, atol = TOLERANCES[lora_result.dtype] + assert torch.allclose(lora_result, + expected_result, + rtol=rtol, + atol=atol) + + # Check that resetting the lora weights succeeds + + for slot_idx in range(max_loras): + lora_linear.reset_lora(slot_idx) + + inputs, index_mapping, prompt_mapping = create_random_inputs( + active_lora_ids=[0], + num_inputs=32 * num_loras, + input_size=(1, 4096), + input_range=(0, 1), + input_type=torch.float16, + ) + lora_mapping = LoRAMapping(index_mapping, + prompt_mapping, + is_prefill=stage) + + punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras, + 512, lora_config.lora_extra_vocab_size) + + lora_result = lora_linear(torch.cat(inputs))[0] + expected_result = linear(torch.cat(inputs))[0] + + rtol, atol = TOLERANCES[lora_result.dtype] + assert torch.allclose(lora_result, + expected_result, + rtol=rtol, + atol=atol) + + @torch.inference_mode() @pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) @pytest.mark.parametrize("orientation", ["row", "column"]) diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index 3176badabbc7f..42ec99e6ea2c8 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -21,6 +21,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, MergedColumnParallelLinear, QKVParallelLinear, + ReplicatedLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.rotary_embedding import ( @@ -262,6 +263,99 @@ def can_replace_layer( return type(source_layer) is VocabParallelEmbedding +class ReplicatedLinearWithLoRA(BaseLayerWithLoRA): + + def __init__(self, base_layer: ReplicatedLinear) -> None: + super().__init__() + self.base_layer = base_layer + self.input_size = self.base_layer.input_size + self.output_size = self.base_layer.output_size + self.device = _get_lora_device(self.base_layer) + + def create_lora_weights( + self, + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None, + ) -> None: + self.lora_config = lora_config + lora_a_output_size = lora_config.max_lora_rank + self.lora_a_stacked = torch.zeros( + max_loras, + 1, + lora_a_output_size, + self.input_size, + dtype=lora_config.lora_dtype, + device=self.device, + ) + self.lora_b_stacked = torch.zeros( + max_loras, + 1, + self.output_size, + lora_config.max_lora_rank, + dtype=lora_config.lora_dtype, + device=self.device, + ) + + def reset_lora(self, index: int): + self.lora_a_stacked[index] = 0 + self.lora_b_stacked[index] = 0 + + def set_lora( + self, + index: int, + lora_a: torch.Tensor, + lora_b: torch.Tensor, + embeddings_tensor: Optional[torch.Tensor], + ): + self.reset_lora(index) + + self.lora_a_stacked[index, + 0, :lora_a.shape[1], :lora_a.shape[0]].copy_( + lora_a.T, non_blocking=True) + self.lora_b_stacked[index, + 0, :lora_b.shape[1], :lora_b.shape[0]].copy_( + lora_b.T, non_blocking=True) + + def apply(self, x: torch.Tensor, + bias: Optional[torch.Tensor]) -> torch.Tensor: + output = self.base_layer.quant_method.apply(self.base_layer, x, bias) + self.punica_wrapper.add_lora(output, x, self.lora_a_stacked, + self.lora_b_stacked, 1.0) + return output + + def forward(self, input_): + """Forward of ReplicatedLinearWithLoRA + + Args: + input_: Tensor whose last dimension is `input_size`. + + Returns: + - output + - bias + """ + bias = (self.base_layer.bias + if not self.base_layer.skip_bias_add else None) + + # Matrix multiply. + output = self.apply(input_, bias) + + output_bias = (self.base_layer.bias + if self.base_layer.skip_bias_add else None) + return output, output_bias + + @classmethod + @_not_fully_sharded_can_replace + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: List, + model_config: Optional[PretrainedConfig], + ) -> bool: + return type(source_layer) is ReplicatedLinear + + class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA): """ LoRA on top of ColumnParallelLinear layer. diff --git a/vllm/lora/utils.py b/vllm/lora/utils.py index 4513337299e16..ee983328e2c5b 100644 --- a/vllm/lora/utils.py +++ b/vllm/lora/utils.py @@ -23,6 +23,7 @@ MergedColumnParallelLinearWithLoRA, MergedQKVParallelLinearWithLora, QKVParallelLinearWithLora, + ReplicatedLinearWithLoRA, RowParallelLinearWithLoRA, VocabParallelEmbeddingWithLoRA) # yapf: enable @@ -38,6 +39,7 @@ QKVParallelLinearWithLora, MergedQKVParallelLinearWithLora, RowParallelLinearWithLoRA, + ReplicatedLinearWithLoRA, LogitsProcessorWithLoRA, ColumnParallelLinearWithShardedLoRA, QKVParallelLinearWithShardedLora, From 67d745cc68d9ad31bf683a88f00a1aee9782f541 Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Fri, 2 Aug 2024 23:52:44 -0700 Subject: [PATCH 14/36] [CI] Temporarily turn off H100 performance benchmark (#7104) --- .../benchmark-pipeline.yaml | 32 +++++++++---------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/.buildkite/nightly-benchmarks/benchmark-pipeline.yaml b/.buildkite/nightly-benchmarks/benchmark-pipeline.yaml index 02c0ee534d72c..8490c9f1da221 100644 --- a/.buildkite/nightly-benchmarks/benchmark-pipeline.yaml +++ b/.buildkite/nightly-benchmarks/benchmark-pipeline.yaml @@ -42,20 +42,20 @@ steps: - name: devshm emptyDir: medium: Memory - - label: "H100" - agents: - queue: H100 - plugins: - - docker#v5.11.0: - image: public.ecr.aws/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT - command: - - bash - - .buildkite/nightly-benchmarks/run-benchmarks-suite.sh - mount-buildkite-agent: true - propagate-environment: true - ipc: host - gpus: all - environment: - - VLLM_USAGE_SOURCE - - HF_TOKEN + # - label: "H100" + # agents: + # queue: H100 + # plugins: + # - docker#v5.11.0: + # image: public.ecr.aws/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT + # command: + # - bash + # - .buildkite/nightly-benchmarks/run-benchmarks-suite.sh + # mount-buildkite-agent: true + # propagate-environment: true + # ipc: host + # gpus: all + # environment: + # - VLLM_USAGE_SOURCE + # - HF_TOKEN From 44dcb52e39ee6b2c9ef9e6497525e1e183c9d24b Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sat, 3 Aug 2024 10:44:53 -0700 Subject: [PATCH 15/36] [ci][test] finalize fork_new_process_for_each_test (#7114) --- tests/utils.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/utils.py b/tests/utils.py index 974fece49f4b4..666694299d397 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -360,6 +360,9 @@ def wait_for_gpu_memory_to_clear(devices: List[int], def fork_new_process_for_each_test(f): + """Decorator to fork a new process for each test function. + See https://github.com/vllm-project/vllm/issues/7053 for more details. + """ @functools.wraps(f) def wrapper(*args, **kwargs): From 825b044863a8e3af82a82a80cd2617486cc829ca Mon Sep 17 00:00:00 2001 From: Jeff Fialho Date: Sat, 3 Aug 2024 20:01:38 -0300 Subject: [PATCH 16/36] [Frontend] Warn if user `max_model_len` is greater than derived `max_model_len` (#7080) Signed-off-by: Jefferson Fialho Co-authored-by: Nick Hill --- vllm/config.py | 19 +++++++++++++------ vllm/envs.py | 10 ++++++++++ 2 files changed, 23 insertions(+), 6 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index ef56e2b6395be..028f4eed8f4a2 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -6,6 +6,7 @@ import torch from transformers import PretrainedConfig +import vllm.envs as envs from vllm.logger import init_logger from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS from vllm.model_executor.models import ModelRegistry @@ -1541,15 +1542,21 @@ def _get_and_verify_max_len( "Disabling sliding window is not supported for models " "model_max_length in the config. Please raise an issue " "so we can investigate.") - pass else: - raise ValueError( + msg = ( f"User-specified max_model_len ({max_model_len}) is greater " - "than the derived max_model_len " - f"({max_len_key}={derived_max_model_len} or model_max_length=" + f"than the derived max_model_len ({max_len_key}=" + f"{derived_max_model_len} or model_max_length=" f"{model_max_length} in model's config.json). This may lead " - "to incorrect model outputs or CUDA errors. Make sure the " - "value is correct and within the model context size.") + "to incorrect model outputs or CUDA errors.") + if envs.VLLM_ALLOW_LONG_MAX_MODEL_LEN: + logger.warning( + "%s Make sure the value is correct and within the " + "model context size.", msg) + else: + raise ValueError( + f"{msg} To allow overriding this maximum, set " + "the env var VLLM_ALLOW_LONG_MAX_MODEL_LEN=1") return int(max_model_len) diff --git a/vllm/envs.py b/vllm/envs.py index a78bad6a2b273..089a39d8e029d 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -50,6 +50,7 @@ VLLM_NO_DEPRECATION_WARNING: bool = False CMAKE_BUILD_TYPE: Optional[str] = None VERBOSE: bool = False + VLLM_ALLOW_LONG_MAX_MODEL_LEN: bool = False def get_default_cache_root(): @@ -331,6 +332,15 @@ def get_default_config_root(): # If set, vllm will skip the deprecation warnings. "VLLM_NO_DEPRECATION_WARNING": lambda: bool(int(os.getenv("VLLM_NO_DEPRECATION_WARNING", "0"))), + + # If the env var VLLM_ALLOW_LONG_MAX_MODEL_LEN is set, it allows + # the user to specify a max sequence length greater than + # the max length derived from the model's config.json. + # To enable this, set VLLM_ALLOW_LONG_MAX_MODEL_LEN=1. + "VLLM_ALLOW_LONG_MAX_MODEL_LEN": + lambda: + (os.environ.get("VLLM_ALLOW_LONG_MAX_MODEL_LEN", "0").strip().lower() in + ("1", "true")), } # end-env-vars-definition From 654bc5ca49bde0969bc95e4b1dbe7fabbb8f631c Mon Sep 17 00:00:00 2001 From: Yihuan Bu <88394319+kevinbu233@users.noreply.github.com> Date: Sat, 3 Aug 2024 23:12:09 -0400 Subject: [PATCH 17/36] Support for guided decoding for offline LLM (#6878) Co-authored-by: Cyrus Leung --- docs/source/conf.py | 1 + tests/entrypoints/{openai => }/conftest.py | 22 ++- tests/entrypoints/llm/test_guided_generate.py | 142 ++++++++++++++++++ vllm/entrypoints/llm.py | 44 +++++- vllm/entrypoints/openai/protocol.py | 26 +++- .../guided_decoding/__init__.py | 26 +++- .../guided_decoding/guided_fields.py | 38 +++++ .../lm_format_enforcer_decoding.py | 39 +++++ .../guided_decoding/outlines_decoding.py | 26 +++- 9 files changed, 352 insertions(+), 12 deletions(-) rename tests/entrypoints/{openai => }/conftest.py (83%) create mode 100644 tests/entrypoints/llm/test_guided_generate.py create mode 100644 vllm/model_executor/guided_decoding/guided_fields.py diff --git a/docs/source/conf.py b/docs/source/conf.py index 1093b30bca11d..f1eb8524d4e9c 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -111,6 +111,7 @@ def setup(app): "tqdm", "tensorizer", "pynvml", + "outlines", ] for mock_target in autodoc_mock_imports: diff --git a/tests/entrypoints/openai/conftest.py b/tests/entrypoints/conftest.py similarity index 83% rename from tests/entrypoints/openai/conftest.py rename to tests/entrypoints/conftest.py index 0837644f26bde..e7ef5637c8ccb 100644 --- a/tests/entrypoints/openai/conftest.py +++ b/tests/entrypoints/conftest.py @@ -1,6 +1,26 @@ import pytest +@pytest.fixture +def sample_prompts(): + return [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + + +@pytest.fixture +def sample_token_ids(): + return [ + [0], + [0, 1], + [0, 2, 1], + [0, 3, 1, 2], + ] + + @pytest.fixture def sample_regex(): return (r"((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.){3}" @@ -66,4 +86,4 @@ def sample_sql_statements(): table: "table_1" | "table_2" condition: column "=" number number: "1" | "2" -""") \ No newline at end of file +""") diff --git a/tests/entrypoints/llm/test_guided_generate.py b/tests/entrypoints/llm/test_guided_generate.py new file mode 100644 index 0000000000000..873e115421257 --- /dev/null +++ b/tests/entrypoints/llm/test_guided_generate.py @@ -0,0 +1,142 @@ +import json +import re +import weakref + +import jsonschema +import pytest + +from vllm.entrypoints.llm import LLM +from vllm.outputs import RequestOutput +from vllm.sampling_params import SamplingParams + +from ...conftest import cleanup + +MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" + + +@pytest.fixture(scope="module") +def llm(): + # pytest caches the fixture so we use weakref.proxy to + # enable garbage collection + llm = LLM(model=MODEL_NAME, max_model_len=1024) + + with llm.deprecate_legacy_api(): + yield weakref.proxy(llm) + del llm + cleanup() + + +@pytest.mark.skip_global_cleanup +def test_guided_regex(sample_regex, llm): + sampling_params = SamplingParams( + temperature=0.8, + top_p=0.95, + ) + outputs = llm.generate( + prompts=[ + f"Give an example IPv4 address with this regex: {sample_regex}" + ] * 2, + sampling_params=sampling_params, + use_tqdm=True, + guided_options_request=dict(guided_regex=sample_regex)) + + assert outputs is not None + for output in outputs: + assert output is not None + assert isinstance(output, RequestOutput) + prompt = output.prompt + generated_text = output.outputs[0].text + print(generated_text) + assert generated_text is not None + assert re.fullmatch(sample_regex, generated_text) is not None + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + + +@pytest.mark.skip_global_cleanup +def test_guided_json_completion(sample_json_schema, llm): + sampling_params = SamplingParams( + temperature=1.0, + max_tokens=1000, + ) + outputs = llm.generate( + prompts=[ + f"Give an example JSON for an employee profile " + f"that fits this schema: {sample_json_schema}" + ] * 2, + sampling_params=sampling_params, + use_tqdm=True, + guided_options_request=dict(guided_json=sample_json_schema)) + + assert outputs is not None + + for output in outputs: + assert output is not None + assert isinstance(output, RequestOutput) + prompt = output.prompt + + generated_text = output.outputs[0].text + assert generated_text is not None + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + output_json = json.loads(generated_text) + jsonschema.validate(instance=output_json, schema=sample_json_schema) + + +@pytest.mark.skip_global_cleanup +def test_guided_choice_completion(sample_guided_choice, llm): + sampling_params = SamplingParams( + temperature=0.8, + top_p=0.95, + ) + outputs = llm.generate( + prompts="The best language for type-safe systems programming is ", + sampling_params=sampling_params, + use_tqdm=True, + guided_options_request=dict(guided_choice=sample_guided_choice)) + + assert outputs is not None + for output in outputs: + assert output is not None + assert isinstance(output, RequestOutput) + prompt = output.prompt + generated_text = output.outputs[0].text + print(generated_text) + assert generated_text is not None + assert generated_text in sample_guided_choice + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + + +@pytest.mark.skip_global_cleanup +def test_guided_grammar(sample_sql_statements, llm): + + sampling_params = SamplingParams( + temperature=0.8, + top_p=0.95, + max_tokens=1000, + ) + outputs = llm.generate( + prompts=("Generate a sql state that select col_1 from " + "table_1 where it is equals to 1"), + sampling_params=sampling_params, + use_tqdm=True, + guided_options_request=dict(guided_grammar=sample_sql_statements)) + + assert outputs is not None + for output in outputs: + assert output is not None + assert isinstance(output, RequestOutput) + prompt = output.prompt + + generated_text = output.outputs[0].text + assert generated_text is not None + # use Lark to parse the output, and make sure it's a valid parse tree + from lark import Lark + parser = Lark(sample_sql_statements) + parser.parse(generated_text) + + # remove spaces for comparison b/c we removed them in the grammar + ground_truth = "SELECT col_1 from table_1 where col_1 = 1".replace( + " ", "") + + assert generated_text.strip() == ground_truth + + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 62309ed345b1d..262cba79e5712 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -10,6 +10,9 @@ parse_and_batch_prompt) from vllm.logger import init_logger from vllm.lora.request import LoRARequest +from vllm.model_executor.guided_decoding import ( + GuidedDecodingRequest, get_local_guided_decoding_logits_processor) +from vllm.model_executor.guided_decoding.guided_fields import LLMGuidedOptions from vllm.outputs import EmbeddingRequestOutput, RequestOutput from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest @@ -262,6 +265,8 @@ def generate( use_tqdm: bool = True, lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, + guided_options_request: Optional[Union[LLMGuidedOptions, + GuidedDecodingRequest]] = None ) -> List[RequestOutput]: """Generates the completions for the input prompts. @@ -303,6 +308,14 @@ def generate( else: inputs = cast(Union[PromptInputs, Sequence[PromptInputs]], prompts) + if isinstance(guided_options_request, dict): + if len(guided_options_request) > 1: + raise ValueError( + "You can only use one guided decoding but multiple is " + f"specified: {guided_options_request}") + guided_options_request = GuidedDecodingRequest( + **guided_options_request) + if sampling_params is None: # Use default sampling params. sampling_params = SamplingParams() @@ -311,7 +324,8 @@ def generate( inputs=inputs, params=sampling_params, lora_request=lora_request, - prompt_adapter_request=prompt_adapter_request) + prompt_adapter_request=prompt_adapter_request, + guided_options=guided_options_request) outputs = self._run_engine(use_tqdm=use_tqdm) return LLMEngine.validate_outputs(outputs, RequestOutput) @@ -508,6 +522,7 @@ def _validate_and_add_requests( Sequence[PoolingParams]], lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]], prompt_adapter_request: Optional[PromptAdapterRequest], + guided_options: Optional[GuidedDecodingRequest] = None, ) -> None: if isinstance(inputs, (str, dict)): # Convert a single prompt to a list. @@ -523,6 +538,15 @@ def _validate_and_add_requests( raise ValueError("The lengths of prompts and lora_request " "must be the same.") + if isinstance(params, list): + params = [ + self._add_guided_processor(param, guided_options) + if isinstance(param, SamplingParams) else param + for param in params + ] + elif isinstance(params, SamplingParams): + params = self._add_guided_processor(params, guided_options) + # Add requests to the engine. for i, request_inputs in enumerate(inputs): self._add_request( @@ -548,6 +572,24 @@ def _add_request( lora_request=lora_request, prompt_adapter_request=prompt_adapter_request) + def _add_guided_processor( + self, + params: SamplingParams, + guided_options: Optional[GuidedDecodingRequest] = None): + if guided_options: + if guided_options.guided_decoding_backend is None: + decoding_config = self.llm_engine.get_decoding_config() + guided_options.guided_decoding_backend = ( + decoding_config.guided_decoding_backend) + guided_logits_processor = get_local_guided_decoding_logits_processor( #noqa + guided_options.guided_decoding_backend, guided_options, + self.get_tokenizer()) + if guided_logits_processor: + if params.logits_processors is None: + params.logits_processors = [] + params.logits_processors.append(guided_logits_processor) + return params + def _run_engine( self, *, use_tqdm: bool ) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 3b35ae1ebd705..76318a1271229 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -1,6 +1,7 @@ # Adapted from # https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py import time +from argparse import Namespace from typing import Any, Dict, List, Literal, Optional, Union import torch @@ -14,6 +15,23 @@ from vllm.sampling_params import LogitsProcessor, SamplingParams from vllm.utils import random_uuid +# torch is mocked during docs generation, +# so we have to provide the values as literals +_MOCK_LONG_INFO = Namespace(min=-9223372036854775808, max=9223372036854775807) + +try: + from sphinx.ext.autodoc.mock import _MockModule + + if isinstance(torch, _MockModule): + _LONG_INFO = _MOCK_LONG_INFO + else: + _LONG_INFO = torch.iinfo(torch.long) +except ModuleNotFoundError: + _LONG_INFO = torch.iinfo(torch.long) + +assert _LONG_INFO.min == _MOCK_LONG_INFO.min +assert _LONG_INFO.max == _MOCK_LONG_INFO.max + class OpenAIBaseModel(BaseModel): # OpenAI API does not allow extra fields @@ -108,9 +126,7 @@ class ChatCompletionRequest(OpenAIBaseModel): n: Optional[int] = 1 presence_penalty: Optional[float] = 0.0 response_format: Optional[ResponseFormat] = None - seed: Optional[int] = Field(None, - ge=torch.iinfo(torch.long).min, - le=torch.iinfo(torch.long).max) + seed: Optional[int] = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max) stop: Optional[Union[str, List[str]]] = Field(default_factory=list) stream: Optional[bool] = False stream_options: Optional[StreamOptions] = None @@ -327,9 +343,7 @@ class CompletionRequest(OpenAIBaseModel): max_tokens: Optional[int] = 16 n: int = 1 presence_penalty: Optional[float] = 0.0 - seed: Optional[int] = Field(None, - ge=torch.iinfo(torch.long).min, - le=torch.iinfo(torch.long).max) + seed: Optional[int] = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max) stop: Optional[Union[str, List[str]]] = Field(default_factory=list) stream: Optional[bool] = False stream_options: Optional[StreamOptions] = None diff --git a/vllm/model_executor/guided_decoding/__init__.py b/vllm/model_executor/guided_decoding/__init__.py index 50aa3ec379f4a..4a2476dd6314d 100644 --- a/vllm/model_executor/guided_decoding/__init__.py +++ b/vllm/model_executor/guided_decoding/__init__.py @@ -3,9 +3,10 @@ from vllm.entrypoints.openai.protocol import ( ChatCompletionNamedToolChoiceParam, ChatCompletionRequest, CompletionRequest) -from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import ( - get_lm_format_enforcer_guided_decoding_logits_processor) +from vllm.model_executor.guided_decoding.guided_fields import ( + GuidedDecodingRequest) from vllm.model_executor.guided_decoding.outlines_decoding import ( + get_local_outlines_guided_decoding_logits_processor, get_outlines_guided_decoding_logits_processor) from vllm.sampling_params import LogitsProcessor @@ -20,6 +21,8 @@ async def get_guided_decoding_logits_processor( return await get_outlines_guided_decoding_logits_processor( request, tokenizer) if guided_decoding_backend == 'lm-format-enforcer': + from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import ( # noqa + get_lm_format_enforcer_guided_decoding_logits_processor) return await get_lm_format_enforcer_guided_decoding_logits_processor( request, tokenizer) @@ -28,6 +31,25 @@ async def get_guided_decoding_logits_processor( "Must be one of 'outlines, 'lm-format-enforcer'") +def get_local_guided_decoding_logits_processor( + guided_decoding_backend: str, guided_options: GuidedDecodingRequest, + tokenizer) -> Optional[LogitsProcessor]: + # request = _adapt_request_for_tool_use(request) + + if guided_decoding_backend == 'outlines': + return get_local_outlines_guided_decoding_logits_processor( + guided_options, tokenizer) + if guided_decoding_backend == 'lm-format-enforcer': + from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import ( # noqa + get_local_lm_format_enforcer_guided_decoding_logits_processor) + return get_local_lm_format_enforcer_guided_decoding_logits_processor( + guided_options, tokenizer) + + raise ValueError( + f"Unknown guided decoding backend '{guided_decoding_backend}'. " + "Must be one of 'outlines, 'lm-format-enforcer'") + + def _adapt_request_for_tool_use(request: Union[CompletionRequest, ChatCompletionRequest]): # the legacy completion API does not support tool use diff --git a/vllm/model_executor/guided_decoding/guided_fields.py b/vllm/model_executor/guided_decoding/guided_fields.py new file mode 100644 index 0000000000000..3082ac1510ccc --- /dev/null +++ b/vllm/model_executor/guided_decoding/guided_fields.py @@ -0,0 +1,38 @@ +from dataclasses import dataclass +from typing import Dict, List, Optional, TypedDict, Union + +from pydantic import BaseModel + + +class LLMGuidedOptions(TypedDict, total=False): + guided_json: Union[Dict, BaseModel, str] + guided_regex: str + guided_choice: List[str] + guided_grammar: str + guided_decoding_backend: str + guided_whitespace_pattern: str + guided_json_object: bool + + +@dataclass +class GuidedDecodingRequest: + """One of the fields will be used to retrieve the logit processor.""" + guided_json: Optional[Union[Dict, BaseModel, str]] = None + guided_regex: Optional[str] = None + guided_choice: Optional[List[str]] = None + guided_grammar: Optional[str] = None + guided_decoding_backend: Optional[str] = None + guided_whitespace_pattern: Optional[str] = None + guided_json_object: Optional[bool] = None + + def __post_init__(self): + """Validate that some fields are mutually exclusive.""" + guide_count = sum([ + self.guided_json is not None, self.guided_regex is not None, + self.guided_choice is not None, self.guided_grammar is not None, + self.guided_json_object is not None + ]) + if guide_count > 1: + raise ValueError( + "You can only use one kind of guided decoding but multiple are " + f"specified: {self.__dict__}") diff --git a/vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py b/vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py index d0a5ca5592f9d..b2188c9cbc2bb 100644 --- a/vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py +++ b/vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py @@ -12,7 +12,10 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, CompletionRequest) +from vllm.model_executor.guided_decoding.guided_fields import ( + GuidedDecodingRequest) from vllm.model_executor.guided_decoding.outlines_decoding import ( + get_local_outlines_guided_decoding_logits_processor, get_outlines_guided_decoding_logits_processor) from vllm.sampling_params import LogitsProcessor @@ -54,6 +57,42 @@ async def get_lm_format_enforcer_guided_decoding_logits_processor( return logits_processor +def get_local_lm_format_enforcer_guided_decoding_logits_processor( + guided_options: GuidedDecodingRequest, + tokenizer) -> Optional[LogitsProcessor]: + """ + Given an OpenAI-compatible request, check for guided decoding parameters + and get the necessary logits processor for the given guide. + We cache logit processors by (guide, tokenizer), and on cache hit + we make a shallow copy to reuse the same underlying FSM. + """ + + tokenizer_data = _cached_build_vllm_token_enforcer_tokenizer_data( + tokenizer) + character_level_parser: CharacterLevelParser + if guided_options.guided_json: + schema = _normalize_json_schema_object(guided_options.guided_json) + character_level_parser = JsonSchemaParser(schema) + elif guided_options.guided_choice: + character_level_parser = UnionParser( + [StringParser(choice) for choice in guided_options.guided_choice]) + elif guided_options.guided_regex: + character_level_parser = RegexParser(guided_options.guided_regex) + elif guided_options.guided_grammar: + # CFG grammar not supported by LMFE, revert to outlines + return get_local_outlines_guided_decoding_logits_processor( + guided_options, tokenizer) + elif guided_options.guided_json_object: + # None means any json object + character_level_parser = JsonSchemaParser(None) + else: + return None + + logits_processor = build_vllm_logits_processor(tokenizer_data, + character_level_parser) + return logits_processor + + def _normalize_json_schema_object(schema: Union[str, dict, BaseModel]) -> dict: if isinstance(schema, str): return json_loads(schema) diff --git a/vllm/model_executor/guided_decoding/outlines_decoding.py b/vllm/model_executor/guided_decoding/outlines_decoding.py index 721f7e0530cb7..bc62224dabecf 100644 --- a/vllm/model_executor/guided_decoding/outlines_decoding.py +++ b/vllm/model_executor/guided_decoding/outlines_decoding.py @@ -10,6 +10,8 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, CompletionRequest) +from vllm.model_executor.guided_decoding.guided_fields import ( + GuidedDecodingRequest) from vllm.model_executor.guided_decoding.outlines_logits_processors import ( CFGLogitsProcessor, JSONLogitsProcessor, RegexLogitsProcessor) @@ -77,8 +79,27 @@ async def get_outlines_guided_decoding_logits_processor( mode, request.guided_whitespace_pattern) +def get_local_outlines_guided_decoding_logits_processor( + guided_options: GuidedDecodingRequest, tokenizer: PreTrainedTokenizerBase +) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor, + None]: + """ + Given an OpenAI-compatible request, check for guided decoding parameters + and get the necessary logits processor for the given guide. + We cache logit processors by (guide, tokenizer), and on cache hit + we make a shallow copy to reuse the same underlying FSM. + """ + guide, mode = _get_guide_and_mode(guided_options) + if not guide or not mode: + return None + + return _get_logits_processor(guide, tokenizer, mode, + guided_options.guided_whitespace_pattern) + + def _get_guide_and_mode( - request: Union[CompletionRequest, ChatCompletionRequest] + request: Union[CompletionRequest, ChatCompletionRequest, + GuidedDecodingRequest] ) -> Union[Tuple[str, GuidedDecodingMode], Tuple[None, None]]: if request.guided_json: @@ -102,7 +123,8 @@ def _get_guide_and_mode( return choices_regex, GuidedDecodingMode.CHOICE elif request.guided_grammar: return request.guided_grammar, GuidedDecodingMode.GRAMMAR - elif (request.response_format is not None + elif (not isinstance(request, GuidedDecodingRequest) + and request.response_format is not None and request.response_format.type == "json_object"): return JSON_GRAMMAR, GuidedDecodingMode.GRAMMAR else: From 9fadc7b7a03f798036d0e8710587870e13bae759 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sat, 3 Aug 2024 22:03:46 -0700 Subject: [PATCH 18/36] [misc] add zmq in collect env (#7119) --- collect_env.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/collect_env.py b/collect_env.py index 083cb768f5399..244e4ddd5aed5 100644 --- a/collect_env.py +++ b/collect_env.py @@ -65,6 +65,7 @@ "optree", "nccl", "transformers", + "zmq", } DEFAULT_PIP_PATTERNS = { @@ -77,6 +78,7 @@ "onnx", "nccl", "transformers", + "zmq", } From 83c644fe7ecee05d3ebe5057acb6e008d7e81eb8 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sun, 4 Aug 2024 00:22:19 -0700 Subject: [PATCH 19/36] [core][misc] simply output processing with shortcut code path (#7117) --- vllm/engine/output_processor/single_step.py | 39 ++++++++++++++++----- 1 file changed, 31 insertions(+), 8 deletions(-) diff --git a/vllm/engine/output_processor/single_step.py b/vllm/engine/output_processor/single_step.py index 59eb4bc439d1f..4a46c93f84256 100644 --- a/vllm/engine/output_processor/single_step.py +++ b/vllm/engine/output_processor/single_step.py @@ -81,6 +81,29 @@ def process_prompt_logprob(self, seq_group: SequenceGroup, def _process_sequence_group_outputs(self, seq_group: SequenceGroup, outputs: SequenceGroupOutput) -> None: + sampling_params = seq_group.sampling_params + if sampling_params.n == 1 and not sampling_params.use_beam_search: + # only have one output sample + sample = outputs.samples[0] + # only have one sequence + seq = seq_group.seqs[0] + seq.append_token_id(sample.output_token, sample.logprobs) + if sampling_params.detokenize and self.detokenizer: + new_char_count = self.detokenizer.decode_sequence_inplace( + seq, sampling_params) + else: + new_char_count = 0 + self.stop_checker.maybe_stop_sequence( + seq, + new_char_count, + sampling_params, + lora_req=seq_group.lora_request, + ) + if seq.is_finished(): + for scheduler in self.scheduler: + scheduler.free_seq(seq) + return + # Process samples samples = outputs.samples parent_seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING) @@ -127,20 +150,20 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup, child_seqs.append((parent, parent)) for seq, _ in child_seqs: - if seq_group.sampling_params.detokenize and self.detokenizer: + if sampling_params.detokenize and self.detokenizer: new_char_count = self.detokenizer.decode_sequence_inplace( - seq, seq_group.sampling_params) + seq, sampling_params) else: new_char_count = 0 self.stop_checker.maybe_stop_sequence( seq, new_char_count, - seq_group.sampling_params, + sampling_params, lora_req=seq_group.lora_request, ) # Non-beam search case - if not seq_group.sampling_params.use_beam_search: + if not sampling_params.use_beam_search: # For newly created child sequences, add them to the sequence group # and fork them in block manager if they are not finished. for seq, parent in child_seqs: @@ -164,8 +187,8 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup, # Select the child sequences to keep in the sequence group. selected_child_seqs: List[Tuple[Sequence, Optional[Sequence]]] = [] unselected_child_seqs: List[Tuple[Sequence, Optional[Sequence]]] = [] - beam_width = seq_group.sampling_params.best_of - length_penalty = seq_group.sampling_params.length_penalty + beam_width = sampling_params.best_of + length_penalty = sampling_params.length_penalty # Select the newly finished sequences with the highest scores # to replace existing finished sequences. @@ -219,8 +242,8 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup, best_running_seq = running_child_seqs[0][0] current_worst_seq = all_finished_seqs[beam_width - 1][0] stop_beam_search = self._check_beam_search_early_stopping( - seq_group.sampling_params.early_stopping, - seq_group.sampling_params, best_running_seq, current_worst_seq) + sampling_params.early_stopping, sampling_params, + best_running_seq, current_worst_seq) if stop_beam_search: # Stop the beam search and remove all the running sequences from From 179a6a36f2a585df49ce9c26701b1b9d894bd00e Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Sun, 4 Aug 2024 16:12:41 +0800 Subject: [PATCH 20/36] [Model]Refactor MiniCPMV (#7020) Co-authored-by: Cyrus Leung --- docs/source/models/supported_models.rst | 2 +- .../models/idefics2_vision_model.py | 296 +++++ vllm/model_executor/models/minicpmv.py | 1023 ++++++++++------- vllm/model_executor/models/na_vit.py | 2 +- 4 files changed, 937 insertions(+), 386 deletions(-) create mode 100644 vllm/model_executor/models/idefics2_vision_model.py diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index a1ea366b82b04..fd5d154006ae7 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -220,7 +220,7 @@ Vision Language Models - Phi-3-Vision - :code:`microsoft/Phi-3-vision-128k-instruct`, etc. - - * - :code:`MiniCPM-V` + * - :code:`MiniCPMV` - MiniCPM-V - :code:`openbmb/MiniCPM-V-2` (see note), :code:`openbmb/MiniCPM-Llama3-V-2_5`, etc. - diff --git a/vllm/model_executor/models/idefics2_vision_model.py b/vllm/model_executor/models/idefics2_vision_model.py new file mode 100644 index 0000000000000..cc448ed28d2dc --- /dev/null +++ b/vllm/model_executor/models/idefics2_vision_model.py @@ -0,0 +1,296 @@ +# coding=utf-8 + +# adapted from https://github.com/huggingface/transformers/blob/v4.43.2/src/transformers/models/idefics2/modeling_idefics2.py +# Copyright 2024 The vLLM team. +# Copyright 2024 the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Idefics2 model.""" + +from typing import Optional + +import torch +from torch import nn +from transformers.models.idefics2.configuration_idefics2 import ( + Idefics2Config, Idefics2VisionConfig) +from xformers import ops as xops + +from vllm.distributed import divide, get_tensor_model_parallel_world_size +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.quantization import QuantizationConfig + + +class Idefics2VisionEmbeddings(nn.Module): + """ + This is a modified version of `siglip.modelign_siglip.SiglipVisionEmbeddings + ` to enable images of variable + resolution. + + The modifications are adapted from [Patch n' Pack: NaViT, a Vision + Transformer for any Aspect Ratio and Resolution](https://arxiv.org/abs/2307.06304) + which allows treating images in their native aspect ratio and without the + need to resize them to the same fixed size. In particular, we start from the + original pre-trained SigLIP model(which uses images of fixed-size square + images) and adapt it by training on images of variable resolutions. + """ + + def __init__(self, config: Idefics2VisionConfig): + super().__init__() + self.embed_dim = config.hidden_size + self.image_size = config.image_size + self.patch_size = config.patch_size + self.patch_embedding = nn.Conv2d( + in_channels=config.num_channels, + out_channels=self.embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size, + padding="valid", + ) + self.num_patches_per_side = self.image_size // self.patch_size + self.num_patches = self.num_patches_per_side**2 + self.num_positions = self.num_patches + self.position_embedding = nn.Embedding(self.num_positions, + self.embed_dim) + + def forward( + self, + pixel_values: torch.FloatTensor, + patch_attention_mask: torch.BoolTensor, + ) -> torch.Tensor: + batch_size, _, max_im_h, max_im_w = pixel_values.shape + patch_embeds = self.patch_embedding(pixel_values) + embeddings = patch_embeds.flatten(2).transpose(1, 2) + max_nb_patches_h, max_nb_patches_w = ( + max_im_h // self.patch_size, + max_im_w // self.patch_size, + ) + boundaries = torch.arange(1 / self.num_patches_per_side, 1.0, + 1 / self.num_patches_per_side) + position_ids = torch.full(size=(batch_size, + max_nb_patches_h * max_nb_patches_w), + fill_value=0) + + for batch_idx, p_attn_mask in enumerate(patch_attention_mask): + nb_patches_h = p_attn_mask[:, 0].sum() + nb_patches_w = p_attn_mask[0].sum() + fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h) + fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w) + bucket_coords_h = torch.bucketize(fractional_coords_h, + boundaries, + right=True) + bucket_coords_w = torch.bucketize(fractional_coords_w, + boundaries, + right=True) + pos_ids = (bucket_coords_h[:, None] * self.num_patches_per_side + + bucket_coords_w).flatten() + position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids + position_ids = position_ids.to(self.position_embedding.weight.device) + embeddings = embeddings + self.position_embedding(position_ids) + return embeddings + + +class Idefics2VisionAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + config: Idefics2Config, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" # noqa: E501 + f" {self.num_heads}).") + self.scale = self.head_dim**-0.5 + self.dropout = config.attention_dropout + self.qkv_proj = QKVParallelLinear( + self.embed_dim, + self.head_dim, + self.num_heads, + quant_config=quant_config, + ) + self.out_proj = RowParallelLinear( + self.embed_dim, + self.embed_dim, + bias=True, + quant_config=quant_config, + ) + self.tp_size = get_tensor_model_parallel_world_size() + self.num_heads_per_partition = divide(self.num_heads, self.tp_size) + self.is_causal = False + + def forward( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + batch_size, q_len, _ = hidden_states.size() + qkv, _ = self.qkv_proj( + hidden_states + ) # batch_size, q_len, 3 * num_heads_per_partition * head_dim + query_states, key_states, value_states = qkv.chunk(3, dim=-1) + query_states = query_states.view(batch_size, q_len, + self.num_heads_per_partition, + self.head_dim) + key_states = key_states.view(batch_size, q_len, + self.num_heads_per_partition, + self.head_dim) + value_states = value_states.view(batch_size, q_len, + self.num_heads_per_partition, + self.head_dim) + # see: https://facebookresearch.github.io/xformers/components/ops.html + out = xops.memory_efficient_attention_forward( + query_states, + key_states, + value_states, + p=self.dropout, + scale=self.scale, + ) + out = out.view(batch_size, q_len, -1) + attn_output, _ = self.out_proj(out) + return attn_output + + +class Idefics2VisionMLP(nn.Module): + + def __init__( + self, + config: Idefics2Config, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.config = config + self.activation_fn = get_act_fn(config.hidden_act) + self.fc1 = ColumnParallelLinear( + config.hidden_size, + config.intermediate_size, + bias=True, + quant_config=quant_config, + ) + self.fc2 = RowParallelLinear( + config.intermediate_size, + config.hidden_size, + bias=True, + quant_config=quant_config, + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states, _ = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states, _ = self.fc2(hidden_states) + return hidden_states + + +class Idefics2EncoderLayer(nn.Module): + + def __init__(self, config: Idefics2Config): + super().__init__() + self.embed_dim = config.hidden_size + self.self_attn = Idefics2VisionAttention(config) + self.layer_norm1 = nn.LayerNorm(self.embed_dim, + eps=config.layer_norm_eps) + self.mlp = Idefics2VisionMLP(config) + self.layer_norm2 = nn.LayerNorm(self.embed_dim, + eps=config.layer_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + """ + Args: + hidden_states (`torch.FloatTensor`): + Input to the layer of shape `(batch, seq_len, embed_dim)`. + + """ + residual = hidden_states + hidden_states = self.layer_norm1(hidden_states) + hidden_states = self.self_attn(hidden_states) + hidden_states = residual + hidden_states + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + return hidden_states + + +class Idefics2Encoder(nn.Module): + """ + Transformer encoder consisting of `config.num_hidden_layers` self attention + layers. Each layer is a + [`Idefics2EncoderLayer`]. + + Args: + config: Idefics2Config + """ + + def __init__(self, config: Idefics2Config): + super().__init__() + self.config = config + self.layers = nn.ModuleList([ + Idefics2EncoderLayer(config) + for _ in range(config.num_hidden_layers) + ]) + + def forward( + self, + inputs_embeds: torch.Tensor, + ) -> torch.Tensor: + r""" + Args: + inputs_embeds (torch.Tensor): + Optionally, instead of passing `input_ids` you can choose to + directly pass an embedded representation. + This is useful if you want more control over how to convert + `input_ids` indices into associated vectorsthan the model's + internal embedding lookup matrix. + """ + hidden_states = inputs_embeds + for encoder_layer in self.layers: + layer_outputs = encoder_layer(hidden_states) + hidden_states = layer_outputs + return hidden_states + + +class Idefics2VisionTransformer(nn.Module): + + def __init__(self, config: Idefics2VisionConfig): + super().__init__() + embed_dim = config.hidden_size + self.config = config + self.embeddings = Idefics2VisionEmbeddings(config) + self.encoder = Idefics2Encoder(config) + self.post_layernorm = nn.LayerNorm(embed_dim, + eps=config.layer_norm_eps) + + def get_input_embeddings(self): + return self.embeddings + + def forward( + self, + pixel_values, + patch_attention_mask: Optional[torch.BoolTensor] = None, + ) -> torch.tensor: + hidden_states = self.embeddings( + pixel_values=pixel_values, + patch_attention_mask=patch_attention_mask) + encoder_outputs = self.encoder(hidden_states) + last_hidden_state = self.post_layernorm(encoder_outputs) + return last_hidden_state diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index 2a7fe7ba0ebac..095bb49f6ba76 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -24,7 +24,8 @@ import math import re from functools import partial -from typing import Dict, Iterable, List, Optional, Tuple, Union +from typing import (Any, Callable, Iterable, List, Optional, Tuple, TypedDict, + Union) import numpy as np import torch @@ -38,11 +39,14 @@ from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, MultiModalConfig from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs +from vllm.logger import init_logger +from vllm.model_executor.layers.linear import ReplicatedLinear from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead +from vllm.model_executor.model_loader.utils import set_default_torch_dtype from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.interfaces import SupportsVision from vllm.model_executor.models.llama import LlamaModel @@ -54,12 +58,45 @@ cached_get_tokenizer) from vllm.sequence import IntermediateTensors, SamplerOutput, SequenceData +from .idefics2_vision_model import Idefics2VisionTransformer + +logger = init_logger(__name__) + _KEYS_TO_MODIFY_MAPPING = { "llm.lm_head": "lm_head", "llm.model": "llm", } +class MiniCPMVImagePixelInputs(TypedDict): + pixel_values: List[torch.Tensor] + """ + Shape: `(batch_size * num_images, num_channels, height, width)` + + Note that the image size may vary, so we pass it as a list + instead of a batched tensor. + """ + + image_bounds: torch.Tensor + """ + Shape: `(batch_size * num_images, 2)` + + This should be in `(start, stop)` format. + """ + + tgt_sizes: torch.Tensor + """ + Shape: `(batch_size * num_images, 2)` + + This should be in `(height, width)` format. + """ + + +MiniCPMVImageInputs = MiniCPMVImagePixelInputs + +DEFAULT_LN = partial(nn.LayerNorm, eps=1e-6) + + def get_abs_pos(abs_pos: torch.Tensor, tgt_size: torch.Tensor): # abs_pos: L, C # tgt_size: (H, W) @@ -68,23 +105,25 @@ def get_abs_pos(abs_pos: torch.Tensor, tgt_size: torch.Tensor): # tgt_size = int(math.sqrt(tgt_size)) dtype = abs_pos.dtype - return F.interpolate( + return (F.interpolate( abs_pos.float().reshape(1, src_size, src_size, -1).permute(0, 3, 1, 2), size=(tgt_size[0], tgt_size[1]), mode="bicubic", align_corners=False, - ).permute(0, 2, 3, 1).flatten(0, 2).to(dtype=dtype) + ).permute(0, 2, 3, 1).flatten(0, 2).to(dtype=dtype)) # https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20 -def get_2d_sincos_pos_embed(embed_dim: int, - grid_size: Union[int, Tuple[int, int]], - cls_token: bool = False, - version: Tuple[int, int] = (2, 0)): +def get_2d_sincos_pos_embed( + embed_dim: int, + grid_size: Union[int, Tuple[int, int]], + cls_token: bool = False, + version: Tuple[int, int] = (2, 0), +): """ grid_size: int of the grid height and width return: - pos_embed: [grid_size*grid_size, embed_dim] or + pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) """ if isinstance(grid_size, int): @@ -109,7 +148,7 @@ def get_2d_sincos_pos_embed(embed_dim: int, def get_2d_sincos_pos_embed_from_grid(embed_dim: int, - grid: Union[int, Tuple[int, int]], + grid: np.ndarray, version: Tuple[int, int] = (2, 0)): assert embed_dim % 2 == 0 @@ -127,7 +166,7 @@ def get_2d_sincos_pos_embed_from_grid(embed_dim: int, def get_1d_sincos_pos_embed_from_grid(embed_dim: int, - pos: int, + pos: np.ndarray, version: Tuple[int, int] = (2, 0)): """ embed_dim: output dimension for each position @@ -136,24 +175,24 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim: int, """ assert embed_dim % 2 == 0 omega = np.arange(embed_dim // 2, dtype=np.float32) - omega /= embed_dim / 2. - omega = 1. / 10000**omega # (D/2,) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega # (D/2,) if version == (2, 0): pos = pos.reshape(-1) # (M,) - out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product + out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product emb_sin = np.sin(out) # (M, D/2) emb_cos = np.cos(out) # (M, D/2) emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) else: - out = np.einsum('hw,d->hwd', pos, omega) # (H, W, D/2), outer product + out = np.einsum("hw,d->hwd", pos, omega) # (H, W, D/2), outer product emb_sin = np.sin(out) # (H, W, D/2) emb_cos = np.cos(out) # (H, W, D/2) emb = np.concatenate([emb_sin, emb_cos], axis=-1) # (H, W, D) return emb -class Resampler(nn.Module): +class BaseResampler(nn.Module): """ A 2D perceiver-resampler network with one cross attention layers by (grid_size**2) learnable queries and 2d sincos pos_emb @@ -161,89 +200,151 @@ class Resampler(nn.Module): A tensor with the shape of (grid_size**2, embed_dim) """ - default_norm_layer = partial(nn.LayerNorm, eps=1e-6) - - def __init__(self, - num_queries: int, - grid_size: int, - embed_dim: int, - num_heads: int, - kv_dim: Optional[int] = None, - norm_layer: nn.Module = default_norm_layer, - adaptive: bool = False, - max_size: Tuple[int, int] = (70, 70), - version: Tuple[int, int] = (2, 0)): + def __init__( + self, + num_queries: int, + embed_dim: int, + num_heads: int, + kv_dim: Optional[int] = None, + norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN, + ) -> None: super().__init__() - self.version = version - if self.version == (2, 0): - self.num_queries = grid_size**2 - else: - self.num_queries = num_queries - self.max_size = max_size + self.num_queries = num_queries self.embed_dim = embed_dim self.num_heads = num_heads - self.adaptive = adaptive self.query = nn.Parameter(torch.zeros(self.num_queries, embed_dim)) - trunc_normal_(self.query, std=.02) + trunc_normal_(self.query, std=0.02) if kv_dim is not None and kv_dim != embed_dim: - self.kv_proj = nn.Linear(kv_dim, embed_dim, bias=False) + self.kv_proj = ReplicatedLinear(kv_dim, embed_dim, bias=False) else: - self.kv_proj = nn.Identity() + # Maintain the same return value with ReplicatedLinear.forward + self.kv_proj = lambda *args, **kwargs: ( + nn.Identity()(*args, **kwargs), + None, + ) self.attn = nn.MultiheadAttention(embed_dim, num_heads) self.ln_q = norm_layer(embed_dim) self.ln_kv = norm_layer(embed_dim) - self.ln_post = norm_layer(embed_dim) self.proj = nn.Parameter( (embed_dim**-0.5) * torch.randn(embed_dim, embed_dim)) - if self.version == (2, 0): - self.pos_embed = nn.Parameter( - torch.from_numpy( - get_2d_sincos_pos_embed( - embed_dim, grid_size, - version=self.version)).float()).requires_grad_(False) + def _init_weights(self, m: nn.Module) -> None: + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def _repeat(self, query, N: int): + return query.unsqueeze(1).repeat(1, N, 1) + + +class Resampler2(BaseResampler): + + def __init__( + self, + grid_size: int, + embed_dim: int, + num_heads: int, + kv_dim: Optional[int] = None, + norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN, + adaptive: bool = False, + ) -> None: + super().__init__(grid_size**2, embed_dim, num_heads, kv_dim, + norm_layer) + + self.adaptive = adaptive + + pos_embed_arr = get_2d_sincos_pos_embed(embed_dim, + grid_size, + version=(2, 0)) + self.pos_embed = nn.Parameter( + torch.from_numpy(pos_embed_arr).float()).requires_grad_(False) + + self.apply(self._init_weights) + + def forward( + self, + x: torch.Tensor, + tgt_sizes: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + ): + if self.adaptive: + pos_embed_arr = get_2d_sincos_pos_embed(self.embed_dim, + tgt_sizes, + version=(2, 0)) + pos_embed = torch.from_numpy(pos_embed_arr).to(device=x.device, + dtype=x.dtype) else: - self._set_2d_pos_cache(self.max_size) + pos_embed = get_abs_pos(self.pos_embed, tgt_sizes) + + x, _ = self.kv_proj(x) + x = self.ln_kv(x).permute(1, 0, 2) + + N = x.shape[1] + q = self.ln_q(self.query) + out = self.attn( + self._repeat(q, N) + self.pos_embed.unsqueeze(1), + x + pos_embed.unsqueeze(1), + x, + attn_mask=attn_mask, + )[0] + x = out.permute(1, 0, 2) + + x = self.ln_post(x) + x = x @ self.proj + return x + + +class Resampler2_5(BaseResampler): + + def __init__( + self, + num_queries: int, + embed_dim: int, + num_heads: int, + kv_dim: Optional[int] = None, + norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN, + max_size: Tuple[int, int] = (70, 70), + ) -> None: + super().__init__(num_queries, embed_dim, num_heads, kv_dim, norm_layer) + + self.max_size = max_size + self._set_2d_pos_cache(self.max_size) self.apply(self._init_weights) def _set_2d_pos_cache(self, max_size: Tuple[int, int], - device: torch.types.Device = 'cpu'): - pos_embed = torch.from_numpy( - get_2d_sincos_pos_embed(self.embed_dim, - max_size, - version=self.version)).float().to(device) + device: torch.types.Device = "cpu") -> None: + pos_embed_arr = get_2d_sincos_pos_embed(self.embed_dim, + max_size, + version=(2, 5)) + pos_embed = torch.from_numpy(pos_embed_arr).float().to(device) self.register_buffer("pos_embed", pos_embed, persistent=False) def _adjust_pos_cache(self, tgt_sizes: torch.Tensor, - device: torch.types.Device): - max_h = torch.max(tgt_sizes[:, 0]) - max_w = torch.max(tgt_sizes[:, 1]) + device: torch.types.Device) -> None: + max_h = tgt_sizes[:, 0].max().item() + max_w = tgt_sizes[:, 1].max().item() + assert isinstance(max_h, int) and isinstance(max_w, int) + if max_h > self.max_size[0] or max_w > self.max_size[1]: - self.max_size = [ + self.max_size = ( max(max_h, self.max_size[0]), - max(max_w, self.max_size[1]) - ] + max(max_w, self.max_size[1]), + ) self._set_2d_pos_cache(self.max_size, device) - def _init_weights(self, m: nn.Module): - if isinstance(m, nn.Linear): - trunc_normal_(m.weight, std=.02) - if isinstance(m, nn.Linear) and m.bias is not None: - nn.init.constant_(m.bias, 0) - elif isinstance(m, nn.LayerNorm): - nn.init.constant_(m.bias, 0) - nn.init.constant_(m.weight, 1.0) - - def forward_2_5(self, - x: torch.Tensor, - tgt_sizes: Optional[torch.Tensor] = None): + def forward(self, x: torch.Tensor, + tgt_sizes: torch.Tensor) -> torch.Tensor: assert x.shape[0] == tgt_sizes.shape[0] bs = x.shape[0] @@ -254,25 +355,25 @@ def forward_2_5(self, self._adjust_pos_cache(tgt_sizes, device=device) - max_patch_len = torch.max(patch_len) + max_patch_len = patch_len.max().item() + assert isinstance(max_patch_len, int) + key_padding_mask = torch.zeros((bs, max_patch_len), dtype=torch.bool, device=device) pos_embed = [] for i in range(bs): - tgt_h, tgt_w = tgt_sizes[i] + tgt_h, tgt_w = tgt_sizes[i].tolist() pos_embed.append(self.pos_embed[:tgt_h, :tgt_w, :].reshape( (tgt_h * tgt_w, -1)).to(dtype)) # patches * D key_padding_mask[i, patch_len[i]:] = True - pos_embed = torch.nn.utils.rnn.pad_sequence(pos_embed, batch_first=True, padding_value=0.0).permute( 1, 0, 2) # BLD => L * B * D - - x = self.kv_proj(x) # B * L * D + x, _ = self.kv_proj(x) # B * L * D x = self.ln_kv(x).permute(1, 0, 2) # L * B * D q = self.ln_q(self.query) # Q * D @@ -281,7 +382,8 @@ def forward_2_5(self, self._repeat(q, bs), # Q * B * D x + pos_embed, # L * B * D + L * B * D x, - key_padding_mask=key_padding_mask)[0] + key_padding_mask=key_padding_mask, + )[0] # out: Q * B * D x = out.permute(1, 0, 2) # B * Q * D @@ -289,45 +391,6 @@ def forward_2_5(self, x = x @ self.proj return x - def forward_2(self, - x: torch.Tensor, - tgt_sizes: Optional[torch.Tensor] = None, - attn_mask: Optional[torch.Tensor] = None): - if self.adaptive: - pos_embed = torch.Tensor( - get_2d_sincos_pos_embed(self.embed_dim, - tgt_sizes)).float().to(device=x.device, - dtype=x.dtype) - else: - pos_embed = get_abs_pos(self.pos_embed, tgt_sizes) - - x = self.kv_proj(x) - x = self.ln_kv(x).permute(1, 0, 2) - - N = x.shape[1] - q = self.ln_q(self.query) - out = self.attn(self._repeat(q, N) + self.pos_embed.unsqueeze(1), - x + pos_embed.unsqueeze(1), - x, - attn_mask=attn_mask)[0] - x = out.permute(1, 0, 2) - - x = self.ln_post(x) - x = x @ self.proj - return x - - def forward(self, - x: torch.Tensor, - tgt_sizes: Optional[torch.Tensor] = None, - attn_mask: Optional[torch.Tensor] = None): - if self.version == (2, 0): - return self.forward_2(x, tgt_sizes=tgt_sizes, attn_mask=attn_mask) - else: - return self.forward_2_5(x, tgt_sizes=tgt_sizes) - - def _repeat(self, query, N: int): - return query.unsqueeze(1).repeat(1, N, 1) - def get_max_minicpmv_image_tokens(ctx: InputContext): hf_config = ctx.get_hf_config(PretrainedConfig) @@ -348,10 +411,7 @@ def dummy_image_for_minicpmv(hf_config: PretrainedConfig): def dummy_data_for_minicpmv(ctx: InputContext, seq_len: int): hf_config = ctx.get_hf_config(PretrainedConfig) - # image_feature_size = get_max_minicpmv_image_tokens(ctx) - seq_data = dummy_seq_data_for_minicpmv(seq_len) - mm_data = dummy_image_for_minicpmv(hf_config) return seq_data, mm_data @@ -376,25 +436,36 @@ def input_processor_for_minicpmv(ctx: InputContext, llm_inputs: LLMInputs): pattern = "(./)" image = multi_modal_data["image"] image_tags = re.findall(pattern, prompt) - assert len(image_tags) <= 1 - text_chunks = prompt.split(pattern) - new_prompt = text_chunks[0] \ - + image_processor.get_slice_image_placeholder(image.size) \ - + text_chunks[1] - new_token_ids = tokenizer.encode(new_prompt) - - llm_inputs = LLMInputs(prompt_token_ids=new_token_ids, - prompt=new_prompt, - multi_modal_data=multi_modal_data) + if len(image_tags) == 0: + new_token_ids = token_ids + new_prompt = prompt + else: + if len(image_tags) > 1: + logger.warning("Multiple image input is not supported yet, " + "so any extra image tokens will be treated " + "as plain text.") + + text_chunks = prompt.split(pattern) + new_prompt = (text_chunks[0] + + image_processor.get_slice_image_placeholder(image.size) + + "".join(text_chunks[1:])) + + new_token_ids = tokenizer.encode(new_prompt) + + llm_inputs = LLMInputs( + prompt_token_ids=new_token_ids, + prompt=new_prompt, + multi_modal_data=multi_modal_data, + ) return llm_inputs -@MULTIMODAL_REGISTRY.register_image_input_mapper() -@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_minicpmv_image_tokens) -@INPUT_REGISTRY.register_dummy_data(dummy_data_for_minicpmv) -@INPUT_REGISTRY.register_input_processor(input_processor_for_minicpmv) -class MiniCPMV(nn.Module, SupportsVision): +class MiniCPMVBaseModel(nn.Module, SupportsVision): + """ + The abstract class of MiniCPMV can only be inherited, but cannot be + instantiated. + """ def __init__( self, @@ -419,8 +490,8 @@ def __init__( self.vpm = self.init_vision_module() param_dtype = torch.get_default_dtype() self.vpm.to(dtype=param_dtype) - self.vision_dim = self.vpm.embed_dim if self.version == (2, 0) \ - else self.vpm.embeddings.embed_dim + self.vision_dim = (self.vpm.embed_dim if self.version == (2, 0) else + self.vpm.embeddings.embed_dim) self.embed_dim = self.config.hidden_size self.resampler = self.init_resampler(self.embed_dim, self.vision_dim) self.resampler.to(device="cuda", dtype=param_dtype) @@ -430,248 +501,100 @@ def __init__( self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() - def init_llm(self, - config: PretrainedConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None): - if self.version == (2, 0): - return MiniCPMModel(config, - cache_config=cache_config, - quant_config=quant_config) - elif self.version == (2, 5): - return LlamaModel(config, - cache_config=cache_config, - quant_config=quant_config) - else: - return Qwen2Model(config, - cache_config=cache_config, - quant_config=quant_config) - - def init_vision_module(self): - if self.version == (2, 0): - try: - import timm - except ImportError: - raise ImportError( - 'Please install timm==0.9.10') from ImportError - default_dtype = torch.get_default_dtype() - torch.set_default_dtype(torch.float16) - model = timm.create_model('vit_so400m_patch14_siglip_384.webli', - pretrained=False, - num_classes=0, - dynamic_img_size=True, - dynamic_img_pad=True) - torch.set_default_dtype(default_dtype) - if isinstance(model, timm.models.VisionTransformer - ) and model.attn_pool is not None: - model.attn_pool = torch.nn.Identity() - - if self.config.drop_vision_last_layer: - model.blocks = model.blocks[:-1] - elif self.version == (2, 5): - from transformers.models.idefics2.modeling_idefics2 import ( - Idefics2VisionTransformer) - model = Idefics2VisionTransformer(self.config.vision_config) - if self.config.drop_vision_last_layer: - model.encoder.layers = model.encoder.layers[:-1] - else: - from vllm.model_executor.models.na_vit import ( - SiglipVisionTransformer) - if self.config._attn_implementation == 'flash_attention_2': - self.config.vision_config._attn_implementation \ - = 'flash_attention_2' - else: - # not support sdpa - self.config.vision_config._attn_implementation = 'eager' - model = SiglipVisionTransformer(self.config.vision_config) - if self.config.drop_vision_last_layer: - model.encoder.layers = model.encoder.layers[:-1] - return model - - def init_resampler(self, embed_dim: int, vision_dim: int): - default_dtype = torch.get_default_dtype() - torch.set_default_dtype(torch.float16) - if self.version == (2, 0): - resampler = Resampler(grid_size=int( - math.sqrt(self.config.query_num)), - num_queries=None, - embed_dim=embed_dim, - num_heads=embed_dim // 128, - kv_dim=vision_dim, - adaptive=True, - version=self.version) + def get_embedding( + self, + input_ids: torch.Tensor, + image_inputs: Optional[MiniCPMVImageInputs], + ) -> Tuple[torch.Tensor, torch.Tensor]: + vlm_embedding: torch.Tensor = self.llm.embed_tokens(input_ids) + if hasattr(self.config, "scale_emb"): + vlm_embedding *= self.config.scale_emb + + if image_inputs is None: # No image + vision_hidden_states = torch.tensor([], device=input_ids.device) else: - resampler = Resampler(num_queries=self.config.query_num, - grid_size=None, - embed_dim=embed_dim, - num_heads=embed_dim // 128, - kv_dim=vision_dim, - adaptive=True, - version=self.version) - torch.set_default_dtype(default_dtype) - return resampler + vision_hidden_states = self.get_vision_hidden_states(image_inputs) + + # See NOTE in _parse_and_validate_inputs + image_bounds = image_inputs["image_bounds"] + if len(image_bounds) > 0: + image_indices = torch.stack([ + torch.arange(start, end, dtype=torch.long) + for start, end in image_bounds.tolist() + ]).to(vlm_embedding.device) + vlm_embedding.scatter_( + 0, + image_indices.view(-1, 1).repeat(1, + vlm_embedding.shape[-1]), + vision_hidden_states.view(-1, + vision_hidden_states.shape[-1]), + ) - def get_vision_embedding(self, - pixel_values: List[List[torch.Tensor]], - patch_attn_mask: Optional[torch.Tensor] = None, - tgt_sizes: Optional[torch.Tensor] = None, - version: Tuple[int, int] = (2, 0)): - if version == (2, 0): - res = [] - dtype = self.vpm.pos_embed.data.dtype - for pixel_value in pixel_values: - # V2.0 start - H, W = pixel_value[0].shape[-2:] - tgt_size = (math.ceil(H / self.vpm.patch_embed.patch_size[0]), - math.ceil(W / self.vpm.patch_embed.patch_size[0])) - # V2.0 end - vision_embedding = self.vpm.forward_features( - pixel_value.unsqueeze(0).type(dtype)) - if hasattr(self.vpm, 'num_prefix_tokens' - ) and self.vpm.num_prefix_tokens > 0: - vision_embedding = vision_embedding[:, self.vpm. - num_prefix_tokens:] - res.append(self.resampler(vision_embedding, tgt_size)) - return torch.vstack(res) - elif version == (2, 5): - vision_embedding = self.vpm( - pixel_values.type(dtype), - patch_attention_mask=patch_attn_mask).last_hidden_state - vision_embedding = self.resampler(vision_embedding, tgt_sizes) - else: - vision_embedding = self.vpm(pixel_values.type(dtype), - patch_attention_mask=patch_attn_mask, - tgt_sizes=tgt_sizes).last_hidden_state + return vlm_embedding, vision_hidden_states - def get_image_bounds(self, input_ids: torch.Tensor): + def _get_image_bounds(self, input_ids: torch.Tensor) -> torch.Tensor: tokenizer = cached_get_tokenizer(self.config._name_or_path, trust_remote_code=True) - if not hasattr(tokenizer, "slice_start_id"): - start_cond = input_ids == tokenizer.im_start_id - end_cond = input_ids == tokenizer.im_end_id - else: - start_cond = (input_ids == tokenizer.im_start_id) | ( - input_ids == tokenizer.slice_start_id) - end_cond = (input_ids == tokenizer.im_end_id) | ( - input_ids == tokenizer.slice_end_id) + start_cond = input_ids == tokenizer.im_start_id + end_cond = input_ids == tokenizer.im_end_id + if hasattr(tokenizer, "slice_start_id"): + start_cond |= (input_ids == tokenizer.slice_start_id) + end_cond |= (input_ids == tokenizer.slice_end_id) - image_start_tokens = torch.where(start_cond)[0] + image_start_tokens, = torch.where(start_cond) image_start_tokens += 1 - image_end_tokens = torch.where(end_cond)[0] + image_end_tokens, = torch.where(end_cond) valid_image_nums = max(len(image_start_tokens), len(image_end_tokens)) + if valid_image_nums == 0: - return [] - image_bound = torch.hstack([ + return torch.zeros((0, 2), device=input_ids.device) + + return torch.hstack([ image_start_tokens[:valid_image_nums].unsqueeze(-1), image_end_tokens[:valid_image_nums].unsqueeze(-1), ]) - return image_bound - - def get_vision_hidden_states(self, data: Dict[str, - Union[List[torch.Tensor], - torch.Tensor]]): - if "vision_hidden_states" not in data: - pixel_values = data["pixel_values"] - tgt_sizes = data["tgt_sizes"] - vision_hidden_states = [] - if self.version == (2, 0): - if pixel_values is not None and len(pixel_values) > 0: - vision_hidden_states = self.get_vision_embedding( - pixel_values) - else: - vision_hidden_states = torch.tensor([]).to( - data["input_ids"].device) - else: - device = self.vpm.embeddings.position_embedding.weight.device - dtype = self.vpm.embeddings.position_embedding.weight.dtype - all_pixel_values = [ - i.flatten(end_dim=1).permute(1, 0) for i in pixel_values - ] - if all_pixel_values: - tgt_sizes = torch.vstack(tgt_sizes).type(torch.int32) - max_patches = torch.max(tgt_sizes[:, 0] * tgt_sizes[:, 1]) - all_pixel_values = torch.nn.utils.rnn.pad_sequence( - all_pixel_values, batch_first=True, padding_value=0.0) - B, L, _ = all_pixel_values.shape - all_pixel_values = all_pixel_values.permute( - 0, 2, 1).reshape(B, 3, -1, L) - patch_attn_mask = torch.zeros((B, 1, max_patches), - dtype=torch.bool, - device=device) - if self.version == (2, 5): - for i in range(B): - patch_attn_mask[i, :tgt_sizes[i][0] * - tgt_sizes[i][1]] = True - vision_embedding = self.vpm( - all_pixel_values.type(dtype), - patch_attention_mask=patch_attn_mask - ).last_hidden_state - else: - for i in range(B): - patch_attn_mask[i, 0, :tgt_sizes[i][0] * - tgt_sizes[i][1]] = True - vision_embedding = self.vpm( - all_pixel_values.type(dtype), - patch_attention_mask=patch_attn_mask, - tgt_sizes=tgt_sizes).last_hidden_state - - vision_hidden_states = self.resampler( - vision_embedding, tgt_sizes) - - else: # no image - dummy_feature = [] - vision_hidden_states = dummy_feature - else: - vision_hidden_states = data["vision_hidden_states"] - - return vision_hidden_states - - def get_embedding(self, data: Dict[str, Union[List[torch.Tensor], - torch.Tensor]]): - input_ids = data["input_ids"] - - vision_hidden_states = self.get_vision_hidden_states(data) - if vision_hidden_states is not None and len(vision_hidden_states) > 0: - image_bounds = self.get_image_bounds(input_ids) - else: - image_bounds = [] - - if hasattr(self.config, 'scale_emb'): - vlm_embedding = self.llm.embed_tokens( - input_ids) * self.config.scale_emb - else: - vlm_embedding = self.llm.embed_tokens(input_ids) - vision_hidden_states = [ - i.type(vlm_embedding.dtype) if isinstance(i, torch.Tensor) else i - for i in vision_hidden_states - ] - - if len(vision_hidden_states) > 0 and len(image_bounds) > 0: - vision_hidden_states = torch.cat(vision_hidden_states, dim=0) - image_indices = torch.stack([ - torch.arange(r[0], r[1], dtype=torch.long) - for r in image_bounds - ]).to(vlm_embedding.device) - vlm_embedding.scatter_( - 0, - image_indices.view(-1, 1).repeat(1, vlm_embedding.shape[-1]), - vision_hidden_states.view(-1, vision_hidden_states.shape[-1])) - return vlm_embedding, vision_hidden_states - - def process_multimodal_inputs(self, inputs: Dict[str, - Union[List[torch.Tensor], - torch.Tensor]]): - pixel_values = [] - tgt_sizes = [] - for b in range(len(inputs["pixel_values"])): - pixel_values += inputs["pixel_values"][b] - tgt_sizes += inputs["tgt_sizes"][b] - return { - "pixel_values": pixel_values, - "input_ids": inputs["input_ids"], - "tgt_sizes": tgt_sizes - } + def _parse_and_validate_inputs( + self, + input_ids: torch.Tensor, + **kwargs: object, + ) -> Optional[MiniCPMVImageInputs]: + pixel_values = kwargs.pop("pixel_values", []) + tgt_sizes = kwargs.pop("tgt_sizes", []) + + if not isinstance(pixel_values, (torch.Tensor, list)): + raise ValueError("Incorrect type of pixel values. " + f"Got type: {type(pixel_values)}") + + if not isinstance(tgt_sizes, (torch.Tensor, list)): + raise ValueError("Incorrect type of target sizes. " + f"Got type: {type(tgt_sizes)}") + + if len(pixel_values) != len(tgt_sizes): + raise ValueError("Inconsistent batch lengths, found: " + f"{len(pixel_values)} vs. {len(tgt_sizes)}") + + pixel_values_flat: List[torch.Tensor] = [] + tgt_sizes_flat: List[torch.Tensor] = [] + for b in range(len(pixel_values)): + pixel_values_flat += pixel_values[b] + tgt_sizes_flat += tgt_sizes[b] + + # NOTE: Input IDs does not contain image tokens during memory profiling, + # so we allow it to be empty + if len(pixel_values_flat) != len(tgt_sizes_flat): + raise ValueError("Inconsistent flattened lengths, found: " + f"{len(pixel_values_flat)} vs. " + f"{len(tgt_sizes_flat)}") + + if len(pixel_values_flat) == 0: + return None + + return MiniCPMVImageInputs( + image_bounds=self._get_image_bounds(input_ids), + pixel_values=pixel_values_flat, + tgt_sizes=torch.stack(tgt_sizes_flat), + ) def forward( self, @@ -680,23 +603,20 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, - **kwargs: object, - ): - inputs = { - "pixel_values": kwargs.pop("pixel_values", []), - "input_ids": input_ids, - "tgt_sizes": kwargs.pop("tgt_sizes", None), - } - inputs = self.process_multimodal_inputs(inputs) - - vlm_embeddings, vision_hidden_states = self.get_embedding(inputs) - - output = self.llm(input_ids=None, - positions=positions, - kv_caches=kv_caches, - attn_metadata=attn_metadata, - intermediate_tensors=intermediate_tensors, - inputs_embeds=vlm_embeddings) + **kwargs: Any, + ) -> torch.Tensor: + image_inputs = self._parse_and_validate_inputs(input_ids, **kwargs) + + vlm_embeddings, _ = self.get_embedding(input_ids, image_inputs) + + output = self.llm( + input_ids=None, + positions=positions, + kv_caches=kv_caches, + attn_metadata=attn_metadata, + intermediate_tensors=intermediate_tensors, + inputs_embeds=vlm_embeddings, + ) return output def compute_logits(self, hidden_states: torch.Tensor, @@ -735,13 +655,10 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # the checkpoint. Skip them. continue use_default_weight_loading = False - if "vpm" in name or 'resampler' in name: - # We only do sharding for language model and - # not vision model for now. + if self.is_default_weight_loading(name): use_default_weight_loading = True else: - for (param_name, weight_name, - shard_id) in stacked_params_mapping: + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue param = params_dict[name.replace(weight_name, param_name)] @@ -755,3 +672,341 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + + def init_llm( + self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> nn.Module: + raise NotImplementedError + + def init_vision_module(self) -> nn.Module: + raise NotImplementedError + + def init_resampler(self, embed_dim: int, vision_dim: int) -> nn.Module: + raise NotImplementedError + + def get_vision_embedding( + self, + pixel_values: List[torch.Tensor], + patch_attn_mask: Optional[torch.Tensor] = None, + tgt_sizes: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + raise NotImplementedError + + def get_vision_hidden_states(self, + data: MiniCPMVImageInputs) -> torch.Tensor: + raise NotImplementedError + + def is_default_weight_loading(self, name: str) -> bool: + raise NotImplementedError + + +class MiniCPMV2(MiniCPMVBaseModel): + + def __init__( + self, + config: PretrainedConfig, + multimodal_config: MultiModalConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__(config, multimodal_config, cache_config, quant_config) + assert self.version == (2, 0) + + def init_llm( + self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> nn.Module: + return MiniCPMModel(config, + cache_config=cache_config, + quant_config=quant_config) + + def init_vision_module(self) -> nn.Module: + # TODO :refactor this vision model + try: + import timm + except ImportError: + raise ImportError("Please install timm==0.9.10") from ImportError + with set_default_torch_dtype(torch.float16): + model = timm.create_model( + "vit_so400m_patch14_siglip_384.webli", + pretrained=False, + num_classes=0, + dynamic_img_size=True, + dynamic_img_pad=True, + ) + + if (isinstance(model, timm.models.VisionTransformer) + and model.attn_pool is not None): + model.attn_pool = torch.nn.Identity() + + if self.config.drop_vision_last_layer: + model.blocks = model.blocks[:-1] + + return model + + def init_resampler(self, embed_dim: int, vision_dim: int) -> nn.Module: + with set_default_torch_dtype(torch.float16): + resampler = Resampler2( + embed_dim=embed_dim, + num_heads=embed_dim // 128, + grid_size=int(math.sqrt(self.config.query_num)), + kv_dim=vision_dim, + adaptive=True, + ) + + return resampler + + def get_vision_embedding( + self, + pixel_values: List[torch.Tensor], + patch_attn_mask: Optional[torch.Tensor] = None, + tgt_sizes: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + res = [] + dtype = self.vpm.pos_embed.data.dtype + for pixel_value in pixel_values: + H, W = pixel_value[0].shape[-2:] + tgt_size = ( + math.ceil(H / self.vpm.patch_embed.patch_size[0]), + math.ceil(W / self.vpm.patch_embed.patch_size[0]), + ) + vision_embedding = self.vpm.forward_features( + pixel_value.unsqueeze(0).type(dtype)) + if (hasattr(self.vpm, "num_prefix_tokens") + and self.vpm.num_prefix_tokens > 0): + vision_embedding = vision_embedding[:, self.vpm. + num_prefix_tokens:] + res.append(self.resampler(vision_embedding, tgt_size)) + return torch.vstack(res) + + def get_vision_hidden_states(self, + data: MiniCPMVImageInputs) -> torch.Tensor: + pixel_values = data["pixel_values"] + + return self.get_vision_embedding(pixel_values) + + def is_default_weight_loading(self, name: str) -> bool: + return "resampler" in name or "vpm" in name + + +class MiniCPMV2_5(MiniCPMVBaseModel): + + def __init__( + self, + config: PretrainedConfig, + multimodal_config: MultiModalConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__(config, multimodal_config, cache_config, quant_config) + assert self.version == (2, 5) + + def init_llm( + self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> nn.Module: + return LlamaModel(config, + cache_config=cache_config, + quant_config=quant_config) + + def init_vision_module(self) -> nn.Module: + model = Idefics2VisionTransformer(self.config.vision_config) + if self.config.drop_vision_last_layer: + model.encoder.layers = model.encoder.layers[:-1] + return model + + def init_resampler(self, embed_dim: int, vision_dim: int) -> nn.Module: + with set_default_torch_dtype(torch.float16): + resampler = Resampler2_5( + num_queries=self.config.query_num, + embed_dim=embed_dim, + num_heads=embed_dim // 128, + kv_dim=vision_dim, + ) + return resampler + + def get_vision_embedding( + self, + pixel_values: List[torch.Tensor], + patch_attn_mask: Optional[torch.Tensor] = None, + tgt_sizes: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + vision_embedding = self.vpm(pixel_values, + patch_attention_mask=patch_attn_mask) + vision_embedding = self.resampler(vision_embedding, tgt_sizes) + return vision_embedding + + def get_vision_hidden_states(self, + data: MiniCPMVImageInputs) -> torch.Tensor: + pixel_values = data["pixel_values"] + tgt_sizes = data["tgt_sizes"] + + device = self.vpm.embeddings.position_embedding.weight.device + dtype = self.vpm.embeddings.position_embedding.weight.dtype + all_pixel_values_lst = [ + i.flatten(end_dim=1).permute(1, 0) for i in pixel_values + ] + + max_patches = (tgt_sizes[:, 0] * tgt_sizes[:, 1]).max().item() + assert isinstance(max_patches, int) + + all_pixel_values = torch.nn.utils.rnn.pad_sequence( + all_pixel_values_lst, batch_first=True, padding_value=0.0) + B, L, _ = all_pixel_values.shape + all_pixel_values = all_pixel_values.permute(0, 2, + 1).reshape(B, 3, -1, L) + + patch_attn_mask = torch.zeros((B, 1, max_patches), + dtype=torch.bool, + device=device) + for i in range(B): + patch_attn_mask[i, :tgt_sizes[i][0] * tgt_sizes[i][1]] = True + + return self.get_vision_embedding(all_pixel_values.type(dtype), + patch_attn_mask, tgt_sizes) + + def is_default_weight_loading(self, name: str) -> bool: + return "resampler" in name + + +# NOTE: Currently, information about this model is unavailable. We are +# temporarily using `MiniCPMVQwen2` as it's name. The name may need +# to be modified in the future. +class MiniCPMVQwen2(MiniCPMVBaseModel): + + def __init__( + self, + config: PretrainedConfig, + multimodal_config: MultiModalConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__(config, multimodal_config, cache_config, quant_config) + + def init_llm( + self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> nn.Module: + return Qwen2Model(config, + cache_config=cache_config, + quant_config=quant_config) + + def init_vision_module(self) -> nn.Module: + # A custom version of SiglipVisionTransformer, won't work with TP + from vllm.model_executor.models.na_vit import SiglipVisionTransformer + + if self.config._attn_implementation == "flash_attention_2": + self.config.vision_config._attn_implementation = "flash_attention_2" + else: + # not support sdpa + self.config.vision_config._attn_implementation = "eager" + model = SiglipVisionTransformer(self.config.vision_config) + if self.config.drop_vision_last_layer: + model.encoder.layers = model.encoder.layers[:-1] + return model + + def init_resampler(self, embed_dim: int, vision_dim: int) -> nn.Module: + with set_default_torch_dtype(torch.float16): + resampler = Resampler2_5( + num_queries=self.config.query_num, + embed_dim=embed_dim, + num_heads=embed_dim // 128, + kv_dim=vision_dim, + ) + + return resampler + + def get_vision_embedding( + self, + pixel_values: List[torch.Tensor], + patch_attn_mask: Optional[torch.Tensor] = None, + tgt_sizes: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + vision_embedding = self.vpm( + pixel_values, + patch_attention_mask=patch_attn_mask, + tgt_sizes=tgt_sizes, + ).last_hidden_state + return vision_embedding + + def get_vision_hidden_states(self, + data: MiniCPMVImageInputs) -> torch.Tensor: + pixel_values = data["pixel_values"] + tgt_sizes = data["tgt_sizes"] + + device = self.vpm.embeddings.position_embedding.weight.device + dtype = self.vpm.embeddings.position_embedding.weight.dtype + all_pixel_values_lst = [ + i.flatten(end_dim=1).permute(1, 0) for i in pixel_values + ] + + max_patches = (tgt_sizes[:, 0] * tgt_sizes[:, 1]).max().item() + assert isinstance(max_patches, int) + + all_pixel_values = torch.nn.utils.rnn.pad_sequence( + all_pixel_values_lst, batch_first=True, padding_value=0.0) + B, L, _ = all_pixel_values.shape + all_pixel_values = all_pixel_values.permute(0, 2, + 1).reshape(B, 3, -1, L) + + patch_attn_mask = torch.zeros((B, 1, max_patches), + dtype=torch.bool, + device=device) + for i in range(B): + patch_attn_mask[i, 0, :tgt_sizes[i][0] * tgt_sizes[i][1]] = True + vision_embedding = self.vpm( + all_pixel_values.type(dtype), + patch_attention_mask=patch_attn_mask, + tgt_sizes=tgt_sizes, + ).last_hidden_state + + return self.resampler(vision_embedding, tgt_sizes) + + def is_default_weight_loading(self, name: str) -> bool: + return "resampler" in name or "vpm" in name + + +@MULTIMODAL_REGISTRY.register_image_input_mapper() +@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_minicpmv_image_tokens) +@INPUT_REGISTRY.register_dummy_data(dummy_data_for_minicpmv) +@INPUT_REGISTRY.register_input_processor(input_processor_for_minicpmv) +class MiniCPMV(MiniCPMVBaseModel): + """ + Different versions of MiniCPMV use different visual encoders and LLMs, + which is not conducive to the current integration logic of LoRA and + bitsandbytes in vLLM. Therefore, it is necessary to separate them. + """ + + def __new__( + cls, + config: PretrainedConfig, + multimodal_config: MultiModalConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ): + if not hasattr(config, "version"): + if config.hidden_size == 2304 and config.query_num == 64: + version = (2, 0) + else: + version = (2, 5) + else: + version = str(config.version).split(".") + version = tuple([int(x) for x in version]) + # Dispatch class based on version + if version == (2, 0): + instance_class = MiniCPMV2 + elif version == (2, 5): + instance_class = MiniCPMV2_5 + else: + instance_class = MiniCPMVQwen2 + return instance_class(config, multimodal_config, cache_config, + quant_config) diff --git a/vllm/model_executor/models/na_vit.py b/vllm/model_executor/models/na_vit.py index 871e4128b66e1..1d6f26f0d4fb5 100644 --- a/vllm/model_executor/models/na_vit.py +++ b/vllm/model_executor/models/na_vit.py @@ -100,7 +100,7 @@ def _get_unpad_data(attention_mask): indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() max_seqlen_in_batch = seqlens_in_batch.max().item() cu_seqlens = F.pad( - torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) + torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) return ( indices, cu_seqlens, From b1c9aa3daa7dcd981f0f77231b46883624b72dd0 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Sun, 4 Aug 2024 16:13:18 +0200 Subject: [PATCH 21/36] [Bugfix] [SpecDecode] Default speculative_draft_tensor_parallel_size to 1 when using MLPSpeculator (#7105) Signed-off-by: Thomas Parnell --- vllm/config.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 028f4eed8f4a2..0524514f6633a 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1068,7 +1068,7 @@ def maybe_create_spec_config( draft_parallel_config = ( SpeculativeConfig.create_draft_parallel_config( target_parallel_config, - speculative_draft_tensor_parallel_size)) + speculative_draft_tensor_parallel_size, draft_hf_config)) if num_speculative_tokens is None: raise ValueError( @@ -1136,15 +1136,23 @@ def _maybe_override_draft_max_model_len( @staticmethod def create_draft_parallel_config( target_parallel_config: ParallelConfig, - speculative_draft_tensor_parallel_size: Optional[int] + speculative_draft_tensor_parallel_size: Optional[int], + draft_hf_config: PretrainedConfig, ) -> ParallelConfig: """Create a parallel config for use by the draft worker. This is mostly a copy of the target parallel config, except the tp_size. """ if speculative_draft_tensor_parallel_size is None: - speculative_draft_tensor_parallel_size = \ - target_parallel_config.tensor_parallel_size + if draft_hf_config.model_type == "mlp_speculator": + speculative_draft_tensor_parallel_size = 1 + if target_parallel_config.tensor_parallel_size > 1: + logger.warning( + "MLPSpeculator cannot currently be run with tp>1; " + "setting speculative_draft_tensor_parallel_size=1") + else: + speculative_draft_tensor_parallel_size = \ + target_parallel_config.tensor_parallel_size elif speculative_draft_tensor_parallel_size != 1: # TODO(wooyeon): allow tp values larger than 1 raise ValueError( From 16a1cc9bb2b4bba82d78f329e5a89b44a5523ac8 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sun, 4 Aug 2024 11:31:51 -0700 Subject: [PATCH 22/36] [misc][distributed] improve libcudart.so finding (#7127) --- .../device_communicators/cuda_wrapper.py | 44 +++++++++---------- .../custom_all_reduce_utils.py | 4 +- 2 files changed, 25 insertions(+), 23 deletions(-) diff --git a/vllm/distributed/device_communicators/cuda_wrapper.py b/vllm/distributed/device_communicators/cuda_wrapper.py index 5cac3c1d57bca..9c7f41a1f9d62 100644 --- a/vllm/distributed/device_communicators/cuda_wrapper.py +++ b/vllm/distributed/device_communicators/cuda_wrapper.py @@ -4,9 +4,6 @@ """ import ctypes -import glob -import os -import sys from dataclasses import dataclass from typing import Any, Dict, List, Optional @@ -36,24 +33,25 @@ class Function: argtypes: List[Any] -def get_pytorch_default_cudart_library_path() -> str: - # code borrowed from https://github.com/pytorch/pytorch/blob/1cae60a87e5bdda8bcf55724a862eeed98a9747e/torch/__init__.py#L284 # noqa - lib_folder = "cuda_runtime" - lib_name = "libcudart.so.*[0-9]" - lib_path = None - for path in sys.path: - nvidia_path = os.path.join(path, "nvidia") - if not os.path.exists(nvidia_path): - continue - candidate_lib_paths = glob.glob( - os.path.join(nvidia_path, lib_folder, "lib", lib_name)) - if candidate_lib_paths and not lib_path: - lib_path = candidate_lib_paths[0] - if lib_path: - break - if not lib_path: - raise ValueError(f"{lib_name} not found in the system path {sys.path}") - return lib_path +def find_loaded_library(lib_name) -> Optional[str]: + """ + According to according to https://man7.org/linux/man-pages/man5/proc_pid_maps.5.html, + the file `/proc/self/maps` contains the memory maps of the process, which includes the + shared libraries loaded by the process. We can use this file to find the path of the + a loaded library. + """ # noqa + found = False + with open("/proc/self/maps") as f: + for line in f: + if lib_name in line: + found = True + break + if not found: + # the library is not loaded in the current process + return None + start = line.index("/") + path = line[start:].strip() + return path class CudaRTLibrary: @@ -100,7 +98,9 @@ class CudaRTLibrary: def __init__(self, so_file: Optional[str] = None): if so_file is None: - so_file = get_pytorch_default_cudart_library_path() + so_file = find_loaded_library("libcudart.so") + assert so_file is not None, \ + "libcudart.so is not loaded in the current process" if so_file not in CudaRTLibrary.path_to_library_cache: lib = ctypes.CDLL(so_file) CudaRTLibrary.path_to_library_cache[so_file] = lib diff --git a/vllm/distributed/device_communicators/custom_all_reduce_utils.py b/vllm/distributed/device_communicators/custom_all_reduce_utils.py index d27d7ee9a2496..37ae94c671e33 100644 --- a/vllm/distributed/device_communicators/custom_all_reduce_utils.py +++ b/vllm/distributed/device_communicators/custom_all_reduce_utils.py @@ -145,6 +145,7 @@ def can_actually_p2p( p_tgt.start() p_src.join() p_tgt.join() + assert p_src.exitcode == 0 and p_tgt.exitcode == 0 result: List[bool] = [] for src, tgt in zip(batch_src, batch_tgt): a = result_queue.get() @@ -221,7 +222,8 @@ def gpu_p2p_access_check(src: int, tgt: int) -> bool: # wrap raised exception to provide more information raise RuntimeError( f"Error happened when batch testing " - f"peer-to-peer access from {batch_src} to {batch_tgt}") from e + f"peer-to-peer access from {batch_src} to {batch_tgt}:\n" + f"{returned.stderr.decode()}") from e result = pickle.loads(returned.stdout) for _i, _j, r in zip(batch_src, batch_tgt, result): cache[f"{_i}->{_j}"] = r From f80ab3521ca2aa74e121e26a27b87da7a1065939 Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Mon, 5 Aug 2024 06:37:08 +0800 Subject: [PATCH 23/36] Clean up remaining Punica C information (#7027) --- .github/workflows/clang-format.yml | 6 ------ cmake/utils.cmake | 2 +- format.sh | 6 ------ vllm/config.py | 2 +- vllm/lora/layers.py | 2 +- 5 files changed, 3 insertions(+), 15 deletions(-) diff --git a/.github/workflows/clang-format.yml b/.github/workflows/clang-format.yml index e9b6e28fa6bcb..79b85d8cad0d5 100644 --- a/.github/workflows/clang-format.yml +++ b/.github/workflows/clang-format.yml @@ -30,12 +30,6 @@ jobs: run: | EXCLUDES=( 'csrc/moe/topk_softmax_kernels.cu' - 'csrc/punica/bgmv/bgmv_bf16_bf16_bf16.cu' - 'csrc/punica/bgmv/bgmv_config.h' - 'csrc/punica/bgmv/bgmv_impl.cuh' - 'csrc/punica/bgmv/vec_dtypes.cuh' - 'csrc/punica/punica_ops.cu' - 'csrc/punica/type_convert.h' ) find csrc/ \( -name '*.h' -o -name '*.cpp' -o -name '*.cu' -o -name '*.cuh' \) -print \ | grep -vFf <(printf "%s\n" "${EXCLUDES[@]}") \ diff --git a/cmake/utils.cmake b/cmake/utils.cmake index 4869cad541135..69998b45be70a 100644 --- a/cmake/utils.cmake +++ b/cmake/utils.cmake @@ -181,7 +181,7 @@ macro(override_gpu_arches GPU_ARCHES GPU_LANG GPU_SUPPORTED_ARCHES) # # The torch cmake setup hardcodes the detected architecture flags in # `CMAKE_CUDA_FLAGS`. Since `CMAKE_CUDA_FLAGS` is a "global" variable, it - # can't modified on a per-target basis, e.g. for the `punica` extension. + # can't modified on a per-target basis. # So, all the `-gencode` flags need to be extracted and removed from # `CMAKE_CUDA_FLAGS` for processing so they can be passed by another method. # Since it's not possible to use `target_compiler_options` for adding target diff --git a/format.sh b/format.sh index abc688c702aa6..baaebc811d405 100755 --- a/format.sh +++ b/format.sh @@ -242,12 +242,6 @@ echo 'vLLM isort: Done' # NOTE: Keep up to date with .github/workflows/clang-format.yml CLANG_FORMAT_EXCLUDES=( 'csrc/moe/topk_softmax_kernels.cu' - 'csrc/punica/bgmv/bgmv_bf16_bf16_bf16.cu' - 'csrc/punica/bgmv/bgmv_config.h' - 'csrc/punica/bgmv/bgmv_impl.cuh' - 'csrc/punica/bgmv/vec_dtypes.cuh' - 'csrc/punica/punica_ops.cu' - 'csrc/punica/type_convert.h' ) # Format specified files with clang-format diff --git a/vllm/config.py b/vllm/config.py index 0524514f6633a..35945e34452d2 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1304,7 +1304,7 @@ class LoRAConfig: long_lora_scaling_factors: Optional[Tuple[float]] = None def __post_init__(self): - # Keep this in sync with csrc/punica/bgmv/bgmv_config.h + # TODO: Increase the range of rank possible_max_ranks = (8, 16, 32, 64) possible_lora_extra_vocab_size = (0, 256, 512) if self.max_lora_rank not in possible_max_ranks: diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index 42ec99e6ea2c8..d3978ff6f4ff1 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -1073,7 +1073,7 @@ def create_lora_weights( lora_config: LoRAConfig, model_config: Optional[PretrainedConfig] = None, ) -> None: - # Keep this in sync with csrc/punica/bgmv/bgmv_config.h + # TODO: Verify if this condition can be relaxed if 32000 < self.base_layer.vocab_size > 128512: raise ValueError("When using LoRA, vocab size must be " "32000 >= vocab_size <= 128512") From 7b86e7c9cd6541abdf5d083b0a8a98ee667a91d1 Mon Sep 17 00:00:00 2001 From: Alphi <52458637+HwwwwwwwH@users.noreply.github.com> Date: Mon, 5 Aug 2024 09:23:17 +0800 Subject: [PATCH 24/36] [Model] Add multi-image support for minicpmv (#7122) Co-authored-by: hezhihui Co-authored-by: Cyrus Leung --- tests/conftest.py | 5 +- tests/models/test_minicpmv.py | 146 ++++++++++++++++++++++--- vllm/model_executor/models/minicpmv.py | 56 ++++++---- vllm/multimodal/image.py | 2 +- 4 files changed, 172 insertions(+), 37 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 999ca60d07a4f..c7a349f1e9e2a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,7 +3,7 @@ import os import sys from collections import UserList -from typing import Any, Dict, List, Optional, Tuple, TypedDict, TypeVar +from typing import Any, Dict, List, Optional, Tuple, TypedDict, TypeVar, Union import pytest import torch @@ -508,7 +508,8 @@ def generate_greedy_logprobs( prompts: List[str], max_tokens: int, num_logprobs: int, - images: Optional[List[Image.Image]] = None, + images: Optional[Union[List[Image.Image], + List[List[Image.Image]]]] = None, stop_token_ids: Optional[List[int]] = None, ) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]: greedy_logprobs_params = SamplingParams(temperature=0.0, diff --git a/tests/models/test_minicpmv.py b/tests/models/test_minicpmv.py index c57f0f8c08548..c3b2a7bcbaafd 100644 --- a/tests/models/test_minicpmv.py +++ b/tests/models/test_minicpmv.py @@ -14,6 +14,18 @@ pytestmark = pytest.mark.vlm + +class NestedInputs(UserDict): + + def __init__(self, model_inputs: BatchFeature): + super().__init__({"model_inputs": model_inputs}) + + self.model_inputs = model_inputs + + def to(self, device: torch.types.Device): + return NestedInputs(self.model_inputs.to(device)) + + # The image token is placed before "user" on purpose so that the test can pass HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({ "stop_sign": @@ -23,7 +35,7 @@ "cherry_blossom": "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n" \ "(./)\nWhat is the season?<|eot_id|>" \ - "<|start_header_id|>assistant<|end_header_id|>\n\n" + "<|start_header_id|>assistant<|end_header_id|>\n\n", }) models = ["openbmb/MiniCPM-Llama3-V-2_5"] @@ -94,22 +106,10 @@ def run_test( ] with hf_runner(model, dtype=dtype) as hf_model, torch.no_grad(): - - class NestedInputs(UserDict): - - def __init__(self, model_inputs: BatchFeature): - super().__init__({"model_inputs": model_inputs}) - - self.model_inputs = model_inputs - - def to(self, device: torch.types.Device): - return NestedInputs(self.model_inputs.to(device)) - hf_processor = hf_model.processor hf_model.processor = lambda **kw: NestedInputs( hf_processor(**kw) # type: ignore ) - hf_outputs_per_image = [ hf_model.generate_greedy_logprobs_limit(prompts, max_tokens, @@ -161,3 +161,123 @@ def test_models(hf_runner, vllm_runner, image_assets, model, size_factors, num_logprobs=num_logprobs, tensor_parallel_size=1, ) + + +HF_MULTIIMAGE_IMAGE_PROMPT = \ + "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n" \ + "(./)\n(./)\n" \ + "Describe these images.<|eot_id|>" \ + "<|start_header_id|>assistant<|end_header_id|>\n\n" + + +def run_multi_image_test( + hf_runner: Type[HfRunner], + vllm_runner: Type[VllmRunner], + image_assets: _ImageAssets, + model: str, + *, + size_factors: List[float], + dtype: str, + max_tokens: int, + num_logprobs: int, + tensor_parallel_size: int, + distributed_executor_backend: Optional[str] = None, +): + """Inference result should be the same between hf and vllm. + + All the image fixtures for the test is under tests/images. + For huggingface runner, we provide the PIL images as input. + For vllm runner, we provide MultiModalDataDict objects + and corresponding vision language config as input. + Note, the text input is also adjusted to abide by vllm contract. + The text output is sanitized to be able to compare with hf. + """ + images = [asset.pil_image for asset in image_assets] + + inputs_per_case = [ + ([HF_MULTIIMAGE_IMAGE_PROMPT for _ in size_factors], + [[rescale_image_size(image, factor) for image in images] + for factor in size_factors]) + ] + + # NOTE: take care of the order. run vLLM first, and then run HF. + # vLLM needs a fresh new process without cuda initialization. + # if we run HF first, the cuda initialization will be done and it + # will hurt multiprocessing backend with fork method (the default method). + + # max_model_len should be greater than image_feature_size + with vllm_runner(model, + max_model_len=4096, + max_num_seqs=1, + dtype=dtype, + tensor_parallel_size=tensor_parallel_size, + distributed_executor_backend=distributed_executor_backend, + enforce_eager=True) as vllm_model: + tokenizer = vllm_model.model.get_tokenizer() + stop_token_ids = [tokenizer.eos_id, tokenizer.eot_id] + vllm_outputs_per_case = [ + vllm_model.generate_greedy_logprobs(prompts, + max_tokens, + num_logprobs=num_logprobs, + images=images, + stop_token_ids=stop_token_ids) + for prompts, images in inputs_per_case + ] + + with hf_runner(model, dtype=dtype) as hf_model, torch.no_grad(): + hf_processor = hf_model.processor + hf_model.processor = lambda **kw: NestedInputs( + hf_processor(**kw) # type: ignore + ) + hf_outputs_per_case = [ + hf_model.generate_greedy_logprobs_limit(prompts, + max_tokens, + num_logprobs=num_logprobs, + images=images, + tokenizer=tokenizer) + for prompts, images in inputs_per_case + ] + + for hf_outputs, vllm_outputs in zip(hf_outputs_per_case, + vllm_outputs_per_case): + check_logprobs_close( + outputs_0_lst=[ + trunc_hf_output(hf_output) for hf_output in hf_outputs + ], + outputs_1_lst=vllm_outputs, + name_0="hf", + name_1="vllm", + ) + + +@pytest.mark.parametrize("model", models) +@pytest.mark.parametrize( + "size_factors", + [ + # No image + [], + # Single-scale + [1.0], + # Single-scale, batched + [1.0, 1.0, 1.0], + # Multi-scale + [0.25, 0.5, 1.0], + ], +) +@pytest.mark.parametrize("dtype", [target_dtype]) +@pytest.mark.parametrize("max_tokens", [128]) +@pytest.mark.parametrize("num_logprobs", [5]) +def test_multi_images_models(hf_runner, vllm_runner, image_assets, model, + size_factors, dtype: str, max_tokens: int, + num_logprobs: int) -> None: + run_multi_image_test( + hf_runner, + vllm_runner, + image_assets, + model, + size_factors=size_factors, + dtype=dtype, + max_tokens=max_tokens, + num_logprobs=num_logprobs, + tensor_parallel_size=1, + ) diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index 095bb49f6ba76..0388259595628 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -392,6 +392,20 @@ def forward(self, x: torch.Tensor, return x +def get_version_by_config(config: PretrainedConfig) -> Tuple[int, ...]: + version_float = getattr(config, "version", None) + + # The old configs do not include version number + # TODO: Remove this after the HF repos are updated + if version_float is None: + if config.hidden_size == 2304 and config.query_num == 64: + return (2, 0) + return (2, 5) + + version_str = str(version_float) + return tuple(int(x) for x in version_str.split(".")) + + def get_max_minicpmv_image_tokens(ctx: InputContext): hf_config = ctx.get_hf_config(PretrainedConfig) return getattr(hf_config, "query_num", 64) @@ -421,36 +435,43 @@ def input_processor_for_minicpmv(ctx: InputContext, llm_inputs: LLMInputs): multi_modal_data = llm_inputs.get("multi_modal_data") if multi_modal_data is None or "image" not in multi_modal_data: return llm_inputs - model_config = ctx.model_config - + version = get_version_by_config(model_config.hf_config) tokenizer = cached_get_tokenizer(model_config.tokenizer, trust_remote_code=True) + image_processor = cached_get_image_processor(model_config.tokenizer) + + def get_placeholder(image_size: Tuple[int, int], num_image: int): + if version == (2, 0) or version == (2, 5): + return image_processor. \ + get_slice_image_placeholder(image_size) + return image_processor. \ + get_slice_image_placeholder(image_size, num_image) prompt = llm_inputs.get("prompt") if prompt is None: token_ids = llm_inputs.get("prompt_token_ids") prompt = tokenizer.decode(token_ids) - image_processor = cached_get_image_processor(model_config.tokenizer) pattern = "(./)" - image = multi_modal_data["image"] + images = multi_modal_data["image"] + if isinstance(images, Image.Image): + images = [images] image_tags = re.findall(pattern, prompt) if len(image_tags) == 0: new_token_ids = token_ids new_prompt = prompt else: - if len(image_tags) > 1: - logger.warning("Multiple image input is not supported yet, " - "so any extra image tokens will be treated " - "as plain text.") - text_chunks = prompt.split(pattern) - new_prompt = (text_chunks[0] + - image_processor.get_slice_image_placeholder(image.size) + - "".join(text_chunks[1:])) - + new_prompt_chunks: List[str] = [] + for i in range(len(images)): + new_prompt_chunks += [ + text_chunks[i], + get_placeholder(images[i].size, i) + ] + new_prompt_chunks.append(text_chunks[-1]) + new_prompt = "".join(new_prompt_chunks) new_token_ids = tokenizer.encode(new_prompt) llm_inputs = LLMInputs( @@ -478,14 +499,7 @@ def __init__( self.config = config self.multimodal_config = multimodal_config - if not hasattr(self.config, "version"): - if self.config.hidden_size == 2304 and self.config.query_num == 64: - self.version = (2, 0) - else: - self.version = (2, 5) - else: - self.version = str(self.config.version).split(".") - self.version = tuple([int(x) for x in self.version]) + self.version = get_version_by_config(self.config) self.llm = self.init_llm(config, cache_config, quant_config) self.vpm = self.init_vision_module() param_dtype = torch.get_default_dtype() diff --git a/vllm/multimodal/image.py b/vllm/multimodal/image.py index 3b37ce9149fb8..b6a3909e95632 100644 --- a/vllm/multimodal/image.py +++ b/vllm/multimodal/image.py @@ -113,7 +113,7 @@ def _get_hf_image_processor(self, model_config: ModelConfig): def _default_input_mapper(self, ctx: InputContext, data: object) -> MultiModalInputs: model_config = ctx.model_config - if isinstance(data, Image.Image): + if isinstance(data, (Image.Image, list)): image_processor = self._get_hf_image_processor(model_config) if image_processor is None: raise RuntimeError("No HuggingFace processor is available " From cc08fc7225616aeb6709a2e75e5ac47ace124985 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Mon, 5 Aug 2024 11:40:51 +0800 Subject: [PATCH 25/36] [Frontend] Reapply "Factor out code for running uvicorn" (#7095) --- vllm/entrypoints/api_server.py | 77 ++++++++++++++++-------- vllm/entrypoints/launcher.py | 46 +++++++++++++++ vllm/entrypoints/openai/api_server.py | 84 +++++++++------------------ 3 files changed, 125 insertions(+), 82 deletions(-) create mode 100644 vllm/entrypoints/launcher.py diff --git a/vllm/entrypoints/api_server.py b/vllm/entrypoints/api_server.py index 66941442c8c9c..672382717d119 100644 --- a/vllm/entrypoints/api_server.py +++ b/vllm/entrypoints/api_server.py @@ -5,21 +5,23 @@ We are also not going to accept PRs modifying this file, please change `vllm/entrypoints/openai/api_server.py` instead. """ - +import asyncio import json import ssl -from typing import AsyncGenerator +from argparse import Namespace +from typing import Any, AsyncGenerator, Optional -import uvicorn from fastapi import FastAPI, Request from fastapi.responses import JSONResponse, Response, StreamingResponse from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm.entrypoints.launcher import serve_http from vllm.logger import init_logger from vllm.sampling_params import SamplingParams from vllm.usage.usage_lib import UsageContext from vllm.utils import FlexibleArgumentParser, random_uuid +from vllm.version import __version__ as VLLM_VERSION logger = init_logger("vllm.entrypoints.api_server") @@ -81,6 +83,53 @@ async def stream_results() -> AsyncGenerator[bytes, None]: return JSONResponse(ret) +def build_app(args: Namespace) -> FastAPI: + global app + + app.root_path = args.root_path + return app + + +async def init_app( + args: Namespace, + llm_engine: Optional[AsyncLLMEngine] = None, +) -> FastAPI: + app = build_app(args) + + global engine + + engine_args = AsyncEngineArgs.from_cli_args(args) + engine = (llm_engine + if llm_engine is not None else AsyncLLMEngine.from_engine_args( + engine_args, usage_context=UsageContext.API_SERVER)) + + return app + + +async def run_server(args: Namespace, + llm_engine: Optional[AsyncLLMEngine] = None, + **uvicorn_kwargs: Any) -> None: + logger.info("vLLM API server version %s", VLLM_VERSION) + logger.info("args: %s", args) + + app = await init_app(args, llm_engine) + + shutdown_task = await serve_http( + app, + host=args.host, + port=args.port, + log_level=args.log_level, + timeout_keep_alive=TIMEOUT_KEEP_ALIVE, + ssl_keyfile=args.ssl_keyfile, + ssl_certfile=args.ssl_certfile, + ssl_ca_certs=args.ssl_ca_certs, + ssl_cert_reqs=args.ssl_cert_reqs, + **uvicorn_kwargs, + ) + + await shutdown_task + + if __name__ == "__main__": parser = FlexibleArgumentParser() parser.add_argument("--host", type=str, default=None) @@ -105,25 +154,5 @@ async def stream_results() -> AsyncGenerator[bytes, None]: parser.add_argument("--log-level", type=str, default="debug") parser = AsyncEngineArgs.add_cli_args(parser) args = parser.parse_args() - engine_args = AsyncEngineArgs.from_cli_args(args) - engine = AsyncLLMEngine.from_engine_args( - engine_args, usage_context=UsageContext.API_SERVER) - - app.root_path = args.root_path - logger.info("Available routes are:") - for route in app.routes: - if not hasattr(route, 'methods'): - continue - methods = ', '.join(route.methods) - logger.info("Route: %s, Methods: %s", route.path, methods) - - uvicorn.run(app, - host=args.host, - port=args.port, - log_level=args.log_level, - timeout_keep_alive=TIMEOUT_KEEP_ALIVE, - ssl_keyfile=args.ssl_keyfile, - ssl_certfile=args.ssl_certfile, - ssl_ca_certs=args.ssl_ca_certs, - ssl_cert_reqs=args.ssl_cert_reqs) + asyncio.run(run_server(args)) diff --git a/vllm/entrypoints/launcher.py b/vllm/entrypoints/launcher.py new file mode 100644 index 0000000000000..00826762f76a1 --- /dev/null +++ b/vllm/entrypoints/launcher.py @@ -0,0 +1,46 @@ +import asyncio +import signal +from typing import Any + +import uvicorn +from fastapi import FastAPI + +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +async def serve_http(app: FastAPI, **uvicorn_kwargs: Any): + logger.info("Available routes are:") + for route in app.routes: + methods = getattr(route, "methods", None) + path = getattr(route, "path", None) + + if methods is None or path is None: + continue + + logger.info("Route: %s, Methods: %s", path, ', '.join(methods)) + + config = uvicorn.Config(app, **uvicorn_kwargs) + server = uvicorn.Server(config) + + loop = asyncio.get_running_loop() + + server_task = loop.create_task(server.serve()) + + def signal_handler() -> None: + # prevents the uvicorn signal handler to exit early + server_task.cancel() + + async def dummy_shutdown() -> None: + pass + + loop.add_signal_handler(signal.SIGINT, signal_handler) + loop.add_signal_handler(signal.SIGTERM, signal_handler) + + try: + await server_task + return dummy_shutdown() + except asyncio.CancelledError: + logger.info("Gracefully stopping http server") + return server.shutdown() diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index e330ee81f7e44..a0190f3d66b10 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -2,15 +2,13 @@ import importlib import inspect import re -import signal +from argparse import Namespace from contextlib import asynccontextmanager from http import HTTPStatus from multiprocessing import Process from typing import AsyncIterator, Set -import fastapi -import uvicorn -from fastapi import APIRouter, Request +from fastapi import APIRouter, FastAPI, Request from fastapi.exceptions import RequestValidationError from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, Response, StreamingResponse @@ -22,6 +20,7 @@ from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.engine.protocol import AsyncEngineClient +from vllm.entrypoints.launcher import serve_http from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.openai.cli_args import make_arg_parser # yapf conflicts with isort for this block @@ -71,7 +70,7 @@ def model_is_embedding(model_name: str) -> bool: @asynccontextmanager -async def lifespan(app: fastapi.FastAPI): +async def lifespan(app: FastAPI): async def _force_log(): while True: @@ -135,7 +134,7 @@ async def build_async_engine_client(args) -> AsyncIterator[AsyncEngineClient]: router = APIRouter() -def mount_metrics(app: fastapi.FastAPI): +def mount_metrics(app: FastAPI): # Add prometheus asgi middleware to route /metrics requests metrics_route = Mount("/metrics", make_asgi_app()) # Workaround for 307 Redirect for /metrics @@ -225,8 +224,8 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request): return JSONResponse(content=generator.model_dump()) -def build_app(args): - app = fastapi.FastAPI(lifespan=lifespan) +def build_app(args: Namespace) -> FastAPI: + app = FastAPI(lifespan=lifespan) app.include_router(router) app.root_path = args.root_path @@ -274,11 +273,10 @@ async def authentication(request: Request, call_next): return app -async def build_server( +async def init_app( async_engine_client: AsyncEngineClient, - args, - **uvicorn_kwargs, -) -> uvicorn.Server: + args: Namespace, +) -> FastAPI: app = build_app(args) if args.served_model_name is not None: @@ -334,62 +332,31 @@ async def build_server( ) app.root_path = args.root_path - logger.info("Available routes are:") - for route in app.routes: - if not hasattr(route, 'methods'): - continue - methods = ', '.join(route.methods) - logger.info("Route: %s, Methods: %s", route.path, methods) - - config = uvicorn.Config( - app, - host=args.host, - port=args.port, - log_level=args.uvicorn_log_level, - timeout_keep_alive=TIMEOUT_KEEP_ALIVE, - ssl_keyfile=args.ssl_keyfile, - ssl_certfile=args.ssl_certfile, - ssl_ca_certs=args.ssl_ca_certs, - ssl_cert_reqs=args.ssl_cert_reqs, - **uvicorn_kwargs, - ) - - return uvicorn.Server(config) + return app async def run_server(args, **uvicorn_kwargs) -> None: logger.info("vLLM API server version %s", VLLM_VERSION) logger.info("args: %s", args) - shutdown_task = None async with build_async_engine_client(args) as async_engine_client: - - server = await build_server( - async_engine_client, - args, + app = await init_app(async_engine_client, args) + + shutdown_task = await serve_http( + app, + host=args.host, + port=args.port, + log_level=args.uvicorn_log_level, + timeout_keep_alive=TIMEOUT_KEEP_ALIVE, + ssl_keyfile=args.ssl_keyfile, + ssl_certfile=args.ssl_certfile, + ssl_ca_certs=args.ssl_ca_certs, + ssl_cert_reqs=args.ssl_cert_reqs, **uvicorn_kwargs, ) - loop = asyncio.get_running_loop() - - server_task = loop.create_task(server.serve()) - - def signal_handler() -> None: - # prevents the uvicorn signal handler to exit early - server_task.cancel() - - loop.add_signal_handler(signal.SIGINT, signal_handler) - loop.add_signal_handler(signal.SIGTERM, signal_handler) - - try: - await server_task - except asyncio.CancelledError: - logger.info("Gracefully stopping http server") - shutdown_task = server.shutdown() - - if shutdown_task: - # NB: Await server shutdown only after the backend context is exited - await shutdown_task + # NB: Await server shutdown only after the backend context is exited + await shutdown_task if __name__ == "__main__": @@ -399,4 +366,5 @@ def signal_handler() -> None: description="vLLM OpenAI-Compatible RESTful API server.") parser = make_arg_parser(parser) args = parser.parse_args() + asyncio.run(run_server(args)) From c0d8f1636c58f5464e512eaabfed5aa29f2c5b7d Mon Sep 17 00:00:00 2001 From: Jungho Christopher Cho Date: Mon, 5 Aug 2024 15:22:12 +0900 Subject: [PATCH 26/36] [Model] SiglipVisionModel ported from transformers (#6942) Co-authored-by: Roger Wang --- examples/offline_inference_vision_language.py | 3 +- vllm/model_executor/models/paligemma.py | 79 +-- vllm/model_executor/models/siglip.py | 621 ++++++++++++++++++ 3 files changed, 650 insertions(+), 53 deletions(-) create mode 100644 vllm/model_executor/models/siglip.py diff --git a/examples/offline_inference_vision_language.py b/examples/offline_inference_vision_language.py index 846246a2062a6..ce9dc9e457c09 100644 --- a/examples/offline_inference_vision_language.py +++ b/examples/offline_inference_vision_language.py @@ -65,7 +65,8 @@ def run_phi3v(question): # PaliGemma def run_paligemma(question): - prompt = question + # PaliGemma has special prompt format for VQA + prompt = "caption en" llm = LLM(model="google/paligemma-3b-mix-224") return llm, prompt diff --git a/vllm/model_executor/models/paligemma.py b/vllm/model_executor/models/paligemma.py index fe91611cd30ff..9ba53b8b59a2f 100644 --- a/vllm/model_executor/models/paligemma.py +++ b/vllm/model_executor/models/paligemma.py @@ -1,9 +1,8 @@ from typing import Iterable, List, Literal, Optional, Tuple, TypedDict import torch -from PIL import Image from torch import nn -from transformers import PaliGemmaConfig, SiglipVisionConfig, SiglipVisionModel +from transformers import PaliGemmaConfig from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, MultiModalConfig @@ -18,9 +17,11 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.image import cached_get_tokenizer -from vllm.sequence import IntermediateTensors, SamplerOutput, SequenceData +from vllm.sequence import IntermediateTensors, SamplerOutput from .interfaces import SupportsVision +from .siglip import (SiglipVisionModel, dummy_image_for_siglip, + dummy_seq_data_for_siglip, get_max_siglip_image_tokens) from .utils import merge_vision_embeddings logger = init_logger(__name__) @@ -32,55 +33,22 @@ def get_max_paligemma_image_tokens(ctx: InputContext): hf_config = ctx.get_hf_config(PaliGemmaConfig) - text_config = hf_config.text_config - - return text_config.num_image_tokens - - -def dummy_seq_data_for_paligemma( - hf_config: PaliGemmaConfig, - seq_len: int, - *, - image_token_id: int, - image_feature_size_override: Optional[int] = None, -): - if image_feature_size_override is None: - image_feature_size = hf_config.text_config.num_image_tokens - else: - image_feature_size = image_feature_size_override - - token_ids = [image_token_id] * image_feature_size - token_ids += [0] * (seq_len - image_feature_size) - return SequenceData(token_ids) - - -def dummy_image_for_paligemma( - hf_config: SiglipVisionConfig, - *, - image_width_override: Optional[int] = None, - image_height_override: Optional[int] = None, -): - width = height = hf_config.image_size - if image_width_override is not None: - width = image_width_override - if image_height_override is not None: - height = image_height_override + vision_config = hf_config.vision_config - image = Image.new("RGB", (width, height), color=0) - return {"image": image} + return get_max_siglip_image_tokens(vision_config) def dummy_data_for_paligemma(ctx: InputContext, seq_len: int): hf_config = ctx.get_hf_config(PaliGemmaConfig) vision_config = hf_config.vision_config - seq_data = dummy_seq_data_for_paligemma( - hf_config, + seq_data = dummy_seq_data_for_siglip( + vision_config, seq_len, image_token_id=hf_config.image_token_index, ) - mm_data = dummy_image_for_paligemma(vision_config) + mm_data = dummy_image_for_siglip(vision_config) return seq_data, mm_data @@ -208,30 +176,37 @@ def _parse_and_validate_image_input( data=self._validate_pixel_values(pixel_values), ) - def _image_pixels_to_features(self, vision_tower: SiglipVisionModel, - pixel_values: torch.Tensor) -> torch.Tensor: + def _image_pixels_to_features( + self, + vision_tower: SiglipVisionModel, + pixel_values: torch.Tensor, + ) -> torch.Tensor: target_dtype = vision_tower.get_input_embeddings().weight.dtype - image_outputs = vision_tower(pixel_values.to(dtype=target_dtype), - output_hidden_states=True) - - selected_image_features = image_outputs.last_hidden_state + image_features = vision_tower(pixel_values.to(dtype=target_dtype)) - return selected_image_features + return image_features def _process_image_pixels( - self, inputs: PaliGemmaImagePixelInputs) -> torch.Tensor: + self, + inputs: PaliGemmaImagePixelInputs, + ) -> torch.Tensor: assert self.vision_tower is not None pixel_values = inputs["data"] - return self._image_pixels_to_features(self.vision_tower, pixel_values) + return self._image_pixels_to_features( + self.vision_tower, + pixel_values, + ) def _process_image_input( - self, image_input: PaliGemmaImageInputs) -> torch.Tensor: + self, + image_input: PaliGemmaImageInputs, + ) -> torch.Tensor: assert self.vision_tower is not None - image_features = self._process_image_pixels(image_input) + image_features = self._process_image_pixels(image_input, ) return self.multi_modal_projector(image_features) diff --git a/vllm/model_executor/models/siglip.py b/vllm/model_executor/models/siglip.py new file mode 100644 index 0000000000000..6faef45c9a6d3 --- /dev/null +++ b/vllm/model_executor/models/siglip.py @@ -0,0 +1,621 @@ +"""Implementation of SiglipVisionModel intended to be only used +within a vision language model.""" + +import math +from typing import Optional, Tuple + +import torch +from PIL import Image +from torch import nn +from transformers import SiglipConfig, SiglipVisionConfig +from transformers.models.siglip.modeling_siglip import SiglipAttention +from vllm_flash_attn import flash_attn_func +from xformers.ops import memory_efficient_attention + +from vllm.config import ModelConfig +from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.inputs import LLMInputs +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding) +from vllm.multimodal.image import (cached_get_tokenizer, + repeat_and_pad_image_tokens) +from vllm.sequence import SequenceData + + +def get_siglip_patch_grid_length(*, image_size: int, patch_size: int) -> int: + assert image_size % patch_size == 0 + return image_size // patch_size + + +def get_siglip_num_patches(*, image_size: int, patch_size: int) -> int: + grid_length = get_siglip_patch_grid_length(image_size=image_size, + patch_size=patch_size) + return grid_length * grid_length + + +def get_siglip_image_feature_size(hf_config: SiglipVisionConfig) -> int: + return get_siglip_num_patches(image_size=hf_config.image_size, + patch_size=hf_config.patch_size) + + +def get_max_siglip_image_tokens(hf_config: SiglipVisionConfig) -> int: + return get_siglip_image_feature_size(hf_config) + + +def dummy_seq_data_for_siglip( + hf_config: SiglipVisionConfig, + seq_len: int, + *, + image_token_id: int, + image_feature_size_override: Optional[int] = None, +): + if image_feature_size_override is None: + image_feature_size = get_siglip_image_feature_size(hf_config) + else: + image_feature_size = image_feature_size_override + + token_ids = [image_token_id] * image_feature_size + token_ids += [0] * (seq_len - image_feature_size) + return SequenceData(token_ids) + + +def dummy_image_for_siglip( + hf_config: SiglipVisionConfig, + *, + image_width_override: Optional[int] = None, + image_height_override: Optional[int] = None, +): + width = height = hf_config.image_size + if image_width_override is not None: + width = image_width_override + if image_height_override is not None: + height = image_height_override + + image = Image.new("RGB", (width, height), color=0) + return {"image": image} + + +def input_processor_for_siglip( + model_config: ModelConfig, + hf_config: SiglipVisionConfig, + llm_inputs: LLMInputs, + *, + image_token_id: int, + image_feature_size_override: Optional[int] = None, +): + multi_modal_data = llm_inputs.get("multi_modal_data") + if multi_modal_data is None or "image" not in multi_modal_data: + return llm_inputs + + tokenizer = cached_get_tokenizer(model_config.tokenizer) + + if image_feature_size_override is None: + image_feature_size = get_siglip_image_feature_size(hf_config) + else: + image_feature_size = image_feature_size_override + + new_prompt, new_token_ids = repeat_and_pad_image_tokens( + tokenizer, + llm_inputs.get("prompt"), + llm_inputs["prompt_token_ids"], + image_token_id=image_token_id, + repeat_count=image_feature_size, + ) + + # NOTE: Create a defensive copy of the original inputs + return LLMInputs( + prompt_token_ids=new_token_ids, + prompt=new_prompt, + multi_modal_data=multi_modal_data, + ) + + +# Adapted from https://github.com/huggingface/transformers/blob/v4.43.3/src/transformers/models/siglip/modeling_siglip.py#L249 # noqa +class SiglipVisionEmbeddings(nn.Module): + + def __init__(self, config: SiglipVisionConfig): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.image_size = config.image_size + self.patch_size = config.patch_size + + self.patch_embedding = nn.Conv2d( + in_channels=config.num_channels, + out_channels=self.embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size, + padding="valid", + ) + + self.num_patches = (self.image_size // self.patch_size)**2 + self.num_positions = self.num_patches + self.position_embedding = VocabParallelEmbedding( + self.num_positions, self.embed_dim) + self.register_buffer( + "position_ids", + torch.arange(self.num_positions, dtype=torch.int64).expand( + (1, -1)), + persistent=False, + ) + + def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, + width: int) -> torch.Tensor: + """ + This method is an adapted method for SigLIP (due to SigLIP not having + class embedding unlike other ViTs) that allows the model to interpolate + the pre-trained position encodings such that it can be usable on higher + resolution images. + + Source: + https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174 + """ + position_embeddings = self.position_embedding.weight.unsqueeze(0) + num_patches = embeddings.shape[1] + num_positions = position_embeddings.shape[1] + if num_patches == num_positions and height == width: + return position_embeddings + + dim = embeddings.shape[-1] + height = height // self.patch_size + width = width // self.patch_size + # we add a small number to avoid floating point error + # in the interpolation + # see discussion at https://github.com/facebookresearch/dino/issues/8 + height, width = height + 0.1, width + 0.1 + + patch_pos_embed = position_embeddings.reshape( + 1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), + dim) + patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed, + scale_factor=( + height / math.sqrt(num_positions), + width / math.sqrt(num_positions), + ), + mode="bicubic", + align_corners=False, + ) + if (int(height) != patch_pos_embed.shape[-2] + or int(width) != patch_pos_embed.shape[-1]): + raise ValueError("Width or height does not match with " + "the interpolated position embeddings") + + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return patch_pos_embed + + def forward(self, + pixel_values: torch.Tensor, + interpolate_pos_encoding: bool = False) -> torch.Tensor: + _, _, height, width = pixel_values.shape + target_dtype = self.patch_embedding.weight.dtype + patch_embeds = self.patch_embedding(pixel_values.to( + dtype=target_dtype)) # shape = [*, width, grid, grid] + embeddings = patch_embeds.flatten(2).transpose(1, 2) + + if interpolate_pos_encoding: + embeddings = embeddings + self.interpolate_pos_encoding( + embeddings, height, width) + else: + embeddings = embeddings + self.position_embedding( + self.position_ids) + return embeddings + + +# NOTE: Not used - kept for later when we TP the ViT +# TODO(ChristopherCho): Implement TP version of Attention +class SiglipTPAttention(nn.Module): + + def __init__( + self, + config, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = config.num_attention_heads + if self.total_num_heads % tp_size != 0: + raise ValueError( + f"Number of attention heads ({self.total_num_heads}) " + "must be divisible by the tensor model parallel size" + f" ({tp_size}).") + + self.num_heads = self.total_num_heads // tp_size + self.head_dim = self.embed_dim // self.total_num_heads + if self.head_dim * self.total_num_heads != self.embed_dim: + raise ValueError(f"embed_dim must be divisible by num_heads (got " + "`embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads}).") + self.qkv_size = self.num_heads * self.head_dim + self.scale = self.head_dim**-0.5 + self.dropout = config.attention_dropout + + self.qkv_proj = QKVParallelLinear( + hidden_size=self.embed_dim, + head_size=self.head_dim, + total_num_heads=self.total_num_heads, + quant_config=quant_config, + ) + self.out_proj = RowParallelLinear( + input_size=self.embed_dim, + output_size=self.embed_dim, + quant_config=quant_config, + ) + + self.attn_fn = self._basic_attention_forward + + def forward( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + """Input shape: Batch x Time x Channel""" + batch_size, q_len, _ = hidden_states.size() + + qkv_states, _ = self.qkv_proj(hidden_states) + query_states, key_states, value_states = qkv_states.split( + [self.qkv_size] * 3, dim=-1) + + attn_output = self.attn_fn( + q=query_states, + k=key_states, + v=value_states, + batch_size=batch_size, + q_len=q_len, + ) + + attn_output, _ = self.out_proj(attn_output) + return attn_output + + def _basic_attention_forward(self, q, k, v, batch_size, q_len): + q = q.view(batch_size, q_len, self.num_heads, + self.head_dim).transpose(1, 2) + k = k.view(batch_size, q_len, self.num_heads, + self.head_dim).transpose(1, 2) + v = v.view(batch_size, q_len, self.num_heads, + self.head_dim).transpose(1, 2) + + k_v_seq_len = k.shape[-2] + attn_weights = torch.matmul(q, k.transpose(2, 3)) * self.scale + + if attn_weights.size() != ( + batch_size, + self.num_heads, + q_len, + k_v_seq_len, + ): + raise ValueError( + "Attention weights should be of size " + f"{(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is" + f" {attn_weights.size()}") + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, + dim=-1, + dtype=torch.float32).to(q.dtype) + attn_weights = nn.functional.dropout(attn_weights, + p=self.dropout, + training=self.training) + attn_output = torch.matmul(attn_weights, v) + + if attn_output.size() != ( + batch_size, + self.num_heads, + q_len, + self.head_dim, + ): + raise ValueError( + "`attn_output` should be of size " + f"{(batch_size, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}") + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim) + + return attn_output + + +# NOTE: Not used - kept for later when we TP the ViT +# TODO(ChristopherCho): flash_attn_func is not working properly. +# It constantly throws a CUDA error. +class SiglipFlashAttention2(SiglipTPAttention): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.attn_fn = self._flash_attention_forward + + # Ported from https://github.com/huggingface/transformers/blob/v4.43.3/src/transformers/models/siglip/modeling_siglip.py#L449 + # and https://github.com/huggingface/transformers/blob/v4.43.3/src/transformers/modeling_flash_attention_utils.py#L133 + def _flash_attention_forward(self, q, k, v, batch_size, q_len, *args, + **kwargs): + """Implements the multihead softmax attention. + Arguments + --------- + q, k, v: The tensor containing the + query, key, and value. (B, S, H, D) + """ + + q = q.view(batch_size, q_len, self.num_heads, self.head_dim) + k = k.view(batch_size, q_len, self.num_heads, self.head_dim) + v = v.view(batch_size, q_len, self.num_heads, self.head_dim) + + attn_output = flash_attn_func( + q, + k, + v, + dropout_p=self.dropout, + causal=False, + ) + + attn_output = attn_output.reshape(batch_size, q_len, + self.embed_dim).contiguous() + + return attn_output + + +# NOTE: Not used - kept for later when we TP the ViT +class SiglipSdpaAttention(SiglipTPAttention): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.is_causal = False + self.attn_fn = self._sdpa_attention_forward + + def _sdpa_attention_forward(self, q, k, v, batch_size, q_len): + q = q.view(batch_size, q_len, self.num_heads, + self.head_dim).transpose(1, 2) + k = k.view(batch_size, q_len, self.num_heads, + self.head_dim).transpose(1, 2) + v = v.view(batch_size, q_len, self.num_heads, + self.head_dim).transpose(1, 2) + + attn_output = torch.nn.functional.scaled_dot_product_attention( + q, k, v, dropout_p=self.dropout, is_causal=False, scale=self.scale) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(batch_size, q_len, self.embed_dim) + + return attn_output + + +# NOTE: Not used - kept for later when we TP the ViT +class SiglipxFormersAttention(SiglipTPAttention): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.attn_fn = self._xformers_attention_forward + + def _xformers_attention_forward(self, q, k, v, batch_size, q_len): + q = q.view(batch_size, q_len, self.num_heads, self.head_dim) + k = k.view(batch_size, q_len, self.num_heads, self.head_dim) + v = v.view(batch_size, q_len, self.num_heads, self.head_dim) + + attn_output = memory_efficient_attention(q, + k, + v, + p=0.0, + scale=self.scale) + attn_output = attn_output.reshape(batch_size, q_len, + self.embed_dim).contiguous() + + return attn_output + + +# NOTE: Not used - kept for later when we TP the ViT +SIGLIP_ATTENTION_CLASSES = { + "eager": SiglipTPAttention, + "flash_attention_2": SiglipFlashAttention2, + "sdpa": SiglipSdpaAttention, + "xformers": SiglipxFormersAttention, +} + + +class SiglipMLP(nn.Module): + + def __init__( + self, + config, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.config = config + self.activation_fn = get_act_fn(config.hidden_act) + + # For quantization, we require the hidden size to be a multiple of 64 + quantizable = (config.hidden_size % 64 == 0 + and config.intermediate_size % 64 == 0) + self.fc1 = ColumnParallelLinear( + config.hidden_size, + config.intermediate_size, + quant_config=quant_config if quantizable else None, + ) + self.fc2 = RowParallelLinear( + config.intermediate_size, + config.hidden_size, + quant_config=quant_config if quantizable else None, + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states, _ = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states, _ = self.fc2(hidden_states) + return hidden_states + + +class SiglipEncoderLayer(nn.Module): + + def __init__( + self, + config: SiglipConfig, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.embed_dim = config.hidden_size + + # TODO(ChristopherCho): use TP'ed Attention block + self.self_attn = SiglipAttention(config) + self.layer_norm1 = nn.LayerNorm(self.embed_dim, + eps=config.layer_norm_eps) + self.mlp = SiglipMLP( + config, + quant_config=quant_config, + ) + self.layer_norm2 = nn.LayerNorm(self.embed_dim, + eps=config.layer_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + ) -> Tuple[torch.Tensor]: + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + hidden_states, _ = self.self_attn(hidden_states=hidden_states) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states, None + + +class SiglipEncoder(nn.Module): + + def __init__( + self, + config: SiglipConfig, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.config = config + self.layers = nn.ModuleList([ + SiglipEncoderLayer( + config, + quant_config=quant_config, + ) for _ in range(config.num_hidden_layers) + ]) + + def forward( + self, + inputs_embeds: torch.Tensor, + ) -> Tuple: + hidden_states = inputs_embeds + for encoder_layer in self.layers: + hidden_states, _ = encoder_layer(hidden_states) + + return hidden_states + + +class SiglipMultiheadAttentionPoolingHead(nn.Module): + """Multihead Attention Pooling.""" + + def __init__( + self, + config: SiglipVisionConfig, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + + self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size)) + # TODO(ChristopherCho): Implement vLLM version of MultiheadAttention + self.attention = torch.nn.MultiheadAttention( + config.hidden_size, config.num_attention_heads, batch_first=True) + self.layernorm = nn.LayerNorm(config.hidden_size, + eps=config.layer_norm_eps) + self.mlp = SiglipMLP(config=config, quant_config=quant_config) + + def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: + batch_size = hidden_state.shape[0] + probe = self.probe.repeat(batch_size, 1, 1) + + hidden_state = self.attention(probe, hidden_state, hidden_state)[0] + + residual = hidden_state + hidden_state = self.layernorm(hidden_state) + hidden_state = residual + self.mlp(hidden_state) + + return hidden_state[:, 0] + + +class SiglipVisionTransformer(nn.Module): + + def __init__( + self, + config: SiglipVisionConfig, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.config = config + embed_dim = config.hidden_size + + self.embeddings = SiglipVisionEmbeddings(config) + self.encoder = SiglipEncoder( + config, + quant_config=quant_config, + ) + self.post_layernorm = nn.LayerNorm(embed_dim, + eps=config.layer_norm_eps) + self.use_head = (True if not hasattr(config, "vision_use_head") else + config.vision_use_head) + if self.use_head: + self.head = SiglipMultiheadAttentionPoolingHead( + config=config, quant_config=quant_config) + + def forward( + self, + pixel_values: torch.Tensor, + interpolate_pos_encoding: bool = True, + ) -> torch.Tensor: + hidden_states = self.embeddings( + pixel_values, + interpolate_pos_encoding=interpolate_pos_encoding, + ) + + encoder_outputs = self.encoder(inputs_embeds=hidden_states) + + last_hidden_state = self.post_layernorm(encoder_outputs) + + # TODO: add this back when pooled_output is used in inference + # if self.use_head: + # pooled_output = self.head(last_hidden_state) + + return last_hidden_state + + +class SiglipVisionModel(nn.Module): + config_class = SiglipVisionConfig + main_input_name = "pixel_values" + + def __init__( + self, + config: SiglipVisionConfig, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.vision_model = SiglipVisionTransformer( + config, + quant_config, + ) + + def get_input_embeddings(self) -> nn.Module: + return self.vision_model.embeddings.patch_embedding + + def forward( + self, + pixel_values: torch.Tensor, + interpolate_pos_encoding: bool = False, + ) -> torch.Tensor: + return self.vision_model( + pixel_values=pixel_values, + interpolate_pos_encoding=interpolate_pos_encoding, + ) From 82a1b1a82b1fbb454c82a9ef95730b929c9b270c Mon Sep 17 00:00:00 2001 From: Cade Daniel Date: Mon, 5 Aug 2024 01:46:44 -0700 Subject: [PATCH 27/36] [Speculative decoding] Add periodic log with time spent in proposal/scoring/verification (#6963) --- tests/spec_decode/test_spec_decode_worker.py | 68 ++++++++++++++------ vllm/config.py | 8 ++- vllm/engine/arg_utils.py | 1 + vllm/spec_decode/spec_decode_worker.py | 68 ++++++++++++++++---- vllm/spec_decode/util.py | 15 +++++ 5 files changed, 125 insertions(+), 35 deletions(-) diff --git a/tests/spec_decode/test_spec_decode_worker.py b/tests/spec_decode/test_spec_decode_worker.py index 671c9bef294f9..9ae1b4bc40f0f 100644 --- a/tests/spec_decode/test_spec_decode_worker.py +++ b/tests/spec_decode/test_spec_decode_worker.py @@ -34,8 +34,11 @@ def test_correctly_calls_draft_model(k: int, batch_size: int, target_worker = mock_worker() metrics_collector = MagicMock(spec=AsyncMetricsCollector) worker = SpecDecodeWorker( - draft_worker, target_worker, - mock_spec_decode_sampler(acceptance_sampler_method), metrics_collector) + draft_worker, + target_worker, + mock_spec_decode_sampler(acceptance_sampler_method), + disable_logprobs=False, + metrics_collector=metrics_collector) exception_secret = 'artificial stop' draft_worker.get_spec_proposals.side_effect = ValueError(exception_secret) @@ -74,8 +77,11 @@ def test_correctly_calls_target_model(k: int, batch_size: int, set_random_seed(1) worker = SpecDecodeWorker( - draft_worker, target_worker, - mock_spec_decode_sampler(acceptance_sampler_method), metrics_collector) + draft_worker, + target_worker, + mock_spec_decode_sampler(acceptance_sampler_method), + disable_logprobs=False, + metrics_collector=metrics_collector) worker.init_device() vocab_size = 32_000 @@ -159,8 +165,11 @@ def test_correctly_calls_spec_decode_sampler(k: int, batch_size: int, set_random_seed(1) - worker = SpecDecodeWorker(draft_worker, target_worker, spec_decode_sampler, - metrics_collector) + worker = SpecDecodeWorker(draft_worker, + target_worker, + spec_decode_sampler, + disable_logprobs=False, + metrics_collector=metrics_collector) worker.init_device() proposal_token_ids = torch.randint(low=0, @@ -249,8 +258,11 @@ def test_correctly_formats_output(k: int, batch_size: int, set_random_seed(1) spec_decode_sampler = mock_spec_decode_sampler(acceptance_sampler_method) - worker = SpecDecodeWorker(draft_worker, target_worker, spec_decode_sampler, - metrics_collector) + worker = SpecDecodeWorker(draft_worker, + target_worker, + spec_decode_sampler, + disable_logprobs=False, + metrics_collector=metrics_collector) worker.init_device() proposal_token_ids = torch.randint(low=0, @@ -479,9 +491,13 @@ def test_k_equals_zero(k: int, batch_size: int, set_random_seed(1) worker = SpecDecodeWorker( - draft_worker, target_worker, - mock_spec_decode_sampler(acceptance_sampler_method), False, - metrics_collector) + proposer_worker=draft_worker, + scorer_worker=target_worker, + spec_decode_sampler=mock_spec_decode_sampler( + acceptance_sampler_method), + disable_logprobs=False, + metrics_collector=metrics_collector, + ) seq_group_metadata_list, _, _ = create_batch(batch_size, k, @@ -526,9 +542,13 @@ def test_empty_input_batch(k: int, batch_size: int, set_random_seed(1) worker = SpecDecodeWorker( - draft_worker, target_worker, - mock_spec_decode_sampler(acceptance_sampler_method), False, - metrics_collector) + proposer_worker=draft_worker, + scorer_worker=target_worker, + spec_decode_sampler=mock_spec_decode_sampler( + acceptance_sampler_method), + disable_logprobs=False, + metrics_collector=metrics_collector, + ) seq_group_metadata_list, _, _ = create_batch(batch_size, k, @@ -560,8 +580,13 @@ def test_init_device(acceptance_sampler_method: str): spec_decode_sampler = mock_spec_decode_sampler(acceptance_sampler_method) metrics_collector = MagicMock(spec=AsyncMetricsCollector) - worker = SpecDecodeWorker(draft_worker, target_worker, spec_decode_sampler, - False, metrics_collector) + worker = SpecDecodeWorker( + proposer_worker=draft_worker, + scorer_worker=target_worker, + spec_decode_sampler=spec_decode_sampler, + disable_logprobs=False, + metrics_collector=metrics_collector, + ) worker.init_device() draft_worker.init_device.assert_called_once() @@ -583,9 +608,11 @@ def test_initialize_cache(acceptance_sampler_method): target_worker = mock_worker() metrics_collector = MagicMock(spec=AsyncMetricsCollector) - worker = SpecDecodeWorker( - draft_worker, target_worker, - mock_spec_decode_sampler(acceptance_sampler_method), metrics_collector) + worker = SpecDecodeWorker(proposer_worker=draft_worker, + scorer_worker=target_worker, + spec_decode_sampler=mock_spec_decode_sampler( + acceptance_sampler_method), + metrics_collector=metrics_collector) kwargs = {"num_gpu_blocks": 1024, "num_cpu_blocks": 1023} worker.initialize_cache(**kwargs) @@ -725,7 +752,8 @@ def test_populate_seq_ids_with_bonus_tokens(): seq_group_metadata_list=seq_group_metadata_list, accepted_token_ids=accepted_token_ids, target_logprobs=target_token_logprobs, - k=k) + k=k, + stage_times=(0, 0, 0)) # Verify that _seq_with_bonus_token_in_last_step contains the following: # 1. Sequence IDs that were already present in # _seq_with_bonus_token_in_last_step but were not part of the current diff --git a/vllm/config.py b/vllm/config.py index 35945e34452d2..bec0b63197ef4 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -907,6 +907,7 @@ def maybe_create_spec_config( speculative_max_model_len: Optional[int], enable_chunked_prefill: bool, use_v2_block_manager: bool, + disable_log_stats: bool, speculative_disable_by_batch_size: Optional[int], ngram_prompt_lookup_max: Optional[int], ngram_prompt_lookup_min: Optional[int], @@ -1095,7 +1096,8 @@ def maybe_create_spec_config( typical_acceptance_sampler_posterior_threshold, typical_acceptance_sampler_posterior_alpha=\ typical_acceptance_sampler_posterior_alpha, - disable_logprobs=disable_logprobs + disable_logprobs=disable_logprobs, + disable_log_stats=disable_log_stats, ) @staticmethod @@ -1189,6 +1191,7 @@ def __init__( typical_acceptance_sampler_posterior_threshold: float, typical_acceptance_sampler_posterior_alpha: float, disable_logprobs: bool, + disable_log_stats: bool, ): """Create a SpeculativeConfig object. @@ -1221,6 +1224,8 @@ def __init__( sampling, target sampling, and after accepted tokens are determined. If set to False, log probabilities will be returned. + disable_log_stats: Whether to disable periodic printing of stage + times in speculative decoding. """ self.draft_model_config = draft_model_config self.draft_parallel_config = draft_parallel_config @@ -1235,6 +1240,7 @@ def __init__( self.typical_acceptance_sampler_posterior_alpha = \ typical_acceptance_sampler_posterior_alpha self.disable_logprobs = disable_logprobs + self.disable_log_stats = disable_log_stats self._verify_args() diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 2737b50927f6b..acc0551af0154 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -792,6 +792,7 @@ def create_engine_config(self, ) -> EngineConfig: speculative_max_model_len=self.speculative_max_model_len, enable_chunked_prefill=self.enable_chunked_prefill, use_v2_block_manager=self.use_v2_block_manager, + disable_log_stats=self.disable_log_stats, ngram_prompt_lookup_max=self.ngram_prompt_lookup_max, ngram_prompt_lookup_min=self.ngram_prompt_lookup_min, draft_token_acceptance_method=\ diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index ad8c0cee0b5b6..690aad505e215 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -27,7 +27,7 @@ from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase from vllm.spec_decode.smaller_tp_proposer_worker import SmallerTpProposerWorker from vllm.spec_decode.target_model_runner import TargetModelRunner -from vllm.spec_decode.util import (create_sequence_group_output, +from vllm.spec_decode.util import (Timer, create_sequence_group_output, get_all_num_logprobs, get_sampled_token_logprobs, nvtx_range, split_batch_by_proposal_len) @@ -75,7 +75,9 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker": typical_acceptance_sampler_posterior_threshold, typical_acceptance_sampler_posterior_alpha=speculative_config. typical_acceptance_sampler_posterior_alpha, - disable_logprobs=speculative_config.disable_logprobs) + disable_logprobs=speculative_config.disable_logprobs, + disable_log_stats=speculative_config.disable_log_stats, + ) return spec_decode_worker @@ -116,6 +118,7 @@ def create_worker( typical_acceptance_sampler_posterior_threshold: float, typical_acceptance_sampler_posterior_alpha: float, disable_logprobs: bool, + disable_log_stats: bool, ) -> "SpecDecodeWorker": allow_zero_draft_token_step = True @@ -171,6 +174,7 @@ def create_worker( proposer_worker, scorer_worker, disable_logprobs=disable_logprobs, + disable_log_stats=disable_log_stats, disable_by_batch_size=disable_by_batch_size, spec_decode_sampler=spec_decode_sampler, allow_zero_draft_token_step=allow_zero_draft_token_step) @@ -180,7 +184,8 @@ def __init__( proposer_worker: ProposerWorkerBase, scorer_worker: WorkerBase, spec_decode_sampler: SpecDecodeBaseSampler, - disable_logprobs: bool, + disable_logprobs: bool = False, + disable_log_stats: bool = False, metrics_collector: Optional[AsyncMetricsCollector] = None, disable_by_batch_size: Optional[int] = None, allow_zero_draft_token_step: Optional[bool] = True, @@ -203,6 +208,8 @@ def __init__( disable_logprobs: If set to True, token log probabilities will not be output in both the draft worker and the target worker. If set to False, log probabilities will be output by both. + disable_log_stats: If set to True, disable periodic printing of + speculative stage times. disable_by_batch_size: If the batch size is larger than this, disable speculative decoding for new incoming requests. metrics_collector: Helper class for collecting metrics; can be set @@ -240,6 +247,7 @@ def __init__( # in the subsequent step. self.previous_hidden_states: Optional[HiddenStates] = None self._disable_logprobs = disable_logprobs + self._disable_log_stats = disable_log_stats def init_device(self) -> None: """Initialize both scorer and proposer models. @@ -525,28 +533,37 @@ def _run_speculative_decoding_step( execute_model_req.previous_hidden_states = self.previous_hidden_states self.previous_hidden_states = None - # Generate proposals using draft worker. - proposals = self.proposer_worker.get_spec_proposals( - execute_model_req, self._seq_with_bonus_token_in_last_step) + with Timer() as proposal_timer: + # Generate proposals using draft worker. + proposals = self.proposer_worker.get_spec_proposals( + execute_model_req, self._seq_with_bonus_token_in_last_step) if not self._allow_zero_draft_token_step and proposals.no_proposals: #TODO: Fix it #5814 raise RuntimeError("Cannot handle cases where distributed draft " "workers generate no tokens") - proposal_scores = self.scorer.score_proposals( - execute_model_req, - proposals, - ) - accepted_token_ids, target_logprobs = self._verify_tokens( - execute_model_req.seq_group_metadata_list, proposal_scores, - proposals, execute_model_req.num_lookahead_slots) + with Timer() as scoring_timer: + proposal_scores = self.scorer.score_proposals( + execute_model_req, + proposals, + ) + + with Timer() as verification_timer: + accepted_token_ids, target_logprobs = self._verify_tokens( + execute_model_req.seq_group_metadata_list, proposal_scores, + proposals, execute_model_req.num_lookahead_slots) + + stage_times = (proposal_timer.elapsed_time_ms / num_lookahead_slots, + scoring_timer.elapsed_time_ms, + verification_timer.elapsed_time_ms) return self._create_output_sampler_list( execute_model_req.seq_group_metadata_list, accepted_token_ids, target_logprobs=target_logprobs, - k=execute_model_req.num_lookahead_slots) + k=execute_model_req.num_lookahead_slots, + stage_times=stage_times) @nvtx_range("spec_decode_worker._verify_tokens") def _verify_tokens( @@ -645,6 +662,7 @@ def _create_output_sampler_list( accepted_token_ids: torch.Tensor, # shape: [batch_size, k+1] target_logprobs: torch.Tensor, # shape: [batch_size, k+1, vocab_size] k: int, + stage_times: Tuple[float, float, float], ) -> List[SamplerOutput]: """Given the accepted token ids, create a list of SamplerOutput. @@ -722,8 +740,30 @@ def _create_output_sampler_list( if maybe_rejsample_metrics is not None: sampler_output_list[ 0].spec_decode_worker_metrics = maybe_rejsample_metrics + + # Log time spent in each stage periodically. + # This is periodic because the rejection sampler emits metrics + # periodically. + self._maybe_log_stage_times(*stage_times) + return sampler_output_list + def _maybe_log_stage_times(self, average_time_per_proposal_tok_ms: float, + scoring_time_ms: float, + verification_time_ms: float) -> None: + """Log the speculative stage times. If stat logging is disabled, do + nothing. + """ + if self._disable_log_stats: + return + + logger.info( + "SpecDecodeWorker stage times: " + "average_time_per_proposal_tok_ms=%.02f " + "scoring_time_ms=%.02f verification_time_ms=%.02f", + average_time_per_proposal_tok_ms, scoring_time_ms, + verification_time_ms) + def _create_dummy_logprob_lists( self, batch_size: int, diff --git a/vllm/spec_decode/util.py b/vllm/spec_decode/util.py index ade546eef264e..c6223a97dba10 100644 --- a/vllm/spec_decode/util.py +++ b/vllm/spec_decode/util.py @@ -1,3 +1,4 @@ +import time from contextlib import contextmanager from typing import Dict, List, Optional, Tuple @@ -214,3 +215,17 @@ def nvtx_range(msg, *args, **kwargs): yield finally: torch.cuda.nvtx.range_pop() + + +class Timer: + """Basic timer context manager for measuring CPU time. + """ + + def __enter__(self): + self.start_time = time.time() + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.end_time = time.time() + self.elapsed_time_s = self.end_time - self.start_time + self.elapsed_time_ms = self.elapsed_time_s * 1000 From e9630458c7b11732e147c120817c53420280d471 Mon Sep 17 00:00:00 2001 From: Bongwon Jang <152451401+bong-furiosa@users.noreply.github.com> Date: Tue, 6 Aug 2024 00:05:05 +0900 Subject: [PATCH 28/36] [SpecDecode] Support FlashInfer in DraftModelRunner (#6926) --- vllm/spec_decode/draft_model_runner.py | 47 ++++++++++++++++++++++++++ 1 file changed, 47 insertions(+) diff --git a/vllm/spec_decode/draft_model_runner.py b/vllm/spec_decode/draft_model_runner.py index 0b755600ae824..b76a1ab4cf243 100644 --- a/vllm/spec_decode/draft_model_runner.py +++ b/vllm/spec_decode/draft_model_runner.py @@ -11,6 +11,17 @@ from vllm.attention.backends.rocm_flash_attn import ( ROCmFlashAttentionMetadata as FlashAttentionMetadata) +try: + from flashinfer import BatchDecodeWithPagedKVCacheWrapper + from flashinfer.decode import CUDAGraphBatchDecodeWithPagedKVCacheWrapper + from flashinfer.prefill import BatchPrefillWithPagedKVCacheWrapper + FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024 +except ImportError: + BatchDecodeWithPagedKVCacheWrapper = None + CUDAGraphBatchDecodeWithPagedKVCacheWrapper = None + BatchPrefillWithPagedKVCacheWrapper = None + FLASHINFER_WORKSPACE_BUFFER_SIZE = 0 + from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, MultiModalConfig, ParallelConfig, PromptAdapterConfig, SchedulerConfig) @@ -79,6 +90,11 @@ def __init__( return_hidden_states=return_hidden_states, ) + self.flashinfer_decode_workspace_buffer = None + self.flashinfer_decode_wrapper = None + self.flashinfer_prefill_workspace_buffer = None + self.flashinfer_prefill_wrapper = None + def _update_flash_attn_metadata(self, attn_metadata, num_seqs, num_queries): assert isinstance(attn_metadata, FlashAttentionMetadata) @@ -286,6 +302,37 @@ def execute_model( model_input.prompt_adapter_requests, model_input.prompt_adapter_mapping) + if self.attn_backend.get_name() == "flashinfer": + assert model_input.attn_metadata is not None + assert model_input.input_tokens is not None + if self.flashinfer_decode_workspace_buffer is None: + self.flashinfer_decode_workspace_buffer = torch.empty( + FLASHINFER_WORKSPACE_BUFFER_SIZE, + dtype=torch.uint8, + device=self.device) + self.flashinfer_decode_wrapper = \ + BatchDecodeWithPagedKVCacheWrapper( + self.flashinfer_decode_workspace_buffer, "NHD") + self.flashinfer_prefill_workspace_buffer = torch.empty( + FLASHINFER_WORKSPACE_BUFFER_SIZE, + dtype=torch.uint8, + device=self.device) + self.flashinfer_prefill_wrapper = \ + BatchPrefillWithPagedKVCacheWrapper( + self.flashinfer_prefill_workspace_buffer, "NHD") + + model_input.attn_metadata.prefill_wrapper = \ + self.flashinfer_prefill_wrapper + if model_input.attn_metadata.use_cuda_graph: + batch_size = model_input.input_tokens.shape[0] + model_input.attn_metadata.decode_wrapper = \ + self.graph_runners[model_input. + virtual_engine][batch_size].flashinfer_decode_wrapper + else: + model_input.attn_metadata.decode_wrapper = \ + self.flashinfer_decode_wrapper + model_input.attn_metadata.begin_forward() + # Detect exec mode assert model_input.attn_metadata is not None use_cuda_graph = False From 003f8ee1287f90a7e8aa9b9e7d6246ac74ebefbe Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Mon, 5 Aug 2024 08:41:03 -0700 Subject: [PATCH 29/36] [BugFix] Use IP4 localhost form for zmq bind (#7163) --- vllm/entrypoints/openai/rpc/server.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/vllm/entrypoints/openai/rpc/server.py b/vllm/entrypoints/openai/rpc/server.py index 7a72a6f732c99..60bb23b9bde05 100644 --- a/vllm/entrypoints/openai/rpc/server.py +++ b/vllm/entrypoints/openai/rpc/server.py @@ -30,7 +30,9 @@ def __init__(self, async_engine_args: AsyncEngineArgs, # Init socket for readiness state. self.socket = self.context.socket(zmq.constants.ROUTER) - self.socket.bind(f"tcp://localhost:{port}") + # Note numeric form of localhost should be used for zmq bind(), + # see https://stackoverflow.com/a/8958414 + self.socket.bind(f"tcp://127.0.0.1:{port}") def cleanup(self): """Cleanup all resources.""" From 57f560aa23077ed9def5952ab81a65bc080ae234 Mon Sep 17 00:00:00 2001 From: Aditya Paliwal Date: Mon, 5 Aug 2024 09:26:14 -0700 Subject: [PATCH 30/36] [BugFix] Use args.trust_remote_code (#7121) --- vllm/entrypoints/openai/api_server.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index a0190f3d66b10..88f0bd4ee4dbe 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -60,11 +60,11 @@ _running_tasks: Set[asyncio.Task] = set() -def model_is_embedding(model_name: str) -> bool: +def model_is_embedding(model_name: str, trust_remote_code: bool) -> bool: return ModelConfig(model=model_name, tokenizer=model_name, tokenizer_mode="auto", - trust_remote_code=False, + trust_remote_code=trust_remote_code, seed=0, dtype="float16").embedding_mode @@ -97,7 +97,7 @@ async def build_async_engine_client(args) -> AsyncIterator[AsyncEngineClient]: # If manually triggered or embedding model, use AsyncLLMEngine in process. # TODO: support embedding model via RPC. - if (model_is_embedding(args.model) + if (model_is_embedding(args.model, args.trust_remote_code) or args.disable_frontend_multiprocessing): async_engine_client = AsyncLLMEngine.from_engine_args( engine_args, usage_context=UsageContext.OPENAI_API_SERVER) From 997cf78308d292b03c8a1e68d8d1a1f599551937 Mon Sep 17 00:00:00 2001 From: Rui Qiao <161574667+ruisearch42@users.noreply.github.com> Date: Mon, 5 Aug 2024 11:10:16 -0700 Subject: [PATCH 31/36] [Misc] Fix typo in GroupCoordinator.recv() (#7167) Signed-off-by: Rui Qiao --- vllm/distributed/parallel_state.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index d7ca8fd82e1a2..a20b92de81cda 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -713,8 +713,8 @@ def recv(self, size: torch.Size, dtype: torch.dtype, src: Optional[int] = None) -> torch.Tensor: - """Receives a tensor from the src rank.""" - """NOTE: `src` is the local rank of the destination rank.""" + """Receives a tensor from the source rank.""" + """NOTE: `src` is the local rank of the source rank.""" if src is None: src = (self.rank_in_group - 1) % self.world_size From 8571ac4672c8b599338cb95e23dfd624016aab36 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Mon, 5 Aug 2024 15:13:43 -0400 Subject: [PATCH 32/36] [Kernel] Update CUTLASS to 3.5.1 (#7085) --- CMakeLists.txt | 6 +- .../broadcast_load_epilogue_c3x.hpp | 192 ++++++++++-------- .../cutlass_w8a8/scaled_mm_c2x.cuh | 8 +- .../cutlass_w8a8/scaled_mm_c3x.cu | 30 +-- 4 files changed, 129 insertions(+), 107 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 922613ec5ddaa..e5ac5516c2e46 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -193,8 +193,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") FetchContent_Declare( cutlass GIT_REPOSITORY https://github.com/nvidia/cutlass.git - # CUTLASS 3.5.0 - GIT_TAG 7d49e6c7e2f8896c47f586706e67e1fb215529dc + # CUTLASS 3.5.1 + GIT_TAG 06b21349bcf6ddf6a1686a47a137ad1446579db9 # Shallow clone with depth 1 GIT_SHALLOW TRUE GIT_PROGRESS TRUE @@ -237,7 +237,7 @@ define_gpu_extension_target( SOURCES ${VLLM_EXT_SRC} COMPILE_FLAGS ${VLLM_GPU_FLAGS} ARCHITECTURES ${VLLM_GPU_ARCHES} - INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR};${CUTLASS_TOOLS_UTIL_INCLUDE_DIR} + INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR} USE_SABI 3 WITH_SOABI) diff --git a/csrc/quantization/cutlass_w8a8/broadcast_load_epilogue_c3x.hpp b/csrc/quantization/cutlass_w8a8/broadcast_load_epilogue_c3x.hpp index e4bc9752ed7db..58b1e8ff159fb 100644 --- a/csrc/quantization/cutlass_w8a8/broadcast_load_epilogue_c3x.hpp +++ b/csrc/quantization/cutlass_w8a8/broadcast_load_epilogue_c3x.hpp @@ -64,8 +64,6 @@ using namespace detail; // Row vector broadcast template< - // Row bcast reuses the mbarriers from the epilogue subtile load pipeline, so this must be at least - // ceil_div(StagesC, epi tiles per CTA tile) + 1 to ensure no data races int Stages, class CtaTileShapeMNK, class Element, @@ -73,14 +71,12 @@ template< int Alignment = 128 / sizeof_bits_v > struct Sm90RowOrScalarBroadcast { - static_assert(Alignment * sizeof_bits_v % 128 == 0, "sub-16B alignment not supported yet"); - static_assert( - (cute::is_same_v>) || // row vector broadcast, e.g. per-col alpha/bias - (cute::is_same_v>)); // batched row vector broadcast + static_assert(Stages == 0, "Row broadcast doesn't support smem usage"); + static_assert(is_static_v(StrideMNL{}))>); // batch stride can be dynamic or static + static_assert(take<0,2>(StrideMNL{}) == Stride<_0,_1>{}); - // Accumulator doesn't distribute row elements evenly amongst threads so we must buffer in smem - struct SharedStorage { - alignas(16) array_aligned(CtaTileShapeMNK{}) * Stages> smem_row; + struct SharedStorage { + array_aligned(CtaTileShapeMNK{})> smem; }; // This struct has been modified to have a bool indicating that ptr_row is a @@ -100,6 +96,12 @@ struct Sm90RowOrScalarBroadcast { return args; } + template + static bool + can_implement(ProblemShape const& problem_shape, Arguments const& args) { + return true; + } + template static size_t get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { @@ -118,15 +120,15 @@ struct Sm90RowOrScalarBroadcast { CUTLASS_HOST_DEVICE Sm90RowOrScalarBroadcast(Params const& params, SharedStorage const& shared_storage) - : params(params), - smem_row(const_cast(shared_storage.smem_row.data())) { } + : params(params) + , smem(const_cast(shared_storage.smem.data())) { } Params params; - Element* smem_row; + Element *smem = nullptr; CUTLASS_DEVICE bool is_producer_load_needed() const { - return true; + return false; } CUTLASS_DEVICE bool @@ -139,78 +141,76 @@ struct Sm90RowOrScalarBroadcast { return (!params.row_broadcast && *(params.ptr_row) == Element(0)); } - template - struct ProducerLoadCallbacks : EmptyProducerLoadCallbacks { - CUTLASS_DEVICE - ProducerLoadCallbacks(GTensor&& gRow, STensor&& sRow, Params const& params) - : gRow(cute::forward(gRow)), - sRow(cute::forward(sRow)), - params(params) {} - - GTensor gRow; // (CTA_M,CTA_N) - STensor sRow; // (CTA_M,CTA_N,PIPE) - Params const& params; - - CUTLASS_DEVICE void - begin(uint64_t* full_mbarrier_ptr, int load_iteration, bool issue_tma_load) { - if (!params.row_broadcast) { - return; - } - - if (issue_tma_load) { - // Increment the expect-tx count of the first subtile's mbarrier by the row vector's byte-size - constexpr uint32_t copy_bytes = size<1>(CtaTileShapeMNK{}) * sizeof_bits_v / 8; - cutlass::arch::ClusterTransactionBarrier::expect_transaction(full_mbarrier_ptr, copy_bytes); - // Issue the TMA bulk copy - auto bulk_copy = Copy_Atom{}.with(*full_mbarrier_ptr); - // Filter so we don't issue redundant copies over stride-0 modes - int bcast_pipe_index = (load_iteration / EpiTiles) % Stages; - copy(bulk_copy, filter(gRow), filter(sRow(_,_,bcast_pipe_index))); - } - } - }; - template CUTLASS_DEVICE auto get_producer_load_callbacks(ProducerLoadArgs const& args) { - - auto [M, N, K, L] = args.problem_shape_mnkl; - auto [m, n, k, l] = args.tile_coord_mnkl; - Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row), make_shape(M,N,L), params.dRow); - Tensor gRow = local_tile(mRow, take<0,2>(args.tile_shape_mnk), make_coord(m,n,l)); // (CTA_M,CTA_N) - Tensor sRow = make_tensor(make_smem_ptr(smem_row), // (CTA_M,CTA_N,PIPE) - make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{}), Stages), - make_stride(_0{},_1{},size<1>(CtaTileShapeMNK{}))); - - constexpr int EpiTiles = decltype(size<1>(zipped_divide(make_layout(take<0,2>(args.tile_shape_mnk)), args.epi_tile)))::value; - return ProducerLoadCallbacks( - cute::move(gRow), cute::move(sRow), params); + return EmptyProducerLoadCallbacks{}; } - template + template struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { CUTLASS_DEVICE - ConsumerStoreCallbacks(RTensor&& tCrRow, STensor&& tCsRow, Params const& params) - : tCrRow(cute::forward(tCrRow)), - tCsRow(cute::forward(tCsRow)), - params(params) {} - - RTensor tCrRow; // (CPY,CPY_M,CPY_N) - STensor tCsRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N,PIPE) + ConsumerStoreCallbacks( + GS_GTensor tGS_gRow_, GS_STensor tGS_sRow_, + GS_CTensor tGS_cRow_, Tiled_G2S tiled_g2s_, + SR_STensor tSR_sRow_, SR_RTensor tSR_rRow_, + CTensor tCcRow_, ThrResidue residue_tCcRow_, ThrNum thr_num_, Params const& params_) + : tGS_gRow(tGS_gRow_) + , tGS_sRow(tGS_sRow_) + , tGS_cRow(tGS_cRow_) + , tiled_G2S(tiled_g2s_) + , tSR_sRow(tSR_sRow_) + , tSR_rRow(tSR_rRow_) + , tCcRow(tCcRow_) + , residue_tCcRow(residue_tCcRow_) + , params(params_) {} + + GS_GTensor tGS_gRow; // (CPY,CPY_M,CPY_N) + GS_STensor tGS_sRow; // (CPY,CPY_M,CPY_N) + GS_CTensor tGS_cRow; // (CPY,CPY_M,CPY_N) + Tiled_G2S tiled_G2S; + + SR_STensor tSR_sRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + SR_RTensor tSR_rRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + + CTensor tCcRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + ThrResidue residue_tCcRow; // (m, n) + ThrNum thr_num; Params const& params; CUTLASS_DEVICE void - previsit(int epi_m, int epi_n, int load_iteration, bool is_producer_load_needed) { + begin() { if (!params.row_broadcast) { - fill(tCrRow, *(params.ptr_row)); + fill(tSR_rRow, *(params.ptr_row)); return; } + auto synchronize = [&] () { cutlass::arch::NamedBarrier::sync(thr_num, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); }; + Tensor tGS_gRow_flt = filter_zeros(tGS_gRow); + Tensor tGS_sRow_flt = filter_zeros(tGS_sRow); + Tensor tGS_cRow_flt = make_tensor(tGS_cRow.data(), make_layout(tGS_gRow_flt.shape(), tGS_cRow.stride())); + + for (int i = 0; i < size(tGS_gRow_flt); ++i) { + if (get<1>(tGS_cRow_flt(i)) >= size<1>(CtaTileShapeMNK{})) { + continue; // OOB of SMEM, + } + if (elem_less(tGS_cRow_flt(i), make_coord(get<0>(residue_tCcRow), get<1>(residue_tCcRow)))) { + tGS_sRow_flt(i) = tGS_gRow_flt(i); + } + else { + tGS_sRow_flt(i) = Element(0); // Set to Zero when OOB so LDS could be issue without any preds. + } + } + synchronize(); + } + + CUTLASS_DEVICE void + begin_loop(int epi_m, int epi_n) { if (epi_m == 0) { // Assumes M-major subtile loop - // Filter so we don't issue redundant copies over stride-0 modes - // (only works if 0-strides are in same location, which is by construction) - int bcast_pipe_index = (load_iteration / EpiTiles) % Stages; - copy_aligned(filter(tCsRow(_,_,_,epi_m,epi_n,bcast_pipe_index)), filter(tCrRow)); + if (!params.row_broadcast) return; // Do not issue LDS when row is scalar + Tensor tSR_sRow_flt = filter_zeros(tSR_sRow(_,_,_,epi_m,epi_n)); + Tensor tSR_rRow_flt = filter_zeros(tSR_rRow); + copy(tSR_sRow_flt, tSR_rRow_flt); } } @@ -221,7 +221,7 @@ struct Sm90RowOrScalarBroadcast { CUTLASS_PRAGMA_UNROLL for (int i = 0; i < FragmentSize; ++i) { - frg_row[i] = tCrRow(epi_v * FragmentSize + i); + frg_row[i] = tSR_rRow(epi_v * FragmentSize + i); } return frg_row; @@ -234,17 +234,41 @@ struct Sm90RowOrScalarBroadcast { > CUTLASS_DEVICE auto get_consumer_store_callbacks(ConsumerStoreArgs const& args) { + auto [M, N, K, L] = args.problem_shape_mnkl; + auto [m, n, k, l] = args.tile_coord_mnkl; + using ThreadCount = decltype(size(args.tiled_copy)); - Tensor sRow = make_tensor(make_smem_ptr(smem_row), // (CTA_M,CTA_N,PIPE) - make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{}), Stages), - make_stride(_0{},_1{},size<1>(CtaTileShapeMNK{}))); - Tensor tCsRow = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N,PIPE) - sRow, args.epi_tile, args.tiled_copy, args.thread_idx); - Tensor tCrRow = make_tensor_like(take<0,3>(tCsRow)); // (CPY,CPY_M,CPY_N) - - constexpr int EpiTiles = decltype(size<1>(zipped_divide(make_layout(take<0,2>(args.tile_shape_mnk)), args.epi_tile)))::value; - return ConsumerStoreCallbacks( - cute::move(tCrRow), cute::move(tCsRow), params); + Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row), make_shape(M,N,L), params.dRow); + Tensor gRow = local_tile(mRow(_,_,l), take<0,2>(args.tile_shape_mnk), make_coord(m, n)); // (CTA_M, CTA_N) + Tensor sRow = make_tensor(make_smem_ptr(smem), + make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{})), make_shape(_0{}, _1{})); // (CTA_M, CTA_N) + //// G2S: Gmem to Smem + auto tiled_g2s = make_tiled_copy(Copy_Atom{}, + Layout< Shape<_1, ThreadCount>, + Stride<_0, _1>>{}, + Layout<_1>{}); + auto thr_g2s = tiled_g2s.get_slice(args.thread_idx); + Tensor tGS_gRow = thr_g2s.partition_S(gRow); + Tensor tGS_sRow = thr_g2s.partition_D(sRow); + + //// G2S: Coord + auto cRow = make_identity_tensor(make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{}))); + Tensor tGS_cRow = thr_g2s.partition_S(cRow); + + //// S2R: Smem to Reg + Tensor tSR_sRow = sm90_partition_for_epilogue(sRow, args.epi_tile, args.tiled_copy, args.thread_idx); + Tensor tSR_rRow = make_tensor_like(take<0,3>(tSR_sRow)); // (CPY,CPY_M,CPY_N) + + return ConsumerStoreCallbacks( + tGS_gRow, + tGS_sRow, + tGS_cRow, tiled_g2s, + tSR_sRow, + tSR_rRow, + args.tCcD, + args.residue_cD, + ThreadCount{}, + params); } }; @@ -285,6 +309,12 @@ struct Sm90ColOrScalarBroadcast { return args; } + template + static bool + can_implement(ProblemShape const& problem_shape, Arguments const& args) { + return true; + } + template static size_t get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh b/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh index ba620e85117b5..be8a5c0e54e8e 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh @@ -10,8 +10,6 @@ #include "cute/atom/mma_atom.hpp" #include "cutlass/numeric_types.h" -#include "cutlass/util/device_memory.h" - #include "cutlass/cutlass.h" #include "cutlass/gemm_coord.h" #include "cutlass/arch/mma_sm75.h" @@ -301,12 +299,14 @@ inline void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a, // Launch the CUTLASS GEMM kernel. typename Gemm::Op gemm_op; size_t workspace_size = gemm_op.get_workspace_size(args); - cutlass::device_memory::allocation workspace(workspace_size); + auto const workspace_options = + torch::TensorOptions().dtype(torch::kUInt8).device(a.device()); + auto workspace = torch::empty(workspace_size, workspace_options); auto stream = at::cuda::getCurrentCUDAStream(a.get_device()); CUTLASS_CHECK(gemm_op.can_implement(args)); - cutlass::Status status = gemm_op(args, workspace.get(), stream); + cutlass::Status status = gemm_op(args, workspace.data_ptr(), stream); CUTLASS_CHECK(status); } diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu index b3f5b62086609..088185188770d 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu @@ -18,8 +18,6 @@ #include "cute/atom/mma_atom.hpp" #include "cutlass/numeric_types.h" -#include "cutlass/util/device_memory.h" - #include "cutlass/gemm/device/gemm_universal_adapter.h" #include "cutlass/gemm/kernel/gemm_universal.hpp" #include "cutlass/epilogue/collective/collective_builder.hpp" @@ -72,13 +70,9 @@ struct ScaledEpilogueBase { 0 /*Stages*/, typename EpilogueDescriptor::TileShape, float, Stride, Int<0>, Int<0>>>; - using ScaleBDescriptor = - cutlass::epilogue::collective::detail::RowBroadcastDescriptor< - EpilogueDescriptor, float>; - using ScaleB = cutlass::epilogue::fusion::Sm90RowOrScalarBroadcast< - ScaleBDescriptor::Stages, typename EpilogueDescriptor::TileShape, - typename ScaleBDescriptor::Element, Stride, Int<1>, Int<0>>>; + 0 /*Stages*/, typename EpilogueDescriptor::TileShape, float, + Stride, Int<1>, Int<0>>>; }; /* @@ -154,12 +148,8 @@ struct ScaledEpilogueBias cutlass::multiply_add, ElementD, float, cutlass::FloatRoundStyle::round_to_nearest>; - using BiasDescriptor = - cutlass::epilogue::collective::detail::RowBroadcastDescriptor< - EpilogueDescriptor, ElementD>; - using Bias = cutlass::epilogue::fusion::Sm90RowBroadcast< - BiasDescriptor::Stages, typename EpilogueDescriptor::TileShape, ElementD, + 0 /*Stages*/, typename EpilogueDescriptor::TileShape, ElementD, Stride, Int<1>, Int<0>>, 128 / sizeof_bits_v, false>; public: @@ -251,12 +241,12 @@ void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a, int64_t ldb = b.stride(1); int64_t ldc = out.stride(0); - using StrideA = Stride, Int<0>>; - using StrideB = Stride, Int<0>>; + using StrideA = Stride, int64_t>; + using StrideB = Stride, int64_t>; using StrideC = typename Gemm::StrideC; - StrideA a_stride{lda, Int<1>{}, Int<0>{}}; - StrideB b_stride{ldb, Int<1>{}, Int<0>{}}; + StrideA a_stride{lda, Int<1>{}, 0}; + StrideB b_stride{ldb, Int<1>{}, 0}; StrideC c_stride{ldc, Int<1>{}, Int<0>{}}; using GemmKernel = typename Gemm::GemmKernel; @@ -282,11 +272,13 @@ void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a, CUTLASS_CHECK(gemm_op.can_implement(args)); size_t workspace_size = gemm_op.get_workspace_size(args); - cutlass::device_memory::allocation workspace(workspace_size); + auto const workspace_options = + torch::TensorOptions().dtype(torch::kUInt8).device(a.device()); + auto workspace = torch::empty(workspace_size, workspace_options); auto stream = at::cuda::getCurrentCUDAStream(a.get_device()); - cutlass::Status status = gemm_op.run(args, workspace.get(), stream); + cutlass::Status status = gemm_op.run(args, workspace.data_ptr(), stream); CUTLASS_CHECK(status); } From 6e4852ce28ad57dc440067778464ac61e0621899 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Mon, 5 Aug 2024 16:00:01 -0400 Subject: [PATCH 33/36] [CI/Build] Suppress divide-by-zero and missing return statement warnings (#7001) --- csrc/attention/dtype_bfloat16.cuh | 8 ++++++++ csrc/quantization/awq/dequantize.cuh | 1 + csrc/quantization/fp8/nvidia/quant_utils.cuh | 5 +++-- csrc/quantization/gptq_marlin/gptq_marlin.cu | 18 ++++++++++++------ 4 files changed, 24 insertions(+), 8 deletions(-) diff --git a/csrc/attention/dtype_bfloat16.cuh b/csrc/attention/dtype_bfloat16.cuh index 3cdcb95e08099..97a25baa1fc0d 100644 --- a/csrc/attention/dtype_bfloat16.cuh +++ b/csrc/attention/dtype_bfloat16.cuh @@ -94,6 +94,7 @@ inline __device__ float2 bf1622float2(const __nv_bfloat162 val) { #else return __bfloat1622float2(val); #endif + __builtin_unreachable(); // Suppress missing return statement warning } inline __device__ __nv_bfloat162 bf162bf162(const __nv_bfloat16 val) { @@ -102,6 +103,7 @@ inline __device__ __nv_bfloat162 bf162bf162(const __nv_bfloat16 val) { #else return __bfloat162bfloat162(val); #endif + __builtin_unreachable(); // Suppress missing return statement warning } // Vector addition. @@ -115,6 +117,7 @@ inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b) { return __hadd(a, b); #endif #endif + __builtin_unreachable(); // Suppress missing return statement warning } inline __device__ __nv_bfloat162 add(__nv_bfloat162 a, __nv_bfloat162 b) { @@ -123,6 +126,7 @@ inline __device__ __nv_bfloat162 add(__nv_bfloat162 a, __nv_bfloat162 b) { #else return __hadd2(a, b); #endif + __builtin_unreachable(); // Suppress missing return statement warning } inline __device__ bf16_4_t add(bf16_4_t a, bf16_4_t b) { @@ -170,6 +174,7 @@ inline __device__ __nv_bfloat16 mul(__nv_bfloat16 a, __nv_bfloat16 b) { #else return __hmul(a, b); #endif + __builtin_unreachable(); // Suppress missing return statement warning } template <> @@ -179,6 +184,7 @@ inline __device__ __nv_bfloat162 mul(__nv_bfloat162 a, __nv_bfloat162 b) { #else return __hmul2(a, b); #endif + __builtin_unreachable(); // Suppress missing return statement warning } template <> @@ -289,6 +295,7 @@ inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b, #else return __hfma2(a, b, c); #endif + __builtin_unreachable(); // Suppress missing return statement warning } inline __device__ __nv_bfloat162 fma(__nv_bfloat16 a, __nv_bfloat162 b, @@ -298,6 +305,7 @@ inline __device__ __nv_bfloat162 fma(__nv_bfloat16 a, __nv_bfloat162 b, #else return __hfma2(bf162bf162(a), b, c); #endif + __builtin_unreachable(); // Suppress missing return statement warning } inline __device__ bf16_4_t fma(bf16_4_t a, bf16_4_t b, bf16_4_t c) { diff --git a/csrc/quantization/awq/dequantize.cuh b/csrc/quantization/awq/dequantize.cuh index 813ec6716cf54..5fa4b5f640277 100644 --- a/csrc/quantization/awq/dequantize.cuh +++ b/csrc/quantization/awq/dequantize.cuh @@ -95,6 +95,7 @@ __device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source) { return result; #endif + __builtin_unreachable(); // Suppress missing return statement warning } } // namespace awq diff --git a/csrc/quantization/fp8/nvidia/quant_utils.cuh b/csrc/quantization/fp8/nvidia/quant_utils.cuh index e32684eaed24d..f8cd1dcba4ab3 100644 --- a/csrc/quantization/fp8/nvidia/quant_utils.cuh +++ b/csrc/quantization/fp8/nvidia/quant_utils.cuh @@ -475,6 +475,7 @@ __inline__ __device__ uint8_t scaled_vec_conversion( __NV_SATFINITE, fp8_type); return (uint8_t)res; #endif + __builtin_unreachable(); // Suppress missing return statement warning } // float -> fp8 @@ -508,7 +509,7 @@ __inline__ __device__ Tout convert(const Tin& x) { } #endif assert(false); - return {}; // Squash missing return statement warning + __builtin_unreachable(); // Suppress missing return statement warning } template @@ -521,7 +522,7 @@ __inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) { } #endif assert(false); - return {}; // Squash missing return statement warning + __builtin_unreachable(); // Suppress missing return statement warning } // The following macro is used to dispatch the conversion function based on diff --git a/csrc/quantization/gptq_marlin/gptq_marlin.cu b/csrc/quantization/gptq_marlin/gptq_marlin.cu index edf19365c8098..e2b0f2b058164 100644 --- a/csrc/quantization/gptq_marlin/gptq_marlin.cu +++ b/csrc/quantization/gptq_marlin/gptq_marlin.cu @@ -1130,12 +1130,12 @@ __global__ void Marlin( }; auto fetch_zp_to_registers = [&](int k, int full_pipe) { - if constexpr (has_zp) { - // This code does not handle group_blocks == 0, - // which signifies act_order. - // has_zp implies AWQ, which doesn't have act_order, - static_assert(group_blocks != 0); + // This code does not handle group_blocks == 0, + // which signifies act_order. + // has_zp implies AWQ, which doesn't have act_order, + static_assert(!has_zp || group_blocks != 0); + if constexpr (has_zp) { int pipe = full_pipe % stages; if constexpr (group_blocks == -1) { @@ -1161,7 +1161,13 @@ __global__ void Marlin( cur_k += k_iter_size * (k % b_sh_wr_iters); int k_blocks = cur_k / 16; - int cur_group_id = k_blocks / group_blocks; + int cur_group_id = 0; + + // Suppress bogus and persistent divide-by-zero warning + #pragma nv_diagnostic push + #pragma nv_diag_suppress divide_by_zero + cur_group_id = k_blocks / group_blocks; + #pragma nv_diagnostic pop int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe; From 4cf1dc39be80d81ddda9e7e55f4742a6bd57920c Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Mon, 5 Aug 2024 17:22:57 -0400 Subject: [PATCH 34/36] [Bugfix][CI/Build] Fix CUTLASS FetchContent (#7171) --- CMakeLists.txt | 2 -- 1 file changed, 2 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index e5ac5516c2e46..8de0c034a7cb6 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -195,8 +195,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") GIT_REPOSITORY https://github.com/nvidia/cutlass.git # CUTLASS 3.5.1 GIT_TAG 06b21349bcf6ddf6a1686a47a137ad1446579db9 - # Shallow clone with depth 1 - GIT_SHALLOW TRUE GIT_PROGRESS TRUE ) FetchContent_MakeAvailable(cutlass) From 4db5176d9758b720b05460c50ace3c01026eb158 Mon Sep 17 00:00:00 2001 From: Simon Mo Date: Mon, 5 Aug 2024 14:39:48 -0700 Subject: [PATCH 35/36] bump version to v0.5.4 (#7139) --- docs/source/getting_started/installation.rst | 2 +- vllm/version.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/getting_started/installation.rst b/docs/source/getting_started/installation.rst index 57ad8bacedfcc..5294003aa9261 100644 --- a/docs/source/getting_started/installation.rst +++ b/docs/source/getting_started/installation.rst @@ -48,7 +48,7 @@ You can install vLLM using pip: .. code-block:: console - $ export VLLM_VERSION=0.5.2 # vLLM's main branch version is currently set to latest released tag + $ export VLLM_VERSION=0.5.4 # vLLM's main branch version is currently set to latest released tag $ pip install https://vllm-wheels.s3.us-west-2.amazonaws.com/nightly/vllm-${VLLM_VERSION}-cp38-abi3-manylinux1_x86_64.whl $ # You can also access a specific commit $ # export VLLM_COMMIT=... diff --git a/vllm/version.py b/vllm/version.py index 6930654710632..247036f1d6211 100644 --- a/vllm/version.py +++ b/vllm/version.py @@ -9,4 +9,4 @@ stacklevel=2) __commit__ = "COMMIT_HASH_PLACEHOLDER" -__version__ = "0.5.3.post1" +__version__ = "0.5.4" From dfb1a15dcb4c24bf7ff0ba7ddfc5d623ad519d7d Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 5 Aug 2024 15:59:22 -0700 Subject: [PATCH 36/36] [ci][frontend] deduplicate tests (#7101) --- tests/entrypoints/openai/test_completion.py | 14 +- tests/entrypoints/openai/test_disable_mp.py | 715 -------------------- 2 files changed, 6 insertions(+), 723 deletions(-) delete mode 100644 tests/entrypoints/openai/test_disable_mp.py diff --git a/tests/entrypoints/openai/test_completion.py b/tests/entrypoints/openai/test_completion.py index 50add84087a95..05f667231738f 100644 --- a/tests/entrypoints/openai/test_completion.py +++ b/tests/entrypoints/openai/test_completion.py @@ -87,15 +87,13 @@ def default_server_args(zephyr_lora_files, zephyr_lora_added_tokens_files, ] -@pytest.fixture(scope="module") -def server(default_server_args): +@pytest.fixture(scope="module", + params=["", "--disable-frontend-multiprocessing"]) +def client(default_server_args, request): + if request.param: + default_server_args.append(request.param) with RemoteOpenAIServer(MODEL_NAME, default_server_args) as remote_server: - yield remote_server - - -@pytest.fixture(scope="module") -def client(server): - return server.get_async_client() + yield remote_server.get_async_client() @pytest.mark.asyncio diff --git a/tests/entrypoints/openai/test_disable_mp.py b/tests/entrypoints/openai/test_disable_mp.py deleted file mode 100644 index 12c805413311c..0000000000000 --- a/tests/entrypoints/openai/test_disable_mp.py +++ /dev/null @@ -1,715 +0,0 @@ -""" -Repeat of tests in test_completion.py with the non-mp backend. -""" - -# imports for guided decoding tests -import json -import re -import shutil -from tempfile import TemporaryDirectory -from typing import List - -import jsonschema -import openai # use the official client for correctness check -import pytest -# downloading lora to test lora requests -from huggingface_hub import snapshot_download -from openai import BadRequestError -from transformers import AutoTokenizer - -from vllm.transformers_utils.tokenizer import get_tokenizer - -from ...utils import RemoteOpenAIServer - -# any model with a chat template should work here -MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" -# technically these adapters use a different base model, -# but we're not testing generation quality here -LORA_NAME = "typeof/zephyr-7b-beta-lora" -PA_NAME = "swapnilbp/llama_tweet_ptune" -# if PA_NAME changes, PA_NUM_VIRTUAL_TOKENS might also -# need to change to match the prompt adapter -PA_NUM_VIRTUAL_TOKENS = 8 - - -@pytest.fixture(scope="module") -def zephyr_lora_files(): - return snapshot_download(repo_id=LORA_NAME) - - -@pytest.fixture(scope="module") -def zephyr_lora_added_tokens_files(zephyr_lora_files): - tmp_dir = TemporaryDirectory() - tmp_model_dir = f"{tmp_dir.name}/zephyr" - shutil.copytree(zephyr_lora_files, tmp_model_dir) - tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) - # Copy tokenizer to adapter and add some unique tokens - # 32000, 32001, 32002 - added = tokenizer.add_tokens(["vllm1", "vllm2", "vllm3"], - special_tokens=True) - assert added == 3 - tokenizer.save_pretrained(tmp_model_dir) - yield tmp_model_dir - tmp_dir.cleanup() - - -@pytest.fixture(scope="module") -def zephyr_pa_files(): - return snapshot_download(repo_id=PA_NAME) - - -@pytest.fixture(scope="module") -def default_server_args(zephyr_lora_files, zephyr_lora_added_tokens_files, - zephyr_pa_files): - return [ - # use half precision for speed and memory savings in CI environment - "--dtype", - "bfloat16", - "--max-model-len", - "8192", - "--max-num-seqs", - "128", - "--enforce-eager", - # lora config - "--enable-lora", - "--lora-modules", - f"zephyr-lora={zephyr_lora_files}", - f"zephyr-lora2={zephyr_lora_added_tokens_files}", - "--max-lora-rank", - "64", - "--max-cpu-loras", - "2", - # pa config - "--enable-prompt-adapter", - "--prompt-adapters", - f"zephyr-pa={zephyr_pa_files}", - f"zephyr-pa2={zephyr_pa_files}", - "--max-prompt-adapters", - "2", - "--max-prompt-adapter-token", - "128", - "--disable-frontend-multiprocessing" - ] - - -@pytest.fixture(scope="module") -def server(default_server_args): - with RemoteOpenAIServer(MODEL_NAME, default_server_args) as remote_server: - yield remote_server - - -@pytest.fixture(scope="module") -def client(server): - return server.get_async_client() - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - # first test base model, then test loras, then test prompt adapters - "model_name,num_virtual_tokens", - [(MODEL_NAME, 0), ("zephyr-lora", 0), ("zephyr-lora2", 0), - ("zephyr-pa", PA_NUM_VIRTUAL_TOKENS), - ("zephyr-pa2", PA_NUM_VIRTUAL_TOKENS)], -) -async def test_single_completion(client: openai.AsyncOpenAI, model_name: str, - num_virtual_tokens: int): - completion = await client.completions.create(model=model_name, - prompt="Hello, my name is", - max_tokens=5, - temperature=0.0) - - assert completion.id is not None - assert completion.choices is not None and len(completion.choices) == 1 - - choice = completion.choices[0] - assert len(choice.text) >= 5 - assert choice.finish_reason == "length" - assert completion.usage == openai.types.CompletionUsage( - completion_tokens=5, - prompt_tokens=6 + num_virtual_tokens, - total_tokens=11 + num_virtual_tokens) - - # test using token IDs - completion = await client.completions.create( - model=model_name, - prompt=[0, 0, 0, 0, 0], - max_tokens=5, - temperature=0.0, - ) - assert len(completion.choices[0].text) >= 1 - - -@pytest.mark.asyncio -async def test_added_lora_tokens(client: openai.AsyncOpenAI): - # test using token IDs - completion = await client.completions.create( - model="zephyr-lora2", - prompt=[0, 0, 32000, 32001, 32002], - echo=True, - max_tokens=5, - temperature=0.0, - ) - # Added tokens should appear in tokenized prompt - assert completion.choices[0].text.startswith("vllm1vllm2vllm3") - - -@pytest.mark.asyncio -async def test_added_lora_tokens_base_model(client: openai.AsyncOpenAI): - # test using token IDs - completion = await client.completions.create( - model=MODEL_NAME, - prompt=[0, 0, 32000, 32001, 32002], - echo=True, - max_tokens=5, - temperature=0.0, - ) - # Added tokens should not appear in tokenized prompt - assert "vllm" not in completion.choices[0].text - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - # first test base model, then test loras, then test prompt adapters - "model_name", - [MODEL_NAME, "zephyr-lora", "zephyr-lora2", "zephyr-pa", "zephyr-pa2"], -) -async def test_no_logprobs(client: openai.AsyncOpenAI, model_name: str): - # test using token IDs - completion = await client.completions.create( - model=model_name, - prompt=[0, 0, 0, 0, 0], - max_tokens=5, - temperature=0.0, - logprobs=None, - ) - choice = completion.choices[0] - assert choice.logprobs is None - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - # just test 1 lora and 1 pa hereafter - "model_name", - [MODEL_NAME, "zephyr-lora", "zephyr-pa"], -) -async def test_zero_logprobs(client: openai.AsyncOpenAI, model_name: str): - # test using token IDs - completion = await client.completions.create( - model=model_name, - prompt=[0, 0, 0, 0, 0], - max_tokens=5, - temperature=0.0, - logprobs=0, - ) - choice = completion.choices[0] - assert choice.logprobs is not None - assert choice.logprobs.token_logprobs is not None - assert choice.logprobs.top_logprobs is not None - assert len(choice.logprobs.top_logprobs[0]) == 1 - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "model_name", - [MODEL_NAME, "zephyr-lora", "zephyr-pa"], -) -async def test_some_logprobs(client: openai.AsyncOpenAI, model_name: str): - # test using token IDs - completion = await client.completions.create( - model=model_name, - prompt=[0, 0, 0, 0, 0], - max_tokens=5, - temperature=0.0, - logprobs=5, - ) - choice = completion.choices[0] - assert choice.logprobs is not None - assert choice.logprobs.token_logprobs is not None - assert choice.logprobs.top_logprobs is not None - assert 5 <= len(choice.logprobs.top_logprobs[0]) <= 6 - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "model_name", - [MODEL_NAME, "zephyr-lora", "zephyr-pa"], -) -async def test_too_many_completion_logprobs(client: openai.AsyncOpenAI, - model_name: str): - - with pytest.raises( - (openai.BadRequestError, openai.APIError)): # test using token IDs - await client.completions.create( - model=model_name, - prompt=[0, 0, 0, 0, 0], - max_tokens=5, - temperature=0.0, - # vLLM has higher default max_logprobs (20 instead of 5) to support - # both Completion API and Chat Completion API - logprobs=21, - ) - ... - with pytest.raises( - (openai.BadRequestError, openai.APIError)): # test using token IDs - stream = await client.completions.create( - model=model_name, - prompt=[0, 0, 0, 0, 0], - max_tokens=5, - temperature=0.0, - # vLLM has higher default max_logprobs (20 instead of 5) to support - # both Completion API and Chat Completion API - logprobs=30, - stream=True, - ) - async for chunk in stream: - ... - - # the server should still work afterwards - completion = await client.completions.create( - model=model_name, - prompt=[0, 0, 0, 0, 0], - max_tokens=5, - temperature=0.0, - ) - assert len(completion.choices[0].text) >= 0 - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "model_name", - [MODEL_NAME, "zephyr-lora", "zephyr-pa"], -) -async def test_completion_streaming(client: openai.AsyncOpenAI, - model_name: str): - prompt = "What is an LLM?" - - single_completion = await client.completions.create( - model=model_name, - prompt=prompt, - max_tokens=5, - temperature=0.0, - ) - single_output = single_completion.choices[0].text - stream = await client.completions.create(model=model_name, - prompt=prompt, - max_tokens=5, - temperature=0.0, - stream=True) - chunks: List[str] = [] - finish_reason_count = 0 - async for chunk in stream: - chunks.append(chunk.choices[0].text) - if chunk.choices[0].finish_reason is not None: - finish_reason_count += 1 - # finish reason should only return in last block - assert finish_reason_count == 1 - assert chunk.choices[0].finish_reason == "length" - assert chunk.choices[0].text - assert "".join(chunks) == single_output - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "model_name", - [MODEL_NAME, "zephyr-lora", "zephyr-pa"], -) -async def test_completion_stream_options(client: openai.AsyncOpenAI, - model_name: str): - prompt = "What is the capital of France?" - - # Test stream=True, stream_options= - # {"include_usage": False, "continuous_usage_stats": False} - stream = await client.completions.create(model=model_name, - prompt=prompt, - max_tokens=5, - temperature=0.0, - stream=True, - stream_options={ - "include_usage": False, - "continuous_usage_stats": - False, - }) - - async for chunk in stream: - assert chunk.usage is None - - # Test stream=True, stream_options= - # {"include_usage": False, "continuous_usage_stats": True} - stream = await client.completions.create(model=model_name, - prompt=prompt, - max_tokens=5, - temperature=0.0, - stream=True, - stream_options={ - "include_usage": False, - "continuous_usage_stats": - True, - }) - async for chunk in stream: - assert chunk.usage is None - - # Test stream=True, stream_options= - # {"include_usage": True, "continuous_usage_stats": False} - stream = await client.completions.create(model=model_name, - prompt=prompt, - max_tokens=5, - temperature=0.0, - stream=True, - stream_options={ - "include_usage": True, - "continuous_usage_stats": - False, - }) - async for chunk in stream: - if chunk.choices[0].finish_reason is None: - assert chunk.usage is None - else: - assert chunk.usage is None - final_chunk = await stream.__anext__() - assert final_chunk.usage is not None - assert final_chunk.usage.prompt_tokens > 0 - assert final_chunk.usage.completion_tokens > 0 - assert final_chunk.usage.total_tokens == ( - final_chunk.usage.prompt_tokens + - final_chunk.usage.completion_tokens) - assert final_chunk.choices == [] - - # Test stream=True, stream_options= - # {"include_usage": True, "continuous_usage_stats": True} - stream = await client.completions.create(model=model_name, - prompt=prompt, - max_tokens=5, - temperature=0.0, - stream=True, - stream_options={ - "include_usage": True, - "continuous_usage_stats": - True, - }) - async for chunk in stream: - assert chunk.usage is not None - assert chunk.usage.prompt_tokens > 0 - assert chunk.usage.completion_tokens > 0 - assert chunk.usage.total_tokens == (chunk.usage.prompt_tokens + - chunk.usage.completion_tokens) - if chunk.choices[0].finish_reason is not None: - final_chunk = await stream.__anext__() - assert final_chunk.usage is not None - assert final_chunk.usage.prompt_tokens > 0 - assert final_chunk.usage.completion_tokens > 0 - assert final_chunk.usage.total_tokens == ( - final_chunk.usage.prompt_tokens + - final_chunk.usage.completion_tokens) - assert final_chunk.choices == [] - - # Test stream=False, stream_options= - # {"include_usage": None} - with pytest.raises(BadRequestError): - await client.completions.create(model=model_name, - prompt=prompt, - max_tokens=5, - temperature=0.0, - stream=False, - stream_options={"include_usage": None}) - - # Test stream=False, stream_options= - # {"include_usage": True} - with pytest.raises(BadRequestError): - await client.completions.create(model=model_name, - prompt=prompt, - max_tokens=5, - temperature=0.0, - stream=False, - stream_options={"include_usage": True}) - - # Test stream=False, stream_options= - # {"continuous_usage_stats": None} - with pytest.raises(BadRequestError): - await client.completions.create( - model=model_name, - prompt=prompt, - max_tokens=5, - temperature=0.0, - stream=False, - stream_options={"continuous_usage_stats": None}) - - # Test stream=False, stream_options= - # {"continuous_usage_stats": True} - with pytest.raises(BadRequestError): - await client.completions.create( - model=model_name, - prompt=prompt, - max_tokens=5, - temperature=0.0, - stream=False, - stream_options={"continuous_usage_stats": True}) - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "model_name", - [MODEL_NAME, "zephyr-lora", "zephyr-pa"], -) -async def test_batch_completions(client: openai.AsyncOpenAI, model_name: str): - # test both text and token IDs - for prompts in (["Hello, my name is"] * 2, [[0, 0, 0, 0, 0]] * 2): - # test simple list - batch = await client.completions.create( - model=model_name, - prompt=prompts, - max_tokens=5, - temperature=0.0, - ) - assert len(batch.choices) == 2 - assert batch.choices[0].text == batch.choices[1].text - - # test n = 2 - batch = await client.completions.create( - model=model_name, - prompt=prompts, - n=2, - max_tokens=5, - temperature=0.0, - extra_body=dict( - # NOTE: this has to be true for n > 1 in vLLM, but not necessary - # for official client. - use_beam_search=True), - ) - assert len(batch.choices) == 4 - assert batch.choices[0].text != batch.choices[ - 1].text, "beam search should be different" - assert batch.choices[0].text == batch.choices[ - 2].text, "two copies of the same prompt should be the same" - assert batch.choices[1].text == batch.choices[ - 3].text, "two copies of the same prompt should be the same" - - # test streaming - batch = await client.completions.create( - model=model_name, - prompt=prompts, - max_tokens=5, - temperature=0.0, - stream=True, - ) - texts = [""] * 2 - async for chunk in batch: - assert len(chunk.choices) == 1 - choice = chunk.choices[0] - texts[choice.index] += choice.text - assert texts[0] == texts[1] - - -@pytest.mark.asyncio -async def test_logits_bias(client: openai.AsyncOpenAI): - prompt = "Hello, my name is" - max_tokens = 5 - tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME) - - # Test exclusive selection - token_id = 1000 - completion = await client.completions.create( - model=MODEL_NAME, - prompt=prompt, - max_tokens=max_tokens, - temperature=0.0, - logit_bias={str(token_id): 100}, - seed=42, - ) - assert len(completion.choices[0].text) >= 5 - response_tokens = tokenizer(completion.choices[0].text, - add_special_tokens=False)["input_ids"] - expected_tokens = tokenizer(tokenizer.decode([token_id] * 5), - add_special_tokens=False)["input_ids"] - assert all([ - response == expected - for response, expected in zip(response_tokens, expected_tokens) - ]) - - # Test ban - completion = await client.completions.create( - model=MODEL_NAME, - prompt=prompt, - max_tokens=max_tokens, - temperature=0.0, - ) - response_tokens = tokenizer(completion.choices[0].text, - add_special_tokens=False)["input_ids"] - first_response = completion.choices[0].text - completion = await client.completions.create( - model=MODEL_NAME, - prompt=prompt, - max_tokens=max_tokens, - temperature=0.0, - logit_bias={str(token): -100 - for token in response_tokens}, - ) - assert first_response != completion.choices[0].text - - -@pytest.mark.asyncio -async def test_allowed_token_ids(client: openai.AsyncOpenAI): - prompt = "Hello, my name is" - max_tokens = 1 - tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME) - - # Test exclusive selection - allowed_ids = [21555, 21557, 21558] - completion = await client.completions.create( - model=MODEL_NAME, - prompt=prompt, - max_tokens=max_tokens, - temperature=0.0, - seed=42, - extra_body=dict(allowed_token_ids=allowed_ids), - logprobs=1, - ) - response_tokens = completion.choices[0].logprobs.tokens - assert len(response_tokens) == 1 - assert tokenizer.convert_tokens_to_ids(response_tokens)[0] in allowed_ids - - -@pytest.mark.asyncio -@pytest.mark.parametrize("guided_decoding_backend", - ["outlines", "lm-format-enforcer"]) -async def test_guided_json_completion(client: openai.AsyncOpenAI, - guided_decoding_backend: str, - sample_json_schema): - completion = await client.completions.create( - model=MODEL_NAME, - prompt=f"Give an example JSON for an employee profile " - f"that fits this schema: {sample_json_schema}", - n=3, - temperature=1.0, - max_tokens=500, - extra_body=dict(guided_json=sample_json_schema, - guided_decoding_backend=guided_decoding_backend)) - - assert completion.id is not None - assert len(completion.choices) == 3 - for i in range(3): - output_json = json.loads(completion.choices[i].text) - jsonschema.validate(instance=output_json, schema=sample_json_schema) - - -@pytest.mark.asyncio -@pytest.mark.parametrize("guided_decoding_backend", - ["outlines", "lm-format-enforcer"]) -async def test_guided_regex_completion(client: openai.AsyncOpenAI, - guided_decoding_backend: str, - sample_regex): - completion = await client.completions.create( - model=MODEL_NAME, - prompt=f"Give an example IPv4 address with this regex: {sample_regex}", - n=3, - temperature=1.0, - max_tokens=20, - extra_body=dict(guided_regex=sample_regex, - guided_decoding_backend=guided_decoding_backend)) - - assert completion.id is not None - assert len(completion.choices) == 3 - for i in range(3): - assert re.fullmatch(sample_regex, - completion.choices[i].text) is not None - - -@pytest.mark.asyncio -@pytest.mark.parametrize("guided_decoding_backend", - ["outlines", "lm-format-enforcer"]) -async def test_guided_choice_completion(client: openai.AsyncOpenAI, - guided_decoding_backend: str, - sample_guided_choice): - completion = await client.completions.create( - model=MODEL_NAME, - prompt="The best language for type-safe systems programming is ", - n=2, - temperature=1.0, - max_tokens=10, - extra_body=dict(guided_choice=sample_guided_choice, - guided_decoding_backend=guided_decoding_backend)) - - assert completion.id is not None - assert len(completion.choices) == 2 - for i in range(2): - assert completion.choices[i].text in sample_guided_choice - - -@pytest.mark.asyncio -async def test_guided_grammar(client: openai.AsyncOpenAI, - sample_sql_statements): - - completion = await client.completions.create( - model=MODEL_NAME, - prompt=("Generate a sql state that select col_1 from " - "table_1 where it is equals to 1"), - temperature=1.0, - max_tokens=500, - extra_body=dict(guided_grammar=sample_sql_statements)) - - content = completion.choices[0].text - - # use Lark to parse the output, and make sure it's a valid parse tree - from lark import Lark - parser = Lark(sample_sql_statements) - parser.parse(content) - - # remove spaces for comparison b/c we removed them in the grammar - ground_truth = "SELECT col_1 from table_1 where col_1 = 1".replace(" ", "") - - assert content.strip() == ground_truth - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - # first test base model, then test loras - "model_name", - [MODEL_NAME, "zephyr-lora", "zephyr-lora2"], -) -@pytest.mark.parametrize("logprobs_arg", [1, 0]) -async def test_echo_logprob_completion(client: openai.AsyncOpenAI, - model_name: str, logprobs_arg: int): - tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME) - # test using text and token IDs - for prompt in ("Hello, my name is", [0, 0, 0, 0, 0]): - completion = await client.completions.create(model=model_name, - prompt=prompt, - max_tokens=5, - temperature=0.0, - echo=True, - logprobs=logprobs_arg) - - prompt_text = tokenizer.decode(prompt) if isinstance(prompt, - list) else prompt - assert re.search(r"^" + prompt_text, completion.choices[0].text) - logprobs = completion.choices[0].logprobs - assert logprobs is not None - assert len(logprobs.text_offset) > 5 - assert (len(logprobs.token_logprobs) > 5 - and logprobs.token_logprobs[0] is None) - assert (len(logprobs.top_logprobs) > 5 - and logprobs.top_logprobs[0] is None) - for top_logprobs in logprobs.top_logprobs[1:]: - assert max(logprobs_arg, - 1) <= len(top_logprobs) <= logprobs_arg + 1 - assert len(logprobs.tokens) > 5 - - -@pytest.mark.asyncio -@pytest.mark.parametrize("guided_decoding_backend", - ["outlines", "lm-format-enforcer"]) -async def test_guided_decoding_type_error(client: openai.AsyncOpenAI, - guided_decoding_backend: str, - sample_json_schema, sample_regex): - with pytest.raises(openai.BadRequestError): - _ = await client.completions.create( - model=MODEL_NAME, - prompt="Give an example JSON that fits this schema: 42", - extra_body=dict(guided_json=42, - guided_decoding_backend=guided_decoding_backend)) - - with pytest.raises(openai.BadRequestError): - _ = await client.completions.create( - model=MODEL_NAME, - prompt="Give an example string that fits this regex", - extra_body=dict(guided_regex=sample_regex, - guided_json=sample_json_schema))