Skip to content

Commit

Permalink
[Core] Support for Scheduling-defined Prefill-Decode Disaggregation f…
Browse files Browse the repository at this point in the history
…eature (#15)
  • Loading branch information
Xinyi-ECNU authored Oct 14, 2024
1 parent 7097e70 commit ce45945
Show file tree
Hide file tree
Showing 28 changed files with 607 additions and 163 deletions.
21 changes: 19 additions & 2 deletions llumnix/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class EngineManagerArgs:
polling_interval: float = None

dispatch_policy: str = None
num_dispatch_instances: int = None

enable_migration: bool = None
enable_defrag: bool = None
Expand Down Expand Up @@ -60,23 +61,34 @@ class EngineManagerArgs:
last_stage_max_blocks: int = None
max_stages: int = None

enable_pd_disagg: bool = None

def __post_init__(self):
# Check if all fields default to None
for field_info in dataclasses.fields(self):
if field_info.default is not None:
raise ValueError(f"The default value of '{field_info.name}' should be None")

for attr in dataclasses.fields(self):
if getattr(self, attr.name) is None:
setattr(self, attr.name, getattr(_C.MANAGER, attr.name.upper()))

def create_global_scheduler_configs(
self,
) -> Tuple[GlobalSchedulerConfig]:

# Create the GlobalScheduler Configuration.
global_scheduler_config = GlobalSchedulerConfig(self.initial_instances,
self.load_metric,
self.dispatch_policy,
self.num_dispatch_instances,
self.pair_migration_policy,
self.migrate_out_threshold,
self.enable_defrag,
self.scaling_policy,
self.scale_up_threshold,
self.scale_down_threshold)
self.scale_down_threshold,
self.enable_pd_disagg)
return global_scheduler_config

def create_migration_config(self) -> MigrationConfig:
Expand Down Expand Up @@ -134,6 +146,9 @@ def add_cli_args(
type=str,
choices=['balanced', 'load', 'queue', 'flood'],
help='request dispatch policy')
parser.add_argument('--num-available-dispatch-instances',
type=int,
help='number of available instances for dispatching')

parser.add_argument('--enable-migration',
action='store_true',
Expand Down Expand Up @@ -211,5 +226,7 @@ def add_cli_args(
parser.add_argument('--max-stages',
type=int,
help='drop migration if the number of stages > max_stages')

parser.add_argument('--enable-pd-disagg',
type=bool,
help='enable prefill decoding disaggregation')
return parser
7 changes: 6 additions & 1 deletion llumnix/backends/backend_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def is_sim_backend(status: "BackendType") -> bool:
class BackendInterface(ABC):
# Methods for inference
@abstractmethod
def add_request(self, request_id: str, server_info: ServerInfo,
def add_request(self, request_id: str, server_info: ServerInfo, expected_steps: int,
*args, **kwargs) -> None:
"""Adds a new inference request to the backend's processing queue.
Expand All @@ -47,6 +47,11 @@ def add_request(self, request_id: str, server_info: ServerInfo,
Args:
request_id: Request ID.
server_info: The information of the api server where the request come.
expected_steps: The expected number of steps for the request to run. The number of steps
represents the times 'engine.step()' has been called by the backend
instances for the request. Currently, `expected_steps` is used
to implement prefill-decoding disaggregation. For requests dispatched to
prefill instances `expected_steps` is set to 1.
*args: Positional arguments that represent request-specific data.
**kwargs: Keyword arguments that contain metadata of the backend request
(request_id, arrival_time, etc.).
Expand Down
10 changes: 6 additions & 4 deletions llumnix/backends/vllm/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,13 +243,14 @@ def update_instance_info(self, instance_info: InstanceInfo) -> None:
instance_info.num_blocks_last_running_request = self.instance_info.num_blocks_last_running_request
self.instance_info = instance_info

def add_request(self, request_id: str, server_info: ServerInfo, *args, **kwargs):
def add_request(self, request_id: str, server_info: ServerInfo, expected_steps: int, *args, **kwargs):
super().add_request(request_id, *args, **kwargs)
seq_group = self.scheduler.waiting[-1]
if hasattr(server_info, 'request_timestamps'):
server_info.request_timestamps.engine_add_request_timestamp = time.time()
self.scheduler.waiting[-1] = SequenceGroupLlumnix(request_id, server_info, [seq_group.get_seqs()[0]], seq_group.sampling_params,
seq_group.metrics.arrival_time, seq_group.lora_request, seq_group.multi_modal_data)
self.scheduler.waiting[-1] = SequenceGroupLlumnix(request_id, server_info, expected_steps, [seq_group.get_seqs()[0]],
seq_group.sampling_params, seq_group.metrics.arrival_time, seq_group.lora_request,
seq_group.multi_modal_data)
self.scheduler.scheduler_lock.release()

def _start_put_queue_loop(self):
Expand Down Expand Up @@ -346,10 +347,11 @@ def execute_worker_method(self, method, *args, **kwargs):
def add_request(self,
request_id: str,
server_info: ServerInfo,
expected_steps: int,
*args,
**kwargs) -> None:
# Store the server information of each request to put the request outputs back to the corresponding api server correctly.
self.engine.add_request(request_id, server_info, *args, **kwargs)
self.engine.add_request(request_id, server_info, expected_steps, *args, **kwargs)

def commit_dst_request(self, backend_request: SequenceGroupLlumnix) -> None:
seq = backend_request.get_seqs()[0]
Expand Down
14 changes: 14 additions & 0 deletions llumnix/backends/vllm/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import time
import threading
from typing import Dict, List, Optional, Tuple
from collections import deque

from vllm.core.block_manager_v1 import BlockSpaceManagerV1, BlockTable
from vllm.core.scheduler import (Scheduler, PreemptionMode, SequenceStatus, SequenceGroupMetadata, SchedulerOutputs)
Expand Down Expand Up @@ -207,6 +208,19 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]:
for scheduled_seq_group in scheduler_outputs.scheduled_seq_groups]))
return seq_group_metadata_list, scheduler_outputs

def _schedule_running(self, running_queue: deque, *args, **kwargs):
filtered_running_queue = deque()
remove_running = deque()
for seq_group in running_queue:
if seq_group.output_len >= seq_group.expected_steps:
remove_running.extend([seq_group])
else:
filtered_running_queue.extend([seq_group])
remaining_running, running_scheduled = super()._schedule_running(filtered_running_queue, *args, **kwargs)
for seq_group in remove_running:
remaining_running.extend([seq_group])
return remaining_running, running_scheduled

def add_seq_group(self, *args, **kwargs):
# The scheduler lock is mannually released in the end of LLMEngineLlumnix.add_request function.
# pylint: disable=R1732
Expand Down
4 changes: 2 additions & 2 deletions llumnix/backends/vllm/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@


class SequenceGroupLlumnix(SequenceGroup, LlumnixRequest):
def __init__(self, request_id, server_info, *args, **kwargs) -> None:
def __init__(self, request_id, server_info, expected_steps: int, *args, **kwargs) -> None:
SequenceGroup.__init__(self, request_id, *args, **kwargs)
LlumnixRequest.__init__(self, request_id, server_info)
LlumnixRequest.__init__(self, request_id, server_info, expected_steps)

@property
def prompt_len(self) -> int:
Expand Down
10 changes: 10 additions & 0 deletions llumnix/config/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import math

from .config import LlumnixConfig as LC

# -----------------------------------------------------------------------------
Expand Down Expand Up @@ -77,6 +79,8 @@
_C.MANAGER.LOAD_METRIC = 'remaining_steps'
# Request dispatch policy
_C.MANAGER.DISPATCH_POLICY = 'load'
# Number of available dispatch instances. -1 indicates that all instances can be used for dispatching
_C.MANAGER.NUM_DISPATCH_INSTANCES = math.inf

# -----------------------------------------------------------------------------
# MIGRATION CONFIGURATION
Expand Down Expand Up @@ -124,3 +128,9 @@
_C.MANAGER.SCALE_UP_THRESHOLD = 10
# Scale down threshold
_C.MANAGER.SCALE_DOWN_THRESHOLD = 60

# -----------------------------------------------------------------------------
# PREFILL DECODING DISAGGREGATION CONFIGURATION
# -----------------------------------------------------------------------------
# Enable prefill decoding disaggregation
_C.MANAGER.ENABLE_PD_DISAGG = False
1 change: 0 additions & 1 deletion llumnix/entrypoints/llumnix_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,6 @@ def init_llumlets(engine_manager_args: EngineManagerArgs, engine_args, node_id:

instance_ids = [random_uuid() for _ in range(engine_manager_args.initial_instances)]
migration_configs = engine_manager_args.create_migration_config()

for idx in range(engine_manager_args.initial_instances):
instance_id = instance_ids[idx]
if not engine_manager_args.profiling_result_file_path:
Expand Down
18 changes: 14 additions & 4 deletions llumnix/global_scheduler/dispatch_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,14 @@
class DispatchScheduler:
def __init__(self,
dispatch_policy: str,
instance_load_calculator: InstanceLoadCalculator) -> None:
instance_load_calculator: InstanceLoadCalculator,
num_dispatch_instances: int) -> None:
self.dispatch_policy = DispatchPolicyFactory.get_policy(dispatch_policy)
self.instance_load_calculator = instance_load_calculator
self.num_instances = 0
self.instance_id_set: Set[str] = set()
self.available_dispatch_instance_set: Set[str] = set()
self.num_dispatch_instances = num_dispatch_instances
# instance info args
self.instance_info: Dict[str, InstanceInfo] = {}
self.sorted_instance_infos: List[InstanceInfo] = None
Expand Down Expand Up @@ -56,22 +59,29 @@ def update_instance_infos(self,
def add_instance(self, instance_id: str) -> None:
self.instance_id_set.add(instance_id)
self.num_instances = len(self.instance_id_set)
self.instance_num_requests[instance_id] = 0
if self.num_dispatch_instances <= 0 or (self.num_dispatch_instances > 0 and
len(self.available_dispatch_instance_set) < self.num_dispatch_instances):
self.available_dispatch_instance_set.add(instance_id)
self.instance_num_requests[instance_id] = 0

def remove_instance(self, instance_id: str) -> None:
self.instance_id_set.remove(instance_id)
self.num_instances = len(self.instance_id_set)
del self.instance_num_requests[instance_id]
if instance_id in self.instance_num_requests:
del self.instance_num_requests[instance_id]
if instance_id in self.available_dispatch_instance_set:
self.available_dispatch_instance_set.remove(instance_id)

def _sort_instance_infos(self,
descending: bool = True) -> None:
instance_infos: List[InstanceInfo] = list(self.instance_info.values())
available_instance_infos = [info for info in instance_infos if info.instance_id in self.available_dispatch_instance_set]
if isinstance(self.dispatch_policy, Queue):
key_attr = 'num_waiting_requests'
else:
key_attr = 'instance_load_dispatch_scale'
self.sorted_instance_infos = sorted(
instance_infos,
available_instance_infos,
key=lambda instance_info: getattr(instance_info, key_attr),
reverse=descending
)
Expand Down
18 changes: 12 additions & 6 deletions llumnix/global_scheduler/global_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,13 @@
# limitations under the License.

from typing import Dict, List, Tuple, Union, Iterable, Set
import math

from llumnix.logger import init_logger
from llumnix.internal_config import GlobalSchedulerConfig
from llumnix.instance_info import InstanceLoadCalculator, InstanceInfo
from llumnix.global_scheduler.dispatch_scheduler import DispatchScheduler
from llumnix.global_scheduler.migration_scheduler import MigrationScheduler
from llumnix.global_scheduler.migration_scheduler import MigrationScheduler, PairMigrationConstraints
from llumnix.global_scheduler.scaling_scheduler import ScalingScheduler

logger = init_logger(__name__)
Expand All @@ -30,12 +31,14 @@ def __init__(self,
# instance load and instance info args
self.load_metric = global_scheduler_config.load_metric
self.enable_defrag = global_scheduler_config.enable_defrag
self.enable_pd_disagg = global_scheduler_config.enable_pd_disagg
self.instance_load_calculator = InstanceLoadCalculator(load_metric=self.load_metric,
enable_defrag=self.enable_defrag)
# dispatch args
self.dispatch_policy = global_scheduler_config.dispatch_policy
self.dispatch_scheduler = DispatchScheduler(global_scheduler_config.dispatch_policy,
self.instance_load_calculator)
self.instance_load_calculator,
global_scheduler_config.num_dispatch_instances)
# migrate args
self.migration_scheduler = MigrationScheduler(global_scheduler_config.pair_migration_policy,
global_scheduler_config.migrate_out_load_threshold,
Expand All @@ -44,7 +47,8 @@ def __init__(self,
self.scaling_scheduler = ScalingScheduler(global_scheduler_config.scale_up_threshold,
global_scheduler_config.scale_down_threshold,
global_scheduler_config.scaling_policy,
self.instance_load_calculator)
self.instance_load_calculator,
global_scheduler_config.num_dispatch_instances)

self.num_instances = 0
self.instance_id_set: Set[str] = set()
Expand All @@ -56,16 +60,18 @@ def update_instance_infos(self, instance_infos: List[InstanceInfo]) -> None:
# Llumnix have different instance load compuatation methods for dispatch/migrate/scale.
instance_info.instance_load_dispatch_scale = self.instance_load_calculator.compute_instance_load(instance_info, action='dispatch')
instance_info.instance_load_migrate = self.instance_load_calculator.compute_instance_load(instance_info, action='migrate')
instance_info.instance_type = self.scaling_scheduler.get_instance_type_info(instance_info.instance_id)
self.instance_info[instance_info.instance_id] = instance_info

def dispatch(self) -> str:
self.dispatch_scheduler.update_instance_infos(self.instance_info)
instance_id = self.dispatch_scheduler.dispatch()
return instance_id
request_expected_steps = 1 if self.enable_pd_disagg else math.inf
return instance_id, request_expected_steps

def pair_migration(self) -> List[Tuple[str, str]]:
def pair_migration(self, pair_migration_type: PairMigrationConstraints) -> List[Tuple[str, str]]:
self.migration_scheduler.update_instance_infos(self.instance_info)
migrate_instance_pairs = self.migration_scheduler.pair_migration()
migrate_instance_pairs = self.migration_scheduler.pair_migration(pair_migration_type)
return migrate_instance_pairs

def check_scale(self) -> Tuple[str, str]:
Expand Down
Loading

0 comments on commit ce45945

Please sign in to comment.