Skip to content

Commit

Permalink
add some non-chunked-prefill tests, comment
Browse files Browse the repository at this point in the history
  • Loading branch information
aurickq committed Oct 23, 2024
1 parent 05b54fe commit 7653aab
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 6 deletions.
7 changes: 5 additions & 2 deletions tests/lora/test_chatglm3.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import List

import pytest

import vllm
from vllm.lora.request import LoRARequest

Expand Down Expand Up @@ -37,14 +39,15 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
return generated_texts


def test_chatglm3_lora(chatglm3_lora_files):
@pytest.mark.parametrize("enable_chunked_prefill", [False, True])
def test_chatglm3_lora(chatglm3_lora_files, enable_chunked_prefill):
llm = vllm.LLM(MODEL_PATH,
max_model_len=1024,
enable_lora=True,
max_loras=4,
max_lora_rank=64,
trust_remote_code=True,
enable_chunked_prefill=True)
enable_chunked_prefill=enable_chunked_prefill)

expected_lora_output = [
"SELECT count(*) FROM singer",
Expand Down
5 changes: 3 additions & 2 deletions tests/lora/test_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,13 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:


@pytest.mark.xfail(is_hip(), reason="There can be output mismatch on ROCm")
def test_gemma_lora(gemma_lora_files):
@pytest.mark.parametrize("enable_chunked_prefill", [False, True])
def test_gemma_lora(gemma_lora_files, enable_chunked_prefill):
llm = vllm.LLM(MODEL_PATH,
max_model_len=1024,
enable_lora=True,
max_loras=4,
enable_chunked_prefill=True)
enable_chunked_prefill=enable_chunked_prefill)

expected_lora_output = [
"more important than knowledge.\nAuthor: Albert Einstein\n",
Expand Down
6 changes: 4 additions & 2 deletions tests/lora/test_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,9 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:


@pytest.mark.parametrize("tp_size", [1, 2, 4])
def test_llama_lora(sql_lora_files, tp_size, num_gpus_available):
@pytest.mark.parametrize("enable_chunked_prefill", [False, True])
def test_llama_lora(sql_lora_files, tp_size, enable_chunked_prefill,
num_gpus_available):
if num_gpus_available < tp_size:
pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}")

Expand All @@ -47,7 +49,7 @@ def test_llama_lora(sql_lora_files, tp_size, num_gpus_available):
max_num_seqs=16,
max_loras=4,
tensor_parallel_size=tp_size,
enable_chunked_prefill=True)
enable_chunked_prefill=enable_chunked_prefill)

expected_no_lora_output = [
"\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_75 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_76 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_77 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_78 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user]", # noqa: E501
Expand Down
3 changes: 3 additions & 0 deletions vllm/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,9 @@ def is_empty(self) -> bool:
and not self.blocks_to_swap_out and not self.blocks_to_copy)

def _sort_by_lora_ids(self):
# Sort sequence groups so that (1) all prefills come before all decodes
# (required by chunked prefill), and (2) all LoRAs are grouped together
# for improved performance.
self.scheduled_seq_groups = sorted(
self.scheduled_seq_groups,
key=lambda g: (not g.seq_group.is_prefill(), g.seq_group.
Expand Down

0 comments on commit 7653aab

Please sign in to comment.