Skip to content

Commit

Permalink
feat: 更新 Request 支持更多自定义设置 (#267)
Browse files Browse the repository at this point in the history
  • Loading branch information
ONLY-yours authored Jul 9, 2024
1 parent 486484a commit c9ebd88
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 21 deletions.
2 changes: 1 addition & 1 deletion src/ProChat/mocks/streamResponse.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ export class MockResponse {

constructor(
private data: string,
private delay: number = 300,
private delay: number = 100,
error: boolean = false, // 新增参数,默认为false
) {
this.error = error;
Expand Down
46 changes: 39 additions & 7 deletions src/ProChat/store/action.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import { StateCreator } from 'zustand/vanilla';

import { LOADING_FLAT } from '@/ProChat/const/message';
import { ChatStore } from '@/ProChat/store/index';
import { fetchSSE, SSEFinishType } from '@/ProChat/utils/fetch';
import { fetchSSE, MixRequestResponse, SSEFinishType } from '@/ProChat/utils/fetch';
import { isFunctionMessage } from '@/ProChat/utils/message';
import { setNamespace } from '@/ProChat/utils/storeDebug';
import { nanoid } from '@/ProChat/utils/uuid';
Expand Down Expand Up @@ -96,7 +96,7 @@ export interface ChatAction {
defaultModelFetcher: (
params: Partial<ChatStreamPayload>,
options?: FetchChatModelOptions,
) => Promise<Response>;
) => MixRequestResponse;

/**
* 生成消息 ID
Expand Down Expand Up @@ -124,6 +124,7 @@ export interface ChatAction {
stopAnimation: () => void;
outputQueue: string[];
isAnimationActive: boolean;
mixRequestResponse: MixRequestResponse;
};

/**
Expand Down Expand Up @@ -187,6 +188,7 @@ export const chatAction: StateCreator<ChatStore, [['zustand/devtools', never]],
config,
defaultModelFetcher,
createSmoothMessage,
deleteMessage,
transformToChatMessage,
onChatEnd,
onChatStart,
Expand Down Expand Up @@ -245,12 +247,18 @@ export const chatAction: StateCreator<ChatStore, [['zustand/devtools', never]],
let output = '';
let isFunctionCall = false;

const { startAnimation, stopAnimation, outputQueue, isAnimationActive } =
const { startAnimation, stopAnimation, outputQueue, isAnimationActive, mixRequestResponse } =
createSmoothMessage(assistantId);

await fetchSSE(fetcher, {
signal: abortController?.signal,
onCancel: () => {
// cancel 时候删除 Loading 态的消息
deleteMessage(assistantId);
},
onErrorHandle: (error) => {
console.log('error!');

dispatchMessage({ id: assistantId, key: 'error', type: 'updateMessage', value: error });
},
onAbort: async () => {
Expand All @@ -268,9 +276,15 @@ export const chatAction: StateCreator<ChatStore, [['zustand/devtools', never]],
await startAnimation(15);
}
},
onMessageHandle: async (text) => {
onMessageHandle: async (text, response) => {
output += text;

if (response && typeof response === 'object' && 'content' in response) {
for (const [key, value] of Object.entries(response)) {
mixRequestResponse[key] = value;
}
}

if (!isAnimationActive && !isFunctionCall) startAnimation();

if (abortController?.signal.aborted) {
Expand Down Expand Up @@ -319,6 +333,8 @@ export const chatAction: StateCreator<ChatStore, [['zustand/devtools', never]],
// 因为如果顺序反了,messages 中将包含新增的 ai message
const mid = await getMessageId(messages, userMessageId);

console.log('fetch Real');

dispatchMessage({
id: mid,
message: LOADING_FLAT,
Expand Down Expand Up @@ -462,6 +478,8 @@ export const chatAction: StateCreator<ChatStore, [['zustand/devtools', never]],
// why use queue: https://shareg.pt/GLBrjpK
let outputQueue: string[] = [];

let mixRequestResponse = {};

// eslint-disable-next-line no-undef
let animationTimeoutId: NodeJS.Timeout | null = null;
let isAnimationActive = false;
Expand All @@ -484,6 +502,10 @@ export const chatAction: StateCreator<ChatStore, [['zustand/devtools', never]],
return;
}

console.log('mixRequestResponse', mixRequestResponse);

console.log('outputQueue', outputQueue);

isAnimationActive = true;

const updateText = () => {
Expand All @@ -501,8 +523,18 @@ export const chatAction: StateCreator<ChatStore, [['zustand/devtools', never]],
const charsToAdd = outputQueue.splice(0, speed).join('');
buffer += charsToAdd;

// 更新消息内容,这里可能需要结合实际情况调整
dispatchMessage({ id, key: 'content', type: 'updateMessage', value: buffer });
if (typeof mixRequestResponse === 'object' && 'content' in mixRequestResponse) {
dispatchMessage({
...mixRequestResponse,
id,
key: 'content',
type: 'updateMessage',
value: buffer,
});
} else {
// 更新消息内容,这里可能需要结合实际情况调整
dispatchMessage({ id, key: 'content', type: 'updateMessage', value: buffer });
}

// 设置下一个字符的延迟
animationTimeoutId = setTimeout(updateText, 16); // 16 毫秒的延迟模拟打字机效果
Expand All @@ -517,7 +549,7 @@ export const chatAction: StateCreator<ChatStore, [['zustand/devtools', never]],
updateText();
});

return { startAnimation, stopAnimation, outputQueue, isAnimationActive };
return { startAnimation, stopAnimation, outputQueue, isAnimationActive, mixRequestResponse };
},

getChatLoadingId: () => {
Expand Down
4 changes: 2 additions & 2 deletions src/ProChat/store/initialState.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@ import { ReactNode } from 'react';
import { FlexBasicProps } from 'react-layout-kit/lib/FlexBasic';
import { Locale } from '../../locale';
import { ProChatChatReference } from '../container/StoreUpdater';
import { SSEFinishType } from '../utils/fetch';
import { MixRequestResponse, SSEFinishType } from '../utils/fetch';

export type ChatRequest = (
messages: ChatMessage[],
config: ModelConfig,
signal: AbortSignal | undefined,
) => Promise<Response>;
) => MixRequestResponse;

export interface ChatPropsState<T extends Record<string, any> = Record<string, any>> {
/**
Expand Down
10 changes: 9 additions & 1 deletion src/ProChat/store/reducers/message.ts
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,18 @@ export const messagesReducer = (

case 'updateMessage': {
return produce(state, (draftState) => {
const { id, key, value } = payload;
// eslint-disable-next-line @typescript-eslint/no-unused-vars
const { id, key, value, type: _, ...rest } = payload;
const message = draftState.find((m) => m.id === id);
if (!message) return;

// 遍历 rest 对象并更新 message 对象
for (const [restKey, restValue] of Object.entries(rest)) {
console.log('restKey', restKey, restValue);

message[restKey] = restValue;
}

// @ts-ignore
message[key] = value;
message.updateAt = Date.now();
Expand Down
42 changes: 32 additions & 10 deletions src/ProChat/utils/fetch.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,14 @@ export const getMessageError = async (response: Response) => {

export type SSEFinishType = 'done' | 'error' | 'abort';

export type MixRequestResponse = Response | { content?: Response; [key: string]: any } | string;

export interface FetchSSEOptions {
onErrorHandle?: (error: ChatMessageError) => void;
onMessageHandle?: (text: string, response: Response) => void;
onMessageHandle?: (text: string, response: MixRequestResponse) => void;
onAbort?: (text: string) => Promise<void>;
onFinish?: (type: SSEFinishType) => Promise<void>;
onCancel?: () => void;
signal?: AbortSignal;
}

Expand All @@ -26,21 +29,41 @@ export interface FetchSSEOptions {
* @param fetchFn
* @param options
*/
export const fetchSSE = async (fetchFn: () => Promise<Response>, options: FetchSSEOptions = {}) => {
export const fetchSSE = async (
fetchFn: () => MixRequestResponse,
options: FetchSSEOptions = {},
) => {
const response = await fetchFn();

if (!response) {
options.onCancel?.();
return;
}

let returnRes = null;

let realResponse = null;

if (typeof response === 'object' && 'content' in response) {
returnRes = response?.content.clone();
realResponse = response?.content;
} else if (typeof response === 'string') {
returnRes = new Response(response);
realResponse = returnRes;
} else {
returnRes = response?.clone();
realResponse = response;
}

// 如果不 ok 说明有请求错误
if (!response.ok) {
if (!realResponse.ok) {
// TODO: need a message error custom parser
const chatMessageError = await getMessageError(response);

const chatMessageError = await getMessageError(realResponse);
options.onErrorHandle?.(chatMessageError);
return;
}

const returnRes = response.clone();

const data = response.body;
const data = realResponse.body;

if (!data) return;

Expand All @@ -60,8 +83,7 @@ export const fetchSSE = async (fetchFn: () => Promise<Response>, options: FetchS
done = doneReading;
if (value) {
const chunkValue = decoder.decode(value, { stream: !doneReading });
options.onMessageHandle?.(chunkValue, returnRes);
console.log('reader', chunkValue);
options.onMessageHandle?.(chunkValue, response);
}

if (done) {
Expand Down

0 comments on commit c9ebd88

Please sign in to comment.