diff --git a/src/executor/docqa/eval/docqa.py b/src/executor/docqa/eval/docqa.py
new file mode 100644
index 00000000..75946984
--- /dev/null
+++ b/src/executor/docqa/eval/docqa.py
@@ -0,0 +1,380 @@
+#!/bin/python3
+# -#- coding: UTF-8 -*-
+
+from typing import Generator, Iterable
+from collections import namedtuple
+from pathlib import Path
+from urllib.error import HTTPError
+from pprint import pformat
+from langchain.docstore.document import Document
+from kuwa.executor import Modelfile
+
+from .recursive_url_multimedia_loader import RecursiveUrlMultimediaLoader
+from .document_store import DocumentStore
+from .kuwa_llm_client import KuwaLlmClient
+from .eval import Eval
+
+import i18n
+import re
+import gc
+import os
+import logging
+import chevron
+import asyncio
+import copy
+import time
+import pathlib
+import json
+
+class DocQa:
+
+ def __init__(
+ self,
+ document_store:DocumentStore = DocumentStore,
+ vector_db:str = None,
+ llm:KuwaLlmClient = KuwaLlmClient(),
+ lang:str="en",
+ with_ref:bool=False,
+ user_agent:str = None
+ ):
+ self.logger = logging.getLogger(__name__)
+ self.llm = llm
+ self.lang = lang
+ self.with_ref = with_ref
+ self.user_agent = user_agent
+ if vector_db != None:
+ self.pre_build_db = True
+ self.document_store = DocumentStore.load(vector_db)
+ else:
+ self.pre_build_db = False
+ self.document_store:DocumentStore = document_store
+
+ def generate_llm_input(self, task, question, related_docs, override_prompt:str=None):
+ """
+ This function generates a formatted input string suitable for the Kuwa LLM model based on the provided task, question, related documents, and an optional override prompt.
+
+ Args:
+ task (str): The type of task to be performed (e.g., "qa" for question answering, "summary" for summarization, and "translate").
+ question (str): The user's question or prompt.
+ related_docs (List[Document]): A list of Document objects containing relevant information for the task. These document came from the retreiever embedding.
+ override_prompt (str, optional): An optional override prompt to use in place of the default prompt for the task. Defaults to None.
+
+ Returns:
+ str: The formatted LLM input string.
+ """
+ # Convert related documents to a dictionary format suitable for templating
+ docs = [dict(title=doc.metadata.get("title"), **dict(doc)) for doc in related_docs]
+ # Load the appropriate prompt template based on the task and language
+ template_path = f'lang/{self.lang}/prompt_template/llm_input_{task}.mustache'
+ llm_input_template = Path(template_path).read_text(encoding="utf8")
+ # Render the template with the provided data
+ llm_input = chevron.render(llm_input_template, {
+ 'docs': docs,
+ 'question': question,
+ 'override_prompt': override_prompt
+ })
+ return llm_input
+
+ def replace_chat_history(self, chat_history:[dict], task:str, question:str, related_docs:[str], override_prompt:str):
+ """
+ This function modifies the provided chat history to include the generated LLM input for processing by the LLM model.
+
+ Args:
+ chat_history (List[Dict]): A list of dictionaries representing the chat history, where each dictionary contains keys like "msg" (message content) and "isbot" (boolean indicating if the message is from a bot).
+ task (str): The type of task to be performed (e.g., "qa" for question answering, "summary" for summarization).
+ question (str): The user's question or prompt.
+ related_docs (List[str]): A list of URLs or file paths pointing to documents relevant to the user's query.
+ override_prompt (str, optional): An optional override prompt to use in place of the default prompt for the task. Defaults to None.
+
+ Returns:
+ List[Dict]: The modified chat history with the LLM input message inserted.
+ """
+ # Generate the LLM input using the provided information
+ llm_input = self.generate_llm_input(task, question, related_docs, override_prompt)
+ # Modify the chat history to include the LLM input
+ modified_chat_history = chat_history[:-1] + [{"isbot": False, "msg": llm_input}]
+ # If the first message in the chat history is empty
+ if modified_chat_history[0]["msg"] is None:
+ # If it's a multi-round conversation, set the first message to a summary prompt
+ if len(modified_chat_history) != 2: # Multi-round
+ modified_chat_history[0]["msg"] = i18n.t("docqa.summary_prompt")
+ # Otherwise, remove the first message
+ else: # Single-round
+ modified_chat_history = modified_chat_history[1:]
+ # Replace empty messages with a placeholder
+ modified_chat_history = [
+ {"msg": "[Empty message]", "isbot": r["isbot"]} if r["msg"] == '' else r
+ for r in modified_chat_history
+ ]
+ return modified_chat_history
+
+ def is_english(self, paragraph:str, threshold=0.8):
+ total_count = len(paragraph)
+ english_charter_count = len(paragraph.encode("ascii", "ignore"))
+ english_rate = 0 if total_count == 0 else english_charter_count / total_count
+
+ return english_rate >= threshold
+
+ def get_final_user_input(self, chat_history: [dict]) -> str:
+ """
+ This function extracts the final user input (message) from the provided chat history.
+
+ Args:
+ chat_history (List[Dict]): A list of dictionaries representing the chat history, where each dictionary contains keys like "msg" (message content) and "isbot" (boolean indicating if the message is from a bot).
+
+ Returns:
+ str: The final user input message from the chat history, or None if no user messages are found.
+ """
+ final_user_record = next(filter(lambda x: x['isbot'] == False, reversed(chat_history)))
+ return final_user_record['msg']
+
+ async def fetch_documents(self, url:str):
+ """
+ This function asynchronously fetches documents from the provided URL using a RecursiveUrlMultimediaLoader.
+
+ Args:
+ url (str): The URL of the webpage or document to be fetched.
+
+ Returns:
+ List[Document]: A list of Document objects containing the fetched content, or an empty list if there were errors or no documents were found.
+ """
+ # Fetching documents
+ self.logger.info(f'Fetching URL "{url}"')
+ docs = []
+ # Create a RecursiveUrlMultimediaLoader to handle fetching documents
+ loader = RecursiveUrlMultimediaLoader(
+ url=url,
+ max_depth=1,
+ prevent_outside=False,
+ use_async = True,
+ cache_proxy_url = os.environ.get('HTTP_CACHE_PROXY', None),
+ forge_user_agent=self.user_agent
+ )
+ try:
+ # Asynchronously load the documents using the loader
+ docs = await loader.async_load()
+ # Log the number of documents fetched
+ self.logger.info(f'Fetched {len(docs)} documents.')
+ except Exception as e:
+ docs = []
+ finally:
+ return docs
+
+ async def construct_document_store(self, docs: [Document]):
+ document_store = self.document_store
+ # Calls function from document_store to make the docs goes through embedding and forms vectorDB
+ await document_store.from_documents(docs)
+ return document_store
+
+ def filter_detail(self, msg):
+ """
+ This function removes HTML details tags (... ) from the provided message string.
+
+ Args:
+ msg (str): The message string potentially containing HTML details tags.
+
+ Returns:
+ str: The message string with HTML details tags removed. If the input message is None, it returns None.
+ """
+ if msg is None: return None
+ pattern = r".* "
+ return re.sub(pattern=pattern, repl='', string=msg, flags=re.DOTALL)
+
+ def format_references(self, docs:[Document]):
+ """
+ This function formats a list of Document objects into a user-readable reference list with snippets.
+
+ Args:
+ docs (List[Document]): A list of Document objects containing information about the reference sources.
+
+ Returns:
+ str: The formatted reference list as a string, including source links, titles, and snippets. If no documents have a valid source, it returns an empty string.
+ """
+ Reference = namedtuple("Reference", "source, title, content")
+ refs = [
+ Reference(
+ source=doc.metadata.get("source"),
+ title=doc.metadata.get("title", doc.metadata.get("filename")),
+ content=doc.page_content,
+ ) for doc in docs
+ ]
+ refs = filter(lambda x: x.source, refs)
+ result = f"\n\n{i18n.t('docqa.reference')}
\n\n"
+ for i, ref in enumerate(refs):
+
+ src = ref.source
+ title = ref.title if ref.title is not None else src
+ content = ref.content
+ link = src if src.startswith("http") else pathlib.Path(src).as_uri()
+ result += f'{i+1}. [{title}]({link})\n\n```plaintext\n{content}\n```\n\n'
+ result += f" "
+
+ return result
+
+ async def process(self, urls: Iterable, chat_history: [dict], modelfile:Modelfile, auth_token=None) -> Generator[str, None, None]:
+
+ """
+ This asynchronous function processes a conversation history and a list of URLs to answer questions or summarize documents in a conversational manner.
+
+ It interacts with the Kuwa LLM model to generate responses based on the provided information.
+
+ Args:
+ urls (Iterable[str]): An iterable of URLs pointing to documents relevant to the user's query.
+ chat_history (List[Dict]): A list of dictionaries representing the conversation history, where each dictionary contains keys like "msg" (message content) and "isbot" (boolean indicating if the message is from a bot).
+ modelfile (Modelfile): A Modelfile object containing configuration parameters for the conversation.
+ auth_token (str, optional): An optional authentication token for the LLM model. Defaults to None.
+
+ Yields:
+ str: Yields chunks of the conversation generated by the LLM model, including the answer or summary and potentially references to retrieved documents.
+
+ Returns:
+ None: The function itself does not return a value, it yields the conversation chunks asynchronously.
+ """
+ retriever_result = {"questions": []}
+ override_qa_prompt = modelfile.override_system_prompt
+ chat_history = [{"msg": i["content"], "isbot": i["role"]=="assistant"} for i in chat_history]
+ chat_history = [{"msg": self.filter_detail(i["msg"]), "isbot": i["isbot"]} for i in chat_history]
+
+ self.logger.debug(f"Chat history: {chat_history}")
+
+ # final_user_input = self.get_final_user_input(chat_history)
+ # if final_user_input is not None:
+ # final_user_input = "{before}{user}{after}".format(
+ # before = modelfile.before_prompt,
+ # user = final_user_input,
+ # after = modelfile.after_prompt
+ # )
+
+
+ document_store = self.document_store
+ docs = None
+ if not self.pre_build_db:
+ if len(urls) == 1:
+
+ docs = await self.fetch_documents(urls[0])
+ if len(docs) == 0:
+ await asyncio.sleep(2) # To prevent SSE error of web page.
+ yield i18n.t('docqa.error_fetching_document')
+ return
+ else:
+ docs = await asyncio.gather(*[self.fetch_documents(url) for url in urls])
+ docs = [doc for sub_docs in docs for doc in sub_docs]
+ document_store = await self.construct_document_store(docs)
+
+ text = docs[0] # the text will be coming from the document of what user upload
+ data_sample = Eval.clean_llm_response(Eval.generate_questions_RAG(text))#evalotor.generate_question(text)
+ data_sample['contexts'] = []
+ data_sample['answer'] = []
+ gen_questions = data_sample['questions']
+ self.logger.info(f"{gen_questions}")
+ for gen_question in gen_questions:
+ retriever_result = {"questions": []}
+ final_user_input = gen_question
+ task = ''
+ if final_user_input == "":
+ question = i18n.t("docqa.summary_question")
+ llm_question = None
+ task = 'summary'
+ await asyncio.sleep(2) # To prevent SSE error of web page.
+ yield i18n.t("docqa.summary_prefix")+'\n'
+ else:
+ question = final_user_input
+ llm_question = question
+ task = 'qa'
+
+ # Shortcut
+ if docs != None:
+ related_docs = docs
+ modified_chat_history = self.replace_chat_history(chat_history, task, llm_question, related_docs, override_prompt=override_qa_prompt)
+
+ if docs == None or self.llm.is_too_long(modified_chat_history):
+ # Retrieve
+ related_docs = copy.deepcopy(await document_store.retrieve(question))
+
+ self.logger.info("Related documents: {}".format(related_docs))
+ # [TODO] the related-document will be cleared when the history is too long
+ while True:
+ modified_chat_history = self.replace_chat_history(chat_history, task, llm_question, related_docs, override_prompt=override_qa_prompt)
+ if not self.llm.is_too_long(modified_chat_history) or len(related_docs)==0: break
+ related_docs = related_docs[:-1]
+ self.logger.info("Prompt length exceeded the permitted limit, necessitating truncation.")
+
+ #result = Eval.eval_retriever(gen_question, related_docs)
+ #retriever_result["questions"].append(result)
+ #retriever_result = Eval.yield_retriever(retriever_result)
+ #yield retriever_result
+
+ page_content = Eval.extract_page_content(related_docs)
+ data_sample['contexts'].append(page_content)
+ answer = Eval.generate_answer(gen_question, page_content)
+ data_sample['answer'].append(answer)
+ yield pformat(data_sample)
+ Eval.ragas_eval(data_sample)
+
+ break
+ #end2end_result = Eval.eval_end2end(gen_question, answer)
+ #yield Eval.yield_end2end(end2end_result)
+ # yield pformat(Eval.eval_end2end(gen_question, answer))
+
+ # Free the unused VRAM
+ # del document_store
+ # gc.collect()
+
+ # # # Generate
+ # llm_input = self.generate_llm_input(task, llm_question, related_docs, override_prompt=override_qa_prompt)
+ # self.logger.info("Related documents: {}".format(related_docs))
+ # self.logger.info('LLM input: {}'.format(llm_input))
+ # # result = ''
+ # generator = self.llm.chat_complete(
+ # auth_token=auth_token,
+ # messages=modified_chat_history
+ # )
+ # buffer = ""
+ # async for chunk in generator:
+ # buffer += chunk
+
+
+ # self.logger.info("Buffer is: " + buffer)
+ # result = evalotor.eval_end2end(gen_question, buffer)
+ # self.logger.info("Return result is: " + result)
+ # yield result
+ # yield chunk
+
+
+ if self.with_ref and len(related_docs)!=0:
+ yield self.format_references(related_docs)
+
+
+ # # # Generate
+ # llm_input = self.generate_llm_input(task, llm_question, related_docs, override_prompt=override_qa_prompt)
+ # self.logger.info("Related documents: {}".format(related_docs))
+ # self.logger.info('LLM input: {}'.format(llm_input))
+ # generator = self.llm.chat_complete(
+ # # auth_token=auth_token, # Optional
+ # # messages=modified_chat_history # Optional (depending on your model)
+ # )
+ # async for chunk in generator:
+ # # Call your function that returns a JSON object
+ # json_data = await self.your_function(chunk) # Replace with your function name
+ # # Convert the JSON object to a string
+ # json_string = json.dumps(json_data)
+ # yield json_string
+
+ # if self.with_ref and len(related_docs)!=0:
+ # yield self.format_references(related_docs)
+
+ # Egress filter
+ # is_english = self.is_english(result)
+ # self.logger.info(f'Is English: {is_english}')
+ # if task == 'summary' and is_english:
+ # result = await self.llm.chat_complete(
+ # auth_token=auth_token,
+ # messages=[
+ # {
+ # "isbot": False,
+ # "msg": self.generate_llm_input('translate', result, [])
+ # },
+ # ]
+ # )
+
+ # yield result
\ No newline at end of file
diff --git a/src/executor/docqa/eval/eval.py b/src/executor/docqa/eval/eval.py
new file mode 100644
index 00000000..7d169b56
--- /dev/null
+++ b/src/executor/docqa/eval/eval.py
@@ -0,0 +1,290 @@
+import pandas as pd
+import os
+import google.generativeai as genai
+import json
+import time
+from typing import List
+import re
+from datasets import Dataset
+from ragas import evaluate
+from ragas.metrics import faithfulness, answer_correctness
+
+API_KEY = ''
+
+class Eval:
+ @staticmethod
+ def save_to_json(data, filename):
+ with open(filename, 'w') as f:
+ json.dump(data, f, ensure_ascii=False, indent=4)
+
+ @staticmethod
+ def read_csv_as_text(file_path, num_rows=None):
+ df = pd.read_csv(file_path, nrows=num_rows)
+ text = df.to_csv(index=False)
+ return text
+
+ @staticmethod
+ def read_document(file_path):
+ with open(file_path, 'r', encoding='utf-8') as file:
+ content = file.read()
+ return content
+
+ @staticmethod
+ def eval_retriever(question, relevant_chunks):
+ genai.configure(api_key=API_KEY)
+ model = genai.GenerativeModel('gemini-pro')
+ yes_chunks = []
+ no_chunks = []
+ error_chunks = []
+ for chunk in relevant_chunks:
+ prompt = f"""
+ 提問:{question}
+ 出處:{chunk}
+ 請問出處的資料有沒有跟提問的問題有關聯性? 請用只使用'y' 或是 'n' 來回答,不要包含任何解釋
+ """
+ while(True):
+ try:
+ response = model.generate_content(prompt)
+ break
+ except Exception as E:
+ print("\033[1,32m Quota exhausted. Waiting before retrying...\033[0,0m")
+ time.sleep(1)
+ response_text = response.text.strip()
+ if response_text == 'y':
+ yes_chunks.append(chunk)
+ elif response_text == 'n':
+ no_chunks.append(chunk)
+ else:
+ error_chunks.append({"chunk": chunk, "response": response_text})
+
+ yes_length = len(yes_chunks)
+ accuracy = yes_length / (yes_length + len(no_chunks) + len(error_chunks))
+ retriever_result = {
+ "question": question,
+ "y": yes_chunks,
+ "n": no_chunks,
+ "error": error_chunks,
+ "accuracy": accuracy
+ }
+ # Eval.save_to_json(retriever_result, 'retriever_result.json')
+ return retriever_result
+
+ @staticmethod
+ def eval_end2end(question, answer):
+ genai.configure(api_key=API_KEY)
+ model = genai.GenerativeModel('gemini-pro')
+
+ prompt = f"""
+ 問題:{question}
+ 答案:{answer}
+ 請問答案有沒有回答到問題? 請用只使用'y' 或是 'n' 來回答,不要包含任何解釋
+ """
+
+ while(True):
+ try:
+ response = model.generate_content(prompt)
+ break
+ except Exception as E:
+ print("\033[1,32m Quota exhausted. Waiting before retrying...\033[0,0m")
+ time.sleep(1)
+
+ response_text = response.text.strip()
+ retriever_result = {
+ "question": question,
+ "answer": answer,
+ "related": response_text
+ }
+
+ return retriever_result
+
+ @staticmethod
+ def generate_question(text: str or list) -> list:
+ genai.configure(api_key=API_KEY)
+ model = genai.GenerativeModel('gemini-pro')
+
+ if isinstance(text, list):
+ text_string = " ".join(text)
+ else:
+ text_string = text
+
+ prompt = f"請閱讀以下文章並且使用一個 \"- \" 做區分跟列出相關問題,最多隨機產生10個問題,請只包含問題不包含其他不必要資訊: {text_string}"
+
+ try:
+ response = model.generate_content(prompt)
+ except Exception as E:
+ print("\033[1,32m Quota exhausted. Waiting before retrying...\033[0,0m")
+ time.sleep(1)
+ return E
+
+ pattern = r"- (.*)"
+ strings = re.findall(pattern, response.text, flags=re.MULTILINE)
+ return strings
+
+ @staticmethod
+ def generate_answer(question, relevant_chunks):
+ genai.configure(api_key=API_KEY)
+ model = genai.GenerativeModel('gemini-pro')
+ prompt = f"""
+ 提問:{question}
+ 出處:{relevant_chunks}
+ 請使用出處的資料來回答提問,並且用"From:"列出用來回答提問的出處
+ """
+ while(True):
+ try:
+ response = model.generate_content(prompt)
+ break
+ except Exception as E:
+ print("\033[1,32m Quota exhausted. Waiting before retrying...\033[0,0m")
+ time.sleep(1)
+ return response.text
+
+ @staticmethod
+ def yield_retriever(data):
+ """
+ Yields information from the provided JSON data in a formatted manner.
+
+ Args:
+ data: A JSON object containing questions, their accuracies, and relevant/irrelevant chunks.
+
+ Yields:
+ A formatted string representing the question, accuracy, and chunks.
+ """
+
+ result_string = ""
+ for question_data in data['questions']:
+ question = question_data['question']
+ accuracy = f"Accuracy: {question_data['accuracy']:.2f}"
+ relevant_chunks = Eval.extract_page_content(question_data['y'])
+ irrelevant_chunks = Eval.extract_page_content(question_data['n'])
+
+ formatted_relevant_chunks = Eval.format_chunks(relevant_chunks)
+ formatted_irrelevant_chunks = Eval.format_chunks(irrelevant_chunks)
+
+ result_string += f"\nQuestion: {question}\nAccuracy: {accuracy}\n\nRelevant Chunks:\n{formatted_relevant_chunks}\n\nIrrelevant Chunks:\n{formatted_irrelevant_chunks}\n\n"
+
+ return result_string
+
+ @staticmethod
+ def yield_end2end(data):
+
+ question = data['question']
+ answer = data['answer']
+ related = data['related']
+ return f"""
+ Question: {question}
+ Answer: {answer}
+ Is related: {related}
+ """
+
+ def format_chunks(chunks):
+ return "\n".join(f"{index + 1}. {chunk}" for index, chunk in enumerate(chunks))
+
+ @staticmethod
+ def extract_page_content(relevant_chunks):
+ """Extracts page content from a list of Document objects.
+
+ Args:
+ relevant_chunks: A list of Document objects.
+
+ Returns:
+ A list of page content strings.
+ """
+
+ page_contents = []
+ for document in relevant_chunks:
+ page_contents.append(document.page_content)
+ return page_contents
+
+ @staticmethod
+ def clean_llm_response(response):
+ try:
+ # Find the start and end positions of the JSON object
+ start_marker = '{'
+ end_marker = '}'
+
+ start_index = response.index(start_marker)
+ end_index = response.rindex(end_marker) + 1
+
+ # Extract the JSON part
+ json_str = response[start_index:end_index]
+
+ # Parse the JSON string
+ parsed_json = json.loads(json_str)
+
+ return parsed_json
+
+ except ValueError as e:
+ raise ValueError(f"Error processing the response: {e}")
+
+ def filter_questions(parsed_json):
+ for item in parsed_json:
+ text = item['text']
+ questions = item['questions']
+
+ # Filter out questions where the answer does not match the text word by word
+ filtered_questions = [q for q in questions if q['answer'] in text.split()]
+
+ # Update the questions list with filtered questions
+ item['questions'] = filtered_questions
+
+ return parsed_json
+
+ @staticmethod
+ def ragas_eval(data_sample):
+ dataset = Dataset.from_dict(data_sample)
+ score = evaluate(dataset, metrics=[faithfulness, answer_correctness])
+ df = score.to_pandas()
+ df.to_csv('score.csv', index=False)
+
+ @staticmethod
+ def generate_questions_RAG(context):
+ genai.configure(api_key=API_KEY)
+ model = genai.GenerativeModel('gemini-pro')
+
+ prompt = f"""
+ 使用所提供的上下文產生至少一個可以直接從文字逐字回答的問題。
+ 問題不能簡單,答案也不能太短,必須逐字從文字中回答。
+ 使用JSON格式來回答 Context。
+ - 'questions': list of str
+ - 'ground_truth': The ground truth answer to the questions that you will answer word by word from the text
+
+ Here is the Context: {context}
+ """
+ while(True):
+ try:
+ response = model.generate_content(prompt)
+ break
+ except Exception as E:
+ print("\033[1,32m Quota exhausted. Waiting before retrying...\033[0,0m")
+ time.sleep(1)
+ return response.text
+
+
+
+def main():
+ # file_path = r'src/executor/docqa/src/AS-AIGFAQ.csv'
+ # text = Eval.read_csv_as_text(file_path, num_rows=2)
+
+ # questions = Eval.generate_question(text)
+ # # response = Eval.model.count_tokens(text)
+ # # print(f"Prompt Token Count: {response.total_tokens}")
+
+ # # if response.total_tokens >= 200000:
+ # # print("Please trim the file before pass in...")
+
+ # for question in questions:
+ # answer = "sample answer" # You need to provide an actual answer here
+ # Eval.eval_end2end(question, answer)
+
+
+
+
+ # print(Eval.clean_llm_response(generate_questions(context)))
+
+ #print(Eval.parse_and_filter_questions(jason_str))
+
+
+ return None
+
+if __name__ == '__main__':
+ main()