From d5c5154fcf4c5d65551c98e458cbb027e5f4b672 Mon Sep 17 00:00:00 2001 From: Aurick Qiao Date: Tue, 10 Dec 2024 21:09:20 -0500 Subject: [PATCH] [Misc] LoRA + Chunked Prefill (#9057) --- tests/lora/test_chatglm3_tp.py | 9 ++++++--- tests/lora/test_gemma.py | 3 ++- tests/lora/test_llama_tp.py | 6 +++++- tests/lora/test_long_context.py | 3 ++- tests/lora/test_minicpmv.py | 3 ++- tests/lora/test_minicpmv_tp.py | 2 ++ tests/lora/test_mixtral.py | 1 + tests/lora/test_phi.py | 3 ++- tests/lora/test_quant_model.py | 9 ++++++--- vllm/config.py | 3 ++- vllm/core/scheduler.py | 15 ++++++++++++--- vllm/worker/model_runner.py | 12 +++++++----- 12 files changed, 49 insertions(+), 20 deletions(-) diff --git a/tests/lora/test_chatglm3_tp.py b/tests/lora/test_chatglm3_tp.py index f17464573459f..49a527b99ac16 100644 --- a/tests/lora/test_chatglm3_tp.py +++ b/tests/lora/test_chatglm3_tp.py @@ -53,7 +53,8 @@ def test_chatglm3_lora(chatglm3_lora_files): max_loras=4, max_lora_rank=64, tensor_parallel_size=1, - trust_remote_code=True) + trust_remote_code=True, + enable_chunked_prefill=True) output1 = do_sample(llm, chatglm3_lora_files, lora_id=1) for i in range(len(EXPECTED_LORA_OUTPUT)): @@ -73,7 +74,8 @@ def test_chatglm3_lora_tp4(chatglm3_lora_files): max_lora_rank=64, tensor_parallel_size=4, trust_remote_code=True, - fully_sharded_loras=False) + fully_sharded_loras=False, + enable_chunked_prefill=True) output1 = do_sample(llm, chatglm3_lora_files, lora_id=1) for i in range(len(EXPECTED_LORA_OUTPUT)): @@ -93,7 +95,8 @@ def test_chatglm3_lora_tp4_fully_sharded_loras(chatglm3_lora_files): max_lora_rank=64, tensor_parallel_size=4, trust_remote_code=True, - fully_sharded_loras=True) + fully_sharded_loras=True, + enable_chunked_prefill=True) output1 = do_sample(llm, chatglm3_lora_files, lora_id=1) for i in range(len(EXPECTED_LORA_OUTPUT)): assert output1[i] == EXPECTED_LORA_OUTPUT[i] diff --git a/tests/lora/test_gemma.py b/tests/lora/test_gemma.py index 15ec66b0f5502..5ae705e474ec6 100644 --- a/tests/lora/test_gemma.py +++ b/tests/lora/test_gemma.py @@ -37,7 +37,8 @@ def test_gemma_lora(gemma_lora_files): llm = vllm.LLM(MODEL_PATH, max_model_len=1024, enable_lora=True, - max_loras=4) + max_loras=4, + enable_chunked_prefill=True) expected_lora_output = [ "more important than knowledge.\nAuthor: Albert Einstein\n", diff --git a/tests/lora/test_llama_tp.py b/tests/lora/test_llama_tp.py index d3ca7f878191a..dfeac380951d8 100644 --- a/tests/lora/test_llama_tp.py +++ b/tests/lora/test_llama_tp.py @@ -78,7 +78,8 @@ def test_llama_lora(sql_lora_files): enable_lora=True, max_num_seqs=16, max_loras=4, - tensor_parallel_size=1) + tensor_parallel_size=1, + enable_chunked_prefill=True) generate_and_test(llm, sql_lora_files) @@ -120,6 +121,7 @@ def test_llama_lora_tp4(sql_lora_files): max_num_seqs=16, max_loras=4, tensor_parallel_size=4, + enable_chunked_prefill=True, ) generate_and_test(llm, sql_lora_files) @@ -135,6 +137,7 @@ def test_llama_lora_tp4_fully_sharded_loras(sql_lora_files): max_loras=4, tensor_parallel_size=4, fully_sharded_loras=True, + enable_chunked_prefill=True, ) generate_and_test(llm, sql_lora_files) @@ -151,5 +154,6 @@ def test_llama_lora_tp4_fully_sharded_enable_bias(sql_lora_files): tensor_parallel_size=4, fully_sharded_loras=True, enable_lora_bias=True, + enable_chunked_prefill=True, ) generate_and_test(llm, sql_lora_files) diff --git a/tests/lora/test_long_context.py b/tests/lora/test_long_context.py index eada902c891f7..e7a34f2ced7ed 100644 --- a/tests/lora/test_long_context.py +++ b/tests/lora/test_long_context.py @@ -124,7 +124,8 @@ def lora_llm(long_context_infos): tensor_parallel_size=4, # FIXME enable async output processor disable_async_output_proc=True, - distributed_executor_backend="mp") + distributed_executor_backend="mp", + enable_chunked_prefill=True) yield llm del llm diff --git a/tests/lora/test_minicpmv.py b/tests/lora/test_minicpmv.py index 2c45ce5141f7d..1f3de9edc0d0f 100644 --- a/tests/lora/test_minicpmv.py +++ b/tests/lora/test_minicpmv.py @@ -67,7 +67,8 @@ def test_minicpmv_lora(minicpmv_lora_files): max_loras=4, max_lora_rank=64, trust_remote_code=True, - gpu_memory_utilization=0.97 # This model is pretty big for CI gpus + gpu_memory_utilization=0.97, # This model is pretty big for CI gpus + enable_chunked_prefill=True, ) output1 = do_sample(llm, minicpmv_lora_files, lora_id=1) for i in range(len(EXPECTED_OUTPUT)): diff --git a/tests/lora/test_minicpmv_tp.py b/tests/lora/test_minicpmv_tp.py index ba29e562e58ec..930f177953a5f 100644 --- a/tests/lora/test_minicpmv_tp.py +++ b/tests/lora/test_minicpmv_tp.py @@ -69,6 +69,7 @@ def test_minicpmv_tp2(minicpmv_lora_files, fully_sharded): tensor_parallel_size=2, trust_remote_code=True, fully_sharded_loras=fully_sharded, + enable_chunked_prefill=True, ) output_tp = do_sample(llm, minicpmv_lora_files, lora_id=1) @@ -89,6 +90,7 @@ def test_minicpmv_tp4(minicpmv_lora_files, fully_sharded): tensor_parallel_size=4, trust_remote_code=True, fully_sharded_loras=fully_sharded, + enable_chunked_prefill=True, ) output_tp = do_sample(llm, minicpmv_lora_files, lora_id=1) for i in range(len(EXPECTED_OUTPUT)): diff --git a/tests/lora/test_mixtral.py b/tests/lora/test_mixtral.py index dddc299da446b..150221dfce6ab 100644 --- a/tests/lora/test_mixtral.py +++ b/tests/lora/test_mixtral.py @@ -47,6 +47,7 @@ def test_mixtral_lora(mixtral_lora_files, tp_size): max_loras=4, distributed_executor_backend="ray", tensor_parallel_size=tp_size, + enable_chunked_prefill=True, ) expected_lora_output = [ diff --git a/tests/lora/test_phi.py b/tests/lora/test_phi.py index 733eff48a9bf3..5a3fcb8d690d9 100644 --- a/tests/lora/test_phi.py +++ b/tests/lora/test_phi.py @@ -53,7 +53,8 @@ def test_phi2_lora(phi2_lora_files): max_model_len=1024, enable_lora=True, max_loras=2, - enforce_eager=True) + enforce_eager=True, + enable_chunked_prefill=True) expected_lora_output = [ "SELECT catalog_publisher, COUNT(*) as num_catalogs FROM catalogs GROUP BY catalog_publisher ORDER BY num_catalogs DESC LIMIT 1;", # noqa: E501 diff --git a/tests/lora/test_quant_model.py b/tests/lora/test_quant_model.py index 5432fa4ad0d3a..026269667b473 100644 --- a/tests/lora/test_quant_model.py +++ b/tests/lora/test_quant_model.py @@ -84,7 +84,8 @@ def test_quant_model_lora(tinyllama_lora_files, num_gpus_available, model, tensor_parallel_size=tp_size, gpu_memory_utilization=0.2, #avoid OOM quantization=model.quantization, - trust_remote_code=True) + trust_remote_code=True, + enable_chunked_prefill=True) if model.quantization is None: expected_no_lora_output = [ @@ -176,7 +177,8 @@ def test_quant_model_tp_equality(tinyllama_lora_files, num_gpus_available, tensor_parallel_size=1, gpu_memory_utilization=0.2, #avoid OOM quantization=model.quantization, - trust_remote_code=True) + trust_remote_code=True, + enable_chunked_prefill=True) output_tp1 = do_sample(llm_tp1, tinyllama_lora_files, lora_id=1) del llm_tp1 @@ -189,7 +191,8 @@ def test_quant_model_tp_equality(tinyllama_lora_files, num_gpus_available, max_loras=4, tensor_parallel_size=2, gpu_memory_utilization=0.2, #avoid OOM - quantization=model.quantization) + quantization=model.quantization, + enable_chunked_prefill=True) output_tp2 = do_sample(llm_tp2, tinyllama_lora_files, lora_id=1) del llm_tp2 diff --git a/vllm/config.py b/vllm/config.py index 5fb9563fcf3a3..c66ddbb47f22e 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1698,7 +1698,8 @@ def verify_with_scheduler_config(self, scheduler_config: SchedulerConfig): # Reminder: Please update docs/source/usage/compatibility_matrix.rst # If the feature combo become valid if scheduler_config.chunked_prefill_enabled: - raise ValueError("LoRA is not supported with chunked prefill yet.") + logger.warning("LoRA with chunked prefill is still experimental " + "and may be unstable.") @dataclass diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index d23009dae01ee..94c62743883ec 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -166,9 +166,18 @@ def is_empty(self) -> bool: and not self.blocks_to_swap_out and not self.blocks_to_copy) def _sort_by_lora_ids(self): - self.scheduled_seq_groups = sorted( - self.scheduled_seq_groups, - key=lambda g: (g.seq_group.lora_int_id, g.seq_group.request_id)) + assert 0 <= self.num_prefill_groups <= len(self.scheduled_seq_groups) + + def key_fn(group: ScheduledSequenceGroup): + key = (group.seq_group.lora_int_id, group.seq_group.request_id) + if 0 < self.num_prefill_groups < len(self.scheduled_seq_groups): + # Sort sequence groups so that all prefills come before all + # decodes as required by chunked prefill. + return (not group.seq_group.is_prefill(), *key) + return key + + self.scheduled_seq_groups = sorted(self.scheduled_seq_groups, + key=key_fn) @property def lora_requests(self) -> Set[LoRARequest]: diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 1bc5f65c7127f..551b84435fdc0 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -622,11 +622,13 @@ def _compute_lora_input(self, inter_data: InterDataForSeqGroup, inter_data.lora_requests.add(seq_group_metadata.lora_request) query_len = inter_data.query_lens[seq_idx] inter_data.lora_index_mapping.append([lora_id] * query_len) - inter_data.lora_prompt_mapping.append( - [lora_id] * - (query_len if seq_group_metadata.sampling_params - and seq_group_metadata.sampling_params.prompt_logprobs is not None - else 1)) + sampling_params = seq_group_metadata.sampling_params + if sampling_params and sampling_params.prompt_logprobs is not None: + inter_data.lora_prompt_mapping.append([lora_id] * query_len) + elif not self.chunked_prefill_enabled or seq_group_metadata.do_sample: + inter_data.lora_prompt_mapping.append([lora_id]) + else: + inter_data.lora_prompt_mapping.append([]) def _compute_prompt_adapter_input( self, inter_data: InterDataForSeqGroup,