Skip to content

Commit

Permalink
Merge pull request #76 from langchain-ai/harrison/dnd
Browse files Browse the repository at this point in the history
add dnd bot
  • Loading branch information
nfcampos authored Nov 29, 2023
2 parents 896cc5a + 0d4f3fd commit 2666579
Show file tree
Hide file tree
Showing 12 changed files with 633 additions and 305 deletions.
3 changes: 3 additions & 0 deletions backend/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ build_ui:
# Define a variable for the test file path.
TEST_FILE ?= tests/unit_tests/

start:
poetry run uvicorn app.server:app --reload --port 8100

test:
# We need to update handling of env variables for tests
YDC_API_KEY=placeholder OPENAI_API_KEY=placeholder poetry run pytest $(TEST_FILE)
Expand Down
6 changes: 4 additions & 2 deletions backend/packages/agent-executor/agent_executor/checkpoint.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
import pickle
from functools import partial
from typing import Any, Mapping, Sequence
from typing import Any, Mapping

from langchain.pydantic_v1 import Field
from langchain.schema.runnable import RunnableConfig
Expand Down Expand Up @@ -35,21 +35,23 @@ class Config:
arbitrary_types_allowed = True

@property
def config_specs(self) -> Sequence[ConfigurableFieldSpec]:
def config_specs(self) -> list[ConfigurableFieldSpec]:
return [
ConfigurableFieldSpec(
id="user_id",
annotation=str,
name="User ID",
description=None,
default=None,
is_shared=True,
),
ConfigurableFieldSpec(
id="thread_id",
annotation=str,
name="Thread ID",
description=None,
default="",
is_shared=True,
),
]

Expand Down
136 changes: 136 additions & 0 deletions backend/packages/agent-executor/agent_executor/dnd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
import json

from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain.pydantic_v1 import BaseModel, Field
from langchain.utils.openai_functions import convert_pydantic_to_openai_function
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import AIMessage, AnyMessage, HumanMessage
from permchain import BaseCheckpointAdapter, Channel, Pregel
from permchain.channels import LastValue, Topic

character_system_msg = """You are a dungeon master for a game of dungeons and dragons.
You are interacting with the first (and only) player in the game. \
Your job is to collect all needed information about their character. This will be used in the quest. \
Feel free to ask them as many questions as needed to get to the relevant information.
The relevant information is:
- Character's name
- Character's race (or species)
- Character's class
- Character's alignment
Once you have gathered enough information, write that info to `notebook`."""


class CharacterNotebook(BaseModel):
"""Notebook to write information to"""

player_info: str = Field(
description="Information about a player that you will remember over time"
)


character_prompt = ChatPromptTemplate.from_messages(
[("system", character_system_msg), MessagesPlaceholder(variable_name="messages")]
)

gameplay_system_msg = """You are a dungeon master for a game of dungeons and dragons.
You are leading a quest of one person. Their character description is here:
{character}
A summary of the game state is here:
{state}"""

game_prompt = ChatPromptTemplate.from_messages(
[("system", gameplay_system_msg), MessagesPlaceholder(variable_name="messages")]
)


class StateNotebook(BaseModel):
"""Notebook to write information to"""

state: str = Field(description="Information about the current game state")


state_prompt = ChatPromptTemplate.from_messages(
[
("system", gameplay_system_msg),
MessagesPlaceholder(variable_name="messages"),
(
"human",
"If any updates to the game state are neccessary, please update the state notebook. If none are, just say no.",
),
]
)


def _maybe_update_state(message: AnyMessage):
if "function_call" in message.additional_kwargs:
return Channel.write_to(
"messages",
state=json.loads(message.additional_kwargs["function_call"]["arguments"])[
"state"
],
)


def _maybe_update_character(message: AnyMessage):
if "function_call" in message.additional_kwargs:
args = json.loads(message.additional_kwargs["function_call"]["arguments"])
return Channel.write_to(
messages=AIMessage(content="Ready for the quest?"),
character=args["player_info"],
)


def create_dnd_bot(llm: BaseChatModel, checkpoint: BaseCheckpointAdapter):
character_model = llm.bind(
functions=[convert_pydantic_to_openai_function(CharacterNotebook)],
)
game_chain = game_prompt | llm | Channel.write_to("messages", check_update=True)
state_model = llm.bind(
functions=[convert_pydantic_to_openai_function(StateNotebook)],
stream=False,
)
state_chain = (
Channel.subscribe_to(["check_update"]).join(["messages", "character", "state"])
| state_prompt
| state_model
| _maybe_update_state
)
character_chain = (
character_prompt
| character_model
| Channel.write_to("messages")
| _maybe_update_character
)

def _route_to_chain(_input):
messages = _input["messages"]
if not messages:
return
if not _input["character"] and isinstance(messages[-1], HumanMessage):
return character_chain
elif isinstance(messages[-1], HumanMessage):
return game_chain

executor = (
Channel.subscribe_to(["messages"]).join(["character", "state"])
| _route_to_chain
)
dnd = Pregel(
chains={"executor": executor, "update_state": state_chain},
channels={
"messages": Topic(AnyMessage, accumulate=True),
"character": LastValue(str),
"state": LastValue(str),
"check_update": LastValue(bool),
},
input=["messages"],
output=["messages"],
checkpoint=checkpoint,
)
return dnd
32 changes: 2 additions & 30 deletions backend/packages/agent-executor/agent_executor/permchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,7 @@
from typing import Sequence

from langchain.schema.agent import AgentAction, AgentActionMessageLog, AgentFinish
from langchain.schema.messages import (
AnyMessage,
AIMessage,
AIMessageChunk,
BaseMessage,
BaseMessageChunk,
ChatMessage,
ChatMessageChunk,
FunctionMessage,
FunctionMessageChunk,
HumanMessage,
HumanMessageChunk,
)
from langchain.schema.messages import AIMessage, AnyMessage, FunctionMessage
from langchain.schema.runnable import (
Runnable,
RunnableConfig,
Expand All @@ -28,29 +16,13 @@
from permchain.checkpoint.base import BaseCheckpointAdapter


def map_chunk_to_msg(chunk: BaseMessageChunk) -> BaseMessage:
if not isinstance(chunk, BaseMessageChunk):
return chunk
args = {k: v for k, v in chunk.__dict__.items() if k != "type"}
if isinstance(chunk, HumanMessageChunk):
return HumanMessage(**args)
elif isinstance(chunk, AIMessageChunk):
return AIMessage(**args)
elif isinstance(chunk, FunctionMessageChunk):
return FunctionMessage(**args)
elif isinstance(chunk, ChatMessageChunk):
return ChatMessage(**args)
else:
raise ValueError(f"Unknown chunk type: {chunk}")


def _create_agent_message(
output: AgentAction | AgentFinish
) -> list[AnyMessage] | AnyMessage:
if isinstance(output, AgentAction):
if isinstance(output, AgentActionMessageLog):
output.message_log[-1].additional_kwargs["agent"] = output
messages = [map_chunk_to_msg(m) for m in output.message_log]
messages = output.message_log
output.message_log = [] # avoid circular reference for json dumps
return messages
else:
Expand Down
34 changes: 33 additions & 1 deletion backend/packages/gizmo-agent/gizmo_agent/main.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import os
from typing import Any, Mapping, Optional, Sequence

from agent_executor.checkpoint import RedisCheckpoint
from agent_executor.dnd import create_dnd_bot
from agent_executor.permchain import get_agent_executor
from langchain.chat_models import AzureChatOpenAI, ChatOpenAI
from langchain.pydantic_v1 import BaseModel, Field
from langchain.schema.messages import AnyMessage
from langchain.schema.runnable import (
Expand Down Expand Up @@ -82,6 +85,27 @@ class AgentOutput(BaseModel):
messages: Sequence[AnyMessage] = Field(..., extra={"widget": {"type": "chat"}})


dnd_llm = ChatOpenAI(
model="gpt-3.5-turbo-1106", temperature=0, streaming=True
).configurable_alternatives(
ConfigurableField(id="llm", name="LLM"),
default_key="gpt-35-turbo",
azure_openai=AzureChatOpenAI(
temperature=0,
deployment_name=os.environ["AZURE_OPENAI_DEPLOYMENT_NAME"],
openai_api_base=os.environ["AZURE_OPENAI_API_BASE"],
openai_api_version=os.environ["AZURE_OPENAI_API_VERSION"],
openai_api_key=os.environ["AZURE_OPENAI_API_KEY"],
streaming=True,
),
)


dnd_bot = create_dnd_bot(dnd_llm, checkpoint=RedisCheckpoint()).with_types(
input_type=AgentInput, output_type=AgentOutput
)


agent = (
ConfigurableAgent(
agent=GizmoAgentType.GPT_35_TURBO,
Expand All @@ -92,14 +116,22 @@ class AgentOutput(BaseModel):
.configurable_fields(
agent=ConfigurableField(id="agent_type", name="Agent Type"),
system_message=ConfigurableField(id="system_message", name="System Message"),
assistant_id=ConfigurableField(id="assistant_id", name="Assistant ID"),
assistant_id=ConfigurableField(
id="assistant_id", name="Assistant ID", is_shared=True
),
tools=ConfigurableFieldMultiOption(
id="tools",
name="Tools",
options=TOOL_OPTIONS,
default=[],
),
)
.configurable_alternatives(
ConfigurableField(id="type", name="Bot Type"),
default_key="agent",
prefix_keys=True,
dungeons_and_dragons=dnd_bot,
)
.with_types(input_type=AgentInput, output_type=AgentOutput)
)

Expand Down
Loading

0 comments on commit 2666579

Please sign in to comment.