Skip to content

Commit

Permalink
Merge branch 'main' into fix-unittest
Browse files Browse the repository at this point in the history
  • Loading branch information
s5u13b committed Aug 29, 2024
2 parents edc84a7 + d4cd8fa commit 5f3286d
Show file tree
Hide file tree
Showing 5 changed files with 16 additions and 14 deletions.
6 changes: 3 additions & 3 deletions llumnix/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ class EngineManagerArgs:
last_stage_max_blocks: int = 16
max_stages: int = 3

def create_engine_manager_configs(
def create_global_scheduler_configs(
self,
) -> Tuple[GlobalSchedulerConfig]:
global_scheduler_config = GlobalSchedulerConfig(self.initial_instances,
Expand Down Expand Up @@ -94,8 +94,8 @@ def from_cli_args(cls, args: argparse.Namespace) -> 'EngineManagerArgs':

@classmethod
def _check_args(cls, args):
assert args.migration_backend == 'gloo' \
and not args.disable_init_instance_by_manager and not args.disable_fixed_node_init_instance, \
assert args.migration_backend != 'gloo' or (args.migration_backend == 'gloo' \
and not args.disable_init_instance_by_manager and not args.disable_fixed_node_init_instance), \
("When using gloo as migration backend, "
"do not set --disable-init-instance-by-manager and --disable-fixed-node-init-instance.")

Expand Down
10 changes: 9 additions & 1 deletion llumnix/backends/vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
_modify_greedy_probs_inplace, _beam_search_sample

from llumnix.logger import init_logger
from llumnix.arg_utils import EngineManagerArgs

logger = init_logger(__name__)

Expand All @@ -40,9 +41,16 @@ def detect_unsupported_feature(engine_args: EngineArgs) -> None:
if unsupported_feature:
raise ValueError(f'Unsupported feature: Llumnix does not support "{unsupported_feature}" currently.')

def check_engine_args(engine_args: AsyncEngineArgs) -> None:
def check_engine_args(engine_args: AsyncEngineArgs, engine_manager_args: EngineManagerArgs) -> None:
assert engine_args.engine_use_ray and engine_args.worker_use_ray, \
("In Llumnix, engine and worker must be ray actor.")
migration_config = engine_manager_args.create_migration_config()
engine_config = engine_args.create_engine_config()
parallel_config = engine_config.parallel_config
if parallel_config.world_size > 1 and migration_config.migration_backend == 'nccl':
# TODO(s5u13b): fix logger
print("Llumnix does not support TP or PP enabled model when the migration backend is nccl, change migration backend to gloo.")
engine_manager_args.migration_backend = 'gloo'
detect_unsupported_feature(engine_args)

def _get_dtype_size(dtype: torch.dtype) -> int:
Expand Down
7 changes: 1 addition & 6 deletions llumnix/backends/vllm/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,7 @@ def get_global_rank(self):
return self.global_rank

def reserve_memory_for_migration(self, migration_config: MigrationConfig, model_config: ModelConfig,
cache_config: CacheConfig, parallel_config: ParallelConfig) -> int:
# TODO(s5u13b): move this to arguments checker
if parallel_config.world_size > 1 and migration_config.migration_backend == 'nccl':
logger.warning("nccl backend is not supported for PP or TP enabled model, use gloo instead.")
migration_config.migration_backend = 'gloo'

cache_config: CacheConfig, parallel_config: ParallelConfig) -> int:
migrate_cache_blocks_size = migration_config.migration_cache_blocks
migrate_num_layers = migration_config.migration_num_layers
dummy_cache_size = migrate_num_layers * migrate_cache_blocks_size * CacheEngine.get_cache_block_size(
Expand Down
5 changes: 2 additions & 3 deletions llumnix/entrypoints/vllm/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,7 @@ async def manager_generate(prompt, sampling_params, request_id) -> AsyncStream:
try:
# await to catch exception
await engine_manager.generate.remote(request_id, server_info, prompt, sampling_params)
if not manager_available:
manager_available = True
manager_available = True
except ray.exceptions.RayActorError:
# Do not re-generate the request to avoid duplicate requests.
if manager_available:
Expand Down Expand Up @@ -243,7 +242,7 @@ async def is_ready():
engine_manager_args = EngineManagerArgs.from_cli_args(args)
engine_args = AsyncEngineArgs.from_cli_args(args)

check_engine_args(engine_args)
check_engine_args(engine_args, engine_manager_args)

print("engine_args: {}".format(engine_args))

Expand Down
2 changes: 1 addition & 1 deletion llumnix/llm_engine_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,7 @@ async def _check_instance_error(self, migrate_instance_pairs: Tuple[str, str]) -
def from_args(cls,
engine_manager_args: EngineManagerArgs,
profiling_database: ProfilingDatabase=None) -> "LLMEngineManager":
global_scheduler_config = engine_manager_args.create_engine_manager_configs()
global_scheduler_config = engine_manager_args.create_global_scheduler_configs()
# Init manager actor in 'llumnix' namespace to ensure that only one manager can be created.
manager_class = ray.remote(num_cpus=0,
max_restarts=-1,
Expand Down

0 comments on commit 5f3286d

Please sign in to comment.