Skip to content

Commit

Permalink
[Core][Migration] Support waiting request and multiple requests migra…
Browse files Browse the repository at this point in the history
…tion (#36)
  • Loading branch information
s5u13b authored Nov 12, 2024
1 parent 844c836 commit e92c9ac
Show file tree
Hide file tree
Showing 32 changed files with 790 additions and 381 deletions.
15 changes: 8 additions & 7 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -29,30 +29,31 @@ lint: check_pylint_installed check_pytest_installed

.PHONY: test
test: check_pytest_installed
@pytest -v -x --ignore=third_party/ --ignore=tests/e2e_test --disable-warnings
@pytest -v --ignore=third_party/ --ignore=tests/e2e_test --disable-warnings
@python examlpes/offline_inference.py
@pytest -v -x tests/e2e_test/test_e2e.py
@pytest -v -x ./tests/e2e_test/test_migration.py
@pytest -v ./tests/e2e_test/test_e2e.py
@pytest -v ./tests/e2e_test/test_bench.py
@pytest -v ./tests/e2e_test/test_migration.py

.PHONY: unit_test
unit_test: check_pytest_installed
@pytest -v -x --ignore=third_party/ --ignore=tests/e2e_test --disable-warnings
@pytest -v --ignore=third_party/ --ignore=tests/e2e_test --disable-warnings

.PHONY: offline_test
offline_test:
@python examlpes/offline_inference.py

.PHONY: e2e_test
e2e_test:
@pytest -v -x tests/e2e_test/test_e2e.py
@pytest -v ./tests/e2e_test/test_e2e.py

.PHONY: bench_test
bench_test:
@pytest -v -x ./tests/e2e_test/test_bench.py
@pytest -v ./tests/e2e_test/test_bench.py

.PHONY: migration_test
migration_test:
@pytest -v -x ./tests/e2e_test/test_migration.py
@pytest -v ./tests/e2e_test/test_migration.py

#################### pygloo install for gloo migration backend begin ####################

Expand Down
2 changes: 1 addition & 1 deletion configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ MANAGER:

ENABLE_MIGRATION: True
ENABLE_DEFRAG: True
REQUEST_MIGRATION_POLICY: 'SJF'
REQUEST_MIGRATION_POLICY: 'SR'

MIGRATION_BACKEND: 'gloo'
MIGRATION_BUFFER_BLOCKS: 512
Expand Down
6 changes: 3 additions & 3 deletions docs/Arguments.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ usage: -m llumnix.entrypoints.vllm.api_server [-h]
[--pair-migration-frequency PAIR_MIGRATION_FREQUENCY]
[--pair-migration-policy {balanced,defrag_constrained,defrag_relaxed}]
[--migrate-out-threshold MIGRATE_OUT_THRESHOLD]
[--request-migration-policy {LCFS,SJF,LJF}]
[--request-migration-policy {LCR,SR,LR,FCW,FCWSR}]
[--enable-defrag ENABLE_DEFRAG]
[--enable-scaling]
[--min-instances MIN_INSTANCES]
Expand Down Expand Up @@ -90,8 +90,8 @@ usage: -m llumnix.entrypoints.vllm.api_server [-h]

`--request-migration-policy`
- Request migration policy.
- Possible choices: LCFS, SJF, LJF
- Default: "SJF"
- Possible choices: LCR, SR, LR, FCW, FCWSR
- Default: "SR"

`--enable-defrag`
- Enable defragmentation through migration based on virtual usage.
Expand Down
24 changes: 20 additions & 4 deletions llumnix/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from llumnix.config import LlumnixConfig, get_llumnix_config
from llumnix.config.default import _C


class LlumnixArgumentParser(argparse.ArgumentParser):
def __init__(self, *args, **kwargs):
self.cur_namespace = "llumnix"
Expand Down Expand Up @@ -228,7 +229,11 @@ def add_cli_args(
parser.add_argument('--dispatch-policy',
type=str,
choices=['balanced', 'load', 'queue', 'flood'],
help='request dispatch policy')
help='The request dispatch policy.\n\n'
'* "balanced" dispatch request to the instance with minimum requests dispatched.\n'
'* "load" dispatch request to the instance with lowest instance load.\n'
'* "queue" dispatch request to the instance with minimum waiting request queue length.\n'
'* "flood" dispatch request to the instance with maximum requests dispatched.\n')
parser.add_argument('--num-available-dispatch-instances',
type=int,
help='number of available instances for dispatching')
Expand All @@ -242,14 +247,25 @@ def add_cli_args(
parser.add_argument('--pair-migration-policy',
type=str,
choices=['balanced', 'defrag_constrained', 'defrag_relaxed'],
help='pair migration policy')
help='The pair migration policy.\n\n'
'* "balanced" pair migration to make the instance load of instance more balanced.\n'
'* "defrag_constrained" pair migration without balanced constraint to '
'achieve defragmentation thoroughly (with instance constraints).\n'
'* "defrag_relaxed" pair migration to without balanced constraint '
'to achieve defragmentation thoroughly (without instance constraints).\n')
parser.add_argument('--migrate-out-threshold',
type=float,
help='migrate out instance load threshold')
parser.add_argument('--request-migration-policy',
type=str,
choices=['LCFS', 'SJF', 'LJF'],
help='request migration policy')
default=None,
choices=['LCR', 'SR', 'LR', 'FCW', 'FCWSR'],
help='The request migration policy.\n\n'
'* "LCR" migrate the running request last come.\n'
'* "SR" migrate the running request shortest.\n'
'* "LR" migrate the running request longest.\n'
'* "FCW" migrate the waiting request first come.\n'
'* "FCWSR" migrate the waiting request first come and running request shortest.\n')
parser.add_argument('--enable-defrag',
type=bool,
help='enable defragmentation through migration based on virtual usage')
Expand Down
64 changes: 48 additions & 16 deletions llumnix/backends/backend_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@

from abc import ABC, abstractmethod
from enum import Enum
from typing import Iterable, List, Union
from typing import Iterable, List, Union, Deque

from llumnix.llumlet.request import LlumnixRequest
from llumnix.llumlet.request import LlumnixRequest, RequestStatus
from llumnix.server_info import ServerInfo

class EngineState(str, Enum):
Expand Down Expand Up @@ -99,14 +99,21 @@ def get_request_incremental_blocks(self, backend_request: LlumnixRequest, pre_st
raise NotImplementedError

@abstractmethod
def get_running_queue(self) -> List[LlumnixRequest]:
def get_running_queue(self) -> Deque[LlumnixRequest]:
"""
Return backend's running queue.
"""
raise NotImplementedError

@abstractmethod
def remove_running_request(self, request_id: str) -> None:
def get_waiting_queue(self) -> Deque[LlumnixRequest]:
"""
Return backend's waiting queue.
"""
raise NotImplementedError

@abstractmethod
def remove_running_request(self, request_id: str) -> bool:
"""
Removes a request from the backend's running queue.
Expand All @@ -117,6 +124,26 @@ def remove_running_request(self, request_id: str) -> None:
Args:
request_id: A string identifier for the request that is to be removed from the running
queue. This ID uniquely identifies the request within the backend system.
Returns:
True if the request was successfully removed from the running queue, False otherwise.
"""
raise NotImplementedError

@abstractmethod
def remove_waiting_request(self, request_id: str) -> bool:
"""
Removes a request from the backend's waiting queue.
This method is responsible for safely halting and removing an active request from the waiting
queue of the backend engine. This action is performed in waiting request migration.
Args:
request_id: A string identifier for the request that is to be removed from the waiting
queue. This ID uniquely identifies the request within the backend system.
Returns:
True if the request was successfully removed from the waiting queue, False otherwise.
"""
raise NotImplementedError

Expand Down Expand Up @@ -164,17 +191,25 @@ def pop_migrating_out_requests_last_stage(self) -> List[LlumnixRequest]:
raise NotImplementedError

@abstractmethod
def pre_alloc(self, request_id: str, block_num: int) -> List[int]:
def pre_alloc(self,
request_id: str,
request_status: RequestStatus,
request_arrival_time: float,
block_num: int) -> List[int]:
"""Pre-allocates cache blocks for a migrating request.
This method selects a specified number of free cache blocks to be reserved for an incoming
migration request identified by the given request ID. It updates the pre-allocation cache
dictionary with the allocated blocks, which ensures that these blocks are not used by
another process until the migration is finished.
another process until the migration is finished. For the waiting request, it only reserves
free cache blocks when the request is the earliest arrival one among the requests of dst instance's
waiting queue.
Args:
request_id: The unique identifier of the migration request for which cache blocks
are to be pre-allocated.
request_status: The status (waiting/running) of the request.
request_arrival_time: The arrival time of the request.
block_num: The number of cache blocks that need to be pre-allocated for the request.
Returns:
Expand All @@ -187,9 +222,8 @@ def add_running_request(self, backend_request: LlumnixRequest) -> None:
"""
Adds a backend request to the running queue for processing.
This method enqueues a backend request into engine running queue, marking it for
active processing. It is used when a suspend migrating request should be added back
to running queue.
This method enqueues a backend request into engine running queue.
It is used when a suspend migrating request should be added back to running queue.
Args:
backend_request: An object representing the backend request. The type of this
Expand All @@ -199,19 +233,17 @@ def add_running_request(self, backend_request: LlumnixRequest) -> None:
raise NotImplementedError

@abstractmethod
def is_request_running(self, backend_request: LlumnixRequest) -> bool:
"""Checks if a given backend request is currently in the running queue.
def add_waiting_request(self, backend_request: LlumnixRequest) -> None:
"""
Adds a backend request to the waiting queue for processing.
This method determines whether a backend request is present and actively being processed
in the running queue.
This method enqueues a backend request into engine waiting queue.
It is used when a suspend migrating request should be added back to waiting queue.
Args:
backend_request: An object representing the backend request. The type of this
object is dependent on the backend implementation and the details
of the request.
Returns:
True if the backend request is currently in the running queue; False otherwise.
"""
raise NotImplementedError

Expand Down
47 changes: 31 additions & 16 deletions llumnix/backends/vllm/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

import time
import traceback
from typing import Any, List, Optional, Dict, Union, Iterable, Tuple
from typing import Any, List, Optional, Dict, Union, Iterable, Tuple, Deque
from collections import defaultdict
import threading
import asyncio
Expand All @@ -34,7 +34,7 @@
from llumnix.instance_info import InstanceInfo
from llumnix.backends.backend_interface import BackendInterface, EngineState
from llumnix.backends.vllm.scheduler import SchedulerLlumnix
from llumnix.backends.vllm.sequence import SequenceGroupLlumnix
from llumnix.backends.vllm.sequence import SequenceGroupLlumnix, RequestStatus
from llumnix.backends.profiling import LatencyMemData
from llumnix.server_info import ServerInfo
from llumnix.internal_config import MigrationConfig
Expand Down Expand Up @@ -199,7 +199,7 @@ def _process_model_outputs(
# TODO(ZeldaHuang): Use LlumnixRequestOutput to store llumnix output args.
return request_outputs, server_infos

async def step_async(self) -> None:
async def step_async(self) -> Tuple[List[RequestOutput], List[ServerInfo]]:
step_begin_time = time.time()
request_outputs, server_infos = await super().step_async()
for request_output in request_outputs:
Expand Down Expand Up @@ -295,9 +295,11 @@ def __init__(
self.worker_handle_list = self.engine.model_executor.workers.copy()
if len(self.worker_handle_list) + 1 == self.engine.parallel_config.world_size:
self.worker_handle_list.insert(0, ray.get_actor(f"instance_{self.instance_id}", namespace="llumnix"))
self._run_workers("init_migration", instance_id=instance_id, migration_config=migration_config,\
src_worker_handle_list=self.worker_handle_list,
placement_group=placement_group, node_id=node_id)
self._run_workers("init_migration", instance_id=instance_id,
migration_config=migration_config,
src_worker_handle_list=self.worker_handle_list,
placement_group=placement_group,
node_id=node_id)

self.state = EngineState.INIT
logger.info("engine ({}) current state {}".format(self.instance_id, self.state))
Expand Down Expand Up @@ -350,15 +352,22 @@ def commit_dst_request(self, backend_request: SequenceGroupLlumnix) -> None:
logger.info("add seq {} to block table".format(seq.seq_id))
pre_alloc_blocks = self.engine.scheduler.pre_alloc_cache_dict.pop(backend_request.request_id)
self.engine.scheduler.block_manager.add_block_table(pre_alloc_blocks, seq.seq_id)
backend_request.reset_migration_args()
self.add_running_request(backend_request)
backend_request.reset_migration_args_dst()
assert backend_request.status in [RequestStatus.WAITING_MIGRATING, RequestStatus.RUNNING_MIGRATING], \
"The status of request migrated to dst instance should be \
RequestStatus.WAITING_MIGRATING or RequestStatus.RUNNING_MIGRATING"
if backend_request.status == RequestStatus.WAITING_MIGRATING:
self.add_waiting_request(backend_request)
else: # RUNNING_MIGRATING:
backend_request.reset_status()
self.add_running_request(backend_request)

async def send_blocks(self, dst_ray_actor: "ray.actor.ActorHandle", src_blocks: List[int], dst_blocks: List[int]) -> None:
await dst_ray_actor.execute_engine_method.remote("_run_workers",
"migrate_cache",
dst_blocks=dst_blocks,
src_blocks=src_blocks,
src_worker_handle_list=self.worker_handle_list)
"migrate_cache",
dst_blocks=dst_blocks,
src_blocks=src_blocks,
src_worker_handle_list=self.worker_handle_list)

def _run_workers(self, *args, **kwargs):
# pylint: disable=protected-access
Expand All @@ -373,15 +382,21 @@ def abort_request(self, request_id: Union[str, Iterable[str]]) -> None:
request_ids = set(request_id)
return self.engine.abort_request(request_ids)

def get_running_queue(self) -> List[SequenceGroupLlumnix]:
def get_running_queue(self) -> Deque[SequenceGroupLlumnix]:
return self.engine.scheduler.get_running_queue()

def get_waiting_queue(self) -> Deque[SequenceGroupLlumnix]:
return self.engine.scheduler.get_waiting_queue()

def get_request_incremental_blocks(self, *args, **kwargs) -> List[int]:
return self.engine.scheduler.get_request_incremental_blocks(*args, **kwargs)

def remove_running_request(self, *args, **kwargs) -> None:
def remove_running_request(self, *args, **kwargs) -> bool:
return self.engine.scheduler.remove_running_request(*args, **kwargs)

def remove_waiting_request(self, *args, **kwargs) -> bool:
return self.engine.scheduler.remove_waiting_request(*args, **kwargs)

def add_migrating_out_request_last_stage(self, *args, **kwargs) -> None:
return self.engine.scheduler.add_migrating_out_request_last_stage(*args, **kwargs)

Expand All @@ -400,8 +415,8 @@ def should_abort_migration(self, *args, **kwargs) -> bool:
def add_running_request(self, *args, **kwargs) -> None:
return self.engine.scheduler.add_running_request(*args, **kwargs)

def is_request_running(self, *args, **kwargs) -> bool:
return self.engine.scheduler.is_request_running(*args, **kwargs)
def add_waiting_request(self, *args, **kwargs) -> None:
return self.engine.scheduler.add_waiting_request(*args, **kwargs)

def free_dst_pre_alloc_cache(self, *args, **kwargs) -> None:
return self.engine.scheduler.free_dst_pre_alloc_cache(*args, **kwargs)
Expand Down
10 changes: 5 additions & 5 deletions llumnix/backends/vllm/migration_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,15 +286,15 @@ def get_migration_backend(migration_config: MigrationConfig, cache_engine: Cache
.format(migration_config.migration_buffer_blocks, cache_engine.num_gpu_blocks))
migration_config.migration_buffer_blocks = cache_engine.num_gpu_blocks

target_col = None
target_migration_backend = None
backend = migration_config.migration_backend
assert backend in ['nccl', 'gloo', 'rpc'], "Unsupported backend: {} for VLLM".format(backend)
assert backend in ['nccl', 'gloo', 'rpc'], "Unsupported migration backend: {} for llumnix".format(backend)

if backend in ['nccl', 'gloo']:
target_col = RayColMigrationBackend(migration_config, cache_engine, local_rank, scheduling_strategy,
target_migration_backend = RayColMigrationBackend(migration_config, cache_engine, local_rank, scheduling_strategy,
is_driver_worker, gpu_cache)
else:
target_col = RayRpcMigrationBackend(migration_config, cache_engine, worker_rank, worker_handle_list,
target_migration_backend = RayRpcMigrationBackend(migration_config, cache_engine, worker_rank, worker_handle_list,
scheduling_strategy, is_driver_worker, gpu_cache)

return target_col
return target_migration_backend
Loading

0 comments on commit e92c9ac

Please sign in to comment.