From 43b05fa314e90e551d87211e8bdde2e2bb5a0bdc Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sun, 8 Dec 2024 11:18:18 -0800 Subject: [PATCH 1/3] [torch.compile][misc] fix comments (#10993) Signed-off-by: youkaichao --- vllm/config.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 164622b5af34e..38cf642b23cda 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2177,8 +2177,8 @@ class CompilationConfig(BaseModel): TODO: move outside cudagraph logic into compilation. torch.compile will handle cudagraph capture logic in the future. - cudagraph_capture_sizes: sizes to capture cudagraph. - - None: capture sizes are inferred from compilation context. - - List[int]: capture sizes are specified. + - None (default): capture sizes are inferred from vllm config. + - List[int]: capture sizes are specified as given. - cudagraph_num_of_warmups: number of warmup runs for cudagraph. It means the first several runs will be treated as warmup runs. Only after that, the execution will be recorded, and the recorded From 46004e83a2e0b908f28099d93171bfb4934e4722 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sun, 8 Dec 2024 17:28:27 -0800 Subject: [PATCH 2/3] [misc] clean up and unify logging (#10999) Signed-off-by: youkaichao --- vllm/config.py | 73 ++++++++++++++++++--------------------- vllm/engine/llm_engine.py | 54 ++--------------------------- 2 files changed, 37 insertions(+), 90 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 38cf642b23cda..7fbe04eaaf4f8 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2579,45 +2579,40 @@ def __post_init__(self): self.instance_id = random_uuid()[:5] def __str__(self): - return ("model=%r, speculative_config=%r, tokenizer=%r, " - "skip_tokenizer_init=%s, tokenizer_mode=%s, revision=%s, " - "override_neuron_config=%s, tokenizer_revision=%s, " - "trust_remote_code=%s, dtype=%s, max_seq_len=%d, " - "download_dir=%r, load_format=%s, tensor_parallel_size=%d, " - "pipeline_parallel_size=%d, " - "disable_custom_all_reduce=%s, quantization=%s, " - "enforce_eager=%s, kv_cache_dtype=%s, " - "quantization_param_path=%s, device_config=%s, " - "decoding_config=%r, observability_config=%r, " - "seed=%d, served_model_name=%s, " - "num_scheduler_steps=%d, enable_prefix_caching=%s, " - "use_async_output_proc=%s, mm_processor_kwargs=%s") % \ - (self.model_config.model, self.speculative_config, - self.model_config.tokenizer, - self.model_config.skip_tokenizer_init, - self.model_config.tokenizer_mode, - self.model_config.revision, - self.model_config.override_neuron_config, - self.model_config.tokenizer_revision, - self.model_config.trust_remote_code, - self.model_config.dtype, - self.model_config.max_model_len, - self.load_config.download_dir, - self.load_config.load_format, - self.parallel_config.tensor_parallel_size, - self.parallel_config.pipeline_parallel_size, - self.parallel_config.disable_custom_all_reduce, - self.model_config.quantization, - self.model_config.enforce_eager, - self.cache_config.cache_dtype, - self.model_config.quantization_param_path, - self.device_config.device, self.decoding_config, - self.observability_config, self.model_config.seed, - self.model_config.served_model_name, - self.scheduler_config.num_scheduler_steps, - self.cache_config.enable_prefix_caching, - self.model_config.use_async_output_proc, - self.model_config.mm_processor_kwargs) + return ( + f"model={self.model_config.model!r}," + f" speculative_config={self.speculative_config!r}," + f" tokenizer={self.model_config.tokenizer!r}, " + f"skip_tokenizer_init={self.model_config.skip_tokenizer_init}," + f" tokenizer_mode={self.model_config.tokenizer_mode}, " + f"revision={self.model_config.revision}, " + f"override_neuron_config={self.model_config.override_neuron_config}," + f" tokenizer_revision={self.model_config.tokenizer_revision}, " + f"trust_remote_code={self.model_config.trust_remote_code}, " + f"dtype={self.model_config.dtype}, " + f"max_seq_len={self.model_config.max_model_len}," + f" download_dir={self.load_config.download_dir!r}, " + f"load_format={self.load_config.load_format}, " + f"tensor_parallel_size={self.parallel_config.tensor_parallel_size}," + f" pipeline_parallel_size={self.parallel_config.pipeline_parallel_size}, " # noqa + f"disable_custom_all_reduce={self.parallel_config.disable_custom_all_reduce}, " # noqa + f"quantization={self.model_config.quantization}, " + f"enforce_eager={self.model_config.enforce_eager}, " + f"kv_cache_dtype={self.cache_config.cache_dtype}, " + f"quantization_param_path={self.model_config.quantization_param_path}," + f" device_config={self.device_config.device}, " + f"decoding_config={self.decoding_config!r}, " + f"observability_config={self.observability_config!r}, " + f"seed={self.model_config.seed}, " + f"served_model_name={self.model_config.served_model_name}, " + f"num_scheduler_steps={self.scheduler_config.num_scheduler_steps}, " + f"multi_step_stream_outputs={self.scheduler_config.multi_step_stream_outputs}, " # noqa + f"enable_prefix_caching={self.cache_config.enable_prefix_caching}, " + f"chunked_prefill_enabled={self.scheduler_config.chunked_prefill_enabled}, " # noqa + f"use_async_output_proc={self.model_config.use_async_output_proc}, " + f"mm_processor_kwargs={self.model_config.mm_processor_kwargs}, " + f"pooler_config={self.model_config.pooler_config!r}," + f" compilation_config={self.compilation_config!r}") _current_vllm_config: Optional[VllmConfig] = None diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 26a8c94099a11..560f84a008291 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -247,60 +247,12 @@ def __init__( ) logger.info( - "Initializing an LLM engine (v%s) with config: " - "model=%r, speculative_config=%r, tokenizer=%r, " - "skip_tokenizer_init=%s, tokenizer_mode=%s, revision=%s, " - "override_neuron_config=%s, tokenizer_revision=%s, " - "trust_remote_code=%s, dtype=%s, max_seq_len=%d, " - "download_dir=%r, load_format=%s, tensor_parallel_size=%d, " - "pipeline_parallel_size=%d, " - "disable_custom_all_reduce=%s, quantization=%s, " - "enforce_eager=%s, kv_cache_dtype=%s, " - "quantization_param_path=%s, device_config=%s, " - "decoding_config=%r, observability_config=%r, " - "seed=%d, served_model_name=%s, " - "num_scheduler_steps=%d, chunked_prefill_enabled=%s " - "multi_step_stream_outputs=%s, enable_prefix_caching=%s, " - "use_async_output_proc=%s, use_cached_outputs=%s, " - "mm_processor_kwargs=%s, pooler_config=%r," - "compilation_config=%r", + "Initializing an LLM engine (v%s) with config: %r," + "use_cached_outputs=%s, ", VLLM_VERSION, - self.model_config.model, - self.speculative_config, - self.model_config.tokenizer, - self.model_config.skip_tokenizer_init, - self.model_config.tokenizer_mode, - self.model_config.revision, - self.model_config.override_neuron_config, - self.model_config.tokenizer_revision, - self.model_config.trust_remote_code, - self.model_config.dtype, - self.model_config.max_model_len, - self.load_config.download_dir, - self.load_config.load_format, - self.parallel_config.tensor_parallel_size, - self.parallel_config.pipeline_parallel_size, - self.parallel_config.disable_custom_all_reduce, - self.model_config.quantization, - self.model_config.enforce_eager, - self.cache_config.cache_dtype, - self.model_config.quantization_param_path, - self.device_config.device, - self.decoding_config, - self.observability_config, - self.model_config.seed, - self.model_config.served_model_name, - self.scheduler_config.num_scheduler_steps, - self.scheduler_config.chunked_prefill_enabled, - self.scheduler_config.multi_step_stream_outputs, - self.cache_config.enable_prefix_caching, - self.model_config.use_async_output_proc, + vllm_config, use_cached_outputs, - self.model_config.mm_processor_kwargs, - self.model_config.pooler_config, - vllm_config.compilation_config, ) - # TODO(woosuk): Print more configs in debug mode. self.log_stats = log_stats self.use_cached_outputs = use_cached_outputs From af7c4a92e654684066e61518d6ed90feda983635 Mon Sep 17 00:00:00 2001 From: Roger Wang <136131678+ywang96@users.noreply.github.com> Date: Sun, 8 Dec 2024 22:29:16 -0800 Subject: [PATCH 3/3] [Doc][V1] Add V1 support column for multimodal models (#10998) Signed-off-by: Roger Wang --- docs/source/models/supported_models.rst | 26 ++++++++++++++++++++++++- 1 file changed, 25 insertions(+), 1 deletion(-) diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index c9b3fa8485ff1..4e5b10967e3bb 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -495,7 +495,7 @@ Text Generation --------------- .. list-table:: - :widths: 25 25 15 25 5 5 + :widths: 25 25 15 20 5 5 5 :header-rows: 1 * - Architecture @@ -504,47 +504,55 @@ Text Generation - Example HF Models - :ref:`LoRA ` - :ref:`PP ` + - V1 * - :code:`AriaForConditionalGeneration` - Aria - T + I - :code:`rhymes-ai/Aria` - - ✅︎ + - * - :code:`Blip2ForConditionalGeneration` - BLIP-2 - T + I\ :sup:`E` - :code:`Salesforce/blip2-opt-2.7b`, :code:`Salesforce/blip2-opt-6.7b`, etc. - - ✅︎ + - * - :code:`ChameleonForConditionalGeneration` - Chameleon - T + I - :code:`facebook/chameleon-7b` etc. - - ✅︎ + - * - :code:`FuyuForCausalLM` - Fuyu - T + I - :code:`adept/fuyu-8b` etc. - - ✅︎ + - * - :code:`ChatGLMModel` - GLM-4V - T + I - :code:`THUDM/glm-4v-9b` etc. - ✅︎ - ✅︎ + - * - :code:`H2OVLChatModel` - H2OVL - T + I\ :sup:`E+` - :code:`h2oai/h2ovl-mississippi-800m`, :code:`h2oai/h2ovl-mississippi-2b`, etc. - - ✅︎ + - * - :code:`Idefics3ForConditionalGeneration` - Idefics3 - T + I - :code:`HuggingFaceM4/Idefics3-8B-Llama3` etc. - ✅︎ + - - * - :code:`InternVLChatModel` - InternVL 2.5, Mono-InternVL, InternVL 2.0 @@ -552,96 +560,112 @@ Text Generation - :code:`OpenGVLab/InternVL2_5-4B`, :code:`OpenGVLab/Mono-InternVL-2B`, :code:`OpenGVLab/InternVL2-4B`, etc. - - ✅︎ + - ✅︎ * - :code:`LlavaForConditionalGeneration` - LLaVA-1.5 - T + I\ :sup:`E+` - :code:`llava-hf/llava-1.5-7b-hf`, :code:`TIGER-Lab/Mantis-8B-siglip-llama3` (see note), etc. - - ✅︎ + - ✅︎ * - :code:`LlavaNextForConditionalGeneration` - LLaVA-NeXT - T + I\ :sup:`E+` - :code:`llava-hf/llava-v1.6-mistral-7b-hf`, :code:`llava-hf/llava-v1.6-vicuna-7b-hf`, etc. - - ✅︎ + - * - :code:`LlavaNextVideoForConditionalGeneration` - LLaVA-NeXT-Video - T + V - :code:`llava-hf/LLaVA-NeXT-Video-7B-hf`, etc. - - ✅︎ + - * - :code:`LlavaOnevisionForConditionalGeneration` - LLaVA-Onevision - T + I\ :sup:`+` + V\ :sup:`+` - :code:`llava-hf/llava-onevision-qwen2-7b-ov-hf`, :code:`llava-hf/llava-onevision-qwen2-0.5b-ov-hf`, etc. - - ✅︎ + - * - :code:`MiniCPMV` - MiniCPM-V - T + I\ :sup:`E+` - :code:`openbmb/MiniCPM-V-2` (see note), :code:`openbmb/MiniCPM-Llama3-V-2_5`, :code:`openbmb/MiniCPM-V-2_6`, etc. - ✅︎ - ✅︎ + - * - :code:`MllamaForConditionalGeneration` - Llama 3.2 - T + I\ :sup:`+` - :code:`meta-llama/Llama-3.2-90B-Vision-Instruct`, :code:`meta-llama/Llama-3.2-11B-Vision`, etc. - - + - * - :code:`MolmoForCausalLM` - Molmo - T + I - :code:`allenai/Molmo-7B-D-0924`, :code:`allenai/Molmo-72B-0924`, etc. - - ✅︎ + - ✅︎ * - :code:`NVLM_D_Model` - NVLM-D 1.0 - T + I\ :sup:`E+` - :code:`nvidia/NVLM-D-72B`, etc. - - ✅︎ + - ✅︎ * - :code:`PaliGemmaForConditionalGeneration` - PaliGemma - T + I\ :sup:`E` - :code:`google/paligemma-3b-pt-224`, :code:`google/paligemma-3b-mix-224`, etc. - - ✅︎ + - * - :code:`Phi3VForCausalLM` - Phi-3-Vision, Phi-3.5-Vision - T + I\ :sup:`E+` - :code:`microsoft/Phi-3-vision-128k-instruct`, :code:`microsoft/Phi-3.5-vision-instruct` etc. - - ✅︎ + - ✅︎ * - :code:`PixtralForConditionalGeneration` - Pixtral - T + I\ :sup:`+` - :code:`mistralai/Pixtral-12B-2409`, :code:`mistral-community/pixtral-12b` etc. - - ✅︎ + - ✅︎ * - :code:`QWenLMHeadModel` - Qwen-VL - T + I\ :sup:`E+` - :code:`Qwen/Qwen-VL`, :code:`Qwen/Qwen-VL-Chat`, etc. - ✅︎ - ✅︎ + - * - :code:`Qwen2AudioForConditionalGeneration` - Qwen2-Audio - T + A\ :sup:`+` - :code:`Qwen/Qwen2-Audio-7B-Instruct` - - ✅︎ + - * - :code:`Qwen2VLForConditionalGeneration` - Qwen2-VL - T + I\ :sup:`E+` + V\ :sup:`E+` - :code:`Qwen/Qwen2-VL-2B-Instruct`, :code:`Qwen/Qwen2-VL-7B-Instruct`, :code:`Qwen/Qwen2-VL-72B-Instruct`, etc. - ✅︎ - ✅︎ + - * - :code:`UltravoxModel` - Ultravox - T + A\ :sup:`E+` - :code:`fixie-ai/ultravox-v0_3` - - ✅︎ + - | :sup:`E` Pre-computed embeddings can be inputted for this modality. | :sup:`+` Multiple items can be inputted per text prompt for this modality.