Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add dnd bot #76

Merged
merged 7 commits into from
Nov 29, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions backend/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ build_ui:
# Define a variable for the test file path.
TEST_FILE ?= tests/unit_tests/

start:
poetry run uvicorn app.server:app --reload --port 8100

test:
# We need to update handling of env variables for tests
YDC_API_KEY=placeholder OPENAI_API_KEY=placeholder poetry run pytest $(TEST_FILE)
Expand Down
2 changes: 2 additions & 0 deletions backend/packages/agent-executor/agent_executor/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,15 @@ def config_specs(self) -> Sequence[ConfigurableFieldSpec]:
name="User ID",
description=None,
default=None,
is_shared=True,
),
ConfigurableFieldSpec(
id="thread_id",
annotation=str,
name="Thread ID",
description=None,
default="",
is_shared=True,
),
]

Expand Down
137 changes: 137 additions & 0 deletions backend/packages/agent-executor/agent_executor/dnd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
import json

from permchain import Channel, Pregel, BaseCheckpointAdapter
from permchain.channels import Topic, LastValue
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain.pydantic_v1 import BaseModel, Field
from langchain.utils.openai_functions import convert_pydantic_to_openai_function
from langchain_core.messages import AnyMessage, AIMessage, HumanMessage
from langchain_core.language_models import BaseChatModel


character_system_msg = """You are a dungeon master for a game of dungeons and dragons.

You are interacting with the first (and only) player in the game. \
Your job is to collect all needed information about their character. This will be used in the quest. \
Feel free to ask them as many questions as needed to get to the relevant information.
The relevant information is:
- Character's name
- Character's race (or species)
- Character's class
- Character's alignment

Once you have gathered enough information, write that info to `notebook`."""


class CharacterNotebook(BaseModel):
"""Notebook to write information to"""

player_info: str = Field(
description="Information about a player that you will remember over time"
)


character_prompt = ChatPromptTemplate.from_messages(
[("system", character_system_msg), MessagesPlaceholder(variable_name="messages")]
)

gameplay_system_msg = """You are a dungeon master for a game of dungeons and dragons.

You are leading a quest of one person. Their character description is here:

{character}

A summary of the game state is here:

{state}"""

game_prompt = ChatPromptTemplate.from_messages(
[("system", gameplay_system_msg), MessagesPlaceholder(variable_name="messages")]
)


class StateNotebook(BaseModel):
"""Notebook to write information to"""

state: str = Field(description="Information about the current game state")


state_prompt = ChatPromptTemplate.from_messages(
[
("system", gameplay_system_msg),
MessagesPlaceholder(variable_name="messages"),
(
"human",
"If any updates to the game state are neccessary, please update the state notebook. If none are, just say no.",
),
]
)


def _maybe_update_state(message: AnyMessage):
if "function_call" in message.additional_kwargs:
return Channel.write_to(
"messages",
state=json.loads(message.additional_kwargs["function_call"]["arguments"])[
"state"
],
)


def _maybe_update_character(message: AnyMessage):
if "function_call" in message.additional_kwargs:
args = json.loads(message.additional_kwargs["function_call"]["arguments"])
return Channel.write_to(
messages=AIMessage(content="Ready for the quest?"),
character=args["player_info"],
)


def create_dnd_bot(llm: BaseChatModel, checkpoint: BaseCheckpointAdapter):
character_model = llm.bind(
functions=[convert_pydantic_to_openai_function(CharacterNotebook)],
)
game_chain = game_prompt | llm | Channel.write_to("messages", check_update=True)
state_model = llm.bind(
functions=[convert_pydantic_to_openai_function(StateNotebook)],
stream=False,
)
state_chain = (
Channel.subscribe_to(["check_update"]).join(["messages", "character", "state"])
| state_prompt
| state_model
| _maybe_update_state
)
character_chain = (
character_prompt
| character_model
| Channel.write_to("messages")
| _maybe_update_character
)

def _route_to_chain(_input):
messages = _input["messages"]
if not messages:
return
if not _input["character"] and isinstance(messages[-1], HumanMessage):
return character_chain
elif isinstance(messages[-1], HumanMessage):
return game_chain

executor = (
Channel.subscribe_to(["messages"]).join(["character", "state"])
| _route_to_chain
)
dnd = Pregel(
chains={"executor": executor, "update_state": state_chain},
channels={
"messages": Topic(AnyMessage, accumulate=True),
"character": LastValue(str),
"state": LastValue(str),
"check_update": LastValue(bool),
},
input=["messages"],
output=["messages"],
checkpoint=checkpoint,
)
return dnd
32 changes: 2 additions & 30 deletions backend/packages/agent-executor/agent_executor/permchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,7 @@
from typing import Sequence

from langchain.schema.agent import AgentAction, AgentActionMessageLog, AgentFinish
from langchain.schema.messages import (
AnyMessage,
AIMessage,
AIMessageChunk,
BaseMessage,
BaseMessageChunk,
ChatMessage,
ChatMessageChunk,
FunctionMessage,
FunctionMessageChunk,
HumanMessage,
HumanMessageChunk,
)
from langchain.schema.messages import AnyMessage, AIMessage, FunctionMessage
from langchain.schema.runnable import (
Runnable,
RunnableConfig,
Expand All @@ -28,29 +16,13 @@
from permchain.checkpoint.base import BaseCheckpointAdapter


def map_chunk_to_msg(chunk: BaseMessageChunk) -> BaseMessage:
if not isinstance(chunk, BaseMessageChunk):
return chunk
args = {k: v for k, v in chunk.__dict__.items() if k != "type"}
if isinstance(chunk, HumanMessageChunk):
return HumanMessage(**args)
elif isinstance(chunk, AIMessageChunk):
return AIMessage(**args)
elif isinstance(chunk, FunctionMessageChunk):
return FunctionMessage(**args)
elif isinstance(chunk, ChatMessageChunk):
return ChatMessage(**args)
else:
raise ValueError(f"Unknown chunk type: {chunk}")


def _create_agent_message(
output: AgentAction | AgentFinish
) -> list[AnyMessage] | AnyMessage:
if isinstance(output, AgentAction):
if isinstance(output, AgentActionMessageLog):
output.message_log[-1].additional_kwargs["agent"] = output
messages = [map_chunk_to_msg(m) for m in output.message_log]
messages = output.message_log
output.message_log = [] # avoid circular reference for json dumps
return messages
else:
Expand Down
34 changes: 33 additions & 1 deletion backend/packages/gizmo-agent/gizmo_agent/main.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
import os
from typing import Any, Mapping, Optional, Sequence

from agent_executor.checkpoint import RedisCheckpoint
from agent_executor.permchain import get_agent_executor
from agent_executor.dnd import create_dnd_bot
from langchain.pydantic_v1 import BaseModel, Field
from langchain.schema.messages import AnyMessage
from langchain.schema.runnable import (
ConfigurableField,
ConfigurableFieldMultiOption,
RunnableBinding,
)
from langchain.chat_models import ChatOpenAI, AzureChatOpenAI

from gizmo_agent.agent_types import (
GizmoAgentType,
Expand Down Expand Up @@ -82,6 +85,27 @@ class AgentOutput(BaseModel):
messages: Sequence[AnyMessage] = Field(..., extra={"widget": {"type": "chat"}})


dnd_llm = ChatOpenAI(
model="gpt-3.5-turbo-1106", temperature=0, streaming=True
).configurable_alternatives(
ConfigurableField(id="llm", name="LLM"),
default_key="gpt-35-turbo",
azure_openai=AzureChatOpenAI(
temperature=0,
deployment_name=os.environ["AZURE_OPENAI_DEPLOYMENT_NAME"],
openai_api_base=os.environ["AZURE_OPENAI_API_BASE"],
openai_api_version=os.environ["AZURE_OPENAI_API_VERSION"],
openai_api_key=os.environ["AZURE_OPENAI_API_KEY"],
streaming=True,
),
)


dnd_bot = create_dnd_bot(dnd_llm, checkpoint=RedisCheckpoint()).with_types(
input_type=AgentInput, output_type=AgentOutput
)


agent = (
ConfigurableAgent(
agent=GizmoAgentType.GPT_35_TURBO,
Expand All @@ -92,14 +116,22 @@ class AgentOutput(BaseModel):
.configurable_fields(
agent=ConfigurableField(id="agent_type", name="Agent Type"),
system_message=ConfigurableField(id="system_message", name="System Message"),
assistant_id=ConfigurableField(id="assistant_id", name="Assistant ID"),
assistant_id=ConfigurableField(
id="assistant_id", name="Assistant ID", is_shared=True
),
tools=ConfigurableFieldMultiOption(
id="tools",
name="Tools",
options=TOOL_OPTIONS,
default=[],
),
)
.configurable_alternatives(
ConfigurableField(id="type", name="Bot Type"),
default_key="agent",
prefix_keys=True,
dungeons_and_dragons=dnd_bot,
)
.with_types(input_type=AgentInput, output_type=AgentOutput)
)

Expand Down
Loading