forked from chatchat-space/Langchain-Chatchat
-
Notifications
You must be signed in to change notification settings - Fork 2
/
knowledge_based_chatglm.py
124 lines (105 loc) · 4.22 KB
/
knowledge_based_chatglm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
from langchain.chains import RetrievalQA
from langchain.prompts.chat import (
ChatPromptTemplate,
SystemMessagePromptTemplate,
HumanMessagePromptTemplate,
)
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS
from langchain.document_loaders import UnstructuredFileLoader
from chatglm_llm import ChatGLM
import sentence_transformers
import torch
import os
import readline
# Global Parameters
EMBEDDING_MODEL = "text2vec"
VECTOR_SEARCH_TOP_K = 6
LLM_MODEL = "chatglm-6b"
LLM_HISTORY_LEN = 3
DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
# Show reply with source text from input document
REPLY_WITH_SOURCE = True
embedding_model_dict = {
"ernie-tiny": "nghuyong/ernie-3.0-nano-zh",
"ernie-base": "nghuyong/ernie-3.0-base-zh",
"text2vec": "GanymedeNil/text2vec-large-chinese",
}
llm_model_dict = {
"chatglm-6b-int4-qe": "THUDM/chatglm-6b-int4-qe",
"chatglm-6b-int4": "THUDM/chatglm-6b-int4",
"chatglm-6b": "THUDM/chatglm-6b",
}
def init_cfg(LLM_MODEL, EMBEDDING_MODEL, LLM_HISTORY_LEN, V_SEARCH_TOP_K=6):
global chatglm, embeddings, VECTOR_SEARCH_TOP_K
VECTOR_SEARCH_TOP_K = V_SEARCH_TOP_K
chatglm = ChatGLM()
chatglm.load_model(model_name_or_path=llm_model_dict[LLM_MODEL])
chatglm.history_len = LLM_HISTORY_LEN
embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[EMBEDDING_MODEL],)
embeddings.client = sentence_transformers.SentenceTransformer(embeddings.model_name,
device=DEVICE)
def init_knowledge_vector_store(filepath:str):
if not os.path.exists(filepath):
print("路径不存在")
return None
elif os.path.isfile(filepath):
file = os.path.split(filepath)[-1]
try:
loader = UnstructuredFileLoader(filepath, mode="elements")
docs = loader.load()
print(f"{file} 已成功加载")
except:
print(f"{file} 未能成功加载")
return None
elif os.path.isdir(filepath):
docs = []
for file in os.listdir(filepath):
fullfilepath = os.path.join(filepath, file)
try:
loader = UnstructuredFileLoader(fullfilepath, mode="elements")
docs += loader.load()
print(f"{file} 已成功加载")
except:
print(f"{file} 未能成功加载")
vector_store = FAISS.from_documents(docs, embeddings)
return vector_store
def get_knowledge_based_answer(query, vector_store, chat_history=[]):
global chatglm, embeddings
system_template = """基于以下内容,简洁和专业的来回答用户的问题。
如果无法从中得到答案,请说 "不知道" 或 "没有足够的相关信息",不要试图编造答案。答案请使用中文。
----------------
{context}
----------------
"""
messages = [
SystemMessagePromptTemplate.from_template(system_template),
HumanMessagePromptTemplate.from_template("{question}"),
]
prompt = ChatPromptTemplate.from_messages(messages)
chatglm.history = chat_history
knowledge_chain = RetrievalQA.from_llm(
llm=chatglm,
retriever=vector_store.as_retriever(search_kwargs={"k": VECTOR_SEARCH_TOP_K}),
prompt=prompt
)
knowledge_chain.return_source_documents = True
result = knowledge_chain({"query": query})
chatglm.history[-1][0] = query
return result, chatglm.history
if __name__ == "__main__":
init_cfg(LLM_MODEL, EMBEDDING_MODEL, LLM_HISTORY_LEN)
vector_store = None
while not vector_store:
filepath = input("Input your local knowledge file path 请输入本地知识文件路径:")
vector_store = init_knowledge_vector_store(filepath)
history = []
while True:
query = input("Input your question 请输入问题:")
resp, history = get_knowledge_based_answer(query=query,
vector_store=vector_store,
chat_history=history)
if REPLY_WITH_SOURCE:
print(resp)
else:
print(resp["result"])