From b29190de1e10fb8f0642b8c2be63be1cc5426435 Mon Sep 17 00:00:00 2001 From: Matt Bertrand Date: Fri, 19 Sep 2025 10:48:07 -0400 Subject: [PATCH 1/3] Tweak checkpoint structure, step value --- ai_chatbots/api.py | 18 +-- ai_chatbots/api_test.py | 6 +- .../backpopulate_tutor_checkpoints.py | 153 ++++++++++++++++++ 3 files changed, 166 insertions(+), 11 deletions(-) create mode 100644 ai_chatbots/management/commands/backpopulate_tutor_checkpoints.py diff --git a/ai_chatbots/api.py b/ai_chatbots/api.py index 0146d377..9f3b743a 100644 --- a/ai_chatbots/api.py +++ b/ai_chatbots/api.py @@ -659,13 +659,12 @@ def _identify_new_messages( def _create_langchain_message(message: dict) -> dict: """Create a message in LangChain format.""" - message_id = str(uuid4()) return { "id": ["langchain", "schema", "messages", message["type"]], "lc": 1, "type": "constructor", "kwargs": { - "id": message_id, + "id": message.get("id", str(uuid4())), "type": message["type"].lower().replace("message", ""), "content": message["content"], }, @@ -684,7 +683,10 @@ def _create_checkpoint_data(checkpoint_id: str, step: int, chat_data: dict) -> d "__start__": {"__start__": step + 1} if step >= 0 else {}, }, "channel_values": { - "messages": chat_data.get("chat_history", []), + "messages": [ + _create_langchain_message(msg) + for msg in chat_data.get("chat_history", []) + ], # Preserve tutor-specific data "intent_history": chat_data.get("intent_history"), "assessment_history": chat_data.get("assessment_history"), @@ -732,6 +734,7 @@ def create_tutor_checkpoints( # Parse and validate chat data chat_data = json.loads(chat_json) if isinstance(chat_json, str) else chat_json messages = chat_data.get("chat_history", []) + if not messages: return [] @@ -759,13 +762,8 @@ def create_tutor_checkpoints( if not new_messages: return [] # No new messages to checkpoint - # Calculate starting step based on length of previous chat history - previous_messages = ( - json.loads(previous_chat_json).get("chat_history", []) - if previous_chat_json - else [] - ) - step = len(previous_messages) + # Calculate starting step based on previous checkpoint if any + step = latest_checkpoint.metadata.get("step", -1) + 1 if latest_checkpoint else 0 checkpoints_created = [] # Create checkpoints only for the NEW messages diff --git a/ai_chatbots/api_test.py b/ai_chatbots/api_test.py index 370add32..8f266bf2 100644 --- a/ai_chatbots/api_test.py +++ b/ai_chatbots/api_test.py @@ -1611,7 +1611,7 @@ def test_create_tutor_checkpoints_includes_metadata(): # The writes should include the tutor metadata key = next(iter(writes.keys())) - assert key == "__start__" if idx == 0 else "agent" + assert key == ("__start__" if idx == 0 else "agent") start_writes = writes.get(key) assert start_writes["user_id"] == "test_user_123" assert start_writes["course_id"] == "course_456" @@ -1619,3 +1619,7 @@ def test_create_tutor_checkpoints_includes_metadata(): # Messages should still be present assert len(start_writes["messages"]) == 1 + assert ( + start_writes["messages"][0] + == checkpoint.checkpoint["channel_values"]["messages"][idx] + ) diff --git a/ai_chatbots/management/commands/backpopulate_tutor_checkpoints.py b/ai_chatbots/management/commands/backpopulate_tutor_checkpoints.py new file mode 100644 index 00000000..4cbbf087 --- /dev/null +++ b/ai_chatbots/management/commands/backpopulate_tutor_checkpoints.py @@ -0,0 +1,153 @@ +"""Management command for backpopulating TutorBotOutput checkpoints""" + +import json +from uuid import UUID, uuid5 + +from django.conf import settings +from django.core.management import BaseCommand +from django.db import transaction + +from ai_chatbots.api import create_tutor_checkpoints +from ai_chatbots.models import DjangoCheckpoint, TutorBotOutput +from main.utils import chunks + + +def same_message(msg1, msg2): + """Check if two messages are the same based on role and content""" + return msg1.get("role") == msg2.get("role") and msg1.get("content") == msg2.get( + "content" + ) + + +def add_message_ids(thread_id: str, output_id: int, messages: list[dict]) -> list[dict]: + """Add unique IDs to messages that don't have one""" + for message in messages: + # Handle both dict and LangChain message objects + message_id = str( + uuid5( + UUID(thread_id), f'{output_id}_{message["type"]}_{message["content"]}' + ) + ) + if isinstance(message, dict): + if not message.get("id"): + message["id"] = message_id + # # LangChain message object + elif not hasattr(message, "id") or not message.id: + message.id = message_id + return messages + + +def _process_thread_batch(thread_ids_batch, *, overwrite: bool = False) -> int: + """Process a batch of thread_ids to reduce memory usage""" + processed = 0 + for thread_id in thread_ids_batch: + with transaction.atomic(): + if overwrite: + # Delete existing checkpoints for this thread if overwriting + DjangoCheckpoint.objects.filter(thread_id=thread_id).delete() + # If any checkpoints already exist for this thread, skip it + if DjangoCheckpoint.objects.filter(thread_id=thread_id).exists(): + continue + + # Use iterator to avoid loading all outputs into memory at once + outputs = ( + TutorBotOutput.objects.filter(thread_id=thread_id) + .order_by("id") + .iterator() + ) + previous_chat_json = None + # iterate through all outputs instead of just the latest, + # because some messages may have been truncated. + for tutorbot_output in outputs: + # Parse the chat data - handle both string and object formats + if isinstance(tutorbot_output.chat_json, str): + chat_data = json.loads(tutorbot_output.chat_json) + else: + chat_data = tutorbot_output.chat_json + + # Update tutorbot_output with message ids + chat_data["chat_history"] = add_message_ids( + thread_id, tutorbot_output.id, chat_data["chat_history"] + ) + tutorbot_output.chat_json = json.dumps(chat_data) + tutorbot_output.save(update_fields=["chat_json"]) + chat_data_str = json.dumps(chat_data) + create_tutor_checkpoints( + thread_id, chat_data_str, previous_chat_json=previous_chat_json + ) + previous_chat_json = chat_data_str + processed += 1 + return processed + + +def convert_tutorbot_to_checkpoints() -> None: + """ + Add message ids to all TutorBotOutput records and create + DjangoCheckpoints for new messages in each. + Memory-efficient version that processes in batches. + """ + + # Use iterator to avoid loading all thread_ids into memory + thread_ids_qs = ( + TutorBotOutput.objects.only("thread_id") + .values_list("thread_id", flat=True) + .distinct() + ) + # Process in batches using chunks utility + for thread_ids_chunk in chunks(thread_ids_qs, chunk_size=settings.QUERY_BATCH_SIZE): + _process_thread_batch(thread_ids_chunk) + + +class Command(BaseCommand): + """ + Add missing TutorbotOutput checkpoints. + """ + + help = "Add missing TutorbotOutput checkpoints." + + def add_arguments(self, parser): + parser.add_argument( + "--batch-size", + type=int, + default=getattr(settings, "QUERY_BATCH_SIZE", 100), + help=f"Checkpoint batch size (default: {settings.QUERY_BATCH_SIZE})", + ) + parser.add_argument( + "--overwrite", + dest="force_overwrite", + action="store_true", + help="Force regenerate existing TutorBotOutput checkpoints", + ) + + def handle(self, *args, **options): # noqa: ARG002 + """Add missing writes and state attributes to checkpoint metadata""" + + batch_size = options["batch_size"] + overwrite = options["force_overwrite"] + self.stdout.write( + f"Starting tutor checkpoint backpopulate (batch size: {batch_size})..." + ) + + # Use iterator to avoid loading all thread_ids into memory + thread_ids_qs = ( + TutorBotOutput.objects.only("thread_id") + .values_list("thread_id", flat=True) + .distinct() + ) + # Process in batches using chunks utility + total_processed = 0 + for idx, thread_ids_chunk in enumerate( + chunks(thread_ids_qs, chunk_size=settings.QUERY_BATCH_SIZE) + ): + self.stdout.write( + f"Processing batch {idx + 1} (size: {len(thread_ids_chunk)})..." + ) + processed = _process_thread_batch(thread_ids_chunk, overwrite=overwrite) + total_processed += processed + + if total_processed == 0: + self.stdout.write("No TutorBotOutputs found that need backpopulating") + else: + self.stdout.write(f"Completed! Processed {total_processed} TutorBotOutputs") + + return 0 From 8f94a77f8ce6996707414b5737392b1d7c6b0de2 Mon Sep 17 00:00:00 2001 From: Matt Bertrand Date: Fri, 19 Sep 2025 11:19:36 -0400 Subject: [PATCH 2/3] Additional tests --- ai_chatbots/api_test.py | 72 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 72 insertions(+) diff --git a/ai_chatbots/api_test.py b/ai_chatbots/api_test.py index 8f266bf2..f83070eb 100644 --- a/ai_chatbots/api_test.py +++ b/ai_chatbots/api_test.py @@ -1623,3 +1623,75 @@ def test_create_tutor_checkpoints_includes_metadata(): start_writes["messages"][0] == checkpoint.checkpoint["channel_values"]["messages"][idx] ) + + +@pytest.mark.parametrize("has_existing_id", [True, False]) +def test_create_langchain_message_id_handling(has_existing_id): + """Test that _create_langchain_message preserves existing IDs or generates new ones.""" + from ai_chatbots.api import _create_langchain_message + + existing_id = str(uuid4()) + message = { + "type": "HumanMessage", + "content": "Test message", + } + + if has_existing_id: + message["id"] = existing_id + + result = _create_langchain_message(message) + + if has_existing_id: + assert result["kwargs"]["id"] == existing_id + else: + assert len(result["kwargs"]["id"]) == 36 # UUID length + assert result["kwargs"]["id"] != existing_id + + assert result["kwargs"]["type"] == "human" + assert result["kwargs"]["content"] == "Test message" + + +@pytest.mark.django_db +@pytest.mark.parametrize("has_previous_checkpoint", [True, False]) +def test_create_tutor_checkpoints_step_calculation(has_previous_checkpoint): + """Test that step calculation works correctly with or without previous checkpoints.""" + thread_id = str(uuid4()) + + factories.UserChatSessionFactory.create(thread_id=thread_id) + + if has_previous_checkpoint: + initial_chat_data = { + "chat_history": [ + { + "type": "HumanMessage", + "content": "Initial message", + "id": str(uuid4()), + }, + {"type": "AIMessage", "content": "Response", "id": str(uuid4())}, + ], + "user_id": "test_user", + "course_id": "test_course", + } + + previous_checkpoints = create_tutor_checkpoints(thread_id, initial_chat_data) + assert len(previous_checkpoints) == 2 + assert previous_checkpoints[0].metadata["step"] == 0 + else: + initial_chat_data = None + + chat_data = { + "chat_history": [ + {"type": "HumanMessage", "content": "New message", "id": str(uuid4())}, + {"type": "AIMessage", "content": "Response", "id": str(uuid4())}, + ], + "user_id": "test_user", + "course_id": "test_course", + } + + result = create_tutor_checkpoints( + thread_id, chat_data, previous_chat_json=initial_chat_data + ) + + assert len(result) == 2 + assert result[0].metadata["step"] == (2 if has_previous_checkpoint else 0) + assert result[1].metadata["step"] == (3 if has_previous_checkpoint else 1) From f727b52d941b0ae854f9652376eeecab7ac4ab15 Mon Sep 17 00:00:00 2001 From: Matt Bertrand Date: Fri, 19 Sep 2025 11:39:12 -0400 Subject: [PATCH 3/3] More tweaks --- ai_chatbots/api.py | 7 +++---- ai_chatbots/api_test.py | 21 ++++----------------- 2 files changed, 7 insertions(+), 21 deletions(-) diff --git a/ai_chatbots/api.py b/ai_chatbots/api.py index 9f3b743a..b7812ce5 100644 --- a/ai_chatbots/api.py +++ b/ai_chatbots/api.py @@ -652,9 +652,7 @@ def _identify_new_messages( } # Find messages with IDs that don't exist in previous chat - return [ - msg for msg in filtered_messages if msg.get("id") not in existing_message_ids - ] + return [msg for msg in filtered_messages if msg["id"] not in existing_message_ids] def _create_langchain_message(message: dict) -> dict: @@ -664,7 +662,7 @@ def _create_langchain_message(message: dict) -> dict: "lc": 1, "type": "constructor", "kwargs": { - "id": message.get("id", str(uuid4())), + "id": message["id"], "type": message["type"].lower().replace("message", ""), "content": message["content"], }, @@ -767,6 +765,7 @@ def create_tutor_checkpoints( checkpoints_created = [] # Create checkpoints only for the NEW messages + for message in new_messages: checkpoint_id = str(uuid4()) diff --git a/ai_chatbots/api_test.py b/ai_chatbots/api_test.py index f83070eb..fb744fe8 100644 --- a/ai_chatbots/api_test.py +++ b/ai_chatbots/api_test.py @@ -1376,7 +1376,7 @@ def test_create_tutor_checkpoints_with_tool_messages(): chat_json = """ { "chat_history": [ - {"type": "ToolMessage", "content": "Tool result"}, + {"type": "ToolMessage", "content": "Tool result", "id": "msg0"}, {"type": "HumanMessage", "content": "Testing 123", "id": "msg1"} ] } @@ -1625,28 +1625,15 @@ def test_create_tutor_checkpoints_includes_metadata(): ) -@pytest.mark.parametrize("has_existing_id", [True, False]) -def test_create_langchain_message_id_handling(has_existing_id): +def test_create_langchain_message_id_handling(): """Test that _create_langchain_message preserves existing IDs or generates new ones.""" from ai_chatbots.api import _create_langchain_message - existing_id = str(uuid4()) - message = { - "type": "HumanMessage", - "content": "Test message", - } - - if has_existing_id: - message["id"] = existing_id + message = {"type": "HumanMessage", "content": "Test message", "id": str(uuid4())} result = _create_langchain_message(message) - if has_existing_id: - assert result["kwargs"]["id"] == existing_id - else: - assert len(result["kwargs"]["id"]) == 36 # UUID length - assert result["kwargs"]["id"] != existing_id - + assert result["kwargs"]["id"] == message["id"] assert result["kwargs"]["type"] == "human" assert result["kwargs"]["content"] == "Test message"