Skip to content

Commit

Permalink
refine
Browse files Browse the repository at this point in the history
  • Loading branch information
KuilongCui committed Dec 12, 2024
1 parent ec98fef commit 8ecb891
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 61 deletions.
6 changes: 3 additions & 3 deletions llumnix/backends/bladellm/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,20 +91,20 @@ def _start_put_queue_loop(self):
self._put_request_outputs_to_server(request_outputs, req_id_outputs, server_info_outputs)

def _put_request_outputs_to_server(self, request_outputs: List[GenerateStreamResponse],
req_ids: List[int], server_infos: List[ServerInfo]) -> None:
req_ids: List[str], server_infos: List[ServerInfo]) -> None:
server_request_outputs = defaultdict(list)
server_info_dict = {}
# Reorganize data in orther to put request output to queue in batch at one time.
for request_output, req_id, server_info in zip(request_outputs, req_ids, server_infos):
server_id = server_info.server_id
server_request_outputs[server_id].append((str(req_id), request_output.model_dump_json()))
server_request_outputs[server_id].append((req_id, request_output.model_dump_json()))
if server_id not in server_info_dict:
server_info_dict[server_id] = server_info
logger.debug("_put_request_outputs_to_server, {}", server_request_outputs)
self.async_put_queue_actor.put_nowait_to_servers.remote(server_request_outputs, server_info_dict)

async def send(self, req_id, msg, reset=False):
self.put_queue_args_queue.put_nowait((msg, req_id, self.request_client_map[req_id]))
self.put_queue_args_queue.put_nowait((msg, str(req_id), self.request_client_map[req_id]))
if msg.is_finished:
self.request_client_map.pop(req_id)

Expand Down
57 changes: 0 additions & 57 deletions llumnix/backends/bladellm/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,74 +11,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import time
from typing import List

from loguru import logger

from blade_llm.service.proto.bladellm_pb2 import WorkerStepRequest
from blade_llm.service.schedulers import PagedScheduler
from blade_llm.service.scheduler_types import SchedulerStepOutput
from blade_llm.service.args import ServingArgs

from llumnix.instance_info import InstanceInfo
from llumnix.llumlet.request import RequestInferenceType
from llumnix.backends.bladellm.metrics import BladellmMetrics

class PagedSchedulerLlumnix(PagedScheduler):
def __init__(self, serving_args: ServingArgs, *args, **kwargs) -> None:
PagedScheduler.__init__(self, serving_args, *args, **kwargs)
self.llumnix_metrics = BladellmMetrics()
self.llumnix_metrics.block_manager_init_metrics(self.block_manager)

def _get_instance_info(self, steps: List[WorkerStepRequest]) -> InstanceInfo:
num_total_gpu_blocks = self._max_processing_units
num_free_gpu_blocks = self.block_manager.get_num_free_gpu_blocks()
num_used_gpu_blocks = num_total_gpu_blocks - num_free_gpu_blocks
gpu_cache_usage = num_used_gpu_blocks / num_total_gpu_blocks
if self.waiting:
num_blocks_waiting_requests = []
waiting_time_waiting_requests = []
for seq_group in self.waiting:
num_prompt_tokens = len(seq_group.paged_reqs[0].token_ids)
num_blocks = num_prompt_tokens / self.block_size
waiting_time = time.monotonic() - seq_group.receive_time
num_blocks_waiting_requests.append(num_blocks)
waiting_time_waiting_requests.append(waiting_time)
num_blocks_first_waiting_request = num_blocks_waiting_requests[0]
waiting_time_first_waiting_request = waiting_time_waiting_requests[0]
num_blocks_all_waiting_requests = sum(num_blocks_waiting_requests)
else:
num_blocks_first_waiting_request = 0
waiting_time_first_waiting_request = 0
num_blocks_all_waiting_requests = 0
instance_info = InstanceInfo(
num_total_gpu_blocks=num_total_gpu_blocks,
num_watermark_blocks=self.block_manager.reserved_blocks,
num_used_gpu_blocks=num_used_gpu_blocks,
num_free_gpu_blocks=num_free_gpu_blocks,
gpu_cache_usage=gpu_cache_usage,
num_running_requests=len(self.running),
num_waiting_requests=len(self.waiting),
num_killed_requests=self._get_num_killed_requests(),
num_blocks_first_waiting_request=num_blocks_first_waiting_request,
waiting_time_first_waiting_request=waiting_time_first_waiting_request,
num_blocks_all_waiting_requests=num_blocks_all_waiting_requests,
)

for gen_group in self.running:
instance_info.running_seq_lens.extend([len(req_state.token_ids) for req_state in gen_group.paged_reqs])
instance_info.num_seqs = len(instance_info.running_seq_lens)

instance_info.inference_type = RequestInferenceType.generate_inference_type(
exist_prefill=any(len(step.prefill) > 0 for step in steps),
exist_decode=any(len(step.decode) > 0 for step in steps))
instance_info.num_batched_tokens = sum([
len(step.decode) + sum([len(prefill.prompt_tokens) for prefill in step.prefill]) for step in steps
])
instance_info.finished_request_ids = len(self._finished_req_to_remove)
logger.info("update in scheduler {}".format(instance_info.num_running_requests))
return instance_info

def step(self) -> SchedulerStepOutput:
step_out = super().step()
Expand Down
6 changes: 5 additions & 1 deletion llumnix/entrypoints/bladellm/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ class LlumnixClientBladellm(MultiProcessingLLMClient):
def __init__(self, args: ServingArgs, llumnix_context: LlumnixEntrypointsContext, loop: asyncio.AbstractEventLoop):
super().__init__(args, -1)
self.entrypoint_id2llumnix_id = {}
self.llumnix_id2entrypoint_id = {}
self.llumnix_context = llumnix_context
loop.create_task(self.background_process_outputs())

Expand All @@ -54,14 +55,17 @@ async def background_process_outputs(self):
continue
await self.llumnix_context.request_streams[request_id].put(request_output)
if request_output.is_finished:
del self.entrypoint_id2llumnix_id[request_id]
logger.info("Client Recv: {}".format(request_output))
del self.entrypoint_id2llumnix_id[self.llumnix_id2entrypoint_id[request_id]]
del self.llumnix_id2entrypoint_id[request_id]
del self.llumnix_context.request_streams[request_id]

async def _add_request(self, request: ServerRequest) -> LLMResponse:
if request.sampling_params.n > 1 or request.sampling_params.use_beam_search:
return error_resp(request.id, err_code=400, err_msg="Unsupported feature: multiple sequence decoding in Llumnix.")

llumnix_id = random.randint(0, 2147483647) # 1<<31-1
self.llumnix_id2entrypoint_id[str(llumnix_id)] = request.id
self.entrypoint_id2llumnix_id[request.id] = llumnix_id
request.id = llumnix_id
resp_stream = await self._manager_generate(request.model_dump_json(), str(llumnix_id))
Expand Down
1 change: 1 addition & 0 deletions llumnix/entrypoints/vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

WAIT_MANAGER_INTERVAL = 5


def add_cli_args(parser):
parser.set_namespace("llumnix")
parser = LlumnixEntrypointsArgs.add_cli_args(parser)
Expand Down

0 comments on commit 8ecb891

Please sign in to comment.