Skip to content

Commit

Permalink
[Core] Add back ray queue to put request output tokens back to the ap…
Browse files Browse the repository at this point in the history
…i server (#41)
  • Loading branch information
KuilongCui authored Oct 9, 2024
1 parent e9cf870 commit 653ba46
Show file tree
Hide file tree
Showing 38 changed files with 364 additions and 157 deletions.
3 changes: 0 additions & 3 deletions .github/workflows/bench_test.yml
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
name: bench_test

on:
push:
branches:
- main
pull_request:
branches:
- main
Expand Down
3 changes: 0 additions & 3 deletions .github/workflows/e2e_test.yml
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
name: e2e_test

on:
push:
branches:
- main
pull_request:
branches:
- main
Expand Down
3 changes: 0 additions & 3 deletions .github/workflows/migration_test.yml
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
name: migration_test

on:
push:
branches:
- main
pull_request:
branches:
- main
Expand Down
3 changes: 0 additions & 3 deletions .github/workflows/offline_inference.yml
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
name: offline_inference

on:
push:
branches:
- main
pull_request:
branches:
- main
Expand Down
3 changes: 0 additions & 3 deletions .github/workflows/pylint.yml
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
name: pylint

on:
push:
branches:
- main
pull_request:
branches:
- main
Expand Down
3 changes: 0 additions & 3 deletions .github/workflows/unit_test.yml
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
name: unit_test

on:
push:
branches:
- main
pull_request:
branches:
- main
Expand Down
3 changes: 0 additions & 3 deletions .github/workflows/whl.yml → .github/workflows/whl_build.yml
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
name: whl_build

on:
push:
branches:
- main
pull_request:
branches:
- main
Expand Down
3 changes: 2 additions & 1 deletion configs/base.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
SERVER:
HOST: '127.0.0.1'
PORT: 37000
QUEUE_TYPE: "rayqueue"

RAY:
RAY_CLUSTER_PORT: 30037
Expand All @@ -19,7 +20,7 @@ MANAGER:
ENABLE_DEFRAG: True
REQUEST_MIGRATION_POLICY: 'SJF'

MIGRATION_BACKEND: 'rpc'
MIGRATION_BACKEND: 'gloo'
MIGRATION_CACHE_BLOCKS: 512

ENABLE_SCALING: False
27 changes: 8 additions & 19 deletions examlpes/offline_inference.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,14 @@
from typing import List
import os
import uuid
import asyncio

import ray
from ray.util.queue import Queue as RayQueue
from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy

from llumnix import launch_ray_cluster, connect_to_ray_cluster, init_manager, init_llumlets
from llumnix import (SamplingParams, ServerInfo, EngineManagerArgs, LLMEngineManager, Llumlet,
EngineArgs, RequestOutput)
EngineArgs, QueueType)
from llumnix.utils import random_uuid
from llumnix.rpc.queue_server import QueueServer
from llumnix.rpc.queue_client import QueueClient
from llumnix.rpc.utils import get_open_zmq_ipc_path
from llumnix.entrypoints.llumnix_utils import get_ip_address

from llumnix.queue.ray_queue_server import RayQueueServer

# Sample prompts.
prompts = [
Expand Down Expand Up @@ -45,8 +38,10 @@
# Create llumlets.
llumlet_ids: List[str] = None
llumlets: List[Llumlet] = None
llumlet_ids, llumlets = init_llumlets(manager_args, engine_args,
node_id=ray.get_runtime_context().get_node_id())
llumlet_ids, llumlets = init_llumlets(
manager_args, engine_args, ray.get_runtime_context().get_node_id(),
QueueType("rayqueue")
)


# Create a manager. If the manager is created first, and then the llumlets are created, manager.scale_up
Expand All @@ -55,11 +50,8 @@

# The requests‘ outputs will be put to the request_output_queue no matter which instance it's running in.
server_id = random_uuid()
ip = get_ip_address()
port = 1234
server_info = ServerInfo(server_id, ip, port)
rpc_path = get_open_zmq_ipc_path(server_info.request_output_queue_ip, server_info.request_output_queue_port)
request_output_queue = QueueServer(rpc_path)
request_output_queue = RayQueueServer()
server_info = ServerInfo(server_id, QueueType("rayqueue"), request_output_queue, None, None)

# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
Expand Down Expand Up @@ -94,9 +86,6 @@ async def main():
for actor in named_actors:
try:
actor_handle = ray.get_actor(actor['name'], namespace=actor['namespace'])
except:
continue
try:
ray.kill(actor_handle)
except:
continue
Expand Down
4 changes: 3 additions & 1 deletion llumnix/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from llumnix.arg_utils import EngineManagerArgs
from llumnix.llm_engine_manager import LLMEngineManager
from llumnix.llumlet.llumlet import Llumlet
from llumnix.queue.queue_type import QueueType

from .version import __version__

Expand All @@ -32,7 +33,8 @@
"init_llumlets",
"EngineManagerArgs",
"LLMEngineManager",
"Llumlet"
"Llumlet",
"QueueType",
]

__all__.extend(getattr(vllm, "__all__", []))
13 changes: 7 additions & 6 deletions llumnix/backends/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,30 +12,31 @@
# limitations under the License.

from typing import Optional, Tuple

import ray
# pylint: disable=unused-import
from ray.util.placement_group import PlacementGroup

from llumnix.backends.backend_interface import BackendInterface, BackendType
from llumnix.queue.queue_type import QueueType


def init_backend_engine(instance_id: str, backend_type: BackendType, *args, **kwargs) -> BackendInterface:
def init_backend_engine(instance_id: str, output_queue_type: QueueType,
backend_type: BackendType, *args, **kwargs) -> BackendInterface:
if backend_type == BackendType.VLLM:
# pylint: disable=import-outside-toplevel
from llumnix.backends.vllm.llm_engine import BackendVLLM
backend_engine = BackendVLLM(instance_id, *args, **kwargs)
backend_engine = BackendVLLM(instance_id, output_queue_type, *args, **kwargs)
elif backend_type == BackendType.SIM_VLLM:
# pylint: disable=import-outside-toplevel
from llumnix.backends.vllm.simulator import BackendSimVLLM
backend_engine = BackendSimVLLM(instance_id, *args, **kwargs)
backend_engine = BackendSimVLLM(instance_id, output_queue_type, *args, **kwargs)
else:
raise ValueError(f'Unsupported backend: {backend_type}')
return backend_engine

def initialize_placement_group(
world_size: int = 1,
detached: bool = False
) -> Tuple[str, Optional["PlacementGroup"]]:
) -> Tuple[str, Optional[PlacementGroup]]:
"""Initialize the distributed cluster probably with Ray.
Args:
Expand Down
17 changes: 12 additions & 5 deletions llumnix/backends/vllm/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,16 +35,19 @@
from llumnix.backends.profiling import LatencyMemData
from llumnix.server_info import ServerInfo
from llumnix.internal_config import MigrationConfig
from llumnix.rpc.queue_client import QueueClient
from llumnix.queue.queue_client_base import QueueClientBase
from llumnix.queue.utils import get_output_queue_client, QueueType

logger = init_logger(__name__)


class AsyncPutQueueThread(threading.Thread):
def __init__(self, instance_id):
def __init__(self, instance_id, output_queue_type: QueueType):
super().__init__()
self.instance_id = instance_id
self.request_output_queue_client = QueueClient()

self.request_output_queue_client: QueueClientBase \
= get_output_queue_client(output_queue_type)
self.engine_actor_handle = None
self.loop = asyncio.new_event_loop()
self.daemon = True
Expand Down Expand Up @@ -82,20 +85,21 @@ def put_nowait_batch_to_servers(self,


class LLMEngineLlumnix(LLMEngine):
def __init__(self, instance_id: str, *arg, **kwargs) -> None:
def __init__(self, instance_id: str, output_queue_type: QueueType, *arg, **kwargs) -> None:
super().__init__(*arg, **kwargs)
self.instance_id = instance_id
self.step_counter = Counter()
self.instance_info = None
# TODO(s5u13b): Reduce the overhead.
self.async_put_queue_thread = AsyncPutQueueThread(instance_id)
self.async_put_queue_thread = AsyncPutQueueThread(instance_id, output_queue_type)
self.async_put_queue_thread.start()

# pylint: disable=W0221
@classmethod
def from_engine_args(
cls,
engine_args: EngineArgs,
output_queue_type: QueueType,
migration_config: MigrationConfig,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
instance_id: str = None,
Expand Down Expand Up @@ -124,6 +128,7 @@ def from_engine_args(
# Create the LLM engine.
engine = cls(
instance_id=instance_id,
output_queue_type=output_queue_type,
**engine_config.to_dict(),
executor_class=executor_class,
log_stats=not engine_args.disable_log_stats,
Expand Down Expand Up @@ -217,12 +222,14 @@ class BackendVLLM(BackendInterface):
def __init__(
self,
instance_id: str,
output_queue_type: QueueType,
migration_config: MigrationConfig,
engine_args: EngineArgs,
placement_group: PlacementGroup = None,
node_id: str = None
) -> None:
self.engine: LLMEngineLlumnix = LLMEngineLlumnix.from_engine_args(engine_args=engine_args,
output_queue_type=output_queue_type,
migration_config=migration_config,
instance_id=instance_id,
placement_group=placement_group,
Expand Down
8 changes: 6 additions & 2 deletions llumnix/backends/vllm/simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import os
import threading
from typing import List
import ray.actor

from vllm.engine.arg_utils import EngineArgs

Expand All @@ -22,7 +23,7 @@
from llumnix.backends.vllm.scheduler import SchedulerLlumnix
from llumnix.backends.vllm.llm_engine import LLMEngineLlumnix, BackendVLLM
from llumnix.backends.profiling import ProfilingDatabase, LatencyMemData, ProfilingResult, SimParallelConfig

from llumnix.queue.queue_type import QueueType

logger = init_logger(__name__)

Expand All @@ -31,6 +32,7 @@ class BackendSimVLLM(BackendVLLM):
def __init__(
self,
instance_id: str,
output_queue_type: QueueType,
migration_config: MigrationConfig,
profiling_result_file_path: str,
engine_args: EngineArgs,
Expand All @@ -54,6 +56,7 @@ def __init__(
latency_mem: LatencyMemData = profiling_result.para_dict[sim_parallel_config]
# multi-instance args
self.engine: LLMEngineLlumnix = LLMEngineLlumnix.from_engine_args(engine_args=engine_args,
output_queue_type=output_queue_type,
migration_config=migration_config,
instance_id=instance_id,
latency_mem=latency_mem)
Expand All @@ -66,5 +69,6 @@ def __init__(
)
self._thread.start()

def send_blocks(self, dst_ray_actor: "ray.actor.ActorHandle", src_blocks: List[int], dst_blocks: List[int]) -> None:
# pylint: disable=unused-argument
def send_blocks(self, dst_ray_actor: ray.actor.ActorHandle, src_blocks: List[int], dst_blocks: List[int]) -> None:
self.engine.model_executor.send_blocks(len(src_blocks))
4 changes: 3 additions & 1 deletion llumnix/config/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
_C.SERVER.HOST = "localhost"
# Port number for the server
_C.SERVER.PORT = 8000
# Queue type for request output queue
_C.SERVER.QUEUE_TYPE = "rayqueue"
# Port number for the request output queue
_C.SERVER.REQUEST_OUTPUT_QUEUE_PORT = 1234
# Path to SSL key file for secure connections
Expand Down Expand Up @@ -95,7 +97,7 @@
_C.MANAGER.LAST_STAGE_MAX_BLOCKS = 16

# Communication backend of migration
_C.MANAGER.MIGRATION_BACKEND = "rpc"
_C.MANAGER.MIGRATION_BACKEND = "gloo"
# Timeout(s) for initializing migration backend
_C.MANAGER.MIGRATION_BACKEND_INIT_TIMEOUT = 10.0
# Number of cache blocks in migration
Expand Down
Loading

0 comments on commit 653ba46

Please sign in to comment.