Skip to content

Commit

Permalink
Save messages on success
Browse files Browse the repository at this point in the history
  • Loading branch information
VVoruganti committed Dec 18, 2023
1 parent cda27b5 commit c6d0aef
Showing 1 changed file with 22 additions and 6 deletions.
28 changes: 22 additions & 6 deletions chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
from collections.abc import AsyncIterator
from cache import Conversation
from typing import List

from openai import BadRequestError

import sentry_sdk

load_dotenv()
Expand Down Expand Up @@ -52,11 +55,14 @@ def think(cls, cache: Conversation, input: str):
])
chain = thought_prompt | cls.llm

cache.add_message("thought", HumanMessage(content=input))

def save_new_messages(ai_response):
cache.add_message("thought", HumanMessage(content=input))
cache.add_message("thought", AIMessage(content=ai_response))

return Streamable(
chain.astream({}, {"tags": ["thought"], "metadata": {"conversation_id": cache.conversation_id, "user_id": cache.user_id}}),
lambda thought: cache.add_message("thought", AIMessage(content=thought))
save_new_messages
)

@classmethod
Expand All @@ -75,11 +81,13 @@ def revise_thought(cls, cache: Conversation, input: str, thought: str):
])
chain = messages | cls.llm

cache.add_message("thought_revision", HumanMessage(content=input))
def save_new_messages(ai_response):
cache.add_message("thought_revision", HumanMessage(content=input))
cache.add_message("thought_revision", AIMessage(content=ai_response))

return Streamable(
chain.astream({ "thought": thought, "retrieved_vectors": "\n".join(doc.page_content for doc in docs)}, {"tags": ["thought_revision"], "metadata": {"conversation_id": cache.conversation_id, "user_id": cache.user_id}}),
lambda thought_revision: cache.add_message("thought_revision", AIMessage(content=thought_revision)) # add the revised thought to thought memory
save_new_messages
)

@classmethod
Expand All @@ -93,11 +101,13 @@ def respond(cls, cache: Conversation, thought: str, input: str):
])
chain = response_prompt | cls.llm

cache.add_message("response", HumanMessage(content=input))
def save_new_messages(ai_response):
cache.add_message("response", HumanMessage(content=input))
cache.add_message("response", AIMessage(content=ai_response))

return Streamable(
chain.astream({ "thought": thought }, {"tags": ["response"], "metadata": {"conversation_id": cache.conversation_id, "user_id": cache.user_id}}),
lambda response: cache.add_message("response", AIMessage(content=response))
save_new_messages
)

@classmethod
Expand Down Expand Up @@ -337,6 +347,12 @@ async def __anext__(self):
except StopAsyncIteration as e:
self.callback(self.content)
raise StopAsyncIteration
except BadRequestError as e:
if e.code == "content_filter":
self.stream_error = True
self.message = "Sorry, your message was flagged as inappropriate. Please try again."

return self.message
except Exception as e:
raise e

Expand Down

0 comments on commit c6d0aef

Please sign in to comment.