Skip to content

Commit

Permalink
updating stuff.
Browse files Browse the repository at this point in the history
  • Loading branch information
mfmezger committed Jun 1, 2024
1 parent cf7f04a commit b6e519b
Show file tree
Hide file tree
Showing 9 changed files with 250 additions and 170 deletions.
4 changes: 2 additions & 2 deletions agent/backend/LLMBase.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,15 @@ def create_collection(self, name: str) -> bool:
"""Create a new collection in the Vector Database."""

@abstractmethod
def search(self, search: SearchParams) -> list:
def create_search_chain(self, search: SearchParams) -> list:
"""Searches the documents in the Qdrant DB with semantic search."""

# @abstractmethod
# def generate(self, prompt: str) -> str:
# """Generate text from a prompt."""

@abstractmethod
def rag(self, rag: RAGRequest, search: SearchParams) -> tuple:
def create_rag_chain(self, rag: RAGRequest, search: SearchParams) -> tuple:
"""Retrieval Augmented Generation."""

@abstractmethod
Expand Down
10 changes: 5 additions & 5 deletions agent/backend/LLMStrategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from agent.backend.gpt4all_service import GPT4AllService
from agent.backend.LLMBase import LLMBase
from agent.backend.open_ai_service import OpenAIService
from agent.data_model.request_data_model import Filtering, LLMProvider, RAGRequest, SearchParams
from agent.data_model.request_data_model import LLMProvider, RAGRequest, SearchParams


class LLMStrategyFactory:
Expand Down Expand Up @@ -66,9 +66,9 @@ def change_strategy(self, strategy_type: str, token: str, collection_name: str)
"""Changes the strategy using the Factory."""
self.llm = LLMStrategyFactory.get_strategy(strategy_type=strategy_type, token=token, collection_name=collection_name)

def search(self, search: SearchParams, filtering: Filtering) -> list:
def search(self, search: SearchParams) -> list:
"""Wrapper for the search."""
return self.llm.search(search=search, filtering=filtering)
return self.llm.create_search_chain(search=search)

def embed_documents(self, directory: str, file_ending: str) -> None:
"""Wrapper for the Embedding of Documents."""
Expand All @@ -82,9 +82,9 @@ def generate(self, prompt: str) -> str:
"""Wrapper for the generation of text."""
return self.llm.generate(prompt)

def rag(self, rag: RAGRequest, search: SearchParams, filtering: Filtering) -> tuple:
def rag(self, rag: RAGRequest, search: SearchParams) -> tuple:
"""Wrapper for the RAG."""
return self.llm.rag(rag=rag, search=search, filtering=filtering)
return self.llm.create_rag_chain(rag=rag, search=search)

def summarize_text(self, text: str) -> str:
"""Wrapper for the summarization of text."""
Expand Down
94 changes: 39 additions & 55 deletions agent/backend/aleph_alpha_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,20 @@
from langchain.text_splitter import NLTKTextSplitter
from langchain_community.document_loaders import DirectoryLoader, PyPDFium2Loader, TextLoader
from langchain_community.embeddings import AlephAlphaAsymmetricSemanticEmbedding
from langchain_community.llms import AlephAlpha
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnableParallel, RunnablePassthrough, chain
from loguru import logger
from omegaconf import DictConfig
from ultra_simple_config import load_config

from agent.backend.LLMBase import LLMBase
from agent.data_model.request_data_model import (
Filtering,
RAGRequest,
SearchParams,
)
from agent.utils.utility import convert_qdrant_result_to_retrieval_results, generate_prompt
from agent.utils.utility import extract_text_from_langchain_documents, generate_prompt, load_prompt_template
from agent.utils.vdb import generate_collection_aleph_alpha, init_vdb

nltk.download("punkt") # This needs to be installed for the tokenizer to work.
Expand All @@ -47,13 +50,7 @@ def __init__(self, cfg: DictConfig, collection_name: str, token: str) -> None:
super().__init__(token=token, collection_name=collection_name)
"""Initialize the Aleph Alpha Service."""
if token:
self.aleph_alpha_token = token
else:
self.aleph_alpha_token = os.getenv("ALEPH_ALPHA_API_KEY")

if not self.aleph_alpha_token:
msg = "API Token not provided!"
raise ValueError(msg)
os.environ["ALEPH_ALPHA_API_KEY"] = token

self.cfg = cfg

Expand All @@ -68,6 +65,9 @@ def __init__(self, cfg: DictConfig, collection_name: str, token: str) -> None:
normalize=self.cfg.aleph_alpha_embeddings.normalize,
compress_to_size=self.cfg.aleph_alpha_embeddings.compress_to_size,
)

template = load_prompt_template(prompt_name="aleph_alpha_chat.j2", task="chat")
self.prompt = ChatPromptTemplate.from_template(template=template, template_format="jinja2")
self.vector_db = init_vdb(cfg=self.cfg, collection_name=collection_name, embedding=embedding)

def get_tokenizer(self) -> None:
Expand Down Expand Up @@ -208,7 +208,7 @@ def embed_documents(self, directory: str, file_ending: str = ".pdf") -> None:

logger.info("SUCCESS: Texts embedded.")

def search(self, search: SearchParams, filtering: Filtering) -> list[tuple[LangchainDocument, float]]:
def create_search_chain(self, search: SearchParams) -> list[tuple[LangchainDocument, float]]:
"""Searches the Aleph Alpha service for similar documents.
Args:
Expand All @@ -222,36 +222,35 @@ def search(self, search: SearchParams, filtering: Filtering) -> list[tuple[Langc
List[Tuple[Document, float]]: A list of tuples containing the documents and their similarity scores.
"""
docs = self.vector_db.similarity_search_with_score(query=search.query, k=search.amount, score_threshold=filtering.threshold)
logger.info(f"SUCCESS: {len(docs)} Documents found.")

return convert_qdrant_result_to_retrieval_results(docs)
@chain
def retriever_with_score(query: str) -> list[Document]:
docs, scores = zip(
*self.vector_db.similarity_search_with_score(query, k=search.k, filter=search.filter, score_threshold=search.score_threshold), strict=False
)
for doc, score in zip(docs, scores, strict=False):
doc.metadata["score"] = score

def rag(self, rag: RAGRequest, search: SearchParams, filtering: Filtering) -> tuple:
"""QA takes a list of documents and returns a list of answers.
return docs

Args:
----
rag (RAGRequest): The request for the RAG endpoint.
search (SearchRequest): The search request object.
filtering (Filtering): The filtering object.
return retriever_with_score

Returns:
-------
Tuple[str, str, List[RetrievalResults]]: The answer, the prompt and the metadata.
"""
documents = self.search(search=search, filtering=filtering)
if search.amount == 0:
msg = "No documents found."
raise ValueError(msg)
text = "\n".join([doc.document for doc in documents]) if len(documents) > 1 else documents[0].document

prompt = generate_prompt(prompt_name="aleph_alpha_qa.j2", text=text, query=search.query, language=rag.language)
def create_rag_chain(self, rag: RAGRequest, search: SearchParams) -> tuple:
"""Retrieval Augmented Generation."""
search_chain = self.create_search_chain(search=search)
llm = AlephAlpha(
model=self.cfg.aleph_alpha_completion.model,
maximum_tokens=self.cfg.aleph_alpha_completion.max_tokens,
stop_sequences=self.cfg.aleph_alpha_completion.stop_sequences,
top_p=self.cfg.aleph_alpha_completion.top_p,
temperature=self.cfg.aleph_alpha_completion.temperature,
repetition_penalties_include_completion=self.cfg.aleph_alpha_completion.repetition_penalties_include_completion,
repetition_penalties_include_prompt=self.cfg.aleph_alpha_completion.repetition_penalties_include_prompt,
)

answer = self.generate(prompt)
rag_chain_from_docs = RunnablePassthrough.assign(context=(lambda x: extract_text_from_langchain_documents(x["context"]))) | self.prompt | llm | StrOutputParser()

return answer, prompt, documents
return RunnableParallel({"context": search_chain, "question": RunnablePassthrough()}).assign(answer=rag_chain_from_docs)

def explain_qa(self, document: LangchainDocument, explain_threshold: float, query: str) -> tuple:
"""Explian QA WIP."""
Expand Down Expand Up @@ -346,29 +345,14 @@ def process_documents_aleph_alpha(self, folder: str, processing_type: str) -> li


if __name__ == "__main__":
token = os.getenv("ALEPH_ALPHA_API_KEY")

if not token:
msg = "Token cannot be None or empty."
raise ValueError(msg)
query = "Was ist Attention?"

aa_service = AlephAlphaService(token=token, collection_name="aleph_alpha")
aa_service = AlephAlphaService(collection_name="", token="")

aa_service.embed_documents("tests/resources/")
# open the text file and read the text
docs = aa_service.search(SearchParams(query="Was ist Attention?", amount=3), Filtering(threshold=0.0, collection_name="aleph_alpha"))
aa_service.embed_documents(directory="tests/resources/")

logger.info(f"Documents: {docs}")
chain = aa_service.create_rag_chain(rag=RAGRequest(), search=SearchParams(query=query, amount=3))

answer, prompt, meta_data = aa_service.rag(
RAGRequest(language="detect", history={}),
SearchParams(
query="Was ist Attention?",
amount=3,
),
Filtering(threshold=0.0, collection_name="aleph_alpha"),
)
answer = chain.invoke(query)

logger.info(f"Answer: {answer}")
logger.info(f"Prompt: {prompt}")
logger.info(f"Metadata: {meta_data}")
logger.info(answer)
18 changes: 6 additions & 12 deletions agent/backend/cohere_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,15 +90,13 @@ def create_collection(self, name: str) -> bool:
generate_collection_cohere(self.cfg, name)
return True

def search(self, search: SearchParams) -> BaseRetriever:
def create_search_chain(self, search: SearchParams) -> BaseRetriever:
"""Searches the documents in the Qdrant DB with semantic search."""
search = dict(search)
search.pop("query")

@chain
def retriever_with_score(query: str) -> list[Document]:
docs, scores = zip(
*self.vector_db.similarity_search_with_score(query, k=search["k"], filter=search["filter"], score_threshold=search["score_threshold"]), strict=False
*self.vector_db.similarity_search_with_score(query, k=search.k, filter=search.filter, score_threshold=search.score_threshold), strict=False
)
for doc, score in zip(docs, scores, strict=False):
doc.metadata["score"] = score
Expand All @@ -107,9 +105,9 @@ def retriever_with_score(query: str) -> list[Document]:

return retriever_with_score

def rag(self, rag: RAGRequest, search: SearchParams) -> tuple:
def create_rag_chain(self, rag: RAGRequest, search: SearchParams) -> tuple:
"""Retrieval Augmented Generation."""
search_chain = self.search(search=search)
search_chain = self.create_search_chain(search=search)

rag_chain_from_docs = (
RunnablePassthrough.assign(context=(lambda x: extract_text_from_langchain_documents(x["context"]))) | self.prompt | ChatCohere() | StrOutputParser()
Expand All @@ -126,13 +124,9 @@ def summarize_text(self, text: str) -> str:

cohere_service = CohereService(collection_name="", token="")

# cohere_service.embed_documents(directory="tests/resources/")
cohere_service.embed_documents(directory="tests/resources/")

# search_chain = cohere_service.search(search=SearchParams(query=query, amount=3))

# search_results = search_chain.invoke(query)

chain = cohere_service.rag(rag=RAGRequest(), search=SearchParams(query=query, amount=3))
chain = cohere_service.create_rag_chain(rag=RAGRequest(), search=SearchParams(query=query, amount=3))

answer = chain.invoke(query)

Expand Down
88 changes: 33 additions & 55 deletions agent/backend/gpt4all_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,21 @@
from langchain.text_splitter import NLTKTextSplitter
from langchain_community.document_loaders import DirectoryLoader, PyPDFium2Loader, TextLoader
from langchain_community.embeddings import GPT4AllEmbeddings
from langchain_core.documents import Document
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.retrievers import BaseRetriever
from langchain_core.runnables import RunnableParallel, RunnablePassthrough, chain
from loguru import logger
from omegaconf import DictConfig
from ultra_simple_config import load_config

from agent.backend.LLMBase import LLMBase
from agent.data_model.internal_model import RetrievalResults
from agent.data_model.request_data_model import (
Filtering,
RAGRequest,
SearchParams,
)
from agent.utils.utility import convert_qdrant_result_to_retrieval_results, generate_prompt
from agent.utils.utility import extract_text_from_langchain_documents, generate_prompt, load_prompt_template
from agent.utils.vdb import generate_collection_gpt4all, init_vdb

load_dotenv()
Expand All @@ -37,7 +40,11 @@ def __init__(self, cfg: DictConfig, collection_name: str, token: str | None) ->
else:
self.collection_name = self.cfg.qdrant.collection_name_gpt4all

embedding = GPT4AllEmbeddings()
embedding = GPT4AllEmbeddings(model_name="nomic-embed-text-v1.5.f16.gguf")

template = load_prompt_template(prompt_name="cohere_chat.j2", task="chat")
self.prompt = ChatPromptTemplate.from_template(template=template, template_format="jinja2")

self.vector_db = init_vdb(cfg=self.cfg, collection_name=collection_name, embedding=embedding)

def create_collection(self, name: str) -> bool:
Expand Down Expand Up @@ -117,69 +124,40 @@ def generate(self, prompt: str) -> str:

return model.generate(prompt, max_tokens=250)

def search(self, search: SearchParams, filtering: Filtering) -> list[RetrievalResults]:
"""Searches the documents in the Qdrant DB with a specific query.
Args:
----
search (SearchRequest): The search request.
filtering (Filtering): The filtering parameters.
Returns:
-------
List[Tuple[Document, float]]: A list of search results, where each result is a tuple
containing a Document object and a float score.
def create_search_chain(self, search: SearchParams) -> BaseRetriever:
"""Searches the documents in the Qdrant DB with semantic search."""

"""
docs = self.vector_db.similarity_search_with_score(query=search.query, k=search.amount, score_threshold=filtering.threshold, filter=filtering.filter)
logger.info(f"SUCCESS: {len(docs)} Documents found.")
@chain
def retriever_with_score(query: str) -> list[Document]:
docs, scores = zip(
*self.vector_db.similarity_search_with_score(query, k=search.k, filter=search.filter, score_threshold=search.score_threshold), strict=False
)
for doc, score in zip(docs, scores, strict=False):
doc.metadata["score"] = score

return convert_qdrant_result_to_retrieval_results(docs)
return docs

def rag(self, rag_request: RAGRequest, search: SearchParams, filtering: Filtering) -> tuple:
"""RAG takes a Rag Request Object and performs a semantic search and then a generation.
return retriever_with_score

Args:
----
rag_request (RAGRequest): The RAG Request Object.
search (SearchRequest): The search request.
filtering (Filtering): The filtering parameters.
def create_rag_chain(self, rag: RAGRequest, search: SearchParams) -> tuple:
"""Retrieval Augmented Generation."""
search_chain = self.create_search_chain(search=search)
llm = GPT4All(self.cfg.gpt4all_completion.completion_model)

Returns:
-------
Tuple[str, str, List[RetrievalResults]]: The answer, the prompt and the metadata.
rag_chain_from_docs = RunnablePassthrough.assign(context=(lambda x: extract_text_from_langchain_documents(x["context"]))) | self.prompt | llm | StrOutputParser()

"""
llm = GPT4All(self.cfg.gpt4all_completion.completion_model)
return search_chain | llm
return RunnableParallel({"context": search_chain, "question": RunnablePassthrough()}).assign(answer=rag_chain_from_docs)


if __name__ == "__main__":
query = "Was ist Attention?"

gpt4all_service = GPT4AllService(collection_name="gpt4all", token="")

# gpt4all_service.embed_documents(directory="tests/resources/")

retriever = gpt4all_service.create_search_chain(search_kwargs={"k": 3})

results = (retriever.invoke(query),) # config={'callbacks': [ConsoleCallbackHandler()]})

rag_chain = gpt4all_service.create_rag_chain(search_chain=retriever)
gpt4all_service = GPT4AllService(collection_name="", token="")

# docs = gpt4all_service.search(SearchRequest(query, amount=3), Filtering(threshold=0.0, collection_name="gpt4all"))
gpt4all_service.embed_documents(directory="tests/resources/")

# logger.info(f"Documents: {docs}")
chain = gpt4all_service.create_rag_chain(rag=RAGRequest(), search=SearchParams(query=query, amount=3))

# answer, prompt, meta_data = gpt4all_service.rag(
# RAGRequest(language="detect", history={}),
# SearchRequest(
# query=query,
# amount=3,
# ),
# Filtering(threshold=0.0, collection_name="gpt4all"),
# )
answer = chain.invoke(query)

# logger.info(f"Answer: {answer}")
# logger.info(f"Prompt: {prompt}")
# logger.info(f"Metadata: {meta_data}")
logger.info(answer)
Loading

0 comments on commit b6e519b

Please sign in to comment.