diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index f678436dd05e1..427dc14513d45 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -146,7 +146,9 @@ steps: source_file_dependencies: - vllm/ - tests/test_regression - command: pytest -v -s test_regression.py + commands: + - pip install modelscope + - pytest -v -s test_regression.py working_dir: "/vllm-workspace/tests" # optional - label: Engine Test # 10min diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index c2e1c37218651..23f08bfa9756e 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -12,201 +12,249 @@ Alongside each architecture, we include some popular models that use it. Decoder-only Language Models ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ .. list-table:: - :widths: 25 25 50 5 + :widths: 25 25 50 5 5 :header-rows: 1 * - Architecture - Models - Example HuggingFace Models - :ref:`LoRA ` + - :ref:`PP ` * - :code:`AquilaForCausalLM` - Aquila, Aquila2 - :code:`BAAI/Aquila-7B`, :code:`BAAI/AquilaChat-7B`, etc. - ✅︎ + - ✅︎ * - :code:`ArcticForCausalLM` - Arctic - :code:`Snowflake/snowflake-arctic-base`, :code:`Snowflake/snowflake-arctic-instruct`, etc. - + - ✅︎ * - :code:`BaiChuanForCausalLM` - Baichuan2, Baichuan - :code:`baichuan-inc/Baichuan2-13B-Chat`, :code:`baichuan-inc/Baichuan-7B`, etc. - ✅︎ + - ✅︎ * - :code:`BloomForCausalLM` - BLOOM, BLOOMZ, BLOOMChat - :code:`bigscience/bloom`, :code:`bigscience/bloomz`, etc. - + - ✅︎ * - :code:`ChatGLMModel` - ChatGLM - :code:`THUDM/chatglm2-6b`, :code:`THUDM/chatglm3-6b`, etc. - ✅︎ + - ✅︎ * - :code:`CohereForCausalLM` - Command-R - :code:`CohereForAI/c4ai-command-r-v01`, etc. - - + - ✅︎ + - ✅︎ * - :code:`DbrxForCausalLM` - DBRX - :code:`databricks/dbrx-base`, :code:`databricks/dbrx-instruct`, etc. - + - ✅︎ * - :code:`DeciLMForCausalLM` - DeciLM - :code:`Deci/DeciLM-7B`, :code:`Deci/DeciLM-7B-instruct`, etc. - + - ✅︎ * - :code:`DeepseekForCausalLM` - DeepSeek - :code:`deepseek-ai/deepseek-llm-67b-base`, :code:`deepseek-ai/deepseek-llm-7b-chat` etc. - + - ✅︎ * - :code:`DeepseekV2ForCausalLM` - DeepSeek-V2 - :code:`deepseek-ai/DeepSeek-V2`, :code:`deepseek-ai/DeepSeek-V2-Chat` etc. - + - ✅︎ * - :code:`ExaoneForCausalLM` - EXAONE-3 - :code:`LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct`, etc. - ✅︎ + - ✅︎ * - :code:`FalconForCausalLM` - Falcon - :code:`tiiuae/falcon-7b`, :code:`tiiuae/falcon-40b`, :code:`tiiuae/falcon-rw-7b`, etc. - + - ✅︎ * - :code:`GemmaForCausalLM` - Gemma - :code:`google/gemma-2b`, :code:`google/gemma-7b`, etc. - ✅︎ + - ✅︎ * - :code:`Gemma2ForCausalLM` - Gemma2 - :code:`google/gemma-2-9b`, :code:`google/gemma-2-27b`, etc. - ✅︎ + - ✅︎ * - :code:`GPT2LMHeadModel` - GPT-2 - :code:`gpt2`, :code:`gpt2-xl`, etc. - + - ✅︎ * - :code:`GPTBigCodeForCausalLM` - StarCoder, SantaCoder, WizardCoder - :code:`bigcode/starcoder`, :code:`bigcode/gpt_bigcode-santacoder`, :code:`WizardLM/WizardCoder-15B-V1.0`, etc. - ✅︎ + - ✅︎ * - :code:`GPTJForCausalLM` - GPT-J - :code:`EleutherAI/gpt-j-6b`, :code:`nomic-ai/gpt4all-j`, etc. - + - ✅︎ * - :code:`GPTNeoXForCausalLM` - GPT-NeoX, Pythia, OpenAssistant, Dolly V2, StableLM - :code:`EleutherAI/gpt-neox-20b`, :code:`EleutherAI/pythia-12b`, :code:`OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5`, :code:`databricks/dolly-v2-12b`, :code:`stabilityai/stablelm-tuned-alpha-7b`, etc. - + - ✅︎ * - :code:`GraniteForCausalLM` - PowerLM - :code:`ibm/PowerLM-3b` etc. - ✅︎ + - ✅︎ * - :code:`GraniteMoeForCausalLM` - PowerMoE - :code:`ibm/PowerMoE-3b` etc. - ✅︎ + - ✅︎ * - :code:`InternLMForCausalLM` - InternLM - :code:`internlm/internlm-7b`, :code:`internlm/internlm-chat-7b`, etc. - ✅︎ + - ✅︎ * - :code:`InternLM2ForCausalLM` - InternLM2 - :code:`internlm/internlm2-7b`, :code:`internlm/internlm2-chat-7b`, etc. - + - ✅︎ * - :code:`JAISLMHeadModel` - Jais - :code:`core42/jais-13b`, :code:`core42/jais-13b-chat`, :code:`core42/jais-30b-v3`, :code:`core42/jais-30b-chat-v3`, etc. - + - ✅︎ * - :code:`JambaForCausalLM` - Jamba - :code:`ai21labs/AI21-Jamba-1.5-Large`, :code:`ai21labs/AI21-Jamba-1.5-Mini`, :code:`ai21labs/Jamba-v0.1`, etc. - ✅︎ + - * - :code:`LlamaForCausalLM` - Llama 3.1, Llama 3, Llama 2, LLaMA, Yi - :code:`meta-llama/Meta-Llama-3.1-405B-Instruct`, :code:`meta-llama/Meta-Llama-3.1-70B`, :code:`meta-llama/Meta-Llama-3-70B-Instruct`, :code:`meta-llama/Llama-2-70b-hf`, :code:`01-ai/Yi-34B`, etc. - ✅︎ + - ✅︎ * - :code:`MiniCPMForCausalLM` - MiniCPM - :code:`openbmb/MiniCPM-2B-sft-bf16`, :code:`openbmb/MiniCPM-2B-dpo-bf16`, etc. - - + - ✅︎ + - ✅︎ * - :code:`MiniCPM3ForCausalLM` - MiniCPM3 - :code:`openbmb/MiniCPM3-4B`, etc. - - + - ✅︎ + - ✅︎ * - :code:`MistralForCausalLM` - Mistral, Mistral-Instruct - :code:`mistralai/Mistral-7B-v0.1`, :code:`mistralai/Mistral-7B-Instruct-v0.1`, etc. - ✅︎ + - ✅︎ * - :code:`MixtralForCausalLM` - Mixtral-8x7B, Mixtral-8x7B-Instruct - :code:`mistralai/Mixtral-8x7B-v0.1`, :code:`mistralai/Mixtral-8x7B-Instruct-v0.1`, :code:`mistral-community/Mixtral-8x22B-v0.1`, etc. - ✅︎ + - ✅︎ * - :code:`MPTForCausalLM` - MPT, MPT-Instruct, MPT-Chat, MPT-StoryWriter - :code:`mosaicml/mpt-7b`, :code:`mosaicml/mpt-7b-storywriter`, :code:`mosaicml/mpt-30b`, etc. - + - ✅︎ * - :code:`NemotronForCausalLM` - Nemotron-3, Nemotron-4, Minitron - :code:`nvidia/Minitron-8B-Base`, :code:`mgoin/Nemotron-4-340B-Base-hf-FP8`, etc. - ✅︎ - * - :code:`OLMoEForCausalLM` - - OLMoE - - :code:`allenai/OLMoE-1B-7B-0924`, :code:`allenai/OLMoE-1B-7B-0924-Instruct`, etc. - - + - ✅︎ * - :code:`OLMoForCausalLM` - OLMo - :code:`allenai/OLMo-1B-hf`, :code:`allenai/OLMo-7B-hf`, etc. - + - ✅︎ + * - :code:`OLMoEForCausalLM` + - OLMoE + - :code:`allenai/OLMoE-1B-7B-0924`, :code:`allenai/OLMoE-1B-7B-0924-Instruct`, etc. + - ✅︎ + - ✅︎ * - :code:`OPTForCausalLM` - OPT, OPT-IML - :code:`facebook/opt-66b`, :code:`facebook/opt-iml-max-30b`, etc. - + - ✅︎ * - :code:`OrionForCausalLM` - Orion - :code:`OrionStarAI/Orion-14B-Base`, :code:`OrionStarAI/Orion-14B-Chat`, etc. - + - ✅︎ * - :code:`PhiForCausalLM` - Phi - :code:`microsoft/phi-1_5`, :code:`microsoft/phi-2`, etc. - ✅︎ + - ✅︎ * - :code:`Phi3ForCausalLM` - Phi-3 - :code:`microsoft/Phi-3-mini-4k-instruct`, :code:`microsoft/Phi-3-mini-128k-instruct`, :code:`microsoft/Phi-3-medium-128k-instruct`, etc. - - + - ✅︎ + - ✅︎ * - :code:`Phi3SmallForCausalLM` - Phi-3-Small - :code:`microsoft/Phi-3-small-8k-instruct`, :code:`microsoft/Phi-3-small-128k-instruct`, etc. - + - ✅︎ * - :code:`PhiMoEForCausalLM` - Phi-3.5-MoE - :code:`microsoft/Phi-3.5-MoE-instruct`, etc. - - + - ✅︎ + - ✅︎ * - :code:`PersimmonForCausalLM` - Persimmon - :code:`adept/persimmon-8b-base`, :code:`adept/persimmon-8b-chat`, etc. - + - ✅︎ * - :code:`QWenLMHeadModel` - Qwen - :code:`Qwen/Qwen-7B`, :code:`Qwen/Qwen-7B-Chat`, etc. - + - ✅︎ * - :code:`Qwen2ForCausalLM` - Qwen2 - :code:`Qwen/Qwen2-beta-7B`, :code:`Qwen/Qwen2-beta-7B-Chat`, etc. - ✅︎ + - ✅︎ * - :code:`Qwen2MoeForCausalLM` - Qwen2MoE - :code:`Qwen/Qwen1.5-MoE-A2.7B`, :code:`Qwen/Qwen1.5-MoE-A2.7B-Chat`, etc. - + - ✅︎ * - :code:`StableLmForCausalLM` - StableLM - :code:`stabilityai/stablelm-3b-4e1t`, :code:`stabilityai/stablelm-base-alpha-7b-v2`, etc. - + - ✅︎ * - :code:`Starcoder2ForCausalLM` - Starcoder2 - :code:`bigcode/starcoder2-3b`, :code:`bigcode/starcoder2-7b`, :code:`bigcode/starcoder2-15b`, etc. - + - ✅︎ * - :code:`SolarForCausalLM` - - EXAONE-3 + - Solar Pro - :code:`upstage/solar-pro-preview-instruct`, etc. - - + - ✅︎ + - ✅︎ * - :code:`XverseForCausalLM` - - Xverse + - XVERSE - :code:`xverse/XVERSE-7B-Chat`, :code:`xverse/XVERSE-13B-Chat`, :code:`xverse/XVERSE-65B-Chat`, etc. - - + - ✅︎ + - ✅︎ .. note:: Currently, the ROCm version of vLLM supports Mistral and Mixtral only for context lengths up to 4096. @@ -217,7 +265,7 @@ Multimodal Language Models ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ .. list-table:: - :widths: 25 25 25 25 5 + :widths: 25 25 25 25 5 5 :header-rows: 1 * - Architecture @@ -225,86 +273,103 @@ Multimodal Language Models - Modalities - Example HuggingFace Models - :ref:`LoRA ` + - :ref:`PP ` * - :code:`Blip2ForConditionalGeneration` - BLIP-2 - Image\ :sup:`E` - :code:`Salesforce/blip2-opt-2.7b`, :code:`Salesforce/blip2-opt-6.7b`, etc. - + - ✅︎ * - :code:`ChameleonForConditionalGeneration` - Chameleon - Image - :code:`facebook/chameleon-7b` etc. - + - ✅︎ * - :code:`FuyuForCausalLM` - Fuyu - Image - :code:`adept/fuyu-8b` etc. - + - ✅︎ * - :code:`InternVLChatModel` - InternVL2 - Image\ :sup:`E+` - :code:`OpenGVLab/InternVL2-4B`, :code:`OpenGVLab/InternVL2-8B`, etc. - + - ✅︎ * - :code:`LlavaForConditionalGeneration` - LLaVA-1.5 - Image\ :sup:`E+` - :code:`llava-hf/llava-1.5-7b-hf`, :code:`llava-hf/llava-1.5-13b-hf`, etc. - + - ✅︎ * - :code:`LlavaNextForConditionalGeneration` - LLaVA-NeXT - Image\ :sup:`E+` - :code:`llava-hf/llava-v1.6-mistral-7b-hf`, :code:`llava-hf/llava-v1.6-vicuna-7b-hf`, etc. - + - ✅︎ * - :code:`LlavaNextVideoForConditionalGeneration` - LLaVA-NeXT-Video - Video - :code:`llava-hf/LLaVA-NeXT-Video-7B-hf`, etc. - + - ✅︎ * - :code:`LlavaOnevisionForConditionalGeneration` - LLaVA-Onevision - Image\ :sup:`+` / Video - :code:`llava-hf/llava-onevision-qwen2-7b-ov-hf`, :code:`llava-hf/llava-onevision-qwen2-0.5b-ov-hf`, etc. - + - ✅︎ * - :code:`MiniCPMV` - MiniCPM-V - Image\ :sup:`+` - :code:`openbmb/MiniCPM-V-2` (see note), :code:`openbmb/MiniCPM-Llama3-V-2_5`, :code:`openbmb/MiniCPM-V-2_6`, etc. - - + - ✅︎ + - ✅︎ * - :code:`MllamaForConditionalGeneration` - Llama 3.2 - Image - :code:`meta-llama/Llama-3.2-90B-Vision-Instruct`, :code:`meta-llama/Llama-3.2-11B-Vision`, etc. - + - * - :code:`PaliGemmaForConditionalGeneration` - PaliGemma - Image\ :sup:`E` - :code:`google/paligemma-3b-pt-224`, :code:`google/paligemma-3b-mix-224`, etc. - + - ✅︎ * - :code:`Phi3VForCausalLM` - Phi-3-Vision, Phi-3.5-Vision - Image\ :sup:`E+` - :code:`microsoft/Phi-3-vision-128k-instruct`, :code:`microsoft/Phi-3.5-vision-instruct` etc. - + - ✅︎ * - :code:`PixtralForConditionalGeneration` - Pixtral - Image\ :sup:`+` - :code:`mistralai/Pixtral-12B-2409` - + - ✅︎ * - :code:`QWenLMHeadModel` - Qwen-VL - Image\ :sup:`E+` - :code:`Qwen/Qwen-VL`, :code:`Qwen/Qwen-VL-Chat`, etc. - + - ✅︎ * - :code:`Qwen2VLForConditionalGeneration` - Qwen2-VL - Image\ :sup:`E+` / Video\ :sup:`+` - :code:`Qwen/Qwen2-VL-2B-Instruct`, :code:`Qwen/Qwen2-VL-7B-Instruct`, :code:`Qwen/Qwen2-VL-72B-Instruct`, etc. - + - ✅︎ * - :code:`UltravoxModel` - Ultravox - Audio\ :sup:`E+` - :code:`fixie-ai/ultravox-v0_3` - + - ✅︎ | :sup:`E` Pre-computed embeddings can be inputted for this modality. | :sup:`+` Multiple items can be inputted per text prompt for this modality. diff --git a/requirements-test.txt b/requirements-test.txt index 9c6fadb88865a..37c3bd8ba8794 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -10,8 +10,8 @@ pytest-shard awscli einops # required for MPT, qwen-vl and Mamba httpx -librosa # required for audio test -opencv-python # required for video test +librosa # required for audio tests +opencv-python # required for video tests peft requests ray[adag]==2.35 diff --git a/tests/distributed/test_pipeline_parallel.py b/tests/distributed/test_pipeline_parallel.py index 2e8e83c3d271b..1f62cdc7e06a8 100644 --- a/tests/distributed/test_pipeline_parallel.py +++ b/tests/distributed/test_pipeline_parallel.py @@ -6,6 +6,8 @@ to fail. """ import os +from dataclasses import dataclass +from typing import List, NamedTuple, Optional import pytest @@ -18,49 +20,256 @@ VLLM_MULTI_NODE = os.getenv("VLLM_MULTI_NODE", "0") == "1" +class ParallelSetup(NamedTuple): + tp_size: int + pp_size: int + eager_mode: bool + chunked_prefill: bool + + +@dataclass +class PPTestSettings: + parallel_setups: List[ParallelSetup] + distributed_backends: List[str] + trust_remote_code: bool + tokenizer_mode: Optional[str] + + @staticmethod + def detailed( + *, + tp_base: int = 1, + pp_base: int = 2, + trust_remote_code: bool = False, + tokenizer_mode: Optional[str] = None, + ): + return PPTestSettings( + parallel_setups=[ + ParallelSetup(tp_size=tp_base, + pp_size=pp_base, + eager_mode=False, + chunked_prefill=False), + ParallelSetup(tp_size=tp_base, + pp_size=2 * pp_base, + eager_mode=False, + chunked_prefill=True), + ParallelSetup(tp_size=tp_base, + pp_size=2 * pp_base, + eager_mode=True, + chunked_prefill=False), + ParallelSetup(tp_size=2 * tp_base, + pp_size=pp_base, + eager_mode=False, + chunked_prefill=True), + ParallelSetup(tp_size=2 * tp_base, + pp_size=pp_base, + eager_mode=True, + chunked_prefill=False), + ], + distributed_backends=["mp", "ray"], + trust_remote_code=trust_remote_code, + tokenizer_mode=tokenizer_mode, + ) + + @staticmethod + def fast( + *, + tp_base: int = 1, + pp_base: int = 2, + trust_remote_code: bool = False, + tokenizer_mode: Optional[str] = None, + ): + return PPTestSettings( + parallel_setups=[ + ParallelSetup(tp_size=tp_base, + pp_size=pp_base, + eager_mode=True, + chunked_prefill=False), + ], + distributed_backends=["mp"], + trust_remote_code=trust_remote_code, + tokenizer_mode=tokenizer_mode, + ) + + def iter_params(self, model_name: str): + for parallel_setup in self.parallel_setups: + for distributed_backend in self.distributed_backends: + yield (model_name, parallel_setup, distributed_backend, + self.trust_remote_code, self.tokenizer_mode) + + +# yapf: disable +GENERATION_MODEL_SETTINGS = { + # [DETAILED TESTS] + "meta-llama/Meta-Llama-3-8B": PPTestSettings.detailed(), + # [FAST TESTS] + # Uses Llama + # "BAAI/AquilaChat-7B": PPTestSettings.fast(), + # TODO: Test on larger GPU + # "Snowflake/snowflake-arctic-instruct": PPTestSettings.fast(trust_remote_code=True), # noqa: E501 + "baichuan-inc/Baichuan-7B": PPTestSettings.fast(trust_remote_code=True), + "baichuan-inc/Baichuan2-13B-Chat": PPTestSettings.fast(trust_remote_code=True), # noqa: E501 + "bigscience/bloomz-1b1": PPTestSettings.fast(), + "THUDM/chatglm3-6b": PPTestSettings.fast(trust_remote_code=True), + "CohereForAI/c4ai-command-r-v01": PPTestSettings.fast(tp_base=2, trust_remote_code=True), # noqa: E501 + # TODO: Test on larger GPU + # "databricks/dbrx-instruct": PPTestSettings.fast(), + "Deci/DeciLM-7B-instruct": PPTestSettings.fast(trust_remote_code=True), + "deepseek-ai/deepseek-llm-7b-chat": PPTestSettings.fast(), + "deepseek-ai/DeepSeek-V2-Lite-Chat": PPTestSettings.fast(trust_remote_code=True), # noqa: E501 + "LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct": PPTestSettings.fast(), + "tiiuae/falcon-7b": PPTestSettings.fast(), + "google/gemma-2b": PPTestSettings.fast(), + "google/gemma-2-9b": PPTestSettings.fast(), + "gpt2": PPTestSettings.fast(), + "bigcode/starcoder": PPTestSettings.fast(), + "EleutherAI/gpt-j-6b": PPTestSettings.fast(), + "EleutherAI/pythia-12b": PPTestSettings.fast(), + "ibm/PowerLM-3b": PPTestSettings.fast(), + "ibm/PowerMoE-3b": PPTestSettings.fast(), + # Uses Llama + # "internlm/internlm-chat-7b": PPTestSettings.fast(), + "internlm/internlm2-chat-7b": PPTestSettings.fast(trust_remote_code=True), + "core42/jais-13b-chat": PPTestSettings.fast(), + # TODO: Implement PP + # "ai21labs/AI21-Jamba-1.5-Mini": PPTestSettings.fast(), + "openbmb/MiniCPM-2B-sft-bf16": PPTestSettings.fast(trust_remote_code=True), + "openbmb/MiniCPM3-4B": PPTestSettings.fast(trust_remote_code=True), + # Uses Llama + # "mistralai/Mistral-7B-Instruct-v0.1": PPTestSettings.fast(), + "mistralai/Mixtral-8x7B-Instruct-v0.1": PPTestSettings.fast(tp_base=4), + "mosaicml/mpt-7b": PPTestSettings.fast(), + "nvidia/Minitron-8B-Base": PPTestSettings.fast(), + "allenai/OLMoE-1B-7B-0924-Instruct": PPTestSettings.fast(), + "allenai/OLMo-1B-hf": PPTestSettings.fast(), + "facebook/opt-iml-max-1.3b": PPTestSettings.fast(), + "OrionStarAI/Orion-14B-Chat": PPTestSettings.fast(trust_remote_code=True), + "microsoft/phi-2": PPTestSettings.fast(), + "microsoft/Phi-3-mini-4k-instruct": PPTestSettings.fast(), + "microsoft/Phi-3-small-8k-instruct": PPTestSettings.fast(trust_remote_code=True), # noqa: E501 + # FIXME: https://github.com/vllm-project/vllm/issues/8553 + # "microsoft/Phi-3.5-MoE-instruct": PPTestSettings.fast(trust_remote_code=True), # noqa: E501 + "adept/persimmon-8b-chat": PPTestSettings.fast(), + "Qwen/Qwen-7B-Chat": PPTestSettings.fast(trust_remote_code=True), + "Qwen/Qwen2-beta-7B-Chat": PPTestSettings.fast(), + "Qwen/Qwen1.5-MoE-A2.7B-Chat": PPTestSettings.fast(), + "stabilityai/stablelm-3b-4e1t": PPTestSettings.fast(), + "bigcode/starcoder2-3b": PPTestSettings.fast(), + "upstage/solar-pro-preview-instruct": PPTestSettings.fast(tp_base=2), + # FIXME: Cannot load tokenizer in latest transformers version + # "xverse/XVERSE-7B-Chat": PPTestSettings.fast(trust_remote_code=True), +} + +EMBEDDING_MODEL_SETTINGS = { # type: ignore[var-annotated] + # [FAST TESTS] + # Uses Llama + # "intfloat/e5-mistral-7b-instruct": PPTestSettings.fast(), +} + +MULTIMODAL_MODEL_SETTINGS = { + # [FAST TESTS] + "Salesforce/blip2-opt-2.7b": PPTestSettings.fast(), + "facebook/chameleon-7b": PPTestSettings.fast(), + "adept/fuyu-8b": PPTestSettings.fast(), + "OpenGVLab/InternVL2-1B": PPTestSettings.fast(trust_remote_code=True), + "llava-hf/llava-1.5-7b-hf": PPTestSettings.fast(), + "llava-hf/llava-v1.6-mistral-7b-hf": PPTestSettings.fast(), + "llava-hf/LLaVA-NeXT-Video-7B-hf": PPTestSettings.fast(), + "llava-hf/llava-onevision-qwen2-0.5b-ov-hf": PPTestSettings.fast(), + "openbmb/MiniCPM-Llama3-V-2_5": PPTestSettings.fast(trust_remote_code=True), + # TODO: Implement PP + # "meta-llama/Llama-3.2-11B-Vision-Instruct": PPTestSettings.fast(), + "microsoft/Phi-3-vision-128k-instruct": PPTestSettings.fast(trust_remote_code=True), # noqa: E501 + "mistralai/Pixtral-12B-2409": PPTestSettings.fast(tp_base=2, tokenizer_mode="mistral"), # noqa: E501 + "Qwen/Qwen-VL-Chat": PPTestSettings.fast(trust_remote_code=True), + "Qwen/Qwen2-VL-2B-Instruct": PPTestSettings.fast(), + "fixie-ai/ultravox-v0_3": PPTestSettings.fast(), +} + +CONDITIONAL_GENERATION_MODEL_SETTINGS = { # type: ignore[var-annotated] + # [FAST TESTS] + # TODO: Implement PP + # "facebook/bart-base": PPTestSettings.fast(), +} +# yapf: enable + +MODEL_SETTINGS = { + **GENERATION_MODEL_SETTINGS, + **EMBEDDING_MODEL_SETTINGS, + **MULTIMODAL_MODEL_SETTINGS, +} + +# You can update this on your local machine to run specific tests +TEST_MODELS = [ + "meta-llama/Meta-Llama-3-8B", + "facebook/chameleon-7b", + "OpenGVLab/InternVL2-1B", + "microsoft/Phi-3-vision-128k-instruct", + "mistralai/Pixtral-12B-2409", + "fixie-ai/ultravox-v0_3", +] + + @pytest.mark.parametrize( - ("TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, TRUST_REMOTE_CODE, " - "MODEL_NAME, DIST_BACKEND"), + ("model_name", "parallel_setup", "distributed_backend", + "trust_remote_code", "tokenizer_mode"), [ - (2, 2, 0, 1, 0, "meta-llama/Meta-Llama-3-8B", "mp"), - (2, 2, 1, 0, 0, "meta-llama/Meta-Llama-3-8B", "mp"), - (1, 3, 0, 0, 0, "meta-llama/Meta-Llama-3-8B", "mp"), - (1, 4, 0, 1, 0, "meta-llama/Meta-Llama-3-8B", "mp"), - (1, 4, 1, 0, 0, "meta-llama/Meta-Llama-3-8B", "mp"), - (1, 3, 0, 0, 0, "meta-llama/Meta-Llama-3-8B", "ray"), - (1, 4, 0, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray"), - (1, 4, 1, 0, 0, "meta-llama/Meta-Llama-3-8B", "ray"), - (2, 2, 1, 0, 0, "meta-llama/Meta-Llama-3-8B", "ray"), - (2, 2, 0, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray"), - # NOTE: InternVL2 multi-node tests are flaky, - # use mp backend to skip the multi-node tests - (1, 2, 1, 1, 1, "OpenGVLab/InternVL2-1B", "mp"), - (1, 2, 1, 1, 1, "OpenGVLab/InternVL2-2B", "mp"), - (1, 2, 1, 0, 1, "OpenGVLab/InternVL2-4B", "mp"), - (1, 2, 0, 1, 0, "Qwen/Qwen2-VL-2B-Instruct", "mp"), - # TP only models - (2, 1, 1, 0, 0, "adept/fuyu-8b", "mp"), + params for model_name, settings in MODEL_SETTINGS.items() + for params in settings.iter_params(model_name) + if model_name in TEST_MODELS ], ) @fork_new_process_for_each_test -def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, - TRUST_REMOTE_CODE, MODEL_NAME, DIST_BACKEND): - if VLLM_MULTI_NODE and DIST_BACKEND == "mp": +def test_compare_tp(model_name: str, parallel_setup: ParallelSetup, + distributed_backend: str, trust_remote_code: bool, + tokenizer_mode: Optional[str], num_gpus_available): + tp_size, pp_size, eager_mode, chunked_prefill = parallel_setup + + if num_gpus_available < tp_size: + pytest.skip(f"Need at least {tp_size} GPUs to run the test") + if VLLM_MULTI_NODE and distributed_backend == "mp": pytest.skip("Skipping multi-node pipeline parallel test for " "multiprocessing distributed backend") - pp_args = [ + common_args = [ # use half precision for speed and memory savings in CI environment "--dtype", "float16", "--max-model-len", - "8192", + "2048", + "--max-num-seqs", + "8", + ] + if chunked_prefill: + common_args.append("--enable-chunked-prefill") + if eager_mode: + common_args.append("--enforce-eager") + if trust_remote_code: + common_args.append("--trust-remote-code") + if tokenizer_mode: + common_args.extend(["--tokenizer-mode", tokenizer_mode]) + + if (distributed_backend == "ray" and tp_size == 2 and pp_size == 2 + and chunked_prefill): + # Test Ray ADAG for a subset of the tests + pp_env = { + "VLLM_USE_RAY_COMPILED_DAG": "1", + "VLLM_USE_RAY_SPMD_WORKER": "1", + "VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL": "1", + } + # Temporary. Currently when zeromq + SPMD is used, it does not properly + # terminate because of aDAG issue. + common_args.append("--disable-frontend-multiprocessing") + else: + pp_env = None + + pp_args = [ + *common_args, "--pipeline-parallel-size", - str(PP_SIZE), + str(pp_size), "--tensor-parallel-size", - str(TP_SIZE), + str(tp_size), "--distributed-executor-backend", - DIST_BACKEND, + distributed_backend, ] # compare without pipeline parallelism @@ -69,41 +278,15 @@ def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, # schedule all workers in a node other than the head node, # which can cause the test to fail. tp_args = [ - # use half precision for speed and memory savings in CI environment - "--dtype", - "float16", - "--max-model-len", - "8192", + *common_args, "--tensor-parallel-size", - str(max(TP_SIZE, 2)), # We only use 2 GPUs in the CI. + str(tp_size), "--distributed-executor-backend", "mp", ] - if CHUNKED_PREFILL: - pp_args.append("--enable-chunked-prefill") - tp_args.append("--enable-chunked-prefill") - if EAGER_MODE: - pp_args.append("--enforce-eager") - tp_args.append("--enforce-eager") - if TRUST_REMOTE_CODE: - pp_args.append("--trust-remote-code") - tp_args.append("--trust-remote-code") - pp_env = None - if (DIST_BACKEND == "ray" and TP_SIZE == 2 and PP_SIZE == 2 - and CHUNKED_PREFILL): - # Test Ray ADAG for a subset of the tests - pp_env = { - "VLLM_USE_RAY_COMPILED_DAG": "1", - "VLLM_USE_RAY_SPMD_WORKER": "1", - "VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL": "1", - } - # Temporary. Currently when zeromq + SPMD is used, it does not properly - # terminate because of aDAG issue. - pp_args.append("--disable-frontend-multiprocessing") - tp_args.append("--disable-frontend-multiprocessing") try: - compare_two_settings(MODEL_NAME, pp_args, tp_args, pp_env) + compare_two_settings(model_name, pp_args, tp_args, pp_env) except Exception: if pp_env is None: raise diff --git a/tests/models/test_registry.py b/tests/models/test_registry.py index b058e2755c245..ee5c9e8ccb196 100644 --- a/tests/models/test_registry.py +++ b/tests/models/test_registry.py @@ -1,9 +1,55 @@ +import warnings + import pytest +import torch.cuda from vllm.model_executor.models import _MODELS, ModelRegistry +from vllm.platforms import current_platform + +from ..utils import fork_new_process_for_each_test -@pytest.mark.parametrize("model_cls", _MODELS) -def test_registry_imports(model_cls): +@pytest.mark.parametrize("model_arch", _MODELS) +def test_registry_imports(model_arch): # Ensure all model classes can be imported successfully - ModelRegistry.resolve_model_cls([model_cls]) + ModelRegistry.resolve_model_cls(model_arch) + + +@fork_new_process_for_each_test +@pytest.mark.parametrize("model_arch,is_mm,init_cuda", [ + ("LlamaForCausalLM", False, False), + ("MllamaForConditionalGeneration", True, False), + ("LlavaForConditionalGeneration", True, True), +]) +def test_registry_is_multimodal(model_arch, is_mm, init_cuda): + assert ModelRegistry.is_multimodal_model(model_arch) is is_mm + + if init_cuda and current_platform.is_cuda_alike(): + assert not torch.cuda.is_initialized() + + ModelRegistry.resolve_model_cls(model_arch) + if not torch.cuda.is_initialized(): + warnings.warn( + "This model no longer initializes CUDA on import. " + "Please test using a different one.", + stacklevel=2) + + +@fork_new_process_for_each_test +@pytest.mark.parametrize("model_arch,is_pp,init_cuda", [ + ("MLPSpeculatorPreTrainedModel", False, False), + ("DeepseekV2ForCausalLM", True, False), + ("Qwen2VLForConditionalGeneration", True, True), +]) +def test_registry_is_pp(model_arch, is_pp, init_cuda): + assert ModelRegistry.is_pp_supported_model(model_arch) is is_pp + + if init_cuda and current_platform.is_cuda_alike(): + assert not torch.cuda.is_initialized() + + ModelRegistry.resolve_model_cls(model_arch) + if not torch.cuda.is_initialized(): + warnings.warn( + "This model no longer initializes CUDA on import. " + "Please test using a different one.", + stacklevel=2) diff --git a/tests/utils.py b/tests/utils.py index 49bd4f236f658..8c8a7c4bf0c70 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -14,7 +14,6 @@ import pytest import requests from openai.types.completion import Completion -from transformers import AutoTokenizer from typing_extensions import ParamSpec from tests.models.utils import TextTextLogprobs @@ -24,6 +23,7 @@ from vllm.entrypoints.openai.cli_args import make_arg_parser from vllm.model_executor.model_loader.loader import get_model_loader from vllm.platforms import current_platform +from vllm.transformers_utils.tokenizer import get_tokenizer from vllm.utils import (FlexibleArgumentParser, GB_bytes, cuda_device_count_stateless, get_open_port, is_hip) @@ -181,15 +181,26 @@ def compare_two_settings(model: str, env2: The second set of environment variables to pass to the API server. """ - trust_remote_code = "--trust-remote-code" - if trust_remote_code in arg1 or trust_remote_code in arg2: - tokenizer = AutoTokenizer.from_pretrained(model, - trust_remote_code=True) - else: - tokenizer = AutoTokenizer.from_pretrained(model) + trust_remote_code = False + for args in (arg1, arg2): + if "--trust-remote-code" in args: + trust_remote_code = True + break + + tokenizer_mode = "auto" + for args in (arg1, arg2): + if "--tokenizer-mode" in args: + tokenizer_mode = args[args.index("--tokenizer-mode") + 1] + break + + tokenizer = get_tokenizer( + model, + trust_remote_code=trust_remote_code, + tokenizer_mode=tokenizer_mode, + ) prompt = "Hello, my name is" - token_ids = tokenizer(prompt)["input_ids"] + token_ids = tokenizer(prompt).input_ids results = [] for args, env in ((arg1, env1), (arg2, env2)): with RemoteOpenAIServer(model, diff --git a/vllm/config.py b/vllm/config.py index 05d5f4998d74d..7b3996dc90b94 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -31,28 +31,7 @@ logger = init_logger(__name__) _EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768 -_MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS = 4096 - -_PP_SUPPORTED_MODELS = [ - "AquilaForCausalLM", - "AquilaModel", - "DeepseekV2ForCausalLM", - "GPT2LMHeadModel", - "InternLM2ForCausalLM", - "InternLMForCausalLM", - "InternVLChatModel", - "JAISLMHeadModel", - "LlamaForCausalLM", - "LLaMAForCausalLM", - "MistralForCausalLM", - "MixtralForCausalLM", - "NemotronForCausalLM", - "Phi3ForCausalLM", - "Qwen2ForCausalLM", - "Qwen2MoeForCausalLM", - "QWenLMHeadModel", - "Qwen2VLForConditionalGeneration", -] +_MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS = 5120 class ModelConfig: @@ -228,16 +207,14 @@ def _init_multimodal_config( self, limit_mm_per_prompt: Optional[Mapping[str, int]] ) -> Optional["MultiModalConfig"]: architectures = getattr(self.hf_config, "architectures", []) - if any( - ModelRegistry.is_multimodal_model(arch) - for arch in architectures): + if ModelRegistry.is_multimodal_model(architectures): return MultiModalConfig(limit_per_prompt=limit_mm_per_prompt or {}) - else: - if limit_mm_per_prompt: - raise ValueError( - "limit_mm_per_prompt is only supported for multimodal " - "models.") - return None + + if limit_mm_per_prompt: + raise ValueError("`limit_mm_per_prompt` is only supported for " + "multimodal models.") + + return None def _verify_tokenizer_mode(self) -> None: tokenizer_mode = self.tokenizer_mode.lower() @@ -249,8 +226,7 @@ def _verify_tokenizer_mode(self) -> None: def _verify_embedding_mode(self) -> None: architectures = getattr(self.hf_config, "architectures", []) - self.embedding_mode = any( - ModelRegistry.is_embedding_model(arch) for arch in architectures) + self.embedding_mode = ModelRegistry.is_embedding_model(architectures) def _parse_quant_hf_config(self): quant_cfg = getattr(self.hf_config, "quantization_config", None) @@ -417,17 +393,17 @@ def verify_with_parallel_config( f"({tensor_parallel_size}).") pipeline_parallel_size = parallel_config.pipeline_parallel_size - architectures = getattr(self.hf_config, "architectures", []) - if not all(arch in _PP_SUPPORTED_MODELS - for arch in architectures) and pipeline_parallel_size > 1: - raise NotImplementedError( - "Pipeline parallelism is only supported for the following " - f" architectures: {_PP_SUPPORTED_MODELS}.") + if pipeline_parallel_size > 1: + architectures = getattr(self.hf_config, "architectures", []) + if not ModelRegistry.is_pp_supported_model(architectures): + raise NotImplementedError( + "Pipeline parallelism is not supported for this model. " + "Supported models implement the `SupportsPP` interface.") - if pipeline_parallel_size > 1 and self.use_async_output_proc: - logger.warning("Async output processor is not supported with " - "pipeline parallelism currently. Disabling it.") - self.use_async_output_proc = False + if self.use_async_output_proc: + logger.warning("Async output processor is not supported with " + "pipeline parallelism currently. Disabling it.") + self.use_async_output_proc = False def get_hf_config_sliding_window(self) -> Optional[int]: """Get the sliding window size, or None if disabled.""" diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index 3a57db0d04fab..2f9cb2b760a82 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -1,12 +1,18 @@ -import functools import importlib -from typing import Dict, List, Optional, Tuple, Type +import string +import subprocess +import sys +import uuid +from functools import lru_cache, partial +from typing import Callable, Dict, List, Optional, Tuple, Type, Union import torch.nn as nn from vllm.logger import init_logger from vllm.utils import is_hip +from .interfaces import supports_multimodal, supports_pp + logger = init_logger(__name__) _GENERATION_MODELS = { @@ -152,19 +158,25 @@ class ModelRegistry: @staticmethod - @functools.lru_cache(maxsize=128) - def _get_model(model_arch: str): - module_name, model_cls_name = _MODELS[model_arch] - module = importlib.import_module( - f"vllm.model_executor.models.{module_name}") - return getattr(module, model_cls_name, None) + def _get_module_cls_name(model_arch: str) -> Tuple[str, str]: + module_relname, cls_name = _MODELS[model_arch] + return f"vllm.model_executor.models.{module_relname}", cls_name @staticmethod - def _try_load_model_cls(model_arch: str) -> Optional[Type[nn.Module]]: - if model_arch in _OOT_MODELS: - return _OOT_MODELS[model_arch] + @lru_cache(maxsize=128) + def _try_get_model_stateful(model_arch: str) -> Optional[Type[nn.Module]]: if model_arch not in _MODELS: return None + + module_name, cls_name = ModelRegistry._get_module_cls_name(model_arch) + module = importlib.import_module(module_name) + return getattr(module, cls_name, None) + + @staticmethod + def _try_get_model_stateless(model_arch: str) -> Optional[Type[nn.Module]]: + if model_arch in _OOT_MODELS: + return _OOT_MODELS[model_arch] + if is_hip(): if model_arch in _ROCM_UNSUPPORTED_MODELS: raise ValueError( @@ -175,11 +187,24 @@ def _try_load_model_cls(model_arch: str) -> Optional[Type[nn.Module]]: "Model architecture %s is partially supported by ROCm: %s", model_arch, _ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch]) - return ModelRegistry._get_model(model_arch) + return None + + @staticmethod + def _try_load_model_cls(model_arch: str) -> Optional[Type[nn.Module]]: + model = ModelRegistry._try_get_model_stateless(model_arch) + if model is not None: + return model + + return ModelRegistry._try_get_model_stateful(model_arch) @staticmethod def resolve_model_cls( - architectures: List[str]) -> Tuple[Type[nn.Module], str]: + architectures: Union[str, List[str]], ) -> Tuple[Type[nn.Module], str]: + if isinstance(architectures, str): + architectures = [architectures] + if not architectures: + logger.warning("No model architectures are specified") + for arch in architectures: model_cls = ModelRegistry._try_load_model_cls(arch) if model_cls is not None: @@ -200,21 +225,99 @@ def register_model(model_arch: str, model_cls: Type[nn.Module]): "Model architecture %s is already registered, and will be " "overwritten by the new model class %s.", model_arch, model_cls.__name__) - global _OOT_MODELS + _OOT_MODELS[model_arch] = model_cls @staticmethod - def is_embedding_model(model_arch: str) -> bool: - return model_arch in _EMBEDDING_MODELS + @lru_cache(maxsize=128) + def _check_stateless( + func: Callable[[Type[nn.Module]], bool], + model_arch: str, + *, + default: Optional[bool] = None, + ) -> bool: + """ + Run a boolean function against a model and return the result. + + If the model is not found, returns the provided default value. + + If the model is not already imported, the function is run inside a + subprocess to avoid initializing CUDA for the main program. + """ + model = ModelRegistry._try_get_model_stateless(model_arch) + if model is not None: + return func(model) + + if model_arch not in _MODELS and default is not None: + return default + + module_name, cls_name = ModelRegistry._get_module_cls_name(model_arch) + + valid_name_characters = string.ascii_letters + string.digits + "._" + if any(s not in valid_name_characters for s in module_name): + raise ValueError(f"Unsafe module name detected for {model_arch}") + if any(s not in valid_name_characters for s in cls_name): + raise ValueError(f"Unsafe class name detected for {model_arch}") + if any(s not in valid_name_characters for s in func.__module__): + raise ValueError(f"Unsafe module name detected for {func}") + if any(s not in valid_name_characters for s in func.__name__): + raise ValueError(f"Unsafe class name detected for {func}") + + err_id = uuid.uuid4() + + stmts = ";".join([ + f"from {module_name} import {cls_name}", + f"from {func.__module__} import {func.__name__}", + f"assert {func.__name__}({cls_name}), '{err_id}'", + ]) + + result = subprocess.run([sys.executable, "-c", stmts], + capture_output=True) + + if result.returncode != 0: + err_lines = [line.decode() for line in result.stderr.splitlines()] + if err_lines and err_lines[-1] != f"AssertionError: {err_id}": + err_str = "\n".join(err_lines) + raise RuntimeError( + "An unexpected error occurred while importing the model in " + f"another process. Error log:\n{err_str}") + + return result.returncode == 0 + + @staticmethod + def is_embedding_model(architectures: Union[str, List[str]]) -> bool: + if isinstance(architectures, str): + architectures = [architectures] + if not architectures: + logger.warning("No model architectures are specified") + + return any(arch in _EMBEDDING_MODELS for arch in architectures) + + @staticmethod + def is_multimodal_model(architectures: Union[str, List[str]]) -> bool: + if isinstance(architectures, str): + architectures = [architectures] + if not architectures: + logger.warning("No model architectures are specified") + + is_mm = partial(ModelRegistry._check_stateless, + supports_multimodal, + default=False) + + return any(is_mm(arch) for arch in architectures) @staticmethod - def is_multimodal_model(model_arch: str) -> bool: + def is_pp_supported_model(architectures: Union[str, List[str]]) -> bool: + if isinstance(architectures, str): + architectures = [architectures] + if not architectures: + logger.warning("No model architectures are specified") + + is_pp = partial(ModelRegistry._check_stateless, + supports_pp, + default=False) - # TODO: find a way to avoid initializing CUDA prematurely to - # use `supports_multimodal` to determine if a model is multimodal - # model_cls = ModelRegistry._try_load_model_cls(model_arch) - # from vllm.model_executor.models.interfaces import supports_multimodal - return model_arch in _MULTIMODAL_MODELS + return any(is_pp(arch) for arch in architectures) __all__ = [ diff --git a/vllm/model_executor/models/arctic.py b/vllm/model_executor/models/arctic.py index efa044d0b5e92..30b1f1cce1fcc 100644 --- a/vllm/model_executor/models/arctic.py +++ b/vllm/model_executor/models/arctic.py @@ -1,12 +1,12 @@ """Inference-only Snowflake Arctic model.""" -from typing import Iterable, List, Optional, Tuple +from typing import Iterable, List, Optional, Tuple, Union import torch from torch import nn from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig -from vllm.distributed import (get_tensor_model_parallel_rank, +from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) from vllm.logger import init_logger @@ -18,8 +18,7 @@ ReplicatedLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization.deepspeedfp import ( DeepSpeedFPConfig, DeepSpeedFPParameter) from vllm.model_executor.layers.rotary_embedding import get_rope @@ -32,6 +31,10 @@ from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs.arctic import ArcticConfig +from .interfaces import SupportsPP +from .utils import (is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers) + logger = init_logger(__name__) @@ -364,6 +367,7 @@ def __init__( config: ArcticConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.padding_idx = config.pad_token_id @@ -372,15 +376,16 @@ def __init__( self.vocab_size, config.hidden_size, org_num_embeddings=self.vocab_size) - self.layers = nn.ModuleList([ - ArcticDecoderLayer(config, - layer_idx, - cache_config, - quant_config=quant_config) - for layer_idx in range(config.num_hidden_layers) - ]) + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: ArcticDecoderLayer(config, int( + prefix.split(".")[-1]), cache_config, quant_config), + prefix=f"{prefix}.layers") self._attn_implementation = config._attn_implementation self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory(["hidden_states"], + config.hidden_size)) def forward( self, @@ -388,17 +393,25 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, - ) -> torch.Tensor: - hidden_states = self.embed_tokens(input_ids) - for i in range(len(self.layers)): + intermediate_tensors: Optional[IntermediateTensors], + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + hidden_states = self.embed_tokens(input_ids) + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + for i in range(self.start_layer, self.end_layer): layer = self.layers[i] - hidden_states = layer(positions, hidden_states, kv_caches[i], + hidden_states = layer(positions, hidden_states, + kv_caches[i - self.start_layer], attn_metadata) + if not get_pp_group().is_last_rank: + return IntermediateTensors({"hidden_states": hidden_states}) hidden_states = self.norm(hidden_states) return hidden_states -class ArcticForCausalLM(nn.Module): +class ArcticForCausalLM(nn.Module, SupportsPP): def __init__(self, config: ArcticConfig, @@ -422,6 +435,8 @@ def __init__(self, self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size) self.sampler = Sampler() + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) def forward( self, @@ -430,9 +445,9 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata) + attn_metadata, intermediate_tensors) return hidden_states def compute_logits( @@ -503,6 +518,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + if is_pp_missing_parameter(name, self): + continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) @@ -512,6 +529,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): if weight_name not in name: continue name = name.replace(weight_name, param_name) + if is_pp_missing_parameter(name, self): + continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) @@ -522,6 +541,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): if weight_name not in name: continue name = name.replace(weight_name, param_name) + if is_pp_missing_parameter(name, self): + continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, @@ -532,6 +553,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): else: if name.endswith(".bias") and name not in params_dict: continue + if is_pp_missing_parameter(name, self): + continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", diff --git a/vllm/model_executor/models/baichuan.py b/vllm/model_executor/models/baichuan.py index bdd76b11384c2..54ed548ba8bc7 100644 --- a/vllm/model_executor/models/baichuan.py +++ b/vllm/model_executor/models/baichuan.py @@ -19,7 +19,7 @@ # limitations under the License. """Inference-only BaiChuan model compatible with HuggingFace weights.""" import math -from typing import Iterable, List, Optional, Tuple +from typing import Iterable, List, Optional, Tuple, Union import torch from torch import nn @@ -27,7 +27,7 @@ from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig, LoRAConfig -from vllm.distributed import (get_tensor_model_parallel_rank, +from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm @@ -35,8 +35,7 @@ QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -45,7 +44,9 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors -from .interfaces import SupportsLoRA +from .interfaces import SupportsLoRA, SupportsPP +from .utils import (is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers) def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: @@ -255,7 +256,8 @@ def __init__(self, config: PretrainedConfig, position_embedding: str, cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None): + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): super().__init__() self.config = config self.padding_idx = config.pad_token_id @@ -265,12 +267,16 @@ def __init__(self, config.vocab_size, config.hidden_size, ) - self.layers = nn.ModuleList([ - BaiChuanDecoderLayer(config, position_embedding, cache_config, - quant_config) - for _ in range(config.num_hidden_layers) - ]) + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: BaiChuanDecoderLayer(config, position_embedding, + cache_config, quant_config), + prefix=f"{prefix}.layers", + ) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) def forward( self, @@ -278,23 +284,34 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, - ) -> torch.Tensor: - hidden_states = self.embed_tokens(input_ids) - residual = None - for i in range(len(self.layers)): + intermediate_tensors: Optional[IntermediateTensors], + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + hidden_states = self.embed_tokens(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + for i in range(self.start_layer, self.end_layer): layer = self.layers[i] hidden_states, residual = layer( positions, hidden_states, - kv_caches[i], + kv_caches[i - self.start_layer], attn_metadata, residual, ) + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual, + }) hidden_states, _ = self.norm(hidden_states, residual) return hidden_states -class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA): +class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP): packed_modules_mapping = { "W_pack": ["W_pack"], "gate_up_proj": [ @@ -335,6 +352,8 @@ def __init__( self.lm_head.weight = self.model.embed_tokens.weight self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) def forward( self, @@ -343,9 +362,9 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata) + attn_metadata, intermediate_tensors) return hidden_states def compute_logits( @@ -394,6 +413,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + if is_pp_missing_parameter(name, self): + continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) @@ -402,6 +423,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + if is_pp_missing_parameter(name, self): + continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) @@ -413,7 +436,7 @@ class BaichuanForCausalLM(BaiChuanBaseForCausalLM): def __init__( self, - config, + config: PretrainedConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, @@ -431,7 +454,7 @@ class BaiChuanForCausalLM(BaiChuanBaseForCausalLM): def __init__( self, - config, + config: PretrainedConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, diff --git a/vllm/model_executor/models/blip2.py b/vllm/model_executor/models/blip2.py index b28d7699afa01..ca0cbef5cbf48 100644 --- a/vllm/model_executor/models/blip2.py +++ b/vllm/model_executor/models/blip2.py @@ -1,3 +1,4 @@ +from functools import cached_property from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, TypedDict, Union) @@ -11,7 +12,7 @@ from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY @@ -19,7 +20,7 @@ from .blip import (BlipVisionModel, dummy_image_for_blip, get_max_blip_image_tokens) -from .interfaces import SupportsMultiModal +from .interfaces import SupportsMultiModal, SupportsPP from .utils import (group_weights_with_prefix, init_vllm_registered_model, merge_multimodal_embeddings) @@ -475,7 +476,7 @@ def input_processor_for_blip2(ctx: InputContext, llm_inputs: LLMInputs): @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_blip2_image_tokens) @INPUT_REGISTRY.register_dummy_data(dummy_data_for_blip2) @INPUT_REGISTRY.register_input_processor(input_processor_for_blip2) -class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal): +class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): def __init__(self, config: Blip2Config, @@ -508,6 +509,16 @@ def __init__(self, self.language_model = init_vllm_registered_model( config.text_config, cache_config, quant_config) + self.make_empty_intermediate_tensors = ( + self.language_model.make_empty_intermediate_tensors) + + @cached_property + def sampler(self): + if hasattr(self.language_model, "sampler"): + return self.language_model.sampler + + return Sampler() + def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: h = w = self.config.vision_config.image_size expected_dims = (3, h, w) @@ -600,7 +611,7 @@ def forward( attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, **kwargs: object, - ) -> SamplerOutput: + ) -> Union[SamplerOutput, IntermediateTensors]: """Run forward pass for BLIP-2. One key thing to understand is the `input_ids` already accounts for the @@ -631,26 +642,32 @@ def forward( See also: :class:`Blip2ImageInputs` """ - image_input = self._parse_and_validate_image_input(**kwargs) - - if image_input is not None: - vision_embeddings = self._process_image_input(image_input) - inputs_embeds = self.language_model.model.get_input_embeddings( - input_ids) - - inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, vision_embeddings, - BLIP2_IMAGE_TOKEN_ID) - + if intermediate_tensors is not None: input_ids = None - else: inputs_embeds = None - - hidden_states = self.language_model.model(input_ids, - positions, - kv_caches, - attn_metadata, - inputs_embeds=inputs_embeds) + else: + image_input = self._parse_and_validate_image_input(**kwargs) + + if image_input is not None: + vision_embeddings = self._process_image_input(image_input) + inputs_embeds = self.language_model.model.get_input_embeddings( + input_ids) + + inputs_embeds = merge_multimodal_embeddings( + input_ids, inputs_embeds, vision_embeddings, + BLIP2_IMAGE_TOKEN_ID) + + input_ids = None + else: + inputs_embeds = None + + hidden_states = self.language_model.model( + input_ids, + positions, + kv_caches, + attn_metadata, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds) return hidden_states diff --git a/vllm/model_executor/models/bloom.py b/vllm/model_executor/models/bloom.py index 831b3f20457a9..b2c9e221690b3 100644 --- a/vllm/model_executor/models/bloom.py +++ b/vllm/model_executor/models/bloom.py @@ -17,7 +17,7 @@ # limitations under the License. """Inference-only BLOOM model compatible with HuggingFace weights.""" import math -from typing import Iterable, List, Optional, Tuple +from typing import Iterable, List, Optional, Tuple, Union import torch from torch import nn @@ -25,15 +25,14 @@ from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig -from vllm.distributed import (get_tensor_model_parallel_rank, +from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) @@ -41,6 +40,10 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors +from .interfaces import SupportsPP +from .utils import (is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers) + def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: closest_power_of_2 = 2**math.floor(math.log2(total_num_heads)) @@ -222,6 +225,7 @@ def __init__( config: BloomConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() self.embed_dim = config.hidden_size @@ -235,13 +239,16 @@ def __init__( self.embed_dim, eps=config.layer_norm_epsilon) # Transformer blocks - self.h = nn.ModuleList([ - BloomBlock(config, cache_config, quant_config) - for _ in range(config.num_hidden_layers) - ]) + self.start_layer, self.end_layer, self.h = make_layers( + config.num_hidden_layers, + lambda prefix: BloomBlock(config, cache_config, quant_config), + prefix=f"{prefix}.h") # Final Layer Norm self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory(["hidden_states"], + config.hidden_size)) def forward( self, @@ -249,22 +256,29 @@ def forward( position_ids: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, - ) -> torch.Tensor: - hidden_states = self.word_embeddings(input_ids) - hidden_states = self.word_embeddings_layernorm(hidden_states) - for i in range(len(self.h)): + intermediate_tensors: Optional[IntermediateTensors], + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + hidden_states = self.word_embeddings(input_ids) + hidden_states = self.word_embeddings_layernorm(hidden_states) + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + for i in range(self.start_layer, self.end_layer): layer = self.h[i] hidden_states = layer( position_ids, hidden_states, - kv_caches[i], + kv_caches[i - self.start_layer], attn_metadata, ) + if not get_pp_group().is_last_rank: + return IntermediateTensors({"hidden_states": hidden_states}) hidden_states = self.ln_f(hidden_states) return hidden_states -class BloomForCausalLM(nn.Module): +class BloomForCausalLM(nn.Module, SupportsPP): def __init__( self, @@ -284,6 +298,8 @@ def __init__( self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() + self.make_empty_intermediate_tensors = ( + self.transformer.make_empty_intermediate_tensors) def forward( self, @@ -292,9 +308,9 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.transformer(input_ids, positions, kv_caches, - attn_metadata) + attn_metadata, intermediate_tensors) return hidden_states def compute_logits( @@ -321,6 +337,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): continue if not name.startswith("transformer."): name = "transformer." + name + if is_pp_missing_parameter(name, self): + continue param = params_dict[name] if "query_key_value" in name: diff --git a/vllm/model_executor/models/chameleon.py b/vllm/model_executor/models/chameleon.py index 973e47f5f0ccd..03c7419f6f6af 100644 --- a/vllm/model_executor/models/chameleon.py +++ b/vllm/model_executor/models/chameleon.py @@ -1,6 +1,6 @@ from functools import cached_property from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional, - Tuple, TypedDict) + Tuple, TypedDict, Union) import torch import torch.nn.functional as F @@ -10,7 +10,7 @@ from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig, MultiModalConfig -from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm @@ -33,7 +33,9 @@ from vllm.sequence import IntermediateTensors, SequenceData from vllm.utils import print_warning_once -from .interfaces import SupportsMultiModal +from .interfaces import SupportsMultiModal, SupportsPP +from .utils import (is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers) # These configs are not part of the model config but the preprocessor # and processor files, so we hardcode them in the model file for now. @@ -822,6 +824,7 @@ def __init__( config: ChameleonConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.config = config @@ -835,14 +838,20 @@ def __init__( config.vocabulary_map) decoder_layer = ChameleonDecoderLayer if not self.config.swin_norm \ else ChameleonSwinDecoderLayer - self.layers = nn.ModuleList([ - decoder_layer(config=config, - cache_config=cache_config, - quant_config=quant_config) - for _ in range(config.num_hidden_layers) - ]) + + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: decoder_layer(config=config, + cache_config=cache_config, + quant_config=quant_config), + prefix=f"{prefix}.layers", + ) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.vqmodel = ChameleonVQVAE(config.vq_config) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) @@ -865,22 +874,33 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors], inputs_embeds: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - if inputs_embeds is not None: - hidden_states = inputs_embeds + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None else: - hidden_states = self.get_input_embeddings(input_ids) - residual = None - for i in range(len(self.layers)): + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + for i in range(self.start_layer, self.end_layer): layer = self.layers[i] hidden_states, residual = layer( positions, hidden_states, - kv_caches[i], + kv_caches[i - self.start_layer], attn_metadata, residual, ) + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) hidden_states, _ = self.norm(hidden_states, residual) return hidden_states @@ -889,7 +909,8 @@ def forward( @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_chameleon_image_tokens) @INPUT_REGISTRY.register_dummy_data(dummy_data_for_chameleon) @INPUT_REGISTRY.register_input_processor(input_processor_for_chameleon) -class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal): +class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal, + SupportsPP): def __init__( self, @@ -914,6 +935,8 @@ def __init__( self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size, logit_scale) self.sampler = Sampler() + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: @@ -956,22 +979,26 @@ def forward( attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, **kwargs, - ) -> torch.Tensor: - - image_input = self._parse_and_validate_image_input(**kwargs) + ) -> Union[torch.Tensor, IntermediateTensors]: - if image_input is not None: - assert self.model.vqmodel is not None - image_tokens = self.model.get_image_tokens(image_input["data"].to( - self.config.torch_dtype)) - image_token_id = self.model.vocabulary_mapping.image_token_id - special_image_mask = input_ids == image_token_id - image_tokens = image_tokens.to(input_ids.device, input_ids.dtype) - input_ids = input_ids.masked_scatter(special_image_mask, - image_tokens) + if intermediate_tensors is not None: + input_ids = None + else: + image_input = self._parse_and_validate_image_input(**kwargs) + + if image_input is not None: + assert self.model.vqmodel is not None + image_tokens = self.model.get_image_tokens( + image_input["data"].to(self.config.torch_dtype)) + image_token_id = self.model.vocabulary_mapping.image_token_id + special_image_mask = input_ids == image_token_id + image_tokens = image_tokens.to(input_ids.device, + input_ids.dtype) + input_ids = input_ids.masked_scatter(special_image_mask, + image_tokens) hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata) + attn_metadata, intermediate_tensors) return hidden_states def compute_logits( @@ -1039,6 +1066,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + if is_pp_missing_parameter(name, self): + continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) @@ -1060,11 +1089,15 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): continue else: name = remapped_kv_scale_name + if is_pp_missing_parameter(name, self): + continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) if use_default_weight_loading and name in params_dict: + if is_pp_missing_parameter(name, self): + continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) diff --git a/vllm/model_executor/models/chatglm.py b/vllm/model_executor/models/chatglm.py index 35f1ed5ef5d33..879795c0d5955 100644 --- a/vllm/model_executor/models/chatglm.py +++ b/vllm/model_executor/models/chatglm.py @@ -2,7 +2,7 @@ # Adapted from # https://github.com/THUDM/ChatGLM2-6B """Inference-only ChatGLM model compatible with THUDM weights.""" -from typing import Iterable, List, Optional, Tuple +from typing import Iterable, List, Optional, Tuple, Union import torch from torch import nn @@ -10,15 +10,14 @@ from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig, LoRAConfig -from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -28,14 +27,16 @@ from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs import ChatGLMConfig -from .interfaces import SupportsLoRA +from .interfaces import SupportsLoRA, SupportsPP +from .utils import (is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers) class GLMAttention(nn.Module): def __init__( self, - config, + config: ChatGLMConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): @@ -126,7 +127,7 @@ class GLMMLP(nn.Module): def __init__( self, - config, + config: ChatGLMConfig, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -169,7 +170,7 @@ class GLMBlock(nn.Module): def __init__( self, - config, + config: ChatGLMConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): @@ -240,9 +241,10 @@ class GLMTransformer(nn.Module): def __init__( self, - config, + config: ChatGLMConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() self.post_layer_norm = config.post_layer_norm @@ -251,10 +253,11 @@ def __init__( self.num_layers = config.num_layers # Transformer layers. - self.layers = nn.ModuleList([ - GLMBlock(config, cache_config, quant_config) - for i in range(self.num_layers) - ]) + self.start_layer, self.end_layer, self.layers = make_layers( + self.num_layers, + lambda prefix: GLMBlock(config, cache_config, quant_config), + prefix=f"{prefix}.layers", + ) if self.post_layer_norm: layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm @@ -269,16 +272,16 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, ) -> torch.Tensor: - for i in range(self.num_layers): + for i in range(self.start_layer, self.end_layer): layer = self.layers[i] hidden_states = layer( hidden_states=hidden_states, position_ids=position_ids, - kv_cache=kv_caches[i], + kv_cache=kv_caches[i - self.start_layer], attn_metadata=attn_metadata, ) # Final layer norm. - if self.post_layer_norm: + if get_pp_group().is_last_rank and self.post_layer_norm: hidden_states = self.final_layernorm(hidden_states) return hidden_states @@ -288,7 +291,7 @@ class ChatGLMModel(nn.Module): def __init__( self, - config, + config: ChatGLMConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): @@ -305,6 +308,9 @@ def __init__( self.output_layer = ParallelLMHead(config.padded_vocab_size, config.hidden_size, quant_config=quant_config) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory(["hidden_states"], + config.hidden_size)) def forward( self, @@ -312,8 +318,12 @@ def forward( position_ids: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, - ) -> torch.Tensor: - inputs_embeds = self.embedding(input_ids) + intermediate_tensors: Optional[IntermediateTensors], + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + inputs_embeds = self.embedding(input_ids) + else: + inputs_embeds = intermediate_tensors["hidden_states"] # Run encoder. hidden_states = self.encoder( @@ -322,10 +332,13 @@ def forward( kv_caches=kv_caches, attn_metadata=attn_metadata, ) + + if not get_pp_group().is_last_rank: + return IntermediateTensors({"hidden_states": hidden_states}) return hidden_states -class ChatGLMForCausalLM(nn.Module, SupportsLoRA): +class ChatGLMForCausalLM(nn.Module, SupportsLoRA, SupportsPP): packed_modules_mapping = { "query_key_value": ["query_key_value"], "dense_h_to_4h": ["dense_h_to_4h"] @@ -362,6 +375,8 @@ def __init__( self.lm_head = self.transformer.output_layer self.logits_processor = LogitsProcessor(config.padded_vocab_size) self.sampler = Sampler() + self.make_empty_intermediate_tensors = ( + self.transformer.make_empty_intermediate_tensors) def forward( self, @@ -370,9 +385,9 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.transformer(input_ids, positions, kv_caches, - attn_metadata) + attn_metadata, intermediate_tensors) return hidden_states def compute_logits( @@ -402,6 +417,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + if is_pp_missing_parameter(name, self): + continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) diff --git a/vllm/model_executor/models/commandr.py b/vllm/model_executor/models/commandr.py index 649dc798d22dc..a0b8ff3a85c98 100644 --- a/vllm/model_executor/models/commandr.py +++ b/vllm/model_executor/models/commandr.py @@ -20,7 +20,7 @@ # This file is based on the LLama model definition file in transformers """PyTorch Cohere model.""" -from typing import Iterable, List, Optional, Set, Tuple +from typing import Iterable, List, Optional, Set, Tuple, Union import torch import torch.utils.checkpoint @@ -29,14 +29,13 @@ from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig, LoRAConfig -from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -47,7 +46,9 @@ from vllm.model_executor.utils import set_weight_attrs from vllm.sequence import IntermediateTensors -from .interfaces import SupportsLoRA +from .interfaces import SupportsLoRA, SupportsPP +from .utils import (is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers) @torch.compile @@ -82,7 +83,7 @@ class CohereMLP(nn.Module): def __init__( self, - config, + config: CohereConfig, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -256,6 +257,7 @@ def __init__( cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, + prefix: str = "", ): super().__init__() self.config = config @@ -265,12 +267,16 @@ def __init__( self.org_vocab_size = config.vocab_size self.embed_tokens = VocabParallelEmbedding(config.vocab_size, config.hidden_size) - self.layers = nn.ModuleList([ - CohereDecoderLayer(config, cache_config, quant_config=quant_config) - for _ in range(config.num_hidden_layers) - ]) + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: CohereDecoderLayer(config, cache_config, + quant_config), + prefix=f"{prefix}.layers") self.norm = LayerNorm(param_shape=(config.hidden_size), eps=config.layer_norm_eps) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) def forward( self, @@ -278,23 +284,34 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, - ) -> torch.Tensor: - hidden_states = self.embed_tokens(input_ids) - residual = None - for i in range(len(self.layers)): + intermediate_tensors: Optional[IntermediateTensors], + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + hidden_states = self.embed_tokens(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + for i in range(self.start_layer, self.end_layer): layer = self.layers[i] hidden_states, residual = layer( positions, hidden_states, - kv_caches[i], + kv_caches[i - self.start_layer], attn_metadata, residual, ) + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) hidden_states, _ = self.norm(hidden_states, residual) return hidden_states -class CohereForCausalLM(nn.Module, SupportsLoRA): +class CohereForCausalLM(nn.Module, SupportsLoRA, SupportsPP): packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -337,6 +354,8 @@ def __init__( quant_config, lora_config=lora_config) self.sampler = Sampler() + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) @torch.no_grad() def forward( @@ -346,9 +365,9 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata) + attn_metadata, intermediate_tensors) return hidden_states def compute_logits( @@ -393,6 +412,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + if is_pp_missing_parameter(name, self): + continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) @@ -405,6 +426,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + if is_pp_missing_parameter(name, self): + continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) diff --git a/vllm/model_executor/models/dbrx.py b/vllm/model_executor/models/dbrx.py index 397a46a486f72..b0b07e9c03a9d 100644 --- a/vllm/model_executor/models/dbrx.py +++ b/vllm/model_executor/models/dbrx.py @@ -1,20 +1,19 @@ # coding=utf-8 -from typing import Iterable, List, Optional, Tuple +from typing import Iterable, List, Optional, Tuple, Union import torch import torch.nn as nn from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig -from vllm.distributed import (get_tensor_model_parallel_rank, +from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.linear import (QKVParallelLinear, ReplicatedLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -24,6 +23,10 @@ from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs.dbrx import DbrxConfig +from .interfaces import SupportsPP +from .utils import (is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers) + class DbrxRouter(nn.Module): """A Router implementation for DBRX that returns logits for each expert @@ -296,22 +299,27 @@ def __init__( config: DbrxConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() self.wte = VocabParallelEmbedding( config.vocab_size, config.d_model, ) - self.blocks = nn.ModuleList([ - DbrxBlock(config, cache_config, quant_config) - for _ in range(config.n_layers) - ]) + self.start_layer, self.end_layer, self.blocks = make_layers( + config.n_layers, + lambda prefix: DbrxBlock(config, cache_config, quant_config), + prefix=f"{prefix}.blocks", + ) self.norm_f = nn.LayerNorm(config.d_model, eps=1e-5) for module in self.modules(): if hasattr(module, "bias") and isinstance(module.bias, nn.Parameter): # Remove the bias term in Linear and LayerNorm. module.register_parameter("bias", None) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory(["hidden_states"], + config.d_model)) def forward( self, @@ -319,21 +327,28 @@ def forward( position_ids: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, - ) -> torch.Tensor: - hidden_states = self.wte(input_ids) - for i in range(len(self.blocks)): + intermediate_tensors: Optional[IntermediateTensors], + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + hidden_states = self.wte(input_ids) + else: + assert intermediate_tensors + hidden_states = intermediate_tensors["hidden_states"] + for i in range(self.start_layer, self.end_layer): block = self.blocks[i] hidden_states = block( position_ids, hidden_states, - kv_caches[i], + kv_caches[i - self.start_layer], attn_metadata, ) + if not get_pp_group().is_last_rank: + return IntermediateTensors({"hidden_states": hidden_states}) hidden_states = self.norm_f(hidden_states) return hidden_states -class DbrxForCausalLM(nn.Module): +class DbrxForCausalLM(nn.Module, SupportsPP): def __init__( self, @@ -359,6 +374,8 @@ def __init__( self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size) self.sampler = Sampler() + self.make_empty_intermediate_tensors = ( + self.transformer.make_empty_intermediate_tensors) def forward( self, @@ -367,9 +384,9 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.transformer(input_ids, positions, kv_caches, - attn_metadata) + attn_metadata, intermediate_tensors) return hidden_states def compute_logits( @@ -401,11 +418,15 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): if weight_name not in name: continue name = name.replace(weight_name, param_name) + if is_pp_missing_parameter(name, self): + continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, weight_name) break else: + if is_pp_missing_parameter(name, self): + continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) diff --git a/vllm/model_executor/models/decilm.py b/vllm/model_executor/models/decilm.py index 65b409a2a15a0..7ed2b96e65c49 100644 --- a/vllm/model_executor/models/decilm.py +++ b/vllm/model_executor/models/decilm.py @@ -29,11 +29,12 @@ from transformers import LlamaConfig from vllm.config import CacheConfig, LoRAConfig -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.llama import LlamaForCausalLM +from .utils import is_pp_missing_parameter + class DeciLMForCausalLM(LlamaForCausalLM): """ @@ -91,6 +92,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + if is_pp_missing_parameter(name, self): + continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) @@ -99,6 +102,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + if is_pp_missing_parameter(name, self): + continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) diff --git a/vllm/model_executor/models/deepseek.py b/vllm/model_executor/models/deepseek.py index 61cc917ab6207..5b4db8f258711 100644 --- a/vllm/model_executor/models/deepseek.py +++ b/vllm/model_executor/models/deepseek.py @@ -21,7 +21,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Deepseek model.""" -from typing import Any, Dict, Iterable, List, Optional, Tuple +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union import torch from torch import nn @@ -29,7 +29,7 @@ from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig -from vllm.distributed import (get_tensor_model_parallel_rank, +from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) from vllm.model_executor.layers.activation import SiluAndMul @@ -40,8 +40,7 @@ ReplicatedLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -50,6 +49,10 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors +from .interfaces import SupportsPP +from .utils import (is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers) + class DeepseekMLP(nn.Module): @@ -329,6 +332,7 @@ def __init__( config: PretrainedConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.padding_idx = config.pad_token_id @@ -338,14 +342,17 @@ def __init__( config.vocab_size, config.hidden_size, ) - self.layers = nn.ModuleList([ - DeepseekDecoderLayer(config, - layer_idx, - cache_config, - quant_config=quant_config) - for layer_idx in range(config.num_hidden_layers) - ]) + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: DeepseekDecoderLayer(config, + int(prefix.split(".")[-1]), + cache_config, + quant_config=quant_config), + prefix=f"{prefix}.layers") self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) def forward( self, @@ -353,19 +360,29 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, - ) -> torch.Tensor: - hidden_states = self.embed_tokens(input_ids) - residual = None - for i in range(len(self.layers)): + intermediate_tensors: Optional[IntermediateTensors], + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + hidden_states = self.embed_tokens(input_ids) + residual = None + else: + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + for i in range(self.start_layer, self.end_layer): layer = self.layers[i] hidden_states, residual = layer(positions, hidden_states, - kv_caches[i], attn_metadata, - residual) + kv_caches[i - self.start_layer], + attn_metadata, residual) + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) hidden_states, _ = self.norm(hidden_states, residual) return hidden_states -class DeepseekForCausalLM(nn.Module): +class DeepseekForCausalLM(nn.Module, SupportsPP): def __init__( self, @@ -384,6 +401,8 @@ def __init__( self.lm_head.weight = self.model.embed_tokens.weight self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) def forward( self, @@ -392,9 +411,9 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata) + attn_metadata, intermediate_tensors) return hidden_states def compute_logits( @@ -439,6 +458,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): if (("mlp.experts." in name or "mlp.shared_experts." in name) and name not in params_dict): continue + if is_pp_missing_parameter(name, self): + continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) @@ -451,6 +472,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): if (("mlp.experts." in name or "mlp.shared_experts." in name) and name not in params_dict): continue + if is_pp_missing_parameter(name, self): + continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 8cbd9435ec7ca..702be7b7f5ed9 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -21,7 +21,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only DeepseekV2 model.""" -from typing import Any, Dict, Iterable, List, Optional, Tuple +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union import torch from torch import nn @@ -40,8 +40,7 @@ ReplicatedLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -50,7 +49,9 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors -from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers +from .interfaces import SupportsPP +from .utils import (PPMissingLayer, is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers) class DeepseekV2MLP(nn.Module): @@ -439,6 +440,9 @@ def __init__( self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) else: self.norm = PPMissingLayer() + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) def forward( self, @@ -447,7 +451,7 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], - ) -> torch.Tensor: + ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: hidden_states = self.embed_tokens(input_ids) residual = None @@ -472,7 +476,7 @@ def forward( return hidden_states -class DeepseekV2ForCausalLM(nn.Module): +class DeepseekV2ForCausalLM(nn.Module, SupportsPP): def __init__( self, @@ -492,6 +496,8 @@ def __init__( quant_config=quant_config) self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) def forward( self, @@ -500,7 +506,7 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model(input_ids, positions, kv_caches, attn_metadata, intermediate_tensors) return hidden_states diff --git a/vllm/model_executor/models/exaone.py b/vllm/model_executor/models/exaone.py index 4a1c367de3f62..dfb8fe55d2fb8 100644 --- a/vllm/model_executor/models/exaone.py +++ b/vllm/model_executor/models/exaone.py @@ -38,8 +38,7 @@ QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( get_compressed_tensors_cache_scale) from vllm.model_executor.layers.rotary_embedding import get_rope @@ -53,8 +52,9 @@ from vllm.transformers_utils.configs.exaone import ExaoneConfig from vllm.utils import is_hip -from .interfaces import SupportsLoRA -from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers +from .interfaces import SupportsLoRA, SupportsPP +from .utils import (PPMissingLayer, is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers) class ExaoneGatedMLP(nn.Module): @@ -354,6 +354,10 @@ def __init__( else: self.ln_f = PPMissingLayer() + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.wte(input_ids) @@ -397,7 +401,7 @@ def forward( return hidden_states -class ExaoneForCausalLM(nn.Module, SupportsLoRA): +class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP): packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -477,6 +481,9 @@ def __init__( else: self.lm_head = PPMissingLayer() + self.make_empty_intermediate_tensors = ( + self.transformer.make_empty_intermediate_tensors) + def forward( self, input_ids: torch.Tensor, @@ -506,24 +513,6 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def make_empty_intermediate_tensors( - self, batch_size: int, dtype: torch.dtype, - device: torch.device) -> IntermediateTensors: - return IntermediateTensors({ - "hidden_states": - torch.zeros( - (batch_size, self.config.hidden_size), - dtype=dtype, - device=device, - ), - "residual": - torch.zeros( - (batch_size, self.config.hidden_size), - dtype=dtype, - device=device, - ), - }) - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) diff --git a/vllm/model_executor/models/falcon.py b/vllm/model_executor/models/falcon.py index b474d35baf89d..a20dd93cee18c 100644 --- a/vllm/model_executor/models/falcon.py +++ b/vllm/model_executor/models/falcon.py @@ -28,7 +28,7 @@ from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig -from vllm.distributed import (get_tensor_model_parallel_rank, +from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) from vllm.model_executor.layers.activation import get_act_fn @@ -36,8 +36,7 @@ QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -47,6 +46,10 @@ from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs import RWConfig +from .interfaces import SupportsPP +from .utils import (is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers) + FalconConfig = Union[HF_FalconConfig, RWConfig] @@ -333,6 +336,7 @@ def __init__( config: FalconConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() self.config = config @@ -347,35 +351,45 @@ def __init__( ) # Transformer blocks - self.h = nn.ModuleList([ - FalconDecoderLayer(config, cache_config, quant_config) - for _ in range(config.num_hidden_layers) - ]) + self.start_layer, self.end_layer, self.h = make_layers( + config.num_hidden_layers, + lambda prefix: FalconDecoderLayer(config, cache_config, + quant_config), + prefix=f"{prefix}.h") # Final Layer Norm self.ln_f = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory(["hidden_states"], + config.hidden_size)) def forward( self, - input_ids: torch.LongTensor, + input_ids: torch.Tensor, positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, - ) -> torch.Tensor: - hidden_states = self.word_embeddings(input_ids) - for i in range(len(self.h)): + intermediate_tensors: Optional[IntermediateTensors], + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + hidden_states = self.word_embeddings(input_ids) + else: + hidden_states = intermediate_tensors["hidden_states"] + for i in range(self.start_layer, self.end_layer): layer = self.h[i] hidden_states = layer( positions, hidden_states, - kv_caches[i], + kv_caches[i - self.start_layer], attn_metadata, ) + if not get_pp_group().is_last_rank: + return IntermediateTensors({"hidden_states": hidden_states}) hidden_states = self.ln_f(hidden_states) return hidden_states -class FalconForCausalLM(nn.Module): +class FalconForCausalLM(nn.Module, SupportsPP): def __init__( self, @@ -403,6 +417,8 @@ def __init__( ) self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() + self.make_empty_intermediate_tensors = ( + self.transformer.make_empty_intermediate_tensors) def forward( self, @@ -412,12 +428,8 @@ def forward( attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, ) -> torch.Tensor: - hidden_states = self.transformer( - input_ids, - positions, - kv_caches, - attn_metadata, - ) + hidden_states = self.transformer(input_ids, positions, kv_caches, + attn_metadata, intermediate_tensors) return hidden_states def compute_logits( @@ -454,6 +466,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + if is_pp_missing_parameter(name, self): + continue param = params_dict[name] if "query_key_value" in name: output_dim = getattr(param, "output_dim", None) diff --git a/vllm/model_executor/models/fuyu.py b/vllm/model_executor/models/fuyu.py index 87b88da0dc05c..835931746fd4b 100644 --- a/vllm/model_executor/models/fuyu.py +++ b/vllm/model_executor/models/fuyu.py @@ -41,8 +41,9 @@ from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors, SequenceData) -from .interfaces import SupportsMultiModal -from .utils import flatten_bn, merge_multimodal_embeddings +from .interfaces import SupportsMultiModal, SupportsPP +from .utils import (flatten_bn, group_weights_with_prefix, + merge_multimodal_embeddings) # Cannot find the following 2 numbers from hf config. _IMAGE_TOKEN_ID = 71011 @@ -217,7 +218,7 @@ def input_mapper_for_fuyu(ctx: InputContext, data: object): @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_fuyu_image_tokens) @INPUT_REGISTRY.register_dummy_data(dummy_data_for_fuyu) @INPUT_REGISTRY.register_input_processor(input_processor_for_fuyu) -class FuyuForCausalLM(nn.Module, SupportsMultiModal): +class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): def __init__(self, config: FuyuConfig, @@ -242,6 +243,12 @@ def __init__(self, self.language_model = PersimmonForCausalLM(config.text_config, cache_config=cache_config, quant_config=quant_config) + self.make_empty_intermediate_tensors = ( + self.language_model.make_empty_intermediate_tensors) + + @property + def sampler(self): + return self.language_model.sampler def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: @@ -297,23 +304,29 @@ def forward( intermediate_tensors: Optional[IntermediateTensors] = None, **kwargs: object, ): - image_input = self._parse_and_validate_image_input(**kwargs) + if intermediate_tensors is not None: + input_ids = None + inputs_embeds = None + else: + image_input = self._parse_and_validate_image_input(**kwargs) - if image_input is not None: - vision_embeddings = self._process_image_input(image_input) - inputs_embeds = self.language_model.model.embed_tokens(input_ids) - inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, vision_embeddings, - self.image_token_id) + if image_input is not None: + vision_embeddings = self._process_image_input(image_input) + inputs_embeds = self.language_model.model.embed_tokens( + input_ids) + inputs_embeds = merge_multimodal_embeddings( + input_ids, inputs_embeds, vision_embeddings, + self.image_token_id) - else: - inputs_embeds = None + else: + inputs_embeds = None hidden_states = self.language_model( input_ids=input_ids, positions=positions, kv_caches=kv_caches, attn_metadata=attn_metadata, + intermediate_tensors=intermediate_tensors, inputs_embeds=inputs_embeds, ) return hidden_states @@ -336,34 +349,16 @@ def sample( return next_tokens def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): - params_dict = dict(self.named_parameters(remove_duplicate=False)) - for name, loaded_weight in weights: - if "rotary_emb.inv_freq" in name: - continue - if ("rotary_emb.cos_cached" in name - or "rotary_emb.sin_cached" in name): - # Models trained using ColossalAI may include these tensors in - # the checkpoint. Skip them. - continue - param = params_dict[name] - - if "query_key_value" in name: - # copy from vllm/model_executor/models/bloom.py - # NOTE: Fuyu's fused QKV's output_dim has the shape of - # (num_heads * 3 * head_size), while the - # required shape is (3 * num_heads * head_size). - # Thus, we need weight conversion. - output_dim = getattr(param, "output_dim", None) - num_heads = self.config.num_attention_heads - if output_dim is not None: - loaded_weight_shape = loaded_weight.shape - loaded_weight = loaded_weight.view( - loaded_weight_shape[:output_dim] + (num_heads, 3, -1) + - loaded_weight_shape[output_dim + 1:]) - loaded_weight = loaded_weight.transpose( - output_dim, output_dim + 1) - loaded_weight = loaded_weight.reshape(loaded_weight_shape) + # prepare weight iterators for components + weights_group = group_weights_with_prefix(weights) + # load vision embeddings + vision_params_dict = dict(self.vision_embed_tokens.named_parameters()) + for name, loaded_weight in weights_group["vision_embed_tokens"]: + param = vision_params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + + # load llm backbone + self.language_model.load_weights(weights_group["language_model"]) diff --git a/vllm/model_executor/models/gemma.py b/vllm/model_executor/models/gemma.py index 36fd389831282..ca419891f69db 100644 --- a/vllm/model_executor/models/gemma.py +++ b/vllm/model_executor/models/gemma.py @@ -15,7 +15,7 @@ # limitations under the License. """Inference-only Gemma model compatible with HuggingFace weights.""" from functools import lru_cache -from typing import Iterable, List, Optional, Set, Tuple +from typing import Iterable, List, Optional, Set, Tuple, Union import torch from torch import nn @@ -23,7 +23,7 @@ from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig, LoRAConfig -from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.logger import init_logger from vllm.model_executor.layers.activation import GeluAndMul from vllm.model_executor.layers.layernorm import GemmaRMSNorm @@ -31,8 +31,7 @@ QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -41,7 +40,9 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors -from .interfaces import SupportsLoRA +from .interfaces import SupportsLoRA, SupportsPP +from .utils import (is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers) logger = init_logger(__name__) @@ -245,6 +246,7 @@ def __init__( config: GemmaConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.config = config @@ -253,10 +255,11 @@ def __init__( config.vocab_size, config.hidden_size, ) - self.layers = nn.ModuleList([ - GemmaDecoderLayer(config, cache_config, quant_config) - for _ in range(config.num_hidden_layers) - ]) + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: GemmaDecoderLayer(config, cache_config, quant_config + ), + prefix=f"{prefix}.layers") self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) # Normalize the embedding by sqrt(hidden_size) @@ -265,6 +268,9 @@ def __init__( # See https://github.com/huggingface/transformers/pull/29402 normalizer = self.config.hidden_size**0.5 self.register_buffer("normalizer", torch.tensor(normalizer)) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) @@ -275,29 +281,38 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, - intermediate_tensors: Optional[IntermediateTensors] = None, + intermediate_tensors: Optional[IntermediateTensors], inputs_embeds: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - if inputs_embeds is not None: - hidden_states = inputs_embeds + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + hidden_states *= self.normalizer + residual = None else: - hidden_states = self.get_input_embeddings(input_ids) - hidden_states *= self.normalizer - residual = None - for i in range(len(self.layers)): + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + for i in range(self.start_layer, self.end_layer): layer = self.layers[i] hidden_states, residual = layer( positions, hidden_states, - kv_caches[i], + kv_caches[i - self.start_layer], attn_metadata, residual, ) + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) hidden_states, _ = self.norm(hidden_states, residual) return hidden_states -class GemmaForCausalLM(nn.Module, SupportsLoRA): +class GemmaForCausalLM(nn.Module, SupportsLoRA, SupportsPP): packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -339,6 +354,8 @@ def __init__( self.model = GemmaModel(config, cache_config, quant_config) self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) def forward( self, @@ -347,9 +364,9 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata) + attn_metadata, intermediate_tensors) return hidden_states def compute_logits( @@ -388,6 +405,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + if is_pp_missing_parameter(name, self): + continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) @@ -400,6 +419,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + if is_pp_missing_parameter(name, self): + continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) diff --git a/vllm/model_executor/models/gemma2.py b/vllm/model_executor/models/gemma2.py index f9d9f9e7567c8..9fddaac3a0837 100644 --- a/vllm/model_executor/models/gemma2.py +++ b/vllm/model_executor/models/gemma2.py @@ -14,7 +14,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Iterable, List, Optional, Set, Tuple +from typing import Iterable, List, Optional, Set, Tuple, Union import torch from torch import nn @@ -22,7 +22,7 @@ from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig, LoRAConfig -from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.logger import init_logger from vllm.model_executor.layers.activation import GeluAndMul from vllm.model_executor.layers.layernorm import GemmaRMSNorm @@ -30,8 +30,7 @@ QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -40,7 +39,9 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors -from .interfaces import SupportsLoRA +from .interfaces import SupportsLoRA, SupportsPP +from .utils import (is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers) logger = init_logger(__name__) @@ -244,6 +245,7 @@ def __init__( config: Gemma2Config, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.config = config @@ -252,10 +254,11 @@ def __init__( config.vocab_size, config.hidden_size, ) - self.layers = nn.ModuleList([ - Gemma2DecoderLayer(layer_idx, config, cache_config, quant_config) - for layer_idx in range(config.num_hidden_layers) - ]) + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: Gemma2DecoderLayer(int(prefix.split(".")[ + -1]), config, cache_config, quant_config), + prefix=f"{prefix}.layers") self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) # Normalize the embedding by sqrt(hidden_size) @@ -264,6 +267,9 @@ def __init__( # See https://github.com/huggingface/transformers/pull/29402 normalizer = self.config.hidden_size**0.5 self.register_buffer("normalizer", torch.tensor(normalizer)) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) def forward( self, @@ -271,25 +277,36 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, - ) -> torch.Tensor: - hidden_states = self.embed_tokens(input_ids) - hidden_states *= self.normalizer + intermediate_tensors: Optional[IntermediateTensors], + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + hidden_states = self.embed_tokens(input_ids) + hidden_states *= self.normalizer - residual = None - for i in range(len(self.layers)): + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + for i in range(self.start_layer, self.end_layer): layer = self.layers[i] hidden_states, residual = layer( positions, hidden_states, - kv_caches[i], + kv_caches[i - self.start_layer], attn_metadata, residual, ) + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) hidden_states, _ = self.norm(hidden_states, residual) return hidden_states -class Gemma2ForCausalLM(nn.Module, SupportsLoRA): +class Gemma2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -338,6 +355,8 @@ def __init__( self.logits_processor = LogitsProcessor( config.vocab_size, soft_cap=config.final_logit_softcapping) self.sampler = Sampler() + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) def forward( self, @@ -346,9 +365,9 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata) + attn_metadata, intermediate_tensors) return hidden_states def compute_logits( @@ -387,6 +406,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + if is_pp_missing_parameter(name, self): + continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) @@ -399,6 +420,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + if is_pp_missing_parameter(name, self): + continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) diff --git a/vllm/model_executor/models/gpt2.py b/vllm/model_executor/models/gpt2.py index fb5a297661ddc..975502340e5f9 100644 --- a/vllm/model_executor/models/gpt2.py +++ b/vllm/model_executor/models/gpt2.py @@ -32,8 +32,7 @@ QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) @@ -41,7 +40,9 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors -from .utils import is_pp_missing_parameter, make_layers +from .interfaces import SupportsPP +from .utils import (is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers) class GPT2Attention(nn.Module): @@ -204,6 +205,9 @@ def __init__( config, cache_config, quant_config, prefix=prefix), prefix=f"{prefix}.h") self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory(["hidden_states"], + config.n_embd)) def forward( self, @@ -234,7 +238,7 @@ def forward( return hidden_states -class GPT2LMHeadModel(nn.Module): +class GPT2LMHeadModel(nn.Module, SupportsPP): def __init__( self, @@ -256,6 +260,8 @@ def __init__( self.config.hidden_size) self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() + self.make_empty_intermediate_tensors = ( + self.transformer.make_empty_intermediate_tensors) def forward( self, @@ -264,7 +270,7 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.transformer(input_ids, positions, kv_caches, attn_metadata, intermediate_tensors) return hidden_states @@ -286,16 +292,6 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def make_empty_intermediate_tensors( - self, batch_size: int, dtype: torch.dtype, - device: torch.device) -> IntermediateTensors: - return IntermediateTensors({ - "hidden_states": - torch.zeros((batch_size, self.config.hidden_size), - dtype=dtype, - device=device), - }) - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): params_dict = dict(self.named_parameters(remove_duplicate=False)) for name, loaded_weight in weights: diff --git a/vllm/model_executor/models/gpt_bigcode.py b/vllm/model_executor/models/gpt_bigcode.py index fe5ec10827608..6c4a04667c5da 100644 --- a/vllm/model_executor/models/gpt_bigcode.py +++ b/vllm/model_executor/models/gpt_bigcode.py @@ -18,7 +18,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only GPTBigCode model compatible with HuggingFace weights.""" -from typing import Iterable, List, Optional, Tuple +from typing import Iterable, List, Optional, Tuple, Union import torch from torch import nn @@ -26,14 +26,13 @@ from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig, LoRAConfig -from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) @@ -41,7 +40,9 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors -from .interfaces import SupportsLoRA +from .interfaces import SupportsLoRA, SupportsPP +from .utils import (is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers) class GPTBigCodeAttention(nn.Module): @@ -194,6 +195,7 @@ def __init__( cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, + prefix: str = "", ): super().__init__() self.config = config @@ -207,11 +209,15 @@ def __init__( self.embed_dim, org_num_embeddings=config.vocab_size) self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) - self.h = nn.ModuleList([ - GPTBigCodeBlock(config, cache_config, quant_config) - for _ in range(config.num_hidden_layers) - ]) + self.start_layer, self.end_layer, self.h = make_layers( + config.num_hidden_layers, + lambda prefix: GPTBigCodeBlock(config, cache_config, quant_config), + prefix=f"{prefix}.h", + ) self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory(["hidden_states"], + config.n_embd)) def forward( self, @@ -219,20 +225,28 @@ def forward( position_ids: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, - ) -> torch.Tensor: - inputs_embeds = self.wte(input_ids) - position_embeds = self.wpe(position_ids) - hidden_states = inputs_embeds + position_embeds + intermediate_tensors: Optional[IntermediateTensors], + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + inputs_embeds = self.wte(input_ids) + position_embeds = self.wpe(position_ids) + hidden_states = inputs_embeds + position_embeds + else: + hidden_states = intermediate_tensors["hidden_states"] - for i in range(len(self.h)): + for i in range(self.start_layer, self.end_layer): layer = self.h[i] - hidden_states = layer(hidden_states, kv_caches[i], attn_metadata) + hidden_states = layer(hidden_states, + kv_caches[i - self.start_layer], + attn_metadata) + if not get_pp_group().is_last_rank: + return IntermediateTensors({"hidden_states": hidden_states}) hidden_states = self.ln_f(hidden_states) return hidden_states -class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA): +class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA, SupportsPP): packed_modules_mapping = {"c_attn": ["c_attn"]} supported_lora_modules = ["c_fc", "c_proj", "wte", "c_attn"] @@ -272,6 +286,8 @@ def __init__( self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size) self.sampler = Sampler() + self.make_empty_intermediate_tensors = ( + self.transformer.make_empty_intermediate_tensors) def forward( self, @@ -280,9 +296,9 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.transformer(input_ids, positions, kv_caches, - attn_metadata) + attn_metadata, intermediate_tensors) return hidden_states def compute_logits( @@ -311,6 +327,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip attention mask. # NOTE: "c_attn.bias" should not be skipped. continue + if is_pp_missing_parameter(name, self): + continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) diff --git a/vllm/model_executor/models/gpt_j.py b/vllm/model_executor/models/gpt_j.py index 664d775c8ba40..d40bf8c88ee19 100644 --- a/vllm/model_executor/models/gpt_j.py +++ b/vllm/model_executor/models/gpt_j.py @@ -16,7 +16,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only GPT-J model compatible with HuggingFace weights.""" -from typing import Iterable, List, Optional, Tuple +from typing import Iterable, List, Optional, Tuple, Union import torch from torch import nn @@ -24,14 +24,13 @@ from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig -from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -40,6 +39,10 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors +from .interfaces import SupportsPP +from .utils import (is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers) + class GPTJAttention(nn.Module): @@ -178,6 +181,7 @@ def __init__( config: GPTJConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() self.config = config @@ -186,11 +190,15 @@ def __init__( config.vocab_size, self.embed_dim, ) - self.h = nn.ModuleList([ - GPTJBlock(config, cache_config, quant_config) - for _ in range(config.n_layer) - ]) + self.start_layer, self.end_layer, self.h = make_layers( + config.n_layer, + lambda prefix: GPTJBlock(config, cache_config, quant_config), + prefix=f"{prefix}.h", + ) self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory(["hidden_states"], + config.n_embd)) def forward( self, @@ -198,21 +206,27 @@ def forward( position_ids: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, - ) -> torch.Tensor: - hidden_states = self.wte(input_ids) - for i in range(len(self.h)): + intermediate_tensors: Optional[IntermediateTensors], + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + hidden_states = self.wte(input_ids) + else: + hidden_states = intermediate_tensors["hidden_states"] + for i in range(self.start_layer, self.end_layer): layer = self.h[i] hidden_states = layer( position_ids, hidden_states, - kv_caches[i], + kv_caches[i - self.start_layer], attn_metadata, ) + if not get_pp_group().is_last_rank: + return IntermediateTensors({"hidden_states": hidden_states}) hidden_states = self.ln_f(hidden_states) return hidden_states -class GPTJForCausalLM(nn.Module): +class GPTJForCausalLM(nn.Module, SupportsPP): def __init__( self, @@ -233,6 +247,8 @@ def __init__( ) self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() + self.make_empty_intermediate_tensors = ( + self.transformer.make_empty_intermediate_tensors) def forward( self, @@ -241,9 +257,9 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.transformer(input_ids, positions, kv_caches, - attn_metadata) + attn_metadata, intermediate_tensors) return hidden_states def compute_logits( @@ -283,6 +299,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + if is_pp_missing_parameter(name, self): + continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) @@ -291,6 +309,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + if is_pp_missing_parameter(name, self): + continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) diff --git a/vllm/model_executor/models/gpt_neox.py b/vllm/model_executor/models/gpt_neox.py index 5f6f1e3880547..23a1ca06cc69e 100644 --- a/vllm/model_executor/models/gpt_neox.py +++ b/vllm/model_executor/models/gpt_neox.py @@ -16,7 +16,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only GPT-NeoX model compatible with HuggingFace weights.""" -from typing import Iterable, List, Optional, Tuple +from typing import Iterable, List, Optional, Tuple, Union import torch from torch import nn @@ -24,14 +24,13 @@ from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig -from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -40,6 +39,10 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors +from .interfaces import SupportsPP +from .utils import (is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers) + class GPTNeoXAttention(nn.Module): @@ -191,6 +194,7 @@ def __init__( config: GPTNeoXConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() self.config = config @@ -199,12 +203,16 @@ def __init__( config.vocab_size, config.hidden_size, ) - self.layers = nn.ModuleList([ - GPTNeoXLayer(config, cache_config, quant_config) - for _ in range(config.num_hidden_layers) - ]) + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: GPTNeoXLayer(config, cache_config, quant_config), + prefix=f"{prefix}.layers", + ) self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory(["hidden_states"], + config.hidden_size)) def forward( self, @@ -212,21 +220,27 @@ def forward( position_ids: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, - ) -> torch.Tensor: - hidden_states = self.embed_in(input_ids) - for i in range(len(self.layers)): + intermediate_tensors: Optional[IntermediateTensors], + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + hidden_states = self.embed_in(input_ids) + else: + hidden_states = intermediate_tensors["hidden_states"] + for i in range(self.start_layer, self.end_layer): layer = self.layers[i] hidden_states = layer( position_ids, hidden_states, - kv_caches[i], + kv_caches[i - self.start_layer], attn_metadata, ) + if not get_pp_group().is_last_rank: + return IntermediateTensors({"hidden_states": hidden_states}) hidden_states = self.final_layer_norm(hidden_states) return hidden_states -class GPTNeoXForCausalLM(nn.Module): +class GPTNeoXForCausalLM(nn.Module, SupportsPP): def __init__( self, @@ -247,6 +261,8 @@ def __init__( self.embed_out.weight = self.gpt_neox.embed_in.weight self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() + self.make_empty_intermediate_tensors = ( + self.gpt_neox.make_empty_intermediate_tensors) def forward( self, @@ -255,9 +271,9 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.gpt_neox(input_ids, positions, kv_caches, - attn_metadata) + attn_metadata, intermediate_tensors) return hidden_states def compute_logits( @@ -288,6 +304,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Models trained using OpenRLHF may include # these tensors in the checkpoint. Skip them. continue + if is_pp_missing_parameter(name, self): + continue param = params_dict[name] if "query_key_value" in name: diff --git a/vllm/model_executor/models/granite.py b/vllm/model_executor/models/granite.py index 48d43b204fc51..dcf4f5b27704a 100644 --- a/vllm/model_executor/models/granite.py +++ b/vllm/model_executor/models/granite.py @@ -51,7 +51,7 @@ from vllm.sequence import IntermediateTensors from vllm.utils import is_hip -from .interfaces import SupportsLoRA +from .interfaces import SupportsLoRA, SupportsPP from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers @@ -311,13 +311,13 @@ def forward( else: hidden_states = self.get_input_embeddings(input_ids) residual = None + + hidden_states *= self.config.embedding_multiplier else: assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - hidden_states *= self.config.embedding_multiplier - for i in range(self.start_layer, self.end_layer): layer = self.layers[i] hidden_states = layer( @@ -337,7 +337,7 @@ def forward( return hidden_states -class GraniteForCausalLM(nn.Module, SupportsLoRA): +class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP): packed_modules_mapping = { "qkv_proj": [ "q_proj", diff --git a/vllm/model_executor/models/granitemoe.py b/vllm/model_executor/models/granitemoe.py index 1cf2577d24937..5266951794a80 100644 --- a/vllm/model_executor/models/granitemoe.py +++ b/vllm/model_executor/models/granitemoe.py @@ -46,7 +46,7 @@ from vllm.sequence import IntermediateTensors from . import mixtral -from .interfaces import SupportsLoRA +from .interfaces import SupportsLoRA, SupportsPP from .utils import make_layers @@ -307,7 +307,7 @@ def forward( return hidden_states -class GraniteMoeForCausalLM(nn.Module, SupportsLoRA): +class GraniteMoeForCausalLM(nn.Module, SupportsLoRA, SupportsPP): fall_back_to_pt_during_load = False packed_modules_mapping = { diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index 069948f812253..298174fa05965 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -1,11 +1,17 @@ -from typing import (ClassVar, Dict, List, Literal, Optional, Protocol, Type, - Union, overload, runtime_checkable) +import inspect +from typing import (TYPE_CHECKING, ClassVar, Dict, List, Literal, Optional, + Protocol, Type, Union, overload, runtime_checkable) +import torch from typing_extensions import TypeIs -from vllm.config import LoRAConfig, MultiModalConfig, SchedulerConfig from vllm.logger import init_logger +if TYPE_CHECKING: + from vllm.attention import AttentionMetadata + from vllm.config import LoRAConfig, MultiModalConfig, SchedulerConfig + from vllm.sequence import IntermediateTensors + logger = init_logger(__name__) @@ -22,7 +28,7 @@ class SupportsMultiModal(Protocol): MRO of your model class. """ - def __init__(self, *, multimodal_config: MultiModalConfig) -> None: + def __init__(self, *, multimodal_config: "MultiModalConfig") -> None: ... @@ -32,7 +38,7 @@ def __init__(self, *, multimodal_config: MultiModalConfig) -> None: class _SupportsMultiModalType(Protocol): supports_multimodal: Literal[True] - def __call__(self, *, multimodal_config: MultiModalConfig) -> None: + def __call__(self, *, multimodal_config: "MultiModalConfig") -> None: ... @@ -75,7 +81,7 @@ class SupportsLoRA(Protocol): embedding_padding_modules: ClassVar[List[str]] # lora_config is None when LoRA is not enabled - def __init__(self, *, lora_config: Optional[LoRAConfig] = None) -> None: + def __init__(self, *, lora_config: Optional["LoRAConfig"] = None) -> None: ... @@ -90,7 +96,7 @@ class _SupportsLoRAType(Protocol): embedding_modules: Dict[str, str] embedding_padding_modules: List[str] - def __call__(self, *, lora_config: Optional[LoRAConfig] = None) -> None: + def __call__(self, *, lora_config: Optional["LoRAConfig"] = None) -> None: ... @@ -145,6 +151,132 @@ def _supports_lora( return isinstance(model, SupportsLoRA) +@runtime_checkable +class SupportsPP(Protocol): + """The interface required for all models that support pipeline parallel.""" + + supports_pp: ClassVar[Literal[True]] = True + """ + A flag that indicates this model supports pipeline parallel. + + Note: + There is no need to redefine this flag if this class is in the + MRO of your model class. + """ + + def make_empty_intermediate_tensors( + self, + batch_size: int, + dtype: torch.dtype, + device: torch.device, + ) -> "IntermediateTensors": + """Called when PP rank > 0 for profiling purposes.""" + ... + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: "AttentionMetadata", + intermediate_tensors: Optional["IntermediateTensors"], + ) -> Union[torch.Tensor, "IntermediateTensors"]: + """ + Accept :class:`IntermediateTensors` when PP rank > 0. + + Return :class:`IntermediateTensors` only for the last PP rank. + """ + ... + + +# We can't use runtime_checkable with ClassVar for issubclass checks +# so we need to treat the class as an instance and use isinstance instead +@runtime_checkable +class _SupportsPPType(Protocol): + supports_pp: Literal[True] + + def make_empty_intermediate_tensors( + self, + batch_size: int, + dtype: torch.dtype, + device: torch.device, + ) -> "IntermediateTensors": + ... + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: "AttentionMetadata", + intermediate_tensors: Optional["IntermediateTensors"], + ) -> Union[torch.Tensor, "IntermediateTensors"]: + ... + + +@overload +def supports_pp(model: Type[object]) -> TypeIs[Type[SupportsPP]]: + ... + + +@overload +def supports_pp(model: object) -> TypeIs[SupportsPP]: + ... + + +def supports_pp( + model: Union[Type[object], object], +) -> Union[bool, TypeIs[Type[SupportsPP]], TypeIs[SupportsPP]]: + supports_attributes = _supports_pp_attributes(model) + supports_inspect = _supports_pp_inspect(model) + + if supports_attributes and not supports_inspect: + logger.warning( + "The model (%s) sets `supports_pp=True`, but does not accept " + "`intermediate_tensors` in its `forward` method", model) + + if not supports_attributes: + pp_attrs = ("make_empty_intermediate_tensors", ) + missing_attrs = tuple(attr for attr in pp_attrs + if not hasattr(model, attr)) + + if getattr(model, "supports_pp", False): + if missing_attrs: + logger.warning( + "The model (%s) sets `supports_pp=True`, " + "but is missing PP-specific attributes: %s", + model, + missing_attrs, + ) + else: + if not missing_attrs: + logger.warning( + "The model (%s) contains all PP-specific attributes, " + "but does not set `supports_pp=True`.", model) + + return supports_attributes and supports_inspect + + +def _supports_pp_attributes( + model: Union[Type[object], object], +) -> Union[bool, TypeIs[Type[SupportsPP]], TypeIs[SupportsPP]]: + if isinstance(model, type): + return isinstance(model, _SupportsPPType) + + return isinstance(model, SupportsPP) + + +def _supports_pp_inspect( + model: Union[Type[object], object], +) -> Union[bool, TypeIs[Type[SupportsPP]], TypeIs[SupportsPP]]: + model_forward = getattr(model, "forward", None) + if not callable(model_forward): + return False + + forward_params = inspect.signature(model_forward).parameters + return "intermediate_tensors" in forward_params + + @runtime_checkable class HasInnerState(Protocol): """The interface required for all models that has inner state.""" @@ -158,7 +290,7 @@ class HasInnerState(Protocol): def __init__(self, *, - scheduler_config: Optional[SchedulerConfig] = None) -> None: + scheduler_config: Optional["SchedulerConfig"] = None) -> None: ... @@ -168,7 +300,7 @@ class _HasInnerStateType(Protocol): def __init__(self, *, - scheduler_config: Optional[SchedulerConfig] = None) -> None: + scheduler_config: Optional["SchedulerConfig"] = None) -> None: ... diff --git a/vllm/model_executor/models/internlm2.py b/vllm/model_executor/models/internlm2.py index 11a8431a5e7f7..f6cde44e9d83d 100644 --- a/vllm/model_executor/models/internlm2.py +++ b/vllm/model_executor/models/internlm2.py @@ -18,8 +18,7 @@ QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -28,6 +27,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors +from .interfaces import SupportsPP from .utils import (is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers) @@ -266,7 +266,7 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, - intermediate_tensors: IntermediateTensors = None, + intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: @@ -297,7 +297,7 @@ def forward( return hidden_states -class InternLM2ForCausalLM(nn.Module): +class InternLM2ForCausalLM(nn.Module, SupportsPP): def __init__( self, @@ -325,7 +325,7 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, - intermediate_tensors: IntermediateTensors, + intermediate_tensors: Optional[IntermediateTensors], ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, attn_metadata, intermediate_tensors) diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py index e84990a2ab109..816e93818f2ee 100644 --- a/vllm/model_executor/models/internvl.py +++ b/vllm/model_executor/models/internvl.py @@ -5,9 +5,9 @@ # Licensed under The MIT License [see LICENSE for details] # -------------------------------------------------------- import re -from functools import partial -from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional, - Tuple, TypedDict, Union) +from functools import cached_property, partial +from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, + TypedDict, Union) import torch import torch.nn as nn @@ -17,7 +17,6 @@ from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, MultiModalConfig -from vllm.distributed import get_pp_group from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import Sampler, SamplerOutput @@ -32,7 +31,7 @@ from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip, get_clip_num_patches) -from .interfaces import SupportsMultiModal +from .interfaces import SupportsMultiModal, SupportsPP from .utils import (flatten_bn, group_weights_with_prefix, init_vllm_registered_model, merge_multimodal_embeddings) @@ -123,7 +122,7 @@ def calculate_num_blocks(orig_width: int, orig_height: int, min_num: int, return blocks, target_width, target_height -def calculate_num_blocks_wrapper(hf_config: Dict[str, Any], +def calculate_num_blocks_wrapper(hf_config: PretrainedConfig, max_dynamic_patch: Optional[int] = None): if max_dynamic_patch is None: max_dynamic_patch = hf_config.max_dynamic_patch @@ -183,7 +182,7 @@ def image_to_pixel_values(image: Image.Image, input_size: int, min_num: int, return pixel_values -def image_to_pixel_values_wrapper(hf_config: Dict[str, Any], +def image_to_pixel_values_wrapper(hf_config: PretrainedConfig, max_dynamic_patch: Optional[int] = None): image_size = hf_config.vision_config.image_size min_num = hf_config.min_dynamic_patch @@ -197,7 +196,7 @@ def image_to_pixel_values_wrapper(hf_config: Dict[str, Any], use_thumbnail=use_thumbnail) -def get_internvl_num_patches(hf_config: Dict[str, Any]): +def get_internvl_num_patches(hf_config: PretrainedConfig): vision_config = hf_config.vision_config downsample_ratio = hf_config.downsample_ratio image_size = vision_config.image_size @@ -362,7 +361,7 @@ def dummy_data_for_internvl(ctx: InputContext, @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_internvl_image_tokens) @INPUT_REGISTRY.register_dummy_data(dummy_data_for_internvl) @INPUT_REGISTRY.register_input_processor(input_processor_for_internvl) -class InternVLChatModel(nn.Module, SupportsMultiModal): +class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP): def __init__(self, config: PretrainedConfig, @@ -408,10 +407,12 @@ def __init__(self, self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors) + @cached_property + def sampler(self): if hasattr(self.language_model, "sampler"): - self.sampler = self.language_model.sampler - else: - self.sampler = Sampler() + return self.language_model.sampler + + return Sampler() def pixel_shuffle(self, x, scale_factor=0.5): n, w, h, c = x.size() @@ -515,18 +516,22 @@ def forward( attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, **kwargs: object, - ) -> SamplerOutput: - image_input = self._parse_and_validate_image_input(**kwargs) - if image_input is not None and get_pp_group().is_first_rank: - inputs_embeds = self.language_model.model.get_input_embeddings( - input_ids) - vision_embeddings = self._process_image_input(image_input) - inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, vision_embeddings, - self.img_context_token_id) + ) -> Union[SamplerOutput, IntermediateTensors]: + if intermediate_tensors is not None: input_ids = None - else: inputs_embeds = None + else: + image_input = self._parse_and_validate_image_input(**kwargs) + if image_input is not None: + inputs_embeds = self.language_model.model.get_input_embeddings( + input_ids) + vision_embeddings = self._process_image_input(image_input) + inputs_embeds = merge_multimodal_embeddings( + input_ids, inputs_embeds, vision_embeddings, + self.img_context_token_id) + input_ids = None + else: + inputs_embeds = None hidden_states = self.language_model.model(input_ids, positions, diff --git a/vllm/model_executor/models/jais.py b/vllm/model_executor/models/jais.py index b0fbb7e9829e0..c5e5393442e30 100644 --- a/vllm/model_executor/models/jais.py +++ b/vllm/model_executor/models/jais.py @@ -33,8 +33,7 @@ QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) @@ -43,7 +42,9 @@ from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs import JAISConfig -from .utils import is_pp_missing_parameter, make_layers +from .interfaces import SupportsPP +from .utils import (is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers) class SwiGLUActivation(nn.Module): @@ -244,6 +245,9 @@ def __init__( ) self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory(["hidden_states"], + config.n_embd)) def forward( self, @@ -279,7 +283,7 @@ def forward( return hidden_states -class JAISLMHeadModel(nn.Module): +class JAISLMHeadModel(nn.Module, SupportsPP): def __init__( self, @@ -304,6 +308,8 @@ def __init__( self.logits_processor = LogitsProcessor(vocab_size=config.vocab_size, scale=self.output_logits_scale) self.sampler = Sampler() + self.make_empty_intermediate_tensors = ( + self.transformer.make_empty_intermediate_tensors) def forward( self, @@ -326,16 +332,6 @@ def compute_logits( sampling_metadata) return logits - def make_empty_intermediate_tensors( - self, batch_size: int, dtype: torch.dtype, - device: torch.device) -> IntermediateTensors: - return IntermediateTensors({ - "hidden_states": - torch.zeros((batch_size, self.config.hidden_size), - dtype=dtype, - device=device), - }) - def sample( self, logits: torch.Tensor, diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 5ff31e3833ec9..bbb965e614fba 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -37,8 +37,7 @@ QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( get_compressed_tensors_cache_scale) from vllm.model_executor.layers.rotary_embedding import get_rope @@ -51,8 +50,9 @@ from vllm.sequence import IntermediateTensors from vllm.utils import is_hip -from .interfaces import SupportsLoRA -from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers +from .interfaces import SupportsLoRA, SupportsPP +from .utils import (PPMissingLayer, is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers) class LlamaMLP(nn.Module): @@ -72,12 +72,15 @@ def __init__( output_sizes=[intermediate_size] * 2, bias=bias, quant_config=quant_config, - prefix=f"{prefix}.gate_up_proj") - self.down_proj = RowParallelLinear(input_size=intermediate_size, - output_size=hidden_size, - bias=bias, - quant_config=quant_config, - prefix=f"{prefix}.down_proj") + prefix=f"{prefix}.gate_up_proj", + ) + self.down_proj = RowParallelLinear( + input_size=intermediate_size, + output_size=hidden_size, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.down_proj", + ) if hidden_act != "silu": raise ValueError(f"Unsupported activation: {hidden_act}. " "Only silu is supported for now.") @@ -161,12 +164,14 @@ def __init__( rope_scaling=rope_scaling, is_neox_style=is_neox_style, ) - self.attn = Attention(self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads, - cache_config=cache_config, - quant_config=quant_config) + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + ) def forward( self, @@ -248,12 +253,10 @@ def forward( else: hidden_states, residual = self.input_layernorm( hidden_states, residual) - hidden_states = self.self_attn( - positions=positions, - hidden_states=hidden_states, - kv_cache=kv_cache, - attn_metadata=attn_metadata, - ) + hidden_states = self.self_attn(positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata) # Fully Connected hidden_states, residual = self.post_attention_layernorm( @@ -295,12 +298,17 @@ def __init__( cache_config=cache_config, quant_config=quant_config, prefix=prefix), - prefix=f"{prefix}.layers") + prefix=f"{prefix}.layers", + ) if get_pp_group().is_last_rank: self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) else: self.norm = PPMissingLayer() + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) @@ -326,13 +334,9 @@ def forward( for i in range(self.start_layer, self.end_layer): layer = self.layers[i] - hidden_states, residual = layer( - positions, - hidden_states, - kv_caches[i - self.start_layer], - attn_metadata, - residual, - ) + hidden_states, residual = layer(positions, hidden_states, + kv_caches[i - self.start_layer], + attn_metadata, residual) if not get_pp_group().is_last_rank: return IntermediateTensors({ @@ -344,17 +348,10 @@ def forward( return hidden_states -class LlamaForCausalLM(nn.Module, SupportsLoRA): +class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP): packed_modules_mapping = { - "qkv_proj": [ - "q_proj", - "k_proj", - "v_proj", - ], - "gate_up_proj": [ - "gate_proj", - "up_proj", - ], + "qkv_proj": ["q_proj", "k_proj", "v_proj"], + "gate_up_proj": ["gate_proj", "up_proj"] } # LoRA specific attributes @@ -364,7 +361,7 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA): ] embedding_modules = { "embed_tokens": "input_embeddings", - "lm_head": "output_embeddings", + "lm_head": "output_embeddings" } embedding_padding_modules = ["lm_head"] bitsandbytes_stacked_params_mapping = { @@ -420,10 +417,12 @@ def __init__( self.unpadded_vocab_size, config.hidden_size, org_num_embeddings=config.vocab_size, - padding_size=DEFAULT_VOCAB_PADDING_SIZE - # We need bigger padding if using lora for kernel - # compatibility - if not lora_config else lora_config.lora_vocab_padding_size, + padding_size=( + DEFAULT_VOCAB_PADDING_SIZE + # We need bigger padding if using lora for kernel + # compatibility + if not lora_config else + lora_config.lora_vocab_padding_size), quant_config=quant_config, ) if config.tie_word_embeddings: @@ -436,6 +435,8 @@ def __init__( self.sampler = Sampler() else: self.lm_head = PPMissingLayer() + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) def forward( self, @@ -458,28 +459,11 @@ def compute_logits( sampling_metadata) return logits - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: + def sample(self, logits: torch.Tensor, + sampling_metadata: SamplingMetadata) -> Optional[SamplerOutput]: next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def make_empty_intermediate_tensors( - self, batch_size: int, dtype: torch.dtype, - device: torch.device) -> IntermediateTensors: - return IntermediateTensors({ - "hidden_states": - torch.zeros((batch_size, self.config.hidden_size), - dtype=dtype, - device=device), - "residual": - torch.zeros((batch_size, self.config.hidden_size), - dtype=dtype, - device=device), - }) - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) @@ -513,7 +497,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): loaded_weight = loaded_weight[0] weight_loader(param, loaded_weight) continue - for (param_name, weight_name, shard_id) in stacked_params_mapping: + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) diff --git a/vllm/model_executor/models/llama_embedding.py b/vllm/model_executor/models/llama_embedding.py index 8f1c77da50d96..ce05d8e3911bf 100644 --- a/vllm/model_executor/models/llama_embedding.py +++ b/vllm/model_executor/models/llama_embedding.py @@ -1,4 +1,4 @@ -from typing import Iterable, List, Optional, Tuple +from typing import Iterable, List, Optional, Tuple, Union import torch from torch import nn @@ -8,10 +8,13 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.llama import LlamaModel from vllm.model_executor.pooling_metadata import PoolingMetadata -from vllm.sequence import PoolerOutput +from vllm.sequence import IntermediateTensors, PoolerOutput +from .interfaces import SupportsPP +from .utils import is_pp_missing_parameter -class LlamaEmbeddingModel(nn.Module): + +class LlamaEmbeddingModel(nn.Module, SupportsPP): """A model that uses Llama with additional embedding functionalities. This class encapsulates the LlamaModel and provides an interface for @@ -29,6 +32,8 @@ def __init__( super().__init__() self.model = LlamaModel(**kwargs) self._pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) def forward( self, @@ -36,10 +41,12 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, IntermediateTensors]: return self.model.forward(input_ids, positions, kv_caches, - attn_metadata, inputs_embeds) + attn_metadata, intermediate_tensors, + inputs_embeds) def pooler( self, @@ -73,6 +80,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + if is_pp_missing_parameter(name, self): + continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) @@ -81,6 +90,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + if is_pp_missing_parameter(name, self): + continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index 69eb177a7dea8..a62231b628cb9 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -1,3 +1,4 @@ +from functools import cached_property from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, TypedDict, Union) @@ -11,7 +12,7 @@ from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY @@ -21,7 +22,7 @@ from .clip import (CLIPVisionModel, dummy_image_for_clip, dummy_seq_data_for_clip, get_max_clip_image_tokens, input_processor_for_clip) -from .interfaces import SupportsMultiModal +from .interfaces import SupportsMultiModal, SupportsPP from .siglip import (SiglipVisionModel, dummy_image_for_siglip, dummy_seq_data_for_siglip, get_max_siglip_image_tokens, input_processor_for_siglip) @@ -198,7 +199,7 @@ def _init_vision_tower(hf_config: LlavaConfig): @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens) @INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava) @INPUT_REGISTRY.register_input_processor(input_processor_for_llava) -class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal): +class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): def __init__(self, config: LlavaConfig, @@ -220,6 +221,16 @@ def __init__(self, self.language_model = init_vllm_registered_model( config.text_config, cache_config, quant_config) + self.make_empty_intermediate_tensors = ( + self.language_model.make_empty_intermediate_tensors) + + @cached_property + def sampler(self): + if hasattr(self.language_model, "sampler"): + return self.language_model.sampler + + return Sampler() + def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: h = w = self.config.vision_config.image_size expected_dims = (3, h, w) @@ -315,7 +326,7 @@ def forward( attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, **kwargs: object, - ) -> SamplerOutput: + ) -> Union[torch.Tensor, IntermediateTensors]: """Run forward pass for LLaVA-1.5. One key thing to understand is the `input_ids` already accounts for the @@ -351,26 +362,30 @@ def forward( See also: :class:`LlavaImageInputs` """ - image_input = self._parse_and_validate_image_input(**kwargs) + if intermediate_tensors is not None: + input_ids = None + inputs_embeds = None + else: + image_input = self._parse_and_validate_image_input(**kwargs) - if image_input is not None: - vision_embeddings = self._process_image_input(image_input) - inputs_embeds = self.language_model.model.get_input_embeddings( - input_ids) + if image_input is not None: + vision_embeddings = self._process_image_input(image_input) + inputs_embeds = self.language_model.model.get_input_embeddings( + input_ids) - inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, vision_embeddings, - self.config.image_token_index) + inputs_embeds = merge_multimodal_embeddings( + input_ids, inputs_embeds, vision_embeddings, + self.config.image_token_index) - input_ids = None - else: - inputs_embeds = None + input_ids = None + else: + inputs_embeds = None hidden_states = self.language_model.model(input_ids, positions, kv_caches, attn_metadata, - None, + intermediate_tensors, inputs_embeds=inputs_embeds) return hidden_states diff --git a/vllm/model_executor/models/llava_next.py b/vllm/model_executor/models/llava_next.py index 4341cc38bdd28..efad800d7d760 100644 --- a/vllm/model_executor/models/llava_next.py +++ b/vllm/model_executor/models/llava_next.py @@ -1,3 +1,4 @@ +from functools import cached_property from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, TypedDict, Union) @@ -13,7 +14,7 @@ from vllm.config import CacheConfig, MultiModalConfig from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY @@ -23,7 +24,7 @@ from .clip import (CLIPVisionModel, dummy_image_for_clip, dummy_seq_data_for_clip, get_clip_image_feature_size, get_clip_patch_grid_length, input_processor_for_clip) -from .interfaces import SupportsMultiModal +from .interfaces import SupportsMultiModal, SupportsPP from .llava import LlavaMultiModalProjector from .siglip import (SiglipVisionModel, dummy_image_for_siglip, dummy_seq_data_for_siglip, get_siglip_image_feature_size, @@ -286,7 +287,8 @@ def _init_vision_tower(hf_config: LlavaNextConfig): @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_next_image_tokens) @INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava_next) @INPUT_REGISTRY.register_input_processor(input_processor_for_llava_next) -class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal): +class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, + SupportsPP): def __init__(self, config: LlavaNextConfig, @@ -300,6 +302,8 @@ def __init__(self, # TODO: Optionally initializes this for supporting embeddings. self.vision_tower = _init_vision_tower(config) + self.image_newline = nn.Parameter( + torch.empty(config.text_config.hidden_size)) self.multi_modal_projector = LlavaMultiModalProjector( vision_hidden_size=config.vision_config.hidden_size, text_hidden_size=config.text_config.hidden_size, @@ -308,8 +312,15 @@ def __init__(self, self.language_model = init_vllm_registered_model( config.text_config, cache_config, quant_config) - self.image_newline = nn.Parameter( - torch.empty(config.text_config.hidden_size)) + self.make_empty_intermediate_tensors = ( + self.language_model.make_empty_intermediate_tensors) + + @cached_property + def sampler(self): + if hasattr(self.language_model, "sampler"): + return self.language_model.sampler + + return Sampler() def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor: expected_dims = (2, ) @@ -542,7 +553,7 @@ def forward( attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, **kwargs: object, - ) -> SamplerOutput: + ) -> Union[torch.Tensor, IntermediateTensors]: """Run forward pass for LlaVA-NeXT. One key thing to understand is the `input_ids` already accounts for the @@ -587,26 +598,30 @@ def forward( See also: :class:`LlavaNextImageInputs` """ - image_input = self._parse_and_validate_image_input(**kwargs) + if intermediate_tensors is not None: + input_ids = None + inputs_embeds = None + else: + image_input = self._parse_and_validate_image_input(**kwargs) - if image_input is not None: - vision_embeddings = self._process_image_input(image_input) - inputs_embeds = self.language_model.model.get_input_embeddings( - input_ids) + if image_input is not None: + vision_embeddings = self._process_image_input(image_input) + inputs_embeds = self.language_model.model.get_input_embeddings( + input_ids) - inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, vision_embeddings, - self.config.image_token_index) + inputs_embeds = merge_multimodal_embeddings( + input_ids, inputs_embeds, vision_embeddings, + self.config.image_token_index) - input_ids = None - else: - inputs_embeds = None + input_ids = None + else: + inputs_embeds = None hidden_states = self.language_model.model(input_ids, positions, kv_caches, attn_metadata, - None, + intermediate_tensors, inputs_embeds=inputs_embeds) return hidden_states diff --git a/vllm/model_executor/models/llava_next_video.py b/vllm/model_executor/models/llava_next_video.py index 397a6cce5af2c..44b3073b46358 100644 --- a/vllm/model_executor/models/llava_next_video.py +++ b/vllm/model_executor/models/llava_next_video.py @@ -1,4 +1,5 @@ import math +from functools import cached_property from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, TypedDict, Union) @@ -12,9 +13,8 @@ from vllm.config import CacheConfig, MultiModalConfig from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs from vllm.model_executor.layers.activation import get_act_fn -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) -from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.clip import CLIPVisionModel from vllm.model_executor.sampling_metadata import SamplingMetadata @@ -25,7 +25,7 @@ from vllm.utils import is_list_of from .clip import dummy_image_for_clip, dummy_seq_data_for_clip -from .interfaces import SupportsMultiModal +from .interfaces import SupportsMultiModal, SupportsPP from .siglip import (SiglipVisionModel, dummy_image_for_siglip, dummy_seq_data_for_siglip) from .utils import (group_weights_with_prefix, init_vllm_registered_model, @@ -267,7 +267,8 @@ def forward(self, image_features: torch.Tensor) -> torch.Tensor: "video", get_max_llava_next_video_tokens) @INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava_next_video) @INPUT_REGISTRY.register_input_processor(input_processor_for_llava_next_video) -class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal): +class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal, + SupportsPP): def __init__(self, config: LlavaNextVideoConfig, @@ -281,13 +282,23 @@ def __init__(self, # Initialize the vision tower only up to the required feature layer self.vision_tower = _init_vision_tower(config) + self.vision_resampler = LlavaNextVideoPooler(config) self.multi_modal_projector = LlavaNextMultiModalProjector( vision_hidden_size=config.vision_config.hidden_size, text_hidden_size=config.text_config.hidden_size, projector_hidden_act=config.projector_hidden_act) self.language_model = init_vllm_registered_model( config.text_config, cache_config, quant_config) - self.vision_resampler = LlavaNextVideoPooler(config) + + self.make_empty_intermediate_tensors = ( + self.language_model.model.make_empty_intermediate_tensors) + + @cached_property + def sampler(self): + if hasattr(self.language_model, "sampler"): + return self.language_model.sampler + + return Sampler() def _validate_video_pixel_values( self, data: Union[torch.Tensor, List[torch.Tensor]] @@ -397,34 +408,36 @@ def forward( attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, **kwargs: object, - ) -> SamplerOutput: + ) -> Union[torch.Tensor, IntermediateTensors]: """Run forward pass for LlaVA-NeXT-Video. Args: input_ids: Flattened (concatenated) input_ids corresponding to a batch. pixel_values_videos: Pixels in each frames for each input videos. """ - video_input = self._parse_and_validate_video_input(**kwargs) - - # merge video embeddings into input embeddings - if video_input is not None: - video_embeddings = self._process_video_pixels(video_input) - inputs_embeds = self.language_model \ - .model.get_input_embeddings(input_ids) - - inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, video_embeddings, - self.config.video_token_index) - + if intermediate_tensors is not None: input_ids = None - else: inputs_embeds = None + else: + video_input = self._parse_and_validate_video_input(**kwargs) + if video_input is not None: + video_embeddings = self._process_video_pixels(video_input) + inputs_embeds = self.language_model \ + .model.get_input_embeddings(input_ids) + + inputs_embeds = merge_multimodal_embeddings( + input_ids, inputs_embeds, video_embeddings, + self.config.video_token_index) + + input_ids = None + else: + inputs_embeds = None hidden_states = self.language_model.model(input_ids, positions, kv_caches, attn_metadata, - None, + intermediate_tensors, inputs_embeds=inputs_embeds) return hidden_states diff --git a/vllm/model_executor/models/llava_onevision.py b/vllm/model_executor/models/llava_onevision.py index 9099d4f88222d..af957e35d8089 100644 --- a/vllm/model_executor/models/llava_onevision.py +++ b/vllm/model_executor/models/llava_onevision.py @@ -1,4 +1,5 @@ import math +from functools import cached_property from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, TypedDict, Union) @@ -17,9 +18,8 @@ from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs from vllm.logger import init_logger from vllm.model_executor.layers.activation import get_act_fn -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) -from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY @@ -31,7 +31,7 @@ from .clip import (CLIPVisionModel, dummy_seq_data_for_clip, dummy_video_for_clip, get_clip_image_feature_size, get_clip_patch_grid_length, input_processor_for_clip) -from .interfaces import SupportsMultiModal +from .interfaces import SupportsMultiModal, SupportsPP from .siglip import (SiglipVisionModel, dummy_seq_data_for_siglip, dummy_video_for_siglip, get_siglip_image_feature_size, get_siglip_patch_grid_length, input_processor_for_siglip) @@ -414,7 +414,8 @@ def forward(self, image_features: torch.Tensor) -> torch.Tensor: "video", get_max_llava_onevision_video_tokens) @INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava_onevision) @INPUT_REGISTRY.register_input_processor(input_processor_for_llava_onevision) -class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal): +class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, + SupportsPP): def __init__(self, config: LlavaOnevisionConfig, @@ -434,6 +435,16 @@ def __init__(self, self.image_newline = nn.Parameter( torch.empty(config.text_config.hidden_size)) + self.make_empty_intermediate_tensors = ( + self.language_model.model.make_empty_intermediate_tensors) + + @cached_property + def sampler(self): + if hasattr(self.language_model, "sampler"): + return self.language_model.sampler + + return Sampler() + def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor: expected_dims = (2, ) @@ -805,39 +816,42 @@ def forward( attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, **kwargs: object, - ) -> SamplerOutput: + ) -> Union[torch.Tensor, IntermediateTensors]: """Run forward pass for LlaVA-Onevision. Args: input_ids: Flattened (concatenated) input_ids corresponding to a batch. pixel_values_videos: Pixels in each frames for each input videos. """ - modalities = self._parse_and_validate_multimodal_inputs(**kwargs) - # merge video embeddings into input embeddings - if modalities: - inputs_embeds = self.language_model.model.get_input_embeddings( - input_ids) - if "images" in modalities: - image_input = modalities["images"] - vision_embeddings = self._process_image_input(image_input) - inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, vision_embeddings, - self.config.image_token_index) - if "videos" in modalities: - video_input = modalities["videos"] - video_embeddings = self._process_video_pixels(video_input) - inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, video_embeddings, - self.config.video_token_index) + if intermediate_tensors is not None: input_ids = None - else: inputs_embeds = None + else: + modalities = self._parse_and_validate_multimodal_inputs(**kwargs) + if modalities: + inputs_embeds = self.language_model.model.get_input_embeddings( + input_ids) + if "images" in modalities: + image_input = modalities["images"] + vision_embeddings = self._process_image_input(image_input) + inputs_embeds = merge_multimodal_embeddings( + input_ids, inputs_embeds, vision_embeddings, + self.config.image_token_index) + if "videos" in modalities: + video_input = modalities["videos"] + video_embeddings = self._process_video_pixels(video_input) + inputs_embeds = merge_multimodal_embeddings( + input_ids, inputs_embeds, video_embeddings, + self.config.video_token_index) + input_ids = None + else: + inputs_embeds = None hidden_states = self.language_model.model(input_ids, positions, kv_caches, attn_metadata, - None, + intermediate_tensors, inputs_embeds=inputs_embeds) return hidden_states diff --git a/vllm/model_executor/models/minicpm.py b/vllm/model_executor/models/minicpm.py index 963ad7553fe1d..6bba1594c270f 100644 --- a/vllm/model_executor/models/minicpm.py +++ b/vllm/model_executor/models/minicpm.py @@ -22,7 +22,7 @@ # limitations under the License. """Inference-only MiniCPM model compatible with HuggingFace weights.""" import math -from typing import Any, Dict, Iterable, List, Optional, Tuple +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union import torch from torch import nn @@ -30,7 +30,7 @@ from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig, LoRAConfig -from vllm.distributed import (get_tensor_model_parallel_rank, +from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) from vllm.model_executor.layers.activation import SiluAndMul @@ -41,8 +41,7 @@ ReplicatedLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -52,7 +51,9 @@ from vllm.model_executor.utils import set_weight_attrs from vllm.sequence import IntermediateTensors -from .interfaces import SupportsLoRA +from .interfaces import SupportsLoRA, SupportsPP +from .utils import (is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers) class MiniCPMMoE(nn.Module): @@ -264,7 +265,7 @@ class MiniCPMDecoderLayer(nn.Module): def __init__( self, - config, + config: PretrainedConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: @@ -346,10 +347,11 @@ class MiniCPMModel(nn.Module): def __init__( self, - config, + config: PretrainedConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.config = config @@ -365,15 +367,24 @@ def __init__( config.hidden_size, org_num_embeddings=config.vocab_size, ) - self._init_layers() + self._init_layers(prefix, config, cache_config, quant_config) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], self.config.hidden_size)) - def _init_layers(self): - self.layers = nn.ModuleList([ - MiniCPMDecoderLayer(self.config, self.cache_config, - self.quant_config) - for _ in range(self.config.num_hidden_layers) - ]) + def _init_layers( + self, + prefix: str, + config: PretrainedConfig, + cache_config: Optional[CacheConfig], + quant_config: Optional[QuantizationConfig], + ): + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: MiniCPMDecoderLayer(config, cache_config, + quant_config), + prefix=f"{prefix}.layers") def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: embedding = self.embed_tokens(input_ids) @@ -387,27 +398,36 @@ def forward( attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - if inputs_embeds is not None: - hidden_states = inputs_embeds + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None else: - hidden_states = self.get_input_embeddings(input_ids) - residual = None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] - for i in range(len(self.layers)): + for i in range(self.start_layer, self.end_layer): layer = self.layers[i] hidden_states, residual = layer( positions, hidden_states, - kv_caches[i], + kv_caches[i - self.start_layer], attn_metadata, residual, ) + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) hidden_states = self.norm(hidden_states) return hidden_states -class MiniCPMForCausalLM(nn.Module, SupportsLoRA): +class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP): packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -470,6 +490,8 @@ def __init__( self.logits_processor = LogitsProcessor(unpadded_vocab_size, config.vocab_size) self.sampler = Sampler() + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) def _init_model(self): self.model = MiniCPMModel(config=self.config, @@ -484,7 +506,7 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model(input_ids, positions, kv_caches, attn_metadata, intermediate_tensors) return hidden_states @@ -548,6 +570,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + if is_pp_missing_parameter(name, self): + continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) @@ -557,6 +581,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): if weight_name not in name: continue name = name.replace(weight_name, param_name) + if is_pp_missing_parameter(name, self): + continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, @@ -568,6 +594,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + if is_pp_missing_parameter(name, self): + continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) diff --git a/vllm/model_executor/models/minicpm3.py b/vllm/model_executor/models/minicpm3.py index a048a3dba0415..c37bc5ad7c38f 100644 --- a/vllm/model_executor/models/minicpm3.py +++ b/vllm/model_executor/models/minicpm3.py @@ -26,6 +26,7 @@ import torch from torch import nn +from transformers import PretrainedConfig from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig @@ -34,19 +35,20 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, ReplicatedLinear, RowParallelLinear) -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.models.minicpm import (MiniCPMDecoderLayer, MiniCPMForCausalLM, MiniCPMModel) +from .utils import make_layers + class MiniCPM3Attention(nn.Module): def __init__( self, - config, + config: PretrainedConfig, hidden_size: int, num_heads: int, qk_nope_head_dim: int, @@ -199,12 +201,18 @@ def _init_attn_block(self): class MiniCPM3Model(MiniCPMModel): - def _init_layers(self): - self.layers = nn.ModuleList([ - MiniCPM3DecoderLayer(self.config, self.cache_config, - self.quant_config) - for _ in range(self.config.num_hidden_layers) - ]) + def _init_layers( + self, + prefix: str, + config: PretrainedConfig, + cache_config: Optional[CacheConfig], + quant_config: Optional[QuantizationConfig], + ): + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: MiniCPM3DecoderLayer(config, cache_config, + quant_config), + prefix=f"{prefix}.layers") class MiniCPM3ForCausalLM(MiniCPMForCausalLM): diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index 0e0e86f2fe503..6d0fa34f299ad 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -45,7 +45,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.model_loader.utils import set_default_torch_dtype from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.models.interfaces import SupportsMultiModal from vllm.model_executor.models.llama import LlamaModel from vllm.model_executor.models.minicpm import MiniCPMModel from vllm.model_executor.models.module_mapping import MultiModelKeys @@ -59,7 +58,8 @@ from vllm.sequence import IntermediateTensors, SequenceData from .idefics2_vision_model import Idefics2VisionTransformer -from .interfaces import SupportsLoRA +from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP +from .utils import is_pp_missing_parameter _KEYS_TO_MODIFY_MAPPING = { "llm.lm_head": "lm_head", @@ -337,7 +337,7 @@ def input_mapper_for_minicpmv(ctx: InputContext, data: object): return MultiModalInputs(batch_data) -class MiniCPMVBaseModel(nn.Module, SupportsMultiModal): +class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP): """ The abstract class of MiniCPMV can only be inherited, but cannot be instantiated. @@ -374,6 +374,9 @@ def __init__( self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() + self.make_empty_intermediate_tensors = ( + self.llm.make_empty_intermediate_tensors) + def get_embedding( self, input_ids: torch.Tensor, @@ -498,9 +501,12 @@ def forward( intermediate_tensors: Optional[IntermediateTensors] = None, **kwargs: Any, ) -> torch.Tensor: - image_inputs = self._parse_and_validate_inputs(input_ids, **kwargs) + if intermediate_tensors is not None: + vlm_embeddings = None + else: + image_inputs = self._parse_and_validate_inputs(input_ids, **kwargs) - vlm_embeddings, _ = self.get_embedding(input_ids, image_inputs) + vlm_embeddings, _ = self.get_embedding(input_ids, image_inputs) output = self.llm( input_ids=None, @@ -557,6 +563,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue + if is_pp_missing_parameter( + name.replace(weight_name, param_name), self): + continue param = params_dict[name.replace(weight_name, param_name)] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) @@ -564,6 +573,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): else: use_default_weight_loading = True if use_default_weight_loading: + if is_pp_missing_parameter(name, self): + continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 10cbfcf6432b3..f93ba0875c8b1 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -21,7 +21,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Mixtral model.""" -from typing import Iterable, List, Optional, Tuple +from typing import Iterable, List, Optional, Tuple, Union import torch from torch import nn @@ -36,8 +36,7 @@ ReplicatedLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -47,8 +46,9 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors -from .interfaces import SupportsLoRA -from .utils import is_pp_missing_parameter, make_layers +from .interfaces import SupportsLoRA, SupportsPP +from .utils import (is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers) class MixtralMoE(nn.Module): @@ -276,6 +276,9 @@ def __init__( prefix=f"{prefix}.layers") self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) def forward( self, @@ -284,7 +287,7 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], - ) -> torch.Tensor: + ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: hidden_states = self.embed_tokens(input_ids) residual = None @@ -306,7 +309,7 @@ def forward( return hidden_states -class MixtralForCausalLM(nn.Module, SupportsLoRA): +class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP): fall_back_to_pt_during_load = False packed_modules_mapping = { @@ -365,6 +368,8 @@ def __init__( self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size) self.sampler = Sampler() + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) def forward( self, @@ -373,7 +378,7 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model(input_ids, positions, kv_caches, attn_metadata, intermediate_tensors) return hidden_states @@ -387,20 +392,6 @@ def compute_logits( sampling_metadata) return logits - def make_empty_intermediate_tensors( - self, batch_size: int, dtype: torch.dtype, - device: torch.device) -> IntermediateTensors: - return IntermediateTensors({ - "hidden_states": - torch.zeros((batch_size, self.config.hidden_size), - dtype=dtype, - device=device), - "residual": - torch.zeros((batch_size, self.config.hidden_size), - dtype=dtype, - device=device), - }) - def sample( self, logits: Optional[torch.Tensor], diff --git a/vllm/model_executor/models/mixtral_quant.py b/vllm/model_executor/models/mixtral_quant.py index 68471f6ac77d1..63e2c60a84271 100644 --- a/vllm/model_executor/models/mixtral_quant.py +++ b/vllm/model_executor/models/mixtral_quant.py @@ -21,7 +21,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Mixtral model.""" -from typing import Iterable, List, Optional, Tuple +from typing import Iterable, List, Optional, Tuple, Union import numpy as np import torch @@ -31,7 +31,7 @@ from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig -from vllm.distributed import (get_tensor_model_parallel_rank, +from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) from vllm.model_executor.layers.layernorm import RMSNorm @@ -39,8 +39,7 @@ ReplicatedLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -49,6 +48,10 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors +from .interfaces import SupportsPP +from .utils import (is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers) + class MixtralMLP(nn.Module): @@ -296,6 +299,7 @@ def __init__( config: MixtralConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.padding_idx = config.pad_token_id @@ -305,13 +309,15 @@ def __init__( config.vocab_size, config.hidden_size, ) - self.layers = nn.ModuleList([ - MixtralDecoderLayer(config, - cache_config, - quant_config=quant_config) - for _ in range(config.num_hidden_layers) - ]) + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: MixtralDecoderLayer( + config, cache_config, quant_config=quant_config), + prefix=f"{prefix}.layers") self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) def forward( self, @@ -319,19 +325,30 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, - ) -> torch.Tensor: - hidden_states = self.embed_tokens(input_ids) - residual = None - for i in range(len(self.layers)): + intermediate_tensors: Optional[IntermediateTensors], + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + hidden_states = self.embed_tokens(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + for i in range(self.start_layer, self.end_layer): layer = self.layers[i] hidden_states, residual = layer(positions, hidden_states, - kv_caches[i], attn_metadata, - residual) + kv_caches[i - self.start_layer], + attn_metadata, residual) + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) hidden_states, _ = self.norm(hidden_states, residual) return hidden_states -class MixtralForCausalLM(nn.Module): +class MixtralForCausalLM(nn.Module, SupportsPP): fall_back_to_pt_during_load = False def __init__( @@ -351,6 +368,8 @@ def __init__( self.lm_head.weight = self.model.embed_tokens.weight self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) def forward( self, @@ -359,9 +378,9 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata) + attn_metadata, intermediate_tensors) return hidden_states def compute_logits( @@ -400,6 +419,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + if is_pp_missing_parameter(name, self): + continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) @@ -412,6 +433,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): if ("block_sparse_moe.experts." in name and name not in params_dict): continue + if is_pp_missing_parameter(name, self): + continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) diff --git a/vllm/model_executor/models/mpt.py b/vllm/model_executor/models/mpt.py index 0fcbf06e1a060..e3d3937b13fa0 100644 --- a/vllm/model_executor/models/mpt.py +++ b/vllm/model_executor/models/mpt.py @@ -1,22 +1,21 @@ # coding=utf-8 # Adapted from https://huggingface.co/mosaicml/mpt-7b/tree/main import math -from typing import Iterable, List, Optional, Tuple +from typing import Iterable, List, Optional, Tuple, Union import torch import torch.nn as nn from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig -from vllm.distributed import (get_tensor_model_parallel_rank, +from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) @@ -25,6 +24,10 @@ from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs.mpt import MPTConfig +from .interfaces import SupportsPP +from .utils import (is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers) + def _get_alibi_slopes( total_num_heads: int, @@ -208,6 +211,7 @@ def __init__( config: MPTConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() assert config.embedding_fraction == 1.0 @@ -217,10 +221,10 @@ def __init__( config.vocab_size, config.d_model, ) - self.blocks = nn.ModuleList([ - MPTBlock(config, cache_config, quant_config) - for _ in range(config.n_layers) - ]) + self.start_layer, self.end_layer, self.blocks = make_layers( + config.n_layers, + lambda prefix: MPTBlock(config, cache_config, quant_config), + prefix=f"{prefix}.blocks") self.norm_f = nn.LayerNorm(config.d_model) if config.no_bias: for module in self.modules(): @@ -228,6 +232,9 @@ def __init__( module.bias, nn.Parameter): # Remove the bias term in Linear and LayerNorm. module.register_parameter("bias", None) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory(["hidden_states"], + config.d_model)) def forward( self, @@ -235,21 +242,29 @@ def forward( position_ids: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, - ) -> torch.Tensor: - hidden_states = self.wte(input_ids) - for i in range(len(self.blocks)): + intermediate_tensors: Optional[IntermediateTensors], + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + hidden_states = self.wte(input_ids) + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + + for i in range(self.start_layer, self.end_layer): block = self.blocks[i] hidden_states = block( position_ids, hidden_states, - kv_caches[i], + kv_caches[i - self.start_layer], attn_metadata, ) + if not get_pp_group().is_last_rank: + return IntermediateTensors({"hidden_states": hidden_states}) hidden_states = self.norm_f(hidden_states) return hidden_states -class MPTForCausalLM(nn.Module): +class MPTForCausalLM(nn.Module, SupportsPP): def __init__( self, @@ -266,6 +281,8 @@ def __init__( self.lm_head = self.transformer.wte self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() + self.make_empty_intermediate_tensors = ( + self.transformer.make_empty_intermediate_tensors) def forward( self, @@ -274,9 +291,9 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.transformer(input_ids, positions, kv_caches, - attn_metadata) + attn_metadata, intermediate_tensors) return hidden_states def compute_logits( @@ -302,6 +319,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + if is_pp_missing_parameter(name, self): + continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) diff --git a/vllm/model_executor/models/nemotron.py b/vllm/model_executor/models/nemotron.py index e9ff12de2094e..14515e16e34ac 100644 --- a/vllm/model_executor/models/nemotron.py +++ b/vllm/model_executor/models/nemotron.py @@ -34,8 +34,7 @@ QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -46,8 +45,9 @@ from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs import NemotronConfig -from .interfaces import SupportsLoRA -from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers +from .interfaces import SupportsLoRA, SupportsPP +from .utils import (PPMissingLayer, is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers) # The architecture is pretty similar to Llama, with these changes: # - There is no gate_proj, just up_proj @@ -328,6 +328,9 @@ def __init__( eps=config.norm_eps) else: self.norm = PPMissingLayer() + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) @@ -372,7 +375,7 @@ def forward( return hidden_states -class NemotronForCausalLM(nn.Module, SupportsLoRA): +class NemotronForCausalLM(nn.Module, SupportsLoRA, SupportsPP): packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -440,6 +443,8 @@ def __init__( self.sampler = Sampler() else: self.lm_head = PPMissingLayer() + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) def forward( self, @@ -470,20 +475,6 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def make_empty_intermediate_tensors( - self, batch_size: int, dtype: torch.dtype, - device: torch.device) -> IntermediateTensors: - return IntermediateTensors({ - "hidden_states": - torch.zeros((batch_size, self.config.hidden_size), - dtype=dtype, - device=device), - "residual": - torch.zeros((batch_size, self.config.hidden_size), - dtype=dtype, - device=device), - }) - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) diff --git a/vllm/model_executor/models/olmo.py b/vllm/model_executor/models/olmo.py index 97749725dd132..5ca7c66f5407d 100644 --- a/vllm/model_executor/models/olmo.py +++ b/vllm/model_executor/models/olmo.py @@ -21,7 +21,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only OLMo model compatible with HuggingFace weights.""" -from typing import Iterable, List, Optional, Tuple +from typing import Iterable, List, Optional, Tuple, Union import torch from torch import nn @@ -29,14 +29,13 @@ from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig -from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -45,6 +44,10 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors +from .interfaces import SupportsPP +from .utils import (is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers) + class OlmoAttention(nn.Module): """ @@ -223,19 +226,24 @@ class OlmoModel(nn.Module): def __init__(self, config: OlmoConfig, cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None): + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): super().__init__() self.config = config self.embed_tokens = VocabParallelEmbedding(config.vocab_size, config.hidden_size) - self.layers = nn.ModuleList([ - OlmoDecoderLayer(config, cache_config, quant_config) - for layer_idx in range(config.num_hidden_layers) - ]) + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: OlmoDecoderLayer(config, cache_config, quant_config + ), + prefix=f"{prefix}.layers") self.norm = nn.LayerNorm(config.hidden_size, elementwise_affine=False, bias=False) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory(["hidden_states"], + config.hidden_size)) def forward( self, @@ -243,34 +251,41 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, - ) -> torch.Tensor: + intermediate_tensors: Optional[IntermediateTensors], + ) -> Union[torch.Tensor, IntermediateTensors]: """ :param input_ids: A tensor of shape `(batch_size, seq_len)`. """ - # Get embeddings of input. - # shape: (batch_size, seq_len, d_model) - inputs_embeds = self.embed_tokens(input_ids) + if get_pp_group().is_first_rank: + # Get embeddings of input. + # shape: (batch_size, seq_len, d_model) + inputs_embeds = self.embed_tokens(input_ids) - # embed positions - hidden_states = inputs_embeds + # embed positions + hidden_states = inputs_embeds + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] # Apply blocks one-by-one. - for layer_idx, decoder_layer in enumerate(self.layers): + for i in range(self.start_layer, self.end_layer): # shape: (batch_size, seq_len, d_model) - hidden_states = decoder_layer( + hidden_states = self.layers[i]( positions, hidden_states, - kv_caches[layer_idx], + kv_caches[i - self.start_layer], attn_metadata, ) + if not get_pp_group().is_last_rank: + return IntermediateTensors({"hidden_states": hidden_states}) # Apply final layer norm. # shape: (batch_size, seq_len or 1, d_model) hidden_states = self.norm(hidden_states) return hidden_states -class OlmoForCausalLM(nn.Module): +class OlmoForCausalLM(nn.Module, SupportsPP): """ Extremely barebones HF model wrapper. """ @@ -294,6 +309,8 @@ def __init__(self, ) self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) def forward( self, @@ -302,12 +319,13 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model( input_ids=input_ids, positions=positions, kv_caches=kv_caches, attn_metadata=attn_metadata, + intermediate_tensors=intermediate_tensors, ) return hidden_states @@ -358,6 +376,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + if is_pp_missing_parameter(name, self): + continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) @@ -366,6 +386,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + if is_pp_missing_parameter(name, self): + continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) diff --git a/vllm/model_executor/models/olmoe.py b/vllm/model_executor/models/olmoe.py index c76e5e86c89d8..a1ba80e0d7108 100644 --- a/vllm/model_executor/models/olmoe.py +++ b/vllm/model_executor/models/olmoe.py @@ -10,7 +10,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only OLMoE model compatible with HuggingFace weights.""" -from typing import Any, Dict, Iterable, List, Optional, Tuple +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union import torch from torch import nn @@ -18,15 +18,14 @@ from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig -from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (QKVParallelLinear, ReplicatedLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -36,6 +35,10 @@ from vllm.sequence import IntermediateTensors from vllm.utils import print_warning_once +from .interfaces import SupportsPP +from .utils import (is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers) + class OlmoeMoE(nn.Module): """A tensor-parallel MoE implementation for Olmoe that shards each expert @@ -243,6 +246,7 @@ def __init__( config: PretrainedConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.padding_idx = config.pad_token_id @@ -252,34 +256,54 @@ def __init__( config.vocab_size, config.hidden_size, ) - self.layers = nn.ModuleList([ - OlmoeDecoderLayer(config, - layer_idx, - cache_config, - quant_config=quant_config) - for layer_idx in range(config.num_hidden_layers) - ]) + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: OlmoeDecoderLayer(config, int( + prefix.split(".")[-1]), cache_config, quant_config), + prefix=f"{prefix}.layers") self.norm = RMSNorm(config.hidden_size, eps=1e-5) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, - ) -> torch.Tensor: - hidden_states = self.embed_tokens(input_ids) - residual = None - for i in range(len(self.layers)): + intermediate_tensors: Optional[IntermediateTensors], + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + hidden_states = self.embed_tokens(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + for i in range(self.start_layer, self.end_layer): layer = self.layers[i] - hidden_states, residual = layer(positions, hidden_states, - kv_caches[i], attn_metadata, - residual) + hidden_states, residual = layer( + positions, + hidden_states, + kv_caches[i - self.start_layer], + attn_metadata, + residual, + ) + + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + hidden_states, _ = self.norm(hidden_states, residual) return hidden_states -class OlmoeForCausalLM(nn.Module): +class OlmoeForCausalLM(nn.Module, SupportsPP): fall_back_to_pt_during_load = False @@ -299,6 +323,9 @@ def __init__( self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + def forward( self, input_ids: torch.Tensor, @@ -306,9 +333,9 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata) + attn_metadata, intermediate_tensors) return hidden_states def compute_logits(self, hidden_states: torch.Tensor, @@ -363,6 +390,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue if name not in params_dict: continue @@ -376,6 +406,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): if weight_name not in name: continue name = name.replace(weight_name, param_name) + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, @@ -388,6 +421,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue # Remapping the name of FP8 kv-scale. if name.endswith("kv_scale"): remapped_kv_scale_name = name.replace( diff --git a/vllm/model_executor/models/opt.py b/vllm/model_executor/models/opt.py index 47bc8adc3bc14..727dd65acc749 100644 --- a/vllm/model_executor/models/opt.py +++ b/vllm/model_executor/models/opt.py @@ -17,7 +17,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only OPT model compatible with HuggingFace weights.""" -from typing import Iterable, List, Optional, Tuple +from typing import Iterable, List, Optional, Tuple, Union import torch from torch import nn @@ -25,15 +25,14 @@ from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig -from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, QKVParallelLinear, ReplicatedLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) @@ -41,6 +40,10 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors +from .interfaces import SupportsPP +from .utils import (is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers) + class OPTLearnedPositionalEmbedding(nn.Embedding): @@ -189,6 +192,7 @@ def __init__( config: OPTConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() self.config = config @@ -232,10 +236,10 @@ def __init__( else: self.final_layer_norm = None - self.layers = nn.ModuleList([ - OPTDecoderLayer(config, cache_config, quant_config) - for _ in range(config.num_hidden_layers) - ]) + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: OPTDecoderLayer(config, cache_config, quant_config), + prefix=f"{prefix}.layers") def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) @@ -246,19 +250,28 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors], inputs_embeds: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - if inputs_embeds is None: - inputs_embeds = self.get_input_embeddings(input_ids) - pos_embeds = self.embed_positions(positions) - if self.project_in is not None: - inputs_embeds, _ = self.project_in(inputs_embeds) - hidden_states = inputs_embeds + pos_embeds - - for i in range(len(self.layers)): + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings(input_ids) + pos_embeds = self.embed_positions(positions) + if self.project_in is not None: + inputs_embeds, _ = self.project_in(inputs_embeds) + hidden_states = inputs_embeds + pos_embeds + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + + for i in range(self.start_layer, self.end_layer): layer = self.layers[i] - hidden_states = layer(hidden_states, kv_caches[i], attn_metadata) + hidden_states = layer(hidden_states, + kv_caches[i - self.start_layer], + attn_metadata) + if not get_pp_group().is_last_rank: + return IntermediateTensors({"hidden_states": hidden_states}) if self.final_layer_norm is not None: hidden_states = self.final_layer_norm(hidden_states) if self.project_out is not None: @@ -276,6 +289,9 @@ def __init__( ): super().__init__() self.decoder = OPTDecoder(config, cache_config, quant_config) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory(["hidden_states"], + config.hidden_size)) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.decoder.get_input_embeddings(input_ids) @@ -286,20 +302,22 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors], inputs_embeds: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, IntermediateTensors]: return self.decoder(input_ids, positions, kv_caches, attn_metadata, + intermediate_tensors, inputs_embeds=inputs_embeds) -class OPTForCausalLM(nn.Module): +class OPTForCausalLM(nn.Module, SupportsPP): def __init__( self, - config, + config: OPTConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): @@ -314,6 +332,8 @@ def __init__( config.word_embed_proj_dim) self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) def forward( self, @@ -322,9 +342,9 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata) + attn_metadata, intermediate_tensors) return hidden_states def compute_logits( @@ -365,6 +385,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + if is_pp_missing_parameter(name, self): + continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) @@ -373,6 +395,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + if is_pp_missing_parameter(name, self): + continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) diff --git a/vllm/model_executor/models/orion.py b/vllm/model_executor/models/orion.py index b01ce87adfa46..0913193f73a48 100644 --- a/vllm/model_executor/models/orion.py +++ b/vllm/model_executor/models/orion.py @@ -4,7 +4,7 @@ # Copyright (c) OrionStar Inc. # LICENSE: https://huggingface.co/OrionStarAI/Orion-14B-Base/blob/main/LICENSE """Inference-only Orion-14B model compatible with HuggingFace weights.""" -from typing import Any, Dict, Iterable, List, Optional, Tuple +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union import torch from torch import nn @@ -12,14 +12,13 @@ from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig -from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -28,6 +27,10 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors +from .interfaces import SupportsPP +from .utils import (is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers) + class OrionMLP(nn.Module): @@ -210,6 +213,7 @@ def __init__( config: PretrainedConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.config = config @@ -219,11 +223,18 @@ def __init__( config.vocab_size, config.hidden_size, ) - self.layers = nn.ModuleList([ - OrionDecoderLayer(config, cache_config, quant_config) - for _ in range(config.num_hidden_layers) - ]) + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: OrionDecoderLayer( + config, + cache_config, + quant_config, + ), + prefix=f"{prefix}.layers") self.norm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) def forward( self, @@ -231,23 +242,34 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, - ) -> torch.Tensor: - hidden_states = self.embed_tokens(input_ids) - residual = None - for i in range(len(self.layers)): + intermediate_tensors: Optional[IntermediateTensors], + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + hidden_states = self.embed_tokens(input_ids) + residual = None + else: + assert intermediate_tensors + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + for i in range(self.start_layer, self.end_layer): layer = self.layers[i] hidden_states, residual = layer( positions, hidden_states, - kv_caches[i], + kv_caches[i - self.start_layer], attn_metadata, residual, ) + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) hidden_states = self.norm(hidden_states) return hidden_states -class OrionForCausalLM(nn.Module): +class OrionForCausalLM(nn.Module, SupportsPP): def __init__( self, @@ -266,6 +288,8 @@ def __init__( self.lm_head.weight = self.model.embed_tokens.weight self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) def forward( self, @@ -274,9 +298,9 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata) + attn_metadata, intermediate_tensors) return hidden_states def compute_logits( @@ -321,6 +345,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + if is_pp_missing_parameter(name, self): + continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) @@ -329,6 +355,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + if is_pp_missing_parameter(name, self): + continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) diff --git a/vllm/model_executor/models/paligemma.py b/vllm/model_executor/models/paligemma.py index 8130eb54753ea..93032b4095917 100644 --- a/vllm/model_executor/models/paligemma.py +++ b/vllm/model_executor/models/paligemma.py @@ -9,9 +9,8 @@ from vllm.config import CacheConfig, MultiModalConfig from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs from vllm.logger import init_logger -from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.gemma import GemmaForCausalLM from vllm.model_executor.sampling_metadata import SamplingMetadata @@ -19,7 +18,7 @@ from vllm.multimodal.utils import cached_get_tokenizer from vllm.sequence import IntermediateTensors -from .interfaces import SupportsMultiModal +from .interfaces import SupportsMultiModal, SupportsPP from .siglip import (SiglipVisionModel, dummy_image_for_siglip, dummy_seq_data_for_siglip, get_max_siglip_image_tokens) from .utils import group_weights_with_prefix, merge_multimodal_embeddings @@ -129,7 +128,8 @@ def forward(self, image_features: torch.Tensor) -> torch.Tensor: @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_paligemma_image_tokens) @INPUT_REGISTRY.register_dummy_data(dummy_data_for_paligemma) @INPUT_REGISTRY.register_input_processor(input_processor_for_paligemma) -class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal): +class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal, + SupportsPP): def __init__(self, config: PaliGemmaConfig, @@ -149,12 +149,15 @@ def __init__(self, self.quant_config = quant_config self.language_model = GemmaForCausalLM(config.text_config, cache_config, quant_config) - self.unpadded_vocab_size = config.text_config.vocab_size logit_scale = getattr(config, "logit_scale", 1.0) - self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, - config.text_config.vocab_size, - logit_scale) - self.sampler = Sampler() + self.language_model.logits_processor.scale *= logit_scale + + self.make_empty_intermediate_tensors = ( + self.language_model.make_empty_intermediate_tensors) + + @property + def sampler(self): + return self.language_model.sampler def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: h = w = self.config.vision_config.image_size @@ -239,32 +242,36 @@ def forward(self, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, - **kwargs: object) -> SamplerOutput: - - parsed_image_input = self._parse_and_validate_image_input(**kwargs) + **kwargs: object) -> Union[SamplerOutput, IntermediateTensors]: + if intermediate_tensors is not None: + input_ids = None + inputs_embeds = None + else: + parsed_image_input = self._parse_and_validate_image_input(**kwargs) - if parsed_image_input is not None: - vision_embeddings = self._process_image_input(parsed_image_input) - # https://github.com/huggingface/transformers/blob/main/src/transformers/models/paligemma/modeling_paligemma.py#L294 # noqa - vision_embeddings = vision_embeddings * (self.config.hidden_size** - -0.5) + if parsed_image_input is not None: + vision_embeddings = self._process_image_input( + parsed_image_input) + # https://github.com/huggingface/transformers/blob/main/src/transformers/models/paligemma/modeling_paligemma.py#L294 # noqa + vision_embeddings = vision_embeddings * ( + self.config.hidden_size**-0.5) - inputs_embeds = self.language_model.model.get_input_embeddings( - input_ids) + inputs_embeds = self.language_model.model.get_input_embeddings( + input_ids) - inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, vision_embeddings, - self.config.image_token_index) + inputs_embeds = merge_multimodal_embeddings( + input_ids, inputs_embeds, vision_embeddings, + self.config.image_token_index) - input_ids = None - else: - inputs_embeds = None + input_ids = None + else: + inputs_embeds = None hidden_states = self.language_model.model(input_ids, positions, kv_caches, attn_metadata, - None, + intermediate_tensors, inputs_embeds=inputs_embeds) return hidden_states diff --git a/vllm/model_executor/models/persimmon.py b/vllm/model_executor/models/persimmon.py index fda0602110a0b..b625d19f6447d 100644 --- a/vllm/model_executor/models/persimmon.py +++ b/vllm/model_executor/models/persimmon.py @@ -20,7 +20,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only persimmon model compatible with HuggingFace weights.""" -from typing import Iterable, List, Optional, Tuple +from typing import Iterable, List, Optional, Tuple, Union import torch from torch import nn @@ -28,14 +28,13 @@ from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig -from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -44,6 +43,10 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors +from .interfaces import SupportsPP +from .utils import (is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers) + class PersimmonMLP(nn.Module): @@ -211,20 +214,23 @@ class PersimmonModel(nn.Module): def __init__(self, config: PersimmonConfig, cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None): + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): super().__init__() self.vocab_size = config.vocab_size self.embed_tokens = VocabParallelEmbedding(config.vocab_size, config.hidden_size) - self.layers = nn.ModuleList([ - PersimmonDecoderLayer(config, - cache_config=cache_config, - quant_config=quant_config) - for _ in range(config.num_hidden_layers) - ]) + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: PersimmonDecoderLayer(config, cache_config, + quant_config), + prefix=f"{prefix}.layers") self.final_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory(["hidden_states"], + config.hidden_size)) def forward( self, @@ -232,24 +238,31 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors], inputs_embeds: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - if inputs_embeds is not None: - hidden_states = inputs_embeds + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.embed_tokens(input_ids) else: - hidden_states = self.embed_tokens(input_ids) - for i in range(len(self.layers)): + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + for i in range(self.start_layer, self.end_layer): hidden_states = self.layers[i]( positions, hidden_states, - kv_caches[i], + kv_caches[i - self.start_layer], attn_metadata, ) + if not get_pp_group().is_last_rank: + return IntermediateTensors({"hidden_states": hidden_states}) hidden_states = self.final_layernorm(hidden_states) return hidden_states -class PersimmonForCausalLM(nn.Module): +class PersimmonForCausalLM(nn.Module, SupportsPP): def __init__(self, config: PersimmonConfig, @@ -266,6 +279,8 @@ def __init__(self, bias=False) self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) def forward( self, @@ -281,6 +296,7 @@ def forward( positions=positions, kv_caches=kv_caches, attn_metadata=attn_metadata, + intermediate_tensors=intermediate_tensors, inputs_embeds=inputs_embeds, ) return hidden_states @@ -312,6 +328,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Models trained using ColossalAI may include these tensors in # the checkpoint. Skip them. continue + if is_pp_missing_parameter(name, self): + continue param = params_dict[name] if "query_key_value" in name: diff --git a/vllm/model_executor/models/phi.py b/vllm/model_executor/models/phi.py index 15c21cfa2d8a8..c90fe2e0ab9ea 100644 --- a/vllm/model_executor/models/phi.py +++ b/vllm/model_executor/models/phi.py @@ -35,7 +35,7 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. """Inference-only Phi-1.5 model compatible with HuggingFace weights.""" -from typing import Iterable, List, Optional, Tuple +from typing import Iterable, List, Optional, Tuple, Union import torch from torch import nn @@ -43,14 +43,13 @@ from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig, LoRAConfig -from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -59,7 +58,9 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors -from .interfaces import SupportsLoRA +from .interfaces import SupportsLoRA, SupportsPP +from .utils import (is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers) class PhiAttention(nn.Module): @@ -196,18 +197,22 @@ class PhiModel(nn.Module): def __init__(self, config: PhiConfig, cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None): + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): super().__init__() self.config = config self.quant_config = quant_config self.embed_tokens = VocabParallelEmbedding(config.vocab_size, config.hidden_size) - self.layers = nn.ModuleList([ - PhiLayer(config, cache_config, quant_config) - for _ in range(config.num_hidden_layers) - ]) + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: PhiLayer(config, cache_config, quant_config), + prefix=f"{prefix}.layers") self.final_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory(["hidden_states"], + config.hidden_size)) def forward( self, @@ -215,23 +220,31 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, - ) -> torch.Tensor: - hidden_states = self.embed_tokens(input_ids) - for i in range(self.config.num_hidden_layers): + intermediate_tensors: Optional[IntermediateTensors], + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + hidden_states = self.embed_tokens(input_ids) + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + for i in range(self.start_layer, self.end_layer): layer = self.layers[i] hidden_states = layer( positions, hidden_states, - kv_caches[i], + kv_caches[i - self.start_layer], attn_metadata, ) + if not get_pp_group().is_last_rank: + return IntermediateTensors({"hidden_states": hidden_states}) + hidden_states = self.final_layernorm(hidden_states) return hidden_states -class PhiForCausalLM(nn.Module, SupportsLoRA): +class PhiForCausalLM(nn.Module, SupportsLoRA, SupportsPP): packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -274,6 +287,8 @@ def __init__( quant_config=quant_config) self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) def forward( self, @@ -282,9 +297,9 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata) + attn_metadata, intermediate_tensors) return hidden_states @@ -325,6 +340,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + if is_pp_missing_parameter(name, self): + continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) @@ -335,6 +352,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): continue # pylint: disable=E1136 + if is_pp_missing_parameter(name, self): + continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) diff --git a/vllm/model_executor/models/phi3_small.py b/vllm/model_executor/models/phi3_small.py index afc6fe9844ad6..4cfeb3bb3496f 100644 --- a/vllm/model_executor/models/phi3_small.py +++ b/vllm/model_executor/models/phi3_small.py @@ -1,5 +1,5 @@ import math -from typing import Iterable, List, Optional, Tuple +from typing import Iterable, List, Optional, Tuple, Union import torch from torch import nn @@ -7,14 +7,13 @@ from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig, LoRAConfig -from vllm.distributed import (get_tensor_model_parallel_rank, +from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -23,6 +22,10 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors +from .interfaces import SupportsPP +from .utils import (is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers) + def load_column_parallel_weight(param: torch.nn.Parameter, loaded_weight: torch.Tensor): @@ -301,20 +304,25 @@ def __init__( config: PretrainedConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() self.config = config self.embed_tokens = VocabParallelEmbedding(config.vocab_size, config.hidden_size) self.mup_embedding_multiplier = config.mup_embedding_multiplier - self.layers = nn.ModuleList([ - Phi3SmallDecoderLayer(config, layer_idx, cache_config, - quant_config) - for layer_idx in range(config.num_hidden_layers) - ]) + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: Phi3SmallDecoderLayer(config, + int(prefix.split('.')[-1]), + cache_config, quant_config), + prefix=f"{prefix}.layers") self.final_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory(["hidden_states"], + config.hidden_size)) def get_input_embeddings(self): return self.embed_tokens @@ -327,30 +335,37 @@ def forward( input_ids: torch.LongTensor, positions: Optional[torch.LongTensor], kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata = None, - ): - hidden_states = self.embed_tokens(input_ids) - if (self.mup_embedding_multiplier is not None - and self.mup_embedding_multiplier > 0.0): - hidden_states = hidden_states * self.mup_embedding_multiplier - for i in range(len(self.layers)): + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors], + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + hidden_states = self.embed_tokens(input_ids) + if (self.mup_embedding_multiplier is not None + and self.mup_embedding_multiplier > 0.0): + hidden_states = hidden_states * self.mup_embedding_multiplier + else: + assert intermediate_tensors + hidden_states = intermediate_tensors["hidden_states"] + for i in range(self.start_layer, self.end_layer): layer = self.layers[i] hidden_states = layer( positions, hidden_states, - kv_caches[i], + kv_caches[i - self.start_layer], attn_metadata, ) + if not get_pp_group().is_last_rank: + return IntermediateTensors({"hidden_states": hidden_states}) hidden_states = self.final_layernorm(hidden_states) return hidden_states -class Phi3SmallForCausalLM(nn.Module): +class Phi3SmallForCausalLM(nn.Module, SupportsPP): _tied_weights_keys = ["lm_head.weight"] def __init__( self, - config, + config: PretrainedConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, @@ -372,6 +387,8 @@ def __init__( self.lm_head.weight = self.model.embed_tokens.weight self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) # tokens in tiktoken but not used if hasattr(config, 'dummy_token_indices'): @@ -419,12 +436,13 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, IntermediateTensors]: output_hidden_states = self.model( input_ids=input_ids, positions=positions, kv_caches=kv_caches, attn_metadata=attn_metadata, + intermediate_tensors=intermediate_tensors, ) output_hidden_states = output_hidden_states return output_hidden_states @@ -447,6 +465,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): continue if name.endswith(".bias") and name not in params_dict: continue + if is_pp_missing_parameter(name, self): + continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index 245381518a7f8..ebfffb25360cd 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -15,7 +15,7 @@ # limitations under the License. import itertools import re -from functools import lru_cache +from functools import cached_property, lru_cache from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional, Tuple, TypedDict, Union) @@ -29,13 +29,11 @@ from vllm.config import CacheConfig, ModelConfig, MultiModalConfig from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs from vllm.logger import init_logger -from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import Sampler, SamplerOutput -from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.clip import CLIPVisionModel -from vllm.model_executor.models.llama import LlamaModel +from vllm.model_executor.models.llama import LlamaForCausalLM from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.utils import cached_get_tokenizer, repeat_and_pad_token @@ -43,8 +41,9 @@ from vllm.utils import is_list_of from .clip import dummy_image_for_clip, dummy_seq_data_for_clip -from .interfaces import SupportsMultiModal -from .utils import flatten_bn, merge_multimodal_embeddings +from .interfaces import SupportsMultiModal, SupportsPP +from .utils import (flatten_bn, group_weights_with_prefix, + merge_multimodal_embeddings) logger = init_logger(__name__) @@ -295,6 +294,37 @@ def add_image_newline(self, image_features_hd): dim=2).reshape(num_images, -1, hid_dim) return image_features_hd_newline + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + # prepare weight iterators for components + weights_group = group_weights_with_prefix(weights) + + # load vision encoder + self.img_processor.load_weights(weights_group["img_processor"]) + + # load glb_GN + for name, loaded_weight in weights_group["glb_GN"]: + assert name == "" + param = self.glb_GN + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + + # load sub_GN + for name, loaded_weight in weights_group["sub_GN"]: + assert name == "" + param = self.sub_GN + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + + # load mlp projector + mlp_params_dict = dict(self.img_projection.named_parameters()) + for name, loaded_weight in weights_group["img_projection"]: + param = mlp_params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + # Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py#L57 def _calc_padded_size(*, width: int, height: int, padding_unit: int = 336): @@ -508,7 +538,7 @@ def input_processor_for_phi3v(ctx: InputContext, @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_phi3v_image_tokens) @INPUT_REGISTRY.register_dummy_data(dummy_data_for_phi3v) @INPUT_REGISTRY.register_input_processor(input_processor_for_phi3v) -class Phi3VForCausalLM(nn.Module, SupportsMultiModal): +class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): def __init__(self, config: PretrainedConfig, @@ -521,17 +551,21 @@ def __init__(self, self.multimodal_config = multimodal_config self.image_token_id = _IMAGE_TOKEN_ID - self.model = LlamaModel(config, cache_config, quant_config) - # TODO: Optionally initializes this for supporting embeddings. self.vision_embed_tokens = Phi3HDImageEmbedding(config) - self.lm_head = ParallelLMHead(config.vocab_size, - config.hidden_size, - quant_config=quant_config) - if self.config.tie_word_embeddings: - self.lm_head.weight = self.model.embed_tokens.weight - self.logits_processor = LogitsProcessor(config.vocab_size) - self.sampler = Sampler() + + self.language_model = LlamaForCausalLM(config, cache_config, + quant_config) + + self.make_empty_intermediate_tensors = ( + self.language_model.make_empty_intermediate_tensors) + + @cached_property + def sampler(self): + if hasattr(self.language_model, "sampler"): + return self.language_model.sampler + + return Sampler() def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor: expected_dims = (2, ) @@ -631,24 +665,29 @@ def forward(self, attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, **kwargs: object): - image_input = self._parse_and_validate_image_input(**kwargs) - - if image_input is not None: - vision_embeddings = self._process_image_input(image_input) - inputs_embeds = self.model.get_input_embeddings(input_ids) - inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, vision_embeddings, - self.image_token_id) + if intermediate_tensors is not None: input_ids = None - else: inputs_embeds = None + else: + image_input = self._parse_and_validate_image_input(**kwargs) + + if image_input is not None: + vision_embeddings = self._process_image_input(image_input) + inputs_embeds = self.language_model.model.get_input_embeddings( + input_ids) + inputs_embeds = merge_multimodal_embeddings( + input_ids, inputs_embeds, vision_embeddings, + self.image_token_id) + input_ids = None + else: + inputs_embeds = None - hidden_states = self.model(input_ids, - positions, - kv_caches, - attn_metadata, - intermediate_tensors, - inputs_embeds=inputs_embeds) + hidden_states = self.language_model.model(input_ids, + positions, + kv_caches, + attn_metadata, + intermediate_tensors, + inputs_embeds=inputs_embeds) return hidden_states @@ -657,66 +696,38 @@ def compute_logits( hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) - return logits + return self.language_model.compute_logits(hidden_states, + sampling_metadata) def sample( self, logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens + return self.language_model.sample(logits, sampling_metadata) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - (".qkv_proj", ".q_proj", "q"), - (".qkv_proj", ".k_proj", "k"), - (".qkv_proj", ".v_proj", "v"), - (".gate_up_proj", ".gate_proj", 0), - (".gate_up_proj", ".up_proj", 1), - ] + hf_to_vllm_mapping = { + "model.vision_embed_tokens.": "vision_embed_tokens.", + "lm_head.": "language_model.lm_head.", + "model.": "language_model.model.", + } - # TODO(ChristopherCho): This is a temporary fix to load - # the vision weights with CLIPVisionModel.load_weights() - vision_weights = [] - params_dict = dict(self.named_parameters()) - for name, loaded_weight in weights: - if "rotary_emb.inv_freq" in name: - continue - # Skip loading the img_processor weights since they are - # loaded separately. - if "vision_embed_tokens.img_processor" in name: - vision_weights.append((name, loaded_weight)) - continue - - for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items(): - if key_to_modify in name: - name = name.replace(key_to_modify, new_key) - for (param_name, weight_name, shard_id) in stacked_params_mapping: - if weight_name not in name: - continue - - param = params_dict[name.replace(weight_name, param_name)] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - break - else: - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - if name in params_dict: - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) - - # We use regex to extract the sub-module name - # from "model.vision_embed_tokens.img_processor.*" - vision_weights = [ - (re.search(r"vision_embed_tokens\.img_processor\.(.*)", - n).group(1), w) for n, w in vision_weights - ] - self.vision_embed_tokens.img_processor.load_weights(vision_weights) + def hf_to_vllm_name(key: str) -> str: + for hf_name, vllm_name in hf_to_vllm_mapping.items(): + if key.startswith(hf_name): + return key.replace(hf_name, vllm_name, 1) + + return key + + vllm_weights = {hf_to_vllm_name(k): v for k, v in weights} + + # prepare weight iterators for components + weights_group = group_weights_with_prefix(vllm_weights.items()) + + # load vision embeddings and encoder + self.vision_embed_tokens.load_weights( + weights_group["vision_embed_tokens"]) + + # load llm backbone + self.language_model.load_weights(weights_group["language_model"]) diff --git a/vllm/model_executor/models/phimoe.py b/vllm/model_executor/models/phimoe.py index 487d9fc2f4337..a9c815916ed59 100644 --- a/vllm/model_executor/models/phimoe.py +++ b/vllm/model_executor/models/phimoe.py @@ -21,7 +21,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only PhiMoE model.""" -from typing import Iterable, List, Optional, Tuple +from typing import Iterable, List, Optional, Tuple, Union import torch from torch import nn @@ -29,7 +29,7 @@ from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig, LoRAConfig -from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.linear import (QKVParallelLinear, ReplicatedLinear, @@ -46,7 +46,9 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors -from .interfaces import SupportsLoRA +from .interfaces import SupportsLoRA, SupportsPP +from .utils import (is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers) class PhiMoEConfig(PretrainedConfig): @@ -435,6 +437,7 @@ def __init__( cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.padding_idx = config.pad_token_id @@ -448,33 +451,56 @@ def __init__( config.hidden_size, org_num_embeddings=config.vocab_size, ) - self.layers = nn.ModuleList([ - PhiMoEDecoderLayer(config, cache_config, quant_config=quant_config) - for _ in range(config.num_hidden_layers) - ]) + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: PhiMoEDecoderLayer(config, cache_config, + quant_config), + prefix=f"{prefix}.layers") self.norm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps, elementwise_affine=True) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, - ) -> torch.Tensor: - hidden_states = self.embed_tokens(input_ids) - residual = None - for i in range(len(self.layers)): + intermediate_tensors: Optional[IntermediateTensors], + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + hidden_states = self.embed_tokens(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + for i in range(self.start_layer, self.end_layer): layer = self.layers[i] - hidden_states, residual = layer(positions, hidden_states, - kv_caches[i], attn_metadata, - residual) + hidden_states, residual = layer( + positions, + hidden_states, + kv_caches[i - self.start_layer], + attn_metadata, + residual, + ) + + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + hidden_states = self.norm(hidden_states) return hidden_states -class PhiMoEForCausalLM(nn.Module, SupportsLoRA): +class PhiMoEForCausalLM(nn.Module, SupportsLoRA, SupportsPP): fall_back_to_pt_during_load = False packed_modules_mapping = { @@ -537,6 +563,9 @@ def __init__( config.vocab_size) self.sampler = Sampler() + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + def forward( self, input_ids: torch.Tensor, @@ -544,9 +573,9 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata) + attn_metadata, intermediate_tensors) return hidden_states def compute_logits(self, hidden_states: torch.Tensor, @@ -589,6 +618,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) @@ -599,6 +631,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): if weight_name not in name: continue name = name.replace(weight_name, param_name) + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue param = params_dict[name] weight_loader = param.weight_loader weight_loader( @@ -613,6 +648,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue # Remapping the name of FP8 kv-scale. name = maybe_remap_kv_scale_name(name, params_dict) if name is None: diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index aa92e62a30d3f..c8957dcae6b16 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -1,4 +1,5 @@ from dataclasses import dataclass, fields +from functools import cached_property from itertools import tee from typing import Iterable, List, Mapping, Optional, Tuple, Union @@ -16,7 +17,7 @@ from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.utils import merge_multimodal_embeddings from vllm.model_executor.sampling_metadata import SamplingMetadata @@ -25,7 +26,7 @@ from vllm.multimodal.utils import cached_get_tokenizer from vllm.sequence import IntermediateTensors, SequenceData -from .interfaces import SupportsMultiModal +from .interfaces import SupportsMultiModal, SupportsPP from .utils import init_vllm_registered_model @@ -126,7 +127,8 @@ def input_processor_for_pixtral(ctx: InputContext, llm_inputs: LLMInputs): @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_pixtral_image_tokens) @INPUT_REGISTRY.register_dummy_data(dummy_data_for_pixtral) @INPUT_REGISTRY.register_input_processor(input_processor_for_pixtral) -class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal): +class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal, + SupportsPP): def __init__(self, config: PretrainedConfig, @@ -155,6 +157,16 @@ def __init__(self, self.vision_language_adapter = VisionLanguageAdapter( self.vision_args, dim=config.text_config.hidden_size) + self.make_empty_intermediate_tensors = ( + self.language_model.make_empty_intermediate_tensors) + + @cached_property + def sampler(self): + if hasattr(self.language_model, "sampler"): + return self.language_model.sampler + + return Sampler() + def forward( self, input_ids: torch.Tensor, @@ -163,32 +175,36 @@ def forward( attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, **kwargs: object, - ) -> SamplerOutput: + ) -> Union[torch.Tensor, IntermediateTensors]: """Run forward pass for pixtral. TODO """ - image_input = self._parse_and_validate_image_input(**kwargs) + if intermediate_tensors is not None: + input_ids = None + inputs_embeds = None + else: + image_input = self._parse_and_validate_image_input(**kwargs) - if image_input is not None: - vision_embeddings = self._process_image_input(image_input) - inputs_embeds = self.language_model.model.get_input_embeddings( - input_ids) + if image_input is not None: + vision_embeddings = self._process_image_input(image_input) + inputs_embeds = self.language_model.model.get_input_embeddings( + input_ids) - inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, vision_embeddings, - self.vision_args.image_token_id) + inputs_embeds = merge_multimodal_embeddings( + input_ids, inputs_embeds, vision_embeddings, + self.vision_args.image_token_id) - input_ids = None - else: - inputs_embeds = None + input_ids = None + else: + inputs_embeds = None hidden_states = self.language_model.model(input_ids, positions, kv_caches, attn_metadata, - None, + intermediate_tensors, inputs_embeds=inputs_embeds) return hidden_states diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index 761c1370b9776..fd8a27eec3b9a 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -31,15 +31,13 @@ QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.resampler import Resampler2, get_abs_pos from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.models.interfaces import SupportsMultiModal from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.base import MultiModalInputs @@ -47,7 +45,9 @@ from vllm.sequence import IntermediateTensors, SequenceData from vllm.utils import is_list_of -from .utils import flatten_bn, is_pp_missing_parameter, make_layers +from .interfaces import SupportsMultiModal, SupportsPP +from .utils import (flatten_bn, is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers) logger = init_logger(__name__) @@ -568,6 +568,9 @@ def __init__( lambda prefix: QWenBlock(config, cache_config, quant_config), prefix=f"{prefix}.h") self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) self.visual = VisionTransformer(**config.visual, quant_config=quant_config) if hasattr( config, "visual") else None @@ -580,7 +583,7 @@ def forward( attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], pixel_values: Optional[QwenImageInputs], - ) -> torch.Tensor: + ) -> Union[torch.Tensor, IntermediateTensors]: img_pos = None # If pixel / visual embeddings are provided, this is a visual model if pixel_values is not None and self.visual is not None: @@ -860,7 +863,7 @@ def dummy_data_for_qwen( @MULTIMODAL_REGISTRY.register_max_image_tokens(MAX_QWEN_IMG_TOKENS) @INPUT_REGISTRY.register_dummy_data(dummy_data_for_qwen) @INPUT_REGISTRY.register_input_processor(input_processor_for_qwen) -class QWenLMHeadModel(nn.Module, SupportsMultiModal): +class QWenLMHeadModel(nn.Module, SupportsMultiModal, SupportsPP): def __init__( self, @@ -881,6 +884,8 @@ def __init__( self.lm_head.weight = self.transformer.wte.weight self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() + self.make_empty_intermediate_tensors = ( + self.transformer.make_empty_intermediate_tensors) def _get_image_input_type( self, @@ -912,33 +917,26 @@ def _get_image_input_type( ) return None - def forward(self, - input_ids: torch.Tensor, - positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, - intermediate_tensors: Optional[IntermediateTensors] = None, - pixel_values: Optional[torch.Tensor] = None) -> torch.Tensor: - pixel_values = self._get_image_input_type(pixel_values) + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + pixel_values: Optional[torch.Tensor] = None + ) -> Union[torch.Tensor, IntermediateTensors]: + if intermediate_tensors is not None: + input_ids = None + pixel_values = None + else: + pixel_values = self._get_image_input_type(pixel_values) + hidden_states = self.transformer(input_ids, positions, kv_caches, attn_metadata, intermediate_tensors, pixel_values) return hidden_states - def make_empty_intermediate_tensors( - self, batch_size: int, dtype: torch.dtype, - device: torch.device) -> IntermediateTensors: - return IntermediateTensors({ - "hidden_states": - torch.zeros((batch_size, self.config.hidden_size), - dtype=dtype, - device=device), - "residual": - torch.zeros((batch_size, self.config.hidden_size), - dtype=dtype, - device=device), - }) - def compute_logits( self, hidden_states: torch.Tensor, diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index 5e6737ad7fa47..04c1a224c981c 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -22,7 +22,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Qwen2 model compatible with HuggingFace weights.""" -from typing import Iterable, List, Optional, Tuple +from typing import Iterable, List, Optional, Tuple, Union import torch from torch import nn @@ -37,8 +37,7 @@ QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -48,8 +47,9 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors -from .interfaces import SupportsLoRA -from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers +from .interfaces import SupportsLoRA, SupportsPP +from .utils import (PPMissingLayer, is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers) class Qwen2MLP(nn.Module): @@ -253,6 +253,9 @@ def __init__( prefix=f"{prefix}.layers", ) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) if get_pp_group().is_last_rank: self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) else: @@ -269,7 +272,7 @@ def forward( attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: if inputs_embeds is not None: hidden_states = inputs_embeds @@ -298,7 +301,7 @@ def forward( return hidden_states -class Qwen2ForCausalLM(nn.Module, SupportsLoRA): +class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -357,6 +360,8 @@ def __init__( self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) def forward( self, @@ -365,7 +370,7 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model(input_ids, positions, kv_caches, attn_metadata, intermediate_tensors) return hidden_states @@ -379,20 +384,6 @@ def compute_logits( sampling_metadata) return logits - def make_empty_intermediate_tensors( - self, batch_size: int, dtype: torch.dtype, - device: torch.device) -> IntermediateTensors: - return IntermediateTensors({ - "hidden_states": - torch.zeros((batch_size, self.config.hidden_size), - dtype=dtype, - device=device), - "residual": - torch.zeros((batch_size, self.config.hidden_size), - dtype=dtype, - device=device), - }) - def sample( self, logits: torch.Tensor, diff --git a/vllm/model_executor/models/qwen2_moe.py b/vllm/model_executor/models/qwen2_moe.py index d80064601d993..d4475b7ca27af 100644 --- a/vllm/model_executor/models/qwen2_moe.py +++ b/vllm/model_executor/models/qwen2_moe.py @@ -22,7 +22,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Qwen2MoE model compatible with HuggingFace weights.""" -from typing import Any, Dict, Iterable, List, Optional, Tuple +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union import torch import torch.nn.functional as F @@ -42,8 +42,7 @@ ReplicatedLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -53,7 +52,9 @@ from vllm.sequence import IntermediateTensors from vllm.utils import print_warning_once -from .utils import is_pp_missing_parameter, make_layers +from .interfaces import SupportsPP +from .utils import (is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers) class Qwen2MoeMLP(nn.Module): @@ -338,6 +339,9 @@ def __init__( prefix=f"{prefix}.layers", ) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) def forward( self, @@ -346,7 +350,7 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: hidden_states = self.embed_tokens(input_ids) residual = None @@ -368,7 +372,7 @@ def forward( return hidden_states -class Qwen2MoeForCausalLM(nn.Module): +class Qwen2MoeForCausalLM(nn.Module, SupportsPP): fall_back_to_pt_during_load = False @@ -389,6 +393,8 @@ def __init__( self.lm_head.weight = self.model.embed_tokens.weight self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) def forward( self, @@ -397,7 +403,7 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model(input_ids, positions, kv_caches, attn_metadata, intermediate_tensors) return hidden_states @@ -411,20 +417,6 @@ def compute_logits( sampling_metadata) return logits - def make_empty_intermediate_tensors( - self, batch_size: int, dtype: torch.dtype, - device: torch.device) -> IntermediateTensors: - return IntermediateTensors({ - "hidden_states": - torch.zeros((batch_size, self.config.hidden_size), - dtype=dtype, - device=device), - "residual": - torch.zeros((batch_size, self.config.hidden_size), - dtype=dtype, - device=device), - }) - def sample( self, logits: Optional[torch.Tensor], diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index c82e8ed6ed1e0..fd8e2436c1e1f 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -55,7 +55,6 @@ from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.models.interfaces import SupportsMultiModal from vllm.model_executor.models.qwen2 import Qwen2Model from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalDataDict, MultiModalInputs) @@ -68,6 +67,7 @@ from vllm.transformers_utils.processor import get_processor from vllm.utils import is_cpu +from .interfaces import SupportsMultiModal, SupportsPP from .utils import (PPMissingLayer, is_pp_missing_parameter, make_empty_intermediate_tensors_factory) @@ -883,7 +883,8 @@ def input_processor_for_qwen2_vl(ctx: InputContext, "video", get_max_qwen2_vl_video_tokens) @INPUT_REGISTRY.register_dummy_data(dummy_data_for_qwen2_vl) @INPUT_REGISTRY.register_input_processor(input_processor_for_qwen2_vl) -class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal): +class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal, + SupportsPP): def __init__(self, config: Qwen2VLConfig, @@ -1027,7 +1028,7 @@ def forward( attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, **kwargs: object, - ) -> SamplerOutput: + ) -> Union[torch.Tensor, IntermediateTensors]: """Run forward pass for Qwen2-VL. Args: @@ -1047,41 +1048,43 @@ def forward( video_grid_thw: Tensor `(n_videos, 3)` of video 3D grid in LLM. `None` if no videos are passed. """ - - image_input = self._parse_and_validate_image_input(**kwargs) - video_input = self._parse_and_validate_video_input(**kwargs) - - if (image_input is None - and video_input is None) or not get_pp_group().is_first_rank: + if intermediate_tensors is not None: + input_ids = None inputs_embeds = None else: - if getattr(self.config, "rope_scaling", {}).get("type", - None) == "mrope": - assert positions.ndim == 2 and positions.size(0) == 3, ( - "multimodal section rotary embedding requires " - f"(3, seq_len) positions, but got {positions.size()}") - - inputs_embeds = self.model.embed_tokens(input_ids) - - if image_input is not None: - image_embeds = self._process_image_input(image_input) - inputs_embeds = self._merge_multimodal_embeddings( - input_ids, - inputs_embeds, - image_embeds, - placeholder_token_id=self.config.image_token_id, - ) + image_input = self._parse_and_validate_image_input(**kwargs) + video_input = self._parse_and_validate_video_input(**kwargs) - if video_input is not None: - video_embeds = self._process_video_input(video_input) - inputs_embeds = self._merge_multimodal_embeddings( - input_ids, - inputs_embeds, - video_embeds, - placeholder_token_id=self.config.video_token_id, - ) - - input_ids = None + if image_input is None and video_input is None: + inputs_embeds = None + else: + rope_scaling = getattr(self.config, "rope_scaling", {}) + if rope_scaling.get("type", None) == "mrope": + assert positions.ndim == 2 and positions.size(0) == 3, ( + "multimodal section rotary embedding requires " + f"(3, seq_len) positions, but got {positions.size()}") + + inputs_embeds = self.model.embed_tokens(input_ids) + + if image_input is not None: + image_embeds = self._process_image_input(image_input) + inputs_embeds = self._merge_multimodal_embeddings( + input_ids, + inputs_embeds, + image_embeds, + placeholder_token_id=self.config.image_token_id, + ) + + if video_input is not None: + video_embeds = self._process_video_input(video_input) + inputs_embeds = self._merge_multimodal_embeddings( + input_ids, + inputs_embeds, + video_embeds, + placeholder_token_id=self.config.video_token_id, + ) + + input_ids = None hidden_states = self.model( input_ids=input_ids, diff --git a/vllm/model_executor/models/siglip.py b/vllm/model_executor/models/siglip.py index cd99538378412..743a81f8f9e95 100644 --- a/vllm/model_executor/models/siglip.py +++ b/vllm/model_executor/models/siglip.py @@ -246,7 +246,7 @@ class SiglipParallelAttention(nn.Module): def __init__( self, - config, + config: SiglipVisionConfig, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -312,7 +312,7 @@ class SiglipMLP(nn.Module): def __init__( self, - config, + config: SiglipVisionConfig, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() diff --git a/vllm/model_executor/models/solar.py b/vllm/model_executor/models/solar.py index 16e576d0ac29c..b9298ed031144 100644 --- a/vllm/model_executor/models/solar.py +++ b/vllm/model_executor/models/solar.py @@ -26,6 +26,7 @@ import torch from torch import nn +from transformers import PretrainedConfig from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig, LoRAConfig @@ -37,8 +38,7 @@ QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( get_compressed_tensors_cache_scale) from vllm.model_executor.layers.rotary_embedding import get_rope @@ -47,14 +47,14 @@ DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name) -from vllm.model_executor.models.interfaces import SupportsLoRA -from vllm.model_executor.models.utils import (PPMissingLayer, - is_pp_missing_parameter, - make_layers) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from vllm.utils import is_hip +from .interfaces import SupportsLoRA, SupportsPP +from .utils import (PPMissingLayer, is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers) + class SolarMLP(nn.Module): @@ -98,7 +98,7 @@ class SolarAttention(nn.Module): def __init__( self, - config, + config: PretrainedConfig, hidden_size: int, num_heads: int, num_kv_heads: int, @@ -187,7 +187,7 @@ class SolarDecoderLayer(nn.Module): def __init__( self, - config, + config: PretrainedConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", @@ -267,7 +267,7 @@ class SolarModel(nn.Module): def __init__( self, - config, + config: PretrainedConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, @@ -304,6 +304,10 @@ def __init__( else: self.norm = PPMissingLayer() + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) @@ -368,7 +372,7 @@ def forward( return hidden_states -class SolarForCausalLM(nn.Module, SupportsLoRA): +class SolarForCausalLM(nn.Module, SupportsLoRA, SupportsPP): packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -406,7 +410,7 @@ class SolarForCausalLM(nn.Module, SupportsLoRA): def __init__( self, - config, + config: PretrainedConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, @@ -448,6 +452,9 @@ def __init__( else: self.lm_head = PPMissingLayer() + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + def forward( self, input_ids: torch.Tensor, @@ -474,24 +481,6 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def make_empty_intermediate_tensors( - self, batch_size: int, dtype: torch.dtype, - device: torch.device) -> IntermediateTensors: - return IntermediateTensors({ - "hidden_states": - torch.zeros( - (batch_size, self.config.hidden_size), - dtype=dtype, - device=device, - ), - "residual": - torch.zeros( - (batch_size, self.config.hidden_size), - dtype=dtype, - device=device, - ), - }) - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) diff --git a/vllm/model_executor/models/stablelm.py b/vllm/model_executor/models/stablelm.py index 6236426dcd4e1..083a48588d01a 100644 --- a/vllm/model_executor/models/stablelm.py +++ b/vllm/model_executor/models/stablelm.py @@ -19,7 +19,7 @@ # https://huggingface.co/stabilityai/stablelm-3b-4e1t/blob/main/config.json """Inference-only StabeLM (https://github.com/Stability-AI/StableLM) model compatible with HuggingFace weights.""" -from typing import Iterable, List, Optional, Tuple +from typing import Iterable, List, Optional, Tuple, Union import torch from torch import nn @@ -27,14 +27,13 @@ from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig -from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -43,6 +42,10 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors +from .interfaces import SupportsPP +from .utils import (is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers) + class StablelmMLP(nn.Module): @@ -194,19 +197,25 @@ class StableLMEpochModel(nn.Module): def __init__(self, config: PretrainedConfig, cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None) -> None: + quant_config: Optional[QuantizationConfig] = None, + prefix: str = '') -> None: super().__init__() self.embed_tokens = VocabParallelEmbedding( config.vocab_size, config.hidden_size, ) - self.layers = nn.ModuleList([ - StablelmDecoderLayer(config, cache_config, quant_config) - for _ in range(config.num_hidden_layers) - ]) + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: StablelmDecoderLayer(config, cache_config, + quant_config), + prefix=f"{prefix}.layers", + ) norm_eps = getattr(config, "norm_eps", getattr(config, "layer_norm_eps", 1e-05)) self.norm = nn.LayerNorm(config.hidden_size, eps=norm_eps) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory(["hidden_states"], + config.hidden_size)) def forward( self, @@ -214,21 +223,28 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, - ) -> torch.Tensor: - hidden_states = self.embed_tokens(input_ids) - for i in range(len(self.layers)): + intermediate_tensors: Optional[IntermediateTensors], + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + hidden_states = self.embed_tokens(input_ids) + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + for i in range(self.start_layer, self.end_layer): layer = self.layers[i] hidden_states, residual = layer( positions, hidden_states, - kv_caches[i], + kv_caches[i - self.start_layer], attn_metadata, ) + if not get_pp_group().is_last_rank: + return IntermediateTensors({"hidden_states": hidden_states}) hidden_states = self.norm(hidden_states) return hidden_states -class StablelmForCausalLM(nn.Module): +class StablelmForCausalLM(nn.Module, SupportsPP): def __init__( self, @@ -247,6 +263,8 @@ def __init__( self.lm_head.weight = self.model.embed_tokens.weight self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) def forward( self, @@ -255,9 +273,9 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata) + attn_metadata, intermediate_tensors) return hidden_states def compute_logits( @@ -302,6 +320,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + if is_pp_missing_parameter(name, self): + continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) @@ -310,6 +330,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + if is_pp_missing_parameter(name, self): + continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) diff --git a/vllm/model_executor/models/starcoder2.py b/vllm/model_executor/models/starcoder2.py index d3a3a83c8437f..81dd7c4daa5e9 100644 --- a/vllm/model_executor/models/starcoder2.py +++ b/vllm/model_executor/models/starcoder2.py @@ -18,7 +18,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """ PyTorch Starcoder2 model.""" -from typing import Iterable, List, Optional, Tuple +from typing import Iterable, List, Optional, Tuple, Union import torch from torch import nn @@ -26,14 +26,13 @@ from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig -from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -42,6 +41,10 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors +from .interfaces import SupportsPP +from .utils import (is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers) + class Starcoder2Attention(nn.Module): @@ -195,7 +198,8 @@ class Starcoder2Model(nn.Module): def __init__(self, config: Starcoder2Config, cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None): + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): super().__init__() self.config = config self.padding_idx = config.pad_token_id @@ -204,13 +208,16 @@ def __init__(self, # TODO: consider padding_idx (currently removed) self.embed_tokens = VocabParallelEmbedding(config.vocab_size, config.hidden_size) - self.layers = nn.ModuleList([ - Starcoder2DecoderLayer(config, - cache_config, - quant_config=quant_config) - for _ in range(config.num_hidden_layers) - ]) + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: Starcoder2DecoderLayer( + config, cache_config, quant_config=quant_config), + prefix=f"{prefix}.layers", + ) self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory(["hidden_states"], + config.hidden_size)) def forward( self, @@ -218,17 +225,25 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, - ) -> torch.Tensor: - hidden_states = self.embed_tokens(input_ids) - for i in range(len(self.layers)): + intermediate_tensors: Optional[IntermediateTensors], + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + hidden_states = self.embed_tokens(input_ids) + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + for i in range(self.start_layer, self.end_layer): layer = self.layers[i] - hidden_states = layer(positions, hidden_states, kv_caches[i], + hidden_states = layer(positions, hidden_states, + kv_caches[i - self.start_layer], attn_metadata) + if not get_pp_group().is_last_rank: + return IntermediateTensors({"hidden_states": hidden_states}) hidden_states = self.norm(hidden_states) return hidden_states -class Starcoder2ForCausalLM(nn.Module): +class Starcoder2ForCausalLM(nn.Module, SupportsPP): def __init__(self, config: Starcoder2Config, @@ -255,6 +270,8 @@ def __init__(self, self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size) self.sampler = Sampler() + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) def forward( self, @@ -263,9 +280,9 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata) + attn_metadata, intermediate_tensors) return hidden_states def compute_logits( @@ -302,6 +319,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): if weight_name not in name: continue name = name.replace(weight_name, param_name) + if is_pp_missing_parameter(name, self): + continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) @@ -309,6 +328,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): else: if self.config.tie_word_embeddings and "lm_head.weight" in name: continue + if is_pp_missing_parameter(name, self): + continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index 71808eb4c2719..daa6e72dd1002 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -3,7 +3,7 @@ import math from array import array -from functools import lru_cache +from functools import cached_property, lru_cache from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, TypedDict, Union, cast) @@ -22,12 +22,10 @@ from vllm.inputs.registry import InputContext from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) -from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.model_loader.loader import DefaultModelLoader from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.models.interfaces import SupportsMultiModal from vllm.model_executor.models.utils import (flatten_bn, group_weights_with_prefix, init_vllm_registered_model, @@ -37,9 +35,12 @@ from vllm.multimodal.base import MultiModalInputs, NestedTensors from vllm.multimodal.utils import (cached_get_tokenizer, repeat_and_pad_placeholder_tokens) -from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData +from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors, + SequenceData) from vllm.transformers_utils.configs.ultravox import UltravoxConfig +from .interfaces import SupportsMultiModal, SupportsPP + _AUDIO_PLACEHOLDER_TOKEN = 128002 _AUDIO_TOKENS_PER_SECOND = 6.25 @@ -323,7 +324,7 @@ def forward( "audio", get_ultravox_max_audio_tokens) @INPUT_REGISTRY.register_dummy_data(dummy_data_for_ultravox) @INPUT_REGISTRY.register_input_processor(input_processor_for_ultravox) -class UltravoxModel(nn.Module, SupportsMultiModal): +class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP): def __init__(self, config: UltravoxConfig, @@ -353,6 +354,16 @@ def __init__(self, revision=None, prefix="language_model.")) + self.make_empty_intermediate_tensors = ( + self.language_model.make_empty_intermediate_tensors) + + @cached_property + def sampler(self): + if hasattr(self.language_model, "sampler"): + return self.language_model.sampler + + return Sampler() + def _audio_features_to_embeddings( self, input_features: torch.Tensor) -> torch.Tensor: audio_input = input_features.to(self.audio_tower.dtype) @@ -425,7 +436,7 @@ def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[torch.Tensor], - **kwargs) -> SamplerOutput: + **kwargs) -> Union[torch.Tensor, IntermediateTensors]: """Run forward pass for Ultravox One key thing to understand is the `input_ids` already accounts for the @@ -438,18 +449,22 @@ def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, Args: audio_features: A batch of audio inputs [B, N, 80, M]. """ - audio_input = self._parse_and_validate_audio_input(**kwargs) - if audio_input is not None: - audio_embeddings = self._process_audio_input(audio_input) - inputs_embeds = self.language_model.model.get_input_embeddings( - input_ids) - - inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, audio_embeddings, - _AUDIO_PLACEHOLDER_TOKEN) + if intermediate_tensors is not None: input_ids = None - else: inputs_embeds = None + else: + audio_input = self._parse_and_validate_audio_input(**kwargs) + if audio_input is not None: + audio_embeddings = self._process_audio_input(audio_input) + inputs_embeds = self.language_model.model.get_input_embeddings( + input_ids) + + inputs_embeds = merge_multimodal_embeddings( + input_ids, inputs_embeds, audio_embeddings, + _AUDIO_PLACEHOLDER_TOKEN) + input_ids = None + else: + inputs_embeds = None hidden_states = self.language_model.model( input_ids=input_ids, diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index f6218bad4ef1e..761f0406b1333 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -24,7 +24,7 @@ class WeightsGroup(UserDict): when attempting to access a weight component that does not exist. """ - def __getitem__(self, key: str) -> int: + def __getitem__(self, key: str) -> Iterable[Tuple[str, torch.Tensor]]: try: return super().__getitem__(key) except KeyError as exc: @@ -49,8 +49,7 @@ def filter_weights(weights: Iterable[Tuple[str, torch.Tensor]], def group_weights_with_prefix( - weights: Iterable[Tuple[str, torch.Tensor]] -) -> Dict[str, Iterable[Tuple[str, torch.Tensor]]]: + weights: Iterable[Tuple[str, torch.Tensor]], ) -> WeightsGroup: """ Helper function to group weights with prefix """ @@ -183,10 +182,7 @@ def merge_multimodal_embeddings(input_ids: torch.Tensor, class LayerFn(Protocol): - def __call__( - self, - prefix="", - ) -> torch.nn.Module: + def __call__(self, prefix: str) -> torch.nn.Module: ... @@ -319,8 +315,10 @@ def is_pp_missing_parameter(name: str, model: torch.nn.Module) -> bool: def make_empty_intermediate_tensors_factory(keys: List[str], hidden_size: int): def make_empty_intermediate_tensors( - batch_size: int, dtype: torch.dtype, - device: torch.device) -> IntermediateTensors: + batch_size: int, + dtype: torch.dtype, + device: torch.device, + ) -> IntermediateTensors: return IntermediateTensors({ key: torch.zeros((batch_size, hidden_size), dtype=dtype, @@ -342,8 +340,14 @@ def __init__(self, llm: nn.Module, name: str) -> None: self.model_name = name setattr(self, name, llm) - def forward(self, *args, **kwargs) -> Any: - return getattr(self, self.model_name)(*args, **kwargs) + def __getattr__(self, key: str): + llm = super().__getattr__(self.model_name) + if key == self.model_name: + return llm - def embed_tokens(self, *args, **kwargs) -> Any: - return getattr(self, self.model_name).embed_tokens(*args, **kwargs) + return getattr(llm, key) + + # We need to explicitly override this + def __call__(self, *args: Any, **kwargs: Any) -> Any: + llm = super().__getattr__(self.model_name) + return llm(*args, **kwargs) diff --git a/vllm/model_executor/models/xverse.py b/vllm/model_executor/models/xverse.py index 24cc3728f85e4..3bded82033c08 100644 --- a/vllm/model_executor/models/xverse.py +++ b/vllm/model_executor/models/xverse.py @@ -20,7 +20,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Xverse model compatible with HuggingFace weights.""" -from typing import Any, Dict, Iterable, List, Optional, Tuple +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union import torch from torch import nn @@ -28,15 +28,14 @@ from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig, LoRAConfig -from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -45,7 +44,9 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors -from .interfaces import SupportsLoRA +from .interfaces import SupportsLoRA, SupportsPP +from .utils import (is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers) class XverseMLP(nn.Module): @@ -227,6 +228,7 @@ def __init__( cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.config = config @@ -240,11 +242,16 @@ def __init__( config.hidden_size, org_num_embeddings=config.vocab_size, ) - self.layers = nn.ModuleList([ - XverseDecoderLayer(config, cache_config, quant_config) - for _ in range(config.num_hidden_layers) - ]) + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: XverseDecoderLayer(config, cache_config, + quant_config), + prefix=f"{prefix}.layers", + ) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) def forward( self, @@ -252,23 +259,32 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, - ) -> torch.Tensor: - hidden_states = self.embed_tokens(input_ids) - residual = None - for i in range(len(self.layers)): + intermediate_tensors: Optional[IntermediateTensors], + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + hidden_states = self.embed_tokens(input_ids) + residual = None + else: + hidden_states = intermediate_tensors["hidden_states"] + for i in range(self.start_layer, self.end_layer): layer = self.layers[i] hidden_states, residual = layer( positions, hidden_states, - kv_caches[i], + kv_caches[i - self.start_layer], attn_metadata, residual, ) + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) hidden_states, _ = self.norm(hidden_states, residual) return hidden_states -class XverseForCausalLM(nn.Module, SupportsLoRA): +class XverseForCausalLM(nn.Module, SupportsLoRA, SupportsPP): packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -317,6 +333,8 @@ def __init__( self.lm_head.weight = self.model.embed_tokens.weight self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) def forward( self, @@ -325,9 +343,9 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata) + attn_metadata, intermediate_tensors) return hidden_states def compute_logits( @@ -368,6 +386,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + if is_pp_missing_parameter(name, self): + continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) @@ -376,6 +396,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + if is_pp_missing_parameter(name, self): + continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader)