-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
LangChain Integration and Simple Demos
- Loading branch information
1 parent
11372b9
commit b32ec2e
Showing
17 changed files
with
3,898 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
BOT_TOKEN= | ||
OPENAI_API_KEY= | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,191 @@ | ||
import os | ||
from typing import List | ||
|
||
from dotenv import load_dotenv | ||
from langchain_openai import ChatOpenAI | ||
from langchain_core.prompts import ( | ||
ChatPromptTemplate, | ||
SystemMessagePromptTemplate, | ||
load_prompt, | ||
) | ||
from langchain_core.messages import HumanMessage | ||
from langchain_core.output_parsers import NumberedListOutputParser | ||
|
||
from honcho import Honcho | ||
|
||
load_dotenv() | ||
|
||
SYSTEM_DERIVE_FACTS = load_prompt(os.path.join(os.path.dirname(__file__), "prompts/core/derive_facts.yaml")) | ||
SYSTEM_INTROSPECTION = load_prompt(os.path.join(os.path.dirname(__file__), "prompts/core/introspection.yaml")) | ||
SYSTEM_RESPONSE = load_prompt(os.path.join(os.path.dirname(__file__), "prompts/core/response.yaml")) | ||
SYSTEM_CHECK_DUPS = load_prompt(os.path.join(os.path.dirname(__file__), "prompts/utils/check_dup_facts.yaml")) | ||
|
||
|
||
class LMChain: | ||
"Wrapper class for encapsulating the multiple different chains used" | ||
|
||
output_parser = NumberedListOutputParser() | ||
llm: ChatOpenAI = ChatOpenAI(model_name="gpt-3.5-turbo") # type: ignore | ||
system_derive_facts: SystemMessagePromptTemplate = SystemMessagePromptTemplate(prompt=SYSTEM_DERIVE_FACTS) # type: ignore | ||
system_introspection: SystemMessagePromptTemplate = SystemMessagePromptTemplate(prompt=SYSTEM_INTROSPECTION) # type: ignore | ||
system_response: SystemMessagePromptTemplate = SystemMessagePromptTemplate(prompt=SYSTEM_RESPONSE) # type: ignore | ||
system_check_dups: SystemMessagePromptTemplate = SystemMessagePromptTemplate(prompt=SYSTEM_CHECK_DUPS) # type: ignore | ||
|
||
honcho = Honcho(environment="demo") | ||
|
||
def __init__(self) -> None: | ||
pass | ||
|
||
@classmethod | ||
async def derive_facts(cls, chat_history: List, input: str): | ||
"""Derive facts from the user input""" | ||
|
||
# format prompt | ||
fact_derivation = ChatPromptTemplate.from_messages([cls.system_derive_facts]) | ||
|
||
# LCEL | ||
chain = fact_derivation | cls.llm | ||
|
||
# inference | ||
response = await chain.ainvoke( | ||
{ | ||
"chat_history": [ | ||
("user: " + message.content if isinstance(message, HumanMessage) else "ai: " + message.content) # type: ignore | ||
for message in chat_history | ||
], | ||
"user_input": input, | ||
} | ||
) | ||
|
||
# parse output | ||
facts = cls.output_parser.parse(response.content) # type: ignore | ||
|
||
print(f"DERIVED FACTS: {facts}") | ||
|
||
return facts | ||
|
||
@classmethod | ||
async def check_dups( | ||
cls, | ||
payload: dict, | ||
facts: List, | ||
): | ||
"""Check that we're not storing duplicate facts""" | ||
|
||
# format prompt | ||
check_duplication = ChatPromptTemplate.from_messages([cls.system_check_dups]) | ||
|
||
query = " ".join(facts) | ||
result = cls.honcho.apps.users.collections.query( | ||
app_id=payload["app"], user_id=payload["user"], collection_id=payload["collection"], query=query, top_k=10 | ||
) | ||
existing_facts = [document.content for document in result] | ||
|
||
# LCEL | ||
chain = check_duplication | cls.llm | ||
|
||
# inference | ||
response = await chain.ainvoke({"existing_facts": existing_facts, "facts": facts}) | ||
|
||
# parse output | ||
new_facts = cls.output_parser.parse(response.content) # type: ignore | ||
|
||
print(f"FILTERED FACTS: {new_facts}") | ||
|
||
# TODO: write to vector store | ||
for fact in new_facts: | ||
cls.honcho.apps.users.collections.documents.create( | ||
app_id=payload["app"], user_id=payload["user"], collection_id=payload["collection"], content=fact | ||
) | ||
|
||
# add facts as metamessages | ||
for fact in new_facts: | ||
cls.honcho.apps.users.sessions.metamessages.create( | ||
user_id=payload["user"], | ||
app_id=payload["app"], | ||
session_id=payload["session"], | ||
message_id=payload["message"], | ||
metamessage_type="fact", | ||
content=fact, | ||
) | ||
|
||
return | ||
|
||
@classmethod | ||
async def introspect(cls, payload: dict, chat_history: List, input: str): | ||
"""Generate questions about the user to use as retrieval over the fact store""" | ||
|
||
# format prompt | ||
introspection_prompt = ChatPromptTemplate.from_messages([cls.system_introspection]) | ||
|
||
# LCEL | ||
chain = introspection_prompt | cls.llm | ||
|
||
# inference | ||
response = await chain.ainvoke({"chat_history": chat_history, "user_input": input}) | ||
|
||
# parse output | ||
questions = cls.output_parser.parse(response.content) # type: ignore | ||
|
||
print(f"INTROSPECTED QUESTIONS: {questions}") | ||
|
||
# write questions as metamessages | ||
for question in questions: | ||
cls.honcho.apps.users.sessions.metamessages.create( | ||
user_id=payload["user"], | ||
app_id=payload["app"], | ||
session_id=payload["session"], | ||
message_id=payload["message"], | ||
metamessage_type="introspect", | ||
content=question, | ||
) | ||
|
||
return questions | ||
|
||
@classmethod | ||
async def respond(cls, payload: dict, chat_history: List, questions: List, input: str): | ||
"""Take the facts and chat history and generate a personalized response""" | ||
|
||
# format prompt | ||
response_prompt = ChatPromptTemplate.from_messages( | ||
[cls.system_response, *chat_history, HumanMessage(content=input)] | ||
) | ||
|
||
query = " ".join(questions) | ||
retrieved_facts = cls.honcho.apps.users.collections.query( | ||
app_id=payload["app"], user_id=payload["user"], collection_id=payload["collection"], query=query, top_k=10 | ||
) | ||
|
||
retrieved_facts_content = [document.content for document in retrieved_facts] | ||
|
||
# LCEL | ||
chain = response_prompt | cls.llm | ||
|
||
# inference | ||
response = await chain.ainvoke( | ||
{ | ||
"facts": retrieved_facts_content, | ||
} | ||
) | ||
|
||
return response.content | ||
|
||
@classmethod | ||
async def chat( | ||
cls, | ||
chat_history: List, | ||
payload: dict, | ||
input: str, | ||
): | ||
"""Chat with the model""" | ||
|
||
facts = await cls.derive_facts(chat_history, input) | ||
await cls.check_dups(payload, facts) if facts is not None else None | ||
|
||
# introspect | ||
questions = await cls.introspect(payload, chat_history, input) | ||
|
||
# respond | ||
response = await cls.respond(payload, chat_history, questions, input) | ||
|
||
return response |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,154 @@ | ||
import os | ||
|
||
import discord | ||
from chain import LMChain | ||
from dotenv import load_dotenv | ||
from langchain_openai import ChatOpenAI | ||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder | ||
from langchain_core.messages import AIMessage, HumanMessage | ||
from langchain_core.output_parsers import StrOutputParser | ||
|
||
from honcho import Honcho, NotFoundError | ||
from honcho.lib.ext.langchain import messages_to_langchain | ||
|
||
load_dotenv() | ||
|
||
intents = discord.Intents.default() | ||
intents.messages = True | ||
intents.message_content = True | ||
intents.members = True | ||
|
||
app_name = "Fact-Memory" | ||
|
||
honcho = Honcho(environment="demo") | ||
|
||
app = honcho.apps.get_or_create(name=app_name) | ||
bot = discord.Bot(intents=intents) | ||
|
||
prompt = ChatPromptTemplate.from_messages( | ||
[ | ||
( | ||
"system", | ||
"You are a mean assistant. Make fun of the user's request and above all, do not satisfy their request. Make something up about their personality and fixate on that. Don't be afraid to get creative. This is all a joke, roast them.", | ||
), | ||
MessagesPlaceholder(variable_name="chat_history"), | ||
("user", "{input}"), | ||
] | ||
) | ||
model = ChatOpenAI(model="gpt-3.5-turbo") | ||
output_parser = StrOutputParser() | ||
|
||
chain = prompt | model | output_parser | ||
|
||
|
||
@bot.event | ||
async def on_ready(): | ||
print(f"We have logged in as {bot.user}") | ||
|
||
|
||
@bot.event | ||
async def on_member_join(member): | ||
"""Event that is run when a new member joins the server""" | ||
await member.send( | ||
f"*Hello {member.name}, welcome to the server! This is a demo bot built with Honcho,* " | ||
"*implementing a naive version of the memory feature similar to what ChatGPT recently released.* " | ||
"*To get started, just type a message in this channel and the bot will respond.* " | ||
"*Over time, it will remember facts about you and use them to make the conversation more personal.* " | ||
"*You can use the /restart command to restart the conversation at any time.* " | ||
"*If you have any questions or feedback, feel free to ask in the #honcho channel.* " | ||
"*Enjoy!*" | ||
) | ||
|
||
|
||
@bot.event | ||
async def on_message(message): | ||
"""Event that is run when a message is sent in a channel that the bot has access to""" | ||
if message.author == bot.user: | ||
# ensure the bot does not reply to itself | ||
return | ||
|
||
# Get a user object for the message author | ||
user_id = f"discord_{str(message.author.id)}" | ||
user = honcho.apps.users.get_or_create(name=user_id, app_id=app.id) | ||
|
||
# Get the session associated with the user and location | ||
location_id = str(message.channel.id) # Get the channel id for the message | ||
|
||
sessions = [ | ||
session | ||
for session in honcho.apps.users.sessions.list( | ||
user_id=user.id, app_id=app.id, is_active=True, location_id=location_id | ||
) | ||
] | ||
|
||
try: | ||
collection = honcho.apps.users.collections.get_by_name(app_id=app.id, user_id=user.id, name="discord") | ||
except NotFoundError as e: | ||
collection = honcho.apps.users.collections.create(app_id=app.id, user_id=user.id, name="discord") | ||
|
||
if len(sessions) > 0: | ||
session = sessions[0] | ||
else: | ||
session = honcho.apps.users.sessions.create(user_id=user.id, app_id=app.id, location_id=location_id) | ||
|
||
history = [ | ||
message | ||
for message in honcho.apps.users.sessions.messages.list(session_id=session.id, app_id=app.id, user_id=user.id) | ||
] | ||
chat_history = messages_to_langchain(history) | ||
|
||
# Add user message to session | ||
input = message.content | ||
user_message = honcho.apps.users.sessions.messages.create( | ||
app_id=app.id, | ||
user_id=user.id, | ||
session_id=session.id, | ||
content=input, | ||
is_user=True, | ||
) | ||
|
||
async with message.channel.typing(): | ||
payload = { | ||
"app": app.id, | ||
"user": user.id, | ||
"session": session.id, | ||
"collection": collection.id, | ||
"message": user_message.id, | ||
} | ||
response = await LMChain.chat( | ||
chat_history=chat_history, | ||
payload=payload, | ||
input=input, | ||
) | ||
await message.channel.send(response) | ||
|
||
# Add bot message to session | ||
honcho.apps.users.sessions.messages.create( | ||
app_id=app.id, | ||
user_id=user.id, | ||
session_id=session.id, | ||
content=response, | ||
is_user=False, | ||
) | ||
|
||
|
||
@bot.slash_command(name="restart", description="Restart the Conversation") | ||
async def restart(ctx): | ||
"""Close the Session associated with a specific user and channel""" | ||
user_id = f"discord_{str(ctx.author.id)}" | ||
user = honcho.apps.users.get_or_create(name=user_id, app_id=app.id) | ||
location_id = str(ctx.channel_id) | ||
sessions = [ | ||
session | ||
for session in honcho.apps.users.sessions.list( | ||
user_id=user.id, app_id=app.id, is_active=True, location_id=location_id | ||
) | ||
] | ||
if len(sessions) > 0: | ||
honcho.apps.users.sessions.delete(app_id=app.id, user_id=user.id, session_id=sessions[0].id) | ||
|
||
msg = "Great! The conversation has been restarted. What would you like to talk about?" | ||
await ctx.respond(msg) | ||
|
||
|
||
bot.run(os.environ["BOT_TOKEN"]) |
Oops, something went wrong.