Skip to content

Commit

Permalink
Manage session id using random int for gradio local mode (#553)
Browse files Browse the repository at this point in the history
* Use session id from gradio state

* use a new session id after reset

* rename session id like a state

* update comments

* reformat files

* init session id on block loaded

* use auto increased session id

* remove session id textbox

* apply to api_server and tritonserver

* update docstring

* add lock for safety

---------

Co-authored-by: AllentDan <[email protected]>
  • Loading branch information
aisensiy and AllentDan authored Nov 6, 2023
1 parent 85d2f66 commit 11d1093
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 100 deletions.
73 changes: 33 additions & 40 deletions lmdeploy/serve/gradio/api_server_backend.py
Original file line number Diff line number Diff line change
@@ -1,42 +1,34 @@
# Copyright (c) OpenMMLab. All rights reserved.
import threading
import time
from threading import Lock
from typing import Sequence

import gradio as gr

from lmdeploy.serve.gradio.constants import CSS, THEME, disable_btn, enable_btn
from lmdeploy.serve.openai.api_client import (get_model_list,
get_streaming_response)
from lmdeploy.serve.openai.api_server import ip2id


class InterFace:
api_server_url: str = None
global_session_id: int = 0
lock = Lock()


def chat_stream_restful(
instruction: str,
state_chatbot: Sequence,
cancel_btn: gr.Button,
reset_btn: gr.Button,
request: gr.Request,
):
def chat_stream_restful(instruction: str, state_chatbot: Sequence,
cancel_btn: gr.Button, reset_btn: gr.Button,
session_id: int):
"""Chat with AI assistant.
Args:
instruction (str): user's prompt
state_chatbot (Sequence): the chatting history
request (gr.Request): the request from a user
session_id (int): the session id
"""
session_id = threading.current_thread().ident
if request is not None:
session_id = ip2id(request.kwargs['client']['host'])
bot_summarized_response = ''
state_chatbot = state_chatbot + [(instruction, None)]

yield (state_chatbot, state_chatbot, disable_btn, enable_btn,
f'{bot_summarized_response}'.strip())
yield (state_chatbot, state_chatbot, disable_btn, enable_btn)

for response, tokens, finish_reason in get_streaming_response(
instruction,
Expand All @@ -56,27 +48,21 @@ def chat_stream_restful(
state_chatbot[-1] = (state_chatbot[-1][0],
state_chatbot[-1][1] + response
) # piece by piece
yield (state_chatbot, state_chatbot, enable_btn, disable_btn,
f'{bot_summarized_response}'.strip())
yield (state_chatbot, state_chatbot, enable_btn, disable_btn)

yield (state_chatbot, state_chatbot, disable_btn, enable_btn,
f'{bot_summarized_response}'.strip())
yield (state_chatbot, state_chatbot, disable_btn, enable_btn)


def reset_restful_func(instruction_txtbox: gr.Textbox, state_chatbot: gr.State,
request: gr.Request):
session_id: int):
"""reset the session.
Args:
instruction_txtbox (str): user's prompt
state_chatbot (Sequence): the chatting history
request (gr.Request): the request from a user
session_id (int): the session id
"""
state_chatbot = []

session_id = threading.current_thread().ident
if request is not None:
session_id = ip2id(request.kwargs['client']['host'])
# end the session
for response, tokens, finish_reason in get_streaming_response(
'',
Expand All @@ -94,18 +80,15 @@ def reset_restful_func(instruction_txtbox: gr.Textbox, state_chatbot: gr.State,


def cancel_restful_func(state_chatbot: gr.State, cancel_btn: gr.Button,
reset_btn: gr.Button, request: gr.Request):
reset_btn: gr.Button, session_id: int):
"""stop the session.
Args:
instruction_txtbox (str): user's prompt
state_chatbot (Sequence): the chatting history
request (gr.Request): the request from a user
session_id (int): the session id
"""
yield (state_chatbot, disable_btn, disable_btn)
session_id = threading.current_thread().ident
if request is not None:
session_id = ip2id(request.kwargs['client']['host'])
# end the session
for out in get_streaming_response(
'',
Expand Down Expand Up @@ -152,6 +135,7 @@ def run_api_server(api_server_url: str,

with gr.Blocks(css=CSS, theme=THEME) as demo:
state_chatbot = gr.State([])
state_session_id = gr.State(0)

with gr.Column(elem_id='container'):
gr.Markdown('## LMDeploy Playground')
Expand All @@ -164,25 +148,34 @@ def run_api_server(api_server_url: str,
cancel_btn = gr.Button(value='Cancel', interactive=False)
reset_btn = gr.Button(value='Reset')

send_event = instruction_txtbox.submit(
chat_stream_restful,
[instruction_txtbox, state_chatbot, cancel_btn, reset_btn],
[state_chatbot, chatbot, cancel_btn, reset_btn])
send_event = instruction_txtbox.submit(chat_stream_restful, [
instruction_txtbox, state_chatbot, cancel_btn, reset_btn,
state_session_id
], [state_chatbot, chatbot, cancel_btn, reset_btn])
instruction_txtbox.submit(
lambda: gr.Textbox.update(value=''),
[],
[instruction_txtbox],
)
cancel_btn.click(cancel_restful_func,
[state_chatbot, cancel_btn, reset_btn],
[state_chatbot, cancel_btn, reset_btn],
cancels=[send_event])
cancel_btn.click(
cancel_restful_func,
[state_chatbot, cancel_btn, reset_btn, state_session_id],
[state_chatbot, cancel_btn, reset_btn],
cancels=[send_event])

reset_btn.click(reset_restful_func,
[instruction_txtbox, state_chatbot],
[instruction_txtbox, state_chatbot, state_session_id],
[state_chatbot, chatbot, instruction_txtbox],
cancels=[send_event])

def init():
with InterFace.lock:
InterFace.global_session_id += 1
new_session_id = InterFace.global_session_id
return new_session_id

demo.load(init, inputs=None, outputs=[state_session_id])

print(f'server is gonna mount on: http://{server_name}:{server_port}')
demo.queue(concurrency_count=batch_size, max_size=100,
api_open=True).launch(
Expand Down
33 changes: 21 additions & 12 deletions lmdeploy/serve/gradio/triton_server_backend.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,22 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os
import threading
from functools import partial
from threading import Lock
from typing import Sequence

import gradio as gr

from lmdeploy.serve.gradio.constants import CSS, THEME, disable_btn, enable_btn
from lmdeploy.serve.openai.api_server import ip2id
from lmdeploy.serve.turbomind.chatbot import Chatbot


class InterFace:
global_session_id: int = 0
lock = Lock()


def chat_stream(state_chatbot: Sequence, llama_chatbot: Chatbot,
cancel_btn: gr.Button, reset_btn: gr.Button,
request: gr.Request):
cancel_btn: gr.Button, reset_btn: gr.Button, session_id: int):
"""Chat with AI assistant.
Args:
Expand All @@ -22,12 +25,9 @@ def chat_stream(state_chatbot: Sequence, llama_chatbot: Chatbot,
llama_chatbot (Chatbot): the instance of a chatbot
cancel_btn (bool): enable the cancel button or not
reset_btn (bool): enable the reset button or not
request (gr.Request): the request from a user
session_id (int): the session id
"""
instruction = state_chatbot[-1][0]
session_id = threading.current_thread().ident
if request is not None:
session_id = ip2id(request.kwargs['client']['host'])

bot_response = llama_chatbot.stream_infer(
session_id, instruction, f'{session_id}-{len(state_chatbot)}')
Expand Down Expand Up @@ -92,6 +92,7 @@ def run_triton_server(triton_server_addr: str,
llama_chatbot = gr.State(
Chatbot(triton_server_addr, log_level=log_level, display=True))
state_chatbot = gr.State([])
state_session_id = gr.State(0)
model_name = llama_chatbot.value.model_name
reset_all = partial(reset_all_func,
model_name=model_name,
Expand All @@ -110,10 +111,10 @@ def run_triton_server(triton_server_addr: str,

send_event = instruction_txtbox.submit(
add_instruction, [instruction_txtbox, state_chatbot],
[instruction_txtbox, state_chatbot]).then(
chat_stream,
[state_chatbot, llama_chatbot, cancel_btn, reset_btn],
[state_chatbot, chatbot, cancel_btn, reset_btn])
[instruction_txtbox, state_chatbot]).then(chat_stream, [
state_chatbot, llama_chatbot, cancel_btn, reset_btn,
state_session_id
], [state_chatbot, chatbot, cancel_btn, reset_btn])

cancel_btn.click(cancel_func,
[state_chatbot, llama_chatbot, cancel_btn, reset_btn],
Expand All @@ -125,6 +126,14 @@ def run_triton_server(triton_server_addr: str,
[llama_chatbot, state_chatbot, chatbot, instruction_txtbox],
cancels=[send_event])

def init():
with InterFace.lock:
InterFace.global_session_id += 1
new_session_id = InterFace.global_session_id
return new_session_id

demo.load(init, inputs=None, outputs=[state_session_id])

print(f'server is gonna mount on: http://{server_name}:{server_port}')
demo.queue(concurrency_count=4, max_size=100, api_open=True).launch(
max_threads=10,
Expand Down
Loading

0 comments on commit 11d1093

Please sign in to comment.