Skip to content

Commit

Permalink
rebase
Browse files Browse the repository at this point in the history
  • Loading branch information
Xinyi-ECNU committed Aug 28, 2024
1 parent 7470889 commit c28584b
Show file tree
Hide file tree
Showing 14 changed files with 383 additions and 166 deletions.
2 changes: 1 addition & 1 deletion benchmark/benchmark_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,7 @@ def __init__(self):
self._all_decode_latencies = []

def measure(self, f):
async def measured(*args, **kwargs):
async def measured(*args, **gs):
start = time.time()
prompt, output = await f(*args, **kwargs)
# Do not record latency if request failed.
Expand Down
31 changes: 14 additions & 17 deletions llumnix/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,12 @@
from dataclasses import dataclass
import argparse
from typing import Tuple
import json

from llumnix.common.config import get_cfg
from llumnix.config import GlobalSchedulerConfig, MigrationConfig
from llumnix.logger import init_logger

logger = init_logger(__name__)

@dataclass
class EngineManagerArgs:
Expand Down Expand Up @@ -59,21 +62,15 @@ class EngineManagerArgs:
last_stage_max_blocks: int = 16
max_stages: int = 3

pdd_config: str = None
config_file: str = None

def create_engine_manager_configs(
self,
) -> Tuple[GlobalSchedulerConfig]:
# Create the Configuration for prefill decoding disaggregation.
prefill_instance_num = -1
# TODO[xinyi]: Bind the prefill instance to several fixed IP addresses,
# expanding into heterogeneous scenarios.
# prefill_instance_ip = None
if self.pdd_config:
pdd_config = json.load(open(self.pdd_config, 'r', encoding='utf-8'))
# TODO[xinyi]: hardcode the key fields in the pdd_config.
prefill_instance_num = pdd_config["prefill_instance_num"]
# prefill_instance_ip = pdd_config["prefill_instance_ip"]

if self.config_file:
config_data = get_cfg()
config_data.merge_from_file(self.config_file)

# Create the GlobalScheduler Configuration.
global_scheduler_config = GlobalSchedulerConfig(self.initial_instances,
Expand All @@ -85,8 +82,8 @@ def create_engine_manager_configs(
self.scaling_policy,
self.scale_up_threshold,
self.scale_down_threshold,
self.pdd_config != None,
prefill_instance_num)
config_data.PDD_CONFIG.ENABLE_PREFILL_DISAGGREATION,
config_data.PDD_CONFIG.PREFILL_INSTANCE_NUM)
return global_scheduler_config

def create_migration_config(self) -> MigrationConfig:
Expand Down Expand Up @@ -237,8 +234,8 @@ def add_cli_args(
type=int,
default=EngineManagerArgs.max_stages,
help='drop migration if the number of stages > max_stages')
parser.add_argument("--pdd-config",
parser.add_argument("--config-file",
type=str,
default=EngineManagerArgs.pdd_config,
help="configuration for prefill decoding disaggregation")
default=EngineManagerArgs.config_file,
help="path to the configuration file")
return parser
36 changes: 20 additions & 16 deletions llumnix/backends/backend_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class BackendInferenceType(str, Enum):
class BackendInterface(ABC):
# Methods for inference
@abstractmethod
def add_request(self, request_id: str, server_info: ServerInfo, req_expected_step: int,
def add_request(self, request_id: str, server_info: ServerInfo, instance_expected_steps: int,
*args, **kwargs) -> None:
"""Adds a new inference request to the backend's processing queue.
Expand All @@ -46,7 +46,9 @@ def add_request(self, request_id: str, server_info: ServerInfo, req_expected_ste
Args:
request_id: Request ID.
server_info: The information of the api server where the request come.
req_expected_step: The expected number of steps for the request.
instance_expected_steps: The expected number of steps for the request to run on the instance.
The number of steps represents the times 'engine.step()' has been
called by the backend instance for the request.
*args: Positional arguments that represent request-specific data.
**kwargs: Keyword arguments that contain metadata of the backend request
(request_id, arrival_time, etc.).
Expand Down Expand Up @@ -195,7 +197,7 @@ def should_abort_migration(self, backend_request: Any, last_stage_time: int) ->
raise NotImplementedError

@abstractmethod
def add_running_request(self, backend_request: Any, req_expected_step: Optional[int]) -> None:
def add_running_request(self, backend_request: Any) -> None:
"""
Adds a backend request to the running queue for processing.
Expand All @@ -207,7 +209,6 @@ def add_running_request(self, backend_request: Any, req_expected_step: Optional[
backend_request: An object representing the backend request. The type of this
object is dependent on the backend implementation and the details
of the request.
req_expected_step: The expected number of steps for the request.
"""
raise NotImplementedError

Expand Down Expand Up @@ -277,7 +278,7 @@ def send_blocks(self, dst_ray_actor: "ray.actor.ActorHandle", src_blocks: List[i
raise NotImplementedError

@abstractmethod
def commit_dst_request(self, backend_request: Any, server_info: ServerInfo, req_expected_step: int) -> None:
def commit_dst_request(self, backend_request: Any, server_info: ServerInfo, instance_expected_steps: int) -> None:
"""Commits the migrating request to the destination instance.
This method finalizes the migration process by transferring all necessary metadata and resource
Expand All @@ -289,7 +290,8 @@ def commit_dst_request(self, backend_request: Any, server_info: ServerInfo, req_
object is dependent on the backend implementation and the details
of the request.
server_info: The information of the api server where the request come.
req_expected_step: The expected number of steps for the request.
instance_expected_steps: The expected number of steps for the request to run on the instance. In the current version, it is implemented for prefill-decoding
disaggregation.
"""
raise NotImplementedError

Expand Down Expand Up @@ -335,13 +337,12 @@ def get_shortest_running_request(self) -> Optional[MigratingRequest]:
raise NotImplementedError

@abstractmethod
def get_pre_migration_request(self) -> Optional[MigratingRequest]:
"""Retrieves the request which meets the migration conditions from the running queue.
def get_ready_migration_request(self) -> Optional[MigratingRequest]:
"""Retrieves the request which is ready for migration from the running queue.
This method iterates over the running queue in reverse order and returns the last request
that has moved past the prefilling stage and met the migration conditions. In the current
version, a request is considered to meet the migration confitions if its number of steps
exceeds expected_steps and backend.pre_migration is True.
This method iterates over the running queue in reverse order and returns the last request that has
met the migration conditions. A request is considered to meet the migration conditions if its number
of steps exceeds instance_expected_steps and backend.strict_pre_migration is True.
Returns:
An instance of MigratingRequest representing the last request in the running queue
Expand All @@ -351,14 +352,17 @@ def get_pre_migration_request(self) -> Optional[MigratingRequest]:
raise NotImplementedError

@abstractmethod
def update_pre_migration(self, new_pre_migration: bool) -> None:
def update_strict_pre_migration(self, new_strict_pre_migration: bool) -> None:
"""Update the status of whether to force migration in the backend engine.
This method updates the status of whether to force migration in the backend engine. This
action is performed only when the corresponding status in the llumlet is changed.
This method updates the status of whether to force migration in the backend engine. This action is performed only when the
corresponding status in the llumlet is changed.
`pre_migration` represents whether the backend instance enables migration. By default, `pre_migration` is set to True, indicating that
the instance enables migration when `request.instance_completed_steps` >= `request.instance_expected_steps`. If `pre_migration` is set
to False, migration will not occur, and requests on the instance that reach the `instance_expected_steps` will continue with inference.
Args:
new_pre_migration: New migration status provided for backend engine.
new_strict_pre_migration: New migration status provided for backend engine.
"""
raise NotImplementedError

Expand Down
30 changes: 15 additions & 15 deletions llumnix/backends/vllm/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def _process_model_outputs(
new_scheduled_seq_groups.append(scheduled_seq_group)
new_seq_group_metadata_list.append(seq_group_meta)
new_output.append(seq_group_output)
self.scheduler.request_info[seq_group.request_id].completed_step += 1
self.scheduler.request_infos[seq_group.request_id].instance_completed_steps += 1
scheduled_seq_groups = new_scheduled_seq_groups
output[0].outputs = new_output
seq_group_metadata_list = new_seq_group_metadata_list
Expand Down Expand Up @@ -187,8 +187,8 @@ def free_request_states(self, request_id: Union[str, Iterable[str]]) -> None:
del self.request_server_info[req_id]
if req_id in self.scheduler.last_preemption_time_dict:
del self.scheduler.last_preemption_time_dict[req_id]
if req_id in self.scheduler.request_info:
del self.scheduler.request_info[req_id]
if req_id in self.scheduler.request_infos:
del self.scheduler.request_infos[req_id]

class BackendVLLM(BackendInterface):
def __init__(
Expand All @@ -209,7 +209,7 @@ def __init__(
self.engine.scheduler.add_update_instance_info_callback(self.engine.update_instance_info)
self.engine.output_processor.scheduler = self.engine.scheduler
self.instance_id = instance_id
self.pre_migration = True
self.strict_pre_migration = True
self.worker_handle_list = self.engine.model_executor.workers.copy()
if len(self.worker_handle_list) + 1 == self.engine.parallel_config.world_size:
self.worker_handle_list.insert(0, ray.get_actor(f"instance_{self.instance_id}", namespace="llumnix"))
Expand Down Expand Up @@ -253,25 +253,25 @@ def restart_workers(self) -> None:
def add_request(self,
request_id: str,
server_info: ServerInfo,
req_expected_step: int,
instance_expected_steps: int,
*args,
**kwargs) -> None:
# When manager is unavailable, api server might dispatch the request that has already been dispatched.
if request_id in self.engine.request_server_info:
return
# Store the server information of each request to put the request outputs back to the corresponding api server correctly.
self.engine.request_server_info[request_id] = server_info
with self.engine.scheduler.scheduler_lock:
self.engine.scheduler.request_info[request_id] = RequestInfo(expected_step=req_expected_step, completed_step=0)
self.engine.scheduler.update_request_infos(request_id, instance_expected_steps=instance_expected_steps, instance_completed_steps=0)
self.engine.add_request(request_id, *args, **kwargs)

def commit_dst_request(self, backend_request: SequenceGroup, server_info: ServerInfo, req_expected_step: int) -> None:
def commit_dst_request(self, backend_request: SequenceGroup, server_info: ServerInfo, instance_expected_steps: int) -> None:
seq = backend_request.get_seqs()[0]
seq.seq_id = next(self.engine.seq_counter)
logger.info("add seq {} to block table".format(seq.seq_id))
pre_alloc_blocks = self.engine.scheduler.pre_alloc_cache_dict.pop(backend_request.request_id)
self.engine.scheduler.block_manager.add_block_table(pre_alloc_blocks, seq.seq_id)
self.add_running_request(backend_request, req_expected_step)
self.add_running_request(backend_request)
self.engine.scheduler.update_request_infos(backend_request.request_id, instance_expected_steps=instance_expected_steps)
self.engine.request_server_info[backend_request.request_id] = server_info

def send_blocks(self, dst_ray_actor: "ray.actor.ActorHandle", src_blocks: List[int], dst_blocks: List[int]) -> None:
Expand All @@ -298,10 +298,10 @@ def abort_request(self, request_id: Union[str, Iterable[str]]) -> None:
def free_request_states(self, request_id: Union[str, Iterable[str]]) -> None:
return self.engine.free_request_states(request_id)

def update_pre_migration(self, new_migration_state: bool):
if self.pre_migration != new_migration_state:
self.pre_migration = new_migration_state
self.engine.scheduler.update_pre_migration(new_migration_state)
def update_strict_pre_migration(self, new_migration_state: bool):
if self.strict_pre_migration != new_migration_state:
self.strict_pre_migration = new_migration_state
self.engine.scheduler.update_strict_pre_migration(new_migration_state)

def get_request_incremental_blocks(self, *args, **kwargs) -> List[int]:
return self.engine.scheduler.get_request_incremental_blocks(*args, **kwargs)
Expand Down Expand Up @@ -345,8 +345,8 @@ def get_longest_running_request(self) -> Optional[MigratingRequest]:
def get_shortest_running_request(self) -> Optional[MigratingRequest]:
return self.engine.scheduler.get_shortest_running_request()

def get_pre_migration_request(self) -> Optional[MigratingRequest]:
return self.engine.scheduler.get_pre_migration_request()
def get_ready_migration_request(self) -> Optional[MigratingRequest]:
return self.engine.scheduler.get_ready_migration_request()

def get_request_server_info(self, request_id: str) -> ServerInfo:
return self.engine.request_server_info[request_id]
Expand Down
Loading

0 comments on commit c28584b

Please sign in to comment.