From 8b5963185512eb7799f12240570e0ac7e7462a88 Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Fri, 6 Dec 2024 10:34:29 -0500 Subject: [PATCH 01/18] [Core] Support Lark grammars for XGrammar (#10870) Signed-off-by: mgoin --- .../guided_decoding/__init__.py | 8 - .../guided_decoding/xgrammar_decoding.py | 17 +- .../guided_decoding/xgrammar_utils.py | 162 ++++++++++++++++++ 3 files changed, 178 insertions(+), 9 deletions(-) create mode 100644 vllm/model_executor/guided_decoding/xgrammar_utils.py diff --git a/vllm/model_executor/guided_decoding/__init__.py b/vllm/model_executor/guided_decoding/__init__.py index a81377341e095..e631aec928ec5 100644 --- a/vllm/model_executor/guided_decoding/__init__.py +++ b/vllm/model_executor/guided_decoding/__init__.py @@ -73,14 +73,6 @@ def maybe_backend_fallback( "Falling back to use outlines instead.") guided_params.backend = "outlines" - # xgrammar only supports EBNF grammars and uses the GBNF format - # https://github.com/ggerganov/llama.cpp/blob/master/grammars/README.md - elif (guided_params.grammar is not None - and "::=" not in guided_params.grammar): - logger.warning("xgrammar only supports EBNF grammars. " - "Falling back to use outlines instead.") - guided_params.backend = "outlines" - # xgrammar doesn't support some JSON schema features elif (guided_params.json is not None and has_xgrammar_unsupported_json_features(guided_params.json)): diff --git a/vllm/model_executor/guided_decoding/xgrammar_decoding.py b/vllm/model_executor/guided_decoding/xgrammar_decoding.py index 8287cd6cf3aa0..b59a2269d2cd5 100644 --- a/vllm/model_executor/guided_decoding/xgrammar_decoding.py +++ b/vllm/model_executor/guided_decoding/xgrammar_decoding.py @@ -14,6 +14,9 @@ except ImportError: pass +from vllm.model_executor.guided_decoding.xgrammar_utils import ( + convert_lark_to_gbnf, grammar_is_likely_lark) + if TYPE_CHECKING: from transformers import PreTrainedTokenizer @@ -152,7 +155,19 @@ def from_guided_params(cls, tokenizer_hash=tokenizer_hash, max_threads=max_threads) elif guided_params.grammar: - return cls(grammar_str=guided_params.grammar, + # XGrammar only supports GBNF grammars, so we must convert Lark + if grammar_is_likely_lark(guided_params.grammar): + try: + grammar_str = convert_lark_to_gbnf(guided_params.grammar) + except ValueError as e: + raise ValueError( + "Failed to convert the grammar from Lark to GBNF. " + "Please either use GBNF grammar directly or specify" + " --guided-decoding-backend=outlines.\n" + f"Conversion error: {str(e)}") from e + else: + grammar_str = guided_params.grammar + return cls(grammar_str=grammar_str, vocab_size=model_config.hf_config.vocab_size, encoded_vocab=encoded_vocab, stop_token_ids=stop_token_ids, diff --git a/vllm/model_executor/guided_decoding/xgrammar_utils.py b/vllm/model_executor/guided_decoding/xgrammar_utils.py new file mode 100644 index 0000000000000..12b42245f4e3d --- /dev/null +++ b/vllm/model_executor/guided_decoding/xgrammar_utils.py @@ -0,0 +1,162 @@ +import re + + +def grammar_is_likely_lark(grammar_str: str) -> bool: + """ + Check if grammar appears to use Lark syntax. + + Args: + grammar_str: Input grammar string + + Returns: + bool: True if grammar appears to be in Lark format, False otherwise + + Examples: + >>> grammar_is_likely_lark("rule: 'abc'") + True + >>> grammar_is_likely_lark("rule ::= 'abc'") + False + """ + if not grammar_str or not isinstance(grammar_str, str): + return False + + for line in grammar_str.split('\n'): + # Remove both comment styles + line = re.sub(r'(#|//).*$', '', line).strip() + if not line: + continue + + # Look for Lark-style rule definitions + if ':' in line and '::=' not in line: + return True + + # Look for Lark-specific features + if any(pattern in line for pattern in ['?start:', '|', '~']): + return True + + return False + + +def convert_lark_to_gbnf(grammar_str: str) -> str: + """ + Convert a Lark grammar string to GBNF format. + + GBNF reference: + https://github.com/ggerganov/llama.cpp/blob/master/grammars/README.md + Lark grammar reference: + https://lark-parser.readthedocs.io/en/latest/grammar.html + + Args: + grammar_str: Input grammar in Lark format + + Returns: + str: Converted grammar in GBNF format + + Examples: + >>> print(convert_lark_to_gbnf("rule: 'hello'")) + root ::= rule + rule ::= "hello" + """ + if not isinstance(grammar_str, str): + raise ValueError(f"Grammar must be a string, got {type(grammar_str)}") + if not grammar_str.strip(): + raise ValueError("Grammar string cannot be empty") + + defined_rules = set() + referenced_rules = set() + output_lines = [] + + def clean_line(line: str) -> str: + """Remove comments and whitespace from line.""" + return re.sub(r'(#|//).*$', '', line).strip() + + def check_quotes(text: str, rule_name: str, line_num: int) -> None: + """Validate quote matching in text.""" + if text.count("'") % 2 != 0 or text.count('"') % 2 != 0: + raise ValueError( + f"Mismatched quotes in {rule_name} on line {line_num}") + + def extract_references(text: str) -> set: + """Extract rule references from text.""" + # Remove quoted strings and special characters + text = re.sub(r'"[^"]*"', '', text) + text = re.sub(r'[+*?()|\[\]{}]', ' ', text) + return set(re.findall(r'\b[a-zA-Z_][a-zA-Z0-9_]*\b', text)) + + # First pass: Find root rule and validate rule definitions + lines = [clean_line(line) for line in grammar_str.split('\n')] + first_rule = None + + for line_num, line in enumerate(lines, 1): + if not line or line.startswith('|'): + continue + + if ':' in line: + try: + name = line.split(':', 1)[0].strip().strip('?') + defined_rules.add(name) + if first_rule is None: + first_rule = name + if name == 'start': + first_rule = 'start' + except IndexError as e: + raise ValueError(f"Invalid rule format on line {line_num}. " + "Expected 'rule_name: definition'") from e + + if not defined_rules: + raise ValueError("No valid rules found in grammar") + + # Add root rule + output_lines.append(f"root ::= {first_rule}") + + # Second pass: Process rule definitions and alternatives + current_rule = None + current_definition = [] + + for line_num, line in enumerate(lines, 1): + if not line: + continue + + try: + if ':' in line and not line.startswith('|'): + # Save previous rule if exists + if current_rule: + output_lines.append( + f"{current_rule} ::= {' | '.join(current_definition)}") + + # Process new rule + name, definition = line.split(':', 1) + current_rule = name.strip().strip('?') + + check_quotes(definition, f"rule '{current_rule}'", line_num) + definition = re.sub(r"'([^']*)'", r'"\1"', definition) + referenced_rules.update(extract_references(definition)) + current_definition = [definition.strip()] + + elif line.startswith('|'): + if not current_rule: + raise ValueError(f"Alternative '|' on line {line_num} " + "without a preceding rule definition") + + alt_def = line[1:].strip() + check_quotes(alt_def, f"alternative for rule '{current_rule}'", + line_num) + alt_def = re.sub(r"'([^']*)'", r'"\1"', alt_def) + referenced_rules.update(extract_references(alt_def)) + current_definition.append(alt_def) + + except ValueError as e: + raise ValueError(f"Error on line {line_num}: {str(e)}") from e + + # Add final rule if exists + if current_rule: + output_lines.append( + f"{current_rule} ::= {' | '.join(current_definition)}") + + # Validate all rules are defined + undefined_rules = referenced_rules - defined_rules - {'root'} + if undefined_rules: + raise ValueError("Referenced rules are not defined: " + f"{', '.join(sorted(undefined_rules))}") + + return '\n'.join(output_lines) From 74062740416db8572627dda1f87925268ba2f1d3 Mon Sep 17 00:00:00 2001 From: Sam Stoelinga Date: Fri, 6 Dec 2024 09:03:56 -0800 Subject: [PATCH 02/18] [Doc] add KubeAI to serving integrations (#10837) Signed-off-by: Sam Stoelinga --- docs/source/serving/deploying_with_kubeai.rst | 17 +++++++++++++++++ docs/source/serving/integrations.rst | 1 + 2 files changed, 18 insertions(+) create mode 100644 docs/source/serving/deploying_with_kubeai.rst diff --git a/docs/source/serving/deploying_with_kubeai.rst b/docs/source/serving/deploying_with_kubeai.rst new file mode 100644 index 0000000000000..ec3c065320fd9 --- /dev/null +++ b/docs/source/serving/deploying_with_kubeai.rst @@ -0,0 +1,17 @@ +.. _deploying_with_kubeai: + +Deploying with KubeAI +===================== + +`KubeAI `_ is a Kubernetes operator that enables you to deploy and manage AI models on Kubernetes. It provides a simple and scalable way to deploy vLLM in production. Functionality such as scale-from-zero, load based autoscaling, model caching, and much more is provided out of the box with zero external dependencies. + + +Please see the Installation Guides for environment specific instructions: + +* `Any Kubernetes Cluster `_ +* `EKS `_ +* `GKE `_ + +Once you have KubeAI installed, you can +`configure text generation models `_ +using vLLM. \ No newline at end of file diff --git a/docs/source/serving/integrations.rst b/docs/source/serving/integrations.rst index f39997e0e44d9..0dd505a739863 100644 --- a/docs/source/serving/integrations.rst +++ b/docs/source/serving/integrations.rst @@ -6,6 +6,7 @@ Integrations run_on_sky deploying_with_kserve + deploying_with_kubeai deploying_with_triton deploying_with_bentoml deploying_with_cerebrium From c05cfb67da12f84bd142ba51cca98e59139bea42 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 6 Dec 2024 11:25:20 -0800 Subject: [PATCH 03/18] [misc] fix typo (#10960) Signed-off-by: youkaichao --- vllm/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/config.py b/vllm/config.py index a5e2702035a5c..fe4c85441fced 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2082,7 +2082,7 @@ class KVTransferConfig(BaseModel): @classmethod def from_cli(cls, cli_value: str) -> "KVTransferConfig": - """Parse the CLI value for the compilation config.""" + """Parse the CLI value for the kv cache transfer config.""" return KVTransferConfig.model_validate_json(cli_value) def model_post_init(self, __context: Any) -> None: From dcdc3fafe535178037ef0a58f53607b2fb3e4190 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 6 Dec 2024 11:25:47 -0800 Subject: [PATCH 04/18] [ci] fix broken tests (#10956) Signed-off-by: youkaichao --- vllm/worker/model_runner.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 4388b3c1ee164..1bc5f65c7127f 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1782,6 +1782,9 @@ def need_recv_kv(self, model_input, kv_caches) -> bool: kv_caches: vLLM's paged memory """ + if self.vllm_config.kv_transfer_config is None: + return False + prefill_meta = model_input.attn_metadata.prefill_metadata # check if the current run is profiling @@ -1789,9 +1792,6 @@ def need_recv_kv(self, model_input, kv_caches) -> bool: # check if the current run is prefill is_prefill_run = prefill_meta is not None - if self.vllm_config.kv_transfer_config is None: - return False - return self.vllm_config.kv_transfer_config.is_kv_consumer and ( not is_profile_run) and is_prefill_run @@ -1807,6 +1807,9 @@ def need_send_kv(self, model_input, kv_caches) -> bool: kv_caches: vLLM's paged memory """ + if self.vllm_config.kv_transfer_config is None: + return False + prefill_meta = model_input.attn_metadata.prefill_metadata # check if the current run is profiling @@ -1814,9 +1817,6 @@ def need_send_kv(self, model_input, kv_caches) -> bool: # check if the current run is prefill is_prefill_run = prefill_meta is not None - if self.vllm_config.kv_transfer_config is None: - return False - return self.vllm_config.kv_transfer_config.is_kv_producer and ( not is_profile_run) and is_prefill_run From 69d357ba125a8c4243c25d7d9162f1c93cfddd1f Mon Sep 17 00:00:00 2001 From: Russell Bryant Date: Fri, 6 Dec 2024 21:30:23 -0500 Subject: [PATCH 05/18] [Core] Cleanup startup logging a bit (#10961) Signed-off-by: Russell Bryant --- vllm/engine/arg_utils.py | 1 + vllm/entrypoints/openai/api_server.py | 8 ++++---- vllm/plugins/__init__.py | 2 +- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 0b304658f012c..ccd9fac225cba 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -433,6 +433,7 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: 'capping to sliding window size') parser.add_argument('--use-v2-block-manager', action='store_true', + default=True, help='[DEPRECATED] block manager v1 has been ' 'removed and SelfAttnBlockSpaceManager (i.e. ' 'block manager v2) is now the default. ' diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 6bc31ef83ded4..c7bc30040279c 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -175,8 +175,8 @@ async def build_async_engine_client_from_engine_args( # Select random path for IPC. ipc_path = get_open_zmq_ipc_path() - logger.info("Multiprocessing frontend to use %s for IPC Path.", - ipc_path) + logger.debug("Multiprocessing frontend to use %s for IPC Path.", + ipc_path) # Start RPCServer in separate process (holds the LLMEngine). # the current process might have CUDA context, @@ -249,8 +249,8 @@ def mount_metrics(app: FastAPI): prometheus_multiproc_dir_path = os.getenv("PROMETHEUS_MULTIPROC_DIR", None) if prometheus_multiproc_dir_path is not None: - logger.info("vLLM to use %s as PROMETHEUS_MULTIPROC_DIR", - prometheus_multiproc_dir_path) + logger.debug("vLLM to use %s as PROMETHEUS_MULTIPROC_DIR", + prometheus_multiproc_dir_path) registry = CollectorRegistry() multiprocess.MultiProcessCollector(registry) diff --git a/vllm/plugins/__init__.py b/vllm/plugins/__init__.py index ae6e5c0a3481f..17f604ea0e202 100644 --- a/vllm/plugins/__init__.py +++ b/vllm/plugins/__init__.py @@ -57,7 +57,7 @@ def load_general_plugins(): discovered_plugins = entry_points(group='vllm.general_plugins') if len(discovered_plugins) == 0: - logger.info("No plugins found.") + logger.debug("No plugins found.") return logger.info("Available plugins:") for plugin in discovered_plugins: From acf092d34802b187f27daa8e1626f67552bde193 Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Sat, 7 Dec 2024 12:08:54 +0800 Subject: [PATCH 06/18] [Bugfix] Fix test-pipeline.yaml (#10973) Signed-off-by: Jee Jee Li --- .buildkite/test-pipeline.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index bf0de3f69f14e..936e284d9675a 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -237,7 +237,7 @@ steps: source_file_dependencies: - vllm/lora - tests/lora - command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT --ignore lora/test_long_context.py lora/test_chatglm3_tp.py lora/test_llama_tp.py + command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT --ignore=lora/test_long_context.py --ignore=lora/test_chatglm3_tp.py --ignore=lora/test_llama_tp.py parallelism: 4 - label: "PyTorch Fullgraph Smoke Test" # 9min From 955fa9533afde0d232e73f079d72239c8a87c636 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Sat, 7 Dec 2024 16:50:58 +0800 Subject: [PATCH 07/18] [3/N] Support and implement merged input processor for LLaVA model (#10676) Signed-off-by: DarkLight1337 Co-authored-by: Roger Wang --- tests/multimodal/test_mapper.py | 49 +-- tests/multimodal/test_processing.py | 277 +++++++++++----- .../vllm_add_dummy_model/my_llava.py | 12 +- vllm/inputs/registry.py | 42 ++- vllm/model_executor/models/llava.py | 219 +++++------- vllm/multimodal/base.py | 51 ++- vllm/multimodal/processing.py | 313 +++++++++++------- vllm/multimodal/registry.py | 67 +++- vllm/v1/engine/mm_input_mapper.py | 1 + vllm/v1/engine/processor.py | 16 +- 10 files changed, 626 insertions(+), 421 deletions(-) diff --git a/tests/multimodal/test_mapper.py b/tests/multimodal/test_mapper.py index 13ad4a7966b9d..71832acbd17b8 100644 --- a/tests/multimodal/test_mapper.py +++ b/tests/multimodal/test_mapper.py @@ -2,7 +2,7 @@ import numpy as np import pytest -from transformers import CLIPImageProcessor, LlavaNextImageProcessor +from transformers import LlavaNextImageProcessor from vllm.config import ModelConfig from vllm.multimodal import MultiModalRegistry @@ -14,49 +14,6 @@ def mm_registry(): return MultiModalRegistry() -@pytest.mark.parametrize("dtype", ["half", "float"]) -@pytest.mark.parametrize("size_factor", [0.25, 0.5, 1.0]) -def test_clip_image_processor(image_assets, mm_registry, dtype, size_factor): - MODEL_NAME = "llava-hf/llava-1.5-7b-hf" - - hf_processor = CLIPImageProcessor.from_pretrained(MODEL_NAME) - assert isinstance(hf_processor, CLIPImageProcessor) - - model_config = ModelConfig( - model=MODEL_NAME, - task="auto", - tokenizer=MODEL_NAME, - tokenizer_mode="auto", - trust_remote_code=False, - seed=0, - dtype=dtype, - revision=None, - limit_mm_per_prompt={"image": 1}, - ) - - mm_registry.init_mm_limits_per_prompt(model_config) - - for asset in image_assets: - image = rescale_image_size(asset.pil_image, size_factor) - - hf_result = hf_processor.preprocess( - image, - return_tensors="pt", - ) - vllm_result = mm_registry.map_input( - model_config, - {"image": image}, - ) - - assert hf_result.keys() == vllm_result.keys() - for key, hf_tensor in hf_result.items(): - hf_arr: np.ndarray = hf_tensor.numpy() - vllm_arr: np.ndarray = vllm_result[key].numpy() - - assert hf_arr.shape == vllm_arr.shape, f"Failed for key={key}" - assert np.allclose(hf_arr, vllm_arr), f"Failed for key={key}" - - @pytest.mark.parametrize("dtype", ["half", "float"]) @pytest.mark.parametrize("size_factor", [0.25, 0.5, 1.0]) def test_llava_next_image_processor(image_assets, mm_registry, dtype, @@ -107,7 +64,7 @@ def test_llava_next_image_processor(image_assets, mm_registry, dtype, (2, 1, False), (2, 2, True)], ) def test_mm_limits(image_assets, mm_registry, num_images, limit, is_valid): - MODEL_NAME = "llava-hf/llava-1.5-7b-hf" + MODEL_NAME = "llava-hf/llava-v1.6-mistral-7b-hf" model_config = ModelConfig( model=MODEL_NAME, @@ -138,7 +95,7 @@ def test_mm_limits(image_assets, mm_registry, num_images, limit, is_valid): # NOTE: We don't test zero images since the HF processor doesn't support it @pytest.mark.parametrize("num_images", [1, 2]) def test_image_mapper_multi(image_assets, mm_registry, num_images): - MODEL_NAME = "llava-hf/llava-1.5-7b-hf" + MODEL_NAME = "llava-hf/llava-v1.6-mistral-7b-hf" model_config = ModelConfig( model=MODEL_NAME, diff --git a/tests/multimodal/test_processing.py b/tests/multimodal/test_processing.py index b2367060c6c1b..ae668d1dd56c8 100644 --- a/tests/multimodal/test_processing.py +++ b/tests/multimodal/test_processing.py @@ -3,50 +3,15 @@ import pytest from transformers import BatchFeature -from vllm.multimodal.processing import (PromptReplacement, find_text_matches, - find_token_matches, iter_token_matches, - iter_token_runs, replace_text_matches) +from vllm.multimodal.processing import (PromptReplacement, _PlaceholderInfo, + find_text_matches, find_token_matches, + iter_placeholders, iter_token_matches, + replace_text_matches, + replace_token_matches) from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.utils import full_groupby -# yapf: disable -@pytest.mark.parametrize( - ("token_ids", "expected"), - [ - ([], []), - ( - [32000, 32000, 32000], - [{ "token_id": 32000, "start_idx": 0, "length": 3 }], - ), - ( - [9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918], - [ - { "token_id": 9833, "start_idx": 0, "length": 1 }, - { "token_id": 28747, "start_idx": 1, "length": 1 }, - { "token_id": 32000, "start_idx": 2, "length": 3 }, - { "token_id": 9833, "start_idx": 5, "length": 1 }, - { "token_id": 28747, "start_idx": 6, "length": 1 }, - { "token_id": 32000, "start_idx": 7, "length": 2 }, - { "token_id": 918, "start_idx": 9, "length": 1 }, - ], - ), - ], -) -# yapf: enable -def test_iter_token_runs(token_ids, expected): - result = list(iter_token_runs(token_ids)) - - # Only displayed on error - print("result:", result) - - # Manually constructed results - assert [item._asdict() for item in result] == expected - - # Invariants - assert sum(run_info.length for run_info in result) == len(token_ids) - - # yapf: disable @pytest.mark.parametrize( ("token_ids", "match_ids", "expected"), @@ -170,13 +135,11 @@ def test_find_token_matches(prompt, target_by_key, expected_by_key): # Should not be used since there is nothing to convert to token IDs mock_tokenizer = cast(AnyTokenizer, object()) - result = find_token_matches( - prompt, - [ - PromptReplacement(target, [], 0).bind(key, mock_tokenizer) - for key, target in target_by_key.items() - ], - ) + prompt_repls = [ + PromptReplacement(target, [], 0).bind(key, mock_tokenizer) + for key, target in target_by_key.items() + ] + result = find_token_matches(prompt, prompt_repls) # Only displayed on error print("result:", result) @@ -279,13 +242,11 @@ def test_find_text_matches(prompt, target_by_key, expected_by_key): # Should not be used since there is nothing to convert to text mock_tokenizer = cast(AnyTokenizer, object()) - result = find_text_matches( - prompt, - [ - PromptReplacement(target, [], 0).bind(key, mock_tokenizer) - for key, target in target_by_key.items() - ], - ) + prompt_repls = [ + PromptReplacement(target, [], 0).bind(key, mock_tokenizer) + for key, target in target_by_key.items() + ] + result = find_text_matches(prompt, prompt_repls) # Only displayed on error print("result:", result) @@ -303,7 +264,7 @@ def test_find_text_matches(prompt, target_by_key, expected_by_key): # yapf: disable @pytest.mark.parametrize( - ("prompt", "target_by_key", "repl_by_key", "expected_by_mm_count"), + ("prompt", "target_by_key", "repl_by_key"), [ ( "Image:Image:!", @@ -322,49 +283,201 @@ def test_find_text_matches(prompt, target_by_key, expected_by_key): # Test multiple repl_count "pattern_3": ("?", 2), }, - { - # Test no replacement - 0: "Image:Image:!", - # Test single replacement - 1: "Image:??", - # Test repeated replacement - 2: "??", - }, ), ] ) +@pytest.mark.parametrize( + ("mm_count", "expected"), + [ + (0, "Image:Image:!"), + (1, "Image:??"), + (2, "??"), + ] +) # yapf: enable def test_find_replace_text( prompt, target_by_key, repl_by_key, - expected_by_mm_count, + mm_count, + expected, ): # Should not be used since there is nothing to convert to text mock_tokenizer = cast(AnyTokenizer, object()) - matches = find_text_matches( + prompt_repls = [ + PromptReplacement(target, *repl_by_key[key]).bind(key, mock_tokenizer) + for key, target in target_by_key.items() + ] + matches = find_text_matches(prompt, prompt_repls) + + result = replace_text_matches( prompt, - [ - PromptReplacement(target, *repl_by_key[key]) \ - .bind(key, mock_tokenizer) - for key, target in target_by_key.items() - ], + matches, + {key: list(range(mm_count)) + for key in repl_by_key}, + BatchFeature(), ) - result_by_mm_count = { - mm_count: replace_text_matches( - prompt, - matches, - {key: list(range(mm_count)) - for key in repl_by_key}, - BatchFeature(), - ) - for mm_count in expected_by_mm_count - } # Only displayed on error print("matches:", matches) - print("result_by_mm_count:", result_by_mm_count) + print("result:", result) + + # Manually constructed results + assert result == expected + + +# yapf: disable +@pytest.mark.parametrize( + ("prompt", "target_by_key", "repl_by_key"), + [ + # Tokenized test cases of `test_find_replace_text` + # using the vocab of llava-hf/llava-v1.6-mistral-7b-hf + ( + [1, 9833, 28747, 32000, 9833, 28747, 32000, 32000, 918], + { + # We use `` before `Image:` to test matches that + # occur out of order + "pattern_1": [32000], + "pattern_2": [9833, 28747], + "pattern_3": [918], + }, + { + # Test whether target is confused with repl_unit + "pattern_1": ([32000, 32000], 1), + # Test empty repl_unit + "pattern_2": ([], 1), + # Test multiple repl_count + "pattern_3": ([1550], 2), + }, + ), + ] +) +@pytest.mark.parametrize( + ("mm_count", "expected"), + [ + (0, [1, 9833, 28747, 32000, 9833, 28747, 32000, 32000, 918]), + (1, [1, 32000, 32000, 9833, 28747, 32000, 32000, 1550, 1550]), + (2, [1, 32000, 32000, 32000, 32000, 32000, 1550, 1550]), + ] +) +# yapf: enable +def test_find_replace_tokens( + prompt, + target_by_key, + repl_by_key, + mm_count, + expected, +): + # Should not be used since there is nothing to convert to tokens + mock_tokenizer = cast(AnyTokenizer, object()) + + prompt_repls = [ + PromptReplacement(target, *repl_by_key[key]).bind(key, mock_tokenizer) + for key, target in target_by_key.items() + ] + matches = find_token_matches(prompt, prompt_repls) + + result = replace_token_matches( + prompt, + matches, + {key: list(range(mm_count)) + for key in repl_by_key}, + BatchFeature(), + ) + + # Only displayed on error + print("matches:", matches) + print("result:", result) + + # Manually constructed results + assert result == expected + + +# yapf: disable +@pytest.mark.parametrize( + "repl_by_key", + [ + { + "pattern_1": ([32000, 32000], 1), + "pattern_2": ([], 1), + "pattern_3": ([1550], 2), + }, + ], +) +@pytest.mark.parametrize( + ("prompt", "expected"), + [ + ( + [1, 9833, 28747, 32000, 9833, 28747, 32000, 32000, 918], + [ + _PlaceholderInfo( + modality="pattern_1", + start_idx=6, + unit=[32000, 32000], + unit_count=1, + ), + ], + ), + ( + [1, 32000, 32000, 9833, 28747, 32000, 32000, 1550, 1550], + [ + _PlaceholderInfo( + modality="pattern_1", + start_idx=1, + unit=[32000, 32000], + unit_count=1, + ), + _PlaceholderInfo( + modality="pattern_1", + start_idx=5, + unit=[32000, 32000], + unit_count=1, + ), + _PlaceholderInfo( + modality="pattern_3", + start_idx=7, + unit=[1550], + unit_count=2, + ), + ], + ), + ( + [1, 32000, 32000, 32000, 32000, 32000, 1550, 1550], + [ + _PlaceholderInfo( + modality="pattern_1", + start_idx=1, + unit=[32000, 32000], + unit_count=2, + ), + _PlaceholderInfo( + modality="pattern_3", + start_idx=6, + unit=[1550], + unit_count=2, + ), + ], + ), + ] +) +def test_iter_placeholders( + repl_by_key, + prompt, + expected, +): + # Should not be used since there is nothing to convert to tokens + mock_tokenizer = cast(AnyTokenizer, object()) + + prompt_repls = [ + PromptReplacement([], *repl).bind(key, mock_tokenizer) + for key, repl in repl_by_key.items() + ] + + result = list(iter_placeholders(prompt_repls, prompt)) + + # Only displayed on error + print("result:", result) # Manually constructed results - assert result_by_mm_count == expected_by_mm_count + assert result == expected diff --git a/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_llava.py b/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_llava.py index 3ebd7864b8fc8..f2fc0755cae01 100644 --- a/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_llava.py +++ b/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_llava.py @@ -2,19 +2,17 @@ import torch -from vllm.inputs import INPUT_REGISTRY from vllm.model_executor.models.llava import (LlavaForConditionalGeneration, - dummy_data_for_llava, - get_max_llava_image_tokens, - input_processor_for_llava) + create_metadata_for_llava, + dummy_mm_kwargs_for_llava, + get_max_llava_image_tokens) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -@MULTIMODAL_REGISTRY.register_image_input_mapper() @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens) -@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava) -@INPUT_REGISTRY.register_input_processor(input_processor_for_llava) +@MULTIMODAL_REGISTRY.register_processor_by_metadata(create_metadata_for_llava, + dummy_mm_kwargs_for_llava) class MyLlava(LlavaForConditionalGeneration): def compute_logits( diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index 85ab4355cc2e4..646554c72481a 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -232,19 +232,35 @@ def dummy_data_for_profiling( """ # Avoid circular import from vllm.model_executor.model_loader import get_model_architecture - - model_cls, _ = get_model_architecture(model_config) - if is_encoder_data: - dummy_factory = self._get_dummy_encoder_data_factory(model_cls) + from vllm.multimodal import MultiModalKwargs + from vllm.multimodal.utils import cached_get_tokenizer + + if mm_registry.has_processor(model_config): + tokenizer = cached_get_tokenizer( + model_config.tokenizer, + trust_remote_code=model_config.trust_remote_code, + ) + processor = mm_registry.create_processor(model_config, tokenizer) + + mm_counts = mm_registry.get_mm_limits_per_prompt(model_config) + mm_max_tokens = mm_registry.get_max_tokens_by_modality( + model_config) + + dummy_data = processor.get_dummy_data(seq_len, mm_counts, + mm_max_tokens) else: - dummy_factory = self._get_dummy_data_factory(model_cls) - mm_counts = mm_registry.get_mm_limits_per_prompt(model_config) - mm_processor_kwargs = get_allowed_kwarg_only_overrides( - dummy_factory, overrides=model_config.mm_processor_kwargs) + model_cls, _ = get_model_architecture(model_config) + if is_encoder_data: + dummy_factory = self._get_dummy_encoder_data_factory(model_cls) + else: + dummy_factory = self._get_dummy_data_factory(model_cls) + mm_counts = mm_registry.get_mm_limits_per_prompt(model_config) + mm_processor_kwargs = get_allowed_kwarg_only_overrides( + dummy_factory, overrides=model_config.mm_processor_kwargs) - dummy_data = dummy_factory(InputContext(model_config), seq_len, - _MultiModalCounts(mm_counts), - **mm_processor_kwargs) + dummy_data = dummy_factory(InputContext(model_config), seq_len, + _MultiModalCounts(mm_counts), + **mm_processor_kwargs) # Having more tokens is over-conservative but otherwise fine num_tokens = dummy_data.seq_data.prompt_token_ids @@ -257,7 +273,9 @@ def dummy_data_for_profiling( raise AssertionError( f"Expected at least {seq_len} dummy tokens for profiling, " f"but found {len(num_tokens)} tokens instead.") - if dummy_data.multi_modal_data is not None: + + if (dummy_data.multi_modal_data is not None and + not isinstance(dummy_data.multi_modal_data, MultiModalKwargs)): for k, v in dummy_data.multi_modal_data.items(): num_items = len(v) if isinstance(v, list) else 1 num_expected = mm_counts[k] diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index d375c1c9da2a9..953b89f1842af 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -1,17 +1,19 @@ from functools import cached_property +from types import MethodType from typing import (Iterable, List, Literal, Mapping, Optional, Protocol, Set, Tuple, TypedDict, Union) import torch import torch.nn as nn -from PIL import Image -from transformers import (CLIPVisionConfig, LlavaConfig, PixtralVisionConfig, - PretrainedConfig, SiglipVisionConfig) +from PIL.Image import Image +from transformers import (BatchFeature, CLIPVisionConfig, LlavaConfig, + PixtralVisionConfig, PretrainedConfig, + ProcessorMixin, SiglipVisionConfig) +from transformers.models.pixtral import PixtralProcessor from vllm.attention import AttentionMetadata from vllm.config import VllmConfig -from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, - InputContext) +from vllm.inputs import InputContext from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, RowParallelLinear) @@ -19,21 +21,20 @@ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import NestedTensors +from vllm.multimodal.inputs import MultiModalKwargs, NestedTensors +from vllm.multimodal.processing import (InputProcessingContext, + ModalityProcessingMetadata, + MultiModalProcessingMetadata, + MultiModalProcessor, PromptReplacement) from vllm.sequence import IntermediateTensors -from vllm.utils import is_list_of from .clip import (CLIPVisionModel, dummy_image_for_clip, - dummy_seq_data_for_clip, get_max_clip_image_tokens, - input_processor_for_clip) + get_max_clip_image_tokens) from .interfaces import SupportsMultiModal, SupportsPP from .pixtral import (PixtralHFVisionModel, dummy_image_for_pixtral_hf, - dummy_seq_data_for_pixtral_hf, - get_max_pixtral_hf_image_tokens, - input_processor_for_pixtral_hf) + get_max_pixtral_hf_image_tokens) from .siglip import (SiglipVisionModel, dummy_image_for_siglip, - dummy_seq_data_for_siglip, get_max_siglip_image_tokens, - input_processor_for_siglip) + get_max_siglip_image_tokens) from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, maybe_prefix, merge_multimodal_embeddings) @@ -113,102 +114,86 @@ def get_max_llava_image_tokens(ctx: InputContext): raise ValueError(f"Unexpected select feature strategy: {strategy}") -def dummy_data_for_llava(ctx: InputContext, seq_len: int, - mm_counts: Mapping[str, int]): +def dummy_mm_kwargs_for_llava(ctx: InputProcessingContext, + mm_counts: Mapping[str, int]): hf_config = ctx.get_hf_config(LlavaConfig) vision_config = hf_config.vision_config num_images = mm_counts["image"] - image_feature_size = get_max_llava_image_tokens(ctx) - if isinstance(vision_config, CLIPVisionConfig): - seq_data, ranges = dummy_seq_data_for_clip( - vision_config, - seq_len, - num_images, - image_token_id=hf_config.image_token_index, - image_feature_size_override=image_feature_size, - ) - - mm_data = dummy_image_for_clip(vision_config, num_images) - return DummyData(seq_data, mm_data, ranges) + data = dummy_image_for_clip(vision_config, num_images) elif isinstance(vision_config, SiglipVisionConfig): - seq_data, ranges = dummy_seq_data_for_siglip( - vision_config, - seq_len, - num_images, - image_token_id=hf_config.image_token_index, - image_feature_size_override=image_feature_size, - ) - - mm_data = dummy_image_for_siglip(vision_config, num_images) - return DummyData(seq_data, mm_data, ranges) + data = dummy_image_for_siglip(vision_config, num_images) elif isinstance(vision_config, PixtralVisionConfig): - seq_data, ranges = dummy_seq_data_for_pixtral_hf( - vision_config, - seq_len, - num_images, - image_token_id=hf_config.image_token_index, - image_feature_size_override=image_feature_size, - ) - - mm_data = dummy_image_for_pixtral_hf(vision_config, num_images) - return DummyData(seq_data, mm_data, ranges) + data = dummy_image_for_pixtral_hf(vision_config, num_images) + else: + msg = f"Unsupported vision config: {type(vision_config)}" + raise NotImplementedError(msg) - msg = f"Unsupported vision config: {type(vision_config)}" - raise NotImplementedError(msg) + hf_processor = ctx.get_hf_processor() + image_processor = hf_processor.image_processor # type: ignore + hf_inputs = image_processor.preprocess(data['image'], return_tensors="pt") + is_pixtral = isinstance(hf_processor, PixtralProcessor) + return MultiModalKwargs( + **hf_inputs, + is_pixtral=torch.tensor(is_pixtral), + ) -def input_processor_for_llava(ctx: InputContext, inputs: DecoderOnlyInputs): - multi_modal_data = inputs.get("multi_modal_data") - if multi_modal_data is None or "image" not in multi_modal_data: - return inputs - model_config = ctx.model_config +def create_metadata_for_llava( + ctx: InputProcessingContext) -> MultiModalProcessingMetadata: hf_config = ctx.get_hf_config(LlavaConfig) - vision_config = hf_config.vision_config + image_token_id = hf_config.image_token_index + + def get_repl_count( + mm_items: list[Image], + hf_inputs: BatchFeature, + item_idx: int, + ) -> int: + return get_max_llava_image_tokens(ctx) + + return { + "image": + ModalityProcessingMetadata(prompt_repls=[ + PromptReplacement(target=[image_token_id], + repl_unit=[image_token_id], + repl_count=get_repl_count), + ]), + } - image_data = multi_modal_data["image"] - if isinstance(image_data, Image.Image): - image_feature_size = get_max_llava_image_tokens(ctx) - elif is_list_of(image_data, Image.Image): - image_feature_size = [get_max_llava_image_tokens(ctx) - ] * len(image_data) - elif isinstance(image_data, torch.Tensor): - num_images, image_feature_size, hidden_size = image_data.shape - elif is_list_of(image_data, torch.Tensor): - image_feature_size = [item.shape[1] for item in image_data] - else: - raise TypeError(f"Invalid image type: {type(image_data)}") - if isinstance(vision_config, CLIPVisionConfig): - return input_processor_for_clip( - model_config, - vision_config, - inputs, - image_token_id=hf_config.image_token_index, - image_feature_size_override=image_feature_size, - ) - elif isinstance(vision_config, SiglipVisionConfig): - return input_processor_for_siglip( - model_config, - vision_config, - inputs, - image_token_id=hf_config.image_token_index, - image_feature_size_override=image_feature_size, - ) - elif isinstance(vision_config, PixtralVisionConfig): - # We ignore image_feature_size_override since we have non-uniform - # image sizes for Pixtral - return input_processor_for_pixtral_hf( - model_config, - vision_config, - inputs, - image_token_id=hf_config.image_token_index, - ) +class LlavaProcessor(MultiModalProcessor): - msg = f"Unsupported vision config: {type(vision_config)}" - raise NotImplementedError(msg) + def _patch_pixtral_processor(self, hf_processor: PixtralProcessor): + if getattr(hf_processor, "__is_patched__", False): + return # Already patched + + image_processor = hf_processor.image_processor # type: ignore + orig_preprocess = image_processor.preprocess + + def preprocess(__self, *args, **kwargs): + hf_inputs = orig_preprocess(*args, **kwargs) + hf_inputs["is_pixtral"] = torch.tensor(True) + return hf_inputs + + image_processor.preprocess = MethodType(preprocess, image_processor) + + hf_processor.__is_patched__ = True # type: ignore + + def _get_hf_processor(self) -> ProcessorMixin: + hf_processor = self.ctx.get_hf_processor() + + if isinstance(hf_processor, PixtralProcessor): + self._patch_pixtral_processor(hf_processor) + + return hf_processor + + def _get_dummy_mm_kwargs( + self, + mm_counts: Mapping[str, int], + ) -> MultiModalKwargs: + return dummy_mm_kwargs_for_llava(self.ctx, mm_counts) class LlavaLikeConfig(Protocol): @@ -291,10 +276,11 @@ def init_vision_tower_for_llava( raise NotImplementedError(msg) -@MULTIMODAL_REGISTRY.register_image_input_mapper() @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens) -@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava) -@INPUT_REGISTRY.register_input_processor(input_processor_for_llava) +@MULTIMODAL_REGISTRY.register_processor(lambda ctx: LlavaProcessor( + ctx=ctx, + metadata=create_metadata_for_llava(ctx), +)) class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): # BitandBytes specific attributes bitsandbytes_stacked_params_mapping = { @@ -367,38 +353,10 @@ def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: return data - def _validate_image_sizes(self, images: List[torch.Tensor], - sizes: List[torch.Tensor]) -> List[torch.Tensor]: - if not isinstance(sizes, list): - sizes = [sizes] - - total_images = sum(size.numel() // 2 for size in sizes) - if total_images != len(images): - raise ValueError("Mismatch in number of images. " - f"Expected {total_images}, got {len(images)}") - img_idx = 0 - for size in sizes: - # Flatten the size tensor to a list of (height, width) pairs - size = size.view(-1, 2).tolist() - for expected_h, expected_w in size: - if img_idx >= len(images): - raise ValueError("Ran out of images before sizes. " - f"{img_idx} >= {len(images)}") - img = images[img_idx] - if img.shape[-2:] != (expected_h, expected_w): - raise ValueError( - "Image size mismatch. Expected " - f"{(expected_h, expected_w)}, got {img.shape[-2:]}") - if img.shape[-3] != 3: - raise ValueError("Image channel mismatch. Expected 3, " - f"got {img.shape[-3]}") - img_idx += 1 - return images - def _parse_and_validate_image_input( self, **kwargs: object) -> Optional[LlavaImageInputs]: pixel_values = kwargs.pop("pixel_values", None) - image_sizes = kwargs.pop("image_sizes", None) + is_pixtral = kwargs.pop("is_pixtral", torch.tensor([False])) image_embeds = kwargs.pop("image_embeds", None) if pixel_values is None and image_embeds is None: @@ -409,9 +367,8 @@ def _parse_and_validate_image_input( raise ValueError("Incorrect type of pixel values. " f"Got type: {type(pixel_values)}") - # Case for models like PixtralHF that have dynamic image sizes - # so we need to produce a list of tensors - if image_sizes is not None: + assert isinstance(is_pixtral, torch.Tensor) + if is_pixtral.any(): images = pixel_values def flatten_to_3d_tensors(item): @@ -434,7 +391,7 @@ def flatten_to_3d_tensors(item): return LlavaImagePixelInputs( type="pixel_values", - data=self._validate_image_sizes(images, image_sizes), + data=images, ) return LlavaImagePixelInputs( diff --git a/vllm/multimodal/base.py b/vllm/multimodal/base.py index f93722523728d..7dba94b885b6d 100644 --- a/vllm/multimodal/base.py +++ b/vllm/multimodal/base.py @@ -226,16 +226,16 @@ def get_max_multimodal_tokens(self, model_config: "ModelConfig") -> int: """ # Avoid circular import from vllm.model_executor.model_loader import get_model_architecture + from vllm.model_executor.models import supports_multimodal model_cls, _ = get_model_architecture(model_config) - if model_cls not in self._input_mappers: + if not supports_multimodal(model_cls): return 0 max_mm_tokens = self._max_mm_tokens.get(model_cls) if max_mm_tokens is None: - raise KeyError(f"No maximum number of multi-modal tokens is given " - f"for model class {model_cls.__name__} in {self}.") + return 0 if callable(max_mm_tokens): mm_processor_kwargs = get_allowed_kwarg_only_overrides( @@ -326,26 +326,47 @@ def from_seq_group( src_ranges = [] dest_ranges = [] """ - if (not seq_group.multi_modal_data - or not seq_group.multi_modal_placeholders): - return seq_group.multi_modal_data, {} + seq_mm_data = seq_group.multi_modal_data + seq_mm_placeholders = seq_group.multi_modal_placeholders + + if not seq_mm_data or not seq_mm_placeholders: + return seq_mm_data, {} + + # For merged processor, we directly use mm_kwargs as mm_data + if isinstance(seq_mm_data, MultiModalKwargs): + placeholder_maps = dict[str, MultiModalPlaceholderMap]() + + for modality, placeholders in seq_mm_placeholders.items(): + placeholder_map = MultiModalPlaceholderMap() + + if positions: + placeholder_map.append_items_from_seq_group( + positions, + # Dummy, since we don't care about intersecting items + [None] * len(placeholders), + placeholders, + ) + + placeholder_maps[modality] = placeholder_map + + return seq_mm_data, placeholder_maps - mm_data = {**seq_group.multi_modal_data} - placeholder_maps: Dict[str, MultiModalPlaceholderMap] = defaultdict( + mm_data = {**seq_mm_data} + placeholder_maps = defaultdict[str, MultiModalPlaceholderMap]( MultiModalPlaceholderMap) - for ( - modality, - placeholders, - ) in seq_group.multi_modal_placeholders.items(): + for modality, placeholders in seq_mm_placeholders.items(): mm_items = mm_data.pop(modality) if not isinstance(mm_items, list): mm_items = [mm_items] if positions: - intersecting_items = placeholder_maps[ - modality].append_items_from_seq_group( - positions, mm_items, placeholders) + intersecting_items = placeholder_maps[modality] \ + .append_items_from_seq_group( + positions, + mm_items, + placeholders, + ) if intersecting_items: mm_data[modality] = intersecting_items diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index 28c8dda581982..4a1737991534f 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -3,14 +3,13 @@ from collections.abc import Callable, ItemsView, Iterable, Mapping, Sequence from dataclasses import dataclass from functools import lru_cache -from itertools import groupby from typing import Any, Generic, NamedTuple, Optional, Protocol, TypeVar, Union -import numpy as np -from transformers import BatchFeature +import torch +from transformers import BatchFeature, ProcessorMixin from typing_extensions import TypeAlias, TypedDict -from vllm.inputs import InputProcessingContext +from vllm.inputs import DummyData, InputProcessingContext from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer from vllm.utils import flatten_2d_lists, full_groupby, is_list_of @@ -256,63 +255,6 @@ def to_multi_format(data: MultiModalDataDict) -> dict[str, list[Any]]: return multi_data -class _TokenRun(NamedTuple): - token_id: int - - start_idx: int - length: int - - -def iter_token_runs(token_ids: list[int]) -> Iterable[_TokenRun]: - """ - Yield the starting index and length of each run of tokens that are the same. - """ - start_idx = 0 - - for token_id, it in groupby(token_ids): - length = sum(1 for _ in it) - yield _TokenRun(token_id=token_id, start_idx=start_idx, length=length) - - start_idx += length - - -class _PlaceholderInfo(NamedTuple): - modality: str - offset: int - length: int - - def to_range(self) -> PlaceholderRange: - return PlaceholderRange(offset=self.offset, length=self.length) - - -def iter_placeholders( - prompt_repls: Sequence[_BoundPromptReplacement[Any]], - token_ids: list[int], - *, - min_placeholder_count: int, -) -> Iterable[_PlaceholderInfo]: - """Yield each set of placeholder tokens found in :code:`token_ids`.""" - placeholder_ids_by_modality = { - modality: { - token_id - for prompt_repl in repls - for token_id in prompt_repl.repl_unit.token_ids - } - for modality, repls in full_groupby_modality(prompt_repls) - } - - for run_info in iter_token_runs(token_ids): - if run_info.length > min_placeholder_count: - for (modality, - placeholder_ids) in placeholder_ids_by_modality.items(): - if run_info.token_id in placeholder_ids: - yield _PlaceholderInfo( - modality=modality, - offset=run_info.start_idx, - length=run_info.length, - ) - - class _TokenMatch(NamedTuple): start_idx: int end_idx: int @@ -353,13 +295,9 @@ def start_idx(self) -> int: def end_idx(self) -> int: raise NotImplementedError + @property @abstractmethod - def get_repl( - self, - mm_items: list[_T], - hf_inputs: BatchFeature, - item_idx: int, - ) -> _S: + def repl_unit(self) -> _S: raise NotImplementedError def __repr__(self) -> str: @@ -380,15 +318,9 @@ def start_idx(self) -> int: def end_idx(self) -> int: return self.match.end_idx - def get_repl( - self, - mm_items: list[_T], - hf_inputs: BatchFeature, - item_idx: int, - ) -> list[int]: - prompt_repl = self.prompt_repl - count = prompt_repl.get_count(mm_items, hf_inputs, item_idx) - return prompt_repl.repl_unit.token_ids * count + @property + def repl_unit(self) -> list[int]: + return self.prompt_repl.repl_unit.token_ids @dataclass(repr=False) @@ -404,15 +336,26 @@ def start_idx(self) -> int: def end_idx(self) -> int: return self.match.end() - def get_repl( - self, - mm_items: list[_T], - hf_inputs: BatchFeature, - item_idx: int, - ) -> str: - prompt_repl = self.prompt_repl - count = prompt_repl.get_count(mm_items, hf_inputs, item_idx) - return prompt_repl.repl_unit.text * count + @property + def repl_unit(self) -> str: + return self.prompt_repl.repl_unit.text + + +class _PlaceholderInfo(NamedTuple): + modality: str + start_idx: int + unit: list[int] + unit_count: int + + @property + def length(self) -> int: + return len(self.unit) * self.unit_count + + def to_range(self) -> PlaceholderRange: + return PlaceholderRange( + offset=self.start_idx, + length=self.length, + ) def find_token_matches( @@ -447,15 +390,17 @@ def _resolve_matches( Resolve :code:`matches` to ensure that there are no overlapping matches, and sort them such that earlier matches take priority over later ones. """ - num_matches_by_idx = np.zeros(len(prompt), dtype=int) + seen_matches: list[Optional[_PromptReplacementMatch[_T, _S]]] \ + = [None] * len(prompt) + for match in matches: - num_matches_by_idx[match.start_idx:match.end_idx] += 1 + for idx in range(match.start_idx, match.end_idx): + if seen_matches[idx] is not None: + raise ValueError("Found overlapping matches " + f"({seen_matches[idx]} and {match}) " + f"at index={idx} of prompt={prompt}") - duplicate_matches_idxs, = np.nonzero(num_matches_by_idx > 1) - if len(duplicate_matches_idxs) > 0: - raise ValueError("Unable to find a unique replacement " - f"at indices={duplicate_matches_idxs} " - f"of prompt={prompt}") + seen_matches[idx] = match return sorted(matches, key=lambda x: x.start_idx) @@ -480,9 +425,12 @@ def _replace_matches( start_idx = match.start_idx end_idx = match.end_idx - repl_ids = match.get_repl(mm_items, hf_inputs, item_idx) + repl_unit = match.repl_unit + repl_info = match.prompt_repl + repl_count = repl_info.get_count(mm_items, hf_inputs, item_idx) - out_seqs.append(prompt[prev_end_idx:start_idx] + repl_ids) + out_seqs.append(prompt[prev_end_idx:start_idx] + + repl_unit * repl_count) prev_end_idx = end_idx next_idx_by_modality[modality] += 1 @@ -531,7 +479,57 @@ def replace_text_matches( return "".join(texts) -class MultiModalProcessor: +def _merge_placeholder_matches( + matches: Iterable[_PromptReplacementTokenMatch], +) -> Iterable[_PromptReplacementTokenMatch]: + current_match = None + + for match in sorted(matches, key=lambda x: x.start_idx): + if current_match is None: + current_match = match + elif (current_match.prompt_repl == match.prompt_repl + and current_match.end_idx == match.start_idx): + current_match = _PromptReplacementTokenMatch( + current_match.prompt_repl, + match=_TokenMatch(current_match.start_idx, match.end_idx), + ) + else: + yield current_match + current_match = match + + if current_match is not None: + yield current_match + + +def iter_placeholders( + prompt_repls: Sequence[_BoundPromptReplacement[Any]], + prompt: list[int], + *, + min_unit_count: int = 1, +) -> Iterable[_PlaceholderInfo]: + """Yield each set of placeholder tokens found in :code:`token_ids`.""" + if min_unit_count <= 0: + raise ValueError("`min_unit_count` must be a positive integer") + + matches = (_PromptReplacementTokenMatch(prompt_repl, match) + for prompt_repl in prompt_repls + if len(repl_unit := prompt_repl.repl_unit.token_ids) > 0 + for match in iter_token_matches(prompt, repl_unit)) + + for match in _merge_placeholder_matches(matches): + unit = match.repl_unit + placeholder = _PlaceholderInfo( + modality=match.modality, + start_idx=match.start_idx, + unit=unit, + unit_count=(match.end_idx - match.start_idx) // len(unit), + ) + + if placeholder.unit_count >= min_unit_count: + yield placeholder + + +class MultiModalProcessor(ABC): """ Helper class to process multi-modal inputs to be used in vLLM. """ @@ -546,6 +544,12 @@ def __init__( self.ctx = ctx self.metadata = metadata + def _get_hf_processor(self) -> ProcessorMixin: + return self.ctx.get_hf_processor() + + def _get_tokenizer(self) -> AnyTokenizer: + return self.ctx.tokenizer + def __call__( self, prompt: str, @@ -562,13 +566,13 @@ def _find_placeholders( # To avoid false positives from multi-input when detecting # whether placeholder tokens have been inserted, in case # the target sequence is a subset of the replacement tokens - min_placeholder_count: int = 16, + min_unit_count: int = 16, ) -> list[_PlaceholderInfo]: return list( iter_placeholders( all_prompt_repls, new_token_ids, - min_placeholder_count=min_placeholder_count, + min_unit_count=min_unit_count, )) def _apply_hf_processor( @@ -577,19 +581,49 @@ def _apply_hf_processor( mm_data: MultiModalDataDict, mm_processor_kwargs: Mapping[str, object], ) -> BatchFeature: - hf_processor = self.ctx.get_hf_processor() + hf_processor = self._get_hf_processor() + + processor_data = dict[str, Any]() + passthrough_data = dict[str, Any]() + for k, v in mm_data.items(): + # TODO: Make a separate modality for embedding inputs + # to avoid confusion + if k in ("image", "video", "audio"): + if isinstance(v, torch.Tensor) and v.ndim == 3: + # Pass through embedding inputs (single) + passthrough_data[f"{k}_embeds"] = [v] + elif is_list_of(v, torch.Tensor) and v[0].ndim == 2: + # Pass through embedding inputs (multi) + passthrough_data[f"{k}_embeds"] = v + else: + # Map keys to plural form, e.g.: image -> images + processor_data[f"{k}s"] = v + else: + processor_data[k] = v + + try: + hf_inputs = hf_processor( + text=prompt, # type: ignore + **processor_data, + **mm_processor_kwargs, + return_tensors="pt", + ) + except Exception as exc: + data = dict(text=prompt, **processor_data) - return hf_processor( - text=prompt, # type: ignore - **mm_data, - **mm_processor_kwargs, - ) + raise RuntimeError( + f"Failed to apply {type(hf_processor).__name__} " + f"on data={data} with kwargs={mm_processor_kwargs}") from exc + + hf_inputs.update(passthrough_data) + + return hf_inputs def _bind_prompt_replacements( self, mm_data: MultiModalDataDict, ) -> list[_BoundPromptReplacement[Any]]: - tokenizer = self.ctx.tokenizer + tokenizer = self._get_tokenizer() return [ prompt_repl.bind(modality, tokenizer) @@ -604,7 +638,7 @@ def _apply_prompt_replacements( token_ids: list[int], prompt_repls: Sequence[_BoundPromptReplacement[Any]], ) -> tuple[list[int], str, list[_PlaceholderInfo]]: - tokenizer = self.ctx.tokenizer + tokenizer = self._get_tokenizer() mm_items = to_multi_format(mm_data) token_matches = find_token_matches(token_ids, prompt_repls) @@ -620,7 +654,7 @@ def _apply_prompt_replacements( # of the search text in the prompt, we instead perform string # replacement on the decoded token IDs, then encode them back. if all( - len(matches) >= len(mm_data[modality]) + len(matches) >= len(mm_items[modality]) for modality, matches in full_groupby_modality(token_matches) ): # yapf: disable token_ids = replace_token_matches( @@ -648,15 +682,6 @@ def _apply_prompt_replacements( placeholders = self._find_placeholders(matched_repls, token_ids) - # Sanity check - assert len(placeholders) == len(matched_repls), dict( - # Log this information for easier debugging - text=text, - token_ids=token_ids, - placeholders=placeholders, - matched_repls=matched_repls, - ) - return token_ids, text, placeholders def apply( @@ -678,7 +703,7 @@ def apply( 3. Extract information about the placeholder tokens from the processed token IDs. """ - tokenizer = self.ctx.tokenizer + tokenizer = self._get_tokenizer() hf_inputs = self._apply_hf_processor(prompt_text, mm_data, mm_processor_kwargs) @@ -717,3 +742,59 @@ def apply( mm_kwargs=mm_kwargs, mm_placeholders=mm_placeholders, ) + + @abstractmethod + def _get_dummy_mm_kwargs( + self, + mm_counts: Mapping[str, int], + ) -> MultiModalKwargs: + """ + Build the input that corresponds to `mm_max_tokens` in + :meth:`get_dummy_data`. + """ + raise NotImplementedError + + def get_dummy_data( + self, + seq_len: int, + mm_counts: Mapping[str, int], + mm_max_tokens: Mapping[str, int], + ) -> DummyData: + # Avoid circular import + from vllm.sequence import SequenceData + + tokenizer = self._get_tokenizer() + + mm_placeholders = dict[str, _PlaceholderInfo]() + offset = 0 + + for modality, max_tokens in mm_max_tokens.items(): + if max_tokens == 0: + continue + + metadata = self.metadata[modality] + repl = metadata.prompt_repls[0].bind(modality, tokenizer) + repl_token_ids = repl.repl_unit.token_ids + + placeholders = _PlaceholderInfo( + modality=modality, + start_idx=offset, + unit=repl_token_ids, + unit_count=max_tokens // len(repl_token_ids), + ) + + mm_placeholders[modality] = placeholders + offset += placeholders.length + + prompt_token_ids = flatten_2d_lists( + [p.unit * p.unit_count for p in mm_placeholders.values()]) + prompt_token_ids.extend([0] * (seq_len - len(prompt_token_ids))) + + return DummyData( + seq_data=SequenceData.from_seqs(prompt_token_ids), + multi_modal_data=self._get_dummy_mm_kwargs(mm_counts), + multi_modal_placeholders={ + modality: [p.to_range()] + for modality, p in mm_placeholders.items() + }, + ) diff --git a/vllm/multimodal/registry.py b/vllm/multimodal/registry.py index b73daee98bd80..f51da8972d15b 100644 --- a/vllm/multimodal/registry.py +++ b/vllm/multimodal/registry.py @@ -15,7 +15,7 @@ from .base import MultiModalInputMapper, MultiModalPlugin, MultiModalTokensCalc from .image import ImagePlugin from .inputs import MultiModalDataDict, MultiModalKwargs, NestedTensors -from .processing import MultiModalProcessor +from .processing import MultiModalProcessingMetadata, MultiModalProcessor from .video import VideoPlugin if TYPE_CHECKING: @@ -200,9 +200,12 @@ def register_max_image_tokens( """ return self.register_max_multimodal_tokens("image", max_mm_tokens) - def get_max_multimodal_tokens(self, model_config: "ModelConfig") -> int: + def get_max_tokens_by_modality( + self, + model_config: "ModelConfig", + ) -> Mapping[str, int]: """ - Get the maximum number of multi-modal tokens + Get the maximum number of tokens from each modality for profiling the memory usage of a model. See :meth:`MultiModalPlugin.get_max_multimodal_tokens` for more details. @@ -212,9 +215,23 @@ def get_max_multimodal_tokens(self, model_config: "ModelConfig") -> int: """ limits_per_plugin = self._limits_by_model[model_config] - return sum((limits_per_plugin[key] * - plugin.get_max_multimodal_tokens(model_config)) - for key, plugin in self._plugins.items()) + return { + key: (limits_per_plugin[key] * + plugin.get_max_multimodal_tokens(model_config)) + for key, plugin in self._plugins.items() + } + + def get_max_multimodal_tokens(self, model_config: "ModelConfig") -> int: + """ + Get the maximum number of multi-modal tokens + for profiling the memory usage of a model. + + See :meth:`MultiModalPlugin.get_max_multimodal_tokens` for more details. + + Note: + This should be called after :meth:`init_mm_limits_per_prompt`. + """ + return sum(self.get_max_tokens_by_modality(model_config).values()) def init_mm_limits_per_prompt( self, @@ -270,7 +287,8 @@ def register_processor( factory: MultiModalProcessorFactory, ): """ - Register a multi-modal processor to a model class. + Register a multi-modal processor to a model class. The processor + is constructed lazily, hence a factory method should be passed. When the model receives multi-modal data, the provided function is invoked to transform the data into a dictionary of model inputs. @@ -293,6 +311,41 @@ def wrapper(model_cls: N) -> N: return wrapper + def register_processor_by_metadata( + self, + metadata_factory: Callable[[InputProcessingContext], + MultiModalProcessingMetadata], + get_dummy_mm_kwargs: Callable[ + [InputProcessingContext, Mapping[str, int]], MultiModalKwargs], + ): + """ + Convenience method to register a multi-modal processor to a model class + according to a function that constructs its metadata. + + When the model receives multi-modal data, the provided function is + invoked to transform the data into a dictionary of model inputs. + + See also: + - :ref:`input_processing_pipeline` + - :ref:`enabling_multimodal_inputs` + """ + + class ConcreteMultiModalProcessor(MultiModalProcessor): + + def _get_dummy_mm_kwargs( + self, + mm_counts: Mapping[str, int], + ) -> MultiModalKwargs: + return get_dummy_mm_kwargs(self.ctx, mm_counts) + + def factory(ctx: InputProcessingContext): + return ConcreteMultiModalProcessor( + ctx=ctx, + metadata=metadata_factory(ctx), + ) + + return self.register_processor(factory) + def has_processor(self, model_config: "ModelConfig") -> bool: """ Test whether a multi-modal processor is defined for a specific model. diff --git a/vllm/v1/engine/mm_input_mapper.py b/vllm/v1/engine/mm_input_mapper.py index 594c973678235..45882f8f076d4 100644 --- a/vllm/v1/engine/mm_input_mapper.py +++ b/vllm/v1/engine/mm_input_mapper.py @@ -12,6 +12,7 @@ def __init__( model_config: ModelConfig, mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, ): + self.model_config = model_config self.mm_registry = mm_registry self.multi_modal_input_mapper = mm_registry.create_input_mapper( model_config) diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index 7a1ea2530abda..120fc64969552 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -7,7 +7,8 @@ from vllm.inputs.parse import is_encoder_decoder_inputs from vllm.inputs.preprocess import InputPreprocessor from vllm.lora.request import LoRARequest -from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry +from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalKwargs, + MultiModalRegistry) from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams @@ -101,10 +102,15 @@ def process_inputs( self.generation_config_fields, eos_token_id) # Preprocess multi-modal data - mm_inputs = self.mm_input_mapper.process_inputs( - decoder_inputs.multi_modal_data, - decoder_inputs.mm_processor_kwargs) if len( - decoder_inputs.multi_modal_data) > 0 else None + if len(decoder_inputs.multi_modal_data) == 0: + mm_inputs = None + elif isinstance(decoder_inputs.multi_modal_data, MultiModalKwargs): + mm_inputs = [decoder_inputs.multi_modal_data] + else: + mm_inputs = self.mm_input_mapper.process_inputs( + decoder_inputs.multi_modal_data, + decoder_inputs.mm_processor_kwargs, + ) # Make Request for Detokenizer. detokenizer_request = DetokenizerRequest( From f13cf9ad5049e386f766014877dee78d2f438799 Mon Sep 17 00:00:00 2001 From: Gregory Shtrasberg <156009573+gshtras@users.noreply.github.com> Date: Sat, 7 Dec 2024 04:03:44 -0500 Subject: [PATCH 08/18] [Build] Fix for the Wswitch-bool clang warning (#10060) Signed-off-by: Gregory Shtrasberg --- csrc/attention/paged_attention_v1.cu | 11 ++++------- csrc/attention/paged_attention_v2.cu | 11 ++++------- 2 files changed, 8 insertions(+), 14 deletions(-) diff --git a/csrc/attention/paged_attention_v1.cu b/csrc/attention/paged_attention_v1.cu index 741cd0c82dc89..cb1a069942069 100644 --- a/csrc/attention/paged_attention_v1.cu +++ b/csrc/attention/paged_attention_v1.cu @@ -140,13 +140,10 @@ void paged_attention_v1_launcher( blocksparse_block_size, blocksparse_head_sliding_step); #define CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \ - switch (is_block_sparse) { \ - case true: \ - CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, true); \ - break; \ - case false: \ - CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, false); \ - break; \ + if (is_block_sparse) { \ + CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, true); \ + } else { \ + CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, false); \ } // NOTE(woosuk): To reduce the compilation time, we omitted block sizes diff --git a/csrc/attention/paged_attention_v2.cu b/csrc/attention/paged_attention_v2.cu index 6de8d0bdd5b8d..c457bdb89008e 100644 --- a/csrc/attention/paged_attention_v2.cu +++ b/csrc/attention/paged_attention_v2.cu @@ -147,13 +147,10 @@ void paged_attention_v2_launcher( blocksparse_head_sliding_step); #define CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \ - switch (is_block_sparse) { \ - case true: \ - CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, true); \ - break; \ - case false: \ - CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, false); \ - break; \ + if (is_block_sparse) { \ + CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, true); \ + } else { \ + CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, false); \ } // NOTE(woosuk): To reduce the compilation time, we omitted block sizes From b26b4cd03c5468c68c3ce328ea6498a5d816870d Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Sat, 7 Dec 2024 18:33:49 +0800 Subject: [PATCH 09/18] [Misc][LoRA] Refactor and clean MergedQKVParallelLinearWithLora implementation (#10958) Signed-off-by: Isotr0py <2037008807@qq.com> --- vllm/lora/layers.py | 323 ++++++++------------------------------------ 1 file changed, 60 insertions(+), 263 deletions(-) diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index 473e4bedf3d60..3e9c2ceb83eac 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -542,10 +542,20 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): Both slices must have the same size. """ - def __init__(self, base_layer: MergedColumnParallelLinear) -> None: + def __init__( + self, base_layer: Union[MergedColumnParallelLinear, + QKVParallelLinear]) -> None: super().__init__(base_layer) # There are two LoRA layers - self.n_slices = len(self.base_layer.output_sizes) + self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() + # the output_sizes in MergedColumnParallelLinear is not sharded by tp + # we need to divide it by the tp_size to get correct slices size + output_sizes = self.base_layer.output_sizes + self.output_slices = tuple( + divide(output_size, self.tp_size) for output_size in output_sizes) + self.n_slices = len(self.output_slices) + self.output_ids = (self.tp_rank, ) * self.n_slices def create_lora_weights( self, @@ -559,15 +569,6 @@ def create_lora_weights( """ self.lora_config = lora_config - if not (len(self.base_layer.output_sizes) == self.n_slices == 2 - and self.base_layer.output_sizes[0] - == self.base_layer.output_sizes[1]): - raise ValueError( - "LoRAColumnParallelLinear2Slice requires 2 slices with " - "the same size.") - self.tp_size = get_tensor_model_parallel_world_size() - self.tp_rank = get_tensor_model_parallel_rank() - lora_a_output_size_per_partition = ( lora_config.max_lora_rank if not lora_config.fully_sharded_loras else divide(lora_config.max_lora_rank, self.tp_size)) @@ -585,22 +586,20 @@ def create_lora_weights( torch.zeros( max_loras, 1, - self.output_size // 2, + output_size, lora_config.max_lora_rank, dtype=lora_config.lora_dtype, device=self.device, - ) for _ in range(self.n_slices)) + ) for output_size in self.output_slices) if lora_config.bias_enabled: self.lora_bias_stacked = tuple( torch.zeros( max_loras, 1, - self.output_size // 2, + output_size, dtype=lora_config.lora_dtype, device=self.device, - ) for _ in range(self.n_slices)) - self.output_dim = self.lora_b_stacked[0].shape[2] - self.output_slices = (self.output_dim, self.output_dim) + ) for output_size in self.output_slices) def slice_lora_a( self, lora_a: List[Union[torch.Tensor, None]] @@ -610,27 +609,21 @@ def slice_lora_a( def slice_lora_b( self, lora_b: List[Union[torch.Tensor, None]] ) -> List[Union[torch.Tensor, None]]: - #NOTE: lora_b contains 2 subloras, and each sublora could be None. - shard_size = self.output_dim - start_idx = self.tp_rank * shard_size - end_idx = (self.tp_rank + 1) * shard_size - lora_b = [ - lora_b[0][:, start_idx:end_idx] if lora_b[0] is not None else None, - lora_b[1][:, start_idx:end_idx] if lora_b[1] is not None else None, - ] + for i, (shard_id, shard_size) in enumerate( + zip(self.output_ids, self.output_slices)): + if (lora_b_i := lora_b[i]) is not None: + lora_b[i] = lora_b_i[:, shard_size * shard_id:shard_size * + (shard_id + 1)] return lora_b def slice_bias( self, bias: List[Union[torch.Tensor, None]]) -> List[Union[torch.Tensor, None]]: - # NOTE : each bias could be None. - shard_size = self.output_dim - start_idx = self.tp_rank * shard_size - end_idx = (self.tp_rank + 1) * shard_size - bias = [ - bias[0][start_idx:end_idx] if bias[0] is not None else None, - bias[1][start_idx:end_idx] if bias[1] is not None else None - ] + for i, (shard_id, shard_size) in enumerate( + zip(self.output_ids, self.output_slices)): + if (bias_i := bias[i]) is not None: + bias[i] = bias_i[shard_size * shard_id:shard_size * + (shard_id + 1)] return bias def set_lora( @@ -649,30 +642,25 @@ def set_lora( if lora_bias is not None: lora_bias = self.slice_bias(lora_bias) - if lora_a[0] is not None: - self.lora_a_stacked[0][ - index, 0, :lora_a[0].shape[1], :lora_a[0].shape[0]].copy_( - lora_a[0].T, non_blocking=True) - self.lora_b_stacked[0][ - index, 0, :lora_b[0].shape[1], :lora_b[0].shape[0]].copy_( - lora_b[0].T, non_blocking=True) - if lora_bias is not None and lora_bias[0] is not None: - self.lora_bias_stacked = cast(Tuple[torch.Tensor, ...], - self.lora_bias_stacked) - self.lora_bias_stacked[0][index, 0, :lora_bias[0].shape[0]].copy_( - lora_bias[0].T, non_blocking=True) - if lora_a[1] is not None: - self.lora_a_stacked[1][ - index, 0, :lora_a[1].shape[1], :lora_a[1].shape[0]].copy_( - lora_a[1].T, non_blocking=True) - self.lora_b_stacked[1][ - index, 0, :lora_b[1].shape[1], :lora_b[1].shape[0]].copy_( - lora_b[1].T, non_blocking=True) - if lora_bias is not None and lora_bias[1] is not None: + for i in range(self.n_slices): + if (lora_a_i := lora_a[i]) is not None: + self.lora_a_stacked[i][ + index, 0, :lora_a_i.shape[1], :lora_a_i.shape[0]].copy_( + lora_a_i.T, non_blocking=True) + if (lora_b_i := lora_b[i]) is not None: + self.lora_b_stacked[i][ + index, 0, :lora_b_i.shape[1], :lora_b_i.shape[0]].copy_( + lora_b_i.T, non_blocking=True) + + if lora_bias is not None: self.lora_bias_stacked = cast(Tuple[torch.Tensor, ...], self.lora_bias_stacked) - self.lora_bias_stacked[1][index, 0, :lora_bias[1].shape[0]].copy_( - lora_bias[1].T, non_blocking=True) + for i in range(self.n_slices): + if (lora_bias_i := lora_bias[i]) is not None: + self.lora_bias_stacked[i][index, + 0, :lora_bias_i.shape[0]].copy_( + lora_bias_i.T, + non_blocking=True) @classmethod @_not_fully_sharded_can_replace @@ -755,8 +743,8 @@ def can_replace_layer(cls, source_layer: nn.Module, packed_modules_list) == 1 -class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA): - """ColumnParallelLinear layer that is composed of 3 sublayers (slices) +class MergedQKVParallelLinearWithLora(MergedColumnParallelLinearWithLoRA): + """MergedColumnParallelLinear layer that is composed of 3 sublayers (slices) packed together in qkv proj fashion (q_proj + k_proj + v_proj -> qkv_proj). @@ -773,22 +761,6 @@ def __init__(self, base_layer: QKVParallelLinear) -> None: self.tp_size = get_tensor_model_parallel_world_size() self.tp_rank = get_tensor_model_parallel_rank() - def create_lora_weights( - self, - max_loras: int, - lora_config: LoRAConfig, - model_config: Optional[PretrainedConfig] = None, - ) -> None: - """ - The main reason for overloading this function is to handle inconsistent - weight dimensions in qkv lora. - """ - self.lora_config = lora_config - - if not (len(self.base_layer.output_sizes) == self.n_slices == 3): - raise ValueError( - "LoRAColumnParallelLinear3Slice requires 3 slices.") - self.q_proj_shard_size = (self.base_layer.num_heads * self.base_layer.head_size) self.kv_proj_shard_size = (self.base_layer.num_kv_heads * @@ -796,203 +768,28 @@ def create_lora_weights( self.q_shard_id = self.tp_rank self.kv_shard_id = self.tp_rank // self.base_layer.num_kv_head_replicas - lora_a_output_size_per_partition = ( - lora_config.max_lora_rank if not lora_config.fully_sharded_loras - else divide(lora_config.max_lora_rank, self.tp_size)) - # q, k, v - self.lora_a_stacked = ( - torch.zeros( - max_loras, - 1, - lora_a_output_size_per_partition, - self.input_size, - dtype=lora_config.lora_dtype, - device=self.device, - ), - torch.zeros( - max_loras, - 1, - lora_a_output_size_per_partition, - self.input_size, - dtype=lora_config.lora_dtype, - device=self.device, - ), - torch.zeros( - max_loras, - 1, - lora_a_output_size_per_partition, - self.input_size, - dtype=lora_config.lora_dtype, - device=self.device, - ), - ) - self.lora_b_stacked = ( - torch.zeros( - max_loras, - 1, - self.q_proj_shard_size, - lora_config.max_lora_rank, - dtype=lora_config.lora_dtype, - device=self.device, - ), - torch.zeros( - max_loras, - 1, - self.kv_proj_shard_size, - lora_config.max_lora_rank, - dtype=lora_config.lora_dtype, - device=self.device, - ), - torch.zeros( - max_loras, - 1, - self.kv_proj_shard_size, - lora_config.max_lora_rank, - dtype=lora_config.lora_dtype, - device=self.device, - ), - ) - if lora_config.bias_enabled: - self.lora_bias_stacked = ( - torch.zeros( - max_loras, - 1, - self.q_proj_shard_size, - dtype=lora_config.lora_dtype, - device=self.device, - ), - torch.zeros( - max_loras, - 1, - self.kv_proj_shard_size, - dtype=lora_config.lora_dtype, - device=self.device, - ), - torch.zeros( - max_loras, - 1, - self.kv_proj_shard_size, - dtype=lora_config.lora_dtype, - device=self.device, - ), - ) self.output_slices = ( self.q_proj_shard_size, self.kv_proj_shard_size, self.kv_proj_shard_size, ) - self.packed_indices: Optional[torch.Tensor] = None - self.standard_indices: Optional[torch.Tensor] = None - # lazily initialized. - self.indices: torch.Tensor - self.indices_len: List[int] - - def slice_lora_a( - self, lora_a: List[Union[torch.Tensor, None]] - ) -> List[Union[torch.Tensor, None]]: - return lora_a - - def slice_lora_b( - self, lora_b: List[Union[torch.Tensor, None]] - ) -> List[Union[torch.Tensor, None]]: - lora_b_q, lora_b_k, lora_b_v = None, None, None - if lora_b[0] is not None: - lora_b_q = lora_b[0][:, self.q_proj_shard_size * - self.q_shard_id:self.q_proj_shard_size * - (self.q_shard_id + 1), ] - if lora_b[1] is not None: - lora_b_k = lora_b[1][:, self.kv_proj_shard_size * - self.kv_shard_id:self.kv_proj_shard_size * - (self.kv_shard_id + 1), ] - if lora_b[2] is not None: - lora_b_v = lora_b[2][:, self.kv_proj_shard_size * - self.kv_shard_id:self.kv_proj_shard_size * - (self.kv_shard_id + 1), ] - lora_b = [lora_b_q, lora_b_k, lora_b_v] - return lora_b - - def slice_bias( - self, bias: List[Union[torch.Tensor, - None]]) -> List[Union[torch.Tensor, None]]: - bias_q, bias_k, bias_v = bias - if bias_q is not None: - bias_q = bias_q[self.q_proj_shard_size * - self.q_shard_id:self.q_proj_shard_size * - (self.q_shard_id + 1)] - if bias_k is not None: - bias_k = bias_k[self.kv_proj_shard_size * - self.kv_shard_id:self.kv_proj_shard_size * - (self.kv_shard_id + 1)] - if bias_v is not None: - bias_v = bias_v[self.kv_proj_shard_size * - self.kv_shard_id:self.kv_proj_shard_size * - (self.kv_shard_id + 1)] - bias = [bias_q, bias_k, bias_v] - return bias + self.output_ids = ( + self.q_shard_id, + self.kv_shard_id, + self.kv_shard_id, + ) - def set_lora( + def create_lora_weights( self, - index: int, - lora_a: torch.Tensor, - lora_b: torch.Tensor, - embeddings_tensor: Optional[torch.Tensor], - lora_bias: Optional[torch.Tensor] = None, - ): - self.reset_lora(index) - - if self.tp_size > 1: - lora_a = self.slice_lora_a(lora_a) - lora_b = self.slice_lora_b(lora_b) - if lora_bias is not None: - lora_bias = self.slice_bias(lora_bias) - - if lora_b[0] is not None: - lora_b_q = lora_b[0] - self.lora_b_stacked[0][ - index, 0, :lora_b_q.shape[1], :lora_b_q.shape[0]].copy_( - lora_b_q.T, non_blocking=True) - if lora_b[1] is not None: - lora_b_k = lora_b[1] - self.lora_b_stacked[1][ - index, 0, :lora_b_k.shape[1], :lora_b_k.shape[0]].copy_( - lora_b_k.T, non_blocking=True) - if lora_b[2] is not None: - lora_b_v = lora_b[2] - self.lora_b_stacked[2][ - index, 0, :lora_b_v.shape[1], :lora_b_v.shape[0]].copy_( - lora_b_v.T, non_blocking=True) - - if lora_a[0] is not None: - self.lora_a_stacked[0][ - index, 0, :lora_a[0].shape[1], :lora_a[0].shape[0]].copy_( - lora_a[0].T, non_blocking=True) - if lora_a[1] is not None: - self.lora_a_stacked[1][ - index, 0, :lora_a[1].shape[1], :lora_a[1].shape[0]].copy_( - lora_a[1].T, non_blocking=True) - if lora_a[2] is not None: - self.lora_a_stacked[2][ - index, 0, :lora_a[2].shape[1], :lora_a[2].shape[0]].copy_( - lora_a[2].T, non_blocking=True) - - if lora_bias is not None: - self.lora_bias_stacked = cast(Tuple[torch.Tensor, ...], - self.lora_bias_stacked) - if lora_bias[0] is not None: - self.lora_bias_stacked[0][index, - 0, :lora_bias[0].shape[0]].copy_( - lora_bias[0].T, - non_blocking=True) - if lora_bias[1] is not None: - self.lora_bias_stacked[1][index, - 0, :lora_bias[1].shape[0]].copy_( - lora_bias[1].T, - non_blocking=True) - if lora_bias[2] is not None: - self.lora_bias_stacked[2][index, - 0, :lora_bias[2].shape[0]].copy_( - lora_bias[2].T, - non_blocking=True) + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None, + ) -> None: + """ + The main reason for overloading this function is to handle inconsistent + weight dimensions in qkv lora. + """ + super().create_lora_weights(max_loras, lora_config, model_config) @classmethod @_not_fully_sharded_can_replace From bf0e382e16065edebbbb414f7889d31523a569e1 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Sat, 7 Dec 2024 22:22:52 +0800 Subject: [PATCH 10/18] [Model] Composite weight loading for multimodal Qwen2 (#10944) Signed-off-by: DarkLight1337 --- vllm/config.py | 10 +- vllm/model_executor/model_loader/loader.py | 4 +- vllm/model_executor/model_loader/utils.py | 10 +- vllm/model_executor/models/qwen2.py | 17 +- vllm/model_executor/models/qwen2_audio.py | 117 ++++---------- vllm/model_executor/models/qwen2_vl.py | 179 ++++++++++----------- vllm/model_executor/models/utils.py | 15 +- 7 files changed, 147 insertions(+), 205 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index fe4c85441fced..db7046ab2c22d 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2472,7 +2472,15 @@ def _get_quantization_config( return quant_config return None - def with_hf_config(self, hf_config: PretrainedConfig) -> "VllmConfig": + def with_hf_config( + self, + hf_config: PretrainedConfig, + architectures: Optional[list[str]] = None, + ) -> "VllmConfig": + if architectures is not None: + hf_config = copy.deepcopy(hf_config) + hf_config.architectures = architectures + model_config = copy.deepcopy(self.model_config) model_config.hf_config = hf_config diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index a0ea0e5fad3c2..fdc4c6305bd5e 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -101,12 +101,10 @@ def _initialize_model( vllm_config: VllmConfig, *, prefix: str = "", - architectures: Optional[list[str]] = None, ) -> nn.Module: """Initialize a model with the given configurations.""" model_config = vllm_config.model_config - model_class, _ = get_model_architecture(model_config, - architectures=architectures) + model_class, _ = get_model_architecture(model_config) signatures = inspect.signature(model_class.__init__) all_params = [param.name for param in signatures.parameters.values()] diff --git a/vllm/model_executor/model_loader/utils.py b/vllm/model_executor/model_loader/utils.py index 864dd04e79921..cfb89e0f336bc 100644 --- a/vllm/model_executor/model_loader/utils.py +++ b/vllm/model_executor/model_loader/utils.py @@ -1,6 +1,6 @@ """Utilities for selecting and loading models.""" import contextlib -from typing import Optional, Tuple, Type +from typing import Tuple, Type import torch from torch import nn @@ -20,12 +20,8 @@ def set_default_torch_dtype(dtype: torch.dtype): def get_model_architecture( - model_config: ModelConfig, - *, - architectures: Optional[list[str]] = None, -) -> Tuple[Type[nn.Module], str]: - if architectures is None: - architectures = getattr(model_config.hf_config, "architectures", []) + model_config: ModelConfig) -> Tuple[Type[nn.Module], str]: + architectures = getattr(model_config.hf_config, "architectures", []) # Special handling for quantized Mixtral. # FIXME(woosuk): This is a temporary hack. diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index 7d4cc4b69e614..3ce4eb5869f21 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -444,14 +444,17 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.model = Qwen2Model(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")) - if config.tie_word_embeddings: - self.lm_head = self.model.embed_tokens + if get_pp_group().is_last_rank: + if config.tie_word_embeddings: + self.lm_head = self.model.embed_tokens + else: + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=maybe_prefix( + prefix, "lm_head")) else: - self.lm_head = ParallelLMHead(config.vocab_size, - config.hidden_size, - quant_config=quant_config, - prefix=maybe_prefix( - prefix, "lm_head")) + self.lm_head = PPMissingLayer() self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = get_sampler() diff --git a/vllm/model_executor/models/qwen2_audio.py b/vllm/model_executor/models/qwen2_audio.py index a0605fee82aca..48a2d470414b9 100644 --- a/vllm/model_executor/models/qwen2_audio.py +++ b/vllm/model_executor/models/qwen2_audio.py @@ -19,7 +19,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Qwen2-Audio model compatible with HuggingFace weights.""" -from functools import lru_cache +from functools import cached_property, lru_cache from typing import (Iterable, List, Mapping, Optional, Set, Tuple, TypedDict, Union) @@ -34,12 +34,7 @@ from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, InputContext, token_inputs) from vllm.logger import init_logger -from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler -from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead -from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, maybe_remap_kv_scale_name) -from vllm.model_executor.models.qwen2 import Qwen2Model from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs from vllm.multimodal.inputs import NestedTensors @@ -47,15 +42,11 @@ from vllm.sequence import IntermediateTensors, SequenceData from .interfaces import SupportsMultiModal, SupportsPP -from .utils import merge_multimodal_embeddings +from .utils import (AutoWeightsLoader, init_vllm_registered_model, + maybe_prefix, merge_multimodal_embeddings) logger = init_logger(__name__) -_KEYS_TO_MODIFY_MAPPING = { - "language_model.lm_head": "lm_head", - "language_model.model": "language_model", -} - # # === Audio Inputs === # class Qwen2AudioInputs(TypedDict): @@ -281,25 +272,23 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.quant_config = quant_config - self.language_model = Qwen2Model( - vllm_config=vllm_config.with_hf_config(config.text_config), - prefix=prefix) - self.unpadded_vocab_size = config.text_config.vocab_size - if config.text_config.tie_word_embeddings: - self.lm_head = self.language_model.embed_tokens - else: - self.lm_head = ParallelLMHead(config.text_config.vocab_size, - config.text_config.hidden_size, - quant_config=quant_config) - logit_scale = getattr(config, "logit_scale", 1.0) - self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, - config.text_config.vocab_size, - logit_scale) - self.sampler = get_sampler() + self.language_model = init_vllm_registered_model( + vllm_config=vllm_config, + hf_config=config.text_config, + prefix=maybe_prefix(prefix, "language_model"), + architectures=["Qwen2ForCausalLM"], + ) self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors) + @cached_property + def sampler(self): + if hasattr(self.language_model, "sampler"): + return self.language_model.sampler + + return get_sampler() + def _validate_and_reshape_mm_tensor(self, mm_input: Union[torch.Tensor, List[torch.Tensor]], @@ -414,72 +403,30 @@ def forward( multimodal_embeddings) input_ids = None - hidden_states = self.language_model(input_ids, - positions, - kv_caches, - attn_metadata, - intermediate_tensors, - inputs_embeds=inputs_embeds) + hidden_states = self.language_model.model(input_ids, + positions, + kv_caches, + attn_metadata, + intermediate_tensors, + inputs_embeds=inputs_embeds) return hidden_states - def compute_logits(self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) - return logits + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + return self.language_model.compute_logits(hidden_states, + sampling_metadata) def sample( self, logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens + return self.language_model.sample(logits, sampling_metadata) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - ("qkv_proj", "q_proj", "q"), - ("qkv_proj", "k_proj", "k"), - ("qkv_proj", "v_proj", "v"), - ("gate_up_proj", "gate_proj", 0), - ("gate_up_proj", "up_proj", 1), - ] - params_dict = dict(self.named_parameters(remove_duplicate=False)) - loaded_params: Set[str] = set() - for name, loaded_weight in weights: - if "rotary_emb.inv_freq" in name: - continue - if (self.config.text_config.tie_word_embeddings - and "lm_head.weight" in name): - continue - for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items(): - if key_to_modify in name: - name = name.replace(key_to_modify, new_key) - for (param_name, weight_name, shard_id) in stacked_params_mapping: - if weight_name not in name or 'audio' in name: - continue - name = name.replace(weight_name, param_name) - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - break - else: - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - # Remapping the name of FP8 kv-scale. - name = maybe_remap_kv_scale_name(name, params_dict) - if name is None: - continue - - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) - loaded_params.add(name) - return loaded_params + loader = AutoWeightsLoader(self) + return loader.load_weights(weights) diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 27175dbae7483..cfc90cdab01e4 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -21,7 +21,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Qwen2-VL model compatible with HuggingFace weights.""" -from functools import partial +from functools import cached_property, partial from typing import (Any, Callable, Dict, Iterable, List, Literal, Mapping, Optional, Set, Tuple, Type, TypedDict, Union) @@ -40,7 +40,7 @@ from vllm.attention import AttentionMetadata from vllm.config import VllmConfig -from vllm.distributed import get_pp_group, parallel_state +from vllm.distributed import parallel_state from vllm.distributed import utils as dist_utils from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, InputContext, token_inputs) @@ -49,15 +49,12 @@ from vllm.model_executor.layers.activation import QuickGELU from vllm.model_executor.layers.linear import (ColumnParallelLinear, RowParallelLinear) -from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization.gptq import GPTQConfig from vllm.model_executor.layers.quantization.gptq_marlin import ( GPTQMarlinConfig) from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler -from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.models.qwen2 import Qwen2Model from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.image import cached_get_image_processor from vllm.multimodal.inputs import (MultiModalData, MultiModalDataDict, @@ -69,9 +66,8 @@ from vllm.transformers_utils.processor import cached_get_processor from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP -from .utils import (PPMissingLayer, get_vit_attn_backend, - is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, maybe_prefix) +from .utils import (AutoWeightsLoader, WeightsMapper, get_vit_attn_backend, + init_vllm_registered_model, maybe_prefix) logger = init_logger(__name__) @@ -506,6 +502,8 @@ def __init__( mlp_ratio: float = vision_config.mlp_ratio self.spatial_merge_size = spatial_merge_size + self.num_heads = num_heads + self.embed_dim = embed_dim self.patch_embed = Qwen2VisionPatchEmbed( patch_size=patch_size, @@ -595,6 +593,53 @@ def forward( x = self.merger(x) return x + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] + params_dict = dict(self.named_parameters(remove_duplicate=False)) + loaded_params: Set[str] = set() + + for name, loaded_weight in weights: + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + if name.endswith("qkv.weight"): + visual_num_heads = self.num_heads + visual_embed_dim = self.embed_dim + head_size = visual_embed_dim // visual_num_heads + loaded_weight = loaded_weight.view(3, visual_num_heads, + head_size, + visual_embed_dim) + loaded_weight = loaded_weight.transpose(0, 1) + loaded_weight = loaded_weight.reshape(-1, visual_embed_dim) + elif name.endswith("qkv.bias"): + visual_num_heads = self.num_heads + visual_embed_dim = self.embed_dim + head_size = visual_embed_dim // visual_num_heads + loaded_weight = loaded_weight.view(3, visual_num_heads, + head_size) + loaded_weight = loaded_weight.transpose(0, 1) + loaded_weight = loaded_weight.reshape(-1) + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + # === Vision input helpers === # @@ -1082,27 +1127,21 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): prefix=maybe_prefix(prefix, "visual"), ) - self.model = Qwen2Model(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) + self.language_model = init_vllm_registered_model( + vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "language_model"), + architectures=["Qwen2ForCausalLM"], + ) - if get_pp_group().is_last_rank: - if config.tie_word_embeddings: - self.lm_head = self.model.embed_tokens - else: - self.lm_head = ParallelLMHead(config.vocab_size, - config.hidden_size, - quant_config=quant_config, - prefix=maybe_prefix( - prefix, "lm_head")) - else: - self.lm_head = PPMissingLayer() + self.make_empty_intermediate_tensors = ( + self.language_model.make_empty_intermediate_tensors) - self.logits_processor = LogitsProcessor(config.vocab_size) - self.sampler = get_sampler() + @cached_property + def sampler(self): + if hasattr(self.language_model, "sampler"): + return self.language_model.sampler - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory( - ["hidden_states", "residual"], config.hidden_size)) + return get_sampler() def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig): # GPTQ configs do not have a list of ignored modules, however AutoGPTQ @@ -1261,7 +1300,7 @@ def get_input_embeddings( multimodal_embeddings: Optional[List[Tuple[NestedTensors, str]]] = None, ) -> torch.Tensor: - inputs_embeds = self.model.get_input_embeddings(input_ids) + inputs_embeds = self.language_model.get_input_embeddings(input_ids) if multimodal_embeddings is not None: for embeddings, modality in multimodal_embeddings: if modality == "image": @@ -1330,7 +1369,7 @@ def forward( multimodal_embeddings) input_ids = None - hidden_states = self.model( + hidden_states = self.language_model.model( input_ids=input_ids, positions=positions, kv_caches=kv_caches, @@ -1340,80 +1379,28 @@ def forward( ) return hidden_states - def compute_logits(self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) - return logits + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + return self.language_model.compute_logits(hidden_states, + sampling_metadata) def sample( self, logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens + return self.language_model.sample(logits, sampling_metadata) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - ("qkv_proj", "q_proj", "q"), - ("qkv_proj", "k_proj", "k"), - ("qkv_proj", "v_proj", "v"), - ("gate_up_proj", "up_proj", 1), - ("gate_up_proj", "gate_proj", 0), - ] - params_dict = dict(self.named_parameters(remove_duplicate=False)) - loaded_params: Set[str] = set() - for name, loaded_weight in weights: - if "rotary_emb.inv_freq" in name: - continue - if self.config.tie_word_embeddings and "lm_head.weight" in name: - continue - for (param_name, weight_name, shard_id) in stacked_params_mapping: - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - if is_pp_missing_parameter(name, self): - continue - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - break - else: - if "visual" in name and name.endswith("qkv.weight"): - visual_num_heads = self.config.vision_config.num_heads - visual_embed_dim = self.config.vision_config.embed_dim - head_size = visual_embed_dim // visual_num_heads - loaded_weight = loaded_weight.view(3, visual_num_heads, - head_size, - visual_embed_dim) - loaded_weight = loaded_weight.transpose(0, 1) - loaded_weight = loaded_weight.reshape(-1, visual_embed_dim) - elif "visual" in name and name.endswith("qkv.bias"): - visual_num_heads = self.config.vision_config.num_heads - visual_embed_dim = self.config.vision_config.embed_dim - head_size = visual_embed_dim // visual_num_heads - loaded_weight = loaded_weight.view(3, visual_num_heads, - head_size) - loaded_weight = loaded_weight.transpose(0, 1) - loaded_weight = loaded_weight.reshape(-1) - try: - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - if is_pp_missing_parameter(name, self): - continue - param = params_dict[name] - except KeyError: - raise ValueError(f"Unexpected weight: {name}") from None - - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) - loaded_params.add(name) - return loaded_params + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix={ + "lm_head.": "language_model.lm_head.", + "model.": "language_model.model.", + }) + + loader = AutoWeightsLoader(self) + return loader.load_weights(weights, mapper=hf_to_vllm_mapper) diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index 7a1e1f9bf2be4..5ec44955dbd80 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -17,7 +17,7 @@ from vllm.multimodal import MultiModalPlaceholderMap, NestedTensors from vllm.platforms import _Backend, current_platform from vllm.sequence import IntermediateTensors -from vllm.utils import is_pin_memory_available +from vllm.utils import is_pin_memory_available, print_warning_once logger = init_logger(__name__) @@ -251,12 +251,15 @@ def init_vllm_registered_model( """ from vllm.model_executor.model_loader.loader import _initialize_model + if hf_config is None and architectures is not None: + # So that the architectures field is overridden + hf_config = vllm_config.model_config.hf_config + if hf_config is not None: - vllm_config = vllm_config.with_hf_config(hf_config) + vllm_config = vllm_config.with_hf_config(hf_config, + architectures=architectures) - return _initialize_model(vllm_config=vllm_config, - prefix=prefix, - architectures=architectures) + return _initialize_model(vllm_config=vllm_config, prefix=prefix) @overload @@ -592,7 +595,7 @@ def get_vit_attn_backend(support_fa: bool = False) -> _Backend: if is_flash_attn_2_available(): selected_backend = _Backend.FLASH_ATTN else: - logger.warning( + print_warning_once( "Current `vllm-flash-attn` has a bug inside vision module, " "so we use xformers backend instead. You can run " "`pip install flash-attn` to use flash-attention backend.") From 1c768fe53713ef333d74a6645e6a59fb7516134f Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Sun, 8 Dec 2024 00:58:02 +0800 Subject: [PATCH 11/18] [Doc] Explicitly state that InternVL 2.5 is supported (#10978) Signed-off-by: DarkLight1337 --- docs/source/models/supported_models.rst | 4 ++-- examples/offline_inference_vision_language.py | 2 +- examples/offline_inference_vision_language_multi_image.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index 5b416e04da745..d915def588e08 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -547,9 +547,9 @@ Text Generation - ✅︎ - * - :code:`InternVLChatModel` - - InternVL2 + - InternVL 2.5, Mono-InternVL, InternVL 2.0 - T + I\ :sup:`E+` - - :code:`OpenGVLab/Mono-InternVL-2B`, :code:`OpenGVLab/InternVL2-4B`, :code:`OpenGVLab/InternVL2-8B`, etc. + - :code:`OpenGVLab/InternVL2_5-4B`, :code:`OpenGVLab/Mono-InternVL-2B`, :code:`OpenGVLab/InternVL2-4B`, etc. - - ✅︎ * - :code:`LlavaForConditionalGeneration` diff --git a/examples/offline_inference_vision_language.py b/examples/offline_inference_vision_language.py index f08f22eec164a..56209c3c36ed4 100644 --- a/examples/offline_inference_vision_language.py +++ b/examples/offline_inference_vision_language.py @@ -223,7 +223,7 @@ def run_internvl(question: str, modality: str): # Stop tokens for InternVL # models variants may have different stop tokens # please refer to the model card for the correct "stop words": - # https://huggingface.co/OpenGVLab/InternVL2-2B#service + # https://huggingface.co/OpenGVLab/InternVL2-2B/blob/main/conversation.py stop_tokens = ["<|endoftext|>", "<|im_start|>", "<|im_end|>", "<|end|>"] stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens] return llm, prompt, stop_token_ids diff --git a/examples/offline_inference_vision_language_multi_image.py b/examples/offline_inference_vision_language_multi_image.py index 788b604cfd4a0..928bbef54eab7 100644 --- a/examples/offline_inference_vision_language_multi_image.py +++ b/examples/offline_inference_vision_language_multi_image.py @@ -165,7 +165,7 @@ def load_internvl(question: str, image_urls: List[str]) -> ModelRequestData: # Stop tokens for InternVL # models variants may have different stop tokens # please refer to the model card for the correct "stop words": - # https://huggingface.co/OpenGVLab/InternVL2-2B#service + # https://huggingface.co/OpenGVLab/InternVL2-2B/blob/main/conversation.py stop_tokens = ["<|endoftext|>", "<|im_start|>", "<|im_end|>", "<|end|>"] stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens] From 39e227c7ae3149eb8345ea1a1ffee672ef76c09a Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Sun, 8 Dec 2024 01:10:05 +0800 Subject: [PATCH 12/18] [Model] Update multi-modal processor to support Mantis(LLaVA) model (#10711) Signed-off-by: DarkLight1337 --- .buildkite/test-pipeline.yaml | 2 + docs/source/models/supported_models.rst | 6 +- examples/offline_inference_vision_language.py | 17 +++++ requirements-test.in | 3 - .../vision_language/test_models.py | 30 +++++--- .../vision_language/vlm_utils/core.py | 20 ++++-- .../vision_language/vlm_utils/model_utils.py | 35 +++++++++- .../vision_language/vlm_utils/types.py | 19 ++++-- tests/models/registry.py | 1 + .../vllm_add_dummy_model/my_llava.py | 6 +- vllm/model_executor/models/llava.py | 68 ++++++++++++++++--- vllm/model_executor/models/registry.py | 1 + vllm/multimodal/processing.py | 4 +- vllm/multimodal/registry.py | 41 +---------- 14 files changed, 175 insertions(+), 78 deletions(-) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 936e284d9675a..8f57006214c88 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -362,6 +362,7 @@ steps: - tests/models/embedding/vision_language - tests/models/encoder_decoder/vision_language commands: + - pip install git+https://github.com/TIGER-AI-Lab/Mantis.git - pytest -v -s models/decoder_only/audio_language -m 'core_model or quant_model' - pytest -v -s --ignore models/decoder_only/vision_language/test_phi3v.py models/decoder_only/vision_language -m 'core_model or quant_model' - pytest -v -s models/embedding/vision_language -m core_model @@ -377,6 +378,7 @@ steps: - tests/models/embedding/vision_language - tests/models/encoder_decoder/vision_language commands: + - pip install git+https://github.com/TIGER-AI-Lab/Mantis.git - pytest -v -s models/decoder_only/audio_language -m 'not core_model and not quant_model' # HACK - run phi3v tests separately to sidestep this transformers bug # https://github.com/huggingface/transformers/issues/34307 diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index d915def588e08..c9b3fa8485ff1 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -555,7 +555,7 @@ Text Generation * - :code:`LlavaForConditionalGeneration` - LLaVA-1.5 - T + I\ :sup:`E+` - - :code:`llava-hf/llava-1.5-7b-hf`, :code:`llava-hf/llava-1.5-13b-hf`, etc. + - :code:`llava-hf/llava-1.5-7b-hf`, :code:`TIGER-Lab/Mantis-8B-siglip-llama3` (see note), etc. - - ✅︎ * - :code:`LlavaNextForConditionalGeneration` @@ -664,6 +664,10 @@ Text Generation .. note:: vLLM currently only supports adding LoRA to the language backbone of multimodal models. +.. note:: + To use :code:`TIGER-Lab/Mantis-8B-siglip-llama3`, you have to install their GitHub repo (:code:`pip install git+https://github.com/TIGER-AI-Lab/Mantis.git`) + and pass :code:`--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'` when running vLLM. + .. note:: The official :code:`openbmb/MiniCPM-V-2` doesn't work yet, so we need to use a fork (:code:`HwwwH/MiniCPM-V-2`) for now. For more details, please see: https://github.com/vllm-project/vllm/pull/4087#issuecomment-2250397630 diff --git a/examples/offline_inference_vision_language.py b/examples/offline_inference_vision_language.py index 56209c3c36ed4..c6a274ee5894b 100644 --- a/examples/offline_inference_vision_language.py +++ b/examples/offline_inference_vision_language.py @@ -419,6 +419,22 @@ def run_aria(question: str, modality: str): return llm, prompt, stop_token_ids +# Mantis +def run_mantis(question: str, modality: str): + assert modality == "image" + + llama3_template = '<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n' # noqa: E501 + prompt = llama3_template.format(f"{question}\n") + + llm = LLM( + model="TIGER-Lab/Mantis-8B-siglip-llama3", + max_model_len=4096, + hf_overrides={"architectures": ["MantisForConditionalGeneration"]}, + ) + stop_token_ids = [128009] + return llm, prompt, stop_token_ids + + model_example_map = { "llava": run_llava, "llava-next": run_llava_next, @@ -441,6 +457,7 @@ def run_aria(question: str, modality: str): "glm4v": run_glm4v, "idefics3": run_idefics3, "aria": run_aria, + "mantis": run_mantis, } diff --git a/requirements-test.in b/requirements-test.in index 44972866ddc4b..c0b228148ab31 100644 --- a/requirements-test.in +++ b/requirements-test.in @@ -24,9 +24,6 @@ mistral_common[opencv] >= 1.5.0 # required for pixtral test datamodel_code_generator # required for minicpm3 test lm-eval[api]==0.4.4 # required for model evaluation test -# TODO: Add this after fully implementing llava(mantis) -# git+https://github.com/TIGER-AI-Lab/Mantis.git # required for llava(mantis) test - # quantization bitsandbytes>=0.44.0 buildkite-test-collector==0.1.9 diff --git a/tests/models/decoder_only/vision_language/test_models.py b/tests/models/decoder_only/vision_language/test_models.py index 924f19c4448b8..ed8f34a677f84 100644 --- a/tests/models/decoder_only/vision_language/test_models.py +++ b/tests/models/decoder_only/vision_language/test_models.py @@ -34,7 +34,7 @@ "dtype": "half", "max_tokens": 5, "tensor_parallel_size": 2, - "model_kwargs": {"device_map": "auto"}, + "hf_model_kwargs": {"device_map": "auto"}, "image_size_factors": [(.25, 0.5, 1.0)], "distributed_executor_backend": ( "ray", @@ -108,7 +108,7 @@ "cherry_blossom": "What is in the picture?", }), auto_cls=AutoModelForVision2Seq, - postprocess_inputs=model_utils.get_key_type_post_processor( + postprocess_inputs=model_utils.cast_dtype_post_processor( "pixel_values" ), vllm_output_post_proc=model_utils.paligemma_vllm_to_hf_output, @@ -151,7 +151,7 @@ "cherry_blossom": "Please infer the season with reason.", }), multi_image_prompt="Describe the two images shortly.", # noqa: E501 - postprocess_inputs=model_utils.get_key_type_post_processor("pixel_values"), + postprocess_inputs=model_utils.cast_dtype_post_processor("pixel_values"), stop_str=["<|im_end|>"], image_size_factors=[(0.10, 0.15)], max_tokens=64, @@ -177,7 +177,7 @@ prompt_formatter=lambda img_prompt: f"USER: {img_prompt}\nASSISTANT:", max_model_len=4096, auto_cls=AutoModelForVision2Seq, - postprocess_inputs=model_utils.get_key_type_post_processor( + postprocess_inputs=model_utils.cast_dtype_post_processor( "pixel_values" ), # For chameleon, we only compare the sequences @@ -281,7 +281,7 @@ prompt_formatter=lambda vid_prompt: f"<|im_start|>user\n{vid_prompt}<|im_end|>\n<|im_start|>assistant\n", # noqa: E501 num_video_frames=16, max_model_len=16384, - postprocess_inputs=model_utils.get_key_type_post_processor( + postprocess_inputs=model_utils.cast_dtype_post_processor( "pixel_values_videos" ), auto_cls=AutoModelForVision2Seq, @@ -306,6 +306,20 @@ vllm_output_post_proc=model_utils.llava_video_vllm_to_hf_output, image_sizes=[((1669, 2560), (2560, 1669), (183, 488), (488, 183))], ), + "mantis": VLMTestInfo( + models=["TIGER-Lab/Mantis-8B-siglip-llama3"], + test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), + prompt_formatter=lambda img_prompt: f"<|start_header_id|>user<|end_header_id|>\n\n{img_prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", # noqa: E501 + max_model_len=4096, + postprocess_inputs=model_utils.cast_dtype_post_processor( + "pixel_values" + ), + vllm_runner_kwargs={"hf_overrides": {"architectures": ["MantisForConditionalGeneration"]}}, # noqa: E501 + get_stop_token_ids=lambda tok: [128009], + auto_cls=AutoModelForVision2Seq, + vllm_output_post_proc=model_utils.mantis_vllm_to_hf_output, + patch_hf_runner=model_utils.mantis_patch_hf_runner, + ), "minicpmv_25": VLMTestInfo( models=["openbmb/MiniCPM-Llama3-V-2_5"], test_type=VLMTestType.IMAGE, @@ -342,7 +356,7 @@ # max_num_seqs=2, # task="generate", # # use eager mode for hf runner since phi3v didn't work with flash_attn - # model_kwargs={"_attn_implementation": "eager"}, + # hf_model_kwargs={"_attn_implementation": "eager"}, # use_tokenizer_eos=True, # vllm_output_post_proc=model_utils.phi3v_vllm_to_hf_output, # num_logprobs=10, @@ -373,7 +387,7 @@ prompt_formatter=lambda img_prompt: f"USER: {img_prompt}\nASSISTANT:", max_model_len=4096, auto_cls=AutoModelForVision2Seq, - postprocess_inputs=model_utils.get_key_type_post_processor( + postprocess_inputs=model_utils.cast_dtype_post_processor( "pixel_values" ), vllm_output_post_proc = lambda vllm_output, model: vllm_output[:2], @@ -438,7 +452,7 @@ test_type=VLMTestType.CUSTOM_INPUTS, max_model_len=16384, max_num_seqs=2, - postprocess_inputs=model_utils.get_key_type_post_processor( + postprocess_inputs=model_utils.cast_dtype_post_processor( "pixel_values" ), auto_cls=AutoModelForVision2Seq, diff --git a/tests/models/decoder_only/vision_language/vlm_utils/core.py b/tests/models/decoder_only/vision_language/vlm_utils/core.py index 88349ef9a3a69..54b7b0733210f 100644 --- a/tests/models/decoder_only/vision_language/vlm_utils/core.py +++ b/tests/models/decoder_only/vision_language/vlm_utils/core.py @@ -3,9 +3,11 @@ import torch from PIL.Image import Image -from transformers import AutoTokenizer, BatchEncoding +from transformers import AutoTokenizer, BatchEncoding, PreTrainedTokenizerBase from transformers.models.auto.auto_factory import _BaseAutoModelClass +from vllm.config import TaskOption + from .....conftest import HfRunner, VllmRunner from .types import RunnerOutput @@ -28,13 +30,15 @@ def run_test( use_tokenizer_eos: bool, postprocess_inputs: Callable[[BatchEncoding], BatchEncoding], comparator: Callable[..., None], - get_stop_token_ids: Optional[Callable[[AutoTokenizer], List[int]]], + get_stop_token_ids: Optional[Callable[[PreTrainedTokenizerBase], + List[int]]], stop_str: Optional[List[str]], tokenizer_mode: str, limit_mm_per_prompt: Dict[str, int], - model_kwargs: Optional[Dict[str, Any]], + vllm_runner_kwargs: Optional[Dict[str, Any]], + hf_model_kwargs: Optional[Dict[str, Any]], patch_hf_runner: Optional[Callable[[HfRunner], HfRunner]], - task: str = "auto", + task: TaskOption = "auto", runner_mm_key: str = "images", distributed_executor_backend: Optional[str] = None, tensor_parallel_size: int = 1, @@ -58,6 +62,9 @@ def run_test( if stop_str: vllm_kwargs["stop"] = stop_str + if vllm_runner_kwargs is None: + vllm_runner_kwargs = {} + with vllm_runner(model, tokenizer_mode=tokenizer_mode, max_model_len=max_model_len, @@ -67,7 +74,8 @@ def run_test( tensor_parallel_size=tensor_parallel_size, distributed_executor_backend=distributed_executor_backend, enforce_eager=enforce_eager, - task=task) as vllm_model: + task=task, + **vllm_runner_kwargs) as vllm_model: for prompts, media in vllm_inputs: vllm_kwargs[runner_mm_key] = media vllm_output = vllm_model.generate_greedy_logprobs( @@ -78,7 +86,7 @@ def run_test( dtype=dtype, auto_cls=auto_cls, postprocess_inputs=postprocess_inputs, - model_kwargs=model_kwargs) + model_kwargs=hf_model_kwargs) # Some models need to patch things like the model processor, e.g., internvl if patch_hf_runner is not None: diff --git a/tests/models/decoder_only/vision_language/vlm_utils/model_utils.py b/tests/models/decoder_only/vision_language/vlm_utils/model_utils.py index 15f15dd7d8030..3eca8fb9dcb1a 100644 --- a/tests/models/decoder_only/vision_language/vlm_utils/model_utils.py +++ b/tests/models/decoder_only/vision_language/vlm_utils/model_utils.py @@ -126,6 +126,16 @@ def llava_onevision_vllm_to_hf_output(vllm_output: RunnerOutput, return hf_output_ids, hf_output_str, out_logprobs +def mantis_vllm_to_hf_output(vllm_output: RunnerOutput, + model: str) -> RunnerOutput: + """Sanitize vllm output [mantis] to compare with hf output.""" + output_ids, output_str, out_logprobs = vllm_output + + hf_output_str = output_str + "<|eot_id|>" + + return output_ids, hf_output_str, out_logprobs + + def phi3v_vllm_to_hf_output(vllm_output: RunnerOutput, model: str) -> RunnerOutput: """Sanitize vllm output [phi3v] to be comparable with hf output.""" @@ -184,7 +194,7 @@ def get_llava_embeddings(image_assets: _ImageAssets): ####### postprocessors to run on HF BatchEncoding -def get_key_type_post_processor( +def cast_dtype_post_processor( hf_inp_key: str) -> Callable[[BatchEncoding, str], BatchEncoding]: """Gets a handle to a post processor which converts a given key into a target data type.""" @@ -418,3 +428,26 @@ def _internvl_generate( ) return outputs + + +def mantis_patch_hf_runner(hf_model: HfRunner) -> HfRunner: + from mantis.models.mllava import MLlavaProcessor + + hf_model.processor = MLlavaProcessor.from_pretrained(hf_model.model_name) + + orig_generate = hf_model.model.generate + tokenizer = hf_model.processor.tokenizer + + def _generate(self, *args, **kwargs): + return orig_generate( + *args, + **kwargs, + eos_token_id=[ + tokenizer.eos_token_id, + tokenizer.convert_tokens_to_ids("<|eot_id|>"), + ], + ) + + hf_model.model.generate = types.MethodType(_generate, hf_model.model) + + return hf_model diff --git a/tests/models/decoder_only/vision_language/vlm_utils/types.py b/tests/models/decoder_only/vision_language/vlm_utils/types.py index d410fa8c653ce..e2e0c6390fcb9 100644 --- a/tests/models/decoder_only/vision_language/vlm_utils/types.py +++ b/tests/models/decoder_only/vision_language/vlm_utils/types.py @@ -7,9 +7,11 @@ import torch from PIL.Image import Image from pytest import MarkDecorator -from transformers import AutoModelForCausalLM, AutoTokenizer, BatchEncoding +from transformers import (AutoModelForCausalLM, BatchEncoding, + PreTrainedTokenizerBase) from transformers.models.auto.auto_factory import _BaseAutoModelClass +from vllm.config import TaskOption from vllm.sequence import SampleLogprobs from vllm.utils import identity @@ -66,7 +68,7 @@ class ImageSizeWrapper(NamedTuple): class VLMTestInfo(NamedTuple): """Holds the configuration for 1+ tests for one model architecture.""" - models: Union[List[str]] + models: List[str] test_type: Union[VLMTestType, Iterable[VLMTestType]] # Should be None only if this is a CUSTOM_INPUTS test @@ -92,18 +94,20 @@ class VLMTestInfo(NamedTuple): enforce_eager: bool = True max_model_len: int = 1024 max_num_seqs: int = 256 - task: str = "auto" + task: TaskOption = "auto" tensor_parallel_size: int = 1 + vllm_runner_kwargs: Optional[Dict[str, Any]] = None # Optional callable which gets a list of token IDs from the model tokenizer - get_stop_token_ids: Optional[Callable[[AutoTokenizer], List[int]]] = None + get_stop_token_ids: Optional[Callable[[PreTrainedTokenizerBase], + List[int]]] = None # Optional list of strings to stop generation, useful when stop tokens are # not special tokens in the tokenizer stop_str: Optional[List[str]] = None # Exposed options for HF runner - model_kwargs: Optional[Dict[str, Any]] = None - # Indicates we should explicitly pass the EOS from the tokeniezr + hf_model_kwargs: Optional[Dict[str, Any]] = None + # Indicates we should explicitly pass the EOS from the tokenizer use_tokenizer_eos: bool = False auto_cls: Type[_BaseAutoModelClass] = AutoModelForCausalLM # Callable to pass to the HF runner to run on inputs; for now, we also pass @@ -164,6 +168,7 @@ def get_non_parametrized_runner_kwargs(self): "max_num_seqs": self.max_num_seqs, "task": self.task, "tensor_parallel_size": self.tensor_parallel_size, + "vllm_runner_kwargs": self.vllm_runner_kwargs, "hf_output_post_proc": self.hf_output_post_proc, "vllm_output_post_proc": self.vllm_output_post_proc, "auto_cls": self.auto_cls, @@ -171,8 +176,8 @@ def get_non_parametrized_runner_kwargs(self): "postprocess_inputs": self.postprocess_inputs, "comparator": self.comparator, "get_stop_token_ids": self.get_stop_token_ids, + "hf_model_kwargs": self.hf_model_kwargs, "stop_str": self.stop_str, - "model_kwargs": self.model_kwargs, "patch_hf_runner": self.patch_hf_runner, "tokenizer_mode": self.tokenizer_mode } diff --git a/tests/models/registry.py b/tests/models/registry.py index 461f453d8b1c3..a89518820045f 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -176,6 +176,7 @@ class _HfExamplesInfo: "LlavaNextForConditionalGeneration": _HfExamplesInfo("llava-hf/llava-v1.6-mistral-7b-hf"), # noqa: E501 "LlavaNextVideoForConditionalGeneration": _HfExamplesInfo("llava-hf/LLaVA-NeXT-Video-7B-hf"), # noqa: E501 "LlavaOnevisionForConditionalGeneration": _HfExamplesInfo("llava-hf/llava-onevision-qwen2-0.5b-ov-hf"), # noqa: E501 + "MantisForConditionalGeneration": _HfExamplesInfo("TIGER-Lab/Mantis-8B-siglip-llama3"), # noqa: E501 "MiniCPMV": _HfExamplesInfo("openbmb/MiniCPM-Llama3-V-2_5", trust_remote_code=True), "MolmoForCausalLM": _HfExamplesInfo("allenai/Molmo-7B-D-0924", diff --git a/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_llava.py b/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_llava.py index f2fc0755cae01..2f4194a63fc25 100644 --- a/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_llava.py +++ b/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_llava.py @@ -3,16 +3,14 @@ import torch from vllm.model_executor.models.llava import (LlavaForConditionalGeneration, - create_metadata_for_llava, - dummy_mm_kwargs_for_llava, + LlavaProcessor, get_max_llava_image_tokens) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens) -@MULTIMODAL_REGISTRY.register_processor_by_metadata(create_metadata_for_llava, - dummy_mm_kwargs_for_llava) +@MULTIMODAL_REGISTRY.register_processor(LlavaProcessor) class MyLlava(LlavaForConditionalGeneration): def compute_logits( diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index 953b89f1842af..65c6bd07bfff0 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -22,10 +22,11 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import MultiModalKwargs, NestedTensors -from vllm.multimodal.processing import (InputProcessingContext, +from vllm.multimodal.processing import (BaseMultiModalProcessor, + InputProcessingContext, ModalityProcessingMetadata, MultiModalProcessingMetadata, - MultiModalProcessor, PromptReplacement) + PromptReplacement) from vllm.sequence import IntermediateTensors from .clip import (CLIPVisionModel, dummy_image_for_clip, @@ -163,7 +164,13 @@ def get_repl_count( } -class LlavaProcessor(MultiModalProcessor): +class LlavaProcessor(BaseMultiModalProcessor): + + def __init__(self, ctx: InputProcessingContext) -> None: + super().__init__( + ctx=ctx, + metadata=create_metadata_for_llava(ctx), + ) def _patch_pixtral_processor(self, hf_processor: PixtralProcessor): if getattr(hf_processor, "__is_patched__", False): @@ -193,7 +200,30 @@ def _get_dummy_mm_kwargs( self, mm_counts: Mapping[str, int], ) -> MultiModalKwargs: - return dummy_mm_kwargs_for_llava(self.ctx, mm_counts) + hf_config = self.ctx.get_hf_config(LlavaConfig) + vision_config = hf_config.vision_config + num_images = mm_counts["image"] + + if isinstance(vision_config, CLIPVisionConfig): + data = dummy_image_for_clip(vision_config, num_images) + elif isinstance(vision_config, SiglipVisionConfig): + data = dummy_image_for_siglip(vision_config, num_images) + elif isinstance(vision_config, PixtralVisionConfig): + data = dummy_image_for_pixtral_hf(vision_config, num_images) + else: + msg = f"Unsupported vision config: {type(vision_config)}" + raise NotImplementedError(msg) + + hf_processor = self._get_hf_processor() + image_processor = hf_processor.image_processor # type: ignore + hf_inputs = image_processor.preprocess(data['image'], + return_tensors="pt") + is_pixtral = isinstance(hf_processor, PixtralProcessor) + + return MultiModalKwargs( + **hf_inputs, + is_pixtral=torch.tensor(is_pixtral), + ) class LlavaLikeConfig(Protocol): @@ -277,10 +307,7 @@ def init_vision_tower_for_llava( @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens) -@MULTIMODAL_REGISTRY.register_processor(lambda ctx: LlavaProcessor( - ctx=ctx, - metadata=create_metadata_for_llava(ctx), -)) +@MULTIMODAL_REGISTRY.register_processor(LlavaProcessor) class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): # BitandBytes specific attributes bitsandbytes_stacked_params_mapping = { @@ -559,3 +586,28 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights) + + +class MantisProcessor(LlavaProcessor): + + def _get_hf_processor(self) -> ProcessorMixin: + try: + from mantis.models.mllava import MLlavaProcessor + except ModuleNotFoundError as exc: + raise ModuleNotFoundError( + "You need to `pip install " + "git+https://github.com/TIGER-AI-Lab/Mantis.git` " + "to use this model") from exc + + processor = MLlavaProcessor.from_pretrained( + self.ctx.model_config.tokenizer) + assert isinstance(processor, ProcessorMixin) + return processor + + +# To use this model, please use +# `--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'` +@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens) +@MULTIMODAL_REGISTRY.register_processor(MantisProcessor) +class MantisForConditionalGeneration(LlavaForConditionalGeneration): + pass diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index c66fbce018a62..e69596aa915b5 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -152,6 +152,7 @@ "LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"), # noqa: E501 "LlavaNextVideoForConditionalGeneration": ("llava_next_video", "LlavaNextVideoForConditionalGeneration"), # noqa: E501 "LlavaOnevisionForConditionalGeneration": ("llava_onevision", "LlavaOnevisionForConditionalGeneration"), # noqa: E501 + "MantisForConditionalGeneration": ("llava", "MantisForConditionalGeneration"), # noqa: E501 "MiniCPMV": ("minicpmv", "MiniCPMV"), "MolmoForCausalLM": ("molmo", "MolmoForCausalLM"), "NVLM_D": ("nvlm_d", "NVLM_D_Model"), diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index 4a1737991534f..c3a95d60e6fe6 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -529,9 +529,9 @@ def iter_placeholders( yield placeholder -class MultiModalProcessor(ABC): +class BaseMultiModalProcessor(ABC): """ - Helper class to process multi-modal inputs to be used in vLLM. + Abstract base class to process multi-modal inputs to be used in vLLM. """ def __init__( diff --git a/vllm/multimodal/registry.py b/vllm/multimodal/registry.py index f51da8972d15b..6ab6c0fe2f12e 100644 --- a/vllm/multimodal/registry.py +++ b/vllm/multimodal/registry.py @@ -15,7 +15,7 @@ from .base import MultiModalInputMapper, MultiModalPlugin, MultiModalTokensCalc from .image import ImagePlugin from .inputs import MultiModalDataDict, MultiModalKwargs, NestedTensors -from .processing import MultiModalProcessingMetadata, MultiModalProcessor +from .processing import BaseMultiModalProcessor from .video import VideoPlugin if TYPE_CHECKING: @@ -26,7 +26,7 @@ N = TypeVar("N", bound=Type[nn.Module]) MultiModalProcessorFactory: TypeAlias = Callable[[InputProcessingContext], - MultiModalProcessor] + BaseMultiModalProcessor] """ Constructs a :class:`MultiModalProcessor` instance from the context. @@ -311,41 +311,6 @@ def wrapper(model_cls: N) -> N: return wrapper - def register_processor_by_metadata( - self, - metadata_factory: Callable[[InputProcessingContext], - MultiModalProcessingMetadata], - get_dummy_mm_kwargs: Callable[ - [InputProcessingContext, Mapping[str, int]], MultiModalKwargs], - ): - """ - Convenience method to register a multi-modal processor to a model class - according to a function that constructs its metadata. - - When the model receives multi-modal data, the provided function is - invoked to transform the data into a dictionary of model inputs. - - See also: - - :ref:`input_processing_pipeline` - - :ref:`enabling_multimodal_inputs` - """ - - class ConcreteMultiModalProcessor(MultiModalProcessor): - - def _get_dummy_mm_kwargs( - self, - mm_counts: Mapping[str, int], - ) -> MultiModalKwargs: - return get_dummy_mm_kwargs(self.ctx, mm_counts) - - def factory(ctx: InputProcessingContext): - return ConcreteMultiModalProcessor( - ctx=ctx, - metadata=metadata_factory(ctx), - ) - - return self.register_processor(factory) - def has_processor(self, model_config: "ModelConfig") -> bool: """ Test whether a multi-modal processor is defined for a specific model. @@ -360,7 +325,7 @@ def create_processor( self, model_config: "ModelConfig", tokenizer: AnyTokenizer, - ) -> MultiModalProcessor: + ) -> BaseMultiModalProcessor: """ Create a multi-modal processor for a specific model and tokenizer. """ From c889d5888bf6bbfbe3f4ea55bf27ce84a239c3d0 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Sun, 8 Dec 2024 01:20:49 +0800 Subject: [PATCH 13/18] [Doc] Explicitly state that PP isn't compatible with speculative decoding yet (#10975) Signed-off-by: DarkLight1337 --- docs/source/usage/spec_decode.rst | 3 +++ tests/distributed/test_pipeline_parallel.py | 16 +++++++++++++--- vllm/model_executor/models/exaone.py | 3 ++- vllm/model_executor/models/granite.py | 5 +++-- vllm/model_executor/models/llama.py | 3 ++- vllm/model_executor/models/nemotron.py | 4 +++- vllm/model_executor/models/solar.py | 3 ++- vllm/spec_decode/spec_decode_worker.py | 4 ++++ 8 files changed, 32 insertions(+), 9 deletions(-) diff --git a/docs/source/usage/spec_decode.rst b/docs/source/usage/spec_decode.rst index 67e8ede7654b7..f1f1917f974bb 100644 --- a/docs/source/usage/spec_decode.rst +++ b/docs/source/usage/spec_decode.rst @@ -8,6 +8,9 @@ Speculative decoding not usually yield inter-token latency reductions for all prompt datasets or sampling parameters. The work to optimize it is ongoing and can be followed in `this issue. `_ +.. warning:: + Currently, speculative decoding in vLLM is not compatible with pipeline parallelism. + This document shows how to use `Speculative Decoding `_ with vLLM. Speculative decoding is a technique which improves inter-token latency in memory-bound LLM inference. diff --git a/tests/distributed/test_pipeline_parallel.py b/tests/distributed/test_pipeline_parallel.py index 386877e0e0a2c..b818ca921fcb0 100644 --- a/tests/distributed/test_pipeline_parallel.py +++ b/tests/distributed/test_pipeline_parallel.py @@ -247,9 +247,19 @@ def _compare_tp( *, method: Literal["generate", "encode"], ): - tp_size, pp_size, eager_mode, chunked_prefill = parallel_setup - multi_node_only, trust_remote_code, tokenizer_mode, \ - load_format, hf_overrides = test_options + ( + tp_size, + pp_size, + eager_mode, + chunked_prefill, + ) = parallel_setup + ( + multi_node_only, + trust_remote_code, + tokenizer_mode, + load_format, + hf_overrides, + ) = test_options if num_gpus_available < tp_size * pp_size: pytest.skip(f"Need at least {tp_size} x {pp_size} GPUs") diff --git a/vllm/model_executor/models/exaone.py b/vllm/model_executor/models/exaone.py index 5ca26d53a17e7..0398f0943a70a 100644 --- a/vllm/model_executor/models/exaone.py +++ b/vllm/model_executor/models/exaone.py @@ -473,10 +473,11 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size, logit_scale) - self.sampler = get_sampler() else: self.lm_head = PPMissingLayer() + self.sampler = get_sampler() + self.make_empty_intermediate_tensors = ( self.transformer.make_empty_intermediate_tensors) diff --git a/vllm/model_executor/models/granite.py b/vllm/model_executor/models/granite.py index bd2394e71c973..f9e0443b9a508 100644 --- a/vllm/model_executor/models/granite.py +++ b/vllm/model_executor/models/granite.py @@ -400,16 +400,17 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.lm_head.weight = self.model.embed_tokens.weight logit_scale = getattr(config, "logit_scale", 1.0) - if hasattr(config, "logits_scaling"): logit_scale /= config.logits_scaling + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size, scale=logit_scale) - self.sampler = get_sampler() else: self.lm_head = PPMissingLayer() + self.sampler = get_sampler() + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 31dfb235ae877..733b1bc7d80ac 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -540,10 +540,11 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size, logit_scale) - self.sampler = get_sampler() else: self.lm_head = PPMissingLayer() + self.sampler = get_sampler() + self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) diff --git a/vllm/model_executor/models/nemotron.py b/vllm/model_executor/models/nemotron.py index c7b4c22b6896b..34cb9981c167b 100644 --- a/vllm/model_executor/models/nemotron.py +++ b/vllm/model_executor/models/nemotron.py @@ -435,9 +435,11 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size, logit_scale) - self.sampler = get_sampler() else: self.lm_head = PPMissingLayer() + + self.sampler = get_sampler() + self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) diff --git a/vllm/model_executor/models/solar.py b/vllm/model_executor/models/solar.py index f58710d215056..caae0b65d7d10 100644 --- a/vllm/model_executor/models/solar.py +++ b/vllm/model_executor/models/solar.py @@ -443,10 +443,11 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size, logit_scale) - self.sampler = get_sampler() else: self.lm_head = PPMissingLayer() + self.sampler = get_sampler() + self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index ced7f53827665..2689802161987 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -54,6 +54,10 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker": speculative_config: SpeculativeConfig = vllm_config.speculative_config assert speculative_config is not None + if vllm_config.parallel_config.pipeline_parallel_size > 1: + raise NotImplementedError("Speculative decoding is currently " + "incompatible with pipeline parallelism") + draft_worker_kwargs = kwargs.copy() kwargs["model_runner_cls"] = TargetModelRunner From 78029b34ed1be46baf06f92c9e971ea1961d0867 Mon Sep 17 00:00:00 2001 From: zhou fan <1247714429@qq.com> Date: Sun, 8 Dec 2024 01:21:18 +0800 Subject: [PATCH 14/18] [BugFix][Kernel]: fix illegal memory access in causal_conv1d when conv_states is None (#10928) Signed-off-by: xffxff <1247714429@qq.com> --- csrc/mamba/causal_conv1d/causal_conv1d.cu | 2 +- tests/kernels/test_causal_conv1d.py | 39 +++++++++++++---------- 2 files changed, 23 insertions(+), 18 deletions(-) diff --git a/csrc/mamba/causal_conv1d/causal_conv1d.cu b/csrc/mamba/causal_conv1d/causal_conv1d.cu index 498d069c05f0d..dd1e6de2e0180 100644 --- a/csrc/mamba/causal_conv1d/causal_conv1d.cu +++ b/csrc/mamba/causal_conv1d/causal_conv1d.cu @@ -424,7 +424,7 @@ void causal_conv1d_fwd_kernel(ConvParamsBase params) { // and the one before it (chunk = n_chunks - 1 and chunk = n_chunks - 2), // (which occurs when `final_state_position` is a non-positivie index) // we load the correct data from smem_exchange from both chunks, the last chunk iteration and the one before it - if (final_state_position < 0 && seqlen > kWidth){ + if (conv_states != nullptr && final_state_position < 0 && seqlen > kWidth){ input_t vals_load[kNElts] = {0}; if ((chunk == n_chunks - 2) && (tidx == kNThreads - 1)){ // chunk = n_chunks - 2, a segment of the final state sits in the last index diff --git a/tests/kernels/test_causal_conv1d.py b/tests/kernels/test_causal_conv1d.py index f9b11018288be..51be2425d7dd7 100644 --- a/tests/kernels/test_causal_conv1d.py +++ b/tests/kernels/test_causal_conv1d.py @@ -149,13 +149,14 @@ def causal_conv1d_opcheck_fn(x: torch.Tensor, @pytest.mark.parametrize("itype", [torch.bfloat16, torch.float]) @pytest.mark.parametrize("silu_activation", [True]) @pytest.mark.parametrize("has_bias", [True]) +@pytest.mark.parametrize("has_initial_state", [True, False]) @pytest.mark.parametrize("width", [4]) @pytest.mark.parametrize( 'seqlen', [1, 8, 16, 32, 64, 128, 256, 512, 784, 1024, 1025, 2048, 4096]) @pytest.mark.parametrize('dim', [64]) @pytest.mark.parametrize('batch', [1]) def test_causal_conv1d(batch, dim, seqlen, width, has_bias, silu_activation, - itype): + has_initial_state, itype): device = "cuda" rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3) if itype == torch.bfloat16: @@ -167,11 +168,18 @@ def test_causal_conv1d(batch, dim, seqlen, width, has_bias, silu_activation, weight = torch.randn(dim, width, device=device, dtype=itype) bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None - initial_states = torch.randn(batch, - dim, - width - 1, - device=device, - dtype=itype) + if has_initial_state: + initial_states = torch.randn(batch, + dim, + width - 1, + device=device, + dtype=itype) + has_initial_state_tensor = torch.ones(batch, + dtype=torch.bool, + device=x.device) + else: + initial_states = None + has_initial_state_tensor = None x_ref = x.clone() weight_ref = weight.clone() bias_ref = bias.clone() if bias is not None else None @@ -183,9 +191,7 @@ def test_causal_conv1d(batch, dim, seqlen, width, has_bias, silu_activation, bias, activation=activation, conv_states=initial_states, - has_initial_state=torch.ones(batch, - dtype=torch.bool, - device=x.device)) + has_initial_state=has_initial_state_tensor) out_ref, final_states_ref = causal_conv1d_ref( x_ref, weight_ref, @@ -193,11 +199,12 @@ def test_causal_conv1d(batch, dim, seqlen, width, has_bias, silu_activation, initial_states=initial_states_ref, return_final_states=True, activation=activation) - assert initial_states is not None and final_states_ref is not None - assert torch.allclose(initial_states, - final_states_ref, - rtol=rtol, - atol=atol) + if has_initial_state: + assert initial_states is not None and final_states_ref is not None + assert torch.allclose(initial_states, + final_states_ref, + rtol=rtol, + atol=atol) assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) causal_conv1d_opcheck_fn(x, @@ -205,9 +212,7 @@ def test_causal_conv1d(batch, dim, seqlen, width, has_bias, silu_activation, bias, activation=activation, conv_states=initial_states, - has_initial_state=torch.ones(batch, - dtype=torch.bool, - device=x.device)) + has_initial_state=has_initial_state_tensor) @pytest.mark.parametrize("itype", [torch.bfloat16]) From 1b62745b1d00153c5e99879edaf0c2d7ceb4e2c6 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sat, 7 Dec 2024 09:33:45 -0800 Subject: [PATCH 15/18] [core][executor] simplify instance id (#10976) Signed-off-by: youkaichao --- vllm/config.py | 7 ++++++- vllm/envs.py | 6 ------ vllm/executor/cpu_executor.py | 6 +----- vllm/executor/multiproc_gpu_executor.py | 5 +---- vllm/executor/ray_gpu_executor.py | 7 +------ vllm/executor/ray_hpu_executor.py | 7 +------ vllm/executor/ray_tpu_executor.py | 6 +----- vllm/executor/ray_xpu_executor.py | 6 +----- vllm/utils.py | 25 +++++++++---------------- vllm/worker/worker_base.py | 2 +- 10 files changed, 22 insertions(+), 55 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index db7046ab2c22d..d1c4f995ad015 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -27,7 +27,8 @@ get_hf_text_config, get_pooling_config, get_sentence_transformer_tokenizer_config, is_encoder_decoder, uses_mrope) from vllm.utils import (GiB_bytes, cuda_device_count_stateless, get_cpu_memory, - print_warning_once, resolve_obj_by_qualname) + print_warning_once, random_uuid, + resolve_obj_by_qualname) if TYPE_CHECKING: from ray.util.placement_group import PlacementGroup @@ -2408,6 +2409,7 @@ class VllmConfig: init=True) # type: ignore kv_transfer_config: KVTransferConfig = field(default=None, init=True) # type: ignore + instance_id: str = "" @staticmethod def get_graph_batch_size(batch_size: int) -> int: @@ -2573,6 +2575,9 @@ def __post_init__(self): current_platform.check_and_update_config(self) + if not self.instance_id: + self.instance_id = random_uuid()[:5] + def __str__(self): return ("model=%r, speculative_config=%r, tokenizer=%r, " "skip_tokenizer_init=%s, tokenizer_mode=%s, revision=%s, " diff --git a/vllm/envs.py b/vllm/envs.py index 28797ac1e4af2..ab12a7b48dc53 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -8,7 +8,6 @@ VLLM_RPC_BASE_PATH: str = tempfile.gettempdir() VLLM_USE_MODELSCOPE: bool = False VLLM_RINGBUFFER_WARNING_INTERVAL: int = 60 - VLLM_INSTANCE_ID: Optional[str] = None VLLM_NCCL_SO_PATH: Optional[str] = None LD_LIBRARY_PATH: Optional[str] = None VLLM_USE_TRITON_FLASH_ATTN: bool = False @@ -175,11 +174,6 @@ def get_default_config_root(): "VLLM_USE_MODELSCOPE": lambda: os.environ.get("VLLM_USE_MODELSCOPE", "False").lower() == "true", - # Instance id represents an instance of the VLLM. All processes in the same - # instance should have the same instance id. - "VLLM_INSTANCE_ID": - lambda: os.environ.get("VLLM_INSTANCE_ID", None), - # Interval in seconds to log a warning message when the ring buffer is full "VLLM_RINGBUFFER_WARNING_INTERVAL": lambda: int(os.environ.get("VLLM_RINGBUFFER_WARNING_INTERVAL", "60")), diff --git a/vllm/executor/cpu_executor.py b/vllm/executor/cpu_executor.py index 6b4cb5a9a1d61..2816b5c5c1f88 100644 --- a/vllm/executor/cpu_executor.py +++ b/vllm/executor/cpu_executor.py @@ -10,8 +10,7 @@ from vllm.model_executor.layers.sampler import SamplerOutput from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sequence import ExecuteModelRequest -from vllm.utils import (get_distributed_init_method, get_open_port, - get_vllm_instance_id, make_async) +from vllm.utils import get_distributed_init_method, get_open_port, make_async from vllm.worker.worker_base import WorkerWrapperBase logger = init_logger(__name__) @@ -31,9 +30,6 @@ def _init_executor(self) -> None: # Environment variables for CPU executor # - # Ensure that VLLM_INSTANCE_ID is set, to be inherited by workers - os.environ["VLLM_INSTANCE_ID"] = get_vllm_instance_id() - # Disable torch async compiling which won't work with daemonic processes os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1" diff --git a/vllm/executor/multiproc_gpu_executor.py b/vllm/executor/multiproc_gpu_executor.py index a6c05a71d2b6f..c450209f0eb91 100644 --- a/vllm/executor/multiproc_gpu_executor.py +++ b/vllm/executor/multiproc_gpu_executor.py @@ -16,7 +16,7 @@ from vllm.triton_utils.importing import HAS_TRITON from vllm.utils import (_run_task_with_lock, cuda_device_count_stateless, cuda_is_initialized, get_distributed_init_method, - get_open_port, get_vllm_instance_id, make_async, + get_open_port, make_async, update_environment_variables) if HAS_TRITON: @@ -37,9 +37,6 @@ def _init_executor(self) -> None: world_size = self.parallel_config.world_size tensor_parallel_size = self.parallel_config.tensor_parallel_size - # Ensure that VLLM_INSTANCE_ID is set, to be inherited by workers - os.environ["VLLM_INSTANCE_ID"] = get_vllm_instance_id() - # Disable torch async compiling which won't work with daemonic processes os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1" diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py index 6542b18ae70b1..6554cda6b637b 100644 --- a/vllm/executor/ray_gpu_executor.py +++ b/vllm/executor/ray_gpu_executor.py @@ -15,8 +15,7 @@ from vllm.model_executor.layers.sampler import SamplerOutput from vllm.sequence import ExecuteModelRequest from vllm.utils import (_run_task_with_lock, get_distributed_init_method, - get_ip, get_open_port, get_vllm_instance_id, - make_async) + get_ip, get_open_port, make_async) if ray is not None: from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy @@ -220,14 +219,10 @@ def sort_by_driver_then_worker_ip(worker): " environment variable, make sure it is unique for" " each node.") - VLLM_INSTANCE_ID = get_vllm_instance_id() - # Set environment variables for the driver and workers. all_args_to_update_environment_variables = [({ "CUDA_VISIBLE_DEVICES": ",".join(map(str, node_gpus[node_id])), - "VLLM_INSTANCE_ID": - VLLM_INSTANCE_ID, "VLLM_TRACE_FUNCTION": str(envs.VLLM_TRACE_FUNCTION), **({ diff --git a/vllm/executor/ray_hpu_executor.py b/vllm/executor/ray_hpu_executor.py index a74328e5aa272..91c84d9214a60 100644 --- a/vllm/executor/ray_hpu_executor.py +++ b/vllm/executor/ray_hpu_executor.py @@ -15,8 +15,7 @@ from vllm.model_executor.layers.sampler import SamplerOutput from vllm.sequence import ExecuteModelRequest from vllm.utils import (_run_task_with_lock, get_distributed_init_method, - get_ip, get_open_port, get_vllm_instance_id, - make_async) + get_ip, get_open_port, make_async) if ray is not None: from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy @@ -196,12 +195,8 @@ def sort_by_driver_then_worker_ip(worker): "environment variable, make sure it is unique for" " each node.") - VLLM_INSTANCE_ID = get_vllm_instance_id() - # Set environment variables for the driver and workers. all_args_to_update_environment_variables = [({ - "VLLM_INSTANCE_ID": - VLLM_INSTANCE_ID, "VLLM_TRACE_FUNCTION": str(envs.VLLM_TRACE_FUNCTION), }, ) for (node_id, _) in worker_node_and_gpu_ids] diff --git a/vllm/executor/ray_tpu_executor.py b/vllm/executor/ray_tpu_executor.py index c227b5e283c68..3ee59397bf4c9 100644 --- a/vllm/executor/ray_tpu_executor.py +++ b/vllm/executor/ray_tpu_executor.py @@ -13,7 +13,7 @@ from vllm.model_executor.layers.sampler import SamplerOutput from vllm.sequence import ExecuteModelRequest from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, - get_vllm_instance_id, make_async) + make_async) if ray is not None: from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy @@ -144,12 +144,8 @@ def sort_by_driver_then_worker_ip(worker): for i, (node_id, _) in enumerate(worker_node_and_gpu_ids): node_workers[node_id].append(i) - VLLM_INSTANCE_ID = get_vllm_instance_id() - # Set environment variables for the driver and workers. all_args_to_update_environment_variables = [({ - "VLLM_INSTANCE_ID": - VLLM_INSTANCE_ID, "VLLM_TRACE_FUNCTION": str(envs.VLLM_TRACE_FUNCTION), }, ) for _ in worker_node_and_gpu_ids] diff --git a/vllm/executor/ray_xpu_executor.py b/vllm/executor/ray_xpu_executor.py index 2b1cdc09b0a9f..61f5d6a65e999 100644 --- a/vllm/executor/ray_xpu_executor.py +++ b/vllm/executor/ray_xpu_executor.py @@ -5,7 +5,7 @@ from vllm.executor.ray_gpu_executor import RayGPUExecutor, RayGPUExecutorAsync from vllm.executor.xpu_executor import XPUExecutor from vllm.logger import init_logger -from vllm.utils import get_vllm_instance_id, make_async +from vllm.utils import make_async logger = init_logger(__name__) @@ -17,12 +17,8 @@ def _get_env_vars_to_be_updated(self): worker_node_and_gpu_ids = self._run_workers("get_node_and_gpu_ids", use_dummy_driver=True) - VLLM_INSTANCE_ID = get_vllm_instance_id() - # Set environment variables for the driver and workers. all_args_to_update_environment_variables = [({ - "VLLM_INSTANCE_ID": - VLLM_INSTANCE_ID, "VLLM_TRACE_FUNCTION": str(envs.VLLM_TRACE_FUNCTION), }, ) for (_, _) in worker_node_and_gpu_ids] diff --git a/vllm/utils.py b/vllm/utils.py index 6cee4847e57b4..1f19d9eacd16d 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -24,9 +24,9 @@ from collections.abc import Iterable, Mapping from functools import lru_cache, partial, wraps from platform import uname -from typing import (Any, AsyncGenerator, Awaitable, Callable, Dict, Generic, - Hashable, List, Literal, Optional, OrderedDict, Set, Tuple, - Type, TypeVar, Union, overload) +from typing import (TYPE_CHECKING, Any, AsyncGenerator, Awaitable, Callable, + Dict, Generic, Hashable, List, Literal, Optional, + OrderedDict, Set, Tuple, Type, TypeVar, Union, overload) from uuid import uuid4 import numpy as np @@ -43,6 +43,9 @@ from vllm.logger import enable_trace_function_call, init_logger from vllm.platforms import current_platform +if TYPE_CHECKING: + from vllm.config import VllmConfig + logger = init_logger(__name__) # Exception strings for non-implemented encoder/decoder scenarios @@ -335,17 +338,6 @@ def random_uuid() -> str: return str(uuid.uuid4().hex) -@lru_cache(maxsize=None) -def get_vllm_instance_id() -> str: - """ - If the environment variable VLLM_INSTANCE_ID is set, return it. - Otherwise, return a random UUID. - Instance id represents an instance of the VLLM. All processes in the same - instance should have the same instance id. - """ - return envs.VLLM_INSTANCE_ID or f"vllm-instance-{random_uuid()}" - - @lru_cache(maxsize=None) def in_wsl() -> bool: # Reference: https://github.com/microsoft/WSL/issues/4071 @@ -997,7 +989,7 @@ def find_nccl_library() -> str: return so_file -def enable_trace_function_call_for_thread() -> None: +def enable_trace_function_call_for_thread(vllm_config: "VllmConfig") -> None: """Set up function tracing for the current thread, if enabled via the VLLM_TRACE_FUNCTION environment variable """ @@ -1009,7 +1001,8 @@ def enable_trace_function_call_for_thread() -> None: filename = (f"VLLM_TRACE_FUNCTION_for_process_{os.getpid()}" f"_thread_{threading.get_ident()}_" f"at_{datetime.datetime.now()}.log").replace(" ", "_") - log_path = os.path.join(tmp_dir, "vllm", get_vllm_instance_id(), + log_path = os.path.join(tmp_dir, "vllm", + f"vllm-instance-{vllm_config.instance_id}", filename) os.makedirs(os.path.dirname(log_path), exist_ok=True) enable_trace_function_call(log_path) diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index 7c0bc5a678956..6d00102e0a324 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -439,7 +439,7 @@ def init_worker(self, *args, **kwargs): Here we inject some common logic before initializing the worker. Arguments are passed to the worker class constructor. """ - enable_trace_function_call_for_thread() + enable_trace_function_call_for_thread(self.vllm_config) # see https://github.com/NVIDIA/nccl/issues/1234 os.environ['NCCL_CUMEM_ENABLE'] = '0' From 7be15d9356a10c6ae3537565548e4f8bf46f35dd Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sat, 7 Dec 2024 12:06:08 -0800 Subject: [PATCH 16/18] [core][misc] remove use_dummy driver for _run_workers (#10920) Signed-off-by: youkaichao --- vllm/executor/ray_gpu_executor.py | 27 ++++++++++++--------------- vllm/executor/ray_hpu_executor.py | 28 ++++++++++++---------------- vllm/executor/ray_tpu_executor.py | 21 ++++++++++----------- vllm/executor/ray_xpu_executor.py | 11 +++++++++-- 4 files changed, 43 insertions(+), 44 deletions(-) diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py index 6554cda6b637b..4263fb27265f6 100644 --- a/vllm/executor/ray_gpu_executor.py +++ b/vllm/executor/ray_gpu_executor.py @@ -188,8 +188,14 @@ def sort_by_driver_then_worker_ip(worker): self.workers = sorted(self.workers, key=sort_by_driver_then_worker_ip) # Get the set of GPU IDs used on each node. - worker_node_and_gpu_ids = self._run_workers("get_node_and_gpu_ids", - use_dummy_driver=True) + worker_node_and_gpu_ids = [] + for worker in [self.driver_dummy_worker] + self.workers: + if worker is None: + # driver_dummy_worker can be None when using ray spmd worker. + continue + worker_node_and_gpu_ids.append( + ray.get(worker.get_node_and_gpu_ids.remote()) \ + ) # type: ignore node_workers = defaultdict(list) # node id -> list of worker ranks node_gpus = defaultdict(list) # node id -> list of gpu ids @@ -329,7 +335,6 @@ def _run_workers( async_run_tensor_parallel_workers_only: bool = False, all_args: Optional[List[Tuple[Any, ...]]] = None, all_kwargs: Optional[List[Dict[str, Any]]] = None, - use_dummy_driver: bool = False, max_concurrent_workers: Optional[int] = None, **kwargs, ) -> Any: @@ -389,18 +394,10 @@ def _run_workers( driver_kwargs = kwargs if all_kwargs is None else all_kwargs[0] # Start the driver worker after all the ray workers. - if not use_dummy_driver: - driver_worker_output = [ - self.driver_worker.execute_method(method, *driver_args, - **driver_kwargs) - ] - else: - assert self.driver_dummy_worker is not None - driver_worker_output = [ - ray.get( - self.driver_dummy_worker.execute_method.remote( - method, *driver_args, **driver_kwargs)) - ] + driver_worker_output = [ + self.driver_worker.execute_method(method, *driver_args, + **driver_kwargs) + ] # Get the results of the ray workers. if self.workers: diff --git a/vllm/executor/ray_hpu_executor.py b/vllm/executor/ray_hpu_executor.py index 91c84d9214a60..f3025cb537ab8 100644 --- a/vllm/executor/ray_hpu_executor.py +++ b/vllm/executor/ray_hpu_executor.py @@ -163,9 +163,14 @@ def sort_by_driver_then_worker_ip(worker): # node will be placed first. self.workers = sorted(self.workers, key=sort_by_driver_then_worker_ip) - # Get the set of GPU IDs used on each node. - worker_node_and_gpu_ids = self._run_workers("get_node_and_gpu_ids", - use_dummy_driver=True) + worker_node_and_gpu_ids = [] + for worker in [self.driver_dummy_worker] + self.workers: + if worker is None: + # driver_dummy_worker can be None when using ray spmd worker. + continue + worker_node_and_gpu_ids.append( + ray.get(worker.get_node_and_gpu_ids.remote()) \ + ) # type: ignore node_workers = defaultdict(list) # node id -> list of worker ranks node_gpus = defaultdict(list) # node id -> list of gpu ids @@ -296,7 +301,6 @@ def _run_workers( async_run_tensor_parallel_workers_only: bool = False, all_args: Optional[List[Tuple[Any, ...]]] = None, all_kwargs: Optional[List[Dict[str, Any]]] = None, - use_dummy_driver: bool = False, max_concurrent_workers: Optional[int] = None, **kwargs, ) -> Any: @@ -356,18 +360,10 @@ def _run_workers( driver_kwargs = kwargs if all_kwargs is None else all_kwargs[0] # Start the driver worker after all the ray workers. - if not use_dummy_driver: - driver_worker_output = [ - self.driver_worker.execute_method(method, *driver_args, - **driver_kwargs) - ] - else: - assert self.driver_dummy_worker is not None - driver_worker_output = [ - ray.get( - self.driver_dummy_worker.execute_method.remote( - method, *driver_args, **driver_kwargs)) - ] + driver_worker_output = [ + self.driver_worker.execute_method(method, *driver_args, + **driver_kwargs) + ] # Get the results of the ray workers. if self.workers: diff --git a/vllm/executor/ray_tpu_executor.py b/vllm/executor/ray_tpu_executor.py index 3ee59397bf4c9..5118c13934f0d 100644 --- a/vllm/executor/ray_tpu_executor.py +++ b/vllm/executor/ray_tpu_executor.py @@ -137,8 +137,14 @@ def sort_by_driver_then_worker_ip(worker): self.workers = sorted(self.workers, key=sort_by_driver_then_worker_ip) # Get the set of TPU IDs used on each node. - worker_node_and_gpu_ids = self._run_workers("get_node_and_gpu_ids", - use_dummy_driver=True) + worker_node_and_gpu_ids = [] + for worker in [self.driver_dummy_worker] + self.workers: + if worker is None: + # driver_dummy_worker can be None when using ray spmd worker. + continue + worker_node_and_gpu_ids.append( + ray.get(worker.get_node_and_gpu_ids.remote()) \ + ) # type: ignore node_workers = defaultdict(list) for i, (node_id, _) in enumerate(worker_node_and_gpu_ids): @@ -199,7 +205,6 @@ def _run_workers( async_run_remote_workers_only: bool = False, all_args: Optional[List[Tuple[Any, ...]]] = None, all_kwargs: Optional[List[Dict[str, Any]]] = None, - use_dummy_driver: bool = False, max_concurrent_workers: Optional[int] = None, use_ray_compiled_dag: bool = False, **kwargs, @@ -241,14 +246,8 @@ def _run_workers( driver_kwargs = kwargs if all_kwargs is None else all_kwargs[0] # Start the driver worker after all the ray workers. - if not use_dummy_driver: - driver_worker_output = self.driver_worker.execute_method( - method, *driver_args, **driver_kwargs) - else: - assert self.driver_dummy_worker is not None - driver_worker_output = ray.get( - self.driver_dummy_worker.execute_method.remote( - method, *driver_args, **driver_kwargs)) + driver_worker_output = self.driver_worker.execute_method( + method, *driver_args, **driver_kwargs) # Get the results of the ray workers. if self.workers: ray_worker_outputs = ray.get(ray_worker_outputs) diff --git a/vllm/executor/ray_xpu_executor.py b/vllm/executor/ray_xpu_executor.py index 61f5d6a65e999..d2086f5fef26c 100644 --- a/vllm/executor/ray_xpu_executor.py +++ b/vllm/executor/ray_xpu_executor.py @@ -1,6 +1,8 @@ import asyncio from typing import List, Optional +import ray + import vllm.envs as envs from vllm.executor.ray_gpu_executor import RayGPUExecutor, RayGPUExecutorAsync from vllm.executor.xpu_executor import XPUExecutor @@ -14,8 +16,13 @@ class RayXPUExecutor(RayGPUExecutor, XPUExecutor): def _get_env_vars_to_be_updated(self): # Get the set of GPU IDs used on each node. - worker_node_and_gpu_ids = self._run_workers("get_node_and_gpu_ids", - use_dummy_driver=True) + worker_node_and_gpu_ids = [] + for worker in [self.driver_dummy_worker] + self.workers: + if worker is None: + # driver_dummy_worker can be None when using ray spmd worker. + continue + worker_node_and_gpu_ids.append( + ray.get(worker.get_node_and_gpu_ids.remote())) # type: ignore # Set environment variables for the driver and workers. all_args_to_update_environment_variables = [({ From fd57d2b5347e8fe6da9287553d4b5a3aaf2e6693 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sun, 8 Dec 2024 03:05:21 -0800 Subject: [PATCH 17/18] [torch.compile] allow candidate compile sizes (#10984) Signed-off-by: youkaichao --- tests/engine/test_arg_utils.py | 8 +++---- vllm/config.py | 44 +++++++++++++++++----------------- vllm/engine/arg_utils.py | 5 +--- vllm/entrypoints/llm.py | 6 +---- 4 files changed, 28 insertions(+), 35 deletions(-) diff --git a/tests/engine/test_arg_utils.py b/tests/engine/test_arg_utils.py index de78d41ad12eb..4e269de9fc40b 100644 --- a/tests/engine/test_arg_utils.py +++ b/tests/engine/test_arg_utils.py @@ -50,12 +50,12 @@ def test_compilation_config(): args = parser.parse_args(["-O=3"]) assert args.compilation_config.level == 3 - # set to json - args = parser.parse_args(["--compilation-config", '{"level": 3}']) + # set to string form of a dict + args = parser.parse_args(["--compilation-config", "{'level': 3}"]) assert args.compilation_config.level == 3 - # set to json - args = parser.parse_args(['--compilation-config={"level": 3}']) + # set to string form of a dict + args = parser.parse_args(["--compilation-config={'level': 3}"]) assert args.compilation_config.level == 3 diff --git a/vllm/config.py b/vllm/config.py index d1c4f995ad015..164622b5af34e 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1,3 +1,4 @@ +import ast import copy import enum import hashlib @@ -2191,14 +2192,10 @@ class CompilationConfig(BaseModel): - use_inductor: whether to use inductor compilation. - False: inductor compilation is not used. graph runs in eager. - True: inductor compilation is used. one graph for symbolic shape - is compiled. In addition, compile for different sizes specified - in inductor_compile_sizes, using configurations + is compiled. In addition, compile for cudagraph sizes that are + in candidate_compile_sizes, using configurations in inductor_compile_config. - - inductor_compile_sizes: sizes to compile for inductor. - - inductor_specialize_for_cudagraph_no_more_than: an optional integer - to specialize inductor for cudagraph sizes no more than the - specified size. It is useful when we want to specialize inductor - with a subset of cudagraph sizes. + - candidate_compile_sizes: sizes to compile for inductor. - inductor_compile_config: additional configurations for inductor. - None: use default configurations. - inductor_passes: additional passes for inductor. It is a dictionary @@ -2227,8 +2224,7 @@ class CompilationConfig(BaseModel): ]) use_inductor: bool = True - inductor_specialize_for_cudagraph_no_more_than: Optional[int] = None - inductor_compile_sizes: Optional[List[int]] = Field(default=None) + candidate_compile_sizes: Optional[List[int]] = Field(default=None) inductor_compile_config: Dict = Field(default_factory=dict) inductor_passes: Dict[str, str] = Field(default_factory=dict) @@ -2294,7 +2290,9 @@ def from_cli(cls, cli_value: str) -> "CompilationConfig": """Parse the CLI value for the compilation config.""" if cli_value in ["0", "1", "2", "3"]: return cls(level=int(cli_value)) - return CompilationConfig.model_validate_json(cli_value) + # do not use `eval`, it is dangerous and can execute arbitrary code + dict_value = ast.literal_eval(cli_value) + return CompilationConfig.model_validate(dict_value) def model_post_init(self, __context: Any) -> None: @@ -2355,18 +2353,20 @@ def init_with_cudagraph_sizes(self, sizes_to_specialize: List[int]): logger.info(("cudagraph sizes specified by model runner" " %s is overridden by config %s"), sizes_to_specialize, self.cudagraph_capture_sizes) - if self.inductor_specialize_for_cudagraph_no_more_than is not None: - assert self.inductor_compile_sizes is None, ( - "inductor_compile_sizes should be None when " - "inductor_specialize_for_cudagraph_no_more_than is not None") - self.compile_sizes = [ - x for x in self.capture_sizes - if x <= self.inductor_specialize_for_cudagraph_no_more_than - ] - else: - if self.inductor_compile_sizes is None: - self.inductor_compile_sizes = [] - self.compile_sizes = self.inductor_compile_sizes + + if self.candidate_compile_sizes is None: + self.candidate_compile_sizes = [] + self.compile_sizes = [ + x for x in self.candidate_compile_sizes if x in self.capture_sizes + ] + ignored_sizes = [ + x for x in self.candidate_compile_sizes + if x not in self.capture_sizes + ] + if ignored_sizes: + logger.warning(("candidate_compile_sizes %s are ignored " + "because they are not cudagraph capture sizes."), + ignored_sizes) # sort to make sure cudagraph capture sizes are in descending order self.capture_sizes.sort(reverse=True) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index ccd9fac225cba..96c11ec2b4f9e 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -209,12 +209,9 @@ def __post_init__(self): # support `EngineArgs(compilation_config={...})` # without having to manually construct a # CompilationConfig object - if isinstance(self.compilation_config, (int)): + if isinstance(self.compilation_config, (int, dict)): self.compilation_config = CompilationConfig.from_cli( str(self.compilation_config)) - elif isinstance(self.compilation_config, (dict)): - self.compilation_config = CompilationConfig.from_cli( - json.dumps(self.compilation_config)) # Setup plugins from vllm.plugins import load_general_plugins diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 65fa9873df28c..8de30ccd18a11 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -1,5 +1,4 @@ import itertools -import json import warnings from contextlib import contextmanager from typing import (Any, ClassVar, Dict, List, Optional, Sequence, Tuple, Type, @@ -186,12 +185,9 @@ def __init__( kwargs["disable_log_stats"] = True if compilation_config is not None: - if isinstance(compilation_config, (int)): + if isinstance(compilation_config, (int, dict)): compilation_config_instance = CompilationConfig.from_cli( str(compilation_config)) - elif isinstance(compilation_config, (dict)): - compilation_config_instance = CompilationConfig.from_cli( - json.dumps(compilation_config)) else: compilation_config_instance = compilation_config else: From a11f3265282c712d1d9fa75368e2a8c40019fbb7 Mon Sep 17 00:00:00 2001 From: Roger Wang <136131678+ywang96@users.noreply.github.com> Date: Sun, 8 Dec 2024 04:50:51 -0800 Subject: [PATCH 18/18] [V1] Initial support of multimodal models for V1 re-arch (#10699) Signed-off-by: Roger Wang --- vllm/engine/arg_utils.py | 16 +-- vllm/model_executor/models/interfaces.py | 5 + vllm/model_executor/models/internvl.py | 68 ++++++++++--- vllm/model_executor/models/molmo.py | 72 ++++++++++++-- vllm/model_executor/models/pixtral.py | 121 +++++++++++++++++------ vllm/model_executor/models/utils.py | 28 +++++- vllm/multimodal/inputs.py | 3 +- vllm/multimodal/utils.py | 10 +- vllm/v1/core/scheduler.py | 4 +- vllm/v1/engine/llm_engine.py | 24 ++++- vllm/v1/engine/mm_input_mapper.py | 2 +- 11 files changed, 284 insertions(+), 69 deletions(-) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 96c11ec2b4f9e..3db069ec64ee4 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1050,9 +1050,12 @@ def create_engine_config(self, # long context (> 32K) models. This is to avoid OOM errors in the # initial memory profiling phase. - # Chunked prefill is currently disabled for multimodal models by - # default. - if use_long_context and not model_config.is_multimodal_model: + # For multimodal models, chunked prefill is disabled by default in + # V0, but enabled by design in V1 + if model_config.is_multimodal_model: + self.enable_chunked_prefill = bool(envs.VLLM_USE_V1) + + elif use_long_context: is_gpu = device_config.device_type == "cuda" use_sliding_window = (model_config.get_sliding_window() is not None) @@ -1241,12 +1244,9 @@ def _override_v1_engine_config(self, engine_config: VllmConfig) -> None: Override the EngineConfig's configs based on the usage context for V1. """ assert envs.VLLM_USE_V1, "V1 is not enabled" - # TODO (ywang96): Enable APC by default when VLM supports it. if engine_config.model_config.is_multimodal_model: - logger.warning( - "Prefix caching is currently not supported for multimodal " - "models and has been disabled.") - engine_config.cache_config.enable_prefix_caching = False + # TODO (ywang96): Enable APC by default when VLM supports it. + assert not engine_config.cache_config.enable_prefix_caching @dataclass diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index 01a381381ccec..c3979eab905db 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -36,6 +36,11 @@ def get_multimodal_embeddings(self, **kwargs) -> Optional[T]: """ Returns multimodal embeddings generated from multimodal kwargs to be merged with text embeddings. + + The output embeddings must be one of the following formats: + - A list or tuple of 2D tensors, where each tensor corresponds to + each input image. + - A single 3D tensor, with the batch dimension grouping the 2D tensors. """ ... diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py index d5a7781fecfc3..42c769f79e202 100644 --- a/vllm/model_executor/models/internvl.py +++ b/vllm/model_executor/models/internvl.py @@ -26,7 +26,7 @@ InternVisionPatchModel) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs -from vllm.multimodal.inputs import NestedTensors +from vllm.multimodal.inputs import NestedTensors, PlaceholderRange from vllm.multimodal.utils import cached_get_tokenizer from vllm.sequence import IntermediateTensors from vllm.utils import is_list_of @@ -52,12 +52,18 @@ class InternVLImagePixelInputs(TypedDict): Shape: `(batch_size * num_images * (1 + num_patches), num_channels, height, width)` """ + patches_per_image: List[int] + """ + List of number of total patches for each image in the batch. + """ class InternVLImageEmbeddingInputs(TypedDict): type: Literal["image_embeds"] - data: torch.Tensor - """Shape: `(batch_size * num_images, image_feature_size, hidden_size)` + data: NestedTensors + """ + A tensor of shape `(num_images, total_image_feature_size, hidden_size)` + or a list of tensors of shape `(total_image_feature_size, hidden_size)` `hidden_size` must match the hidden size of language model backbone. """ @@ -349,10 +355,32 @@ def input_processor( new_prompt = self._expand_image_prompt(prompt, image_feature_sizes, num_patches) new_prompt_token_ids = tokenizer.encode(new_prompt) + img_context_token_id = tokenizer.encode(self.img_context_token, + add_special_tokens=False) + assert len(img_context_token_id) == 1, \ + (f"Invalid image token '{self.img_context_token}': A valid image " + f"token encodes to a single token ID, got {img_context_token_id}.") + img_context_token_id = img_context_token_id[0] + + # Get precise tracking of placeholder positions + token_idx = image_idx = 0 + placeholder_ranges = [] + while token_idx < len(new_prompt_token_ids): + if new_prompt_token_ids[token_idx] == img_context_token_id: + curr_image_featue_size = image_feature_sizes[image_idx] + placeholder_ranges.append( + PlaceholderRange(offset=token_idx, + length=curr_image_featue_size)) + image_idx += 1 + token_idx += curr_image_featue_size + else: + token_idx += 1 - return token_inputs(prompt=prompt, - prompt_token_ids=new_prompt_token_ids, - multi_modal_data=multi_modal_data) + return token_inputs( + prompt=prompt, + prompt_token_ids=new_prompt_token_ids, + multi_modal_data=multi_modal_data, + multi_modal_placeholders={"image": placeholder_ranges}) def input_mapper( self, @@ -614,26 +642,46 @@ def _parse_and_validate_image_input( if not isinstance(pixel_values, (torch.Tensor, list)): raise ValueError("Incorrect type of pixel values. " f"Got type: {type(pixel_values)}") + + patches_per_image = [] + for request_pixel_values in pixel_values: + for image_pixel_values in request_pixel_values: + patches_per_image.append(image_pixel_values.shape[0]) # We need to flatten (B, N, P) to (B*N*P), # so we call flatten_bn twice. return InternVLImagePixelInputs( type="pixel_values", data=self._validate_pixel_values( flatten_bn(flatten_bn(pixel_values), concat=True)), - ) + patches_per_image=patches_per_image) raise AssertionError("This line should be unreachable.") def _process_image_input( self, image_input: InternVLImageInputs, - ) -> torch.Tensor: + ) -> Tuple[torch.Tensor]: if image_input["type"] == "image_embeds": return image_input["data"] assert self.vision_model is not None + image_embeds = self.extract_feature(image_input["data"]) + patches_per_image = image_input["patches_per_image"] + if len(patches_per_image) == 1: + image_embeds = image_embeds.unsqueeze(0) + return image_embeds + + # NOTE: Image embeddings are split into separate tensors for each image + # by the size of each embedding. + feature_size = image_embeds.shape[1] + image_embeds = image_embeds.view(-1, + self.config.text_config.hidden_size) + image_feature_sizes = [ + num_patches * feature_size for num_patches in patches_per_image + ] + image_embeds = image_embeds.split(image_feature_sizes) return image_embeds def _set_visual_token_mask(self, input_ids: torch.Tensor) -> torch.Tensor: @@ -696,13 +744,11 @@ def forward( "inputs_embeds": inputs_embeds, } + # Only required if the model is mono-architecture if self.visual_token_mask is not None: - # overwrite visual_token_mask and img_context_token_id back to None, - # so that this doesn't need to depend on encoder output forward_kwargs.update( {"visual_token_mask": self.visual_token_mask}) self.visual_token_mask = None - self.img_context_token_id = None hidden_states = self.language_model.model(**forward_kwargs) return hidden_states diff --git a/vllm/model_executor/models/molmo.py b/vllm/model_executor/models/molmo.py index d1fcbd167c199..a328b5a2aeea7 100644 --- a/vllm/model_executor/models/molmo.py +++ b/vllm/model_executor/models/molmo.py @@ -37,7 +37,7 @@ ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs -from vllm.multimodal.inputs import NestedTensors +from vllm.multimodal.inputs import NestedTensors, PlaceholderRange from vllm.multimodal.utils import cached_get_tokenizer from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors, SequenceData) @@ -46,12 +46,16 @@ from .interfaces import SupportsMultiModal, SupportsPP from .utils import (AutoWeightsLoader, WeightsMapper, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) + maybe_prefix, merge_multimodal_embeddings) # TODO: hard-coded for now. Consider making it configurable. VIT_LAYERS = [-2, -9] NUM_PREFIX_TOKENS = 1 ADDITIONAL_VOCAB_SIZE = 128 +DEFAULT_IMAGE_PATCH_TOKEN_ID = 152066 +DEFAULT_IM_START_TOKEN_ID = 152067 +DEFAULT_IM_END_TOKEN_ID = 152064 +DEFAULT_IM_COL_TOKEN_ID = 152065 class MolmoImageInputs(TypedDict): @@ -75,6 +79,11 @@ class MolmoImageInputs(TypedDict): `(batch_size, num_crops, num_patch)` """ + image_start_end: Tuple[int, int] + """Starting and ending index of placeholder + tokens + """ + @dataclass class VisionBackboneConfig: @@ -918,6 +927,8 @@ def image_input_mapper_for_molmo( ctx: InputContext, data: object, ): + if isinstance(data, list): + data = data[0] return MultiModalKwargs(data) @@ -967,7 +978,22 @@ def dummy_data_for_molmo(ctx: InputContext, seq_len: int, if "image_masks" in out: dummy_imgdata["image_masks"] = out["image_masks"] dummy_imgdata["seq_len"] = torch.tensor(seq_len, dtype=torch.long) - return DummyData(dummy_seqdata, {"image": dummy_imgdata}) + size = 0 + offset = -1 + for i in range(len(token_ids)): + if token_ids[i] in (DEFAULT_IMAGE_PATCH_TOKEN_ID, + DEFAULT_IM_START_TOKEN_ID, DEFAULT_IM_END_TOKEN_ID, + DEFAULT_IM_COL_TOKEN_ID): + if offset < 0: + offset = i + size += 1 + dummy_imgdata["image_start_end"] = (offset, offset + size) + return DummyData(seq_data=dummy_seqdata, + multi_modal_data={"image": dummy_imgdata}, + multi_modal_placeholders={ + "image": + [PlaceholderRange(offset=offset, length=size)] + }) def pad_images( @@ -1055,19 +1081,34 @@ def input_processor_for_molmo(ctx: InputContext, inputs: DecoderOnlyInputs): if image_masks is not None: image_data["image_masks"] = image_masks - image_data["seq_len"] = torch.tensor(len(out["input_ids"]), + new_prompt_token_ids = out["input_ids"].tolist() + image_data["seq_len"] = torch.tensor(len(new_prompt_token_ids), dtype=torch.long) multi_modal_data = dict(image=image_data) + size = 0 + offset = -1 + for i in range(len(new_prompt_token_ids)): + if new_prompt_token_ids[i] in (DEFAULT_IMAGE_PATCH_TOKEN_ID, + DEFAULT_IM_START_TOKEN_ID, + DEFAULT_IM_END_TOKEN_ID, + DEFAULT_IM_COL_TOKEN_ID): + if offset < 0: + offset = i + size += 1 + image_data["image_start_end"] = (offset, offset + size) prompt = inputs.get("prompt") if prompt is None: - prompt = tokenizer.decode(out["input_ids"]) + prompt = tokenizer.decode(new_prompt_token_ids) return token_inputs( - prompt_token_ids=out["input_ids"], + prompt_token_ids=new_prompt_token_ids, prompt=prompt, multi_modal_data=multi_modal_data, + multi_modal_placeholders={ + "image": [PlaceholderRange(offset=offset, length=size)] + }, ) @@ -1113,6 +1154,7 @@ def _parse_and_validate_image_input( ) -> Optional[MolmoImageInputs]: images = kwargs.pop("images", None) image_masks = kwargs.pop("image_masks", None) + image_start_end = kwargs.pop("image_start_end", None) if images is None: return None @@ -1130,6 +1172,7 @@ def _parse_and_validate_image_input( image_input_idx=image_input_idx, seq_len=seq_len, image_masks=image_masks, + image_start_end=image_start_end, ) def _process_image_input( @@ -1178,9 +1221,16 @@ def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]: # Note: In this original implementation from AI2, the final # vision_embeddings will be always be the same length - # of input embedddings, which is not very efficient. - # TODO(ywang96): see if this can be optimized. + # of input embeddings. vision_embeddings = torch.einsum('nd,nm->md', image_features, mat) + + # Split by the sizes of the input sequences. For each full embedding, + # extract the actual vision embeddings to be merged. + vision_embeddings = list(vision_embeddings.split(seq_len.tolist())) + for i in range(len(vision_embeddings)): + start, end = image_input['image_start_end'][i] + vision_embeddings[i] = vision_embeddings[i][start:end] + return vision_embeddings def get_input_embeddings( @@ -1190,7 +1240,11 @@ def get_input_embeddings( ) -> torch.Tensor: inputs_embeds = self.model.get_input_embeddings(input_ids) if multimodal_embeddings is not None: - inputs_embeds = inputs_embeds + multimodal_embeddings + inputs_embeds = merge_multimodal_embeddings( + input_ids, inputs_embeds, multimodal_embeddings, [ + DEFAULT_IMAGE_PATCH_TOKEN_ID, DEFAULT_IM_START_TOKEN_ID, + DEFAULT_IM_END_TOKEN_ID, DEFAULT_IM_COL_TOKEN_ID + ]) return inputs_embeds def forward( diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index 215727cadd954..c6786c363ab4a 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -48,6 +48,9 @@ except ImportError: USE_XFORMERS_OPS = False +PIXTRAL_IMAGE_BREAK_ID = 12 +PIXTRAL_IMAGE_END_ID = 13 + def get_max_pixtral_image_tokens(ctx: InputContext): tokenizer = cached_get_tokenizer( @@ -68,7 +71,6 @@ def dummy_data_for_pixtral(ctx: InputContext, seq_len: int, tokenizer_mode=ctx.model_config.tokenizer_mode) mm_encoder = tokenizer.mistral.instruct_tokenizer.mm_encoder - patch_size = mm_encoder.mm_config.image_patch_size image_token_id = mm_encoder.special_ids.img mm_config = ctx.model_config.multimodal_config @@ -78,8 +80,8 @@ def dummy_data_for_pixtral(ctx: InputContext, seq_len: int, size = 256 image = Image.new("RGB", (size, size), color=0) - image_feature_size = (size**2) // (patch_size**2) - + encoding = tokenizer.instruct.mm_encoder(ImageChunk(image=image)) + image_feature_size = len(encoding.tokens) num_image_tokens = image_feature_size * num_images seq_data = SequenceData.from_prompt_token_counts( (image_token_id, num_image_tokens), @@ -101,14 +103,13 @@ def input_mapper_for_pixtral(ctx: InputContext, Args: ctx: Context of the loaded model. - data: data potentially containing image/image embeddings to be mapped - to pixel_values in .forward() for a visual QWenLMHeadModel model. + data: data potentially containing PIL images to be processed + and mapped to `images`. Returns: MultiModalKwargs containing the stacked normalized images tensor or image embeddings. """ - # Early exit if we have provided an image to a language only Qwen model model_config = ctx.model_config tokenizer = cached_get_tokenizer( model_config.tokenizer, tokenizer_mode=model_config.tokenizer_mode) @@ -116,35 +117,67 @@ def input_mapper_for_pixtral(ctx: InputContext, data_list = data if isinstance(data, list) else [data] images = [] + image_tokens_list = [] for image_data in data_list: image = ImageChunk(image=image_data) encoding = tokenizer.instruct.mm_encoder(image) image = torch.from_numpy(encoding.image).to(device="cuda", dtype=torch.float16) images.append(image) + image_tokens_list.append(encoding.tokens) - return MultiModalKwargs({"images": images}) + image_tokens = torch.tensor([ + token_id for image_tokens in image_tokens_list + for token_id in image_tokens + ]) + return MultiModalKwargs({"images": images, "image_tokens": image_tokens}) def input_processor_for_pixtral(ctx: InputContext, inputs: DecoderOnlyInputs): multi_modal_data = inputs.get("multi_modal_data") - if multi_modal_data is not None and "image" in multi_modal_data: - tokenizer = cached_get_tokenizer( - ctx.model_config.tokenizer, - tokenizer_mode=ctx.model_config.tokenizer_mode) - - mm_encoder = tokenizer.mistral.instruct_tokenizer.mm_encoder - image_token_id = mm_encoder.special_ids.img + if multi_modal_data is None or "image" not in multi_modal_data: + return inputs - if image_token_id not in inputs['prompt_token_ids']: - raise ValueError( - f"You've passed {inputs=} without {image_token_id=}" - " Make sure to process your input via mistral_common's" - " tokenizer or pass a chat completion request. For more" - " For more info, see: " - "https://github.com/vllm-project/vllm/issues/8411.") + prompt_token_ids = inputs.get("prompt_token_ids") + prompt = inputs.get("prompt") + tokenizer = cached_get_tokenizer( + ctx.model_config.tokenizer, + tokenizer_mode=ctx.model_config.tokenizer_mode) - return inputs + mm_encoder = tokenizer.mistral.instruct_tokenizer.mm_encoder + image_token_id = mm_encoder.special_ids.img + image_break_id = mm_encoder.special_ids.img_break + image_end_id = mm_encoder.special_ids.img_end + + if image_token_id not in inputs['prompt_token_ids']: + raise ValueError( + f"You've passed {inputs=} without {image_token_id=}" + " Make sure to process your input via mistral_common's" + " tokenizer or pass a chat completion request. For more" + " For more info, see: " + "https://github.com/vllm-project/vllm/issues/8411.") + + # Get precise tracking of placeholder positions + placeholder_ranges = [] + curr_offset = -1 + curr_length = 0 + for i in range(len(prompt_token_ids)): + if prompt_token_ids[i] in (image_token_id, image_break_id): + if curr_offset < 0: + curr_offset = i + curr_length += 1 + elif prompt_token_ids[i] == image_end_id: + curr_length += 1 + placeholder_ranges.append( + PlaceholderRange(offset=curr_offset, length=curr_length)) + curr_offset = -1 + curr_length = 0 + else: + pass + return token_inputs(prompt=prompt, + prompt_token_ids=prompt_token_ids, + multi_modal_data=multi_modal_data, + multi_modal_placeholders={"image": placeholder_ranges}) @MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_pixtral) @@ -192,11 +225,29 @@ def sampler(self): return get_sampler() def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]: - image_input = self._parse_and_validate_image_input(**kwargs) + image_input, image_tokens = self._parse_and_validate_image_input( + **kwargs) if image_input is None: return None + vision_embeddings = self._process_image_input(image_input) - return vision_embeddings + + # NOTE: We patch the outputs of the vision encoder with embeddings + # from `[IMG_BREAK]` and `[IMG_END]` tokens. + image_embeds = self.language_model.get_input_embeddings(image_tokens) + image_token_mask = image_tokens == self.vision_args.image_token_id + image_embeds[image_token_mask] = vision_embeddings + + # NOTE: Image embeddings are split into separate tensors for each image + # by the indices of `[IMG_END]` token. + split_indices = torch.where( + image_tokens == PIXTRAL_IMAGE_END_ID)[0] + 1 + if len(split_indices) <= 1: + # Do not split, return as tensor of shape [1, fs, hs] + return image_embeds.unsqueeze(0) + + image_embeds = image_embeds.tensor_split(split_indices.cpu()) + return image_embeds def get_input_embeddings( self, @@ -206,8 +257,10 @@ def get_input_embeddings( inputs_embeds = self.language_model.get_input_embeddings(input_ids) if multimodal_embeddings is not None: inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, multimodal_embeddings, - self.vision_args.image_token_id) + input_ids, inputs_embeds, multimodal_embeddings, [ + self.vision_args.image_token_id, PIXTRAL_IMAGE_END_ID, + PIXTRAL_IMAGE_BREAK_ID + ]) return inputs_embeds def forward( @@ -245,10 +298,11 @@ def forward( def _parse_and_validate_image_input( self, images: Optional[Union[List[List[torch.Tensor]], List[torch.Tensor], - torch.Tensor]] = None + torch.Tensor]] = None, + image_tokens: Optional[torch.Tensor] = None, ) -> Optional[List[torch.Tensor]]: if images is None: - return None + return None, None if isinstance(images, torch.Tensor): # if passed as batch take all images @@ -267,7 +321,16 @@ def _parse_and_validate_image_input( images = flatten_images - return images + if isinstance(image_tokens, torch.Tensor): + # image_tokens are batched + image_tokens = image_tokens.flatten() + elif isinstance(image_tokens, list): + # image_tokens are of different lengths thus passed as a list + image_tokens = torch.cat(image_tokens) + + assert image_tokens.dim() == 1 + + return images, image_tokens def _process_image_input(self, image_input: List[torch.Tensor]) -> torch.Tensor: diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index 5ec44955dbd80..269b66806adf4 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -409,16 +409,42 @@ def merge_multimodal_embeddings( input_ids: torch.Tensor, inputs_embeds: torch.Tensor, multimodal_embeddings: NestedTensors, - placeholder_token_id: int, + placeholder_token_id: Union[int, List[int]], ) -> torch.Tensor: """ Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the positions in ``inputs_embeds`` corresponding to placeholder tokens in ``input_ids``. + + ``placeholder_token_id`` can be a list of token ids (e.g, token ids + of img_start, img_break, and img_end tokens) when needed: This means + the order of these tokens in the ``input_ids`` MUST MATCH the order of + their embeddings in ``multimodal_embeddings`` since we need to + slice-merge instead of individually scattering. + + For example, if input_ids is "TTTTTSIIIBIIIBIIIETTT", where + - T is text token + - S is image start token + - I is image embedding token + - B is image break token + - E is image end token. + + Then the image embeddings (that correspond to I's) from vision encoder + must be padded with embeddings of S, B, and E in the same order of + input_ids for a correct embedding merge. Note: This updates ``inputs_embeds`` in place. """ + if isinstance(placeholder_token_id, list): + placeholder_token_id = torch.tensor(placeholder_token_id, + device=input_ids.device) + return _merge_multimodal_embeddings( + inputs_embeds, + torch.isin(input_ids, placeholder_token_id), + multimodal_embeddings, + ) + return _merge_multimodal_embeddings( inputs_embeds, (input_ids == placeholder_token_id), diff --git a/vllm/multimodal/inputs.py b/vllm/multimodal/inputs.py index 640c7c04b8817..229a8fbdf5831 100644 --- a/vllm/multimodal/inputs.py +++ b/vllm/multimodal/inputs.py @@ -96,7 +96,8 @@ class PlaceholderRange(TypedDict): """The length of the placeholder.""" -NestedTensors = Union[List["NestedTensors"], List[torch.Tensor], torch.Tensor] +NestedTensors = Union[List["NestedTensors"], List[torch.Tensor], torch.Tensor, + Tuple[torch.Tensor, ...]] """ Uses a list instead of a tensor if the dimensions of each element do not match. """ diff --git a/vllm/multimodal/utils.py b/vllm/multimodal/utils.py index d4333b7519b47..c898ca4e6573e 100644 --- a/vllm/multimodal/utils.py +++ b/vllm/multimodal/utils.py @@ -535,11 +535,13 @@ def repeat_and_pad_placeholder_tokens( return new_prompt, new_token_ids, placeholder_ranges -def consecutive_placeholder_ranges(num_items: int, - item_size: int) -> List[PlaceholderRange]: +def consecutive_placeholder_ranges( + num_items: int, + item_size: int, + initial_offset: int = 0) -> List[PlaceholderRange]: """Returns a list of consecutive PlaceholderRanges of a fixed size""" return [ - PlaceholderRange(offset=i * item_size, length=item_size) - for i in range(num_items) + PlaceholderRange(offset=initial_offset + i * item_size, + length=item_size) for i in range(num_items) ] diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index f1f26f4e8d443..1203d35fc985f 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -73,12 +73,12 @@ def __init__( # has the Transformer architecture (e.g., ViT). # FIXME(woosuk): Below are placeholder values. We need to calculate the # actual values from the configurations. - self.max_num_encoder_input_tokens = 2048 + self.max_num_encoder_input_tokens = 16384 # NOTE(woosuk): For the models without encoder (e.g., text-only models), # the encoder cache will not be initialized and used, regardless of # the cache size. This is because the memory space for the encoder cache # is preallocated in the profiling run. - self.encoder_cache_manager = EncoderCacheManager(cache_size=2048) + self.encoder_cache_manager = EncoderCacheManager(cache_size=16384) def schedule(self) -> "SchedulerOutput": # NOTE(woosuk) on the scheduling algorithm: diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index 312c0242a45dd..994e68669108e 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -1,5 +1,7 @@ from typing import Dict, List, Mapping, Optional, Type, Union +from typing_extensions import TypeVar + from vllm.config import VllmConfig from vllm.engine.arg_utils import EngineArgs from vllm.engine.metrics_types import StatLoggerBase @@ -12,7 +14,8 @@ from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams -from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs +from vllm.transformers_utils.tokenizer_group import ( + BaseTokenizerGroup, init_tokenizer_from_configs) from vllm.usage.usage_lib import UsageContext from vllm.v1.engine.core_client import EngineCoreClient from vllm.v1.engine.detokenizer import Detokenizer @@ -21,6 +24,8 @@ logger = init_logger(__name__) +_G = TypeVar("_G", bound=BaseTokenizerGroup, default=BaseTokenizerGroup) + class LLMEngine: """Legacy LLMEngine for backwards compatibility.""" @@ -169,5 +174,18 @@ def start_profile(self): def stop_profile(self): self.engine_core.profile(False) - def get_tokenizer_group(self, group_type): - pass + def get_tokenizer_group( + self, + group_type: Type[_G] = BaseTokenizerGroup, + ) -> _G: + tokenizer_group = self.tokenizer + + if tokenizer_group is None: + raise ValueError("Unable to get tokenizer because " + "skip_tokenizer_init is True") + if not isinstance(tokenizer_group, group_type): + raise TypeError("Invalid type of tokenizer group. " + f"Expected type: {group_type}, but " + f"found type: {type(tokenizer_group)}") + + return tokenizer_group diff --git a/vllm/v1/engine/mm_input_mapper.py b/vllm/v1/engine/mm_input_mapper.py index 45882f8f076d4..7ad6882b04520 100644 --- a/vllm/v1/engine/mm_input_mapper.py +++ b/vllm/v1/engine/mm_input_mapper.py @@ -33,7 +33,7 @@ def process_inputs( num_images = len(image_inputs) for i in range(num_images): mm_input = self.multi_modal_input_mapper( - {"image": [image_inputs[i]]}, + {"image": image_inputs[i]}, mm_processor_kwargs=mm_processor_kwargs, ) mm_inputs.append(mm_input)