From d04bd45323bb913ab3f5d3723f95bcce8273d46b Mon Sep 17 00:00:00 2001 From: irexyc Date: Mon, 11 Mar 2024 06:25:05 +0000 Subject: [PATCH] remove stream output --- lmdeploy/serve/async_engine.py | 116 ++++++++++++--------------------- 1 file changed, 42 insertions(+), 74 deletions(-) diff --git a/lmdeploy/serve/async_engine.py b/lmdeploy/serve/async_engine.py index a9258358a..1bcaf0333 100644 --- a/lmdeploy/serve/async_engine.py +++ b/lmdeploy/serve/async_engine.py @@ -81,10 +81,8 @@ class Session: def __init__(self): self._id = next(self._ids) self._step = 0 - self._queue = None self._prompt = None self._response = None - self._work_thread = None self._engine = None self.history = [] @@ -96,40 +94,14 @@ def _merge_response(self, resp: Response, step: Response): return resp def response(self) -> Response: - if self._response: - return self._response - elif self._queue is not None: - resp = Response('', -1, -1, self._id) - while True: - step = self._queue.get() - resp = self._merge_response(resp, step) - if resp.finish_reason is not None: - break - self._step += resp.generate_token_len + resp.input_token_len - self._response = resp - self.history.append((self._prompt, resp.text)) return self._response - def stream_response(self): - resp = Response('', -1, -1, self._id) - while self._queue is not None: - step = self._queue.get() - resp = self._merge_response(resp, step) - yield step - if resp.finish_reason is not None: - break - self._step += resp.generate_token_len + resp.input_token_len - self._response = resp - self.history.append((self._prompt, resp.text)) - def close(self): if self._engine: inst = self._engine.create_instance() inst.cancel(self._id) - _ = self.response() def __repr__(self) -> str: - _ = self.response() res = '' for user, assistant in self.history: res += f'USER:\n{user}\nASSISTANT:\n{assistant}\n' @@ -620,10 +592,8 @@ def chat(self, session._engine = self.engine # sync & init - _ = session.response() session._prompt = prompt session._response = None - session._queue = Queue() sequence_start = session._step == 0 @@ -645,49 +615,47 @@ def chat(self, prompt = self.chat_template.messages2prompt(prompt, sequence_start) input_ids = self.tokenizer.encode(prompt, add_bos=sequence_start) - def _work_thread(): - if gen_config.max_new_tokens is None: - # for interactive endpoint, will try maximum possible token num - gen_config.max_new_tokens = max( - 128, self.session_len - session._step - len(input_ids)) - finish_reason = None - if session._id + len( - input_ids) + gen_config.max_new_tokens > self.session_len: - finish_reason = 'length' - resp = Response('', 0, len(input_ids), session._id, - finish_reason) - session._queue.put(resp) - else: - generator = self.engine.create_instance() - state = DetokenizeState() - for outputs in generator.stream_infer( - session_id=session._id, - input_ids=input_ids, - sequence_start=sequence_start, - step=session._step, - gen_config=gen_config, - stream_output=True): - _, res, tokens = outputs - print(session._id, tokens) - response, state = self.tokenizer.detokenize_incrementally( - res, - state, - skip_special_tokens=gen_config.skip_special_tokens) - resp = Response(response, tokens, len(input_ids), - session._id, finish_reason) - session._queue.put(resp) - - finish_reason = 'length' \ - if tokens >= gen_config.max_new_tokens else 'stop' - # utf-8 char at the end means it's a potential unfinished - # byte sequence - if not response.endswith('�'): - response = '' # avaid returning the last response twice - resp = Response(response, tokens, len(input_ids), session._id, - finish_reason) - session._queue.put(resp) + if gen_config.max_new_tokens is None: + # for interactive endpoint, will try maximum possible token num + gen_config.max_new_tokens = max( + 128, self.session_len - session._step - len(input_ids)) + finish_reason = None + if session._id + len( + input_ids) + gen_config.max_new_tokens > self.session_len: + finish_reason = 'length' + resp = Response('', 0, len(input_ids), session._id, finish_reason) + else: + generator = self.engine.create_instance() + state = DetokenizeState() + resp = Response('', -1, -1, session._id) + for outputs in generator.stream_infer( + session_id=session._id, + input_ids=input_ids, + sequence_start=sequence_start, + step=session._step, + gen_config=gen_config, + stream_output=False): + _, res, tokens = outputs + response, state = self.tokenizer.detokenize_incrementally( + res, + state, + skip_special_tokens=gen_config.skip_special_tokens) + _resp = Response(response, tokens, len(input_ids), session._id, + finish_reason) + resp = session._merge_response(resp, _resp) + + finish_reason = 'length' \ + if tokens >= gen_config.max_new_tokens else 'stop' + # utf-8 char at the end means it's a potential unfinished + # byte sequence + if not response.endswith('�'): + response = '' # avaid returning the last response twice + _resp = Response(response, tokens, len(input_ids), session._id, + finish_reason) + resp = session._merge_response(resp, _resp) + + session._response = resp + session._step += resp.generate_token_len + resp.input_token_len + session.history.append((session._prompt, resp.text)) - work_thread = Thread(target=_work_thread) - work_thread.start() - session._work_thread = work_thread return session