From c2877b4f1439c5791beb6a9ace489db9b1afd393 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=99=BA=E9=B8=A3?= Date: Thu, 25 Jul 2024 13:07:40 +0800 Subject: [PATCH 1/3] fix typo --- benchmark/benchmark_serving.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmark/benchmark_serving.py b/benchmark/benchmark_serving.py index b5afe3b7..c915e1f5 100644 --- a/benchmark/benchmark_serving.py +++ b/benchmark/benchmark_serving.py @@ -424,7 +424,7 @@ async def benchmark( allow_variable_generation_length: bool, verbose: bool, results_filename: str, - ip_ports: list[int], + ip_ports: List[int], distribution: str, qps: float, coefficient_variation: float, From 572cdbbf88c58f489d915eab32b2c1e7377c9253 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=99=BA=E9=B8=A3?= Date: Thu, 25 Jul 2024 14:03:01 +0800 Subject: [PATCH 2/3] record inference latency in executor --- llumnix/backends/vllm/executor.py | 20 ++++++++++++++++---- llumnix/backends/vllm/llm_engine.py | 4 +--- 2 files changed, 17 insertions(+), 7 deletions(-) diff --git a/llumnix/backends/vllm/executor.py b/llumnix/backends/vllm/executor.py index 8cfd136f..c3fd183a 100644 --- a/llumnix/backends/vllm/executor.py +++ b/llumnix/backends/vllm/executor.py @@ -36,6 +36,7 @@ class LlumnixRayGPUExecutor(RayGPUExecutor): def _init_workers_ray(self, placement_group: "PlacementGroup", **ray_remote_kwargs): + self.last_inference_latency = 0 if self.parallel_config.tensor_parallel_size == 1: # For single GPU case, we use a ray worker with constrained memory. num_gpus = self.cache_config.gpu_memory_utilization @@ -73,7 +74,7 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", num_cpus=0, num_gpus=num_gpus, scheduling_strategy=scheduling_strategy, - max_concurrency=4, + max_concurrency=2, **ray_remote_kwargs, )(RayWorkerWrapper).remote( worker_module_name="llumnix.backends.vllm.worker", @@ -146,10 +147,18 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", max_concurrent_workers=self.parallel_config. max_parallel_loading_workers) + def execute_model(self, *args, **kwargs): + t0 = time.time() + outputs = super().execute_model(*args, **kwargs) + t1 = time.time() + self.last_inference_latency = (t1 - t0) * 1000 + return outputs + class SimGPUExecutor(GPUExecutor): latency_mem: LatencyMemData = None def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) + self.last_inference_latency = 0 self.migration_bandwidth = self.latency_mem.migration_bandwidth # TODO(ziming) add swap bandwidth @@ -187,10 +196,13 @@ def execute_model( decode_bs = _pad_to_alignment(decode_bs, 8) latency = 0 if prefill_seq_len: - latency += model_prefill(prefill_seq_len, *self.latency_mem.prefill_model_params) / 1000 + latency += self.latency_mem.prefill_latency[prefill_seq_len][0] if prefill_seq_len in self.latency_mem.prefill_latency \ + else model_prefill(prefill_seq_len, *self.latency_mem.prefill_model_params) if decode_bs: - latency += model_decode((decode_bs, decode_seq_len), *self.latency_mem.decode_model_params) / 1000 - time.sleep(latency) + decode_meta_data = (decode_bs, decode_seq_len) + latency += self.latency_mem.decode_latency[decode_meta_data][0] if decode_meta_data in self.latency_mem.decode_latency \ + else model_decode((decode_bs, decode_seq_len), *self.latency_mem.decode_model_params) + time.sleep(latency/1000) sampler_outputs = [] for meta_data in execute_model_req.seq_group_metadata_list: samples = [] diff --git a/llumnix/backends/vllm/llm_engine.py b/llumnix/backends/vllm/llm_engine.py index c3311d1a..cee6cd1a 100644 --- a/llumnix/backends/vllm/llm_engine.py +++ b/llumnix/backends/vllm/llm_engine.py @@ -177,9 +177,7 @@ def send_blocks(self, dst_ray_actor: "ray.actor.ActorHandle", src_blocks: List[i src_worker_handle_list=self.worker_handle_list)) def step(self) -> Tuple[List[RequestOutput], InstanceInfo, List[ServerInfo]]: - t0_inference_begin = time.time() output_list = self.engine.step() - t1_inference_end = time.time() instance_info: InstanceInfo = self.engine.scheduler.get_record_instance_info() @@ -191,7 +189,7 @@ def step(self) -> Tuple[List[RequestOutput], InstanceInfo, List[ServerInfo]]: instance_info.instance_id = self.instance_id instance_info.step_id = next(self.step_counter) instance_info.timestamp = time.time() - instance_info.latency = (t1_inference_end - t0_inference_begin)*1000 + instance_info.latency = self.engine.model_executor.last_inference_latency seq_groups = self.engine.scheduler.running if seq_groups: tot_blocks = [] From 15919f22cccd8f898c9738e721c9d82602a54e5b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=99=BA=E9=B8=A3?= Date: Thu, 25 Jul 2024 14:26:16 +0800 Subject: [PATCH 3/3] use await in _background_process_outputs to improve api server throughput --- llumnix/entrypoints/vllm/api_server.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/llumnix/entrypoints/vllm/api_server.py b/llumnix/entrypoints/vllm/api_server.py index 21663cd9..344b879b 100644 --- a/llumnix/entrypoints/vllm/api_server.py +++ b/llumnix/entrypoints/vllm/api_server.py @@ -49,7 +49,8 @@ async def _background_process_outputs(): while True: - request_outputs = request_output_queue.get_nowait_batch(num_items=request_output_queue.qsize()) + qsize = await request_output_queue.actor.qsize.remote() + request_outputs = await request_output_queue.actor.get_nowait_batch.remote(qsize) for request_output in request_outputs: request_id = request_output.request_id # Request could be dispatched twice when manager is dead, the first request will free the request_streams when finished. @@ -59,7 +60,6 @@ async def _background_process_outputs(): if request_output.finished: request_streams[request_id].finish() del request_streams[request_id] - await asyncio.sleep(0.01) # pylint: disable=unused-argument @asynccontextmanager