Skip to content

Commit

Permalink
resolve conflict
Browse files Browse the repository at this point in the history
  • Loading branch information
Xinyi-ECNU committed Sep 11, 2024
1 parent 1810a2e commit c921ef9
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 7 deletions.
2 changes: 1 addition & 1 deletion llumnix/global_scheduler/dispatch_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def update_instance_infos(self,
def add_instance(self, instance_id: str) -> None:
self.instance_id_set.add(instance_id)
self.num_instances = len(self.instance_id_set)
if self.num_dispatch_instances == -1 or (self.num_dispatch_instances > 0 and
if self.num_dispatch_instances <= 0 or (self.num_dispatch_instances > 0 and
len(self.available_dispatch_instance_set) < self.num_dispatch_instances):
self.available_dispatch_instance_set.add(instance_id)
self.instance_num_requests[instance_id] = 0
Expand Down
8 changes: 3 additions & 5 deletions llumnix/global_scheduler/scaling_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ class InstanceType(str, Enum):
NO_CONSTRAINTS = "NO_CONSTRAINTS"

# Specific to Prefill-Decoding disaggregation.
PREFILL = "prefill"
DECODE = "decode"
PREFILL = "PREFILL"
DECODE = "DECODE"


class ScalingScheduler:
Expand Down Expand Up @@ -79,14 +79,12 @@ def add_instance(self, instance_id: str) -> None:
instance_type = None
if self.maximum_prefill_instance_num > 0:
if len(self.instance_type_id_set[InstanceType.PREFILL]) < self.maximum_prefill_instance_num:
self.instance_type_id_set[InstanceType.PREFILL].add(instance_id)
instance_type = InstanceType.PREFILL
else:
self.instance_type_id_set[InstanceType.DECODE].add(instance_id)
instance_type = InstanceType.DECODE
else:
self.instance_type_id_set[InstanceType.NO_CONSTRAINTS].add(instance_id)
instance_type = InstanceType.NO_CONSTRAINTS
self.instance_type_id_set[instance_type].add(instance_id)
return instance_type

def remove_instance(self, instance_id: str) -> None:
Expand Down
2 changes: 1 addition & 1 deletion tests/global_scheduler/test_migration_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def test_get_migration_instance_infos(pair_migration_type):
constraint_prefill_instance_num = random.randint(-1, INSTANCE_NUM)
migration_scheduler = init_migration_scheduler()
if constraint_prefill_instance_num > 0:
if len([info for info in migration_scheduler.instance_info.values() if info.instance_type == InstanceType.PREFILL]) < constraint_prefill_instance_num:
if len([info for info in instance_info_dict.values() if info.instance_type == InstanceType.PREFILL]) < constraint_prefill_instance_num:
instance_info.instance_type = InstanceType.PREFILL
else:
instance_info.instance_type = InstanceType.DECODE
Expand Down
14 changes: 14 additions & 0 deletions tests/llumlet/test_migration_coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def test_migrate_out_onestage(setup_ray_env):
dst_blocks = [1, 2]
backend_engine.get_request_incremental_blocks.return_value = src_blocks
migrate_out_request.should_abort_migration.return_value = False
migrate_out_request.blocking_migration = False
migrate_in_ray_actor.execute_migration_method.remote.return_value = ray_remote_call.remote(dst_blocks)

# Test normal migration scenario
Expand All @@ -52,6 +53,7 @@ def test_migrate_out_onestage(setup_ray_env):
dst_blocks = [3]
backend_engine.get_request_incremental_blocks.return_value = src_blocks
migrate_out_request.should_abort_migration.return_value = False
migrate_out_request.blocking_migration = False
migrate_in_ray_actor.execute_migration_method.remote.return_value = ray_remote_call.remote(dst_blocks)
status = coordinator.migrate_out_onestage(migrate_in_ray_actor, migrate_out_request)
assert status == MigrationStatus.FINISHED_DONE
Expand All @@ -62,6 +64,7 @@ def test_migrate_out_onestage(setup_ray_env):
dst_blocks = []
backend_engine.get_request_incremental_blocks.return_value = src_blocks
migrate_out_request.should_abort_migration.return_value = False
migrate_out_request.blocking_migration = False
migrate_in_ray_actor.execute_migration_method.remote.return_value = ray_remote_call.remote(dst_blocks)
status = coordinator.migrate_out_onestage(migrate_in_ray_actor, migrate_out_request)
assert status == MigrationStatus.FINISHED_ABORTED
Expand All @@ -71,6 +74,17 @@ def test_migrate_out_onestage(setup_ray_env):
dst_blocks = [1, 2]
backend_engine.get_request_incremental_blocks.return_value = src_blocks
migrate_out_request.should_abort_migration.return_value = True
migrate_out_request.blocking_migration = False
migrate_in_ray_actor.execute_migration_method.remote.return_value = ray_remote_call.remote(dst_blocks)
status = coordinator.migrate_out_onestage(migrate_in_ray_actor, migrate_out_request)
assert status == MigrationStatus.FINISHED_ABORTED

migrate_out_request = MagicMock()
src_blocks = [1, 2, 3]
dst_blocks = [1, 2]
backend_engine.get_request_incremental_blocks.return_value = src_blocks
migrate_out_request.should_abort_migration.return_value = False
migrate_out_request.blocking_migration = True
migrate_in_ray_actor.execute_migration_method.remote.return_value = ray_remote_call.remote(dst_blocks)
status = coordinator.migrate_out_onestage(migrate_in_ray_actor, migrate_out_request)
assert status == MigrationStatus.FINISHED_ABORTED
Expand Down

0 comments on commit c921ef9

Please sign in to comment.