Skip to content

Commit

Permalink
Continue generation feature (#707)
Browse files Browse the repository at this point in the history
* Initial work on continue feature

* Move continue button

* Fix websearch with continue

* Make it work with every model

* Update src/routes/conversation/[id]/+server.ts

Co-authored-by: Mishig <[email protected]>

* fixes

* async all the things

* add reduce comment

* remove log

* Only show loading indicator if not continuing

---------

Co-authored-by: Mishig <[email protected]>
  • Loading branch information
nsarrazin and Mishig authored Jan 22, 2024
1 parent 6e0b0ea commit 77399ca
Show file tree
Hide file tree
Showing 11 changed files with 211 additions and 83 deletions.
6 changes: 4 additions & 2 deletions .env.template
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ MODELS=`[
"repetition_penalty": 1.2,
"top_k": 50,
"truncate": 3072,
"max_new_tokens": 1024
"max_new_tokens": 1024,
"stop" : ["</s>", " </s><s>[INST] "]
}
},
{
Expand Down Expand Up @@ -116,7 +117,8 @@ MODELS=`[
"repetition_penalty": 1.2,
"top_k": 50,
"truncate": 4096,
"max_new_tokens": 4096
"max_new_tokens": 4096,
"stop": [" </s><s>[INST] "]
}
},
{
Expand Down
34 changes: 18 additions & 16 deletions src/lib/buildPrompt.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ interface buildPromptOptions {
webSearch?: WebSearch;
preprompt?: string;
files?: File[];
continue?: boolean;
}

export async function buildPrompt({
Expand All @@ -22,37 +23,38 @@ export async function buildPrompt({
preprompt,
id,
}: buildPromptOptions): Promise<string> {
let modifiedMessages = [...messages];

if (webSearch && webSearch.context) {
const lastMsg = messages.slice(-1)[0];
const messagesWithoutLastUsrMsg = messages.slice(0, -1);
const previousUserMessages = messages.filter((el) => el.from === "user").slice(0, -1);
// find index of the last user message
const lastUsrMsgIndex = modifiedMessages.map((el) => el.from).lastIndexOf("user");

// combine all the other previous questions into one string
const previousUserMessages = modifiedMessages.filter((el) => el.from === "user").slice(0, -1);
const previousQuestions =
previousUserMessages.length > 0
? `Previous questions: \n${previousUserMessages
.map(({ content }) => `- ${content}`)
.join("\n")}`
: "";

const currentDate = format(new Date(), "MMMM d, yyyy");
messages = [
...messagesWithoutLastUsrMsg,
{
from: "user",
content: `I searched the web using the query: ${webSearch.searchQuery}. Today is ${currentDate} and here are the results:

// update the last user message directly (that way if the last message is an assistant partial answer, we keep the beginning of that answer)
modifiedMessages[lastUsrMsgIndex] = {
from: "user",
content: `I searched the web using the query: ${webSearch.searchQuery}. Today is ${currentDate} and here are the results:
=====================
${webSearch.context}
=====================
${previousQuestions}
Answer the question: ${lastMsg.content}
`,
},
];
Answer the question: ${messages[lastUsrMsgIndex].content} `,
};
}

// section to handle potential files input
if (model.multimodal) {
messages = await Promise.all(
messages.map(async (el) => {
modifiedMessages = await Promise.all(
modifiedMessages.map(async (el) => {
let content = el.content;

if (el.from === "user") {
Expand Down Expand Up @@ -83,7 +85,7 @@ export async function buildPrompt({

return (
model
.chatPromptRender({ messages, preprompt })
.chatPromptRender({ messages: modifiedMessages, preprompt })
// Not super precise, but it's truncated in the model's backend anyway
.split(" ")
.slice(-(model.parameters?.truncate ?? 0))
Expand Down
13 changes: 13 additions & 0 deletions src/lib/components/ContinueBtn.svelte
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
<script lang="ts">
import CarbonContinue from "~icons/carbon/continue";
export let classNames = "";
</script>

<button
type="button"
on:click
class="btn flex h-8 rounded-lg border bg-white px-3 py-1 text-gray-500 shadow-sm transition-all hover:bg-gray-100 dark:border-gray-600 dark:bg-gray-700 dark:text-gray-300 dark:hover:bg-gray-600 {classNames}"
>
<CarbonContinue class="mr-2 text-xs " /> Continue
</button>
1 change: 1 addition & 0 deletions src/lib/components/chat/ChatMessage.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import CarbonDownload from "~icons/carbon/download";
import CarbonThumbsUp from "~icons/carbon/thumbs-up";
import CarbonThumbsDown from "~icons/carbon/thumbs-down";
import { PUBLIC_SEP_TOKEN } from "$lib/constants/publicSepToken";
import type { Model } from "$lib/types/Model";
Expand Down
3 changes: 2 additions & 1 deletion src/lib/components/chat/ChatMessages.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,12 @@
webSearchMessages={i === messages.length - 1 ? webSearchMessages : []}
on:retry
on:vote
on:continue
/>
{:else}
<ChatIntroduction {models} {currentModel} on:message />
{/each}
{#if pending}
{#if pending && messages[messages.length - 1]?.from === "user"}
<ChatMessage
message={{ from: "assistant", content: "", id: randomUUID() }}
model={currentModel}
Expand Down
19 changes: 17 additions & 2 deletions src/lib/components/chat/ChatWindow.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import UploadBtn from "../UploadBtn.svelte";
import file2base64 from "$lib/utils/file2base64";
import { useSettingsStore } from "$lib/stores/settings";
import ContinueBtn from "../ContinueBtn.svelte";
export let messages: Message[] = [];
export let loading = false;
Expand All @@ -48,6 +49,7 @@
share: void;
stop: void;
retry: { id: Message["id"]; content: string };
continue: { id: Message["id"] };
}>();
const handleSubmit = () => {
Expand Down Expand Up @@ -124,6 +126,7 @@
}
}}
on:vote
on:continue
on:retry={(ev) => {
if (!loading) dispatch("retry", ev.detail);
}}
Expand Down Expand Up @@ -173,8 +176,20 @@
content: messages[messages.length - 1].content,
})}
/>
{:else if currentModel.multimodal}
<UploadBtn bind:files classNames="ml-auto" />
{:else}
<div class="ml-auto gap-2">
{#if currentModel.multimodal}
<UploadBtn bind:files classNames="ml-auto" />
{/if}
{#if messages && messages[messages.length - 1]?.interrupted && !isReadOnly}
<ContinueBtn
on:click={() =>
dispatch("continue", {
id: messages[messages.length - 1].id,
})}
/>
{/if}
</div>
{/if}
</div>
<form
Expand Down
1 change: 1 addition & 0 deletions src/lib/server/endpoints/endpoints.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ interface EndpointParameters {
preprompt?: Conversation["preprompt"];
_id?: Conversation["_id"];
};
continue?: boolean;
}

interface CommonEndpoint {
Expand Down
15 changes: 13 additions & 2 deletions src/lib/server/endpoints/tgi/endpointTgi.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,26 @@ export const endpointTgiParametersSchema = z.object({

export function endpointTgi(input: z.input<typeof endpointTgiParametersSchema>): Endpoint {
const { url, accessToken, model, authorization } = endpointTgiParametersSchema.parse(input);
return async ({ conversation }) => {
const prompt = await buildPrompt({

return async ({ conversation, continue: messageContinue }) => {
let prompt = await buildPrompt({
messages: conversation.messages,
webSearch: conversation.messages[conversation.messages.length - 1].webSearch,
preprompt: conversation.preprompt,
model,
id: conversation._id,
});

if (messageContinue) {
// start with the full prompt, and for each stop token, try to remove it from the end of the prompt
prompt = model.parameters.stop.reduce((acc: string, curr: string) => {
if (acc.endsWith(curr)) {
return acc.slice(0, acc.length - curr.length);
}
return acc;
}, prompt.trimEnd());
}

return textGenerationStream(
{
parameters: { ...model.parameters, return_full_text: false },
Expand Down
1 change: 1 addition & 0 deletions src/lib/types/Message.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,5 @@ export type Message = Partial<Timestamps> & {
webSearch?: WebSearch;
score?: -1 | 0 | 1;
files?: string[]; // can contain either the hash of the file or the b64 encoded image data on the client side when uploading
interrupted?: boolean;
};
101 changes: 76 additions & 25 deletions src/routes/conversation/[id]/+page.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -64,23 +64,39 @@
}
}
// this function is used to send new message to the backends
async function writeMessage(message: string, messageId = randomUUID()) {
if (!message.trim()) return;
async function writeMessage({
prompt,
messageId = randomUUID(),
isRetry = false,
isContinue = false,
}: {
prompt?: string;
messageId?: ReturnType<typeof randomUUID>;
isRetry?: boolean;
isContinue?: boolean;
}): Promise<void> {
try {
$isAborted = false;
loading = true;
pending = true;
// first we check if the messageId already exists, indicating a retry
let retryMessageIndex = messages.findIndex((msg) => msg.id === messageId);
const isRetry = retryMessageIndex !== -1;
// if it's not a retry we just use the whole array
if (!isRetry) {
retryMessageIndex = messages.length;
let msgIndex = messages.findIndex((msg) => msg.id === messageId);
if (msgIndex === -1) {
msgIndex = messages.length - 1;
}
if (isRetry && messages[msgIndex].from === "assistant") {
throw new Error("Trying to retry a message that is not from user");
}
if (isContinue && messages[msgIndex].from === "user") {
throw new Error("Trying to continue a message that is not from assistant");
}
// const isNewMessage = !isRetry && !isContinue;
const module = await import("browser-image-resizer");
// currently, only IDEFICS is supported by TGI
Expand All @@ -99,25 +115,42 @@
);
// slice up to the point of the retry
messages = [
...messages.slice(0, retryMessageIndex),
{
from: "user",
content: message,
id: messageId,
files: isRetry ? messages[retryMessageIndex].files : resizedImages,
},
];
if (isRetry) {
messages = [
...messages.slice(0, msgIndex),
{
from: "user",
content: messages[msgIndex].content,
id: messageId,
files: messages[msgIndex].files,
},
];
} else if (!isContinue) {
// or add a new message if its not a continue request
if (!prompt) {
throw new Error("Prompt is undefined");
}
messages = [
...messages,
{
from: "user",
content: prompt ?? "",
id: messageId,
files: resizedImages,
},
];
}
files = [];
const response = await fetch(`${base}/conversation/${$page.params.id}`, {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({
inputs: message,
inputs: prompt,
id: messageId,
is_retry: isRetry,
is_continue: isContinue,
web_search: $webSearchParameters.useSearch,
files: isRetry ? undefined : resizedImages,
}),
Expand Down Expand Up @@ -282,37 +315,54 @@
// only used in case of creating new conversations (from the parent POST endpoint)
if ($pendingMessage) {
files = $pendingMessage.files;
await writeMessage($pendingMessage.content);
await writeMessage({ prompt: $pendingMessage.content });
$pendingMessage = undefined;
}
});
async function onMessage(event: CustomEvent<string>) {
if (!data.shared) {
writeMessage(event.detail);
await writeMessage({ prompt: event.detail });
} else {
convFromShared()
await convFromShared()
.then(async (convId) => {
await goto(`${base}/conversation/${convId}`, { invalidateAll: true });
})
.then(() => writeMessage(event.detail))
.then(async () => await writeMessage({ prompt: event.detail }))
.finally(() => (loading = false));
}
}
async function onRetry(event: CustomEvent<{ id: Message["id"]; content: string }>) {
if (!data.shared) {
writeMessage(event.detail.content, event.detail.id);
await writeMessage({
prompt: event.detail.content,
messageId: event.detail.id,
isRetry: true,
});
} else {
convFromShared()
await convFromShared()
.then(async (convId) => {
await goto(`${base}/conversation/${convId}`, { invalidateAll: true });
})
.then(() => writeMessage(event.detail.content, event.detail.id))
.then(
async () =>
await writeMessage({
prompt: event.detail.content,
messageId: event.detail.id,
isRetry: true,
})
)
.finally(() => (loading = false));
}
}
async function onContinue(event: CustomEvent<{ id: Message["id"] }>) {
if (!data.shared) {
writeMessage({ messageId: event.detail.id, isContinue: true });
}
}
$: $page.params.id, (($isAborted = true), (loading = false));
$: title = data.conversations.find((conv) => conv.id === $page.params.id)?.title ?? data.title;
</script>
Expand All @@ -337,6 +387,7 @@
bind:files
on:message={onMessage}
on:retry={onRetry}
on:continue={onContinue}
on:vote={(event) => voteMessage(event.detail.score, event.detail.id)}
on:share={() => shareConversation($page.params.id, data.title)}
on:stop={() => (($isAborted = true), (loading = false))}
Expand Down
Loading

0 comments on commit 77399ca

Please sign in to comment.