Skip to content

Commit

Permalink
async recv
Browse files Browse the repository at this point in the history
  • Loading branch information
grimoire committed Feb 2, 2024
1 parent f3146e0 commit 6b6dcae
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 9 deletions.
41 changes: 33 additions & 8 deletions lmdeploy/pytorch/engine/engine.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
import asyncio
import time
from dataclasses import dataclass
from queue import Queue
Expand Down Expand Up @@ -877,6 +876,7 @@ async def async_stream_infer(self,
session_id: int,
input_ids: List[int],
gen_config: EngineGenerationConfig = None,
adapter_name: str = None,
**kwargs):
"""Send stream inference request.
Expand All @@ -891,12 +891,38 @@ async def async_stream_infer(self,
List[int]: The streaming output tokens.
int: The number of the output tokens.
"""
for item in self.stream_infer(session_id=session_id,
input_ids=input_ids,
gen_config=gen_config,
**kwargs):
await asyncio.sleep(0)
yield item
gen_config = gen_config or EngineGenerationConfig()
request_output_len = gen_config.max_new_tokens
sampling_param = SamplingParam.from_gen_config(gen_config=gen_config)
self._try_add_session(session_id)
msg = dict(
token_ids=input_ids,
session_id=session_id,
max_request_output_len=request_output_len,
sampling_param=sampling_param,
adapter_name=adapter_name,
)
req_id = self.req_sender.send_async(RequestType.ADD_MESSAGE, msg)

token_ids = []
while True:
if not self.engine.loop_threads.is_alive():
yield (ResponseType.ENGINE_STOP_ERROR, [], 0)
break
resp = await self.req_sender.async_recv(req_id)
# avoid token decoding and scheduling simultaneously
if resp.req_id != req_id:
continue
if resp.type == ResponseType.SUCCESS:
token_ids += resp.data['token_ids']
yield (resp.type, token_ids, len(token_ids))
elif resp.type == ResponseType.FINISH:
token_ids += resp.data['token_ids']
yield (resp.type, token_ids, len(token_ids))
break
else:
yield (resp.type, [], 0)
break

def stream_infer(self,
session_id: int,
Expand Down Expand Up @@ -937,7 +963,6 @@ def stream_infer(self,
if not self.engine.loop_threads.is_alive():
yield (ResponseType.ENGINE_STOP_ERROR, [], 0)
break

resp = self.req_sender.recv(req_id)
# avoid token decoding and scheduling simultaneously
if resp.req_id != req_id:
Expand Down
38 changes: 38 additions & 0 deletions lmdeploy/pytorch/engine/request.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
import asyncio
import enum
from dataclasses import dataclass, field
from queue import Empty, Queue
Expand Down Expand Up @@ -80,6 +81,26 @@ def _resp_que_get(self, block: bool = True, timeout: float = None):

return self.resp_que.get(timeout=timeout_counter)

async def _async_resp_que_get(self,
block: bool = True,
timeout: float = None):
"""warp of resp_que.get."""
if not block:
return self.resp_que(block=block, timeout=timeout)
timeout_counter = timeout or float(1 << 30)
while timeout_counter > self.THREAD_ALIVE_INTERVAL:
if self.resp_que.qsize() == 0:
await asyncio.sleep(self.THREAD_ALIVE_INTERVAL)
timeout_counter -= self.THREAD_ALIVE_INTERVAL
else:
return self.resp_que.get(block=False)
if self._thread and not self._thread.is_alive():
logger.error('Engine main loop stopped.')
exit(1)

await asyncio.sleep(self.THREAD_ALIVE_INTERVAL)
return self.resp_que.get(block=False)

def _push_resp(self, req_id: int, resp: Response):
"""push response."""
self.resp_dict.setdefault(req_id, [])
Expand Down Expand Up @@ -168,6 +189,23 @@ def recv(self, req_id: int, que_timeout: float = None) -> Response:
else:
return resp

async def async_recv(self,
req_id: int,
que_timeout: float = None) -> Response:
"""receive response of given request id async."""
ret = self._pop_resp(req_id, default=None)
if ret is not None:
return ret

# check resp que
while True:
resp: Response = await self._async_resp_que_get(timeout=que_timeout
)
if resp.req_id != req_id:
self._push_resp(req_id, resp)
else:
return resp

def send(self,
req_type: RequestType,
data: Any,
Expand Down
2 changes: 1 addition & 1 deletion lmdeploy/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,7 +416,7 @@ def detokenize_incrementally(self,
# Please notice that in VLLM, indexes are detokenized one by one
# while in LMDeploy, every turn, the detokenized indexes length
# can be different.
if skip_special_tokens and new_tokens[
if skip_special_tokens and new_tokens and new_tokens[
0] in tokenizer.all_special_ids:
read_offset = 1 # skip special token
output_tokens = new_tokens
Expand Down

0 comments on commit 6b6dcae

Please sign in to comment.