Skip to content

Commit

Permalink
[CI] Add unittest for global_scheduler and entrypoints (#12)
Browse files Browse the repository at this point in the history
  • Loading branch information
s5u13b authored Aug 23, 2024
1 parent 309c296 commit 0b48bbc
Show file tree
Hide file tree
Showing 23 changed files with 1,026 additions and 256 deletions.
30 changes: 15 additions & 15 deletions benchmark/benchmark_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@
from typing import List


num_finished_request = 0
server_num_request = {}
num_finished_requests = 0
server_num_requests = {}


def get_wait_time(mean_time_between_requests: float, distribution: str, coefficient_variation: float = 0.0) -> float:
Expand Down Expand Up @@ -76,11 +76,11 @@ async def query_model_vllm(prompt, verbose, ip_ports):
prompt, prompt_len, expected_response_len = prompt

# Round-Robin dispatch request to the given api servers.
global server_num_request
server_id = min(server_num_request, key=server_num_request.get)
server_num_request[server_id] += 1
global server_num_requests
server_id = min(server_num_requests, key=server_num_requests.get)
server_num_requests[server_id] += 1
timeout = aiohttp.ClientTimeout(total=4*60*60)
global num_finished_request
global num_finished_requests

async with aiohttp.ClientSession(timeout=timeout) as session:
# TODO(yiwang): Remove hard codes of params.
Expand Down Expand Up @@ -111,8 +111,8 @@ async def query_model_vllm(prompt, verbose, ip_ports):
output['response_len'] = expected_response_len
if verbose and 'generated_text' in output:
print(json.dumps(output['generated_text']))
num_finished_request += 1
print("num_finised_request: {}".format(num_finished_request))
num_finished_requests += 1
print("num_finised_requests: {}".format(num_finished_requests))
return (prompt, output)
except aiohttp.ClientError as e:
print(f"Connect to {ip_ports[server_id]} failed with: {str(e)}")
Expand Down Expand Up @@ -334,18 +334,18 @@ def plot_instance(log_filename_0):
log_files.sort(key=os.path.getmtime, reverse=True)
df_0 = pd.read_csv(log_files[0]).sort_values(by=["timestamp"])
timestamp_list_0 = df_0["timestamp"].to_numpy()
instance_num_list_0 = df_0["num_instance"].to_numpy()
num_instances_list_0 = df_0["num_instances"].to_numpy()
time_0 = 0
sum_0 = 0
for idx, t in enumerate(timestamp_list_0):
if t > time_0:
time_0 += 1
sum_0 += instance_num_list_0[idx]
sum_0 += num_instances_list_0[idx]
print(f"{sum_0/time_0} gpu/s")
avg_instance_num = np.round(sum_0/time_0, 2)

fig, ax = plt.subplots()
ax.plot(timestamp_list_0, instance_num_list_0, color="red", label=f"instance_num(avg {avg_instance_num} /s)")
ax.plot(timestamp_list_0, num_instances_list_0, color="red", label=f"instance_num(avg {avg_instance_num} /s)")
ax.legend(loc='upper left')
fig_filename = os.path.splitext(log_filename_0)[0] + "_instance.png"
index1 = fig_filename.rfind('/')
Expand Down Expand Up @@ -437,10 +437,10 @@ async def benchmark(
else:
raise ValueError(f'unknown backend {backend}')

global server_num_request
num_server = len(ip_ports)
for server_id in range(num_server):
server_num_request[server_id] = 0
global server_num_requests
num_servers = len(ip_ports)
for server_id in range(num_servers):
server_num_requests[server_id] = 0

m = MeasureLatency()

Expand Down
12 changes: 6 additions & 6 deletions docs/Arguments.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@ usage: -m llumnix.entrypoints.vllm.api_server [-h]
[--fixed-node-init-instance]
[--init-instance-by-manager]
[--initial-instances INITIAL_INSTANCES]
[--load-metric {consumed_speed,used_ratio}]
[--load-metric {remaining_steps,usage_ratio}]
[--polling-interval POLLING_INTERVAL]
[--dispatch-policy {balanced,load,queue}]
[--enable-migration]
[--pair-migration-frequency PAIR_MIGRATION_FREQUENCY]
[--pair-migration-policy {balanced,prefill_constrained,prefill_relaxed}]
[--pair-migration-policy {balanced,defrag_constrained,defrag_relaxed}]
[--migrate-out-threshold MIGRATE_OUT_THRESHOLD]
[--request-migration-policy {LCFS,SJF,LJF}]
[--enable-defrag ENABLE_DEFRAG]
Expand Down Expand Up @@ -48,8 +48,8 @@ usage: -m llumnix.entrypoints.vllm.api_server [-h]

`--load-metric`
- Instance load metric.
- Possible choices: consumed_speed, used_ratio
- Default: "consumed_speed"
- Possible choices: remaining_steps, usage_ratio
- Default: "remaining_steps"

`--polling-interval`
- Time interval(s) to update instance info and pair migration.
Expand Down Expand Up @@ -139,11 +139,11 @@ usage: -m llumnix.entrypoints.vllm.api_server [-h]
- Default: 512

`--last-stage-max-blocks`
- If the remaining blocks num < last_stage_max_blocks, do last stage migration.
- If the number of remaining blocks < last_stage_max_blocks, do last stage migration.
- Default: 4

`--max-stages`
- Drop migration if stage num > max_stages.
- Drop migration if the number of stages > max_stages.
- Default: 3

# Unsupported vLLM feature options
Expand Down
16 changes: 8 additions & 8 deletions llumnix/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,15 @@ class EngineManagerArgs:
initial_instances: int = 1
fixed_node_init_instance: bool = False

load_metric: str = 'consumed_speed'
load_metric: str = 'remaining_steps'
polling_interval: float = 0.05

dispatch_policy: str = 'load'

enable_migration: bool = True
enable_defrag: bool = True
pair_migration_frequency: int = 1
pair_migration_policy: str = 'prefill_constrained'
pair_migration_policy: str = 'defrag_constrained'
migrate_out_threshold: float = 3.0
request_migration_policy: str = 'SJF'

Expand Down Expand Up @@ -87,8 +87,8 @@ def from_cli_args(cls, args: argparse.Namespace) -> 'EngineManagerArgs':
# Get the list of attributes of this dataclass.
attrs = [attr.name for attr in dataclasses.fields(cls)]
# Set the attributes from the parsed arguments.
engine_args = cls(**{attr: getattr(args, attr) for attr in attrs})
return engine_args
engine_manager_args = cls(**{attr: getattr(args, attr) for attr in attrs})
return engine_manager_args

@staticmethod
def add_cli_args(
Expand All @@ -107,7 +107,7 @@ def add_cli_args(
parser.add_argument('--load-metric',
type=str,
default=EngineManagerArgs.load_metric,
choices=['consumed_speed', 'used_ratio'],
choices=['remaining_steps', 'usage_ratio'],
help='instance load metric')
parser.add_argument('--polling-interval',
type=float,
Expand All @@ -130,7 +130,7 @@ def add_cli_args(
parser.add_argument('--pair-migration-policy',
type=str,
default=EngineManagerArgs.pair_migration_policy,
choices=['balanced', 'prefill_constrained', 'prefill_relaxed'],
choices=['balanced', 'defrag_constrained', 'defrag_relaxed'],
help='pair migration policy')
parser.add_argument('--migrate-out-threshold',
type=float,
Expand Down Expand Up @@ -207,10 +207,10 @@ def add_cli_args(
parser.add_argument('--last-stage-max-blocks',
type=int,
default=EngineManagerArgs.last_stage_max_blocks,
help='if the remain blocks num < last_stage_max_blocks, do last stage migration')
help='if the number pf remain blocks < last_stage_max_blocks, do last stage migration')
parser.add_argument('--max-stages',
type=int,
default=EngineManagerArgs.max_stages,
help='drop migration if stage num > max_stages')
help='drop migration if the number of stages > max_stages')

return parser
8 changes: 4 additions & 4 deletions llumnix/backends/vllm/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,9 @@ def step(self) -> None:
instance_info: InstanceInfo = self.scheduler.get_instance_info()

if self.scaling_down:
instance_info.num_running_request = 1
instance_info.num_available_gpu_block = -self.cache_config.num_gpu_blocks
instance_info.num_available_gpu_block_waiting = -self.cache_config.num_gpu_blocks
instance_info.num_running_requests = 1
instance_info.num_available_gpu_blocks = -self.cache_config.num_gpu_blocks
instance_info.num_available_gpu_blocks_waiting = -self.cache_config.num_gpu_blocks

instance_info.instance_id = self.instance_id
instance_info.step_id = next(self.step_counter)
Expand All @@ -136,7 +136,7 @@ def step(self) -> None:
blocks = self.scheduler.block_manager.get_block_table(seq)
tot_blocks.extend(blocks)
tot_blocks = set(tot_blocks)
instance_info.num_block_last_running_request = len(tot_blocks)
instance_info.num_blocks_last_running_request = len(tot_blocks)

self.free_request_states(instance_info.finished_request_ids)

Expand Down
46 changes: 23 additions & 23 deletions llumnix/backends/vllm/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def _preempt(
self.last_preemption_time_dict[seq_group.request_id] = time.time()
return super()._preempt(seq_group, blocks_to_swap_out, preemption_mode)

def _get_num_killed_request(self) -> int:
def _get_num_killed_requests(self) -> int:
cnt = len(self.swapped)
for seq_group in self.waiting:
if seq_group.request_id in self.last_preemption_time_dict:
Expand Down Expand Up @@ -187,44 +187,44 @@ def free_src_request(self, backend_request: SequenceGroup) -> None:

@scheduler_lock
def get_instance_info(self) -> InstanceInfo:
num_total_gpu_block = self.cache_config.num_gpu_blocks
num_free_gpu_block = self.block_manager.get_num_free_gpu_blocks()
num_used_gpu_block = num_total_gpu_block - num_free_gpu_block
gpu_cache_usage = num_used_gpu_block / num_total_gpu_block
num_total_gpu_blocks = self.cache_config.num_gpu_blocks
num_free_gpu_blocks = self.block_manager.get_num_free_gpu_blocks()
num_used_gpu_blocks = num_total_gpu_blocks - num_free_gpu_blocks
gpu_cache_usage = num_used_gpu_blocks / num_total_gpu_blocks
if self.waiting:
num_block_waiting_requests = []
num_blocks_waiting_requests = []
waiting_time_waiting_requests = []
for seq_group in self.waiting:
num_prompt_token = seq_group.get_seqs()[0].get_len()
num_block = num_prompt_token / self.cache_config.block_size
num_prompt_tokens = seq_group.get_seqs()[0].get_len()
num_blocks = num_prompt_tokens / self.cache_config.block_size
waiting_time = time.time() - seq_group.metrics.arrival_time
num_block_waiting_requests.append(num_block)
num_blocks_waiting_requests.append(num_blocks)
waiting_time_waiting_requests.append(waiting_time)
num_block_first_waiting_request = num_block_waiting_requests[0]
num_blocks_first_waiting_request = num_blocks_waiting_requests[0]
waiting_time_first_waiting_request = waiting_time_waiting_requests[0]
num_block_all_waiting_request = sum(num_block_waiting_requests)
num_blocks_all_waiting_requests = sum(num_blocks_waiting_requests)
else:
num_block_first_waiting_request = 0
num_blocks_first_waiting_request = 0
waiting_time_first_waiting_request = 0
num_block_all_waiting_request = 0
num_blocks_all_waiting_requests = 0
instance_info = InstanceInfo(
num_total_gpu_block=num_total_gpu_block,
num_watermark_block=self.block_manager.watermark_blocks,
num_free_gpu_block=num_free_gpu_block,
num_used_gpu_block=num_used_gpu_block,
num_total_gpu_blocks=num_total_gpu_blocks,
num_watermark_blocks=self.block_manager.watermark_blocks,
num_free_gpu_blocks=num_free_gpu_blocks,
num_used_gpu_blocks=num_used_gpu_blocks,
gpu_cache_usage=gpu_cache_usage,
num_running_request=len(self.running),
num_waiting_request=len(self.waiting),
num_killed_request=self._get_num_killed_request(),
num_block_first_waiting_request=num_block_first_waiting_request,
num_running_requests=len(self.running),
num_waiting_requests=len(self.waiting),
num_killed_requests=self._get_num_killed_requests(),
num_blocks_first_waiting_request=num_blocks_first_waiting_request,
waiting_time_first_waiting_request=waiting_time_first_waiting_request,
num_block_all_waiting_request=num_block_all_waiting_request,
num_blocks_all_waiting_requests=num_blocks_all_waiting_requests,
inference_type=BackendInferenceType.PREFILL if self.prefilling_seq_groups \
else BackendInferenceType.DECODE,
)
for seq_group in self.running:
instance_info.running_seq_lens.extend([seq.get_len() for seq in seq_group.get_seqs()])
instance_info.num_seq = len(instance_info.running_seq_lens)
instance_info.num_seqs = len(instance_info.running_seq_lens)
instance_info.num_batched_tokens = sum([seq_group.get_seqs()[0].get_len() for seq_group in self.prefilling_seq_groups])\
if self.prefilling_seq_groups else len(instance_info.running_seq_lens)
instance_info.finished_request_ids = [seq_group.request_id for seq_group in self.running if seq_group.is_finished()]
Expand Down
4 changes: 2 additions & 2 deletions llumnix/backends/vllm/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,8 +209,8 @@ def restart(self) -> None:
self.init_cache_engine(self.cache_config)

# instance_id is changed from int to str, this function should be modified if used
# def init_migration_dist_ray(self, num_instance, instance_id):
# self.ray_world_size = num_instance * self.parallel_config.world_size
# def init_migration_dist_ray(self, num_instances, instance_id):
# self.ray_world_size = num_instances * self.parallel_config.world_size
# self.ray_rank = self.rank + instance_id * self.parallel_config.world_size
# logger.info(f"{self.ray_world_size, self.ray_rank}")
# # col.init_collective_group(world_size=self.ray_world_size, rank=self.ray_rank , backend="gloo")
Expand Down
17 changes: 16 additions & 1 deletion llumnix/entrypoints/llumnix_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,16 @@
# Copyright (c) 2024, Alibaba Group;
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import subprocess
import sys
import os
Expand All @@ -15,6 +28,7 @@
from llumnix.logger import init_logger
from llumnix.arg_utils import EngineManagerArgs


logger = init_logger(__name__)

# TODO(s5u13b): Set the values through tests.
Expand All @@ -29,7 +43,7 @@ def get_ip_address():
ip_address = result.stdout.decode('utf-8').strip()
return ip_address

def launch_ray_cluster(ray_cluster_port: int) -> None:
def launch_ray_cluster(ray_cluster_port: int) -> subprocess.CompletedProcess:
head_node_ip = os.getenv('HEAD_NODE_IP')
node_ip_address = get_ip_address()
try:
Expand Down Expand Up @@ -66,6 +80,7 @@ def launch_ray_cluster(ray_cluster_port: int) -> None:
sys.exit(1)
logger.info("'{}' succeeed with: \n{}".format(ray_start_command, result.stdout))
ray.init(address=f"{head_node_ip}:{ray_cluster_port}", ignore_reinit_error=True, namespace='llumnix')
return result

def is_gpu_available() -> bool:
try:
Expand Down
22 changes: 11 additions & 11 deletions llumnix/entrypoints/vllm/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,14 @@
logger = init_logger(__name__)
engine_manager = None
instances = {}
instance_num_request: Dict[str, int] = {}
instance_num_requests: Dict[str, int] = {}
# request_output_queue could be None if initialzed in lifespan.
request_output_queue = None
server_id = None
TIMEOUT_KEEP_ALIVE = 5 # seconds.
request_streams: Dict[str, AsyncStream] = {}
log_requests = None
num_finished_request = 0
num_finished_requests = 0
WAIT_MANAGER_INTERVAL = 5


Expand Down Expand Up @@ -82,9 +82,9 @@ async def manager_generate(prompt, sampling_params, request_id) -> AsyncStream:
await engine_manager.generate.remote(request_id, server_info, prompt, sampling_params)
except ray.exceptions.RayActorError:
try:
if instance_num_request:
instance_id = min(instance_num_request, key=instance_num_request.get)
instance_num_request[instance_id] += 1
if instance_num_requests:
instance_id = min(instance_num_requests, key=instance_num_requests.get)
instance_num_requests[instance_id] += 1
await instances[instance_id].generate.remote(request_id, server_info, prompt, sampling_params)
print("Manager is unavailable, directly pass request {} to instance {}".format(request_id, instance_id))
else:
Expand All @@ -96,7 +96,7 @@ async def manager_generate(prompt, sampling_params, request_id) -> AsyncStream:
if instance_id in instances:
print("[manager_generate] instance {} is dead".format(instance_id))
del instances[instance_id]
del instance_num_request[instance_id]
del instance_num_requests[instance_id]
return await asyncio.create_task(manager_generate(prompt, sampling_params, request_id))
return results_generator

Expand Down Expand Up @@ -185,12 +185,12 @@ async def generate_benchmark(request: Request) -> Response:
start = now_time
final_output = request_output

global num_finished_request
global num_finished_requests
if log_requests:
# TODO(s5u13b): Use logger.
print(f"Finished request {request_id}.")
num_finished_request += 1
print(f"num_finished_request {num_finished_request}.")
num_finished_requests += 1
print(f"num_finished_requests {num_finished_requests}.")

generation = final_output.outputs[0].text
num_output_tokens = len(final_output.outputs[0].token_ids)
Expand Down Expand Up @@ -218,7 +218,7 @@ async def is_ready():
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--port", type=int, default=8003)
parser.add_argument("--port", type=int, default=8000)
parser.add_argument("--ssl-keyfile", type=str, default=None)
parser.add_argument("--ssl-certfile", type=str, default=None)
parser.add_argument('--disable-log-requests-server',
Expand Down Expand Up @@ -249,7 +249,7 @@ async def is_ready():
engine_manager, instance_ids, llumlets, request_output_queue = init_llumnix_components(engine_manager_args, engine_args, node_id)
for idx, ins_id in enumerate(instance_ids):
instances[ins_id] = llumlets[idx]
instance_num_request[ins_id] = 0
instance_num_requests[ins_id] = 0
log_requests = not args.disable_log_requests_server
# Start the api server after all the components of llumnix are ready.
print(f"Start Api Server on '{args.host}:{args.port}'")
Expand Down
Loading

0 comments on commit 0b48bbc

Please sign in to comment.