Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions python/packages/chatkit/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,17 @@ class MyChatKitServer(ChatKitServer[dict[str, Any]]):
if input_user_message is None:
return

# Convert ChatKit message to Agent Framework format
agent_messages = await simple_to_agent_input(input_user_message)
# Load full thread history to maintain conversation context
thread_items_page = await self.store.load_thread_items(
thread_id=thread.id,
after=None,
limit=1000,
order="asc",
context=context,
)

# Convert all ChatKit messages to Agent Framework format
agent_messages = await simple_to_agent_input(thread_items_page.data)

# Run the agent and stream responses
response_stream = agent.run_stream(agent_messages)
Expand Down
157 changes: 124 additions & 33 deletions python/samples/demos/chatkit-integration/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,43 @@
from typing import Annotated, Any

import uvicorn

# Agent Framework imports
from agent_framework import AgentRunResponseUpdate, ChatAgent, ChatMessage, FunctionResultContent, Role
from agent_framework.azure import AzureOpenAIChatClient

# Agent Framework ChatKit integration
from agent_framework_chatkit import ThreadItemConverter, stream_agent_response

# Local imports
from attachment_store import FileBasedAttachmentStore
from azure.identity import AzureCliCredential

# ChatKit imports
from chatkit.actions import Action
from chatkit.server import ChatKitServer
from chatkit.store import StoreItemType, default_generate_id
from chatkit.types import (
ThreadItem,
ThreadItemDoneEvent,
ThreadMetadata,
ThreadStreamEvent,
UserMessageItem,
WidgetItem,
)
from chatkit.widgets import WidgetRoot
from fastapi import FastAPI, File, Request, UploadFile
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse, JSONResponse, Response, StreamingResponse
from pydantic import Field
from store import SQLiteStore
from weather_widget import (
WeatherData,
city_selector_copy_text,
render_city_selector_widget,
render_weather_widget,
weather_widget_copy_text,
)

# ============================================================================
# Configuration Constants
Expand Down Expand Up @@ -56,37 +88,6 @@
)
logger = logging.getLogger(__name__)

# Agent Framework imports
from agent_framework import AgentRunResponseUpdate, ChatAgent, ChatMessage, FunctionResultContent, Role
from agent_framework.azure import AzureOpenAIChatClient

# Agent Framework ChatKit integration
from agent_framework_chatkit import ThreadItemConverter, stream_agent_response

# Local imports
from attachment_store import FileBasedAttachmentStore

# ChatKit imports
from chatkit.actions import Action
from chatkit.server import ChatKitServer
from chatkit.store import StoreItemType, default_generate_id
from chatkit.types import (
ThreadItemDoneEvent,
ThreadMetadata,
ThreadStreamEvent,
UserMessageItem,
WidgetItem,
)
from chatkit.widgets import WidgetRoot
from store import SQLiteStore
from weather_widget import (
WeatherData,
city_selector_copy_text,
render_city_selector_widget,
render_weather_widget,
weather_widget_copy_text,
)


class WeatherResponse(str):
"""A string response that also carries WeatherData for widget creation."""
Expand Down Expand Up @@ -238,6 +239,81 @@ async def _fetch_attachment_data(self, attachment_id: str) -> bytes:
"""
return await attachment_store.read_attachment_bytes(attachment_id)

async def _update_thread_title(
self, thread: ThreadMetadata, thread_items: list[ThreadItem], context: dict[str, Any]
) -> None:
"""Update thread title using LLM to generate a concise summary.

Args:
thread: The thread metadata to update.
thread_items: All items in the thread.
context: The context dictionary.
"""
logger.info(f"Attempting to update thread title for thread: {thread.id}")

if not thread_items:
logger.debug("No thread items available for title generation")
return

# Collect user messages to understand the conversation topic
user_messages: list[str] = []
for item in thread_items:
if isinstance(item, UserMessageItem) and item.content:
for content_part in item.content:
if hasattr(content_part, "text") and isinstance(content_part.text, str):
user_messages.append(content_part.text)
break

if not user_messages:
logger.debug("No user messages found for title generation")
return

logger.debug(f"Found {len(user_messages)} user message(s) for title generation")

try:
# Use the agent's chat client to generate a concise title
# Combine first few messages to capture the conversation topic
conversation_context = "\n".join(user_messages[:3])

title_prompt = [
ChatMessage(
role=Role.USER,
text=(
f"Generate a very short, concise title (max 40 characters) for a conversation "
f"that starts with:\n\n{conversation_context}\n\n"
"Respond with ONLY the title, nothing else."
),
)
]

# Use the chat client directly for a quick, lightweight call
response = await self.weather_agent.chat_client.get_response(
messages=title_prompt,
temperature=0.3,
max_tokens=20,
)

if response.messages and response.messages[-1].text:
title = response.messages[-1].text.strip().strip('"').strip("'")
# Ensure it's not too long
if len(title) > 50:
title = title[:47] + "..."

thread.title = title
await self.store.save_thread(thread, context)
logger.info(f"Updated thread {thread.id} title to: {title}")

except Exception as e:
logger.warning(f"Failed to generate thread title, using fallback: {e}")
# Fallback to simple truncation
first_message: str = user_messages[0]
title: str = first_message[:50].strip()
if len(first_message) > 50:
title += "..."
thread.title = title
await self.store.save_thread(thread, context)
logger.info(f"Updated thread {thread.id} title to (fallback): {title}")

async def respond(
self,
thread: ThreadMetadata,
Expand All @@ -263,8 +339,19 @@ async def respond(
weather_data: WeatherData | None = None
show_city_selector = False

# Convert ChatKit user message to Agent Framework ChatMessage using ThreadItemConverter
agent_messages = await self.converter.to_agent_input(input_user_message)
# Load full thread history from the store
thread_items_page = await self.store.load_thread_items(
thread_id=thread.id,
after=None,
limit=1000,
order="asc",
context=context,
)
thread_items = thread_items_page.data

# Convert ALL thread items to Agent Framework ChatMessages using ThreadItemConverter
# This ensures the agent has the full conversation context
agent_messages = await self.converter.to_agent_input(thread_items)

if not agent_messages:
logger.warning("No messages after conversion")
Expand Down Expand Up @@ -330,6 +417,10 @@ async def intercept_stream() -> AsyncIterator[AgentRunResponseUpdate]:
yield widget_event
logger.debug("City selector widget streamed successfully")

# Update thread title based on first user message if not already set
if not thread.title or thread.title == "New thread":
await self._update_thread_title(thread, thread_items, context)

logger.info(f"Completed processing message for thread: {thread.id}")

except Exception as e:
Expand Down
14 changes: 6 additions & 8 deletions python/samples/demos/chatkit-integration/attachment_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
"""

from pathlib import Path
from typing import Any, TYPE_CHECKING
from typing import TYPE_CHECKING, Any

from chatkit.store import AttachmentStore
from chatkit.types import Attachment, AttachmentCreateParams, FileAttachment, ImageAttachment
Expand Down Expand Up @@ -51,7 +51,7 @@ def __init__(
self.uploads_dir = Path(uploads_dir)
self.base_url = base_url.rstrip("/")
self.data_store = data_store

# Create uploads directory if it doesn't exist
self.uploads_dir.mkdir(parents=True, exist_ok=True)

Expand All @@ -65,25 +65,23 @@ async def delete_attachment(self, attachment_id: str, context: dict[str, Any]) -
if file_path.exists():
file_path.unlink()

async def create_attachment(
self, input: AttachmentCreateParams, context: dict[str, Any]
) -> Attachment:
async def create_attachment(self, input: AttachmentCreateParams, context: dict[str, Any]) -> Attachment:
"""Create an attachment with upload URL for two-phase upload.

This creates the attachment metadata and returns upload URLs that
the client will use to POST the actual file bytes.
"""
# Generate unique ID for this attachment
attachment_id = self.generate_attachment_id(input.mime_type, context)

# Generate upload URL that points to our FastAPI upload endpoint
upload_url = f"{self.base_url}/upload/{attachment_id}"

# Create appropriate attachment type based on MIME type
if input.mime_type.startswith("image/"):
# For images, also provide a preview URL
preview_url = f"{self.base_url}/preview/{attachment_id}"

attachment = ImageAttachment(
id=attachment_id,
type="image",
Expand Down Expand Up @@ -117,5 +115,5 @@ async def read_attachment_bytes(self, attachment_id: str) -> bytes:
file_path = self.get_file_path(attachment_id)
if not file_path.exists():
raise FileNotFoundError(f"Attachment {attachment_id} not found on disk")

return file_path.read_bytes()
33 changes: 10 additions & 23 deletions python/samples/demos/chatkit-integration/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import uuid
from typing import Any

from chatkit.store import Store, NotFoundError
from chatkit.store import NotFoundError, Store
from chatkit.types import (
Attachment,
Page,
Expand All @@ -22,16 +22,19 @@

class ThreadData(BaseModel):
"""Model for serializing thread data to SQLite."""

thread: ThreadMetadata


class ItemData(BaseModel):
"""Model for serializing thread item data to SQLite."""

item: ThreadItem


class AttachmentData(BaseModel):
"""Model for serializing attachment data to SQLite."""

attachment: Attachment


Expand Down Expand Up @@ -185,19 +188,13 @@ async def load_thread_items(
params.append(limit + 1)

items_cursor = conn.execute(query, params).fetchall()
items = [
ItemData.model_validate_json(row[0]).item for row in items_cursor
]
items = [ItemData.model_validate_json(row[0]).item for row in items_cursor]

has_more = len(items) > limit
if has_more:
items = items[:limit]

return Page[ThreadItem](
data=items,
has_more=has_more,
after=items[-1].id if items else None
)
return Page[ThreadItem](data=items, has_more=has_more, after=items[-1].id if items else None)

async def save_attachment(self, attachment: Attachment, context: dict[str, Any]) -> None:
user_id = context.get("user_id", "demo_user")
Expand Down Expand Up @@ -270,23 +267,15 @@ async def load_threads(
params.append(limit + 1)

threads_cursor = conn.execute(query, params).fetchall()
threads = [
ThreadData.model_validate_json(row[0]).thread for row in threads_cursor
]
threads = [ThreadData.model_validate_json(row[0]).thread for row in threads_cursor]

has_more = len(threads) > limit
if has_more:
threads = threads[:limit]

return Page[ThreadMetadata](
data=threads,
has_more=has_more,
after=threads[-1].id if threads else None
)
return Page[ThreadMetadata](data=threads, has_more=has_more, after=threads[-1].id if threads else None)

async def add_thread_item(
self, thread_id: str, item: ThreadItem, context: dict[str, Any]
) -> None:
async def add_thread_item(self, thread_id: str, item: ThreadItem, context: dict[str, Any]) -> None:
user_id = context.get("user_id", "demo_user")

with self._create_connection() as conn:
Expand Down Expand Up @@ -348,9 +337,7 @@ async def delete_thread(self, thread_id: str, context: dict[str, Any]) -> None:
)
conn.commit()

async def delete_thread_item(
self, thread_id: str, item_id: str, context: dict[str, Any]
) -> None:
async def delete_thread_item(self, thread_id: str, item_id: str, context: dict[str, Any]) -> None:
user_id = context.get("user_id", "demo_user")

with self._create_connection() as conn:
Expand Down
1 change: 0 additions & 1 deletion python/samples/demos/chatkit-integration/weather_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
CITY_VALUE_TO_NAME = {city["value"]: city["label"] for city in POPULAR_CITIES}



def _sun_svg() -> str:
"""Generate SVG for sunny weather icon."""
color = WEATHER_ICON_COLOR
Expand Down
Loading
Loading