Skip to content

Commit

Permalink
[Bugfix] Address Request Status Changes During Migration Asynchronous…
Browse files Browse the repository at this point in the history
… Operations
  • Loading branch information
KuilongCui committed Nov 18, 2024
1 parent bcd49ba commit df22eef
Show file tree
Hide file tree
Showing 7 changed files with 97 additions and 40 deletions.
1 change: 1 addition & 0 deletions llumnix/llm_engine_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,7 @@ async def migrate_done_callback(ret, migrate_instance_pair: Tuple[str, str]) ->
# Instance without error should clear migration states.
if not has_error:
try:
# TODO(s5u13b): clear_migration_states by instance_id
await self.instances[migrate_instance_pair[i]].clear_migration_states.remote(is_migrate_in=bool(i))
except (ray.exceptions.RayActorError, ray.exceptions.RayTaskError, KeyError):
has_error = True
Expand Down
8 changes: 8 additions & 0 deletions llumnix/llumlet/llumlet.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,10 @@ async def migrate_out(self, dst_instance_name: str) -> List[str]:
migrate_out_requests = self.migration_scheduler.get_migrate_out_requests()
if len(migrate_out_requests) == 0:
return []

for migrate_out_request in migrate_out_requests:
migrate_out_request.is_migrating = True

migrated_request_list = []
for migrate_out_request in migrate_out_requests:
migrated_request = await self._migrate_out_one_request(migrate_out_request, dst_instance_name)
Expand All @@ -148,12 +152,16 @@ async def _migrate_out_one_request(self, migrate_out_request: LlumnixRequest, ds
dst_instance_id = dst_instance_name[len("instance_"):]
logger.info("{}->{} begin migrate out".format(self.instance_id, dst_instance_id))
migrated_request = []

if migrate_out_request.status == RequestStatus.RUNNING:
migrate_out_request.migration_start_time = time.time()
status = await self.migration_coordinator.migrate_out_running_request(migrate_in_ray_actor, migrate_out_request)
elif migrate_out_request.status == RequestStatus.WAITING:
migrate_out_request.migration_start_time = time.time()
status = await self.migration_coordinator.migrate_out_waiting_request(migrate_in_ray_actor, migrate_out_request)
else:
return migrated_request

if status == MigrationStatus.FINISHED:
await migrate_in_ray_actor.execute_engine_method.remote("commit_dst_request", migrate_out_request)
self.backend_engine.free_src_request(migrate_out_request)
Expand Down
2 changes: 2 additions & 0 deletions llumnix/llumlet/local_migration_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def _filter_running_queue(self, running, min_request_len, max_request_len):
if request.status == RequestStatus.RUNNING \
and request.inference_type == RequestInferenceType.DECODE \
and min_request_len < request.request_len < max_request_len \
and (not request.is_migrating) \
]
return filtered_running

Expand All @@ -67,6 +68,7 @@ def _filter_waiting_queue(self, waiting, min_request_len, max_request_len):
if request.status == RequestStatus.WAITING \
and request.try_schedule_times >= 1 \
and min_request_len < request.request_len < max_request_len \
and (not request.is_migrating) \
]
return filtered_waiting

Expand Down
6 changes: 6 additions & 0 deletions llumnix/llumlet/migration_coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,9 @@ async def _migrate_out_onestage(self,
migrate_out_request: LlumnixRequest) -> "MigrationStatus":
"""one-stage live migration until last stage for a running request
"""
if migrate_out_request.should_abort_migration():
return MigrationStatus.ABORTED_SRC

pre_stage_num_blocks = sum(migrate_out_request.stage_num_blocks_list)
incremental_blocks = self.backend_engine.get_request_incremental_blocks(migrate_out_request, pre_stage_num_blocks)
# live migration, transfer all blocks except last one(currently updating)
Expand Down Expand Up @@ -129,12 +132,15 @@ async def _migrate_out_onestage(self,
self.backend_engine.add_running_request(migrate_out_request)
self.backend_engine.remove_migrating_out_request_last_stage(migrate_out_request)
return MigrationStatus.ABORTED_DST
if migrate_out_request.should_abort_migration():
return MigrationStatus.ABORTED_SRC

# do stage send/recv
migrate_out_request.stage_timestamps.append(time.time())
migrate_out_request.stage_num_blocks_list.append(stage_block_num)
# TODO(ZeldaHuang): send_blocks in migrate_in_pre_alloc/migrate_in_last_stage
await self.backend_engine.send_blocks(migrate_in_ray_actor, src_blocks, dst_blocks)

if not is_last_stage and migrate_out_request.should_abort_migration():
# migrate-out request abort by scheduler during send/recv
return MigrationStatus.ABORTED_SRC
Expand Down
12 changes: 10 additions & 2 deletions llumnix/llumlet/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ def __init__(self, request_id: int, server_info: ServerInfo, expected_steps: int
self.stage_num_blocks_list = []
self.try_schedule_times = 0
self._status = None
self.migration_start_time = None
self.is_migrating = False

# end-of-migration, for multiple requests migration
self.eom = False
Expand All @@ -53,11 +55,15 @@ def reset_migration_args_dst(self):
self.stage_timestamps = []
self.stage_num_blocks_list = []
self.try_schedule_times = 0
self.migration_start_time = None
self.is_migrating = False

def reset_migration_args_src(self):
self.last_preemption_time = None
self.stage_timestamps = []
self.stage_num_blocks_list = []
self.migration_start_time = None
self.is_migrating = False

def reset_status(self):
self._status = None
Expand Down Expand Up @@ -104,5 +110,7 @@ def blocking_migration(self) -> bool:
return self.output_len >= self.expected_steps

def should_abort_migration(self) -> bool:
return self.finished \
or (self.last_preemption_time is not None and self.last_preemption_time > self.stage_timestamps[-1])
begin_time = self.stage_timestamps[-1] if len(self.stage_timestamps) > 0 else self.migration_start_time
preempted = self.last_preemption_time is not None and self.last_preemption_time > begin_time

return self.finished or preempted
45 changes: 31 additions & 14 deletions tests/e2e_test/test_bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import subprocess
from concurrent.futures import ThreadPoolExecutor, as_completed
import asyncio
import json
import os
import subprocess
import pytest
import torch
import numpy as np
Expand Down Expand Up @@ -99,27 +100,43 @@ async def test_simple_benchmark(setup_ray_env, model):
launch_ray_cluster=False, port=base_port+i, model=model)
subprocess.run(launch_command, shell=True, check=True)

await asyncio.sleep(60)
await asyncio.sleep(30)

async def run_bench_command(command):
process = await asyncio.create_subprocess_shell(command)
await process.wait()
assert process.returncode == 0
def run_bench_command(command):
# pylint: disable=consider-using-with
process = subprocess.Popen(command, shell=True)
return process

tasks = []
for i in range(device_count):
bench_command = generate_bench_command(ip_ports=f"127.0.0.1:{base_port+i}", model=model, num_prompts=300,
dataset_type="sharegpt",
dataset_path="/mnt/dataset/sharegpt_gpt4/sharegpt_gpt4.jsonl" ,
qps=2,
results_filename=f"{base_port+i}.out")
tasks.append(run_bench_command(bench_command))

await asyncio.wait(tasks, timeout=60*30)
bench_command = generate_bench_command(
ip_ports=f"127.0.0.1:{base_port + i}",
model=model,
num_prompts=300,
dataset_type="sharegpt",
dataset_path="/mnt/dataset/sharegpt_gpt4/sharegpt_gpt4.jsonl",
qps=2,
results_filename=f"{base_port + i}.out"
)
tasks.append(bench_command)

with ThreadPoolExecutor() as executor:
future_to_command = {executor.submit(run_bench_command, command): command for command in tasks}

for future in as_completed(future_to_command):
try:
process = future.result()
process.wait(timeout=60*30)
assert process.returncode == 0, "bench_test failed with return code {}.".format(process.returncode)
# pylint: disable=broad-except
except subprocess.TimeoutExpired:
process.kill()
print("bench_test timed out after 30 minutes.")

with open("performance.txt", "w", encoding="utf-8") as f:
f.write(parse_log_file())

# TODO(KuilongCui): change clear_state function to fixture
shutdown_llumnix_service()
clear_ray_state()
await asyncio.sleep(3)
63 changes: 39 additions & 24 deletions tests/e2e_test/test_migration.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import subprocess
from concurrent.futures import ThreadPoolExecutor, as_completed
import asyncio
from collections import defaultdict
import re
import subprocess
import pytest
import torch
import pandas as pd
Expand Down Expand Up @@ -49,7 +50,9 @@ def parse_instance_log_file(log_files):

speeds.sort()
trimmed_speeds = speeds[1:-1]
average_speed[transfer_size] = sum(trimmed_speeds) / len(trimmed_speeds)

if len(trimmed_speeds) > 0:
average_speed[transfer_size] = sum(trimmed_speeds) / len(trimmed_speeds)

assert len(average_speed) > 0, "Migration should have occurred, but it was not detected. "

Expand All @@ -66,8 +69,8 @@ def parse_manager_log_file(log_file):
@pytest.mark.asyncio
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="at least 2 gpus required for migration bench")
@pytest.mark.parametrize("model", ['/mnt/model/Qwen-7B'])
@pytest.mark.parametrize("migration_backend", ['rpc', 'gloo'])
@pytest.mark.parametrize("migrated_request_status", ['running', 'waiting'])
@pytest.mark.parametrize("migration_backend", ['gloo'])
@pytest.mark.parametrize("migrated_request_status", ['running'])
async def test_migration_benchmark(model, migration_backend, migrated_request_status):
if migrated_request_status == 'waiting' and migration_backend != 'rpc':
pytest.skip("When the migrated request status is waiting, only test the rpc migration backend.")
Expand All @@ -86,29 +89,41 @@ async def test_migration_benchmark(model, migration_backend, migrated_request_st
log_instance_info=True,
request_migration_policy=request_migration_policy)
subprocess.run(launch_command, shell=True, check=True)
await asyncio.sleep(5)
await asyncio.sleep(30)

async def run_bench_command(command):
process = await asyncio.create_subprocess_shell(command)
await process.wait()
assert process.returncode == 0
def run_bench_command(command):
# pylint: disable=consider-using-with
process = subprocess.Popen(command, shell=True)
return process

tasks = []
for i in range(device_count//2):
bench_command = generate_bench_command(ip_ports=f"127.0.0.1:{base_port+i}", model=model, num_prompts=300,
dataset_type="sharegpt",
dataset_path="/mnt/dataset/sharegpt_gpt4/sharegpt_gpt4.jsonl" ,
qps=10,
results_filename=f"{base_port+i}.out")
tasks.append(asyncio.create_task(run_bench_command(bench_command)))

_, pending = await asyncio.wait(tasks, timeout=60*30)

await asyncio.sleep(10)

if len(pending) > 0:
raise RuntimeError("migration task Timeout")
for i in range(device_count // 2):
bench_command = generate_bench_command(
ip_ports=f"127.0.0.1:{base_port + i}",
model=model,
num_prompts=300,
dataset_type="sharegpt",
dataset_path="/mnt/dataset/sharegpt_gpt4/sharegpt_gpt4.jsonl",
qps=10,
results_filename=f"{base_port + i}.out"
)
tasks.append(bench_command)

# Execute the commands concurrently using ThreadPoolExecutor
with ThreadPoolExecutor() as executor:
future_to_command = {executor.submit(run_bench_command, command): command for command in tasks}

for future in as_completed(future_to_command):
try:
process = future.result()
process.wait(timeout=60*30)
assert process.returncode == 0, "migration_test failed with return code {}.".format(process.returncode)
# pylint: disable=broad-except
except subprocess.TimeoutExpired:
process.kill()
print("bench_test timed out after 30 minutes.")

await asyncio.sleep(5)

parse_manager_log_file("manager_instance.csv")

Expand All @@ -124,4 +139,4 @@ async def run_bench_command(command):

shutdown_llumnix_service()
clear_ray_state()
await asyncio.sleep(10)
await asyncio.sleep(3)

0 comments on commit df22eef

Please sign in to comment.