diff --git a/vllm/config.py b/vllm/config.py index eae6f909e3933..ac2b256dba815 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2318,6 +2318,7 @@ class VllmConfig: quant_config: Optional[QuantizationConfig] = None compilation_config: CompilationConfig = field(default=None, init=True) # type: ignore + model_configs: List[ModelConfig] = field(default=None, init=True) # type: ignore @staticmethod def _get_quantization_config( diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index a4975cece9a81..71db41e0ffe47 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -229,6 +229,7 @@ def __init__( input_registry: InputRegistry = INPUT_REGISTRY, mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, use_cached_outputs: bool = False, + model: str=None, ) -> None: self.model_config = vllm_config.model_config @@ -1442,7 +1443,6 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: outputs = self.model_executor.execute_model( execute_model_req=execute_model_req) - # We need to do this here so that last step's sampled_token_ids can # be passed to the next iteration for PP. if self.scheduler_config.is_multi_step: diff --git a/vllm/engine/mm_arg_utils.py b/vllm/engine/mm_arg_utils.py new file mode 100644 index 0000000000000..9ed775441bae1 --- /dev/null +++ b/vllm/engine/mm_arg_utils.py @@ -0,0 +1,1206 @@ +import argparse +import dataclasses +import json +from dataclasses import dataclass +from typing import (TYPE_CHECKING, Any, Dict, List, Literal, Mapping, Optional, + Tuple, Type, Union, cast, get_args) + +import torch + +import vllm.envs as envs +from vllm.config import (CacheConfig, CompilationConfig, ConfigFormat, + DecodingConfig, DeviceConfig, HfOverrides, LoadConfig, + LoadFormat, LoRAConfig, ModelConfig, + ObservabilityConfig, ParallelConfig, PoolerConfig, + PromptAdapterConfig, SchedulerConfig, + SpeculativeConfig, TaskOption, TokenizerPoolConfig, + VllmConfig) +from vllm.executor.executor_base import ExecutorBase +from vllm.logger import init_logger +from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS +from vllm.platforms import current_platform +from vllm.transformers_utils.utils import check_gguf_file +from vllm.utils import FlexibleArgumentParser, StoreBoolean + +if TYPE_CHECKING: + from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup + +logger = init_logger(__name__) + +ALLOWED_DETAILED_TRACE_MODULES = ["model", "worker", "all"] + +DEVICE_OPTIONS = [ + "auto", + "cuda", + "neuron", + "cpu", + "openvino", + "tpu", + "xpu", + "hpu", +] + + +def nullable_str(val: str): + if not val or val == "None": + return None + return val + + +def nullable_kvs(val: str) -> Optional[Mapping[str, int]]: + """Parses a string containing comma separate key [str] to value [int] + pairs into a dictionary. + + Args: + val: String value to be parsed. + + Returns: + Dictionary with parsed values. + """ + if len(val) == 0: + return None + + out_dict: Dict[str, int] = {} + for item in val.split(","): + kv_parts = [part.lower().strip() for part in item.split("=")] + if len(kv_parts) != 2: + raise argparse.ArgumentTypeError( + "Each item should be in the form KEY=VALUE") + key, value = kv_parts + + try: + parsed_value = int(value) + except ValueError as exc: + msg = f"Failed to parse value of item {key}={value}" + raise argparse.ArgumentTypeError(msg) from exc + + if key in out_dict and out_dict[key] != parsed_value: + raise argparse.ArgumentTypeError( + f"Conflicting values specified for key: {key}") + out_dict[key] = parsed_value + + return out_dict + + +@dataclass +class EngineArgs: + """Arguments for vLLM engine.""" + model: Optional[Union[str, List[str]]] = 'facebook/opt-125m' + served_model_name: Optional[Union[str, List[str]]] = None + tokenizer: Optional[str] = None + task: TaskOption = "auto" + skip_tokenizer_init: bool = False + tokenizer_mode: str = 'auto' + trust_remote_code: bool = False + allowed_local_media_path: str = "" + download_dir: Optional[str] = None + load_format: str = 'auto' + config_format: ConfigFormat = ConfigFormat.AUTO + dtype: str = 'auto' + kv_cache_dtype: str = 'auto' + quantization_param_path: Optional[str] = None + seed: int = 0 + max_model_len: Optional[int] = None + worker_use_ray: bool = False + # Note: Specifying a custom executor backend by passing a class + # is intended for expert use only. The API may change without + # notice. + distributed_executor_backend: Optional[Union[str, + Type[ExecutorBase]]] = None + pipeline_parallel_size: int = 1 + tensor_parallel_size: int = 1 + max_parallel_loading_workers: Optional[int] = None + # NOTE(kzawora): default block size for Gaudi should be 128 + # smaller sizes still work, but very inefficiently + block_size: int = 16 if not current_platform.is_hpu() else 128 + enable_prefix_caching: bool = False + disable_sliding_window: bool = False + use_v2_block_manager: bool = True + swap_space: float = 4 # GiB + cpu_offload_gb: float = 0 # GiB + gpu_memory_utilization: float = 0.90 + max_num_batched_tokens: Optional[int] = None + max_num_seqs: int = 256 + max_logprobs: int = 20 # Default value for OpenAI Chat Completions API + disable_log_stats: bool = False + revision: Optional[str] = None + code_revision: Optional[str] = None + rope_scaling: Optional[Dict[str, Any]] = None + rope_theta: Optional[float] = None + hf_overrides: Optional[HfOverrides] = None + tokenizer_revision: Optional[str] = None + quantization: Optional[str] = None + enforce_eager: Optional[bool] = None + max_seq_len_to_capture: int = 8192 + disable_custom_all_reduce: bool = False + tokenizer_pool_size: int = 0 + # Note: Specifying a tokenizer pool by passing a class + # is intended for expert use only. The API may change without + # notice. + tokenizer_pool_type: Union[str, Type["BaseTokenizerGroup"]] = "ray" + tokenizer_pool_extra_config: Optional[Dict[str, Any]] = None + limit_mm_per_prompt: Optional[Mapping[str, int]] = None + mm_processor_kwargs: Optional[Dict[str, Any]] = None + enable_lora: bool = False + enable_lora_bias: bool = False + max_loras: int = 1 + max_lora_rank: int = 16 + enable_prompt_adapter: bool = False + max_prompt_adapters: int = 1 + max_prompt_adapter_token: int = 0 + fully_sharded_loras: bool = False + lora_extra_vocab_size: int = 256 + long_lora_scaling_factors: Optional[Tuple[float]] = None + lora_dtype: Optional[Union[str, torch.dtype]] = 'auto' + max_cpu_loras: Optional[int] = None + device: str = 'auto' + num_scheduler_steps: int = 1 + multi_step_stream_outputs: bool = True + ray_workers_use_nsight: bool = False + num_gpu_blocks_override: Optional[int] = None + num_lookahead_slots: int = 0 + model_loader_extra_config: Optional[dict] = None + ignore_patterns: Optional[Union[str, List[str]]] = None + preemption_mode: Optional[str] = None + + scheduler_delay_factor: float = 0.0 + enable_chunked_prefill: Optional[bool] = None + + guided_decoding_backend: str = 'outlines' + # Speculative decoding configuration. + speculative_model: Optional[str] = None + speculative_model_quantization: Optional[str] = None + speculative_draft_tensor_parallel_size: Optional[int] = None + num_speculative_tokens: Optional[int] = None + speculative_disable_mqa_scorer: Optional[bool] = False + speculative_max_model_len: Optional[int] = None + speculative_disable_by_batch_size: Optional[int] = None + ngram_prompt_lookup_max: Optional[int] = None + ngram_prompt_lookup_min: Optional[int] = None + spec_decoding_acceptance_method: str = 'rejection_sampler' + typical_acceptance_sampler_posterior_threshold: Optional[float] = None + typical_acceptance_sampler_posterior_alpha: Optional[float] = None + qlora_adapter_name_or_path: Optional[str] = None + disable_logprobs_during_spec_decoding: Optional[bool] = None + + otlp_traces_endpoint: Optional[str] = None + collect_detailed_traces: Optional[str] = None + disable_async_output_proc: bool = False + scheduling_policy: Literal["fcfs", "priority"] = "fcfs" + + override_neuron_config: Optional[Dict[str, Any]] = None + override_pooler_config: Optional[PoolerConfig] = None + compilation_config: Optional[CompilationConfig] = None + models = None + + def __post_init__(self): + if isinstance(self.model, str): + self.models = [self.model] + else: + self.models = self.model + self.model = self.models[0] + if not self.tokenizer: + self.tokenizer = self.model + + # Setup plugins + from vllm.plugins import load_general_plugins + load_general_plugins() + + @staticmethod + def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: + """Shared CLI arguments for vLLM engine.""" + + # Model arguments + parser.add_argument( + '--model', + '--names-list', + nargs="*", + type=str, + default=EngineArgs.model, + help='Name or path of the huggingface model to use.') + + parser.add_argument( + '--task', + default=EngineArgs.task, + choices=get_args(TaskOption), + help='The task to use the model for. Each vLLM instance only ' + 'supports one task, even if the same model can be used for ' + 'multiple tasks. When the model only supports one task, "auto" ' + 'can be used to select it; otherwise, you must specify explicitly ' + 'which task to use.') + parser.add_argument( + '--tokenizer', + type=nullable_str, + default=EngineArgs.tokenizer, + help='Name or path of the huggingface tokenizer to use. ' + 'If unspecified, model name or path will be used.') + parser.add_argument( + '--skip-tokenizer-init', + action='store_true', + help='Skip initialization of tokenizer and detokenizer') + parser.add_argument( + '--revision', + type=nullable_str, + default=None, + help='The specific model version to use. It can be a branch ' + 'name, a tag name, or a commit id. If unspecified, will use ' + 'the default version.') + parser.add_argument( + '--code-revision', + type=nullable_str, + default=None, + help='The specific revision to use for the model code on ' + 'Hugging Face Hub. It can be a branch name, a tag name, or a ' + 'commit id. If unspecified, will use the default version.') + parser.add_argument( + '--tokenizer-revision', + type=nullable_str, + default=None, + help='Revision of the huggingface tokenizer to use. ' + 'It can be a branch name, a tag name, or a commit id. ' + 'If unspecified, will use the default version.') + parser.add_argument( + '--tokenizer-mode', + type=str, + default=EngineArgs.tokenizer_mode, + choices=['auto', 'slow', 'mistral'], + help='The tokenizer mode.\n\n* "auto" will use the ' + 'fast tokenizer if available.\n* "slow" will ' + 'always use the slow tokenizer. \n* ' + '"mistral" will always use the `mistral_common` tokenizer.') + parser.add_argument('--trust-remote-code', + action='store_true', + help='Trust remote code from huggingface.') + parser.add_argument( + '--allowed-local-media-path', + type=str, + help="Allowing API requests to read local images or videos " + "from directories specified by the server file system. " + "This is a security risk. " + "Should only be enabled in trusted environments.") + parser.add_argument('--download-dir', + type=nullable_str, + default=EngineArgs.download_dir, + help='Directory to download and load the weights, ' + 'default to the default cache dir of ' + 'huggingface.') + parser.add_argument( + '--load-format', + type=str, + default=EngineArgs.load_format, + choices=[f.value for f in LoadFormat], + help='The format of the model weights to load.\n\n' + '* "auto" will try to load the weights in the safetensors format ' + 'and fall back to the pytorch bin format if safetensors format ' + 'is not available.\n' + '* "pt" will load the weights in the pytorch bin format.\n' + '* "safetensors" will load the weights in the safetensors format.\n' + '* "npcache" will load the weights in pytorch format and store ' + 'a numpy cache to speed up the loading.\n' + '* "dummy" will initialize the weights with random values, ' + 'which is mainly for profiling.\n' + '* "tensorizer" will load the weights using tensorizer from ' + 'CoreWeave. See the Tensorize vLLM Model script in the Examples ' + 'section for more information.\n' + '* "bitsandbytes" will load the weights using bitsandbytes ' + 'quantization.\n') + parser.add_argument( + '--config-format', + default=EngineArgs.config_format, + choices=[f.value for f in ConfigFormat], + help='The format of the model config to load.\n\n' + '* "auto" will try to load the config in hf format ' + 'if available else it will try to load in mistral format ') + parser.add_argument( + '--dtype', + type=str, + default=EngineArgs.dtype, + choices=[ + 'auto', 'half', 'float16', 'bfloat16', 'float', 'float32' + ], + help='Data type for model weights and activations.\n\n' + '* "auto" will use FP16 precision for FP32 and FP16 models, and ' + 'BF16 precision for BF16 models.\n' + '* "half" for FP16. Recommended for AWQ quantization.\n' + '* "float16" is the same as "half".\n' + '* "bfloat16" for a balance between precision and range.\n' + '* "float" is shorthand for FP32 precision.\n' + '* "float32" for FP32 precision.') + parser.add_argument( + '--kv-cache-dtype', + type=str, + choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'], + default=EngineArgs.kv_cache_dtype, + help='Data type for kv cache storage. If "auto", will use model ' + 'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. ' + 'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)') + parser.add_argument( + '--quantization-param-path', + type=nullable_str, + default=None, + help='Path to the JSON file containing the KV cache ' + 'scaling factors. This should generally be supplied, when ' + 'KV cache dtype is FP8. Otherwise, KV cache scaling factors ' + 'default to 1.0, which may cause accuracy issues. ' + 'FP8_E5M2 (without scaling) is only supported on cuda version ' + 'greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead ' + 'supported for common inference criteria.') + parser.add_argument('--max-model-len', + type=int, + default=EngineArgs.max_model_len, + help='Model context length. If unspecified, will ' + 'be automatically derived from the model config.') + parser.add_argument( + '--guided-decoding-backend', + type=str, + default='outlines', + choices=['outlines', 'lm-format-enforcer'], + help='Which engine will be used for guided decoding' + ' (JSON schema / regex etc) by default. Currently support ' + 'https://github.com/outlines-dev/outlines and ' + 'https://github.com/noamgat/lm-format-enforcer.' + ' Can be overridden per request via guided_decoding_backend' + ' parameter.') + # Parallel arguments + parser.add_argument( + '--distributed-executor-backend', + choices=['ray', 'mp'], + default=EngineArgs.distributed_executor_backend, + help='Backend to use for distributed model ' + 'workers, either "ray" or "mp" (multiprocessing). If the product ' + 'of pipeline_parallel_size and tensor_parallel_size is less than ' + 'or equal to the number of GPUs available, "mp" will be used to ' + 'keep processing on a single host. Otherwise, this will default ' + 'to "ray" if Ray is installed and fail otherwise. Note that tpu ' + 'and hpu only support Ray for distributed inference.') + + parser.add_argument( + '--worker-use-ray', + action='store_true', + help='Deprecated, use --distributed-executor-backend=ray.') + parser.add_argument('--pipeline-parallel-size', + '-pp', + type=int, + default=EngineArgs.pipeline_parallel_size, + help='Number of pipeline stages.') + parser.add_argument('--tensor-parallel-size', + '-tp', + type=int, + default=EngineArgs.tensor_parallel_size, + help='Number of tensor parallel replicas.') + parser.add_argument( + '--max-parallel-loading-workers', + type=int, + default=EngineArgs.max_parallel_loading_workers, + help='Load model sequentially in multiple batches, ' + 'to avoid RAM OOM when using tensor ' + 'parallel and large models.') + parser.add_argument( + '--ray-workers-use-nsight', + action='store_true', + help='If specified, use nsight to profile Ray workers.') + # KV cache arguments + parser.add_argument('--block-size', + type=int, + default=EngineArgs.block_size, + choices=[8, 16, 32, 64, 128], + help='Token block size for contiguous chunks of ' + 'tokens. This is ignored on neuron devices and ' + 'set to max-model-len') + + parser.add_argument('--enable-prefix-caching', + action='store_true', + help='Enables automatic prefix caching.') + parser.add_argument('--disable-sliding-window', + action='store_true', + help='Disables sliding window, ' + 'capping to sliding window size') + parser.add_argument('--use-v2-block-manager', + action='store_true', + help='[DEPRECATED] block manager v1 has been ' + 'removed and SelfAttnBlockSpaceManager (i.e. ' + 'block manager v2) is now the default. ' + 'Setting this flag to True or False' + ' has no effect on vLLM behavior.') + parser.add_argument( + '--num-lookahead-slots', + type=int, + default=EngineArgs.num_lookahead_slots, + help='Experimental scheduling config necessary for ' + 'speculative decoding. This will be replaced by ' + 'speculative config in the future; it is present ' + 'to enable correctness tests until then.') + + parser.add_argument('--seed', + type=int, + default=EngineArgs.seed, + help='Random seed for operations.') + parser.add_argument('--swap-space', + type=float, + default=EngineArgs.swap_space, + help='CPU swap space size (GiB) per GPU.') + parser.add_argument( + '--cpu-offload-gb', + type=float, + default=0, + help='The space in GiB to offload to CPU, per GPU. ' + 'Default is 0, which means no offloading. Intuitively, ' + 'this argument can be seen as a virtual way to increase ' + 'the GPU memory size. For example, if you have one 24 GB ' + 'GPU and set this to 10, virtually you can think of it as ' + 'a 34 GB GPU. Then you can load a 13B model with BF16 weight, ' + 'which requires at least 26GB GPU memory. Note that this ' + 'requires fast CPU-GPU interconnect, as part of the model is ' + 'loaded from CPU memory to GPU memory on the fly in each ' + 'model forward pass.') + parser.add_argument( + '--gpu-memory-utilization', + type=float, + default=EngineArgs.gpu_memory_utilization, + help='The fraction of GPU memory to be used for the model ' + 'executor, which can range from 0 to 1. For example, a value of ' + '0.5 would imply 50%% GPU memory utilization. If unspecified, ' + 'will use the default value of 0.9. This is a global gpu memory ' + 'utilization limit, for example if 50%% of the gpu memory is ' + 'already used before vLLM starts and --gpu-memory-utilization is ' + 'set to 0.9, then only 40%% of the gpu memory will be allocated ' + 'to the model executor.') + parser.add_argument( + '--num-gpu-blocks-override', + type=int, + default=None, + help='If specified, ignore GPU profiling result and use this number' + ' of GPU blocks. Used for testing preemption.') + parser.add_argument('--max-num-batched-tokens', + type=int, + default=EngineArgs.max_num_batched_tokens, + help='Maximum number of batched tokens per ' + 'iteration.') + parser.add_argument('--max-num-seqs', + type=int, + default=EngineArgs.max_num_seqs, + help='Maximum number of sequences per iteration.') + parser.add_argument( + '--max-logprobs', + type=int, + default=EngineArgs.max_logprobs, + help=('Max number of log probs to return logprobs is specified in' + ' SamplingParams.')) + parser.add_argument('--disable-log-stats', + action='store_true', + help='Disable logging statistics.') + # Quantization settings. + parser.add_argument('--quantization', + '-q', + type=nullable_str, + choices=[*QUANTIZATION_METHODS, None], + default=EngineArgs.quantization, + help='Method used to quantize the weights. If ' + 'None, we first check the `quantization_config` ' + 'attribute in the model config file. If that is ' + 'None, we assume the model weights are not ' + 'quantized and use `dtype` to determine the data ' + 'type of the weights.') + parser.add_argument( + '--rope-scaling', + default=None, + type=json.loads, + help='RoPE scaling configuration in JSON format. ' + 'For example, {"rope_type":"dynamic","factor":2.0}') + parser.add_argument('--rope-theta', + default=None, + type=float, + help='RoPE theta. Use with `rope_scaling`. In ' + 'some cases, changing the RoPE theta improves the ' + 'performance of the scaled model.') + parser.add_argument('--hf-overrides', + type=json.loads, + default=EngineArgs.hf_overrides, + help='Extra arguments for the HuggingFace config. ' + 'This should be a JSON string that will be ' + 'parsed into a dictionary.') + parser.add_argument('--enforce-eager', + action='store_true', + help='Always use eager-mode PyTorch. If False, ' + 'will use eager mode and CUDA graph in hybrid ' + 'for maximal performance and flexibility.') + parser.add_argument('--max-seq-len-to-capture', + type=int, + default=EngineArgs.max_seq_len_to_capture, + help='Maximum sequence length covered by CUDA ' + 'graphs. When a sequence has context length ' + 'larger than this, we fall back to eager mode. ' + 'Additionally for encoder-decoder models, if the ' + 'sequence length of the encoder input is larger ' + 'than this, we fall back to the eager mode.') + parser.add_argument('--disable-custom-all-reduce', + action='store_true', + default=EngineArgs.disable_custom_all_reduce, + help='See ParallelConfig.') + parser.add_argument('--tokenizer-pool-size', + type=int, + default=EngineArgs.tokenizer_pool_size, + help='Size of tokenizer pool to use for ' + 'asynchronous tokenization. If 0, will ' + 'use synchronous tokenization.') + parser.add_argument('--tokenizer-pool-type', + type=str, + default=EngineArgs.tokenizer_pool_type, + help='Type of tokenizer pool to use for ' + 'asynchronous tokenization. Ignored ' + 'if tokenizer_pool_size is 0.') + parser.add_argument('--tokenizer-pool-extra-config', + type=nullable_str, + default=EngineArgs.tokenizer_pool_extra_config, + help='Extra config for tokenizer pool. ' + 'This should be a JSON string that will be ' + 'parsed into a dictionary. Ignored if ' + 'tokenizer_pool_size is 0.') + + # Multimodal related configs + parser.add_argument( + '--limit-mm-per-prompt', + type=nullable_kvs, + default=EngineArgs.limit_mm_per_prompt, + # The default value is given in + # MultiModalRegistry.init_mm_limits_per_prompt + help=('For each multimodal plugin, limit how many ' + 'input instances to allow for each prompt. ' + 'Expects a comma-separated list of items, ' + 'e.g.: `image=16,video=2` allows a maximum of 16 ' + 'images and 2 videos per prompt. Defaults to 1 for ' + 'each modality.')) + parser.add_argument( + '--mm-processor-kwargs', + default=None, + type=json.loads, + help=('Overrides for the multimodal input mapping/processing, ' + 'e.g., image processor. For example: {"num_crops": 4}.')) + + # LoRA related configs + parser.add_argument('--enable-lora', + action='store_true', + help='If True, enable handling of LoRA adapters.') + parser.add_argument('--enable-lora-bias', + action='store_true', + help='If True, enable bias for LoRA adapters.') + parser.add_argument('--max-loras', + type=int, + default=EngineArgs.max_loras, + help='Max number of LoRAs in a single batch.') + parser.add_argument('--max-lora-rank', + type=int, + default=EngineArgs.max_lora_rank, + help='Max LoRA rank.') + parser.add_argument( + '--lora-extra-vocab-size', + type=int, + default=EngineArgs.lora_extra_vocab_size, + help=('Maximum size of extra vocabulary that can be ' + 'present in a LoRA adapter (added to the base ' + 'model vocabulary).')) + parser.add_argument( + '--lora-dtype', + type=str, + default=EngineArgs.lora_dtype, + choices=['auto', 'float16', 'bfloat16'], + help=('Data type for LoRA. If auto, will default to ' + 'base model dtype.')) + parser.add_argument( + '--long-lora-scaling-factors', + type=nullable_str, + default=EngineArgs.long_lora_scaling_factors, + help=('Specify multiple scaling factors (which can ' + 'be different from base model scaling factor ' + '- see eg. Long LoRA) to allow for multiple ' + 'LoRA adapters trained with those scaling ' + 'factors to be used at the same time. If not ' + 'specified, only adapters trained with the ' + 'base model scaling factor are allowed.')) + parser.add_argument( + '--max-cpu-loras', + type=int, + default=EngineArgs.max_cpu_loras, + help=('Maximum number of LoRAs to store in CPU memory. ' + 'Must be >= than max_loras. ' + 'Defaults to max_loras.')) + parser.add_argument( + '--fully-sharded-loras', + action='store_true', + help=('By default, only half of the LoRA computation is ' + 'sharded with tensor parallelism. ' + 'Enabling this will use the fully sharded layers. ' + 'At high sequence length, max rank or ' + 'tensor parallel size, this is likely faster.')) + parser.add_argument('--enable-prompt-adapter', + action='store_true', + help='If True, enable handling of PromptAdapters.') + parser.add_argument('--max-prompt-adapters', + type=int, + default=EngineArgs.max_prompt_adapters, + help='Max number of PromptAdapters in a batch.') + parser.add_argument('--max-prompt-adapter-token', + type=int, + default=EngineArgs.max_prompt_adapter_token, + help='Max number of PromptAdapters tokens') + parser.add_argument("--device", + type=str, + default=EngineArgs.device, + choices=DEVICE_OPTIONS, + help='Device type for vLLM execution.') + parser.add_argument('--num-scheduler-steps', + type=int, + default=1, + help=('Maximum number of forward steps per ' + 'scheduler call.')) + + parser.add_argument( + '--multi-step-stream-outputs', + action=StoreBoolean, + default=EngineArgs.multi_step_stream_outputs, + nargs="?", + const="True", + help='If False, then multi-step will stream outputs at the end ' + 'of all steps') + parser.add_argument( + '--scheduler-delay-factor', + type=float, + default=EngineArgs.scheduler_delay_factor, + help='Apply a delay (of delay factor multiplied by previous ' + 'prompt latency) before scheduling next prompt.') + parser.add_argument( + '--enable-chunked-prefill', + action=StoreBoolean, + default=EngineArgs.enable_chunked_prefill, + nargs="?", + const="True", + help='If set, the prefill requests can be chunked based on the ' + 'max_num_batched_tokens.') + + parser.add_argument( + '--speculative-model', + type=nullable_str, + default=EngineArgs.speculative_model, + help= + 'The name of the draft model to be used in speculative decoding.') + # Quantization settings for speculative model. + parser.add_argument( + '--speculative-model-quantization', + type=nullable_str, + choices=[*QUANTIZATION_METHODS, None], + default=EngineArgs.speculative_model_quantization, + help='Method used to quantize the weights of speculative model. ' + 'If None, we first check the `quantization_config` ' + 'attribute in the model config file. If that is ' + 'None, we assume the model weights are not ' + 'quantized and use `dtype` to determine the data ' + 'type of the weights.') + parser.add_argument( + '--num-speculative-tokens', + type=int, + default=EngineArgs.num_speculative_tokens, + help='The number of speculative tokens to sample from ' + 'the draft model in speculative decoding.') + parser.add_argument( + '--speculative-disable-mqa-scorer', + action='store_true', + help= + 'If set to True, the MQA scorer will be disabled in speculative ' + ' and fall back to batch expansion') + parser.add_argument( + '--speculative-draft-tensor-parallel-size', + '-spec-draft-tp', + type=int, + default=EngineArgs.speculative_draft_tensor_parallel_size, + help='Number of tensor parallel replicas for ' + 'the draft model in speculative decoding.') + + parser.add_argument( + '--speculative-max-model-len', + type=int, + default=EngineArgs.speculative_max_model_len, + help='The maximum sequence length supported by the ' + 'draft model. Sequences over this length will skip ' + 'speculation.') + + parser.add_argument( + '--speculative-disable-by-batch-size', + type=int, + default=EngineArgs.speculative_disable_by_batch_size, + help='Disable speculative decoding for new incoming requests ' + 'if the number of enqueue requests is larger than this value.') + + parser.add_argument( + '--ngram-prompt-lookup-max', + type=int, + default=EngineArgs.ngram_prompt_lookup_max, + help='Max size of window for ngram prompt lookup in speculative ' + 'decoding.') + + parser.add_argument( + '--ngram-prompt-lookup-min', + type=int, + default=EngineArgs.ngram_prompt_lookup_min, + help='Min size of window for ngram prompt lookup in speculative ' + 'decoding.') + + parser.add_argument( + '--spec-decoding-acceptance-method', + type=str, + default=EngineArgs.spec_decoding_acceptance_method, + choices=['rejection_sampler', 'typical_acceptance_sampler'], + help='Specify the acceptance method to use during draft token ' + 'verification in speculative decoding. Two types of acceptance ' + 'routines are supported: ' + '1) RejectionSampler which does not allow changing the ' + 'acceptance rate of draft tokens, ' + '2) TypicalAcceptanceSampler which is configurable, allowing for ' + 'a higher acceptance rate at the cost of lower quality, ' + 'and vice versa.') + + parser.add_argument( + '--typical-acceptance-sampler-posterior-threshold', + type=float, + default=EngineArgs.typical_acceptance_sampler_posterior_threshold, + help='Set the lower bound threshold for the posterior ' + 'probability of a token to be accepted. This threshold is ' + 'used by the TypicalAcceptanceSampler to make sampling decisions ' + 'during speculative decoding. Defaults to 0.09') + + parser.add_argument( + '--typical-acceptance-sampler-posterior-alpha', + type=float, + default=EngineArgs.typical_acceptance_sampler_posterior_alpha, + help='A scaling factor for the entropy-based threshold for token ' + 'acceptance in the TypicalAcceptanceSampler. Typically defaults ' + 'to sqrt of --typical-acceptance-sampler-posterior-threshold ' + 'i.e. 0.3') + + parser.add_argument( + '--disable-logprobs-during-spec-decoding', + action=StoreBoolean, + default=EngineArgs.disable_logprobs_during_spec_decoding, + nargs="?", + const="True", + help='If set to True, token log probabilities are not returned ' + 'during speculative decoding. If set to False, log probabilities ' + 'are returned according to the settings in SamplingParams. If ' + 'not specified, it defaults to True. Disabling log probabilities ' + 'during speculative decoding reduces latency by skipping logprob ' + 'calculation in proposal sampling, target sampling, and after ' + 'accepted tokens are determined.') + + parser.add_argument('--model-loader-extra-config', + type=nullable_str, + default=EngineArgs.model_loader_extra_config, + help='Extra config for model loader. ' + 'This will be passed to the model loader ' + 'corresponding to the chosen load_format. ' + 'This should be a JSON string that will be ' + 'parsed into a dictionary.') + parser.add_argument( + '--ignore-patterns', + action="append", + type=str, + default=[], + help="The pattern(s) to ignore when loading the model." + "Default to `original/**/*` to avoid repeated loading of llama's " + "checkpoints.") + parser.add_argument( + '--preemption-mode', + type=str, + default=None, + help='If \'recompute\', the engine performs preemption by ' + 'recomputing; If \'swap\', the engine performs preemption by ' + 'block swapping.') + + parser.add_argument( + "--served-model-name", + nargs="+", + type=str, + default=None, + help="The model name(s) used in the API. If multiple " + "names are provided, the server will respond to any " + "of the provided names. The model name in the model " + "field of a response will be the first name in this " + "list. If not specified, the model name will be the " + "same as the `--model` argument. Noted that this name(s) " + "will also be used in `model_name` tag content of " + "prometheus metrics, if multiple names provided, metrics " + "tag will take the first one.") + parser.add_argument('--qlora-adapter-name-or-path', + type=str, + default=None, + help='Name or path of the QLoRA adapter.') + + parser.add_argument( + '--otlp-traces-endpoint', + type=str, + default=None, + help='Target URL to which OpenTelemetry traces will be sent.') + parser.add_argument( + '--collect-detailed-traces', + type=str, + default=None, + help="Valid choices are " + + ",".join(ALLOWED_DETAILED_TRACE_MODULES) + + ". It makes sense to set this only if --otlp-traces-endpoint is" + " set. If set, it will collect detailed traces for the specified " + "modules. This involves use of possibly costly and or blocking " + "operations and hence might have a performance impact.") + + parser.add_argument( + '--disable-async-output-proc', + action='store_true', + default=EngineArgs.disable_async_output_proc, + help="Disable async output processing. This may result in " + "lower performance.") + + parser.add_argument( + '--scheduling-policy', + choices=['fcfs', 'priority'], + default="fcfs", + help='The scheduling policy to use. "fcfs" (first come first served' + ', i.e. requests are handled in order of arrival; default) ' + 'or "priority" (requests are handled based on given ' + 'priority (lower value means earlier handling) and time of ' + 'arrival deciding any ties).') + + parser.add_argument( + '--override-neuron-config', + type=json.loads, + default=None, + help="Override or set neuron device configuration. " + "e.g. {\"cast_logits_dtype\": \"bloat16\"}.'") + parser.add_argument( + '--override-pooler-config', + type=PoolerConfig.from_json, + default=None, + help="Override or set the pooling method in the embedding model. " + "e.g. {\"pooling_type\": \"mean\", \"normalize\": false}.'") + + parser.add_argument('--compilation-config', + '-O', + type=CompilationConfig.from_cli, + default=None, + help='torch.compile configuration for the model.' + 'When it is a number (0, 1, 2, 3), it will be ' + 'interpreted as the optimization level.\n' + 'NOTE: level 0 is the default level without ' + 'any optimization. level 1 and 2 are for internal ' + 'testing only. level 3 is the recommended level ' + 'for production.\n' + 'To specify the full compilation config, ' + 'use a JSON string.') + + return parser + + @classmethod + def from_cli_args(cls, args: argparse.Namespace): + # Get the list of attributes of this dataclass. + attrs = [attr.name for attr in dataclasses.fields(cls)] + # Set the attributes from the parsed arguments. + engine_args = cls(**{attr: getattr(args, attr) for attr in attrs}) + return engine_args + + def create_model_configs(self)-> list[ModelConfig]: + return [self.create_model_config(model) for model in self.models] + + def create_model_config(self, model:str = None) -> ModelConfig: + return ModelConfig( + model=model if model is not None else self.model, + task=self.task, + # We know this is not None because we set it in __post_init__ + tokenizer=cast(str, model), + tokenizer_mode=self.tokenizer_mode, + trust_remote_code=self.trust_remote_code, + allowed_local_media_path=self.allowed_local_media_path, + dtype=self.dtype, + seed=self.seed, + revision=self.revision, + code_revision=self.code_revision, + rope_scaling=self.rope_scaling, + rope_theta=self.rope_theta, + hf_overrides=self.hf_overrides, + tokenizer_revision=self.tokenizer_revision, + max_model_len=self.max_model_len, + quantization=self.quantization, + quantization_param_path=self.quantization_param_path, + enforce_eager=self.enforce_eager, + max_seq_len_to_capture=self.max_seq_len_to_capture, + max_logprobs=self.max_logprobs, + disable_sliding_window=self.disable_sliding_window, + skip_tokenizer_init=self.skip_tokenizer_init, + served_model_name=self.served_model_name, + limit_mm_per_prompt=self.limit_mm_per_prompt, + use_async_output_proc=not self.disable_async_output_proc, + config_format=self.config_format, + mm_processor_kwargs=self.mm_processor_kwargs, + override_neuron_config=self.override_neuron_config, + override_pooler_config=self.override_pooler_config, + ) + + def create_load_config(self) -> LoadConfig: + return LoadConfig( + load_format=self.load_format, + download_dir=self.download_dir, + model_loader_extra_config=self.model_loader_extra_config, + ignore_patterns=self.ignore_patterns, + ) + + def create_engine_config(self) -> VllmConfig: + # gguf file needs a specific model loader and doesn't use hf_repo + if check_gguf_file(self.model): + self.quantization = self.load_format = "gguf" + + # bitsandbytes quantization needs a specific model loader + # so we make sure the quant method and the load format are consistent + if (self.quantization == "bitsandbytes" or + self.qlora_adapter_name_or_path is not None) and \ + self.load_format != "bitsandbytes": + raise ValueError( + "BitsAndBytes quantization and QLoRA adapter only support " + f"'bitsandbytes' load format, but got {self.load_format}") + + if (self.load_format == "bitsandbytes" or + self.qlora_adapter_name_or_path is not None) and \ + self.quantization != "bitsandbytes": + raise ValueError( + "BitsAndBytes load format and QLoRA adapter only support " + f"'bitsandbytes' quantization, but got {self.quantization}") + + assert self.cpu_offload_gb >= 0, ( + "CPU offload space must be non-negative" + f", but got {self.cpu_offload_gb}") + + device_config = DeviceConfig(device=self.device) + model_config = self.create_model_config() + model_configs = self.create_model_configs() + + if model_config.is_multimodal_model: + if self.enable_prefix_caching: + logger.warning( + "--enable-prefix-caching is currently not " + "supported for multimodal models and has been disabled.") + self.enable_prefix_caching = False + + cache_config = CacheConfig( + # neuron needs block_size = max_model_len + block_size=self.block_size if self.device != "neuron" else + (self.max_model_len if self.max_model_len is not None else 0), + gpu_memory_utilization=self.gpu_memory_utilization, + swap_space=self.swap_space, + cache_dtype=self.kv_cache_dtype, + is_attention_free=model_config.is_attention_free, + num_gpu_blocks_override=self.num_gpu_blocks_override, + sliding_window=model_config.get_sliding_window(), + enable_prefix_caching=self.enable_prefix_caching, + cpu_offload_gb=self.cpu_offload_gb, + ) + parallel_config = ParallelConfig( + pipeline_parallel_size=self.pipeline_parallel_size, + tensor_parallel_size=self.tensor_parallel_size, + worker_use_ray=self.worker_use_ray, + max_parallel_loading_workers=self.max_parallel_loading_workers, + disable_custom_all_reduce=self.disable_custom_all_reduce, + tokenizer_pool_config=TokenizerPoolConfig.create_config( + self.tokenizer_pool_size, + self.tokenizer_pool_type, + self.tokenizer_pool_extra_config, + ), + ray_workers_use_nsight=self.ray_workers_use_nsight, + distributed_executor_backend=self.distributed_executor_backend) + + max_model_len = model_config.max_model_len + use_long_context = max_model_len > 32768 + if self.enable_chunked_prefill is None: + # If not explicitly set, enable chunked prefill by default for + # long context (> 32K) models. This is to avoid OOM errors in the + # initial memory profiling phase. + + # Chunked prefill is currently disabled for multimodal models by + # default. + if use_long_context and not model_config.is_multimodal_model: + is_gpu = device_config.device_type == "cuda" + use_sliding_window = (model_config.get_sliding_window() + is not None) + use_spec_decode = self.speculative_model is not None + if (is_gpu and not use_sliding_window and not use_spec_decode + and not self.enable_lora + and not self.enable_prompt_adapter + and model_config.task != "embedding"): + self.enable_chunked_prefill = True + logger.warning( + "Chunked prefill is enabled by default for models with " + "max_model_len > 32K. Currently, chunked prefill might " + "not work with some features or models. If you " + "encounter any issues, please disable chunked prefill " + "by setting --enable-chunked-prefill=False.") + if self.enable_chunked_prefill is None: + self.enable_chunked_prefill = False + + if not self.enable_chunked_prefill and use_long_context: + logger.warning( + "The model has a long context length (%s). This may cause OOM " + "errors during the initial memory profiling phase, or result " + "in low performance due to small KV cache space. Consider " + "setting --max-model-len to a smaller value.", max_model_len) + elif self.enable_chunked_prefill and model_config.task == "embedding": + msg = "Chunked prefill is not supported for embedding models" + raise ValueError(msg) + + speculative_config = SpeculativeConfig.maybe_create_spec_config( + target_model_config=model_config, + target_parallel_config=parallel_config, + target_dtype=self.dtype, + speculative_model=self.speculative_model, + speculative_model_quantization = \ + self.speculative_model_quantization, + speculative_draft_tensor_parallel_size = \ + self.speculative_draft_tensor_parallel_size, + num_speculative_tokens=self.num_speculative_tokens, + speculative_disable_mqa_scorer=self.speculative_disable_mqa_scorer, + speculative_disable_by_batch_size=self. + speculative_disable_by_batch_size, + speculative_max_model_len=self.speculative_max_model_len, + enable_chunked_prefill=self.enable_chunked_prefill, + disable_log_stats=self.disable_log_stats, + ngram_prompt_lookup_max=self.ngram_prompt_lookup_max, + ngram_prompt_lookup_min=self.ngram_prompt_lookup_min, + draft_token_acceptance_method=\ + self.spec_decoding_acceptance_method, + typical_acceptance_sampler_posterior_threshold=self. + typical_acceptance_sampler_posterior_threshold, + typical_acceptance_sampler_posterior_alpha=self. + typical_acceptance_sampler_posterior_alpha, + disable_logprobs=self.disable_logprobs_during_spec_decoding, + ) + + # Reminder: Please update docs/source/serving/compatibility_matrix.rst + # If the feature combo become valid + if self.num_scheduler_steps > 1: + if speculative_config is not None: + raise ValueError("Speculative decoding is not supported with " + "multi-step (--num-scheduler-steps > 1)") + if self.enable_chunked_prefill and self.pipeline_parallel_size > 1: + raise ValueError("Multi-Step Chunked-Prefill is not supported " + "for pipeline-parallel-size > 1") + + # make sure num_lookahead_slots is set the higher value depending on + # if we are using speculative decoding or multi-step + num_lookahead_slots = max(self.num_lookahead_slots, + self.num_scheduler_steps - 1) + num_lookahead_slots = num_lookahead_slots \ + if speculative_config is None \ + else speculative_config.num_lookahead_slots + + if not self.use_v2_block_manager: + logger.warning( + "[DEPRECATED] Block manager v1 has been removed, " + "and setting --use-v2-block-manager to True or False has " + "no effect on vLLM behavior. Please remove " + "--use-v2-block-manager in your engine argument. " + "If your use case is not supported by " + "SelfAttnBlockSpaceManager (i.e. block manager v2)," + " please file an issue with detailed information.") + + scheduler_config = SchedulerConfig( + task=model_config.task, + max_num_batched_tokens=self.max_num_batched_tokens, + max_num_seqs=self.max_num_seqs, + max_model_len=model_config.max_model_len, + num_lookahead_slots=num_lookahead_slots, + delay_factor=self.scheduler_delay_factor, + enable_chunked_prefill=self.enable_chunked_prefill, + is_multimodal_model=model_config.is_multimodal_model, + preemption_mode=self.preemption_mode, + num_scheduler_steps=self.num_scheduler_steps, + multi_step_stream_outputs=self.multi_step_stream_outputs, + send_delta_data=(envs.VLLM_USE_RAY_SPMD_WORKER + and parallel_config.use_ray), + policy=self.scheduling_policy) + lora_config = LoRAConfig( + bias_enabled=self.enable_lora_bias, + max_lora_rank=self.max_lora_rank, + max_loras=self.max_loras, + fully_sharded_loras=self.fully_sharded_loras, + lora_extra_vocab_size=self.lora_extra_vocab_size, + long_lora_scaling_factors=self.long_lora_scaling_factors, + lora_dtype=self.lora_dtype, + max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras + and self.max_cpu_loras > 0 else None) if self.enable_lora else None + + if self.qlora_adapter_name_or_path is not None and \ + self.qlora_adapter_name_or_path != "": + if self.model_loader_extra_config is None: + self.model_loader_extra_config = {} + self.model_loader_extra_config[ + "qlora_adapter_name_or_path"] = self.qlora_adapter_name_or_path + + load_config = self.create_load_config() + + prompt_adapter_config = PromptAdapterConfig( + max_prompt_adapters=self.max_prompt_adapters, + max_prompt_adapter_token=self.max_prompt_adapter_token) \ + if self.enable_prompt_adapter else None + + decoding_config = DecodingConfig( + guided_decoding_backend=self.guided_decoding_backend) + + detailed_trace_modules = [] + if self.collect_detailed_traces is not None: + detailed_trace_modules = self.collect_detailed_traces.split(",") + for m in detailed_trace_modules: + if m not in ALLOWED_DETAILED_TRACE_MODULES: + raise ValueError( + f"Invalid module {m} in collect_detailed_traces. " + f"Valid modules are {ALLOWED_DETAILED_TRACE_MODULES}") + observability_config = ObservabilityConfig( + otlp_traces_endpoint=self.otlp_traces_endpoint, + collect_model_forward_time="model" in detailed_trace_modules + or "all" in detailed_trace_modules, + collect_model_execute_time="worker" in detailed_trace_modules + or "all" in detailed_trace_modules, + ) + + return VllmConfig( + model_config=model_config, + cache_config=cache_config, + parallel_config=parallel_config, + scheduler_config=scheduler_config, + device_config=device_config, + lora_config=lora_config, + speculative_config=speculative_config, + load_config=load_config, + decoding_config=decoding_config, + observability_config=observability_config, + prompt_adapter_config=prompt_adapter_config, + compilation_config=self.compilation_config, + model_configs=model_configs, + ) + + +@dataclass +class AsyncEngineArgs(EngineArgs): + """Arguments for asynchronous vLLM engine.""" + disable_log_requests: bool = False + + @staticmethod + def add_cli_args(parser: FlexibleArgumentParser, + async_args_only: bool = False) -> FlexibleArgumentParser: + if not async_args_only: + parser = EngineArgs.add_cli_args(parser) + parser.add_argument('--disable-log-requests', + action='store_true', + help='Disable logging requests.') + return parser + + +# These functions are used by sphinx to build the documentation +def _engine_args_parser(): + return EngineArgs.add_cli_args(FlexibleArgumentParser()) + + +def _async_engine_args_parser(): + return AsyncEngineArgs.add_cli_args(FlexibleArgumentParser(), + async_args_only=True) diff --git a/vllm/engine/multiprocessing/__init__.py b/vllm/engine/multiprocessing/__init__.py index 34c161e9395ae..46b95986fa96c 100644 --- a/vllm/engine/multiprocessing/__init__.py +++ b/vllm/engine/multiprocessing/__init__.py @@ -27,6 +27,7 @@ class RPCProcessRequest: prompt: PromptType params: Union[SamplingParams, PoolingParams] request_id: str + model: str lora_request: Optional[LoRARequest] = None trace_headers: Optional[Mapping[str, str]] = None prompt_adapter_request: Optional[PromptAdapterRequest] = None @@ -39,6 +40,7 @@ def __init__( inputs: PromptType, params: Union[SamplingParams, PoolingParams], request_id: str, + model: str, lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, @@ -52,6 +54,7 @@ def __init__( prompt: PromptType, params: Union[SamplingParams, PoolingParams], request_id: str, + model: str, lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, @@ -68,6 +71,7 @@ def __init__( prompt: Optional[PromptType] = None, params: Optional[Union[SamplingParams, PoolingParams]] = None, request_id: Optional[str] = None, + model: Optional[str] = None, lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, @@ -85,6 +89,7 @@ def __init__( self.prompt = prompt self.params = params self.request_id = request_id + self.model = model self.lora_request = lora_request self.trace_headers = trace_headers self.prompt_adapter_request = prompt_adapter_request diff --git a/vllm/engine/multiprocessing/mm_client.py b/vllm/engine/multiprocessing/mm_client.py new file mode 100644 index 0000000000000..99c963ab0792e --- /dev/null +++ b/vllm/engine/multiprocessing/mm_client.py @@ -0,0 +1,671 @@ +import asyncio +import copy +import pickle +from contextlib import contextmanager, suppress +from typing import (Any, AsyncGenerator, Dict, Iterator, List, Mapping, + Optional, Union, cast, overload) + +import cloudpickle +import psutil +import zmq +import zmq.asyncio +from zmq import Frame # type: ignore[attr-defined] +from zmq.asyncio import Socket + +from vllm import PoolingParams +from vllm.config import DecodingConfig, ModelConfig, VllmConfig +from vllm.core.scheduler import SchedulerOutputs +from vllm.engine.arg_utils import AsyncEngineArgs +# yapf conflicts with isort for this block +# yapf: disable +from vllm.engine.async_llm_engine import ( + build_guided_decoding_logits_processor_async) +from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT, + IPC_HEALTH_EXT, IPC_INPUT_EXT, + IPC_OUTPUT_EXT, RPC_REQUEST_T, + VLLM_RPC_SUCCESS_STR, RPCAbortRequest, + RPCError, RPCProcessRequest, + RPCStartupRequest, RPCStartupResponse, + RPCUProfileRequest) +from vllm.engine.protocol import EngineClient +# yapf: enable +from vllm.envs import VLLM_RPC_TIMEOUT +from vllm.inputs import PromptType +from vllm.inputs.preprocess import InputPreprocessor +from vllm.logger import init_logger +from vllm.lora.request import LoRARequest +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.outputs import EmbeddingRequestOutput, RequestOutput +from vllm.prompt_adapter.request import PromptAdapterRequest +from vllm.sampling_params import SamplingParams +from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs +from vllm.utils import deprecate_kwargs + +logger = init_logger(__name__) + + +class MQClientClosedError(Exception): + """Exception class raised when the client is used post-close. + + The client can be closed, which closes the ZMQ context. This normally + happens on server shutdown. In some cases, methods like abort and + do_log_stats will still be called and then try to open a socket, which + causes a ZMQError and creates a huge stack trace. + So, we throw this error such that we can suppress it. + """ + + +class MMLLMEngineClient(EngineClient): + """A client wrapper for MQLLMEngine that conforms to the + EngineClient protocol. + + MQLLMEngine and MMLLMEngineClient are intended to run in separate + processes communicating via zeromq ipc sockets. + + The entrypoint to MMLLMEngineClient is through the generate() + method. On generate() MQLLMEngine does three things: + - Creates an asyncio output queue + - Sends a RPCGenerateRequest to the MQLLMEngine via zmq + - Pulls RequestOutputs from its queue and yields them + + MQLLMEngine runs two background loops: + - output_loop: the output loop pulls List[RequestOutput] + from the MQLLMEngine via zmq (each list is the output + of one engine_step in the LLMEngine). It then parses + the list and pushes individual request_outputs into + the corresponding output_queue such that they can be + consumed by the .generate() method. + - health_loop: the health loop queries the health socket + every N seconds, confirming the engine is healthy + """ + + def __init__(self, ipc_path: str, engine_config: VllmConfig, + engine_pid: int): + self.context = zmq.asyncio.Context() + self._errored_with: Optional[BaseException] = None + + # Get the configs. + # FIXME: use model_configs. + self.model_configs = engine_config.model_configs + self.decoding_config = engine_config.decoding_config + + # Create the tokenizer group. + self.tokenizers = [] + self.input_preprocessors = [] + for model_config in self.model_configs: + self.tokenizers.append( + init_tokenizer_from_configs( + model_config=model_config, + scheduler_config=engine_config.scheduler_config, + parallel_config=engine_config.parallel_config, + enable_lora=bool(engine_config.lora_config), + )) + self.input_preprocessors.append( + InputPreprocessor(model_config, self.tokenizers[-1])) + + # Send RPCGenerateRequest to the MQLLMEngine. + self.input_socket: Socket = self.context.socket(zmq.constants.PUSH) + self.input_socket.connect(f"{ipc_path}{IPC_INPUT_EXT}") + + # Receive streams of RequestOutput from the MQLLMEngine. + self.output_socket: Socket = self.context.socket(zmq.constants.PULL) + self.output_socket.connect(f"{ipc_path}{IPC_OUTPUT_EXT}") + + # IPC path for acking heartbeats. + self.heartbeat_socket: Socket = self.context.socket(zmq.constants.PULL) + self.heartbeat_socket.connect(f"{ipc_path}{IPC_HEALTH_EXT}") + + # IPC path for the data socket. + self.data_ipc_path = f"{ipc_path}{IPC_DATA_EXT}" + + # Stream for each individual request. + self.output_queues: Dict[str, asyncio.Queue] = {} + + # Loop to handle output of the LLMEngine periodically. + # Started after the MQLLMEngine is ready so that we can + # build the Client in an executor to enable clean shutdown. + self.output_loop: Optional[asyncio.Task] = None + + # Loop to check health of the LLMEngine periodically. + # Started after the MQLLMEngine is ready. + self.health_loop: Optional[asyncio.Task] = None + self._engine_process = psutil.Process(engine_pid) + + @staticmethod + def is_unsupported_config(engine_args: AsyncEngineArgs): + # Pipeline parallel not yet supported + return engine_args.pipeline_parallel_size > 1 + + @contextmanager + def get_data_socket(self) -> Iterator[Socket]: + socket = self.context.socket(zmq.constants.DEALER) + try: + socket.connect(self.data_ipc_path) + yield socket + finally: + socket.close(linger=0) + + async def run_heartbeat_loop(self, timeout: int): + """Background loop that continually checks to ensure the engine process + is still alive. + """ + try: + while True: + # Check if the engine process is running: + if not self._engine_process.is_running() or ( + self._engine_process.status() == psutil.STATUS_ZOMBIE): + # NB: is_running() returns True for zombies + self._set_errored( + RuntimeError( + f"Engine process (pid {self._engine_process.pid}) " + "died.")) + break + + if await self.heartbeat_socket.poll(timeout=timeout): + # Heartbeat received- check the message + await self._check_success( + error_message="Heartbeat failed.", + socket=self.heartbeat_socket) + + logger.debug("Heartbeat successful.") + + except asyncio.CancelledError: + logger.debug("Shutting down MQLLMEngineClient check health loop.") + + except psutil.NoSuchProcess: + self._set_errored( + RuntimeError( + f"Engine process (pid {self._engine_process.pid}) died.")) + + except Exception as e: + self._set_errored(e) + + async def run_output_handler_loop(self): + """Get RequestOutputs from Engine and stream to Request Queues""" + + try: + while True: + # Poll, checking for ENGINE_DEAD + while await self.output_socket.poll(timeout=VLLM_RPC_TIMEOUT + ) == 0: + logger.debug("Waiting for output from MQLLMEngine.") + + # If errored, alert all running requests. + if self.errored: + for queue_j in tuple(self.output_queues.values()): + queue_j.put_nowait( + ENGINE_DEAD_ERROR(self._errored_with)) + return + + message: Frame = await self.output_socket.recv(copy=False) + request_outputs = pickle.loads(message.buffer) + + is_error = isinstance(request_outputs, + (BaseException, RPCError)) + if is_error: + if isinstance(request_outputs, RPCError): + rpc_error: RPCError = request_outputs + request_id = rpc_error.request_id + exception = rpc_error.exception + is_engine_errored = rpc_error.is_engine_errored + else: + # MPLLMEngine should always return an RPCError to + # the output_socket when an issue arises. + # If we are here, we are in a bad state and + # should shut down the server. + error: BaseException = request_outputs + logger.error( + "Received Exception %s rather than RPCError from " + "MPLLMEngine. This should never happen.", error) + request_id = None + exception = error + is_engine_errored = True + + # Set to error state only on engine critical error + # (and record only the first one) + if is_engine_errored and not self._errored_with: + self._errored_with = exception + # If engine is errored, no matter the type of exception + # it will no longer be able to receive new requests, + # therefore we have to inform that the current + # processed requests failed as well. Send back a dead + # engine error give this feedback and also give a + # 'hint' to the server to shutdown next. + exception = self.dead_error + + if request_id is None: + # If request_id is None, then the engine raised an + # exception for a batch, and we may not know the + # request that caused it, neither if it was actually + # caused by any of them (e.g. CUDA OOM). Therefore we + # broadcast the same exception for all requests. + for queue_i in tuple(self.output_queues.values()): + queue_i.put_nowait(exception) + else: + queue = self.output_queues.get(request_id) + if queue is not None: + queue.put_nowait(exception) + else: + # Put each output into the appropriate steam. + for request_output in request_outputs: + queue = self.output_queues.get( + request_output.request_id) + if queue is not None: + queue.put_nowait(request_output) + + except asyncio.CancelledError: + logger.debug("Shutting down MQLLMEngineClient output handler.") + + async def setup(self): + """Setup the client before it starts sending server requests.""" + + # Start output_loop + self.output_loop = asyncio.create_task(self.run_output_handler_loop()) + + with self.get_data_socket() as socket: + # Wait until server is ready. + response = await self._wait_for_server_rpc(socket) + + self.tracing_flag = response.tracing_enabled + + # Start health_loop. + self.health_loop = asyncio.create_task( + self.run_heartbeat_loop(timeout=VLLM_RPC_TIMEOUT)) + + def close(self): + """Destroy the ZeroMQ Context.""" + # Close all sockets and terminate the context. + self.context.destroy(linger=0) + + # Cancel background tasks. + if self.health_loop is not None: + self.health_loop.cancel() + if self.output_loop is not None: + self.output_loop.cancel() + + def _set_errored(self, e: BaseException): + logger.exception(repr(e)) + if self._errored_with is None: + self._errored_with = e + + @staticmethod + async def _send_get_data_rpc_request(request: RPCStartupRequest, + expected_type: Any, + error_message: str, + socket: Socket) -> Any: + """Send an RPC request that is expecting data back.""" + + # Ping RPCServer with a request. + await socket.send_multipart((pickle.dumps(request), ), copy=False) + + # Make sure the server responds in time. + if await socket.poll(timeout=VLLM_RPC_TIMEOUT) == 0: + raise TimeoutError("RPCServer didn't reply within " + f"{VLLM_RPC_TIMEOUT} ms") + + # Await the data from the Server. + frame = await socket.recv(copy=False) + data = pickle.loads(frame.buffer) + + if isinstance(data, BaseException): + raise data + elif not isinstance(data, expected_type): + raise ValueError(error_message) + + return data + + @staticmethod + async def _send_one_way_rpc_request(request: RPC_REQUEST_T, + socket: Socket): + """Send one-way RPC request to trigger an action.""" + + if socket.closed: + raise MQClientClosedError() + + await socket.send_multipart((pickle.dumps(request), )) + + async def _await_ack(self, error_message: str, socket: Socket): + """Await acknowledgement that a request succeeded.""" + + if socket.closed: + raise MQClientClosedError() + + if await socket.poll(timeout=VLLM_RPC_TIMEOUT) == 0: + raise TimeoutError("MQLLMEngine didn't reply within " + f"{VLLM_RPC_TIMEOUT}ms") + + await self._check_success(error_message, socket) + + @staticmethod + async def _check_success(error_message: str, socket: Socket): + """Confirm that socket has a VLLM_RPC_SUCCESS_STR message""" + + if socket.closed: + raise MQClientClosedError() + + frame = await socket.recv(copy=False) + response = pickle.loads(frame.buffer) + + # Raise error if unsuccessful + if isinstance(response, BaseException): + raise response + elif (not isinstance(response, str) + or response != VLLM_RPC_SUCCESS_STR): + raise ValueError(error_message) + + #TODO:check usage + async def get_input_preprocessor(self) -> InputPreprocessor: + return self.input_preprocessor + + #TODO:check usage + async def get_tokenizer(self, lora_request: Optional[LoRARequest] = None): + return await self.tokenizer.get_lora_tokenizer_async(lora_request) + + async def get_tokenizer_mm(self, model, lora_request: Optional[LoRARequest] = None): + for tokenizer in self.tokenizers: + if tokenizer.tokenizer_id == model: + return await tokenizer.get_lora_tokenizer_async(lora_request) + raise ValueError(f"Tokenizer for model {model} not found.") + + async def get_decoding_config(self) -> DecodingConfig: + return self.decoding_config + + #TODO:check usage + async def get_model_config(self) -> ModelConfig: + return self.model_configs[0] + + async def is_tracing_enabled(self) -> bool: + return self.tracing_flag + + async def _wait_for_server_rpc(self, socket: Socket) -> RPCStartupResponse: + """Wait for the RPCServer to start up.""" + + return await self._send_get_data_rpc_request( + request=RPCStartupRequest.IS_SERVER_READY, + expected_type=RPCStartupResponse, + error_message="Unable to start RPC Server", + socket=socket) + + async def abort(self, request_id: str): + """Send an ABORT_REQUEST signal to the RPC Server""" + + with suppress(MQClientClosedError): + await self._send_one_way_rpc_request( + request=RPCAbortRequest(request_id), socket=self.input_socket) + + async def do_log_stats( + self, + scheduler_outputs: Optional[SchedulerOutputs] = None, + model_output: Optional[List[SamplerOutput]] = None, + ) -> None: + """ + Ignore do_log_stats (handled on MQLLMEngine polling) + """ + pass + + async def check_health(self): + """ + The check health loop probes the health status of the + Engine's health every N seconds and sets _errored_with + if the engine is unhealthy. + """ + if self._errored_with is not None: + raise self._errored_with + + @property + def is_running(self) -> bool: + return not self.errored + + @property + def is_stopped(self) -> bool: + return self.errored + + @property + def errored(self) -> bool: + return self._errored_with is not None + + @property + def dead_error(self) -> BaseException: + return ENGINE_DEAD_ERROR(self._errored_with) + + @overload # DEPRECATED + def generate( + self, + *, + inputs: PromptType, + sampling_params: SamplingParams, + request_id: str, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + priority: int = 0, + ) -> AsyncGenerator[RequestOutput, None]: + ... + + @overload + def generate( + self, + prompt: PromptType, + sampling_params: SamplingParams, + request_id: str, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + priority: int = 0, + ) -> AsyncGenerator[RequestOutput, None]: + ... + + @deprecate_kwargs( + "inputs", + additional_message="Please use the 'prompt' parameter instead.", + ) + def generate( + self, + prompt: Optional[PromptType] = None, + sampling_params: Optional[SamplingParams] = None, + request_id: Optional[str] = None, + model: Optional[str] = None, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + priority: int = 0, + *, + inputs: Optional[PromptType] = None # DEPRECATED + ) -> AsyncGenerator[RequestOutput, None]: + """Generate outputs for a request. + + Generate outputs for a request. This method is a coroutine. It adds the + request into the waiting queue of the LLMEngine and streams the outputs + from the LLMEngine to the caller. + + Args: + prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType` + for more details about the format of each input. + sampling_params: The sampling parameters of the request. + request_id: The unique id of the request. + lora_request: LoRA request to use for generation, if any. + trace_headers: OpenTelemetry trace headers. + prompt_adapter_request: Prompt Adapter request to use + for generation, if any. + priority: Priority of the request (lower means earlier handling). + Any priority other than 0 will lead to an error if the + scheduling policy is not "priority". + """ + if inputs is not None: + prompt = inputs + assert (prompt is not None and sampling_params is not None + and request_id is not None) + + return self._process_request(prompt, sampling_params, request_id, model, + lora_request, trace_headers, + prompt_adapter_request, priority) + + @overload # DEPRECATED + def encode( + self, + *, + inputs: PromptType, + pooling_params: PoolingParams, + request_id: str, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + priority: int = 0, + ) -> AsyncGenerator[EmbeddingRequestOutput, None]: + ... + + @overload + def encode( + self, + prompt: PromptType, + pooling_params: PoolingParams, + request_id: str, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + priority: int = 0, + ) -> AsyncGenerator[EmbeddingRequestOutput, None]: + ... + + @deprecate_kwargs( + "inputs", + additional_message="Please use the 'prompt' parameter instead.", + ) + def encode( + self, + prompt: Optional[PromptType] = None, + pooling_params: Optional[PoolingParams] = None, + request_id: Optional[str] = None, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + priority: int = 0, + *, + inputs: Optional[PromptType] = None # DEPRECATED + ) -> AsyncGenerator[EmbeddingRequestOutput, None]: + """Generate outputs for a request from an embedding model. + + Generate outputs for a request. This method is a coroutine. It adds the + request into the waiting queue of the LLMEngine and streams the outputs + from the LLMEngine to the caller. + + Args: + prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType` + for more details about the format of each input. + pooling_params: The pooling parameters of the request. + request_id: The unique id of the request. + lora_request: LoRA request to use for generation, if any. + trace_headers: OpenTelemetry trace headers. + + Yields: + The output `EmbeddingRequestOutput` objects from the LLMEngine + for the request. + """ + if inputs is not None: + prompt = inputs + assert (prompt is not None and pooling_params is not None + and request_id is not None) + + return cast( + AsyncGenerator[EmbeddingRequestOutput, None], + self._process_request(prompt, + pooling_params, + request_id, + lora_request, + trace_headers, + priority=priority)) + + async def _process_request( + self, + prompt: PromptType, + params: Union[SamplingParams, PoolingParams], + request_id: str, + model: str, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + priority: int = 0, + ) -> Union[AsyncGenerator[RequestOutput, None], AsyncGenerator[ + EmbeddingRequestOutput, None]]: + """Send an RPCGenerateRequest to the RPCServer and stream responses.""" + + # If already dead, error out. + if self._errored_with is not None: + raise ENGINE_DEAD_ERROR(self._errored_with) + + # Constructing guided decoding logits processors is expensive, so we do + # it here to avoid contending with cpu resources and the GIL on the + # backend process. + if isinstance(params, SamplingParams) and \ + params.guided_decoding is not None: + params = await \ + build_guided_decoding_logits_processor_async( + sampling_params=params, + tokenizer=await self.get_tokenizer(lora_request), + default_guided_backend=(self.decoding_config.guided_decoding_backend + if self.decoding_config + else DecodingConfig.guided_decoding_backend), + ) + + # 1) Create output queue for this requests. + queue: asyncio.Queue[Union[RequestOutput, + BaseException]] = asyncio.Queue() + self.output_queues[request_id] = queue + + try: + # 2) Detach logits processors so that they can be pickled + # separately (may require cloudpickle which is slower) + if isinstance(params, SamplingParams) and params.logits_processors: + # Defensive shallow copy + params = copy.copy(params) + logits_processors = params.logits_processors + params.logits_processors = None + lp_bytes = cloudpickle.dumps(logits_processors) + else: + lp_bytes = None + + request_bytes = pickle.dumps( + RPCProcessRequest( + prompt=prompt, + params=params, + request_id=request_id, + model=model, + lora_request=lora_request, + trace_headers=trace_headers, + prompt_adapter_request=prompt_adapter_request, + priority=priority, + )) + + # 3) Send the RPCGenerateRequest to the MQLLMEngine. + parts = (request_bytes, + lp_bytes) if lp_bytes else (request_bytes, ) + await self.input_socket.send_multipart(parts, copy=False) + + # 4) Stream the RequestOutputs from the output queue. Note + # that the output_loop pushes RequestOutput objects to this + # queue after pulling them from the zmq socket. + finished = False + try: + while not finished: + request_output = await queue.get() + + if isinstance(request_output, BaseException): + raise request_output + + finished = request_output.finished + yield request_output + finally: + # Request was canceled by the client. + if not finished and not self.errored: + await self.abort(request_id) + finally: + self.output_queues.pop(request_id) + + async def start_profile(self) -> None: + """Start profiling the engine""" + + await self._send_one_way_rpc_request( + request=RPCUProfileRequest.START_PROFILE, socket=self.input_socket) + + async def stop_profile(self) -> None: + """Stop profiling the engine""" + + await self._send_one_way_rpc_request( + request=RPCUProfileRequest.STOP_PROFILE, socket=self.input_socket) diff --git a/vllm/engine/multiprocessing/mm_engine.py b/vllm/engine/multiprocessing/mm_engine.py new file mode 100644 index 0000000000000..6f4729f398a34 --- /dev/null +++ b/vllm/engine/multiprocessing/mm_engine.py @@ -0,0 +1,387 @@ +import pickle +import signal +from contextlib import contextmanager +from typing import Iterator, List, Optional, Union + +import cloudpickle +import zmq + +from vllm import SamplingParams +from vllm.engine.mm_arg_utils import AsyncEngineArgs +from vllm.engine.llm_engine import LLMEngine +# yapf conflicts with isort for this block +# yapf: disable +from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT, + IPC_HEALTH_EXT, IPC_INPUT_EXT, + IPC_OUTPUT_EXT, REQUEST_OUTPUTS_T, + VLLM_RPC_SUCCESS_STR, RPCAbortRequest, + RPCError, RPCProcessRequest, + RPCStartupRequest, RPCStartupResponse, + RPCUProfileRequest) +# yapf: enable +from vllm.executor.gpu_executor import GPUExecutor +from vllm.logger import init_logger +from vllm.outputs import RequestOutput +from vllm.usage.usage_lib import UsageContext + +logger = init_logger(__name__) + +POLLING_TIMEOUT_MS = 10000 +HEALTHY_RESPONSE = (pickle.dumps(VLLM_RPC_SUCCESS_STR), ) + + +class MMLLMEngine: + """A multiprocessing wrapper for :class:`LLMEngine`. + + This class is used to wrap the :class:`LLMEngine` class to enable use + in concurrnet manner. It runs a background loop and uses zeromq to + receive new requests and stream outputs incrementally via ipc. + + The :class:`LLMEngine` generate or encode process is kicked off when a new + RPCProcessRequest is received by the input_socket. + + The self.engine_loop checks the input_socket for new requests, + adds them to the LLMEngine if there are any, calls the internal + :class:`LLMEngine.step()`, and sends the RequestOutputs back over + the output_socket. + + If use_async_sockets is set, the logic associated with reading new + requests from the socket and sending data to the socket is passed + as a callback to the llm_engine, which calls the logic asynchronously + such that the IPC can be overlapped with the GPU. + + Args: + ipc_path: Base path for zeromq interprocess messaging + use_async_sockets: Whether to make send/recv async with GPU + log_requests: Whether to log the requests. + *args: Arguments for :class:`LLMEngine`. + **kwargs: Arguments for :class:`LLMEngine`. + """ + + def __init__(self, + ipc_path: str, + use_async_sockets: bool, + *args, + log_requests: bool = True, + **kwargs) -> None: + # For MQLLMEngine, we can use cached outputs, since each new request + # output is immediately pickled and send over the socket, which frees + # the python object to be reused again. + kwargs['use_cached_outputs'] = True + + # get configs from args and kwargs, determine how many models to load + vllm_config = kwargs.get('vllm_config') + models_load = [model_config.model for model_config in vllm_config.model_configs ] + self.engines = [] + + for i, model in enumerate(models_load): + vllm_config.model_config = vllm_config.model_configs[i] + self.engines.append(LLMEngine(model=model, *args, **kwargs)) + self.log_requests = log_requests + + self.use_async_sockets = use_async_sockets + if self.use_async_sockets: + for engine in self.engines: + engine.process_request_outputs_callback = \ + self._async_socket_engine_callback + + self.ctx = zmq.Context() # type: ignore[attr-defined] + + # Receive input from the client. + self.input_socket = self.ctx.socket(zmq.constants.PULL) + self.input_socket.bind(f"{ipc_path}{IPC_INPUT_EXT}") + + # Send output stream back to client. + self.output_socket = self.ctx.socket(zmq.constants.PUSH) + self.output_socket.bind(f"{ipc_path}{IPC_OUTPUT_EXT}") + + # Send heartbeats back to client. + self.heartbeat_socket = self.ctx.socket(zmq.constants.PUSH) + self.heartbeat_socket.bind(f"{ipc_path}{IPC_HEALTH_EXT}") + + # IPC path for the data socket. + self.data_ipc_path = f"{ipc_path}{IPC_DATA_EXT}" + + # Error state. + self._errored_with: Optional[BaseException] = None + + @property + def dead_error(self) -> BaseException: + if self._errored_with is not None: + return ENGINE_DEAD_ERROR(self._errored_with) + else: + return ENGINE_DEAD_ERROR() + + @classmethod + def from_engine_args(cls, engine_args: AsyncEngineArgs, + usage_context: UsageContext, ipc_path: str): + """Creates an MQLLMEngine from the engine arguments.""" + # Setup plugins for each process + from vllm.plugins import load_general_plugins + load_general_plugins() + + engine_config = engine_args.create_engine_config() + executor_class = LLMEngine._get_executor_cls(engine_config) + + use_async_sockets = engine_config.model_config.use_async_output_proc + + return cls(ipc_path=ipc_path, + use_async_sockets=use_async_sockets, + vllm_config=engine_config, + executor_class=executor_class, + log_requests=not engine_args.disable_log_requests, + log_stats=not engine_args.disable_log_stats, + usage_context=usage_context) + + def start(self): + try: + try: + logger.debug("Starting Startup Loop.") + self.run_startup_loop() + logger.debug("Starting Engine Loop.") + self.run_engine_loop() + except Exception as e: + logger.exception(repr(e)) + except KeyboardInterrupt: + logger.debug("Shutting down MQLLMEngine.") + finally: + logger.debug("MQLLMEngine is shut down.") + self.cleanup() + + def cleanup(self): + """Cleanup zeromq state on shutdown.""" + # Closes all sockets and destroys context. + self.ctx.destroy(linger=0) + del self.engines + + @contextmanager + def make_data_socket( + self) -> Iterator[zmq.Socket]: # type: ignore[name-defined] + socket = self.ctx.socket(zmq.constants.ROUTER) + try: + socket.bind(self.data_ipc_path) + yield socket + finally: + socket.close(linger=0) + + def run_startup_loop(self) -> None: + """Startup loop for sending data from Engine -> Client.""" + + with self.make_data_socket() as socket: + response: Union[RPCStartupResponse, BaseException] + try: + identity, message = socket.recv_multipart(copy=False) + request: RPCStartupRequest = pickle.loads(message.buffer) + + # Handle the query from the Client. + if request == RPCStartupRequest.IS_SERVER_READY: + tracing_enabled = self.engines[0].is_tracing_enabled() + response = RPCStartupResponse( + tracing_enabled=tracing_enabled) + + except Exception as e: + response = e + + socket.send_multipart((identity, pickle.dumps(response)), + copy=False) + + def run_engine_loop(self): + """Core busy loop of the LLMEngine.""" + + while True: + if not any(engine.has_unfinished_requests() for engine in self.engines): + # Poll until there is work to do. + while self.input_socket.poll(timeout=POLLING_TIMEOUT_MS) == 0: + # When there's no work, check on engine health and send + # health status back to client + self._health_check() + for engine in self.engines: + engine.do_log_stats() + logger.debug("Waiting for new requests in engine loop.") + + # Handle any input from the client. + self.handle_new_input() + + # Engine step. + request_outputs = self.engine_step() + + # Send request outputs (if async, done in engine_step callback). + if not self.use_async_sockets: + self._send_outputs(request_outputs) + + def engine_step(self) -> List[RequestOutput]: + """Engine step wrapper with error handling.""" + try: + res = [] + for engine in self.engines: + res += engine.step() + return res + except SystemExit: + raise + except BaseException as e: + self._set_errored(e) + rpc_err = RPCError(request_id=None, + is_engine_errored=True, + exception=e) + self._send_outputs(rpc_err) + raise e + + def handle_new_input(self): + """Handle new input from the socket""" + try: + while self.input_socket.poll(timeout=0) != 0: + frames = self.input_socket.recv_multipart(copy=False) + request = pickle.loads(frames[0].buffer) + + if isinstance(request, RPCProcessRequest): + if len(frames) > 1: + # Use cloudpickle for logits processors + assert isinstance(request.params, SamplingParams) + lprocs = cloudpickle.loads(frames[1].buffer) + request.params.logits_processors = lprocs + self._handle_process_request(request) + elif isinstance(request, RPCAbortRequest): + self._handle_abort_request(request) + elif isinstance(request, RPCUProfileRequest): + if request == RPCUProfileRequest.START_PROFILE: + self.start_profile() + else: + self.stop_profile() + else: + raise ValueError("Unknown RPCRequest Type: " + f"{type(request)}") + + except Exception as e: + self._set_errored(e) + self._send_unhealthy(e) + raise e + + # FIXME: add model field in RPCProcessRequest, and dispatch to the correct engine + def _handle_process_request(self, request: RPCProcessRequest): + """Handle RPCProcessRequest by adding it to the LLMEngine.""" + request_id = request.request_id + + if self._errored_with is not None: + rpc_err = RPCError(request_id=request_id, + is_engine_errored=True, + exception=ENGINE_DEAD_ERROR(self._errored_with)) + self._send_outputs(rpc_err) + + try: + for engine in self.engines: + if engine.model_config.model == request.model: + engine.add_request( + request_id=request_id, + prompt=request.prompt, + params=request.params, + lora_request=request.lora_request, + trace_headers=request.trace_headers, + prompt_adapter_request=request.prompt_adapter_request, + priority=request.priority) + + if self.log_requests: + logger.info("Added request %s.", request.request_id) + + except Exception as e: + # We do not set self._errored = True here, since the error + # is due to an issue adding this request to the engine, + # rather than an issue with the engine itself. + is_errored = self._errored_with is not None + rpc_err = RPCError(request_id=request_id, + is_engine_errored=is_errored, + exception=e) + self._send_outputs(rpc_err) + + # Remove request from the engine. + self.engine.abort_request(request_id) + + # FIXME: add model field in RPCAbortRequest, and dispatch to the correct engine + def _handle_abort_request(self, request: RPCAbortRequest): + self.engine.abort_request(request.request_id) + if self.log_requests: + logger.info("Aborted request %s.", request.request_id) + + def _health_check(self): + # Send unhealthy if engine has already errored + if self._errored_with is not None: + self._send_unhealthy(self._errored_with) + try: + for engine in self.engines: + engine.check_health() + self._send_healthy() + except Exception as e: + self._set_errored(e) + self._send_unhealthy(e) + + def _send_outputs(self, outputs: REQUEST_OUTPUTS_T): + """Send List of RequestOutput to RPCClient.""" + if outputs: + try: + from ray.exceptions import RayTaskError + + # RayTaskError might not pickelable here. We need to unpack the + # underlying exception as the real exception in the output. + if (isinstance(outputs, RPCError) + and isinstance(outputs.exception, RayTaskError)): + outputs.exception = outputs.exception.cause + except ImportError: + pass + + output_bytes = pickle.dumps(outputs) + self.output_socket.send_multipart((output_bytes, ), copy=False) + + def _send_healthy(self): + """Send HEALTHY message to RPCClient.""" + if not self.heartbeat_socket.closed: + self.heartbeat_socket.send_multipart(HEALTHY_RESPONSE, copy=False) + + def _send_unhealthy(self, error: BaseException): + """Send UNHEALTHY message to RPCClient.""" + if not self.heartbeat_socket.closed: + error_bytes = pickle.dumps(error) + self.heartbeat_socket.send_multipart((error_bytes, ), copy=False) + + def _async_socket_engine_callback(self, + request_outputs: REQUEST_OUTPUTS_T): + """Callback used by engine to make socket handling async with GPU.""" + self._send_outputs(request_outputs) + self.handle_new_input() + + def _set_errored(self, e: BaseException): + """Log and set errored status if this is the first issue.""" + if self._errored_with is None: + self._errored_with = e + + # only enable for engine(model) 0 + def start_profile(self) -> None: + if type(self.engines[0].model_executor) is GPUExecutor: + self.engines[0].model_executor.start_profile() + else: + self.engines[0].model_executor._run_workers("start_profile") + + def stop_profile(self) -> None: + if type(self.engines[0].model_executor) is GPUExecutor: + self.engines[0].model_executor.stop_profile() + else: + self.engines[0].model_executor._run_workers("stop_profile") + + +def signal_handler(*_) -> None: + raise KeyboardInterrupt("MQLLMEngine terminated") + + +def run_mm_engine(engine_args: AsyncEngineArgs, usage_context: UsageContext, + ipc_path: str, engine_alive): + try: + engine = MMLLMEngine.from_engine_args(engine_args=engine_args, + usage_context=usage_context, + ipc_path=ipc_path) + + signal.signal(signal.SIGTERM, signal_handler) + + engine.start() + + except BaseException as e: + logger.exception(e) + engine_alive.value = False + raise e diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py index 24c206a1261f2..9cf56d0bf1c2f 100644 --- a/vllm/entrypoints/openai/cli_args.py +++ b/vllm/entrypoints/openai/cli_args.py @@ -9,7 +9,7 @@ import ssl from typing import List, Optional, Sequence, Union, get_args -from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str +from vllm.engine.mm_arg_utils import AsyncEngineArgs, nullable_str from vllm.entrypoints.chat_utils import (ChatTemplateContentFormatOption, validate_chat_template) from vllm.entrypoints.openai.serving_engine import (LoRAModulePath, diff --git a/vllm/entrypoints/openai/mm_api_server.py b/vllm/entrypoints/openai/mm_api_server.py new file mode 100644 index 0000000000000..99db253c9da33 --- /dev/null +++ b/vllm/entrypoints/openai/mm_api_server.py @@ -0,0 +1,649 @@ +import asyncio +import importlib +import inspect +import multiprocessing +import os +import re +import signal +import socket +import tempfile +import uuid +from argparse import Namespace +from contextlib import asynccontextmanager +from functools import partial +from http import HTTPStatus +from typing import AsyncIterator, Optional, Set, Tuple + +import uvloop +from fastapi import APIRouter, FastAPI, Request +from fastapi.exceptions import RequestValidationError +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import JSONResponse, Response, StreamingResponse +from starlette.datastructures import State +from starlette.routing import Mount +from typing_extensions import assert_never + +import vllm.envs as envs +from vllm.config import ModelConfig +from vllm.engine.mm_arg_utils import AsyncEngineArgs +from vllm.engine.multiprocessing.mm_client import MMLLMEngineClient +from vllm.engine.multiprocessing.mm_engine import run_mm_engine +from vllm.engine.protocol import EngineClient +from vllm.entrypoints.chat_utils import load_chat_template +from vllm.entrypoints.launcher import serve_http +from vllm.entrypoints.logger import RequestLogger +from vllm.entrypoints.openai.cli_args import (make_arg_parser, + validate_parsed_serve_args) +# yapf conflicts with isort for this block +# yapf: disable +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + ChatCompletionResponse, + CompletionRequest, + CompletionResponse, + DetokenizeRequest, + DetokenizeResponse, + EmbeddingRequest, + EmbeddingResponse, ErrorResponse, + LoadLoraAdapterRequest, + TokenizeRequest, + TokenizeResponse, + UnloadLoraAdapterRequest) +# yapf: enable +from vllm.entrypoints.openai.serving_chat import OpenAIServingChat +from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion +from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding +from vllm.entrypoints.openai.serving_engine import BaseModelPath, OpenAIServing +from vllm.entrypoints.openai.serving_tokenization import ( + OpenAIServingTokenization) +from vllm.entrypoints.openai.tool_parsers import ToolParserManager +from vllm.logger import init_logger +from vllm.usage.usage_lib import UsageContext +from vllm.utils import (FlexibleArgumentParser, get_open_zmq_ipc_path, + is_valid_ipv6_address) +from vllm.version import __version__ as VLLM_VERSION + +if envs.VLLM_USE_V1: + from vllm.v1.engine.async_llm import AsyncLLMEngine # type: ignore +else: + from vllm.engine.async_llm_engine import AsyncLLMEngine # type: ignore + +TIMEOUT_KEEP_ALIVE = 5 # seconds + +prometheus_multiproc_dir: tempfile.TemporaryDirectory + +# Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765) +logger = init_logger('vllm.entrypoints.openai.api_server') + +_running_tasks: Set[asyncio.Task] = set() + + +@asynccontextmanager +async def lifespan(app: FastAPI): + try: + if app.state.log_stats: + engine_client: EngineClient = app.state.engine_client + + async def _force_log(): + while True: + await asyncio.sleep(10.) + await engine_client.do_log_stats() + + task = asyncio.create_task(_force_log()) + _running_tasks.add(task) + task.add_done_callback(_running_tasks.remove) + else: + task = None + try: + yield + finally: + if task is not None: + task.cancel() + finally: + # Ensure app state including engine ref is gc'd + del app.state + + +@asynccontextmanager +async def build_async_engine_client( + args: Namespace) -> AsyncIterator[EngineClient]: + + # Context manager to handle engine_client lifecycle + # Ensures everything is shutdown and cleaned up on error/exit + engine_args = AsyncEngineArgs.from_cli_args(args) + + async with build_async_engine_client_from_engine_args( + engine_args, args.disable_frontend_multiprocessing) as engine: + yield engine + + +@asynccontextmanager +async def build_async_engine_client_from_engine_args( + engine_args: AsyncEngineArgs, + disable_frontend_multiprocessing: bool = False, +) -> AsyncIterator[EngineClient]: + """ + Create EngineClient, either: + - in-process using the AsyncLLMEngine Directly + - multiprocess using AsyncLLMEngine RPC + + Returns the Client or None if the creation failed. + """ + # Fall back + # TODO: fill out feature matrix. + if (MMLLMEngineClient.is_unsupported_config(engine_args) + or envs.VLLM_USE_V1 or disable_frontend_multiprocessing): + + engine_config = engine_args.create_engine_config() + uses_ray = getattr(AsyncLLMEngine._get_executor_cls(engine_config), + "uses_ray", False) + + build_engine = partial(AsyncLLMEngine.from_engine_args, + engine_args=engine_args, + engine_config=engine_config, + usage_context=UsageContext.OPENAI_API_SERVER) + if uses_ray: + # Must run in main thread with ray for its signal handlers to work + engine_client = build_engine() + else: + engine_client = await asyncio.get_running_loop().run_in_executor( + None, build_engine) + + yield engine_client + if hasattr(engine_client, "shutdown"): + engine_client.shutdown() + return + + # Otherwise, use the multiprocessing AsyncLLMEngine. + else: + if "PROMETHEUS_MULTIPROC_DIR" not in os.environ: + # Make TemporaryDirectory for prometheus multiprocessing + # Note: global TemporaryDirectory will be automatically + # cleaned up upon exit. + global prometheus_multiproc_dir + prometheus_multiproc_dir = tempfile.TemporaryDirectory() + os.environ[ + "PROMETHEUS_MULTIPROC_DIR"] = prometheus_multiproc_dir.name + else: + logger.warning( + "Found PROMETHEUS_MULTIPROC_DIR was set by user. " + "This directory must be wiped between vLLM runs or " + "you will find inaccurate metrics. Unset the variable " + "and vLLM will properly handle cleanup.") + + # Select random path for IPC. + ipc_path = get_open_zmq_ipc_path() + logger.info("Multiprocessing frontend to use %s for IPC Path.", + ipc_path) + + # Start RPCServer in separate process (holds the LLMEngine). + # the current process might have CUDA context, + # so we need to spawn a new process + context = multiprocessing.get_context("spawn") + + # The Process can raise an exception during startup, which may + # not actually result in an exitcode being reported. As a result + # we use a shared variable to communicate the information. + engine_alive = multiprocessing.Value('b', True, lock=False) + engine_process = context.Process(target=run_mm_engine, + args=(engine_args, + UsageContext.OPENAI_API_SERVER, + ipc_path, engine_alive)) + engine_process.start() + engine_pid = engine_process.pid + assert engine_pid is not None, "Engine process failed to start." + logger.info("Started engine process with PID %d", engine_pid) + + # Build RPCClient, which conforms to EngineClient Protocol. + engine_config = engine_args.create_engine_config() + build_client = partial(MMLLMEngineClient, ipc_path, engine_config, + engine_pid) + mq_engine_client = await asyncio.get_running_loop().run_in_executor( + None, build_client) + try: + while True: + try: + await mq_engine_client.setup() + break + except TimeoutError: + if (not engine_process.is_alive() + or not engine_alive.value): + raise RuntimeError( + "Engine process failed to start. See stack " + "trace for the root cause.") from None + + yield mq_engine_client # type: ignore[misc] + finally: + # Ensure rpc server process was terminated + engine_process.terminate() + + # Close all open connections to the backend + mq_engine_client.close() + + # Wait for engine process to join + engine_process.join(4) + if engine_process.exitcode is None: + # Kill if taking longer than 5 seconds to stop + engine_process.kill() + + # Lazy import for prometheus multiprocessing. + # We need to set PROMETHEUS_MULTIPROC_DIR environment variable + # before prometheus_client is imported. + # See https://prometheus.github.io/client_python/multiprocess/ + from prometheus_client import multiprocess + multiprocess.mark_process_dead(engine_process.pid) + + +router = APIRouter() + + +def mount_metrics(app: FastAPI): + # Lazy import for prometheus multiprocessing. + # We need to set PROMETHEUS_MULTIPROC_DIR environment variable + # before prometheus_client is imported. + # See https://prometheus.github.io/client_python/multiprocess/ + from prometheus_client import (CollectorRegistry, make_asgi_app, + multiprocess) + + prometheus_multiproc_dir_path = os.getenv("PROMETHEUS_MULTIPROC_DIR", None) + if prometheus_multiproc_dir_path is not None: + logger.info("vLLM to use %s as PROMETHEUS_MULTIPROC_DIR", + prometheus_multiproc_dir_path) + registry = CollectorRegistry() + multiprocess.MultiProcessCollector(registry) + + # Add prometheus asgi middleware to route /metrics requests + metrics_route = Mount("/metrics", make_asgi_app(registry=registry)) + else: + # Add prometheus asgi middleware to route /metrics requests + metrics_route = Mount("/metrics", make_asgi_app()) + + # Workaround for 307 Redirect for /metrics + metrics_route.path_regex = re.compile("^/metrics(?P.*)$") + app.routes.append(metrics_route) + + +def base(request: Request) -> OpenAIServing: + # Reuse the existing instance + return tokenization(request) + + +def chat(request: Request) -> Optional[OpenAIServingChat]: + return request.app.state.openai_serving_chat + + +def completion(request: Request) -> Optional[OpenAIServingCompletion]: + return request.app.state.openai_serving_completion + + +def embedding(request: Request) -> Optional[OpenAIServingEmbedding]: + return request.app.state.openai_serving_embedding + + +def tokenization(request: Request) -> OpenAIServingTokenization: + return request.app.state.openai_serving_tokenization + + +def engine_client(request: Request) -> EngineClient: + return request.app.state.engine_client + + +@router.get("/health") +async def health(raw_request: Request) -> Response: + """Health check.""" + await engine_client(raw_request).check_health() + return Response(status_code=200) + + +@router.post("/tokenize") +async def tokenize(request: TokenizeRequest, raw_request: Request): + handler = tokenization(raw_request) + + generator = await handler.create_tokenize(request) + if isinstance(generator, ErrorResponse): + return JSONResponse(content=generator.model_dump(), + status_code=generator.code) + elif isinstance(generator, TokenizeResponse): + return JSONResponse(content=generator.model_dump()) + + assert_never(generator) + + +@router.post("/detokenize") +async def detokenize(request: DetokenizeRequest, raw_request: Request): + handler = tokenization(raw_request) + + generator = await handler.create_detokenize(request) + if isinstance(generator, ErrorResponse): + return JSONResponse(content=generator.model_dump(), + status_code=generator.code) + elif isinstance(generator, DetokenizeResponse): + return JSONResponse(content=generator.model_dump()) + + assert_never(generator) + + +@router.get("/v1/models") +async def show_available_models(raw_request: Request): + handler = base(raw_request) + + models = await handler.show_available_models() + return JSONResponse(content=models.model_dump()) + + +@router.get("/version") +async def show_version(): + ver = {"version": VLLM_VERSION} + return JSONResponse(content=ver) + + +@router.post("/v1/chat/completions") +async def create_chat_completion(request: ChatCompletionRequest, + raw_request: Request): + handler = chat(raw_request) + if handler is None: + return base(raw_request).create_error_response( + message="The model does not support Chat Completions API") + + generator = await handler.create_chat_completion(request, raw_request) + + if isinstance(generator, ErrorResponse): + return JSONResponse(content=generator.model_dump(), + status_code=generator.code) + + elif isinstance(generator, ChatCompletionResponse): + return JSONResponse(content=generator.model_dump()) + + return StreamingResponse(content=generator, media_type="text/event-stream") + + +@router.post("/v1/completions") +async def create_completion(request: CompletionRequest, raw_request: Request): + handler = completion(raw_request) + if handler is None: + return base(raw_request).create_error_response( + message="The model does not support Completions API") + + generator = await handler.create_completion(request, raw_request) + if isinstance(generator, ErrorResponse): + return JSONResponse(content=generator.model_dump(), + status_code=generator.code) + elif isinstance(generator, CompletionResponse): + return JSONResponse(content=generator.model_dump()) + + return StreamingResponse(content=generator, media_type="text/event-stream") + + +@router.post("/v1/embeddings") +async def create_embedding(request: EmbeddingRequest, raw_request: Request): + handler = embedding(raw_request) + if handler is None: + return base(raw_request).create_error_response( + message="The model does not support Embeddings API") + + generator = await handler.create_embedding(request, raw_request) + if isinstance(generator, ErrorResponse): + return JSONResponse(content=generator.model_dump(), + status_code=generator.code) + elif isinstance(generator, EmbeddingResponse): + return JSONResponse(content=generator.model_dump()) + + assert_never(generator) + + +if envs.VLLM_TORCH_PROFILER_DIR: + logger.warning( + "Torch Profiler is enabled in the API server. This should ONLY be " + "used for local development!") + + @router.post("/start_profile") + async def start_profile(raw_request: Request): + logger.info("Starting profiler...") + await engine_client(raw_request).start_profile() + logger.info("Profiler started.") + return Response(status_code=200) + + @router.post("/stop_profile") + async def stop_profile(raw_request: Request): + logger.info("Stopping profiler...") + await engine_client(raw_request).stop_profile() + logger.info("Profiler stopped.") + return Response(status_code=200) + + +if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING: + logger.warning( + "Lora dynamic loading & unloading is enabled in the API server. " + "This should ONLY be used for local development!") + + @router.post("/v1/load_lora_adapter") + async def load_lora_adapter(request: LoadLoraAdapterRequest, + raw_request: Request): + for route in [chat, completion, embedding]: + handler = route(raw_request) + if handler is not None: + response = await handler.load_lora_adapter(request) + if isinstance(response, ErrorResponse): + return JSONResponse(content=response.model_dump(), + status_code=response.code) + + return Response(status_code=200, content=response) + + @router.post("/v1/unload_lora_adapter") + async def unload_lora_adapter(request: UnloadLoraAdapterRequest, + raw_request: Request): + for route in [chat, completion, embedding]: + handler = route(raw_request) + if handler is not None: + response = await handler.unload_lora_adapter(request) + if isinstance(response, ErrorResponse): + return JSONResponse(content=response.model_dump(), + status_code=response.code) + + return Response(status_code=200, content=response) + + +def build_app(args: Namespace) -> FastAPI: + if args.disable_fastapi_docs: + app = FastAPI(openapi_url=None, + docs_url=None, + redoc_url=None, + lifespan=lifespan) + else: + app = FastAPI(lifespan=lifespan) + app.include_router(router) + app.root_path = args.root_path + + mount_metrics(app) + + app.add_middleware( + CORSMiddleware, + allow_origins=args.allowed_origins, + allow_credentials=args.allow_credentials, + allow_methods=args.allowed_methods, + allow_headers=args.allowed_headers, + ) + + @app.exception_handler(RequestValidationError) + async def validation_exception_handler(_, exc): + chat = app.state.openai_serving_chat + err = chat.create_error_response(message=str(exc)) + return JSONResponse(err.model_dump(), + status_code=HTTPStatus.BAD_REQUEST) + + if token := envs.VLLM_API_KEY or args.api_key: + + @app.middleware("http") + async def authentication(request: Request, call_next): + root_path = "" if args.root_path is None else args.root_path + if request.method == "OPTIONS": + return await call_next(request) + if not request.url.path.startswith(f"{root_path}/v1"): + return await call_next(request) + if request.headers.get("Authorization") != "Bearer " + token: + return JSONResponse(content={"error": "Unauthorized"}, + status_code=401) + return await call_next(request) + + @app.middleware("http") + async def add_request_id(request: Request, call_next): + request_id = request.headers.get("X-Request-Id") or uuid.uuid4().hex + response = await call_next(request) + response.headers["X-Request-Id"] = request_id + return response + + for middleware in args.middleware: + module_path, object_name = middleware.rsplit(".", 1) + imported = getattr(importlib.import_module(module_path), object_name) + if inspect.isclass(imported): + app.add_middleware(imported) + elif inspect.iscoroutinefunction(imported): + app.middleware("http")(imported) + else: + raise ValueError(f"Invalid middleware {middleware}. " + f"Must be a function or a class.") + + return app + + +def init_app_state( + engine_client: EngineClient, + model_config: ModelConfig, + state: State, + args: Namespace, +) -> None: + if args.served_model_name is not None: + served_model_names = args.served_model_name + else: + served_model_names = [model for model in args.model] + + if args.disable_log_requests: + request_logger = None + else: + request_logger = RequestLogger(max_log_len=args.max_log_len) + + base_model_paths = [ + BaseModelPath(name=name, model_path=path) + for name, path in zip(served_model_names, args.model) + ] + + state.engine_client = engine_client + state.log_stats = not args.disable_log_stats + + resolved_chat_template = load_chat_template(args.chat_template) + logger.info("Using supplied chat template:\n%s", resolved_chat_template) + + state.openai_serving_chat = OpenAIServingChat( + engine_client, + model_config, + base_model_paths, + args.response_role, + lora_modules=args.lora_modules, + prompt_adapters=args.prompt_adapters, + request_logger=request_logger, + chat_template=resolved_chat_template, + chat_template_content_format=args.chat_template_content_format, + return_tokens_as_token_ids=args.return_tokens_as_token_ids, + enable_auto_tools=args.enable_auto_tool_choice, + tool_parser=args.tool_call_parser, + enable_prompt_tokens_details=args.enable_prompt_tokens_details, + ) if model_config.task == "generate" else None + state.openai_serving_completion = OpenAIServingCompletion( + engine_client, + model_config, + base_model_paths, + lora_modules=args.lora_modules, + prompt_adapters=args.prompt_adapters, + request_logger=request_logger, + return_tokens_as_token_ids=args.return_tokens_as_token_ids, + ) if model_config.task == "generate" else None + state.openai_serving_embedding = OpenAIServingEmbedding( + engine_client, + model_config, + base_model_paths, + request_logger=request_logger, + chat_template=resolved_chat_template, + chat_template_content_format=args.chat_template_content_format, + ) if model_config.task == "embedding" else None + state.openai_serving_tokenization = OpenAIServingTokenization( + engine_client, + model_config, + base_model_paths, + lora_modules=args.lora_modules, + request_logger=request_logger, + chat_template=resolved_chat_template, + chat_template_content_format=args.chat_template_content_format, + ) + + +def create_server_socket(addr: Tuple[str, int]) -> socket.socket: + family = socket.AF_INET + if is_valid_ipv6_address(addr[0]): + family = socket.AF_INET6 + + sock = socket.socket(family=family, type=socket.SOCK_STREAM) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.bind(addr) + + return sock + + +async def run_server(args, **uvicorn_kwargs) -> None: + logger.info("vLLM API server version %s", VLLM_VERSION) + logger.info("args: %s", args) + + if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3: + ToolParserManager.import_tool_parser(args.tool_parser_plugin) + + valide_tool_parses = ToolParserManager.tool_parsers.keys() + if args.enable_auto_tool_choice \ + and args.tool_call_parser not in valide_tool_parses: + raise KeyError(f"invalid tool call parser: {args.tool_call_parser} " + f"(chose from {{ {','.join(valide_tool_parses)} }})") + + # workaround to make sure that we bind the port before the engine is set up. + # This avoids race conditions with ray. + # see https://github.com/vllm-project/vllm/issues/8204 + sock_addr = (args.host or "", args.port) + sock = create_server_socket(sock_addr) + + def signal_handler(*_) -> None: + # Interrupt server on sigterm while initializing + raise KeyboardInterrupt("terminated") + + signal.signal(signal.SIGTERM, signal_handler) + + async with build_async_engine_client(args) as engine_client: + app = build_app(args) + + model_config = await engine_client.get_model_config() + init_app_state(engine_client, model_config, app.state, args) + + shutdown_task = await serve_http( + app, + host=args.host, + port=args.port, + log_level=args.uvicorn_log_level, + timeout_keep_alive=TIMEOUT_KEEP_ALIVE, + ssl_keyfile=args.ssl_keyfile, + ssl_certfile=args.ssl_certfile, + ssl_ca_certs=args.ssl_ca_certs, + ssl_cert_reqs=args.ssl_cert_reqs, + **uvicorn_kwargs, + ) + + # NB: Await server shutdown only after the backend context is exited + await shutdown_task + + sock.close() + + +if __name__ == "__main__": + # NOTE(simon): + # This section should be in sync with vllm/scripts.py for CLI entrypoints. + parser = FlexibleArgumentParser( + description="vLLM OpenAI-Compatible RESTful API server.") + parser = make_arg_parser(parser) + args = parser.parse_args() + validate_parsed_serve_args(args) + + uvloop.run(run_server(args)) diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 54ca0463bcab1..db048bf04c2da 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -123,7 +123,7 @@ async def create_chat_completion( prompt_adapter_request, ) = self._maybe_get_adapters(request) - tokenizer = await self.engine_client.get_tokenizer(lora_request) + tokenizer = await self.engine_client.get_tokenizer_mm(request.model, lora_request) tool_parser = self.tool_parser diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 936aae8f1c267..44ecf1798a997 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -99,7 +99,7 @@ async def create_completion( prompt_adapter_request, ) = self._maybe_get_adapters(request) - tokenizer = await self.engine_client.get_tokenizer(lora_request) + tokenizer = await self.engine_client.get_tokenizer_mm(request.model, lora_request) request_prompts, engine_prompts = self._preprocess_completion( request, @@ -148,6 +148,7 @@ async def create_completion( engine_prompt, sampling_params, request_id_item, + request.model, lora_request=lora_request, prompt_adapter_request=prompt_adapter_request, trace_headers=trace_headers,