Skip to content

Commit

Permalink
removing the generation from the basic services.
Browse files Browse the repository at this point in the history
  • Loading branch information
Marc Fabian Mezger committed Jun 29, 2024
1 parent 4aebcfd commit 8e49067
Show file tree
Hide file tree
Showing 15 changed files with 1,936 additions and 273 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -137,3 +137,4 @@ htmlcov/
vector_db/
test.py
reports.xlsx
phoenix_data/
2 changes: 2 additions & 0 deletions agent/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from fastapi import FastAPI, File, UploadFile
from fastapi.openapi.utils import get_openapi
from loguru import logger
from phoenix.trace.langchain import LangChainInstrumentor
from qdrant_client import models
from qdrant_client.http.models.models import UpdateResult
from starlette.responses import JSONResponse
Expand All @@ -32,6 +33,7 @@
)
from agent.utils.vdb import initialize_all_vector_dbs, load_vec_db_conn

LangChainInstrumentor().instrument()
nltk.download("punkt")
# add file logger for loguru
# logger.add("logs/file_{time}.log", backtrace=False, diagnose=False)
Expand Down
10 changes: 1 addition & 9 deletions agent/backend/LLMBase.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Strategy Pattern."""
from abc import ABC, abstractmethod

from agent.data_model.request_data_model import LLMBackend, RAGRequest, SearchParams
from agent.data_model.request_data_model import LLMBackend, SearchParams


class LLMBase(ABC):
Expand All @@ -25,14 +25,6 @@ def create_collection(self, name: str) -> bool:
def create_search_chain(self, search: SearchParams) -> list:
"""Searches the documents in the Qdrant DB with semantic search."""

# @abstractmethod
# def generate(self, prompt: str) -> str:
# """Generate text from a prompt."""

@abstractmethod
def create_rag_chain(self, rag: RAGRequest, search: SearchParams) -> tuple:
"""Retrieval Augmented Generation."""

@abstractmethod
def summarize_text(self, text: str) -> str:
"""Summarize text."""
10 changes: 1 addition & 9 deletions agent/backend/LLMStrategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from agent.backend.LLMBase import LLMBase
from agent.backend.ollama_service import OllamaService
from agent.backend.open_ai_service import OpenAIService
from agent.data_model.request_data_model import LLMProvider, RAGRequest, SearchParams
from agent.data_model.request_data_model import LLMProvider, SearchParams


class LLMStrategyFactory:
Expand Down Expand Up @@ -82,14 +82,6 @@ def create_collection(self, name: str) -> None:
"""Wrapper for creating a collection."""
return self.llm.create_collection(name)

def generate(self, prompt: str) -> str:
"""Wrapper for the generation of text."""
return self.llm.generate(prompt)

def create_rag_chain(self, rag: RAGRequest, search: SearchParams) -> tuple:
"""Wrapper for the RAG."""
return self.llm.create_rag_chain(rag=rag, search=search)

def summarize_text(self, text: str) -> str:
"""Wrapper for the summarization of text."""
return self.llm.summarize_text(text)
17 changes: 3 additions & 14 deletions agent/backend/cohere_service.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
"""Cohere Backend."""

from dotenv import load_dotenv
from langchain_cohere import ChatCohere, CohereEmbeddings
from langchain_cohere import CohereEmbeddings
from langchain_community.document_loaders import DirectoryLoader, PyPDFium2Loader, TextLoader
from langchain_core.documents import Document
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.retrievers import BaseRetriever
from langchain_core.runnables import RunnableParallel, RunnablePassthrough, chain
from langchain_core.runnables import chain
from langchain_text_splitters import NLTKTextSplitter
from loguru import logger
from omegaconf import DictConfig
Expand All @@ -18,7 +17,7 @@
RAGRequest,
SearchParams,
)
from agent.utils.utility import extract_text_from_langchain_documents, load_prompt_template
from agent.utils.utility import load_prompt_template
from agent.utils.vdb import generate_collection, init_vdb

load_dotenv()
Expand Down Expand Up @@ -111,16 +110,6 @@ def retriever_with_score(query: str) -> list[Document]:

return retriever_with_score

def create_rag_chain(self, rag: RAGRequest, search: SearchParams) -> tuple:
"""Retrieval Augmented Generation."""
search_chain = self.create_search_chain(search=search)

chat = ChatCohere(model_name=self.cfg.cohere_completions.model_name, maximum_tokens=self.cfg.cohere_completions.maximum_tokens)

rag_chain_from_docs = RunnablePassthrough.assign(context=(lambda x: extract_text_from_langchain_documents(x["context"]))) | self.prompt | chat | StrOutputParser()

return RunnableParallel({"context": search_chain, "question": RunnablePassthrough()}).assign(answer=rag_chain_from_docs)

def summarize_text(self, text: str) -> str:
"""Summarize text."""

Expand Down
16 changes: 0 additions & 16 deletions agent/backend/gpt4all_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,22 +104,6 @@ def summarize_text(self, text: str) -> str:

return model.generate(prompt, max_tokens=300)

def generate(self, prompt: str) -> str:
"""Complete text with GPT4ALL.
Args:
----
prompt (str): The prompt to be completed.
Returns:
-------
str: The completed text.
"""
model = GPT4All(self.cfg.gpt4all_completion.completion_model)

return model.generate(prompt, max_tokens=250)

def create_search_chain(self, search: SearchParams) -> BaseRetriever:
"""Searches the documents in the Qdrant DB with semantic search."""

Expand Down
195 changes: 195 additions & 0 deletions agent/backend/graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
import os
from collections.abc import Sequence
from typing import Annotated, Literal, TypedDict

from langchain_cohere import ChatCohere, CohereEmbeddings
from langchain_core.documents import Document
from langchain_core.language_models import LanguageModelLike
from langchain_core.messages import (
AIMessage,
BaseMessage,
HumanMessage,
convert_to_messages,
)
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import (
ChatPromptTemplate,
PromptTemplate,
)
from langchain_core.retrievers import BaseRetriever
from langchain_core.runnables import ConfigurableField, RunnableConfig
from langchain_openai import ChatOpenAI
from langchain_qdrant import Qdrant
from langgraph.graph import END, StateGraph, add_messages
from qdrant_client import QdrantClient

from agent.backend.prompts import COHERE_RESPONSE_TEMPLATE, REPHRASE_TEMPLATE
from agent.utils.utility import format_docs_for_citations

OPENAI_MODEL_KEY = "openai_gpt_3_5_turbo"
COHERE_MODEL_KEY = "cohere_command"
OLLAMA_MODEL_KEY = "phi3_ollama"


class AgentState(TypedDict):
query: str
documents: list[Document]
messages: Annotated[list[BaseMessage], add_messages]


# define models
gpt4o = ChatOpenAI(model="gpt-4o", temperature=0, streaming=True)

cohere_command = ChatCohere(
model="command",
temperature=0,
cohere_api_key=os.environ.get("COHERE_API_KEY", "not_provided"),
streaming=True,
)

ollama_chat = ChatOllama()


# define model alternatives
llm = gpt4o.configurable_alternatives(
ConfigurableField(id="model_name"),
default_key=OPENAI_MODEL_KEY,
**{
COHERE_MODEL_KEY: cohere_command,
},
).with_fallbacks([cohere_command, ollama_chat])


def get_retriever() -> BaseRetriever:
embedding = CohereEmbeddings(model="embed-multilingual-v3.0")

qdrant_client = QdrantClient("http://localhost", port=6333, api_key=os.getenv("QDRANT_API_KEY"), prefer_grpc=False)

vector_db = Qdrant(client=qdrant_client, collection_name="cohere", embeddings=embedding)
return vector_db.as_retriever(search_kwargs={"k": 4})


def retrieve_documents(state: AgentState) -> AgentState:
retriever = get_retriever()
messages = convert_to_messages(state["messages"])
query = messages[-1].content
relevant_documents = retriever.invoke(query)
return {"query": query, "documents": relevant_documents}


def retrieve_documents_with_chat_history(state: AgentState) -> AgentState:
retriever = get_retriever()
model = llm.with_config(tags=["nostream"])

CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(REPHRASE_TEMPLATE)
condense_question_chain = (CONDENSE_QUESTION_PROMPT | model | StrOutputParser()).with_config(
run_name="CondenseQuestion",
)

messages = convert_to_messages(state["messages"])
query = messages[-1].content
retriever_with_condensed_question = condense_question_chain | retriever
relevant_documents = retriever_with_condensed_question.invoke({"question": query, "chat_history": get_chat_history(messages[:-1])})
return {"query": query, "documents": relevant_documents}


def route_to_retriever(
state: AgentState,
) -> Literal["retriever", "retriever_with_chat_history"]:
# at this point in the graph execution there is exactly one (i.e. first) message from the user,
# so use basic retriever without chat history
if len(state["messages"]) == 1:
return "retriever"
else:
return "retriever_with_chat_history"


def get_chat_history(messages: Sequence[BaseMessage]) -> Sequence[BaseMessage]:
chat_history = []
for message in messages:
if (isinstance(message, AIMessage) and not message.tool_calls) or isinstance(message, HumanMessage):
chat_history.append({"content": message.content, "role": message.type})
return chat_history


def generate_response(state: AgentState, model: LanguageModelLike, prompt_template: str) -> AgentState:
"""Args:
----
state (AgentState): _description_
model (LanguageModelLike): _description_
prompt_template (str): _description_.
Returns
-------
AgentState: _description_
"""
prompt = ChatPromptTemplate.from_messages(
[
("system", prompt_template),
("placeholder", "{chat_history}"),
("human", "{question}"),
]
)
response_synthesizer = prompt | model
synthesized_response = response_synthesizer.invoke(
{
"question": state["query"],
"context": format_docs_for_citations(state["documents"]),
# NOTE: we're ignoring the last message here, as it's going to contain the most recent
# query and we don't want that to be included in the chat history
"chat_history": get_chat_history(convert_to_messages(state["messages"][:-1])),
}
)
return {
"messages": [synthesized_response],
}


def generate_response_default(state: AgentState) -> AgentState:
return generate_response(state, llm, RESPONSE_TEMPLATE)


def generate_response_cohere(state: AgentState) -> AgentState:
model = llm.bind(documents=state["documents"])
return generate_response(state, model, COHERE_RESPONSE_TEMPLATE)


def route_to_response_synthesizer(state: AgentState, config: RunnableConfig) -> Literal["response_synthesizer", "response_synthesizer_cohere"]:
model_name = config.get("configurable", {}).get("model_name", OPENAI_MODEL_KEY)
if model_name == COHERE_MODEL_KEY:
return "response_synthesizer_cohere"
else:
return "response_synthesizer"


def build_graph():
"""Build the graph for the agent.
Returns
-------
Graph: The generated graph for RAG.
"""
workflow = StateGraph(AgentState)

# define nodes
workflow.add_node("retriever", retrieve_documents)
workflow.add_node("retriever_with_chat_history", retrieve_documents_with_chat_history)
workflow.add_node("response_synthesizer", generate_response_default)
workflow.add_node("response_synthesizer_cohere", generate_response_cohere)

# set entry point to retrievers
workflow.set_conditional_entry_point(route_to_retriever)

# connect retrievers and response synthesizers
workflow.add_conditional_edges("retriever", route_to_response_synthesizer)
workflow.add_conditional_edges("retriever_with_chat_history", route_to_response_synthesizer)

# connect synthesizers to terminal node
workflow.add_edge("response_synthesizer", END)
workflow.add_edge("response_synthesizer_cohere", END)

return workflow.compile()


# answer = graph.invoke({"messages": [{"role": "human", "content": "wer ist der vater von luke skywalker?"}, {"role": "assistant", "content": "Der Vater von Luke Skywalker war Anakin Skywalker."}, {"role": "human", "content": "und wer ist seine mutter?"}]})
# logger.info(answer)
19 changes: 2 additions & 17 deletions agent/backend/ollama_service.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
"""Ollama Backend."""

from dotenv import load_dotenv
from langchain_community.chat_models import ChatOllama
from langchain_community.document_loaders import DirectoryLoader, PyPDFium2Loader, TextLoader
from langchain_community.embeddings import OllamaEmbeddings
from langchain_core.documents import Document
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.retrievers import BaseRetriever
from langchain_core.runnables import RunnableParallel, RunnablePassthrough, chain
from langchain_core.runnables import chain
from langchain_text_splitters import NLTKTextSplitter
from loguru import logger
from omegaconf import DictConfig
Expand All @@ -19,7 +17,7 @@
RAGRequest,
SearchParams,
)
from agent.utils.utility import extract_text_from_langchain_documents, load_prompt_template
from agent.utils.utility import load_prompt_template
from agent.utils.vdb import generate_collection, init_vdb

load_dotenv()
Expand Down Expand Up @@ -112,19 +110,6 @@ def retriever_with_score(query: str) -> list[Document]:

return retriever_with_score

def create_rag_chain(self, rag: RAGRequest, search: SearchParams) -> tuple:
"""Retrieval Augmented Generation."""
search_chain = self.create_search_chain(search=search)

rag_chain_from_docs = (
RunnablePassthrough.assign(context=(lambda x: extract_text_from_langchain_documents(x["context"])))
| self.prompt
| ChatOllama(model=self.cfg.ollama.model)
| StrOutputParser()
)

return RunnableParallel({"context": search_chain, "question": RunnablePassthrough()}).assign(answer=rag_chain_from_docs)

def summarize_text(self, text: str) -> str:
"""Summarize text."""

Expand Down
Loading

0 comments on commit 8e49067

Please sign in to comment.