diff --git a/.buildkite/run-amd-test.sh b/.buildkite/run-amd-test.sh index 329cc42558da6..860272e71fd84 100755 --- a/.buildkite/run-amd-test.sh +++ b/.buildkite/run-amd-test.sh @@ -107,11 +107,12 @@ fi PARALLEL_JOB_COUNT=8 # check if the command contains shard flag, we will run all shards in parallel because the host have 8 GPUs. if [[ $commands == *"--shard-id="* ]]; then + # assign job count as the number of shards used + commands=${commands//"--num-shards= "/"--num-shards=${PARALLEL_JOB_COUNT} "} for GPU in $(seq 0 $(($PARALLEL_JOB_COUNT-1))); do - #replace shard arguments - commands=${commands//"--shard-id= "/"--shard-id=${GPU} "} - commands=${commands//"--num-shards= "/"--num-shards=${PARALLEL_JOB_COUNT} "} - echo "Shard ${GPU} commands:$commands" + # assign shard-id for each shard + commands_gpu=${commands//"--shard-id= "/"--shard-id=${GPU} "} + echo "Shard ${GPU} commands:$commands_gpu" docker run \ --device /dev/kfd --device /dev/dri \ --network host \ @@ -123,7 +124,7 @@ if [[ $commands == *"--shard-id="* ]]; then -e HF_HOME=${HF_MOUNT} \ --name ${container_name}_${GPU} \ ${image_name} \ - /bin/bash -c "${commands}" \ + /bin/bash -c "${commands_gpu}" \ |& while read -r line; do echo ">>Shard $GPU: $line"; done & PIDS+=($!) done diff --git a/.buildkite/run-tpu-test.sh b/.buildkite/run-tpu-test.sh index 6989c94d46a89..988d5aef5fb8c 100644 --- a/.buildkite/run-tpu-test.sh +++ b/.buildkite/run-tpu-test.sh @@ -12,4 +12,4 @@ remove_docker_container # For HF_TOKEN. source /etc/environment # Run a simple end-to-end example. -docker run --privileged --net host --shm-size=16G -it -e HF_TOKEN=$HF_TOKEN --name tpu-test vllm-tpu /bin/bash -c "python3 -m pip install git+https://github.com/thuml/depyf.git && python3 -m pip install pytest && pytest -v -s /workspace/vllm/tests/tpu/test_custom_dispatcher.py && python3 /workspace/vllm/tests/tpu/test_compilation.py && python3 /workspace/vllm/examples/offline_inference_tpu.py" +docker run --privileged --net host --shm-size=16G -it -e HF_TOKEN=$HF_TOKEN --name tpu-test vllm-tpu /bin/bash -c "python3 -m pip install git+https://github.com/thuml/depyf.git && python3 -m pip install pytest && python3 -m pip install lm_eval[api]==0.4.4 && pytest -v -s /workspace/vllm/tests/entrypoints/openai/test_accuracy.py && pytest -v -s /workspace/vllm/tests/tpu/test_custom_dispatcher.py && python3 /workspace/vllm/tests/tpu/test_compilation.py && python3 /workspace/vllm/examples/offline_inference_tpu.py" diff --git a/.github/dependabot.yml b/.github/dependabot.yml index a21acd9671eeb..4f54eea564ecb 100644 --- a/.github/dependabot.yml +++ b/.github/dependabot.yml @@ -14,6 +14,15 @@ updates: reviewers: ["khluu", "simon-mo"] allow: - dependency-type: "all" + ignore: + - dependency-name: "torch" + - dependency-name: "torchvision" + - dependency-name: "xformers" + - dependency-name: "lm-format-enforcer" + - dependency-name: "gguf" + - dependency-name: "compressed-tensors" + - dependency-name: "ray[adag]" + - dependency-name: "lm-eval" groups: patch-update: applies-to: version-updates diff --git a/CMakeLists.txt b/CMakeLists.txt index 1a6a311e97633..943424bc4edfa 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -49,8 +49,8 @@ set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx11 # requirements.txt files and should be kept consistent. The ROCm torch # versions are derived from Dockerfile.rocm # -set(TORCH_SUPPORTED_VERSION_CUDA "2.5.0") -set(TORCH_SUPPORTED_VERSION_ROCM "2.5.0") +set(TORCH_SUPPORTED_VERSION_CUDA "2.5.1") +set(TORCH_SUPPORTED_VERSION_ROCM "2.5.1") # # Try to find python package with an executable that exactly matches diff --git a/Dockerfile b/Dockerfile index 0a562253c537b..343364da2ebf5 100644 --- a/Dockerfile +++ b/Dockerfile @@ -206,7 +206,7 @@ FROM vllm-base AS vllm-openai # install additional dependencies for openai api server RUN --mount=type=cache,target=/root/.cache/pip \ - pip install accelerate hf_transfer 'modelscope!=1.15.0' bitsandbytes>=0.44.0 timm==0.9.10 + pip install accelerate hf_transfer 'modelscope!=1.15.0' 'bitsandbytes>=0.44.0' timm==0.9.10 ENV VLLM_USAGE_SOURCE production-docker-image diff --git a/Dockerfile.neuron b/Dockerfile.neuron index 0d0d8df94578c..2143315d2a078 100644 --- a/Dockerfile.neuron +++ b/Dockerfile.neuron @@ -31,7 +31,7 @@ RUN --mount=type=bind,source=.git,target=.git \ if [ "$GIT_REPO_CHECK" != 0 ]; then bash tools/check_repo.sh ; fi RUN python3 -m pip install -U \ - cmake>=3.26 ninja packaging setuptools-scm>=8 wheel jinja2 \ + 'cmake>=3.26' ninja packaging 'setuptools-scm>=8' wheel jinja2 \ -r requirements-neuron.txt ENV VLLM_TARGET_DEVICE neuron diff --git a/Dockerfile.ppc64le b/Dockerfile.ppc64le index cd5fcf481f07c..b19c6ddec7948 100644 --- a/Dockerfile.ppc64le +++ b/Dockerfile.ppc64le @@ -21,7 +21,7 @@ RUN --mount=type=bind,source=.git,target=.git \ # These packages will be in rocketce eventually RUN --mount=type=cache,target=/root/.cache/pip \ pip install -v --prefer-binary --extra-index-url https://repo.fury.io/mgiessing \ - cmake>=3.26 ninja packaging setuptools-scm>=8 wheel jinja2 \ + 'cmake>=3.26' ninja packaging 'setuptools-scm>=8' wheel jinja2 \ torch==2.3.1 \ -r requirements-cpu.txt \ xformers uvloop==0.20.0 diff --git a/Dockerfile.rocm b/Dockerfile.rocm index 562117a313020..8fb79afaebe97 100644 --- a/Dockerfile.rocm +++ b/Dockerfile.rocm @@ -52,7 +52,7 @@ RUN --mount=type=cache,target=/root/.cache/pip \ python3 -m pip uninstall -y torch torchvision \ && python3 -m pip install --pre \ torch==2.6.0.dev20240918 \ - setuptools-scm>=8 \ + 'setuptools-scm>=8' \ torchvision==0.20.0.dev20240918 \ --extra-index-url https://download.pytorch.org/whl/nightly/rocm6.2;; \ *) ;; esac diff --git a/Dockerfile.tpu b/Dockerfile.tpu index dd8f9ad4714a9..b43442e4c0af1 100644 --- a/Dockerfile.tpu +++ b/Dockerfile.tpu @@ -25,7 +25,7 @@ ENV VLLM_TARGET_DEVICE="tpu" RUN --mount=type=cache,target=/root/.cache/pip \ --mount=type=bind,source=.git,target=.git \ python3 -m pip install \ - cmake>=3.26 ninja packaging setuptools-scm>=8 wheel jinja2 \ + 'cmake>=3.26' ninja packaging 'setuptools-scm>=8' wheel jinja2 \ -r requirements-tpu.txt RUN python3 setup.py develop diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py index 0d205014b15bf..ff06622628219 100644 --- a/benchmarks/benchmark_serving.py +++ b/benchmarks/benchmark_serving.py @@ -406,9 +406,9 @@ def calculate_metrics( median_itl_ms=np.median(itls or 0) * 1000, percentiles_itl_ms=[(p, np.percentile(itls or 0, p) * 1000) for p in selected_percentiles], - mean_e2el_ms=np.median(e2els or 0) * 1000, + mean_e2el_ms=np.mean(e2els or 0) * 1000, std_e2el_ms=np.std(e2els or 0) * 1000, - median_e2el_ms=np.mean(e2els or 0) * 1000, + median_e2el_ms=np.median(e2els or 0) * 1000, percentiles_e2el_ms=[(p, np.percentile(e2els or 0, p) * 1000) for p in selected_percentiles], ) diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index ee41c8ea38382..262b8652e49ff 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -4,7 +4,7 @@ import json import random import time -from typing import List, Optional, Tuple +from typing import List, Optional import torch import uvloop @@ -15,16 +15,35 @@ from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs from vllm.entrypoints.openai.api_server import ( build_async_engine_client_from_engine_args) +from vllm.inputs import TextPrompt +from vllm.multimodal import MultiModalDataDict from vllm.sampling_params import BeamSearchParams from vllm.utils import FlexibleArgumentParser, merge_async_iterators +@dataclasses.dataclass +class SampleRequest: + """A class representing a single inference request for benchmarking. + + Attributes: + prompt: The input text prompt for the model. + multi_modal_data: Optional dictionary containing multi-modal data (e.g. + images). + prompt_len: The length of the prompt in tokens. + expected_output_len: The expected length of the output in tokens. + """ + prompt: str + prompt_len: int + expected_output_len: int + multi_modal_data: Optional[MultiModalDataDict] = None + + def sample_requests( dataset_path: str, num_requests: int, tokenizer: PreTrainedTokenizerBase, fixed_output_len: Optional[int], -) -> List[Tuple[str, int, int]]: +) -> List[SampleRequest]: if fixed_output_len is not None and fixed_output_len < 4: raise ValueError("output_len too small") @@ -41,7 +60,7 @@ def sample_requests( random.shuffle(dataset) # Filter out sequences that are too long or too short - filtered_dataset: List[Tuple[str, int, int]] = [] + filtered_dataset: List[SampleRequest] = [] for i in range(len(dataset)): if len(filtered_dataset) == num_requests: break @@ -60,13 +79,16 @@ def sample_requests( if prompt_len > 1024 or prompt_len + output_len > 2048: # Prune too long sequences. continue - filtered_dataset.append((prompt, prompt_len, output_len)) + filtered_dataset.append( + SampleRequest(prompt=prompt, + prompt_len=prompt_len, + expected_output_len=output_len)) return filtered_dataset def run_vllm( - requests: List[Tuple[str, int, int]], + requests: List[SampleRequest], n: int, engine_args: EngineArgs, ) -> float: @@ -74,17 +96,17 @@ def run_vllm( llm = LLM(**dataclasses.asdict(engine_args)) # Add the requests to the engine. - prompts: List[str] = [] + prompts: List[TextPrompt] = [] sampling_params: List[SamplingParams] = [] - for prompt, _, output_len in requests: - prompts.append(prompt) + for request in requests: + prompts.append(TextPrompt(prompt=request.prompt)) sampling_params.append( SamplingParams( n=n, temperature=1.0, top_p=1.0, ignore_eos=True, - max_tokens=output_len, + max_tokens=request.expected_output_len, )) use_beam_search = False @@ -94,11 +116,11 @@ def run_vllm( llm.generate(prompts, sampling_params, use_tqdm=True) end = time.perf_counter() else: - prompts = [prompt for prompt, _, _ in requests] + prompts = [request.prompt for request in requests] # output_len should be the same for all requests. output_len = requests[0][2] - for prompt, input_len, _output_len in requests: - assert _output_len == output_len + for request in requests: + assert request.expected_output_len == output_len start = time.perf_counter() llm.beam_search( prompts, @@ -112,7 +134,7 @@ def run_vllm( async def run_vllm_async( - requests: List[Tuple[str, int, int]], + requests: List[SampleRequest], n: int, engine_args: AsyncEngineArgs, disable_frontend_multiprocessing: bool = False, @@ -123,17 +145,17 @@ async def run_vllm_async( engine_args, disable_frontend_multiprocessing) as llm: # Add the requests to the engine. - prompts: List[str] = [] + prompts: List[TextPrompt] = [] sampling_params: List[SamplingParams] = [] - for prompt, _, output_len in requests: - prompts.append(prompt) + for request in requests: + prompts.append(TextPrompt(prompt=request.prompt)) sampling_params.append( SamplingParams( n=n, temperature=1.0, top_p=1.0, ignore_eos=True, - max_tokens=output_len, + max_tokens=request.expected_output_len, )) generators = [] @@ -149,7 +171,7 @@ async def run_vllm_async( def run_hf( - requests: List[Tuple[str, int, int]], + requests: List[SampleRequest], model: str, tokenizer: PreTrainedTokenizerBase, n: int, @@ -207,14 +229,14 @@ def run_hf( def run_mii( - requests: List[Tuple[str, int, int]], + requests: List[SampleRequest], model: str, tensor_parallel_size: int, output_len: int, ) -> float: from mii import client, serve llm = serve(model, tensor_parallel=tensor_parallel_size) - prompts = [prompt for prompt, _, _ in requests] + prompts = [request.prompt for request in requests] start = time.perf_counter() llm.generate(prompts, max_new_tokens=output_len) @@ -243,8 +265,12 @@ def main(args: argparse.Namespace): else: raise ValueError( f"Failed to synthesize a prompt with {args.input_len} tokens.") - requests = [(prompt, args.input_len, args.output_len) - for _ in range(args.num_prompts)] + requests = [ + SampleRequest(prompt=prompt, + prompt_len=args.input_len, + expected_output_len=args.output_len) + for _ in range(args.num_prompts) + ] else: requests = sample_requests(args.dataset, args.num_prompts, tokenizer, args.output_len) @@ -270,9 +296,10 @@ def main(args: argparse.Namespace): args.output_len) else: raise ValueError(f"Unknown backend: {args.backend}") - total_num_tokens = sum(prompt_len + output_len - for _, prompt_len, output_len in requests) - total_output_tokens = sum(output_len for _, _, output_len in requests) + total_num_tokens = sum(request.prompt_len + request.expected_output_len + for request in requests) + total_output_tokens = sum(request.expected_output_len + for request in requests) print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, " f"{total_num_tokens / elapsed_time:.2f} total tokens/s, " f"{total_output_tokens / elapsed_time:.2f} output tokens/s") @@ -299,7 +326,9 @@ def main(args: argparse.Namespace): parser.add_argument("--dataset", type=str, default=None, - help="Path to the dataset.") + help="Path to the dataset. The dataset is expected to " + "be a json in form of List[Dict[..., conversations: " + "List[Dict[..., value: ]]]]") parser.add_argument("--input-len", type=int, default=None, diff --git a/docs/source/getting_started/tpu-installation.rst b/docs/source/getting_started/tpu-installation.rst index edba209986f6a..f0c812b941c1f 100644 --- a/docs/source/getting_started/tpu-installation.rst +++ b/docs/source/getting_started/tpu-installation.rst @@ -1,35 +1,167 @@ .. _installation_tpu: +##################### Installation with TPU -===================== +##################### -vLLM supports Google Cloud TPUs using PyTorch XLA. +Tensor Processing Units (TPUs) are Google's custom-developed application-specific +integrated circuits (ASICs) used to accelerate machine learning workloads. TPUs +are available in different versions each with different hardware specifications. +For more information about TPUs, see `TPU System Architecture `_. +For more information on the TPU versions supported with vLLM, see: + +* `TPU v6e `_ +* `TPU v5e `_ +* `TPU v5p `_ +* `TPU v4 `_ + +These TPU versions allow you to configure the physical arrangements of the TPU +chips. This can improve throughput and networking performance. For more +information see: + +* `TPU v6e topologies `_ +* `TPU v5e topologies `_ +* `TPU v5p topologies `_ +* `TPU v4 topologies `_ + +In order for you to use Cloud TPUs you need to have TPU quota granted to your +Google Cloud Platform project. TPU quotas specify how many TPUs you can use in a +GPC project and are specified in terms of TPU version, the number of TPU you +want to use, and quota type. For more information, see `TPU quota `_. + +For TPU pricing information, see `Cloud TPU pricing `_. + +You may need additional persistent storage for your TPU VMs. For more +information, see `Storage options for Cloud TPU data `_. Requirements ------------ -* Google Cloud TPU VM (single & multi host) -* TPU versions: v5e, v5p, v4 -* Python: 3.10 +* Google Cloud TPU VM +* TPU versions: v6e, v5e, v5p, v4 +* Python: 3.10 or newer + +Provision Cloud TPUs +==================== + +You can provision Cloud TPUs using the `Cloud TPU API `_` +or the `queued resources `_` +API. This section shows how to create TPUs using the queued resource API. +For more information about using the Cloud TPU API, see `Create a Cloud TPU using the Create Node API `_. +`Queued resources `_ +enable you to request Cloud TPU resources in a queued manner. When you request +queued resources, the request is added to a queue maintained by the Cloud TPU +service. When the requested resource becomes available, it's assigned to your +Google Cloud project for your immediate exclusive use. + +Provision a Cloud TPU with the queued resource API +-------------------------------------------------- +Create a TPU v5e with 4 TPU chips: + +.. code-block:: console + + gcloud alpha compute tpus queued-resources create QUEUED_RESOURCE_ID \ + --node-id TPU_NAME \ + --project PROJECT_ID \ + --zone ZONE \ + --accelerator-type ACCELERATOR_TYPE \ + --runtime-version RUNTIME_VERSION \ + --service-account SERVICE_ACCOUNT + +.. list-table:: Parameter descriptions + :header-rows: 1 + + * - Parameter name + - Description + * - QUEUED_RESOURCE_ID + - The user-assigned ID of the queued resource request. + * - TPU_NAME + - The user-assigned name of the TPU which is created when the queued + resource request is allocated. + * - PROJECT_ID + - Your Google Cloud project + * - ZONE + - The `zone `_ where you + want to create your Cloud TPU. + * - ACCELERATOR_TYPE + - The TPU version you want to use. Specify the TPU version, followed by a + '-' and the number of TPU cores. For example `v5e-4` specifies a v5e TPU + with 4 cores. For more information, see `TPU versions `_. + * - RUNTIME_VERSION + - The TPU VM runtime version to use. For more information see `TPU VM images `_. + * - SERVICE_ACCOUNT + - The email address for your service account. You can find it in the IAM + Cloud Console under *Service Accounts*. For example: + `tpu-service-account@.iam.gserviceaccount.com` + +Connect to your TPU using SSH: + +.. code-block:: bash + + gcloud compute tpus tpu-vm ssh TPU_NAME + +Create and activate a Conda environment for vLLM: + +.. code-block:: bash -Installation options: + conda create -n vllm python=3.10 -y + conda activate vllm -1. :ref:`Build a docker image with Dockerfile `. -2. :ref:`Build from source `. +Clone the vLLM repository and go to the vLLM directory: + +.. code-block:: bash + + git clone https://github.com/vllm-project/vllm.git && cd vllm + +Uninstall the existing `torch` and `torch_xla` packages: + +.. code-block:: bash + + pip uninstall torch torch-xla -y + +Install `torch` and `torch_xla` + +.. code-block:: bash + + pip install --pre torch==2.6.0.dev20241028+cpu torchvision==0.20.0.dev20241028+cpu --index-url https://download.pytorch.org/whl/nightly/cpu + pip install 'torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev-cp310-cp310-linux_x86_64.whl' -f https://storage.googleapis.com/libtpu-releases/index.html + +Install JAX and Pallas: + +.. code-block:: bash + + pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html + pip install jaxlib==0.4.32.dev20240829 jax==0.4.32.dev20240829 -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html + +Install other build dependencies: + +.. code-block:: bash + + pip install -r requirements-tpu.txt + VLLM_TARGET_DEVICE="tpu" python setup.py develop + sudo apt-get install libopenblas-base libopenmpi-dev libomp-dev + +Provision Cloud TPUs with GKE +----------------------------- + +For more information about using TPUs with GKE, see +https://cloud.google.com/kubernetes-engine/docs/how-to/tpus +https://cloud.google.com/kubernetes-engine/docs/concepts/tpus +https://cloud.google.com/kubernetes-engine/docs/concepts/plan-tpus .. _build_docker_tpu: Build a docker image with :code:`Dockerfile.tpu` ------------------------------------------------ -`Dockerfile.tpu `_ is provided to build a docker image with TPU support. +You can use `Dockerfile.tpu `_ +to build a Docker image with TPU support. .. code-block:: console $ docker build -f Dockerfile.tpu -t vllm-tpu . - -You can run the docker image with the following command: +Run the Docker image with the following command: .. code-block:: console @@ -75,14 +207,12 @@ Next, build vLLM from source. This will only take a few seconds: $ VLLM_TARGET_DEVICE="tpu" python setup.py develop - .. note:: Since TPU relies on XLA which requires static shapes, vLLM bucketizes the possible input shapes and compiles an XLA graph for each different shape. The compilation time may take 20~30 minutes in the first run. However, the compilation time reduces to ~5 minutes afterwards because the XLA graphs are cached in the disk (in :code:`VLLM_XLA_CACHE_PATH` or :code:`~/.cache/vllm/xla_cache` by default). - .. tip:: If you encounter the following error: @@ -93,7 +223,7 @@ Next, build vLLM from source. This will only take a few seconds: ImportError: libopenblas.so.0: cannot open shared object file: No such file or directory - Please install OpenBLAS with the following command: + Install OpenBLAS with the following command: .. code-block:: console diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index 80714a90df5c2..55835d945b00c 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -160,13 +160,13 @@ Text Generation - - ✅︎ * - :code:`GraniteForCausalLM` - - PowerLM - - :code:`ibm/PowerLM-3b` etc. + - Granite 3.0, PowerLM + - :code:`ibm-granite/granite-3.0-2b-base`, :code:`ibm-granite/granite-3.0-8b-instruct`, :code:`ibm/PowerLM-3b`, etc. - ✅︎ - ✅︎ * - :code:`GraniteMoeForCausalLM` - - PowerMoE - - :code:`ibm/PowerMoE-3b` etc. + - Granite 3.0 MoE, PowerMoE + - :code:`ibm-granite/granite-3.0-1b-a400m-base`, :code:`ibm-granite/granite-3.0-3b-a800m-instruct`, :code:`ibm/PowerMoE-3b`, etc. - ✅︎ - ✅︎ * - :code:`InternLMForCausalLM` @@ -440,6 +440,12 @@ Text Generation - :code:`THUDM/glm-4v-9b` etc. - - ✅︎ + * - :code:`H2OVLChatModel` + - H2OVL + - T + I\ :sup:`E+` + - :code:`h2oai/h2ovl-mississippi-800m`, :code:`h2oai/h2ovl-mississippi-2b`, etc. + - + - ✅︎ * - :code:`InternVLChatModel` - InternVL2 - T + I\ :sup:`E+` diff --git a/docs/source/models/vlm.rst b/docs/source/models/vlm.rst index 3377502a6db28..112e9db6a41de 100644 --- a/docs/source/models/vlm.rst +++ b/docs/source/models/vlm.rst @@ -242,6 +242,10 @@ To consume the server, you can use the OpenAI client like in the example below: A full code example can be found in `examples/openai_chat_completion_client_for_multimodal.py `_. +.. tip:: + Loading from local file paths is also supported on vLLM: You can specify the allowed local media path via ``--allowed-local-media-path`` when launching the API server/engine, + and pass the file path as ``url`` in the API request. + .. tip:: There is no need to place image placeholders in the text content of the API request - they are already represented by the image content. In fact, you can place image placeholders in the middle of the text by interleaving text and image content. diff --git a/examples/offline_inference_audio_language.py b/examples/offline_inference_audio_language.py index 37ec667d96a77..050b791b62adb 100644 --- a/examples/offline_inference_audio_language.py +++ b/examples/offline_inference_audio_language.py @@ -34,11 +34,7 @@ def run_ultravox(question: str, audio_count: int): tokenize=False, add_generation_prompt=True) - llm = LLM(model=model_name, - enforce_eager=True, - enable_chunked_prefill=False, - max_model_len=8192, - limit_mm_per_prompt={"audio": audio_count}) + llm = LLM(model=model_name, limit_mm_per_prompt={"audio": audio_count}) stop_token_ids = None return llm, prompt, stop_token_ids diff --git a/examples/offline_inference_vision_language.py b/examples/offline_inference_vision_language.py index 60cdb186331fe..4fd002caf1763 100644 --- a/examples/offline_inference_vision_language.py +++ b/examples/offline_inference_vision_language.py @@ -176,6 +176,31 @@ def run_minicpmv(question: str, modality: str): return llm, prompt, stop_token_ids +# H2OVL-Mississippi +def run_h2ovl(question: str, modality: str): + assert modality == "image" + + model_name = "h2oai/h2ovl-mississippi-2b" + + llm = LLM( + model=model_name, + trust_remote_code=True, + max_model_len=8192, + ) + + tokenizer = AutoTokenizer.from_pretrained(model_name, + trust_remote_code=True) + messages = [{'role': 'user', 'content': f"\n{question}"}] + prompt = tokenizer.apply_chat_template(messages, + tokenize=False, + add_generation_prompt=True) + + # Stop tokens for H2OVL-Mississippi + # https://huggingface.co/h2oai/h2ovl-mississippi-2b + stop_token_ids = [tokenizer.eos_token_id] + return llm, prompt, stop_token_ids + + # InternVL def run_internvl(question: str, modality: str): assert modality == "image" @@ -363,6 +388,7 @@ def run_glm4v(question: str, modality: str): "chameleon": run_chameleon, "minicpmv": run_minicpmv, "blip-2": run_blip2, + "h2ovl_chat": run_h2ovl, "internvl_chat": run_internvl, "NVLM_D": run_nvlm_d, "qwen_vl": run_qwen_vl, @@ -475,4 +501,4 @@ def main(args): default=16, help='Number of frames to extract from the video.') args = parser.parse_args() - main(args) + main(args) \ No newline at end of file diff --git a/examples/offline_inference_vision_language_multi_image.py b/examples/offline_inference_vision_language_multi_image.py index e28514bf403f7..d99684078ff3d 100644 --- a/examples/offline_inference_vision_language_multi_image.py +++ b/examples/offline_inference_vision_language_multi_image.py @@ -107,6 +107,40 @@ def load_phi3v(question: str, image_urls: List[str]) -> ModelRequestData: ) +def load_h2onvl(question: str, image_urls: List[str]) -> ModelRequestData: + model_name = "h2oai/h2ovl-mississippi-2b" + + llm = LLM( + model=model_name, + trust_remote_code=True, + max_model_len=8192, + limit_mm_per_prompt={"image": len(image_urls)}, + mm_processor_kwargs={"max_dynamic_patch": 4}, + ) + + placeholders = "\n".join(f"Image-{i}: \n" + for i, _ in enumerate(image_urls, start=1)) + messages = [{'role': 'user', 'content': f"{placeholders}\n{question}"}] + + tokenizer = AutoTokenizer.from_pretrained(model_name, + trust_remote_code=True) + prompt = tokenizer.apply_chat_template(messages, + tokenize=False, + add_generation_prompt=True) + + # Stop tokens for H2OVL-Mississippi + # https://huggingface.co/h2oai/h2ovl-mississippi-2b + stop_token_ids = [tokenizer.eos_token_id] + + return ModelRequestData( + llm=llm, + prompt=prompt, + stop_token_ids=stop_token_ids, + image_data=[fetch_image(url) for url in image_urls], + chat_template=None, + ) + + def load_internvl(question: str, image_urls: List[str]) -> ModelRequestData: model_name = "OpenGVLab/InternVL2-2B" @@ -258,6 +292,7 @@ def load_mllama(question, image_urls: List[str]) -> ModelRequestData: model_example_map = { "phi3_v": load_phi3v, + "h2ovl_chat": load_h2onvl, "internvl_chat": load_internvl, "NVLM_D": load_nvlm_d, "qwen2_vl": load_qwen2_vl, diff --git a/pyproject.toml b/pyproject.toml index e78f5652f486b..0bbab3cd3fbc3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,7 @@ requires = [ "packaging", "setuptools>=61", "setuptools-scm>=8.0", - "torch == 2.5.0", + "torch == 2.5.1", "wheel", "jinja2", ] diff --git a/requirements-build.txt b/requirements-build.txt index 7b16d9778c1a6..fec01caaf25ef 100644 --- a/requirements-build.txt +++ b/requirements-build.txt @@ -4,6 +4,6 @@ ninja packaging setuptools>=61 setuptools-scm>=8 -torch==2.5.0 +torch==2.5.1 wheel jinja2 diff --git a/requirements-cuda.txt b/requirements-cuda.txt index 282ab11838bf4..058ab7c1ee9df 100644 --- a/requirements-cuda.txt +++ b/requirements-cuda.txt @@ -4,7 +4,7 @@ # Dependencies for NVIDIA GPUs ray >= 2.9 nvidia-ml-py >= 12.560.30 # for pynvml package -torch == 2.5.0 +torch == 2.5.1 # These must be updated alongside torch -torchvision == 0.20 # Required for phi3v processor. See https://github.com/pytorch/vision?tab=readme-ov-file#installation for corresponding version -xformers == 0.0.28.post2; platform_system == 'Linux' and platform_machine == 'x86_64' # Requires PyTorch 2.5.0 +torchvision == 0.20.1 # Required for phi3v processor. See https://github.com/pytorch/vision?tab=readme-ov-file#installation for corresponding version +xformers == 0.0.28.post3; platform_system == 'Linux' and platform_machine == 'x86_64' # Requires PyTorch 2.5.1 diff --git a/requirements-lint.txt b/requirements-lint.txt index 07f738873e1a8..f9132bbf96437 100644 --- a/requirements-lint.txt +++ b/requirements-lint.txt @@ -1,7 +1,7 @@ # formatting yapf==0.32.0 toml==0.10.2 -tomli==2.0.1 +tomli==2.0.2 ruff==0.6.5 codespell==2.3.0 isort==5.13.2 diff --git a/requirements-openvino.txt b/requirements-openvino.txt index 7ad0d1e7f704b..95e5914757812 100644 --- a/requirements-openvino.txt +++ b/requirements-openvino.txt @@ -1,7 +1,7 @@ # Common dependencies -r requirements-common.txt -torch == 2.5.0 # should be aligned with "common" vLLM torch version +torch == 2.5.1 # should be aligned with "common" vLLM torch version openvino >= 2024.4.0 # since 2024.4.0 both CPU and GPU support Paged Attention optimum @ git+https://github.com/huggingface/optimum.git@main # latest optimum is used to support latest transformers version diff --git a/requirements-test.in b/requirements-test.in index 3881f2566b556..560c005fd6157 100644 --- a/requirements-test.in +++ b/requirements-test.in @@ -18,7 +18,7 @@ ray[adag]==2.35 sentence-transformers # required for embedding soundfile # required for audio test timm # required for internvl test -torch==2.5.0 +torch==2.5.1 transformers_stream_generator # required for qwen-vl test matplotlib # required for qwen-vl test datamodel_code_generator # required for minicpm3 test @@ -32,6 +32,6 @@ aiohttp # quantization bitsandbytes>=0.44.0 -buildkite-test-collector==0.1.8 +buildkite-test-collector==0.1.9 numpy < 2.0.0 diff --git a/requirements-test.txt b/requirements-test.txt index c474c2ec34b22..518e81021cbcb 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -36,20 +36,20 @@ attrs==24.2.0 # referencing audioread==3.0.1 # via librosa -awscli==1.35.16 +awscli==1.35.19 # via -r requirements-test.in bitsandbytes==0.44.1 # via -r requirements-test.in black==24.10.0 # via datamodel-code-generator -boto3==1.35.50 +boto3==1.35.53 # via tensorizer -botocore==1.35.50 +botocore==1.35.53 # via # awscli # boto3 # s3transfer -buildkite-test-collector==0.1.8 +buildkite-test-collector==0.1.9 # via -r requirements-test.in certifi==2024.8.30 # via @@ -426,7 +426,7 @@ requests==2.32.3 # transformers rouge-score==0.1.2 # via lm-eval -rpds-py==0.20.0 +rpds-py==0.20.1 # via # jsonschema # referencing @@ -492,7 +492,7 @@ timm==1.0.11 # via -r requirements-test.in tokenizers==0.20.1 # via transformers -torch==2.5.0 +torch==2.5.1 # via # -r requirements-test.in # accelerate @@ -503,7 +503,7 @@ torch==2.5.0 # tensorizer # timm # torchvision -torchvision==0.20.0 +torchvision==0.20.1 # via timm tqdm==4.66.6 # via @@ -552,7 +552,7 @@ xxhash==3.5.0 # via # datasets # evaluate -yarl==1.17.0 +yarl==1.17.1 # via aiohttp zstandard==0.23.0 # via lm-eval diff --git a/tests/core/utils.py b/tests/core/utils.py index a95a573db7cd3..cd0caa4704e11 100644 --- a/tests/core/utils.py +++ b/tests/core/utils.py @@ -4,6 +4,7 @@ from typing import Tuple from vllm import SamplingParams +from vllm.inputs import EncoderDecoderInputs, token_inputs from vllm.lora.request import LoRARequest from vllm.sequence import Logprob, Sequence, SequenceGroup @@ -27,10 +28,7 @@ def create_dummy_prompt( prompt_tokens = list(range(prompt_length)) prompt_str = " ".join([str(t) for t in prompt_tokens]) prompt = Sequence(int(request_id), - inputs={ - "prompt": prompt_str, - "prompt_token_ids": prompt_tokens, - }, + inputs=token_inputs(prompt_tokens, prompt=prompt_str), block_size=block_size) seq_group = SequenceGroup(request_id=request_id, seqs=[prompt], @@ -63,23 +61,21 @@ def create_dummy_prompt_encoder_decoder( encoder_prompt_tokens = list(reversed(list(range(encoder_prompt_length)))) encoder_prompt_str = " ".join([str(t) for t in encoder_prompt_tokens]) - inputs = { - "prompt": decoder_prompt_str, - "prompt_token_ids": decoder_prompt_tokens, - "encoder_prompt": encoder_prompt_str, - "encoder_prompt_token_ids": encoder_prompt_tokens, - "multi_modal_data": None, + inputs: EncoderDecoderInputs = { + "decoder": token_inputs(decoder_prompt_tokens, + prompt=decoder_prompt_str), + "encoder": token_inputs(encoder_prompt_tokens, + prompt=encoder_prompt_str), } decoder_prompt = Sequence(int(request_id), - inputs=inputs, - block_size=block_size, - from_decoder_prompt=True) + inputs=inputs["decoder"], + block_size=block_size) encoder_prompt = Sequence(int(request_id), - inputs=inputs, - block_size=block_size, - from_decoder_prompt=False) + inputs=inputs["encoder"], + block_size=block_size) + seq_group = SequenceGroup(request_id=request_id, seqs=[decoder_prompt], sampling_params=SamplingParams(best_of=best_of), @@ -108,7 +104,7 @@ def create_seq_group( for seq_id_offset, output_len in enumerate(seq_output_lens): seq = Sequence( seq_id=seq_id_start + seq_id_offset, - inputs={"prompt_token_ids": prompt_token_ids}, + inputs=token_inputs(prompt_token_ids), block_size=16, ) @@ -143,21 +139,19 @@ def create_seq_group_encoder_decoder( prompt_token_ids = [0] * seq_prompt_len - inputs = { - "prompt": "", - "prompt_token_ids": prompt_token_ids, - "encoder_prompt": "", - "encoder_prompt_token_ids": prompt_token_ids, - "multi_modal_data": None, + inputs: EncoderDecoderInputs = { + "decoder": token_inputs(prompt_token_ids), + "encoder": token_inputs(prompt_token_ids), } seqs = [] for seq_id_offset, output_len in enumerate(seq_output_lens): # Construct decoder input sequences - seq = Sequence(seq_id=seq_id_start + seq_id_offset, - inputs=inputs, - block_size=16, - from_decoder_prompt=True) + seq = Sequence( + seq_id=seq_id_start + seq_id_offset, + inputs=inputs["decoder"], + block_size=16, + ) for i in range(output_len): seq.append_token_id( @@ -167,10 +161,11 @@ def create_seq_group_encoder_decoder( seqs.append(seq) # Encoder input sequence - encoder_seq = Sequence(seq_id=seq_id_start + len(seq_output_lens), - inputs=inputs, - block_size=16, - from_decoder_prompt=False) + encoder_seq = Sequence( + seq_id=seq_id_start + len(seq_output_lens), + inputs=inputs["encoder"], + block_size=16, + ) return SequenceGroup(request_id=request_id, seqs=seqs, diff --git a/tests/encoder_decoder/test_e2e_correctness.py b/tests/encoder_decoder/test_e2e_correctness.py index bef0c515b9073..f2d7e9fd78cf3 100644 --- a/tests/encoder_decoder/test_e2e_correctness.py +++ b/tests/encoder_decoder/test_e2e_correctness.py @@ -7,12 +7,18 @@ import pytest from transformers import AutoModelForSeq2SeqLM +from vllm.attention.selector import (_Backend, + global_force_attn_backend_context_manager) from vllm.platforms import current_platform from vllm.sequence import SampleLogprobs from ..conftest import DecoderPromptType from ..models.utils import check_logprobs_close +LIST_ENC_DEC_SUPPORTED_BACKENDS = [ + _Backend.XFORMERS, _Backend.FLASH_ATTN, None +] + def vllm_to_hf_output( vllm_output: Tuple[List[int], str, Optional[SampleLogprobs]], @@ -29,7 +35,8 @@ def vllm_to_hf_output( @pytest.mark.parametrize("model", ["facebook/bart-large-cnn"]) -@pytest.mark.parametrize("dtype", ["bfloat16"]) +@pytest.mark.parametrize("dtype", ["float"]) +@pytest.mark.parametrize("attn_backend", LIST_ENC_DEC_SUPPORTED_BACKENDS) @pytest.mark.parametrize("max_tokens", [128]) @pytest.mark.parametrize("num_logprobs", [5]) @pytest.mark.parametrize("decoder_prompt_type", list(DecoderPromptType)) @@ -48,6 +55,7 @@ def test_encoder_decoder_e2e( num_logprobs: int, decoder_prompt_type: DecoderPromptType, enforce_eager: bool, + attn_backend: _Backend, ) -> None: ''' End-to-End (E2E) test for the encoder-decoder framework. @@ -56,43 +64,49 @@ def test_encoder_decoder_e2e( implementations to ensure that both implementations produce consistent and correct results. ''' - test_case_prompts = example_encoder_decoder_prompts[decoder_prompt_type] + with global_force_attn_backend_context_manager(attn_backend): + if attn_backend == _Backend.FLASH_ATTN: + # Flash Attention works only with bfloat16 data-type + dtype = 'bfloat16' + test_case_prompts = example_encoder_decoder_prompts[ + decoder_prompt_type] - # Configuration settings for HF baseline - hf_kwargs = { - "top_k": None, - "num_beams": 1, - "repetition_penalty": 1.0, - "top_p": 1.0, - "length_penalty": 1.0, - "early_stopping": False, - "no_repeat_ngram_size": None, - "min_length": 0 - } + # Configuration settings for HF baseline + hf_kwargs = { + "top_k": None, + "num_beams": 1, + "repetition_penalty": 1.0, + "top_p": 1.0, + "length_penalty": 1.0, + "early_stopping": False, + "no_repeat_ngram_size": None, + "min_length": 0 + } - with hf_runner(model, dtype=dtype, - auto_cls=AutoModelForSeq2SeqLM) as hf_model: - hf_outputs = (hf_model.generate_encoder_decoder_greedy_logprobs_limit( - test_case_prompts, - max_tokens, - num_logprobs, - **hf_kwargs, - )) - with vllm_runner(model, dtype=dtype, - enforce_eager=enforce_eager) as vllm_model: - vllm_outputs = vllm_model.generate_encoder_decoder_greedy_logprobs( - test_case_prompts, max_tokens, num_logprobs) + with hf_runner(model, dtype=dtype, + auto_cls=AutoModelForSeq2SeqLM) as hf_model: + hf_outputs = ( + hf_model.generate_encoder_decoder_greedy_logprobs_limit( + test_case_prompts, + max_tokens, + num_logprobs, + **hf_kwargs, + )) + with vllm_runner(model, dtype=dtype, + enforce_eager=enforce_eager) as vllm_model: + vllm_outputs = vllm_model.generate_encoder_decoder_greedy_logprobs( + test_case_prompts, max_tokens, num_logprobs) - hf_skip_tokens = (1 - if decoder_prompt_type == DecoderPromptType.NONE else 0) + hf_skip_tokens = (1 if decoder_prompt_type == DecoderPromptType.NONE + else 0) - check_logprobs_close( - outputs_0_lst=hf_outputs, - outputs_1_lst=[ - vllm_to_hf_output(vllm_output, decoder_prompt_type) - for vllm_output in vllm_outputs - ], - name_0="hf", - name_1="vllm", - num_outputs_0_skip_tokens=hf_skip_tokens, - ) + check_logprobs_close( + outputs_0_lst=hf_outputs, + outputs_1_lst=[ + vllm_to_hf_output(vllm_output, decoder_prompt_type) + for vllm_output in vllm_outputs + ], + name_0="hf", + name_1="vllm", + num_outputs_0_skip_tokens=hf_skip_tokens, + ) diff --git a/tests/engine/output_processor/test_stop_checker.py b/tests/engine/output_processor/test_stop_checker.py index 0d84443c51f99..cc14e8cbf75df 100644 --- a/tests/engine/output_processor/test_stop_checker.py +++ b/tests/engine/output_processor/test_stop_checker.py @@ -4,6 +4,7 @@ from transformers import PreTrainedTokenizer from vllm.engine.output_processor.stop_checker import StopChecker +from vllm.inputs import token_inputs from vllm.sampling_params import SamplingParams from vllm.sequence import Logprob, Sequence, SequenceStatus @@ -15,7 +16,7 @@ def sequence_with_eos(text: str, eos_token: str, """ seq = Sequence( seq_id=0, - inputs={"prompt_token_ids": []}, + inputs=token_inputs([]), block_size=16, eos_token_id=eos_token_id, ) diff --git a/tests/entrypoints/openai/test_accuracy.py b/tests/entrypoints/openai/test_accuracy.py index 63beaaba29a80..a16e95f94171e 100644 --- a/tests/entrypoints/openai/test_accuracy.py +++ b/tests/entrypoints/openai/test_accuracy.py @@ -10,6 +10,8 @@ import lm_eval import pytest +from vllm.platforms import current_platform + from ...utils import RemoteOpenAIServer MODEL_NAME = "Qwen/Qwen2-1.5B-Instruct" @@ -18,12 +20,21 @@ FILTER = "exact_match,strict-match" RTOL = 0.03 EXPECTED_VALUE = 0.58 -DEFAULT_ARGS = ["--max-model-len", "4096", "--disable-log-requests"] +DEFAULT_ARGS = ["--max-model-len", "2048", "--disable-log-requests"] MORE_ARGS_LIST = [ + [], # Default ["--enable-chunked-prefill"], # Chunked ["--num-scheduler-steps", "8"], # MS ["--num-scheduler-steps", "8", "--multi-step-stream-outputs"] # MS+Stream ] +MAX_WAIT_SECONDS = None + +if current_platform.is_tpu(): + MORE_ARGS_LIST = [ + [], # Default + # ["--num-scheduler-steps", "8"], # Multi-step << currently fails + ] + MAX_WAIT_SECONDS = 600 @pytest.mark.parametrize("more_args", MORE_ARGS_LIST) @@ -33,7 +44,9 @@ def test_lm_eval_accuracy(more_args): print(f"Running with: {args}") - with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: + with RemoteOpenAIServer( + MODEL_NAME, args, + max_wait_seconds=MAX_WAIT_SECONDS) as remote_server: url = f"{remote_server.url_for('v1')}/completions" model_args = ( diff --git a/tests/entrypoints/openai/test_metrics.py b/tests/entrypoints/openai/test_metrics.py index b3f1fea91d13e..6523c8b6297c6 100644 --- a/tests/entrypoints/openai/test_metrics.py +++ b/tests/entrypoints/openai/test_metrics.py @@ -70,10 +70,14 @@ async def client(server): [("_sum", _NUM_REQUESTS * _NUM_GENERATION_TOKENS_PER_REQUEST), ("_count", _NUM_REQUESTS)], "vllm:request_params_n": [("_count", _NUM_REQUESTS)], + "vllm:request_params_max_tokens": + [("_sum", _NUM_REQUESTS * _NUM_GENERATION_TOKENS_PER_REQUEST), + ("_count", _NUM_REQUESTS)], "vllm:prompt_tokens": [("_total", _NUM_REQUESTS * _NUM_PROMPT_TOKENS_PER_REQUEST)], - "vllm:generation_tokens": - [("_total", _NUM_REQUESTS * _NUM_PROMPT_TOKENS_PER_REQUEST)], + "vllm:generation_tokens": [ + ("_total", _NUM_REQUESTS * _NUM_PROMPT_TOKENS_PER_REQUEST) + ], "vllm:request_success": [("_total", _NUM_REQUESTS)], } @@ -149,6 +153,9 @@ async def test_metrics_counts(server: RemoteOpenAIServer, "vllm:request_params_n_sum", "vllm:request_params_n_bucket", "vllm:request_params_n_count", + "vllm:request_params_max_tokens_sum", + "vllm:request_params_max_tokens_bucket", + "vllm:request_params_max_tokens_count", "vllm:num_preemptions_total", "vllm:prompt_tokens_total", "vllm:generation_tokens_total", diff --git a/tests/kernels/test_cache.py b/tests/kernels/test_cache.py index 5b8311a33c361..e2b4778b94b9e 100644 --- a/tests/kernels/test_cache.py +++ b/tests/kernels/test_cache.py @@ -258,19 +258,20 @@ def test_reshape_and_cache_flash( del key_caches del value_caches + k_scale = key.amax().item() / 256 + v_scale = value.amax().item() / 256 + # Clone the KV caches. if kv_cache_dtype == "fp8": cloned_key_cache = torch.empty_like(key_cache, dtype=torch.float16) - ops.convert_fp8(cloned_key_cache, key_cache) + ops.convert_fp8(cloned_key_cache, key_cache, k_scale, kv_cache_dtype) cloned_value_cache = torch.empty_like(value_cache, dtype=torch.float16) - ops.convert_fp8(cloned_value_cache, value_cache) + ops.convert_fp8(cloned_value_cache, value_cache, v_scale, + kv_cache_dtype) else: cloned_key_cache = key_cache.clone() cloned_value_cache = value_cache.clone() - # Using default kv_scale - k_scale = v_scale = 1.0 - # Call the reshape_and_cache kernel. opcheck(torch.ops._C_cache_ops.reshape_and_cache_flash, (key, value, key_cache, value_cache, slot_mapping, kv_cache_dtype, @@ -281,9 +282,15 @@ def test_reshape_and_cache_flash( if kv_cache_dtype == "fp8": result_key_cache = torch.empty_like(key_cache, dtype=torch.float16) - ops.convert_fp8(result_key_cache, key_cache) + ops.convert_fp8(result_key_cache, + key_cache, + k_scale, + kv_dtype=kv_cache_dtype) result_value_cache = torch.empty_like(value_cache, dtype=torch.float16) - ops.convert_fp8(result_value_cache, value_cache) + ops.convert_fp8(result_value_cache, + value_cache, + v_scale, + kv_dtype=kv_cache_dtype) # Run the reference implementation. block_indicies = torch.div(slot_mapping, block_size, rounding_mode="floor") diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index bc99c5559d388..a1dd5eeeaa398 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -16,13 +16,13 @@ from vllm.attention import (Attention, AttentionBackend, AttentionMetadata, AttentionType) from vllm.attention.backends.utils import STR_NOT_IMPL_ENC_DEC_ROCM_HIP -from vllm.attention.selector import (_Backend, +from vllm.attention.selector import (_Backend, get_attn_backend, global_force_attn_backend_context_manager) +from vllm.forward_context import set_forward_context from vllm.platforms import current_platform # List of support backends for encoder/decoder models -LIST_ENC_DEC_SUPPORTED_BACKENDS = [_Backend.XFORMERS] - +LIST_ENC_DEC_SUPPORTED_BACKENDS = [_Backend.XFORMERS, _Backend.FLASH_ATTN] HEAD_SIZES = [64, 256] NUM_HEADS = [1, 16] @@ -145,7 +145,8 @@ class that Attention will automatically select when it is constructed. test_pt.num_heads, test_pt.head_size, test_pt.block_size, - device=CUDA_DEVICE) + device=CUDA_DEVICE, + backend=test_pt.backend_name) return TestResources(scale, attn_backend, attn, kv_cache) @@ -592,6 +593,7 @@ def _run_encoder_attention_test( attn: Attention, encoder_test_params: PhaseTestParameters, attn_metadata: AttentionMetadata, + test_pt: TestPoint, ) -> torch.Tensor: ''' Run encoder attention. @@ -610,6 +612,8 @@ def _run_encoder_attention_test( (number_of_tokens x num_heads x head_size) query/key/value fields * attn_metadata: attention metadata for encoder/decoder-self attention + * test_pt: The TestPoint object containing test details like number of + model heads, head size, name of the backend being used etc. Returns: * Attention.forward() applied to packed {query,key,value} and @@ -619,20 +623,31 @@ def _run_encoder_attention_test( attn_type = AttentionType.ENCODER packed_qkv = encoder_test_params.packed_qkvo.packed_qkv assert packed_qkv is not None - return attn.forward(packed_qkv.query, - packed_qkv.key, - packed_qkv.value, - torch.tensor([], - dtype=torch.float32, - device=packed_qkv.query.device), - attn_metadata, - attn_type=attn_type) + with set_forward_context(attn_metadata): + # In the test setup the shape of the query is + # [batch_size, seq_len, num_heads, head_size]. However + # the attention backend expect the shape to be + # [num_tokens, hidden_size]. Hence reshape the query before + # invoking the forward method. + # TODO - Update the way we construct the query so that it + # is shaped as [num_tokens, hidden_size] and we can skip the reshape. + reshaped_query = packed_qkv.query.view( + -1, test_pt.num_heads * test_pt.head_size) + return attn.forward(reshaped_query, + packed_qkv.key, + packed_qkv.value, + torch.tensor([], + dtype=torch.float32, + device=packed_qkv.query.device), + attn_metadata, + attn_type=attn_type) def _run_decoder_self_attention_test( test_rsrcs: TestResources, decoder_test_params: PhaseTestParameters, attn_metadata: AttentionMetadata, + test_pt: TestPoint, ) -> torch.Tensor: ''' Run decoder self-attention test. @@ -650,6 +665,8 @@ def _run_decoder_self_attention_test( query/key/value fields * attn_metadata: attention metadata for decoder-self attention (contains KV cache memory-mapping) + * test_pt: The TestPoint object containing test details like number of + model heads, head size, name of the backend being used etc. Returns: * Attention.forward() applied to packed_{query,key,value}, kv_cache @@ -660,12 +677,22 @@ def _run_decoder_self_attention_test( kv_cache = test_rsrcs.kv_cache packed_qkv = decoder_test_params.packed_qkvo.packed_qkv assert packed_qkv is not None - return attn.forward(packed_qkv.query, - packed_qkv.key, - packed_qkv.value, - kv_cache, - attn_metadata, - attn_type=attn_type) + with set_forward_context(attn_metadata): + # In the test setup the shape of the query is + # [batch_size, seq_len, num_heads, head_size]. However + # the attention backend expect the shape to be + # [num_tokens, hidden_size]. Hence reshape the query before + # invoking the forward method. + # TODO - Update the way we construct the query so that it + # is shaped as [num_tokens, hidden_size] and we can skip the reshape. + reshaped_query = packed_qkv.query.view( + -1, test_pt.num_heads * test_pt.head_size) + return attn.forward(reshaped_query, + packed_qkv.key, + packed_qkv.value, + kv_cache, + attn_metadata, + attn_type=attn_type) def _run_encoder_decoder_cross_attention_test( @@ -673,6 +700,7 @@ def _run_encoder_decoder_cross_attention_test( decoder_test_params: PhaseTestParameters, cross_test_params: Optional[PhaseTestParameters], attn_metadata: AttentionMetadata, + test_pt: TestPoint, ) -> torch.Tensor: ''' Run encoder/decoder cross-attention test. @@ -701,6 +729,8 @@ def _run_encoder_decoder_cross_attention_test( (number_of_tokens x num_heads x head_size) key/value fields * attn_metadata: attention metadata for encoder/decoder-self attention + * test_pt: The TestPoint object containing test details like number of + model heads, head size, name of the backend being used etc. Returns: * Attention.forward() applied to packed_{query,key,value}, kv_cache @@ -718,12 +748,37 @@ def _run_encoder_decoder_cross_attention_test( cross_pckd_qkv = cross_test_params.packed_qkvo.packed_qkv key = (None if cross_pckd_qkv is None else cross_pckd_qkv.key) value = (None if cross_pckd_qkv is None else cross_pckd_qkv.value) - return attn.forward(decoder_test_params.packed_qkvo.packed_qkv.query, - key, - value, - kv_cache, - attn_metadata, - attn_type=attn_type) + with set_forward_context(attn_metadata): + # In the test setup the shape of the query is + # [batch_size, seq_len, num_heads, head_size]. However + # the attention backend expect the shape to be + # [num_tokens, hidden_size]. Hence reshape the query before + # invoking the forward method. + # TODO - Update the way we construct the query so that it + # is shaped as [num_tokens, hidden_size] and we can skip the reshape. + reshaped_query = decoder_test_params.packed_qkvo.packed_qkv.query.view( + -1, test_pt.num_heads * test_pt.head_size) + return attn.forward(reshaped_query, + key, + value, + kv_cache, + attn_metadata, + attn_type=attn_type) + + +@pytest.fixture(autouse=True) +def set_reset_environment(attn_backend): + # Set the default torch datatype to bfloat16 to enable + # testing of the Flash Attention backend. Also clear the + # cached value of the backend. + default_dtype = torch.get_default_dtype() + if attn_backend.name == 'FLASH_ATTN': + torch.set_default_dtype(torch.bfloat16) + get_attn_backend.cache_clear() + yield + # Reset the torch datatype to what it was before the test + # so as not to impact the remaining tests. + torch.set_default_dtype(default_dtype) @pytest.mark.skipif(current_platform.is_rocm(), @@ -773,10 +828,8 @@ def test_encoder_only( * max_dec_seq_len: max length of decoder input sequences * max_enc_seq_len: max length of encoder input sequences ''' - # Force Attention wrapper backend with global_force_attn_backend_context_manager(attn_backend): - # Note: KV cache size of 4096 is arbitrary & chosen intentionally # to be more than necessary, since exceeding the kv cache size # is not part of this test @@ -807,10 +860,14 @@ def test_encoder_only( # PREFILL: encoder attention enc_pckd_act_out: torch.Tensor = (_run_encoder_attention_test( - test_rsrcs.attn, enc_test_params, prephase_attn_metadata)) + test_rsrcs.attn, + enc_test_params, + prephase_attn_metadata, + test_pt=test_pt)) # - Is encoder attention result correct? - assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out) + assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out, + attn_backend.name) @pytest.mark.skipif(current_platform.is_rocm(), @@ -892,10 +949,8 @@ def test_e2e_enc_dec_attn( * max_dec_seq_len: max length of decoder input sequences * max_enc_seq_len: max length of encoder input sequences ''' - # Force Attention wrapper backend with global_force_attn_backend_context_manager(attn_backend): - # Note: KV cache size of 4096 is arbitrary & chosen intentionally # to be more than necessary, since exceeding the kv cache size # is not part of this test @@ -955,29 +1010,39 @@ def test_e2e_enc_dec_attn( enc_pckd_act_out = _run_encoder_attention_test(test_rsrcs.attn, enc_test_params, - prephase_attn_metadata) + prephase_attn_metadata, + test_pt=test_pt) # - Is encoder attention result correct? - assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out) + assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out, + attn_backend.name) # PREFILL: decoder self-attention test prephase_dec_pckd_act_out = _run_decoder_self_attention_test( - test_rsrcs, prephase_dec_test_params, prephase_attn_metadata) + test_rsrcs, + prephase_dec_test_params, + prephase_attn_metadata, + test_pt=test_pt) # - Is prefill decoder self-attention correct? assert_actual_matches_ideal(prephase_dec_test_params, - prephase_dec_pckd_act_out) + prephase_dec_pckd_act_out, + attn_backend.name) # PREFILL: encoder/decoder cross-attention test prephase_cross_pckd_act_out = _run_encoder_decoder_cross_attention_test( - test_rsrcs, prephase_dec_test_params, prephase_cross_test_params, - prephase_attn_metadata) + test_rsrcs, + prephase_dec_test_params, + prephase_cross_test_params, + prephase_attn_metadata, + test_pt=test_pt) # - Is prefill encoder/decoder cross-attention correct? assert_actual_matches_ideal(prephase_cross_test_params, - prephase_cross_pckd_act_out) + prephase_cross_pckd_act_out, + attn_backend.name) # DECODE: build decode-phase attention metadata @@ -993,17 +1058,26 @@ def test_e2e_enc_dec_attn( # DECODE: decoder self-attention test decphase_dec_pckd_act_out = _run_decoder_self_attention_test( - test_rsrcs, decphase_dec_test_params, decphase_attn_metadata) + test_rsrcs, + decphase_dec_test_params, + decphase_attn_metadata, + test_pt=test_pt) # - Is decode-phase decoder self-attention correct? assert_actual_matches_ideal(decphase_dec_test_params, - decphase_dec_pckd_act_out) + decphase_dec_pckd_act_out, + attn_backend.name) # DECODE: encoder/decoder cross-attention test decphase_cross_pckd_act_out = _run_encoder_decoder_cross_attention_test( - test_rsrcs, decphase_dec_test_params, None, decphase_attn_metadata) + test_rsrcs, + decphase_dec_test_params, + None, + decphase_attn_metadata, + test_pt=test_pt) # - Is decode-phase encoder/decoder cross-attention correct? assert_actual_matches_ideal(decphase_cross_test_params, - decphase_cross_pckd_act_out) + decphase_cross_pckd_act_out, + attn_backend.name) diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index a2d414f636e13..e7865fb2500ef 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -13,8 +13,8 @@ from vllm.attention import AttentionBackend, AttentionMetadata, AttentionType from vllm.model_executor.layers.activation import SiluAndMul -from vllm.utils import (STR_BACKEND_ENV_VAR, STR_XFORMERS_ATTN_VAL, - make_tensor_with_pad) +from vllm.utils import (STR_BACKEND_ENV_VAR, STR_FLASH_ATTN_VAL, + STR_XFORMERS_ATTN_VAL, make_tensor_with_pad) # For now, disable "test_aot_dispatch_dynamic" since there are some # bugs related to this test in PyTorch 2.4. @@ -525,17 +525,22 @@ def make_backend(backend_name: str) -> AttentionBackend: if backend_name == STR_XFORMERS_ATTN_VAL: # NOTE: xFormers backend cannot be imported for CPU and AMD GPUs. from vllm.attention.backends.xformers import XFormersBackend - return XFormersBackend() + elif backend_name == STR_FLASH_ATTN_VAL: + from vllm.attention.backends.flash_attn import FlashAttentionBackend + return FlashAttentionBackend() + raise AssertionError( f"Unrecognized backend_name {backend_name} for unit test") def _make_metadata_tensors( - seq_lens: Optional[List[int]], context_lens: Optional[List[int]], - encoder_seq_lens: Optional[List[int]], device: Union[torch.device, str] -) -> Tuple[torch.Tensor, torch.Tensor, Any, Any, Optional[List[int]], - torch.Tensor, Optional[int]]: + seq_lens: Optional[List[int]], + context_lens: Optional[List[int]], + encoder_seq_lens: Optional[List[int]], + device: Union[torch.device, str], +) -> Tuple[torch.Tensor, torch.Tensor, Any, Any, Optional[torch.Tensor], + torch.Tensor, torch.Tensor, Optional[int]]: ''' Build scalar & tensor values required to build attention metadata structure. @@ -553,6 +558,8 @@ def _make_metadata_tensors( * max_context_len: max(context_lens) * max_seq_len: max(seq_lens) * seq_start_loc: start idx of each sequence + * encoder_seq_lens_tensor: encoder seq_lens list, as tensor + * encoder_seq_start_loc: start idx of each encoder sequence * max_encoder_seq_len: encoder seq_lens list, as tensor ''' seq_lens_tensor = maybe_make_int_tensor(seq_lens, device) @@ -566,8 +573,26 @@ def _make_metadata_tensors( seq_start_loc = None + if seq_lens_tensor is not None: + seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1, + dtype=torch.int32, + device=seq_lens_tensor.device) + torch.cumsum(seq_lens_tensor, + dim=0, + dtype=seq_start_loc.dtype, + out=seq_start_loc[1:]) + + encoder_seq_start_loc = torch.zeros(encoder_seq_lens_tensor.shape[0] + 1, + dtype=torch.int32, + device=encoder_seq_lens_tensor.device) + torch.cumsum(encoder_seq_lens_tensor, + dim=0, + dtype=encoder_seq_start_loc.dtype, + out=encoder_seq_start_loc[1:]) + return (seq_lens_tensor, context_lens_tensor, max_context_len, max_seq_len, - seq_start_loc, encoder_seq_lens_tensor, max_encoder_seq_len) + seq_start_loc, encoder_seq_lens_tensor, encoder_seq_start_loc, + max_encoder_seq_len) def make_kv_cache(num_blocks: int, @@ -575,6 +600,7 @@ def make_kv_cache(num_blocks: int, head_size: int, block_size: int, device: Union[torch.device, str], + backend: str, default_val: float = 0.0) -> torch.Tensor: ''' Create a fake KV cache. @@ -591,10 +617,20 @@ def make_kv_cache(num_blocks: int, Returns: * kv_cache: 2 x num_blocks x (block_size * num_heads * head_size) + * for backend 'XFORMERS' + * kv_cache: 2 x num_blocks x block_size x num_heads x head_size + * for backend 'FLASH_ATTN' ''' - - kv_cache = torch.rand( - (2, num_blocks, block_size * num_heads * head_size)).to(device) + if backend == 'XFORMERS': + kv_cache = torch.rand( + (2, num_blocks, block_size * num_heads * head_size)).to(device) + elif backend == 'FLASH_ATTN': + kv_cache = torch.rand( + (2, num_blocks, block_size, num_heads, head_size)).to(device) + else: + raise ValueError( + f"Unknown backend value: '{backend}'. Expected 'XFORMERS' or " + f"'FLASH_ATTN'.") if default_val is not None: kv_cache[:, :, :] = default_val return kv_cache @@ -858,8 +894,9 @@ def make_test_metadata( context_lens_tensor, _, _, - _, + seq_start_loc, encoder_seq_lens_tensor, + encoder_seq_start_loc, max_encoder_seq_len, ) = _make_metadata_tensors(seq_lens, context_lens, @@ -869,10 +906,12 @@ def make_test_metadata( return attn_backend.make_metadata( num_prefills=num_prefills, slot_mapping=(None if kv_mmap is None else kv_mmap.slot_mapping), + multi_modal_placeholder_index_maps=None, num_prefill_tokens=num_prefill_tokens, num_decode_tokens=num_decode_tokens, seq_lens=seq_lens, seq_lens_tensor=seq_lens_tensor, + seq_start_loc=seq_start_loc, max_prefill_seq_len=None if seq_lens is None else max(seq_lens), max_decode_seq_len=0, context_lens_tensor=context_lens_tensor, @@ -881,6 +920,7 @@ def make_test_metadata( num_encoder_tokens=num_encoder_tokens, encoder_seq_lens=encoder_seq_lens, encoder_seq_lens_tensor=encoder_seq_lens_tensor, + encoder_seq_start_loc=encoder_seq_start_loc, max_encoder_seq_len=max_encoder_seq_len, cross_slot_mapping=(None if cross_kv_mmap is None else cross_kv_mmap.slot_mapping), @@ -903,8 +943,9 @@ def make_test_metadata( context_lens_tensor, _, _, - _, + seq_start_loc, encoder_seq_lens_tensor, + encoder_seq_start_loc, max_encoder_seq_len, ) = _make_metadata_tensors(seq_lens, context_lens, @@ -914,18 +955,22 @@ def make_test_metadata( return attn_backend.make_metadata( num_prefills=num_prefills, slot_mapping=kv_mmap.slot_mapping, + multi_modal_placeholder_index_maps=None, num_prefill_tokens=num_prefill_tokens, num_decode_tokens=num_decode_tokens, seq_lens=seq_lens, seq_lens_tensor=seq_lens_tensor, + seq_start_loc=seq_start_loc, max_prefill_seq_len=0, max_decode_seq_len=max(seq_lens), + max_decode_query_len=1, context_lens_tensor=context_lens_tensor, block_tables=kv_mmap.block_tables, use_cuda_graph=False, num_encoder_tokens=num_encoder_tokens, encoder_seq_lens=encoder_seq_lens, encoder_seq_lens_tensor=encoder_seq_lens_tensor, + encoder_seq_start_loc=encoder_seq_start_loc, max_encoder_seq_len=max_encoder_seq_len, cross_slot_mapping=(None if cross_kv_mmap is None else cross_kv_mmap.slot_mapping), @@ -934,7 +979,8 @@ def make_test_metadata( def assert_actual_matches_ideal(test_params: PhaseTestParameters, - output_under_test: torch.Tensor) -> None: + output_under_test: torch.Tensor, + backend: str) -> None: ''' Assert that observed output matches the ideal output contained in the test parameters data structure. @@ -945,8 +991,22 @@ def assert_actual_matches_ideal(test_params: PhaseTestParameters, * output_under_test: actually observed output value ''' ideal_output = test_params.packed_qkvo.ideal_output - torch.testing.assert_close(ideal_output, - output_under_test.view_as(ideal_output)) + if backend == 'XFORMERS': + torch.testing.assert_close(ideal_output, + output_under_test.view_as(ideal_output)) + + elif backend == 'FLASH_ATTN': + # For FlashAttention override the accuracy thresholds to non default + # values since we notice a higher difference between the ideal and + # actual output. + torch.testing.assert_close(ideal_output, + output_under_test.view_as(ideal_output), + atol=0.01, + rtol=0.016) + else: + raise ValueError( + f"Unknown backend value: '{backend}'. Expected 'XFORMERS' or " + f"'FLASH_ATTN'.") # Copied/modified from torch._refs.__init__.py diff --git a/tests/lora/conftest.py b/tests/lora/conftest.py index e40f0dd74602e..816d3986fe333 100644 --- a/tests/lora/conftest.py +++ b/tests/lora/conftest.py @@ -248,11 +248,10 @@ def llama_2_7b_engine_extra_embeddings(): cleanup_dist_env_and_memory(shutdown_ray=True) get_model_old = get_model - def get_model_patched(*, model_config, device_config, **kwargs): - kwargs["lora_config"] = LoRAConfig(max_loras=4, max_lora_rank=8) - return get_model_old(model_config=model_config, - device_config=device_config, - **kwargs) + def get_model_patched(**kwargs): + kwargs["vllm_config"].lora_config = LoRAConfig(max_loras=4, + max_lora_rank=8) + return get_model_old(**kwargs) with patch("vllm.worker.model_runner.get_model", get_model_patched): engine = vllm.LLM("meta-llama/Llama-2-7b-hf", enable_lora=False) diff --git a/tests/lora/test_long_context.py b/tests/lora/test_long_context.py index c8edb02a88d4b..eada902c891f7 100644 --- a/tests/lora/test_long_context.py +++ b/tests/lora/test_long_context.py @@ -138,13 +138,7 @@ def test_rotary_emb_replaced(dist_init): enable_lora=True) engine_config = engine_args.create_engine_config() model_runner = ModelRunner( - model_config=engine_config.model_config, - parallel_config=engine_config.parallel_config, - scheduler_config=engine_config.scheduler_config, - device_config=engine_config.device_config, - cache_config=engine_config.cache_config, - load_config=engine_config.load_config, - lora_config=engine_config.lora_config, + vllm_config=engine_config, is_driver_worker=True, ) model_runner.load_model() diff --git a/tests/lora/test_minicpmv.py b/tests/lora/test_minicpmv.py index be040060d02b2..2c45ce5141f7d 100644 --- a/tests/lora/test_minicpmv.py +++ b/tests/lora/test_minicpmv.py @@ -1,8 +1,11 @@ from typing import List +import pytest + import vllm from vllm.assets.image import ImageAsset from vllm.lora.request import LoRARequest +from vllm.platforms import current_platform MODEL_PATH = "openbmb/MiniCPM-Llama3-V-2_5" @@ -53,6 +56,9 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]: return generated_texts +@pytest.mark.xfail( + current_platform.is_rocm(), + reason="MiniCPM-V dependency xformers incompatible with ROCm") def test_minicpmv_lora(minicpmv_lora_files): llm = vllm.LLM( MODEL_PATH, @@ -63,7 +69,6 @@ def test_minicpmv_lora(minicpmv_lora_files): trust_remote_code=True, gpu_memory_utilization=0.97 # This model is pretty big for CI gpus ) - output1 = do_sample(llm, minicpmv_lora_files, lora_id=1) for i in range(len(EXPECTED_OUTPUT)): assert EXPECTED_OUTPUT[i].startswith(output1[i]) diff --git a/tests/lora/test_worker.py b/tests/lora/test_worker.py index 2f7ac85507425..9d814f657ac43 100644 --- a/tests/lora/test_worker.py +++ b/tests/lora/test_worker.py @@ -4,7 +4,8 @@ from unittest.mock import patch from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, - ModelConfig, ParallelConfig, SchedulerConfig) + ModelConfig, ParallelConfig, SchedulerConfig, + VllmConfig) from vllm.lora.models import LoRAMapping from vllm.lora.request import LoRARequest from vllm.worker.worker import Worker @@ -12,7 +13,7 @@ @patch.dict(os.environ, {"RANK": "0"}) def test_worker_apply_lora(sql_lora_files): - worker = Worker( + vllm_config = VllmConfig( model_config=ModelConfig( "meta-llama/Llama-2-7b-hf", task="auto", @@ -34,10 +35,13 @@ def test_worker_apply_lora(sql_lora_files): gpu_memory_utilization=1., swap_space=0, cache_dtype="auto"), - local_rank=0, - rank=0, lora_config=LoRAConfig(max_lora_rank=8, max_cpu_loras=32, max_loras=32), + ) + worker = Worker( + vllm_config=vllm_config, + local_rank=0, + rank=0, distributed_init_method=f"file://{tempfile.mkstemp()[1]}", ) worker.init_device() diff --git a/tests/metrics/test_metrics.py b/tests/metrics/test_metrics.py index 7a361ef320810..4a824c7acef21 100644 --- a/tests/metrics/test_metrics.py +++ b/tests/metrics/test_metrics.py @@ -365,6 +365,7 @@ def assert_metrics(engine: LLMEngine, disable_log_stats: bool, "vllm:request_prompt_tokens", "vllm:request_generation_tokens", "vllm:request_params_n", + "vllm:request_params_max_tokens", ] for metric_name in request_histogram_metrics: metric_value = REGISTRY.get_sample_value(f"{metric_name}_count", diff --git a/tests/models/decoder_only/audio_language/test_ultravox.py b/tests/models/decoder_only/audio_language/test_ultravox.py index b9089e75ffab8..d14e88b4e5b26 100644 --- a/tests/models/decoder_only/audio_language/test_ultravox.py +++ b/tests/models/decoder_only/audio_language/test_ultravox.py @@ -2,8 +2,10 @@ import numpy as np import pytest +import pytest_asyncio from transformers import AutoModel, AutoTokenizer, BatchEncoding +from tests.utils import RemoteOpenAIServer from vllm.sequence import SampleLogprobs from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE @@ -17,6 +19,13 @@ VLLM_PLACEHOLDER = "<|reserved_special_token_0|>" HF_PLACEHOLDER = "<|audio|>" +CHUNKED_PREFILL_KWARGS = { + "enable_chunked_prefill": True, + "max_num_seqs": 2, + # Use a very small limit to exercise chunked prefill. + "max_num_batched_tokens": 16 +} + @pytest.fixture(scope="session") def audio_assets(): @@ -30,6 +39,26 @@ def audio(request): return AudioAsset(request.param) +@pytest.fixture(params=({}, CHUNKED_PREFILL_KWARGS)) +def server(request, audio_assets): + args = [ + "--dtype=bfloat16", "--max-model-len=4096", "--enforce-eager", + f"--limit-mm-per-prompt=audio={len(audio_assets)}" + ] + [ + f"--{key.replace('_','-')}={value}" + for key, value in request.param.items() + ] + + with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: + yield remote_server + + +@pytest_asyncio.fixture +async def client(server): + async with server.get_async_client() as async_client: + yield async_client + + def _get_prompt(audio_count, question, placeholder): tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) placeholder = f"{placeholder}\n" * audio_count @@ -68,8 +97,7 @@ def run_test( dtype: str, max_tokens: int, num_logprobs: int, - tensor_parallel_size: int, - distributed_executor_backend: Optional[str] = None, + **kwargs, ): """Inference result should be the same between hf and vllm.""" torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[dtype] @@ -79,11 +107,8 @@ def run_test( # if we run HF first, the cuda initialization will be done and it # will hurt multiprocessing backend with fork method (the default method). - with vllm_runner(model, - dtype=dtype, - tensor_parallel_size=tensor_parallel_size, - distributed_executor_backend=distributed_executor_backend, - enforce_eager=True) as vllm_model: + with vllm_runner(model, dtype=dtype, enforce_eager=True, + **kwargs) as vllm_model: vllm_outputs_per_audio = [ vllm_model.generate_greedy_logprobs([vllm_prompt], max_tokens, @@ -135,18 +160,16 @@ def run_multi_audio_test( dtype: str, max_tokens: int, num_logprobs: int, - tensor_parallel_size: int, - distributed_executor_backend: Optional[str] = None, + **kwargs, ): with vllm_runner(model, dtype=dtype, - tensor_parallel_size=tensor_parallel_size, - distributed_executor_backend=distributed_executor_backend, enforce_eager=True, limit_mm_per_prompt={ "audio": max((len(audio) for _, audio in prompts_and_audios)) - }) as vllm_model: + }, + **kwargs) as vllm_model: vllm_outputs = vllm_model.generate_greedy_logprobs( [prompt for prompt, _ in prompts_and_audios], max_tokens, @@ -162,8 +185,9 @@ def run_multi_audio_test( @pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("max_tokens", [128]) @pytest.mark.parametrize("num_logprobs", [5]) +@pytest.mark.parametrize("vllm_kwargs", [{}, CHUNKED_PREFILL_KWARGS]) def test_models(hf_runner, vllm_runner, audio, dtype: str, max_tokens: int, - num_logprobs: int) -> None: + num_logprobs: int, vllm_kwargs: dict) -> None: vllm_prompt = _get_prompt(1, "Describe the audio above.", VLLM_PLACEHOLDER) hf_prompt = _get_prompt(1, "Describe the audio above.", HF_PLACEHOLDER) @@ -175,7 +199,7 @@ def test_models(hf_runner, vllm_runner, audio, dtype: str, max_tokens: int, dtype=dtype, max_tokens=max_tokens, num_logprobs=num_logprobs, - tensor_parallel_size=1, + **vllm_kwargs, ) @@ -183,9 +207,10 @@ def test_models(hf_runner, vllm_runner, audio, dtype: str, max_tokens: int, @pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("max_tokens", [128]) @pytest.mark.parametrize("num_logprobs", [5]) +@pytest.mark.parametrize("vllm_kwargs", [{}, CHUNKED_PREFILL_KWARGS]) def test_models_with_multiple_audios(vllm_runner, audio_assets, dtype: str, - max_tokens: int, - num_logprobs: int) -> None: + max_tokens: int, num_logprobs: int, + vllm_kwargs: dict) -> None: vllm_prompt = _get_prompt(len(audio_assets), "Describe each of the audios above.", @@ -198,5 +223,37 @@ def test_models_with_multiple_audios(vllm_runner, audio_assets, dtype: str, dtype=dtype, max_tokens=max_tokens, num_logprobs=num_logprobs, - tensor_parallel_size=1, + **vllm_kwargs, ) + + +@pytest.mark.asyncio +async def test_online_inference(client, audio_assets): + """Exercises online inference with/without chunked prefill enabled.""" + + messages = [{ + "role": + "user", + "content": [ + *[{ + "type": "audio_url", + "audio_url": { + "url": audio.url + } + } for audio in audio_assets], + { + "type": + "text", + "text": + f"What's happening in these {len(audio_assets)} audio clips?" + }, + ], + }] + + chat_completion = await client.chat.completions.create(model=MODEL_NAME, + messages=messages, + max_tokens=10) + + assert len(chat_completion.choices) == 1 + choice = chat_completion.choices[0] + assert choice.finish_reason == "length" diff --git a/tests/models/decoder_only/vision_language/test_h2ovl.py b/tests/models/decoder_only/vision_language/test_h2ovl.py new file mode 100644 index 0000000000000..ad9aa3104750b --- /dev/null +++ b/tests/models/decoder_only/vision_language/test_h2ovl.py @@ -0,0 +1,130 @@ +from typing import Optional, Tuple + +import pytest +import torch +from PIL.Image import Image +from transformers import AutoConfig + +# Import the functions to test +from vllm.model_executor.models.h2ovl import (calculate_num_blocks, + image_to_pixel_values_wrapper) +from vllm.multimodal.utils import rescale_image_size + +models = [ + "h2oai/h2ovl-mississippi-800m", # Replace with your actual model names + "h2oai/h2ovl-mississippi-2b", +] +target_dtype = "bfloat16" + + +def run_preprocessing_test( + image: Image, + config, + max_dynamic_patch: Optional[int] = None, +) -> Tuple[torch.Tensor, int]: + """Test the image preprocessing and calculate expected blocks.""" + + if max_dynamic_patch is None: + max_dynamic_patch = config.max_dynamic_patch + + width, height = image.size + use_MSAC = config.use_msac + + # Create the mapper function with the provided configuration + mapper = image_to_pixel_values_wrapper(config, max_dynamic_patch, use_MSAC) + pixel_values = mapper(image) + + # Calculate the expected number of blocks + if use_MSAC: + # First pass + blocks1, _, _, aspect_ratio = calculate_num_blocks( + width, + height, + config.min_dynamic_patch, + max_dynamic_patch, + config.vision_config.image_size, + use_thumbnail=False, # Thumbnail is handled separately + prior_aspect_ratio=None, + ) + + # Second pass + blocks2, _, _, _ = calculate_num_blocks( + width, + height, + config.min_dynamic_patch, + max_dynamic_patch, + config.vision_config.image_size, + use_thumbnail=False, + prior_aspect_ratio=aspect_ratio, + ) + + # Add thumbnail if use_thumbnail is True and total_blocks > 1 + if config.use_thumbnail: + blocks1 += 1 if blocks1 > 1 else 0 + blocks2 += 1 if blocks2 > 1 else 0 + + # Total blocks is the sum of blocks from both passes minus overlapping + total_blocks = blocks1 + blocks2 - 1 + + expected_blocks = total_blocks + + else: + blocks, _, _, _ = calculate_num_blocks( + width, + height, + config.min_dynamic_patch, + max_dynamic_patch, + config.vision_config.image_size, + use_thumbnail=False, + prior_aspect_ratio=None, + ) + expected_blocks = blocks + + if config.use_thumbnail and expected_blocks > 1: + expected_blocks += 1 + + return pixel_values, expected_blocks + + +@pytest.mark.parametrize("model_name", models) +@pytest.mark.parametrize( + "size_factors", + [ + # Single-scale + [1.0], + # Single-scale, batched + [1.0, 1.0, 1.0], + # Multi-scale + [0.25, 0.5, 1.0], + ], +) +@pytest.mark.parametrize("max_dynamic_patch", [None, 2, 4, 8]) +def test_image_preprocessing(image_assets, model_name, size_factors, + max_dynamic_patch): + """Test image preprocessing pipeline with different configurations.""" + # Load the configuration from the model + config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) + + for asset in image_assets: + image = asset.pil_image + for factor in size_factors: + scaled_image = rescale_image_size(image, factor) + + # Test preprocessing and get expected number of blocks + pixel_values, expected_blocks = run_preprocessing_test( + scaled_image, config, max_dynamic_patch) + + # Verify output shapes and properties + actual_blocks = pixel_values.shape[0] + assert actual_blocks == expected_blocks, ( + f"Expected {expected_blocks} blocks, got {actual_blocks}") + + # Check image dimensions + expected_size = ( + 3, # Number of channels (C, H, W) + config.vision_config.image_size, + config.vision_config.image_size, + ) + for img in pixel_values: + assert img.shape == expected_size, ( + f"Expected image size {expected_size}, got {img.shape}") diff --git a/tests/models/decoder_only/vision_language/test_models.py b/tests/models/decoder_only/vision_language/test_models.py index e49ea6f98324d..cfd2d61f2b633 100644 --- a/tests/models/decoder_only/vision_language/test_models.py +++ b/tests/models/decoder_only/vision_language/test_models.py @@ -187,6 +187,23 @@ marks=[large_gpu_mark(min_gb=48)], patch_hf_runner=model_utils.glm_patch_hf_runner, ), + "h2ovl": VLMTestInfo( + models = [ + "h2oai/h2ovl-mississippi-800m", + "h2oai/h2ovl-mississippi-2b", + ], + test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), + prompt_formatter=lambda img_prompt: f"<|prompt|>{img_prompt}<|end|><|answer|>", # noqa: E501 + single_image_prompts=IMAGE_ASSETS.prompts({ + "stop_sign": "\nWhat's the content in the center of the image?", # noqa: E501 + "cherry_blossom": "\nWhat is the season?", + }), + multi_image_prompt="Image-1: \nImage-2: \nDescribe the two images in short.", # noqa: E501 + max_model_len=8192, + dtype="bfloat16", + use_tokenizer_eos=True, + patch_hf_runner=model_utils.h2ovl_patch_hf_runner, + ), "intern_vl": VLMTestInfo( models=[ "OpenGVLab/InternVL2-1B", diff --git a/tests/models/decoder_only/vision_language/vlm_utils/model_utils.py b/tests/models/decoder_only/vision_language/vlm_utils/model_utils.py index e925934db0e7c..849857b4232e7 100644 --- a/tests/models/decoder_only/vision_language/vlm_utils/model_utils.py +++ b/tests/models/decoder_only/vision_language/vlm_utils/model_utils.py @@ -259,6 +259,66 @@ def processor(*args, text="", images=None, **kwargs): return hf_model +def h2ovl_patch_hf_runner(hf_model: HfRunner) -> HfRunner: + """Patches and returns an instance of the HfRunner to use for H2OVL.""" + + class H2OVLProcessor: + """A simple processor for H2OVL models.""" + + def __init__(self, hf_runner: HfRunner): + self.num_image_token = hf_runner.model.num_image_token + self.tokenizer = hf_runner.tokenizer + self.dtype = hf_runner.model.dtype + + self.config = AutoConfig.from_pretrained(hf_runner.model_name, + trust_remote_code=True) + self.vision_config = self.config.vision_config + self.use_thumbnail = self.config.use_thumbnail + self.min_num = self.config.min_dynamic_patch + self.max_num = self.config.max_dynamic_patch + self.image_size = self.vision_config.image_size + + def __call__(self, text: str, images: Union[Image, List[Image]], + **kwargs): + # yapf: disable + from vllm.model_executor.models.h2ovl import ( + IMG_CONTEXT, IMG_END, IMG_START, image_to_pixel_values) + + # yapf: enable + images = [images] if isinstance(images, Image) else images + pixel_values = [ + image_to_pixel_values(image, + self.image_size, + self.min_num, + self.max_num, + self.use_thumbnail, + use_MSAC=self.config.use_msac).to( + self.dtype) for image in images + ] + num_patches_list = [ + pixel_value.shape[0] for pixel_value in pixel_values + ] + pixel_values = torch.cat(pixel_values, dim=0) + for num_patches in num_patches_list: + context_tokens = IMG_CONTEXT * self.num_image_token \ + * num_patches + image_tokens = IMG_START + context_tokens + IMG_END + text = text.replace('', image_tokens, 1) + prompt = self.tokenizer(text, return_tensors="pt") + prompt.update({"pixel_values": pixel_values}) + return prompt + + img_context_token_id = hf_model.tokenizer.convert_tokens_to_ids( + "") + hf_model.model.img_context_token_id = img_context_token_id + hf_model.processor = H2OVLProcessor(hf_model) + hf_model.model.get_output_embeddings = lambda: \ + hf_model.model.language_model.get_output_embeddings() + hf_model.model.generate = types.MethodType(_internvl_generate, + hf_model.model) + return hf_model + + def internvl_patch_hf_runner(hf_model: HfRunner) -> HfRunner: """Patches and returns an instance of the HfRunner to use for InternVL.""" diff --git a/tests/models/encoder_decoder/vision_language/test_florence2.py b/tests/models/encoder_decoder/vision_language/test_florence2.py index 483773f069133..d686f1da3fa17 100644 --- a/tests/models/encoder_decoder/vision_language/test_florence2.py +++ b/tests/models/encoder_decoder/vision_language/test_florence2.py @@ -85,7 +85,7 @@ def run_test( @pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["float"]) +@pytest.mark.parametrize("dtype", ["float", "bfloat16"]) @pytest.mark.parametrize("max_tokens", [64]) @pytest.mark.parametrize("num_logprobs", [5]) def test_models(hf_runner, vllm_runner, model, dtype, max_tokens, diff --git a/tests/multimodal/test_processor_kwargs.py b/tests/multimodal/test_processor_kwargs.py index 5044740c3e734..4d3bbd805c152 100644 --- a/tests/multimodal/test_processor_kwargs.py +++ b/tests/multimodal/test_processor_kwargs.py @@ -5,8 +5,8 @@ import pytest import torch -from vllm.inputs import DecoderOnlyInputs, InputContext, token_inputs -from vllm.inputs.registry import InputRegistry +from vllm.inputs import (DecoderOnlyInputs, DummyData, InputContext, + InputRegistry, token_inputs) from vllm.multimodal import MultiModalRegistry from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData @@ -56,7 +56,7 @@ def custom_dummy_data_factory(self, num_crops=DEFAULT_NUM_CROPS): seq_data = SequenceData( array(VLLM_TOKEN_ID_ARRAY_TYPE, [0] * num_crops)) - return seq_data, None + return DummyData(seq_data, None) with patch( "vllm.inputs.registry.InputRegistry._default_dummy_data_factory", @@ -177,9 +177,9 @@ def test_dummy_data_kwarg_overrides(use_dummy_data_mock, num_crops): # NOTE: seq_len is thrown away here since this will leverage the # default dummy data factory that we have patched in, whose seq # len is solely dependent on the value of the mm_processor_kwargs. - seq_data, _ = dummy_registry.dummy_data_for_profiling( + dummy_data = dummy_registry.dummy_data_for_profiling( ctx.model_config, seq_len=-1, mm_registry=mm_registry) - assert len(seq_data.prompt_token_ids) == expected_seq_count + assert len(dummy_data.seq_data.prompt_token_ids) == expected_seq_count @pytest.mark.parametrize( @@ -206,9 +206,9 @@ def test_dummy_data_with_sad_kwarg_overrides(use_dummy_data_mock, # NOTE: seq_len is thrown away here since this will leverage the # default dummy data factory that we have patched in, whose seq # len is solely dependent on the value of the mm_processor_kwargs. - seq_data, _ = dummy_registry.dummy_data_for_profiling( + dummy_data = dummy_registry.dummy_data_for_profiling( ctx.model_config, seq_len=-1, mm_registry=mm_registry) - assert len(seq_data.prompt_token_ids) == DEFAULT_NUM_CROPS + assert len(dummy_data.seq_data.prompt_token_ids) == DEFAULT_NUM_CROPS ### Test overrides for the max token count per multimodal instance diff --git a/tests/multimodal/test_utils.py b/tests/multimodal/test_utils.py index 38cd48629f903..9869c8123f001 100644 --- a/tests/multimodal/test_utils.py +++ b/tests/multimodal/test_utils.py @@ -1,11 +1,12 @@ import base64 import mimetypes -from tempfile import NamedTemporaryFile +import os +from tempfile import NamedTemporaryFile, TemporaryDirectory from typing import Dict, Tuple import numpy as np import pytest -from PIL import Image +from PIL import Image, ImageChops from transformers import AutoConfig, AutoTokenizer from vllm.multimodal.utils import (async_fetch_image, fetch_image, @@ -84,6 +85,40 @@ async def test_fetch_image_base64(url_images: Dict[str, Image.Image], assert _image_equals(data_image_sync, data_image_async) +@pytest.mark.asyncio +@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS) +async def test_fetch_image_local_files(image_url: str): + with TemporaryDirectory() as temp_dir: + origin_image = fetch_image(image_url) + origin_image.save(os.path.join(temp_dir, os.path.basename(image_url)), + quality=100, + icc_profile=origin_image.info.get('icc_profile')) + + image_async = await async_fetch_image( + f"file://{temp_dir}/{os.path.basename(image_url)}", + allowed_local_media_path=temp_dir) + + image_sync = fetch_image( + f"file://{temp_dir}/{os.path.basename(image_url)}", + allowed_local_media_path=temp_dir) + # Check that the images are equal + assert not ImageChops.difference(image_sync, image_async).getbbox() + + with pytest.raises(ValueError): + await async_fetch_image( + f"file://{temp_dir}/../{os.path.basename(image_url)}", + allowed_local_media_path=temp_dir) + with pytest.raises(ValueError): + await async_fetch_image( + f"file://{temp_dir}/../{os.path.basename(image_url)}") + + with pytest.raises(ValueError): + fetch_image(f"file://{temp_dir}/../{os.path.basename(image_url)}", + allowed_local_media_path=temp_dir) + with pytest.raises(ValueError): + fetch_image(f"file://{temp_dir}/../{os.path.basename(image_url)}") + + @pytest.mark.parametrize("model", ["llava-hf/llava-v1.6-mistral-7b-hf"]) def test_repeat_and_pad_placeholder_tokens(model): config = AutoConfig.from_pretrained(model) @@ -92,18 +127,50 @@ def test_repeat_and_pad_placeholder_tokens(model): tokenizer = AutoTokenizer.from_pretrained(model) test_cases = [ - ("", 2, "", [32000, 32000]), - ("", 2, "", [32000, 32000, 32000]), - ("", [3, 2], "", - [32000, 32000, 32000, 32000, 32000]), - ("Image:Image:!", [3, 2], - "Image:Image:!", - [9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918]), - ("", [3, 2], "", [32000, 32000, 32000]), - ] - - for prompt, repeat_count, expected_prompt, expected_token_ids in test_cases: - new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens( + ( + "", + 2, + "", + [32000, 32000], + [{ "offset": 0, "length": 2 }], + ), + ( + "", + 2, + "", + [32000, 32000, 32000], + [{ "offset": 0, "length": 2 }]), + ( + "", + [3, 2], + "", + [32000, 32000, 32000, 32000, 32000], + [{ "offset": 0, "length": 3 }, { "offset": 3, "length": 2 }], + ), + ( + "Image:Image:!", + [3, 2], + "Image:Image:!", + [9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918], + [{ "offset": 2, "length": 3 }, { "offset": 7, "length": 2 }], + ), + ( + "", + [3, 2], + "", + [32000, 32000, 32000], + [{ "offset": 0, "length": 3 }], + ), + ] # yapf: disable + + for ( + prompt, + repeat_count, + expected_prompt, + expected_token_ids, + expected_ranges, + ) in test_cases: + new_prompt, new_token_ids, ranges = repeat_and_pad_placeholder_tokens( tokenizer=tokenizer, prompt=prompt, prompt_token_ids=tokenizer.encode(prompt, @@ -113,3 +180,4 @@ def test_repeat_and_pad_placeholder_tokens(model): ) assert new_prompt == expected_prompt assert new_token_ids == expected_token_ids + assert ranges == expected_ranges diff --git a/tests/spec_decode/utils.py b/tests/spec_decode/utils.py index f683942a5854b..6cf0cfb09b8fa 100644 --- a/tests/spec_decode/utils.py +++ b/tests/spec_decode/utils.py @@ -81,12 +81,7 @@ def create_worker(cls: Callable[..., T], get_ip(), get_open_port()) worker = cls( - model_config=engine_config.model_config, - parallel_config=engine_config.parallel_config, - scheduler_config=engine_config.scheduler_config, - device_config=engine_config.device_config, - cache_config=engine_config.cache_config, - load_config=engine_config.load_config, + vllm_config=engine_config, local_rank=0, rank=0, distributed_init_method=distributed_init_method, diff --git a/tests/test_cache_block_hashing.py b/tests/test_cache_block_hashing.py index 3576a4834ebc3..e8f8499aa88ca 100644 --- a/tests/test_cache_block_hashing.py +++ b/tests/test_cache_block_hashing.py @@ -6,6 +6,7 @@ import pytest +from vllm.inputs import token_inputs from vllm.lora.request import LoRARequest from vllm.sequence import Sequence from vllm.transformers_utils.tokenizer_group import TokenizerGroup @@ -70,10 +71,8 @@ def test_auto_prefix_caching(model: str, block_size: int, max_num_seqs: int, hashes[-1].append([]) prompt_token_ids = tokenizer.encode(prompt) seq = Sequence(seq_id, - inputs={ - "prompt": prompt, - "prompt_token_ids": prompt_token_ids, - }, + inputs=token_inputs(prompt_token_ids, + prompt=prompt), block_size=block_size, eos_token_id=tokenizer.tokenizer.eos_token_id, lora_request=lora_request) diff --git a/tests/tokenization/test_detokenize.py b/tests/tokenization/test_detokenize.py index f4551ed42efb8..a3e70a40db979 100644 --- a/tests/tokenization/test_detokenize.py +++ b/tests/tokenization/test_detokenize.py @@ -1,17 +1,24 @@ -from typing import Any, Dict, List, Optional +from typing import Any, Dict, Generator, List, Optional import pytest from transformers import AutoTokenizer +from vllm.inputs import token_inputs from vllm.sequence import Logprob, SamplingParams, Sequence, SequenceGroup from vllm.transformers_utils.detokenizer import (Detokenizer, detokenize_incrementally) from vllm.transformers_utils.tokenizer_group import get_tokenizer_group +from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer TRUTH = [ "Hello here, this is a simple test", "vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs. It is designed to be used in production environments, where inference and serving", # noqa - "我很感谢你的热情" + "我很感谢你的热情", + # Burmese text triggers an edge-case for Mistral's V3-Tekken tokenizer (eg. + # for mistralai/Pixtral-12B-2409) where tokens may map to bytes with + # incomplete UTF-8 characters + # see https://github.com/vllm-project/vllm/pull/9625 + "ပုံပြင်လေးပြောပြပါ်", ] TOKENIZERS = [ "facebook/opt-125m", @@ -24,6 +31,7 @@ "tiiuae/falcon-7b", "meta-llama/Llama-2-7b-hf", "codellama/CodeLlama-7b-hf", + "mistralai/Pixtral-12B-2409", ] @@ -49,15 +57,55 @@ def _run_incremental_decode(tokenizer, all_input_ids, return decoded_text +@pytest.fixture +def tokenizer(tokenizer_name): + return (MistralTokenizer.from_pretrained(tokenizer_name) + if "mistral" in tokenizer_name else + AutoTokenizer.from_pretrained(tokenizer_name)) + + +@pytest.mark.parametrize("tokenizer_name", ["mistralai/Pixtral-12B-2409"]) +@pytest.mark.parametrize( + "truth", + [ + # Burmese text triggers an edge-case where tokens may map to bytes with + # incomplete UTF-8 characters + "ပုံပြင်လေးပြောပြပါ", + # Using "URGENCY" since "CY" has token id 130282 + "URGENCY🌶️", + ]) +def test_mistral_edge_case(tokenizer, truth): + """Test for a specific edge cases with V3-Tekken MistralTokenizer. + + See https://github.com/vllm-project/vllm/pull/9625 + """ + starting_index = 0 + all_input_ids = tokenizer(truth, add_special_tokens=False).input_ids + + decoded_text = _run_incremental_decode(tokenizer, + all_input_ids, + skip_special_tokens=True, + starting_index=starting_index) + assert decoded_text == truth + + +@pytest.fixture +def skip_special_tokens(request, tokenizer_name) -> Generator[bool, Any, None]: + if "mistral" in tokenizer_name: + yield ( + bool(True) if request.param else + pytest.skip("mistral doesn't support skip_special_tokens=False")) + else: + yield bool(True) if request.param else bool(False) + + @pytest.mark.parametrize("truth", TRUTH) @pytest.mark.parametrize("with_prompt", [True, False]) -@pytest.mark.parametrize("tokenizer_id", TOKENIZERS) -@pytest.mark.parametrize("skip_special_tokens", (True, False)) -def test_decode_streaming(tokenizer_id, truth, with_prompt, - skip_special_tokens): - tokenizer = AutoTokenizer.from_pretrained(tokenizer_id) +@pytest.mark.parametrize("tokenizer_name", TOKENIZERS) +@pytest.mark.parametrize("skip_special_tokens", (True, False), indirect=True) +def test_decode_streaming(tokenizer, truth, with_prompt, skip_special_tokens): if with_prompt: - truth_tokens = tokenizer(truth, add_special_tokens=False)["input_ids"] + truth_tokens = tokenizer(truth, add_special_tokens=False).input_ids prompt_input_ids = truth_tokens[:len(truth) // 2] generated_input_ids = truth_tokens[len(truth) // 2:] all_input_ids = prompt_input_ids + generated_input_ids @@ -68,7 +116,7 @@ def test_decode_streaming(tokenizer_id, truth, with_prompt, else: generated = truth starting_index = 0 - all_input_ids = tokenizer(truth, add_special_tokens=False)["input_ids"] + all_input_ids = tokenizer(truth, add_special_tokens=False).input_ids if skip_special_tokens: if tokenizer.bos_token_id is not None: all_input_ids = [tokenizer.bos_token_id] + all_input_ids @@ -98,7 +146,7 @@ def detokenizer(tokenizer_name: str) -> Detokenizer: enable_lora=False, max_num_seqs=100, max_input_length=None, - tokenizer_mode="auto", + tokenizer_mode="mistral" if "mistral" in tokenizer_name else "auto", trust_remote_code=False, revision=None, ) @@ -113,9 +161,8 @@ def detokenizer(tokenizer_name: str) -> Detokenizer: @pytest.fixture(name="complete_sequence_token_ids") def create_complete_sequence_token_ids(complete_sequence: str, - tokenizer_name: str) -> List[int]: - tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) - complete_sequence_token_ids = tokenizer(complete_sequence)["input_ids"] + tokenizer) -> List[int]: + complete_sequence_token_ids = tokenizer(complete_sequence).input_ids return complete_sequence_token_ids @@ -123,10 +170,7 @@ def create_sequence(prompt_token_ids=None): prompt_token_ids = prompt_token_ids or [1] return Sequence( seq_id=0, - inputs={ - "prompt": "", - "prompt_token_ids": prompt_token_ids, - }, + inputs=token_inputs(prompt_token_ids, prompt=""), block_size=16, ) @@ -150,7 +194,7 @@ def create_dummy_prompt_logprobs( @pytest.mark.parametrize("complete_sequence", TRUTH) @pytest.mark.parametrize("tokenizer_name", TOKENIZERS) -@pytest.mark.parametrize("skip_special_tokens", [True, False]) +@pytest.mark.parametrize("skip_special_tokens", [True, False], indirect=True) def test_decode_sequence_logprobs(complete_sequence: str, complete_sequence_token_ids: List[int], detokenizer: Detokenizer, @@ -208,9 +252,9 @@ def test_decode_prompt_logprobs(complete_sequence_token_ids: List[int], # decoded_prompt_logprobs doesn't contain the first token. token_ids = complete_sequence_token_ids - tokenzier = detokenizer.get_tokenizer_for_seq(seq) - text_full = tokenzier.decode(token_ids, skip_special_tokens=True) - text_first = tokenzier.decode(token_ids[0], skip_special_tokens=True) + tokenizer = detokenizer.get_tokenizer_for_seq(seq) + text_full = tokenizer.decode(token_ids, skip_special_tokens=True) + text_first = tokenizer.decode(token_ids[0], skip_special_tokens=True) text = text_full[len(text_first):] # Text for logprobs for the chosen token should be the same as the diff --git a/tests/worker/test_encoder_decoder_model_runner.py b/tests/worker/test_encoder_decoder_model_runner.py index e75884a7395e2..9e166ae64dbfb 100644 --- a/tests/worker/test_encoder_decoder_model_runner.py +++ b/tests/worker/test_encoder_decoder_model_runner.py @@ -19,14 +19,7 @@ def _create_model_runner(model: str, *args, engine_args = EngineArgs(model, *args, **kwargs) engine_config = engine_args.create_engine_config() model_runner = EncoderDecoderModelRunner( - model_config=engine_config.model_config, - parallel_config=engine_config.parallel_config, - scheduler_config=engine_config.scheduler_config, - device_config=engine_config.device_config, - cache_config=engine_config.cache_config, - load_config=engine_config.load_config, - lora_config=engine_config.lora_config, - prompt_adapter_config=engine_config.prompt_adapter_config, + vllm_config=engine_config, is_driver_worker=True, ) return model_runner diff --git a/tests/worker/test_model_input.py b/tests/worker/test_model_input.py index 1e7f560fc68cc..b36e8bfe73ff3 100644 --- a/tests/worker/test_model_input.py +++ b/tests/worker/test_model_input.py @@ -73,6 +73,7 @@ def test_model_runner_input(): num_prefill_tokens=2, num_decode_tokens=3, slot_mapping=torch.zeros(1), + multi_modal_placeholder_index_maps=None, ) model_input = ModelInputForGPUWithSamplingMetadata( input_tokens=torch.ones(10), @@ -124,6 +125,7 @@ def test_embedding_model_runner_input(): num_prefill_tokens=2, num_decode_tokens=3, slot_mapping=torch.zeros(1), + multi_modal_placeholder_index_maps=None, ) model_input = ModelInputForGPUWithPoolingMetadata( input_tokens=torch.ones(10), @@ -174,6 +176,7 @@ def test_multi_step_model_runner_input(): num_prefill_tokens=2, num_decode_tokens=3, slot_mapping=torch.zeros(1), + multi_modal_placeholder_index_maps=None, ) frozen_model_input = ModelInputForGPUWithSamplingMetadata( input_tokens=torch.ones(10), diff --git a/tests/worker/test_model_runner.py b/tests/worker/test_model_runner.py index fe97199bac62d..433a9b30ba57a 100644 --- a/tests/worker/test_model_runner.py +++ b/tests/worker/test_model_runner.py @@ -16,15 +16,7 @@ def _create_model_runner(model: str, *args, **kwargs) -> ModelRunner: engine_args = EngineArgs(model, *args, **kwargs) engine_config = engine_args.create_engine_config() model_runner = ModelRunner( - model_config=engine_config.model_config, - parallel_config=engine_config.parallel_config, - scheduler_config=engine_config.scheduler_config, - device_config=engine_config.device_config, - cache_config=engine_config.cache_config, - load_config=engine_config.load_config, - lora_config=engine_config.lora_config, - prompt_adapter_config=engine_config.prompt_adapter_config, - observability_config=engine_config.observability_config, + vllm_config=engine_config, is_driver_worker=True, ) return model_runner diff --git a/tests/worker/test_profile.py b/tests/worker/test_profile.py index acd2ed6836365..194ea2aa506f4 100644 --- a/tests/worker/test_profile.py +++ b/tests/worker/test_profile.py @@ -24,12 +24,7 @@ def test_gpu_memory_profiling(): distributed_init_method = get_distributed_init_method( get_ip(), get_open_port()) worker = Worker( - model_config=engine_config.model_config, - parallel_config=engine_config.parallel_config, - scheduler_config=engine_config.scheduler_config, - device_config=engine_config.device_config, - cache_config=engine_config.cache_config, - load_config=engine_config.load_config, + vllm_config=engine_config, local_rank=0, rank=0, distributed_init_method=distributed_init_method, diff --git a/tests/worker/test_swap.py b/tests/worker/test_swap.py index 7aa439ba0a154..acede959f59f8 100644 --- a/tests/worker/test_swap.py +++ b/tests/worker/test_swap.py @@ -19,12 +19,7 @@ def test_swap() -> None: distributed_init_method = get_distributed_init_method( get_ip(), get_open_port()) worker = Worker( - model_config=engine_config.model_config, - parallel_config=engine_config.parallel_config, - scheduler_config=engine_config.scheduler_config, - device_config=engine_config.device_config, - cache_config=engine_config.cache_config, - load_config=engine_config.load_config, + vllm_config=engine_config, local_rank=0, rank=0, distributed_init_method=distributed_init_method, diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index 9ea89eca01f5b..a504cb1f7e318 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -7,6 +7,8 @@ import torch +from vllm.multimodal import MultiModalPlaceholderMap + if TYPE_CHECKING: from vllm.worker.model_runner_base import (ModelRunnerBase, ModelRunnerInputBase, @@ -108,6 +110,15 @@ class AttentionMetadata: # in block 0, and 1st slot in block 1, respectively. slot_mapping: torch.Tensor + # The index maps that relate multi-modal embeddings to the corresponding + # placeholders. + # + # N.B. These aren't really related to attention and don't belong on this + # type -- this is just a temporary solution to make them available to + # `model_executable`. + multi_modal_placeholder_index_maps: Optional[Dict[ + str, MultiModalPlaceholderMap.IndexMap]] + @property @abstractmethod def prefill_metadata(self) -> Optional["AttentionMetadata"]: diff --git a/vllm/attention/backends/blocksparse_attn.py b/vllm/attention/backends/blocksparse_attn.py index c216d195c9e7e..409a42187f46c 100644 --- a/vllm/attention/backends/blocksparse_attn.py +++ b/vllm/attention/backends/blocksparse_attn.py @@ -215,6 +215,8 @@ def prefill_metadata( num_prefill_tokens=self.num_prefill_tokens, num_decode_tokens=0, slot_mapping=self.slot_mapping[:self.num_prefill_tokens], + multi_modal_placeholder_index_maps=self. + multi_modal_placeholder_index_maps, seq_lens=self.seq_lens[:self.num_prefills], seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills], max_query_len=self.max_query_len, @@ -243,6 +245,7 @@ def decode_metadata(self) -> Optional["BlocksparseFlashAttentionMetadata"]: num_prefill_tokens=0, num_decode_tokens=self.num_decode_tokens, slot_mapping=self.slot_mapping[self.num_prefill_tokens:], + multi_modal_placeholder_index_maps=None, seq_lens=None, seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:], max_query_len=None, diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index c294fcf7f08fe..26da0d89def29 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -1,5 +1,7 @@ """Attention layer with FlashAttention.""" +from collections import defaultdict from dataclasses import dataclass +from itertools import accumulate from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type import torch @@ -9,11 +11,13 @@ AttentionMetadata, AttentionMetadataBuilder, AttentionType) -from vllm.attention.backends.utils import (PAD_SLOT_ID, CommonAttentionState, - compute_slot_mapping, - compute_slot_mapping_start_idx, - is_block_tables_empty) +from vllm.attention.backends.utils import ( + PAD_SLOT_ID, CommonAttentionState, compute_slot_mapping, + compute_slot_mapping_start_idx, get_num_prefill_decode_query_kv_tokens, + get_seq_len_block_table_args, is_all_cross_attn_metadata_set, + is_all_encoder_attn_metadata_set, is_block_tables_empty) from vllm.forward_context import get_forward_context +from vllm.multimodal import MultiModalPlaceholderMap from vllm.utils import (async_tensor_h2d, direct_register_custom_op, make_tensor_with_pad) @@ -71,7 +75,6 @@ def swap_blocks( src_key_cache = src_kv_cache[0] dst_key_cache = dst_kv_cache[0] ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst) - src_value_cache = src_kv_cache[1] dst_value_cache = dst_kv_cache[1] ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst) @@ -83,6 +86,7 @@ def copy_blocks( ) -> None: key_caches = [kv_cache[0] for kv_cache in kv_caches] value_caches = [kv_cache[1] for kv_cache in kv_caches] + ops.copy_blocks(key_caches, value_caches, src_to_dists) @@ -109,26 +113,12 @@ class FlashAttentionMetadata(AttentionMetadata): # |-------------------- seq_len ---------------------| # |-- query_len ---| - # Maximum query length in the batch. - max_query_len: Optional[int] - - # Max number of query tokens among request in the batch. - max_decode_query_len: Optional[int] - # Maximum sequence length among prefill batch. 0 if there are decoding # requests only. max_prefill_seq_len: int # Maximum sequence length among decode batch. 0 if there are prefill # requests only. max_decode_seq_len: int - # (batch_size + 1,). The cumulative subquery lengths of the sequences in - # the batch, used to index into subquery. E.g., if the subquery length - # is [4, 6], it is [0, 4, 10]. - query_start_loc: Optional[torch.Tensor] - # (batch_size + 1,). The cumulative sequence lengths of the sequences in - # the batch, used to index into sequence. E.g., if the sequence length is - # [4, 6], it is [0, 4, 10]. - seq_start_loc: Optional[torch.Tensor] # (batch_size,) A tensor of context lengths (tokens that are computed # so far). context_lens_tensor: Optional[torch.Tensor] @@ -144,11 +134,62 @@ class FlashAttentionMetadata(AttentionMetadata): # Whether or not if cuda graph is enabled. # Cuda-graph is currently enabled for decoding only. # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. + use_cuda_graph: bool + # Maximum query length in the batch. + max_query_len: Optional[int] = None + + # Max number of query tokens among request in the batch. + max_decode_query_len: Optional[int] = None + + # (batch_size + 1,). The cumulative subquery lengths of the sequences in + # the batch, used to index into subquery. E.g., if the subquery length + # is [4, 6], it is [0, 4, 10]. + query_start_loc: Optional[torch.Tensor] = None + # (batch_size + 1,). The cumulative sequence lengths of the sequences in + # the batch, used to index into sequence. E.g., if the sequence length is + # [4, 6], it is [0, 4, 10]. + seq_start_loc: Optional[torch.Tensor] = None + _cached_prefill_metadata: Optional["FlashAttentionMetadata"] = None _cached_decode_metadata: Optional["FlashAttentionMetadata"] = None + # Begin encoder attn & enc/dec cross-attn fields... + + # Encoder sequence lengths representation + encoder_seq_lens: Optional[List[int]] = None + encoder_seq_lens_tensor: Optional[torch.Tensor] = None + # (batch_size + 1,). The cumulative sequence lengths of the sequences in + # the batch, used to index into sequence. E.g., if the sequence length is + # [4, 6], it is [0, 4, 10]. + encoder_seq_start_loc: Optional[torch.Tensor] = None + # Maximum sequence length among encoder sequences + max_encoder_seq_len: Optional[int] = None + # Number of tokens input to encoder + num_encoder_tokens: Optional[int] = None + + # Cross-attention memory-mapping data structures: slot mapping + # and block tables + cross_slot_mapping: Optional[torch.Tensor] = None + cross_block_tables: Optional[torch.Tensor] = None + + @property + def is_all_encoder_attn_metadata_set(self): + ''' + All attention metadata required for encoder attention is set. + ''' + return is_all_encoder_attn_metadata_set(self) + + @property + def is_all_cross_attn_metadata_set(self): + ''' + All attention metadata required for enc/dec cross-attention is set. + + Superset of encoder attention required metadata. + ''' + return is_all_cross_attn_metadata_set(self) + @property def prefill_metadata(self) -> Optional["FlashAttentionMetadata"]: if self.num_prefills == 0: @@ -157,30 +198,52 @@ def prefill_metadata(self) -> Optional["FlashAttentionMetadata"]: if self._cached_prefill_metadata is not None: return self._cached_prefill_metadata - assert self.seq_lens is not None - assert self.seq_lens_tensor is not None - assert self.query_start_loc is not None - assert self.context_lens_tensor is not None - assert self.block_tables is not None - assert self.seq_start_loc is not None + assert ((self.seq_lens is not None) + or (self.encoder_seq_lens is not None)) + assert ((self.seq_lens_tensor is not None) + or (self.encoder_seq_lens_tensor is not None)) + + # Compute some attn_metadata fields which default to None + query_start_loc = (None if self.query_start_loc is None else + self.query_start_loc[:self.num_prefills + 1]) + slot_mapping = (None if self.slot_mapping is None else + self.slot_mapping[:self.num_prefill_tokens]) + seq_lens = (None if self.seq_lens is None else + self.seq_lens[:self.num_prefills]) + seq_lens_tensor = (None if self.seq_lens_tensor is None else + self.seq_lens_tensor[:self.num_prefills]) + seq_start_loc = (None if self.seq_start_loc is None else + self.seq_start_loc[:self.num_prefills + 1]) + context_lens_tensor = (None if self.context_lens_tensor is None else + self.context_lens_tensor[:self.num_prefills]) + block_tables = (None if self.block_tables is None else + self.block_tables[:self.num_prefills]) self._cached_prefill_metadata = FlashAttentionMetadata( num_prefills=self.num_prefills, num_prefill_tokens=self.num_prefill_tokens, num_decode_tokens=0, - slot_mapping=self.slot_mapping[:self.num_prefill_tokens], - seq_lens=self.seq_lens[:self.num_prefills], - seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills], + slot_mapping=slot_mapping, + multi_modal_placeholder_index_maps=self. + multi_modal_placeholder_index_maps, + seq_lens=seq_lens, + seq_lens_tensor=seq_lens_tensor, max_query_len=self.max_query_len, max_prefill_seq_len=self.max_prefill_seq_len, max_decode_query_len=0, max_decode_seq_len=0, - query_start_loc=self.query_start_loc[:self.num_prefills + 1], - seq_start_loc=self.seq_start_loc[:self.num_prefills + 1], - context_lens_tensor=self.context_lens_tensor[:self.num_prefills], - block_tables=self.block_tables[:self.num_prefills], + query_start_loc=query_start_loc, + seq_start_loc=seq_start_loc, + context_lens_tensor=context_lens_tensor, + block_tables=block_tables, use_cuda_graph=False, - ) + # Begin encoder & cross attn fields below... + encoder_seq_lens=self.encoder_seq_lens, + encoder_seq_lens_tensor=self.encoder_seq_lens_tensor, + encoder_seq_start_loc=self.encoder_seq_start_loc, + max_encoder_seq_len=self.max_encoder_seq_len, + cross_slot_mapping=self.cross_slot_mapping, + cross_block_tables=self.cross_block_tables) return self._cached_prefill_metadata @property @@ -190,16 +253,25 @@ def decode_metadata(self) -> Optional["FlashAttentionMetadata"]: if self._cached_decode_metadata is not None: return self._cached_decode_metadata - assert self.block_tables is not None - assert self.seq_lens_tensor is not None + assert ((self.seq_lens_tensor is not None) + or (self.encoder_seq_lens_tensor is not None)) + + # Compute some attn_metadata fields which default to None + slot_mapping = (None if self.slot_mapping is None else + self.slot_mapping[self.num_prefill_tokens:]) + seq_lens_tensor = (None if self.seq_lens_tensor is None else + self.seq_lens_tensor[self.num_prefills:]) + block_tables = (None if self.block_tables is None else + self.block_tables[self.num_prefills:]) self._cached_decode_metadata = FlashAttentionMetadata( num_prefills=0, num_prefill_tokens=0, num_decode_tokens=self.num_decode_tokens, - slot_mapping=self.slot_mapping[self.num_prefill_tokens:], + slot_mapping=slot_mapping, + multi_modal_placeholder_index_maps=None, seq_lens=None, - seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:], + seq_lens_tensor=seq_lens_tensor, max_decode_query_len=self.max_decode_query_len, max_query_len=self.max_query_len, max_prefill_seq_len=0, @@ -209,9 +281,15 @@ def decode_metadata(self) -> Optional["FlashAttentionMetadata"]: seq_start_loc=self.seq_start_loc[self.num_prefills:] if self.seq_start_loc is not None else None, context_lens_tensor=None, - block_tables=self.block_tables[self.num_prefills:], + block_tables=block_tables, use_cuda_graph=self.use_cuda_graph, - ) + # Begin encoder & cross attn fields below... + encoder_seq_lens=self.encoder_seq_lens, + encoder_seq_lens_tensor=self.encoder_seq_lens_tensor, + encoder_seq_start_loc=self.encoder_seq_start_loc, + max_encoder_seq_len=self.max_encoder_seq_len, + cross_slot_mapping=self.cross_slot_mapping, + cross_block_tables=self.cross_block_tables) return self._cached_decode_metadata def advance_step(self, @@ -297,6 +375,9 @@ def __init__(self, input_builder: "ModelInputForGPUBuilder"): self.context_lens: List[int] = [] self.block_tables: List[List[int]] = [] self.curr_seq_lens: List[int] = [] + self.multimodal_placeholder_maps: Dict[ + str, + MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap) self.num_prefills = 0 self.num_prefill_tokens = 0 self.num_decode_tokens = 0 @@ -327,6 +408,12 @@ def _add_seq_group( self.context_lens.append(context_len) if is_prompt: + mm_maps = inter_data.multi_modal_placeholder_maps + if mm_maps: + for modality, placeholders in mm_maps.items(): + self.multimodal_placeholder_maps[modality].extend( + placeholders) + self.num_prefills += 1 self.num_prefill_tokens += token_len self.prefill_seq_lens.append(seq_len) @@ -417,6 +504,8 @@ def build(self, seq_lens: List[int], query_lens: List[int], max_prefill_seq_len = max(self.prefill_seq_lens, default=0) max_decode_seq_len = max(self.curr_seq_lens, default=0) num_decode_tokens = self.num_decode_tokens + query_start_loc = list(accumulate(query_lens, initial=0)) + seq_start_loc = list(accumulate(seq_lens, initial=0)) num_seqs = len(seq_lens) if use_captured_graph: @@ -439,24 +528,18 @@ def build(self, seq_lens: List[int], query_lens: List[int], device, self.runner.pin_memory) seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device, self.runner.pin_memory) - query_lens_tensor = async_tensor_h2d(query_lens, torch.long, device, - self.runner.pin_memory) slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long, device, self.runner.pin_memory) - query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1, - dtype=torch.int32, - device=device) - seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1, - dtype=torch.int32, - device=device) - torch.cumsum(seq_lens_tensor, - dim=0, - dtype=seq_start_loc.dtype, - out=seq_start_loc[1:]) - torch.cumsum(query_lens_tensor, - dim=0, - dtype=query_start_loc.dtype, - out=query_start_loc[1:]) + query_start_loc_tensor = async_tensor_h2d(query_start_loc, torch.int32, + device, + self.runner.pin_memory) + seq_start_loc_tensor = async_tensor_h2d(seq_start_loc, torch.int32, + device, self.runner.pin_memory) + placeholder_index_maps = { + modality: placeholder_map.index_map() + for modality, placeholder_map in + self.multimodal_placeholder_maps.items() + } return FlashAttentionMetadata( num_prefills=self.num_prefills, @@ -464,13 +547,14 @@ def build(self, seq_lens: List[int], query_lens: List[int], num_prefill_tokens=self.num_prefill_tokens, num_decode_tokens=num_decode_tokens, seq_lens=seq_lens, + multi_modal_placeholder_index_maps=placeholder_index_maps, seq_lens_tensor=seq_lens_tensor, max_query_len=max_query_len, max_decode_query_len=max_decode_query_len, max_prefill_seq_len=max_prefill_seq_len, max_decode_seq_len=max_decode_seq_len, - query_start_loc=query_start_loc, - seq_start_loc=seq_start_loc, + query_start_loc=query_start_loc_tensor, + seq_start_loc=seq_start_loc_tensor, context_lens_tensor=context_lens_tensor, block_tables=block_tables, use_cuda_graph=use_captured_graph, @@ -566,16 +650,20 @@ def forward( Returns: shape = [num_tokens, num_heads * head_size] """ - if attn_type != AttentionType.DECODER: - raise NotImplementedError("Encoder self-attention and " - "encoder/decoder cross-attention " - "are not implemented for " - "FlashAttentionImpl") - # NOTE(woosuk): FlashAttention does not support FP8 KV cache. assert k_scale == 1.0 and v_scale == 1.0, ( "key/v_scale is not supported in FlashAttention.") + if (attn_type == AttentionType.ENCODER + and (not attn_metadata.is_all_encoder_attn_metadata_set)): + raise AttributeError("Encoder attention requires setting " + "encoder metadata attributes.") + elif (attn_type == AttentionType.ENCODER_DECODER + and (not attn_metadata.is_all_cross_attn_metadata_set)): + raise AttributeError("Encoder/decoder cross-attention " + "requires setting cross-attention " + "metadata attributes.") + output = torch.ops.vllm.unified_flash_attention( query, key, @@ -588,6 +676,7 @@ def forward( k_scale, v_scale, self.scale, + attn_type.value, self.sliding_window, self.alibi_slopes, self.logits_soft_cap, @@ -596,6 +685,89 @@ def forward( return output +def _get_query_key_seq_metadata( + attn_metadata, + is_prompt: bool, + attn_type: AttentionType, +) -> tuple: + """ + Returns sequence metadata for key and query based on the specified + attention type and whether input is a prompt. + + This function computes the starting locations and maximum sequence lengths + for key and query sequences for different attention types. + + Args: + attn_metadata: The attention metadata object + is_prompt (bool): A flag indicating if the input is a prompt + attn_type (AttentionType): The type of attention being used. + + Returns: + tuple: A tuple containing four integers: + - Starting location for the query sequence. + - Maximum sequence length for the query sequence. + - Starting location for the key sequence. + - Maximum sequence length for the key sequence. + + Raises: + AttributeError: If an invalid attention type is provided. + """ + if attn_type == AttentionType.DECODER: + # Decoder self-attention + # Choose max_seq_len based on whether we are in prompt_run + if is_prompt: + max_seq_len = attn_metadata.max_prefill_seq_len + else: + max_seq_len = attn_metadata.max_decode_seq_len + return (attn_metadata.seq_start_loc, max_seq_len, + attn_metadata.seq_start_loc, max_seq_len) + + elif attn_type == AttentionType.ENCODER_DECODER: + # This is cross attention between the where the key + # is the precomputed encoder attention and query + # is the input sequence. + # Choose query max length based on whether it is prompt + # or not. + if is_prompt: + max_seq_len = attn_metadata.max_prefill_seq_len + else: + max_seq_len = attn_metadata.max_decode_seq_len + return (attn_metadata.seq_start_loc, max_seq_len, + attn_metadata.encoder_seq_start_loc, + attn_metadata.max_encoder_seq_len) + elif attn_type == AttentionType.ENCODER: + # For encoder attention both the query and the key are same i.e the + # encoder sequence. + return (attn_metadata.encoder_seq_start_loc, + attn_metadata.max_encoder_seq_len, + attn_metadata.encoder_seq_start_loc, + attn_metadata.max_encoder_seq_len) + elif attn_type == AttentionType.ENCODER_ONLY: + assert is_prompt, "Should not have decode for encoder only model." + return (attn_metadata.seq_start_loc, attn_metadata.max_prefill_seq_len, + attn_metadata.seq_start_loc, attn_metadata.max_prefill_seq_len) + else: + raise AttributeError(f"Invalid attention type {str(attn_type)}") + + +def _get_causal_option(attn_type: AttentionType) -> bool: + """ + Determine whether the given attention type is suitable for causal + attention mechanisms. + + Args: + attn_type (AttentionType): The type of attention being evaluated + + Returns: + bool: Returns `True` if the attention type is suitable for causal + attention (i.e., not encoder, encoder-only, or encoder-decoder), + otherwise returns `False`. + """ + return not (attn_type == AttentionType.ENCODER + or attn_type == AttentionType.ENCODER_ONLY + or attn_type == AttentionType.ENCODER_DECODER) + + def unified_flash_attention( query: torch.Tensor, key: torch.Tensor, @@ -608,60 +780,76 @@ def unified_flash_attention( k_scale: float, v_scale: float, softmax_scale: float, + attn_type_int_val: int, window_size: Optional[List[int]] = None, alibi_slopes: Optional[torch.Tensor] = None, logits_soft_cap: Optional[float] = None, ) -> torch.Tensor: + # Convert integer attn_type to enum + try: + attn_type = AttentionType(attn_type_int_val) + except ValueError as err: + raise AttributeError( + f"Invalid attention type {str(attn_type_int_val)}") from err + current_metadata = get_forward_context() assert current_metadata is not None assert isinstance(current_metadata, FlashAttentionMetadata) attn_metadata: FlashAttentionMetadata = current_metadata num_tokens, hidden_size = query.shape + # Reshape the query, key, and value tensors. query = query.view(-1, num_heads, head_size) - key = key.view(-1, num_kv_heads, head_size) - value = value.view(-1, num_kv_heads, head_size) + if (key is not None) and (value is not None): + key = key.view(-1, num_kv_heads, head_size) + value = value.view(-1, num_kv_heads, head_size) if kv_cache.numel() > 0: key_cache = kv_cache[0] value_cache = kv_cache[1] + # We skip updating the KV cache under two conditions: + # a. When the Attention Type is ENCODER. In this phase, we compute + # only the encoder attention without updating the cache. + # b. When both Key and Value are None. This occurs during + # cross-attention computation in the decoding phase, where the KV + # cache is already populated with the cross-attention tensor. + # Thus, we skip cache updates during this time. + if (attn_type != AttentionType.ENCODER) and (key is not None) and ( + value is not None): + if attn_type == AttentionType.ENCODER_DECODER: + # Update cross-attention KV cache (prefill-only) + updated_slot_mapping = attn_metadata.cross_slot_mapping + else: + # Update self-attention KV cache (prefill/decode) + updated_slot_mapping = attn_metadata.slot_mapping + + # Reshape the input keys and values and store them in the cache. + # If kv_cache is not provided, the new key and value tensors are + # not cached. This happens during the initial memory profiling run. + torch.ops._C_cache_ops.reshape_and_cache_flash( + key, + value, + kv_cache[0], + kv_cache[1], + updated_slot_mapping.flatten(), # type: ignore[union-attr] + kv_cache_dtype, + k_scale, + v_scale, + ) - # Reshape the input keys and values and store them in the cache. - # If kv_cache is not provided, the new key and value tensors are - # not cached. This happens during the initial memory profiling run. - torch.ops._C_cache_ops.reshape_and_cache_flash( - key, - value, - kv_cache[0], - kv_cache[1], - attn_metadata.slot_mapping.flatten(), - kv_cache_dtype, - k_scale, - v_scale, - ) - - num_prefill_tokens = attn_metadata.num_prefill_tokens - num_decode_tokens = attn_metadata.num_decode_tokens - assert key.shape[0] == num_prefill_tokens + num_decode_tokens, \ - f"key : {key.shape} : #prefill tokens {num_prefill_tokens} : #decode tokens {num_decode_tokens}" # noqa - assert value.shape[0] == num_prefill_tokens + num_decode_tokens, \ - f"value : {value.shape} : #prefill toks {num_prefill_tokens} : #decode toks {num_decode_tokens}" # noqa - - # Query for decode. KV is not needed because it is already cached. - decode_query = query[num_prefill_tokens:] + (num_prefill_query_tokens, num_prefill_kv_tokens, + num_decode_query_tokens) = \ + get_num_prefill_decode_query_kv_tokens(attn_metadata, attn_type) + decode_query = query[num_prefill_query_tokens:] # QKV for prefill. - query = query[:num_prefill_tokens] - key = key[:num_prefill_tokens] - value = value[:num_prefill_tokens] - - assert query.shape[0] == num_prefill_tokens - assert decode_query.shape[0] == num_decode_tokens + query = query[:num_prefill_query_tokens] + assert query.shape[0] == num_prefill_query_tokens + assert decode_query.shape[0] == num_decode_query_tokens prefill_output: Optional[torch.Tensor] = None decode_output: Optional[torch.Tensor] = None - if prefill_meta := attn_metadata.prefill_metadata: # Prompt run. if (kv_cache.numel() == 0 or prefill_meta.block_tables is None @@ -669,22 +857,30 @@ def unified_flash_attention( # normal attention # When block_tables are not filled, it means q and k are the # prompt, and they have the same length. + q_seq_start_loc, q_seq_len, k_seq_start_loc, k_seq_len = \ + _get_query_key_seq_metadata(prefill_meta, True, attn_type) + + key = key[:num_prefill_kv_tokens] + value = value[:num_prefill_kv_tokens] + prefill_output = flash_attn_varlen_func( q=query, k=key, v=value, - cu_seqlens_q=prefill_meta.seq_start_loc, - cu_seqlens_k=prefill_meta.seq_start_loc, - max_seqlen_q=prefill_meta.max_prefill_seq_len, - max_seqlen_k=prefill_meta.max_prefill_seq_len, + cu_seqlens_q=q_seq_start_loc, + cu_seqlens_k=k_seq_start_loc, + max_seqlen_q=q_seq_len, + max_seqlen_k=k_seq_len, softmax_scale=softmax_scale, - causal=True, + causal=_get_causal_option(attn_type), window_size=window_size, alibi_slopes=alibi_slopes, softcap=logits_soft_cap, ) else: # prefix-enabled attention + assert attn_type == AttentionType.DECODER, ( + "Only decoder-only models support prefix caching") assert prefill_meta.seq_lens is not None max_seq_len = max(prefill_meta.seq_lens) prefill_output = flash_attn_varlen_func( # noqa @@ -709,6 +905,8 @@ def unified_flash_attention( # because different queries might have different lengths. assert decode_meta.max_decode_query_len is not None if decode_meta.max_decode_query_len > 1: + assert attn_type == AttentionType.DECODER, ( + "Only decoder-only models support max_decode_query_len > 1") decode_output = flash_attn_varlen_func( q=decode_query, k=key_cache, @@ -726,12 +924,17 @@ def unified_flash_attention( ) else: # Use flash_attn_with_kvcache for normal decoding. + ( + seq_lens_arg, + _, + block_tables_arg, + ) = get_seq_len_block_table_args(decode_meta, False, attn_type) decode_output = flash_attn_with_kvcache( q=decode_query.unsqueeze(1), k_cache=key_cache, v_cache=value_cache, - block_table=decode_meta.block_tables, - cache_seqlens=decode_meta.seq_lens_tensor, + block_table=block_tables_arg, + cache_seqlens=seq_lens_arg, softmax_scale=softmax_scale, causal=True, window_size=window_size, @@ -741,10 +944,10 @@ def unified_flash_attention( if prefill_output is None: assert decode_output is not None - return decode_output.view(num_decode_tokens, hidden_size) + return decode_output.view(num_decode_query_tokens, hidden_size) if decode_output is None: assert prefill_output is not None - return prefill_output.view(num_prefill_tokens, hidden_size) + return prefill_output.view(num_prefill_query_tokens, hidden_size) # Chunked prefill does not work with speculative decoding. # Therefore, the query length for decode should be 1 in chunked prefill. @@ -766,6 +969,7 @@ def unified_flash_attention_fake( k_scale: float, v_scale: float, softmax_scale: float, + attn_type_int_val: int, window_size: Optional[List[int]] = None, alibi_slopes: Optional[torch.Tensor] = None, logits_soft_cap: Optional[float] = None, diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 234c87d5c4edb..107e3bbf79666 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -1,7 +1,10 @@ +from collections import defaultdict from contextlib import contextmanager from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Type +from vllm.multimodal import MultiModalPlaceholderMap + try: from flashinfer import BatchDecodeWithPagedKVCacheWrapper from flashinfer.decode import CUDAGraphBatchDecodeWithPagedKVCacheWrapper @@ -215,6 +218,7 @@ def graph_capture_get_metadata_for_batch( attn_metadata = self.runner.attn_backend.make_metadata( num_prefills=0, slot_mapping=self._graph_slot_mapping[:batch_size], + multi_modal_placeholder_index_maps=None, num_prefill_tokens=0, num_decode_tokens=batch_size, max_prefill_seq_len=0, @@ -470,6 +474,9 @@ def __init__(self, input_builder: "ModelInputForGPUBuilder"): self.context_lens: List[int] = [] self.block_tables: List[List[int]] = [] self.curr_seq_lens: List[int] = [] + self.multimodal_placeholder_maps: Dict[ + str, + MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap) self.num_prefills = 0 self.num_prefill_tokens = 0 self.num_decode_tokens = 0 @@ -519,6 +526,11 @@ def _add_seq_group( inter_data.curr_sliding_window_blocks): self.context_lens.append(context_len) if is_prompt: + mm_maps = inter_data.multi_modal_placeholder_maps + if mm_maps: + for modality, placeholders in mm_maps.items(): + self.multimodal_placeholder_maps[modality].extend( + placeholders) self.num_prefills += 1 self.num_prefill_tokens += token_len self.prefill_seq_lens.append(seq_len) @@ -651,6 +663,11 @@ def build(self, seq_lens: List[int], query_lens: List[int], seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1, dtype=torch.int32, device=device) + placeholder_index_maps = { + modality: placeholder_map.index_map() + for modality, placeholder_map in + self.multimodal_placeholder_maps.items() + } torch.cumsum(seq_lens_tensor, dim=0, dtype=seq_start_loc.dtype, @@ -694,6 +711,7 @@ def build(self, seq_lens: List[int], query_lens: List[int], decode_query_len=decode_query_len, num_prefills=self.num_prefills, slot_mapping=slot_mapping_tensor, + multi_modal_placeholder_index_maps=placeholder_index_maps, num_prefill_tokens=self.num_prefill_tokens, num_decode_tokens=num_decode_tokens, max_prefill_seq_len=max_prefill_seq_len, @@ -759,8 +777,6 @@ def forward( v_scale: float = 1.0, attn_type: AttentionType = AttentionType.DECODER, ) -> torch.Tensor: - assert k_scale == 1.0 and v_scale == 1.0, ( - "key/v_scale is not supported in FlashInfer.") if attn_type != AttentionType.DECODER: raise NotImplementedError("Encoder self-attention and " "encoder/decoder cross-attention " @@ -874,7 +890,12 @@ def unified_flash_infer( assert prefill_meta is not None assert prefill_meta.prefill_wrapper is not None prefill_output = prefill_meta.prefill_wrapper.forward( - query, kv_cache, logits_soft_cap=logits_soft_cap, causal=True) + query, + kv_cache, + logits_soft_cap=logits_soft_cap, + causal=True, + k_scale=k_scale, + v_scale=v_scale) if decode_meta := attn_metadata.decode_metadata: assert attn_metadata.decode_metadata is not None assert attn_metadata.decode_metadata.decode_wrapper is not None diff --git a/vllm/attention/backends/placeholder_attn.py b/vllm/attention/backends/placeholder_attn.py index 4116fbf00020c..888adbffb8578 100644 --- a/vllm/attention/backends/placeholder_attn.py +++ b/vllm/attention/backends/placeholder_attn.py @@ -1,5 +1,6 @@ +from collections import defaultdict from dataclasses import dataclass -from typing import TYPE_CHECKING, List, Optional, Tuple, Type +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Type import torch @@ -7,6 +8,7 @@ AttentionMetadata, AttentionMetadataBuilder) from vllm.attention.backends.utils import CommonAttentionState +from vllm.multimodal import MultiModalPlaceholderMap if TYPE_CHECKING: from vllm.worker.model_runner import ModelInputForGPUBuilder @@ -135,6 +137,8 @@ def prefill_metadata(self) -> Optional["PlaceholderAttentionMetadata"]: num_prefill_tokens=self.num_prefill_tokens, num_decode_tokens=0, slot_mapping=slot_mapping, + multi_modal_placeholder_index_maps=self. + multi_modal_placeholder_index_maps, seq_lens=self.seq_lens[:self.num_prefills], seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills], max_decode_query_len=0, @@ -167,6 +171,7 @@ def decode_metadata(self) -> Optional["PlaceholderAttentionMetadata"]: num_prefill_tokens=0, num_decode_tokens=self.num_decode_tokens, slot_mapping=slot_mapping, + multi_modal_placeholder_index_maps=None, seq_lens=None, seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:], max_decode_query_len=self.max_decode_query_len, @@ -189,6 +194,9 @@ def __init__(self, input_builder: "ModelInputForGPUBuilder"): self.prefill_seq_lens: List[int] = [] self.context_lens: List[int] = [] self.curr_seq_lens: List[int] = [] + self.multimodal_placeholder_maps: Dict[ + str, + MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap) self.num_prefills = 0 self.num_prefill_tokens = 0 self.num_decode_tokens = 0 @@ -213,6 +221,12 @@ def _add_seq_group( self.context_lens.append(context_len) if is_prompt: + mm_maps = inter_data.multi_modal_placeholder_maps + if mm_maps: + for modality, placeholders in mm_maps.items(): + self.multimodal_placeholder_maps[modality].extend( + placeholders) + self.num_prefills += 1 self.num_prefill_tokens += token_len self.prefill_seq_lens.append(seq_len) @@ -280,6 +294,11 @@ def build(self, seq_lens: List[int], query_lens: List[int], seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1, dtype=torch.int32, device=device) + placeholder_index_maps = { + modality: placeholder_map.index_map() + for modality, placeholder_map in + self.multimodal_placeholder_maps.items() + } torch.cumsum(seq_lens_tensor, dim=0, dtype=seq_start_loc.dtype, @@ -296,6 +315,7 @@ def build(self, seq_lens: List[int], query_lens: List[int], return PlaceholderAttentionMetadata( num_prefills=self.num_prefills, slot_mapping=slot_mapping, + multi_modal_placeholder_index_maps=placeholder_index_maps, num_prefill_tokens=self.num_prefill_tokens, num_decode_tokens=num_decode_tokens, seq_lens=seq_lens, diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 30859dfa60634..b129d0d992f2f 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -150,6 +150,8 @@ def prefill_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]: num_prefill_tokens=self.num_prefill_tokens, num_decode_tokens=0, slot_mapping=self.slot_mapping[:self.num_prefill_tokens], + multi_modal_placeholder_index_maps=self. + multi_modal_placeholder_index_maps, seq_lens=self.seq_lens[:self.num_prefills], seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills], max_query_len=self.max_query_len, @@ -178,6 +180,7 @@ def decode_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]: num_prefill_tokens=0, num_decode_tokens=self.num_decode_tokens, slot_mapping=self.slot_mapping[self.num_prefill_tokens:], + multi_modal_placeholder_index_maps=None, seq_lens=None, seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:], max_query_len=None, diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py index 32fccd0dfb496..12800668af223 100644 --- a/vllm/attention/backends/utils.py +++ b/vllm/attention/backends/utils.py @@ -1,12 +1,16 @@ """Attention backend utils""" +from collections import defaultdict from contextlib import contextmanager -from typing import TYPE_CHECKING, Any, Dict, List, Type, TypeVar, Union +from itertools import accumulate +from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Type, TypeVar, Union import numpy as np import torch from vllm.attention import (AttentionMetadata, AttentionMetadataBuilder, AttentionState) +from vllm.attention.backends.abstract import AttentionType +from vllm.multimodal import MultiModalPlaceholderMap from vllm.utils import async_tensor_h2d, make_tensor_with_pad if TYPE_CHECKING: @@ -123,6 +127,9 @@ def __init__(self, input_builder: "ModelInputForGPUBuilder"): self.context_lens: List[int] = [] self.block_tables: List[List[int]] = [] self.curr_seq_lens: List[int] = [] + self.multimodal_placeholder_maps: Dict[ + str, + MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap) self.num_prefills = 0 self.num_prefill_tokens = 0 self.num_decode_tokens = 0 @@ -147,6 +154,12 @@ def _add_seq_group( inter_data.curr_sliding_window_blocks): self.context_lens.append(context_len) if is_prompt: + mm_maps = inter_data.multi_modal_placeholder_maps + if mm_maps: + for modality, placeholders in mm_maps.items(): + self.multimodal_placeholder_maps[modality].extend( + placeholders) + self.num_prefills += 1 self.num_prefill_tokens += token_len self.prefill_seq_lens.append(seq_len) @@ -204,6 +217,8 @@ def build(self, seq_lens: List[int], query_lens: List[int], max_prefill_seq_len = max(self.prefill_seq_lens, default=0) max_decode_seq_len = max(self.curr_seq_lens, default=0) num_decode_tokens = self.num_decode_tokens + query_start_loc = list(accumulate(query_lens, initial=0)) + seq_start_loc = list(accumulate(seq_lens, initial=0)) if use_captured_graph: self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size) @@ -232,28 +247,23 @@ def build(self, seq_lens: List[int], query_lens: List[int], device, self.runner.pin_memory) seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device, self.runner.pin_memory) - query_lens_tensor = async_tensor_h2d(query_lens, torch.long, device, - self.runner.pin_memory) slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long, device, self.runner.pin_memory) - query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1, - dtype=torch.int32, - device=device) - seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1, - dtype=torch.int32, - device=device) - torch.cumsum(seq_lens_tensor, - dim=0, - dtype=seq_start_loc.dtype, - out=seq_start_loc[1:]) - torch.cumsum(query_lens_tensor, - dim=0, - dtype=query_start_loc.dtype, - out=query_start_loc[1:]) + query_start_loc_tensor = async_tensor_h2d(query_start_loc, torch.int32, + device, + self.runner.pin_memory) + seq_start_loc_tensor = async_tensor_h2d(seq_start_loc, torch.int32, + device, self.runner.pin_memory) + placeholder_index_maps = { + modality: placeholder_map.index_map() + for modality, placeholder_map in + self.multimodal_placeholder_maps.items() + } return self._metadata_cls( # type: ignore num_prefills=self.num_prefills, slot_mapping=slot_mapping_tensor, + multi_modal_placeholder_index_maps=placeholder_index_maps, num_prefill_tokens=self.num_prefill_tokens, num_decode_tokens=num_decode_tokens, seq_lens=seq_lens, @@ -261,8 +271,8 @@ def build(self, seq_lens: List[int], query_lens: List[int], max_query_len=max_query_len, max_prefill_seq_len=max_prefill_seq_len, max_decode_seq_len=max_decode_seq_len, - query_start_loc=query_start_loc, - seq_start_loc=seq_start_loc, + query_start_loc=query_start_loc_tensor, + seq_start_loc=seq_start_loc_tensor, context_lens_tensor=context_lens_tensor, block_tables=block_tables, use_cuda_graph=use_captured_graph, @@ -305,6 +315,7 @@ def graph_capture_get_metadata_for_batch( num_prefill_tokens=0, num_decode_tokens=batch_size, slot_mapping=self._graph_slot_mapping[:batch_size], + multi_modal_placeholder_index_maps=None, seq_lens=None, seq_lens_tensor=self._graph_seq_lens[:batch_size], max_query_len=1, @@ -318,11 +329,13 @@ def graph_capture_get_metadata_for_batch( use_cuda_graph=True, ) if is_encoder_decoder_model: - # The encoder decoder model works only with XFormers backend. - # Assert the same. - assert self.runner.attn_backend.get_name() == "XFORMERS", \ - f"Expected attn_backend name to be 'XFORMERS', but "\ - f" got '{self.runner.attn_backend.get_name()}'" + # The encoder decoder model works only with XFormers and + # Flash Attention backend. Assert the same. + assert self.runner.attn_backend.get_name() in\ + ["XFORMERS", "FLASH_ATTN"], \ + f"Expected attn_backend name to be either 'XFORMERS' or " \ + f"'FLASH_ATTN', but "\ + f"got '{self.runner.attn_backend.get_name()}'" self._update_captured_metadata_for_enc_dec_model( batch_size=batch_size, attn_metadata=attn_metadata) @@ -338,11 +351,13 @@ def get_graph_input_buffers( "block_tables": attn_metadata.decode_metadata.block_tables, } if is_encoder_decoder_model: - # The encoder decoder model works only with XFormers backend. - # Assert the same. - assert self.runner.attn_backend.get_name() == "XFORMERS", \ - f"Expected attn_backend name to be 'XFORMERS', but "\ - f" got '{self.runner.attn_backend.get_name()}'" + # The encoder decoder model works only with XFormers and + # Flash Attention backend. Assert the same. + assert self.runner.attn_backend.get_name() in\ + ["XFORMERS", "FLASH_ATTN"], \ + f"Expected attn_backend name to be either 'XFORMERS' or "\ + f"'FLASH_ATTN', but "\ + f"got '{self.runner.attn_backend.get_name()}'" self._add_additonal_input_buffers_for_enc_dec_model( attn_metadata=attn_metadata, input_buffers=input_buffers) return input_buffers @@ -357,11 +372,13 @@ def prepare_graph_input_buffers( input_buffers["block_tables"].copy_( attn_metadata.decode_metadata.block_tables, non_blocking=True) if is_encoder_decoder_model: - # The encoder decoder model works only with XFormers backend. - # Assert the same. - assert self.runner.attn_backend.get_name() == "XFORMERS", \ - f"Expected attn_backend name to be 'XFORMERS', but "\ - f" got '{self.runner.attn_backend.get_name()}'" + # The encoder decoder model works only with XFormers and + # Flash Attention backend. Assert the same. + assert self.runner.attn_backend.get_name() in\ + ["XFORMERS", "FLASH_ATTN"], \ + f"Expected attn_backend name to be either 'XFORMERS' or "\ + f"'FLASH_ATTN', but "\ + f"got '{self.runner.attn_backend.get_name()}'" self._prepare_input_buffers_for_enc_dec_model( attn_metadata, input_buffers) @@ -393,6 +410,7 @@ def _update_captured_metadata_for_enc_dec_model(self, batch_size: int, attn_metadata.encoder_seq_lens_tensor = torch.full( (batch_size, ), 1, dtype=torch.int).cuda() attn_metadata.max_encoder_seq_len = self.runner.max_seq_len_to_capture + attn_metadata.num_encoder_tokens = 0 def _add_additonal_input_buffers_for_enc_dec_model( self, attn_metadata, input_buffers: Dict[str, Any]): @@ -435,3 +453,122 @@ def _prepare_input_buffers_for_enc_dec_model(self, attn_metadata, input_buffers["cross_block_tables"].copy_( attn_metadata.decode_metadata.cross_block_tables, non_blocking=True) + + +def is_all_encoder_attn_metadata_set(attn_metadata): + ''' + All attention metadata required for encoder attention is set. + ''' + return ((attn_metadata.encoder_seq_lens is not None) + and (attn_metadata.encoder_seq_lens_tensor is not None) + and (attn_metadata.max_encoder_seq_len is not None)) + + +def is_all_cross_attn_metadata_set(attn_metadata): + ''' + All attention metadata required for enc/dec cross-attention is set. + + Superset of encoder attention required metadata. + ''' + return (attn_metadata.is_all_encoder_attn_metadata_set + and (attn_metadata.cross_slot_mapping is not None) + and (attn_metadata.cross_block_tables is not None)) + + +def get_seq_len_block_table_args( + attn_metadata, + is_prompt: bool, + attn_type: AttentionType, +) -> tuple: + ''' + The particular choice of sequence-length- and block-table-related + attributes which should be extracted from attn_metadata is dependent + on the type of attention operation. + + Decoder attn -> select entirely decoder self-attention-related fields + Encoder/decoder cross-attn -> select encoder sequence lengths & + cross-attn block-tables fields + Encoder attn -> select encoder sequence lengths fields & no block tables + + Arguments: + + * attn_metadata: Attention metadata structure associated with attention op + * is_prompt: True if prefill, False otherwise + * attn_type: encoder attention, decoder self-attention, + encoder/decoder cross-attention + + Returns: + + * Appropriate sequence-lengths tensor + * Appropriate max sequence-length scalar + * Appropriate block tables (or None) + ''' + + if attn_type == AttentionType.DECODER: + # Decoder self-attention + # Choose max_seq_len based on whether we are in prompt_run + if is_prompt: + max_seq_len = attn_metadata.max_prefill_seq_len + else: + max_seq_len = attn_metadata.max_decode_seq_len + return (attn_metadata.seq_lens_tensor, max_seq_len, + attn_metadata.block_tables) + elif attn_type == AttentionType.ENCODER_DECODER: + # Enc/dec cross-attention KVs match encoder sequence length; + # cross-attention utilizes special "cross" block tables + return (attn_metadata.encoder_seq_lens_tensor, + attn_metadata.max_encoder_seq_len, + attn_metadata.cross_block_tables) + elif attn_type == AttentionType.ENCODER: + # No block tables associated with encoder attention + return (attn_metadata.encoder_seq_lens_tensor, + attn_metadata.max_encoder_seq_len, None) + else: + raise AttributeError(f"Invalid attention type {str(attn_type)}") + + +def get_num_prefill_decode_query_kv_tokens( + attn_metadata, + attn_type: AttentionType, +) -> Tuple[int, int, int]: + """ + Calculate the number of prefill and decode tokens for query, key/value + based on the attention metadata and the specified attention type. + + Args: + attn_metadata (FlashAttentionMetadata): Attention Metadata object. + attn_type (AttentionType): The type of attention being used. + Returns: + Tuple[int, int, int]: A tuple containing three integers: + - The number of prefill query tokens. + - The number of prefill key/value tokens. + - The number of decode query tokens. + + Raises: + AssertionError: If the number of encoder tokens in `attn_metadata` + is `None` when required for the calculations. + """ + num_prefill_query_tokens = 0 + num_decode_query_tokens = 0 + num_prefill_kv_tokens = 0 + if attn_type == AttentionType.ENCODER: + # Encoder attention is only invoked during prefill phase. + # The same input servers a both query and key. + assert attn_metadata.num_encoder_tokens is not None + num_prefill_query_tokens = attn_metadata.num_encoder_tokens + num_prefill_kv_tokens = attn_metadata.num_encoder_tokens + num_decode_query_tokens = 0 + elif attn_type == AttentionType.ENCODER_DECODER: + assert attn_metadata.num_encoder_tokens is not None + num_prefill_query_tokens = attn_metadata.num_prefill_tokens + # The key is the encoder/cross-attention. + num_prefill_kv_tokens = attn_metadata.num_encoder_tokens + num_decode_query_tokens = attn_metadata.num_decode_tokens + else: # attn_type == AttentionType.DECODER or + # attn_type == AttentionType.ENCODER_ONLY + num_prefill_query_tokens = attn_metadata.num_prefill_tokens + num_prefill_kv_tokens = attn_metadata.num_prefill_tokens + num_decode_query_tokens = attn_metadata.num_decode_tokens + + return (num_prefill_query_tokens, num_prefill_kv_tokens, + num_decode_query_tokens) diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 5aaf13d8ea744..4725413baade7 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -11,8 +11,10 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, AttentionType) -from vllm.attention.backends.utils import (CommonAttentionState, - CommonMetadataBuilder) +from vllm.attention.backends.utils import ( + CommonAttentionState, CommonMetadataBuilder, + get_num_prefill_decode_query_kv_tokens, get_seq_len_block_table_args, + is_all_cross_attn_metadata_set, is_all_encoder_attn_metadata_set) from vllm.attention.ops.paged_attn import (PagedAttention, PagedAttentionMetadata) from vllm.logger import init_logger @@ -135,6 +137,11 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata): # Encoder sequence lengths representation encoder_seq_lens: Optional[List[int]] = None encoder_seq_lens_tensor: Optional[torch.Tensor] = None + # FIXME: It is for flash attn. + # (batch_size + 1,). The cumulative sequence lengths of the sequences in + # the batch, used to index into sequence. E.g., if the sequence length is + # [4, 6], it is [0, 4, 10]. + encoder_seq_start_loc: Optional[torch.Tensor] = None # Maximum sequence length among encoder sequences max_encoder_seq_len: Optional[int] = None @@ -162,9 +169,7 @@ def is_all_encoder_attn_metadata_set(self): ''' All attention metadata required for encoder attention is set. ''' - return ((self.encoder_seq_lens is not None) - and (self.encoder_seq_lens_tensor is not None) - and (self.max_encoder_seq_len is not None)) + return is_all_encoder_attn_metadata_set(self) @property def is_all_cross_attn_metadata_set(self): @@ -173,9 +178,7 @@ def is_all_cross_attn_metadata_set(self): Superset of encoder attention required metadata. ''' - return (self.is_all_encoder_attn_metadata_set - and (self.cross_slot_mapping is not None) - and (self.cross_block_tables is not None)) + return is_all_cross_attn_metadata_set(self) @property def prefill_metadata(self) -> Optional["XFormersMetadata"]: @@ -212,6 +215,8 @@ def prefill_metadata(self) -> Optional["XFormersMetadata"]: num_prefill_tokens=self.num_prefill_tokens, num_decode_tokens=0, slot_mapping=slot_mapping, + multi_modal_placeholder_index_maps=self. + multi_modal_placeholder_index_maps, seq_lens=seq_lens, seq_lens_tensor=seq_lens_tensor, max_query_len=self.max_query_len, @@ -255,6 +260,7 @@ def decode_metadata(self) -> Optional["XFormersMetadata"]: num_prefill_tokens=0, num_decode_tokens=self.num_decode_tokens, slot_mapping=slot_mapping, + multi_modal_placeholder_index_maps=None, seq_lens_tensor=seq_lens_tensor, max_prefill_seq_len=0, max_decode_seq_len=self.max_decode_seq_len, @@ -326,64 +332,6 @@ def _set_attn_bias( raise AttributeError(f"Invalid attention type {str(attn_type)}") -def _get_seq_len_block_table_args( - attn_metadata: XFormersMetadata, - is_prompt: bool, - attn_type: AttentionType, -) -> tuple: - ''' - The particular choice of sequence-length- and block-table-related - attributes which should be extracted from attn_metadata is dependent - on the type of attention operation. - - Decoder attn -> select entirely decoder self-attention-related fields - Encoder/decoder cross-attn -> select encoder sequence lengths & - cross-attn block-tables fields - Encoder attn -> select encoder sequence lengths fields & no block tables - - Arguments: - - * attn_metadata: Attention metadata structure associated with attention op - * is_prompt: True if prefill, False otherwise - * attn_type: encoder attention, decoder self-attention, - encoder/decoder cross-attention - - Returns: - - * Appropriate sequence-lengths tensor - * Appropriate max sequence-length scalar - * Appropriate block tables (or None) - ''' - - if attn_type == AttentionType.DECODER: - # Decoder self-attention - # Choose max_seq_len based on whether we are in prompt_run - if is_prompt: - max_seq_len = attn_metadata.max_prefill_seq_len - else: - max_seq_len = attn_metadata.max_decode_seq_len - return (attn_metadata.seq_lens_tensor, max_seq_len, - attn_metadata.block_tables) - elif attn_type == AttentionType.ENCODER_DECODER: - # Enc/dec cross-attention KVs match encoder sequence length; - # cross-attention utilizes special "cross" block tables - return (attn_metadata.encoder_seq_lens_tensor, - attn_metadata.max_encoder_seq_len, - attn_metadata.cross_block_tables) - elif attn_type == AttentionType.ENCODER: - # No block tables associated with encoder attention - return (attn_metadata.encoder_seq_lens_tensor, - attn_metadata.max_encoder_seq_len, None) - elif attn_type == AttentionType.ENCODER_ONLY: - assert is_prompt, "Should not have decode for encoder only model." - - # No block tables associated with encoder attention - return (attn_metadata.seq_lens_tensor, - attn_metadata.max_prefill_seq_len, None) - else: - raise AttributeError(f"Invalid attention type {str(attn_type)}") - - class XFormersMetadataBuilder(CommonMetadataBuilder[XFormersMetadata]): _metadata_cls = XFormersMetadata @@ -571,45 +519,21 @@ def forward( updated_slot_mapping, self.kv_cache_dtype, k_scale, v_scale) - - if attn_type == AttentionType.ENCODER: - # Encoder attention - chunked prefill is not applicable; - # derive token-count from query shape & and treat them - # as 100% prefill tokens - assert attn_metadata.num_encoder_tokens is not None - num_prefill_tokens = attn_metadata.num_encoder_tokens - num_encoder_tokens = attn_metadata.num_encoder_tokens - num_decode_tokens = 0 - elif attn_type == AttentionType.DECODER: - # Decoder self-attention supports chunked prefill. - num_prefill_tokens = attn_metadata.num_prefill_tokens - num_encoder_tokens = attn_metadata.num_prefill_tokens - num_decode_tokens = attn_metadata.num_decode_tokens - # Only enforce this shape-constraint for decoder - # self-attention - assert key.shape[0] == num_prefill_tokens + num_decode_tokens - assert value.shape[0] == num_prefill_tokens + num_decode_tokens - else: # attn_type == AttentionType.ENCODER_DECODER - # Encoder/decoder cross-attention requires no chunked - # prefill (100% prefill or 100% decode tokens, no mix) - num_prefill_tokens = attn_metadata.num_prefill_tokens - if attn_metadata.num_encoder_tokens is not None: - num_encoder_tokens = attn_metadata.num_encoder_tokens - else: - num_encoder_tokens = attn_metadata.num_prefill_tokens - num_decode_tokens = attn_metadata.num_decode_tokens + (num_prefill_query_tokens, num_prefill_kv_tokens, + num_decode_query_tokens) = \ + get_num_prefill_decode_query_kv_tokens(attn_metadata, attn_type) output = torch.empty_like(query) # Query for decode. KV is not needed because it is already cached. - decode_query = query[num_prefill_tokens:] + decode_query = query[num_prefill_query_tokens:] # QKV for prefill. - query = query[:num_prefill_tokens] + query = query[:num_prefill_query_tokens] if key is not None and value is not None: - key = key[:num_encoder_tokens] - value = value[:num_encoder_tokens] + key = key[:num_prefill_kv_tokens] + value = value[:num_prefill_kv_tokens] - assert query.shape[0] == num_prefill_tokens - assert decode_query.shape[0] == num_decode_tokens + assert query.shape[0] == num_prefill_query_tokens + assert decode_query.shape[0] == num_decode_query_tokens if prefill_meta := attn_metadata.prefill_metadata: # Prompt run. @@ -619,8 +543,8 @@ def forward( # prefix. out = self._run_memory_efficient_xformers_forward( query, key, value, prefill_meta, attn_type=attn_type) - assert out.shape == output[:num_prefill_tokens].shape - output[:num_prefill_tokens] = out + assert out.shape == output[:num_prefill_query_tokens].shape + output[:num_prefill_query_tokens] = out else: assert attn_type != AttentionType.ENCODER_ONLY, ( "Encoder-only models should not have prefix attention.") @@ -649,8 +573,8 @@ def forward( k_scale, v_scale, ) - assert output[:num_prefill_tokens].shape == out.shape - output[:num_prefill_tokens] = out + assert output[:num_prefill_query_tokens].shape == out.shape + output[:num_prefill_query_tokens] = out if decode_meta := attn_metadata.decode_metadata: assert attn_type != AttentionType.ENCODER_ONLY, ( @@ -660,9 +584,9 @@ def forward( seq_lens_arg, max_seq_len_arg, block_tables_arg, - ) = _get_seq_len_block_table_args(decode_meta, False, attn_type) + ) = get_seq_len_block_table_args(decode_meta, False, attn_type) - output[num_prefill_tokens:] = PagedAttention.forward_decode( + output[num_prefill_query_tokens:] = PagedAttention.forward_decode( decode_query, key_cache, value_cache, diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 376b3136f0fb8..8a59cf41a689e 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -98,7 +98,6 @@ def get_attn_backend( is_blocksparse: bool = False, ) -> Type[AttentionBackend]: """Selects which attention backend to use and lazily imports it.""" - if is_blocksparse: logger.info("Using BlocksparseFlashAttention backend.") from vllm.attention.backends.blocksparse_attn import ( @@ -108,6 +107,7 @@ def get_attn_backend( backend = which_attn_to_use(head_size, dtype, kv_cache_dtype, block_size, is_attention_free) if backend == _Backend.FLASH_ATTN: + logger.info("Using Flash Attention backend.") from vllm.attention.backends.flash_attn import ( # noqa: F401 FlashAttentionBackend) return FlashAttentionBackend diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 10cf49e19eccc..96ddcba467c5b 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -243,6 +243,65 @@ def split_graph(graph: fx.GraphModule, return split_gm, outputs +# we share the global graph pool among all the backends +global_graph_pool = None + + +class PiecewiseCompileInterpreter(torch.fx.Interpreter): + """Code adapted from `torch.fx.passes.shape_prop.ShapeProp`. + It runs the given graph with fake inputs, and compile some + submodules specified by `compile_submod_names` with the given + compilation configs. + """ + + def __init__(self, module: torch.fx.GraphModule, + compile_submod_names: List[str], + compilation_configs: CompilationConfig, graph_pool): + super().__init__(module) + from torch._guards import detect_fake_mode + self.fake_mode = detect_fake_mode() + self.compile_submod_names = compile_submod_names + self.compilation_configs = compilation_configs + self.graph_pool = graph_pool + self.have_seen_first_graph = False + + def run(self, *args): + fake_args = [ + self.fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t + for t in args + ] + return super().run(*fake_args) + + def call_module(self, target: torch.fx.node.Target, + args: Tuple[torch.fx.node.Argument, + ...], kwargs: Dict[str, Any]) -> Any: + assert isinstance(target, str) + output = super().call_module(target, args, kwargs) + + if target in self.compile_submod_names: + submod = self.fetch_attr(target) + sym_shape_indices = [ + i for i, x in enumerate(args) if isinstance(x, torch.SymInt) + ] + compiled_graph_for_general_shape = wrap_inductor( + submod, + args, + self.compilation_configs.inductor_compile_config, + runtime_shape=None, + do_logging=not self.have_seen_first_graph, + use_inductor=self.compilation_configs.use_inductor) + + self.module.__dict__[target] = PiecewiseBackend( + submod, self.compilation_configs, self.graph_pool, + not self.have_seen_first_graph, sym_shape_indices, + compiled_graph_for_general_shape) + + self.have_seen_first_graph = True + compilation_counter.num_piecewise_capturable_graphs_seen += 1 + + return output + + class VllmBackend: """The compilation backend for `torch.compile` with VLLM. It is used for compilation level of `CompilationLevel.PIECEWISE`, @@ -263,8 +322,14 @@ class VllmBackend: returned_callable: Callable def __init__(self, ): - # every instance of VllmBackend has its own graph pool - self.graph_pool = torch.cuda.graph_pool_handle() + global global_graph_pool + if global_graph_pool is None: + global_graph_pool = torch.cuda.graph_pool_handle() + + # TODO: in the future, if we want to use multiple + # streams, it might not be safe to share a global pool. + # only investigate this when we use multiple streams + self.graph_pool = global_graph_pool # `torch.compile` is JIT compiled, so we don't need to # do anything here @@ -286,55 +351,26 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: self.split_gm, self.piecewise_graphs = split_graph( graph, self.compilation_configs.non_cudagraph_ops) - returned_callable: Callable # type: ignore + from torch._dynamo.utils import lazy_format_graph_code + logger.debug("%s", + lazy_format_graph_code("stiching module", self.split_gm)) - if len(self.piecewise_graphs) == 0: - compilation_counter.num_piecewise_graphs_seen += 1 - compilation_counter.num_piecewise_capturable_graphs_seen += 1 - returned_callable = PiecewiseBackend(graph, - self.compilation_configs, - self.graph_pool, - is_first_graph=True) - else: - from torch._dynamo.utils import lazy_format_graph_code - logger.debug( - "%s", lazy_format_graph_code("stiching module", self.split_gm)) - - is_first_graph = True - - for item in self.piecewise_graphs: - compilation_counter.num_piecewise_graphs_seen += 1 - compilation_counter.num_piecewise_capturable_graphs_seen += not item.is_splitting_graph # noqa - if not item.is_splitting_graph: - # cannot setattr to a module, so we need to set - # the attribute in the __dict__ - self.split_gm.__dict__[ - item.submod_name] = PiecewiseBackend( - item.graph, self.compilation_configs, - self.graph_pool, is_first_graph) - is_first_graph = False - returned_callable = self.split_gm - - self.returned_callable = returned_callable - # trigger the first compilation - # code borrowed from https://github.com/pytorch/pytorch/blob/4e3e08b71171fa34172b2362ff668553fac75f27/torch/_dynamo/backends/distributed.py#L206 # noqa - # to turn the inputs into fake tensors - import torch._guards - from torch._guards import detect_fake_mode - fake_mode = detect_fake_mode(example_inputs) - fake_args = [] - for arg in example_inputs: - if isinstance(arg, torch.Tensor) and not isinstance( - arg, torch._subclasses.FakeTensor): - fake_args.append( - torch._dynamo.utils.to_fake_tensor(arg, fake_mode)) - else: - fake_args.append(arg) - self.returned_callable(*fake_args) + compilation_counter.num_piecewise_graphs_seen += len( + self.piecewise_graphs) + submod_names_to_compile = [ + item.submod_name for item in self.piecewise_graphs + if not item.is_splitting_graph + ] + + # propagate the split graph to the piecewise backend, + # compile submodules with symbolic shapes + PiecewiseCompileInterpreter(self.split_gm, submod_names_to_compile, + self.compilation_configs, + self.graph_pool).run(*example_inputs) self._called = True - return self.returned_callable + return self.split_gm @dataclasses.dataclass @@ -352,11 +388,10 @@ class ConcreteSizeEntry: class PiecewiseBackend: - def __init__(self, - graph: fx.GraphModule, - compilation_configs: CompilationConfig, - graph_pool: Any, - is_first_graph: bool = False): + def __init__(self, graph: fx.GraphModule, + compilation_configs: CompilationConfig, graph_pool: Any, + is_first_graph: bool, sym_shape_indices: List[int], + compiled_graph_for_general_shape: Callable): """ The backend for piecewise compilation. It mainly handles the compilation and cudagraph capturing. @@ -381,12 +416,11 @@ def __init__(self, self.compilation_configs.capture_sizes ) if self.compilation_configs.use_cudagraph else set() - self.compile_finished = False self.first_run_finished = False - self.compiled_graph_for_general_shape: Callable = None # type: ignore + self.compiled_graph_for_general_shape = compiled_graph_for_general_shape # noqa - self.sym_shape_indices: List[int] = [] + self.sym_shape_indices = sym_shape_indices # the entries for different shapes that we need to either # compile or capture cudagraph @@ -399,27 +433,6 @@ def __init__(self, ) def __call__(self, *args) -> Any: - - if not self.compile_finished: - self.compile_finished = True - - # this is the first compilation, we will compile a graph with - # dynamic shape, as the caller will mark first dimension as dynamic - - self.sym_shape_indices = [ - i for i, x in enumerate(args) if isinstance(x, torch.SymInt) - ] - - self.compiled_graph_for_general_shape = wrap_inductor( - self.graph, - args, - self.compilation_configs.inductor_compile_config, - runtime_shape=None, - do_logging=self.is_first_graph, - use_inductor=self.compilation_configs.use_inductor) - - return self.graph(*args) - if not self.first_run_finished: self.first_run_finished = True return self.compiled_graph_for_general_shape(*args) diff --git a/vllm/config.py b/vllm/config.py index 223981d630552..431bf55d6efb4 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1,6 +1,6 @@ import enum import json -from dataclasses import dataclass, field, fields +from dataclasses import dataclass, field from typing import (TYPE_CHECKING, Any, ClassVar, Dict, Final, List, Literal, Mapping, Optional, Set, Tuple, Type, Union) @@ -23,9 +23,13 @@ from ray.util.placement_group import PlacementGroup from vllm.executor.executor_base import ExecutorBase + from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) from vllm.model_executor.model_loader.loader import BaseModelLoader from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import ( BaseTokenizerGroup) +else: + QuantizationConfig = None logger = init_logger(__name__) @@ -55,6 +59,10 @@ class ModelConfig: "mistral" will always use the tokenizer from `mistral_common`. trust_remote_code: Trust remote code (e.g., from HuggingFace) when downloading the model and tokenizer. + allowed_local_media_path: Allowing API requests to read local images or + videos from directories specified by the server file system. + This is a security risk. Should only be enabled in trusted + environments. dtype: Data type for model weights and activations. The "auto" option will use FP16 precision for FP32 and FP16 models, and BF16 precision for BF16 models. @@ -134,6 +142,7 @@ def __init__( trust_remote_code: bool, dtype: Union[str, torch.dtype], seed: int, + allowed_local_media_path: str = "", revision: Optional[str] = None, code_revision: Optional[str] = None, rope_scaling: Optional[dict] = None, @@ -163,6 +172,7 @@ def __init__( self.tokenizer = tokenizer self.tokenizer_mode = tokenizer_mode self.trust_remote_code = trust_remote_code + self.allowed_local_media_path = allowed_local_media_path self.seed = seed self.revision = revision self.code_revision = code_revision @@ -1317,6 +1327,8 @@ def maybe_create_spec_config( tokenizer=target_model_config.tokenizer, tokenizer_mode=target_model_config.tokenizer_mode, trust_remote_code=target_model_config.trust_remote_code, + allowed_local_media_path=target_model_config. + allowed_local_media_path, dtype=target_model_config.dtype, seed=target_model_config.seed, revision=draft_revision, @@ -1939,9 +1951,9 @@ def __post_init__(self): f"installed. Original error:\n{otel_import_error_traceback}") -@dataclass(frozen=True) -class EngineConfig: - """Dataclass which contains all engine-related configuration. This +@dataclass +class VllmConfig: + """Dataclass which contains all vllm-related configuration. This simplifies passing around the distinct configurations in the codebase. """ @@ -1951,11 +1963,40 @@ class EngineConfig: scheduler_config: SchedulerConfig device_config: DeviceConfig load_config: LoadConfig - lora_config: Optional[LoRAConfig] - speculative_config: Optional[SpeculativeConfig] - decoding_config: Optional[DecodingConfig] - observability_config: Optional[ObservabilityConfig] - prompt_adapter_config: Optional[PromptAdapterConfig] + lora_config: Optional[LoRAConfig] = None + speculative_config: Optional[SpeculativeConfig] = None + decoding_config: Optional[DecodingConfig] = None + observability_config: Optional[ObservabilityConfig] = None + prompt_adapter_config: Optional[PromptAdapterConfig] = None + quant_config: Optional[QuantizationConfig] = None + + @staticmethod + def _get_quantization_config( + model_config: ModelConfig, + load_config: LoadConfig) -> Optional[QuantizationConfig]: + """Get the quantization config.""" + if model_config.quantization is not None: + from vllm.model_executor.model_loader.weight_utils import ( + get_quant_config) + quant_config = get_quant_config(model_config, load_config) + capability_tuple = current_platform.get_device_capability() + + if capability_tuple is not None: + capability = capability_tuple.to_int() + if capability < quant_config.get_min_capability(): + raise ValueError( + f"The quantization method {model_config.quantization} " + "is not supported for the current GPU. Minimum " + f"capability: {quant_config.get_min_capability()}. " + f"Current capability: {capability}.") + supported_dtypes = quant_config.get_supported_act_dtypes() + if model_config.dtype not in supported_dtypes: + raise ValueError( + f"{model_config.dtype} is not supported for quantization " + f"method {model_config.quantization}. Supported dtypes: " + f"{supported_dtypes}") + return quant_config + return None def __post_init__(self): """Verify configs are valid & consistent with each other. @@ -1974,8 +2015,7 @@ def __post_init__(self): self.prompt_adapter_config.verify_with_model_config( self.model_config) - def to_dict(self): - """Return the configs as a dictionary, for use in **kwargs. - """ - return dict( - (field.name, getattr(self, field.name)) for field in fields(self)) + if self.quant_config is None and \ + self.model_config is not None and self.load_config is not None: + self.quant_config = VllmConfig._get_quantization_config( + self.model_config, self.load_config) diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 88733b8f53b86..e56d5cddce424 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -828,8 +828,7 @@ def _schedule_priority_preemption( num_running_seqs) #Preempt out the victim sequence group - self._preempt(vseq_group, blocks_to_swap_out, - PreemptionMode.RECOMPUTE) + self._preempt(vseq_group, blocks_to_swap_out) waiting_queue.appendleft(vseq_group) force_preemption_count += 1 #Put the sequence back into the waiting queue @@ -1309,6 +1308,8 @@ def schedule( # `multi_modal_data` will be None. multi_modal_data=seq_group.multi_modal_data if scheduler_outputs.num_prefill_groups > 0 else None, + multi_modal_placeholders=seq_group.multi_modal_placeholders + if scheduler_outputs.num_prefill_groups > 0 else None, mm_processor_kwargs=seq_group.mm_processor_kwargs, prompt_adapter_request=seq_group.prompt_adapter_request, ) @@ -1451,12 +1452,8 @@ def _append_slots(self, if len(cows) > 0: blocks_to_copy.extend(cows) - def _preempt( - self, - seq_group: SequenceGroup, - blocks_to_swap_out: List[Tuple[int, int]], - preemption_mode: Optional[PreemptionMode] = None, - ) -> PreemptionMode: + def _preempt(self, seq_group: SequenceGroup, + blocks_to_swap_out: List[Tuple[int, int]]) -> PreemptionMode: # If preemption mode is not specified, we determine the mode as follows: # We use recomputation by default since it incurs lower overhead than # swapping. However, when the sequence group has multiple sequences diff --git a/vllm/distributed/device_communicators/shm_broadcast.py b/vllm/distributed/device_communicators/shm_broadcast.py index 7d526b25ed193..2ff1a1ead99c1 100644 --- a/vllm/distributed/device_communicators/shm_broadcast.py +++ b/vllm/distributed/device_communicators/shm_broadcast.py @@ -1,3 +1,4 @@ +import os import pickle import time from contextlib import contextmanager @@ -18,12 +19,6 @@ VLLM_RINGBUFFER_WARNING_INTERVAL = envs.VLLM_RINGBUFFER_WARNING_INTERVAL -# time to wait if the queue is full or empty -# if we sleep for too short, it will consume too much CPU -# if we sleep for too long, it will slow down the writer/reader -# 0.1 us is a good balance -RINGBUFFER_SLEEP_INTERVAL = 1e-7 - logger = init_logger(__name__) @@ -333,8 +328,8 @@ def acquire_write(self): # if this block is not ready to write, # we need to wait until it is read by all readers - # wait for a while - time.sleep(RINGBUFFER_SLEEP_INTERVAL) + # Release the processor to other threads + os.sched_yield() # if we wait for a long time, we should warn the user if (time.monotonic() - start_time > @@ -387,8 +382,8 @@ def acquire_read(self): # if this block is not ready, # we need to wait until it is written - # wait for a while - time.sleep(RINGBUFFER_SLEEP_INTERVAL) + # Release the processor to other threads + os.sched_yield() # if we wait for a long time, we should warn the user if (time.monotonic() - start_time > diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 16bd5a9056a67..609f651d8e8c7 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -9,10 +9,11 @@ import vllm.envs as envs from vllm.config import (CacheConfig, ConfigFormat, DecodingConfig, - DeviceConfig, EngineConfig, LoadConfig, LoadFormat, - LoRAConfig, ModelConfig, ObservabilityConfig, - ParallelConfig, PromptAdapterConfig, SchedulerConfig, - SpeculativeConfig, TaskOption, TokenizerPoolConfig) + DeviceConfig, LoadConfig, LoadFormat, LoRAConfig, + ModelConfig, ObservabilityConfig, ParallelConfig, + PromptAdapterConfig, SchedulerConfig, + SpeculativeConfig, TaskOption, TokenizerPoolConfig, + VllmConfig) from vllm.executor.executor_base import ExecutorBase from vllm.logger import init_logger from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS @@ -90,6 +91,7 @@ class EngineArgs: skip_tokenizer_init: bool = False tokenizer_mode: str = 'auto' trust_remote_code: bool = False + allowed_local_media_path: str = "" download_dir: Optional[str] = None load_format: str = 'auto' config_format: ConfigFormat = ConfigFormat.AUTO @@ -259,6 +261,13 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: parser.add_argument('--trust-remote-code', action='store_true', help='Trust remote code from huggingface.') + parser.add_argument( + '--allowed-local-media-path', + type=str, + help="Allowing API requests to read local images or videos" + "from directories specified by the server file system." + "This is a security risk." + "Should only be enabled in trusted environments") parser.add_argument('--download-dir', type=nullable_str, default=EngineArgs.download_dir, @@ -909,6 +918,7 @@ def create_model_config(self) -> ModelConfig: tokenizer=cast(str, self.tokenizer), tokenizer_mode=self.tokenizer_mode, trust_remote_code=self.trust_remote_code, + allowed_local_media_path=self.allowed_local_media_path, dtype=self.dtype, seed=self.seed, revision=self.revision, @@ -945,7 +955,7 @@ def create_load_config(self) -> LoadConfig: ignore_patterns=self.ignore_patterns, ) - def create_engine_config(self) -> EngineConfig: + def create_engine_config(self) -> VllmConfig: # gguf file needs a specific model loader and doesn't use hf_repo if check_gguf_file(self.model): self.quantization = self.load_format = "gguf" @@ -1157,7 +1167,7 @@ def create_engine_config(self) -> EngineConfig: or "all" in detailed_trace_modules, ) - return EngineConfig( + return VllmConfig( model_config=model_config, cache_config=cache_config, parallel_config=parallel_config, diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 5198467a6ac40..b0fdc67776bbd 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -7,8 +7,8 @@ from weakref import ReferenceType import vllm.envs as envs -from vllm.config import (DecodingConfig, EngineConfig, LoRAConfig, ModelConfig, - ParallelConfig, SchedulerConfig) +from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig, + ParallelConfig, SchedulerConfig, VllmConfig) from vllm.core.scheduler import SchedulerOutputs from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_timeout import asyncio_timeout @@ -604,7 +604,7 @@ def __del__(self): @classmethod def _get_executor_cls( - cls, engine_config: EngineConfig) -> Type[ExecutorAsyncBase]: + cls, engine_config: VllmConfig) -> Type[ExecutorAsyncBase]: distributed_executor_backend = ( engine_config.parallel_config.distributed_executor_backend) if isinstance(distributed_executor_backend, type): @@ -663,7 +663,7 @@ def _get_executor_cls( def from_engine_args( cls, engine_args: AsyncEngineArgs, - engine_config: Optional[EngineConfig] = None, + engine_config: Optional[VllmConfig] = None, start_engine_loop: bool = True, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, stat_loggers: Optional[Dict[str, StatLoggerBase]] = None, @@ -680,7 +680,7 @@ def from_engine_args( # Create the async LLM engine. engine = cls( - **engine_config.to_dict(), + vllm_config=engine_config, executor_class=executor_class, log_requests=not engine_args.disable_log_requests, log_stats=not engine_args.disable_log_stats, diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 51aaa77a75d72..bb8b98134c082 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -10,14 +10,12 @@ from typing import Set, Type, Union, cast, overload import torch -from typing_extensions import TypeIs, TypeVar +from typing_extensions import TypeVar import vllm.envs as envs -from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, - EngineConfig, LoadConfig, LoRAConfig, ModelConfig, - ObservabilityConfig, ParallelConfig, - PromptAdapterConfig, SchedulerConfig, - SpeculativeConfig) +from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig, + ObservabilityConfig, ParallelConfig, SchedulerConfig, + VllmConfig) from vllm.core.scheduler import (ScheduledSequenceGroup, Scheduler, SchedulerOutputs) from vllm.engine.arg_utils import EngineArgs @@ -31,9 +29,9 @@ from vllm.executor.executor_base import ExecutorBase from vllm.executor.gpu_executor import GPUExecutor from vllm.executor.ray_utils import initialize_ray_cluster -from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, - EncoderDecoderInputs, InputRegistry, PromptType, - TokensPrompt) +from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs, + PromptType) +from vllm.inputs.parse import is_encoder_decoder_inputs, is_token_prompt from vllm.inputs.preprocess import InputPreprocessor from vllm.logger import init_logger from vllm.logits_process import get_bad_words_logits_processors @@ -222,17 +220,7 @@ def validate_outputs( def __init__( self, - model_config: ModelConfig, - cache_config: CacheConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - device_config: DeviceConfig, - load_config: LoadConfig, - lora_config: Optional[LoRAConfig], - speculative_config: Optional[SpeculativeConfig], - decoding_config: Optional[DecodingConfig], - observability_config: Optional[ObservabilityConfig], - prompt_adapter_config: Optional[PromptAdapterConfig], + vllm_config: VllmConfig, executor_class: Type[ExecutorBase], log_stats: bool, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, @@ -240,6 +228,22 @@ def __init__( input_registry: InputRegistry = INPUT_REGISTRY, use_cached_outputs: bool = False, ) -> None: + + # TODO: remove the local variables and use self.* throughout the class. + model_config = self.model_config = vllm_config.model_config + cache_config = self.cache_config = vllm_config.cache_config + lora_config = self.lora_config = vllm_config.lora_config + parallel_config = self.parallel_config = vllm_config.parallel_config + scheduler_config = self.scheduler_config = vllm_config.scheduler_config + device_config = self.device_config = vllm_config.device_config + speculative_config = self.speculative_config = vllm_config.speculative_config # noqa + load_config = self.load_config = vllm_config.load_config + decoding_config = self.decoding_config = vllm_config.decoding_config or DecodingConfig( # noqa + ) + prompt_adapter_config = self.prompt_adapter_config = vllm_config.prompt_adapter_config # noqa + observability_config = self.observability_config = vllm_config.observability_config or ObservabilityConfig( # noqa + ) + logger.info( "Initializing an LLM engine (v%s) with config: " "model=%r, speculative_config=%r, tokenizer=%r, " @@ -338,18 +342,7 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer: self.input_processor = input_registry.create_input_processor( model_config) - self.model_executor = executor_class( - model_config=model_config, - cache_config=cache_config, - parallel_config=parallel_config, - scheduler_config=scheduler_config, - device_config=device_config, - lora_config=lora_config, - speculative_config=speculative_config, - load_config=load_config, - prompt_adapter_config=prompt_adapter_config, - observability_config=self.observability_config, - ) + self.model_executor = executor_class(vllm_config=vllm_config, ) if self.model_config.task != "embedding": self._initialize_kv_caches() @@ -506,7 +499,7 @@ def _initialize_kv_caches(self) -> None: @classmethod def _get_executor_cls(cls, - engine_config: EngineConfig) -> Type[ExecutorBase]: + engine_config: VllmConfig) -> Type[ExecutorBase]: distributed_executor_backend = ( engine_config.parallel_config.distributed_executor_backend) # Initialize the cluster and specify the executor class. @@ -580,7 +573,7 @@ def from_engine_args( executor_class = cls._get_executor_cls(engine_config) # Create the LLM engine. engine = cls( - **engine_config.to_dict(), + vllm_config=engine_config, executor_class=executor_class, log_stats=not engine_args.disable_log_stats, usage_context=usage_context, @@ -643,7 +636,7 @@ def _verify_args(self) -> None: def _add_processed_request( self, request_id: str, - processed_inputs: Union[DecoderOnlyInputs, EncoderDecoderInputs], + processed_inputs: ProcessorInputs, params: Union[SamplingParams, PoolingParams], arrival_time: float, lora_request: Optional[LoRARequest], @@ -674,18 +667,19 @@ def _add_processed_request( seq_id = next(self.seq_counter) eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request) - seq = Sequence(seq_id, processed_inputs, block_size, eos_token_id, + if is_encoder_decoder_inputs(processed_inputs): + decoder_inputs = processed_inputs["decoder"] + encoder_inputs = processed_inputs["encoder"] + else: + decoder_inputs = processed_inputs + encoder_inputs = None + + seq = Sequence(seq_id, decoder_inputs, block_size, eos_token_id, lora_request, prompt_adapter_request) - encoder_seq = None - if 'encoder_prompt_token_ids' in processed_inputs: - encoder_seq = Sequence(seq_id, - processed_inputs, - block_size, - eos_token_id, - lora_request, - prompt_adapter_request, - from_decoder_prompt=False) + encoder_seq = (None if encoder_inputs is None else Sequence( + seq_id, encoder_inputs, block_size, eos_token_id, lora_request, + prompt_adapter_request)) # Create a SequenceGroup based on SamplingParams or PoolingParams if isinstance(params, SamplingParams): @@ -879,7 +873,7 @@ def _validate_token_prompt(self, prompt: PromptType, # This needs to happen before multimodal input pre-processing, which # may add dummy tokens that aren't part of the tokenizer's # vocabulary. - if self._is_token_prompt(prompt): + if is_token_prompt(prompt): prompt_ids = prompt["prompt_token_ids"] if len(prompt_ids) == 0: # Empty prompt check is handled later @@ -889,10 +883,6 @@ def _validate_token_prompt(self, prompt: PromptType, raise ValueError( "Token id {} is out of vocabulary".format(max_input_id)) - @staticmethod - def _is_token_prompt(prompt: PromptType) -> TypeIs[TokensPrompt]: - return isinstance(prompt, dict) and "prompt_token_ids" in prompt - def _create_sequence_group_with_sampling( self, request_id: str, @@ -1690,6 +1680,7 @@ def _get_stats(self, num_prompt_tokens_requests: List[int] = [] num_generation_tokens_requests: List[int] = [] n_requests: List[int] = [] + max_tokens_requests: List[int] = [] finished_reason_requests: List[str] = [] # Lora requests @@ -1797,6 +1788,8 @@ def _get_stats(self, ]) if seq_group.sampling_params is not None: n_requests.append(seq_group.sampling_params.n) + max_tokens_requests.append( + seq_group.sampling_params.max_tokens) finished_reason_requests.extend([ SequenceStatus.get_finished_reason(seq.status) for seq in seq_group.get_finished_seqs() @@ -1852,6 +1845,7 @@ def _get_stats(self, num_prompt_tokens_requests=num_prompt_tokens_requests, num_generation_tokens_requests=num_generation_tokens_requests, n_requests=n_requests, + max_tokens_requests=max_tokens_requests, finished_reason_requests=finished_reason_requests, max_lora=str(max_lora_stat), waiting_lora_adapters=list(waiting_lora_adapters.keys()), @@ -1979,17 +1973,17 @@ def create_trace_span(self, seq_group: SequenceGroup) -> None: def is_encoder_decoder_model(self): return self.input_preprocessor.is_encoder_decoder_model() - def _validate_model_inputs(self, inputs: Union[DecoderOnlyInputs, - EncoderDecoderInputs], + def _validate_model_inputs(self, inputs: ProcessorInputs, lora_request: Optional[LoRARequest]): - if self.model_config.is_multimodal_model: + if is_encoder_decoder_inputs(inputs): # For encoder-decoder multimodal models, the max_prompt_len # restricts the decoder prompt length - prompt_ids = inputs.get("prompt_token_ids") - elif self.is_encoder_decoder_model(): - prompt_ids = inputs.get("encoder_prompt_token_ids") + prompt_inputs = inputs["decoder" if self.model_config. + is_multimodal_model else "encoder"] else: - prompt_ids = inputs.get("prompt_token_ids") + prompt_inputs = inputs + + prompt_ids = prompt_inputs.get("prompt_token_ids") if prompt_ids is None or len(prompt_ids) == 0: raise ValueError("Prompt cannot be empty") diff --git a/vllm/engine/metrics.py b/vllm/engine/metrics.py index 9ed30e1e99857..3e3357ed74633 100644 --- a/vllm/engine/metrics.py +++ b/vllm/engine/metrics.py @@ -179,6 +179,12 @@ def __init__(self, labelnames: List[str], max_model_len: int): labelnames=labelnames, buckets=[1, 2, 5, 10, 20], ) + self.histogram_max_tokens_request = self._histogram_cls( + name="vllm:request_params_max_tokens", + documentation="Histogram of the max_tokens request parameter.", + labelnames=labelnames, + buckets=build_1_2_5_buckets(max_model_len), + ) self.counter_request_success = self._counter_cls( name="vllm:request_success_total", documentation="Count of successfully processed requests.", @@ -547,6 +553,8 @@ def _log_prometheus(self, stats: Stats) -> None: self.metrics.histogram_num_generation_tokens_request, stats.num_generation_tokens_requests) self._log_histogram(self.metrics.histogram_n_request, stats.n_requests) + self._log_histogram(self.metrics.histogram_max_tokens_request, + stats.max_tokens_requests) def _log_prometheus_interval(self, prompt_throughput: float, generation_throughput: float) -> None: diff --git a/vllm/engine/metrics_types.py b/vllm/engine/metrics_types.py index 510dd04bb3e55..25b7a7479672a 100644 --- a/vllm/engine/metrics_types.py +++ b/vllm/engine/metrics_types.py @@ -53,6 +53,7 @@ class Stats: num_prompt_tokens_requests: List[int] num_generation_tokens_requests: List[int] n_requests: List[int] + max_tokens_requests: List[int] finished_reason_requests: List[str] waiting_lora_adapters: List[str] running_lora_adapters: List[str] diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py index 6e6630b3ff55f..882742c2fc61b 100644 --- a/vllm/engine/multiprocessing/client.py +++ b/vllm/engine/multiprocessing/client.py @@ -13,7 +13,7 @@ from zmq.asyncio import Socket from vllm import PoolingParams -from vllm.config import DecodingConfig, EngineConfig, ModelConfig +from vllm.config import DecodingConfig, ModelConfig, VllmConfig from vllm.core.scheduler import SchedulerOutputs from vllm.engine.arg_utils import AsyncEngineArgs # yapf conflicts with isort for this block @@ -78,7 +78,7 @@ class MQLLMEngineClient(EngineClient): every N seconds, confirming the engine is healthy """ - def __init__(self, ipc_path: str, engine_config: EngineConfig, + def __init__(self, ipc_path: str, engine_config: VllmConfig, engine_pid: int): self.context = zmq.asyncio.Context() self._errored_with: Optional[BaseException] = None @@ -112,7 +112,11 @@ def __init__(self, ipc_path: str, engine_config: EngineConfig, # Stream for each individual request. self.output_queues: Dict[str, asyncio.Queue] = {} - self.output_loop = asyncio.create_task(self.run_output_handler_loop()) + + # Loop to handle output of the LLMEngine periodically. + # Started after the MQLLMEngine is ready so that we can + # build the Client in an executor to enable clean shutdown. + self.output_loop: Optional[asyncio.Task] = None # Loop to check health of the LLMEngine periodically. # Started after the MQLLMEngine is ready. @@ -247,6 +251,9 @@ async def run_output_handler_loop(self): async def setup(self): """Setup the client before it starts sending server requests.""" + # Start output_loop + self.output_loop = asyncio.create_task(self.run_output_handler_loop()) + with self.get_data_socket() as socket: # Wait until server is ready. response = await self._wait_for_server_rpc(socket) @@ -265,7 +272,8 @@ def close(self): # Cancel background tasks. if self.health_loop is not None: self.health_loop.cancel() - self.output_loop.cancel() + if self.output_loop is not None: + self.output_loop.cancel() def _set_errored(self, e: BaseException): logger.exception(repr(e)) diff --git a/vllm/engine/multiprocessing/engine.py b/vllm/engine/multiprocessing/engine.py index 0a7f430eca488..9dd6fa5b14315 100644 --- a/vllm/engine/multiprocessing/engine.py +++ b/vllm/engine/multiprocessing/engine.py @@ -5,10 +5,9 @@ import cloudpickle import zmq +from ray.exceptions import RayTaskError from vllm import AsyncEngineArgs, SamplingParams -from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig, - ParallelConfig, SchedulerConfig) # yapf conflicts with isort for this block # yapf: disable from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT, @@ -30,9 +29,6 @@ else: from vllm.engine.llm_engine import LLMEngine -CONFIG_TYPE = Union[ModelConfig, DecodingConfig, ParallelConfig, - SchedulerConfig, LoRAConfig] - logger = init_logger(__name__) POLLING_TIMEOUT_MS = 10000 @@ -130,7 +126,7 @@ def from_engine_args(cls, engine_args: AsyncEngineArgs, return cls(ipc_path=ipc_path, use_async_sockets=use_async_sockets, - **engine_config.to_dict(), + vllm_config=engine_config, executor_class=executor_class, log_requests=not engine_args.disable_log_requests, log_stats=not engine_args.disable_log_stats, @@ -310,6 +306,11 @@ def _health_check(self): def _send_outputs(self, outputs: REQUEST_OUTPUTS_T): """Send List of RequestOutput to RPCClient.""" if outputs: + # RayTaskError might not pickelable here. We need to unpack the + # underlying exception as the real exception in the output. + if (isinstance(outputs, RPCError) + and isinstance(outputs.exception, RayTaskError)): + outputs.exception = outputs.exception.cause output_bytes = pickle.dumps(outputs) self.output_socket.send_multipart((output_bytes, ), copy=False) @@ -348,16 +349,22 @@ def stop_profile(self) -> None: self.engine.model_executor._run_workers("stop_profile") +def signal_handler(*_) -> None: + raise KeyboardInterrupt("MQLLMEngine terminated") + + def run_mp_engine(engine_args: AsyncEngineArgs, usage_context: UsageContext, - ipc_path: str): + ipc_path: str, engine_alive): + try: + engine = MQLLMEngine.from_engine_args(engine_args=engine_args, + usage_context=usage_context, + ipc_path=ipc_path) - def signal_handler(*_) -> None: - # Interrupt server on sigterm - raise KeyboardInterrupt("MQLLMEngine terminated") + signal.signal(signal.SIGTERM, signal_handler) - signal.signal(signal.SIGTERM, signal_handler) + engine.start() - engine = MQLLMEngine.from_engine_args(engine_args=engine_args, - usage_context=usage_context, - ipc_path=ipc_path) - engine.start() + except BaseException as e: + logger.exception(e) + engine_alive.value = False + raise e diff --git a/vllm/engine/protocol.py b/vllm/engine/protocol.py index 6a09361c56865..e0b59d94cfdc3 100644 --- a/vllm/engine/protocol.py +++ b/vllm/engine/protocol.py @@ -1,11 +1,12 @@ import asyncio from abc import ABC, abstractmethod -from typing import AsyncGenerator, List, Mapping, Optional, Union +from typing import AsyncGenerator, List, Mapping, Optional from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function from vllm.config import DecodingConfig, ModelConfig from vllm.core.scheduler import SchedulerOutputs from vllm.inputs.data import PromptType, TokensPrompt +from vllm.inputs.parse import is_explicit_encoder_decoder_prompt from vllm.inputs.preprocess import InputPreprocessor from vllm.logger import init_logger from vllm.lora.request import LoRARequest @@ -60,7 +61,7 @@ def generate( async def beam_search( self, - prompt: Union[PromptType, List[int]], + prompt: PromptType, model_config: ModelConfig, request_id: str, params: BeamSearchParams, @@ -76,11 +77,19 @@ async def beam_search( tokenizer = await self.get_tokenizer() input_preprocessor = InputPreprocessor(model_config, tokenizer) - (prompt_text, prompt_token_ids, multi_modal_data, - mm_processor_kwargs) = input_preprocessor._extract_prompt_components( - prompt, - request_id=request_id, - ) + if is_explicit_encoder_decoder_prompt(prompt): + raise NotImplementedError + else: + processed_inputs = input_preprocessor._prompt_to_llm_inputs( + prompt, + request_id=request_id, + ) + + prompt_token_ids = processed_inputs["prompt_token_ids"] + prompt_text = processed_inputs.get("prompt") + multi_modal_data = processed_inputs.get("multi_modal_data") + mm_processor_kwargs = processed_inputs.get("mm_processor_kwargs") + tokenized_length = len(prompt_token_ids) sort_beams_key = create_sort_beams_key_function( diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index 8a5f18c860cb9..53ee7ac5a02f7 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -327,7 +327,8 @@ def _placeholder_str(self, modality: ModalityStr, if model_type.startswith("llava"): return self._cached_token_str(self._tokenizer, hf_config.image_token_index) - if model_type in ("chameleon", "internvl_chat", "NVLM_D"): + if model_type in ("chameleon", "internvl_chat", "NVLM_D", + "h2ovl_chat"): return "" if model_type == "mllama": return "<|image|>" @@ -446,7 +447,9 @@ def __init__(self, tracker: MultiModalItemTracker) -> None: self._tracker = tracker def parse_image(self, image_url: str) -> None: - image = get_and_parse_image(image_url) + image = get_and_parse_image(image_url, + allowed_local_media_path=self._tracker. + _model_config.allowed_local_media_path) placeholder = self._tracker.add("image", image) self._add_placeholder(placeholder) @@ -466,7 +469,10 @@ def __init__(self, tracker: AsyncMultiModalItemTracker) -> None: self._tracker = tracker def parse_image(self, image_url: str) -> None: - image_coro = async_get_and_parse_image(image_url) + image_coro = async_get_and_parse_image( + image_url, + allowed_local_media_path=self._tracker._model_config. + allowed_local_media_path) placeholder = self._tracker.add("image", image_coro) self._add_placeholder(placeholder) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 0cb56261ebd77..4472511f19e9e 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -60,6 +60,10 @@ class LLM: from the input. trust_remote_code: Trust remote code (e.g., from HuggingFace) when downloading the model and tokenizer. + allowed_local_media_path: Allowing API requests to read local images + or videos from directories specified by the server file system. + This is a security risk. Should only be enabled in trusted + environments. tensor_parallel_size: The number of GPUs to use for distributed execution with tensor parallelism. dtype: The data type for the model weights and activations. Currently, @@ -141,6 +145,7 @@ def __init__( tokenizer_mode: str = "auto", skip_tokenizer_init: bool = False, trust_remote_code: bool = False, + allowed_local_media_path: str = "", tensor_parallel_size: int = 1, dtype: str = "auto", quantization: Optional[str] = None, @@ -181,6 +186,7 @@ def __init__( tokenizer_mode=tokenizer_mode, skip_tokenizer_init=skip_tokenizer_init, trust_remote_code=trust_remote_code, + allowed_local_media_path=allowed_local_media_path, tensor_parallel_size=tensor_parallel_size, dtype=dtype, quantization=quantization, diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 9a5e44468e335..5582651b31102 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -172,39 +172,44 @@ async def build_async_engine_client_from_engine_args( # so we need to spawn a new process context = multiprocessing.get_context("spawn") + # The Process can raise an exception during startup, which may + # not actually result in an exitcode being reported. As a result + # we use a shared variable to communicate the information. + engine_alive = multiprocessing.Value('b', True, lock=False) engine_process = context.Process(target=run_mp_engine, args=(engine_args, UsageContext.OPENAI_API_SERVER, - ipc_path)) + ipc_path, engine_alive)) engine_process.start() engine_pid = engine_process.pid - assert engine_pid is not None, "Engine process failed to start" + assert engine_pid is not None, "Engine process failed to start." logger.info("Started engine process with PID %d", engine_pid) # Build RPCClient, which conforms to EngineClient Protocol. - # NOTE: Actually, this is not true yet. We still need to support - # embedding models via RPC (see TODO above) engine_config = engine_args.create_engine_config() - mp_engine_client = MQLLMEngineClient(ipc_path, engine_config, - engine_pid) - + build_client = partial(MQLLMEngineClient, ipc_path, engine_config, + engine_pid) + mq_engine_client = await asyncio.get_running_loop().run_in_executor( + None, build_client) try: while True: try: - await mp_engine_client.setup() + await mq_engine_client.setup() break except TimeoutError: - if not engine_process.is_alive(): + if (not engine_process.is_alive() + or not engine_alive.value): raise RuntimeError( - "Engine process failed to start") from None + "Engine process failed to start. See stack " + "trace for the root cause.") from None - yield mp_engine_client # type: ignore[misc] + yield mq_engine_client # type: ignore[misc] finally: # Ensure rpc server process was terminated engine_process.terminate() # Close all open connections to the backend - mp_engine_client.close() + mq_engine_client.close() # Wait for engine process to join engine_process.join(4) diff --git a/vllm/executor/cpu_executor.py b/vllm/executor/cpu_executor.py index e32993e0e452e..ab3ebb4e43d18 100644 --- a/vllm/executor/cpu_executor.py +++ b/vllm/executor/cpu_executor.py @@ -138,18 +138,11 @@ def _create_worker( assert self.distributed_init_method is not None kwargs = dict( - model_config=self.model_config, - parallel_config=self.parallel_config, - scheduler_config=self.scheduler_config, - device_config=self.device_config, - cache_config=self.cache_config, - load_config=self.load_config, + vllm_config=self.vllm_config, local_rank=local_rank, rank=rank, distributed_init_method=self.distributed_init_method, - lora_config=self.lora_config, kv_cache_dtype=self.cache_config.cache_dtype, - prompt_adapter_config=self.prompt_adapter_config, is_driver_worker=rank == 0, ) wrapper.init_worker(**kwargs) diff --git a/vllm/executor/executor_base.py b/vllm/executor/executor_base.py index c96cb0f2c2981..9cba189dd57f9 100644 --- a/vllm/executor/executor_base.py +++ b/vllm/executor/executor_base.py @@ -1,10 +1,7 @@ from abc import ABC, abstractmethod from typing import List, Optional, Set, Tuple -from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, - ModelConfig, ObservabilityConfig, ParallelConfig, - PromptAdapterConfig, SchedulerConfig, - SpeculativeConfig) +from vllm.config import VllmConfig from vllm.lora.request import LoRARequest from vllm.model_executor.layers.sampler import SamplerOutput from vllm.prompt_adapter.request import PromptAdapterRequest @@ -23,27 +20,19 @@ class ExecutorBase(ABC): def __init__( self, - model_config: ModelConfig, - cache_config: CacheConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - device_config: DeviceConfig, - load_config: LoadConfig, - lora_config: Optional[LoRAConfig], - speculative_config: Optional[SpeculativeConfig], - prompt_adapter_config: Optional[PromptAdapterConfig], - observability_config: Optional[ObservabilityConfig], + vllm_config: VllmConfig, ) -> None: - self.model_config = model_config - self.cache_config = cache_config - self.lora_config = lora_config - self.load_config = load_config - self.parallel_config = parallel_config - self.scheduler_config = scheduler_config - self.device_config = device_config - self.speculative_config = speculative_config - self.prompt_adapter_config = prompt_adapter_config - self.observability_config = observability_config + self.vllm_config = vllm_config + self.model_config = vllm_config.model_config + self.cache_config = vllm_config.cache_config + self.lora_config = vllm_config.lora_config + self.load_config = vllm_config.load_config + self.parallel_config = vllm_config.parallel_config + self.scheduler_config = vllm_config.scheduler_config + self.device_config = vllm_config.device_config + self.speculative_config = vllm_config.speculative_config + self.prompt_adapter_config = vllm_config.prompt_adapter_config + self.observability_config = vllm_config.observability_config self._init_executor() @abstractmethod diff --git a/vllm/executor/gpu_executor.py b/vllm/executor/gpu_executor.py index ed30d3186a453..c65d0836e5ff7 100644 --- a/vllm/executor/gpu_executor.py +++ b/vllm/executor/gpu_executor.py @@ -49,21 +49,12 @@ def _get_worker_kwargs( distributed_init_method = get_distributed_init_method( get_ip(), get_open_port()) return dict( - model_config=self.model_config, - parallel_config=self.parallel_config, - scheduler_config=self.scheduler_config, - device_config=self.device_config, - cache_config=self.cache_config, - load_config=self.load_config, + vllm_config=self.vllm_config, local_rank=local_rank, rank=rank, distributed_init_method=distributed_init_method, - lora_config=self.lora_config, - speculative_config=self.speculative_config, - prompt_adapter_config=self.prompt_adapter_config, is_driver_worker=(not self.parallel_config) or (rank % self.parallel_config.tensor_parallel_size == 0), - observability_config=self.observability_config, ) def _get_worker_module_and_class( diff --git a/vllm/executor/neuron_executor.py b/vllm/executor/neuron_executor.py index f2fcfa58b26e1..02d37cd7fbf23 100644 --- a/vllm/executor/neuron_executor.py +++ b/vllm/executor/neuron_executor.py @@ -29,11 +29,7 @@ def _init_worker(self): distributed_init_method = get_distributed_init_method( get_ip(), get_open_port()) self.driver_worker = NeuronWorker( - model_config=self.model_config, - parallel_config=self.parallel_config, - scheduler_config=self.scheduler_config, - device_config=self.device_config, - cache_config=self.cache_config, + vllm_config=self.vllm_config, local_rank=0, rank=0, distributed_init_method=distributed_init_method) diff --git a/vllm/executor/openvino_executor.py b/vllm/executor/openvino_executor.py index d0c0333854dae..d06b0ccb7906e 100644 --- a/vllm/executor/openvino_executor.py +++ b/vllm/executor/openvino_executor.py @@ -48,16 +48,10 @@ def _init_worker(self): get_ip(), get_open_port()) self.driver_worker = OpenVINOWorker( ov_core=self.ov_core, - model_config=self.model_config, - parallel_config=self.parallel_config, - scheduler_config=self.scheduler_config, - device_config=self.device_config, - cache_config=self.cache_config, - load_config=self.load_config, + vllm_config=self.vllm_config, local_rank=0, rank=0, distributed_init_method=distributed_init_method, - lora_config=self.lora_config, kv_cache_dtype=self.cache_config.cache_dtype, is_driver_worker=True, ) diff --git a/vllm/executor/tpu_executor.py b/vllm/executor/tpu_executor.py index 972649dedf33e..e37e8973790db 100644 --- a/vllm/executor/tpu_executor.py +++ b/vllm/executor/tpu_executor.py @@ -44,12 +44,7 @@ def _get_worker_kwargs( distributed_init_method = get_distributed_init_method( get_ip(), get_open_port()) return dict( - model_config=self.model_config, - parallel_config=self.parallel_config, - scheduler_config=self.scheduler_config, - device_config=self.device_config, - cache_config=self.cache_config, - load_config=self.load_config, + vllm_config=self.vllm_config, local_rank=local_rank, rank=rank, distributed_init_method=distributed_init_method, diff --git a/vllm/executor/xpu_executor.py b/vllm/executor/xpu_executor.py index 5f78993ddc4b4..36b7e2265efab 100644 --- a/vllm/executor/xpu_executor.py +++ b/vllm/executor/xpu_executor.py @@ -2,10 +2,7 @@ import torch -from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, - ModelConfig, ObservabilityConfig, ParallelConfig, - PromptAdapterConfig, SchedulerConfig, - SpeculativeConfig) +from vllm.config import ModelConfig, ParallelConfig from vllm.executor.executor_base import ExecutorAsyncBase from vllm.executor.gpu_executor import GPUExecutor from vllm.logger import init_logger @@ -21,38 +18,13 @@ class XPUExecutor(GPUExecutor): uses_ray: bool = False - def __init__( - self, - model_config: ModelConfig, - cache_config: CacheConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - device_config: DeviceConfig, - load_config: LoadConfig, - lora_config: Optional[LoRAConfig], - prompt_adapter_config: Optional[PromptAdapterConfig], - speculative_config: Optional[SpeculativeConfig], - observability_config: Optional[ObservabilityConfig], - ) -> None: - assert device_config.device_type == "xpu" - assert (not speculative_config - ), "Speculative decoding not yet supported for XPU backend" - - model_config = _verify_and_get_model_config(model_config) - - self.model_config = model_config - self.cache_config = cache_config - self.load_config = load_config - self.lora_config = lora_config - self.parallel_config = _verify_and_get_parallel_config(parallel_config) - self.scheduler_config = scheduler_config - self.device_config = device_config - self.prompt_adapter_config = prompt_adapter_config - self.speculative_config = None - self.observability_config = observability_config - - # Instantiate the worker and load the model to GPU. - self._init_executor() + def _init_executor(self) -> None: + assert self.device_config.device_type == "xpu" + assert self.speculative_config is None, ( + "Speculative decoding not yet supported for XPU backend") + + self.model_config = _verify_and_get_model_config(self.model_config) + GPUExecutor._init_executor(self) def _get_worker_module_and_class( self) -> Tuple[str, str, Optional[Callable[[], Type[WorkerBase]]]]: diff --git a/vllm/inputs/__init__.py b/vllm/inputs/__init__.py index 7b73922ddd2c5..68ac50a2c5a16 100644 --- a/vllm/inputs/__init__.py +++ b/vllm/inputs/__init__.py @@ -1,9 +1,9 @@ from .data import (DecoderOnlyInputs, EncoderDecoderInputs, - ExplicitEncoderDecoderPrompt, PromptType, SingletonInputs, - SingletonPrompt, TextPrompt, TokenInputs, TokensPrompt, - build_explicit_enc_dec_prompt, to_enc_dec_tuple_list, - token_inputs, zip_enc_dec_prompts) -from .registry import InputContext, InputRegistry + ExplicitEncoderDecoderPrompt, ProcessorInputs, PromptType, + SingletonInputs, SingletonPrompt, TextPrompt, TokenInputs, + TokensPrompt, build_explicit_enc_dec_prompt, + to_enc_dec_tuple_list, token_inputs, zip_enc_dec_prompts) +from .registry import DummyData, InputContext, InputRegistry INPUT_REGISTRY = InputRegistry() """ @@ -22,13 +22,15 @@ "ExplicitEncoderDecoderPrompt", "TokenInputs", "token_inputs", - "SingletonInputs", "DecoderOnlyInputs", "EncoderDecoderInputs", + "ProcessorInputs", + "SingletonInputs", "build_explicit_enc_dec_prompt", "to_enc_dec_tuple_list", "zip_enc_dec_prompts", "INPUT_REGISTRY", + "DummyData", "InputContext", "InputRegistry", ] diff --git a/vllm/inputs/data.py b/vllm/inputs/data.py index 9a094191eda38..46b41f431bec7 100644 --- a/vllm/inputs/data.py +++ b/vllm/inputs/data.py @@ -1,10 +1,10 @@ -from typing import (TYPE_CHECKING, Any, Dict, Generic, Iterable, List, +from typing import (TYPE_CHECKING, Any, Dict, Generic, Iterable, List, Literal, Optional, Tuple, Union, cast) from typing_extensions import NotRequired, TypedDict, TypeVar if TYPE_CHECKING: - from vllm.multimodal import MultiModalDataDict + from vllm.multimodal import MultiModalDataDict, MultiModalPlaceholderDict class TextPrompt(TypedDict): @@ -122,21 +122,30 @@ class ExplicitEncoderDecoderPrompt(TypedDict, Generic[_T1_co, _T2_co]): class TokenInputs(TypedDict): """Represents token-based inputs.""" + + type: Literal["token"] + """The type of inputs.""" + prompt_token_ids: List[int] """The token IDs of the prompt.""" - prompt: NotRequired[Optional[str]] + prompt: NotRequired[str] """ The original prompt text corresponding to the token IDs, if available. """ - multi_modal_data: NotRequired[Optional["MultiModalDataDict"]] + multi_modal_data: NotRequired["MultiModalDataDict"] """ Optional multi-modal data to pass to the model, if the model supports it. """ - mm_processor_kwargs: NotRequired[Optional[Dict[str, Any]]] + multi_modal_placeholders: NotRequired["MultiModalPlaceholderDict"] + """ + Placeholder ranges for the multi-modal data. + """ + + mm_processor_kwargs: NotRequired[Dict[str, Any]] """ Optional multi-modal processor kwargs to be forwarded to the multimodal input mapper & processor. Note that if multiple modalities @@ -149,27 +158,24 @@ def token_inputs( prompt_token_ids: List[int], prompt: Optional[str] = None, multi_modal_data: Optional["MultiModalDataDict"] = None, + multi_modal_placeholders: Optional["MultiModalPlaceholderDict"] = None, mm_processor_kwargs: Optional[Dict[str, Any]] = None, ) -> TokenInputs: """Construct :class:`TokenInputs` from optional values.""" - inputs = TokenInputs(prompt_token_ids=prompt_token_ids) + inputs = TokenInputs(type="token", prompt_token_ids=prompt_token_ids) if prompt is not None: inputs["prompt"] = prompt if multi_modal_data is not None: inputs["multi_modal_data"] = multi_modal_data + if multi_modal_placeholders is not None: + inputs["multi_modal_placeholders"] = multi_modal_placeholders if mm_processor_kwargs is not None: inputs["mm_processor_kwargs"] = mm_processor_kwargs return inputs -SingletonInputs = TokenInputs -""" -A processed :class:`SingletonPrompt` which can be passed to -:class:`vllm.sequence.Sequence`. -""" - DecoderOnlyInputs = TokenInputs """ The inputs in :class:`~vllm.LLMEngine` before they are @@ -178,28 +184,30 @@ def token_inputs( """ -class EncoderDecoderInputs(TokenInputs): +class EncoderDecoderInputs(TypedDict): """ The inputs in :class:`~vllm.LLMEngine` before they are passed to the model executor. This specifies the required data for encoder-decoder models. """ - encoder_prompt_token_ids: List[int] - """The token IDs of the encoder prompt.""" + encoder: TokenInputs + """The inputs for the encoder portion.""" - encoder_prompt: NotRequired[Optional[str]] - """ - The original encoder prompt text corresponding to the token IDs, if - available. - """ + decoder: TokenInputs + """The inputs for the decoder portion.""" - encoder_multi_modal_data: NotRequired[Optional["MultiModalDataDict"]] - """ - Optional multi-modal data to pass to the encoder model, - if the model supports it. - """ +SingletonInputs = TokenInputs +""" +A processed :class:`SingletonPrompt` which can be passed to +:class:`vllm.sequence.Sequence`. +""" + +ProcessorInputs = Union[DecoderOnlyInputs, EncoderDecoderInputs] +""" +The inputs to :data:`vllm.inputs.InputProcessor`. +""" _T1 = TypeVar("_T1", bound=SingletonPrompt, default=SingletonPrompt) _T2 = TypeVar("_T2", bound=SingletonPrompt, default=SingletonPrompt) diff --git a/vllm/inputs/parse.py b/vllm/inputs/parse.py index e79d2c813bb4f..09f1ff2cb42e9 100644 --- a/vllm/inputs/parse.py +++ b/vllm/inputs/parse.py @@ -4,9 +4,9 @@ from vllm.utils import is_list_of -from .data import (DecoderOnlyInputs, EncoderDecoderInputs, - ExplicitEncoderDecoderPrompt, PromptType, SingletonPrompt, - TextPrompt, TokensPrompt) +from .data import (EncoderDecoderInputs, ExplicitEncoderDecoderPrompt, + ProcessorInputs, PromptType, SingletonPrompt, TextPrompt, + TokensPrompt) class ParsedText(TypedDict): @@ -98,12 +98,15 @@ def parse_singleton_prompt( raise TypeError("inputs must be a string, TextPrompt, or TokensPrompt") +def is_token_prompt(prompt: PromptType) -> TypeIs[TokensPrompt]: + return isinstance(prompt, dict) and "prompt_token_ids" in prompt + + def is_explicit_encoder_decoder_prompt( prompt: PromptType) -> TypeIs[ExplicitEncoderDecoderPrompt]: return isinstance(prompt, dict) and "encoder_prompt" in prompt def is_encoder_decoder_inputs( - inputs: Union[DecoderOnlyInputs, EncoderDecoderInputs], -) -> TypeIs[EncoderDecoderInputs]: - return "encoder_prompt_token_ids" in inputs + inputs: ProcessorInputs) -> TypeIs[EncoderDecoderInputs]: + return "encoder" in inputs and "decoder" in inputs diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py index 82ce7d392b719..a5c787a56b5a9 100644 --- a/vllm/inputs/preprocess.py +++ b/vllm/inputs/preprocess.py @@ -1,5 +1,5 @@ import asyncio -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union +from typing import List, Optional from typing_extensions import assert_never @@ -10,22 +10,12 @@ from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup from vllm.utils import print_warning_once -from .data import (DecoderOnlyInputs, EncoderDecoderInputs, PromptType, - SingletonPrompt) +from .data import (DecoderOnlyInputs, EncoderDecoderInputs, ProcessorInputs, + PromptType, SingletonInputs, SingletonPrompt, token_inputs) from .parse import is_explicit_encoder_decoder_prompt, parse_singleton_prompt -if TYPE_CHECKING: - from vllm.multimodal import MultiModalDataDict - logger = init_logger(__name__) -PromptComponents = Tuple[Optional[str], List[int], - Optional["MultiModalDataDict"], Optional[Dict[str, - Any]]] -DecoderPromptComponents = Tuple[Optional[str], Optional[List[int]], - Optional["MultiModalDataDict"], - Optional[Dict[str, Any]]] - class InputPreprocessor: @@ -115,7 +105,7 @@ def _get_default_enc_dec_decoder_prompt(self) -> List[int]: "default" decoder prompt be . However, it is possible that in the future - other models may have different or more + other models may have different or more complex logic for the default decoder prompt. This motivates having a special helper method for default decoder prompts. @@ -132,7 +122,6 @@ def _get_default_enc_dec_decoder_prompt(self) -> List[int]: def _prepare_decoder_input_ids_for_generation( self, decoder_input_ids: Optional[List[int]], - force_bos: bool = True, ) -> List[int]: """ Prepares `decoder_input_ids` for generation with encoder-decoder models. @@ -162,8 +151,8 @@ def _prepare_decoder_input_ids_for_generation( # use decoder_start_token_id as decoder_input_ids decoder_input_ids = self._get_default_enc_dec_decoder_prompt() - if force_bos and (len(decoder_input_ids) == 0 - or decoder_input_ids[0] != decoder_start_token_id): + if (len(decoder_input_ids) == 0 + or decoder_input_ids[0] != decoder_start_token_id): decoder_input_ids = [decoder_start_token_id] + decoder_input_ids return decoder_input_ids @@ -209,12 +198,12 @@ async def _tokenize_prompt_async( prompt=prompt, lora_request=lora_request) - def _extract_prompt_components( + def _prompt_to_llm_inputs( self, prompt: SingletonPrompt, request_id: str, lora_request: Optional[LoRARequest] = None, - ) -> PromptComponents: + ) -> SingletonInputs: ''' Extract the components of any single encoder or decoder input prompt. @@ -241,34 +230,52 @@ def _extract_prompt_components( request_id=request_id, lora_request=lora_request, ) - multi_modal_data = None - mm_processor_kwargs = None - elif parsed["type"] == "tokens": - prompt_text = None - prompt_token_ids = parsed["content"]["prompt_token_ids"] - multi_modal_data = parsed["content"].get("multi_modal_data") - mm_processor_kwargs = parsed["content"].get("mm_processor_kwargs") - elif parsed["type"] == "text": - prompt_text = parsed["content"]["prompt"] + + return token_inputs( + prompt=prompt_text, + prompt_token_ids=prompt_token_ids, + ) + + if parsed["type"] == "tokens": + tokens_content = parsed["content"] + + prompt_token_ids = tokens_content["prompt_token_ids"] + multi_modal_data = tokens_content.get("multi_modal_data") + mm_processor_kwargs = tokens_content.get("mm_processor_kwargs") + + return token_inputs( + prompt_token_ids=prompt_token_ids, + multi_modal_data=multi_modal_data, + mm_processor_kwargs=mm_processor_kwargs, + ) + + if parsed["type"] == "text": + text_content = parsed["content"] + + prompt_text = text_content["prompt"] prompt_token_ids = self._tokenize_prompt( prompt_text, request_id=request_id, lora_request=lora_request, ) - multi_modal_data = parsed["content"].get("multi_modal_data") - mm_processor_kwargs = parsed["content"].get("mm_processor_kwargs") - else: - assert_never(parsed) + multi_modal_data = text_content.get("multi_modal_data") + mm_processor_kwargs = text_content.get("mm_processor_kwargs") + + return token_inputs( + prompt=prompt_text, + prompt_token_ids=prompt_token_ids, + multi_modal_data=multi_modal_data, + mm_processor_kwargs=mm_processor_kwargs, + ) - return (prompt_text, prompt_token_ids, multi_modal_data, - mm_processor_kwargs) + assert_never(parsed) - async def _extract_prompt_components_async( + async def _prompt_to_llm_inputs_async( self, prompt: SingletonPrompt, request_id: str, lora_request: Optional[LoRARequest] = None, - ) -> PromptComponents: + ) -> SingletonInputs: """Async version of :meth:`_extract_prompt_components`.""" parsed = parse_singleton_prompt(prompt) @@ -279,59 +286,74 @@ async def _extract_prompt_components_async( request_id=request_id, lora_request=lora_request, ) - multi_modal_data = None - mm_processor_kwargs = None - elif parsed["type"] == "tokens": - prompt_text = None - prompt_token_ids = parsed["content"]["prompt_token_ids"] - multi_modal_data = parsed["content"].get("multi_modal_data") - mm_processor_kwargs = parsed["content"].get("mm_processor_kwargs") - elif parsed["type"] == "text": - prompt_text = parsed["content"]["prompt"] + + return token_inputs( + prompt=prompt_text, + prompt_token_ids=prompt_token_ids, + ) + + if parsed["type"] == "tokens": + tokens_content = parsed["content"] + + prompt_token_ids = tokens_content["prompt_token_ids"] + multi_modal_data = tokens_content.get("multi_modal_data") + mm_processor_kwargs = tokens_content.get("mm_processor_kwargs") + + return token_inputs( + prompt_token_ids=prompt_token_ids, + multi_modal_data=multi_modal_data, + mm_processor_kwargs=mm_processor_kwargs, + ) + + if parsed["type"] == "text": + text_content = parsed["content"] + + prompt_text = text_content["prompt"] prompt_token_ids = await self._tokenize_prompt_async( prompt_text, request_id=request_id, lora_request=lora_request, ) - multi_modal_data = parsed["content"].get("multi_modal_data") - mm_processor_kwargs = parsed["content"].get("mm_processor_kwargs") - else: - assert_never(parsed) + multi_modal_data = text_content.get("multi_modal_data") + mm_processor_kwargs = text_content.get("mm_processor_kwargs") + + return token_inputs( + prompt=prompt_text, + prompt_token_ids=prompt_token_ids, + multi_modal_data=multi_modal_data, + mm_processor_kwargs=mm_processor_kwargs, + ) - return (prompt_text, prompt_token_ids, multi_modal_data, - mm_processor_kwargs) + assert_never(parsed) def _build_enc_dec_llm_inputs( self, - encoder_comps: PromptComponents, - decoder_comps: DecoderPromptComponents, - mm_processor_kwargs: Dict[str, Any], + encoder_inputs: SingletonInputs, + decoder_inputs: Optional[SingletonInputs], ) -> EncoderDecoderInputs: - encoder_prompt, encoder_prompt_ids, encoder_mm_data, _ = encoder_comps - decoder_prompt, decoder_prompt_ids, decoder_mm_data, _ = decoder_comps - - # Reminder: Please update docs/source/serving/compatibility_matrix.rst - # If the feature combo become valid - if decoder_mm_data is not None: - raise ValueError( - "Multi-modality decoder inputs of encoder-decoder models are " - "not supported yet") - - # For Multi-Modal models (e.g., mllama), the text input can be - # <|image|><|begin_of_text|>hello world. And we should not add - # another <|begin_of_text|> to the beginning. - decoder_prompt_ids = (self._prepare_decoder_input_ids_for_generation( - decoder_prompt_ids, - force_bos=(encoder_mm_data is None and decoder_mm_data is None))) + if encoder_inputs["type"] == "token": + pass + else: + assert_never(encoder_inputs) + + if decoder_inputs is None: + dec_token_ids = self._prepare_decoder_input_ids_for_generation( + None) + decoder_inputs = token_inputs(dec_token_ids) + elif decoder_inputs["type"] == "token": + dec_token_ids = self._prepare_decoder_input_ids_for_generation( + decoder_inputs["prompt_token_ids"]) + decoder_inputs["prompt_token_ids"] = dec_token_ids + + if "multi_modal_data" in decoder_inputs: + raise ValueError("Multi-modal decoder inputs of encoder-" + "decoder models are not supported yet") + else: + assert_never(encoder_inputs) return EncoderDecoderInputs( - prompt_token_ids=decoder_prompt_ids, - prompt=decoder_prompt, - multi_modal_data=decoder_mm_data, - mm_processor_kwargs=mm_processor_kwargs, - encoder_prompt_token_ids=encoder_prompt_ids, - encoder_prompt=encoder_prompt, - encoder_multi_modal_data=encoder_mm_data, + encoder=encoder_inputs, + decoder=decoder_inputs, ) def _process_encoder_decoder_prompt( @@ -341,8 +363,7 @@ def _process_encoder_decoder_prompt( ) -> EncoderDecoderInputs: ''' For encoder/decoder models only: - Process an input prompt into an - :class:`EncoderDecoderInputs` instance. + Process an input prompt into an :class:`EncoderDecoderInputs` instance. There are two types of input prompts: singleton prompts which carry only the @@ -361,7 +382,7 @@ def _process_encoder_decoder_prompt( have any possible singleton type; thus this method relies on helper functions to obtain token ids for the sub-prompts. - + Arguments: * prompt: an input prompt @@ -372,40 +393,31 @@ def _process_encoder_decoder_prompt( * :class:`EncoderDecoderInputs` instance ''' - encoder_comps: PromptComponents - decoder_comps: DecoderPromptComponents + encoder_inputs: SingletonInputs + decoder_inputs: Optional[SingletonInputs] if is_explicit_encoder_decoder_prompt(prompt): - encoder_comps = self._extract_prompt_components( + encoder_inputs = self._prompt_to_llm_inputs( prompt["encoder_prompt"], request_id=request_id, ) if (decoder_input := prompt["decoder_prompt"]) is None: - decoder_comps = None, None, None, None + decoder_inputs = None else: - decoder_comps = self._extract_prompt_components( + decoder_inputs = self._prompt_to_llm_inputs( decoder_input, request_id=request_id, ) - # Handle this carefully in case it was directly initialized by user - mm_processor_kwargs = prompt.get("mm_processor_kwargs", {}) else: - encoder_comps = self._extract_prompt_components( + encoder_inputs = self._prompt_to_llm_inputs( prompt, request_id=request_id, ) - # If there are no decoder components, we assume the - # mm_processor_kwargs are in the encoder prompt - mm_processor_kwargs = encoder_comps[-1] if encoder_comps[ - -1] is not None else {} - decoder_comps = None, None, None, None - - return self._build_enc_dec_llm_inputs( - encoder_comps, - decoder_comps, - mm_processor_kwargs, - ) + + decoder_inputs = None + + return self._build_enc_dec_llm_inputs(encoder_inputs, decoder_inputs) async def _process_encoder_decoder_prompt_async( self, @@ -413,59 +425,50 @@ async def _process_encoder_decoder_prompt_async( request_id: str, ) -> EncoderDecoderInputs: """Async version of :meth:`_process_encoder_decoder_prompt`.""" - encoder_comps: PromptComponents - decoder_comps: DecoderPromptComponents + encoder_inputs: SingletonInputs + decoder_inputs: Optional[SingletonInputs] if is_explicit_encoder_decoder_prompt(prompt): - encoder_task = self._extract_prompt_components_async( + encoder_task = self._prompt_to_llm_inputs_async( prompt["encoder_prompt"], request_id=request_id, ) if (decoder_input := prompt["decoder_prompt"]) is None: - encoder_comps = await encoder_task - decoder_comps = None, None, None, None + encoder_inputs = await encoder_task + decoder_inputs = None else: - decoder_task = self._extract_prompt_components_async( + decoder_task = self._prompt_to_llm_inputs_async( decoder_input, request_id=request_id, ) - encoder_comps, decoder_comps = await asyncio.gather( + encoder_inputs, decoder_inputs = await asyncio.gather( encoder_task, decoder_task) - mm_processor_kwargs = prompt["mm_processor_kwargs"] else: - encoder_comps = await self._extract_prompt_components_async( + encoder_inputs = await self._prompt_to_llm_inputs_async( prompt, request_id=request_id, ) - # If there are no decoder components, we assume the - # mm_processor_kwargs are in the encoder prompt - mm_processor_kwargs = encoder_comps[-1] if encoder_comps[ - -1] is not None else {} - decoder_comps = None, None, None, None - - return self._build_enc_dec_llm_inputs( - encoder_comps, - decoder_comps, - mm_processor_kwargs, - ) + + decoder_inputs = None + + return self._build_enc_dec_llm_inputs(encoder_inputs, decoder_inputs) def _build_decoder_only_llm_inputs( self, - prompt_comps: PromptComponents, + prompt_inputs: DecoderOnlyInputs, prompt_adapter_request: Optional[PromptAdapterRequest], ) -> DecoderOnlyInputs: - (prompt, prompt_token_ids, multi_modal_data, - mm_processor_kwargs) = prompt_comps - - prompt_token_ids = self._apply_prompt_adapter( - prompt_token_ids, prompt_adapter_request=prompt_adapter_request) + if prompt_inputs["type"] == "token": + prompt_inputs["prompt_token_ids"] = self._apply_prompt_adapter( + prompt_inputs["prompt_token_ids"], + prompt_adapter_request=prompt_adapter_request, + ) + else: + assert_never(prompt_inputs) - return DecoderOnlyInputs(prompt_token_ids=prompt_token_ids, - prompt=prompt, - multi_modal_data=multi_modal_data, - mm_processor_kwargs=mm_processor_kwargs) + return prompt_inputs def _process_decoder_only_prompt( self, @@ -490,7 +493,7 @@ def _process_decoder_only_prompt( * :class:`DecoderOnlyInputs` instance ''' - prompt_comps = self._extract_prompt_components( + prompt_comps = self._prompt_to_llm_inputs( prompt, request_id=request_id, lora_request=lora_request, @@ -509,7 +512,7 @@ async def _process_decoder_only_prompt_async( prompt_adapter_request: Optional[PromptAdapterRequest] = None, ) -> DecoderOnlyInputs: """Async version of :meth:`_process_decoder_only_prompt`.""" - prompt_comps = await self._extract_prompt_components_async( + prompt_comps = await self._prompt_to_llm_inputs_async( prompt, request_id=request_id, lora_request=lora_request, @@ -526,7 +529,7 @@ def preprocess( request_id: str, lora_request: Optional[LoRARequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, - ) -> Union[DecoderOnlyInputs, EncoderDecoderInputs]: + ) -> ProcessorInputs: """Preprocess the input prompt.""" if self.is_encoder_decoder_model(): # Encoder-decoder model requires special mapping of @@ -554,7 +557,7 @@ async def preprocess_async( request_id: str, lora_request: Optional[LoRARequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, - ) -> Union[DecoderOnlyInputs, EncoderDecoderInputs]: + ) -> ProcessorInputs: """Async version of :meth:`preprocess`.""" if self.is_encoder_decoder_model(): # Encoder-decoder model requires special mapping of diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index 4cebc91ce715c..7d7a797be4f60 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -1,8 +1,8 @@ import functools from collections import UserDict from dataclasses import dataclass -from typing import (TYPE_CHECKING, Any, Callable, Dict, Mapping, Optional, - Protocol, Tuple, Type) +from typing import (TYPE_CHECKING, Any, Callable, Dict, Mapping, NamedTuple, + Optional, Protocol, Type, cast) from torch import nn from transformers import PretrainedConfig @@ -12,11 +12,12 @@ from vllm.utils import (get_allowed_kwarg_only_overrides, print_warning_once, resolve_mm_processor_kwargs) -from .data import DecoderOnlyInputs +from .data import ProcessorInputs if TYPE_CHECKING: from vllm.config import ModelConfig - from vllm.multimodal import MultiModalDataDict, MultiModalRegistry + from vllm.multimodal import (MultiModalDataDict, MultiModalPlaceholderDict, + MultiModalRegistry) from vllm.sequence import SequenceData logger = init_logger(__name__) @@ -63,6 +64,14 @@ def get_hf_image_processor_config(self) -> Dict[str, Any]: N = TypeVar("N", bound=Type[nn.Module]) +class DummyData(NamedTuple): + """Dummy data used for profiling.""" + + seq_data: "SequenceData" + multi_modal_data: Optional["MultiModalDataDict"] = None + multi_modal_placeholders: Optional["MultiModalPlaceholderDict"] = None + + class DummyDataFactory(Protocol): def __call__( @@ -71,7 +80,7 @@ def __call__( seq_len: int, mm_counts: Mapping[str, int], **mm_processor_kwargs: Any, - ) -> Tuple["SequenceData", Optional["MultiModalDataDict"]]: + ) -> DummyData: """ Create dummy data to be inputted into the model. @@ -100,7 +109,7 @@ def __getitem__(self, key: str) -> int: raise KeyError(msg) from exc -InputProcessor = Callable[[InputContext, DecoderOnlyInputs], DecoderOnlyInputs] +InputProcessor = Callable[[InputContext, ProcessorInputs], ProcessorInputs] """Preprocess the inputs to the model.""" @@ -123,7 +132,7 @@ def _default_dummy_data_factory( ctx: InputContext, seq_len: int, mm_counts: Mapping[str, int], - ) -> Tuple["SequenceData", Optional["MultiModalDataDict"]]: + ) -> DummyData: """ The default dummy data factory represents the longest possible text that can be inputted to the model. @@ -134,10 +143,7 @@ def _default_dummy_data_factory( # Avoid circular import from vllm.sequence import SequenceData - dummy_seq_data = SequenceData.from_prompt_token_counts((0, seq_len)) - dummy_multi_modal_data = None - - return dummy_seq_data, dummy_multi_modal_data + return DummyData(SequenceData.from_prompt_token_counts((0, seq_len))) def register_dummy_data(self, factory: DummyDataFactory): """ @@ -195,7 +201,7 @@ def dummy_data_for_profiling( seq_len: int, mm_registry: "MultiModalRegistry", is_encoder_data: bool = False, - ) -> Tuple["SequenceData", Optional["MultiModalDataDict"]]: + ) -> DummyData: """ Create dummy data for profiling the memory usage of a model. @@ -220,12 +226,12 @@ def dummy_data_for_profiling( mm_processor_kwargs = get_allowed_kwarg_only_overrides( dummy_factory, overrides=model_config.mm_processor_kwargs) - seq_data, mm_data = dummy_factory(InputContext(model_config), seq_len, - _MultiModalCounts(mm_counts), - **mm_processor_kwargs) + dummy_data = dummy_factory(InputContext(model_config), seq_len, + _MultiModalCounts(mm_counts), + **mm_processor_kwargs) # Having more tokens is over-conservative but otherwise fine - num_tokens = seq_data.prompt_token_ids + num_tokens = dummy_data.seq_data.prompt_token_ids if len(num_tokens) < seq_len: if is_encoder_data: print_warning_once( @@ -235,21 +241,21 @@ def dummy_data_for_profiling( raise AssertionError( f"Expected at least {seq_len} dummy tokens for profiling, " f"but found {len(num_tokens)} tokens instead.") - if mm_data is not None: - for k, v in mm_data.items(): + if dummy_data.multi_modal_data is not None: + for k, v in dummy_data.multi_modal_data.items(): num_items = len(v) if isinstance(v, list) else 1 num_expected = mm_counts[k] assert num_items >= num_expected, ( f"Expected at least {num_expected} dummy '{k}' instances " f"for profiling, but found {num_items} instances instead.") - return seq_data, mm_data + return dummy_data def _default_input_processor( self, ctx: InputContext, - inputs: DecoderOnlyInputs, - ) -> DecoderOnlyInputs: + inputs: ProcessorInputs, + ) -> ProcessorInputs: """The default input processor is a no-op.""" return inputs @@ -282,7 +288,7 @@ def _get_model_input_processor(self, model_cls: Type[nn.Module]): .get(model_cls, self._default_input_processor) def process_input(self, model_config: "ModelConfig", - inputs: DecoderOnlyInputs) -> DecoderOnlyInputs: + inputs: ProcessorInputs) -> ProcessorInputs: """ Apply an input processor to an instance of model inputs. @@ -302,7 +308,7 @@ def process_input(self, model_config: "ModelConfig", # If it's empty, it'll fall back to the default kwarg values mm_processor_kwargs = resolve_mm_processor_kwargs( model_config.mm_processor_kwargs, - inputs.get("mm_processor_kwargs"), + cast(Dict[str, Any], inputs.get("mm_processor_kwargs")), processor, ) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer.py b/vllm/model_executor/layers/mamba/mamba_mixer.py new file mode 100644 index 0000000000000..8ef0a6cdf2c52 --- /dev/null +++ b/vllm/model_executor/layers/mamba/mamba_mixer.py @@ -0,0 +1,217 @@ +import torch +from torch import nn +from torch.nn.parameter import Parameter + +from vllm.attention.backends.abstract import AttentionMetadata +from vllm.distributed.parallel_state import ( + get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) +from vllm.model_executor.custom_op import CustomOp +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + MergedColumnParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( + causal_conv1d_fn, causal_conv1d_update) +from vllm.model_executor.layers.mamba.ops.mamba_ssm import ( + selective_scan_fn, selective_state_update) +from vllm.model_executor.models.mamba_cache import MambaCacheParams +from vllm.model_executor.utils import set_weight_attrs + + +# Adapted from transformers.models.mamba.modeling_mamba.MambaMixer +@CustomOp.register("mamba_mixer") +class MambaMixer(CustomOp): + """ + Compute ∆, A, B, C, and D the state space parameters and compute + the `contextualized_states`. A, D are input independent + (see Mamba paper [1] Section 3.5.2 "Interpretation of A" + for why A isn't selective) ∆, B, C are input-dependent + (this is a key difference between Mamba and the linear time + invariant S4, and is why Mamba is called + **selective** state spaces) + """ + + def __init__(self, + hidden_size: int, + ssm_state_size: int, + conv_kernel_size: int, + intermediate_size: int, + time_step_rank: int, + use_conv_bias: bool, + use_bias: bool, + use_rms_norm: bool, + rms_norm_eps: float = 1e-5, + activation="silu"): + super().__init__() + self.time_step_rank = time_step_rank + self.ssm_state_size = ssm_state_size + self.use_rms_norm = use_rms_norm + self.activation = activation + + self.conv1d = ColumnParallelLinear( + input_size=conv_kernel_size, + output_size=intermediate_size, + bias=use_conv_bias, + ) + # unsqueeze to fit conv1d weights shape into the linear weights shape. + # Can't do this in `weight_loader` since it already exists in + # `ColumnParallelLinear` and `set_weight_attrs` + # doesn't allow to override it + self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1) + + self.in_proj = MergedColumnParallelLinear(hidden_size, + [intermediate_size] * 2, + bias=use_bias) + # selective projection used to make dt, B and C input dependent + self.x_proj = RowParallelLinear( + intermediate_size, + time_step_rank + ssm_state_size * 2, + bias=False, + ) + # time step projection (discretization) - + # In the forward we need to apply dt_proj without the bias, + # as the bias is added in the selective scan kernel. + self.dt_proj = ColumnParallelLinear(time_step_rank, + intermediate_size, + bias=True, + skip_bias_add=True) + + def weight_loader(param: Parameter, loaded_weight: torch.Tensor): + tp_rank = get_tensor_model_parallel_rank() + tp_size = get_tensor_model_parallel_world_size() + param.data.copy_( + loaded_weight.data.split(loaded_weight.shape[0] // tp_size, + dim=0)[tp_rank]) + + def A_weight_loader(param: Parameter, loaded_weight: torch.Tensor): + weight_loader(param, -torch.exp(loaded_weight.float())) + + tp_size = get_tensor_model_parallel_world_size() + self.A = nn.Parameter( + torch.empty( + intermediate_size // tp_size, + ssm_state_size, + dtype=torch.float32, + )) + self.D = nn.Parameter(torch.ones(intermediate_size // tp_size)) + + set_weight_attrs(self.D, {"weight_loader": weight_loader}) + set_weight_attrs(self.A, {"weight_loader": A_weight_loader}) + + self.out_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=use_bias, + input_is_parallel=True, + ) + + self.dt_layernorm = RMSNorm(time_step_rank, + eps=rms_norm_eps) if use_rms_norm else None + + self.b_layernorm = RMSNorm(ssm_state_size, + eps=rms_norm_eps) if use_rms_norm else None + + self.c_layernorm = RMSNorm(ssm_state_size, + eps=rms_norm_eps) if use_rms_norm else None + + def forward_native(self, hidden_states: torch.Tensor, + attn_metadata: AttentionMetadata, + conv_state: torch.Tensor, ssm_state: torch.Tensor): + pass + + def forward_cuda(self, hidden_states: torch.Tensor, + attn_metadata: AttentionMetadata, + mamba_cache_params: MambaCacheParams): + + # 1. Gated MLP's linear projection + projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1) + hidden_states, gate = projected_states.chunk(2, dim=-2) + + # 2. Convolution sequence transformation + conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), + self.conv1d.weight.size(2)) + + if attn_metadata.query_start_loc is not None \ + and attn_metadata.context_lens_tensor is not None: + # |---------- N-1 iteration --------| + # |---------------- N iteration ---------------------| + # |- tokenA -|......................|-- newTokens ---| + # |---------- context_len ----------| + # |-------------------- seq_len ---------------------| + # |-- query_len ---| + hidden_states = causal_conv1d_fn( + hidden_states, + conv_weights, + self.conv1d.bias, + activation=self.activation, + conv_states=mamba_cache_params.conv_state, + has_initial_state=attn_metadata.context_lens_tensor > 0, + cache_indices=mamba_cache_params.state_indices_tensor, + query_start_loc=attn_metadata.query_start_loc) + else: + hidden_states = causal_conv1d_update( + hidden_states.transpose(0, 1), + mamba_cache_params.conv_state, + conv_weights, + self.conv1d.bias, + self.activation, + conv_state_indices=mamba_cache_params.state_indices_tensor) + hidden_states = hidden_states.transpose(0, 1) + + # 3. State Space Model sequence transformation + # 3.a. input varying initialization of time_step, B and C + ssm_parameters = self.x_proj(hidden_states.transpose(-2, -1))[0] + + time_step, B, C = torch.split( + ssm_parameters, + [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], + dim=-1, + ) + if self.use_rms_norm: + assert self.dt_layernorm is not None + assert self.b_layernorm is not None + assert self.c_layernorm is not None + time_step = self.dt_layernorm(time_step.contiguous()) + B = self.b_layernorm(B.contiguous()) + C = self.c_layernorm(C.contiguous()) + + discrete_time_step = self.dt_proj(time_step)[0].transpose(-2, -1) + # 3.c perform the recurrence y ← SSM(A, B, C)(x) + time_proj_bias = (self.dt_proj.bias.float() if hasattr( + self.dt_proj, "bias") else None) + + if attn_metadata.query_start_loc is not None \ + and attn_metadata.context_lens_tensor is not None: + scan_outputs = selective_scan_fn( + hidden_states, + mamba_cache_params.ssm_state, + discrete_time_step, + self.A, + B.transpose(-2, -1), + C.transpose(-2, -1), + self.D.float(), + gate, + time_proj_bias, + delta_softplus=True, + cache_indices=mamba_cache_params.state_indices_tensor, + has_initial_state=attn_metadata.context_lens_tensor > 0, + query_start_loc=attn_metadata.query_start_loc) + else: + scan_outputs = selective_state_update( + mamba_cache_params.ssm_state, + hidden_states.transpose(0, 1), + discrete_time_step.transpose(0, 1), + self.A, + B, + C, + self.D, + gate.transpose(0, 1), + time_proj_bias, + dt_softplus=True, + state_batch_indices=mamba_cache_params.state_indices_tensor) + scan_outputs = scan_outputs.transpose(0, 1) + + # 4. Final linear projection + contextualized_states = self.out_proj(scan_outputs.transpose(-2, + -1))[0] + return contextualized_states diff --git a/vllm/model_executor/layers/quantization/awq_marlin.py b/vllm/model_executor/layers/quantization/awq_marlin.py index 95ec12daeeeb5..ea69bee45f8d9 100644 --- a/vllm/model_executor/layers/quantization/awq_marlin.py +++ b/vllm/model_executor/layers/quantization/awq_marlin.py @@ -9,7 +9,9 @@ from vllm.model_executor.layers.fused_moe.layer import ( FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported) from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, + UnquantizedLinearMethod, set_weight_attrs) +from vllm.model_executor.layers.quantization.awq import is_layer_skipped_awq from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) from vllm.model_executor.layers.quantization.utils import replace_parameter @@ -36,13 +38,18 @@ class AWQMarlinConfig(QuantizationConfig): 8: scalar_types.uint8, } - def __init__(self, weight_bits: int, group_size: int, has_zp: bool, - lm_head_quantized: bool) -> None: + def __init__(self, + weight_bits: int, + group_size: int, + zero_point: bool, + lm_head_quantized: bool, + modules_to_not_convert: Optional[List[str]] = None) -> None: self.pack_factor = 32 // weight_bits # packed into int32 self.group_size = group_size - self.has_zp = has_zp + self.zero_point = zero_point self.lm_head_quantized = lm_head_quantized self.weight_bits = weight_bits + self.modules_to_not_convert = modules_to_not_convert or [] if self.weight_bits not in self.TYPE_MAP: raise ValueError(f"Unsupported num_bits = {self.weight_bits}. " @@ -52,13 +59,14 @@ def __init__(self, weight_bits: int, group_size: int, has_zp: bool, verify_marlin_supported(self.quant_type, group_size=self.group_size, - has_zp=self.has_zp) + has_zp=self.zero_point) def __repr__(self) -> str: return (f"AWQMarlinConfig(quant_type={self.quant_type}, " f"group_size={self.group_size}, " - f"has_zp={self.has_zp}, " - f"lm_head_quantized={self.lm_head_quantized})") + f"zero_point={self.zero_point}, " + f"lm_head_quantized={self.lm_head_quantized}, " + f"modules_to_not_convert={self.modules_to_not_convert})") @classmethod def get_name(cls) -> str: @@ -80,10 +88,13 @@ def get_config_filenames(cls) -> List[str]: def from_config(cls, config: Dict[str, Any]) -> "AWQMarlinConfig": weight_bits = cls.get_from_keys(config, ["bits"]) group_size = cls.get_from_keys(config, ["group_size"]) - has_zp = cls.get_from_keys(config, ["zero_point"]) + zero_point = cls.get_from_keys(config, ["zero_point"]) lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], default=False) - return cls(weight_bits, group_size, has_zp, lm_head_quantized) + modules_to_not_convert = cls.get_from_keys_or( + config, ["modules_to_not_convert"], None) + return cls(weight_bits, group_size, zero_point, lm_head_quantized, + modules_to_not_convert) @classmethod def override_quantization_method(cls, hf_quant_cfg, @@ -109,6 +120,8 @@ def get_quant_method(self, layer: torch.nn.Module, prefix: str) -> Optional["QuantizeMethodBase"]: if (isinstance(layer, LinearBase) or (isinstance(layer, ParallelLMHead) and self.lm_head_quantized)): + if is_layer_skipped_awq(prefix, self.modules_to_not_convert): + return UnquantizedLinearMethod() return AWQMarlinLinearMethod(self) elif isinstance(layer, FusedMoE): return AWQMoEMethod(self) @@ -123,7 +136,7 @@ def is_awq_marlin_compatible(cls, quant_config: Dict[str, Any]): quant_method = quant_config.get("quant_method", "").lower() num_bits = quant_config.get("bits") group_size = quant_config.get("group_size") - has_zp = quant_config.get("zero_point") + zero_point = quant_config.get("zero_point") if not current_platform.is_cuda(): return False @@ -132,7 +145,7 @@ def is_awq_marlin_compatible(cls, quant_config: Dict[str, Any]): return False # If we cannot find the info needed in the config, cannot convert. - if (num_bits is None or group_size is None or has_zp is None): + if (num_bits is None or group_size is None or zero_point is None): return False if num_bits not in cls.TYPE_MAP: @@ -140,7 +153,7 @@ def is_awq_marlin_compatible(cls, quant_config: Dict[str, Any]): return check_marlin_supported(quant_type=cls.TYPE_MAP[num_bits], group_size=group_size, - has_zp=has_zp) + has_zp=zero_point) class AWQMarlinLinearMethod(LinearMethodBase): diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index dc5f47eb9b0fb..9694f2b8208e2 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -141,8 +141,11 @@ def create_weights( layer.register_parameter("input_scale", scale) def process_weights_after_loading(self, layer: Module) -> None: - max_w_scale, weight = requantize_with_max_scale( - layer.weight, layer.weight_scale, layer.logical_widths) + weight = layer.weight + max_w_scale = layer.weight_scale.max() + if not (layer.weight_scale == layer.weight_scale[0]).all(): + max_w_scale, weight = requantize_with_max_scale( + layer.weight, layer.weight_scale, layer.logical_widths) layer.weight = Parameter(weight.t(), requires_grad=False) layer.weight_scale = Parameter(max_w_scale, requires_grad=False) layer.input_scale = Parameter(layer.input_scale.max(), diff --git a/vllm/model_executor/layers/resampler.py b/vllm/model_executor/layers/resampler.py index 8cd938fc85fb2..bce91f1d7fd5e 100644 --- a/vllm/model_executor/layers/resampler.py +++ b/vllm/model_executor/layers/resampler.py @@ -41,6 +41,7 @@ from torch.nn.init import trunc_normal_ from vllm.model_executor.layers.linear import ReplicatedLinear +from vllm.model_executor.layers.quantization import QuantizationConfig DEFAULT_LN = partial(nn.LayerNorm, eps=1e-6) @@ -154,15 +155,15 @@ class BaseResampler(nn.Module): A tensor with the shape of (grid_size**2, embed_dim) """ - def __init__( - self, - num_queries: int, - embed_dim: int, - num_heads: int, - kv_dim: Optional[int] = None, - norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN, - do_post_projection: bool = True, - ) -> None: + def __init__(self, + num_queries: int, + embed_dim: int, + num_heads: int, + kv_dim: Optional[int] = None, + norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN, + do_post_projection: bool = True, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "") -> None: super().__init__() self.num_queries = num_queries @@ -172,7 +173,11 @@ def __init__( self.query = nn.Parameter(torch.zeros(self.num_queries, embed_dim)) trunc_normal_(self.query, std=0.02) if kv_dim is not None and kv_dim != embed_dim: - self.kv_proj = ReplicatedLinear(kv_dim, embed_dim, bias=False) + self.kv_proj = ReplicatedLinear(kv_dim, + embed_dim, + bias=False, + quant_config=quant_config, + prefix=prefix) else: # Maintain the same return value with ReplicatedLinear.forward self.kv_proj = lambda *args, **kwargs: ( # type: ignore # noqa @@ -209,22 +214,24 @@ class Resampler2(BaseResampler): present in minicpmv2.0, but not qwen-vl. """ - def __init__( - self, - grid_size: int, - embed_dim: int, - num_heads: int, - kv_dim: Optional[int] = None, - norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN, - adaptive: bool = False, - do_post_projection: bool = True, - ) -> None: + def __init__(self, + grid_size: int, + embed_dim: int, + num_heads: int, + kv_dim: Optional[int] = None, + norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN, + adaptive: bool = False, + do_post_projection: bool = True, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "") -> None: super().__init__(grid_size**2, embed_dim, num_heads, kv_dim, norm_layer, - do_post_projection=do_post_projection) + do_post_projection=do_post_projection, + quant_config=quant_config, + prefix=prefix) self.adaptive = adaptive pos_embed_arr = get_2d_sincos_pos_embed(embed_dim, diff --git a/vllm/model_executor/model_loader/__init__.py b/vllm/model_executor/model_loader/__init__.py index d1ec171c9ec2a..12468997e4653 100644 --- a/vllm/model_executor/model_loader/__init__.py +++ b/vllm/model_executor/model_loader/__init__.py @@ -1,27 +1,15 @@ -from typing import Optional - from torch import nn -from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, - ModelConfig, ParallelConfig, SchedulerConfig) +from vllm.config import VllmConfig from vllm.model_executor.model_loader.loader import (BaseModelLoader, get_model_loader) from vllm.model_executor.model_loader.utils import ( get_architecture_class_name, get_model_architecture) -def get_model(*, model_config: ModelConfig, load_config: LoadConfig, - device_config: DeviceConfig, parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - lora_config: Optional[LoRAConfig], - cache_config: CacheConfig) -> nn.Module: - loader = get_model_loader(load_config) - return loader.load_model(model_config=model_config, - device_config=device_config, - lora_config=lora_config, - parallel_config=parallel_config, - scheduler_config=scheduler_config, - cache_config=cache_config) +def get_model(*, vllm_config: VllmConfig) -> nn.Module: + loader = get_model_loader(vllm_config.load_config) + return loader.load_model(vllm_config=vllm_config) __all__ = [ diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index 79703bb7ded7a..c3e0290f270ae 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -21,13 +21,15 @@ from transformers import AutoModelForCausalLM, PretrainedConfig from transformers.utils import SAFE_WEIGHTS_INDEX_NAME -from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoadFormat, - LoRAConfig, ModelConfig, MultiModalConfig, - ParallelConfig, PoolerConfig, SchedulerConfig) +from vllm.config import (CacheConfig, LoadConfig, LoadFormat, LoRAConfig, + ModelConfig, MultiModalConfig, ParallelConfig, + PoolerConfig, SchedulerConfig, VllmConfig) from vllm.distributed import (get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.envs import VLLM_USE_MODELSCOPE from vllm.logger import init_logger +from vllm.model_executor.layers.linear import (ReplicatedLinear, + RowParallelLinear) from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.model_loader.tensorizer import ( @@ -38,7 +40,7 @@ from vllm.model_executor.model_loader.weight_utils import ( download_safetensors_index_file_from_hf, download_weights_from_hf, filter_duplicate_safetensors_files, filter_files_not_needed_for_inference, - get_gguf_extra_tensor_names, get_quant_config, gguf_quant_weights_iterator, + get_gguf_extra_tensor_names, gguf_quant_weights_iterator, initialize_dummy_weights, np_cache_weights_iterator, pt_weights_iterator, safetensors_weights_iterator) from vllm.model_executor.models import (has_inner_state, supports_lora, @@ -92,32 +94,6 @@ def device_loading_context(module: torch.nn.Module, logger = init_logger(__name__) -def _get_quantization_config( - model_config: ModelConfig, - load_config: LoadConfig) -> Optional[QuantizationConfig]: - """Get the quantization config.""" - if model_config.quantization is not None: - quant_config = get_quant_config(model_config, load_config) - capability_tuple = current_platform.get_device_capability() - - if capability_tuple is not None: - capability = capability_tuple.to_int() - if capability < quant_config.get_min_capability(): - raise ValueError( - f"The quantization method {model_config.quantization} " - "is not supported for the current GPU. " - f"Minimum capability: {quant_config.get_min_capability()}. " - f"Current capability: {capability}.") - supported_dtypes = quant_config.get_supported_act_dtypes() - if model_config.dtype not in supported_dtypes: - raise ValueError( - f"{model_config.dtype} is not supported for quantization " - f"method {model_config.quantization}. Supported dtypes: " - f"{supported_dtypes}") - return quant_config - return None - - def _get_model_initialization_kwargs( model_class: Type[nn.Module], lora_config: Optional[LoRAConfig], @@ -150,6 +126,7 @@ def _get_model_initialization_kwargs( def build_model(model_class: Type[nn.Module], + vllm_config: Optional[VllmConfig], hf_config: PretrainedConfig, cache_config: Optional[CacheConfig], quant_config: Optional[QuantizationConfig], @@ -166,26 +143,31 @@ def build_model(model_class: Type[nn.Module], if prefix: extra_kwargs["prefix"] = prefix + # TODO: unify all the module initialization code + # to only take the `VllmConfig` object as input + from vllm.plugins import set_vllm_config + set_vllm_config(vllm_config) + return model_class(config=hf_config, cache_config=cache_config, quant_config=quant_config, **extra_kwargs) -def _initialize_model( - model_config: ModelConfig, - load_config: LoadConfig, - lora_config: Optional[LoRAConfig], - cache_config: CacheConfig, - scheduler_config: Optional[SchedulerConfig] = None) -> nn.Module: +def _initialize_model(vllm_config: VllmConfig) -> nn.Module: """Initialize a model with the given configurations.""" + model_config = vllm_config.model_config + lora_config = vllm_config.lora_config + scheduler_config = vllm_config.scheduler_config + cache_config = vllm_config.cache_config model_class, _ = get_model_architecture(model_config) return build_model( model_class, + vllm_config, model_config.hf_config, cache_config=cache_config, - quant_config=_get_quantization_config(model_config, load_config), + quant_config=vllm_config.quant_config, lora_config=lora_config, multimodal_config=model_config.multimodal_config, scheduler_config=scheduler_config, @@ -205,12 +187,7 @@ def download_model(self, model_config: ModelConfig) -> None: raise NotImplementedError @abstractmethod - def load_model(self, *, model_config: ModelConfig, - device_config: DeviceConfig, - lora_config: Optional[LoRAConfig], - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - cache_config: CacheConfig) -> nn.Module: + def load_model(self, *, vllm_config: VllmConfig) -> nn.Module: """Load a model with the given configurations.""" raise NotImplementedError @@ -396,18 +373,14 @@ def download_model(self, model_config: ModelConfig) -> None: model_config.revision, fall_back_to_pt=True) - def load_model(self, *, model_config: ModelConfig, - device_config: DeviceConfig, - lora_config: Optional[LoRAConfig], - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - cache_config: CacheConfig) -> nn.Module: + def load_model(self, vllm_config: VllmConfig) -> nn.Module: + device_config = vllm_config.device_config + model_config = vllm_config.model_config + target_device = torch.device(device_config.device) with set_default_torch_dtype(model_config.dtype): with target_device: - model = _initialize_model(model_config, self.load_config, - lora_config, cache_config, - scheduler_config) + model = _initialize_model(vllm_config=vllm_config) model.load_weights(self._get_all_weights(model_config, model)) @@ -436,17 +409,12 @@ def __init__(self, load_config: LoadConfig): def download_model(self, model_config: ModelConfig) -> None: pass # Nothing to download - def load_model(self, *, model_config: ModelConfig, - device_config: DeviceConfig, - lora_config: Optional[LoRAConfig], - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - cache_config: CacheConfig) -> nn.Module: + def load_model(self, vllm_config: VllmConfig) -> nn.Module: + device_config = vllm_config.device_config + model_config = vllm_config.model_config with set_default_torch_dtype(model_config.dtype): with torch.device(device_config.device): - model = _initialize_model(model_config, self.load_config, - lora_config, cache_config, - scheduler_config) + model = _initialize_model(vllm_config=vllm_config) # NOTE(woosuk): For accurate performance evaluation, we assign # random values to the weights. initialize_dummy_weights(model) @@ -488,10 +456,7 @@ def _get_weights_iterator( def _load_model_serialized_cpu( self, - model_config: ModelConfig, - device_config: DeviceConfig, - lora_config: Optional[LoRAConfig], - cache_config: CacheConfig, + vllm_config: VllmConfig, ) -> nn.Module: """Load a serialized model with tensorizer to the CPU. @@ -500,31 +465,34 @@ def _load_model_serialized_cpu( default HuggingFace loading, but will be slower than loading a vLLM-tensorized model. """ + device_config = vllm_config.device_config + model_config = vllm_config.model_config with set_default_torch_dtype(model_config.dtype): with torch.device(device_config.device): - model = _initialize_model(model_config, self.load_config, - lora_config, cache_config) + model = _initialize_model(vllm_config=vllm_config) model.load_weights(self._get_weights_iterator()) return model.eval() def _load_model_serialized( self, - model_config: ModelConfig, - device_config: DeviceConfig, - lora_config: Optional[LoRAConfig], - cache_config: CacheConfig, + vllm_config: VllmConfig, ) -> nn.Module: """Load a serialized model with tensorizer. Expects a vLLM-tensorized model. See the examples/tensorize_vllm_model.py example script for serializing vLLM models.""" + + device_config = vllm_config.device_config + model_config = vllm_config.model_config + lora_config = vllm_config.lora_config + cache_config = vllm_config.cache_config + with set_default_torch_dtype(model_config.dtype): with torch.device(device_config.device): model_class = get_model_architecture(model_config)[0] - quant_config = _get_quantization_config( - model_config, self.load_config) + quant_config = vllm_config.quant_config extra_kwargs = _get_model_initialization_kwargs( model_class, lora_config, model_config.multimodal_config) extra_kwargs["quant_config"] = quant_config @@ -544,12 +512,9 @@ def download_model(self, model_config: ModelConfig) -> None: with self.tensorizer_config.open_stream(): pass - def load_model(self, *, model_config: ModelConfig, - device_config: DeviceConfig, - lora_config: Optional[LoRAConfig], - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - cache_config: CacheConfig) -> nn.Module: + def load_model(self, vllm_config: VllmConfig) -> nn.Module: + model_config = vllm_config.model_config + parallel_config = vllm_config.parallel_config self._verify_config(model_config, parallel_config) if parallel_config.tensor_parallel_size > 1: @@ -559,10 +524,8 @@ def load_model(self, *, model_config: ModelConfig, % get_tensor_model_parallel_rank() if is_vllm_tensorized(self.tensorizer_config): - return self._load_model_serialized(model_config, device_config, - lora_config, cache_config) - return self._load_model_serialized_cpu(model_config, device_config, - lora_config, cache_config) + return self._load_model_serialized(vllm_config=vllm_config) + return self._load_model_serialized_cpu(vllm_config=vllm_config) @staticmethod def save_model( @@ -648,12 +611,9 @@ def _prepare_weights(self, model_name_or_path: str, def download_model(self, model_config: ModelConfig) -> None: self._prepare_weights(model_config.model, model_config.revision) - def load_model(self, *, model_config: ModelConfig, - device_config: DeviceConfig, - lora_config: Optional[LoRAConfig], - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - cache_config: CacheConfig) -> nn.Module: + def load_model(self, vllm_config: VllmConfig) -> nn.Module: + device_config = vllm_config.device_config + model_config = vllm_config.model_config from safetensors.torch import safe_open from vllm.distributed import get_tensor_model_parallel_rank @@ -663,8 +623,7 @@ def load_model(self, *, model_config: ModelConfig, with set_default_torch_dtype(model_config.dtype): with torch.device(device_config.device): - model = _initialize_model(model_config, self.load_config, - lora_config, cache_config) + model = _initialize_model(vllm_config=vllm_config) for _, module in model.named_modules(): quant_method = getattr(module, "quant_method", None) if quant_method is not None: @@ -769,6 +728,10 @@ class BitsAndBytesModelLoader(BaseModelLoader): def __init__(self, load_config: LoadConfig): super().__init__(load_config) + # Save the module names without sharding. + self.unsharded_weights_modules: List[str] = [] + # Save the module names that are sharded by column. + self.column_sharded_weights_modules: List[str] = [] # we don't need to quantize the whole model, only the target modules # that are specified in the adapter config file. If the adapter config # file is not provided, we will quantize the default modules. @@ -1005,16 +968,21 @@ def _unquantized_generator(self, hf_weights_files, use_safetensors, if any(target_module in weight_name for target_module in self.target_modules) and weight_name.endswith(".weight"): weight_name = weight_name.replace(".weight", ".qweight") - - if any(module in weight_name - for module in self.column_parallel_weights_modules): - + # Without sharding + if any( + weight_name.startswith(module) + for module in self.unsharded_weights_modules): + weight_sub_tensor = weight_tensor + # Shard by column + elif any( + weight_name.startswith(module) + for module in self.column_sharded_weights_modules): total_size = weight_tensor.size(-1) start_index = total_size // tp_size * tp_rank end_index = total_size // tp_size * (tp_rank + 1) weight_sub_tensor = weight_tensor[..., start_index:end_index] - + # Shard by row else: total_size = weight_tensor.size(0) start_index = total_size // tp_size * tp_rank @@ -1063,11 +1031,16 @@ def _load_weights(self, model_config: ModelConfig, else: self.target_modules = self.default_target_modules - if hasattr(model, 'column_parallel_weights_modules'): - self.column_parallel_weights_modules = \ - model.column_parallel_weights_modules - else: - self.column_parallel_weights_modules = [] + for name, module in model.named_modules(): + # Some modules like `ReplicatedLinear` should not have their weights + # sharded. The reason for implementing it this way is to avoid new + # static variable in the model implementation. + if isinstance(module, (ReplicatedLinear, )): + self.unsharded_weights_modules.append(name) + # In TP, these weights are partitioned along the column + # dimension (dim=-1) + elif isinstance(module, (RowParallelLinear, )): + self.column_sharded_weights_modules.append(name) self.model_type = type(model).__name__ @@ -1115,7 +1088,13 @@ def _load_weights(self, model_config: ModelConfig, for shard_name, ( weight_name, index ) in model.bitsandbytes_stacked_params_mapping.items(): - if shard_name in quant_param_name: + + shard_pos = quant_param_name.find(shard_name) + # Some models, such as MiniCPM V2.5/2.6, contain both + # module names 'kv_proj' and 'qkv_proj'. To prevent 'kv_proj' + # from being incorrectly identified as being present in + # 'vpm.encoder.layers.0.self_attn.qkv_proj.qweight + if shard_pos > 0 and quant_param_name[shard_pos - 1] == ".": shard_index = index quant_param_name = quant_param_name.replace( shard_name, weight_name) @@ -1157,16 +1136,12 @@ def _load_weights(self, model_config: ModelConfig, def download_model(self, model_config: ModelConfig) -> None: self._prepare_weights(model_config.model, model_config.revision) - def load_model(self, *, model_config: ModelConfig, - device_config: DeviceConfig, - lora_config: Optional[LoRAConfig], - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - cache_config: CacheConfig) -> nn.Module: + def load_model(self, vllm_config: VllmConfig) -> nn.Module: + device_config = vllm_config.device_config + model_config = vllm_config.model_config with set_default_torch_dtype(model_config.dtype): with torch.device(device_config.device): - model = _initialize_model(model_config, self.load_config, - lora_config, cache_config) + model = _initialize_model(vllm_config=vllm_config) self._load_weights(model_config, model) @@ -1235,13 +1210,9 @@ def _get_weights_iterator( def download_model(self, model_config: ModelConfig) -> None: self._prepare_weights(model_config.model) - def load_model(self, *, model_config: ModelConfig, - device_config: DeviceConfig, - lora_config: Optional[LoRAConfig], - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - cache_config: CacheConfig) -> nn.Module: - + def load_model(self, vllm_config: VllmConfig) -> nn.Module: + device_config = vllm_config.device_config + model_config = vllm_config.model_config local_model_path = self._prepare_weights(model_config.model) gguf_weights_map = self._get_gguf_weights_map(model_config) # we can only know if tie word embeddings after mapping weights @@ -1251,8 +1222,7 @@ def load_model(self, *, model_config: ModelConfig, with set_default_torch_dtype(model_config.dtype): with torch.device(device_config.device): - model = _initialize_model(model_config, self.load_config, - lora_config, cache_config) + model = _initialize_model(vllm_config=vllm_config) model.load_weights( self._get_weights_iterator(local_model_path, gguf_weights_map)) return model diff --git a/vllm/model_executor/models/bart.py b/vllm/model_executor/models/bart.py index cbdacf779b089..0543ca978b7dd 100644 --- a/vllm/model_executor/models/bart.py +++ b/vllm/model_executor/models/bart.py @@ -624,8 +624,6 @@ def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, Decoder output torch.Tensor """ # retrieve input_ids and inputs_embeds - - input_ids = input_ids.view(-1, input_ids.shape[-1]) inputs_embeds = self.embed_tokens(input_ids) embed_pos = self.embed_positions( diff --git a/vllm/model_executor/models/blip.py b/vllm/model_executor/models/blip.py index 1f2d7384076ed..e612010677364 100644 --- a/vllm/model_executor/models/blip.py +++ b/vllm/model_executor/models/blip.py @@ -98,6 +98,11 @@ def input_processor_for_blip( if multi_modal_data is None or "image" not in multi_modal_data: return inputs + if "multi_modal_placeholders" in inputs and "image" in inputs[ + "multi_modal_placeholders"]: + # The inputs already have placeholders. + return inputs + tokenizer = cached_get_tokenizer(model_config.tokenizer) if image_feature_size_override is None: @@ -105,7 +110,7 @@ def input_processor_for_blip( else: image_feature_size = image_feature_size_override - new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens( + new_prompt, new_token_ids, ranges = repeat_and_pad_placeholder_tokens( tokenizer, inputs.get("prompt"), inputs["prompt_token_ids"], @@ -116,7 +121,8 @@ def input_processor_for_blip( # NOTE: Create a defensive copy of the original inputs return token_inputs(prompt_token_ids=new_token_ids, prompt=new_prompt, - multi_modal_data=multi_modal_data) + multi_modal_data=multi_modal_data, + multi_modal_placeholders={"image": ranges}) # Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/blip/modeling_blip.py#L164 # noqa diff --git a/vllm/model_executor/models/blip2.py b/vllm/model_executor/models/blip2.py index c3b3cc8a4ddb6..db1f92649bd49 100644 --- a/vllm/model_executor/models/blip2.py +++ b/vllm/model_executor/models/blip2.py @@ -9,13 +9,14 @@ from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, MultiModalConfig -from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext, - token_inputs) +from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, + InputContext, token_inputs) from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.utils import consecutive_placeholder_ranges from vllm.sequence import IntermediateTensors, SequenceData from .blip import (BlipVisionModel, dummy_image_for_blip, @@ -425,7 +426,11 @@ def dummy_seq_data_for_blip2( return SequenceData.from_prompt_token_counts( (image_token_id, image_feature_size * num_images), (0, seq_len - image_feature_size * num_images), - ) + ), { + "image": + consecutive_placeholder_ranges(num_items=num_images, + item_size=image_feature_size) + } def dummy_data_for_blip2(ctx: InputContext, seq_len: int, @@ -434,7 +439,7 @@ def dummy_data_for_blip2(ctx: InputContext, seq_len: int, vision_config = hf_config.vision_config num_images = mm_counts["image"] - seq_data = dummy_seq_data_for_blip2( + seq_data, ranges = dummy_seq_data_for_blip2( hf_config, seq_len, num_images, @@ -444,7 +449,7 @@ def dummy_data_for_blip2(ctx: InputContext, seq_len: int, if isinstance(vision_config, Blip2VisionConfig): mm_data = dummy_image_for_blip(vision_config, num_images) - return seq_data, mm_data + return DummyData(seq_data, mm_data, ranges) msg = f"Unsupported vision config: {type(vision_config)}" raise NotImplementedError(msg) diff --git a/vllm/model_executor/models/chameleon.py b/vllm/model_executor/models/chameleon.py index aaf559ca386cc..9f6c6786c0fa4 100644 --- a/vllm/model_executor/models/chameleon.py +++ b/vllm/model_executor/models/chameleon.py @@ -11,8 +11,8 @@ from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig, MultiModalConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size -from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext, - token_inputs) +from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, + InputContext, token_inputs) from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, @@ -30,6 +30,7 @@ from vllm.model_executor.utils import set_weight_attrs from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.utils import (cached_get_tokenizer, + consecutive_placeholder_ranges, repeat_and_pad_placeholder_tokens) from vllm.sequence import IntermediateTensors, SequenceData from vllm.utils import print_warning_once @@ -73,7 +74,11 @@ def dummy_seq_data_for_chameleon( return SequenceData.from_prompt_token_counts( (image_token_id, image_feature_size * num_images), (0, seq_len - image_feature_size * num_images), - ) + ), { + "image": + consecutive_placeholder_ranges(num_items=num_images, + item_size=image_feature_size) + } def dummy_image_for_chameleon( @@ -97,14 +102,14 @@ def dummy_data_for_chameleon(ctx: InputContext, seq_len: int, mm_counts: Mapping[str, int]): num_images = mm_counts["image"] - seq_data = dummy_seq_data_for_chameleon( + seq_data, ranges = dummy_seq_data_for_chameleon( seq_len, num_images, image_token_id=CHAMELEON_IMAGE_TOKEN_ID, ) mm_data = dummy_image_for_chameleon(num_images) - return seq_data, mm_data + return DummyData(seq_data, mm_data, ranges) def input_processor_for_chameleon(ctx: InputContext, @@ -120,9 +125,14 @@ def input_processor_for_chameleon(ctx: InputContext, if multi_modal_data is None or "image" not in multi_modal_data: return inputs + if "multi_modal_placeholders" in inputs and "image" in inputs[ + "multi_modal_placeholders"]: + # The inputs already have placeholders. + return inputs + model_config = ctx.model_config tokenizer = cached_get_tokenizer(model_config.tokenizer) - new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens( + new_prompt, new_token_ids, ranges = repeat_and_pad_placeholder_tokens( tokenizer, inputs.get("prompt"), inputs["prompt_token_ids"], diff --git a/vllm/model_executor/models/chatglm.py b/vllm/model_executor/models/chatglm.py index ca90d10e9f9fb..c3c9ec703c1e6 100644 --- a/vllm/model_executor/models/chatglm.py +++ b/vllm/model_executor/models/chatglm.py @@ -14,8 +14,8 @@ from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig, LoRAConfig, MultiModalConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size -from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext, - token_inputs) +from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, + InputContext, token_inputs) from vllm.logger import init_logger from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm @@ -31,8 +31,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.glm4_vision_encoder import EVA2CLIPModel from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalDataDict, - MultiModalInputs) +from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalInputs from vllm.multimodal.base import MultiModalData from vllm.multimodal.utils import cached_get_tokenizer from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors, @@ -117,16 +116,15 @@ def get_max_glmv_image_tokens(ctx: InputContext): raise NotImplementedError(msg) -def dummy_data_for_glmv( - ctx: InputContext, seq_len: int, mm_counts: Mapping[str, int] -) -> Tuple[SequenceData, Optional[MultiModalDataDict]]: +def dummy_data_for_glmv(ctx: InputContext, seq_len: int, + mm_counts: Mapping[str, int]) -> DummyData: hf_config = ctx.get_hf_config(ChatGLMConfig) vision_config = getattr(hf_config, 'vision_config', None) if vision_config is None: token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, [0] * seq_len) seq_data = SequenceData(token_ids) - return seq_data, None + return DummyData(seq_data, None) elif isinstance(vision_config, dict): image_size = vision_config["image_size"] image_placeholder_length = calculate_image_placeholder(vision_config) @@ -141,7 +139,7 @@ def dummy_data_for_glmv( "image": Image.new("RGB", (image_size, image_size), color=0) } - return seq_data, mm_data + return DummyData(seq_data, mm_data) msg = f"Unsupported vision config: {type(vision_config)}" raise NotImplementedError(msg) diff --git a/vllm/model_executor/models/clip.py b/vllm/model_executor/models/clip.py index a3293020c042e..2d81b9266826b 100644 --- a/vllm/model_executor/models/clip.py +++ b/vllm/model_executor/models/clip.py @@ -19,6 +19,7 @@ from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.multimodal.utils import (cached_get_tokenizer, + consecutive_placeholder_ranges, repeat_and_pad_placeholder_tokens) from vllm.sequence import SequenceData @@ -49,14 +50,13 @@ def get_max_clip_image_tokens(hf_config: CLIPVisionConfig) -> int: return get_clip_image_feature_size(hf_config) -def dummy_seq_data_for_clip( - hf_config: CLIPVisionConfig, - seq_len: int, - num_images: int, - *, - image_token_id: int, - image_feature_size_override: Optional[int] = None, -): +def dummy_seq_data_for_clip(hf_config: CLIPVisionConfig, + seq_len: int, + num_images: int, + *, + image_token_id: int, + image_feature_size_override: Optional[int] = None, + mm_key: str = "image"): if image_feature_size_override is None: image_feature_size = get_clip_image_feature_size(hf_config) else: @@ -65,7 +65,11 @@ def dummy_seq_data_for_clip( return SequenceData.from_prompt_token_counts( (image_token_id, image_feature_size * num_images), (0, seq_len - image_feature_size * num_images), - ) + ), { + mm_key: + consecutive_placeholder_ranges(num_items=num_images, + item_size=image_feature_size) + } def dummy_image_for_clip( @@ -117,6 +121,11 @@ def input_processor_for_clip( if multi_modal_data is None or "image" not in multi_modal_data: return inputs + if "multi_modal_placeholders" in inputs and "image" in inputs[ + "multi_modal_placeholders"]: + # The inputs already have placeholders. + return inputs + tokenizer = cached_get_tokenizer(model_config.tokenizer) if image_feature_size_override is None: @@ -130,7 +139,7 @@ def input_processor_for_clip( else: image_feature_size = image_feature_size_override - new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens( + new_prompt, new_token_ids, ranges = repeat_and_pad_placeholder_tokens( tokenizer, inputs.get("prompt"), inputs["prompt_token_ids"], @@ -141,7 +150,8 @@ def input_processor_for_clip( # NOTE: Create a defensive copy of the original inputs return token_inputs(prompt_token_ids=new_token_ids, prompt=new_prompt, - multi_modal_data=multi_modal_data) + multi_modal_data=multi_modal_data, + multi_modal_placeholders={"image": ranges}) # Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/clip/modeling_clip.py#L164 # noqa diff --git a/vllm/model_executor/models/falcon.py b/vllm/model_executor/models/falcon.py index 36c85e37783ab..c376347811965 100644 --- a/vllm/model_executor/models/falcon.py +++ b/vllm/model_executor/models/falcon.py @@ -401,8 +401,6 @@ class FalconForCausalLM(nn.Module, SupportsPP): ".dense_h_to_4h.", ".dense_4h_to_h.", ] - # in TP, these weights are partitioned along the column dimension (dim=-1) - column_parallel_weights_modules = [".dense_4h_to_h.", ".dense."] def __init__( self, diff --git a/vllm/model_executor/models/fuyu.py b/vllm/model_executor/models/fuyu.py index 358d1dd288c49..0de590d1d8372 100644 --- a/vllm/model_executor/models/fuyu.py +++ b/vllm/model_executor/models/fuyu.py @@ -27,8 +27,8 @@ from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, MultiModalConfig -from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext, - token_inputs) +from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, + InputContext, token_inputs) from vllm.model_executor.layers.linear import ColumnParallelLinear from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import SamplerOutput @@ -37,9 +37,11 @@ from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.base import MultiModalInputs from vllm.multimodal.image import cached_get_image_processor -from vllm.multimodal.utils import cached_get_tokenizer +from vllm.multimodal.utils import (cached_get_tokenizer, + consecutive_placeholder_ranges) from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors, SequenceData) +from vllm.utils import is_list_of from .interfaces import SupportsMultiModal, SupportsPP from .utils import AutoWeightsLoader, flatten_bn, merge_multimodal_embeddings @@ -103,7 +105,11 @@ def dummy_seq_data_for_fuyu(ctx: InputContext, seq_len: int, num_images: int): token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, image_token_ids) * num_images token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE, [0]) * (seq_len - image_feature_size * num_images) - return SequenceData(token_ids) + return SequenceData(token_ids), { + "image": + consecutive_placeholder_ranges(num_items=num_images, + item_size=image_feature_size) + } def dummy_image_for_fuyu( @@ -119,15 +125,15 @@ def dummy_image_for_fuyu( def dummy_data_for_fuyu(ctx: InputContext, seq_len: int, mm_counts: Mapping[str, int]): num_images = mm_counts["image"] - seq_data = dummy_seq_data_for_fuyu(ctx, seq_len, num_images) + seq_data, ranges = dummy_seq_data_for_fuyu(ctx, seq_len, num_images) mm_data = dummy_image_for_fuyu(num_images, image_width=MAX_IMAGE_FEATURE_SIZE_WIDTH, image_height=MAX_IMAGE_FEATURE_SIZE_HEIGHT) - return seq_data, mm_data + return DummyData(seq_data, mm_data, ranges) def _fuyu_image_preprocess(image_processor: FuyuImageProcessor, - data: Image.Image): + data: List[Image.Image]): image_encoding = image_processor.preprocess(data, return_tensors="pt") batch_images = torch.stack([img[0] for img in image_encoding["images"] ]).unsqueeze(1) @@ -158,8 +164,10 @@ def input_processor_for_fuyu(ctx: InputContext, inputs: DecoderOnlyInputs): model_config = ctx.model_config image_data = multi_modal_data["image"] new_multi_modal_data = {} + image_list = image_data if isinstance(image_data, list) else [image_data] + # process image data - if isinstance(image_data, Image.Image): + if is_list_of(image_list, Image.Image): # Fuyu's image_processor can also finish token padding image_processor: FuyuImageProcessor = cached_get_image_processor( model_config.model) @@ -171,7 +179,7 @@ def input_processor_for_fuyu(ctx: InputContext, inputs: DecoderOnlyInputs): ]) new_multi_modal_data["image"] = image_patches - elif isinstance(image_data, torch.Tensor): + elif is_list_of(image_list, torch.Tensor): raise NotImplementedError("Embeddings input is not supported yet") else: raise TypeError(f"Invalid image type: {type(image_data)}") @@ -198,12 +206,13 @@ def input_processor_for_fuyu(ctx: InputContext, inputs: DecoderOnlyInputs): def input_mapper_for_fuyu(ctx: InputContext, data: object): model_config = ctx.model_config - if isinstance(data, Image.Image): + data_list = data if isinstance(data, list) else [data] + if is_list_of(data_list, Image.Image): # Fuyu's image_processor can also finish token padding image_processor: FuyuImageProcessor = cached_get_image_processor( model_config.model) - model_image_input = _fuyu_image_preprocess(image_processor, data) + model_image_input = _fuyu_image_preprocess(image_processor, data_list) data = torch.stack([ image_patch[0] for image_patch in model_image_input["image_patches"] diff --git a/vllm/model_executor/models/gemma.py b/vllm/model_executor/models/gemma.py index 57b2b43c82f89..029178af61da0 100644 --- a/vllm/model_executor/models/gemma.py +++ b/vllm/model_executor/models/gemma.py @@ -350,7 +350,6 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA, SupportsPP): "gate_up_proj", "down_proj", ] - # BitandBytes specific attributes default_bitsandbytes_target_modules = [ ".gate_proj.", @@ -361,8 +360,6 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ".v_proj.", ".o_proj.", ] - # in TP, these weights are partitioned along the column dimension (dim=-1) - column_parallel_weights_modules = [".down_proj.", ".o_proj."] bitsandbytes_stacked_params_mapping = { # shard_name, weight_name, index "q_proj": ("qkv_proj", 0), diff --git a/vllm/model_executor/models/gemma2.py b/vllm/model_executor/models/gemma2.py index 693f32160a289..9238ed839c9de 100644 --- a/vllm/model_executor/models/gemma2.py +++ b/vllm/model_executor/models/gemma2.py @@ -390,8 +390,6 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ".v_proj.", ".o_proj.", ] - # in TP, these weights are partitioned along the column dimension (dim=-1) - column_parallel_weights_modules = [".down_proj.", ".o_proj."] bitsandbytes_stacked_params_mapping = { # shard_name, weight_name, index "q_proj": ("qkv_proj", 0), diff --git a/vllm/model_executor/models/h2ovl.py b/vllm/model_executor/models/h2ovl.py new file mode 100644 index 0000000000000..43242fe370ba2 --- /dev/null +++ b/vllm/model_executor/models/h2ovl.py @@ -0,0 +1,401 @@ +# adapted from https://huggingface.co/h2oai/h2ovl-mississippi-2b/blob/main/modeling_h2ovl_chat.py +# https://huggingface.co/h2oai/h2ovl-mississippi-2b/blob/main/image_process.py +# -------------------------------------------------------- +# H2OVL-Mississippi +# Copyright (c) 2024 H2O.AI +# Licensed under Apache 2.0 License [see LICENSE for details] +# -------------------------------------------------------- +from functools import partial +from typing import List, Optional, Tuple + +import torch +from PIL import Image +from transformers import PretrainedConfig + +from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext, + token_inputs) +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.base import MultiModalInputs +from vllm.multimodal.utils import cached_get_tokenizer +from vllm.utils import is_list_of + +from .intern_vit import InternVisionModel +from .internvl import (IMG_CONTEXT, IMG_END, IMG_START, InternVLChatModel, + InternVLInputPipeline, build_transform, + find_closest_aspect_ratio, get_internvl_num_patches) + + +# modified to include blocks generated in second pass +def calculate_num_blocks( + orig_width: int, + orig_height: int, + min_num: int, + max_num: int, + image_size: int, + use_thumbnail: bool, + prior_aspect_ratio=None, +) -> Tuple[int, int, int, Tuple[int, int]]: + aspect_ratio = orig_width / orig_height + + # calculate the existing image aspect ratio + target_ratios = set((i, j) for n in range(min_num, max_num + 1) + for i in range(1, n + 1) for j in range(1, n + 1) + if i * j <= max_num and i * j >= min_num) + target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) + + # if prior_aspect_ratio is provided, filter the target ratios + if prior_aspect_ratio is not None: + target_ratios = [ + ratio for ratio in target_ratios if prior_aspect_ratio[0] % + ratio[0] != 0 and prior_aspect_ratio[1] % ratio[1] != 0 + ] + + # find the closest aspect ratio to the target + target_aspect_ratio = find_closest_aspect_ratio(aspect_ratio, + target_ratios, orig_width, + orig_height, image_size) + + # calculate the target width and height + target_width = image_size * target_aspect_ratio[0] + target_height = image_size * target_aspect_ratio[1] + blocks = target_aspect_ratio[0] * target_aspect_ratio[1] + # add thumbnail image if num_blocks > 1 + if use_thumbnail and blocks > 1: + blocks += 1 + return blocks, target_width, target_height, target_aspect_ratio + + +# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B +# refactored to handle prior_aspect_ratio as optional +def dynamic_preprocess( + image: Image.Image, + min_num: int, + max_num: int, + image_size: int, + use_thumbnail: bool, + prior_aspect_ratio: Optional[Tuple[int, int]] = None, +) -> Tuple[List[Image.Image], Tuple[int, int]]: + orig_width, orig_height = image.size + + # calculate the number of blocks based on prior aspect ratio if available + blocks, target_width, target_height, target_aspect_ratio = ( + calculate_num_blocks( + orig_width, + orig_height, + min_num, + max_num, + image_size, + use_thumbnail=False, + prior_aspect_ratio=prior_aspect_ratio, + )) + # resize the image + resized_img = image.resize((target_width, target_height)) + processed_images = [] + for i in range(blocks): + box = ( + (i % (target_width // image_size)) * image_size, + (i // (target_width // image_size)) * image_size, + ((i % (target_width // image_size)) + 1) * image_size, + ((i // (target_width // image_size)) + 1) * image_size, + ) + # split the image + split_img = resized_img.crop(box) + processed_images.append(split_img) + assert len(processed_images) == blocks + if use_thumbnail and len(processed_images) != 1: + thumbnail_img = image.resize((image_size, image_size)) + processed_images.append(thumbnail_img) + return processed_images, target_aspect_ratio + + +def load_image( + image: Image.Image, + input_size=448, + min_num=1, + max_num=6, + use_thumbnail=True, + prior_aspect_ratio: Optional[Tuple[int, int]] = None, +) -> Tuple[torch.Tensor, Tuple[int, int]]: + transform = build_transform(input_size=input_size) + images, target_aspect_ratio = dynamic_preprocess( + image, + image_size=input_size, + use_thumbnail=use_thumbnail, + min_num=min_num, + max_num=max_num, + prior_aspect_ratio=prior_aspect_ratio, + ) + pixel_values = [transform(image) for image in images] + pixel_values = torch.stack(pixel_values) + return pixel_values, target_aspect_ratio + + +# refactored to use the combined load_image function +def image_to_pixel_values( + image: Image.Image, + input_size: int, + min_num: int, + max_num: int, + use_thumbnail: bool, + use_MSAC: bool, +) -> torch.Tensor: + # when MSAC is turned on, we need to process the image twice + if use_MSAC: + # first pass + pixel_values, target_aspect_ratio = load_image( + image, + input_size=input_size, + min_num=min_num, + max_num=max_num, + use_thumbnail=True, + ) + # second pass + pixel_values2, _ = load_image( + image, + input_size=input_size, + min_num=min_num, + max_num=max_num, + prior_aspect_ratio=target_aspect_ratio, + ) + # combine pixel values + pixel_values = torch.cat( + [pixel_values2[:-1], pixel_values[:-1], pixel_values2[-1:]], 0) + + else: + pixel_values, _ = load_image( + image, + input_size=input_size, + min_num=min_num, + max_num=max_num, + use_thumbnail=use_thumbnail, + ) + + return pixel_values + + +def image_to_pixel_values_wrapper(hf_config: PretrainedConfig, + max_dynamic_patch: Optional[int] = None, + use_MSAC: Optional[bool] = None): + image_size = hf_config.vision_config.image_size + min_num = hf_config.min_dynamic_patch + if max_dynamic_patch is None: + max_dynamic_patch = hf_config.max_dynamic_patch + if use_MSAC is None: + use_MSAC = hf_config.use_msac + use_thumbnail = hf_config.use_thumbnail + return partial( + image_to_pixel_values, + input_size=image_size, + min_num=min_num, + max_num=max_dynamic_patch, + use_thumbnail=use_thumbnail, + use_MSAC=use_MSAC, + ) + + +def get_max_internvl_image_tokens(ctx: InputContext, + *, + max_dynamic_patch: Optional[int] = None): + """ + Calculate the maximum number of tokens with/without MSAC and thumbnail + """ + hf_config = ctx.get_hf_config() + use_thumbnail = hf_config.use_thumbnail + use_MSAC = hf_config.use_msac + + if max_dynamic_patch is None: + max_dynamic_patch = hf_config.max_dynamic_patch + + num_patches = get_internvl_num_patches(hf_config) + + coefficient = 2 if use_MSAC else 1 + num_blocks = coefficient * max_dynamic_patch + (1 if use_thumbnail else 0) + + return num_blocks * num_patches + + +class H2OVLInputPipeline(InternVLInputPipeline): + """ + Input pipeline for processing image and text data for the H2OVL model. + """ + + def input_processor( + self, + ctx: InputContext, + inputs: DecoderOnlyInputs, + *, + max_dynamic_patch: Optional[int] = None, + ) -> DecoderOnlyInputs: + # get multi_modal_data + multi_modal_data = inputs.get("multi_modal_data") + if multi_modal_data is None or "image" not in multi_modal_data: + return inputs + + model_config = ctx.model_config + hf_config = ctx.get_hf_config() + use_MSAC = hf_config.use_msac + + image_data = multi_modal_data["image"] + num_patches = get_internvl_num_patches(hf_config) + + image_pixel_values_mapper = image_to_pixel_values_wrapper( + hf_config, max_dynamic_patch=max_dynamic_patch) + + # single image + if isinstance(image_data, Image.Image): + pixel_values = image_pixel_values_mapper(image_data, + use_MSAC=use_MSAC) + num_blocks = pixel_values.shape[0] + image_feature_sizes = [num_blocks * num_patches] + pixel_values = pixel_values.unsqueeze(0) + + # multi images + elif is_list_of(image_data, Image.Image): + # Do not use MSAC for multi images + image_feature_sizes = [] + pixel_values = [ + image_pixel_values_mapper(image, use_MSAC=False) + for image in image_data + ] + for pixel_value in pixel_values: + num_blocks = pixel_value.shape[0] + image_feature_sizes.append(num_blocks * num_patches) + + # image embeddings as input + elif isinstance(image_data, torch.Tensor): + _, image_feature_size, _ = image_data.shape + image_feature_sizes = [image_feature_size] + pixel_values = None + + # multi-image image embeddings + elif is_list_of(image_data, torch.Tensor): + + image_feature_sizes = [] + for image_embed in image_data: + _, image_feature_size, _ = image_embed.shape + image_feature_sizes.append(image_feature_size) + pixel_values = None + + else: + raise TypeError(f"Invalid image type: {type(image_data)}") + + tokenizer = cached_get_tokenizer( + model_config.tokenizer, + trust_remote_code=model_config.trust_remote_code, + ) + + prompt = inputs.get("prompt") + prompt_token_ids = inputs["prompt_token_ids"] + if prompt is None: + prompt = tokenizer.decode(prompt_token_ids) + + new_prompt = self._expand_image_prompt(prompt, image_feature_sizes, + num_patches) + new_prompt_token_ids = tokenizer.encode(new_prompt) + + # Wrap image processing in input_processor to avoid duplication + image_token_id = tokenizer.encode( + self.img_context_token, + add_special_tokens=False, + return_tensors="pt", + )[0] + + # Update multi_modal_data to return + if pixel_values is not None: + multi_modal_data = { + "image": { + "pixel_values": pixel_values, + "image_token_id": image_token_id, + } + } + else: + multi_modal_data = {"image": {"image_embeds": image_data}} + + return token_inputs( + prompt=prompt, + prompt_token_ids=new_prompt_token_ids, + multi_modal_data=multi_modal_data, + ) + + def input_mapper( + self, + ctx: InputContext, + data: object, + *, + max_dynamic_patch: Optional[int] = None, + ) -> MultiModalInputs: + + # NOTE: Preprocessing for the image data is done in the + # 'input_processor' function during actual inference. + if isinstance(data, dict): + return MultiModalInputs(data) + + # The section below is only used with dummy data during + # memory profiling. + hf_config = ctx.get_hf_config() + + image_pixel_values_mapper = image_to_pixel_values_wrapper( + hf_config, max_dynamic_patch) + + if isinstance(data, Image.Image): + pixel_values = image_pixel_values_mapper(data) + pixel_values = pixel_values.unsqueeze(0) + + elif is_list_of(data, Image.Image): + hf_config.use_msac = False + pixel_values = [image_pixel_values_mapper(img) for img in data] + + else: + return MultiModalInputs({"image_embeds": data}) + model_config = ctx.model_config + tokenizer = cached_get_tokenizer( + model_config.tokenizer, + trust_remote_code=model_config.trust_remote_code, + ) + image_token_id = tokenizer.encode( + self.img_context_token, + add_special_tokens=False, + return_tensors="pt", + )[0] + + return MultiModalInputs({ + "pixel_values": pixel_values, + "image_token_id": image_token_id + }) + + +input_pipeline = H2OVLInputPipeline(IMG_START, IMG_END, IMG_CONTEXT) + + +@MULTIMODAL_REGISTRY.register_image_input_mapper(input_pipeline.input_mapper) +@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_internvl_image_tokens) +@INPUT_REGISTRY.register_dummy_data(input_pipeline.dummy_data) +@INPUT_REGISTRY.register_input_processor(input_pipeline.input_processor) +class H2OVLChatModel(InternVLChatModel): + + def _init_vision_model( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig], + *, + is_mono: bool, + prefix: str, + ): + if not is_mono: + vision_feature_layer = config.select_layer + if vision_feature_layer < 0: + num_hidden_layers = (config.vision_config.num_hidden_layers + + vision_feature_layer + 1) + else: + num_hidden_layers = vision_feature_layer + 1 + + return InternVisionModel( + config.vision_config, + quant_config=quant_config, + num_hidden_layers_override=num_hidden_layers, + prefix=prefix, + ) + else: + msg = "Monolith mode is not applicable to H2OVL" + raise NotImplementedError(msg) diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py index 1c1fde5b30983..d2ec0ff6e74c6 100644 --- a/vllm/model_executor/models/internvl.py +++ b/vllm/model_executor/models/internvl.py @@ -17,8 +17,8 @@ from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, MultiModalConfig -from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext, - token_inputs) +from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, + InputContext, token_inputs) from vllm.model_executor.layers.quantization import (AWQConfig, QuantizationConfig) from vllm.model_executor.layers.sampler import Sampler, SamplerOutput @@ -379,7 +379,7 @@ def dummy_data( model_config.tokenizer, trust_remote_code=model_config.trust_remote_code) - seq_data = dummy_seq_data_for_clip( + seq_data, ranges = dummy_seq_data_for_clip( hf_config.vision_config, seq_len, num_images, @@ -398,7 +398,7 @@ def dummy_data( image_height_override=max_image_height, ) - return seq_data, mm_data + return DummyData(seq_data, mm_data, ranges) input_pipeline = InternVLInputPipeline(IMG_START, IMG_END, IMG_CONTEXT) diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index fddd39fb8c85b..6f7949c880e61 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -12,26 +12,19 @@ from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - MergedColumnParallelLinear, - QKVParallelLinear, +from vllm.model_executor.layers.linear import (QKVParallelLinear, ReplicatedLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( - causal_conv1d_fn, causal_conv1d_update) -from vllm.model_executor.layers.mamba.ops.mamba_ssm import ( - selective_scan_fn, selective_state_update) +from vllm.model_executor.layers.mamba.mamba_mixer import MambaMixer from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.model_loader.weight_utils import ( - composed_weight_loader, default_weight_loader, sharded_weight_loader) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.mamba_cache import (MambaCacheManager, MambaCacheParams) from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.model_executor.utils import set_weight_attrs from vllm.sequence import IntermediateTensors from vllm.worker.model_runner import (_BATCH_SIZES_TO_CAPTURE, _get_graph_batch_size) @@ -41,179 +34,6 @@ KVCache = Tuple[torch.Tensor, torch.Tensor] -# Adapted from transformers.models.mamba.modeling_mamba.MambaMixer -class JambaMambaMixer(nn.Module): - """ - Compute ∆, A, B, C, and D the state space parameters and compute - the `contextualized_states`. A, D are input independent - (see Mamba paper [1] Section 3.5.2 "Interpretation of A" - for why A isn't selective) ∆, B, C are input-dependent - (this is a key difference between Mamba and the linear time - invariant S4, and is why Mamba is called - **selective** state spaces) - """ - - def __init__(self, config: JambaConfig): - super().__init__() - self.config = config - self.hidden_size = config.hidden_size - self.ssm_state_size = config.mamba_d_state - self.conv_kernel_size = config.mamba_d_conv - self.intermediate_size = config.mamba_expand * config.hidden_size - self.time_step_rank = config.mamba_dt_rank - self.use_conv_bias = config.mamba_conv_bias - self.use_bias = config.mamba_proj_bias - self.conv1d = ColumnParallelLinear( - input_size=self.conv_kernel_size, - output_size=self.intermediate_size, - bias=self.use_conv_bias, - ) - # unsqueeze to fit conv1d weights shape into the linear weights shape. - # Can't do this in `weight_loader` since it already exists in - # `ColumnParallelLinear` and `set_weight_attrs` - # doesn't allow to override it - self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1) - - self.in_proj = MergedColumnParallelLinear(self.hidden_size, - [self.intermediate_size] * 2, - bias=self.use_bias) - # selective projection used to make dt, B and C input dependent - self.x_proj = RowParallelLinear( - self.intermediate_size, - self.time_step_rank + self.ssm_state_size * 2, - bias=False, - ) - # time step projection (discretization) - - # In the forward we need to apply dt_proj without the bias, - # as the bias is added in the selective scan kernel. - self.dt_proj = ColumnParallelLinear(self.time_step_rank, - self.intermediate_size, - bias=True, - skip_bias_add=True) - - tp_size = get_tensor_model_parallel_world_size() - self.A = nn.Parameter( - torch.empty( - self.intermediate_size // tp_size, - self.ssm_state_size, - dtype=torch.float32, - )) - self.D = nn.Parameter(torch.ones(self.intermediate_size // tp_size)) - - set_weight_attrs(self.D, {"weight_loader": sharded_weight_loader(0)}) - a_weight_loader = composed_weight_loader( - sharded_weight_loader(0), lambda x: -torch.exp(x.float())) - set_weight_attrs(self.A, {"weight_loader": a_weight_loader}) - - self.out_proj = RowParallelLinear( - self.intermediate_size, - self.hidden_size, - bias=self.use_bias, - input_is_parallel=True, - ) - self.activation = config.hidden_act - - self.dt_layernorm = RMSNorm(self.time_step_rank, - eps=config.rms_norm_eps) - self.b_layernorm = RMSNorm(self.ssm_state_size, - eps=config.rms_norm_eps) - self.c_layernorm = RMSNorm(self.ssm_state_size, - eps=config.rms_norm_eps) - - def forward(self, hidden_states: torch.Tensor, - attn_metadata: AttentionMetadata, - mamba_cache_params: MambaCacheParams): - - # 1. Gated MLP's linear projection - projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1) - hidden_states, gate = projected_states.chunk(2, dim=-2) - - # 2. Convolution sequence transformation - conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), - self.conv1d.weight.size(2)) - - if attn_metadata.query_start_loc is not None \ - and attn_metadata.context_lens_tensor is not None: - # |---------- N-1 iteration --------| - # |---------------- N iteration ---------------------| - # |- tokenA -|......................|-- newTokens ---| - # |---------- context_len ----------| - # |-------------------- seq_len ---------------------| - # |-- query_len ---| - hidden_states = causal_conv1d_fn( - hidden_states, - conv_weights, - self.conv1d.bias, - activation=self.activation, - conv_states=mamba_cache_params.conv_state, - has_initial_state=attn_metadata.context_lens_tensor > 0, - cache_indices=mamba_cache_params.state_indices_tensor, - query_start_loc=attn_metadata.query_start_loc) - else: - hidden_states = causal_conv1d_update( - hidden_states.transpose(0, 1), - mamba_cache_params.conv_state, - conv_weights, - self.conv1d.bias, - self.activation, - conv_state_indices=mamba_cache_params.state_indices_tensor) - hidden_states = hidden_states.transpose(0, 1) - - # 3. State Space Model sequence transformation - # 3.a. input varying initialization of time_step, B and C - ssm_parameters = self.x_proj(hidden_states.transpose(-2, -1))[0] - - time_step, B, C = torch.split( - ssm_parameters, - [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], - dim=-1, - ) - time_step = self.dt_layernorm(time_step.contiguous()) - B = self.b_layernorm(B.contiguous()) - C = self.c_layernorm(C.contiguous()) - - discrete_time_step = self.dt_proj(time_step)[0].transpose(-2, -1) - # 3.c perform the recurrence y ← SSM(A, B, C)(x) - time_proj_bias = (self.dt_proj.bias.float() if hasattr( - self.dt_proj, "bias") else None) - - if attn_metadata.query_start_loc is not None \ - and attn_metadata.context_lens_tensor is not None: - scan_outputs = selective_scan_fn( - hidden_states, - mamba_cache_params.ssm_state, - discrete_time_step, - self.A, - B.transpose(-2, -1), - C.transpose(-2, -1), - self.D.float(), - gate, - time_proj_bias, - delta_softplus=True, - cache_indices=mamba_cache_params.state_indices_tensor, - has_initial_state=attn_metadata.context_lens_tensor > 0, - query_start_loc=attn_metadata.query_start_loc) - else: - scan_outputs = selective_state_update( - mamba_cache_params.ssm_state, - hidden_states.transpose(0, 1), - discrete_time_step.transpose(0, 1), - self.A, - B, - C, - self.D, - gate.transpose(0, 1), - time_proj_bias, - dt_softplus=True, - state_batch_indices=mamba_cache_params.state_indices_tensor) - scan_outputs = scan_outputs.transpose(0, 1) - - # 4. Final linear projection - contextualized_states = self.out_proj(scan_outputs.transpose(-2, - -1))[0] - return contextualized_states - - class JambaMoE(nn.Module): def __init__(self, @@ -284,9 +104,18 @@ def __init__(self, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None) -> None: super().__init__() - self.layer_idx = layer_idx self.config = config - self.mamba = JambaMambaMixer(config) + self.mamba = MambaMixer(hidden_size= config.hidden_size, + ssm_state_size = config.mamba_d_state, + conv_kernel_size = config.mamba_d_conv, + intermediate_size = config.mamba_expand *\ + config.hidden_size, + time_step_rank = config.mamba_dt_rank, + use_conv_bias = config.mamba_conv_bias, + use_bias = config.mamba_proj_bias, + use_rms_norm=True, + rms_norm_eps=config.rms_norm_eps, + activation=config.hidden_act) num_experts = config.layers_num_experts[layer_idx] ffn_layer_class = JambaMoE if num_experts > 1 else JambaMLP diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 8a9e5203972be..38a31f420cec9 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -464,8 +464,6 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ".v_proj.", ".o_proj.", ] - # in TP, these weights are partitioned along the column dimension (dim=-1) - column_parallel_weights_modules = [".down_proj.", ".o_proj."] bitsandbytes_stacked_params_mapping = { # shard_name, weight_name, index "q_proj": ("qkv_proj", 0), diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index 27055e7ced865..7fbd59ebd98fd 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -10,7 +10,8 @@ from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, MultiModalConfig -from vllm.inputs import INPUT_REGISTRY, DecoderOnlyInputs, InputContext +from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, + InputContext) from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import Sampler, SamplerOutput @@ -111,7 +112,7 @@ def dummy_data_for_llava(ctx: InputContext, seq_len: int, image_feature_size = get_max_llava_image_tokens(ctx) if isinstance(vision_config, CLIPVisionConfig): - seq_data = dummy_seq_data_for_clip( + seq_data, ranges = dummy_seq_data_for_clip( vision_config, seq_len, num_images, @@ -120,9 +121,9 @@ def dummy_data_for_llava(ctx: InputContext, seq_len: int, ) mm_data = dummy_image_for_clip(vision_config, num_images) - return seq_data, mm_data + return DummyData(seq_data, mm_data, ranges) elif isinstance(vision_config, SiglipVisionConfig): - seq_data = dummy_seq_data_for_siglip( + seq_data, ranges = dummy_seq_data_for_siglip( vision_config, seq_len, num_images, @@ -131,9 +132,9 @@ def dummy_data_for_llava(ctx: InputContext, seq_len: int, ) mm_data = dummy_image_for_siglip(vision_config, num_images) - return seq_data, mm_data + return DummyData(seq_data, mm_data, ranges) elif isinstance(vision_config, PixtralVisionConfig): - seq_data = dummy_seq_data_for_pixtral_hf( + seq_data, ranges = dummy_seq_data_for_pixtral_hf( vision_config, seq_len, num_images, @@ -142,7 +143,7 @@ def dummy_data_for_llava(ctx: InputContext, seq_len: int, ) mm_data = dummy_image_for_pixtral_hf(vision_config, num_images) - return seq_data, mm_data + return DummyData(seq_data, mm_data, ranges) msg = f"Unsupported vision config: {type(vision_config)}" raise NotImplementedError(msg) diff --git a/vllm/model_executor/models/llava_next.py b/vllm/model_executor/models/llava_next.py index e8540d85ff565..7a2c95594ddcd 100644 --- a/vllm/model_executor/models/llava_next.py +++ b/vllm/model_executor/models/llava_next.py @@ -12,7 +12,8 @@ from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, MultiModalConfig, PoolerConfig -from vllm.inputs import INPUT_REGISTRY, DecoderOnlyInputs, InputContext +from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, + InputContext) from vllm.model_executor.layers.pooler import Pooler, PoolingType from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import Sampler, SamplerOutput @@ -180,7 +181,7 @@ def dummy_data_for_llava_next(ctx: InputContext, seq_len: int, max_feat_height, max_feat_width = pinpoint if isinstance(vision_config, CLIPVisionConfig): - seq_data = dummy_seq_data_for_clip( + seq_data, ranges = dummy_seq_data_for_clip( vision_config, seq_len, num_images, @@ -195,9 +196,9 @@ def dummy_data_for_llava_next(ctx: InputContext, seq_len: int, image_height_override=max_feat_height, ) - return seq_data, mm_data + return DummyData(seq_data, mm_data, ranges) elif isinstance(vision_config, SiglipVisionConfig): - seq_data = dummy_seq_data_for_siglip( + seq_data, ranges = dummy_seq_data_for_siglip( vision_config, seq_len, num_images, @@ -212,7 +213,7 @@ def dummy_data_for_llava_next(ctx: InputContext, seq_len: int, image_height_override=max_feat_height, ) - return seq_data, mm_data + return DummyData(seq_data, mm_data, ranges) msg = f"Unsupported vision config: {type(vision_config)}" raise NotImplementedError(msg) @@ -605,7 +606,6 @@ def forward( :class:`LlavaNextImageInputs` """ if intermediate_tensors is not None: - input_ids = None inputs_embeds = None else: image_input = self._parse_and_validate_image_input(**kwargs) @@ -617,9 +617,14 @@ def forward( self.language_model.model.get_input_embeddings, lambda _: self._process_image_input(image_input), ) - input_ids = None else: - inputs_embeds = None + inputs_embeds = self.language_model.model.get_input_embeddings( + input_ids) + + # always pass the input via `inputs_embeds` + # to make sure the computation graph is consistent + # for `torch.compile` integration + input_ids = None hidden_states = self.language_model.model(input_ids, positions, diff --git a/vllm/model_executor/models/llava_next_video.py b/vllm/model_executor/models/llava_next_video.py index b8051d5fc6ae2..b755e2347f6ed 100644 --- a/vllm/model_executor/models/llava_next_video.py +++ b/vllm/model_executor/models/llava_next_video.py @@ -11,8 +11,8 @@ from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, MultiModalConfig -from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext, - token_inputs) +from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, + InputContext, token_inputs) from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import Sampler, SamplerOutput @@ -108,33 +108,35 @@ def dummy_data_for_llava_next_video(ctx: InputContext, seq_len: int, video_feature_size = frames_per_video * tokens_per_frame if isinstance(vision_config, CLIPVisionConfig): - seq_data = dummy_seq_data_for_clip( + seq_data, ranges = dummy_seq_data_for_clip( vision_config, seq_len, num_videos, image_token_id=hf_config.video_token_index, image_feature_size_override=video_feature_size, + mm_key="video", ) pil_frame = dummy_image_for_clip(vision_config, num_images=1) np_frame = np.array(pil_frame["image"]) mm_data_per_video = np.repeat([np_frame], frames_per_video, axis=0) mm_data = {"video": mm_data_per_video} - return seq_data, mm_data + return DummyData(seq_data, mm_data, ranges) elif isinstance(vision_config, SiglipVisionConfig): - seq_data = dummy_seq_data_for_siglip( + seq_data, ranges = dummy_seq_data_for_siglip( vision_config, seq_len, num_videos, image_token_id=hf_config.video_token_index, image_feature_size_override=video_feature_size, + mm_key="video", ) pil_frame = dummy_image_for_siglip(vision_config, num_images=1) np_frame = np.array(pil_frame["image"]) mm_data_per_video = np.repeat([np_frame], frames_per_video, axis=0) mm_data = {"video": mm_data_per_video} - return seq_data, mm_data + return DummyData(seq_data, mm_data, ranges) msg = f"Unsupported vision config: {type(vision_config)}" raise NotImplementedError(msg) @@ -145,6 +147,12 @@ def input_processor_for_llava_next_video(ctx: InputContext, multi_modal_data = inputs.get("multi_modal_data") if multi_modal_data is None or "video" not in multi_modal_data: return inputs + + if "multi_modal_placeholders" in inputs and "video" in inputs[ + "multi_modal_placeholders"]: + # The inputs already have placeholders. + return inputs + video_data = multi_modal_data["video"] model_config = ctx.model_config @@ -160,7 +168,7 @@ def input_processor_for_llava_next_video(ctx: InputContext, tokenizer = cached_get_tokenizer(model_config.tokenizer) - new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens( + new_prompt, new_token_ids, ranges = repeat_and_pad_placeholder_tokens( tokenizer, inputs.get("prompt"), inputs["prompt_token_ids"], @@ -170,7 +178,8 @@ def input_processor_for_llava_next_video(ctx: InputContext, return token_inputs(prompt_token_ids=new_token_ids, prompt=new_prompt, - multi_modal_data=multi_modal_data) + multi_modal_data=multi_modal_data, + multi_modal_placeholders={"video": ranges}) elif is_list_of(video_data, np.ndarray): raise NotImplementedError( diff --git a/vllm/model_executor/models/llava_onevision.py b/vllm/model_executor/models/llava_onevision.py index a0cf208a65f36..f410d64577a77 100644 --- a/vllm/model_executor/models/llava_onevision.py +++ b/vllm/model_executor/models/llava_onevision.py @@ -15,8 +15,8 @@ from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, MultiModalConfig -from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext, - token_inputs) +from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, + InputContext, token_inputs) from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import Sampler, SamplerOutput @@ -218,31 +218,31 @@ def dummy_data_for_llava_onevision(ctx: InputContext, seq_len: int, video_feature_size = get_llava_onevision_video_tokens(ctx, num_frames) if isinstance(vision_config, CLIPVisionConfig): - seq_data = dummy_seq_data_for_clip( + seq_data, ranges = dummy_seq_data_for_clip( vision_config, seq_len, num_videos, image_token_id=hf_config.video_token_index, image_feature_size_override=video_feature_size, - ) + mm_key="video") mm_data = dummy_video_for_clip(vision_config, num_frames=num_frames, num_videos=num_videos) - return seq_data, mm_data + return DummyData(seq_data, mm_data, ranges) elif isinstance(vision_config, SiglipVisionConfig): - seq_data = dummy_seq_data_for_siglip( + seq_data, ranges = dummy_seq_data_for_siglip( vision_config, seq_len, num_videos, image_token_id=hf_config.video_token_index, image_feature_size_override=video_feature_size, - ) + mm_key="video") mm_data = dummy_video_for_siglip(vision_config, num_frames=num_frames, num_videos=num_videos) - return seq_data, mm_data + return DummyData(seq_data, mm_data, ranges) msg = f"Unsupported vision config: {type(vision_config)}" raise NotImplementedError(msg) @@ -320,7 +320,7 @@ def input_processor_when_multimodal_input_video(ctx: InputContext, video_feature_size = get_llava_onevision_video_tokens(ctx, num_frames) tokenizer = cached_get_tokenizer(model_config.tokenizer) - new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens( + new_prompt, new_token_ids, ranges = repeat_and_pad_placeholder_tokens( tokenizer, inputs.get("prompt"), inputs["prompt_token_ids"], @@ -330,7 +330,8 @@ def input_processor_when_multimodal_input_video(ctx: InputContext, return token_inputs(prompt_token_ids=new_token_ids, prompt=new_prompt, - multi_modal_data=multi_modal_data) + multi_modal_data=multi_modal_data, + multi_modal_placeholders={"video": ranges}) elif is_list_of(video_data, np.ndarray): video_feature_size = [] diff --git a/vllm/model_executor/models/mamba.py b/vllm/model_executor/models/mamba.py index 9f4f391a6682e..ec726dc4ff4fa 100644 --- a/vllm/model_executor/models/mamba.py +++ b/vllm/model_executor/models/mamba.py @@ -10,27 +10,19 @@ from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - MergedColumnParallelLinear, - RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( - causal_conv1d_fn, causal_conv1d_update) -from vllm.model_executor.layers.mamba.ops.mamba_ssm import ( - selective_scan_fn, selective_state_update) +from vllm.model_executor.layers.mamba.mamba_mixer import MambaMixer from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.model_loader.weight_utils import ( - composed_weight_loader, default_weight_loader, sharded_weight_loader) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.interfaces import (HasInnerState, IsAttentionFree) from vllm.model_executor.models.mamba_cache import (MambaCacheManager, MambaCacheParams) from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.model_executor.utils import set_weight_attrs from vllm.sequence import IntermediateTensors from vllm.worker.model_runner import (_BATCH_SIZES_TO_CAPTURE, _get_graph_batch_size) @@ -38,194 +30,27 @@ KVCache = Tuple[torch.Tensor, torch.Tensor] -# Adapted from transformers.models.mamba.modeling_mamba.MambaMixer -class MambaMixer(nn.Module): - """ - Compute ∆, A, B, C, and D the state space parameters and compute - the `contextualized_states`. A, D are input independent - (see Mamba paper [1] Section 3.5.2 "Interpretation of A" - for why A isn't selective) ∆, B, C are input-dependent - (this is a key difference between Mamba and the linear time - invariant S4, and is why Mamba is called - **selective** state spaces) - """ - - def __init__(self, config: MambaConfig, layer_idx): - super().__init__() - self.config = config - self.layer_idx = layer_idx - self.hidden_size = config.hidden_size - self.ssm_state_size = config.state_size - self.conv_kernel_size = config.conv_kernel - self.intermediate_size = config.intermediate_size - self.time_step_rank = int(config.time_step_rank) - self.is_falcon_mamba = config.model_type == "falcon_mamba" - self.conv1d = ColumnParallelLinear( - input_size=self.conv_kernel_size, - output_size=self.intermediate_size, - bias=config.use_conv_bias, - ) - # unsqueeze to fit conv1d weights shape into the linear weights shape. - # Can't do this in `weight_loader` since it already exists in - # `ColumnParallelLinear` and `set_weight_attrs` - # doesn't allow to override it - self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1) - - self.in_proj = MergedColumnParallelLinear(self.hidden_size, - [self.intermediate_size] * 2, - bias=config.use_bias) - # selective projection used to make dt, B and C input dependent - self.x_proj = RowParallelLinear( - self.intermediate_size, - self.time_step_rank + self.ssm_state_size * 2, - bias=False, - ) - # time step projection (discretization) - - # In the forward we need to apply dt_proj without the bias, - # as the bias is added in the selective scan kernel. - self.dt_proj = ColumnParallelLinear(self.time_step_rank, - self.intermediate_size, - bias=True, - skip_bias_add=True) - - tp_size = get_tensor_model_parallel_world_size() - self.A = nn.Parameter( - torch.empty( - self.intermediate_size // tp_size, - self.ssm_state_size, - dtype=torch.float32, - )) - self.D = nn.Parameter(torch.ones(self.intermediate_size // tp_size)) - - set_weight_attrs(self.D, {"weight_loader": sharded_weight_loader(0)}) - a_weight_loader = composed_weight_loader( - sharded_weight_loader(0), lambda x: -torch.exp(x.float())) - set_weight_attrs(self.A, {"weight_loader": a_weight_loader}) - - self.out_proj = RowParallelLinear( - self.intermediate_size, - self.hidden_size, - bias=config.use_bias, - input_is_parallel=True, - ) - self.activation = config.hidden_act - if self.is_falcon_mamba: - self.dt_layernorm = RMSNorm(self.time_step_rank, - eps=config.mixer_rms_eps) - self.b_layernorm = RMSNorm(self.ssm_state_size, - eps=config.mixer_rms_eps) - self.c_layernorm = RMSNorm(self.ssm_state_size, - eps=config.mixer_rms_eps) - - def forward(self, hidden_states: torch.Tensor, - attn_metadata: AttentionMetadata, - mamba_cache_params: MambaCacheParams): - - # 1. Gated MLP's linear projection - projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1) - hidden_states, gate = projected_states.chunk(2, dim=-2) - - # 2. Convolution sequence transformation - conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), - self.conv1d.weight.size(2)) - - if attn_metadata.query_start_loc is not None \ - and attn_metadata.context_lens_tensor is not None: - # |---------- N-1 iteration --------| - # |---------------- N iteration ---------------------| - # |- tokenA -|......................|-- newTokens ---| - # |---------- context_len ----------| - # |-------------------- seq_len ---------------------| - # |-- query_len ---| - hidden_states = causal_conv1d_fn( - hidden_states, - conv_weights, - self.conv1d.bias, - activation=self.activation, - conv_states=mamba_cache_params.conv_state, - has_initial_state=attn_metadata.context_lens_tensor > 0, - cache_indices=mamba_cache_params.state_indices_tensor, - query_start_loc=attn_metadata.query_start_loc) - else: - hidden_states = causal_conv1d_update( - hidden_states.transpose(0, 1), - mamba_cache_params.conv_state, - conv_weights, - self.conv1d.bias, - self.activation, - conv_state_indices=mamba_cache_params.state_indices_tensor) - hidden_states = hidden_states.transpose(0, 1) - - # 3. State Space Model sequence transformation - # 3.a. input varying initialization of time_step, B and C - ssm_parameters = self.x_proj(hidden_states.transpose(-2, -1))[0] - - time_step, B, C = torch.split( - ssm_parameters, - [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], - dim=-1, - ) - # Note that Jamba and FalconMamba normalizes B, C, and time_step here - # but Mamba doesn't. - if self.is_falcon_mamba: - time_step = self.dt_layernorm(time_step.contiguous()) - B = self.b_layernorm(B.contiguous()) - C = self.c_layernorm(C.contiguous()) - - discrete_time_step = self.dt_proj(time_step)[0].transpose(-2, -1) - # 3.c perform the recurrence y ← SSM(A, B, C)(x) - time_proj_bias = (self.dt_proj.bias.float() if hasattr( - self.dt_proj, "bias") else None) - - if attn_metadata.query_start_loc is not None \ - and attn_metadata.context_lens_tensor is not None: - scan_outputs = selective_scan_fn( - hidden_states, - mamba_cache_params.ssm_state, - discrete_time_step, - self.A, - B.transpose(-2, -1), - C.transpose(-2, -1), - self.D.float(), - gate, - time_proj_bias, - delta_softplus=True, - cache_indices=mamba_cache_params.state_indices_tensor, - has_initial_state=attn_metadata.context_lens_tensor > 0, - query_start_loc=attn_metadata.query_start_loc) - else: - scan_outputs = selective_state_update( - mamba_cache_params.ssm_state, - hidden_states.transpose(0, 1), - discrete_time_step.transpose(0, 1), - self.A, - B, - C, - self.D, - gate.transpose(0, 1), - time_proj_bias, - dt_softplus=True, - state_batch_indices=mamba_cache_params.state_indices_tensor) - scan_outputs = scan_outputs.transpose(0, 1) - - # 4. Final linear projection - contextualized_states = self.out_proj(scan_outputs.transpose(-2, - -1))[0] - return contextualized_states - - class MambaDecoderLayer(nn.Module): def __init__(self, config: MambaConfig, - layer_idx: int, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None) -> None: super().__init__() - self.layer_idx = layer_idx self.config = config self.is_falcon_mamba = config.model_type == "falcon_mamba" - self.mixer = MambaMixer(config, layer_idx) + mixer_rms_rps = config.mixer_rms_rps if self.is_falcon_mamba else None + self.mamba = MambaMixer(hidden_size=config.hidden_size, + ssm_state_size=config.state_size, + conv_kernel_size=config.conv_kernel, + intermediate_size=config.intermediate_size, + time_step_rank=config.time_step_rank, + use_conv_bias=config.use_conv_bias, + use_bias=config.use_bias, + use_rms_norm=self.is_falcon_mamba, + rms_norm_eps=mixer_rms_rps, + activation=config.hidden_act) + self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) def forward( diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index 4917c33136069..f90df6b7df036 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -36,8 +36,8 @@ from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, LoRAConfig, MultiModalConfig -from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext, - token_inputs) +from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, + InputContext, token_inputs) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.resampler import (BaseResampler, Resampler2, @@ -131,16 +131,22 @@ class MiniCPMVImageEmbeddingInputs(TypedDict): class Resampler2_5(BaseResampler): - def __init__( - self, - num_queries: int, - embed_dim: int, - num_heads: int, - kv_dim: Optional[int] = None, - norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN, - max_size: Tuple[int, int] = (70, 70), - ) -> None: - super().__init__(num_queries, embed_dim, num_heads, kv_dim, norm_layer) + def __init__(self, + num_queries: int, + embed_dim: int, + num_heads: int, + kv_dim: Optional[int] = None, + norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN, + max_size: Tuple[int, int] = (70, 70), + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "") -> None: + super().__init__(num_queries, + embed_dim, + num_heads, + kv_dim, + norm_layer, + quant_config=quant_config, + prefix=prefix) self.max_size = max_size self._set_2d_pos_cache(self.max_size) @@ -277,7 +283,7 @@ def dummy_data_for_minicpmv(ctx: InputContext, seq_len: int, seq_data = dummy_seq_data_for_minicpmv(seq_len, num_images) mm_data = dummy_image_for_minicpmv(ctx, hf_config, num_images) - return seq_data, mm_data + return DummyData(seq_data, mm_data) def input_processor_for_minicpmv(ctx: InputContext, inputs: DecoderOnlyInputs): @@ -404,7 +410,10 @@ def __init__( self.vision_dim = (self.vpm.embed_dim if self.version == (2, 0) else self.vpm.embeddings.embed_dim) self.embed_dim = self.config.hidden_size - self.resampler = self.init_resampler(self.embed_dim, self.vision_dim) + self.resampler = self.init_resampler(self.embed_dim, + self.vision_dim, + quant_config=quant_config, + prefix="resampler") self.resampler.to(device="cuda", dtype=param_dtype) # TODO: why is there _KEYS_TO_MODIFY_MAPPING? lm_head should be in llm self.lm_head = ParallelLMHead(config.vocab_size, @@ -564,8 +573,13 @@ def forward( vlm_embeddings, _ = self.get_embedding(input_ids, image_inputs) + # always pass the input via `inputs_embeds` + # to make sure the computation graph is consistent + # for `torch.compile` integration + input_ids = None + output = self.llm( - input_ids=None, + input_ids=input_ids, positions=positions, kv_caches=kv_caches, attn_metadata=attn_metadata, @@ -661,7 +675,11 @@ def init_vision_module( ) -> nn.Module: raise NotImplementedError - def init_resampler(self, embed_dim: int, vision_dim: int) -> nn.Module: + def init_resampler(self, + embed_dim: int, + vision_dim: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "") -> nn.Module: raise NotImplementedError def get_vision_embedding( @@ -738,16 +756,21 @@ def init_vision_module( def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.embed_tokens(input_ids) - def init_resampler(self, embed_dim: int, vision_dim: int) -> nn.Module: + def init_resampler(self, + embed_dim: int, + vision_dim: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "") -> nn.Module: with set_default_torch_dtype(torch.float16): - resampler = Resampler2( - embed_dim=embed_dim, - num_heads=embed_dim // 128, - grid_size=int(math.sqrt(self.config.query_num)), - kv_dim=vision_dim, - adaptive=False, - do_post_projection=True, - ) + resampler = Resampler2(embed_dim=embed_dim, + num_heads=embed_dim // 128, + grid_size=int( + math.sqrt(self.config.query_num)), + kv_dim=vision_dim, + adaptive=False, + do_post_projection=True, + quant_config=quant_config, + prefix=prefix) return resampler @@ -820,9 +843,17 @@ class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA): ".k_proj.", ".v_proj.", ".o_proj.", + # vision encoder + ".fc1.", + ".fc2.", + # Currently, vllm does not support BNB quantization for the `out_proj` + # of the resampler, so it's necessary to distinguish between the + # vision encoder and the resampler's out_proj. The same applies to + # MiniCPMV2_6. + ".self_attn.out_proj.", # vision encoder out_proj + # resampler + ".kv_proj.", ] - # in TP, these weights are partitioned along the column dimension (dim=-1) - column_parallel_weights_modules = [".down_proj.", ".o_proj."] bitsandbytes_stacked_params_mapping = { # shard_name, weight_name, index "q_proj": ("qkv_proj", 0), @@ -872,14 +903,18 @@ def init_vision_module( model.encoder.layers = model.encoder.layers[:-1] return model - def init_resampler(self, embed_dim: int, vision_dim: int) -> nn.Module: + def init_resampler(self, + embed_dim: int, + vision_dim: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "") -> nn.Module: with set_default_torch_dtype(torch.float16): - resampler = Resampler2_5( - num_queries=self.config.query_num, - embed_dim=embed_dim, - num_heads=embed_dim // 128, - kv_dim=vision_dim, - ) + resampler = Resampler2_5(num_queries=self.config.query_num, + embed_dim=embed_dim, + num_heads=embed_dim // 128, + kv_dim=vision_dim, + quant_config=quant_config, + prefix=prefix) return resampler def get_vision_embedding( @@ -962,9 +997,13 @@ class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA): ".k_proj.", ".v_proj.", ".o_proj.", + # vision encoder + ".fc1.", + ".fc2.", + ".self_attn.out_proj.", + # resampler + ".kv_proj.", ] - # in TP, these weights are partitioned along the column dimension (dim=-1) - column_parallel_weights_modules = [".down_proj.", ".o_proj."] bitsandbytes_stacked_params_mapping = { # shard_name, weight_name, index "q_proj": ("qkv_proj", 0), @@ -1014,15 +1053,19 @@ def init_vision_module( model.encoder.layers = model.encoder.layers[:-1] return model - def init_resampler(self, embed_dim: int, vision_dim: int) -> nn.Module: + def init_resampler(self, + embed_dim: int, + vision_dim: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "") -> nn.Module: with set_default_torch_dtype(torch.float16): # The resampler in 2.6 remains consistent with the one in 2.5. - resampler = Resampler2_5( - num_queries=self.config.query_num, - embed_dim=embed_dim, - num_heads=embed_dim // 128, - kv_dim=vision_dim, - ) + resampler = Resampler2_5(num_queries=self.config.query_num, + embed_dim=embed_dim, + num_heads=embed_dim // 128, + kv_dim=vision_dim, + quant_config=quant_config, + prefix=prefix) return resampler def get_vision_embedding( diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py index 5cf5272cae878..251bfc079684e 100644 --- a/vllm/model_executor/models/mllama.py +++ b/vllm/model_executor/models/mllama.py @@ -36,8 +36,8 @@ from vllm.attention.ops.paged_attn import PagedAttention from vllm.config import CacheConfig, MultiModalConfig from vllm.distributed import get_tensor_model_parallel_world_size -from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, - EncoderDecoderInputs, InputContext) +from vllm.inputs import (INPUT_REGISTRY, DummyData, EncoderDecoderInputs, + InputContext, TokenInputs, token_inputs) from vllm.logger import init_logger from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (ColumnParallelLinear, @@ -52,6 +52,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.sequence import SequenceData +from vllm.utils import is_list_of from .clip import CLIPMLP from .interfaces import SupportsMultiModal @@ -86,41 +87,58 @@ def _get_num_image_in_last_group(prompt_token_ids: List[int]) -> int: return num_images -def input_processor_for_mllama(ctx: InputContext, - inputs: Union[DecoderOnlyInputs, - EncoderDecoderInputs]): - # move encoder_prompt to prompt - if inputs.get("prompt") is None: - inputs["prompt"] = inputs["encoder_prompt"] - inputs["prompt_token_ids"] = inputs["encoder_prompt_token_ids"] +def input_processor_for_mllama( + ctx: InputContext, + inputs: EncoderDecoderInputs, +) -> EncoderDecoderInputs: + # Example input to processor: + # { + # 'encoder': { + # 'type': 'token', + # 'prompt_token_ids': [128000, 128256, 128000, 3923, 374, 279, 2262, 315, 420, 2217, 30], # noqa: E501 + # 'prompt': '<|image|><|begin_of_text|>What is the content of this image?', # noqa: E501 + # 'multi_modal_data': {'image': }, # noqa: E501 + # }, + # 'decoder': { + # 'type': 'token', + # 'prompt_token_ids': [128000], + # }, + # } + + # move encoder prompt to decoder + dec_inputs = TokenInputs(**inputs["encoder"]) + + multi_modal_data = dec_inputs.get("multi_modal_data") + if multi_modal_data is None or "image" not in multi_modal_data: + # text-only + return EncoderDecoderInputs( + encoder=token_inputs([]), + decoder=dec_inputs, + ) - # process multi-modal data - multi_modal_data = inputs.get("encoder_multi_modal_data") + image_data = multi_modal_data["image"] + if isinstance(image_data, Image.Image): + image_data = [image_data] - if multi_modal_data is None or "image" not in multi_modal_data \ - or multi_modal_data["image"] is None: - # text-only - inputs["encoder_prompt"] = "" - inputs["encoder_prompt_token_ids"] = [] - inputs["encoder_multi_modal_data"] = {} - return inputs + assert is_list_of(image_data, Image.Image) - if isinstance(multi_modal_data['image'], Image.Image): - multi_modal_data['image'] = [multi_modal_data['image']] # Since only the last group of consecutive images # are attended by the decoded tokens, we only need to # get the number of tiles for those images. num_decode_images = _get_num_image_in_last_group( - inputs["prompt_token_ids"]) + dec_inputs["prompt_token_ids"]) + hf_config = ctx.model_config.hf_config + vision_config = hf_config.vision_config + num_tiles = 0 - for image in multi_modal_data["image"][::-1]: + for image in image_data[::-1]: width, height = image.size - tile_size = hf_config.vision_config.image_size + tile_size = vision_config.image_size canvas_height, canvas_width = get_optimal_tiled_canvas( image_height=height, image_width=width, - max_image_tiles=hf_config.vision_config.max_num_tiles, + max_image_tiles=vision_config.max_num_tiles, tile_size=tile_size, ) num_tiles_height = canvas_height // tile_size @@ -133,14 +151,34 @@ def input_processor_for_mllama(ctx: InputContext, # Set encoder prompt length based on the number of tiles. # This tells the block manager to allocate correct number # of slots for encoder tokens. - assert hf_config.vision_config.image_size % 14 == 0, \ + assert vision_config.image_size % 14 == 0, \ "chunk size should be multiple of 14" - token_per_chunk = (hf_config.vision_config.image_size // 14)**2 + 1 + token_per_chunk = (vision_config.image_size // 14)**2 + 1 num_tokens = num_tiles * token_per_chunk - inputs["encoder_prompt"] = MLLAMA_IMAGE_TOKEN * num_tokens - inputs["encoder_prompt_token_ids"] = [MLLAMA_IMAGE_TOKEN_ID] * num_tokens - return inputs + # Example output from processor: + # { + # 'encoder': { + # 'type': 'token', + # 'prompt_token_ids': [128256, 128256, ..., 128256], + # 'prompt': '<|image|><|image|>...<|image|>', + # 'multi_modal_data': {'image': }, # noqa: E501 + # }, + # 'decoder': { + # 'type': 'token', + # 'prompt_token_ids': [128000, 128256, 128000, 3923, 374, 279, 2262, 315, 420, 2217, 30], # noqa: E501 + # 'prompt': '<|image|><|begin_of_text|>What is the content of this image?', # noqa: E501 + # 'multi_modal_data': {'image': }, # noqa: E501 + # }, + # } + return EncoderDecoderInputs( + encoder=token_inputs( + prompt_token_ids=[MLLAMA_IMAGE_TOKEN_ID] * num_tokens, + prompt=MLLAMA_IMAGE_TOKEN * num_tokens, + multi_modal_data=multi_modal_data, + ), + decoder=dec_inputs, + ) def get_max_mllama_image_tokens(ctx: InputContext) -> int: @@ -176,13 +214,14 @@ def dummy_image(num_images: int, ): def dummy_decoder_data_for_mllama(ctx: InputContext, seq_len: int, mm_counts: Mapping[str, int]): num_images = mm_counts["image"] - return dummy_decoder_seq_data(seq_len, num_images), None + return DummyData(dummy_decoder_seq_data(seq_len, num_images)) def dummy_encoder_data_for_mllama(ctx: InputContext, seq_len: int, mm_counts: Mapping[str, int]): num_images = mm_counts["image"] - return dummy_encoder_seq_data(ctx, num_images), dummy_image(num_images) + return DummyData(dummy_encoder_seq_data(ctx, num_images), + dummy_image(num_images)) def _prepare_aspect_ratio_attention_mask( @@ -1055,9 +1094,12 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal): ".k_proj.", ".v_proj.", ".o_proj.", + ".fc1.", + ".fc2.", + # The `multi_modal_projector` is at the top level of the model, + # so we can't add a dot in front of it. + "multi_modal_projector." ] - # in TP, these weights are partitioned along the column dimension (dim=-1) - column_parallel_weights_modules = [".down_proj.", ".o_proj."] bitsandbytes_stacked_params_mapping = { # shard_name, weight_name, index "q_proj": ("qkv_proj", 0), diff --git a/vllm/model_executor/models/molmo.py b/vllm/model_executor/models/molmo.py index 3c34227767e05..ba798833e26a9 100644 --- a/vllm/model_executor/models/molmo.py +++ b/vllm/model_executor/models/molmo.py @@ -15,6 +15,7 @@ from vllm.attention import Attention, AttentionMetadata from vllm.attention.selector import _Backend +from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, MultiModalConfig from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, @@ -713,6 +714,7 @@ def forward( return image_features +@support_torch_compile class MolmoModel(nn.Module): def __init__( @@ -1141,7 +1143,6 @@ def forward( **kwargs: object, ) -> SamplerOutput: if intermediate_tensors is not None: - input_ids = None inputs_embeds = None else: image_input = self._parse_and_validate_image_input(**kwargs) @@ -1156,10 +1157,13 @@ def forward( image_input["image_input_idx"], image_input["seq_len"], ) - - input_ids = None else: - inputs_embeds = None + inputs_embeds = self.model.embed_tokens(input_ids) + + # always pass the input via `inputs_embeds` + # to make sure the computation graph is consistent + # for `torch.compile` integration + input_ids = None hidden_states = self.model( input_ids=input_ids, diff --git a/vllm/model_executor/models/opt.py b/vllm/model_executor/models/opt.py index 10cca8b56268a..7521ab749e10f 100644 --- a/vllm/model_executor/models/opt.py +++ b/vllm/model_executor/models/opt.py @@ -343,8 +343,6 @@ class OPTForCausalLM(nn.Module, SupportsPP): default_bitsandbytes_target_modules = [ ".q_proj.", ".k_proj.", ".v_proj.", ".out_proj.", ".fc1.", ".fc2." ] - # in TP, these weights are partitioned along the column dimension (dim=-1) - column_parallel_weights_modules = [".out_proj.", ".fc2."] def __init__( self, diff --git a/vllm/model_executor/models/paligemma.py b/vllm/model_executor/models/paligemma.py index 8e29c6079b994..4b6061e113cb2 100644 --- a/vllm/model_executor/models/paligemma.py +++ b/vllm/model_executor/models/paligemma.py @@ -7,8 +7,8 @@ from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, MultiModalConfig -from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext, - token_inputs) +from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, + InputContext, token_inputs) from vllm.logger import init_logger from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import SamplerOutput @@ -58,7 +58,7 @@ def dummy_data_for_paligemma(ctx: InputContext, seq_len: int, vision_config = hf_config.vision_config num_images = mm_counts["image"] - seq_data = dummy_seq_data_for_siglip( + seq_data, ranges = dummy_seq_data_for_siglip( vision_config, seq_len, num_images, @@ -66,7 +66,7 @@ def dummy_data_for_paligemma(ctx: InputContext, seq_len: int, ) mm_data = dummy_image_for_siglip(vision_config, num_images) - return seq_data, mm_data + return DummyData(seq_data, mm_data, ranges) def input_processor_for_paligemma(ctx: InputContext, diff --git a/vllm/model_executor/models/phi.py b/vllm/model_executor/models/phi.py index 497eae4e8905b..4e7935a7636c5 100644 --- a/vllm/model_executor/models/phi.py +++ b/vllm/model_executor/models/phi.py @@ -274,8 +274,6 @@ class PhiForCausalLM(nn.Module, SupportsLoRA, SupportsPP): default_bitsandbytes_target_modules = [ ".q_proj.", ".k_proj.", ".v_proj.", ".fc1.", ".fc2.", ".dense." ] - # in TP, these weights are partitioned along the column dimension (dim=-1) - column_parallel_weights_modules = [".fc2.", ".dense."] embedding_modules = {} embedding_padding_modules = [] diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index 4928e447d5b9e..5b477a8ed5f49 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -28,8 +28,8 @@ from vllm.attention import AttentionMetadata from vllm.config import (CacheConfig, ModelConfig, MultiModalConfig, PoolerConfig) -from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext, - token_inputs) +from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, + InputContext, token_inputs) from vllm.logger import init_logger from vllm.model_executor.layers.pooler import Pooler, PoolingType from vllm.model_executor.layers.quantization import QuantizationConfig @@ -380,7 +380,7 @@ def dummy_data_for_phi3v(ctx: InputContext, image_feature_size = get_max_phi3v_image_tokens(ctx, num_crops=num_crops) - seq_data = dummy_seq_data_for_clip( + seq_data, ranges = dummy_seq_data_for_clip( CLIP_VIT_LARGE_PATCH14_336_CONFIG, seq_len, num_images, @@ -394,7 +394,7 @@ def dummy_data_for_phi3v(ctx: InputContext, image_height_override=MAX_IMAGE_FEATURE_SIZE_HEIGHT, ) - return seq_data, mm_data + return DummyData(seq_data, mm_data, ranges) @lru_cache diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index 6b53bf5660096..051454c49bff8 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -17,8 +17,8 @@ from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, ModelConfig, MultiModalConfig -from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext, - token_inputs) +from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, + InputContext, token_inputs) from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.quantization import QuantizationConfig @@ -28,7 +28,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.base import MultiModalInputs -from vllm.multimodal.utils import cached_get_tokenizer +from vllm.multimodal.utils import (cached_get_tokenizer, + consecutive_placeholder_ranges) from vllm.sequence import IntermediateTensors, SequenceData from vllm.transformers_utils.processor import cached_get_processor from vllm.utils import is_list_of @@ -81,7 +82,12 @@ def dummy_data_for_pixtral(ctx: InputContext, seq_len: int, ) mm_data = {"image": num_images * [image]} - return seq_data, mm_data + mm_placeholders = { + "image": + consecutive_placeholder_ranges(num_items=num_images, + item_size=image_feature_size) + } + return DummyData(seq_data, mm_data, mm_placeholders) def input_mapper_for_pixtral(ctx: InputContext, @@ -630,13 +636,13 @@ def get_max_pixtral_hf_image_tokens(hf_config: PixtralVisionConfig) -> int: def dummy_seq_data_for_pixtral_hf( - hf_config: PixtralVisionConfig, - seq_len: int, - num_images: int, - *, - image_token_id: int, - image_feature_size_override: Optional[int] = None, -): + hf_config: PixtralVisionConfig, + seq_len: int, + num_images: int, + *, + image_token_id: int, + image_feature_size_override: Optional[int] = None, + mm_key: str = "image"): if image_feature_size_override is None: image_feature_size = get_max_pixtral_hf_image_feature_size(hf_config) else: @@ -645,7 +651,11 @@ def dummy_seq_data_for_pixtral_hf( return SequenceData.from_prompt_token_counts( (image_token_id, image_feature_size * num_images), (0, seq_len - image_feature_size * num_images), - ) + ), { + mm_key: + consecutive_placeholder_ranges(num_items=num_images, + item_size=image_feature_size) + } def dummy_image_for_pixtral_hf( diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index 61665768eacf5..b2b5c70182135 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -23,8 +23,8 @@ from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, LoRAConfig, MultiModalConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size -from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext, - token_inputs) +from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, + InputContext, token_inputs) from vllm.logger import init_logger from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn from vllm.model_executor.layers.layernorm import RMSNorm @@ -810,7 +810,7 @@ def dummy_data_for_qwen( ctx: InputContext, seq_len: int, mm_counts: Mapping[str, int], -) -> Tuple[SequenceData, Optional[Dict]]: +) -> DummyData: """Build dummy data for warming up Qwen models; this will only contain text matching the defaults for VLLM unless the model has a visual config. @@ -829,7 +829,7 @@ def dummy_data_for_qwen( if not hasattr(hf_config, "visual"): seq_data = SequenceData.from_prompt_token_counts((0, seq_len)) mm_data = None - return seq_data, mm_data + return DummyData(seq_data, mm_data) # We have a visual component - use images to warm up num_images = mm_counts["image"] @@ -861,7 +861,7 @@ def dummy_data_for_qwen( # the data will get resized and the # of tokens per image is constant image = Image.new("RGB", (224, 224), color=0) mm_data = {"image": image if num_images == 1 else [image] * num_images} - return seq_data, mm_data + return DummyData(seq_data, mm_data) class QWenBaseModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA): diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index db7556b3b5f4b..72b286fe6f6d6 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -395,9 +395,6 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ".v_proj.", ".o_proj.", ] - - # in TP, these weights are partitioned along the column dimension (dim=-1) - column_parallel_weights_modules = [".down_proj.", ".o_proj."] bitsandbytes_stacked_params_mapping = { # shard_name, weight_name, index "q_proj": ("qkv_proj", 0), diff --git a/vllm/model_executor/models/qwen2_audio.py b/vllm/model_executor/models/qwen2_audio.py index 3d049eeb920b7..6114548bda42c 100644 --- a/vllm/model_executor/models/qwen2_audio.py +++ b/vllm/model_executor/models/qwen2_audio.py @@ -31,8 +31,8 @@ from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, MultiModalConfig -from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext, - token_inputs) +from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, + InputContext, token_inputs) from vllm.logger import init_logger from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization.base_config import ( @@ -44,6 +44,7 @@ from vllm.model_executor.models.qwen2 import Qwen2Model from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalInputs +from vllm.multimodal.utils import consecutive_placeholder_ranges from vllm.sequence import IntermediateTensors, SequenceData from .interfaces import SupportsMultiModal, SupportsPP @@ -85,7 +86,8 @@ def forward(self, audio_features): def dummy_data_for_qwen2_audio(ctx: InputContext, seq_len: int, mm_counts: Mapping[str, int]): num_audios = mm_counts["audio"] - max_llm_audio_tokens = get_max_qwen2_audio_audio_tokens(ctx) * num_audios + max_tokens_per_audio = get_max_qwen2_audio_audio_tokens(ctx) + max_llm_audio_tokens = max_tokens_per_audio * num_audios if seq_len - max_llm_audio_tokens - 2 < 0: raise RuntimeError( f"Qwen2-Audio cannot process {num_audios} audios in a prompt, " @@ -99,7 +101,12 @@ def dummy_data_for_qwen2_audio(ctx: InputContext, seq_len: int, (0, seq_len - max_llm_audio_tokens), ) dummy_audio = np.full((max_llm_audio_tokens * 2 * 2 * 160, ), 0.) - return dummy_seqdata, {"audio": [(dummy_audio, 16000)] * num_audios} + return DummyData( + dummy_seqdata, {"audio": [(dummy_audio, 16000)] * num_audios}, { + "audio": + consecutive_placeholder_ranges(num_items=num_audios, + item_size=max_tokens_per_audio) + }) def get_processor( diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 1e12c2332b65e..d801903f8f9fe 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -44,8 +44,8 @@ from vllm.config import CacheConfig, MultiModalConfig from vllm.distributed import get_pp_group, parallel_state from vllm.distributed import utils as dist_utils -from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext, - token_inputs) +from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, + InputContext, token_inputs) from vllm.logger import init_logger from vllm.model_executor import SamplingMetadata from vllm.model_executor.layers.activation import QuickGELU @@ -744,9 +744,10 @@ def dummy_data_for_qwen2_vl( dummy_image = Image.new("RGB", (max_resized_width, max_resized_height), color=0) - return dummy_seqdata, { - "image": dummy_image if num_images == 1 else [dummy_image] * num_images - } + return DummyData(dummy_seqdata, { + "image": + dummy_image if num_images == 1 else [dummy_image] * num_images + }) def _get_llm_num_vision_tokens( diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index f50ceaccb1bbe..af52fbffba19e 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -128,6 +128,7 @@ def add_embedding_models(base_models, embedding_models): "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"), "ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"), "FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"), + "H2OVLChatModel": ("h2ovl", "H2OVLChatModel"), "InternVLChatModel": ("internvl", "InternVLChatModel"), "LlavaForConditionalGeneration": ("llava", "LlavaForConditionalGeneration"), "LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"), # noqa: E501 @@ -342,6 +343,11 @@ def register_model( def _raise_for_unsupported(self, architectures: List[str]): all_supported_archs = self.get_supported_archs() + if any(arch in all_supported_archs for arch in architectures): + raise ValueError( + f"Model architectures {architectures} failed " + "to be inspected. Please check the logs for more details.") + raise ValueError( f"Model architectures {architectures} are not supported for now. " f"Supported architectures: {all_supported_archs}") @@ -482,4 +488,4 @@ def _run() -> None: if __name__ == "__main__": - _run() + _run() \ No newline at end of file diff --git a/vllm/model_executor/models/siglip.py b/vllm/model_executor/models/siglip.py index 2e7ae32055aaf..acaf4afdecfe5 100644 --- a/vllm/model_executor/models/siglip.py +++ b/vllm/model_executor/models/siglip.py @@ -23,6 +23,7 @@ VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.multimodal.utils import (cached_get_tokenizer, + consecutive_placeholder_ranges, repeat_and_pad_placeholder_tokens) from vllm.sequence import SequenceData @@ -61,6 +62,7 @@ def dummy_seq_data_for_siglip( *, image_token_id: int, image_feature_size_override: Optional[int] = None, + mm_key: str = "image", ): if image_feature_size_override is None: image_feature_size = get_siglip_image_feature_size(hf_config) @@ -70,7 +72,11 @@ def dummy_seq_data_for_siglip( return SequenceData.from_prompt_token_counts( (image_token_id, image_feature_size * num_images), (0, seq_len - image_feature_size * num_images), - ) + ), { + mm_key: + consecutive_placeholder_ranges(num_items=num_images, + item_size=image_feature_size) + } def dummy_image_for_siglip( @@ -122,6 +128,11 @@ def input_processor_for_siglip( if multi_modal_data is None or "image" not in multi_modal_data: return inputs + if "multi_modal_placeholders" in inputs and "image" in inputs[ + "multi_modal_placeholders"]: + # The inputs already have placeholders. + return inputs + tokenizer = cached_get_tokenizer(model_config.tokenizer) if image_feature_size_override is None: @@ -135,7 +146,7 @@ def input_processor_for_siglip( else: image_feature_size = image_feature_size_override - new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens( + new_prompt, new_token_ids, ranges = repeat_and_pad_placeholder_tokens( tokenizer, inputs.get("prompt"), inputs["prompt_token_ids"], @@ -144,11 +155,10 @@ def input_processor_for_siglip( ) # NOTE: Create a defensive copy of the original inputs - return token_inputs( - prompt_token_ids=new_token_ids, - prompt=new_prompt, - multi_modal_data=multi_modal_data, - ) + return token_inputs(prompt_token_ids=new_token_ids, + prompt=new_prompt, + multi_modal_data=multi_modal_data, + multi_modal_placeholders={"image": ranges}) # Adapted from https://github.com/huggingface/transformers/blob/v4.43.3/src/transformers/models/siglip/modeling_siglip.py#L249 # noqa diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index f08e4aa355086..749750fc9c16e 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -2,7 +2,6 @@ """PyTorch Ultravox model.""" import math -from array import array from functools import cached_property, lru_cache from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, TypedDict, Union, cast) @@ -17,27 +16,27 @@ from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, MultiModalConfig -from vllm.inputs import INPUT_REGISTRY -from vllm.inputs.data import DecoderOnlyInputs, token_inputs -from vllm.inputs.registry import InputContext +from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, + InputContext, token_inputs) from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.model_loader.loader import DefaultModelLoader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.base import MultiModalInputs, NestedTensors +from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalInputs, + NestedTensors) from vllm.multimodal.utils import (cached_get_tokenizer, + consecutive_placeholder_ranges, repeat_and_pad_placeholder_tokens) -from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors, - SequenceData) +from vllm.sequence import IntermediateTensors, SequenceData from vllm.transformers_utils.configs.ultravox import UltravoxConfig from vllm.utils import is_list_of from .interfaces import SupportsMultiModal, SupportsPP from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, - init_vllm_registered_model, merge_multimodal_embeddings) + init_vllm_registered_model, + merge_multimodal_embeddings_from_map) _AUDIO_PLACEHOLDER_TOKEN = 128002 _AUDIO_TOKENS_PER_SECOND = 6.25 @@ -46,13 +45,13 @@ class UltravoxAudioFeatureInputs(TypedDict): type: Literal["audio_features"] data: NestedTensors - """Shape: `(batch_size, num_audios, 80, M)""" + """Shape: `(batch_size, num_audios, 80, M)`""" class UltravoxAudioEmbeddingInputs(TypedDict): type: Literal["audio_embeds"] data: NestedTensors - """Shape: `(batch_size, num_audios, audio_feature_size, hidden_size)""" + """Shape: `(batch_size, num_audios, audio_feature_size, hidden_size)`""" UltravoxAudioInputs = Union[UltravoxAudioFeatureInputs, @@ -79,17 +78,16 @@ def dummy_seq_data_for_ultravox( seq_len: int, audio_count: int, ): - audio_placeholder = array( - VLLM_TOKEN_ID_ARRAY_TYPE, - [_AUDIO_PLACEHOLDER_TOKEN]) * get_ultravox_max_audio_tokens(ctx) + audio_length = min(get_ultravox_max_audio_tokens(ctx), + seq_len // audio_count) - # Add a separator between each chunk. - audio_token_ids = (audio_placeholder + - array(VLLM_TOKEN_ID_ARRAY_TYPE, [0])) * audio_count - other_token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, - [0]) * (seq_len - len(audio_token_ids)) - - return SequenceData(audio_token_ids + other_token_ids) + return SequenceData.from_prompt_token_counts( + (_AUDIO_PLACEHOLDER_TOKEN, audio_length * audio_count), + (0, seq_len - audio_length * audio_count)), { + "audio": + consecutive_placeholder_ranges(num_items=audio_count, + item_size=audio_length) + } def dummy_audio_for_ultravox( @@ -107,10 +105,10 @@ def dummy_data_for_ultravox( mm_counts: Mapping[str, int], ): audio_count = mm_counts["audio"] - seq_data = dummy_seq_data_for_ultravox(ctx, seq_len, audio_count) + seq_data, ranges = dummy_seq_data_for_ultravox(ctx, seq_len, audio_count) mm_dict = dummy_audio_for_ultravox(ctx, audio_count) - return (seq_data, mm_dict) + return DummyData(seq_data, mm_dict, ranges) def input_mapper_for_ultravox(ctx: InputContext, data: object): @@ -164,6 +162,11 @@ def input_processor_for_ultravox(ctx: InputContext, inputs: DecoderOnlyInputs): if multi_modal_data is None or "audio" not in multi_modal_data: return inputs + if "multi_modal_placeholders" in inputs and "audio" in inputs[ + "multi_modal_placeholders"]: + # The inputs already have placeholders. + return inputs + feature_extractor = whisper_feature_extractor(ctx) audios = multi_modal_data["audio"] if not isinstance(audios, list): @@ -197,7 +200,7 @@ def input_processor_for_ultravox(ctx: InputContext, inputs: DecoderOnlyInputs): tokenizer = cached_get_tokenizer(ctx.model_config.tokenizer) - new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens( + new_prompt, new_token_ids, ranges = repeat_and_pad_placeholder_tokens( tokenizer, inputs.get("prompt"), inputs["prompt_token_ids"], @@ -208,7 +211,8 @@ def input_processor_for_ultravox(ctx: InputContext, inputs: DecoderOnlyInputs): # NOTE: Create a defensive copy of the original inputs return token_inputs(prompt_token_ids=new_token_ids, prompt=new_prompt, - multi_modal_data=multi_modal_data) + multi_modal_data=multi_modal_data, + multi_modal_placeholders={"audio": ranges}) class StackAudioFrames(nn.Module): @@ -472,9 +476,9 @@ def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, inputs_embeds = self.language_model.model.get_input_embeddings( input_ids) - inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, audio_embeddings, - _AUDIO_PLACEHOLDER_TOKEN) + merge_multimodal_embeddings_from_map( + inputs_embeds, audio_embeddings, + attn_metadata.multi_modal_placeholder_index_maps["audio"]) input_ids = None else: inputs_embeds = None diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index 0aecb5d151a45..fee97e8922a76 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -18,7 +18,7 @@ from vllm.model_executor.model_loader.loader import build_model from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models import ModelRegistry -from vllm.multimodal.base import NestedTensors +from vllm.multimodal.base import MultiModalPlaceholderMap, NestedTensors from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors from vllm.utils import is_pin_memory_available @@ -252,6 +252,7 @@ def init_vllm_registered_model( return build_model( model_class, + None, hf_config, cache_config, quant_config, @@ -326,6 +327,22 @@ def _embedding_count_expression(embeddings: NestedTensors) -> str: _embedding_count_expression(inner) for inner in embeddings) +def merge_multimodal_embeddings_from_map( + inputs_embeds: torch.Tensor, multimodal_embeddings: NestedTensors, + placeholder_map: MultiModalPlaceholderMap.IndexMap) -> torch.Tensor: + """ + Merge ``multimodal_embeddings`` into ``inputs_embeds`` using the provided + placeholder map . + + Note: + This updates ``inputs_embeds`` in place. + """ + flattened_embeddings = _flatten_embeddings(multimodal_embeddings) + inputs_embeds[placeholder_map.dest] = flattened_embeddings[ + placeholder_map.src] + return inputs_embeds + + def _merge_multimodal_embeddings( inputs_embeds: torch.Tensor, is_multimodal: torch.Tensor, diff --git a/vllm/multimodal/__init__.py b/vllm/multimodal/__init__.py index 489e1e51f05cb..53da2badb9b98 100644 --- a/vllm/multimodal/__init__.py +++ b/vllm/multimodal/__init__.py @@ -1,6 +1,7 @@ from .base import (BatchedTensorInputs, MultiModalDataBuiltins, - MultiModalDataDict, MultiModalInputs, MultiModalPlugin, - NestedTensors) + MultiModalDataDict, MultiModalInputs, + MultiModalPlaceholderDict, MultiModalPlaceholderMap, + MultiModalPlugin, NestedTensors) from .registry import MultiModalRegistry MULTIMODAL_REGISTRY = MultiModalRegistry() @@ -17,6 +18,8 @@ "MultiModalDataBuiltins", "MultiModalDataDict", "MultiModalInputs", + "MultiModalPlaceholderDict", + "MultiModalPlaceholderMap", "MultiModalPlugin", "NestedTensors", "MULTIMODAL_REGISTRY", diff --git a/vllm/multimodal/base.py b/vllm/multimodal/base.py index 84e71cbf60df7..6b10d0c609f13 100644 --- a/vllm/multimodal/base.py +++ b/vllm/multimodal/base.py @@ -1,8 +1,9 @@ import sys from abc import ABC, abstractmethod from collections import UserDict, defaultdict -from typing import (Any, Callable, Dict, List, Mapping, Optional, Tuple, Type, - TypedDict, TypeVar, Union, cast, final) +from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Mapping, + NamedTuple, Optional, Tuple, Type, TypedDict, TypeVar, + Union, cast, final) import numpy as np import torch @@ -11,12 +12,15 @@ from torch import nn from typing_extensions import TypeAlias -from vllm.config import ModelConfig from vllm.inputs import InputContext from vllm.logger import init_logger from vllm.utils import (JSONTree, get_allowed_kwarg_only_overrides, is_list_of, json_map_leaves, resolve_mm_processor_kwargs) +if TYPE_CHECKING: + from vllm.config import ModelConfig + from vllm.sequence import SequenceGroupMetadata + logger = init_logger(__name__) NestedTensors = Union[List["NestedTensors"], List[torch.Tensor], torch.Tensor] @@ -151,6 +155,30 @@ class MultiModalDataBuiltins(TypedDict, total=False): Read more on that :ref:`here `. """ + +class PlaceholderRange(TypedDict): + """ + Placeholder location information for multi-modal data. + + For example: + Prompt: AAAA BBBB What is in these images? + Images A and B will have: + A: { "offset": 0, "length": 4 } + B: { "offset": 5, "length": 4 } + """ + + offset: int + """The start index of the placeholder in the prompt.""" + + length: int + """The length of the placeholder.""" + + +MultiModalPlaceholderDict = Mapping[str, List[PlaceholderRange]] +""" +A dictionary containing placeholder ranges. +""" + MultiModalInputMapper = Callable[[InputContext, MultiModalData[object]], MultiModalInputs] """ @@ -243,7 +271,7 @@ def wrapper(model_cls: N) -> N: return wrapper - def map_input(self, model_config: ModelConfig, + def map_input(self, model_config: "ModelConfig", data: MultiModalData[object], mm_processor_kwargs: Dict[str, Any]) -> MultiModalInputs: """ @@ -332,7 +360,7 @@ def wrapper(model_cls: N) -> N: return wrapper - def get_max_multimodal_tokens(self, model_config: ModelConfig) -> int: + def get_max_multimodal_tokens(self, model_config: "ModelConfig") -> int: """ Get the maximum number of multi-modal tokens for profiling the memory usage of a model. @@ -366,3 +394,179 @@ def get_max_multimodal_tokens(self, model_config: ModelConfig) -> int: self._validate_max_multimodal_tokens(max_mm_tokens) return max_mm_tokens + + +class MultiModalPlaceholderMap: + """ + Relates multi-modal embeddings to their corresponding placeholders. + """ + + class IndexMap(NamedTuple): + src: List[int] + dest: List[int] + + src_ranges: List[range] + """ + The indices of the multi-modal embeddings that will replace the + corresponding placeholder embeddings pointed to by ``dest_ranges``. + """ + + src_len: int + """ + The total number of flattened multi-modal embeddings. + """ + + dest_ranges: List[range] + """ + The indices of the placeholder embeddings that will be replaced by the + multimodal embeddings. + """ + + dest_len: int + """ + The total number of embeddings in the destination tensor. + """ + + def __init__(self): + self.src_ranges = [] + self.src_len = 0 + self.dest_ranges = [] + self.dest_len = 0 + + @classmethod + def from_seq_group( + cls, seq_group: "SequenceGroupMetadata", positions: range + ) -> Tuple[Optional[MultiModalDataDict], Dict[str, + "MultiModalPlaceholderMap"]]: + """ + Returns the multi-modal items that intersect with the portion of a + prompt (``seq_group``) represented by ``positions``, as well as a + ``MultiModalPlaceholderMap`` that relates the multi-modal embedding + vectors to their corresponding placeholders. + + Consider the following scenarios: + + Prompt: |AAAA BBBB What's in these images?| + Positions: |.................................| + + images = [A, B] + src_ranges = [(0, 4), (4, 8)] + dest_ranges = [(0, 4), (5, 9)] + + Prompt: |AAAA BBBB What's in these images?| + Positions: | ..... | + + images = [A, B] + src_ranges = [(2, 4), (4, 6)] + dest_ranges = [(0, 2), (3, 5)] + + Prompt: |AAAA BBBB What's in these images?| + Positions: | ......... | + + images = [B] + src_ranges = [(0, 4)] + dest_ranges = [(0, 4)] + + Prompt: |AAAA BBBB What's in these images?| + Positions: | .......................| + + images = [] + src_ranges = [] + dest_ranges = [] + """ + if (not seq_group.multi_modal_data + or not seq_group.multi_modal_placeholders): + return seq_group.multi_modal_data, {} + + mm_data = {**seq_group.multi_modal_data} + placeholder_maps: Dict[str, MultiModalPlaceholderMap] = defaultdict( + MultiModalPlaceholderMap) + + for modality, placeholders in seq_group.multi_modal_placeholders.items( + ): + mm_items = mm_data.pop(modality) + if not isinstance(mm_items, list): + mm_items = [mm_items] + + if positions: + intersecting_items = placeholder_maps[ + modality].append_items_from_seq_group( + positions, mm_items, placeholders) + + if intersecting_items: + mm_data[modality] = intersecting_items + + return mm_data, placeholder_maps + + def append_items_from_seq_group( + self, positions: range, multi_modal_items: List[_T], + multi_modal_placeholders: List[PlaceholderRange]) -> List[_T]: + """ + Adds the multi-modal items that intersect ```positions`` to this + placeholder map and returns the intersecting items. + """ + intersecting_items = [] + + if len(multi_modal_items) != len(multi_modal_placeholders): + raise ValueError( + "Multi-modal placeholders and items must have the same length." + ) + for placeholder_dict, mm_item in zip(multi_modal_placeholders, + multi_modal_items): + placeholder = range( + placeholder_dict["offset"], + placeholder_dict["offset"] + placeholder_dict["length"]) + intersection = range(max(positions.start, placeholder.start), + min(positions.stop, placeholder.stop)) + + if not intersection: + # Skip this multi-modal item. + continue + + token_embedding_range = range(intersection.start - positions.start, + intersection.stop - positions.start) + + multimodal_embedding_range = range( + intersection.start - placeholder.start + self.src_len, + intersection.stop - placeholder.start + self.src_len) + + intersecting_items.append(mm_item) + self.dest_ranges.append(token_embedding_range) + self.src_ranges.append(multimodal_embedding_range) + self.src_len += len(placeholder) + + self.dest_len += len(positions) + return intersecting_items + + def extend(self, other: "MultiModalPlaceholderMap"): + """ + Adds the placeholders from another ``MultiModalPlaceholderMap`` to this + instance based on the source and destination tensors being + concatenated. + """ + + self.src_ranges.extend( + range(self.src_len + r.start, self.src_len + r.stop) + for r in other.src_ranges) + self.src_len += other.src_len + self.dest_ranges.extend( + range(self.dest_len + r.start, self.dest_len + r.stop) + for r in other.dest_ranges) + self.dest_len += other.dest_len + + def index_map(self) -> "IndexMap": + """ + Finalizes the placeholder map into lists of indices that can be used to + index the source and destination tensors. + """ + + src_indices = [i for r in self.src_ranges for i in r] + dest_indices = [i for r in self.dest_ranges for i in r] + + if len(src_indices) != len(dest_indices): + raise ValueError( + f"The number of source ({len(src_indices)}) and destination " + f"indices ({len(dest_indices)}) must be the same.") + + return MultiModalPlaceholderMap.IndexMap(src=src_indices, + dest=dest_indices) diff --git a/vllm/multimodal/image.py b/vllm/multimodal/image.py index 5f74bcea65ce2..3f6bb6c8338d2 100644 --- a/vllm/multimodal/image.py +++ b/vllm/multimodal/image.py @@ -1,11 +1,10 @@ from functools import lru_cache -from typing import Any, Dict, Optional +from typing import TYPE_CHECKING, Any, Dict, Optional import torch from PIL import Image from transformers.image_processing_base import BatchFeature -from vllm.config import ModelConfig from vllm.inputs.registry import InputContext from vllm.logger import init_logger from vllm.transformers_utils.processor import get_image_processor @@ -13,6 +12,9 @@ from .base import MultiModalData, MultiModalInputs, MultiModalPlugin +if TYPE_CHECKING: + from vllm.config import ModelConfig + logger = init_logger(__name__) cached_get_image_processor = lru_cache(get_image_processor) @@ -26,7 +28,7 @@ def get_data_key(self) -> str: def _get_hf_image_processor( self, - model_config: ModelConfig, + model_config: "ModelConfig", mm_processor_kwargs: Optional[Dict[str, Any]] = None, ): if mm_processor_kwargs is None: diff --git a/vllm/multimodal/registry.py b/vllm/multimodal/registry.py index 5e9b8bd518de3..bce2f4c6abe5b 100644 --- a/vllm/multimodal/registry.py +++ b/vllm/multimodal/registry.py @@ -1,8 +1,7 @@ import functools from collections import UserDict -from typing import Any, Dict, Mapping, Optional, Sequence +from typing import TYPE_CHECKING, Any, Dict, Mapping, Optional, Sequence -from vllm.config import ModelConfig from vllm.logger import init_logger from .audio import AudioPlugin @@ -11,6 +10,9 @@ from .image import ImagePlugin from .video import VideoPlugin +if TYPE_CHECKING: + from vllm.config import ModelConfig + logger = init_logger(__name__) @@ -20,7 +22,7 @@ class _MultiModalLimits(UserDict): when attempting to access a model that does not exist. """ - def __getitem__(self, key: ModelConfig) -> Dict[str, int]: + def __getitem__(self, key: "ModelConfig") -> Dict[str, int]: try: return super().__getitem__(key) except KeyError as exc: @@ -98,7 +100,7 @@ def register_image_input_mapper( def map_input( self, - model_config: ModelConfig, + model_config: "ModelConfig", data: MultiModalDataDict, mm_processor_kwargs: Optional[Dict[str, Any]] = None, ) -> MultiModalInputs: @@ -139,7 +141,7 @@ def map_input( return MultiModalInputs(merged_dict) - def create_input_mapper(self, model_config: ModelConfig): + def create_input_mapper(self, model_config: "ModelConfig"): """ Create an input mapper (see :meth:`map_input`) for a specific model. """ @@ -177,7 +179,7 @@ def register_max_image_tokens( """ return self.register_max_multimodal_tokens("image", max_mm_tokens) - def get_max_multimodal_tokens(self, model_config: ModelConfig) -> int: + def get_max_multimodal_tokens(self, model_config: "ModelConfig") -> int: """ Get the maximum number of multi-modal tokens for profiling the memory usage of a model. @@ -195,7 +197,7 @@ def get_max_multimodal_tokens(self, model_config: ModelConfig) -> int: def init_mm_limits_per_prompt( self, - model_config: ModelConfig, + model_config: "ModelConfig", ) -> None: """ Initialize the maximum number of multi-modal input instances for each @@ -231,7 +233,7 @@ def init_mm_limits_per_prompt( def get_mm_limits_per_prompt( self, - model_config: ModelConfig, + model_config: "ModelConfig", ) -> Mapping[str, int]: """ Get the maximum number of multi-modal input instances for each modality diff --git a/vllm/multimodal/utils.py b/vllm/multimodal/utils.py index 3c801464383ad..283c23c94d330 100644 --- a/vllm/multimodal/utils.py +++ b/vllm/multimodal/utils.py @@ -1,4 +1,5 @@ import base64 +import os from functools import lru_cache from io import BytesIO from typing import Any, List, Optional, Tuple, TypeVar, Union @@ -10,7 +11,7 @@ from vllm.connections import global_http_connection from vllm.envs import VLLM_AUDIO_FETCH_TIMEOUT, VLLM_IMAGE_FETCH_TIMEOUT from vllm.logger import init_logger -from vllm.multimodal.base import MultiModalDataDict +from vllm.multimodal.base import MultiModalDataDict, PlaceholderRange from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer logger = init_logger(__name__) @@ -18,19 +19,60 @@ cached_get_tokenizer = lru_cache(get_tokenizer) -def _load_image_from_bytes(b: bytes): +def _load_image_from_bytes(b: bytes) -> Image.Image: image = Image.open(BytesIO(b)) image.load() return image -def _load_image_from_data_url(image_url: str): +def _is_subpath(image_path: str, allowed_local_media_path: str) -> bool: + # Get the common path + common_path = os.path.commonpath([ + os.path.abspath(image_path), + os.path.abspath(allowed_local_media_path) + ]) + # Check if the common path is the same as allowed_local_media_path + return common_path == os.path.abspath(allowed_local_media_path) + + +def _load_image_from_file(image_url: str, + allowed_local_media_path: str) -> Image.Image: + if not allowed_local_media_path: + raise ValueError("Invalid 'image_url': Cannot load local files without" + "'--allowed-local-media-path'.") + if allowed_local_media_path: + if not os.path.exists(allowed_local_media_path): + raise ValueError( + "Invalid '--allowed-local-media-path': " + f"The path {allowed_local_media_path} does not exist.") + if not os.path.isdir(allowed_local_media_path): + raise ValueError( + "Invalid '--allowed-local-media-path': " + f"The path {allowed_local_media_path} must be a directory.") + + # Only split once and assume the second part is the image path + _, image_path = image_url.split("file://", 1) + if not _is_subpath(image_path, allowed_local_media_path): + raise ValueError( + f"Invalid 'image_url': The file path {image_path} must" + " be a subpath of '--allowed-local-media-path'" + f" '{allowed_local_media_path}'.") + + image = Image.open(image_path) + image.load() + return image + + +def _load_image_from_data_url(image_url: str) -> Image.Image: # Only split once and assume the second part is the base64 encoded image _, image_base64 = image_url.split(",", 1) return load_image_from_base64(image_base64) -def fetch_image(image_url: str, *, image_mode: str = "RGB") -> Image.Image: +def fetch_image(image_url: str, + *, + image_mode: str = "RGB", + allowed_local_media_path: str = "") -> Image.Image: """ Load a PIL image from a HTTP or base64 data URL. @@ -43,16 +85,19 @@ def fetch_image(image_url: str, *, image_mode: str = "RGB") -> Image.Image: elif image_url.startswith('data:image'): image = _load_image_from_data_url(image_url) + elif image_url.startswith('file://'): + image = _load_image_from_file(image_url, allowed_local_media_path) else: raise ValueError("Invalid 'image_url': A valid 'image_url' must start " - "with either 'data:image' or 'http'.") + "with either 'data:image', 'file://' or 'http'.") return image.convert(image_mode) async def async_fetch_image(image_url: str, *, - image_mode: str = "RGB") -> Image.Image: + image_mode: str = "RGB", + allowed_local_media_path: str = "") -> Image.Image: """ Asynchronously load a PIL image from a HTTP or base64 data URL. @@ -65,9 +110,11 @@ async def async_fetch_image(image_url: str, elif image_url.startswith('data:image'): image = _load_image_from_data_url(image_url) + elif image_url.startswith('file://'): + image = _load_image_from_file(image_url, allowed_local_media_path) else: raise ValueError("Invalid 'image_url': A valid 'image_url' must start " - "with either 'data:image' or 'http'.") + "with either 'data:image', 'file://' or 'http'.") return image.convert(image_mode) @@ -126,8 +173,12 @@ def get_and_parse_audio(audio_url: str) -> MultiModalDataDict: return {"audio": (audio, sr)} -def get_and_parse_image(image_url: str) -> MultiModalDataDict: - image = fetch_image(image_url) +def get_and_parse_image( + image_url: str, + *, + allowed_local_media_path: str = "") -> MultiModalDataDict: + image = fetch_image(image_url, + allowed_local_media_path=allowed_local_media_path) return {"image": image} @@ -136,8 +187,12 @@ async def async_get_and_parse_audio(audio_url: str) -> MultiModalDataDict: return {"audio": (audio, sr)} -async def async_get_and_parse_image(image_url: str) -> MultiModalDataDict: - image = await async_fetch_image(image_url) +async def async_get_and_parse_image( + image_url: str, + *, + allowed_local_media_path: str = "") -> MultiModalDataDict: + image = await async_fetch_image( + image_url, allowed_local_media_path=allowed_local_media_path) return {"image": image} @@ -258,7 +313,7 @@ def repeat_and_pad_placeholder_tokens( repeat_count: Union[int, List[int]], pad_token_left: Optional[int] = None, pad_token_right: Optional[int] = None, -) -> Tuple[Optional[str], List[int]]: +) -> Tuple[Optional[str], List[int], List[PlaceholderRange]]: if isinstance(repeat_count, int): repeat_count = [repeat_count] @@ -301,6 +356,7 @@ def repeat_and_pad_placeholder_tokens( new_prompt += prompt_parts[-1] new_token_ids: List[int] = [] + placeholder_ranges: List[PlaceholderRange] = [] placeholder_token_idx = 0 for i, token in enumerate(prompt_token_ids): if token == placeholder_token_id: @@ -310,6 +366,10 @@ def repeat_and_pad_placeholder_tokens( pad_token_left=pad_token_left, pad_token_right=pad_token_right, ) + placeholder_ranges.append({ + "offset": len(new_token_ids), + "length": len(replacement_ids) + }) new_token_ids.extend(replacement_ids) placeholder_token_idx += 1 @@ -320,4 +380,14 @@ def repeat_and_pad_placeholder_tokens( else: new_token_ids.append(token) - return new_prompt, new_token_ids + return new_prompt, new_token_ids, placeholder_ranges + + +def consecutive_placeholder_ranges(num_items: int, + item_size: int) -> List[PlaceholderRange]: + """Returns a list of consecutive PlaceholderRanges of a fixed size""" + + return [ + PlaceholderRange(offset=i * item_size, length=item_size) + for i in range(num_items) + ] diff --git a/vllm/multimodal/video.py b/vllm/multimodal/video.py index c3235c4acb6fd..6c2c6720f4276 100644 --- a/vllm/multimodal/video.py +++ b/vllm/multimodal/video.py @@ -1,18 +1,19 @@ from functools import lru_cache -from typing import Any, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union import numpy as np -from vllm.config import ModelConfig from vllm.inputs.registry import InputContext from vllm.logger import init_logger from vllm.transformers_utils.processor import get_video_processor from vllm.transformers_utils.tokenizer import get_tokenizer -from vllm.utils import is_list_of from .base import MultiModalData, MultiModalInputs from .image import ImagePlugin +if TYPE_CHECKING: + from vllm.config import ModelConfig + logger = init_logger(__name__) cached_get_video_processor = lru_cache(get_video_processor) @@ -38,7 +39,7 @@ def get_data_key(self) -> str: def _get_hf_video_processor( self, - model_config: ModelConfig, + model_config: "ModelConfig", mm_processor_kwargs: Optional[Dict[str, Any]] = None, ): if mm_processor_kwargs is None: @@ -56,7 +57,10 @@ def _default_input_mapper( ) -> MultiModalInputs: model_config = ctx.model_config - if isinstance(data, np.ndarray) or is_list_of(data, np.ndarray): + if isinstance(data, list) and len(data) == 1: + data = data[0] + + if isinstance(data, np.ndarray): video_processor = self._get_hf_video_processor( model_config, mm_processor_kwargs, diff --git a/vllm/platforms/openvino.py b/vllm/platforms/openvino.py index 35dbe22abf7ff..31fe3f1fcbfe4 100644 --- a/vllm/platforms/openvino.py +++ b/vllm/platforms/openvino.py @@ -1,10 +1,12 @@ import torch import vllm.envs as envs -from vllm.utils import print_warning_once +from vllm.logger import init_logger from .interface import Platform, PlatformEnum +logger = init_logger(__name__) + class OpenVinoPlatform(Platform): _enum = PlatformEnum.OPENVINO @@ -27,5 +29,5 @@ def is_openvino_gpu(self) -> bool: @classmethod def is_pin_memory_available(self) -> bool: - print_warning_once("Pin memory is not supported on OpenViNO.") + logger.warning("Pin memory is not supported on OpenViNO.") return False diff --git a/vllm/plugins/__init__.py b/vllm/plugins/__init__.py index 4338cbc37f6c1..3336569f59467 100644 --- a/vllm/plugins/__init__.py +++ b/vllm/plugins/__init__.py @@ -1,8 +1,14 @@ import logging -from typing import Callable, Optional, Union +from typing import TYPE_CHECKING, Callable, Optional, Union import vllm.envs as envs -from vllm.compilation.config import CompilationConfig + +if TYPE_CHECKING: + from vllm.compilation.config import CompilationConfig + from vllm.config import VllmConfig +else: + CompilationConfig = None + VllmConfig = None logger = logging.getLogger(__name__) @@ -55,3 +61,15 @@ def set_compilation_config(config: Optional[CompilationConfig]): def get_compilation_config() -> Optional[CompilationConfig]: return _compilation_config + + +_vllm_config: Optional[VllmConfig] = None + + +def set_vllm_config(config: Optional[VllmConfig]): + global _vllm_config + _vllm_config = config + + +def get_vllm_config() -> Optional[VllmConfig]: + return _vllm_config diff --git a/vllm/sequence.py b/vllm/sequence.py index ff59f333f00b4..7d7ddc7ec4447 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -6,22 +6,23 @@ from collections import defaultdict from dataclasses import dataclass, field from functools import cached_property, reduce -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Mapping, Optional +from typing import (TYPE_CHECKING, Any, Callable, DefaultDict, Dict, List, + Mapping, Optional) from typing import Sequence as GenericSequence -from typing import Set, Tuple, Union, cast +from typing import Set, Tuple, Union import msgspec import torch +from typing_extensions import assert_never -from vllm.inputs.parse import is_encoder_decoder_inputs from vllm.lora.request import LoRARequest +from vllm.multimodal import MultiModalDataDict, MultiModalPlaceholderDict from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import RequestOutputKind, SamplingParams if TYPE_CHECKING: from vllm.inputs import SingletonInputs - from vllm.multimodal.base import MultiModalDataDict VLLM_TOKEN_ID_ARRAY_TYPE = "l" @@ -256,7 +257,8 @@ def output_token_ids(self) -> Tuple[int, ...]: return tuple(self._output_token_ids) @output_token_ids.setter - def output_token_ids(self, new_output_token_ids: List[int]) -> None: + def output_token_ids(self, + new_output_token_ids: GenericSequence[int]) -> None: self._output_token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, new_output_token_ids) self._update_cached_all_tokens() @@ -377,15 +379,10 @@ def __repr__(self) -> str: class Sequence: """Stores the data, status, and block information of a sequence. - - The sequence is constructed from the :code:`SingletonInputs` instance - passed in through the :code:`inputs` constructor argument. - - For encoder/decoder models, SingletonInputs encapsulates both a - decoder and encoder prompt, creating an ambiguity about which - prompt to construct the sequence from. The `from_decoder_prompt` - constructor argument signals whether to construct the Sequence - from the SingletonInputs decoder prompt, or encoder prompt. + + The sequence is constructed from the :data:`DecoderOnlyInputs` + (for decoder-only) or :data:`EncoderDecoderInputs` (for encoder-decoder) + instance passed in through the :code:`inputs` constructor argument. Args: seq_id: The ID of the sequence. @@ -395,10 +392,6 @@ class Sequence: eos_token_id: The end-of-sequence (EOS) token id recognized by this LLM. lora_request: LoRA request. prompt_adapter_request: Prompt Adapter request. - from_decoder_prompt: Construct Sequence from SingletonInputs decoder - prompt (True) or encoder prompt (False.) Must be - True for decoder-only model. - """ def __init__( @@ -409,7 +402,6 @@ def __init__( eos_token_id: Optional[int] = None, lora_request: Optional[LoRARequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, - from_decoder_prompt: bool = True, ) -> None: self.seq_id = seq_id self.inputs = inputs @@ -417,33 +409,6 @@ def __init__( self.eos_token_id = eos_token_id self.lora_request = lora_request self.prompt_adapter_request = prompt_adapter_request - self.from_decoder_prompt = from_decoder_prompt - - # For decoder-only models, a Sequence is constructed - # from an DecoderOnlyInputs instance (the `inputs` arg.) - # - # For encoder/decoder models the same `inputs` - # instance could be utilized to construct either an - # encoder sequence or a decoder sequence, because - # `DecoderOnlyInputs` has both decoder- and encoder-oriented - # member variables (i.e. it encapsulates both an encoder - # and a decoder prompt.) The decision of which type of sequence - # to generate is determined by the `from_decoder_prompt` argument. - # - # When constructing a encoder sequence - # (`from_decoder_prompt` False) it matters that - # the `DecoderOnlyInputs` instance stored in `inputs` is valid - # in the sense that its encoder-related member variables are - # populated; below, an exception is raised if this is - # not the case. - # - # When constructing a decoder sequence (`from_decoder_prompt` True) - # it does not matter whether `inputs` has its encoder-related - # member variables populated. - if not (from_decoder_prompt or is_encoder_decoder_inputs(inputs)): - raise ValueError("Cannot extract encoder input prompt from " - f"invalid input {inputs}; did you forget the " - "encoder input prompt fields?") self.data = SequenceData.from_seqs(self.prompt_token_ids) self.output_logprobs: SampleLogprobs = [] @@ -468,41 +433,57 @@ def n_blocks(self) -> int: @cached_property def prompt(self) -> Optional[str]: - # Select decoder or encoder input prompt str, as appropriate - prompt_key: str = ("prompt" - if self.from_decoder_prompt else "encoder_prompt") + inputs = self.inputs - return cast(Optional[str], self.inputs.get(prompt_key)) + if inputs["type"] == "token": + return inputs.get("prompt") + + assert_never(inputs) @cached_property def prompt_token_ids(self) -> List[int]: - # Select decoder or encoder input prompt token ids, as appropriate - prompt_token_ids_key: str = ("prompt_token_ids" - if self.from_decoder_prompt else - "encoder_prompt_token_ids") + inputs = self.inputs - # Cache computed prompt token ids - return cast(List[int], self.inputs.get(prompt_token_ids_key)) + if inputs["type"] == "token": + return inputs.get("prompt_token_ids", []) - @property + assert_never(inputs) + + @cached_property + def prompt_embeds(self) -> Optional[torch.Tensor]: + inputs = self.inputs + + if inputs["type"] == "token": + return None + + assert_never(inputs) + + @cached_property def multi_modal_data(self) -> "MultiModalDataDict": inputs = self.inputs - if (inputs.get("multi_modal_data") - and inputs.get("encoder_multi_modal_data")): - raise ValueError( - "Multi-modal data in both encoder and decoder is not supported." - ) + if inputs["type"] == "token": + return inputs.get("multi_modal_data", {}) - return cast( - "MultiModalDataDict", - (inputs.get("multi_modal_data") - or inputs.get("encoder_multi_modal_data") or {}), - ) + assert_never(inputs) - @property + @cached_property def mm_processor_kwargs(self) -> Dict[str, Any]: - return self.inputs.get("mm_processor_kwargs") or {} + inputs = self.inputs + + if inputs["type"] == "token": + return inputs.get("mm_processor_kwargs", {}) + + assert_never(inputs) + + @property + def multi_modal_placeholders(self) -> MultiModalPlaceholderDict: + inputs = self.inputs + + if inputs["type"] == "token": + return inputs.get("multi_modal_placeholders", {}) + + assert_never(inputs) @property def lora_int_id(self) -> int: @@ -728,9 +709,13 @@ def encoder_prompt_token_ids(self) -> Optional[List[int]]: if self.encoder_seq is not None else None) @property - def multi_modal_data(self) -> "MultiModalDataDict": + def multi_modal_data(self) -> MultiModalDataDict: return self.first_seq.multi_modal_data + @property + def multi_modal_placeholders(self) -> MultiModalPlaceholderDict: + return self.first_seq.multi_modal_placeholders + @property def mm_processor_kwargs(self) -> Dict[str, Any]: return self.first_seq.mm_processor_kwargs @@ -946,6 +931,7 @@ class SequenceGroupMetadata( # "MultiModalDataDict" types. We have to use Any due to msgspec # doesn't allow to have union of 2 different dicts. multi_modal_data: Optional[Any] = None + multi_modal_placeholders: Optional[MultiModalPlaceholderDict] = None mm_processor_kwargs: Optional[Dict[str, Any]] = None encoder_seq_data: Optional[SequenceData] = None cross_block_table: Optional[List[int]] = None @@ -1164,7 +1150,7 @@ def get_all_seq_ids_and_request_ids( sequence ids. """ seq_ids: List[int] = [] - request_id_seq_ids_mapping: Dict[str, Set[int]] = defaultdict(set) + request_id_seq_ids_mapping: DefaultDict[str, Set[int]] = defaultdict(set) for sg in seq_group_metadata_list: for seq_id in sg.seq_data: seq_ids.append(seq_id) diff --git a/vllm/spec_decode/draft_model_runner.py b/vllm/spec_decode/draft_model_runner.py index 3aa999fcb9ebb..17cc0ad1a4a3a 100644 --- a/vllm/spec_decode/draft_model_runner.py +++ b/vllm/spec_decode/draft_model_runner.py @@ -17,9 +17,6 @@ "Draft model speculative decoding currently only supports" "CUDA and ROCm flash attention backend.") from err -from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, - ModelConfig, ObservabilityConfig, ParallelConfig, - PromptAdapterConfig, SchedulerConfig) from vllm.logger import init_logger from vllm.multimodal import MultiModalInputs from vllm.sequence import ExecuteModelRequest, IntermediateTensors @@ -49,40 +46,13 @@ class TP1DraftModelRunner(ModelRunner): any broadcasting inside execute_model). """ - def __init__( - self, - model_config: ModelConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - device_config: DeviceConfig, - cache_config: CacheConfig, - load_config: LoadConfig, - lora_config: Optional[LoRAConfig], - kv_cache_dtype: Optional[str] = "auto", - is_driver_worker: bool = False, - prompt_adapter_config: Optional[PromptAdapterConfig] = None, - return_hidden_states: bool = False, - observability_config: Optional[ObservabilityConfig] = None, - ): - if return_hidden_states: + def __init__(self, *args, **kwargs): + if kwargs.get("return_hidden_states"): raise ValueError( "return_hidden_states is not supported for TP1DraftModelRunner." ) - super().__init__( - model_config=model_config, - parallel_config=parallel_config, - scheduler_config=scheduler_config, - device_config=device_config, - cache_config=cache_config, - load_config=load_config, - lora_config=lora_config, - kv_cache_dtype=kv_cache_dtype, - is_driver_worker=is_driver_worker, - prompt_adapter_config=prompt_adapter_config, - return_hidden_states=return_hidden_states, - observability_config=observability_config, - ) + super().__init__(*args, **kwargs) def _update_sampling_metadata(self, sampling_metadata, num_seqs, num_queries): diff --git a/vllm/spec_decode/ngram_worker.py b/vllm/spec_decode/ngram_worker.py index a777e5c3f22a7..debb3b2d5ec30 100644 --- a/vllm/spec_decode/ngram_worker.py +++ b/vllm/spec_decode/ngram_worker.py @@ -21,7 +21,7 @@ class NGramWorker(NonLLMProposerWorkerBase): def __init__(self, *args, **kwargs): # Get local_rank/vocab_size from kwargs attribute self.local_rank = kwargs["local_rank"] - self.vocab_size = kwargs["model_config"].get_vocab_size() + self.vocab_size = kwargs["vllm_config"].model_config.get_vocab_size() # Lazy initialization list. self._proposer: Top1Proposer diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index 9f7ef2f8d851c..eb3c2e88e668c 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -1,10 +1,11 @@ +import copy from collections import defaultdict from functools import cached_property from typing import Any, Dict, List, Optional, Set, Tuple, Type import torch -from vllm.config import ParallelConfig, SpeculativeConfig +from vllm.config import ParallelConfig, SpeculativeConfig, VllmConfig from vllm.distributed.communication_op import broadcast_tensor_dict from vllm.logger import init_logger from vllm.model_executor.layers.rejection_sampler import RejectionSampler @@ -45,8 +46,8 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker": """Helper method that is the entrypoint for Executors which use WorkerWrapper. It constructs a SpecDecodeWorker from the speculative config. """ - assert "speculative_config" in kwargs - speculative_config: SpeculativeConfig = kwargs.get("speculative_config") + vllm_config: VllmConfig = kwargs.get("vllm_config") + speculative_config: SpeculativeConfig = vllm_config.speculative_config assert speculative_config is not None draft_worker_kwargs = kwargs.copy() @@ -58,14 +59,20 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker": target_worker.model_runner.disable_logprobs =\ speculative_config.disable_logprobs + draft_worker_config = copy.deepcopy(vllm_config) + draft_worker_config.model_config = speculative_config.draft_model_config + draft_worker_config.quant_config = VllmConfig._get_quantization_config( + draft_worker_config.model_config, + vllm_config.load_config, + ) + draft_worker_config.parallel_config = speculative_config.draft_parallel_config # noqa + # TODO allow draft-model specific load config. + # Override draft-model specific worker args. draft_worker_kwargs.update( - model_config=speculative_config.draft_model_config, - parallel_config=speculative_config.draft_parallel_config, + vllm_config=draft_worker_config, ngram_prompt_lookup_max=speculative_config.ngram_prompt_lookup_max, ngram_prompt_lookup_min=speculative_config.ngram_prompt_lookup_min, - # TODO allow draft-model specific load config. - #load_config=load_config, ) spec_decode_worker = SpecDecodeWorker.create_worker( @@ -134,29 +141,27 @@ def create_worker( draft_worker_kwargs.pop("ngram_prompt_lookup_max")) ngram_prompt_lookup_min = ( draft_worker_kwargs.pop("ngram_prompt_lookup_min")) + draft_model_config = draft_worker_kwargs["vllm_config"].model_config + draft_parallel_config: ParallelConfig = draft_worker_kwargs[ + 'vllm_config'].parallel_config if ngram_prompt_lookup_max > 0: proposer_worker = NGramWorker(**draft_worker_kwargs) proposer_worker.set_ngram_window_size(ngram_prompt_lookup_min, ngram_prompt_lookup_max) else: - draft_parallel_config: ParallelConfig = draft_worker_kwargs[ - 'parallel_config'] draft_tp = draft_parallel_config.tensor_parallel_size target_tp = scorer_worker.parallel_config.tensor_parallel_size - if draft_worker_kwargs[ - "model_config"].hf_config.model_type == "mlp_speculator": + if draft_model_config.hf_config.model_type == "mlp_speculator": proposer_worker = MLPSpeculatorWorker(**draft_worker_kwargs) - elif draft_worker_kwargs[ - "model_config"].hf_config.model_type == "medusa": + elif draft_model_config.hf_config.model_type == "medusa": proposer_worker = MedusaWorker(**draft_worker_kwargs) else: if draft_tp == 1: draft_worker_kwargs[ "model_runner_cls"] = TP1DraftModelRunner else: - if draft_worker_kwargs[ - "model_config"].hf_config.model_type == "eagle": + if draft_model_config.hf_config.model_type == "eagle": raise NotImplementedError( "EAGLE does not support TP > 1 yet") @@ -190,8 +195,8 @@ def create_worker( "[Speculative Decoding] Disabling MQA scorer as the " "MQA is only available with flash attn backend.") - if "model_config" in draft_worker_kwargs and \ - draft_worker_kwargs["model_config"].max_model_len < \ + if draft_model_config and \ + draft_model_config.max_model_len < \ scorer_worker.model_config.max_model_len: disable_mqa_scorer = True logger.info( diff --git a/vllm/spec_decode/target_model_runner.py b/vllm/spec_decode/target_model_runner.py index 2bb7af7d7c600..e61cde5b17f20 100644 --- a/vllm/spec_decode/target_model_runner.py +++ b/vllm/spec_decode/target_model_runner.py @@ -1,8 +1,6 @@ from typing import List, Optional -from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, - ModelConfig, ObservabilityConfig, ParallelConfig, - PromptAdapterConfig, SchedulerConfig) +from vllm.config import VllmConfig from vllm.sequence import SequenceGroupMetadata from vllm.worker.model_runner import (ModelInputForGPUWithSamplingMetadata, ModelRunner) @@ -20,35 +18,21 @@ class TargetModelRunner(ModelRunner): requested or not. """ - def __init__(self, - model_config: ModelConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - device_config: DeviceConfig, - cache_config: CacheConfig, - load_config: LoadConfig, - lora_config: Optional[LoRAConfig], - kv_cache_dtype: Optional[str] = "auto", - is_driver_worker: bool = False, - prompt_adapter_config: Optional[PromptAdapterConfig] = None, - return_hidden_states: bool = False, - observability_config: Optional[ObservabilityConfig] = None): + def __init__( + self, + vllm_config: VllmConfig, + kv_cache_dtype: Optional[str] = "auto", + is_driver_worker: bool = False, + return_hidden_states: bool = False, + ): # An internal boolean member variable to indicate if token log # probabilities are needed or not. self.disable_logprobs = True super().__init__( - model_config=model_config, - parallel_config=parallel_config, - scheduler_config=scheduler_config, - device_config=device_config, - cache_config=cache_config, - load_config=load_config, - lora_config=lora_config, + vllm_config=vllm_config, kv_cache_dtype=kv_cache_dtype, is_driver_worker=is_driver_worker, - prompt_adapter_config=prompt_adapter_config, return_hidden_states=return_hidden_states, - observability_config=observability_config, ) def prepare_model_input( diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 9bd2531d7a15c..08697274854e0 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -19,6 +19,7 @@ # yapf: disable from vllm.transformers_utils.configs import (ChatGLMConfig, DbrxConfig, EAGLEConfig, ExaoneConfig, + H2OVLChatConfig, InternVLChatConfig, JAISConfig, MedusaConfig, MllamaConfig, MLPSpeculatorConfig, MPTConfig, @@ -52,6 +53,7 @@ "medusa": MedusaConfig, "eagle": EAGLEConfig, "exaone": ExaoneConfig, + "h2ovl_chat": H2OVLChatConfig, "internvl_chat": InternVLChatConfig, "nemotron": NemotronConfig, "NVLM_D": NVLM_D_Config, diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py index f0d79197a82c5..d1e19c9a33c24 100644 --- a/vllm/transformers_utils/configs/__init__.py +++ b/vllm/transformers_utils/configs/__init__.py @@ -6,6 +6,7 @@ # tiiuae/falcon-7b(-instruct) models. Newer Falcon models will use the # `FalconConfig` class from the official HuggingFace transformers library. from vllm.transformers_utils.configs.falcon import RWConfig +from vllm.transformers_utils.configs.h2ovl import H2OVLChatConfig from vllm.transformers_utils.configs.internvl import InternVLChatConfig from vllm.transformers_utils.configs.jais import JAISConfig from vllm.transformers_utils.configs.medusa import MedusaConfig @@ -22,6 +23,7 @@ "DbrxConfig", "MPTConfig", "RWConfig", + "H2OVLChatConfig", "InternVLChatConfig", "JAISConfig", "MedusaConfig", @@ -33,4 +35,4 @@ "NVLM_D_Config", "SolarConfig", "UltravoxConfig", -] +] \ No newline at end of file diff --git a/vllm/transformers_utils/configs/h2ovl.py b/vllm/transformers_utils/configs/h2ovl.py new file mode 100644 index 0000000000000..b94c5b77e4b7f --- /dev/null +++ b/vllm/transformers_utils/configs/h2ovl.py @@ -0,0 +1,13 @@ +# Adapted from +# https://huggingface.co/h2oai/h2ovl-mississippi-2b/blob/main/configuration_h2ovl_chat.py +# -------------------------------------------------------- +# H2OVL-Mississippi +# Copyright (c) 2024 H2O.AI +# Licensed under Apache 2.0 License [see LICENSE for details] +# -------------------------------------------------------- + +from .internvl import InternVLChatConfig + + +class H2OVLChatConfig(InternVLChatConfig): + model_type = "h2ovl_chat" diff --git a/vllm/transformers_utils/tokenizers/mistral.py b/vllm/transformers_utils/tokenizers/mistral.py index 80e21c2d32ecc..896f70bc1dafd 100644 --- a/vllm/transformers_utils/tokenizers/mistral.py +++ b/vllm/transformers_utils/tokenizers/mistral.py @@ -16,9 +16,13 @@ from mistral_common.tokens.tokenizers.tekken import (SpecialTokenPolicy, Tekkenizer) +from vllm.logger import init_logger + if TYPE_CHECKING: from vllm.entrypoints.chat_utils import ChatCompletionMessageParam +logger = init_logger(__name__) + @dataclass class Encoding: @@ -72,20 +76,21 @@ def __init__(self, tokenizer: PublicMistralTokenizer) -> None: # Make sure special tokens will not raise tokenizer_.special_token_policy = SpecialTokenPolicy.IGNORE - self._vocab = { - token: idx - for idx, token in enumerate(tokenizer_.vocab()) - } elif isinstance(tokenizer_, SentencePieceTokenizer): - self._vocab = { - token: idx - for idx, token in enumerate(tokenizer_.vocab()) - } + pass else: raise TypeError(f"Unsupported tokenizer: {type(tokenizer_)}") + self._vocab = tokenizer_.vocab() + # Convert to a Dict[str, int] to match protocol, but this is a lossy + # conversion. There may be multiple token ids that decode to the same + # string due to partial UTF-8 byte sequences being converted to � + self._vocab_dict = { + token: idx + for idx, token in enumerate(self._vocab) + } self.tokenizer = tokenizer_ - self._max_token_id = max(self._vocab.values()) + self._max_token_id = self.vocab_size - 1 @classmethod def from_pretrained(cls, @@ -182,7 +187,9 @@ def __call__( return Encoding(input_ids=input_ids) def get_vocab(self) -> Dict[str, int]: - return self._vocab + # NB: the dictionary form of the vocabulary collapses token ids that map + # to the same string but have different bytes + return self._vocab_dict def get_added_vocab(self) -> Dict[str, int]: # Mistral tokenizers have no added vocabulary @@ -220,14 +227,20 @@ def convert_tokens_to_string(self, tokens: List[str]) -> str: if any(isinstance(t, bytes) for t in tokens): # we need to encode and decode all tokens again shift = self.tokenizer.num_special_tokens - byte_tokens = [ - t.encode("utf-8") if not isinstance(t, bytes) else t - for t in tokens - ] - ids = [ - self.tokenizer._tekken_token2id_nospecial[t] + shift - for t in byte_tokens - ] + + def _token_to_id(t: str): + t_bytes = t.encode("utf-8") \ + if not isinstance(t, bytes) else t + try: + return shift + \ + self.tokenizer._tekken_token2id_nospecial[t_bytes] + except KeyError: + logger.warning( + "Failed to convert token %s to id," + " replacing with ", t_bytes) + return self.tokenizer.unk_id + + ids = [_token_to_id(t) for t in tokens] decoded = self.tokenizer.decode(ids) else: decoded = "".join(tokens) @@ -236,7 +249,13 @@ def convert_tokens_to_string(self, tokens: List[str]) -> str: return decoded - def decode(self, ids: Union[List[int], int]) -> str: + def decode(self, + ids: Union[List[int], int], + skip_special_tokens: bool = True) -> str: + assert ( + skip_special_tokens + ), "Skipping special tokens is not supported for Mistral tokenizers." + if isinstance(ids, int): ids = [ids] return self.tokenizer.decode(ids) @@ -257,10 +276,11 @@ def convert_ids_to_tokens( tokens = [self.tokenizer.id_to_piece(id) for id in ids] - if any(t.strip() == "�" for t in tokens): - # if any stripped decoded token is undefined - # because it's invalid unicode then pass bytes + if any("�" in t for t in tokens): + # if a decoded token contains the replacement character, then the + # token has an incomplete UTF-8 character so we must use bytes # See: https://github.com/vllm-project/vllm/pull/8640 + # https://github.com/vllm-project/vllm/pull/9625 tokens = [self.tokenizer.id_to_byte_piece(id) for id in ids] return tokens diff --git a/vllm/utils.py b/vllm/utils.py index 5488719cc99b0..a742ec8d76908 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -80,8 +80,8 @@ "currently supported with encoder/" "decoder models.") -STR_NOT_IMPL_ENC_DEC_BACKEND = ("XFormers is the only backend " - "currently supported with encoder/" +STR_NOT_IMPL_ENC_DEC_BACKEND = ("XFormers and Flash-Attention are the only " + "backends currently supported with encoder/" "decoder models.") STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER = ("Prompt adapters are not " @@ -1551,7 +1551,14 @@ def direct_register_custom_op( """ if is_in_doc_build(): return - schema_str = torch.library.infer_schema(op_func, mutates_args=mutates_args) + import torch.library + if hasattr(torch.library, "infer_schema"): + schema_str = torch.library.infer_schema(op_func, + mutates_args=mutates_args) + else: + # for pytorch 2.4 + import torch._custom_op.impl + schema_str = torch._custom_op.impl.infer_schema(op_func, mutates_args) my_lib = target_lib or vllm_lib my_lib.define(op_name + schema_str) my_lib.impl(op_name, op_func, "CUDA") diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index 072e52bcd686a..64cc18149d6c5 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -2,11 +2,9 @@ from typing import (Any, Dict, Iterable, List, Mapping, Optional, Tuple, Type, Union) -from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, - EngineConfig, LoadConfig, LoRAConfig, ModelConfig, - ObservabilityConfig, ParallelConfig, - PromptAdapterConfig, SchedulerConfig, - SpeculativeConfig) +from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig, + ObservabilityConfig, ParallelConfig, SchedulerConfig, + VllmConfig) from vllm.engine.arg_utils import EngineArgs from vllm.engine.metrics_types import StatLoggerBase from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, @@ -35,17 +33,7 @@ class LLMEngine: def __init__( self, - model_config: ModelConfig, - cache_config: CacheConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - device_config: DeviceConfig, - load_config: LoadConfig, - lora_config: Optional[LoRAConfig], - speculative_config: Optional[SpeculativeConfig], - decoding_config: Optional[DecodingConfig], - observability_config: Optional[ObservabilityConfig], - prompt_adapter_config: Optional[PromptAdapterConfig], + vllm_config: VllmConfig, executor_class: Type[GPUExecutor], log_stats: bool, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, @@ -53,6 +41,22 @@ def __init__( input_registry: InputRegistry = INPUT_REGISTRY, use_cached_outputs: bool = False, ) -> None: + + # TODO: remove the local variables and use self.* throughout the class. + model_config = self.model_config = vllm_config.model_config + cache_config = self.cache_config = vllm_config.cache_config + lora_config = self.lora_config = vllm_config.lora_config + parallel_config = self.parallel_config = vllm_config.parallel_config + scheduler_config = self.scheduler_config = vllm_config.scheduler_config + device_config = self.device_config = vllm_config.device_config + speculative_config = self.speculative_config = vllm_config.speculative_config # noqa + load_config = self.load_config = vllm_config.load_config + decoding_config = self.decoding_config = vllm_config.decoding_config or DecodingConfig( # noqa + ) + prompt_adapter_config = self.prompt_adapter_config = vllm_config.prompt_adapter_config # noqa + observability_config = self.observability_config = vllm_config.observability_config or ObservabilityConfig( # noqa + ) + # Override the configs for V1. # FIXME if usage_context == UsageContext.LLM_CLASS: @@ -112,18 +116,6 @@ def __init__( model_config.mm_processor_kwargs, ) - self.model_config = model_config - self.cache_config = cache_config - self.lora_config = lora_config - self.parallel_config = parallel_config - self.scheduler_config = scheduler_config - self.device_config = device_config - self.speculative_config = speculative_config - self.load_config = load_config - self.decoding_config = decoding_config or DecodingConfig() - self.prompt_adapter_config = prompt_adapter_config - self.observability_config = observability_config or ObservabilityConfig( - ) self.log_stats = log_stats assert not self.model_config.skip_tokenizer_init @@ -154,18 +146,7 @@ def __init__( # Request id -> RequestOutput self.request_outputs: Dict[str, RequestOutput] = {} - self.model_executor = executor_class( - model_config=model_config, - cache_config=cache_config, - parallel_config=parallel_config, - scheduler_config=scheduler_config, - device_config=device_config, - lora_config=lora_config, - speculative_config=speculative_config, - load_config=load_config, - prompt_adapter_config=prompt_adapter_config, - observability_config=self.observability_config, - ) + self.model_executor = executor_class(vllm_config=vllm_config) assert self.model_config.task != "embedding" self._initialize_kv_caches() @@ -203,7 +184,7 @@ def from_engine_args( executor_class = cls._get_executor_cls(engine_config) # Create the LLM engine. engine = cls( - **engine_config.to_dict(), + vllm_config=engine_config, executor_class=executor_class, log_stats=not engine_args.disable_log_stats, usage_context=usage_context, @@ -497,7 +478,7 @@ def get_lora_config(self) -> LoRAConfig: return self.lora_config @classmethod - def _get_executor_cls(cls, engine_config: EngineConfig): + def _get_executor_cls(cls, engine_config: VllmConfig): return GPUExecutor def is_tracing_enabled(self) -> bool: diff --git a/vllm/v1/executor/gpu_executor.py b/vllm/v1/executor/gpu_executor.py index c780c7031c3d6..f71fa16b16e27 100644 --- a/vllm/v1/executor/gpu_executor.py +++ b/vllm/v1/executor/gpu_executor.py @@ -1,10 +1,7 @@ import os from typing import Optional, Tuple -from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, - ModelConfig, ObservabilityConfig, ParallelConfig, - PromptAdapterConfig, SchedulerConfig, - SpeculativeConfig) +from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.utils import get_distributed_init_method, get_ip, get_open_port from vllm.v1.outputs import ModelRunnerOutput @@ -15,29 +12,18 @@ class GPUExecutor: - def __init__( - self, - model_config: ModelConfig, - cache_config: CacheConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - device_config: DeviceConfig, - load_config: LoadConfig, - lora_config: Optional[LoRAConfig], - speculative_config: Optional[SpeculativeConfig], - prompt_adapter_config: Optional[PromptAdapterConfig], - observability_config: Optional[ObservabilityConfig], - ) -> None: - self.model_config = model_config - self.cache_config = cache_config - self.lora_config = lora_config - self.load_config = load_config - self.parallel_config = parallel_config - self.scheduler_config = scheduler_config - self.device_config = device_config - self.speculative_config = speculative_config - self.prompt_adapter_config = prompt_adapter_config - self.observability_config = observability_config + def __init__(self, vllm_config: VllmConfig) -> None: + self.vllm_config = vllm_config + self.model_config = vllm_config.model_config + self.cache_config = vllm_config.cache_config + self.lora_config = vllm_config.lora_config + self.load_config = vllm_config.load_config + self.parallel_config = vllm_config.parallel_config + self.scheduler_config = vllm_config.scheduler_config + self.device_config = vllm_config.device_config + self.speculative_config = vllm_config.speculative_config + self.prompt_adapter_config = vllm_config.prompt_adapter_config + self.observability_config = vllm_config.observability_config self.worker = self._create_worker() self.worker.initialize() @@ -56,19 +42,10 @@ def _create_worker( distributed_init_method = get_distributed_init_method( get_ip(), get_open_port()) return Worker( - model_config=self.model_config, - parallel_config=self.parallel_config, - scheduler_config=self.scheduler_config, - device_config=self.device_config, - cache_config=self.cache_config, - load_config=self.load_config, + vllm_config=self.vllm_config, local_rank=local_rank, rank=rank, distributed_init_method=distributed_init_method, - lora_config=self.lora_config, - speculative_config=self.speculative_config, - prompt_adapter_config=self.prompt_adapter_config, - observability_config=self.observability_config, ) def determine_num_available_blocks(self) -> Tuple[int, int]: diff --git a/vllm/v1/sample/metadata.py b/vllm/v1/sample/metadata.py index 28614377b27b9..9ef36f2e6b212 100644 --- a/vllm/v1/sample/metadata.py +++ b/vllm/v1/sample/metadata.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import List, Optional +from typing import Dict import torch @@ -16,7 +16,6 @@ class SamplingMetadata: no_top_p: bool no_top_k: bool - generators: List[Optional[torch.Generator]] - no_generator: bool + generators: Dict[int, torch.Generator] max_num_logprobs: int diff --git a/vllm/v1/sample/sampler.py b/vllm/v1/sample/sampler.py index 157c4dd6d771e..927f274541c4d 100644 --- a/vllm/v1/sample/sampler.py +++ b/vllm/v1/sample/sampler.py @@ -1,5 +1,5 @@ """A layer that samples the next tokens from the model's outputs.""" -from typing import List, Optional +from typing import Dict import torch import torch.nn as nn @@ -84,22 +84,21 @@ def greedy_sample(self, probs: torch.Tensor) -> torch.Tensor: def random_sample( self, probs: torch.Tensor, - generators: List[Optional[torch.Generator]], - no_generator: bool, + generators: Dict[int, torch.Generator], ) -> torch.Tensor: q = torch.empty_like(probs) # NOTE(woosuk): To batch-process the requests without their own seeds, # which is the common case, we first assume that every request does # not have its own seed. Then, we overwrite the values for the requests # that have their own seeds. - q.exponential_() - if not no_generator: - assert len(generators) == probs.shape[0] + if len(generators) != probs.shape[0]: + # This might still be done here unnecessarily if there are greedies + q.exponential_() + if generators: # TODO(woosuk): This can be slow because we handle each request # one by one. Optimize this. - for i, generator in enumerate(generators): - if generator is not None: - q[i].exponential_(generator=generator) + for i, generator in generators.items(): + q[i].exponential_(generator=generator) return probs.div_(q).argmax(dim=-1).view(-1) def sample( @@ -112,13 +111,11 @@ def sample( if sampling_metadata.all_greedy: return self.greedy_sample(probs) if sampling_metadata.all_random: - return self.random_sample(probs, sampling_metadata.generators, - sampling_metadata.no_generator) + return self.random_sample(probs, sampling_metadata.generators) greedy_sampled = self.greedy_sample(probs) random_sampled = self.random_sample(probs, - sampling_metadata.generators, - sampling_metadata.no_generator) + sampling_metadata.generators) sampled = torch.where( sampling_metadata.temperature < _SAMPLING_EPS, greedy_sampled, diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index e84645ac7a4ae..ae4239f8e1fab 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -7,9 +7,7 @@ import torch.distributed import torch.nn as nn -from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, - ModelConfig, ObservabilityConfig, ParallelConfig, - PromptAdapterConfig, SchedulerConfig) +from vllm.config import VllmConfig from vllm.forward_context import set_forward_context from vllm.logger import init_logger from vllm.model_executor.model_loader import get_model @@ -33,26 +31,25 @@ class GPUModelRunner: def __init__( self, - model_config: ModelConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - device_config: DeviceConfig, - cache_config: CacheConfig, - load_config: LoadConfig, - lora_config: Optional[LoRAConfig] = None, - prompt_adapter_config: Optional[PromptAdapterConfig] = None, - observability_config: Optional[ObservabilityConfig] = None, + vllm_config: VllmConfig, ): - self.model_config = model_config - self.parallel_config = parallel_config - self.scheduler_config = scheduler_config - self.device_config = device_config - self.cache_config = cache_config - self.lora_config = lora_config - self.load_config = load_config - self.prompt_adapter_config = prompt_adapter_config - self.observability_config = observability_config - + # TODO: use ModelRunnerBase.__init__(self, vllm_config=vllm_config) + self.vllm_config = vllm_config + self.model_config = vllm_config.model_config + self.cache_config = vllm_config.cache_config + self.lora_config = vllm_config.lora_config + self.load_config = vllm_config.load_config + self.parallel_config = vllm_config.parallel_config + self.scheduler_config = vllm_config.scheduler_config + self.device_config = vllm_config.device_config + self.speculative_config = vllm_config.speculative_config + self.prompt_adapter_config = vllm_config.prompt_adapter_config + self.observability_config = vllm_config.observability_config + + model_config = self.model_config + cache_config = self.cache_config + scheduler_config = self.scheduler_config + parallel_config = self.parallel_config self.device = self.device_config.device self.pin_memory = is_pin_memory_available() self.dtype = self.model_config.dtype @@ -131,13 +128,20 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: # Add new requests to the cached states. for req_data in scheduler_output.scheduled_new_reqs: req_id = req_data.req_id + sampling_params = req_data.sampling_params + if sampling_params.seed is not None: + generator = torch.Generator(device=self.device) + generator.manual_seed(sampling_params.seed) + else: + generator = None + self.requests[req_id] = CachedRequestState( req_id=req_id, prompt_token_ids=req_data.prompt_token_ids, prompt=req_data.prompt, multi_modal_data=req_data.multi_modal_data, - sampling_params=req_data.sampling_params, - generator=None, # TODO + sampling_params=sampling_params, + generator=generator, block_ids=req_data.block_ids, num_computed_tokens=req_data.num_computed_tokens, output_token_ids=[], @@ -345,11 +349,9 @@ def execute_model( else: # Ignore the sampled token from the partial request. # Rewind the generator state as if the token was not sampled. - generator = self.input_batch.generators[i] + generator = self.input_batch.generators.get(i) if generator is not None: - offset = generator.get_offset() - generator = generator.set_offset(offset - 1) - self.input_batch.generators[i] = generator + generator.set_offset(generator.get_offset() - 1) if sampler_output.logprob_token_ids is None: logprob_token_ids = None @@ -372,13 +374,7 @@ def load_model(self) -> None: logger.info("Starting to load model %s...", self.model_config.model) with DeviceMemoryProfiler() as m: # noqa: SIM117 with patch("vllm.model_executor.layers.sampler.Sampler", Sampler): - self.model = get_model(model_config=self.model_config, - device_config=self.device_config, - load_config=self.load_config, - lora_config=self.lora_config, - parallel_config=self.parallel_config, - scheduler_config=self.scheduler_config, - cache_config=self.cache_config) + self.model = get_model(vllm_config=self.vllm_config) self.model_memory_usage = m.consumed_memory logger.info("Loading model weights took %.4f GB", @@ -503,8 +499,8 @@ def __init__( self.top_k_cpu = self.top_k_cpu_tensor.numpy() self.top_k_reqs: Set[str] = set() - self.generators: List[Optional[torch.Generator]] = [None - ] * max_num_reqs + # req_index -> generator + self.generators: Dict[int, torch.Generator] = {} self.num_logprobs: Dict[str, int] = {} self.prompt_logprob_reqs: Set[str] = set() @@ -518,8 +514,9 @@ def add_request( req_index = self.num_reqs assert req_index < self.max_num_reqs - self.req_ids[req_index] = request.req_id - self.req_id_to_index[request.req_id] = req_index + req_id = request.req_id + self.req_ids[req_index] = req_id + self.req_id_to_index[req_id] = req_index # Copy the prompt token ids and output token ids. num_prompt_tokens = len(request.prompt_token_ids) @@ -537,27 +534,24 @@ def add_request( sampling_params = request.sampling_params self.temperature_cpu[req_index] = sampling_params.temperature if sampling_params.sampling_type == SamplingType.GREEDY: - self.greedy_reqs.add(req_index) - elif sampling_params.sampling_type == SamplingType.RANDOM: - self.random_reqs.add(req_index) - elif sampling_params.sampling_type == SamplingType.RANDOM_SEED: - # TODO(woosuk): Support per-request random seed. - raise NotImplementedError("Per-request seed is not supported yet.") + self.greedy_reqs.add(req_id) + else: + self.random_reqs.add(req_id) self.top_p_cpu[req_index] = sampling_params.top_p if sampling_params.top_p < 1: - self.top_p_reqs.add(req_index) + self.top_p_reqs.add(req_id) self.top_k_cpu[req_index] = sampling_params.top_k if sampling_params.top_k > 0: - self.top_k_reqs.add(req_index) + self.top_k_reqs.add(req_id) self.generators[req_index] = request.generator num_logprobs = sampling_params.logprobs if num_logprobs is not None and num_logprobs > 0: - self.num_logprobs[request.req_id] = num_logprobs + self.num_logprobs[req_id] = num_logprobs if sampling_params.prompt_logprobs: - self.prompt_logprob_reqs.add(req_index) + self.prompt_logprob_reqs.add(req_id) def remove_request(self, req_id: str) -> Optional[int]: req_index = self.req_id_to_index.pop(req_id, None) @@ -569,7 +563,7 @@ def remove_request(self, req_id: str) -> Optional[int]: self.random_reqs.discard(req_id) self.top_p_reqs.discard(req_id) self.top_k_reqs.discard(req_id) - self.generators[req_index] = None + self.generators.pop(req_index, None) self.num_logprobs.pop(req_id, None) self.prompt_logprob_reqs.discard(req_id) return req_index @@ -621,7 +615,9 @@ def condense(self, empty_req_indices: List[int]) -> None: last_req_index] self.top_p_cpu[empty_index] = self.top_p_cpu[last_req_index] self.top_k_cpu[empty_index] = self.top_k_cpu[last_req_index] - self.generators[empty_index] = self.generators[last_req_index] + generator = self.generators.pop(last_req_index, None) + if generator is not None: + self.generators[empty_index] = generator # Decrement last_req_index since it is now empty. last_req_index -= 1 @@ -645,8 +641,7 @@ def make_sampling_metadata( top_k=self.top_k[:self.num_reqs], no_top_p=self.no_top_p, no_top_k=self.no_top_k, - generators=self.generators[:self.num_reqs], - no_generator=self.no_generator, + generators=self.generators, max_num_logprobs=self.max_num_logprobs, ) @@ -670,16 +665,9 @@ def no_top_p(self) -> bool: def no_top_k(self) -> bool: return len(self.top_k_reqs) == 0 - @property - def no_generator(self) -> bool: - return len(self.generators) == 0 - @property def max_num_logprobs(self) -> int: - if self.num_logprobs: - return max(self.num_logprobs.values()) - else: - return 0 + return max(self.num_logprobs.values()) if self.num_logprobs else 0 @property def no_logprob(self) -> bool: diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 8c5ca2ec35666..c8192b7f86eb0 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -6,10 +6,7 @@ import torch import torch.distributed -from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, - ModelConfig, ObservabilityConfig, ParallelConfig, - PromptAdapterConfig, SchedulerConfig, - SpeculativeConfig) +from vllm.config import CacheConfig, ModelConfig, ParallelConfig, VllmConfig from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment, set_custom_all_reduce) @@ -30,48 +27,35 @@ class Worker: def __init__( self, - model_config: ModelConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - device_config: DeviceConfig, - cache_config: CacheConfig, - load_config: LoadConfig, + vllm_config: VllmConfig, local_rank: int, rank: int, distributed_init_method: str, - speculative_config: Optional[SpeculativeConfig] = None, - lora_config: Optional[LoRAConfig] = None, - prompt_adapter_config: Optional[PromptAdapterConfig] = None, - observability_config: Optional[ObservabilityConfig] = None, ): - self.model_config = model_config - self.parallel_config = parallel_config - self.scheduler_config = scheduler_config - self.device_config = device_config - self.cache_config = cache_config - self.load_config = load_config + + # TODO: use WorkerBase.__init__(self, vllm_config=vllm_config) + self.vllm_config = vllm_config + self.model_config = vllm_config.model_config + self.cache_config = vllm_config.cache_config + self.lora_config = vllm_config.lora_config + self.load_config = vllm_config.load_config + self.parallel_config = vllm_config.parallel_config + self.scheduler_config = vllm_config.scheduler_config + self.device_config = vllm_config.device_config + self.speculative_config = vllm_config.speculative_config + self.prompt_adapter_config = vllm_config.prompt_adapter_config + self.observability_config = vllm_config.observability_config + self.local_rank = local_rank self.rank = rank self.distributed_init_method = distributed_init_method - self.lora_config = lora_config - self.speculative_config = speculative_config - self.prompt_adapter_config = prompt_adapter_config - self.observability_config = observability_config if self.model_config.trust_remote_code: # note: lazy import to avoid importing torch before initializing from vllm.utils import init_cached_hf_modules init_cached_hf_modules() - self.model_runner = GPUModelRunner( - model_config, - parallel_config, - scheduler_config, - device_config, - cache_config, - load_config, - lora_config=lora_config, - ) + self.model_runner = GPUModelRunner(vllm_config) def initialize(self): if self.device_config.device.type == "cuda": diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index 5032896600b3b..fdd72a452f2ad 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -1,5 +1,6 @@ import dataclasses import weakref +from collections import defaultdict from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union @@ -7,16 +8,14 @@ from torch import nn from vllm.attention import AttentionMetadata, get_attn_backend -from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, - ModelConfig, ParallelConfig, PromptAdapterConfig, - SchedulerConfig) +from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.model_executor import SamplingMetadata from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.model_loader import get_model from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, - MultiModalInputs) + MultiModalInputs, MultiModalPlaceholderMap) from vllm.sequence import (IntermediateTensors, SequenceData, SequenceGroupMetadata) from vllm.transformers_utils.config import uses_mrope @@ -148,9 +147,18 @@ def build(self) -> ModelInputForCPU: query_lens=seq_lens, ) - def _compute_multi_modal_input(self, seq_data: SequenceData, mm_data, - computed_len: int, + def _compute_multi_modal_input(self, seq_group: SequenceGroupMetadata, + seq_data: SequenceData, computed_len: int, mm_processor_kwargs: Dict[str, Any]): + + # NOTE: mm_data only includes the subset of multi-modal items that + # intersect with the current prefill positions. + mm_data, placeholder_maps = MultiModalPlaceholderMap.from_seq_group( + seq_group, range(computed_len, len(seq_data.get_token_ids()))) + + if not mm_data: + return + mm_kwargs = self.multi_modal_input_mapper(mm_data, mm_processor_kwargs) # special processing for mrope position deltas. @@ -179,7 +187,7 @@ def _compute_multi_modal_input(self, seq_data: SequenceData, mm_data, context_len=computed_len, ) seq_data.mrope_position_delta = mrope_position_delta - return mm_kwargs, mrope_positions + return mm_kwargs, placeholder_maps, mrope_positions def _prepare_prompt( self, @@ -194,6 +202,9 @@ def _prepare_prompt( slot_mapping: List[int] = [] seq_lens: List[int] = [] multi_modal_inputs_list: List[MultiModalInputs] = [] + multi_modal_placeholder_maps: Dict[ + str, + MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap) for seq_group_metadata in seq_group_metadata_list: assert seq_group_metadata.is_prompt @@ -210,11 +221,15 @@ def _prepare_prompt( input_tokens.extend(prompt_tokens) # Token ids mrope_positions = None - if (mm_data := seq_group_metadata.multi_modal_data): - mm_kwargs, mrope_positions = self._compute_multi_modal_input( - seq_data, mm_data, computed_len, + if seq_group_metadata.multi_modal_data: + mm_kwargs, placeholder_maps, mrope_positions = self \ + ._compute_multi_modal_input( + seq_group_metadata, seq_data, computed_len, seq_group_metadata.mm_processor_kwargs) multi_modal_inputs_list.append(mm_kwargs) + for modality, placeholder_map in placeholder_maps.items(): + multi_modal_placeholder_maps[modality].extend( + placeholder_map) # Token position ids # NOTE(woosuk): Here we assume that the first token in the prompt @@ -264,6 +279,11 @@ def _prepare_prompt( slot_mapping = torch.tensor(slot_mapping, dtype=torch.long, device=self.device) # type: ignore + placeholder_index_maps = { + modality: placeholder_map.index_map() + for modality, placeholder_map in + multi_modal_placeholder_maps.items() + } attn_metadata = self.attn_backend.make_metadata( is_prompt=True, @@ -275,6 +295,7 @@ def _prepare_prompt( num_decode_tokens=0, block_tables=torch.tensor([]), slot_mapping=slot_mapping, + multi_modal_placeholder_index_maps=placeholder_index_maps, ) multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list) @@ -366,6 +387,7 @@ def _prepare_decode( attn_metadata = self.attn_backend.make_metadata( is_prompt=False, slot_mapping=slot_mapping, + multi_modal_placeholder_index_maps=None, seq_lens=seq_lens, seq_lens_tensor=seq_lens_tensor, max_decode_seq_len=max_decode_seq_len, @@ -388,29 +410,18 @@ class CPUModelRunner(ModelRunnerBase[ModelInputForCPU]): def __init__( self, - model_config: ModelConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - device_config: DeviceConfig, - cache_config: CacheConfig, - load_config: LoadConfig, - lora_config: Optional[LoRAConfig], + vllm_config: VllmConfig, kv_cache_dtype: Optional[str] = "auto", - prompt_adapter_config: Optional[PromptAdapterConfig] = None, is_driver_worker: bool = False, *args, **kwargs, ): - self.model_config = model_config - self.parallel_config = parallel_config - self.scheduler_config = scheduler_config + ModelRunnerBase.__init__(self, vllm_config) # Currently, CPU worker doesn't support chunked prefill. assert self.scheduler_config.chunked_prefill_enabled is False - self.device_config = device_config - self.cache_config = cache_config - self.lora_config = lora_config - self.prompt_adapter_config = prompt_adapter_config - self.load_config = load_config + model_config = self.model_config + cache_config = self.cache_config + self.is_driver_worker = is_driver_worker self.device = self.device_config.device @@ -442,13 +453,7 @@ def model_is_mrope(self) -> bool: return uses_mrope(self.model_config.hf_config) def load_model(self) -> None: - self.model = get_model(model_config=self.model_config, - load_config=self.load_config, - device_config=self.device_config, - lora_config=self.lora_config, - parallel_config=self.parallel_config, - scheduler_config=self.scheduler_config, - cache_config=self.cache_config) + self.model = get_model(vllm_config=self.vllm_config) def make_model_input_from_broadcasted_tensor_dict( self, diff --git a/vllm/worker/cpu_worker.py b/vllm/worker/cpu_worker.py index ab93471b5af74..3778707ae07e8 100644 --- a/vllm/worker/cpu_worker.py +++ b/vllm/worker/cpu_worker.py @@ -6,9 +6,8 @@ import vllm.envs as envs from vllm.attention import get_attn_backend -from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, - ModelConfig, ParallelConfig, PromptAdapterConfig, - SchedulerConfig) +from vllm.config import (CacheConfig, DeviceConfig, ModelConfig, + ParallelConfig, VllmConfig) from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment) from vllm.logger import init_logger @@ -18,7 +17,8 @@ from vllm.worker.cpu_enc_dec_model_runner import CPUEncoderDecoderModelRunner from vllm.worker.cpu_model_runner import CPUModelRunner from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, - LoraNotSupportedWorkerBase, WorkerInput) + LoraNotSupportedWorkerBase, WorkerBase, + WorkerInput) logger = init_logger(__name__) @@ -121,31 +121,19 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase): def __init__( self, - model_config: ModelConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - device_config: DeviceConfig, - cache_config: CacheConfig, - load_config: LoadConfig, + vllm_config: VllmConfig, local_rank: int, rank: int, distributed_init_method: str, - lora_config: Optional[LoRAConfig] = None, kv_cache_dtype: Optional[str] = "auto", - prompt_adapter_config: Optional[PromptAdapterConfig] = None, is_driver_worker: bool = False, ) -> None: - self.model_config = model_config - self.parallel_config = parallel_config - self.scheduler_config = scheduler_config - self.device_config = device_config - self.cache_config = cache_config - self.load_config = load_config + WorkerBase.__init__(self, vllm_config=vllm_config) + self.local_rank = local_rank self.rank = rank self.distributed_init_method = distributed_init_method - self.lora_config = lora_config - self.prompt_adapter_config = prompt_adapter_config + self.is_driver_worker = is_driver_worker if self.is_driver_worker: assert self.rank == 0, "The driver worker must have rank 0." @@ -166,15 +154,8 @@ def __init__( if self._is_encoder_decoder_model(): ModelRunnerClass = CPUEncoderDecoderModelRunner self.model_runner: CPUModelRunner = ModelRunnerClass( - model_config, - parallel_config, - scheduler_config, - device_config, - cache_config, - load_config=self.load_config, - lora_config=self.lora_config, + vllm_config=vllm_config, kv_cache_dtype=kv_cache_dtype, - prompt_adapter_config=self.prompt_adapter_config, is_driver_worker=is_driver_worker) # Uninitialized cache engine. Will be initialized by # initialize_cache. diff --git a/vllm/worker/embedding_model_runner.py b/vllm/worker/embedding_model_runner.py index a7f5b2d4fdd1f..ff288d5ca1512 100644 --- a/vllm/worker/embedding_model_runner.py +++ b/vllm/worker/embedding_model_runner.py @@ -3,9 +3,7 @@ import torch -from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, - ModelConfig, ObservabilityConfig, ParallelConfig, - PromptAdapterConfig, SchedulerConfig) +from vllm.config import VllmConfig from vllm.distributed import get_pp_group from vllm.forward_context import set_forward_context from vllm.logger import init_logger @@ -36,29 +34,13 @@ class EmbeddingModelRunner( def __init__( self, - model_config: ModelConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - device_config: DeviceConfig, - cache_config: CacheConfig, - load_config: LoadConfig, - lora_config: Optional[LoRAConfig], + vllm_config: VllmConfig, kv_cache_dtype: Optional[str] = "auto", is_driver_worker: bool = False, - prompt_adapter_config: Optional[PromptAdapterConfig] = None, - observability_config: Optional[ObservabilityConfig] = None, ): - super().__init__(model_config, - parallel_config, - scheduler_config, - device_config, - cache_config, - load_config, - lora_config=lora_config, + super().__init__(vllm_config=vllm_config, kv_cache_dtype=kv_cache_dtype, - is_driver_worker=is_driver_worker, - prompt_adapter_config=prompt_adapter_config, - observability_config=observability_config) + is_driver_worker=is_driver_worker) @torch.inference_mode() def execute_model( diff --git a/vllm/worker/enc_dec_model_runner.py b/vllm/worker/enc_dec_model_runner.py index 6a00444f5098b..90a43196084ea 100644 --- a/vllm/worker/enc_dec_model_runner.py +++ b/vllm/worker/enc_dec_model_runner.py @@ -11,14 +11,13 @@ from vllm.attention.selector import (_Backend, get_env_variable_attn_backend, get_global_forced_attn_backend, global_force_attn_backend) -from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, - ModelConfig, ObservabilityConfig, ParallelConfig, - PromptAdapterConfig, SchedulerConfig) +from vllm.config import ModelConfig, VllmConfig from vllm.forward_context import set_forward_context from vllm.inputs import INPUT_REGISTRY, InputRegistry from vllm.logger import init_logger from vllm.model_executor import SamplingMetadata from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.model_executor.model_loader.utils import get_architecture_class_name from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalInputs, MultiModalRegistry) from vllm.sampling_params import SamplingParams @@ -36,6 +35,11 @@ logger = init_logger(__name__) +# The Mllama model has PagedAttention specific logic because of which it +# can only be run with the XFORMERS backend +# TODO Make Mllama model work with Flash Attention backend. +_XFORMERS_ONLY_ENCODER_DECODER_ARCHS = ["MllamaForConditionalGeneration"] + @dataclasses.dataclass(frozen=True) class EncoderDecoderModelInput(ModelInputForGPUWithSamplingMetadata): @@ -79,17 +83,9 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]): def __init__( self, - model_config: ModelConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - device_config: DeviceConfig, - cache_config: CacheConfig, - load_config: LoadConfig, - lora_config: Optional[LoRAConfig], + vllm_config: VllmConfig, kv_cache_dtype: Optional[str] = "auto", is_driver_worker: bool = False, - prompt_adapter_config: Optional[PromptAdapterConfig] = None, - observability_config: Optional[ObservabilityConfig] = None, input_registry: InputRegistry = INPUT_REGISTRY, mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, ): @@ -101,17 +97,10 @@ def __init__( models) but these arguments are present here for compatibility with the base-class constructor. ''' - - self._maybe_force_supported_attention_backend() + self._maybe_force_supported_attention_backend(vllm_config.model_config) super().__init__( - model_config, - parallel_config, - scheduler_config, - device_config, - cache_config, - load_config, - lora_config=None, + vllm_config=vllm_config, kv_cache_dtype=kv_cache_dtype, is_driver_worker=is_driver_worker, ) @@ -119,7 +108,12 @@ def __init__( # Crash for unsupported encoder/scenarios assert_enc_dec_mr_supported_scenario(self) - def _maybe_force_supported_attention_backend(self): + def _is_xformers_only_encoder_decoder_model(self, + model: ModelConfig) -> bool: + return get_architecture_class_name( + model) in _XFORMERS_ONLY_ENCODER_DECODER_ARCHS + + def _maybe_force_supported_attention_backend(self, model: ModelConfig): ''' Force vLLM to use the XFormers attention backend, which is currently the only supported option. @@ -135,22 +129,26 @@ def raise_backend_err(): is_forced_by_global = maybe_global_forced_backend is not None is_forced_by_env_var = maybe_env_var_forced_backend is not None - if not (is_forced_by_global or is_forced_by_env_var): + if not (is_forced_by_global or is_forced_by_env_var) \ + and self._is_xformers_only_encoder_decoder_model(model): # The user has not already specified an attention backend # override - logger.info("EncoderDecoderModelRunner requires " - "XFormers backend; overriding backend " - "auto-selection and forcing XFormers.") + logger.info( + "Encoder-Decoder Model Architecture %s requires XFormers " + "backend; overriding backend auto-selection and " + "forcing XFormers.", get_architecture_class_name(model)) global_force_attn_backend(_Backend.XFORMERS) elif is_forced_by_global: # Backend override enforced by global variable takes # precedence over vLLM backend environment variable. - if maybe_global_forced_backend != _Backend.XFORMERS: + if maybe_global_forced_backend not in\ + [_Backend.XFORMERS, _Backend.FLASH_ATTN]: raise_backend_err() elif is_forced_by_env_var: # Backend override enforced by vLLM backend # environment variable - if maybe_env_var_forced_backend != _Backend.XFORMERS: + if maybe_env_var_forced_backend not in\ + [_Backend.XFORMERS, _Backend.FLASH_ATTN]: raise_backend_err() def _list_to_int32_tensor( @@ -306,13 +304,12 @@ def profile_run(self) -> None: (group_id < max_num_batched_tokens % max_num_seqs)) batch_size += seq_len - decoder_seq_data, decoder_dummy_multi_modal_data \ - = self.input_registry.dummy_data_for_profiling( - self.model_config, + decoder_dummy_data = self.input_registry \ + .dummy_data_for_profiling(self.model_config, seq_len, self.mm_registry, is_encoder_data=False) - encoder_seq_data, encoder_dummy_multi_modal_data \ + encoder_dummy_data \ = self.input_registry.dummy_data_for_profiling( self.model_config, seq_len, @@ -320,26 +317,31 @@ def profile_run(self) -> None: is_encoder_data=True) # Having more tokens is over-conservative but otherwise fine - assert len(decoder_seq_data.prompt_token_ids) >= seq_len, ( + assert len( + decoder_dummy_data.seq_data.prompt_token_ids + ) >= seq_len, ( f"Expected at least {seq_len} dummy tokens for profiling, " - f"but got: {len(decoder_seq_data.prompt_token_ids)}") + f"but got: {len(decoder_dummy_data.seq_data.prompt_token_ids)}" + ) - assert decoder_dummy_multi_modal_data is None or \ - encoder_dummy_multi_modal_data is None, ( + assert decoder_dummy_data.multi_modal_data is None or \ + encoder_dummy_data.multi_modal_data is None, ( "Multi-modal data can't be provided in both encoder and decoder" ) seq = SequenceGroupMetadata( request_id=str(group_id), is_prompt=True, - seq_data={group_id: decoder_seq_data}, + seq_data={group_id: decoder_dummy_data.seq_data}, sampling_params=sampling_params, block_tables=None, - encoder_seq_data=encoder_seq_data, + encoder_seq_data=encoder_dummy_data.seq_data, cross_block_table=None, - multi_modal_data=decoder_dummy_multi_modal_data - or encoder_dummy_multi_modal_data, - ) + multi_modal_data=decoder_dummy_data.multi_modal_data + or encoder_dummy_data.multi_modal_data, + multi_modal_placeholders=decoder_dummy_data. + multi_modal_placeholders + or encoder_dummy_data.multi_modal_placeholders) seqs.append(seq) # Run the model with the dummy inputs. @@ -528,6 +530,7 @@ def _prepare_encoder_model_input_tensors( attn_metadata.encoder_seq_lens, attn_metadata.encoder_seq_lens_tensor, attn_metadata.max_encoder_seq_len, + attn_metadata.encoder_seq_start_loc, attn_metadata.cross_slot_mapping, attn_metadata.cross_block_tables, ) = ( @@ -535,6 +538,7 @@ def _prepare_encoder_model_input_tensors( encoder_seq_lens, encoder_seq_lens_tensor, max_encoder_seq_len, + encoder_seq_start_loc, cross_slot_mapping_tensor, cross_block_tables, ) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 891637dafbb14..328dab598f8ef 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -20,9 +20,7 @@ from vllm.attention.backends.utils import CommonAttentionState from vllm.compilation.compile_context import set_compile_context from vllm.compilation.levels import CompilationLevel -from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, - ModelConfig, ObservabilityConfig, ParallelConfig, - PromptAdapterConfig, SchedulerConfig) +from vllm.config import VllmConfig from vllm.core.scheduler import SchedulerOutputs from vllm.distributed import get_pp_group from vllm.distributed.parallel_state import graph_capture @@ -40,7 +38,8 @@ from vllm.model_executor.models import supports_lora, supports_multimodal from vllm.model_executor.models.utils import set_cpu_offload_max_bytes from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, - MultiModalInputs, MultiModalRegistry) + MultiModalInputs, MultiModalPlaceholderMap, + MultiModalRegistry) from vllm.platforms import current_platform from vllm.prompt_adapter.layers import PromptAdapterMapping from vllm.prompt_adapter.request import PromptAdapterRequest @@ -242,6 +241,8 @@ def __init__( # Multi-modal inputs. multi_modal_inputs: Optional[MultiModalInputs] = None, + multi_modal_placeholder_maps: Optional[Dict[ + str, MultiModalPlaceholderMap]] = None, # Whether the prefix cache is hit (prefill only). prefix_cache_hit: bool = False, @@ -361,6 +362,7 @@ def __init__( self.prompt_adapter_request = prompt_adapter_request self.multi_modal_inputs = multi_modal_inputs + self.multi_modal_placeholder_maps = multi_modal_placeholder_maps self.prefix_cache_hit = prefix_cache_hit self.n_seqs = len(self.seq_ids) @@ -635,7 +637,12 @@ def _compute_prompt_adapter_input( def _compute_multi_modal_input(self, inter_data: InterDataForSeqGroup, seq_group_metadata: SequenceGroupMetadata): """If multi-modal data is given, add it to the input.""" - mm_data = seq_group_metadata.multi_modal_data + # NOTE: mm_data only includes the subset of multi-modal items that + # intersect with the current prefill positions. + positions = inter_data.input_positions[0] + mm_data, placeholder_maps = MultiModalPlaceholderMap.from_seq_group( + seq_group_metadata, + range(positions[0], positions[0] + len(positions))) if not mm_data: return @@ -643,6 +650,7 @@ def _compute_multi_modal_input(self, inter_data: InterDataForSeqGroup, mm_data, mm_processor_kwargs=seq_group_metadata.mm_processor_kwargs) inter_data.multi_modal_inputs = mm_kwargs + inter_data.multi_modal_placeholder_maps = placeholder_maps # special processing for mrope position deltas. if self.runner.model_is_mrope: @@ -945,32 +953,20 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): def __init__( self, - model_config: ModelConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - device_config: DeviceConfig, - cache_config: CacheConfig, - load_config: LoadConfig, - lora_config: Optional[LoRAConfig], + vllm_config: VllmConfig, kv_cache_dtype: Optional[str] = "auto", is_driver_worker: bool = False, - prompt_adapter_config: Optional[PromptAdapterConfig] = None, return_hidden_states: bool = False, - observability_config: Optional[ObservabilityConfig] = None, input_registry: InputRegistry = INPUT_REGISTRY, mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, ): - self.model_config = model_config - self.parallel_config = parallel_config - self.scheduler_config = scheduler_config - self.device_config = device_config - self.cache_config = cache_config - self.lora_config = lora_config - self.load_config = load_config + + ModelRunnerBase.__init__(self, vllm_config) + model_config = self.model_config + cache_config = self.cache_config + self.is_driver_worker = is_driver_worker - self.prompt_adapter_config = prompt_adapter_config self.return_hidden_states = return_hidden_states - self.observability_config = observability_config self.device = self.device_config.device self.pin_memory = is_pin_memory_available() @@ -1055,13 +1051,7 @@ def __init__( def load_model(self) -> None: logger.info("Starting to load model %s...", self.model_config.model) with DeviceMemoryProfiler() as m: - self.model = get_model(model_config=self.model_config, - device_config=self.device_config, - load_config=self.load_config, - lora_config=self.lora_config, - parallel_config=self.parallel_config, - scheduler_config=self.scheduler_config, - cache_config=self.cache_config) + self.model = get_model(vllm_config=self.vllm_config) self.model_memory_usage = m.consumed_memory logger.info("Loading model weights took %.4f GB", @@ -1255,7 +1245,7 @@ def profile_run(self) -> None: (group_id < max_num_batched_tokens % max_num_seqs)) batch_size += seq_len - seq_data, dummy_multi_modal_data = self.input_registry \ + dummy_data = self.input_registry \ .dummy_data_for_profiling(self.model_config, seq_len, self.mm_registry) @@ -1263,12 +1253,13 @@ def profile_run(self) -> None: seq = SequenceGroupMetadata( request_id=str(group_id), is_prompt=True, - seq_data={group_id: seq_data}, + seq_data={group_id: dummy_data.seq_data}, sampling_params=sampling_params, block_tables=None, lora_request=dummy_lora_requests_per_seq[group_id] if dummy_lora_requests_per_seq else None, - multi_modal_data=dummy_multi_modal_data, + multi_modal_data=dummy_data.multi_modal_data, + multi_modal_placeholders=dummy_data.multi_modal_placeholders, ) seqs.append(seq) diff --git a/vllm/worker/model_runner_base.py b/vllm/worker/model_runner_base.py index 86883cf152449..9e529f86b46bb 100644 --- a/vllm/worker/model_runner_base.py +++ b/vllm/worker/model_runner_base.py @@ -9,6 +9,7 @@ import torch from torch import is_tensor +from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.model_executor.layers.sampler import SamplerOutput from vllm.platforms import current_platform @@ -46,9 +47,8 @@ def _init_attn_metadata_from_tensor_dict( # Extract the fields used to create AttentionMetadata. valid_attn_kwargs = {} for field in dataclasses.fields(attn_backend.get_metadata_cls()): - val = tensor_dict.pop(field.name, None) - if val is not None: - valid_attn_kwargs[field.name] = val + if field.name in tensor_dict: + valid_attn_kwargs[field.name] = tensor_dict.pop(field.name) attn_metadata = attn_backend.make_metadata(**valid_attn_kwargs) tensor_dict["attn_metadata"] = attn_metadata @@ -221,6 +221,22 @@ class ModelRunnerBase(ABC, Generic[T]): ModelRunnerInputBase subclass. """ + def __init__( + self, + vllm_config: VllmConfig, + ) -> None: + self.vllm_config = vllm_config + self.model_config = vllm_config.model_config + self.cache_config = vllm_config.cache_config + self.lora_config = vllm_config.lora_config + self.load_config = vllm_config.load_config + self.parallel_config = vllm_config.parallel_config + self.scheduler_config = vllm_config.scheduler_config + self.device_config = vllm_config.device_config + self.speculative_config = vllm_config.speculative_config + self.prompt_adapter_config = vllm_config.prompt_adapter_config + self.observability_config = vllm_config.observability_config + # Map of request_id -> generator used for seeded random sampling generators: Dict[str, torch.Generator] = {} diff --git a/vllm/worker/multi_step_model_runner.py b/vllm/worker/multi_step_model_runner.py index be2f0d79154d6..3ee0fb4dc943e 100644 --- a/vllm/worker/multi_step_model_runner.py +++ b/vllm/worker/multi_step_model_runner.py @@ -304,6 +304,7 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]): # mypy: enable-error-code=type-var def __init__(self, base_model_runner: GPUModelRunnerBase, *args, **kwargs): + super().__init__(*args, **kwargs) # Check attention backend support. diff --git a/vllm/worker/multi_step_worker.py b/vllm/worker/multi_step_worker.py index bf66f32d7d244..1f982fe103366 100644 --- a/vllm/worker/multi_step_worker.py +++ b/vllm/worker/multi_step_worker.py @@ -27,17 +27,9 @@ def __init__(self, *args, **kwargs): # for multi-step model, wrap the model runner with MultiStepModelRunner self.model_runner = MultiStepModelRunner( base_model_runner, - base_model_runner.model_config, - base_model_runner.parallel_config, - base_model_runner.scheduler_config, - base_model_runner.device_config, - base_model_runner.cache_config, - load_config=base_model_runner.load_config, - lora_config=self.lora_config, + vllm_config=base_model_runner.vllm_config, kv_cache_dtype=self.cache_config.cache_dtype, is_driver_worker=base_model_runner.is_driver_worker, - prompt_adapter_config=base_model_runner.prompt_adapter_config, - observability_config=base_model_runner.observability_config, ) pipeline_parallel_size = self.parallel_config.pipeline_parallel_size diff --git a/vllm/worker/neuron_model_runner.py b/vllm/worker/neuron_model_runner.py index b8c760c4b5396..2da22cbfc7cb5 100644 --- a/vllm/worker/neuron_model_runner.py +++ b/vllm/worker/neuron_model_runner.py @@ -7,8 +7,7 @@ from torch import nn from transformers_neuronx.config import GenerationConfig -from vllm.config import (DeviceConfig, ModelConfig, ParallelConfig, - SchedulerConfig) +from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.model_executor import SamplingMetadata from vllm.model_executor.layers.sampler import SamplerOutput @@ -57,20 +56,13 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]): def __init__( self, - model_config: ModelConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - device_config: DeviceConfig, + vllm_config: VllmConfig, ): - self.model_config = model_config - self.parallel_config = parallel_config - self.scheduler_config = scheduler_config - + ModelRunnerBase.__init__(self, vllm_config) + model_config = self.model_config if model_config is not None and model_config.get_sliding_window(): logger.warning("Sliding window is not supported on Neuron. " "The model will run without sliding window.") - self.device_config = (device_config - if device_config is not None else DeviceConfig()) self.device = self.device_config.device self.pin_memory = is_pin_memory_available() diff --git a/vllm/worker/neuron_worker.py b/vllm/worker/neuron_worker.py index fff14d6402b44..3f6269684ac93 100644 --- a/vllm/worker/neuron_worker.py +++ b/vllm/worker/neuron_worker.py @@ -4,15 +4,15 @@ import torch import torch.distributed -from vllm.config import (CacheConfig, DeviceConfig, ModelConfig, - ParallelConfig, SchedulerConfig) +from vllm.config import VllmConfig from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment) from vllm.model_executor import set_random_seed from vllm.sequence import ExecuteModelRequest from vllm.worker.neuron_model_runner import NeuronModelRunner from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, - LoraNotSupportedWorkerBase, WorkerInput) + LoraNotSupportedWorkerBase, WorkerBase, + WorkerInput) class NeuronWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase): @@ -21,20 +21,12 @@ class NeuronWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase): def __init__( self, - model_config: ModelConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - device_config: DeviceConfig, - cache_config: CacheConfig, + vllm_config: VllmConfig, local_rank: int, rank: int, distributed_init_method: str, ) -> None: - self.model_config = model_config - self.parallel_config = parallel_config - self.scheduler_config = scheduler_config - self.device_config = device_config - self.cache_config = cache_config + WorkerBase.__init__(self, vllm_config=vllm_config) self.local_rank = local_rank self.rank = rank self.distributed_init_method = distributed_init_method @@ -44,7 +36,7 @@ def __init__( init_cached_hf_modules() self.model_runner: NeuronModelRunner = NeuronModelRunner( - model_config, parallel_config, scheduler_config, device_config) + vllm_config=vllm_config) self.is_driver_worker = True def init_device(self) -> None: diff --git a/vllm/worker/openvino_model_runner.py b/vllm/worker/openvino_model_runner.py index a164fbe3393c4..c9c87ea748081 100644 --- a/vllm/worker/openvino_model_runner.py +++ b/vllm/worker/openvino_model_runner.py @@ -1,4 +1,5 @@ -from typing import List, NamedTuple, Optional, Tuple +from collections import defaultdict +from typing import Dict, List, NamedTuple, Optional, Tuple import openvino as ov import torch @@ -6,16 +7,15 @@ from vllm.attention import get_attn_backend from vllm.attention.backends.openvino import OpenVINOAttentionMetadata -from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, - ModelConfig, MultiModalConfig, ParallelConfig, - SchedulerConfig) +from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.model_executor import SamplingMetadata from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.model_loader.openvino import get_model from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, - MultiModalInputs) + MultiModalInputs, MultiModalPlaceholderMap) from vllm.sequence import SequenceGroupMetadata +from vllm.worker.model_runner_base import ModelRunnerBase logger = init_logger(__name__) @@ -38,33 +38,21 @@ def empty(cls, device): multi_modal_kwargs={}) -class OpenVINOModelRunner: +class OpenVINOModelRunner(ModelRunnerBase): def __init__( self, ov_core: ov.Core, - model_config: ModelConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - device_config: DeviceConfig, - cache_config: CacheConfig, - load_config: LoadConfig, - lora_config: Optional[LoRAConfig], - multimodal_config: Optional[MultiModalConfig], + vllm_config: VllmConfig, kv_cache_dtype: Optional[str] = "auto", is_driver_worker: bool = False, *args, **kwargs, ): self.ov_core = ov_core - self.model_config = model_config - self.parallel_config = parallel_config - self.scheduler_config = scheduler_config - self.device_config = device_config - self.cache_config = cache_config - self.lora_config = lora_config - self.multimodal_config = multimodal_config - self.load_config = load_config + ModelRunnerBase.__init__(self, vllm_config=vllm_config) + cache_config = self.cache_config + model_config = self.model_config self.is_driver_worker = is_driver_worker self.device = self.device_config.device @@ -115,6 +103,9 @@ def _prepare_model_input( past_lens: List[int] = [] query_lens: List[int] = [] multi_modal_inputs_list: List[MultiModalInputs] = [] + multi_modal_placeholder_maps: Dict[ + str, + MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap) subsequence_begins: List[int] = [] block_indices: List[int] = [] @@ -168,15 +159,6 @@ def _prepare_model_input( and self.sliding_window is None and is_prompt) - mm_data = seq_group_metadata.multi_modal_data - if mm_data: - mm_kwargs = self.multi_modal_input_mapper( - mm_data, - mm_processor_kwargs=seq_group_metadata. - mm_processor_kwargs, - ) - multi_modal_inputs_list.append(mm_kwargs) - block_table = seq_group_metadata.block_tables[seq_id] # TODO(sang): Combine chunked prefill and prefix caching by # only allowing multiple of block_size chunk size. @@ -220,7 +202,8 @@ def _prepare_model_input( query_lens.append(query_len) input_tokens.extend(tokens) - input_positions.extend(list(range(computed_len, seq_len))) + positions_range = range(computed_len, seq_len) + input_positions.extend(list(positions_range)) past_lens.append(computed_len) subsequence_begins.append(subsequence_begins[-1] + query_len) @@ -233,6 +216,22 @@ def _prepare_model_input( ), "seq_len: {}, computed_len: {}, query_len: {}".format( seq_len, computed_len, query_len) + if seq_group_metadata.multi_modal_data: + # NOTE: mm_data only includes the subset of multi-modal + # items that intersect with the current prefill positions. + mm_data, placeholder_maps = MultiModalPlaceholderMap \ + .from_seq_group(seq_group_metadata, positions_range) + + mm_kwargs = self.multi_modal_input_mapper( + mm_data, + mm_processor_kwargs=seq_group_metadata. + mm_processor_kwargs) + multi_modal_inputs_list.append(mm_kwargs) + + for modality, placeholder_map in placeholder_maps.items(): + multi_modal_placeholder_maps[modality].extend( + placeholder_map, ) + max_query_len = max(query_lens) assert max_query_len > 0, "query_lens: {}".format(query_lens) @@ -261,12 +260,19 @@ def _prepare_model_input( max_context_len, dtype=torch.int32, device=self.device) # type: ignore + placeholder_index_maps = { + modality: placeholder_map.index_map() + for modality, placeholder_map in + multi_modal_placeholder_maps.items() + } + attn_metadata = self.attn_backend.make_openvino_metadata( past_lens=past_lens_tensor, subsequence_begins=subsequence_begins_tensor, block_indices=block_indices_tensor, block_indices_begins=block_indices_begins_tensor, max_context_len=max_context_len_tensor, + multi_modal_placeholder_index_maps=placeholder_index_maps, ) multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list) @@ -350,3 +356,9 @@ def execute_model( sampling_metadata=sampling_metadata, ) return output + + def prepare_model_input(self, *args, **kwargs): + raise NotImplementedError + + def make_model_input_from_broadcasted_tensor_dict(self, *args, **kwargs): + raise NotImplementedError diff --git a/vllm/worker/openvino_worker.py b/vllm/worker/openvino_worker.py index a420d390c1ae4..205f8a337ce6c 100644 --- a/vllm/worker/openvino_worker.py +++ b/vllm/worker/openvino_worker.py @@ -7,9 +7,8 @@ import vllm.envs as envs from vllm.attention import get_attn_backend -from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, - ModelConfig, MultiModalConfig, ParallelConfig, - SchedulerConfig) +from vllm.config import (CacheConfig, DeviceConfig, ModelConfig, + ParallelConfig, VllmConfig) from vllm.distributed import (broadcast_tensor_dict, ensure_model_parallel_initialized, init_distributed_environment) @@ -22,7 +21,7 @@ from vllm.sampling_params import SamplingParams from vllm.sequence import ExecuteModelRequest, SequenceGroupMetadata from vllm.worker.openvino_model_runner import OpenVINOModelRunner -from vllm.worker.worker_base import LoraNotSupportedWorkerBase +from vllm.worker.worker_base import LoraNotSupportedWorkerBase, WorkerBase logger = init_logger(__name__) @@ -212,33 +211,19 @@ class OpenVINOWorker(LoraNotSupportedWorkerBase): def __init__( self, ov_core: ov.Core, - model_config: ModelConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - device_config: DeviceConfig, - cache_config: CacheConfig, - load_config: LoadConfig, + vllm_config: VllmConfig, local_rank: int, rank: int, distributed_init_method: str, - lora_config: Optional[LoRAConfig] = None, - multimodal_config: Optional[MultiModalConfig] = None, kv_cache_dtype: Optional[ov.Type] = ov.Type.undefined, is_driver_worker: bool = False, ) -> None: self.ov_core = ov_core - self.model_config = model_config - self.parallel_config = parallel_config + WorkerBase.__init__(self, vllm_config) self.parallel_config.rank = rank - self.scheduler_config = scheduler_config - self.device_config = device_config - self.cache_config = cache_config - self.load_config = load_config self.local_rank = local_rank self.rank = rank self.distributed_init_method = distributed_init_method - self.lora_config = lora_config - self.multimodal_config = multimodal_config self.is_driver_worker = is_driver_worker if self.is_driver_worker: assert self.rank == 0, "The driver worker must have rank 0." @@ -250,14 +235,7 @@ def __init__( init_cached_hf_modules() self.model_runner = OpenVINOModelRunner( self.ov_core, - model_config, - parallel_config, - scheduler_config, - device_config, - cache_config, - load_config=self.load_config, - lora_config=self.lora_config, - multimodal_config=self.multimodal_config, + vllm_config=self.vllm_config, kv_cache_dtype=kv_cache_dtype, is_driver_worker=is_driver_worker, ) diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index 87ced7818a676..a721186137328 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -12,8 +12,7 @@ from vllm.attention import AttentionMetadata, get_attn_backend from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher -from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, ModelConfig, - ParallelConfig, SchedulerConfig) +from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.model_loader import get_model @@ -90,20 +89,10 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]): def __init__( self, - model_config: ModelConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - device_config: DeviceConfig, - cache_config: CacheConfig, - load_config: LoadConfig, + vllm_config: VllmConfig, is_driver_worker: bool = False, ): - self.model_config = model_config - self.parallel_config = parallel_config - self.scheduler_config = scheduler_config - self.device_config = device_config - self.cache_config = cache_config - self.load_config = load_config + ModelRunnerBase.__init__(self, vllm_config=vllm_config) self.is_driver_worker = is_driver_worker self.block_size = self.cache_config.block_size @@ -148,15 +137,7 @@ def load_model(self) -> None: "vllm.model_executor.layers.vocab_parallel_embedding." "get_tensor_model_parallel_rank", return_value=xm_tp_rank): - model = get_model( - model_config=self.model_config, - load_config=self.load_config, - device_config=self.device_config, - parallel_config=self.parallel_config, - cache_config=self.cache_config, - scheduler_config=self.scheduler_config, - lora_config=None, - ) + model = get_model(vllm_config=self.vllm_config) model = model.eval() xm.wait_device_ops() self.model = ModelWrapper(model) @@ -184,6 +165,7 @@ def _dummy_run( num_prefill_tokens=batch_size * seq_len, num_decode_tokens=0, slot_mapping=slot_mapping, + multi_modal_placeholder_index_maps=None, block_tables=None, context_lens=None, ) @@ -216,6 +198,7 @@ def _dummy_run( num_prefill_tokens=0, num_decode_tokens=batch_size * seq_len, slot_mapping=slot_mapping, + multi_modal_placeholder_index_maps=None, block_tables=block_tables, context_lens=context_lens, ) @@ -360,6 +343,7 @@ def _prepare_prompt( num_prefill_tokens=0, # NOTE: This is not used. num_decode_tokens=0, slot_mapping=slot_mapping, + multi_modal_placeholder_index_maps=None, block_tables=None, context_lens=None, ) @@ -429,6 +413,7 @@ def _prepare_decode( num_prefill_tokens=0, num_decode_tokens=batch_size, slot_mapping=slot_mapping, + multi_modal_placeholder_index_maps=None, block_tables=block_tables, context_lens=context_lens, ) diff --git a/vllm/worker/tpu_worker.py b/vllm/worker/tpu_worker.py index de6f7ab0072fd..096cb23416909 100644 --- a/vllm/worker/tpu_worker.py +++ b/vllm/worker/tpu_worker.py @@ -6,8 +6,7 @@ import torch_xla.runtime as xr import vllm.envs as envs -from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, ModelConfig, - ParallelConfig, SchedulerConfig) +from vllm.config import VllmConfig from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment) from vllm.logger import init_logger @@ -16,7 +15,8 @@ from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, get_dtype_size from vllm.worker.tpu_model_runner import TPUModelRunner from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, - LoraNotSupportedWorkerBase, WorkerInput) + LoraNotSupportedWorkerBase, WorkerBase, + WorkerInput) logger = init_logger(__name__) @@ -25,24 +25,14 @@ class TPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase): def __init__( self, - model_config: ModelConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - device_config: DeviceConfig, - cache_config: CacheConfig, - load_config: LoadConfig, + vllm_config: VllmConfig, local_rank: int, rank: int, distributed_init_method: str, is_driver_worker: bool, ) -> None: - self.model_config = model_config - self.parallel_config = parallel_config + WorkerBase.__init__(self, vllm_config=vllm_config) self.parallel_config.rank = rank - self.scheduler_config = scheduler_config - self.device_config = device_config - self.cache_config = cache_config - self.load_config = load_config self.local_rank = local_rank self.rank = rank self.distributed_init_method = distributed_init_method @@ -56,13 +46,7 @@ def __init__( self.cache_config.cache_dtype] self.model_runner: TPUModelRunner = TPUModelRunner( - model_config, - parallel_config, - scheduler_config, - device_config, - cache_config, - load_config, - is_driver_worker=is_driver_worker) + vllm_config=vllm_config, is_driver_worker=is_driver_worker) def init_device(self) -> None: os.environ["PJRT_DEVICE"] = "TPU" diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index fd30962e5d6bb..8928936b4f9fc 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -7,10 +7,7 @@ import torch.distributed import vllm.envs as envs -from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, - ModelConfig, ObservabilityConfig, ParallelConfig, - PromptAdapterConfig, SchedulerConfig, - SpeculativeConfig) +from vllm.config import ParallelConfig, VllmConfig from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment, set_custom_all_reduce) @@ -27,7 +24,8 @@ from vllm.worker.embedding_model_runner import EmbeddingModelRunner from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner from vllm.worker.model_runner import GPUModelRunnerBase, ModelRunner -from vllm.worker.worker_base import LocalOrDistributedWorkerBase, WorkerInput +from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, WorkerBase, + WorkerInput) logger = init_logger(__name__) @@ -42,46 +40,31 @@ class Worker(LocalOrDistributedWorkerBase): def __init__( self, - model_config: ModelConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - device_config: DeviceConfig, - cache_config: CacheConfig, - load_config: LoadConfig, + vllm_config: VllmConfig, local_rank: int, rank: int, distributed_init_method: str, - lora_config: Optional[LoRAConfig] = None, - speculative_config: Optional[SpeculativeConfig] = None, - prompt_adapter_config: Optional[PromptAdapterConfig] = None, is_driver_worker: bool = False, model_runner_cls: Optional[Type[GPUModelRunnerBase]] = None, - observability_config: Optional[ObservabilityConfig] = None, ) -> None: - self.model_config = model_config - self.parallel_config = parallel_config + WorkerBase.__init__(self, vllm_config) self.parallel_config.rank = rank - self.scheduler_config = scheduler_config - self.device_config = device_config - self.cache_config = cache_config self.local_rank = local_rank self.rank = rank self.distributed_init_method = distributed_init_method - self.lora_config = lora_config - self.load_config = load_config - self.prompt_adapter_config = prompt_adapter_config self.is_driver_worker = is_driver_worker - if parallel_config and is_driver_worker: - assert rank % parallel_config.tensor_parallel_size == 0, \ + if is_driver_worker: + assert rank % self.parallel_config.tensor_parallel_size == 0, \ "Driver worker should be rank 0 of tensor parallel group." if self.model_config.trust_remote_code: # note: lazy import to avoid importing torch before initializing from vllm.utils import init_cached_hf_modules init_cached_hf_modules() - self.observability_config = observability_config # Return hidden states from target model if the draft model is an # mlp_speculator + speculative_config = self.speculative_config + model_config = self.model_config speculative_args = {} if speculative_config is None \ or (speculative_config.draft_model_config.model == model_config.model) \ @@ -97,17 +80,9 @@ def __init__( elif self._is_encoder_decoder_model(): ModelRunnerClass = EncoderDecoderModelRunner self.model_runner: GPUModelRunnerBase = ModelRunnerClass( - model_config, - parallel_config, - scheduler_config, - device_config, - cache_config, - load_config=load_config, - lora_config=self.lora_config, + vllm_config=self.vllm_config, kv_cache_dtype=self.cache_config.cache_dtype, is_driver_worker=is_driver_worker, - prompt_adapter_config=prompt_adapter_config, - observability_config=observability_config, **speculative_args, ) # Uninitialized cache engine. Will be initialized by diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index 6ba4f272315ce..cf8a4946a71c4 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -7,7 +7,7 @@ import torch -from vllm.config import ObservabilityConfig +from vllm.config import ObservabilityConfig, VllmConfig from vllm.distributed import broadcast_tensor_dict, get_pp_group, get_tp_group from vllm.logger import init_logger from vllm.lora.request import LoRARequest @@ -29,6 +29,22 @@ class WorkerBase(ABC): communicate request metadata to other workers. """ + def __init__( + self, + vllm_config: VllmConfig, + ) -> None: + self.vllm_config = vllm_config + self.model_config = vllm_config.model_config + self.cache_config = vllm_config.cache_config + self.lora_config = vllm_config.lora_config + self.load_config = vllm_config.load_config + self.parallel_config = vllm_config.parallel_config + self.scheduler_config = vllm_config.scheduler_config + self.device_config = vllm_config.device_config + self.speculative_config = vllm_config.speculative_config + self.prompt_adapter_config = vllm_config.prompt_adapter_config + self.observability_config = vllm_config.observability_config + @abstractmethod def init_device(self) -> None: """Initialize device state, such as loading the model or other on-device diff --git a/vllm/worker/xpu_model_runner.py b/vllm/worker/xpu_model_runner.py index 75a6de3b24ba4..bae8b469767b2 100644 --- a/vllm/worker/xpu_model_runner.py +++ b/vllm/worker/xpu_model_runner.py @@ -1,6 +1,7 @@ import dataclasses import time import weakref +from collections import defaultdict from dataclasses import dataclass from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar) @@ -9,9 +10,7 @@ import torch.nn as nn from vllm.attention import get_attn_backend -from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, - ModelConfig, ObservabilityConfig, ParallelConfig, - PromptAdapterConfig, SchedulerConfig) +from vllm.config import VllmConfig from vllm.distributed import get_pp_group from vllm.inputs import INPUT_REGISTRY, InputRegistry from vllm.logger import init_logger @@ -19,7 +18,8 @@ from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.model_loader import get_model from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, - MultiModalInputs, MultiModalRegistry) + MultiModalInputs, MultiModalPlaceholderMap, + MultiModalRegistry) from vllm.sampling_params import SamplingParams from vllm.sequence import IntermediateTensors, SequenceGroupMetadata from vllm.utils import DeviceMemoryProfiler, make_tensor_with_pad @@ -161,6 +161,9 @@ def _prepare_prompt( slot_mapping: List[int] = [] seq_lens: List[int] = [] multi_modal_inputs_list: List[MultiModalInputs] = [] + multi_modal_placeholder_maps: Dict[ + str, + MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap) for seq_group_metadata in seq_group_metadata_list: assert seq_group_metadata.is_prompt @@ -179,7 +182,21 @@ def _prepare_prompt( # Token position ids # NOTE(woosuk): Here we assume that the first token in the prompt # is always the first token in the sequence. - input_positions.extend(list(range(computed_len, seq_len))) + positions_range = range(computed_len, seq_len) + input_positions.extend(list(positions_range)) + + if seq_group_metadata.multi_modal_data: + # NOTE: mm_data only includes the subset of multi-modal items + # that intersect with the current prefill positions. + mm_data, placeholder_maps = MultiModalPlaceholderMap \ + .from_seq_group(seq_group_metadata, positions_range) + + mm_kwargs = self.runner.multi_modal_input_mapper(mm_data) + multi_modal_inputs_list.append(mm_kwargs) + + for modality, placeholder_map in placeholder_maps.items(): + multi_modal_placeholder_maps[modality].extend( + placeholder_map) if seq_group_metadata.block_tables is None: # During memory profiling, the block tables are not initialized @@ -220,6 +237,11 @@ def _prepare_prompt( slot_mapping = torch.tensor(slot_mapping, dtype=torch.long, device=self.device) # type: ignore + placeholder_index_maps = { + modality: placeholder_map.index_map() + for modality, placeholder_map in + multi_modal_placeholder_maps.items() + } max_seqlen = max(seq_lens) tmp = [0] @@ -230,6 +252,7 @@ def _prepare_prompt( attn_metadata = self.attn_backend.make_metadata( is_prompt=True, slot_mapping=slot_mapping, + multi_modal_placeholder_index_maps=placeholder_index_maps, seq_lens=seq_lens, seqlen_q=seqlen_q, max_seqlen=max_seqlen, @@ -313,6 +336,7 @@ def _prepare_decode( attn_metadata = self.attn_backend.make_metadata( is_prompt=False, slot_mapping=slot_mapping, + multi_modal_placeholder_index_maps=None, seq_lens=seq_lens, seqlen_q=torch.tensor([]), max_seqlen=0, @@ -337,33 +361,18 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]): def __init__( self, - model_config: ModelConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - device_config: DeviceConfig, - cache_config: CacheConfig, - load_config: LoadConfig, - lora_config: Optional[LoRAConfig], + vllm_config: VllmConfig, kv_cache_dtype: Optional[str] = "auto", is_driver_worker: bool = False, - prompt_adapter_config: Optional[PromptAdapterConfig] = None, return_hidden_states: bool = False, - observability_config: Optional[ObservabilityConfig] = None, input_registry: InputRegistry = INPUT_REGISTRY, mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, ): - self.model_config = model_config - self.parallel_config = parallel_config - self.scheduler_config = scheduler_config - self.device_config = device_config - self.cache_config = cache_config - self.lora_config = lora_config - self.load_config = load_config + + ModelRunnerBase.__init__(self, vllm_config=vllm_config) + model_config = self.model_config + cache_config = self.cache_config self.is_driver_worker = is_driver_worker - self.prompt_adapter_config = prompt_adapter_config - self.observability_config = observability_config - if self.observability_config is not None: - print(f"observability_config is {self.observability_config}") self.return_hidden_states = return_hidden_states self.device = self.device_config.device @@ -396,15 +405,7 @@ def __init__( def load_model(self) -> None: with DeviceMemoryProfiler() as m: - self.model = get_model( - model_config=self.model_config, - device_config=self.device_config, - load_config=self.load_config, - lora_config=self.lora_config, - parallel_config=self.parallel_config, - scheduler_config=self.scheduler_config, - cache_config=self.cache_config, - ) + self.model = get_model(vllm_config=self.vllm_config) self.model_memory_usage = m.consumed_memory logger.info("Loading model weights took %.4f GB", @@ -450,7 +451,7 @@ def profile_run(self) -> None: (group_id < max_num_batched_tokens % max_num_seqs)) batch_size += seq_len - seq_data, dummy_multi_modal_data = self.input_registry \ + dummy_data = self.input_registry \ .dummy_data_for_profiling(self.model_config, seq_len, self.mm_registry) @@ -458,12 +459,12 @@ def profile_run(self) -> None: seq = SequenceGroupMetadata( request_id=str(group_id), is_prompt=True, - seq_data={group_id: seq_data}, + seq_data={group_id: dummy_data.seq_data}, sampling_params=sampling_params, block_tables=None, lora_request=None, - multi_modal_data=dummy_multi_modal_data, - ) + multi_modal_data=dummy_data.multi_modal_data, + multi_modal_placeholders=dummy_data.multi_modal_placeholders) seqs.append(seq) # Run the model with the dummy inputs. diff --git a/vllm/worker/xpu_worker.py b/vllm/worker/xpu_worker.py index c1d836bb0d318..1295666055b04 100644 --- a/vllm/worker/xpu_worker.py +++ b/vllm/worker/xpu_worker.py @@ -8,10 +8,7 @@ import torch import torch.distributed -from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, - ModelConfig, ObservabilityConfig, ParallelConfig, - PromptAdapterConfig, SchedulerConfig, - SpeculativeConfig) +from vllm.config import VllmConfig from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment) from vllm.logger import init_logger @@ -19,7 +16,7 @@ from vllm.platforms import current_platform from vllm.worker.cache_engine import CacheEngine from vllm.worker.worker import Worker -from vllm.worker.worker_base import LoraNotSupportedWorkerBase +from vllm.worker.worker_base import LoraNotSupportedWorkerBase, WorkerBase from vllm.worker.xpu_model_runner import XPUModelRunner logger = init_logger(__name__) @@ -36,53 +33,32 @@ class XPUWorker(LoraNotSupportedWorkerBase, Worker): def __init__( self, - model_config: ModelConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - device_config: DeviceConfig, - cache_config: CacheConfig, - load_config: LoadConfig, + vllm_config: VllmConfig, local_rank: int, rank: int, distributed_init_method: str, - lora_config: Optional[LoRAConfig] = None, - speculative_config: Optional[SpeculativeConfig] = None, - prompt_adapter_config: Optional[PromptAdapterConfig] = None, is_driver_worker: bool = False, - observability_config: Optional[ObservabilityConfig] = None, ) -> None: + WorkerBase.__init__(self, vllm_config=vllm_config) + device_config = self.device_config + parallel_config = self.parallel_config assert device_config.device_type == "xpu" assert current_platform.is_xpu() - self.model_config = model_config - self.parallel_config = parallel_config self.parallel_config.rank = rank - self.scheduler_config = scheduler_config - self.device_config = device_config - self.cache_config = cache_config - self.load_config = load_config + self.local_rank = local_rank self.rank = rank self.distributed_init_method = distributed_init_method - self.lora_config = lora_config - self.prompt_adapter_config = prompt_adapter_config self.is_driver_worker = is_driver_worker - self.observability_config = observability_config if parallel_config and is_driver_worker: assert rank % parallel_config.tensor_parallel_size == 0, \ "Driver worker should be rank 0 of tensor parallel group." self.model_runner = XPUModelRunner( # type: ignore - model_config, - parallel_config, - scheduler_config, - device_config, - cache_config, - load_config=self.load_config, - lora_config=self.lora_config, + vllm_config=vllm_config, kv_cache_dtype=self.cache_config.cache_dtype, is_driver_worker=is_driver_worker, - observability_config=self.observability_config, ) # Uninitialized cache engine. Will be initialized by # initialize_cache.