Skip to content

Commit

Permalink
[Core] Optimize request output tokens putting back implementation to …
Browse files Browse the repository at this point in the history
…reduce overhead (#45)
  • Loading branch information
s5u13b authored Oct 11, 2024
1 parent fc5ecee commit 6de3bd7
Show file tree
Hide file tree
Showing 22 changed files with 173 additions and 104 deletions.
2 changes: 1 addition & 1 deletion configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ SERVER:
QUEUE_TYPE: "rayqueue"

RAY:
RAY_CLUSTER_PORT: 30037
RAY_CLUSTER_PORT: 6379
LAUNCH_RAY_CLUSTER: True

MANAGER:
Expand Down
15 changes: 8 additions & 7 deletions examlpes/offline_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
# Launch ray cluster
os.environ['HEAD_NODE'] = '1'
os.environ['HEAD_NODE_IP'] = '127.0.0.1'
ray_cluster_port=37000
ray_cluster_port=6379

# Note: launch_ray_cluster will stop current ray cluster first, then init a new one.
launch_ray_cluster(ray_cluster_port=ray_cluster_port)
Expand Down Expand Up @@ -58,12 +58,13 @@
async def background_process_outputs(num_tasks):
finish_task = 0
while finish_task != num_tasks:
request_output = await request_output_queue.get()
if request_output.finished:
finish_task += 1
prompt = request_output.prompt
generated_text = request_output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
request_outputs = await request_output_queue.get()
for request_output in request_outputs:
if request_output.finished:
finish_task += 1
prompt = request_output.prompt
generated_text = request_output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
request_output_queue.cleanup()

async def main():
Expand Down
2 changes: 1 addition & 1 deletion llumnix/backends/backend_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def abort_request(self, request_id: Union[str, Iterable[str]]) -> None:
raise NotImplementedError

@abstractmethod
def _start_engine_loop(self) -> None:
def _start_engine_step_loop(self) -> None:
"""Start step loop of backend engine.
"""
raise NotImplementedError
Expand Down
105 changes: 67 additions & 38 deletions llumnix/backends/vllm/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@
from collections import defaultdict
import threading
import asyncio
import queue
import ray
from ray.util.placement_group import PlacementGroup
from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy

from vllm.engine.llm_engine import LLMEngine
from vllm.core.scheduler import ScheduledSequenceGroup
Expand All @@ -42,58 +44,71 @@
logger = init_logger(__name__)


class AsyncPutQueueThread(threading.Thread):
class AsyncPutQueueActor:
def __init__(self, instance_id, output_queue_type: QueueType):
super().__init__()
self.instance_id = instance_id

self.request_output_queue_client: QueueClientBase \
= get_output_queue_client(output_queue_type)
self.request_output_queue_client: QueueClientBase = get_output_queue_client(output_queue_type)
self.engine_actor_handle = None
self.loop = asyncio.new_event_loop()
self.daemon = True

def run(self):
asyncio.set_event_loop(self.loop)
self.loop.run_forever()

async def _put_nowait_batch_to_servers(self,
server_request_outputs: Dict[str, List[RequestOutput]],
server_info_dict: Dict[str, ServerInfo]) -> None:
async def put_nowait_to_servers(self,
server_request_outputs: Dict[str, List[RequestOutput]],
server_info_dict: Dict[str, ServerInfo]) -> None:
if self.engine_actor_handle is None:
self.engine_actor_handle = ray.get_actor("instance_{}".format(self.instance_id), namespace="llumnix")
tasks = []
for server_id, req_outputs in server_request_outputs.items():
server_info = server_info_dict[server_id]
tasks.append(asyncio.create_task(self.request_output_queue_client.put_nowait_batch(req_outputs, server_info)))
tasks.append(asyncio.create_task(self.request_output_queue_client.put_nowait(req_outputs, server_info)))
rets = await asyncio.gather(*tasks, return_exceptions=True)
for idx, ret in enumerate(rets):
if isinstance(ret, TimeoutError):
if isinstance(ret, (TimeoutError, ray.exceptions.RayActorError)):
server_id = list(server_request_outputs.keys())[idx]
server_info = server_info_dict[server_id]
logger.info("Server {} is dead".format(server_id))
logger.info("request output queue ip: {}, port: {}".format(server_info.request_output_queue_ip,
server_info.request_output_queue_port))
if output_queue_type == QueueType.ZMQ:
logger.info("request output queue ip: {}, port: {}".format(server_info.request_output_queue_ip,
server_info.request_output_queue_port))
req_outputs = list(server_request_outputs.values())[idx]
request_ids = [req_output.request_id for req_output in req_outputs]
self.engine_actor_handle.abort_request.remote(request_ids)

def put_nowait_batch_to_servers(self,
server_request_outputs: Dict[str, List[RequestOutput]],
server_info_dict: Dict[str, ServerInfo]) -> None:
asyncio.run_coroutine_threadsafe(self._put_nowait_batch_to_servers(server_request_outputs, server_info_dict),
self.loop)


class LLMEngineLlumnix(LLMEngine):
def __init__(self, instance_id: str, output_queue_type: QueueType, *arg, **kwargs) -> None:
def __init__(self,
instance_id: str,
output_queue_type: QueueType,
placement_group: Optional[PlacementGroup],
node_id: Optional[str],
*arg, **kwargs) -> None:
super().__init__(*arg, **kwargs)
self.instance_id = instance_id
self.step_counter = Counter()
self.instance_info = None
# TODO(s5u13b): Reduce the overhead.
self.async_put_queue_thread = AsyncPutQueueThread(instance_id, output_queue_type)
self.async_put_queue_thread.start()
# Place the async put queue actor together with the instance.
if placement_group:
scheduling_strategy = PlacementGroupSchedulingStrategy(
placement_group=placement_group,
placement_group_capture_child_tasks=True,
)
elif node_id:
scheduling_strategy = NodeAffinitySchedulingStrategy(
node_id=node_id,
soft=False,
)
else: # When use simulator, placement_group and node_id are both None.
scheduling_strategy = NodeAffinitySchedulingStrategy(
node_id=ray.get_runtime_context().get_node_id(),
soft=False,
)
self.put_queue_args_queue = queue.Queue()
self.put_queue_loop_thread = threading.Thread(
target=self._start_put_queue_loop, args=(), daemon=True, name="put_queue_loop"
)
self.async_put_queue_actor = ray.remote(
num_cpus=1,
scheduling_strategy=scheduling_strategy
)(AsyncPutQueueActor).remote(instance_id, output_queue_type)
self.put_queue_loop_thread.start()

# pylint: disable=W0221
@classmethod
Expand All @@ -105,7 +120,7 @@ def from_engine_args(
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
instance_id: str = None,
placement_group: Optional[PlacementGroup] = None,
node_id: str = None,
node_id: Optional[str] = None,
latency_mem: Optional[LatencyMemData] = None
) -> "LLMEngineLlumnix":
"""Creates an LLM engine from the engine arguments."""
Expand All @@ -130,6 +145,8 @@ def from_engine_args(
engine = cls(
instance_id=instance_id,
output_queue_type=output_queue_type,
placement_group=placement_group,
node_id=node_id,
**engine_config.to_dict(),
executor_class=executor_class,
log_stats=not engine_args.disable_log_stats,
Expand Down Expand Up @@ -186,10 +203,12 @@ def step(self) -> None:
tot_blocks.extend(blocks)
tot_blocks = set(tot_blocks)
instance_info.num_blocks_last_running_request = len(tot_blocks)

if request_outputs:
self._put_request_outputs_to_server(request_outputs, server_infos)
self.put_queue_args_queue.put((request_outputs, server_infos))
self.instance_info = instance_info
num_request_outputs = len(request_outputs)

return num_request_outputs

def update_instance_info(self, instance_info: InstanceInfo) -> None:
# These fields are updated after step.
Expand All @@ -208,7 +227,13 @@ def add_request(self, request_id: str, server_info: ServerInfo, *args, **kwargs)
seq_group.metrics.arrival_time, seq_group.lora_request, seq_group.multi_modal_data)
self.scheduler.scheduler_lock.release()

def _put_request_outputs_to_server(self, request_outputs, server_infos: List[ServerInfo]) -> None:
def _start_put_queue_loop(self):
while True:
args = self.put_queue_args_queue.get()
request_outputs, server_infos = args
self._put_request_outputs_to_server(request_outputs, server_infos)

def _put_request_outputs_to_server(self, request_outputs: List[RequestOutput], server_infos: List[ServerInfo]) -> None:
server_request_outputs = defaultdict(list)
server_info_dict = {}
# Reorganize data in orther to put request output to queue in batch at one time.
Expand All @@ -217,7 +242,9 @@ def _put_request_outputs_to_server(self, request_outputs, server_infos: List[Ser
server_request_outputs[server_id].append(request_output)
if server_id not in server_info_dict:
server_info_dict[server_id] = server_info
self.async_put_queue_thread.put_nowait_batch_to_servers(server_request_outputs, server_info_dict)
# TODO(s5u13b): Reduce the across-actor overhead.
# TODO(s5u13b): It is not necessary to use async_put_queue_actor when output_queue_type is RayQueue.
self.async_put_queue_actor.put_nowait_to_servers.remote(server_request_outputs, server_info_dict)

class BackendVLLM(BackendInterface):
def __init__(
Expand Down Expand Up @@ -251,12 +278,12 @@ def __init__(
logger.info("engine ({}) current state {}".format(self.instance_id, self.state))

self._stop_event = threading.Event()
self._thread = threading.Thread(
target=self._start_engine_loop, args=(), daemon=True, name="engine_loop"
self.engine_step_loop_thread = threading.Thread(
target=self._start_engine_step_loop, args=(), daemon=True, name="engine_step_loop"
)
self._thread.start()
self.engine_step_loop_thread.start()

def _start_engine_loop(self) -> None:
def _start_engine_step_loop(self) -> None:
self._stop_event.clear()

with self.state_lock:
Expand All @@ -266,7 +293,9 @@ def _start_engine_loop(self) -> None:

while not self._stop_event.is_set():
try:
self.engine.step()
num_request_outputs = self.engine.step()
if num_request_outputs == 0:
time.sleep(0.01)
# pylint: disable=broad-except
except Exception as e:
logger.error("Error in engine loop: {}".format(e))
Expand Down
6 changes: 3 additions & 3 deletions llumnix/backends/vllm/simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,10 @@ def __init__(
self.engine.scheduler.add_update_instance_info_callback(self.engine.update_instance_info)
self.engine.output_processor.scheduler = self.engine.scheduler
self.instance_id = instance_id
self._thread = threading.Thread(
target=self._start_engine_loop, args=(), daemon=True, name="engine_loop"
self.engine_step_loop_thread = threading.Thread(
target=self._start_engine_step_loop, args=(), daemon=True, name="engine_loop"
)
self._thread.start()
self.engine_step_loop_thread.start()

# pylint: disable=unused-argument
def send_blocks(self, dst_ray_actor: ray.actor.ActorHandle, src_blocks: List[int], dst_blocks: List[int]) -> None:
Expand Down
2 changes: 1 addition & 1 deletion llumnix/config/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
# -----------------------------------------------------------------------------
_C.RAY = LC()
# Port number for the Ray cluster
_C.RAY.RAY_CLUSTER_PORT = 30050
_C.RAY.RAY_CLUSTER_PORT = 6379
# If True, launch Ray cluster in API server
_C.RAY.LAUNCH_RAY_CLUSTER = False

Expand Down
23 changes: 12 additions & 11 deletions llumnix/entrypoints/vllm/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,15 +56,16 @@

async def _background_process_outputs():
while True:
request_output = await request_output_queue.get()
request_id = request_output.request_id
# Request could be dispatched twice when manager is dead, the first request will free the request_streams when finished.
if request_id not in request_streams:
continue
request_streams[request_id].put(request_output)
if request_output.finished:
request_streams[request_id].finish()
del request_streams[request_id]
request_outputs = await request_output_queue.get()
for request_output in request_outputs:
request_id = request_output.request_id
# Request could be dispatched twice when manager is dead, the first request will free the request_streams when finished.
if request_id not in request_streams:
continue
request_streams[request_id].put(request_output)
if request_output.finished:
request_streams[request_id].finish()
del request_streams[request_id]

# pylint: disable=unused-argument
@asynccontextmanager
Expand Down Expand Up @@ -180,11 +181,11 @@ async def generate_benchmark(request: Request) -> Response:
sampling_params = SamplingParams(**request_dict)
request_id = random_uuid()

start = time.time()

results_generator = await manager_generate(prompt, sampling_params, request_id)

per_token_latency = []
start = time.time()

# Non-streaming case
final_output = None
async for request_output in results_generator:
Expand Down
24 changes: 15 additions & 9 deletions llumnix/llumlet/llumlet.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@
from typing import List, Union, Iterable
import time
import ray
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy, NodeAffinitySchedulingStrategy

from llumnix.logger import init_logger
from llumnix.instance_info import InstanceInfo
Expand Down Expand Up @@ -86,7 +85,9 @@ def from_args(cls,
lifetime=lifetime)(cls).options(
scheduling_strategy=PlacementGroupSchedulingStrategy(
placement_group=placement_group,
placement_group_bundle_index=0,))
placement_group_bundle_index=0,
)
)
else:
kwargs["node_id"] = node_id
engine_class = ray.remote(num_cpus=1,
Expand All @@ -96,16 +97,21 @@ def from_args(cls,
lifetime=lifetime)(cls).options(
scheduling_strategy=NodeAffinitySchedulingStrategy(
node_id=node_id,
soft=False,))
soft=False,
)
)
else: # backend_type == backend_type.SIM_VLLM:
kwargs["node_id"] = node_id
engine_class = ray.remote(num_cpus=1,
name=actor_name,
namespace='llumnix',
max_concurrency=4,
lifetime=lifetime)(cls).options(
scheduling_strategy=NodeAffinitySchedulingStrategy(
node_id=node_id,
soft=False,))
scheduling_strategy=NodeAffinitySchedulingStrategy(
node_id=node_id,
soft=False,
)
)
llumlet = engine_class.remote(instance_id, output_queue_type, backend_type, migration_config, *args, **kwargs)
return llumlet

Expand All @@ -118,8 +124,8 @@ def check_state(self):
logger.warning("llumlet ({}) detected backend engine crashed. Stopping...".format(self.instance_id))
# pylint: disable=protected-access
self.backend_engine._stop_event.set()
if self.backend_engine._thread.is_alive():
self.backend_engine._thread.join()
if self.backend_engine.engine_step_loop_thread.is_alive():
self.backend_engine.engine_step_loop_thread.join()

self_actor = ray.get_actor(self.actor_name)
ray.kill(self_actor)
Expand Down
5 changes: 5 additions & 0 deletions llumnix/queue/queue_client_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,16 @@
# limitations under the License.

from abc import ABC, abstractmethod
from typing import Any
from collections.abc import Iterable

from llumnix.server_info import ServerInfo

class QueueClientBase(ABC):
@abstractmethod
async def put_nowait(self, item: Any, server_info: ServerInfo):
raise NotImplementedError

@abstractmethod
async def put_nowait_batch(self, items: Iterable, server_info: ServerInfo):
raise NotImplementedError
Loading

0 comments on commit 6de3bd7

Please sign in to comment.