From 7653aabe5ce70d28aa3b49edd4bc54d4aba6eb06 Mon Sep 17 00:00:00 2001 From: Aurick Qiao Date: Wed, 23 Oct 2024 19:54:57 +0000 Subject: [PATCH] add some non-chunked-prefill tests, comment --- tests/lora/test_chatglm3.py | 7 +++++-- tests/lora/test_gemma.py | 5 +++-- tests/lora/test_llama.py | 6 ++++-- vllm/core/scheduler.py | 3 +++ 4 files changed, 15 insertions(+), 6 deletions(-) diff --git a/tests/lora/test_chatglm3.py b/tests/lora/test_chatglm3.py index dbc0c7773912c..30647c1705e0b 100644 --- a/tests/lora/test_chatglm3.py +++ b/tests/lora/test_chatglm3.py @@ -1,5 +1,7 @@ from typing import List +import pytest + import vllm from vllm.lora.request import LoRARequest @@ -37,14 +39,15 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]: return generated_texts -def test_chatglm3_lora(chatglm3_lora_files): +@pytest.mark.parametrize("enable_chunked_prefill", [False, True]) +def test_chatglm3_lora(chatglm3_lora_files, enable_chunked_prefill): llm = vllm.LLM(MODEL_PATH, max_model_len=1024, enable_lora=True, max_loras=4, max_lora_rank=64, trust_remote_code=True, - enable_chunked_prefill=True) + enable_chunked_prefill=enable_chunked_prefill) expected_lora_output = [ "SELECT count(*) FROM singer", diff --git a/tests/lora/test_gemma.py b/tests/lora/test_gemma.py index a566e747799fa..992366d25a013 100644 --- a/tests/lora/test_gemma.py +++ b/tests/lora/test_gemma.py @@ -32,12 +32,13 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]: @pytest.mark.xfail(is_hip(), reason="There can be output mismatch on ROCm") -def test_gemma_lora(gemma_lora_files): +@pytest.mark.parametrize("enable_chunked_prefill", [False, True]) +def test_gemma_lora(gemma_lora_files, enable_chunked_prefill): llm = vllm.LLM(MODEL_PATH, max_model_len=1024, enable_lora=True, max_loras=4, - enable_chunked_prefill=True) + enable_chunked_prefill=enable_chunked_prefill) expected_lora_output = [ "more important than knowledge.\nAuthor: Albert Einstein\n", diff --git a/tests/lora/test_llama.py b/tests/lora/test_llama.py index cdbc647bd9f8e..f8533b76e3f64 100644 --- a/tests/lora/test_llama.py +++ b/tests/lora/test_llama.py @@ -38,7 +38,9 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]: @pytest.mark.parametrize("tp_size", [1, 2, 4]) -def test_llama_lora(sql_lora_files, tp_size, num_gpus_available): +@pytest.mark.parametrize("enable_chunked_prefill", [False, True]) +def test_llama_lora(sql_lora_files, tp_size, enable_chunked_prefill, + num_gpus_available): if num_gpus_available < tp_size: pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}") @@ -47,7 +49,7 @@ def test_llama_lora(sql_lora_files, tp_size, num_gpus_available): max_num_seqs=16, max_loras=4, tensor_parallel_size=tp_size, - enable_chunked_prefill=True) + enable_chunked_prefill=enable_chunked_prefill) expected_no_lora_output = [ "\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_75 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_76 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_77 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_78 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user]", # noqa: E501 diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index c93d260b2a50f..f6ee8886b2cfc 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -151,6 +151,9 @@ def is_empty(self) -> bool: and not self.blocks_to_swap_out and not self.blocks_to_copy) def _sort_by_lora_ids(self): + # Sort sequence groups so that (1) all prefills come before all decodes + # (required by chunked prefill), and (2) all LoRAs are grouped together + # for improved performance. self.scheduled_seq_groups = sorted( self.scheduled_seq_groups, key=lambda g: (not g.seq_group.is_prefill(), g.seq_group.