Skip to content

Commit

Permalink
Merge branch 'update' of github.com:mfmezger/conversational-agent-lan…
Browse files Browse the repository at this point in the history
…gchain into update
  • Loading branch information
mfmezger committed May 23, 2024
2 parents fff49e7 + 690f390 commit d4ef0da
Show file tree
Hide file tree
Showing 10 changed files with 169 additions and 292 deletions.
7 changes: 2 additions & 5 deletions agent/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,12 @@
QAResponse,
SearchResponse,
)
from agent.utils.vdb import (
from agent.utils.utility import (
combine_text_from_list,
create_tmp_folder,
initialize_aleph_alpha_vector_db,
initialize_gpt4all_vector_db,
initialize_open_ai_vector_db,
load_vec_db_conn,
validate_token,
)
from agent.utils.vdb import initialize_aleph_alpha_vector_db, initialize_gpt4all_vector_db, initialize_open_ai_vector_db, load_vec_db_conn

# add file logger for loguru
# logger.add("logs/file_{time}.log", backtrace=False, diagnose=False)
Expand Down
44 changes: 16 additions & 28 deletions agent/backend/aleph_alpha_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from langchain.text_splitter import NLTKTextSplitter
from langchain_community.document_loaders import DirectoryLoader, PyPDFium2Loader, TextLoader
from langchain_community.embeddings import AlephAlphaAsymmetricSemanticEmbedding
from langchain_community.vectorstores import Qdrant
from loguru import logger
from omegaconf import DictConfig
from ultra_simple_config import load_config
Expand All @@ -28,12 +27,8 @@
RAGRequest,
SearchRequest,
)
from agent.utils.utility import (
convert_qdrant_result_to_retrieval_results,
generate_collection_aleph_alpha,
generate_prompt,
)
from agent.utils.vdb import init_vdb
from agent.utils.utility import convert_qdrant_result_to_retrieval_results, generate_prompt
from agent.utils.vdb import generate_collection_aleph_alpha, init_vdb

nltk.download("punkt") # This needs to be installed for the tokenizer to work.
load_dotenv()
Expand Down Expand Up @@ -67,7 +62,13 @@ def __init__(self, cfg: DictConfig, collection_name: str, token: str) -> None:
else:
self.collection_name = self.cfg.qdrant.collection_name_aa

self.vector_db = self.get_db_connection(self.collection_name)
embedding = AlephAlphaAsymmetricSemanticEmbedding(
model=self.cfg.aleph_alpha_embeddings.model_name,
aleph_alpha_api_key=self.aleph_alpha_token,
normalize=self.cfg.aleph_alpha_embeddings.normalize,
compress_to_size=self.cfg.aleph_alpha_embeddings.compress_to_size,
)
self.vector_db = init_vdb(cfg=self.cfg, collection_name=collection_name, embedding=embedding)

def get_tokenizer(self) -> None:
"""Initialize the tokenizer."""
Expand All @@ -84,30 +85,11 @@ def count_tokens(self, text: str) -> int:
Returns:
-------
int: Number of tokens.
"""
tokens = self.tokenizer.encode(text)
return len(tokens)

def get_db_connection(self, collection_name: str) -> Qdrant:
"""Initializes a connection to the Qdrant DB.
Args:
----
collection_name (str): The name of the collection in the Qdrant DB.
Returns:
-------
Qdrant: The Qdrant DB connection.
"""
embedding = AlephAlphaAsymmetricSemanticEmbedding(
model=self.cfg.aleph_alpha_embeddings.model_name,
aleph_alpha_api_key=self.aleph_alpha_token,
normalize=self.cfg.aleph_alpha_embeddings.normalize,
compress_to_size=self.cfg.aleph_alpha_embeddings.compress_to_size,
)

return init_vdb(self.cfg, collection_name, embedding)

def create_collection(self, name: str) -> bool:
"""Create a new collection in the Qdrant DB.
Expand All @@ -130,6 +112,7 @@ def summarize_text(self, text: str) -> str:
Returns:
-------
str: The summary of the text.
"""
# TODO: rewrite because deprecated.
client = Client(token=self.aleph_alpha_token)
Expand All @@ -154,6 +137,7 @@ def generate(self, text: str) -> str:
Raises:
------
ValueError: If the text or token is None or empty, or if the response or completion is empty.
"""
if not text:
msg = "Text cannot be None or empty."
Expand Down Expand Up @@ -196,6 +180,7 @@ def embed_documents(self, directory: str, file_ending: str = ".pdf") -> None:
Returns:
-------
None
"""
if file_ending == ".pdf":
loader = DirectoryLoader(directory, glob="*" + file_ending, loader_cls=PyPDFium2Loader)
Expand Down Expand Up @@ -235,6 +220,7 @@ def search(self, search: SearchRequest, filtering: Filtering) -> list[tuple[Lang
Returns:
-------
List[Tuple[Document, float]]: A list of tuples containing the documents and their similarity scores.
"""
docs = self.vector_db.similarity_search_with_score(query=search.query, k=search.amount, score_threshold=filtering.threshold)
logger.info(f"SUCCESS: {len(docs)} Documents found.")
Expand All @@ -253,6 +239,7 @@ def rag(self, rag: RAGRequest, search: SearchRequest, filtering: Filtering) -> t
Returns:
-------
Tuple[str, str, List[RetrievalResults]]: The answer, the prompt and the metadata.
"""
documents = self.search(search=search, filtering=filtering)
if search.amount == 0:
Expand Down Expand Up @@ -323,6 +310,7 @@ def process_documents_aleph_alpha(self, folder: str, processing_type: str) -> li
Raises:
------
ValueError: If the type is not one of 'qa', 'summarization', or 'invoice'.
"""
# load the documents
loader = DirectoryLoader(folder, glob="*.pdf", loader_cls=PyPDFium2Loader)
Expand Down
33 changes: 33 additions & 0 deletions agent/backend/cohere_service.py
Original file line number Diff line number Diff line change
@@ -1 +1,34 @@
"""Cohere Backend."""
from dotenv import load_dotenv

from agent.backend.LLMBase import LLMBackend, LLMBase
from agent.data_model.request_data_model import (
Filtering,
RAGRequest,
SearchRequest,
)

load_dotenv()


class CohereService(LLMBase):

"""Wrapper for cohere llms."""

def embed_documents(self, directory: str, llm_backend: LLMBackend) -> None:
"""Embedd new docments in the Qdrant DB."""

def create_collection(self, name: str) -> bool:
"""Create a new collection in the Vector Database."""

def search(self, search: SearchRequest, filtering: Filtering) -> list:
"""Searches the documents in the Qdrant DB with semantic search."""

def generate(self, prompt: str) -> str:
"""Generate text from a prompt."""

def rag(self, rag: RAGRequest, search: SearchRequest, filtering: Filtering) -> tuple:
"""Retrieval Augmented Generation."""

def summarize_text(self, text: str) -> str:
"""Summarize text."""
75 changes: 40 additions & 35 deletions agent/backend/gpt4all_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from langchain.text_splitter import NLTKTextSplitter
from langchain_community.document_loaders import DirectoryLoader, PyPDFium2Loader, TextLoader
from langchain_community.embeddings import GPT4AllEmbeddings
from langchain_community.vectorstores.qdrant import Qdrant
from loguru import logger
from omegaconf import DictConfig
from ultra_simple_config import load_config
Expand All @@ -17,12 +16,10 @@
RAGRequest,
SearchRequest,
)
from agent.utils.utility import (
convert_qdrant_result_to_retrieval_results,
generate_collection_gpt4all,
generate_prompt,
)
from agent.utils.vdb import init_vdb
from agent.utils.utility import convert_qdrant_result_to_retrieval_results, generate_prompt
from agent.utils.vdb import generate_collection_gpt4all, init_vdb

# nltk.download("punkt")

load_dotenv()

Expand All @@ -42,24 +39,24 @@ def __init__(self, cfg: DictConfig, collection_name: str, token: str | None) ->
else:
self.collection_name = self.cfg.qdrant.collection_name_gpt4all

self.vector_db = self.get_db_connection()
embedding = GPT4AllEmbeddings()
self.vector_db = init_vdb(cfg=self.cfg, collection_name=collection_name, embedding=embedding)

def get_db_connection(self) -> Qdrant:
"""Initializes a connection to the Qdrant DB.
# create retriever from the vector database

Args:
----
cfg (DictConfig): The configuration file loaded via OmegaConf.
aleph_alpha_token (str): The Aleph Alpha API token.
Returns:
-------
Qdrant: The Qdrant DB connection.
# query = "Was ist Attention?"
# results = retriever.invoke(query=query)
# print(results)
# self.chain =

"""
embedding = GPT4AllEmbeddings()
def create_search_chain(self, search_kwargs: dict[str, any] | None = None):
if search_kwargs is None:
search_kwargs = {}
return self.vector_db.as_retriever(search_kwargs=search_kwargs)

return init_vdb(self.cfg, self.collection_name, embedding)
def create_rag_chain(self, search_chain):
llm = GPT4All(self.cfg.gpt4all_completion.completion_model)
return search_chain | llm

def create_collection(self, name: str) -> bool:
"""Create a new collection in the Vector Database."""
Expand Down Expand Up @@ -186,23 +183,31 @@ def rag(self, rag_request: RAGRequest, search: SearchRequest, filtering: Filteri


if __name__ == "__main__":
query = "Was ist Attention?"

gpt4all_service = GPT4AllService(collection_name="gpt4all", token="")

gpt4all_service.embed_documents(directory="tests/resources/")
# gpt4all_service.embed_documents(directory="tests/resources/")

retriever = gpt4all_service.create_search_chain(search_kwargs={"k": 3})

results = (retriever.invoke(query),) # config={'callbacks': [ConsoleCallbackHandler()]})

rag_chain = gpt4all_service.create_rag_chain(search_chain=retriever)

docs = gpt4all_service.search(SearchRequest(query="Was ist Attention?", amount=3), Filtering(threshold=0.0, collection_name="gpt4all"))
# docs = gpt4all_service.search(SearchRequest(query, amount=3), Filtering(threshold=0.0, collection_name="gpt4all"))

logger.info(f"Documents: {docs}")
# logger.info(f"Documents: {docs}")

answer, prompt, meta_data = gpt4all_service.rag(
RAGRequest(language="detect", history={}),
SearchRequest(
query="Was ist Attention?",
amount=3,
),
Filtering(threshold=0.0, collection_name="gpt4all"),
)
# answer, prompt, meta_data = gpt4all_service.rag(
# RAGRequest(language="detect", history={}),
# SearchRequest(
# query=query,
# amount=3,
# ),
# Filtering(threshold=0.0, collection_name="gpt4all"),
# )

logger.info(f"Answer: {answer}")
logger.info(f"Prompt: {prompt}")
logger.info(f"Metadata: {meta_data}")
# logger.info(f"Answer: {answer}")
# logger.info(f"Prompt: {prompt}")
# logger.info(f"Metadata: {meta_data}")
Loading

0 comments on commit d4ef0da

Please sign in to comment.