Skip to content

Commit

Permalink
restructuring the request pydantic classes.
Browse files Browse the repository at this point in the history
  • Loading branch information
mfmezger committed Jun 1, 2024
1 parent 18cf2c7 commit 6e06173
Show file tree
Hide file tree
Showing 9 changed files with 35 additions and 51 deletions.
13 changes: 5 additions & 8 deletions agent/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,10 @@
CustomPromptCompletion,
EmbeddTextRequest,
ExplainQARequest,
Filtering,
LLMBackend,
LLMProvider,
RAGRequest,
SearchRequest,
SearchParams,
)
from agent.data_model.response_data_model import (
EmbeddingResponse,
Expand Down Expand Up @@ -178,14 +177,13 @@ async def embedd_text(embedding: EmbeddTextRequest, llm_backend: LLMBackend) ->


@app.post("/semantic/search", tags=["search"])
def search(search: SearchRequest, llm_backend: LLMBackend, filtering: Filtering) -> list[SearchResponse]:
def search(search: SearchParams, llm_backend: LLMBackend) -> list[SearchResponse]:
"""Searches for a query in the vector database.
Args:
----
search (SearchRequest): The search request.
llm_backend (LLMBackend): The LLM Backend.
filtering (Filtering): The Filtering Parameters.
Raises:
------
Expand All @@ -201,7 +199,7 @@ def search(search: SearchRequest, llm_backend: LLMBackend, filtering: Filtering)

service = LLMContext(LLMStrategyFactory.get_strategy(strategy_type=llm_backend.llm_provider, token=llm_backend.token, collection_name=llm_backend.collection_name))

docs = service.search(search=search, filtering=filtering)
docs = service.search(search=search)

if not docs:
logger.info("No Documents found.")
Expand All @@ -221,14 +219,13 @@ def search(search: SearchRequest, llm_backend: LLMBackend, filtering: Filtering)


@app.post("/rag", tags=["rag"])
def question_answer(rag: RAGRequest, llm_backend: LLMBackend, filtering: Filtering) -> QAResponse:
def question_answer(rag: RAGRequest, llm_backend: LLMBackend) -> QAResponse:
"""Answer a question based on the documents in the database.
Args:
----
rag (RAGRequest): The request parameters.
llm_backend (LLMBackend): The LLM Backend.
filtering (Filtering): The Filtering Parameters.
Raises:
------
Expand All @@ -255,7 +252,7 @@ def question_answer(rag: RAGRequest, llm_backend: LLMBackend, filtering: Filteri
text = combine_text_from_list(rag.history)
service.summarize_text(text=text, token="")

answer, prompt, meta_data = service.rag(rag=rag, llm_backend=llm_backend, filtering=filtering)
answer, prompt, meta_data = service.rag(rag=rag, llm_backend=llm_backend)

return QAResponse(answer=answer, prompt=prompt, meta_data=meta_data)

Expand Down
6 changes: 3 additions & 3 deletions agent/backend/LLMBase.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Strategy Pattern."""
from abc import ABC, abstractmethod

from agent.data_model.request_data_model import Filtering, LLMBackend, RAGRequest, SearchRequest
from agent.data_model.request_data_model import Filtering, LLMBackend, RAGRequest, SearchParams


class LLMBase(ABC):
Expand All @@ -23,15 +23,15 @@ def create_collection(self, name: str) -> bool:
"""Create a new collection in the Vector Database."""

@abstractmethod
def search(self, search: SearchRequest, filtering: Filtering) -> list:
def search(self, search: SearchParams, filtering: Filtering) -> list:
"""Searches the documents in the Qdrant DB with semantic search."""

@abstractmethod
def generate(self, prompt: str) -> str:
"""Generate text from a prompt."""

@abstractmethod
def rag(self, rag: RAGRequest, search: SearchRequest, filtering: Filtering) -> tuple:
def rag(self, rag: RAGRequest, search: SearchParams, filtering: Filtering) -> tuple:
"""Retrieval Augmented Generation."""

@abstractmethod
Expand Down
6 changes: 3 additions & 3 deletions agent/backend/LLMStrategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from agent.backend.gpt4all_service import GPT4AllService
from agent.backend.LLMBase import LLMBase
from agent.backend.open_ai_service import OpenAIService
from agent.data_model.request_data_model import Filtering, LLMProvider, RAGRequest, SearchRequest
from agent.data_model.request_data_model import Filtering, LLMProvider, RAGRequest, SearchParams


class LLMStrategyFactory:
Expand Down Expand Up @@ -66,7 +66,7 @@ def change_strategy(self, strategy_type: str, token: str, collection_name: str)
"""Changes the strategy using the Factory."""
self.llm = LLMStrategyFactory.get_strategy(strategy_type=strategy_type, token=token, collection_name=collection_name)

def search(self, search: SearchRequest, filtering: Filtering) -> list:
def search(self, search: SearchParams, filtering: Filtering) -> list:
"""Wrapper for the search."""
return self.llm.search(search=search, filtering=filtering)

Expand All @@ -82,7 +82,7 @@ def generate(self, prompt: str) -> str:
"""Wrapper for the generation of text."""
return self.llm.generate(prompt)

def rag(self, rag: RAGRequest, search: SearchRequest, filtering: Filtering) -> tuple:
def rag(self, rag: RAGRequest, search: SearchParams, filtering: Filtering) -> tuple:
"""Wrapper for the RAG."""
return self.llm.rag(rag=rag, search=search, filtering=filtering)

Expand Down
10 changes: 5 additions & 5 deletions agent/backend/aleph_alpha_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from agent.data_model.request_data_model import (
Filtering,
RAGRequest,
SearchRequest,
SearchParams,
)
from agent.utils.utility import convert_qdrant_result_to_retrieval_results, generate_prompt
from agent.utils.vdb import generate_collection_aleph_alpha, init_vdb
Expand Down Expand Up @@ -208,7 +208,7 @@ def embed_documents(self, directory: str, file_ending: str = ".pdf") -> None:

logger.info("SUCCESS: Texts embedded.")

def search(self, search: SearchRequest, filtering: Filtering) -> list[tuple[LangchainDocument, float]]:
def search(self, search: SearchParams, filtering: Filtering) -> list[tuple[LangchainDocument, float]]:
"""Searches the Aleph Alpha service for similar documents.
Args:
Expand All @@ -227,7 +227,7 @@ def search(self, search: SearchRequest, filtering: Filtering) -> list[tuple[Lang

return convert_qdrant_result_to_retrieval_results(docs)

def rag(self, rag: RAGRequest, search: SearchRequest, filtering: Filtering) -> tuple:
def rag(self, rag: RAGRequest, search: SearchParams, filtering: Filtering) -> tuple:
"""QA takes a list of documents and returns a list of answers.
Args:
Expand Down Expand Up @@ -356,13 +356,13 @@ def process_documents_aleph_alpha(self, folder: str, processing_type: str) -> li

aa_service.embed_documents("tests/resources/")
# open the text file and read the text
docs = aa_service.search(SearchRequest(query="Was ist Attention?", amount=3), Filtering(threshold=0.0, collection_name="aleph_alpha"))
docs = aa_service.search(SearchParams(query="Was ist Attention?", amount=3), Filtering(threshold=0.0, collection_name="aleph_alpha"))

logger.info(f"Documents: {docs}")

answer, prompt, meta_data = aa_service.rag(
RAGRequest(language="detect", history={}),
SearchRequest(
SearchParams(
query="Was ist Attention?",
amount=3,
),
Expand Down
17 changes: 6 additions & 11 deletions agent/backend/cohere_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,8 @@

from agent.backend.LLMBase import LLMBase
from agent.data_model.request_data_model import (
Filtering,
RAGRequest,
SearchRequest,
SearchParams,
)
from agent.utils.utility import extract_text_from_langchain_documents, load_prompt_template
from agent.utils.vdb import init_vdb
Expand Down Expand Up @@ -87,20 +86,16 @@ def embed_documents(self, directory: str, file_ending: str = ".pdf") -> None:
def create_collection(self, name: str) -> bool:
"""Create a new collection in the Vector Database."""

def search(self, search: SearchRequest, filtering: Filtering) -> list:
def search(self, search: SearchParams) -> list:
"""Searches the documents in the Qdrant DB with semantic search."""
search = dict(search)
filtering = dict(filtering)

search.update(filtering)

search.pop("query")

return self.vector_db.as_retriever(search_kwargs=search)

def rag(self, rag: RAGRequest, search: SearchRequest, filtering: Filtering) -> tuple:
def rag(self, rag: RAGRequest, search: SearchParams) -> tuple:
"""Retrieval Augmented Generation."""
search_chain = self.search(search=search, filtering=filtering)
search_chain = self.search(search=search)
return {"context": search_chain | extract_text_from_langchain_documents, "question": RunnablePassthrough()} | self.prompt | ChatCohere() | StrOutputParser()

def summarize_text(self, text: str) -> str:
Expand All @@ -114,8 +109,8 @@ def summarize_text(self, text: str) -> str:

# cohere_service.embed_documents(directory="tests/resources/")

search_chain = cohere_service.search(search=SearchRequest(query=query, amount=3), filtering=Filtering())
search_chain = cohere_service.search(search=SearchParams(query=query, amount=3))

chain = cohere_service.generate(search_chain=search_chain)

chain = cohere_service.rag(rag=RAGRequest(), search=SearchRequest(query=query, amount=3), filtering=Filtering())
chain = cohere_service.rag(rag=RAGRequest(), search=SearchParams(query=query, amount=3))
8 changes: 3 additions & 5 deletions agent/backend/gpt4all_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,11 @@
from agent.data_model.request_data_model import (
Filtering,
RAGRequest,
SearchRequest,
SearchParams,
)
from agent.utils.utility import convert_qdrant_result_to_retrieval_results, generate_prompt
from agent.utils.vdb import generate_collection_gpt4all, init_vdb



load_dotenv()


Expand Down Expand Up @@ -135,7 +133,7 @@ def generate(self, prompt: str) -> str:

return model.generate(prompt, max_tokens=250)

def search(self, search: SearchRequest, filtering: Filtering) -> list[RetrievalResults]:
def search(self, search: SearchParams, filtering: Filtering) -> list[RetrievalResults]:
"""Searches the documents in the Qdrant DB with a specific query.
Args:
Expand All @@ -154,7 +152,7 @@ def search(self, search: SearchRequest, filtering: Filtering) -> list[RetrievalR

return convert_qdrant_result_to_retrieval_results(docs)

def rag(self, rag_request: RAGRequest, search: SearchRequest, filtering: Filtering) -> tuple:
def rag(self, rag_request: RAGRequest, search: SearchParams, filtering: Filtering) -> tuple:
"""RAG takes a Rag Request Object and performs a semantic search and then a generation.
Args:
Expand Down
8 changes: 4 additions & 4 deletions agent/backend/open_ai_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from ultra_simple_config import load_config

from agent.backend.LLMBase import LLMBase
from agent.data_model.request_data_model import Filtering, RAGRequest, SearchRequest
from agent.data_model.request_data_model import Filtering, RAGRequest, SearchParams
from agent.utils.utility import generate_prompt
from agent.utils.vdb import generate_collection_openai, init_vdb

Expand Down Expand Up @@ -118,7 +118,7 @@ def embed_documents(self, directory: str, file_ending: str = ".pdf") -> None:

logger.info("SUCCESS: Texts embedded.")

def search(self, search: SearchRequest, filtering: Filtering) -> list[tuple[Document, float]]:
def search(self, search: SearchParams, filtering: Filtering) -> list[tuple[Document, float]]:
"""Searches the documents in the Qdrant DB with a specific query.
Args:
Expand Down Expand Up @@ -193,7 +193,7 @@ def generate(self, prompt: str) -> str:

return response.choices[0].message.content

def rag(self, rag: RAGRequest, search: SearchRequest, filtering: Filtering) -> tuple:
def rag(self, rag: RAGRequest, search: SearchParams, filtering: Filtering) -> tuple:
"""QA Function for OpenAI LLMs.
Args:
Expand Down Expand Up @@ -225,7 +225,7 @@ def rag(self, rag: RAGRequest, search: SearchRequest, filtering: Filtering) -> t
query = "Was ist Attention?"
logger.info(f"Token: {token}")

from agent.data_model.request_data_model import Filtering, SearchRequest
from agent.data_model.request_data_model import Filtering, SearchParams

if not token:
msg = "OPENAI_API_KEY is not set."
Expand Down
12 changes: 3 additions & 9 deletions agent/data_model/request_data_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,6 @@ class LLMBackend(BaseModel):
collection_name: str | None = Field("", description="The name of the Qdrant Collection.")


class Filtering(BaseModel):

"""The Filtering Model."""

score_threshold: float = Field(0.0, title="Threshold", description="The threshold to use for the search.")
filter: dict | None = Field(None, title="Filter", description="Filter for the database search with metadata.")


class EmbeddTextFilesRequest(BaseModel):

"""The request for the Embedd Text Files endpoint."""
Expand All @@ -61,12 +53,14 @@ class EmbeddTextFilesRequest(BaseModel):
seperator: str = Field("###", description="The seperator to use between embedded texts.")


class SearchRequest(BaseModel):
class SearchParams(BaseModel):

"""The request parameters for searching the database."""

query: str = Field(..., title="Query", description="The search query.")
k: int = Field(3, title="Amount", description="The number of search results to return.")
score_threshold: float = Field(0.0, title="Threshold", description="The threshold to use for the search.")
filter: dict | None = Field(None, title="Filter", description="Filter for the database search with metadata.")


class EmbeddTextRequest(BaseModel):
Expand Down
6 changes: 3 additions & 3 deletions tests/e2e_tests/test_e2e_aleph_alpha_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from agent.data_model.request_data_model import (
Filtering,
RAGRequest,
SearchRequest,
SearchParams,
)


Expand Down Expand Up @@ -50,7 +50,7 @@ def test_embed_documents(service: AlephAlphaService) -> None:

def test_search(service: AlephAlphaService) -> None:
"""Test the search function."""
response = service.search(SearchRequest(query="Was ist Attention?", amount=3), Filtering(threshold=0.0, collection_name="aleph_alpha"))
response = service.search(SearchParams(query="Was ist Attention?", amount=3), Filtering(threshold=0.0, collection_name="aleph_alpha"))
assert response is not None
assert len(response) > 0

Expand All @@ -59,7 +59,7 @@ def test_rag(service: AlephAlphaService) -> None:
"""Test the rag function."""
response = service.rag(
RAGRequest(language="detect", history={}),
SearchRequest(
SearchParams(
query="Was ist Attention?",
amount=3,
),
Expand Down

0 comments on commit 6e06173

Please sign in to comment.