-
Notifications
You must be signed in to change notification settings - Fork 765
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
32 changed files
with
769 additions
and
254 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,5 @@ | ||
# Local data | ||
data/local_data/ | ||
resources/local_data/ | ||
|
||
# Prototyping notebooks | ||
notebooks/ | ||
|
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
import streamlit as st | ||
from streamlit.runtime.caching.hashing import HashFuncsDict | ||
|
||
import knowledge_gpt.core.parsing as parsing | ||
import knowledge_gpt.core.chunking as chunking | ||
import knowledge_gpt.core.embedding as embedding | ||
from knowledge_gpt.core.parsing import File | ||
|
||
|
||
def file_hash_func(file: File) -> str: | ||
"""Get a unique hash for a file""" | ||
return file.id | ||
|
||
|
||
@st.cache_resource() | ||
def bootstrap_caching(): | ||
"""Patch module functions with caching""" | ||
|
||
# Get all substypes of File from module | ||
file_subtypes = [ | ||
cls | ||
for cls in vars(parsing).values() | ||
if isinstance(cls, type) and issubclass(cls, File) and cls != File | ||
] | ||
file_hash_funcs: HashFuncsDict = {cls: file_hash_func for cls in file_subtypes} | ||
|
||
parsing.read_file = st.cache_data(show_spinner=False)(parsing.read_file) | ||
chunking.chunk_file = st.cache_data(show_spinner=False, hash_funcs=file_hash_funcs)( | ||
chunking.chunk_file | ||
) | ||
embedding.embed_files = st.cache_data( | ||
show_spinner=False, hash_funcs=file_hash_funcs | ||
)(embedding.embed_files) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
from langchain.docstore.document import Document | ||
from langchain.text_splitter import RecursiveCharacterTextSplitter | ||
from knowledge_gpt.core.parsing import File | ||
|
||
|
||
def chunk_file( | ||
file: File, chunk_size: int, chunk_overlap: int = 0, model_name="gpt-3.5-turbo" | ||
) -> File: | ||
"""Chunks each document in a file into smaller documents | ||
according to the specified chunk size and overlap | ||
where the size is determined by the number of token for the specified model. | ||
""" | ||
|
||
# split each document into chunks | ||
chunked_docs = [] | ||
for doc in file.docs: | ||
text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder( | ||
model_name=model_name, | ||
chunk_size=chunk_size, | ||
chunk_overlap=chunk_overlap, | ||
) | ||
|
||
chunks = text_splitter.split_text(doc.page_content) | ||
|
||
for i, chunk in enumerate(chunks): | ||
doc = Document( | ||
page_content=chunk, | ||
metadata={ | ||
"page": doc.metadata.get("page", 1), | ||
"chunk": i + 1, | ||
"source": f"{doc.metadata.get('page', 1)}-{i + 1}", | ||
}, | ||
) | ||
chunked_docs.append(doc) | ||
|
||
chunked_file = file.copy() | ||
chunked_file.docs = chunked_docs | ||
return chunked_file |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
from langchain.vectorstores import VectorStore | ||
from knowledge_gpt.core.parsing import File | ||
from langchain.vectorstores.faiss import FAISS | ||
from langchain.embeddings import OpenAIEmbeddings | ||
from langchain.embeddings.base import Embeddings | ||
from typing import List, Type | ||
from langchain.docstore.document import Document | ||
|
||
|
||
class FolderIndex: | ||
"""Index for a collection of files (a folder)""" | ||
|
||
def __init__(self, files: List[File], index: VectorStore): | ||
self.name: str = "default" | ||
self.files = files | ||
self.index: VectorStore = index | ||
|
||
@staticmethod | ||
def _combine_files(files: List[File]) -> List[Document]: | ||
"""Combines all the documents in a list of files into a single list.""" | ||
|
||
all_texts = [] | ||
for file in files: | ||
for doc in file.docs: | ||
doc.metadata["file_name"] = file.name | ||
doc.metadata["file_id"] = file.id | ||
all_texts.append(doc) | ||
|
||
return all_texts | ||
|
||
@classmethod | ||
def from_files( | ||
cls, files: List[File], embeddings: Embeddings, vector_store: Type[VectorStore] | ||
) -> "FolderIndex": | ||
"""Creates an index from files.""" | ||
|
||
all_docs = cls._combine_files(files) | ||
|
||
index = vector_store.from_documents( | ||
documents=all_docs, | ||
embedding=embeddings, | ||
) | ||
|
||
return cls(files=files, index=index) | ||
|
||
|
||
def embed_files( | ||
files: List[File], embedding: str, vector_store: str, **kwargs | ||
) -> FolderIndex: | ||
"""Embeds a collection of files and stores them in a FolderIndex.""" | ||
|
||
supported_embeddings: dict[str, Type[Embeddings]] = { | ||
"openai": OpenAIEmbeddings, | ||
} | ||
supported_vector_stores: dict[str, Type[VectorStore]] = { | ||
"faiss": FAISS, | ||
} | ||
|
||
if embedding in supported_embeddings: | ||
_embeddings = supported_embeddings[embedding](**kwargs) | ||
else: | ||
raise NotImplementedError(f"Embedding {embedding} not supported.") | ||
|
||
if vector_store in supported_vector_stores: | ||
_vector_store = supported_vector_stores[vector_store] | ||
else: | ||
raise NotImplementedError(f"Vector store {vector_store} not supported.") | ||
|
||
return FolderIndex.from_files( | ||
files=files, embeddings=_embeddings, vector_store=_vector_store | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
from io import BytesIO | ||
from typing import List, Any, Optional | ||
import re | ||
|
||
import docx2txt | ||
from langchain.docstore.document import Document | ||
from pypdf import PdfReader | ||
from hashlib import md5 | ||
|
||
from abc import abstractmethod, ABC | ||
from copy import deepcopy | ||
|
||
|
||
class File(ABC): | ||
"""Represents an uploaded file comprised of Documents""" | ||
|
||
def __init__( | ||
self, | ||
name: str, | ||
id: str, | ||
metadata: Optional[dict[str, Any]] = None, | ||
docs: Optional[List[Document]] = None, | ||
): | ||
self.name = name | ||
self.id = id | ||
self.metadata = metadata or {} | ||
self.docs = docs or [] | ||
|
||
@classmethod | ||
@abstractmethod | ||
def from_bytes(cls, file: BytesIO) -> "File": | ||
"""Creates a File from a BytesIO object""" | ||
|
||
def __repr__(self) -> str: | ||
return ( | ||
f"File(name={self.name}, id={self.id}," | ||
" metadata={self.metadata}, docs={self.docs})" | ||
) | ||
|
||
def __str__(self) -> str: | ||
return f"File(name={self.name}, id={self.id}, metadata={self.metadata})" | ||
|
||
def copy(self) -> "File": | ||
"""Create a deep copy of this File""" | ||
return self.__class__( | ||
name=self.name, | ||
id=self.id, | ||
metadata=deepcopy(self.metadata), | ||
docs=deepcopy(self.docs), | ||
) | ||
|
||
|
||
def strip_consecutive_newlines(text: str) -> str: | ||
"""Strips consecutive newlines from a string | ||
possibly with whitespace in between | ||
""" | ||
return re.sub(r"\s*\n\s*", "\n", text) | ||
|
||
|
||
class DocxFile(File): | ||
@classmethod | ||
def from_bytes(cls, file: BytesIO) -> "DocxFile": | ||
text = docx2txt.process(file) | ||
text = strip_consecutive_newlines(text) | ||
doc = Document(page_content=text.strip()) | ||
return cls(name=file.name, id=md5(file.read()).hexdigest(), docs=[doc]) | ||
|
||
|
||
class PdfFile(File): | ||
@classmethod | ||
def from_bytes(cls, file: BytesIO) -> "PdfFile": | ||
pdf = PdfReader(file) | ||
docs = [] | ||
for i, page in enumerate(pdf.pages): | ||
text = page.extract_text() | ||
text = strip_consecutive_newlines(text) | ||
doc = Document(page_content=text.strip()) | ||
doc.metadata["page"] = i + 1 | ||
docs.append(doc) | ||
return cls(name=file.name, id=md5(file.read()).hexdigest(), docs=docs) | ||
|
||
|
||
class TxtFile(File): | ||
@classmethod | ||
def from_bytes(cls, file: BytesIO) -> "TxtFile": | ||
text = file.read().decode("utf-8") | ||
text = strip_consecutive_newlines(text) | ||
file.seek(0) | ||
doc = Document(page_content=text.strip()) | ||
return cls(name=file.name, id=md5(file.read()).hexdigest(), docs=[doc]) | ||
|
||
|
||
def read_file(file: BytesIO) -> File: | ||
"""Reads an uploaded file and returns a File object""" | ||
if file.name.endswith(".docx"): | ||
return DocxFile.from_bytes(file) | ||
elif file.name.endswith(".pdf"): | ||
return PdfFile.from_bytes(file) | ||
elif file.name.endswith(".txt"): | ||
return TxtFile.from_bytes(file) | ||
else: | ||
raise NotImplementedError |
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
from typing import Any, List | ||
from langchain.chains.qa_with_sources import load_qa_with_sources_chain | ||
from knowledge_gpt.core.prompts import STUFF_PROMPT | ||
from langchain.docstore.document import Document | ||
from langchain.chat_models import ChatOpenAI | ||
from knowledge_gpt.core.embedding import FolderIndex | ||
from pydantic import BaseModel | ||
|
||
|
||
class AnswerWithSources(BaseModel): | ||
answer: str | ||
sources: List[Document] | ||
|
||
|
||
def query_folder( | ||
query: str, folder_index: FolderIndex, return_all: bool = False, **model_kwargs: Any | ||
) -> AnswerWithSources: | ||
"""Queries a folder index for an answer. | ||
Args: | ||
query (str): The query to search for. | ||
folder_index (FolderIndex): The folder index to search. | ||
return_all (bool): Whether to return all the documents from the embedding or | ||
just the sources for the answer. | ||
**model_kwargs (Any): Keyword arguments for the model. | ||
Returns: | ||
AnswerWithSources: The answer and the source documents. | ||
""" | ||
|
||
chain = load_qa_with_sources_chain( | ||
llm=ChatOpenAI(**model_kwargs), | ||
chain_type="stuff", | ||
prompt=STUFF_PROMPT, | ||
) | ||
|
||
relevant_docs = folder_index.index.similarity_search(query, k=5) | ||
result = chain( | ||
{"input_documents": relevant_docs, "question": query}, return_only_outputs=True | ||
) | ||
sources = relevant_docs | ||
|
||
if not return_all: | ||
sources = get_sources(result["output_text"], folder_index) | ||
|
||
answer = result["output_text"].split("SOURCES: ")[0] | ||
|
||
return AnswerWithSources(answer=answer, sources=sources) | ||
|
||
|
||
def get_sources(answer: str, folder_index: FolderIndex) -> List[Document]: | ||
"""Retrieves the docs that were used to answer the question the generated answer.""" | ||
|
||
source_keys = [s for s in answer.split("SOURCES: ")[-1].split(", ")] | ||
|
||
source_docs = [] | ||
for file in folder_index.files: | ||
for doc in file.docs: | ||
if doc.metadata["source"] in source_keys: | ||
source_docs.append(doc) | ||
|
||
return source_docs |
Oops, something went wrong.