Skip to content

Commit

Permalink
[Core] Add elastic support for migrating the backend
Browse files Browse the repository at this point in the history
  • Loading branch information
KuilongCui committed Aug 12, 2024
1 parent f07eb64 commit ed62763
Show file tree
Hide file tree
Showing 9 changed files with 122 additions and 37 deletions.
5 changes: 2 additions & 3 deletions llumnix/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,16 +70,15 @@ def create_engine_manager_configs(
return global_scheduler_config

def create_migration_configs(
self, instance_rank_map, pp_or_tp_enabled, group_name
self, instance_rank_map, pp_or_tp_enabled
) -> MigrationConfig:
migration_config = MigrationConfig(self.migrate_policy,
self.migration_backend,
self.migration_cache_blocks,
self.last_stage_max_blocks,
self.max_stages,
instance_rank_map,
pp_or_tp_enabled,
group_name)
pp_or_tp_enabled)
return migration_config

@classmethod
Expand Down
7 changes: 2 additions & 5 deletions llumnix/backends/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,8 @@ def init_backend_engine(instance_id: str, backend_type: BackendType, *args, **kw
raise ValueError(f'unimplemented backend {backend_type}')
return backend_engine

def initialize_cluster(
world_size: int = 1,
ray_address: Optional[str] = None,
def get_placement_group(
world_size: int = 1
) -> Tuple[str, Optional["PlacementGroup"]]:
"""Initialize the distributed cluster probably with Ray.
Expand All @@ -54,8 +53,6 @@ def initialize_cluster(
raise ImportError(
"Ray is not installed. Please install Ray to use distributed "
"serving.")
# Connect to a ray cluster.
ray.init(address=ray_address, ignore_reinit_error=True)

# Create placement group for worker processes
current_placement_group = ray.util.get_current_placement_group()
Expand Down
31 changes: 17 additions & 14 deletions llumnix/backends/vllm/migrate_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

class MigrateBackendBase(ABC):
@abstractmethod
def init_col(self, name) -> None:
def init_col(self, name, world_size, rank) -> None:
raise NotImplementedError

@abstractmethod
Expand Down Expand Up @@ -65,9 +65,7 @@ def __init__(self, migrate_config: MigrationConfig, cache_engine: CacheEngine,
self.dtype = dtype
self.is_driver_worker = is_driver_worker
self.gpu_cache = gpu_cache
self.init_col(None)

def init_col(self, name) -> None:
self.cache_device = "cpu"

self.num_migration_cache_blocks = self.migrate_config.migration_cache_blocks
Expand All @@ -82,6 +80,8 @@ def init_col(self, name) -> None:
)

self.migration_stream = torch.cuda.Stream()

def init_col(self, name, world_size, rank) -> None:
logger.info("create rpc migrate backend success.")

def destory_col(self) -> None:
Expand Down Expand Up @@ -144,19 +144,14 @@ def __init__(self, migrate_config: MigrationConfig, cache_engine: CacheEngine, r

self.ray_world_size = ray_world_size
self.ray_rank = ray_rank
self.group_name = migrate_config.group_name
self.group_name = None

self.local_rank = local_rank
self.actor = ProxyActor.options(scheduling_strategy=scheduling_strategy).remote()
self.dtype = dtype
self.is_driver_worker = is_driver_worker
self.gpu_cache = gpu_cache

self.init_col(migrate_config.group_name)

def init_col(self, name) -> None:
self.group_name = name

migration_cache_size = self.cache_engine.block_size * self.cache_engine.num_heads * self.cache_engine.head_size

if self.backend == 'gloo':
Expand All @@ -176,11 +171,16 @@ def init_col(self, name) -> None:

self.migration_stream = torch.cuda.Stream()

def init_col(self, name, world_size, rank) -> None:
self.group_name = name
self.ray_world_size = world_size
self.ray_rank = rank

col.init_collective_group(world_size=self.ray_world_size, rank=self.ray_rank,
backend=self.backend, group_name=self.group_name)

logger.info("create ray collective group success (group_name:{}, backbend: {})."
.format(self.group_name, self.backend))
logger.info("create ray collective group success (group_name:{}, world_size: {}, rank: {}, backbend: {})."
.format(self.group_name, self.ray_world_size, self.ray_rank, self.backend))

def warmup(self) -> None:
enable_warmup = self.ray_world_size > 1
Expand All @@ -193,7 +193,10 @@ def warmup(self) -> None:
self.do_recv(self.ray_rank-1, [0])

def destory_col(self) -> None:
col.destroy_collective_group(self.group_name)
if self.group_name is not None:
col.destroy_collective_group(self.group_name)
logger.info("destory ray collective group success (group_name:{}, backbend: {})."
.format(self.group_name, self.backend))

def migrate_cache(self, src_handle, src_blocks: List[int], dst_blocks: List[int]) -> None:
tot_blocks = len(src_blocks)
Expand Down Expand Up @@ -252,7 +255,7 @@ def to_tensor(data):

raise TypeError("Input data must be either a numpy array or a PyTorch tensor")

def get_migrate_collective(migrate_config: MigrationConfig, cache_engine: CacheEngine, worker_handle_list, scheduling_strategy, \
def get_migrate_backend(migrate_config: MigrationConfig, cache_engine: CacheEngine, worker_handle_list, scheduling_strategy, \
dtype, is_driver_worker, gpu_cache, ray_world_size, ray_rank, worker_rank, local_rank) -> MigrateBackendBase:
if migrate_config.pp_or_tp_enabled and migrate_config.migration_backend == 'nccl':
logger.warning("NCCL backend is not supported for PP or TP enabled model, using gloo instead.")
Expand Down
25 changes: 18 additions & 7 deletions llumnix/backends/vllm/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List
from typing import Dict, List
import math
import os
import ray
Expand All @@ -25,7 +25,7 @@
from llumnix.logger import init_logger
from llumnix.backends.vllm.utils import _sample_with_torch

from llumnix.backends.vllm.migrate_backend import get_migrate_collective
from llumnix.backends.vllm.migrate_backend import get_migrate_backend
from llumnix.config import MigrationConfig

logger = init_logger(__name__)
Expand Down Expand Up @@ -104,7 +104,7 @@ def init_migration(self, instance_id: str, migration_config: MigrationConfig, sr
num_instance = len(migration_config.instance_rank_map)
self.ray_world_size = num_instance * self.parallel_config.world_size
self.ray_rank = self.rank + migration_config.instance_rank_map[self.instance_id] * self.parallel_config.world_size
self.migrate_col = get_migrate_collective(migrate_config=migration_config,
self.migrate_backend = get_migrate_backend(migrate_config=migration_config,
cache_engine=self.cache_engine,
worker_handle_list=src_worker_handle_list,
scheduling_strategy=scheduling_strategy,
Expand All @@ -115,22 +115,32 @@ def init_migration(self, instance_id: str, migration_config: MigrationConfig, sr
ray_rank=self.ray_rank,
worker_rank=self.rank,
local_rank=self.local_rank)
self.migrate_col.warmup()

def migrate_cache(self, src_worker_handle_list, src_blocks: List[int], dst_blocks: List[int]):
try:
src_worker_handle = src_worker_handle_list[self.rank]
self.migrate_col.migrate_cache(src_worker_handle, src_blocks, dst_blocks)
self.migrate_backend.migrate_cache(src_worker_handle, src_blocks, dst_blocks)
except ray.exceptions.RayActorError:
logger.info("[migrate_cache] self.rank: {}, src_worker_handle {} is dead".format(self.rank, src_worker_handle))

def do_recv(self, src_handle, blocks: List[int]):
return self.migrate_col.do_recv(src_handle, blocks=blocks)
return self.migrate_backend.do_recv(src_handle, blocks=blocks)

def do_send(self, dst_handle, blocks: List[int]):
return self.migrate_col.do_send(dst_handle, blocks=blocks)
return self.migrate_backend.do_send(dst_handle, blocks=blocks)

def rebuild_migrate_backend(self, id_rank_map: Dict[str, int], group_name: str) -> None:
self.migrate_backend.destory_col()
num_instance = len(id_rank_map)
self.ray_world_size = num_instance * self.parallel_config.world_size
self.ray_rank = self.rank + id_rank_map[self.instance_id] * self.parallel_config.world_size
return self.migrate_backend.init_col(group_name, self.ray_world_size, self.ray_rank)

def warmup(self):
self.migrate_backend.warmup()

def shutdown(self) -> None:
# self.migrate_backend.destory_col()
torch.cuda.synchronize()
del self.model_runner
del self.cache_engine
Expand All @@ -139,5 +149,6 @@ def shutdown(self) -> None:
torch.cuda.reset_max_memory_allocated()

def restart(self) -> None:
# self.migrate_backend.destory_col()
self.init_model()
self.init_cache_engine(self.cache_config)
4 changes: 1 addition & 3 deletions llumnix/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,14 @@ def __init__(
last_stage_max_blocks: int,
max_stages: int,
instance_rank_map: dict,
pp_or_tp_enabled: bool,
group_name: str) -> None:
pp_or_tp_enabled: bool) -> None:
self.migrate_policy = migrate_policy
self.migration_backend = migration_backend
self.migration_cache_blocks = migration_cache_blocks
self.last_stage_max_blocks = last_stage_max_blocks
self.max_stages = max_stages
self.instance_rank_map = instance_rank_map
self.pp_or_tp_enabled = pp_or_tp_enabled
self.group_name = group_name

class GlobalSchedulerConfig:
def __init__(
Expand Down
2 changes: 1 addition & 1 deletion llumnix/entrypoints/llumnix_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def init_llumlets(engine_manager_args: EngineManagerArgs,
instance_ids = [random_uuid() for _ in range(engine_manager_args.initial_instances)]
id_rank_map = {instance_id: index for index, instance_id in enumerate(instance_ids)}
pp_or_tp_enabled = parallel_config.world_size > 1
migration_configs = engine_manager_args.create_migration_configs(id_rank_map, pp_or_tp_enabled, random_uuid())
migration_configs = engine_manager_args.create_migration_configs(id_rank_map, pp_or_tp_enabled)

for idx in range(engine_manager_args.initial_instances):
instance_id = instance_ids[idx]
Expand Down
5 changes: 5 additions & 0 deletions llumnix/entrypoints/vllm/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,11 @@ def set_runtime_env(manager_args: EngineManagerArgs):
if args.launch_ray_cluster:
# Launch the ray cluster for multi-node serving.
launch_ray_cluster(args.ray_cluster_port)
else:
# Connect to a ray cluster.
head_node_ip = os.getenv('HEAD_NODE_IP')
ray.init(address=f"{head_node_ip}:{args.ray_cluster_port}", ignore_reinit_error=True, namespace="llumnix")


# if gpu is not available, it means that this node is head pod without any llumnix components
if is_gpu_available():
Expand Down
76 changes: 74 additions & 2 deletions llumnix/llm_engine_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
import traceback
import ray

from vllm.utils import random_uuid

from llumnix.llumlet.llumlet import Llumlet
from llumnix.logger import init_logger
from llumnix.global_scheduler.global_scheduler import GlobalScheduler
Expand Down Expand Up @@ -68,6 +70,7 @@ def __init__(self,

self.instances: Dict[str, Llumlet] = {}
self.instance_migrating: Dict[str, bool] = {}
self.pending_rebuild_migrate_instances = 0
self.global_scheduler = GlobalScheduler(global_scheduler_config)
# When manager starts, it automatically connects to all existing instances.
self._connect_to_instances()
Expand Down Expand Up @@ -248,28 +251,97 @@ async def _migrate(self) -> None:
logger.error("unexpected exception occurs: {}".format(e))
logger.error("exception traceback: {}".format(traceback.format_exc()))

async def rebuild_migrate_backend(self) -> None:
# wait for all instances to finish migration
while any(self.instance_migrating.values()):
await asyncio.sleep(0.1)

# when rebuild migrate backend, disable migrate
origin_config = self.enable_migrate
self.enable_migrate = False

async def run_task(alive_instances: List[str], task: str, *args, **kwargs) -> List:
tasks = []
for instance_name in alive_instances:
llumlet_handle = self.instances[instance_name]
tasks.append(llumlet_handle.execute_engine_method.remote(
"_run_workers", task, *args, **kwargs))

rets = await asyncio.gather(*tasks, return_exceptions=True)
dead_instances = []
for instance_name, ret in zip(alive_instances, rets):
if isinstance(ret, ray.exceptions.RayActorError):
self.scale_down(instance_name, rebuild_migrate_backend=False)
dead_instances.append(instance_name)
logger.info(f"{task} fail, {instance_name}: {ret}")

return dead_instances

alive_instances = sorted(self.instances.keys())
pending_task = self.pending_rebuild_migrate_instances

while len(alive_instances) > 0 and self.pending_rebuild_migrate_instances > 0:
group_name = random_uuid()
id_rank_map = {instance_id: index for index, instance_id in enumerate(alive_instances)}

dead_instances = await run_task(alive_instances, "rebuild_migrate_backend", id_rank_map, group_name)

if len(dead_instances) == 0:
dead_instances.extend(await run_task(alive_instances, "warmup"))

if len(dead_instances) == 0:
self.pending_rebuild_migrate_instances -= pending_task

alive_instances = sorted(self.instances.keys())
pending_task = self.pending_rebuild_migrate_instances

if len(alive_instances) > 0:
logger.info(f"rebuild migrate backend done, group_name: {group_name}, alive instance: {alive_instances}")

# restore migrate config
self.enable_migrate = origin_config

def scale_up(self, instance_id: Union[str, Iterable[str]], llumlet_actor_handles: List["ray.actor.ActorHandle"]) -> None:
if isinstance(instance_id, str):
instance_id = [instance_id,]
instance_ids = list(instance_id)

indeed_update = False
for idx, ins_id in enumerate(instance_ids):
if ins_id not in self.instances:
indeed_update = True
self.instances[ins_id] = llumlet_actor_handles[idx]
self.instance_migrating[ins_id] = False
self.pending_rebuild_migrate_instances += 1
self.global_scheduler.scale_up(instance_ids)
self.num_instance = len(self.instances)

def scale_down(self, instance_id: Union[str, Iterable[str]]) -> None:

# When scaling up, we need to rebuild the migration backend. But if self.pending_rebuild_migrate_instances > 1,
# a coroutine is already handling the membership change. And the coroutine will account for the membership changes
# caused by this scale-up (see rebuild_migrate_backend for details). Therefore, we simply return in this case.
if indeed_update and self.engine_manager_args.migration_backend != "rpc" and \
self.pending_rebuild_migrate_instances == 1:
asyncio.create_task(self.rebuild_migrate_backend())

def scale_down(self, instance_id: Union[str, Iterable[str]], rebuild_migrate_backend: bool = True) -> None:
if isinstance(instance_id, str):
instance_id = [instance_id,]
instance_ids = list(instance_id)

indeed_update = False
for ins_id in instance_ids:
if ins_id in self.instances:
indeed_update = True
del self.instances[ins_id]
del self.instance_migrating[ins_id]
self.pending_rebuild_migrate_instances += 1
self.global_scheduler.scale_down(instance_ids)
self.num_instance = len(self.instances)

if indeed_update and rebuild_migrate_backend and self.engine_manager_args.migration_backend != "rpc" \
and self.pending_rebuild_migrate_instances == 1:
asyncio.create_task(self.rebuild_migrate_backend())

def _connect_to_instances(self):
actor_names_dict = ray.util.list_named_actors(True)
instance_actor_names = [actor_name_dict['name'] for actor_name_dict in actor_names_dict if actor_name_dict['name'] != MANAGER_ACTOR_NAME]
Expand Down
4 changes: 2 additions & 2 deletions llumnix/llumlet/llumlet.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from llumnix.logger import init_logger
from llumnix.instance_info import InstanceInfo
from llumnix.backends.backend_interface import BackendInterface, BackendType
from llumnix.backends.utils import init_backend_engine, initialize_cluster
from llumnix.backends.utils import init_backend_engine, get_placement_group
from llumnix.llumlet.migration_coordinator import MigrationCoordinator, MigrationStatus
from llumnix.llumlet.local_migration_scheduler import LocalMigrationScheduler
from llumnix.server_info import ServerInfo
Expand Down Expand Up @@ -63,7 +63,7 @@ def from_args(cls,
assert backend_type in [backend_type.VLLM, backend_type.SIM_VLLM], f'unimplemented backend {backend_type}'
if backend_type == backend_type.VLLM:
if not fixed_node_init:
placement_group = initialize_cluster(world_size)
placement_group = get_placement_group(world_size)
kwargs["placement_group"] = placement_group
engine_class = ray.remote(num_cpus=1,
name=f"instance_{instance_id}",
Expand Down

0 comments on commit ed62763

Please sign in to comment.