From 8b9050d8b199d5a9d747a26e2c5d2e181d0c5cea Mon Sep 17 00:00:00 2001 From: root Date: Tue, 5 Nov 2024 02:46:39 +0000 Subject: [PATCH] fix --- .../test_llm_engine_manager.py | 62 ++++++++++++++----- .../llumlet/test_local_migration_scheduler.py | 3 + 2 files changed, 48 insertions(+), 17 deletions(-) diff --git a/tests/unit_test/global_scheduler/test_llm_engine_manager.py b/tests/unit_test/global_scheduler/test_llm_engine_manager.py index b744ced6..06ecc836 100644 --- a/tests/unit_test/global_scheduler/test_llm_engine_manager.py +++ b/tests/unit_test/global_scheduler/test_llm_engine_manager.py @@ -40,6 +40,7 @@ def __init__(self, instance_id): self.request_id_set = set() self.instance_info = None self.num_migrate_out = 0 + self.num_migrate_in = 0 def get_instance_id(self) -> str: return self.instance_id @@ -75,12 +76,22 @@ def abort(self, request_id): self.num_requests = len(self.request_id_set) return self.num_requests - def migrate_out(self, src_instance_name, dst_instance_name): + def migrate_out(self, dst_instance_name, num_requests): self.num_migrate_out += 1 + migrate_in_ray_actor = ray.get_actor(dst_instance_name, namespace='llumnix') + ray.get(migrate_in_ray_actor.migrate_in.remote(self.actor_name, num_requests)) + time.sleep(0.1) + return self.num_migrate_out + + def migrate_in(self, src_instance_name, num_requests): + self.num_migrate_in += 1 + return self.num_migrate_in def get_num_migrate_out(self): return self.num_migrate_out + def get_num_migrate_in(self): + return self.num_migrate_in def init_manager(): try: @@ -222,20 +233,37 @@ def get_instance_info_migrate_out(instance_id): return instance_info def test_update_instance_info_loop_and_migrate(setup_ray_env, engine_manager): - instance_ids, llumlets = init_llumlets(2) - instance_id, instance_id_1 = instance_ids[0], instance_ids[1] - llumlet, llumlet_1 = llumlets[0], llumlets[1] - request_id = random_uuid() - request_id_1 = random_uuid() - ray.get(llumlet.generate.remote(request_id, None, math.inf, None, None)) - ray.get(llumlet_1.generate.remote(request_id_1, None, math.inf, None, None)) - instance_info_migrate_out = get_instance_info_migrate_out(instance_id) - instance_info_migrate_in = get_instance_info_migrate_in(instance_id_1) - ray.get(llumlet.set_instance_info.remote(instance_info_migrate_out)) - ray.get(llumlet_1.set_instance_info.remote(instance_info_migrate_in)) - num_migrate_out = ray.get(llumlet.get_num_migrate_out.remote()) - assert num_migrate_out == 0 + num_llumlets = 5 + instance_ids, llumlets = init_llumlets(num_llumlets) + + for i in range(num_llumlets): + for _ in range(2*(i+1)): + ray.get(llumlets[i].generate.remote(random_uuid(), None, math.inf, None, None)) + + instance_info = InstanceInfo() + instance_info.instance_type = InstanceType.NO_CONSTRAINTS + + for i in range(num_llumlets): + instance_info.instance_id = instance_ids[i] + instance_info.num_available_gpu_blocks = 40 - i * 10 + instance_info.num_running_requests = i + instance_info.num_blocks_first_waiting_request = i + ray.get(llumlets[i].set_instance_info.remote(instance_info)) + + for i in range(num_llumlets): + num_migrate_out = ray.get(llumlets[i].get_num_migrate_out.remote()) + assert num_migrate_out == 0 + ray.get(engine_manager.scale_up.remote(instance_ids, llumlets)) - time.sleep(0.5) - num_migrate_out = ray.get(llumlet.get_num_migrate_out.remote()) - assert num_migrate_out != 0 + time.sleep(2) + + for i in range(num_llumlets): + num_migrate_out = ray.get(llumlets[i].get_num_migrate_out.remote()) + num_migrate_in = ray.get(llumlets[i].get_num_migrate_in.remote()) + + if i == 0: + assert num_migrate_in > 1 and num_migrate_out == 0 + elif i == num_llumlets - 1: + assert num_migrate_in == 0 and num_migrate_out > 1 + else: + assert num_migrate_in == 0 and num_migrate_out == 0 diff --git a/tests/unit_test/llumlet/test_local_migration_scheduler.py b/tests/unit_test/llumlet/test_local_migration_scheduler.py index 2db1c304..adf96810 100644 --- a/tests/unit_test/llumlet/test_local_migration_scheduler.py +++ b/tests/unit_test/llumlet/test_local_migration_scheduler.py @@ -67,11 +67,14 @@ def test_scheduler_policy(): engine.add_request(request_id="3", length=2, expected_steps=1) engine.add_request(request_id="4", length=3, expected_steps=math.inf) + engine.add_request(request_id="5", length=4, expected_steps=math.inf) scheduler.request_migration_policy = "LCFS" request = scheduler.get_migrate_out_request() assert request.request_id == "3" assert request.output_len >= request.expected_steps and request.inference_type == RequestInferenceType.DECODE request = scheduler.get_migrate_out_request() + assert request.request_id == "5" + request = scheduler.get_migrate_out_request() assert request.request_id == "4" def test_scheduler_should_abort_migration():