-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
103 lines (86 loc) · 3.66 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import streamlit as st
from langchain.chat_models.gigachat import GigaChat
from typing import Optional, Type
import dotenv
import os
from langchain.pydantic_v1 import BaseModel, Field
from services.retrievers_ensemble_retriever import get_ensemble_retriver
from langchain.tools import BaseTool
from langchain.agents import (
AgentExecutor,
create_gigachat_functions_agent,
)
from langchain.agents.gigachat_functions_agent.base import (
format_to_gigachat_function_messages,
)
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
try:
dotenv.load_dotenv()
except:
pass
GIGACHAT_API_CREDENTIALS = os.environ.get("GIGACHAT_API_CREDENTIALS")
class SearchInput(BaseModel):
question: str = Field(
description="вопрос пользователя"
)
class SearchTool(BaseTool):
name = "search"
description = """Выполняет поиск вопроса пользователя по медицине в базе данных
"""
args_schema: Type[BaseModel] = SearchInput
def _run(
self,
question: str,
run_manager=None
) -> str:
result = get_ensemble_retriver().get_relevant_documents(question)
result_string = "Agent Tool RAG: Найденные статьи:\n\n"
for index, item in enumerate(result):
result_string += f"{index+1} \t" + item.page_content
result_string += "\n" + item.metadata['source'] + "\n\n"
return result_string
giga = GigaChat(credentials=GIGACHAT_API_CREDENTIALS,
scope='GIGACHAT_API_CORP',
verify_ssl_certs=False,
model = 'GigaChat-Pro-preview',
profanity_check=False,
timeout=600,
#streaming=True
)
tools = [SearchTool()]
agent = create_gigachat_functions_agent(giga, tools)
# AgentExecutor создает среду, в которой будет работать агент
agent_executor = AgentExecutor(
agent=agent, tools=tools, verbose=False, return_intermediate_steps=True
)
system = f"""Ты ИИ-ассистент и справочник по медицине
У тебя есть доступные функции:
Получить RAG context из векторной базы данных
""" # noqa
chat_history = [SystemMessage(content=system)]
with st.sidebar:
"Ответы от чат-бота носят справочный характер."
"Ответы от ИИ-ассистента не являются врачебной рекомендацией, лучше обратиться в мед. организацию."
st.title("💬 Мед чат-бот")
if "messages" not in st.session_state:
st.session_state["messages"] = [{"role": "assistant", "content": "Вас приветствует мед чат-бот"}]
for msg in st.session_state.messages:
st.chat_message(msg["role"]).write(msg["content"])
if question := st.chat_input():
st.session_state.messages.append({"role": "user", "content": question})
st.chat_message("user").write(question)
result = agent_executor.invoke(
{
"chat_history": chat_history,
"input": question,
}
)
#result = giga(st.session_state["messages"])
msg = result["output"]
details = result["intermediate_steps"]
if len(details) > 0:
tool_question, tool_answer = details[0]
st.session_state.messages.append({"role": "assistant", "content": tool_answer})
st.chat_message("assistant", avatar='🤖').write(tool_answer)
st.session_state.messages.append({"role": "assistant", "content": msg})
st.chat_message("assistant").write(msg)