Skip to content

Commit

Permalink
Use thread to overlap the put queue remote call overhead
Browse files Browse the repository at this point in the history
  • Loading branch information
s5u13b committed Oct 9, 2024
1 parent abe5e45 commit 3649b02
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 20 deletions.
2 changes: 1 addition & 1 deletion llumnix/backends/backend_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def abort_request(self, request_id: Union[str, Iterable[str]]) -> None:
raise NotImplementedError

@abstractmethod
def _start_engine_loop(self) -> None:
def _start_engine_step_loop(self) -> None:
"""Start step loop of backend engine.
"""
raise NotImplementedError
Expand Down
57 changes: 43 additions & 14 deletions llumnix/backends/vllm/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from collections import defaultdict
import threading
import asyncio
import queue
import ray
from ray.util.placement_group import PlacementGroup
from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy
Expand Down Expand Up @@ -70,20 +71,35 @@ async def put_nowait_batch_to_servers(self,


class LLMEngineLlumnix(LLMEngine):
def __init__(self, instance_id: str, *arg, **kwargs) -> None:
def __init__(self,
instance_id: str,
placement_group: Optional[PlacementGroup],
node_id: Optional[str],
*arg, **kwargs) -> None:
super().__init__(*arg, **kwargs)
self.instance_id = instance_id
self.step_counter = Counter()
self.instance_info = None
# TODO(s5u13b): Reduce the cross-actor overhead.
scheduling_strategy = NodeAffinitySchedulingStrategy(
node_id=ray.get_runtime_context().get_node_id(),
soft=False
# Place the async put queue actor together with the instance.
if placement_group:
scheduling_strategy = PlacementGroupSchedulingStrategy(
placement_group=placement_group,
placement_group_capture_child_tasks=True,
)
else:
scheduling_strategy = NodeAffinitySchedulingStrategy(
node_id=node_id,
soft=False,
)
self.put_queue_args_queue = queue.Queue()
self.put_queue_loop_thread = threading.Thread(
target=self._start_put_queue_loop, args=(), daemon=True, name="put_queue_loop"
)
self.async_put_queue_actor = ray.remote(
num_cpus=1,
scheduling_strategy=scheduling_strategy
)(AsyncPutQueueActor).remote(instance_id)
self.put_queue_loop_thread.start()

# pylint: disable=W0221
@classmethod
Expand All @@ -94,7 +110,7 @@ def from_engine_args(
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
instance_id: str = None,
placement_group: Optional[PlacementGroup] = None,
node_id: str = None,
node_id: Optional[str] = None,
latency_mem: Optional[LatencyMemData] = None
) -> "LLMEngineLlumnix":
"""Creates an LLM engine from the engine arguments."""
Expand All @@ -118,6 +134,8 @@ def from_engine_args(
# Create the LLM engine.
engine = cls(
instance_id=instance_id,
placement_group=placement_group,
node_id=node_id,
**engine_config.to_dict(),
executor_class=executor_class,
log_stats=not engine_args.disable_log_stats,
Expand Down Expand Up @@ -171,10 +189,12 @@ def step(self) -> None:
tot_blocks.extend(blocks)
tot_blocks = set(tot_blocks)
instance_info.num_blocks_last_running_request = len(tot_blocks)

if request_outputs:
self._put_request_outputs_to_server(request_outputs, server_infos)
self.put_queue_args_queue.put((request_outputs, server_infos))
self.instance_info = instance_info
num_request_outputs = len(request_outputs)

return num_request_outputs

def update_instance_info(self, instance_info: InstanceInfo) -> None:
# These fields are updated after step.
Expand All @@ -193,7 +213,13 @@ def add_request(self, request_id: str, server_info: ServerInfo, *args, **kwargs)
seq_group.metrics.arrival_time, seq_group.lora_request, seq_group.multi_modal_data)
self.scheduler.scheduler_lock.release()

def _put_request_outputs_to_server(self, request_outputs, server_infos: List[ServerInfo]) -> None:
def _start_put_queue_loop(self):
while True:
args = self.put_queue_args_queue.get()
request_outputs, server_infos = args
self._put_request_outputs_to_server(request_outputs, server_infos)

def _put_request_outputs_to_server(self, request_outputs: List[RequestOutput], server_infos: List[ServerInfo]) -> None:
server_request_outputs = defaultdict(list)
server_info_dict = {}
# Reorganize data in orther to put request output to queue in batch at one time.
Expand All @@ -202,6 +228,7 @@ def _put_request_outputs_to_server(self, request_outputs, server_infos: List[Ser
server_request_outputs[server_id].append(request_output)
if server_id not in server_info_dict:
server_info_dict[server_id] = server_info
# TODO(s5u13b): Reduce the cross-actor overhead.
self.async_put_queue_actor.put_nowait_batch_to_servers.remote(server_request_outputs, server_info_dict)

class BackendVLLM(BackendInterface):
Expand All @@ -228,14 +255,16 @@ def __init__(
self._run_workers("init_migration", instance_id=instance_id, migration_config=migration_config,\
src_worker_handle_list=self.worker_handle_list,
placement_group=placement_group, node_id=node_id)
self._thread = threading.Thread(
target=self._start_engine_loop, args=(), daemon=True, name="engine_loop"
self.engine_step_loop_thread = threading.Thread(
target=self._start_engine_step_loop, args=(), daemon=True, name="engine_step_loop"
)
self._thread.start()
self.engine_step_loop_thread.start()

def _start_engine_loop(self) -> None:
def _start_engine_step_loop(self) -> None:
while True:
self.engine.step()
num_request_outputs = self.engine.step()
if num_request_outputs == 0:
time.sleep(0.01)

def execute_worker_method(self, method, *args, **kwargs):
return self.engine.model_executor.driver_worker.execute_method(method, *args, **kwargs)
Expand Down
17 changes: 12 additions & 5 deletions llumnix/llumlet/llumlet.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,9 @@ def from_args(cls,
lifetime=lifetime)(cls).options(
scheduling_strategy=PlacementGroupSchedulingStrategy(
placement_group=placement_group,
placement_group_bundle_index=0,))
placement_group_bundle_index=0,
)
)
else:
kwargs["node_id"] = node_id
engine_class = ray.remote(num_cpus=1,
Expand All @@ -84,16 +86,21 @@ def from_args(cls,
lifetime=lifetime)(cls).options(
scheduling_strategy=NodeAffinitySchedulingStrategy(
node_id=node_id,
soft=False,))
soft=False,
)
)
else: # backend_type == backend_type.SIM_VLLM:
kwargs["node_id"] = node_id
engine_class = ray.remote(num_cpus=1,
name=f"instance_{instance_id}",
namespace='llumnix',
max_concurrency=4,
lifetime=lifetime)(cls).options(
scheduling_strategy=NodeAffinitySchedulingStrategy(
node_id=node_id,
soft=False,))
scheduling_strategy=NodeAffinitySchedulingStrategy(
node_id=node_id,
soft=False,
)
)
llumlet = engine_class.remote(instance_id, backend_type, migration_config, *args, **kwargs)
return llumlet

Expand Down

0 comments on commit 3649b02

Please sign in to comment.