Skip to content

Commit

Permalink
rebase
Browse files Browse the repository at this point in the history
  • Loading branch information
Xinyi-ECNU committed Aug 28, 2024
1 parent 56c0c3c commit 7470889
Show file tree
Hide file tree
Showing 13 changed files with 316 additions and 110 deletions.
25 changes: 22 additions & 3 deletions llumnix/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from dataclasses import dataclass
import argparse
from typing import Tuple

import json
from llumnix.config import GlobalSchedulerConfig, MigrationConfig


Expand Down Expand Up @@ -59,9 +59,23 @@ class EngineManagerArgs:
last_stage_max_blocks: int = 16
max_stages: int = 3

pdd_config: str = None

def create_engine_manager_configs(
self,
) -> Tuple[GlobalSchedulerConfig]:
# Create the Configuration for prefill decoding disaggregation.
prefill_instance_num = -1
# TODO[xinyi]: Bind the prefill instance to several fixed IP addresses,
# expanding into heterogeneous scenarios.
# prefill_instance_ip = None
if self.pdd_config:
pdd_config = json.load(open(self.pdd_config, 'r', encoding='utf-8'))
# TODO[xinyi]: hardcode the key fields in the pdd_config.
prefill_instance_num = pdd_config["prefill_instance_num"]
# prefill_instance_ip = pdd_config["prefill_instance_ip"]

# Create the GlobalScheduler Configuration.
global_scheduler_config = GlobalSchedulerConfig(self.initial_instances,
self.load_metric,
self.dispatch_policy,
Expand All @@ -70,7 +84,9 @@ def create_engine_manager_configs(
self.enable_defrag,
self.scaling_policy,
self.scale_up_threshold,
self.scale_down_threshold)
self.scale_down_threshold,
self.pdd_config != None,
prefill_instance_num)
return global_scheduler_config

def create_migration_config(self) -> MigrationConfig:
Expand Down Expand Up @@ -221,5 +237,8 @@ def add_cli_args(
type=int,
default=EngineManagerArgs.max_stages,
help='drop migration if the number of stages > max_stages')

parser.add_argument("--pdd-config",
type=str,
default=EngineManagerArgs.pdd_config,
help="configuration for prefill decoding disaggregation")
return parser
37 changes: 34 additions & 3 deletions llumnix/backends/backend_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class BackendInferenceType(str, Enum):
class BackendInterface(ABC):
# Methods for inference
@abstractmethod
def add_request(self, request_id: str, server_info: ServerInfo,
def add_request(self, request_id: str, server_info: ServerInfo, req_expected_step: int,
*args, **kwargs) -> None:
"""Adds a new inference request to the backend's processing queue.
Expand All @@ -46,6 +46,7 @@ def add_request(self, request_id: str, server_info: ServerInfo,
Args:
request_id: Request ID.
server_info: The information of the api server where the request come.
req_expected_step: The expected number of steps for the request.
*args: Positional arguments that represent request-specific data.
**kwargs: Keyword arguments that contain metadata of the backend request
(request_id, arrival_time, etc.).
Expand Down Expand Up @@ -194,7 +195,7 @@ def should_abort_migration(self, backend_request: Any, last_stage_time: int) ->
raise NotImplementedError

@abstractmethod
def add_running_request(self, backend_request: Any) -> None:
def add_running_request(self, backend_request: Any, req_expected_step: Optional[int]) -> None:
"""
Adds a backend request to the running queue for processing.
Expand All @@ -206,6 +207,7 @@ def add_running_request(self, backend_request: Any) -> None:
backend_request: An object representing the backend request. The type of this
object is dependent on the backend implementation and the details
of the request.
req_expected_step: The expected number of steps for the request.
"""
raise NotImplementedError

Expand Down Expand Up @@ -275,7 +277,7 @@ def send_blocks(self, dst_ray_actor: "ray.actor.ActorHandle", src_blocks: List[i
raise NotImplementedError

@abstractmethod
def commit_dst_request(self, backend_request: Any, server_info: ServerInfo) -> None:
def commit_dst_request(self, backend_request: Any, server_info: ServerInfo, req_expected_step: int) -> None:
"""Commits the migrating request to the destination instance.
This method finalizes the migration process by transferring all necessary metadata and resource
Expand All @@ -287,6 +289,7 @@ def commit_dst_request(self, backend_request: Any, server_info: ServerInfo) -> N
object is dependent on the backend implementation and the details
of the request.
server_info: The information of the api server where the request come.
req_expected_step: The expected number of steps for the request.
"""
raise NotImplementedError

Expand Down Expand Up @@ -330,6 +333,34 @@ def get_shortest_running_request(self) -> Optional[MigratingRequest]:
has generated output, or None if no such request exists.
"""
raise NotImplementedError

@abstractmethod
def get_pre_migration_request(self) -> Optional[MigratingRequest]:
"""Retrieves the request which meets the migration conditions from the running queue.
This method iterates over the running queue in reverse order and returns the last request
that has moved past the prefilling stage and met the migration conditions. In the current
version, a request is considered to meet the migration confitions if its number of steps
exceeds expected_steps and backend.pre_migration is True.
Returns:
An instance of MigratingRequest representing the last request in the running queue
that is not prefilling and meets the migration conditions, or None if there are no
such requests in the queue.
"""
raise NotImplementedError

@abstractmethod
def update_pre_migration(self, new_pre_migration: bool) -> None:
"""Update the status of whether to force migration in the backend engine.
This method updates the status of whether to force migration in the backend engine. This
action is performed only when the corresponding status in the llumlet is changed.
Args:
new_pre_migration: New migration status provided for backend engine.
"""
raise NotImplementedError

@abstractmethod
def get_request_server_info(self, request_id: str) -> ServerInfo:
Expand Down
21 changes: 18 additions & 3 deletions llumnix/backends/vllm/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from llumnix.llumlet.migrating_request import MigratingRequest
from llumnix.instance_info import InstanceInfo
from llumnix.backends.backend_interface import BackendInterface
from llumnix.backends.vllm.scheduler import SchedulerLlumnix
from llumnix.backends.vllm.scheduler import SchedulerLlumnix, RequestInfo
from llumnix.backends.vllm.utils import detect_unsupported_feature
from llumnix.backends.profiling import LatencyMemData
from llumnix.server_info import ServerInfo
Expand Down Expand Up @@ -110,6 +110,7 @@ def _process_model_outputs(
new_scheduled_seq_groups.append(scheduled_seq_group)
new_seq_group_metadata_list.append(seq_group_meta)
new_output.append(seq_group_output)
self.scheduler.request_info[seq_group.request_id].completed_step += 1
scheduled_seq_groups = new_scheduled_seq_groups
output[0].outputs = new_output
seq_group_metadata_list = new_seq_group_metadata_list
Expand Down Expand Up @@ -186,6 +187,8 @@ def free_request_states(self, request_id: Union[str, Iterable[str]]) -> None:
del self.request_server_info[req_id]
if req_id in self.scheduler.last_preemption_time_dict:
del self.scheduler.last_preemption_time_dict[req_id]
if req_id in self.scheduler.request_info:
del self.scheduler.request_info[req_id]

class BackendVLLM(BackendInterface):
def __init__(
Expand All @@ -206,6 +209,7 @@ def __init__(
self.engine.scheduler.add_update_instance_info_callback(self.engine.update_instance_info)
self.engine.output_processor.scheduler = self.engine.scheduler
self.instance_id = instance_id
self.pre_migration = True
self.worker_handle_list = self.engine.model_executor.workers.copy()
if len(self.worker_handle_list) + 1 == self.engine.parallel_config.world_size:
self.worker_handle_list.insert(0, ray.get_actor(f"instance_{self.instance_id}", namespace="llumnix"))
Expand Down Expand Up @@ -249,22 +253,25 @@ def restart_workers(self) -> None:
def add_request(self,
request_id: str,
server_info: ServerInfo,
req_expected_step: int,
*args,
**kwargs) -> None:
# When manager is unavailable, api server might dispatch the request that has already been dispatched.
if request_id in self.engine.request_server_info:
return
# Store the server information of each request to put the request outputs back to the corresponding api server correctly.
self.engine.request_server_info[request_id] = server_info
with self.engine.scheduler.scheduler_lock:
self.engine.scheduler.request_info[request_id] = RequestInfo(expected_step=req_expected_step, completed_step=0)
self.engine.add_request(request_id, *args, **kwargs)

def commit_dst_request(self, backend_request: SequenceGroup, server_info: ServerInfo) -> None:
def commit_dst_request(self, backend_request: SequenceGroup, server_info: ServerInfo, req_expected_step: int) -> None:
seq = backend_request.get_seqs()[0]
seq.seq_id = next(self.engine.seq_counter)
logger.info("add seq {} to block table".format(seq.seq_id))
pre_alloc_blocks = self.engine.scheduler.pre_alloc_cache_dict.pop(backend_request.request_id)
self.engine.scheduler.block_manager.add_block_table(pre_alloc_blocks, seq.seq_id)
self.add_running_request(backend_request)
self.add_running_request(backend_request, req_expected_step)
self.engine.request_server_info[backend_request.request_id] = server_info

def send_blocks(self, dst_ray_actor: "ray.actor.ActorHandle", src_blocks: List[int], dst_blocks: List[int]) -> None:
Expand All @@ -291,6 +298,11 @@ def abort_request(self, request_id: Union[str, Iterable[str]]) -> None:
def free_request_states(self, request_id: Union[str, Iterable[str]]) -> None:
return self.engine.free_request_states(request_id)

def update_pre_migration(self, new_migration_state: bool):
if self.pre_migration != new_migration_state:
self.pre_migration = new_migration_state
self.engine.scheduler.update_pre_migration(new_migration_state)

def get_request_incremental_blocks(self, *args, **kwargs) -> List[int]:
return self.engine.scheduler.get_request_incremental_blocks(*args, **kwargs)

Expand Down Expand Up @@ -332,6 +344,9 @@ def get_longest_running_request(self) -> Optional[MigratingRequest]:

def get_shortest_running_request(self) -> Optional[MigratingRequest]:
return self.engine.scheduler.get_shortest_running_request()

def get_pre_migration_request(self) -> Optional[MigratingRequest]:
return self.engine.scheduler.get_pre_migration_request()

def get_request_server_info(self, request_id: str) -> ServerInfo:
return self.engine.request_server_info[request_id]
Expand Down
50 changes: 49 additions & 1 deletion llumnix/backends/vllm/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from asyncio.log import logger
import time
import threading
import copy
from typing import Dict, List, Optional, Tuple

from vllm.sequence import SequenceGroup
Expand All @@ -29,6 +30,13 @@

logger = init_logger(__name__)

<<<<<<< HEAD
=======
class RequestInfo():
def __init__(self, completed_step: int, expected_step: int) -> None:
self.expected_step = expected_step
self.completed_step = completed_step
>>>>>>> fe66a36 (init)

# TODO(ZeldaHuang): adapt prefix cache and sliding window, now use v1 manager
class BlockManagerLlumnix(BlockSpaceManagerV1):
Expand Down Expand Up @@ -62,6 +70,8 @@ def __init__(self, *args, **kwargs) -> None:
self.prefilling_seq_groups = []
self.scheduler_lock = threading.Lock()
self.migrating_out_request_last_stage = []
self.pre_migration = True
self.request_info: Dict[str, RequestInfo] = {}

def add_update_instance_info_callback(self, update_instance_info_callback):
self.update_instance_info_callback = update_instance_info_callback
Expand All @@ -82,6 +92,19 @@ def _get_num_killed_requests(self) -> int:
if seq_group.request_id in self.last_preemption_time_dict:
cnt += 1
return cnt

# TODO(xinyi): Currently, the function is only used for Prefill-decoding disaggregation,
# and only selects request that migrates from the prefill instance to the decoding instance.
@scheduler_lock
def get_pre_migration_request(self) -> Optional[MigratingRequest]:
pre_migration_request = None
if self.pre_migration:
for seq_group in reversed(self.running):
# logger.info("get_pre_migration_request {} {}".format(self.request_info[seq_group.request_id].expected_step,self.request_info[seq_group.request_id].completed_step))
if self.request_info[seq_group.request_id].expected_step > 0 and \
self.request_info[seq_group.request_id].completed_step >= self.request_info[seq_group.request_id].expected_step:
return MigratingRequest(seq_group.request_id, seq_group, expected_step=-1, blocking_migration=False)
return pre_migration_request

@scheduler_lock
def get_last_running_request(self) -> Optional[MigratingRequest]:
Expand Down Expand Up @@ -161,10 +184,15 @@ def should_abort_migration(self, backend_request: SequenceGroup, last_stage_time
return False

@scheduler_lock
def add_running_request(self, backend_request: SequenceGroup) -> None:
def add_running_request(self, backend_request: SequenceGroup, req_expected_step: Optional[int] = None) -> None:
seq = backend_request.get_seqs()[0]
seq.status = SequenceStatus.RUNNING
self.running.append(backend_request)
if req_expected_step:
if backend_request.request_id not in self.request_info:
self.request_info[backend_request.request_id] = RequestInfo(expected_step=req_expected_step, completed_step=0)
else:
self.request_info[backend_request.request_id].expected_step = req_expected_step

@scheduler_lock
def is_request_running(self, backend_request: SequenceGroup) -> bool:
Expand Down Expand Up @@ -246,6 +274,26 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]:
self.update_instance_info_callback(self._get_instance_info())
return seq_group_metadata_list, scheduler_outputs

def _schedule_running(self, *args, **kwargs):
args_list = list(args)
args_list[0] = copy.deepcopy(self.running)
remove_running = []
if self.pre_migration:
for seq_group in list(args_list[0]):
if self.request_info[seq_group.request_id].expected_step > 0 and \
self.request_info[seq_group.request_id].completed_step >= self.request_info[seq_group.request_id].expected_step:
args_list[0].remove(seq_group)
remove_running.append(seq_group)
new_args = tuple(args_list)
remaining_running, running_scheduled = super()._schedule_running(*new_args, **kwargs)
for seq_group in remove_running:
remaining_running.append(seq_group)
return remaining_running, running_scheduled

@scheduler_lock
def update_pre_migration(self, new_migration_state: bool) -> None:
self.pre_migration = new_migration_state

@scheduler_lock
def add_seq_group(self, *args, **kwargs):
return super().add_seq_group(*args, **kwargs)
Expand Down
7 changes: 6 additions & 1 deletion llumnix/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@ def __init__(
enable_defrag: bool,
scaling_policy: str,
scale_up_threshold: float,
scale_down_threshold: float) -> None:
scale_down_threshold: float,
enable_pd_disaggregation: bool,
available_dispatch_instance_num: int) -> None:
self.initial_instances = initial_instances
self.load_metric = load_metric

Expand All @@ -53,3 +55,6 @@ def __init__(
self.scaling_policy = scaling_policy
self.scale_up_threshold = scale_up_threshold*(-1)
self.scale_down_threshold = scale_down_threshold*(-1)

self.enable_pd_disaggregation = enable_pd_disaggregation
self.available_dispatch_instance_num = available_dispatch_instance_num
11 changes: 9 additions & 2 deletions llumnix/global_scheduler/dispatch_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,14 @@
class DispatchScheduler:
def __init__(self,
dispatch_policy: str,
instance_load_calculator: InstanceLoadCalculator) -> None:
instance_load_calculator: InstanceLoadCalculator,
available_dispatch_instance_num: int) -> None:
self.dispatch_policy = DispatchPolicyFactory.get_policy(dispatch_policy)
self.instance_load_calculator = instance_load_calculator
self.num_instances = 0
self.instance_id_set: Set[str] = set()
self.available_dispatch_instance_set: Set[str] = set()
self.available_dispatch_instance_num = available_dispatch_instance_num
# instance info args
self.instance_info: Dict[str, InstanceInfo] = {}
self.sorted_instance_infos: List[InstanceInfo] = None
Expand Down Expand Up @@ -57,6 +60,9 @@ def add_instance(self, instance_id: str) -> None:
self.instance_id_set.add(instance_id)
self.num_instances = len(self.instance_id_set)
self.instance_num_requests[instance_id] = 0
if self.available_dispatch_instance_num > 0 and len(self.available_dispatch_instance_set) < self.available_dispatch_instance_num:
self.available_dispatch_instance_set.add(instance_id)


def remove_instance(self, instance_id: str) -> None:
self.instance_id_set.remove(instance_id)
Expand All @@ -66,12 +72,13 @@ def remove_instance(self, instance_id: str) -> None:
def _sort_instance_infos(self,
descending: bool = True) -> None:
instance_infos: List[InstanceInfo] = list(self.instance_info.values())
available_instance_infos = [info for info in instance_infos if info.instance_id in self.available_dispatch_instance_set]
if isinstance(self.dispatch_policy, Queue):
key_attr = 'num_waiting_requests'
else:
key_attr = 'instance_load_dispatch_scale'
self.sorted_instance_infos = sorted(
instance_infos,
available_instance_infos,
key=lambda instance_info: getattr(instance_info, key_attr),
reverse=descending
)
Expand Down
Loading

0 comments on commit 7470889

Please sign in to comment.