Skip to content

Commit

Permalink
refine
Browse files Browse the repository at this point in the history
  • Loading branch information
KuilongCui committed Dec 16, 2024
1 parent 32f7d9b commit cdac2e3
Show file tree
Hide file tree
Showing 14 changed files with 81 additions and 116 deletions.
4 changes: 2 additions & 2 deletions examlpes/offline_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from llumnix import launch_ray_cluster, connect_to_ray_cluster, init_manager, init_llumlets
from llumnix import (SamplingParams, ServerInfo, EngineManagerArgs, LLMEngineManager, Llumlet,
EngineArgs, QueueType)
EngineArgs, QueueType, BackendType)
from llumnix.utils import random_uuid
from llumnix.queue.ray_queue_server import RayQueueServer

Expand Down Expand Up @@ -40,7 +40,7 @@
llumlets: List[Llumlet] = None
llumlet_ids, llumlets = init_llumlets(
manager_args, engine_args, ray.get_runtime_context().get_node_id(),
QueueType("rayqueue")
QueueType("rayqueue"), BackendType.VLLM, 1,
)


Expand Down
26 changes: 10 additions & 16 deletions llumnix/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from llumnix.server_info import ServerInfo
from llumnix.entrypoints.utils import (launch_ray_cluster,
connect_to_ray_cluster,
init_manager,
init_llumlets)
from llumnix.arg_utils import EngineManagerArgs
from llumnix.llm_engine_manager import LLMEngineManager
from llumnix.llumlet.llumlet import Llumlet
from llumnix.queue.queue_type import QueueType

from llumnix.server_info import ServerInfo
from llumnix.entrypoints.setup import (launch_ray_cluster,
connect_to_ray_cluster,
Expand All @@ -30,6 +20,8 @@
from llumnix.llm_engine_manager import LLMEngineManager
from llumnix.llumlet.llumlet import Llumlet
from llumnix.queue.queue_type import QueueType
from llumnix.backends.backend_interface import BackendType
from llumnix.version import __version__

__all__ = [
"__version__",
Expand All @@ -42,6 +34,7 @@
"LLMEngineManager",
"Llumlet",
"QueueType",
"BackendType",
]

try:
Expand All @@ -51,9 +44,10 @@
except ImportError:
pass

try:
import blade_llm
from blade_llm import *
__all__.extend(getattr(blade_llm, "__all__", []))
except ImportError:
pass
# TODO(KuilongCui): import blade_llm after cuda is ready
# try:
# import blade_llm
# from blade_llm import *
# __all__.extend(getattr(blade_llm, "__all__", []))
# except ImportError:
# pass
39 changes: 5 additions & 34 deletions llumnix/backends/bladellm/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,41 +34,10 @@
from llumnix.backends.backend_interface import BackendInterface, EngineState
from llumnix.internal_config import MigrationConfig
from llumnix.server_info import ServerInfo
from llumnix.queue.utils import QueueType, QueueClientBase, init_request_output_queue_client
from llumnix.backends.utils import AsyncPutQueueActor
from llumnix.llumlet.request import LlumnixRequest, RequestStatus
from llumnix.instance_info import InstanceInfo

class AsyncPutQueueActor:
def __init__(self, instance_id, output_queue_type: QueueType):
self.instance_id = instance_id
self.output_queue_type = output_queue_type
self.request_output_queue_client: QueueClientBase = init_request_output_queue_client(output_queue_type)
self.engine_actor_handle = None

async def put_nowait_to_servers(self,
server_request_outputs,
server_info_dict) -> None:
if self.engine_actor_handle is None:
self.engine_actor_handle = ray.get_actor("instance_{}".format(self.instance_id), namespace="llumnix")
tasks = []
for server_id, req_outputs in server_request_outputs.items():
server_info = server_info_dict[server_id]
for req_output in req_outputs:
if hasattr(req_output, 'request_timestamps'):
req_output.request_timestamps.engine_actor_put_queue_timestamp = time.time()
tasks.append(asyncio.create_task(self.request_output_queue_client.put_nowait(req_outputs, server_info)))
rets = await asyncio.gather(*tasks, return_exceptions=True)
for idx, ret in enumerate(rets):
if isinstance(ret, (TimeoutError, ray.exceptions.RayActorError)):
server_id = list(server_request_outputs.keys())[idx]
server_info = server_info_dict[server_id]
logger.info("Server {} is dead".format(server_id))
if self.output_queue_type == QueueType.ZMQ:
logger.info("request output queue ip: {}, port: {}".format(server_info.request_output_queue_ip,
server_info.request_output_queue_port))
req_outputs = list(server_request_outputs.values())[idx]
request_ids = [req_output.request_id for req_output in req_outputs]
self.engine_actor_handle.abort.remote(request_ids)
from llumnix.queue.queue_type import QueueType

class AsyncBackQueue(APIWrapper):
def __init__(self, placement_group, node_id, instance_id, output_queue_type) -> None:
Expand Down Expand Up @@ -233,9 +202,11 @@ def __init__(self,
node_id: Optional[str],
*args, **kwargs,
) -> None:
logger.info("aaa")
AsyncLLMEngine.__init__(self, *args, **kwargs)
logger.info("bbb")
AsyncLLMEngineLlumnixMixin.__init__(self, instance_id, output_queue_type, migration_config, placement_group, node_id)

logger.info("ccc")
class PrefillAsyncLLMEngineLlumnix(AsyncLLMEngineLlumnixMixin, PrefillAsyncLLMEngine):
def __init__(self,
instance_id: str,
Expand Down
4 changes: 2 additions & 2 deletions llumnix/backends/bladellm/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from blade_llm.service.block_space_manager import BlockSpaceManager
from blade_llm.service.schedulers.paged_scheduler import PagedScheduler

from llumnix.backends.bladellm.llm_engine import LLMEngineLlumnixMixin
from llumnix.backends.bladellm.llm_engine import AsyncLLMEngineLlumnixMixin
from llumnix.metrics.base_metrics import LlumnixMetrics
from llumnix.metrics.dumper import LoggerDumper

Expand All @@ -26,7 +26,7 @@ def block_manager_init_metrics(self, block_manager: BlockSpaceManager):
self.num_total_gpu_blocks.observe(block_manager.num_total_gpu_blocks)
self.num_watermark_blocks.observe(block_manager.reserved_blocks)

def engine_init_metrics(self, engine: LLMEngineLlumnixMixin):
def engine_init_metrics(self, engine: AsyncLLMEngineLlumnixMixin):
self.instance_id.observe(engine.instance_id)

def scheduler_step_metrics(self, scheduler: PagedScheduler):
Expand Down
40 changes: 39 additions & 1 deletion llumnix/backends/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,51 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional, Tuple
from typing import Optional, Tuple, Dict, List
import asyncio
import time

import ray
from ray.util.placement_group import PlacementGroup
from loguru import logger

from llumnix.backends.backend_interface import BackendInterface, BackendType
from llumnix.queue.queue_type import QueueType
from llumnix.queue.queue_client_base import QueueClientBase
from llumnix.queue.utils import init_request_output_queue_client
from llumnix.server_info import ServerInfo

class AsyncPutQueueActor:
def __init__(self, instance_id, request_output_queue_type: QueueType):
self.instance_id = instance_id
self.request_output_queue_type = request_output_queue_type
self.request_output_queue_client: QueueClientBase = init_request_output_queue_client(request_output_queue_type)
self.engine_actor_handle = None

async def put_nowait_to_servers(self,
server_request_outputs: Dict[str, List],
server_info_dict: Dict[str, ServerInfo]) -> None:
if self.engine_actor_handle is None:
self.engine_actor_handle = ray.get_actor("instance_{}".format(self.instance_id), namespace="llumnix")
tasks = []
for server_id, req_outputs in server_request_outputs.items():
server_info = server_info_dict[server_id]
for req_output in req_outputs:
if hasattr(req_output, 'request_timestamps'):
req_output.request_timestamps.engine_actor_put_queue_timestamp = time.time()
tasks.append(asyncio.create_task(self.request_output_queue_client.put_nowait(req_outputs, server_info)))
rets = await asyncio.gather(*tasks, return_exceptions=True)
for idx, ret in enumerate(rets):
if isinstance(ret, Exception):
server_id = list(server_request_outputs.keys())[idx]
server_info = server_info_dict[server_id]
logger.info("server {} is dead".format(server_id))
if self.request_output_queue_type == QueueType.ZMQ:
logger.info("request output queue ip: {}, port: {}".format(server_info.request_output_queue_ip,
server_info.request_output_queue_port))
req_outputs = list(server_request_outputs.values())[idx]
request_ids = [req_output.request_id for req_output in req_outputs]
self.engine_actor_handle.abort_request.remote(request_ids)

def init_backend_engine(instance_id: str, request_output_queue_type: QueueType,
backend_type: BackendType, *args, **kwargs) -> BackendInterface:
Expand Down
45 changes: 3 additions & 42 deletions llumnix/backends/vllm/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

import time
import traceback
from typing import Any, List, Optional, Dict, Union, Iterable, Tuple, Deque
from typing import Any, List, Optional, Union, Iterable, Tuple, Deque
from collections import defaultdict
import threading
import asyncio
Expand All @@ -38,52 +38,13 @@
from llumnix.backends.profiling import LatencyMemData
from llumnix.server_info import ServerInfo
from llumnix.internal_config import MigrationConfig
from llumnix.queue.queue_client_base import QueueClientBase
from llumnix.queue.utils import init_request_output_queue_client, QueueType
from llumnix.queue.utils import QueueType
from llumnix.backends.utils import AsyncPutQueueActor

logger = init_logger(__name__)

NO_OUTPUTS_STEP_INTERVAL = 0.01


class AsyncPutQueueActor:
def __init__(self, instance_id, request_output_queue_type: QueueType):
self.instance_id = instance_id
self.request_output_queue_type = request_output_queue_type
self.request_output_queue_client: QueueClientBase = init_request_output_queue_client(request_output_queue_type)
self.engine_actor_handle = None

async def put_nowait_to_servers(self,
server_request_outputs: Dict[str, List[RequestOutput]],
server_info_dict: Dict[str, ServerInfo]) -> None:
try:
if self.engine_actor_handle is None:
self.engine_actor_handle = ray.get_actor("instance_{}".format(self.instance_id), namespace="llumnix")
tasks = []
for server_id, req_outputs in server_request_outputs.items():
server_info = server_info_dict[server_id]
for req_output in req_outputs:
if hasattr(req_output, 'request_timestamps'):
req_output.request_timestamps.engine_actor_put_queue_timestamp = time.time()
tasks.append(asyncio.create_task(self.request_output_queue_client.put_nowait(req_outputs, server_info)))
rets = await asyncio.gather(*tasks, return_exceptions=True)
for idx, ret in enumerate(rets):
if isinstance(ret, Exception):
server_id = list(server_request_outputs.keys())[idx]
server_info = server_info_dict[server_id]
logger.info("server {} is dead".format(server_id))
if self.request_output_queue_type == QueueType.ZMQ:
logger.info("request output queue ip: {}, port: {}".format(server_info.request_output_queue_ip,
server_info.request_output_queue_port))
req_outputs = list(server_request_outputs.values())[idx]
request_ids = [req_output.request_id for req_output in req_outputs]
self.engine_actor_handle.abort_request.remote(request_ids)
# pylint: disable=W0703
except Exception as e:
logger.error("Error in engine loop: {}".format(e))
logger.error("exception traceback: {}".format(traceback.format_exc()))


class LLMEngineLlumnix(_AsyncLLMEngine):
def __init__(self,
instance_id: str,
Expand Down
2 changes: 1 addition & 1 deletion llumnix/config/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@
_C.MANAGER.LAST_STAGE_MAX_BLOCKS = 16

# Communication backend of migration
_C.MANAGER.MIGRATION_BACKEND = "grpc"
_C.MANAGER.MIGRATION_BACKEND = "gloo"
# Transfer type for migration backend grpc and kvTransfer
_C.MANAGER.MIGRATION_BACKEND_TRANSFER_TYPE = ""
# Address of grpc server for migration backend
Expand Down
4 changes: 2 additions & 2 deletions llumnix/entrypoints/bladellm/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@
from llumnix.config import get_llumnix_config, LlumnixConfig
from llumnix.backends.backend_interface import BackendType
from llumnix.arg_utils import LlumnixEntrypointsArgs, EngineManagerArgs, LlumnixArgumentParser
from llumnix.entrypoints.utils import setup_ray_cluster, setup_llumnix, is_gpu_available
from llumnix.entrypoints.setup import setup_ray_cluster, setup_llumnix, is_gpu_available
from llumnix.entrypoints.bladellm.client import LlumnixClientBladeLLM
from llumnix.entrypoints.utils import LlumnixEntrypointsContext
from llumnix.entrypoints.setup import LlumnixEntrypointsContext
from llumnix.entrypoints.bladellm.utils import get_args

def setup_llumnix_api_server(bladellm_args: ServingArgs, loop: asyncio.AbstractEventLoop):
Expand Down
16 changes: 9 additions & 7 deletions llumnix/entrypoints/bladellm/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import asyncio
import copy
import random
from typing import Dict

import ray

Expand All @@ -27,7 +28,7 @@
from blade_llm.service.communications.response import error_resp

from llumnix.server_info import RequestTimestamps
from llumnix.entrypoints.utils import LlumnixEntrypointsContext
from llumnix.entrypoints.setup import LlumnixEntrypointsContext
from llumnix.logger import init_logger

logger = init_logger(__name__)
Expand All @@ -40,6 +41,7 @@ def __init__(self, args: ServingArgs, llumnix_context: LlumnixEntrypointsContext
self.entrypoint_id2llumnix_id = {}
self.llumnix_id2entrypoint_id = {}
self.llumnix_context = llumnix_context
self.request_streams: Dict[str, asyncio.Queue] = {}
loop.create_task(self.background_process_outputs())

async def background_process_outputs(self):
Expand All @@ -50,14 +52,14 @@ async def background_process_outputs(self):
for (request_id, request_output) in request_outputs:
request_output = GenerateStreamResponse(**json.loads(request_output))
# Request could be dispatched twice when manager is dead, the first request will free the request_streams when finished.
if request_id not in self.llumnix_context.request_streams:
if request_id not in self.request_streams:
continue
await self.llumnix_context.request_streams[request_id].put(request_output)
await self.request_streams[request_id].put(request_output)
if request_output.is_finished:
logger.info("Client Recv: {}".format(request_output))
del self.entrypoint_id2llumnix_id[self.llumnix_id2entrypoint_id[request_id]]
del self.llumnix_id2entrypoint_id[request_id]
del self.llumnix_context.request_streams[request_id]
del self.request_streams[request_id]

async def _add_request(self, request: ServerRequest) -> LLMResponse:
if request.sampling_params.n > 1 or request.sampling_params.use_beam_search:
Expand All @@ -74,7 +76,7 @@ async def _manager_generate(self, request, request_id: str) -> LLMResponse:
logger.debug("Client Add request: {}:{}".format(request_id, request))

results_queue = asyncio.Queue()
self.llumnix_context.request_streams[request_id] = results_queue
self.request_streams[request_id] = results_queue

# This request's outputs will be put to the request_output_queue of this api server no matter which instance it's running in.
# If manager is unavailable, request will be directly added to the llumlet held by api server.
Expand Down Expand Up @@ -136,8 +138,8 @@ async def get_stats(self) -> Stats:
async def get_metrics(self) -> str:
pass

def start_profiler(self) -> None:
async def start_profiler(self) -> None:
pass

def stop_profiler(self) -> None:
async def stop_profiler(self) -> None:
pass
5 changes: 3 additions & 2 deletions llumnix/entrypoints/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import sys
import os
import time
from typing import List, Tuple, Dict, Any
from typing import List, Tuple, Dict
import asyncio
import socket
import ray
Expand Down Expand Up @@ -157,8 +157,9 @@ def init_llumlets(engine_manager_args: EngineManagerArgs, engine_args, node_id:
instance_ids: List[str] = []
llumlets: List[Llumlet] = []
instance_ids = [random_uuid() for _ in range(engine_manager_args.initial_instances)]
if 'instance_ids' in kwargs and kwargs['instance_ids'][0]:
if 'instance_ids' in kwargs and kwargs['instance_ids']:
instance_ids = kwargs['instance_ids']
kwargs.pop('instance_ids')
migration_configs = engine_manager_args.create_migration_config()
for idx in range(engine_manager_args.initial_instances):
instance_id = instance_ids[idx]
Expand Down
6 changes: 3 additions & 3 deletions llumnix/llm_engine_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def __init__(self,
self.instance_last_logged_empty = {}

# When manager starts, it automatically connects to all existing instances.
self._connect_to_instances()
asyncio.run_coroutine_threadsafe(self._connect_to_instances(), asyncio.get_event_loop())

async def generate(self, request_id: str, server_info: ServerInfo, *args, **kwargs,) -> None:
while self.num_instances == 0:
Expand Down Expand Up @@ -401,7 +401,7 @@ def scale_down(self, instance_id: Union[str, Iterable[str]], rebuild_migrate_bac

return self.num_instances

def _connect_to_instances(self):
async def _connect_to_instances(self):
actor_names_dict = ray.util.list_named_actors(True)
instance_actor_names = [actor_name_dict['name'] for actor_name_dict in actor_names_dict if actor_name_dict['name'] != MANAGER_ACTOR_NAME]
instance_actor_handles = [ray.get_actor(actor_name, namespace='llumnix') for actor_name in instance_actor_names]
Expand All @@ -411,7 +411,7 @@ def _connect_to_instances(self):
instance_id = instance_actor_name[len('instance_'):]
if instance_id not in self.instances:
try:
ray.get(instance_actor_handle.is_ready.remote())
await instance_actor_handle.is_ready.remote()
# pylint: disable=W0703
except Exception as e:
logger.info("connect to instance {} abort, which may be not ready or alive, err: {}".format(instance_id, e))
Expand Down
Loading

0 comments on commit cdac2e3

Please sign in to comment.