diff --git a/.github/workflows/offline_inference.yml b/.github/workflows/offline_inference.yml index 91a5d46a..65c6c848 100644 --- a/.github/workflows/offline_inference.yml +++ b/.github/workflows/offline_inference.yml @@ -24,9 +24,4 @@ jobs: steps: - uses: actions/checkout@v4 - name: Run offline inference example - run: | - nvidia-docker run --rm -t --net host --ipc host \ - -v ${PWD}:/workspace \ - -w /workspace \ - registry.cn-beijing.aliyuncs.com/llumnix/llumnix-dev:20240909_action_678a439 \ - bash -c "pip install -e . > /dev/null && make offline_test" + run: ./tools/offline_test.sh diff --git a/llumnix/backends/vllm/worker.py b/llumnix/backends/vllm/worker.py index e38c3423..7c910526 100644 --- a/llumnix/backends/vllm/worker.py +++ b/llumnix/backends/vllm/worker.py @@ -15,6 +15,7 @@ from typing import Dict, List import math import torch +import ray from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy, NodeAffinitySchedulingStrategy from vllm.utils import is_pin_memory_available diff --git a/llumnix/llm_engine_manager.py b/llumnix/llm_engine_manager.py index 5d8c48a5..5212f3bb 100644 --- a/llumnix/llm_engine_manager.py +++ b/llumnix/llm_engine_manager.py @@ -232,13 +232,14 @@ async def migrate_done_callback(ret, migrate_instance_pair: Tuple[str, str]) -> # TODO(s5u13b): Add more exception types for failover. if isinstance(ret, (ray.exceptions.RayActorError, ray.exceptions.RayTaskError, KeyError)): has_error_pair = await self._check_instance_error(migrate_instance_pair) - for i, has_error in enumerate(has_error_pair): - # Instance without error should clear migration states. - if not has_error: - try: - await self.instances[migrate_instance_pair[i]].clear_migration_states.remote(is_migrate_in=bool(i)) - except (ray.exceptions.RayActorError, ray.exceptions.RayTaskError, KeyError): - has_error = True + # TODO(s5u13b): clear_migration_states by instance_id + # for i, has_error in enumerate(has_error_pair): + # # Instance without error should clear migration states. + # if not has_error: + # try: + # await self.instances[migrate_instance_pair[i]].clear_migration_states.remote(is_migrate_in=bool(i)) + # except (ray.exceptions.RayActorError, ray.exceptions.RayTaskError, KeyError): + # has_error = True for i, has_error in enumerate(has_error_pair): if has_error: instance_id = migrate_instance_pair[i] diff --git a/llumnix/llumlet/llumlet.py b/llumnix/llumlet/llumlet.py index 3af73ac5..b1c91fe9 100644 --- a/llumnix/llumlet/llumlet.py +++ b/llumnix/llumlet/llumlet.py @@ -133,6 +133,10 @@ async def migrate_out(self, dst_instance_name: str) -> List[str]: migrate_out_requests = self.migration_scheduler.get_migrate_out_requests() if len(migrate_out_requests) == 0: return [] + + for migrate_out_request in migrate_out_requests: + migrate_out_request.is_migrating = True + migrated_request_list = [] for migrate_out_request in migrate_out_requests: migrated_request = await self._migrate_out_one_request(migrate_out_request, dst_instance_name) @@ -148,12 +152,16 @@ async def _migrate_out_one_request(self, migrate_out_request: LlumnixRequest, ds dst_instance_id = dst_instance_name[len("instance_"):] logger.info("{}->{} begin migrate out".format(self.instance_id, dst_instance_id)) migrated_request = [] + if migrate_out_request.status == RequestStatus.RUNNING: + migrate_out_request.migration_start_time = time.time() status = await self.migration_coordinator.migrate_out_running_request(migrate_in_ray_actor, migrate_out_request) elif migrate_out_request.status == RequestStatus.WAITING: + migrate_out_request.migration_start_time = time.time() status = await self.migration_coordinator.migrate_out_waiting_request(migrate_in_ray_actor, migrate_out_request) else: return migrated_request + if status == MigrationStatus.FINISHED: await migrate_in_ray_actor.execute_engine_method.remote("commit_dst_request", migrate_out_request) self.backend_engine.free_src_request(migrate_out_request) diff --git a/llumnix/llumlet/local_migration_scheduler.py b/llumnix/llumlet/local_migration_scheduler.py index 4f30f850..b3aee50d 100644 --- a/llumnix/llumlet/local_migration_scheduler.py +++ b/llumnix/llumlet/local_migration_scheduler.py @@ -58,6 +58,7 @@ def _filter_running_queue(self, running, min_request_len, max_request_len): if request.status == RequestStatus.RUNNING \ and request.inference_type == RequestInferenceType.DECODE \ and min_request_len < request.request_len < max_request_len \ + and (not request.is_migrating) \ ] return filtered_running @@ -67,6 +68,7 @@ def _filter_waiting_queue(self, waiting, min_request_len, max_request_len): if request.status == RequestStatus.WAITING \ and request.try_schedule_times >= 1 \ and min_request_len < request.request_len < max_request_len \ + and (not request.is_migrating) \ ] return filtered_waiting diff --git a/llumnix/llumlet/migration_coordinator.py b/llumnix/llumlet/migration_coordinator.py index 224c41c3..3ef54766 100644 --- a/llumnix/llumlet/migration_coordinator.py +++ b/llumnix/llumlet/migration_coordinator.py @@ -95,6 +95,9 @@ async def _migrate_out_onestage(self, migrate_out_request: LlumnixRequest) -> "MigrationStatus": """one-stage live migration until last stage for a running request """ + if migrate_out_request.should_abort_migration(): + return MigrationStatus.ABORTED_SRC + pre_stage_num_blocks = sum(migrate_out_request.stage_num_blocks_list) incremental_blocks = self.backend_engine.get_request_incremental_blocks(migrate_out_request, pre_stage_num_blocks) # live migration, transfer all blocks except last one(currently updating) @@ -129,12 +132,15 @@ async def _migrate_out_onestage(self, self.backend_engine.add_running_request(migrate_out_request) self.backend_engine.remove_migrating_out_request_last_stage(migrate_out_request) return MigrationStatus.ABORTED_DST + if migrate_out_request.should_abort_migration(): + return MigrationStatus.ABORTED_SRC # do stage send/recv migrate_out_request.stage_timestamps.append(time.time()) migrate_out_request.stage_num_blocks_list.append(stage_block_num) # TODO(ZeldaHuang): send_blocks in migrate_in_pre_alloc/migrate_in_last_stage await self.backend_engine.send_blocks(migrate_in_ray_actor, src_blocks, dst_blocks) + if not is_last_stage and migrate_out_request.should_abort_migration(): # migrate-out request abort by scheduler during send/recv return MigrationStatus.ABORTED_SRC diff --git a/llumnix/llumlet/request.py b/llumnix/llumlet/request.py index d92e6564..d6c7dac5 100644 --- a/llumnix/llumlet/request.py +++ b/llumnix/llumlet/request.py @@ -41,6 +41,8 @@ def __init__(self, request_id: int, server_info: ServerInfo, expected_steps: int self.stage_num_blocks_list = [] self.try_schedule_times = 0 self._status = None + self.migration_start_time = None + self.is_migrating = False # end-of-migration, for multiple requests migration self.eom = False @@ -53,11 +55,15 @@ def reset_migration_args_dst(self): self.stage_timestamps = [] self.stage_num_blocks_list = [] self.try_schedule_times = 0 + self.migration_start_time = None + self.is_migrating = False def reset_migration_args_src(self): self.last_preemption_time = None self.stage_timestamps = [] self.stage_num_blocks_list = [] + self.migration_start_time = None + self.is_migrating = False def reset_status(self): self._status = None @@ -104,5 +110,7 @@ def blocking_migration(self) -> bool: return self.output_len >= self.expected_steps def should_abort_migration(self) -> bool: - return self.finished \ - or (self.last_preemption_time is not None and self.last_preemption_time > self.stage_timestamps[-1]) + begin_time = self.stage_timestamps[-1] if len(self.stage_timestamps) > 0 else self.migration_start_time + preempted = self.last_preemption_time is not None and self.last_preemption_time > begin_time + + return self.finished or preempted diff --git a/tests/e2e_test/test_bench.py b/tests/e2e_test/test_bench.py index 5eba27d1..482ae11d 100644 --- a/tests/e2e_test/test_bench.py +++ b/tests/e2e_test/test_bench.py @@ -11,17 +11,17 @@ # See the License for the specific language governing permissions and # limitations under the License. +import subprocess +from concurrent.futures import ThreadPoolExecutor, as_completed import asyncio import json import os -import subprocess import pytest import torch import numpy as np from .test_e2e import generate_launch_command, clear_ray_state -# pylint: disable=unused-import -from .utils import to_markdown_table, setup_ray_env +from .utils import to_markdown_table def launch_llumnix_service(command): subprocess.run(command, shell=True, check=True) @@ -91,7 +91,7 @@ def get_markdown_data(key: str, head_name: str): @pytest.mark.asyncio @pytest.mark.skipif(torch.cuda.device_count() < 1, reason="at least 1 gpus required for simple benchmark") @pytest.mark.parametrize("model", ['/mnt/model/Qwen-7B']) -async def test_simple_benchmark(setup_ray_env, model): +async def test_simple_benchmark(model): device_count = torch.cuda.device_count() base_port = 37037 for i in range(device_count): @@ -99,27 +99,43 @@ async def test_simple_benchmark(setup_ray_env, model): launch_ray_cluster=False, port=base_port+i, model=model) subprocess.run(launch_command, shell=True, check=True) - await asyncio.sleep(60) + await asyncio.sleep(30) - async def run_bench_command(command): - process = await asyncio.create_subprocess_shell(command) - await process.wait() - assert process.returncode == 0 + def run_bench_command(command): + # pylint: disable=consider-using-with + process = subprocess.Popen(command, shell=True) + return process tasks = [] for i in range(device_count): - bench_command = generate_bench_command(ip_ports=f"127.0.0.1:{base_port+i}", model=model, num_prompts=300, - dataset_type="sharegpt", - dataset_path="/mnt/dataset/sharegpt_gpt4/sharegpt_gpt4.jsonl" , - qps=2, - results_filename=f"{base_port+i}.out") - tasks.append(run_bench_command(bench_command)) - - await asyncio.wait(tasks, timeout=60*30) + bench_command = generate_bench_command( + ip_ports=f"127.0.0.1:{base_port + i}", + model=model, + num_prompts=200, + dataset_type="sharegpt", + dataset_path="/mnt/dataset/sharegpt_gpt4/sharegpt_gpt4.jsonl", + qps=5, + results_filename=f"{base_port + i}.out" + ) + tasks.append(bench_command) + + with ThreadPoolExecutor() as executor: + future_to_command = {executor.submit(run_bench_command, command): command for command in tasks} + + for future in as_completed(future_to_command): + try: + process = future.result() + process.wait(timeout=60*30) + assert process.returncode == 0, "bench_test failed with return code {}.".format(process.returncode) + # pylint: disable=broad-except + except subprocess.TimeoutExpired: + process.kill() + print("bench_test timed out after 30 minutes.") with open("performance.txt", "w", encoding="utf-8") as f: f.write(parse_log_file()) + # TODO(KuilongCui): change clear_state function to fixture shutdown_llumnix_service() clear_ray_state() await asyncio.sleep(3) diff --git a/tests/e2e_test/test_e2e.py b/tests/e2e_test/test_e2e.py index a3bf1977..8dcd9398 100644 --- a/tests/e2e_test/test_e2e.py +++ b/tests/e2e_test/test_e2e.py @@ -19,8 +19,6 @@ import torch from vllm import LLM, SamplingParams -# pylint: disable=unused-import -from .utils import setup_ray_env def parse_launch_mode(launch_mode: str): # 'eief' means that enable init instance by manager and enable fixed node init instance, and so on. @@ -140,7 +138,7 @@ def run_vllm(model, max_model_len, sampling_params): @pytest.mark.parametrize("model", ['/mnt/model/Qwen-7B']) @pytest.mark.parametrize("migration_backend", ['rpc', 'gloo']) @pytest.mark.parametrize("launch_mode", ['eief', 'eidf', 'dief', 'didf']) -async def test_e2e(setup_ray_env, model, migration_backend, launch_mode): +async def test_e2e(model, migration_backend, launch_mode): if migration_backend == 'gloo' and launch_mode != 'eief': pytest.skip("When the migration backend is gloo, the launch mode of llumnix can only be eief") max_model_len = 370 diff --git a/tests/e2e_test/test_migration.py b/tests/e2e_test/test_migration.py index ced1e0be..91f0a3e2 100644 --- a/tests/e2e_test/test_migration.py +++ b/tests/e2e_test/test_migration.py @@ -11,18 +11,18 @@ # See the License for the specific language governing permissions and # limitations under the License. +import subprocess +from concurrent.futures import ThreadPoolExecutor, as_completed import asyncio from collections import defaultdict import re -import subprocess import pytest import torch import pandas as pd from .test_e2e import generate_launch_command from .test_bench import generate_bench_command, clear_ray_state, shutdown_llumnix_service -# pylint: disable=unused-import -from .utils import to_markdown_table, setup_ray_env +from .utils import to_markdown_table size_pattern = re.compile(r'total_kv_cache_size:\s*([\d.]+)\s*(B|KB|MB|GB|KB|TB)') speed_pattern = re.compile(r'speed:\s*([\d.]+)GB/s') @@ -49,7 +49,9 @@ def parse_instance_log_file(log_files): speeds.sort() trimmed_speeds = speeds[1:-1] - average_speed[transfer_size] = sum(trimmed_speeds) / len(trimmed_speeds) + + if len(trimmed_speeds) > 0: + average_speed[transfer_size] = sum(trimmed_speeds) / len(trimmed_speeds) assert len(average_speed) > 0, "Migration should have occurred, but it was not detected. " @@ -86,31 +88,44 @@ async def test_migration_benchmark(model, migration_backend, migrated_request_st log_instance_info=True, request_migration_policy=request_migration_policy) subprocess.run(launch_command, shell=True, check=True) - await asyncio.sleep(5) await asyncio.sleep(30) - async def run_bench_command(command): - process = await asyncio.create_subprocess_shell(command) - await process.wait() - assert process.returncode == 0 + def run_bench_command(command): + # pylint: disable=consider-using-with + process = subprocess.Popen(command, shell=True) + return process tasks = [] - for i in range(device_count//2): - bench_command = generate_bench_command(ip_ports=f"127.0.0.1:{base_port+i}", model=model, num_prompts=300, - dataset_type="sharegpt", - dataset_path="/mnt/dataset/sharegpt_gpt4/sharegpt_gpt4.jsonl" , - qps=10, - results_filename=f"{base_port+i}.out") - tasks.append(asyncio.create_task(run_bench_command(bench_command))) - - _, pending = await asyncio.wait(tasks, timeout=60*30) - - await asyncio.sleep(10) - - if len(pending) > 0: - raise RuntimeError("migration task Timeout") - - parse_manager_log_file("manager_instance.csv") + for i in range(device_count // 2): + bench_command = generate_bench_command( + ip_ports=f"127.0.0.1:{base_port + i}", + model=model, + num_prompts=300, + dataset_type="sharegpt", + dataset_path="/mnt/dataset/sharegpt_gpt4/sharegpt_gpt4.jsonl", + qps=10, + results_filename=f"{base_port + i}.out" + ) + tasks.append(bench_command) + + # Execute the commands concurrently using ThreadPoolExecutor + with ThreadPoolExecutor() as executor: + future_to_command = {executor.submit(run_bench_command, command): command for command in tasks} + + for future in as_completed(future_to_command): + try: + process = future.result() + process.wait(timeout=60*30) + assert process.returncode == 0, "migration_test failed with return code {}.".format(process.returncode) + # pylint: disable=broad-except + except subprocess.TimeoutExpired: + process.kill() + print("bench_test timed out after 30 minutes.") + + await asyncio.sleep(5) + + # TODO(s5u13b): use a more definitive way to determine that there is no memory leak. + # parse_manager_log_file("manager_instance.csv") if migrated_request_status == 'running': average_speed = parse_instance_log_file(instance_output_logs) @@ -124,4 +139,4 @@ async def run_bench_command(command): shutdown_llumnix_service() clear_ray_state() - await asyncio.sleep(10) + await asyncio.sleep(3) diff --git a/tests/e2e_test/utils.py b/tests/e2e_test/utils.py index 1c38dcc8..62d9bff8 100644 --- a/tests/e2e_test/utils.py +++ b/tests/e2e_test/utils.py @@ -11,10 +11,6 @@ # See the License for the specific language governing permissions and # limitations under the License. - -import subprocess -import pytest - def to_markdown_table(data): headers = data[0] rows = data[1:] @@ -31,11 +27,3 @@ def to_markdown_table(data): table = f"{header_row}\n{separator_row}\n" + "\n".join(data_rows) + "\n\n" return table - -@pytest.fixture -def setup_ray_env(): - subprocess.run(["ray", "stop"], check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) - subprocess.run(["ray", "start", "--head", "--disable-usage-stats", "--port=6379"], check=True, - stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) - yield - subprocess.run(["ray", "stop"], check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) diff --git a/tests/unit_test/global_scheduler/test_llm_engine_manager.py b/tests/unit_test/global_scheduler/test_llm_engine_manager.py index 5f81baf6..5a09b283 100644 --- a/tests/unit_test/global_scheduler/test_llm_engine_manager.py +++ b/tests/unit_test/global_scheduler/test_llm_engine_manager.py @@ -288,3 +288,7 @@ def test_update_instance_info_loop_and_migrate(setup_ray_env, engine_manager): assert num_migrate_in == 0 and num_migrate_out > 1 else: assert num_migrate_in == 0 and num_migrate_out == 0 + +@pytest.mark.skip("Not implemented yet") +def test_concurrent_migrate(setup_ray_env): + pass diff --git a/tools/bench_test.sh b/tools/bench_test.sh index ec12cef9..7b18f794 100755 --- a/tools/bench_test.sh +++ b/tools/bench_test.sh @@ -1,6 +1,8 @@ #!/bin/bash set -ex +pgrep -f llumnix.entrypoints.vllm.api_server | { while read pid; do kill -9 "$pid"; done; } + nvidia-docker run --rm -t --net host --ipc host -v ${PWD}:/workspace -v /mnt:/mnt -w /workspace \ registry.cn-beijing.aliyuncs.com/llumnix/llumnix-dev:20240909_action_678a439 \ bash -c "pip install -e . > /dev/null && make bench_test" diff --git a/tools/e2e_test.sh b/tools/e2e_test.sh index 867a8aaf..6db4911a 100755 --- a/tools/e2e_test.sh +++ b/tools/e2e_test.sh @@ -1,6 +1,8 @@ # #!/bin/bash set -ex +pgrep -f llumnix.entrypoints.vllm.api_server | { while read pid; do kill -9 "$pid"; done; } + nvidia-docker run --rm -t --net host --ipc host -v ${PWD}:/workspace -v /mnt:/mnt -w /workspace \ registry.cn-beijing.aliyuncs.com/llumnix/llumnix-dev:20240909_action_678a439 \ bash -c "pip install -e . > /dev/null && make e2e_test" diff --git a/tools/migration_test.sh b/tools/migration_test.sh index 3e13ce55..4bf601e9 100755 --- a/tools/migration_test.sh +++ b/tools/migration_test.sh @@ -1,6 +1,8 @@ # #!/bin/bash set -ex +pgrep -f llumnix.entrypoints.vllm.api_server | { while read pid; do kill -9 "$pid"; done; } + nvidia-docker run --rm -t --net host --ipc host -v ${PWD}:/workspace -v /mnt:/mnt -w /workspace \ registry.cn-beijing.aliyuncs.com/llumnix/llumnix-dev:20240909_action_678a439 \ bash -c "pip install -e . > /dev/null && make migration_test" diff --git a/tools/offline_test.sh b/tools/offline_test.sh new file mode 100755 index 00000000..43ba9141 --- /dev/null +++ b/tools/offline_test.sh @@ -0,0 +1,8 @@ +#!/bin/bash +set -ex + +pgrep -f llumnix.entrypoints.vllm.api_server | { while read pid; do kill -9 "$pid"; done; } + +nvidia-docker run --rm -t --net host --ipc host -v ${PWD}:/workspace -v /mnt:/mnt -w /workspace \ + registry.cn-beijing.aliyuncs.com/llumnix/llumnix-dev:20240909_action_678a439 \ + bash -c "pip install -e . > /dev/null && make offline_test" diff --git a/tools/unit_test.sh b/tools/unit_test.sh index 0b075df3..114fb773 100755 --- a/tools/unit_test.sh +++ b/tools/unit_test.sh @@ -1,6 +1,8 @@ # #!/bin/bash set -ex +pgrep -f llumnix.entrypoints.vllm.api_server | { while read pid; do kill -9 "$pid"; done; } + nvidia-docker run --rm -t --net host --ipc host -v ${PWD}:/workspace -v /mnt:/mnt -w /workspace \ registry.cn-beijing.aliyuncs.com/llumnix/llumnix-dev:20240909_action_678a439 \ bash -c "pip install -e . > /dev/null && make unit_test"