Skip to content

Commit

Permalink
[BugFix] Support initialize llumlet by manager (#11)
Browse files Browse the repository at this point in the history
  • Loading branch information
s5u13b authored Aug 21, 2024
1 parent e10bcec commit 309c296
Show file tree
Hide file tree
Showing 13 changed files with 157 additions and 50 deletions.
8 changes: 6 additions & 2 deletions docs/Arguments.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ Note: since Llumnix is still in alpha stage, the interface and arguments are *su

```
usage: -m llumnix.entrypoints.vllm.api_server [-h]
[--fixed-node-init]
[--fixed-node-init-instance]
[--init-instance-by-manager]
[--initial-instances INITIAL_INSTANCES]
[--load-metric {consumed_speed,used_ratio}]
[--polling-interval POLLING_INTERVAL]
Expand Down Expand Up @@ -35,9 +36,12 @@ usage: -m llumnix.entrypoints.vllm.api_server [-h]
[--max-stages MAX_STAGES]
```

`--fixed-node-init`
`--fixed-node-init-instance`
- Fix the placement of instance to current node.

`--init-instance-by-manager`
- initialize instance by manager.

`--initial-instances`
- Number of model instances created at initialization.
- Default: 1
Expand Down
8 changes: 6 additions & 2 deletions llumnix/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,9 @@
@dataclass
class EngineManagerArgs:
launch_ray_cluster: bool = True
init_instance_by_manager: bool = True
initial_instances: int = 1
fixed_node_init: bool = False
fixed_node_init_instance: bool = False

load_metric: str = 'consumed_speed'
polling_interval: float = 0.05
Expand Down Expand Up @@ -92,9 +93,12 @@ def from_cli_args(cls, args: argparse.Namespace) -> 'EngineManagerArgs':
@staticmethod
def add_cli_args(
parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
parser.add_argument('--fixed-node-init',
parser.add_argument('--fixed-node-init-instance',
action='store_true',
help='fix the placement of instance to current node')
parser.add_argument('--init-instance-by-manager',
action='store_true',
help='initialize instance by manager')
parser.add_argument('--initial-instances',
type=int,
default=EngineManagerArgs.initial_instances,
Expand Down
6 changes: 4 additions & 2 deletions llumnix/backends/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def init_backend_engine(instance_id: str, backend_type: BackendType, *args, **kw
def initialize_cluster(
world_size: int = 1,
ray_address: Optional[str] = None,
detached: bool = False,
) -> Tuple[str, Optional["PlacementGroup"]]:
"""Initialize the distributed cluster probably with Ray.
Expand All @@ -55,8 +56,9 @@ def initialize_cluster(
"Ray is not installed. Please install Ray to use distributed "
"serving.")
# Connect to a ray cluster.
ray.init(address=ray_address, ignore_reinit_error=True)
ray.init(address=ray_address, ignore_reinit_error=True, namespace='llumnix')

lifetime = "detached" if detached else None
# Create placement group for worker processes
current_placement_group = ray.util.get_current_placement_group()
if current_placement_group:
Expand Down Expand Up @@ -84,7 +86,7 @@ def initialize_cluster(
# Create a new placement group
placement_group_specs = ([{"CPU": 1}] + [{"GPU": 1}] * world_size)
current_placement_group = ray.util.placement_group(
placement_group_specs, "STRICT_PACK")
placement_group_specs, "STRICT_PACK", lifetime=lifetime)
# Wait until PG is ready - this will block until all
# requested resources are available, and will timeout
# if they cannot be provisioned.
Expand Down
5 changes: 4 additions & 1 deletion llumnix/backends/vllm/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
logger = init_logger(__name__)

class LlumnixRayGPUExecutor(RayGPUExecutor):
node_id: str = None

def _init_workers_ray(self, placement_group: "PlacementGroup",
**ray_remote_kwargs):
self.last_inference_latency = 0
Expand All @@ -56,6 +58,7 @@ def _init_workers_ray(self, placement_group: "PlacementGroup",

# Create the workers.
driver_ip = get_ip()
node_id = self.node_id
for rank in range(self.parallel_config.world_size):
if placement_group:
bundle = placement_group.bundle_specs[rank+1]
Expand All @@ -67,7 +70,7 @@ def _init_workers_ray(self, placement_group: "PlacementGroup",
)
else:
scheduling_strategy = NodeAffinitySchedulingStrategy(
node_id=ray.get_runtime_context().get_node_id(),
node_id=node_id,
soft=False,
)
worker = ray.remote(
Expand Down
18 changes: 13 additions & 5 deletions llumnix/backends/vllm/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def from_engine_args(
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
instance_id: str = None,
placement_group: Optional["PlacementGroup"] = None,
node_id: str = None,
latency_mem: Optional[LatencyMemData] = None
) -> "LLMEngineLlumnix":
"""Creates an LLM engine from the engine arguments."""
Expand All @@ -77,6 +78,9 @@ def from_engine_args(
executor_class = LlumnixRayGPUExecutor
else:
raise ValueError('unimplemented executor backend')
# TODO(s5u13b): Do not hack here.
# Hack to pass node_id to _init_workers_ray function.
executor_class.node_id = node_id
# Create the LLM engine.
engine = cls(
instance_id=instance_id,
Expand Down Expand Up @@ -177,11 +181,14 @@ def __init__(
instance_id: int,
migration_config: MigrationConfig,
engine_args: EngineArgs,
placement_group: "PlacementGroup"
placement_group: "PlacementGroup" = None,
node_id: str = None
) -> None:
assert migration_config.migration_backend == "rpc", "Gloo support will be released later."
self.engine: LLMEngineLlumnix = LLMEngineLlumnix.from_engine_args(engine_args=engine_args, instance_id=instance_id,
placement_group=placement_group)
self.engine: LLMEngineLlumnix = LLMEngineLlumnix.from_engine_args(engine_args=engine_args,
instance_id=instance_id,
placement_group=placement_group,
node_id=node_id)
# multi-instance args
self.engine.scheduler = SchedulerLlumnix(self.engine.scheduler_config, self.engine.cache_config, self.engine.lora_config)
self.engine.output_processor.scheduler = self.engine.scheduler
Expand All @@ -190,8 +197,9 @@ def __init__(
if len(self.worker_handle_list) + 1 == self.engine.parallel_config.world_size:
self.worker_handle_list.insert(0, ray.get_actor(f"instance_{self.instance_id}", namespace="llumnix"))
self._run_workers("init_migration", num_migration_cache_blocks=migration_config.migration_cache_blocks,\
src_worker_handle_list=self.worker_handle_list,
placement_group=placement_group)
src_worker_handle_list=self.worker_handle_list,
placement_group=placement_group,
node_id=node_id)
self._thread = threading.Thread(
target=self._start_engine_loop, args=(), daemon=True, name="engine_loop"
)
Expand Down
6 changes: 3 additions & 3 deletions llumnix/backends/vllm/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,15 +55,15 @@ def load_model(self):
torch.cuda.set_device(self.device)
return super().load_model()

def init_migration(self, num_migration_cache_blocks: int, src_worker_handle_list, placement_group=None) -> None:
def init_migration(self, num_migration_cache_blocks: int, src_worker_handle_list, placement_group=None, node_id=None) -> None:
if placement_group:
scheduling_strategy = PlacementGroupSchedulingStrategy(
placement_group=placement_group,
placement_group_capture_child_tasks=True,
)
else:
scheduling_strategy = NodeAffinitySchedulingStrategy(
node_id=ray.get_runtime_context().get_node_id(),
node_id=node_id,
soft=False,
)
self.recv_actor = RecvActor.options(scheduling_strategy=scheduling_strategy).remote()
Expand Down Expand Up @@ -147,8 +147,8 @@ def recv_cpu_cache_v2(self, blocks: List[int], rpc_numpy_cache):
self.cache_engine.attn_backend.swap_blocks(self.migration_cache[layer_idx], self.gpu_cache[layer_idx],src_to_dst)
torch.cuda.Stream.synchronize(self.migration_stream)


def migrate_gpu_cache_ray_rpc(self, src_worker_handle_list, src_blocks: List[int], dst_blocks: List[int]):
# TODO(s5u13b): Raise exception here.
try:
src_worker_handle = src_worker_handle_list[self.rank]
tot_blocks = len(src_blocks)
Expand Down
49 changes: 28 additions & 21 deletions llumnix/entrypoints/llumnix_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,7 @@
from ray.util.queue import Queue as RayQueue
from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy

from vllm.utils import random_uuid
from vllm.engine.arg_utils import AsyncEngineArgs

from llumnix.utils import random_uuid
from llumnix.llm_engine_manager import LLMEngineManager, MANAGER_ACTOR_NAME
from llumnix.llumlet.llumlet import Llumlet
from llumnix.backends.backend_interface import BackendType
Expand Down Expand Up @@ -106,8 +104,6 @@ async def retry_manager_method_async(ray_call, method_name, *args, **kwargs):

def init_manager(engine_manager_args: EngineManagerArgs) -> LLMEngineManager:
# Only one instance create the manager actor, the other instances get the existing manager actor through ray.
# if 'HEAD_NODE' in os.environ:
# time.sleep(20)
try:
engine_manager = LLMEngineManager.from_args(engine_manager_args, None)
logger.info("Init LLMEngineManager on current node")
Expand All @@ -117,7 +113,8 @@ def init_manager(engine_manager_args: EngineManagerArgs) -> LLMEngineManager:
return engine_manager

def init_llumlets(engine_manager_args: EngineManagerArgs,
engine_args: AsyncEngineArgs) -> Tuple[List[str], List[Llumlet]]:
engine_args,
node_id: str) -> Tuple[List[str], List[Llumlet]]:
engine_config = engine_args.create_engine_config()
parallel_config = engine_config.parallel_config
instance_ids: List[str] = []
Expand All @@ -126,7 +123,9 @@ def init_llumlets(engine_manager_args: EngineManagerArgs,
instance_id = random_uuid()
if not engine_manager_args.profiling_result_file_path:
llumlet = Llumlet.from_args(
engine_manager_args.fixed_node_init,
engine_manager_args.fixed_node_init_instance,
False,
node_id,
instance_id,
BackendType.VLLM,
parallel_config.world_size,
Expand All @@ -135,7 +134,9 @@ def init_llumlets(engine_manager_args: EngineManagerArgs,
)
else:
llumlet = Llumlet.from_args(
engine_manager_args.fixed_node_init,
engine_manager_args.fixed_node_init_instance,
False,
node_id,
instance_id,
BackendType.SIM_VLLM,
parallel_config.world_size,
Expand All @@ -158,23 +159,29 @@ def init_request_output_queue() -> RayQueue:
return request_output_queue

def init_llumnix_components(engine_manager_args: EngineManagerArgs,
engine_args: AsyncEngineArgs) -> Tuple[LLMEngineManager, List[Llumlet], RayQueue]:
engine_args,
node_id: str) -> Tuple[LLMEngineManager, List[Llumlet], RayQueue]:
assert engine_args.engine_use_ray and engine_args.worker_use_ray, \
("In Llumnix, engine and worker must be ray actor in orther to run step and migrate concurrently.")

engine_manager = init_manager(engine_manager_args)
logger.info("Init LLMEngineManager done")
instance_ids, llumlets = init_llumlets(engine_manager_args, engine_args)
# TODO(s5u13b): Add arguments checker for Llumnix.
if not engine_manager_args.init_instance_by_manager:
assert engine_manager_args.migration_backend != 'gloo', \
("Llumlet should be initialized by manager when using gloo as migration backend for auto-scaling, "
"please set --init-instance-by-manager argument.")
instance_ids, llumlets = init_llumlets(engine_manager_args, engine_args, node_id)
retry_manager_method_sync(engine_manager.scale_up.remote, 'scale_up', instance_ids, llumlets)
else:
instance_ids, llumlets = retry_manager_method_sync(engine_manager.init_llumlets.remote, 'init_llumlets', engine_args, node_id)
request_output_queue = init_request_output_queue()
logger.info("Init request_output_queue done")
ray.get([llumlet.is_ready.remote() for llumlet in llumlets])
logger.info("Init Llumlets done")
retry_manager_method_sync(engine_manager.scale_up.remote, 'scale_up', instance_ids, llumlets)
logger.info("Scale up instance done")
# We now call run_engine_loop after llumlet's creation.
# for llumlet in llumlets:
# llumlet.run_engine_loop.remote()

try:
ray.get([llumlet.is_ready.remote() for llumlet in llumlets])
except ray.exceptions.RayActorError:
for idx, llumlet in enumerate(llumlets):
try:
ray.get(llumlet.is_ready.remote())
except ray.exceptions.RayActorError:
retry_manager_method_sync(engine_manager.scale_down.remote, 'scale_down', instance_ids[idx])
logger.info("Init Llumnix components done")

return engine_manager, instance_ids, llumlets, request_output_queue
7 changes: 4 additions & 3 deletions llumnix/entrypoints/vllm/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@
from vllm.sampling_params import SamplingParams
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncStream
from vllm.utils import random_uuid

from llumnix.utils import random_uuid
from llumnix.arg_utils import EngineManagerArgs
from llumnix.server_info import ServerInfo
from llumnix.entrypoints.llumnix_utils import launch_ray_cluster, is_gpu_available, init_llumnix_components
Expand Down Expand Up @@ -235,7 +235,7 @@ async def is_ready():
engine_manager_args = EngineManagerArgs.from_cli_args(args)
engine_args = AsyncEngineArgs.from_cli_args(args)

logger.info("engine_args: {}".format(engine_args))
print("engine_args: {}".format(engine_args))

if args.launch_ray_cluster:
# Launch the ray cluster for multi-node serving.
Expand All @@ -245,7 +245,8 @@ async def is_ready():
if is_gpu_available():
# Launch the Llumnix componets on current node.
server_id = random_uuid()
engine_manager, instance_ids, llumlets, request_output_queue = init_llumnix_components(engine_manager_args, engine_args)
node_id = ray.get_runtime_context().get_node_id()
engine_manager, instance_ids, llumlets, request_output_queue = init_llumnix_components(engine_manager_args, engine_args, node_id)
for idx, ins_id in enumerate(instance_ids):
instances[ins_id] = llumlets[idx]
instance_num_request[ins_id] = 0
Expand Down
41 changes: 41 additions & 0 deletions llumnix/llm_engine_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
from llumnix.arg_utils import EngineManagerArgs
from llumnix.backends.profiling import ProfilingDatabase
from llumnix.server_info import ServerInfo
from llumnix.backends.backend_interface import BackendType
from llumnix.utils import random_uuid


logger = init_logger(__name__)
Expand Down Expand Up @@ -316,6 +318,45 @@ def from_args(cls,
logger.info("engine_manager_args: {}".format(engine_manager_args))
return engine_manager

def init_llumlets(self,
engine_args,
node_id: str) -> Tuple[List[str], List[Llumlet]]:
engine_manager_args = self.engine_manager_args
engine_config = engine_args.create_engine_config()
parallel_config = engine_config.parallel_config
instance_ids: List[str] = []
llumlets: List[Llumlet] = []
for _ in range(engine_manager_args.initial_instances):
instance_id = random_uuid()
if not engine_manager_args.profiling_result_file_path:
llumlet = Llumlet.from_args(
engine_manager_args.fixed_node_init_instance,
True,
node_id,
instance_id,
BackendType.VLLM,
parallel_config.world_size,
engine_manager_args.create_migration_configs(),
engine_args,
)
else:
llumlet = Llumlet.from_args(
engine_manager_args.fixed_node_init_instance,
True,
node_id,
instance_id,
BackendType.SIM_VLLM,
parallel_config.world_size,
engine_manager_args.create_migration_configs(),
engine_manager_args.profiling_result_file_path,
engine_manager_args.gpu_type,
engine_args,
)
instance_ids.append(instance_id)
llumlets.append(llumlet)
self.scale_up(instance_ids, llumlets)
return instance_ids, llumlets

def get_actor_name(self) -> str:
return self.actor_name

Expand Down
Loading

0 comments on commit 309c296

Please sign in to comment.