Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] Support for Scheduling-defined Prefill-Decode Disaggregation feature #15

Merged
merged 13 commits into from
Oct 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 19 additions & 2 deletions llumnix/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class EngineManagerArgs:
polling_interval: float = None

dispatch_policy: str = None
num_dispatch_instances: int = None

enable_migration: bool = None
enable_defrag: bool = None
Expand Down Expand Up @@ -60,23 +61,34 @@ class EngineManagerArgs:
last_stage_max_blocks: int = None
max_stages: int = None

enable_pd_disagg: bool = None

def __post_init__(self):
# Check if all fields default to None
for field_info in dataclasses.fields(self):
if field_info.default is not None:
raise ValueError(f"The default value of '{field_info.name}' should be None")

for attr in dataclasses.fields(self):
if getattr(self, attr.name) is None:
setattr(self, attr.name, getattr(_C.MANAGER, attr.name.upper()))

def create_global_scheduler_configs(
self,
) -> Tuple[GlobalSchedulerConfig]:

# Create the GlobalScheduler Configuration.
global_scheduler_config = GlobalSchedulerConfig(self.initial_instances,
self.load_metric,
self.dispatch_policy,
self.num_dispatch_instances,
self.pair_migration_policy,
self.migrate_out_threshold,
self.enable_defrag,
self.scaling_policy,
self.scale_up_threshold,
self.scale_down_threshold)
self.scale_down_threshold,
self.enable_pd_disagg)
return global_scheduler_config

def create_migration_config(self) -> MigrationConfig:
Expand Down Expand Up @@ -134,6 +146,9 @@ def add_cli_args(
type=str,
choices=['balanced', 'load', 'queue', 'flood'],
help='request dispatch policy')
parser.add_argument('--num-available-dispatch-instances',
type=int,
help='number of available instances for dispatching')

parser.add_argument('--enable-migration',
action='store_true',
Expand Down Expand Up @@ -211,5 +226,7 @@ def add_cli_args(
parser.add_argument('--max-stages',
type=int,
help='drop migration if the number of stages > max_stages')

parser.add_argument('--enable-pd-disagg',
type=bool,
help='enable prefill decoding disaggregation')
return parser
7 changes: 6 additions & 1 deletion llumnix/backends/backend_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def is_sim_backend(status: "BackendType") -> bool:
class BackendInterface(ABC):
# Methods for inference
@abstractmethod
def add_request(self, request_id: str, server_info: ServerInfo,
def add_request(self, request_id: str, server_info: ServerInfo, expected_steps: int,
*args, **kwargs) -> None:
"""Adds a new inference request to the backend's processing queue.

Expand All @@ -47,6 +47,11 @@ def add_request(self, request_id: str, server_info: ServerInfo,
Args:
request_id: Request ID.
server_info: The information of the api server where the request come.
expected_steps: The expected number of steps for the request to run. The number of steps
represents the times 'engine.step()' has been called by the backend
instances for the request. Currently, `expected_steps` is used
to implement prefill-decoding disaggregation. For requests dispatched to
prefill instances `expected_steps` is set to 1.
*args: Positional arguments that represent request-specific data.
**kwargs: Keyword arguments that contain metadata of the backend request
(request_id, arrival_time, etc.).
Expand Down
10 changes: 6 additions & 4 deletions llumnix/backends/vllm/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,13 +243,14 @@ def update_instance_info(self, instance_info: InstanceInfo) -> None:
instance_info.num_blocks_last_running_request = self.instance_info.num_blocks_last_running_request
self.instance_info = instance_info

def add_request(self, request_id: str, server_info: ServerInfo, *args, **kwargs):
def add_request(self, request_id: str, server_info: ServerInfo, expected_steps: int, *args, **kwargs):
super().add_request(request_id, *args, **kwargs)
seq_group = self.scheduler.waiting[-1]
if hasattr(server_info, 'request_timestamps'):
server_info.request_timestamps.engine_add_request_timestamp = time.time()
self.scheduler.waiting[-1] = SequenceGroupLlumnix(request_id, server_info, [seq_group.get_seqs()[0]], seq_group.sampling_params,
seq_group.metrics.arrival_time, seq_group.lora_request, seq_group.multi_modal_data)
self.scheduler.waiting[-1] = SequenceGroupLlumnix(request_id, server_info, expected_steps, [seq_group.get_seqs()[0]],
seq_group.sampling_params, seq_group.metrics.arrival_time, seq_group.lora_request,
seq_group.multi_modal_data)
self.scheduler.scheduler_lock.release()

def _start_put_queue_loop(self):
Expand Down Expand Up @@ -346,10 +347,11 @@ def execute_worker_method(self, method, *args, **kwargs):
def add_request(self,
request_id: str,
server_info: ServerInfo,
expected_steps: int,
*args,
**kwargs) -> None:
# Store the server information of each request to put the request outputs back to the corresponding api server correctly.
self.engine.add_request(request_id, server_info, *args, **kwargs)
self.engine.add_request(request_id, server_info, expected_steps, *args, **kwargs)

def commit_dst_request(self, backend_request: SequenceGroupLlumnix) -> None:
seq = backend_request.get_seqs()[0]
Expand Down
14 changes: 14 additions & 0 deletions llumnix/backends/vllm/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import time
import threading
from typing import Dict, List, Optional, Tuple
from collections import deque

from vllm.core.block_manager_v1 import BlockSpaceManagerV1, BlockTable
from vllm.core.scheduler import (Scheduler, PreemptionMode, SequenceStatus, SequenceGroupMetadata, SchedulerOutputs)
Expand Down Expand Up @@ -207,6 +208,19 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]:
for scheduled_seq_group in scheduler_outputs.scheduled_seq_groups]))
return seq_group_metadata_list, scheduler_outputs

def _schedule_running(self, running_queue: deque, *args, **kwargs):
filtered_running_queue = deque()
remove_running = deque()
for seq_group in running_queue:
if seq_group.output_len >= seq_group.expected_steps:
remove_running.extend([seq_group])
else:
filtered_running_queue.extend([seq_group])
remaining_running, running_scheduled = super()._schedule_running(filtered_running_queue, *args, **kwargs)
for seq_group in remove_running:
remaining_running.extend([seq_group])
return remaining_running, running_scheduled

def add_seq_group(self, *args, **kwargs):
# The scheduler lock is mannually released in the end of LLMEngineLlumnix.add_request function.
# pylint: disable=R1732
Expand Down
4 changes: 2 additions & 2 deletions llumnix/backends/vllm/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@


class SequenceGroupLlumnix(SequenceGroup, LlumnixRequest):
def __init__(self, request_id, server_info, *args, **kwargs) -> None:
def __init__(self, request_id, server_info, expected_steps: int, *args, **kwargs) -> None:
SequenceGroup.__init__(self, request_id, *args, **kwargs)
LlumnixRequest.__init__(self, request_id, server_info)
LlumnixRequest.__init__(self, request_id, server_info, expected_steps)

@property
def prompt_len(self) -> int:
Expand Down
10 changes: 10 additions & 0 deletions llumnix/config/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import math

from .config import LlumnixConfig as LC

# -----------------------------------------------------------------------------
Expand Down Expand Up @@ -77,6 +79,8 @@
_C.MANAGER.LOAD_METRIC = 'remaining_steps'
# Request dispatch policy
_C.MANAGER.DISPATCH_POLICY = 'load'
# Number of available dispatch instances. -1 indicates that all instances can be used for dispatching
_C.MANAGER.NUM_DISPATCH_INSTANCES = math.inf

# -----------------------------------------------------------------------------
# MIGRATION CONFIGURATION
Expand Down Expand Up @@ -124,3 +128,9 @@
_C.MANAGER.SCALE_UP_THRESHOLD = 10
# Scale down threshold
_C.MANAGER.SCALE_DOWN_THRESHOLD = 60

# -----------------------------------------------------------------------------
# PREFILL DECODING DISAGGREGATION CONFIGURATION
# -----------------------------------------------------------------------------
# Enable prefill decoding disaggregation
_C.MANAGER.ENABLE_PD_DISAGG = False
1 change: 0 additions & 1 deletion llumnix/entrypoints/llumnix_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,6 @@ def init_llumlets(engine_manager_args: EngineManagerArgs, engine_args, node_id:

instance_ids = [random_uuid() for _ in range(engine_manager_args.initial_instances)]
migration_configs = engine_manager_args.create_migration_config()

for idx in range(engine_manager_args.initial_instances):
instance_id = instance_ids[idx]
if not engine_manager_args.profiling_result_file_path:
Expand Down
18 changes: 14 additions & 4 deletions llumnix/global_scheduler/dispatch_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,14 @@
class DispatchScheduler:
def __init__(self,
dispatch_policy: str,
instance_load_calculator: InstanceLoadCalculator) -> None:
instance_load_calculator: InstanceLoadCalculator,
num_dispatch_instances: int) -> None:
self.dispatch_policy = DispatchPolicyFactory.get_policy(dispatch_policy)
self.instance_load_calculator = instance_load_calculator
self.num_instances = 0
self.instance_id_set: Set[str] = set()
self.available_dispatch_instance_set: Set[str] = set()
self.num_dispatch_instances = num_dispatch_instances
# instance info args
self.instance_info: Dict[str, InstanceInfo] = {}
self.sorted_instance_infos: List[InstanceInfo] = None
Expand Down Expand Up @@ -56,22 +59,29 @@ def update_instance_infos(self,
def add_instance(self, instance_id: str) -> None:
self.instance_id_set.add(instance_id)
self.num_instances = len(self.instance_id_set)
self.instance_num_requests[instance_id] = 0
if self.num_dispatch_instances <= 0 or (self.num_dispatch_instances > 0 and
len(self.available_dispatch_instance_set) < self.num_dispatch_instances):
self.available_dispatch_instance_set.add(instance_id)
self.instance_num_requests[instance_id] = 0

def remove_instance(self, instance_id: str) -> None:
self.instance_id_set.remove(instance_id)
self.num_instances = len(self.instance_id_set)
del self.instance_num_requests[instance_id]
if instance_id in self.instance_num_requests:
del self.instance_num_requests[instance_id]
if instance_id in self.available_dispatch_instance_set:
self.available_dispatch_instance_set.remove(instance_id)

def _sort_instance_infos(self,
descending: bool = True) -> None:
instance_infos: List[InstanceInfo] = list(self.instance_info.values())
available_instance_infos = [info for info in instance_infos if info.instance_id in self.available_dispatch_instance_set]
if isinstance(self.dispatch_policy, Queue):
key_attr = 'num_waiting_requests'
else:
key_attr = 'instance_load_dispatch_scale'
self.sorted_instance_infos = sorted(
instance_infos,
available_instance_infos,
key=lambda instance_info: getattr(instance_info, key_attr),
reverse=descending
)
Expand Down
18 changes: 12 additions & 6 deletions llumnix/global_scheduler/global_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,13 @@
# limitations under the License.

from typing import Dict, List, Tuple, Union, Iterable, Set
import math

from llumnix.logger import init_logger
from llumnix.internal_config import GlobalSchedulerConfig
from llumnix.instance_info import InstanceLoadCalculator, InstanceInfo
from llumnix.global_scheduler.dispatch_scheduler import DispatchScheduler
from llumnix.global_scheduler.migration_scheduler import MigrationScheduler
from llumnix.global_scheduler.migration_scheduler import MigrationScheduler, PairMigrationConstraints
from llumnix.global_scheduler.scaling_scheduler import ScalingScheduler

logger = init_logger(__name__)
Expand All @@ -30,12 +31,14 @@ def __init__(self,
# instance load and instance info args
self.load_metric = global_scheduler_config.load_metric
self.enable_defrag = global_scheduler_config.enable_defrag
self.enable_pd_disagg = global_scheduler_config.enable_pd_disagg
self.instance_load_calculator = InstanceLoadCalculator(load_metric=self.load_metric,
enable_defrag=self.enable_defrag)
# dispatch args
self.dispatch_policy = global_scheduler_config.dispatch_policy
self.dispatch_scheduler = DispatchScheduler(global_scheduler_config.dispatch_policy,
self.instance_load_calculator)
self.instance_load_calculator,
global_scheduler_config.num_dispatch_instances)
# migrate args
self.migration_scheduler = MigrationScheduler(global_scheduler_config.pair_migration_policy,
global_scheduler_config.migrate_out_load_threshold,
Expand All @@ -44,7 +47,8 @@ def __init__(self,
self.scaling_scheduler = ScalingScheduler(global_scheduler_config.scale_up_threshold,
global_scheduler_config.scale_down_threshold,
global_scheduler_config.scaling_policy,
self.instance_load_calculator)
self.instance_load_calculator,
global_scheduler_config.num_dispatch_instances)

self.num_instances = 0
self.instance_id_set: Set[str] = set()
Expand All @@ -56,16 +60,18 @@ def update_instance_infos(self, instance_infos: List[InstanceInfo]) -> None:
# Llumnix have different instance load compuatation methods for dispatch/migrate/scale.
instance_info.instance_load_dispatch_scale = self.instance_load_calculator.compute_instance_load(instance_info, action='dispatch')
instance_info.instance_load_migrate = self.instance_load_calculator.compute_instance_load(instance_info, action='migrate')
instance_info.instance_type = self.scaling_scheduler.get_instance_type_info(instance_info.instance_id)
self.instance_info[instance_info.instance_id] = instance_info

def dispatch(self) -> str:
self.dispatch_scheduler.update_instance_infos(self.instance_info)
instance_id = self.dispatch_scheduler.dispatch()
return instance_id
request_expected_steps = 1 if self.enable_pd_disagg else math.inf
return instance_id, request_expected_steps

def pair_migration(self) -> List[Tuple[str, str]]:
def pair_migration(self, pair_migration_type: PairMigrationConstraints) -> List[Tuple[str, str]]:
self.migration_scheduler.update_instance_infos(self.instance_info)
migrate_instance_pairs = self.migration_scheduler.pair_migration()
migrate_instance_pairs = self.migration_scheduler.pair_migration(pair_migration_type)
return migrate_instance_pairs

def check_scale(self) -> Tuple[str, str]:
Expand Down
Loading