From 87faf783a7ea459cebaf4f47e9462f1c5c28d59a Mon Sep 17 00:00:00 2001 From: Nick Larew Date: Tue, 1 Oct 2024 16:18:09 -0500 Subject: [PATCH 1/3] Minimize renders with useMemo/useCallback --- .../mongodb-chatbot-ui/src/useChatbot.tsx | 119 +++-- .../src/useConversation.tsx | 499 ++++++++++-------- 2 files changed, 339 insertions(+), 279 deletions(-) diff --git a/packages/mongodb-chatbot-ui/src/useChatbot.tsx b/packages/mongodb-chatbot-ui/src/useChatbot.tsx index fd97f8672..b02d28787 100644 --- a/packages/mongodb-chatbot-ui/src/useChatbot.tsx +++ b/packages/mongodb-chatbot-ui/src/useChatbot.tsx @@ -1,4 +1,4 @@ -import { useRef, useState } from "react"; +import { useCallback, useEffect, useRef, useState } from "react"; import { useConversation, type UseConversationParams } from "./useConversation"; export type OpenCloseHandlers = { @@ -47,7 +47,7 @@ export function useChatbot({ const [awaitingReply, setAwaitingReply] = useState(false); const inputBarRef = useRef(null); - async function openChat() { + const openChat = useCallback(async () => { if (open) { return; } @@ -56,16 +56,16 @@ export function useChatbot({ if (!conversation.conversationId) { await conversation.createConversation(); } - } + }, [open, onOpen, conversation]); - function closeChat() { + const closeChat = useCallback(() => { if (!open) { return false; } onClose?.(); setOpen(false); return true; - } + }, [open, onClose]); const [inputData, setInputData] = useState({ text: "", @@ -73,58 +73,67 @@ export function useChatbot({ }); const inputText = inputData.text; const inputTextError = inputData.error; - function setInputText(text: string) { - const isValid = maxInputCharacters - ? text.length <= maxInputCharacters - : true; - setInputData({ - text, - error: isValid - ? "" - : `Input must be less than ${maxInputCharacters} characters`, - }); - } + const setInputText = useCallback( + (text: string) => { + const isValid = maxInputCharacters + ? text.length <= maxInputCharacters + : true; + setInputData({ + text, + error: isValid + ? "" + : `Input must be less than ${maxInputCharacters} characters`, + }); + }, + [maxInputCharacters] + ); - function canSubmit(text: string) { - // Don't let users submit a message if the conversation hasn't fully loaded - if (!conversation.conversationId) { - console.error(`Cannot add message without a conversationId`); - return false; - } - // Don't let users submit a message if something is wrong with their input text - if (inputData.error) { - console.error(`Cannot add message with invalid input text`); - return false; - } - // Don't let users submit a message that is empty or only whitespace - if (text.replace(/\s/g, "").length === 0) { - console.error(`Cannot add message with no text`); - return false; - } - // Don't let users submit a message if we're already waiting for a reply - if (awaitingReply) { - console.error(`Cannot add message while awaiting a reply`); - return false; - } - return true; - } + const canSubmit = useCallback( + (text: string) => { + // Don't let users submit a message if the conversation hasn't fully loaded + if (!conversation.conversationId) { + console.error(`Cannot add message without a conversationId`); + return false; + } + // Don't let users submit a message if something is wrong with their input text + if (inputData.error) { + console.error(`Cannot add message with invalid input text`); + return false; + } + // Don't let users submit a message that is empty or only whitespace + if (text.replace(/\s/g, "").length === 0) { + console.error(`Cannot add message with no text`); + return false; + } + // Don't let users submit a message if we're already waiting for a reply + if (awaitingReply) { + console.error(`Cannot add message while awaiting a reply`); + return false; + } + return true; + }, + [conversation.conversationId, inputData.error, awaitingReply] + ); - async function handleSubmit(text: string) { - if (!canSubmit(text)) return; - try { - setInputText(""); - setAwaitingReply(true); - openChat(); - await conversation.addMessage({ - role: "user", - content: text, - }); - } catch (e) { - console.error(e); - } finally { - setAwaitingReply(false); - } - } + const handleSubmit = useCallback( + async (text: string) => { + if (!canSubmit(text)) return; + try { + setInputText(""); + setAwaitingReply(true); + openChat(); + await conversation.addMessage({ + role: "user", + content: text, + }); + } catch (e) { + console.error(e); + } finally { + setAwaitingReply(false); + } + }, + [canSubmit, setInputText, setAwaitingReply, openChat, conversation] + ); return { awaitingReply, diff --git a/packages/mongodb-chatbot-ui/src/useConversation.tsx b/packages/mongodb-chatbot-ui/src/useConversation.tsx index 33d441ba9..df5ae7f27 100644 --- a/packages/mongodb-chatbot-ui/src/useConversation.tsx +++ b/packages/mongodb-chatbot-ui/src/useConversation.tsx @@ -1,4 +1,4 @@ -import { useMemo, useReducer } from "react"; +import { useCallback, useMemo, useReducer } from "react"; import { type References } from "mongodb-rag-core"; import { MessageData, @@ -393,12 +393,16 @@ export function useConversation(params: UseConversationParams) { conversationReducer, defaultConversationState ); - const dispatch = (...args: Parameters) => { - if (import.meta.env.MODE !== "production") { - console.log(`dispatch`, ...args); - } - _dispatch(...args); - }; + + const dispatch = useCallback( + (...args: Parameters) => { + if (import.meta.env.MODE !== "production") { + console.log(`dispatch`, ...args); + } + _dispatch(...args); + }, + [_dispatch] + ); // Use a custom sort function if provided. If undefined and we're on a // well-known MongoDB domain, then prioritize links to the current domain. @@ -407,15 +411,21 @@ export function useConversation(params: UseConversationParams) { params.sortMessageReferences ?? makePrioritizeCurrentMongoDbReferenceDomain(); - const setConversation = (conversation: Required) => { - dispatch({ type: "setConversation", conversation }); - }; + const setConversation = useCallback( + (conversation: Required) => { + dispatch({ type: "setConversation", conversation }); + }, + [dispatch] + ); - const endConversationWithError = (errorMessage: string) => { - dispatch({ type: "setConversationError", errorMessage }); - }; + const endConversationWithError = useCallback( + (errorMessage: string) => { + dispatch({ type: "setConversationError", errorMessage }); + }, + [dispatch] + ); - const createConversation = async () => { + const createConversation = useCallback(async () => { try { const conversation = await conversationService.createConversation(); setConversation({ @@ -432,197 +442,219 @@ export function useConversation(params: UseConversationParams) { console.error(errorMessage); endConversationWithError(errorMessage); } - }; - - const addMessage: ConversationActor["addMessage"] = async ({ - role, - content, - }) => { - if (!state.conversationId) { - console.error(`Cannot addMessage without a conversationId`); - return; - } + }, [conversationService, setConversation, endConversationWithError]); + + const setMessageMetadata = useCallback< + ConversationActor["setMessageMetadata"] + >( + async ({ messageId, metadata }) => { + if (!state.conversationId) { + console.error(`Cannot setMessageMetadata without a conversationId`); + return; + } + dispatch({ type: "setMessageMetadata", messageId, metadata }); + }, + [state.conversationId, dispatch] + ); - const shouldStream = - canUseServerSentEvents() && (params.shouldStream ?? true); - - // Stream control - const abortController = new AbortController(); - let finishedStreaming = false; - let finishedBuffering = !shouldStream; - let streamedMessageId: string | null = null; - let references: References | null = null; - let bufferedTokens: string[] = []; - let streamedTokens: string[] = []; - const streamingIntervalMs = 50; - const streamingInterval = setInterval(() => { - const [nextToken, ...remainingTokens] = bufferedTokens; - - bufferedTokens = remainingTokens; - - if (nextToken) { - dispatch({ type: "appendStreamingResponse", data: nextToken }); + const addMessage = useCallback( + async ({ role, content }) => { + if (!state.conversationId) { + console.error(`Cannot addMessage without a conversationId`); + return; } - const allBufferedTokensDispatched = - finishedStreaming && bufferedTokens.length === 0; + const shouldStream = + canUseServerSentEvents() && (params.shouldStream ?? true); + + // Stream control + const abortController = new AbortController(); + let finishedStreaming = false; + let finishedBuffering = !shouldStream; + let streamedMessageId: string | null = null; + let references: References | null = null; + let bufferedTokens: string[] = []; + let streamedTokens: string[] = []; + const streamingIntervalMs = 50; + const streamingInterval = setInterval(() => { + const [nextToken, ...remainingTokens] = bufferedTokens; + + bufferedTokens = remainingTokens; + + if (nextToken) { + dispatch({ type: "appendStreamingResponse", data: nextToken }); + } + + const allBufferedTokensDispatched = + finishedStreaming && bufferedTokens.length === 0; + + if (references && allBufferedTokensDispatched) { + // Count the number of markdown code fences in the response. If + // it's odd, the streaming message stopped in the middle of a + // code block and we need to escape from it. + const numCodeFences = countRegexMatches( + /```/g, + streamedTokens.join("") + ); + if (numCodeFences % 2 !== 0) { + dispatch({ + type: "appendStreamingResponse", + data: "\n```\n\n", + }); + } - if (references && allBufferedTokensDispatched) { - // Count the number of markdown code fences in the response. If - // it's odd, the streaming message stopped in the middle of a - // code block and we need to escape from it. - const numCodeFences = countRegexMatches( - /```/g, - streamedTokens.join("") - ); - if (numCodeFences % 2 !== 0) { dispatch({ - type: "appendStreamingResponse", - data: "\n```\n\n", + type: "appendStreamingReferences", + data: references.sort(sortMessageReferences), }); + references = null; } - - dispatch({ - type: "appendStreamingReferences", - data: references.sort(sortMessageReferences), - }); - references = null; - } - if (!finishedBuffering && allBufferedTokensDispatched) { - if (!streamedMessageId) { - streamedMessageId = createMessageId(); + if (!finishedBuffering && allBufferedTokensDispatched) { + if (!streamedMessageId) { + streamedMessageId = createMessageId(); + } + dispatch({ + type: "finishStreamingResponse", + messageId: streamedMessageId, + }); + finishedBuffering = true; } - dispatch({ - type: "finishStreamingResponse", - messageId: streamedMessageId, - }); - finishedBuffering = true; - } - }, streamingIntervalMs); + }, streamingIntervalMs); - try { - dispatch({ type: "addMessage", role, content }); - if (shouldStream) { - dispatch({ type: "createStreamingResponse", data: "" }); - await conversationService.addMessageStreaming({ - conversationId: state.conversationId, - message: content, - maxRetries: 0, - onResponseDelta: async (data: string) => { - bufferedTokens = [...bufferedTokens, data]; - streamedTokens = [...streamedTokens, data]; - }, - onReferences: async (data: References) => { - if (references === null) { - references = []; - } - references.push(...data); - }, - onMetadata: async (metadata) => { - setMessageMetadata({ - messageId: STREAMING_MESSAGE_ID, - metadata, - }); - }, - onResponseFinished: async (messageId: string) => { - streamedMessageId = messageId; - finishedStreaming = true; - }, - signal: abortController.signal, - }); - } else { - // We start a streaming response to indicate the loading state - // but we'll never append to it since the response message comes - // in all at once. - dispatch({ type: "createStreamingResponse", data: "" }); - const response = await conversationService.addMessage({ - conversationId: state.conversationId, - message: content, - }); + try { + dispatch({ type: "addMessage", role, content }); + if (shouldStream) { + dispatch({ type: "createStreamingResponse", data: "" }); + await conversationService.addMessageStreaming({ + conversationId: state.conversationId, + message: content, + maxRetries: 0, + onResponseDelta: async (data: string) => { + bufferedTokens = [...bufferedTokens, data]; + streamedTokens = [...streamedTokens, data]; + }, + onReferences: async (data: References) => { + if (references === null) { + references = []; + } + references.push(...data); + }, + onMetadata: async (metadata) => { + setMessageMetadata({ + messageId: STREAMING_MESSAGE_ID, + metadata, + }); + }, + onResponseFinished: async (messageId: string) => { + streamedMessageId = messageId; + finishedStreaming = true; + }, + signal: abortController.signal, + }); + } else { + // We start a streaming response to indicate the loading state + // but we'll never append to it since the response message comes + // in all at once. + dispatch({ type: "createStreamingResponse", data: "" }); + const response = await conversationService.addMessage({ + conversationId: state.conversationId, + message: content, + }); + dispatch({ type: "cancelStreamingResponse" }); + dispatch({ + type: "addMessage", + role: "assistant", + content: response.content, + references: response.references?.sort(sortMessageReferences), + metadata: response.metadata, + }); + } + } catch (error) { + abortController.abort(); + console.error(`Failed to add message: ${error}`); + const errorMessage = + error instanceof Error ? error.message : String(error); dispatch({ type: "cancelStreamingResponse" }); - dispatch({ - type: "addMessage", - role: "assistant", - content: response.content, - references: response.references?.sort(sortMessageReferences), - metadata: response.metadata, - }); - } - } catch (error) { - abortController.abort(); - console.error(`Failed to add message: ${error}`); - const errorMessage = - error instanceof Error ? error.message : String(error); - dispatch({ type: "cancelStreamingResponse" }); - clearInterval(streamingInterval); + clearInterval(streamingInterval); - endConversationWithError(errorMessage); - throw error; - } + endConversationWithError(errorMessage); + throw error; + } - let cleanupInterval: ReturnType | undefined; - return new Promise((resolve) => { - cleanupInterval = setInterval(() => { - if (finishedBuffering) { - clearInterval(streamingInterval); - clearInterval(cleanupInterval); - resolve(); - } - }, streamingIntervalMs); - }); - }; + let cleanupInterval: ReturnType | undefined; + return new Promise((resolve) => { + cleanupInterval = setInterval(() => { + if (finishedBuffering) { + clearInterval(streamingInterval); + clearInterval(cleanupInterval); + resolve(); + } + }, streamingIntervalMs); + }); + }, + [ + state.conversationId, + params.shouldStream, + dispatch, + conversationService, + sortMessageReferences, + setMessageMetadata, + endConversationWithError, + ] + ); - const setMessageContent = async (messageId: string, content: string) => { - if (!state.conversationId) { - console.error(`Cannot setMessageContent without a conversationId`); - return; - } - dispatch({ type: "setMessageContent", messageId, content }); - }; - - const setMessageMetadata: ConversationActor["setMessageMetadata"] = async ({ - messageId, - metadata, - }) => { - if (!state.conversationId) { - console.error(`Cannot setMessageMetadata without a conversationId`); - return; - } - dispatch({ type: "setMessageMetadata", messageId, metadata }); - }; + const setMessageContent = useCallback( + async (messageId: string, content: string) => { + if (!state.conversationId) { + console.error(`Cannot setMessageContent without a conversationId`); + return; + } + dispatch({ type: "setMessageContent", messageId, content }); + }, + [state.conversationId, dispatch] + ); - const deleteMessage = async (messageId: string) => { - if (!state.conversationId) { - console.error(`Cannot deleteMessage without a conversationId`); - return; - } - dispatch({ type: "deleteMessage", messageId }); - }; + const deleteMessage = useCallback( + async (messageId: string) => { + if (!state.conversationId) { + console.error(`Cannot deleteMessage without a conversationId`); + return; + } + dispatch({ type: "deleteMessage", messageId }); + }, + [state.conversationId, dispatch] + ); - const rateMessage = async (messageId: string, rating: boolean) => { - if (!state.conversationId) { - console.error(`Cannot rateMessage without a conversationId`); - return; - } - await conversationService.rateMessage({ - conversationId: state.conversationId, - messageId, - rating, - }); - dispatch({ type: "rateMessage", messageId, rating }); - }; + const rateMessage = useCallback( + async (messageId: string, rating: boolean) => { + if (!state.conversationId) { + console.error(`Cannot rateMessage without a conversationId`); + return; + } + await conversationService.rateMessage({ + conversationId: state.conversationId, + messageId, + rating, + }); + dispatch({ type: "rateMessage", messageId, rating }); + }, + [state.conversationId, conversationService, dispatch] + ); - const commentMessage = async (messageId: string, comment: string) => { - if (!state.conversationId) { - console.error(`Cannot commentMessage without a conversationId`); - return; - } - await conversationService.commentMessage({ - conversationId: state.conversationId, - messageId, - comment, - }); - }; + const commentMessage = useCallback( + async (messageId: string, comment: string) => { + if (!state.conversationId) { + console.error(`Cannot commentMessage without a conversationId`); + return; + } + await conversationService.commentMessage({ + conversationId: state.conversationId, + messageId, + comment, + }); + }, + [state.conversationId, conversationService] + ); const streamingMessage = state.messages.find( (m) => m.id === STREAMING_MESSAGE_ID @@ -631,39 +663,58 @@ export function useConversation(params: UseConversationParams) { /** * Switch to a different, existing conversation. */ - const switchConversation = async (conversationId: string) => { - try { - const conversation = await conversationService.getConversation( - conversationId - ); - setConversation({ - ...conversation, - error: "", - }); - } catch (error) { - const errorMessage = - typeof error === "string" - ? error - : error instanceof Error - ? error.message - : "Failed to switch conversation."; - console.error(errorMessage); - // Rethrow the error so that we can handle it in the UI - throw error; - } - }; - - return { - ...state, - createConversation, - endConversationWithError, - streamingMessage, - addMessage, - setMessageContent, - setMessageMetadata, - deleteMessage, - rateMessage, - commentMessage, - switchConversation, - } satisfies Conversation; + const switchConversation = useCallback( + async (conversationId: string) => { + try { + const conversation = await conversationService.getConversation( + conversationId + ); + setConversation({ + ...conversation, + error: "", + }); + } catch (error) { + const errorMessage = + typeof error === "string" + ? error + : error instanceof Error + ? error.message + : "Failed to switch conversation."; + console.error(errorMessage); + // Rethrow the error so that we can handle it in the UI + throw error; + } + }, + [conversationService, setConversation] + ); + + return useMemo( + () => + ({ + ...state, + createConversation, + endConversationWithError, + streamingMessage, + addMessage, + setMessageContent, + setMessageMetadata, + deleteMessage, + rateMessage, + commentMessage, + switchConversation, + } satisfies Conversation), + [ + state, + createConversation, + endConversationWithError, + streamingMessage, + addMessage, + setMessageContent, + setMessageMetadata, + deleteMessage, + rateMessage, + commentMessage, + switchConversation, + ] + ); } From 21159cab86e08c7c7bcafa48ccc5d673daa825c3 Mon Sep 17 00:00:00 2001 From: Nick Larew Date: Tue, 1 Oct 2024 16:24:58 -0500 Subject: [PATCH 2/3] remove unused import --- packages/mongodb-chatbot-ui/src/useChatbot.tsx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/mongodb-chatbot-ui/src/useChatbot.tsx b/packages/mongodb-chatbot-ui/src/useChatbot.tsx index b02d28787..e00fb5e4b 100644 --- a/packages/mongodb-chatbot-ui/src/useChatbot.tsx +++ b/packages/mongodb-chatbot-ui/src/useChatbot.tsx @@ -1,4 +1,4 @@ -import { useCallback, useEffect, useRef, useState } from "react"; +import { useCallback, useRef, useState } from "react"; import { useConversation, type UseConversationParams } from "./useConversation"; export type OpenCloseHandlers = { From bd058c42fadbc1be652a924ce4d61dae82fd76b0 Mon Sep 17 00:00:00 2001 From: Nick Larew Date: Tue, 1 Oct 2024 17:26:28 -0500 Subject: [PATCH 3/3] replace useReducer with useState --- .../src/ConversationProvider.tsx | 15 + .../src/useConversation.tsx | 771 +++++++++--------- 2 files changed, 398 insertions(+), 388 deletions(-) diff --git a/packages/mongodb-chatbot-ui/src/ConversationProvider.tsx b/packages/mongodb-chatbot-ui/src/ConversationProvider.tsx index c8c94c6a3..5d1696d7a 100644 --- a/packages/mongodb-chatbot-ui/src/ConversationProvider.tsx +++ b/packages/mongodb-chatbot-ui/src/ConversationProvider.tsx @@ -30,6 +30,21 @@ export const ConversationContext = createContext({ switchConversation: async () => { return; }, + createStreamingResponse: async () => { + return; + }, + appendStreamingResponse: async () => { + return; + }, + appendStreamingReferences: async () => { + return; + }, + finishStreamingResponse: async () => { + return; + }, + cancelStreamingResponse: async () => { + return; + }, }); export default function ConversationProvider({ diff --git a/packages/mongodb-chatbot-ui/src/useConversation.tsx b/packages/mongodb-chatbot-ui/src/useConversation.tsx index df5ae7f27..ca7a0a0a9 100644 --- a/packages/mongodb-chatbot-ui/src/useConversation.tsx +++ b/packages/mongodb-chatbot-ui/src/useConversation.tsx @@ -1,4 +1,4 @@ -import { useCallback, useMemo, useReducer } from "react"; +import { useCallback, useMemo, useState } from "react"; import { type References } from "mongodb-rag-core"; import { MessageData, @@ -71,7 +71,16 @@ type ConversationActorArgs = Omit< type ConversationActor = { createConversation: () => Promise; endConversationWithError: (errorMessage: string) => void; - addMessage: (args: ConversationActorArgs<"addMessage">) => Promise; + addMessage: ( + args: + | { role: "user"; content: string } + | { + role: "assistant"; + content: string; + references?: References; + metadata?: AssistantMessageMetadata; + } + ) => Promise; setMessageContent: (messageId: string, content: string) => Promise; setMessageMetadata: ( args: ConversationActorArgs<"setMessageMetadata"> @@ -80,6 +89,13 @@ type ConversationActor = { rateMessage: (messageId: string, rating: boolean) => Promise; commentMessage: (messageId: string, comment: string) => Promise; switchConversation: (conversationId: string) => Promise; + + // Streaming State + createStreamingResponse: (data: string) => void; + appendStreamingResponse: (data: string) => void; + appendStreamingReferences: (references: References) => void; + finishStreamingResponse: (messageId: string) => void; + cancelStreamingResponse: () => void; }; export type Conversation = ConversationState & ConversationActor; @@ -90,288 +106,437 @@ export const defaultConversationState = { isStreamingMessage: false, } satisfies ConversationState; -function conversationReducer( - state: ConversationState, - action: ConversationAction -): ConversationState { - function getMessageIndex(messageId: MessageData["id"]) { - const messageIndex = state.messages.findIndex( - (message) => message.id === messageId - ); - return messageIndex; - } - function getStreamingMessage() { - const streamingMessageIndex = getMessageIndex(STREAMING_MESSAGE_ID); - const streamingMessage = - streamingMessageIndex === -1 - ? null - : state.messages[streamingMessageIndex]; - return { - streamingMessageIndex, - streamingMessage, - }; - } - switch (action.type) { - case "setConversation": { - return { - ...action.conversation, - error: "", - }; - } - case "setConversationError": { - return { - ...state, - error: action.errorMessage, - }; - } - case "addMessage": { - if (!state.conversationId) { - console.error(`Cannot addMessage without a conversationId`); - } +function getMessageIndex( + messages: MessageData[], + messageId: MessageData["id"] +) { + const messageIndex = messages.findIndex( + (message) => message.id === messageId + ); + return messageIndex; +} +function getStreamingMessage(messages: MessageData[]) { + const streamingMessageIndex = getMessageIndex(messages, STREAMING_MESSAGE_ID); + const streamingMessage = + streamingMessageIndex === -1 ? null : messages[streamingMessageIndex]; + return { + streamingMessageIndex, + streamingMessage, + }; +} - const newMessage = createMessage(action); +function useConversationState(conversationService: ConversationService): { + state: ConversationState; + actions: ConversationActor; +} { + const [state, setState] = useState({ + conversationId: undefined, + messages: [], + error: undefined, + isStreamingMessage: false, + streamingMessage: undefined, + }); - return { - ...state, - messages: [...state.messages, newMessage], - }; + const setConversation = useCallback( + (conversation: Required) => { + setState((_) => conversation); + }, + [] + ); + + const endConversationWithError = useCallback((errorMessage: string) => { + setState((prevState) => ({ + ...prevState, + error: errorMessage, + })); + }, []); + + const createConversation = useCallback(async () => { + try { + const conversation = await conversationService.createConversation(); + setConversation({ + ...conversation, + error: "", + }); + } catch (error) { + const errorMessage = + typeof error === "string" + ? error + : error instanceof Error + ? error.message + : "Failed to create conversation."; + console.error(errorMessage); + endConversationWithError(errorMessage); } - case "setMessageContent": { - if (!state.conversationId) { - console.error(`Cannot setMessageContent without a conversationId`); - return state; - } - const messageIndex = getMessageIndex(action.messageId); - if (messageIndex === -1) { - console.error( - `Cannot setMessageContent because message with id ${action.messageId} does not exist` - ); - return state; - } - const modifiedMessage = { - ...state.messages[messageIndex], - content: action.content, - }; - return { - ...state, + }, [conversationService, setConversation, endConversationWithError]); + + const addMessage = useCallback( + async ({ role, content }) => { + setState((prevState) => ({ + ...prevState, messages: [ - ...state.messages.slice(0, messageIndex), - modifiedMessage, - ...state.messages.slice(messageIndex + 1), + ...prevState.messages, + createMessage({ + role, + content, + }), ], - }; - } - case "setMessageMetadata": { - if (!state.conversationId) { + })); + }, + [] + ); + + const setMessageContent = useCallback( + async (messageId: string, content: string) => { + setState((prevState) => { + if (!prevState.conversationId) { + console.error(`Cannot setMessageContent without a conversationId`); + return prevState; + } + const messageIndex = getMessageIndex(prevState.messages, messageId); + if (messageIndex === -1) { + console.error( + `Cannot setMessageContent because message with id ${messageId} does not exist` + ); + return prevState; + } + return { + ...prevState, + messages: [ + ...prevState.messages.slice(0, messageIndex), + { + ...prevState.messages[messageIndex], + content, + }, + ...prevState.messages.slice(messageIndex + 1), + ], + }; + }); + }, + [] + ); + + const setMessageMetadata = useCallback< + ConversationActor["setMessageMetadata"] + >(async ({ messageId, metadata }) => { + setState((prevState) => { + if (!prevState.conversationId) { console.error(`Cannot setMessageMetadata without a conversationId`); - return state; + return prevState; } - const messageIndex = getMessageIndex(action.messageId); + const messageIndex = getMessageIndex(prevState.messages, messageId); if (messageIndex === -1) { console.error( - `Cannot setMessageMetadata because message with id ${action.messageId} does not exist` + `Cannot setMessageMetadata because message with id ${messageId} does not exist` ); - return state; + return prevState; } - const existingMessage = state.messages[messageIndex]; + const existingMessage = prevState.messages[messageIndex]; if (existingMessage.role !== "assistant") { console.error( - `Cannot setMessageMetadata because message with id ${action.messageId} is not an assistant message` + `Cannot setMessageMetadata because message with id ${messageId} is not an assistant message` ); - return state; + return prevState; } - const modifiedMessage = { - ...existingMessage, - metadata: action.metadata, - }; return { - ...state, + ...prevState, messages: [ - ...state.messages.slice(0, messageIndex), - modifiedMessage, - ...state.messages.slice(messageIndex + 1), + ...prevState.messages.slice(0, messageIndex), + { + ...existingMessage, + metadata, + }, + ...prevState.messages.slice(messageIndex + 1), ], }; - } - case "deleteMessage": { - if (!state.conversationId) { - console.error(`Cannot deleteMessage without a conversationId`); - return state; - } - const messageIndex = getMessageIndex(action.messageId); - if (messageIndex === -1) { - console.error( - `Cannot deleteMessage because message with id ${action.messageId} does not exist` - ); - return state; - } - return { - ...state, - messages: [ - ...state.messages.slice(0, messageIndex), - ...state.messages.slice(messageIndex + 1), - ], - }; - } - case "rateMessage": { - if (!state.conversationId) { - console.error(`Cannot rateMessage without a conversationId`); - return state; - } - const messageIndex = getMessageIndex(action.messageId); - if (messageIndex === -1) { - console.error( - `Cannot rateMessage because message with id ${action.messageId} does not exist` - ); - return state; - } + }); + }, []); - const ratedMessage = { - ...state.messages[messageIndex], - rating: action.rating, - }; - return { - ...state, - messages: [ - ...state.messages.slice(0, messageIndex), - ratedMessage, - ...state.messages.slice(messageIndex + 1), - ], - }; - } - case "createStreamingResponse": { - if (!state.conversationId) { + const createStreamingResponse = useCallback((data: string) => { + setState((prevState) => { + if (!prevState.conversationId) { console.error( `Cannot createStreamingResponse without a conversationId` ); - return state; + return prevState; } - let { streamingMessage } = getStreamingMessage(); + let { streamingMessage } = getStreamingMessage(prevState.messages); if (streamingMessage) { console.error( `Cannot createStreamingResponse because a streamingMessage already exists` ); - return state; + return prevState; } streamingMessage = { ...createMessage({ role: "assistant", - content: action.data, + content: data, }), id: STREAMING_MESSAGE_ID, }; return { - ...state, + ...prevState, isStreamingMessage: true, streamingMessage, - messages: [...state.messages, streamingMessage], + messages: [...prevState.messages, streamingMessage], }; - } - case "appendStreamingResponse": { - if (!state.conversationId) { + }); + }, []); + + const appendStreamingResponse = useCallback((data: string) => { + setState((prevState) => { + if (!prevState.conversationId) { console.error( `Cannot appendStreamingResponse without a conversationId` ); - return state; + return prevState; } - const { streamingMessage, streamingMessageIndex } = getStreamingMessage(); + const { streamingMessage, streamingMessageIndex } = getStreamingMessage( + prevState.messages + ); if (!streamingMessage) { console.error( - `Cannot appendStreamingResponse without a streamingMessage. Make sure to dispatch createStreamingResponse first.` + `Cannot appendStreamingResponse without a streamingMessage. Make sure to call createStreamingResponse first.` ); - return state; + return prevState; } - const modifiedMessage = { - ...streamingMessage, - content: streamingMessage.content + action.data, - }; return { - ...state, + ...prevState, messages: updateArrayElementAt( - state.messages, + prevState.messages, streamingMessageIndex, - modifiedMessage + { + ...streamingMessage, + content: streamingMessage.content + data, + } ), }; - } - case "appendStreamingReferences": { - if (!state.conversationId) { + }); + }, []); + + const appendStreamingReferences = useCallback((data: References) => { + setState((prevState) => { + if (!prevState.conversationId) { console.error( - `Cannot appendStreamingResponse without a conversationId` + `Cannot appendStreamingReferences without a conversationId` ); - return state; + return prevState; } - const { streamingMessage, streamingMessageIndex } = getStreamingMessage(); + const { streamingMessage, streamingMessageIndex } = getStreamingMessage( + prevState.messages + ); if (!streamingMessage) { console.error( - `Cannot appendStreamingResponse without a streamingMessage. Make sure to dispatch createStreamingResponse first.` + `Cannot appendStreamingReferences without a streamingMessage. Make sure to call createStreamingResponse first.` ); - return state; + return prevState; } - const modifiedMessage = { - ...streamingMessage, - references: [...(streamingMessage.references ?? []), ...action.data], - } satisfies MessageData; return { - ...state, + ...prevState, messages: updateArrayElementAt( - state.messages, + prevState.messages, streamingMessageIndex, - modifiedMessage + { + ...streamingMessage, + references: [...(streamingMessage.references ?? []), ...data], + } ), }; - } - case "finishStreamingResponse": { - const { streamingMessage, streamingMessageIndex } = getStreamingMessage(); + }); + }, []); + + const finishStreamingResponse = useCallback((messageId: string) => { + setState((prevState) => { + if (!prevState.conversationId) { + console.error( + `Cannot finishStreamingResponse without a conversationId` + ); + return prevState; + } + const { streamingMessage, streamingMessageIndex } = getStreamingMessage( + prevState.messages + ); if (!streamingMessage) { console.error( `Cannot finishStreamingResponse without a streamingMessage` ); - return state; + return prevState; } - const finalMessage = { - ...streamingMessage, - id: action.messageId, - }; - return { - ...state, + ...prevState, isStreamingMessage: false, streamingMessage: undefined, messages: updateArrayElementAt( - state.messages, + prevState.messages, streamingMessageIndex, - finalMessage + { + ...streamingMessage, + id: messageId ?? createMessageId(), + } ), }; - } - case "cancelStreamingResponse": { - const { streamingMessage, streamingMessageIndex } = getStreamingMessage(); + }); + }, []); + + const cancelStreamingResponse = useCallback(() => { + setState((prevState) => { + if (!prevState.conversationId) { + console.error( + `Cannot cancelStreamingResponse without a conversationId` + ); + return prevState; + } + const { streamingMessage, streamingMessageIndex } = getStreamingMessage( + prevState.messages + ); if (!streamingMessage) { console.error( `Cannot cancelStreamingResponse without a streamingMessage` ); - return state; + return prevState; } - - const messages = removeArrayElementAt( - state.messages, - streamingMessageIndex - ); - return { - ...state, + ...prevState, isStreamingMessage: false, streamingMessage: undefined, - messages, + messages: removeArrayElementAt( + prevState.messages, + streamingMessageIndex + ), }; - } - default: { - console.error("Unhandled action", action); - throw new Error(`Unhandled action type`); - } - } + }); + }, []); + + const deleteMessage = useCallback(async (messageId: string) => { + setState((prevState) => { + if (!prevState.conversationId) { + console.error(`Cannot deleteMessage without a conversationId`); + return prevState; + } + const messageIndex = getMessageIndex(prevState.messages, messageId); + if (messageIndex === -1) { + console.error( + `Cannot deleteMessage because message with id ${messageId} does not exist` + ); + return prevState; + } + return { + ...prevState, + messages: removeArrayElementAt(prevState.messages, messageIndex), + }; + }); + }, []); + + const rateMessage = useCallback( + async (messageId: string, rating: boolean) => { + setState((prevState) => { + if (!prevState.conversationId) { + console.error(`Cannot rateMessage without a conversationId`); + return prevState; + } + const messageIndex = getMessageIndex(prevState.messages, messageId); + if (messageIndex === -1) { + console.error( + `Cannot rateMessage because message with id ${messageId} does not exist` + ); + return prevState; + } + const ratedMessage = { + ...prevState.messages[messageIndex], + rating, + }; + return { + ...prevState, + messages: updateArrayElementAt( + prevState.messages, + messageIndex, + ratedMessage + ), + }; + }); + }, + [] + ); + + const commentMessage = useCallback( + async (messageId: string, comment: string) => { + setState((prevState) => { + if (!prevState.conversationId) { + console.error(`Cannot commentMessage without a conversationId`); + return prevState; + } + const messageIndex = getMessageIndex(prevState.messages, messageId); + if (messageIndex === -1) { + console.error( + `Cannot commentMessage because message with id ${messageId} does not exist` + ); + return prevState; + } + const commentedMessage = { + ...prevState.messages[messageIndex], + comment, + }; + return { + ...prevState, + messages: updateArrayElementAt( + prevState.messages, + messageIndex, + commentedMessage + ), + }; + }); + }, + [] + ); + + /** + * Switch to a different, existing conversation. + */ + const switchConversation = useCallback( + async (conversationId: string) => { + try { + const conversation = await conversationService.getConversation( + conversationId + ); + setConversation({ + ...conversation, + error: "", + }); + } catch (error) { + const errorMessage = + typeof error === "string" + ? error + : error instanceof Error + ? error.message + : "Failed to switch conversation."; + console.error(errorMessage); + // Rethrow the error so that we can handle it in the UI + throw error; + } + }, + [conversationService, setConversation] + ); + + return { + state, + actions: { + createConversation, + endConversationWithError, + addMessage, + setMessageContent, + setMessageMetadata, + deleteMessage, + rateMessage, + commentMessage, + switchConversation, + createStreamingResponse, + appendStreamingResponse, + appendStreamingReferences, + finishStreamingResponse, + cancelStreamingResponse, + }, + }; } export type UseConversationParams = { @@ -389,20 +554,7 @@ export function useConversation(params: UseConversationParams) { }); }, [params.serverBaseUrl, params.fetchOptions]); - const [state, _dispatch] = useReducer( - conversationReducer, - defaultConversationState - ); - - const dispatch = useCallback( - (...args: Parameters) => { - if (import.meta.env.MODE !== "production") { - console.log(`dispatch`, ...args); - } - _dispatch(...args); - }, - [_dispatch] - ); + const { state, actions } = useConversationState(conversationService); // Use a custom sort function if provided. If undefined and we're on a // well-known MongoDB domain, then prioritize links to the current domain. @@ -411,52 +563,6 @@ export function useConversation(params: UseConversationParams) { params.sortMessageReferences ?? makePrioritizeCurrentMongoDbReferenceDomain(); - const setConversation = useCallback( - (conversation: Required) => { - dispatch({ type: "setConversation", conversation }); - }, - [dispatch] - ); - - const endConversationWithError = useCallback( - (errorMessage: string) => { - dispatch({ type: "setConversationError", errorMessage }); - }, - [dispatch] - ); - - const createConversation = useCallback(async () => { - try { - const conversation = await conversationService.createConversation(); - setConversation({ - ...conversation, - error: "", - }); - } catch (error) { - const errorMessage = - typeof error === "string" - ? error - : error instanceof Error - ? error.message - : "Failed to create conversation."; - console.error(errorMessage); - endConversationWithError(errorMessage); - } - }, [conversationService, setConversation, endConversationWithError]); - - const setMessageMetadata = useCallback< - ConversationActor["setMessageMetadata"] - >( - async ({ messageId, metadata }) => { - if (!state.conversationId) { - console.error(`Cannot setMessageMetadata without a conversationId`); - return; - } - dispatch({ type: "setMessageMetadata", messageId, metadata }); - }, - [state.conversationId, dispatch] - ); - const addMessage = useCallback( async ({ role, content }) => { if (!state.conversationId) { @@ -482,7 +588,7 @@ export function useConversation(params: UseConversationParams) { bufferedTokens = remainingTokens; if (nextToken) { - dispatch({ type: "appendStreamingResponse", data: nextToken }); + actions.appendStreamingResponse(nextToken); } const allBufferedTokensDispatched = @@ -497,34 +603,26 @@ export function useConversation(params: UseConversationParams) { streamedTokens.join("") ); if (numCodeFences % 2 !== 0) { - dispatch({ - type: "appendStreamingResponse", - data: "\n```\n\n", - }); + actions.appendStreamingResponse("\n```\n\n"); } - - dispatch({ - type: "appendStreamingReferences", - data: references.sort(sortMessageReferences), - }); + actions.appendStreamingReferences( + references.sort(sortMessageReferences) + ); references = null; } if (!finishedBuffering && allBufferedTokensDispatched) { if (!streamedMessageId) { streamedMessageId = createMessageId(); } - dispatch({ - type: "finishStreamingResponse", - messageId: streamedMessageId, - }); + actions.finishStreamingResponse(streamedMessageId); finishedBuffering = true; } }, streamingIntervalMs); try { - dispatch({ type: "addMessage", role, content }); + await actions.addMessage({ role, content }); if (shouldStream) { - dispatch({ type: "createStreamingResponse", data: "" }); + actions.createStreamingResponse(""); await conversationService.addMessageStreaming({ conversationId: state.conversationId, message: content, @@ -540,7 +638,7 @@ export function useConversation(params: UseConversationParams) { references.push(...data); }, onMetadata: async (metadata) => { - setMessageMetadata({ + actions.setMessageMetadata({ messageId: STREAMING_MESSAGE_ID, metadata, }); @@ -555,14 +653,13 @@ export function useConversation(params: UseConversationParams) { // We start a streaming response to indicate the loading state // but we'll never append to it since the response message comes // in all at once. - dispatch({ type: "createStreamingResponse", data: "" }); + actions.createStreamingResponse(""); const response = await conversationService.addMessage({ conversationId: state.conversationId, message: content, }); - dispatch({ type: "cancelStreamingResponse" }); - dispatch({ - type: "addMessage", + actions.cancelStreamingResponse(); + actions.addMessage({ role: "assistant", content: response.content, references: response.references?.sort(sortMessageReferences), @@ -574,10 +671,10 @@ export function useConversation(params: UseConversationParams) { console.error(`Failed to add message: ${error}`); const errorMessage = error instanceof Error ? error.message : String(error); - dispatch({ type: "cancelStreamingResponse" }); + actions.cancelStreamingResponse(); clearInterval(streamingInterval); - endConversationWithError(errorMessage); + actions.endConversationWithError(errorMessage); throw error; } @@ -593,128 +690,26 @@ export function useConversation(params: UseConversationParams) { }); }, [ - state.conversationId, params.shouldStream, - dispatch, conversationService, + state.conversationId, + actions, sortMessageReferences, - setMessageMetadata, - endConversationWithError, ] ); - const setMessageContent = useCallback( - async (messageId: string, content: string) => { - if (!state.conversationId) { - console.error(`Cannot setMessageContent without a conversationId`); - return; - } - dispatch({ type: "setMessageContent", messageId, content }); - }, - [state.conversationId, dispatch] - ); - - const deleteMessage = useCallback( - async (messageId: string) => { - if (!state.conversationId) { - console.error(`Cannot deleteMessage without a conversationId`); - return; - } - dispatch({ type: "deleteMessage", messageId }); - }, - [state.conversationId, dispatch] - ); - - const rateMessage = useCallback( - async (messageId: string, rating: boolean) => { - if (!state.conversationId) { - console.error(`Cannot rateMessage without a conversationId`); - return; - } - await conversationService.rateMessage({ - conversationId: state.conversationId, - messageId, - rating, - }); - dispatch({ type: "rateMessage", messageId, rating }); - }, - [state.conversationId, conversationService, dispatch] - ); - - const commentMessage = useCallback( - async (messageId: string, comment: string) => { - if (!state.conversationId) { - console.error(`Cannot commentMessage without a conversationId`); - return; - } - await conversationService.commentMessage({ - conversationId: state.conversationId, - messageId, - comment, - }); - }, - [state.conversationId, conversationService] - ); - const streamingMessage = state.messages.find( (m) => m.id === STREAMING_MESSAGE_ID ); - /** - * Switch to a different, existing conversation. - */ - const switchConversation = useCallback( - async (conversationId: string) => { - try { - const conversation = await conversationService.getConversation( - conversationId - ); - setConversation({ - ...conversation, - error: "", - }); - } catch (error) { - const errorMessage = - typeof error === "string" - ? error - : error instanceof Error - ? error.message - : "Failed to switch conversation."; - console.error(errorMessage); - // Rethrow the error so that we can handle it in the UI - throw error; - } - }, - [conversationService, setConversation] - ); - return useMemo( () => ({ ...state, - createConversation, - endConversationWithError, - streamingMessage, + ...actions, addMessage, - setMessageContent, - setMessageMetadata, - deleteMessage, - rateMessage, - commentMessage, - switchConversation, + streamingMessage, } satisfies Conversation), - [ - state, - createConversation, - endConversationWithError, - streamingMessage, - addMessage, - setMessageContent, - setMessageMetadata, - deleteMessage, - rateMessage, - commentMessage, - switchConversation, - ] + [state, actions, addMessage, streamingMessage] ); }