diff --git a/.env.template b/.env.template index 6b2dfb0b35b..d8573a499c7 100644 --- a/.env.template +++ b/.env.template @@ -57,7 +57,8 @@ MODELS=`[ "repetition_penalty": 1.2, "top_k": 50, "truncate": 3072, - "max_new_tokens": 1024 + "max_new_tokens": 1024, + "stop" : ["", " [INST] "] } }, { @@ -116,7 +117,8 @@ MODELS=`[ "repetition_penalty": 1.2, "top_k": 50, "truncate": 4096, - "max_new_tokens": 4096 + "max_new_tokens": 4096, + "stop": [" [INST] "] } }, { diff --git a/src/lib/buildPrompt.ts b/src/lib/buildPrompt.ts index 15e3a450e64..96f5e3c7580 100644 --- a/src/lib/buildPrompt.ts +++ b/src/lib/buildPrompt.ts @@ -13,6 +13,7 @@ interface buildPromptOptions { webSearch?: WebSearch; preprompt?: string; files?: File[]; + continue?: boolean; } export async function buildPrompt({ @@ -22,37 +23,38 @@ export async function buildPrompt({ preprompt, id, }: buildPromptOptions): Promise { + let modifiedMessages = [...messages]; + if (webSearch && webSearch.context) { - const lastMsg = messages.slice(-1)[0]; - const messagesWithoutLastUsrMsg = messages.slice(0, -1); - const previousUserMessages = messages.filter((el) => el.from === "user").slice(0, -1); + // find index of the last user message + const lastUsrMsgIndex = modifiedMessages.map((el) => el.from).lastIndexOf("user"); + // combine all the other previous questions into one string + const previousUserMessages = modifiedMessages.filter((el) => el.from === "user").slice(0, -1); const previousQuestions = previousUserMessages.length > 0 ? `Previous questions: \n${previousUserMessages .map(({ content }) => `- ${content}`) .join("\n")}` : ""; + const currentDate = format(new Date(), "MMMM d, yyyy"); - messages = [ - ...messagesWithoutLastUsrMsg, - { - from: "user", - content: `I searched the web using the query: ${webSearch.searchQuery}. Today is ${currentDate} and here are the results: + + // update the last user message directly (that way if the last message is an assistant partial answer, we keep the beginning of that answer) + modifiedMessages[lastUsrMsgIndex] = { + from: "user", + content: `I searched the web using the query: ${webSearch.searchQuery}. Today is ${currentDate} and here are the results: ===================== ${webSearch.context} ===================== ${previousQuestions} - Answer the question: ${lastMsg.content} - `, - }, - ]; + Answer the question: ${messages[lastUsrMsgIndex].content} `, + }; } - // section to handle potential files input if (model.multimodal) { - messages = await Promise.all( - messages.map(async (el) => { + modifiedMessages = await Promise.all( + modifiedMessages.map(async (el) => { let content = el.content; if (el.from === "user") { @@ -83,7 +85,7 @@ export async function buildPrompt({ return ( model - .chatPromptRender({ messages, preprompt }) + .chatPromptRender({ messages: modifiedMessages, preprompt }) // Not super precise, but it's truncated in the model's backend anyway .split(" ") .slice(-(model.parameters?.truncate ?? 0)) diff --git a/src/lib/components/ContinueBtn.svelte b/src/lib/components/ContinueBtn.svelte new file mode 100644 index 00000000000..aa72fd16042 --- /dev/null +++ b/src/lib/components/ContinueBtn.svelte @@ -0,0 +1,13 @@ + + + diff --git a/src/lib/components/chat/ChatMessage.svelte b/src/lib/components/chat/ChatMessage.svelte index 158d8728863..32efd09e83e 100644 --- a/src/lib/components/chat/ChatMessage.svelte +++ b/src/lib/components/chat/ChatMessage.svelte @@ -13,6 +13,7 @@ import CarbonDownload from "~icons/carbon/download"; import CarbonThumbsUp from "~icons/carbon/thumbs-up"; import CarbonThumbsDown from "~icons/carbon/thumbs-down"; + import { PUBLIC_SEP_TOKEN } from "$lib/constants/publicSepToken"; import type { Model } from "$lib/types/Model"; diff --git a/src/lib/components/chat/ChatMessages.svelte b/src/lib/components/chat/ChatMessages.svelte index 9ce0b115a3b..5031d280efc 100644 --- a/src/lib/components/chat/ChatMessages.svelte +++ b/src/lib/components/chat/ChatMessages.svelte @@ -54,11 +54,12 @@ webSearchMessages={i === messages.length - 1 ? webSearchMessages : []} on:retry on:vote + on:continue /> {:else} {/each} - {#if pending} + {#if pending && messages[messages.length - 1]?.from === "user"} (); const handleSubmit = () => { @@ -124,6 +126,7 @@ } }} on:vote + on:continue on:retry={(ev) => { if (!loading) dispatch("retry", ev.detail); }} @@ -173,8 +176,20 @@ content: messages[messages.length - 1].content, })} /> - {:else if currentModel.multimodal} - + {:else} +
+ {#if currentModel.multimodal} + + {/if} + {#if messages && messages[messages.length - 1]?.interrupted && !isReadOnly} + + dispatch("continue", { + id: messages[messages.length - 1].id, + })} + /> + {/if} +
{/if}
): Endpoint { const { url, accessToken, model, authorization } = endpointTgiParametersSchema.parse(input); - return async ({ conversation }) => { - const prompt = await buildPrompt({ + + return async ({ conversation, continue: messageContinue }) => { + let prompt = await buildPrompt({ messages: conversation.messages, webSearch: conversation.messages[conversation.messages.length - 1].webSearch, preprompt: conversation.preprompt, @@ -24,6 +25,16 @@ export function endpointTgi(input: z.input): id: conversation._id, }); + if (messageContinue) { + // start with the full prompt, and for each stop token, try to remove it from the end of the prompt + prompt = model.parameters.stop.reduce((acc: string, curr: string) => { + if (acc.endsWith(curr)) { + return acc.slice(0, acc.length - curr.length); + } + return acc; + }, prompt.trimEnd()); + } + return textGenerationStream( { parameters: { ...model.parameters, return_full_text: false }, diff --git a/src/lib/types/Message.ts b/src/lib/types/Message.ts index e485dfc444f..4318034a418 100644 --- a/src/lib/types/Message.ts +++ b/src/lib/types/Message.ts @@ -11,4 +11,5 @@ export type Message = Partial & { webSearch?: WebSearch; score?: -1 | 0 | 1; files?: string[]; // can contain either the hash of the file or the b64 encoded image data on the client side when uploading + interrupted?: boolean; }; diff --git a/src/routes/conversation/[id]/+page.svelte b/src/routes/conversation/[id]/+page.svelte index b8633809653..075a66eb3de 100644 --- a/src/routes/conversation/[id]/+page.svelte +++ b/src/routes/conversation/[id]/+page.svelte @@ -64,9 +64,17 @@ } } // this function is used to send new message to the backends - async function writeMessage(message: string, messageId = randomUUID()) { - if (!message.trim()) return; - + async function writeMessage({ + prompt, + messageId = randomUUID(), + isRetry = false, + isContinue = false, + }: { + prompt?: string; + messageId?: ReturnType; + isRetry?: boolean; + isContinue?: boolean; + }): Promise { try { $isAborted = false; loading = true; @@ -74,13 +82,21 @@ // first we check if the messageId already exists, indicating a retry - let retryMessageIndex = messages.findIndex((msg) => msg.id === messageId); - const isRetry = retryMessageIndex !== -1; - // if it's not a retry we just use the whole array - if (!isRetry) { - retryMessageIndex = messages.length; + let msgIndex = messages.findIndex((msg) => msg.id === messageId); + + if (msgIndex === -1) { + msgIndex = messages.length - 1; + } + if (isRetry && messages[msgIndex].from === "assistant") { + throw new Error("Trying to retry a message that is not from user"); + } + + if (isContinue && messages[msgIndex].from === "user") { + throw new Error("Trying to continue a message that is not from assistant"); } + // const isNewMessage = !isRetry && !isContinue; + const module = await import("browser-image-resizer"); // currently, only IDEFICS is supported by TGI @@ -99,15 +115,31 @@ ); // slice up to the point of the retry - messages = [ - ...messages.slice(0, retryMessageIndex), - { - from: "user", - content: message, - id: messageId, - files: isRetry ? messages[retryMessageIndex].files : resizedImages, - }, - ]; + if (isRetry) { + messages = [ + ...messages.slice(0, msgIndex), + { + from: "user", + content: messages[msgIndex].content, + id: messageId, + files: messages[msgIndex].files, + }, + ]; + } else if (!isContinue) { + // or add a new message if its not a continue request + if (!prompt) { + throw new Error("Prompt is undefined"); + } + messages = [ + ...messages, + { + from: "user", + content: prompt ?? "", + id: messageId, + files: resizedImages, + }, + ]; + } files = []; @@ -115,9 +147,10 @@ method: "POST", headers: { "Content-Type": "application/json" }, body: JSON.stringify({ - inputs: message, + inputs: prompt, id: messageId, is_retry: isRetry, + is_continue: isContinue, web_search: $webSearchParameters.useSearch, files: isRetry ? undefined : resizedImages, }), @@ -282,37 +315,54 @@ // only used in case of creating new conversations (from the parent POST endpoint) if ($pendingMessage) { files = $pendingMessage.files; - await writeMessage($pendingMessage.content); + await writeMessage({ prompt: $pendingMessage.content }); $pendingMessage = undefined; } }); async function onMessage(event: CustomEvent) { if (!data.shared) { - writeMessage(event.detail); + await writeMessage({ prompt: event.detail }); } else { - convFromShared() + await convFromShared() .then(async (convId) => { await goto(`${base}/conversation/${convId}`, { invalidateAll: true }); }) - .then(() => writeMessage(event.detail)) + .then(async () => await writeMessage({ prompt: event.detail })) .finally(() => (loading = false)); } } async function onRetry(event: CustomEvent<{ id: Message["id"]; content: string }>) { if (!data.shared) { - writeMessage(event.detail.content, event.detail.id); + await writeMessage({ + prompt: event.detail.content, + messageId: event.detail.id, + isRetry: true, + }); } else { - convFromShared() + await convFromShared() .then(async (convId) => { await goto(`${base}/conversation/${convId}`, { invalidateAll: true }); }) - .then(() => writeMessage(event.detail.content, event.detail.id)) + .then( + async () => + await writeMessage({ + prompt: event.detail.content, + messageId: event.detail.id, + isRetry: true, + }) + ) .finally(() => (loading = false)); } } + async function onContinue(event: CustomEvent<{ id: Message["id"] }>) { + if (!data.shared) { + writeMessage({ messageId: event.detail.id, isContinue: true }); + } + } + $: $page.params.id, (($isAborted = true), (loading = false)); $: title = data.conversations.find((conv) => conv.id === $page.params.id)?.title ?? data.title; @@ -337,6 +387,7 @@ bind:files on:message={onMessage} on:retry={onRetry} + on:continue={onContinue} on:vote={(event) => voteMessage(event.detail.score, event.detail.id)} on:share={() => shareConversation($page.params.id, data.title)} on:stop={() => (($isAborted = true), (loading = false))} diff --git a/src/routes/conversation/[id]/+server.ts b/src/routes/conversation/[id]/+server.ts index 2e8c7484ec8..db0b2a9eec9 100644 --- a/src/routes/conversation/[id]/+server.ts +++ b/src/routes/conversation/[id]/+server.ts @@ -91,14 +91,16 @@ export async function POST({ request, locals, params, getClientAddress }) { const { inputs: newPrompt, id: messageId, - is_retry, + is_retry: isRetry, + is_continue: isContinue, web_search: webSearch, files: b64files, } = z .object({ - inputs: z.string().trim().min(1), + inputs: z.optional(z.string().trim().min(1)), id: z.optional(z.string().uuid()), is_retry: z.optional(z.boolean()), + is_continue: z.optional(z.boolean()), web_search: z.optional(z.boolean()), files: z.optional(z.array(z.string())), }) @@ -136,38 +138,50 @@ export async function POST({ request, locals, params, getClientAddress }) { hashes = await Promise.all(files.map(async (file) => await uploadFile(file, conv))); } + // can only call isContinue on the last message id + if (isContinue && conv.messages[conv.messages.length - 1].id !== messageId) { + throw error(400, "Can only continue the last message"); + } + // get the list of messages // while checking for retries let messages = (() => { - if (is_retry && messageId) { + // for retries we slice and rewrite the last user message + if (isRetry && messageId) { // if the message is a retry, replace the message and remove the messages after it let retryMessageIdx = conv.messages.findIndex((message) => message.id === messageId); + if (retryMessageIdx === -1) { retryMessageIdx = conv.messages.length; } + return [ ...conv.messages.slice(0, retryMessageIdx), { - content: newPrompt, + content: conv.messages[retryMessageIdx]?.content, from: "user", id: messageId as Message["id"], updatedAt: new Date(), files: conv.messages[retryMessageIdx]?.files, }, ]; + } else if (isContinue && messageId) { + // for continue we do nothing and expand the last assistant message + return conv.messages; + } else { + // in normal conversation we add an extra user message + return [ + ...conv.messages, + { + content: newPrompt ?? "", + from: "user", + id: (messageId as Message["id"]) || crypto.randomUUID(), + createdAt: new Date(), + updatedAt: new Date(), + files: hashes, + }, + ]; } // else append the message at the bottom - - return [ - ...conv.messages, - { - content: newPrompt, - from: "user", - id: (messageId as Message["id"]) || crypto.randomUUID(), - createdAt: new Date(), - updatedAt: new Date(), - files: hashes, - }, - ]; })() satisfies Message[]; await collections.conversations.updateOne( @@ -183,10 +197,14 @@ export async function POST({ request, locals, params, getClientAddress }) { } ); + let doneStreaming = false; + // we now build the stream const stream = new ReadableStream({ async start(controller) { - const updates: MessageUpdate[] = []; + const updates: MessageUpdate[] = isContinue + ? conv.messages[conv.messages.length - 1].updates ?? [] + : []; function update(newUpdate: MessageUpdate) { if (newUpdate.type !== "stream") { @@ -209,7 +227,7 @@ export async function POST({ request, locals, params, getClientAddress }) { const summarizeIfNeeded = (async () => { if (conv.title === "New Chat" && messages.length === 1) { try { - conv.title = (await summarize(newPrompt)) ?? conv.title; + conv.title = (await summarize(messages[0].content)) ?? conv.title; update({ type: "status", status: "title", message: conv.title }); } catch (e) { console.error(e); @@ -232,17 +250,22 @@ export async function POST({ request, locals, params, getClientAddress }) { let webSearchResults: WebSearch | undefined; - if (webSearch) { - webSearchResults = await runWebSearch(conv, newPrompt, update); + if (webSearch && !isContinue) { + webSearchResults = await runWebSearch(conv, messages[messages.length - 1].content, update); + messages[messages.length - 1].webSearch = webSearchResults; + } else if (isContinue) { + webSearchResults = messages[messages.length - 1].webSearch; } - messages[messages.length - 1].webSearch = webSearchResults; - conv.messages = messages; + const previousContent = isContinue + ? conv.messages.find((message) => message.id === messageId)?.content ?? "" + : ""; + try { const endpoint = await model.getEndpoint(); - for await (const output of await endpoint({ conversation: conv })) { + for await (const output of await endpoint({ conversation: conv, continue: isContinue })) { // if not generated_text is here it means the generation is not done if (!output.generated_text) { // else we get the next token @@ -292,7 +315,8 @@ export async function POST({ request, locals, params, getClientAddress }) { ...messages.slice(0, -1), { ...messages[messages.length - 1], - content: output.generated_text, + content: previousContent + output.generated_text, + interrupted: !output.token.special, // if its a special token it finished on its own, else it was interrupted updates, updatedAt: new Date(), }, @@ -302,6 +326,7 @@ export async function POST({ request, locals, params, getClientAddress }) { } catch (e) { update({ type: "status", status: "error", message: (e as Error).message }); } + await collections.conversations.updateOne( { _id: convId, @@ -315,6 +340,9 @@ export async function POST({ request, locals, params, getClientAddress }) { } ); + // used to detect if cancel() is called bc of interrupt or just because the connection closes + doneStreaming = true; + update({ type: "finalAnswer", text: messages[messages.length - 1].content, @@ -324,18 +352,20 @@ export async function POST({ request, locals, params, getClientAddress }) { return; }, async cancel() { - await collections.conversations.updateOne( - { - _id: convId, - }, - { - $set: { - messages, - title: conv.title, - updatedAt: new Date(), + if (!doneStreaming) { + await collections.conversations.updateOne( + { + _id: convId, }, - } - ); + { + $set: { + messages, + title: conv.title, + updatedAt: new Date(), + }, + } + ); + } }, });