diff --git a/.buildkite/release-pipeline.yaml b/.buildkite/release-pipeline.yaml
index 5be9a553dddd4..416fe344a36ea 100644
--- a/.buildkite/release-pipeline.yaml
+++ b/.buildkite/release-pipeline.yaml
@@ -1,9 +1,27 @@
steps:
- - label: "Build wheel - CUDA {{matrix.cuda_version}}"
+ - label: "Build wheel - CUDA 12.1"
agents:
queue: cpu_queue
commands:
- - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg buildkite_commit=$BUILDKITE_COMMIT --build-arg USE_SCCACHE=1 --build-arg CUDA_VERSION={{matrix.cuda_version}} --tag vllm-ci:build-image --target build --progress plain ."
+ - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg buildkite_commit=$BUILDKITE_COMMIT --build-arg USE_SCCACHE=1 --build-arg CUDA_VERSION=12.1.0 --tag vllm-ci:build-image --target build --progress plain ."
+ - "mkdir artifacts"
+ - "docker run --rm -v $(pwd)/artifacts:/artifacts_host vllm-ci:build-image bash -c 'cp -r dist /artifacts_host && chmod -R a+rw /artifacts_host'"
+ # rename the files to change linux -> manylinux1
+ - "for f in artifacts/dist/*.whl; do mv -- \"$$f\" \"$${f/linux/manylinux1}\"; done"
+ - "aws s3 cp --recursive artifacts/dist s3://vllm-wheels/$BUILDKITE_COMMIT/"
+ - "aws s3 cp --recursive artifacts/dist s3://vllm-wheels/nightly/"
+ env:
+ DOCKER_BUILDKIT: "1"
+
+ - block: "Build CUDA 11.8 wheel"
+ key: block-build-cu118-wheel
+
+ - label: "Build wheel - CUDA 11.8"
+ depends_on: block-build-cu118-wheel
+ agents:
+ queue: cpu_queue
+ commands:
+ - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg buildkite_commit=$BUILDKITE_COMMIT --build-arg USE_SCCACHE=1 --build-arg CUDA_VERSION=11.8.0 --tag vllm-ci:build-image --target build --progress plain ."
- "mkdir artifacts"
- "docker run --rm -v $(pwd)/artifacts:/artifacts_host vllm-ci:build-image bash -c 'cp -r dist /artifacts_host && chmod -R a+rw /artifacts_host'"
# rename the files to change linux -> manylinux1
@@ -12,8 +30,3 @@ steps:
- "aws s3 cp --recursive artifacts/dist s3://vllm-wheels/nightly/"
env:
DOCKER_BUILDKIT: "1"
- matrix:
- setup:
- cuda_version:
- - "11.8.0"
- - "12.1.0"
diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml
index 6f38cd313f115..6e83c887f89b6 100644
--- a/.buildkite/test-pipeline.yaml
+++ b/.buildkite/test-pipeline.yaml
@@ -148,8 +148,9 @@ steps:
- python3 cpu_offload.py
- python3 offline_inference_with_prefix.py
- python3 llm_engine_example.py
- - python3 llava_example.py
+ - python3 offline_inference_vision_language.py
- python3 tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors
+ - python3 offline_inference_encoder_decoder.py
- label: Models Test # 1hr10min
source_file_dependencies:
@@ -289,6 +290,7 @@ steps:
commands:
- VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py
- TARGET_TEST_SUITE=L4 pytest -v -s distributed/test_basic_distributed_correctness.py
+ - pytest -v -s distributed/test_basic_distributed_correctness_enc_dec.py
- pytest -v -s distributed/test_chunked_prefill_distributed.py
- pytest -v -s distributed/test_multimodal_broadcast.py
- pytest -v -s spec_decode/e2e/test_integration_dist_tp2.py
diff --git a/Dockerfile.openvino b/Dockerfile.openvino
index c84dea419e58a..06ca4638dfeb9 100644
--- a/Dockerfile.openvino
+++ b/Dockerfile.openvino
@@ -21,7 +21,7 @@ COPY setup.py /workspace/vllm/
# install build requirements
RUN PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cpu" python3 -m pip install -r /workspace/vllm/requirements-build.txt
# build vLLM with OpenVINO backend
-RUN PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cpu https://storage.openvinotoolkit.org/simple/wheels/pre-release" VLLM_TARGET_DEVICE="openvino" python3 -m pip install /workspace/vllm/
+RUN PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cpu" VLLM_TARGET_DEVICE="openvino" python3 -m pip install /workspace/vllm/
COPY examples/ /workspace/vllm/examples
COPY benchmarks/ /workspace/vllm/benchmarks
diff --git a/docs/source/conf.py b/docs/source/conf.py
index f1eb8524d4e9c..ded6742ea2e5c 100644
--- a/docs/source/conf.py
+++ b/docs/source/conf.py
@@ -112,6 +112,8 @@ def setup(app):
"tensorizer",
"pynvml",
"outlines",
+ "gguf",
+ "lark",
]
for mock_target in autodoc_mock_imports:
diff --git a/docs/source/getting_started/openvino-installation.rst b/docs/source/getting_started/openvino-installation.rst
index 62256df091a44..d8f27c4328a58 100644
--- a/docs/source/getting_started/openvino-installation.rst
+++ b/docs/source/getting_started/openvino-installation.rst
@@ -57,7 +57,7 @@ Install from source
.. code-block:: console
- $ PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cpu https://storage.openvinotoolkit.org/simple/wheels/pre-release" VLLM_TARGET_DEVICE=openvino python -m pip install -v .
+ $ PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cpu" VLLM_TARGET_DEVICE=openvino python -m pip install -v .
.. _openvino_backend_performance_tips:
diff --git a/docs/source/models/spec_decode.rst b/docs/source/models/spec_decode.rst
index be901fa881b12..d3c196faff25d 100644
--- a/docs/source/models/spec_decode.rst
+++ b/docs/source/models/spec_decode.rst
@@ -14,17 +14,17 @@ Speculative decoding is a technique which improves inter-token latency in memory
Speculating with a draft model
------------------------------
-The following code configures vLLM to use speculative decoding with a draft model, speculating 5 tokens at a time.
+The following code configures vLLM in an offline mode to use speculative decoding with a draft model, speculating 5 tokens at a time.
.. code-block:: python
from vllm import LLM, SamplingParams
-
+
prompts = [
"The future of AI is",
]
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
-
+
llm = LLM(
model="facebook/opt-6.7b",
tensor_parallel_size=1,
@@ -33,12 +33,56 @@ The following code configures vLLM to use speculative decoding with a draft mode
use_v2_block_manager=True,
)
outputs = llm.generate(prompts, sampling_params)
-
+
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
+To perform the same with an online mode launch the server:
+
+.. code-block:: bash
+
+ python -m vllm.entrypoints.openai.api_server --host 0.0.0.0 --port 8000 --model facebook/opt-6.7b \
+ --seed 42 -tp 1 --speculative_model facebook/opt-125m --use-v2-block-manager \
+ --num_speculative_tokens 5 --gpu_memory_utilization 0.8
+
+ Then use a client:
+
+.. code-block:: python
+
+ from openai import OpenAI
+
+ # Modify OpenAI's API key and API base to use vLLM's API server.
+ openai_api_key = "EMPTY"
+ openai_api_base = "http://localhost:8000/v1"
+
+ client = OpenAI(
+ # defaults to os.environ.get("OPENAI_API_KEY")
+ api_key=openai_api_key,
+ base_url=openai_api_base,
+ )
+
+ models = client.models.list()
+ model = models.data[0].id
+
+ # Completion API
+ stream = False
+ completion = client.completions.create(
+ model=model,
+ prompt="The future of AI is",
+ echo=False,
+ n=1,
+ stream=stream,
+ )
+
+ print("Completion results:")
+ if stream:
+ for c in completion:
+ print(c)
+ else:
+ print(completion)
+
Speculating by matching n-grams in the prompt
---------------------------------------------
@@ -48,12 +92,12 @@ matching n-grams in the prompt. For more information read `this thread. `_ or
+For more information see `this blog `_ or
`this technical report `_.
.. code-block:: python
@@ -100,9 +144,9 @@ For more information see `this blog = 4.43.2
-cmake >= 3.21
-ninja # For faster builds.
-psutil
-sentencepiece # Required for LLaMA tokenizer.
-numpy < 2.0.0
-requests
-tqdm
-py-cpuinfo
-transformers < 4.43
-tokenizers >= 0.19.1 # Required for Llama 3.
-fastapi
-aiohttp
-openai
-uvicorn[standard]
-pydantic >= 2.0 # Required for OpenAI server.
-pillow # Required for image processing
-prometheus_client >= 0.18.0
-prometheus-fastapi-instrumentator >= 7.0.0
-tiktoken >= 0.6.0 # Required for DBRX tokenizer
-lm-format-enforcer == 0.10.3
-outlines >= 0.0.43, < 0.1 # Requires torch >= 2.1.0
-typing_extensions
-filelock >= 3.10.4 # filelock starts to support `mode` argument from 3.10.4
-pyzmq
-gguf == 0.9.1
+-r requirements-common.txt
# OpenVINO dependencies
torch >= 2.1.2
-openvino ~= 2024.3.0.dev
-openvino-tokenizers[transformers] ~= 2024.3.0.0.dev
-optimum-intel[openvino] >= 1.18.1
+openvino ~= 2024.3.0
+optimum-intel[openvino] >= 1.18.2
diff --git a/setup.py b/setup.py
index b146299f8269d..f6e005879aeff 100644
--- a/setup.py
+++ b/setup.py
@@ -272,7 +272,7 @@ def _build_custom_ops() -> bool:
def _build_core_ext() -> bool:
- return not _is_neuron() and not _is_tpu()
+ return not _is_neuron() and not _is_tpu() and not _is_openvino()
def get_hipcc_rocm_version():
diff --git a/tests/async_engine/api_server_async_engine.py b/tests/async_engine/api_server_async_engine.py
index 495a123c351d7..a3c9d5c6e0898 100644
--- a/tests/async_engine/api_server_async_engine.py
+++ b/tests/async_engine/api_server_async_engine.py
@@ -1,5 +1,5 @@
"""vllm.entrypoints.api_server with some extra logging for testing."""
-from typing import Any, Dict
+from typing import Any, Dict, Iterable
import uvicorn
from fastapi.responses import JSONResponse, Response
@@ -18,9 +18,10 @@ def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._num_aborts = 0
- async def abort(self, request_id: str) -> None:
- await super().abort(request_id)
- self._num_aborts += 1
+ async def _engine_abort(self, request_ids: Iterable[str]):
+ ids = list(request_ids)
+ self._num_aborts += len(ids)
+ await super()._engine_abort(ids)
def testing_stats(self) -> Dict[str, Any]:
return {"num_aborted_requests": self._num_aborts}
diff --git a/tests/async_engine/test_chat_template.py b/tests/async_engine/test_chat_template.py
index aea8a7fed6e33..4df6c02973284 100644
--- a/tests/async_engine/test_chat_template.py
+++ b/tests/async_engine/test_chat_template.py
@@ -1,22 +1,16 @@
-import os
-import pathlib
-
import pytest
-from vllm.entrypoints.chat_utils import load_chat_template
+from vllm.entrypoints.chat_utils import apply_chat_template, load_chat_template
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
from vllm.transformers_utils.tokenizer import get_tokenizer
-chatml_jinja_path = pathlib.Path(os.path.dirname(os.path.abspath(
- __file__))).parent.parent / "examples/template_chatml.jinja"
+from ..utils import VLLM_PATH
+
+chatml_jinja_path = VLLM_PATH / "examples/template_chatml.jinja"
assert chatml_jinja_path.exists()
# Define models, templates, and their corresponding expected outputs
MODEL_TEMPLATE_GENERATON_OUTPUT = [
- ("facebook/opt-125m", None, True,
- "HelloHi there!What is the capital of"),
- ("facebook/opt-125m", None, False,
- "HelloHi there!What is the capital of"),
("facebook/opt-125m", chatml_jinja_path, True, """<|im_start|>user
Hello<|im_end|>
<|im_start|>assistant
@@ -93,11 +87,12 @@ def test_get_gen_prompt(model, template, add_generation_prompt,
add_generation_prompt=add_generation_prompt)
# Call the function and get the result
- result = tokenizer.apply_chat_template(
+ result = apply_chat_template(
+ tokenizer,
conversation=mock_request.messages,
- tokenize=False,
+ chat_template=mock_request.chat_template or template_content,
add_generation_prompt=mock_request.add_generation_prompt,
- chat_template=mock_request.chat_template or template_content)
+ )
# Test assertion
assert result == expected_output, (
diff --git a/tests/async_engine/test_openapi_server_ray.py b/tests/async_engine/test_openapi_server_ray.py
index 5ecd770ede836..0d53b39e7ce1c 100644
--- a/tests/async_engine/test_openapi_server_ray.py
+++ b/tests/async_engine/test_openapi_server_ray.py
@@ -1,10 +1,12 @@
import openai # use the official client for correctness check
import pytest
-from ..utils import RemoteOpenAIServer
+from ..utils import VLLM_PATH, RemoteOpenAIServer
# any model with a chat template should work here
MODEL_NAME = "facebook/opt-125m"
+chatml_jinja_path = VLLM_PATH / "examples/template_chatml.jinja"
+assert chatml_jinja_path.exists()
@pytest.fixture(scope="module")
@@ -16,7 +18,9 @@ def server():
"--max-model-len",
"2048",
"--enforce-eager",
- "--engine-use-ray"
+ "--engine-use-ray",
+ "--chat-template",
+ str(chatml_jinja_path),
]
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
@@ -83,7 +87,7 @@ async def test_single_chat_session(client: openai.AsyncOpenAI):
choice = chat_completion.choices[0]
assert choice.finish_reason == "length"
assert chat_completion.usage == openai.types.CompletionUsage(
- completion_tokens=10, prompt_tokens=13, total_tokens=23)
+ completion_tokens=10, prompt_tokens=55, total_tokens=65)
message = choice.message
assert message.content is not None and len(message.content) >= 10
diff --git a/tests/async_engine/test_request_tracker.py b/tests/async_engine/test_request_tracker.py
index 7b1f4a9e1eb2f..c66bdd5f9003d 100644
--- a/tests/async_engine/test_request_tracker.py
+++ b/tests/async_engine/test_request_tracker.py
@@ -10,23 +10,23 @@ async def test_request_tracker():
stream_1 = tracker.add_request("1")
assert tracker.new_requests_event.is_set()
await tracker.wait_for_new_requests()
- new, finished = tracker.get_new_and_finished_requests()
+ new, aborted = tracker.get_new_and_aborted_requests()
assert not tracker.new_requests_event.is_set()
assert len(new) == 1
assert new[0]["request_id"] == "1"
- assert not finished
+ assert not aborted
assert not stream_1.finished
stream_2 = tracker.add_request("2")
stream_3 = tracker.add_request("3")
assert tracker.new_requests_event.is_set()
await tracker.wait_for_new_requests()
- new, finished = tracker.get_new_and_finished_requests()
+ new, aborted = tracker.get_new_and_aborted_requests()
assert not tracker.new_requests_event.is_set()
assert len(new) == 2
assert new[0]["request_id"] == "2"
assert new[1]["request_id"] == "3"
- assert not finished
+ assert not aborted
assert not stream_2.finished
assert not stream_3.finished
@@ -36,9 +36,9 @@ async def test_request_tracker():
assert not tracker.new_requests_event.is_set()
tracker.abort_request("1")
- new, finished = tracker.get_new_and_finished_requests()
- assert len(finished) == 1
- assert "1" in finished
+ new, aborted = tracker.get_new_and_aborted_requests()
+ assert len(aborted) == 1
+ assert "1" in aborted
assert not new
assert stream_1.finished
@@ -46,9 +46,9 @@ async def test_request_tracker():
tracker.abort_request("4")
assert tracker.new_requests_event.is_set()
await tracker.wait_for_new_requests()
- new, finished = tracker.get_new_and_finished_requests()
- assert len(finished) == 1
- assert "4" in finished
+ new, aborted = tracker.get_new_and_aborted_requests()
+ assert len(aborted) == 1
+ assert "4" in aborted
assert not new
assert stream_4.finished
@@ -57,10 +57,9 @@ async def test_request_tracker():
tracker.process_request_output(
RequestOutput("2", "output", [], [], [], finished=True))
await tracker.wait_for_new_requests()
- new, finished = tracker.get_new_and_finished_requests()
+ new, aborted = tracker.get_new_and_aborted_requests()
assert not tracker.new_requests_event.is_set()
- assert len(finished) == 1
- assert "2" in finished
+ assert not aborted
assert len(new) == 1
assert new[0]["request_id"] == "5"
assert stream_2.finished
diff --git a/tests/basic_correctness/test_cpu_offload.py b/tests/basic_correctness/test_cpu_offload.py
index 180b926637ecb..f0f09ee63c0e6 100644
--- a/tests/basic_correctness/test_cpu_offload.py
+++ b/tests/basic_correctness/test_cpu_offload.py
@@ -22,11 +22,28 @@ def test_cpu_offload_fp8():
["--cpu-offload-gb", "2"])
-@pytest.mark.skipif(not is_quant_method_supported("awq"),
- reason="awq is not supported on this GPU type.")
+@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"),
+ reason="gptq_marlin is not supported on this GPU type.")
+def test_cpu_offload_gptq():
+ # Test GPTQ Marlin
+ compare_two_settings("Qwen/Qwen2-1.5B-Instruct-GPTQ-Int4", [],
+ ["--cpu-offload-gb", "1"])
+ # Test GPTQ
+ compare_two_settings("Qwen/Qwen2-1.5B-Instruct-GPTQ-Int4",
+ ["--quantization", "gptq"],
+ ["--quantization", "gptq", "--cpu-offload-gb", "1"])
+
+
+@pytest.mark.skipif(not is_quant_method_supported("awq_marlin"),
+ reason="awq_marlin is not supported on this GPU type.")
def test_cpu_offload_awq():
- compare_two_settings("casperhansen/llama-3-8b-instruct-awq", [],
- ["--cpu-offload-gb", "2"])
+ # Test AWQ Marlin
+ compare_two_settings("Qwen/Qwen2-1.5B-Instruct-AWQ", [],
+ ["--cpu-offload-gb", "1"])
+ # Test AWQ
+ compare_two_settings("Qwen/Qwen2-1.5B-Instruct-AWQ",
+ ["--quantization", "awq"],
+ ["--quantization", "awq", "--cpu-offload-gb", "1"])
@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"),
diff --git a/tests/conftest.py b/tests/conftest.py
index c7a349f1e9e2a..c0bf9897c97f2 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -10,9 +10,11 @@
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
-from transformers import (AutoModelForCausalLM, AutoModelForVision2Seq,
- AutoTokenizer, BatchEncoding, BatchFeature)
+from transformers import (AutoModelForCausalLM, AutoModelForSeq2SeqLM,
+ AutoModelForVision2Seq, AutoTokenizer, BatchEncoding,
+ BatchFeature)
+from tests.models.utils import DecoderPromptType
from vllm import LLM, SamplingParams
from vllm.assets.image import ImageAsset
from vllm.config import TokenizerPoolConfig
@@ -21,9 +23,11 @@
destroy_model_parallel)
from vllm.inputs import TextPrompt
from vllm.logger import init_logger
+from vllm.outputs import RequestOutput
from vllm.sequence import SampleLogprobs
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, cuda_device_count_stateless,
- is_cpu)
+ is_cpu, to_enc_dec_tuple_list,
+ zip_enc_dec_prompt_lists)
logger = init_logger(__name__)
@@ -120,6 +124,40 @@ def example_prompts() -> List[str]:
return prompts
+@pytest.fixture
+def example_encoder_decoder_prompts() \
+ -> Dict[DecoderPromptType,
+ Tuple[List[str], List[Optional[str]]]]:
+ '''
+ Returns an encoder prompt list and a decoder prompt list, wherein each pair
+ of same-index entries in both lists corresponds to an (encoder prompt,
+ decoder prompt) tuple.
+
+ Returns:
+
+ * Encoder prompt list
+ * Decoder prompt list (reverse of encoder prompt list)
+ '''
+
+ encoder_prompts = []
+ for filename in _TEST_PROMPTS:
+ encoder_prompts += _read_prompts(filename)
+
+ custom_decoder_prompts = encoder_prompts[::-1]
+ empty_str_decoder_prompts = [""] * len(encoder_prompts)
+ none_decoder_prompts = [None] * len(encoder_prompts)
+
+ # NONE decoder prompt type
+ return {
+ DecoderPromptType.NONE:
+ zip_enc_dec_prompt_lists(encoder_prompts, none_decoder_prompts),
+ DecoderPromptType.EMPTY_STR:
+ zip_enc_dec_prompt_lists(encoder_prompts, empty_str_decoder_prompts),
+ DecoderPromptType.CUSTOM:
+ zip_enc_dec_prompt_lists(encoder_prompts, custom_decoder_prompts),
+ }
+
+
@pytest.fixture
def example_long_prompts() -> List[str]:
prompts = []
@@ -152,6 +190,7 @@ def __init__(
model_kwargs: Optional[Dict[str, Any]] = None,
is_embedding_model: bool = False,
is_vision_model: bool = False,
+ is_encoder_decoder_model: bool = False,
) -> None:
torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[dtype]
@@ -168,6 +207,8 @@ def __init__(
else:
if is_vision_model:
auto_cls = AutoModelForVision2Seq
+ elif is_encoder_decoder_model:
+ auto_cls = AutoModelForSeq2SeqLM
else:
auto_cls = AutoModelForCausalLM
@@ -314,6 +355,44 @@ def generate_greedy_logprobs(
all_logprobs.append(seq_logprobs)
return all_logprobs
+ def _hidden_states_to_logprobs(
+ self,
+ hidden_states,
+ num_logprobs,
+ ) -> Tuple[List[Dict[int, float]], int]:
+ seq_logprobs: List[torch.Tensor] = []
+ output_len = len(hidden_states)
+ for _, hidden_state in enumerate(hidden_states):
+ last_hidden_states = hidden_state[-1][0]
+ logits = torch.matmul(
+ last_hidden_states,
+ self.model.get_output_embeddings().weight.t(),
+ )
+ if getattr(self.model.get_output_embeddings(), "bias",
+ None) is not None:
+ logits += self.model.get_output_embeddings().bias.unsqueeze(0)
+ logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
+ seq_logprobs.append(logprobs)
+
+ # convert to dict
+ seq_logprobs_lst: List[Dict[int, float]] = []
+ for tok_idx, tok_logprobs in enumerate(seq_logprobs):
+ # drop prompt logprobs
+ if tok_idx == 0:
+ tok_logprobs = tok_logprobs[-1, :].reshape(1, -1)
+ topk = tok_logprobs.topk(num_logprobs)
+
+ tok_logprobs_dct = {}
+ for token_id, logprob in zip(topk.indices[0], topk.values[0]):
+ tok_logprobs_dct[token_id.item()] = logprob.item()
+
+ seq_logprobs_lst.append(tok_logprobs_dct)
+
+ return (
+ seq_logprobs_lst,
+ output_len,
+ )
+
def generate_greedy_logprobs_limit(
self,
prompts: List[str],
@@ -346,37 +425,66 @@ def generate_greedy_logprobs_limit(
**kwargs,
)
- seq_logprobs: List[torch.Tensor] = []
- for _, hidden_states in enumerate(output.hidden_states):
- last_hidden_states = hidden_states[-1][0]
- logits = torch.matmul(
- last_hidden_states,
- self.model.get_output_embeddings().weight.t(),
- )
- if getattr(self.model.get_output_embeddings(), "bias",
- None) is not None:
- logits += self.model.get_output_embeddings(
- ).bias.unsqueeze(0)
- logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
- seq_logprobs.append(logprobs)
+ (
+ seq_logprobs_lst,
+ output_len,
+ ) = self._hidden_states_to_logprobs(output.hidden_states,
+ num_logprobs)
+
+ all_logprobs.append(seq_logprobs_lst)
+ seq_ids = output.sequences[0]
+ output_len = len(seq_logprobs_lst)
+ output_ids = seq_ids[-output_len:]
+ all_output_ids.append(output_ids.tolist())
+ all_output_strs.append(self.tokenizer.decode(output_ids))
- # convert to dict
- seq_logprobs_lst: List[Dict[int, float]] = []
- for tok_idx, tok_logprobs in enumerate(seq_logprobs):
- # drop prompt logprobs
- if tok_idx == 0:
- tok_logprobs = tok_logprobs[-1, :].reshape(1, -1)
- topk = tok_logprobs.topk(num_logprobs)
+ outputs = zip(all_output_ids, all_output_strs, all_logprobs)
+ return [(output_ids, output_str, output_logprobs)
+ for output_ids, output_str, output_logprobs in outputs]
+
+ def generate_encoder_decoder_greedy_logprobs_limit(
+ self,
+ encoder_decoder_prompts: Tuple[List[str], List[str]],
+ max_tokens: int,
+ num_logprobs: int,
+ **kwargs: Any,
+ ) -> List[Tuple[List[int], str, List[Dict[int, float]]]]:
+ '''
+ Greedy logprobs generation for vLLM encoder/decoder models
+ '''
- tok_logprobs_dct = {}
- for token_id, logprob in zip(topk.indices[0], topk.values[0]):
- tok_logprobs_dct[token_id.item()] = logprob.item()
+ all_logprobs: List[List[Dict[int, float]]] = []
+ all_output_ids: List[List[int]] = []
+ all_output_strs: List[str] = []
- seq_logprobs_lst.append(tok_logprobs_dct)
+ for (encoder_prompt,
+ decoder_prompt) in to_enc_dec_tuple_list(encoder_decoder_prompts):
+ encoder_input_ids = self.wrap_device(
+ self.tokenizer(encoder_prompt, return_tensors="pt").input_ids)
+ decoder_input_ids = (
+ None if decoder_prompt is None else self.wrap_device(
+ self.tokenizer(decoder_prompt,
+ return_tensors="pt").input_ids))
+
+ output = self.model.generate(
+ encoder_input_ids,
+ decoder_input_ids=decoder_input_ids,
+ use_cache=True,
+ do_sample=False,
+ max_new_tokens=max_tokens,
+ output_hidden_states=True,
+ return_dict_in_generate=True,
+ **kwargs,
+ )
+
+ (
+ seq_logprobs_lst,
+ output_len,
+ ) = self._hidden_states_to_logprobs(output.decoder_hidden_states,
+ num_logprobs)
all_logprobs.append(seq_logprobs_lst)
seq_ids = output.sequences[0]
- output_len = len(seq_logprobs_lst)
output_ids = seq_ids[-output_len:]
all_output_ids.append(output_ids.tolist())
all_output_strs.append(self.tokenizer.decode(output_ids))
@@ -416,7 +524,7 @@ def __init__(
block_size: int = 16,
enable_chunked_prefill: bool = False,
swap_space: int = 4,
- enforce_eager: bool = False,
+ enforce_eager: Optional[bool] = False,
**kwargs,
) -> None:
self.model = LLM(
@@ -465,6 +573,19 @@ def generate(
outputs.append((req_sample_output_ids, req_sample_output_strs))
return outputs
+ def _final_steps_generate_w_logprobs(
+ self,
+ req_outputs: List[RequestOutput],
+ ) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
+ outputs: List[Tuple[List[int], str, Optional[SampleLogprobs]]] = []
+ for req_output in req_outputs:
+ for sample in req_output.outputs:
+ output_str = sample.text
+ output_ids = sample.token_ids
+ output_logprobs = sample.logprobs
+ outputs.append((output_ids, output_str, output_logprobs))
+ return outputs
+
def generate_w_logprobs(
self,
prompts: List[str],
@@ -483,14 +604,21 @@ def generate_w_logprobs(
req_outputs = self.model.generate(inputs,
sampling_params=sampling_params)
- outputs: List[Tuple[List[int], str, Optional[SampleLogprobs]]] = []
- for req_output in req_outputs:
- for sample in req_output.outputs:
- output_str = sample.text
- output_ids = sample.token_ids
- output_logprobs = sample.logprobs
- outputs.append((output_ids, output_str, output_logprobs))
- return outputs
+ return self._final_steps_generate_w_logprobs(req_outputs)
+
+ def generate_encoder_decoder_w_logprobs(
+ self,
+ encoder_decoder_prompts: Tuple[List[str], List[str]],
+ sampling_params: SamplingParams,
+ ) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
+ '''
+ Logprobs generation for vLLM encoder/decoder models
+ '''
+
+ assert sampling_params.logprobs is not None
+ req_outputs = self.model.generate(encoder_decoder_prompts,
+ sampling_params=sampling_params)
+ return self._final_steps_generate_w_logprobs(req_outputs)
def generate_greedy(
self,
@@ -523,6 +651,26 @@ def generate_greedy_logprobs(
return [(output_ids, output_str, output_logprobs)
for output_ids, output_str, output_logprobs in outputs]
+ def generate_encoder_decoder_greedy_logprobs(
+ self,
+ encoder_decoder_prompts: Tuple[List[str], List[str]],
+ max_tokens: int,
+ num_logprobs: int,
+ ) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
+ greedy_logprobs_params = SamplingParams(temperature=0.0,
+ use_beam_search=False,
+ max_tokens=max_tokens,
+ logprobs=num_logprobs)
+ '''
+ Greedy logprobs generation for vLLM encoder/decoder models
+ '''
+
+ outputs = self.generate_encoder_decoder_w_logprobs(
+ encoder_decoder_prompts, greedy_logprobs_params)
+
+ return [(output_ids, output_str, output_logprobs)
+ for output_ids, output_str, output_logprobs in outputs]
+
def generate_beam_search(
self,
prompts: List[str],
diff --git a/tests/core/test_scheduler.py b/tests/core/test_scheduler.py
index 447e8f8a586f6..11168d2423b0e 100644
--- a/tests/core/test_scheduler.py
+++ b/tests/core/test_scheduler.py
@@ -9,33 +9,11 @@
from vllm.core.interfaces import AllocStatus
from vllm.core.scheduler import Scheduler, SchedulingBudget
from vllm.lora.request import LoRARequest
-from vllm.sequence import Logprob, SequenceGroup, SequenceStatus
+from vllm.sequence import SequenceGroup, SequenceStatus
-from .utils import create_dummy_prompt
-
-
-def get_sequence_groups(scheduler_output):
- return [s.seq_group for s in scheduler_output.scheduled_seq_groups]
-
-
-def append_new_token(out, token_id: int):
- seq_groups = get_sequence_groups(out)
- for seq_group in seq_groups:
- for seq in seq_group.get_seqs():
- seq.append_token_id(token_id, {token_id: Logprob(token_id)})
-
-
-def schedule_and_update_computed_tokens(scheduler):
- metas, out = scheduler.schedule()
- for s, meta in zip(out.scheduled_seq_groups, metas):
- s.seq_group.update_num_computed_tokens(meta.token_chunk_size)
- return metas, out
-
-
-def append_new_token_seq_group(token_chunk_size, seq_group, token_id: int):
- seq_group.update_num_computed_tokens(token_chunk_size)
- for seq in seq_group.get_seqs():
- seq.append_token_id(token_id, {token_id: Logprob(token_id)})
+from .utils import (append_new_token, append_new_token_seq_group,
+ create_dummy_prompt, get_sequence_groups,
+ schedule_and_update_computed_tokens)
def test_scheduler_add_seq_group():
diff --git a/tests/core/test_scheduler_encoder_decoder.py b/tests/core/test_scheduler_encoder_decoder.py
new file mode 100644
index 0000000000000..50c047f30b80d
--- /dev/null
+++ b/tests/core/test_scheduler_encoder_decoder.py
@@ -0,0 +1,99 @@
+from typing import List
+
+import pytest # noqa
+
+from vllm.config import CacheConfig, SchedulerConfig
+from vllm.core.scheduler import Scheduler
+from vllm.sequence import SequenceGroup
+
+from .utils import (append_new_token, create_dummy_prompt_encoder_decoder,
+ get_sequence_groups, schedule_and_update_computed_tokens)
+
+
+def test_scheduler_schedule_simple_encoder_decoder():
+ '''
+ Test basic scheduler functionality in the context
+ of an encoder/decoder model. Focus on testing
+ enc/dec-specific functionality sense tests already
+ exist for decoder-only functionality
+
+ Test behavior:
+ * Construct Scheduler
+ * Construct dummy encoder/decoder sequence groups
+ * Add dummy seq groups to scheduler backlog
+ * Schedule the next seq group & validate:
+ * Cross-attn block tables
+ * Updated states of seq groups
+ * Number of batched tokens
+ * Number of blocks to copy/swap-in/swap-out
+ * Number of scheduled seq groups
+ * Repeat for both prefill- and decode-phase
+ * Abort scheduled seq groups
+ * Assert that aborted seq groups no longer appear in
+ cross-attention block table
+ '''
+
+ block_size = 4
+ num_seq_group = 4
+ max_model_len = 16
+ scheduler_config = SchedulerConfig(64, num_seq_group, max_model_len)
+ cache_config = CacheConfig(block_size, 1.0, 1, "auto")
+ cache_config.num_cpu_blocks = 16 # enc and dec prompts per seq_group
+ cache_config.num_gpu_blocks = 16 # enc and dec prompts per seq_group
+ scheduler = Scheduler(scheduler_config, cache_config, None)
+ running: List[SequenceGroup] = []
+
+ # Add seq groups to scheduler.
+ req_id_list = []
+ for i in range(num_seq_group):
+ req_id = str(i)
+ req_id_list.append(req_id)
+ _, _, seq_group = create_dummy_prompt_encoder_decoder(
+ req_id, block_size, block_size, block_size)
+ scheduler.add_seq_group(seq_group)
+ running.append(seq_group)
+
+ # Schedule seq groups prefill.
+ num_tokens = block_size * num_seq_group
+ seq_group_meta_list, out = schedule_and_update_computed_tokens(scheduler)
+ # - Verify that sequence group cross-attention block tables are
+ # registered with the block manager
+ assert all([(req_id in scheduler.block_manager.cross_block_tables)
+ for req_id in req_id_list])
+ # - Validate sequence-group status
+ assert set(get_sequence_groups(out)) == set(running)
+ # - Validate number of batched tokens
+ assert out.num_batched_tokens == num_tokens
+ # - Validate there are no remaining blocks to swap
+ assert (not out.blocks_to_copy and not out.blocks_to_swap_in
+ and not out.blocks_to_swap_out)
+ # - Validate all seq groups were scheduled
+ assert len(seq_group_meta_list) == num_seq_group
+ append_new_token(out, 1)
+
+ # Schedule seq groups decode.
+ seq_group_meta_list, out = schedule_and_update_computed_tokens(scheduler)
+ # - Verify that sequence group metadata includes encoder attention
+ # and cross-attention metadata
+ assert all([
+ not ((seq_group_meta.encoder_seq_data is None) or
+ (seq_group_meta.cross_block_table is None))
+ for seq_group_meta in seq_group_meta_list
+ ])
+ # - Validate sequence-group status
+ assert set(get_sequence_groups(out)) == set(running)
+ # - Validate there is one batched token per seq group
+ assert out.num_batched_tokens == num_seq_group
+ # - Validate there are no remaining blocks to swap
+ assert (not out.blocks_to_copy and not out.blocks_to_swap_in
+ and not out.blocks_to_swap_out)
+ # - Validate that all seq groups were scheduled
+ assert len(seq_group_meta_list) == num_seq_group
+ append_new_token(out, 1)
+
+ # Abort sequences
+ for req_id in req_id_list:
+ scheduler.abort_seq_group(req_id)
+ # - Verify that sequence group cross-attention block tables are
+ # NO LONGER registered with the block manager
+ assert req_id not in scheduler.block_manager.cross_block_tables
diff --git a/tests/core/utils.py b/tests/core/utils.py
index f249f4b59a2ee..45a8e74e85324 100644
--- a/tests/core/utils.py
+++ b/tests/core/utils.py
@@ -53,27 +53,30 @@ def create_dummy_prompt_encoder_decoder(
block_size = decoder_prompt_length
# Create dummy prompt sequence with tokens 0...block_size-1
- # and prompt "0 ... block_size".
+ # and prompt "0 ... block_size". Note that the prompt string
+ # doesn't actually match the tokens
decoder_prompt_tokens = list(range(decoder_prompt_length))
decoder_prompt_str = " ".join([str(t) for t in decoder_prompt_tokens])
+ encoder_prompt_tokens = list(reversed(list(range(encoder_prompt_length))))
+ encoder_prompt_str = " ".join([str(t) for t in encoder_prompt_tokens])
+
+ inputs = {
+ "prompt": decoder_prompt_str,
+ "prompt_token_ids": decoder_prompt_tokens,
+ "encoder_prompt": encoder_prompt_str,
+ "encoder_prompt_token_ids": encoder_prompt_tokens,
+ "multi_modal_data": None,
+ }
decoder_prompt = Sequence(int(request_id),
- inputs={
- "prompt": decoder_prompt_str,
- "prompt_token_ids": decoder_prompt_tokens,
- "multi_modal_data": None,
- },
- block_size=block_size)
+ inputs=inputs,
+ block_size=block_size,
+ from_decoder_prompt=True)
- encoder_prompt_tokens = list(reversed(list(range(encoder_prompt_length))))
- encoder_prompt_str = " ".join([str(t) for t in encoder_prompt_tokens])
encoder_prompt = Sequence(int(request_id),
- inputs={
- "prompt": encoder_prompt_str,
- "prompt_token_ids": encoder_prompt_tokens,
- "multi_modal_data": None,
- },
- block_size=block_size)
+ inputs=inputs,
+ block_size=block_size,
+ from_decoder_prompt=False)
seq_group = SequenceGroup(request_id=request_id,
seqs=[decoder_prompt],
sampling_params=SamplingParams(
@@ -139,17 +142,21 @@ def create_seq_group_encoder_decoder(
prompt_token_ids = [0] * seq_prompt_len
+ inputs = {
+ "prompt": "",
+ "prompt_token_ids": prompt_token_ids,
+ "encoder_prompt": "",
+ "encoder_prompt_token_ids": prompt_token_ids,
+ "multi_modal_data": None,
+ }
+
seqs = []
for seq_id_offset, output_len in enumerate(seq_output_lens):
- seq = Sequence(
- seq_id=seq_id_start + seq_id_offset,
- inputs={
- "prompt": "",
- "prompt_token_ids": prompt_token_ids,
- "multi_modal_data": None,
- },
- block_size=16,
- )
+ # Construct decoder input sequences
+ seq = Sequence(seq_id=seq_id_start + seq_id_offset,
+ inputs=inputs,
+ block_size=16,
+ from_decoder_prompt=True)
for i in range(output_len):
seq.append_token_id(
@@ -158,16 +165,11 @@ def create_seq_group_encoder_decoder(
)
seqs.append(seq)
- # Encoder sequence
- encoder_seq = Sequence(
- seq_id=seq_id_start + len(seq_output_lens),
- inputs={
- "prompt": "",
- "prompt_token_ids": prompt_token_ids,
- "multi_modal_data": None,
- },
- block_size=16,
- )
+ # Encoder input sequence
+ encoder_seq = Sequence(seq_id=seq_id_start + len(seq_output_lens),
+ inputs=inputs,
+ block_size=16,
+ from_decoder_prompt=False)
return SequenceGroup(request_id=request_id,
seqs=seqs,
@@ -177,4 +179,31 @@ def create_seq_group_encoder_decoder(
def round_up_to_next_block(seq_len: int, block_size: int) -> int:
- return (seq_len + block_size - 1) // block_size
\ No newline at end of file
+ return (seq_len + block_size - 1) // block_size
+
+
+# Helper functions for scheduler tests
+
+
+def get_sequence_groups(scheduler_output):
+ return [s.seq_group for s in scheduler_output.scheduled_seq_groups]
+
+
+def append_new_token(out, token_id: int):
+ seq_groups = get_sequence_groups(out)
+ for seq_group in seq_groups:
+ for seq in seq_group.get_seqs():
+ seq.append_token_id(token_id, {token_id: Logprob(token_id)})
+
+
+def schedule_and_update_computed_tokens(scheduler):
+ metas, out = scheduler.schedule()
+ for s, meta in zip(out.scheduled_seq_groups, metas):
+ s.seq_group.update_num_computed_tokens(meta.token_chunk_size)
+ return metas, out
+
+
+def append_new_token_seq_group(token_chunk_size, seq_group, token_id: int):
+ seq_group.update_num_computed_tokens(token_chunk_size)
+ for seq in seq_group.get_seqs():
+ seq.append_token_id(token_id, {token_id: Logprob(token_id)})
diff --git a/tests/distributed/test_basic_distributed_correctness_enc_dec.py b/tests/distributed/test_basic_distributed_correctness_enc_dec.py
new file mode 100644
index 0000000000000..69eae62ca7320
--- /dev/null
+++ b/tests/distributed/test_basic_distributed_correctness_enc_dec.py
@@ -0,0 +1,101 @@
+"""For encoder/decoder models only:
+Compare the outputs of HF and distributed vLLM when using greedy sampling.
+
+Run:
+```sh
+cd $VLLM_PATH/tests
+
+pytest distributed/test_basic_distributed_correctness_enc_dec.py
+```
+"""
+
+import pytest
+
+from tests.models.utils import DecoderPromptType
+from vllm.utils import cuda_device_count_stateless
+
+from ..models.utils import check_logprobs_close
+from ..utils import fork_new_process_for_each_test
+
+
+@pytest.mark.skipif(cuda_device_count_stateless() < 2,
+ reason="Need at least 2 GPUs to run the test.")
+@pytest.mark.parametrize("model, distributed_executor_backend", [
+ ("facebook/bart-large-cnn", "ray"),
+ ("facebook/bart-large-cnn", "mp"),
+])
+@fork_new_process_for_each_test
+def test_models(
+ model: str,
+ distributed_executor_backend: str,
+ hf_runner,
+ vllm_runner,
+ example_encoder_decoder_prompts,
+) -> None:
+ '''
+ Test vLLM BART inference on more than one GPU, comparing
+ outputs against HF as a baseline.
+
+ Fork a new process for each test, to prevent CUDA from
+ being re-initialized by successive tests within the same
+ process.
+
+ Arguments:
+
+ * model: the HF ID of the specific BART variant under test
+ * distributed_executor_backend
+ * hf_runner: HuggingFace (HF) test model runner
+ * vllm_runner: vLLM test model runner
+ * example_encoder_decoder_prompts: test fixture which provides a
+ dictionary of dummy prompts
+ '''
+
+ dtype = "float"
+ max_tokens = 64
+ num_logprobs = 5
+
+ # Example inputs with non-trivial (i.e. not None/empty) encoder &
+ # decoder prompts.
+ test_prompts = example_encoder_decoder_prompts[DecoderPromptType.CUSTOM]
+
+ # NOTE: take care of the order. run vLLM first, and then run HF.
+ # vLLM needs a fresh new process without cuda initialization.
+ # if we run HF first, the cuda initialization will be done and it
+ # will hurt multiprocessing backend with fork method (the default method).
+ with vllm_runner(
+ model,
+ dtype=dtype,
+ tensor_parallel_size=2,
+ distributed_executor_backend=distributed_executor_backend,
+ enforce_eager=True,
+ ) as vllm_model:
+ vllm_outputs = vllm_model.generate_encoder_decoder_greedy_logprobs(
+ test_prompts, max_tokens, num_logprobs)
+
+ # Configuration settings for HF baseline
+ hf_kwargs = {
+ "top_k": None,
+ "num_beams": 1,
+ "repetition_penalty": 1.0,
+ "top_p": 1.0,
+ "length_penalty": 1.0,
+ "early_stopping": False,
+ "no_repeat_ngram_size": None,
+ "min_length": 0
+ }
+
+ with hf_runner(model, dtype=dtype,
+ is_encoder_decoder_model=True) as hf_model:
+ hf_outputs = (hf_model.generate_encoder_decoder_greedy_logprobs_limit(
+ test_prompts,
+ max_tokens,
+ num_logprobs,
+ **hf_kwargs,
+ ))
+
+ check_logprobs_close(
+ outputs_0_lst=hf_outputs,
+ outputs_1_lst=vllm_outputs,
+ name_0="hf",
+ name_1="vllm",
+ )
diff --git a/tests/entrypoints/openai/test_mp_crash.py b/tests/entrypoints/openai/test_mp_crash.py
new file mode 100644
index 0000000000000..7dc595a7be351
--- /dev/null
+++ b/tests/entrypoints/openai/test_mp_crash.py
@@ -0,0 +1,35 @@
+from typing import Any
+
+import pytest
+
+from vllm.engine.async_llm_engine import AsyncLLMEngine
+from vllm.entrypoints.openai.api_server import build_async_engine_client
+from vllm.entrypoints.openai.cli_args import make_arg_parser
+from vllm.utils import FlexibleArgumentParser
+
+
+def crashing_from_engine_args(
+ cls,
+ engine_args: Any = None,
+ start_engine_loop: Any = None,
+ usage_context: Any = None,
+ stat_loggers: Any = None,
+) -> "AsyncLLMEngine":
+ raise Exception("foo")
+
+
+@pytest.mark.asyncio
+async def test_mp_crash_detection(monkeypatch):
+
+ with pytest.raises(RuntimeError) as excinfo, monkeypatch.context() as m:
+ m.setattr(AsyncLLMEngine, "from_engine_args",
+ crashing_from_engine_args)
+ parser = FlexibleArgumentParser(
+ description="vLLM's remote OpenAI server.")
+ parser = make_arg_parser(parser)
+ args = parser.parse_args([])
+
+ async with build_async_engine_client(args):
+ pass
+ assert "The server process died before responding to the readiness probe"\
+ in str(excinfo.value)
diff --git a/tests/entrypoints/openai/test_oot_registration.py b/tests/entrypoints/openai/test_oot_registration.py
index 5272ac4065f1d..9f9a4cd972c51 100644
--- a/tests/entrypoints/openai/test_oot_registration.py
+++ b/tests/entrypoints/openai/test_oot_registration.py
@@ -9,6 +9,11 @@
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.utils import get_open_port
+from ...utils import VLLM_PATH, RemoteOpenAIServer
+
+chatml_jinja_path = VLLM_PATH / "examples/template_chatml.jinja"
+assert chatml_jinja_path.exists()
+
class MyOPTForCausalLM(OPTForCausalLM):
@@ -21,12 +26,25 @@ def compute_logits(self, hidden_states: torch.Tensor,
return logits
-def server_function(port):
+def server_function(port: int):
# register our dummy model
ModelRegistry.register_model("OPTForCausalLM", MyOPTForCausalLM)
- sys.argv = ["placeholder.py"] + \
- ("--model facebook/opt-125m --gpu-memory-utilization 0.10 "
- f"--dtype float32 --api-key token-abc123 --port {port}").split()
+
+ sys.argv = ["placeholder.py"] + [
+ "--model",
+ "facebook/opt-125m",
+ "--gpu-memory-utilization",
+ "0.10",
+ "--dtype",
+ "float32",
+ "--api-key",
+ "token-abc123",
+ "--port",
+ str(port),
+ "--chat-template",
+ str(chatml_jinja_path),
+ ]
+
import runpy
runpy.run_module('vllm.entrypoints.openai.api_server', run_name='__main__')
@@ -36,35 +54,40 @@ def test_oot_registration_for_api_server():
ctx = torch.multiprocessing.get_context()
server = ctx.Process(target=server_function, args=(port, ))
server.start()
- MAX_SERVER_START_WAIT_S = 60
- client = OpenAI(
- base_url=f"http://localhost:{port}/v1",
- api_key="token-abc123",
- )
- now = time.time()
- while True:
- try:
- completion = client.chat.completions.create(
- model="facebook/opt-125m",
- messages=[{
- "role": "system",
- "content": "You are a helpful assistant."
- }, {
- "role": "user",
- "content": "Hello!"
- }],
- temperature=0,
- )
- break
- except OpenAIError as e:
- if "Connection error" in str(e):
- time.sleep(3)
- if time.time() - now > MAX_SERVER_START_WAIT_S:
- raise RuntimeError("Server did not start in time") from e
- else:
- raise e
- server.kill()
+
+ try:
+ client = OpenAI(
+ base_url=f"http://localhost:{port}/v1",
+ api_key="token-abc123",
+ )
+ now = time.time()
+ while True:
+ try:
+ completion = client.chat.completions.create(
+ model="facebook/opt-125m",
+ messages=[{
+ "role": "system",
+ "content": "You are a helpful assistant."
+ }, {
+ "role": "user",
+ "content": "Hello!"
+ }],
+ temperature=0,
+ )
+ break
+ except OpenAIError as e:
+ if "Connection error" in str(e):
+ time.sleep(3)
+ if time.time() - now > RemoteOpenAIServer.MAX_START_WAIT_S:
+ msg = "Server did not start in time"
+ raise RuntimeError(msg) from e
+ else:
+ raise e
+ finally:
+ server.terminate()
+
generated_text = completion.choices[0].message.content
+ assert generated_text is not None
# make sure only the first token is generated
rest = generated_text.replace("", "")
assert rest == ""
diff --git a/tests/kernels/test_attention_selector.py b/tests/kernels/test_attention_selector.py
index d9404e6442616..a20a741c27f74 100644
--- a/tests/kernels/test_attention_selector.py
+++ b/tests/kernels/test_attention_selector.py
@@ -3,9 +3,9 @@
import pytest
import torch
-from tests.kernels.utils import (STR_FLASH_ATTN_VAL, STR_INVALID_VAL,
- override_backend_env_variable)
+from tests.kernels.utils import override_backend_env_variable
from vllm.attention.selector import which_attn_to_use
+from vllm.utils import STR_FLASH_ATTN_VAL, STR_INVALID_VAL
@pytest.mark.parametrize(
diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py
index f25e7d480b6b3..b550a7fdd84f0 100644
--- a/tests/kernels/test_encoder_decoder_attn.py
+++ b/tests/kernels/test_encoder_decoder_attn.py
@@ -4,8 +4,6 @@
* E2E test of Encoder attention + Decoder self-attention +
Encoder/decoder cross-attention (collectively
"encoder/decoder attention")
-* Confirm enc/dec models will fail for chunked prefill
-* Confirm enc/dec models will fail for prefix caching
"""
@@ -15,19 +13,22 @@
import torch
from tests.kernels.utils import *
-from tests.kernels.utils import make_causal_mask, maybe_make_long_tensor
-from vllm.attention import Attention, AttentionMetadata
-from vllm.attention.backends.abstract import AttentionBackend, AttentionType
+from vllm.attention import (Attention, AttentionBackend, AttentionMetadata,
+ AttentionType)
from vllm.attention.backends.utils import STR_NOT_IMPL_ENC_DEC_ROCM_HIP
+from vllm.attention.selector import (_Backend,
+ global_force_attn_backend_context_manager)
from vllm.utils import is_hip
+# List of support backends for encoder/decoder models
+LIST_ENC_DEC_SUPPORTED_BACKENDS = [_Backend.XFORMERS]
+
HEAD_SIZES = [64, 256]
NUM_HEADS = [1, 16]
BATCH_SIZES = [1, 16]
BLOCK_SIZES = [16]
-BACKEND_NAMES = [STR_XFORMERS_ATTN_VAL]
CUDA_DEVICE = "cuda:0"
MAX_DEC_SEQ_LENS = [128]
@@ -724,57 +725,92 @@ def _run_encoder_decoder_cross_attention_test(
@pytest.mark.skipif(is_hip(), reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
-@pytest.mark.parametrize("backend_name", BACKEND_NAMES)
+@pytest.mark.parametrize("attn_backend", LIST_ENC_DEC_SUPPORTED_BACKENDS)
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("max_dec_seq_len", MAX_DEC_SEQ_LENS)
@pytest.mark.parametrize("max_enc_seq_len", MAX_ENC_SEQ_LENS)
-def test_encoder_only(num_heads: int, head_size: int, backend_name: str,
- batch_size: int, block_size: int, max_dec_seq_len: int,
- max_enc_seq_len: int, monkeypatch):
+def test_encoder_only(
+ num_heads: int,
+ head_size: int,
+ attn_backend: _Backend,
+ batch_size: int,
+ block_size: int,
+ max_dec_seq_len: int,
+ max_enc_seq_len: int,
+):
+ '''
+ End-to-end encoder-only attention test:
+
+ * Construct fake test vectors for (1) encoder attention
+ * Construct (1) attention metadata structure with prefill-phase
+ encoder attention, and (2) an analogous attention metadata
+ structure but for decode-phase
+ * Test & validate encoder attention against ideal output
+
+ No KV cache is required for encoder-only attention.
+
+ Note on ROCm/HIP: currently encoder/decoder models are not supported on
+ AMD GPUs, therefore this test simply is skipped if is_hip().
+
+ This test globally forces an override of the usual backend
+ auto-selection process, forcing the specific backend-under-test
+ to be utilized.
+
+ Arguments:
+
+ * num_heads
+ * head_size,
+ * attn_backend: The attention backend to employ for testing
+ * batch_size
+ * block_size: KV cache block size
+ * max_dec_seq_len: max length of decoder input sequences
+ * max_enc_seq_len: max length of encoder input sequences
+ '''
# Force Attention wrapper backend
- override_backend_env_variable(monkeypatch, backend_name)
+ with global_force_attn_backend_context_manager(attn_backend):
- # Note: KV cache size of 4096 is arbitrary & chosen intentionally
- # to be more than necessary, since exceeding the kv cache size
- # is not part of this test
- test_pt = TestPoint(num_heads, head_size, backend_name, batch_size,
- block_size, max_dec_seq_len, max_enc_seq_len, 4096)
+ # Note: KV cache size of 4096 is arbitrary & chosen intentionally
+ # to be more than necessary, since exceeding the kv cache size
+ # is not part of this test
+ test_pt = TestPoint(num_heads, head_size, attn_backend.name,
+ batch_size, block_size, max_dec_seq_len,
+ max_enc_seq_len, 4096)
- # Attention scale factor, attention backend instance, attention wrapper
- # instance, KV cache init
- test_rsrcs = _make_test_resources(test_pt)
+ # Attention scale factor, attention backend instance, attention wrapper
+ # instance, KV cache init
+ test_rsrcs = _make_test_resources(test_pt)
- # Construct encoder attention test params (only used
- # during prefill)
+ # Construct encoder attention test params (only used
+ # during prefill)
- enc_test_params = _encoder_attn_setup(test_pt, test_rsrcs)
+ enc_test_params = _encoder_attn_setup(test_pt, test_rsrcs)
- # Shared prefill metadata structure
+ # Shared prefill metadata structure
- prephase_attn_metadata: AttentionMetadata = make_test_metadata(
- test_rsrcs.attn_backend,
- True,
- None,
- decoder_test_params=None,
- encoder_test_params=enc_test_params,
- cross_test_params=None,
- device=CUDA_DEVICE)
+ prephase_attn_metadata: AttentionMetadata = make_test_metadata(
+ test_rsrcs.attn_backend,
+ True,
+ None,
+ decoder_test_params=None,
+ encoder_test_params=enc_test_params,
+ cross_test_params=None,
+ device=CUDA_DEVICE)
- # PREFILL: encoder attention
+ # PREFILL: encoder attention
- enc_pckd_act_out: torch.Tensor = (_run_encoder_attention_test(
- test_rsrcs.attn, enc_test_params, prephase_attn_metadata))
+ enc_pckd_act_out: torch.Tensor = (_run_encoder_attention_test(
+ test_rsrcs.attn, enc_test_params, prephase_attn_metadata))
- # - Is encoder attention result correct?
- assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out)
+ # - Is encoder attention result correct?
+ assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out)
@pytest.mark.skipif(is_hip(), reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
-@pytest.mark.parametrize("backend_name", BACKEND_NAMES)
+@pytest.mark.parametrize("attn_backend", LIST_ENC_DEC_SUPPORTED_BACKENDS)
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("max_dec_seq_len", MAX_DEC_SEQ_LENS)
@@ -782,12 +818,11 @@ def test_encoder_only(num_heads: int, head_size: int, backend_name: str,
def test_e2e_enc_dec_attn(
num_heads: int,
head_size: int,
- backend_name: str,
+ attn_backend: _Backend,
batch_size: int,
block_size: int,
max_dec_seq_len: int,
max_enc_seq_len: int,
- monkeypatch,
) -> None:
'''
End-to-end encoder/decoder test:
@@ -820,8 +855,9 @@ def test_e2e_enc_dec_attn(
cross-attention K/Vs are allowed to differ in seq len, as is often the case
for cross-attention.
- This test utilizes PyTest monkey patching to force the attention backend
- via an environment variable.
+ This test globally forces an override of the usual backend
+ auto-selection process, forcing the specific backend-under-test
+ to be utilized.
Note on ROCm/HIP: currently encoder/decoder models are not supported on
AMD GPUs, therefore this test simply is skipped if is_hip().
@@ -830,124 +866,136 @@ def test_e2e_enc_dec_attn(
all prefill-phase attention operations (encoder, decoder, enc/dec cross),
and a single one shared by all decode-phase attention operations
(decoder & enc/dec cross.) This is intended to reflect the behavior
- of ModelRunner, which constructs a single attention metadata structure for
- each prefill or decode run. A realistic scenario would rely on the
- attention backend to utilize the appropriate attention metadata fields
- according to the value of attn_metadata.attention_type. Thus, this test is
- organized so as to confirm that the backend-under-test can handle a
- shared prefill attention metadata structure & a shared decode attention
- metadata structure.
- '''
-
- # Force Attention wrapper backend
- override_backend_env_variable(monkeypatch, backend_name)
-
- # Note: KV cache size of 4096 is arbitrary & chosen intentionally
- # to be more than necessary, since exceeding the kv cache size
- # is not part of this test
- test_pt = TestPoint(num_heads, head_size, backend_name, batch_size,
- block_size, max_dec_seq_len, max_enc_seq_len, 4096)
-
- # Attention scale factor, attention backend instance, attention wrapper
- # instance, KV cache init
- test_rsrcs = _make_test_resources(test_pt)
+ of EncoderDecoderModelRunner, which constructs a single attention metadata
+ structure for each prefill or decode run. A realistic scenario would rely
+ on the attention backend to utilize the appropriate attention metadata
+ fields according to the value of attn_metadata.attention_type. Thus,
+ this test is organized so as to confirm that the backend-under-test can
+ handle a shared prefill attention metadata structure & a shared decode\
+ attention metadata structure.
- # Construct encoder attention test params (only used
- # during prefill)
-
- enc_test_params = _encoder_attn_setup(test_pt, test_rsrcs)
-
- # Construct Decoder self-attention prefill-phase & decode-phase
- # test params, including query/key/value tensors, decoder self-attention
- # memory-mapping. cross_block_base_addr is the uppermost address in the
- # decoder self-attention block-table, i.e. a base address which the
- # encoder/decoder cross-attention block-table may build downward toward.
-
- (
- dec_qkv,
- prephase_dec_test_params,
- decphase_dec_test_params,
- cross_block_base_addr,
- ) = _decoder_attn_setup(test_pt, test_rsrcs)
+ Arguments:
- # Construct encoder/decoder cross-attention prefill-phase & decode-phase
- # test params, including key/value tensors, cross-attention memory-mapping
+ * num_heads
+ * head_size,
+ * attn_backend: The attention backend to employ for testing
+ * batch_size
+ * block_size: KV cache block size
+ * max_dec_seq_len: max length of decoder input sequences
+ * max_enc_seq_len: max length of encoder input sequences
+ '''
- (
- prephase_cross_test_params,
- decphase_cross_test_params,
- ) = _enc_dec_cross_attn_setup_reuses_query(
- dec_qkv,
- enc_test_params,
- prephase_dec_test_params,
- test_pt,
- test_rsrcs,
- block_base_addr=cross_block_base_addr)
-
- # Shared prefill metadata structure
- assert prephase_dec_test_params.packed_qkvo.packed_qkv is not None
- prephase_attn_metadata: AttentionMetadata = make_test_metadata(
- test_rsrcs.attn_backend,
- True,
- prephase_dec_test_params.packed_qkvo.packed_qkv.q_seq_lens,
- decoder_test_params=prephase_dec_test_params,
- encoder_test_params=enc_test_params,
- cross_test_params=prephase_cross_test_params,
- device=CUDA_DEVICE)
-
- # PREFILL: encoder attention
-
- enc_pckd_act_out = _run_encoder_attention_test(test_rsrcs.attn,
- enc_test_params,
- prephase_attn_metadata)
-
- # - Is encoder attention result correct?
- assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out)
-
- # PREFILL: decoder self-attention test
-
- prephase_dec_pckd_act_out = _run_decoder_self_attention_test(
- test_rsrcs, prephase_dec_test_params, prephase_attn_metadata)
-
- # - Is prefill decoder self-attention correct?
- assert_actual_matches_ideal(prephase_dec_test_params,
- prephase_dec_pckd_act_out)
-
- # PREFILL: encoder/decoder cross-attention test
-
- prephase_cross_pckd_act_out = _run_encoder_decoder_cross_attention_test(
- test_rsrcs, prephase_dec_test_params, prephase_cross_test_params,
- prephase_attn_metadata)
-
- # - Is prefill encoder/decoder cross-attention correct?
- assert_actual_matches_ideal(prephase_cross_test_params,
- prephase_cross_pckd_act_out)
-
- # DECODE: build decode-phase attention metadata
-
- decphase_attn_metadata: AttentionMetadata = make_test_metadata(
- test_rsrcs.attn_backend,
- False,
- dec_qkv.q_seq_lens,
- decoder_test_params=decphase_dec_test_params,
- encoder_test_params=enc_test_params,
- cross_test_params=decphase_cross_test_params,
- device=CUDA_DEVICE)
-
- # DECODE: decoder self-attention test
-
- decphase_dec_pckd_act_out = _run_decoder_self_attention_test(
- test_rsrcs, decphase_dec_test_params, decphase_attn_metadata)
-
- # - Is decode-phase decoder self-attention correct?
- assert_actual_matches_ideal(decphase_dec_test_params,
- decphase_dec_pckd_act_out)
-
- # DECODE: encoder/decoder cross-attention test
-
- decphase_cross_pckd_act_out = _run_encoder_decoder_cross_attention_test(
- test_rsrcs, decphase_dec_test_params, None, decphase_attn_metadata)
-
- # - Is decode-phase encoder/decoder cross-attention correct?
- assert_actual_matches_ideal(decphase_cross_test_params,
- decphase_cross_pckd_act_out)
+ # Force Attention wrapper backend
+ with global_force_attn_backend_context_manager(attn_backend):
+
+ # Note: KV cache size of 4096 is arbitrary & chosen intentionally
+ # to be more than necessary, since exceeding the kv cache size
+ # is not part of this test
+ test_pt = TestPoint(num_heads, head_size, attn_backend.name,
+ batch_size, block_size, max_dec_seq_len,
+ max_enc_seq_len, 4096)
+
+ # Attention scale factor, attention backend instance, attention wrapper
+ # instance, KV cache init
+ test_rsrcs = _make_test_resources(test_pt)
+
+ # Construct encoder attention test params (only used
+ # during prefill)
+
+ enc_test_params = _encoder_attn_setup(test_pt, test_rsrcs)
+
+ # Construct Decoder self-attention prefill-phase & decode-phase
+ # test params, including query/key/value tensors, decoder self-attention
+ # memory-mapping. cross_block_base_addr is the uppermost address in the
+ # decoder self-attention block-table, i.e. a base address which the
+ # encoder/decoder cross-attention block-table may build downward toward.
+
+ (
+ dec_qkv,
+ prephase_dec_test_params,
+ decphase_dec_test_params,
+ cross_block_base_addr,
+ ) = _decoder_attn_setup(test_pt, test_rsrcs)
+
+ # Construct encoder/decoder cross-attention prefill-phase
+ # & decode-phase test params, including key/value tensors,
+ # cross-attention memory-mapping
+
+ (
+ prephase_cross_test_params,
+ decphase_cross_test_params,
+ ) = _enc_dec_cross_attn_setup_reuses_query(
+ dec_qkv,
+ enc_test_params,
+ prephase_dec_test_params,
+ test_pt,
+ test_rsrcs,
+ block_base_addr=cross_block_base_addr)
+
+ # Shared prefill metadata structure
+ assert prephase_dec_test_params.packed_qkvo.packed_qkv is not None
+ prephase_attn_metadata: AttentionMetadata = make_test_metadata(
+ test_rsrcs.attn_backend,
+ True,
+ prephase_dec_test_params.packed_qkvo.packed_qkv.q_seq_lens,
+ decoder_test_params=prephase_dec_test_params,
+ encoder_test_params=enc_test_params,
+ cross_test_params=prephase_cross_test_params,
+ device=CUDA_DEVICE)
+
+ # PREFILL: encoder attention
+
+ enc_pckd_act_out = _run_encoder_attention_test(test_rsrcs.attn,
+ enc_test_params,
+ prephase_attn_metadata)
+
+ # - Is encoder attention result correct?
+ assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out)
+
+ # PREFILL: decoder self-attention test
+
+ prephase_dec_pckd_act_out = _run_decoder_self_attention_test(
+ test_rsrcs, prephase_dec_test_params, prephase_attn_metadata)
+
+ # - Is prefill decoder self-attention correct?
+ assert_actual_matches_ideal(prephase_dec_test_params,
+ prephase_dec_pckd_act_out)
+
+ # PREFILL: encoder/decoder cross-attention test
+
+ prephase_cross_pckd_act_out = _run_encoder_decoder_cross_attention_test(
+ test_rsrcs, prephase_dec_test_params, prephase_cross_test_params,
+ prephase_attn_metadata)
+
+ # - Is prefill encoder/decoder cross-attention correct?
+ assert_actual_matches_ideal(prephase_cross_test_params,
+ prephase_cross_pckd_act_out)
+
+ # DECODE: build decode-phase attention metadata
+
+ decphase_attn_metadata: AttentionMetadata = make_test_metadata(
+ test_rsrcs.attn_backend,
+ False,
+ dec_qkv.q_seq_lens,
+ decoder_test_params=decphase_dec_test_params,
+ encoder_test_params=enc_test_params,
+ cross_test_params=decphase_cross_test_params,
+ device=CUDA_DEVICE)
+
+ # DECODE: decoder self-attention test
+
+ decphase_dec_pckd_act_out = _run_decoder_self_attention_test(
+ test_rsrcs, decphase_dec_test_params, decphase_attn_metadata)
+
+ # - Is decode-phase decoder self-attention correct?
+ assert_actual_matches_ideal(decphase_dec_test_params,
+ decphase_dec_pckd_act_out)
+
+ # DECODE: encoder/decoder cross-attention test
+
+ decphase_cross_pckd_act_out = _run_encoder_decoder_cross_attention_test(
+ test_rsrcs, decphase_dec_test_params, None, decphase_attn_metadata)
+
+ # - Is decode-phase encoder/decoder cross-attention correct?
+ assert_actual_matches_ideal(decphase_cross_test_params,
+ decphase_cross_pckd_act_out)
diff --git a/tests/kernels/test_flash_attn.py b/tests/kernels/test_flash_attn.py
index 6c5eff00de44c..0d3edc5d2aaf7 100644
--- a/tests/kernels/test_flash_attn.py
+++ b/tests/kernels/test_flash_attn.py
@@ -211,5 +211,5 @@ def test_varlen_with_paged_kv(
sliding_window=sliding_window,
soft_cap=soft_cap,
)
- assert torch.allclose(output, ref_output, atol=1e-2, rtol=1e-2), \
+ assert torch.allclose(output, ref_output, atol=2e-2, rtol=1e-2), \
f"{torch.max(torch.abs(output - ref_output))}"
diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py
index 23d627820d247..e942336ff7fdc 100644
--- a/tests/kernels/utils.py
+++ b/tests/kernels/utils.py
@@ -8,24 +8,10 @@
import pytest
import torch
-from vllm.attention.backends.abstract import (AttentionBackend,
- AttentionMetadata, AttentionType)
+from vllm.attention import AttentionBackend, AttentionMetadata, AttentionType
from vllm.attention.backends.xformers import XFormersBackend
-from vllm.utils import make_tensor_with_pad
-
-# String name of register which may be set in order to
-# force auto-selection of attention backend by Attention
-# wrapper
-STR_BACKEND_ENV_VAR: str = "VLLM_ATTENTION_BACKEND"
-
-# Possible string values of STR_BACKEND_ENV_VAR
-# register, corresponding to possible backends
-STR_FLASHINFER_ATTN_VAL: str = "FLASHINFER"
-STR_TORCH_SDPA_ATTN_VAL: str = "TORCH_SDPA"
-STR_ROCM_FLASH_ATTN_VAL: str = "ROCM_FLASH"
-STR_XFORMERS_ATTN_VAL: str = "XFORMERS"
-STR_FLASH_ATTN_VAL: str = "FLASH_ATTN"
-STR_INVALID_VAL: str = "INVALID"
+from vllm.utils import (STR_BACKEND_ENV_VAR, STR_XFORMERS_ATTN_VAL,
+ make_tensor_with_pad)
class QKVInputs(NamedTuple):
diff --git a/tests/models/test_bart.py b/tests/models/test_bart.py
new file mode 100644
index 0000000000000..9c26b7163ff62
--- /dev/null
+++ b/tests/models/test_bart.py
@@ -0,0 +1,153 @@
+"""Compare the outputs of HF and vLLM for BART models using greedy sampling.
+
+Run `pytest tests/models/test_bart.py`.
+"""
+from vllm.utils import is_cpu
+
+if not is_cpu():
+ # CPU backend is not currently supported with encoder/decoder models
+ # skip test definitions entirely to avoid importing GPU kernel libs
+ # (xFormers, etc.)
+
+ import pytest
+
+ from tests.models.utils import DecoderPromptType
+
+ from .utils import check_logprobs_close
+
+ MODELS = ["facebook/bart-base", "facebook/bart-large-cnn"]
+
+ DECODER_PROMPT_TYPES = ([
+ DecoderPromptType.CUSTOM, DecoderPromptType.EMPTY_STR,
+ DecoderPromptType.NONE
+ ])
+
+ @pytest.mark.parametrize("model", MODELS)
+ @pytest.mark.parametrize("dtype", ["float", "bfloat16"])
+ @pytest.mark.parametrize("max_tokens", [64])
+ @pytest.mark.parametrize("num_logprobs", [5])
+ @pytest.mark.parametrize("decoder_prompt_type", DECODER_PROMPT_TYPES)
+ def test_models(
+ hf_runner,
+ vllm_runner,
+ example_encoder_decoder_prompts,
+ model: str,
+ dtype: str,
+ max_tokens: int,
+ num_logprobs: int,
+ decoder_prompt_type: DecoderPromptType,
+ ) -> None:
+ '''
+ Test the vLLM BART model for a variety of encoder/decoder input prompts,
+ by validating it against HuggingFace (HF) BART.
+
+ Arguments:
+
+ * hf_runner: HuggingFace (HF) test model runner
+ * vllm_runner: vLLM test model runner
+ * example_encoder_decoder_prompts: test fixture which provides a
+ dictionary of dummy prompts
+ * model: the HF ID of the specific BART variant under test
+ * dtype: the tensor datatype to employ
+ * max_tokens
+ * num_logprobs
+ * decoder_prompt_type: key into the example_encoder_decoder_prompts
+ dictionary; selects specific encoder/decoder
+ prompt scenarios to test
+
+ A note on using HF BART as a baseline for validating vLLM BART,
+ specifically when the decoder prompt is None.
+
+ The HF GenerationMixin's default behavior is to force the first
+ decoded token to be if the prompt does not already contain
+ (this is accomplished using a logit
+ processor setting.)
+
+ So when we use HF BART as our baseline for comparison, note that
+ when the user provides a request with a None decoder prompt
+ (i.e. a singleton encoder prompt, or else an explicit encoder/
+ decoder prompt with the decoder sub-prompt set to None), HF and
+ vLLM handle this in different ways:
+
+ * HF will (1) tokenize the None prompt as an empty token-list,
+ (2) append to the beginning, yielding
+ [], (3) pass this token list to the model, and
+ then (4) after computing logits during prefill, override the model
+ logits & force to be the first generated token.
+
+ * vLLM will (1) tokenize the None prompt as [], (2) append decoder-
+ start-token to the beginning, yielding [],
+ (3) pass these tokens to the model & proceed with generation.
+
+ The net effect is that compared to vLLM, the list of HF *decoded* tokens
+ will contain one more initial than the vLLM generated tokens,
+ because vLLM's token is injected into the prompt rather than into
+ the generated output. This is in spite of the fact that overall, the
+ complete sequences (prompt + decoded tokens) produced by vLLM will match
+ HF.
+
+ So when we use HF decoded token output to validate vLLM's decoded token
+ output, the testing process must account for the difference in decoded
+ token sequences between vLLM and HF specifically in the
+ decoder-prompt-is-None case.
+
+ One option is to disable the logit processor feature that forces the
+ token to be decoded (forced_bos_token_id = None), eliminating
+ the problem entirely. However this is not "normal" BART usage.
+
+ The other option is - only in the decoder-prompt-is-None case - to
+ discard the first decoded token from the HF output before comparing it
+ to vLLM.
+
+ To that end, when testing the scenario where the decoder prompt is None
+ (and only in that one scenario), this test skips the first HF decoded
+ token during the process of validating the vLLM decoded output.
+ '''
+
+ test_case_prompts = example_encoder_decoder_prompts[
+ decoder_prompt_type]
+
+ # Configuration settings for HF baseline
+ hf_kwargs = {
+ "top_k": None,
+ "num_beams": 1,
+ "repetition_penalty": 1.0,
+ "top_p": 1.0,
+ "length_penalty": 1.0,
+ "early_stopping": False,
+ "no_repeat_ngram_size": None,
+ "min_length": 0
+ }
+
+ with hf_runner(model, dtype=dtype,
+ is_encoder_decoder_model=True) as hf_model:
+ hf_outputs = (
+ hf_model.generate_encoder_decoder_greedy_logprobs_limit(
+ test_case_prompts,
+ max_tokens,
+ num_logprobs,
+ **hf_kwargs,
+ ))
+
+ # Note: currently encoder/decoder models are only compatible with
+ # enforce_eager=True. Normally this is not a problem because
+ # for encoder/decoder models vLLM will
+ # default to enforce_eager=True if enforce_eager
+ # is left unspecified. However, the
+ # VllmRunner test fixture (which wraps around the LLM class) defaults to
+ # enforce_eager=False (a behavior which a number of already-exisitng
+ # decoder-only unit tests expect), so when testing an encoder/decoder
+ # model we must explicitly specify enforce_eager=True in the VllmRunner
+ # constructor.
+ with vllm_runner(model, dtype=dtype, enforce_eager=True) as vllm_model:
+ vllm_outputs = vllm_model.generate_encoder_decoder_greedy_logprobs(
+ test_case_prompts, max_tokens, num_logprobs)
+
+ hf_skip_tokens = (1 if decoder_prompt_type == DecoderPromptType.NONE
+ else 0)
+
+ check_logprobs_close(outputs_0_lst=hf_outputs,
+ outputs_1_lst=vllm_outputs,
+ name_0="hf",
+ name_1="vllm",
+ num_outputs_0_skip_tokens=hf_skip_tokens)
diff --git a/tests/models/test_internvl.py b/tests/models/test_internvl.py
index 66cb8dda248db..6aa0189648d72 100644
--- a/tests/models/test_internvl.py
+++ b/tests/models/test_internvl.py
@@ -5,6 +5,7 @@
import torch
from huggingface_hub import snapshot_download
from PIL.Image import Image
+from transformers import AutoConfig
from vllm.model_executor.models.internvl import (IMG_CONTEXT, IMG_END,
IMG_START,
@@ -26,10 +27,15 @@
# we use snapshot_download to prevent conflicts between
# dynamic_module and trust_remote_code for hf_runner
+DOWNLOAD_PATTERN = ["*.json", "*.py", "*.safetensors", "*.txt", "*.model"]
models = [
- snapshot_download("OpenGVLab/InternVL2-1B"),
- snapshot_download("OpenGVLab/InternVL2-2B"),
- # snapshot_download("OpenGVLab/InternVL2-4B"), # broken
+ snapshot_download("OpenGVLab/InternVL2-1B",
+ allow_patterns=DOWNLOAD_PATTERN),
+ snapshot_download("OpenGVLab/InternVL2-2B",
+ allow_patterns=DOWNLOAD_PATTERN),
+ # Broken due to outdated implementation of Phi-3
+ # See: https://huggingface.co/OpenGVLab/InternVL2-4B/discussions/3
+ # snapshot_download("OpenGVLab/InternVL2-4B"),
]
@@ -41,8 +47,17 @@ def __init__(self, hf_runner: HfRunner):
self.tokenizer = hf_runner.tokenizer
self.dtype = hf_runner.model.dtype
+ self.config = AutoConfig.from_pretrained(hf_runner.model_name)
+ self.vision_config = self.config.vision_config
+ self.use_thumbnail = self.config.use_thumbnail
+ self.min_num = self.config.min_dynamic_patch
+ self.max_num = self.config.max_dynamic_patch
+ self.image_size = self.vision_config.image_size
+
def __call__(self, text: str, images: Image, **kwargs):
- pixel_values = image_to_pixel_values(images).to(self.dtype)
+ pixel_values = image_to_pixel_values(images, self.image_size,
+ self.min_num, self.max_num,
+ self.use_thumbnail).to(self.dtype)
num_patches_list = [pixel_values.shape[0]]
for num_patches in num_patches_list:
context_tokens = IMG_CONTEXT * self.num_image_token * num_patches
diff --git a/tests/models/utils.py b/tests/models/utils.py
index 425f57ef9b966..d96301b853c85 100644
--- a/tests/models/utils.py
+++ b/tests/models/utils.py
@@ -1,4 +1,5 @@
import warnings
+from enum import Enum
from typing import Dict, List, Optional, Sequence, Tuple, Union
from vllm.sequence import SampleLogprobs
@@ -45,11 +46,27 @@ def check_logprobs_close(
outputs_1_lst: Sequence[TokensTextLogprobs],
name_0: str,
name_1: str,
+ num_outputs_0_skip_tokens: int = 0,
warn_on_mismatch: bool = True,
):
"""
Compare the logprobs of two sequences generated by different models,
which should be similar but not necessarily equal.
+
+ Arguments:
+
+ * outputs_0_lst: First sequence to compare
+ * outputs_0_lst: Second sequence to compare
+ * name_0: sequence #0 name
+ * name_1: sequence #1 name
+ * num_outputs_0_skip_tokens: If > 0, specifies the number of initial
+ sequence #0 tokens & logprobs to discard
+ before comparison, i.e. all
+ of sequence #1 will be compared to
+ sequence #0 beginning at index
+ num_outputs_0_skip_tokens
+ * warn_on_mismatch: Issue a warning if there is token-wise or text-wise
+ mismatch between the two sequences
"""
assert len(outputs_0_lst) == len(outputs_1_lst)
@@ -65,6 +82,15 @@ def check_logprobs_close(
if logprobs_1 is None:
logprobs_1 = [None] * len(output_ids_1)
+ # Skip specified number of initial sequence #0 tokens
+ # & logprobs, leaving output text as-is for simplicity
+ # (text mismatches may generate warnings but do not
+ # cause the test to fail.)
+ if num_outputs_0_skip_tokens < 0:
+ raise ValueError("num_outputs_0_skip_tokens must be non-negative")
+ output_ids_0 = output_ids_0[num_outputs_0_skip_tokens:]
+ logprobs_0 = logprobs_0[num_outputs_0_skip_tokens:]
+
# Loop through generated tokens.
for idx, (output_id_0,
output_id_1) in enumerate(zip(output_ids_0, output_ids_1)):
@@ -110,3 +136,13 @@ def check_logprobs_close(
warnings.simplefilter("always")
warnings.warn(fail_msg, stacklevel=2)
+
+
+class DecoderPromptType(Enum):
+ '''
+ For encoder/decoder models only -
+
+ '''
+ CUSTOM = 1
+ NONE = 2
+ EMPTY_STR = 3
diff --git a/tests/quantization/test_compressed_tensors.py b/tests/quantization/test_compressed_tensors.py
index bd79da84a7764..2ea340779b819 100644
--- a/tests/quantization/test_compressed_tensors.py
+++ b/tests/quantization/test_compressed_tensors.py
@@ -9,7 +9,7 @@
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
CompressedTensorsLinearMethod, CompressedTensorsW4A16Sparse24,
CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8,
- CompressedTensorsWNA16)
+ CompressedTensorsW8A16Fp8, CompressedTensorsWNA16)
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
QuantizationType)
@@ -109,7 +109,7 @@ def test_compressed_tensors_wNa16(vllm_runner, wNa16_args):
assert qkv_proj.weight_packed.dtype is torch.int32
assert qkv_proj.weight_scale.dtype is torch.float16
- assert qkv_proj.weight_packed.pack_factor == pack_factor
+ assert qkv_proj.scheme.pack_factor == pack_factor
output = llm.generate_greedy("Hello my name is", max_tokens=20)
assert output
@@ -140,13 +140,17 @@ def test_compressed_tensors_fp8(vllm_runner):
qkv_proj = layer.self_attn.qkv_proj
assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
- assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8Fp8)
- assert qkv_proj.weight.dtype is torch.float8_e4m3fn
+ assert isinstance(
+ qkv_proj.scheme,
+ (CompressedTensorsW8A8Fp8, CompressedTensorsW8A16Fp8))
+
assert qkv_proj.input_scale.dtype is torch.float32
- assert qkv_proj.weight_scale.dtype is torch.float32
- # should be scalars after processing
- assert len(qkv_proj.input_scale.shape) == 0
- assert len(qkv_proj.weight_scale.shape) == 0
+
+ if isinstance(qkv_proj.scheme, CompressedTensorsW8A8Fp8):
+ assert len(qkv_proj.input_scale.shape) == 0
+ assert qkv_proj.weight.dtype is torch.float8_e4m3fn
+ assert qkv_proj.weight_scale.dtype is torch.float32
+ assert len(qkv_proj.weight_scale.shape) == 0
output = llm.generate_greedy("Hello my name is", max_tokens=20)
assert output
diff --git a/tests/quantization/test_fp8.py b/tests/quantization/test_fp8.py
index a020f7bf37262..ebb06ed20f249 100644
--- a/tests/quantization/test_fp8.py
+++ b/tests/quantization/test_fp8.py
@@ -9,6 +9,7 @@
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.fp8 import (Fp8KVCacheMethod,
Fp8LinearMethod)
+from vllm.platforms import current_platform
MODELS = [
"neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV",
@@ -20,7 +21,12 @@
@pytest.mark.skipif(not is_quant_method_supported("fp8"),
reason="FP8 is not supported on this GPU type.")
@pytest.mark.parametrize("model_id", MODELS)
-def test_model_load_and_run(vllm_runner, model_id: str):
+@pytest.mark.parametrize("force_marlin", [False, True])
+def test_model_load_and_run(vllm_runner, model_id: str, force_marlin: bool,
+ monkeypatch) -> None:
+ if force_marlin:
+ monkeypatch.setenv("VLLM_TEST_FORCE_FP8_MARLIN", "1")
+
with vllm_runner(model_id) as llm:
# note: this does not test accuracy, just that we can run through
# see lm-eval tests for accuracy
@@ -61,7 +67,12 @@ def test_kv_cache_model_load_and_run(vllm_runner, model_id: str):
@pytest.mark.skipif(not is_quant_method_supported("fp8"),
reason="FP8 is not supported on this GPU type.")
@pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8"])
-def test_load_fp16_model(vllm_runner, kv_cache_dtype: str) -> None:
+@pytest.mark.parametrize("force_marlin", [False, True])
+def test_load_fp16_model(vllm_runner, kv_cache_dtype: str, force_marlin: bool,
+ monkeypatch) -> None:
+ if force_marlin:
+ monkeypatch.setenv("VLLM_TEST_FORCE_FP8_MARLIN", "1")
+
with vllm_runner("facebook/opt-125m",
quantization="fp8",
kv_cache_dtype=kv_cache_dtype) as llm:
@@ -75,9 +86,9 @@ def test_load_fp16_model(vllm_runner, kv_cache_dtype: str) -> None:
assert attn._k_scale == 1.0
assert attn._v_scale == 1.0
- capability = torch.cuda.get_device_capability()
+ capability = current_platform.get_device_capability()
capability = capability[0] * 10 + capability[1]
- if capability >= 89:
+ if capability >= 89 and not force_marlin:
# For GPUs with hardware support, we keep weights in fp8
assert fc1.weight.dtype == torch.float8_e4m3fn
else:
diff --git a/tests/test_utils.py b/tests/test_utils.py
index 8203b5d2f960d..8d22c20bb1977 100644
--- a/tests/test_utils.py
+++ b/tests/test_utils.py
@@ -2,6 +2,7 @@
import os
import socket
import sys
+from functools import partial
from typing import (TYPE_CHECKING, Any, AsyncIterator, Awaitable, Protocol,
Tuple, TypeVar)
@@ -37,11 +38,11 @@ async def mock_async_iterator(idx: int) -> AsyncIterator[str]:
yield f"item from iterator {idx}"
await asyncio.sleep(0.1)
except asyncio.CancelledError:
- pass
+ print(f"iterator {idx} cancelled")
iterators = [mock_async_iterator(i) for i in range(3)]
merged_iterator: AsyncIterator[Tuple[int, str]] = merge_async_iterators(
- *iterators)
+ *iterators, is_cancelled=partial(asyncio.sleep, 0, result=False))
async def stream_output(generator: AsyncIterator[Tuple[int, str]]):
async for idx, output in generator:
diff --git a/tests/utils.py b/tests/utils.py
index 666694299d397..e3d04cc638a95 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -50,7 +50,7 @@ def _nvml():
class RemoteOpenAIServer:
DUMMY_API_KEY = "token-abc123" # vLLM's OpenAI server does not need API key
- MAX_SERVER_START_WAIT_S = 120 # wait for server to start for 120 seconds
+ MAX_START_WAIT_S = 120 # wait for server to start for 120 seconds
def __init__(
self,
@@ -85,7 +85,7 @@ def __init__(
stdout=sys.stdout,
stderr=sys.stderr)
self._wait_for_server(url=self.url_for("health"),
- timeout=self.MAX_SERVER_START_WAIT_S)
+ timeout=self.MAX_START_WAIT_S)
def __enter__(self):
return self
@@ -266,8 +266,9 @@ def compare_two_settings(model: str,
arg1_results = results[:n]
arg2_results = results[n:]
for arg1_result, arg2_result in zip(arg1_results, arg2_results):
- assert arg1_result == arg2_result, \
- f"Results for {model=} are not the same with {arg1=} and {arg2=}"
+ assert arg1_result == arg2_result, (
+ f"Results for {model=} are not the same with {arg1=} and {arg2=}. "
+ f"{arg1_result=} != {arg2_result=}")
def init_test_distributed_environment(
diff --git a/tests/worker/test_encoder_decoder_model_runner.py b/tests/worker/test_encoder_decoder_model_runner.py
new file mode 100644
index 0000000000000..8a2e9b81580fc
--- /dev/null
+++ b/tests/worker/test_encoder_decoder_model_runner.py
@@ -0,0 +1,480 @@
+from typing import List
+
+import pytest
+import torch
+
+from vllm.engine.arg_utils import EngineArgs
+from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
+from vllm.utils import is_cpu
+from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner
+
+# CUDA graph scenarios to test
+#
+# Currently CUDA graph is not supported
+ENFORCE_EAGER = [True]
+
+BATCH_SIZES = [1, 4, 16, 64, 256]
+
+
+def _create_model_runner(model: str, *args,
+ **kwargs) -> EncoderDecoderModelRunner:
+ engine_args = EngineArgs(model, *args, **kwargs)
+ engine_config = engine_args.create_engine_config()
+ model_runner = EncoderDecoderModelRunner(
+ model_config=engine_config.model_config,
+ parallel_config=engine_config.parallel_config,
+ scheduler_config=engine_config.scheduler_config,
+ device_config=engine_config.device_config,
+ cache_config=engine_config.cache_config,
+ load_config=engine_config.load_config,
+ lora_config=engine_config.lora_config,
+ prompt_adapter_config=engine_config.prompt_adapter_config,
+ is_driver_worker=True,
+ )
+ return model_runner
+
+
+@pytest.mark.skipif(condition=is_cpu(),
+ reason="CPU backend is currently "
+ "unsupported for encoder/ "
+ "decoder models")
+@pytest.mark.parametrize("enforce_eager", ENFORCE_EAGER)
+def test_empty_seq_group(enforce_eager, ):
+ """Verify prepare prompt and decode returns empty output
+ for empty seq group list"""
+
+ model_runner = _create_model_runner(
+ "facebook/bart-base",
+ seed=0,
+ dtype="float16",
+ max_num_batched_tokens=100000,
+ max_num_seqs=100000,
+ enable_chunked_prefill=False,
+ enforce_eager=enforce_eager,
+ )
+ seq_group_metadata_list: List[SequenceGroupMetadata] = []
+ model_input = model_runner._prepare_model_input_tensors(
+ seq_group_metadata_list)
+ (
+ input_tokens,
+ input_positions,
+ encoder_input_tokens,
+ encoder_input_positions,
+ attn_metadata,
+ return_seq_lens,
+ ) = (
+ model_input.input_tokens,
+ model_input.input_positions,
+ model_input.encoder_input_tokens,
+ model_input.encoder_input_positions,
+ model_input.attn_metadata,
+ model_input.seq_lens,
+ )
+ assert input_tokens is None
+ assert input_positions is None
+ assert encoder_input_tokens is None
+ assert encoder_input_positions is None
+ assert attn_metadata is None
+ assert return_seq_lens is None
+
+
+@pytest.mark.skipif(condition=is_cpu(),
+ reason="CPU backend is currently "
+ "unsupported for encoder/ "
+ "decoder models")
+@pytest.mark.parametrize("batch_size", BATCH_SIZES)
+@pytest.mark.parametrize("enforce_eager", ENFORCE_EAGER)
+def test_prepare_prompt(
+ batch_size,
+ enforce_eager,
+):
+ '''
+ Test the ability of the encoder/decoder model runner subclass to
+ produce prefill-phase model inputs & attention metadata.
+
+ Test behavior:
+
+ * Instantiate BART base model & enc/dec model runner
+ * Construct sequence-group metadata for dummy prompts
+ * Test that encoder attention, decoder self-attention,
+ and encoder/decoder cross-attention inputs are correct
+
+ Arguments:
+
+ * batch_size
+ * backend_name: The attention backend under test
+ * enforce_eager: Enforce eager mode if True (i.e. no CUDAGraph)
+ '''
+
+ model_runner = _create_model_runner(
+ "facebook/bart-base",
+ seed=0,
+ dtype="float16",
+ max_num_batched_tokens=100000,
+ max_num_seqs=100000,
+ enable_chunked_prefill=False,
+ enforce_eager=enforce_eager,
+ )
+
+ seq_lens: List[int] = []
+ encoder_seq_lens: List[int] = []
+ seq_group_metadata_list: List[SequenceGroupMetadata] = []
+ block_tables = {0: [1]}
+ cross_block_table = [2]
+ for i in range(batch_size):
+ # make sure all tokens fit into one block
+ seq_len = i % (model_runner.block_size - 1) + 1
+ seq_lens.append(seq_len)
+ seq_data = SequenceData(list(range(seq_len)))
+ encoder_seq_len = (i + 1) % (model_runner.block_size - 1) + 1
+ encoder_seq_lens.append(encoder_seq_len)
+ encoder_seq_data = SequenceData(list(range(encoder_seq_len)))
+ seq_group_metadata = SequenceGroupMetadata(
+ request_id=f"test_{i}",
+ is_prompt=True,
+ seq_data={0: seq_data},
+ sampling_params=SamplingParams(temperature=0),
+ block_tables=block_tables,
+ encoder_seq_data=encoder_seq_data,
+ cross_block_table=cross_block_table,
+ )
+ assert seq_group_metadata.token_chunk_size == seq_data.get_len()
+ seq_group_metadata_list.append(seq_group_metadata)
+
+ # Build
+ # * Decoder model inputs
+ # * Decoder self-attention KV caching data structures
+ # * Encoder model inputs
+ # * Encoder/decoder cross-attention KV caching data structures
+ model_input = model_runner.prepare_model_input(seq_group_metadata_list)
+
+ input_tokens = model_input.input_tokens
+ input_positions = model_input.input_positions
+ attn_metadata = model_input.attn_metadata
+ return_seq_lens = model_input.seq_lens
+ slot_mapping = attn_metadata.slot_mapping
+ encoder_input_tokens = model_input.encoder_input_tokens
+ encoder_input_positions = model_input.encoder_input_positions
+ cross_slot_mapping = attn_metadata.cross_slot_mapping
+ assert return_seq_lens == seq_lens
+ assert len(slot_mapping) == len(input_tokens)
+ assert len(cross_slot_mapping) == len(encoder_input_tokens)
+
+ # Verify input metadata is correct for prompts.
+ # - Decoder attention metadata
+ device = model_runner.device
+ assert attn_metadata.num_prefills > 0
+ assert attn_metadata.num_decode_tokens == 0
+ assert torch.equal(attn_metadata.seq_lens_tensor,
+ torch.tensor(seq_lens, device=device, dtype=torch.int))
+ assert attn_metadata.seq_lens == seq_lens
+ assert attn_metadata.max_prefill_seq_len == max(seq_lens)
+ assert attn_metadata.max_decode_seq_len == 0
+ # - Encoder attention metadata
+ assert attn_metadata.encoder_seq_lens == encoder_seq_lens
+ assert torch.equal(
+ attn_metadata.encoder_seq_lens_tensor,
+ torch.tensor(encoder_seq_lens, device=device, dtype=torch.int))
+ assert attn_metadata.max_encoder_seq_len == max(encoder_seq_lens)
+ assert attn_metadata.num_encoder_tokens == sum(encoder_seq_lens)
+
+ # Test decoder subquery start locs.
+ start_idx = 0
+ start_loc = [start_idx]
+ for seq_len in seq_lens:
+ start_idx += seq_len
+ start_loc.append(start_idx)
+ assert torch.equal(
+ attn_metadata.query_start_loc,
+ torch.tensor(start_loc, dtype=torch.int32, device=device),
+ )
+
+ # Test decoder seq start locs & context lengths
+
+ assert torch.equal(
+ attn_metadata.seq_start_loc,
+ torch.tensor(start_loc, dtype=torch.int32, device=device),
+ )
+ assert torch.equal(
+ attn_metadata.context_lens_tensor,
+ torch.zeros(attn_metadata.context_lens_tensor.shape[0],
+ dtype=torch.int,
+ device=device),
+ )
+
+ # Verify block tables are correct for prompts
+ # - Decoder self-attention
+ expected = torch.tensor(
+ [[] for _ in range(len(seq_group_metadata_list))],
+ dtype=torch.int32,
+ device=model_runner.device,
+ )
+ assert torch.equal(
+ attn_metadata.block_tables,
+ expected,
+ )
+ # - Encoder/decoder cross-attention
+ assert torch.equal(
+ attn_metadata.cross_block_tables,
+ expected,
+ )
+
+ # Cuda graph should not be used for prefill.
+ assert attn_metadata.use_cuda_graph is False
+
+ # Verify the lengths of input tokens & positions
+ # - Decoder
+ assert len(input_tokens) == sum(seq_lens)
+ assert len(input_positions) == sum(seq_lens)
+ # -- An indirect check that model_input.input_tokens
+ # and model_input.input_positions are correct -
+ # by design of the test, the input tokens are
+ # equal to the input position values, so if
+ # the model_input data structure has the correct
+ # values then these two should be equal
+ assert torch.equal(
+ input_tokens,
+ input_positions,
+ )
+ # - Encoder
+ assert len(encoder_input_tokens) == sum(encoder_seq_lens)
+ # -- An indirect check that model_input.encoder_input_tokens
+ # and model_input.encoder_input_positions are correct -
+ # by design of the test, the input tokens are
+ # equal to the input position values, so if
+ # the model_input data structure has the correct
+ # values then these two should be equal
+ assert torch.equal(
+ encoder_input_tokens,
+ encoder_input_positions,
+ )
+
+ # Test that vLLM sampling infrastructure chooses the correct
+ # sequence positions at which to sample (i.e. the end of
+ # each sequence) in the prefill phase
+
+ expected_selected_token_indices = []
+ selected_token_start_idx = 0
+ for seq_len in seq_lens:
+ # Compute the index offset of the final token in each
+ # prompt (recall that the prompts are concatenated)
+ expected_selected_token_indices.append(selected_token_start_idx +
+ seq_len - 1)
+ selected_token_start_idx += seq_len
+
+ sampling_metadata = model_input.sampling_metadata
+ actual = sampling_metadata.selected_token_indices
+ expected = torch.tensor(
+ expected_selected_token_indices,
+ device=actual.device,
+ dtype=actual.dtype,
+ )
+ assert torch.equal(actual, expected)
+
+
+@pytest.mark.skipif(condition=is_cpu(),
+ reason="CPU backend is currently "
+ "unsupported for encoder/ "
+ "decoder models")
+@pytest.mark.parametrize("batch_size", BATCH_SIZES)
+@pytest.mark.parametrize("enforce_eager", ENFORCE_EAGER)
+def test_prepare_decode(
+ batch_size,
+ enforce_eager,
+):
+ '''
+ Test the ability of the encoder/decoder model runner subclass to
+ produce decode-phase model inputs & attention metadata.
+
+ Test behavior:
+
+ * Instantiate BART base model & enc/dec model runner
+ * Construct sequence-group metadata for dummy prompts
+ * Test that encoder attention, decoder self-attention,
+ and encoder/decoder cross-attention inputs are correct
+
+ Arguments:
+
+ * batch_size
+ * backend_name: The attention backend under test
+ * enforce_eager: Enforce eager mode if True (i.e. no CUDAGraph)
+ '''
+
+ model_runner = _create_model_runner(
+ "facebook/bart-base",
+ seed=0,
+ dtype="float16",
+ max_num_batched_tokens=100000,
+ max_num_seqs=100000,
+ enable_chunked_prefill=False,
+ enforce_eager=enforce_eager,
+ )
+
+ seq_lens: List[int] = []
+ encoder_seq_lens: List[int] = []
+ seq_group_metadata_list: List[SequenceGroupMetadata] = []
+ block_tables = {0: [1]}
+ cross_block_table = [2]
+ for i in range(batch_size):
+ # make sure all tokens fit into one block
+ seq_len = i % (model_runner.block_size - 1) + 1
+ seq_lens.append(seq_len)
+ seq_data = SequenceData(list(range(seq_len)))
+ encoder_seq_len = (i + 1) % (model_runner.block_size - 1) + 1
+ encoder_seq_lens.append(encoder_seq_len)
+ encoder_seq_data = SequenceData(list(range(encoder_seq_len)))
+ seq_group_metadata = SequenceGroupMetadata(
+ request_id=f"test_{i}",
+ is_prompt=False,
+ seq_data={0: seq_data},
+ sampling_params=SamplingParams(temperature=0),
+ block_tables=block_tables,
+ encoder_seq_data=encoder_seq_data,
+ cross_block_table=cross_block_table,
+ )
+ assert seq_group_metadata.token_chunk_size == 1
+ seq_group_metadata_list.append(seq_group_metadata)
+
+ # Build
+ # * Decoder model inputs
+ # * Decoder self-attention KV caching data structures
+ # * Encoder model inputs
+ # * Encoder/decoder cross-attention KV caching data structures
+ model_input = model_runner.prepare_model_input(seq_group_metadata_list)
+ input_tokens = model_input.input_tokens
+ input_positions = model_input.input_positions
+ attn_metadata = model_input.attn_metadata
+ return_seq_lens = model_input.seq_lens
+ slot_mapping = attn_metadata.slot_mapping
+ encoder_input_tokens = model_input.encoder_input_tokens
+ encoder_input_positions = model_input.encoder_input_positions
+ cross_slot_mapping = attn_metadata.cross_slot_mapping
+ assert return_seq_lens == seq_lens
+ assert len(slot_mapping) == len(input_tokens)
+ assert len(cross_slot_mapping) == len(encoder_input_tokens)
+
+ # Verify input metadata is correct for decode phase.
+ # - Decoder attention metadata
+ device = model_runner.device
+ assert attn_metadata.num_prefills == 0
+ assert attn_metadata.num_decode_tokens > 0
+ assert torch.equal(attn_metadata.seq_lens_tensor,
+ torch.tensor(seq_lens, device=device, dtype=torch.int))
+ assert attn_metadata.seq_lens == seq_lens
+ assert attn_metadata.max_prefill_seq_len == 0
+ assert attn_metadata.max_decode_seq_len == max(seq_lens)
+ # - Encoder attention metadata
+ assert attn_metadata.encoder_seq_lens == encoder_seq_lens
+ assert torch.equal(
+ attn_metadata.encoder_seq_lens_tensor,
+ torch.tensor(encoder_seq_lens, device=device, dtype=torch.int))
+ assert attn_metadata.max_encoder_seq_len == max(encoder_seq_lens)
+ assert attn_metadata.num_encoder_tokens == sum(encoder_seq_lens)
+
+ # Test decoder subquery start locs.
+ start_idx = 0
+ start_loc = [start_idx]
+ for seq_len in seq_lens:
+ start_idx += 1
+ start_loc.append(start_idx)
+ assert torch.equal(
+ attn_metadata.query_start_loc,
+ torch.tensor(start_loc, dtype=torch.int32, device=device),
+ )
+
+ # Test decoder seq start locs. Note that for normal prefill it is
+ # equivalent to query_start_loc.
+ start_idx = 0
+ seq_start_loc = [start_idx]
+ for seq_len in seq_lens:
+ start_idx += seq_len
+ seq_start_loc.append(start_idx)
+
+ # Test seq_start_loc and context lengths
+
+ assert torch.equal(
+ attn_metadata.seq_start_loc,
+ torch.tensor(seq_start_loc, dtype=torch.int32, device=device),
+ )
+ assert torch.equal(
+ attn_metadata.context_lens_tensor,
+ torch.tensor([seq_len - 1 for seq_len in seq_lens],
+ dtype=torch.int,
+ device=device))
+
+ # Verify block tables are correct for prompts
+ # - Decoder self-attention
+ expected = torch.tensor(
+ [block_tables[0] for _ in range(len(seq_group_metadata_list))],
+ dtype=torch.int32,
+ device=model_runner.device)
+ assert torch.equal(
+ attn_metadata.block_tables,
+ expected,
+ )
+ # - Encoder/decoder cross-attention
+ expected = torch.tensor(
+ [cross_block_table for _ in range(len(seq_group_metadata_list))],
+ dtype=torch.int32,
+ device=model_runner.device)
+ assert torch.equal(
+ attn_metadata.cross_block_tables,
+ expected,
+ )
+
+ # Cuda graph should is currently not supported for encoder/decoer.
+ assert attn_metadata.use_cuda_graph is False
+
+ # Verify the lengths of input tokens & positions
+ # - Decoder
+ assert len(input_tokens) == len(seq_lens)
+ assert len(input_positions) == len(seq_lens)
+ # -- An indirect check that model_input.input_tokens
+ # and model_input.input_positions are correct -
+ # by design of the test, the input tokens are
+ # equal to the input position values, so if
+ # the model_input data structure has the correct
+ # values then these two should be equal
+ assert torch.equal(
+ input_tokens,
+ input_positions,
+ )
+ # - Encoder
+ assert len(encoder_input_tokens) == 0
+ assert len(encoder_input_tokens) == 0
+ # -- An indirect check that model_input.encoder_input_tokens
+ # and model_input.encoder_input_positions are correct -
+ # by design of the test, the input tokens are
+ # equal to the input position values, so if
+ # the model_input data structure has the correct
+ # values then these two should be equal
+ assert torch.equal(
+ encoder_input_tokens,
+ encoder_input_positions,
+ )
+
+ # Test that vLLM sampling infrastructure chooses the correct
+ # sequence positions at which to sample (i.e. the end of
+ # each sequence) in the decode phase
+
+ expected_selected_token_indices = []
+ selected_token_start_idx = 0
+ for seq_len in seq_lens:
+ # Compute the index offset of the final token in each
+ # sequence's decoded outputs; since a single token is
+ # decoded per iteration per sequence, then the length
+ # of the decoded tokens for a given sequence is 1 and
+ # the final index offset into a given sequence's
+ # generated tokens is 0 (i.e. the expected sampling index
+ # for a given sequence is just `selected_token_start_idx`)
+ expected_selected_token_indices.append(selected_token_start_idx)
+ selected_token_start_idx += 1
+
+ sampling_metadata = model_input.sampling_metadata
+ actual = sampling_metadata.selected_token_indices
+ expected = torch.tensor(
+ expected_selected_token_indices,
+ device=actual.device,
+ dtype=actual.dtype,
+ )
+ assert torch.equal(actual, expected)
diff --git a/vllm/attention/__init__.py b/vllm/attention/__init__.py
index 44bfae44cfddd..4643d316d48b7 100644
--- a/vllm/attention/__init__.py
+++ b/vllm/attention/__init__.py
@@ -1,6 +1,7 @@
from vllm.attention.backends.abstract import (AttentionBackend,
AttentionMetadata,
- AttentionMetadataBuilder)
+ AttentionMetadataBuilder,
+ AttentionType)
from vllm.attention.layer import Attention
from vllm.attention.selector import get_attn_backend
@@ -8,6 +9,7 @@
"Attention",
"AttentionBackend",
"AttentionMetadata",
+ "AttentionType",
"AttentionMetadataBuilder",
"Attention",
"get_attn_backend",
diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py
index 2c21502dcf407..ecf964fa49d9b 100644
--- a/vllm/attention/layer.py
+++ b/vllm/attention/layer.py
@@ -4,7 +4,7 @@
import torch
import torch.nn as nn
-from vllm.attention.backends.abstract import AttentionMetadata, AttentionType
+from vllm.attention import AttentionMetadata, AttentionType
from vllm.attention.selector import get_attn_backend
from vllm.config import CacheConfig
from vllm.model_executor.layers.quantization.base_config import (
diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py
index 8fcd85585a18f..d5c8d6a376961 100644
--- a/vllm/attention/selector.py
+++ b/vllm/attention/selector.py
@@ -1,6 +1,8 @@
import enum
+import os
+from contextlib import contextmanager
from functools import lru_cache
-from typing import Optional, Type
+from typing import Generator, Optional, Type
import torch
@@ -8,7 +10,8 @@
from vllm.attention.backends.abstract import AttentionBackend
from vllm.logger import init_logger
from vllm.platforms import current_platform
-from vllm.utils import is_cpu, is_hip, is_openvino, is_tpu, is_xpu
+from vllm.utils import (STR_BACKEND_ENV_VAR, is_cpu, is_hip, is_openvino,
+ is_tpu, is_xpu)
logger = init_logger(__name__)
@@ -24,6 +27,66 @@ class _Backend(enum.Enum):
IPEX = enum.auto()
+def backend_name_to_enum(backend_name: str) -> _Backend:
+ assert backend_name is not None
+
+ backend_members = _Backend.__members__
+ if backend_name not in backend_members:
+ raise ValueError(f"Invalid attention backend '{backend_name}'. "
+ f"Available backends: {', '.join(backend_members)} "
+ "(case-sensitive).")
+
+ return _Backend[backend_name]
+
+
+def get_env_variable_attn_backend() -> Optional[_Backend]:
+ '''
+ Get the backend override specified by the vLLM attention
+ backend environment variable, if one is specified.
+
+ Returns:
+
+ * _Backend enum value if an override is specified
+ * None otherwise
+ '''
+ backend_name = os.environ.get(STR_BACKEND_ENV_VAR)
+ return (None
+ if backend_name is None else backend_name_to_enum(backend_name))
+
+
+# Global state allows a particular choice of backend
+# to be forced, overriding the logic which auto-selects
+# a backend based on system & workload configuration
+# (default behavior if this variable is None)
+#
+# THIS SELECTION TAKES PRECEDENCE OVER THE
+# VLLM ATTENTION BACKEND ENVIRONMENT VARIABLE
+forced_attn_backend: Optional[_Backend] = None
+
+
+def global_force_attn_backend(attn_backend: Optional[_Backend]) -> None:
+ '''
+ Force all attention operations to use a specified backend.
+
+ Passing `None` for the argument re-enables automatic
+ backend selection.,
+
+ Arguments:
+
+ * attn_backend: backend selection (None to revert to auto)
+ '''
+ global forced_attn_backend
+ forced_attn_backend = attn_backend
+
+
+def get_global_forced_attn_backend() -> Optional[_Backend]:
+ '''
+ Get the currently-forced choice of attention backend,
+ or None if auto-selection is currently enabled.
+ '''
+ return forced_attn_backend
+
+
@lru_cache(maxsize=None)
def get_attn_backend(
num_heads: int,
@@ -101,16 +164,20 @@ def which_attn_to_use(
# Default case.
selected_backend = _Backend.FLASH_ATTN
- # Check the environment variable and override if specified
- backend_by_env_var: Optional[str] = envs.VLLM_ATTENTION_BACKEND
- if backend_by_env_var is not None:
- backend_members = _Backend.__members__
- if backend_by_env_var not in backend_members:
- raise ValueError(
- f"Invalid attention backend '{backend_by_env_var}'. "
- f"Available backends: {', '.join(backend_members)} "
- "(case-sensitive).")
- selected_backend = _Backend[backend_by_env_var]
+ # Check whether a particular choice of backend was
+ # previously forced.
+ #
+ # THIS SELECTION OVERRIDES THE VLLM_ATTENTION_BACKEND
+ # ENVIRONMENT VARIABLE.
+ backend_by_global_setting: Optional[_Backend] = (
+ get_global_forced_attn_backend())
+ if backend_by_global_setting is not None:
+ selected_backend = backend_by_global_setting
+ else:
+ # Check the environment variable and override if specified
+ backend_by_env_var: Optional[str] = envs.VLLM_ATTENTION_BACKEND
+ if backend_by_env_var is not None:
+ selected_backend = backend_name_to_enum(backend_by_env_var)
if is_cpu():
if selected_backend != _Backend.TORCH_SDPA:
@@ -193,3 +260,35 @@ def which_attn_to_use(
selected_backend = _Backend.XFORMERS
return selected_backend
+
+
+@contextmanager
+def global_force_attn_backend_context_manager(
+ attn_backend: _Backend) -> Generator[None, None, None]:
+ '''
+ Globally force a vLLM attention backend override within a
+ context manager, reverting the global attention backend
+ override to its prior state upon exiting the context
+ manager.
+
+ Arguments:
+
+ * attn_backend: attention backend to force
+
+ Returns:
+
+ * Generator
+ '''
+
+ # Save the current state of the global backend override (if any)
+ original_value = get_global_forced_attn_backend()
+
+ # Globally force the new backend override
+ global_force_attn_backend(attn_backend)
+
+ # Yield control back to the enclosed code block
+ try:
+ yield
+ finally:
+ # Revert the original global backend override, if any
+ global_force_attn_backend(original_value)
diff --git a/vllm/config.py b/vllm/config.py
index 3cc197f3d655f..ec6d587e7925b 100644
--- a/vllm/config.py
+++ b/vllm/config.py
@@ -12,7 +12,8 @@
from vllm.model_executor.models import ModelRegistry
from vllm.tracing import is_otel_installed
from vllm.transformers_utils.config import get_config, get_hf_text_config
-from vllm.utils import (cuda_device_count_stateless, get_cpu_memory, is_cpu,
+from vllm.utils import (STR_NOT_IMPL_ENC_DEC_CUDAGRAPH,
+ cuda_device_count_stateless, get_cpu_memory, is_cpu,
is_hip, is_neuron, is_openvino, is_tpu, is_xpu,
print_warning_once)
@@ -87,6 +88,9 @@ class ModelConfig:
enforce_eager: Whether to enforce eager execution. If True, we will
disable CUDA graph and always execute the model in eager mode.
If False, we will use CUDA graph and eager execution in hybrid.
+ If None, the user did not specify, so default to False -
+ except for encoder/decoder models, which currently require
+ eager mode.
max_context_len_to_capture: Maximum context len covered by CUDA graphs.
When a sequence has context length larger than this, we fall back
to eager mode (DEPRECATED. Use max_seq_len_to_capture instead).
@@ -121,7 +125,7 @@ def __init__(
max_model_len: Optional[int] = None,
quantization: Optional[str] = None,
quantization_param_path: Optional[str] = None,
- enforce_eager: bool = False,
+ enforce_eager: Optional[bool] = None,
max_context_len_to_capture: Optional[int] = None,
max_seq_len_to_capture: Optional[int] = None,
max_logprobs: int = 20,
@@ -160,6 +164,34 @@ def __init__(
self.hf_text_config = get_hf_text_config(self.hf_config)
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
+ # Choose a default enforce_eager value if the user did not specify
+ # a value (enforce_eager is None)
+ if getattr(self.hf_config, 'is_encoder_decoder', False):
+ if self.enforce_eager is None:
+ # *Only for encoder/decoder models* and
+ # *only if enforce_eager is unset*, override
+ # to enforce_eager=True
+ #
+ # Add a logger message since it is *somewhat* non-intuitive that
+ # enforce_eager is True when the user has not specified its
+ # value.
+ logger.info("Forcing enforce_eager == True because "
+ "enforce_eager setting was unspecified and "
+ "CUDAGraph is not supported with encoder/ "
+ "decoder models.")
+ self.enforce_eager = True
+
+ if not self.enforce_eager:
+ # Eager mode explicitly disabled by user for an encoder/
+ # decoder model; however CUDAGRAPH + encoder/decoder is
+ # not currently supported
+ raise ValueError(STR_NOT_IMPL_ENC_DEC_CUDAGRAPH)
+ elif self.enforce_eager is None:
+ # *Only for decoder-only models*, enforce_eager
+ # defaults to False if unset. This is intuitive
+ # so no logging message needed.
+ self.enforce_eager = False
+
if (not self.disable_sliding_window
and self.hf_text_config.model_type == "gemma2"
and self.hf_text_config.sliding_window is not None):
diff --git a/vllm/core/block/utils.py b/vllm/core/block/utils.py
index 2c412a8f472e0..28839437c33c5 100644
--- a/vllm/core/block/utils.py
+++ b/vllm/core/block/utils.py
@@ -1,15 +1,7 @@
"""Block manager utils."""
from vllm.sequence import SequenceGroup
-
-# Exception strings for non-implemented block manager enc/dec scenarios
-
-STR_NOT_IMPL_ENC_DEC_SWA = \
- "Sliding window attention for encoder/decoder models " + \
- "is not currently supported."
-
-STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE = \
- "Prefix caching for encoder/decoder models " + \
- "is not currently supported."
+from vllm.utils import (STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE,
+ STR_NOT_IMPL_ENC_DEC_SWA)
def _get_block_mgr_sliding_window_attr(block_mgr):
diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py
index 11d020be0c940..f60463107be44 100644
--- a/vllm/core/scheduler.py
+++ b/vllm/core/scheduler.py
@@ -392,6 +392,19 @@ def abort_seq_group(self, request_id: Union[str, Iterable[str]]) -> None:
seq.status = SequenceStatus.FINISHED_ABORTED
self.free_seq(seq)
+ self._free_seq_group_cross_attn_blocks(aborted_group)
+
+ def _free_seq_group_cross_attn_blocks(
+ self,
+ seq_group: SequenceGroup,
+ ) -> None:
+ """
+ Free a sequence group from a cross-attention block table.
+ Has no effect on decoder-only models.
+ """
+ if seq_group.is_encoder_decoder():
+ self.block_manager.free_cross(seq_group)
+
def has_unfinished_seqs(self) -> bool:
return len(self.waiting) != 0 or len(self.running) != 0 or len(
self.swapped) != 0
@@ -963,6 +976,17 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]:
# seq_id -> physical block numbers
block_tables: Dict[int, List[int]] = {}
+ if seq_group.is_encoder_decoder():
+ # Encoder associated with SequenceGroup
+ encoder_seq_data = seq_group.get_encoder_seq().data
+ # Block table for cross-attention
+ # Also managed at SequenceGroup level
+ cross_block_table = self.block_manager.get_cross_block_table(
+ seq_group)
+ else:
+ encoder_seq_data = None
+ cross_block_table = None
+
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
seq_id = seq.seq_id
seq_data[seq_id] = seq.data
@@ -1001,6 +1025,8 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]:
token_chunk_size=token_chunk_size,
lora_request=seq_group.lora_request,
computed_block_nums=common_computed_block_nums,
+ encoder_seq_data=encoder_seq_data,
+ cross_block_table=cross_block_table,
# `multi_modal_data` will only be present for the 1st comm
# between engine and worker.
# the subsequent comms can still use delta, but
@@ -1032,6 +1058,8 @@ def free_finished_seq_groups(self) -> None:
remaining: Deque[SequenceGroup] = deque()
for seq_group in self.running:
if seq_group.is_finished():
+ # Free cross-attention block table, if it exists
+ self._free_seq_group_cross_attn_blocks(seq_group)
# Add the finished requests to the finished requests list.
# This list will be used to update the Mamba cache in the
# next step.
diff --git a/vllm/distributed/device_communicators/custom_all_reduce.py b/vllm/distributed/device_communicators/custom_all_reduce.py
index a4f30808d32e1..479dc95a8b667 100644
--- a/vllm/distributed/device_communicators/custom_all_reduce.py
+++ b/vllm/distributed/device_communicators/custom_all_reduce.py
@@ -11,7 +11,8 @@
gpu_p2p_access_check)
from vllm.distributed.parallel_state import in_the_same_node_as
from vllm.logger import init_logger
-from vllm.utils import cuda_device_count_stateless, is_full_nvlink
+from vllm.platforms import current_platform
+from vllm.utils import cuda_device_count_stateless
try:
assert ops.is_custom_op_supported("_C_custom_ar::meta_size")
@@ -113,7 +114,10 @@ def __init__(self,
# test nvlink first, this will filter out most of the cases
# where custom allreduce is not supported
# this checks hardware and driver support for NVLink
- full_nvlink = is_full_nvlink(physical_device_ids)
+ assert current_platform.is_cuda()
+ from vllm.platforms.cuda import CudaPlatform
+ cuda_platform: CudaPlatform = current_platform
+ full_nvlink = cuda_platform.is_full_nvlink(physical_device_ids)
if world_size > 2 and not full_nvlink:
logger.warning(
"Custom allreduce is disabled because it's not supported on"
diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py
index 935a509cdb7ce..b6d2ea463940f 100644
--- a/vllm/engine/arg_utils.py
+++ b/vllm/engine/arg_utils.py
@@ -69,7 +69,7 @@ class EngineArgs:
rope_theta: Optional[float] = None
tokenizer_revision: Optional[str] = None
quantization: Optional[str] = None
- enforce_eager: bool = False
+ enforce_eager: Optional[bool] = None
max_context_len_to_capture: Optional[int] = None
max_seq_len_to_capture: int = 8192
disable_custom_all_reduce: bool = False
diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py
index 1ad9c1c026618..f4f05808b7417 100644
--- a/vllm/engine/async_llm_engine.py
+++ b/vllm/engine/async_llm_engine.py
@@ -1,7 +1,7 @@
import asyncio
import time
from functools import partial
-from typing import (AsyncIterator, Callable, Dict, Iterable, List, Mapping,
+from typing import (AsyncGenerator, Callable, Dict, Iterable, List, Mapping,
Optional, Set, Tuple, Type, Union)
from transformers import PreTrainedTokenizer
@@ -62,12 +62,16 @@ def _log_task_completion(task: asyncio.Task,
"actual cause.") from e
+STOP_ITERATION = Exception() # Sentinel
+
+
class AsyncStream:
"""A stream of RequestOutputs or EmbeddingRequestOutputs for a request
- that can be iterated over asynchronously."""
+ that can be iterated over asynchronously via an async generator."""
- def __init__(self, request_id: str) -> None:
+ def __init__(self, request_id: str, cancel: Callable[[str], None]) -> None:
self.request_id = request_id
+ self._cancel = cancel
self._queue: asyncio.Queue = asyncio.Queue()
self._finished = False
@@ -77,22 +81,30 @@ def put(self, item: Union[RequestOutput, EmbeddingRequestOutput,
return
self._queue.put_nowait(item)
- def finish(self) -> None:
- self._queue.put_nowait(StopAsyncIteration())
- self._finished = True
+ def finish(self, cancelled: bool = False) -> None:
+ if not self._finished:
+ self._finished = True
+ self._queue.put_nowait(
+ asyncio.CancelledError if cancelled else STOP_ITERATION)
@property
def finished(self) -> bool:
return self._finished
- def __aiter__(self):
- return self
-
- async def __anext__(self) -> Union[RequestOutput, EmbeddingRequestOutput]:
- result = await self._queue.get()
- if isinstance(result, Exception):
- raise result
- return result
+ async def generator(
+ self
+ ) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]:
+ try:
+ while not self._finished:
+ result = await self._queue.get()
+ if isinstance(result, Exception):
+ if result == STOP_ITERATION:
+ return
+ raise result
+ yield result
+ except GeneratorExit:
+ self._cancel(self.request_id)
+ raise asyncio.CancelledError from None
class RequestTracker:
@@ -100,7 +112,7 @@ class RequestTracker:
def __init__(self) -> None:
self._request_streams: Dict[str, AsyncStream] = {}
- self._finished_requests: asyncio.Queue[str] = asyncio.Queue()
+ self._aborted_requests: asyncio.Queue[str] = asyncio.Queue()
self._new_requests: asyncio.Queue[Tuple[AsyncStream,
dict]] = asyncio.Queue()
self.new_requests_event = asyncio.Event()
@@ -131,15 +143,21 @@ def process_request_output(self,
verbose: bool = False) -> None:
"""Process a request output from the engine."""
request_id = request_output.request_id
+ finished = request_output.finished
+ if finished:
+ stream = self._request_streams.pop(request_id, None)
+ else:
+ stream = self._request_streams.get(request_id)
# Guard against a KeyError which can occur if the request was aborted
# while the output was generated
- if (stream := self._request_streams.get(request_id)) is not None:
+ if stream is not None:
stream.put(request_output)
- if request_output.finished:
- if verbose:
- logger.info("Finished request %s.", request_id)
- self.abort_request(request_id)
+ if finished:
+ stream.finish()
+
+ if verbose and finished:
+ logger.info("Finished request %s.", request_id)
def process_exception(self,
request_id: str,
@@ -162,7 +180,8 @@ def add_request(self,
if request_id in self._request_streams:
raise KeyError(f"Request {request_id} already exists.")
- stream = AsyncStream(request_id)
+ abort_request = partial(self.abort_request, verbose=verbose)
+ stream = AsyncStream(request_id, abort_request)
self._new_requests.put_nowait((stream, {
"request_id": request_id,
**engine_add_request_kwargs
@@ -175,36 +194,36 @@ def add_request(self,
return stream
- def abort_request(self, request_id: str, *, verbose: bool = False) -> None:
+ def abort_request(self,
+ request_id: str,
+ *,
+ cancelled: bool = False,
+ verbose: bool = False) -> None:
"""Abort a request during next background loop iteration."""
if verbose:
logger.info("Aborted request %s.", request_id)
- self._finished_requests.put_nowait(request_id)
+ self._aborted_requests.put_nowait(request_id)
- if request_id not in self._request_streams or self._request_streams[
- request_id].finished:
- # The request has already finished or been aborted.
- return
-
- self._request_streams[request_id].finish()
+ stream = self._request_streams.pop(request_id, None)
+ if stream is not None:
+ stream.finish(cancelled=cancelled)
- def get_new_and_finished_requests(self) -> Tuple[List[Dict], Set[str]]:
+ def get_new_and_aborted_requests(self) -> Tuple[List[Dict], Set[str]]:
"""Get the new requests and finished requests to be
sent to the engine."""
new_requests: List[Dict] = []
finished_requests: Set[str] = set()
- while not self._finished_requests.empty():
- request_id = self._finished_requests.get_nowait()
+ while not self._aborted_requests.empty():
+ request_id = self._aborted_requests.get_nowait()
finished_requests.add(request_id)
- self._request_streams.pop(request_id, None)
while not self._new_requests.empty():
stream, new_request = self._new_requests.get_nowait()
if stream.request_id in finished_requests:
# The request has already been aborted.
- stream.finish()
+ stream.finish(cancelled=True)
continue
self._request_streams[stream.request_id] = stream
new_requests.append(new_request)
@@ -556,8 +575,8 @@ async def engine_step(self, virtual_engine: int) -> bool:
Returns True if there are in-progress requests."""
- new_requests, finished_requests = (
- self._request_tracker.get_new_and_finished_requests())
+ new_requests, aborted_requests = (
+ self._request_tracker.get_new_and_aborted_requests())
for new_request in new_requests:
# Add the request into the vLLM engine's waiting queue.
@@ -576,8 +595,8 @@ async def engine_step(self, virtual_engine: int) -> bool:
verbose=self.log_requests,
)
- if finished_requests:
- await self._engine_abort(finished_requests)
+ if aborted_requests:
+ await self._engine_abort(aborted_requests)
if self.engine_use_ray:
request_outputs = await self.engine.step.remote() # type: ignore
@@ -666,6 +685,8 @@ async def run_engine_loop(self):
raise
await asyncio.sleep(0)
+ # This method does not need to be async, but kept that way
+ # for backwards compatibility.
async def add_request(
self,
request_id: str,
@@ -675,7 +696,7 @@ async def add_request(
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None
- ) -> AsyncStream:
+ ) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]:
if not self.is_running:
if self.start_engine_loop:
self.start_background_loop()
@@ -686,20 +707,17 @@ async def add_request(
"error that caused the background loop to stop "
"(AsyncEngineDeadError).")
- if arrival_time is None:
- arrival_time = time.time()
-
stream = self._request_tracker.add_request(
request_id,
verbose=self.log_requests,
inputs=inputs,
params=params,
- arrival_time=arrival_time,
+ arrival_time=arrival_time or time.time(),
lora_request=lora_request,
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request)
- return stream
+ return stream.generator()
async def generate(
self,
@@ -709,7 +727,7 @@ async def generate(
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None
- ) -> AsyncIterator[RequestOutput]:
+ ) -> AsyncGenerator[RequestOutput, None]:
"""Generate outputs for a request.
Generate outputs for a request. This method is a coroutine. It adds the
@@ -774,7 +792,7 @@ async def generate(
>>> # Process and return the final output
>>> ...
"""
- async for output in self._process_request(
+ async for output in await self.add_request(
request_id,
inputs,
sampling_params,
@@ -791,7 +809,7 @@ async def encode(
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
- ) -> AsyncIterator[EmbeddingRequestOutput]:
+ ) -> AsyncGenerator[EmbeddingRequestOutput, None]:
"""Generate outputs for a request from an embedding model.
Generate outputs for a request. This method is a coroutine. It adds the
@@ -852,7 +870,7 @@ async def encode(
>>> # Process and return the final output
>>> ...
"""
- async for output in self._process_request(
+ async for output in await self.add_request(
request_id,
inputs,
pooling_params,
@@ -861,37 +879,6 @@ async def encode(
):
yield LLMEngine.validate_output(output, EmbeddingRequestOutput)
- async def _process_request(
- self,
- request_id: str,
- inputs: PromptInputs,
- params: Union[SamplingParams, PoolingParams],
- *,
- lora_request: Optional[LoRARequest] = None,
- trace_headers: Optional[Mapping[str, str]] = None,
- prompt_adapter_request: Optional[PromptAdapterRequest] = None,
- ) -> AsyncIterator[Union[RequestOutput, EmbeddingRequestOutput]]:
- """Common logic to process requests with SamplingParams or
- PoolingParams."""
- arrival_time = time.time()
-
- stream = await self.add_request(
- request_id,
- inputs,
- params,
- arrival_time=arrival_time,
- lora_request=lora_request,
- trace_headers=trace_headers,
- prompt_adapter_request=prompt_adapter_request,
- )
-
- try:
- async for request_output in stream:
- yield request_output
- except (Exception, asyncio.CancelledError) as e:
- self._abort(request_id)
- raise e
-
async def abort(self, request_id: str) -> None:
"""Abort a request.
@@ -920,6 +907,7 @@ def _abort(self, request_id: str) -> None:
request_id: The unique id of the request.
"""
self._request_tracker.abort_request(request_id,
+ cancelled=True,
verbose=self.log_requests)
async def get_model_config(self) -> ModelConfig:
diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py
index 3747f93b16cd1..75c6d7e6c9b21 100644
--- a/vllm/engine/llm_engine.py
+++ b/vllm/engine/llm_engine.py
@@ -3,7 +3,7 @@
from typing import (TYPE_CHECKING, Any, ClassVar, Dict, Iterable, List,
Mapping, Optional)
from typing import Sequence as GenericSequence
-from typing import Set, Type, TypeVar, Union
+from typing import Set, Tuple, Type, TypeVar, Union
import vllm.envs as envs
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig,
@@ -22,7 +22,8 @@
from vllm.engine.output_processor.util import create_output_by_sequence_group
from vllm.executor.executor_base import ExecutorBase
from vllm.executor.ray_utils import initialize_ray_cluster
-from vllm.inputs import INPUT_REGISTRY, LLMInputs, PromptInputs
+from vllm.inputs import (INPUT_REGISTRY, LLMInputs, PromptInputs,
+ get_prompt_type)
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.outputs import (EmbeddingRequestOutput, RequestOutput,
@@ -42,7 +43,8 @@
AnyTokenizer, BaseTokenizerGroup, init_tokenizer_from_configs)
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
usage_message)
-from vllm.utils import Counter
+from vllm.utils import (Counter, is_embedding_model_config,
+ is_encoder_decoder_model_config)
from vllm.version import __version__ as VLLM_VERSION
logger = init_logger(__name__)
@@ -502,8 +504,19 @@ def _verify_args(self) -> None:
self.prompt_adapter_config.verify_with_model_config(
self.model_config)
- def _get_eos_token_id(
- self, lora_request: Optional[LoRARequest]) -> Optional[int]:
+ def _get_bos_token_id(self,
+ lora_request: Optional[LoRARequest] = None
+ ) -> Optional[int]:
+ if self.tokenizer is None:
+ logger.warning("Using None for BOS token id because tokenizer "
+ "is not initialized")
+ return None
+
+ return self.tokenizer.get_lora_tokenizer(lora_request).bos_token_id
+
+ def _get_eos_token_id(self,
+ lora_request: Optional[LoRARequest] = None
+ ) -> Optional[int]:
if self.tokenizer is None:
logger.warning("Using None for EOS token id because tokenizer "
"is not initialized")
@@ -511,6 +524,32 @@ def _get_eos_token_id(
return self.tokenizer.get_lora_tokenizer(lora_request).eos_token_id
+ def _get_decoder_start_token_id(self, ) -> Optional[int]:
+ '''
+ Obtain the decoder start token id employed by an encoder/decoder
+ model. Returns None for non-encoder/decoder models or if the
+ model config is unavailable.
+ '''
+
+ if not self.is_encoder_decoder_model():
+ logger.warning("Using None for decoder start token id because "
+ "this is not an encoder/decoder model.")
+ return None
+
+ if (self.model_config is None or self.model_config.hf_config is None):
+ logger.warning("Using None for decoder start token id because "
+ "model config is not available.")
+ return None
+
+ dec_start_token_id = getattr(self.model_config.hf_config,
+ 'decoder_start_token_id', None)
+ if dec_start_token_id is None:
+ logger.warning("Falling back on for decoder start token id "
+ "because decoder start token id is not available.")
+ dec_start_token_id = self._get_bos_token_id()
+
+ return dec_start_token_id
+
def _add_processed_request(
self,
request_id: str,
@@ -529,6 +568,16 @@ def _add_processed_request(
seq = Sequence(seq_id, processed_inputs, block_size, eos_token_id,
lora_request, prompt_adapter_request)
+ encoder_seq = None
+ if 'encoder_prompt_token_ids' in processed_inputs:
+ encoder_seq = Sequence(seq_id,
+ processed_inputs,
+ block_size,
+ eos_token_id,
+ lora_request,
+ prompt_adapter_request,
+ from_decoder_prompt=False)
+
# Create a SequenceGroup based on SamplingParams or PoolingParams
if isinstance(params, SamplingParams):
seq_group = self._create_sequence_group_with_sampling(
@@ -538,7 +587,8 @@ def _add_processed_request(
arrival_time=arrival_time,
lora_request=lora_request,
trace_headers=trace_headers,
- prompt_adapter_request=prompt_adapter_request)
+ prompt_adapter_request=prompt_adapter_request,
+ encoder_seq=encoder_seq)
elif isinstance(params, PoolingParams):
seq_group = self._create_sequence_group_with_pooling(
request_id,
@@ -546,7 +596,8 @@ def _add_processed_request(
params,
arrival_time=arrival_time,
lora_request=lora_request,
- prompt_adapter_request=prompt_adapter_request)
+ prompt_adapter_request=prompt_adapter_request,
+ encoder_seq=encoder_seq)
else:
raise ValueError(
"Either SamplingParams or PoolingParams must be provided.")
@@ -562,36 +613,362 @@ def _add_processed_request(
def stop_remote_worker_execution_loop(self) -> None:
self.model_executor.stop_remote_worker_execution_loop()
- def process_model_inputs(
+ _LLMInputComponentsType = Tuple[str, List[int], ]
+
+ def _prepare_decoder_input_ids_for_generation(
+ self,
+ decoder_input_ids: Optional[List[int]] = None,
+ ) -> List[int]:
+ """
+ Prepares `decoder_input_ids` for generation with encoder-decoder models.
+
+ Based on
+
+ https://github.com/huggingface/transformers/blob/
+ 4037a2b5b1278736e566aec12e169100275545ea/
+ src/transformers/generation/utils.py
+
+ specifically GenerationMixin._prepare_decoder_input_ids_for_generation()
+
+ Arguments:
+
+ * decoder_input_ids: input token ids to preprocess
+
+ Returns:
+
+ * Processed token list
+ """
+
+ decoder_start_token_id: Optional[int] = (
+ self._get_decoder_start_token_id())
+ assert decoder_start_token_id is not None
+
+ if decoder_input_ids is None:
+ # no decoder prompt input ->
+ # use decoder_start_token_id as decoder_input_ids
+ (decoder_input_ids) = self._get_default_enc_dec_decoder_prompt()
+
+ if (len(decoder_input_ids) == 0
+ or decoder_input_ids[0] != decoder_start_token_id):
+ decoder_input_ids = [decoder_start_token_id] + decoder_input_ids
+
+ return decoder_input_ids
+
+ def _tokenize_prompt(
+ self,
+ prompt: str,
+ request_id: Optional[str] = None,
+ lora_request: Optional[str] = None,
+ ) -> List[int]:
+ '''
+ Wrapper around application of the model's
+ tokenizer.
+
+ Arguments:
+
+ * prompt
+ * request_id
+ * lora_request
+
+ Returns:
+
+ * prompt token ids
+ '''
+
+ tokenizer = self.get_tokenizer_group("prompts must be None if "
+ "skip_tokenizer_init is True")
+
+ prompt_token_ids = tokenizer.encode(request_id=request_id,
+ prompt=prompt,
+ lora_request=lora_request)
+
+ return prompt_token_ids
+
+ def _extract_single_prompt_for_enc_dec_input(
+ self,
+ inputs: Optional[PromptInputs],
+ request_id: Optional[str] = None,
+ ptype: Optional[str] = None,
+ is_encoder_prompt: bool = False,
+ ) -> Tuple[Optional[str], List[int]]:
+ '''
+ Only for encoder/decoder models:
+ Extract prompt & prompt_token_ids from any single
+ encoder or decoder input prompt. For encoder input prompts
+ in particular, also extract multi-modal data.
+
+ This function handles the following scenarios:
+ 1. The user supplied a singleton encoder prompt
+ & the prompt/prompt-token-ids must be extracted.
+ 2. The user supplied an explicit encoder/decoder
+ prompt & the prompt/prompt-token-ids must be
+ extracted from either the encoder and decoder prompts.
+
+ For decoder prompts in particular (scenario 2), special
+ processing is applied to the returned decoder token ids.
+
+ Arguments:
+
+ * request_id
+ * ptype: str representation of the input prompt type.
+ If `ptype` is `None`, assume that the prompt
+ type is unknown and must be inferred. This is the
+ case for ExplicitEncoderDecoder sub-prompts.
+ * inputs: single encoder or decoder input prompt
+ * is_encoder_prompt: True if encoder input prompt.
+ If False, decoder prompt tokens
+ are preprocessed.
+
+ Returns:
+
+ * prompt
+ * prompt_token_ids
+ '''
+ prompt_token_ids = None
+ ptype = (get_prompt_type(inputs) if ptype is None else ptype)
+
+ if inputs is None:
+ prompt = None
+ elif ptype == 'str':
+ prompt = inputs
+ prompt_token_ids = self._tokenize_prompt(
+ prompt,
+ request_id=request_id,
+ )
+ elif ptype == 'TokensPrompt':
+ prompt = None
+ prompt_token_ids = inputs['prompt_token_ids']
+ else:
+ prompt = inputs['prompt']
+ prompt_token_ids = self._tokenize_prompt(
+ prompt,
+ request_id=request_id,
+ )
+
+ if not is_encoder_prompt:
+ # Apply special pre-processing to
+ # decoder prompts
+ prompt_token_ids = (self._prepare_decoder_input_ids_for_generation(
+ prompt_token_ids, ))
+
+ assert prompt_token_ids is not None
+
+ return (
+ prompt,
+ prompt_token_ids,
+ )
+
+ def _get_default_enc_dec_decoder_prompt(self, ) -> List[int]:
+ '''
+ Specifically for encoder/decoder models:
+ generate a default decoder prompt for when
+ the user specifies only the encoder prompt.
+
+ Encoder/decoder models utilize the decoder
+ prompt in different ways; as new models are
+ added, it is intended that this function
+ will be extended to produce differing
+ default decoder prompts, depending on the
+ model variety.
+
+ Absent a special case, the default behavior
+ of this method is to mirror the behavior of
+ the HuggingFace (HF) GenerationMixin for a None
+ decoder prompt, which is to employ a logit processor
+ setting to force the first decoded token to be .
+ Here, this behavior is approximated by having the
+ "default" decoder prompt be .
+
+ However, it is possible that in the future
+ other models may have different or more
+ complex logic for the default decoder prompt.
+ This motivates having a special helper method
+ for default decoder prompts.
+
+ Returns:
+
+ * prompt_token_ids
+ '''
+
+ bos_token_id = self._get_bos_token_id()
+ assert bos_token_id is not None
+ prompt_token_ids: List[int] = [bos_token_id]
+ return prompt_token_ids
+
+ def _process_encoder_decoder_prompt(
+ self,
+ inputs: PromptInputs,
+ request_id: Optional[str] = None,
+ ) -> LLMInputs:
+ '''
+ For encoder/decoder models only:
+ Process an input prompt
+ into an `LLMInputs` instance.
+
+ There are two types of input prompts:
+ singleton prompts which carry only the
+ encoder prompt, and explicit encoder/decoder
+ prompts which carry both the encoder and the
+ decoder prompts as member variables.
+
+ This function handles the following scenarios:
+ * Singleton encoder prompt: extract encoder prompt
+ token ids & infer default decoder prompt token ids
+ * Explicit encoder/decoder prompt: extract encoder
+ and decoder prompt token ids
+
+ Note that for Explicit encoder/decoder prompts,
+ each sub-prompt (encoder or decoder prompt) can
+ have any possible singleton type; thus this
+ method relies on helper functions to obtain
+ token ids for the sub-prompts.
+
+ Arguments:
+
+ * inputs: an input prompt
+ * request_id
+
+ Returns:
+
+ * `LLMInputs` instance
+ '''
+
+ ptype = get_prompt_type(inputs)
+
+ # Obtain encoder and decoder prompt tokens. Note
+ # that, no matter what, the decoder
+ # prompt type is unknown.
+ if ptype == "ExplicitEncoderDecoder":
+ # If input is explicit encoder/decoder prompt,
+ # then it remains to be determined what type
+ # of encoder prompt we have
+ extracted_encoder_prompt = inputs.get('encoder_prompt')
+ encoder_ptype = None
+ # Extract decoder prompt from explicit
+ # encoder/decoder prompt
+ extracted_decoder_prompt = inputs.get('decoder_prompt')
+ else:
+ # If input is singleton encoder prompt, then
+ # we know the encoder prompt type
+ extracted_encoder_prompt = inputs
+ encoder_ptype = ptype
+ # Decoder prompt is always unknown if
+ # encoder/decoder prompt is not explicit
+ extracted_decoder_prompt = None
+
+ # Invoke helper function to obtain encoder
+ # prompt and prompt token ids, either from
+ # singleton encoder prompt or from the
+ # encoder sub-prompt of an explicit
+ # encoder/decode scenario 2), special
+ # processing is applied to the returned decoder token ids
+ (
+ encoder_prompt,
+ encoder_prompt_token_ids,
+ ) = self._extract_single_prompt_for_enc_dec_input(
+ extracted_encoder_prompt,
+ request_id=request_id,
+ ptype=encoder_ptype,
+ is_encoder_prompt=True,
+ )
+
+ # Invoke helper method to obtain
+ # decoder prompt and prompt token ids.
+ #
+ # The helper method will detect the decoder
+ # prompt type.
+ #
+ # Helper method will also apply special
+ # preprocessing unique to decoder prompts.
+ (
+ decoder_prompt,
+ decoder_prompt_token_ids,
+ ) = self._extract_single_prompt_for_enc_dec_input(
+ extracted_decoder_prompt,
+ request_id=request_id,
+ ptype=None,
+ is_encoder_prompt=False,
+ )
+
+ return LLMInputs(
+ prompt_token_ids=decoder_prompt_token_ids,
+ prompt=decoder_prompt,
+ encoder_prompt_token_ids=encoder_prompt_token_ids,
+ encoder_prompt=encoder_prompt,
+ )
+
+ def _process_decoder_only_prompt(
self,
- request_id: str,
inputs: PromptInputs,
lora_request: Optional[LoRARequest] = None,
+ request_id: Optional[str] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> LLMInputs:
+ '''
+ For decoder-only models:
+ Process an input prompt
+ into an `LLMInputs` instance.
+
+ Arguments:
+
+ * inputs: input prompt
+ * lora_request
+ * request_id
+ * prompt_adapter_request
+
+ Returns:
+
+ * `LLMInputs` instance
+ '''
+
if isinstance(inputs, str):
inputs = {"prompt": inputs}
+ prompt = inputs.get("prompt")
if "prompt_token_ids" not in inputs:
- tokenizer = self.get_tokenizer_group("prompts must be None if "
- "skip_tokenizer_init is True")
-
- prompt_token_ids = tokenizer.encode(request_id=request_id,
- prompt=inputs["prompt"],
- lora_request=lora_request)
+ prompt_token_ids = self._tokenize_prompt(
+ prompt,
+ request_id=request_id,
+ lora_request=lora_request,
+ )
else:
prompt_token_ids = inputs["prompt_token_ids"]
if prompt_adapter_request:
- prompt_token_ids = \
- [0] * prompt_adapter_request.prompt_adapter_num_virtual_tokens\
- + prompt_token_ids
+ prompt_token_ids = (
+ [0] * prompt_adapter_request.prompt_adapter_num_virtual_tokens
+ + prompt_token_ids)
+
+ return LLMInputs(prompt_token_ids=prompt_token_ids,
+ prompt=prompt,
+ multi_modal_data=inputs.get("multi_modal_data"))
+
+ def process_model_inputs(
+ self,
+ request_id: str,
+ inputs: PromptInputs,
+ lora_request: Optional[LoRARequest] = None,
+ prompt_adapter_request: Optional[PromptAdapterRequest] = None,
+ ) -> LLMInputs:
- llm_inputs = LLMInputs(prompt_token_ids=prompt_token_ids,
- prompt=inputs.get("prompt"),
- multi_modal_data=inputs.get("multi_modal_data"))
+ if self.is_encoder_decoder_model():
+ # Encoder-decoder model requires special mapping of
+ # input prompts to encoder & decoder
- return self.input_processor(llm_inputs)
+ model_inputs = self._process_encoder_decoder_prompt(
+ inputs,
+ request_id=request_id,
+ )
+ else:
+ # Decoder-only operation
+ model_inputs = self._process_decoder_only_prompt(
+ inputs,
+ request_id=request_id,
+ lora_request=lora_request,
+ prompt_adapter_request=prompt_adapter_request,
+ )
+
+ return self.input_processor(model_inputs)
def add_request(
self,
@@ -676,6 +1053,7 @@ def _create_sequence_group_with_sampling(
lora_request: Optional[LoRARequest],
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
+ encoder_seq: Optional[Sequence] = None,
) -> SequenceGroup:
"""Creates a SequenceGroup with SamplingParams."""
max_logprobs = self.get_model_config().max_logprobs
@@ -701,7 +1079,8 @@ def _create_sequence_group_with_sampling(
sampling_params=sampling_params,
lora_request=lora_request,
trace_headers=trace_headers,
- prompt_adapter_request=prompt_adapter_request)
+ prompt_adapter_request=prompt_adapter_request,
+ encoder_seq=encoder_seq)
return seq_group
@@ -713,6 +1092,7 @@ def _create_sequence_group_with_pooling(
arrival_time: float,
lora_request: Optional[LoRARequest],
prompt_adapter_request: Optional[PromptAdapterRequest],
+ encoder_seq: Optional[Sequence] = None,
) -> SequenceGroup:
"""Creates a SequenceGroup with PoolingParams."""
# Defensive copy of PoolingParams, which are used by the pooler
@@ -724,7 +1104,8 @@ def _create_sequence_group_with_pooling(
arrival_time=arrival_time,
lora_request=lora_request,
pooling_params=pooling_params,
- prompt_adapter_request=prompt_adapter_request)
+ prompt_adapter_request=prompt_adapter_request,
+ encoder_seq=encoder_seq)
return seq_group
def abort_request(self, request_id: Union[str, Iterable[str]]) -> None:
@@ -1214,3 +1595,9 @@ def create_trace_span(self, seq_group: SequenceGroup) -> None:
seq_span.set_attribute(
SpanAttributes.LLM_LATENCY_TIME_TO_FIRST_TOKEN, ttft)
seq_span.set_attribute(SpanAttributes.LLM_LATENCY_E2E, e2e_time)
+
+ def is_encoder_decoder_model(self):
+ return is_encoder_decoder_model_config(self.model_config)
+
+ def is_embedding_model(self):
+ return is_embedding_model_config(self.model_config)
diff --git a/vllm/engine/protocol.py b/vllm/engine/protocol.py
index fc94ef6662e0a..e05c01fa8d6c3 100644
--- a/vllm/engine/protocol.py
+++ b/vllm/engine/protocol.py
@@ -1,4 +1,4 @@
-from typing import (AsyncIterator, List, Mapping, Optional, Protocol,
+from typing import (AsyncGenerator, List, Mapping, Optional, Protocol,
runtime_checkable)
from transformers import PreTrainedTokenizer
@@ -30,7 +30,7 @@ def is_stopped(self) -> bool:
def errored(self) -> bool:
...
- async def generate(
+ def generate(
self,
inputs: PromptInputs,
sampling_params: SamplingParams,
@@ -38,17 +38,17 @@ async def generate(
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None
- ) -> AsyncIterator[RequestOutput]:
+ ) -> AsyncGenerator[RequestOutput, None]:
"""Generates outputs for a request"""
- async def encode(
+ def encode(
self,
inputs: PromptInputs,
pooling_params: PoolingParams,
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
- ) -> AsyncIterator[EmbeddingRequestOutput]:
+ ) -> AsyncGenerator[EmbeddingRequestOutput, None]:
"""Generate outputs for a request from an embedding model."""
async def abort(self, request_id: str) -> None:
diff --git a/vllm/entrypoints/api_server.py b/vllm/entrypoints/api_server.py
index d5fac5557104d..f6e8a417b648c 100644
--- a/vllm/entrypoints/api_server.py
+++ b/vllm/entrypoints/api_server.py
@@ -20,7 +20,8 @@
from vllm.logger import init_logger
from vllm.sampling_params import SamplingParams
from vllm.usage.usage_lib import UsageContext
-from vllm.utils import FlexibleArgumentParser, random_uuid
+from vllm.utils import (FlexibleArgumentParser, iterate_with_cancellation,
+ random_uuid)
from vllm.version import __version__ as VLLM_VERSION
logger = init_logger("vllm.entrypoints.api_server")
@@ -53,6 +54,8 @@ async def generate(request: Request) -> Response:
assert engine is not None
results_generator = engine.generate(prompt, sampling_params, request_id)
+ results_generator = iterate_with_cancellation(
+ results_generator, is_cancelled=request.is_disconnected)
# Streaming case
async def stream_results() -> AsyncGenerator[bytes, None]:
@@ -69,12 +72,11 @@ async def stream_results() -> AsyncGenerator[bytes, None]:
# Non-streaming case
final_output = None
- async for request_output in results_generator:
- if await request.is_disconnected():
- # Abort the request if the client disconnects.
- await engine.abort(request_id)
- return Response(status_code=499)
- final_output = request_output
+ try:
+ async for request_output in results_generator:
+ final_output = request_output
+ except asyncio.CancelledError:
+ return Response(status_code=499)
assert final_output is not None
prompt = final_output.prompt
diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py
index 072450a6146ee..12634c3261856 100644
--- a/vllm/entrypoints/chat_utils.py
+++ b/vllm/entrypoints/chat_utils.py
@@ -1,8 +1,9 @@
import codecs
from dataclasses import dataclass
from functools import lru_cache
-from typing import (Awaitable, Iterable, List, Optional, Tuple, Union, cast,
- final)
+from pathlib import Path
+from typing import (Any, Awaitable, Iterable, List, Optional, Tuple, Union,
+ cast, final)
# yapf conflicts with isort for this block
# yapf: disable
@@ -22,6 +23,7 @@
from vllm.logger import init_logger
from vllm.multimodal import MultiModalDataDict
from vllm.multimodal.utils import async_get_and_parse_image
+from vllm.transformers_utils.tokenizer import AnyTokenizer
logger = init_logger(__name__)
@@ -69,13 +71,17 @@ class ChatMessageParseResult:
mm_futures: List[Awaitable[MultiModalDataDict]]
-def load_chat_template(chat_template: Optional[str]) -> Optional[str]:
+def load_chat_template(
+ chat_template: Optional[Union[Path, str]]) -> Optional[str]:
if chat_template is None:
return None
try:
with open(chat_template, "r") as f:
resolved_chat_template = f.read()
except OSError as e:
+ if isinstance(chat_template, Path):
+ raise
+
JINJA_CHARS = "{}\n"
if not any(c in chat_template for c in JINJA_CHARS):
msg = (f"The supplied chat template ({chat_template}) "
@@ -208,3 +214,28 @@ def parse_chat_messages(
mm_futures.extend(parse_result.mm_futures)
return conversation, mm_futures
+
+
+def apply_chat_template(
+ tokenizer: AnyTokenizer,
+ conversation: List[ConversationMessage],
+ chat_template: Optional[str],
+ *,
+ tokenize: bool = False, # Different from HF's default
+ **kwargs: Any,
+) -> str:
+ if chat_template is None and tokenizer.chat_template is None:
+ raise ValueError(
+ "As of transformers v4.44, default chat template is no longer "
+ "allowed, so you must provide a chat template if the tokenizer "
+ "does not define one.")
+
+ prompt = tokenizer.apply_chat_template(
+ conversation=conversation,
+ chat_template=chat_template,
+ tokenize=tokenize,
+ **kwargs,
+ )
+ assert isinstance(prompt, str)
+
+ return prompt
diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py
index 262cba79e5712..eaa1572094936 100644
--- a/vllm/entrypoints/llm.py
+++ b/vllm/entrypoints/llm.py
@@ -121,12 +121,21 @@ def __init__(
gpu_memory_utilization: float = 0.9,
swap_space: int = 4,
cpu_offload_gb: float = 0,
- enforce_eager: bool = False,
+ enforce_eager: Optional[bool] = None,
max_context_len_to_capture: Optional[int] = None,
max_seq_len_to_capture: int = 8192,
disable_custom_all_reduce: bool = False,
**kwargs,
) -> None:
+ '''
+ LLM constructor.
+
+ Note: if enforce_eager is unset (enforce_eager is None)
+ it defaults to False for decoder-only models and True
+ for encoder/decoder models, since encoder/decoder models
+ do not currently support CUDAGraph.
+ '''
+
if "disable_log_stats" not in kwargs:
kwargs["disable_log_stats"] = True
removed_vision_keys = ("image_token_id", "image_feature_size",
@@ -297,8 +306,8 @@ def generate(
"""
if self.llm_engine.model_config.embedding_mode:
raise ValueError(
- "LLM.generate() is only supported for generation models "
- "(XForCausalLM).")
+ "LLM.generate() is only supported for (conditional) generation "
+ "models (XForCausalLM, XForConditionalGeneration).")
if prompt_token_ids is not None:
inputs = self._convert_v1_inputs(
@@ -631,3 +640,9 @@ def _run_engine(
# This is necessary because some requests may be finished earlier than
# its previous requests.
return sorted(outputs, key=lambda x: int(x.request_id))
+
+ def _is_encoder_decoder_model(self):
+ return self.llm_engine.is_encoder_decoder_model()
+
+ def _is_embedding_model(self):
+ return self.llm_engine.is_embedding_model()
diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py
index 0c57ccb11d2cf..1a0addfedc55f 100644
--- a/vllm/entrypoints/openai/api_server.py
+++ b/vllm/entrypoints/openai/api_server.py
@@ -43,7 +43,7 @@
OpenAIServingTokenization)
from vllm.logger import init_logger
from vllm.usage.usage_lib import UsageContext
-from vllm.utils import FlexibleArgumentParser, get_open_port
+from vllm.utils import FlexibleArgumentParser, get_open_zmq_ipc_path
from vllm.version import __version__ as VLLM_VERSION
TIMEOUT_KEEP_ALIVE = 5 # seconds
@@ -106,19 +106,32 @@ async def build_async_engine_client(args) -> AsyncIterator[AsyncEngineClient]:
# Otherwise, use the multiprocessing AsyncLLMEngine.
else:
+ # Select random path for IPC.
+ rpc_path = get_open_zmq_ipc_path()
+ logger.info("Multiprocessing frontend to use %s for RPC Path.",
+ rpc_path)
+
# Start RPCServer in separate process (holds the AsyncLLMEngine).
- port = get_open_port(envs.VLLM_RPC_PORT)
rpc_server_process = Process(target=run_rpc_server,
args=(engine_args,
UsageContext.OPENAI_API_SERVER,
- port))
+ rpc_path))
rpc_server_process.start()
# Build RPCClient, which conforms to AsyncEngineClient Protocol.
- async_engine_client = AsyncEngineRPCClient(port)
- await async_engine_client.setup()
+ async_engine_client = AsyncEngineRPCClient(rpc_path)
try:
+ while True:
+ try:
+ await async_engine_client.setup()
+ break
+ except TimeoutError as e:
+ if not rpc_server_process.is_alive():
+ raise RuntimeError(
+ "The server process died before "
+ "responding to the readiness probe") from e
+
yield async_engine_client
finally:
# Ensure rpc server process was terminated
diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py
index 76318a1271229..70467bd879690 100644
--- a/vllm/entrypoints/openai/protocol.py
+++ b/vllm/entrypoints/openai/protocol.py
@@ -190,8 +190,9 @@ class ChatCompletionRequest(OpenAIBaseModel):
default=None,
description=(
"A Jinja template to use for this conversion. "
- "If this is not passed, the model's default chat template will be "
- "used instead."),
+ "As of transformers v4.44, default chat template is no longer "
+ "allowed, so you must provide a chat template if the tokenizer "
+ "does not define one."),
)
chat_template_kwargs: Optional[Dict[str, Any]] = Field(
default=None,
diff --git a/vllm/entrypoints/openai/rpc/client.py b/vllm/entrypoints/openai/rpc/client.py
index b88c40203e330..64a20b33d8f3e 100644
--- a/vllm/entrypoints/openai/rpc/client.py
+++ b/vllm/entrypoints/openai/rpc/client.py
@@ -1,5 +1,5 @@
from contextlib import contextmanager
-from typing import Any, AsyncIterator, Mapping, Optional
+from typing import Any, AsyncGenerator, Mapping, Optional
import cloudpickle
import zmq
@@ -18,12 +18,15 @@
from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
+# Time to wait before checking it the server process is alive.
+SERVER_START_TIMEOUT_MS = 1000
+
class AsyncEngineRPCClient:
- def __init__(self, port: int):
+ def __init__(self, rpc_path: str):
self.context = zmq.asyncio.Context()
- self.path = f"tcp://localhost:{port}"
+ self.rpc_path = rpc_path
async def setup(self):
"""Setup the client before it starts sending server requests."""
@@ -59,10 +62,19 @@ def socket(self):
# to enable streaming.
socket = self.context.socket(zmq.constants.DEALER)
try:
- socket.connect(self.path)
+ socket.connect(self.rpc_path)
yield socket
finally:
- socket.close()
+ # linger == 0 means discard unsent messages
+ # when the socket is closed. This is necessary
+ # because otherwise self.context.destroy() will
+ # wait for 30 seconds until unsent messages are
+ # received, which is impossible if the server
+ # crashed. In the absence of a server crash we
+ # always expect a response before closing the
+ # socket anyway.
+ # Reference: http://api.zeromq.org/4-2:zmq-setsockopt#toc24
+ socket.close(linger=0)
async def _send_get_data_rpc_request(self, request: RPCUtilityRequest,
expected_type: Any,
@@ -86,14 +98,19 @@ async def _send_get_data_rpc_request(self, request: RPCUtilityRequest,
return data
- async def _send_one_way_rpc_request(self, request: RPC_REQUEST_TYPE,
- error_message: str):
+ async def _send_one_way_rpc_request(self,
+ request: RPC_REQUEST_TYPE,
+ error_message: str,
+ timeout: Optional[int] = None):
"""Send one-way RPC request to trigger an action."""
with self.socket() as socket:
# Ping RPC Server with request.
await socket.send(cloudpickle.dumps(request))
# Await acknowledgement from RPCServer.
+ if timeout is not None and await socket.poll(timeout=timeout) == 0:
+ raise TimeoutError(f"server didn't reply within {timeout} ms")
+
response = cloudpickle.loads(await socket.recv())
if not isinstance(response, str) or response != VLLM_RPC_SUCCESS_STR:
@@ -118,7 +135,8 @@ async def wait_for_server(self):
await self._send_one_way_rpc_request(
request=RPCUtilityRequest.IS_SERVER_READY,
- error_message="Unable to start RPC Server.")
+ error_message="Unable to start RPC Server.",
+ timeout=SERVER_START_TIMEOUT_MS)
async def _get_model_config_rpc(self) -> ModelConfig:
"""Get the ModelConfig object from the RPC Server"""
@@ -203,45 +221,47 @@ async def generate(
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None
- ) -> AsyncIterator[RequestOutput]:
+ ) -> AsyncGenerator[RequestOutput, None]:
"""Send an RPCGenerateRequest to the RPCServer and stream responses."""
- with self.socket() as socket:
-
- # Send RPCGenerateRequest to the RPCServer.
- await socket.send_multipart([
- cloudpickle.dumps(
- RPCGenerateRequest(
- inputs=inputs,
- sampling_params=sampling_params,
- request_id=request_id,
- lora_request=lora_request,
- trace_headers=trace_headers,
- prompt_adapter_request=prompt_adapter_request))
- ])
-
- # Stream back the results from the RPC Server.
- while True:
- message = await socket.recv()
- request_output = cloudpickle.loads(message)
-
- if isinstance(request_output, Exception):
- # On exception, check if the server is still healthy.
- # Use this to set the sync `is_running` and `errored`
- # properties.
- try:
- await self.check_health()
- except Exception:
- self._errored = True
- # NB: do before raising here so that the flag is set
- # by the time the caller receives this exception
- raise request_output
-
- if request_output.finished:
- break
- yield request_output
-
- yield request_output
+ finished = False
+ try:
+ with self.socket() as socket:
+
+ # Send RPCGenerateRequest to the RPCServer.
+ await socket.send_multipart([
+ cloudpickle.dumps(
+ RPCGenerateRequest(
+ inputs=inputs,
+ sampling_params=sampling_params,
+ request_id=request_id,
+ lora_request=lora_request,
+ trace_headers=trace_headers,
+ prompt_adapter_request=prompt_adapter_request))
+ ])
+
+ # Stream back the results from the RPC Server.
+ while not finished:
+ message = await socket.recv()
+ request_output = cloudpickle.loads(message)
+
+ if isinstance(request_output, Exception):
+ # On exception, check if the server is still healthy.
+ # Use this to set the sync `is_running` and `errored`
+ # properties.
+ try:
+ await self.check_health()
+ except Exception:
+ self._errored = True
+ # NB: do before raising here so that the flag is set
+ # by the time the caller receives this exception
+ raise request_output
+
+ finished = request_output.finished
+ yield request_output
+ finally:
+ if not finished:
+ await self.abort(request_id)
async def check_health(self) -> None:
"""Raise if unhealthy"""
@@ -265,6 +285,6 @@ async def check_health(self) -> None:
"f{health_message}")
async def encode(self, *args,
- **kwargs) -> AsyncIterator[EmbeddingRequestOutput]:
+ **kwargs) -> AsyncGenerator[EmbeddingRequestOutput, None]:
raise NotImplementedError(
"Embeddings not supported with multiprocessing backend")
diff --git a/vllm/entrypoints/openai/rpc/server.py b/vllm/entrypoints/openai/rpc/server.py
index 60bb23b9bde05..617c9b7070e2c 100644
--- a/vllm/entrypoints/openai/rpc/server.py
+++ b/vllm/entrypoints/openai/rpc/server.py
@@ -20,7 +20,7 @@
class AsyncEngineRPCServer:
def __init__(self, async_engine_args: AsyncEngineArgs,
- usage_context: UsageContext, port: int):
+ usage_context: UsageContext, rpc_path: str):
# Initialize engine first.
self.engine = AsyncLLMEngine.from_engine_args(async_engine_args,
usage_context)
@@ -30,9 +30,7 @@ def __init__(self, async_engine_args: AsyncEngineArgs,
# Init socket for readiness state.
self.socket = self.context.socket(zmq.constants.ROUTER)
- # Note numeric form of localhost should be used for zmq bind(),
- # see https://stackoverflow.com/a/8958414
- self.socket.bind(f"tcp://127.0.0.1:{port}")
+ self.socket.bind(rpc_path)
def cleanup(self):
"""Cleanup all resources."""
@@ -213,6 +211,6 @@ def signal_handler() -> None:
def run_rpc_server(async_engine_args: AsyncEngineArgs,
- usage_context: UsageContext, port: int):
- server = AsyncEngineRPCServer(async_engine_args, usage_context, port)
+ usage_context: UsageContext, rpc_path: str):
+ server = AsyncEngineRPCServer(async_engine_args, usage_context, rpc_path)
asyncio.run(run_server(server))
diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py
index d215754993e82..2167b967b14b5 100644
--- a/vllm/entrypoints/openai/serving_chat.py
+++ b/vllm/entrypoints/openai/serving_chat.py
@@ -1,3 +1,4 @@
+import asyncio
import time
from typing import AsyncGenerator, AsyncIterator, Dict, List, Optional
from typing import Sequence as GenericSequence
@@ -9,6 +10,7 @@
from vllm.config import ModelConfig
from vllm.engine.protocol import AsyncEngineClient
from vllm.entrypoints.chat_utils import (ConversationMessage,
+ apply_chat_template,
load_chat_template,
parse_chat_messages)
from vllm.entrypoints.logger import RequestLogger
@@ -29,7 +31,7 @@
from vllm.sequence import Logprob
from vllm.tracing import (contains_trace_headers, extract_trace_headers,
log_tracing_disabled_warning)
-from vllm.utils import random_uuid
+from vllm.utils import iterate_with_cancellation, random_uuid
logger = init_logger(__name__)
@@ -98,16 +100,15 @@ async def create_chat_completion(
tool.model_dump() for tool in request.tools
]
- prompt = tokenizer.apply_chat_template(
+ prompt = apply_chat_template(
+ tokenizer,
conversation=conversation,
- tokenize=False,
+ chat_template=request.chat_template or self.chat_template,
add_generation_prompt=request.add_generation_prompt,
tools=tool_dicts,
documents=request.documents,
- chat_template=request.chat_template or self.chat_template,
**(request.chat_template_kwargs or {}),
)
- assert isinstance(prompt, str)
except Exception as e:
logger.error("Error in applying chat template from request: %s", e)
return self.create_error_response(str(e))
@@ -176,18 +177,20 @@ async def create_chat_completion(
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
+ if raw_request:
+ result_generator = iterate_with_cancellation(
+ result_generator, raw_request.is_disconnected)
+
# Streaming response
if request.stream:
return self.chat_completion_stream_generator(
request, result_generator, request_id, conversation, tokenizer)
- else:
- try:
- return await self.chat_completion_full_generator(
- request, raw_request, result_generator, request_id,
- conversation, tokenizer)
- except ValueError as e:
- # TODO: Use a vllm-specific Validation Error
- return self.create_error_response(str(e))
+ try:
+ return await self.chat_completion_full_generator(
+ request, result_generator, request_id, conversation, tokenizer)
+ except ValueError as e:
+ # TODO: Use a vllm-specific Validation Error
+ return self.create_error_response(str(e))
def get_chat_request_role(self, request: ChatCompletionRequest) -> str:
if request.add_generation_prompt:
@@ -422,7 +425,6 @@ async def chat_completion_stream_generator(
async def chat_completion_full_generator(
self,
request: ChatCompletionRequest,
- raw_request: Optional[Request],
result_generator: AsyncIterator[RequestOutput],
request_id: str,
conversation: List[ConversationMessage],
@@ -433,12 +435,12 @@ async def chat_completion_full_generator(
created_time = int(time.time())
final_res: Optional[RequestOutput] = None
- async for res in result_generator:
- if raw_request is not None and await raw_request.is_disconnected():
- # Abort the request if the client disconnects.
- await self.async_engine_client.abort(request_id)
- return self.create_error_response("Client disconnected")
- final_res = res
+ try:
+ async for res in result_generator:
+ final_res = res
+ except asyncio.CancelledError:
+ return self.create_error_response("Client disconnected")
+
assert final_res is not None
choices: List[ChatCompletionResponseChoice] = []
diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py
index edc83d83fbba7..f4c91ce046847 100644
--- a/vllm/entrypoints/openai/serving_completion.py
+++ b/vllm/entrypoints/openai/serving_completion.py
@@ -1,3 +1,4 @@
+import asyncio
import time
from typing import (AsyncGenerator, AsyncIterator, Callable, Dict, List,
Optional)
@@ -84,7 +85,7 @@ async def create_completion(self, request: CompletionRequest,
created_time = int(time.time())
# Schedule the request and get the result generator.
- generators: List[AsyncIterator[RequestOutput]] = []
+ generators: List[AsyncGenerator[RequestOutput, None]] = []
try:
(
lora_request,
@@ -144,7 +145,8 @@ async def create_completion(self, request: CompletionRequest,
return self.create_error_response(str(e))
result_generator: AsyncIterator[Tuple[
- int, RequestOutput]] = merge_async_iterators(*generators)
+ int, RequestOutput]] = merge_async_iterators(
+ *generators, is_cancelled=raw_request.is_disconnected)
# Similar to the OpenAI API, when n != best_of, we do not stream the
# results. In addition, we do not stream the results when use
@@ -156,7 +158,6 @@ async def create_completion(self, request: CompletionRequest,
# Streaming response
if stream:
return self.completion_stream_generator(request,
- raw_request,
result_generator,
request_id,
created_time,
@@ -168,10 +169,6 @@ async def create_completion(self, request: CompletionRequest,
final_res_batch: List[Optional[RequestOutput]] = [None] * len(prompts)
try:
async for i, res in result_generator:
- if await raw_request.is_disconnected():
- # Abort the request if the client disconnects.
- await self.async_engine_client.abort(f"{request_id}-{i}")
- return self.create_error_response("Client disconnected")
final_res_batch[i] = res
for i, final_res in enumerate(final_res_batch):
@@ -194,6 +191,8 @@ async def create_completion(self, request: CompletionRequest,
model_name,
tokenizer,
)
+ except asyncio.CancelledError:
+ return self.create_error_response("Client disconnected")
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
@@ -214,7 +213,6 @@ async def fake_stream_generator() -> AsyncGenerator[str, None]:
async def completion_stream_generator(
self,
request: CompletionRequest,
- raw_request: Request,
result_generator: AsyncIterator[Tuple[int, RequestOutput]],
request_id: str,
created_time: int,
@@ -230,12 +228,6 @@ async def completion_stream_generator(
try:
async for prompt_idx, res in result_generator:
- # Abort the request if the client disconnects.
- if await raw_request.is_disconnected():
- await self.async_engine_client.abort(
- f"{request_id}-{prompt_idx}")
- raise StopAsyncIteration()
-
for output in res.outputs:
i = output.index + prompt_idx * num_choices
# TODO(simon): optimize the performance by avoiding full
diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py
index e61c82f9a8a6c..28dbaecfd6819 100644
--- a/vllm/entrypoints/openai/serving_embedding.py
+++ b/vllm/entrypoints/openai/serving_embedding.py
@@ -1,6 +1,7 @@
+import asyncio
import base64
import time
-from typing import AsyncIterator, List, Optional, Tuple, cast
+from typing import AsyncGenerator, AsyncIterator, List, Optional, Tuple, cast
import numpy as np
from fastapi import Request
@@ -92,7 +93,7 @@ async def create_embedding(self, request: EmbeddingRequest,
created_time = int(time.monotonic())
# Schedule the request and get the result generator.
- generators: List[AsyncIterator[EmbeddingRequestOutput]] = []
+ generators: List[AsyncGenerator[EmbeddingRequestOutput, None]] = []
try:
(
lora_request,
@@ -138,17 +139,14 @@ async def create_embedding(self, request: EmbeddingRequest,
return self.create_error_response(str(e))
result_generator: AsyncIterator[Tuple[
- int, EmbeddingRequestOutput]] = merge_async_iterators(*generators)
+ int, EmbeddingRequestOutput]] = merge_async_iterators(
+ *generators, is_cancelled=raw_request.is_disconnected)
# Non-streaming response
final_res_batch: List[Optional[EmbeddingRequestOutput]]
final_res_batch = [None] * len(prompts)
try:
async for i, res in result_generator:
- if await raw_request.is_disconnected():
- # Abort the request if the client disconnects.
- await self.async_engine_client.abort(f"{request_id}-{i}")
- return self.create_error_response("Client disconnected")
final_res_batch[i] = res
for final_res in final_res_batch:
@@ -160,6 +158,8 @@ async def create_embedding(self, request: EmbeddingRequest,
response = request_output_to_embedding_response(
final_res_batch_checked, request_id, created_time, model_name,
encoding_format)
+ except asyncio.CancelledError:
+ return self.create_error_response("Client disconnected")
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
diff --git a/vllm/entrypoints/openai/serving_tokenization.py b/vllm/entrypoints/openai/serving_tokenization.py
index 5b6b979b9b9e7..1aeabb7a7d729 100644
--- a/vllm/entrypoints/openai/serving_tokenization.py
+++ b/vllm/entrypoints/openai/serving_tokenization.py
@@ -2,7 +2,9 @@
from vllm.config import ModelConfig
from vllm.engine.protocol import AsyncEngineClient
-from vllm.entrypoints.chat_utils import load_chat_template, parse_chat_messages
+from vllm.entrypoints.chat_utils import (apply_chat_template,
+ load_chat_template,
+ parse_chat_messages)
from vllm.entrypoints.logger import RequestLogger
# yapf conflicts with isort for this block
# yapf: disable
@@ -70,12 +72,12 @@ async def create_tokenize(
logger.warning(
"Multi-modal inputs are ignored during tokenization")
- prompt = tokenizer.apply_chat_template(
- add_generation_prompt=request.add_generation_prompt,
+ prompt = apply_chat_template(
+ tokenizer,
conversation=conversation,
- tokenize=False,
- chat_template=self.chat_template)
- assert isinstance(prompt, str)
+ chat_template=self.chat_template,
+ add_generation_prompt=request.add_generation_prompt,
+ )
else:
prompt = request.prompt
diff --git a/vllm/envs.py b/vllm/envs.py
index 61bccd7926e26..26d0c33707fea 100644
--- a/vllm/envs.py
+++ b/vllm/envs.py
@@ -1,10 +1,11 @@
import os
+import tempfile
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional
if TYPE_CHECKING:
VLLM_HOST_IP: str = ""
VLLM_PORT: Optional[int] = None
- VLLM_RPC_PORT: int = 5570
+ VLLM_RPC_BASE_PATH: str = tempfile.gettempdir()
VLLM_USE_MODELSCOPE: bool = False
VLLM_RINGBUFFER_WARNING_INTERVAL: int = 60
VLLM_INSTANCE_ID: Optional[str] = None
@@ -52,6 +53,7 @@
CMAKE_BUILD_TYPE: Optional[str] = None
VERBOSE: bool = False
VLLM_ALLOW_LONG_MAX_MODEL_LEN: bool = False
+ VLLM_TEST_FORCE_FP8_MARLIN: bool = False
def get_default_cache_root():
@@ -143,10 +145,10 @@ def get_default_config_root():
lambda: int(os.getenv('VLLM_PORT', '0'))
if 'VLLM_PORT' in os.environ else None,
- # used when the frontend api server is running in multi-processing mode,
- # to communicate with the backend engine process over ZMQ.
- 'VLLM_RPC_PORT':
- lambda: int(os.getenv('VLLM_RPC_PORT', '5570')),
+ # path used for ipc when the frontend api server is running in
+ # multi-processing mode to communicate with the backend engine process.
+ 'VLLM_RPC_BASE_PATH':
+ lambda: os.getenv('VLLM_RPC_BASE_PATH', tempfile.gettempdir()),
# If true, will load models from ModelScope instead of Hugging Face Hub.
# note that the value is true or false, not numbers
@@ -347,6 +349,13 @@ def get_default_config_root():
lambda:
(os.environ.get("VLLM_ALLOW_LONG_MAX_MODEL_LEN", "0").strip().lower() in
("1", "true")),
+
+ # If set, forces FP8 Marlin to be used for FP8 quantization regardless
+ # of the hardware support for FP8 compute.
+ "VLLM_TEST_FORCE_FP8_MARLIN":
+ lambda:
+ (os.environ.get("VLLM_TEST_FORCE_FP8_MARLIN", "0").strip().lower() in
+ ("1", "true")),
}
# end-env-vars-definition
diff --git a/vllm/inputs/__init__.py b/vllm/inputs/__init__.py
index b13d9acf93d3b..e22b88f2fc38a 100644
--- a/vllm/inputs/__init__.py
+++ b/vllm/inputs/__init__.py
@@ -1,5 +1,7 @@
-from .data import (LLMInputs, ParsedText, ParsedTokens, PromptInputs,
- TextPrompt, TokensPrompt, parse_and_batch_prompt)
+from .data import (ExplicitEncoderDecoderPrompt, LLMInputs, ParsedText,
+ ParsedTokens, PromptInputs, SingletonPromptInputs,
+ TextPrompt, TokensPrompt, get_prompt_type,
+ is_valid_encoder_decoder_llm_inputs, parse_and_batch_prompt)
from .registry import InputContext, InputRegistry
INPUT_REGISTRY = InputRegistry()
@@ -12,7 +14,18 @@
"""
__all__ = [
- "ParsedText", "ParsedTokens", "parse_and_batch_prompt", "TextPrompt",
- "TokensPrompt", "PromptInputs", "LLMInputs", "INPUT_REGISTRY",
- "InputContext", "InputRegistry"
+ "ParsedText",
+ "ParsedTokens",
+ "parse_and_batch_prompt",
+ "TextPrompt",
+ "TokensPrompt",
+ "PromptInputs",
+ "LLMInputs",
+ "INPUT_REGISTRY",
+ "InputContext",
+ "InputRegistry",
+ "get_prompt_type",
+ "is_valid_encoder_decoder_llm_inputs",
+ "ExplicitEncoderDecoderPrompt",
+ "SingletonPromptInputs",
]
diff --git a/vllm/inputs/data.py b/vllm/inputs/data.py
index 4443e6c70fe5b..86c2901dc4c80 100644
--- a/vllm/inputs/data.py
+++ b/vllm/inputs/data.py
@@ -92,15 +92,114 @@ class TokensPrompt(TypedDict):
"""
-PromptInputs = Union[str, TextPrompt, TokensPrompt]
+SingletonPromptInputs = Union[str, TextPrompt, TokensPrompt]
"""
-The inputs to the LLM, which can take one of the following forms:
+Set of possible schemas for a single LLM input:
- A text prompt (:class:`str` or :class:`TextPrompt`)
- A tokenized prompt (:class:`TokensPrompt`)
+
+Note that "singleton" is as opposed to a data structure
+which encapsulates multiple prompts, i.e. of the sort
+which may be utilized for encoder/decoder models when
+the user desires to express both the encoder & decoder
+prompts explicitly, i.e. ExplicitEncoderDecoderPrompt
+
+A prompt of type SingletonPromptInputs may be employed
+as (1) input to a decoder-only model, (2) input to
+the encoder of an encoder/decoder model, in the scenario
+where the decoder-prompt is not specified explicitly, or
+(3) as a member of a larger data structure encapsulating
+more than one prompt, i.e. ExplicitEncoderDecoderPrompt
"""
+class ExplicitEncoderDecoderPrompt(TypedDict):
+ """Represents an encoder/decoder model input prompt,
+ comprising an explicit encoder prompt and a
+ decoder prompt.
+
+ The encoder and decoder prompts, respectively,
+ may formatted according to any of the
+ SingletonPromptInputs schemas, and are not
+ required to have the same schema.
+
+ Only the encoder prompt may have multi-modal data.
+
+ Note that an ExplicitEncoderDecoderPrompt may not
+ be used as an input to a decoder-only model,
+ and that the `encoder_prompt` and `decoder_prompt`
+ fields of this data structure may not themselves
+ must be SingletonPromptInputs instances.
+ """
+
+ encoder_prompt: SingletonPromptInputs
+
+ decoder_prompt: SingletonPromptInputs
+
+
+PromptInputs = Union[SingletonPromptInputs, ExplicitEncoderDecoderPrompt]
+"""
+Set of possible schemas for an LLM input, including
+both decoder-only and encoder/decoder input types:
+
+- A text prompt (:class:`str` or :class:`TextPrompt`)
+- A tokenized prompt (:class:`TokensPrompt`)
+- A single data structure containing both an encoder and a decoder prompt
+ (:class:`ExplicitEncoderDecoderPrompt`)
+"""
+
+
+def _has_required_keys(
+ d: dict,
+ required_keys: set,
+) -> bool:
+ return required_keys.issubset(d.keys())
+
+
+def get_prompt_type(prompt: Optional[PromptInputs]) -> Optional[str]:
+ """
+ Get the type-name of the prompt argument instance, given that
+ isinstance() cannot apply to TypedDict subclasses directly.
+ If the prompt is None, return 'None' as the type name.
+
+ Arguments:
+
+ * prompt: LLM input prompt or None
+
+ Returns:
+
+ * String representation of prompt type
+ """
+
+ if prompt is None:
+ return 'None'
+
+ required_keys_dict = {
+ 'TextPrompt': {'prompt'},
+ 'TokensPrompt': {'prompt_token_ids'},
+ 'ExplicitEncoderDecoder': {'encoder_prompt', 'decoder_prompt'},
+ }
+
+ if isinstance(prompt, dict):
+ for (ptype, required_keys) in required_keys_dict.items():
+ # Ignore type checking in the conditional below because type
+ # checker does not understand that is_dict(prompt) narrows
+ # down the possible types
+ if _has_required_keys(
+ prompt, # type: ignore
+ required_keys):
+ return ptype
+
+ raise ValueError(f"Invalid prompt {prompt}, valid types are "
+ "required_keys_dict={required_keys_dict}")
+
+ if isinstance(prompt, str):
+ return "str"
+
+ raise ValueError(f"Invalid prompt {prompt}")
+
+
class LLMInputs(TypedDict):
"""
The inputs in :class:`~vllm.LLMEngine` before they are
@@ -114,8 +213,29 @@ class LLMInputs(TypedDict):
The original prompt text corresponding to the token IDs, if available.
"""
+ encoder_prompt_token_ids: NotRequired[List[int]]
+ """The token IDs of the encoder prompt."""
+
+ encoder_prompt: NotRequired[Optional[str]]
+ """
+ The original encoder prompt text corresponding to the token IDs, if
+ available.
+ """
+
multi_modal_data: NotRequired[Optional["MultiModalDataDict"]]
"""
Optional multi-modal data to pass to the model,
if the model supports it.
"""
+
+
+def is_valid_encoder_decoder_llm_inputs(inputs: LLMInputs) -> bool:
+ """
+ Return True if the LLMInputs instance has the correct configuration
+ for encoder/decoder.
+ """
+
+ # True if encoder prompt token ids field exists &
+ # is not None
+ return ('encoder_prompt_token_ids' in inputs
+ and inputs['encoder_prompt_token_ids'] is not None)
diff --git a/vllm/model_executor/__init__.py b/vllm/model_executor/__init__.py
index fb98f4a6b46f4..5c767e22de4d0 100644
--- a/vllm/model_executor/__init__.py
+++ b/vllm/model_executor/__init__.py
@@ -1,7 +1,11 @@
+from vllm.model_executor.parameter import (BasevLLMParameter,
+ PackedvLLMParameter)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_random_seed
__all__ = [
"SamplingMetadata",
"set_random_seed",
+ "BasevLLMParameter",
+ "PackedvLLMParameter",
]
diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py
index cd53c2b916211..646839ff303ee 100644
--- a/vllm/model_executor/layers/linear.py
+++ b/vllm/model_executor/layers/linear.py
@@ -13,10 +13,14 @@
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
+from vllm.model_executor.parameter import (BasevLLMParameter,
+ PackedvLLMParameter)
from vllm.model_executor.utils import set_weight_attrs
logger = init_logger(__name__)
+WEIGHT_LOADER_V2_SUPPORTED = ["CompressedTensorsLinearMethod"]
+
def adjust_marlin_shard(param, shard_size, shard_offset):
marlin_tile_size = getattr(param, "marlin_tile_size", None)
@@ -288,6 +292,7 @@ def __init__(self,
if output_sizes is None:
output_sizes = [output_size]
+
self.quant_method.create_weights(
layer=self,
input_size_per_partition=self.input_size,
@@ -295,7 +300,9 @@ def __init__(self,
input_size=self.input_size,
output_size=self.output_size,
params_dtype=self.params_dtype,
- weight_loader=self.weight_loader,
+ weight_loader=(
+ self.weight_loader_v2 if self.quant_method.__class__.__name__
+ in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader),
prefix=prefix)
if bias:
self.bias = Parameter(
@@ -337,6 +344,9 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
+ def weight_loader_v2(self, param: Parameter, loaded_weight: torch.Tensor):
+ param.load_column_parallel_weight(loaded_weight=loaded_weight)
+
def forward(self, input_):
bias = self.bias if not self.skip_bias_add else None
@@ -527,6 +537,62 @@ def weight_loader(self,
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
+ def _load_fused_module_from_checkpoint(self, param: BasevLLMParameter,
+ loaded_weight: torch.Tensor):
+ """
+ Handle special case for models where MLP layers are already
+ fused on disk. In this case, we have no shard id. This function
+ determmines the shard id by splitting these layers and then calls
+ the weight loader using the shard id.
+
+ An example of a model with these fused layers:
+ https://huggingface.co/microsoft/Phi-3-mini-4k-instruct
+ """
+
+ current_shard_offset = 0
+ shard_offsets: List[Tuple[int, int, int]] = []
+ for i, output_size in enumerate(self.output_sizes):
+ shard_offsets.append((i, current_shard_offset, output_size))
+ current_shard_offset += output_size
+
+ for shard_id, shard_offset, shard_size in shard_offsets:
+ # Special case for Quantization.
+ # If quantized, we need to adjust the offset and size to account
+ # for the packing.
+ if isinstance(param, PackedvLLMParameter
+ ) and param.packed_dim == param.output_dim:
+ param.adjust_shard_indexes_for_packing(
+ shard_size=shard_size, shard_offset=shard_offset)
+
+ loaded_weight_shard = loaded_weight.narrow(param.output_dim,
+ shard_offset,
+ shard_size)
+ self.weight_loader_v2(param, loaded_weight_shard, shard_id)
+
+ def weight_loader_v2(self,
+ param: BasevLLMParameter,
+ loaded_weight: torch.Tensor,
+ loaded_shard_id: Optional[int] = None):
+ param_data = param.data
+ if loaded_shard_id is None:
+ if param.output_dim is None:
+ assert param_data.shape == loaded_weight.shape
+ param_data.copy_(loaded_weight)
+ return
+ self._load_fused_module_from_checkpoint(param, loaded_weight)
+ return
+
+ assert loaded_shard_id < len(self.output_sizes)
+
+ tp_size = get_tensor_model_parallel_world_size()
+ shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size
+ shard_size = self.output_sizes[loaded_shard_id] // tp_size
+
+ param.load_merged_column_weight(loaded_weight=loaded_weight,
+ shard_id=loaded_shard_id,
+ shard_offset=shard_offset,
+ shard_size=shard_size)
+
class QKVParallelLinear(ColumnParallelLinear):
"""Linear layers for the attention's QKV transformation.
@@ -598,6 +664,82 @@ def __init__(self,
quant_config=quant_config,
prefix=prefix)
+ def _get_shard_offset_mapping(self, loaded_shard_id: str):
+ shard_offset_mapping = {
+ "q": 0,
+ "k": self.num_heads * self.head_size,
+ "v": (self.num_heads + self.num_kv_heads) * self.head_size,
+ "total": (self.num_heads + 2 * self.num_kv_heads) * self.head_size
+ }
+ return shard_offset_mapping.get(loaded_shard_id)
+
+ def _get_shard_size_mapping(self, loaded_shard_id: str):
+ shard_size_mapping = {
+ "q": self.num_heads * self.head_size,
+ "k": self.num_kv_heads * self.head_size,
+ "v": self.num_kv_heads * self.head_size,
+ }
+ return shard_size_mapping.get(loaded_shard_id)
+
+ def _load_fused_module_from_checkpoint(self, param: BasevLLMParameter,
+ loaded_weight: torch.Tensor):
+ """
+ Handle special case for models where QKV layers are already
+ fused on disk. In this case, we have no shard id. This function
+ determmines the shard id by splitting these layers and then calls
+ the weight loader using the shard id.
+
+ An example of a model with these fused layers:
+ https://huggingface.co/microsoft/Phi-3-mini-4k-instruct
+ """
+ shard_offsets = [
+ # (shard_id, shard_offset, shard_size)
+ ("q", 0, self.total_num_heads * self.head_size),
+ ("k", self.total_num_heads * self.head_size,
+ self.total_num_kv_heads * self.head_size),
+ ("v",
+ (self.total_num_heads + self.total_num_kv_heads) * self.head_size,
+ self.total_num_kv_heads * self.head_size),
+ ]
+
+ for shard_id, shard_offset, shard_size in shard_offsets:
+ # Special case for Quantization.
+ # If quantized, we need to adjust the offset and size to account
+ # for the packing.
+ if isinstance(param, PackedvLLMParameter
+ ) and param.packed_dim == param.output_dim:
+ param.adjust_shard_indexes_for_packing(
+ shard_size=shard_size, shard_offset=shard_offset)
+
+ loaded_weight_shard = loaded_weight.narrow(param.output_dim,
+ shard_offset,
+ shard_size)
+ self.weight_loader_v2(param, loaded_weight_shard, shard_id)
+
+ def weight_loader_v2(self,
+ param: BasevLLMParameter,
+ loaded_weight: torch.Tensor,
+ loaded_shard_id: Optional[str] = None):
+ param_data = param.data
+ if loaded_shard_id is None: # special case for certain models
+ if param.output_dim is None:
+ assert param_data.shape == loaded_weight.shape
+ param_data.copy_(loaded_weight)
+ return
+ self._load_fused_module_from_checkpoint(param, loaded_weight)
+ return
+
+ assert loaded_shard_id in ["q", "k", "v"]
+
+ shard_offset = self._get_shard_offset_mapping(loaded_shard_id)
+ shard_size = self._get_shard_size_mapping(loaded_shard_id)
+
+ param.load_qkv_weight(loaded_weight=loaded_weight,
+ num_heads=self.num_kv_head_replicas,
+ shard_id=loaded_shard_id,
+ shard_offset=shard_offset,
+ shard_size=shard_size)
+
def weight_loader(self,
param: Parameter,
loaded_weight: torch.Tensor,
@@ -798,6 +940,7 @@ def __init__(self,
self.tp_size = get_tensor_model_parallel_world_size()
self.input_size_per_partition = divide(input_size, self.tp_size)
assert self.quant_method is not None
+
self.quant_method.create_weights(
layer=self,
input_size_per_partition=self.input_size_per_partition,
@@ -805,7 +948,9 @@ def __init__(self,
input_size=self.input_size,
output_size=self.output_size,
params_dtype=self.params_dtype,
- weight_loader=self.weight_loader,
+ weight_loader=(
+ self.weight_loader_v2 if self.quant_method.__class__.__name__
+ in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader),
prefix=prefix)
if not reduce_results and (bias and not skip_bias_add):
raise ValueError("When not reduce the results, adding bias to the "
@@ -850,6 +995,10 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
+ def weight_loader_v2(self, param: BasevLLMParameter,
+ loaded_weight: torch.Tensor):
+ param.load_row_parallel_weight(loaded_weight=loaded_weight)
+
def forward(self, input_):
if self.input_is_parallel:
input_parallel = input_
diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py
index 39d00bd5733ff..ae75781927381 100644
--- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py
+++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py
@@ -19,6 +19,8 @@
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.platforms import current_platform
+__all__ = ["CompressedTensorsLinearMethod"]
+
class CompressedTensorsConfig(QuantizationConfig):
@@ -146,18 +148,15 @@ def _is_fp8_w8a8(self, weight_quant: BaseModel,
if weight_quant is None or input_quant is None:
return False
- # Confirm we have floating points.
- if not (weight_quant.type == QuantizationType.FLOAT
- and input_quant.type == QuantizationType.FLOAT):
- return False
-
# Confirm weight scheme is supported.
+ is_floating_point = (weight_quant.type == QuantizationType.FLOAT
+ and input_quant.type == QuantizationType.FLOAT)
is_symmetric_weight = weight_quant.symmetric
is_static_weight = not weight_quant.dynamic
is_per_tensor_or_channel_weight = (weight_quant.strategy in [
QuantizationStrategy.TENSOR, QuantizationStrategy.CHANNEL
])
- if not (is_symmetric_weight and is_static_weight
+ if not (is_floating_point and is_symmetric_weight and is_static_weight
and is_per_tensor_or_channel_weight):
return False
@@ -169,11 +168,7 @@ def _is_fp8_w8a8(self, weight_quant: BaseModel,
is_symmetric_activation = input_quant.symmetric
is_per_tensor_activation = (
input_quant.strategy == QuantizationStrategy.TENSOR)
- if not (is_symmetric_activation and is_per_tensor_activation):
- return False
-
- # All conditions satisfied.
- return True
+ return is_symmetric_activation and is_per_tensor_activation
def _is_fp8_w8a16(self, weight_quant: BaseModel,
input_quant: BaseModel) -> bool:
@@ -230,6 +225,7 @@ def _get_scheme_from_parts(
group_size=weight_quant.group_size)
# Detect If Activation Quantization.
+ # TODO @dsikka: clean-up conditions
if is_activation_quantization_format(self.quant_format):
if self._is_fp8_w8a8(weight_quant, input_quant):
is_fp8_w8a8_supported = self._check_scheme_supported(
@@ -237,7 +233,8 @@ def _get_scheme_from_parts(
if is_fp8_w8a8_supported:
return CompressedTensorsW8A8Fp8(
strategy=weight_quant.strategy,
- is_static_input_scheme=(not input_quant.dynamic))
+ is_static_input_scheme=(input_quant
+ and not input_quant.dynamic))
else:
return CompressedTensorsW8A16Fp8(
strategy=weight_quant.strategy,
diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_unquantized.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_unquantized.py
index b7ba29ddc9840..2e8d520eacc81 100644
--- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_unquantized.py
+++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_unquantized.py
@@ -2,11 +2,10 @@
import torch
import torch.nn.functional as F
-from torch.nn import Parameter
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme)
-from vllm.model_executor.utils import set_weight_attrs
+from vllm.model_executor.parameter import ModelWeightParameter
__all__ = ["CompressedTensorsUnquantized"]
@@ -24,7 +23,9 @@ def get_min_capability(cls) -> int:
return 70
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
- pass
+ # required by torch.compile to be torch.nn.Parameter
+ layer.weight = torch.nn.Parameter(layer.weight.data,
+ requires_grad=False)
def create_weights(self, layer: torch.nn.Module,
output_partition_sizes: List[int],
@@ -32,14 +33,15 @@ def create_weights(self, layer: torch.nn.Module,
params_dtype: torch.dtype, weight_loader: Callable,
**kwargs):
- weight = Parameter(torch.empty(sum(output_partition_sizes),
- input_size_per_partition,
- dtype=params_dtype),
- requires_grad=False)
+ weight = ModelWeightParameter(data=torch.empty(
+ sum(output_partition_sizes),
+ input_size_per_partition,
+ dtype=params_dtype),
+ input_dim=1,
+ output_dim=0,
+ weight_loader=weight_loader)
- set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
layer.register_parameter("weight", weight)
- set_weight_attrs(weight, {"weight_loader": weight_loader})
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:
diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py
index c1adfdb2980b6..9ad61a64e406c 100644
--- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py
+++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py
@@ -8,7 +8,10 @@
CompressedTensorsScheme)
from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
GPTQ_MARLIN_24_MAX_PARALLEL, GPTQ_MARLIN_24_MIN_THREAD_N)
-from vllm.model_executor.utils import set_weight_attrs
+from vllm.model_executor.parameter import (BasevLLMParameter,
+ ChannelQuantScaleParameter,
+ GroupQuantScaleParameter,
+ PackedvLLMParameter)
from vllm.scalar_type import scalar_types
__all__ = ["CompressedTensorsW4A16Sparse24"]
@@ -45,7 +48,12 @@ def get_min_capability(cls) -> int:
return 80
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
- pass
+ # required by torch.compile to be torch.nn.Parameter
+ layer.weight_packed = Parameter(layer.weight_packed.data,
+ requires_grad=False)
+ layer.scale_packed = Parameter(layer.scale_packed.data,
+ requires_grad=False)
+ layer.meta = Parameter(layer.meta.data, requires_grad=False)
def create_weights(self, layer: torch.nn.Module, input_size: int,
output_partition_sizes: List[int],
@@ -56,79 +64,65 @@ def create_weights(self, layer: torch.nn.Module, input_size: int,
pack_factor = 32 // self.quant_type.size_bits
output_size_per_partition = sum(output_partition_sizes)
- qweight = Parameter(
- torch.empty(
- input_size_per_partition // self.tile_size // 2,
- output_size_per_partition * self.tile_size // pack_factor,
- dtype=torch.int32,
- ),
- requires_grad=False,
- )
- set_weight_attrs(
- qweight,
- {
- "input_dim": 0,
- "output_dim": 1,
- "packed_dim": 1,
- "pack_factor": pack_factor,
- "marlin_tile_size": self.tile_size,
- "weight_loader": weight_loader
- },
- )
-
- layer.register_parameter("weight_packed", qweight)
+ qweight = PackedvLLMParameter(data=torch.empty(
+ input_size_per_partition // self.tile_size // 2,
+ output_size_per_partition * self.tile_size // pack_factor,
+ dtype=torch.int32,
+ ),
+ input_dim=0,
+ output_dim=1,
+ packed_dim=1,
+ packed_factor=pack_factor,
+ marlin_tile_size=self.tile_size,
+ weight_loader=weight_loader)
input_groups = (1 if self.group_size is None else
input_size_per_partition // self.group_size)
- scales = Parameter(
+ weight_scale_args = {
+ "data":
torch.empty(
input_groups,
output_size_per_partition,
dtype=params_dtype,
),
- requires_grad=False,
- )
- set_weight_attrs(
- scales,
- {
- "output_dim": 1,
- "input_dim": None if input_groups == 1 else 0,
- "weight_loader": weight_loader
- },
- )
- layer.register_parameter("scale_packed", scales)
-
- weight_shape = Parameter(torch.empty(2, dtype=torch.int64),
- requires_grad=False)
+ "weight_loader":
+ weight_loader
+ }
+
+ if self.group_size is not None:
+ scales = GroupQuantScaleParameter(output_dim=1,
+ input_dim=0,
+ **weight_scale_args)
+ else:
+ scales = ChannelQuantScaleParameter(output_dim=1,
+ **weight_scale_args)
+
+ weight_shape = BasevLLMParameter(data=torch.empty(2,
+ dtype=torch.int64),
+ weight_loader=weight_loader)
+
+ meta = PackedvLLMParameter(data=torch.empty(
+ input_size_per_partition // 8 // 2 // 2,
+ output_size_per_partition * 2,
+ dtype=torch.int16,
+ ),
+ input_dim=0,
+ output_dim=1,
+ packed_dim=1,
+ packed_factor=1,
+ marlin_tile_size=2,
+ weight_loader=weight_loader)
+ layer.register_parameter("weight_packed", qweight)
layer.register_parameter("weight_shape", weight_shape)
- set_weight_attrs(weight_shape, {"weight_loader": weight_loader})
-
- meta = Parameter(
- torch.empty(
- input_size_per_partition // 8 // 2 // 2,
- output_size_per_partition * 2,
- dtype=torch.int16,
- ),
- requires_grad=False,
- )
- set_weight_attrs(
- meta,
- {
- "input_dim": 0,
- "packed_dim": 1,
- "pack_factor": 1,
- "output_dim": 1,
- "marlin_tile_size": 2,
- "weight_loader": weight_loader
- },
- )
+ layer.register_parameter("scale_packed", scales)
layer.register_parameter("meta", meta)
max_workspace_size = (
output_size_per_partition //
GPTQ_MARLIN_24_MIN_THREAD_N) * GPTQ_MARLIN_24_MAX_PARALLEL
+
workspace = Parameter(torch.zeros(max_workspace_size, dtype=torch.int),
requires_grad=False)
layer.workspace = workspace
diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py
index eeb7c042e1d1f..3d55d55cc390d 100644
--- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py
+++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py
@@ -9,9 +9,10 @@
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
- convert_to_channelwise, create_per_channel_scale_param,
- create_per_tensor_scale_param)
-from vllm.model_executor.utils import set_weight_attrs
+ convert_to_channelwise)
+from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
+ ModelWeightParameter,
+ PerTensorScaleParameter)
__all__ = ["CompressedTensorsW8A16Fp8"]
@@ -40,11 +41,19 @@ def process_weights_after_loading(self, layer) -> None:
layer.logical_widths)
layer.weight_scale = torch.nn.Parameter(ws_channelwise,
requires_grad=False)
+ else:
+ # required by torch.compile to be torch.nn.Parameter
+ layer.weight_scale = torch.nn.Parameter(layer.weight_scale.data,
+ requires_grad=False)
# Weights must be transposed for marlin
layer.weight = torch.nn.Parameter(layer.weight.t(),
requires_grad=False)
+ if self.is_static_input_scheme:
+ # required by torch.compile to be torch.nn.Parameter
+ layer.input_scale = torch.nn.Parameter(layer.input_scale.data,
+ requires_grad=False)
prepare_fp8_layer_for_marlin(layer, strategy="channel")
def create_weights(self, layer: torch.nn.Module, input_size: int,
@@ -60,35 +69,39 @@ def create_weights(self, layer: torch.nn.Module, input_size: int,
layer.orig_dtype = params_dtype
# WEIGHT
- weight = torch.nn.Parameter(torch.empty(output_size_per_partition,
- input_size_per_partition,
- dtype=torch.float8_e4m3fn),
- requires_grad=False)
+ weight = ModelWeightParameter(data=torch.empty(
+ output_size_per_partition,
+ input_size_per_partition,
+ dtype=torch.float8_e4m3fn),
+ input_dim=1,
+ output_dim=0,
+ weight_loader=weight_loader)
layer.register_parameter("weight", weight)
- set_weight_attrs(weight, {
- "input_dim": 1,
- "output_dim": 0,
- "weight_loader": weight_loader,
- })
# WEIGHT SCALE
- layer_kwargs = {"weight_loader": weight_loader}
if self.strategy == QuantizationStrategy.CHANNEL:
- weight_scale = create_per_channel_scale_param(
- output_partition_sizes, **layer_kwargs)
+ weight_scale = ChannelQuantScaleParameter(
+ data=torch.empty((sum(output_partition_sizes), 1),
+ dtype=torch.float32),
+ output_dim=0,
+ weight_loader=weight_loader)
elif self.strategy == QuantizationStrategy.TENSOR:
- weight_scale = create_per_tensor_scale_param(
- output_partition_sizes, **layer_kwargs)
+ weight_scale = PerTensorScaleParameter(data=torch.empty(
+ len(output_partition_sizes), dtype=torch.float32),
+ weight_loader=weight_loader)
else:
raise ValueError(
f"Unsupported weight strategy={self.strategy}, "
f"supported strategies are {SUPPORTED_STRATEGIES}")
+
+ weight_scale[:] = torch.finfo(torch.float32).min
layer.register_parameter("weight_scale", weight_scale)
# INPUT SCALE (to deal with converted checkpoints)
if self.is_static_input_scheme:
- input_scale = create_per_tensor_scale_param(
- output_partition_sizes, **layer_kwargs)
+ input_scale = PerTensorScaleParameter(data=torch.empty(
+ len(output_partition_sizes), dtype=torch.float32),
+ weight_loader=weight_loader)
layer.register_parameter("input_scale", input_scale)
def apply_weights(self,
diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py
index cc9d71db140c2..8a3d24e2fd258 100644
--- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py
+++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py
@@ -8,10 +8,10 @@
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
QuantizationStrategy)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
- apply_fp8_linear, create_per_channel_scale_param,
- create_per_tensor_scale_param, cutlass_fp8_supported,
- requantize_with_max_scale)
-from vllm.model_executor.utils import set_weight_attrs
+ apply_fp8_linear, cutlass_fp8_supported, requantize_with_max_scale)
+from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
+ ModelWeightParameter,
+ PerTensorScaleParameter)
__all__ = ["CompressedTensorsW8A8Fp8"]
@@ -46,6 +46,9 @@ def process_weights_after_loading(self, layer) -> None:
elif self.strategy == QuantizationStrategy.CHANNEL:
weight = layer.weight
layer.weight = Parameter(weight.t(), requires_grad=False)
+ # required by torch.compile to be torch.nn.Parameter
+ layer.weight_scale = Parameter(layer.weight_scale.data,
+ requires_grad=False)
else:
raise ValueError(f"Unknown quantization strategy {self.strategy}")
@@ -66,32 +69,40 @@ def create_weights(self, layer: torch.nn.Module,
layer.logical_widths = output_partition_sizes
# WEIGHT
- weight = torch.nn.Parameter(torch.empty(output_size_per_partition,
- input_size_per_partition,
- dtype=torch.float8_e4m3fn),
- requires_grad=False)
+ weight = ModelWeightParameter(data=torch.empty(
+ output_size_per_partition,
+ input_size_per_partition,
+ dtype=torch.float8_e4m3fn),
+ input_dim=1,
+ output_dim=0,
+ weight_loader=weight_loader)
layer.register_parameter("weight", weight)
- set_weight_attrs(weight, {
- "input_dim": 1,
- "output_dim": 0,
- "weight_loader": weight_loader,
- })
# WEIGHT SCALE
- layer_kwargs = {"weight_loader": weight_loader}
+ # TODO: update create_xxx_parameter functions to return
+ # the newly added parameters
if self.strategy == QuantizationStrategy.CHANNEL:
- weight_scale = create_per_channel_scale_param(
- output_partition_sizes, **layer_kwargs)
+ weight_scale = ChannelQuantScaleParameter(
+ data=torch.empty((sum(output_partition_sizes), 1),
+ dtype=torch.float32),
+ output_dim=0,
+ weight_loader=weight_loader)
else:
assert self.strategy == QuantizationStrategy.TENSOR
- weight_scale = create_per_tensor_scale_param(
- output_partition_sizes, **layer_kwargs)
+ weight_scale = PerTensorScaleParameter(data=torch.empty(
+ len(output_partition_sizes), dtype=torch.float32),
+ weight_loader=weight_loader)
+
+ # min requirement for fp8 kernels
+ weight_scale[:] = torch.finfo(torch.float32).min
layer.register_parameter("weight_scale", weight_scale)
# INPUT SCALE
if self.is_static_input_scheme:
- input_scale = create_per_tensor_scale_param(
- output_partition_sizes, **layer_kwargs)
+ input_scale = PerTensorScaleParameter(data=torch.empty(
+ len(output_partition_sizes), dtype=torch.float32),
+ weight_loader=weight_loader)
+ input_scale[:] = torch.finfo(torch.float32).min
layer.register_parameter("input_scale", input_scale)
def apply_weights(self,
diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py
index 3a80863d3abbe..078380f159291 100644
--- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py
+++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py
@@ -8,9 +8,11 @@
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
QuantizationStrategy)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
- apply_int8_linear, convert_to_channelwise, create_per_channel_scale_param,
- create_per_tensor_scale_param)
-from vllm.model_executor.utils import set_weight_attrs
+ apply_int8_linear, convert_to_channelwise)
+from vllm.model_executor.parameter import (BasevLLMParameter,
+ ChannelQuantScaleParameter,
+ ModelWeightParameter,
+ PerTensorScaleParameter)
class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
@@ -39,7 +41,9 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
ws_channelwise = convert_to_channelwise(layer.weight_scale,
self.logical_widths)
layer.weight_scale = Parameter(ws_channelwise, requires_grad=False)
-
+ else:
+ layer.weight_scale = Parameter(layer.weight_scale.data,
+ requires_grad=False)
# INPUT SCALE
if self.is_static_input_scheme:
layer.input_scale = Parameter(layer.input_scale.max(),
@@ -55,32 +59,35 @@ def create_weights(self, layer: torch.nn.Module,
self.logical_widths = output_partition_sizes
# WEIGHT
- weight = Parameter(torch.empty(sum(output_partition_sizes),
- input_size_per_partition,
- dtype=torch.int8),
- requires_grad=False)
+ weight = ModelWeightParameter(data=torch.empty(
+ sum(output_partition_sizes),
+ input_size_per_partition,
+ dtype=torch.int8),
+ input_dim=1,
+ output_dim=0,
+ weight_loader=weight_loader)
+
layer.register_parameter("weight", weight)
- set_weight_attrs(weight, {
- "input_dim": 1,
- "output_dim": 0,
- "weight_loader": weight_loader,
- })
# WEIGHT SCALE
- layer_kwargs = {"weight_loader": weight_loader}
if self.strategy == QuantizationStrategy.CHANNEL:
- weight_scale = create_per_channel_scale_param(
- output_partition_sizes, **layer_kwargs)
+ weight_scale = ChannelQuantScaleParameter(
+ data=torch.empty((sum(output_partition_sizes), 1),
+ dtype=torch.float32),
+ output_dim=0,
+ weight_loader=weight_loader)
else:
assert self.strategy == QuantizationStrategy.TENSOR
- weight_scale = create_per_tensor_scale_param(
- output_partition_sizes, **layer_kwargs)
+ weight_scale = PerTensorScaleParameter(data=torch.empty(
+ len(output_partition_sizes), dtype=torch.float32),
+ weight_loader=weight_loader)
layer.register_parameter("weight_scale", weight_scale)
# INPUT SCALE
if self.is_static_input_scheme:
- input_scale = create_per_tensor_scale_param(
- output_partition_sizes, **layer_kwargs)
+ input_scale = BasevLLMParameter(data=torch.empty(
+ 1, dtype=torch.float32),
+ weight_loader=weight_loader)
layer.register_parameter("input_scale", input_scale)
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py
index b8880f7ac136f..94699c27d5cee 100644
--- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py
+++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py
@@ -1,7 +1,6 @@
from typing import Callable, List, Optional
import torch
-from torch.nn import Parameter
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
@@ -10,7 +9,10 @@
apply_gptq_marlin_linear, marlin_make_empty_g_idx, marlin_make_workspace,
marlin_permute_scales, replace_tensor, verify_marlin_supported,
verify_marlin_supports_shape)
-from vllm.model_executor.utils import set_weight_attrs
+from vllm.model_executor.parameter import (BasevLLMParameter,
+ ChannelQuantScaleParameter,
+ GroupQuantScaleParameter,
+ PackedvLLMParameter)
from vllm.scalar_type import scalar_types
__all__ = ["CompressedTensorsWNA16"]
@@ -30,17 +32,12 @@ def __init__(self,
self.pack_factor = 32 // num_bits
self.strategy = strategy
+ self.group_size = -1 if group_size is None else group_size
- self.group_size: int
- if group_size is None:
- if self.strategy != "channel":
- raise ValueError(
- "Marlin kernels require group quantization or "
- "channelwise quantization, but found no group "
- "size and strategy is not channelwise.")
- self.group_size = -1
- else:
- self.group_size = group_size
+ if self.group_size == -1 and self.strategy != "channel":
+ raise ValueError("Marlin kernels require group quantization or "
+ "channelwise quantization, but found no group "
+ "size and strategy is not channelwise.")
if num_bits not in WNA16_SUPPORTED_TYPES_MAP:
raise ValueError(
@@ -63,11 +60,12 @@ def create_weights(self, layer: torch.nn.Module, input_size: int,
input_size_per_partition: int,
params_dtype: torch.dtype, weight_loader: Callable,
**kwargs):
+
output_size_per_partition = sum(output_partition_sizes)
# If group_size is -1, we are in channelwise case.
channelwise = (self.group_size == -1)
- group_size = input_size if channelwise else self.group_size
+ group_size = self.group_size if self.group_size != -1 else input_size
row_parallel = (input_size != input_size_per_partition)
# In the case of channelwise quantization, we need to replicate the
# scales across all gpus.
@@ -79,60 +77,51 @@ def create_weights(self, layer: torch.nn.Module, input_size: int,
input_size=input_size,
group_size=group_size)
- weight_scale_dim = None
scales_and_zp_size = input_size // group_size
if partition_scales:
assert input_size_per_partition % group_size == 0
- weight_scale_dim = 1
scales_and_zp_size = input_size_per_partition // group_size
- weight = Parameter(
- torch.empty(
- output_size_per_partition,
- input_size_per_partition // self.pack_factor,
- dtype=torch.int32,
- ),
- requires_grad=False,
- )
-
- set_weight_attrs(
- weight, {
- "input_dim": 1,
- "output_dim": 0,
- "packed_dim": 1,
- "pack_factor": self.pack_factor,
- "weight_loader": weight_loader
- })
- layer.register_parameter("weight_packed", weight)
-
- weight_scale = Parameter(
+ weight = PackedvLLMParameter(input_dim=1,
+ output_dim=0,
+ weight_loader=weight_loader,
+ packed_factor=self.pack_factor,
+ packed_dim=1,
+ data=torch.empty(
+ output_size_per_partition,
+ input_size_per_partition //
+ self.pack_factor,
+ dtype=torch.int32,
+ ))
+
+ weight_scale_args = {
+ "weight_loader":
+ weight_loader,
+ "data":
torch.empty(
output_size_per_partition,
scales_and_zp_size,
dtype=params_dtype,
- ),
- requires_grad=False,
- )
-
- set_weight_attrs(
- weight_scale, {
- "weight_loader": weight_loader,
- "input_dim": weight_scale_dim,
- "output_dim": 0
- })
- layer.register_parameter("weight_scale", weight_scale)
+ )
+ }
+ if self.group_size == -1:
+ weight_scale = ChannelQuantScaleParameter(output_dim=0,
+ **weight_scale_args)
+ else:
+ weight_scale = GroupQuantScaleParameter(output_dim=0,
+ input_dim=1,
+ **weight_scale_args)
# A 2D array defining the original shape of the weights
# before packing
- weight_shape = Parameter(torch.empty(2, dtype=torch.int64),
- requires_grad=False)
+ weight_shape = BasevLLMParameter(data=torch.empty(2,
+ dtype=torch.int64),
+ weight_loader=weight_loader)
+ layer.register_parameter("weight_packed", weight)
+ layer.register_parameter("weight_scale", weight_scale)
layer.register_parameter("weight_shape", weight_shape)
- set_weight_attrs(weight_shape, {
- "weight_loader": weight_loader,
- "ignore_warning": True,
- })
layer.input_size_per_partition = input_size_per_partition
layer.output_size_per_partition = output_size_per_partition
@@ -154,10 +143,15 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# No zero-point
layer.weight_zp = marlin_make_empty_g_idx(device)
+ # Update for kernel
+ layer.weight_packed = torch.nn.Parameter(
+ layer.weight_packed.t().contiguous(), requires_grad=False)
+ layer.weight_scale = torch.nn.Parameter(
+ layer.weight_scale.squeeze().t().contiguous(), requires_grad=False)
# Repack weights from compressed-tensors format to marlin format.
marlin_qweight = ops.gptq_marlin_repack(
- layer.weight_packed.t().contiguous(),
+ layer.weight_packed,
perm=layer.g_idx_sort_indices,
size_k=layer.input_size_per_partition,
size_n=layer.output_size_per_partition,
@@ -166,7 +160,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# Permute scales from compressed-tensors format to marlin format.
marlin_scales = marlin_permute_scales(
- layer.weight_scale.squeeze().t().contiguous(),
+ layer.weight_scale,
size_k=layer.input_size_per_partition,
size_n=layer.output_size_per_partition,
group_size=layer.group_size)
diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py
index c829cb836ee4c..cdd2413f5b2c4 100644
--- a/vllm/model_executor/layers/quantization/fp8.py
+++ b/vllm/model_executor/layers/quantization/fp8.py
@@ -4,6 +4,7 @@
from torch.nn import Module
from torch.nn.parameter import Parameter
+import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import FusedMoE, FusedMoEMethodBase
@@ -118,7 +119,7 @@ def __init__(self, quant_config: Fp8Config):
# kernel for fast weight-only FP8 quantization
capability = current_platform.get_device_capability()
capability = capability[0] * 10 + capability[1]
- self.use_marlin = capability < 89
+ self.use_marlin = capability < 89 or envs.VLLM_TEST_FORCE_FP8_MARLIN
def create_weights(
self,
@@ -174,6 +175,14 @@ def process_weights_after_loading(self, layer: Module) -> None:
qweight, weight_scale = ops.scaled_fp8_quant(layer.weight,
scale=None)
+ # If using marlin (w8a16), kernel uses channelwise weights,
+ # so extend the weight scales to be channelwise.
+ if self.use_marlin:
+ assert weight_scale.numel() == 1
+ weight_scale = convert_to_channelwise(
+ weight_scale.expand(len(layer.logical_widths)),
+ layer.logical_widths)
+
# Update the layer with the new values.
layer.weight = Parameter(qweight.t(), requires_grad=False)
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
diff --git a/vllm/model_executor/layers/quantization/gptq.py b/vllm/model_executor/layers/quantization/gptq.py
index 510c9dd49ef03..aa04fcf8310bf 100644
--- a/vllm/model_executor/layers/quantization/gptq.py
+++ b/vllm/model_executor/layers/quantization/gptq.py
@@ -204,13 +204,7 @@ def create_weights(
layer.exllama_state = exllama_state
- def apply(self,
- layer: torch.nn.Module,
- x: torch.Tensor,
- bias: Optional[torch.Tensor] = None) -> torch.Tensor:
- qweight = layer.qweight
- out_shape = x.shape[:-1] + (qweight.shape[-1], )
- reshaped_x = x.reshape(-1, x.shape[-1])
+ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# exllama needs to shuffle the weight after the weight is loaded
# here we do the shuffle on first forward pass
if layer.exllama_state == ExllamaState.UNINITIALIZED:
@@ -222,6 +216,14 @@ def apply(self,
layer.exllama_state = ExllamaState.READY
ops.gptq_shuffle(layer.qweight, layer.g_idx,
self.quant_config.weight_bits)
+
+ def apply(self,
+ layer: torch.nn.Module,
+ x: torch.Tensor,
+ bias: Optional[torch.Tensor] = None) -> torch.Tensor:
+ out_shape = x.shape[:-1] + (layer.qweight.shape[-1], )
+ reshaped_x = x.reshape(-1, x.shape[-1])
+
output = ops.gptq_gemm(reshaped_x, layer.qweight, layer.qzeros,
layer.scales, layer.g_idx,
layer.exllama_state == ExllamaState.READY,
diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py
index 4a11b14971076..066102f3a01c0 100644
--- a/vllm/model_executor/layers/quantization/gptq_marlin.py
+++ b/vllm/model_executor/layers/quantization/gptq_marlin.py
@@ -251,7 +251,6 @@ def create_weights(
scales_and_zp_size,
output_size_per_partition // self.quant_config.pack_factor,
dtype=torch.int32,
- device="meta",
),
requires_grad=False,
)
diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py
index ebb77a802d5cb..0f91b92665c28 100644
--- a/vllm/model_executor/models/__init__.py
+++ b/vllm/model_executor/models/__init__.py
@@ -83,7 +83,16 @@
"MistralModel": ("llama_embedding", "LlamaEmbeddingModel"),
}
-_MODELS = {**_GENERATION_MODELS, **_EMBEDDING_MODELS}
+_CONDITIONAL_GENERATION_MODELS = {
+ "BartModel": ("bart", "BartForConditionalGeneration"),
+ "BartForConditionalGeneration": ("bart", "BartForConditionalGeneration"),
+}
+
+_MODELS = {
+ **_GENERATION_MODELS,
+ **_EMBEDDING_MODELS,
+ **_CONDITIONAL_GENERATION_MODELS
+}
# Architecture -> type.
# out of tree models
diff --git a/vllm/model_executor/models/bart.py b/vllm/model_executor/models/bart.py
new file mode 100644
index 0000000000000..5066e991f9003
--- /dev/null
+++ b/vllm/model_executor/models/bart.py
@@ -0,0 +1,996 @@
+# Derived from BART implementation posted on HuggingFace; license below:
+#
+# coding=utf-8
+# Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team.
+# All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch BART model."""
+import math
+from typing import Iterable, List, Optional, Tuple
+
+import torch
+from torch import nn
+from transformers import BartConfig
+from transformers.utils import logging
+
+from vllm.attention import Attention, AttentionMetadata, AttentionType
+from vllm.config import CacheConfig, LoRAConfig
+from vllm.distributed import get_tensor_model_parallel_world_size
+from vllm.model_executor.layers.activation import get_act_fn
+from vllm.model_executor.layers.linear import (ColumnParallelLinear,
+ QKVParallelLinear,
+ RowParallelLinear)
+from vllm.model_executor.layers.logits_processor import LogitsProcessor
+from vllm.model_executor.layers.quantization.base_config import (
+ QuantizationConfig)
+from vllm.model_executor.layers.sampler import Sampler
+from vllm.model_executor.layers.vocab_parallel_embedding import (
+ ParallelLMHead, VocabParallelEmbedding)
+from vllm.model_executor.model_loader.weight_utils import default_weight_loader
+from vllm.model_executor.sampling_metadata import SamplingMetadata
+from vllm.sequence import IntermediateTensors, SamplerOutput
+
+logger = logging.get_logger(__name__)
+
+
+def get_bsz_seq_len(input_ids):
+ shp = input_ids.shape
+ ndim = len(shp)
+ if ndim == 1:
+ return 1, input_ids.numel()
+ else:
+ return shp[:2]
+
+
+class BartLearnedPositionalEmbedding(VocabParallelEmbedding):
+ """
+ This module learns positional embeddings up to a fixed maximum size.
+ """
+
+ def __init__(self, num_embeddings: int, embedding_dim: int):
+ # Bart is set up so that if padding_idx is
+ # specified then offset the embedding ids by 2
+ # and adjust num_embeddings appropriately.
+ # Other models don't have this hack
+ self.offset = 2
+ super().__init__(num_embeddings + self.offset, embedding_dim)
+
+ def forward(
+ self,
+ positions: torch.Tensor,
+ attn_type: AttentionType,
+ ) -> torch.Tensor:
+ """`input_ids' shape is expected to be [bsz x seqlen]."""
+
+ assert attn_type != AttentionType.ENCODER_DECODER
+
+ return super().forward(positions + self.offset)
+
+
+class BartScaledWordEmbedding(VocabParallelEmbedding):
+ """
+ This module overrides VocabParallelEmbedding's
+ forward by multiplying with embeddings scale.
+ """
+
+ def __init__(self,
+ num_embeddings: int,
+ embedding_dim: int,
+ embed_scale: float = 1.0):
+ super().__init__(num_embeddings, embedding_dim)
+ self.embed_scale = embed_scale
+
+ def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
+ return super().forward(input_ids) * self.embed_scale
+
+
+class BartParallelLMHead(ParallelLMHead):
+ """
+ This module overrides ParallelLMHead's
+ forward by dividing by embeddings scale,
+ yielding effectively the inverse of
+ BartScaledWordEmbedding
+ """
+
+ def __init__(self,
+ num_embeddings: int,
+ embedding_dim: int,
+ embed_scale: float = 1.0):
+ super().__init__(num_embeddings, embedding_dim)
+ self.embed_scale = embed_scale
+
+ def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
+ return super().forward(input_ids) / self.embed_scale
+
+
+class BartEncoderAttention(nn.Module):
+
+ def __init__(
+ self,
+ embed_dim: int,
+ num_heads: int,
+ bias: bool = True,
+ config: Optional[BartConfig] = None,
+ cache_config: Optional[CacheConfig] = None,
+ quant_config: Optional[QuantizationConfig] = None,
+ ):
+ super().__init__()
+ self.d_model = config.d_model
+ self.embed_dim = embed_dim
+ self.total_num_heads = num_heads
+ self.total_num_kv_heads = self.total_num_heads
+ self.head_dim = embed_dim // num_heads
+ self.config = config
+
+ if (self.head_dim * num_heads) != self.embed_dim:
+ raise ValueError(f"embed_dim must be divisible by num_heads "
+ f"(got `embed_dim`: {self.embed_dim}"
+ f" and `num_heads`: {num_heads}).")
+ self.scaling = self.head_dim**-0.5
+
+ self.qkv_proj = QKVParallelLinear(
+ self.d_model,
+ self.d_model // self.total_num_heads,
+ self.total_num_heads,
+ self.total_num_kv_heads,
+ bias=bias,
+ quant_config=quant_config,
+ )
+
+ self.out_proj = RowParallelLinear(
+ embed_dim,
+ embed_dim,
+ bias=bias,
+ quant_config=quant_config,
+ )
+
+ tp_world_size = get_tensor_model_parallel_world_size()
+ assert self.total_num_heads % tp_world_size == 0
+ self.num_heads = self.total_num_heads // tp_world_size
+
+ if self.total_num_kv_heads >= tp_world_size:
+ # Number of KV heads is greater than TP size, so we partition
+ # the KV heads across multiple tensor parallel GPUs.
+ assert self.total_num_kv_heads % tp_world_size == 0
+ else:
+ # Number of KV heads is less than TP size, so we replicate
+ # the KV heads across multiple tensor parallel GPUs.
+ assert tp_world_size % self.total_num_kv_heads == 0
+ self.num_kv_heads = max(1, self.total_num_kv_heads // tp_world_size)
+ self.q_size = self.num_heads * self.head_dim
+ self.kv_size = self.num_kv_heads * self.head_dim
+
+ self.attn = Attention(self.num_heads,
+ self.head_dim,
+ self.scaling,
+ num_kv_heads=self.num_kv_heads,
+ cache_config=cache_config,
+ quant_config=quant_config)
+
+ def forward(self, hidden_states: torch.Tensor, kv_cache: torch.Tensor,
+ attn_metadata: AttentionMetadata) -> torch.Tensor:
+ """Input shape: Batch x Time x Channel"""
+
+ qkv, _ = self.qkv_proj(hidden_states)
+ q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
+
+ attn_output = self.attn(q,
+ k,
+ v,
+ kv_cache,
+ attn_metadata,
+ attn_type=AttentionType.ENCODER)
+
+ output, _ = self.out_proj(attn_output)
+ return output
+
+
+class BartDecoderSelfAttention(nn.Module):
+
+ def __init__(
+ self,
+ embed_dim: int,
+ num_heads: int,
+ bias: bool = True,
+ config: Optional[BartConfig] = None,
+ cache_config: Optional[CacheConfig] = None,
+ quant_config: Optional[QuantizationConfig] = None,
+ ):
+ super().__init__()
+ self.d_model = config.d_model
+ self.embed_dim = embed_dim
+ self.total_num_heads = num_heads
+ self.total_num_kv_heads = self.total_num_heads
+ self.head_dim = embed_dim // num_heads
+ self.config = config
+
+ if (self.head_dim * num_heads) != self.embed_dim:
+ raise ValueError(f"embed_dim must be divisible by num_heads "
+ f"(got `embed_dim`: {self.embed_dim}"
+ f" and `num_heads`: {num_heads}).")
+ self.scaling = self.head_dim**-0.5
+
+ self.qkv_proj = QKVParallelLinear(
+ self.d_model,
+ self.d_model // self.total_num_heads,
+ self.total_num_heads,
+ self.total_num_kv_heads,
+ bias=bias,
+ quant_config=quant_config,
+ )
+
+ self.out_proj = RowParallelLinear(
+ embed_dim,
+ embed_dim,
+ bias=bias,
+ quant_config=quant_config,
+ )
+
+ tp_world_size = get_tensor_model_parallel_world_size()
+ assert self.total_num_heads % tp_world_size == 0
+ self.num_heads = self.total_num_heads // tp_world_size
+
+ if self.total_num_kv_heads >= tp_world_size:
+ # Number of KV heads is greater than TP size, so we partition
+ # the KV heads across multiple tensor parallel GPUs.
+ assert self.total_num_kv_heads % tp_world_size == 0
+ else:
+ # Number of KV heads is less than TP size, so we replicate
+ # the KV heads across multiple tensor parallel GPUs.
+ assert tp_world_size % self.total_num_kv_heads == 0
+ self.num_kv_heads = max(1, self.total_num_kv_heads // tp_world_size)
+ self.q_size = self.num_heads * self.head_dim
+ self.kv_size = self.num_kv_heads * self.head_dim
+
+ self.attn = Attention(self.num_heads,
+ self.head_dim,
+ self.scaling,
+ num_kv_heads=self.num_kv_heads,
+ cache_config=cache_config,
+ quant_config=quant_config)
+
+ def forward(self, hidden_states: torch.Tensor, kv_cache: torch.Tensor,
+ attn_metadata: AttentionMetadata) -> torch.Tensor:
+ """Input shape: Batch x Time x Channel"""
+
+ qkv, _ = self.qkv_proj(hidden_states)
+ q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
+
+ attn_output = self.attn(q,
+ k,
+ v,
+ kv_cache,
+ attn_metadata,
+ attn_type=AttentionType.DECODER)
+
+ output, _ = self.out_proj(attn_output)
+ return output
+
+
+class BartCrossAttention(nn.Module):
+
+ def __init__(
+ self,
+ embed_dim: int,
+ num_heads: int,
+ bias: bool = True,
+ config: Optional[BartConfig] = None,
+ cache_config: Optional[CacheConfig] = None,
+ quant_config: Optional[QuantizationConfig] = None,
+ ):
+ super().__init__()
+ self.d_model = config.d_model
+ self.embed_dim = embed_dim
+ self.total_num_heads = num_heads
+ self.total_num_kv_heads = self.total_num_heads
+ self.head_dim = embed_dim // num_heads
+ self.config = config
+
+ if (self.head_dim * num_heads) != self.embed_dim:
+ raise ValueError(f"embed_dim must be divisible by num_heads "
+ f"(got `embed_dim`: {self.embed_dim}"
+ f" and `num_heads`: {num_heads}).")
+ self.scaling = self.head_dim**-0.5
+
+ self.qkv_proj = QKVParallelLinear(
+ self.d_model,
+ self.d_model // self.total_num_heads,
+ self.total_num_heads,
+ self.total_num_kv_heads,
+ bias=bias,
+ quant_config=quant_config,
+ )
+
+ self.out_proj = RowParallelLinear(
+ embed_dim,
+ embed_dim,
+ bias=bias,
+ quant_config=quant_config,
+ )
+
+ tp_world_size = get_tensor_model_parallel_world_size()
+ assert self.total_num_heads % tp_world_size == 0
+ self.num_heads = self.total_num_heads // tp_world_size
+
+ if self.total_num_kv_heads >= tp_world_size:
+ # Number of KV heads is greater than TP size, so we partition
+ # the KV heads across multiple tensor parallel GPUs.
+ assert self.total_num_kv_heads % tp_world_size == 0
+ else:
+ # Number of KV heads is less than TP size, so we replicate
+ # the KV heads across multiple tensor parallel GPUs.
+ assert tp_world_size % self.total_num_kv_heads == 0
+ self.num_kv_heads = max(1, self.total_num_kv_heads // tp_world_size)
+ self.q_size = self.num_heads * self.head_dim
+ self.kv_size = self.num_kv_heads * self.head_dim
+
+ self.attn = Attention(self.num_heads,
+ self.head_dim,
+ self.scaling,
+ num_kv_heads=self.num_kv_heads,
+ cache_config=cache_config,
+ quant_config=quant_config)
+
+ def forward(
+ self,
+ decoder_hidden_states: torch.Tensor,
+ kv_cache: torch.Tensor,
+ attn_metadata: AttentionMetadata,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ """Input shape: Batch x Time x Channel"""
+
+ # (afeldman-nm 2024/07/22) TODO:
+ # Need a more efficient solution for q/k/v
+ qkv_dec, _ = self.qkv_proj(decoder_hidden_states)
+ q, _, _ = qkv_dec.split([self.q_size, self.kv_size, self.kv_size],
+ dim=-1)
+ if encoder_hidden_states is None:
+ k = None
+ v = None
+ else:
+ qkv_enc, _ = self.qkv_proj(encoder_hidden_states)
+ _, k, v = qkv_enc.split([self.q_size, self.kv_size, self.kv_size],
+ dim=-1)
+
+ attn_output = self.attn(q,
+ k,
+ v,
+ kv_cache,
+ attn_metadata,
+ attn_type=AttentionType.ENCODER_DECODER)
+
+ output, _ = self.out_proj(attn_output)
+ return output
+
+
+class BartEncoderLayer(nn.Module):
+
+ def __init__(
+ self,
+ config: BartConfig,
+ cache_config: Optional[CacheConfig] = None,
+ quant_config: Optional[QuantizationConfig] = None,
+ ):
+ super().__init__()
+ self.embed_dim = config.d_model
+
+ self.self_attn = BartEncoderAttention(
+ embed_dim=self.embed_dim,
+ num_heads=config.encoder_attention_heads,
+ config=config,
+ cache_config=cache_config,
+ quant_config=quant_config)
+ self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
+ self.activation_fn = get_act_fn(config.activation_function,
+ quant_config)
+
+ ffn_hidden_size = self.embed_dim
+ ffn_intermediate_size = config.encoder_ffn_dim
+ ffn_has_bias = True
+ self.fc1 = ColumnParallelLinear(
+ ffn_hidden_size,
+ ffn_intermediate_size,
+ bias=ffn_has_bias,
+ quant_config=quant_config,
+ )
+ self.act = get_act_fn("gelu", quant_config, ffn_intermediate_size)
+ self.fc2 = RowParallelLinear(
+ ffn_intermediate_size,
+ ffn_hidden_size,
+ bias=ffn_has_bias,
+ quant_config=quant_config,
+ )
+
+ self.final_layer_norm = nn.LayerNorm(self.embed_dim)
+
+ def forward(self, hidden_states: torch.Tensor, kv_cache: torch.Tensor,
+ attn_metadata: AttentionMetadata) -> torch.Tensor:
+ r"""
+ Args:
+ hidden_states
+ torch.Tensor of *encoder* input embeddings.
+ kv_cache:
+ Layer-wise list of KV cache tensors
+ attn_metadata:
+ vLLM Attention metadata structure
+ Returns:
+ Encoder layer output torch.Tensor
+ """
+ residual = hidden_states
+ hidden_states = self.self_attn(hidden_states=hidden_states,
+ kv_cache=kv_cache,
+ attn_metadata=attn_metadata)
+
+ hidden_states = residual + hidden_states
+ hidden_states = self.self_attn_layer_norm(hidden_states)
+
+ residual = hidden_states
+ fc1_out, _ = self.fc1(hidden_states)
+ hidden_states = self.activation_fn(fc1_out)
+
+ hidden_states, _ = self.fc2(hidden_states)
+
+ hidden_states = residual + hidden_states
+ hidden_states = self.final_layer_norm(hidden_states)
+
+ if hidden_states.dtype == torch.float16 and (
+ torch.isinf(hidden_states).any()
+ or torch.isnan(hidden_states).any()):
+ clamp_value = torch.finfo(hidden_states.dtype).max - 1000
+ hidden_states = torch.clamp(hidden_states,
+ min=-clamp_value,
+ max=clamp_value)
+
+ return hidden_states
+
+
+class BartDecoderLayer(nn.Module):
+
+ def __init__(
+ self,
+ config: BartConfig,
+ cache_config: Optional[CacheConfig] = None,
+ quant_config: Optional[QuantizationConfig] = None,
+ ):
+ super().__init__()
+ self.embed_dim = config.d_model
+
+ self.self_attn = BartDecoderSelfAttention(
+ embed_dim=self.embed_dim,
+ num_heads=config.decoder_attention_heads,
+ config=config,
+ cache_config=cache_config,
+ quant_config=quant_config)
+ self.activation_fn = get_act_fn(config.activation_function,
+ quant_config)
+
+ self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
+ '''
+ afeldman-nm: personally I would call this "cross-attention",
+ however I left the name as "encoder_attn" to maintain consistency
+ with the name of the pretrained weights.
+ '''
+ self.encoder_attn = BartCrossAttention(
+ self.embed_dim,
+ config.decoder_attention_heads,
+ config=config,
+ )
+ self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
+
+ ffn_hidden_size = self.embed_dim
+ ffn_intermediate_size = config.encoder_ffn_dim
+ ffn_has_bias = True
+ self.fc1 = ColumnParallelLinear(
+ ffn_hidden_size,
+ ffn_intermediate_size,
+ bias=ffn_has_bias,
+ quant_config=quant_config,
+ )
+ self.fc2 = RowParallelLinear(
+ ffn_intermediate_size,
+ ffn_hidden_size,
+ bias=ffn_has_bias,
+ quant_config=quant_config,
+ )
+
+ self.final_layer_norm = nn.LayerNorm(self.embed_dim)
+
+ def forward(
+ self,
+ decoder_hidden_states: torch.Tensor,
+ kv_cache: torch.Tensor,
+ attn_metadata: AttentionMetadata,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ r"""
+ Args:
+ decoder_hidden_states
+ torch.Tensor of *decoder* input embeddings.
+ kv_cache:
+ KV cache tensor
+ attn_metadata:
+ vLLM Attention metadata structure
+ encoder_hidden_states
+ torch.Tensor of *encoder* input embeddings.
+ Returns:
+ Decoder layer output torch.Tensor
+ """
+ residual = decoder_hidden_states
+
+ # Self Attention
+ hidden_states = self.self_attn(hidden_states=decoder_hidden_states,
+ kv_cache=kv_cache,
+ attn_metadata=attn_metadata)
+
+ hidden_states = residual + hidden_states
+ hidden_states = self.self_attn_layer_norm(hidden_states)
+
+ # Cross-Attention Block
+
+ residual = hidden_states
+
+ hidden_states = self.encoder_attn(
+ decoder_hidden_states=hidden_states,
+ kv_cache=kv_cache,
+ attn_metadata=attn_metadata,
+ encoder_hidden_states=encoder_hidden_states,
+ )
+
+ hidden_states = residual + hidden_states
+ hidden_states = self.encoder_attn_layer_norm(hidden_states)
+
+ # Fully Connected
+ residual = hidden_states
+ fc1_out, _ = self.fc1(hidden_states)
+ hidden_states = self.activation_fn(fc1_out)
+
+ hidden_states, _ = self.fc2(hidden_states)
+
+ hidden_states = residual + hidden_states
+ hidden_states = self.final_layer_norm(hidden_states)
+
+ return hidden_states
+
+
+class BartEncoder(nn.Module):
+ """
+ Transformer encoder consisting of *config.encoder_layers*
+ self attention layers. Each layer is a [`BartEncoderLayer`].
+ Args:
+ config: BartConfig
+ embed_tokens (nn.Embedding): output embedding
+ """
+
+ def __init__(self,
+ config: BartConfig,
+ cache_config: Optional[CacheConfig] = None,
+ quant_config: Optional[QuantizationConfig] = None,
+ lora_config: Optional[LoRAConfig] = None,
+ embed_tokens: Optional[nn.Embedding] = None):
+ super().__init__()
+
+ self.cache_config = cache_config
+ self.quant_config = quant_config
+ self.lora_config = lora_config
+ embed_dim = config.d_model
+ self.max_source_positions = config.max_position_embeddings
+ embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
+
+ self.embed_tokens = BartScaledWordEmbedding(config.vocab_size,
+ embed_dim,
+ embed_scale=embed_scale)
+
+ if embed_tokens is not None:
+ self.embed_tokens.weight = embed_tokens.weight
+
+ self.embed_positions = BartLearnedPositionalEmbedding(
+ config.max_position_embeddings,
+ embed_dim,
+ )
+ self.layers = nn.ModuleList(
+ [BartEncoderLayer(config,cache_config,quant_config) \
+ for _ in range(config.encoder_layers)])
+
+ self.layernorm_embedding = nn.LayerNorm(embed_dim)
+
+ def forward(self, input_ids: torch.Tensor, positions: torch.Tensor,
+ kv_caches: List[torch.Tensor],
+ attn_metadata: AttentionMetadata) -> torch.Tensor:
+ r"""
+ Args:
+ input_ids
+ Indices of *encoder* input sequence tokens in the vocabulary.
+ Padding will be ignored by default should you
+ provide it.
+ positions
+ Positions of *encoder* input sequence tokens.
+ kv_caches:
+ Layer-wise list of KV cache tensors
+ attn_metadata:
+ vLLM Attention metadata structure
+ Returns:
+ Decoder output torch.Tensor
+ """
+ # retrieve input_ids and inputs_embeds
+
+ input_ids = input_ids.view(-1, input_ids.shape[-1])
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ embed_pos = self.embed_positions(
+ positions,
+ AttentionType.ENCODER,
+ )
+ embed_pos = embed_pos.to(inputs_embeds.device)
+
+ hidden_states = inputs_embeds + embed_pos
+ hidden_states = self.layernorm_embedding(hidden_states)
+
+ for idx, encoder_layer in enumerate(self.layers):
+ hidden_states = encoder_layer(
+ hidden_states=hidden_states,
+ kv_cache=kv_caches[idx],
+ attn_metadata=attn_metadata,
+ )
+
+ return hidden_states
+
+
+class BartDecoder(nn.Module):
+ """
+ Transformer decoder consisting of *config.decoder_layers* layers.
+ Each layer is a [`BartDecoderLayer`]
+ Args:
+ config: BartConfig
+ embed_tokens (nn.Embedding): output embedding
+ """
+
+ def __init__(
+ self,
+ config: BartConfig,
+ cache_config: Optional[CacheConfig] = None,
+ quant_config: Optional[QuantizationConfig] = None,
+ lora_config: Optional[LoRAConfig] = None,
+ embed_tokens: Optional[nn.Embedding] = None,
+ ):
+ super().__init__()
+ self.cache_config = cache_config
+ self.quant_config = quant_config
+ self.lora_config = lora_config
+ self.max_target_positions = config.max_position_embeddings
+ embed_scale = math.sqrt(
+ config.d_model) if config.scale_embedding else 1.0
+
+ self.embed_tokens = BartScaledWordEmbedding(config.vocab_size,
+ config.d_model,
+ embed_scale=embed_scale)
+
+ if embed_tokens is not None:
+ self.embed_tokens.weight = embed_tokens.weight
+
+ self.embed_positions = BartLearnedPositionalEmbedding(
+ config.max_position_embeddings,
+ config.d_model,
+ )
+
+ self.layers = nn.ModuleList(
+ [BartDecoderLayer(config,cache_config,quant_config) \
+ for _ in range(config.decoder_layers)])
+
+ self.layernorm_embedding = nn.LayerNorm(config.d_model)
+
+ def forward(self, decoder_input_ids: torch.Tensor,
+ decoder_positions: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor],
+ kv_caches: List[torch.Tensor],
+ attn_metadata: AttentionMetadata) -> torch.Tensor:
+ r"""
+ Args:
+ decoder_input_ids
+ Indices of *decoder* input sequence tokens in the vocabulary.
+ Padding will be ignored by default should you
+ provide it.
+ decoder_positions
+ Positions of *decoder* input sequence tokens.
+ encoder_hidden_states:
+ Tensor of encoder output embeddings
+ kv_caches:
+ Layer-wise list of KV cache tensors
+ attn_metadata:
+ vLLM Attention metadata structure
+ Returns:
+ Decoder output torch.Tensor
+ """
+
+ inputs_embeds = self.embed_tokens(decoder_input_ids)
+
+ # embed positions
+ embed_pos = self.embed_positions(
+ decoder_positions,
+ AttentionType.DECODER,
+ )
+ embed_pos = embed_pos.to(inputs_embeds.device)
+
+ hidden_states = inputs_embeds + embed_pos
+ hidden_states = self.layernorm_embedding(hidden_states)
+
+ # decoder layers
+
+ for idx, decoder_layer in enumerate(self.layers):
+ hidden_states = decoder_layer(
+ decoder_hidden_states=hidden_states,
+ kv_cache=kv_caches[idx],
+ attn_metadata=attn_metadata,
+ encoder_hidden_states=encoder_hidden_states,
+ )
+
+ return hidden_states
+
+
+class BartModel(nn.Module):
+ _tied_weights_keys = [
+ "encoder.embed_tokens.weight", "decoder.embed_tokens.weight"
+ ]
+
+ def __init__(self,
+ config: BartConfig,
+ cache_config: Optional[CacheConfig] = None,
+ quant_config: Optional[QuantizationConfig] = None,
+ lora_config: Optional[LoRAConfig] = None):
+ super().__init__()
+
+ self.config = config
+
+ self.padding_idx = config.pad_token_id
+ lora_vocab = (lora_config.lora_extra_vocab_size *
+ (lora_config.max_loras or 1)) if lora_config else 0
+ self.vocab_size = config.vocab_size + lora_vocab
+ self.org_vocab_size = config.vocab_size
+
+ self.encoder = BartEncoder(config,
+ cache_config,
+ quant_config=quant_config)
+ self.decoder = BartDecoder(config,
+ cache_config,
+ quant_config=quant_config)
+
+ def forward(self, input_ids: torch.Tensor, positions: torch.Tensor,
+ encoder_input_ids: torch.Tensor,
+ encoder_positions: torch.Tensor, kv_caches: List[torch.Tensor],
+ attn_metadata: AttentionMetadata) -> torch.Tensor:
+ r"""
+ Args:
+ input_ids
+ Indices of *decoder* input sequence tokens in the vocabulary.
+ Padding will be ignored by default should you
+ provide it.
+ positions
+ Positions of *decoder* input sequence tokens.
+ encoder_input_ids
+ Indices of *encoder* input sequence tokens in the vocabulary.
+ encoder_positions:
+ Positions of *encoder* input sequence tokens.
+ kv_caches:
+ Layer-wise list of KV cache tensors
+ attn_metadata:
+ vLLM Attention metadata structure
+ Returns:
+ Model output torch.Tensor
+ """
+
+ encoder_hidden_states = None
+
+ if encoder_input_ids.numel() > 0:
+ # Run encoder attention if a non-zero number of encoder tokens
+ # are provided as input
+ encoder_hidden_states = self.encoder(input_ids=encoder_input_ids,
+ positions=encoder_positions,
+ kv_caches=kv_caches,
+ attn_metadata=attn_metadata)
+
+ # decoder outputs consists of
+ # (dec_features, past_key_value, dec_hidden, dec_attn)
+ decoder_outputs = self.decoder(
+ decoder_input_ids=input_ids,
+ decoder_positions=positions,
+ encoder_hidden_states=encoder_hidden_states,
+ kv_caches=kv_caches,
+ attn_metadata=attn_metadata)
+
+ return decoder_outputs
+
+
+class BartForConditionalGeneration(nn.Module):
+ base_model_prefix = "model"
+
+ def __init__(self,
+ config: BartConfig,
+ cache_config: Optional[CacheConfig] = None,
+ quant_config: Optional[QuantizationConfig] = None,
+ lora_config: Optional[LoRAConfig] = None):
+
+ super().__init__()
+ self.config = config
+ self.model = BartModel(config,
+ cache_config,
+ quant_config,
+ lora_config=lora_config)
+
+ self.unpadded_vocab_size = config.vocab_size
+ if lora_config:
+ self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
+
+ embed_scale = math.sqrt(
+ config.d_model) if config.scale_embedding else 1.0
+
+ self.lm_head = BartParallelLMHead(config.vocab_size,
+ config.d_model,
+ embed_scale=embed_scale)
+
+ self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
+ config.vocab_size)
+ self.sampler = Sampler()
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ positions: torch.Tensor,
+ encoder_input_ids: torch.Tensor,
+ encoder_positions: torch.Tensor,
+ kv_caches: List[torch.Tensor],
+ attn_metadata: AttentionMetadata,
+ intermediate_tensors: Optional[IntermediateTensors] = None,
+ ) -> torch.Tensor:
+ r"""
+ Args:
+ input_ids
+ torch.Tensor of *decoder* input token ids.
+ positions
+ torch.Tensor of *decoder* position indices.
+ encoder_input_ids
+ torch.Tensor of *encoder* input token ids.
+ encoder_positions
+ torch.Tensor of *encoder* position indices
+ kv_caches:
+ Layer-wise list of KV cache tensors
+ attn_metadata:
+ vLLM Attention metadata structure
+ Returns:
+ Output torch.Tensor
+ """
+ return self.model(input_ids, positions, encoder_input_ids,
+ encoder_positions, kv_caches, attn_metadata)
+
+ def compute_logits(self, hidden_states: torch.Tensor,
+ sampling_metadata: SamplingMetadata) -> torch.Tensor:
+ logits = self.logits_processor(self.lm_head, hidden_states,
+ sampling_metadata)
+ return logits
+
+ def sample(
+ self,
+ logits: Optional[torch.Tensor],
+ sampling_metadata: SamplingMetadata,
+ ) -> Optional[SamplerOutput]:
+ next_tokens = self.sampler(logits, sampling_metadata)
+ return next_tokens
+
+ stacked_params_mapping = {
+ "q_proj": {
+ "param_name": "qkv_proj",
+ "shard_id": "q",
+ },
+ "k_proj": {
+ "param_name": "qkv_proj",
+ "shard_id": "k",
+ },
+ "v_proj": {
+ "param_name": "qkv_proj",
+ "shard_id": "v",
+ },
+ }
+
+ params_mapping = {
+ "beta": "bias",
+ "gamma": "weight",
+ "LayerNorm": "layernorm",
+ }
+
+ def _rename_key(self, key: str):
+ prefix = f"{self.base_model_prefix}."
+ key = key[len(prefix):] if key.startswith(prefix) else key
+
+ for src, dst in self.params_mapping.items():
+ key = key.replace(src, dst)
+
+ return key
+
+ def _rename_stacked_param(
+ self,
+ name: str,
+ ) -> Tuple[str, Optional[str]]:
+ for key, mapping in self.stacked_params_mapping.items():
+ if key in name:
+ name = name.replace(key, mapping["param_name"])
+ return name, mapping["shard_id"]
+ return name, None
+
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+
+ model_params_dict = dict(self.model.named_parameters())
+ top_params_dict = dict(self.named_parameters())
+
+ weights_tuple_list = list(weights)
+
+ shared_embedding_weight = None
+ shared_embedding_shard_id = None
+
+ for name, loaded_weight in weights_tuple_list:
+
+ name = self._rename_key(name)
+ name, shard_id = self._rename_stacked_param(name)
+
+ if ('shared.weight' in name
+ or 'encoder.embed_tokens.weight' in name
+ or 'decoder.embed_tokens.weight' in name
+ or 'lm_head.weight' in name):
+ assert shared_embedding_weight is None, (
+ "Conflicting embedding weights.")
+ shared_embedding_weight = loaded_weight
+ shared_embedding_shard_id = shard_id
+ else:
+ # Skip the specific downstream task weight.
+ if name.startswith('cls.'):
+ continue
+ # use Pooler instead.
+ if name.startswith('pooler.'):
+ continue
+ # Skip loading extra bias for GPTQ models.
+ if name.endswith(".bias") and name not in model_params_dict:
+ continue
+
+ param = model_params_dict[name]
+ weight_loader = getattr(param, "weight_loader",
+ default_weight_loader)
+ if shard_id:
+ weight_loader(param, loaded_weight, shard_id)
+ else:
+ weight_loader(param, loaded_weight)
+
+ # Assign shared weight values
+ encoder_in_param = model_params_dict['encoder.embed_tokens.weight']
+ encoder_in_weight_loader = getattr(encoder_in_param, "weight_loader",
+ default_weight_loader)
+
+ decoder_in_param = model_params_dict['decoder.embed_tokens.weight']
+ decoder_in_weight_loader = getattr(decoder_in_param, "weight_loader",
+ default_weight_loader)
+
+ lm_head_in_param = top_params_dict['lm_head.weight']
+ lm_head_in_weight_loader = getattr(lm_head_in_param, "weight_loader",
+ default_weight_loader)
+
+ assert shared_embedding_weight is not None
+
+ if shared_embedding_shard_id:
+ encoder_in_weight_loader(encoder_in_param, shared_embedding_weight,
+ shared_embedding_shard_id)
+ decoder_in_weight_loader(decoder_in_param, shared_embedding_weight,
+ shared_embedding_shard_id)
+ lm_head_in_weight_loader(lm_head_in_param, shared_embedding_weight,
+ shared_embedding_shard_id)
+ else:
+ encoder_in_weight_loader(encoder_in_param, shared_embedding_weight)
+ decoder_in_weight_loader(decoder_in_param, shared_embedding_weight)
+ lm_head_in_weight_loader(lm_head_in_param, shared_embedding_weight)
diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py
index 8850fd7c6763b..49f9a4c85f2d0 100644
--- a/vllm/model_executor/models/internvl.py
+++ b/vllm/model_executor/models/internvl.py
@@ -38,9 +38,6 @@
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)
-MAX_IMAGE_FEATURE_SIZE_WIDTH = 3000
-MAX_IMAGE_FEATURE_SIZE_HEIGHT = 500
-
class InternVLImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
@@ -84,11 +81,9 @@ def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height,
return best_ratio
-def calculate_num_blocks(orig_width: int,
- orig_height: int,
- min_num=1,
- max_num=6,
- image_size=448):
+def calculate_num_blocks(orig_width: int, orig_height: int, min_num: int,
+ max_num: int,
+ image_size: int) -> Tuple[int, int, int]:
aspect_ratio = orig_width / orig_height
# calculate the existing image aspect ratio
@@ -110,11 +105,9 @@ def calculate_num_blocks(orig_width: int,
# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
-def dynamic_preprocess(image,
- min_num=1,
- max_num=6,
- image_size=448,
- use_thumbnail=False):
+def dynamic_preprocess(image: Image.Image, min_num: int, max_num: int,
+ image_size: int,
+ use_thumbnail: int) -> List[Image.Image]:
orig_width, orig_height = image.size
blocks, target_width, target_height = calculate_num_blocks(
@@ -138,12 +131,14 @@ def dynamic_preprocess(image,
# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
-def image_to_pixel_values(image: Image.Image, input_size=448, max_num=6):
+def image_to_pixel_values(image: Image.Image, input_size: int, min_num: int,
+ max_num: int, use_thumbnail: bool) -> torch.Tensor:
transform = build_transform(input_size=input_size)
images = dynamic_preprocess(image,
+ min_num=min_num,
+ max_num=max_num,
image_size=input_size,
- use_thumbnail=True,
- max_num=max_num)
+ use_thumbnail=use_thumbnail)
pixel_values = [transform(image) for image in images]
pixel_values = torch.stack(pixel_values)
return pixel_values
@@ -159,12 +154,18 @@ def get_internvl_num_patches(image_size: int, patch_size: int,
def get_max_internvl_image_tokens(ctx: InputContext):
hf_config = ctx.get_hf_config(PretrainedConfig)
vision_config = hf_config.vision_config
+
+ use_thumbnail = hf_config.use_thumbnail
+ max_dynamic_patch = hf_config.max_dynamic_patch
+ if use_thumbnail:
+ max_dynamic_patch += 1
+ downsample_ratio = hf_config.downsample_ratio
+
image_size = vision_config.image_size
patch_size = vision_config.patch_size
- downsample_ratio = hf_config.downsample_ratio
num_patches = get_internvl_num_patches(image_size, patch_size,
downsample_ratio)
- return num_patches * 7
+ return num_patches * max_dynamic_patch
def input_processor_for_internvl(ctx: InputContext, llm_inputs: LLMInputs):
@@ -176,21 +177,27 @@ def input_processor_for_internvl(ctx: InputContext, llm_inputs: LLMInputs):
hf_config = ctx.get_hf_config(PretrainedConfig)
vision_config = hf_config.vision_config
+ image_size = vision_config.image_size
+ patch_size = vision_config.patch_size
+ downsample_ratio = hf_config.downsample_ratio
+ num_patches = get_internvl_num_patches(image_size, patch_size,
+ downsample_ratio)
+
image_data = multi_modal_data["image"]
if isinstance(image_data, Image.Image):
width, height = image_data.size
- num_blocks, _, _ = calculate_num_blocks(width, height)
+ min_num = hf_config.min_dynamic_patch
+ max_num = hf_config.max_dynamic_patch
+ num_blocks, _, _ = calculate_num_blocks(width, height, min_num,
+ max_num, image_size)
+ # add thumbnail image if num_blocks > 1
+ if hf_config.use_thumbnail and num_blocks > 1:
+ num_blocks += 1
elif isinstance(image_data, torch.Tensor):
raise NotImplementedError("Embeddings input is not supported yet")
else:
raise TypeError(f"Invalid image type: {type(image_data)}")
- image_size = vision_config.image_size
- patch_size = vision_config.patch_size
- downsample_ratio = hf_config.downsample_ratio
- num_patches = get_internvl_num_patches(image_size, patch_size,
- downsample_ratio)
-
tokenizer = cached_get_tokenizer(model_config.tokenizer,
trust_remote_code=True)
@@ -198,8 +205,7 @@ def input_processor_for_internvl(ctx: InputContext, llm_inputs: LLMInputs):
prompt_token_ids = llm_inputs["prompt_token_ids"]
if prompt is None:
prompt = tokenizer.decode(prompt_token_ids)
- image_prompt = IMG_START + IMG_CONTEXT * (num_blocks +
- 1) * num_patches + IMG_END
+ image_prompt = IMG_START + IMG_CONTEXT * num_blocks * num_patches + IMG_END
new_prompt = prompt.replace('', image_prompt, 1)
new_prompt_token_ids = tokenizer.encode(new_prompt)
@@ -209,8 +215,19 @@ def input_processor_for_internvl(ctx: InputContext, llm_inputs: LLMInputs):
def input_mapper_for_internvl(ctx: InputContext, data: object):
+ hf_config = ctx.get_hf_config(PretrainedConfig)
+
+ use_thumbnail = hf_config.use_thumbnail
+ min_num = hf_config.min_dynamic_patch
+ max_num = hf_config.max_dynamic_patch
+ image_size = hf_config.vision_config.image_size
+
if isinstance(data, Image.Image):
- data = image_to_pixel_values(data)
+ data = image_to_pixel_values(data,
+ image_size,
+ min_num,
+ max_num,
+ use_thumbnail=use_thumbnail)
model_config = ctx.model_config
tokenizer = cached_get_tokenizer(model_config.tokenizer,
trust_remote_code=True)
@@ -240,10 +257,17 @@ def dummy_data_for_internvl(ctx: InputContext, seq_len: int):
add_special_tokens=False)[0],
image_feature_size_override=image_feature_size,
)
+
+ image_size = vision_config.image_size
+ min_num = hf_config.min_dynamic_patch
+ max_num = hf_config.max_dynamic_patch
+ max_image_width = max_num * image_size
+ max_image_height = min_num * image_size
+
mm_data = dummy_image_for_clip(
vision_config,
- image_width_override=MAX_IMAGE_FEATURE_SIZE_WIDTH,
- image_height_override=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
+ image_width_override=max_image_width,
+ image_height_override=max_image_height,
)
return seq_data, mm_data
diff --git a/vllm/model_executor/parameter.py b/vllm/model_executor/parameter.py
new file mode 100644
index 0000000000000..10239843b3222
--- /dev/null
+++ b/vllm/model_executor/parameter.py
@@ -0,0 +1,277 @@
+from typing import Callable, Optional, Union
+
+import torch
+from torch.nn import Parameter
+
+from vllm.distributed import get_tensor_model_parallel_rank
+from vllm.logger import init_logger
+
+__all__ = [
+ "BasevLLMParameter", "PackedvLLMParameter", "PerTensorScaleParameter",
+ "ModelWeightParameter", "ChannelQuantScaleParameter",
+ "GroupQuantScaleParameter"
+]
+
+logger = init_logger(__name__)
+
+
+class BasevLLMParameter(Parameter):
+ """
+ Base parameter for vLLM linear layers. Extends the torch.nn.parameter
+ by taking in a linear weight loader. Will copy the loaded weight
+ into the parameter when the provided weight loader is called.
+ """
+
+ def __new__(cls, data: torch.Tensor, **kwargs):
+
+ return super().__new__(cls, data=data, requires_grad=False)
+
+ def __init__(self, data: torch.Tensor, weight_loader: Callable):
+ """
+ Initialize the BasevLLMParameter
+
+ :param data: torch tensor with the parameter data
+ :param weight_loader: weight loader callable
+
+ :returns: a torch.nn.parameter
+ """
+
+ self._weight_loader = weight_loader
+
+ @property
+ def weight_loader(self):
+ return self._weight_loader
+
+ def _assert_and_load(self, loaded_weight: torch.Tensor):
+ assert self.data.shape == loaded_weight.shape
+ self.data.copy_(loaded_weight)
+
+ def load_column_parallel_weight(self, loaded_weight: torch.Tensor):
+ self._assert_and_load(loaded_weight)
+
+ def load_row_parallel_weight(self, loaded_weight: torch.Tensor):
+ self._assert_and_load(loaded_weight)
+
+ def load_merged_column_weight(self, loaded_weight: torch.Tensor, **kwargs):
+ self._assert_and_load(loaded_weight)
+
+ def load_qkv_weight(self, loaded_weight: torch.Tensor, **kwargs):
+ self._assert_and_load(loaded_weight)
+
+
+class _ColumnvLLMParameter(BasevLLMParameter):
+ """
+ Private class defining weight loading functionality
+ (load_merged_column_weight, load_qkv_weight)
+ for parameters being loaded into linear layers with column
+ parallelism. This includes QKV and MLP layers which are
+ not already fused on disk. Requires an output dimension
+ to be defined. Called within the weight loader of
+ each of the column parallel linear layers.
+ """
+
+ def __init__(self, output_dim: int, **kwargs):
+ self._output_dim = output_dim
+ super().__init__(**kwargs)
+
+ @property
+ def output_dim(self):
+ return self._output_dim
+
+ def load_column_parallel_weight(self, loaded_weight: torch.Tensor):
+ tp_rank = get_tensor_model_parallel_rank()
+ shard_size = self.data.shape[self.output_dim]
+ loaded_weight = loaded_weight.narrow(self.output_dim,
+ tp_rank * shard_size, shard_size)
+ assert self.data.shape == loaded_weight.shape
+ self.data.copy_(loaded_weight)
+
+ def load_merged_column_weight(self, loaded_weight: torch.Tensor, **kwargs):
+
+ shard_offset = kwargs.get("shard_offset")
+ shard_size = kwargs.get("shard_size")
+ if isinstance(
+ self,
+ PackedvLLMParameter) and self.packed_dim == self.output_dim:
+ shard_size, shard_offset = self.adjust_shard_indexes_for_packing(
+ shard_offset=shard_offset, shard_size=shard_size)
+
+ param_data = self.data
+
+ tp_rank = get_tensor_model_parallel_rank()
+ param_data = param_data.narrow(self.output_dim, shard_offset,
+ shard_size)
+ loaded_weight = loaded_weight.narrow(self.output_dim,
+ tp_rank * shard_size, shard_size)
+ assert param_data.shape == loaded_weight.shape
+ param_data.copy_(loaded_weight)
+
+ def load_qkv_weight(self, loaded_weight: torch.Tensor, **kwargs):
+
+ shard_offset = kwargs.get("shard_offset")
+ shard_size = kwargs.get("shard_size")
+ shard_id = kwargs.get("shard_id")
+ num_heads = kwargs.get("num_heads")
+
+ if isinstance(
+ self,
+ PackedvLLMParameter) and self.output_dim == self.packed_dim:
+ shard_size, shard_offset = self.adjust_shard_indexes_for_packing(
+ shard_offset=shard_offset, shard_size=shard_size)
+
+ param_data = self.data
+ tp_rank = get_tensor_model_parallel_rank()
+ shard_id = tp_rank if shard_id == "q" else tp_rank // num_heads
+ param_data = param_data.narrow(self.output_dim, shard_offset,
+ shard_size)
+ loaded_weight = loaded_weight.narrow(self.output_dim,
+ shard_id * shard_size, shard_size)
+
+ assert param_data.shape == loaded_weight.shape
+ param_data.copy_(loaded_weight)
+
+
+class ModelWeightParameter(_ColumnvLLMParameter):
+ """
+ Parameter class for linear layer weights. Extends the
+ _ColumnvLLMParameter by adding loading functionality
+ for linear layers with row parallel functionality.
+ Requires an input dimension to be defined.
+ """
+
+ def __init__(self, input_dim: int, **kwargs):
+ self._input_dim = input_dim
+ super().__init__(**kwargs)
+
+ @property
+ def input_dim(self):
+ return self._input_dim
+
+ def load_row_parallel_weight(self, loaded_weight: torch.Tensor):
+ tp_rank = get_tensor_model_parallel_rank()
+ shard_size = self.data.shape[self.input_dim]
+ loaded_weight = loaded_weight.narrow(self.input_dim,
+ tp_rank * shard_size, shard_size)
+
+ if len(loaded_weight.shape) == 0:
+ loaded_weight = loaded_weight.reshape(1)
+
+ assert self.data.shape == loaded_weight.shape
+ self.data.copy_(loaded_weight)
+
+
+class GroupQuantScaleParameter(ModelWeightParameter):
+ """
+ Parameter class for weight scales loaded for weights with
+ grouped quantization. Equivalent to ModelWeightParameter.
+ """
+ pass
+
+
+class ChannelQuantScaleParameter(_ColumnvLLMParameter):
+ """
+ Parameter class for weight scales loaded for weights with
+ channel-wise quantization. Equivalent to _ColumnvLLMParameter.
+ """
+ pass
+
+
+class PerTensorScaleParameter(BasevLLMParameter):
+ """
+ Parameter class for scales where the number of scales is
+ equivalent to the number of logical matrices in fused linear
+ layers (e.g. for QKV, there are 3 scales loaded from disk).
+ This is relevant to weights with per-tensor quantization.
+ Adds functionality to map the scalers to a shard during
+ weight loading.
+
+ Note: additional parameter manipulation may be handled
+ for each quantization config specifically, within
+ process_weights_after_loading
+ """
+
+ def __init__(self, **kwargs):
+ self.qkv_idxs = {"q": 0, "k": 1, "v": 2}
+ super().__init__(**kwargs)
+
+ def _shard_id_as_int(self, shard_id: Union[str, int]) -> int:
+ if isinstance(shard_id, int):
+ return shard_id
+
+ assert isinstance(shard_id, str)
+ assert shard_id in self.qkv_idxs
+ return self.qkv_idxs[shard_id]
+
+ def load_merged_column_weight(self, *args, **kwargs):
+ self._load_into_shard_id(*args, **kwargs)
+
+ def load_qkv_weight(self, *args, **kwargs):
+ self._load_into_shard_id(*args, **kwargs)
+
+ def load_column_parallel_weight(self, *args, **kwargs):
+ self._load_into_shard_id(*args, **kwargs)
+
+ def _load_into_shard_id(self, loaded_weight: torch.Tensor,
+ shard_id: Union[str, int], **kwargs):
+ """
+ Slice the parameter data based on the shard id for
+ loading.
+ """
+
+ param_data = self.data
+ shard_id = self._shard_id_as_int(shard_id)
+
+ # AutoFP8 scales do not have a shape
+ # compressed-tensors scales do have a shape
+ if len(loaded_weight.shape) != 0:
+ assert loaded_weight.shape[0] == 1
+ loaded_weight = loaded_weight[0]
+
+ param_data = param_data[shard_id]
+ assert param_data.shape == loaded_weight.shape
+ param_data.copy_(loaded_weight)
+
+
+class PackedvLLMParameter(ModelWeightParameter):
+ """
+ Parameter for model weights which are packed on disk.
+ Example: GPTQ Marlin weights are int4 or int8, packed into int32.
+ Extends the ModelWeightParameter to take in the
+ packed factor, the packed dimension, and optionally, marlin
+ tile size for marlin kernels. Adjusts the shard_size and
+ shard_offset for fused linear layers model weight loading
+ by accounting for packing and optionally, marlin tile size.
+ """
+
+ def __init__(self,
+ packed_factor: int,
+ packed_dim: int,
+ marlin_tile_size: Optional[int] = None,
+ **kwargs):
+ self._packed_factor = packed_factor
+ self._packed_dim = packed_dim
+ self._marlin_tile = marlin_tile_size
+ super().__init__(**kwargs)
+
+ @property
+ def packed_dim(self):
+ return self._packed_dim
+
+ @property
+ def packed_factor(self):
+ return self._packed_factor
+
+ @property
+ def marlin_tile(self):
+ return self._marlin_tile
+
+ def _adjust_shard_indexes_for_marlin(self, shard_size, shard_offset):
+ return shard_size * self.marlin_tile, shard_offset * self.marlin_tile
+
+ def adjust_shard_indexes_for_packing(self, shard_size, shard_offset):
+ shard_size = shard_size // self.packed_factor
+ shard_offset = shard_offset // self.packed_factor
+ if self.marlin_tile is not None:
+ return self._adjust_shard_indexes_for_marlin(
+ shard_size, shard_offset)
+ return shard_size, shard_offset
diff --git a/vllm/outputs.py b/vllm/outputs.py
index b1cb1cd07fbb1..040f770814576 100644
--- a/vllm/outputs.py
+++ b/vllm/outputs.py
@@ -70,12 +70,20 @@ class RequestOutput:
Args:
request_id: The unique ID of the request.
prompt: The prompt string of the request.
+ For encoder/decoder models, this is the
+ decoder input prompt.
prompt_token_ids: The token IDs of the prompt.
+ For encoder/decoder models, this is the
+ decoder input prompt token ids.
prompt_logprobs: The log probabilities to return per prompt token.
outputs: The output sequences of the request.
finished: Whether the whole request is finished.
metrics: Metrics associated with the request.
lora_request: The LoRA request that was used to generate the output.
+ encoder_prompt: The encoder prompt string of the request;
+ None if decoder-only
+ encoder_prompt_token_ids: The token IDs of the encoder prompt;
+ None if decoder-only
"""
def __init__(
@@ -88,6 +96,8 @@ def __init__(
finished: bool,
metrics: Optional[RequestMetrics] = None,
lora_request: Optional[LoRARequest] = None,
+ encoder_prompt: Optional[str] = None,
+ encoder_prompt_token_ids: Optional[List[int]] = None,
) -> None:
self.request_id = request_id
self.prompt = prompt
@@ -97,6 +107,8 @@ def __init__(
self.finished = finished
self.metrics = metrics
self.lora_request = lora_request
+ self.encoder_prompt = encoder_prompt
+ self.encoder_prompt_token_ids = encoder_prompt_token_ids
@classmethod
def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput":
@@ -137,6 +149,8 @@ def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput":
# Every sequence in the sequence group should have the same prompt.
prompt = seq_group.prompt
prompt_token_ids = seq_group.prompt_token_ids
+ encoder_prompt = seq_group.encoder_prompt
+ encoder_prompt_token_ids = seq_group.encoder_prompt_token_ids
prompt_logprobs = seq_group.prompt_logprobs
finished = seq_group.is_finished()
finished_time = time.time() if finished else None
@@ -148,12 +162,16 @@ def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput":
outputs,
finished,
seq_group.metrics,
- lora_request=seq_group.lora_request)
+ lora_request=seq_group.lora_request,
+ encoder_prompt=encoder_prompt,
+ encoder_prompt_token_ids=encoder_prompt_token_ids)
def __repr__(self) -> str:
return (f"RequestOutput(request_id={self.request_id}, "
f"prompt={self.prompt!r}, "
f"prompt_token_ids={self.prompt_token_ids}, "
+ f"encoder_prompt={self.encoder_prompt!r}, "
+ f"encoder_prompt_token_ids={self.encoder_prompt_token_ids}, "
f"prompt_logprobs={self.prompt_logprobs}, "
f"outputs={self.outputs}, "
f"finished={self.finished}, "
diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py
index 02ba227460e3f..a7e760cc16408 100644
--- a/vllm/platforms/cuda.py
+++ b/vllm/platforms/cuda.py
@@ -4,12 +4,21 @@
import os
from functools import lru_cache, wraps
-from typing import Tuple
+from typing import List, Tuple
import pynvml
+from vllm.logger import init_logger
+
from .interface import Platform, PlatformEnum
+logger = init_logger(__name__)
+
+# NVML utils
+# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,
+# all the related functions work on real physical device ids.
+# the major benefit of using NVML is that it will not initialize CUDA
+
def with_nvml_context(fn):
@@ -47,3 +56,29 @@ class CudaPlatform(Platform):
def get_device_capability(device_id: int = 0) -> Tuple[int, int]:
physical_device_id = device_id_to_physical_device_id(device_id)
return get_physical_device_capability(physical_device_id)
+
+ @staticmethod
+ @with_nvml_context
+ def is_full_nvlink(physical_device_ids: List[int]) -> bool:
+ """
+ query if the set of gpus are fully connected by nvlink (1 hop)
+ """
+ handles = [
+ pynvml.nvmlDeviceGetHandleByIndex(i) for i in physical_device_ids
+ ]
+ for i, handle in enumerate(handles):
+ for j, peer_handle in enumerate(handles):
+ if i < j:
+ try:
+ p2p_status = pynvml.nvmlDeviceGetP2PStatus(
+ handle, peer_handle,
+ pynvml.NVML_P2P_CAPS_INDEX_NVLINK)
+ if p2p_status != pynvml.NVML_P2P_STATUS_OK:
+ return False
+ except pynvml.NVMLError as error:
+ logger.error(
+ "NVLink detection failed. This is normal if your"
+ " machine has no NVLink equipped.",
+ exc_info=error)
+ return False
+ return True
diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py
index 2598325439ebf..1c1e5f16b5172 100644
--- a/vllm/sampling_params.py
+++ b/vllm/sampling_params.py
@@ -224,6 +224,9 @@ def _verify_args(self) -> None:
if self.top_k < -1 or self.top_k == 0:
raise ValueError(f"top_k must be -1 (disable), or at least 1, "
f"got {self.top_k}.")
+ if not isinstance(self.top_k, int):
+ raise TypeError(
+ f"top_k must be an integer, got {type(self.top_k).__name__}")
if not 0.0 <= self.min_p <= 1.0:
raise ValueError("min_p must be in [0, 1], got "
f"{self.min_p}.")
diff --git a/vllm/scripts.py b/vllm/scripts.py
index 403b22239aed0..f45bfe06047de 100644
--- a/vllm/scripts.py
+++ b/vllm/scripts.py
@@ -14,7 +14,7 @@
from vllm.utils import FlexibleArgumentParser
-def registrer_signal_handlers():
+def register_signal_handlers():
def signal_handler(sig, frame):
sys.exit(0)
@@ -31,7 +31,7 @@ def serve(args: argparse.Namespace) -> None:
def interactive_cli(args: argparse.Namespace) -> None:
- registrer_signal_handlers()
+ register_signal_handlers()
base_url = args.url
api_key = args.api_key or os.environ.get("OPENAI_API_KEY", "EMPTY")
diff --git a/vllm/sequence.py b/vllm/sequence.py
index 7ef9387c611f8..6347855333822 100644
--- a/vllm/sequence.py
+++ b/vllm/sequence.py
@@ -7,10 +7,11 @@
from collections import defaultdict
from dataclasses import dataclass, field
from typing import (TYPE_CHECKING, Dict, List, Mapping, Optional, Set, Tuple,
- Union)
+ Union, cast)
import torch
+from vllm.inputs import is_valid_encoder_decoder_llm_inputs
from vllm.lora.request import LoRARequest
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
@@ -244,24 +245,38 @@ def __repr__(self) -> str:
class Sequence:
"""Stores the data, status, and block information of a sequence.
+ The sequence is constructed from the LLMInputs instance passed
+ in through the `inputs` constructor argument.
+
+ For encoder/decoder models, LLMInputs encapsulates both a
+ decoder and encoder prompt, creating an ambiguity about which
+ prompt to construct the sequence from. The `from_decoder_prompt`
+ constructor argument signals whether to construct the Sequence
+ from the LLMInputs decoder prompt, or encoder prompt.
+
Args:
seq_id: The ID of the sequence.
inputs: The inputs of the sequence.
block_size: The block size of the sequence. Should be the same as the
block size used by the block manager and cache engine.
+ eos_token_id: The end-of-sequence (EOS) token id recognized by this LLM.
lora_request: LoRA request.
prompt_adapter_request: Prompt Adapter request.
+ from_decoder_prompt: Construct Sequence from LLMInputs decoder prompt
+ (True) or encoder prompt (False.) Must be True
+ for decoder-only model.
"""
def __init__(
- self,
- seq_id: int,
- inputs: "LLMInputs",
- block_size: int,
- eos_token_id: Optional[int] = None,
- lora_request: Optional[LoRARequest] = None,
- prompt_adapter_request: Optional[PromptAdapterRequest] = None
+ self,
+ seq_id: int,
+ inputs: "LLMInputs",
+ block_size: int,
+ eos_token_id: Optional[int] = None,
+ lora_request: Optional[LoRARequest] = None,
+ prompt_adapter_request: Optional[PromptAdapterRequest] = None,
+ from_decoder_prompt: bool = True,
) -> None:
self.seq_id = seq_id
self.inputs = inputs
@@ -269,6 +284,36 @@ def __init__(
self.eos_token_id = eos_token_id
self.lora_request = lora_request
self.prompt_adapter_request = prompt_adapter_request
+ self.from_decoder_prompt = from_decoder_prompt
+ self._prompt: Optional[str] = None
+ self._prompt_token_ids: Optional[List[int]] = None
+
+ # For decoder-only models, a Sequence is constructed
+ # from an LLMInputs instance (the `inputs` arg.)
+ #
+ # For encoder/decoder models the same `inputs`
+ # instance could be utilized to construct either an
+ # encoder sequence or a decoder sequence, because
+ # `LLMInputs` has both decoder- and encoder-oriented
+ # member variables (i.e. it encapsulates both an encoder
+ # and a decoder prompt.) The decision of which type of sequence
+ # to generate is determined by the `from_decoder_prompt` argument.
+ #
+ # When constructing a encoder sequence
+ # (`from_decoder_prompt` False) it matters that
+ # the `LLMInputs` instance stored in `inputs` is valid
+ # in the sense that its encoder-related member variables are
+ # populated; below, an exception is raised if this is
+ # not the case.
+ #
+ # When constructing a decoder sequence (`from_decoder_prompt` True)
+ # it does not matter whether `inputs` has its encoder-related
+ # member variables populated.
+ if not (from_decoder_prompt
+ or is_valid_encoder_decoder_llm_inputs(inputs)):
+ raise ValueError("Cannot extract encoder input prompt from "
+ f"invalid input {inputs}; did you forget the "
+ "encoder input prompt fields?")
self.data = SequenceData(self.prompt_token_ids)
self.output_logprobs: SampleLogprobs = []
@@ -289,11 +334,35 @@ def n_blocks(self) -> int:
@property
def prompt(self) -> Optional[str]:
- return self.inputs.get("prompt")
+ if self._prompt is not None:
+ # Reuse precomputed prompt string
+ return self._prompt
+
+ # Select decoder or encoder input prompt str,
+ # as appropriate
+ prompt_key: str = ("prompt"
+ if self.from_decoder_prompt else "encoder_prompt")
+
+ # Cache prompt
+ self._prompt = cast(Optional[str], self.inputs.get(prompt_key))
+ return self._prompt
@property
def prompt_token_ids(self) -> List[int]:
- return self.inputs["prompt_token_ids"]
+ if self._prompt_token_ids is not None:
+ # Reuse precomputed prompt token ids
+ return self._prompt_token_ids
+
+ # Select decoder or encoder input prompt
+ # token ids, as appropriate
+ prompt_token_ids_key: str = ("prompt_token_ids"
+ if self.from_decoder_prompt else
+ "encoder_prompt_token_ids")
+
+ # Cache computed prompt token ids
+ self._prompt_token_ids = cast(List[int],
+ self.inputs.get(prompt_token_ids_key))
+ return self._prompt_token_ids
@property
def multi_modal_data(self) -> "MultiModalDataDict":
@@ -472,6 +541,22 @@ def prompt_token_ids(self) -> List[int]:
# We use the prompt of an arbitrary sequence.
return self.seqs[0].prompt_token_ids
+ @property
+ def encoder_prompt(self) -> Optional[str]:
+ # There are either 0 or 1 encoder sequences
+ # If one is present, its prompt is distinct
+ # from the decoder's.
+ return (self.encoder_seq.prompt
+ if self.encoder_seq is not None else None)
+
+ @property
+ def encoder_prompt_token_ids(self) -> Optional[List[int]]:
+ # There are either 0 or 1 encoder sequences
+ # If one is present, its prompt token ids are
+ # distinct from the decoder's.
+ return (self.encoder_seq.prompt_token_ids
+ if self.encoder_seq is not None else None)
+
@property
def multi_modal_data(self) -> "MultiModalDataDict":
# All sequences in the group should have the same multi-modal data.
diff --git a/vllm/transformers_utils/tokenizer.py b/vllm/transformers_utils/tokenizer.py
index bf26d889d1388..25e4c41592c68 100644
--- a/vllm/transformers_utils/tokenizer.py
+++ b/vllm/transformers_utils/tokenizer.py
@@ -12,12 +12,12 @@
from vllm.transformers_utils.tokenizers import BaichuanTokenizer
from vllm.utils import make_async
+from .tokenizer_group import AnyTokenizer
+
logger = init_logger(__name__)
-def get_cached_tokenizer(
- tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
-) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
+def get_cached_tokenizer(tokenizer: AnyTokenizer) -> AnyTokenizer:
"""Get tokenizer with cached properties.
This will patch the tokenizer object in place.
@@ -63,7 +63,7 @@ def get_tokenizer(
revision: Optional[str] = None,
download_dir: Optional[str] = None,
**kwargs,
-) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
+) -> AnyTokenizer:
"""Gets a tokenizer for the given model name via HuggingFace or ModelScope.
"""
if VLLM_USE_MODELSCOPE:
diff --git a/vllm/utils.py b/vllm/utils.py
index 51bd72977a226..1fd395c04ca24 100644
--- a/vllm/utils.py
+++ b/vllm/utils.py
@@ -1,5 +1,6 @@
import argparse
import asyncio
+import contextlib
import datetime
import enum
import gc
@@ -11,12 +12,14 @@
import threading
import uuid
import warnings
+from asyncio import FIRST_COMPLETED, ensure_future
from collections import defaultdict
from functools import lru_cache, partial, wraps
from platform import uname
-from typing import (Any, AsyncIterator, Awaitable, Callable, Dict, Generic,
+from typing import (Any, AsyncGenerator, Awaitable, Callable, Dict, Generic,
Hashable, List, Optional, OrderedDict, Set, Tuple, TypeVar,
Union, overload)
+from uuid import uuid4
import numpy as np
import numpy.typing as npt
@@ -27,10 +30,93 @@
import vllm.envs as envs
from vllm import _custom_ops as ops
+from vllm.inputs import (ExplicitEncoderDecoderPrompt, PromptInputs,
+ SingletonPromptInputs)
from vllm.logger import enable_trace_function_call, init_logger
logger = init_logger(__name__)
+# Exception strings for non-implemented encoder/decoder scenarios
+
+STR_NOT_IMPL_ENC_DEC_SWA = \
+ "Sliding window attention for encoder/decoder models " + \
+ "is not currently supported."
+
+STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE = \
+ "Prefix caching for encoder/decoder models " + \
+ "is not currently supported."
+
+STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL = \
+ "Chunked prefill for encoder/decoder models " + \
+ "is not currently supported."
+
+STR_NOT_IMPL_ENC_DEC_LOGIT_SOFTCAP = (
+ "Models with logits_soft_cap "
+ "require FlashInfer backend, which is "
+ "currently not supported for encoder/decoder "
+ "models.")
+
+STR_NOT_IMPL_ENC_DEC_LORA = ("LoRA is currently not currently "
+ "supported with encoder/decoder "
+ "models.")
+
+STR_NOT_IMPL_ENC_DEC_PP = ("Pipeline parallelism is not "
+ "currently supported with "
+ "encoder/decoder models.")
+
+STR_NOT_IMPL_ENC_DEC_MM = ("Multimodal is not currently "
+ "supported with encoder/decoder "
+ "models.")
+
+STR_NOT_IMPL_ENC_DEC_SPEC_DEC = ("Speculative decoding is not "
+ "currently supported with encoder/"
+ "decoder models.")
+
+STR_NOT_IMPL_ENC_DEC_CUDAGRAPH = ("CUDAGraph is not "
+ "currently supported with encoder/"
+ "decoder models.")
+
+STR_NOT_IMPL_ENC_DEC_BACKEND = ("XFormers is the only backend "
+ "currently supported with encoder/"
+ "decoder models.")
+
+STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER = ("Prompt adapters are not "
+ "currently supported with encoder/"
+ "decoder models.")
+
+# Efficiently import all enc/dec error strings
+# rather than having to import all of the above
+STR_NOT_IMPL_ENC_DEC_ERR_STRS = {
+ "STR_NOT_IMPL_ENC_DEC_SWA": STR_NOT_IMPL_ENC_DEC_SWA,
+ "STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE": STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE,
+ "STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL":
+ STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL,
+ "STR_NOT_IMPL_ENC_DEC_LOGIT_SOFTCAP": STR_NOT_IMPL_ENC_DEC_LOGIT_SOFTCAP,
+ "STR_NOT_IMPL_ENC_DEC_LORA": STR_NOT_IMPL_ENC_DEC_LORA,
+ "STR_NOT_IMPL_ENC_DEC_PP": STR_NOT_IMPL_ENC_DEC_PP,
+ "STR_NOT_IMPL_ENC_DEC_MM": STR_NOT_IMPL_ENC_DEC_MM,
+ "STR_NOT_IMPL_ENC_DEC_SPEC_DEC": STR_NOT_IMPL_ENC_DEC_SPEC_DEC,
+ "STR_NOT_IMPL_ENC_DEC_CUDA_GRAPH": STR_NOT_IMPL_ENC_DEC_CUDAGRAPH,
+ "STR_NOT_IMPL_ENC_DEC_BACKEND": STR_NOT_IMPL_ENC_DEC_BACKEND,
+ "STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER": STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER,
+}
+
+# Constants related to forcing the attention backend selection
+
+# String name of register which may be set in order to
+# force auto-selection of attention backend by Attention
+# wrapper
+STR_BACKEND_ENV_VAR: str = "VLLM_ATTENTION_BACKEND"
+
+# Possible string values of STR_BACKEND_ENV_VAR
+# register, corresponding to possible backends
+STR_FLASHINFER_ATTN_VAL: str = "FLASHINFER"
+STR_TORCH_SDPA_ATTN_VAL: str = "TORCH_SDPA"
+STR_ROCM_FLASH_ATTN_VAL: str = "ROCM_FLASH"
+STR_XFORMERS_ATTN_VAL: str = "XFORMERS"
+STR_FLASH_ATTN_VAL: str = "FLASH_ATTN"
+STR_INVALID_VAL: str = "INVALID"
+
STR_DTYPE_TO_TORCH_DTYPE = {
"half": torch.half,
"bfloat16": torch.bfloat16,
@@ -290,63 +376,74 @@ def _async_wrapper(*args: P.args, **kwargs: P.kwargs) -> asyncio.Future:
return _async_wrapper
-class ProducerFinished:
- pass
+async def iterate_with_cancellation(
+ iterator: AsyncGenerator[T, None],
+ is_cancelled: Callable[[], Awaitable[bool]],
+) -> AsyncGenerator[T, None]:
+ """Convert async iterator into one that polls the provided function
+ at least once per second to check for client cancellation.
+ """
+ # Can use anext() in python >= 3.10
+ awaits = [ensure_future(iterator.__anext__())]
+ while True:
+ done, pending = await asyncio.wait(awaits, timeout=1)
+ if await is_cancelled():
+ with contextlib.suppress(BaseException):
+ awaits[0].cancel()
+ await iterator.aclose()
+ raise asyncio.CancelledError("client cancelled")
+ if done:
+ try:
+ item = await awaits[0]
+ awaits[0] = ensure_future(iterator.__anext__())
+ yield item
+ except StopAsyncIteration:
+ # we are done
+ return
-def merge_async_iterators(
- *iterators: AsyncIterator[T]) -> AsyncIterator[Tuple[int, T]]:
+
+async def merge_async_iterators(
+ *iterators: AsyncGenerator[T, None],
+ is_cancelled: Callable[[], Awaitable[bool]],
+) -> AsyncGenerator[Tuple[int, T], None]:
"""Merge multiple asynchronous iterators into a single iterator.
This method handle the case where some iterators finish before others.
When it yields, it yields a tuple (i, item) where i is the index of the
iterator that yields the item.
- """
- queue: asyncio.Queue[Union[Tuple[int, T], ProducerFinished,
- Exception]] = asyncio.Queue()
-
- producers = len(iterators)
-
- async def producer(i: int, iterator: AsyncIterator[T]):
- try:
- async for item in iterator:
- await queue.put((i, item))
- except Exception as e:
- await queue.put(e)
- # Signal to the consumer that we've finished
- await queue.put(ProducerFinished())
-
- _tasks = [
- asyncio.create_task(producer(i, iterator))
- for i, iterator in enumerate(iterators)
- ]
-
- async def consumer():
- remaining = producers
- try:
- while remaining or not queue.empty():
- # we think there is a race condition here
- item = await queue.get()
-
- if isinstance(item, ProducerFinished):
- # Signal that a producer finished- not a real item
- remaining -= 1
- continue
- if isinstance(item, Exception):
- raise item
- yield item
- except (Exception, asyncio.CancelledError) as e:
- for task in _tasks:
- if sys.version_info >= (3, 9):
- # msg parameter only supported in Python 3.9+
- task.cancel(e)
- else:
- task.cancel()
- raise e
- await asyncio.gather(*_tasks)
+ It also polls the provided function at least once per second to check
+ for client cancellation.
+ """
- return consumer()
+ # Can use anext() in python >= 3.10
+ awaits = {
+ ensure_future(pair[1].__anext__()): pair
+ for pair in enumerate(iterators)
+ }
+ try:
+ while awaits:
+ done, pending = await asyncio.wait(awaits.keys(),
+ return_when=FIRST_COMPLETED,
+ timeout=1)
+ if await is_cancelled():
+ raise asyncio.CancelledError("client cancelled")
+ for d in done:
+ pair = awaits.pop(d)
+ try:
+ item = await d
+ i, it = pair
+ awaits[ensure_future(it.__anext__())] = pair
+ yield i, item
+ except StopAsyncIteration:
+ pass
+ finally:
+ # Cancel any remaining iterators
+ for f, (_, it) in awaits.items():
+ with contextlib.suppress(BaseException):
+ f.cancel()
+ await it.aclose()
def get_ip() -> str:
@@ -388,10 +485,13 @@ def get_distributed_init_method(ip: str, port: int) -> str:
return f"tcp://[{ip}]:{port}" if ":" in ip else f"tcp://{ip}:{port}"
-def get_open_port(port: Optional[int] = None) -> int:
- if port is None:
- # Default behavior here is to return a port for multi-gpu communication
- port = envs.VLLM_PORT
+def get_open_zmq_ipc_path() -> str:
+ base_rpc_path = envs.VLLM_RPC_BASE_PATH
+ return f"ipc://{base_rpc_path}/{uuid4()}"
+
+
+def get_open_port() -> int:
+ port = envs.VLLM_PORT
if port is not None:
while True:
try:
@@ -938,56 +1038,6 @@ def cuda_device_count_stateless() -> int:
return _cuda_device_count_stateless(envs.CUDA_VISIBLE_DEVICES)
-# NVML utils
-# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,
-# all the related functions work on real physical device ids.
-# the major benefit of using NVML is that it will not initialize CUDA
-
-try:
- import pynvml
-except ImportError:
- # For non-NV devices
- pynvml = None
-
-
-def with_nvml_context(fn):
-
- @wraps(fn)
- def wrapper(*args, **kwargs):
- if pynvml is not None:
- pynvml.nvmlInit()
- try:
- return fn(*args, **kwargs)
- finally:
- if pynvml is not None:
- pynvml.nvmlShutdown()
-
- return wrapper
-
-
-@with_nvml_context
-def is_full_nvlink(device_ids: List[int]) -> bool:
- """
- query if the set of gpus are fully connected by nvlink (1 hop)
- """
- handles = [pynvml.nvmlDeviceGetHandleByIndex(i) for i in device_ids]
- for i, handle in enumerate(handles):
- for j, peer_handle in enumerate(handles):
- if i < j:
- try:
- p2p_status = pynvml.nvmlDeviceGetP2PStatus(
- handle, peer_handle, pynvml.NVML_P2P_CAPS_INDEX_NVLINK)
- if p2p_status != pynvml.NVML_P2P_STATUS_OK:
- return False
- except pynvml.NVMLError as error:
- logger.error(
- "NVLink detection failed. This is normal if your"
- " machine has no NVLink equipped.",
- exc_info=error)
- return False
- return True
-
-
#From: https://stackoverflow.com/a/4104188/2749989
def run_once(f):
@@ -1029,3 +1079,50 @@ async def _run_task_with_lock(task: Callable, lock: asyncio.Lock, *args,
"""Utility function to run async task in a lock"""
async with lock:
return await task(*args, **kwargs)
+
+
+def is_encoder_decoder_model_config(model_config) -> bool:
+ '''
+ Extract the HF encoder/decoder model flag from the ModelConfig instance.
+ Return False if model_config is None.
+ '''
+ return model_config is not None and \
+ getattr(model_config.hf_config,
+ "is_encoder_decoder",
+ False)
+
+
+def is_embedding_model_config(model_config) -> bool:
+ '''
+ Extract the embedding model flag from the ModelConfig instance.
+ Return False if model_config is None.
+ '''
+ return model_config is not None and \
+ model_config.embedding_mode
+
+
+def build_explicit_enc_dec_prompt(
+ encoder_prompt: SingletonPromptInputs,
+ decoder_prompt: SingletonPromptInputs,
+) -> ExplicitEncoderDecoderPrompt:
+ return ExplicitEncoderDecoderPrompt(encoder_prompt=encoder_prompt,
+ decoder_prompt=decoder_prompt)
+
+
+def zip_enc_dec_prompt_lists(
+ enc_prompt_list: List[SingletonPromptInputs],
+ dec_prompt_list: List[SingletonPromptInputs],
+) -> List[ExplicitEncoderDecoderPrompt]:
+ return [
+ build_explicit_enc_dec_prompt(encoder_prompt, decoder_prompt)
+ for (encoder_prompt,
+ decoder_prompt) in zip(enc_prompt_list, dec_prompt_list)
+ ]
+
+
+def to_enc_dec_tuple_list(
+ enc_dec_prompts: List[ExplicitEncoderDecoderPrompt],
+) -> List[Tuple[PromptInputs, PromptInputs]]:
+ return [(enc_dec_prompt['encoder_prompt'],
+ enc_dec_prompt['decoder_prompt'])
+ for enc_dec_prompt in enc_dec_prompts]
diff --git a/vllm/worker/enc_dec_model_runner.py b/vllm/worker/enc_dec_model_runner.py
new file mode 100644
index 0000000000000..d9b323f2af09e
--- /dev/null
+++ b/vllm/worker/enc_dec_model_runner.py
@@ -0,0 +1,472 @@
+import dataclasses
+from typing import Any, Dict, List, Optional, Tuple, Type, cast
+
+import torch
+import torch.distributed
+
+from vllm.attention.backends.abstract import (AttentionBackend,
+ AttentionMetadata)
+from vllm.attention.selector import (_Backend, get_env_variable_attn_backend,
+ get_global_forced_attn_backend,
+ global_force_attn_backend)
+from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
+ ModelConfig, MultiModalConfig, ParallelConfig,
+ PromptAdapterConfig, SchedulerConfig)
+from vllm.inputs import INPUT_REGISTRY
+from vllm.logger import init_logger
+from vllm.model_executor import SamplingMetadata
+from vllm.sampling_params import SamplingParams
+from vllm.sequence import (IntermediateTensors, PoolerOutput, SamplerOutput,
+ SequenceGroupMetadata)
+from vllm.utils import STR_NOT_IMPL_ENC_DEC_BACKEND, make_tensor_with_pad
+from vllm.worker.model_runner import (_PAD_SLOT_ID, GPUModelRunnerBase,
+ ModelInputForGPUBuilder,
+ ModelInputForGPUWithSamplingMetadata)
+from vllm.worker.model_runner_base import (
+ _add_attn_metadata_broadcastable_dict,
+ _add_sampling_metadata_broadcastable_dict)
+from vllm.worker.utils import assert_enc_dec_mr_supported_scenario
+
+logger = init_logger(__name__)
+
+
+@dataclasses.dataclass(frozen=True)
+class EncoderDecoderModelInput(ModelInputForGPUWithSamplingMetadata):
+ """
+ Used by the EncoderDecoderModelRunner.
+ """
+ encoder_input_tokens: Optional[torch.Tensor] = None
+ encoder_input_positions: Optional[torch.Tensor] = None
+
+ def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
+ tensor_dict = {
+ "input_tokens": self.input_tokens,
+ "input_positions": self.input_positions,
+ "encoder_input_tokens": self.encoder_input_tokens,
+ "encoder_input_positions": self.encoder_input_positions,
+ "virtual_engine": self.virtual_engine,
+ "request_ids_to_seq_ids": self.request_ids_to_seq_ids,
+ "finished_requests_ids": self.finished_requests_ids,
+ }
+ _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
+ _add_sampling_metadata_broadcastable_dict(tensor_dict,
+ self.sampling_metadata)
+ return tensor_dict
+
+ @classmethod
+ def from_broadcasted_tensor_dict(
+ cls,
+ tensor_dict: Dict[str, Any],
+ attn_backend: Optional["AttentionBackend"] = None,
+ ) -> "EncoderDecoderModelInput":
+ return cast(
+ EncoderDecoderModelInput,
+ super().from_broadcasted_tensor_dict(tensor_dict, attn_backend))
+
+
+class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
+ _model_input_cls: Type[EncoderDecoderModelInput] = (
+ EncoderDecoderModelInput)
+ _builder_cls: Type[ModelInputForGPUBuilder] = (ModelInputForGPUBuilder)
+
+ def __init__(
+ self,
+ model_config: ModelConfig,
+ parallel_config: ParallelConfig,
+ scheduler_config: SchedulerConfig,
+ device_config: DeviceConfig,
+ cache_config: CacheConfig,
+ load_config: LoadConfig,
+ lora_config: Optional[LoRAConfig],
+ kv_cache_dtype: Optional[str] = "auto",
+ is_driver_worker: bool = False,
+ prompt_adapter_config: Optional[PromptAdapterConfig] = None,
+ multimodal_config: Optional[MultiModalConfig] = None,
+ ):
+ '''
+ EncoderDecoderModelRunner constructor.
+
+ `lora_config`, `multimodal_config`, and prompt_adapter_config are
+ unused (since these features are not yet supported for encoder/decoder
+ models) but these arguments are present here for compatibility with
+ the base-class constructor.
+ '''
+
+ self._maybe_force_supported_attention_backend()
+
+ super().__init__(
+ model_config,
+ parallel_config,
+ scheduler_config,
+ device_config,
+ cache_config,
+ load_config,
+ lora_config=None,
+ kv_cache_dtype=kv_cache_dtype,
+ is_driver_worker=is_driver_worker,
+ )
+
+ # Crash for unsupported encoder/scenarios
+ assert_enc_dec_mr_supported_scenario(self)
+
+ def _maybe_force_supported_attention_backend(self):
+ '''
+ Force vLLM to use the XFormers attention backend,
+ which is currently the only supported option.
+ '''
+
+ def raise_backend_err():
+ # The user has specified an attention backend override
+ # which is invalid for encoder/decoder models
+ raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_BACKEND)
+
+ maybe_env_var_forced_backend = get_env_variable_attn_backend()
+ maybe_global_forced_backend = get_global_forced_attn_backend()
+ is_forced_by_global = maybe_global_forced_backend is not None
+ is_forced_by_env_var = maybe_env_var_forced_backend is not None
+
+ if not (is_forced_by_global or is_forced_by_env_var):
+ # The user has not already specified an attention backend
+ # override
+ logger.info("EncoderDecoderModelRunner requires "
+ "XFormers backend; overriding backend "
+ "auto-selection and forcing XFormers.")
+ global_force_attn_backend(_Backend.XFORMERS)
+ elif is_forced_by_global:
+ # Backend override enforced by global variable takes
+ # precedence over vLLM backend environment variable.
+ if maybe_global_forced_backend != _Backend.XFORMERS:
+ raise_backend_err()
+ elif is_forced_by_env_var:
+ # Backend override enforced by vLLM backend
+ # environment variable
+ if maybe_env_var_forced_backend != _Backend.XFORMERS:
+ raise_backend_err()
+
+ def _list_to_int32_tensor(
+ self,
+ _list: List[int],
+ ) -> torch.Tensor:
+ return torch.tensor(_list, dtype=torch.int32, device=self.device)
+
+ def _list_to_long_tensor(
+ self,
+ _list: List[int],
+ ) -> torch.Tensor:
+ return torch.tensor(_list, dtype=torch.long, device=self.device)
+
+ def _empty_int32_tensor(self) -> torch.Tensor:
+ return self._list_to_int32_tensor([])
+
+ def _empty_long_tensor(self) -> torch.Tensor:
+ return self._list_to_long_tensor([])
+
+ @torch.inference_mode()
+ def execute_model(
+ self,
+ model_input: EncoderDecoderModelInput,
+ kv_caches: List[torch.Tensor],
+ intermediate_tensors: Optional[IntermediateTensors] = None,
+ num_steps: int = 1,
+ ) -> Optional[List[PoolerOutput]]:
+ if num_steps > 1:
+ raise ValueError("num_steps > 1 is not supported in "
+ "EncoderDecoderModelRunner")
+
+ model_executable = self.model
+
+ seqlen_agnostic_kwargs = {
+ "finished_requests_ids": model_input.finished_requests_ids,
+ "request_ids_to_seq_ids": model_input.request_ids_to_seq_ids,
+ } if self.has_seqlen_agnostic else {}
+ hidden_or_intermediate_states = model_executable(
+ input_ids=model_input.input_tokens,
+ positions=model_input.input_positions,
+ encoder_input_ids=model_input.encoder_input_tokens,
+ encoder_positions=model_input.encoder_input_positions,
+ kv_caches=kv_caches,
+ attn_metadata=model_input.attn_metadata,
+ intermediate_tensors=intermediate_tensors,
+ **seqlen_agnostic_kwargs)
+
+ logits = self.model.compute_logits(hidden_or_intermediate_states,
+ model_input.sampling_metadata)
+
+ if not self.is_driver_worker:
+ return []
+
+ # Sample the next token.
+ output: SamplerOutput = self.model.sample(
+ logits=logits,
+ sampling_metadata=model_input.sampling_metadata,
+ )
+
+ return [output]
+
+ def make_model_input_from_broadcasted_tensor_dict(
+ self, tensor_dict: Dict[str, Any]) -> EncoderDecoderModelInput:
+ return EncoderDecoderModelInput.from_broadcasted_tensor_dict(
+ tensor_dict,
+ attn_backend=self.attn_backend,
+ )
+
+ def prepare_model_input(
+ self,
+ seq_group_metadata_list: List[SequenceGroupMetadata],
+ virtual_engine: int = 0,
+ finished_requests_ids: Optional[List[str]] = None
+ ) -> EncoderDecoderModelInput:
+ """Prepare the model input based on a given sequence group, including
+ metadata for the sampling step.
+
+ Since chunked prefill is not supported for encoder/decoder models,
+ `input_tokens` is assumed to be either entirely prefill tokens or
+ entirely decode tokens.
+
+ """
+ model_input = self._prepare_model_input_tensors(
+ seq_group_metadata_list, finished_requests_ids)
+
+ (
+ attn_metadata,
+ encoder_input_tokens_tensor,
+ encoder_input_positions_tensor,
+ ) = (self._prepare_encoder_model_input_tensors(seq_group_metadata_list,
+ model_input))
+
+ # Inject attn_metadata encoder/cross-attention fields &
+ # encoder input tokens/positions into model_input.
+ # Frozen dataclass fields cannot be modified, so use
+ # dataclasses.replace to construct a new model input
+ # instance.
+ model_input = dataclasses.replace(
+ model_input,
+ attn_metadata=attn_metadata,
+ encoder_input_tokens=encoder_input_tokens_tensor,
+ encoder_input_positions=encoder_input_positions_tensor,
+ )
+
+ sampling_metadata = SamplingMetadata.prepare(seq_group_metadata_list,
+ model_input.seq_lens,
+ model_input.query_lens,
+ self.device,
+ self.pin_memory)
+ is_prompt = (seq_group_metadata_list[0].is_prompt
+ if seq_group_metadata_list else None)
+ return dataclasses.replace(model_input,
+ sampling_metadata=sampling_metadata,
+ is_prompt=is_prompt,
+ virtual_engine=virtual_engine)
+
+ @torch.inference_mode()
+ def profile_run(self) -> None:
+ # Enable top-k sampling to reflect the accurate memory usage.
+ sampling_params = SamplingParams(top_p=0.99, top_k=self.vocab_size - 1)
+ max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
+ max_num_seqs = self.scheduler_config.max_num_seqs
+
+ # Profile memory usage with max_num_sequences sequences and the total
+ # number of tokens equal to max_num_batched_tokens.
+ seqs: List[SequenceGroupMetadata] = []
+
+ model_config = self.model_config
+
+ batch_size = 0
+ for group_id in range(max_num_seqs):
+ seq_len = (max_num_batched_tokens // max_num_seqs +
+ (group_id < max_num_batched_tokens % max_num_seqs))
+ batch_size += seq_len
+
+ seq_data, _ = INPUT_REGISTRY \
+ .dummy_data_for_profiling(model_config, seq_len)
+
+ # Having more tokens is over-conservative but otherwise fine
+ assert len(seq_data.prompt_token_ids) >= seq_len, (
+ f"Expected at least {seq_len} dummy tokens for profiling, "
+ f"but got: {len(seq_data.prompt_token_ids)}")
+
+ seq = SequenceGroupMetadata(
+ request_id=str(group_id),
+ is_prompt=True,
+ seq_data={group_id: seq_data},
+ sampling_params=sampling_params,
+ block_tables=None,
+ encoder_seq_data=seq_data,
+ cross_block_table=None,
+ )
+ seqs.append(seq)
+
+ # Run the model with the dummy inputs.
+ num_layers = self.model_config.get_num_layers(self.parallel_config)
+ kv_caches = [None] * num_layers
+ finished_requests_ids = [seq.request_id for seq in seqs]
+ model_input = self.prepare_model_input(
+ seqs, finished_requests_ids=finished_requests_ids)
+ intermediate_tensors = None
+ self.execute_model(model_input, kv_caches, intermediate_tensors)
+ torch.cuda.synchronize()
+ return
+
+ def _prepare_encoder_model_input_tensors(
+ self,
+ seq_group_metadata_list: List[SequenceGroupMetadata],
+ model_input: EncoderDecoderModelInput,
+ ) -> Tuple[AttentionMetadata, Optional[torch.Tensor],
+ Optional[torch.Tensor]]:
+ """Helper method to prepare the encoder- and cross-attn-related
+ model inputs based on a given sequence group. These additional inputs
+ are used to augment an already-computed `EncoderDecoderModelInput`
+ data structure which already has decoder-related model inputs
+ populated.
+
+ Sets the following attn_metadata fields:
+ * `num_encoder_tokens`
+ * `encoder_seq_lens`
+ * `encoder_seq_lens_tensor`
+ * `max_encoder_seq_len`
+ * `cross_slot_mapping`
+ * `cross_block_tables`
+
+ Constructs a new model inputs data structure, based on
+ (1) the existing fields in the `model_inputs` argument,
+ and (2) the following additional fields which are
+ computed (or in the case of `attn_metadata`, updated)
+ by this function:
+ * attn_metadata
+ * encoder_input_tokens
+ * encoder_input_positions
+
+ Arguments:
+
+ * seq_group_metadata_list: list of sequence groups for which to
+ compute inputs
+ * model_inputs: model inputs data structure with decoder-oriented
+ fields already computed.
+
+ Return:
+
+ * Updated model inputs data structure
+ """
+
+ if len(seq_group_metadata_list) == 0:
+ return (model_input.attn_metadata, None, None)
+
+ # Since we are not supporting chunked prefill either the entire
+ # batch is prefill or it is decode
+ is_prompt = seq_group_metadata_list[0].is_prompt
+
+ # Build encoder inputs
+ encoder_seq_lens: List[int] = []
+ if is_prompt:
+ # Prefill phase.
+ cross_block_tables = self._empty_int32_tensor().view(
+ len(seq_group_metadata_list), -1)
+
+ # Extract input tokens/positions, cross-attention slot-mapping,
+ # & seq len from each sequence group metadata
+ (
+ encoder_input_tokens,
+ encoder_input_positions,
+ cross_slot_mapping,
+ ) = (
+ [],
+ [],
+ [],
+ )
+ for seq_group_metadata in seq_group_metadata_list:
+ # Build seq lens
+ seq_len = seq_group_metadata.encoder_seq_data.get_len()
+ token_ids = seq_group_metadata.encoder_seq_data.get_token_ids()
+ encoder_seq_lens.append(seq_len)
+
+ # Build slot mapping
+ is_profile_run = (seq_group_metadata.block_tables is None)
+ if is_profile_run:
+ # During memory profiling, the block tables are not
+ # initialized yet. In this case, we just use a dummy
+ # slot mapping.
+ # In embeddings, the block tables are {seq_id: None}.
+ cross_slot_mapping.extend([_PAD_SLOT_ID] * seq_len)
+ else:
+ for i in range(0, seq_len):
+ block_number = seq_group_metadata.cross_block_table[
+ i // self.block_size]
+ block_offset = i % self.block_size
+ slot = block_number * self.block_size + block_offset
+ cross_slot_mapping.append(slot)
+
+ # Build encoder input tokens
+ encoder_input_tokens.extend(token_ids)
+ encoder_input_positions.extend(list(range(0, seq_len)))
+
+ # Convert tokens/positions & cross-attention
+ # slot-mapping to encoder input tensors
+ encoder_input_tokens_tensor = self._list_to_long_tensor(
+ encoder_input_tokens)
+ encoder_input_positions_tensor = self._list_to_long_tensor(
+ encoder_input_positions)
+ cross_slot_mapping_tensor = self._list_to_long_tensor(
+ cross_slot_mapping)
+
+ else:
+ # Decode phase.
+ encoder_input_tokens_tensor = self._empty_long_tensor()
+ encoder_input_positions_tensor = self._empty_long_tensor()
+ cross_slot_mapping_tensor = self._empty_long_tensor()
+
+ # Extract cross-attention block tables &
+ # seq len from each sequence group metadata.
+ # Cross-attention block tables are empty
+ # during vLLM memory profiling.
+ cross_block_tables = []
+ for seq_group_metadata in seq_group_metadata_list:
+ encoder_seq_lens.append(
+ seq_group_metadata.encoder_seq_data.get_len())
+ cross_block_table = seq_group_metadata.cross_block_table
+ cross_block_tables.append([] if (
+ cross_block_table is None) else cross_block_table)
+
+ # Convert cross-attention block tables to encoder input tensor
+ cross_block_tables = make_tensor_with_pad(
+ cross_block_tables,
+ max_len=max(
+ len(block_table) for block_table in cross_block_tables),
+ pad=0,
+ dtype=torch.int32,
+ device=self.device,
+ )
+
+ # Compute encoder sequence lengths & encoder
+ # sequence starting offset tensors
+ max_encoder_seq_len = max(encoder_seq_lens, default=0)
+ encoder_seq_lens_tensor = self._list_to_int32_tensor(encoder_seq_lens)
+ encoder_seq_start_loc = torch.zeros(encoder_seq_lens_tensor.shape[0] +
+ 1,
+ dtype=torch.int32,
+ device=self.device)
+ torch.cumsum(encoder_seq_lens_tensor,
+ dim=0,
+ dtype=encoder_seq_start_loc.dtype,
+ out=encoder_seq_start_loc[1:])
+
+ # Update attention metadata with encoder-oriented attributes
+ attn_metadata = model_input.attn_metadata
+ assert attn_metadata is not None
+ (
+ attn_metadata.num_encoder_tokens,
+ attn_metadata.encoder_seq_lens,
+ attn_metadata.encoder_seq_lens_tensor,
+ attn_metadata.max_encoder_seq_len,
+ attn_metadata.cross_slot_mapping,
+ attn_metadata.cross_block_tables,
+ ) = (
+ sum(encoder_seq_lens),
+ encoder_seq_lens,
+ encoder_seq_lens_tensor,
+ max_encoder_seq_len,
+ cross_slot_mapping_tensor,
+ cross_block_tables,
+ )
+
+ return (attn_metadata, encoder_input_tokens_tensor,
+ encoder_input_positions_tensor)
diff --git a/vllm/worker/utils.py b/vllm/worker/utils.py
new file mode 100644
index 0000000000000..8df3c8bc5408b
--- /dev/null
+++ b/vllm/worker/utils.py
@@ -0,0 +1,56 @@
+'''
+Worker-related helper functions.
+'''
+
+from vllm.utils import STR_NOT_IMPL_ENC_DEC_ERR_STRS
+from vllm.worker.model_runner import GPUModelRunnerBase
+
+
+def assert_enc_dec_mr_supported_scenario(
+ enc_dec_mr: GPUModelRunnerBase) -> None:
+ '''
+ Asserted that the provided encoder/decoder model runner instance reflects
+ a supported scenario.
+ '''
+
+ if enc_dec_mr.cache_config.enable_prefix_caching:
+ raise NotImplementedError(
+ STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE'])
+
+ if enc_dec_mr.sliding_window is not None:
+ raise NotImplementedError(
+ STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_SWA'])
+
+ if enc_dec_mr.scheduler_config.chunked_prefill_enabled:
+ raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_ERR_STRS[
+ 'STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL'])
+
+ if getattr(enc_dec_mr.model_config.hf_config, 'attn_logit_softcapping',
+ None) is not None:
+ raise NotImplementedError(
+ STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_LOGIT_SOFTCAP']
+ )
+
+ if enc_dec_mr.lora_config is not None:
+ raise NotImplementedError(
+ STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_LORA'])
+
+ if enc_dec_mr.parallel_config.pipeline_parallel_size > 1:
+ raise NotImplementedError(
+ STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_PP'])
+
+ if enc_dec_mr.multimodal_config is not None:
+ raise NotImplementedError(
+ STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_MM'])
+
+ if enc_dec_mr.scheduler_config.num_lookahead_slots > 0:
+ raise NotImplementedError(
+ STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_SPEC_DEC'])
+
+ if not enc_dec_mr.model_config.enforce_eager:
+ raise NotImplementedError(
+ STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_CUDA_GRAPH'])
+
+ if enc_dec_mr.prompt_adapter_config is not None:
+ raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_ERR_STRS[
+ 'STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER'])
diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py
index 9e2cfff435cf6..ad6f6750ff980 100644
--- a/vllm/worker/worker.py
+++ b/vllm/worker/worker.py
@@ -19,8 +19,11 @@
from vllm.platforms import current_platform
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sequence import ExecuteModelRequest
+from vllm.utils import (is_embedding_model_config,
+ is_encoder_decoder_model_config)
from vllm.worker.cache_engine import CacheEngine
from vllm.worker.embedding_model_runner import EmbeddingModelRunner
+from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner
from vllm.worker.model_runner import GPUModelRunnerBase, ModelRunner
from vllm.worker.worker_base import LocalOrDistributedWorkerBase, WorkerInput
@@ -85,8 +88,10 @@ def __init__(
ModelRunnerClass: Type[GPUModelRunnerBase] = ModelRunner
if model_runner_cls is not None:
ModelRunnerClass = model_runner_cls
- elif self.model_config.embedding_mode:
+ elif self._is_embedding_model():
ModelRunnerClass = EmbeddingModelRunner
+ elif self._is_encoder_decoder_model():
+ ModelRunnerClass = EncoderDecoderModelRunner
self.model_runner: GPUModelRunnerBase = ModelRunnerClass(
model_config,
parallel_config,
@@ -107,6 +112,12 @@ def __init__(
# Initialize gpu_cache as embedding models don't initialize kv_caches
self.gpu_cache: Optional[List[List[torch.Tensor]]] = None
+ def _is_encoder_decoder_model(self):
+ return is_encoder_decoder_model_config(self.model_config)
+
+ def _is_embedding_model(self):
+ return is_embedding_model_config(self.model_config)
+
def init_device(self) -> None:
if self.device_config.device.type == "cuda":
# torch.distributed.all_reduce does not free the input tensor until