diff --git a/src/executor/docqa/eval/docqa.py b/src/executor/docqa/eval/docqa.py new file mode 100644 index 00000000..75946984 --- /dev/null +++ b/src/executor/docqa/eval/docqa.py @@ -0,0 +1,380 @@ +#!/bin/python3 +# -#- coding: UTF-8 -*- + +from typing import Generator, Iterable +from collections import namedtuple +from pathlib import Path +from urllib.error import HTTPError +from pprint import pformat +from langchain.docstore.document import Document +from kuwa.executor import Modelfile + +from .recursive_url_multimedia_loader import RecursiveUrlMultimediaLoader +from .document_store import DocumentStore +from .kuwa_llm_client import KuwaLlmClient +from .eval import Eval + +import i18n +import re +import gc +import os +import logging +import chevron +import asyncio +import copy +import time +import pathlib +import json + +class DocQa: + + def __init__( + self, + document_store:DocumentStore = DocumentStore, + vector_db:str = None, + llm:KuwaLlmClient = KuwaLlmClient(), + lang:str="en", + with_ref:bool=False, + user_agent:str = None + ): + self.logger = logging.getLogger(__name__) + self.llm = llm + self.lang = lang + self.with_ref = with_ref + self.user_agent = user_agent + if vector_db != None: + self.pre_build_db = True + self.document_store = DocumentStore.load(vector_db) + else: + self.pre_build_db = False + self.document_store:DocumentStore = document_store + + def generate_llm_input(self, task, question, related_docs, override_prompt:str=None): + """ + This function generates a formatted input string suitable for the Kuwa LLM model based on the provided task, question, related documents, and an optional override prompt. + + Args: + task (str): The type of task to be performed (e.g., "qa" for question answering, "summary" for summarization, and "translate"). + question (str): The user's question or prompt. + related_docs (List[Document]): A list of Document objects containing relevant information for the task. These document came from the retreiever embedding. + override_prompt (str, optional): An optional override prompt to use in place of the default prompt for the task. Defaults to None. + + Returns: + str: The formatted LLM input string. + """ + # Convert related documents to a dictionary format suitable for templating + docs = [dict(title=doc.metadata.get("title"), **dict(doc)) for doc in related_docs] + # Load the appropriate prompt template based on the task and language + template_path = f'lang/{self.lang}/prompt_template/llm_input_{task}.mustache' + llm_input_template = Path(template_path).read_text(encoding="utf8") + # Render the template with the provided data + llm_input = chevron.render(llm_input_template, { + 'docs': docs, + 'question': question, + 'override_prompt': override_prompt + }) + return llm_input + + def replace_chat_history(self, chat_history:[dict], task:str, question:str, related_docs:[str], override_prompt:str): + """ + This function modifies the provided chat history to include the generated LLM input for processing by the LLM model. + + Args: + chat_history (List[Dict]): A list of dictionaries representing the chat history, where each dictionary contains keys like "msg" (message content) and "isbot" (boolean indicating if the message is from a bot). + task (str): The type of task to be performed (e.g., "qa" for question answering, "summary" for summarization). + question (str): The user's question or prompt. + related_docs (List[str]): A list of URLs or file paths pointing to documents relevant to the user's query. + override_prompt (str, optional): An optional override prompt to use in place of the default prompt for the task. Defaults to None. + + Returns: + List[Dict]: The modified chat history with the LLM input message inserted. + """ + # Generate the LLM input using the provided information + llm_input = self.generate_llm_input(task, question, related_docs, override_prompt) + # Modify the chat history to include the LLM input + modified_chat_history = chat_history[:-1] + [{"isbot": False, "msg": llm_input}] + # If the first message in the chat history is empty + if modified_chat_history[0]["msg"] is None: + # If it's a multi-round conversation, set the first message to a summary prompt + if len(modified_chat_history) != 2: # Multi-round + modified_chat_history[0]["msg"] = i18n.t("docqa.summary_prompt") + # Otherwise, remove the first message + else: # Single-round + modified_chat_history = modified_chat_history[1:] + # Replace empty messages with a placeholder + modified_chat_history = [ + {"msg": "[Empty message]", "isbot": r["isbot"]} if r["msg"] == '' else r + for r in modified_chat_history + ] + return modified_chat_history + + def is_english(self, paragraph:str, threshold=0.8): + total_count = len(paragraph) + english_charter_count = len(paragraph.encode("ascii", "ignore")) + english_rate = 0 if total_count == 0 else english_charter_count / total_count + + return english_rate >= threshold + + def get_final_user_input(self, chat_history: [dict]) -> str: + """ + This function extracts the final user input (message) from the provided chat history. + + Args: + chat_history (List[Dict]): A list of dictionaries representing the chat history, where each dictionary contains keys like "msg" (message content) and "isbot" (boolean indicating if the message is from a bot). + + Returns: + str: The final user input message from the chat history, or None if no user messages are found. + """ + final_user_record = next(filter(lambda x: x['isbot'] == False, reversed(chat_history))) + return final_user_record['msg'] + + async def fetch_documents(self, url:str): + """ + This function asynchronously fetches documents from the provided URL using a RecursiveUrlMultimediaLoader. + + Args: + url (str): The URL of the webpage or document to be fetched. + + Returns: + List[Document]: A list of Document objects containing the fetched content, or an empty list if there were errors or no documents were found. + """ + # Fetching documents + self.logger.info(f'Fetching URL "{url}"') + docs = [] + # Create a RecursiveUrlMultimediaLoader to handle fetching documents + loader = RecursiveUrlMultimediaLoader( + url=url, + max_depth=1, + prevent_outside=False, + use_async = True, + cache_proxy_url = os.environ.get('HTTP_CACHE_PROXY', None), + forge_user_agent=self.user_agent + ) + try: + # Asynchronously load the documents using the loader + docs = await loader.async_load() + # Log the number of documents fetched + self.logger.info(f'Fetched {len(docs)} documents.') + except Exception as e: + docs = [] + finally: + return docs + + async def construct_document_store(self, docs: [Document]): + document_store = self.document_store + # Calls function from document_store to make the docs goes through embedding and forms vectorDB + await document_store.from_documents(docs) + return document_store + + def filter_detail(self, msg): + """ + This function removes HTML details tags (
...
) from the provided message string. + + Args: + msg (str): The message string potentially containing HTML details tags. + + Returns: + str: The message string with HTML details tags removed. If the input message is None, it returns None. + """ + if msg is None: return None + pattern = r"
.*
" + return re.sub(pattern=pattern, repl='', string=msg, flags=re.DOTALL) + + def format_references(self, docs:[Document]): + """ + This function formats a list of Document objects into a user-readable reference list with snippets. + + Args: + docs (List[Document]): A list of Document objects containing information about the reference sources. + + Returns: + str: The formatted reference list as a string, including source links, titles, and snippets. If no documents have a valid source, it returns an empty string. + """ + Reference = namedtuple("Reference", "source, title, content") + refs = [ + Reference( + source=doc.metadata.get("source"), + title=doc.metadata.get("title", doc.metadata.get("filename")), + content=doc.page_content, + ) for doc in docs + ] + refs = filter(lambda x: x.source, refs) + result = f"\n\n
{i18n.t('docqa.reference')}\n\n" + for i, ref in enumerate(refs): + + src = ref.source + title = ref.title if ref.title is not None else src + content = ref.content + link = src if src.startswith("http") else pathlib.Path(src).as_uri() + result += f'{i+1}. [{title}]({link})\n\n```plaintext\n{content}\n```\n\n' + result += f"
" + + return result + + async def process(self, urls: Iterable, chat_history: [dict], modelfile:Modelfile, auth_token=None) -> Generator[str, None, None]: + + """ + This asynchronous function processes a conversation history and a list of URLs to answer questions or summarize documents in a conversational manner. + + It interacts with the Kuwa LLM model to generate responses based on the provided information. + + Args: + urls (Iterable[str]): An iterable of URLs pointing to documents relevant to the user's query. + chat_history (List[Dict]): A list of dictionaries representing the conversation history, where each dictionary contains keys like "msg" (message content) and "isbot" (boolean indicating if the message is from a bot). + modelfile (Modelfile): A Modelfile object containing configuration parameters for the conversation. + auth_token (str, optional): An optional authentication token for the LLM model. Defaults to None. + + Yields: + str: Yields chunks of the conversation generated by the LLM model, including the answer or summary and potentially references to retrieved documents. + + Returns: + None: The function itself does not return a value, it yields the conversation chunks asynchronously. + """ + retriever_result = {"questions": []} + override_qa_prompt = modelfile.override_system_prompt + chat_history = [{"msg": i["content"], "isbot": i["role"]=="assistant"} for i in chat_history] + chat_history = [{"msg": self.filter_detail(i["msg"]), "isbot": i["isbot"]} for i in chat_history] + + self.logger.debug(f"Chat history: {chat_history}") + + # final_user_input = self.get_final_user_input(chat_history) + # if final_user_input is not None: + # final_user_input = "{before}{user}{after}".format( + # before = modelfile.before_prompt, + # user = final_user_input, + # after = modelfile.after_prompt + # ) + + + document_store = self.document_store + docs = None + if not self.pre_build_db: + if len(urls) == 1: + + docs = await self.fetch_documents(urls[0]) + if len(docs) == 0: + await asyncio.sleep(2) # To prevent SSE error of web page. + yield i18n.t('docqa.error_fetching_document') + return + else: + docs = await asyncio.gather(*[self.fetch_documents(url) for url in urls]) + docs = [doc for sub_docs in docs for doc in sub_docs] + document_store = await self.construct_document_store(docs) + + text = docs[0] # the text will be coming from the document of what user upload + data_sample = Eval.clean_llm_response(Eval.generate_questions_RAG(text))#evalotor.generate_question(text) + data_sample['contexts'] = [] + data_sample['answer'] = [] + gen_questions = data_sample['questions'] + self.logger.info(f"{gen_questions}") + for gen_question in gen_questions: + retriever_result = {"questions": []} + final_user_input = gen_question + task = '' + if final_user_input == "": + question = i18n.t("docqa.summary_question") + llm_question = None + task = 'summary' + await asyncio.sleep(2) # To prevent SSE error of web page. + yield i18n.t("docqa.summary_prefix")+'\n' + else: + question = final_user_input + llm_question = question + task = 'qa' + + # Shortcut + if docs != None: + related_docs = docs + modified_chat_history = self.replace_chat_history(chat_history, task, llm_question, related_docs, override_prompt=override_qa_prompt) + + if docs == None or self.llm.is_too_long(modified_chat_history): + # Retrieve + related_docs = copy.deepcopy(await document_store.retrieve(question)) + + self.logger.info("Related documents: {}".format(related_docs)) + # [TODO] the related-document will be cleared when the history is too long + while True: + modified_chat_history = self.replace_chat_history(chat_history, task, llm_question, related_docs, override_prompt=override_qa_prompt) + if not self.llm.is_too_long(modified_chat_history) or len(related_docs)==0: break + related_docs = related_docs[:-1] + self.logger.info("Prompt length exceeded the permitted limit, necessitating truncation.") + + #result = Eval.eval_retriever(gen_question, related_docs) + #retriever_result["questions"].append(result) + #retriever_result = Eval.yield_retriever(retriever_result) + #yield retriever_result + + page_content = Eval.extract_page_content(related_docs) + data_sample['contexts'].append(page_content) + answer = Eval.generate_answer(gen_question, page_content) + data_sample['answer'].append(answer) + yield pformat(data_sample) + Eval.ragas_eval(data_sample) + + break + #end2end_result = Eval.eval_end2end(gen_question, answer) + #yield Eval.yield_end2end(end2end_result) + # yield pformat(Eval.eval_end2end(gen_question, answer)) + + # Free the unused VRAM + # del document_store + # gc.collect() + + # # # Generate + # llm_input = self.generate_llm_input(task, llm_question, related_docs, override_prompt=override_qa_prompt) + # self.logger.info("Related documents: {}".format(related_docs)) + # self.logger.info('LLM input: {}'.format(llm_input)) + # # result = '' + # generator = self.llm.chat_complete( + # auth_token=auth_token, + # messages=modified_chat_history + # ) + # buffer = "" + # async for chunk in generator: + # buffer += chunk + + + # self.logger.info("Buffer is: " + buffer) + # result = evalotor.eval_end2end(gen_question, buffer) + # self.logger.info("Return result is: " + result) + # yield result + # yield chunk + + + if self.with_ref and len(related_docs)!=0: + yield self.format_references(related_docs) + + + # # # Generate + # llm_input = self.generate_llm_input(task, llm_question, related_docs, override_prompt=override_qa_prompt) + # self.logger.info("Related documents: {}".format(related_docs)) + # self.logger.info('LLM input: {}'.format(llm_input)) + # generator = self.llm.chat_complete( + # # auth_token=auth_token, # Optional + # # messages=modified_chat_history # Optional (depending on your model) + # ) + # async for chunk in generator: + # # Call your function that returns a JSON object + # json_data = await self.your_function(chunk) # Replace with your function name + # # Convert the JSON object to a string + # json_string = json.dumps(json_data) + # yield json_string + + # if self.with_ref and len(related_docs)!=0: + # yield self.format_references(related_docs) + + # Egress filter + # is_english = self.is_english(result) + # self.logger.info(f'Is English: {is_english}') + # if task == 'summary' and is_english: + # result = await self.llm.chat_complete( + # auth_token=auth_token, + # messages=[ + # { + # "isbot": False, + # "msg": self.generate_llm_input('translate', result, []) + # }, + # ] + # ) + + # yield result \ No newline at end of file diff --git a/src/executor/docqa/eval/eval.py b/src/executor/docqa/eval/eval.py new file mode 100644 index 00000000..7d169b56 --- /dev/null +++ b/src/executor/docqa/eval/eval.py @@ -0,0 +1,290 @@ +import pandas as pd +import os +import google.generativeai as genai +import json +import time +from typing import List +import re +from datasets import Dataset +from ragas import evaluate +from ragas.metrics import faithfulness, answer_correctness + +API_KEY = '' + +class Eval: + @staticmethod + def save_to_json(data, filename): + with open(filename, 'w') as f: + json.dump(data, f, ensure_ascii=False, indent=4) + + @staticmethod + def read_csv_as_text(file_path, num_rows=None): + df = pd.read_csv(file_path, nrows=num_rows) + text = df.to_csv(index=False) + return text + + @staticmethod + def read_document(file_path): + with open(file_path, 'r', encoding='utf-8') as file: + content = file.read() + return content + + @staticmethod + def eval_retriever(question, relevant_chunks): + genai.configure(api_key=API_KEY) + model = genai.GenerativeModel('gemini-pro') + yes_chunks = [] + no_chunks = [] + error_chunks = [] + for chunk in relevant_chunks: + prompt = f""" + 提問:{question} + 出處:{chunk} + 請問出處的資料有沒有跟提問的問題有關聯性? 請用只使用'y' 或是 'n' 來回答,不要包含任何解釋 + """ + while(True): + try: + response = model.generate_content(prompt) + break + except Exception as E: + print("\033[1,32m Quota exhausted. Waiting before retrying...\033[0,0m") + time.sleep(1) + response_text = response.text.strip() + if response_text == 'y': + yes_chunks.append(chunk) + elif response_text == 'n': + no_chunks.append(chunk) + else: + error_chunks.append({"chunk": chunk, "response": response_text}) + + yes_length = len(yes_chunks) + accuracy = yes_length / (yes_length + len(no_chunks) + len(error_chunks)) + retriever_result = { + "question": question, + "y": yes_chunks, + "n": no_chunks, + "error": error_chunks, + "accuracy": accuracy + } + # Eval.save_to_json(retriever_result, 'retriever_result.json') + return retriever_result + + @staticmethod + def eval_end2end(question, answer): + genai.configure(api_key=API_KEY) + model = genai.GenerativeModel('gemini-pro') + + prompt = f""" + 問題:{question} + 答案:{answer} + 請問答案有沒有回答到問題? 請用只使用'y' 或是 'n' 來回答,不要包含任何解釋 + """ + + while(True): + try: + response = model.generate_content(prompt) + break + except Exception as E: + print("\033[1,32m Quota exhausted. Waiting before retrying...\033[0,0m") + time.sleep(1) + + response_text = response.text.strip() + retriever_result = { + "question": question, + "answer": answer, + "related": response_text + } + + return retriever_result + + @staticmethod + def generate_question(text: str or list) -> list: + genai.configure(api_key=API_KEY) + model = genai.GenerativeModel('gemini-pro') + + if isinstance(text, list): + text_string = " ".join(text) + else: + text_string = text + + prompt = f"請閱讀以下文章並且使用一個 \"- \" 做區分跟列出相關問題,最多隨機產生10個問題,請只包含問題不包含其他不必要資訊: {text_string}" + + try: + response = model.generate_content(prompt) + except Exception as E: + print("\033[1,32m Quota exhausted. Waiting before retrying...\033[0,0m") + time.sleep(1) + return E + + pattern = r"- (.*)" + strings = re.findall(pattern, response.text, flags=re.MULTILINE) + return strings + + @staticmethod + def generate_answer(question, relevant_chunks): + genai.configure(api_key=API_KEY) + model = genai.GenerativeModel('gemini-pro') + prompt = f""" + 提問:{question} + 出處:{relevant_chunks} + 請使用出處的資料來回答提問,並且用"From:"列出用來回答提問的出處 + """ + while(True): + try: + response = model.generate_content(prompt) + break + except Exception as E: + print("\033[1,32m Quota exhausted. Waiting before retrying...\033[0,0m") + time.sleep(1) + return response.text + + @staticmethod + def yield_retriever(data): + """ + Yields information from the provided JSON data in a formatted manner. + + Args: + data: A JSON object containing questions, their accuracies, and relevant/irrelevant chunks. + + Yields: + A formatted string representing the question, accuracy, and chunks. + """ + + result_string = "" + for question_data in data['questions']: + question = question_data['question'] + accuracy = f"Accuracy: {question_data['accuracy']:.2f}" + relevant_chunks = Eval.extract_page_content(question_data['y']) + irrelevant_chunks = Eval.extract_page_content(question_data['n']) + + formatted_relevant_chunks = Eval.format_chunks(relevant_chunks) + formatted_irrelevant_chunks = Eval.format_chunks(irrelevant_chunks) + + result_string += f"\nQuestion: {question}\nAccuracy: {accuracy}\n\nRelevant Chunks:\n{formatted_relevant_chunks}\n\nIrrelevant Chunks:\n{formatted_irrelevant_chunks}\n\n" + + return result_string + + @staticmethod + def yield_end2end(data): + + question = data['question'] + answer = data['answer'] + related = data['related'] + return f""" + Question: {question} + Answer: {answer} + Is related: {related} + """ + + def format_chunks(chunks): + return "\n".join(f"{index + 1}. {chunk}" for index, chunk in enumerate(chunks)) + + @staticmethod + def extract_page_content(relevant_chunks): + """Extracts page content from a list of Document objects. + + Args: + relevant_chunks: A list of Document objects. + + Returns: + A list of page content strings. + """ + + page_contents = [] + for document in relevant_chunks: + page_contents.append(document.page_content) + return page_contents + + @staticmethod + def clean_llm_response(response): + try: + # Find the start and end positions of the JSON object + start_marker = '{' + end_marker = '}' + + start_index = response.index(start_marker) + end_index = response.rindex(end_marker) + 1 + + # Extract the JSON part + json_str = response[start_index:end_index] + + # Parse the JSON string + parsed_json = json.loads(json_str) + + return parsed_json + + except ValueError as e: + raise ValueError(f"Error processing the response: {e}") + + def filter_questions(parsed_json): + for item in parsed_json: + text = item['text'] + questions = item['questions'] + + # Filter out questions where the answer does not match the text word by word + filtered_questions = [q for q in questions if q['answer'] in text.split()] + + # Update the questions list with filtered questions + item['questions'] = filtered_questions + + return parsed_json + + @staticmethod + def ragas_eval(data_sample): + dataset = Dataset.from_dict(data_sample) + score = evaluate(dataset, metrics=[faithfulness, answer_correctness]) + df = score.to_pandas() + df.to_csv('score.csv', index=False) + + @staticmethod + def generate_questions_RAG(context): + genai.configure(api_key=API_KEY) + model = genai.GenerativeModel('gemini-pro') + + prompt = f""" + 使用所提供的上下文產生至少一個可以直接從文字逐字回答的問題。 + 問題不能簡單,答案也不能太短,必須逐字從文字中回答。 + 使用JSON格式來回答 Context。 + - 'questions': list of str + - 'ground_truth': The ground truth answer to the questions that you will answer word by word from the text + + Here is the Context: {context} + """ + while(True): + try: + response = model.generate_content(prompt) + break + except Exception as E: + print("\033[1,32m Quota exhausted. Waiting before retrying...\033[0,0m") + time.sleep(1) + return response.text + + + +def main(): + # file_path = r'src/executor/docqa/src/AS-AIGFAQ.csv' + # text = Eval.read_csv_as_text(file_path, num_rows=2) + + # questions = Eval.generate_question(text) + # # response = Eval.model.count_tokens(text) + # # print(f"Prompt Token Count: {response.total_tokens}") + + # # if response.total_tokens >= 200000: + # # print("Please trim the file before pass in...") + + # for question in questions: + # answer = "sample answer" # You need to provide an actual answer here + # Eval.eval_end2end(question, answer) + + + + + # print(Eval.clean_llm_response(generate_questions(context))) + + #print(Eval.parse_and_filter_questions(jason_str)) + + + return None + +if __name__ == '__main__': + main()