Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
fix

fix pylint

update

fix

add todo

add todo

fix
  • Loading branch information
ZeldaHuang committed Aug 30, 2024
1 parent d4cd8fa commit e4617d9
Show file tree
Hide file tree
Showing 17 changed files with 313 additions and 404 deletions.
121 changes: 17 additions & 104 deletions llumnix/backends/backend_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@

from abc import ABC, abstractmethod
from enum import Enum
from typing import Any, Iterable, List, Optional, Union
from typing import Iterable, List, Union

from llumnix.llumlet.migrating_request import MigratingRequest
from llumnix.llumlet.request import LlumnixRequest
from llumnix.server_info import ServerInfo


Expand All @@ -29,10 +29,6 @@ def is_sim_backend(status: "BackendType") -> bool:
BackendType.SIM_VLLM,
]

class BackendInferenceType(str, Enum):
PREFILL = "prefill"
DECODE = "decode"

class BackendInterface(ABC):
# Methods for inference
@abstractmethod
Expand Down Expand Up @@ -69,7 +65,7 @@ def _start_engine_loop(self) -> None:

# Methods for migration
@abstractmethod
def get_request_incremental_blocks(self, backend_request: Any, pre_stage_num_blocks: int) -> List[int]:
def get_request_incremental_blocks(self, backend_request: LlumnixRequest, pre_stage_num_blocks: int) -> List[int]:
"""Retrieves the incremental block table for a given request.
This method is used to fetch a list of block numbers that represent the incremental
Expand All @@ -92,6 +88,13 @@ def get_request_incremental_blocks(self, backend_request: Any, pre_stage_num_blo
"""
raise NotImplementedError

@abstractmethod
def get_running_queue(self) -> List[LlumnixRequest]:
"""
Return backend's running queue.
"""
raise NotImplementedError

@abstractmethod
def remove_running_request(self, request_id: str) -> None:
"""
Expand All @@ -108,7 +111,7 @@ def remove_running_request(self, request_id: str) -> None:
raise NotImplementedError

@abstractmethod
def add_migrating_out_request_last_stage(self, backend_request: Any) -> None:
def add_migrating_out_request_last_stage(self, backend_request: LlumnixRequest) -> None:
"""
Adds a backend request to the list of migrating out request in last stage.
Expand All @@ -123,7 +126,7 @@ def add_migrating_out_request_last_stage(self, backend_request: Any) -> None:
raise NotImplementedError

@abstractmethod
def remove_migrating_out_request_last_stage(self, backend_request: Any) -> None:
def remove_migrating_out_request_last_stage(self, backend_request: LlumnixRequest) -> None:
"""
Removes a backend request from the list of migrating out request in last stage.
Expand All @@ -138,7 +141,7 @@ def remove_migrating_out_request_last_stage(self, backend_request: Any) -> None:
raise NotImplementedError

@abstractmethod
def pop_migrating_out_requests_last_stage(self) -> List[Any]:
def pop_migrating_out_requests_last_stage(self) -> List[LlumnixRequest]:
"""
Pops the list of migrating out request in last stage.
Expand Down Expand Up @@ -170,31 +173,7 @@ def pre_alloc(self, request_id: str, block_num: int) -> List[int]:
raise NotImplementedError

@abstractmethod
def should_abort_migration(self, backend_request: Any, last_stage_time: int) -> bool:
"""
Determines whether the migration for a specific backend request should be aborted.
This method evaluates the conditions under which a migration process, associated with a
backend request, should be terminated before completion. If the backend request is no longer valid or a preemption
event has occurred since the last migration stage, the migration need to be aborted.
Args:
backend_request: An object representing the backend request. The type of this
object is dependent on the backend implementation and the details
of the request.
last_stage_time: An integer timestamp representing the last successful stage of the
migration process. This is used to determine if any significant event
has occurred after this point that would warrant the abortion of the
migration.
Returns:
True if the migration should be aborted based on the evaluation of the backend_request
and last_stage_time; False otherwise.
"""
raise NotImplementedError

@abstractmethod
def add_running_request(self, backend_request: Any) -> None:
def add_running_request(self, backend_request: LlumnixRequest) -> None:
"""
Adds a backend request to the running queue for processing.
Expand All @@ -210,7 +189,7 @@ def add_running_request(self, backend_request: Any) -> None:
raise NotImplementedError

@abstractmethod
def is_request_running(self, backend_request: Any) -> bool:
def is_request_running(self, backend_request: LlumnixRequest) -> bool:
"""Checks if a given backend request is currently in the running queue.
This method determines whether a backend request is present and actively being processed
Expand Down Expand Up @@ -241,7 +220,7 @@ def free_dst_pre_alloc_cache(self, request_id: str = None) -> None:
raise NotImplementedError

@abstractmethod
def free_src_request(self, backend_request: Any) -> None:
def free_src_request(self, backend_request: LlumnixRequest) -> None:
"""Frees blocks associated with a migrating request on the source instance.
Upon completion or cancellation of a migration process, this method is invoked to clean up and
Expand Down Expand Up @@ -275,7 +254,7 @@ def send_blocks(self, dst_ray_actor: "ray.actor.ActorHandle", src_blocks: List[i
raise NotImplementedError

@abstractmethod
def commit_dst_request(self, backend_request: Any, server_info: ServerInfo) -> None:
def commit_dst_request(self, backend_request: LlumnixRequest) -> None:
"""Commits the migrating request to the destination instance.
This method finalizes the migration process by transferring all necessary metadata and resource
Expand All @@ -286,61 +265,6 @@ def commit_dst_request(self, backend_request: Any, server_info: ServerInfo) -> N
backend_request: An object representing the backend request. The type of this
object is dependent on the backend implementation and the details
of the request.
server_info: The information of the api server where the request come.
"""
raise NotImplementedError

@abstractmethod
def get_last_running_request(self) -> Optional[MigratingRequest]:
"""Retrieves the last non-prefilling request from the running queue.
This method iterates over the running queue in reverse order and returns the last request
that has moved past the prefilling stage. Prefilling requests are not considered for
migration, as they have not yet begun processing and therefore do not have any state that
needs to be preserved.
Returns:
An instance of MigratingRequest representing the last request in the running queue that
is not prefilling, or None if there are no such requests in the queue.
"""
raise NotImplementedError

@abstractmethod
def get_longest_running_request(self) -> Optional[MigratingRequest]:
"""Retrieves the request with the longest sequence length from the running queue.
This method should sorts the running queue based on length of the requests and
returns the non-prefilling longest one.
Returns:
An instance of MigratingRequest representing the longest-running request in the queue that
has generated output, or None if no such request exists.
"""
raise NotImplementedError

@abstractmethod
def get_shortest_running_request(self) -> Optional[MigratingRequest]:
"""Retrieves the request with the shortest sequence length from the running queue.
This method should sorts the running queue based on length of the requests and
returns the non-prefilling shortest one.
Returns:
An instance of MigratingRequest representing the shortest-running request in the queue that
has generated output, or None if no such request exists.
"""
raise NotImplementedError

@abstractmethod
def get_request_server_info(self, request_id: str) -> ServerInfo:
"""Retrieves the information of the api server where the request come.
This method is used by the migration coordinator to get the information of server where the migrating request come.
Args:
request_id: Request ID.
Returns:
The request output queue of the api server where the request come.
"""
raise NotImplementedError

Expand All @@ -354,14 +278,3 @@ def get_all_request_ids(self) -> List[str]:
The list of request ID.
"""
raise NotImplementedError

@abstractmethod
def free_request_states(self, request_id: Union[str, Iterable[str]]) -> None:
"""Free request states recorded in backend engine.
This method is used by the llumlet or backend engine to free the request states when the request is finished/migrated/aborted.
Args:
request_id: Single/List of request ID.
"""
raise NotImplementedError
13 changes: 7 additions & 6 deletions llumnix/backends/profiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@
import pandas as pd
import numpy as np

from llumnix.backends.backend_interface import BackendType, BackendInferenceType
from llumnix.backends.backend_interface import BackendType
from llumnix.llumlet.request import RequestInferenceType

# 2D parallel configuration
# (gpu, tensor parallel, pipeline parallel)
Expand Down Expand Up @@ -53,8 +54,8 @@ class LatencyMemData:
decode_model_params: Any = None
prefill_model_params: Any = None

def add_latency_result(self, inference_type: BackendInferenceType, batch_size: int, tot_seq_len: int, latency: List[float]):
if inference_type == BackendInferenceType.PREFILL:
def add_latency_result(self, inference_type: RequestInferenceType, batch_size: int, tot_seq_len: int, latency: List[float]):
if inference_type == RequestInferenceType.PREFILL:
self.prefill_latency[batch_size] = latency
else:
self.decode_latency[(batch_size, tot_seq_len)] = latency
Expand All @@ -79,13 +80,13 @@ class ProfilingResult:
# The latency of postprocess on CPU.
postprocess_cpu: float = 0.0

def add_latency_result(self, parallel_config: SimParallelConfig, inference_type: BackendInferenceType, batch_size: int,
def add_latency_result(self, parallel_config: SimParallelConfig, inference_type: RequestInferenceType, batch_size: int,
tot_seq_len: int, stage_latency: List[float], metadata: Any = None):
"""Add or overwrite the profiling results of a model."""
if parallel_config not in self.para_dict:
self.para_dict[parallel_config] = LatencyMemData(
metadata=metadata, prefill_latency={}, decode_latency={}, cache_dict={})
if inference_type == BackendInferenceType.PREFILL:
if inference_type == RequestInferenceType.PREFILL:
self.para_dict[parallel_config].prefill_latency = {tot_seq_len: stage_latency}
else:
self.para_dict[parallel_config].decode_latency = {(batch_size, tot_seq_len): stage_latency}
Expand Down Expand Up @@ -145,7 +146,7 @@ def update(self, result: ProfilingResult):

def _extract_data(self, row):
"""Extract the profiling results from a row of the profiling CSV file."""
inference_type = BackendInferenceType.PREFILL if row["inference_type"] == "prefill" else BackendInferenceType.DECODE
inference_type = RequestInferenceType.PREFILL if row["inference_type"] == "prefill" else RequestInferenceType.DECODE
# assert pp==1
stage_latencies = [float(row["latency"])]
batch_size = _pad_to_alignment(int(row["bs"]), 8)
Expand Down
Loading

0 comments on commit e4617d9

Please sign in to comment.