diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index dd265b22bff54..6540e023c1ab0 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -495,7 +495,7 @@ Text Generation --------------- .. list-table:: - :widths: 25 25 15 25 5 5 + :widths: 25 25 15 20 5 5 5 :header-rows: 1 * - Architecture @@ -504,47 +504,55 @@ Text Generation - Example HF Models - :ref:`LoRA ` - :ref:`PP ` + - V1 * - :code:`AriaForConditionalGeneration` - Aria - T + I - :code:`rhymes-ai/Aria` - - ✅︎ + - * - :code:`Blip2ForConditionalGeneration` - BLIP-2 - T + I\ :sup:`E` - :code:`Salesforce/blip2-opt-2.7b`, :code:`Salesforce/blip2-opt-6.7b`, etc. - - ✅︎ + - * - :code:`ChameleonForConditionalGeneration` - Chameleon - T + I - :code:`facebook/chameleon-7b` etc. - - ✅︎ + - * - :code:`FuyuForCausalLM` - Fuyu - T + I - :code:`adept/fuyu-8b` etc. - - ✅︎ + - * - :code:`ChatGLMModel` - GLM-4V - T + I - :code:`THUDM/glm-4v-9b` etc. - ✅︎ - ✅︎ + - * - :code:`H2OVLChatModel` - H2OVL - T + I\ :sup:`E+` - :code:`h2oai/h2ovl-mississippi-800m`, :code:`h2oai/h2ovl-mississippi-2b`, etc. - - ✅︎ + - * - :code:`Idefics3ForConditionalGeneration` - Idefics3 - T + I - :code:`HuggingFaceM4/Idefics3-8B-Llama3` etc. - ✅︎ + - - * - :code:`InternVLChatModel` - InternVL 2.5, Mono-InternVL, InternVL 2.0 @@ -552,96 +560,112 @@ Text Generation - :code:`OpenGVLab/InternVL2_5-4B`, :code:`OpenGVLab/Mono-InternVL-2B`, :code:`OpenGVLab/InternVL2-4B`, etc. - - ✅︎ + - ✅︎ * - :code:`LlavaForConditionalGeneration` - LLaVA-1.5 - T + I\ :sup:`E+` - :code:`llava-hf/llava-1.5-7b-hf`, :code:`TIGER-Lab/Mantis-8B-siglip-llama3` (see note), etc. - - ✅︎ + - ✅︎ * - :code:`LlavaNextForConditionalGeneration` - LLaVA-NeXT - T + I\ :sup:`E+` - :code:`llava-hf/llava-v1.6-mistral-7b-hf`, :code:`llava-hf/llava-v1.6-vicuna-7b-hf`, etc. - - ✅︎ + - * - :code:`LlavaNextVideoForConditionalGeneration` - LLaVA-NeXT-Video - T + V - :code:`llava-hf/LLaVA-NeXT-Video-7B-hf`, etc. - - ✅︎ + - * - :code:`LlavaOnevisionForConditionalGeneration` - LLaVA-Onevision - T + I\ :sup:`+` + V\ :sup:`+` - :code:`llava-hf/llava-onevision-qwen2-7b-ov-hf`, :code:`llava-hf/llava-onevision-qwen2-0.5b-ov-hf`, etc. - - ✅︎ + - * - :code:`MiniCPMV` - MiniCPM-V - T + I\ :sup:`E+` - :code:`openbmb/MiniCPM-V-2` (see note), :code:`openbmb/MiniCPM-Llama3-V-2_5`, :code:`openbmb/MiniCPM-V-2_6`, etc. - ✅︎ - ✅︎ + - * - :code:`MllamaForConditionalGeneration` - Llama 3.2 - T + I\ :sup:`+` - :code:`meta-llama/Llama-3.2-90B-Vision-Instruct`, :code:`meta-llama/Llama-3.2-11B-Vision`, etc. - - + - * - :code:`MolmoForCausalLM` - Molmo - T + I - :code:`allenai/Molmo-7B-D-0924`, :code:`allenai/Molmo-72B-0924`, etc. - - ✅︎ + - ✅︎ * - :code:`NVLM_D_Model` - NVLM-D 1.0 - T + I\ :sup:`E+` - :code:`nvidia/NVLM-D-72B`, etc. - - ✅︎ + - ✅︎ * - :code:`PaliGemmaForConditionalGeneration` - PaliGemma - T + I\ :sup:`E` - :code:`google/paligemma-3b-pt-224`, :code:`google/paligemma-3b-mix-224`, etc. - - ✅︎ + - * - :code:`Phi3VForCausalLM` - Phi-3-Vision, Phi-3.5-Vision - T + I\ :sup:`E+` - :code:`microsoft/Phi-3-vision-128k-instruct`, :code:`microsoft/Phi-3.5-vision-instruct` etc. - - ✅︎ + - ✅︎ * - :code:`PixtralForConditionalGeneration` - Pixtral - T + I\ :sup:`+` - :code:`mistralai/Pixtral-12B-2409`, :code:`mistral-community/pixtral-12b` etc. - - ✅︎ + - ✅︎ * - :code:`QWenLMHeadModel` - Qwen-VL - T + I\ :sup:`E+` - :code:`Qwen/Qwen-VL`, :code:`Qwen/Qwen-VL-Chat`, etc. - ✅︎ - ✅︎ + - * - :code:`Qwen2AudioForConditionalGeneration` - Qwen2-Audio - T + A\ :sup:`+` - :code:`Qwen/Qwen2-Audio-7B-Instruct` - - ✅︎ + - * - :code:`Qwen2VLForConditionalGeneration` - Qwen2-VL - T + I\ :sup:`E+` + V\ :sup:`E+` - :code:`Qwen/Qwen2-VL-2B-Instruct`, :code:`Qwen/Qwen2-VL-7B-Instruct`, :code:`Qwen/Qwen2-VL-72B-Instruct`, etc. - ✅︎ - ✅︎ + - * - :code:`UltravoxModel` - Ultravox - T + A\ :sup:`E+` - :code:`fixie-ai/ultravox-v0_3` - - ✅︎ + - | :sup:`E` Pre-computed embeddings can be inputted for this modality. | :sup:`+` Multiple items can be inputted per text prompt for this modality. diff --git a/requirements-common.txt b/requirements-common.txt index 72fb020a82c4e..112528880c0ac 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -19,7 +19,7 @@ prometheus-fastapi-instrumentator >= 7.0.0 tiktoken >= 0.6.0 # Required for DBRX tokenizer lm-format-enforcer >= 0.10.9, < 0.11 outlines >= 0.0.43, < 0.1 -xgrammar >= 0.1.5; platform_machine == "x86_64" +xgrammar >= 0.1.6; platform_machine == "x86_64" typing_extensions >= 4.10 filelock >= 3.16.1 # need to contain https://github.com/tox-dev/filelock/pull/317 partial-json-parser # used for parsing partial JSON outputs diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 1206424ae1e3f..f002a8ff905b1 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -265,7 +265,13 @@ def configure_post_pass(self): def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: + # when dynamo calls the backend, it means the bytecode + # transform and analysis are done compilation_counter.num_graphs_seen += 1 + from .monitor import torch_compile_start_time + dynamo_time = time.time() - torch_compile_start_time + logger.info("Dynamo bytecode transform time: %.2f s", dynamo_time) + self.compilation_configs.compilation_time += dynamo_time # we control the compilation process, each instance can only be # called once diff --git a/vllm/compilation/decorators.py b/vllm/compilation/decorators.py index a32dced57e5b3..938430fe2a501 100644 --- a/vllm/compilation/decorators.py +++ b/vllm/compilation/decorators.py @@ -145,6 +145,7 @@ def _support_torch_compile( def __init__(self, *, vllm_config: VllmConfig, prefix: str = '', **kwargs): old_init(self, vllm_config=vllm_config, prefix=prefix, **kwargs) + self.vllm_config = vllm_config # for CompilationLevel.DYNAMO_AS_IS , the upper level model runner # will handle the compilation, so we don't need to do anything here. self.do_not_compile = \ @@ -157,9 +158,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = '', **kwargs): TorchCompileWrapperWithCustomDispatcher.__init__( self, compilation_level=vllm_config.compilation_config.level) - if vllm_config.compilation_config.level == CompilationLevel.PIECEWISE: - start_monitoring_torch_compile(vllm_config.compilation_config) - cls.__init__ = __init__ def __call__(self, *args, **kwargs): @@ -186,6 +184,8 @@ def __call__(self, *args, **kwargs): raise ValueError( "Unsupported dynamic dimensions" f" {dims} for argument {k} with type {type(arg)}.") + # here, it is the starting point of the `torch.compile` process + start_monitoring_torch_compile(self.vllm_config.compilation_config) # if we don't use custom dispatcher, we can directly call the # compiled function and let torch.compile handle the dispatching, diff --git a/vllm/compilation/monitor.py b/vllm/compilation/monitor.py index f718e46423212..3348674b09af2 100644 --- a/vllm/compilation/monitor.py +++ b/vllm/compilation/monitor.py @@ -1,14 +1,19 @@ +import time + from vllm.config import CompilationConfig, CompilationLevel from vllm.logger import init_logger logger = init_logger(__name__) +torch_compile_start_time: float = 0.0 + def start_monitoring_torch_compile(compilation_config: CompilationConfig): - pass + global torch_compile_start_time + torch_compile_start_time = time.time() def end_monitoring_torch_compile(compilation_config: CompilationConfig): if compilation_config.level == CompilationLevel.PIECEWISE: - logger.info("graph compilation takes %.2f s in total", + logger.info("torch.compile takes %.2f s in total", compilation_config.compilation_time) diff --git a/vllm/config.py b/vllm/config.py index c8d1cc9952515..b57e78723e86d 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -518,11 +518,10 @@ def verify_async_output_proc(self, parallel_config, speculative_config, # Reminder: Please update docs/source/usage/compatibility_matrix.rst # If the feature combo become valid - if device_config.device_type not in ("cuda", "tpu", "xpu", "hpu"): + if not current_platform.is_async_output_supported(self.enforce_eager): logger.warning( - "Async output processing is only supported for CUDA, TPU, XPU " - "and HPU." - "Disabling it for other platforms.") + "Async output processing is not supported on the " + "current platform type %s.", current_platform.device_type) self.use_async_output_proc = False return @@ -532,16 +531,6 @@ def verify_async_output_proc(self, parallel_config, speculative_config, self.use_async_output_proc = False return - # Reminder: Please update docs/source/usage/compatibility_matrix.rst - # If the feature combo become valid - if device_config.device_type == "cuda" and self.enforce_eager: - logger.warning( - "To see benefits of async output processing, enable CUDA " - "graph. Since, enforce-eager is enabled, async output " - "processor cannot be used") - self.use_async_output_proc = not self.enforce_eager - return - # Async postprocessor is not necessary with embedding mode # since there is no token generation if self.task == "embedding": @@ -2199,8 +2188,8 @@ class CompilationConfig(BaseModel): TODO: move outside cudagraph logic into compilation. torch.compile will handle cudagraph capture logic in the future. - cudagraph_capture_sizes: sizes to capture cudagraph. - - None: capture sizes are inferred from compilation context. - - List[int]: capture sizes are specified. + - None (default): capture sizes are inferred from vllm config. + - List[int]: capture sizes are specified as given. - cudagraph_num_of_warmups: number of warmup runs for cudagraph. It means the first several runs will be treated as warmup runs. Only after that, the execution will be recorded, and the recorded @@ -2601,45 +2590,40 @@ def __post_init__(self): self.instance_id = random_uuid()[:5] def __str__(self): - return ("model=%r, speculative_config=%r, tokenizer=%r, " - "skip_tokenizer_init=%s, tokenizer_mode=%s, revision=%s, " - "override_neuron_config=%s, tokenizer_revision=%s, " - "trust_remote_code=%s, dtype=%s, max_seq_len=%d, " - "download_dir=%r, load_format=%s, tensor_parallel_size=%d, " - "pipeline_parallel_size=%d, " - "disable_custom_all_reduce=%s, quantization=%s, " - "enforce_eager=%s, kv_cache_dtype=%s, " - "quantization_param_path=%s, device_config=%s, " - "decoding_config=%r, observability_config=%r, " - "seed=%d, served_model_name=%s, " - "num_scheduler_steps=%d, enable_prefix_caching=%s, " - "use_async_output_proc=%s, mm_processor_kwargs=%s") % \ - (self.model_config.model, self.speculative_config, - self.model_config.tokenizer, - self.model_config.skip_tokenizer_init, - self.model_config.tokenizer_mode, - self.model_config.revision, - self.model_config.override_neuron_config, - self.model_config.tokenizer_revision, - self.model_config.trust_remote_code, - self.model_config.dtype, - self.model_config.max_model_len, - self.load_config.download_dir, - self.load_config.load_format, - self.parallel_config.tensor_parallel_size, - self.parallel_config.pipeline_parallel_size, - self.parallel_config.disable_custom_all_reduce, - self.model_config.quantization, - self.model_config.enforce_eager, - self.cache_config.cache_dtype, - self.model_config.quantization_param_path, - self.device_config.device, self.decoding_config, - self.observability_config, self.model_config.seed, - self.model_config.served_model_name, - self.scheduler_config.num_scheduler_steps, - self.cache_config.enable_prefix_caching, - self.model_config.use_async_output_proc, - self.model_config.mm_processor_kwargs) + return ( + f"model={self.model_config.model!r}," + f" speculative_config={self.speculative_config!r}," + f" tokenizer={self.model_config.tokenizer!r}, " + f"skip_tokenizer_init={self.model_config.skip_tokenizer_init}," + f" tokenizer_mode={self.model_config.tokenizer_mode}, " + f"revision={self.model_config.revision}, " + f"override_neuron_config={self.model_config.override_neuron_config}," + f" tokenizer_revision={self.model_config.tokenizer_revision}, " + f"trust_remote_code={self.model_config.trust_remote_code}, " + f"dtype={self.model_config.dtype}, " + f"max_seq_len={self.model_config.max_model_len}," + f" download_dir={self.load_config.download_dir!r}, " + f"load_format={self.load_config.load_format}, " + f"tensor_parallel_size={self.parallel_config.tensor_parallel_size}," + f" pipeline_parallel_size={self.parallel_config.pipeline_parallel_size}, " # noqa + f"disable_custom_all_reduce={self.parallel_config.disable_custom_all_reduce}, " # noqa + f"quantization={self.model_config.quantization}, " + f"enforce_eager={self.model_config.enforce_eager}, " + f"kv_cache_dtype={self.cache_config.cache_dtype}, " + f"quantization_param_path={self.model_config.quantization_param_path}," + f" device_config={self.device_config.device}, " + f"decoding_config={self.decoding_config!r}, " + f"observability_config={self.observability_config!r}, " + f"seed={self.model_config.seed}, " + f"served_model_name={self.model_config.served_model_name}, " + f"num_scheduler_steps={self.scheduler_config.num_scheduler_steps}, " + f"multi_step_stream_outputs={self.scheduler_config.multi_step_stream_outputs}, " # noqa + f"enable_prefix_caching={self.cache_config.enable_prefix_caching}, " + f"chunked_prefill_enabled={self.scheduler_config.chunked_prefill_enabled}, " # noqa + f"use_async_output_proc={self.model_config.use_async_output_proc}, " + f"mm_processor_kwargs={self.model_config.mm_processor_kwargs}, " + f"pooler_config={self.model_config.pooler_config!r}," + f" compilation_config={self.compilation_config!r}") _current_vllm_config: Optional[VllmConfig] = None diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 26a8c94099a11..560f84a008291 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -247,60 +247,12 @@ def __init__( ) logger.info( - "Initializing an LLM engine (v%s) with config: " - "model=%r, speculative_config=%r, tokenizer=%r, " - "skip_tokenizer_init=%s, tokenizer_mode=%s, revision=%s, " - "override_neuron_config=%s, tokenizer_revision=%s, " - "trust_remote_code=%s, dtype=%s, max_seq_len=%d, " - "download_dir=%r, load_format=%s, tensor_parallel_size=%d, " - "pipeline_parallel_size=%d, " - "disable_custom_all_reduce=%s, quantization=%s, " - "enforce_eager=%s, kv_cache_dtype=%s, " - "quantization_param_path=%s, device_config=%s, " - "decoding_config=%r, observability_config=%r, " - "seed=%d, served_model_name=%s, " - "num_scheduler_steps=%d, chunked_prefill_enabled=%s " - "multi_step_stream_outputs=%s, enable_prefix_caching=%s, " - "use_async_output_proc=%s, use_cached_outputs=%s, " - "mm_processor_kwargs=%s, pooler_config=%r," - "compilation_config=%r", + "Initializing an LLM engine (v%s) with config: %r," + "use_cached_outputs=%s, ", VLLM_VERSION, - self.model_config.model, - self.speculative_config, - self.model_config.tokenizer, - self.model_config.skip_tokenizer_init, - self.model_config.tokenizer_mode, - self.model_config.revision, - self.model_config.override_neuron_config, - self.model_config.tokenizer_revision, - self.model_config.trust_remote_code, - self.model_config.dtype, - self.model_config.max_model_len, - self.load_config.download_dir, - self.load_config.load_format, - self.parallel_config.tensor_parallel_size, - self.parallel_config.pipeline_parallel_size, - self.parallel_config.disable_custom_all_reduce, - self.model_config.quantization, - self.model_config.enforce_eager, - self.cache_config.cache_dtype, - self.model_config.quantization_param_path, - self.device_config.device, - self.decoding_config, - self.observability_config, - self.model_config.seed, - self.model_config.served_model_name, - self.scheduler_config.num_scheduler_steps, - self.scheduler_config.chunked_prefill_enabled, - self.scheduler_config.multi_step_stream_outputs, - self.cache_config.enable_prefix_caching, - self.model_config.use_async_output_proc, + vllm_config, use_cached_outputs, - self.model_config.mm_processor_kwargs, - self.model_config.pooler_config, - vllm_config.compilation_config, ) - # TODO(woosuk): Print more configs in debug mode. self.log_stats = log_stats self.use_cached_outputs = use_cached_outputs diff --git a/vllm/platforms/cpu.py b/vllm/platforms/cpu.py index 680ee74129739..e5142b985d1f2 100644 --- a/vllm/platforms/cpu.py +++ b/vllm/platforms/cpu.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional import psutil import torch @@ -37,6 +37,10 @@ def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend: def get_device_total_memory(cls, device_id: int = 0) -> int: return psutil.virtual_memory().total + @classmethod + def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool: + return False + @classmethod def inference_mode(cls): return torch.no_grad() diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 846a1869da228..edaf377b501df 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -4,7 +4,7 @@ import os from functools import lru_cache, wraps -from typing import TYPE_CHECKING, Callable, List, TypeVar +from typing import TYPE_CHECKING, Callable, List, Optional, TypeVar import pynvml import torch @@ -88,6 +88,16 @@ def get_device_name(cls, device_id: int = 0) -> str: def get_device_total_memory(cls, device_id: int = 0) -> int: raise NotImplementedError + @classmethod + def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool: + if enforce_eager: + logger.warning( + "To see benefits of async output processing, enable CUDA " + "graph. Since, enforce-eager is enabled, async output " + "processor cannot be used") + return False + return True + @classmethod def is_full_nvlink(cls, device_ids: List[int]) -> bool: raise NotImplementedError diff --git a/vllm/platforms/hpu.py b/vllm/platforms/hpu.py index 10aaa6d54962c..7f22bee3eaa74 100644 --- a/vllm/platforms/hpu.py +++ b/vllm/platforms/hpu.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional import torch @@ -20,6 +20,10 @@ class HpuPlatform(Platform): def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend: return _Backend.HPU_ATTN + @classmethod + def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool: + return True + @staticmethod def inference_mode(): return torch.no_grad() diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 0be7df7941b8b..db06d2c18e681 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -6,11 +6,15 @@ import numpy as np import torch +from vllm.logger import init_logger + if TYPE_CHECKING: from vllm.config import VllmConfig else: VllmConfig = None +logger = init_logger(__name__) + class _Backend(enum.Enum): FLASH_ATTN = enum.auto() @@ -147,6 +151,13 @@ def get_device_total_memory(cls, device_id: int = 0) -> int: """Get the total memory of a device in bytes.""" raise NotImplementedError + @classmethod + def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool: + """ + Check if the current platform supports async output. + """ + raise NotImplementedError + @classmethod def inference_mode(cls): """A device-specific wrapper of `torch.inference_mode`. diff --git a/vllm/platforms/neuron.py b/vllm/platforms/neuron.py index 87655ea198303..1e5c4bddfa24f 100644 --- a/vllm/platforms/neuron.py +++ b/vllm/platforms/neuron.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional from .interface import Platform, PlatformEnum @@ -18,6 +18,10 @@ class NeuronPlatform(Platform): def get_device_name(cls, device_id: int = 0) -> str: return "neuron" + @classmethod + def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool: + return False + @classmethod def check_and_update_config(cls, vllm_config: VllmConfig) -> None: parallel_config = vllm_config.parallel_config diff --git a/vllm/platforms/openvino.py b/vllm/platforms/openvino.py index 29b61e955d9ab..e0f8e8b4b49fe 100644 --- a/vllm/platforms/openvino.py +++ b/vllm/platforms/openvino.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional import torch @@ -37,6 +37,10 @@ def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend: def get_device_name(self, device_id: int = 0) -> str: return "openvino" + @classmethod + def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool: + return False + @classmethod def inference_mode(self): return torch.inference_mode(mode=True) diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 3c14fbc179f69..66674e3ebe91f 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -1,6 +1,6 @@ import os from functools import lru_cache -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional import torch @@ -72,6 +72,16 @@ def get_device_total_memory(cls, device_id: int = 0) -> int: device_props = torch.cuda.get_device_properties(device_id) return device_props.total_memory + @classmethod + def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool: + if enforce_eager: + logger.warning( + "To see benefits of async output processing, enable CUDA " + "graph. Since, enforce-eager is enabled, async output " + "processor cannot be used") + return False + return True + @classmethod def check_and_update_config(cls, vllm_config: VllmConfig) -> None: parallel_config = vllm_config.parallel_config diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index b138f7e1c54c5..10d874349f36b 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional import torch @@ -35,6 +35,10 @@ def get_device_name(cls, device_id: int = 0) -> str: def get_device_total_memory(cls, device_id: int = 0) -> int: raise NotImplementedError + @classmethod + def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool: + return True + @classmethod def inference_mode(cls): return torch.no_grad() diff --git a/vllm/platforms/xpu.py b/vllm/platforms/xpu.py index 9665786f4c499..11dbd04d55671 100644 --- a/vllm/platforms/xpu.py +++ b/vllm/platforms/xpu.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional import torch @@ -41,6 +41,10 @@ def get_device_total_memory(cls, device_id: int = 0) -> int: device_props = torch.xpu.get_device_properties(device_id) return device_props.total_memory + @classmethod + def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool: + return True + @staticmethod def inference_mode(): return torch.no_grad() diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 4ef372fd8464b..0bcccda2bf329 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -65,7 +65,12 @@ def __init__( input_registry) # Detokenizer (converts EngineCoreOutputs --> RequestOutput). - self.detokenizer = Detokenizer(vllm_config.model_config.tokenizer) + self.detokenizer = Detokenizer( + tokenizer_name=vllm_config.model_config.tokenizer, + tokenizer_mode=vllm_config.model_config.tokenizer_mode, + trust_remote_code=vllm_config.model_config.trust_remote_code, + revision=vllm_config.model_config.tokenizer_revision, + ) # EngineCore (starts the engine in background process). self.engine_core = EngineCoreClient.make_client( diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py new file mode 100644 index 0000000000000..457784bb0287c --- /dev/null +++ b/vllm/v1/worker/gpu_input_batch.py @@ -0,0 +1,280 @@ +# Datastructures defining an input batch + +from dataclasses import dataclass +from typing import TYPE_CHECKING, Dict, List, Optional, Set + +import numpy as np +import torch + +from vllm.multimodal import MultiModalKwargs +from vllm.sampling_params import SamplingParams, SamplingType +from vllm.v1.sample.metadata import SamplingMetadata + +if TYPE_CHECKING: + from vllm.multimodal.inputs import PlaceholderRange + + +@dataclass +class CachedRequestState: + + req_id: str + prompt_token_ids: List[int] + prompt: Optional[str] + mm_inputs: List[MultiModalKwargs] + mm_positions: List["PlaceholderRange"] + sampling_params: SamplingParams + generator: Optional[torch.Generator] + + block_ids: List[int] + num_computed_tokens: int + output_token_ids: List[int] + + @property + def num_tokens(self) -> int: + return len(self.prompt_token_ids) + len(self.output_token_ids) + + +class InputBatch: + + def __init__( + self, + max_num_reqs: int, + max_model_len: int, + max_num_blocks_per_req: int, + device: torch.device, + pin_memory: bool, + ): + self.max_num_reqs = max_num_reqs + self.max_model_len = max_model_len + self.max_num_blocks_per_req = max_num_blocks_per_req + self.device = device + self.pin_memory = pin_memory + + self.req_ids: List[Optional[str]] = [None] * max_num_reqs + self.req_id_to_index: Dict[str, int] = {} + + self.token_ids_cpu = np.empty((max_num_reqs, max_model_len), + dtype=np.int32) + self.num_computed_tokens_cpu = np.empty(max_num_reqs, dtype=np.int32) + + # Attention-related. + self.block_table = torch.zeros((max_num_reqs, max_num_blocks_per_req), + device=self.device, + dtype=torch.int32) + self.block_table_cpu_tensor = torch.zeros( + (max_num_reqs, max_num_blocks_per_req), + device="cpu", + dtype=torch.int32, + pin_memory=pin_memory, + ) + self.block_table_cpu = self.block_table_cpu_tensor.numpy() + + # Sampling-related. + self.temperature = torch.empty((max_num_reqs, ), + dtype=torch.float32, + device=device) + self.temperature_cpu_tensor = torch.empty((max_num_reqs, ), + dtype=torch.float32, + device="cpu", + pin_memory=pin_memory) + self.temperature_cpu = self.temperature_cpu_tensor.numpy() + self.greedy_reqs: Set[str] = set() + self.random_reqs: Set[str] = set() + + self.top_p = torch.empty((max_num_reqs, ), + dtype=torch.float32, + device=device) + self.top_p_cpu_tensor = torch.empty((max_num_reqs, ), + dtype=torch.float32, + device="cpu", + pin_memory=pin_memory) + self.top_p_cpu = self.top_p_cpu_tensor.numpy() + self.top_p_reqs: Set[str] = set() + + self.top_k = torch.empty((max_num_reqs, ), + dtype=torch.int32, + device=device) + self.top_k_cpu_tensor = torch.empty((max_num_reqs, ), + dtype=torch.int32, + device="cpu", + pin_memory=pin_memory) + self.top_k_cpu = self.top_k_cpu_tensor.numpy() + self.top_k_reqs: Set[str] = set() + + # req_index -> generator + self.generators: Dict[int, torch.Generator] = {} + + self.num_logprobs: Dict[str, int] = {} + self.prompt_logprob_reqs: Set[str] = set() + + def add_request( + self, + request: "CachedRequestState", + req_index: Optional[int] = None, + ) -> None: + if req_index is None: + req_index = self.num_reqs + assert req_index < self.max_num_reqs + + req_id = request.req_id + self.req_ids[req_index] = req_id + self.req_id_to_index[req_id] = req_index + + # Copy the prompt token ids and output token ids. + num_prompt_tokens = len(request.prompt_token_ids) + self.token_ids_cpu[ + req_index, :num_prompt_tokens] = request.prompt_token_ids + start_idx = num_prompt_tokens + end_idx = start_idx + len(request.output_token_ids) + self.token_ids_cpu[req_index, + start_idx:end_idx] = request.output_token_ids + + self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens + num_blocks = len(request.block_ids) + self.block_table_cpu[req_index, :num_blocks] = request.block_ids + + sampling_params = request.sampling_params + self.temperature_cpu[req_index] = sampling_params.temperature + if sampling_params.sampling_type == SamplingType.GREEDY: + self.greedy_reqs.add(req_id) + else: + self.random_reqs.add(req_id) + + self.top_p_cpu[req_index] = sampling_params.top_p + if sampling_params.top_p < 1: + self.top_p_reqs.add(req_id) + self.top_k_cpu[req_index] = sampling_params.top_k + if sampling_params.top_k > 0: + self.top_k_reqs.add(req_id) + + self.generators[req_index] = request.generator + + num_logprobs = sampling_params.logprobs + if num_logprobs is not None and num_logprobs > 0: + self.num_logprobs[req_id] = num_logprobs + if sampling_params.prompt_logprobs: + self.prompt_logprob_reqs.add(req_id) + + def remove_request(self, req_id: str) -> Optional[int]: + req_index = self.req_id_to_index.pop(req_id, None) + if req_index is None: + return None + self.req_ids[req_index] = None + + self.greedy_reqs.discard(req_id) + self.random_reqs.discard(req_id) + self.top_p_reqs.discard(req_id) + self.top_k_reqs.discard(req_id) + self.generators.pop(req_index, None) + self.num_logprobs.pop(req_id, None) + self.prompt_logprob_reqs.discard(req_id) + return req_index + + def clear(self) -> None: + self.req_ids = [None] * self.max_num_reqs + self.req_id_to_index.clear() + self.greedy_reqs.clear() + self.random_reqs.clear() + self.top_p_reqs.clear() + self.top_k_reqs.clear() + self.generators.clear() + self.num_logprobs.clear() + self.prompt_logprob_reqs.clear() + + def condense(self, empty_req_indices: List[int]) -> None: + if self.num_reqs == 0: + # The batched states are empty. + return + + # NOTE(woosuk): This function assumes that the empty_req_indices + # is sorted in descending order. + last_req_index = self.num_reqs + len(empty_req_indices) - 1 + while empty_req_indices: + # Find the largest non-empty index. + while last_req_index in empty_req_indices: + last_req_index -= 1 + + # Find the smallest empty index. + empty_index = empty_req_indices.pop() + if empty_index >= last_req_index: + break + + # Swap the states. + req_id = self.req_ids[last_req_index] + self.req_ids[empty_index] = req_id + self.req_ids[last_req_index] = None + self.req_id_to_index[req_id] = empty_index + + # TODO(woosuk): Optimize the copy of token_ids_cpu and + # block_table_cpu. + self.token_ids_cpu[empty_index] = self.token_ids_cpu[ + last_req_index] + self.num_computed_tokens_cpu[ + empty_index] = self.num_computed_tokens_cpu[last_req_index] + self.block_table_cpu[empty_index] = self.block_table_cpu[ + last_req_index] + self.temperature_cpu[empty_index] = self.temperature_cpu[ + last_req_index] + self.top_p_cpu[empty_index] = self.top_p_cpu[last_req_index] + self.top_k_cpu[empty_index] = self.top_k_cpu[last_req_index] + generator = self.generators.pop(last_req_index, None) + if generator is not None: + self.generators[empty_index] = generator + + # Decrement last_req_index since it is now empty. + last_req_index -= 1 + + def make_sampling_metadata( + self, + skip_copy: bool = False, + ) -> SamplingMetadata: + if not skip_copy: + self.temperature[:self.num_reqs].copy_( + self.temperature_cpu_tensor[:self.num_reqs], non_blocking=True) + self.top_p[:self.num_reqs].copy_( + self.top_p_cpu_tensor[:self.num_reqs], non_blocking=True) + self.top_k[:self.num_reqs].copy_( + self.top_k_cpu_tensor[:self.num_reqs], non_blocking=True) + return SamplingMetadata( + temperature=self.temperature[:self.num_reqs], + all_greedy=self.all_greedy, + all_random=self.all_random, + top_p=self.top_p[:self.num_reqs], + top_k=self.top_k[:self.num_reqs], + no_top_p=self.no_top_p, + no_top_k=self.no_top_k, + generators=self.generators, + max_num_logprobs=self.max_num_logprobs, + ) + + @property + def num_reqs(self) -> int: + return len(self.req_id_to_index) + + @property + def all_greedy(self) -> bool: + return len(self.random_reqs) == 0 + + @property + def all_random(self) -> bool: + return len(self.greedy_reqs) == 0 + + @property + def no_top_p(self) -> bool: + return len(self.top_p_reqs) == 0 + + @property + def no_top_k(self) -> bool: + return len(self.top_k_reqs) == 0 + + @property + def max_num_logprobs(self) -> int: + return max(self.num_logprobs.values()) if self.num_logprobs else 0 + + @property + def no_logprob(self) -> bool: + return len(self.num_logprobs) == 0 + + @property + def no_prompt_logprob(self) -> bool: + return len(self.prompt_logprob_reqs) == 0 diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 9f1b37d3b8e8c..5658c37d39dbd 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1,7 +1,6 @@ import gc import time -from dataclasses import dataclass -from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple import numpy as np import torch @@ -15,16 +14,16 @@ from vllm.logger import init_logger from vllm.model_executor.model_loader import get_model from vllm.multimodal import MultiModalKwargs -from vllm.sampling_params import SamplingParams, SamplingType -from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, - LayerBlockType, cdiv, is_pin_memory_available) +from vllm.sampling_params import SamplingType +from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, LayerBlockType, cdiv, + is_pin_memory_available) from vllm.v1.attention.backends.flash_attn import (FlashAttentionBackend, FlashAttentionMetadata) from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.sample.metadata import SamplingMetadata +from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch if TYPE_CHECKING: - from vllm.multimodal.inputs import PlaceholderRange from vllm.v1.core.scheduler import SchedulerOutput logger = init_logger(__name__) @@ -609,269 +608,3 @@ def _get_padded_batch_size(self, batch_size: int) -> Optional[int]: if batch_size <= size: return size return None - - -@dataclass -class CachedRequestState: - - req_id: str - prompt_token_ids: List[int] - prompt: Optional[str] - mm_inputs: List[MultiModalKwargs] - mm_positions: List["PlaceholderRange"] - sampling_params: SamplingParams - generator: Optional[torch.Generator] - - block_ids: List[int] - num_computed_tokens: int - output_token_ids: List[int] - - @property - def num_tokens(self) -> int: - return len(self.prompt_token_ids) + len(self.output_token_ids) - - -class InputBatch: - - def __init__( - self, - max_num_reqs: int, - max_model_len: int, - max_num_blocks_per_req: int, - device: torch.device, - pin_memory: bool, - ): - self.max_num_reqs = max_num_reqs - self.max_model_len = max_model_len - self.max_num_blocks_per_req = max_num_blocks_per_req - self.device = device - self.pin_memory = pin_memory - - self.req_ids: List[Optional[str]] = [None] * max_num_reqs - self.req_id_to_index: Dict[str, int] = {} - - self.token_ids_cpu = np.empty((max_num_reqs, max_model_len), - dtype=np.int32) - self.num_computed_tokens_cpu = np.empty(max_num_reqs, dtype=np.int32) - - # Attention-related. - self.block_table = torch.zeros((max_num_reqs, max_num_blocks_per_req), - device=self.device, - dtype=torch.int32) - self.block_table_cpu_tensor = torch.zeros( - (max_num_reqs, max_num_blocks_per_req), - device="cpu", - dtype=torch.int32, - pin_memory=pin_memory, - ) - self.block_table_cpu = self.block_table_cpu_tensor.numpy() - - # Sampling-related. - self.temperature = torch.empty((max_num_reqs, ), - dtype=torch.float32, - device=device) - self.temperature_cpu_tensor = torch.empty((max_num_reqs, ), - dtype=torch.float32, - device="cpu", - pin_memory=pin_memory) - self.temperature_cpu = self.temperature_cpu_tensor.numpy() - self.greedy_reqs: Set[str] = set() - self.random_reqs: Set[str] = set() - - self.top_p = torch.empty((max_num_reqs, ), - dtype=torch.float32, - device=device) - self.top_p_cpu_tensor = torch.empty((max_num_reqs, ), - dtype=torch.float32, - device="cpu", - pin_memory=pin_memory) - self.top_p_cpu = self.top_p_cpu_tensor.numpy() - self.top_p_reqs: Set[str] = set() - - self.top_k = torch.empty((max_num_reqs, ), - dtype=torch.int32, - device=device) - self.top_k_cpu_tensor = torch.empty((max_num_reqs, ), - dtype=torch.int32, - device="cpu", - pin_memory=pin_memory) - self.top_k_cpu = self.top_k_cpu_tensor.numpy() - self.top_k_reqs: Set[str] = set() - - # req_index -> generator - self.generators: Dict[int, torch.Generator] = {} - - self.num_logprobs: Dict[str, int] = {} - self.prompt_logprob_reqs: Set[str] = set() - - def add_request( - self, - request: "CachedRequestState", - req_index: Optional[int] = None, - ) -> None: - if req_index is None: - req_index = self.num_reqs - assert req_index < self.max_num_reqs - - req_id = request.req_id - self.req_ids[req_index] = req_id - self.req_id_to_index[req_id] = req_index - - # Copy the prompt token ids and output token ids. - num_prompt_tokens = len(request.prompt_token_ids) - self.token_ids_cpu[ - req_index, :num_prompt_tokens] = request.prompt_token_ids - start_idx = num_prompt_tokens - end_idx = start_idx + len(request.output_token_ids) - self.token_ids_cpu[req_index, - start_idx:end_idx] = request.output_token_ids - - self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens - num_blocks = len(request.block_ids) - self.block_table_cpu[req_index, :num_blocks] = request.block_ids - - sampling_params = request.sampling_params - self.temperature_cpu[req_index] = sampling_params.temperature - if sampling_params.sampling_type == SamplingType.GREEDY: - self.greedy_reqs.add(req_id) - else: - self.random_reqs.add(req_id) - - self.top_p_cpu[req_index] = sampling_params.top_p - if sampling_params.top_p < 1: - self.top_p_reqs.add(req_id) - self.top_k_cpu[req_index] = sampling_params.top_k - if sampling_params.top_k > 0: - self.top_k_reqs.add(req_id) - - self.generators[req_index] = request.generator - - num_logprobs = sampling_params.logprobs - if num_logprobs is not None and num_logprobs > 0: - self.num_logprobs[req_id] = num_logprobs - if sampling_params.prompt_logprobs: - self.prompt_logprob_reqs.add(req_id) - - def remove_request(self, req_id: str) -> Optional[int]: - req_index = self.req_id_to_index.pop(req_id, None) - if req_index is None: - return None - self.req_ids[req_index] = None - - self.greedy_reqs.discard(req_id) - self.random_reqs.discard(req_id) - self.top_p_reqs.discard(req_id) - self.top_k_reqs.discard(req_id) - self.generators.pop(req_index, None) - self.num_logprobs.pop(req_id, None) - self.prompt_logprob_reqs.discard(req_id) - return req_index - - def clear(self) -> None: - self.req_ids = [None] * self.max_num_reqs - self.req_id_to_index.clear() - self.greedy_reqs.clear() - self.random_reqs.clear() - self.top_p_reqs.clear() - self.top_k_reqs.clear() - self.generators.clear() - self.num_logprobs.clear() - self.prompt_logprob_reqs.clear() - - def condense(self, empty_req_indices: List[int]) -> None: - if self.num_reqs == 0: - # The batched states are empty. - return - - # NOTE(woosuk): This function assumes that the empty_req_indices - # is sorted in descending order. - last_req_index = self.num_reqs + len(empty_req_indices) - 1 - while empty_req_indices: - # Find the largest non-empty index. - while last_req_index in empty_req_indices: - last_req_index -= 1 - - # Find the smallest empty index. - empty_index = empty_req_indices.pop() - if empty_index >= last_req_index: - break - - # Swap the states. - req_id = self.req_ids[last_req_index] - self.req_ids[empty_index] = req_id - self.req_ids[last_req_index] = None - self.req_id_to_index[req_id] = empty_index - - # TODO(woosuk): Optimize the copy of token_ids_cpu and - # block_table_cpu. - self.token_ids_cpu[empty_index] = self.token_ids_cpu[ - last_req_index] - self.num_computed_tokens_cpu[ - empty_index] = self.num_computed_tokens_cpu[last_req_index] - self.block_table_cpu[empty_index] = self.block_table_cpu[ - last_req_index] - self.temperature_cpu[empty_index] = self.temperature_cpu[ - last_req_index] - self.top_p_cpu[empty_index] = self.top_p_cpu[last_req_index] - self.top_k_cpu[empty_index] = self.top_k_cpu[last_req_index] - generator = self.generators.pop(last_req_index, None) - if generator is not None: - self.generators[empty_index] = generator - - # Decrement last_req_index since it is now empty. - last_req_index -= 1 - - def make_sampling_metadata( - self, - skip_copy: bool = False, - ) -> SamplingMetadata: - if not skip_copy: - self.temperature[:self.num_reqs].copy_( - self.temperature_cpu_tensor[:self.num_reqs], non_blocking=True) - self.top_p[:self.num_reqs].copy_( - self.top_p_cpu_tensor[:self.num_reqs], non_blocking=True) - self.top_k[:self.num_reqs].copy_( - self.top_k_cpu_tensor[:self.num_reqs], non_blocking=True) - return SamplingMetadata( - temperature=self.temperature[:self.num_reqs], - all_greedy=self.all_greedy, - all_random=self.all_random, - top_p=self.top_p[:self.num_reqs], - top_k=self.top_k[:self.num_reqs], - no_top_p=self.no_top_p, - no_top_k=self.no_top_k, - generators=self.generators, - max_num_logprobs=self.max_num_logprobs, - ) - - @property - def num_reqs(self) -> int: - return len(self.req_id_to_index) - - @property - def all_greedy(self) -> bool: - return len(self.random_reqs) == 0 - - @property - def all_random(self) -> bool: - return len(self.greedy_reqs) == 0 - - @property - def no_top_p(self) -> bool: - return len(self.top_p_reqs) == 0 - - @property - def no_top_k(self) -> bool: - return len(self.top_k_reqs) == 0 - - @property - def max_num_logprobs(self) -> int: - return max(self.num_logprobs.values()) if self.num_logprobs else 0 - - @property - def no_logprob(self) -> bool: - return len(self.num_logprobs) == 0 - - @property - def no_prompt_logprob(self) -> bool: - return len(self.prompt_logprob_reqs) == 0