Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Manage session id using random int for gradio local mode #553

Merged
merged 12 commits into from
Nov 6, 2023
73 changes: 33 additions & 40 deletions lmdeploy/serve/gradio/api_server_backend.py
Original file line number Diff line number Diff line change
@@ -1,42 +1,34 @@
# Copyright (c) OpenMMLab. All rights reserved.
import threading
import time
from threading import Lock
from typing import Sequence

import gradio as gr

from lmdeploy.serve.gradio.constants import CSS, THEME, disable_btn, enable_btn
from lmdeploy.serve.openai.api_client import (get_model_list,
get_streaming_response)
from lmdeploy.serve.openai.api_server import ip2id


class InterFace:
api_server_url: str = None
global_session_id: int = 0
lock = Lock()


def chat_stream_restful(
instruction: str,
state_chatbot: Sequence,
cancel_btn: gr.Button,
reset_btn: gr.Button,
request: gr.Request,
):
def chat_stream_restful(instruction: str, state_chatbot: Sequence,
cancel_btn: gr.Button, reset_btn: gr.Button,
session_id: int):
"""Chat with AI assistant.

Args:
instruction (str): user's prompt
state_chatbot (Sequence): the chatting history
request (gr.Request): the request from a user
session_id (int): the session id
"""
session_id = threading.current_thread().ident
if request is not None:
session_id = ip2id(request.kwargs['client']['host'])
bot_summarized_response = ''
state_chatbot = state_chatbot + [(instruction, None)]

yield (state_chatbot, state_chatbot, disable_btn, enable_btn,
f'{bot_summarized_response}'.strip())
yield (state_chatbot, state_chatbot, disable_btn, enable_btn)

for response, tokens, finish_reason in get_streaming_response(
instruction,
Expand All @@ -56,27 +48,21 @@ def chat_stream_restful(
state_chatbot[-1] = (state_chatbot[-1][0],
state_chatbot[-1][1] + response
) # piece by piece
yield (state_chatbot, state_chatbot, enable_btn, disable_btn,
f'{bot_summarized_response}'.strip())
yield (state_chatbot, state_chatbot, enable_btn, disable_btn)

yield (state_chatbot, state_chatbot, disable_btn, enable_btn,
f'{bot_summarized_response}'.strip())
yield (state_chatbot, state_chatbot, disable_btn, enable_btn)


def reset_restful_func(instruction_txtbox: gr.Textbox, state_chatbot: gr.State,
request: gr.Request):
session_id: int):
"""reset the session.

Args:
instruction_txtbox (str): user's prompt
state_chatbot (Sequence): the chatting history
request (gr.Request): the request from a user
session_id (int): the session id
"""
state_chatbot = []

session_id = threading.current_thread().ident
if request is not None:
session_id = ip2id(request.kwargs['client']['host'])
# end the session
for response, tokens, finish_reason in get_streaming_response(
'',
Expand All @@ -94,18 +80,15 @@ def reset_restful_func(instruction_txtbox: gr.Textbox, state_chatbot: gr.State,


def cancel_restful_func(state_chatbot: gr.State, cancel_btn: gr.Button,
reset_btn: gr.Button, request: gr.Request):
reset_btn: gr.Button, session_id: int):
"""stop the session.

Args:
instruction_txtbox (str): user's prompt
state_chatbot (Sequence): the chatting history
request (gr.Request): the request from a user
session_id (int): the session id
"""
yield (state_chatbot, disable_btn, disable_btn)
session_id = threading.current_thread().ident
if request is not None:
session_id = ip2id(request.kwargs['client']['host'])
# end the session
for out in get_streaming_response(
'',
Expand Down Expand Up @@ -152,6 +135,7 @@ def run_api_server(api_server_url: str,

with gr.Blocks(css=CSS, theme=THEME) as demo:
state_chatbot = gr.State([])
state_session_id = gr.State(0)

with gr.Column(elem_id='container'):
gr.Markdown('## LMDeploy Playground')
Expand All @@ -164,25 +148,34 @@ def run_api_server(api_server_url: str,
cancel_btn = gr.Button(value='Cancel', interactive=False)
reset_btn = gr.Button(value='Reset')

send_event = instruction_txtbox.submit(
chat_stream_restful,
[instruction_txtbox, state_chatbot, cancel_btn, reset_btn],
[state_chatbot, chatbot, cancel_btn, reset_btn])
send_event = instruction_txtbox.submit(chat_stream_restful, [
instruction_txtbox, state_chatbot, cancel_btn, reset_btn,
state_session_id
], [state_chatbot, chatbot, cancel_btn, reset_btn])
instruction_txtbox.submit(
lambda: gr.Textbox.update(value=''),
[],
[instruction_txtbox],
)
cancel_btn.click(cancel_restful_func,
[state_chatbot, cancel_btn, reset_btn],
[state_chatbot, cancel_btn, reset_btn],
cancels=[send_event])
cancel_btn.click(
cancel_restful_func,
[state_chatbot, cancel_btn, reset_btn, state_session_id],
[state_chatbot, cancel_btn, reset_btn],
cancels=[send_event])

reset_btn.click(reset_restful_func,
[instruction_txtbox, state_chatbot],
[instruction_txtbox, state_chatbot, state_session_id],
[state_chatbot, chatbot, instruction_txtbox],
cancels=[send_event])

def init():
with InterFace.lock:
InterFace.global_session_id += 1
new_session_id = InterFace.global_session_id
return new_session_id

demo.load(init, inputs=None, outputs=[state_session_id])

print(f'server is gonna mount on: http://{server_name}:{server_port}')
demo.queue(concurrency_count=batch_size, max_size=100,
api_open=True).launch(
Expand Down
33 changes: 21 additions & 12 deletions lmdeploy/serve/gradio/triton_server_backend.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,22 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os
import threading
from functools import partial
from threading import Lock
from typing import Sequence

import gradio as gr

from lmdeploy.serve.gradio.constants import CSS, THEME, disable_btn, enable_btn
from lmdeploy.serve.openai.api_server import ip2id
from lmdeploy.serve.turbomind.chatbot import Chatbot


class InterFace:
global_session_id: int = 0
lock = Lock()


def chat_stream(state_chatbot: Sequence, llama_chatbot: Chatbot,
cancel_btn: gr.Button, reset_btn: gr.Button,
request: gr.Request):
cancel_btn: gr.Button, reset_btn: gr.Button, session_id: int):
"""Chat with AI assistant.

Args:
Expand All @@ -22,12 +25,9 @@ def chat_stream(state_chatbot: Sequence, llama_chatbot: Chatbot,
llama_chatbot (Chatbot): the instance of a chatbot
cancel_btn (bool): enable the cancel button or not
reset_btn (bool): enable the reset button or not
request (gr.Request): the request from a user
session_id (int): the session id
"""
instruction = state_chatbot[-1][0]
session_id = threading.current_thread().ident
if request is not None:
session_id = ip2id(request.kwargs['client']['host'])

bot_response = llama_chatbot.stream_infer(
session_id, instruction, f'{session_id}-{len(state_chatbot)}')
Expand Down Expand Up @@ -92,6 +92,7 @@ def run_triton_server(triton_server_addr: str,
llama_chatbot = gr.State(
Chatbot(triton_server_addr, log_level=log_level, display=True))
state_chatbot = gr.State([])
state_session_id = gr.State(0)
model_name = llama_chatbot.value.model_name
reset_all = partial(reset_all_func,
model_name=model_name,
Expand All @@ -110,10 +111,10 @@ def run_triton_server(triton_server_addr: str,

send_event = instruction_txtbox.submit(
add_instruction, [instruction_txtbox, state_chatbot],
[instruction_txtbox, state_chatbot]).then(
chat_stream,
[state_chatbot, llama_chatbot, cancel_btn, reset_btn],
[state_chatbot, chatbot, cancel_btn, reset_btn])
[instruction_txtbox, state_chatbot]).then(chat_stream, [
state_chatbot, llama_chatbot, cancel_btn, reset_btn,
state_session_id
], [state_chatbot, chatbot, cancel_btn, reset_btn])

cancel_btn.click(cancel_func,
[state_chatbot, llama_chatbot, cancel_btn, reset_btn],
Expand All @@ -125,6 +126,14 @@ def run_triton_server(triton_server_addr: str,
[llama_chatbot, state_chatbot, chatbot, instruction_txtbox],
cancels=[send_event])

def init():
with InterFace.lock:
InterFace.global_session_id += 1
new_session_id = InterFace.global_session_id
return new_session_id

demo.load(init, inputs=None, outputs=[state_session_id])

print(f'server is gonna mount on: http://{server_name}:{server_port}')
demo.queue(concurrency_count=4, max_size=100, api_open=True).launch(
max_threads=10,
Expand Down
Loading
Loading