Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
ZeldaHuang committed Dec 17, 2024
1 parent a1bd218 commit 05b5499
Show file tree
Hide file tree
Showing 12 changed files with 69 additions and 45 deletions.
6 changes: 3 additions & 3 deletions llumnix/backends/backend_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from abc import ABC, abstractmethod
from enum import Enum
from typing import Iterable, List, Union, Deque
from typing import Iterable, List, Union, Deque, Tuple

from llumnix.llumlet.request import LlumnixRequest, RequestStatus
from llumnix.server_info import ServerInfo
Expand Down Expand Up @@ -75,7 +75,7 @@ async def _start_engine_step_loop(self) -> None:

# Methods for migration
@abstractmethod
def get_request_incremental_blocks(self, backend_request: LlumnixRequest, pre_stage_num_blocks: int) -> List[int]:
def get_request_incremental_blocks(self, backend_request: LlumnixRequest, pre_stage_num_blocks: int) -> Tuple[List[int], List[int]]:
"""Retrieves the incremental block table for a given request.
This method is used to fetch a list of block numbers that represent the incremental
Expand All @@ -92,7 +92,7 @@ def get_request_incremental_blocks(self, backend_request: LlumnixRequest, pre_st
need to be fetched in the current stage.
Returns:
A list of integers, where each integer represents a block number that indicates
A list of integers and its token ids, where each integer represents a block number that indicates
physical index of kv cache block tensor. These block numbers can then be used
to transfer to dstination instance.
"""
Expand Down
9 changes: 6 additions & 3 deletions llumnix/backends/vllm/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,16 +375,16 @@ def get_running_queue(self) -> List[SequenceGroupLlumnix]:
return self.engine.scheduler[0].get_running_queue()

def get_waiting_queue(self) -> Deque[SequenceGroupLlumnix]:
return self.engine.scheduler.get_waiting_queue()
return self.engine.scheduler[0].get_waiting_queue()

def get_request_incremental_blocks(self, *args, **kwargs) -> List[int]:
def get_request_incremental_blocks(self, *args, **kwargs) -> Tuple[List[int], List[int]]:
return self.engine.scheduler[0].get_request_incremental_blocks(*args, **kwargs)

def remove_running_request(self, *args, **kwargs) -> None:
return self.engine.scheduler[0].remove_running_request(*args, **kwargs)

def remove_waiting_request(self, *args, **kwargs) -> bool:
return self.engine.scheduler.remove_waiting_request(*args, **kwargs)
return self.engine.scheduler[0].remove_waiting_request(*args, **kwargs)

def add_migrating_out_request_last_stage(self, *args, **kwargs) -> None:
return self.engine.scheduler[0].add_migrating_out_request_last_stage(*args, **kwargs)
Expand All @@ -404,6 +404,9 @@ def should_abort_migration(self, *args, **kwargs) -> bool:
def add_running_request(self, *args, **kwargs) -> None:
return self.engine.scheduler[0].add_running_request(*args, **kwargs)

def add_waiting_request(self, *args, **kwargs) -> None:
return self.engine.scheduler[0].add_waiting_request(*args, **kwargs)

def is_request_running(self, *args, **kwargs) -> bool:
return self.engine.scheduler[0].is_request_running(*args, **kwargs)

Expand Down
14 changes: 7 additions & 7 deletions llumnix/backends/vllm/migration_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,12 @@ def __init__(self, migration_config: MigrationConfig, cache_engine: List[CacheEn
self.is_driver_worker = is_driver_worker
self.gpu_cache = gpu_cache
self.cache_device = "cpu"
self.num_migration_cache_blocks = self.migration_config.migration_cache_blocks
self.num_migration_buffer_blocks = self.migration_config.migration_buffer_blocks
self.num_layers = self.cache_engine[0].num_attention_layers
self.migration_cache_size = self.cache_engine[0].block_size * self.cache_engine[0].num_kv_heads * self.cache_engine[0].head_size

self.dummy_cache = torch.empty(
size=(self.num_migration_cache_blocks, self.num_layers, 2, self.migration_cache_size),
size=(self.num_migration_buffer_blocks, self.num_layers, 2, self.migration_cache_size),
dtype=self.cache_engine[0].dtype,
device=self.cache_device,
pin_memory=True
Expand Down Expand Up @@ -162,7 +162,7 @@ def __init__(self, migration_config: MigrationConfig, cache_engine: List[CacheEn
self.cache_engine = cache_engine
self.backend = migration_config.migration_backend
self.migration_num_layers = min(migration_config.migration_num_layers, self.cache_engine[0].num_attention_layers)
self.num_migration_cache_blocks = migration_config.migration_cache_blocks
self.num_migration_buffer_blocks = migration_config.migration_buffer_blocks

self.backend = migration_config.migration_backend
self.global_world_size = -1
Expand All @@ -184,7 +184,7 @@ def __init__(self, migration_config: MigrationConfig, cache_engine: List[CacheEn

pin_memory = (self.backend == 'gloo')
self.dummy_cache = torch.empty(
size=(self.num_migration_cache_blocks, self.migration_num_layers, 2, self.migration_cache_size),
size=(self.num_migration_buffer_blocks, self.migration_num_layers, 2, self.migration_cache_size),
dtype=self.cache_engine[0].dtype,
device=self.cache_device,
pin_memory=pin_memory
Expand Down Expand Up @@ -297,10 +297,10 @@ def do_recv(self, src_handle, blocks: List[int], virtuel_engine: int=0):

def get_migration_backend(migration_config: MigrationConfig, cache_engine: List[CacheEngine], worker_handle_list, scheduling_strategy,
is_driver_worker, gpu_cache, worker_rank, local_rank) -> MigrationBackendBase:
if cache_engine[0].num_gpu_blocks < migration_config.migration_cache_blocks:
if cache_engine[0].num_gpu_blocks < migration_config.migration_buffer_blocks:
logger.warning("migration_cache_blocks({}) is larger than num_gpu_blocks({}), reducing it to num_gpu_blocks."
.format(migration_config.migration_cache_blocks, cache_engine.num_gpu_blocks))
migration_config.migration_cache_blocks = cache_engine[0].num_gpu_blocks
.format(migration_config.migration_buffer_blocks, cache_engine[0].num_gpu_blocks))
migration_config.migration_buffer_blocks = cache_engine[0].num_gpu_blocks

target_migration_backend = None
backend = migration_config.migration_backend
Expand Down
13 changes: 8 additions & 5 deletions llumnix/backends/vllm/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,16 +97,20 @@ def get_all_request_ids(self) -> List[str]:
request_ids.append(seq_group.request_id)
return request_ids

def get_request_incremental_blocks(self, backend_request: LlumnixRequest, pre_stage_num_blocks: int) -> List[int]:
def get_request_incremental_blocks(self, backend_request: LlumnixRequest, pre_stage_num_blocks: int) -> Tuple[List[int], List[int]]:
seq = backend_request.get_seqs()[0]
blocks = self.block_manager.get_block_table(seq)
return blocks[pre_stage_num_blocks:]
block_table = self.block_manager.block_tables[seq.seq_id]
token_ids = backend_request.token_ids
return blocks[pre_stage_num_blocks:], token_ids[pre_stage_num_blocks * self.block_manager.block_size:block_table.num_full_slots]

def remove_running_request(self, request_id: str) -> bool:
for seq_group in self.running:
for seq_group in reversed(self.running):
if seq_group.request_id == request_id:
self.running.remove(seq_group)
seq_group.set_status(RequestStatus.RUNNING_MIGRATING)
logger.info(f"remove running req {request_id}")
logger.info(f"len:{len(self.running)}")
return True
return False

Expand Down Expand Up @@ -138,8 +142,7 @@ def pre_alloc(self,
# Only migrate waiting request when the waiting request is the earliest arrival one
# among the requests of dst instance's waiting queue.
if request_status == RequestStatus.WAITING_MIGRATING:
if (self.waiting and request_arrival_time > self.waiting[0].arrival_time) \
or block_num * self.cache_config.block_size > self.prompt_limit:
if self.waiting and request_arrival_time > self.waiting[0].arrival_time:
return []
block_table = self.pre_alloc_cache_dict.get(request_id, None)
if not block_table:
Expand Down
10 changes: 7 additions & 3 deletions llumnix/backends/vllm/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ def __init__(self, request_id, server_info, expected_steps: int, *args, **kwargs
SequenceGroup.__init__(self, request_id, *args, **kwargs)
LlumnixRequest.__init__(self, request_id, server_info, expected_steps)

@property
def block_size(self) -> int:
return self.get_seqs()[0].block_size

@property
def prompt_len(self) -> int:
return self.get_seqs()[0].get_prompt_len()
Expand Down Expand Up @@ -54,8 +58,8 @@ def finished(self) -> bool:
return self.get_seqs()[0].is_finished()

@property
def arrival_time(self) -> float:
return self.metrics.arrival_time
def request_arrival_time(self) -> float:
return self.arrival_time

@property
def status(self) -> RequestStatus:
Expand All @@ -73,4 +77,4 @@ def status(self) -> RequestStatus:
@property
def prefill_num_blocks(self) -> int:
# Get the prefill len of the waiting request.
return len(self.get_seqs()[0].logical_token_blocks)
return self.get_seqs()[0].n_blocks
25 changes: 16 additions & 9 deletions llumnix/llumlet/migration_coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# limitations under the License.

import time
import traceback
import enum
from typing import List

Expand Down Expand Up @@ -71,8 +72,9 @@ async def migrate_out_waiting_request(self,
dst_blocks = await migrate_in_ray_actor.execute_migration_method \
.remote("migrate_in_pre_alloc", migrate_out_request.request_id,
migrate_out_request.status,
migrate_out_request.arrival_time,
migrate_out_request.prefill_num_blocks)
migrate_out_request.request_arrival_time,
migrate_out_request.prefill_num_blocks,
migrate_out_request.token_ids)
if len(dst_blocks) != migrate_out_request.prefill_num_blocks:
self.backend_engine.add_waiting_request(migrate_out_request)
self.backend_engine.remove_migrating_out_request_last_stage(migrate_out_request)
Expand Down Expand Up @@ -115,18 +117,20 @@ async def _migrate_out_onestage(self,
return MigrationStatus.ABORTED_SRC

pre_stage_num_blocks = sum(migrate_out_request.stage_num_blocks_list)
incremental_blocks = self.backend_engine.get_request_incremental_blocks(migrate_out_request, pre_stage_num_blocks)
incremental_blocks, incremental_token_ids = self.backend_engine.get_request_incremental_blocks(migrate_out_request, pre_stage_num_blocks)
# live migration, transfer all blocks except last one(currently updating)
is_last_stage = (len(incremental_blocks) <= self.last_stage_max_blocks) or migrate_out_request.blocking_migration
if not is_last_stage:
migration_status = MigrationStatus.RUNNING
src_blocks = incremental_blocks[:-1]
incremental_token_ids = incremental_token_ids[:src_blocks*migrate_out_request.block_size]
stage_block_num = len(incremental_blocks) - 1
dst_blocks = await migrate_in_ray_actor.execute_migration_method \
.remote("migrate_in_pre_alloc", migrate_out_request.request_id,
migrate_out_request.status,
migrate_out_request.arrival_time,
stage_block_num)
migrate_out_request.request_arrival_time,
stage_block_num,
incremental_token_ids)
else:
# last stage migration, stop inference, transfer all blocks
migration_status = MigrationStatus.FINISHED
Expand All @@ -139,8 +143,9 @@ async def _migrate_out_onestage(self,
dst_blocks = await migrate_in_ray_actor.execute_migration_method \
.remote("migrate_in_pre_alloc", migrate_out_request.request_id,
migrate_out_request.status,
migrate_out_request.arrival_time,
stage_block_num)
migrate_out_request.request_arrival_time,
stage_block_num,
incremental_token_ids)

if len(dst_blocks) != len(src_blocks):
# migrate-in instance failed to pre alloc
Expand Down Expand Up @@ -172,13 +177,15 @@ def migrate_in_pre_alloc(self,
request_id: str,
request_status: RequestStatus,
request_arrival_time: float,
block_num: int) -> List[int]:
block_num: int,
token_ids: List[int]) -> List[int]:
"""prev alloc blocks to migrate in request
"""
pre_alloc_blocks = self.backend_engine.pre_alloc(request_id,
request_status,
request_arrival_time,
block_num)
block_num,
token_ids)
if len(pre_alloc_blocks) != block_num:
# failed to alloc, abort request
self.free_dst_pre_alloc_cache(request_id)
Expand Down
6 changes: 5 additions & 1 deletion llumnix/llumlet/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def finished(self) -> bool:
raise NotImplementedError

@property
def arrival_time(self) -> float:
def request_arrival_time(self) -> float:
raise NotImplementedError

@property
Expand All @@ -111,6 +111,10 @@ def n_blocks(self) -> int:
def token_ids(self) -> int:
raise NotImplementedError

@property
def block_size(self) -> int:
raise NotImplementedError

# Whether the migration of request is completed within one stage. For requests that have already reached
# the expected steps, blocking_migration is True.
@property
Expand Down
7 changes: 4 additions & 3 deletions tests/unit_test/backends/vllm/test_migration.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,26 +58,27 @@ class MockLlumletDoNotSchedule(Llumlet):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# stop the schedule in engine step loop
self.backend_engine.engine.scheduler.schedule = MagicMock()
self.backend_engine.engine.scheduler[0].schedule = MagicMock()

# For some reason, if MockScheduelrOutputs is defined outside, the constructor would raise error.
class MockScheduelrOutputs:
def __init__(self):
self.scheduled_seq_groups = []
self.ignored_seq_groups = []
self.num_batched_tokens = 0
self.preempted = False

def is_empty(self) -> bool:
return not self.scheduled_seq_groups

scheduler_outputs = MockScheduelrOutputs()
self.backend_engine.engine.scheduler.schedule.return_value = ([], scheduler_outputs)
self.backend_engine.engine.scheduler[0].schedule.return_value = ([], scheduler_outputs, False)

self.step_async = self.backend_engine.engine.step_async

async def step_async_try_schedule():
request_outputs, server_infos = await self.step_async()
for seq_group in self.backend_engine.engine.scheduler.waiting:
for seq_group in self.backend_engine.engine.scheduler[0].waiting:
seq_group.try_schedule_times += 1
return request_outputs, server_infos

Expand Down
5 changes: 3 additions & 2 deletions tests/unit_test/backends/vllm/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,8 +213,8 @@ def test_schedule_running():
def test_try_schedule_times():
# total 8 blocks
scheduler = initialize_scheduler()
_, seq_group_1 = create_dummy_prompt("1", prompt_length=8, block_size=1)
_, seq_group_2 = create_dummy_prompt("2", prompt_length=8, block_size=1)
_, seq_group_1 = create_dummy_prompt("1", prompt_length=32, block_size=4)
_, seq_group_2 = create_dummy_prompt("2", prompt_length=32, block_size=4)
scheduler.add_seq_group(seq_group_1)
scheduler.add_seq_group(seq_group_2)
waiting_queue = scheduler.get_waiting_queue()
Expand All @@ -225,6 +225,7 @@ def test_try_schedule_times():
# seq_group_2 cannot be scheduled due to lack of blocks
assert seq_group_1.try_schedule_times == 0
assert seq_group_2.try_schedule_times == 1
append_new_token_seq_group(1, seq_group_1, 1)
scheduler.schedule()
# seq_group_1 is preempted to waiting queue
assert seq_group_1.try_schedule_times == 1
Expand Down
9 changes: 5 additions & 4 deletions tests/unit_test/backends/vllm/test_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,12 @@ async def test_executor():
scheduler_config=engine_config.scheduler_config,
device_config=engine_config.device_config,
lora_config=engine_config.lora_config,
vision_language_config=engine_config.vision_language_config,
speculative_config=engine_config.speculative_config,
load_config=engine_config.load_config)
load_config=engine_config.load_config,
prompt_adapter_config=engine_config.prompt_adapter_config,
observability_config=engine_config.observability_config)
scheduler = initialize_scheduler()
scheduler.schedule()
metas, out, _ = scheduler.schedule()
_, seq_group_0 = create_dummy_prompt(
"0", prompt_length=7, block_size=4
)
Expand All @@ -55,7 +56,7 @@ async def test_executor():
)
scheduler.add_seq_group(seq_group_0)
scheduler.add_seq_group(seq_group_1)
metas, out = scheduler.schedule()
metas, out, _ = scheduler.schedule()
execute_model_req = ExecuteModelRequest(
seq_group_metadata_list=metas,
blocks_to_swap_in=out.blocks_to_swap_in,
Expand Down
2 changes: 1 addition & 1 deletion tests/unit_test/llumlet/test_local_migration_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def output_len(self) -> int:
return self.length

@property
def arrival_time(self) -> float:
def request_arrival_time(self) -> float:
pass

@property
Expand Down
Loading

0 comments on commit 05b5499

Please sign in to comment.