diff --git a/.gitignore b/.gitignore index f233bd5f..ab0e45ab 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,5 @@ # Local data -data/local_data/ +resources/local_data/ # Prototyping notebooks notebooks/ diff --git a/data/paul_graham_essay.docx b/data/paul_graham_essay.docx deleted file mode 100644 index 54249bfb..00000000 Binary files a/data/paul_graham_essay.docx and /dev/null differ diff --git a/knowledge_gpt/components/sidebar.py b/knowledge_gpt/components/sidebar.py index 22ba156e..8461d029 100644 --- a/knowledge_gpt/components/sidebar.py +++ b/knowledge_gpt/components/sidebar.py @@ -7,10 +7,6 @@ load_dotenv() -def set_openai_api_key(api_key: str): - st.session_state["OPENAI_API_KEY"] = api_key - - def sidebar(): with st.sidebar: st.markdown( @@ -28,8 +24,7 @@ def sidebar(): or st.session_state.get("OPENAI_API_KEY", ""), ) - if api_key_input: - set_openai_api_key(api_key_input) + st.session_state["OPENAI_API_KEY"] = api_key_input st.markdown("---") st.markdown("# About") diff --git a/knowledge_gpt/utils/__init__.py b/knowledge_gpt/core/__init__.py similarity index 100% rename from knowledge_gpt/utils/__init__.py rename to knowledge_gpt/core/__init__.py diff --git a/knowledge_gpt/core/caching.py b/knowledge_gpt/core/caching.py new file mode 100644 index 00000000..a878ebbe --- /dev/null +++ b/knowledge_gpt/core/caching.py @@ -0,0 +1,33 @@ +import streamlit as st +from streamlit.runtime.caching.hashing import HashFuncsDict + +import knowledge_gpt.core.parsing as parsing +import knowledge_gpt.core.chunking as chunking +import knowledge_gpt.core.embedding as embedding +from knowledge_gpt.core.parsing import File + + +def file_hash_func(file: File) -> str: + """Get a unique hash for a file""" + return file.id + + +@st.cache_resource() +def bootstrap_caching(): + """Patch module functions with caching""" + + # Get all substypes of File from module + file_subtypes = [ + cls + for cls in vars(parsing).values() + if isinstance(cls, type) and issubclass(cls, File) and cls != File + ] + file_hash_funcs: HashFuncsDict = {cls: file_hash_func for cls in file_subtypes} + + parsing.read_file = st.cache_data(show_spinner=False)(parsing.read_file) + chunking.chunk_file = st.cache_data(show_spinner=False, hash_funcs=file_hash_funcs)( + chunking.chunk_file + ) + embedding.embed_files = st.cache_data( + show_spinner=False, hash_funcs=file_hash_funcs + )(embedding.embed_files) diff --git a/knowledge_gpt/core/chunking.py b/knowledge_gpt/core/chunking.py new file mode 100644 index 00000000..7739c663 --- /dev/null +++ b/knowledge_gpt/core/chunking.py @@ -0,0 +1,38 @@ +from langchain.docstore.document import Document +from langchain.text_splitter import RecursiveCharacterTextSplitter +from knowledge_gpt.core.parsing import File + + +def chunk_file( + file: File, chunk_size: int, chunk_overlap: int = 0, model_name="gpt-3.5-turbo" +) -> File: + """Chunks each document in a file into smaller documents + according to the specified chunk size and overlap + where the size is determined by the number of token for the specified model. + """ + + # split each document into chunks + chunked_docs = [] + for doc in file.docs: + text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder( + model_name=model_name, + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + ) + + chunks = text_splitter.split_text(doc.page_content) + + for i, chunk in enumerate(chunks): + doc = Document( + page_content=chunk, + metadata={ + "page": doc.metadata.get("page", 1), + "chunk": i + 1, + "source": f"{doc.metadata.get('page', 1)}-{i + 1}", + }, + ) + chunked_docs.append(doc) + + chunked_file = file.copy() + chunked_file.docs = chunked_docs + return chunked_file diff --git a/knowledge_gpt/core/embedding.py b/knowledge_gpt/core/embedding.py new file mode 100644 index 00000000..87b9fadb --- /dev/null +++ b/knowledge_gpt/core/embedding.py @@ -0,0 +1,71 @@ +from langchain.vectorstores import VectorStore +from knowledge_gpt.core.parsing import File +from langchain.vectorstores.faiss import FAISS +from langchain.embeddings import OpenAIEmbeddings +from langchain.embeddings.base import Embeddings +from typing import List, Type +from langchain.docstore.document import Document + + +class FolderIndex: + """Index for a collection of files (a folder)""" + + def __init__(self, files: List[File], index: VectorStore): + self.name: str = "default" + self.files = files + self.index: VectorStore = index + + @staticmethod + def _combine_files(files: List[File]) -> List[Document]: + """Combines all the documents in a list of files into a single list.""" + + all_texts = [] + for file in files: + for doc in file.docs: + doc.metadata["file_name"] = file.name + doc.metadata["file_id"] = file.id + all_texts.append(doc) + + return all_texts + + @classmethod + def from_files( + cls, files: List[File], embeddings: Embeddings, vector_store: Type[VectorStore] + ) -> "FolderIndex": + """Creates an index from files.""" + + all_docs = cls._combine_files(files) + + index = vector_store.from_documents( + documents=all_docs, + embedding=embeddings, + ) + + return cls(files=files, index=index) + + +def embed_files( + files: List[File], embedding: str, vector_store: str, **kwargs +) -> FolderIndex: + """Embeds a collection of files and stores them in a FolderIndex.""" + + supported_embeddings: dict[str, Type[Embeddings]] = { + "openai": OpenAIEmbeddings, + } + supported_vector_stores: dict[str, Type[VectorStore]] = { + "faiss": FAISS, + } + + if embedding in supported_embeddings: + _embeddings = supported_embeddings[embedding](**kwargs) + else: + raise NotImplementedError(f"Embedding {embedding} not supported.") + + if vector_store in supported_vector_stores: + _vector_store = supported_vector_stores[vector_store] + else: + raise NotImplementedError(f"Vector store {vector_store} not supported.") + + return FolderIndex.from_files( + files=files, embeddings=_embeddings, vector_store=_vector_store + ) diff --git a/knowledge_gpt/core/parsing.py b/knowledge_gpt/core/parsing.py new file mode 100644 index 00000000..d092f7ea --- /dev/null +++ b/knowledge_gpt/core/parsing.py @@ -0,0 +1,102 @@ +from io import BytesIO +from typing import List, Any, Optional +import re + +import docx2txt +from langchain.docstore.document import Document +from pypdf import PdfReader +from hashlib import md5 + +from abc import abstractmethod, ABC +from copy import deepcopy + + +class File(ABC): + """Represents an uploaded file comprised of Documents""" + + def __init__( + self, + name: str, + id: str, + metadata: Optional[dict[str, Any]] = None, + docs: Optional[List[Document]] = None, + ): + self.name = name + self.id = id + self.metadata = metadata or {} + self.docs = docs or [] + + @classmethod + @abstractmethod + def from_bytes(cls, file: BytesIO) -> "File": + """Creates a File from a BytesIO object""" + + def __repr__(self) -> str: + return ( + f"File(name={self.name}, id={self.id}," + " metadata={self.metadata}, docs={self.docs})" + ) + + def __str__(self) -> str: + return f"File(name={self.name}, id={self.id}, metadata={self.metadata})" + + def copy(self) -> "File": + """Create a deep copy of this File""" + return self.__class__( + name=self.name, + id=self.id, + metadata=deepcopy(self.metadata), + docs=deepcopy(self.docs), + ) + + +def strip_consecutive_newlines(text: str) -> str: + """Strips consecutive newlines from a string + possibly with whitespace in between + """ + return re.sub(r"\s*\n\s*", "\n", text) + + +class DocxFile(File): + @classmethod + def from_bytes(cls, file: BytesIO) -> "DocxFile": + text = docx2txt.process(file) + text = strip_consecutive_newlines(text) + doc = Document(page_content=text.strip()) + return cls(name=file.name, id=md5(file.read()).hexdigest(), docs=[doc]) + + +class PdfFile(File): + @classmethod + def from_bytes(cls, file: BytesIO) -> "PdfFile": + pdf = PdfReader(file) + docs = [] + for i, page in enumerate(pdf.pages): + text = page.extract_text() + text = strip_consecutive_newlines(text) + doc = Document(page_content=text.strip()) + doc.metadata["page"] = i + 1 + docs.append(doc) + return cls(name=file.name, id=md5(file.read()).hexdigest(), docs=docs) + + +class TxtFile(File): + @classmethod + def from_bytes(cls, file: BytesIO) -> "TxtFile": + text = file.read().decode("utf-8") + text = strip_consecutive_newlines(text) + file.seek(0) + doc = Document(page_content=text.strip()) + return cls(name=file.name, id=md5(file.read()).hexdigest(), docs=[doc]) + + +def read_file(file: BytesIO) -> File: + """Reads an uploaded file and returns a File object""" + if file.name.endswith(".docx"): + return DocxFile.from_bytes(file) + elif file.name.endswith(".pdf"): + return PdfFile.from_bytes(file) + elif file.name.endswith(".txt"): + return TxtFile.from_bytes(file) + else: + raise NotImplementedError diff --git a/knowledge_gpt/prompts.py b/knowledge_gpt/core/prompts.py similarity index 100% rename from knowledge_gpt/prompts.py rename to knowledge_gpt/core/prompts.py diff --git a/knowledge_gpt/core/qa.py b/knowledge_gpt/core/qa.py new file mode 100644 index 00000000..faae6a92 --- /dev/null +++ b/knowledge_gpt/core/qa.py @@ -0,0 +1,62 @@ +from typing import Any, List +from langchain.chains.qa_with_sources import load_qa_with_sources_chain +from knowledge_gpt.core.prompts import STUFF_PROMPT +from langchain.docstore.document import Document +from langchain.chat_models import ChatOpenAI +from knowledge_gpt.core.embedding import FolderIndex +from pydantic import BaseModel + + +class AnswerWithSources(BaseModel): + answer: str + sources: List[Document] + + +def query_folder( + query: str, folder_index: FolderIndex, return_all: bool = False, **model_kwargs: Any +) -> AnswerWithSources: + """Queries a folder index for an answer. + + Args: + query (str): The query to search for. + folder_index (FolderIndex): The folder index to search. + return_all (bool): Whether to return all the documents from the embedding or + just the sources for the answer. + **model_kwargs (Any): Keyword arguments for the model. + + Returns: + AnswerWithSources: The answer and the source documents. + """ + + chain = load_qa_with_sources_chain( + llm=ChatOpenAI(**model_kwargs), + chain_type="stuff", + prompt=STUFF_PROMPT, + ) + + relevant_docs = folder_index.index.similarity_search(query, k=5) + result = chain( + {"input_documents": relevant_docs, "question": query}, return_only_outputs=True + ) + sources = relevant_docs + + if not return_all: + sources = get_sources(result["output_text"], folder_index) + + answer = result["output_text"].split("SOURCES: ")[0] + + return AnswerWithSources(answer=answer, sources=sources) + + +def get_sources(answer: str, folder_index: FolderIndex) -> List[Document]: + """Retrieves the docs that were used to answer the question the generated answer.""" + + source_keys = [s for s in answer.split("SOURCES: ")[-1].split(", ")] + + source_docs = [] + for file in folder_index.files: + for doc in file.docs: + if doc.metadata["source"] in source_keys: + source_docs.append(doc) + + return source_docs diff --git a/knowledge_gpt/main.py b/knowledge_gpt/main.py index 18a0ee8b..b7db8a71 100644 --- a/knowledge_gpt/main.py +++ b/knowledge_gpt/main.py @@ -1,83 +1,105 @@ import streamlit as st -from openai.error import OpenAIError from knowledge_gpt.components.sidebar import sidebar -from knowledge_gpt.utils.QA import ( - embed_docs, - get_answer, - get_sources, - parse_file, - text_to_docs, -) - -from knowledge_gpt.utils.UI import wrap_text_in_html, is_valid +from knowledge_gpt.ui import ( + wrap_doc_in_html, + is_query_valid, + is_file_valid, + is_open_ai_key_valid, +) -def clear_submit(): - st.session_state["submit"] = False +from knowledge_gpt.core.caching import bootstrap_caching +from knowledge_gpt.core.parsing import read_file +from knowledge_gpt.core.chunking import chunk_file +from knowledge_gpt.core.embedding import embed_files +from knowledge_gpt.core.qa import query_folder st.set_page_config(page_title="KnowledgeGPT", page_icon="📖", layout="wide") st.header("📖KnowledgeGPT") +# Enable caching for expensive functions +bootstrap_caching() + sidebar() +openai_api_key = st.session_state.get("OPENAI_API_KEY") + + +if not openai_api_key: + st.warning( + "Enter your OpenAI API key in the sidebar. You can get a key at" + " https://platform.openai.com/account/api-keys." + ) + + uploaded_file = st.file_uploader( "Upload a pdf, docx, or txt file", type=["pdf", "docx", "txt"], help="Scanned documents are not supported yet!", - on_change=clear_submit, ) -index = None -texts = None -if uploaded_file is not None: - texts = parse_file(uploaded_file) - docs = text_to_docs(texts) +if not uploaded_file: + st.stop() + - try: - with st.spinner("Indexing document... This may take a while⏳"): - index = embed_docs(docs) - except OpenAIError as e: - st.error(e._message) +file = read_file(uploaded_file) +chunked_file = chunk_file(file, chunk_size=300, chunk_overlap=0) + +if not is_file_valid(file): + st.stop() + +if not is_open_ai_key_valid(openai_api_key): + st.stop() + + +with st.spinner("Indexing document... This may take a while⏳"): + folder_index = embed_files( + files=[chunked_file], + embedding="openai", + vector_store="faiss", + openai_api_key=openai_api_key, + ) + +with st.form(key="qa_form"): + query = st.text_area("Ask a question about the document") + submit = st.form_submit_button("Submit") -query = st.text_area("Ask a question about the document", on_change=clear_submit) with st.expander("Advanced Options"): - show_all_chunks = st.checkbox("Show all chunks retrieved from vector search") + return_all_chunks = st.checkbox("Show all chunks retrieved from vector search") show_full_doc = st.checkbox("Show parsed contents of the document") -if show_full_doc and texts: + +if show_full_doc: with st.expander("Document"): # Hack to get around st.markdown rendering LaTeX - st.markdown(f"

{wrap_text_in_html(texts)}

", unsafe_allow_html=True) + st.markdown(f"

{wrap_doc_in_html(file.docs)}

", unsafe_allow_html=True) -button = st.button("Submit") -if button or st.session_state.get("submit"): - if not is_valid(index, query): + +if submit: + if not is_query_valid(query): st.stop() - st.session_state["submit"] = True # Output Columns answer_col, sources_col = st.columns(2) - sources = index.similarity_search(query, k=5) # type: ignore - - try: - answer = get_answer(sources, query) - if not show_all_chunks: - # Get the sources for the answer - sources = get_sources(answer, sources) - - with answer_col: - st.markdown("#### Answer") - st.markdown(answer["output_text"].split("SOURCES: ")[0]) - - with sources_col: - st.markdown("#### Sources") - for source in sources: - st.markdown(source.page_content) - st.markdown(source.metadata["source"]) - st.markdown("---") - - except OpenAIError as e: - st.error(e._message) + + result = query_folder( + folder_index=folder_index, + query=query, + return_all=return_all_chunks, + openai_api_key=openai_api_key, + temperature=0, + ) + + with answer_col: + st.markdown("#### Answer") + st.markdown(result.answer) + + with sources_col: + st.markdown("#### Sources") + for source in result.sources: + st.markdown(source.page_content) + st.markdown(source.metadata["source"]) + st.markdown("---") diff --git a/knowledge_gpt/ui.py b/knowledge_gpt/ui.py new file mode 100644 index 00000000..81ba8007 --- /dev/null +++ b/knowledge_gpt/ui.py @@ -0,0 +1,53 @@ +from typing import List +import streamlit as st +from langchain.docstore.document import Document +from knowledge_gpt.core.parsing import File +import openai +from streamlit.logger import get_logger + +logger = get_logger(__name__) + + +def wrap_doc_in_html(docs: List[Document]) -> str: + """Wraps each page in document separated by newlines in

tags""" + text = [doc.page_content for doc in docs] + if isinstance(text, list): + # Add horizontal rules between pages + text = "\n


\n".join(text) + return "".join([f"

{line}

" for line in text.split("\n")]) + + +def is_query_valid(query: str) -> bool: + if not query: + st.error("Please enter a question!") + return False + return True + + +def is_file_valid(file: File) -> bool: + if len(file.docs) == 0 or len(file.docs[0].page_content.strip()) == 0: + st.error( + "Cannot read document! Make sure the document has" + " selectable text or is not password protected." + ) + logger.error("Cannot read document") + return False + return True + + +@st.cache_data(show_spinner=False) +def is_open_ai_key_valid(openai_api_key) -> bool: + if not openai_api_key: + st.error("Please enter your OpenAI API key in the sidebar!") + return False + try: + openai.ChatCompletion.create( + model="gpt-3.5-turbo", + messages=[{"role": "user", "content": "test"}], + api_key=openai_api_key, + ) + except Exception as e: + st.error(f"{e.__class__.__name__}: {e}") + logger.error(f"{e.__class__.__name__}: {e}") + return False + return True diff --git a/knowledge_gpt/utils/QA.py b/knowledge_gpt/utils/QA.py deleted file mode 100644 index 26dd45f7..00000000 --- a/knowledge_gpt/utils/QA.py +++ /dev/null @@ -1,158 +0,0 @@ -import re -from io import BytesIO -from typing import Any, Dict, List - -import docx2txt -import streamlit as st -from langchain.chains.qa_with_sources import load_qa_with_sources_chain -from langchain.docstore.document import Document -from langchain.chat_models import ChatOpenAI -from langchain.text_splitter import RecursiveCharacterTextSplitter -from langchain.vectorstores import VectorStore -from langchain.vectorstores.faiss import FAISS -from openai.error import AuthenticationError -from pypdf import PdfReader - -from langchain.embeddings import OpenAIEmbeddings -from knowledge_gpt.prompts import STUFF_PROMPT - -from hashlib import md5 - - -def hash_func(doc: Document) -> str: - """Hash function for caching Documents""" - return md5(doc.page_content.encode("utf-8")).hexdigest() - - -@st.cache_data() -def parse_docx(file: BytesIO) -> str: - text = docx2txt.process(file) - # Remove multiple newlines - text = re.sub(r"\n\s*\n", "\n\n", text) - return text - - -@st.cache_data() -def parse_pdf(file: BytesIO) -> List[str]: - pdf = PdfReader(file) - output = [] - for page in pdf.pages: - text = page.extract_text() - # Merge hyphenated words - text = re.sub(r"(\w+)-\n(\w+)", r"\1\2", text) - # Fix newlines in the middle of sentences - text = re.sub(r"(? str: - text = file.read().decode("utf-8") - # Remove multiple newlines - text = re.sub(r"\n\s*\n", "\n\n", text) - return text - - -@st.cache_data() -def text_to_docs(text: str | List[str]) -> List[Document]: - """Converts a string or list of strings to a list of Documents - with metadata.""" - if isinstance(text, str): - # Take a single string as one page - text = [text] - page_docs = [Document(page_content=page) for page in text] - - # Add page numbers as metadata - for i, doc in enumerate(page_docs): - doc.metadata["page"] = i + 1 - - # Split pages into chunks - doc_chunks = [] - - for doc in page_docs: - text_splitter = RecursiveCharacterTextSplitter( - chunk_size=800, - separators=["\n\n", "\n", ",", " ", ""], - chunk_overlap=0, - ) - chunks = text_splitter.split_text(doc.page_content) - for i, chunk in enumerate(chunks): - doc = Document( - page_content=chunk, metadata={"page": doc.metadata["page"], "chunk": i} - ) - # Add sources a metadata - doc.metadata["source"] = f"{doc.metadata['page']}-{doc.metadata['chunk']}" - doc_chunks.append(doc) - return doc_chunks - - -@st.cache_data() -def parse_file(file: BytesIO) -> str | List[str]: - """Parses a file and returns a list of Documents.""" - if file.name.endswith(".pdf"): - return parse_pdf(file) - elif file.name.endswith(".docx"): - return parse_docx(file) - elif file.name.endswith(".txt"): - return parse_txt(file) - else: - raise ValueError("File type not supported!") - - -@st.cache_data(show_spinner=False, hash_funcs={Document: hash_func}) -def embed_docs(docs: List[Document]) -> VectorStore: - """Embeds a list of Documents and returns a FAISS index""" - - if not st.session_state.get("OPENAI_API_KEY"): - raise AuthenticationError( - "Enter your OpenAI API key in the sidebar. You can get a key at" - " https://platform.openai.com/account/api-keys." - ) - else: - # Embed the chunks - embeddings = OpenAIEmbeddings( - openai_api_key=st.session_state.get("OPENAI_API_KEY"), - ) # type: ignore - - index = FAISS.from_documents(docs, embeddings) - - return index - - -@st.cache_data(show_spinner=False, hash_funcs={Document: hash_func}) -def get_answer(docs: List[Document], query: str) -> Dict[str, Any]: - """Gets an answer to a question from a list of Documents.""" - - # Get the answer - chain = load_qa_with_sources_chain( - ChatOpenAI( - temperature=0, openai_api_key=st.session_state.get("OPENAI_API_KEY") - ), # type: ignore - chain_type="stuff", - prompt=STUFF_PROMPT, - ) - - answer = chain( - {"input_documents": docs, "question": query}, return_only_outputs=True - ) - return answer - - -@st.cache_data(show_spinner=False, hash_funcs={Document: hash_func}) -def get_sources(answer: Dict[str, Any], docs: List[Document]) -> List[Document]: - """Gets the source documents for an answer.""" - - # Get sources for the answer - source_keys = [s for s in answer["output_text"].split("SOURCES: ")[-1].split(", ")] - - source_docs = [] - for doc in docs: - if doc.metadata["source"] in source_keys: - source_docs.append(doc) - - return source_docs diff --git a/knowledge_gpt/utils/UI.py b/knowledge_gpt/utils/UI.py deleted file mode 100644 index 5635272e..00000000 --- a/knowledge_gpt/utils/UI.py +++ /dev/null @@ -1,23 +0,0 @@ -from typing import List -import streamlit as st - - -def wrap_text_in_html(text: str | List[str]) -> str: - """Wraps each text block separated by newlines in

tags""" - if isinstance(text, list): - # Add horizontal rules between pages - text = "\n


\n".join(text) - return "".join([f"

{line}

" for line in text.split("\n")]) - - -def is_valid(index, query): - if not st.session_state.get("OPENAI_API_KEY"): - st.error("Please configure your OpenAI API key!") - return False - if not index: - st.error("Please upload a document!") - return False - if not query: - st.error("Please enter a question!") - return False - return True diff --git a/data/employment_agreement.pdf b/resources/employment_agreement.pdf similarity index 100% rename from data/employment_agreement.pdf rename to resources/employment_agreement.pdf diff --git a/data/paul_graham_essay.pdf b/resources/paul_graham_essay.pdf similarity index 100% rename from data/paul_graham_essay.pdf rename to resources/paul_graham_essay.pdf diff --git a/data/paul_graham_essay.txt b/resources/paul_graham_essay.txt similarity index 100% rename from data/paul_graham_essay.txt rename to resources/paul_graham_essay.txt diff --git a/data/questions.md b/resources/questions.md similarity index 100% rename from data/questions.md rename to resources/questions.md diff --git a/resources/samples/test_hello.docx b/resources/samples/test_hello.docx new file mode 100644 index 00000000..66a552d8 Binary files /dev/null and b/resources/samples/test_hello.docx differ diff --git a/resources/samples/test_hello.pdf b/resources/samples/test_hello.pdf new file mode 100644 index 00000000..4bb8c942 Binary files /dev/null and b/resources/samples/test_hello.pdf differ diff --git a/resources/samples/test_hello.txt b/resources/samples/test_hello.txt new file mode 100644 index 00000000..5e1c309d --- /dev/null +++ b/resources/samples/test_hello.txt @@ -0,0 +1 @@ +Hello World \ No newline at end of file diff --git a/resources/samples/test_hello_multi.docx b/resources/samples/test_hello_multi.docx new file mode 100644 index 00000000..302e1c17 Binary files /dev/null and b/resources/samples/test_hello_multi.docx differ diff --git a/resources/samples/test_hello_multi.pdf b/resources/samples/test_hello_multi.pdf new file mode 100644 index 00000000..91d92cad Binary files /dev/null and b/resources/samples/test_hello_multi.pdf differ diff --git a/tests/integration_tests/__init__.py b/tests/integration_tests/__init__.py new file mode 100644 index 00000000..a076291f --- /dev/null +++ b/tests/integration_tests/__init__.py @@ -0,0 +1 @@ +"""All integration tests (tests that call out to an external API).""" diff --git a/tests/test_utils.py b/tests/test_utils.py deleted file mode 100644 index 6636adb4..00000000 --- a/tests/test_utils.py +++ /dev/null @@ -1,14 +0,0 @@ -from langchain.docstore.document import Document - -from knowledge_gpt.utils.QA import get_sources - - -def test_get_sources(): - """Test getting sources from an answer""" - docs = [ - Document(page_content="This is a test document.", metadata={"source": "1-5"}), - Document(page_content="This is a test document.", metadata={"source": "2-6"}), - ] - answer = {"output_text": "This is a test answer. SOURCES: 1-5, 3-8"} - sources = get_sources(answer, docs) - assert sources == [docs[0]] diff --git a/tests/unit_tests/__init__.py b/tests/unit_tests/__init__.py new file mode 100644 index 00000000..307b5085 --- /dev/null +++ b/tests/unit_tests/__init__.py @@ -0,0 +1 @@ +"""All unit tests (lightweight tests).""" diff --git a/tests/unit_tests/fake_file.py b/tests/unit_tests/fake_file.py new file mode 100644 index 00000000..30a5153c --- /dev/null +++ b/tests/unit_tests/fake_file.py @@ -0,0 +1,10 @@ +from knowledge_gpt.core.parsing import File +from io import BytesIO + + +class FakeFile(File): + """A fake file for testing purposes""" + + @classmethod + def from_bytes(cls, file: BytesIO) -> "FakeFile": + return NotImplemented diff --git a/tests/unit_tests/fake_vector_store.py b/tests/unit_tests/fake_vector_store.py new file mode 100644 index 00000000..6604513b --- /dev/null +++ b/tests/unit_tests/fake_vector_store.py @@ -0,0 +1,22 @@ +from langchain.vectorstores import VectorStore +from typing import Iterable, List, Any +from langchain.docstore.document import Document + + +class FakeVectorStore(VectorStore): + """Fake vector store for testing purposes.""" + + def add_texts( + self, texts: Iterable[str], metadatas: List[dict] | None = None, **kwargs: Any + ) -> List[str]: + raise NotImplementedError + + def from_texts( + self, texts: Iterable[str], metadatas: List[dict] | None = None, **kwargs: Any + ) -> List[str]: + raise NotImplementedError + + def similarity_search( + self, query: str, k: int = 4, **kwargs: Any + ) -> List[Document]: + raise NotImplementedError diff --git a/tests/unit_tests/test_chunking.py b/tests/unit_tests/test_chunking.py new file mode 100644 index 00000000..b34aaeee --- /dev/null +++ b/tests/unit_tests/test_chunking.py @@ -0,0 +1,55 @@ +import pytest +from langchain.docstore.document import Document + +from knowledge_gpt.core.chunking import chunk_file +from .fake_file import FakeFile + + +@pytest.fixture +def single_page_file(): + doc = Document(page_content="This is a page.\nIt has stuff.") + file = FakeFile(name="test.txt", id="1", docs=[doc]) + return file + + +@pytest.fixture +def multi_page_file(): + docs = [ + Document(page_content="This is the first page", metadata={"page": 1}), + Document(page_content="This is the second page.", metadata={"page": 2}), + ] + file = FakeFile(name="test.pdf", id="2", docs=docs) + return file + + +def test_chunk_file_single_page_no_overlap(single_page_file): + chunked_file = chunk_file(single_page_file, chunk_size=6, chunk_overlap=0) + assert len(chunked_file.docs) == 2 # The document should be split into 2 chunks + + assert chunked_file.docs[1].page_content == "It has stuff." + assert chunked_file.docs[0].page_content == "This is a page." + + assert chunked_file.docs[0].metadata["chunk"] == 1 + assert chunked_file.docs[1].metadata["chunk"] == 2 + + assert chunked_file.docs[0].metadata["source"] == "1-1" + assert chunked_file.docs[1].metadata["source"] == "1-2" + + +def test_chunk_file_multi_page_no_overlap(multi_page_file): + chunked_file = chunk_file(multi_page_file, chunk_size=10, chunk_overlap=0) + assert ( + len(chunked_file.docs) == 2 + ) # Each of the two documents should be split into 2 chunks + + assert chunked_file.docs[0].page_content == "This is the first page" + assert chunked_file.docs[1].page_content == "This is the second page." + + assert chunked_file.docs[0].metadata["page"] == 1 + assert chunked_file.docs[1].metadata["page"] == 2 + + assert chunked_file.docs[0].metadata["chunk"] == 1 + assert chunked_file.docs[1].metadata["chunk"] == 1 + + assert chunked_file.docs[0].metadata["source"] == "1-1" + assert chunked_file.docs[1].metadata["source"] == "2-1" diff --git a/tests/unit_tests/test_embedding.py b/tests/unit_tests/test_embedding.py new file mode 100644 index 00000000..e624a003 --- /dev/null +++ b/tests/unit_tests/test_embedding.py @@ -0,0 +1,39 @@ +from knowledge_gpt.core.embedding import FolderIndex +from .fake_file import FakeFile +from langchain.docstore.document import Document +from knowledge_gpt.core.parsing import File +from typing import List + + +def test_combining_files(): + """Tests that combining files works.""" + + files: List[File] = [ + FakeFile( + name="file1", + id="1", + docs=[Document(page_content="1"), Document(page_content="2")], + ), + FakeFile( + name="file2", + id="2", + docs=[Document(page_content="3"), Document(page_content="4")], + ), + ] + + all_docs = FolderIndex._combine_files(files) + + assert len(all_docs) == 4 + assert all_docs[0].page_content == "1" + assert all_docs[1].page_content == "2" + assert all_docs[2].page_content == "3" + assert all_docs[3].page_content == "4" + + assert all_docs[0].metadata["file_name"] == "file1" + assert all_docs[0].metadata["file_id"] == "1" + assert all_docs[1].metadata["file_name"] == "file1" + assert all_docs[1].metadata["file_id"] == "1" + assert all_docs[2].metadata["file_name"] == "file2" + assert all_docs[2].metadata["file_id"] == "2" + assert all_docs[3].metadata["file_name"] == "file2" + assert all_docs[3].metadata["file_id"] == "2" diff --git a/tests/unit_tests/test_parsing.py b/tests/unit_tests/test_parsing.py new file mode 100644 index 00000000..5bfa505a --- /dev/null +++ b/tests/unit_tests/test_parsing.py @@ -0,0 +1,163 @@ +import pytest +from io import BytesIO + +from knowledge_gpt.core.parsing import ( + DocxFile, + PdfFile, + TxtFile, + read_file, + strip_consecutive_newlines, +) +from pathlib import Path + +from .fake_file import FakeFile +from langchain.docstore.document import Document + + +UNIT_TESTS_ROOT = Path(__file__).parent.resolve() +TESTS_ROOT = UNIT_TESTS_ROOT.parent.resolve() +PROJECT_ROOT = TESTS_ROOT.parent.resolve() +RESOURCE_ROOT = PROJECT_ROOT / "resources" +SAMPLE_ROOT = RESOURCE_ROOT / "samples" + + +def test_docx_file(): + with open(SAMPLE_ROOT / "test_hello.docx", "rb") as f: + file = BytesIO(f.read()) + file.name = "test.docx" + docx_file = DocxFile.from_bytes(file) + + assert docx_file.name == "test.docx" + assert len(docx_file.docs) == 1 + assert docx_file.docs[0].page_content == "Hello World" + + +def test_docx_file_with_multiple_pages(): + with open(SAMPLE_ROOT / "test_hello_multi.docx", "rb") as f: + file = BytesIO(f.read()) + file.name = "test.docx" + docx_file = DocxFile.from_bytes(file) + + assert docx_file.name == "test.docx" + assert len(docx_file.docs) == 1 + assert ( + docx_file.docs[0].page_content == "Hello World 1\nHello World 2\nHello World 3" + ) + + +def test_pdf_file_with_single_page(): + with open(SAMPLE_ROOT / "test_hello.pdf", "rb") as f: + file = BytesIO(f.read()) + file.name = "test_hello.pdf" + pdf_file = PdfFile.from_bytes(file) + + assert pdf_file.name == "test_hello.pdf" + assert len(pdf_file.docs) == 1 + assert pdf_file.docs[0].page_content == "Hello World" + + +def test_pdf_file_with_multiple_pages(): + with open(SAMPLE_ROOT / "test_hello_multi.pdf", "rb") as f: + file = BytesIO(f.read()) + file.name = "test_hello_multiple.pdf" + pdf_file = PdfFile.from_bytes(file) + + assert pdf_file.name == "test_hello_multiple.pdf" + assert len(pdf_file.docs) == 3 + assert pdf_file.docs[0].page_content == "Hello World 1" + assert pdf_file.docs[1].page_content == "Hello World 2" + assert pdf_file.docs[2].page_content == "Hello World 3" + + assert pdf_file.docs[0].metadata["page"] == 1 + assert pdf_file.docs[1].metadata["page"] == 2 + assert pdf_file.docs[2].metadata["page"] == 3 + + +def test_txt_file(): + with open(SAMPLE_ROOT / "test_hello.txt", "rb") as f: + file = BytesIO(f.read()) + file.name = "test.txt" + txt_file = TxtFile.from_bytes(file) + + assert txt_file.name == "test.txt" + assert len(txt_file.docs) == 1 + assert txt_file.docs[0].page_content == "Hello World" + + +def test_read_file(): + # Test the `read_file` function with each file type + for ext, FileClass in [(".docx", DocxFile), (".pdf", PdfFile), (".txt", TxtFile)]: + with open(SAMPLE_ROOT / f"test_hello{ext}", "rb") as f: + file = BytesIO(f.read()) + file.name = f"test_hello{ext}" + file_obj = read_file(file) + + assert isinstance(file_obj, FileClass) + assert file_obj.name == f"test_hello{ext}" + assert len(file_obj.docs) == 1 + assert file_obj.docs[0].page_content == "Hello World" + + +def test_read_file_not_implemented(): + file = BytesIO(b"Hello World") + file.name = "test.unknown" + with pytest.raises(NotImplementedError): + read_file(file) + + +def test_file_copy(): + # Create a Document and FakeFile instance + document = Document(page_content="test content", metadata={"page": "1"}) + file = FakeFile("test_file", "1234", {"author": "test"}, [document]) + + # Create a copy of the file + file_copy = file.copy() + + # Check that the original and copy are distinct objects + assert file is not file_copy + + # Check that the copy has the same attributes as the original + assert file.name == file_copy.name + assert file.id == file_copy.id + + # Check the mutable attributes were deeply copied + assert file.metadata == file_copy.metadata + assert file.metadata is not file_copy.metadata + + # Check the documents were deeply copied + assert file.docs == file_copy.docs + assert file.docs is not file_copy.docs + + # Check individual documents are not the same objects + assert file.docs[0] is not file_copy.docs[0] + + # Check the documents have the same attributes + assert file.docs[0].page_content == file_copy.docs[0].page_content + assert file.docs[0].metadata == file_copy.docs[0].metadata + + +def test_strip_consecutive_newlines(): + # Test with multiple consecutive newlines + text = "\n\n\n" + expected = "\n" + assert strip_consecutive_newlines(text) == expected + + # Test with newlines and spaces + text = "\n \n \n" + expected = "\n" + assert strip_consecutive_newlines(text) == expected + + # Test with newlines and tabs + text = "\n\t\n\t\n" + expected = "\n" + assert strip_consecutive_newlines(text) == expected + + # Test with mixed whitespace characters + text = "\n \t\n \t \n" + expected = "\n" + assert strip_consecutive_newlines(text) == expected + + # Test with no consecutive newlines + text = "\nHello\nWorld\n" + expected = "\nHello\nWorld\n" + assert strip_consecutive_newlines(text) == expected diff --git a/tests/unit_tests/test_qa.py b/tests/unit_tests/test_qa.py new file mode 100644 index 00000000..0b7d153e --- /dev/null +++ b/tests/unit_tests/test_qa.py @@ -0,0 +1,42 @@ +from langchain.docstore.document import Document +from knowledge_gpt.core.qa import get_sources +from knowledge_gpt.core.embedding import FolderIndex + +from typing import List +from .fake_file import FakeFile +from knowledge_gpt.core.parsing import File + +from .fake_vector_store import FakeVectorStore + + +def test_getting_sources_from_answer(): + """Test that we can get the sources from an answer.""" + files: List[File] = [ + FakeFile( + name="file1", + id="1", + docs=[ + Document(page_content="1", metadata={"source": "1"}), + Document(page_content="2", metadata={"source": "2"}), + ], + ), + FakeFile( + name="file2", + id="2", + docs=[ + Document(page_content="3", metadata={"source": "3"}), + Document(page_content="4", metadata={"source": "4"}), + ], + ), + ] + folder_index = FolderIndex(files=files, index=FakeVectorStore()) + + answer = "This is the answer. SOURCES: 1, 2, 3, 4" + + sources = get_sources(answer, folder_index) + + assert len(sources) == 4 + assert sources[0].metadata["source"] == "1" + assert sources[1].metadata["source"] == "2" + assert sources[2].metadata["source"] == "3" + assert sources[3].metadata["source"] == "4"