diff --git a/tests/unit_test/global_scheduler/test_llm_engine_manager.py b/tests/unit_test/global_scheduler/test_llm_engine_manager.py index b744ced6..7de04c8a 100644 --- a/tests/unit_test/global_scheduler/test_llm_engine_manager.py +++ b/tests/unit_test/global_scheduler/test_llm_engine_manager.py @@ -26,6 +26,7 @@ from llumnix.server_info import ServerInfo from llumnix.queue.queue_type import QueueType from llumnix.global_scheduler.scaling_scheduler import InstanceType +from llumnix.backends.vllm.simulator import BackendSimVLLM # pylint: disable=unused-import from tests.conftest import setup_ray_env @@ -81,6 +82,13 @@ def migrate_out(self, src_instance_name, dst_instance_name): def get_num_migrate_out(self): return self.num_migrate_out +class MockBackendSim(BackendSimVLLM): + + def _get_lantecy_mem(self, *args, **kwargs): + latency_mem = LatencyMemData({}, {}, {}) + latency_mem.prefill_model_params = (0,0) + latency_mem.decode_model_params = (0,0,0) + return latency_mem def init_manager(): try: @@ -138,6 +146,17 @@ def test_init_llumlets(setup_ray_env, engine_manager): engine_manager_args = EngineManagerArgs() assert num_instances == engine_manager_args.initial_instances +def test_init_llumlets_sim(setup_ray_env, engine_manager): + engine_manager.profiling_result_file_pagth="//" + import llumnix.backends.vllm.simulator + llumnix.backends.vllm.simulator.BackendSimVLLM = MockBackendSim + engine_args = EngineArgs(model="facebook/opt-125m", worker_use_ray=True) + node_id = ray.get_runtime_context().get_node_id() + instance_ids, llumlets = ray.get(engine_manager.init_llumlets.remote(engine_args, node_id, QueueType("rayqueue"))) + num_instances = ray.get(engine_manager.scale_up.remote(instance_ids, llumlets)) + engine_manager_args = EngineManagerArgs() + assert num_instances == engine_manager_args.initial_instances + def test_scale_up_and_down(setup_ray_env, engine_manager): initial_instances = 4 instance_ids, llumlets = init_llumlets(initial_instances)