Skip to content

Commit

Permalink
adding visualization.
Browse files Browse the repository at this point in the history
  • Loading branch information
Marc Fabian Mezger committed Jul 1, 2024
1 parent 8e49067 commit 4ed3f4d
Show file tree
Hide file tree
Showing 4 changed files with 125 additions and 17 deletions.
116 changes: 99 additions & 17 deletions agent/backend/graph.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
"""Defining the graph."""
import os
from collections.abc import Sequence
from typing import Annotated, Literal, TypedDict

from langchain_cohere import ChatCohere, CohereEmbeddings
from langchain_community.chat_models.ollama import ChatOllama
from langchain_core.documents import Document
from langchain_core.language_models import LanguageModelLike
from langchain_core.messages import (
Expand All @@ -23,7 +25,7 @@
from langgraph.graph import END, StateGraph, add_messages
from qdrant_client import QdrantClient

from agent.backend.prompts import COHERE_RESPONSE_TEMPLATE, REPHRASE_TEMPLATE
from agent.backend.prompts import COHERE_RESPONSE_TEMPLATE, REPHRASE_TEMPLATE, RESPONSE_TEMPLATE
from agent.utils.utility import format_docs_for_citations

OPENAI_MODEL_KEY = "openai_gpt_3_5_turbo"
Expand All @@ -32,6 +34,9 @@


class AgentState(TypedDict):

"""State of the Agent."""

query: str
documents: list[Document]
messages: Annotated[list[BaseMessage], add_messages]
Expand Down Expand Up @@ -61,6 +66,12 @@ class AgentState(TypedDict):


def get_retriever() -> BaseRetriever:
"""Create a Vector Database retriever.
Returns
-------
BaseRetriever: Qdrant + Cohere Embeddings Retriever
"""
embedding = CohereEmbeddings(model="embed-multilingual-v3.0")

qdrant_client = QdrantClient("http://localhost", port=6333, api_key=os.getenv("QDRANT_API_KEY"), prefer_grpc=False)
Expand All @@ -70,6 +81,16 @@ def get_retriever() -> BaseRetriever:


def retrieve_documents(state: AgentState) -> AgentState:
"""Retrieve documents from the retriever.
Args:
----
state (AgentState): Graph State.
Returns:
-------
AgentState: Modified Graph State.
"""
retriever = get_retriever()
messages = convert_to_messages(state["messages"])
query = messages[-1].content
Expand All @@ -78,11 +99,21 @@ def retrieve_documents(state: AgentState) -> AgentState:


def retrieve_documents_with_chat_history(state: AgentState) -> AgentState:
"""Retrieve documents from the retriever with chat history.
Args:
----
state (AgentState): Graph State.
Returns:
-------
AgentState: Modified Graph State.
"""
retriever = get_retriever()
model = llm.with_config(tags=["nostream"])

CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(REPHRASE_TEMPLATE)
condense_question_chain = (CONDENSE_QUESTION_PROMPT | model | StrOutputParser()).with_config(
condense_queston_prompt = PromptTemplate.from_template(REPHRASE_TEMPLATE)
condense_question_chain = (condense_queston_prompt | model | StrOutputParser()).with_config(
run_name="CondenseQuestion",
)

Expand All @@ -96,6 +127,12 @@ def retrieve_documents_with_chat_history(state: AgentState) -> AgentState:
def route_to_retriever(
state: AgentState,
) -> Literal["retriever", "retriever_with_chat_history"]:
"""Route to the appropriate retriever based on the state.
Returns
-------
Literal["retriever", "retriever_with_chat_history"]: Choosen retriever method.
"""
# at this point in the graph execution there is exactly one (i.e. first) message from the user,
# so use basic retriever without chat history
if len(state["messages"]) == 1:
Expand All @@ -105,23 +142,35 @@ def route_to_retriever(


def get_chat_history(messages: Sequence[BaseMessage]) -> Sequence[BaseMessage]:
chat_history = []
for message in messages:
if (isinstance(message, AIMessage) and not message.tool_calls) or isinstance(message, HumanMessage):
chat_history.append({"content": message.content, "role": message.type})
return chat_history
"""Append the chat history to the messages.
Args:
----
messages (Sequence[BaseMessage]): Messages from the frontend.
Returns:
-------
Sequence[BaseMessage]: Chat history as Langchain messages.
"""
return [
{"content": message.content, "role": message.type}
for message in messages
if (isinstance(message, AIMessage) and not message.tool_calls) or isinstance(message, HumanMessage)
]


def generate_response(state: AgentState, model: LanguageModelLike, prompt_template: str) -> AgentState:
"""Args:
"""Create a response from the model.
Args:
----
state (AgentState): _description_
model (LanguageModelLike): _description_
prompt_template (str): _description_.
state (AgentState): Graph State.
model (LanguageModelLike): Language Model.
prompt_template (str): Template for the prompt.
Returns
Returns:
-------
AgentState: _description_
AgentState: Modified Graph State.
"""
prompt = ChatPromptTemplate.from_messages(
[
Expand All @@ -146,23 +195,55 @@ def generate_response(state: AgentState, model: LanguageModelLike, prompt_templa


def generate_response_default(state: AgentState) -> AgentState:
"""Generate a response using non cohere model.
Args:
----
state (AgentState): Graph State.
Returns:
-------
AgentState: Modified Graph State.
"""
return generate_response(state, llm, RESPONSE_TEMPLATE)


def generate_response_cohere(state: AgentState) -> AgentState:
"""Generate a response using the Cohere model.
Args:
----
state (AgentState): Graph State.
Returns:
-------
AgentState: Modified Graph State.
"""
model = llm.bind(documents=state["documents"])
return generate_response(state, model, COHERE_RESPONSE_TEMPLATE)


def route_to_response_synthesizer(state: AgentState, config: RunnableConfig) -> Literal["response_synthesizer", "response_synthesizer_cohere"]:
def route_to_response_synthesizer(state: AgentState, config: RunnableConfig) -> Literal["response_synthesizer", "response_synthesizer_cohere"]: # noqa: ARG001
"""Route to the appropriate response synthesizer based on the config.
Args:
----
state (AgentState): Graph State.
config (RunnableConfig): Runnable Config.
Returns:
-------
Literal["response_synthesizer", "response_synthesizer_cohere"]: Choosen response synthesizer method.
"""
model_name = config.get("configurable", {}).get("model_name", OPENAI_MODEL_KEY)
if model_name == COHERE_MODEL_KEY:
return "response_synthesizer_cohere"
else:
return "response_synthesizer"


def build_graph():
def build_graph() -> StateGraph:
"""Build the graph for the agent.
Returns
Expand Down Expand Up @@ -191,5 +272,6 @@ def build_graph():
return workflow.compile()


# answer = graph.invoke({"messages": [{"role": "human", "content": "wer ist der vater von luke skywalker?"}, {"role": "assistant", "content": "Der Vater von Luke Skywalker war Anakin Skywalker."}, {"role": "human", "content": "und wer ist seine mutter?"}]})
# answer = graph.invoke({"messages": [{"role": "human", "content": "wer ist der vater von luke skywalker?"}, {"role": "assistant", "content": "Der Vater von Luke
# Skywalker war Anakin Skywalker."}, {"role": "human", "content": "und wer ist seine mutter?"}]})
# logger.info(answer)
15 changes: 15 additions & 0 deletions agent/scripts/visualize_graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
"""Visualizing the Langgraph Graph."""
from pathlib import Path

from langchain_core.runnables.graph import MermaidDrawMethod

from agent.backend.graph import build_graph

workflow = build_graph()


mermaid_graph = workflow.get_graph().draw_mermaid_png(draw_method=MermaidDrawMethod.API)

# save as png
with Path("graph.png").open("wb") as f:
f.write(mermaid_graph)
11 changes: 11 additions & 0 deletions agent/utils/utility.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Utility module."""
import uuid
from collections.abc import Sequence
from pathlib import Path

from langchain.prompts import PromptTemplate
Expand Down Expand Up @@ -181,6 +182,16 @@ def extract_text_from_langchain_documents(docs: list[Document]) -> str:


def format_docs_for_citations(docs: Sequence[Document]) -> str:
"""Format the documents for citations.
Args:
----
docs (Sequence[Document]): Langchain documents from a vectordatabase.
Returns:
-------
str: Combined documents in a format suitable for citations.
"""
formatted_docs = []
for i, doc in enumerate(docs):
doc_string = f"<doc id='{i}'>{doc.page_content}</doc>"
Expand Down
Binary file added graph.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit 4ed3f4d

Please sign in to comment.