Skip to content

Commit

Permalink
feat(model):multi llm switch (#576)
Browse files Browse the repository at this point in the history
1.multi llm switch
2.web page refactor
  • Loading branch information
fangyinc authored Sep 13, 2023
2 parents 4e6bc44 + f39efb0 commit 4854cba
Show file tree
Hide file tree
Showing 122 changed files with 586 additions and 569 deletions.
12 changes: 6 additions & 6 deletions docs/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
autodoc_pydantic
myst_parser
nbsphinx==0.8.9
sphinx==4.5.0
nbsphinx
sphinx
recommonmark
sphinx_intl
sphinx-autobuild==2021.3.14
sphinx-autobuild
sphinx_book_theme
sphinx_rtd_theme==1.0.0
sphinx-typlog-theme==0.8.0
sphinx_rtd_theme
sphinx-typlog-theme
sphinx-panels
toml
myst_nb
sphinx_copybutton
pydata-sphinx-theme==0.13.1
pydata-sphinx-theme
furo
61 changes: 52 additions & 9 deletions pilot/openapi/api_v1/api_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ async def dialogue_list(user_id: str = None):
conv_uid = item.get("conv_uid")
summary = item.get("summary")
chat_mode = item.get("chat_mode")
model_name = item.get("model_name", CFG.LLM_MODEL)

messages = json.loads(item.get("messages"))
last_round = max(messages, key=lambda x: x["chat_order"])
Expand All @@ -160,6 +161,7 @@ async def dialogue_list(user_id: str = None):
conv_uid=conv_uid,
user_input=summary,
chat_mode=chat_mode,
model_name=model_name,
select_param=select_param,
)
dialogues.append(conv_vo)
Expand Down Expand Up @@ -215,8 +217,10 @@ async def params_list(chat_mode: str = ChatScene.ChatNormal.value()):


@router.post("/v1/chat/mode/params/file/load")
async def params_load(conv_uid: str, chat_mode: str, doc_file: UploadFile = File(...)):
print(f"params_load: {conv_uid},{chat_mode}")
async def params_load(
conv_uid: str, chat_mode: str, model_name: str, doc_file: UploadFile = File(...)
):
print(f"params_load: {conv_uid},{chat_mode},{model_name}")
try:
if doc_file:
## file save
Expand All @@ -235,7 +239,10 @@ async def params_load(conv_uid: str, chat_mode: str, doc_file: UploadFile = File
)
## chat prepare
dialogue = ConversationVo(
conv_uid=conv_uid, chat_mode=chat_mode, select_param=doc_file.filename
conv_uid=conv_uid,
chat_mode=chat_mode,
select_param=doc_file.filename,
model_name=model_name,
)
chat: BaseChat = get_chat_instance(dialogue)
resp = await chat.prepare()
Expand All @@ -259,8 +266,11 @@ def get_hist_messages(conv_uid: str):
history_messages: List[OnceConversation] = history_mem.get_messages()
if history_messages:
for once in history_messages:
print(f"once:{once}")
model_name = once.get("model_name", CFG.LLM_MODEL)
once_message_vos = [
message2Vo(element, once["chat_order"]) for element in once["messages"]
message2Vo(element, once["chat_order"], model_name)
for element in once["messages"]
]
message_vos.extend(once_message_vos)
return message_vos
Expand All @@ -287,15 +297,19 @@ def get_chat_instance(dialogue: ConversationVo = Body()) -> BaseChat:

chat_param = {
"chat_session_id": dialogue.conv_uid,
"user_input": dialogue.user_input,
"current_user_input": dialogue.user_input,
"select_param": dialogue.select_param,
"model_name": dialogue.model_name,
}
chat: BaseChat = CHAT_FACTORY.get_implementation(dialogue.chat_mode, **chat_param)
chat: BaseChat = CHAT_FACTORY.get_implementation(
dialogue.chat_mode, **{"chat_param": chat_param}
)
return chat


@router.post("/v1/chat/prepare")
async def chat_prepare(dialogue: ConversationVo = Body()):
# dialogue.model_name = CFG.LLM_MODEL
logger.info(f"chat_prepare:{dialogue}")
## check conv_uid
chat: BaseChat = get_chat_instance(dialogue)
Expand All @@ -307,7 +321,9 @@ async def chat_prepare(dialogue: ConversationVo = Body()):

@router.post("/v1/chat/completions")
async def chat_completions(dialogue: ConversationVo = Body()):
print(f"chat_completions:{dialogue.chat_mode},{dialogue.select_param}")
print(
f"chat_completions:{dialogue.chat_mode},{dialogue.select_param},{dialogue.model_name}"
)
chat: BaseChat = get_chat_instance(dialogue)
# background_tasks = BackgroundTasks()
# background_tasks.add_task(release_model_semaphore)
Expand All @@ -332,6 +348,30 @@ async def chat_completions(dialogue: ConversationVo = Body()):
)


@router.get("/v1/model/types")
async def model_types(request: Request):
print(f"/controller/model/types")
try:
import httpx

async with httpx.AsyncClient() as client:
base_url = request.base_url
response = await client.get(
f"{base_url}api/controller/models?healthy_only=true",
)
types = set()
if response.status_code == 200:
models = json.loads(response.text)
for model in models:
worker_type = model["model_name"].split("@")[1]
if worker_type == "llm":
types.add(model["model_name"].split("@")[0])
return Result.succ(list(types))

except Exception as e:
return Result.faild(code="E000X", msg=f"controller model types error {e}")


async def no_stream_generator(chat):
msg = await chat.nostream_call()
msg = msg.replace("\n", "\\n")
Expand All @@ -356,7 +396,10 @@ async def stream_generator(chat):
chat.memory.append(chat.current_message)


def message2Vo(message: dict, order) -> MessageVo:
def message2Vo(message: dict, order, model_name) -> MessageVo:
return MessageVo(
role=message["type"], context=message["data"]["content"], order=order
role=message["type"],
context=message["data"]["content"],
order=order,
model_name=model_name,
)
9 changes: 9 additions & 0 deletions pilot/openapi/api_view_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@ class ConversationVo(BaseModel):
chat scene select param
"""
select_param: str = None
"""
llm model name
"""
model_name: str = None


class MessageVo(BaseModel):
Expand All @@ -74,3 +78,8 @@ class MessageVo(BaseModel):
time the current message was sent
"""
time_stamp: Any = None

"""
model_name
"""
model_name: str
31 changes: 17 additions & 14 deletions pilot/scene/base_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import traceback
import warnings
from abc import ABC, abstractmethod
from typing import Any, List
from typing import Any, List, Dict

from pilot.configs.config import Config
from pilot.configs.model_config import LOGDIR
Expand Down Expand Up @@ -32,13 +32,13 @@ class Config:

arbitrary_types_allowed = True

def __init__(
self, chat_mode, chat_session_id, current_user_input, select_param: Any = None
):
self.chat_session_id = chat_session_id
self.chat_mode = chat_mode
self.current_user_input: str = current_user_input
self.llm_model = CFG.LLM_MODEL
def __init__(self, chat_param: Dict):
self.chat_session_id = chat_param["chat_session_id"]
self.chat_mode = chat_param["chat_mode"]
self.current_user_input: str = chat_param["current_user_input"]
self.llm_model = (
chat_param["model_name"] if chat_param["model_name"] else CFG.LLM_MODEL
)
self.llm_echo = False

### load prompt template
Expand All @@ -55,14 +55,17 @@ def __init__(
)

### can configurable storage methods
self.memory = DuckdbHistoryMemory(chat_session_id)
self.memory = DuckdbHistoryMemory(chat_param["chat_session_id"])

self.history_message: List[OnceConversation] = self.memory.messages()
self.current_message: OnceConversation = OnceConversation(chat_mode.value())
if select_param:
if len(chat_mode.param_types()) > 0:
self.current_message.param_type = chat_mode.param_types()[0]
self.current_message.param_value = select_param
self.current_message: OnceConversation = OnceConversation(
self.chat_mode.value()
)
self.current_message.model_name = self.llm_model
if chat_param["select_param"]:
if len(self.chat_mode.param_types()) > 0:
self.current_message.param_type = self.chat_mode.param_types()[0]
self.current_message.param_value = chat_param["select_param"]
self.current_tokens_used: int = 0

class Config:
Expand Down
24 changes: 7 additions & 17 deletions pilot/scene/chat_dashboard/chat.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
import os
import uuid
from typing import List
from typing import List, Dict

from pilot.scene.base_chat import BaseChat
from pilot.scene.base import ChatScene
Expand All @@ -21,30 +21,20 @@ class ChatDashboard(BaseChat):
report_name: str
"""Number of results to return from the query"""

def __init__(
self,
chat_session_id,
user_input,
select_param: str = "",
report_name: str = "report",
):
def __init__(self, chat_param: Dict):
""" """
self.db_name = select_param
super().__init__(
chat_mode=ChatScene.ChatDashboard,
chat_session_id=chat_session_id,
current_user_input=user_input,
select_param=self.db_name,
)
self.db_name = chat_param["select_param"]
chat_param["chat_mode"] = ChatScene.ChatDashboard
super().__init__(chat_param=chat_param)
if not self.db_name:
raise ValueError(f"{ChatScene.ChatDashboard.value} mode should choose db!")
self.db_name = self.db_name
self.report_name = report_name
self.report_name = chat_param["report_name"] or "report"

self.database = CFG.LOCAL_DB_MANAGE.get_connect(self.db_name)

self.top_k: int = 5
self.dashboard_template = self.__load_dashboard_template(report_name)
self.dashboard_template = self.__load_dashboard_template(self.report_name)

def __load_dashboard_template(self, template_name):
current_dir = os.getcwd()
Expand Down
20 changes: 9 additions & 11 deletions pilot/scene/chat_data/chat_excel/excel_analyze/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,25 +22,22 @@ class ChatExcel(BaseChat):
chat_scene: str = ChatScene.ChatExcel.value()
chat_retention_rounds = 1

def __init__(self, chat_session_id, user_input, select_param: str = ""):
def __init__(self, chat_param: Dict):
chat_mode = ChatScene.ChatExcel

self.select_param = select_param
if has_path(select_param):
self.excel_reader = ExcelReader(select_param)
self.select_param = chat_param["select_param"]
self.model_name = chat_param["model_name"]
chat_param["chat_mode"] = ChatScene.ChatExcel
if has_path(self.select_param):
self.excel_reader = ExcelReader(self.select_param)
else:
self.excel_reader = ExcelReader(
os.path.join(
KNOWLEDGE_UPLOAD_ROOT_PATH, chat_mode.value(), select_param
KNOWLEDGE_UPLOAD_ROOT_PATH, chat_mode.value(), self.select_param
)
)

super().__init__(
chat_mode=chat_mode,
chat_session_id=chat_session_id,
current_user_input=user_input,
select_param=select_param,
)
super().__init__(chat_param=chat_param)

def _generate_command_string(self, command: Dict[str, Any]) -> str:
"""
Expand Down Expand Up @@ -85,6 +82,7 @@ async def prepare(self):
"parent_mode": self.chat_mode,
"select_param": self.excel_reader.excel_file_name,
"excel_reader": self.excel_reader,
"model_name": self.model_name,
}
learn_chat = ExcelLearning(**chat_param)
result = await learn_chat.nostream_call()
Expand Down
15 changes: 9 additions & 6 deletions pilot/scene/chat_data/chat_excel/excel_learning/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,17 +30,20 @@ def __init__(
parent_mode: Any = None,
select_param: str = None,
excel_reader: Any = None,
model_name: str = None,
):
chat_mode = ChatScene.ExcelLearning
""" """
self.excel_file_path = select_param
self.excel_reader = excel_reader
super().__init__(
chat_mode=chat_mode,
chat_session_id=chat_session_id,
current_user_input=user_input,
select_param=select_param,
)
chat_param = {
"chat_mode": chat_mode,
"chat_session_id": chat_session_id,
"current_user_input": user_input,
"select_param": select_param,
"model_name": model_name,
}
super().__init__(chat_param=chat_param)
if parent_mode:
self.current_message.chat_mode = parent_mode.value()

Expand Down
12 changes: 6 additions & 6 deletions pilot/scene/chat_db/auto_execute/chat.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Dict

from pilot.scene.base_chat import BaseChat
from pilot.scene.base import ChatScene
from pilot.common.sql_database import Database
Expand All @@ -12,15 +14,13 @@ class ChatWithDbAutoExecute(BaseChat):

"""Number of results to return from the query"""

def __init__(self, chat_session_id, user_input, select_param: str = ""):
def __init__(self, chat_param: Dict):
chat_mode = ChatScene.ChatWithDbExecute
self.db_name = select_param
self.db_name = chat_param["select_param"]
chat_param["chat_mode"] = chat_mode
""" """
super().__init__(
chat_mode=chat_mode,
chat_session_id=chat_session_id,
current_user_input=user_input,
select_param=self.db_name,
chat_param=chat_param,
)
if not self.db_name:
raise ValueError(
Expand Down
14 changes: 6 additions & 8 deletions pilot/scene/chat_db/professional_qa/chat.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Dict

from pilot.scene.base_chat import BaseChat
from pilot.scene.base import ChatScene
from pilot.common.sql_database import Database
Expand All @@ -12,15 +14,11 @@ class ChatWithDbQA(BaseChat):

"""Number of results to return from the query"""

def __init__(self, chat_session_id, user_input, select_param: str = ""):
def __init__(self, chat_param: Dict):
""" """
self.db_name = select_param
super().__init__(
chat_mode=ChatScene.ChatWithDbQA,
chat_session_id=chat_session_id,
current_user_input=user_input,
select_param=self.db_name,
)
self.db_name = chat_param["select_param"]
chat_param["chat_mode"] = ChatScene.ChatWithDbQA
super().__init__(chat_param=chat_param)

if self.db_name:
self.database = CFG.LOCAL_DB_MANAGE.get_connect(self.db_name)
Expand Down
Loading

0 comments on commit 4854cba

Please sign in to comment.