From 11d10930bd0aa808332d1019354968baf871dae6 Mon Sep 17 00:00:00 2001 From: aisensiy Date: Mon, 6 Nov 2023 16:02:40 +0800 Subject: [PATCH] Manage session id using random int for gradio local mode (#553) * Use session id from gradio state * use a new session id after reset * rename session id like a state * update comments * reformat files * init session id on block loaded * use auto increased session id * remove session id textbox * apply to api_server and tritonserver * update docstring * add lock for safety --------- Co-authored-by: AllentDan --- lmdeploy/serve/gradio/api_server_backend.py | 73 +++++++-------- .../serve/gradio/triton_server_backend.py | 33 ++++--- lmdeploy/serve/gradio/turbomind_coupled.py | 89 +++++++++---------- 3 files changed, 95 insertions(+), 100 deletions(-) diff --git a/lmdeploy/serve/gradio/api_server_backend.py b/lmdeploy/serve/gradio/api_server_backend.py index ce64508795..8dd92fa0fd 100644 --- a/lmdeploy/serve/gradio/api_server_backend.py +++ b/lmdeploy/serve/gradio/api_server_backend.py @@ -1,6 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. -import threading import time +from threading import Lock from typing import Sequence import gradio as gr @@ -8,35 +8,27 @@ from lmdeploy.serve.gradio.constants import CSS, THEME, disable_btn, enable_btn from lmdeploy.serve.openai.api_client import (get_model_list, get_streaming_response) -from lmdeploy.serve.openai.api_server import ip2id class InterFace: api_server_url: str = None + global_session_id: int = 0 + lock = Lock() -def chat_stream_restful( - instruction: str, - state_chatbot: Sequence, - cancel_btn: gr.Button, - reset_btn: gr.Button, - request: gr.Request, -): +def chat_stream_restful(instruction: str, state_chatbot: Sequence, + cancel_btn: gr.Button, reset_btn: gr.Button, + session_id: int): """Chat with AI assistant. Args: instruction (str): user's prompt state_chatbot (Sequence): the chatting history - request (gr.Request): the request from a user + session_id (int): the session id """ - session_id = threading.current_thread().ident - if request is not None: - session_id = ip2id(request.kwargs['client']['host']) - bot_summarized_response = '' state_chatbot = state_chatbot + [(instruction, None)] - yield (state_chatbot, state_chatbot, disable_btn, enable_btn, - f'{bot_summarized_response}'.strip()) + yield (state_chatbot, state_chatbot, disable_btn, enable_btn) for response, tokens, finish_reason in get_streaming_response( instruction, @@ -56,27 +48,21 @@ def chat_stream_restful( state_chatbot[-1] = (state_chatbot[-1][0], state_chatbot[-1][1] + response ) # piece by piece - yield (state_chatbot, state_chatbot, enable_btn, disable_btn, - f'{bot_summarized_response}'.strip()) + yield (state_chatbot, state_chatbot, enable_btn, disable_btn) - yield (state_chatbot, state_chatbot, disable_btn, enable_btn, - f'{bot_summarized_response}'.strip()) + yield (state_chatbot, state_chatbot, disable_btn, enable_btn) def reset_restful_func(instruction_txtbox: gr.Textbox, state_chatbot: gr.State, - request: gr.Request): + session_id: int): """reset the session. Args: instruction_txtbox (str): user's prompt state_chatbot (Sequence): the chatting history - request (gr.Request): the request from a user + session_id (int): the session id """ state_chatbot = [] - - session_id = threading.current_thread().ident - if request is not None: - session_id = ip2id(request.kwargs['client']['host']) # end the session for response, tokens, finish_reason in get_streaming_response( '', @@ -94,18 +80,15 @@ def reset_restful_func(instruction_txtbox: gr.Textbox, state_chatbot: gr.State, def cancel_restful_func(state_chatbot: gr.State, cancel_btn: gr.Button, - reset_btn: gr.Button, request: gr.Request): + reset_btn: gr.Button, session_id: int): """stop the session. Args: instruction_txtbox (str): user's prompt state_chatbot (Sequence): the chatting history - request (gr.Request): the request from a user + session_id (int): the session id """ yield (state_chatbot, disable_btn, disable_btn) - session_id = threading.current_thread().ident - if request is not None: - session_id = ip2id(request.kwargs['client']['host']) # end the session for out in get_streaming_response( '', @@ -152,6 +135,7 @@ def run_api_server(api_server_url: str, with gr.Blocks(css=CSS, theme=THEME) as demo: state_chatbot = gr.State([]) + state_session_id = gr.State(0) with gr.Column(elem_id='container'): gr.Markdown('## LMDeploy Playground') @@ -164,25 +148,34 @@ def run_api_server(api_server_url: str, cancel_btn = gr.Button(value='Cancel', interactive=False) reset_btn = gr.Button(value='Reset') - send_event = instruction_txtbox.submit( - chat_stream_restful, - [instruction_txtbox, state_chatbot, cancel_btn, reset_btn], - [state_chatbot, chatbot, cancel_btn, reset_btn]) + send_event = instruction_txtbox.submit(chat_stream_restful, [ + instruction_txtbox, state_chatbot, cancel_btn, reset_btn, + state_session_id + ], [state_chatbot, chatbot, cancel_btn, reset_btn]) instruction_txtbox.submit( lambda: gr.Textbox.update(value=''), [], [instruction_txtbox], ) - cancel_btn.click(cancel_restful_func, - [state_chatbot, cancel_btn, reset_btn], - [state_chatbot, cancel_btn, reset_btn], - cancels=[send_event]) + cancel_btn.click( + cancel_restful_func, + [state_chatbot, cancel_btn, reset_btn, state_session_id], + [state_chatbot, cancel_btn, reset_btn], + cancels=[send_event]) reset_btn.click(reset_restful_func, - [instruction_txtbox, state_chatbot], + [instruction_txtbox, state_chatbot, state_session_id], [state_chatbot, chatbot, instruction_txtbox], cancels=[send_event]) + def init(): + with InterFace.lock: + InterFace.global_session_id += 1 + new_session_id = InterFace.global_session_id + return new_session_id + + demo.load(init, inputs=None, outputs=[state_session_id]) + print(f'server is gonna mount on: http://{server_name}:{server_port}') demo.queue(concurrency_count=batch_size, max_size=100, api_open=True).launch( diff --git a/lmdeploy/serve/gradio/triton_server_backend.py b/lmdeploy/serve/gradio/triton_server_backend.py index 5936f4ba5f..9148903cc5 100644 --- a/lmdeploy/serve/gradio/triton_server_backend.py +++ b/lmdeploy/serve/gradio/triton_server_backend.py @@ -1,19 +1,22 @@ # Copyright (c) OpenMMLab. All rights reserved. import os -import threading from functools import partial +from threading import Lock from typing import Sequence import gradio as gr from lmdeploy.serve.gradio.constants import CSS, THEME, disable_btn, enable_btn -from lmdeploy.serve.openai.api_server import ip2id from lmdeploy.serve.turbomind.chatbot import Chatbot +class InterFace: + global_session_id: int = 0 + lock = Lock() + + def chat_stream(state_chatbot: Sequence, llama_chatbot: Chatbot, - cancel_btn: gr.Button, reset_btn: gr.Button, - request: gr.Request): + cancel_btn: gr.Button, reset_btn: gr.Button, session_id: int): """Chat with AI assistant. Args: @@ -22,12 +25,9 @@ def chat_stream(state_chatbot: Sequence, llama_chatbot: Chatbot, llama_chatbot (Chatbot): the instance of a chatbot cancel_btn (bool): enable the cancel button or not reset_btn (bool): enable the reset button or not - request (gr.Request): the request from a user + session_id (int): the session id """ instruction = state_chatbot[-1][0] - session_id = threading.current_thread().ident - if request is not None: - session_id = ip2id(request.kwargs['client']['host']) bot_response = llama_chatbot.stream_infer( session_id, instruction, f'{session_id}-{len(state_chatbot)}') @@ -92,6 +92,7 @@ def run_triton_server(triton_server_addr: str, llama_chatbot = gr.State( Chatbot(triton_server_addr, log_level=log_level, display=True)) state_chatbot = gr.State([]) + state_session_id = gr.State(0) model_name = llama_chatbot.value.model_name reset_all = partial(reset_all_func, model_name=model_name, @@ -110,10 +111,10 @@ def run_triton_server(triton_server_addr: str, send_event = instruction_txtbox.submit( add_instruction, [instruction_txtbox, state_chatbot], - [instruction_txtbox, state_chatbot]).then( - chat_stream, - [state_chatbot, llama_chatbot, cancel_btn, reset_btn], - [state_chatbot, chatbot, cancel_btn, reset_btn]) + [instruction_txtbox, state_chatbot]).then(chat_stream, [ + state_chatbot, llama_chatbot, cancel_btn, reset_btn, + state_session_id + ], [state_chatbot, chatbot, cancel_btn, reset_btn]) cancel_btn.click(cancel_func, [state_chatbot, llama_chatbot, cancel_btn, reset_btn], @@ -125,6 +126,14 @@ def run_triton_server(triton_server_addr: str, [llama_chatbot, state_chatbot, chatbot, instruction_txtbox], cancels=[send_event]) + def init(): + with InterFace.lock: + InterFace.global_session_id += 1 + new_session_id = InterFace.global_session_id + return new_session_id + + demo.load(init, inputs=None, outputs=[state_session_id]) + print(f'server is gonna mount on: http://{server_name}:{server_port}') demo.queue(concurrency_count=4, max_size=100, api_open=True).launch( max_threads=10, diff --git a/lmdeploy/serve/gradio/turbomind_coupled.py b/lmdeploy/serve/gradio/turbomind_coupled.py index d5cd59867f..e344abcbda 100644 --- a/lmdeploy/serve/gradio/turbomind_coupled.py +++ b/lmdeploy/serve/gradio/turbomind_coupled.py @@ -1,16 +1,17 @@ # Copyright (c) OpenMMLab. All rights reserved. -import threading +from threading import Lock from typing import Sequence import gradio as gr from lmdeploy.serve.async_engine import AsyncEngine from lmdeploy.serve.gradio.constants import CSS, THEME, disable_btn, enable_btn -from lmdeploy.serve.openai.api_server import ip2id class InterFace: async_engine: AsyncEngine = None + global_session_id: int = 0 + lock = Lock() async def chat_stream_local( @@ -18,25 +19,20 @@ async def chat_stream_local( state_chatbot: Sequence, cancel_btn: gr.Button, reset_btn: gr.Button, - request: gr.Request, + session_id: int, ): """Chat with AI assistant. Args: instruction (str): user's prompt state_chatbot (Sequence): the chatting history - cancel_btn (bool): enable the cancel button or not - reset_btn (bool): enable the reset button or not - request (gr.Request): the request from a user + cancel_btn (gr.Button): the cancel button + reset_btn (gr.Button): the reset button + session_id (int): the session id """ - session_id = threading.current_thread().ident - if request is not None: - session_id = ip2id(request.kwargs['client']['host']) - bot_summarized_response = '' state_chatbot = state_chatbot + [(instruction, None)] - yield (state_chatbot, state_chatbot, disable_btn, enable_btn, - f'{bot_summarized_response}'.strip()) + yield (state_chatbot, state_chatbot, disable_btn, enable_btn) async for outputs in InterFace.async_engine.generate( instruction, @@ -57,27 +53,21 @@ async def chat_stream_local( state_chatbot[-1] = (state_chatbot[-1][0], state_chatbot[-1][1] + response ) # piece by piece - yield (state_chatbot, state_chatbot, enable_btn, disable_btn, - f'{bot_summarized_response}'.strip()) + yield (state_chatbot, state_chatbot, enable_btn, disable_btn) - yield (state_chatbot, state_chatbot, disable_btn, enable_btn, - f'{bot_summarized_response}'.strip()) + yield (state_chatbot, state_chatbot, disable_btn, enable_btn) async def reset_local_func(instruction_txtbox: gr.Textbox, - state_chatbot: gr.State, request: gr.Request): + state_chatbot: Sequence, session_id: int): """reset the session. Args: instruction_txtbox (str): user's prompt state_chatbot (Sequence): the chatting history - request (gr.Request): the request from a user + session_id (int): the session id """ state_chatbot = [] - - session_id = threading.current_thread().ident - if request is not None: - session_id = ip2id(request.kwargs['client']['host']) # end the session async for out in InterFace.async_engine.generate('', session_id, @@ -86,29 +76,21 @@ async def reset_local_func(instruction_txtbox: gr.Textbox, sequence_start=False, sequence_end=True): pass - - return ( - state_chatbot, - state_chatbot, - gr.Textbox.update(value=''), - ) + return (state_chatbot, state_chatbot, gr.Textbox.update(value='')) -async def cancel_local_func(state_chatbot: gr.State, cancel_btn: gr.Button, - reset_btn: gr.Button, request: gr.Request): +async def cancel_local_func(state_chatbot: Sequence, cancel_btn: gr.Button, + reset_btn: gr.Button, session_id: int): """stop the session. Args: + instruction_txtbox (str): user's prompt state_chatbot (Sequence): the chatting history - cancel_btn (bool): enable the cancel button or not - reset_btn (bool): enable the reset button or not - request (gr.Request): the request from a user + cancel_btn (gr.Button): the cancel button + reset_btn (gr.Button): the reset button + session_id (int): the session id """ - yield (state_chatbot, disable_btn, disable_btn) - session_id = threading.current_thread().ident - if request is not None: - session_id = ip2id(request.kwargs['client']['host']) - # end the session + yield (state_chatbot, disable_btn, enable_btn) async for out in InterFace.async_engine.generate('', session_id, request_output_len=0, @@ -152,6 +134,7 @@ def run_local(model_path: str, with gr.Blocks(css=CSS, theme=THEME) as demo: state_chatbot = gr.State([]) + state_session_id = gr.State(0) with gr.Column(elem_id='container'): gr.Markdown('## LMDeploy Playground') @@ -166,24 +149,34 @@ def run_local(model_path: str, cancel_btn = gr.Button(value='Cancel', interactive=False) reset_btn = gr.Button(value='Reset') - send_event = instruction_txtbox.submit( - chat_stream_local, - [instruction_txtbox, state_chatbot, cancel_btn, reset_btn], - [state_chatbot, chatbot, cancel_btn, reset_btn]) + send_event = instruction_txtbox.submit(chat_stream_local, [ + instruction_txtbox, state_chatbot, cancel_btn, reset_btn, + state_session_id + ], [state_chatbot, chatbot, cancel_btn, reset_btn]) instruction_txtbox.submit( lambda: gr.Textbox.update(value=''), [], [instruction_txtbox], ) - cancel_btn.click(cancel_local_func, - [state_chatbot, cancel_btn, reset_btn], - [state_chatbot, cancel_btn, reset_btn], - cancels=[send_event]) - - reset_btn.click(reset_local_func, [instruction_txtbox, state_chatbot], + cancel_btn.click( + cancel_local_func, + [state_chatbot, cancel_btn, reset_btn, state_session_id], + [state_chatbot, cancel_btn, reset_btn], + cancels=[send_event]) + + reset_btn.click(reset_local_func, + [instruction_txtbox, state_chatbot, state_session_id], [state_chatbot, chatbot, instruction_txtbox], cancels=[send_event]) + def init(): + with InterFace.lock: + InterFace.global_session_id += 1 + new_session_id = InterFace.global_session_id + return new_session_id + + demo.load(init, inputs=None, outputs=[state_session_id]) + print(f'server is gonna mount on: http://{server_name}:{server_port}') demo.queue(concurrency_count=batch_size, max_size=100, api_open=True).launch(