Skip to content

Commit

Permalink
refacotring to lcel.
Browse files Browse the repository at this point in the history
  • Loading branch information
Marc Fabian Mezger committed May 23, 2024
1 parent ac3525e commit 690f390
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 42 deletions.
7 changes: 2 additions & 5 deletions agent/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,12 @@
QAResponse,
SearchResponse,
)
from agent.utils.vdb import (
from agent.utils.utility import (
combine_text_from_list,
create_tmp_folder,
initialize_aleph_alpha_vector_db,
initialize_gpt4all_vector_db,
initialize_open_ai_vector_db,
load_vec_db_conn,
validate_token,
)
from agent.utils.vdb import initialize_aleph_alpha_vector_db, initialize_gpt4all_vector_db, initialize_open_ai_vector_db, load_vec_db_conn

# add file logger for loguru
# logger.add("logs/file_{time}.log", backtrace=False, diagnose=False)
Expand Down
8 changes: 2 additions & 6 deletions agent/backend/aleph_alpha_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,8 @@
RAGRequest,
SearchRequest,
)
from agent.utils.utility import (
convert_qdrant_result_to_retrieval_results,
generate_collection_aleph_alpha,
generate_prompt,
)
from agent.utils.vdb import init_vdb
from agent.utils.utility import convert_qdrant_result_to_retrieval_results, generate_prompt
from agent.utils.vdb import generate_collection_aleph_alpha, init_vdb

nltk.download("punkt") # This needs to be installed for the tokenizer to work.
load_dotenv()
Expand Down
62 changes: 42 additions & 20 deletions agent/backend/gpt4all_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,10 @@
RAGRequest,
SearchRequest,
)
from agent.utils.utility import (
convert_qdrant_result_to_retrieval_results,
generate_collection_gpt4all,
generate_prompt,
)
from agent.utils.vdb import init_vdb
from agent.utils.utility import convert_qdrant_result_to_retrieval_results, generate_prompt
from agent.utils.vdb import generate_collection_gpt4all, init_vdb

# nltk.download("punkt")

load_dotenv()

Expand All @@ -44,6 +42,22 @@ def __init__(self, cfg: DictConfig, collection_name: str, token: str | None) ->
embedding = GPT4AllEmbeddings()
self.vector_db = init_vdb(cfg=self.cfg, collection_name=collection_name, embedding=embedding)

# create retriever from the vector database

# query = "Was ist Attention?"
# results = retriever.invoke(query=query)
# print(results)
# self.chain =

def create_search_chain(self, search_kwargs: dict[str, any] | None = None):
if search_kwargs is None:
search_kwargs = {}
return self.vector_db.as_retriever(search_kwargs=search_kwargs)

def create_rag_chain(self, search_chain):
llm = GPT4All(self.cfg.gpt4all_completion.completion_model)
return search_chain | llm

def create_collection(self, name: str) -> bool:
"""Create a new collection in the Vector Database."""
generate_collection_gpt4all(self.vector_db.client, name)
Expand Down Expand Up @@ -169,23 +183,31 @@ def rag(self, rag_request: RAGRequest, search: SearchRequest, filtering: Filteri


if __name__ == "__main__":
query = "Was ist Attention?"

gpt4all_service = GPT4AllService(collection_name="gpt4all", token="")

gpt4all_service.embed_documents(directory="tests/resources/")
# gpt4all_service.embed_documents(directory="tests/resources/")

retriever = gpt4all_service.create_search_chain(search_kwargs={"k": 3})

results = (retriever.invoke(query),) # config={'callbacks': [ConsoleCallbackHandler()]})

rag_chain = gpt4all_service.create_rag_chain(search_chain=retriever)

docs = gpt4all_service.search(SearchRequest(query="Was ist Attention?", amount=3), Filtering(threshold=0.0, collection_name="gpt4all"))
# docs = gpt4all_service.search(SearchRequest(query, amount=3), Filtering(threshold=0.0, collection_name="gpt4all"))

logger.info(f"Documents: {docs}")
# logger.info(f"Documents: {docs}")

answer, prompt, meta_data = gpt4all_service.rag(
RAGRequest(language="detect", history={}),
SearchRequest(
query="Was ist Attention?",
amount=3,
),
Filtering(threshold=0.0, collection_name="gpt4all"),
)
# answer, prompt, meta_data = gpt4all_service.rag(
# RAGRequest(language="detect", history={}),
# SearchRequest(
# query=query,
# amount=3,
# ),
# Filtering(threshold=0.0, collection_name="gpt4all"),
# )

logger.info(f"Answer: {answer}")
logger.info(f"Prompt: {prompt}")
logger.info(f"Metadata: {meta_data}")
# logger.info(f"Answer: {answer}")
# logger.info(f"Prompt: {prompt}")
# logger.info(f"Metadata: {meta_data}")
57 changes: 46 additions & 11 deletions agent/backend/open_ai_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,19 @@
from langchain.docstore.document import Document
from langchain.text_splitter import NLTKTextSplitter
from langchain_community.document_loaders import DirectoryLoader, PyPDFium2Loader, TextLoader
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough
from langchain_openai import ChatOpenAI
from langchain_openai.embeddings import AzureOpenAIEmbeddings, OpenAIEmbeddings
from loguru import logger
from omegaconf import DictConfig
from ultra_simple_config import load_config

from agent.backend.LLMBase import LLMBase
from agent.data_model.request_data_model import Filtering, RAGRequest, SearchRequest
from agent.utils.utility import generate_collection_openai, generate_prompt
from agent.utils.vdb import init_vdb
from agent.utils.utility import generate_prompt
from agent.utils.vdb import generate_collection_openai, init_vdb

load_dotenv()

Expand Down Expand Up @@ -57,6 +61,26 @@ def __init__(self, cfg: DictConfig, collection_name: str, token: str) -> None:

self.vector_db = init_vdb(self.cfg, self.collection_name, embedding=embedding)

def create_search_chain(self, search_kwargs: dict[str, any] | None = None):
if search_kwargs is None:
search_kwargs = {}
return self.vector_db.as_retriever(search_kwargs=search_kwargs)

def create_rag_chain(self, search_chain):
template = """Answer the question based only on the following context:
{context}
Question: {question}
"""
prompt = ChatPromptTemplate.from_template(template=template)

def format_docs(docs):
return "\n\n".join(doc.page_content for doc in docs)

return (
{"context": search_chain | format_docs, "question": RunnablePassthrough()} | prompt | ChatOpenAI(model=self.cfg.openai_completion.model) | StrOutputParser()
)

def create_collection(self, name: str) -> bool:
"""Create a new collection in the Vector Database.
Expand Down Expand Up @@ -206,6 +230,7 @@ def rag(self, rag: RAGRequest, search: SearchRequest, filtering: Filtering) -> t

if __name__ == "__main__":
token = os.getenv("OPENAI_API_KEY")
query = "Was ist Attention?"
logger.info(f"Token: {token}")

from agent.data_model.request_data_model import Filtering, SearchRequest
Expand All @@ -216,14 +241,24 @@ def rag(self, rag: RAGRequest, search: SearchRequest, filtering: Filtering) -> t

openai_service = OpenAIService(collection_name="openai", token=token)

openai_service.embed_documents(directory="tests/resources/")
# openai_service.embed_documents(directory="tests/resources/")

retriever = openai_service.create_search_chain(search_kwargs={"k": 3})

results = (retriever.invoke(query),) # config={'callbacks': [ConsoleCallbackHandler()]})

rag_chain = openai_service.create_rag_chain(search_chain=retriever)

answer = rag_chain.invoke(query)

logger.info("Answer: {answer}")

answer, prompt, meta_data = openai_service.rag(
RAGRequest(language="detect", filter={}),
SearchRequest(query="Was ist Attention", amount=3),
Filtering(threshold=0.0, collection_name="openai"),
)
# answer, prompt, meta_data = openai_service.rag(
# RAGRequest(language="detect", filter={}),
# SearchRequest(query="Was ist Attention", amount=3),
# Filtering(threshold=0.0, collection_name="openai"),
# )

logger.info(f"Answer: {answer}")
logger.info(f"Prompt: {prompt}")
logger.info(f"Metadata: {meta_data}")
# logger.info(f"Answer: {answer}")
# logger.info(f"Prompt: {prompt}")
# logger.info(f"Metadata: {meta_data}")

0 comments on commit 690f390

Please sign in to comment.