Skip to content

Commit

Permalink
create more advanced chains with filtering and score returngs.
Browse files Browse the repository at this point in the history
  • Loading branch information
mfmezger committed Jun 1, 2024
1 parent 6e06173 commit cf7f04a
Show file tree
Hide file tree
Showing 6 changed files with 64 additions and 96 deletions.
12 changes: 6 additions & 6 deletions agent/backend/LLMBase.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Strategy Pattern."""
from abc import ABC, abstractmethod

from agent.data_model.request_data_model import Filtering, LLMBackend, RAGRequest, SearchParams
from agent.data_model.request_data_model import LLMBackend, RAGRequest, SearchParams


class LLMBase(ABC):
Expand All @@ -23,15 +23,15 @@ def create_collection(self, name: str) -> bool:
"""Create a new collection in the Vector Database."""

@abstractmethod
def search(self, search: SearchParams, filtering: Filtering) -> list:
def search(self, search: SearchParams) -> list:
"""Searches the documents in the Qdrant DB with semantic search."""

@abstractmethod
def generate(self, prompt: str) -> str:
"""Generate text from a prompt."""
# @abstractmethod
# def generate(self, prompt: str) -> str:
# """Generate text from a prompt."""

@abstractmethod
def rag(self, rag: RAGRequest, search: SearchParams, filtering: Filtering) -> tuple:
def rag(self, rag: RAGRequest, search: SearchParams) -> tuple:
"""Retrieval Augmented Generation."""

@abstractmethod
Expand Down
37 changes: 30 additions & 7 deletions agent/backend/cohere_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@
from dotenv import load_dotenv
from langchain_cohere import ChatCohere, CohereEmbeddings
from langchain_community.document_loaders import DirectoryLoader, PyPDFium2Loader, TextLoader
from langchain_core.documents import Document
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough
from langchain_core.retrievers import BaseRetriever
from langchain_core.runnables import RunnableParallel, RunnablePassthrough, chain
from langchain_text_splitters import NLTKTextSplitter
from loguru import logger
from omegaconf import DictConfig
Expand All @@ -18,7 +20,7 @@
SearchParams,
)
from agent.utils.utility import extract_text_from_langchain_documents, load_prompt_template
from agent.utils.vdb import init_vdb
from agent.utils.vdb import generate_collection_cohere, init_vdb

load_dotenv()

Expand Down Expand Up @@ -85,18 +87,35 @@ def embed_documents(self, directory: str, file_ending: str = ".pdf") -> None:

def create_collection(self, name: str) -> bool:
"""Create a new collection in the Vector Database."""
generate_collection_cohere(self.cfg, name)
return True

def search(self, search: SearchParams) -> list:
def search(self, search: SearchParams) -> BaseRetriever:
"""Searches the documents in the Qdrant DB with semantic search."""
search = dict(search)
search.pop("query")

return self.vector_db.as_retriever(search_kwargs=search)
@chain
def retriever_with_score(query: str) -> list[Document]:
docs, scores = zip(
*self.vector_db.similarity_search_with_score(query, k=search["k"], filter=search["filter"], score_threshold=search["score_threshold"]), strict=False
)
for doc, score in zip(docs, scores, strict=False):
doc.metadata["score"] = score

return docs

return retriever_with_score

def rag(self, rag: RAGRequest, search: SearchParams) -> tuple:
"""Retrieval Augmented Generation."""
search_chain = self.search(search=search)
return {"context": search_chain | extract_text_from_langchain_documents, "question": RunnablePassthrough()} | self.prompt | ChatCohere() | StrOutputParser()

rag_chain_from_docs = (
RunnablePassthrough.assign(context=(lambda x: extract_text_from_langchain_documents(x["context"]))) | self.prompt | ChatCohere() | StrOutputParser()
)

return RunnableParallel({"context": search_chain, "question": RunnablePassthrough()}).assign(answer=rag_chain_from_docs)

def summarize_text(self, text: str) -> str:
"""Summarize text."""
Expand All @@ -109,8 +128,12 @@ def summarize_text(self, text: str) -> str:

# cohere_service.embed_documents(directory="tests/resources/")

search_chain = cohere_service.search(search=SearchParams(query=query, amount=3))
# search_chain = cohere_service.search(search=SearchParams(query=query, amount=3))

chain = cohere_service.generate(search_chain=search_chain)
# search_results = search_chain.invoke(query)

chain = cohere_service.rag(rag=RAGRequest(), search=SearchParams(query=query, amount=3))

answer = chain.invoke(query)

logger.info(answer)
30 changes: 2 additions & 28 deletions agent/backend/gpt4all_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,22 +40,6 @@ def __init__(self, cfg: DictConfig, collection_name: str, token: str | None) ->
embedding = GPT4AllEmbeddings()
self.vector_db = init_vdb(cfg=self.cfg, collection_name=collection_name, embedding=embedding)

# create retriever from the vector database

# query = "Was ist Attention?"
# results = retriever.invoke(query=query)
# print(results)
# self.chain =

def create_search_chain(self, search_kwargs: dict[str, any] | None = None):
if search_kwargs is None:
search_kwargs = {}
return self.vector_db.as_retriever(search_kwargs=search_kwargs)

def create_rag_chain(self, search_chain):
llm = GPT4All(self.cfg.gpt4all_completion.completion_model)
return search_chain | llm

def create_collection(self, name: str) -> bool:
"""Create a new collection in the Vector Database."""
generate_collection_gpt4all(self.vector_db.client, name)
Expand Down Expand Up @@ -166,18 +150,8 @@ def rag(self, rag_request: RAGRequest, search: SearchParams, filtering: Filterin
Tuple[str, str, List[RetrievalResults]]: The answer, the prompt and the metadata.
"""
documents = self.search(search=search, filtering=filtering)
if search.amount == 0:
msg = "No documents found."
raise ValueError(msg)
text = "\n".join([doc.document for doc in documents]) if search.amount > 1 else documents[0].document

# TODO: Add the history to the prompt
prompt = generate_prompt(prompt_name="gpt4all-completion.j2", text=text, query=search.query, language=rag_request.language)

answer = self.generate(prompt)

return answer, prompt, documents
llm = GPT4All(self.cfg.gpt4all_completion.completion_model)
return search_chain | llm


if __name__ == "__main__":
Expand Down
63 changes: 16 additions & 47 deletions agent/backend/open_ai_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@

import openai
from dotenv import load_dotenv
from langchain.docstore.document import Document
from langchain.text_splitter import NLTKTextSplitter
from langchain_community.document_loaders import DirectoryLoader, PyPDFium2Loader, TextLoader
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.retrievers import BaseRetriever
from langchain_core.runnables import RunnablePassthrough
from langchain_openai import ChatOpenAI
from langchain_openai.embeddings import AzureOpenAIEmbeddings, OpenAIEmbeddings
Expand All @@ -16,7 +16,7 @@
from ultra_simple_config import load_config

from agent.backend.LLMBase import LLMBase
from agent.data_model.request_data_model import Filtering, RAGRequest, SearchParams
from agent.data_model.request_data_model import RAGRequest, SearchParams
from agent.utils.utility import generate_prompt
from agent.utils.vdb import generate_collection_openai, init_vdb

Expand Down Expand Up @@ -49,6 +49,9 @@ def __init__(self, cfg: DictConfig, collection_name: str, token: str) -> None:
else:
self.collection_name = self.cfg.qdrant.collection_name_openai

template = load_prompt_template(prompt_name="cohere_chat.j2", task="chat")
self.prompt = ChatPromptTemplate.from_template(template=template, template_format="jinja2")

if self.cfg.openai_embeddings.azure:
embedding = AzureOpenAIEmbeddings(
azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
Expand All @@ -61,18 +64,6 @@ def __init__(self, cfg: DictConfig, collection_name: str, token: str) -> None:

self.vector_db = init_vdb(self.cfg, self.collection_name, embedding=embedding)

def create_search_chain(self, search_kwargs: dict[str, any] | None = None):
if search_kwargs is None:
search_kwargs = {}
return self.vector_db.as_retriever(search_kwargs=search_kwargs)

def create_rag_chain(self, search_chain):
prompt = ChatPromptTemplate.from_template(template=template)

return (
{"context": search_chain | format_docs, "question": RunnablePassthrough()} | prompt | ChatOpenAI(model=self.cfg.openai_completion.model) | StrOutputParser()
)

def create_collection(self, name: str) -> bool:
"""Create a new collection in the Vector Database.
Expand Down Expand Up @@ -118,7 +109,7 @@ def embed_documents(self, directory: str, file_ending: str = ".pdf") -> None:

logger.info("SUCCESS: Texts embedded.")

def search(self, search: SearchParams, filtering: Filtering) -> list[tuple[Document, float]]:
def search(self, search: SearchParams) -> BaseRetriever:
"""Searches the documents in the Qdrant DB with a specific query.
Args:
Expand All @@ -132,9 +123,10 @@ def search(self, search: SearchParams, filtering: Filtering) -> list[tuple[Docum
containing a Document object and a float score.
"""
docs = self.vector_db.similarity_search_with_score(search.query, k=search.amount, score_threshold=filtering.threshold, filter=filtering.filter)
logger.info("SUCCESS: Documents found.")
return docs
search = dict(search)
search.pop("query")

return self.vector_db.as_retriever(search_kwargs=search)

def summarize_text(self, text: str) -> str:
"""Summarizes the given text using the OpenAI API.
Expand Down Expand Up @@ -193,7 +185,7 @@ def generate(self, prompt: str) -> str:

return response.choices[0].message.content

def rag(self, rag: RAGRequest, search: SearchParams, filtering: Filtering) -> tuple:
def rag(self, rag: RAGRequest, search: SearchParams) -> tuple:
"""QA Function for OpenAI LLMs.
Args:
Expand All @@ -207,50 +199,27 @@ def rag(self, rag: RAGRequest, search: SearchParams, filtering: Filtering) -> tu
tuple: answer, prompt, meta_data
"""
documents = self.search(search=search, filtering=filtering)
if len(documents) == 0:
msg = "No documents found."
raise ValueError(msg)
text = "\n".join([doc[0].page_content for doc in documents]) if len(documents) > 1 else documents[0].document

prompt = generate_prompt(prompt_name="openai-qa.j2", text=text, query=search.query, language=rag.language)

answer = self.generate(prompt)

return answer, prompt, documents
search_chain = self.search(search=search)
return (
{"context": search_chain | format_docs, "question": RunnablePassthrough()} | prompt | ChatOpenAI(model=self.cfg.openai_completion.model) | StrOutputParser()
)


if __name__ == "__main__":
token = os.getenv("OPENAI_API_KEY")
query = "Was ist Attention?"
logger.info(f"Token: {token}")

from agent.data_model.request_data_model import Filtering, SearchParams
from agent.data_model.request_data_model import SearchParams

if not token:
msg = "OPENAI_API_KEY is not set."
raise ValueError(msg)

openai_service = OpenAIService(collection_name="openai", token=token)

# openai_service.embed_documents(directory="tests/resources/")

retriever = openai_service.create_search_chain(search_kwargs={"k": 3})

results = (retriever.invoke(query),) # config={'callbacks': [ConsoleCallbackHandler()]})

rag_chain = openai_service.create_rag_chain(search_chain=retriever)

answer = rag_chain.invoke(query)

logger.info("Answer: {answer}")

# answer, prompt, meta_data = openai_service.rag(
# RAGRequest(language="detect", filter={}),
# SearchRequest(query="Was ist Attention", amount=3),
# Filtering(threshold=0.0, collection_name="openai"),
# )

# logger.info(f"Answer: {answer}")
# logger.info(f"Prompt: {prompt}")
# logger.info(f"Metadata: {meta_data}")
17 changes: 9 additions & 8 deletions agent/data_model/request_data_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,18 @@ class SearchParams(BaseModel):
query: str = Field(..., title="Query", description="The search query.")
k: int = Field(3, title="Amount", description="The number of search results to return.")
score_threshold: float = Field(0.0, title="Threshold", description="The threshold to use for the search.")
# TODO: renaming due to python keyword
filter: dict | None = Field(None, title="Filter", description="Filter for the database search with metadata.")


class RAGRequest(BaseModel):

"""Request for the QA endpoint."""

language: Language | None = Field(Language.DETECT, title="Language", description="The language to use for the answer.")
history: dict[str, str] | None = Field([], title="History", description="A list of previous questions and answers to include in the context.")


class EmbeddTextRequest(BaseModel):

"""The request parameters for embedding text."""
Expand All @@ -84,14 +93,6 @@ class CustomPromptCompletion(BaseModel):
stop_sequences: list[str] = Field([], title="Stop Sequences", description="The stop sequences to use for the completion.")


class RAGRequest(BaseModel):

"""Request for the QA endpoint."""

language: Language | None = Field(Language.DETECT, title="Language", description="The language to use for the answer.")
history: dict[str, str] | None = Field([], title="History", description="A list of previous questions and answers to include in the context.")


class ExplainQARequest(BaseModel):

"""Request for the QA endpoint."""
Expand Down
1 change: 1 addition & 0 deletions agent/utils/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@ def create_tmp_folder() -> str:


def extract_text_from_langchain_documents(docs):
logger.info(f"Loaded {len(docs)} documents.")
return "\n\n".join(f"Context {i+1}:\n{doc.page_content}" for i, doc in enumerate(docs))


Expand Down

0 comments on commit cf7f04a

Please sign in to comment.