diff --git a/components/use-x-chat/__tests__/index.test.tsx b/components/use-x-chat/__tests__/index.test.tsx index 444a033a..e4369c4f 100644 --- a/components/use-x-chat/__tests__/index.test.tsx +++ b/components/use-x-chat/__tests__/index.test.tsx @@ -211,7 +211,7 @@ describe('useXChat', () => { fireEvent.change(container.querySelector('input')!, { target: { value: 'little' } }); expect(getMessages(container)).toEqual([ expectMessage('little', 'local'), - expectMessage('bamboo', 'loading'), + expectMessage('bamboo', 'updating'), ]); await waitFakeTimer(); @@ -230,6 +230,8 @@ describe('useXChat', () => { expect(getMessages(container)).toEqual([ expectMessage('0_light', 'local'), expectMessage('1_light', 'local'), + expectMessage('0_', 'loading'), + expectMessage('1_', 'loading'), ]); }); @@ -246,7 +248,7 @@ describe('useXChat', () => { fireEvent.change(container.querySelector('input')!, { target: { value: 'little' } }); expect(getMessages(container)).toEqual([ expectMessage('little', 'local'), - expectMessage('bamboo', 'loading'), + expectMessage('bamboo', 'updating'), ]); await waitFakeTimer(); @@ -284,7 +286,10 @@ describe('useXChat', () => { expect.any(TransformStream), ); - expect(getMessages(container)).toEqual([expectMessage('little', 'local')]); + expect(getMessages(container)).toEqual([ + expectMessage('little', 'local'), + expectMessage('', 'loading'), + ]); }); it('custom require called resolveAbortController', (done) => { const transformStream = new TransformStream(); @@ -319,7 +324,10 @@ describe('useXChat', () => { />, ); fireEvent.change(container.querySelector('input')!, { target: { value: 'little' } }); - expect(getMessages(container)).toEqual([expectMessage('little', 'local')]); + expect(getMessages(container)).toEqual([ + expectMessage('little', 'local'), + expectMessage('', 'loading'), + ]); }); describe('transformMessage', () => { @@ -350,7 +358,7 @@ describe('useXChat', () => { expect(getMessages(container)).toEqual([ expectMessage('little', 'local'), - expectMessage('bamboo', 'loading'), + expectMessage('bamboo', 'updating'), ]); await waitFakeTimer(); expect(getMessages(container)).toEqual([ @@ -374,7 +382,7 @@ describe('useXChat', () => { expect(getMessages(container)).toEqual([ expectMessage('little', 'local'), - expectMessage('bamboo', 'loading'), + expectMessage('bamboo', 'updating'), ]); await waitFakeTimer(); expect(getMessages(container)).toEqual([ @@ -395,7 +403,10 @@ describe('useXChat', () => { ); fireEvent.change(container.querySelector('input')!, { target: { value: 'little' } }); - expect(getMessages(container)).toEqual([expectMessage('little', 'local')]); + expect(getMessages(container)).toEqual([ + expectMessage('little', 'local'), + expectMessage('', 'loading'), + ]); await waitFakeTimer(); expect(getMessages(container)).toEqual([ expectMessage('little', 'local'), diff --git a/components/use-x-chat/demo/stream.tsx b/components/use-x-chat/demo/stream.tsx index 8474eb31..e8f8221a 100644 --- a/components/use-x-chat/demo/stream.tsx +++ b/components/use-x-chat/demo/stream.tsx @@ -46,6 +46,7 @@ const App = () => { style={{ maxHeight: 300 }} items={messages.map(({ id, message, status }) => ({ key: id, + loading: status === 'loading', role: status === 'local' ? 'local' : 'ai', content: message, }))} diff --git a/components/use-x-chat/index.ts b/components/use-x-chat/index.ts index 2c59cb8f..b40d424b 100644 --- a/components/use-x-chat/index.ts +++ b/components/use-x-chat/index.ts @@ -8,7 +8,7 @@ import useSyncState from './useSyncState'; export type SimpleType = string | number | boolean | object; -export type MessageStatus = 'local' | 'loading' | 'success' | 'error'; +export type MessageStatus = 'local' | 'loading' | 'updating' | 'success' | 'error'; type RequestPlaceholderFn = ( message: Message, @@ -178,11 +178,7 @@ export default function useXChat< let message: AgentMessage; let otherRequestParams = {}; - if ( - requestParams && - typeof requestParams === 'object' && - 'message' in requestParams - ) { + if (requestParams && typeof requestParams === 'object' && 'message' in requestParams) { const { message: requestParamsMessage, ...other } = requestParams as RequestParams; message = requestParamsMessage; @@ -193,8 +189,8 @@ export default function useXChat< // Add placeholder message setMessages((ori) => { let nextMessages = [...ori, createMessage(message, 'local')]; + let placeholderMsg = '' as AgentMessage; if (requestPlaceholder) { - let placeholderMsg: AgentMessage; if (typeof requestPlaceholder === 'function') { // typescript has bug that not get real return type when use `typeof function` check placeholderMsg = (requestPlaceholder as RequestPlaceholderFn)(message, { @@ -203,11 +199,11 @@ export default function useXChat< } else { placeholderMsg = requestPlaceholder; } - const loadingMsg = createMessage(placeholderMsg, 'loading'); - loadingMsgId = loadingMsg.id; - - nextMessages = [...nextMessages, loadingMsg]; } + const loadingMsg = createMessage(placeholderMsg, 'loading'); + loadingMsgId = loadingMsg.id; + + nextMessages = [...nextMessages, loadingMsg]; return nextMessages; }); @@ -257,7 +253,7 @@ export default function useXChat< } as Input, { onUpdate: (chunk) => { - updateMessage('loading', chunk, []); + updateMessage('updating', chunk, []); }, onSuccess: (chunks) => { updateMessage('success', undefined as Output, chunks);