From c8a7e93273ff4338d6f89f8a63ff16426ac240b8 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Wed, 31 Jul 2024 23:51:09 -0700 Subject: [PATCH] [core][scheduler] simplify and improve scheduler (#6867) --- tests/core/block/e2e/test_correctness.py | 2 +- tests/core/test_scheduler.py | 163 ++++++++++------------- vllm/core/policy.py | 45 ------- vllm/core/scheduler.py | 116 ++++++---------- 4 files changed, 112 insertions(+), 214 deletions(-) delete mode 100644 vllm/core/policy.py diff --git a/tests/core/block/e2e/test_correctness.py b/tests/core/block/e2e/test_correctness.py index 8502eab0f8da0..e0dee43f500a0 100644 --- a/tests/core/block/e2e/test_correctness.py +++ b/tests/core/block/e2e/test_correctness.py @@ -183,7 +183,7 @@ def test_v1_v2_greedy_equality_with_cow(baseline_llm_generator, # Allow only 2 sequences of ~128 tokens in worst case. # Note 16 = 128/block_size - "num_gpu_blocks_override": 2 * (16 + 1), + "num_gpu_blocks_override": 2 * (16 + 2), } ]) @pytest.mark.parametrize("baseline_llm_kwargs", [{ diff --git a/tests/core/test_scheduler.py b/tests/core/test_scheduler.py index 4ca2260b5e017..447e8f8a586f6 100644 --- a/tests/core/test_scheduler.py +++ b/tests/core/test_scheduler.py @@ -1,13 +1,12 @@ import time from collections import deque -from typing import Deque, List, Set, Tuple +from typing import List, Set, Tuple from unittest.mock import MagicMock import pytest # noqa from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig from vllm.core.interfaces import AllocStatus -from vllm.core.policy import PolicyFactory from vllm.core.scheduler import Scheduler, SchedulingBudget from vllm.lora.request import LoRARequest from vllm.sequence import Logprob, SequenceGroup, SequenceStatus @@ -348,10 +347,10 @@ def test_prefill_schedule_max_prompt_len(): """ scheduler = initialize_scheduler(max_model_len=30) _, seq_group = create_dummy_prompt("0", prompt_length=60) - waiting = deque([seq_group]) + scheduler.add_seq_group(seq_group) budget = create_token_budget() - remaining_waiting, output = scheduler._schedule_prefills( - waiting, budget, None) + output = scheduler._schedule_prefills(budget, None) + remaining_waiting = scheduler.waiting assert len(output.ignored_seq_groups) == 1 assert len(output.seq_groups) == 0 assert budget.num_batched_tokens == 0 @@ -364,15 +363,14 @@ def test_prefill_schedule_token_budget(): Test token budget respected. """ scheduler = initialize_scheduler() - waiting: Deque[SequenceGroup] = deque() budget = create_token_budget(token_budget=0) for i in range(2): _, seq_group = create_dummy_prompt(str(i), prompt_length=60) - waiting.append(seq_group) + scheduler.add_seq_group(seq_group) # 0 token budget == nothing is scheduled. - remaining_waiting, output = scheduler._schedule_prefills( - waiting, budget, None) + output = scheduler._schedule_prefills(budget, None) + remaining_waiting = scheduler.waiting assert len(output.ignored_seq_groups) == 0 assert len(output.seq_groups) == 0 assert budget.num_batched_tokens == 0 @@ -381,8 +379,8 @@ def test_prefill_schedule_token_budget(): # 60 token budget == 1 request scheduled. budget = create_token_budget(token_budget=60) - remaining_waiting, output = scheduler._schedule_prefills( - waiting, budget, None) + output = scheduler._schedule_prefills(budget, None) + remaining_waiting = scheduler.waiting assert len(output.ignored_seq_groups) == 0 assert len(output.seq_groups) == 1 assert budget.num_batched_tokens == 60 @@ -391,14 +389,13 @@ def test_prefill_schedule_token_budget(): # Test when current_batched_tokens respected. scheduler = initialize_scheduler() - waiting = deque() budget = create_token_budget(token_budget=60) add_token_budget(budget, 30, 0) _, seq_group = create_dummy_prompt(str(i), prompt_length=60) # Cannot schedule a prompt that doesn't fit the budget. - waiting.append(seq_group) - remaining_waiting, output = scheduler._schedule_prefills( - waiting, budget, None) + scheduler.add_seq_group(seq_group) + output = scheduler._schedule_prefills(budget, None) + remaining_waiting = scheduler.waiting assert len(output.ignored_seq_groups) == 0 assert len(output.seq_groups) == 0 assert budget.num_batched_tokens == 30 @@ -406,8 +403,8 @@ def test_prefill_schedule_token_budget(): assert len(remaining_waiting) == 1 budget = create_token_budget(token_budget=90) add_token_budget(budget, 30, 0) - remaining_waiting, output = scheduler._schedule_prefills( - waiting, budget, None) + output = scheduler._schedule_prefills(budget, None) + remaining_waiting = scheduler.waiting assert len(output.seq_groups) == 1 assert budget.num_batched_tokens == 90 assert budget.num_curr_seqs == 1 @@ -419,13 +416,12 @@ def test_prefill_schedule_max_seqs(): Test max seq respected. """ scheduler = initialize_scheduler() - waiting: Deque[SequenceGroup] = deque() budget = create_token_budget(max_num_seqs=2) for i in range(3): _, seq_group = create_dummy_prompt(str(i), prompt_length=60) - waiting.append(seq_group) - remaining_waiting, output = scheduler._schedule_prefills( - waiting, budget, None) + scheduler.add_seq_group(seq_group) + output = scheduler._schedule_prefills(budget, None) + remaining_waiting = scheduler.waiting assert len(output.ignored_seq_groups) == 0 assert len(output.seq_groups) == 2 assert budget.num_batched_tokens == 120 @@ -433,13 +429,13 @@ def test_prefill_schedule_max_seqs(): assert len(remaining_waiting) == 1 # Verify curr_num_seqs respected. - waiting = deque() + scheduler.waiting = deque() budget = create_token_budget(max_num_seqs=2) add_token_budget(budget, 0, 2) _, seq_group = create_dummy_prompt(str(i), prompt_length=60) - waiting.append(seq_group) - remaining_waiting, output = scheduler._schedule_prefills( - waiting, budget, None) + scheduler.add_seq_group(seq_group) + output = scheduler._schedule_prefills(budget, None) + remaining_waiting = scheduler.waiting assert len(output.ignored_seq_groups) == 0 assert len(output.seq_groups) == 0 assert budget.num_batched_tokens == 0 @@ -453,7 +449,6 @@ def test_prefill_schedule_max_lora(): """ lora_config = LoRAConfig(max_lora_rank=8, max_loras=1) scheduler = initialize_scheduler(lora_config=lora_config) - waiting: Deque[SequenceGroup] = deque() budget = create_token_budget(token_budget=120) curr_loras: Set[int] = set() for i in range(2): @@ -463,7 +458,7 @@ def test_prefill_schedule_max_lora(): lora_name=str(i), lora_int_id=i + 1, lora_path="abc")) - waiting.append(seq_group) + scheduler.add_seq_group(seq_group) # Add two more requests to verify lora is prioritized. # 0: Lora, 1: Lora, 2: regular, 3: regular # In the first iteration, index 0, 2 is scheduled. @@ -471,10 +466,10 @@ def test_prefill_schedule_max_lora(): # prioritized. Verify that. for i in range(2, 4): _, seq_group = create_dummy_prompt(str(i), prompt_length=60) - waiting.append(seq_group) + scheduler.add_seq_group(seq_group) # Schedule 2 requests (0 and 2) - remaining_waiting, output = scheduler._schedule_prefills( - waiting, budget, curr_loras) + output = scheduler._schedule_prefills(budget, curr_loras) + remaining_waiting = scheduler.waiting assert len(output.ignored_seq_groups) == 0 assert len(output.seq_groups) == 2 assert budget.num_batched_tokens == 120 @@ -485,8 +480,8 @@ def test_prefill_schedule_max_lora(): # Reset curr_loras so that it can be scheduled. curr_loras = set() budget = create_token_budget(token_budget=60) - remaining_waiting, output = scheduler._schedule_prefills( - remaining_waiting, budget, curr_loras) + output = scheduler._schedule_prefills(budget, curr_loras) + remaining_waiting = scheduler.waiting assert len(output.seq_groups) == 1 assert output.seq_groups[0].seq_group.request_id == "1" assert len(remaining_waiting) == 1 @@ -499,31 +494,29 @@ def test_prefill_schedule_no_block_manager_capacity(): Test sequence cannot be scheduled due to block manager has no capacity. """ scheduler = initialize_scheduler() - waiting: Deque[SequenceGroup] = deque() budget = create_token_budget() for i in range(3): _, seq_group = create_dummy_prompt(str(i), prompt_length=60) - waiting.append(seq_group) + scheduler.add_seq_group(seq_group) scheduler.block_manager.can_allocate = MagicMock() scheduler.block_manager.can_allocate.return_value = AllocStatus.LATER - remainig_waiting, output = scheduler._schedule_prefills( - waiting, budget, None) + output = scheduler._schedule_prefills(budget, None) + remaining_waiting = scheduler.waiting assert len(output.ignored_seq_groups) == 0 assert len(output.seq_groups) == 0 assert budget.num_batched_tokens == 0 assert budget.num_curr_seqs == 0 - assert len(remainig_waiting) == 3 + assert len(remaining_waiting) == 3 scheduler = initialize_scheduler() - waiting = deque() budget = create_token_budget() for i in range(3): _, seq_group = create_dummy_prompt(str(i), prompt_length=60) - waiting.append(seq_group) + scheduler.add_seq_group(seq_group) scheduler.block_manager.can_allocate = MagicMock() scheduler.block_manager.can_allocate.return_value = AllocStatus.NEVER - remaining_waiting, output = scheduler._schedule_prefills( - waiting, budget, None) + output = scheduler._schedule_prefills(budget, None) + remaining_waiting = scheduler.waiting assert len(output.ignored_seq_groups) == 3 assert len(output.seq_groups) == 0 assert budget.num_batched_tokens == 0 @@ -536,14 +529,12 @@ def test_decode_schedule_preempted(): Test decodes cannot be scheduled and preempted. """ scheduler = initialize_scheduler() - running: Deque[SequenceGroup] = deque() - policy = PolicyFactory.get_policy(policy_name="fcfs") curr_loras = None for i in range(3): _, seq_group = create_dummy_prompt(str(i), prompt_length=60) scheduler._allocate_and_set_running(seq_group) append_new_token_seq_group(60, seq_group, 1) - running.append(seq_group) + scheduler._add_seq_group_to_running(seq_group) scheduler.block_manager.can_append_slots = MagicMock() def cannot_append_second_group(seq_group, num_lookahead_slots): @@ -555,8 +546,8 @@ def cannot_append_second_group(seq_group, num_lookahead_slots): # 1 cannot be scheduled, and the lowest priority (request 2) # should be preempted. 1 will also be preempted. budget = create_token_budget() - remainig_running, output = scheduler._schedule_running( - running, budget, curr_loras, policy) + output = scheduler._schedule_running(budget, curr_loras) + remainig_running = scheduler.running assert len(remainig_running) == 0 assert len(output.decode_seq_groups) == 1 assert len(output.prefill_seq_groups) == 0 @@ -577,14 +568,12 @@ def test_decode_swap_beam_search(): Test best_of > 1 swap out blocks """ scheduler = initialize_scheduler() - running: Deque[SequenceGroup] = deque() - policy = PolicyFactory.get_policy(policy_name="fcfs") curr_loras = None budget = create_token_budget() for i in range(3): _, seq_group = create_dummy_prompt(str(i), prompt_length=60, best_of=2) scheduler._allocate_and_set_running(seq_group) - running.append(seq_group) + scheduler._add_seq_group_to_running(seq_group) append_new_token_seq_group(60, seq_group, 1) budget.add_num_seqs(seq_group.request_id, seq_group.get_max_num_running_seqs()) @@ -603,8 +592,8 @@ def cannot_append_second_group(seq_group, num_lookahead_slots): expected_swap_mapping = [("5", "7")] scheduler.block_manager.swap_out.return_value = expected_swap_mapping - remainig_running, output = scheduler._schedule_running( - running, budget, curr_loras, policy) + output = scheduler._schedule_running(budget, curr_loras) + remainig_running = scheduler.running assert len(remainig_running) == 0 assert len(output.decode_seq_groups) == 2 assert len(output.prefill_seq_groups) == 0 @@ -628,20 +617,18 @@ def test_schedule_decode_blocks_to_copy_update(): """ scheduler = initialize_scheduler() _, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2) - running: Deque[SequenceGroup] = deque() - policy = PolicyFactory.get_policy(policy_name="fcfs") curr_loras = None scheduler._allocate_and_set_running(seq_group) append_new_token_seq_group(60, seq_group, 1) - running.append(seq_group) + scheduler._add_seq_group_to_running(seq_group) # The last request should be swapped out. scheduler.block_manager.append_slots = MagicMock() scheduler.block_manager.append_slots.return_value = [(2, 3)] budget = create_token_budget() - remaining_running, output = scheduler._schedule_running( - running, budget, curr_loras, policy) + output = scheduler._schedule_running(budget, curr_loras) + remaining_running = scheduler.running assert len(remaining_running) == 0 assert len(output.decode_seq_groups) == 1 assert len(output.prefill_seq_groups) == 0 @@ -656,19 +643,17 @@ def test_schedule_decode_blocks_to_copy_update(): def test_schedule_swapped_simple(): scheduler = initialize_scheduler() - swapped: Deque[SequenceGroup] = deque() - policy = PolicyFactory.get_policy(policy_name="fcfs") curr_loras = None blocks_to_swap_out: List[Tuple[int, int]] = [] _, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2) scheduler._allocate_and_set_running(seq_group) append_new_token_seq_group(60, seq_group, 1) scheduler._swap_out(seq_group, blocks_to_swap_out) - swapped.append(seq_group) + scheduler._add_seq_group_to_swapped(seq_group) budget = create_token_budget() - remaining_swapped, output = scheduler._schedule_swapped( - swapped, budget, curr_loras, policy) + output = scheduler._schedule_swapped(budget, curr_loras) + remaining_swapped = scheduler.swapped assert len(remaining_swapped) == 0 assert budget.num_batched_tokens == 1 assert budget.num_curr_seqs == 2 @@ -683,8 +668,6 @@ def test_schedule_swapped_simple(): def test_schedule_swapped_max_token_budget(): scheduler = initialize_scheduler() - swapped: Deque[SequenceGroup] = deque() - policy = PolicyFactory.get_policy(policy_name="fcfs") curr_loras = None blocks_to_swap_out: List[Tuple[int, int]] = [] for _ in range(2): @@ -692,11 +675,11 @@ def test_schedule_swapped_max_token_budget(): scheduler._allocate_and_set_running(seq_group) append_new_token_seq_group(60, seq_group, 1) scheduler._swap_out(seq_group, blocks_to_swap_out) - swapped.append(seq_group) + scheduler._add_seq_group_to_swapped(seq_group) budget = create_token_budget(token_budget=1) - remaining_swapped, output = scheduler._schedule_swapped( - swapped, budget, curr_loras, policy) + output = scheduler._schedule_swapped(budget, curr_loras) + remaining_swapped = scheduler.swapped assert len(remaining_swapped) == 1 assert budget.num_batched_tokens == 1 assert budget.num_curr_seqs == 2 @@ -706,8 +689,8 @@ def test_schedule_swapped_max_token_budget(): # Verify num_batched_tokens are respected. budget = create_token_budget(token_budget=1) add_token_budget(budget, 1, 0) - remaining_swapped, output = scheduler._schedule_swapped( - remaining_swapped, budget, curr_loras, policy) + output = scheduler._schedule_swapped(budget, curr_loras) + remaining_swapped = scheduler.swapped assert len(remaining_swapped) == 1 assert budget.num_batched_tokens == 1 assert budget.num_curr_seqs == 0 @@ -717,8 +700,6 @@ def test_schedule_swapped_max_token_budget(): def test_schedule_swapped_max_seqs(): scheduler = initialize_scheduler() - swapped: Deque[SequenceGroup] = deque() - policy = PolicyFactory.get_policy(policy_name="fcfs") curr_loras = None blocks_to_swap_out: List[Tuple[int, int]] = [] for i in range(4): @@ -726,11 +707,11 @@ def test_schedule_swapped_max_seqs(): scheduler._allocate_and_set_running(seq_group) append_new_token_seq_group(60, seq_group, 1) scheduler._swap_out(seq_group, blocks_to_swap_out) - swapped.append(seq_group) + scheduler._add_seq_group_to_swapped(seq_group) budget = create_token_budget(max_num_seqs=2) - remaining_swapped, output = scheduler._schedule_swapped( - swapped, budget, curr_loras, policy) + output = scheduler._schedule_swapped(budget, curr_loras) + remaining_swapped = scheduler.swapped assert len(remaining_swapped) == 2 assert budget.num_batched_tokens == 2 assert budget.num_curr_seqs == 2 @@ -738,8 +719,8 @@ def test_schedule_swapped_max_seqs(): assert len(output.prefill_seq_groups) == 0 # Verify num_curr_seqs are respected. - remaining_swapped, output = scheduler._schedule_swapped( - remaining_swapped, budget, curr_loras, policy) + output = scheduler._schedule_swapped(budget, curr_loras) + remaining_swapped = scheduler.swapped assert len(remaining_swapped) == 2 assert budget.num_batched_tokens == 2 assert budget.num_curr_seqs == 2 @@ -750,8 +731,6 @@ def test_schedule_swapped_max_seqs(): def test_schedule_swapped_max_loras(): lora_config = LoRAConfig(max_lora_rank=8, max_loras=1) scheduler = initialize_scheduler(lora_config=lora_config) - swapped: Deque[SequenceGroup] = deque() - policy = PolicyFactory.get_policy(policy_name="fcfs") curr_loras: Set[int] = set() blocks_to_swap_out: List[Tuple[int, int]] = [] for i in range(2): @@ -764,11 +743,11 @@ def test_schedule_swapped_max_loras(): scheduler._allocate_and_set_running(seq_group) append_new_token_seq_group(60, seq_group, 1) scheduler._swap_out(seq_group, blocks_to_swap_out) - swapped.append(seq_group) + scheduler._add_seq_group_to_swapped(seq_group) budget = create_token_budget() - remaining_swapped, output = scheduler._schedule_swapped( - swapped, budget, curr_loras, policy) + output = scheduler._schedule_swapped(budget, curr_loras) + remaining_swapped = scheduler.swapped assert len(remaining_swapped) == 1 assert budget.num_batched_tokens == 1 assert budget.num_curr_seqs == 1 @@ -779,8 +758,6 @@ def test_schedule_swapped_max_loras(): def test_schedule_swapped_cannot_swap_in(): scheduler = initialize_scheduler() - swapped: Deque[SequenceGroup] = deque() - policy = PolicyFactory.get_policy(policy_name="fcfs") curr_loras = None blocks_to_swap_out: List[Tuple[int, int]] = [] for _ in range(2): @@ -788,15 +765,15 @@ def test_schedule_swapped_cannot_swap_in(): scheduler._allocate_and_set_running(seq_group) append_new_token_seq_group(60, seq_group, 1) scheduler._swap_out(seq_group, blocks_to_swap_out) - swapped.append(seq_group) + scheduler._add_seq_group_to_swapped(seq_group) # The last request should be swapped out. scheduler.block_manager.can_swap_in = MagicMock() scheduler.block_manager.can_swap_in.return_value = AllocStatus.LATER # Since we cannot swap in, none of the requests are swapped in. budget = create_token_budget() - remaining_swapped, output = scheduler._schedule_swapped( - swapped, budget, curr_loras, policy) + output = scheduler._schedule_swapped(budget, curr_loras) + remaining_swapped = scheduler.swapped assert len(remaining_swapped) == 2 assert budget.num_batched_tokens == 0 assert budget.num_curr_seqs == 0 @@ -806,8 +783,6 @@ def test_schedule_swapped_cannot_swap_in(): def test_infeasible_swap(): scheduler = initialize_scheduler() - swapped: Deque[SequenceGroup] = deque() - policy = PolicyFactory.get_policy(policy_name="fcfs") curr_loras = None blocks_to_swap_out: List[Tuple[int, int]] = [] for _ in range(2): @@ -815,15 +790,15 @@ def test_infeasible_swap(): scheduler._allocate_and_set_running(seq_group) append_new_token_seq_group(60, seq_group, 1) scheduler._swap_out(seq_group, blocks_to_swap_out) - swapped.append(seq_group) + scheduler._add_seq_group_to_swapped(seq_group) # The last request should be swapped out. scheduler.block_manager.can_swap_in = MagicMock() scheduler.block_manager.can_swap_in.return_value = AllocStatus.NEVER # Since we cannot swap in, none of the requests are swapped in. budget = create_token_budget() - remaining_swapped, output = scheduler._schedule_swapped( - swapped, budget, curr_loras, policy) + output = scheduler._schedule_swapped(budget, curr_loras) + remaining_swapped = scheduler.swapped assert len(remaining_swapped) == 0 assert len(output.infeasible_seq_groups) == 2 assert budget.num_batched_tokens == 0 @@ -834,23 +809,21 @@ def test_infeasible_swap(): def test_schedule_swapped_blocks_to_copy(): scheduler = initialize_scheduler() - swapped: Deque[SequenceGroup] = deque() - policy = PolicyFactory.get_policy(policy_name="fcfs") curr_loras = None _, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2) scheduler._allocate_and_set_running(seq_group) append_new_token_seq_group(60, seq_group, 1) blocks_to_swap_out: List[Tuple[int, int]] = [] scheduler._swap_out(seq_group, blocks_to_swap_out) - swapped.append(seq_group) + scheduler._add_seq_group_to_swapped(seq_group) # The last request should be swapped out. scheduler.block_manager.append_slots = MagicMock() scheduler.block_manager.append_slots.return_value = [(2, 3)] budget = create_token_budget() - remaining_swapped, output = scheduler._schedule_swapped( - swapped, budget, curr_loras, policy) + output = scheduler._schedule_swapped(budget, curr_loras) + remaining_swapped = scheduler.swapped assert len(remaining_swapped) == 0 assert len(output.decode_seq_groups) == 1 assert len(output.prefill_seq_groups) == 0 diff --git a/vllm/core/policy.py b/vllm/core/policy.py deleted file mode 100644 index a4463ac0f340e..0000000000000 --- a/vllm/core/policy.py +++ /dev/null @@ -1,45 +0,0 @@ -from collections import deque -from typing import Deque - -from vllm.sequence import SequenceGroup - - -class Policy: - - def get_priority( - self, - now: float, - seq_group: SequenceGroup, - ) -> float: - raise NotImplementedError - - def sort_by_priority( - self, - now: float, - seq_groups: Deque[SequenceGroup], - ) -> Deque[SequenceGroup]: - return deque( - sorted( - seq_groups, - key=lambda seq_group: self.get_priority(now, seq_group), - reverse=True, - )) - - -class FCFS(Policy): - - def get_priority( - self, - now: float, - seq_group: SequenceGroup, - ) -> float: - return now - seq_group.metrics.arrival_time - - -class PolicyFactory: - - _POLICY_REGISTRY = {'fcfs': FCFS} - - @classmethod - def get_policy(cls, policy_name: str, **kwargs) -> Policy: - return cls._POLICY_REGISTRY[policy_name](**kwargs) diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 5cdf1d15c31e1..11d020be0c940 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -8,7 +8,6 @@ from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig from vllm.core.interfaces import AllocStatus, BlockSpaceManager -from vllm.core.policy import Policy, PolicyFactory from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.prompt_adapter.request import PromptAdapterRequest @@ -345,6 +344,16 @@ def add_seq_group(self, seq_group: SequenceGroup) -> None: # Add sequence groups to the waiting queue. self.waiting.append(seq_group) + def _add_seq_group_to_running(self, seq_group: SequenceGroup) -> None: + # Add sequence groups to the running queue. + # Only for testing purposes. + self.running.append(seq_group) + + def _add_seq_group_to_swapped(self, seq_group: SequenceGroup) -> None: + # Add sequence groups to the swapped queue. + # Only for testing purposes. + self.swapped.append(seq_group) + def abort_seq_group(self, request_id: Union[str, Iterable[str]]) -> None: """Aborts a sequence group with the given ID. @@ -398,32 +407,26 @@ def get_and_reset_finished_requests_ids(self) -> List[str]: def _schedule_running( self, - running_queue: deque, budget: SchedulingBudget, curr_loras: Optional[Set[int]], - policy: Policy, enable_chunking: bool = False, - ) -> Tuple[deque, SchedulerRunningOutputs]: + ) -> SchedulerRunningOutputs: """Schedule sequence groups that are running. Running queue should include decode and chunked prefill requests. Args: - running_queue: The queue that contains running requests (i.e., - decodes). The given arguments are NOT in-place modified. budget: The scheduling budget. The argument is in-place updated when any decodes are preempted. curr_loras: Currently batched lora request ids. The argument is in-place updated when any decodes are preempted. - policy: The sorting policy to sort running_queue. enable_chunking: If True, seq group can be chunked and only a chunked number of tokens are scheduled if `budget.num_batched_tokens` has not enough capacity to schedule all tokens. Returns: - A tuple of remaining running queue (should be always 0) after - scheduling and SchedulerRunningOutputs. + SchedulerRunningOutputs. """ # Blocks that need to be swapped or copied before model execution. blocks_to_swap_out: List[Tuple[int, int]] = [] @@ -436,10 +439,9 @@ def _schedule_running( # NOTE(woosuk): Preemption happens only when there is no available slot # to keep all the sequence groups in the RUNNING state. - # In this case, the policy is responsible for deciding which sequence - # groups to preempt. - now = time.time() - running_queue = policy.sort_by_priority(now, running_queue) + + running_queue = self.running + while running_queue: seq_group = running_queue[0] num_running_tokens = self._get_num_new_tokens( @@ -503,7 +505,7 @@ def _schedule_running( if curr_loras is not None and seq_group.lora_int_id > 0: curr_loras.add(seq_group.lora_int_id) - return running_queue, SchedulerRunningOutputs( + return SchedulerRunningOutputs( decode_seq_groups=decode_seq_groups, prefill_seq_groups=prefill_seq_groups, preempted=preempted, @@ -515,12 +517,10 @@ def _schedule_running( def _schedule_swapped( self, - swapped_queue: deque, budget: SchedulingBudget, curr_loras: Optional[Set[int]], - policy: Policy, enable_chunking: bool = False, - ) -> Tuple[deque, SchedulerSwappedInOutputs]: + ) -> SchedulerSwappedInOutputs: """Schedule sequence groups that are swapped out. It schedules swapped requests as long as it fits `budget` and @@ -528,20 +528,16 @@ def _schedule_swapped( `budget` and `curr_loras` are updated based on scheduled seq_groups. Args: - swapped_queue: The queue that contains swapped out requests. - The given arguments are NOT in-place modified. budget: The scheduling budget. The argument is in-place updated when any requests are swapped in. curr_loras: Currently batched lora request ids. The argument is in-place updated when any requests are swapped in. - policy: The sorting policy to sort swapped_queue. enable_chunking: If True, seq group can be chunked and only a chunked number of tokens are scheduled if `budget.num_batched_tokens` has not enough capacity to schedule all tokens. Returns: - A tuple of remaining swapped_queue after scheduling and SchedulerSwappedInOutputs. """ # Blocks that need to be swapped or copied before model execution. @@ -549,10 +545,10 @@ def _schedule_swapped( blocks_to_copy: List[Tuple[int, int]] = [] decode_seq_groups: List[ScheduledSequenceGroup] = [] prefill_seq_groups: List[ScheduledSequenceGroup] = [] - now = time.time() - swapped_queue = policy.sort_by_priority(now, swapped_queue) infeasible_seq_groups: List[SequenceGroup] = [] + swapped_queue = self.swapped + leftover_swapped: Deque[SequenceGroup] = deque() while swapped_queue: seq_group = swapped_queue[0] @@ -617,7 +613,7 @@ def _schedule_swapped( swapped_queue.extendleft(leftover_swapped) - return swapped_queue, SchedulerSwappedInOutputs( + return SchedulerSwappedInOutputs( decode_seq_groups=decode_seq_groups, prefill_seq_groups=prefill_seq_groups, blocks_to_swap_in=blocks_to_swap_in, @@ -644,11 +640,10 @@ def _get_prompt_limit(self, seq_group: SequenceGroup) -> int: def _schedule_prefills( self, - waiting_queue: deque, budget: SchedulingBudget, curr_loras: Optional[Set[int]], enable_chunking: bool = False, - ) -> Tuple[deque, SchedulerPrefillOutputs]: + ) -> SchedulerPrefillOutputs: """Schedule sequence groups that are in prefill stage. Note that the current scheduler treats PREEMPTED_FOR_RECOMPUTE @@ -660,8 +655,6 @@ def _schedule_prefills( `budget` and `curr_loras` are updated based on scheduled seq_groups. Args: - waiting_queue: The queue that contains prefill requests. - The given arguments are NOT in-place modified. budget: The scheduling budget. The argument is in-place updated when any requests are scheduled. curr_loras: Currently batched lora request ids. The argument is @@ -672,14 +665,12 @@ def _schedule_prefills( all tokens. Returns: - A tuple of remaining waiting_queue after scheduling and SchedulerSwappedInOutputs. """ ignored_seq_groups: List[SequenceGroup] = [] seq_groups: List[SequenceGroup] = [] - # We don't sort waiting queue because we assume it is sorted. - # Copy the queue so that the input queue is not modified. - waiting_queue = deque([s for s in waiting_queue]) + + waiting_queue = self.waiting leftover_waiting_sequences: Deque[SequenceGroup] = deque() while self._passed_delay(time.time()) and waiting_queue: @@ -758,7 +749,7 @@ def _schedule_prefills( if len(seq_groups) > 0: self.prev_prompt = True - return waiting_queue, SchedulerPrefillOutputs( + return SchedulerPrefillOutputs( seq_groups=seq_groups, ignored_seq_groups=ignored_seq_groups, num_lookahead_slots=self._get_num_lookahead_slots(is_prefill=True)) @@ -785,53 +776,43 @@ def _schedule_default(self) -> SchedulerOutputs: seq_group.lora_int_id for seq_group in self.running if seq_group.lora_int_id > 0) if self.lora_enabled else None - remaining_waiting, prefills = (self.waiting, - SchedulerPrefillOutputs.create_empty()) - remaining_running, running_scheduled = ( - self.running, SchedulerRunningOutputs.create_empty()) - remaining_swapped, swapped_in = ( - self.swapped, SchedulerSwappedInOutputs.create_empty()) + prefills = SchedulerPrefillOutputs.create_empty() + running_scheduled = SchedulerRunningOutputs.create_empty() + swapped_in = SchedulerSwappedInOutputs.create_empty() # If any requests are swapped, prioritized swapped requests. if not self.swapped: - remaining_waiting, prefills = self._schedule_prefills( - self.waiting, budget, curr_loras, enable_chunking=False) + prefills = self._schedule_prefills(budget, + curr_loras, + enable_chunking=False) - fcfs_policy = PolicyFactory.get_policy(policy_name="fcfs") # Don't schedule decodes if prefills are scheduled. # NOTE: If `_schedule_prefills` doesn't enable chunking, self.running # only contains decode requests, not chunked prefills. if len(prefills.seq_groups) == 0: - remaining_running, running_scheduled = self._schedule_running( - self.running, - budget, - curr_loras, - fcfs_policy, - enable_chunking=False) + running_scheduled = self._schedule_running(budget, + curr_loras, + enable_chunking=False) # If any sequence group is preempted, do not swap in any sequence # group. because it means there's no slot for new running requests. if len(running_scheduled.preempted) + len( running_scheduled.swapped_out) == 0: - remaining_swapped, swapped_in = self._schedule_swapped( - self.swapped, budget, curr_loras, fcfs_policy) + swapped_in = self._schedule_swapped(budget, curr_loras) assert (budget.num_batched_tokens <= self.scheduler_config.max_num_batched_tokens) assert budget.num_curr_seqs <= self.scheduler_config.max_num_seqs # Update waiting requests. - self.waiting = remaining_waiting self.waiting.extendleft(running_scheduled.preempted) # Update new running requests. - self.running = remaining_running self.running.extend([s.seq_group for s in prefills.seq_groups]) self.running.extend( [s.seq_group for s in running_scheduled.decode_seq_groups]) self.running.extend( [s.seq_group for s in swapped_in.decode_seq_groups]) # Update swapped requests. - self.swapped = remaining_swapped self.swapped.extend(running_scheduled.swapped_out) preempted = (len(running_scheduled.preempted) + len(running_scheduled.swapped_out)) @@ -877,42 +858,32 @@ def _schedule_chunked_prefill(self): ) curr_loras: Set[int] = set() - remaining_waiting, prefills = (self.waiting, - SchedulerPrefillOutputs.create_empty()) - remaining_running, running_scheduled = ( - self.running, SchedulerRunningOutputs.create_empty()) - remaining_swapped, swapped_in = ( - self.swapped, SchedulerSwappedInOutputs.create_empty()) + prefills = SchedulerPrefillOutputs.create_empty() + swapped_in = SchedulerSwappedInOutputs.create_empty() # Decoding should be always scheduled first by fcfs. - fcfs_policy = PolicyFactory.get_policy(policy_name="fcfs") - remaining_running, running_scheduled = self._schedule_running( - self.running, - budget, - curr_loras, - fcfs_policy, - enable_chunking=True) + running_scheduled = self._schedule_running(budget, + curr_loras, + enable_chunking=True) # Schedule swapped out requests. # If preemption happens, it means we don't have space for swap-in. if len(running_scheduled.preempted) + len( running_scheduled.swapped_out) == 0: - remaining_swapped, swapped_in = self._schedule_swapped( - self.swapped, budget, curr_loras, fcfs_policy) + swapped_in = self._schedule_swapped(budget, curr_loras) # Schedule new prefills. - remaining_waiting, prefills = self._schedule_prefills( - self.waiting, budget, curr_loras, enable_chunking=True) + prefills = self._schedule_prefills(budget, + curr_loras, + enable_chunking=True) assert (budget.num_batched_tokens <= self.scheduler_config.max_num_batched_tokens) assert budget.num_curr_seqs <= self.scheduler_config.max_num_seqs # Update waiting requests. - self.waiting = remaining_waiting self.waiting.extendleft(running_scheduled.preempted) # Update new running requests. - self.running = remaining_running self.running.extend([s.seq_group for s in prefills.seq_groups]) self.running.extend( [s.seq_group for s in running_scheduled.decode_seq_groups]) @@ -923,7 +894,6 @@ def _schedule_chunked_prefill(self): self.running.extend( [s.seq_group for s in swapped_in.prefill_seq_groups]) # Update swapped requests. - self.swapped = remaining_swapped self.swapped.extend(running_scheduled.swapped_out) return SchedulerOutputs( scheduled_seq_groups=(prefills.seq_groups +