Skip to content

Commit

Permalink
add unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
Xinyi-ECNU committed Sep 5, 2024
1 parent 16a05d1 commit 26f4d6a
Show file tree
Hide file tree
Showing 23 changed files with 207 additions and 122 deletions.
42 changes: 21 additions & 21 deletions llumnix/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,28 +11,28 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# import vllm
# from vllm import *
import vllm
from vllm import *

# from llumnix.server_info import ServerInfo
# from llumnix.entrypoints.llumnix_utils import (launch_ray_cluster, connect_to_ray_cluster,
# init_manager, init_llumlets)
# from llumnix.arg_utils import EngineManagerArgs
# from llumnix.llm_engine_manager import LLMEngineManager
# from llumnix.llumlet.llumlet import Llumlet
from llumnix.server_info import ServerInfo
from llumnix.entrypoints.llumnix_utils import (launch_ray_cluster, connect_to_ray_cluster,
init_manager, init_llumlets)
from llumnix.arg_utils import EngineManagerArgs
from llumnix.llm_engine_manager import LLMEngineManager
from llumnix.llumlet.llumlet import Llumlet

# from .version import __version__
from .version import __version__

# __all__ = [
# "__version__",
# "ServerInfo",
# "launch_ray_cluster",
# "connect_to_ray_cluster",
# "init_manager",
# "init_llumlets",
# "EngineManagerArgs",
# "LLMEngineManager",
# "Llumlet"
# ]
__all__ = [
"__version__",
"ServerInfo",
"launch_ray_cluster",
"connect_to_ray_cluster",
"init_manager",
"init_llumlets",
"EngineManagerArgs",
"LLMEngineManager",
"Llumlet"
]

# __all__.extend(getattr(vllm, "__all__", []))
__all__.extend(getattr(vllm, "__all__", []))
5 changes: 3 additions & 2 deletions llumnix/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,10 @@ class EngineManagerArgs:
def create_global_scheduler_configs(
self,
) -> Tuple[GlobalSchedulerConfig]:

config_data = get_cfg()
config_data.merge_from_file(self.config_file)
if self.config_file:
config_data.merge_from_file(self.config_file)

# Create the GlobalScheduler Configuration.
global_scheduler_config = GlobalSchedulerConfig(self.initial_instances,
Expand Down
18 changes: 9 additions & 9 deletions llumnix/backends/backend_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,9 @@ def add_request(self, request_id: str, server_info: ServerInfo, request_expected
Args:
request_id: Request ID.
server_info: The information of the api server where the request come.
request_expected_steps: The expected number of steps for the request to run.The number of steps
represents the sum of the times 'engine.step()' has been called by the
backend instances for the request.
request_expected_steps: The expected number of steps for the request to run.The number of steps
represents the sum of the times 'engine.step()' has been called by the
backend instances for the request.
*args: Positional arguments that represent request-specific data.
**kwargs: Keyword arguments that contain metadata of the backend request
(request_id, arrival_time, etc.).
Expand Down Expand Up @@ -270,19 +270,19 @@ def commit_dst_request(self, backend_request: LlumnixRequest) -> None:
of the request.
"""
raise NotImplementedError

@abstractmethod
def update_strict_pre_migration(self, new_strict_pre_migration: bool) -> None:
def update_strict_pre_migration(self, strict_pre_migration: bool) -> None:
"""Update the status of whether to force migration in the backend engine.
This method updates the status of whether to force migration in the backend engine. This action is performed only when the
This method updates the status of whether to force migration in the backend engine. This action is performed only when the
corresponding status in the llumlet is changed.
`pre_migration` represents whether the backend instance enables migration. By default, `pre_migration` is set to True, indicating that
the instance enables migration when `request.output_len` >= `request.request_expected_steps`. If `pre_migration` is set
`pre_migration` represents whether the backend instance enables migration. By default, `pre_migration` is set to True, indicating that
the instance enables migration when `request.output_len` >= `request.request_expected_steps`. If `pre_migration` is set
to False, migration will not occur, and requests on the instance that reach the `request_expected_steps` will continue with inference.
Args:
new_strict_pre_migration: New migration status provided for backend engine.
strict_pre_migration: control whether the backend engine enables migration.
"""
raise NotImplementedError

Expand Down
20 changes: 10 additions & 10 deletions llumnix/backends/vllm/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ def _process_model_outputs(
for scheduled_seq_group, seq_group_meta, seq_group_output in zip(scheduled_seq_groups, seq_group_metadata_list, output[0].outputs):
seq_group = scheduled_seq_group.seq_group
if seq_group.get_seqs(SequenceStatus.RUNNING):
# print(seq_group)
new_scheduled_seq_groups.append(scheduled_seq_group)
new_seq_group_metadata_list.append(seq_group_meta)
new_output.append(seq_group_output)
Expand Down Expand Up @@ -149,10 +150,10 @@ def update_instance_info(self, instance_info: InstanceInfo) -> None:

def add_request(self, request_id: str, server_info: ServerInfo, request_expected_steps: int, *args, **kwargs):
super().add_request(request_id, *args, **kwargs)
logger.info("add_request")
seq_group = self.scheduler.waiting[-1]
self.scheduler.waiting[-1] = SequenceGroupLlumnix(request_id, server_info, request_expected_steps, [seq_group.get_seqs()[0]], seq_group.sampling_params,
seq_group.metrics.arrival_time, seq_group.lora_request, seq_group.multi_modal_data)
self.scheduler.waiting[-1] = SequenceGroupLlumnix(request_id, server_info, request_expected_steps, [seq_group.get_seqs()[0]],
seq_group.sampling_params,
seq_group.metrics.arrival_time, seq_group.lora_request, seq_group.multi_modal_data)
self.scheduler.scheduler_lock.release()

def _put_request_output_to_server(self, request_outputs, server_infos: List[ServerInfo]) -> None:
Expand Down Expand Up @@ -182,14 +183,13 @@ def __init__(
placement_group: "PlacementGroup" = None,
node_id: str = None
) -> None:
self.strict_pre_migration = True
self.engine: LLMEngineLlumnix = LLMEngineLlumnix.from_engine_args(engine_args=engine_args,
migration_config=migration_config,
instance_id=instance_id,
placement_group=placement_group,
node_id=node_id)
# multi-instance args
self.engine.scheduler = SchedulerLlumnix(self.strict_pre_migration, self.engine.scheduler_config, self.engine.cache_config, self.engine.lora_config)
self.engine.scheduler = SchedulerLlumnix(self.engine.scheduler_config, self.engine.cache_config, self.engine.lora_config)
self.engine.scheduler.add_update_instance_info_callback(self.engine.update_instance_info)
self.engine.output_processor.scheduler = self.engine.scheduler
self.instance_id = instance_id
Expand Down Expand Up @@ -220,7 +220,6 @@ def add_request(self,
# Store the server information of each request to put the request outputs back to the corresponding api server correctly.
self.engine.add_request(request_id, server_info, request_expected_steps, *args, **kwargs)


def commit_dst_request(self, backend_request: SequenceGroupLlumnix) -> None:
seq = backend_request.get_seqs()[0]
seq.seq_id = next(self.engine.seq_counter)
Expand Down Expand Up @@ -252,10 +251,10 @@ def abort_request(self, request_id: Union[str, Iterable[str]]) -> None:

def get_running_queue(self ) -> List[SequenceGroupLlumnix]:
return self.engine.scheduler.get_running_queue()
def update_strict_pre_migration(self, new_migration_state: bool):
if self.strict_pre_migration != new_migration_state:
self.strict_pre_migration = new_migration_state
self.engine.scheduler.update_strict_pre_migration(new_migration_state)

def update_strict_pre_migration(self, strict_pre_migration: bool):
self.engine.scheduler.update_strict_pre_migration(strict_pre_migration)

def get_request_incremental_blocks(self, *args, **kwargs) -> List[int]:
return self.engine.scheduler.get_request_incremental_blocks(*args, **kwargs)

Expand Down Expand Up @@ -288,5 +287,6 @@ def free_dst_pre_alloc_cache(self, *args, **kwargs) -> None:

def free_src_request(self, backend_request: SequenceGroup) -> None:
return self.engine.scheduler.free_src_request(backend_request)

def get_all_request_ids(self) -> List[str]:
return self.engine.scheduler.get_all_request_ids()
11 changes: 5 additions & 6 deletions llumnix/backends/vllm/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from asyncio.log import logger
import time
import threading
import copy
from typing import Dict, List, Optional, Tuple

from vllm.core.block_manager_v1 import BlockSpaceManagerV1, BlockTable
Expand Down Expand Up @@ -46,7 +45,7 @@ def add_block_table(self, block_table: BlockTable, seq_id: int) -> None:
self.block_tables[seq_id] = block_table.copy()

class SchedulerLlumnix(Scheduler):
def __init__(self, strict_pre_migration, *args, **kwargs) -> None:
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.block_manager = BlockManagerLlumnix(
block_size=self.cache_config.block_size,
Expand All @@ -57,7 +56,7 @@ def __init__(self, strict_pre_migration, *args, **kwargs) -> None:
self.pre_alloc_cache_dict: Dict[str, BlockTable] = {}
self.scheduler_lock = threading.Lock()
self.migrating_out_request_last_stage: List[LlumnixRequest] = []
self.strict_pre_migration = strict_pre_migration
self.strict_pre_migration = True

def add_update_instance_info_callback(self, update_instance_info_callback):
self.update_instance_info_callback = update_instance_info_callback
Expand Down Expand Up @@ -209,7 +208,7 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]:

def _schedule_running(self, *args, **kwargs):
args_list = list(args)
args_list[0] = copy.deepcopy(self.running)
args_list[0] = self.running.copy()
remove_running = []
if self.strict_pre_migration:
for seq_group in list(args_list[0]):
Expand All @@ -223,8 +222,8 @@ def _schedule_running(self, *args, **kwargs):
return remaining_running, running_scheduled

@scheduler_lock
def update_strict_pre_migration(self, new_migration_state: bool) -> None:
self.strict_pre_migration = new_migration_state
def update_strict_pre_migration(self, strict_pre_migration: bool) -> None:
self.strict_pre_migration = strict_pre_migration

def add_seq_group(self, *args, **kwargs):
# The scheduler lock is mannually released in the end of LLMEngineLlumnix.add_request function.
Expand Down
2 changes: 1 addition & 1 deletion llumnix/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,4 +57,4 @@ def __init__(
self.scale_down_threshold = scale_down_threshold*(-1)

self.enable_pd_disaggregation = enable_pd_disaggregation
self.available_dispatch_instance_num = available_dispatch_instance_num
self.available_dispatch_instance_num = available_dispatch_instance_num
5 changes: 0 additions & 5 deletions llumnix/entrypoints/llumnix_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,6 @@ def init_llumlets(engine_manager_args: EngineManagerArgs,

instance_ids = [random_uuid() for _ in range(engine_manager_args.initial_instances)]
migration_configs = engine_manager_args.create_migration_config()
print("??",engine_manager_args.initial_instances)
for idx in range(engine_manager_args.initial_instances):
instance_id = instance_ids[idx]
if not engine_manager_args.profiling_result_file_path:
Expand Down Expand Up @@ -200,10 +199,6 @@ def init_llumnix_components(engine_manager_args: EngineManagerArgs,
ray.get(task)
available_instance_ids.append(instance_ids[idx])
available_llumlets.append(llumlets[idx])
except Exception as e:
import traceback
logger.error("unexpected exception occurs: {}".format(e))
logger.error("exception traceback: {}".format(traceback.format_exc()))
except ray.exceptions.RayActorError:
dead_instance_ids.append(instance_ids[idx])

Expand Down
3 changes: 2 additions & 1 deletion llumnix/global_scheduler/dispatch_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@ def update_instance_infos(self,
def add_instance(self, instance_id: str) -> None:
self.instance_id_set.add(instance_id)
self.num_instances = len(self.instance_id_set)
if self.available_dispatch_instance_num == -1 or (self.available_dispatch_instance_num > 0 and len(self.available_dispatch_instance_set) < self.available_dispatch_instance_num):
if self.available_dispatch_instance_num == -1 or (self.available_dispatch_instance_num > 0 and
len(self.available_dispatch_instance_set) < self.available_dispatch_instance_num):
self.available_dispatch_instance_set.add(instance_id)
self.instance_num_requests[instance_id] = 0

Expand Down
7 changes: 4 additions & 3 deletions llumnix/global_scheduler/migration_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def _get_migration_settings(self, migration_target:str) -> Dict[str, InstanceInf
sorted_dst_instance_infos = [i for i in self.sorted_instance_infos[InstanceType.NO_CONSTRAINTS]
if i.num_killed_requests == 0 and i.instance_load_migrate < self.migrate_out_load_threshold]
elif migration_target == PairMigrationConstraints.PREFILL_2_DECODING:
sorted_src_instance_infos = [i for i in reversed(self.sorted_instance_infos[InstanceType.PREFILL])]
sorted_src_instance_infos = list(reversed(self.sorted_instance_infos[InstanceType.PREFILL]))
sorted_dst_instance_infos = [i for i in self.sorted_instance_infos[InstanceType.DECODE]
if i.num_killed_requests == 0]
# TODO[xinyi]: Considering decoding instances load, try to decode on the prefill instance(set strict_pre_migration as False).
Expand All @@ -90,7 +90,7 @@ def _get_migration_settings(self, migration_target:str) -> Dict[str, InstanceInf
sorted_dst_instance_infos = [i for i in self.sorted_instance_infos[InstanceType.DECODE]
if i.num_killed_requests == 0 and i.instance_load_migrate < self.migrate_out_load_threshold]
return sorted_src_instance_infos, sorted_dst_instance_infos, strict_pre_migration

def update_instance_infos(self,
instance_info: Dict[str, InstanceInfo]) -> None:
self.instance_info = instance_info
Expand Down Expand Up @@ -154,7 +154,8 @@ def pair_migration(self,
continue
load_diff_after_mig = left_load_after_mig - right_load_after_mig
if (0 < load_diff_after_mig < load_diff_before_mig) or (sorted_dst_instance_infos[i].instance_load_migrate == -np.inf):
migrate_instance_pairs.append((sorted_src_instance_infos[i].instance_id, sorted_dst_instance_infos[i].instance_id, strict_pre_migration))
migrate_instance_pairs.append((sorted_src_instance_infos[i].instance_id,
sorted_dst_instance_infos[i].instance_id, strict_pre_migration))
return migrate_instance_pairs

def _compute_instance_load_after_migrate(self, instance_info: InstanceInfo, is_migrate_in: bool) -> float:
Expand Down
13 changes: 4 additions & 9 deletions llumnix/llm_engine_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ async def generate(
logger.info("No instance available temporarily, sleep {}s, "
"and retry generate request {} again....".format(RETRIES_INTERVALS, request_id))
await asyncio.sleep(RETRIES_INTERVALS)

try:
instance_id, request_expected_steps = self.global_scheduler.dispatch()
await self.instances[instance_id].generate.remote(request_id, server_info, request_expected_steps, *args, **kwargs)
Expand All @@ -126,10 +126,6 @@ async def generate(
except (ray.exceptions.RayActorError, KeyError):
logger.info("[generate] instance {} is dead, regenerate request {}".format(instance_id, request_id))
self.scale_down(instance_id)
except Exception as e:
import traceback
logger.error("unexpected exception occurs: {}".format(e))
logger.error("exception traceback: {}".format(traceback.format_exc()))

async def abort(self, request_id: Union[str, Iterable[str]]) -> None:
if isinstance(request_id, str):
Expand Down Expand Up @@ -240,7 +236,7 @@ async def _post_migrate(self, rets: List[str], call_migrate_instance_pairs: List
self.request_instance[migrate_out_request_id] = call_migrate_instance_pairs[i][1]
logger.info("{}->{} migrate done, migrate request {}".format(
call_migrate_instance_pairs[i][0], call_migrate_instance_pairs[i][1], migrate_out_request_ids))

async def _migrate_control(self) -> None:
# Push migrate when the instance_info have updated a certain number of times.
if self.enable_pd_disaggregation:
Expand All @@ -258,15 +254,14 @@ async def _migrate(self, migration_target:str, migrate_in_num_requests:int) -> N
call_migrate_instance_pairs: List[Tuple[str, str]] = []
for _, migrate_instance_pair in enumerate(migrate_instance_pairs):
migrate_out_instance_id, migrate_in_instance_id, strict_pre_migration = migrate_instance_pair
# logger.info("[_migrate] migrate_instance_pairs {} {} {} {} {}".format(migration_target, migrate_out_instance_id, migrate_in_instance_id, self.instance_migrating[migrate_out_instance_id], self.instance_migrating[migrate_in_instance_id]))
if self.instance_migrating[migrate_out_instance_id] or self.instance_migrating[migrate_in_instance_id]:
continue
# logger.info("[_migrate] migrate_instance_pairs {} {} {} ".format(migration_target, migrate_out_instance_id, migrate_in_instance_id))
self.instance_migrating[migrate_out_instance_id] = True
self.instance_migrating[migrate_in_instance_id] = True
migrate_in_instance_name = "instance_{}".format(migrate_in_instance_id)
call_migrate_instance_pairs.append(migrate_instance_pair)
task = self.instances[migrate_out_instance_id].migrate_out.remote(migrate_in_instance_name, migrate_in_num_requests, strict_pre_migration)
task = self.instances[migrate_out_instance_id].migrate_out.remote(migrate_in_instance_name,
migrate_in_num_requests, strict_pre_migration)
migration_tasks.append(task)
# TODO(s5u13b): Migration failover could be implemented in Llumlet rather than manager.
rets = await asyncio.gather(*migration_tasks, return_exceptions=True)
Expand Down
Loading

0 comments on commit 26f4d6a

Please sign in to comment.