Skip to content

Commit

Permalink
feat(deriver) Turn off derivations by editing session medatadata
Browse files Browse the repository at this point in the history
  • Loading branch information
Vineeth Voruganti committed Sep 14, 2024
1 parent 22f68fb commit 560e4a6
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 7 deletions.
12 changes: 8 additions & 4 deletions src/deriver/consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,9 @@ async def process_ai_message(
result = await db.execute(messages_stmt)
messages = result.scalars().all()[::-1]

chat_history_str = "\n".join([
f"human: {m.content}" if m.is_user else f"ai: {m.content}" for m in messages
])
chat_history_str = "\n".join(
[f"human: {m.content}" if m.is_user else f"ai: {m.content}" for m in messages]
)
# append current message to chat history
chat_history_str = f"{chat_history_str}\nai: {content}"

Expand Down Expand Up @@ -203,6 +203,8 @@ async def process_user_message(
Process a user message. If there are revised user predictions to run VoE against, run it. Otherwise pass.
"""
rprint(f"[orange1]Processing User Message: {content}")

# Get the AI message directly preceding this User message
subquery = (
select(models.Message.created_at)
.where(models.Message.id == message_id)
Expand All @@ -223,6 +225,7 @@ async def process_user_message(

if ai_message and ai_message.content:
rprint(f"[orange1]AI Message: {ai_message.content}")
# Get the User Thought Revision Associated with this AI Message
metamessages_stmt = (
select(models.Metamessage)
.where(models.Metamessage.message_id == ai_message.id)
Expand Down Expand Up @@ -292,7 +295,8 @@ async def process_user_message(
)
rprint(f"[orange1]Returned Document: {doc.content}")
else:
raise Exception("\033[91mUser Thought Prediction Revision NOT READY YET")
rprint("[red] No Prediction Associated with this Message")
return
else:
rprint("[red]No AI message before this user message[/red]")
return
Expand Down
2 changes: 1 addition & 1 deletion src/deriver/queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ async def schedule_session(
async with semaphore, SessionLocal() as db:
try:
available_slots = semaphore._value
print(available_slots)
# print(available_slots)
new_sessions = await get_available_sessions(db, available_slots)

if new_sessions:
Expand Down
17 changes: 15 additions & 2 deletions src/routers/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,21 @@

async def enqueue(payload: dict):
async with SessionLocal() as db:
# Get Session and Check metadata
session = await crud.get_session(
db,
app_id=payload["app_id"],
user_id=payload["user_id"],
session_id=payload["session_id"],
)
# Check if metadata has a "deriver" key
deriver_enabled = session.h_metadata.get("deriver_enabled")
if deriver_enabled is not None and deriver_enabled is not True:
print("=====================")
print(f"Deriver is not enabled on session {payload['session_id']}")
print("=====================")
# If deriver is not enabled, do not enqueue
return
try:
processed_payload = {
k: str(v) if isinstance(v, uuid.UUID) else v for k, v in payload.items()
Expand Down Expand Up @@ -68,8 +83,6 @@ async def create_message_for_session(
honcho_message = await crud.create_message(
db, message=message, app_id=app_id, user_id=user_id, session_id=session_id
)
print("=======")
print("Should be enqueued")
payload = {
"app_id": app_id,
"user_id": user_id,
Expand Down

0 comments on commit 560e4a6

Please sign in to comment.