diff --git a/.github/dependabot.yml b/.github/dependabot.yml index a21acd9671eeb..4f54eea564ecb 100644 --- a/.github/dependabot.yml +++ b/.github/dependabot.yml @@ -14,6 +14,15 @@ updates: reviewers: ["khluu", "simon-mo"] allow: - dependency-type: "all" + ignore: + - dependency-name: "torch" + - dependency-name: "torchvision" + - dependency-name: "xformers" + - dependency-name: "lm-format-enforcer" + - dependency-name: "gguf" + - dependency-name: "compressed-tensors" + - dependency-name: "ray[adag]" + - dependency-name: "lm-eval" groups: patch-update: applies-to: version-updates diff --git a/Dockerfile b/Dockerfile index 0a562253c537b..343364da2ebf5 100644 --- a/Dockerfile +++ b/Dockerfile @@ -206,7 +206,7 @@ FROM vllm-base AS vllm-openai # install additional dependencies for openai api server RUN --mount=type=cache,target=/root/.cache/pip \ - pip install accelerate hf_transfer 'modelscope!=1.15.0' bitsandbytes>=0.44.0 timm==0.9.10 + pip install accelerate hf_transfer 'modelscope!=1.15.0' 'bitsandbytes>=0.44.0' timm==0.9.10 ENV VLLM_USAGE_SOURCE production-docker-image diff --git a/Dockerfile.neuron b/Dockerfile.neuron index 0d0d8df94578c..2143315d2a078 100644 --- a/Dockerfile.neuron +++ b/Dockerfile.neuron @@ -31,7 +31,7 @@ RUN --mount=type=bind,source=.git,target=.git \ if [ "$GIT_REPO_CHECK" != 0 ]; then bash tools/check_repo.sh ; fi RUN python3 -m pip install -U \ - cmake>=3.26 ninja packaging setuptools-scm>=8 wheel jinja2 \ + 'cmake>=3.26' ninja packaging 'setuptools-scm>=8' wheel jinja2 \ -r requirements-neuron.txt ENV VLLM_TARGET_DEVICE neuron diff --git a/Dockerfile.ppc64le b/Dockerfile.ppc64le index cd5fcf481f07c..b19c6ddec7948 100644 --- a/Dockerfile.ppc64le +++ b/Dockerfile.ppc64le @@ -21,7 +21,7 @@ RUN --mount=type=bind,source=.git,target=.git \ # These packages will be in rocketce eventually RUN --mount=type=cache,target=/root/.cache/pip \ pip install -v --prefer-binary --extra-index-url https://repo.fury.io/mgiessing \ - cmake>=3.26 ninja packaging setuptools-scm>=8 wheel jinja2 \ + 'cmake>=3.26' ninja packaging 'setuptools-scm>=8' wheel jinja2 \ torch==2.3.1 \ -r requirements-cpu.txt \ xformers uvloop==0.20.0 diff --git a/Dockerfile.rocm b/Dockerfile.rocm index 562117a313020..8fb79afaebe97 100644 --- a/Dockerfile.rocm +++ b/Dockerfile.rocm @@ -52,7 +52,7 @@ RUN --mount=type=cache,target=/root/.cache/pip \ python3 -m pip uninstall -y torch torchvision \ && python3 -m pip install --pre \ torch==2.6.0.dev20240918 \ - setuptools-scm>=8 \ + 'setuptools-scm>=8' \ torchvision==0.20.0.dev20240918 \ --extra-index-url https://download.pytorch.org/whl/nightly/rocm6.2;; \ *) ;; esac diff --git a/Dockerfile.tpu b/Dockerfile.tpu index dd8f9ad4714a9..b43442e4c0af1 100644 --- a/Dockerfile.tpu +++ b/Dockerfile.tpu @@ -25,7 +25,7 @@ ENV VLLM_TARGET_DEVICE="tpu" RUN --mount=type=cache,target=/root/.cache/pip \ --mount=type=bind,source=.git,target=.git \ python3 -m pip install \ - cmake>=3.26 ninja packaging setuptools-scm>=8 wheel jinja2 \ + 'cmake>=3.26' ninja packaging 'setuptools-scm>=8' wheel jinja2 \ -r requirements-tpu.txt RUN python3 setup.py develop diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py index 0d205014b15bf..ff06622628219 100644 --- a/benchmarks/benchmark_serving.py +++ b/benchmarks/benchmark_serving.py @@ -406,9 +406,9 @@ def calculate_metrics( median_itl_ms=np.median(itls or 0) * 1000, percentiles_itl_ms=[(p, np.percentile(itls or 0, p) * 1000) for p in selected_percentiles], - mean_e2el_ms=np.median(e2els or 0) * 1000, + mean_e2el_ms=np.mean(e2els or 0) * 1000, std_e2el_ms=np.std(e2els or 0) * 1000, - median_e2el_ms=np.mean(e2els or 0) * 1000, + median_e2el_ms=np.median(e2els or 0) * 1000, percentiles_e2el_ms=[(p, np.percentile(e2els or 0, p) * 1000) for p in selected_percentiles], ) diff --git a/docs/source/getting_started/tpu-installation.rst b/docs/source/getting_started/tpu-installation.rst index edba209986f6a..f0c812b941c1f 100644 --- a/docs/source/getting_started/tpu-installation.rst +++ b/docs/source/getting_started/tpu-installation.rst @@ -1,35 +1,167 @@ .. _installation_tpu: +##################### Installation with TPU -===================== +##################### -vLLM supports Google Cloud TPUs using PyTorch XLA. +Tensor Processing Units (TPUs) are Google's custom-developed application-specific +integrated circuits (ASICs) used to accelerate machine learning workloads. TPUs +are available in different versions each with different hardware specifications. +For more information about TPUs, see `TPU System Architecture `_. +For more information on the TPU versions supported with vLLM, see: + +* `TPU v6e `_ +* `TPU v5e `_ +* `TPU v5p `_ +* `TPU v4 `_ + +These TPU versions allow you to configure the physical arrangements of the TPU +chips. This can improve throughput and networking performance. For more +information see: + +* `TPU v6e topologies `_ +* `TPU v5e topologies `_ +* `TPU v5p topologies `_ +* `TPU v4 topologies `_ + +In order for you to use Cloud TPUs you need to have TPU quota granted to your +Google Cloud Platform project. TPU quotas specify how many TPUs you can use in a +GPC project and are specified in terms of TPU version, the number of TPU you +want to use, and quota type. For more information, see `TPU quota `_. + +For TPU pricing information, see `Cloud TPU pricing `_. + +You may need additional persistent storage for your TPU VMs. For more +information, see `Storage options for Cloud TPU data `_. Requirements ------------ -* Google Cloud TPU VM (single & multi host) -* TPU versions: v5e, v5p, v4 -* Python: 3.10 +* Google Cloud TPU VM +* TPU versions: v6e, v5e, v5p, v4 +* Python: 3.10 or newer + +Provision Cloud TPUs +==================== + +You can provision Cloud TPUs using the `Cloud TPU API `_` +or the `queued resources `_` +API. This section shows how to create TPUs using the queued resource API. +For more information about using the Cloud TPU API, see `Create a Cloud TPU using the Create Node API `_. +`Queued resources `_ +enable you to request Cloud TPU resources in a queued manner. When you request +queued resources, the request is added to a queue maintained by the Cloud TPU +service. When the requested resource becomes available, it's assigned to your +Google Cloud project for your immediate exclusive use. + +Provision a Cloud TPU with the queued resource API +-------------------------------------------------- +Create a TPU v5e with 4 TPU chips: + +.. code-block:: console + + gcloud alpha compute tpus queued-resources create QUEUED_RESOURCE_ID \ + --node-id TPU_NAME \ + --project PROJECT_ID \ + --zone ZONE \ + --accelerator-type ACCELERATOR_TYPE \ + --runtime-version RUNTIME_VERSION \ + --service-account SERVICE_ACCOUNT + +.. list-table:: Parameter descriptions + :header-rows: 1 + + * - Parameter name + - Description + * - QUEUED_RESOURCE_ID + - The user-assigned ID of the queued resource request. + * - TPU_NAME + - The user-assigned name of the TPU which is created when the queued + resource request is allocated. + * - PROJECT_ID + - Your Google Cloud project + * - ZONE + - The `zone `_ where you + want to create your Cloud TPU. + * - ACCELERATOR_TYPE + - The TPU version you want to use. Specify the TPU version, followed by a + '-' and the number of TPU cores. For example `v5e-4` specifies a v5e TPU + with 4 cores. For more information, see `TPU versions `_. + * - RUNTIME_VERSION + - The TPU VM runtime version to use. For more information see `TPU VM images `_. + * - SERVICE_ACCOUNT + - The email address for your service account. You can find it in the IAM + Cloud Console under *Service Accounts*. For example: + `tpu-service-account@.iam.gserviceaccount.com` + +Connect to your TPU using SSH: + +.. code-block:: bash + + gcloud compute tpus tpu-vm ssh TPU_NAME + +Create and activate a Conda environment for vLLM: + +.. code-block:: bash -Installation options: + conda create -n vllm python=3.10 -y + conda activate vllm -1. :ref:`Build a docker image with Dockerfile `. -2. :ref:`Build from source `. +Clone the vLLM repository and go to the vLLM directory: + +.. code-block:: bash + + git clone https://github.com/vllm-project/vllm.git && cd vllm + +Uninstall the existing `torch` and `torch_xla` packages: + +.. code-block:: bash + + pip uninstall torch torch-xla -y + +Install `torch` and `torch_xla` + +.. code-block:: bash + + pip install --pre torch==2.6.0.dev20241028+cpu torchvision==0.20.0.dev20241028+cpu --index-url https://download.pytorch.org/whl/nightly/cpu + pip install 'torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev-cp310-cp310-linux_x86_64.whl' -f https://storage.googleapis.com/libtpu-releases/index.html + +Install JAX and Pallas: + +.. code-block:: bash + + pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html + pip install jaxlib==0.4.32.dev20240829 jax==0.4.32.dev20240829 -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html + +Install other build dependencies: + +.. code-block:: bash + + pip install -r requirements-tpu.txt + VLLM_TARGET_DEVICE="tpu" python setup.py develop + sudo apt-get install libopenblas-base libopenmpi-dev libomp-dev + +Provision Cloud TPUs with GKE +----------------------------- + +For more information about using TPUs with GKE, see +https://cloud.google.com/kubernetes-engine/docs/how-to/tpus +https://cloud.google.com/kubernetes-engine/docs/concepts/tpus +https://cloud.google.com/kubernetes-engine/docs/concepts/plan-tpus .. _build_docker_tpu: Build a docker image with :code:`Dockerfile.tpu` ------------------------------------------------ -`Dockerfile.tpu `_ is provided to build a docker image with TPU support. +You can use `Dockerfile.tpu `_ +to build a Docker image with TPU support. .. code-block:: console $ docker build -f Dockerfile.tpu -t vllm-tpu . - -You can run the docker image with the following command: +Run the Docker image with the following command: .. code-block:: console @@ -75,14 +207,12 @@ Next, build vLLM from source. This will only take a few seconds: $ VLLM_TARGET_DEVICE="tpu" python setup.py develop - .. note:: Since TPU relies on XLA which requires static shapes, vLLM bucketizes the possible input shapes and compiles an XLA graph for each different shape. The compilation time may take 20~30 minutes in the first run. However, the compilation time reduces to ~5 minutes afterwards because the XLA graphs are cached in the disk (in :code:`VLLM_XLA_CACHE_PATH` or :code:`~/.cache/vllm/xla_cache` by default). - .. tip:: If you encounter the following error: @@ -93,7 +223,7 @@ Next, build vLLM from source. This will only take a few seconds: ImportError: libopenblas.so.0: cannot open shared object file: No such file or directory - Please install OpenBLAS with the following command: + Install OpenBLAS with the following command: .. code-block:: console diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index 80714a90df5c2..55835d945b00c 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -160,13 +160,13 @@ Text Generation - - ✅︎ * - :code:`GraniteForCausalLM` - - PowerLM - - :code:`ibm/PowerLM-3b` etc. + - Granite 3.0, PowerLM + - :code:`ibm-granite/granite-3.0-2b-base`, :code:`ibm-granite/granite-3.0-8b-instruct`, :code:`ibm/PowerLM-3b`, etc. - ✅︎ - ✅︎ * - :code:`GraniteMoeForCausalLM` - - PowerMoE - - :code:`ibm/PowerMoE-3b` etc. + - Granite 3.0 MoE, PowerMoE + - :code:`ibm-granite/granite-3.0-1b-a400m-base`, :code:`ibm-granite/granite-3.0-3b-a800m-instruct`, :code:`ibm/PowerMoE-3b`, etc. - ✅︎ - ✅︎ * - :code:`InternLMForCausalLM` @@ -440,6 +440,12 @@ Text Generation - :code:`THUDM/glm-4v-9b` etc. - - ✅︎ + * - :code:`H2OVLChatModel` + - H2OVL + - T + I\ :sup:`E+` + - :code:`h2oai/h2ovl-mississippi-800m`, :code:`h2oai/h2ovl-mississippi-2b`, etc. + - + - ✅︎ * - :code:`InternVLChatModel` - InternVL2 - T + I\ :sup:`E+` diff --git a/examples/offline_inference_audio_language.py b/examples/offline_inference_audio_language.py index 37ec667d96a77..050b791b62adb 100644 --- a/examples/offline_inference_audio_language.py +++ b/examples/offline_inference_audio_language.py @@ -34,11 +34,7 @@ def run_ultravox(question: str, audio_count: int): tokenize=False, add_generation_prompt=True) - llm = LLM(model=model_name, - enforce_eager=True, - enable_chunked_prefill=False, - max_model_len=8192, - limit_mm_per_prompt={"audio": audio_count}) + llm = LLM(model=model_name, limit_mm_per_prompt={"audio": audio_count}) stop_token_ids = None return llm, prompt, stop_token_ids diff --git a/examples/offline_inference_vision_language.py b/examples/offline_inference_vision_language.py index 60cdb186331fe..4fd002caf1763 100644 --- a/examples/offline_inference_vision_language.py +++ b/examples/offline_inference_vision_language.py @@ -176,6 +176,31 @@ def run_minicpmv(question: str, modality: str): return llm, prompt, stop_token_ids +# H2OVL-Mississippi +def run_h2ovl(question: str, modality: str): + assert modality == "image" + + model_name = "h2oai/h2ovl-mississippi-2b" + + llm = LLM( + model=model_name, + trust_remote_code=True, + max_model_len=8192, + ) + + tokenizer = AutoTokenizer.from_pretrained(model_name, + trust_remote_code=True) + messages = [{'role': 'user', 'content': f"\n{question}"}] + prompt = tokenizer.apply_chat_template(messages, + tokenize=False, + add_generation_prompt=True) + + # Stop tokens for H2OVL-Mississippi + # https://huggingface.co/h2oai/h2ovl-mississippi-2b + stop_token_ids = [tokenizer.eos_token_id] + return llm, prompt, stop_token_ids + + # InternVL def run_internvl(question: str, modality: str): assert modality == "image" @@ -363,6 +388,7 @@ def run_glm4v(question: str, modality: str): "chameleon": run_chameleon, "minicpmv": run_minicpmv, "blip-2": run_blip2, + "h2ovl_chat": run_h2ovl, "internvl_chat": run_internvl, "NVLM_D": run_nvlm_d, "qwen_vl": run_qwen_vl, @@ -475,4 +501,4 @@ def main(args): default=16, help='Number of frames to extract from the video.') args = parser.parse_args() - main(args) + main(args) \ No newline at end of file diff --git a/examples/offline_inference_vision_language_multi_image.py b/examples/offline_inference_vision_language_multi_image.py index e28514bf403f7..d99684078ff3d 100644 --- a/examples/offline_inference_vision_language_multi_image.py +++ b/examples/offline_inference_vision_language_multi_image.py @@ -107,6 +107,40 @@ def load_phi3v(question: str, image_urls: List[str]) -> ModelRequestData: ) +def load_h2onvl(question: str, image_urls: List[str]) -> ModelRequestData: + model_name = "h2oai/h2ovl-mississippi-2b" + + llm = LLM( + model=model_name, + trust_remote_code=True, + max_model_len=8192, + limit_mm_per_prompt={"image": len(image_urls)}, + mm_processor_kwargs={"max_dynamic_patch": 4}, + ) + + placeholders = "\n".join(f"Image-{i}: \n" + for i, _ in enumerate(image_urls, start=1)) + messages = [{'role': 'user', 'content': f"{placeholders}\n{question}"}] + + tokenizer = AutoTokenizer.from_pretrained(model_name, + trust_remote_code=True) + prompt = tokenizer.apply_chat_template(messages, + tokenize=False, + add_generation_prompt=True) + + # Stop tokens for H2OVL-Mississippi + # https://huggingface.co/h2oai/h2ovl-mississippi-2b + stop_token_ids = [tokenizer.eos_token_id] + + return ModelRequestData( + llm=llm, + prompt=prompt, + stop_token_ids=stop_token_ids, + image_data=[fetch_image(url) for url in image_urls], + chat_template=None, + ) + + def load_internvl(question: str, image_urls: List[str]) -> ModelRequestData: model_name = "OpenGVLab/InternVL2-2B" @@ -258,6 +292,7 @@ def load_mllama(question, image_urls: List[str]) -> ModelRequestData: model_example_map = { "phi3_v": load_phi3v, + "h2ovl_chat": load_h2onvl, "internvl_chat": load_internvl, "NVLM_D": load_nvlm_d, "qwen2_vl": load_qwen2_vl, diff --git a/requirements-lint.txt b/requirements-lint.txt index 07f738873e1a8..f9132bbf96437 100644 --- a/requirements-lint.txt +++ b/requirements-lint.txt @@ -1,7 +1,7 @@ # formatting yapf==0.32.0 toml==0.10.2 -tomli==2.0.1 +tomli==2.0.2 ruff==0.6.5 codespell==2.3.0 isort==5.13.2 diff --git a/requirements-test.in b/requirements-test.in index 3881f2566b556..5d44664c082a6 100644 --- a/requirements-test.in +++ b/requirements-test.in @@ -32,6 +32,6 @@ aiohttp # quantization bitsandbytes>=0.44.0 -buildkite-test-collector==0.1.8 +buildkite-test-collector==0.1.9 numpy < 2.0.0 diff --git a/requirements-test.txt b/requirements-test.txt index c474c2ec34b22..7477b7c3a79cd 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -36,20 +36,20 @@ attrs==24.2.0 # referencing audioread==3.0.1 # via librosa -awscli==1.35.16 +awscli==1.35.19 # via -r requirements-test.in bitsandbytes==0.44.1 # via -r requirements-test.in black==24.10.0 # via datamodel-code-generator -boto3==1.35.50 +boto3==1.35.53 # via tensorizer -botocore==1.35.50 +botocore==1.35.53 # via # awscli # boto3 # s3transfer -buildkite-test-collector==0.1.8 +buildkite-test-collector==0.1.9 # via -r requirements-test.in certifi==2024.8.30 # via @@ -426,7 +426,7 @@ requests==2.32.3 # transformers rouge-score==0.1.2 # via lm-eval -rpds-py==0.20.0 +rpds-py==0.20.1 # via # jsonschema # referencing @@ -552,7 +552,7 @@ xxhash==3.5.0 # via # datasets # evaluate -yarl==1.17.0 +yarl==1.17.1 # via aiohttp zstandard==0.23.0 # via lm-eval diff --git a/tests/encoder_decoder/test_e2e_correctness.py b/tests/encoder_decoder/test_e2e_correctness.py index bef0c515b9073..f2d7e9fd78cf3 100644 --- a/tests/encoder_decoder/test_e2e_correctness.py +++ b/tests/encoder_decoder/test_e2e_correctness.py @@ -7,12 +7,18 @@ import pytest from transformers import AutoModelForSeq2SeqLM +from vllm.attention.selector import (_Backend, + global_force_attn_backend_context_manager) from vllm.platforms import current_platform from vllm.sequence import SampleLogprobs from ..conftest import DecoderPromptType from ..models.utils import check_logprobs_close +LIST_ENC_DEC_SUPPORTED_BACKENDS = [ + _Backend.XFORMERS, _Backend.FLASH_ATTN, None +] + def vllm_to_hf_output( vllm_output: Tuple[List[int], str, Optional[SampleLogprobs]], @@ -29,7 +35,8 @@ def vllm_to_hf_output( @pytest.mark.parametrize("model", ["facebook/bart-large-cnn"]) -@pytest.mark.parametrize("dtype", ["bfloat16"]) +@pytest.mark.parametrize("dtype", ["float"]) +@pytest.mark.parametrize("attn_backend", LIST_ENC_DEC_SUPPORTED_BACKENDS) @pytest.mark.parametrize("max_tokens", [128]) @pytest.mark.parametrize("num_logprobs", [5]) @pytest.mark.parametrize("decoder_prompt_type", list(DecoderPromptType)) @@ -48,6 +55,7 @@ def test_encoder_decoder_e2e( num_logprobs: int, decoder_prompt_type: DecoderPromptType, enforce_eager: bool, + attn_backend: _Backend, ) -> None: ''' End-to-End (E2E) test for the encoder-decoder framework. @@ -56,43 +64,49 @@ def test_encoder_decoder_e2e( implementations to ensure that both implementations produce consistent and correct results. ''' - test_case_prompts = example_encoder_decoder_prompts[decoder_prompt_type] + with global_force_attn_backend_context_manager(attn_backend): + if attn_backend == _Backend.FLASH_ATTN: + # Flash Attention works only with bfloat16 data-type + dtype = 'bfloat16' + test_case_prompts = example_encoder_decoder_prompts[ + decoder_prompt_type] - # Configuration settings for HF baseline - hf_kwargs = { - "top_k": None, - "num_beams": 1, - "repetition_penalty": 1.0, - "top_p": 1.0, - "length_penalty": 1.0, - "early_stopping": False, - "no_repeat_ngram_size": None, - "min_length": 0 - } + # Configuration settings for HF baseline + hf_kwargs = { + "top_k": None, + "num_beams": 1, + "repetition_penalty": 1.0, + "top_p": 1.0, + "length_penalty": 1.0, + "early_stopping": False, + "no_repeat_ngram_size": None, + "min_length": 0 + } - with hf_runner(model, dtype=dtype, - auto_cls=AutoModelForSeq2SeqLM) as hf_model: - hf_outputs = (hf_model.generate_encoder_decoder_greedy_logprobs_limit( - test_case_prompts, - max_tokens, - num_logprobs, - **hf_kwargs, - )) - with vllm_runner(model, dtype=dtype, - enforce_eager=enforce_eager) as vllm_model: - vllm_outputs = vllm_model.generate_encoder_decoder_greedy_logprobs( - test_case_prompts, max_tokens, num_logprobs) + with hf_runner(model, dtype=dtype, + auto_cls=AutoModelForSeq2SeqLM) as hf_model: + hf_outputs = ( + hf_model.generate_encoder_decoder_greedy_logprobs_limit( + test_case_prompts, + max_tokens, + num_logprobs, + **hf_kwargs, + )) + with vllm_runner(model, dtype=dtype, + enforce_eager=enforce_eager) as vllm_model: + vllm_outputs = vllm_model.generate_encoder_decoder_greedy_logprobs( + test_case_prompts, max_tokens, num_logprobs) - hf_skip_tokens = (1 - if decoder_prompt_type == DecoderPromptType.NONE else 0) + hf_skip_tokens = (1 if decoder_prompt_type == DecoderPromptType.NONE + else 0) - check_logprobs_close( - outputs_0_lst=hf_outputs, - outputs_1_lst=[ - vllm_to_hf_output(vllm_output, decoder_prompt_type) - for vllm_output in vllm_outputs - ], - name_0="hf", - name_1="vllm", - num_outputs_0_skip_tokens=hf_skip_tokens, - ) + check_logprobs_close( + outputs_0_lst=hf_outputs, + outputs_1_lst=[ + vllm_to_hf_output(vllm_output, decoder_prompt_type) + for vllm_output in vllm_outputs + ], + name_0="hf", + name_1="vllm", + num_outputs_0_skip_tokens=hf_skip_tokens, + ) diff --git a/tests/kernels/test_cache.py b/tests/kernels/test_cache.py index 5b8311a33c361..e2b4778b94b9e 100644 --- a/tests/kernels/test_cache.py +++ b/tests/kernels/test_cache.py @@ -258,19 +258,20 @@ def test_reshape_and_cache_flash( del key_caches del value_caches + k_scale = key.amax().item() / 256 + v_scale = value.amax().item() / 256 + # Clone the KV caches. if kv_cache_dtype == "fp8": cloned_key_cache = torch.empty_like(key_cache, dtype=torch.float16) - ops.convert_fp8(cloned_key_cache, key_cache) + ops.convert_fp8(cloned_key_cache, key_cache, k_scale, kv_cache_dtype) cloned_value_cache = torch.empty_like(value_cache, dtype=torch.float16) - ops.convert_fp8(cloned_value_cache, value_cache) + ops.convert_fp8(cloned_value_cache, value_cache, v_scale, + kv_cache_dtype) else: cloned_key_cache = key_cache.clone() cloned_value_cache = value_cache.clone() - # Using default kv_scale - k_scale = v_scale = 1.0 - # Call the reshape_and_cache kernel. opcheck(torch.ops._C_cache_ops.reshape_and_cache_flash, (key, value, key_cache, value_cache, slot_mapping, kv_cache_dtype, @@ -281,9 +282,15 @@ def test_reshape_and_cache_flash( if kv_cache_dtype == "fp8": result_key_cache = torch.empty_like(key_cache, dtype=torch.float16) - ops.convert_fp8(result_key_cache, key_cache) + ops.convert_fp8(result_key_cache, + key_cache, + k_scale, + kv_dtype=kv_cache_dtype) result_value_cache = torch.empty_like(value_cache, dtype=torch.float16) - ops.convert_fp8(result_value_cache, value_cache) + ops.convert_fp8(result_value_cache, + value_cache, + v_scale, + kv_dtype=kv_cache_dtype) # Run the reference implementation. block_indicies = torch.div(slot_mapping, block_size, rounding_mode="floor") diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index bc99c5559d388..a1dd5eeeaa398 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -16,13 +16,13 @@ from vllm.attention import (Attention, AttentionBackend, AttentionMetadata, AttentionType) from vllm.attention.backends.utils import STR_NOT_IMPL_ENC_DEC_ROCM_HIP -from vllm.attention.selector import (_Backend, +from vllm.attention.selector import (_Backend, get_attn_backend, global_force_attn_backend_context_manager) +from vllm.forward_context import set_forward_context from vllm.platforms import current_platform # List of support backends for encoder/decoder models -LIST_ENC_DEC_SUPPORTED_BACKENDS = [_Backend.XFORMERS] - +LIST_ENC_DEC_SUPPORTED_BACKENDS = [_Backend.XFORMERS, _Backend.FLASH_ATTN] HEAD_SIZES = [64, 256] NUM_HEADS = [1, 16] @@ -145,7 +145,8 @@ class that Attention will automatically select when it is constructed. test_pt.num_heads, test_pt.head_size, test_pt.block_size, - device=CUDA_DEVICE) + device=CUDA_DEVICE, + backend=test_pt.backend_name) return TestResources(scale, attn_backend, attn, kv_cache) @@ -592,6 +593,7 @@ def _run_encoder_attention_test( attn: Attention, encoder_test_params: PhaseTestParameters, attn_metadata: AttentionMetadata, + test_pt: TestPoint, ) -> torch.Tensor: ''' Run encoder attention. @@ -610,6 +612,8 @@ def _run_encoder_attention_test( (number_of_tokens x num_heads x head_size) query/key/value fields * attn_metadata: attention metadata for encoder/decoder-self attention + * test_pt: The TestPoint object containing test details like number of + model heads, head size, name of the backend being used etc. Returns: * Attention.forward() applied to packed {query,key,value} and @@ -619,20 +623,31 @@ def _run_encoder_attention_test( attn_type = AttentionType.ENCODER packed_qkv = encoder_test_params.packed_qkvo.packed_qkv assert packed_qkv is not None - return attn.forward(packed_qkv.query, - packed_qkv.key, - packed_qkv.value, - torch.tensor([], - dtype=torch.float32, - device=packed_qkv.query.device), - attn_metadata, - attn_type=attn_type) + with set_forward_context(attn_metadata): + # In the test setup the shape of the query is + # [batch_size, seq_len, num_heads, head_size]. However + # the attention backend expect the shape to be + # [num_tokens, hidden_size]. Hence reshape the query before + # invoking the forward method. + # TODO - Update the way we construct the query so that it + # is shaped as [num_tokens, hidden_size] and we can skip the reshape. + reshaped_query = packed_qkv.query.view( + -1, test_pt.num_heads * test_pt.head_size) + return attn.forward(reshaped_query, + packed_qkv.key, + packed_qkv.value, + torch.tensor([], + dtype=torch.float32, + device=packed_qkv.query.device), + attn_metadata, + attn_type=attn_type) def _run_decoder_self_attention_test( test_rsrcs: TestResources, decoder_test_params: PhaseTestParameters, attn_metadata: AttentionMetadata, + test_pt: TestPoint, ) -> torch.Tensor: ''' Run decoder self-attention test. @@ -650,6 +665,8 @@ def _run_decoder_self_attention_test( query/key/value fields * attn_metadata: attention metadata for decoder-self attention (contains KV cache memory-mapping) + * test_pt: The TestPoint object containing test details like number of + model heads, head size, name of the backend being used etc. Returns: * Attention.forward() applied to packed_{query,key,value}, kv_cache @@ -660,12 +677,22 @@ def _run_decoder_self_attention_test( kv_cache = test_rsrcs.kv_cache packed_qkv = decoder_test_params.packed_qkvo.packed_qkv assert packed_qkv is not None - return attn.forward(packed_qkv.query, - packed_qkv.key, - packed_qkv.value, - kv_cache, - attn_metadata, - attn_type=attn_type) + with set_forward_context(attn_metadata): + # In the test setup the shape of the query is + # [batch_size, seq_len, num_heads, head_size]. However + # the attention backend expect the shape to be + # [num_tokens, hidden_size]. Hence reshape the query before + # invoking the forward method. + # TODO - Update the way we construct the query so that it + # is shaped as [num_tokens, hidden_size] and we can skip the reshape. + reshaped_query = packed_qkv.query.view( + -1, test_pt.num_heads * test_pt.head_size) + return attn.forward(reshaped_query, + packed_qkv.key, + packed_qkv.value, + kv_cache, + attn_metadata, + attn_type=attn_type) def _run_encoder_decoder_cross_attention_test( @@ -673,6 +700,7 @@ def _run_encoder_decoder_cross_attention_test( decoder_test_params: PhaseTestParameters, cross_test_params: Optional[PhaseTestParameters], attn_metadata: AttentionMetadata, + test_pt: TestPoint, ) -> torch.Tensor: ''' Run encoder/decoder cross-attention test. @@ -701,6 +729,8 @@ def _run_encoder_decoder_cross_attention_test( (number_of_tokens x num_heads x head_size) key/value fields * attn_metadata: attention metadata for encoder/decoder-self attention + * test_pt: The TestPoint object containing test details like number of + model heads, head size, name of the backend being used etc. Returns: * Attention.forward() applied to packed_{query,key,value}, kv_cache @@ -718,12 +748,37 @@ def _run_encoder_decoder_cross_attention_test( cross_pckd_qkv = cross_test_params.packed_qkvo.packed_qkv key = (None if cross_pckd_qkv is None else cross_pckd_qkv.key) value = (None if cross_pckd_qkv is None else cross_pckd_qkv.value) - return attn.forward(decoder_test_params.packed_qkvo.packed_qkv.query, - key, - value, - kv_cache, - attn_metadata, - attn_type=attn_type) + with set_forward_context(attn_metadata): + # In the test setup the shape of the query is + # [batch_size, seq_len, num_heads, head_size]. However + # the attention backend expect the shape to be + # [num_tokens, hidden_size]. Hence reshape the query before + # invoking the forward method. + # TODO - Update the way we construct the query so that it + # is shaped as [num_tokens, hidden_size] and we can skip the reshape. + reshaped_query = decoder_test_params.packed_qkvo.packed_qkv.query.view( + -1, test_pt.num_heads * test_pt.head_size) + return attn.forward(reshaped_query, + key, + value, + kv_cache, + attn_metadata, + attn_type=attn_type) + + +@pytest.fixture(autouse=True) +def set_reset_environment(attn_backend): + # Set the default torch datatype to bfloat16 to enable + # testing of the Flash Attention backend. Also clear the + # cached value of the backend. + default_dtype = torch.get_default_dtype() + if attn_backend.name == 'FLASH_ATTN': + torch.set_default_dtype(torch.bfloat16) + get_attn_backend.cache_clear() + yield + # Reset the torch datatype to what it was before the test + # so as not to impact the remaining tests. + torch.set_default_dtype(default_dtype) @pytest.mark.skipif(current_platform.is_rocm(), @@ -773,10 +828,8 @@ def test_encoder_only( * max_dec_seq_len: max length of decoder input sequences * max_enc_seq_len: max length of encoder input sequences ''' - # Force Attention wrapper backend with global_force_attn_backend_context_manager(attn_backend): - # Note: KV cache size of 4096 is arbitrary & chosen intentionally # to be more than necessary, since exceeding the kv cache size # is not part of this test @@ -807,10 +860,14 @@ def test_encoder_only( # PREFILL: encoder attention enc_pckd_act_out: torch.Tensor = (_run_encoder_attention_test( - test_rsrcs.attn, enc_test_params, prephase_attn_metadata)) + test_rsrcs.attn, + enc_test_params, + prephase_attn_metadata, + test_pt=test_pt)) # - Is encoder attention result correct? - assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out) + assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out, + attn_backend.name) @pytest.mark.skipif(current_platform.is_rocm(), @@ -892,10 +949,8 @@ def test_e2e_enc_dec_attn( * max_dec_seq_len: max length of decoder input sequences * max_enc_seq_len: max length of encoder input sequences ''' - # Force Attention wrapper backend with global_force_attn_backend_context_manager(attn_backend): - # Note: KV cache size of 4096 is arbitrary & chosen intentionally # to be more than necessary, since exceeding the kv cache size # is not part of this test @@ -955,29 +1010,39 @@ def test_e2e_enc_dec_attn( enc_pckd_act_out = _run_encoder_attention_test(test_rsrcs.attn, enc_test_params, - prephase_attn_metadata) + prephase_attn_metadata, + test_pt=test_pt) # - Is encoder attention result correct? - assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out) + assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out, + attn_backend.name) # PREFILL: decoder self-attention test prephase_dec_pckd_act_out = _run_decoder_self_attention_test( - test_rsrcs, prephase_dec_test_params, prephase_attn_metadata) + test_rsrcs, + prephase_dec_test_params, + prephase_attn_metadata, + test_pt=test_pt) # - Is prefill decoder self-attention correct? assert_actual_matches_ideal(prephase_dec_test_params, - prephase_dec_pckd_act_out) + prephase_dec_pckd_act_out, + attn_backend.name) # PREFILL: encoder/decoder cross-attention test prephase_cross_pckd_act_out = _run_encoder_decoder_cross_attention_test( - test_rsrcs, prephase_dec_test_params, prephase_cross_test_params, - prephase_attn_metadata) + test_rsrcs, + prephase_dec_test_params, + prephase_cross_test_params, + prephase_attn_metadata, + test_pt=test_pt) # - Is prefill encoder/decoder cross-attention correct? assert_actual_matches_ideal(prephase_cross_test_params, - prephase_cross_pckd_act_out) + prephase_cross_pckd_act_out, + attn_backend.name) # DECODE: build decode-phase attention metadata @@ -993,17 +1058,26 @@ def test_e2e_enc_dec_attn( # DECODE: decoder self-attention test decphase_dec_pckd_act_out = _run_decoder_self_attention_test( - test_rsrcs, decphase_dec_test_params, decphase_attn_metadata) + test_rsrcs, + decphase_dec_test_params, + decphase_attn_metadata, + test_pt=test_pt) # - Is decode-phase decoder self-attention correct? assert_actual_matches_ideal(decphase_dec_test_params, - decphase_dec_pckd_act_out) + decphase_dec_pckd_act_out, + attn_backend.name) # DECODE: encoder/decoder cross-attention test decphase_cross_pckd_act_out = _run_encoder_decoder_cross_attention_test( - test_rsrcs, decphase_dec_test_params, None, decphase_attn_metadata) + test_rsrcs, + decphase_dec_test_params, + None, + decphase_attn_metadata, + test_pt=test_pt) # - Is decode-phase encoder/decoder cross-attention correct? assert_actual_matches_ideal(decphase_cross_test_params, - decphase_cross_pckd_act_out) + decphase_cross_pckd_act_out, + attn_backend.name) diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index a2d414f636e13..e7865fb2500ef 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -13,8 +13,8 @@ from vllm.attention import AttentionBackend, AttentionMetadata, AttentionType from vllm.model_executor.layers.activation import SiluAndMul -from vllm.utils import (STR_BACKEND_ENV_VAR, STR_XFORMERS_ATTN_VAL, - make_tensor_with_pad) +from vllm.utils import (STR_BACKEND_ENV_VAR, STR_FLASH_ATTN_VAL, + STR_XFORMERS_ATTN_VAL, make_tensor_with_pad) # For now, disable "test_aot_dispatch_dynamic" since there are some # bugs related to this test in PyTorch 2.4. @@ -525,17 +525,22 @@ def make_backend(backend_name: str) -> AttentionBackend: if backend_name == STR_XFORMERS_ATTN_VAL: # NOTE: xFormers backend cannot be imported for CPU and AMD GPUs. from vllm.attention.backends.xformers import XFormersBackend - return XFormersBackend() + elif backend_name == STR_FLASH_ATTN_VAL: + from vllm.attention.backends.flash_attn import FlashAttentionBackend + return FlashAttentionBackend() + raise AssertionError( f"Unrecognized backend_name {backend_name} for unit test") def _make_metadata_tensors( - seq_lens: Optional[List[int]], context_lens: Optional[List[int]], - encoder_seq_lens: Optional[List[int]], device: Union[torch.device, str] -) -> Tuple[torch.Tensor, torch.Tensor, Any, Any, Optional[List[int]], - torch.Tensor, Optional[int]]: + seq_lens: Optional[List[int]], + context_lens: Optional[List[int]], + encoder_seq_lens: Optional[List[int]], + device: Union[torch.device, str], +) -> Tuple[torch.Tensor, torch.Tensor, Any, Any, Optional[torch.Tensor], + torch.Tensor, torch.Tensor, Optional[int]]: ''' Build scalar & tensor values required to build attention metadata structure. @@ -553,6 +558,8 @@ def _make_metadata_tensors( * max_context_len: max(context_lens) * max_seq_len: max(seq_lens) * seq_start_loc: start idx of each sequence + * encoder_seq_lens_tensor: encoder seq_lens list, as tensor + * encoder_seq_start_loc: start idx of each encoder sequence * max_encoder_seq_len: encoder seq_lens list, as tensor ''' seq_lens_tensor = maybe_make_int_tensor(seq_lens, device) @@ -566,8 +573,26 @@ def _make_metadata_tensors( seq_start_loc = None + if seq_lens_tensor is not None: + seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1, + dtype=torch.int32, + device=seq_lens_tensor.device) + torch.cumsum(seq_lens_tensor, + dim=0, + dtype=seq_start_loc.dtype, + out=seq_start_loc[1:]) + + encoder_seq_start_loc = torch.zeros(encoder_seq_lens_tensor.shape[0] + 1, + dtype=torch.int32, + device=encoder_seq_lens_tensor.device) + torch.cumsum(encoder_seq_lens_tensor, + dim=0, + dtype=encoder_seq_start_loc.dtype, + out=encoder_seq_start_loc[1:]) + return (seq_lens_tensor, context_lens_tensor, max_context_len, max_seq_len, - seq_start_loc, encoder_seq_lens_tensor, max_encoder_seq_len) + seq_start_loc, encoder_seq_lens_tensor, encoder_seq_start_loc, + max_encoder_seq_len) def make_kv_cache(num_blocks: int, @@ -575,6 +600,7 @@ def make_kv_cache(num_blocks: int, head_size: int, block_size: int, device: Union[torch.device, str], + backend: str, default_val: float = 0.0) -> torch.Tensor: ''' Create a fake KV cache. @@ -591,10 +617,20 @@ def make_kv_cache(num_blocks: int, Returns: * kv_cache: 2 x num_blocks x (block_size * num_heads * head_size) + * for backend 'XFORMERS' + * kv_cache: 2 x num_blocks x block_size x num_heads x head_size + * for backend 'FLASH_ATTN' ''' - - kv_cache = torch.rand( - (2, num_blocks, block_size * num_heads * head_size)).to(device) + if backend == 'XFORMERS': + kv_cache = torch.rand( + (2, num_blocks, block_size * num_heads * head_size)).to(device) + elif backend == 'FLASH_ATTN': + kv_cache = torch.rand( + (2, num_blocks, block_size, num_heads, head_size)).to(device) + else: + raise ValueError( + f"Unknown backend value: '{backend}'. Expected 'XFORMERS' or " + f"'FLASH_ATTN'.") if default_val is not None: kv_cache[:, :, :] = default_val return kv_cache @@ -858,8 +894,9 @@ def make_test_metadata( context_lens_tensor, _, _, - _, + seq_start_loc, encoder_seq_lens_tensor, + encoder_seq_start_loc, max_encoder_seq_len, ) = _make_metadata_tensors(seq_lens, context_lens, @@ -869,10 +906,12 @@ def make_test_metadata( return attn_backend.make_metadata( num_prefills=num_prefills, slot_mapping=(None if kv_mmap is None else kv_mmap.slot_mapping), + multi_modal_placeholder_index_maps=None, num_prefill_tokens=num_prefill_tokens, num_decode_tokens=num_decode_tokens, seq_lens=seq_lens, seq_lens_tensor=seq_lens_tensor, + seq_start_loc=seq_start_loc, max_prefill_seq_len=None if seq_lens is None else max(seq_lens), max_decode_seq_len=0, context_lens_tensor=context_lens_tensor, @@ -881,6 +920,7 @@ def make_test_metadata( num_encoder_tokens=num_encoder_tokens, encoder_seq_lens=encoder_seq_lens, encoder_seq_lens_tensor=encoder_seq_lens_tensor, + encoder_seq_start_loc=encoder_seq_start_loc, max_encoder_seq_len=max_encoder_seq_len, cross_slot_mapping=(None if cross_kv_mmap is None else cross_kv_mmap.slot_mapping), @@ -903,8 +943,9 @@ def make_test_metadata( context_lens_tensor, _, _, - _, + seq_start_loc, encoder_seq_lens_tensor, + encoder_seq_start_loc, max_encoder_seq_len, ) = _make_metadata_tensors(seq_lens, context_lens, @@ -914,18 +955,22 @@ def make_test_metadata( return attn_backend.make_metadata( num_prefills=num_prefills, slot_mapping=kv_mmap.slot_mapping, + multi_modal_placeholder_index_maps=None, num_prefill_tokens=num_prefill_tokens, num_decode_tokens=num_decode_tokens, seq_lens=seq_lens, seq_lens_tensor=seq_lens_tensor, + seq_start_loc=seq_start_loc, max_prefill_seq_len=0, max_decode_seq_len=max(seq_lens), + max_decode_query_len=1, context_lens_tensor=context_lens_tensor, block_tables=kv_mmap.block_tables, use_cuda_graph=False, num_encoder_tokens=num_encoder_tokens, encoder_seq_lens=encoder_seq_lens, encoder_seq_lens_tensor=encoder_seq_lens_tensor, + encoder_seq_start_loc=encoder_seq_start_loc, max_encoder_seq_len=max_encoder_seq_len, cross_slot_mapping=(None if cross_kv_mmap is None else cross_kv_mmap.slot_mapping), @@ -934,7 +979,8 @@ def make_test_metadata( def assert_actual_matches_ideal(test_params: PhaseTestParameters, - output_under_test: torch.Tensor) -> None: + output_under_test: torch.Tensor, + backend: str) -> None: ''' Assert that observed output matches the ideal output contained in the test parameters data structure. @@ -945,8 +991,22 @@ def assert_actual_matches_ideal(test_params: PhaseTestParameters, * output_under_test: actually observed output value ''' ideal_output = test_params.packed_qkvo.ideal_output - torch.testing.assert_close(ideal_output, - output_under_test.view_as(ideal_output)) + if backend == 'XFORMERS': + torch.testing.assert_close(ideal_output, + output_under_test.view_as(ideal_output)) + + elif backend == 'FLASH_ATTN': + # For FlashAttention override the accuracy thresholds to non default + # values since we notice a higher difference between the ideal and + # actual output. + torch.testing.assert_close(ideal_output, + output_under_test.view_as(ideal_output), + atol=0.01, + rtol=0.016) + else: + raise ValueError( + f"Unknown backend value: '{backend}'. Expected 'XFORMERS' or " + f"'FLASH_ATTN'.") # Copied/modified from torch._refs.__init__.py diff --git a/tests/lora/conftest.py b/tests/lora/conftest.py index e40f0dd74602e..816d3986fe333 100644 --- a/tests/lora/conftest.py +++ b/tests/lora/conftest.py @@ -248,11 +248,10 @@ def llama_2_7b_engine_extra_embeddings(): cleanup_dist_env_and_memory(shutdown_ray=True) get_model_old = get_model - def get_model_patched(*, model_config, device_config, **kwargs): - kwargs["lora_config"] = LoRAConfig(max_loras=4, max_lora_rank=8) - return get_model_old(model_config=model_config, - device_config=device_config, - **kwargs) + def get_model_patched(**kwargs): + kwargs["vllm_config"].lora_config = LoRAConfig(max_loras=4, + max_lora_rank=8) + return get_model_old(**kwargs) with patch("vllm.worker.model_runner.get_model", get_model_patched): engine = vllm.LLM("meta-llama/Llama-2-7b-hf", enable_lora=False) diff --git a/tests/lora/test_long_context.py b/tests/lora/test_long_context.py index c8edb02a88d4b..eada902c891f7 100644 --- a/tests/lora/test_long_context.py +++ b/tests/lora/test_long_context.py @@ -138,13 +138,7 @@ def test_rotary_emb_replaced(dist_init): enable_lora=True) engine_config = engine_args.create_engine_config() model_runner = ModelRunner( - model_config=engine_config.model_config, - parallel_config=engine_config.parallel_config, - scheduler_config=engine_config.scheduler_config, - device_config=engine_config.device_config, - cache_config=engine_config.cache_config, - load_config=engine_config.load_config, - lora_config=engine_config.lora_config, + vllm_config=engine_config, is_driver_worker=True, ) model_runner.load_model() diff --git a/tests/lora/test_worker.py b/tests/lora/test_worker.py index 2f7ac85507425..9d814f657ac43 100644 --- a/tests/lora/test_worker.py +++ b/tests/lora/test_worker.py @@ -4,7 +4,8 @@ from unittest.mock import patch from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, - ModelConfig, ParallelConfig, SchedulerConfig) + ModelConfig, ParallelConfig, SchedulerConfig, + VllmConfig) from vllm.lora.models import LoRAMapping from vllm.lora.request import LoRARequest from vllm.worker.worker import Worker @@ -12,7 +13,7 @@ @patch.dict(os.environ, {"RANK": "0"}) def test_worker_apply_lora(sql_lora_files): - worker = Worker( + vllm_config = VllmConfig( model_config=ModelConfig( "meta-llama/Llama-2-7b-hf", task="auto", @@ -34,10 +35,13 @@ def test_worker_apply_lora(sql_lora_files): gpu_memory_utilization=1., swap_space=0, cache_dtype="auto"), - local_rank=0, - rank=0, lora_config=LoRAConfig(max_lora_rank=8, max_cpu_loras=32, max_loras=32), + ) + worker = Worker( + vllm_config=vllm_config, + local_rank=0, + rank=0, distributed_init_method=f"file://{tempfile.mkstemp()[1]}", ) worker.init_device() diff --git a/tests/models/decoder_only/audio_language/test_ultravox.py b/tests/models/decoder_only/audio_language/test_ultravox.py index b9089e75ffab8..d14e88b4e5b26 100644 --- a/tests/models/decoder_only/audio_language/test_ultravox.py +++ b/tests/models/decoder_only/audio_language/test_ultravox.py @@ -2,8 +2,10 @@ import numpy as np import pytest +import pytest_asyncio from transformers import AutoModel, AutoTokenizer, BatchEncoding +from tests.utils import RemoteOpenAIServer from vllm.sequence import SampleLogprobs from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE @@ -17,6 +19,13 @@ VLLM_PLACEHOLDER = "<|reserved_special_token_0|>" HF_PLACEHOLDER = "<|audio|>" +CHUNKED_PREFILL_KWARGS = { + "enable_chunked_prefill": True, + "max_num_seqs": 2, + # Use a very small limit to exercise chunked prefill. + "max_num_batched_tokens": 16 +} + @pytest.fixture(scope="session") def audio_assets(): @@ -30,6 +39,26 @@ def audio(request): return AudioAsset(request.param) +@pytest.fixture(params=({}, CHUNKED_PREFILL_KWARGS)) +def server(request, audio_assets): + args = [ + "--dtype=bfloat16", "--max-model-len=4096", "--enforce-eager", + f"--limit-mm-per-prompt=audio={len(audio_assets)}" + ] + [ + f"--{key.replace('_','-')}={value}" + for key, value in request.param.items() + ] + + with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: + yield remote_server + + +@pytest_asyncio.fixture +async def client(server): + async with server.get_async_client() as async_client: + yield async_client + + def _get_prompt(audio_count, question, placeholder): tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) placeholder = f"{placeholder}\n" * audio_count @@ -68,8 +97,7 @@ def run_test( dtype: str, max_tokens: int, num_logprobs: int, - tensor_parallel_size: int, - distributed_executor_backend: Optional[str] = None, + **kwargs, ): """Inference result should be the same between hf and vllm.""" torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[dtype] @@ -79,11 +107,8 @@ def run_test( # if we run HF first, the cuda initialization will be done and it # will hurt multiprocessing backend with fork method (the default method). - with vllm_runner(model, - dtype=dtype, - tensor_parallel_size=tensor_parallel_size, - distributed_executor_backend=distributed_executor_backend, - enforce_eager=True) as vllm_model: + with vllm_runner(model, dtype=dtype, enforce_eager=True, + **kwargs) as vllm_model: vllm_outputs_per_audio = [ vllm_model.generate_greedy_logprobs([vllm_prompt], max_tokens, @@ -135,18 +160,16 @@ def run_multi_audio_test( dtype: str, max_tokens: int, num_logprobs: int, - tensor_parallel_size: int, - distributed_executor_backend: Optional[str] = None, + **kwargs, ): with vllm_runner(model, dtype=dtype, - tensor_parallel_size=tensor_parallel_size, - distributed_executor_backend=distributed_executor_backend, enforce_eager=True, limit_mm_per_prompt={ "audio": max((len(audio) for _, audio in prompts_and_audios)) - }) as vllm_model: + }, + **kwargs) as vllm_model: vllm_outputs = vllm_model.generate_greedy_logprobs( [prompt for prompt, _ in prompts_and_audios], max_tokens, @@ -162,8 +185,9 @@ def run_multi_audio_test( @pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("max_tokens", [128]) @pytest.mark.parametrize("num_logprobs", [5]) +@pytest.mark.parametrize("vllm_kwargs", [{}, CHUNKED_PREFILL_KWARGS]) def test_models(hf_runner, vllm_runner, audio, dtype: str, max_tokens: int, - num_logprobs: int) -> None: + num_logprobs: int, vllm_kwargs: dict) -> None: vllm_prompt = _get_prompt(1, "Describe the audio above.", VLLM_PLACEHOLDER) hf_prompt = _get_prompt(1, "Describe the audio above.", HF_PLACEHOLDER) @@ -175,7 +199,7 @@ def test_models(hf_runner, vllm_runner, audio, dtype: str, max_tokens: int, dtype=dtype, max_tokens=max_tokens, num_logprobs=num_logprobs, - tensor_parallel_size=1, + **vllm_kwargs, ) @@ -183,9 +207,10 @@ def test_models(hf_runner, vllm_runner, audio, dtype: str, max_tokens: int, @pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("max_tokens", [128]) @pytest.mark.parametrize("num_logprobs", [5]) +@pytest.mark.parametrize("vllm_kwargs", [{}, CHUNKED_PREFILL_KWARGS]) def test_models_with_multiple_audios(vllm_runner, audio_assets, dtype: str, - max_tokens: int, - num_logprobs: int) -> None: + max_tokens: int, num_logprobs: int, + vllm_kwargs: dict) -> None: vllm_prompt = _get_prompt(len(audio_assets), "Describe each of the audios above.", @@ -198,5 +223,37 @@ def test_models_with_multiple_audios(vllm_runner, audio_assets, dtype: str, dtype=dtype, max_tokens=max_tokens, num_logprobs=num_logprobs, - tensor_parallel_size=1, + **vllm_kwargs, ) + + +@pytest.mark.asyncio +async def test_online_inference(client, audio_assets): + """Exercises online inference with/without chunked prefill enabled.""" + + messages = [{ + "role": + "user", + "content": [ + *[{ + "type": "audio_url", + "audio_url": { + "url": audio.url + } + } for audio in audio_assets], + { + "type": + "text", + "text": + f"What's happening in these {len(audio_assets)} audio clips?" + }, + ], + }] + + chat_completion = await client.chat.completions.create(model=MODEL_NAME, + messages=messages, + max_tokens=10) + + assert len(chat_completion.choices) == 1 + choice = chat_completion.choices[0] + assert choice.finish_reason == "length" diff --git a/tests/models/decoder_only/vision_language/test_h2ovl.py b/tests/models/decoder_only/vision_language/test_h2ovl.py new file mode 100644 index 0000000000000..ad9aa3104750b --- /dev/null +++ b/tests/models/decoder_only/vision_language/test_h2ovl.py @@ -0,0 +1,130 @@ +from typing import Optional, Tuple + +import pytest +import torch +from PIL.Image import Image +from transformers import AutoConfig + +# Import the functions to test +from vllm.model_executor.models.h2ovl import (calculate_num_blocks, + image_to_pixel_values_wrapper) +from vllm.multimodal.utils import rescale_image_size + +models = [ + "h2oai/h2ovl-mississippi-800m", # Replace with your actual model names + "h2oai/h2ovl-mississippi-2b", +] +target_dtype = "bfloat16" + + +def run_preprocessing_test( + image: Image, + config, + max_dynamic_patch: Optional[int] = None, +) -> Tuple[torch.Tensor, int]: + """Test the image preprocessing and calculate expected blocks.""" + + if max_dynamic_patch is None: + max_dynamic_patch = config.max_dynamic_patch + + width, height = image.size + use_MSAC = config.use_msac + + # Create the mapper function with the provided configuration + mapper = image_to_pixel_values_wrapper(config, max_dynamic_patch, use_MSAC) + pixel_values = mapper(image) + + # Calculate the expected number of blocks + if use_MSAC: + # First pass + blocks1, _, _, aspect_ratio = calculate_num_blocks( + width, + height, + config.min_dynamic_patch, + max_dynamic_patch, + config.vision_config.image_size, + use_thumbnail=False, # Thumbnail is handled separately + prior_aspect_ratio=None, + ) + + # Second pass + blocks2, _, _, _ = calculate_num_blocks( + width, + height, + config.min_dynamic_patch, + max_dynamic_patch, + config.vision_config.image_size, + use_thumbnail=False, + prior_aspect_ratio=aspect_ratio, + ) + + # Add thumbnail if use_thumbnail is True and total_blocks > 1 + if config.use_thumbnail: + blocks1 += 1 if blocks1 > 1 else 0 + blocks2 += 1 if blocks2 > 1 else 0 + + # Total blocks is the sum of blocks from both passes minus overlapping + total_blocks = blocks1 + blocks2 - 1 + + expected_blocks = total_blocks + + else: + blocks, _, _, _ = calculate_num_blocks( + width, + height, + config.min_dynamic_patch, + max_dynamic_patch, + config.vision_config.image_size, + use_thumbnail=False, + prior_aspect_ratio=None, + ) + expected_blocks = blocks + + if config.use_thumbnail and expected_blocks > 1: + expected_blocks += 1 + + return pixel_values, expected_blocks + + +@pytest.mark.parametrize("model_name", models) +@pytest.mark.parametrize( + "size_factors", + [ + # Single-scale + [1.0], + # Single-scale, batched + [1.0, 1.0, 1.0], + # Multi-scale + [0.25, 0.5, 1.0], + ], +) +@pytest.mark.parametrize("max_dynamic_patch", [None, 2, 4, 8]) +def test_image_preprocessing(image_assets, model_name, size_factors, + max_dynamic_patch): + """Test image preprocessing pipeline with different configurations.""" + # Load the configuration from the model + config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) + + for asset in image_assets: + image = asset.pil_image + for factor in size_factors: + scaled_image = rescale_image_size(image, factor) + + # Test preprocessing and get expected number of blocks + pixel_values, expected_blocks = run_preprocessing_test( + scaled_image, config, max_dynamic_patch) + + # Verify output shapes and properties + actual_blocks = pixel_values.shape[0] + assert actual_blocks == expected_blocks, ( + f"Expected {expected_blocks} blocks, got {actual_blocks}") + + # Check image dimensions + expected_size = ( + 3, # Number of channels (C, H, W) + config.vision_config.image_size, + config.vision_config.image_size, + ) + for img in pixel_values: + assert img.shape == expected_size, ( + f"Expected image size {expected_size}, got {img.shape}") diff --git a/tests/models/decoder_only/vision_language/test_models.py b/tests/models/decoder_only/vision_language/test_models.py index e49ea6f98324d..cfd2d61f2b633 100644 --- a/tests/models/decoder_only/vision_language/test_models.py +++ b/tests/models/decoder_only/vision_language/test_models.py @@ -187,6 +187,23 @@ marks=[large_gpu_mark(min_gb=48)], patch_hf_runner=model_utils.glm_patch_hf_runner, ), + "h2ovl": VLMTestInfo( + models = [ + "h2oai/h2ovl-mississippi-800m", + "h2oai/h2ovl-mississippi-2b", + ], + test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), + prompt_formatter=lambda img_prompt: f"<|prompt|>{img_prompt}<|end|><|answer|>", # noqa: E501 + single_image_prompts=IMAGE_ASSETS.prompts({ + "stop_sign": "\nWhat's the content in the center of the image?", # noqa: E501 + "cherry_blossom": "\nWhat is the season?", + }), + multi_image_prompt="Image-1: \nImage-2: \nDescribe the two images in short.", # noqa: E501 + max_model_len=8192, + dtype="bfloat16", + use_tokenizer_eos=True, + patch_hf_runner=model_utils.h2ovl_patch_hf_runner, + ), "intern_vl": VLMTestInfo( models=[ "OpenGVLab/InternVL2-1B", diff --git a/tests/models/decoder_only/vision_language/vlm_utils/model_utils.py b/tests/models/decoder_only/vision_language/vlm_utils/model_utils.py index e925934db0e7c..849857b4232e7 100644 --- a/tests/models/decoder_only/vision_language/vlm_utils/model_utils.py +++ b/tests/models/decoder_only/vision_language/vlm_utils/model_utils.py @@ -259,6 +259,66 @@ def processor(*args, text="", images=None, **kwargs): return hf_model +def h2ovl_patch_hf_runner(hf_model: HfRunner) -> HfRunner: + """Patches and returns an instance of the HfRunner to use for H2OVL.""" + + class H2OVLProcessor: + """A simple processor for H2OVL models.""" + + def __init__(self, hf_runner: HfRunner): + self.num_image_token = hf_runner.model.num_image_token + self.tokenizer = hf_runner.tokenizer + self.dtype = hf_runner.model.dtype + + self.config = AutoConfig.from_pretrained(hf_runner.model_name, + trust_remote_code=True) + self.vision_config = self.config.vision_config + self.use_thumbnail = self.config.use_thumbnail + self.min_num = self.config.min_dynamic_patch + self.max_num = self.config.max_dynamic_patch + self.image_size = self.vision_config.image_size + + def __call__(self, text: str, images: Union[Image, List[Image]], + **kwargs): + # yapf: disable + from vllm.model_executor.models.h2ovl import ( + IMG_CONTEXT, IMG_END, IMG_START, image_to_pixel_values) + + # yapf: enable + images = [images] if isinstance(images, Image) else images + pixel_values = [ + image_to_pixel_values(image, + self.image_size, + self.min_num, + self.max_num, + self.use_thumbnail, + use_MSAC=self.config.use_msac).to( + self.dtype) for image in images + ] + num_patches_list = [ + pixel_value.shape[0] for pixel_value in pixel_values + ] + pixel_values = torch.cat(pixel_values, dim=0) + for num_patches in num_patches_list: + context_tokens = IMG_CONTEXT * self.num_image_token \ + * num_patches + image_tokens = IMG_START + context_tokens + IMG_END + text = text.replace('', image_tokens, 1) + prompt = self.tokenizer(text, return_tensors="pt") + prompt.update({"pixel_values": pixel_values}) + return prompt + + img_context_token_id = hf_model.tokenizer.convert_tokens_to_ids( + "") + hf_model.model.img_context_token_id = img_context_token_id + hf_model.processor = H2OVLProcessor(hf_model) + hf_model.model.get_output_embeddings = lambda: \ + hf_model.model.language_model.get_output_embeddings() + hf_model.model.generate = types.MethodType(_internvl_generate, + hf_model.model) + return hf_model + + def internvl_patch_hf_runner(hf_model: HfRunner) -> HfRunner: """Patches and returns an instance of the HfRunner to use for InternVL.""" diff --git a/tests/models/encoder_decoder/vision_language/test_florence2.py b/tests/models/encoder_decoder/vision_language/test_florence2.py index 483773f069133..d686f1da3fa17 100644 --- a/tests/models/encoder_decoder/vision_language/test_florence2.py +++ b/tests/models/encoder_decoder/vision_language/test_florence2.py @@ -85,7 +85,7 @@ def run_test( @pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["float"]) +@pytest.mark.parametrize("dtype", ["float", "bfloat16"]) @pytest.mark.parametrize("max_tokens", [64]) @pytest.mark.parametrize("num_logprobs", [5]) def test_models(hf_runner, vllm_runner, model, dtype, max_tokens, diff --git a/tests/multimodal/test_processor_kwargs.py b/tests/multimodal/test_processor_kwargs.py index 5044740c3e734..4d3bbd805c152 100644 --- a/tests/multimodal/test_processor_kwargs.py +++ b/tests/multimodal/test_processor_kwargs.py @@ -5,8 +5,8 @@ import pytest import torch -from vllm.inputs import DecoderOnlyInputs, InputContext, token_inputs -from vllm.inputs.registry import InputRegistry +from vllm.inputs import (DecoderOnlyInputs, DummyData, InputContext, + InputRegistry, token_inputs) from vllm.multimodal import MultiModalRegistry from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData @@ -56,7 +56,7 @@ def custom_dummy_data_factory(self, num_crops=DEFAULT_NUM_CROPS): seq_data = SequenceData( array(VLLM_TOKEN_ID_ARRAY_TYPE, [0] * num_crops)) - return seq_data, None + return DummyData(seq_data, None) with patch( "vllm.inputs.registry.InputRegistry._default_dummy_data_factory", @@ -177,9 +177,9 @@ def test_dummy_data_kwarg_overrides(use_dummy_data_mock, num_crops): # NOTE: seq_len is thrown away here since this will leverage the # default dummy data factory that we have patched in, whose seq # len is solely dependent on the value of the mm_processor_kwargs. - seq_data, _ = dummy_registry.dummy_data_for_profiling( + dummy_data = dummy_registry.dummy_data_for_profiling( ctx.model_config, seq_len=-1, mm_registry=mm_registry) - assert len(seq_data.prompt_token_ids) == expected_seq_count + assert len(dummy_data.seq_data.prompt_token_ids) == expected_seq_count @pytest.mark.parametrize( @@ -206,9 +206,9 @@ def test_dummy_data_with_sad_kwarg_overrides(use_dummy_data_mock, # NOTE: seq_len is thrown away here since this will leverage the # default dummy data factory that we have patched in, whose seq # len is solely dependent on the value of the mm_processor_kwargs. - seq_data, _ = dummy_registry.dummy_data_for_profiling( + dummy_data = dummy_registry.dummy_data_for_profiling( ctx.model_config, seq_len=-1, mm_registry=mm_registry) - assert len(seq_data.prompt_token_ids) == DEFAULT_NUM_CROPS + assert len(dummy_data.seq_data.prompt_token_ids) == DEFAULT_NUM_CROPS ### Test overrides for the max token count per multimodal instance diff --git a/tests/multimodal/test_utils.py b/tests/multimodal/test_utils.py index 38cd48629f903..69f04f0a69c0b 100644 --- a/tests/multimodal/test_utils.py +++ b/tests/multimodal/test_utils.py @@ -92,18 +92,50 @@ def test_repeat_and_pad_placeholder_tokens(model): tokenizer = AutoTokenizer.from_pretrained(model) test_cases = [ - ("", 2, "", [32000, 32000]), - ("", 2, "", [32000, 32000, 32000]), - ("", [3, 2], "", - [32000, 32000, 32000, 32000, 32000]), - ("Image:Image:!", [3, 2], - "Image:Image:!", - [9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918]), - ("", [3, 2], "", [32000, 32000, 32000]), - ] - - for prompt, repeat_count, expected_prompt, expected_token_ids in test_cases: - new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens( + ( + "", + 2, + "", + [32000, 32000], + [{ "offset": 0, "length": 2 }], + ), + ( + "", + 2, + "", + [32000, 32000, 32000], + [{ "offset": 0, "length": 2 }]), + ( + "", + [3, 2], + "", + [32000, 32000, 32000, 32000, 32000], + [{ "offset": 0, "length": 3 }, { "offset": 3, "length": 2 }], + ), + ( + "Image:Image:!", + [3, 2], + "Image:Image:!", + [9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918], + [{ "offset": 2, "length": 3 }, { "offset": 7, "length": 2 }], + ), + ( + "", + [3, 2], + "", + [32000, 32000, 32000], + [{ "offset": 0, "length": 3 }], + ), + ] # yapf: disable + + for ( + prompt, + repeat_count, + expected_prompt, + expected_token_ids, + expected_ranges, + ) in test_cases: + new_prompt, new_token_ids, ranges = repeat_and_pad_placeholder_tokens( tokenizer=tokenizer, prompt=prompt, prompt_token_ids=tokenizer.encode(prompt, @@ -113,3 +145,4 @@ def test_repeat_and_pad_placeholder_tokens(model): ) assert new_prompt == expected_prompt assert new_token_ids == expected_token_ids + assert ranges == expected_ranges diff --git a/tests/spec_decode/utils.py b/tests/spec_decode/utils.py index f683942a5854b..6cf0cfb09b8fa 100644 --- a/tests/spec_decode/utils.py +++ b/tests/spec_decode/utils.py @@ -81,12 +81,7 @@ def create_worker(cls: Callable[..., T], get_ip(), get_open_port()) worker = cls( - model_config=engine_config.model_config, - parallel_config=engine_config.parallel_config, - scheduler_config=engine_config.scheduler_config, - device_config=engine_config.device_config, - cache_config=engine_config.cache_config, - load_config=engine_config.load_config, + vllm_config=engine_config, local_rank=0, rank=0, distributed_init_method=distributed_init_method, diff --git a/tests/worker/test_encoder_decoder_model_runner.py b/tests/worker/test_encoder_decoder_model_runner.py index e75884a7395e2..9e166ae64dbfb 100644 --- a/tests/worker/test_encoder_decoder_model_runner.py +++ b/tests/worker/test_encoder_decoder_model_runner.py @@ -19,14 +19,7 @@ def _create_model_runner(model: str, *args, engine_args = EngineArgs(model, *args, **kwargs) engine_config = engine_args.create_engine_config() model_runner = EncoderDecoderModelRunner( - model_config=engine_config.model_config, - parallel_config=engine_config.parallel_config, - scheduler_config=engine_config.scheduler_config, - device_config=engine_config.device_config, - cache_config=engine_config.cache_config, - load_config=engine_config.load_config, - lora_config=engine_config.lora_config, - prompt_adapter_config=engine_config.prompt_adapter_config, + vllm_config=engine_config, is_driver_worker=True, ) return model_runner diff --git a/tests/worker/test_model_input.py b/tests/worker/test_model_input.py index 1e7f560fc68cc..b36e8bfe73ff3 100644 --- a/tests/worker/test_model_input.py +++ b/tests/worker/test_model_input.py @@ -73,6 +73,7 @@ def test_model_runner_input(): num_prefill_tokens=2, num_decode_tokens=3, slot_mapping=torch.zeros(1), + multi_modal_placeholder_index_maps=None, ) model_input = ModelInputForGPUWithSamplingMetadata( input_tokens=torch.ones(10), @@ -124,6 +125,7 @@ def test_embedding_model_runner_input(): num_prefill_tokens=2, num_decode_tokens=3, slot_mapping=torch.zeros(1), + multi_modal_placeholder_index_maps=None, ) model_input = ModelInputForGPUWithPoolingMetadata( input_tokens=torch.ones(10), @@ -174,6 +176,7 @@ def test_multi_step_model_runner_input(): num_prefill_tokens=2, num_decode_tokens=3, slot_mapping=torch.zeros(1), + multi_modal_placeholder_index_maps=None, ) frozen_model_input = ModelInputForGPUWithSamplingMetadata( input_tokens=torch.ones(10), diff --git a/tests/worker/test_model_runner.py b/tests/worker/test_model_runner.py index fe97199bac62d..433a9b30ba57a 100644 --- a/tests/worker/test_model_runner.py +++ b/tests/worker/test_model_runner.py @@ -16,15 +16,7 @@ def _create_model_runner(model: str, *args, **kwargs) -> ModelRunner: engine_args = EngineArgs(model, *args, **kwargs) engine_config = engine_args.create_engine_config() model_runner = ModelRunner( - model_config=engine_config.model_config, - parallel_config=engine_config.parallel_config, - scheduler_config=engine_config.scheduler_config, - device_config=engine_config.device_config, - cache_config=engine_config.cache_config, - load_config=engine_config.load_config, - lora_config=engine_config.lora_config, - prompt_adapter_config=engine_config.prompt_adapter_config, - observability_config=engine_config.observability_config, + vllm_config=engine_config, is_driver_worker=True, ) return model_runner diff --git a/tests/worker/test_profile.py b/tests/worker/test_profile.py index acd2ed6836365..194ea2aa506f4 100644 --- a/tests/worker/test_profile.py +++ b/tests/worker/test_profile.py @@ -24,12 +24,7 @@ def test_gpu_memory_profiling(): distributed_init_method = get_distributed_init_method( get_ip(), get_open_port()) worker = Worker( - model_config=engine_config.model_config, - parallel_config=engine_config.parallel_config, - scheduler_config=engine_config.scheduler_config, - device_config=engine_config.device_config, - cache_config=engine_config.cache_config, - load_config=engine_config.load_config, + vllm_config=engine_config, local_rank=0, rank=0, distributed_init_method=distributed_init_method, diff --git a/tests/worker/test_swap.py b/tests/worker/test_swap.py index 7aa439ba0a154..acede959f59f8 100644 --- a/tests/worker/test_swap.py +++ b/tests/worker/test_swap.py @@ -19,12 +19,7 @@ def test_swap() -> None: distributed_init_method = get_distributed_init_method( get_ip(), get_open_port()) worker = Worker( - model_config=engine_config.model_config, - parallel_config=engine_config.parallel_config, - scheduler_config=engine_config.scheduler_config, - device_config=engine_config.device_config, - cache_config=engine_config.cache_config, - load_config=engine_config.load_config, + vllm_config=engine_config, local_rank=0, rank=0, distributed_init_method=distributed_init_method, diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index 9ea89eca01f5b..a504cb1f7e318 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -7,6 +7,8 @@ import torch +from vllm.multimodal import MultiModalPlaceholderMap + if TYPE_CHECKING: from vllm.worker.model_runner_base import (ModelRunnerBase, ModelRunnerInputBase, @@ -108,6 +110,15 @@ class AttentionMetadata: # in block 0, and 1st slot in block 1, respectively. slot_mapping: torch.Tensor + # The index maps that relate multi-modal embeddings to the corresponding + # placeholders. + # + # N.B. These aren't really related to attention and don't belong on this + # type -- this is just a temporary solution to make them available to + # `model_executable`. + multi_modal_placeholder_index_maps: Optional[Dict[ + str, MultiModalPlaceholderMap.IndexMap]] + @property @abstractmethod def prefill_metadata(self) -> Optional["AttentionMetadata"]: diff --git a/vllm/attention/backends/blocksparse_attn.py b/vllm/attention/backends/blocksparse_attn.py index c216d195c9e7e..409a42187f46c 100644 --- a/vllm/attention/backends/blocksparse_attn.py +++ b/vllm/attention/backends/blocksparse_attn.py @@ -215,6 +215,8 @@ def prefill_metadata( num_prefill_tokens=self.num_prefill_tokens, num_decode_tokens=0, slot_mapping=self.slot_mapping[:self.num_prefill_tokens], + multi_modal_placeholder_index_maps=self. + multi_modal_placeholder_index_maps, seq_lens=self.seq_lens[:self.num_prefills], seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills], max_query_len=self.max_query_len, @@ -243,6 +245,7 @@ def decode_metadata(self) -> Optional["BlocksparseFlashAttentionMetadata"]: num_prefill_tokens=0, num_decode_tokens=self.num_decode_tokens, slot_mapping=self.slot_mapping[self.num_prefill_tokens:], + multi_modal_placeholder_index_maps=None, seq_lens=None, seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:], max_query_len=None, diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index c294fcf7f08fe..26da0d89def29 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -1,5 +1,7 @@ """Attention layer with FlashAttention.""" +from collections import defaultdict from dataclasses import dataclass +from itertools import accumulate from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type import torch @@ -9,11 +11,13 @@ AttentionMetadata, AttentionMetadataBuilder, AttentionType) -from vllm.attention.backends.utils import (PAD_SLOT_ID, CommonAttentionState, - compute_slot_mapping, - compute_slot_mapping_start_idx, - is_block_tables_empty) +from vllm.attention.backends.utils import ( + PAD_SLOT_ID, CommonAttentionState, compute_slot_mapping, + compute_slot_mapping_start_idx, get_num_prefill_decode_query_kv_tokens, + get_seq_len_block_table_args, is_all_cross_attn_metadata_set, + is_all_encoder_attn_metadata_set, is_block_tables_empty) from vllm.forward_context import get_forward_context +from vllm.multimodal import MultiModalPlaceholderMap from vllm.utils import (async_tensor_h2d, direct_register_custom_op, make_tensor_with_pad) @@ -71,7 +75,6 @@ def swap_blocks( src_key_cache = src_kv_cache[0] dst_key_cache = dst_kv_cache[0] ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst) - src_value_cache = src_kv_cache[1] dst_value_cache = dst_kv_cache[1] ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst) @@ -83,6 +86,7 @@ def copy_blocks( ) -> None: key_caches = [kv_cache[0] for kv_cache in kv_caches] value_caches = [kv_cache[1] for kv_cache in kv_caches] + ops.copy_blocks(key_caches, value_caches, src_to_dists) @@ -109,26 +113,12 @@ class FlashAttentionMetadata(AttentionMetadata): # |-------------------- seq_len ---------------------| # |-- query_len ---| - # Maximum query length in the batch. - max_query_len: Optional[int] - - # Max number of query tokens among request in the batch. - max_decode_query_len: Optional[int] - # Maximum sequence length among prefill batch. 0 if there are decoding # requests only. max_prefill_seq_len: int # Maximum sequence length among decode batch. 0 if there are prefill # requests only. max_decode_seq_len: int - # (batch_size + 1,). The cumulative subquery lengths of the sequences in - # the batch, used to index into subquery. E.g., if the subquery length - # is [4, 6], it is [0, 4, 10]. - query_start_loc: Optional[torch.Tensor] - # (batch_size + 1,). The cumulative sequence lengths of the sequences in - # the batch, used to index into sequence. E.g., if the sequence length is - # [4, 6], it is [0, 4, 10]. - seq_start_loc: Optional[torch.Tensor] # (batch_size,) A tensor of context lengths (tokens that are computed # so far). context_lens_tensor: Optional[torch.Tensor] @@ -144,11 +134,62 @@ class FlashAttentionMetadata(AttentionMetadata): # Whether or not if cuda graph is enabled. # Cuda-graph is currently enabled for decoding only. # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. + use_cuda_graph: bool + # Maximum query length in the batch. + max_query_len: Optional[int] = None + + # Max number of query tokens among request in the batch. + max_decode_query_len: Optional[int] = None + + # (batch_size + 1,). The cumulative subquery lengths of the sequences in + # the batch, used to index into subquery. E.g., if the subquery length + # is [4, 6], it is [0, 4, 10]. + query_start_loc: Optional[torch.Tensor] = None + # (batch_size + 1,). The cumulative sequence lengths of the sequences in + # the batch, used to index into sequence. E.g., if the sequence length is + # [4, 6], it is [0, 4, 10]. + seq_start_loc: Optional[torch.Tensor] = None + _cached_prefill_metadata: Optional["FlashAttentionMetadata"] = None _cached_decode_metadata: Optional["FlashAttentionMetadata"] = None + # Begin encoder attn & enc/dec cross-attn fields... + + # Encoder sequence lengths representation + encoder_seq_lens: Optional[List[int]] = None + encoder_seq_lens_tensor: Optional[torch.Tensor] = None + # (batch_size + 1,). The cumulative sequence lengths of the sequences in + # the batch, used to index into sequence. E.g., if the sequence length is + # [4, 6], it is [0, 4, 10]. + encoder_seq_start_loc: Optional[torch.Tensor] = None + # Maximum sequence length among encoder sequences + max_encoder_seq_len: Optional[int] = None + # Number of tokens input to encoder + num_encoder_tokens: Optional[int] = None + + # Cross-attention memory-mapping data structures: slot mapping + # and block tables + cross_slot_mapping: Optional[torch.Tensor] = None + cross_block_tables: Optional[torch.Tensor] = None + + @property + def is_all_encoder_attn_metadata_set(self): + ''' + All attention metadata required for encoder attention is set. + ''' + return is_all_encoder_attn_metadata_set(self) + + @property + def is_all_cross_attn_metadata_set(self): + ''' + All attention metadata required for enc/dec cross-attention is set. + + Superset of encoder attention required metadata. + ''' + return is_all_cross_attn_metadata_set(self) + @property def prefill_metadata(self) -> Optional["FlashAttentionMetadata"]: if self.num_prefills == 0: @@ -157,30 +198,52 @@ def prefill_metadata(self) -> Optional["FlashAttentionMetadata"]: if self._cached_prefill_metadata is not None: return self._cached_prefill_metadata - assert self.seq_lens is not None - assert self.seq_lens_tensor is not None - assert self.query_start_loc is not None - assert self.context_lens_tensor is not None - assert self.block_tables is not None - assert self.seq_start_loc is not None + assert ((self.seq_lens is not None) + or (self.encoder_seq_lens is not None)) + assert ((self.seq_lens_tensor is not None) + or (self.encoder_seq_lens_tensor is not None)) + + # Compute some attn_metadata fields which default to None + query_start_loc = (None if self.query_start_loc is None else + self.query_start_loc[:self.num_prefills + 1]) + slot_mapping = (None if self.slot_mapping is None else + self.slot_mapping[:self.num_prefill_tokens]) + seq_lens = (None if self.seq_lens is None else + self.seq_lens[:self.num_prefills]) + seq_lens_tensor = (None if self.seq_lens_tensor is None else + self.seq_lens_tensor[:self.num_prefills]) + seq_start_loc = (None if self.seq_start_loc is None else + self.seq_start_loc[:self.num_prefills + 1]) + context_lens_tensor = (None if self.context_lens_tensor is None else + self.context_lens_tensor[:self.num_prefills]) + block_tables = (None if self.block_tables is None else + self.block_tables[:self.num_prefills]) self._cached_prefill_metadata = FlashAttentionMetadata( num_prefills=self.num_prefills, num_prefill_tokens=self.num_prefill_tokens, num_decode_tokens=0, - slot_mapping=self.slot_mapping[:self.num_prefill_tokens], - seq_lens=self.seq_lens[:self.num_prefills], - seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills], + slot_mapping=slot_mapping, + multi_modal_placeholder_index_maps=self. + multi_modal_placeholder_index_maps, + seq_lens=seq_lens, + seq_lens_tensor=seq_lens_tensor, max_query_len=self.max_query_len, max_prefill_seq_len=self.max_prefill_seq_len, max_decode_query_len=0, max_decode_seq_len=0, - query_start_loc=self.query_start_loc[:self.num_prefills + 1], - seq_start_loc=self.seq_start_loc[:self.num_prefills + 1], - context_lens_tensor=self.context_lens_tensor[:self.num_prefills], - block_tables=self.block_tables[:self.num_prefills], + query_start_loc=query_start_loc, + seq_start_loc=seq_start_loc, + context_lens_tensor=context_lens_tensor, + block_tables=block_tables, use_cuda_graph=False, - ) + # Begin encoder & cross attn fields below... + encoder_seq_lens=self.encoder_seq_lens, + encoder_seq_lens_tensor=self.encoder_seq_lens_tensor, + encoder_seq_start_loc=self.encoder_seq_start_loc, + max_encoder_seq_len=self.max_encoder_seq_len, + cross_slot_mapping=self.cross_slot_mapping, + cross_block_tables=self.cross_block_tables) return self._cached_prefill_metadata @property @@ -190,16 +253,25 @@ def decode_metadata(self) -> Optional["FlashAttentionMetadata"]: if self._cached_decode_metadata is not None: return self._cached_decode_metadata - assert self.block_tables is not None - assert self.seq_lens_tensor is not None + assert ((self.seq_lens_tensor is not None) + or (self.encoder_seq_lens_tensor is not None)) + + # Compute some attn_metadata fields which default to None + slot_mapping = (None if self.slot_mapping is None else + self.slot_mapping[self.num_prefill_tokens:]) + seq_lens_tensor = (None if self.seq_lens_tensor is None else + self.seq_lens_tensor[self.num_prefills:]) + block_tables = (None if self.block_tables is None else + self.block_tables[self.num_prefills:]) self._cached_decode_metadata = FlashAttentionMetadata( num_prefills=0, num_prefill_tokens=0, num_decode_tokens=self.num_decode_tokens, - slot_mapping=self.slot_mapping[self.num_prefill_tokens:], + slot_mapping=slot_mapping, + multi_modal_placeholder_index_maps=None, seq_lens=None, - seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:], + seq_lens_tensor=seq_lens_tensor, max_decode_query_len=self.max_decode_query_len, max_query_len=self.max_query_len, max_prefill_seq_len=0, @@ -209,9 +281,15 @@ def decode_metadata(self) -> Optional["FlashAttentionMetadata"]: seq_start_loc=self.seq_start_loc[self.num_prefills:] if self.seq_start_loc is not None else None, context_lens_tensor=None, - block_tables=self.block_tables[self.num_prefills:], + block_tables=block_tables, use_cuda_graph=self.use_cuda_graph, - ) + # Begin encoder & cross attn fields below... + encoder_seq_lens=self.encoder_seq_lens, + encoder_seq_lens_tensor=self.encoder_seq_lens_tensor, + encoder_seq_start_loc=self.encoder_seq_start_loc, + max_encoder_seq_len=self.max_encoder_seq_len, + cross_slot_mapping=self.cross_slot_mapping, + cross_block_tables=self.cross_block_tables) return self._cached_decode_metadata def advance_step(self, @@ -297,6 +375,9 @@ def __init__(self, input_builder: "ModelInputForGPUBuilder"): self.context_lens: List[int] = [] self.block_tables: List[List[int]] = [] self.curr_seq_lens: List[int] = [] + self.multimodal_placeholder_maps: Dict[ + str, + MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap) self.num_prefills = 0 self.num_prefill_tokens = 0 self.num_decode_tokens = 0 @@ -327,6 +408,12 @@ def _add_seq_group( self.context_lens.append(context_len) if is_prompt: + mm_maps = inter_data.multi_modal_placeholder_maps + if mm_maps: + for modality, placeholders in mm_maps.items(): + self.multimodal_placeholder_maps[modality].extend( + placeholders) + self.num_prefills += 1 self.num_prefill_tokens += token_len self.prefill_seq_lens.append(seq_len) @@ -417,6 +504,8 @@ def build(self, seq_lens: List[int], query_lens: List[int], max_prefill_seq_len = max(self.prefill_seq_lens, default=0) max_decode_seq_len = max(self.curr_seq_lens, default=0) num_decode_tokens = self.num_decode_tokens + query_start_loc = list(accumulate(query_lens, initial=0)) + seq_start_loc = list(accumulate(seq_lens, initial=0)) num_seqs = len(seq_lens) if use_captured_graph: @@ -439,24 +528,18 @@ def build(self, seq_lens: List[int], query_lens: List[int], device, self.runner.pin_memory) seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device, self.runner.pin_memory) - query_lens_tensor = async_tensor_h2d(query_lens, torch.long, device, - self.runner.pin_memory) slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long, device, self.runner.pin_memory) - query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1, - dtype=torch.int32, - device=device) - seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1, - dtype=torch.int32, - device=device) - torch.cumsum(seq_lens_tensor, - dim=0, - dtype=seq_start_loc.dtype, - out=seq_start_loc[1:]) - torch.cumsum(query_lens_tensor, - dim=0, - dtype=query_start_loc.dtype, - out=query_start_loc[1:]) + query_start_loc_tensor = async_tensor_h2d(query_start_loc, torch.int32, + device, + self.runner.pin_memory) + seq_start_loc_tensor = async_tensor_h2d(seq_start_loc, torch.int32, + device, self.runner.pin_memory) + placeholder_index_maps = { + modality: placeholder_map.index_map() + for modality, placeholder_map in + self.multimodal_placeholder_maps.items() + } return FlashAttentionMetadata( num_prefills=self.num_prefills, @@ -464,13 +547,14 @@ def build(self, seq_lens: List[int], query_lens: List[int], num_prefill_tokens=self.num_prefill_tokens, num_decode_tokens=num_decode_tokens, seq_lens=seq_lens, + multi_modal_placeholder_index_maps=placeholder_index_maps, seq_lens_tensor=seq_lens_tensor, max_query_len=max_query_len, max_decode_query_len=max_decode_query_len, max_prefill_seq_len=max_prefill_seq_len, max_decode_seq_len=max_decode_seq_len, - query_start_loc=query_start_loc, - seq_start_loc=seq_start_loc, + query_start_loc=query_start_loc_tensor, + seq_start_loc=seq_start_loc_tensor, context_lens_tensor=context_lens_tensor, block_tables=block_tables, use_cuda_graph=use_captured_graph, @@ -566,16 +650,20 @@ def forward( Returns: shape = [num_tokens, num_heads * head_size] """ - if attn_type != AttentionType.DECODER: - raise NotImplementedError("Encoder self-attention and " - "encoder/decoder cross-attention " - "are not implemented for " - "FlashAttentionImpl") - # NOTE(woosuk): FlashAttention does not support FP8 KV cache. assert k_scale == 1.0 and v_scale == 1.0, ( "key/v_scale is not supported in FlashAttention.") + if (attn_type == AttentionType.ENCODER + and (not attn_metadata.is_all_encoder_attn_metadata_set)): + raise AttributeError("Encoder attention requires setting " + "encoder metadata attributes.") + elif (attn_type == AttentionType.ENCODER_DECODER + and (not attn_metadata.is_all_cross_attn_metadata_set)): + raise AttributeError("Encoder/decoder cross-attention " + "requires setting cross-attention " + "metadata attributes.") + output = torch.ops.vllm.unified_flash_attention( query, key, @@ -588,6 +676,7 @@ def forward( k_scale, v_scale, self.scale, + attn_type.value, self.sliding_window, self.alibi_slopes, self.logits_soft_cap, @@ -596,6 +685,89 @@ def forward( return output +def _get_query_key_seq_metadata( + attn_metadata, + is_prompt: bool, + attn_type: AttentionType, +) -> tuple: + """ + Returns sequence metadata for key and query based on the specified + attention type and whether input is a prompt. + + This function computes the starting locations and maximum sequence lengths + for key and query sequences for different attention types. + + Args: + attn_metadata: The attention metadata object + is_prompt (bool): A flag indicating if the input is a prompt + attn_type (AttentionType): The type of attention being used. + + Returns: + tuple: A tuple containing four integers: + - Starting location for the query sequence. + - Maximum sequence length for the query sequence. + - Starting location for the key sequence. + - Maximum sequence length for the key sequence. + + Raises: + AttributeError: If an invalid attention type is provided. + """ + if attn_type == AttentionType.DECODER: + # Decoder self-attention + # Choose max_seq_len based on whether we are in prompt_run + if is_prompt: + max_seq_len = attn_metadata.max_prefill_seq_len + else: + max_seq_len = attn_metadata.max_decode_seq_len + return (attn_metadata.seq_start_loc, max_seq_len, + attn_metadata.seq_start_loc, max_seq_len) + + elif attn_type == AttentionType.ENCODER_DECODER: + # This is cross attention between the where the key + # is the precomputed encoder attention and query + # is the input sequence. + # Choose query max length based on whether it is prompt + # or not. + if is_prompt: + max_seq_len = attn_metadata.max_prefill_seq_len + else: + max_seq_len = attn_metadata.max_decode_seq_len + return (attn_metadata.seq_start_loc, max_seq_len, + attn_metadata.encoder_seq_start_loc, + attn_metadata.max_encoder_seq_len) + elif attn_type == AttentionType.ENCODER: + # For encoder attention both the query and the key are same i.e the + # encoder sequence. + return (attn_metadata.encoder_seq_start_loc, + attn_metadata.max_encoder_seq_len, + attn_metadata.encoder_seq_start_loc, + attn_metadata.max_encoder_seq_len) + elif attn_type == AttentionType.ENCODER_ONLY: + assert is_prompt, "Should not have decode for encoder only model." + return (attn_metadata.seq_start_loc, attn_metadata.max_prefill_seq_len, + attn_metadata.seq_start_loc, attn_metadata.max_prefill_seq_len) + else: + raise AttributeError(f"Invalid attention type {str(attn_type)}") + + +def _get_causal_option(attn_type: AttentionType) -> bool: + """ + Determine whether the given attention type is suitable for causal + attention mechanisms. + + Args: + attn_type (AttentionType): The type of attention being evaluated + + Returns: + bool: Returns `True` if the attention type is suitable for causal + attention (i.e., not encoder, encoder-only, or encoder-decoder), + otherwise returns `False`. + """ + return not (attn_type == AttentionType.ENCODER + or attn_type == AttentionType.ENCODER_ONLY + or attn_type == AttentionType.ENCODER_DECODER) + + def unified_flash_attention( query: torch.Tensor, key: torch.Tensor, @@ -608,60 +780,76 @@ def unified_flash_attention( k_scale: float, v_scale: float, softmax_scale: float, + attn_type_int_val: int, window_size: Optional[List[int]] = None, alibi_slopes: Optional[torch.Tensor] = None, logits_soft_cap: Optional[float] = None, ) -> torch.Tensor: + # Convert integer attn_type to enum + try: + attn_type = AttentionType(attn_type_int_val) + except ValueError as err: + raise AttributeError( + f"Invalid attention type {str(attn_type_int_val)}") from err + current_metadata = get_forward_context() assert current_metadata is not None assert isinstance(current_metadata, FlashAttentionMetadata) attn_metadata: FlashAttentionMetadata = current_metadata num_tokens, hidden_size = query.shape + # Reshape the query, key, and value tensors. query = query.view(-1, num_heads, head_size) - key = key.view(-1, num_kv_heads, head_size) - value = value.view(-1, num_kv_heads, head_size) + if (key is not None) and (value is not None): + key = key.view(-1, num_kv_heads, head_size) + value = value.view(-1, num_kv_heads, head_size) if kv_cache.numel() > 0: key_cache = kv_cache[0] value_cache = kv_cache[1] + # We skip updating the KV cache under two conditions: + # a. When the Attention Type is ENCODER. In this phase, we compute + # only the encoder attention without updating the cache. + # b. When both Key and Value are None. This occurs during + # cross-attention computation in the decoding phase, where the KV + # cache is already populated with the cross-attention tensor. + # Thus, we skip cache updates during this time. + if (attn_type != AttentionType.ENCODER) and (key is not None) and ( + value is not None): + if attn_type == AttentionType.ENCODER_DECODER: + # Update cross-attention KV cache (prefill-only) + updated_slot_mapping = attn_metadata.cross_slot_mapping + else: + # Update self-attention KV cache (prefill/decode) + updated_slot_mapping = attn_metadata.slot_mapping + + # Reshape the input keys and values and store them in the cache. + # If kv_cache is not provided, the new key and value tensors are + # not cached. This happens during the initial memory profiling run. + torch.ops._C_cache_ops.reshape_and_cache_flash( + key, + value, + kv_cache[0], + kv_cache[1], + updated_slot_mapping.flatten(), # type: ignore[union-attr] + kv_cache_dtype, + k_scale, + v_scale, + ) - # Reshape the input keys and values and store them in the cache. - # If kv_cache is not provided, the new key and value tensors are - # not cached. This happens during the initial memory profiling run. - torch.ops._C_cache_ops.reshape_and_cache_flash( - key, - value, - kv_cache[0], - kv_cache[1], - attn_metadata.slot_mapping.flatten(), - kv_cache_dtype, - k_scale, - v_scale, - ) - - num_prefill_tokens = attn_metadata.num_prefill_tokens - num_decode_tokens = attn_metadata.num_decode_tokens - assert key.shape[0] == num_prefill_tokens + num_decode_tokens, \ - f"key : {key.shape} : #prefill tokens {num_prefill_tokens} : #decode tokens {num_decode_tokens}" # noqa - assert value.shape[0] == num_prefill_tokens + num_decode_tokens, \ - f"value : {value.shape} : #prefill toks {num_prefill_tokens} : #decode toks {num_decode_tokens}" # noqa - - # Query for decode. KV is not needed because it is already cached. - decode_query = query[num_prefill_tokens:] + (num_prefill_query_tokens, num_prefill_kv_tokens, + num_decode_query_tokens) = \ + get_num_prefill_decode_query_kv_tokens(attn_metadata, attn_type) + decode_query = query[num_prefill_query_tokens:] # QKV for prefill. - query = query[:num_prefill_tokens] - key = key[:num_prefill_tokens] - value = value[:num_prefill_tokens] - - assert query.shape[0] == num_prefill_tokens - assert decode_query.shape[0] == num_decode_tokens + query = query[:num_prefill_query_tokens] + assert query.shape[0] == num_prefill_query_tokens + assert decode_query.shape[0] == num_decode_query_tokens prefill_output: Optional[torch.Tensor] = None decode_output: Optional[torch.Tensor] = None - if prefill_meta := attn_metadata.prefill_metadata: # Prompt run. if (kv_cache.numel() == 0 or prefill_meta.block_tables is None @@ -669,22 +857,30 @@ def unified_flash_attention( # normal attention # When block_tables are not filled, it means q and k are the # prompt, and they have the same length. + q_seq_start_loc, q_seq_len, k_seq_start_loc, k_seq_len = \ + _get_query_key_seq_metadata(prefill_meta, True, attn_type) + + key = key[:num_prefill_kv_tokens] + value = value[:num_prefill_kv_tokens] + prefill_output = flash_attn_varlen_func( q=query, k=key, v=value, - cu_seqlens_q=prefill_meta.seq_start_loc, - cu_seqlens_k=prefill_meta.seq_start_loc, - max_seqlen_q=prefill_meta.max_prefill_seq_len, - max_seqlen_k=prefill_meta.max_prefill_seq_len, + cu_seqlens_q=q_seq_start_loc, + cu_seqlens_k=k_seq_start_loc, + max_seqlen_q=q_seq_len, + max_seqlen_k=k_seq_len, softmax_scale=softmax_scale, - causal=True, + causal=_get_causal_option(attn_type), window_size=window_size, alibi_slopes=alibi_slopes, softcap=logits_soft_cap, ) else: # prefix-enabled attention + assert attn_type == AttentionType.DECODER, ( + "Only decoder-only models support prefix caching") assert prefill_meta.seq_lens is not None max_seq_len = max(prefill_meta.seq_lens) prefill_output = flash_attn_varlen_func( # noqa @@ -709,6 +905,8 @@ def unified_flash_attention( # because different queries might have different lengths. assert decode_meta.max_decode_query_len is not None if decode_meta.max_decode_query_len > 1: + assert attn_type == AttentionType.DECODER, ( + "Only decoder-only models support max_decode_query_len > 1") decode_output = flash_attn_varlen_func( q=decode_query, k=key_cache, @@ -726,12 +924,17 @@ def unified_flash_attention( ) else: # Use flash_attn_with_kvcache for normal decoding. + ( + seq_lens_arg, + _, + block_tables_arg, + ) = get_seq_len_block_table_args(decode_meta, False, attn_type) decode_output = flash_attn_with_kvcache( q=decode_query.unsqueeze(1), k_cache=key_cache, v_cache=value_cache, - block_table=decode_meta.block_tables, - cache_seqlens=decode_meta.seq_lens_tensor, + block_table=block_tables_arg, + cache_seqlens=seq_lens_arg, softmax_scale=softmax_scale, causal=True, window_size=window_size, @@ -741,10 +944,10 @@ def unified_flash_attention( if prefill_output is None: assert decode_output is not None - return decode_output.view(num_decode_tokens, hidden_size) + return decode_output.view(num_decode_query_tokens, hidden_size) if decode_output is None: assert prefill_output is not None - return prefill_output.view(num_prefill_tokens, hidden_size) + return prefill_output.view(num_prefill_query_tokens, hidden_size) # Chunked prefill does not work with speculative decoding. # Therefore, the query length for decode should be 1 in chunked prefill. @@ -766,6 +969,7 @@ def unified_flash_attention_fake( k_scale: float, v_scale: float, softmax_scale: float, + attn_type_int_val: int, window_size: Optional[List[int]] = None, alibi_slopes: Optional[torch.Tensor] = None, logits_soft_cap: Optional[float] = None, diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 234c87d5c4edb..107e3bbf79666 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -1,7 +1,10 @@ +from collections import defaultdict from contextlib import contextmanager from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Type +from vllm.multimodal import MultiModalPlaceholderMap + try: from flashinfer import BatchDecodeWithPagedKVCacheWrapper from flashinfer.decode import CUDAGraphBatchDecodeWithPagedKVCacheWrapper @@ -215,6 +218,7 @@ def graph_capture_get_metadata_for_batch( attn_metadata = self.runner.attn_backend.make_metadata( num_prefills=0, slot_mapping=self._graph_slot_mapping[:batch_size], + multi_modal_placeholder_index_maps=None, num_prefill_tokens=0, num_decode_tokens=batch_size, max_prefill_seq_len=0, @@ -470,6 +474,9 @@ def __init__(self, input_builder: "ModelInputForGPUBuilder"): self.context_lens: List[int] = [] self.block_tables: List[List[int]] = [] self.curr_seq_lens: List[int] = [] + self.multimodal_placeholder_maps: Dict[ + str, + MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap) self.num_prefills = 0 self.num_prefill_tokens = 0 self.num_decode_tokens = 0 @@ -519,6 +526,11 @@ def _add_seq_group( inter_data.curr_sliding_window_blocks): self.context_lens.append(context_len) if is_prompt: + mm_maps = inter_data.multi_modal_placeholder_maps + if mm_maps: + for modality, placeholders in mm_maps.items(): + self.multimodal_placeholder_maps[modality].extend( + placeholders) self.num_prefills += 1 self.num_prefill_tokens += token_len self.prefill_seq_lens.append(seq_len) @@ -651,6 +663,11 @@ def build(self, seq_lens: List[int], query_lens: List[int], seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1, dtype=torch.int32, device=device) + placeholder_index_maps = { + modality: placeholder_map.index_map() + for modality, placeholder_map in + self.multimodal_placeholder_maps.items() + } torch.cumsum(seq_lens_tensor, dim=0, dtype=seq_start_loc.dtype, @@ -694,6 +711,7 @@ def build(self, seq_lens: List[int], query_lens: List[int], decode_query_len=decode_query_len, num_prefills=self.num_prefills, slot_mapping=slot_mapping_tensor, + multi_modal_placeholder_index_maps=placeholder_index_maps, num_prefill_tokens=self.num_prefill_tokens, num_decode_tokens=num_decode_tokens, max_prefill_seq_len=max_prefill_seq_len, @@ -759,8 +777,6 @@ def forward( v_scale: float = 1.0, attn_type: AttentionType = AttentionType.DECODER, ) -> torch.Tensor: - assert k_scale == 1.0 and v_scale == 1.0, ( - "key/v_scale is not supported in FlashInfer.") if attn_type != AttentionType.DECODER: raise NotImplementedError("Encoder self-attention and " "encoder/decoder cross-attention " @@ -874,7 +890,12 @@ def unified_flash_infer( assert prefill_meta is not None assert prefill_meta.prefill_wrapper is not None prefill_output = prefill_meta.prefill_wrapper.forward( - query, kv_cache, logits_soft_cap=logits_soft_cap, causal=True) + query, + kv_cache, + logits_soft_cap=logits_soft_cap, + causal=True, + k_scale=k_scale, + v_scale=v_scale) if decode_meta := attn_metadata.decode_metadata: assert attn_metadata.decode_metadata is not None assert attn_metadata.decode_metadata.decode_wrapper is not None diff --git a/vllm/attention/backends/placeholder_attn.py b/vllm/attention/backends/placeholder_attn.py index 4116fbf00020c..888adbffb8578 100644 --- a/vllm/attention/backends/placeholder_attn.py +++ b/vllm/attention/backends/placeholder_attn.py @@ -1,5 +1,6 @@ +from collections import defaultdict from dataclasses import dataclass -from typing import TYPE_CHECKING, List, Optional, Tuple, Type +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Type import torch @@ -7,6 +8,7 @@ AttentionMetadata, AttentionMetadataBuilder) from vllm.attention.backends.utils import CommonAttentionState +from vllm.multimodal import MultiModalPlaceholderMap if TYPE_CHECKING: from vllm.worker.model_runner import ModelInputForGPUBuilder @@ -135,6 +137,8 @@ def prefill_metadata(self) -> Optional["PlaceholderAttentionMetadata"]: num_prefill_tokens=self.num_prefill_tokens, num_decode_tokens=0, slot_mapping=slot_mapping, + multi_modal_placeholder_index_maps=self. + multi_modal_placeholder_index_maps, seq_lens=self.seq_lens[:self.num_prefills], seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills], max_decode_query_len=0, @@ -167,6 +171,7 @@ def decode_metadata(self) -> Optional["PlaceholderAttentionMetadata"]: num_prefill_tokens=0, num_decode_tokens=self.num_decode_tokens, slot_mapping=slot_mapping, + multi_modal_placeholder_index_maps=None, seq_lens=None, seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:], max_decode_query_len=self.max_decode_query_len, @@ -189,6 +194,9 @@ def __init__(self, input_builder: "ModelInputForGPUBuilder"): self.prefill_seq_lens: List[int] = [] self.context_lens: List[int] = [] self.curr_seq_lens: List[int] = [] + self.multimodal_placeholder_maps: Dict[ + str, + MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap) self.num_prefills = 0 self.num_prefill_tokens = 0 self.num_decode_tokens = 0 @@ -213,6 +221,12 @@ def _add_seq_group( self.context_lens.append(context_len) if is_prompt: + mm_maps = inter_data.multi_modal_placeholder_maps + if mm_maps: + for modality, placeholders in mm_maps.items(): + self.multimodal_placeholder_maps[modality].extend( + placeholders) + self.num_prefills += 1 self.num_prefill_tokens += token_len self.prefill_seq_lens.append(seq_len) @@ -280,6 +294,11 @@ def build(self, seq_lens: List[int], query_lens: List[int], seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1, dtype=torch.int32, device=device) + placeholder_index_maps = { + modality: placeholder_map.index_map() + for modality, placeholder_map in + self.multimodal_placeholder_maps.items() + } torch.cumsum(seq_lens_tensor, dim=0, dtype=seq_start_loc.dtype, @@ -296,6 +315,7 @@ def build(self, seq_lens: List[int], query_lens: List[int], return PlaceholderAttentionMetadata( num_prefills=self.num_prefills, slot_mapping=slot_mapping, + multi_modal_placeholder_index_maps=placeholder_index_maps, num_prefill_tokens=self.num_prefill_tokens, num_decode_tokens=num_decode_tokens, seq_lens=seq_lens, diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 30859dfa60634..b129d0d992f2f 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -150,6 +150,8 @@ def prefill_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]: num_prefill_tokens=self.num_prefill_tokens, num_decode_tokens=0, slot_mapping=self.slot_mapping[:self.num_prefill_tokens], + multi_modal_placeholder_index_maps=self. + multi_modal_placeholder_index_maps, seq_lens=self.seq_lens[:self.num_prefills], seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills], max_query_len=self.max_query_len, @@ -178,6 +180,7 @@ def decode_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]: num_prefill_tokens=0, num_decode_tokens=self.num_decode_tokens, slot_mapping=self.slot_mapping[self.num_prefill_tokens:], + multi_modal_placeholder_index_maps=None, seq_lens=None, seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:], max_query_len=None, diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py index 32fccd0dfb496..12800668af223 100644 --- a/vllm/attention/backends/utils.py +++ b/vllm/attention/backends/utils.py @@ -1,12 +1,16 @@ """Attention backend utils""" +from collections import defaultdict from contextlib import contextmanager -from typing import TYPE_CHECKING, Any, Dict, List, Type, TypeVar, Union +from itertools import accumulate +from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Type, TypeVar, Union import numpy as np import torch from vllm.attention import (AttentionMetadata, AttentionMetadataBuilder, AttentionState) +from vllm.attention.backends.abstract import AttentionType +from vllm.multimodal import MultiModalPlaceholderMap from vllm.utils import async_tensor_h2d, make_tensor_with_pad if TYPE_CHECKING: @@ -123,6 +127,9 @@ def __init__(self, input_builder: "ModelInputForGPUBuilder"): self.context_lens: List[int] = [] self.block_tables: List[List[int]] = [] self.curr_seq_lens: List[int] = [] + self.multimodal_placeholder_maps: Dict[ + str, + MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap) self.num_prefills = 0 self.num_prefill_tokens = 0 self.num_decode_tokens = 0 @@ -147,6 +154,12 @@ def _add_seq_group( inter_data.curr_sliding_window_blocks): self.context_lens.append(context_len) if is_prompt: + mm_maps = inter_data.multi_modal_placeholder_maps + if mm_maps: + for modality, placeholders in mm_maps.items(): + self.multimodal_placeholder_maps[modality].extend( + placeholders) + self.num_prefills += 1 self.num_prefill_tokens += token_len self.prefill_seq_lens.append(seq_len) @@ -204,6 +217,8 @@ def build(self, seq_lens: List[int], query_lens: List[int], max_prefill_seq_len = max(self.prefill_seq_lens, default=0) max_decode_seq_len = max(self.curr_seq_lens, default=0) num_decode_tokens = self.num_decode_tokens + query_start_loc = list(accumulate(query_lens, initial=0)) + seq_start_loc = list(accumulate(seq_lens, initial=0)) if use_captured_graph: self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size) @@ -232,28 +247,23 @@ def build(self, seq_lens: List[int], query_lens: List[int], device, self.runner.pin_memory) seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device, self.runner.pin_memory) - query_lens_tensor = async_tensor_h2d(query_lens, torch.long, device, - self.runner.pin_memory) slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long, device, self.runner.pin_memory) - query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1, - dtype=torch.int32, - device=device) - seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1, - dtype=torch.int32, - device=device) - torch.cumsum(seq_lens_tensor, - dim=0, - dtype=seq_start_loc.dtype, - out=seq_start_loc[1:]) - torch.cumsum(query_lens_tensor, - dim=0, - dtype=query_start_loc.dtype, - out=query_start_loc[1:]) + query_start_loc_tensor = async_tensor_h2d(query_start_loc, torch.int32, + device, + self.runner.pin_memory) + seq_start_loc_tensor = async_tensor_h2d(seq_start_loc, torch.int32, + device, self.runner.pin_memory) + placeholder_index_maps = { + modality: placeholder_map.index_map() + for modality, placeholder_map in + self.multimodal_placeholder_maps.items() + } return self._metadata_cls( # type: ignore num_prefills=self.num_prefills, slot_mapping=slot_mapping_tensor, + multi_modal_placeholder_index_maps=placeholder_index_maps, num_prefill_tokens=self.num_prefill_tokens, num_decode_tokens=num_decode_tokens, seq_lens=seq_lens, @@ -261,8 +271,8 @@ def build(self, seq_lens: List[int], query_lens: List[int], max_query_len=max_query_len, max_prefill_seq_len=max_prefill_seq_len, max_decode_seq_len=max_decode_seq_len, - query_start_loc=query_start_loc, - seq_start_loc=seq_start_loc, + query_start_loc=query_start_loc_tensor, + seq_start_loc=seq_start_loc_tensor, context_lens_tensor=context_lens_tensor, block_tables=block_tables, use_cuda_graph=use_captured_graph, @@ -305,6 +315,7 @@ def graph_capture_get_metadata_for_batch( num_prefill_tokens=0, num_decode_tokens=batch_size, slot_mapping=self._graph_slot_mapping[:batch_size], + multi_modal_placeholder_index_maps=None, seq_lens=None, seq_lens_tensor=self._graph_seq_lens[:batch_size], max_query_len=1, @@ -318,11 +329,13 @@ def graph_capture_get_metadata_for_batch( use_cuda_graph=True, ) if is_encoder_decoder_model: - # The encoder decoder model works only with XFormers backend. - # Assert the same. - assert self.runner.attn_backend.get_name() == "XFORMERS", \ - f"Expected attn_backend name to be 'XFORMERS', but "\ - f" got '{self.runner.attn_backend.get_name()}'" + # The encoder decoder model works only with XFormers and + # Flash Attention backend. Assert the same. + assert self.runner.attn_backend.get_name() in\ + ["XFORMERS", "FLASH_ATTN"], \ + f"Expected attn_backend name to be either 'XFORMERS' or " \ + f"'FLASH_ATTN', but "\ + f"got '{self.runner.attn_backend.get_name()}'" self._update_captured_metadata_for_enc_dec_model( batch_size=batch_size, attn_metadata=attn_metadata) @@ -338,11 +351,13 @@ def get_graph_input_buffers( "block_tables": attn_metadata.decode_metadata.block_tables, } if is_encoder_decoder_model: - # The encoder decoder model works only with XFormers backend. - # Assert the same. - assert self.runner.attn_backend.get_name() == "XFORMERS", \ - f"Expected attn_backend name to be 'XFORMERS', but "\ - f" got '{self.runner.attn_backend.get_name()}'" + # The encoder decoder model works only with XFormers and + # Flash Attention backend. Assert the same. + assert self.runner.attn_backend.get_name() in\ + ["XFORMERS", "FLASH_ATTN"], \ + f"Expected attn_backend name to be either 'XFORMERS' or "\ + f"'FLASH_ATTN', but "\ + f"got '{self.runner.attn_backend.get_name()}'" self._add_additonal_input_buffers_for_enc_dec_model( attn_metadata=attn_metadata, input_buffers=input_buffers) return input_buffers @@ -357,11 +372,13 @@ def prepare_graph_input_buffers( input_buffers["block_tables"].copy_( attn_metadata.decode_metadata.block_tables, non_blocking=True) if is_encoder_decoder_model: - # The encoder decoder model works only with XFormers backend. - # Assert the same. - assert self.runner.attn_backend.get_name() == "XFORMERS", \ - f"Expected attn_backend name to be 'XFORMERS', but "\ - f" got '{self.runner.attn_backend.get_name()}'" + # The encoder decoder model works only with XFormers and + # Flash Attention backend. Assert the same. + assert self.runner.attn_backend.get_name() in\ + ["XFORMERS", "FLASH_ATTN"], \ + f"Expected attn_backend name to be either 'XFORMERS' or "\ + f"'FLASH_ATTN', but "\ + f"got '{self.runner.attn_backend.get_name()}'" self._prepare_input_buffers_for_enc_dec_model( attn_metadata, input_buffers) @@ -393,6 +410,7 @@ def _update_captured_metadata_for_enc_dec_model(self, batch_size: int, attn_metadata.encoder_seq_lens_tensor = torch.full( (batch_size, ), 1, dtype=torch.int).cuda() attn_metadata.max_encoder_seq_len = self.runner.max_seq_len_to_capture + attn_metadata.num_encoder_tokens = 0 def _add_additonal_input_buffers_for_enc_dec_model( self, attn_metadata, input_buffers: Dict[str, Any]): @@ -435,3 +453,122 @@ def _prepare_input_buffers_for_enc_dec_model(self, attn_metadata, input_buffers["cross_block_tables"].copy_( attn_metadata.decode_metadata.cross_block_tables, non_blocking=True) + + +def is_all_encoder_attn_metadata_set(attn_metadata): + ''' + All attention metadata required for encoder attention is set. + ''' + return ((attn_metadata.encoder_seq_lens is not None) + and (attn_metadata.encoder_seq_lens_tensor is not None) + and (attn_metadata.max_encoder_seq_len is not None)) + + +def is_all_cross_attn_metadata_set(attn_metadata): + ''' + All attention metadata required for enc/dec cross-attention is set. + + Superset of encoder attention required metadata. + ''' + return (attn_metadata.is_all_encoder_attn_metadata_set + and (attn_metadata.cross_slot_mapping is not None) + and (attn_metadata.cross_block_tables is not None)) + + +def get_seq_len_block_table_args( + attn_metadata, + is_prompt: bool, + attn_type: AttentionType, +) -> tuple: + ''' + The particular choice of sequence-length- and block-table-related + attributes which should be extracted from attn_metadata is dependent + on the type of attention operation. + + Decoder attn -> select entirely decoder self-attention-related fields + Encoder/decoder cross-attn -> select encoder sequence lengths & + cross-attn block-tables fields + Encoder attn -> select encoder sequence lengths fields & no block tables + + Arguments: + + * attn_metadata: Attention metadata structure associated with attention op + * is_prompt: True if prefill, False otherwise + * attn_type: encoder attention, decoder self-attention, + encoder/decoder cross-attention + + Returns: + + * Appropriate sequence-lengths tensor + * Appropriate max sequence-length scalar + * Appropriate block tables (or None) + ''' + + if attn_type == AttentionType.DECODER: + # Decoder self-attention + # Choose max_seq_len based on whether we are in prompt_run + if is_prompt: + max_seq_len = attn_metadata.max_prefill_seq_len + else: + max_seq_len = attn_metadata.max_decode_seq_len + return (attn_metadata.seq_lens_tensor, max_seq_len, + attn_metadata.block_tables) + elif attn_type == AttentionType.ENCODER_DECODER: + # Enc/dec cross-attention KVs match encoder sequence length; + # cross-attention utilizes special "cross" block tables + return (attn_metadata.encoder_seq_lens_tensor, + attn_metadata.max_encoder_seq_len, + attn_metadata.cross_block_tables) + elif attn_type == AttentionType.ENCODER: + # No block tables associated with encoder attention + return (attn_metadata.encoder_seq_lens_tensor, + attn_metadata.max_encoder_seq_len, None) + else: + raise AttributeError(f"Invalid attention type {str(attn_type)}") + + +def get_num_prefill_decode_query_kv_tokens( + attn_metadata, + attn_type: AttentionType, +) -> Tuple[int, int, int]: + """ + Calculate the number of prefill and decode tokens for query, key/value + based on the attention metadata and the specified attention type. + + Args: + attn_metadata (FlashAttentionMetadata): Attention Metadata object. + attn_type (AttentionType): The type of attention being used. + Returns: + Tuple[int, int, int]: A tuple containing three integers: + - The number of prefill query tokens. + - The number of prefill key/value tokens. + - The number of decode query tokens. + + Raises: + AssertionError: If the number of encoder tokens in `attn_metadata` + is `None` when required for the calculations. + """ + num_prefill_query_tokens = 0 + num_decode_query_tokens = 0 + num_prefill_kv_tokens = 0 + if attn_type == AttentionType.ENCODER: + # Encoder attention is only invoked during prefill phase. + # The same input servers a both query and key. + assert attn_metadata.num_encoder_tokens is not None + num_prefill_query_tokens = attn_metadata.num_encoder_tokens + num_prefill_kv_tokens = attn_metadata.num_encoder_tokens + num_decode_query_tokens = 0 + elif attn_type == AttentionType.ENCODER_DECODER: + assert attn_metadata.num_encoder_tokens is not None + num_prefill_query_tokens = attn_metadata.num_prefill_tokens + # The key is the encoder/cross-attention. + num_prefill_kv_tokens = attn_metadata.num_encoder_tokens + num_decode_query_tokens = attn_metadata.num_decode_tokens + else: # attn_type == AttentionType.DECODER or + # attn_type == AttentionType.ENCODER_ONLY + num_prefill_query_tokens = attn_metadata.num_prefill_tokens + num_prefill_kv_tokens = attn_metadata.num_prefill_tokens + num_decode_query_tokens = attn_metadata.num_decode_tokens + + return (num_prefill_query_tokens, num_prefill_kv_tokens, + num_decode_query_tokens) diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 5aaf13d8ea744..4725413baade7 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -11,8 +11,10 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, AttentionType) -from vllm.attention.backends.utils import (CommonAttentionState, - CommonMetadataBuilder) +from vllm.attention.backends.utils import ( + CommonAttentionState, CommonMetadataBuilder, + get_num_prefill_decode_query_kv_tokens, get_seq_len_block_table_args, + is_all_cross_attn_metadata_set, is_all_encoder_attn_metadata_set) from vllm.attention.ops.paged_attn import (PagedAttention, PagedAttentionMetadata) from vllm.logger import init_logger @@ -135,6 +137,11 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata): # Encoder sequence lengths representation encoder_seq_lens: Optional[List[int]] = None encoder_seq_lens_tensor: Optional[torch.Tensor] = None + # FIXME: It is for flash attn. + # (batch_size + 1,). The cumulative sequence lengths of the sequences in + # the batch, used to index into sequence. E.g., if the sequence length is + # [4, 6], it is [0, 4, 10]. + encoder_seq_start_loc: Optional[torch.Tensor] = None # Maximum sequence length among encoder sequences max_encoder_seq_len: Optional[int] = None @@ -162,9 +169,7 @@ def is_all_encoder_attn_metadata_set(self): ''' All attention metadata required for encoder attention is set. ''' - return ((self.encoder_seq_lens is not None) - and (self.encoder_seq_lens_tensor is not None) - and (self.max_encoder_seq_len is not None)) + return is_all_encoder_attn_metadata_set(self) @property def is_all_cross_attn_metadata_set(self): @@ -173,9 +178,7 @@ def is_all_cross_attn_metadata_set(self): Superset of encoder attention required metadata. ''' - return (self.is_all_encoder_attn_metadata_set - and (self.cross_slot_mapping is not None) - and (self.cross_block_tables is not None)) + return is_all_cross_attn_metadata_set(self) @property def prefill_metadata(self) -> Optional["XFormersMetadata"]: @@ -212,6 +215,8 @@ def prefill_metadata(self) -> Optional["XFormersMetadata"]: num_prefill_tokens=self.num_prefill_tokens, num_decode_tokens=0, slot_mapping=slot_mapping, + multi_modal_placeholder_index_maps=self. + multi_modal_placeholder_index_maps, seq_lens=seq_lens, seq_lens_tensor=seq_lens_tensor, max_query_len=self.max_query_len, @@ -255,6 +260,7 @@ def decode_metadata(self) -> Optional["XFormersMetadata"]: num_prefill_tokens=0, num_decode_tokens=self.num_decode_tokens, slot_mapping=slot_mapping, + multi_modal_placeholder_index_maps=None, seq_lens_tensor=seq_lens_tensor, max_prefill_seq_len=0, max_decode_seq_len=self.max_decode_seq_len, @@ -326,64 +332,6 @@ def _set_attn_bias( raise AttributeError(f"Invalid attention type {str(attn_type)}") -def _get_seq_len_block_table_args( - attn_metadata: XFormersMetadata, - is_prompt: bool, - attn_type: AttentionType, -) -> tuple: - ''' - The particular choice of sequence-length- and block-table-related - attributes which should be extracted from attn_metadata is dependent - on the type of attention operation. - - Decoder attn -> select entirely decoder self-attention-related fields - Encoder/decoder cross-attn -> select encoder sequence lengths & - cross-attn block-tables fields - Encoder attn -> select encoder sequence lengths fields & no block tables - - Arguments: - - * attn_metadata: Attention metadata structure associated with attention op - * is_prompt: True if prefill, False otherwise - * attn_type: encoder attention, decoder self-attention, - encoder/decoder cross-attention - - Returns: - - * Appropriate sequence-lengths tensor - * Appropriate max sequence-length scalar - * Appropriate block tables (or None) - ''' - - if attn_type == AttentionType.DECODER: - # Decoder self-attention - # Choose max_seq_len based on whether we are in prompt_run - if is_prompt: - max_seq_len = attn_metadata.max_prefill_seq_len - else: - max_seq_len = attn_metadata.max_decode_seq_len - return (attn_metadata.seq_lens_tensor, max_seq_len, - attn_metadata.block_tables) - elif attn_type == AttentionType.ENCODER_DECODER: - # Enc/dec cross-attention KVs match encoder sequence length; - # cross-attention utilizes special "cross" block tables - return (attn_metadata.encoder_seq_lens_tensor, - attn_metadata.max_encoder_seq_len, - attn_metadata.cross_block_tables) - elif attn_type == AttentionType.ENCODER: - # No block tables associated with encoder attention - return (attn_metadata.encoder_seq_lens_tensor, - attn_metadata.max_encoder_seq_len, None) - elif attn_type == AttentionType.ENCODER_ONLY: - assert is_prompt, "Should not have decode for encoder only model." - - # No block tables associated with encoder attention - return (attn_metadata.seq_lens_tensor, - attn_metadata.max_prefill_seq_len, None) - else: - raise AttributeError(f"Invalid attention type {str(attn_type)}") - - class XFormersMetadataBuilder(CommonMetadataBuilder[XFormersMetadata]): _metadata_cls = XFormersMetadata @@ -571,45 +519,21 @@ def forward( updated_slot_mapping, self.kv_cache_dtype, k_scale, v_scale) - - if attn_type == AttentionType.ENCODER: - # Encoder attention - chunked prefill is not applicable; - # derive token-count from query shape & and treat them - # as 100% prefill tokens - assert attn_metadata.num_encoder_tokens is not None - num_prefill_tokens = attn_metadata.num_encoder_tokens - num_encoder_tokens = attn_metadata.num_encoder_tokens - num_decode_tokens = 0 - elif attn_type == AttentionType.DECODER: - # Decoder self-attention supports chunked prefill. - num_prefill_tokens = attn_metadata.num_prefill_tokens - num_encoder_tokens = attn_metadata.num_prefill_tokens - num_decode_tokens = attn_metadata.num_decode_tokens - # Only enforce this shape-constraint for decoder - # self-attention - assert key.shape[0] == num_prefill_tokens + num_decode_tokens - assert value.shape[0] == num_prefill_tokens + num_decode_tokens - else: # attn_type == AttentionType.ENCODER_DECODER - # Encoder/decoder cross-attention requires no chunked - # prefill (100% prefill or 100% decode tokens, no mix) - num_prefill_tokens = attn_metadata.num_prefill_tokens - if attn_metadata.num_encoder_tokens is not None: - num_encoder_tokens = attn_metadata.num_encoder_tokens - else: - num_encoder_tokens = attn_metadata.num_prefill_tokens - num_decode_tokens = attn_metadata.num_decode_tokens + (num_prefill_query_tokens, num_prefill_kv_tokens, + num_decode_query_tokens) = \ + get_num_prefill_decode_query_kv_tokens(attn_metadata, attn_type) output = torch.empty_like(query) # Query for decode. KV is not needed because it is already cached. - decode_query = query[num_prefill_tokens:] + decode_query = query[num_prefill_query_tokens:] # QKV for prefill. - query = query[:num_prefill_tokens] + query = query[:num_prefill_query_tokens] if key is not None and value is not None: - key = key[:num_encoder_tokens] - value = value[:num_encoder_tokens] + key = key[:num_prefill_kv_tokens] + value = value[:num_prefill_kv_tokens] - assert query.shape[0] == num_prefill_tokens - assert decode_query.shape[0] == num_decode_tokens + assert query.shape[0] == num_prefill_query_tokens + assert decode_query.shape[0] == num_decode_query_tokens if prefill_meta := attn_metadata.prefill_metadata: # Prompt run. @@ -619,8 +543,8 @@ def forward( # prefix. out = self._run_memory_efficient_xformers_forward( query, key, value, prefill_meta, attn_type=attn_type) - assert out.shape == output[:num_prefill_tokens].shape - output[:num_prefill_tokens] = out + assert out.shape == output[:num_prefill_query_tokens].shape + output[:num_prefill_query_tokens] = out else: assert attn_type != AttentionType.ENCODER_ONLY, ( "Encoder-only models should not have prefix attention.") @@ -649,8 +573,8 @@ def forward( k_scale, v_scale, ) - assert output[:num_prefill_tokens].shape == out.shape - output[:num_prefill_tokens] = out + assert output[:num_prefill_query_tokens].shape == out.shape + output[:num_prefill_query_tokens] = out if decode_meta := attn_metadata.decode_metadata: assert attn_type != AttentionType.ENCODER_ONLY, ( @@ -660,9 +584,9 @@ def forward( seq_lens_arg, max_seq_len_arg, block_tables_arg, - ) = _get_seq_len_block_table_args(decode_meta, False, attn_type) + ) = get_seq_len_block_table_args(decode_meta, False, attn_type) - output[num_prefill_tokens:] = PagedAttention.forward_decode( + output[num_prefill_query_tokens:] = PagedAttention.forward_decode( decode_query, key_cache, value_cache, diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 376b3136f0fb8..8a59cf41a689e 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -98,7 +98,6 @@ def get_attn_backend( is_blocksparse: bool = False, ) -> Type[AttentionBackend]: """Selects which attention backend to use and lazily imports it.""" - if is_blocksparse: logger.info("Using BlocksparseFlashAttention backend.") from vllm.attention.backends.blocksparse_attn import ( @@ -108,6 +107,7 @@ def get_attn_backend( backend = which_attn_to_use(head_size, dtype, kv_cache_dtype, block_size, is_attention_free) if backend == _Backend.FLASH_ATTN: + logger.info("Using Flash Attention backend.") from vllm.attention.backends.flash_attn import ( # noqa: F401 FlashAttentionBackend) return FlashAttentionBackend diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 10cf49e19eccc..96ddcba467c5b 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -243,6 +243,65 @@ def split_graph(graph: fx.GraphModule, return split_gm, outputs +# we share the global graph pool among all the backends +global_graph_pool = None + + +class PiecewiseCompileInterpreter(torch.fx.Interpreter): + """Code adapted from `torch.fx.passes.shape_prop.ShapeProp`. + It runs the given graph with fake inputs, and compile some + submodules specified by `compile_submod_names` with the given + compilation configs. + """ + + def __init__(self, module: torch.fx.GraphModule, + compile_submod_names: List[str], + compilation_configs: CompilationConfig, graph_pool): + super().__init__(module) + from torch._guards import detect_fake_mode + self.fake_mode = detect_fake_mode() + self.compile_submod_names = compile_submod_names + self.compilation_configs = compilation_configs + self.graph_pool = graph_pool + self.have_seen_first_graph = False + + def run(self, *args): + fake_args = [ + self.fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t + for t in args + ] + return super().run(*fake_args) + + def call_module(self, target: torch.fx.node.Target, + args: Tuple[torch.fx.node.Argument, + ...], kwargs: Dict[str, Any]) -> Any: + assert isinstance(target, str) + output = super().call_module(target, args, kwargs) + + if target in self.compile_submod_names: + submod = self.fetch_attr(target) + sym_shape_indices = [ + i for i, x in enumerate(args) if isinstance(x, torch.SymInt) + ] + compiled_graph_for_general_shape = wrap_inductor( + submod, + args, + self.compilation_configs.inductor_compile_config, + runtime_shape=None, + do_logging=not self.have_seen_first_graph, + use_inductor=self.compilation_configs.use_inductor) + + self.module.__dict__[target] = PiecewiseBackend( + submod, self.compilation_configs, self.graph_pool, + not self.have_seen_first_graph, sym_shape_indices, + compiled_graph_for_general_shape) + + self.have_seen_first_graph = True + compilation_counter.num_piecewise_capturable_graphs_seen += 1 + + return output + + class VllmBackend: """The compilation backend for `torch.compile` with VLLM. It is used for compilation level of `CompilationLevel.PIECEWISE`, @@ -263,8 +322,14 @@ class VllmBackend: returned_callable: Callable def __init__(self, ): - # every instance of VllmBackend has its own graph pool - self.graph_pool = torch.cuda.graph_pool_handle() + global global_graph_pool + if global_graph_pool is None: + global_graph_pool = torch.cuda.graph_pool_handle() + + # TODO: in the future, if we want to use multiple + # streams, it might not be safe to share a global pool. + # only investigate this when we use multiple streams + self.graph_pool = global_graph_pool # `torch.compile` is JIT compiled, so we don't need to # do anything here @@ -286,55 +351,26 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: self.split_gm, self.piecewise_graphs = split_graph( graph, self.compilation_configs.non_cudagraph_ops) - returned_callable: Callable # type: ignore + from torch._dynamo.utils import lazy_format_graph_code + logger.debug("%s", + lazy_format_graph_code("stiching module", self.split_gm)) - if len(self.piecewise_graphs) == 0: - compilation_counter.num_piecewise_graphs_seen += 1 - compilation_counter.num_piecewise_capturable_graphs_seen += 1 - returned_callable = PiecewiseBackend(graph, - self.compilation_configs, - self.graph_pool, - is_first_graph=True) - else: - from torch._dynamo.utils import lazy_format_graph_code - logger.debug( - "%s", lazy_format_graph_code("stiching module", self.split_gm)) - - is_first_graph = True - - for item in self.piecewise_graphs: - compilation_counter.num_piecewise_graphs_seen += 1 - compilation_counter.num_piecewise_capturable_graphs_seen += not item.is_splitting_graph # noqa - if not item.is_splitting_graph: - # cannot setattr to a module, so we need to set - # the attribute in the __dict__ - self.split_gm.__dict__[ - item.submod_name] = PiecewiseBackend( - item.graph, self.compilation_configs, - self.graph_pool, is_first_graph) - is_first_graph = False - returned_callable = self.split_gm - - self.returned_callable = returned_callable - # trigger the first compilation - # code borrowed from https://github.com/pytorch/pytorch/blob/4e3e08b71171fa34172b2362ff668553fac75f27/torch/_dynamo/backends/distributed.py#L206 # noqa - # to turn the inputs into fake tensors - import torch._guards - from torch._guards import detect_fake_mode - fake_mode = detect_fake_mode(example_inputs) - fake_args = [] - for arg in example_inputs: - if isinstance(arg, torch.Tensor) and not isinstance( - arg, torch._subclasses.FakeTensor): - fake_args.append( - torch._dynamo.utils.to_fake_tensor(arg, fake_mode)) - else: - fake_args.append(arg) - self.returned_callable(*fake_args) + compilation_counter.num_piecewise_graphs_seen += len( + self.piecewise_graphs) + submod_names_to_compile = [ + item.submod_name for item in self.piecewise_graphs + if not item.is_splitting_graph + ] + + # propagate the split graph to the piecewise backend, + # compile submodules with symbolic shapes + PiecewiseCompileInterpreter(self.split_gm, submod_names_to_compile, + self.compilation_configs, + self.graph_pool).run(*example_inputs) self._called = True - return self.returned_callable + return self.split_gm @dataclasses.dataclass @@ -352,11 +388,10 @@ class ConcreteSizeEntry: class PiecewiseBackend: - def __init__(self, - graph: fx.GraphModule, - compilation_configs: CompilationConfig, - graph_pool: Any, - is_first_graph: bool = False): + def __init__(self, graph: fx.GraphModule, + compilation_configs: CompilationConfig, graph_pool: Any, + is_first_graph: bool, sym_shape_indices: List[int], + compiled_graph_for_general_shape: Callable): """ The backend for piecewise compilation. It mainly handles the compilation and cudagraph capturing. @@ -381,12 +416,11 @@ def __init__(self, self.compilation_configs.capture_sizes ) if self.compilation_configs.use_cudagraph else set() - self.compile_finished = False self.first_run_finished = False - self.compiled_graph_for_general_shape: Callable = None # type: ignore + self.compiled_graph_for_general_shape = compiled_graph_for_general_shape # noqa - self.sym_shape_indices: List[int] = [] + self.sym_shape_indices = sym_shape_indices # the entries for different shapes that we need to either # compile or capture cudagraph @@ -399,27 +433,6 @@ def __init__(self, ) def __call__(self, *args) -> Any: - - if not self.compile_finished: - self.compile_finished = True - - # this is the first compilation, we will compile a graph with - # dynamic shape, as the caller will mark first dimension as dynamic - - self.sym_shape_indices = [ - i for i, x in enumerate(args) if isinstance(x, torch.SymInt) - ] - - self.compiled_graph_for_general_shape = wrap_inductor( - self.graph, - args, - self.compilation_configs.inductor_compile_config, - runtime_shape=None, - do_logging=self.is_first_graph, - use_inductor=self.compilation_configs.use_inductor) - - return self.graph(*args) - if not self.first_run_finished: self.first_run_finished = True return self.compiled_graph_for_general_shape(*args) diff --git a/vllm/config.py b/vllm/config.py index c2a8c956b374a..17e9b1c100498 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1,6 +1,6 @@ import enum import json -from dataclasses import dataclass, field, fields +from dataclasses import dataclass, field from typing import (TYPE_CHECKING, Any, ClassVar, Dict, Final, List, Literal, Mapping, Optional, Set, Tuple, Type, Union) @@ -1941,9 +1941,9 @@ def __post_init__(self): f"installed. Original error:\n{otel_import_error_traceback}") -@dataclass(frozen=True) -class EngineConfig: - """Dataclass which contains all engine-related configuration. This +@dataclass +class VllmConfig: + """Dataclass which contains all vllm-related configuration. This simplifies passing around the distinct configurations in the codebase. """ @@ -1953,11 +1953,11 @@ class EngineConfig: scheduler_config: SchedulerConfig device_config: DeviceConfig load_config: LoadConfig - lora_config: Optional[LoRAConfig] - speculative_config: Optional[SpeculativeConfig] - decoding_config: Optional[DecodingConfig] - observability_config: Optional[ObservabilityConfig] - prompt_adapter_config: Optional[PromptAdapterConfig] + lora_config: Optional[LoRAConfig] = None + speculative_config: Optional[SpeculativeConfig] = None + decoding_config: Optional[DecodingConfig] = None + observability_config: Optional[ObservabilityConfig] = None + prompt_adapter_config: Optional[PromptAdapterConfig] = None def __post_init__(self): """Verify configs are valid & consistent with each other. @@ -1975,9 +1975,3 @@ def __post_init__(self): if self.prompt_adapter_config: self.prompt_adapter_config.verify_with_model_config( self.model_config) - - def to_dict(self): - """Return the configs as a dictionary, for use in **kwargs. - """ - return dict( - (field.name, getattr(self, field.name)) for field in fields(self)) diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 88733b8f53b86..e56d5cddce424 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -828,8 +828,7 @@ def _schedule_priority_preemption( num_running_seqs) #Preempt out the victim sequence group - self._preempt(vseq_group, blocks_to_swap_out, - PreemptionMode.RECOMPUTE) + self._preempt(vseq_group, blocks_to_swap_out) waiting_queue.appendleft(vseq_group) force_preemption_count += 1 #Put the sequence back into the waiting queue @@ -1309,6 +1308,8 @@ def schedule( # `multi_modal_data` will be None. multi_modal_data=seq_group.multi_modal_data if scheduler_outputs.num_prefill_groups > 0 else None, + multi_modal_placeholders=seq_group.multi_modal_placeholders + if scheduler_outputs.num_prefill_groups > 0 else None, mm_processor_kwargs=seq_group.mm_processor_kwargs, prompt_adapter_request=seq_group.prompt_adapter_request, ) @@ -1451,12 +1452,8 @@ def _append_slots(self, if len(cows) > 0: blocks_to_copy.extend(cows) - def _preempt( - self, - seq_group: SequenceGroup, - blocks_to_swap_out: List[Tuple[int, int]], - preemption_mode: Optional[PreemptionMode] = None, - ) -> PreemptionMode: + def _preempt(self, seq_group: SequenceGroup, + blocks_to_swap_out: List[Tuple[int, int]]) -> PreemptionMode: # If preemption mode is not specified, we determine the mode as follows: # We use recomputation by default since it incurs lower overhead than # swapping. However, when the sequence group has multiple sequences diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index b1f0f8b9df925..da06ab186821e 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -9,10 +9,11 @@ import vllm.envs as envs from vllm.config import (CacheConfig, ConfigFormat, DecodingConfig, - DeviceConfig, EngineConfig, LoadConfig, LoadFormat, - LoRAConfig, ModelConfig, ObservabilityConfig, - ParallelConfig, PromptAdapterConfig, SchedulerConfig, - SpeculativeConfig, TaskOption, TokenizerPoolConfig) + DeviceConfig, LoadConfig, LoadFormat, LoRAConfig, + ModelConfig, ObservabilityConfig, ParallelConfig, + PromptAdapterConfig, SchedulerConfig, + SpeculativeConfig, TaskOption, TokenizerPoolConfig, + VllmConfig) from vllm.executor.executor_base import ExecutorBase from vllm.logger import init_logger from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS @@ -955,7 +956,7 @@ def create_load_config(self) -> LoadConfig: ignore_patterns=self.ignore_patterns, ) - def create_engine_config(self) -> EngineConfig: + def create_engine_config(self) -> VllmConfig: # gguf file needs a specific model loader and doesn't use hf_repo if check_gguf_file(self.model): self.quantization = self.load_format = "gguf" @@ -1167,7 +1168,7 @@ def create_engine_config(self) -> EngineConfig: or "all" in detailed_trace_modules, ) - return EngineConfig( + return VllmConfig( model_config=model_config, cache_config=cache_config, parallel_config=parallel_config, diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 5198467a6ac40..b0fdc67776bbd 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -7,8 +7,8 @@ from weakref import ReferenceType import vllm.envs as envs -from vllm.config import (DecodingConfig, EngineConfig, LoRAConfig, ModelConfig, - ParallelConfig, SchedulerConfig) +from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig, + ParallelConfig, SchedulerConfig, VllmConfig) from vllm.core.scheduler import SchedulerOutputs from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_timeout import asyncio_timeout @@ -604,7 +604,7 @@ def __del__(self): @classmethod def _get_executor_cls( - cls, engine_config: EngineConfig) -> Type[ExecutorAsyncBase]: + cls, engine_config: VllmConfig) -> Type[ExecutorAsyncBase]: distributed_executor_backend = ( engine_config.parallel_config.distributed_executor_backend) if isinstance(distributed_executor_backend, type): @@ -663,7 +663,7 @@ def _get_executor_cls( def from_engine_args( cls, engine_args: AsyncEngineArgs, - engine_config: Optional[EngineConfig] = None, + engine_config: Optional[VllmConfig] = None, start_engine_loop: bool = True, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, stat_loggers: Optional[Dict[str, StatLoggerBase]] = None, @@ -680,7 +680,7 @@ def from_engine_args( # Create the async LLM engine. engine = cls( - **engine_config.to_dict(), + vllm_config=engine_config, executor_class=executor_class, log_requests=not engine_args.disable_log_requests, log_stats=not engine_args.disable_log_stats, diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index edef1f30a9e91..b12d29c4a8503 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -13,11 +13,9 @@ from typing_extensions import TypeIs, TypeVar import vllm.envs as envs -from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, - EngineConfig, LoadConfig, LoRAConfig, ModelConfig, - ObservabilityConfig, ParallelConfig, - PromptAdapterConfig, SchedulerConfig, - SpeculativeConfig) +from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig, + ObservabilityConfig, ParallelConfig, SchedulerConfig, + VllmConfig) from vllm.core.scheduler import (ScheduledSequenceGroup, Scheduler, SchedulerOutputs) from vllm.engine.arg_utils import EngineArgs @@ -222,17 +220,7 @@ def validate_outputs( def __init__( self, - model_config: ModelConfig, - cache_config: CacheConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - device_config: DeviceConfig, - load_config: LoadConfig, - lora_config: Optional[LoRAConfig], - speculative_config: Optional[SpeculativeConfig], - decoding_config: Optional[DecodingConfig], - observability_config: Optional[ObservabilityConfig], - prompt_adapter_config: Optional[PromptAdapterConfig], + vllm_config: VllmConfig, executor_class: Type[ExecutorBase], log_stats: bool, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, @@ -240,6 +228,22 @@ def __init__( input_registry: InputRegistry = INPUT_REGISTRY, use_cached_outputs: bool = False, ) -> None: + + # TODO: remove the local variables and use self.* throughout the class. + model_config = self.model_config = vllm_config.model_config + cache_config = self.cache_config = vllm_config.cache_config + lora_config = self.lora_config = vllm_config.lora_config + parallel_config = self.parallel_config = vllm_config.parallel_config + scheduler_config = self.scheduler_config = vllm_config.scheduler_config + device_config = self.device_config = vllm_config.device_config + speculative_config = self.speculative_config = vllm_config.speculative_config # noqa + load_config = self.load_config = vllm_config.load_config + decoding_config = self.decoding_config = vllm_config.decoding_config or DecodingConfig( # noqa + ) + prompt_adapter_config = self.prompt_adapter_config = vllm_config.prompt_adapter_config # noqa + observability_config = self.observability_config = vllm_config.observability_config or ObservabilityConfig( # noqa + ) + logger.info( "Initializing an LLM engine (v%s) with config: " "model=%r, speculative_config=%r, tokenizer=%r, " @@ -340,18 +344,7 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer: self.input_processor = input_registry.create_input_processor( model_config) - self.model_executor = executor_class( - model_config=model_config, - cache_config=cache_config, - parallel_config=parallel_config, - scheduler_config=scheduler_config, - device_config=device_config, - lora_config=lora_config, - speculative_config=speculative_config, - load_config=load_config, - prompt_adapter_config=prompt_adapter_config, - observability_config=self.observability_config, - ) + self.model_executor = executor_class(vllm_config=vllm_config, ) if self.model_config.task != "embedding": self._initialize_kv_caches() @@ -508,7 +501,7 @@ def _initialize_kv_caches(self) -> None: @classmethod def _get_executor_cls(cls, - engine_config: EngineConfig) -> Type[ExecutorBase]: + engine_config: VllmConfig) -> Type[ExecutorBase]: distributed_executor_backend = ( engine_config.parallel_config.distributed_executor_backend) # Initialize the cluster and specify the executor class. @@ -582,7 +575,7 @@ def from_engine_args( executor_class = cls._get_executor_cls(engine_config) # Create the LLM engine. engine = cls( - **engine_config.to_dict(), + vllm_config=engine_config, executor_class=executor_class, log_stats=not engine_args.disable_log_stats, usage_context=usage_context, diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py index 6e6630b3ff55f..7f1ca621d91c4 100644 --- a/vllm/engine/multiprocessing/client.py +++ b/vllm/engine/multiprocessing/client.py @@ -13,7 +13,7 @@ from zmq.asyncio import Socket from vllm import PoolingParams -from vllm.config import DecodingConfig, EngineConfig, ModelConfig +from vllm.config import DecodingConfig, ModelConfig, VllmConfig from vllm.core.scheduler import SchedulerOutputs from vllm.engine.arg_utils import AsyncEngineArgs # yapf conflicts with isort for this block @@ -78,7 +78,7 @@ class MQLLMEngineClient(EngineClient): every N seconds, confirming the engine is healthy """ - def __init__(self, ipc_path: str, engine_config: EngineConfig, + def __init__(self, ipc_path: str, engine_config: VllmConfig, engine_pid: int): self.context = zmq.asyncio.Context() self._errored_with: Optional[BaseException] = None diff --git a/vllm/engine/multiprocessing/engine.py b/vllm/engine/multiprocessing/engine.py index 0a7f430eca488..a73b4c825b11c 100644 --- a/vllm/engine/multiprocessing/engine.py +++ b/vllm/engine/multiprocessing/engine.py @@ -5,10 +5,9 @@ import cloudpickle import zmq +from ray.exceptions import RayTaskError from vllm import AsyncEngineArgs, SamplingParams -from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig, - ParallelConfig, SchedulerConfig) # yapf conflicts with isort for this block # yapf: disable from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT, @@ -30,9 +29,6 @@ else: from vllm.engine.llm_engine import LLMEngine -CONFIG_TYPE = Union[ModelConfig, DecodingConfig, ParallelConfig, - SchedulerConfig, LoRAConfig] - logger = init_logger(__name__) POLLING_TIMEOUT_MS = 10000 @@ -130,7 +126,7 @@ def from_engine_args(cls, engine_args: AsyncEngineArgs, return cls(ipc_path=ipc_path, use_async_sockets=use_async_sockets, - **engine_config.to_dict(), + vllm_config=engine_config, executor_class=executor_class, log_requests=not engine_args.disable_log_requests, log_stats=not engine_args.disable_log_stats, @@ -310,6 +306,11 @@ def _health_check(self): def _send_outputs(self, outputs: REQUEST_OUTPUTS_T): """Send List of RequestOutput to RPCClient.""" if outputs: + # RayTaskError might not pickelable here. We need to unpack the + # underlying exception as the real exception in the output. + if (isinstance(outputs, RPCError) + and isinstance(outputs.exception, RayTaskError)): + outputs.exception = outputs.exception.cause output_bytes = pickle.dumps(outputs) self.output_socket.send_multipart((output_bytes, ), copy=False) diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index bc2de2d162473..c9552977710d1 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -187,7 +187,8 @@ def _placeholder_str(self, modality: ModalityStr, if model_type.startswith("llava"): return self._cached_token_str(self._tokenizer, hf_config.image_token_index) - if model_type in ("chameleon", "internvl_chat", "NVLM_D"): + if model_type in ("chameleon", "internvl_chat", "NVLM_D", + "h2ovl_chat"): return "" if model_type == "mllama": return "<|image|>" diff --git a/vllm/executor/cpu_executor.py b/vllm/executor/cpu_executor.py index e32993e0e452e..ab3ebb4e43d18 100644 --- a/vllm/executor/cpu_executor.py +++ b/vllm/executor/cpu_executor.py @@ -138,18 +138,11 @@ def _create_worker( assert self.distributed_init_method is not None kwargs = dict( - model_config=self.model_config, - parallel_config=self.parallel_config, - scheduler_config=self.scheduler_config, - device_config=self.device_config, - cache_config=self.cache_config, - load_config=self.load_config, + vllm_config=self.vllm_config, local_rank=local_rank, rank=rank, distributed_init_method=self.distributed_init_method, - lora_config=self.lora_config, kv_cache_dtype=self.cache_config.cache_dtype, - prompt_adapter_config=self.prompt_adapter_config, is_driver_worker=rank == 0, ) wrapper.init_worker(**kwargs) diff --git a/vllm/executor/executor_base.py b/vllm/executor/executor_base.py index c96cb0f2c2981..9cba189dd57f9 100644 --- a/vllm/executor/executor_base.py +++ b/vllm/executor/executor_base.py @@ -1,10 +1,7 @@ from abc import ABC, abstractmethod from typing import List, Optional, Set, Tuple -from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, - ModelConfig, ObservabilityConfig, ParallelConfig, - PromptAdapterConfig, SchedulerConfig, - SpeculativeConfig) +from vllm.config import VllmConfig from vllm.lora.request import LoRARequest from vllm.model_executor.layers.sampler import SamplerOutput from vllm.prompt_adapter.request import PromptAdapterRequest @@ -23,27 +20,19 @@ class ExecutorBase(ABC): def __init__( self, - model_config: ModelConfig, - cache_config: CacheConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - device_config: DeviceConfig, - load_config: LoadConfig, - lora_config: Optional[LoRAConfig], - speculative_config: Optional[SpeculativeConfig], - prompt_adapter_config: Optional[PromptAdapterConfig], - observability_config: Optional[ObservabilityConfig], + vllm_config: VllmConfig, ) -> None: - self.model_config = model_config - self.cache_config = cache_config - self.lora_config = lora_config - self.load_config = load_config - self.parallel_config = parallel_config - self.scheduler_config = scheduler_config - self.device_config = device_config - self.speculative_config = speculative_config - self.prompt_adapter_config = prompt_adapter_config - self.observability_config = observability_config + self.vllm_config = vllm_config + self.model_config = vllm_config.model_config + self.cache_config = vllm_config.cache_config + self.lora_config = vllm_config.lora_config + self.load_config = vllm_config.load_config + self.parallel_config = vllm_config.parallel_config + self.scheduler_config = vllm_config.scheduler_config + self.device_config = vllm_config.device_config + self.speculative_config = vllm_config.speculative_config + self.prompt_adapter_config = vllm_config.prompt_adapter_config + self.observability_config = vllm_config.observability_config self._init_executor() @abstractmethod diff --git a/vllm/executor/gpu_executor.py b/vllm/executor/gpu_executor.py index ed30d3186a453..c65d0836e5ff7 100644 --- a/vllm/executor/gpu_executor.py +++ b/vllm/executor/gpu_executor.py @@ -49,21 +49,12 @@ def _get_worker_kwargs( distributed_init_method = get_distributed_init_method( get_ip(), get_open_port()) return dict( - model_config=self.model_config, - parallel_config=self.parallel_config, - scheduler_config=self.scheduler_config, - device_config=self.device_config, - cache_config=self.cache_config, - load_config=self.load_config, + vllm_config=self.vllm_config, local_rank=local_rank, rank=rank, distributed_init_method=distributed_init_method, - lora_config=self.lora_config, - speculative_config=self.speculative_config, - prompt_adapter_config=self.prompt_adapter_config, is_driver_worker=(not self.parallel_config) or (rank % self.parallel_config.tensor_parallel_size == 0), - observability_config=self.observability_config, ) def _get_worker_module_and_class( diff --git a/vllm/executor/neuron_executor.py b/vllm/executor/neuron_executor.py index f2fcfa58b26e1..02d37cd7fbf23 100644 --- a/vllm/executor/neuron_executor.py +++ b/vllm/executor/neuron_executor.py @@ -29,11 +29,7 @@ def _init_worker(self): distributed_init_method = get_distributed_init_method( get_ip(), get_open_port()) self.driver_worker = NeuronWorker( - model_config=self.model_config, - parallel_config=self.parallel_config, - scheduler_config=self.scheduler_config, - device_config=self.device_config, - cache_config=self.cache_config, + vllm_config=self.vllm_config, local_rank=0, rank=0, distributed_init_method=distributed_init_method) diff --git a/vllm/executor/openvino_executor.py b/vllm/executor/openvino_executor.py index d0c0333854dae..d06b0ccb7906e 100644 --- a/vllm/executor/openvino_executor.py +++ b/vllm/executor/openvino_executor.py @@ -48,16 +48,10 @@ def _init_worker(self): get_ip(), get_open_port()) self.driver_worker = OpenVINOWorker( ov_core=self.ov_core, - model_config=self.model_config, - parallel_config=self.parallel_config, - scheduler_config=self.scheduler_config, - device_config=self.device_config, - cache_config=self.cache_config, - load_config=self.load_config, + vllm_config=self.vllm_config, local_rank=0, rank=0, distributed_init_method=distributed_init_method, - lora_config=self.lora_config, kv_cache_dtype=self.cache_config.cache_dtype, is_driver_worker=True, ) diff --git a/vllm/executor/tpu_executor.py b/vllm/executor/tpu_executor.py index 972649dedf33e..e37e8973790db 100644 --- a/vllm/executor/tpu_executor.py +++ b/vllm/executor/tpu_executor.py @@ -44,12 +44,7 @@ def _get_worker_kwargs( distributed_init_method = get_distributed_init_method( get_ip(), get_open_port()) return dict( - model_config=self.model_config, - parallel_config=self.parallel_config, - scheduler_config=self.scheduler_config, - device_config=self.device_config, - cache_config=self.cache_config, - load_config=self.load_config, + vllm_config=self.vllm_config, local_rank=local_rank, rank=rank, distributed_init_method=distributed_init_method, diff --git a/vllm/executor/xpu_executor.py b/vllm/executor/xpu_executor.py index 5f78993ddc4b4..36b7e2265efab 100644 --- a/vllm/executor/xpu_executor.py +++ b/vllm/executor/xpu_executor.py @@ -2,10 +2,7 @@ import torch -from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, - ModelConfig, ObservabilityConfig, ParallelConfig, - PromptAdapterConfig, SchedulerConfig, - SpeculativeConfig) +from vllm.config import ModelConfig, ParallelConfig from vllm.executor.executor_base import ExecutorAsyncBase from vllm.executor.gpu_executor import GPUExecutor from vllm.logger import init_logger @@ -21,38 +18,13 @@ class XPUExecutor(GPUExecutor): uses_ray: bool = False - def __init__( - self, - model_config: ModelConfig, - cache_config: CacheConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - device_config: DeviceConfig, - load_config: LoadConfig, - lora_config: Optional[LoRAConfig], - prompt_adapter_config: Optional[PromptAdapterConfig], - speculative_config: Optional[SpeculativeConfig], - observability_config: Optional[ObservabilityConfig], - ) -> None: - assert device_config.device_type == "xpu" - assert (not speculative_config - ), "Speculative decoding not yet supported for XPU backend" - - model_config = _verify_and_get_model_config(model_config) - - self.model_config = model_config - self.cache_config = cache_config - self.load_config = load_config - self.lora_config = lora_config - self.parallel_config = _verify_and_get_parallel_config(parallel_config) - self.scheduler_config = scheduler_config - self.device_config = device_config - self.prompt_adapter_config = prompt_adapter_config - self.speculative_config = None - self.observability_config = observability_config - - # Instantiate the worker and load the model to GPU. - self._init_executor() + def _init_executor(self) -> None: + assert self.device_config.device_type == "xpu" + assert self.speculative_config is None, ( + "Speculative decoding not yet supported for XPU backend") + + self.model_config = _verify_and_get_model_config(self.model_config) + GPUExecutor._init_executor(self) def _get_worker_module_and_class( self) -> Tuple[str, str, Optional[Callable[[], Type[WorkerBase]]]]: diff --git a/vllm/inputs/__init__.py b/vllm/inputs/__init__.py index 7b73922ddd2c5..ac7b3ca28b406 100644 --- a/vllm/inputs/__init__.py +++ b/vllm/inputs/__init__.py @@ -3,7 +3,7 @@ SingletonPrompt, TextPrompt, TokenInputs, TokensPrompt, build_explicit_enc_dec_prompt, to_enc_dec_tuple_list, token_inputs, zip_enc_dec_prompts) -from .registry import InputContext, InputRegistry +from .registry import DummyData, InputContext, InputRegistry INPUT_REGISTRY = InputRegistry() """ @@ -29,6 +29,7 @@ "to_enc_dec_tuple_list", "zip_enc_dec_prompts", "INPUT_REGISTRY", + "DummyData", "InputContext", "InputRegistry", ] diff --git a/vllm/inputs/data.py b/vllm/inputs/data.py index 9a094191eda38..ba393cbcce4eb 100644 --- a/vllm/inputs/data.py +++ b/vllm/inputs/data.py @@ -4,7 +4,7 @@ from typing_extensions import NotRequired, TypedDict, TypeVar if TYPE_CHECKING: - from vllm.multimodal import MultiModalDataDict + from vllm.multimodal import MultiModalDataDict, MultiModalPlaceholderDict class TextPrompt(TypedDict): @@ -136,6 +136,12 @@ class TokenInputs(TypedDict): if the model supports it. """ + multi_modal_placeholders: NotRequired[ + Optional["MultiModalPlaceholderDict"]] + """ + Placeholder ranges for the multi-modal data. + """ + mm_processor_kwargs: NotRequired[Optional[Dict[str, Any]]] """ Optional multi-modal processor kwargs to be forwarded to the @@ -149,6 +155,7 @@ def token_inputs( prompt_token_ids: List[int], prompt: Optional[str] = None, multi_modal_data: Optional["MultiModalDataDict"] = None, + multi_modal_placeholders: Optional["MultiModalPlaceholderDict"] = None, mm_processor_kwargs: Optional[Dict[str, Any]] = None, ) -> TokenInputs: """Construct :class:`TokenInputs` from optional values.""" @@ -158,6 +165,8 @@ def token_inputs( inputs["prompt"] = prompt if multi_modal_data is not None: inputs["multi_modal_data"] = multi_modal_data + if multi_modal_placeholders is not None: + inputs["multi_modal_placeholders"] = multi_modal_placeholders if mm_processor_kwargs is not None: inputs["mm_processor_kwargs"] = mm_processor_kwargs diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index 4cebc91ce715c..fbf912a212568 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -1,8 +1,8 @@ import functools from collections import UserDict from dataclasses import dataclass -from typing import (TYPE_CHECKING, Any, Callable, Dict, Mapping, Optional, - Protocol, Tuple, Type) +from typing import (TYPE_CHECKING, Any, Callable, Dict, Mapping, NamedTuple, + Optional, Protocol, Type) from torch import nn from transformers import PretrainedConfig @@ -16,7 +16,8 @@ if TYPE_CHECKING: from vllm.config import ModelConfig - from vllm.multimodal import MultiModalDataDict, MultiModalRegistry + from vllm.multimodal import (MultiModalDataDict, MultiModalPlaceholderDict, + MultiModalRegistry) from vllm.sequence import SequenceData logger = init_logger(__name__) @@ -63,6 +64,14 @@ def get_hf_image_processor_config(self) -> Dict[str, Any]: N = TypeVar("N", bound=Type[nn.Module]) +class DummyData(NamedTuple): + """Dummy data used for profiling.""" + + seq_data: "SequenceData" + multi_modal_data: Optional["MultiModalDataDict"] = None + multi_modal_placeholders: Optional["MultiModalPlaceholderDict"] = None + + class DummyDataFactory(Protocol): def __call__( @@ -71,7 +80,7 @@ def __call__( seq_len: int, mm_counts: Mapping[str, int], **mm_processor_kwargs: Any, - ) -> Tuple["SequenceData", Optional["MultiModalDataDict"]]: + ) -> DummyData: """ Create dummy data to be inputted into the model. @@ -123,7 +132,7 @@ def _default_dummy_data_factory( ctx: InputContext, seq_len: int, mm_counts: Mapping[str, int], - ) -> Tuple["SequenceData", Optional["MultiModalDataDict"]]: + ) -> DummyData: """ The default dummy data factory represents the longest possible text that can be inputted to the model. @@ -134,10 +143,7 @@ def _default_dummy_data_factory( # Avoid circular import from vllm.sequence import SequenceData - dummy_seq_data = SequenceData.from_prompt_token_counts((0, seq_len)) - dummy_multi_modal_data = None - - return dummy_seq_data, dummy_multi_modal_data + return DummyData(SequenceData.from_prompt_token_counts((0, seq_len))) def register_dummy_data(self, factory: DummyDataFactory): """ @@ -195,7 +201,7 @@ def dummy_data_for_profiling( seq_len: int, mm_registry: "MultiModalRegistry", is_encoder_data: bool = False, - ) -> Tuple["SequenceData", Optional["MultiModalDataDict"]]: + ) -> DummyData: """ Create dummy data for profiling the memory usage of a model. @@ -220,12 +226,12 @@ def dummy_data_for_profiling( mm_processor_kwargs = get_allowed_kwarg_only_overrides( dummy_factory, overrides=model_config.mm_processor_kwargs) - seq_data, mm_data = dummy_factory(InputContext(model_config), seq_len, - _MultiModalCounts(mm_counts), - **mm_processor_kwargs) + dummy_data = dummy_factory(InputContext(model_config), seq_len, + _MultiModalCounts(mm_counts), + **mm_processor_kwargs) # Having more tokens is over-conservative but otherwise fine - num_tokens = seq_data.prompt_token_ids + num_tokens = dummy_data.seq_data.prompt_token_ids if len(num_tokens) < seq_len: if is_encoder_data: print_warning_once( @@ -235,15 +241,15 @@ def dummy_data_for_profiling( raise AssertionError( f"Expected at least {seq_len} dummy tokens for profiling, " f"but found {len(num_tokens)} tokens instead.") - if mm_data is not None: - for k, v in mm_data.items(): + if dummy_data.multi_modal_data is not None: + for k, v in dummy_data.multi_modal_data.items(): num_items = len(v) if isinstance(v, list) else 1 num_expected = mm_counts[k] assert num_items >= num_expected, ( f"Expected at least {num_expected} dummy '{k}' instances " f"for profiling, but found {num_items} instances instead.") - return seq_data, mm_data + return dummy_data def _default_input_processor( self, diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index dc5f47eb9b0fb..9694f2b8208e2 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -141,8 +141,11 @@ def create_weights( layer.register_parameter("input_scale", scale) def process_weights_after_loading(self, layer: Module) -> None: - max_w_scale, weight = requantize_with_max_scale( - layer.weight, layer.weight_scale, layer.logical_widths) + weight = layer.weight + max_w_scale = layer.weight_scale.max() + if not (layer.weight_scale == layer.weight_scale[0]).all(): + max_w_scale, weight = requantize_with_max_scale( + layer.weight, layer.weight_scale, layer.logical_widths) layer.weight = Parameter(weight.t(), requires_grad=False) layer.weight_scale = Parameter(max_w_scale, requires_grad=False) layer.input_scale = Parameter(layer.input_scale.max(), diff --git a/vllm/model_executor/layers/resampler.py b/vllm/model_executor/layers/resampler.py index 8cd938fc85fb2..bce91f1d7fd5e 100644 --- a/vllm/model_executor/layers/resampler.py +++ b/vllm/model_executor/layers/resampler.py @@ -41,6 +41,7 @@ from torch.nn.init import trunc_normal_ from vllm.model_executor.layers.linear import ReplicatedLinear +from vllm.model_executor.layers.quantization import QuantizationConfig DEFAULT_LN = partial(nn.LayerNorm, eps=1e-6) @@ -154,15 +155,15 @@ class BaseResampler(nn.Module): A tensor with the shape of (grid_size**2, embed_dim) """ - def __init__( - self, - num_queries: int, - embed_dim: int, - num_heads: int, - kv_dim: Optional[int] = None, - norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN, - do_post_projection: bool = True, - ) -> None: + def __init__(self, + num_queries: int, + embed_dim: int, + num_heads: int, + kv_dim: Optional[int] = None, + norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN, + do_post_projection: bool = True, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "") -> None: super().__init__() self.num_queries = num_queries @@ -172,7 +173,11 @@ def __init__( self.query = nn.Parameter(torch.zeros(self.num_queries, embed_dim)) trunc_normal_(self.query, std=0.02) if kv_dim is not None and kv_dim != embed_dim: - self.kv_proj = ReplicatedLinear(kv_dim, embed_dim, bias=False) + self.kv_proj = ReplicatedLinear(kv_dim, + embed_dim, + bias=False, + quant_config=quant_config, + prefix=prefix) else: # Maintain the same return value with ReplicatedLinear.forward self.kv_proj = lambda *args, **kwargs: ( # type: ignore # noqa @@ -209,22 +214,24 @@ class Resampler2(BaseResampler): present in minicpmv2.0, but not qwen-vl. """ - def __init__( - self, - grid_size: int, - embed_dim: int, - num_heads: int, - kv_dim: Optional[int] = None, - norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN, - adaptive: bool = False, - do_post_projection: bool = True, - ) -> None: + def __init__(self, + grid_size: int, + embed_dim: int, + num_heads: int, + kv_dim: Optional[int] = None, + norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN, + adaptive: bool = False, + do_post_projection: bool = True, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "") -> None: super().__init__(grid_size**2, embed_dim, num_heads, kv_dim, norm_layer, - do_post_projection=do_post_projection) + do_post_projection=do_post_projection, + quant_config=quant_config, + prefix=prefix) self.adaptive = adaptive pos_embed_arr = get_2d_sincos_pos_embed(embed_dim, diff --git a/vllm/model_executor/model_loader/__init__.py b/vllm/model_executor/model_loader/__init__.py index d1ec171c9ec2a..12468997e4653 100644 --- a/vllm/model_executor/model_loader/__init__.py +++ b/vllm/model_executor/model_loader/__init__.py @@ -1,27 +1,15 @@ -from typing import Optional - from torch import nn -from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, - ModelConfig, ParallelConfig, SchedulerConfig) +from vllm.config import VllmConfig from vllm.model_executor.model_loader.loader import (BaseModelLoader, get_model_loader) from vllm.model_executor.model_loader.utils import ( get_architecture_class_name, get_model_architecture) -def get_model(*, model_config: ModelConfig, load_config: LoadConfig, - device_config: DeviceConfig, parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - lora_config: Optional[LoRAConfig], - cache_config: CacheConfig) -> nn.Module: - loader = get_model_loader(load_config) - return loader.load_model(model_config=model_config, - device_config=device_config, - lora_config=lora_config, - parallel_config=parallel_config, - scheduler_config=scheduler_config, - cache_config=cache_config) +def get_model(*, vllm_config: VllmConfig) -> nn.Module: + loader = get_model_loader(vllm_config.load_config) + return loader.load_model(vllm_config=vllm_config) __all__ = [ diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index 79703bb7ded7a..07adf7c01eaaf 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -21,13 +21,14 @@ from transformers import AutoModelForCausalLM, PretrainedConfig from transformers.utils import SAFE_WEIGHTS_INDEX_NAME -from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoadFormat, - LoRAConfig, ModelConfig, MultiModalConfig, - ParallelConfig, PoolerConfig, SchedulerConfig) +from vllm.config import (CacheConfig, LoadConfig, LoadFormat, LoRAConfig, + ModelConfig, MultiModalConfig, ParallelConfig, + PoolerConfig, SchedulerConfig, VllmConfig) from vllm.distributed import (get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.envs import VLLM_USE_MODELSCOPE from vllm.logger import init_logger +from vllm.model_executor.layers.linear import ReplicatedLinear from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.model_loader.tensorizer import ( @@ -150,6 +151,7 @@ def _get_model_initialization_kwargs( def build_model(model_class: Type[nn.Module], + vllm_config: Optional[VllmConfig], hf_config: PretrainedConfig, cache_config: Optional[CacheConfig], quant_config: Optional[QuantizationConfig], @@ -166,23 +168,29 @@ def build_model(model_class: Type[nn.Module], if prefix: extra_kwargs["prefix"] = prefix + # TODO: unify all the module initialization code + # to only take the `VllmConfig` object as input + from vllm.plugins import set_vllm_config + set_vllm_config(vllm_config) + return model_class(config=hf_config, cache_config=cache_config, quant_config=quant_config, **extra_kwargs) -def _initialize_model( - model_config: ModelConfig, - load_config: LoadConfig, - lora_config: Optional[LoRAConfig], - cache_config: CacheConfig, - scheduler_config: Optional[SchedulerConfig] = None) -> nn.Module: +def _initialize_model(vllm_config: VllmConfig) -> nn.Module: """Initialize a model with the given configurations.""" + model_config = vllm_config.model_config + lora_config = vllm_config.lora_config + scheduler_config = vllm_config.scheduler_config + cache_config = vllm_config.cache_config + load_config = vllm_config.load_config model_class, _ = get_model_architecture(model_config) return build_model( model_class, + vllm_config, model_config.hf_config, cache_config=cache_config, quant_config=_get_quantization_config(model_config, load_config), @@ -205,12 +213,7 @@ def download_model(self, model_config: ModelConfig) -> None: raise NotImplementedError @abstractmethod - def load_model(self, *, model_config: ModelConfig, - device_config: DeviceConfig, - lora_config: Optional[LoRAConfig], - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - cache_config: CacheConfig) -> nn.Module: + def load_model(self, *, vllm_config: VllmConfig) -> nn.Module: """Load a model with the given configurations.""" raise NotImplementedError @@ -396,18 +399,14 @@ def download_model(self, model_config: ModelConfig) -> None: model_config.revision, fall_back_to_pt=True) - def load_model(self, *, model_config: ModelConfig, - device_config: DeviceConfig, - lora_config: Optional[LoRAConfig], - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - cache_config: CacheConfig) -> nn.Module: + def load_model(self, vllm_config: VllmConfig) -> nn.Module: + device_config = vllm_config.device_config + model_config = vllm_config.model_config + target_device = torch.device(device_config.device) with set_default_torch_dtype(model_config.dtype): with target_device: - model = _initialize_model(model_config, self.load_config, - lora_config, cache_config, - scheduler_config) + model = _initialize_model(vllm_config=vllm_config) model.load_weights(self._get_all_weights(model_config, model)) @@ -436,17 +435,12 @@ def __init__(self, load_config: LoadConfig): def download_model(self, model_config: ModelConfig) -> None: pass # Nothing to download - def load_model(self, *, model_config: ModelConfig, - device_config: DeviceConfig, - lora_config: Optional[LoRAConfig], - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - cache_config: CacheConfig) -> nn.Module: + def load_model(self, vllm_config: VllmConfig) -> nn.Module: + device_config = vllm_config.device_config + model_config = vllm_config.model_config with set_default_torch_dtype(model_config.dtype): with torch.device(device_config.device): - model = _initialize_model(model_config, self.load_config, - lora_config, cache_config, - scheduler_config) + model = _initialize_model(vllm_config=vllm_config) # NOTE(woosuk): For accurate performance evaluation, we assign # random values to the weights. initialize_dummy_weights(model) @@ -488,10 +482,7 @@ def _get_weights_iterator( def _load_model_serialized_cpu( self, - model_config: ModelConfig, - device_config: DeviceConfig, - lora_config: Optional[LoRAConfig], - cache_config: CacheConfig, + vllm_config: VllmConfig, ) -> nn.Module: """Load a serialized model with tensorizer to the CPU. @@ -500,26 +491,30 @@ def _load_model_serialized_cpu( default HuggingFace loading, but will be slower than loading a vLLM-tensorized model. """ + device_config = vllm_config.device_config + model_config = vllm_config.model_config with set_default_torch_dtype(model_config.dtype): with torch.device(device_config.device): - model = _initialize_model(model_config, self.load_config, - lora_config, cache_config) + model = _initialize_model(vllm_config=vllm_config) model.load_weights(self._get_weights_iterator()) return model.eval() def _load_model_serialized( self, - model_config: ModelConfig, - device_config: DeviceConfig, - lora_config: Optional[LoRAConfig], - cache_config: CacheConfig, + vllm_config: VllmConfig, ) -> nn.Module: """Load a serialized model with tensorizer. Expects a vLLM-tensorized model. See the examples/tensorize_vllm_model.py example script for serializing vLLM models.""" + + device_config = vllm_config.device_config + model_config = vllm_config.model_config + lora_config = vllm_config.lora_config + cache_config = vllm_config.cache_config + with set_default_torch_dtype(model_config.dtype): with torch.device(device_config.device): model_class = get_model_architecture(model_config)[0] @@ -544,12 +539,9 @@ def download_model(self, model_config: ModelConfig) -> None: with self.tensorizer_config.open_stream(): pass - def load_model(self, *, model_config: ModelConfig, - device_config: DeviceConfig, - lora_config: Optional[LoRAConfig], - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - cache_config: CacheConfig) -> nn.Module: + def load_model(self, vllm_config: VllmConfig) -> nn.Module: + model_config = vllm_config.model_config + parallel_config = vllm_config.parallel_config self._verify_config(model_config, parallel_config) if parallel_config.tensor_parallel_size > 1: @@ -559,10 +551,8 @@ def load_model(self, *, model_config: ModelConfig, % get_tensor_model_parallel_rank() if is_vllm_tensorized(self.tensorizer_config): - return self._load_model_serialized(model_config, device_config, - lora_config, cache_config) - return self._load_model_serialized_cpu(model_config, device_config, - lora_config, cache_config) + return self._load_model_serialized(vllm_config=vllm_config) + return self._load_model_serialized_cpu(vllm_config=vllm_config) @staticmethod def save_model( @@ -648,12 +638,9 @@ def _prepare_weights(self, model_name_or_path: str, def download_model(self, model_config: ModelConfig) -> None: self._prepare_weights(model_config.model, model_config.revision) - def load_model(self, *, model_config: ModelConfig, - device_config: DeviceConfig, - lora_config: Optional[LoRAConfig], - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - cache_config: CacheConfig) -> nn.Module: + def load_model(self, vllm_config: VllmConfig) -> nn.Module: + device_config = vllm_config.device_config + model_config = vllm_config.model_config from safetensors.torch import safe_open from vllm.distributed import get_tensor_model_parallel_rank @@ -663,8 +650,7 @@ def load_model(self, *, model_config: ModelConfig, with set_default_torch_dtype(model_config.dtype): with torch.device(device_config.device): - model = _initialize_model(model_config, self.load_config, - lora_config, cache_config) + model = _initialize_model(vllm_config=vllm_config) for _, module in model.named_modules(): quant_method = getattr(module, "quant_method", None) if quant_method is not None: @@ -786,6 +772,8 @@ def __init__(self, load_config: LoadConfig): with open(config_file_path, "r") as f: config = json.load(f) self.target_modules = config["target_modules"] + # Save the module names without sharding. + self.unsharded_weights_modules: List[str] = [] def _get_config_file(self, qlora_adapter: str) -> str: is_local = os.path.isdir(qlora_adapter) @@ -1005,16 +993,21 @@ def _unquantized_generator(self, hf_weights_files, use_safetensors, if any(target_module in weight_name for target_module in self.target_modules) and weight_name.endswith(".weight"): weight_name = weight_name.replace(".weight", ".qweight") - - if any(module in weight_name - for module in self.column_parallel_weights_modules): + # Without sharding + if any( + weight_name.startswith(module) + for module in self.unsharded_weights_modules): + weight_sub_tensor = weight_tensor + # Shard by column + elif any(module in weight_name + for module in self.column_parallel_weights_modules): total_size = weight_tensor.size(-1) start_index = total_size // tp_size * tp_rank end_index = total_size // tp_size * (tp_rank + 1) weight_sub_tensor = weight_tensor[..., start_index:end_index] - + # Shard by row else: total_size = weight_tensor.size(0) start_index = total_size // tp_size * tp_rank @@ -1068,7 +1061,15 @@ def _load_weights(self, model_config: ModelConfig, model.column_parallel_weights_modules else: self.column_parallel_weights_modules = [] - + # Some modules like `ReplicatedLinear` should not have their weights + # sharded. The reason for implementing it this way is to avoid new + # static variable in the model implementation. + # TODO: Can we reduce the static variables needed for BNB based on + # model information? + self.unsharded_weights_modules = [ + name for name, module in model.named_modules() + if isinstance(module, (ReplicatedLinear, )) + ] self.model_type = type(model).__name__ logger.info("Loading weights with BitsAndBytes quantization. " @@ -1115,7 +1116,13 @@ def _load_weights(self, model_config: ModelConfig, for shard_name, ( weight_name, index ) in model.bitsandbytes_stacked_params_mapping.items(): - if shard_name in quant_param_name: + + shard_pos = quant_param_name.find(shard_name) + # Some models, such as MiniCPM V2.5/2.6, contain both + # module names 'kv_proj' and 'qkv_proj'. To prevent 'kv_proj' + # from being incorrectly identified as being present in + # 'vpm.encoder.layers.0.self_attn.qkv_proj.qweight + if shard_pos > 0 and quant_param_name[shard_pos - 1] == ".": shard_index = index quant_param_name = quant_param_name.replace( shard_name, weight_name) @@ -1157,16 +1164,12 @@ def _load_weights(self, model_config: ModelConfig, def download_model(self, model_config: ModelConfig) -> None: self._prepare_weights(model_config.model, model_config.revision) - def load_model(self, *, model_config: ModelConfig, - device_config: DeviceConfig, - lora_config: Optional[LoRAConfig], - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - cache_config: CacheConfig) -> nn.Module: + def load_model(self, vllm_config: VllmConfig) -> nn.Module: + device_config = vllm_config.device_config + model_config = vllm_config.model_config with set_default_torch_dtype(model_config.dtype): with torch.device(device_config.device): - model = _initialize_model(model_config, self.load_config, - lora_config, cache_config) + model = _initialize_model(vllm_config=vllm_config) self._load_weights(model_config, model) @@ -1235,13 +1238,9 @@ def _get_weights_iterator( def download_model(self, model_config: ModelConfig) -> None: self._prepare_weights(model_config.model) - def load_model(self, *, model_config: ModelConfig, - device_config: DeviceConfig, - lora_config: Optional[LoRAConfig], - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - cache_config: CacheConfig) -> nn.Module: - + def load_model(self, vllm_config: VllmConfig) -> nn.Module: + device_config = vllm_config.device_config + model_config = vllm_config.model_config local_model_path = self._prepare_weights(model_config.model) gguf_weights_map = self._get_gguf_weights_map(model_config) # we can only know if tie word embeddings after mapping weights @@ -1251,8 +1250,7 @@ def load_model(self, *, model_config: ModelConfig, with set_default_torch_dtype(model_config.dtype): with torch.device(device_config.device): - model = _initialize_model(model_config, self.load_config, - lora_config, cache_config) + model = _initialize_model(vllm_config=vllm_config) model.load_weights( self._get_weights_iterator(local_model_path, gguf_weights_map)) return model diff --git a/vllm/model_executor/models/bart.py b/vllm/model_executor/models/bart.py index cbdacf779b089..0543ca978b7dd 100644 --- a/vllm/model_executor/models/bart.py +++ b/vllm/model_executor/models/bart.py @@ -624,8 +624,6 @@ def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, Decoder output torch.Tensor """ # retrieve input_ids and inputs_embeds - - input_ids = input_ids.view(-1, input_ids.shape[-1]) inputs_embeds = self.embed_tokens(input_ids) embed_pos = self.embed_positions( diff --git a/vllm/model_executor/models/blip.py b/vllm/model_executor/models/blip.py index 1f2d7384076ed..e612010677364 100644 --- a/vllm/model_executor/models/blip.py +++ b/vllm/model_executor/models/blip.py @@ -98,6 +98,11 @@ def input_processor_for_blip( if multi_modal_data is None or "image" not in multi_modal_data: return inputs + if "multi_modal_placeholders" in inputs and "image" in inputs[ + "multi_modal_placeholders"]: + # The inputs already have placeholders. + return inputs + tokenizer = cached_get_tokenizer(model_config.tokenizer) if image_feature_size_override is None: @@ -105,7 +110,7 @@ def input_processor_for_blip( else: image_feature_size = image_feature_size_override - new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens( + new_prompt, new_token_ids, ranges = repeat_and_pad_placeholder_tokens( tokenizer, inputs.get("prompt"), inputs["prompt_token_ids"], @@ -116,7 +121,8 @@ def input_processor_for_blip( # NOTE: Create a defensive copy of the original inputs return token_inputs(prompt_token_ids=new_token_ids, prompt=new_prompt, - multi_modal_data=multi_modal_data) + multi_modal_data=multi_modal_data, + multi_modal_placeholders={"image": ranges}) # Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/blip/modeling_blip.py#L164 # noqa diff --git a/vllm/model_executor/models/blip2.py b/vllm/model_executor/models/blip2.py index c3b3cc8a4ddb6..db1f92649bd49 100644 --- a/vllm/model_executor/models/blip2.py +++ b/vllm/model_executor/models/blip2.py @@ -9,13 +9,14 @@ from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, MultiModalConfig -from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext, - token_inputs) +from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, + InputContext, token_inputs) from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.utils import consecutive_placeholder_ranges from vllm.sequence import IntermediateTensors, SequenceData from .blip import (BlipVisionModel, dummy_image_for_blip, @@ -425,7 +426,11 @@ def dummy_seq_data_for_blip2( return SequenceData.from_prompt_token_counts( (image_token_id, image_feature_size * num_images), (0, seq_len - image_feature_size * num_images), - ) + ), { + "image": + consecutive_placeholder_ranges(num_items=num_images, + item_size=image_feature_size) + } def dummy_data_for_blip2(ctx: InputContext, seq_len: int, @@ -434,7 +439,7 @@ def dummy_data_for_blip2(ctx: InputContext, seq_len: int, vision_config = hf_config.vision_config num_images = mm_counts["image"] - seq_data = dummy_seq_data_for_blip2( + seq_data, ranges = dummy_seq_data_for_blip2( hf_config, seq_len, num_images, @@ -444,7 +449,7 @@ def dummy_data_for_blip2(ctx: InputContext, seq_len: int, if isinstance(vision_config, Blip2VisionConfig): mm_data = dummy_image_for_blip(vision_config, num_images) - return seq_data, mm_data + return DummyData(seq_data, mm_data, ranges) msg = f"Unsupported vision config: {type(vision_config)}" raise NotImplementedError(msg) diff --git a/vllm/model_executor/models/chameleon.py b/vllm/model_executor/models/chameleon.py index aaf559ca386cc..9f6c6786c0fa4 100644 --- a/vllm/model_executor/models/chameleon.py +++ b/vllm/model_executor/models/chameleon.py @@ -11,8 +11,8 @@ from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig, MultiModalConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size -from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext, - token_inputs) +from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, + InputContext, token_inputs) from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, @@ -30,6 +30,7 @@ from vllm.model_executor.utils import set_weight_attrs from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.utils import (cached_get_tokenizer, + consecutive_placeholder_ranges, repeat_and_pad_placeholder_tokens) from vllm.sequence import IntermediateTensors, SequenceData from vllm.utils import print_warning_once @@ -73,7 +74,11 @@ def dummy_seq_data_for_chameleon( return SequenceData.from_prompt_token_counts( (image_token_id, image_feature_size * num_images), (0, seq_len - image_feature_size * num_images), - ) + ), { + "image": + consecutive_placeholder_ranges(num_items=num_images, + item_size=image_feature_size) + } def dummy_image_for_chameleon( @@ -97,14 +102,14 @@ def dummy_data_for_chameleon(ctx: InputContext, seq_len: int, mm_counts: Mapping[str, int]): num_images = mm_counts["image"] - seq_data = dummy_seq_data_for_chameleon( + seq_data, ranges = dummy_seq_data_for_chameleon( seq_len, num_images, image_token_id=CHAMELEON_IMAGE_TOKEN_ID, ) mm_data = dummy_image_for_chameleon(num_images) - return seq_data, mm_data + return DummyData(seq_data, mm_data, ranges) def input_processor_for_chameleon(ctx: InputContext, @@ -120,9 +125,14 @@ def input_processor_for_chameleon(ctx: InputContext, if multi_modal_data is None or "image" not in multi_modal_data: return inputs + if "multi_modal_placeholders" in inputs and "image" in inputs[ + "multi_modal_placeholders"]: + # The inputs already have placeholders. + return inputs + model_config = ctx.model_config tokenizer = cached_get_tokenizer(model_config.tokenizer) - new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens( + new_prompt, new_token_ids, ranges = repeat_and_pad_placeholder_tokens( tokenizer, inputs.get("prompt"), inputs["prompt_token_ids"], diff --git a/vllm/model_executor/models/chatglm.py b/vllm/model_executor/models/chatglm.py index ca90d10e9f9fb..c3c9ec703c1e6 100644 --- a/vllm/model_executor/models/chatglm.py +++ b/vllm/model_executor/models/chatglm.py @@ -14,8 +14,8 @@ from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig, LoRAConfig, MultiModalConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size -from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext, - token_inputs) +from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, + InputContext, token_inputs) from vllm.logger import init_logger from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm @@ -31,8 +31,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.glm4_vision_encoder import EVA2CLIPModel from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalDataDict, - MultiModalInputs) +from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalInputs from vllm.multimodal.base import MultiModalData from vllm.multimodal.utils import cached_get_tokenizer from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors, @@ -117,16 +116,15 @@ def get_max_glmv_image_tokens(ctx: InputContext): raise NotImplementedError(msg) -def dummy_data_for_glmv( - ctx: InputContext, seq_len: int, mm_counts: Mapping[str, int] -) -> Tuple[SequenceData, Optional[MultiModalDataDict]]: +def dummy_data_for_glmv(ctx: InputContext, seq_len: int, + mm_counts: Mapping[str, int]) -> DummyData: hf_config = ctx.get_hf_config(ChatGLMConfig) vision_config = getattr(hf_config, 'vision_config', None) if vision_config is None: token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, [0] * seq_len) seq_data = SequenceData(token_ids) - return seq_data, None + return DummyData(seq_data, None) elif isinstance(vision_config, dict): image_size = vision_config["image_size"] image_placeholder_length = calculate_image_placeholder(vision_config) @@ -141,7 +139,7 @@ def dummy_data_for_glmv( "image": Image.new("RGB", (image_size, image_size), color=0) } - return seq_data, mm_data + return DummyData(seq_data, mm_data) msg = f"Unsupported vision config: {type(vision_config)}" raise NotImplementedError(msg) diff --git a/vllm/model_executor/models/clip.py b/vllm/model_executor/models/clip.py index a3293020c042e..2d81b9266826b 100644 --- a/vllm/model_executor/models/clip.py +++ b/vllm/model_executor/models/clip.py @@ -19,6 +19,7 @@ from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.multimodal.utils import (cached_get_tokenizer, + consecutive_placeholder_ranges, repeat_and_pad_placeholder_tokens) from vllm.sequence import SequenceData @@ -49,14 +50,13 @@ def get_max_clip_image_tokens(hf_config: CLIPVisionConfig) -> int: return get_clip_image_feature_size(hf_config) -def dummy_seq_data_for_clip( - hf_config: CLIPVisionConfig, - seq_len: int, - num_images: int, - *, - image_token_id: int, - image_feature_size_override: Optional[int] = None, -): +def dummy_seq_data_for_clip(hf_config: CLIPVisionConfig, + seq_len: int, + num_images: int, + *, + image_token_id: int, + image_feature_size_override: Optional[int] = None, + mm_key: str = "image"): if image_feature_size_override is None: image_feature_size = get_clip_image_feature_size(hf_config) else: @@ -65,7 +65,11 @@ def dummy_seq_data_for_clip( return SequenceData.from_prompt_token_counts( (image_token_id, image_feature_size * num_images), (0, seq_len - image_feature_size * num_images), - ) + ), { + mm_key: + consecutive_placeholder_ranges(num_items=num_images, + item_size=image_feature_size) + } def dummy_image_for_clip( @@ -117,6 +121,11 @@ def input_processor_for_clip( if multi_modal_data is None or "image" not in multi_modal_data: return inputs + if "multi_modal_placeholders" in inputs and "image" in inputs[ + "multi_modal_placeholders"]: + # The inputs already have placeholders. + return inputs + tokenizer = cached_get_tokenizer(model_config.tokenizer) if image_feature_size_override is None: @@ -130,7 +139,7 @@ def input_processor_for_clip( else: image_feature_size = image_feature_size_override - new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens( + new_prompt, new_token_ids, ranges = repeat_and_pad_placeholder_tokens( tokenizer, inputs.get("prompt"), inputs["prompt_token_ids"], @@ -141,7 +150,8 @@ def input_processor_for_clip( # NOTE: Create a defensive copy of the original inputs return token_inputs(prompt_token_ids=new_token_ids, prompt=new_prompt, - multi_modal_data=multi_modal_data) + multi_modal_data=multi_modal_data, + multi_modal_placeholders={"image": ranges}) # Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/clip/modeling_clip.py#L164 # noqa diff --git a/vllm/model_executor/models/fuyu.py b/vllm/model_executor/models/fuyu.py index 358d1dd288c49..0de590d1d8372 100644 --- a/vllm/model_executor/models/fuyu.py +++ b/vllm/model_executor/models/fuyu.py @@ -27,8 +27,8 @@ from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, MultiModalConfig -from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext, - token_inputs) +from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, + InputContext, token_inputs) from vllm.model_executor.layers.linear import ColumnParallelLinear from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import SamplerOutput @@ -37,9 +37,11 @@ from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.base import MultiModalInputs from vllm.multimodal.image import cached_get_image_processor -from vllm.multimodal.utils import cached_get_tokenizer +from vllm.multimodal.utils import (cached_get_tokenizer, + consecutive_placeholder_ranges) from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors, SequenceData) +from vllm.utils import is_list_of from .interfaces import SupportsMultiModal, SupportsPP from .utils import AutoWeightsLoader, flatten_bn, merge_multimodal_embeddings @@ -103,7 +105,11 @@ def dummy_seq_data_for_fuyu(ctx: InputContext, seq_len: int, num_images: int): token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, image_token_ids) * num_images token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE, [0]) * (seq_len - image_feature_size * num_images) - return SequenceData(token_ids) + return SequenceData(token_ids), { + "image": + consecutive_placeholder_ranges(num_items=num_images, + item_size=image_feature_size) + } def dummy_image_for_fuyu( @@ -119,15 +125,15 @@ def dummy_image_for_fuyu( def dummy_data_for_fuyu(ctx: InputContext, seq_len: int, mm_counts: Mapping[str, int]): num_images = mm_counts["image"] - seq_data = dummy_seq_data_for_fuyu(ctx, seq_len, num_images) + seq_data, ranges = dummy_seq_data_for_fuyu(ctx, seq_len, num_images) mm_data = dummy_image_for_fuyu(num_images, image_width=MAX_IMAGE_FEATURE_SIZE_WIDTH, image_height=MAX_IMAGE_FEATURE_SIZE_HEIGHT) - return seq_data, mm_data + return DummyData(seq_data, mm_data, ranges) def _fuyu_image_preprocess(image_processor: FuyuImageProcessor, - data: Image.Image): + data: List[Image.Image]): image_encoding = image_processor.preprocess(data, return_tensors="pt") batch_images = torch.stack([img[0] for img in image_encoding["images"] ]).unsqueeze(1) @@ -158,8 +164,10 @@ def input_processor_for_fuyu(ctx: InputContext, inputs: DecoderOnlyInputs): model_config = ctx.model_config image_data = multi_modal_data["image"] new_multi_modal_data = {} + image_list = image_data if isinstance(image_data, list) else [image_data] + # process image data - if isinstance(image_data, Image.Image): + if is_list_of(image_list, Image.Image): # Fuyu's image_processor can also finish token padding image_processor: FuyuImageProcessor = cached_get_image_processor( model_config.model) @@ -171,7 +179,7 @@ def input_processor_for_fuyu(ctx: InputContext, inputs: DecoderOnlyInputs): ]) new_multi_modal_data["image"] = image_patches - elif isinstance(image_data, torch.Tensor): + elif is_list_of(image_list, torch.Tensor): raise NotImplementedError("Embeddings input is not supported yet") else: raise TypeError(f"Invalid image type: {type(image_data)}") @@ -198,12 +206,13 @@ def input_processor_for_fuyu(ctx: InputContext, inputs: DecoderOnlyInputs): def input_mapper_for_fuyu(ctx: InputContext, data: object): model_config = ctx.model_config - if isinstance(data, Image.Image): + data_list = data if isinstance(data, list) else [data] + if is_list_of(data_list, Image.Image): # Fuyu's image_processor can also finish token padding image_processor: FuyuImageProcessor = cached_get_image_processor( model_config.model) - model_image_input = _fuyu_image_preprocess(image_processor, data) + model_image_input = _fuyu_image_preprocess(image_processor, data_list) data = torch.stack([ image_patch[0] for image_patch in model_image_input["image_patches"] diff --git a/vllm/model_executor/models/h2ovl.py b/vllm/model_executor/models/h2ovl.py new file mode 100644 index 0000000000000..43242fe370ba2 --- /dev/null +++ b/vllm/model_executor/models/h2ovl.py @@ -0,0 +1,401 @@ +# adapted from https://huggingface.co/h2oai/h2ovl-mississippi-2b/blob/main/modeling_h2ovl_chat.py +# https://huggingface.co/h2oai/h2ovl-mississippi-2b/blob/main/image_process.py +# -------------------------------------------------------- +# H2OVL-Mississippi +# Copyright (c) 2024 H2O.AI +# Licensed under Apache 2.0 License [see LICENSE for details] +# -------------------------------------------------------- +from functools import partial +from typing import List, Optional, Tuple + +import torch +from PIL import Image +from transformers import PretrainedConfig + +from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext, + token_inputs) +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.base import MultiModalInputs +from vllm.multimodal.utils import cached_get_tokenizer +from vllm.utils import is_list_of + +from .intern_vit import InternVisionModel +from .internvl import (IMG_CONTEXT, IMG_END, IMG_START, InternVLChatModel, + InternVLInputPipeline, build_transform, + find_closest_aspect_ratio, get_internvl_num_patches) + + +# modified to include blocks generated in second pass +def calculate_num_blocks( + orig_width: int, + orig_height: int, + min_num: int, + max_num: int, + image_size: int, + use_thumbnail: bool, + prior_aspect_ratio=None, +) -> Tuple[int, int, int, Tuple[int, int]]: + aspect_ratio = orig_width / orig_height + + # calculate the existing image aspect ratio + target_ratios = set((i, j) for n in range(min_num, max_num + 1) + for i in range(1, n + 1) for j in range(1, n + 1) + if i * j <= max_num and i * j >= min_num) + target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) + + # if prior_aspect_ratio is provided, filter the target ratios + if prior_aspect_ratio is not None: + target_ratios = [ + ratio for ratio in target_ratios if prior_aspect_ratio[0] % + ratio[0] != 0 and prior_aspect_ratio[1] % ratio[1] != 0 + ] + + # find the closest aspect ratio to the target + target_aspect_ratio = find_closest_aspect_ratio(aspect_ratio, + target_ratios, orig_width, + orig_height, image_size) + + # calculate the target width and height + target_width = image_size * target_aspect_ratio[0] + target_height = image_size * target_aspect_ratio[1] + blocks = target_aspect_ratio[0] * target_aspect_ratio[1] + # add thumbnail image if num_blocks > 1 + if use_thumbnail and blocks > 1: + blocks += 1 + return blocks, target_width, target_height, target_aspect_ratio + + +# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B +# refactored to handle prior_aspect_ratio as optional +def dynamic_preprocess( + image: Image.Image, + min_num: int, + max_num: int, + image_size: int, + use_thumbnail: bool, + prior_aspect_ratio: Optional[Tuple[int, int]] = None, +) -> Tuple[List[Image.Image], Tuple[int, int]]: + orig_width, orig_height = image.size + + # calculate the number of blocks based on prior aspect ratio if available + blocks, target_width, target_height, target_aspect_ratio = ( + calculate_num_blocks( + orig_width, + orig_height, + min_num, + max_num, + image_size, + use_thumbnail=False, + prior_aspect_ratio=prior_aspect_ratio, + )) + # resize the image + resized_img = image.resize((target_width, target_height)) + processed_images = [] + for i in range(blocks): + box = ( + (i % (target_width // image_size)) * image_size, + (i // (target_width // image_size)) * image_size, + ((i % (target_width // image_size)) + 1) * image_size, + ((i // (target_width // image_size)) + 1) * image_size, + ) + # split the image + split_img = resized_img.crop(box) + processed_images.append(split_img) + assert len(processed_images) == blocks + if use_thumbnail and len(processed_images) != 1: + thumbnail_img = image.resize((image_size, image_size)) + processed_images.append(thumbnail_img) + return processed_images, target_aspect_ratio + + +def load_image( + image: Image.Image, + input_size=448, + min_num=1, + max_num=6, + use_thumbnail=True, + prior_aspect_ratio: Optional[Tuple[int, int]] = None, +) -> Tuple[torch.Tensor, Tuple[int, int]]: + transform = build_transform(input_size=input_size) + images, target_aspect_ratio = dynamic_preprocess( + image, + image_size=input_size, + use_thumbnail=use_thumbnail, + min_num=min_num, + max_num=max_num, + prior_aspect_ratio=prior_aspect_ratio, + ) + pixel_values = [transform(image) for image in images] + pixel_values = torch.stack(pixel_values) + return pixel_values, target_aspect_ratio + + +# refactored to use the combined load_image function +def image_to_pixel_values( + image: Image.Image, + input_size: int, + min_num: int, + max_num: int, + use_thumbnail: bool, + use_MSAC: bool, +) -> torch.Tensor: + # when MSAC is turned on, we need to process the image twice + if use_MSAC: + # first pass + pixel_values, target_aspect_ratio = load_image( + image, + input_size=input_size, + min_num=min_num, + max_num=max_num, + use_thumbnail=True, + ) + # second pass + pixel_values2, _ = load_image( + image, + input_size=input_size, + min_num=min_num, + max_num=max_num, + prior_aspect_ratio=target_aspect_ratio, + ) + # combine pixel values + pixel_values = torch.cat( + [pixel_values2[:-1], pixel_values[:-1], pixel_values2[-1:]], 0) + + else: + pixel_values, _ = load_image( + image, + input_size=input_size, + min_num=min_num, + max_num=max_num, + use_thumbnail=use_thumbnail, + ) + + return pixel_values + + +def image_to_pixel_values_wrapper(hf_config: PretrainedConfig, + max_dynamic_patch: Optional[int] = None, + use_MSAC: Optional[bool] = None): + image_size = hf_config.vision_config.image_size + min_num = hf_config.min_dynamic_patch + if max_dynamic_patch is None: + max_dynamic_patch = hf_config.max_dynamic_patch + if use_MSAC is None: + use_MSAC = hf_config.use_msac + use_thumbnail = hf_config.use_thumbnail + return partial( + image_to_pixel_values, + input_size=image_size, + min_num=min_num, + max_num=max_dynamic_patch, + use_thumbnail=use_thumbnail, + use_MSAC=use_MSAC, + ) + + +def get_max_internvl_image_tokens(ctx: InputContext, + *, + max_dynamic_patch: Optional[int] = None): + """ + Calculate the maximum number of tokens with/without MSAC and thumbnail + """ + hf_config = ctx.get_hf_config() + use_thumbnail = hf_config.use_thumbnail + use_MSAC = hf_config.use_msac + + if max_dynamic_patch is None: + max_dynamic_patch = hf_config.max_dynamic_patch + + num_patches = get_internvl_num_patches(hf_config) + + coefficient = 2 if use_MSAC else 1 + num_blocks = coefficient * max_dynamic_patch + (1 if use_thumbnail else 0) + + return num_blocks * num_patches + + +class H2OVLInputPipeline(InternVLInputPipeline): + """ + Input pipeline for processing image and text data for the H2OVL model. + """ + + def input_processor( + self, + ctx: InputContext, + inputs: DecoderOnlyInputs, + *, + max_dynamic_patch: Optional[int] = None, + ) -> DecoderOnlyInputs: + # get multi_modal_data + multi_modal_data = inputs.get("multi_modal_data") + if multi_modal_data is None or "image" not in multi_modal_data: + return inputs + + model_config = ctx.model_config + hf_config = ctx.get_hf_config() + use_MSAC = hf_config.use_msac + + image_data = multi_modal_data["image"] + num_patches = get_internvl_num_patches(hf_config) + + image_pixel_values_mapper = image_to_pixel_values_wrapper( + hf_config, max_dynamic_patch=max_dynamic_patch) + + # single image + if isinstance(image_data, Image.Image): + pixel_values = image_pixel_values_mapper(image_data, + use_MSAC=use_MSAC) + num_blocks = pixel_values.shape[0] + image_feature_sizes = [num_blocks * num_patches] + pixel_values = pixel_values.unsqueeze(0) + + # multi images + elif is_list_of(image_data, Image.Image): + # Do not use MSAC for multi images + image_feature_sizes = [] + pixel_values = [ + image_pixel_values_mapper(image, use_MSAC=False) + for image in image_data + ] + for pixel_value in pixel_values: + num_blocks = pixel_value.shape[0] + image_feature_sizes.append(num_blocks * num_patches) + + # image embeddings as input + elif isinstance(image_data, torch.Tensor): + _, image_feature_size, _ = image_data.shape + image_feature_sizes = [image_feature_size] + pixel_values = None + + # multi-image image embeddings + elif is_list_of(image_data, torch.Tensor): + + image_feature_sizes = [] + for image_embed in image_data: + _, image_feature_size, _ = image_embed.shape + image_feature_sizes.append(image_feature_size) + pixel_values = None + + else: + raise TypeError(f"Invalid image type: {type(image_data)}") + + tokenizer = cached_get_tokenizer( + model_config.tokenizer, + trust_remote_code=model_config.trust_remote_code, + ) + + prompt = inputs.get("prompt") + prompt_token_ids = inputs["prompt_token_ids"] + if prompt is None: + prompt = tokenizer.decode(prompt_token_ids) + + new_prompt = self._expand_image_prompt(prompt, image_feature_sizes, + num_patches) + new_prompt_token_ids = tokenizer.encode(new_prompt) + + # Wrap image processing in input_processor to avoid duplication + image_token_id = tokenizer.encode( + self.img_context_token, + add_special_tokens=False, + return_tensors="pt", + )[0] + + # Update multi_modal_data to return + if pixel_values is not None: + multi_modal_data = { + "image": { + "pixel_values": pixel_values, + "image_token_id": image_token_id, + } + } + else: + multi_modal_data = {"image": {"image_embeds": image_data}} + + return token_inputs( + prompt=prompt, + prompt_token_ids=new_prompt_token_ids, + multi_modal_data=multi_modal_data, + ) + + def input_mapper( + self, + ctx: InputContext, + data: object, + *, + max_dynamic_patch: Optional[int] = None, + ) -> MultiModalInputs: + + # NOTE: Preprocessing for the image data is done in the + # 'input_processor' function during actual inference. + if isinstance(data, dict): + return MultiModalInputs(data) + + # The section below is only used with dummy data during + # memory profiling. + hf_config = ctx.get_hf_config() + + image_pixel_values_mapper = image_to_pixel_values_wrapper( + hf_config, max_dynamic_patch) + + if isinstance(data, Image.Image): + pixel_values = image_pixel_values_mapper(data) + pixel_values = pixel_values.unsqueeze(0) + + elif is_list_of(data, Image.Image): + hf_config.use_msac = False + pixel_values = [image_pixel_values_mapper(img) for img in data] + + else: + return MultiModalInputs({"image_embeds": data}) + model_config = ctx.model_config + tokenizer = cached_get_tokenizer( + model_config.tokenizer, + trust_remote_code=model_config.trust_remote_code, + ) + image_token_id = tokenizer.encode( + self.img_context_token, + add_special_tokens=False, + return_tensors="pt", + )[0] + + return MultiModalInputs({ + "pixel_values": pixel_values, + "image_token_id": image_token_id + }) + + +input_pipeline = H2OVLInputPipeline(IMG_START, IMG_END, IMG_CONTEXT) + + +@MULTIMODAL_REGISTRY.register_image_input_mapper(input_pipeline.input_mapper) +@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_internvl_image_tokens) +@INPUT_REGISTRY.register_dummy_data(input_pipeline.dummy_data) +@INPUT_REGISTRY.register_input_processor(input_pipeline.input_processor) +class H2OVLChatModel(InternVLChatModel): + + def _init_vision_model( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig], + *, + is_mono: bool, + prefix: str, + ): + if not is_mono: + vision_feature_layer = config.select_layer + if vision_feature_layer < 0: + num_hidden_layers = (config.vision_config.num_hidden_layers + + vision_feature_layer + 1) + else: + num_hidden_layers = vision_feature_layer + 1 + + return InternVisionModel( + config.vision_config, + quant_config=quant_config, + num_hidden_layers_override=num_hidden_layers, + prefix=prefix, + ) + else: + msg = "Monolith mode is not applicable to H2OVL" + raise NotImplementedError(msg) diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py index 1c1fde5b30983..d2ec0ff6e74c6 100644 --- a/vllm/model_executor/models/internvl.py +++ b/vllm/model_executor/models/internvl.py @@ -17,8 +17,8 @@ from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, MultiModalConfig -from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext, - token_inputs) +from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, + InputContext, token_inputs) from vllm.model_executor.layers.quantization import (AWQConfig, QuantizationConfig) from vllm.model_executor.layers.sampler import Sampler, SamplerOutput @@ -379,7 +379,7 @@ def dummy_data( model_config.tokenizer, trust_remote_code=model_config.trust_remote_code) - seq_data = dummy_seq_data_for_clip( + seq_data, ranges = dummy_seq_data_for_clip( hf_config.vision_config, seq_len, num_images, @@ -398,7 +398,7 @@ def dummy_data( image_height_override=max_image_height, ) - return seq_data, mm_data + return DummyData(seq_data, mm_data, ranges) input_pipeline = InternVLInputPipeline(IMG_START, IMG_END, IMG_CONTEXT) diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index 27055e7ced865..7fbd59ebd98fd 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -10,7 +10,8 @@ from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, MultiModalConfig -from vllm.inputs import INPUT_REGISTRY, DecoderOnlyInputs, InputContext +from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, + InputContext) from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import Sampler, SamplerOutput @@ -111,7 +112,7 @@ def dummy_data_for_llava(ctx: InputContext, seq_len: int, image_feature_size = get_max_llava_image_tokens(ctx) if isinstance(vision_config, CLIPVisionConfig): - seq_data = dummy_seq_data_for_clip( + seq_data, ranges = dummy_seq_data_for_clip( vision_config, seq_len, num_images, @@ -120,9 +121,9 @@ def dummy_data_for_llava(ctx: InputContext, seq_len: int, ) mm_data = dummy_image_for_clip(vision_config, num_images) - return seq_data, mm_data + return DummyData(seq_data, mm_data, ranges) elif isinstance(vision_config, SiglipVisionConfig): - seq_data = dummy_seq_data_for_siglip( + seq_data, ranges = dummy_seq_data_for_siglip( vision_config, seq_len, num_images, @@ -131,9 +132,9 @@ def dummy_data_for_llava(ctx: InputContext, seq_len: int, ) mm_data = dummy_image_for_siglip(vision_config, num_images) - return seq_data, mm_data + return DummyData(seq_data, mm_data, ranges) elif isinstance(vision_config, PixtralVisionConfig): - seq_data = dummy_seq_data_for_pixtral_hf( + seq_data, ranges = dummy_seq_data_for_pixtral_hf( vision_config, seq_len, num_images, @@ -142,7 +143,7 @@ def dummy_data_for_llava(ctx: InputContext, seq_len: int, ) mm_data = dummy_image_for_pixtral_hf(vision_config, num_images) - return seq_data, mm_data + return DummyData(seq_data, mm_data, ranges) msg = f"Unsupported vision config: {type(vision_config)}" raise NotImplementedError(msg) diff --git a/vllm/model_executor/models/llava_next.py b/vllm/model_executor/models/llava_next.py index e8540d85ff565..7a2c95594ddcd 100644 --- a/vllm/model_executor/models/llava_next.py +++ b/vllm/model_executor/models/llava_next.py @@ -12,7 +12,8 @@ from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, MultiModalConfig, PoolerConfig -from vllm.inputs import INPUT_REGISTRY, DecoderOnlyInputs, InputContext +from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, + InputContext) from vllm.model_executor.layers.pooler import Pooler, PoolingType from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import Sampler, SamplerOutput @@ -180,7 +181,7 @@ def dummy_data_for_llava_next(ctx: InputContext, seq_len: int, max_feat_height, max_feat_width = pinpoint if isinstance(vision_config, CLIPVisionConfig): - seq_data = dummy_seq_data_for_clip( + seq_data, ranges = dummy_seq_data_for_clip( vision_config, seq_len, num_images, @@ -195,9 +196,9 @@ def dummy_data_for_llava_next(ctx: InputContext, seq_len: int, image_height_override=max_feat_height, ) - return seq_data, mm_data + return DummyData(seq_data, mm_data, ranges) elif isinstance(vision_config, SiglipVisionConfig): - seq_data = dummy_seq_data_for_siglip( + seq_data, ranges = dummy_seq_data_for_siglip( vision_config, seq_len, num_images, @@ -212,7 +213,7 @@ def dummy_data_for_llava_next(ctx: InputContext, seq_len: int, image_height_override=max_feat_height, ) - return seq_data, mm_data + return DummyData(seq_data, mm_data, ranges) msg = f"Unsupported vision config: {type(vision_config)}" raise NotImplementedError(msg) @@ -605,7 +606,6 @@ def forward( :class:`LlavaNextImageInputs` """ if intermediate_tensors is not None: - input_ids = None inputs_embeds = None else: image_input = self._parse_and_validate_image_input(**kwargs) @@ -617,9 +617,14 @@ def forward( self.language_model.model.get_input_embeddings, lambda _: self._process_image_input(image_input), ) - input_ids = None else: - inputs_embeds = None + inputs_embeds = self.language_model.model.get_input_embeddings( + input_ids) + + # always pass the input via `inputs_embeds` + # to make sure the computation graph is consistent + # for `torch.compile` integration + input_ids = None hidden_states = self.language_model.model(input_ids, positions, diff --git a/vllm/model_executor/models/llava_next_video.py b/vllm/model_executor/models/llava_next_video.py index b8051d5fc6ae2..b755e2347f6ed 100644 --- a/vllm/model_executor/models/llava_next_video.py +++ b/vllm/model_executor/models/llava_next_video.py @@ -11,8 +11,8 @@ from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, MultiModalConfig -from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext, - token_inputs) +from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, + InputContext, token_inputs) from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import Sampler, SamplerOutput @@ -108,33 +108,35 @@ def dummy_data_for_llava_next_video(ctx: InputContext, seq_len: int, video_feature_size = frames_per_video * tokens_per_frame if isinstance(vision_config, CLIPVisionConfig): - seq_data = dummy_seq_data_for_clip( + seq_data, ranges = dummy_seq_data_for_clip( vision_config, seq_len, num_videos, image_token_id=hf_config.video_token_index, image_feature_size_override=video_feature_size, + mm_key="video", ) pil_frame = dummy_image_for_clip(vision_config, num_images=1) np_frame = np.array(pil_frame["image"]) mm_data_per_video = np.repeat([np_frame], frames_per_video, axis=0) mm_data = {"video": mm_data_per_video} - return seq_data, mm_data + return DummyData(seq_data, mm_data, ranges) elif isinstance(vision_config, SiglipVisionConfig): - seq_data = dummy_seq_data_for_siglip( + seq_data, ranges = dummy_seq_data_for_siglip( vision_config, seq_len, num_videos, image_token_id=hf_config.video_token_index, image_feature_size_override=video_feature_size, + mm_key="video", ) pil_frame = dummy_image_for_siglip(vision_config, num_images=1) np_frame = np.array(pil_frame["image"]) mm_data_per_video = np.repeat([np_frame], frames_per_video, axis=0) mm_data = {"video": mm_data_per_video} - return seq_data, mm_data + return DummyData(seq_data, mm_data, ranges) msg = f"Unsupported vision config: {type(vision_config)}" raise NotImplementedError(msg) @@ -145,6 +147,12 @@ def input_processor_for_llava_next_video(ctx: InputContext, multi_modal_data = inputs.get("multi_modal_data") if multi_modal_data is None or "video" not in multi_modal_data: return inputs + + if "multi_modal_placeholders" in inputs and "video" in inputs[ + "multi_modal_placeholders"]: + # The inputs already have placeholders. + return inputs + video_data = multi_modal_data["video"] model_config = ctx.model_config @@ -160,7 +168,7 @@ def input_processor_for_llava_next_video(ctx: InputContext, tokenizer = cached_get_tokenizer(model_config.tokenizer) - new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens( + new_prompt, new_token_ids, ranges = repeat_and_pad_placeholder_tokens( tokenizer, inputs.get("prompt"), inputs["prompt_token_ids"], @@ -170,7 +178,8 @@ def input_processor_for_llava_next_video(ctx: InputContext, return token_inputs(prompt_token_ids=new_token_ids, prompt=new_prompt, - multi_modal_data=multi_modal_data) + multi_modal_data=multi_modal_data, + multi_modal_placeholders={"video": ranges}) elif is_list_of(video_data, np.ndarray): raise NotImplementedError( diff --git a/vllm/model_executor/models/llava_onevision.py b/vllm/model_executor/models/llava_onevision.py index a0cf208a65f36..f410d64577a77 100644 --- a/vllm/model_executor/models/llava_onevision.py +++ b/vllm/model_executor/models/llava_onevision.py @@ -15,8 +15,8 @@ from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, MultiModalConfig -from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext, - token_inputs) +from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, + InputContext, token_inputs) from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import Sampler, SamplerOutput @@ -218,31 +218,31 @@ def dummy_data_for_llava_onevision(ctx: InputContext, seq_len: int, video_feature_size = get_llava_onevision_video_tokens(ctx, num_frames) if isinstance(vision_config, CLIPVisionConfig): - seq_data = dummy_seq_data_for_clip( + seq_data, ranges = dummy_seq_data_for_clip( vision_config, seq_len, num_videos, image_token_id=hf_config.video_token_index, image_feature_size_override=video_feature_size, - ) + mm_key="video") mm_data = dummy_video_for_clip(vision_config, num_frames=num_frames, num_videos=num_videos) - return seq_data, mm_data + return DummyData(seq_data, mm_data, ranges) elif isinstance(vision_config, SiglipVisionConfig): - seq_data = dummy_seq_data_for_siglip( + seq_data, ranges = dummy_seq_data_for_siglip( vision_config, seq_len, num_videos, image_token_id=hf_config.video_token_index, image_feature_size_override=video_feature_size, - ) + mm_key="video") mm_data = dummy_video_for_siglip(vision_config, num_frames=num_frames, num_videos=num_videos) - return seq_data, mm_data + return DummyData(seq_data, mm_data, ranges) msg = f"Unsupported vision config: {type(vision_config)}" raise NotImplementedError(msg) @@ -320,7 +320,7 @@ def input_processor_when_multimodal_input_video(ctx: InputContext, video_feature_size = get_llava_onevision_video_tokens(ctx, num_frames) tokenizer = cached_get_tokenizer(model_config.tokenizer) - new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens( + new_prompt, new_token_ids, ranges = repeat_and_pad_placeholder_tokens( tokenizer, inputs.get("prompt"), inputs["prompt_token_ids"], @@ -330,7 +330,8 @@ def input_processor_when_multimodal_input_video(ctx: InputContext, return token_inputs(prompt_token_ids=new_token_ids, prompt=new_prompt, - multi_modal_data=multi_modal_data) + multi_modal_data=multi_modal_data, + multi_modal_placeholders={"video": ranges}) elif is_list_of(video_data, np.ndarray): video_feature_size = [] diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index 4917c33136069..c1f714bb25680 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -36,8 +36,8 @@ from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, LoRAConfig, MultiModalConfig -from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext, - token_inputs) +from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, + InputContext, token_inputs) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.resampler import (BaseResampler, Resampler2, @@ -131,16 +131,22 @@ class MiniCPMVImageEmbeddingInputs(TypedDict): class Resampler2_5(BaseResampler): - def __init__( - self, - num_queries: int, - embed_dim: int, - num_heads: int, - kv_dim: Optional[int] = None, - norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN, - max_size: Tuple[int, int] = (70, 70), - ) -> None: - super().__init__(num_queries, embed_dim, num_heads, kv_dim, norm_layer) + def __init__(self, + num_queries: int, + embed_dim: int, + num_heads: int, + kv_dim: Optional[int] = None, + norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN, + max_size: Tuple[int, int] = (70, 70), + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "") -> None: + super().__init__(num_queries, + embed_dim, + num_heads, + kv_dim, + norm_layer, + quant_config=quant_config, + prefix=prefix) self.max_size = max_size self._set_2d_pos_cache(self.max_size) @@ -277,7 +283,7 @@ def dummy_data_for_minicpmv(ctx: InputContext, seq_len: int, seq_data = dummy_seq_data_for_minicpmv(seq_len, num_images) mm_data = dummy_image_for_minicpmv(ctx, hf_config, num_images) - return seq_data, mm_data + return DummyData(seq_data, mm_data) def input_processor_for_minicpmv(ctx: InputContext, inputs: DecoderOnlyInputs): @@ -404,7 +410,10 @@ def __init__( self.vision_dim = (self.vpm.embed_dim if self.version == (2, 0) else self.vpm.embeddings.embed_dim) self.embed_dim = self.config.hidden_size - self.resampler = self.init_resampler(self.embed_dim, self.vision_dim) + self.resampler = self.init_resampler(self.embed_dim, + self.vision_dim, + quant_config=quant_config, + prefix="resampler") self.resampler.to(device="cuda", dtype=param_dtype) # TODO: why is there _KEYS_TO_MODIFY_MAPPING? lm_head should be in llm self.lm_head = ParallelLMHead(config.vocab_size, @@ -564,8 +573,13 @@ def forward( vlm_embeddings, _ = self.get_embedding(input_ids, image_inputs) + # always pass the input via `inputs_embeds` + # to make sure the computation graph is consistent + # for `torch.compile` integration + input_ids = None + output = self.llm( - input_ids=None, + input_ids=input_ids, positions=positions, kv_caches=kv_caches, attn_metadata=attn_metadata, @@ -661,7 +675,11 @@ def init_vision_module( ) -> nn.Module: raise NotImplementedError - def init_resampler(self, embed_dim: int, vision_dim: int) -> nn.Module: + def init_resampler(self, + embed_dim: int, + vision_dim: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "") -> nn.Module: raise NotImplementedError def get_vision_embedding( @@ -738,16 +756,21 @@ def init_vision_module( def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.embed_tokens(input_ids) - def init_resampler(self, embed_dim: int, vision_dim: int) -> nn.Module: + def init_resampler(self, + embed_dim: int, + vision_dim: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "") -> nn.Module: with set_default_torch_dtype(torch.float16): - resampler = Resampler2( - embed_dim=embed_dim, - num_heads=embed_dim // 128, - grid_size=int(math.sqrt(self.config.query_num)), - kv_dim=vision_dim, - adaptive=False, - do_post_projection=True, - ) + resampler = Resampler2(embed_dim=embed_dim, + num_heads=embed_dim // 128, + grid_size=int( + math.sqrt(self.config.query_num)), + kv_dim=vision_dim, + adaptive=False, + do_post_projection=True, + quant_config=quant_config, + prefix=prefix) return resampler @@ -820,9 +843,21 @@ class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA): ".k_proj.", ".v_proj.", ".o_proj.", + # vision encoder + ".fc1.", + ".fc2.", + # Currently, vllm does not support BNB quantization for the `out_proj` + # of the resampler, so it's necessary to distinguish between the + # vision encoder and the resampler's out_proj. The same applies to + # MiniCPMV2_6. + ".self_attn.out_proj.", # vision encoder out_proj + # resampler + ".kv_proj.", ] # in TP, these weights are partitioned along the column dimension (dim=-1) - column_parallel_weights_modules = [".down_proj.", ".o_proj."] + column_parallel_weights_modules = [ + ".down_proj.", ".o_proj.", ".self_attn.out_proj.", ".fc2." + ] bitsandbytes_stacked_params_mapping = { # shard_name, weight_name, index "q_proj": ("qkv_proj", 0), @@ -872,14 +907,18 @@ def init_vision_module( model.encoder.layers = model.encoder.layers[:-1] return model - def init_resampler(self, embed_dim: int, vision_dim: int) -> nn.Module: + def init_resampler(self, + embed_dim: int, + vision_dim: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "") -> nn.Module: with set_default_torch_dtype(torch.float16): - resampler = Resampler2_5( - num_queries=self.config.query_num, - embed_dim=embed_dim, - num_heads=embed_dim // 128, - kv_dim=vision_dim, - ) + resampler = Resampler2_5(num_queries=self.config.query_num, + embed_dim=embed_dim, + num_heads=embed_dim // 128, + kv_dim=vision_dim, + quant_config=quant_config, + prefix=prefix) return resampler def get_vision_embedding( @@ -962,9 +1001,17 @@ class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA): ".k_proj.", ".v_proj.", ".o_proj.", + # vision encoder + ".fc1.", + ".fc2.", + ".self_attn.out_proj.", + # resampler + ".kv_proj.", ] # in TP, these weights are partitioned along the column dimension (dim=-1) - column_parallel_weights_modules = [".down_proj.", ".o_proj."] + column_parallel_weights_modules = [ + ".down_proj.", ".o_proj.", ".self_attn.out_proj.", ".fc2." + ] bitsandbytes_stacked_params_mapping = { # shard_name, weight_name, index "q_proj": ("qkv_proj", 0), @@ -1014,15 +1061,19 @@ def init_vision_module( model.encoder.layers = model.encoder.layers[:-1] return model - def init_resampler(self, embed_dim: int, vision_dim: int) -> nn.Module: + def init_resampler(self, + embed_dim: int, + vision_dim: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "") -> nn.Module: with set_default_torch_dtype(torch.float16): # The resampler in 2.6 remains consistent with the one in 2.5. - resampler = Resampler2_5( - num_queries=self.config.query_num, - embed_dim=embed_dim, - num_heads=embed_dim // 128, - kv_dim=vision_dim, - ) + resampler = Resampler2_5(num_queries=self.config.query_num, + embed_dim=embed_dim, + num_heads=embed_dim // 128, + kv_dim=vision_dim, + quant_config=quant_config, + prefix=prefix) return resampler def get_vision_embedding( diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py index 5cf5272cae878..a03155ac32a61 100644 --- a/vllm/model_executor/models/mllama.py +++ b/vllm/model_executor/models/mllama.py @@ -36,7 +36,7 @@ from vllm.attention.ops.paged_attn import PagedAttention from vllm.config import CacheConfig, MultiModalConfig from vllm.distributed import get_tensor_model_parallel_world_size -from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, +from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, EncoderDecoderInputs, InputContext) from vllm.logger import init_logger from vllm.model_executor.layers.layernorm import RMSNorm @@ -176,13 +176,14 @@ def dummy_image(num_images: int, ): def dummy_decoder_data_for_mllama(ctx: InputContext, seq_len: int, mm_counts: Mapping[str, int]): num_images = mm_counts["image"] - return dummy_decoder_seq_data(seq_len, num_images), None + return DummyData(dummy_decoder_seq_data(seq_len, num_images)) def dummy_encoder_data_for_mllama(ctx: InputContext, seq_len: int, mm_counts: Mapping[str, int]): num_images = mm_counts["image"] - return dummy_encoder_seq_data(ctx, num_images), dummy_image(num_images) + return DummyData(dummy_encoder_seq_data(ctx, num_images), + dummy_image(num_images)) def _prepare_aspect_ratio_attention_mask( @@ -1055,9 +1056,14 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal): ".k_proj.", ".v_proj.", ".o_proj.", + ".fc1.", + ".fc2.", + # The `multi_modal_projector` is at the top level of the model, + # so we can't add a dot in front of it. + "multi_modal_projector." ] # in TP, these weights are partitioned along the column dimension (dim=-1) - column_parallel_weights_modules = [".down_proj.", ".o_proj."] + column_parallel_weights_modules = [".down_proj.", ".o_proj.", ".fc2."] bitsandbytes_stacked_params_mapping = { # shard_name, weight_name, index "q_proj": ("qkv_proj", 0), diff --git a/vllm/model_executor/models/molmo.py b/vllm/model_executor/models/molmo.py index 3c34227767e05..ba798833e26a9 100644 --- a/vllm/model_executor/models/molmo.py +++ b/vllm/model_executor/models/molmo.py @@ -15,6 +15,7 @@ from vllm.attention import Attention, AttentionMetadata from vllm.attention.selector import _Backend +from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, MultiModalConfig from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, @@ -713,6 +714,7 @@ def forward( return image_features +@support_torch_compile class MolmoModel(nn.Module): def __init__( @@ -1141,7 +1143,6 @@ def forward( **kwargs: object, ) -> SamplerOutput: if intermediate_tensors is not None: - input_ids = None inputs_embeds = None else: image_input = self._parse_and_validate_image_input(**kwargs) @@ -1156,10 +1157,13 @@ def forward( image_input["image_input_idx"], image_input["seq_len"], ) - - input_ids = None else: - inputs_embeds = None + inputs_embeds = self.model.embed_tokens(input_ids) + + # always pass the input via `inputs_embeds` + # to make sure the computation graph is consistent + # for `torch.compile` integration + input_ids = None hidden_states = self.model( input_ids=input_ids, diff --git a/vllm/model_executor/models/paligemma.py b/vllm/model_executor/models/paligemma.py index 8e29c6079b994..4b6061e113cb2 100644 --- a/vllm/model_executor/models/paligemma.py +++ b/vllm/model_executor/models/paligemma.py @@ -7,8 +7,8 @@ from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, MultiModalConfig -from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext, - token_inputs) +from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, + InputContext, token_inputs) from vllm.logger import init_logger from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import SamplerOutput @@ -58,7 +58,7 @@ def dummy_data_for_paligemma(ctx: InputContext, seq_len: int, vision_config = hf_config.vision_config num_images = mm_counts["image"] - seq_data = dummy_seq_data_for_siglip( + seq_data, ranges = dummy_seq_data_for_siglip( vision_config, seq_len, num_images, @@ -66,7 +66,7 @@ def dummy_data_for_paligemma(ctx: InputContext, seq_len: int, ) mm_data = dummy_image_for_siglip(vision_config, num_images) - return seq_data, mm_data + return DummyData(seq_data, mm_data, ranges) def input_processor_for_paligemma(ctx: InputContext, diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index 4928e447d5b9e..5b477a8ed5f49 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -28,8 +28,8 @@ from vllm.attention import AttentionMetadata from vllm.config import (CacheConfig, ModelConfig, MultiModalConfig, PoolerConfig) -from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext, - token_inputs) +from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, + InputContext, token_inputs) from vllm.logger import init_logger from vllm.model_executor.layers.pooler import Pooler, PoolingType from vllm.model_executor.layers.quantization import QuantizationConfig @@ -380,7 +380,7 @@ def dummy_data_for_phi3v(ctx: InputContext, image_feature_size = get_max_phi3v_image_tokens(ctx, num_crops=num_crops) - seq_data = dummy_seq_data_for_clip( + seq_data, ranges = dummy_seq_data_for_clip( CLIP_VIT_LARGE_PATCH14_336_CONFIG, seq_len, num_images, @@ -394,7 +394,7 @@ def dummy_data_for_phi3v(ctx: InputContext, image_height_override=MAX_IMAGE_FEATURE_SIZE_HEIGHT, ) - return seq_data, mm_data + return DummyData(seq_data, mm_data, ranges) @lru_cache diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index 6b53bf5660096..051454c49bff8 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -17,8 +17,8 @@ from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, ModelConfig, MultiModalConfig -from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext, - token_inputs) +from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, + InputContext, token_inputs) from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.quantization import QuantizationConfig @@ -28,7 +28,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.base import MultiModalInputs -from vllm.multimodal.utils import cached_get_tokenizer +from vllm.multimodal.utils import (cached_get_tokenizer, + consecutive_placeholder_ranges) from vllm.sequence import IntermediateTensors, SequenceData from vllm.transformers_utils.processor import cached_get_processor from vllm.utils import is_list_of @@ -81,7 +82,12 @@ def dummy_data_for_pixtral(ctx: InputContext, seq_len: int, ) mm_data = {"image": num_images * [image]} - return seq_data, mm_data + mm_placeholders = { + "image": + consecutive_placeholder_ranges(num_items=num_images, + item_size=image_feature_size) + } + return DummyData(seq_data, mm_data, mm_placeholders) def input_mapper_for_pixtral(ctx: InputContext, @@ -630,13 +636,13 @@ def get_max_pixtral_hf_image_tokens(hf_config: PixtralVisionConfig) -> int: def dummy_seq_data_for_pixtral_hf( - hf_config: PixtralVisionConfig, - seq_len: int, - num_images: int, - *, - image_token_id: int, - image_feature_size_override: Optional[int] = None, -): + hf_config: PixtralVisionConfig, + seq_len: int, + num_images: int, + *, + image_token_id: int, + image_feature_size_override: Optional[int] = None, + mm_key: str = "image"): if image_feature_size_override is None: image_feature_size = get_max_pixtral_hf_image_feature_size(hf_config) else: @@ -645,7 +651,11 @@ def dummy_seq_data_for_pixtral_hf( return SequenceData.from_prompt_token_counts( (image_token_id, image_feature_size * num_images), (0, seq_len - image_feature_size * num_images), - ) + ), { + mm_key: + consecutive_placeholder_ranges(num_items=num_images, + item_size=image_feature_size) + } def dummy_image_for_pixtral_hf( diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index 61665768eacf5..b2b5c70182135 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -23,8 +23,8 @@ from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, LoRAConfig, MultiModalConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size -from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext, - token_inputs) +from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, + InputContext, token_inputs) from vllm.logger import init_logger from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn from vllm.model_executor.layers.layernorm import RMSNorm @@ -810,7 +810,7 @@ def dummy_data_for_qwen( ctx: InputContext, seq_len: int, mm_counts: Mapping[str, int], -) -> Tuple[SequenceData, Optional[Dict]]: +) -> DummyData: """Build dummy data for warming up Qwen models; this will only contain text matching the defaults for VLLM unless the model has a visual config. @@ -829,7 +829,7 @@ def dummy_data_for_qwen( if not hasattr(hf_config, "visual"): seq_data = SequenceData.from_prompt_token_counts((0, seq_len)) mm_data = None - return seq_data, mm_data + return DummyData(seq_data, mm_data) # We have a visual component - use images to warm up num_images = mm_counts["image"] @@ -861,7 +861,7 @@ def dummy_data_for_qwen( # the data will get resized and the # of tokens per image is constant image = Image.new("RGB", (224, 224), color=0) mm_data = {"image": image if num_images == 1 else [image] * num_images} - return seq_data, mm_data + return DummyData(seq_data, mm_data) class QWenBaseModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA): diff --git a/vllm/model_executor/models/qwen2_audio.py b/vllm/model_executor/models/qwen2_audio.py index 3d049eeb920b7..6114548bda42c 100644 --- a/vllm/model_executor/models/qwen2_audio.py +++ b/vllm/model_executor/models/qwen2_audio.py @@ -31,8 +31,8 @@ from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, MultiModalConfig -from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext, - token_inputs) +from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, + InputContext, token_inputs) from vllm.logger import init_logger from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization.base_config import ( @@ -44,6 +44,7 @@ from vllm.model_executor.models.qwen2 import Qwen2Model from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalInputs +from vllm.multimodal.utils import consecutive_placeholder_ranges from vllm.sequence import IntermediateTensors, SequenceData from .interfaces import SupportsMultiModal, SupportsPP @@ -85,7 +86,8 @@ def forward(self, audio_features): def dummy_data_for_qwen2_audio(ctx: InputContext, seq_len: int, mm_counts: Mapping[str, int]): num_audios = mm_counts["audio"] - max_llm_audio_tokens = get_max_qwen2_audio_audio_tokens(ctx) * num_audios + max_tokens_per_audio = get_max_qwen2_audio_audio_tokens(ctx) + max_llm_audio_tokens = max_tokens_per_audio * num_audios if seq_len - max_llm_audio_tokens - 2 < 0: raise RuntimeError( f"Qwen2-Audio cannot process {num_audios} audios in a prompt, " @@ -99,7 +101,12 @@ def dummy_data_for_qwen2_audio(ctx: InputContext, seq_len: int, (0, seq_len - max_llm_audio_tokens), ) dummy_audio = np.full((max_llm_audio_tokens * 2 * 2 * 160, ), 0.) - return dummy_seqdata, {"audio": [(dummy_audio, 16000)] * num_audios} + return DummyData( + dummy_seqdata, {"audio": [(dummy_audio, 16000)] * num_audios}, { + "audio": + consecutive_placeholder_ranges(num_items=num_audios, + item_size=max_tokens_per_audio) + }) def get_processor( diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 1e12c2332b65e..d801903f8f9fe 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -44,8 +44,8 @@ from vllm.config import CacheConfig, MultiModalConfig from vllm.distributed import get_pp_group, parallel_state from vllm.distributed import utils as dist_utils -from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext, - token_inputs) +from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, + InputContext, token_inputs) from vllm.logger import init_logger from vllm.model_executor import SamplingMetadata from vllm.model_executor.layers.activation import QuickGELU @@ -744,9 +744,10 @@ def dummy_data_for_qwen2_vl( dummy_image = Image.new("RGB", (max_resized_width, max_resized_height), color=0) - return dummy_seqdata, { - "image": dummy_image if num_images == 1 else [dummy_image] * num_images - } + return DummyData(dummy_seqdata, { + "image": + dummy_image if num_images == 1 else [dummy_image] * num_images + }) def _get_llm_num_vision_tokens( diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index f50ceaccb1bbe..3a929f5cb5195 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -128,6 +128,7 @@ def add_embedding_models(base_models, embedding_models): "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"), "ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"), "FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"), + "H2OVLChatModel": ("h2ovl", "H2OVLChatModel"), "InternVLChatModel": ("internvl", "InternVLChatModel"), "LlavaForConditionalGeneration": ("llava", "LlavaForConditionalGeneration"), "LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"), # noqa: E501 @@ -482,4 +483,4 @@ def _run() -> None: if __name__ == "__main__": - _run() + _run() \ No newline at end of file diff --git a/vllm/model_executor/models/siglip.py b/vllm/model_executor/models/siglip.py index 2e7ae32055aaf..acaf4afdecfe5 100644 --- a/vllm/model_executor/models/siglip.py +++ b/vllm/model_executor/models/siglip.py @@ -23,6 +23,7 @@ VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.multimodal.utils import (cached_get_tokenizer, + consecutive_placeholder_ranges, repeat_and_pad_placeholder_tokens) from vllm.sequence import SequenceData @@ -61,6 +62,7 @@ def dummy_seq_data_for_siglip( *, image_token_id: int, image_feature_size_override: Optional[int] = None, + mm_key: str = "image", ): if image_feature_size_override is None: image_feature_size = get_siglip_image_feature_size(hf_config) @@ -70,7 +72,11 @@ def dummy_seq_data_for_siglip( return SequenceData.from_prompt_token_counts( (image_token_id, image_feature_size * num_images), (0, seq_len - image_feature_size * num_images), - ) + ), { + mm_key: + consecutive_placeholder_ranges(num_items=num_images, + item_size=image_feature_size) + } def dummy_image_for_siglip( @@ -122,6 +128,11 @@ def input_processor_for_siglip( if multi_modal_data is None or "image" not in multi_modal_data: return inputs + if "multi_modal_placeholders" in inputs and "image" in inputs[ + "multi_modal_placeholders"]: + # The inputs already have placeholders. + return inputs + tokenizer = cached_get_tokenizer(model_config.tokenizer) if image_feature_size_override is None: @@ -135,7 +146,7 @@ def input_processor_for_siglip( else: image_feature_size = image_feature_size_override - new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens( + new_prompt, new_token_ids, ranges = repeat_and_pad_placeholder_tokens( tokenizer, inputs.get("prompt"), inputs["prompt_token_ids"], @@ -144,11 +155,10 @@ def input_processor_for_siglip( ) # NOTE: Create a defensive copy of the original inputs - return token_inputs( - prompt_token_ids=new_token_ids, - prompt=new_prompt, - multi_modal_data=multi_modal_data, - ) + return token_inputs(prompt_token_ids=new_token_ids, + prompt=new_prompt, + multi_modal_data=multi_modal_data, + multi_modal_placeholders={"image": ranges}) # Adapted from https://github.com/huggingface/transformers/blob/v4.43.3/src/transformers/models/siglip/modeling_siglip.py#L249 # noqa diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index f08e4aa355086..749750fc9c16e 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -2,7 +2,6 @@ """PyTorch Ultravox model.""" import math -from array import array from functools import cached_property, lru_cache from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, TypedDict, Union, cast) @@ -17,27 +16,27 @@ from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, MultiModalConfig -from vllm.inputs import INPUT_REGISTRY -from vllm.inputs.data import DecoderOnlyInputs, token_inputs -from vllm.inputs.registry import InputContext +from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, + InputContext, token_inputs) from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.model_loader.loader import DefaultModelLoader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.base import MultiModalInputs, NestedTensors +from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalInputs, + NestedTensors) from vllm.multimodal.utils import (cached_get_tokenizer, + consecutive_placeholder_ranges, repeat_and_pad_placeholder_tokens) -from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors, - SequenceData) +from vllm.sequence import IntermediateTensors, SequenceData from vllm.transformers_utils.configs.ultravox import UltravoxConfig from vllm.utils import is_list_of from .interfaces import SupportsMultiModal, SupportsPP from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, - init_vllm_registered_model, merge_multimodal_embeddings) + init_vllm_registered_model, + merge_multimodal_embeddings_from_map) _AUDIO_PLACEHOLDER_TOKEN = 128002 _AUDIO_TOKENS_PER_SECOND = 6.25 @@ -46,13 +45,13 @@ class UltravoxAudioFeatureInputs(TypedDict): type: Literal["audio_features"] data: NestedTensors - """Shape: `(batch_size, num_audios, 80, M)""" + """Shape: `(batch_size, num_audios, 80, M)`""" class UltravoxAudioEmbeddingInputs(TypedDict): type: Literal["audio_embeds"] data: NestedTensors - """Shape: `(batch_size, num_audios, audio_feature_size, hidden_size)""" + """Shape: `(batch_size, num_audios, audio_feature_size, hidden_size)`""" UltravoxAudioInputs = Union[UltravoxAudioFeatureInputs, @@ -79,17 +78,16 @@ def dummy_seq_data_for_ultravox( seq_len: int, audio_count: int, ): - audio_placeholder = array( - VLLM_TOKEN_ID_ARRAY_TYPE, - [_AUDIO_PLACEHOLDER_TOKEN]) * get_ultravox_max_audio_tokens(ctx) + audio_length = min(get_ultravox_max_audio_tokens(ctx), + seq_len // audio_count) - # Add a separator between each chunk. - audio_token_ids = (audio_placeholder + - array(VLLM_TOKEN_ID_ARRAY_TYPE, [0])) * audio_count - other_token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, - [0]) * (seq_len - len(audio_token_ids)) - - return SequenceData(audio_token_ids + other_token_ids) + return SequenceData.from_prompt_token_counts( + (_AUDIO_PLACEHOLDER_TOKEN, audio_length * audio_count), + (0, seq_len - audio_length * audio_count)), { + "audio": + consecutive_placeholder_ranges(num_items=audio_count, + item_size=audio_length) + } def dummy_audio_for_ultravox( @@ -107,10 +105,10 @@ def dummy_data_for_ultravox( mm_counts: Mapping[str, int], ): audio_count = mm_counts["audio"] - seq_data = dummy_seq_data_for_ultravox(ctx, seq_len, audio_count) + seq_data, ranges = dummy_seq_data_for_ultravox(ctx, seq_len, audio_count) mm_dict = dummy_audio_for_ultravox(ctx, audio_count) - return (seq_data, mm_dict) + return DummyData(seq_data, mm_dict, ranges) def input_mapper_for_ultravox(ctx: InputContext, data: object): @@ -164,6 +162,11 @@ def input_processor_for_ultravox(ctx: InputContext, inputs: DecoderOnlyInputs): if multi_modal_data is None or "audio" not in multi_modal_data: return inputs + if "multi_modal_placeholders" in inputs and "audio" in inputs[ + "multi_modal_placeholders"]: + # The inputs already have placeholders. + return inputs + feature_extractor = whisper_feature_extractor(ctx) audios = multi_modal_data["audio"] if not isinstance(audios, list): @@ -197,7 +200,7 @@ def input_processor_for_ultravox(ctx: InputContext, inputs: DecoderOnlyInputs): tokenizer = cached_get_tokenizer(ctx.model_config.tokenizer) - new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens( + new_prompt, new_token_ids, ranges = repeat_and_pad_placeholder_tokens( tokenizer, inputs.get("prompt"), inputs["prompt_token_ids"], @@ -208,7 +211,8 @@ def input_processor_for_ultravox(ctx: InputContext, inputs: DecoderOnlyInputs): # NOTE: Create a defensive copy of the original inputs return token_inputs(prompt_token_ids=new_token_ids, prompt=new_prompt, - multi_modal_data=multi_modal_data) + multi_modal_data=multi_modal_data, + multi_modal_placeholders={"audio": ranges}) class StackAudioFrames(nn.Module): @@ -472,9 +476,9 @@ def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, inputs_embeds = self.language_model.model.get_input_embeddings( input_ids) - inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, audio_embeddings, - _AUDIO_PLACEHOLDER_TOKEN) + merge_multimodal_embeddings_from_map( + inputs_embeds, audio_embeddings, + attn_metadata.multi_modal_placeholder_index_maps["audio"]) input_ids = None else: inputs_embeds = None diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index 0aecb5d151a45..fee97e8922a76 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -18,7 +18,7 @@ from vllm.model_executor.model_loader.loader import build_model from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models import ModelRegistry -from vllm.multimodal.base import NestedTensors +from vllm.multimodal.base import MultiModalPlaceholderMap, NestedTensors from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors from vllm.utils import is_pin_memory_available @@ -252,6 +252,7 @@ def init_vllm_registered_model( return build_model( model_class, + None, hf_config, cache_config, quant_config, @@ -326,6 +327,22 @@ def _embedding_count_expression(embeddings: NestedTensors) -> str: _embedding_count_expression(inner) for inner in embeddings) +def merge_multimodal_embeddings_from_map( + inputs_embeds: torch.Tensor, multimodal_embeddings: NestedTensors, + placeholder_map: MultiModalPlaceholderMap.IndexMap) -> torch.Tensor: + """ + Merge ``multimodal_embeddings`` into ``inputs_embeds`` using the provided + placeholder map . + + Note: + This updates ``inputs_embeds`` in place. + """ + flattened_embeddings = _flatten_embeddings(multimodal_embeddings) + inputs_embeds[placeholder_map.dest] = flattened_embeddings[ + placeholder_map.src] + return inputs_embeds + + def _merge_multimodal_embeddings( inputs_embeds: torch.Tensor, is_multimodal: torch.Tensor, diff --git a/vllm/multimodal/__init__.py b/vllm/multimodal/__init__.py index 489e1e51f05cb..53da2badb9b98 100644 --- a/vllm/multimodal/__init__.py +++ b/vllm/multimodal/__init__.py @@ -1,6 +1,7 @@ from .base import (BatchedTensorInputs, MultiModalDataBuiltins, - MultiModalDataDict, MultiModalInputs, MultiModalPlugin, - NestedTensors) + MultiModalDataDict, MultiModalInputs, + MultiModalPlaceholderDict, MultiModalPlaceholderMap, + MultiModalPlugin, NestedTensors) from .registry import MultiModalRegistry MULTIMODAL_REGISTRY = MultiModalRegistry() @@ -17,6 +18,8 @@ "MultiModalDataBuiltins", "MultiModalDataDict", "MultiModalInputs", + "MultiModalPlaceholderDict", + "MultiModalPlaceholderMap", "MultiModalPlugin", "NestedTensors", "MULTIMODAL_REGISTRY", diff --git a/vllm/multimodal/base.py b/vllm/multimodal/base.py index 84e71cbf60df7..6b10d0c609f13 100644 --- a/vllm/multimodal/base.py +++ b/vllm/multimodal/base.py @@ -1,8 +1,9 @@ import sys from abc import ABC, abstractmethod from collections import UserDict, defaultdict -from typing import (Any, Callable, Dict, List, Mapping, Optional, Tuple, Type, - TypedDict, TypeVar, Union, cast, final) +from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Mapping, + NamedTuple, Optional, Tuple, Type, TypedDict, TypeVar, + Union, cast, final) import numpy as np import torch @@ -11,12 +12,15 @@ from torch import nn from typing_extensions import TypeAlias -from vllm.config import ModelConfig from vllm.inputs import InputContext from vllm.logger import init_logger from vllm.utils import (JSONTree, get_allowed_kwarg_only_overrides, is_list_of, json_map_leaves, resolve_mm_processor_kwargs) +if TYPE_CHECKING: + from vllm.config import ModelConfig + from vllm.sequence import SequenceGroupMetadata + logger = init_logger(__name__) NestedTensors = Union[List["NestedTensors"], List[torch.Tensor], torch.Tensor] @@ -151,6 +155,30 @@ class MultiModalDataBuiltins(TypedDict, total=False): Read more on that :ref:`here `. """ + +class PlaceholderRange(TypedDict): + """ + Placeholder location information for multi-modal data. + + For example: + Prompt: AAAA BBBB What is in these images? + Images A and B will have: + A: { "offset": 0, "length": 4 } + B: { "offset": 5, "length": 4 } + """ + + offset: int + """The start index of the placeholder in the prompt.""" + + length: int + """The length of the placeholder.""" + + +MultiModalPlaceholderDict = Mapping[str, List[PlaceholderRange]] +""" +A dictionary containing placeholder ranges. +""" + MultiModalInputMapper = Callable[[InputContext, MultiModalData[object]], MultiModalInputs] """ @@ -243,7 +271,7 @@ def wrapper(model_cls: N) -> N: return wrapper - def map_input(self, model_config: ModelConfig, + def map_input(self, model_config: "ModelConfig", data: MultiModalData[object], mm_processor_kwargs: Dict[str, Any]) -> MultiModalInputs: """ @@ -332,7 +360,7 @@ def wrapper(model_cls: N) -> N: return wrapper - def get_max_multimodal_tokens(self, model_config: ModelConfig) -> int: + def get_max_multimodal_tokens(self, model_config: "ModelConfig") -> int: """ Get the maximum number of multi-modal tokens for profiling the memory usage of a model. @@ -366,3 +394,179 @@ def get_max_multimodal_tokens(self, model_config: ModelConfig) -> int: self._validate_max_multimodal_tokens(max_mm_tokens) return max_mm_tokens + + +class MultiModalPlaceholderMap: + """ + Relates multi-modal embeddings to their corresponding placeholders. + """ + + class IndexMap(NamedTuple): + src: List[int] + dest: List[int] + + src_ranges: List[range] + """ + The indices of the multi-modal embeddings that will replace the + corresponding placeholder embeddings pointed to by ``dest_ranges``. + """ + + src_len: int + """ + The total number of flattened multi-modal embeddings. + """ + + dest_ranges: List[range] + """ + The indices of the placeholder embeddings that will be replaced by the + multimodal embeddings. + """ + + dest_len: int + """ + The total number of embeddings in the destination tensor. + """ + + def __init__(self): + self.src_ranges = [] + self.src_len = 0 + self.dest_ranges = [] + self.dest_len = 0 + + @classmethod + def from_seq_group( + cls, seq_group: "SequenceGroupMetadata", positions: range + ) -> Tuple[Optional[MultiModalDataDict], Dict[str, + "MultiModalPlaceholderMap"]]: + """ + Returns the multi-modal items that intersect with the portion of a + prompt (``seq_group``) represented by ``positions``, as well as a + ``MultiModalPlaceholderMap`` that relates the multi-modal embedding + vectors to their corresponding placeholders. + + Consider the following scenarios: + + Prompt: |AAAA BBBB What's in these images?| + Positions: |.................................| + + images = [A, B] + src_ranges = [(0, 4), (4, 8)] + dest_ranges = [(0, 4), (5, 9)] + + Prompt: |AAAA BBBB What's in these images?| + Positions: | ..... | + + images = [A, B] + src_ranges = [(2, 4), (4, 6)] + dest_ranges = [(0, 2), (3, 5)] + + Prompt: |AAAA BBBB What's in these images?| + Positions: | ......... | + + images = [B] + src_ranges = [(0, 4)] + dest_ranges = [(0, 4)] + + Prompt: |AAAA BBBB What's in these images?| + Positions: | .......................| + + images = [] + src_ranges = [] + dest_ranges = [] + """ + if (not seq_group.multi_modal_data + or not seq_group.multi_modal_placeholders): + return seq_group.multi_modal_data, {} + + mm_data = {**seq_group.multi_modal_data} + placeholder_maps: Dict[str, MultiModalPlaceholderMap] = defaultdict( + MultiModalPlaceholderMap) + + for modality, placeholders in seq_group.multi_modal_placeholders.items( + ): + mm_items = mm_data.pop(modality) + if not isinstance(mm_items, list): + mm_items = [mm_items] + + if positions: + intersecting_items = placeholder_maps[ + modality].append_items_from_seq_group( + positions, mm_items, placeholders) + + if intersecting_items: + mm_data[modality] = intersecting_items + + return mm_data, placeholder_maps + + def append_items_from_seq_group( + self, positions: range, multi_modal_items: List[_T], + multi_modal_placeholders: List[PlaceholderRange]) -> List[_T]: + """ + Adds the multi-modal items that intersect ```positions`` to this + placeholder map and returns the intersecting items. + """ + intersecting_items = [] + + if len(multi_modal_items) != len(multi_modal_placeholders): + raise ValueError( + "Multi-modal placeholders and items must have the same length." + ) + for placeholder_dict, mm_item in zip(multi_modal_placeholders, + multi_modal_items): + placeholder = range( + placeholder_dict["offset"], + placeholder_dict["offset"] + placeholder_dict["length"]) + intersection = range(max(positions.start, placeholder.start), + min(positions.stop, placeholder.stop)) + + if not intersection: + # Skip this multi-modal item. + continue + + token_embedding_range = range(intersection.start - positions.start, + intersection.stop - positions.start) + + multimodal_embedding_range = range( + intersection.start - placeholder.start + self.src_len, + intersection.stop - placeholder.start + self.src_len) + + intersecting_items.append(mm_item) + self.dest_ranges.append(token_embedding_range) + self.src_ranges.append(multimodal_embedding_range) + self.src_len += len(placeholder) + + self.dest_len += len(positions) + return intersecting_items + + def extend(self, other: "MultiModalPlaceholderMap"): + """ + Adds the placeholders from another ``MultiModalPlaceholderMap`` to this + instance based on the source and destination tensors being + concatenated. + """ + + self.src_ranges.extend( + range(self.src_len + r.start, self.src_len + r.stop) + for r in other.src_ranges) + self.src_len += other.src_len + self.dest_ranges.extend( + range(self.dest_len + r.start, self.dest_len + r.stop) + for r in other.dest_ranges) + self.dest_len += other.dest_len + + def index_map(self) -> "IndexMap": + """ + Finalizes the placeholder map into lists of indices that can be used to + index the source and destination tensors. + """ + + src_indices = [i for r in self.src_ranges for i in r] + dest_indices = [i for r in self.dest_ranges for i in r] + + if len(src_indices) != len(dest_indices): + raise ValueError( + f"The number of source ({len(src_indices)}) and destination " + f"indices ({len(dest_indices)}) must be the same.") + + return MultiModalPlaceholderMap.IndexMap(src=src_indices, + dest=dest_indices) diff --git a/vllm/multimodal/image.py b/vllm/multimodal/image.py index 5f74bcea65ce2..3f6bb6c8338d2 100644 --- a/vllm/multimodal/image.py +++ b/vllm/multimodal/image.py @@ -1,11 +1,10 @@ from functools import lru_cache -from typing import Any, Dict, Optional +from typing import TYPE_CHECKING, Any, Dict, Optional import torch from PIL import Image from transformers.image_processing_base import BatchFeature -from vllm.config import ModelConfig from vllm.inputs.registry import InputContext from vllm.logger import init_logger from vllm.transformers_utils.processor import get_image_processor @@ -13,6 +12,9 @@ from .base import MultiModalData, MultiModalInputs, MultiModalPlugin +if TYPE_CHECKING: + from vllm.config import ModelConfig + logger = init_logger(__name__) cached_get_image_processor = lru_cache(get_image_processor) @@ -26,7 +28,7 @@ def get_data_key(self) -> str: def _get_hf_image_processor( self, - model_config: ModelConfig, + model_config: "ModelConfig", mm_processor_kwargs: Optional[Dict[str, Any]] = None, ): if mm_processor_kwargs is None: diff --git a/vllm/multimodal/registry.py b/vllm/multimodal/registry.py index 5e9b8bd518de3..bce2f4c6abe5b 100644 --- a/vllm/multimodal/registry.py +++ b/vllm/multimodal/registry.py @@ -1,8 +1,7 @@ import functools from collections import UserDict -from typing import Any, Dict, Mapping, Optional, Sequence +from typing import TYPE_CHECKING, Any, Dict, Mapping, Optional, Sequence -from vllm.config import ModelConfig from vllm.logger import init_logger from .audio import AudioPlugin @@ -11,6 +10,9 @@ from .image import ImagePlugin from .video import VideoPlugin +if TYPE_CHECKING: + from vllm.config import ModelConfig + logger = init_logger(__name__) @@ -20,7 +22,7 @@ class _MultiModalLimits(UserDict): when attempting to access a model that does not exist. """ - def __getitem__(self, key: ModelConfig) -> Dict[str, int]: + def __getitem__(self, key: "ModelConfig") -> Dict[str, int]: try: return super().__getitem__(key) except KeyError as exc: @@ -98,7 +100,7 @@ def register_image_input_mapper( def map_input( self, - model_config: ModelConfig, + model_config: "ModelConfig", data: MultiModalDataDict, mm_processor_kwargs: Optional[Dict[str, Any]] = None, ) -> MultiModalInputs: @@ -139,7 +141,7 @@ def map_input( return MultiModalInputs(merged_dict) - def create_input_mapper(self, model_config: ModelConfig): + def create_input_mapper(self, model_config: "ModelConfig"): """ Create an input mapper (see :meth:`map_input`) for a specific model. """ @@ -177,7 +179,7 @@ def register_max_image_tokens( """ return self.register_max_multimodal_tokens("image", max_mm_tokens) - def get_max_multimodal_tokens(self, model_config: ModelConfig) -> int: + def get_max_multimodal_tokens(self, model_config: "ModelConfig") -> int: """ Get the maximum number of multi-modal tokens for profiling the memory usage of a model. @@ -195,7 +197,7 @@ def get_max_multimodal_tokens(self, model_config: ModelConfig) -> int: def init_mm_limits_per_prompt( self, - model_config: ModelConfig, + model_config: "ModelConfig", ) -> None: """ Initialize the maximum number of multi-modal input instances for each @@ -231,7 +233,7 @@ def init_mm_limits_per_prompt( def get_mm_limits_per_prompt( self, - model_config: ModelConfig, + model_config: "ModelConfig", ) -> Mapping[str, int]: """ Get the maximum number of multi-modal input instances for each modality diff --git a/vllm/multimodal/utils.py b/vllm/multimodal/utils.py index 3c801464383ad..c5ff552e06099 100644 --- a/vllm/multimodal/utils.py +++ b/vllm/multimodal/utils.py @@ -10,7 +10,7 @@ from vllm.connections import global_http_connection from vllm.envs import VLLM_AUDIO_FETCH_TIMEOUT, VLLM_IMAGE_FETCH_TIMEOUT from vllm.logger import init_logger -from vllm.multimodal.base import MultiModalDataDict +from vllm.multimodal.base import MultiModalDataDict, PlaceholderRange from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer logger = init_logger(__name__) @@ -258,7 +258,7 @@ def repeat_and_pad_placeholder_tokens( repeat_count: Union[int, List[int]], pad_token_left: Optional[int] = None, pad_token_right: Optional[int] = None, -) -> Tuple[Optional[str], List[int]]: +) -> Tuple[Optional[str], List[int], List[PlaceholderRange]]: if isinstance(repeat_count, int): repeat_count = [repeat_count] @@ -301,6 +301,7 @@ def repeat_and_pad_placeholder_tokens( new_prompt += prompt_parts[-1] new_token_ids: List[int] = [] + placeholder_ranges: List[PlaceholderRange] = [] placeholder_token_idx = 0 for i, token in enumerate(prompt_token_ids): if token == placeholder_token_id: @@ -310,6 +311,10 @@ def repeat_and_pad_placeholder_tokens( pad_token_left=pad_token_left, pad_token_right=pad_token_right, ) + placeholder_ranges.append({ + "offset": len(new_token_ids), + "length": len(replacement_ids) + }) new_token_ids.extend(replacement_ids) placeholder_token_idx += 1 @@ -320,4 +325,14 @@ def repeat_and_pad_placeholder_tokens( else: new_token_ids.append(token) - return new_prompt, new_token_ids + return new_prompt, new_token_ids, placeholder_ranges + + +def consecutive_placeholder_ranges(num_items: int, + item_size: int) -> List[PlaceholderRange]: + """Returns a list of consecutive PlaceholderRanges of a fixed size""" + + return [ + PlaceholderRange(offset=i * item_size, length=item_size) + for i in range(num_items) + ] diff --git a/vllm/multimodal/video.py b/vllm/multimodal/video.py index c3235c4acb6fd..6c2c6720f4276 100644 --- a/vllm/multimodal/video.py +++ b/vllm/multimodal/video.py @@ -1,18 +1,19 @@ from functools import lru_cache -from typing import Any, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union import numpy as np -from vllm.config import ModelConfig from vllm.inputs.registry import InputContext from vllm.logger import init_logger from vllm.transformers_utils.processor import get_video_processor from vllm.transformers_utils.tokenizer import get_tokenizer -from vllm.utils import is_list_of from .base import MultiModalData, MultiModalInputs from .image import ImagePlugin +if TYPE_CHECKING: + from vllm.config import ModelConfig + logger = init_logger(__name__) cached_get_video_processor = lru_cache(get_video_processor) @@ -38,7 +39,7 @@ def get_data_key(self) -> str: def _get_hf_video_processor( self, - model_config: ModelConfig, + model_config: "ModelConfig", mm_processor_kwargs: Optional[Dict[str, Any]] = None, ): if mm_processor_kwargs is None: @@ -56,7 +57,10 @@ def _default_input_mapper( ) -> MultiModalInputs: model_config = ctx.model_config - if isinstance(data, np.ndarray) or is_list_of(data, np.ndarray): + if isinstance(data, list) and len(data) == 1: + data = data[0] + + if isinstance(data, np.ndarray): video_processor = self._get_hf_video_processor( model_config, mm_processor_kwargs, diff --git a/vllm/platforms/openvino.py b/vllm/platforms/openvino.py index 35dbe22abf7ff..31fe3f1fcbfe4 100644 --- a/vllm/platforms/openvino.py +++ b/vllm/platforms/openvino.py @@ -1,10 +1,12 @@ import torch import vllm.envs as envs -from vllm.utils import print_warning_once +from vllm.logger import init_logger from .interface import Platform, PlatformEnum +logger = init_logger(__name__) + class OpenVinoPlatform(Platform): _enum = PlatformEnum.OPENVINO @@ -27,5 +29,5 @@ def is_openvino_gpu(self) -> bool: @classmethod def is_pin_memory_available(self) -> bool: - print_warning_once("Pin memory is not supported on OpenViNO.") + logger.warning("Pin memory is not supported on OpenViNO.") return False diff --git a/vllm/plugins/__init__.py b/vllm/plugins/__init__.py index 4338cbc37f6c1..3336569f59467 100644 --- a/vllm/plugins/__init__.py +++ b/vllm/plugins/__init__.py @@ -1,8 +1,14 @@ import logging -from typing import Callable, Optional, Union +from typing import TYPE_CHECKING, Callable, Optional, Union import vllm.envs as envs -from vllm.compilation.config import CompilationConfig + +if TYPE_CHECKING: + from vllm.compilation.config import CompilationConfig + from vllm.config import VllmConfig +else: + CompilationConfig = None + VllmConfig = None logger = logging.getLogger(__name__) @@ -55,3 +61,15 @@ def set_compilation_config(config: Optional[CompilationConfig]): def get_compilation_config() -> Optional[CompilationConfig]: return _compilation_config + + +_vllm_config: Optional[VllmConfig] = None + + +def set_vllm_config(config: Optional[VllmConfig]): + global _vllm_config + _vllm_config = config + + +def get_vllm_config() -> Optional[VllmConfig]: + return _vllm_config diff --git a/vllm/sequence.py b/vllm/sequence.py index ff59f333f00b4..44a9257c9a4c1 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -6,7 +6,8 @@ from collections import defaultdict from dataclasses import dataclass, field from functools import cached_property, reduce -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Mapping, Optional +from typing import (TYPE_CHECKING, Any, Callable, DefaultDict, Dict, List, + Mapping, Optional) from typing import Sequence as GenericSequence from typing import Set, Tuple, Union, cast @@ -15,13 +16,13 @@ from vllm.inputs.parse import is_encoder_decoder_inputs from vllm.lora.request import LoRARequest +from vllm.multimodal import MultiModalDataDict, MultiModalPlaceholderDict from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import RequestOutputKind, SamplingParams if TYPE_CHECKING: from vllm.inputs import SingletonInputs - from vllm.multimodal.base import MultiModalDataDict VLLM_TOKEN_ID_ARRAY_TYPE = "l" @@ -256,7 +257,8 @@ def output_token_ids(self) -> Tuple[int, ...]: return tuple(self._output_token_ids) @output_token_ids.setter - def output_token_ids(self, new_output_token_ids: List[int]) -> None: + def output_token_ids(self, + new_output_token_ids: GenericSequence[int]) -> None: self._output_token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, new_output_token_ids) self._update_cached_all_tokens() @@ -485,7 +487,7 @@ def prompt_token_ids(self) -> List[int]: return cast(List[int], self.inputs.get(prompt_token_ids_key)) @property - def multi_modal_data(self) -> "MultiModalDataDict": + def multi_modal_data(self) -> MultiModalDataDict: inputs = self.inputs if (inputs.get("multi_modal_data") @@ -495,11 +497,15 @@ def multi_modal_data(self) -> "MultiModalDataDict": ) return cast( - "MultiModalDataDict", + MultiModalDataDict, (inputs.get("multi_modal_data") or inputs.get("encoder_multi_modal_data") or {}), ) + @property + def multi_modal_placeholders(self) -> MultiModalPlaceholderDict: + return self.inputs.get("multi_modal_placeholders") or {} + @property def mm_processor_kwargs(self) -> Dict[str, Any]: return self.inputs.get("mm_processor_kwargs") or {} @@ -728,9 +734,13 @@ def encoder_prompt_token_ids(self) -> Optional[List[int]]: if self.encoder_seq is not None else None) @property - def multi_modal_data(self) -> "MultiModalDataDict": + def multi_modal_data(self) -> MultiModalDataDict: return self.first_seq.multi_modal_data + @property + def multi_modal_placeholders(self) -> MultiModalPlaceholderDict: + return self.first_seq.multi_modal_placeholders + @property def mm_processor_kwargs(self) -> Dict[str, Any]: return self.first_seq.mm_processor_kwargs @@ -946,6 +956,7 @@ class SequenceGroupMetadata( # "MultiModalDataDict" types. We have to use Any due to msgspec # doesn't allow to have union of 2 different dicts. multi_modal_data: Optional[Any] = None + multi_modal_placeholders: Optional[MultiModalPlaceholderDict] = None mm_processor_kwargs: Optional[Dict[str, Any]] = None encoder_seq_data: Optional[SequenceData] = None cross_block_table: Optional[List[int]] = None @@ -1164,7 +1175,7 @@ def get_all_seq_ids_and_request_ids( sequence ids. """ seq_ids: List[int] = [] - request_id_seq_ids_mapping: Dict[str, Set[int]] = defaultdict(set) + request_id_seq_ids_mapping: DefaultDict[str, Set[int]] = defaultdict(set) for sg in seq_group_metadata_list: for seq_id in sg.seq_data: seq_ids.append(seq_id) diff --git a/vllm/spec_decode/draft_model_runner.py b/vllm/spec_decode/draft_model_runner.py index 3aa999fcb9ebb..17cc0ad1a4a3a 100644 --- a/vllm/spec_decode/draft_model_runner.py +++ b/vllm/spec_decode/draft_model_runner.py @@ -17,9 +17,6 @@ "Draft model speculative decoding currently only supports" "CUDA and ROCm flash attention backend.") from err -from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, - ModelConfig, ObservabilityConfig, ParallelConfig, - PromptAdapterConfig, SchedulerConfig) from vllm.logger import init_logger from vllm.multimodal import MultiModalInputs from vllm.sequence import ExecuteModelRequest, IntermediateTensors @@ -49,40 +46,13 @@ class TP1DraftModelRunner(ModelRunner): any broadcasting inside execute_model). """ - def __init__( - self, - model_config: ModelConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - device_config: DeviceConfig, - cache_config: CacheConfig, - load_config: LoadConfig, - lora_config: Optional[LoRAConfig], - kv_cache_dtype: Optional[str] = "auto", - is_driver_worker: bool = False, - prompt_adapter_config: Optional[PromptAdapterConfig] = None, - return_hidden_states: bool = False, - observability_config: Optional[ObservabilityConfig] = None, - ): - if return_hidden_states: + def __init__(self, *args, **kwargs): + if kwargs.get("return_hidden_states"): raise ValueError( "return_hidden_states is not supported for TP1DraftModelRunner." ) - super().__init__( - model_config=model_config, - parallel_config=parallel_config, - scheduler_config=scheduler_config, - device_config=device_config, - cache_config=cache_config, - load_config=load_config, - lora_config=lora_config, - kv_cache_dtype=kv_cache_dtype, - is_driver_worker=is_driver_worker, - prompt_adapter_config=prompt_adapter_config, - return_hidden_states=return_hidden_states, - observability_config=observability_config, - ) + super().__init__(*args, **kwargs) def _update_sampling_metadata(self, sampling_metadata, num_seqs, num_queries): diff --git a/vllm/spec_decode/ngram_worker.py b/vllm/spec_decode/ngram_worker.py index a777e5c3f22a7..debb3b2d5ec30 100644 --- a/vllm/spec_decode/ngram_worker.py +++ b/vllm/spec_decode/ngram_worker.py @@ -21,7 +21,7 @@ class NGramWorker(NonLLMProposerWorkerBase): def __init__(self, *args, **kwargs): # Get local_rank/vocab_size from kwargs attribute self.local_rank = kwargs["local_rank"] - self.vocab_size = kwargs["model_config"].get_vocab_size() + self.vocab_size = kwargs["vllm_config"].model_config.get_vocab_size() # Lazy initialization list. self._proposer: Top1Proposer diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index 9f7ef2f8d851c..a402181b13db8 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -1,10 +1,11 @@ +import copy from collections import defaultdict from functools import cached_property from typing import Any, Dict, List, Optional, Set, Tuple, Type import torch -from vllm.config import ParallelConfig, SpeculativeConfig +from vllm.config import ParallelConfig, SpeculativeConfig, VllmConfig from vllm.distributed.communication_op import broadcast_tensor_dict from vllm.logger import init_logger from vllm.model_executor.layers.rejection_sampler import RejectionSampler @@ -45,8 +46,8 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker": """Helper method that is the entrypoint for Executors which use WorkerWrapper. It constructs a SpecDecodeWorker from the speculative config. """ - assert "speculative_config" in kwargs - speculative_config: SpeculativeConfig = kwargs.get("speculative_config") + vllm_config: VllmConfig = kwargs.get("vllm_config") + speculative_config: SpeculativeConfig = vllm_config.speculative_config assert speculative_config is not None draft_worker_kwargs = kwargs.copy() @@ -58,14 +59,16 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker": target_worker.model_runner.disable_logprobs =\ speculative_config.disable_logprobs + draft_worker_config = copy.deepcopy(vllm_config) + draft_worker_config.model_config = speculative_config.draft_model_config + draft_worker_config.parallel_config = speculative_config.draft_parallel_config # noqa + # TODO allow draft-model specific load config. + # Override draft-model specific worker args. draft_worker_kwargs.update( - model_config=speculative_config.draft_model_config, - parallel_config=speculative_config.draft_parallel_config, + vllm_config=draft_worker_config, ngram_prompt_lookup_max=speculative_config.ngram_prompt_lookup_max, ngram_prompt_lookup_min=speculative_config.ngram_prompt_lookup_min, - # TODO allow draft-model specific load config. - #load_config=load_config, ) spec_decode_worker = SpecDecodeWorker.create_worker( @@ -134,29 +137,27 @@ def create_worker( draft_worker_kwargs.pop("ngram_prompt_lookup_max")) ngram_prompt_lookup_min = ( draft_worker_kwargs.pop("ngram_prompt_lookup_min")) + draft_model_config = draft_worker_kwargs["vllm_config"].model_config + draft_parallel_config: ParallelConfig = draft_worker_kwargs[ + 'vllm_config'].parallel_config if ngram_prompt_lookup_max > 0: proposer_worker = NGramWorker(**draft_worker_kwargs) proposer_worker.set_ngram_window_size(ngram_prompt_lookup_min, ngram_prompt_lookup_max) else: - draft_parallel_config: ParallelConfig = draft_worker_kwargs[ - 'parallel_config'] draft_tp = draft_parallel_config.tensor_parallel_size target_tp = scorer_worker.parallel_config.tensor_parallel_size - if draft_worker_kwargs[ - "model_config"].hf_config.model_type == "mlp_speculator": + if draft_model_config.hf_config.model_type == "mlp_speculator": proposer_worker = MLPSpeculatorWorker(**draft_worker_kwargs) - elif draft_worker_kwargs[ - "model_config"].hf_config.model_type == "medusa": + elif draft_model_config.hf_config.model_type == "medusa": proposer_worker = MedusaWorker(**draft_worker_kwargs) else: if draft_tp == 1: draft_worker_kwargs[ "model_runner_cls"] = TP1DraftModelRunner else: - if draft_worker_kwargs[ - "model_config"].hf_config.model_type == "eagle": + if draft_model_config.hf_config.model_type == "eagle": raise NotImplementedError( "EAGLE does not support TP > 1 yet") @@ -190,8 +191,8 @@ def create_worker( "[Speculative Decoding] Disabling MQA scorer as the " "MQA is only available with flash attn backend.") - if "model_config" in draft_worker_kwargs and \ - draft_worker_kwargs["model_config"].max_model_len < \ + if draft_model_config and \ + draft_model_config.max_model_len < \ scorer_worker.model_config.max_model_len: disable_mqa_scorer = True logger.info( diff --git a/vllm/spec_decode/target_model_runner.py b/vllm/spec_decode/target_model_runner.py index 2bb7af7d7c600..e61cde5b17f20 100644 --- a/vllm/spec_decode/target_model_runner.py +++ b/vllm/spec_decode/target_model_runner.py @@ -1,8 +1,6 @@ from typing import List, Optional -from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, - ModelConfig, ObservabilityConfig, ParallelConfig, - PromptAdapterConfig, SchedulerConfig) +from vllm.config import VllmConfig from vllm.sequence import SequenceGroupMetadata from vllm.worker.model_runner import (ModelInputForGPUWithSamplingMetadata, ModelRunner) @@ -20,35 +18,21 @@ class TargetModelRunner(ModelRunner): requested or not. """ - def __init__(self, - model_config: ModelConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - device_config: DeviceConfig, - cache_config: CacheConfig, - load_config: LoadConfig, - lora_config: Optional[LoRAConfig], - kv_cache_dtype: Optional[str] = "auto", - is_driver_worker: bool = False, - prompt_adapter_config: Optional[PromptAdapterConfig] = None, - return_hidden_states: bool = False, - observability_config: Optional[ObservabilityConfig] = None): + def __init__( + self, + vllm_config: VllmConfig, + kv_cache_dtype: Optional[str] = "auto", + is_driver_worker: bool = False, + return_hidden_states: bool = False, + ): # An internal boolean member variable to indicate if token log # probabilities are needed or not. self.disable_logprobs = True super().__init__( - model_config=model_config, - parallel_config=parallel_config, - scheduler_config=scheduler_config, - device_config=device_config, - cache_config=cache_config, - load_config=load_config, - lora_config=lora_config, + vllm_config=vllm_config, kv_cache_dtype=kv_cache_dtype, is_driver_worker=is_driver_worker, - prompt_adapter_config=prompt_adapter_config, return_hidden_states=return_hidden_states, - observability_config=observability_config, ) def prepare_model_input( diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 9bd2531d7a15c..08697274854e0 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -19,6 +19,7 @@ # yapf: disable from vllm.transformers_utils.configs import (ChatGLMConfig, DbrxConfig, EAGLEConfig, ExaoneConfig, + H2OVLChatConfig, InternVLChatConfig, JAISConfig, MedusaConfig, MllamaConfig, MLPSpeculatorConfig, MPTConfig, @@ -52,6 +53,7 @@ "medusa": MedusaConfig, "eagle": EAGLEConfig, "exaone": ExaoneConfig, + "h2ovl_chat": H2OVLChatConfig, "internvl_chat": InternVLChatConfig, "nemotron": NemotronConfig, "NVLM_D": NVLM_D_Config, diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py index f0d79197a82c5..d1e19c9a33c24 100644 --- a/vllm/transformers_utils/configs/__init__.py +++ b/vllm/transformers_utils/configs/__init__.py @@ -6,6 +6,7 @@ # tiiuae/falcon-7b(-instruct) models. Newer Falcon models will use the # `FalconConfig` class from the official HuggingFace transformers library. from vllm.transformers_utils.configs.falcon import RWConfig +from vllm.transformers_utils.configs.h2ovl import H2OVLChatConfig from vllm.transformers_utils.configs.internvl import InternVLChatConfig from vllm.transformers_utils.configs.jais import JAISConfig from vllm.transformers_utils.configs.medusa import MedusaConfig @@ -22,6 +23,7 @@ "DbrxConfig", "MPTConfig", "RWConfig", + "H2OVLChatConfig", "InternVLChatConfig", "JAISConfig", "MedusaConfig", @@ -33,4 +35,4 @@ "NVLM_D_Config", "SolarConfig", "UltravoxConfig", -] +] \ No newline at end of file diff --git a/vllm/transformers_utils/configs/h2ovl.py b/vllm/transformers_utils/configs/h2ovl.py new file mode 100644 index 0000000000000..b94c5b77e4b7f --- /dev/null +++ b/vllm/transformers_utils/configs/h2ovl.py @@ -0,0 +1,13 @@ +# Adapted from +# https://huggingface.co/h2oai/h2ovl-mississippi-2b/blob/main/configuration_h2ovl_chat.py +# -------------------------------------------------------- +# H2OVL-Mississippi +# Copyright (c) 2024 H2O.AI +# Licensed under Apache 2.0 License [see LICENSE for details] +# -------------------------------------------------------- + +from .internvl import InternVLChatConfig + + +class H2OVLChatConfig(InternVLChatConfig): + model_type = "h2ovl_chat" diff --git a/vllm/utils.py b/vllm/utils.py index 5488719cc99b0..a742ec8d76908 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -80,8 +80,8 @@ "currently supported with encoder/" "decoder models.") -STR_NOT_IMPL_ENC_DEC_BACKEND = ("XFormers is the only backend " - "currently supported with encoder/" +STR_NOT_IMPL_ENC_DEC_BACKEND = ("XFormers and Flash-Attention are the only " + "backends currently supported with encoder/" "decoder models.") STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER = ("Prompt adapters are not " @@ -1551,7 +1551,14 @@ def direct_register_custom_op( """ if is_in_doc_build(): return - schema_str = torch.library.infer_schema(op_func, mutates_args=mutates_args) + import torch.library + if hasattr(torch.library, "infer_schema"): + schema_str = torch.library.infer_schema(op_func, + mutates_args=mutates_args) + else: + # for pytorch 2.4 + import torch._custom_op.impl + schema_str = torch._custom_op.impl.infer_schema(op_func, mutates_args) my_lib = target_lib or vllm_lib my_lib.define(op_name + schema_str) my_lib.impl(op_name, op_func, "CUDA") diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index 072e52bcd686a..64cc18149d6c5 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -2,11 +2,9 @@ from typing import (Any, Dict, Iterable, List, Mapping, Optional, Tuple, Type, Union) -from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, - EngineConfig, LoadConfig, LoRAConfig, ModelConfig, - ObservabilityConfig, ParallelConfig, - PromptAdapterConfig, SchedulerConfig, - SpeculativeConfig) +from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig, + ObservabilityConfig, ParallelConfig, SchedulerConfig, + VllmConfig) from vllm.engine.arg_utils import EngineArgs from vllm.engine.metrics_types import StatLoggerBase from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, @@ -35,17 +33,7 @@ class LLMEngine: def __init__( self, - model_config: ModelConfig, - cache_config: CacheConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - device_config: DeviceConfig, - load_config: LoadConfig, - lora_config: Optional[LoRAConfig], - speculative_config: Optional[SpeculativeConfig], - decoding_config: Optional[DecodingConfig], - observability_config: Optional[ObservabilityConfig], - prompt_adapter_config: Optional[PromptAdapterConfig], + vllm_config: VllmConfig, executor_class: Type[GPUExecutor], log_stats: bool, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, @@ -53,6 +41,22 @@ def __init__( input_registry: InputRegistry = INPUT_REGISTRY, use_cached_outputs: bool = False, ) -> None: + + # TODO: remove the local variables and use self.* throughout the class. + model_config = self.model_config = vllm_config.model_config + cache_config = self.cache_config = vllm_config.cache_config + lora_config = self.lora_config = vllm_config.lora_config + parallel_config = self.parallel_config = vllm_config.parallel_config + scheduler_config = self.scheduler_config = vllm_config.scheduler_config + device_config = self.device_config = vllm_config.device_config + speculative_config = self.speculative_config = vllm_config.speculative_config # noqa + load_config = self.load_config = vllm_config.load_config + decoding_config = self.decoding_config = vllm_config.decoding_config or DecodingConfig( # noqa + ) + prompt_adapter_config = self.prompt_adapter_config = vllm_config.prompt_adapter_config # noqa + observability_config = self.observability_config = vllm_config.observability_config or ObservabilityConfig( # noqa + ) + # Override the configs for V1. # FIXME if usage_context == UsageContext.LLM_CLASS: @@ -112,18 +116,6 @@ def __init__( model_config.mm_processor_kwargs, ) - self.model_config = model_config - self.cache_config = cache_config - self.lora_config = lora_config - self.parallel_config = parallel_config - self.scheduler_config = scheduler_config - self.device_config = device_config - self.speculative_config = speculative_config - self.load_config = load_config - self.decoding_config = decoding_config or DecodingConfig() - self.prompt_adapter_config = prompt_adapter_config - self.observability_config = observability_config or ObservabilityConfig( - ) self.log_stats = log_stats assert not self.model_config.skip_tokenizer_init @@ -154,18 +146,7 @@ def __init__( # Request id -> RequestOutput self.request_outputs: Dict[str, RequestOutput] = {} - self.model_executor = executor_class( - model_config=model_config, - cache_config=cache_config, - parallel_config=parallel_config, - scheduler_config=scheduler_config, - device_config=device_config, - lora_config=lora_config, - speculative_config=speculative_config, - load_config=load_config, - prompt_adapter_config=prompt_adapter_config, - observability_config=self.observability_config, - ) + self.model_executor = executor_class(vllm_config=vllm_config) assert self.model_config.task != "embedding" self._initialize_kv_caches() @@ -203,7 +184,7 @@ def from_engine_args( executor_class = cls._get_executor_cls(engine_config) # Create the LLM engine. engine = cls( - **engine_config.to_dict(), + vllm_config=engine_config, executor_class=executor_class, log_stats=not engine_args.disable_log_stats, usage_context=usage_context, @@ -497,7 +478,7 @@ def get_lora_config(self) -> LoRAConfig: return self.lora_config @classmethod - def _get_executor_cls(cls, engine_config: EngineConfig): + def _get_executor_cls(cls, engine_config: VllmConfig): return GPUExecutor def is_tracing_enabled(self) -> bool: diff --git a/vllm/v1/executor/gpu_executor.py b/vllm/v1/executor/gpu_executor.py index c780c7031c3d6..f71fa16b16e27 100644 --- a/vllm/v1/executor/gpu_executor.py +++ b/vllm/v1/executor/gpu_executor.py @@ -1,10 +1,7 @@ import os from typing import Optional, Tuple -from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, - ModelConfig, ObservabilityConfig, ParallelConfig, - PromptAdapterConfig, SchedulerConfig, - SpeculativeConfig) +from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.utils import get_distributed_init_method, get_ip, get_open_port from vllm.v1.outputs import ModelRunnerOutput @@ -15,29 +12,18 @@ class GPUExecutor: - def __init__( - self, - model_config: ModelConfig, - cache_config: CacheConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - device_config: DeviceConfig, - load_config: LoadConfig, - lora_config: Optional[LoRAConfig], - speculative_config: Optional[SpeculativeConfig], - prompt_adapter_config: Optional[PromptAdapterConfig], - observability_config: Optional[ObservabilityConfig], - ) -> None: - self.model_config = model_config - self.cache_config = cache_config - self.lora_config = lora_config - self.load_config = load_config - self.parallel_config = parallel_config - self.scheduler_config = scheduler_config - self.device_config = device_config - self.speculative_config = speculative_config - self.prompt_adapter_config = prompt_adapter_config - self.observability_config = observability_config + def __init__(self, vllm_config: VllmConfig) -> None: + self.vllm_config = vllm_config + self.model_config = vllm_config.model_config + self.cache_config = vllm_config.cache_config + self.lora_config = vllm_config.lora_config + self.load_config = vllm_config.load_config + self.parallel_config = vllm_config.parallel_config + self.scheduler_config = vllm_config.scheduler_config + self.device_config = vllm_config.device_config + self.speculative_config = vllm_config.speculative_config + self.prompt_adapter_config = vllm_config.prompt_adapter_config + self.observability_config = vllm_config.observability_config self.worker = self._create_worker() self.worker.initialize() @@ -56,19 +42,10 @@ def _create_worker( distributed_init_method = get_distributed_init_method( get_ip(), get_open_port()) return Worker( - model_config=self.model_config, - parallel_config=self.parallel_config, - scheduler_config=self.scheduler_config, - device_config=self.device_config, - cache_config=self.cache_config, - load_config=self.load_config, + vllm_config=self.vllm_config, local_rank=local_rank, rank=rank, distributed_init_method=distributed_init_method, - lora_config=self.lora_config, - speculative_config=self.speculative_config, - prompt_adapter_config=self.prompt_adapter_config, - observability_config=self.observability_config, ) def determine_num_available_blocks(self) -> Tuple[int, int]: diff --git a/vllm/v1/sample/metadata.py b/vllm/v1/sample/metadata.py index 28614377b27b9..9ef36f2e6b212 100644 --- a/vllm/v1/sample/metadata.py +++ b/vllm/v1/sample/metadata.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import List, Optional +from typing import Dict import torch @@ -16,7 +16,6 @@ class SamplingMetadata: no_top_p: bool no_top_k: bool - generators: List[Optional[torch.Generator]] - no_generator: bool + generators: Dict[int, torch.Generator] max_num_logprobs: int diff --git a/vllm/v1/sample/sampler.py b/vllm/v1/sample/sampler.py index 157c4dd6d771e..927f274541c4d 100644 --- a/vllm/v1/sample/sampler.py +++ b/vllm/v1/sample/sampler.py @@ -1,5 +1,5 @@ """A layer that samples the next tokens from the model's outputs.""" -from typing import List, Optional +from typing import Dict import torch import torch.nn as nn @@ -84,22 +84,21 @@ def greedy_sample(self, probs: torch.Tensor) -> torch.Tensor: def random_sample( self, probs: torch.Tensor, - generators: List[Optional[torch.Generator]], - no_generator: bool, + generators: Dict[int, torch.Generator], ) -> torch.Tensor: q = torch.empty_like(probs) # NOTE(woosuk): To batch-process the requests without their own seeds, # which is the common case, we first assume that every request does # not have its own seed. Then, we overwrite the values for the requests # that have their own seeds. - q.exponential_() - if not no_generator: - assert len(generators) == probs.shape[0] + if len(generators) != probs.shape[0]: + # This might still be done here unnecessarily if there are greedies + q.exponential_() + if generators: # TODO(woosuk): This can be slow because we handle each request # one by one. Optimize this. - for i, generator in enumerate(generators): - if generator is not None: - q[i].exponential_(generator=generator) + for i, generator in generators.items(): + q[i].exponential_(generator=generator) return probs.div_(q).argmax(dim=-1).view(-1) def sample( @@ -112,13 +111,11 @@ def sample( if sampling_metadata.all_greedy: return self.greedy_sample(probs) if sampling_metadata.all_random: - return self.random_sample(probs, sampling_metadata.generators, - sampling_metadata.no_generator) + return self.random_sample(probs, sampling_metadata.generators) greedy_sampled = self.greedy_sample(probs) random_sampled = self.random_sample(probs, - sampling_metadata.generators, - sampling_metadata.no_generator) + sampling_metadata.generators) sampled = torch.where( sampling_metadata.temperature < _SAMPLING_EPS, greedy_sampled, diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index e84645ac7a4ae..ae4239f8e1fab 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -7,9 +7,7 @@ import torch.distributed import torch.nn as nn -from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, - ModelConfig, ObservabilityConfig, ParallelConfig, - PromptAdapterConfig, SchedulerConfig) +from vllm.config import VllmConfig from vllm.forward_context import set_forward_context from vllm.logger import init_logger from vllm.model_executor.model_loader import get_model @@ -33,26 +31,25 @@ class GPUModelRunner: def __init__( self, - model_config: ModelConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - device_config: DeviceConfig, - cache_config: CacheConfig, - load_config: LoadConfig, - lora_config: Optional[LoRAConfig] = None, - prompt_adapter_config: Optional[PromptAdapterConfig] = None, - observability_config: Optional[ObservabilityConfig] = None, + vllm_config: VllmConfig, ): - self.model_config = model_config - self.parallel_config = parallel_config - self.scheduler_config = scheduler_config - self.device_config = device_config - self.cache_config = cache_config - self.lora_config = lora_config - self.load_config = load_config - self.prompt_adapter_config = prompt_adapter_config - self.observability_config = observability_config - + # TODO: use ModelRunnerBase.__init__(self, vllm_config=vllm_config) + self.vllm_config = vllm_config + self.model_config = vllm_config.model_config + self.cache_config = vllm_config.cache_config + self.lora_config = vllm_config.lora_config + self.load_config = vllm_config.load_config + self.parallel_config = vllm_config.parallel_config + self.scheduler_config = vllm_config.scheduler_config + self.device_config = vllm_config.device_config + self.speculative_config = vllm_config.speculative_config + self.prompt_adapter_config = vllm_config.prompt_adapter_config + self.observability_config = vllm_config.observability_config + + model_config = self.model_config + cache_config = self.cache_config + scheduler_config = self.scheduler_config + parallel_config = self.parallel_config self.device = self.device_config.device self.pin_memory = is_pin_memory_available() self.dtype = self.model_config.dtype @@ -131,13 +128,20 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: # Add new requests to the cached states. for req_data in scheduler_output.scheduled_new_reqs: req_id = req_data.req_id + sampling_params = req_data.sampling_params + if sampling_params.seed is not None: + generator = torch.Generator(device=self.device) + generator.manual_seed(sampling_params.seed) + else: + generator = None + self.requests[req_id] = CachedRequestState( req_id=req_id, prompt_token_ids=req_data.prompt_token_ids, prompt=req_data.prompt, multi_modal_data=req_data.multi_modal_data, - sampling_params=req_data.sampling_params, - generator=None, # TODO + sampling_params=sampling_params, + generator=generator, block_ids=req_data.block_ids, num_computed_tokens=req_data.num_computed_tokens, output_token_ids=[], @@ -345,11 +349,9 @@ def execute_model( else: # Ignore the sampled token from the partial request. # Rewind the generator state as if the token was not sampled. - generator = self.input_batch.generators[i] + generator = self.input_batch.generators.get(i) if generator is not None: - offset = generator.get_offset() - generator = generator.set_offset(offset - 1) - self.input_batch.generators[i] = generator + generator.set_offset(generator.get_offset() - 1) if sampler_output.logprob_token_ids is None: logprob_token_ids = None @@ -372,13 +374,7 @@ def load_model(self) -> None: logger.info("Starting to load model %s...", self.model_config.model) with DeviceMemoryProfiler() as m: # noqa: SIM117 with patch("vllm.model_executor.layers.sampler.Sampler", Sampler): - self.model = get_model(model_config=self.model_config, - device_config=self.device_config, - load_config=self.load_config, - lora_config=self.lora_config, - parallel_config=self.parallel_config, - scheduler_config=self.scheduler_config, - cache_config=self.cache_config) + self.model = get_model(vllm_config=self.vllm_config) self.model_memory_usage = m.consumed_memory logger.info("Loading model weights took %.4f GB", @@ -503,8 +499,8 @@ def __init__( self.top_k_cpu = self.top_k_cpu_tensor.numpy() self.top_k_reqs: Set[str] = set() - self.generators: List[Optional[torch.Generator]] = [None - ] * max_num_reqs + # req_index -> generator + self.generators: Dict[int, torch.Generator] = {} self.num_logprobs: Dict[str, int] = {} self.prompt_logprob_reqs: Set[str] = set() @@ -518,8 +514,9 @@ def add_request( req_index = self.num_reqs assert req_index < self.max_num_reqs - self.req_ids[req_index] = request.req_id - self.req_id_to_index[request.req_id] = req_index + req_id = request.req_id + self.req_ids[req_index] = req_id + self.req_id_to_index[req_id] = req_index # Copy the prompt token ids and output token ids. num_prompt_tokens = len(request.prompt_token_ids) @@ -537,27 +534,24 @@ def add_request( sampling_params = request.sampling_params self.temperature_cpu[req_index] = sampling_params.temperature if sampling_params.sampling_type == SamplingType.GREEDY: - self.greedy_reqs.add(req_index) - elif sampling_params.sampling_type == SamplingType.RANDOM: - self.random_reqs.add(req_index) - elif sampling_params.sampling_type == SamplingType.RANDOM_SEED: - # TODO(woosuk): Support per-request random seed. - raise NotImplementedError("Per-request seed is not supported yet.") + self.greedy_reqs.add(req_id) + else: + self.random_reqs.add(req_id) self.top_p_cpu[req_index] = sampling_params.top_p if sampling_params.top_p < 1: - self.top_p_reqs.add(req_index) + self.top_p_reqs.add(req_id) self.top_k_cpu[req_index] = sampling_params.top_k if sampling_params.top_k > 0: - self.top_k_reqs.add(req_index) + self.top_k_reqs.add(req_id) self.generators[req_index] = request.generator num_logprobs = sampling_params.logprobs if num_logprobs is not None and num_logprobs > 0: - self.num_logprobs[request.req_id] = num_logprobs + self.num_logprobs[req_id] = num_logprobs if sampling_params.prompt_logprobs: - self.prompt_logprob_reqs.add(req_index) + self.prompt_logprob_reqs.add(req_id) def remove_request(self, req_id: str) -> Optional[int]: req_index = self.req_id_to_index.pop(req_id, None) @@ -569,7 +563,7 @@ def remove_request(self, req_id: str) -> Optional[int]: self.random_reqs.discard(req_id) self.top_p_reqs.discard(req_id) self.top_k_reqs.discard(req_id) - self.generators[req_index] = None + self.generators.pop(req_index, None) self.num_logprobs.pop(req_id, None) self.prompt_logprob_reqs.discard(req_id) return req_index @@ -621,7 +615,9 @@ def condense(self, empty_req_indices: List[int]) -> None: last_req_index] self.top_p_cpu[empty_index] = self.top_p_cpu[last_req_index] self.top_k_cpu[empty_index] = self.top_k_cpu[last_req_index] - self.generators[empty_index] = self.generators[last_req_index] + generator = self.generators.pop(last_req_index, None) + if generator is not None: + self.generators[empty_index] = generator # Decrement last_req_index since it is now empty. last_req_index -= 1 @@ -645,8 +641,7 @@ def make_sampling_metadata( top_k=self.top_k[:self.num_reqs], no_top_p=self.no_top_p, no_top_k=self.no_top_k, - generators=self.generators[:self.num_reqs], - no_generator=self.no_generator, + generators=self.generators, max_num_logprobs=self.max_num_logprobs, ) @@ -670,16 +665,9 @@ def no_top_p(self) -> bool: def no_top_k(self) -> bool: return len(self.top_k_reqs) == 0 - @property - def no_generator(self) -> bool: - return len(self.generators) == 0 - @property def max_num_logprobs(self) -> int: - if self.num_logprobs: - return max(self.num_logprobs.values()) - else: - return 0 + return max(self.num_logprobs.values()) if self.num_logprobs else 0 @property def no_logprob(self) -> bool: diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 8c5ca2ec35666..c8192b7f86eb0 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -6,10 +6,7 @@ import torch import torch.distributed -from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, - ModelConfig, ObservabilityConfig, ParallelConfig, - PromptAdapterConfig, SchedulerConfig, - SpeculativeConfig) +from vllm.config import CacheConfig, ModelConfig, ParallelConfig, VllmConfig from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment, set_custom_all_reduce) @@ -30,48 +27,35 @@ class Worker: def __init__( self, - model_config: ModelConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - device_config: DeviceConfig, - cache_config: CacheConfig, - load_config: LoadConfig, + vllm_config: VllmConfig, local_rank: int, rank: int, distributed_init_method: str, - speculative_config: Optional[SpeculativeConfig] = None, - lora_config: Optional[LoRAConfig] = None, - prompt_adapter_config: Optional[PromptAdapterConfig] = None, - observability_config: Optional[ObservabilityConfig] = None, ): - self.model_config = model_config - self.parallel_config = parallel_config - self.scheduler_config = scheduler_config - self.device_config = device_config - self.cache_config = cache_config - self.load_config = load_config + + # TODO: use WorkerBase.__init__(self, vllm_config=vllm_config) + self.vllm_config = vllm_config + self.model_config = vllm_config.model_config + self.cache_config = vllm_config.cache_config + self.lora_config = vllm_config.lora_config + self.load_config = vllm_config.load_config + self.parallel_config = vllm_config.parallel_config + self.scheduler_config = vllm_config.scheduler_config + self.device_config = vllm_config.device_config + self.speculative_config = vllm_config.speculative_config + self.prompt_adapter_config = vllm_config.prompt_adapter_config + self.observability_config = vllm_config.observability_config + self.local_rank = local_rank self.rank = rank self.distributed_init_method = distributed_init_method - self.lora_config = lora_config - self.speculative_config = speculative_config - self.prompt_adapter_config = prompt_adapter_config - self.observability_config = observability_config if self.model_config.trust_remote_code: # note: lazy import to avoid importing torch before initializing from vllm.utils import init_cached_hf_modules init_cached_hf_modules() - self.model_runner = GPUModelRunner( - model_config, - parallel_config, - scheduler_config, - device_config, - cache_config, - load_config, - lora_config=lora_config, - ) + self.model_runner = GPUModelRunner(vllm_config) def initialize(self): if self.device_config.device.type == "cuda": diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index 5032896600b3b..fdd72a452f2ad 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -1,5 +1,6 @@ import dataclasses import weakref +from collections import defaultdict from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union @@ -7,16 +8,14 @@ from torch import nn from vllm.attention import AttentionMetadata, get_attn_backend -from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, - ModelConfig, ParallelConfig, PromptAdapterConfig, - SchedulerConfig) +from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.model_executor import SamplingMetadata from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.model_loader import get_model from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, - MultiModalInputs) + MultiModalInputs, MultiModalPlaceholderMap) from vllm.sequence import (IntermediateTensors, SequenceData, SequenceGroupMetadata) from vllm.transformers_utils.config import uses_mrope @@ -148,9 +147,18 @@ def build(self) -> ModelInputForCPU: query_lens=seq_lens, ) - def _compute_multi_modal_input(self, seq_data: SequenceData, mm_data, - computed_len: int, + def _compute_multi_modal_input(self, seq_group: SequenceGroupMetadata, + seq_data: SequenceData, computed_len: int, mm_processor_kwargs: Dict[str, Any]): + + # NOTE: mm_data only includes the subset of multi-modal items that + # intersect with the current prefill positions. + mm_data, placeholder_maps = MultiModalPlaceholderMap.from_seq_group( + seq_group, range(computed_len, len(seq_data.get_token_ids()))) + + if not mm_data: + return + mm_kwargs = self.multi_modal_input_mapper(mm_data, mm_processor_kwargs) # special processing for mrope position deltas. @@ -179,7 +187,7 @@ def _compute_multi_modal_input(self, seq_data: SequenceData, mm_data, context_len=computed_len, ) seq_data.mrope_position_delta = mrope_position_delta - return mm_kwargs, mrope_positions + return mm_kwargs, placeholder_maps, mrope_positions def _prepare_prompt( self, @@ -194,6 +202,9 @@ def _prepare_prompt( slot_mapping: List[int] = [] seq_lens: List[int] = [] multi_modal_inputs_list: List[MultiModalInputs] = [] + multi_modal_placeholder_maps: Dict[ + str, + MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap) for seq_group_metadata in seq_group_metadata_list: assert seq_group_metadata.is_prompt @@ -210,11 +221,15 @@ def _prepare_prompt( input_tokens.extend(prompt_tokens) # Token ids mrope_positions = None - if (mm_data := seq_group_metadata.multi_modal_data): - mm_kwargs, mrope_positions = self._compute_multi_modal_input( - seq_data, mm_data, computed_len, + if seq_group_metadata.multi_modal_data: + mm_kwargs, placeholder_maps, mrope_positions = self \ + ._compute_multi_modal_input( + seq_group_metadata, seq_data, computed_len, seq_group_metadata.mm_processor_kwargs) multi_modal_inputs_list.append(mm_kwargs) + for modality, placeholder_map in placeholder_maps.items(): + multi_modal_placeholder_maps[modality].extend( + placeholder_map) # Token position ids # NOTE(woosuk): Here we assume that the first token in the prompt @@ -264,6 +279,11 @@ def _prepare_prompt( slot_mapping = torch.tensor(slot_mapping, dtype=torch.long, device=self.device) # type: ignore + placeholder_index_maps = { + modality: placeholder_map.index_map() + for modality, placeholder_map in + multi_modal_placeholder_maps.items() + } attn_metadata = self.attn_backend.make_metadata( is_prompt=True, @@ -275,6 +295,7 @@ def _prepare_prompt( num_decode_tokens=0, block_tables=torch.tensor([]), slot_mapping=slot_mapping, + multi_modal_placeholder_index_maps=placeholder_index_maps, ) multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list) @@ -366,6 +387,7 @@ def _prepare_decode( attn_metadata = self.attn_backend.make_metadata( is_prompt=False, slot_mapping=slot_mapping, + multi_modal_placeholder_index_maps=None, seq_lens=seq_lens, seq_lens_tensor=seq_lens_tensor, max_decode_seq_len=max_decode_seq_len, @@ -388,29 +410,18 @@ class CPUModelRunner(ModelRunnerBase[ModelInputForCPU]): def __init__( self, - model_config: ModelConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - device_config: DeviceConfig, - cache_config: CacheConfig, - load_config: LoadConfig, - lora_config: Optional[LoRAConfig], + vllm_config: VllmConfig, kv_cache_dtype: Optional[str] = "auto", - prompt_adapter_config: Optional[PromptAdapterConfig] = None, is_driver_worker: bool = False, *args, **kwargs, ): - self.model_config = model_config - self.parallel_config = parallel_config - self.scheduler_config = scheduler_config + ModelRunnerBase.__init__(self, vllm_config) # Currently, CPU worker doesn't support chunked prefill. assert self.scheduler_config.chunked_prefill_enabled is False - self.device_config = device_config - self.cache_config = cache_config - self.lora_config = lora_config - self.prompt_adapter_config = prompt_adapter_config - self.load_config = load_config + model_config = self.model_config + cache_config = self.cache_config + self.is_driver_worker = is_driver_worker self.device = self.device_config.device @@ -442,13 +453,7 @@ def model_is_mrope(self) -> bool: return uses_mrope(self.model_config.hf_config) def load_model(self) -> None: - self.model = get_model(model_config=self.model_config, - load_config=self.load_config, - device_config=self.device_config, - lora_config=self.lora_config, - parallel_config=self.parallel_config, - scheduler_config=self.scheduler_config, - cache_config=self.cache_config) + self.model = get_model(vllm_config=self.vllm_config) def make_model_input_from_broadcasted_tensor_dict( self, diff --git a/vllm/worker/cpu_worker.py b/vllm/worker/cpu_worker.py index ab93471b5af74..3778707ae07e8 100644 --- a/vllm/worker/cpu_worker.py +++ b/vllm/worker/cpu_worker.py @@ -6,9 +6,8 @@ import vllm.envs as envs from vllm.attention import get_attn_backend -from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, - ModelConfig, ParallelConfig, PromptAdapterConfig, - SchedulerConfig) +from vllm.config import (CacheConfig, DeviceConfig, ModelConfig, + ParallelConfig, VllmConfig) from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment) from vllm.logger import init_logger @@ -18,7 +17,8 @@ from vllm.worker.cpu_enc_dec_model_runner import CPUEncoderDecoderModelRunner from vllm.worker.cpu_model_runner import CPUModelRunner from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, - LoraNotSupportedWorkerBase, WorkerInput) + LoraNotSupportedWorkerBase, WorkerBase, + WorkerInput) logger = init_logger(__name__) @@ -121,31 +121,19 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase): def __init__( self, - model_config: ModelConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - device_config: DeviceConfig, - cache_config: CacheConfig, - load_config: LoadConfig, + vllm_config: VllmConfig, local_rank: int, rank: int, distributed_init_method: str, - lora_config: Optional[LoRAConfig] = None, kv_cache_dtype: Optional[str] = "auto", - prompt_adapter_config: Optional[PromptAdapterConfig] = None, is_driver_worker: bool = False, ) -> None: - self.model_config = model_config - self.parallel_config = parallel_config - self.scheduler_config = scheduler_config - self.device_config = device_config - self.cache_config = cache_config - self.load_config = load_config + WorkerBase.__init__(self, vllm_config=vllm_config) + self.local_rank = local_rank self.rank = rank self.distributed_init_method = distributed_init_method - self.lora_config = lora_config - self.prompt_adapter_config = prompt_adapter_config + self.is_driver_worker = is_driver_worker if self.is_driver_worker: assert self.rank == 0, "The driver worker must have rank 0." @@ -166,15 +154,8 @@ def __init__( if self._is_encoder_decoder_model(): ModelRunnerClass = CPUEncoderDecoderModelRunner self.model_runner: CPUModelRunner = ModelRunnerClass( - model_config, - parallel_config, - scheduler_config, - device_config, - cache_config, - load_config=self.load_config, - lora_config=self.lora_config, + vllm_config=vllm_config, kv_cache_dtype=kv_cache_dtype, - prompt_adapter_config=self.prompt_adapter_config, is_driver_worker=is_driver_worker) # Uninitialized cache engine. Will be initialized by # initialize_cache. diff --git a/vllm/worker/embedding_model_runner.py b/vllm/worker/embedding_model_runner.py index a7f5b2d4fdd1f..ff288d5ca1512 100644 --- a/vllm/worker/embedding_model_runner.py +++ b/vllm/worker/embedding_model_runner.py @@ -3,9 +3,7 @@ import torch -from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, - ModelConfig, ObservabilityConfig, ParallelConfig, - PromptAdapterConfig, SchedulerConfig) +from vllm.config import VllmConfig from vllm.distributed import get_pp_group from vllm.forward_context import set_forward_context from vllm.logger import init_logger @@ -36,29 +34,13 @@ class EmbeddingModelRunner( def __init__( self, - model_config: ModelConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - device_config: DeviceConfig, - cache_config: CacheConfig, - load_config: LoadConfig, - lora_config: Optional[LoRAConfig], + vllm_config: VllmConfig, kv_cache_dtype: Optional[str] = "auto", is_driver_worker: bool = False, - prompt_adapter_config: Optional[PromptAdapterConfig] = None, - observability_config: Optional[ObservabilityConfig] = None, ): - super().__init__(model_config, - parallel_config, - scheduler_config, - device_config, - cache_config, - load_config, - lora_config=lora_config, + super().__init__(vllm_config=vllm_config, kv_cache_dtype=kv_cache_dtype, - is_driver_worker=is_driver_worker, - prompt_adapter_config=prompt_adapter_config, - observability_config=observability_config) + is_driver_worker=is_driver_worker) @torch.inference_mode() def execute_model( diff --git a/vllm/worker/enc_dec_model_runner.py b/vllm/worker/enc_dec_model_runner.py index 6a00444f5098b..90a43196084ea 100644 --- a/vllm/worker/enc_dec_model_runner.py +++ b/vllm/worker/enc_dec_model_runner.py @@ -11,14 +11,13 @@ from vllm.attention.selector import (_Backend, get_env_variable_attn_backend, get_global_forced_attn_backend, global_force_attn_backend) -from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, - ModelConfig, ObservabilityConfig, ParallelConfig, - PromptAdapterConfig, SchedulerConfig) +from vllm.config import ModelConfig, VllmConfig from vllm.forward_context import set_forward_context from vllm.inputs import INPUT_REGISTRY, InputRegistry from vllm.logger import init_logger from vllm.model_executor import SamplingMetadata from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.model_executor.model_loader.utils import get_architecture_class_name from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalInputs, MultiModalRegistry) from vllm.sampling_params import SamplingParams @@ -36,6 +35,11 @@ logger = init_logger(__name__) +# The Mllama model has PagedAttention specific logic because of which it +# can only be run with the XFORMERS backend +# TODO Make Mllama model work with Flash Attention backend. +_XFORMERS_ONLY_ENCODER_DECODER_ARCHS = ["MllamaForConditionalGeneration"] + @dataclasses.dataclass(frozen=True) class EncoderDecoderModelInput(ModelInputForGPUWithSamplingMetadata): @@ -79,17 +83,9 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]): def __init__( self, - model_config: ModelConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - device_config: DeviceConfig, - cache_config: CacheConfig, - load_config: LoadConfig, - lora_config: Optional[LoRAConfig], + vllm_config: VllmConfig, kv_cache_dtype: Optional[str] = "auto", is_driver_worker: bool = False, - prompt_adapter_config: Optional[PromptAdapterConfig] = None, - observability_config: Optional[ObservabilityConfig] = None, input_registry: InputRegistry = INPUT_REGISTRY, mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, ): @@ -101,17 +97,10 @@ def __init__( models) but these arguments are present here for compatibility with the base-class constructor. ''' - - self._maybe_force_supported_attention_backend() + self._maybe_force_supported_attention_backend(vllm_config.model_config) super().__init__( - model_config, - parallel_config, - scheduler_config, - device_config, - cache_config, - load_config, - lora_config=None, + vllm_config=vllm_config, kv_cache_dtype=kv_cache_dtype, is_driver_worker=is_driver_worker, ) @@ -119,7 +108,12 @@ def __init__( # Crash for unsupported encoder/scenarios assert_enc_dec_mr_supported_scenario(self) - def _maybe_force_supported_attention_backend(self): + def _is_xformers_only_encoder_decoder_model(self, + model: ModelConfig) -> bool: + return get_architecture_class_name( + model) in _XFORMERS_ONLY_ENCODER_DECODER_ARCHS + + def _maybe_force_supported_attention_backend(self, model: ModelConfig): ''' Force vLLM to use the XFormers attention backend, which is currently the only supported option. @@ -135,22 +129,26 @@ def raise_backend_err(): is_forced_by_global = maybe_global_forced_backend is not None is_forced_by_env_var = maybe_env_var_forced_backend is not None - if not (is_forced_by_global or is_forced_by_env_var): + if not (is_forced_by_global or is_forced_by_env_var) \ + and self._is_xformers_only_encoder_decoder_model(model): # The user has not already specified an attention backend # override - logger.info("EncoderDecoderModelRunner requires " - "XFormers backend; overriding backend " - "auto-selection and forcing XFormers.") + logger.info( + "Encoder-Decoder Model Architecture %s requires XFormers " + "backend; overriding backend auto-selection and " + "forcing XFormers.", get_architecture_class_name(model)) global_force_attn_backend(_Backend.XFORMERS) elif is_forced_by_global: # Backend override enforced by global variable takes # precedence over vLLM backend environment variable. - if maybe_global_forced_backend != _Backend.XFORMERS: + if maybe_global_forced_backend not in\ + [_Backend.XFORMERS, _Backend.FLASH_ATTN]: raise_backend_err() elif is_forced_by_env_var: # Backend override enforced by vLLM backend # environment variable - if maybe_env_var_forced_backend != _Backend.XFORMERS: + if maybe_env_var_forced_backend not in\ + [_Backend.XFORMERS, _Backend.FLASH_ATTN]: raise_backend_err() def _list_to_int32_tensor( @@ -306,13 +304,12 @@ def profile_run(self) -> None: (group_id < max_num_batched_tokens % max_num_seqs)) batch_size += seq_len - decoder_seq_data, decoder_dummy_multi_modal_data \ - = self.input_registry.dummy_data_for_profiling( - self.model_config, + decoder_dummy_data = self.input_registry \ + .dummy_data_for_profiling(self.model_config, seq_len, self.mm_registry, is_encoder_data=False) - encoder_seq_data, encoder_dummy_multi_modal_data \ + encoder_dummy_data \ = self.input_registry.dummy_data_for_profiling( self.model_config, seq_len, @@ -320,26 +317,31 @@ def profile_run(self) -> None: is_encoder_data=True) # Having more tokens is over-conservative but otherwise fine - assert len(decoder_seq_data.prompt_token_ids) >= seq_len, ( + assert len( + decoder_dummy_data.seq_data.prompt_token_ids + ) >= seq_len, ( f"Expected at least {seq_len} dummy tokens for profiling, " - f"but got: {len(decoder_seq_data.prompt_token_ids)}") + f"but got: {len(decoder_dummy_data.seq_data.prompt_token_ids)}" + ) - assert decoder_dummy_multi_modal_data is None or \ - encoder_dummy_multi_modal_data is None, ( + assert decoder_dummy_data.multi_modal_data is None or \ + encoder_dummy_data.multi_modal_data is None, ( "Multi-modal data can't be provided in both encoder and decoder" ) seq = SequenceGroupMetadata( request_id=str(group_id), is_prompt=True, - seq_data={group_id: decoder_seq_data}, + seq_data={group_id: decoder_dummy_data.seq_data}, sampling_params=sampling_params, block_tables=None, - encoder_seq_data=encoder_seq_data, + encoder_seq_data=encoder_dummy_data.seq_data, cross_block_table=None, - multi_modal_data=decoder_dummy_multi_modal_data - or encoder_dummy_multi_modal_data, - ) + multi_modal_data=decoder_dummy_data.multi_modal_data + or encoder_dummy_data.multi_modal_data, + multi_modal_placeholders=decoder_dummy_data. + multi_modal_placeholders + or encoder_dummy_data.multi_modal_placeholders) seqs.append(seq) # Run the model with the dummy inputs. @@ -528,6 +530,7 @@ def _prepare_encoder_model_input_tensors( attn_metadata.encoder_seq_lens, attn_metadata.encoder_seq_lens_tensor, attn_metadata.max_encoder_seq_len, + attn_metadata.encoder_seq_start_loc, attn_metadata.cross_slot_mapping, attn_metadata.cross_block_tables, ) = ( @@ -535,6 +538,7 @@ def _prepare_encoder_model_input_tensors( encoder_seq_lens, encoder_seq_lens_tensor, max_encoder_seq_len, + encoder_seq_start_loc, cross_slot_mapping_tensor, cross_block_tables, ) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 891637dafbb14..328dab598f8ef 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -20,9 +20,7 @@ from vllm.attention.backends.utils import CommonAttentionState from vllm.compilation.compile_context import set_compile_context from vllm.compilation.levels import CompilationLevel -from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, - ModelConfig, ObservabilityConfig, ParallelConfig, - PromptAdapterConfig, SchedulerConfig) +from vllm.config import VllmConfig from vllm.core.scheduler import SchedulerOutputs from vllm.distributed import get_pp_group from vllm.distributed.parallel_state import graph_capture @@ -40,7 +38,8 @@ from vllm.model_executor.models import supports_lora, supports_multimodal from vllm.model_executor.models.utils import set_cpu_offload_max_bytes from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, - MultiModalInputs, MultiModalRegistry) + MultiModalInputs, MultiModalPlaceholderMap, + MultiModalRegistry) from vllm.platforms import current_platform from vllm.prompt_adapter.layers import PromptAdapterMapping from vllm.prompt_adapter.request import PromptAdapterRequest @@ -242,6 +241,8 @@ def __init__( # Multi-modal inputs. multi_modal_inputs: Optional[MultiModalInputs] = None, + multi_modal_placeholder_maps: Optional[Dict[ + str, MultiModalPlaceholderMap]] = None, # Whether the prefix cache is hit (prefill only). prefix_cache_hit: bool = False, @@ -361,6 +362,7 @@ def __init__( self.prompt_adapter_request = prompt_adapter_request self.multi_modal_inputs = multi_modal_inputs + self.multi_modal_placeholder_maps = multi_modal_placeholder_maps self.prefix_cache_hit = prefix_cache_hit self.n_seqs = len(self.seq_ids) @@ -635,7 +637,12 @@ def _compute_prompt_adapter_input( def _compute_multi_modal_input(self, inter_data: InterDataForSeqGroup, seq_group_metadata: SequenceGroupMetadata): """If multi-modal data is given, add it to the input.""" - mm_data = seq_group_metadata.multi_modal_data + # NOTE: mm_data only includes the subset of multi-modal items that + # intersect with the current prefill positions. + positions = inter_data.input_positions[0] + mm_data, placeholder_maps = MultiModalPlaceholderMap.from_seq_group( + seq_group_metadata, + range(positions[0], positions[0] + len(positions))) if not mm_data: return @@ -643,6 +650,7 @@ def _compute_multi_modal_input(self, inter_data: InterDataForSeqGroup, mm_data, mm_processor_kwargs=seq_group_metadata.mm_processor_kwargs) inter_data.multi_modal_inputs = mm_kwargs + inter_data.multi_modal_placeholder_maps = placeholder_maps # special processing for mrope position deltas. if self.runner.model_is_mrope: @@ -945,32 +953,20 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): def __init__( self, - model_config: ModelConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - device_config: DeviceConfig, - cache_config: CacheConfig, - load_config: LoadConfig, - lora_config: Optional[LoRAConfig], + vllm_config: VllmConfig, kv_cache_dtype: Optional[str] = "auto", is_driver_worker: bool = False, - prompt_adapter_config: Optional[PromptAdapterConfig] = None, return_hidden_states: bool = False, - observability_config: Optional[ObservabilityConfig] = None, input_registry: InputRegistry = INPUT_REGISTRY, mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, ): - self.model_config = model_config - self.parallel_config = parallel_config - self.scheduler_config = scheduler_config - self.device_config = device_config - self.cache_config = cache_config - self.lora_config = lora_config - self.load_config = load_config + + ModelRunnerBase.__init__(self, vllm_config) + model_config = self.model_config + cache_config = self.cache_config + self.is_driver_worker = is_driver_worker - self.prompt_adapter_config = prompt_adapter_config self.return_hidden_states = return_hidden_states - self.observability_config = observability_config self.device = self.device_config.device self.pin_memory = is_pin_memory_available() @@ -1055,13 +1051,7 @@ def __init__( def load_model(self) -> None: logger.info("Starting to load model %s...", self.model_config.model) with DeviceMemoryProfiler() as m: - self.model = get_model(model_config=self.model_config, - device_config=self.device_config, - load_config=self.load_config, - lora_config=self.lora_config, - parallel_config=self.parallel_config, - scheduler_config=self.scheduler_config, - cache_config=self.cache_config) + self.model = get_model(vllm_config=self.vllm_config) self.model_memory_usage = m.consumed_memory logger.info("Loading model weights took %.4f GB", @@ -1255,7 +1245,7 @@ def profile_run(self) -> None: (group_id < max_num_batched_tokens % max_num_seqs)) batch_size += seq_len - seq_data, dummy_multi_modal_data = self.input_registry \ + dummy_data = self.input_registry \ .dummy_data_for_profiling(self.model_config, seq_len, self.mm_registry) @@ -1263,12 +1253,13 @@ def profile_run(self) -> None: seq = SequenceGroupMetadata( request_id=str(group_id), is_prompt=True, - seq_data={group_id: seq_data}, + seq_data={group_id: dummy_data.seq_data}, sampling_params=sampling_params, block_tables=None, lora_request=dummy_lora_requests_per_seq[group_id] if dummy_lora_requests_per_seq else None, - multi_modal_data=dummy_multi_modal_data, + multi_modal_data=dummy_data.multi_modal_data, + multi_modal_placeholders=dummy_data.multi_modal_placeholders, ) seqs.append(seq) diff --git a/vllm/worker/model_runner_base.py b/vllm/worker/model_runner_base.py index 86883cf152449..9e529f86b46bb 100644 --- a/vllm/worker/model_runner_base.py +++ b/vllm/worker/model_runner_base.py @@ -9,6 +9,7 @@ import torch from torch import is_tensor +from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.model_executor.layers.sampler import SamplerOutput from vllm.platforms import current_platform @@ -46,9 +47,8 @@ def _init_attn_metadata_from_tensor_dict( # Extract the fields used to create AttentionMetadata. valid_attn_kwargs = {} for field in dataclasses.fields(attn_backend.get_metadata_cls()): - val = tensor_dict.pop(field.name, None) - if val is not None: - valid_attn_kwargs[field.name] = val + if field.name in tensor_dict: + valid_attn_kwargs[field.name] = tensor_dict.pop(field.name) attn_metadata = attn_backend.make_metadata(**valid_attn_kwargs) tensor_dict["attn_metadata"] = attn_metadata @@ -221,6 +221,22 @@ class ModelRunnerBase(ABC, Generic[T]): ModelRunnerInputBase subclass. """ + def __init__( + self, + vllm_config: VllmConfig, + ) -> None: + self.vllm_config = vllm_config + self.model_config = vllm_config.model_config + self.cache_config = vllm_config.cache_config + self.lora_config = vllm_config.lora_config + self.load_config = vllm_config.load_config + self.parallel_config = vllm_config.parallel_config + self.scheduler_config = vllm_config.scheduler_config + self.device_config = vllm_config.device_config + self.speculative_config = vllm_config.speculative_config + self.prompt_adapter_config = vllm_config.prompt_adapter_config + self.observability_config = vllm_config.observability_config + # Map of request_id -> generator used for seeded random sampling generators: Dict[str, torch.Generator] = {} diff --git a/vllm/worker/multi_step_model_runner.py b/vllm/worker/multi_step_model_runner.py index be2f0d79154d6..3ee0fb4dc943e 100644 --- a/vllm/worker/multi_step_model_runner.py +++ b/vllm/worker/multi_step_model_runner.py @@ -304,6 +304,7 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]): # mypy: enable-error-code=type-var def __init__(self, base_model_runner: GPUModelRunnerBase, *args, **kwargs): + super().__init__(*args, **kwargs) # Check attention backend support. diff --git a/vllm/worker/multi_step_worker.py b/vllm/worker/multi_step_worker.py index bf66f32d7d244..1f982fe103366 100644 --- a/vllm/worker/multi_step_worker.py +++ b/vllm/worker/multi_step_worker.py @@ -27,17 +27,9 @@ def __init__(self, *args, **kwargs): # for multi-step model, wrap the model runner with MultiStepModelRunner self.model_runner = MultiStepModelRunner( base_model_runner, - base_model_runner.model_config, - base_model_runner.parallel_config, - base_model_runner.scheduler_config, - base_model_runner.device_config, - base_model_runner.cache_config, - load_config=base_model_runner.load_config, - lora_config=self.lora_config, + vllm_config=base_model_runner.vllm_config, kv_cache_dtype=self.cache_config.cache_dtype, is_driver_worker=base_model_runner.is_driver_worker, - prompt_adapter_config=base_model_runner.prompt_adapter_config, - observability_config=base_model_runner.observability_config, ) pipeline_parallel_size = self.parallel_config.pipeline_parallel_size diff --git a/vllm/worker/neuron_model_runner.py b/vllm/worker/neuron_model_runner.py index b8c760c4b5396..2da22cbfc7cb5 100644 --- a/vllm/worker/neuron_model_runner.py +++ b/vllm/worker/neuron_model_runner.py @@ -7,8 +7,7 @@ from torch import nn from transformers_neuronx.config import GenerationConfig -from vllm.config import (DeviceConfig, ModelConfig, ParallelConfig, - SchedulerConfig) +from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.model_executor import SamplingMetadata from vllm.model_executor.layers.sampler import SamplerOutput @@ -57,20 +56,13 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]): def __init__( self, - model_config: ModelConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - device_config: DeviceConfig, + vllm_config: VllmConfig, ): - self.model_config = model_config - self.parallel_config = parallel_config - self.scheduler_config = scheduler_config - + ModelRunnerBase.__init__(self, vllm_config) + model_config = self.model_config if model_config is not None and model_config.get_sliding_window(): logger.warning("Sliding window is not supported on Neuron. " "The model will run without sliding window.") - self.device_config = (device_config - if device_config is not None else DeviceConfig()) self.device = self.device_config.device self.pin_memory = is_pin_memory_available() diff --git a/vllm/worker/neuron_worker.py b/vllm/worker/neuron_worker.py index fff14d6402b44..3f6269684ac93 100644 --- a/vllm/worker/neuron_worker.py +++ b/vllm/worker/neuron_worker.py @@ -4,15 +4,15 @@ import torch import torch.distributed -from vllm.config import (CacheConfig, DeviceConfig, ModelConfig, - ParallelConfig, SchedulerConfig) +from vllm.config import VllmConfig from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment) from vllm.model_executor import set_random_seed from vllm.sequence import ExecuteModelRequest from vllm.worker.neuron_model_runner import NeuronModelRunner from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, - LoraNotSupportedWorkerBase, WorkerInput) + LoraNotSupportedWorkerBase, WorkerBase, + WorkerInput) class NeuronWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase): @@ -21,20 +21,12 @@ class NeuronWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase): def __init__( self, - model_config: ModelConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - device_config: DeviceConfig, - cache_config: CacheConfig, + vllm_config: VllmConfig, local_rank: int, rank: int, distributed_init_method: str, ) -> None: - self.model_config = model_config - self.parallel_config = parallel_config - self.scheduler_config = scheduler_config - self.device_config = device_config - self.cache_config = cache_config + WorkerBase.__init__(self, vllm_config=vllm_config) self.local_rank = local_rank self.rank = rank self.distributed_init_method = distributed_init_method @@ -44,7 +36,7 @@ def __init__( init_cached_hf_modules() self.model_runner: NeuronModelRunner = NeuronModelRunner( - model_config, parallel_config, scheduler_config, device_config) + vllm_config=vllm_config) self.is_driver_worker = True def init_device(self) -> None: diff --git a/vllm/worker/openvino_model_runner.py b/vllm/worker/openvino_model_runner.py index a164fbe3393c4..c9c87ea748081 100644 --- a/vllm/worker/openvino_model_runner.py +++ b/vllm/worker/openvino_model_runner.py @@ -1,4 +1,5 @@ -from typing import List, NamedTuple, Optional, Tuple +from collections import defaultdict +from typing import Dict, List, NamedTuple, Optional, Tuple import openvino as ov import torch @@ -6,16 +7,15 @@ from vllm.attention import get_attn_backend from vllm.attention.backends.openvino import OpenVINOAttentionMetadata -from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, - ModelConfig, MultiModalConfig, ParallelConfig, - SchedulerConfig) +from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.model_executor import SamplingMetadata from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.model_loader.openvino import get_model from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, - MultiModalInputs) + MultiModalInputs, MultiModalPlaceholderMap) from vllm.sequence import SequenceGroupMetadata +from vllm.worker.model_runner_base import ModelRunnerBase logger = init_logger(__name__) @@ -38,33 +38,21 @@ def empty(cls, device): multi_modal_kwargs={}) -class OpenVINOModelRunner: +class OpenVINOModelRunner(ModelRunnerBase): def __init__( self, ov_core: ov.Core, - model_config: ModelConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - device_config: DeviceConfig, - cache_config: CacheConfig, - load_config: LoadConfig, - lora_config: Optional[LoRAConfig], - multimodal_config: Optional[MultiModalConfig], + vllm_config: VllmConfig, kv_cache_dtype: Optional[str] = "auto", is_driver_worker: bool = False, *args, **kwargs, ): self.ov_core = ov_core - self.model_config = model_config - self.parallel_config = parallel_config - self.scheduler_config = scheduler_config - self.device_config = device_config - self.cache_config = cache_config - self.lora_config = lora_config - self.multimodal_config = multimodal_config - self.load_config = load_config + ModelRunnerBase.__init__(self, vllm_config=vllm_config) + cache_config = self.cache_config + model_config = self.model_config self.is_driver_worker = is_driver_worker self.device = self.device_config.device @@ -115,6 +103,9 @@ def _prepare_model_input( past_lens: List[int] = [] query_lens: List[int] = [] multi_modal_inputs_list: List[MultiModalInputs] = [] + multi_modal_placeholder_maps: Dict[ + str, + MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap) subsequence_begins: List[int] = [] block_indices: List[int] = [] @@ -168,15 +159,6 @@ def _prepare_model_input( and self.sliding_window is None and is_prompt) - mm_data = seq_group_metadata.multi_modal_data - if mm_data: - mm_kwargs = self.multi_modal_input_mapper( - mm_data, - mm_processor_kwargs=seq_group_metadata. - mm_processor_kwargs, - ) - multi_modal_inputs_list.append(mm_kwargs) - block_table = seq_group_metadata.block_tables[seq_id] # TODO(sang): Combine chunked prefill and prefix caching by # only allowing multiple of block_size chunk size. @@ -220,7 +202,8 @@ def _prepare_model_input( query_lens.append(query_len) input_tokens.extend(tokens) - input_positions.extend(list(range(computed_len, seq_len))) + positions_range = range(computed_len, seq_len) + input_positions.extend(list(positions_range)) past_lens.append(computed_len) subsequence_begins.append(subsequence_begins[-1] + query_len) @@ -233,6 +216,22 @@ def _prepare_model_input( ), "seq_len: {}, computed_len: {}, query_len: {}".format( seq_len, computed_len, query_len) + if seq_group_metadata.multi_modal_data: + # NOTE: mm_data only includes the subset of multi-modal + # items that intersect with the current prefill positions. + mm_data, placeholder_maps = MultiModalPlaceholderMap \ + .from_seq_group(seq_group_metadata, positions_range) + + mm_kwargs = self.multi_modal_input_mapper( + mm_data, + mm_processor_kwargs=seq_group_metadata. + mm_processor_kwargs) + multi_modal_inputs_list.append(mm_kwargs) + + for modality, placeholder_map in placeholder_maps.items(): + multi_modal_placeholder_maps[modality].extend( + placeholder_map, ) + max_query_len = max(query_lens) assert max_query_len > 0, "query_lens: {}".format(query_lens) @@ -261,12 +260,19 @@ def _prepare_model_input( max_context_len, dtype=torch.int32, device=self.device) # type: ignore + placeholder_index_maps = { + modality: placeholder_map.index_map() + for modality, placeholder_map in + multi_modal_placeholder_maps.items() + } + attn_metadata = self.attn_backend.make_openvino_metadata( past_lens=past_lens_tensor, subsequence_begins=subsequence_begins_tensor, block_indices=block_indices_tensor, block_indices_begins=block_indices_begins_tensor, max_context_len=max_context_len_tensor, + multi_modal_placeholder_index_maps=placeholder_index_maps, ) multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list) @@ -350,3 +356,9 @@ def execute_model( sampling_metadata=sampling_metadata, ) return output + + def prepare_model_input(self, *args, **kwargs): + raise NotImplementedError + + def make_model_input_from_broadcasted_tensor_dict(self, *args, **kwargs): + raise NotImplementedError diff --git a/vllm/worker/openvino_worker.py b/vllm/worker/openvino_worker.py index a420d390c1ae4..205f8a337ce6c 100644 --- a/vllm/worker/openvino_worker.py +++ b/vllm/worker/openvino_worker.py @@ -7,9 +7,8 @@ import vllm.envs as envs from vllm.attention import get_attn_backend -from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, - ModelConfig, MultiModalConfig, ParallelConfig, - SchedulerConfig) +from vllm.config import (CacheConfig, DeviceConfig, ModelConfig, + ParallelConfig, VllmConfig) from vllm.distributed import (broadcast_tensor_dict, ensure_model_parallel_initialized, init_distributed_environment) @@ -22,7 +21,7 @@ from vllm.sampling_params import SamplingParams from vllm.sequence import ExecuteModelRequest, SequenceGroupMetadata from vllm.worker.openvino_model_runner import OpenVINOModelRunner -from vllm.worker.worker_base import LoraNotSupportedWorkerBase +from vllm.worker.worker_base import LoraNotSupportedWorkerBase, WorkerBase logger = init_logger(__name__) @@ -212,33 +211,19 @@ class OpenVINOWorker(LoraNotSupportedWorkerBase): def __init__( self, ov_core: ov.Core, - model_config: ModelConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - device_config: DeviceConfig, - cache_config: CacheConfig, - load_config: LoadConfig, + vllm_config: VllmConfig, local_rank: int, rank: int, distributed_init_method: str, - lora_config: Optional[LoRAConfig] = None, - multimodal_config: Optional[MultiModalConfig] = None, kv_cache_dtype: Optional[ov.Type] = ov.Type.undefined, is_driver_worker: bool = False, ) -> None: self.ov_core = ov_core - self.model_config = model_config - self.parallel_config = parallel_config + WorkerBase.__init__(self, vllm_config) self.parallel_config.rank = rank - self.scheduler_config = scheduler_config - self.device_config = device_config - self.cache_config = cache_config - self.load_config = load_config self.local_rank = local_rank self.rank = rank self.distributed_init_method = distributed_init_method - self.lora_config = lora_config - self.multimodal_config = multimodal_config self.is_driver_worker = is_driver_worker if self.is_driver_worker: assert self.rank == 0, "The driver worker must have rank 0." @@ -250,14 +235,7 @@ def __init__( init_cached_hf_modules() self.model_runner = OpenVINOModelRunner( self.ov_core, - model_config, - parallel_config, - scheduler_config, - device_config, - cache_config, - load_config=self.load_config, - lora_config=self.lora_config, - multimodal_config=self.multimodal_config, + vllm_config=self.vllm_config, kv_cache_dtype=kv_cache_dtype, is_driver_worker=is_driver_worker, ) diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index 87ced7818a676..a721186137328 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -12,8 +12,7 @@ from vllm.attention import AttentionMetadata, get_attn_backend from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher -from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, ModelConfig, - ParallelConfig, SchedulerConfig) +from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.model_loader import get_model @@ -90,20 +89,10 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]): def __init__( self, - model_config: ModelConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - device_config: DeviceConfig, - cache_config: CacheConfig, - load_config: LoadConfig, + vllm_config: VllmConfig, is_driver_worker: bool = False, ): - self.model_config = model_config - self.parallel_config = parallel_config - self.scheduler_config = scheduler_config - self.device_config = device_config - self.cache_config = cache_config - self.load_config = load_config + ModelRunnerBase.__init__(self, vllm_config=vllm_config) self.is_driver_worker = is_driver_worker self.block_size = self.cache_config.block_size @@ -148,15 +137,7 @@ def load_model(self) -> None: "vllm.model_executor.layers.vocab_parallel_embedding." "get_tensor_model_parallel_rank", return_value=xm_tp_rank): - model = get_model( - model_config=self.model_config, - load_config=self.load_config, - device_config=self.device_config, - parallel_config=self.parallel_config, - cache_config=self.cache_config, - scheduler_config=self.scheduler_config, - lora_config=None, - ) + model = get_model(vllm_config=self.vllm_config) model = model.eval() xm.wait_device_ops() self.model = ModelWrapper(model) @@ -184,6 +165,7 @@ def _dummy_run( num_prefill_tokens=batch_size * seq_len, num_decode_tokens=0, slot_mapping=slot_mapping, + multi_modal_placeholder_index_maps=None, block_tables=None, context_lens=None, ) @@ -216,6 +198,7 @@ def _dummy_run( num_prefill_tokens=0, num_decode_tokens=batch_size * seq_len, slot_mapping=slot_mapping, + multi_modal_placeholder_index_maps=None, block_tables=block_tables, context_lens=context_lens, ) @@ -360,6 +343,7 @@ def _prepare_prompt( num_prefill_tokens=0, # NOTE: This is not used. num_decode_tokens=0, slot_mapping=slot_mapping, + multi_modal_placeholder_index_maps=None, block_tables=None, context_lens=None, ) @@ -429,6 +413,7 @@ def _prepare_decode( num_prefill_tokens=0, num_decode_tokens=batch_size, slot_mapping=slot_mapping, + multi_modal_placeholder_index_maps=None, block_tables=block_tables, context_lens=context_lens, ) diff --git a/vllm/worker/tpu_worker.py b/vllm/worker/tpu_worker.py index de6f7ab0072fd..096cb23416909 100644 --- a/vllm/worker/tpu_worker.py +++ b/vllm/worker/tpu_worker.py @@ -6,8 +6,7 @@ import torch_xla.runtime as xr import vllm.envs as envs -from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, ModelConfig, - ParallelConfig, SchedulerConfig) +from vllm.config import VllmConfig from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment) from vllm.logger import init_logger @@ -16,7 +15,8 @@ from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, get_dtype_size from vllm.worker.tpu_model_runner import TPUModelRunner from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, - LoraNotSupportedWorkerBase, WorkerInput) + LoraNotSupportedWorkerBase, WorkerBase, + WorkerInput) logger = init_logger(__name__) @@ -25,24 +25,14 @@ class TPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase): def __init__( self, - model_config: ModelConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - device_config: DeviceConfig, - cache_config: CacheConfig, - load_config: LoadConfig, + vllm_config: VllmConfig, local_rank: int, rank: int, distributed_init_method: str, is_driver_worker: bool, ) -> None: - self.model_config = model_config - self.parallel_config = parallel_config + WorkerBase.__init__(self, vllm_config=vllm_config) self.parallel_config.rank = rank - self.scheduler_config = scheduler_config - self.device_config = device_config - self.cache_config = cache_config - self.load_config = load_config self.local_rank = local_rank self.rank = rank self.distributed_init_method = distributed_init_method @@ -56,13 +46,7 @@ def __init__( self.cache_config.cache_dtype] self.model_runner: TPUModelRunner = TPUModelRunner( - model_config, - parallel_config, - scheduler_config, - device_config, - cache_config, - load_config, - is_driver_worker=is_driver_worker) + vllm_config=vllm_config, is_driver_worker=is_driver_worker) def init_device(self) -> None: os.environ["PJRT_DEVICE"] = "TPU" diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index fd30962e5d6bb..8928936b4f9fc 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -7,10 +7,7 @@ import torch.distributed import vllm.envs as envs -from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, - ModelConfig, ObservabilityConfig, ParallelConfig, - PromptAdapterConfig, SchedulerConfig, - SpeculativeConfig) +from vllm.config import ParallelConfig, VllmConfig from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment, set_custom_all_reduce) @@ -27,7 +24,8 @@ from vllm.worker.embedding_model_runner import EmbeddingModelRunner from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner from vllm.worker.model_runner import GPUModelRunnerBase, ModelRunner -from vllm.worker.worker_base import LocalOrDistributedWorkerBase, WorkerInput +from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, WorkerBase, + WorkerInput) logger = init_logger(__name__) @@ -42,46 +40,31 @@ class Worker(LocalOrDistributedWorkerBase): def __init__( self, - model_config: ModelConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - device_config: DeviceConfig, - cache_config: CacheConfig, - load_config: LoadConfig, + vllm_config: VllmConfig, local_rank: int, rank: int, distributed_init_method: str, - lora_config: Optional[LoRAConfig] = None, - speculative_config: Optional[SpeculativeConfig] = None, - prompt_adapter_config: Optional[PromptAdapterConfig] = None, is_driver_worker: bool = False, model_runner_cls: Optional[Type[GPUModelRunnerBase]] = None, - observability_config: Optional[ObservabilityConfig] = None, ) -> None: - self.model_config = model_config - self.parallel_config = parallel_config + WorkerBase.__init__(self, vllm_config) self.parallel_config.rank = rank - self.scheduler_config = scheduler_config - self.device_config = device_config - self.cache_config = cache_config self.local_rank = local_rank self.rank = rank self.distributed_init_method = distributed_init_method - self.lora_config = lora_config - self.load_config = load_config - self.prompt_adapter_config = prompt_adapter_config self.is_driver_worker = is_driver_worker - if parallel_config and is_driver_worker: - assert rank % parallel_config.tensor_parallel_size == 0, \ + if is_driver_worker: + assert rank % self.parallel_config.tensor_parallel_size == 0, \ "Driver worker should be rank 0 of tensor parallel group." if self.model_config.trust_remote_code: # note: lazy import to avoid importing torch before initializing from vllm.utils import init_cached_hf_modules init_cached_hf_modules() - self.observability_config = observability_config # Return hidden states from target model if the draft model is an # mlp_speculator + speculative_config = self.speculative_config + model_config = self.model_config speculative_args = {} if speculative_config is None \ or (speculative_config.draft_model_config.model == model_config.model) \ @@ -97,17 +80,9 @@ def __init__( elif self._is_encoder_decoder_model(): ModelRunnerClass = EncoderDecoderModelRunner self.model_runner: GPUModelRunnerBase = ModelRunnerClass( - model_config, - parallel_config, - scheduler_config, - device_config, - cache_config, - load_config=load_config, - lora_config=self.lora_config, + vllm_config=self.vllm_config, kv_cache_dtype=self.cache_config.cache_dtype, is_driver_worker=is_driver_worker, - prompt_adapter_config=prompt_adapter_config, - observability_config=observability_config, **speculative_args, ) # Uninitialized cache engine. Will be initialized by diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index 6ba4f272315ce..cf8a4946a71c4 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -7,7 +7,7 @@ import torch -from vllm.config import ObservabilityConfig +from vllm.config import ObservabilityConfig, VllmConfig from vllm.distributed import broadcast_tensor_dict, get_pp_group, get_tp_group from vllm.logger import init_logger from vllm.lora.request import LoRARequest @@ -29,6 +29,22 @@ class WorkerBase(ABC): communicate request metadata to other workers. """ + def __init__( + self, + vllm_config: VllmConfig, + ) -> None: + self.vllm_config = vllm_config + self.model_config = vllm_config.model_config + self.cache_config = vllm_config.cache_config + self.lora_config = vllm_config.lora_config + self.load_config = vllm_config.load_config + self.parallel_config = vllm_config.parallel_config + self.scheduler_config = vllm_config.scheduler_config + self.device_config = vllm_config.device_config + self.speculative_config = vllm_config.speculative_config + self.prompt_adapter_config = vllm_config.prompt_adapter_config + self.observability_config = vllm_config.observability_config + @abstractmethod def init_device(self) -> None: """Initialize device state, such as loading the model or other on-device diff --git a/vllm/worker/xpu_model_runner.py b/vllm/worker/xpu_model_runner.py index 75a6de3b24ba4..bae8b469767b2 100644 --- a/vllm/worker/xpu_model_runner.py +++ b/vllm/worker/xpu_model_runner.py @@ -1,6 +1,7 @@ import dataclasses import time import weakref +from collections import defaultdict from dataclasses import dataclass from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar) @@ -9,9 +10,7 @@ import torch.nn as nn from vllm.attention import get_attn_backend -from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, - ModelConfig, ObservabilityConfig, ParallelConfig, - PromptAdapterConfig, SchedulerConfig) +from vllm.config import VllmConfig from vllm.distributed import get_pp_group from vllm.inputs import INPUT_REGISTRY, InputRegistry from vllm.logger import init_logger @@ -19,7 +18,8 @@ from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.model_loader import get_model from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, - MultiModalInputs, MultiModalRegistry) + MultiModalInputs, MultiModalPlaceholderMap, + MultiModalRegistry) from vllm.sampling_params import SamplingParams from vllm.sequence import IntermediateTensors, SequenceGroupMetadata from vllm.utils import DeviceMemoryProfiler, make_tensor_with_pad @@ -161,6 +161,9 @@ def _prepare_prompt( slot_mapping: List[int] = [] seq_lens: List[int] = [] multi_modal_inputs_list: List[MultiModalInputs] = [] + multi_modal_placeholder_maps: Dict[ + str, + MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap) for seq_group_metadata in seq_group_metadata_list: assert seq_group_metadata.is_prompt @@ -179,7 +182,21 @@ def _prepare_prompt( # Token position ids # NOTE(woosuk): Here we assume that the first token in the prompt # is always the first token in the sequence. - input_positions.extend(list(range(computed_len, seq_len))) + positions_range = range(computed_len, seq_len) + input_positions.extend(list(positions_range)) + + if seq_group_metadata.multi_modal_data: + # NOTE: mm_data only includes the subset of multi-modal items + # that intersect with the current prefill positions. + mm_data, placeholder_maps = MultiModalPlaceholderMap \ + .from_seq_group(seq_group_metadata, positions_range) + + mm_kwargs = self.runner.multi_modal_input_mapper(mm_data) + multi_modal_inputs_list.append(mm_kwargs) + + for modality, placeholder_map in placeholder_maps.items(): + multi_modal_placeholder_maps[modality].extend( + placeholder_map) if seq_group_metadata.block_tables is None: # During memory profiling, the block tables are not initialized @@ -220,6 +237,11 @@ def _prepare_prompt( slot_mapping = torch.tensor(slot_mapping, dtype=torch.long, device=self.device) # type: ignore + placeholder_index_maps = { + modality: placeholder_map.index_map() + for modality, placeholder_map in + multi_modal_placeholder_maps.items() + } max_seqlen = max(seq_lens) tmp = [0] @@ -230,6 +252,7 @@ def _prepare_prompt( attn_metadata = self.attn_backend.make_metadata( is_prompt=True, slot_mapping=slot_mapping, + multi_modal_placeholder_index_maps=placeholder_index_maps, seq_lens=seq_lens, seqlen_q=seqlen_q, max_seqlen=max_seqlen, @@ -313,6 +336,7 @@ def _prepare_decode( attn_metadata = self.attn_backend.make_metadata( is_prompt=False, slot_mapping=slot_mapping, + multi_modal_placeholder_index_maps=None, seq_lens=seq_lens, seqlen_q=torch.tensor([]), max_seqlen=0, @@ -337,33 +361,18 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]): def __init__( self, - model_config: ModelConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - device_config: DeviceConfig, - cache_config: CacheConfig, - load_config: LoadConfig, - lora_config: Optional[LoRAConfig], + vllm_config: VllmConfig, kv_cache_dtype: Optional[str] = "auto", is_driver_worker: bool = False, - prompt_adapter_config: Optional[PromptAdapterConfig] = None, return_hidden_states: bool = False, - observability_config: Optional[ObservabilityConfig] = None, input_registry: InputRegistry = INPUT_REGISTRY, mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, ): - self.model_config = model_config - self.parallel_config = parallel_config - self.scheduler_config = scheduler_config - self.device_config = device_config - self.cache_config = cache_config - self.lora_config = lora_config - self.load_config = load_config + + ModelRunnerBase.__init__(self, vllm_config=vllm_config) + model_config = self.model_config + cache_config = self.cache_config self.is_driver_worker = is_driver_worker - self.prompt_adapter_config = prompt_adapter_config - self.observability_config = observability_config - if self.observability_config is not None: - print(f"observability_config is {self.observability_config}") self.return_hidden_states = return_hidden_states self.device = self.device_config.device @@ -396,15 +405,7 @@ def __init__( def load_model(self) -> None: with DeviceMemoryProfiler() as m: - self.model = get_model( - model_config=self.model_config, - device_config=self.device_config, - load_config=self.load_config, - lora_config=self.lora_config, - parallel_config=self.parallel_config, - scheduler_config=self.scheduler_config, - cache_config=self.cache_config, - ) + self.model = get_model(vllm_config=self.vllm_config) self.model_memory_usage = m.consumed_memory logger.info("Loading model weights took %.4f GB", @@ -450,7 +451,7 @@ def profile_run(self) -> None: (group_id < max_num_batched_tokens % max_num_seqs)) batch_size += seq_len - seq_data, dummy_multi_modal_data = self.input_registry \ + dummy_data = self.input_registry \ .dummy_data_for_profiling(self.model_config, seq_len, self.mm_registry) @@ -458,12 +459,12 @@ def profile_run(self) -> None: seq = SequenceGroupMetadata( request_id=str(group_id), is_prompt=True, - seq_data={group_id: seq_data}, + seq_data={group_id: dummy_data.seq_data}, sampling_params=sampling_params, block_tables=None, lora_request=None, - multi_modal_data=dummy_multi_modal_data, - ) + multi_modal_data=dummy_data.multi_modal_data, + multi_modal_placeholders=dummy_data.multi_modal_placeholders) seqs.append(seq) # Run the model with the dummy inputs. diff --git a/vllm/worker/xpu_worker.py b/vllm/worker/xpu_worker.py index c1d836bb0d318..1295666055b04 100644 --- a/vllm/worker/xpu_worker.py +++ b/vllm/worker/xpu_worker.py @@ -8,10 +8,7 @@ import torch import torch.distributed -from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, - ModelConfig, ObservabilityConfig, ParallelConfig, - PromptAdapterConfig, SchedulerConfig, - SpeculativeConfig) +from vllm.config import VllmConfig from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment) from vllm.logger import init_logger @@ -19,7 +16,7 @@ from vllm.platforms import current_platform from vllm.worker.cache_engine import CacheEngine from vllm.worker.worker import Worker -from vllm.worker.worker_base import LoraNotSupportedWorkerBase +from vllm.worker.worker_base import LoraNotSupportedWorkerBase, WorkerBase from vllm.worker.xpu_model_runner import XPUModelRunner logger = init_logger(__name__) @@ -36,53 +33,32 @@ class XPUWorker(LoraNotSupportedWorkerBase, Worker): def __init__( self, - model_config: ModelConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - device_config: DeviceConfig, - cache_config: CacheConfig, - load_config: LoadConfig, + vllm_config: VllmConfig, local_rank: int, rank: int, distributed_init_method: str, - lora_config: Optional[LoRAConfig] = None, - speculative_config: Optional[SpeculativeConfig] = None, - prompt_adapter_config: Optional[PromptAdapterConfig] = None, is_driver_worker: bool = False, - observability_config: Optional[ObservabilityConfig] = None, ) -> None: + WorkerBase.__init__(self, vllm_config=vllm_config) + device_config = self.device_config + parallel_config = self.parallel_config assert device_config.device_type == "xpu" assert current_platform.is_xpu() - self.model_config = model_config - self.parallel_config = parallel_config self.parallel_config.rank = rank - self.scheduler_config = scheduler_config - self.device_config = device_config - self.cache_config = cache_config - self.load_config = load_config + self.local_rank = local_rank self.rank = rank self.distributed_init_method = distributed_init_method - self.lora_config = lora_config - self.prompt_adapter_config = prompt_adapter_config self.is_driver_worker = is_driver_worker - self.observability_config = observability_config if parallel_config and is_driver_worker: assert rank % parallel_config.tensor_parallel_size == 0, \ "Driver worker should be rank 0 of tensor parallel group." self.model_runner = XPUModelRunner( # type: ignore - model_config, - parallel_config, - scheduler_config, - device_config, - cache_config, - load_config=self.load_config, - lora_config=self.lora_config, + vllm_config=vllm_config, kv_cache_dtype=self.cache_config.cache_dtype, is_driver_worker=is_driver_worker, - observability_config=self.observability_config, ) # Uninitialized cache engine. Will be initialized by # initialize_cache.