Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] Reduce TTFT with concurrent partial prefills #10235

Open
wants to merge 59 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
59 commits
Select commit Hold shift + click to select a range
f97eacf
:bug: fix multi-chunked-prefill sampler bug
joerunde Nov 6, 2024
b50a6b8
🚧 add num_prefill_slots arg
prashantgupta24 Nov 8, 2024
7f23c04
:sparkles: start to write prefill slot logic
joerunde Nov 8, 2024
d271cc9
🎨 format
prashantgupta24 Nov 8, 2024
b2cb96f
:sparkles: update num tokens for prefill slots
joerunde Nov 8, 2024
c349ac0
♻️ add schedule_chunked_prefill logic
prashantgupta24 Nov 8, 2024
e20518d
♻️ change function name
prashantgupta24 Nov 8, 2024
6ba0e34
:sparkles: reserve incoming prefill slots
joerunde Nov 8, 2024
a7491cc
🎨 fix some typos
prashantgupta24 Nov 8, 2024
1ee6fea
:zap: finish awesome scheduler
joerunde Nov 9, 2024
517915a
:bug: fix the deadlocks
joerunde Nov 11, 2024
ed298c3
:memo: Add more docstrings
joerunde Nov 11, 2024
90e0c07
:bug: fix deadlock
joerunde Nov 11, 2024
1c92ac2
:construction: WIP scheduler tests
joerunde Nov 11, 2024
de95f62
:bug: fix prefix caching
joerunde Nov 12, 2024
41e20ca
:test_tube: add prefix caching test
joerunde Nov 12, 2024
4dc7310
✅ add second test iteration
prashantgupta24 Nov 12, 2024
8e3118e
✅ add llm engine test
prashantgupta24 Nov 12, 2024
b6ebec8
♻️ quicker budget check
prashantgupta24 Nov 12, 2024
7e93668
🎨 rename to max_num_partial_prefills
prashantgupta24 Nov 13, 2024
557bfe3
🎨 more renaming to max_num_partial_prefills + docstring updates
prashantgupta24 Nov 13, 2024
d3e94df
🎨 rename big to long
prashantgupta24 Nov 13, 2024
849baf6
♻️ add cli args for partial_prefill configs
prashantgupta24 Nov 13, 2024
beaf086
🎨 fix request word typo
prashantgupta24 Nov 13, 2024
672a50c
🎨 more docstring changes
prashantgupta24 Nov 13, 2024
a2751ff
🎨 forgot to add the new args to config
prashantgupta24 Nov 13, 2024
dff757d
🐛 fix range bug on partial_prefill_budget_lookup_list
prashantgupta24 Nov 13, 2024
86ffa04
🎨 add docstring to test function
prashantgupta24 Nov 13, 2024
3d39942
:construction: WIP move metadata to dataclass
joerunde Nov 13, 2024
dbb9ae8
🎨 wrap up PartialPrefillMetadata
prashantgupta24 Nov 13, 2024
4bac8ed
♻️ add some utility functions within partial_prefill_metadata
prashantgupta24 Nov 13, 2024
c44ca1f
🎨 change to long_prefill_token_threshold
prashantgupta24 Nov 13, 2024
38bad7a
🔥 remove commented code
prashantgupta24 Nov 13, 2024
0f3efa1
🐛 fix the big bug! (Thanks Joe)
prashantgupta24 Nov 13, 2024
3daf35f
:memo: docstings galore
joerunde Nov 13, 2024
241853a
🎨 fix typo
prashantgupta24 Nov 14, 2024
07b6d72
⏪ revert logging change
prashantgupta24 Nov 14, 2024
c4bdf37
✅ remove value error from test
prashantgupta24 Nov 14, 2024
7c8b400
✅ remove value error from test
prashantgupta24 Nov 14, 2024
21796fc
🎨 fix typo
prashantgupta24 Nov 14, 2024
d993861
✅ make test comprehensive
prashantgupta24 Nov 15, 2024
946d297
🎨 fix unused vars in test
prashantgupta24 Nov 15, 2024
5535515
🎨 some more comments
prashantgupta24 Nov 18, 2024
ba91ddf
🎨 fix merge conflict
prashantgupta24 Nov 20, 2024
bccf86f
🎨 fmt
prashantgupta24 Nov 20, 2024
75848c9
♻️ merge with main
prashantgupta24 Nov 22, 2024
1c80379
Merge branch 'main' into prefill-slots
prashantgupta24 Nov 22, 2024
4f1c322
🎨 fix fmt
prashantgupta24 Nov 22, 2024
cb8fc93
⏪ revert quick budget check
prashantgupta24 Nov 22, 2024
8a8a07f
🎨 fmt
prashantgupta24 Nov 22, 2024
90a53ab
♻️ merge with main
prashantgupta24 Nov 26, 2024
29a7ccd
Merge remote-tracking branch 'upstream/main' into prefill-slots
prashantgupta24 Nov 26, 2024
752ce1b
🎨 fmt
prashantgupta24 Nov 26, 2024
edc204e
Merge remote-tracking branch 'upstream/main' into prefill-slots
joerunde Dec 6, 2024
0206173
Merge remote-tracking branch 'upstream/main' into prefill-slots
joerunde Dec 9, 2024
80b72ef
Merge remote-tracking branch 'upstream/main' into prefill-slots
joerunde Dec 18, 2024
03525f2
:bug: fix index out of range
joerunde Dec 18, 2024
d5f5eb6
:recycle: naming updates
joerunde Dec 19, 2024
cb5361a
:bug: fix long prefill threshold init
joerunde Dec 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 10 additions & 20 deletions tests/basic_correctness/test_chunked_prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
Run `pytest tests/models/test_chunked_prefill.py`.
"""
import os
from contextlib import nullcontext

import pytest

Expand Down Expand Up @@ -232,7 +231,6 @@ def test_with_prefix_caching(

max_num_batched_tokens = max_num_seqs = chunk_size
outputs = {} # type: ignore
check_result = True
for enable in (True, False):
with vllm_runner(
model,
Expand All @@ -244,25 +242,17 @@ def test_with_prefix_caching(
enforce_eager=enforce_eager,
max_num_seqs=max_num_seqs,
) as vllm_model:
# It should fail when prefix caching is enable and chunk
# size is not a multiple of block size (16).
should_fail = chunk_size % 16 != 0 and enable
check_result &= not should_fail
outputs[enable] = []
# Send the request one-by-one to ensure the cache is populated.
with pytest.raises(ValueError) if should_fail else nullcontext():
for prompt in full_prompts:
outputs[enable] += vllm_model.generate_greedy([prompt],
max_tokens)

# Check results only if we did not expect a failure.
if check_result:
check_outputs_equal(
outputs_0_lst=outputs[False],
outputs_1_lst=outputs[True],
name_0="w/o prefix caching",
name_1="with prefix caching",
)
for prompt in full_prompts:
outputs[enable] += vllm_model.generate_greedy([prompt],
max_tokens)

check_outputs_equal(
outputs_0_lst=outputs[False],
outputs_1_lst=outputs[True],
name_0="w/o prefix caching",
name_1="with prefix caching",
)


@pytest.mark.parametrize("model", ["facebook/opt-125m"])
Expand Down
298 changes: 296 additions & 2 deletions tests/core/test_chunked_prefill_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@

from vllm.config import CacheConfig, SchedulerConfig
from vllm.core.scheduler import Scheduler
from vllm.engine.arg_utils import EngineArgs
from vllm.engine.llm_engine import LLMEngine
from vllm.sampling_params import SamplingParams
from vllm.sequence import Logprob, SequenceGroup

from .utils import create_dummy_prompt
Expand All @@ -14,7 +17,7 @@ def get_sequence_groups(scheduler_output):
return [s.seq_group for s in scheduler_output.scheduled_seq_groups]


def append_new_token(seq_group, token_id: int):
def append_new_token(seq_group: SequenceGroup, token_id: int):
for seq in seq_group.get_seqs():
seq.append_token_id(token_id, {token_id: Logprob(token_id)})

Expand Down Expand Up @@ -121,6 +124,214 @@ def test_chunk():
assert out.num_batched_tokens == 57


def test_concurrent_chunking():
"""Verify prefills are chunked properly when
--max-num-partial-prefills is > 1"""
block_size = 4
max_seqs = 60
max_model_len = 2000
max_num_batched_tokens = 64
scheduler_config = SchedulerConfig(
"generate",
max_num_batched_tokens,
max_seqs,
max_model_len,
enable_chunked_prefill=True,
max_num_partial_prefills=2, # Up to 2 partial prefills at a time
)
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
cache_config.num_cpu_blocks = 32
cache_config.num_gpu_blocks = 32
scheduler = Scheduler(scheduler_config, cache_config, None)
running: List[SequenceGroup] = []

# Add seq groups to scheduler.
for i in range(2):
_, seq_group = create_dummy_prompt(str(i),
prompt_length=60,
block_size=block_size)
scheduler.add_seq_group(seq_group)
running.append(seq_group)

# Verify both requests are chunked with half of max_num_batched_tokens each
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
assert set(get_sequence_groups(out)) == set(running)
assert seq_group_meta[0].token_chunk_size == 32
assert seq_group_meta[1].token_chunk_size == 32
assert out.num_prefill_groups == 2
assert out.num_batched_tokens == 64

# After one iteration, both should have 60 - 32 = 28 tokens left to prefill
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
assert set(get_sequence_groups(out)) == set(running)
assert seq_group_meta[0].token_chunk_size == 28
assert seq_group_meta[1].token_chunk_size == 28
assert out.num_prefill_groups == 2
assert out.num_batched_tokens == 56


def test_concurrent_chunking_large_requests():
"""Verify large prefill requests are run one at a time"""
block_size = 4
max_seqs = 60
max_model_len = 2000
max_num_batched_tokens = 64
scheduler_config = SchedulerConfig(
"generate",
max_num_batched_tokens,
max_seqs,
max_model_len,
enable_chunked_prefill=True,
max_num_partial_prefills=2, # Up to 2 partial prefills at a time
)
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
cache_config.num_cpu_blocks = 3200 # large KV cache size for large requests
cache_config.num_gpu_blocks = 3200
scheduler = Scheduler(scheduler_config, cache_config, None)

# Add seq groups to scheduler.
for i in range(2):
_, seq_group = create_dummy_prompt(
str(i),
prompt_length=1200, # Very large prompt
block_size=block_size)
scheduler.add_seq_group(seq_group)

# Verify only a single request is chunked, and it gets all 64 tokens
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
assert len(get_sequence_groups(out)) == 1
assert seq_group_meta[0].token_chunk_size == 64
assert out.num_prefill_groups == 1
assert out.num_batched_tokens == 64


def test_short_prompts_jump_long_prompts_in_queue():
"""Verify large prefill requests are punted behind smaller ones if
another large prefill request is already running"""
block_size = 4
max_seqs = 60
max_model_len = 2000
max_num_batched_tokens = 64
scheduler_config = SchedulerConfig(
"generate",
max_num_batched_tokens,
max_seqs,
max_model_len,
enable_chunked_prefill=True,
max_num_partial_prefills=2, # Up to 2 partial prefills at a time
)
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
cache_config.num_cpu_blocks = 3200 # large KV cache size for large requests
cache_config.num_gpu_blocks = 3200
scheduler = Scheduler(scheduler_config, cache_config, None)
running: List[SequenceGroup] = []

# Add 2 large seq groups to scheduler.
for i in range(2):
_, seq_group = create_dummy_prompt(
str(i),
prompt_length=1200, # Very large prompt
block_size=block_size)
scheduler.add_seq_group(seq_group)
running.append(seq_group)
assert seq_group.is_prefill()

# Add 2 small seq groups behind them
for i in range(2):
_, seq_group = create_dummy_prompt(
str(i + 2),
prompt_length=40, # Very small prompt
block_size=block_size)
scheduler.add_seq_group(seq_group)
running.append(seq_group)
assert seq_group.is_prefill()

# Verify one large req and 1 small req chunked
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
assert seq_group_meta[0].token_chunk_size == 32 # large req gets 32 tokens
assert seq_group_meta[1].token_chunk_size == 32 # small req gets 32 tokens

# all 4 are prefilling
assert running[0].is_prefill()
assert running[1].is_prefill()
assert running[2].is_prefill()
assert running[3].is_prefill()

assert out.num_prefill_groups == 2
assert out.num_batched_tokens == 64

# in the second iteration,
# the first small request had only 8 tokens left
# so it went to decode
# The other small req is scheduled
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
# the new small req got 64 - (32+8) tokens
assert (seq_group_meta[0].token_chunk_size == 24)
assert seq_group_meta[1].token_chunk_size == 32 # large req still got 32
# the other small request had only 8 tokens left
assert seq_group_meta[2].token_chunk_size == 8 # 40-32

# notice the small request got to decode now
# this is because of max_num_partial_prefills logic
assert running[0].is_prefill()
assert running[1].is_prefill()
assert not running[2].is_prefill()
assert running[3].is_prefill()

assert out.num_prefill_groups == 3
assert out.num_batched_tokens == 64
# the small seq group has a new token appended.
append_new_token(running[2], 1)

# in the third iteration,
# the first small request has entered decode
# and other small req had 16 tokens left
# so it went to decode
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
assert seq_group_meta[0].token_chunk_size == 32 # large still got 32
# small req prefilled 40-24=16 tokens
assert (seq_group_meta[1].token_chunk_size == 16)
assert seq_group_meta[2].token_chunk_size == 1 # decode
assert out.num_prefill_groups == 2
assert out.num_batched_tokens == 49 # (32+16+1 decode)

# both small requests have now reached decode
assert running[0].is_prefill()
assert running[1].is_prefill()
assert not running[2].is_prefill()
assert not running[3].is_prefill()

# the small seq group has a new token appended.
append_new_token(running[2], 1)

# in the fourth iteration, both small requests are decoding
# so large request gets all the budget
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
# large req gets 63 tokens (minus 1 for decode)
assert seq_group_meta[0].token_chunk_size == 63
assert seq_group_meta[1].token_chunk_size == 1 # decode
Comment on lines +309 to +312
Copy link
Contributor

@prashantgupta24 prashantgupta24 Nov 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if this is a bug, but at this stage, request#3 should be decoding, but it didn't get any budget. Request#2 got budget for 1 decode token, and request#0 got the remaining budget for prefilling 63 tokens. Is that expected?

Based on this comment,

        # Update new running requests.
        # By default, vLLM scheduler prioritizes prefills.
        # Once chunked prefill is enabled,
        # the policy is changed to prioritize decode requests.

vllm should have prioritized decode requests and given both request#2 and request#3 1 budget, and request#0 62?

Copy link
Contributor

@prashantgupta24 prashantgupta24 Nov 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This does happen in the next iteration though

assert out.num_prefill_groups == 1
assert out.num_batched_tokens == 64

assert running[0].is_prefill()
assert running[1].is_prefill()
assert not running[2].is_prefill()
assert not running[3].is_prefill()

# both the small seq groups have a new token appended
append_new_token(running[2], 1)
append_new_token(running[3], 1)

# in the fifth iteration, large request gets all the budget
# while both small requests are decoding
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
assert seq_group_meta[0].token_chunk_size == 62
assert seq_group_meta[1].token_chunk_size == 1 # decode
assert seq_group_meta[2].token_chunk_size == 1 # decode
assert out.num_prefill_groups == 1
assert out.num_batched_tokens == 64


def test_complex():
block_size = 4
max_seqs = 60
Expand Down Expand Up @@ -506,7 +717,7 @@ def test_chunked_prefill_max_seqs():
assert not running[1].is_prefill()


def test_perfix_caching():
def test_prefix_caching():
"""Verify allocating full blocks when prefix caching is enabled."""
block_size = 4
max_seqs = 10
Expand Down Expand Up @@ -546,3 +757,86 @@ def test_perfix_caching():
assert seq_group_meta[1].token_chunk_size == 12
assert out.num_prefill_groups == 2
assert out.num_batched_tokens == 62


def test_prefix_caching_with_concurrent_partial_prefills():
"""Verify allocating full blocks when prefix caching is enabled with
--max-num-partial-prefills > 1."""
block_size = 4
max_seqs = 10
max_model_len = 8000
max_num_batched_tokens = 60 # With two slots, each slot will get 30 tokens
scheduler_config = SchedulerConfig("generate",
max_num_batched_tokens,
max_seqs,
max_model_len,
enable_chunked_prefill=True,
max_num_partial_prefills=2)
cache_config = CacheConfig(block_size,
1.0,
1,
"auto",
enable_prefix_caching=True)
cache_config.num_cpu_blocks = 0
cache_config.num_gpu_blocks = 32
scheduler = Scheduler(scheduler_config, cache_config, None)
running: List[SequenceGroup] = []

# Add seq groups to scheduler.
for i in range(2):
_, seq_group = create_dummy_prompt(str(i),
block_size=block_size,
prompt_length=50)
scheduler.add_seq_group(seq_group)
running.append(seq_group)

seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
assert set(get_sequence_groups(out)) == set(running)
# To partially prefill both sequences, both can chunk up to 30 tokens
# But the next lowest multiple of the block size (4) is 28
assert seq_group_meta[0].token_chunk_size == 28
assert seq_group_meta[1].token_chunk_size == 28
assert out.num_prefill_groups == 2
assert out.num_batched_tokens == 56

# On the next iteration, both sequences should finish prefill
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
assert set(get_sequence_groups(out)) == set(running)
# Both sequences have 50 - 28 = 22 tokens left to prefill.
# This is not a multiple of the block size, but we don't care since we don't
# cache the final partial block of prefix sequences
assert seq_group_meta[0].token_chunk_size == 22
assert seq_group_meta[1].token_chunk_size == 22
assert out.num_prefill_groups == 2
assert out.num_batched_tokens == 44


@pytest.mark.parametrize("model", ["facebook/opt-125m"])
@pytest.mark.parametrize("max_num_partial_prefills", [2, 4, 8])
def test_chunked_prefill_with_actual_engine(model: str,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @rickyyx here's what we tried to do to test that the sampler doesn't throw any assertions- we put multiple prompts into an engine and manually step it forward with them all partially prefilled

max_num_partial_prefills: int):
"""Make sure the model can actually sample with concurrent
partial prefills
"""

prompt = "hello" * 40

engine_args = EngineArgs(
model=model,
max_num_partial_prefills=max_num_partial_prefills,
max_num_batched_tokens=40,
max_num_seqs=8,
enable_chunked_prefill=True,
gpu_memory_utilization=0.8,
)

engine = LLMEngine.from_engine_args(engine_args)
sampling_params = SamplingParams(temperature=0)

for req_num in range(max_num_partial_prefills):
engine.add_request(f"{req_num}", prompt, sampling_params)
# first step
request_outputs = engine.step()
# means all are prefilling
assert len(request_outputs) == 0
assert len(engine.scheduler[0].running) == max_num_partial_prefills
Loading
Loading