Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
KuilongCui committed Nov 5, 2024
1 parent 68c3a9a commit 8b9050d
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 17 deletions.
62 changes: 45 additions & 17 deletions tests/unit_test/global_scheduler/test_llm_engine_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def __init__(self, instance_id):
self.request_id_set = set()
self.instance_info = None
self.num_migrate_out = 0
self.num_migrate_in = 0

def get_instance_id(self) -> str:
return self.instance_id
Expand Down Expand Up @@ -75,12 +76,22 @@ def abort(self, request_id):
self.num_requests = len(self.request_id_set)
return self.num_requests

def migrate_out(self, src_instance_name, dst_instance_name):
def migrate_out(self, dst_instance_name, num_requests):
self.num_migrate_out += 1
migrate_in_ray_actor = ray.get_actor(dst_instance_name, namespace='llumnix')
ray.get(migrate_in_ray_actor.migrate_in.remote(self.actor_name, num_requests))
time.sleep(0.1)
return self.num_migrate_out

def migrate_in(self, src_instance_name, num_requests):
self.num_migrate_in += 1
return self.num_migrate_in

def get_num_migrate_out(self):
return self.num_migrate_out

def get_num_migrate_in(self):
return self.num_migrate_in

def init_manager():
try:
Expand Down Expand Up @@ -222,20 +233,37 @@ def get_instance_info_migrate_out(instance_id):
return instance_info

def test_update_instance_info_loop_and_migrate(setup_ray_env, engine_manager):
instance_ids, llumlets = init_llumlets(2)
instance_id, instance_id_1 = instance_ids[0], instance_ids[1]
llumlet, llumlet_1 = llumlets[0], llumlets[1]
request_id = random_uuid()
request_id_1 = random_uuid()
ray.get(llumlet.generate.remote(request_id, None, math.inf, None, None))
ray.get(llumlet_1.generate.remote(request_id_1, None, math.inf, None, None))
instance_info_migrate_out = get_instance_info_migrate_out(instance_id)
instance_info_migrate_in = get_instance_info_migrate_in(instance_id_1)
ray.get(llumlet.set_instance_info.remote(instance_info_migrate_out))
ray.get(llumlet_1.set_instance_info.remote(instance_info_migrate_in))
num_migrate_out = ray.get(llumlet.get_num_migrate_out.remote())
assert num_migrate_out == 0
num_llumlets = 5
instance_ids, llumlets = init_llumlets(num_llumlets)

for i in range(num_llumlets):
for _ in range(2*(i+1)):
ray.get(llumlets[i].generate.remote(random_uuid(), None, math.inf, None, None))

instance_info = InstanceInfo()
instance_info.instance_type = InstanceType.NO_CONSTRAINTS

for i in range(num_llumlets):
instance_info.instance_id = instance_ids[i]
instance_info.num_available_gpu_blocks = 40 - i * 10
instance_info.num_running_requests = i
instance_info.num_blocks_first_waiting_request = i
ray.get(llumlets[i].set_instance_info.remote(instance_info))

for i in range(num_llumlets):
num_migrate_out = ray.get(llumlets[i].get_num_migrate_out.remote())
assert num_migrate_out == 0

ray.get(engine_manager.scale_up.remote(instance_ids, llumlets))
time.sleep(0.5)
num_migrate_out = ray.get(llumlet.get_num_migrate_out.remote())
assert num_migrate_out != 0
time.sleep(2)

for i in range(num_llumlets):
num_migrate_out = ray.get(llumlets[i].get_num_migrate_out.remote())
num_migrate_in = ray.get(llumlets[i].get_num_migrate_in.remote())

if i == 0:
assert num_migrate_in > 1 and num_migrate_out == 0
elif i == num_llumlets - 1:
assert num_migrate_in == 0 and num_migrate_out > 1
else:
assert num_migrate_in == 0 and num_migrate_out == 0
3 changes: 3 additions & 0 deletions tests/unit_test/llumlet/test_local_migration_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,14 @@ def test_scheduler_policy():

engine.add_request(request_id="3", length=2, expected_steps=1)
engine.add_request(request_id="4", length=3, expected_steps=math.inf)
engine.add_request(request_id="5", length=4, expected_steps=math.inf)
scheduler.request_migration_policy = "LCFS"
request = scheduler.get_migrate_out_request()
assert request.request_id == "3"
assert request.output_len >= request.expected_steps and request.inference_type == RequestInferenceType.DECODE
request = scheduler.get_migrate_out_request()
assert request.request_id == "5"
request = scheduler.get_migrate_out_request()
assert request.request_id == "4"

def test_scheduler_should_abort_migration():
Expand Down

0 comments on commit 8b9050d

Please sign in to comment.