Skip to content

Commit

Permalink
remove stream output
Browse files Browse the repository at this point in the history
  • Loading branch information
irexyc committed Mar 11, 2024
1 parent 6a77145 commit d04bd45
Showing 1 changed file with 42 additions and 74 deletions.
116 changes: 42 additions & 74 deletions lmdeploy/serve/async_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,8 @@ class Session:
def __init__(self):
self._id = next(self._ids)
self._step = 0
self._queue = None
self._prompt = None
self._response = None
self._work_thread = None
self._engine = None
self.history = []

Expand All @@ -96,40 +94,14 @@ def _merge_response(self, resp: Response, step: Response):
return resp

def response(self) -> Response:
if self._response:
return self._response
elif self._queue is not None:
resp = Response('', -1, -1, self._id)
while True:
step = self._queue.get()
resp = self._merge_response(resp, step)
if resp.finish_reason is not None:
break
self._step += resp.generate_token_len + resp.input_token_len
self._response = resp
self.history.append((self._prompt, resp.text))
return self._response

def stream_response(self):
resp = Response('', -1, -1, self._id)
while self._queue is not None:
step = self._queue.get()
resp = self._merge_response(resp, step)
yield step
if resp.finish_reason is not None:
break
self._step += resp.generate_token_len + resp.input_token_len
self._response = resp
self.history.append((self._prompt, resp.text))

def close(self):
if self._engine:
inst = self._engine.create_instance()
inst.cancel(self._id)
_ = self.response()

def __repr__(self) -> str:
_ = self.response()
res = ''
for user, assistant in self.history:
res += f'USER:\n{user}\nASSISTANT:\n{assistant}\n'
Expand Down Expand Up @@ -620,10 +592,8 @@ def chat(self,
session._engine = self.engine

# sync & init
_ = session.response()
session._prompt = prompt
session._response = None
session._queue = Queue()

sequence_start = session._step == 0

Expand All @@ -645,49 +615,47 @@ def chat(self,
prompt = self.chat_template.messages2prompt(prompt, sequence_start)
input_ids = self.tokenizer.encode(prompt, add_bos=sequence_start)

def _work_thread():
if gen_config.max_new_tokens is None:
# for interactive endpoint, will try maximum possible token num
gen_config.max_new_tokens = max(
128, self.session_len - session._step - len(input_ids))
finish_reason = None
if session._id + len(
input_ids) + gen_config.max_new_tokens > self.session_len:
finish_reason = 'length'
resp = Response('', 0, len(input_ids), session._id,
finish_reason)
session._queue.put(resp)
else:
generator = self.engine.create_instance()
state = DetokenizeState()
for outputs in generator.stream_infer(
session_id=session._id,
input_ids=input_ids,
sequence_start=sequence_start,
step=session._step,
gen_config=gen_config,
stream_output=True):
_, res, tokens = outputs
print(session._id, tokens)
response, state = self.tokenizer.detokenize_incrementally(
res,
state,
skip_special_tokens=gen_config.skip_special_tokens)
resp = Response(response, tokens, len(input_ids),
session._id, finish_reason)
session._queue.put(resp)

finish_reason = 'length' \
if tokens >= gen_config.max_new_tokens else 'stop'
# utf-8 char at the end means it's a potential unfinished
# byte sequence
if not response.endswith('�'):
response = '' # avaid returning the last response twice
resp = Response(response, tokens, len(input_ids), session._id,
finish_reason)
session._queue.put(resp)
if gen_config.max_new_tokens is None:
# for interactive endpoint, will try maximum possible token num
gen_config.max_new_tokens = max(
128, self.session_len - session._step - len(input_ids))
finish_reason = None
if session._id + len(
input_ids) + gen_config.max_new_tokens > self.session_len:
finish_reason = 'length'
resp = Response('', 0, len(input_ids), session._id, finish_reason)
else:
generator = self.engine.create_instance()
state = DetokenizeState()
resp = Response('', -1, -1, session._id)
for outputs in generator.stream_infer(
session_id=session._id,
input_ids=input_ids,
sequence_start=sequence_start,
step=session._step,
gen_config=gen_config,
stream_output=False):
_, res, tokens = outputs
response, state = self.tokenizer.detokenize_incrementally(
res,
state,
skip_special_tokens=gen_config.skip_special_tokens)
_resp = Response(response, tokens, len(input_ids), session._id,
finish_reason)
resp = session._merge_response(resp, _resp)

finish_reason = 'length' \
if tokens >= gen_config.max_new_tokens else 'stop'
# utf-8 char at the end means it's a potential unfinished
# byte sequence
if not response.endswith('�'):
response = '' # avaid returning the last response twice
_resp = Response(response, tokens, len(input_ids), session._id,
finish_reason)
resp = session._merge_response(resp, _resp)

session._response = resp
session._step += resp.generate_token_len + resp.input_token_len
session.history.append((session._prompt, resp.text))

work_thread = Thread(target=_work_thread)
work_thread.start()
session._work_thread = work_thread
return session

0 comments on commit d04bd45

Please sign in to comment.