Skip to content

Commit

Permalink
make IPv6 compatible, safe run for coroutine interrupting (#487)
Browse files Browse the repository at this point in the history
* make IPv6 compatible, safe run for coroutine interrupting

* instance_id -> session_id and fix api_client.py

* update doc

* remove useless faq

* safe ip mapping

* update app.py

* remove print

* update doc
  • Loading branch information
AllentDan authored Oct 11, 2023
1 parent fbd9770 commit 759e1dd
Show file tree
Hide file tree
Showing 8 changed files with 137 additions and 107 deletions.
6 changes: 3 additions & 3 deletions benchmark/profile_restful_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

def get_streaming_response(prompt: str,
api_url: str,
instance_id: int,
session_id: int,
request_output_len: int,
stream: bool = True,
sequence_start: bool = True,
Expand All @@ -24,7 +24,7 @@ def get_streaming_response(prompt: str,
pload = {
'prompt': prompt,
'stream': stream,
'instance_id': instance_id,
'session_id': session_id,
'request_output_len': request_output_len,
'sequence_start': sequence_start,
'sequence_end': sequence_end,
Expand All @@ -36,7 +36,7 @@ def get_streaming_response(prompt: str,
stream=stream)
for chunk in response.iter_lines(chunk_size=8192,
decode_unicode=False,
delimiter=b'\0'):
delimiter=b'\n'):
if chunk:
data = json.loads(chunk.decode('utf-8'))
output = data['text']
Expand Down
13 changes: 6 additions & 7 deletions docs/en/restful_api.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ from typing import Iterable, List

def get_streaming_response(prompt: str,
api_url: str,
instance_id: int,
session_id: int,
request_output_len: int,
stream: bool = True,
sequence_start: bool = True,
Expand All @@ -32,7 +32,7 @@ def get_streaming_response(prompt: str,
pload = {
'prompt': prompt,
'stream': stream,
'instance_id': instance_id,
'session_id': session_id,
'request_output_len': request_output_len,
'sequence_start': sequence_start,
'sequence_end': sequence_end,
Expand All @@ -41,7 +41,7 @@ def get_streaming_response(prompt: str,
response = requests.post(
api_url, headers=headers, json=pload, stream=stream)
for chunk in response.iter_lines(
chunk_size=8192, decode_unicode=False, delimiter=b'\0'):
chunk_size=8192, decode_unicode=False, delimiter=b'\n'):
if chunk:
data = json.loads(chunk.decode('utf-8'))
output = data['text']
Expand Down Expand Up @@ -91,7 +91,7 @@ curl http://{server_ip}:{server_port}/generate \
-H "Content-Type: application/json" \
-d '{
"prompt": "Hello! How are you?",
"instance_id": 1,
"session_id": 1,
"sequence_start": true,
"sequence_end": true
}'
Expand Down Expand Up @@ -146,11 +146,10 @@ python -m lmdeploy.serve.gradio.app restful_api_url server_ip --restful_api True

2. When OOM appeared at the server side, please reduce the number of `instance_num` when lanching the service.

3. When the request with the same `instance_id` to `generate` got a empty return value and a negative `tokens`, please consider setting `sequence_start=false` for the second question and the same for the afterwards.
3. When the request with the same `session_id` to `generate` got a empty return value and a negative `tokens`, please consider setting `sequence_start=false` for the second question and the same for the afterwards.

4. Requests were previously being handled sequentially rather than concurrently. To resolve this issue,

- kindly provide unique instance_id values when calling the `generate` API or else your requests may be associated with client IP addresses
- additionally, setting `stream=true` enables processing multiple requests simultaneously
- kindly provide unique session_id values when calling the `generate` API or else your requests may be associated with client IP addresses

5. Both `generate` api and `v1/chat/completions` upport engaging in multiple rounds of conversation, where input `prompt` or `messages` consists of either single strings or entire chat histories.These inputs are interpreted using multi-turn dialogue modes. However, ff you want to turn the mode of and manage the chat history in clients, please the parameter `sequence_end: true` when utilizing the `generate` function, or specify `renew_session: true` when making use of `v1/chat/completions`
13 changes: 6 additions & 7 deletions docs/zh_cn/restful_api.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ from typing import Iterable, List

def get_streaming_response(prompt: str,
api_url: str,
instance_id: int,
session_id: int,
request_output_len: int,
stream: bool = True,
sequence_start: bool = True,
Expand All @@ -34,7 +34,7 @@ def get_streaming_response(prompt: str,
pload = {
'prompt': prompt,
'stream': stream,
'instance_id': instance_id,
'session_id': session_id,
'request_output_len': request_output_len,
'sequence_start': sequence_start,
'sequence_end': sequence_end,
Expand All @@ -43,7 +43,7 @@ def get_streaming_response(prompt: str,
response = requests.post(
api_url, headers=headers, json=pload, stream=stream)
for chunk in response.iter_lines(
chunk_size=8192, decode_unicode=False, delimiter=b'\0'):
chunk_size=8192, decode_unicode=False, delimiter=b'\n'):
if chunk:
data = json.loads(chunk.decode('utf-8'))
output = data['text']
Expand Down Expand Up @@ -93,7 +93,7 @@ curl http://{server_ip}:{server_port}/generate \
-H "Content-Type: application/json" \
-d '{
"prompt": "Hello! How are you?",
"instance_id": 1,
"session_id": 1,
"sequence_start": true,
"sequence_end": true
}'
Expand Down Expand Up @@ -148,12 +148,11 @@ python -m lmdeploy.serve.gradio.app restful_api_url server_ip --restful_api True

2. 当服务端显存 OOM 时,可以适当减小启动服务时的 `instance_num` 个数

3. 当同一个 `instance_id` 的请求给 `generate` 函数后,出现返回空字符串和负值的 `tokens`,应该是第二次问话没有设置 `sequence_start=false`
3. 当同一个 `session_id` 的请求给 `generate` 函数后,出现返回空字符串和负值的 `tokens`,应该是第二次问话没有设置 `sequence_start=false`

4. 如果感觉请求不是并发地被处理,而是一个一个地处理,请设置好以下参数:

- 不同的 instance_id 传入 `generate` api。否则,我们将自动绑定会话 id 为请求端的 ip 地址编号。
- 设置 `stream=true` 使模型在前向传播时可以允许其他请求进入被处理
- 不同的 session_id 传入 `generate` api。否则,我们将自动绑定会话 id 为请求端的 ip 地址编号。

5. `generate` api 和 `v1/chat/completions` 均支持多轮对话。`messages` 或者 `prompt` 参数既可以是一个简单字符串表示用户的单词提问,也可以是一段对话历史。
两个 api 都是默认开启多伦对话的,如果你想关闭这个功能,然后在客户端管理会话记录,请设置 `sequence_end: true` 传入 `generate`,或者设置
Expand Down
94 changes: 55 additions & 39 deletions lmdeploy/serve/async_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,31 @@ def __init__(self, model_path, instance_num=32, tp=1) -> None:
self.starts = [None] * instance_num
self.steps = {}

def stop_session(self, session_id: int):
instance_id = session_id % self.instance_num
input_ids = self.tokenizer.encode('')
for outputs in self.generators[instance_id].stream_infer(
session_id,
input_ids,
request_output_len=0,
sequence_start=False,
sequence_end=False,
stop=True):
pass
self.available[instance_id] = True

@contextmanager
def safe_run(self, instance_id: int, stop: bool = False):
def safe_run(self, instance_id: int, session_id: Optional[int] = None):
self.available[instance_id] = False
yield
try:
yield
except (Exception, asyncio.CancelledError) as e: # noqa
self.stop_session(session_id)
self.available[instance_id] = True

async def get_embeddings(self, prompt):
prompt = self.model.get_prompt(prompt)
async def get_embeddings(self, prompt, do_prerpocess=False):
if do_prerpocess:
prompt = self.model.get_prompt(prompt)
input_ids = self.tokenizer.encode(prompt)
return input_ids

Expand All @@ -68,7 +85,7 @@ async def get_generator(self, instance_id: int, stop: bool = False):
async def generate(
self,
messages,
instance_id,
session_id,
stream_response=True,
sequence_start=True,
sequence_end=False,
Expand All @@ -85,7 +102,7 @@ async def generate(
Args:
messages (str | List): chat history or prompt
instance_id (int): actually request host ip
session_id (int): the session id
stream_response (bool): whether return responses streamingly
request_output_len (int): output token nums
sequence_start (bool): indicator for starting a sequence
Expand All @@ -102,8 +119,7 @@ async def generate(
1.0 means no penalty
ignore_eos (bool): indicator for ignoring eos
"""
session_id = instance_id
instance_id %= self.instance_num
instance_id = session_id % self.instance_num
if str(session_id) not in self.steps:
self.steps[str(session_id)] = 0
if step != 0:
Expand All @@ -119,7 +135,7 @@ async def generate(
finish_reason)
else:
generator = await self.get_generator(instance_id, stop)
with self.safe_run(instance_id):
with self.safe_run(instance_id, session_id):
response_size = 0
async for outputs in generator.async_stream_infer(
session_id=session_id,
Expand Down Expand Up @@ -188,14 +204,14 @@ async def generate_openai(
instance_id %= self.instance_num
sequence_start = False
generator = await self.get_generator(instance_id)
self.available[instance_id] = False
if renew_session: # renew a session
empty_input_ids = self.tokenizer.encode('')
for outputs in generator.stream_infer(session_id=session_id,
input_ids=[empty_input_ids],
request_output_len=0,
sequence_start=False,
sequence_end=True):
sequence_end=True,
stop=True):
pass
self.steps[str(session_id)] = 0
if str(session_id) not in self.steps:
Expand All @@ -212,31 +228,31 @@ async def generate_openai(
yield GenOut('', self.steps[str(session_id)], len(input_ids), 0,
finish_reason)
else:
response_size = 0
async for outputs in generator.async_stream_infer(
session_id=session_id,
input_ids=[input_ids],
stream_output=stream_response,
request_output_len=request_output_len,
sequence_start=(sequence_start),
sequence_end=False,
step=self.steps[str(session_id)],
stop=stop,
top_k=top_k,
top_p=top_p,
temperature=temperature,
repetition_penalty=repetition_penalty,
ignore_eos=ignore_eos,
random_seed=seed if sequence_start else None):
res, tokens = outputs[0]
# decode res
response = self.tokenizer.decode(res.tolist(),
offset=response_size)
# response, history token len, input token len, gen token len
yield GenOut(response, self.steps[str(session_id)],
len(input_ids), tokens, finish_reason)
response_size = tokens

# update step
self.steps[str(session_id)] += len(input_ids) + tokens
self.available[instance_id] = True
with self.safe_run(instance_id, session_id):
response_size = 0
async for outputs in generator.async_stream_infer(
session_id=session_id,
input_ids=[input_ids],
stream_output=stream_response,
request_output_len=request_output_len,
sequence_start=(sequence_start),
sequence_end=False,
step=self.steps[str(session_id)],
stop=stop,
top_k=top_k,
top_p=top_p,
temperature=temperature,
repetition_penalty=repetition_penalty,
ignore_eos=ignore_eos,
random_seed=seed if sequence_start else None):
res, tokens = outputs[0]
# decode res
response = self.tokenizer.decode(res.tolist(),
offset=response_size)
# response, history len, input len, generation len
yield GenOut(response, self.steps[str(session_id)],
len(input_ids), tokens, finish_reason)
response_size = tokens

# update step
self.steps[str(session_id)] += len(input_ids) + tokens
23 changes: 12 additions & 11 deletions lmdeploy/serve/gradio/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from lmdeploy.serve.gradio.css import CSS
from lmdeploy.serve.openai.api_client import (get_model_list,
get_streaming_response)
from lmdeploy.serve.openai.api_server import ip2id
from lmdeploy.serve.turbomind.chatbot import Chatbot

THEME = gr.themes.Soft(
Expand All @@ -37,7 +38,7 @@ def chat_stream(state_chatbot: Sequence, llama_chatbot: Chatbot,
instruction = state_chatbot[-1][0]
session_id = threading.current_thread().ident
if request is not None:
session_id = int(request.kwargs['client']['host'].replace('.', ''))
session_id = ip2id(request.kwargs['client']['host'])

bot_response = llama_chatbot.stream_infer(
session_id, instruction, f'{session_id}-{len(state_chatbot)}')
Expand Down Expand Up @@ -166,7 +167,7 @@ def chat_stream_restful(
"""
session_id = threading.current_thread().ident
if request is not None:
session_id = int(request.kwargs['client']['host'].replace('.', ''))
session_id = ip2id(request.kwargs['client']['host'])
bot_summarized_response = ''
state_chatbot = state_chatbot + [(instruction, None)]

Expand All @@ -176,7 +177,7 @@ def chat_stream_restful(
for response, tokens, finish_reason in get_streaming_response(
instruction,
f'{InterFace.restful_api_url}/generate',
instance_id=session_id,
session_id=session_id,
request_output_len=512,
sequence_start=(len(state_chatbot) == 1),
sequence_end=False):
Expand Down Expand Up @@ -212,12 +213,12 @@ def reset_restful_func(instruction_txtbox: gr.Textbox, state_chatbot: gr.State,

session_id = threading.current_thread().ident
if request is not None:
session_id = int(request.kwargs['client']['host'].replace('.', ''))
session_id = ip2id(request.kwargs['client']['host'])
# end the session
for response, tokens, finish_reason in get_streaming_response(
'',
f'{InterFace.restful_api_url}/generate',
instance_id=session_id,
session_id=session_id,
request_output_len=0,
sequence_start=False,
sequence_end=True):
Expand All @@ -241,11 +242,11 @@ def cancel_restful_func(state_chatbot: gr.State, cancel_btn: gr.Button,
"""
session_id = threading.current_thread().ident
if request is not None:
session_id = int(request.kwargs['client']['host'].replace('.', ''))
session_id = ip2id(request.kwargs['client']['host'])
# end the session
for out in get_streaming_response('',
f'{InterFace.restful_api_url}/generate',
instance_id=session_id,
session_id=session_id,
request_output_len=0,
sequence_start=False,
sequence_end=False,
Expand All @@ -259,7 +260,7 @@ def cancel_restful_func(state_chatbot: gr.State, cancel_btn: gr.Button,
messages.append(dict(role='assistant', content=qa[1]))
for out in get_streaming_response(messages,
f'{InterFace.restful_api_url}/generate',
instance_id=session_id,
session_id=session_id,
request_output_len=0,
sequence_start=True,
sequence_end=False):
Expand Down Expand Up @@ -346,7 +347,7 @@ async def chat_stream_local(
"""
session_id = threading.current_thread().ident
if request is not None:
session_id = int(request.kwargs['client']['host'].replace('.', ''))
session_id = ip2id(request.kwargs['client']['host'])
bot_summarized_response = ''
state_chatbot = state_chatbot + [(instruction, None)]

Expand Down Expand Up @@ -391,7 +392,7 @@ async def reset_local_func(instruction_txtbox: gr.Textbox,

session_id = threading.current_thread().ident
if request is not None:
session_id = int(request.kwargs['client']['host'].replace('.', ''))
session_id = ip2id(request.kwargs['client']['host'])
# end the session
async for out in InterFace.async_engine.generate('',
session_id,
Expand Down Expand Up @@ -419,7 +420,7 @@ async def cancel_local_func(state_chatbot: gr.State, cancel_btn: gr.Button,
"""
session_id = threading.current_thread().ident
if request is not None:
session_id = int(request.kwargs['client']['host'].replace('.', ''))
session_id = ip2id(request.kwargs['client']['host'])
# end the session
async for out in InterFace.async_engine.generate('',
session_id,
Expand Down
Loading

0 comments on commit 759e1dd

Please sign in to comment.