diff --git a/llumnix/backends/vllm/simulator.py b/llumnix/backends/vllm/simulator.py index b5ccb45b..94367d75 100644 --- a/llumnix/backends/vllm/simulator.py +++ b/llumnix/backends/vllm/simulator.py @@ -36,6 +36,7 @@ def __init__( migration_config: MigrationConfig, profiling_result_file_path: str, engine_args: EngineArgs, + node_id: str = None, ) -> None: # multi-instance args latency_mem = self._get_lantecy_mem(profiling_result_file_path, engine_args) @@ -43,7 +44,8 @@ def __init__( output_queue_type=output_queue_type, migration_config=migration_config, instance_id=instance_id, - latency_mem=latency_mem) + latency_mem=latency_mem, + node_id=node_id) self.engine.scheduler = SchedulerLlumnix(self.engine.scheduler_config, self.engine.cache_config, self.engine.lora_config) self.engine.scheduler.add_update_instance_info_callback(self.engine.update_instance_info) self.engine.output_processor.scheduler = self.engine.scheduler diff --git a/tests/unit_test/global_scheduler/test_llm_engine_manager.py b/tests/unit_test/global_scheduler/test_llm_engine_manager.py index b744ced6..5c5fc644 100644 --- a/tests/unit_test/global_scheduler/test_llm_engine_manager.py +++ b/tests/unit_test/global_scheduler/test_llm_engine_manager.py @@ -26,6 +26,7 @@ from llumnix.server_info import ServerInfo from llumnix.queue.queue_type import QueueType from llumnix.global_scheduler.scaling_scheduler import InstanceType +from llumnix.backends.vllm.simulator import BackendSimVLLM # pylint: disable=unused-import from tests.conftest import setup_ray_env @@ -81,6 +82,13 @@ def migrate_out(self, src_instance_name, dst_instance_name): def get_num_migrate_out(self): return self.num_migrate_out +class MockBackendSim(BackendSimVLLM): + + def _get_lantecy_mem(self, *args, **kwargs): + latency_mem = LatencyMemData({}, {}, {}) + latency_mem.prefill_model_params = (0,0) + latency_mem.decode_model_params = (0,0,0) + return latency_mem def init_manager(): try: @@ -138,6 +146,18 @@ def test_init_llumlets(setup_ray_env, engine_manager): engine_manager_args = EngineManagerArgs() assert num_instances == engine_manager_args.initial_instances +def test_init_llumlets_sim(setup_ray_env, engine_manager): + engine_manager.profiling_result_file_path="//" + # pylint: disable=import-outside-toplevel + import llumnix.backends.vllm.simulator + llumnix.backends.vllm.simulator.BackendSimVLLM = MockBackendSim + engine_args = EngineArgs(model="facebook/opt-125m", worker_use_ray=True) + node_id = ray.get_runtime_context().get_node_id() + instance_ids, llumlets = ray.get(engine_manager.init_llumlets.remote(engine_args, node_id, QueueType("rayqueue"))) + num_instances = ray.get(engine_manager.scale_up.remote(instance_ids, llumlets)) + engine_manager_args = EngineManagerArgs() + assert num_instances == engine_manager_args.initial_instances + def test_scale_up_and_down(setup_ray_env, engine_manager): initial_instances = 4 instance_ids, llumlets = init_llumlets(initial_instances)