Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 10 additions & 13 deletions ai_chatbots/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -652,20 +652,17 @@ def _identify_new_messages(
}

# Find messages with IDs that don't exist in previous chat
return [
msg for msg in filtered_messages if msg.get("id") not in existing_message_ids
]
return [msg for msg in filtered_messages if msg["id"] not in existing_message_ids]


def _create_langchain_message(message: dict) -> dict:
"""Create a message in LangChain format."""
message_id = str(uuid4())
return {
"id": ["langchain", "schema", "messages", message["type"]],
"lc": 1,
"type": "constructor",
"kwargs": {
"id": message_id,
"id": message["id"],
"type": message["type"].lower().replace("message", ""),
"content": message["content"],
},
Expand All @@ -684,7 +681,10 @@ def _create_checkpoint_data(checkpoint_id: str, step: int, chat_data: dict) -> d
"__start__": {"__start__": step + 1} if step >= 0 else {},
},
"channel_values": {
"messages": chat_data.get("chat_history", []),
"messages": [
_create_langchain_message(msg)
for msg in chat_data.get("chat_history", [])
],
# Preserve tutor-specific data
"intent_history": chat_data.get("intent_history"),
"assessment_history": chat_data.get("assessment_history"),
Expand Down Expand Up @@ -732,6 +732,7 @@ def create_tutor_checkpoints(
# Parse and validate chat data
chat_data = json.loads(chat_json) if isinstance(chat_json, str) else chat_json
messages = chat_data.get("chat_history", [])

if not messages:
return []

Expand Down Expand Up @@ -759,16 +760,12 @@ def create_tutor_checkpoints(
if not new_messages:
return [] # No new messages to checkpoint

# Calculate starting step based on length of previous chat history
previous_messages = (
json.loads(previous_chat_json).get("chat_history", [])
if previous_chat_json
else []
)
step = len(previous_messages)
# Calculate starting step based on previous checkpoint if any
step = latest_checkpoint.metadata.get("step", -1) + 1 if latest_checkpoint else 0
checkpoints_created = []

# Create checkpoints only for the NEW messages

for message in new_messages:
checkpoint_id = str(uuid4())

Expand Down
67 changes: 65 additions & 2 deletions ai_chatbots/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1376,7 +1376,7 @@ def test_create_tutor_checkpoints_with_tool_messages():
chat_json = """
{
"chat_history": [
{"type": "ToolMessage", "content": "Tool result"},
{"type": "ToolMessage", "content": "Tool result", "id": "msg0"},
{"type": "HumanMessage", "content": "Testing 123", "id": "msg1"}
]
}
Expand Down Expand Up @@ -1611,11 +1611,74 @@ def test_create_tutor_checkpoints_includes_metadata():

# The writes should include the tutor metadata
key = next(iter(writes.keys()))
assert key == "__start__" if idx == 0 else "agent"
assert key == ("__start__" if idx == 0 else "agent")
start_writes = writes.get(key)
assert start_writes["user_id"] == "test_user_123"
assert start_writes["course_id"] == "course_456"
assert start_writes["custom_field"] == "custom_value"

# Messages should still be present
assert len(start_writes["messages"]) == 1
assert (
start_writes["messages"][0]
== checkpoint.checkpoint["channel_values"]["messages"][idx]
)


def test_create_langchain_message_id_handling():
"""Test that _create_langchain_message preserves existing IDs or generates new ones."""
from ai_chatbots.api import _create_langchain_message

message = {"type": "HumanMessage", "content": "Test message", "id": str(uuid4())}

result = _create_langchain_message(message)

assert result["kwargs"]["id"] == message["id"]
assert result["kwargs"]["type"] == "human"
assert result["kwargs"]["content"] == "Test message"


@pytest.mark.django_db
@pytest.mark.parametrize("has_previous_checkpoint", [True, False])
def test_create_tutor_checkpoints_step_calculation(has_previous_checkpoint):
"""Test that step calculation works correctly with or without previous checkpoints."""
thread_id = str(uuid4())

factories.UserChatSessionFactory.create(thread_id=thread_id)

if has_previous_checkpoint:
initial_chat_data = {
"chat_history": [
{
"type": "HumanMessage",
"content": "Initial message",
"id": str(uuid4()),
},
{"type": "AIMessage", "content": "Response", "id": str(uuid4())},
],
"user_id": "test_user",
"course_id": "test_course",
}

previous_checkpoints = create_tutor_checkpoints(thread_id, initial_chat_data)
assert len(previous_checkpoints) == 2
assert previous_checkpoints[0].metadata["step"] == 0
else:
initial_chat_data = None

chat_data = {
"chat_history": [
{"type": "HumanMessage", "content": "New message", "id": str(uuid4())},
{"type": "AIMessage", "content": "Response", "id": str(uuid4())},
],
"user_id": "test_user",
"course_id": "test_course",
}

result = create_tutor_checkpoints(
thread_id, chat_data, previous_chat_json=initial_chat_data
)

assert len(result) == 2
assert result[0].metadata["step"] == (2 if has_previous_checkpoint else 0)
assert result[1].metadata["step"] == (3 if has_previous_checkpoint else 1)
153 changes: 153 additions & 0 deletions ai_chatbots/management/commands/backpopulate_tutor_checkpoints.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
"""Management command for backpopulating TutorBotOutput checkpoints"""

import json
from uuid import UUID, uuid5

from django.conf import settings
from django.core.management import BaseCommand
from django.db import transaction

from ai_chatbots.api import create_tutor_checkpoints
from ai_chatbots.models import DjangoCheckpoint, TutorBotOutput
from main.utils import chunks


def same_message(msg1, msg2):
"""Check if two messages are the same based on role and content"""
return msg1.get("role") == msg2.get("role") and msg1.get("content") == msg2.get(
"content"
)


def add_message_ids(thread_id: str, output_id: int, messages: list[dict]) -> list[dict]:
"""Add unique IDs to messages that don't have one"""
for message in messages:
# Handle both dict and LangChain message objects
message_id = str(
uuid5(
UUID(thread_id), f'{output_id}_{message["type"]}_{message["content"]}'
)
)
if isinstance(message, dict):
if not message.get("id"):
message["id"] = message_id
# # LangChain message object
elif not hasattr(message, "id") or not message.id:
message.id = message_id
return messages


def _process_thread_batch(thread_ids_batch, *, overwrite: bool = False) -> int:
"""Process a batch of thread_ids to reduce memory usage"""
processed = 0
for thread_id in thread_ids_batch:
with transaction.atomic():
if overwrite:
# Delete existing checkpoints for this thread if overwriting
DjangoCheckpoint.objects.filter(thread_id=thread_id).delete()
# If any checkpoints already exist for this thread, skip it
if DjangoCheckpoint.objects.filter(thread_id=thread_id).exists():
continue

# Use iterator to avoid loading all outputs into memory at once
outputs = (
TutorBotOutput.objects.filter(thread_id=thread_id)
.order_by("id")
.iterator()
)
previous_chat_json = None
# iterate through all outputs instead of just the latest,
# because some messages may have been truncated.
for tutorbot_output in outputs:
# Parse the chat data - handle both string and object formats
if isinstance(tutorbot_output.chat_json, str):
chat_data = json.loads(tutorbot_output.chat_json)
else:
chat_data = tutorbot_output.chat_json

# Update tutorbot_output with message ids
chat_data["chat_history"] = add_message_ids(
thread_id, tutorbot_output.id, chat_data["chat_history"]
)
tutorbot_output.chat_json = json.dumps(chat_data)
tutorbot_output.save(update_fields=["chat_json"])
chat_data_str = json.dumps(chat_data)
create_tutor_checkpoints(
thread_id, chat_data_str, previous_chat_json=previous_chat_json
)
previous_chat_json = chat_data_str
processed += 1
return processed


def convert_tutorbot_to_checkpoints() -> None:
"""
Add message ids to all TutorBotOutput records and create
DjangoCheckpoints for new messages in each.
Memory-efficient version that processes in batches.
"""

# Use iterator to avoid loading all thread_ids into memory
thread_ids_qs = (
TutorBotOutput.objects.only("thread_id")
.values_list("thread_id", flat=True)
.distinct()
)
# Process in batches using chunks utility
for thread_ids_chunk in chunks(thread_ids_qs, chunk_size=settings.QUERY_BATCH_SIZE):
_process_thread_batch(thread_ids_chunk)


class Command(BaseCommand):
"""
Add missing TutorbotOutput checkpoints.
"""

help = "Add missing TutorbotOutput checkpoints."

def add_arguments(self, parser):
parser.add_argument(
"--batch-size",
type=int,
default=getattr(settings, "QUERY_BATCH_SIZE", 100),
help=f"Checkpoint batch size (default: {settings.QUERY_BATCH_SIZE})",
)
parser.add_argument(
"--overwrite",
dest="force_overwrite",
action="store_true",
help="Force regenerate existing TutorBotOutput checkpoints",
)

def handle(self, *args, **options): # noqa: ARG002
"""Add missing writes and state attributes to checkpoint metadata"""

batch_size = options["batch_size"]
overwrite = options["force_overwrite"]
self.stdout.write(
f"Starting tutor checkpoint backpopulate (batch size: {batch_size})..."
)

# Use iterator to avoid loading all thread_ids into memory
thread_ids_qs = (
TutorBotOutput.objects.only("thread_id")
.values_list("thread_id", flat=True)
.distinct()
)
# Process in batches using chunks utility
total_processed = 0
for idx, thread_ids_chunk in enumerate(
chunks(thread_ids_qs, chunk_size=settings.QUERY_BATCH_SIZE)
):
self.stdout.write(
f"Processing batch {idx + 1} (size: {len(thread_ids_chunk)})..."
)
processed = _process_thread_batch(thread_ids_chunk, overwrite=overwrite)
total_processed += processed

if total_processed == 0:
self.stdout.write("No TutorBotOutputs found that need backpopulating")
else:
self.stdout.write(f"Completed! Processed {total_processed} TutorBotOutputs")

return 0
Loading