Skip to content

Commit

Permalink
[ColossalQA] refactor server and webui & add new feature (hpcaitech#5138
Browse files Browse the repository at this point in the history
)

* refactor server and webui & add new feature

* add requirements

* modify readme and ui
  • Loading branch information
MichelleMa8 authored Nov 30, 2023
1 parent 2a2ec49 commit c7fd9a5
Show file tree
Hide file tree
Showing 12 changed files with 380 additions and 257 deletions.
2 changes: 1 addition & 1 deletion applications/ColossalQA/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ Read comments under ./colossalqa/data_loader for more detail regarding supported
### Run The Script

We provide a simple Web UI demo of ColossalQA, enabling you to upload your files as a knowledge base and interact with them through a chat interface in your browser. More details can be found [here](examples/webui_demo/README.md)
![ColossalQA Demo](https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/colossalqa/ui.png)
![ColossalQA Demo](https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/colossalqa/new_ui.png)

We also provided some scripts for Chinese document retrieval based conversation system, English document retrieval based conversation system, Bi-lingual document retrieval based conversation system and an experimental AI agent with document retrieval and SQL query functionality. The Bi-lingual one is a high-level wrapper for the other two classes. We write different scripts for different languages because retrieval QA requires different embedding models, LLMs, prompts for different language setting. For now, we use LLaMa2 for English retrieval QA and ChatGLM2 for Chinese retrieval QA for better performance.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,3 +126,11 @@ def load_data(self, path: str) -> None:
else:
# May ba a directory, we strictly follow the glob path and will not load files in subdirectories
pass

def clear(self):
"""
Clear loaded data.
"""
self.data = {}
self.kwargs = {}
self.all_data = []
20 changes: 13 additions & 7 deletions applications/ColossalQA/colossalqa/prompt/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@

from langchain.prompts.prompt import PromptTemplate


# Below are Chinese retrieval qa prompts

_CUSTOM_SUMMARIZER_TEMPLATE_ZH = """请递进式地总结所提供的当前对话,将当前对话的摘要内容添加到先前已有的摘要上,返回一个融合了当前对话的新的摘要。
例1:
Expand All @@ -27,8 +30,6 @@
新的摘要:"""


# Chinese retrieval qa prompt

_ZH_RETRIEVAL_QA_PROMPT = """<指令>根据下列支持文档和对话历史,简洁和专业地来回答问题。如果无法从支持文档中得到答案,请说 “根据已知信息无法回答该问题”。回答中请不要涉及支持文档中没有提及的信息,答案请使用中文。 </指令>
{context}
Expand Down Expand Up @@ -70,7 +71,8 @@
句子: {input}
消除歧义的句子:"""

# English retrieval qa prompt

# Below are English retrieval qa prompts

_EN_RETRIEVAL_QA_PROMPT = """[INST] <<SYS>>Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist content.
If the answer cannot be infered based on the given context, please say "I cannot answer the question based on the information given.".<</SYS>>
Expand Down Expand Up @@ -105,20 +107,24 @@
disambiguated sentence:"""


# Prompt templates

# English retrieval prompt, the model generates answer based on this prompt
PROMPT_RETRIEVAL_QA_EN = PromptTemplate(
template=_EN_RETRIEVAL_QA_PROMPT, input_variables=["question", "chat_history", "context"]
)

# English disambigate prompt, which replace any ambiguous references in the user's input with the specific names or entities mentioned in the chat history
PROMPT_DISAMBIGUATE_EN = PromptTemplate(template=_EN_DISAMBIGUATION_PROMPT, input_variables=["chat_history", "input"])

# Chinese summary prompt, which summarize the chat history
SUMMARY_PROMPT_ZH = PromptTemplate(input_variables=["summary", "new_lines"], template=_CUSTOM_SUMMARIZER_TEMPLATE_ZH)

# Chinese disambigate prompt, which replace any ambiguous references in the user's input with the specific names or entities mentioned in the chat history
PROMPT_DISAMBIGUATE_ZH = PromptTemplate(template=_ZH_DISAMBIGUATION_PROMPT, input_variables=["chat_history", "input"])

# Chinese retrieval prompt, the model generates answer based on this prompt
PROMPT_RETRIEVAL_QA_ZH = PromptTemplate(
template=_ZH_RETRIEVAL_QA_PROMPT, input_variables=["question", "chat_history", "context"]
)

# Chinese retrieval prompt for a use case to analyze fault causes
PROMPT_RETRIEVAL_CLASSIFICATION_USE_CASE_ZH = PromptTemplate(
template=_ZH_RETRIEVAL_CLASSIFICATION_USE_CASE, input_variables=["question", "context"]
)
19 changes: 18 additions & 1 deletion applications/ColossalQA/colossalqa/retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def add_documents(
data_by_source[doc.metadata["source"]].append(doc)
elif mode == "merge":
data_by_source["merged"] = docs

for source in data_by_source:
if source not in self.vector_stores:
hash_encoding = hashlib.sha3_224(source.encode()).hexdigest()
Expand All @@ -81,8 +82,10 @@ def add_documents(
os.remove(f"{self.sql_file_path}/{hash_encoding}.db")
# Create a new sql database to store indexes, sql files are stored in the same directory as the source file
sql_path = f"sqlite:///{self.sql_file_path}/{hash_encoding}.db"
self.vector_stores[source] = Chroma(embedding_function=embedding, collection_name=hash_encoding)
# to record the sql database with their source as index
self.sql_index_database[source] = f"{self.sql_file_path}/{hash_encoding}.db"

self.vector_stores[source] = Chroma(embedding_function=embedding, collection_name=hash_encoding)
self.record_managers[source] = SQLRecordManager(source, db_url=sql_path)
self.record_managers[source].create_schema()
index(
Expand All @@ -93,6 +96,20 @@ def add_documents(
source_id_key="source",
)

def clear_documents(self):
"""Clear all document vectors from database"""
for source in self.vector_stores:
index(
[],
self.record_managers[source],
self.vector_stores[source],
cleanup="full",
source_id_key="source"
)
self.vector_stores = {}
self.sql_index_database = {}
self.record_managers = {}

def __del__(self):
for source in self.sql_index_database:
if os.path.exists(self.sql_index_database[source]):
Expand Down
101 changes: 47 additions & 54 deletions applications/ColossalQA/examples/webui_demo/RAG_ChatBot.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
from typing import Dict, Tuple

from colossalqa.chain.retrieval_qa.base import RetrievalQA
Expand All @@ -12,29 +13,11 @@
ZH_RETRIEVAL_QA_TRIGGER_KEYWORDS,
)
from colossalqa.retriever import CustomRetriever
from colossalqa.text_splitter import ChineseTextSplitter
from langchain import LLMChain
from langchain.embeddings import HuggingFaceEmbeddings

logger = get_logger()

DEFAULT_RAG_CFG = {
"retri_top_k": 3,
"retri_kb_file_path": "./",
"verbose": True,
"mem_summary_prompt": SUMMARY_PROMPT_ZH,
"mem_human_prefix": "用户",
"mem_ai_prefix": "Assistant",
"mem_max_tokens": 2000,
"mem_llm_kwargs": {"max_new_tokens": 50, "temperature": 1, "do_sample": True},
"disambig_prompt": PROMPT_DISAMBIGUATE_ZH,
"disambig_llm_kwargs": {"max_new_tokens": 30, "temperature": 1, "do_sample": True},
"embed_model_name_or_path": "moka-ai/m3e-base",
"embed_model_device": {"device": "cpu"},
"gen_llm_kwargs": {"max_new_tokens": 100, "temperature": 1, "do_sample": True},
"gen_qa_prompt": PROMPT_RETRIEVAL_QA_ZH,
}


class RAG_ChatBot:
def __init__(
Expand All @@ -44,13 +27,16 @@ def __init__(
) -> None:
self.llm = llm
self.rag_config = rag_config
self.set_embed_model(**self.rag_config)
self.set_text_splitter(**self.rag_config)
self.set_memory(**self.rag_config)
self.set_info_retriever(**self.rag_config)
self.set_rag_chain(**self.rag_config)
if self.rag_config.get("disambig_prompt", None):
self.set_disambig_retriv(**self.rag_config)
self.set_embed_model(**self.rag_config["embed"])
self.set_text_splitter(**self.rag_config["splitter"])
self.set_memory(**self.rag_config["chain"])
self.set_info_retriever(**self.rag_config["retrieval"])
self.set_rag_chain(**self.rag_config["chain"])
if self.rag_config["chain"].get("disambig_prompt", None):
self.set_disambig_retriv(**self.rag_config["chain"])

self.documents = []
self.docs_names = []

def set_embed_model(self, **kwargs):
self.embed_model = HuggingFaceEmbeddings(
Expand All @@ -61,7 +47,7 @@ def set_embed_model(self, **kwargs):

def set_text_splitter(self, **kwargs):
# Initialize text_splitter
self.text_splitter = ChineseTextSplitter()
self.text_splitter = kwargs["name"]()

def set_memory(self, **kwargs):
params = {"llm_kwargs": kwargs["mem_llm_kwargs"]} if kwargs.get("mem_llm_kwargs", None) else {}
Expand Down Expand Up @@ -91,10 +77,6 @@ def set_rag_chain(self, **kwargs):
**params,
)

def split_docs(self, documents):
doc_splits = self.text_splitter.split_documents(documents)
return doc_splits

def set_disambig_retriv(self, **kwargs):
params = {"llm_kwargs": kwargs["disambig_llm_kwargs"]} if kwargs.get("disambig_llm_kwargs", None) else {}
self.llm_chain_disambiguate = LLMChain(llm=self.llm, prompt=kwargs["disambig_prompt"], **params)
Expand All @@ -106,42 +88,50 @@ def disambiguity(input: str):
self.info_retriever.set_rephrase_handler(disambiguity)

def load_doc_from_console(self, json_parse_args: Dict = {}):
documents = []
print("Select files for constructing Chinese retriever")
print("Select files for constructing the retriever")
while True:
file = input("Enter a file path or press Enter directly without input to exit:").strip()
if file == "":
break
data_name = input("Enter a short description of the data:")
docs = DocumentLoader([[file, data_name.replace(" ", "_")]], **json_parse_args).all_data
documents.extend(docs)
self.documents = documents
self.split_docs_and_add_to_mem(**self.rag_config)
self.documents.extend(docs)
self.docs_names.append(data_name)
self.split_docs_and_add_to_mem(**self.rag_config["chain"])

def load_doc_from_files(self, files, data_name="default_kb", json_parse_args: Dict = {}):
documents = []
for file in files:
docs = DocumentLoader([[file, data_name.replace(" ", "_")]], **json_parse_args).all_data
documents.extend(docs)
self.documents = documents
self.split_docs_and_add_to_mem(**self.rag_config)
self.documents.extend(docs)
self.docs_names.append(os.path.basename(file))
self.split_docs_and_add_to_mem(**self.rag_config["chain"])

def split_docs_and_add_to_mem(self, **kwargs):
self.doc_splits = self.split_docs(self.documents)
doc_splits = self.split_docs(self.documents)
self.info_retriever.add_documents(
docs=self.doc_splits, cleanup="incremental", mode="by_source", embedding=self.embed_model
docs=doc_splits, cleanup="incremental", mode="by_source", embedding=self.embed_model
)
self.memory.initiate_document_retrieval_chain(self.llm, kwargs["gen_qa_prompt"], self.info_retriever)

def split_docs(self, documents):
doc_splits = self.text_splitter.split_documents(documents)
return doc_splits

def clear_docs(self, **kwargs):
self.documents = []
self.docs_names = []
self.info_retriever.clear_documents()
self.memory.initiate_document_retrieval_chain(self.llm, kwargs["gen_qa_prompt"], self.info_retriever)

def reset_config(self, rag_config):
self.rag_config = rag_config
self.set_embed_model(**self.rag_config)
self.set_text_splitter(**self.rag_config)
self.set_memory(**self.rag_config)
self.set_info_retriever(**self.rag_config)
self.set_rag_chain(**self.rag_config)
if self.rag_config.get("disambig_prompt", None):
self.set_disambig_retriv(**self.rag_config)
self.set_embed_model(**self.rag_config["embed"])
self.set_text_splitter(**self.rag_config["splitter"])
self.set_memory(**self.rag_config["chain"])
self.set_info_retriever(**self.rag_config["retrieval"])
self.set_rag_chain(**self.rag_config["chain"])
if self.rag_config["chain"].get("disambig_prompt", None):
self.set_disambig_retriv(**self.rag_config["chain"])

def run(self, user_input: str, memory: ConversationBufferWithSummary) -> Tuple[str, ConversationBufferWithSummary]:
if memory:
Expand All @@ -153,7 +143,7 @@ def run(self, user_input: str, memory: ConversationBufferWithSummary) -> Tuple[s
rejection_trigger_keywrods=ZH_RETRIEVAL_QA_TRIGGER_KEYWORDS,
rejection_answer=ZH_RETRIEVAL_QA_REJECTION_ANSWER,
)
return result.split("\n")[0], memory
return result, memory

def start_test_session(self):
"""
Expand All @@ -170,15 +160,18 @@ def start_test_session(self):

if __name__ == "__main__":
# Initialize an Langchain LLM(here we use ChatGPT as an example)
import config
from langchain.llms import OpenAI

llm = OpenAI(openai_api_key="YOUR_OPENAI_API_KEY")
# you need to: export OPENAI_API_KEY="YOUR_OPENAI_API_KEY"
llm = OpenAI(openai_api_key=os.getenv("OPENAI_API_KEY"))

# chatgpt cannot control temperature, do_sample, etc.
DEFAULT_RAG_CFG["mem_llm_kwargs"] = None
DEFAULT_RAG_CFG["disambig_llm_kwargs"] = None
DEFAULT_RAG_CFG["gen_llm_kwargs"] = None
all_config = config.ALL_CONFIG
all_config["chain"]["mem_llm_kwargs"] = None
all_config["chain"]["disambig_llm_kwargs"] = None
all_config["chain"]["gen_llm_kwargs"] = None

rag = RAG_ChatBot(llm, DEFAULT_RAG_CFG)
rag = RAG_ChatBot(llm, all_config)
rag.load_doc_from_console()
rag.start_test_session()
Loading

0 comments on commit c7fd9a5

Please sign in to comment.