Skip to content

Commit

Permalink
optimize chat
Browse files Browse the repository at this point in the history
  • Loading branch information
greywen committed Dec 18, 2024
1 parent f86c17a commit 8a22b17
Show file tree
Hide file tree
Showing 13 changed files with 186 additions and 128 deletions.
11 changes: 7 additions & 4 deletions src/FE/apis/clientApis.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { useFetch } from '@/hooks/useFetch';

import { calculateMessages } from '@/utils/message';
import { formatMessages } from '@/utils/message';

import { AdminModelDto, PostPromptParams } from '@/types/adminApis';
import { ChatMessage } from '@/types/chatMessage';
Expand Down Expand Up @@ -33,9 +33,7 @@ export const changeUserPassword = (params: PostUserPassword) => {

export const getUserMessages = (chatId: string): Promise<ChatMessage[]> => {
const fetchService = useFetch();
return fetchService.get(`/api/messages/${chatId}`).then((data: any) => {
return calculateMessages(data) as any;
});
return fetchService.get(`/api/messages/${chatId}`);
};

export const getChatsByPaging = (
Expand Down Expand Up @@ -68,6 +66,11 @@ export const deleteChats = (id: string) => {
return fetchService.delete(`/api/user/chats/${id}`);
};

export const stopChat = (id: string) => {
const fetchService = useFetch();
return fetchService.post(`/api/chats/stop/${id}`);
};

export const getCsrfToken = (): Promise<{ csrfToken: string }> => {
const fetchServer = useFetch();
return fetchServer.get('/api/auth/csrf');
Expand Down
8 changes: 7 additions & 1 deletion src/FE/pages/home/_actions/chat.actions.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import {
SetIsChatsLoadingType,
SetMessageIsStreamingType,
SetSelectedChatType,
SetStopIdsType,
} from '../_reducers/chat.reducer';

export const setChats = (chats: SetChatsType): ChatAction => ({
Expand Down Expand Up @@ -39,8 +40,13 @@ export const setMessageIsStreaming = (
export const setIsChatsLoading = (
isChatsLoading: SetIsChatsLoadingType,
): ChatAction => ({
type: ChatActionTypes.SET_MESSAGE_IS_STREAMING,
type: ChatActionTypes.SET_IS_CHATS_LOADING,
payload: isChatsLoading,
});

export const setStopIds = (stopIds: SetStopIdsType): ChatAction => ({
type: ChatActionTypes.SET_STOP_IDS,
payload: stopIds,
});

export default function () {}
13 changes: 11 additions & 2 deletions src/FE/pages/home/_actions/message.actions.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,23 @@ import {
SetCurrentMessageIdType,
SetCurrentMessagesType,
SetLastMessageIdType,
SetMessagesType,
SetSelectedMessagesType,
} from '../_reducers/message.reducer';

export const setMessages = (messages: SetMessagesType): MessageAction => ({
export const setMessages = (
messages: SetSelectedMessagesType,
): MessageAction => ({
type: MessageActionTypes.SET_MESSAGES,
payload: messages,
});

export const setSelectedMessages = (
selectedMessages: SetSelectedMessagesType,
): MessageAction => ({
type: MessageActionTypes.SET_SELECTED_MESSAGES,
payload: selectedMessages,
});

export const setCurrentMessages = (
currentMessages: SetCurrentMessagesType,
): MessageAction => ({
Expand Down
140 changes: 67 additions & 73 deletions src/FE/pages/home/_components/Chat/Chat.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -11,26 +11,29 @@ import toast from 'react-hot-toast';
import useTranslation from '@/hooks/useTranslation';

import { getApiUrl } from '@/utils/common';
import { formatMessages, getSelectMessages } from '@/utils/message';
import { throttle } from '@/utils/throttle';
import { getUserSession } from '@/utils/user';

import { ChatBody, Content, ContentRequest, Message, Role } from '@/types/chat';
import { SseResponseKind, SseResponseLine } from '@/types/chatMessage';
import { Prompt } from '@/types/prompt';

import ChangeModel from '@/components/ChangeModel/ChangeModel';
import TemperatureSlider from '@/components/TemperatureSlider/TemperatureSlider';

import { setMessageIsStreaming } from '../../_actions/chat.actions';
import { setMessageIsStreaming, setStopIds } from '../../_actions/chat.actions';
import {
setCurrentMessages,
setLastMessageId,
setMessages,
setSelectedMessages,
} from '../../_actions/message.actions';
import { setSelectedModel } from '../../_actions/model.actions';
import {
setEnableSearch,
setPrompt,
setTemperature,
setUserModelConfig,
} from '../../_actions/userModelConfig.actions';
import HomeContext from '../../_contexts/home.context';
import ModeToggle from '../ModeToggle/ModeToggle';
Expand All @@ -41,9 +44,8 @@ import ModelSelect from './ModelSelect';
import NoModel from './NoModel';
import SystemPrompt from './SystemPrompt';

import { getChat, putUserChatModel } from '@/apis/clientApis';
import { putUserChatModel } from '@/apis/clientApis';
import { cn } from '@/lib/utils';
import { SseResponseKind, SseResponseLine } from '@/types/chatMessage';

const Chat = memo(() => {
const { t } = useTranslation();
Expand All @@ -59,6 +61,7 @@ const Chat = memo(() => {
chatError,
messageIsStreaming,

messages,
selectMessages,
currentChatMessageId,
currentMessages,
Expand All @@ -73,10 +76,9 @@ const Chat = memo(() => {
handleStartChat,
handleChatIsError,
handleUpdateChatStatus,
handleUpdateChats,
handleStopChats,

handleUpdateSelectMessage,
handleUpdateCurrentMessage,
hasModel,
chatDispatch,
messageDispatch,
Expand All @@ -89,8 +91,6 @@ const Chat = memo(() => {

const messagesEndRef = useRef<HTMLDivElement>(null);
const chatContainerRef = useRef<HTMLDivElement>(null);
const stopConversationRef = useRef<boolean>(false);

const getSelectMessagesLast = () => {
const selectMessageLength = selectMessages.length - 1;
const lastMessage = { ...selectMessages[selectMessageLength] };
Expand All @@ -104,32 +104,37 @@ const Chat = memo(() => {
isRegenerate: boolean,
modelId?: number,
) => {
const isChatEmpty = selectMessages.length === 0;
handleUpdateChatStatus(false);
let selectChatId = selectChat?.id;
let selectMessageList = [...selectMessages];
let newMessages = [...messages];
let newSelectedMessages = [...selectMessages];
let assistantParentId = messageId;
if (!selectChatId) {
const newChat = await handleCreateNewChat();
selectChatId = newChat.id;
}
let selectedMessageId = messageId;
const MESSAGE_TEMP_ID = 'userMessageTempId';
if (messageId && isRegenerate) {
const messageIndex = selectMessageList.findIndex(
const messageIndex = newSelectedMessages.findIndex(
(x) => x.id === messageId,
);
selectMessageList.splice(messageIndex + 1, selectMessageList.length);
newSelectedMessages.splice(
messageIndex + 1,
newSelectedMessages.length,
);
} else {
const messageTempId = 'userMessageTempId';
assistantParentId = messageTempId;
selectedMessageId = messageTempId;
const parentMessage = selectMessageList.find((x) => x.id == messageId);
parentMessage && parentMessage?.childrenIds.unshift(messageTempId);
const parentMessageIndex = selectMessageList.findIndex(
assistantParentId = MESSAGE_TEMP_ID;
selectedMessageId = MESSAGE_TEMP_ID;
const parentMessage = newSelectedMessages.find(
(x) => x.id == messageId,
);
parentMessage && parentMessage?.childrenIds?.unshift(MESSAGE_TEMP_ID);
const parentMessageIndex = newSelectedMessages.findIndex(
(x) => x.id == messageId,
);
const newUserMessage = {
id: messageTempId,
id: MESSAGE_TEMP_ID,
role: 'user' as Role,
parentId: messageId,
childrenIds: [],
Expand All @@ -142,24 +147,25 @@ const Chat = memo(() => {
};
let removeCount = -1;
if (parentMessageIndex !== -1)
removeCount = selectMessageList.length - 1;
removeCount = newSelectedMessages.length - 1;
if (!messageId) {
removeCount = selectMessageList.length;
removeCount = newSelectedMessages.length;
messageDispatch(
setCurrentMessages([...currentMessages, newUserMessage]),
);
}

selectMessageList.splice(
newSelectedMessages.splice(
parentMessageIndex + 1,
removeCount,
newUserMessage,
);
newMessages.push(newUserMessage);
}

const assistantMessageTempId = 'assistantMessageTempId';
const ASSISTANT_MESSAGE_TEMP_ID = 'assistantMessageTempId';
const newAssistantMessage = {
id: assistantMessageTempId,
id: ASSISTANT_MESSAGE_TEMP_ID,
role: 'assistant' as Role,
parentId: assistantParentId,
childrenIds: [],
Expand All @@ -174,11 +180,12 @@ const Chat = memo(() => {
outputPrice: 0,
};

selectMessageList.push(newAssistantMessage);
newSelectedMessages.push(newAssistantMessage);
newMessages.push(newAssistantMessage);
handleStartChat(
selectMessageList,
newSelectedMessages,
selectedMessageId,
assistantMessageTempId,
ASSISTANT_MESSAGE_TEMP_ID,
);

const messageContent: ContentRequest = {
Expand All @@ -202,14 +209,12 @@ const Chat = memo(() => {
};
let body = JSON.stringify(chatBody);

const controller = new AbortController();
const response = await fetch(`${getApiUrl()}/api/chats`, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
Authorization: `Bearer ${getUserSession()}`,
},
signal: controller.signal,
body,
});

Expand All @@ -225,22 +230,22 @@ const Chat = memo(() => {
return;
}

let errorChat = false;
let isErrorChat = false;
let text = '';
const reader = data.getReader();
const decoder = new TextDecoder();
let buffer = '';

function setSelectMessages(content: Content) {
let lastMessages = selectMessageList[selectMessageList.length - 1];
let lastMessages = newSelectedMessages[newSelectedMessages.length - 1];
lastMessages = {
...lastMessages,
content,
};

selectMessageList.splice(-1, 1, lastMessages);
newSelectedMessages.splice(-1, 1, lastMessages);

messageDispatch(setMessages([...selectMessageList]));
messageDispatch(setSelectedMessages([...newSelectedMessages]));
}
async function* processBuffer() {
while (true) {
Expand Down Expand Up @@ -271,49 +276,40 @@ const Chat = memo(() => {
const value: SseResponseLine = JSON.parse(message);

if (value.k === SseResponseKind.StopId) {
const stopId = value.r;
console.log('stopId', stopId);
}
else if (value.k === SseResponseKind.Segment) {
chatDispatch(setStopIds([value.r]));
} else if (value.k === SseResponseKind.Segment) {
text += value.r;
setSelectMessages({ text });
}
else if (value.k === SseResponseKind.Error) { // error
errorChat = true;
handleUpdateChatStatus(errorChat);
controller.abort();
} else if (value.k === SseResponseKind.Error) {
isErrorChat = true;
handleUpdateChatStatus(isErrorChat);
handleStopChats();
setSelectMessages({ text, error: value.r });
break;
}
else if (value.k === SseResponseKind.End) {
console.log('End', value.r);
}
if (stopConversationRef.current === true) {
controller.abort();
break;
}
}

if (isChatEmpty) {
setTimeout(async () => {
const data = await getChat(selectChatId);
const _chats = chats.map((x) => {
if (x.id === data.id) {
return data;
} else if (value.k === SseResponseKind.End) {
const { requestMessage, responseMessage } = value.r;
newMessages = newMessages.map((m) => {
if (requestMessage && m.id === MESSAGE_TEMP_ID) {
m = { ...m, ...requestMessage };
}
if (m.id === ASSISTANT_MESSAGE_TEMP_ID) {
m = { ...m, ...responseMessage };
}
return x;
return m;
});
userModelConfigDispatch(setUserModelConfig(data.userModelConfig));
handleUpdateChats(_chats);
}, 100);
const messageList = formatMessages(newMessages);
messageDispatch(setMessages(newMessages));
messageDispatch(setCurrentMessages(messageList));
const lastMessage = messageList[messageList.length - 1];
const selectMessageList = getSelectMessages(
messageList,
lastMessage.id,
);
messageDispatch(setSelectedMessages(selectMessageList));
messageDispatch(setLastMessageId(lastMessage.id));
}
}

chatDispatch(setMessageIsStreaming(false));
!errorChat &&
setTimeout(() => {
handleUpdateCurrentMessage(selectChatId);
}, 200);
stopConversationRef.current = false;
},
[
prompt,
Expand All @@ -325,7 +321,6 @@ const Chat = memo(() => {
currentMessages,
selectMessages,
selectModel,
stopConversationRef,
],
);

Expand Down Expand Up @@ -480,10 +475,10 @@ const Chat = memo(() => {
(x) => x === current.id,
)}
parentId={current.parentId}
childrenIds={current.childrenIds}
childrenIds={current.childrenIds!}
parentChildrenIds={parentChildrenIds}
assistantChildrenIds={current.assistantChildrenIds}
assistantCurrentSelectIndex={current.assistantChildrenIds.findIndex(
assistantChildrenIds={current.assistantChildrenIds!}
assistantCurrentSelectIndex={current.assistantChildrenIds!.findIndex(
(x) => x === current.id,
)}
lastMessageId={lastMessage.id}
Expand Down Expand Up @@ -530,7 +525,6 @@ const Chat = memo(() => {
</div>
{hasModel() && (
<ChatInput
stopConversationRef={stopConversationRef}
onSend={(message) => {
const { lastMessage } = getSelectMessagesLast();
handleSend(message, lastMessage?.id, false);
Expand Down
Loading

0 comments on commit 8a22b17

Please sign in to comment.