From f97eacfae851ca40b7acc353f1bb26562a62f9bc Mon Sep 17 00:00:00 2001 From: Joe Runde Date: Wed, 6 Nov 2024 16:15:34 -0700 Subject: [PATCH 01/54] :bug: fix multi-chunked-prefill sampler bug Signed-off-by: Joe Runde --- vllm/model_executor/layers/sampler.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index c10efefea5471..ddc8aadb81d8f 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -995,7 +995,9 @@ def get_logprobs( if len(query_indices) == 0: empty_sampled_logprob: SampleLogprobs = [] empty_prompt_logprob: Optional[PromptLogprobs] = None - return [empty_prompt_logprob], [empty_sampled_logprob] + num_seq_groups = len(sampling_metadata.seq_groups) + return [empty_prompt_logprob + ] * num_seq_groups, [empty_sampled_logprob] * num_seq_groups selected_logprobs, ranks = None, None top_logprobs, top_token_ids = None, None From b50a6b805fd4ac0ba2a4b0fa8ddda05c7de9cbcd Mon Sep 17 00:00:00 2001 From: Prashant Gupta Date: Fri, 8 Nov 2024 14:36:47 -0800 Subject: [PATCH 02/54] =?UTF-8?q?=F0=9F=9A=A7=20add=20num=5Fprefill=5Fslot?= =?UTF-8?q?s=20arg?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Prashant Gupta --- vllm/config.py | 2 ++ vllm/engine/arg_utils.py | 10 +++++++++- 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/vllm/config.py b/vllm/config.py index e69cbd3eb402a..c0ef0a0595c19 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1092,6 +1092,7 @@ def __init__(self, max_num_batched_tokens: Optional[int], max_num_seqs: int, max_model_len: int, + num_prefill_slots: int = 1, num_lookahead_slots: int = 0, delay_factor: float = 0.0, enable_chunked_prefill: bool = False, @@ -1132,6 +1133,7 @@ def __init__(self, ) self.max_num_batched_tokens = max_num_batched_tokens + self.num_prefill_slots = num_prefill_slots if enable_chunked_prefill: logger.info( diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 9288cd22c0036..9e57f99fc3c25 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -120,6 +120,7 @@ class EngineArgs: cpu_offload_gb: float = 0 # GiB gpu_memory_utilization: float = 0.90 max_num_batched_tokens: Optional[int] = None + num_prefill_slots: Optional[int] = 1 max_num_seqs: int = 256 max_logprobs: int = 20 # Default value for OpenAI Chat Completions API disable_log_stats: bool = False @@ -467,6 +468,12 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: default=EngineArgs.max_num_batched_tokens, help='Maximum number of batched tokens per ' 'iteration.') + parser.add_argument( + "--num-prefill-slots", + type=int, + default=EngineArgs.num_prefill_slots, + help="For chunked prefill, the number of prefill slots to use. Defaults to 1", + ) parser.add_argument('--max-num-seqs', type=int, default=EngineArgs.max_num_seqs, @@ -1105,7 +1112,8 @@ def create_engine_config(self) -> VllmConfig: multi_step_stream_outputs=self.multi_step_stream_outputs, send_delta_data=(envs.VLLM_USE_RAY_SPMD_WORKER and parallel_config.use_ray), - policy=self.scheduling_policy) + policy=self.scheduling_policy, + num_prefill_slots=self.num_prefill_slots) lora_config = LoRAConfig( bias_enabled=self.enable_lora_bias, max_lora_rank=self.max_lora_rank, From 7f23c04340e77c7afe150485f4db4e17fa218c38 Mon Sep 17 00:00:00 2001 From: Joe Runde Date: Fri, 8 Nov 2024 15:41:17 -0700 Subject: [PATCH 03/54] :sparkles: start to write prefill slot logic Signed-off-by: Joe Runde --- vllm/core/scheduler.py | 85 +++++++++++++++++++++++++----------------- 1 file changed, 50 insertions(+), 35 deletions(-) diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index af4671ec29be9..8ffb936b6ce9c 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -395,6 +395,23 @@ def __init__( # for processing and deallocation by the free_finished_seq_groups() self._async_stopped: List[SequenceGroup] = [] + # For chunked prefill, we allocate a set of "prefill slots" that + # each represent a sequence group that can be concurrently prefilled + self.num_prefill_slots = 2 # todo from config + self.prefill_slots_running = 0 + self.big_prefill_requests = 0 + # Requests with more than (4% max context length) tokens to prefill + # are "big" + self.big_prefill_threshold = scheduler_config.max_model_len // 25 + + # Dict cache with the chunk sizes to hand out to each sequence depending + # on how many prefill slots are used. + # This is just the full budget / number of prefill slots + self.prefill_chunk_sizes = { + slots: self.scheduler_config.max_num_batched_tokens / slots + for slots in range(self.num_prefill_slots) + } + @property def next_cache_id(self): return (self.cache_id + 1) % self.num_cache_iters @@ -1602,46 +1619,44 @@ def _get_num_new_tokens(self, seq_group: SequenceGroup, for seq in seqs: num_new_tokens += seq.get_num_new_tokens() assert num_new_tokens > 0 - # Chunk if a running request cannot fit in the given budget. + + if self.scheduler_config.is_multi_step: + # The current multi-step + chunked prefill capability does + # not actually support chunking prompts. + # + # Therefore, `num_new_tokens` is computed in the same fashion + # for both multi-step+chunked-prefill & + # multi-step+chunked-prefill+APC + # + # Prompts with more tokens than the current remaining budget + # are postponed to future scheduler steps + if num_new_tokens > self._get_prompt_limit(seq_group): + # If the seq_group is in prompt-stage, pass the + # num_new_tokens as-is so the caller can ignore + # the sequence. + pass + else: + num_new_tokens = 0 \ + if num_new_tokens > budget.remaining_token_budget() \ + else num_new_tokens + # If number of seq > 1, it means it is doing beam search # in a decode phase. Do not chunk. - if enable_chunking and len(seqs) == 1: + elif enable_chunking and len(seqs) == 1: remaining_token_budget = budget.remaining_token_budget() - if self.scheduler_config.is_multi_step: - # The current multi-step + chunked prefill capability does - # not actually support chunking prompts. - # - # Therefore, `num_new_tokens` is computed in the same fashion - # for both multi-step+chunked-prefill & - # multi-step+chunked-prefill+APC - # - # Prompts with more tokens than the current remaining budget - # are postponed to future scheduler steps - if num_new_tokens > self._get_prompt_limit(seq_group): - # If the seq_group is in prompt-stage, pass the - # num_new_tokens as-is so the caller can ignore - # the sequence. - pass - else: - num_new_tokens = 0 \ - if num_new_tokens > remaining_token_budget \ - else num_new_tokens - elif self.cache_config.enable_prefix_caching: + chunk_size = remaining_token_budget + + if self.cache_config.enable_prefix_caching: # When prefix caching is enabled, we always allocate # the number of new tokens that is dividable by the block # size to avoid partial block matching. block_size = self.cache_config.block_size - remainder = budget.token_budget % block_size - if remainder != 0: - raise ValueError("When enabling chunked prefill and " - "prefix caching, max_num_batched_tokens " - "(chunk size) must be dividable by " - "block size, but got chunk_size " - f"({budget.token_budget}) % block_size " - f"({block_size}) = {remainder}") - if remaining_token_budget < num_new_tokens: - num_new_tokens = (remaining_token_budget // - block_size) * block_size - else: - num_new_tokens = min(num_new_tokens, remaining_token_budget) + # Set chunk size to the next lowest multiple of block size + # so we don't exceed our budget + chunk_size = (chunk_size // block_size) * block_size + # NB: In the case where num_new_tokens < chunk_size, this does + # not allocate a multiple of `block_size` tokens. + + num_new_tokens = min(num_new_tokens, chunk_size) + return num_new_tokens From d271cc9e0da0a12f681db79821110e99e8b0a7fa Mon Sep 17 00:00:00 2001 From: Prashant Gupta Date: Fri, 8 Nov 2024 14:46:37 -0800 Subject: [PATCH 04/54] =?UTF-8?q?=F0=9F=8E=A8=20format?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Prashant Gupta --- vllm/engine/arg_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 9e57f99fc3c25..4b52f82e71eca 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -472,7 +472,8 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: "--num-prefill-slots", type=int, default=EngineArgs.num_prefill_slots, - help="For chunked prefill, the number of prefill slots to use. Defaults to 1", + help='For chunked prefill, the number of prefill slots to use. ' + 'Defaults to 1', ) parser.add_argument('--max-num-seqs', type=int, From b2cb96f957be8ab668e31d745dea7526b7e67edc Mon Sep 17 00:00:00 2001 From: Joe Runde Date: Fri, 8 Nov 2024 16:01:11 -0700 Subject: [PATCH 05/54] :sparkles: update num tokens for prefill slots Signed-off-by: Joe Runde --- vllm/core/scheduler.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 8ffb936b6ce9c..2896d48bdd3d2 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -1644,7 +1644,8 @@ def _get_num_new_tokens(self, seq_group: SequenceGroup, # in a decode phase. Do not chunk. elif enable_chunking and len(seqs) == 1: remaining_token_budget = budget.remaining_token_budget() - chunk_size = remaining_token_budget + # Get the number of tokens to allocate to this prefill slot + prefill_slot_budget = self.prefill_chunk_sizes[self.prefill_slots_running] if self.cache_config.enable_prefix_caching: # When prefix caching is enabled, we always allocate @@ -1653,10 +1654,10 @@ def _get_num_new_tokens(self, seq_group: SequenceGroup, block_size = self.cache_config.block_size # Set chunk size to the next lowest multiple of block size # so we don't exceed our budget - chunk_size = (chunk_size // block_size) * block_size + prefill_slot_budget = (prefill_slot_budget // block_size) * block_size # NB: In the case where num_new_tokens < chunk_size, this does # not allocate a multiple of `block_size` tokens. - num_new_tokens = min(num_new_tokens, chunk_size) + num_new_tokens = min(num_new_tokens, remaining_token_budget, prefill_slot_budget) return num_new_tokens From c349ac07057e666c447dcb4043dadaab343fc9a0 Mon Sep 17 00:00:00 2001 From: Prashant Gupta Date: Fri, 8 Nov 2024 15:34:03 -0800 Subject: [PATCH 06/54] =?UTF-8?q?=E2=99=BB=EF=B8=8F=20add=20schedule=5Fchu?= =?UTF-8?q?nked=5Fprefill=20logic?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Prashant Gupta --- vllm/core/scheduler.py | 42 ++++++++++++++++++++++++++++++++++-------- 1 file changed, 34 insertions(+), 8 deletions(-) diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 2896d48bdd3d2..f3b80f1f663a0 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -824,17 +824,17 @@ def _schedule_priority_preemption( SequenceStatus.WAITING, False, budget) - #Only preempt if priority inversion exists + # Only preempt if priority inversion exists while running_queue and self._get_priority( running_queue[-1]) > self._get_priority(seq_group): - #Only preempt if waiting sequence cannot be allocated + # Only preempt if waiting sequence cannot be allocated can_allocate = self.block_manager.can_allocate(seq_group) if (num_new_tokens and can_allocate == AllocStatus.OK and budget.can_schedule(num_new_tokens=num_new_tokens, num_new_seqs=num_new_seqs)): break - #Adjust budget to remove the victim sequence group + # Adjust budget to remove the victim sequence group vseq_group = running_queue.pop() num_running_tokens = self._get_num_new_tokens( vseq_group, SequenceStatus.RUNNING, False, budget) @@ -844,11 +844,11 @@ def _schedule_priority_preemption( budget.subtract_num_seqs(vseq_group.request_id, num_running_seqs) - #Preempt out the victim sequence group + # Preempt out the victim sequence group self._preempt(vseq_group, blocks_to_swap_out) waiting_queue.appendleft(vseq_group) force_preemption_count += 1 - #Put the sequence back into the waiting queue + # Put the sequence back into the waiting queue waiting_queue.appendleft(seq_group) waiting_queue = deque(sorted(waiting_queue, key=self._get_priority)) @@ -1141,6 +1141,16 @@ def _schedule_chunked_prefill(self) -> SchedulerOutputs: curr_loras, enable_chunking=True) + prefilling = running_scheduled.prefill_seq_groups + prefills.seq_groups + + prefilling = [ + p for p in prefilling if self._will_still_be_prefilling(p) + ] + + # Set slot counts for next iteration + self.prefill_slots_running = len(prefilling) + self.big_prefill_requests = self._count_big(prefilling) + assert (budget.num_batched_tokens <= self.scheduler_config.max_num_batched_tokens) assert budget.num_curr_seqs <= self.scheduler_config.max_num_seqs @@ -1187,6 +1197,19 @@ def _schedule_chunked_prefill(self) -> SchedulerOutputs: len(running_scheduled.swapped_out)), ) + def _count_big(self, + scheduled_seq_groups: List[ScheduledSequenceGroup]) -> int: + return len([ + scheduled_seq_group for scheduled_seq_group in scheduled_seq_groups + if scheduled_seq_group.seq_group.seqs[0].get_num_new_tokens() >= + self.big_prefill_threshold + ]) + + def _will_still_be_prefilling(self, + seq_group: ScheduledSequenceGroup) -> bool: + return seq_group.token_chunk_size != seq_group.seq_group.seqs[ + 0].get_num_new_tokens() + def _schedule(self) -> SchedulerOutputs: """Schedule queued requests.""" if self.scheduler_config.chunked_prefill_enabled: @@ -1645,7 +1668,8 @@ def _get_num_new_tokens(self, seq_group: SequenceGroup, elif enable_chunking and len(seqs) == 1: remaining_token_budget = budget.remaining_token_budget() # Get the number of tokens to allocate to this prefill slot - prefill_slot_budget = self.prefill_chunk_sizes[self.prefill_slots_running] + prefill_slot_budget = self.prefill_chunk_sizes[ + self.prefill_slots_running] if self.cache_config.enable_prefix_caching: # When prefix caching is enabled, we always allocate @@ -1654,10 +1678,12 @@ def _get_num_new_tokens(self, seq_group: SequenceGroup, block_size = self.cache_config.block_size # Set chunk size to the next lowest multiple of block size # so we don't exceed our budget - prefill_slot_budget = (prefill_slot_budget // block_size) * block_size + prefill_slot_budget = (prefill_slot_budget // + block_size) * block_size # NB: In the case where num_new_tokens < chunk_size, this does # not allocate a multiple of `block_size` tokens. - num_new_tokens = min(num_new_tokens, remaining_token_budget, prefill_slot_budget) + num_new_tokens = min(num_new_tokens, remaining_token_budget, + prefill_slot_budget) return num_new_tokens From e20518d6bd5054664613143cda747abac4e80b45 Mon Sep 17 00:00:00 2001 From: Prashant Gupta Date: Fri, 8 Nov 2024 15:41:14 -0800 Subject: [PATCH 07/54] =?UTF-8?q?=E2=99=BB=EF=B8=8F=20change=20function=20?= =?UTF-8?q?name?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Prashant Gupta --- vllm/core/scheduler.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index f3b80f1f663a0..953a86f96ff1d 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -397,7 +397,7 @@ def __init__( # For chunked prefill, we allocate a set of "prefill slots" that # each represent a sequence group that can be concurrently prefilled - self.num_prefill_slots = 2 # todo from config + self.num_prefill_slots = self.num_prefill_slots self.prefill_slots_running = 0 self.big_prefill_requests = 0 # Requests with more than (4% max context length) tokens to prefill @@ -1149,7 +1149,10 @@ def _schedule_chunked_prefill(self) -> SchedulerOutputs: # Set slot counts for next iteration self.prefill_slots_running = len(prefilling) - self.big_prefill_requests = self._count_big(prefilling) + self.big_prefill_requests = [ + seq_group for seq_group in prefilling + if self._is_big_seq_group(seq_group) + ] assert (budget.num_batched_tokens <= self.scheduler_config.max_num_batched_tokens) @@ -1197,13 +1200,10 @@ def _schedule_chunked_prefill(self) -> SchedulerOutputs: len(running_scheduled.swapped_out)), ) - def _count_big(self, - scheduled_seq_groups: List[ScheduledSequenceGroup]) -> int: - return len([ - scheduled_seq_group for scheduled_seq_group in scheduled_seq_groups - if scheduled_seq_group.seq_group.seqs[0].get_num_new_tokens() >= - self.big_prefill_threshold - ]) + def _is_big_seq_group(self, + scheduled_seq_group: ScheduledSequenceGroup) -> bool: + return scheduled_seq_group.seq_group.seqs[0].get_num_new_tokens( + ) >= self.big_prefill_threshold def _will_still_be_prefilling(self, seq_group: ScheduledSequenceGroup) -> bool: From 6ba0e34243c4d07cfcb9b806adf105777c63ce5b Mon Sep 17 00:00:00 2001 From: Joe Runde Date: Fri, 8 Nov 2024 16:42:49 -0700 Subject: [PATCH 08/54] :sparkles: reserve incoming prefill slots Signed-off-by: Joe Runde --- vllm/core/scheduler.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 953a86f96ff1d..c355fb6362ec5 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -403,6 +403,7 @@ def __init__( # Requests with more than (4% max context length) tokens to prefill # are "big" self.big_prefill_threshold = scheduler_config.max_model_len // 25 + self.max_big_requests = 1 # TODO: something # Dict cache with the chunk sizes to hand out to each sequence depending # on how many prefill slots are used. @@ -1125,6 +1126,12 @@ def _schedule_chunked_prefill(self) -> SchedulerOutputs: prefills = SchedulerPrefillOutputs.create_empty() swapped_in = SchedulerSwappedInOutputs.create_empty() + # Before any scheduling, decide if we have prefill slots available + # to pull new requests from the waiting queue + if self.prefill_slots_running < self.num_prefill_slots and len( + self.waiting) > 0: + self._reserve_prefill_slots_from_waiting_queue() + # Decoding should be always scheduled first by fcfs. running_scheduled = self._schedule_running(budget, curr_loras, @@ -1210,6 +1217,22 @@ def _will_still_be_prefilling(self, return seq_group.token_chunk_size != seq_group.seq_group.seqs[ 0].get_num_new_tokens() + def _reserve_prefill_slots_from_waiting_queue(self): + # Increment self.num_slots_filled for each request in the waiting queue + # that we can fit into a slot + for seq_group in self.waiting: + # Don't fill more slots than we have + if self.prefill_slots_running >= self.num_prefill_slots: + break + + # Disallow multiple big requests + if self._is_big_seq_group(seq_group): + if self.big_prefill_requests >= self.max_big_requests: + continue + self.big_prefill_requests += 1 + + self.prefill_slots_running += 1 + def _schedule(self) -> SchedulerOutputs: """Schedule queued requests.""" if self.scheduler_config.chunked_prefill_enabled: From a7491cc84ef974e4e22a59ab3def2c593fd2b802 Mon Sep 17 00:00:00 2001 From: Prashant Gupta Date: Fri, 8 Nov 2024 15:49:44 -0800 Subject: [PATCH 09/54] =?UTF-8?q?=F0=9F=8E=A8=20fix=20some=20typos?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Prashant Gupta --- vllm/core/scheduler.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index c355fb6362ec5..214aa5e968585 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -397,7 +397,7 @@ def __init__( # For chunked prefill, we allocate a set of "prefill slots" that # each represent a sequence group that can be concurrently prefilled - self.num_prefill_slots = self.num_prefill_slots + self.num_prefill_slots = scheduler_config.num_prefill_slots self.prefill_slots_running = 0 self.big_prefill_requests = 0 # Requests with more than (4% max context length) tokens to prefill @@ -1156,10 +1156,10 @@ def _schedule_chunked_prefill(self) -> SchedulerOutputs: # Set slot counts for next iteration self.prefill_slots_running = len(prefilling) - self.big_prefill_requests = [ + self.big_prefill_requests = len([ seq_group for seq_group in prefilling - if self._is_big_seq_group(seq_group) - ] + if self._is_big_seq_group(seq_group.seq_group) + ]) assert (budget.num_batched_tokens <= self.scheduler_config.max_num_batched_tokens) @@ -1207,10 +1207,9 @@ def _schedule_chunked_prefill(self) -> SchedulerOutputs: len(running_scheduled.swapped_out)), ) - def _is_big_seq_group(self, - scheduled_seq_group: ScheduledSequenceGroup) -> bool: - return scheduled_seq_group.seq_group.seqs[0].get_num_new_tokens( - ) >= self.big_prefill_threshold + def _is_big_seq_group(self, seq_group: SequenceGroup) -> bool: + return (seq_group.seqs[0].get_num_new_tokens() >= + self.big_prefill_threshold) def _will_still_be_prefilling(self, seq_group: ScheduledSequenceGroup) -> bool: From 1ee6feab0cc5d06a2948a634e71e843c8f8b454f Mon Sep 17 00:00:00 2001 From: Joe Runde Date: Fri, 8 Nov 2024 17:06:57 -0700 Subject: [PATCH 10/54] :zap: finish awesome scheduler Signed-off-by: Joe Runde --- vllm/core/scheduler.py | 32 ++++++++++++++++++++------------ 1 file changed, 20 insertions(+), 12 deletions(-) diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 214aa5e968585..eae01077c7fd8 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -405,14 +405,6 @@ def __init__( self.big_prefill_threshold = scheduler_config.max_model_len // 25 self.max_big_requests = 1 # TODO: something - # Dict cache with the chunk sizes to hand out to each sequence depending - # on how many prefill slots are used. - # This is just the full budget / number of prefill slots - self.prefill_chunk_sizes = { - slots: self.scheduler_config.max_num_batched_tokens / slots - for slots in range(self.num_prefill_slots) - } - @property def next_cache_id(self): return (self.cache_id + 1) % self.num_cache_iters @@ -887,6 +879,15 @@ def _schedule_prefills( Returns: SchedulerPrefillOutputs. """ + if self.prefill_slots_running >= self.num_prefill_slots \ + or budget.remaining_token_budget() == 0: + # Do nothing: Can't add any more prefill anyway + return SchedulerPrefillOutputs( + seq_groups=[], + ignored_seq_groups=[], + num_lookahead_slots=self._get_num_lookahead_slots( + is_prefill=True, enable_chunking=enable_chunking)) + ignored_seq_groups: List[SequenceGroup] = [] seq_groups: List[ScheduledSequenceGroup] = [] @@ -900,6 +901,13 @@ def _schedule_prefills( assert len(waiting_seqs) == 1, ( "Waiting sequence group should have only one prompt " "sequence.") + + if self._is_big_seq_group( + seq_group + ) and self.big_prefill_requests >= self.max_big_requests: + # Cannot schedule more big requests than max_big_requests + break + num_new_tokens = self._get_num_new_tokens(seq_group, SequenceStatus.WAITING, enable_chunking, budget) @@ -1208,8 +1216,8 @@ def _schedule_chunked_prefill(self) -> SchedulerOutputs: ) def _is_big_seq_group(self, seq_group: SequenceGroup) -> bool: - return (seq_group.seqs[0].get_num_new_tokens() >= - self.big_prefill_threshold) + return seq_group.seqs[0].get_num_new_tokens( + ) >= self.big_prefill_threshold def _will_still_be_prefilling(self, seq_group: ScheduledSequenceGroup) -> bool: @@ -1690,8 +1698,8 @@ def _get_num_new_tokens(self, seq_group: SequenceGroup, elif enable_chunking and len(seqs) == 1: remaining_token_budget = budget.remaining_token_budget() # Get the number of tokens to allocate to this prefill slot - prefill_slot_budget = self.prefill_chunk_sizes[ - self.prefill_slots_running] + prefill_slot_budget = \ + budget.token_budget // self.prefill_slots_running if self.cache_config.enable_prefix_caching: # When prefix caching is enabled, we always allocate From 517915aeb267c9d34abb0b699960d91025aa8dda Mon Sep 17 00:00:00 2001 From: Joe Runde Date: Mon, 11 Nov 2024 09:49:28 -0700 Subject: [PATCH 11/54] :bug: fix the deadlocks Signed-off-by: Joe Runde --- vllm/core/scheduler.py | 34 ++++++++++++++++++++++++---------- 1 file changed, 24 insertions(+), 10 deletions(-) diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index eae01077c7fd8..be40585582f73 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -2,7 +2,7 @@ import os import random import time -from collections import deque +from collections import defaultdict, deque from dataclasses import dataclass, field from typing import Callable, Deque, Dict, Iterable, List, Optional from typing import Sequence as GenericSequence @@ -405,6 +405,17 @@ def __init__( self.big_prefill_threshold = scheduler_config.max_model_len // 25 self.max_big_requests = 1 # TODO: something + # Dict cache with the chunk sizes to hand out to each sequence depending + # on how many prefill slots are used. + # This is just the full budget / number of prefill slots + self.prefill_chunk_sizes = defaultdict( + lambda: scheduler_config.max_num_batched_tokens // self. + num_prefill_slots) + self.prefill_chunk_sizes[0] = scheduler_config.max_num_batched_tokens + for i in range(1, self.num_prefill_slots): + self.prefill_chunk_sizes[i] = \ + scheduler_config.max_num_batched_tokens // i + @property def next_cache_id(self): return (self.cache_id + 1) % self.num_cache_iters @@ -879,8 +890,7 @@ def _schedule_prefills( Returns: SchedulerPrefillOutputs. """ - if self.prefill_slots_running >= self.num_prefill_slots \ - or budget.remaining_token_budget() == 0: + if budget.remaining_token_budget() == 0: # Do nothing: Can't add any more prefill anyway return SchedulerPrefillOutputs( seq_groups=[], @@ -902,11 +912,10 @@ def _schedule_prefills( "Waiting sequence group should have only one prompt " "sequence.") - if self._is_big_seq_group( - seq_group - ) and self.big_prefill_requests >= self.max_big_requests: + is_big = self._is_big_seq_group(seq_group) + if is_big and self.big_prefill_requests >= self.max_big_requests: # Cannot schedule more big requests than max_big_requests - break + continue num_new_tokens = self._get_num_new_tokens(seq_group, SequenceStatus.WAITING, @@ -973,6 +982,9 @@ def _schedule_prefills( waiting_queue.popleft() self._allocate_and_set_running(seq_group) + if is_big: + self.big_prefill_requests += 1 + if enable_chunking and self.scheduler_config.is_multi_step: blocks_to_copy: List[Tuple[int, int]] = [] # init_multi_step_from_lookahead_slots happens in append_slots @@ -1227,6 +1239,7 @@ def _will_still_be_prefilling(self, def _reserve_prefill_slots_from_waiting_queue(self): # Increment self.num_slots_filled for each request in the waiting queue # that we can fit into a slot + queued_big_requests = 0 for seq_group in self.waiting: # Don't fill more slots than we have if self.prefill_slots_running >= self.num_prefill_slots: @@ -1234,9 +1247,10 @@ def _reserve_prefill_slots_from_waiting_queue(self): # Disallow multiple big requests if self._is_big_seq_group(seq_group): - if self.big_prefill_requests >= self.max_big_requests: + if self.big_prefill_requests + queued_big_requests \ + >= self.max_big_requests: continue - self.big_prefill_requests += 1 + queued_big_requests += 1 self.prefill_slots_running += 1 @@ -1699,7 +1713,7 @@ def _get_num_new_tokens(self, seq_group: SequenceGroup, remaining_token_budget = budget.remaining_token_budget() # Get the number of tokens to allocate to this prefill slot prefill_slot_budget = \ - budget.token_budget // self.prefill_slots_running + self.prefill_chunk_sizes[self.prefill_slots_running] if self.cache_config.enable_prefix_caching: # When prefix caching is enabled, we always allocate From ed298c32aff9822db305ad07aebfb738678add77 Mon Sep 17 00:00:00 2001 From: Joe Runde Date: Mon, 11 Nov 2024 10:46:24 -0700 Subject: [PATCH 12/54] :memo: Add more docstrings Signed-off-by: Joe Runde --- vllm/core/scheduler.py | 31 +++++++++++++++++++++++++------ 1 file changed, 25 insertions(+), 6 deletions(-) diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index be40585582f73..bcdf7be1c1f96 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -396,18 +396,29 @@ def __init__( self._async_stopped: List[SequenceGroup] = [] # For chunked prefill, we allocate a set of "prefill slots" that - # each represent a sequence group that can be concurrently prefilled + # each represent a sequence group that can be partially prefilled. + # Having multiple partial prefills in flight allows us to minimize TTFT + # and avoid decode starvation in cases where a single sequence group + # with a very large prompt blocks the queue for too many iterations. self.num_prefill_slots = scheduler_config.num_prefill_slots self.prefill_slots_running = 0 self.big_prefill_requests = 0 # Requests with more than (4% max context length) tokens to prefill - # are "big" + # are "big". + # The number of big prefill requests is limited so that smaller + # requests may jump the queue in front of them and get to the decode + # phase faster. self.big_prefill_threshold = scheduler_config.max_model_len // 25 self.max_big_requests = 1 # TODO: something # Dict cache with the chunk sizes to hand out to each sequence depending - # on how many prefill slots are used. - # This is just the full budget / number of prefill slots + # on how many prefill slots are used. This is slightly faster than + # running an integer division every time a prefill is scheduled. + # This splits the budget evenly among all prefill slots. + # We use a defaultdict here to handle the case where we prefill many + # more requests than prefill slots. This is normal when requests have + # very small prompts to prefill, and we want to give them each the same + # budget. self.prefill_chunk_sizes = defaultdict( lambda: scheduler_config.max_num_batched_tokens // self. num_prefill_slots) @@ -1228,17 +1239,25 @@ def _schedule_chunked_prefill(self) -> SchedulerOutputs: ) def _is_big_seq_group(self, seq_group: SequenceGroup) -> bool: + """Simple heuristic to check if a sequence group needs a lot of prefill + work.""" return seq_group.seqs[0].get_num_new_tokens( ) >= self.big_prefill_threshold def _will_still_be_prefilling(self, seq_group: ScheduledSequenceGroup) -> bool: + """Check if a sequence will be mid-prefill after this iteration. + We need to know how many partial prefills will be running in order to + properly budget the next iteration.""" return seq_group.token_chunk_size != seq_group.seq_group.seqs[ 0].get_num_new_tokens() def _reserve_prefill_slots_from_waiting_queue(self): - # Increment self.num_slots_filled for each request in the waiting queue - # that we can fit into a slot + """Peek into the waiting queue to see how many requests we may be able + to start prefilling during this scheduling iteration. This allows us to + budget fewer tokens for currently running prefills if we know that more + requests from the queue will fit. + """ queued_big_requests = 0 for seq_group in self.waiting: # Don't fill more slots than we have From 90e0c07dcb67b84528226a09cbbf8190578c2e41 Mon Sep 17 00:00:00 2001 From: Joe Runde Date: Mon, 11 Nov 2024 13:37:53 -0700 Subject: [PATCH 13/54] :bug: fix deadlock Signed-off-by: Joe Runde --- vllm/core/scheduler.py | 2 ++ vllm/model_executor/layers/sampler.py | 4 ++++ 2 files changed, 6 insertions(+) diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index bcdf7be1c1f96..301f81daecc6b 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -926,6 +926,8 @@ def _schedule_prefills( is_big = self._is_big_seq_group(seq_group) if is_big and self.big_prefill_requests >= self.max_big_requests: # Cannot schedule more big requests than max_big_requests + leftover_waiting_sequences.appendleft(seq_group) + waiting_queue.popleft() continue num_new_tokens = self._get_num_new_tokens(seq_group, diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index ddc8aadb81d8f..8792bd42d54d2 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -1264,6 +1264,10 @@ def _build_sampler_output( assert sample_logprobs is not None assert not isinstance(maybe_deferred_sample_results, SampleResultArgsType) + assert len(sampling_metadata.seq_groups) \ + == len(maybe_deferred_sample_results) \ + == len(prompt_logprobs) \ + == len(sample_logprobs) deferred_sample_results_args = None for (seq_group, sample_result, group_prompt_logprobs, From 1c92ac2d1ec007737e2832cc5743a46be323486c Mon Sep 17 00:00:00 2001 From: Joe Runde Date: Mon, 11 Nov 2024 16:38:39 -0700 Subject: [PATCH 14/54] :construction: WIP scheduler tests Signed-off-by: Joe Runde --- tests/core/test_chunked_prefill_scheduler.py | 130 +++++++++++++++++++ 1 file changed, 130 insertions(+) diff --git a/tests/core/test_chunked_prefill_scheduler.py b/tests/core/test_chunked_prefill_scheduler.py index acd82065ae457..d048d2b7d4737 100644 --- a/tests/core/test_chunked_prefill_scheduler.py +++ b/tests/core/test_chunked_prefill_scheduler.py @@ -121,6 +121,136 @@ def test_chunk(): assert out.num_batched_tokens == 57 +def test_concurrent_chunking(): + """Verify prefills are chunked properly when --num-prefill-slots is > 1""" + block_size = 4 + max_seqs = 60 + max_model_len = 2000 + max_num_batched_tokens = 64 + scheduler_config = SchedulerConfig( + "generate", + max_num_batched_tokens, + max_seqs, + max_model_len, + enable_chunked_prefill=True, + num_prefill_slots=2, # Up to 2 partial prefills at a time + ) + cache_config = CacheConfig(block_size, 1.0, 1, "auto") + cache_config.num_cpu_blocks = 32 + cache_config.num_gpu_blocks = 32 + scheduler = Scheduler(scheduler_config, cache_config, None) + running: List[SequenceGroup] = [] + + # Add seq groups to scheduler. + for i in range(2): + _, seq_group = create_dummy_prompt(str(i), + prompt_length=60, + block_size=block_size) + scheduler.add_seq_group(seq_group) + running.append(seq_group) + + # Verify both requests are chunked with half of max_num_batched_tokens each + seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) + assert set(get_sequence_groups(out)) == set(running) + assert seq_group_meta[0].token_chunk_size == 32 + assert seq_group_meta[1].token_chunk_size == 32 + assert out.num_prefill_groups == 2 + assert out.num_batched_tokens == 64 + + # After one iteration, both should have 60 - 32 = 28 tokens left to prefill + seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) + assert set(get_sequence_groups(out)) == set(running) + assert seq_group_meta[0].token_chunk_size == 28 + assert seq_group_meta[1].token_chunk_size == 28 + assert out.num_prefill_groups == 2 + assert out.num_batched_tokens == 56 + + +def test_concurrent_chunking_large_requests(): + """Verify large prefill requests are run one at a time""" + block_size = 4 + max_seqs = 60 + max_model_len = 2000 + max_num_batched_tokens = 64 + scheduler_config = SchedulerConfig( + "generate", + max_num_batched_tokens, + max_seqs, + max_model_len, + enable_chunked_prefill=True, + num_prefill_slots=2, # Up to 2 partial prefills at a time + ) + cache_config = CacheConfig(block_size, 1.0, 1, "auto") + cache_config.num_cpu_blocks = 32 + cache_config.num_gpu_blocks = 32 + scheduler = Scheduler(scheduler_config, cache_config, None) + + # Add seq groups to scheduler. + for i in range(2): + _, seq_group = create_dummy_prompt( + str(i), + prompt_length=1200, # Very large prompt + block_size=block_size) + scheduler.add_seq_group(seq_group) + + # Verify only a single request is chunked, and it gets all 64 tokens + seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) + assert len(get_sequence_groups(out)) == 1 + assert seq_group_meta[0].token_chunk_size == 64 + assert out.num_prefill_groups == 1 + assert out.num_batched_tokens == 64 + + +def test_small_prompts_jump_big_prompts_in_queue(): + """Verify large prefill requests are punted behind smaller ones if + another large prefill request is already running""" + block_size = 4 + max_seqs = 60 + max_model_len = 2000 + max_num_batched_tokens = 64 + scheduler_config = SchedulerConfig( + "generate", + max_num_batched_tokens, + max_seqs, + max_model_len, + enable_chunked_prefill=True, + num_prefill_slots=2, # Up to 2 partial prefills at a time + ) + cache_config = CacheConfig(block_size, 1.0, 1, "auto") + cache_config.num_cpu_blocks = 32 + cache_config.num_gpu_blocks = 32 + scheduler = Scheduler(scheduler_config, cache_config, None) + running: List[SequenceGroup] = [] + + # Add 2 large seq groups to scheduler. + for i in range(2): + _, seq_group = create_dummy_prompt( + str(i), + prompt_length=1200, # Very large prompt + block_size=block_size) + scheduler.add_seq_group(seq_group) + running.append(seq_group) + + # Add 2 small seq groups behind them + for i in range(2): + _, seq_group = create_dummy_prompt( + str(i), + prompt_length=12, # Very small prompt + block_size=block_size) + scheduler.add_seq_group(seq_group) + running.append(seq_group) + + # Verify one large req and two small reqs chunked + seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) + assert len(get_sequence_groups(out)) == 3 + assert seq_group_meta[0].token_chunk_size == 32 # large req gets 32 tokens + assert seq_group_meta[ + 1].token_chunk_size == 12 # both small reqs fit in remaining 32 tokens + assert seq_group_meta[2].token_chunk_size == 12 + assert out.num_prefill_groups == 3 + assert out.num_batched_tokens == 64 + + def test_complex(): block_size = 4 max_seqs = 60 From de95f628a15c520d6391dd3725a606de62f335d3 Mon Sep 17 00:00:00 2001 From: Joe Runde Date: Tue, 12 Nov 2024 09:43:22 -0700 Subject: [PATCH 15/54] :bug: fix prefix caching Signed-off-by: Joe Runde --- tests/core/test_chunked_prefill_scheduler.py | 23 +++++++--------- vllm/core/scheduler.py | 28 ++++++++++++-------- 2 files changed, 27 insertions(+), 24 deletions(-) diff --git a/tests/core/test_chunked_prefill_scheduler.py b/tests/core/test_chunked_prefill_scheduler.py index d048d2b7d4737..1e4de76cda41d 100644 --- a/tests/core/test_chunked_prefill_scheduler.py +++ b/tests/core/test_chunked_prefill_scheduler.py @@ -181,8 +181,8 @@ def test_concurrent_chunking_large_requests(): num_prefill_slots=2, # Up to 2 partial prefills at a time ) cache_config = CacheConfig(block_size, 1.0, 1, "auto") - cache_config.num_cpu_blocks = 32 - cache_config.num_gpu_blocks = 32 + cache_config.num_cpu_blocks = 3200 # large KV cache size for large requests + cache_config.num_gpu_blocks = 3200 scheduler = Scheduler(scheduler_config, cache_config, None) # Add seq groups to scheduler. @@ -217,10 +217,9 @@ def test_small_prompts_jump_big_prompts_in_queue(): num_prefill_slots=2, # Up to 2 partial prefills at a time ) cache_config = CacheConfig(block_size, 1.0, 1, "auto") - cache_config.num_cpu_blocks = 32 - cache_config.num_gpu_blocks = 32 + cache_config.num_cpu_blocks = 3200 # large KV cache size for large requests + cache_config.num_gpu_blocks = 3200 scheduler = Scheduler(scheduler_config, cache_config, None) - running: List[SequenceGroup] = [] # Add 2 large seq groups to scheduler. for i in range(2): @@ -229,26 +228,24 @@ def test_small_prompts_jump_big_prompts_in_queue(): prompt_length=1200, # Very large prompt block_size=block_size) scheduler.add_seq_group(seq_group) - running.append(seq_group) # Add 2 small seq groups behind them for i in range(2): _, seq_group = create_dummy_prompt( - str(i), + str(i + 2), prompt_length=12, # Very small prompt block_size=block_size) scheduler.add_seq_group(seq_group) - running.append(seq_group) # Verify one large req and two small reqs chunked seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) - assert len(get_sequence_groups(out)) == 3 + # assert len(get_sequence_groups(out)) == 3 assert seq_group_meta[0].token_chunk_size == 32 # large req gets 32 tokens - assert seq_group_meta[ - 1].token_chunk_size == 12 # both small reqs fit in remaining 32 tokens + # both small reqs fit in remaining 32 tokens + assert seq_group_meta[1].token_chunk_size == 12 assert seq_group_meta[2].token_chunk_size == 12 assert out.num_prefill_groups == 3 - assert out.num_batched_tokens == 64 + assert out.num_batched_tokens == 56 def test_complex(): @@ -597,7 +594,7 @@ def test_chunked_prefill_max_seqs(): assert not running[1].is_prefill() -def test_perfix_caching(): +def test_prefix_caching(): """Verify allocating full blocks when prefix caching is enabled.""" block_size = 4 max_seqs = 10 diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 301f81daecc6b..1d59e062cf88f 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -924,8 +924,11 @@ def _schedule_prefills( "sequence.") is_big = self._is_big_seq_group(seq_group) - if is_big and self.big_prefill_requests >= self.max_big_requests: - # Cannot schedule more big requests than max_big_requests + if is_big and self.big_prefill_requests >= self.max_big_requests \ + and self.num_prefill_slots > 1: + # When we have more than one prefill slot, we limit the number + # of big requests to avoid filling all of our slots with partial + # prefills of big prompt sequences. leftover_waiting_sequences.appendleft(seq_group) waiting_queue.popleft() continue @@ -1737,16 +1740,19 @@ def _get_num_new_tokens(self, seq_group: SequenceGroup, self.prefill_chunk_sizes[self.prefill_slots_running] if self.cache_config.enable_prefix_caching: - # When prefix caching is enabled, we always allocate - # the number of new tokens that is dividable by the block - # size to avoid partial block matching. + # When prefix caching is enabled and we're partially prefilling + # a sequence, we always allocate a number of new tokens that is + # divisible by the block size to avoid partial block matching. block_size = self.cache_config.block_size - # Set chunk size to the next lowest multiple of block size - # so we don't exceed our budget - prefill_slot_budget = (prefill_slot_budget // - block_size) * block_size - # NB: In the case where num_new_tokens < chunk_size, this does - # not allocate a multiple of `block_size` tokens. + # Don't exceed either the total budget or slot budget. + # Take min of those and get the next lowest multiple of the + # block size: + remaining_token_budget = \ + (min(remaining_token_budget, prefill_slot_budget) // + block_size) * block_size + # NB: In the case where num_new_tokens < budget, we are + # finishing prefill for this sequence, so we do not need to + # allocate a full block. num_new_tokens = min(num_new_tokens, remaining_token_budget, prefill_slot_budget) From 41e20ca853a962870e0b9002a913cf6912772ad4 Mon Sep 17 00:00:00 2001 From: Joe Runde Date: Tue, 12 Nov 2024 09:55:21 -0700 Subject: [PATCH 16/54] :test_tube: add prefix caching test Signed-off-by: Joe Runde --- tests/core/test_chunked_prefill_scheduler.py | 52 ++++++++++++++++++++ 1 file changed, 52 insertions(+) diff --git a/tests/core/test_chunked_prefill_scheduler.py b/tests/core/test_chunked_prefill_scheduler.py index 1e4de76cda41d..cf3a49f73fa5f 100644 --- a/tests/core/test_chunked_prefill_scheduler.py +++ b/tests/core/test_chunked_prefill_scheduler.py @@ -634,3 +634,55 @@ def test_prefix_caching(): assert seq_group_meta[1].token_chunk_size == 12 assert out.num_prefill_groups == 2 assert out.num_batched_tokens == 62 + + +def test_prefix_caching_with_concurrent_partial_prefills(): + """Verify allocating full blocks when prefix caching is enabled with + --num-prefill-slots > 1.""" + block_size = 4 + max_seqs = 10 + max_model_len = 8000 + max_num_batched_tokens = 60 # With two slots, each slot will get 30 tokens + scheduler_config = SchedulerConfig("generate", + max_num_batched_tokens, + max_seqs, + max_model_len, + enable_chunked_prefill=True, + num_prefill_slots=2) + cache_config = CacheConfig(block_size, + 1.0, + 1, + "auto", + enable_prefix_caching=True) + cache_config.num_cpu_blocks = 0 + cache_config.num_gpu_blocks = 32 + scheduler = Scheduler(scheduler_config, cache_config, None) + running: List[SequenceGroup] = [] + + # Add seq groups to scheduler. + for i in range(2): + _, seq_group = create_dummy_prompt(str(i), + block_size=block_size, + prompt_length=50) + scheduler.add_seq_group(seq_group) + running.append(seq_group) + + seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) + assert set(get_sequence_groups(out)) == set(running) + # To partially prefill both sequences, both can chunk up to 30 tokens + # But the next lowest multiple of the block size (4) is 28 + assert seq_group_meta[0].token_chunk_size == 28 + assert seq_group_meta[1].token_chunk_size == 28 + assert out.num_prefill_groups == 2 + assert out.num_batched_tokens == 56 + + # On the next iteration, both sequences should finish prefill + seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) + assert set(get_sequence_groups(out)) == set(running) + # Both sequences have 50 - 28 = 22 tokens left to prefill. + # This is not a multiple of the block size, but we don't care since we don't + # cache the final partial block of prefix sequences + assert seq_group_meta[0].token_chunk_size == 22 + assert seq_group_meta[1].token_chunk_size == 22 + assert out.num_prefill_groups == 2 + assert out.num_batched_tokens == 44 From 4dc7310bde0c890a0425d780141279bfa568c0fe Mon Sep 17 00:00:00 2001 From: Prashant Gupta Date: Tue, 12 Nov 2024 12:27:16 -0800 Subject: [PATCH 17/54] =?UTF-8?q?=E2=9C=85=20add=20second=20test=20iterati?= =?UTF-8?q?on?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Prashant Gupta --- tests/core/test_chunked_prefill_scheduler.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/tests/core/test_chunked_prefill_scheduler.py b/tests/core/test_chunked_prefill_scheduler.py index cf3a49f73fa5f..19a16b6b7f9bb 100644 --- a/tests/core/test_chunked_prefill_scheduler.py +++ b/tests/core/test_chunked_prefill_scheduler.py @@ -247,6 +247,14 @@ def test_small_prompts_jump_big_prompts_in_queue(): assert out.num_prefill_groups == 3 assert out.num_batched_tokens == 56 + # in the second iteration, both small requests are completed + # so large request gets all the budget + seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) + assert seq_group_meta[ + 0].token_chunk_size == 64 # large req gets all tokens now + assert out.num_prefill_groups == 1 + assert out.num_batched_tokens == 64 + def test_complex(): block_size = 4 From 8e3118e11f63d3776d2b4ca3a853a8bdc233788f Mon Sep 17 00:00:00 2001 From: Prashant Gupta Date: Tue, 12 Nov 2024 15:38:35 -0800 Subject: [PATCH 18/54] =?UTF-8?q?=E2=9C=85=20add=20llm=20engine=20test?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Prashant Gupta --- tests/core/test_chunked_prefill_scheduler.py | 31 ++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/tests/core/test_chunked_prefill_scheduler.py b/tests/core/test_chunked_prefill_scheduler.py index 19a16b6b7f9bb..dfafb886eebe7 100644 --- a/tests/core/test_chunked_prefill_scheduler.py +++ b/tests/core/test_chunked_prefill_scheduler.py @@ -5,6 +5,9 @@ from vllm.config import CacheConfig, SchedulerConfig from vllm.core.scheduler import Scheduler +from vllm.engine.arg_utils import EngineArgs +from vllm.engine.llm_engine import LLMEngine +from vllm.sampling_params import SamplingParams from vllm.sequence import Logprob, SequenceGroup from .utils import create_dummy_prompt @@ -694,3 +697,31 @@ def test_prefix_caching_with_concurrent_partial_prefills(): assert seq_group_meta[1].token_chunk_size == 22 assert out.num_prefill_groups == 2 assert out.num_batched_tokens == 44 + + +@pytest.mark.parametrize("model", ["facebook/opt-125m"]) +@pytest.mark.parametrize("num_prefill_slots", [2, 4, 8]) +def test_chunked_prefill_with_actual_engine(model: str, + num_prefill_slots: int): + + prompt = "hello" * 40 + + engine_args = EngineArgs( + model=model, + num_prefill_slots=num_prefill_slots, + max_num_batched_tokens=40, + max_num_seqs=8, + enable_chunked_prefill=True, + gpu_memory_utilization=0.8, + ) + + engine = LLMEngine.from_engine_args(engine_args) + sampling_params = SamplingParams(temperature=0) + + for req_num in range(num_prefill_slots): + engine.add_request(f"{req_num}", prompt, sampling_params) + # first step + request_outputs = engine.step() + # means all are prefilling + assert len(request_outputs) == 0 + assert len(engine.scheduler[0].running) == num_prefill_slots From b6ebec87e85419728509071afb3a00aa491499e9 Mon Sep 17 00:00:00 2001 From: Prashant Gupta Date: Tue, 12 Nov 2024 15:38:52 -0800 Subject: [PATCH 19/54] =?UTF-8?q?=E2=99=BB=EF=B8=8F=20quicker=20budget=20c?= =?UTF-8?q?heck?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Prashant Gupta --- vllm/core/scheduler.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 1d59e062cf88f..f1ad50e7de6a4 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -936,6 +936,13 @@ def _schedule_prefills( num_new_tokens = self._get_num_new_tokens(seq_group, SequenceStatus.WAITING, enable_chunking, budget) + + num_new_seqs = seq_group.get_max_num_running_seqs() + # quick budget check + if num_new_tokens == 0 or not budget.can_schedule( + num_new_tokens=num_new_tokens, num_new_seqs=num_new_seqs): + break + if not enable_chunking: num_prompt_tokens = waiting_seqs[0].get_len() assert num_new_tokens == num_prompt_tokens @@ -986,12 +993,6 @@ def _schedule_prefills( waiting_queue.popleft() continue - num_new_seqs = seq_group.get_max_num_running_seqs() - if (num_new_tokens == 0 - or not budget.can_schedule(num_new_tokens=num_new_tokens, - num_new_seqs=num_new_seqs)): - break - # Can schedule this request. if curr_loras is not None and lora_int_id > 0: curr_loras.add(lora_int_id) From 7e936683234130843e870e3f075532f27203fe42 Mon Sep 17 00:00:00 2001 From: Prashant Gupta Date: Wed, 13 Nov 2024 10:01:50 -0800 Subject: [PATCH 20/54] =?UTF-8?q?=F0=9F=8E=A8=20rename=20to=20max=5Fnum=5F?= =?UTF-8?q?partial=5Fprefills?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Prashant Gupta --- tests/core/test_chunked_prefill_scheduler.py | 18 +++++++++--------- vllm/config.py | 4 ++-- vllm/core/scheduler.py | 13 +++++++------ vllm/engine/arg_utils.py | 12 ++++++------ 4 files changed, 24 insertions(+), 23 deletions(-) diff --git a/tests/core/test_chunked_prefill_scheduler.py b/tests/core/test_chunked_prefill_scheduler.py index dfafb886eebe7..31e706b6cd7d8 100644 --- a/tests/core/test_chunked_prefill_scheduler.py +++ b/tests/core/test_chunked_prefill_scheduler.py @@ -136,7 +136,7 @@ def test_concurrent_chunking(): max_seqs, max_model_len, enable_chunked_prefill=True, - num_prefill_slots=2, # Up to 2 partial prefills at a time + max_num_partial_prefills=2, # Up to 2 partial prefills at a time ) cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config.num_cpu_blocks = 32 @@ -181,7 +181,7 @@ def test_concurrent_chunking_large_requests(): max_seqs, max_model_len, enable_chunked_prefill=True, - num_prefill_slots=2, # Up to 2 partial prefills at a time + max_num_partial_prefills=2, # Up to 2 partial prefills at a time ) cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config.num_cpu_blocks = 3200 # large KV cache size for large requests @@ -217,7 +217,7 @@ def test_small_prompts_jump_big_prompts_in_queue(): max_seqs, max_model_len, enable_chunked_prefill=True, - num_prefill_slots=2, # Up to 2 partial prefills at a time + max_num_partial_prefills=2, # Up to 2 partial prefills at a time ) cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config.num_cpu_blocks = 3200 # large KV cache size for large requests @@ -659,7 +659,7 @@ def test_prefix_caching_with_concurrent_partial_prefills(): max_seqs, max_model_len, enable_chunked_prefill=True, - num_prefill_slots=2) + max_num_partial_prefills=2) cache_config = CacheConfig(block_size, 1.0, 1, @@ -700,15 +700,15 @@ def test_prefix_caching_with_concurrent_partial_prefills(): @pytest.mark.parametrize("model", ["facebook/opt-125m"]) -@pytest.mark.parametrize("num_prefill_slots", [2, 4, 8]) +@pytest.mark.parametrize("max_num_partial_prefills", [2, 4, 8]) def test_chunked_prefill_with_actual_engine(model: str, - num_prefill_slots: int): + max_num_partial_prefills: int): prompt = "hello" * 40 engine_args = EngineArgs( model=model, - num_prefill_slots=num_prefill_slots, + max_num_partial_prefills=max_num_partial_prefills, max_num_batched_tokens=40, max_num_seqs=8, enable_chunked_prefill=True, @@ -718,10 +718,10 @@ def test_chunked_prefill_with_actual_engine(model: str, engine = LLMEngine.from_engine_args(engine_args) sampling_params = SamplingParams(temperature=0) - for req_num in range(num_prefill_slots): + for req_num in range(max_num_partial_prefills): engine.add_request(f"{req_num}", prompt, sampling_params) # first step request_outputs = engine.step() # means all are prefilling assert len(request_outputs) == 0 - assert len(engine.scheduler[0].running) == num_prefill_slots + assert len(engine.scheduler[0].running) == max_num_partial_prefills diff --git a/vllm/config.py b/vllm/config.py index c0ef0a0595c19..02ec67bed19b6 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1092,7 +1092,7 @@ def __init__(self, max_num_batched_tokens: Optional[int], max_num_seqs: int, max_model_len: int, - num_prefill_slots: int = 1, + max_num_partial_prefills: int = 1, num_lookahead_slots: int = 0, delay_factor: float = 0.0, enable_chunked_prefill: bool = False, @@ -1133,7 +1133,7 @@ def __init__(self, ) self.max_num_batched_tokens = max_num_batched_tokens - self.num_prefill_slots = num_prefill_slots + self.max_num_partial_prefills = max_num_partial_prefills if enable_chunked_prefill: logger.info( diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index f1ad50e7de6a4..a0d5dc03fd321 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -400,7 +400,8 @@ def __init__( # Having multiple partial prefills in flight allows us to minimize TTFT # and avoid decode starvation in cases where a single sequence group # with a very large prompt blocks the queue for too many iterations. - self.num_prefill_slots = scheduler_config.num_prefill_slots + self.max_num_partial_prefills = \ + scheduler_config.max_num_partial_prefills self.prefill_slots_running = 0 self.big_prefill_requests = 0 # Requests with more than (4% max context length) tokens to prefill @@ -421,9 +422,9 @@ def __init__( # budget. self.prefill_chunk_sizes = defaultdict( lambda: scheduler_config.max_num_batched_tokens // self. - num_prefill_slots) + max_num_partial_prefills) self.prefill_chunk_sizes[0] = scheduler_config.max_num_batched_tokens - for i in range(1, self.num_prefill_slots): + for i in range(1, self.max_num_partial_prefills): self.prefill_chunk_sizes[i] = \ scheduler_config.max_num_batched_tokens // i @@ -925,7 +926,7 @@ def _schedule_prefills( is_big = self._is_big_seq_group(seq_group) if is_big and self.big_prefill_requests >= self.max_big_requests \ - and self.num_prefill_slots > 1: + and self.max_num_partial_prefills > 1: # When we have more than one prefill slot, we limit the number # of big requests to avoid filling all of our slots with partial # prefills of big prompt sequences. @@ -1165,7 +1166,7 @@ def _schedule_chunked_prefill(self) -> SchedulerOutputs: # Before any scheduling, decide if we have prefill slots available # to pull new requests from the waiting queue - if self.prefill_slots_running < self.num_prefill_slots and len( + if self.prefill_slots_running < self.max_num_partial_prefills and len( self.waiting) > 0: self._reserve_prefill_slots_from_waiting_queue() @@ -1267,7 +1268,7 @@ def _reserve_prefill_slots_from_waiting_queue(self): queued_big_requests = 0 for seq_group in self.waiting: # Don't fill more slots than we have - if self.prefill_slots_running >= self.num_prefill_slots: + if self.prefill_slots_running >= self.max_num_partial_prefills: break # Disallow multiple big requests diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 4b52f82e71eca..8be768cc14bc1 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -120,7 +120,7 @@ class EngineArgs: cpu_offload_gb: float = 0 # GiB gpu_memory_utilization: float = 0.90 max_num_batched_tokens: Optional[int] = None - num_prefill_slots: Optional[int] = 1 + max_num_partial_prefills: Optional[int] = 1 max_num_seqs: int = 256 max_logprobs: int = 20 # Default value for OpenAI Chat Completions API disable_log_stats: bool = False @@ -469,11 +469,11 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: help='Maximum number of batched tokens per ' 'iteration.') parser.add_argument( - "--num-prefill-slots", + "--max-num-partial-prefills", type=int, - default=EngineArgs.num_prefill_slots, - help='For chunked prefill, the number of prefill slots to use. ' - 'Defaults to 1', + default=EngineArgs.max_num_partial_prefills, + help="For chunked prefill, the number of prefill slots to use. " + "Defaults to 1", ) parser.add_argument('--max-num-seqs', type=int, @@ -1114,7 +1114,7 @@ def create_engine_config(self) -> VllmConfig: send_delta_data=(envs.VLLM_USE_RAY_SPMD_WORKER and parallel_config.use_ray), policy=self.scheduling_policy, - num_prefill_slots=self.num_prefill_slots) + max_num_partial_prefills=self.max_num_partial_prefills) lora_config = LoRAConfig( bias_enabled=self.enable_lora_bias, max_lora_rank=self.max_lora_rank, From 557bfe3c392c43317bd7cbbc7943d0ac9e95450e Mon Sep 17 00:00:00 2001 From: Prashant Gupta Date: Wed, 13 Nov 2024 10:15:32 -0800 Subject: [PATCH 21/54] =?UTF-8?q?=F0=9F=8E=A8=20more=20renaming=20to=20max?= =?UTF-8?q?=5Fnum=5Fpartial=5Fprefills=20+=20docstring=20updates?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Prashant Gupta --- tests/core/test_chunked_prefill_scheduler.py | 4 ++-- vllm/core/scheduler.py | 13 +++++++------ vllm/engine/arg_utils.py | 3 ++- 3 files changed, 11 insertions(+), 9 deletions(-) diff --git a/tests/core/test_chunked_prefill_scheduler.py b/tests/core/test_chunked_prefill_scheduler.py index 31e706b6cd7d8..48f7137a022c7 100644 --- a/tests/core/test_chunked_prefill_scheduler.py +++ b/tests/core/test_chunked_prefill_scheduler.py @@ -125,7 +125,7 @@ def test_chunk(): def test_concurrent_chunking(): - """Verify prefills are chunked properly when --num-prefill-slots is > 1""" + """Verify prefills are chunked properly when --max-num-partial-prefills is > 1""" block_size = 4 max_seqs = 60 max_model_len = 2000 @@ -649,7 +649,7 @@ def test_prefix_caching(): def test_prefix_caching_with_concurrent_partial_prefills(): """Verify allocating full blocks when prefix caching is enabled with - --num-prefill-slots > 1.""" + --max-num-partial-prefills > 1.""" block_size = 4 max_seqs = 10 max_model_len = 8000 diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index a0d5dc03fd321..6b4fa9f7defe2 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -395,8 +395,8 @@ def __init__( # for processing and deallocation by the free_finished_seq_groups() self._async_stopped: List[SequenceGroup] = [] - # For chunked prefill, we allocate a set of "prefill slots" that - # each represent a sequence group that can be partially prefilled. + # For chunked prefill, we allow a certain number of seqs + # to be partially prefilled. # Having multiple partial prefills in flight allows us to minimize TTFT # and avoid decode starvation in cases where a single sequence group # with a very large prompt blocks the queue for too many iterations. @@ -1164,11 +1164,12 @@ def _schedule_chunked_prefill(self) -> SchedulerOutputs: prefills = SchedulerPrefillOutputs.create_empty() swapped_in = SchedulerSwappedInOutputs.create_empty() - # Before any scheduling, decide if we have prefill slots available - # to pull new requests from the waiting queue + # Before any scheduling, look at the reqests in the waiting queue. + # We may decide to budget fewer tokens for running prefills if there are + # requests in the queue we want to prefill concurrently if self.prefill_slots_running < self.max_num_partial_prefills and len( self.waiting) > 0: - self._reserve_prefill_slots_from_waiting_queue() + self._count_prefills_in_waiting_queue() # Decoding should be always scheduled first by fcfs. running_scheduled = self._schedule_running(budget, @@ -1259,7 +1260,7 @@ def _will_still_be_prefilling(self, return seq_group.token_chunk_size != seq_group.seq_group.seqs[ 0].get_num_new_tokens() - def _reserve_prefill_slots_from_waiting_queue(self): + def _count_prefills_in_waiting_queue(self): """Peek into the waiting queue to see how many requests we may be able to start prefilling during this scheduling iteration. This allows us to budget fewer tokens for currently running prefills if we know that more diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 8be768cc14bc1..11f65835325fd 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -472,7 +472,8 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: "--max-num-partial-prefills", type=int, default=EngineArgs.max_num_partial_prefills, - help="For chunked prefill, the number of prefill slots to use. " + help= + "For chunked prefill, the max number of concurrent partial prefills." "Defaults to 1", ) parser.add_argument('--max-num-seqs', From d3e94df91de9136ee1c7416d3e233945a920a6ba Mon Sep 17 00:00:00 2001 From: Prashant Gupta Date: Wed, 13 Nov 2024 10:19:34 -0800 Subject: [PATCH 22/54] =?UTF-8?q?=F0=9F=8E=A8=20rename=20big=20to=20long?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Prashant Gupta --- tests/core/test_chunked_prefill_scheduler.py | 2 +- vllm/core/scheduler.py | 42 ++++++++++---------- 2 files changed, 22 insertions(+), 22 deletions(-) diff --git a/tests/core/test_chunked_prefill_scheduler.py b/tests/core/test_chunked_prefill_scheduler.py index 48f7137a022c7..bd3939cc7785e 100644 --- a/tests/core/test_chunked_prefill_scheduler.py +++ b/tests/core/test_chunked_prefill_scheduler.py @@ -204,7 +204,7 @@ def test_concurrent_chunking_large_requests(): assert out.num_batched_tokens == 64 -def test_small_prompts_jump_big_prompts_in_queue(): +def test_short_prompts_jump_long_prompts_in_queue(): """Verify large prefill requests are punted behind smaller ones if another large prefill request is already running""" block_size = 4 diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 6b4fa9f7defe2..6b74a7d841171 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -403,14 +403,14 @@ def __init__( self.max_num_partial_prefills = \ scheduler_config.max_num_partial_prefills self.prefill_slots_running = 0 - self.big_prefill_requests = 0 + self.long_prefill_requests = 0 # Requests with more than (4% max context length) tokens to prefill - # are "big". - # The number of big prefill requests is limited so that smaller + # are "long". + # The number of long prefill requests is limited so that smaller # requests may jump the queue in front of them and get to the decode # phase faster. - self.big_prefill_threshold = scheduler_config.max_model_len // 25 - self.max_big_requests = 1 # TODO: something + self.long_prefill_threshold = scheduler_config.max_model_len // 25 + self.max_long_requests = 1 # TODO: something # Dict cache with the chunk sizes to hand out to each sequence depending # on how many prefill slots are used. This is slightly faster than @@ -924,12 +924,12 @@ def _schedule_prefills( "Waiting sequence group should have only one prompt " "sequence.") - is_big = self._is_big_seq_group(seq_group) - if is_big and self.big_prefill_requests >= self.max_big_requests \ + is_long = self._is_long_seq_group(seq_group) + if is_long and self.long_prefill_requests >= self.max_long_requests \ and self.max_num_partial_prefills > 1: # When we have more than one prefill slot, we limit the number - # of big requests to avoid filling all of our slots with partial - # prefills of big prompt sequences. + # of long requests to avoid filling all of our slots with partial + # prefills of long prompt sequences. leftover_waiting_sequences.appendleft(seq_group) waiting_queue.popleft() continue @@ -1000,8 +1000,8 @@ def _schedule_prefills( waiting_queue.popleft() self._allocate_and_set_running(seq_group) - if is_big: - self.big_prefill_requests += 1 + if is_long: + self.long_prefill_requests += 1 if enable_chunking and self.scheduler_config.is_multi_step: blocks_to_copy: List[Tuple[int, int]] = [] @@ -1195,9 +1195,9 @@ def _schedule_chunked_prefill(self) -> SchedulerOutputs: # Set slot counts for next iteration self.prefill_slots_running = len(prefilling) - self.big_prefill_requests = len([ + self.long_prefill_requests = len([ seq_group for seq_group in prefilling - if self._is_big_seq_group(seq_group.seq_group) + if self._is_long_seq_group(seq_group.seq_group) ]) assert (budget.num_batched_tokens <= @@ -1246,11 +1246,11 @@ def _schedule_chunked_prefill(self) -> SchedulerOutputs: len(running_scheduled.swapped_out)), ) - def _is_big_seq_group(self, seq_group: SequenceGroup) -> bool: + def _is_long_seq_group(self, seq_group: SequenceGroup) -> bool: """Simple heuristic to check if a sequence group needs a lot of prefill work.""" return seq_group.seqs[0].get_num_new_tokens( - ) >= self.big_prefill_threshold + ) >= self.long_prefill_threshold def _will_still_be_prefilling(self, seq_group: ScheduledSequenceGroup) -> bool: @@ -1266,18 +1266,18 @@ def _count_prefills_in_waiting_queue(self): budget fewer tokens for currently running prefills if we know that more requests from the queue will fit. """ - queued_big_requests = 0 + queued_long_requests = 0 for seq_group in self.waiting: # Don't fill more slots than we have if self.prefill_slots_running >= self.max_num_partial_prefills: break - # Disallow multiple big requests - if self._is_big_seq_group(seq_group): - if self.big_prefill_requests + queued_big_requests \ - >= self.max_big_requests: + # Disallow multiple long requests + if self._is_long_seq_group(seq_group): + if self.long_prefill_requests + queued_long_requests \ + >= self.max_long_requests: continue - queued_big_requests += 1 + queued_long_requests += 1 self.prefill_slots_running += 1 From 849baf6877c670e822429eff2591021f77663afb Mon Sep 17 00:00:00 2001 From: Prashant Gupta Date: Wed, 13 Nov 2024 10:35:37 -0800 Subject: [PATCH 23/54] =?UTF-8?q?=E2=99=BB=EF=B8=8F=20add=20cli=20args=20f?= =?UTF-8?q?or=20partial=5Fprefill=20configs?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Prashant Gupta --- vllm/engine/arg_utils.py | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 11f65835325fd..086eb583c56d9 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -121,6 +121,8 @@ class EngineArgs: gpu_memory_utilization: float = 0.90 max_num_batched_tokens: Optional[int] = None max_num_partial_prefills: Optional[int] = 1 + max_long_partial_prefills: Optional[int] = 1 + long_prefill_threshold: Optional[float] = 0.04 max_num_seqs: int = 256 max_logprobs: int = 20 # Default value for OpenAI Chat Completions API disable_log_stats: bool = False @@ -476,6 +478,21 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: "For chunked prefill, the max number of concurrent partial prefills." "Defaults to 1", ) + parser.add_argument( + "--max-long-partial-prefills", + type=int, + default=EngineArgs.max_long_partial_prefills, + help="For chunked prefill, the max number of long concurrent partial prefills. \ + The length is determined by the long_prefill_threshold argument below" + "Defaults to 1", + ) + parser.add_argument( + "--long-prefill-threshold", + type=float, + default=EngineArgs.long_prefill_threshold, + help="For chunked prefill, a request is considered long if the prompt is longer than \ + the max_model_length * long_prefill_threshold. Defaults to 0.04%", + ) parser.add_argument('--max-num-seqs', type=int, default=EngineArgs.max_num_seqs, @@ -1115,7 +1132,9 @@ def create_engine_config(self) -> VllmConfig: send_delta_data=(envs.VLLM_USE_RAY_SPMD_WORKER and parallel_config.use_ray), policy=self.scheduling_policy, - max_num_partial_prefills=self.max_num_partial_prefills) + max_num_partial_prefills=self.max_num_partial_prefills, + max_long_partial_prefills=self.max_long_partial_prefills, + long_prefill_threshold=self.long_prefill_threshold) lora_config = LoRAConfig( bias_enabled=self.enable_lora_bias, max_lora_rank=self.max_lora_rank, From beaf0868a15abfd00a87cce050605170369b046a Mon Sep 17 00:00:00 2001 From: Prashant Gupta Date: Wed, 13 Nov 2024 10:36:44 -0800 Subject: [PATCH 24/54] =?UTF-8?q?=F0=9F=8E=A8=20fix=20request=20word=20typ?= =?UTF-8?q?o?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Prashant Gupta --- vllm/core/scheduler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 6b74a7d841171..e4d513e517593 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -1164,7 +1164,7 @@ def _schedule_chunked_prefill(self) -> SchedulerOutputs: prefills = SchedulerPrefillOutputs.create_empty() swapped_in = SchedulerSwappedInOutputs.create_empty() - # Before any scheduling, look at the reqests in the waiting queue. + # Before any scheduling, look at the requests in the waiting queue. # We may decide to budget fewer tokens for running prefills if there are # requests in the queue we want to prefill concurrently if self.prefill_slots_running < self.max_num_partial_prefills and len( From 672a50c5c713ab5e37c2bd42e000b050784cda8a Mon Sep 17 00:00:00 2001 From: Prashant Gupta Date: Wed, 13 Nov 2024 11:02:26 -0800 Subject: [PATCH 25/54] =?UTF-8?q?=F0=9F=8E=A8=20more=20docstring=20changes?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Prashant Gupta --- tests/core/test_chunked_prefill_scheduler.py | 3 +- vllm/core/scheduler.py | 34 +++++++++----------- vllm/engine/arg_utils.py | 14 ++++---- 3 files changed, 26 insertions(+), 25 deletions(-) diff --git a/tests/core/test_chunked_prefill_scheduler.py b/tests/core/test_chunked_prefill_scheduler.py index bd3939cc7785e..a2a534da1c9d2 100644 --- a/tests/core/test_chunked_prefill_scheduler.py +++ b/tests/core/test_chunked_prefill_scheduler.py @@ -125,7 +125,8 @@ def test_chunk(): def test_concurrent_chunking(): - """Verify prefills are chunked properly when --max-num-partial-prefills is > 1""" + """Verify prefills are chunked properly when + --max-num-partial-prefills is > 1""" block_size = 4 max_seqs = 60 max_model_len = 2000 diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index e4d513e517593..f0bfb1712f22d 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -2,7 +2,7 @@ import os import random import time -from collections import defaultdict, deque +from collections import deque from dataclasses import dataclass, field from typing import Callable, Deque, Dict, Iterable, List, Optional from typing import Sequence as GenericSequence @@ -412,20 +412,16 @@ def __init__( self.long_prefill_threshold = scheduler_config.max_model_len // 25 self.max_long_requests = 1 # TODO: something - # Dict cache with the chunk sizes to hand out to each sequence depending - # on how many prefill slots are used. This is slightly faster than + # List with the chunk sizes to hand out to each sequence depending + # on how many partial prefills are running. This is slightly faster than # running an integer division every time a prefill is scheduled. - # This splits the budget evenly among all prefill slots. - # We use a defaultdict here to handle the case where we prefill many - # more requests than prefill slots. This is normal when requests have - # very small prompts to prefill, and we want to give them each the same - # budget. - self.prefill_chunk_sizes = defaultdict( - lambda: scheduler_config.max_num_batched_tokens // self. - max_num_partial_prefills) - self.prefill_chunk_sizes[0] = scheduler_config.max_num_batched_tokens + # This splits the budget evenly among all prefills. + self.partial_prefill_budget_lookup_list = [0] * ( + self.max_num_partial_prefills + 1) + self.partial_prefill_budget_lookup_list[ + 0] = scheduler_config.max_num_batched_tokens for i in range(1, self.max_num_partial_prefills): - self.prefill_chunk_sizes[i] = \ + self.partial_prefill_budget_lookup_list[i] = \ scheduler_config.max_num_batched_tokens // i @property @@ -925,11 +921,13 @@ def _schedule_prefills( "sequence.") is_long = self._is_long_seq_group(seq_group) - if is_long and self.long_prefill_requests >= self.max_long_requests \ + if is_long \ + and self.long_prefill_requests >= self.max_long_requests \ and self.max_num_partial_prefills > 1: - # When we have more than one prefill slot, we limit the number - # of long requests to avoid filling all of our slots with partial - # prefills of long prompt sequences. + # When concurrent partial prefills are enabled, + # we limit the number of long requests and only accept + # shorter requests from the queue while running them + # concurrently leftover_waiting_sequences.appendleft(seq_group) waiting_queue.popleft() continue @@ -1740,7 +1738,7 @@ def _get_num_new_tokens(self, seq_group: SequenceGroup, remaining_token_budget = budget.remaining_token_budget() # Get the number of tokens to allocate to this prefill slot prefill_slot_budget = \ - self.prefill_chunk_sizes[self.prefill_slots_running] + self.partial_prefill_budget_lookup_list[self.prefill_slots_running] if self.cache_config.enable_prefix_caching: # When prefix caching is enabled and we're partially prefilling diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 086eb583c56d9..38e424fbc2e04 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -474,24 +474,26 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: "--max-num-partial-prefills", type=int, default=EngineArgs.max_num_partial_prefills, - help= - "For chunked prefill, the max number of concurrent partial prefills." + help="For chunked prefill, the max number of concurrent \ + partial prefills." "Defaults to 1", ) parser.add_argument( "--max-long-partial-prefills", type=int, default=EngineArgs.max_long_partial_prefills, - help="For chunked prefill, the max number of long concurrent partial prefills. \ - The length is determined by the long_prefill_threshold argument below" + help="For chunked prefill, the max number of long concurrent " + "partial prefills. The length is determined by the " + "long-prefill-threshold argument. " "Defaults to 1", ) parser.add_argument( "--long-prefill-threshold", type=float, default=EngineArgs.long_prefill_threshold, - help="For chunked prefill, a request is considered long if the prompt is longer than \ - the max_model_length * long_prefill_threshold. Defaults to 0.04%", + help="For chunked prefill, a request is considered long " + "if the prompt is longer than the " + "max_model_length * long_prefill_threshold. Defaults to 0.04%", ) parser.add_argument('--max-num-seqs', type=int, From a2751ff28ab0085ba87f5a24e300912506e0622d Mon Sep 17 00:00:00 2001 From: Prashant Gupta Date: Wed, 13 Nov 2024 11:16:03 -0800 Subject: [PATCH 26/54] =?UTF-8?q?=F0=9F=8E=A8=20forgot=20to=20add=20the=20?= =?UTF-8?q?new=20args=20to=20config?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Prashant Gupta --- vllm/config.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/vllm/config.py b/vllm/config.py index 02ec67bed19b6..b0f6485d2c641 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1093,6 +1093,8 @@ def __init__(self, max_num_seqs: int, max_model_len: int, max_num_partial_prefills: int = 1, + max_long_partial_prefills: int = 1, + long_prefill_threshold: float = 0.04, num_lookahead_slots: int = 0, delay_factor: float = 0.0, enable_chunked_prefill: bool = False, @@ -1134,6 +1136,8 @@ def __init__(self, self.max_num_batched_tokens = max_num_batched_tokens self.max_num_partial_prefills = max_num_partial_prefills + self.max_long_partial_prefills = max_long_partial_prefills + self.long_prefill_threshold = long_prefill_threshold if enable_chunked_prefill: logger.info( From dff757d8028b39103dda2cdfcc04816fb90b2da8 Mon Sep 17 00:00:00 2001 From: Prashant Gupta Date: Wed, 13 Nov 2024 11:23:33 -0800 Subject: [PATCH 27/54] =?UTF-8?q?=F0=9F=90=9B=20fix=20range=20bug=20on=20p?= =?UTF-8?q?artial=5Fprefill=5Fbudget=5Flookup=5Flist?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Prashant Gupta --- vllm/core/scheduler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index f0bfb1712f22d..77c1bc4b8d347 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -420,7 +420,7 @@ def __init__( self.max_num_partial_prefills + 1) self.partial_prefill_budget_lookup_list[ 0] = scheduler_config.max_num_batched_tokens - for i in range(1, self.max_num_partial_prefills): + for i in range(1, self.max_num_partial_prefills + 1): self.partial_prefill_budget_lookup_list[i] = \ scheduler_config.max_num_batched_tokens // i From 86ffa0456161f30a7f9f3f0d849b148529899299 Mon Sep 17 00:00:00 2001 From: Prashant Gupta Date: Wed, 13 Nov 2024 11:25:03 -0800 Subject: [PATCH 28/54] =?UTF-8?q?=F0=9F=8E=A8=20add=20docstring=20to=20tes?= =?UTF-8?q?t=20function?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Prashant Gupta --- tests/core/test_chunked_prefill_scheduler.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/core/test_chunked_prefill_scheduler.py b/tests/core/test_chunked_prefill_scheduler.py index a2a534da1c9d2..b14380e69ccaa 100644 --- a/tests/core/test_chunked_prefill_scheduler.py +++ b/tests/core/test_chunked_prefill_scheduler.py @@ -704,6 +704,9 @@ def test_prefix_caching_with_concurrent_partial_prefills(): @pytest.mark.parametrize("max_num_partial_prefills", [2, 4, 8]) def test_chunked_prefill_with_actual_engine(model: str, max_num_partial_prefills: int): + """Make sure the model can actually sample with concurrent + partial prefills + """ prompt = "hello" * 40 From 3d399425c79d197c1920d39f9886c76fc7cfea2d Mon Sep 17 00:00:00 2001 From: Joe Runde Date: Wed, 13 Nov 2024 12:59:17 -0700 Subject: [PATCH 29/54] :construction: WIP move metadata to dataclass Signed-off-by: Joe Runde --- vllm/core/scheduler.py | 92 +++++++++++++++++++++++++++++++++++++- vllm/entrypoints/logger.py | 12 ++--- 2 files changed, 97 insertions(+), 7 deletions(-) diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 77c1bc4b8d347..47b1cef81f478 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -14,7 +14,7 @@ from vllm.lora.request import LoRARequest from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sequence import (Sequence, SequenceData, SequenceGroup, - SequenceGroupMetadata, SequenceGroupMetadataDelta, + SequenceGroupMetadata, SequenceGroupMetadataDelta, SequenceStage, SequenceStatus) from vllm.utils import Device, PyObjectCache @@ -294,6 +294,74 @@ def scheduled_seq_group_builder(): token_chunk_size=0) # return ScheduledSequenceGroup(seq_group=None, token_chunk_size=0) +# @dataclass +# class PartialPrefillConfig: +# max_num_partial_prefills: int +# max_long_partial_prefills: int +# long_prefill_threshold: int +# # Default this list to empty +# partial_prefill_budget_lookup_list: List[int] = field(default_factory=list) + +# def __post_init__(self): +# # Initialize partial_prefill_budget_lookup_list here +# # List with the chunk sizes to hand out to each sequence depending +# # on how many partial prefills are running. This is slightly faster than +# # running an integer division every time a prefill is scheduled. +# # This splits the budget evenly among all prefills. +# self.partial_prefill_budget_lookup_list = [0] * ( +# self.max_num_partial_prefills + 1) +# self.partial_prefill_budget_lookup_list[ +# 0] = scheduler_config.max_num_batched_tokens +# for i in range(1, self.max_num_partial_prefills + 1): +# self.partial_prefill_budget_lookup_list[i] = \ +# scheduler_config.max_num_batched_tokens // i + + +@dataclass +class PartialPrefillMetadata: + """Holds information about the partial prefills that are currently running.""" + partial_prefills: int + long_partial_prefills: int + + waiting_partial_prefills: int + waiting_long_partial_prefills: int + + def from_queues(running: Deque[SequenceGroup], + waiting: Deque[SequenceGroup], + long_prefill_threshold: int, + max_partial_prefills: int, + max_long_prefills: int) -> "PartialPrefillMetadata": + """Create a PartialPrefillMetadata object from the running queue.""" + partial_prefills = 0 + long_partial_prefills = 0 + + waiting_partial_prefills = 0 + waiting_long_prefills = 0 + + for sg in running: + # TODO: Check if this stage is correctly updated before scheduling + if sg.first_seq.data.stage == SequenceStage.PREFILL: + partial_prefills += 1 + if sg.first_seq.get_num_new_tokens() > long_prefill_threshold: + long_partial_prefills += 1 + + for sg in waiting: + # Don't bother looping through the rest of the queue if we know there are already at least max_partial_prefills requests to fill + if partial_prefills + waiting_partial_prefills >= max_partial_prefills: + break + + # Disallow multiple long requests + if sg.first_seq.get_num_new_tokens() > long_prefill_threshold: + if long_partial_prefills + waiting_long_prefills >= max_long_prefills: + continue + waiting_long_prefills += 1 + waiting_partial_prefills += 1 + + return PartialPrefillMetadata(partial_prefills, + long_partial_prefills, + waiting_partial_prefills, + waiting_long_prefills) + class Scheduler: @@ -423,6 +491,22 @@ def __init__( for i in range(1, self.max_num_partial_prefills + 1): self.partial_prefill_budget_lookup_list[i] = \ scheduler_config.max_num_batched_tokens // i + + @dataclass + class PartialPrefillConfig: + max_num_partial_prefills: int + max_long_partial_prefills: int + long_prefill_threshold: int + partial_prefill_budget_lookup_list: list + + + @dataclass + class PartialPrefillMetadata: + partial_prefills: int + long_partial_prefills: int + + def from_running_queue(running): + # ... @property def next_cache_id(self): @@ -1162,6 +1246,12 @@ def _schedule_chunked_prefill(self) -> SchedulerOutputs: prefills = SchedulerPrefillOutputs.create_empty() swapped_in = SchedulerSwappedInOutputs.create_empty() + # Create partial prefill metadata + partial_prefill_metadata = PartialPrefillMetadata.from_running_queue( + self.running, self.long_prefill_threshold + ) + + # Before any scheduling, look at the requests in the waiting queue. # We may decide to budget fewer tokens for running prefills if there are # requests in the queue we want to prefill concurrently diff --git a/vllm/entrypoints/logger.py b/vllm/entrypoints/logger.py index 584ee0d9e1c54..7e0880689f007 100644 --- a/vllm/entrypoints/logger.py +++ b/vllm/entrypoints/logger.py @@ -34,9 +34,9 @@ def log_inputs( if prompt_token_ids is not None: prompt_token_ids = prompt_token_ids[:max_log_len] - logger.info( - "Received request %s: prompt: %r, " - "params: %s, prompt_token_ids: %s, " - "lora_request: %s, prompt_adapter_request: %s.", request_id, - prompt, params, prompt_token_ids, lora_request, - prompt_adapter_request) + # logger.info( + # "Received request %s: prompt: %r, " + # "params: %s, prompt_token_ids: %s, " + # "lora_request: %s, prompt_adapter_request: %s.", request_id, + # prompt, params, prompt_token_ids, lora_request, + # prompt_adapter_request) From dbb9ae8c29b6d660ef2c2b23c67f581459a65a76 Mon Sep 17 00:00:00 2001 From: Prashant Gupta Date: Wed, 13 Nov 2024 14:12:41 -0800 Subject: [PATCH 30/54] =?UTF-8?q?=F0=9F=8E=A8=20wrap=20up=20PartialPrefill?= =?UTF-8?q?Metadata?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Prashant Gupta --- vllm/core/scheduler.py | 160 ++++++++++++++++++++++------------------- 1 file changed, 87 insertions(+), 73 deletions(-) diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 47b1cef81f478..61ffa28fb886b 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -14,8 +14,8 @@ from vllm.lora.request import LoRARequest from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sequence import (Sequence, SequenceData, SequenceGroup, - SequenceGroupMetadata, SequenceGroupMetadataDelta, SequenceStage, - SequenceStatus) + SequenceGroupMetadata, SequenceGroupMetadataDelta, + SequenceStage, SequenceStatus) from vllm.utils import Device, PyObjectCache logger = init_logger(__name__) @@ -294,19 +294,23 @@ def scheduled_seq_group_builder(): token_chunk_size=0) # return ScheduledSequenceGroup(seq_group=None, token_chunk_size=0) + # @dataclass # class PartialPrefillConfig: # max_num_partial_prefills: int # max_long_partial_prefills: int # long_prefill_threshold: int # # Default this list to empty -# partial_prefill_budget_lookup_list: List[int] = field(default_factory=list) +# partial_prefill_budget_lookup_list: List[int] = \ +# field(default_factory=list) # def __post_init__(self): # # Initialize partial_prefill_budget_lookup_list here # # List with the chunk sizes to hand out to each sequence depending -# # on how many partial prefills are running. This is slightly faster than -# # running an integer division every time a prefill is scheduled. +# # on how many partial prefills are running. +# # This is slightly faster than +# # running an integer division every time a prefill is +# # scheduled. # # This splits the budget evenly among all prefills. # self.partial_prefill_budget_lookup_list = [0] * ( # self.max_num_partial_prefills + 1) @@ -319,16 +323,15 @@ def scheduled_seq_group_builder(): @dataclass class PartialPrefillMetadata: - """Holds information about the partial prefills that are currently running.""" + """Holds information about the partial prefills that are + currently running.""" partial_prefills: int long_partial_prefills: int - waiting_partial_prefills: int - waiting_long_partial_prefills: int - - def from_queues(running: Deque[SequenceGroup], - waiting: Deque[SequenceGroup], - long_prefill_threshold: int, + # def can_schedule(): + @classmethod + def from_queues(cls, running: Deque[SequenceGroup], + waiting: Deque[SequenceGroup], long_prefill_threshold: int, max_partial_prefills: int, max_long_prefills: int) -> "PartialPrefillMetadata": """Create a PartialPrefillMetadata object from the running queue.""" @@ -346,21 +349,23 @@ def from_queues(running: Deque[SequenceGroup], long_partial_prefills += 1 for sg in waiting: - # Don't bother looping through the rest of the queue if we know there are already at least max_partial_prefills requests to fill - if partial_prefills + waiting_partial_prefills >= max_partial_prefills: + # Don't bother looping through the rest of the queue + # if we know there are already at + # least max_partial_prefills requests to fill + if partial_prefills + waiting_partial_prefills \ + >= max_partial_prefills: break # Disallow multiple long requests if sg.first_seq.get_num_new_tokens() > long_prefill_threshold: - if long_partial_prefills + waiting_long_prefills >= max_long_prefills: + if long_partial_prefills + waiting_long_prefills \ + >= max_long_prefills: continue waiting_long_prefills += 1 waiting_partial_prefills += 1 - return PartialPrefillMetadata(partial_prefills, - long_partial_prefills, - waiting_partial_prefills, - waiting_long_prefills) + return PartialPrefillMetadata( + partial_prefills + waiting_partial_prefills, long_partial_prefills) class Scheduler: @@ -472,13 +477,13 @@ def __init__( scheduler_config.max_num_partial_prefills self.prefill_slots_running = 0 self.long_prefill_requests = 0 - # Requests with more than (4% max context length) tokens to prefill - # are "long". # The number of long prefill requests is limited so that smaller # requests may jump the queue in front of them and get to the decode # phase faster. - self.long_prefill_threshold = scheduler_config.max_model_len // 25 - self.max_long_requests = 1 # TODO: something + self.long_prefill_threshold = int( + scheduler_config.max_model_len * + scheduler_config.long_prefill_threshold) + self.max_long_requests = scheduler_config.max_long_partial_prefills # List with the chunk sizes to hand out to each sequence depending # on how many partial prefills are running. This is slightly faster than @@ -491,22 +496,21 @@ def __init__( for i in range(1, self.max_num_partial_prefills + 1): self.partial_prefill_budget_lookup_list[i] = \ scheduler_config.max_num_batched_tokens // i - - @dataclass - class PartialPrefillConfig: - max_num_partial_prefills: int - max_long_partial_prefills: int - long_prefill_threshold: int - partial_prefill_budget_lookup_list: list + # @dataclass + # class PartialPrefillConfig: + # max_num_partial_prefills: int + # max_long_partial_prefills: int + # long_prefill_threshold: int + # partial_prefill_budget_lookup_list: list - @dataclass - class PartialPrefillMetadata: - partial_prefills: int - long_partial_prefills: int + # @dataclass + # class PartialPrefillMetadata: + # partial_prefills: int + # long_partial_prefills: int - def from_running_queue(running): - # ... + # def from_running_queue(running): + # # ... @property def next_cache_id(self): @@ -958,6 +962,7 @@ def _schedule_prefills( budget: SchedulingBudget, curr_loras: Optional[Set[int]], enable_chunking: bool = False, + partial_prefill_metadata: Optional[PartialPrefillMetadata] = None, ) -> SchedulerPrefillOutputs: """Schedule sequence groups that are in prefill stage. @@ -978,6 +983,8 @@ def _schedule_prefills( chunked number of tokens are scheduled if `budget.num_batched_tokens` has not enough capacity to schedule all tokens. + partial_prefill_metadata: information about the partial prefills + that are currently running Returns: SchedulerPrefillOutputs. @@ -1004,10 +1011,12 @@ def _schedule_prefills( "Waiting sequence group should have only one prompt " "sequence.") - is_long = self._is_long_seq_group(seq_group) - if is_long \ - and self.long_prefill_requests >= self.max_long_requests \ - and self.max_num_partial_prefills > 1: + if (partial_prefill_metadata is not None + and seq_group.first_seq.get_num_new_tokens() > + self.long_prefill_threshold + and partial_prefill_metadata.long_partial_prefills >= + self.max_long_requests + and self.max_num_partial_prefills > 1): # When concurrent partial prefills are enabled, # we limit the number of long requests and only accept # shorter requests from the queue while running them @@ -1016,9 +1025,12 @@ def _schedule_prefills( waiting_queue.popleft() continue - num_new_tokens = self._get_num_new_tokens(seq_group, - SequenceStatus.WAITING, - enable_chunking, budget) + num_new_tokens = self._get_num_new_tokens( + seq_group, + SequenceStatus.WAITING, + enable_chunking, + budget, + partial_prefill_metadata=partial_prefill_metadata) num_new_seqs = seq_group.get_max_num_running_seqs() # quick budget check @@ -1082,8 +1094,10 @@ def _schedule_prefills( waiting_queue.popleft() self._allocate_and_set_running(seq_group) - if is_long: - self.long_prefill_requests += 1 + if partial_prefill_metadata is not None and \ + seq_group.first_seq.get_num_new_tokens( + ) > self.long_prefill_threshold: + partial_prefill_metadata.long_partial_prefills += 1 if enable_chunking and self.scheduler_config.is_multi_step: blocks_to_copy: List[Tuple[int, int]] = [] @@ -1247,17 +1261,19 @@ def _schedule_chunked_prefill(self) -> SchedulerOutputs: swapped_in = SchedulerSwappedInOutputs.create_empty() # Create partial prefill metadata - partial_prefill_metadata = PartialPrefillMetadata.from_running_queue( - self.running, self.long_prefill_threshold - ) - + partial_prefill_metadata = PartialPrefillMetadata.from_queues( + running=self.running, + waiting=self.waiting, + long_prefill_threshold=self.long_prefill_threshold, + max_partial_prefills=self.max_num_partial_prefills, + max_long_prefills=self.max_long_requests) # Before any scheduling, look at the requests in the waiting queue. # We may decide to budget fewer tokens for running prefills if there are # requests in the queue we want to prefill concurrently - if self.prefill_slots_running < self.max_num_partial_prefills and len( - self.waiting) > 0: - self._count_prefills_in_waiting_queue() + # if self.prefill_slots_running < self.max_num_partial_prefills and len( + # self.waiting) > 0: + # self._count_prefills_in_waiting_queue() # Decoding should be always scheduled first by fcfs. running_scheduled = self._schedule_running(budget, @@ -1271,22 +1287,12 @@ def _schedule_chunked_prefill(self) -> SchedulerOutputs: swapped_in = self._schedule_swapped(budget, curr_loras) # Schedule new prefills. - prefills = self._schedule_prefills(budget, - curr_loras, - enable_chunking=True) - - prefilling = running_scheduled.prefill_seq_groups + prefills.seq_groups - - prefilling = [ - p for p in prefilling if self._will_still_be_prefilling(p) - ] - - # Set slot counts for next iteration - self.prefill_slots_running = len(prefilling) - self.long_prefill_requests = len([ - seq_group for seq_group in prefilling - if self._is_long_seq_group(seq_group.seq_group) - ]) + prefills = self._schedule_prefills( + budget, + curr_loras, + enable_chunking=True, + partial_prefill_metadata=partial_prefill_metadata, + ) assert (budget.num_batched_tokens <= self.scheduler_config.max_num_batched_tokens) @@ -1783,9 +1789,14 @@ def _get_num_lookahead_slots(self, is_prefill: bool, return self.scheduler_config.num_lookahead_slots - def _get_num_new_tokens(self, seq_group: SequenceGroup, - status: SequenceStatus, enable_chunking: bool, - budget: SchedulingBudget) -> int: + def _get_num_new_tokens( + self, + seq_group: SequenceGroup, + status: SequenceStatus, + enable_chunking: bool, + budget: SchedulingBudget, + partial_prefill_metadata: Optional[PartialPrefillMetadata] = None, + ) -> int: """Get the next new tokens to compute for a given sequence group that's in a given `status`. @@ -1827,8 +1838,11 @@ def _get_num_new_tokens(self, seq_group: SequenceGroup, elif enable_chunking and len(seqs) == 1: remaining_token_budget = budget.remaining_token_budget() # Get the number of tokens to allocate to this prefill slot - prefill_slot_budget = \ - self.partial_prefill_budget_lookup_list[self.prefill_slots_running] + prefill_slot_budget = remaining_token_budget \ + if partial_prefill_metadata is None \ + else self.partial_prefill_budget_lookup_list[ + partial_prefill_metadata.partial_prefills + ] if self.cache_config.enable_prefix_caching: # When prefix caching is enabled and we're partially prefilling From 4bac8ed1b85c0641324a7a7f28137ae68e47363c Mon Sep 17 00:00:00 2001 From: Prashant Gupta Date: Wed, 13 Nov 2024 14:33:14 -0800 Subject: [PATCH 31/54] =?UTF-8?q?=E2=99=BB=EF=B8=8F=20add=20some=20utility?= =?UTF-8?q?=20functions=20within=20partial=5Fprefill=5Fmetadata?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Prashant Gupta --- vllm/core/scheduler.py | 142 +++++++++++++++-------------------------- 1 file changed, 50 insertions(+), 92 deletions(-) diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 61ffa28fb886b..4a5c4b9df769c 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -323,17 +323,41 @@ def scheduled_seq_group_builder(): @dataclass class PartialPrefillMetadata: - """Holds information about the partial prefills that are - currently running.""" + """Holds information about the partial prefills that are + currently running. For chunked prefill, we allow a certain number of seqs + to be partially prefilled. Having multiple partial prefills in flight + allows us to minimize TTFT and avoid decode starvation in cases + where a single sequence group with a very large prompt blocks + the queue for too many iterations. + + The number of long prefill requests is limited so that smaller + requests may jump the queue in front of them and get to the decode + phase faster. + """ partial_prefills: int long_partial_prefills: int + scheduler_config: SchedulerConfig + + def cannot_schedule(self, seq_group: SequenceGroup): + """When concurrent partial prefills are enabled, + we limit the number of long requests and only accept + shorter requests from the queue while running them + concurrently""" + return seq_group.first_seq.get_num_new_tokens() > \ + self.scheduler_config.long_prefill_token_threshold \ + and self.long_partial_prefills >= \ + self.scheduler_config.max_long_partial_prefills \ + and self.scheduler_config.max_num_partial_prefills > 1 + + def increment_partial_prefills(self, seq_group: SequenceGroup): + if (seq_group.first_seq.get_num_new_tokens() > + self.scheduler_config.long_prefill_token_threshold): + self.long_partial_prefills += 1 - # def can_schedule(): @classmethod - def from_queues(cls, running: Deque[SequenceGroup], - waiting: Deque[SequenceGroup], long_prefill_threshold: int, - max_partial_prefills: int, - max_long_prefills: int) -> "PartialPrefillMetadata": + def from_queues( + cls, running: Deque[SequenceGroup], waiting: Deque[SequenceGroup], + scheduler_config: SchedulerConfig) -> "PartialPrefillMetadata": """Create a PartialPrefillMetadata object from the running queue.""" partial_prefills = 0 long_partial_prefills = 0 @@ -345,7 +369,8 @@ def from_queues(cls, running: Deque[SequenceGroup], # TODO: Check if this stage is correctly updated before scheduling if sg.first_seq.data.stage == SequenceStage.PREFILL: partial_prefills += 1 - if sg.first_seq.get_num_new_tokens() > long_prefill_threshold: + if sg.first_seq.get_num_new_tokens( + ) > scheduler_config.long_prefill_token_threshold: long_partial_prefills += 1 for sg in waiting: @@ -353,19 +378,22 @@ def from_queues(cls, running: Deque[SequenceGroup], # if we know there are already at # least max_partial_prefills requests to fill if partial_prefills + waiting_partial_prefills \ - >= max_partial_prefills: + >= scheduler_config.max_num_partial_prefills: break # Disallow multiple long requests - if sg.first_seq.get_num_new_tokens() > long_prefill_threshold: + if sg.first_seq.get_num_new_tokens( + ) > scheduler_config.long_prefill_token_threshold: if long_partial_prefills + waiting_long_prefills \ - >= max_long_prefills: + >= scheduler_config.max_long_partial_prefills: continue waiting_long_prefills += 1 waiting_partial_prefills += 1 - return PartialPrefillMetadata( - partial_prefills + waiting_partial_prefills, long_partial_prefills) + return PartialPrefillMetadata(partial_prefills + + waiting_partial_prefills, + long_partial_prefills, + scheduler_config=scheduler_config) class Scheduler: @@ -468,32 +496,15 @@ def __init__( # for processing and deallocation by the free_finished_seq_groups() self._async_stopped: List[SequenceGroup] = [] - # For chunked prefill, we allow a certain number of seqs - # to be partially prefilled. - # Having multiple partial prefills in flight allows us to minimize TTFT - # and avoid decode starvation in cases where a single sequence group - # with a very large prompt blocks the queue for too many iterations. - self.max_num_partial_prefills = \ - scheduler_config.max_num_partial_prefills - self.prefill_slots_running = 0 - self.long_prefill_requests = 0 - # The number of long prefill requests is limited so that smaller - # requests may jump the queue in front of them and get to the decode - # phase faster. - self.long_prefill_threshold = int( - scheduler_config.max_model_len * - scheduler_config.long_prefill_threshold) - self.max_long_requests = scheduler_config.max_long_partial_prefills - # List with the chunk sizes to hand out to each sequence depending # on how many partial prefills are running. This is slightly faster than # running an integer division every time a prefill is scheduled. # This splits the budget evenly among all prefills. self.partial_prefill_budget_lookup_list = [0] * ( - self.max_num_partial_prefills + 1) + self.scheduler_config.max_num_partial_prefills + 1) self.partial_prefill_budget_lookup_list[ 0] = scheduler_config.max_num_batched_tokens - for i in range(1, self.max_num_partial_prefills + 1): + for i in range(1, self.scheduler_config.max_num_partial_prefills + 1): self.partial_prefill_budget_lookup_list[i] = \ scheduler_config.max_num_batched_tokens // i @@ -1006,21 +1017,14 @@ def _schedule_prefills( while self._passed_delay(time.time()) and waiting_queue: seq_group = waiting_queue[0] - waiting_seqs = seq_group.get_seqs(status=SequenceStatus.WAITING) + waiting_seqs = \ + seq_group.get_seqs(status=SequenceStatus.WAITING) assert len(waiting_seqs) == 1, ( "Waiting sequence group should have only one prompt " "sequence.") - if (partial_prefill_metadata is not None - and seq_group.first_seq.get_num_new_tokens() > - self.long_prefill_threshold - and partial_prefill_metadata.long_partial_prefills >= - self.max_long_requests - and self.max_num_partial_prefills > 1): - # When concurrent partial prefills are enabled, - # we limit the number of long requests and only accept - # shorter requests from the queue while running them - # concurrently + if partial_prefill_metadata is not None and \ + partial_prefill_metadata.cannot_schedule(seq_group): leftover_waiting_sequences.appendleft(seq_group) waiting_queue.popleft() continue @@ -1094,10 +1098,8 @@ def _schedule_prefills( waiting_queue.popleft() self._allocate_and_set_running(seq_group) - if partial_prefill_metadata is not None and \ - seq_group.first_seq.get_num_new_tokens( - ) > self.long_prefill_threshold: - partial_prefill_metadata.long_partial_prefills += 1 + if partial_prefill_metadata is not None: + partial_prefill_metadata.increment_partial_prefills(seq_group) if enable_chunking and self.scheduler_config.is_multi_step: blocks_to_copy: List[Tuple[int, int]] = [] @@ -1264,16 +1266,7 @@ def _schedule_chunked_prefill(self) -> SchedulerOutputs: partial_prefill_metadata = PartialPrefillMetadata.from_queues( running=self.running, waiting=self.waiting, - long_prefill_threshold=self.long_prefill_threshold, - max_partial_prefills=self.max_num_partial_prefills, - max_long_prefills=self.max_long_requests) - - # Before any scheduling, look at the requests in the waiting queue. - # We may decide to budget fewer tokens for running prefills if there are - # requests in the queue we want to prefill concurrently - # if self.prefill_slots_running < self.max_num_partial_prefills and len( - # self.waiting) > 0: - # self._count_prefills_in_waiting_queue() + scheduler_config=self.scheduler_config) # Decoding should be always scheduled first by fcfs. running_scheduled = self._schedule_running(budget, @@ -1340,41 +1333,6 @@ def _schedule_chunked_prefill(self) -> SchedulerOutputs: len(running_scheduled.swapped_out)), ) - def _is_long_seq_group(self, seq_group: SequenceGroup) -> bool: - """Simple heuristic to check if a sequence group needs a lot of prefill - work.""" - return seq_group.seqs[0].get_num_new_tokens( - ) >= self.long_prefill_threshold - - def _will_still_be_prefilling(self, - seq_group: ScheduledSequenceGroup) -> bool: - """Check if a sequence will be mid-prefill after this iteration. - We need to know how many partial prefills will be running in order to - properly budget the next iteration.""" - return seq_group.token_chunk_size != seq_group.seq_group.seqs[ - 0].get_num_new_tokens() - - def _count_prefills_in_waiting_queue(self): - """Peek into the waiting queue to see how many requests we may be able - to start prefilling during this scheduling iteration. This allows us to - budget fewer tokens for currently running prefills if we know that more - requests from the queue will fit. - """ - queued_long_requests = 0 - for seq_group in self.waiting: - # Don't fill more slots than we have - if self.prefill_slots_running >= self.max_num_partial_prefills: - break - - # Disallow multiple long requests - if self._is_long_seq_group(seq_group): - if self.long_prefill_requests + queued_long_requests \ - >= self.max_long_requests: - continue - queued_long_requests += 1 - - self.prefill_slots_running += 1 - def _schedule(self) -> SchedulerOutputs: """Schedule queued requests.""" if self.scheduler_config.chunked_prefill_enabled: From c44ca1fbfa795a20bbc81925daa311ee57b67363 Mon Sep 17 00:00:00 2001 From: Prashant Gupta Date: Wed, 13 Nov 2024 14:33:32 -0800 Subject: [PATCH 32/54] =?UTF-8?q?=F0=9F=8E=A8=20change=20to=20long=5Fprefi?= =?UTF-8?q?ll=5Ftoken=5Fthreshold?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Prashant Gupta --- vllm/config.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/config.py b/vllm/config.py index b0f6485d2c641..ae60f0e7d4830 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1137,7 +1137,8 @@ def __init__(self, self.max_num_batched_tokens = max_num_batched_tokens self.max_num_partial_prefills = max_num_partial_prefills self.max_long_partial_prefills = max_long_partial_prefills - self.long_prefill_threshold = long_prefill_threshold + self.long_prefill_token_threshold = int(max_model_len * + long_prefill_threshold) if enable_chunked_prefill: logger.info( From 38bad7aad4d4ca0df6485f7bdec7344abc544c63 Mon Sep 17 00:00:00 2001 From: Prashant Gupta Date: Wed, 13 Nov 2024 14:56:44 -0800 Subject: [PATCH 33/54] =?UTF-8?q?=F0=9F=94=A5=20remove=20commented=20code?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Prashant Gupta --- vllm/core/scheduler.py | 43 ++---------------------------------------- 1 file changed, 2 insertions(+), 41 deletions(-) diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 4a5c4b9df769c..caf0a2db9b746 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -295,32 +295,6 @@ def scheduled_seq_group_builder(): # return ScheduledSequenceGroup(seq_group=None, token_chunk_size=0) -# @dataclass -# class PartialPrefillConfig: -# max_num_partial_prefills: int -# max_long_partial_prefills: int -# long_prefill_threshold: int -# # Default this list to empty -# partial_prefill_budget_lookup_list: List[int] = \ -# field(default_factory=list) - -# def __post_init__(self): -# # Initialize partial_prefill_budget_lookup_list here -# # List with the chunk sizes to hand out to each sequence depending -# # on how many partial prefills are running. -# # This is slightly faster than -# # running an integer division every time a prefill is -# # scheduled. -# # This splits the budget evenly among all prefills. -# self.partial_prefill_budget_lookup_list = [0] * ( -# self.max_num_partial_prefills + 1) -# self.partial_prefill_budget_lookup_list[ -# 0] = scheduler_config.max_num_batched_tokens -# for i in range(1, self.max_num_partial_prefills + 1): -# self.partial_prefill_budget_lookup_list[i] = \ -# scheduler_config.max_num_batched_tokens // i - - @dataclass class PartialPrefillMetadata: """Holds information about the partial prefills that are @@ -508,21 +482,6 @@ def __init__( self.partial_prefill_budget_lookup_list[i] = \ scheduler_config.max_num_batched_tokens // i - # @dataclass - # class PartialPrefillConfig: - # max_num_partial_prefills: int - # max_long_partial_prefills: int - # long_prefill_threshold: int - # partial_prefill_budget_lookup_list: list - - # @dataclass - # class PartialPrefillMetadata: - # partial_prefills: int - # long_partial_prefills: int - - # def from_running_queue(running): - # # ... - @property def next_cache_id(self): return (self.cache_id + 1) % self.num_cache_iters @@ -1268,6 +1227,8 @@ def _schedule_chunked_prefill(self) -> SchedulerOutputs: waiting=self.waiting, scheduler_config=self.scheduler_config) + print("partial_prefill_metadata : ", partial_prefill_metadata) + # Decoding should be always scheduled first by fcfs. running_scheduled = self._schedule_running(budget, curr_loras, From 0f3efa15d00cc26a6cc0d0be9caa47f5807bb544 Mon Sep 17 00:00:00 2001 From: Prashant Gupta Date: Wed, 13 Nov 2024 15:14:10 -0800 Subject: [PATCH 34/54] =?UTF-8?q?=F0=9F=90=9B=20fix=20the=20big=20bug!=20(?= =?UTF-8?q?Thanks=20Joe)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Prashant Gupta --- vllm/core/scheduler.py | 29 +++++++++++++++++++---------- 1 file changed, 19 insertions(+), 10 deletions(-) diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index caf0a2db9b746..4339a0aa27e71 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -343,9 +343,9 @@ def from_queues( # TODO: Check if this stage is correctly updated before scheduling if sg.first_seq.data.stage == SequenceStage.PREFILL: partial_prefills += 1 - if sg.first_seq.get_num_new_tokens( - ) > scheduler_config.long_prefill_token_threshold: - long_partial_prefills += 1 + if sg.first_seq.get_num_new_tokens( + ) > scheduler_config.long_prefill_token_threshold: + long_partial_prefills += 1 for sg in waiting: # Don't bother looping through the rest of the queue @@ -581,6 +581,7 @@ def _schedule_running( budget: SchedulingBudget, curr_loras: Optional[Set[int]], enable_chunking: bool = False, + partial_prefill_metadata: Optional[PartialPrefillMetadata] = None, ) -> SchedulerRunningOutputs: """Schedule sequence groups that are running. @@ -595,7 +596,9 @@ def _schedule_running( chunked number of tokens are scheduled if `budget.num_batched_tokens` has not enough capacity to schedule all tokens. - + partial_prefill_metadata: information about the partial prefills + that are currently running + Returns: SchedulerRunningOutputs. """ @@ -629,7 +632,12 @@ def _schedule_running( while running_queue: seq_group = running_queue[0] num_running_tokens = self._get_num_new_tokens( - seq_group, SequenceStatus.RUNNING, enable_chunking, budget) + seq_group, + SequenceStatus.RUNNING, + enable_chunking, + budget, + partial_prefill_metadata, + ) if num_running_tokens == 0: # No budget => Stop @@ -1227,12 +1235,13 @@ def _schedule_chunked_prefill(self) -> SchedulerOutputs: waiting=self.waiting, scheduler_config=self.scheduler_config) - print("partial_prefill_metadata : ", partial_prefill_metadata) - # Decoding should be always scheduled first by fcfs. - running_scheduled = self._schedule_running(budget, - curr_loras, - enable_chunking=True) + running_scheduled = self._schedule_running( + budget, + curr_loras, + enable_chunking=True, + partial_prefill_metadata=partial_prefill_metadata, + ) # Schedule swapped out requests. # If preemption happens, it means we don't have space for swap-in. From 3daf35fed5a599b56fb8a5a325f7c15b42bf68b9 Mon Sep 17 00:00:00 2001 From: Joe Runde Date: Wed, 13 Nov 2024 16:17:26 -0700 Subject: [PATCH 35/54] :memo: docstings galore Signed-off-by: Joe Runde --- vllm/core/scheduler.py | 35 +++++++++++++++++++++++++---------- 1 file changed, 25 insertions(+), 10 deletions(-) diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 4339a0aa27e71..a664c3c747d60 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -297,19 +297,27 @@ def scheduled_seq_group_builder(): @dataclass class PartialPrefillMetadata: - """Holds information about the partial prefills that are - currently running. For chunked prefill, we allow a certain number of seqs - to be partially prefilled. Having multiple partial prefills in flight - allows us to minimize TTFT and avoid decode starvation in cases - where a single sequence group with a very large prompt blocks - the queue for too many iterations. + """Holds information about the partial prefills that are currently running + during a single iteration of the Scheduler. + + When chunked prefill is enabled, we allow a certain number of seqs to be + partially prefilled during each iteration. Having multiple partial prefills + in flight allows us to minimize TTFT and avoid decode starvation in cases + where a single sequence group with a very large prompt blocks the queue for + too many iterations. The number of long prefill requests is limited so that smaller requests may jump the queue in front of them and get to the decode phase faster. """ + + # A minimum bound on the total number of prefills running during this + # scheduling step partial_prefills: int + + # The number of long prefill requests currently running long_partial_prefills: int + scheduler_config: SchedulerConfig def cannot_schedule(self, seq_group: SequenceGroup): @@ -324,6 +332,7 @@ def cannot_schedule(self, seq_group: SequenceGroup): and self.scheduler_config.max_num_partial_prefills > 1 def increment_partial_prefills(self, seq_group: SequenceGroup): + # When a new prefill is scheduled, we need to know if it is a if (seq_group.first_seq.get_num_new_tokens() > self.scheduler_config.long_prefill_token_threshold): self.long_partial_prefills += 1 @@ -332,7 +341,12 @@ def increment_partial_prefills(self, seq_group: SequenceGroup): def from_queues( cls, running: Deque[SequenceGroup], waiting: Deque[SequenceGroup], scheduler_config: SchedulerConfig) -> "PartialPrefillMetadata": - """Create a PartialPrefillMetadata object from the running queue.""" + """Create a PartialPrefillMetadata object from the current state of + the scheduler's queues. + + This accounts for the currently running prefill requests, and peeks into + the waiting queue to see if there are more prefills to potentially be + scheduled during this iteration.""" partial_prefills = 0 long_partial_prefills = 0 @@ -348,14 +362,15 @@ def from_queues( long_partial_prefills += 1 for sg in waiting: - # Don't bother looping through the rest of the queue - # if we know there are already at + # Don't bother looping through the rest of the queue if we know + # there are already at # least max_partial_prefills requests to fill if partial_prefills + waiting_partial_prefills \ >= scheduler_config.max_num_partial_prefills: break - # Disallow multiple long requests + # Don't count long requests from the waiting queue if we aren't + # going to schedule them anyway if sg.first_seq.get_num_new_tokens( ) > scheduler_config.long_prefill_token_threshold: if long_partial_prefills + waiting_long_prefills \ From 241853a7dc3f9f7da08b4d4f8ab71cda62a35a4b Mon Sep 17 00:00:00 2001 From: Prashant Gupta Date: Thu, 14 Nov 2024 11:58:42 -0800 Subject: [PATCH 36/54] =?UTF-8?q?=F0=9F=8E=A8=20fix=20typo?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Prashant Gupta --- vllm/core/scheduler.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index a664c3c747d60..a6598118badf6 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -320,7 +320,7 @@ class PartialPrefillMetadata: scheduler_config: SchedulerConfig - def cannot_schedule(self, seq_group: SequenceGroup): + def cannot_schedule(self, seq_group: SequenceGroup) -> bool: """When concurrent partial prefills are enabled, we limit the number of long requests and only accept shorter requests from the queue while running them @@ -331,8 +331,9 @@ def cannot_schedule(self, seq_group: SequenceGroup): self.scheduler_config.max_long_partial_prefills \ and self.scheduler_config.max_num_partial_prefills > 1 - def increment_partial_prefills(self, seq_group: SequenceGroup): + def increment_partial_prefills(self, seq_group: SequenceGroup) -> None: # When a new prefill is scheduled, we need to know if it is a + # long request if (seq_group.first_seq.get_num_new_tokens() > self.scheduler_config.long_prefill_token_threshold): self.long_partial_prefills += 1 From 07b6d728727e75540ccd9a16f9918149fe7594fb Mon Sep 17 00:00:00 2001 From: Prashant Gupta Date: Thu, 14 Nov 2024 12:04:50 -0800 Subject: [PATCH 37/54] =?UTF-8?q?=E2=8F=AA=20revert=20logging=20change?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Prashant Gupta --- vllm/entrypoints/logger.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/vllm/entrypoints/logger.py b/vllm/entrypoints/logger.py index 7e0880689f007..584ee0d9e1c54 100644 --- a/vllm/entrypoints/logger.py +++ b/vllm/entrypoints/logger.py @@ -34,9 +34,9 @@ def log_inputs( if prompt_token_ids is not None: prompt_token_ids = prompt_token_ids[:max_log_len] - # logger.info( - # "Received request %s: prompt: %r, " - # "params: %s, prompt_token_ids: %s, " - # "lora_request: %s, prompt_adapter_request: %s.", request_id, - # prompt, params, prompt_token_ids, lora_request, - # prompt_adapter_request) + logger.info( + "Received request %s: prompt: %r, " + "params: %s, prompt_token_ids: %s, " + "lora_request: %s, prompt_adapter_request: %s.", request_id, + prompt, params, prompt_token_ids, lora_request, + prompt_adapter_request) From c4bdf370204c8f80c6d12c5f76ad1dc210fc2436 Mon Sep 17 00:00:00 2001 From: Prashant Gupta Date: Thu, 14 Nov 2024 15:40:01 -0800 Subject: [PATCH 38/54] =?UTF-8?q?=E2=9C=85=20remove=20value=20error=20from?= =?UTF-8?q?=20test?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit not sure if that's needed? Signed-off-by: Prashant Gupta --- tests/basic_correctness/test_chunked_prefill.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/tests/basic_correctness/test_chunked_prefill.py b/tests/basic_correctness/test_chunked_prefill.py index 469d18a4dd7af..a9d10d5a41709 100644 --- a/tests/basic_correctness/test_chunked_prefill.py +++ b/tests/basic_correctness/test_chunked_prefill.py @@ -250,12 +250,7 @@ def test_with_prefix_caching( check_result &= not should_fail outputs[enable] = [] # Send the request one-by-one to ensure the cache is populated. - with pytest.raises(ValueError) if should_fail else nullcontext(): - for prompt in full_prompts: - outputs[enable] += vllm_model.generate_greedy([prompt], - max_tokens) - # Check results only if we did not expect a failure. if check_result: check_outputs_equal( outputs_0_lst=outputs[False], From 7c8b400da8e2771ca0beee9cc32db88e2443301f Mon Sep 17 00:00:00 2001 From: Prashant Gupta Date: Thu, 14 Nov 2024 15:47:56 -0800 Subject: [PATCH 39/54] =?UTF-8?q?=E2=9C=85=20remove=20value=20error=20from?= =?UTF-8?q?=20test?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit not sure if that's needed? Signed-off-by: Prashant Gupta --- tests/basic_correctness/test_chunked_prefill.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/basic_correctness/test_chunked_prefill.py b/tests/basic_correctness/test_chunked_prefill.py index a9d10d5a41709..f445689aa4086 100644 --- a/tests/basic_correctness/test_chunked_prefill.py +++ b/tests/basic_correctness/test_chunked_prefill.py @@ -7,7 +7,6 @@ Run `pytest tests/models/test_chunked_prefill.py`. """ import os -from contextlib import nullcontext import pytest @@ -249,6 +248,9 @@ def test_with_prefix_caching( should_fail = chunk_size % 16 != 0 and enable check_result &= not should_fail outputs[enable] = [] + for prompt in full_prompts: + outputs[enable] += vllm_model.generate_greedy([prompt], + max_tokens) # Send the request one-by-one to ensure the cache is populated. if check_result: From 21796fc22770e6bca5c578369015f55269ba63e1 Mon Sep 17 00:00:00 2001 From: Prashant Gupta Date: Thu, 14 Nov 2024 15:48:59 -0800 Subject: [PATCH 40/54] =?UTF-8?q?=F0=9F=8E=A8=20fix=20typo?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Prashant Gupta --- tests/basic_correctness/test_chunked_prefill.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/basic_correctness/test_chunked_prefill.py b/tests/basic_correctness/test_chunked_prefill.py index f445689aa4086..964c68f2c3dfe 100644 --- a/tests/basic_correctness/test_chunked_prefill.py +++ b/tests/basic_correctness/test_chunked_prefill.py @@ -248,10 +248,10 @@ def test_with_prefix_caching( should_fail = chunk_size % 16 != 0 and enable check_result &= not should_fail outputs[enable] = [] + # Send the request one-by-one to ensure the cache is populated. for prompt in full_prompts: outputs[enable] += vllm_model.generate_greedy([prompt], max_tokens) - # Send the request one-by-one to ensure the cache is populated. if check_result: check_outputs_equal( From d9938611e84dbe39668a8185b5c6b383575ad841 Mon Sep 17 00:00:00 2001 From: Prashant Gupta Date: Fri, 15 Nov 2024 12:42:23 -0800 Subject: [PATCH 41/54] =?UTF-8?q?=E2=9C=85=20make=20test=20comprehensive?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Prashant Gupta --- tests/core/test_chunked_prefill_scheduler.py | 90 +++++++++++++++++--- 1 file changed, 79 insertions(+), 11 deletions(-) diff --git a/tests/core/test_chunked_prefill_scheduler.py b/tests/core/test_chunked_prefill_scheduler.py index b14380e69ccaa..1906049bbc199 100644 --- a/tests/core/test_chunked_prefill_scheduler.py +++ b/tests/core/test_chunked_prefill_scheduler.py @@ -17,7 +17,7 @@ def get_sequence_groups(scheduler_output): return [s.seq_group for s in scheduler_output.scheduled_seq_groups] -def append_new_token(seq_group, token_id: int): +def append_new_token(seq_group: SequenceGroup, token_id: int): for seq in seq_group.get_seqs(): seq.append_token_id(token_id, {token_id: Logprob(token_id)}) @@ -224,6 +224,7 @@ def test_short_prompts_jump_long_prompts_in_queue(): cache_config.num_cpu_blocks = 3200 # large KV cache size for large requests cache_config.num_gpu_blocks = 3200 scheduler = Scheduler(scheduler_config, cache_config, None) + running: List[SequenceGroup] = [] # Add 2 large seq groups to scheduler. for i in range(2): @@ -232,30 +233,97 @@ def test_short_prompts_jump_long_prompts_in_queue(): prompt_length=1200, # Very large prompt block_size=block_size) scheduler.add_seq_group(seq_group) + running.append(seq_group) + assert seq_group.is_prefill() # Add 2 small seq groups behind them for i in range(2): _, seq_group = create_dummy_prompt( str(i + 2), - prompt_length=12, # Very small prompt + prompt_length=40, # Very small prompt block_size=block_size) scheduler.add_seq_group(seq_group) + running.append(seq_group) + assert seq_group.is_prefill() - # Verify one large req and two small reqs chunked + # Verify one large req and 1 small req chunked seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) - # assert len(get_sequence_groups(out)) == 3 assert seq_group_meta[0].token_chunk_size == 32 # large req gets 32 tokens - # both small reqs fit in remaining 32 tokens - assert seq_group_meta[1].token_chunk_size == 12 - assert seq_group_meta[2].token_chunk_size == 12 + assert seq_group_meta[1].token_chunk_size == 32 # small req gets 32 tokens + + assert running[0].is_prefill() + assert running[1].is_prefill() + assert running[2].is_prefill() + assert running[3].is_prefill() + + assert out.num_prefill_groups == 2 + assert out.num_batched_tokens == 64 + + # in the second iteration, + # the first small request had only 8 tokens left + # so it went to decode + # The other small req is scheduled + seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) + # the new small req got 64 - (32+8) tokens + assert (seq_group_meta[0].token_chunk_size == 24) + assert seq_group_meta[1].token_chunk_size == 32 # large req still got 32 + # the other small request had only 8 tokens left + assert seq_group_meta[2].token_chunk_size == 8 # 40-32 + + assert running[0].is_prefill() + assert running[1].is_prefill() + assert not running[2].is_prefill() + assert running[3].is_prefill() + assert out.num_prefill_groups == 3 - assert out.num_batched_tokens == 56 + assert out.num_batched_tokens == 64 + # the small seq group has a new token appended. + append_new_token(running[2], 1) + + # in the third iteration, + # the first small request has entered decode + # and other small req had 16 tokens left + # so it went to decode + seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) + assert seq_group_meta[0].token_chunk_size == 32 # large still got 32 + # small req prefilled 40-24=16 tokens + assert (seq_group_meta[1].token_chunk_size == 16) + assert seq_group_meta[2].token_chunk_size == 1 # decode + assert out.num_prefill_groups == 2 + assert out.num_batched_tokens == 49 # (32+16+1 decode) + + assert running[0].is_prefill() + assert running[1].is_prefill() + assert not running[2].is_prefill() + assert not running[3].is_prefill() - # in the second iteration, both small requests are completed + # the small seq group has a new token appended. + append_new_token(running[2], 1) + + # in the fourth iteration, both small requests are decoding # so large request gets all the budget seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) - assert seq_group_meta[ - 0].token_chunk_size == 64 # large req gets all tokens now + # large req gets 63 tokens (minus 1 for decode) + assert seq_group_meta[0].token_chunk_size == 63 + assert seq_group_meta[1].token_chunk_size == 1 # decode + assert out.num_prefill_groups == 1 + assert out.num_batched_tokens == 64 + + assert running[0].is_prefill() + assert running[1].is_prefill() + assert not running[2].is_prefill() + assert not running[3].is_prefill() + + # both the small seq groups have a new token appended + append_new_token(running[2], 1) + append_new_token(running[3], 1) + + # in the fifth iteration, large request gets all the budget + # while both small requests are decoding + seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) + assert seq_group_meta[0].token_chunk_size == 62 + assert seq_group_meta[1].token_chunk_size == 1 # decode + assert seq_group_meta[2].token_chunk_size == 1 # decode assert out.num_prefill_groups == 1 assert out.num_batched_tokens == 64 From 946d297b746846f44c2e7fa31af263a02d37412d Mon Sep 17 00:00:00 2001 From: Prashant Gupta Date: Fri, 15 Nov 2024 13:28:43 -0800 Subject: [PATCH 42/54] =?UTF-8?q?=F0=9F=8E=A8=20fix=20unused=20vars=20in?= =?UTF-8?q?=20test?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Prashant Gupta --- .../basic_correctness/test_chunked_prefill.py | 74 ++----------------- 1 file changed, 5 insertions(+), 69 deletions(-) diff --git a/tests/basic_correctness/test_chunked_prefill.py b/tests/basic_correctness/test_chunked_prefill.py index 964c68f2c3dfe..ef557603568a1 100644 --- a/tests/basic_correctness/test_chunked_prefill.py +++ b/tests/basic_correctness/test_chunked_prefill.py @@ -231,7 +231,6 @@ def test_with_prefix_caching( max_num_batched_tokens = max_num_seqs = chunk_size outputs = {} # type: ignore - check_result = True for enable in (True, False): with vllm_runner( model, @@ -243,78 +242,15 @@ def test_with_prefix_caching( enforce_eager=enforce_eager, max_num_seqs=max_num_seqs, ) as vllm_model: - # It should fail when prefix caching is enable and chunk - # size is not a multiple of block size (16). - should_fail = chunk_size % 16 != 0 and enable - check_result &= not should_fail outputs[enable] = [] # Send the request one-by-one to ensure the cache is populated. for prompt in full_prompts: outputs[enable] += vllm_model.generate_greedy([prompt], max_tokens) - if check_result: - check_outputs_equal( - outputs_0_lst=outputs[False], - outputs_1_lst=outputs[True], - name_0="w/o prefix caching", - name_1="with prefix caching", - ) - - -@pytest.mark.parametrize("model", ["facebook/opt-125m"]) -@pytest.mark.parametrize("dtype", ["bfloat16"]) -@pytest.mark.parametrize("max_tokens", [32]) -@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16]) -@pytest.mark.parametrize("enforce_eager", [False]) -@pytest.mark.parametrize("attention_backend", ["TORCH_SDPA"]) -@pytest.mark.cpu_model -@pytest.mark.skipif(not current_platform.is_cpu(), reason="CPU only") -def test_models_cpu( - hf_runner, - vllm_runner, - example_prompts, - model: str, - dtype: str, - max_tokens: int, - chunked_prefill_token_size: int, - enforce_eager: bool, - attention_backend: str, - monkeypatch, -) -> None: - test_models( - hf_runner, - vllm_runner, - example_prompts, - model, - dtype, - max_tokens, - chunked_prefill_token_size, - enforce_eager, - 1, - attention_backend, - monkeypatch, - ) - - -@pytest.mark.parametrize("max_tokens", [16]) -@pytest.mark.parametrize("enforce_eager", [False]) -@pytest.mark.parametrize("chunk_size", [30, 32]) -@pytest.mark.parametrize("dtype", ["bfloat16"]) -@pytest.mark.cpu_model -@pytest.mark.skipif(not current_platform.is_cpu(), reason="CPU only") -def test_with_prefix_caching_cpu( - vllm_runner, - max_tokens: int, - enforce_eager: bool, - chunk_size: int, - dtype: str, -) -> None: - test_with_prefix_caching( - vllm_runner, - max_tokens, - enforce_eager, - chunk_size, - 1, - dtype, + check_outputs_equal( + outputs_0_lst=outputs[False], + outputs_1_lst=outputs[True], + name_0="w/o prefix caching", + name_1="with prefix caching", ) From 5535515a197e32c70dbac72fb84013c69eebfe2a Mon Sep 17 00:00:00 2001 From: Prashant Gupta Date: Mon, 18 Nov 2024 10:12:14 -0800 Subject: [PATCH 43/54] =?UTF-8?q?=F0=9F=8E=A8=20some=20more=20comments?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Prashant Gupta --- tests/core/test_chunked_prefill_scheduler.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/core/test_chunked_prefill_scheduler.py b/tests/core/test_chunked_prefill_scheduler.py index 1906049bbc199..34ced55828bd7 100644 --- a/tests/core/test_chunked_prefill_scheduler.py +++ b/tests/core/test_chunked_prefill_scheduler.py @@ -251,6 +251,7 @@ def test_short_prompts_jump_long_prompts_in_queue(): assert seq_group_meta[0].token_chunk_size == 32 # large req gets 32 tokens assert seq_group_meta[1].token_chunk_size == 32 # small req gets 32 tokens + # all 4 are prefilling assert running[0].is_prefill() assert running[1].is_prefill() assert running[2].is_prefill() @@ -270,6 +271,8 @@ def test_short_prompts_jump_long_prompts_in_queue(): # the other small request had only 8 tokens left assert seq_group_meta[2].token_chunk_size == 8 # 40-32 + # notice the small request got to decode now + # this is because of max_num_partial_prefills logic assert running[0].is_prefill() assert running[1].is_prefill() assert not running[2].is_prefill() @@ -292,6 +295,7 @@ def test_short_prompts_jump_long_prompts_in_queue(): assert out.num_prefill_groups == 2 assert out.num_batched_tokens == 49 # (32+16+1 decode) + # both small requests have now reached decode assert running[0].is_prefill() assert running[1].is_prefill() assert not running[2].is_prefill() From ba91ddfbf8ff538cd091a4038360500d5ab842a1 Mon Sep 17 00:00:00 2001 From: Prashant Gupta Date: Wed, 20 Nov 2024 10:10:41 -0800 Subject: [PATCH 44/54] =?UTF-8?q?=F0=9F=8E=A8=20fix=20merge=20conflict?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Prashant Gupta --- .../basic_correctness/test_chunked_prefill.py | 62 ++++++++++++++++++- 1 file changed, 59 insertions(+), 3 deletions(-) diff --git a/tests/basic_correctness/test_chunked_prefill.py b/tests/basic_correctness/test_chunked_prefill.py index ef557603568a1..5800385077255 100644 --- a/tests/basic_correctness/test_chunked_prefill.py +++ b/tests/basic_correctness/test_chunked_prefill.py @@ -7,7 +7,6 @@ Run `pytest tests/models/test_chunked_prefill.py`. """ import os - import pytest from tests.kernels.utils import override_backend_env_variable @@ -243,10 +242,9 @@ def test_with_prefix_caching( max_num_seqs=max_num_seqs, ) as vllm_model: outputs[enable] = [] - # Send the request one-by-one to ensure the cache is populated. for prompt in full_prompts: outputs[enable] += vllm_model.generate_greedy([prompt], - max_tokens) + max_tokens) check_outputs_equal( outputs_0_lst=outputs[False], @@ -254,3 +252,61 @@ def test_with_prefix_caching( name_0="w/o prefix caching", name_1="with prefix caching", ) + + +@pytest.mark.parametrize("model", ["facebook/opt-125m"]) +@pytest.mark.parametrize("dtype", ["bfloat16"]) +@pytest.mark.parametrize("max_tokens", [32]) +@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16]) +@pytest.mark.parametrize("enforce_eager", [False]) +@pytest.mark.parametrize("attention_backend", ["TORCH_SDPA"]) +@pytest.mark.cpu_model +@pytest.mark.skipif(not current_platform.is_cpu(), reason="CPU only") +def test_models_cpu( + hf_runner, + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_tokens: int, + chunked_prefill_token_size: int, + enforce_eager: bool, + attention_backend: str, + monkeypatch, +) -> None: + test_models( + hf_runner, + vllm_runner, + example_prompts, + model, + dtype, + max_tokens, + chunked_prefill_token_size, + enforce_eager, + 1, + attention_backend, + monkeypatch, + ) + + +@pytest.mark.parametrize("max_tokens", [16]) +@pytest.mark.parametrize("enforce_eager", [False]) +@pytest.mark.parametrize("chunk_size", [30, 32]) +@pytest.mark.parametrize("dtype", ["bfloat16"]) +@pytest.mark.cpu_model +@pytest.mark.skipif(not current_platform.is_cpu(), reason="CPU only") +def test_with_prefix_caching_cpu( + vllm_runner, + max_tokens: int, + enforce_eager: bool, + chunk_size: int, + dtype: str, +) -> None: + test_with_prefix_caching( + vllm_runner, + max_tokens, + enforce_eager, + chunk_size, + 1, + dtype, + ) From bccf86fbde8344146432b38f385d21ad639e2c90 Mon Sep 17 00:00:00 2001 From: Prashant Gupta Date: Wed, 20 Nov 2024 10:11:30 -0800 Subject: [PATCH 45/54] =?UTF-8?q?=F0=9F=8E=A8=20fmt?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Prashant Gupta --- tests/basic_correctness/test_chunked_prefill.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/basic_correctness/test_chunked_prefill.py b/tests/basic_correctness/test_chunked_prefill.py index 5800385077255..5f90c52481793 100644 --- a/tests/basic_correctness/test_chunked_prefill.py +++ b/tests/basic_correctness/test_chunked_prefill.py @@ -7,6 +7,7 @@ Run `pytest tests/models/test_chunked_prefill.py`. """ import os + import pytest from tests.kernels.utils import override_backend_env_variable @@ -244,7 +245,7 @@ def test_with_prefix_caching( outputs[enable] = [] for prompt in full_prompts: outputs[enable] += vllm_model.generate_greedy([prompt], - max_tokens) + max_tokens) check_outputs_equal( outputs_0_lst=outputs[False], From 75848c949177fca7233a970237e05d32bb92c0dd Mon Sep 17 00:00:00 2001 From: Prashant Gupta Date: Fri, 22 Nov 2024 10:34:03 -0800 Subject: [PATCH 46/54] =?UTF-8?q?=E2=99=BB=EF=B8=8F=20merge=20with=20main?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Prashant Gupta --- vllm/config.py | 362 ++++++++++++++++++++++++++----------------------- 1 file changed, 195 insertions(+), 167 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index ae60f0e7d4830..a4ec21451f549 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1,5 +1,6 @@ import copy import enum +import hashlib import json import warnings from dataclasses import dataclass, field, replace @@ -13,8 +14,10 @@ from transformers import PretrainedConfig import vllm.envs as envs +from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass from vllm.logger import init_logger -from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS +from vllm.model_executor.layers.quantization import (QUANTIZATION_METHODS, + get_quantization_config) from vllm.model_executor.models import ModelRegistry from vllm.platforms import current_platform from vllm.tracing import is_otel_available, otel_import_error_traceback @@ -370,7 +373,7 @@ def _parse_quant_hf_config(self): return quant_cfg def _verify_quantization(self) -> None: - supported_quantization = [*QUANTIZATION_METHODS] + supported_quantization = QUANTIZATION_METHODS rocm_supported_quantization = [ "awq", "gptq", "fp8", "compressed_tensors", "compressed-tensors", "fbgemm_fp8" @@ -392,7 +395,8 @@ def _verify_quantization(self) -> None: quant_method = quant_cfg.get("quant_method", "").lower() # Detect which checkpoint is it - for _, method in QUANTIZATION_METHODS.items(): + for name in QUANTIZATION_METHODS: + method = get_quantization_config(name) quantization_override = method.override_quantization_method( quant_cfg, self.quantization) if quantization_override: @@ -922,56 +926,56 @@ def _verify_load_format(self) -> None: f"{rocm_supported_load_format}") +@dataclass class ParallelConfig: - """Configuration for the distributed execution. + """Configuration for the distributed execution.""" - Args: - pipeline_parallel_size: Number of pipeline parallel groups. - tensor_parallel_size: Number of tensor parallel groups. - worker_use_ray: Deprecated, use distributed_executor_backend instead. - max_parallel_loading_workers: Maximum number of multiple batches - when load model sequentially. To avoid RAM OOM when using tensor - parallel and large models. - disable_custom_all_reduce: Disable the custom all-reduce kernel and - fall back to NCCL. - tokenizer_pool_config: Config for the tokenizer pool. - If None, will use synchronous tokenization. - ray_workers_use_nsight: Whether to profile Ray workers with nsight, see - https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler. - placement_group: ray distributed model workers placement group. - distributed_executor_backend: Backend to use for distributed model - workers, either "ray" or "mp" (multiprocessing). If the product - of pipeline_parallel_size and tensor_parallel_size is less than - or equal to the number of GPUs available, "mp" will be used to - keep processing on a single host. Otherwise, this will default - to "ray" if Ray is installed and fail otherwise. Note that tpu - and hpu only support Ray for distributed inference. - """ + pipeline_parallel_size: int = 1 # Number of pipeline parallel groups. + tensor_parallel_size: int = 1 # Number of tensor parallel groups. - def __init__( - self, - pipeline_parallel_size: int, - tensor_parallel_size: int, - worker_use_ray: Optional[bool] = None, - max_parallel_loading_workers: Optional[int] = None, - disable_custom_all_reduce: bool = False, - tokenizer_pool_config: Optional[TokenizerPoolConfig] = None, - ray_workers_use_nsight: bool = False, - placement_group: Optional["PlacementGroup"] = None, - distributed_executor_backend: Optional[Union[ - str, Type["ExecutorBase"]]] = None, - ) -> None: - self.pipeline_parallel_size = pipeline_parallel_size - self.tensor_parallel_size = tensor_parallel_size - self.distributed_executor_backend = distributed_executor_backend - self.max_parallel_loading_workers = max_parallel_loading_workers - self.disable_custom_all_reduce = disable_custom_all_reduce - self.tokenizer_pool_config = tokenizer_pool_config - self.ray_workers_use_nsight = ray_workers_use_nsight - self.placement_group = placement_group - self.world_size = pipeline_parallel_size * self.tensor_parallel_size - - if worker_use_ray: + # Deprecated, use distributed_executor_backend instead. + worker_use_ray: Optional[bool] = None + + # Maximum number of multiple batches + # when load model sequentially. To avoid RAM OOM when using tensor + # parallel and large models. + max_parallel_loading_workers: Optional[int] = None + + # Disable the custom all-reduce kernel and fall back to NCCL. + disable_custom_all_reduce: bool = False + + # Config for the tokenizer pool. If None, will use synchronous tokenization. + tokenizer_pool_config: Optional[TokenizerPoolConfig] = None + + # Whether to profile Ray workers with nsight, see https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler. + ray_workers_use_nsight: bool = False + + # ray distributed model workers placement group. + placement_group: Optional["PlacementGroup"] = None + + # Backend to use for distributed model + # workers, either "ray" or "mp" (multiprocessing). If the product + # of pipeline_parallel_size and tensor_parallel_size is less than + # or equal to the number of GPUs available, "mp" will be used to + # keep processing on a single host. Otherwise, this will default + # to "ray" if Ray is installed and fail otherwise. Note that tpu + # and hpu only support Ray for distributed inference. + distributed_executor_backend: Optional[Union[str, + Type["ExecutorBase"]]] = None + + # the full name of the worker class to use. If "auto", the worker class + # will be determined based on the platform. + worker_cls: str = "auto" + + world_size: int = field(init=False) + + rank: int = 0 + + def __post_init__(self) -> None: + self.world_size = self.pipeline_parallel_size * \ + self.tensor_parallel_size + + if self.worker_use_ray: if self.distributed_executor_backend is None: self.distributed_executor_backend = "ray" elif not self.use_ray: @@ -1022,7 +1026,6 @@ def __init__( backend) self._verify_args() - self.rank: int = 0 @property def use_ray(self) -> bool: @@ -1055,107 +1058,108 @@ def _verify_args(self) -> None: "run with Ray.") +@dataclass class SchedulerConfig: - """Scheduler configuration. + """Scheduler configuration.""" - Args: - task: The task to use the model for. - max_num_batched_tokens: Maximum number of tokens to be processed in - a single iteration. - max_num_seqs: Maximum number of sequences to be processed in a single - iteration. - max_model_len: Maximum length of a sequence (including prompt - and generated text). - num_lookahead_slots: The number of slots to allocate per sequence per - step, beyond the known token ids. This is used in speculative - decoding to store KV activations of tokens which may or may not be - accepted. - delay_factor: Apply a delay (of delay factor multiplied by previous - prompt latency) before scheduling next prompt. - enable_chunked_prefill: If True, prefill requests can be chunked based - on the remaining max_num_batched_tokens. - preemption_mode: Whether to perform preemption by swapping or - recomputation. If not specified, we determine the mode as follows: - We use recomputation by default since it incurs lower overhead than - swapping. However, when the sequence group has multiple sequences - (e.g., beam search), recomputation is not currently supported. In - such a case, we use swapping instead. - send_delta_data: Private API. If used, scheduler sends delta data to - workers instead of an entire data. It should be enabled only - when SPMD worker architecture is enabled. I.e., - VLLM_USE_RAY_SPMD_WORKER=1 - policy: The scheduling policy to use. "fcfs" (default) or "priority". - """ + task: str = "generate" # The task to use the model for. + + # Maximum number of tokens to be processed in a single iteration. + max_num_batched_tokens: int = field(default=None) # type: ignore - def __init__(self, - task: _Task, - max_num_batched_tokens: Optional[int], - max_num_seqs: int, - max_model_len: int, - max_num_partial_prefills: int = 1, - max_long_partial_prefills: int = 1, - long_prefill_threshold: float = 0.04, - num_lookahead_slots: int = 0, - delay_factor: float = 0.0, - enable_chunked_prefill: bool = False, - is_multimodal_model: bool = False, - preemption_mode: Optional[str] = None, - num_scheduler_steps: int = 1, - multi_step_stream_outputs: bool = False, - send_delta_data: bool = False, - policy: str = "fcfs") -> None: - if max_num_batched_tokens is None: - if enable_chunked_prefill: - if num_scheduler_steps > 1: + # Maximum number of sequences to be processed in a single iteration. + max_num_seqs: int = 128 + + # Maximum length of a sequence (including prompt and generated text). + max_model_len: int = 8192 + + # Maximum number of sequences that can be partially prefilled concurrently + max_num_partial_prefills: int = 1, + + # Maximum number of “very long prompt” sequences that can be prefilled + # concurrently (long is defined by long_prefill_threshold) + max_long_partial_prefills: int = 1, + + # Set a percentage of the context length that determines which + # sequences are considered "long + long_prefill_threshold: float = 0.04 + + # The number of slots to allocate per sequence per + # step, beyond the known token ids. This is used in speculative + # decoding to store KV activations of tokens which may or may not be + # accepted. + num_lookahead_slots: int = 0 + + # Apply a delay (of delay factor multiplied by previous + # prompt latency) before scheduling next prompt. + delay_factor: float = 0.0 + + # If True, prefill requests can be chunked based + # on the remaining max_num_batched_tokens. + enable_chunked_prefill: bool = False + + is_multimodal_model: bool = False + + # Whether to perform preemption by swapping or + # recomputation. If not specified, we determine the mode as follows: + # We use recomputation by default since it incurs lower overhead than + # swapping. However, when the sequence group has multiple sequences + # (e.g., beam search), recomputation is not currently supported. In + # such a case, we use swapping instead. + preemption_mode: Optional[str] = None + + num_scheduler_steps: int = 1 + + multi_step_stream_outputs: bool = False + + # Private API. If used, scheduler sends delta data to + # workers instead of an entire data. It should be enabled only + # when SPMD worker architecture is enabled. I.e., + # VLLM_USE_RAY_SPMD_WORKER=1 + send_delta_data: bool = False + + # The scheduling policy to use. "fcfs" (default) or "priority". + policy: str = "fcfs" + + chunked_prefill_enabled: bool = field(init=False) + + def __post_init__(self) -> None: + if self.max_num_batched_tokens is None: + if self.enable_chunked_prefill: + if self.num_scheduler_steps > 1: # Multi-step Chunked-Prefill doesn't allow prompt-chunking # for now. Have max_num_batched_tokens set to max_model_len # so we don't reject sequences on account of a short # max_num_batched_tokens. - max_num_batched_tokens = max(max_model_len, 2048) + self.max_num_batched_tokens = max(self.max_model_len, 2048) else: # It is the values that have the best balance between ITL # and TTFT on A100. Note it is not optimized for throughput. - max_num_batched_tokens = 512 + self.max_num_batched_tokens = 512 else: # If max_model_len is too short, use 2048 as the default value # for higher throughput. - max_num_batched_tokens = max(max_model_len, 2048) + self.max_num_batched_tokens = max(self.max_model_len, 2048) - if task == "embedding": + if self.task == "embedding": # For embedding, choose specific value for higher throughput - max_num_batched_tokens = max( - max_num_batched_tokens, + self.max_num_batched_tokens = max( + self.max_num_batched_tokens, _EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS, ) - if is_multimodal_model: + if self.is_multimodal_model: # The value needs to be at least the number of multimodal tokens - max_num_batched_tokens = max( - max_num_batched_tokens, + self.max_num_batched_tokens = max( + self.max_num_batched_tokens, _MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS, ) - self.max_num_batched_tokens = max_num_batched_tokens - self.max_num_partial_prefills = max_num_partial_prefills - self.max_long_partial_prefills = max_long_partial_prefills - self.long_prefill_token_threshold = int(max_model_len * - long_prefill_threshold) - - if enable_chunked_prefill: + if self.enable_chunked_prefill: logger.info( "Chunked prefill is enabled with max_num_batched_tokens=%d.", self.max_num_batched_tokens) - self.task: Final = task - self.max_num_seqs = max_num_seqs - self.max_model_len = max_model_len - self.num_lookahead_slots = num_lookahead_slots - self.delay_factor = delay_factor - self.chunked_prefill_enabled = enable_chunked_prefill - self.preemption_mode = preemption_mode - self.num_scheduler_steps = num_scheduler_steps - self.multi_step_stream_outputs = multi_step_stream_outputs - self.send_delta_data = send_delta_data - self.policy = policy + self.chunked_prefill_enabled = self.enable_chunked_prefill self._verify_args() def _verify_args(self) -> None: @@ -1194,25 +1198,13 @@ def is_multi_step(self) -> bool: class DeviceConfig: device: Optional[torch.device] + device_type: str def __init__(self, device: str = "auto") -> None: if device == "auto": # Automated device type detection - if current_platform.is_cuda_alike(): - self.device_type = "cuda" - elif current_platform.is_neuron(): - self.device_type = "neuron" - elif current_platform.is_hpu(): - self.device_type = "hpu" - elif current_platform.is_openvino(): - self.device_type = "openvino" - elif current_platform.is_tpu(): - self.device_type = "tpu" - elif current_platform.is_cpu(): - self.device_type = "cpu" - elif current_platform.is_xpu(): - self.device_type = "xpu" - else: + self.device_type = current_platform.device_type + if not self.device_type: raise RuntimeError("Failed to infer device type") else: # Device type is assigned explicitly @@ -2096,13 +2088,15 @@ class CompilationConfig(BaseModel): - 'none,+op1,+op2' to enable only op1 and op2 By default, all custom ops are enabled when running without Inductor and disabled when running with Inductor (compile_level >= Inductor). + - splitting_ops: a list of ops to split the full graph into subgraphs, used in piecewise compilation. - CudaGraph capture: - use_cudagraph: whether to use cudagraph inside compilation. - False: cudagraph inside compilation is not used. - True: cudagraph inside compilation is used. It requires - that all input buffers have fixed addresses. - Note that this is orthogonal to the cudagraph capture out - side of compilation. + that all input buffers have fixed addresses, and all + splitting ops write their outputs to input buffers. + Note that this is orthogonal to the cudagraph capture logic + outside of compilation. TODO: move outside cudagraph logic into compilation. torch.compile will handle cudagraph capture logic in the future. - cudagraph_capture_sizes: sizes to capture cudagraph. @@ -2136,12 +2130,7 @@ class CompilationConfig(BaseModel): name because the config uses json format. If we pass the config from Python, functions can also be passed directly via Python object constructor, e.g. `CompilationConfig(inductor_passes={"a": func})` - - custom inductor passes: - - dump_graph_stages: list of stages for which we want to dump the graph. - Each pass defines its own stages (before, after, maybe in-between). - - dump_graph_dir: directory to dump the graph. Default is . - - enable_fusion: whether to enable the custom fusion pass. - TODO better pass enabling system. + - custom inductor passes: see PassConfig for more details Why we have different sizes for cudagraph and inductor: - cudagraph: a cudagraph captured for a specific size can only be used @@ -2156,6 +2145,11 @@ class CompilationConfig(BaseModel): level: int = 0 backend: str = "" custom_ops: List[str] = Field(default_factory=list) + splitting_ops: List[str] = Field(default_factory=lambda: [ + "vllm.unified_flash_attention", + "vllm.unified_flash_infer", + "vllm.unified_v1_flash_attention", + ]) use_inductor: bool = True inductor_specialize_for_cudagraph_no_more_than: Optional[int] = None @@ -2164,14 +2158,47 @@ class CompilationConfig(BaseModel): inductor_passes: Dict[str, str] = Field(default_factory=dict) use_cudagraph: bool = False - non_cudagraph_ops: List[str] = Field(default_factory=list) cudagraph_num_of_warmups: int = 0 cudagraph_capture_sizes: Optional[List[int]] = None cudagraph_copy_inputs: bool = False - dump_graph_stages: List[str] = Field(default_factory=list) - dump_graph_dir: Path = Field(default=Path(".")) - enable_fusion: bool = True + class PassConfig(BaseModel): + """ + Configuration for custom Inductor passes. + This is separate from general CompilationConfig so that inductor passes + don't all have access to full configuration - that would create a cycle + as the PassManager is set as a property of config. + - dump_graph_stages: list of stages for which we want to dump the graph. + Each pass defines its own stages (before, after, maybe in-between). + - dump_graph_dir: directory to dump the graphs. Default is . + - enable_fusion: whether to enable the custom fusion pass. + - enable_reshape: whether to enable the custom reshape elimination pass. + TODO better pass enabling system. + """ + dump_graph_stages: List[str] = Field(default_factory=list) + dump_graph_dir: Path = Field(default=Path(".")) + enable_fusion: bool = True + enable_reshape: bool = True + + def uuid(self): + """ + Produces a hash unique to the pass configuration. + Any new fields that affect compilation should be added to the hash. + Do not include dump_graph_* in the hash - they don't affect + compilation. + """ + dict_ = self.model_dump( + include={"enable_fusion", "enable_reshape"}) + encoded = json.dumps(dict_, sort_keys=True).encode("utf-8") + return hashlib.sha256(encoded).digest() + + def model_post_init(self, __context: Any) -> None: + if not self.enable_reshape and self.enable_fusion: + print_warning_once( + "Fusion enabled but reshape elimination disabled." + "RMSNorm + quant (fp8) fusion might not work") + + pass_config: PassConfig = Field(default_factory=PassConfig) # not configurable, computed after init compile_sizes: List[int] = PrivateAttr @@ -2197,8 +2224,9 @@ def model_post_init(self, __context: Any) -> None: for k, v in self.inductor_passes.items(): if not isinstance(v, str): assert callable(v), ( - f"pass {k} should be a function or a qualified name") - self.inductor_compile_config[k] = v + f"pass {k} should be callable or a qualified name") + self.inductor_compile_config[k] = v if isinstance( + v, InductorPass) else CallableInductorPass(v) continue # resolve function from qualified name @@ -2206,7 +2234,8 @@ def model_post_init(self, __context: Any) -> None: module = ".".join(names[:-1]) func_name = names[-1] func = __import__(module).__dict__[func_name] - self.inductor_compile_config[k] = func + self.inductor_compile_config[k] = func if isinstance( + func, InductorPass) else CallableInductorPass(func) self.enabled_custom_ops = Counter() self.disabled_custom_ops = Counter() @@ -2271,10 +2300,10 @@ class VllmConfig: model_config: ModelConfig = field(default=None, init=True) # type: ignore cache_config: CacheConfig = field(default=None, init=True) # type: ignore - parallel_config: ParallelConfig = field(default=None, - init=True) # type: ignore - scheduler_config: SchedulerConfig = field(default=None, - init=True) # type: ignore + parallel_config: ParallelConfig = field(default_factory=ParallelConfig, + init=True) + scheduler_config: SchedulerConfig = field(default_factory=SchedulerConfig, + init=True) device_config: DeviceConfig = field(default=None, init=True) # type: ignore load_config: LoadConfig = field(default=None, init=True) # type: ignore @@ -2348,18 +2377,17 @@ def __post_init__(self): if self.compilation_config is None: self.compilation_config = CompilationConfig() - if envs.VLLM_USE_V1: + if envs.VLLM_USE_V1 and not self.model_config.enforce_eager: # NOTE(woosuk): Currently, we use inductor because the piecewise # CUDA graphs do not work properly with the custom CUDA kernels. # FIXME(woosuk): Disable inductor to reduce the compilation time # and avoid any potential issues with the inductor. self.compilation_config.custom_ops = ["none"] self.compilation_config.use_cudagraph = True - self.compilation_config.non_cudagraph_ops = [ - "vllm.unified_v1_flash_attention" - ] self.compilation_config.use_inductor = True - self.compilation_config.enable_fusion = False + self.compilation_config.pass_config.enable_fusion = False + self.compilation_config.pass_config.enable_reshape = False + self.compilation_config.level = CompilationLevel.PIECEWISE current_platform.check_and_update_config(self) From 4f1c322584301395ef18f8a4b5deeda43ab1565d Mon Sep 17 00:00:00 2001 From: Prashant Gupta Date: Fri, 22 Nov 2024 10:48:01 -0800 Subject: [PATCH 47/54] =?UTF-8?q?=F0=9F=8E=A8=20fix=20fmt?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Prashant Gupta --- vllm/config.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index a4ec21451f549..96d8be9f7e33b 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1074,16 +1074,20 @@ class SchedulerConfig: max_model_len: int = 8192 # Maximum number of sequences that can be partially prefilled concurrently - max_num_partial_prefills: int = 1, + max_num_partial_prefills: int = 1 - # Maximum number of “very long prompt” sequences that can be prefilled + # Maximum number of "very long prompt" sequences that can be prefilled # concurrently (long is defined by long_prefill_threshold) - max_long_partial_prefills: int = 1, + max_long_partial_prefills: int = 1 - # Set a percentage of the context length that determines which - # sequences are considered "long + # Set a percentage of the context length that determines which + # sequences are considered "long" long_prefill_threshold: float = 0.04 + # calculate context length that determines which sequences are + # considered "long" + long_prefill_token_threshold = int(max_model_len * long_prefill_threshold) + # The number of slots to allocate per sequence per # step, beyond the known token ids. This is used in speculative # decoding to store KV activations of tokens which may or may not be From cb8fc93131a52d4ce1ce250b38d2da0ea851b110 Mon Sep 17 00:00:00 2001 From: Prashant Gupta Date: Fri, 22 Nov 2024 13:40:29 -0800 Subject: [PATCH 48/54] =?UTF-8?q?=E2=8F=AA=20revert=20quick=20budget=20che?= =?UTF-8?q?ck?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Turns out we need to add seqs to SequenceStatus.FINISHED_IGNORED if num_new_tokens > prompt_limit Signed-off-by: Prashant Gupta --- vllm/core/scheduler.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index a6598118badf6..127a82b6de91d 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -1019,12 +1019,6 @@ def _schedule_prefills( budget, partial_prefill_metadata=partial_prefill_metadata) - num_new_seqs = seq_group.get_max_num_running_seqs() - # quick budget check - if num_new_tokens == 0 or not budget.can_schedule( - num_new_tokens=num_new_tokens, num_new_seqs=num_new_seqs): - break - if not enable_chunking: num_prompt_tokens = waiting_seqs[0].get_len() assert num_new_tokens == num_prompt_tokens @@ -1075,6 +1069,12 @@ def _schedule_prefills( waiting_queue.popleft() continue + num_new_seqs = seq_group.get_max_num_running_seqs() + # quick budget check + if num_new_tokens == 0 or not budget.can_schedule( + num_new_tokens=num_new_tokens, num_new_seqs=num_new_seqs): + break + # Can schedule this request. if curr_loras is not None and lora_int_id > 0: curr_loras.add(lora_int_id) From 8a8a07f4966a9c69c72b323c5f1794abd4f3741f Mon Sep 17 00:00:00 2001 From: Prashant Gupta Date: Fri, 22 Nov 2024 13:42:33 -0800 Subject: [PATCH 49/54] =?UTF-8?q?=F0=9F=8E=A8=20fmt?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Prashant Gupta --- vllm/core/scheduler.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 127a82b6de91d..38b7b66216158 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -1070,9 +1070,9 @@ def _schedule_prefills( continue num_new_seqs = seq_group.get_max_num_running_seqs() - # quick budget check - if num_new_tokens == 0 or not budget.can_schedule( - num_new_tokens=num_new_tokens, num_new_seqs=num_new_seqs): + if (num_new_tokens == 0 + or not budget.can_schedule(num_new_tokens=num_new_tokens, + num_new_seqs=num_new_seqs)): break # Can schedule this request. From 90a53ab24bddec9ca52aed6713fb54d5393cbe0f Mon Sep 17 00:00:00 2001 From: Prashant Gupta Date: Tue, 26 Nov 2024 11:27:40 -0800 Subject: [PATCH 50/54] =?UTF-8?q?=E2=99=BB=EF=B8=8F=20merge=20with=20main?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Prashant Gupta --- vllm/core/scheduler.py | 292 +++++++++++++++++++++++++++------------ vllm/engine/arg_utils.py | 3 +- 2 files changed, 208 insertions(+), 87 deletions(-) diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 38b7b66216158..5f3c83f1a1bf9 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -299,13 +299,11 @@ def scheduled_seq_group_builder(): class PartialPrefillMetadata: """Holds information about the partial prefills that are currently running during a single iteration of the Scheduler. - When chunked prefill is enabled, we allow a certain number of seqs to be partially prefilled during each iteration. Having multiple partial prefills in flight allows us to minimize TTFT and avoid decode starvation in cases where a single sequence group with a very large prompt blocks the queue for too many iterations. - The number of long prefill requests is limited so that smaller requests may jump the queue in front of them and get to the decode phase faster. @@ -325,11 +323,11 @@ def cannot_schedule(self, seq_group: SequenceGroup) -> bool: we limit the number of long requests and only accept shorter requests from the queue while running them concurrently""" - return seq_group.first_seq.get_num_new_tokens() > \ - self.scheduler_config.long_prefill_token_threshold \ - and self.long_partial_prefills >= \ - self.scheduler_config.max_long_partial_prefills \ - and self.scheduler_config.max_num_partial_prefills > 1 + return (seq_group.first_seq.get_num_new_tokens() > + self.scheduler_config.long_prefill_token_threshold + and self.long_partial_prefills >= + self.scheduler_config.max_long_partial_prefills + and self.scheduler_config.max_num_partial_prefills > 1) def increment_partial_prefills(self, seq_group: SequenceGroup) -> None: # When a new prefill is scheduled, we need to know if it is a @@ -340,11 +338,13 @@ def increment_partial_prefills(self, seq_group: SequenceGroup) -> None: @classmethod def from_queues( - cls, running: Deque[SequenceGroup], waiting: Deque[SequenceGroup], - scheduler_config: SchedulerConfig) -> "PartialPrefillMetadata": + cls, + running: Deque[SequenceGroup], + waiting: Deque[SequenceGroup], + scheduler_config: SchedulerConfig, + ) -> "PartialPrefillMetadata": """Create a PartialPrefillMetadata object from the current state of the scheduler's queues. - This accounts for the currently running prefill requests, and peeks into the waiting queue to see if there are more prefills to potentially be scheduled during this iteration.""" @@ -358,32 +358,33 @@ def from_queues( # TODO: Check if this stage is correctly updated before scheduling if sg.first_seq.data.stage == SequenceStage.PREFILL: partial_prefills += 1 - if sg.first_seq.get_num_new_tokens( - ) > scheduler_config.long_prefill_token_threshold: + if (sg.first_seq.get_num_new_tokens() > + scheduler_config.long_prefill_token_threshold): long_partial_prefills += 1 for sg in waiting: # Don't bother looping through the rest of the queue if we know # there are already at # least max_partial_prefills requests to fill - if partial_prefills + waiting_partial_prefills \ - >= scheduler_config.max_num_partial_prefills: + if (partial_prefills + waiting_partial_prefills >= + scheduler_config.max_num_partial_prefills): break # Don't count long requests from the waiting queue if we aren't # going to schedule them anyway - if sg.first_seq.get_num_new_tokens( - ) > scheduler_config.long_prefill_token_threshold: - if long_partial_prefills + waiting_long_prefills \ - >= scheduler_config.max_long_partial_prefills: + if (sg.first_seq.get_num_new_tokens() > + scheduler_config.long_prefill_token_threshold): + if (long_partial_prefills + waiting_long_prefills >= + scheduler_config.max_long_partial_prefills): continue waiting_long_prefills += 1 waiting_partial_prefills += 1 - return PartialPrefillMetadata(partial_prefills + - waiting_partial_prefills, - long_partial_prefills, - scheduler_config=scheduler_config) + return PartialPrefillMetadata( + partial_prefills + waiting_partial_prefills, + long_partial_prefills, + scheduler_config=scheduler_config, + ) class Scheduler: @@ -492,11 +493,11 @@ def __init__( # This splits the budget evenly among all prefills. self.partial_prefill_budget_lookup_list = [0] * ( self.scheduler_config.max_num_partial_prefills + 1) - self.partial_prefill_budget_lookup_list[ - 0] = scheduler_config.max_num_batched_tokens + self.partial_prefill_budget_lookup_list[0] = ( + scheduler_config.max_num_batched_tokens) for i in range(1, self.scheduler_config.max_num_partial_prefills + 1): - self.partial_prefill_budget_lookup_list[i] = \ - scheduler_config.max_num_batched_tokens // i + self.partial_prefill_budget_lookup_list[i] = ( + scheduler_config.max_num_batched_tokens // i) @property def next_cache_id(self): @@ -613,7 +614,7 @@ def _schedule_running( `budget.num_batched_tokens` has not enough capacity to schedule all tokens. partial_prefill_metadata: information about the partial prefills - that are currently running + that are currently running Returns: SchedulerRunningOutputs. @@ -647,7 +648,15 @@ def _schedule_running( assert len(self._async_stopped) == 0 while running_queue: seq_group = running_queue[0] - num_running_tokens = self._get_num_new_tokens( + # We discard the cached tokens info here because we don't need it + # for running sequence: + # 1. If a sequence is running with chunked prefill, the cached + # tokens info was already used for the first prefill. + # 2. If a sequence is running with non-chunked prefill, then + # there it's a decoding sequence, and the cached tokens info is + # irrelevant. + num_uncached_new_tokens, _ = \ + self._get_num_new_uncached_and_cached_tokens( seq_group, SequenceStatus.RUNNING, enable_chunking, @@ -655,6 +664,7 @@ def _schedule_running( partial_prefill_metadata, ) + num_running_tokens = num_uncached_new_tokens if num_running_tokens == 0: # No budget => Stop break @@ -977,7 +987,7 @@ def _schedule_prefills( chunked number of tokens are scheduled if `budget.num_batched_tokens` has not enough capacity to schedule all tokens. - partial_prefill_metadata: information about the partial prefills + partial_prefill_metadata: information about the partial prefills that are currently running Returns: @@ -989,8 +999,8 @@ def _schedule_prefills( seq_groups=[], ignored_seq_groups=[], num_lookahead_slots=self._get_num_lookahead_slots( - is_prefill=True, enable_chunking=enable_chunking)) - + is_prefill=True, enable_chunking=enable_chunking), + ) ignored_seq_groups: List[SequenceGroup] = [] seq_groups: List[ScheduledSequenceGroup] = [] @@ -1005,19 +1015,20 @@ def _schedule_prefills( assert len(waiting_seqs) == 1, ( "Waiting sequence group should have only one prompt " "sequence.") - - if partial_prefill_metadata is not None and \ - partial_prefill_metadata.cannot_schedule(seq_group): + if (partial_prefill_metadata is not None + and partial_prefill_metadata.cannot_schedule(seq_group)): leftover_waiting_sequences.appendleft(seq_group) waiting_queue.popleft() continue - - num_new_tokens = self._get_num_new_tokens( - seq_group, - SequenceStatus.WAITING, - enable_chunking, - budget, - partial_prefill_metadata=partial_prefill_metadata) + num_new_tokens_uncached, num_new_tokens_cached = ( + self._get_num_new_uncached_and_cached_tokens( + seq_group, + SequenceStatus.WAITING, + enable_chunking, + budget, + partial_prefill_metadata=partial_prefill_metadata, + )) + num_new_tokens = num_new_tokens_uncached + num_new_tokens_cached if not enable_chunking: num_prompt_tokens = waiting_seqs[0].get_len() @@ -1099,7 +1110,8 @@ def _schedule_prefills( num_scheduler_steps=self.scheduler_config. num_scheduler_steps, is_multi_step=self.scheduler_config.is_multi_step, - enable_chunking=enable_chunking) + enable_chunking=enable_chunking, + ) seq_groups.append( ScheduledSequenceGroup(seq_group=seq_group, @@ -1249,7 +1261,8 @@ def _schedule_chunked_prefill(self) -> SchedulerOutputs: partial_prefill_metadata = PartialPrefillMetadata.from_queues( running=self.running, waiting=self.waiting, - scheduler_config=self.scheduler_config) + scheduler_config=self.scheduler_config, + ) # Decoding should be always scheduled first by fcfs. running_scheduled = self._schedule_running( @@ -1265,7 +1278,6 @@ def _schedule_chunked_prefill(self) -> SchedulerOutputs: running_scheduled.swapped_out) == 0: swapped_in = self._schedule_swapped(budget, curr_loras) - # Schedule new prefills. prefills = self._schedule_prefills( budget, curr_loras, @@ -1740,24 +1752,138 @@ def _get_num_new_tokens( enable_chunking: bool, budget: SchedulingBudget, partial_prefill_metadata: Optional[PartialPrefillMetadata] = None, - ) -> int: - """Get the next new tokens to compute for a given sequence group - that's in a given `status`. + ) -> Tuple[int, int]: + """ + Returns the number of new uncached and cached tokens to schedule for a + given sequence group that's in a given `status`. The API could chunk the number of tokens to compute based on `budget` if `enable_chunking` is True. If a sequence group has multiple sequences (e.g., running beam search), it means it is in decoding phase, so chunking doesn't happen. - Returns 0 if the new token cannot be computed due to token budget. + Returns (0, 0) if the new token cannot be computed due to token budget. + + The cached tokens's blocks are already computed, and the attention + backend will reuse the cached blocks rather than recomputing them. So + the scheduler could schedule these cached tokens "for free". + + Args: + seq_group: The sequence group to get the number of new tokens to + schedule. + status: The status of the sequences to get the number of new tokens + to schedule. + enable_chunking: Whether to chunk the number of tokens to compute. + budget: The budget to chunk the number of tokens to compute. + partial_prefill_metadata: information about the partial prefills + that are currently running + + + Returns: + A tuple of two ints. The first int is the number of new uncached + tokens to schedule. The second int is the number of cached tokens. + If no more new tokens can be scheduled, returns (0, 0). """ - num_new_tokens = 0 + num_cached_new_tokens = 0 + num_uncached_new_tokens = 0 + seqs = seq_group.get_seqs(status=status) for seq in seqs: - num_new_tokens += seq.get_num_new_tokens() - assert num_new_tokens > 0 + if not seq.is_prefill(): + # Decode sequences should always just have 1 uncached token + # TODO(rickyx): Actually is this still correct for multi-step? + num_uncached_new_tokens += 1 + continue + + num_computed_tokens_seq = seq.get_num_computed_tokens() + all_num_new_tokens_seq = seq.get_len() - num_computed_tokens_seq + if not self.cache_config.enable_prefix_caching: + # If prefix caching is not enabled, all new tokens are uncached. + num_uncached_new_tokens += all_num_new_tokens_seq + continue + + # NOTE: the cache token might be currently in a block that's in an + # evictor meaning that it's not yet allocated. However, we don't + # exclude such tokens in the cache count because it will be + # guaranteed to be allocated later if the sequence can be allocated. + num_cached_tokens_seq = self.block_manager.get_num_cached_tokens( + seq) + + # Sanity check. + if num_cached_tokens_seq < num_computed_tokens_seq: + # This should only happen with chunked prefill, and + # the seq is still in prefill. The `num_cached_tokens_seq` + # is the value we calculated on scheduling the first prefill. + # For subsequent continuous prefill steps, we cached the + # number of cache tokens for the sequence so the cached token + # count could be less than the number of computed tokens. + # See comments on `ComputedBlocksTracker` for more details. + assert ( + seq.is_prefill() and seq.status == SequenceStatus.RUNNING + and self.scheduler_config.chunked_prefill_enabled + ), ("Number of cached tokens should not be less than the " + "number of computed tokens for a sequence that's still " + f"in prefill. But there are {num_cached_tokens_seq} cached " + f"tokens and {num_computed_tokens_seq} computed tokens " + f"for sequence {seq.seq_id}.") + + num_cached_new_tokens_seq = max( + 0, num_cached_tokens_seq - num_computed_tokens_seq) + num_uncached_new_tokens_seq = (all_num_new_tokens_seq - + num_cached_new_tokens_seq) + + num_uncached_new_tokens += num_uncached_new_tokens_seq + num_cached_new_tokens += num_cached_new_tokens_seq + + if num_uncached_new_tokens == 0 and num_cached_new_tokens > 0: + # For a fully cached hit sequence, we actually need to recompute the + # last token. So we need at least 1 uncached token to schedule. + # See ModelRunner._compute_for_prefix_cache_hit for more details. + num_uncached_new_tokens = 1 + num_cached_new_tokens -= 1 + + if enable_chunking and len(seqs) == 1: + # Chunk if a running request cannot fit in the given budget. + # If number of seq > 1, it means it is doing beam search + # in a decode phase. Do not chunk. + num_uncached_new_tokens = self._chunk_new_tokens_to_schedule( + self.scheduler_config, + self.cache_config, + budget, + self._get_prompt_limit(seq_group), + num_uncached_new_tokens, + self.partial_prefill_budget_lookup_list, + partial_prefill_metadata, + ) + + return num_uncached_new_tokens, num_cached_new_tokens + + @staticmethod + def _chunk_new_tokens_to_schedule( + scheduler_config: SchedulerConfig, + cache_config: CacheConfig, + budget: SchedulingBudget, + prompt_limit: int, + num_new_tokens: int, + partial_prefill_budget_lookup_list: List[int], + partial_prefill_metadata: Optional[PartialPrefillMetadata] = None, + ) -> int: + """ + Chunks the number of new tokens to schedule based on the budget when + chunked prefill is enabled. + + Args: + scheduler_config: The scheduler config. + cache_config: The cache config. + budget: The budget to chunk the number of tokens to compute. + prompt_limit: The maximum number of tokens allowed in a prompt. + num_new_tokens: The number of new tokens to schedule. - if self.scheduler_config.is_multi_step: + Returns: + The number of new tokens to schedule after chunking. + """ + remaining_token_budget = budget.remaining_token_budget() + if scheduler_config.is_multi_step: # The current multi-step + chunked prefill capability does # not actually support chunking prompts. # @@ -1771,39 +1897,33 @@ def _get_num_new_tokens( # If the seq_group is in prompt-stage, pass the # num_new_tokens as-is so the caller can ignore # the sequence. - pass - else: - num_new_tokens = 0 \ - if num_new_tokens > budget.remaining_token_budget() \ - else num_new_tokens - - # If number of seq > 1, it means it is doing beam search - # in a decode phase. Do not chunk. - elif enable_chunking and len(seqs) == 1: - remaining_token_budget = budget.remaining_token_budget() - # Get the number of tokens to allocate to this prefill slot - prefill_slot_budget = remaining_token_budget \ - if partial_prefill_metadata is None \ - else self.partial_prefill_budget_lookup_list[ - partial_prefill_metadata.partial_prefills - ] - - if self.cache_config.enable_prefix_caching: - # When prefix caching is enabled and we're partially prefilling - # a sequence, we always allocate a number of new tokens that is - # divisible by the block size to avoid partial block matching. - block_size = self.cache_config.block_size - # Don't exceed either the total budget or slot budget. - # Take min of those and get the next lowest multiple of the - # block size: - remaining_token_budget = \ - (min(remaining_token_budget, prefill_slot_budget) // - block_size) * block_size - # NB: In the case where num_new_tokens < budget, we are - # finishing prefill for this sequence, so we do not need to - # allocate a full block. - - num_new_tokens = min(num_new_tokens, remaining_token_budget, - prefill_slot_budget) + return num_new_tokens + + return 0 if num_new_tokens > \ + remaining_token_budget else num_new_tokens + + # Get the number of tokens to allocate to this prefill slot + prefill_slot_budget = (remaining_token_budget + if partial_prefill_metadata is None else + partial_prefill_budget_lookup_list[ + partial_prefill_metadata.partial_prefills]) + + if cache_config.enable_prefix_caching: + # When prefix caching is enabled and we're partially prefilling + # a sequence, we always allocate a number of new tokens that is + # divisible by the block size to avoid partial block matching. + block_size = cache_config.block_size + # Don't exceed either the total budget or slot budget. + # Take min of those and get the next lowest multiple of the + # block size: + remaining_token_budget = ( + min(remaining_token_budget, prefill_slot_budget) // + block_size) * block_size + # NB: In the case where num_new_tokens < budget, we are + # finishing prefill for this sequence, so we do not need to + # allocate a full block. + + num_new_tokens = min(num_new_tokens, remaining_token_budget, + prefill_slot_budget) return num_new_tokens diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index f4af87e901875..f35002cd26f74 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1148,7 +1148,8 @@ def create_engine_config(self) -> VllmConfig: policy=self.scheduling_policy, max_num_partial_prefills=self.max_num_partial_prefills, max_long_partial_prefills=self.max_long_partial_prefills, - long_prefill_threshold=self.long_prefill_threshold) + long_prefill_threshold=self.long_prefill_threshold, + ) lora_config = LoRAConfig( bias_enabled=self.enable_lora_bias, max_lora_rank=self.max_lora_rank, From 752ce1b7eff1bef3cde500b59fb5bea8b2719f0e Mon Sep 17 00:00:00 2001 From: Prashant Gupta Date: Tue, 26 Nov 2024 11:42:10 -0800 Subject: [PATCH 51/54] =?UTF-8?q?=F0=9F=8E=A8=20fmt?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Prashant Gupta --- vllm/core/scheduler.py | 550 ++++++++++++++++++----------------------- 1 file changed, 238 insertions(+), 312 deletions(-) diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 4a72b3638bd03..96ee18e0ebca5 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -13,15 +13,9 @@ from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.prompt_adapter.request import PromptAdapterRequest -from vllm.sequence import ( - Sequence, - SequenceData, - SequenceGroup, - SequenceGroupMetadata, - SequenceGroupMetadataDelta, - SequenceStage, - SequenceStatus, -) +from vllm.sequence import (Sequence, SequenceData, SequenceGroup, + SequenceGroupMetadata, SequenceGroupMetadataDelta, + SequenceStage, SequenceStatus) from vllm.utils import Device, PyObjectCache logger = init_logger(__name__) @@ -29,8 +23,7 @@ # Test-only. If configured, decode is preempted with # ARTIFICIAL_PREEMPTION_PROB% probability. ENABLE_ARTIFICIAL_PREEMPT = bool( - os.getenv("VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT", False) -) # noqa + os.getenv("VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT", False)) # noqa ARTIFICIAL_PREEMPTION_PROB = 0.5 ARTIFICIAL_PREEMPTION_MAX_CNT = 500 @@ -76,17 +69,16 @@ def can_schedule(self, *, num_new_tokens: int, num_new_seqs: int): # been cached. assert num_new_tokens >= 0 assert num_new_seqs != 0 - return ( - self.num_batched_tokens + num_new_tokens <= self.token_budget - and self.num_curr_seqs + num_new_seqs <= self.max_num_seqs - ) + return (self.num_batched_tokens + num_new_tokens <= self.token_budget + and self.num_curr_seqs + num_new_seqs <= self.max_num_seqs) def remaining_token_budget(self): return self.token_budget - self.num_batched_tokens - def add_num_batched_tokens( - self, req_id: str, num_batched_tokens: int, num_cached_tokens: int = 0 - ): + def add_num_batched_tokens(self, + req_id: str, + num_batched_tokens: int, + num_cached_tokens: int = 0): if req_id in self._request_ids_num_batched_tokens: return assert num_cached_tokens >= 0 @@ -96,7 +88,8 @@ def add_num_batched_tokens( self._num_batched_tokens += num_batched_tokens self._num_cached_tokens += num_cached_tokens - def subtract_num_batched_tokens(self, req_id: str, num_batched_tokens: int): + def subtract_num_batched_tokens(self, req_id: str, + num_batched_tokens: int): if req_id in self._request_ids_num_batched_tokens: self._request_ids_num_batched_tokens.remove(req_id) self._num_batched_tokens -= num_batched_tokens @@ -172,12 +165,8 @@ def __post_init__(self): def is_empty(self) -> bool: # NOTE: We do not consider the ignored sequence groups. - return ( - not self.scheduled_seq_groups - and not self.blocks_to_swap_in - and not self.blocks_to_swap_out - and not self.blocks_to_copy - ) + return (not self.scheduled_seq_groups and not self.blocks_to_swap_in + and not self.blocks_to_swap_out and not self.blocks_to_copy) def _sort_by_lora_ids(self): self.scheduled_seq_groups = sorted( @@ -327,9 +316,8 @@ def scheduler_running_outputs_builder(): def scheduled_seq_group_builder(): - return ScheduledSequenceGroup( - SequenceGroup.__new__(SequenceGroup), token_chunk_size=0 - ) + return ScheduledSequenceGroup(SequenceGroup.__new__(SequenceGroup), + token_chunk_size=0) # return ScheduledSequenceGroup(seq_group=None, token_chunk_size=0) @@ -361,21 +349,17 @@ def cannot_schedule(self, seq_group: SequenceGroup) -> bool: we limit the number of long requests and only accept shorter requests from the queue while running them concurrently""" - return ( - seq_group.first_seq.get_num_new_tokens() - > self.scheduler_config.long_prefill_token_threshold - and self.long_partial_prefills - >= self.scheduler_config.max_long_partial_prefills - and self.scheduler_config.max_num_partial_prefills > 1 - ) + return (seq_group.first_seq.get_num_new_tokens() > + self.scheduler_config.long_prefill_token_threshold + and self.long_partial_prefills >= + self.scheduler_config.max_long_partial_prefills + and self.scheduler_config.max_num_partial_prefills > 1) def increment_partial_prefills(self, seq_group: SequenceGroup) -> None: # When a new prefill is scheduled, we need to know if it is a # long request - if ( - seq_group.first_seq.get_num_new_tokens() - > self.scheduler_config.long_prefill_token_threshold - ): + if (seq_group.first_seq.get_num_new_tokens() > + self.scheduler_config.long_prefill_token_threshold): self.long_partial_prefills += 1 @classmethod @@ -400,32 +384,24 @@ def from_queues( # TODO: Check if this stage is correctly updated before scheduling if sg.first_seq.data.stage == SequenceStage.PREFILL: partial_prefills += 1 - if ( - sg.first_seq.get_num_new_tokens() - > scheduler_config.long_prefill_token_threshold - ): + if (sg.first_seq.get_num_new_tokens() > + scheduler_config.long_prefill_token_threshold): long_partial_prefills += 1 for sg in waiting: # Don't bother looping through the rest of the queue if we know # there are already at # least max_partial_prefills requests to fill - if ( - partial_prefills + waiting_partial_prefills - >= scheduler_config.max_num_partial_prefills - ): + if (partial_prefills + waiting_partial_prefills >= + scheduler_config.max_num_partial_prefills): break # Don't count long requests from the waiting queue if we aren't # going to schedule them anyway - if ( - sg.first_seq.get_num_new_tokens() - > scheduler_config.long_prefill_token_threshold - ): - if ( - long_partial_prefills + waiting_long_prefills - >= scheduler_config.max_long_partial_prefills - ): + if (sg.first_seq.get_num_new_tokens() > + scheduler_config.long_prefill_token_threshold): + if (long_partial_prefills + waiting_long_prefills >= + scheduler_config.max_long_partial_prefills): continue waiting_long_prefills += 1 waiting_partial_prefills += 1 @@ -455,13 +431,12 @@ def __init__( self.lora_config = lora_config version = "selfattn" - if ( - self.scheduler_config.task == "embedding" - or self.cache_config.is_attention_free - ): + if (self.scheduler_config.task == "embedding" + or self.cache_config.is_attention_free): version = "placeholder" - BlockSpaceManagerImpl = BlockSpaceManager.get_block_space_manager_class(version) + BlockSpaceManagerImpl = BlockSpaceManager.get_block_space_manager_class( + version) num_gpu_blocks = cache_config.num_gpu_blocks if num_gpu_blocks: @@ -506,9 +481,9 @@ def __init__( # The following field is test-only. It is used to inject artificial # preemption. self.enable_artificial_preemption = ENABLE_ARTIFICIAL_PREEMPT - self.artificial_preempt_cnt = ( - ARTIFICIAL_PREEMPTION_MAX_CNT if self.enable_artificial_preemption else 0 - ) + self.artificial_preempt_cnt = (ARTIFICIAL_PREEMPTION_MAX_CNT + if self.enable_artificial_preemption + else 0) self.num_cumulative_preemption: int = 0 # Used to cache python objects @@ -527,14 +502,11 @@ def __init__( self.cache_id = 0 for i in range(self.num_cache_iters): self._seq_group_metadata_cache.append( - PyObjectCache(seq_group_metadata_builder) - ) + PyObjectCache(seq_group_metadata_builder)) self._scheduler_running_outputs_cache.append( - PyObjectCache(scheduler_running_outputs_builder) - ) + PyObjectCache(scheduler_running_outputs_builder)) self._scheduled_seq_group_cache.append( - PyObjectCache(scheduled_seq_group_builder) - ) + PyObjectCache(scheduled_seq_group_builder)) # For async postprocessor, the extra decode run cannot be done # when the request reaches max_model_len. In this case, the request @@ -547,15 +519,12 @@ def __init__( # running an integer division every time a prefill is scheduled. # This splits the budget evenly among all prefills. self.partial_prefill_budget_lookup_list = [0] * ( - self.scheduler_config.max_num_partial_prefills + 1 - ) + self.scheduler_config.max_num_partial_prefills + 1) self.partial_prefill_budget_lookup_list[0] = ( - scheduler_config.max_num_batched_tokens - ) + scheduler_config.max_num_batched_tokens) for i in range(1, self.scheduler_config.max_num_partial_prefills + 1): self.partial_prefill_budget_lookup_list[i] = ( - scheduler_config.max_num_batched_tokens // i - ) + scheduler_config.max_num_batched_tokens // i) @property def next_cache_id(self): @@ -598,7 +567,7 @@ def abort_seq_group(self, request_id: Union[str, Iterable[str]]) -> None: request_id: The ID(s) of the sequence group to abort. """ if isinstance(request_id, str): - request_id = (request_id,) + request_id = (request_id, ) request_ids = set(request_id) for state_queue in [self.waiting, self.running, self.swapped]: aborted_groups: List[SequenceGroup] = [] @@ -636,9 +605,8 @@ def _free_seq_group_cross_attn_blocks( self.block_manager.free_cross(seq_group) def has_unfinished_seqs(self) -> bool: - return ( - len(self.waiting) != 0 or len(self.running) != 0 or len(self.swapped) != 0 - ) + return (len(self.waiting) != 0 or len(self.running) != 0 + or len(self.swapped) != 0) def get_prefix_cache_hit_rate(self, device: Device) -> float: return self.block_manager.get_prefix_cache_hit_rate(device) @@ -679,8 +647,7 @@ def _schedule_running( SchedulerRunningOutputs. """ ret: SchedulerRunningOutputs = self._scheduler_running_outputs_cache[ - self.cache_id - ].get_object() + self.cache_id].get_object() ret.blocks_to_swap_out.clear() ret.blocks_to_copy.clear() ret.decode_seq_groups.clear() @@ -689,8 +656,7 @@ def _schedule_running( ret.swapped_out.clear() ret.num_lookahead_slots = self._get_num_lookahead_slots( - is_prefill=False, enable_chunking=enable_chunking - ) + is_prefill=False, enable_chunking=enable_chunking) ret.decode_seq_groups_list.clear() ret.prefill_seq_groups_list.clear() @@ -700,7 +666,8 @@ def _schedule_running( blocks_to_copy: List[Tuple[int, int]] = ret.blocks_to_copy decode_seq_groups: List[ScheduledSequenceGroup] = ret.decode_seq_groups - prefill_seq_groups: List[ScheduledSequenceGroup] = ret.prefill_seq_groups + prefill_seq_groups: List[ + ScheduledSequenceGroup] = ret.prefill_seq_groups preempted: List[SequenceGroup] = ret.preempted swapped_out: List[SequenceGroup] = ret.swapped_out @@ -715,7 +682,8 @@ def _schedule_running( # 2. If a sequence is running with non-chunked prefill, then # there it's a decoding sequence, and the cached tokens info is # irrelevant. - num_uncached_new_tokens, _ = self._get_num_new_uncached_and_cached_tokens( + num_uncached_new_tokens, _ = \ + self._get_num_new_uncached_and_cached_tokens( seq_group, SequenceStatus.RUNNING, enable_chunking, @@ -734,27 +702,22 @@ def _schedule_running( # to process the final tokens. The check below avoids this extra # decode run when the model max len is reached, in order to avoid # a memory overflow. - if ( - self.use_async_output_proc - and seq_group.seqs[0].get_len() > self.scheduler_config.max_model_len - ): + if (self.use_async_output_proc and seq_group.seqs[0].get_len() > + self.scheduler_config.max_model_len): self._async_stopped.append(seq_group) continue # NOTE(woosuk): Preemption happens only when there is no available # slot to keep all the sequence groups in the RUNNING state. while not self._can_append_slots(seq_group, enable_chunking): - budget.subtract_num_batched_tokens( - seq_group.request_id, num_running_tokens - ) + budget.subtract_num_batched_tokens(seq_group.request_id, + num_running_tokens) num_running_seqs = seq_group.get_max_num_running_seqs() - budget.subtract_num_seqs(seq_group.request_id, num_running_seqs) + budget.subtract_num_seqs(seq_group.request_id, + num_running_seqs) - if ( - curr_loras is not None - and seq_group.lora_int_id > 0 - and seq_group.lora_int_id in curr_loras - ): + if (curr_loras is not None and seq_group.lora_int_id > 0 + and seq_group.lora_int_id in curr_loras): curr_loras.remove(seq_group.lora_int_id) # Determine victim sequence @@ -775,7 +738,8 @@ def _schedule_running( do_preempt = True if self.use_async_output_proc: assert self.output_proc_callback is not None - self.output_proc_callback(request_id=victim_seq_group.request_id) + self.output_proc_callback( + request_id=victim_seq_group.request_id) # It may be that the async pending "victim_seq_group" # becomes finished, in which case we simply free it. @@ -785,7 +749,8 @@ def _schedule_running( # Do preemption if do_preempt: - preempted_mode = self._preempt(victim_seq_group, blocks_to_swap_out) + preempted_mode = self._preempt(victim_seq_group, + blocks_to_swap_out) if preempted_mode == PreemptionMode.RECOMPUTE: preempted.append(victim_seq_group) else: @@ -798,8 +763,8 @@ def _schedule_running( is_prefill = seq_group.is_prefill() scheduled_seq_group: ScheduledSequenceGroup = ( - self._scheduled_seq_group_cache[self.cache_id].get_object() - ) + self._scheduled_seq_group_cache[ + self.cache_id].get_object()) scheduled_seq_group.seq_group = seq_group if is_prefill: scheduled_seq_group.token_chunk_size = num_running_tokens @@ -810,7 +775,8 @@ def _schedule_running( decode_seq_groups.append(scheduled_seq_group) ret.decode_seq_groups_list.append(seq_group) - budget.add_num_batched_tokens(seq_group.request_id, num_running_tokens) + budget.add_num_batched_tokens(seq_group.request_id, + num_running_tokens) # OPTIMIZATION: Note that get_max_num_running_seqs is # expensive. For the default scheduling chase where # enable_chunking is False, num_seqs are updated before running @@ -867,8 +833,8 @@ def _schedule_swapped( # If the sequence group cannot be swapped in, stop. is_prefill = seq_group.is_prefill() alloc_status = self.block_manager.can_swap_in( - seq_group, self._get_num_lookahead_slots(is_prefill, enable_chunking) - ) + seq_group, + self._get_num_lookahead_slots(is_prefill, enable_chunking)) if alloc_status == AllocStatus.LATER: break elif alloc_status == AllocStatus.NEVER: @@ -888,11 +854,8 @@ def _schedule_swapped( lora_int_id = seq_group.lora_int_id assert curr_loras is not None assert self.lora_config is not None - if ( - lora_int_id > 0 - and (lora_int_id not in curr_loras) - and len(curr_loras) >= self.lora_config.max_loras - ): + if (lora_int_id > 0 and (lora_int_id not in curr_loras) + and len(curr_loras) >= self.lora_config.max_loras): # We don't have a space for another LoRA, so # we ignore this request for now. leftover_swapped.appendleft(seq_group) @@ -904,13 +867,12 @@ def _schedule_swapped( num_new_seqs = seq_group.get_max_num_running_seqs() num_new_tokens_uncached, num_new_tokens_cached = ( self._get_num_new_uncached_and_cached_tokens( - seq_group, SequenceStatus.SWAPPED, enable_chunking, budget - ) - ) + seq_group, SequenceStatus.SWAPPED, enable_chunking, + budget)) if num_new_tokens_uncached == 0 or not budget.can_schedule( - num_new_tokens=num_new_tokens_uncached, - num_new_seqs=num_new_seqs, + num_new_tokens=num_new_tokens_uncached, + num_new_seqs=num_new_seqs, ): break @@ -924,14 +886,12 @@ def _schedule_swapped( prefill_seq_groups.append( ScheduledSequenceGroup( seq_group, - token_chunk_size=num_new_tokens_uncached - + num_new_tokens_cached, - ) - ) + token_chunk_size=num_new_tokens_uncached + + num_new_tokens_cached, + )) else: decode_seq_groups.append( - ScheduledSequenceGroup(seq_group, token_chunk_size=1) - ) + ScheduledSequenceGroup(seq_group, token_chunk_size=1)) budget.add_num_batched_tokens( seq_group.request_id, num_batched_tokens=num_new_tokens_uncached, @@ -947,16 +907,13 @@ def _schedule_swapped( blocks_to_swap_in=blocks_to_swap_in, blocks_to_copy=blocks_to_copy, num_lookahead_slots=self._get_num_lookahead_slots( - is_prefill=False, enable_chunking=enable_chunking - ), + is_prefill=False, enable_chunking=enable_chunking), infeasible_seq_groups=infeasible_seq_groups, ) def _get_prompt_limit(self, seq_group: SequenceGroup) -> int: - if ( - self.scheduler_config.chunked_prefill_enabled - and not self.scheduler_config.is_multi_step - ): + if (self.scheduler_config.chunked_prefill_enabled + and not self.scheduler_config.is_multi_step): prompt_limit = self.scheduler_config.max_model_len else: prompt_limit = min( @@ -971,7 +928,8 @@ def _get_prompt_limit(self, seq_group: SequenceGroup) -> int: else: return prompt_limit - def _get_priority(self, seq_group: SequenceGroup) -> Tuple[Optional[int], float]: + def _get_priority(self, + seq_group: SequenceGroup) -> Tuple[Optional[int], float]: """Get the priority of the sequence group. Highest preference to user-defined priority, followed by arrival time. Args: @@ -1005,38 +963,33 @@ def _schedule_priority_preemption( if waiting_queue: seq_group = waiting_queue.popleft() num_new_seqs = seq_group.get_max_num_running_seqs() - num_new_tokens_uncached, _ = self._get_num_new_uncached_and_cached_tokens( - seq_group, SequenceStatus.WAITING, False, budget - ) + num_new_tokens_uncached, _ = \ + self._get_num_new_uncached_and_cached_tokens( + seq_group, SequenceStatus.WAITING, False, budget) # Only preempt if priority inversion exists while running_queue and self._get_priority( - running_queue[-1] - ) > self._get_priority(seq_group): + running_queue[-1]) > self._get_priority(seq_group): # Only preempt if waiting sequence cannot be allocated can_allocate = self.block_manager.can_allocate(seq_group) - if ( - num_new_tokens_uncached > 0 - and can_allocate == AllocStatus.OK - and budget.can_schedule( - num_new_tokens=num_new_tokens_uncached, - num_new_seqs=num_new_seqs, - ) - ): + if (num_new_tokens_uncached > 0 + and can_allocate == AllocStatus.OK + and budget.can_schedule( + num_new_tokens=num_new_tokens_uncached, + num_new_seqs=num_new_seqs, + )): break # Adjust budget to remove the victim sequence group vseq_group = running_queue.pop() num_running_tokens_uncached, _ = ( self._get_num_new_uncached_and_cached_tokens( - vseq_group, SequenceStatus.RUNNING, False, budget - ) - ) + vseq_group, SequenceStatus.RUNNING, False, budget)) budget.subtract_num_batched_tokens( - vseq_group.request_id, num_running_tokens_uncached - ) + vseq_group.request_id, num_running_tokens_uncached) num_running_seqs = vseq_group.get_max_num_running_seqs() - budget.subtract_num_seqs(vseq_group.request_id, num_running_seqs) + budget.subtract_num_seqs(vseq_group.request_id, + num_running_seqs) # Preempt out the victim sequence group self._preempt(vseq_group, blocks_to_swap_out) @@ -1089,8 +1042,7 @@ def _schedule_prefills( seq_groups=[], ignored_seq_groups=[], num_lookahead_slots=self._get_num_lookahead_slots( - is_prefill=True, enable_chunking=enable_chunking - ), + is_prefill=True, enable_chunking=enable_chunking), ) ignored_seq_groups: List[SequenceGroup] = [] seq_groups: List[ScheduledSequenceGroup] = [] @@ -1103,12 +1055,10 @@ def _schedule_prefills( waiting_seqs = seq_group.get_seqs(status=SequenceStatus.WAITING) assert len(waiting_seqs) == 1, ( - "Waiting sequence group should have only one prompt " "sequence." - ) - if ( - partial_prefill_metadata is not None - and partial_prefill_metadata.cannot_schedule(seq_group) - ): + "Waiting sequence group should have only one prompt " + "sequence.") + if (partial_prefill_metadata is not None + and partial_prefill_metadata.cannot_schedule(seq_group)): leftover_waiting_sequences.appendleft(seq_group) waiting_queue.popleft() continue @@ -1119,8 +1069,7 @@ def _schedule_prefills( enable_chunking, budget, partial_prefill_metadata=partial_prefill_metadata, - ) - ) + )) num_new_tokens = num_new_tokens_uncached + num_new_tokens_cached if not enable_chunking: @@ -1130,7 +1079,8 @@ def _schedule_prefills( prompt_limit = self._get_prompt_limit(seq_group) if num_new_tokens > prompt_limit: logger.warning( - "Input prompt (%d tokens) is too long" " and exceeds limit of %d", + "Input prompt (%d tokens) is too long" + " and exceeds limit of %d", num_new_tokens, prompt_limit, ) @@ -1143,13 +1093,11 @@ def _schedule_prefills( num_lookahead_slots: int = 0 if self.scheduler_config.is_multi_step and enable_chunking: num_lookahead_slots = self._get_num_lookahead_slots( - True, enable_chunking - ) + True, enable_chunking) # If the sequence group cannot be allocated, stop. can_allocate = self.block_manager.can_allocate( - seq_group, num_lookahead_slots=num_lookahead_slots - ) + seq_group, num_lookahead_slots=num_lookahead_slots) if can_allocate == AllocStatus.LATER: break elif can_allocate == AllocStatus.NEVER: @@ -1170,22 +1118,17 @@ def _schedule_prefills( lora_int_id = seq_group.lora_int_id assert curr_loras is not None assert self.lora_config is not None - if ( - self.lora_enabled - and lora_int_id > 0 - and lora_int_id not in curr_loras - and len(curr_loras) >= self.lora_config.max_loras - ): + if (self.lora_enabled and lora_int_id > 0 + and lora_int_id not in curr_loras + and len(curr_loras) >= self.lora_config.max_loras): # We don't have a space for another LoRA, so # we ignore this request for now. leftover_waiting_sequences.appendleft(seq_group) waiting_queue.popleft() continue - if ( - budget.num_batched_tokens - >= self.scheduler_config.max_num_batched_tokens - ): + if (budget.num_batched_tokens >= + self.scheduler_config.max_num_batched_tokens): # We've reached the budget limit - since there might be # continuous prefills in the running queue, we should break # to avoid scheduling any new prefills. @@ -1193,8 +1136,8 @@ def _schedule_prefills( num_new_seqs = seq_group.get_max_num_running_seqs() if num_new_tokens_uncached == 0 or not budget.can_schedule( - num_new_tokens=num_new_tokens_uncached, - num_new_seqs=num_new_seqs, + num_new_tokens=num_new_tokens_uncached, + num_new_seqs=num_new_seqs, ): break @@ -1219,16 +1162,15 @@ def _schedule_prefills( else: seq_group.init_multi_step_from_lookahead_slots( num_lookahead_slots, - num_scheduler_steps=self.scheduler_config.num_scheduler_steps, + num_scheduler_steps=self.scheduler_config. + num_scheduler_steps, is_multi_step=self.scheduler_config.is_multi_step, enable_chunking=enable_chunking, ) seq_groups.append( - ScheduledSequenceGroup( - seq_group=seq_group, token_chunk_size=num_new_tokens - ) - ) + ScheduledSequenceGroup(seq_group=seq_group, + token_chunk_size=num_new_tokens)) budget.add_num_batched_tokens( seq_group.request_id, num_batched_tokens=num_new_tokens_uncached, @@ -1245,8 +1187,7 @@ def _schedule_prefills( seq_groups=seq_groups, ignored_seq_groups=ignored_seq_groups, num_lookahead_slots=self._get_num_lookahead_slots( - is_prefill=True, enable_chunking=enable_chunking - ), + is_prefill=True, enable_chunking=enable_chunking), ) def _schedule_default(self) -> SchedulerOutputs: @@ -1265,18 +1206,11 @@ def _schedule_default(self) -> SchedulerOutputs: # Make sure we include num running seqs before scheduling prefill, # so that we don't schedule beyond max_num_seqs for prefill. for seq_group in self.running: - budget.add_num_seqs( - seq_group.request_id, seq_group.get_max_num_running_seqs() - ) - curr_loras = ( - set( - seq_group.lora_int_id - for seq_group in self.running - if seq_group.lora_int_id > 0 - ) - if self.lora_enabled - else None - ) + budget.add_num_seqs(seq_group.request_id, + seq_group.get_max_num_running_seqs()) + curr_loras = (set( + seq_group.lora_int_id for seq_group in self.running + if seq_group.lora_int_id > 0) if self.lora_enabled else None) prefills = SchedulerPrefillOutputs.create_empty() running_scheduled = SchedulerRunningOutputs.create_empty() @@ -1284,30 +1218,31 @@ def _schedule_default(self) -> SchedulerOutputs: # If any requests are swapped, prioritized swapped requests. if not self.swapped: - prefills = self._schedule_prefills( - budget, curr_loras, enable_chunking=False - ) + prefills = self._schedule_prefills(budget, + curr_loras, + enable_chunking=False) - if len(prefills.seq_groups) == 0 and self.scheduler_config.policy == "priority": + if len(prefills.seq_groups + ) == 0 and self.scheduler_config.policy == "priority": self._schedule_priority_preemption(budget) # Don't schedule decodes if prefills are scheduled. # NOTE: If `_schedule_prefills` doesn't enable chunking, self.running # only contains decode requests, not chunked prefills. if len(prefills.seq_groups) == 0: - running_scheduled = self._schedule_running( - budget, curr_loras, enable_chunking=False - ) + running_scheduled = self._schedule_running(budget, + curr_loras, + enable_chunking=False) # If any sequence group is preempted, do not swap in any sequence # group. because it means there's no slot for new running requests. - if ( - len(running_scheduled.preempted) + len(running_scheduled.swapped_out) - == 0 - ): - swapped_in = self._schedule_swapped(budget, curr_loras) + if (len(running_scheduled.preempted) + + len(running_scheduled.swapped_out) == 0): + swapped_in = \ + self._schedule_swapped(budget, curr_loras) - assert budget.num_batched_tokens <= self.scheduler_config.max_num_batched_tokens + assert budget.num_batched_tokens <= \ + self.scheduler_config.max_num_batched_tokens assert budget.num_curr_seqs <= self.scheduler_config.max_num_seqs # Update waiting requests. @@ -1319,13 +1254,13 @@ def _schedule_default(self) -> SchedulerOutputs: self.running.extend(running_scheduled.decode_seq_groups_list) if len(swapped_in.decode_seq_groups) > 0: - self.running.extend([s.seq_group for s in swapped_in.decode_seq_groups]) + self.running.extend( + [s.seq_group for s in swapped_in.decode_seq_groups]) # Update swapped requests. self.swapped.extend(running_scheduled.swapped_out) preempted = len(running_scheduled.preempted) + len( - running_scheduled.swapped_out - ) + running_scheduled.swapped_out) # There should be no prefill from running queue because this policy # doesn't allow chunked prefills. @@ -1350,7 +1285,8 @@ def _schedule_default(self) -> SchedulerOutputs: return SchedulerOutputs( scheduled_seq_groups=scheduled_seq_groups, num_prefill_groups=num_prefill_groups, - num_batched_tokens=budget.num_batched_tokens + budget.num_cached_tokens, + num_batched_tokens=budget.num_batched_tokens + + budget.num_cached_tokens, blocks_to_swap_in=swapped_in.blocks_to_swap_in, blocks_to_swap_out=running_scheduled.blocks_to_swap_out, blocks_to_copy=blocks_to_copy, @@ -1400,7 +1336,8 @@ def _schedule_chunked_prefill(self) -> SchedulerOutputs: # Schedule swapped out requests. # If preemption happens, it means we don't have space for swap-in. - if len(running_scheduled.preempted) + len(running_scheduled.swapped_out) == 0: + if len(running_scheduled.preempted) + len( + running_scheduled.swapped_out) == 0: swapped_in = self._schedule_swapped(budget, curr_loras) prefills = self._schedule_prefills( @@ -1410,7 +1347,8 @@ def _schedule_chunked_prefill(self) -> SchedulerOutputs: partial_prefill_metadata=partial_prefill_metadata, ) - assert budget.num_batched_tokens <= self.scheduler_config.max_num_batched_tokens + assert budget.num_batched_tokens <= \ + self.scheduler_config.max_num_batched_tokens assert budget.num_curr_seqs <= self.scheduler_config.max_num_seqs # Update waiting requests. @@ -1420,50 +1358,50 @@ def _schedule_chunked_prefill(self) -> SchedulerOutputs: # By default, vLLM scheduler prioritizes prefills. # Once chunked prefill is enabled, # the policy is changed to prioritize decode requests. - self.running.extend([s.seq_group for s in swapped_in.decode_seq_groups]) - self.running.extend([s.seq_group for s in swapped_in.prefill_seq_groups]) - self.running.extend([s.seq_group for s in running_scheduled.decode_seq_groups]) - self.running.extend([s.seq_group for s in running_scheduled.prefill_seq_groups]) + self.running.extend( + [s.seq_group for s in swapped_in.decode_seq_groups]) + self.running.extend( + [s.seq_group for s in swapped_in.prefill_seq_groups]) + self.running.extend( + [s.seq_group for s in running_scheduled.decode_seq_groups]) + self.running.extend( + [s.seq_group for s in running_scheduled.prefill_seq_groups]) self.running.extend([s.seq_group for s in prefills.seq_groups]) # Update swapped requests. self.swapped.extend(running_scheduled.swapped_out) # Put prefills first due to Attention backend ordering assumption. - scheduled_seq_groups = ( - prefills.seq_groups - + running_scheduled.prefill_seq_groups - + swapped_in.prefill_seq_groups - + running_scheduled.decode_seq_groups - + swapped_in.decode_seq_groups - ) - num_prefill_groups = ( - len(prefills.seq_groups) - + len(swapped_in.prefill_seq_groups) - + len(running_scheduled.prefill_seq_groups) - ) + scheduled_seq_groups = (prefills.seq_groups + + running_scheduled.prefill_seq_groups + + swapped_in.prefill_seq_groups + + running_scheduled.decode_seq_groups + + swapped_in.decode_seq_groups) + num_prefill_groups = (len(prefills.seq_groups) + + len(swapped_in.prefill_seq_groups) + + len(running_scheduled.prefill_seq_groups)) # If all prompts, then we set num_lookahead_slots to 0 # this allows us to go through the `no_spec` path in # `spec_decode_worker.py` all_prefills = len(scheduled_seq_groups) == num_prefill_groups - num_lookahead_slots = ( - 0 - if (all_prefills and not self.scheduler_config.is_multi_step) - else running_scheduled.num_lookahead_slots - ) + num_lookahead_slots = (0 if + (all_prefills + and not self.scheduler_config.is_multi_step) + else running_scheduled.num_lookahead_slots) return SchedulerOutputs( scheduled_seq_groups=scheduled_seq_groups, num_prefill_groups=num_prefill_groups, - num_batched_tokens=budget.num_batched_tokens + budget.num_cached_tokens, + num_batched_tokens=budget.num_batched_tokens + + budget.num_cached_tokens, blocks_to_swap_in=swapped_in.blocks_to_swap_in, blocks_to_swap_out=running_scheduled.blocks_to_swap_out, - blocks_to_copy=running_scheduled.blocks_to_copy + swapped_in.blocks_to_copy, - ignored_seq_groups=prefills.ignored_seq_groups - + swapped_in.infeasible_seq_groups, + blocks_to_copy=running_scheduled.blocks_to_copy + + swapped_in.blocks_to_copy, + ignored_seq_groups=prefills.ignored_seq_groups + + swapped_in.infeasible_seq_groups, num_lookahead_slots=num_lookahead_slots, running_queue_size=len(self.running), - preempted=( - len(running_scheduled.preempted) + len(running_scheduled.swapped_out) - ), + preempted=(len(running_scheduled.preempted) + + len(running_scheduled.swapped_out)), ) def _schedule(self) -> SchedulerOutputs: @@ -1473,23 +1411,21 @@ def _schedule(self) -> SchedulerOutputs: else: return self._schedule_default() - def _can_append_slots( - self, seq_group: SequenceGroup, enable_chunking: bool - ) -> bool: + def _can_append_slots(self, seq_group: SequenceGroup, + enable_chunking: bool) -> bool: """Determine whether or not we have enough space in the KV cache to continue generation of the sequence group. """ # It is True only for testing case to trigger artificial preemption. - if ( - self.enable_artificial_preemption - and random.uniform(0, 1) < ARTIFICIAL_PREEMPTION_PROB - and self.artificial_preempt_cnt > 0 - ): + if (self.enable_artificial_preemption + and random.uniform(0, 1) < ARTIFICIAL_PREEMPTION_PROB + and self.artificial_preempt_cnt > 0): self.artificial_preempt_cnt -= 1 return False is_prefill = seq_group.is_prefill() - num_lookahead_slots = self._get_num_lookahead_slots(is_prefill, enable_chunking) + num_lookahead_slots = self._get_num_lookahead_slots( + is_prefill, enable_chunking) if is_prefill and num_lookahead_slots > 0: # Appending prefill slots only happens multi-step and @@ -1497,18 +1433,18 @@ def _can_append_slots( assert self.scheduler_config.is_multi_step and enable_chunking return self.block_manager.can_append_slots( - seq_group=seq_group, num_lookahead_slots=num_lookahead_slots - ) + seq_group=seq_group, num_lookahead_slots=num_lookahead_slots) def _allow_async_output_proc(self, seq_group: SequenceGroup) -> bool: # async_output_proc is allowed only when we have a single sequence # in the sequence group no_single_seq = seq_group.sampling_params is None or ( - seq_group.sampling_params.n == 1 - ) + seq_group.sampling_params.n == 1) return no_single_seq - def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs, bool]: + def schedule( + self + ) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs, bool]: # Schedule sequence groups. # This function call changes the internal states of the scheduler # such as self.running, self.swapped, and self.waiting. @@ -1524,14 +1460,14 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs, bool] # Create input data structures. seq_group_metadata_list: List[SequenceGroupMetadata] = [] - for i, scheduled_seq_group in enumerate(scheduler_outputs.scheduled_seq_groups): + for i, scheduled_seq_group in enumerate( + scheduler_outputs.scheduled_seq_groups): seq_group = scheduled_seq_group.seq_group token_chunk_size = scheduled_seq_group.token_chunk_size seq_group.maybe_set_first_scheduled_time(now) seq_group_metadata = self._seq_group_metadata_cache[ - self.cache_id - ].get_object() + self.cache_id].get_object() seq_group_metadata.seq_data.clear() seq_group_metadata.block_tables.clear() @@ -1547,7 +1483,8 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs, bool] encoder_seq_data = encoder_seq.data # Block table for cross-attention # Also managed at SequenceGroup level - cross_block_table = self.block_manager.get_cross_block_table(seq_group) + cross_block_table = self.block_manager.get_cross_block_table( + seq_group) else: encoder_seq_data = None cross_block_table = None @@ -1561,9 +1498,7 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs, bool] if self.cache_config.enable_prefix_caching: common_computed_block_nums = ( self.block_manager.get_common_computed_block_ids( - seq_group.get_seqs(status=SequenceStatus.RUNNING) - ) - ) + seq_group.get_seqs(status=SequenceStatus.RUNNING))) do_sample = True is_prompt = seq_group.is_prefill() @@ -1581,7 +1516,8 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs, bool] # NOTE: We use get_len instead of get_prompt_len because when # a sequence is preempted, prefill includes previous generated # output tokens. - if token_chunk_size + num_computed_tokens < seqs[0].data.get_len(): + if token_chunk_size + num_computed_tokens < seqs[ + 0].data.get_len(): do_sample = False # It assumes the scheduled_seq_groups is ordered by @@ -1606,16 +1542,12 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs, bool] # between engine and worker. # the subsequent comms can still use delta, but # `multi_modal_data` will be None. - multi_modal_data=( - seq_group.multi_modal_data - if scheduler_outputs.num_prefill_groups > 0 - else None - ), + multi_modal_data=(seq_group.multi_modal_data if + scheduler_outputs.num_prefill_groups > 0 + else None), multi_modal_placeholders=( seq_group.multi_modal_placeholders - if scheduler_outputs.num_prefill_groups > 0 - else None - ), + if scheduler_outputs.num_prefill_groups > 0 else None), mm_processor_kwargs=seq_group.mm_processor_kwargs, prompt_adapter_request=seq_group.prompt_adapter_request, ) @@ -1637,7 +1569,8 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs, bool] seq_group_metadata_list.append(seq_group_metadata) if allow_async_output_proc: - allow_async_output_proc = self._allow_async_output_proc(seq_group) + allow_async_output_proc = self._allow_async_output_proc( + seq_group) # Now that the batch has been created, we can assume all blocks in the # batch will have been computed before the next scheduling invocation. @@ -1645,8 +1578,8 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs, bool] # will crash the vLLM instance / will not retry. for scheduled_seq_group in scheduler_outputs.scheduled_seq_groups: self.block_manager.mark_blocks_as_computed( - scheduled_seq_group.seq_group, scheduled_seq_group.token_chunk_size - ) + scheduled_seq_group.seq_group, + scheduled_seq_group.token_chunk_size) self._seq_group_metadata_cache[self.next_cache_id].reset() @@ -1665,7 +1598,8 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs, bool] self.cache_id = self.next_cache_id # Return results - return (seq_group_metadata_list, scheduler_outputs, allow_async_output_proc) + return (seq_group_metadata_list, scheduler_outputs, + allow_async_output_proc) def fork_seq(self, parent_seq: Sequence, child_seq: Sequence) -> None: self.block_manager.fork(parent_seq, child_seq) @@ -1739,8 +1673,7 @@ def _append_slots( """ is_prefill: bool = seq_group.is_prefill() num_lookahead_slots: int = self._get_num_lookahead_slots( - is_prefill, enable_chunking - ) + is_prefill, enable_chunking) seq_group.init_multi_step_from_lookahead_slots( num_lookahead_slots, @@ -1760,9 +1693,8 @@ def _append_slots( if len(cows) > 0: blocks_to_copy.extend(cows) - def _preempt( - self, seq_group: SequenceGroup, blocks_to_swap_out: List[Tuple[int, int]] - ) -> PreemptionMode: + def _preempt(self, seq_group: SequenceGroup, + blocks_to_swap_out: List[Tuple[int, int]]) -> PreemptionMode: # If preemption mode is not specified, we determine the mode as follows: # We use recomputation by default since it incurs lower overhead than # swapping. However, when the sequence group has multiple sequences @@ -1844,8 +1776,7 @@ def _swap_out( # entire engine. raise RuntimeError( "Aborted due to the lack of CPU swap space. Please increase " - "the swap space to avoid this error." - ) + "the swap space to avoid this error.") mapping = self.block_manager.swap_out(seq_group) blocks_to_swap_out.extend(mapping) for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): @@ -1857,15 +1788,17 @@ def _passed_delay(self, now: float) -> bool: self.prev_time, self.prev_prompt = now, False # Delay scheduling prompts to let waiting queue fill up if self.scheduler_config.delay_factor > 0 and self.waiting: - earliest_arrival_time = min([e.metrics.arrival_time for e in self.waiting]) + earliest_arrival_time = min( + [e.metrics.arrival_time for e in self.waiting]) passed_delay = (now - earliest_arrival_time) > ( - self.scheduler_config.delay_factor * self.last_prompt_latency - ) or not self.running + self.scheduler_config.delay_factor * + self.last_prompt_latency) or not self.running else: passed_delay = True return passed_delay - def _get_num_lookahead_slots(self, is_prefill: bool, enable_chunking: bool) -> int: + def _get_num_lookahead_slots(self, is_prefill: bool, + enable_chunking: bool) -> int: """The number of slots to allocate per sequence per step, beyond known token ids. Speculative decoding uses these slots to store KV activations of tokens which may or may not be accepted. @@ -1893,7 +1826,7 @@ def _get_num_lookahead_slots(self, is_prefill: bool, enable_chunking: bool) -> i return self.scheduler_config.num_lookahead_slots - def _get_num_new_tokens( + def _get_num_new_uncached_and_cached_tokens( self, seq_group: SequenceGroup, status: SequenceStatus, @@ -1956,7 +1889,8 @@ def _get_num_new_tokens( # evictor meaning that it's not yet allocated. However, we don't # exclude such tokens in the cache count because it will be # guaranteed to be allocated later if the sequence can be allocated. - num_cached_tokens_seq = self.block_manager.get_num_cached_tokens(seq) + num_cached_tokens_seq = self.block_manager.get_num_cached_tokens( + seq) # Sanity check. if num_cached_tokens_seq < num_computed_tokens_seq: @@ -1968,23 +1902,18 @@ def _get_num_new_tokens( # count could be less than the number of computed tokens. # See comments on `ComputedBlocksTracker` for more details. assert ( - seq.is_prefill() - and seq.status == SequenceStatus.RUNNING + seq.is_prefill() and seq.status == SequenceStatus.RUNNING and self.scheduler_config.chunked_prefill_enabled - ), ( - "Number of cached tokens should not be less than the " + ), ("Number of cached tokens should not be less than the " "number of computed tokens for a sequence that's still " f"in prefill. But there are {num_cached_tokens_seq} cached " f"tokens and {num_computed_tokens_seq} computed tokens " - f"for sequence {seq.seq_id}." - ) + f"for sequence {seq.seq_id}.") num_cached_new_tokens_seq = max( - 0, num_cached_tokens_seq - num_computed_tokens_seq - ) - num_uncached_new_tokens_seq = ( - all_num_new_tokens_seq - num_cached_new_tokens_seq - ) + 0, num_cached_tokens_seq - num_computed_tokens_seq) + num_uncached_new_tokens_seq = (all_num_new_tokens_seq - + num_cached_new_tokens_seq) num_uncached_new_tokens += num_uncached_new_tokens_seq num_cached_new_tokens += num_cached_new_tokens_seq @@ -2047,22 +1976,20 @@ def _chunk_new_tokens_to_schedule( # # Prompts with more tokens than the current remaining budget # are postponed to future scheduler steps - if num_new_tokens > self._get_prompt_limit(seq_group): + if num_new_tokens > prompt_limit: # If the seq_group is in prompt-stage, pass the # num_new_tokens as-is so the caller can ignore # the sequence. return num_new_tokens - return 0 if num_new_tokens > remaining_token_budget else num_new_tokens + return 0 if num_new_tokens > \ + remaining_token_budget else num_new_tokens # Get the number of tokens to allocate to this prefill slot - prefill_slot_budget = ( - remaining_token_budget - if partial_prefill_metadata is None - else partial_prefill_budget_lookup_list[ - partial_prefill_metadata.partial_prefills - ] - ) + prefill_slot_budget = (remaining_token_budget + if partial_prefill_metadata is None else + partial_prefill_budget_lookup_list[ + partial_prefill_metadata.partial_prefills]) if cache_config.enable_prefix_caching: # When prefix caching is enabled and we're partially prefilling @@ -2073,14 +2000,13 @@ def _chunk_new_tokens_to_schedule( # Take min of those and get the next lowest multiple of the # block size: remaining_token_budget = ( - min(remaining_token_budget, prefill_slot_budget) // block_size - ) * block_size + min(remaining_token_budget, prefill_slot_budget) // + block_size) * block_size # NB: In the case where num_new_tokens < budget, we are # finishing prefill for this sequence, so we do not need to # allocate a full block. - num_new_tokens = min( - num_new_tokens, remaining_token_budget, prefill_slot_budget - ) + num_new_tokens = min(num_new_tokens, remaining_token_budget, + prefill_slot_budget) return num_new_tokens From 03525f25c95e96eeb48f5a39d440d722030e75ba Mon Sep 17 00:00:00 2001 From: Joe Runde Date: Wed, 18 Dec 2024 13:04:30 -0700 Subject: [PATCH 52/54] :bug: fix index out of range Signed-off-by: Joe Runde --- vllm/core/scheduler.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index ef6598ea3addb..02220153f6cc6 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -415,8 +415,9 @@ def from_queues( waiting_partial_prefills += 1 return PartialPrefillMetadata( - partial_prefills + waiting_partial_prefills, - long_partial_prefills, + partial_prefills=min(partial_prefills + waiting_partial_prefills, + scheduler_config.max_num_partial_prefills), + long_partial_prefills=long_partial_prefills, scheduler_config=scheduler_config, ) From d5f5eb61270922bfe0bd1437bfb0d83ea489f014 Mon Sep 17 00:00:00 2001 From: Joe Runde Date: Thu, 19 Dec 2024 14:39:12 -0700 Subject: [PATCH 53/54] :recycle: naming updates Signed-off-by: Joe Runde --- vllm/core/scheduler.py | 46 ++++++++++++++++++++-------------------- vllm/engine/arg_utils.py | 2 +- 2 files changed, 24 insertions(+), 24 deletions(-) diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 02220153f6cc6..33437bd76ed92 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -343,12 +343,12 @@ class PartialPrefillMetadata: phase faster. """ - # A minimum bound on the total number of prefills running during this - # scheduling step - partial_prefills: int + # A minimum bound on the total number of prefills to be scheduled during + # this iteration + schedulable_prefills: int # The number of long prefill requests currently running - long_partial_prefills: int + long_prefills: int scheduler_config: SchedulerConfig @@ -359,7 +359,7 @@ def cannot_schedule(self, seq_group: SequenceGroup) -> bool: concurrently""" return (seq_group.first_seq.get_num_new_tokens() > self.scheduler_config.long_prefill_token_threshold - and self.long_partial_prefills >= + and self.long_prefills >= self.scheduler_config.max_long_partial_prefills and self.scheduler_config.max_num_partial_prefills > 1) @@ -368,7 +368,7 @@ def increment_partial_prefills(self, seq_group: SequenceGroup) -> None: # long request if (seq_group.first_seq.get_num_new_tokens() > self.scheduler_config.long_prefill_token_threshold): - self.long_partial_prefills += 1 + self.long_prefills += 1 @classmethod def from_queues( @@ -382,42 +382,42 @@ def from_queues( This accounts for the currently running prefill requests, and peeks into the waiting queue to see if there are more prefills to potentially be scheduled during this iteration.""" - partial_prefills = 0 - long_partial_prefills = 0 + prefills = 0 + long_prefills = 0 - waiting_partial_prefills = 0 waiting_long_prefills = 0 for sg in running: - # TODO: Check if this stage is correctly updated before scheduling if sg.first_seq.data.stage == SequenceStage.PREFILL: - partial_prefills += 1 + prefills += 1 if (sg.first_seq.get_num_new_tokens() > scheduler_config.long_prefill_token_threshold): - long_partial_prefills += 1 + long_prefills += 1 for sg in waiting: # Don't bother looping through the rest of the queue if we know # there are already at # least max_partial_prefills requests to fill - if (partial_prefills + waiting_partial_prefills >= - scheduler_config.max_num_partial_prefills): + if prefills >= scheduler_config.max_num_partial_prefills: break # Don't count long requests from the waiting queue if we aren't # going to schedule them anyway if (sg.first_seq.get_num_new_tokens() > scheduler_config.long_prefill_token_threshold): - if (long_partial_prefills + waiting_long_prefills >= + if (long_prefills + waiting_long_prefills >= scheduler_config.max_long_partial_prefills): continue waiting_long_prefills += 1 - waiting_partial_prefills += 1 + prefills += 1 + # NB: long_prefills and waiting_long_prefills are tracked separately. + # We don't account for the waiting requests here because we need to use + # this metadata to track how many have actually been scheduled. return PartialPrefillMetadata( - partial_prefills=min(partial_prefills + waiting_partial_prefills, - scheduler_config.max_num_partial_prefills), - long_partial_prefills=long_partial_prefills, + schedulable_prefills=min( + prefills, scheduler_config.max_num_partial_prefills), + long_prefills=long_prefills, scheduler_config=scheduler_config, ) @@ -1995,10 +1995,10 @@ def _chunk_new_tokens_to_schedule( remaining_token_budget else num_new_tokens # Get the number of tokens to allocate to this prefill slot - prefill_slot_budget = (remaining_token_budget - if partial_prefill_metadata is None else - partial_prefill_budget_lookup_list[ - partial_prefill_metadata.partial_prefills]) + prefill_slot_budget = ( + remaining_token_budget if partial_prefill_metadata is None else + partial_prefill_budget_lookup_list[ + partial_prefill_metadata.schedulable_prefills]) if cache_config.enable_prefix_caching: # When prefix caching is enabled and we're partially prefilling diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 52593465712e2..f2674116cebc1 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -531,7 +531,7 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: default=EngineArgs.long_prefill_threshold, help="For chunked prefill, a request is considered long " "if the prompt is longer than the " - "max_model_length * long_prefill_threshold. Defaults to 0.04%", + "max_model_length * long_prefill_threshold. Defaults to 0.04", ) parser.add_argument('--max-num-seqs', type=int, From cb5361a949976ad95c5e130ac37b252a0f0b1af4 Mon Sep 17 00:00:00 2001 From: Joe Runde Date: Thu, 19 Dec 2024 14:44:19 -0700 Subject: [PATCH 54/54] :bug: fix long prefill threshold init Signed-off-by: Joe Runde --- vllm/config.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/vllm/config.py b/vllm/config.py index 841977c98632e..a6590924c549b 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1281,7 +1281,7 @@ class SchedulerConfig: # calculate context length that determines which sequences are # considered "long" - long_prefill_token_threshold = int(max_model_len * long_prefill_threshold) + long_prefill_token_threshold: int = 0 # The number of slots to allocate per sequence per # step, beyond the known token ids. This is used in speculative @@ -1385,6 +1385,8 @@ def __post_init__(self) -> None: self.max_num_batched_tokens) self.chunked_prefill_enabled = self.enable_chunked_prefill + self.long_prefill_token_threshold = int(self.max_model_len * + self.long_prefill_threshold) self._verify_args() def _verify_args(self) -> None: