diff --git a/templates/components/agents/python/blog/app/workflows/models.py b/templates/components/agents/python/blog/app/workflows/models.py index 92f6aae4..da9dfad4 100644 --- a/templates/components/agents/python/blog/app/workflows/models.py +++ b/templates/components/agents/python/blog/app/workflows/models.py @@ -1,9 +1,10 @@ from typing import Any, Dict, List, Literal -from app.api.routers.models import SourceNodes -from llama_index.core.schema import Node, NodeWithScore +from llama_index.core.schema import NodeWithScore from llama_index.core.workflow import Event +from app.api.routers.models import SourceNodes + # Workflow events class PlanResearchEvent(Event): @@ -13,7 +14,7 @@ class PlanResearchEvent(Event): class ResearchEvent(Event): question_id: str question: str - context_nodes: List[NodeWithScore | Node] + context_nodes: List[NodeWithScore] class CollectAnswersEvent(Event): diff --git a/templates/components/agents/python/blog/app/workflows/writer.py b/templates/components/agents/python/blog/app/workflows/writer.py index 516d13d3..0e16fc6c 100644 --- a/templates/components/agents/python/blog/app/workflows/writer.py +++ b/templates/components/agents/python/blog/app/workflows/writer.py @@ -3,16 +3,6 @@ import uuid from typing import Any, Dict, List, Optional -from app.engine.index import IndexConfig, get_index -from app.workflows.agents import plan_research, research, write_report -from app.workflows.models import ( - CollectAnswersEvent, - DataEvent, - PlanResearchEvent, - ResearchEvent, - SourceNodesEvent, - WriteReportEvent, -) from llama_index.core.indices.base import BaseIndex from llama_index.core.memory import ChatMemoryBuffer from llama_index.core.memory.simple_composable_memory import SimpleComposableMemory @@ -26,7 +16,18 @@ step, ) -logger = logging.getLogger(__name__) +from app.engine.index import IndexConfig, get_index +from app.workflows.agents import plan_research, research, write_report +from app.workflows.models import ( + CollectAnswersEvent, + DataEvent, + PlanResearchEvent, + ResearchEvent, + SourceNodesEvent, + WriteReportEvent, +) + +logger = logging.getLogger("uvicorn") logger.setLevel(logging.INFO) @@ -45,7 +46,7 @@ def create_workflow( return WriterWorkflow( index=index, chat_history=chat_history, - **kwargs, + timeout=120.0, ) @@ -67,14 +68,13 @@ class WriterWorkflow(Workflow): context_nodes: List[Node] index: BaseIndex user_request: str - stream: bool = False + stream: bool = True def __init__( self, index: BaseIndex, chat_history: Optional[List[ChatMessage]] = None, - stream: bool = False, - timeout: Optional[float] = 120.0, + stream: bool = True, **kwargs, ): super().__init__(**kwargs) @@ -82,7 +82,6 @@ def __init__( self.context_nodes = [] self.stream = stream self.chat_history = chat_history - self.timeout = timeout self.memory = SimpleComposableMemory.from_defaults( primary_memory=ChatMemoryBuffer.from_defaults( chat_history=chat_history, @@ -142,7 +141,7 @@ async def analyze( """ Analyze the retrieved information """ - print("Analyzing the retrieved information") + logger.info("Analyzing the retrieved information") ctx.write_event_to_stream( DataEvent( type="analyze", @@ -262,7 +261,7 @@ async def collect_answers( content=f"{result.question}\n{result.answer}", ) ) - ctx.set("n_questions", 0) + await ctx.set("n_questions", 0) self.memory.put( message=ChatMessage( role=MessageRole.ASSISTANT, @@ -276,6 +275,7 @@ async def report(self, ctx: Context, ev: WriteReportEvent) -> StopEvent: """ Report the answers """ + logger.info("Writing the report") res = await write_report( memory=self.memory, user_request=self.user_request, diff --git a/templates/types/streaming/nextjs/app/components/ui/chat/chat-message-content.tsx b/templates/types/streaming/nextjs/app/components/ui/chat/chat-message-content.tsx index fa421692..db43ed9e 100644 --- a/templates/types/streaming/nextjs/app/components/ui/chat/chat-message-content.tsx +++ b/templates/types/streaming/nextjs/app/components/ui/chat/chat-message-content.tsx @@ -25,7 +25,7 @@ export function ChatMessageContent() { }, // add the writer card { - position: ContentPosition.AFTER_EVENTS, + position: ContentPosition.CHAT_EVENTS, component: , }, { diff --git a/templates/types/streaming/nextjs/app/components/ui/chat/custom/writer-card.tsx b/templates/types/streaming/nextjs/app/components/ui/chat/custom/writer-card.tsx index 17023f3e..b2ad7c5b 100644 --- a/templates/types/streaming/nextjs/app/components/ui/chat/custom/writer-card.tsx +++ b/templates/types/streaming/nextjs/app/components/ui/chat/custom/writer-card.tsx @@ -8,7 +8,7 @@ import { NotebookPen, Search, } from "lucide-react"; -import { useEffect, useState } from "react"; +import { useMemo } from "react"; import { Collapsible, CollapsibleContent, @@ -45,42 +45,49 @@ type WriterState = { }; }; -// Update the state based on the event -const updateState = (state: WriterState, event: WriterEvent): WriterState => { +const stateIcon: Record = { + pending: , + inprogress: , + done: , + error: , +}; + +// Transform the state based on the event without mutations +const transformState = ( + state: WriterState, + event: WriterEvent, +): WriterState => { switch (event.type) { case "answer": { const { id, question, answer } = event.data; if (!id || !question) return state; - const questions = state.analyze.questions; - const existingQuestion = questions.find((q) => q.id === id); - - const updatedQuestions = existingQuestion - ? questions.map((q) => - q.id === id - ? { - ...existingQuestion, - state: event.state, - answer: answer || existingQuestion.answer, - } - : q, - ) - : [ - ...questions, + const updatedQuestions = state.analyze.questions.map((q) => { + if (q.id !== id) return q; + return { + ...q, + state: event.state, + answer: answer ?? q.answer, + }; + }); + + const newQuestion = !state.analyze.questions.some((q) => q.id === id) + ? [ { id, question, - answer: answer || null, + answer: answer ?? null, state: event.state, isOpen: false, }, - ]; + ] + : []; return { ...state, analyze: { ...state.analyze, - questions: updatedQuestions, + questions: [...updatedQuestions, ...newQuestion], }, }; } @@ -100,36 +107,30 @@ const updateState = (state: WriterState, event: WriterEvent): WriterState => { } }; -export function WriterCard({ message }: { message: Message }) { - const [state, setState] = useState({ +// Convert writer events to state +const writeEventsToState = (events: WriterEvent[] | undefined): WriterState => { + if (!events?.length) { + return { + retrieve: { state: null }, + analyze: { state: null, questions: [] }, + }; + } + + const initialState: WriterState = { retrieve: { state: null }, analyze: { state: null, questions: [] }, - }); + }; + return events.reduce( + (acc: WriterState, event: WriterEvent) => transformState(acc, event), + initialState, + ); +}; + +export function WriterCard({ message }: { message: Message }) { const writerEvents = message.annotations as WriterEvent[] | undefined; - useEffect(() => { - if (writerEvents?.length) { - writerEvents.forEach((event) => { - setState((currentState) => updateState(currentState, event)); - }); - } - }, [writerEvents]); - - const getStateIcon = (state: EventState | null) => { - switch (state) { - case "pending": - return ; - case "inprogress": - return ; - case "done": - return ; - case "error": - return ; - default: - return null; - } - }; + const state = useMemo(() => writeEventsToState(writerEvents), [writerEvents]); if (!writerEvents?.length) { return null; @@ -167,7 +168,7 @@ export function WriterCard({ message }: { message: Message }) { - {getStateIcon(question.state)} + {stateIcon[question.state]} {question.question}