Skip to content

Commit

Permalink
[Refactor] Asynchronous llumlet (#56)
Browse files Browse the repository at this point in the history
  • Loading branch information
ZeldaHuang authored Oct 16, 2024
1 parent e2fe3e2 commit 28d2743
Show file tree
Hide file tree
Showing 12 changed files with 118 additions and 168 deletions.
4 changes: 2 additions & 2 deletions llumnix/backends/backend_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def abort_request(self, request_id: Union[str, Iterable[str]]) -> None:
raise NotImplementedError

@abstractmethod
def _start_engine_step_loop(self) -> None:
async def _start_engine_step_loop(self) -> None:
"""Start step loop of backend engine.
"""
raise NotImplementedError
Expand Down Expand Up @@ -244,7 +244,7 @@ def free_src_request(self, backend_request: LlumnixRequest) -> None:
raise NotImplementedError

@abstractmethod
def send_blocks(self, dst_ray_actor: "ray.actor.ActorHandle", src_blocks: List[int], dst_blocks: List[int]):
async def send_blocks(self, dst_ray_actor: "ray.actor.ActorHandle", src_blocks: List[int], dst_blocks: List[int]):
"""
Sends cache blocks from the source instance to the destination instance.
Expand Down
26 changes: 14 additions & 12 deletions llumnix/backends/vllm/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,18 @@
# limitations under the License.

import time
import asyncio

from collections import defaultdict
from typing import List, Optional, Tuple
import ray
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy, NodeAffinitySchedulingStrategy
# pylint: disable=unused-import
from ray.util.placement_group import PlacementGroup

from vllm.executor.gpu_executor import GPUExecutor
from vllm.executor.ray_gpu_executor import RayGPUExecutor, RayWorkerWrapper, get_distributed_init_method,\
get_ip, get_vllm_instance_id, get_open_port
from vllm.executor.executor_base import ExecutorBase
from vllm.executor.ray_gpu_executor import RayGPUExecutor, RayGPUExecutorAsync, RayWorkerWrapper,\
get_distributed_init_method, get_ip, get_vllm_instance_id, get_open_port

from vllm import envs
from vllm.sequence import Logprob, SequenceOutput, SequenceGroupOutput, SamplerOutput, ExecuteModelRequest
Expand All @@ -34,7 +36,7 @@

logger = init_logger(__name__)

class LlumnixRayGPUExecutor(RayGPUExecutor):
class LlumnixRayGPUExecutor(RayGPUExecutorAsync):
node_id: str = None
migration_config: MigrationConfig = None

Expand Down Expand Up @@ -157,17 +159,17 @@ def _init_workers_ray(self, placement_group: "PlacementGroup",
cache_config=self.cache_config,
parallel_config=self.parallel_config)

def execute_model(self, *args, **kwargs):
async def execute_model_async(self, *args, **kwargs):
t0 = time.time()
outputs = super().execute_model(*args, **kwargs)
outputs = await super().execute_model_async(*args, **kwargs)
t1 = time.time()
self.last_inference_latency = (t1 - t0) * 1000
return outputs

class SimGPUExecutor(GPUExecutor):
class SimGPUExecutor(RayGPUExecutor):
latency_mem: LatencyMemData = None
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
RayGPUExecutor.__init__(self, *args, **kwargs)
self.last_inference_latency = 0
self.migration_bandwidth = self.latency_mem.migration_bandwidth
# TODO(ZeldaHuang): add swap bandwidth
Expand All @@ -191,7 +193,7 @@ def initialize_cache(self, num_gpu_blocks: int,
num_cpu_blocks: int) -> None:
logger.info("# GPU blocks: %d, # CPU blocks: %d", num_gpu_blocks, num_cpu_blocks)

def execute_model(
async def execute_model_async(
self,
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
prefill_seq_len = 0
Expand All @@ -213,7 +215,7 @@ def execute_model(
decode_meta_data = (decode_bs, decode_seq_len)
latency += self.latency_mem.decode_latency[decode_meta_data][0] if decode_meta_data in self.latency_mem.decode_latency \
else model_decode((decode_bs, decode_seq_len), *self.latency_mem.decode_model_params)
time.sleep(latency/1000)
await asyncio.sleep(latency/1000)
sampler_outputs = []
for meta_data in execute_model_req.seq_group_metadata_list:
samples = []
Expand All @@ -225,6 +227,6 @@ def execute_model(
sampler_outputs.append(output)
return [SamplerOutput(outputs=sampler_outputs)]

def send_blocks(self, blocks_len) -> None:
async def send_blocks(self, blocks_len) -> None:
migration_latency = (self.cache_block_size * blocks_len) / self.migration_bandwidth
time.sleep(migration_latency)
await asyncio.sleep(migration_latency)
107 changes: 49 additions & 58 deletions llumnix/backends/vllm/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from ray.util.placement_group import PlacementGroup
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy, NodeAffinitySchedulingStrategy

from vllm.engine.llm_engine import LLMEngine
from vllm.engine.async_llm_engine import _AsyncLLMEngine
from vllm.core.scheduler import ScheduledSequenceGroup
from vllm.outputs import RequestOutput
from vllm.sequence import SequenceGroup, SequenceStatus, SamplerOutput, SequenceGroupMetadata
Expand Down Expand Up @@ -82,7 +82,7 @@ async def put_nowait_to_servers(self,
logger.error("exception traceback: {}".format(traceback.format_exc()))


class LLMEngineLlumnix(LLMEngine):
class LLMEngineLlumnix(_AsyncLLMEngine):
def __init__(self,
instance_id: str,
output_queue_type: QueueType,
Expand Down Expand Up @@ -171,38 +171,37 @@ def _process_model_outputs(
seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Tuple[List[RequestOutput], List[ServerInfo]]:
# ensure scheduled_seq_groups matching output
with self.scheduler.scheduler_lock:
server_infos = []
if output:
new_output = []
new_scheduled_seq_groups = []
new_seq_group_metadata_list = []
for scheduled_seq_group, seq_group_meta, seq_group_output in zip(scheduled_seq_groups, seq_group_metadata_list, output[0].outputs):
seq_group = scheduled_seq_group.seq_group
if seq_group.get_seqs(SequenceStatus.RUNNING):
new_scheduled_seq_groups.append(scheduled_seq_group)
new_seq_group_metadata_list.append(seq_group_meta)
new_output.append(seq_group_output)
server_infos.append(seq_group.server_info)
scheduled_seq_groups = new_scheduled_seq_groups
output[0].outputs = new_output
seq_group_metadata_list = new_seq_group_metadata_list
for ignored_seq_group in ignored_seq_groups:
server_infos.append(ignored_seq_group.server_info)
for server_info in server_infos:
if hasattr(server_info, 'request_timestamps'):
server_info.request_timestamps.engine_process_model_outputs_timestamp_begin = time.time()
request_outputs = super()._process_model_outputs(output, scheduled_seq_groups, ignored_seq_groups, seq_group_metadata_list)
for request_output, server_info in zip(request_outputs, server_infos):
if hasattr(server_info, 'request_timestamps'):
request_output.request_timestamps = server_info.request_timestamps
request_output.request_timestamps.engine_process_model_outputs_timestamp_end = time.time()
# TODO(ZeldaHuang): Use LlumnixRequestOutput to store llumnix output args.
return request_outputs, server_infos

def step(self) -> None:
server_infos = []
if output:
new_output = []
new_scheduled_seq_groups = []
new_seq_group_metadata_list = []
for scheduled_seq_group, seq_group_meta, seq_group_output in zip(scheduled_seq_groups, seq_group_metadata_list, output[0].outputs):
seq_group = scheduled_seq_group.seq_group
if seq_group.get_seqs(SequenceStatus.RUNNING):
new_scheduled_seq_groups.append(scheduled_seq_group)
new_seq_group_metadata_list.append(seq_group_meta)
new_output.append(seq_group_output)
server_infos.append(seq_group.server_info)
scheduled_seq_groups = new_scheduled_seq_groups
output[0].outputs = new_output
seq_group_metadata_list = new_seq_group_metadata_list
for ignored_seq_group in ignored_seq_groups:
server_infos.append(ignored_seq_group.server_info)
for server_info in server_infos:
if hasattr(server_info, 'request_timestamps'):
server_info.request_timestamps.engine_process_model_outputs_timestamp_begin = time.time()
request_outputs = super()._process_model_outputs(output, scheduled_seq_groups, ignored_seq_groups, seq_group_metadata_list)
for request_output, server_info in zip(request_outputs, server_infos):
if hasattr(server_info, 'request_timestamps'):
request_output.request_timestamps = server_info.request_timestamps
request_output.request_timestamps.engine_process_model_outputs_timestamp_end = time.time()
# TODO(ZeldaHuang): Use LlumnixRequestOutput to store llumnix output args.
return request_outputs, server_infos

async def step_async(self) -> None:
step_begin_time = time.time()
request_outputs, server_infos = super().step()
request_outputs, server_infos = await super().step_async()
for request_output in request_outputs:
if hasattr(request_output, 'request_timestamps'):
request_output.request_timestamps.engine_step_timestamp_begin = step_begin_time
Expand Down Expand Up @@ -251,7 +250,6 @@ def add_request(self, request_id: str, server_info: ServerInfo, expected_steps:
self.scheduler.waiting[-1] = SequenceGroupLlumnix(request_id, server_info, expected_steps, [seq_group.get_seqs()[0]],
seq_group.sampling_params, seq_group.metrics.arrival_time, seq_group.lora_request,
seq_group.multi_modal_data)
self.scheduler.scheduler_lock.release()

def _start_put_queue_loop(self):
while True:
Expand Down Expand Up @@ -301,45 +299,38 @@ def __init__(
src_worker_handle_list=self.worker_handle_list,
placement_group=placement_group, node_id=node_id)

self.state_lock = threading.Lock()
self.state = EngineState.INIT
logger.info("engine ({}) current state {}".format(self.instance_id, self.state))

self._stop_event = threading.Event()
self.engine_step_loop_thread = threading.Thread(
target=self._start_engine_step_loop, args=(), daemon=True, name="engine_step_loop"
)
self.engine_step_loop_thread.start()
self._stop_event = asyncio.Event()
asyncio.create_task(self._start_engine_step_loop())

def _start_engine_step_loop(self) -> None:
async def _start_engine_step_loop(self) -> None:
self._stop_event.clear()

with self.state_lock:
previous_state = self.state
self.state = EngineState.RUNNING
logger.info("engine ({}) change state: {} -> {}".format(self.instance_id, previous_state, self.state))
previous_state = self.state
self.state = EngineState.RUNNING
logger.info("engine ({}) change state: {} -> {}".format(self.instance_id, previous_state, self.state))

while not self._stop_event.is_set():
try:
request_outputs, _ = self.engine.step()
request_outputs, _ = await self.engine.step_async()
if len(request_outputs) == 0:
time.sleep(0.01)
await asyncio.sleep(0.01)
# pylint: disable=broad-except
except Exception as e:
logger.error("Error in engine loop: {}".format(e))
logger.error("exception traceback: {}".format(traceback.format_exc()))
self._run_workers("shutdown")

with self.state_lock:
previous_state = self.state
self.state = EngineState.CRASHED
logger.info("engine ({}) change state: {} -> {}".format(self.instance_id, previous_state, self.state))
previous_state = self.state
self.state = EngineState.CRASHED
logger.info("engine ({}) change state: {} -> {}".format(self.instance_id, previous_state, self.state))
break

with self.state_lock:
if self.state == EngineState.RUNNING:
self.state = EngineState.STOPPED
logger.info("engine ({}) change state: {} -> {}".format(self.instance_id, EngineState.RUNNING, self.state))
if self.state == EngineState.RUNNING:
self.state = EngineState.STOPPED
logger.info("engine ({}) change state: {} -> {}".format(self.instance_id, EngineState.RUNNING, self.state))

def execute_worker_method(self, method, *args, **kwargs):
return self.engine.model_executor.driver_worker.execute_method(method, *args, **kwargs)
Expand All @@ -362,12 +353,12 @@ def commit_dst_request(self, backend_request: SequenceGroupLlumnix) -> None:
backend_request.reset_migration_args()
self.add_running_request(backend_request)

def send_blocks(self, dst_ray_actor: "ray.actor.ActorHandle", src_blocks: List[int], dst_blocks: List[int]) -> None:
ray.get(dst_ray_actor.execute_engine_method.remote("_run_workers",
async def send_blocks(self, dst_ray_actor: "ray.actor.ActorHandle", src_blocks: List[int], dst_blocks: List[int]) -> None:
await dst_ray_actor.execute_engine_method.remote("_run_workers",
"migrate_cache",
dst_blocks=dst_blocks,
src_blocks=src_blocks,
src_worker_handle_list=self.worker_handle_list))
src_worker_handle_list=self.worker_handle_list)

def _run_workers(self, *args, **kwargs):
# pylint: disable=protected-access
Expand Down
23 changes: 0 additions & 23 deletions llumnix/backends/vllm/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@

from asyncio.log import logger
import time
import threading
from typing import Dict, List, Optional, Tuple
from collections import deque

Expand All @@ -23,7 +22,6 @@
from llumnix.instance_info import InstanceInfo
from llumnix.logger import init_logger
from llumnix.llumlet.request import RequestInferenceType
from llumnix.backends.vllm.utils import scheduler_lock
from llumnix.backends.vllm.sequence import SequenceGroupLlumnix

logger = init_logger(__name__)
Expand Down Expand Up @@ -56,7 +54,6 @@ def __init__(self, *args, **kwargs) -> None:
sliding_window=self.cache_config.sliding_window,
enable_caching=self.cache_config.enable_prefix_caching)
self.pre_alloc_cache_dict: Dict[str, BlockTable] = {}
self.scheduler_lock = threading.Lock()
self.migrating_out_request_last_stage: List[SequenceGroupLlumnix] = []

def add_update_instance_info_callback(self, update_instance_info_callback):
Expand All @@ -79,25 +76,21 @@ def _get_num_killed_requests(self) -> int:
cnt += 1
return cnt

@scheduler_lock
def get_running_queue(self):
return self.running

@scheduler_lock
def get_all_request_ids(self) -> List[str]:
request_ids : List[str] = []
for state_queue in [self.waiting, self.running, self.swapped]:
for seq_group in state_queue:
request_ids.append(seq_group.request_id)
return request_ids

@scheduler_lock
def get_request_incremental_blocks(self, backend_request: SequenceGroupLlumnix, pre_stage_num_blocks: int) -> List[int]:
seq = backend_request.get_seqs()[0]
blocks = self.block_manager.get_block_table(seq)
return blocks[pre_stage_num_blocks:]

@scheduler_lock
def remove_running_request(self, request_id: str) -> None:
for seq_group in self.running:
if seq_group.request_id == request_id:
Expand All @@ -117,7 +110,6 @@ def pop_migrating_out_requests_last_stage(self) -> List[SequenceGroupLlumnix]:
self.migrating_out_request_last_stage.clear()
return migrating_out_request_last_stage

@scheduler_lock
def pre_alloc(self, request_id: str, block_num: int) -> List[int]:
blocks = self.block_manager.get_free_blocks(block_num)
pre_blocks = self.pre_alloc_cache_dict.get(request_id, [])
Expand All @@ -126,17 +118,14 @@ def pre_alloc(self, request_id: str, block_num: int) -> List[int]:
blocks = [block.block_number for block in blocks]
return blocks

@scheduler_lock
def add_running_request(self, backend_request: SequenceGroupLlumnix) -> None:
seq = backend_request.get_seqs()[0]
seq.status = SequenceStatus.RUNNING
self.running.append(backend_request)

@scheduler_lock
def is_request_running(self, backend_request: SequenceGroupLlumnix) -> bool:
return backend_request in self.running

@scheduler_lock
def free_dst_pre_alloc_cache(self, request_id: str = None) -> None:
if request_id:
blocks = self.pre_alloc_cache_dict.pop(request_id, [])
Expand All @@ -150,7 +139,6 @@ def free_dst_pre_alloc_cache(self, request_id: str = None) -> None:
# pylint: disable=protected-access
self.block_manager._free_block_table(blocks)

@scheduler_lock
def free_src_request(self, backend_request: SequenceGroupLlumnix) -> None:
seq = backend_request.get_seqs()[0]
logger.info("free seq {}".format(seq.seq_id))
Expand Down Expand Up @@ -201,7 +189,6 @@ def _get_instance_info(self, scheduled_seq_groups: List[SequenceGroupLlumnix]) -
instance_info.finished_request_ids = [seq_group.request_id for seq_group in self.running if seq_group.is_finished()]
return instance_info

@scheduler_lock
def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]:
seq_group_metadata_list, scheduler_outputs = super().schedule()
self.update_instance_info_callback(self._get_instance_info([scheduled_seq_group.seq_group \
Expand All @@ -220,13 +207,3 @@ def _schedule_running(self, running_queue: deque, *args, **kwargs):
for seq_group in remove_running:
remaining_running.extend([seq_group])
return remaining_running, running_scheduled

def add_seq_group(self, *args, **kwargs):
# The scheduler lock is mannually released in the end of LLMEngineLlumnix.add_request function.
# pylint: disable=R1732
self.scheduler_lock.acquire()
return super().add_seq_group(*args, **kwargs)

@scheduler_lock
def abort_seq_group(self, *args, **kwargs):
return super().abort_seq_group(*args, **kwargs)
Loading

0 comments on commit 28d2743

Please sign in to comment.