Skip to content

Commit

Permalink
Showing 5 changed files with 105 additions and 91 deletions.
30 changes: 19 additions & 11 deletions gui/src/context/IdeMessenger.ts
Original file line number Diff line number Diff line change
@@ -37,14 +37,14 @@ export interface IIdeMessenger {
messageType: T,
data: FromWebviewProtocol[T][0],
cancelToken?: AbortSignal,
): FromWebviewProtocol[T][1];
): AsyncGenerator<unknown[]>;

llmStreamChat(
modelTitle: string,
cancelToken: AbortSignal | undefined,
messages: ChatMessage[],
options?: LLMFullCompletionOptions,
): AsyncGenerator<ChatMessage, PromptLog, unknown>;
): AsyncGenerator<ChatMessage[], PromptLog, unknown>;

ide: IDE;
}
@@ -147,11 +147,20 @@ export class IdeMessenger implements IIdeMessenger {
});
}

/**
* Because of weird type stuff, we're actually yielding an array of the things
* that are streamed. For example, if the return type here says
* AsyncGenerator<ChatMessage>, then it's actually AsyncGenerator<ChatMessage[]>.
* This needs to be handled by the caller.
*
* Using unknown for now to make this more explicit
*/
async *streamRequest<T extends keyof FromWebviewProtocol>(
messageType: T,
data: FromWebviewProtocol[T][0],
cancelToken?: AbortSignal,
): FromWebviewProtocol[T][1] {
): AsyncGenerator<unknown[]> {
// ): FromWebviewProtocol[T][1] {
const messageId = uuidv4();

this.post(messageType, data, messageId);
@@ -181,17 +190,16 @@ export class IdeMessenger implements IIdeMessenger {

while (!done) {
if (buffer.length > index) {
const chunk = buffer[index];
index++;
yield chunk;
const chunks = buffer.slice(index);
index = buffer.length;
yield chunks;
}
await new Promise((resolve) => setTimeout(resolve, 50));
}

while (buffer.length > index) {
const chunk = buffer[index];
index++;
yield chunk;
if (buffer.length > index) {
const chunks = buffer.slice(index);
yield chunks;
}

return returnVal;
@@ -202,7 +210,7 @@ export class IdeMessenger implements IIdeMessenger {
cancelToken: AbortSignal | undefined,
messages: ChatMessage[],
options: LLMFullCompletionOptions = {},
): AsyncGenerator<ChatMessage, PromptLog> {
): AsyncGenerator<ChatMessage[], PromptLog> {
const gen = this.streamRequest(
"llm/streamChat",
{
134 changes: 68 additions & 66 deletions gui/src/redux/slices/sessionSlice.ts
Original file line number Diff line number Diff line change
@@ -10,21 +10,21 @@ import {
ApplyState,
ChatHistoryItem,
ChatMessage,
CodeToEdit,
ContextItem,
ContextItemWithId,
FileSymbolMap,
Session,
MessageModes,
PromptLog,
CodeToEdit,
Session,
ToolCall,
ContextItem,
MessageModes,
} from "core";
import { incrementalParseJson } from "core/util/incrementalParseJson";
import { renderChatMessage } from "core/util/messageContent";
import { v4 as uuidv4 } from "uuid";
import { RootState } from "../store";
import { streamResponseThunk } from "../thunks/streamResponse";
import { findCurrentToolCall } from "../util";
import { RootState } from "../store";

// We need this to handle reorderings (e.g. a mid-array deletion) of the messages array.
// The proper fix is adding a UUID to all chat messages, but this is the temp workaround.
@@ -284,72 +284,74 @@ export const sessionSlice = createSlice({
state.streamAborter.abort();
state.streamAborter = new AbortController();
},
streamUpdate: (state, action: PayloadAction<ChatMessage>) => {
streamUpdate: (state, action: PayloadAction<ChatMessage[]>) => {
if (state.history.length) {
const lastMessage = state.history[state.history.length - 1];

if (
action.payload.role &&
(lastMessage.message.role !== action.payload.role ||
// This is when a tool call comes after assistant text
(lastMessage.message.content !== "" &&
action.payload.role === "assistant" &&
action.payload.toolCalls?.length))
) {
const baseHistoryItem = getBaseHistoryItem();

// Create a new message
const historyItem: ChatHistoryItemWithMessageId = {
...baseHistoryItem,
message: { ...baseHistoryItem.message, ...action.payload },
};

if (action.payload.role === "assistant" && action.payload.toolCalls) {
const [_, parsedArgs] = incrementalParseJson(
action.payload.toolCalls[0].function.arguments,
);
historyItem.toolCallState = {
status: "generating",
toolCall: action.payload.toolCalls[0] as ToolCall,
toolCallId: action.payload.toolCalls[0].id,
parsedArgs,
for (const message of action.payload) {
const lastMessage = state.history[state.history.length - 1];

if (
message.role &&
(lastMessage.message.role !== message.role ||
// This is when a tool call comes after assistant text
(lastMessage.message.content !== "" &&
message.role === "assistant" &&
message.toolCalls?.length))
) {
const baseHistoryItem = getBaseHistoryItem();

// Create a new message
const historyItem: ChatHistoryItemWithMessageId = {
...baseHistoryItem,
message: { ...baseHistoryItem.message, ...message },
};
}

state.history.push(historyItem);
} else {
// Add to the existing message
const msg = state.history[state.history.length - 1].message;
if (action.payload.content) {
msg.content += renderChatMessage(action.payload);
} else if (
action.payload.role === "assistant" &&
action.payload.toolCalls &&
msg.role === "assistant"
) {
if (!msg.toolCalls) {
msg.toolCalls = [];
if (message.role === "assistant" && message.toolCalls) {
const [_, parsedArgs] = incrementalParseJson(
message.toolCalls[0].function.arguments,
);
historyItem.toolCallState = {
status: "generating",
toolCall: message.toolCalls[0] as ToolCall,
toolCallId: message.toolCalls[0].id,
parsedArgs,
};
}
action.payload.toolCalls.forEach((toolCall, i) => {
if (msg.toolCalls.length <= i) {
msg.toolCalls.push(toolCall);
} else {
msg.toolCalls[i].function.arguments +=
toolCall.function.arguments;

const [_, parsedArgs] = incrementalParseJson(
msg.toolCalls[i].function.arguments,
);

state.history[
state.history.length - 1
].toolCallState.parsedArgs = parsedArgs;
state.history[
state.history.length - 1
].toolCallState.toolCall.function.arguments +=
toolCall.function.arguments;

state.history.push(historyItem);
} else {
// Add to the existing message
const msg = state.history[state.history.length - 1].message;
if (message.content) {
msg.content += renderChatMessage(message);
} else if (
message.role === "assistant" &&
message.toolCalls &&
msg.role === "assistant"
) {
if (!msg.toolCalls) {
msg.toolCalls = [];
}
});
message.toolCalls.forEach((toolCall, i) => {
if (msg.toolCalls.length <= i) {
msg.toolCalls.push(toolCall);
} else {
msg.toolCalls[i].function.arguments +=
toolCall.function.arguments;

const [_, parsedArgs] = incrementalParseJson(
msg.toolCalls[i].function.arguments,
);

state.history[
state.history.length - 1
].toolCallState.parsedArgs = parsedArgs;
state.history[
state.history.length - 1
].toolCallState.toolCall.function.arguments +=
toolCall.function.arguments;
}
});
}
}
}
}
11 changes: 3 additions & 8 deletions gui/src/redux/thunks/streamNormalInput.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import { createAsyncThunk } from "@reduxjs/toolkit";
import { ChatMessage, PromptLog } from "core";
import { selectCurrentToolCall } from "../selectors/selectCurrentToolCall";
import { selectDefaultModel } from "../slices/configSlice";
import {
abortStream,
addPromptCompletionPair,
@@ -10,7 +11,6 @@ import {
} from "../slices/sessionSlice";
import { ThunkApiType } from "../store";
import { callTool } from "./callTool";
import { selectDefaultModel } from "../slices/configSlice";

export const streamNormalInput = createAsyncThunk<
void,
@@ -56,14 +56,9 @@ export const streamNormalInput = createAsyncThunk<
break;
}

const update = next.value as ChatMessage;
dispatch(streamUpdate(update));
const updates = next.value as ChatMessage[];
dispatch(streamUpdate(updates));
next = await gen.next();

// There has been lag when streaming tool calls. This is a temporary solution
if (update.role === "assistant" && update.toolCalls) {
await new Promise((resolve) => setTimeout(resolve, 10));
}
}

// Attach prompt log
4 changes: 2 additions & 2 deletions gui/src/redux/thunks/streamResponseAfterToolCall.ts
Original file line number Diff line number Diff line change
@@ -2,6 +2,7 @@ import { createAsyncThunk } from "@reduxjs/toolkit";
import { ChatMessage, ContextItem } from "core";
import { constructMessages } from "core/llm/constructMessages";
import { renderContextItems } from "core/util/messageContent";
import { selectDefaultModel } from "../slices/configSlice";
import {
addContextItemsAtIndex,
setActive,
@@ -11,7 +12,6 @@ import { ThunkApiType } from "../store";
import { handleErrors } from "./handleErrors";
import { resetStateForNewMessage } from "./resetStateForNewMessage";
import { streamNormalInput } from "./streamNormalInput";
import { selectDefaultModel } from "../slices/configSlice";

export const streamResponseAfterToolCall = createAsyncThunk<
void,
@@ -39,7 +39,7 @@ export const streamResponseAfterToolCall = createAsyncThunk<
toolCallId,
};

dispatch(streamUpdate(newMessage));
dispatch(streamUpdate([newMessage]));
dispatch(
addContextItemsAtIndex({
index: initialHistory.length,
17 changes: 13 additions & 4 deletions gui/src/redux/thunks/streamSlashCommand.ts
Original file line number Diff line number Diff line change
@@ -5,9 +5,9 @@ import {
RangeInFile,
SlashCommandDescription,
} from "core";
import { ThunkApiType } from "../store";
import { abortStream, streamUpdate } from "../slices/sessionSlice";
import { selectDefaultModel } from "../slices/configSlice";
import { abortStream, streamUpdate } from "../slices/sessionSlice";
import { ThunkApiType } from "../store";

export const streamSlashCommand = createAsyncThunk<
void,
@@ -62,8 +62,17 @@ export const streamSlashCommand = createAsyncThunk<
dispatch(abortStream());
break;
}
if (typeof update === "string") {
dispatch(streamUpdate(update));
for (const item of update) {
if (typeof item === "string") {
dispatch(
streamUpdate([
{
role: "assistant",
content: item,
},
]),
);
}
}
}
clearInterval(checkActiveInterval);

0 comments on commit ef7fa34

Please sign in to comment.