Skip to content

Commit

Permalink
feat(assistant): stop a generation (#28810)
Browse files Browse the repository at this point in the history
Co-authored-by: github-actions <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: Michael Matloka <[email protected]>
  • Loading branch information
3 people authored Feb 20, 2025
1 parent b35281e commit a250542
Show file tree
Hide file tree
Showing 22 changed files with 540 additions and 166 deletions.
27 changes: 23 additions & 4 deletions ee/api/conversation.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
from typing import cast

import pydantic
from django.http import StreamingHttpResponse
from pydantic import ValidationError
from rest_framework import serializers
from rest_framework import serializers, status
from rest_framework.decorators import action
from rest_framework.exceptions import ValidationError
from rest_framework.renderers import BaseRenderer
from rest_framework.request import Request
from rest_framework.response import Response
from rest_framework.viewsets import GenericViewSet

from ee.hogai.assistant import Assistant
from ee.models.assistant import Conversation
from posthog.api.routing import TeamAndOrgViewSetMixin
from posthog.exceptions import Conflict
from posthog.models.user import User
from posthog.rate_limit import AIBurstRateThrottle, AISustainedRateThrottle
from posthog.schema import HumanMessage
Expand All @@ -24,7 +28,7 @@ def validate(self, data):
try:
message = HumanMessage(content=data["content"])
data["message"] = message
except ValidationError:
except pydantic.ValidationError:
raise serializers.ValidationError("Invalid message content.")
return data

Expand All @@ -40,7 +44,6 @@ def render(self, data, accepted_media_type=None, renderer_context=None):
class ConversationViewSet(TeamAndOrgViewSetMixin, GenericViewSet):
scope_object = "INTERNAL"
serializer_class = MessageSerializer
renderer_classes = [ServerSentEventRenderer]
queryset = Conversation.objects.all()
lookup_url_kwarg = "conversation"

Expand All @@ -51,6 +54,11 @@ def safely_get_queryset(self, queryset):
def get_throttles(self):
return [AIBurstRateThrottle(), AISustainedRateThrottle()]

def get_renderers(self):
if self.action == "create":
return [ServerSentEventRenderer()]
return super().get_renderers()

def create(self, request: Request, *args, **kwargs):
serializer = self.get_serializer(data=request.data)
serializer.is_valid(raise_exception=True)
Expand All @@ -60,6 +68,8 @@ def create(self, request: Request, *args, **kwargs):
conversation = self.get_object()
else:
conversation = self.get_queryset().create(user=request.user, team=self.team)
if conversation.is_locked:
raise Conflict("Conversation is locked.")
assistant = Assistant(
self.team,
conversation,
Expand All @@ -69,3 +79,12 @@ def create(self, request: Request, *args, **kwargs):
trace_id=serializer.validated_data["trace_id"],
)
return StreamingHttpResponse(assistant.stream(), content_type=ServerSentEventRenderer.media_type)

@action(detail=True, methods=["PATCH"])
def cancel(self, request: Request, *args, **kwargs):
conversation = self.get_object()
if conversation.status == Conversation.Status.CANCELING:
raise ValidationError("Generation has already been cancelled.")
conversation.status = Conversation.Status.CANCELING
conversation.save()
return Response(status=status.HTTP_204_NO_CONTENT)
57 changes: 57 additions & 0 deletions ee/api/test/test_conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,3 +169,60 @@ def raise_error():
with self.assertRaises(Exception) as context:
b"".join(response.streaming_content)
self.assertTrue("Streaming error" in str(context.exception))

def test_cancel_conversation(self):
conversation = Conversation.objects.create(user=self.user, team=self.team)
response = self.client.patch(
f"/api/environments/{self.team.id}/conversations/{conversation.id}/cancel/",
)
self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT)
conversation.refresh_from_db()
self.assertEqual(conversation.status, Conversation.Status.CANCELING)

def test_cancel_already_canceling_conversation(self):
conversation = Conversation.objects.create(user=self.user, team=self.team, status=Conversation.Status.CANCELING)
response = self.client.patch(
f"/api/environments/{self.team.id}/conversations/{conversation.id}/cancel/",
)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertEqual(response.json()["detail"], "Generation has already been cancelled.")

def test_cancel_other_users_conversation(self):
conversation = Conversation.objects.create(user=self.other_user, team=self.team)
response = self.client.patch(
f"/api/environments/{self.team.id}/conversations/{conversation.id}/cancel/",
)
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)

def test_cancel_other_teams_conversation(self):
conversation = Conversation.objects.create(user=self.user, team=self.other_team)
response = self.client.patch(
f"/api/environments/{self.team.id}/conversations/{conversation.id}/cancel/",
)
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)

def test_cant_use_locked_conversation(self):
conversation = Conversation.objects.create(
user=self.user, team=self.team, status=Conversation.Status.IN_PROGRESS
)
response = self.client.post(
f"/api/environments/{self.team.id}/conversations/",
{
"conversation": str(conversation.id),
"content": "test query",
"trace_id": str(uuid.uuid4()),
},
)
self.assertEqual(response.status_code, status.HTTP_409_CONFLICT)

conversation.status = Conversation.Status.CANCELING
conversation.save()
response = self.client.post(
f"/api/environments/{self.team.id}/conversations/",
{
"conversation": str(conversation.id),
"content": "test query",
"trace_id": str(uuid.uuid4()),
},
)
self.assertEqual(response.status_code, status.HTTP_409_CONFLICT)
106 changes: 70 additions & 36 deletions ee/hogai/assistant.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
from collections.abc import Generator, Iterator
from contextlib import contextmanager
from typing import Any, Optional, cast
from uuid import UUID, uuid4

Expand All @@ -19,6 +20,7 @@
from ee.hogai.schema_generator.nodes import SchemaGeneratorNode
from ee.hogai.trends.nodes import TrendsGeneratorNode
from ee.hogai.utils.asgi import SyncIterableToAsync
from ee.hogai.utils.exceptions import GenerationCanceled
from ee.hogai.utils.state import (
GraphMessageUpdateTuple,
GraphTaskStartedUpdateTuple,
Expand Down Expand Up @@ -124,39 +126,55 @@ def _stream(self) -> Generator[str, None, None]:
state, config=config, stream_mode=["messages", "values", "updates", "debug"]
)

# Assign the conversation id to the client.
if self._is_new_conversation:
yield self._serialize_conversation()

# Send the last message with the initialized id.
yield self._serialize_message(self._latest_message)

try:
last_viz_message = None
for update in generator:
if message := self._process_update(update):
if isinstance(message, VisualizationMessage):
last_viz_message = message
yield self._serialize_message(message)

# Check if the assistant has requested help.
state = self._graph.get_state(config)
if state.next:
interrupt_value = state.tasks[0].interrupts[0].value
feedback_message = (
AssistantMessage(content=interrupt_value, id=str(uuid4()))
if isinstance(interrupt_value, str)
else interrupt_value
)
self._graph.update_state(config, PartialAssistantState(messages=[feedback_message]))
yield self._serialize_message(feedback_message)
else:
self._report_conversation_state(last_viz_message)
except Exception as e:
logger.exception("Error in assistant stream", error=e)
# This is an unhandled error, so we just stop further generation at this point
yield self._serialize_message(FailureMessage())
raise # Re-raise, so that the error is printed or goes into Sentry
with self._lock_conversation():
# Assign the conversation id to the client.
if self._is_new_conversation:
yield self._serialize_conversation()

# Send the last message with the initialized id.
yield self._serialize_message(self._latest_message)

try:
last_viz_message = None
for update in generator:
if message := self._process_update(update):
if isinstance(message, VisualizationMessage):
last_viz_message = message
yield self._serialize_message(message)

# Check if the assistant has requested help.
state = self._graph.get_state(config)
if state.next:
interrupt_messages = []
for task in state.tasks:
for interrupt in task.interrupts:
interrupt_message = (
AssistantMessage(content=interrupt.value, id=str(uuid4()))
if isinstance(interrupt.value, str)
else interrupt.value
)
interrupt_messages.append(interrupt_message)
yield self._serialize_message(interrupt_message)

self._graph.update_state(
config,
PartialAssistantState(
messages=interrupt_messages,
# LangGraph by some reason doesn't store the interrupt exceptions in checkpoints.
graph_status="interrupted",
),
)
else:
self._report_conversation_state(last_viz_message)
except Exception as e:
# Reset the state, so that the next generation starts from the beginning.
self._graph.update_state(config, PartialAssistantState.get_reset_state())

if not isinstance(e, GenerationCanceled):
logger.exception("Error in assistant stream", error=e)
# This is an unhandled error, so we just stop further generation at this point
yield self._serialize_message(FailureMessage())
raise # Re-raise, so that the error is printed or goes into Sentry

@property
def _initial_state(self) -> AssistantState:
Expand All @@ -174,12 +192,18 @@ def _get_config(self) -> RunnableConfig:
def _init_or_update_state(self):
config = self._get_config()
snapshot = self._graph.get_state(config)

# If the graph previously hasn't reset the state, it is an interrupt. We resume from the point of interruption.
if snapshot.next:
saved_state = validate_state_update(snapshot.values)
self._state = saved_state
self._graph.update_state(config, PartialAssistantState(messages=[self._latest_message], resumed=True))
if saved_state.graph_status == "interrupted":
self._state = saved_state
self._graph.update_state(
config, PartialAssistantState(messages=[self._latest_message], graph_status="resumed")
)
# Return None to indicate that we want to continue the execution from the interrupted point.
return None

return None
initial_state = self._initial_state
self._state = initial_state
return initial_state
Expand Down Expand Up @@ -317,3 +341,13 @@ def _report_conversation_state(self, message: Optional[VisualizationMessage]):
"chat with ai",
{"prompt": human_message.content, "response": message.model_dump_json(exclude_none=True)},
)

@contextmanager
def _lock_conversation(self):
try:
self._conversation.status = Conversation.Status.IN_PROGRESS
self._conversation.save()
yield
finally:
self._conversation.status = Conversation.Status.IDLE
self._conversation.save()
Loading

0 comments on commit a250542

Please sign in to comment.