Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(assistant): stop a generation #28810

Merged
merged 24 commits into from
Feb 20, 2025
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
a58c528
fix: multiple interrupts
skoob13 Feb 17, 2025
ace1f97
feat: api
skoob13 Feb 17, 2025
0254ac7
feat: frontend
skoob13 Feb 17, 2025
c20f951
feat: use call of a node
skoob13 Feb 17, 2025
f23604f
feat: prevent double texting
skoob13 Feb 17, 2025
6358028
test: resuming
skoob13 Feb 17, 2025
7e56dcf
fix: better state updates
skoob13 Feb 17, 2025
32fcb4c
fix: remove reasoning messages when stopped
skoob13 Feb 17, 2025
b6030ab
Merge branch 'master' of github.com:PostHog/posthog into feat/max-sto…
skoob13 Feb 17, 2025
e0246f7
Update query snapshots
github-actions[bot] Feb 17, 2025
35b6ec0
fix: icons
skoob13 Feb 17, 2025
d855f79
Merge branch 'feat/max-stop-btn' of github.com:PostHog/posthog into f…
skoob13 Feb 17, 2025
137e6c2
Merge branch 'master' of github.com:PostHog/posthog into feat/max-sto…
skoob13 Feb 17, 2025
59acb3a
Update query snapshots
github-actions[bot] Feb 17, 2025
962d394
feat: lock for conversations
skoob13 Feb 17, 2025
4ce38a9
Merge branch 'feat/max-stop-btn' of github.com:PostHog/posthog into f…
skoob13 Feb 17, 2025
94beeb9
Merge branch 'master' of github.com:PostHog/posthog into feat/max-sto…
skoob13 Feb 18, 2025
e74386a
fix: useless test
skoob13 Feb 18, 2025
b30e92a
fix: interrupts are not stored in checkpoints
skoob13 Feb 18, 2025
077d3a0
Merge branch 'master' into feat/max-stop-btn
Twixes Feb 18, 2025
cfe032f
Update query snapshots
github-actions[bot] Feb 18, 2025
9d24f62
fix: feedback
skoob13 Feb 19, 2025
16138b9
Merge branch 'master' of github.com:PostHog/posthog into feat/max-sto…
skoob13 Feb 19, 2025
96c41b0
fix: categorize icons
skoob13 Feb 19, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 23 additions & 4 deletions ee/api/conversation.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
from typing import cast

import pydantic
from django.http import StreamingHttpResponse
from pydantic import ValidationError
from rest_framework import serializers
from rest_framework import serializers, status
from rest_framework.decorators import action
from rest_framework.exceptions import ValidationError
from rest_framework.renderers import BaseRenderer
from rest_framework.request import Request
from rest_framework.response import Response
from rest_framework.viewsets import GenericViewSet

from ee.hogai.assistant import Assistant
from ee.models.assistant import Conversation
from posthog.api.routing import TeamAndOrgViewSetMixin
from posthog.exceptions import Conflict
from posthog.models.user import User
from posthog.rate_limit import AIBurstRateThrottle, AISustainedRateThrottle
from posthog.schema import HumanMessage
Expand All @@ -24,7 +28,7 @@ def validate(self, data):
try:
message = HumanMessage(content=data["content"])
data["message"] = message
except ValidationError:
except pydantic.ValidationError:
raise serializers.ValidationError("Invalid message content.")
return data

Expand All @@ -40,7 +44,6 @@ def render(self, data, accepted_media_type=None, renderer_context=None):
class ConversationViewSet(TeamAndOrgViewSetMixin, GenericViewSet):
scope_object = "INTERNAL"
serializer_class = MessageSerializer
renderer_classes = [ServerSentEventRenderer]
queryset = Conversation.objects.all()
lookup_url_kwarg = "conversation"

Expand All @@ -51,6 +54,11 @@ def safely_get_queryset(self, queryset):
def get_throttles(self):
return [AIBurstRateThrottle(), AISustainedRateThrottle()]

def get_renderers(self):
if self.action == "create":
return [ServerSentEventRenderer()]
return super().get_renderers()

def create(self, request: Request, *args, **kwargs):
serializer = self.get_serializer(data=request.data)
serializer.is_valid(raise_exception=True)
Expand All @@ -60,6 +68,8 @@ def create(self, request: Request, *args, **kwargs):
conversation = self.get_object()
else:
conversation = self.get_queryset().create(user=request.user, team=self.team)
if conversation.is_locked:
raise Conflict("Conversation is locked.")
assistant = Assistant(
self.team,
conversation,
Expand All @@ -69,3 +79,12 @@ def create(self, request: Request, *args, **kwargs):
trace_id=serializer.validated_data["trace_id"],
)
return StreamingHttpResponse(assistant.stream(), content_type=ServerSentEventRenderer.media_type)

@action(detail=True, methods=["PATCH"])
def cancel(self, request: Request, *args, **kwargs):
conversation = self.get_object()
if conversation.status == Conversation.Status.CANCELLING:
raise ValidationError("Generation has already cancelled.")
conversation.status = Conversation.Status.CANCELLING
conversation.save()
return Response(status=status.HTTP_204_NO_CONTENT)
59 changes: 59 additions & 0 deletions ee/api/test/test_conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,3 +169,62 @@ def raise_error():
with self.assertRaises(Exception) as context:
b"".join(response.streaming_content)
self.assertTrue("Streaming error" in str(context.exception))

def test_cancel_conversation(self):
conversation = Conversation.objects.create(user=self.user, team=self.team)
response = self.client.patch(
f"/api/environments/{self.team.id}/conversations/{conversation.id}/cancel/",
)
self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT)
conversation.refresh_from_db()
self.assertEqual(conversation.status, Conversation.Status.CANCELLING)

def test_cancel_already_cancelling_conversation(self):
conversation = Conversation.objects.create(
user=self.user, team=self.team, status=Conversation.Status.CANCELLING
)
response = self.client.patch(
f"/api/environments/{self.team.id}/conversations/{conversation.id}/cancel/",
)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertEqual(response.json()["detail"], "Generation has already cancelled.")

def test_cancel_other_users_conversation(self):
conversation = Conversation.objects.create(user=self.other_user, team=self.team)
response = self.client.patch(
f"/api/environments/{self.team.id}/conversations/{conversation.id}/cancel/",
)
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)

def test_cancel_other_teams_conversation(self):
conversation = Conversation.objects.create(user=self.user, team=self.other_team)
response = self.client.patch(
f"/api/environments/{self.team.id}/conversations/{conversation.id}/cancel/",
)
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)

def test_cant_use_locked_conversation(self):
conversation = Conversation.objects.create(
user=self.user, team=self.team, status=Conversation.Status.IN_PROGRESS
)
response = self.client.post(
f"/api/environments/{self.team.id}/conversations/",
{
"conversation": str(conversation.id),
"content": "test query",
"trace_id": str(uuid.uuid4()),
},
)
self.assertEqual(response.status_code, status.HTTP_409_CONFLICT)

conversation.status = Conversation.Status.CANCELLING
conversation.save()
response = self.client.post(
f"/api/environments/{self.team.id}/conversations/",
{
"conversation": str(conversation.id),
"content": "test query",
"trace_id": str(uuid.uuid4()),
},
)
self.assertEqual(response.status_code, status.HTTP_409_CONFLICT)
106 changes: 70 additions & 36 deletions ee/hogai/assistant.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
from collections.abc import Generator, Iterator
from contextlib import contextmanager
from typing import Any, Optional, cast
from uuid import UUID, uuid4

Expand All @@ -19,6 +20,7 @@
from ee.hogai.schema_generator.nodes import SchemaGeneratorNode
from ee.hogai.trends.nodes import TrendsGeneratorNode
from ee.hogai.utils.asgi import SyncIterableToAsync
from ee.hogai.utils.exceptions import GenerationCanceled
from ee.hogai.utils.state import (
GraphMessageUpdateTuple,
GraphTaskStartedUpdateTuple,
Expand Down Expand Up @@ -124,39 +126,55 @@ def _stream(self) -> Generator[str, None, None]:
state, config=config, stream_mode=["messages", "values", "updates", "debug"]
)

# Assign the conversation id to the client.
if self._is_new_conversation:
yield self._serialize_conversation()

# Send the last message with the initialized id.
yield self._serialize_message(self._latest_message)

try:
last_viz_message = None
for update in generator:
if message := self._process_update(update):
if isinstance(message, VisualizationMessage):
last_viz_message = message
yield self._serialize_message(message)

# Check if the assistant has requested help.
state = self._graph.get_state(config)
if state.next:
interrupt_value = state.tasks[0].interrupts[0].value
feedback_message = (
AssistantMessage(content=interrupt_value, id=str(uuid4()))
if isinstance(interrupt_value, str)
else interrupt_value
)
self._graph.update_state(config, PartialAssistantState(messages=[feedback_message]))
yield self._serialize_message(feedback_message)
else:
self._report_conversation_state(last_viz_message)
except Exception as e:
logger.exception("Error in assistant stream", error=e)
# This is an unhandled error, so we just stop further generation at this point
yield self._serialize_message(FailureMessage())
raise # Re-raise, so that the error is printed or goes into Sentry
with self._lock_conversation():
# Assign the conversation id to the client.
if self._is_new_conversation:
yield self._serialize_conversation()

# Send the last message with the initialized id.
yield self._serialize_message(self._latest_message)

try:
last_viz_message = None
for update in generator:
if message := self._process_update(update):
if isinstance(message, VisualizationMessage):
last_viz_message = message
yield self._serialize_message(message)

# Check if the assistant has requested help.
state = self._graph.get_state(config)
if state.next:
interrupt_messages = []
for task in state.tasks:
for interrupt in task.interrupts:
interrupt_message = (
AssistantMessage(content=interrupt.value, id=str(uuid4()))
if isinstance(interrupt.value, str)
else interrupt.value
)
interrupt_messages.append(interrupt_message)
yield self._serialize_message(interrupt_message)

self._graph.update_state(
config,
PartialAssistantState(
messages=interrupt_messages,
# LangGraph by some reason doesn't store the interrupt exceptions in checkpoints.
graph_status="interrupted",
),
)
else:
self._report_conversation_state(last_viz_message)
except Exception as e:
# Reset the state, so that the next generation starts from the beginning.
self._graph.update_state(config, PartialAssistantState.get_reset_state())

if not isinstance(e, GenerationCanceled):
logger.exception("Error in assistant stream", error=e)
# This is an unhandled error, so we just stop further generation at this point
yield self._serialize_message(FailureMessage())
raise # Re-raise, so that the error is printed or goes into Sentry

@property
def _initial_state(self) -> AssistantState:
Expand All @@ -174,12 +192,18 @@ def _get_config(self) -> RunnableConfig:
def _init_or_update_state(self):
config = self._get_config()
snapshot = self._graph.get_state(config)

# If the graph previously hasn't reset the state, it is an interrupt. We resume from the point of interruption.
if snapshot.next:
saved_state = validate_state_update(snapshot.values)
self._state = saved_state
self._graph.update_state(config, PartialAssistantState(messages=[self._latest_message], resumed=True))
if saved_state.graph_status == "interrupted":
self._state = saved_state
self._graph.update_state(
config, PartialAssistantState(messages=[self._latest_message], graph_status="resumed")
)
# Return None to indicate that we want to continue the execution from the interrupted point.
return None

return None
initial_state = self._initial_state
self._state = initial_state
return initial_state
Expand Down Expand Up @@ -317,3 +341,13 @@ def _report_conversation_state(self, message: Optional[VisualizationMessage]):
"chat with ai",
{"prompt": human_message.content, "response": message.model_dump_json(exclude_none=True)},
)

@contextmanager
def _lock_conversation(self):
try:
self._conversation.status = Conversation.Status.IN_PROGRESS
self._conversation.save()
yield
finally:
self._conversation.status = Conversation.Status.NOT_STARTED
self._conversation.save()
Loading
Loading