diff --git a/agent/backend/graph.py b/agent/backend/graph.py index 86543c2..f0eb662 100644 --- a/agent/backend/graph.py +++ b/agent/backend/graph.py @@ -1,8 +1,10 @@ +"""Defining the graph.""" import os from collections.abc import Sequence from typing import Annotated, Literal, TypedDict from langchain_cohere import ChatCohere, CohereEmbeddings +from langchain_community.chat_models.ollama import ChatOllama from langchain_core.documents import Document from langchain_core.language_models import LanguageModelLike from langchain_core.messages import ( @@ -23,7 +25,7 @@ from langgraph.graph import END, StateGraph, add_messages from qdrant_client import QdrantClient -from agent.backend.prompts import COHERE_RESPONSE_TEMPLATE, REPHRASE_TEMPLATE +from agent.backend.prompts import COHERE_RESPONSE_TEMPLATE, REPHRASE_TEMPLATE, RESPONSE_TEMPLATE from agent.utils.utility import format_docs_for_citations OPENAI_MODEL_KEY = "openai_gpt_3_5_turbo" @@ -32,6 +34,9 @@ class AgentState(TypedDict): + + """State of the Agent.""" + query: str documents: list[Document] messages: Annotated[list[BaseMessage], add_messages] @@ -61,6 +66,12 @@ class AgentState(TypedDict): def get_retriever() -> BaseRetriever: + """Create a Vector Database retriever. + + Returns + ------- + BaseRetriever: Qdrant + Cohere Embeddings Retriever + """ embedding = CohereEmbeddings(model="embed-multilingual-v3.0") qdrant_client = QdrantClient("http://localhost", port=6333, api_key=os.getenv("QDRANT_API_KEY"), prefer_grpc=False) @@ -70,6 +81,16 @@ def get_retriever() -> BaseRetriever: def retrieve_documents(state: AgentState) -> AgentState: + """Retrieve documents from the retriever. + + Args: + ---- + state (AgentState): Graph State. + + Returns: + ------- + AgentState: Modified Graph State. + """ retriever = get_retriever() messages = convert_to_messages(state["messages"]) query = messages[-1].content @@ -78,11 +99,21 @@ def retrieve_documents(state: AgentState) -> AgentState: def retrieve_documents_with_chat_history(state: AgentState) -> AgentState: + """Retrieve documents from the retriever with chat history. + + Args: + ---- + state (AgentState): Graph State. + + Returns: + ------- + AgentState: Modified Graph State. + """ retriever = get_retriever() model = llm.with_config(tags=["nostream"]) - CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(REPHRASE_TEMPLATE) - condense_question_chain = (CONDENSE_QUESTION_PROMPT | model | StrOutputParser()).with_config( + condense_queston_prompt = PromptTemplate.from_template(REPHRASE_TEMPLATE) + condense_question_chain = (condense_queston_prompt | model | StrOutputParser()).with_config( run_name="CondenseQuestion", ) @@ -96,6 +127,12 @@ def retrieve_documents_with_chat_history(state: AgentState) -> AgentState: def route_to_retriever( state: AgentState, ) -> Literal["retriever", "retriever_with_chat_history"]: + """Route to the appropriate retriever based on the state. + + Returns + ------- + Literal["retriever", "retriever_with_chat_history"]: Choosen retriever method. + """ # at this point in the graph execution there is exactly one (i.e. first) message from the user, # so use basic retriever without chat history if len(state["messages"]) == 1: @@ -105,23 +142,35 @@ def route_to_retriever( def get_chat_history(messages: Sequence[BaseMessage]) -> Sequence[BaseMessage]: - chat_history = [] - for message in messages: - if (isinstance(message, AIMessage) and not message.tool_calls) or isinstance(message, HumanMessage): - chat_history.append({"content": message.content, "role": message.type}) - return chat_history + """Append the chat history to the messages. + + Args: + ---- + messages (Sequence[BaseMessage]): Messages from the frontend. + + Returns: + ------- + Sequence[BaseMessage]: Chat history as Langchain messages. + """ + return [ + {"content": message.content, "role": message.type} + for message in messages + if (isinstance(message, AIMessage) and not message.tool_calls) or isinstance(message, HumanMessage) + ] def generate_response(state: AgentState, model: LanguageModelLike, prompt_template: str) -> AgentState: - """Args: + """Create a response from the model. + + Args: ---- - state (AgentState): _description_ - model (LanguageModelLike): _description_ - prompt_template (str): _description_. + state (AgentState): Graph State. + model (LanguageModelLike): Language Model. + prompt_template (str): Template for the prompt. - Returns + Returns: ------- - AgentState: _description_ + AgentState: Modified Graph State. """ prompt = ChatPromptTemplate.from_messages( [ @@ -146,15 +195,47 @@ def generate_response(state: AgentState, model: LanguageModelLike, prompt_templa def generate_response_default(state: AgentState) -> AgentState: + """Generate a response using non cohere model. + + Args: + ---- + state (AgentState): Graph State. + + Returns: + ------- + AgentState: Modified Graph State. + """ return generate_response(state, llm, RESPONSE_TEMPLATE) def generate_response_cohere(state: AgentState) -> AgentState: + """Generate a response using the Cohere model. + + Args: + ---- + state (AgentState): Graph State. + + Returns: + ------- + AgentState: Modified Graph State. + """ model = llm.bind(documents=state["documents"]) return generate_response(state, model, COHERE_RESPONSE_TEMPLATE) -def route_to_response_synthesizer(state: AgentState, config: RunnableConfig) -> Literal["response_synthesizer", "response_synthesizer_cohere"]: +def route_to_response_synthesizer(state: AgentState, config: RunnableConfig) -> Literal["response_synthesizer", "response_synthesizer_cohere"]: # noqa: ARG001 + """Route to the appropriate response synthesizer based on the config. + + Args: + ---- + state (AgentState): Graph State. + config (RunnableConfig): Runnable Config. + + + Returns: + ------- + Literal["response_synthesizer", "response_synthesizer_cohere"]: Choosen response synthesizer method. + """ model_name = config.get("configurable", {}).get("model_name", OPENAI_MODEL_KEY) if model_name == COHERE_MODEL_KEY: return "response_synthesizer_cohere" @@ -162,7 +243,7 @@ def route_to_response_synthesizer(state: AgentState, config: RunnableConfig) -> return "response_synthesizer" -def build_graph(): +def build_graph() -> StateGraph: """Build the graph for the agent. Returns @@ -191,5 +272,6 @@ def build_graph(): return workflow.compile() -# answer = graph.invoke({"messages": [{"role": "human", "content": "wer ist der vater von luke skywalker?"}, {"role": "assistant", "content": "Der Vater von Luke Skywalker war Anakin Skywalker."}, {"role": "human", "content": "und wer ist seine mutter?"}]}) +# answer = graph.invoke({"messages": [{"role": "human", "content": "wer ist der vater von luke skywalker?"}, {"role": "assistant", "content": "Der Vater von Luke +# Skywalker war Anakin Skywalker."}, {"role": "human", "content": "und wer ist seine mutter?"}]}) # logger.info(answer) diff --git a/agent/scripts/visualize_graph.py b/agent/scripts/visualize_graph.py new file mode 100644 index 0000000..9347f2b --- /dev/null +++ b/agent/scripts/visualize_graph.py @@ -0,0 +1,15 @@ +"""Visualizing the Langgraph Graph.""" +from pathlib import Path + +from langchain_core.runnables.graph import MermaidDrawMethod + +from agent.backend.graph import build_graph + +workflow = build_graph() + + +mermaid_graph = workflow.get_graph().draw_mermaid_png(draw_method=MermaidDrawMethod.API) + +# save as png +with Path("graph.png").open("wb") as f: + f.write(mermaid_graph) diff --git a/agent/utils/utility.py b/agent/utils/utility.py index 3ccec97..4e34322 100644 --- a/agent/utils/utility.py +++ b/agent/utils/utility.py @@ -1,5 +1,6 @@ """Utility module.""" import uuid +from collections.abc import Sequence from pathlib import Path from langchain.prompts import PromptTemplate @@ -181,6 +182,16 @@ def extract_text_from_langchain_documents(docs: list[Document]) -> str: def format_docs_for_citations(docs: Sequence[Document]) -> str: + """Format the documents for citations. + + Args: + ---- + docs (Sequence[Document]): Langchain documents from a vectordatabase. + + Returns: + ------- + str: Combined documents in a format suitable for citations. + """ formatted_docs = [] for i, doc in enumerate(docs): doc_string = f"{doc.page_content}" diff --git a/graph.png b/graph.png new file mode 100644 index 0000000..8f1863b Binary files /dev/null and b/graph.png differ