Skip to content

Commit

Permalink
[Obs AI Assistant] Fixes issue w/ duplicated recalls (elastic#169927)
Browse files Browse the repository at this point in the history
Co-authored-by: Kibana Machine <[email protected]>
  • Loading branch information
dgieselaar and kibanamachine authored Oct 27, 2023
1 parent 54dd98d commit 6e21d81
Show file tree
Hide file tree
Showing 7 changed files with 184 additions and 92 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -122,13 +122,22 @@ export async function fetchSeries<T extends ValueAggregationMap>({
}

return response.aggregations.groupBy.buckets.map((bucket) => {
let value =
bucket.value?.value === undefined || bucket.value?.value === null
? null
: Number(bucket.value.value);

if (value !== null) {
value =
Math.abs(value) < 100
? Number(value.toPrecision(3))
: Math.round(value);
}

return {
groupBy: bucket.key_as_string || String(bucket.key),
data: bucket.timeseries.buckets,
value:
bucket.value?.value === undefined || bucket.value?.value === null
? null
: Math.round(bucket.value.value),
value,
change_point: bucket.change_point,
unit,
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
* 2.0.
*/
import React, { useCallback, useEffect, useMemo, useRef, useState } from 'react';
import { last } from 'lodash';
import { first } from 'lodash';
import { EuiFlexGroup, EuiFlexItem } from '@elastic/eui';
import { AbortError } from '@kbn/kibana-utils-plugin/common';
import { isObservable, Subscription } from 'rxjs';
Expand Down Expand Up @@ -42,41 +42,41 @@ function ChatContent({

const [pendingMessage, setPendingMessage] = useState<PendingMessage | undefined>();

const [recalledMessages, setRecalledMessages] = useState<Message[] | undefined>(undefined);

const [loading, setLoading] = useState(false);
const [subscription, setSubscription] = useState<Subscription | undefined>();

const [conversationId, setConversationId] = useState<string>();

const { conversation, displayedMessages, setDisplayedMessages, save, saveTitle } =
useConversation({
conversationId,
connectorId,
chatService,
});
const {
conversation,
displayedMessages,
setDisplayedMessages,
getSystemMessage,
save,
saveTitle,
} = useConversation({
conversationId,
connectorId,
chatService,
initialMessages,
});

const conversationTitle = conversationId
? conversation.value?.conversation.title || ''
: defaultTitle;

const controllerRef = useRef(new AbortController());

const reloadRecalledMessages = useCallback(async () => {
setLoading(true);

setDisplayedMessages(initialMessages);

setRecalledMessages(undefined);
const reloadRecalledMessages = useCallback(
async (messages: Message[]) => {
controllerRef.current.abort();

controllerRef.current.abort();
const controller = (controllerRef.current = new AbortController());

const controller = (controllerRef.current = new AbortController());
const isStartOfConversation =
messages.some((message) => message.message.role === MessageRole.Assistant) === false;

let appendedMessages: Message[] = [];

if (chatService.hasFunction('recall')) {
try {
if (isStartOfConversation && chatService.hasFunction('recall')) {
// manually execute recall function and append to list of
// messages
const functionCall = {
Expand All @@ -86,7 +86,7 @@ function ChatContent({

const response = await chatService.executeFunction({
...functionCall,
messages: initialMessages,
messages,
signal: controller.signal,
connectorId,
});
Expand All @@ -95,7 +95,7 @@ function ChatContent({
throw new Error('Recall function unexpectedly returned an Observable');
}

appendedMessages = [
return [
{
'@timestamp': new Date().toISOString(),
message: {
Expand All @@ -117,43 +117,60 @@ function ChatContent({
},
},
];

setRecalledMessages(appendedMessages);
} catch (err) {
// eslint-disable-next-line no-console
console.error(err);
setRecalledMessages([]);
}
}
}, [chatService, connectorId, initialMessages, setDisplayedMessages]);

useEffect(() => {
let lastPendingMessage: PendingMessage | undefined;
return [];
},
[chatService, connectorId]
);

const reloadConversation = useCallback(async () => {
setLoading(true);

setDisplayedMessages(initialMessages);
setPendingMessage(undefined);

const messages = [getSystemMessage(), ...initialMessages];

if (recalledMessages === undefined) {
// don't do anything, it's loading
return;
}
const recalledMessages = await reloadRecalledMessages(messages);
const next = messages.concat(recalledMessages);

setDisplayedMessages(next);

let lastPendingMessage: PendingMessage | undefined;

const nextSubscription = chatService
.chat({ messages: displayedMessages.concat(recalledMessages), connectorId, function: 'none' })
.chat({ messages: next, connectorId, function: 'none' })
.subscribe({
next: (msg) => {
lastPendingMessage = msg;
setPendingMessage(() => msg);
},
complete: () => {
setDisplayedMessages((prev) =>
prev.concat({
'@timestamp': new Date().toISOString(),
...lastPendingMessage!,
})
);
setPendingMessage(lastPendingMessage);
setLoading(false);
},
});

setSubscription(nextSubscription);
}, [chatService, connectorId, displayedMessages, setDisplayedMessages, recalledMessages]);
}, [
reloadRecalledMessages,
chatService,
connectorId,
initialMessages,
getSystemMessage,
setDisplayedMessages,
]);

useEffect(() => {
reloadRecalledMessages();
}, [reloadRecalledMessages]);
reloadConversation();
}, [reloadConversation]);

useEffect(() => {
setDisplayedMessages(initialMessages);
Expand All @@ -163,25 +180,30 @@ function ChatContent({

const messagesWithPending = useMemo(() => {
return pendingMessage
? displayedMessages.concat(recalledMessages || []).concat({
? displayedMessages.concat({
'@timestamp': new Date().toISOString(),
message: {
...pendingMessage.message,
},
})
: displayedMessages.concat(recalledMessages || []);
}, [pendingMessage, displayedMessages, recalledMessages]);

const lastAssistantMessage = last(
messagesWithPending.filter((message) => message.message.role === MessageRole.Assistant)
: displayedMessages;
}, [pendingMessage, displayedMessages]);

const firstAssistantMessage = first(
messagesWithPending.filter(
(message) =>
message.message.role === MessageRole.Assistant &&
(!message.message.function_call?.trigger ||
message.message.function_call.trigger === MessageRole.Assistant)
)
);

return (
<>
<MessagePanel
body={
<MessageText
content={lastAssistantMessage?.message.content ?? ''}
content={firstAssistantMessage?.message.content ?? ''}
loading={loading}
onActionClick={async () => {}}
/>
Expand Down Expand Up @@ -216,7 +238,7 @@ function ChatContent({
<EuiFlexItem grow={false}>
<RegenerateResponseButton
onClick={() => {
reloadRecalledMessages();
reloadConversation();
}}
/>
</EuiFlexItem>
Expand All @@ -237,7 +259,7 @@ function ChatContent({
onClose={() => {
setIsOpen(() => false);
}}
messages={messagesWithPending}
messages={displayedMessages}
conversationId={conversationId}
startedFrom="contextualInsight"
onChatComplete={(nextMessages) => {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
*/
import { i18n } from '@kbn/i18n';
import { merge, omit } from 'lodash';
import { Dispatch, SetStateAction, useMemo, useState } from 'react';
import { Dispatch, SetStateAction, useCallback, useMemo, useState } from 'react';
import { type Conversation, type Message } from '../../common';
import { ConversationCreateRequest, MessageRole } from '../../common/types';
import { getAssistantSetupMessage } from '../service/get_assistant_setup_message';
Expand All @@ -20,14 +20,17 @@ export function useConversation({
conversationId,
chatService,
connectorId,
initialMessages = [],
}: {
conversationId?: string;
chatService?: ObservabilityAIAssistantChatService; // will eventually resolve to a non-nullish value
connectorId: string | undefined;
initialMessages?: Message[];
}): {
conversation: AbortableAsyncState<ConversationCreateRequest | Conversation | undefined>;
displayedMessages: Message[];
setDisplayedMessages: Dispatch<SetStateAction<Message[]>>;
getSystemMessage: () => Message;
save: (messages: Message[], handleRefreshConversations?: () => void) => Promise<Conversation>;
saveTitle: (
title: string,
Expand All @@ -40,20 +43,25 @@ export function useConversation({
services: { notifications },
} = useKibana();

const [displayedMessages, setDisplayedMessages] = useState<Message[]>([]);
const [displayedMessages, setDisplayedMessages] = useState<Message[]>(initialMessages);

const getSystemMessage = useCallback(() => {
return getAssistantSetupMessage({ contexts: chatService?.getContexts() || [] });
}, [chatService]);

const displayedMessagesWithHardcodedSystemMessage = useMemo(() => {
if (!chatService) {
return displayedMessages;
}
const systemMessage = getAssistantSetupMessage({ contexts: chatService?.getContexts() || [] });

const systemMessage = getSystemMessage();

if (displayedMessages[0]?.message.role === MessageRole.User) {
return [systemMessage, ...displayedMessages];
}

return [systemMessage, ...displayedMessages.slice(1)];
}, [displayedMessages, chatService]);
}, [displayedMessages, chatService, getSystemMessage]);

const conversation: AbortableAsyncState<ConversationCreateRequest | Conversation | undefined> =
useAbortableAsync(
Expand Down Expand Up @@ -87,6 +95,7 @@ export function useConversation({
conversation,
displayedMessages: displayedMessagesWithHardcodedSystemMessage,
setDisplayedMessages,
getSystemMessage,
save: (messages: Message[], handleRefreshConversations?: () => void) => {
const conversationObject = conversation.value!;

Expand All @@ -106,7 +115,13 @@ export function useConversation({
id: conversationId,
},
},
omit(conversationObject, 'conversation.last_updated', 'namespace', 'user'),
omit(
conversationObject,
'conversation.last_updated',
'namespace',
'user',
'messages'
),
{ messages }
),
},
Expand Down
Loading

0 comments on commit 6e21d81

Please sign in to comment.