Skip to content

Commit

Permalink
gpt4all works now.
Browse files Browse the repository at this point in the history
  • Loading branch information
mfmezger committed Apr 13, 2024
1 parent b4fbf07 commit 9ed6ab8
Showing 1 changed file with 17 additions and 26 deletions.
43 changes: 17 additions & 26 deletions agent/backend/gpt4all_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,13 @@
from agent.data_model.internal_model import RetrievalResults
from agent.data_model.request_data_model import (
Filtering,
LLMBackend,
LLMProvider,
RAGRequest,
SearchRequest,
)
from agent.utils.utility import (
convert_qdrant_result_to_retrieval_results,
generate_prompt,
generate_collection_gpt4all,
generate_prompt,
)
from agent.utils.vdb import init_vdb

Expand Down Expand Up @@ -140,6 +138,7 @@ def search(self, search: SearchRequest, filtering: Filtering) -> list[RetrievalR
Args:
----
search (SearchRequest): The search request.
filtering (Filtering): The filtering parameters.
Returns:
-------
Expand All @@ -151,24 +150,27 @@ def search(self, search: SearchRequest, filtering: Filtering) -> list[RetrievalR

return convert_qdrant_result_to_retrieval_results(docs)

def rag(self, rag_request: RAGRequest) -> tuple:
def rag(self, rag_request: RAGRequest, search: SearchRequest, filtering: Filtering) -> tuple:
"""RAG takes a Rag Request Object and performs a semantic search and then a generation.
Args:
----
rag_request (RAGRequest): The RAG Request Object.
search (SearchRequest): The search request.
filtering (Filtering): The filtering parameters.
Returns:
-------
Tuple[str, str, List[RetrievalResults]]: The answer, the prompt and the metadata.
"""
documents = self.search(rag_request.search)
if rag_request.search.amount == 0:
documents = self.search(search=search, filtering=filtering)
if search.amount == 0:
msg = "No documents found."
raise ValueError(msg)
text = "\n".join([doc.document for doc in documents]) if rag_request.search.amount > 1 else documents[0].document
text = "\n".join([doc.document for doc in documents]) if search.amount > 1 else documents[0].document

prompt = generate_prompt(prompt_name="gpt4all-completion.j2", text=text, query=rag_request.search.query)
# TODO: Add the history to the prompt
prompt = generate_prompt(prompt_name="gpt4all-completion.j2", text=text, query=search.query, language=rag_request.language)

answer = self.generate(prompt)

Expand All @@ -180,28 +182,17 @@ def rag(self, rag_request: RAGRequest) -> tuple:

gpt4all_service.embed_documents(directory="tests/resources/")

docs = gpt4all_service.search(
SearchRequest(
query="Was ist Attention?",
amount=3,
filtering=Filtering(threshold=0.0, collection_name="gpt4all"),
llm_backend=LLMBackend(token="gpt4all", provider=LLMProvider.GPT4ALL),
)
)
docs = gpt4all_service.search(SearchRequest(query="Was ist Attention?", amount=3), Filtering(threshold=0.0, collection_name="gpt4all"))

logger.info(f"Documents: {docs}")

answer, prompt, meta_data = gpt4all_service.rag(
RAGRequest(
search=SearchRequest(
query="Was ist Attention?",
amount=3,
filtering=Filtering(threshold=0.0, collection_name="gpt4all"),
llm_backend=LLMBackend(token="gpt4all", provider=LLMProvider.GPT4ALL),
),
documents=docs,
query="Was ist das?",
)
RAGRequest(language="detect", history={}),
SearchRequest(
query="Was ist Attention?",
amount=3,
),
Filtering(threshold=0.0, collection_name="gpt4all"),
)

logger.info(f"Answer: {answer}")
Expand Down

0 comments on commit 9ed6ab8

Please sign in to comment.