Skip to content

Commit

Permalink
[Misc] LoRA + Chunked Prefill (#9057)
Browse files Browse the repository at this point in the history
  • Loading branch information
aurickq authored Dec 11, 2024
1 parent 9a93973 commit d5c5154
Show file tree
Hide file tree
Showing 12 changed files with 49 additions and 20 deletions.
9 changes: 6 additions & 3 deletions tests/lora/test_chatglm3_tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ def test_chatglm3_lora(chatglm3_lora_files):
max_loras=4,
max_lora_rank=64,
tensor_parallel_size=1,
trust_remote_code=True)
trust_remote_code=True,
enable_chunked_prefill=True)

output1 = do_sample(llm, chatglm3_lora_files, lora_id=1)
for i in range(len(EXPECTED_LORA_OUTPUT)):
Expand All @@ -73,7 +74,8 @@ def test_chatglm3_lora_tp4(chatglm3_lora_files):
max_lora_rank=64,
tensor_parallel_size=4,
trust_remote_code=True,
fully_sharded_loras=False)
fully_sharded_loras=False,
enable_chunked_prefill=True)

output1 = do_sample(llm, chatglm3_lora_files, lora_id=1)
for i in range(len(EXPECTED_LORA_OUTPUT)):
Expand All @@ -93,7 +95,8 @@ def test_chatglm3_lora_tp4_fully_sharded_loras(chatglm3_lora_files):
max_lora_rank=64,
tensor_parallel_size=4,
trust_remote_code=True,
fully_sharded_loras=True)
fully_sharded_loras=True,
enable_chunked_prefill=True)
output1 = do_sample(llm, chatglm3_lora_files, lora_id=1)
for i in range(len(EXPECTED_LORA_OUTPUT)):
assert output1[i] == EXPECTED_LORA_OUTPUT[i]
Expand Down
3 changes: 2 additions & 1 deletion tests/lora/test_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ def test_gemma_lora(gemma_lora_files):
llm = vllm.LLM(MODEL_PATH,
max_model_len=1024,
enable_lora=True,
max_loras=4)
max_loras=4,
enable_chunked_prefill=True)

expected_lora_output = [
"more important than knowledge.\nAuthor: Albert Einstein\n",
Expand Down
6 changes: 5 additions & 1 deletion tests/lora/test_llama_tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,8 @@ def test_llama_lora(sql_lora_files):
enable_lora=True,
max_num_seqs=16,
max_loras=4,
tensor_parallel_size=1)
tensor_parallel_size=1,
enable_chunked_prefill=True)
generate_and_test(llm, sql_lora_files)


Expand Down Expand Up @@ -120,6 +121,7 @@ def test_llama_lora_tp4(sql_lora_files):
max_num_seqs=16,
max_loras=4,
tensor_parallel_size=4,
enable_chunked_prefill=True,
)
generate_and_test(llm, sql_lora_files)

Expand All @@ -135,6 +137,7 @@ def test_llama_lora_tp4_fully_sharded_loras(sql_lora_files):
max_loras=4,
tensor_parallel_size=4,
fully_sharded_loras=True,
enable_chunked_prefill=True,
)
generate_and_test(llm, sql_lora_files)

Expand All @@ -151,5 +154,6 @@ def test_llama_lora_tp4_fully_sharded_enable_bias(sql_lora_files):
tensor_parallel_size=4,
fully_sharded_loras=True,
enable_lora_bias=True,
enable_chunked_prefill=True,
)
generate_and_test(llm, sql_lora_files)
3 changes: 2 additions & 1 deletion tests/lora/test_long_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,8 @@ def lora_llm(long_context_infos):
tensor_parallel_size=4,
# FIXME enable async output processor
disable_async_output_proc=True,
distributed_executor_backend="mp")
distributed_executor_backend="mp",
enable_chunked_prefill=True)
yield llm
del llm

Expand Down
3 changes: 2 additions & 1 deletion tests/lora/test_minicpmv.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,8 @@ def test_minicpmv_lora(minicpmv_lora_files):
max_loras=4,
max_lora_rank=64,
trust_remote_code=True,
gpu_memory_utilization=0.97 # This model is pretty big for CI gpus
gpu_memory_utilization=0.97, # This model is pretty big for CI gpus
enable_chunked_prefill=True,
)
output1 = do_sample(llm, minicpmv_lora_files, lora_id=1)
for i in range(len(EXPECTED_OUTPUT)):
Expand Down
2 changes: 2 additions & 0 deletions tests/lora/test_minicpmv_tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def test_minicpmv_tp2(minicpmv_lora_files, fully_sharded):
tensor_parallel_size=2,
trust_remote_code=True,
fully_sharded_loras=fully_sharded,
enable_chunked_prefill=True,
)

output_tp = do_sample(llm, minicpmv_lora_files, lora_id=1)
Expand All @@ -89,6 +90,7 @@ def test_minicpmv_tp4(minicpmv_lora_files, fully_sharded):
tensor_parallel_size=4,
trust_remote_code=True,
fully_sharded_loras=fully_sharded,
enable_chunked_prefill=True,
)
output_tp = do_sample(llm, minicpmv_lora_files, lora_id=1)
for i in range(len(EXPECTED_OUTPUT)):
Expand Down
1 change: 1 addition & 0 deletions tests/lora/test_mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def test_mixtral_lora(mixtral_lora_files, tp_size):
max_loras=4,
distributed_executor_backend="ray",
tensor_parallel_size=tp_size,
enable_chunked_prefill=True,
)

expected_lora_output = [
Expand Down
3 changes: 2 additions & 1 deletion tests/lora/test_phi.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ def test_phi2_lora(phi2_lora_files):
max_model_len=1024,
enable_lora=True,
max_loras=2,
enforce_eager=True)
enforce_eager=True,
enable_chunked_prefill=True)

expected_lora_output = [
"SELECT catalog_publisher, COUNT(*) as num_catalogs FROM catalogs GROUP BY catalog_publisher ORDER BY num_catalogs DESC LIMIT 1;", # noqa: E501
Expand Down
9 changes: 6 additions & 3 deletions tests/lora/test_quant_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,8 @@ def test_quant_model_lora(tinyllama_lora_files, num_gpus_available, model,
tensor_parallel_size=tp_size,
gpu_memory_utilization=0.2, #avoid OOM
quantization=model.quantization,
trust_remote_code=True)
trust_remote_code=True,
enable_chunked_prefill=True)

if model.quantization is None:
expected_no_lora_output = [
Expand Down Expand Up @@ -176,7 +177,8 @@ def test_quant_model_tp_equality(tinyllama_lora_files, num_gpus_available,
tensor_parallel_size=1,
gpu_memory_utilization=0.2, #avoid OOM
quantization=model.quantization,
trust_remote_code=True)
trust_remote_code=True,
enable_chunked_prefill=True)
output_tp1 = do_sample(llm_tp1, tinyllama_lora_files, lora_id=1)

del llm_tp1
Expand All @@ -189,7 +191,8 @@ def test_quant_model_tp_equality(tinyllama_lora_files, num_gpus_available,
max_loras=4,
tensor_parallel_size=2,
gpu_memory_utilization=0.2, #avoid OOM
quantization=model.quantization)
quantization=model.quantization,
enable_chunked_prefill=True)
output_tp2 = do_sample(llm_tp2, tinyllama_lora_files, lora_id=1)

del llm_tp2
Expand Down
3 changes: 2 additions & 1 deletion vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1698,7 +1698,8 @@ def verify_with_scheduler_config(self, scheduler_config: SchedulerConfig):
# Reminder: Please update docs/source/usage/compatibility_matrix.rst
# If the feature combo become valid
if scheduler_config.chunked_prefill_enabled:
raise ValueError("LoRA is not supported with chunked prefill yet.")
logger.warning("LoRA with chunked prefill is still experimental "
"and may be unstable.")


@dataclass
Expand Down
15 changes: 12 additions & 3 deletions vllm/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,9 +166,18 @@ def is_empty(self) -> bool:
and not self.blocks_to_swap_out and not self.blocks_to_copy)

def _sort_by_lora_ids(self):
self.scheduled_seq_groups = sorted(
self.scheduled_seq_groups,
key=lambda g: (g.seq_group.lora_int_id, g.seq_group.request_id))
assert 0 <= self.num_prefill_groups <= len(self.scheduled_seq_groups)

def key_fn(group: ScheduledSequenceGroup):
key = (group.seq_group.lora_int_id, group.seq_group.request_id)
if 0 < self.num_prefill_groups < len(self.scheduled_seq_groups):
# Sort sequence groups so that all prefills come before all
# decodes as required by chunked prefill.
return (not group.seq_group.is_prefill(), *key)
return key

self.scheduled_seq_groups = sorted(self.scheduled_seq_groups,
key=key_fn)

@property
def lora_requests(self) -> Set[LoRARequest]:
Expand Down
12 changes: 7 additions & 5 deletions vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -622,11 +622,13 @@ def _compute_lora_input(self, inter_data: InterDataForSeqGroup,
inter_data.lora_requests.add(seq_group_metadata.lora_request)
query_len = inter_data.query_lens[seq_idx]
inter_data.lora_index_mapping.append([lora_id] * query_len)
inter_data.lora_prompt_mapping.append(
[lora_id] *
(query_len if seq_group_metadata.sampling_params
and seq_group_metadata.sampling_params.prompt_logprobs is not None
else 1))
sampling_params = seq_group_metadata.sampling_params
if sampling_params and sampling_params.prompt_logprobs is not None:
inter_data.lora_prompt_mapping.append([lora_id] * query_len)
elif not self.chunked_prefill_enabled or seq_group_metadata.do_sample:
inter_data.lora_prompt_mapping.append([lora_id])
else:
inter_data.lora_prompt_mapping.append([])

def _compute_prompt_adapter_input(
self, inter_data: InterDataForSeqGroup,
Expand Down

0 comments on commit d5c5154

Please sign in to comment.