Skip to content

Commit

Permalink
update endpoints
Browse files Browse the repository at this point in the history
  • Loading branch information
Josh-XT committed Jan 7, 2025
1 parent 2e7385b commit 9af8053
Showing 1 changed file with 144 additions and 35 deletions.
179 changes: 144 additions & 35 deletions agixt/graphqlendpoints/Conversations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,8 @@
import strawberry
from fastapi import HTTPException
from endpoints.Conversation import (
get_conversations_list as rest_get_conversations_list,
get_conversations as rest_get_conversations,
get_conversation_history as rest_get_conversation_history,
get_conversation_data as rest_get_conversation_data,
new_conversation_history as rest_new_conversation,
delete_conversation_history as rest_delete_conversation,
delete_history_message as rest_delete_message,
Expand All @@ -26,7 +24,7 @@
async def get_user_and_auth_from_context(info):
request = info.context["request"]
try:
user = await verify_api_key(request)
user = verify_api_key(request)
auth = request.headers.get("authorization")
return user, auth
except HTTPException as e:
Expand Down Expand Up @@ -110,16 +108,6 @@ class ConversationMetadata:
attachment_count: int


@strawberry.type
class ConversationNotification:
conversation_id: str
conversation_name: str
message_id: str
message: str
role: str
timestamp: datetime


@strawberry.type
class ConversationIdentifier:
id: str
Expand All @@ -132,13 +120,18 @@ class ConversationList:


@strawberry.type
class ConversationDetail:
conversations: List["ConversationMetadata"]
class ConversationHistory:
messages: List[ConversationMessage]


@strawberry.type
class ConversationHistory:
messages: List[ConversationMessage]
class ConversationNotification:
conversation_id: str
conversation_name: str
message_id: str
message: str
role: str
timestamp: datetime


@strawberry.type
Expand Down Expand Up @@ -194,19 +187,69 @@ class ConversationForkInput:
message_id: str


# Updated Query type
# Pagination Input
@strawberry.input
class PaginationInput:
page: int = 1
limit: int = 100


# Pagination Info
@strawberry.type
class PageInfo:
has_next_page: bool
has_previous_page: bool
total_pages: int
total_items: int
current_page: int
items_per_page: int


# Types
@strawberry.type
class ConversationMessage:
id: str
role: str
message: str
timestamp: datetime
updated_at: datetime
updated_by: Optional[str]
feedback_received: bool


@strawberry.type
class ConversationDetail:
metadata: ConversationMetadata
messages: List[ConversationMessage]


@strawberry.type
class ConversationConnection:
page_info: PageInfo
edges: List[ConversationMetadata]


@strawberry.type
class NotificationConnection:
page_info: PageInfo
edges: List["ConversationNotification"]


# Query type with pagination
@strawberry.type
class Query:
@strawberry.field
async def conversations(self, info) -> ConversationDetail:
"""Get detailed conversations list"""
async def conversations(
self, info, pagination: Optional[PaginationInput] = None
) -> ConversationConnection:
"""Get paginated list of conversations with details"""
user = await verify_api_key(info.context["request"])
result = await rest_get_conversations(user=user)

# Convert dictionary to strongly typed objects
conversation_metadata = [
# Convert dictionary to list and sort by updated_at
conversations = [
ConversationMetadata(
id=details["id"],
id=id,
name=details["name"],
agent_id=details["agent_id"],
created_at=details["created_at"],
Expand All @@ -215,20 +258,60 @@ async def conversations(self, info) -> ConversationDetail:
summary=details["summary"],
attachment_count=details["attachment_count"],
)
for details in result.conversations.values()
for id, details in result.conversations.items()
]
conversations.sort(key=lambda x: x.updated_at, reverse=True)

# Handle pagination
page = pagination.page if pagination else 1
limit = pagination.limit if pagination else 100
total_items = len(conversations)
total_pages = -(-total_items // limit) # Ceiling division
start_idx = (page - 1) * limit
end_idx = start_idx + limit

page_info = PageInfo(
has_next_page=end_idx < total_items,
has_previous_page=page > 1,
total_pages=total_pages,
total_items=total_items,
current_page=page,
items_per_page=limit,
)

return ConversationDetail(conversations=conversation_metadata)
return ConversationConnection(
page_info=page_info, edges=conversations[start_idx:end_idx]
)

@strawberry.field
async def conversation(self, info, conversation_id: str) -> ConversationHistory:
"""Get conversation history by ID"""
async def conversation(
self, info, conversation_id: str, pagination: Optional[PaginationInput] = None
) -> ConversationDetail:
"""Get conversation details and paginated messages"""
user, auth = await get_user_and_auth_from_context(info)
result = await rest_get_conversation_history(

# Get conversation metadata
result = await rest_get_conversations(user=user)
if conversation_id not in result.conversations:
raise Exception(f"Conversation {conversation_id} not found")

details = result.conversations[conversation_id]
metadata = ConversationMetadata(
id=conversation_id,
name=details["name"],
agent_id=details["agent_id"],
created_at=details["created_at"],
updated_at=details["updated_at"],
has_notifications=details["has_notifications"],
summary=details["summary"],
attachment_count=details["attachment_count"],
)

# Get messages with pagination
history_result = await rest_get_conversation_history(
conversation_id=conversation_id, user=user, authorization=auth
)

# Convert dictionary to strongly typed objects
messages = [
ConversationMessage(
id=msg["id"],
Expand All @@ -239,18 +322,25 @@ async def conversation(self, info, conversation_id: str) -> ConversationHistory:
updated_by=msg["updated_by"],
feedback_received=msg["feedback_received"],
)
for msg in result.conversation_history
for msg in history_result.conversation_history
]

return ConversationHistory(messages=messages)
# Apply pagination if provided
if pagination:
start_idx = (pagination.page - 1) * pagination.limit
end_idx = start_idx + pagination.limit
messages = messages[start_idx:end_idx]

return ConversationDetail(metadata=metadata, messages=messages)

@strawberry.field
async def notifications(self, info) -> NotificationList:
"""Get user notifications"""
async def notifications(
self, info, pagination: Optional[PaginationInput] = None
) -> NotificationConnection:
"""Get paginated notifications"""
user = await verify_api_key(info.context["request"])
result = await rest_get_notifications(user=user)

# Convert dictionary to strongly typed objects
notifications = [
ConversationNotification(
conversation_id=notif["conversation_id"],
Expand All @@ -263,7 +353,26 @@ async def notifications(self, info) -> NotificationList:
for notif in result.notifications
]

return NotificationList(notifications=notifications)
# Handle pagination
page = pagination.page if pagination else 1
limit = pagination.limit if pagination else 100
total_items = len(notifications)
total_pages = -(-total_items // limit) # Ceiling division
start_idx = (page - 1) * limit
end_idx = start_idx + limit

page_info = PageInfo(
has_next_page=end_idx < total_items,
has_previous_page=page > 1,
total_pages=total_pages,
total_items=total_items,
current_page=page,
items_per_page=limit,
)

return NotificationConnection(
page_info=page_info, edges=notifications[start_idx:end_idx]
)


# Response types for mutations
Expand Down

0 comments on commit 9af8053

Please sign in to comment.