Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] Optimize request output tokens putting back implementation to reduce overhead #45

Merged
merged 15 commits into from
Oct 11, 2024
2 changes: 1 addition & 1 deletion configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ SERVER:
QUEUE_TYPE: "rayqueue"

RAY:
RAY_CLUSTER_PORT: 30037
RAY_CLUSTER_PORT: 6379
LAUNCH_RAY_CLUSTER: True

MANAGER:
Expand Down
15 changes: 8 additions & 7 deletions examlpes/offline_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
# Launch ray cluster
os.environ['HEAD_NODE'] = '1'
os.environ['HEAD_NODE_IP'] = '127.0.0.1'
ray_cluster_port=37000
ray_cluster_port=6379

# Note: launch_ray_cluster will stop current ray cluster first, then init a new one.
launch_ray_cluster(ray_cluster_port=ray_cluster_port)
Expand Down Expand Up @@ -58,12 +58,13 @@
async def background_process_outputs(num_tasks):
finish_task = 0
while finish_task != num_tasks:
request_output = await request_output_queue.get()
if request_output.finished:
finish_task += 1
prompt = request_output.prompt
generated_text = request_output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
request_outputs = await request_output_queue.get()
for request_output in request_outputs:
if request_output.finished:
finish_task += 1
prompt = request_output.prompt
generated_text = request_output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
request_output_queue.cleanup()

async def main():
Expand Down
2 changes: 1 addition & 1 deletion llumnix/backends/backend_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def abort_request(self, request_id: Union[str, Iterable[str]]) -> None:
raise NotImplementedError

@abstractmethod
def _start_engine_loop(self) -> None:
def _start_engine_step_loop(self) -> None:
"""Start step loop of backend engine.
"""
raise NotImplementedError
Expand Down
105 changes: 67 additions & 38 deletions llumnix/backends/vllm/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@
from collections import defaultdict
import threading
import asyncio
import queue
import ray
from ray.util.placement_group import PlacementGroup
from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy

from vllm.engine.llm_engine import LLMEngine
from vllm.core.scheduler import ScheduledSequenceGroup
Expand All @@ -42,58 +44,71 @@
logger = init_logger(__name__)


class AsyncPutQueueThread(threading.Thread):
class AsyncPutQueueActor:
def __init__(self, instance_id, output_queue_type: QueueType):
super().__init__()
self.instance_id = instance_id

self.request_output_queue_client: QueueClientBase \
= get_output_queue_client(output_queue_type)
self.request_output_queue_client: QueueClientBase = get_output_queue_client(output_queue_type)
self.engine_actor_handle = None
self.loop = asyncio.new_event_loop()
self.daemon = True

def run(self):
asyncio.set_event_loop(self.loop)
self.loop.run_forever()

async def _put_nowait_batch_to_servers(self,
server_request_outputs: Dict[str, List[RequestOutput]],
server_info_dict: Dict[str, ServerInfo]) -> None:
async def put_nowait_to_servers(self,
server_request_outputs: Dict[str, List[RequestOutput]],
server_info_dict: Dict[str, ServerInfo]) -> None:
if self.engine_actor_handle is None:
self.engine_actor_handle = ray.get_actor("instance_{}".format(self.instance_id), namespace="llumnix")
tasks = []
for server_id, req_outputs in server_request_outputs.items():
server_info = server_info_dict[server_id]
tasks.append(asyncio.create_task(self.request_output_queue_client.put_nowait_batch(req_outputs, server_info)))
tasks.append(asyncio.create_task(self.request_output_queue_client.put_nowait(req_outputs, server_info)))
rets = await asyncio.gather(*tasks, return_exceptions=True)
for idx, ret in enumerate(rets):
if isinstance(ret, TimeoutError):
if isinstance(ret, (TimeoutError, ray.exceptions.RayActorError)):
server_id = list(server_request_outputs.keys())[idx]
server_info = server_info_dict[server_id]
logger.info("Server {} is dead".format(server_id))
logger.info("request output queue ip: {}, port: {}".format(server_info.request_output_queue_ip,
server_info.request_output_queue_port))
if output_queue_type == QueueType.ZMQ:
logger.info("request output queue ip: {}, port: {}".format(server_info.request_output_queue_ip,
server_info.request_output_queue_port))
req_outputs = list(server_request_outputs.values())[idx]
request_ids = [req_output.request_id for req_output in req_outputs]
self.engine_actor_handle.abort_request.remote(request_ids)

def put_nowait_batch_to_servers(self,
server_request_outputs: Dict[str, List[RequestOutput]],
server_info_dict: Dict[str, ServerInfo]) -> None:
asyncio.run_coroutine_threadsafe(self._put_nowait_batch_to_servers(server_request_outputs, server_info_dict),
self.loop)


class LLMEngineLlumnix(LLMEngine):
def __init__(self, instance_id: str, output_queue_type: QueueType, *arg, **kwargs) -> None:
def __init__(self,
instance_id: str,
output_queue_type: QueueType,
placement_group: Optional[PlacementGroup],
node_id: Optional[str],
*arg, **kwargs) -> None:
super().__init__(*arg, **kwargs)
self.instance_id = instance_id
self.step_counter = Counter()
self.instance_info = None
# TODO(s5u13b): Reduce the overhead.
self.async_put_queue_thread = AsyncPutQueueThread(instance_id, output_queue_type)
self.async_put_queue_thread.start()
# Place the async put queue actor together with the instance.
if placement_group:
scheduling_strategy = PlacementGroupSchedulingStrategy(
placement_group=placement_group,
placement_group_capture_child_tasks=True,
)
elif node_id:
scheduling_strategy = NodeAffinitySchedulingStrategy(
node_id=node_id,
soft=False,
)
else: # When use simulator, placement_group and node_id are both None.
scheduling_strategy = NodeAffinitySchedulingStrategy(
node_id=ray.get_runtime_context().get_node_id(),
soft=False,
)
self.put_queue_args_queue = queue.Queue()
self.put_queue_loop_thread = threading.Thread(
target=self._start_put_queue_loop, args=(), daemon=True, name="put_queue_loop"
)
self.async_put_queue_actor = ray.remote(
num_cpus=1,
scheduling_strategy=scheduling_strategy
)(AsyncPutQueueActor).remote(instance_id, output_queue_type)
self.put_queue_loop_thread.start()

# pylint: disable=W0221
@classmethod
Expand All @@ -105,7 +120,7 @@ def from_engine_args(
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
instance_id: str = None,
placement_group: Optional[PlacementGroup] = None,
node_id: str = None,
node_id: Optional[str] = None,
latency_mem: Optional[LatencyMemData] = None
) -> "LLMEngineLlumnix":
"""Creates an LLM engine from the engine arguments."""
Expand All @@ -130,6 +145,8 @@ def from_engine_args(
engine = cls(
instance_id=instance_id,
output_queue_type=output_queue_type,
placement_group=placement_group,
node_id=node_id,
**engine_config.to_dict(),
executor_class=executor_class,
log_stats=not engine_args.disable_log_stats,
Expand Down Expand Up @@ -186,10 +203,12 @@ def step(self) -> None:
tot_blocks.extend(blocks)
tot_blocks = set(tot_blocks)
instance_info.num_blocks_last_running_request = len(tot_blocks)

if request_outputs:
self._put_request_outputs_to_server(request_outputs, server_infos)
self.put_queue_args_queue.put((request_outputs, server_infos))
self.instance_info = instance_info
num_request_outputs = len(request_outputs)

return num_request_outputs

def update_instance_info(self, instance_info: InstanceInfo) -> None:
# These fields are updated after step.
Expand All @@ -208,7 +227,13 @@ def add_request(self, request_id: str, server_info: ServerInfo, *args, **kwargs)
seq_group.metrics.arrival_time, seq_group.lora_request, seq_group.multi_modal_data)
self.scheduler.scheduler_lock.release()

def _put_request_outputs_to_server(self, request_outputs, server_infos: List[ServerInfo]) -> None:
def _start_put_queue_loop(self):
zhypku marked this conversation as resolved.
Show resolved Hide resolved
while True:
args = self.put_queue_args_queue.get()
request_outputs, server_infos = args
self._put_request_outputs_to_server(request_outputs, server_infos)

def _put_request_outputs_to_server(self, request_outputs: List[RequestOutput], server_infos: List[ServerInfo]) -> None:
server_request_outputs = defaultdict(list)
server_info_dict = {}
# Reorganize data in orther to put request output to queue in batch at one time.
Expand All @@ -217,7 +242,9 @@ def _put_request_outputs_to_server(self, request_outputs, server_infos: List[Ser
server_request_outputs[server_id].append(request_output)
if server_id not in server_info_dict:
server_info_dict[server_id] = server_info
self.async_put_queue_thread.put_nowait_batch_to_servers(server_request_outputs, server_info_dict)
# TODO(s5u13b): Reduce the across-actor overhead.
# TODO(s5u13b): It is not necessary to use async_put_queue_actor when output_queue_type is RayQueue.
self.async_put_queue_actor.put_nowait_to_servers.remote(server_request_outputs, server_info_dict)

class BackendVLLM(BackendInterface):
def __init__(
Expand Down Expand Up @@ -251,12 +278,12 @@ def __init__(
logger.info("engine ({}) current state {}".format(self.instance_id, self.state))

self._stop_event = threading.Event()
self._thread = threading.Thread(
target=self._start_engine_loop, args=(), daemon=True, name="engine_loop"
self.engine_step_loop_thread = threading.Thread(
target=self._start_engine_step_loop, args=(), daemon=True, name="engine_step_loop"
)
self._thread.start()
self.engine_step_loop_thread.start()

def _start_engine_loop(self) -> None:
def _start_engine_step_loop(self) -> None:
self._stop_event.clear()

with self.state_lock:
Expand All @@ -266,7 +293,9 @@ def _start_engine_loop(self) -> None:

while not self._stop_event.is_set():
try:
self.engine.step()
num_request_outputs = self.engine.step()
if num_request_outputs == 0:
time.sleep(0.01)
# pylint: disable=broad-except
except Exception as e:
logger.error("Error in engine loop: {}".format(e))
Expand Down
6 changes: 3 additions & 3 deletions llumnix/backends/vllm/simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,10 @@ def __init__(
self.engine.scheduler.add_update_instance_info_callback(self.engine.update_instance_info)
self.engine.output_processor.scheduler = self.engine.scheduler
self.instance_id = instance_id
self._thread = threading.Thread(
target=self._start_engine_loop, args=(), daemon=True, name="engine_loop"
self.engine_step_loop_thread = threading.Thread(
target=self._start_engine_step_loop, args=(), daemon=True, name="engine_loop"
)
self._thread.start()
self.engine_step_loop_thread.start()

# pylint: disable=unused-argument
def send_blocks(self, dst_ray_actor: ray.actor.ActorHandle, src_blocks: List[int], dst_blocks: List[int]) -> None:
Expand Down
2 changes: 1 addition & 1 deletion llumnix/config/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
# -----------------------------------------------------------------------------
_C.RAY = LC()
# Port number for the Ray cluster
_C.RAY.RAY_CLUSTER_PORT = 30050
_C.RAY.RAY_CLUSTER_PORT = 6379
# If True, launch Ray cluster in API server
_C.RAY.LAUNCH_RAY_CLUSTER = False

Expand Down
23 changes: 12 additions & 11 deletions llumnix/entrypoints/vllm/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,15 +56,16 @@

async def _background_process_outputs():
while True:
request_output = await request_output_queue.get()
request_id = request_output.request_id
# Request could be dispatched twice when manager is dead, the first request will free the request_streams when finished.
if request_id not in request_streams:
continue
request_streams[request_id].put(request_output)
if request_output.finished:
request_streams[request_id].finish()
del request_streams[request_id]
request_outputs = await request_output_queue.get()
for request_output in request_outputs:
request_id = request_output.request_id
# Request could be dispatched twice when manager is dead, the first request will free the request_streams when finished.
if request_id not in request_streams:
continue
request_streams[request_id].put(request_output)
if request_output.finished:
request_streams[request_id].finish()
del request_streams[request_id]

# pylint: disable=unused-argument
@asynccontextmanager
Expand Down Expand Up @@ -180,11 +181,11 @@ async def generate_benchmark(request: Request) -> Response:
sampling_params = SamplingParams(**request_dict)
request_id = random_uuid()

start = time.time()

results_generator = await manager_generate(prompt, sampling_params, request_id)

per_token_latency = []
start = time.time()

# Non-streaming case
final_output = None
async for request_output in results_generator:
Expand Down
24 changes: 15 additions & 9 deletions llumnix/llumlet/llumlet.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@
from typing import List, Union, Iterable
import time
import ray
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy, NodeAffinitySchedulingStrategy

from llumnix.logger import init_logger
from llumnix.instance_info import InstanceInfo
Expand Down Expand Up @@ -86,7 +85,9 @@ def from_args(cls,
lifetime=lifetime)(cls).options(
scheduling_strategy=PlacementGroupSchedulingStrategy(
placement_group=placement_group,
placement_group_bundle_index=0,))
placement_group_bundle_index=0,
)
)
else:
kwargs["node_id"] = node_id
engine_class = ray.remote(num_cpus=1,
Expand All @@ -96,16 +97,21 @@ def from_args(cls,
lifetime=lifetime)(cls).options(
scheduling_strategy=NodeAffinitySchedulingStrategy(
node_id=node_id,
soft=False,))
soft=False,
)
)
else: # backend_type == backend_type.SIM_VLLM:
kwargs["node_id"] = node_id
engine_class = ray.remote(num_cpus=1,
name=actor_name,
namespace='llumnix',
max_concurrency=4,
lifetime=lifetime)(cls).options(
scheduling_strategy=NodeAffinitySchedulingStrategy(
node_id=node_id,
soft=False,))
scheduling_strategy=NodeAffinitySchedulingStrategy(
node_id=node_id,
soft=False,
)
)
llumlet = engine_class.remote(instance_id, output_queue_type, backend_type, migration_config, *args, **kwargs)
return llumlet

Expand All @@ -118,8 +124,8 @@ def check_state(self):
logger.warning("llumlet ({}) detected backend engine crashed. Stopping...".format(self.instance_id))
# pylint: disable=protected-access
self.backend_engine._stop_event.set()
if self.backend_engine._thread.is_alive():
self.backend_engine._thread.join()
if self.backend_engine.engine_step_loop_thread.is_alive():
self.backend_engine.engine_step_loop_thread.join()

self_actor = ray.get_actor(self.actor_name)
ray.kill(self_actor)
Expand Down
5 changes: 5 additions & 0 deletions llumnix/queue/queue_client_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,16 @@
# limitations under the License.

from abc import ABC, abstractmethod
from typing import Any
from collections.abc import Iterable

from llumnix.server_info import ServerInfo

class QueueClientBase(ABC):
@abstractmethod
async def put_nowait(self, item: Any, server_info: ServerInfo):
raise NotImplementedError

@abstractmethod
async def put_nowait_batch(self, items: Iterable, server_info: ServerInfo):
raise NotImplementedError
Loading