From 6e21d81091506b7d3bcfc4bfae640d7160a5e960 Mon Sep 17 00:00:00 2001 From: Dario Gieselaar Date: Fri, 27 Oct 2023 13:34:48 +0200 Subject: [PATCH] [Obs AI Assistant] Fixes issue w/ duplicated recalls (#169927) Co-authored-by: Kibana Machine <42973632+kibanamachine@users.noreply.github.com> --- .../get_apm_timeseries/fetch_timeseries.ts | 17 ++- .../public/components/insight/insight.tsx | 122 +++++++++++------- .../public/hooks/use_conversation.ts | 25 +++- .../public/hooks/use_timeline.test.ts | 91 +++++++++---- .../public/hooks/use_timeline.ts | 11 +- .../server/service/client/index.ts | 8 +- .../server/service/kb_service/index.ts | 2 +- 7 files changed, 184 insertions(+), 92 deletions(-) diff --git a/x-pack/plugins/apm/server/routes/assistant_functions/get_apm_timeseries/fetch_timeseries.ts b/x-pack/plugins/apm/server/routes/assistant_functions/get_apm_timeseries/fetch_timeseries.ts index cbb9a54d31354..423f4908ea902 100644 --- a/x-pack/plugins/apm/server/routes/assistant_functions/get_apm_timeseries/fetch_timeseries.ts +++ b/x-pack/plugins/apm/server/routes/assistant_functions/get_apm_timeseries/fetch_timeseries.ts @@ -122,13 +122,22 @@ export async function fetchSeries({ } return response.aggregations.groupBy.buckets.map((bucket) => { + let value = + bucket.value?.value === undefined || bucket.value?.value === null + ? null + : Number(bucket.value.value); + + if (value !== null) { + value = + Math.abs(value) < 100 + ? Number(value.toPrecision(3)) + : Math.round(value); + } + return { groupBy: bucket.key_as_string || String(bucket.key), data: bucket.timeseries.buckets, - value: - bucket.value?.value === undefined || bucket.value?.value === null - ? null - : Math.round(bucket.value.value), + value, change_point: bucket.change_point, unit, }; diff --git a/x-pack/plugins/observability_ai_assistant/public/components/insight/insight.tsx b/x-pack/plugins/observability_ai_assistant/public/components/insight/insight.tsx index a48d53c942055..8c6463ca9318c 100644 --- a/x-pack/plugins/observability_ai_assistant/public/components/insight/insight.tsx +++ b/x-pack/plugins/observability_ai_assistant/public/components/insight/insight.tsx @@ -5,7 +5,7 @@ * 2.0. */ import React, { useCallback, useEffect, useMemo, useRef, useState } from 'react'; -import { last } from 'lodash'; +import { first } from 'lodash'; import { EuiFlexGroup, EuiFlexItem } from '@elastic/eui'; import { AbortError } from '@kbn/kibana-utils-plugin/common'; import { isObservable, Subscription } from 'rxjs'; @@ -42,19 +42,24 @@ function ChatContent({ const [pendingMessage, setPendingMessage] = useState(); - const [recalledMessages, setRecalledMessages] = useState(undefined); - const [loading, setLoading] = useState(false); const [subscription, setSubscription] = useState(); const [conversationId, setConversationId] = useState(); - const { conversation, displayedMessages, setDisplayedMessages, save, saveTitle } = - useConversation({ - conversationId, - connectorId, - chatService, - }); + const { + conversation, + displayedMessages, + setDisplayedMessages, + getSystemMessage, + save, + saveTitle, + } = useConversation({ + conversationId, + connectorId, + chatService, + initialMessages, + }); const conversationTitle = conversationId ? conversation.value?.conversation.title || '' @@ -62,21 +67,16 @@ function ChatContent({ const controllerRef = useRef(new AbortController()); - const reloadRecalledMessages = useCallback(async () => { - setLoading(true); - - setDisplayedMessages(initialMessages); - - setRecalledMessages(undefined); + const reloadRecalledMessages = useCallback( + async (messages: Message[]) => { + controllerRef.current.abort(); - controllerRef.current.abort(); + const controller = (controllerRef.current = new AbortController()); - const controller = (controllerRef.current = new AbortController()); + const isStartOfConversation = + messages.some((message) => message.message.role === MessageRole.Assistant) === false; - let appendedMessages: Message[] = []; - - if (chatService.hasFunction('recall')) { - try { + if (isStartOfConversation && chatService.hasFunction('recall')) { // manually execute recall function and append to list of // messages const functionCall = { @@ -86,7 +86,7 @@ function ChatContent({ const response = await chatService.executeFunction({ ...functionCall, - messages: initialMessages, + messages, signal: controller.signal, connectorId, }); @@ -95,7 +95,7 @@ function ChatContent({ throw new Error('Recall function unexpectedly returned an Observable'); } - appendedMessages = [ + return [ { '@timestamp': new Date().toISOString(), message: { @@ -117,43 +117,60 @@ function ChatContent({ }, }, ]; - - setRecalledMessages(appendedMessages); - } catch (err) { - // eslint-disable-next-line no-console - console.error(err); - setRecalledMessages([]); } - } - }, [chatService, connectorId, initialMessages, setDisplayedMessages]); - useEffect(() => { - let lastPendingMessage: PendingMessage | undefined; + return []; + }, + [chatService, connectorId] + ); + + const reloadConversation = useCallback(async () => { + setLoading(true); + + setDisplayedMessages(initialMessages); + setPendingMessage(undefined); + + const messages = [getSystemMessage(), ...initialMessages]; - if (recalledMessages === undefined) { - // don't do anything, it's loading - return; - } + const recalledMessages = await reloadRecalledMessages(messages); + const next = messages.concat(recalledMessages); + + setDisplayedMessages(next); + + let lastPendingMessage: PendingMessage | undefined; const nextSubscription = chatService - .chat({ messages: displayedMessages.concat(recalledMessages), connectorId, function: 'none' }) + .chat({ messages: next, connectorId, function: 'none' }) .subscribe({ next: (msg) => { lastPendingMessage = msg; setPendingMessage(() => msg); }, complete: () => { + setDisplayedMessages((prev) => + prev.concat({ + '@timestamp': new Date().toISOString(), + ...lastPendingMessage!, + }) + ); setPendingMessage(lastPendingMessage); setLoading(false); }, }); setSubscription(nextSubscription); - }, [chatService, connectorId, displayedMessages, setDisplayedMessages, recalledMessages]); + }, [ + reloadRecalledMessages, + chatService, + connectorId, + initialMessages, + getSystemMessage, + setDisplayedMessages, + ]); useEffect(() => { - reloadRecalledMessages(); - }, [reloadRecalledMessages]); + reloadConversation(); + }, [reloadConversation]); useEffect(() => { setDisplayedMessages(initialMessages); @@ -163,17 +180,22 @@ function ChatContent({ const messagesWithPending = useMemo(() => { return pendingMessage - ? displayedMessages.concat(recalledMessages || []).concat({ + ? displayedMessages.concat({ '@timestamp': new Date().toISOString(), message: { ...pendingMessage.message, }, }) - : displayedMessages.concat(recalledMessages || []); - }, [pendingMessage, displayedMessages, recalledMessages]); - - const lastAssistantMessage = last( - messagesWithPending.filter((message) => message.message.role === MessageRole.Assistant) + : displayedMessages; + }, [pendingMessage, displayedMessages]); + + const firstAssistantMessage = first( + messagesWithPending.filter( + (message) => + message.message.role === MessageRole.Assistant && + (!message.message.function_call?.trigger || + message.message.function_call.trigger === MessageRole.Assistant) + ) ); return ( @@ -181,7 +203,7 @@ function ChatContent({ {}} /> @@ -216,7 +238,7 @@ function ChatContent({ { - reloadRecalledMessages(); + reloadConversation(); }} /> @@ -237,7 +259,7 @@ function ChatContent({ onClose={() => { setIsOpen(() => false); }} - messages={messagesWithPending} + messages={displayedMessages} conversationId={conversationId} startedFrom="contextualInsight" onChatComplete={(nextMessages) => { diff --git a/x-pack/plugins/observability_ai_assistant/public/hooks/use_conversation.ts b/x-pack/plugins/observability_ai_assistant/public/hooks/use_conversation.ts index 965a8b899879a..6970c53e28bf1 100644 --- a/x-pack/plugins/observability_ai_assistant/public/hooks/use_conversation.ts +++ b/x-pack/plugins/observability_ai_assistant/public/hooks/use_conversation.ts @@ -6,7 +6,7 @@ */ import { i18n } from '@kbn/i18n'; import { merge, omit } from 'lodash'; -import { Dispatch, SetStateAction, useMemo, useState } from 'react'; +import { Dispatch, SetStateAction, useCallback, useMemo, useState } from 'react'; import { type Conversation, type Message } from '../../common'; import { ConversationCreateRequest, MessageRole } from '../../common/types'; import { getAssistantSetupMessage } from '../service/get_assistant_setup_message'; @@ -20,14 +20,17 @@ export function useConversation({ conversationId, chatService, connectorId, + initialMessages = [], }: { conversationId?: string; chatService?: ObservabilityAIAssistantChatService; // will eventually resolve to a non-nullish value connectorId: string | undefined; + initialMessages?: Message[]; }): { conversation: AbortableAsyncState; displayedMessages: Message[]; setDisplayedMessages: Dispatch>; + getSystemMessage: () => Message; save: (messages: Message[], handleRefreshConversations?: () => void) => Promise; saveTitle: ( title: string, @@ -40,20 +43,25 @@ export function useConversation({ services: { notifications }, } = useKibana(); - const [displayedMessages, setDisplayedMessages] = useState([]); + const [displayedMessages, setDisplayedMessages] = useState(initialMessages); + + const getSystemMessage = useCallback(() => { + return getAssistantSetupMessage({ contexts: chatService?.getContexts() || [] }); + }, [chatService]); const displayedMessagesWithHardcodedSystemMessage = useMemo(() => { if (!chatService) { return displayedMessages; } - const systemMessage = getAssistantSetupMessage({ contexts: chatService?.getContexts() || [] }); + + const systemMessage = getSystemMessage(); if (displayedMessages[0]?.message.role === MessageRole.User) { return [systemMessage, ...displayedMessages]; } return [systemMessage, ...displayedMessages.slice(1)]; - }, [displayedMessages, chatService]); + }, [displayedMessages, chatService, getSystemMessage]); const conversation: AbortableAsyncState = useAbortableAsync( @@ -87,6 +95,7 @@ export function useConversation({ conversation, displayedMessages: displayedMessagesWithHardcodedSystemMessage, setDisplayedMessages, + getSystemMessage, save: (messages: Message[], handleRefreshConversations?: () => void) => { const conversationObject = conversation.value!; @@ -106,7 +115,13 @@ export function useConversation({ id: conversationId, }, }, - omit(conversationObject, 'conversation.last_updated', 'namespace', 'user'), + omit( + conversationObject, + 'conversation.last_updated', + 'namespace', + 'user', + 'messages' + ), { messages } ), }, diff --git a/x-pack/plugins/observability_ai_assistant/public/hooks/use_timeline.test.ts b/x-pack/plugins/observability_ai_assistant/public/hooks/use_timeline.test.ts index 8d8afe6fb9cca..6ad1d0746a517 100644 --- a/x-pack/plugins/observability_ai_assistant/public/hooks/use_timeline.test.ts +++ b/x-pack/plugins/observability_ai_assistant/public/hooks/use_timeline.test.ts @@ -81,6 +81,12 @@ describe('useTimeline', () => { hookResult = renderHook((props) => useTimeline(props), { initialProps: { messages: [ + { + message: { + role: MessageRole.System, + content: 'You are a helpful assistant for Elastic Observability', + }, + }, { message: { role: MessageRole.User, @@ -122,6 +128,7 @@ describe('useTimeline', () => { chatService: { chat: () => {}, hasRenderFunction: () => {}, + hasFunction: () => {}, }, } as unknown as HookProps, }); @@ -308,35 +315,71 @@ describe('useTimeline', () => { canGiveFeedback: false, }, }); + }); - act(() => { - subject.next({ message: { role: MessageRole.Assistant, content: 'Goodbye' } }); + describe('and it pushes the next part', () => { + beforeEach(() => { + act(() => { + subject.next({ message: { role: MessageRole.Assistant, content: 'Goodbye' } }); + }); }); - expect(hookResult.result.current.items[2]).toMatchObject({ - role: MessageRole.Assistant, - content: 'Goodbye', - loading: true, - actions: { - canRegenerate: false, - canGiveFeedback: false, - }, + it('adds the partial response', () => { + expect(hookResult.result.current.items[2]).toMatchObject({ + role: MessageRole.Assistant, + content: 'Goodbye', + loading: true, + actions: { + canRegenerate: false, + canGiveFeedback: false, + }, + }); }); - act(() => { - subject.complete(); - }); + describe('and it completes', () => { + beforeEach(async () => { + act(() => { + subject.complete(); + }); - await hookResult.waitForNextUpdate(WAIT_OPTIONS); + await hookResult.waitForNextUpdate(WAIT_OPTIONS); + }); - expect(hookResult.result.current.items[2]).toMatchObject({ - role: MessageRole.Assistant, - content: 'Goodbye', - loading: false, - actions: { - canRegenerate: true, - canGiveFeedback: false, - }, + it('adds the completed message', () => { + expect(hookResult.result.current.items[2]).toMatchObject({ + role: MessageRole.Assistant, + content: 'Goodbye', + loading: false, + actions: { + canRegenerate: true, + canGiveFeedback: false, + }, + }); + }); + + describe('and the user edits a message', () => { + beforeEach(() => { + act(() => { + hookResult.result.current.onEdit( + hookResult.result.current.items[1] as ChatTimelineItem, + { + '@timestamp': new Date().toISOString(), + message: { content: 'Edited message', role: MessageRole.User }, + } + ); + subject.next({ message: { role: MessageRole.Assistant, content: '' } }); + subject.complete(); + }); + }); + + it('calls onChatUpdate with the edited message', () => { + expect(hookResult.result.current.items.length).toEqual(4); + expect((hookResult.result.current.items[2] as ChatTimelineItem).content).toEqual( + 'Edited message' + ); + expect((hookResult.result.current.items[3] as ChatTimelineItem).content).toEqual(''); + }); + }); }); }); @@ -379,7 +422,7 @@ describe('useTimeline', () => { }); }); - describe('and it being regenerated', () => { + describe('and it is being regenerated', () => { beforeEach(() => { act(() => { hookResult.result.current.onRegenerate( @@ -390,6 +433,8 @@ describe('useTimeline', () => { }); it('updates the last item in the array to be loading', () => { + expect(hookResult.result.current.items.length).toEqual(3); + expect(hookResult.result.current.items[2]).toEqual({ display: { hide: false, diff --git a/x-pack/plugins/observability_ai_assistant/public/hooks/use_timeline.ts b/x-pack/plugins/observability_ai_assistant/public/hooks/use_timeline.ts index 7db568a07a99e..64d82cabb9437 100644 --- a/x-pack/plugins/observability_ai_assistant/public/hooks/use_timeline.ts +++ b/x-pack/plugins/observability_ai_assistant/public/hooks/use_timeline.ts @@ -8,7 +8,7 @@ import { i18n } from '@kbn/i18n'; import { AbortError } from '@kbn/kibana-utils-plugin/common'; import type { AuthenticatedUser } from '@kbn/security-plugin/common'; -import { last } from 'lodash'; +import { flatten, last } from 'lodash'; import { useEffect, useMemo, useRef, useState } from 'react'; import usePrevious from 'react-use/lib/usePrevious'; import { isObservable, Observable, Subscription } from 'rxjs'; @@ -333,15 +333,16 @@ export function useTimeline({ return { items, onEdit: async (item, newMessage) => { - const indexOf = items.indexOf(item); - const sliced = messages.slice(0, indexOf - 1); + const indexOf = flatten(items).indexOf(item); + const sliced = messages.slice(0, indexOf); const nextMessages = await chat(sliced.concat(newMessage)); onChatComplete(nextMessages); }, onFeedback: (item, feedback) => {}, onRegenerate: (item) => { - const indexOf = items.indexOf(item); - chat(messages.slice(0, indexOf - 1)).then((nextMessages) => onChatComplete(nextMessages)); + const indexOf = flatten(items).indexOf(item); + + chat(messages.slice(0, indexOf)).then((nextMessages) => onChatComplete(nextMessages)); }, onStopGenerating: () => { subscription?.unsubscribe(); diff --git a/x-pack/plugins/observability_ai_assistant/server/service/client/index.ts b/x-pack/plugins/observability_ai_assistant/server/service/client/index.ts index f0df01d87168d..2332f63a54c78 100644 --- a/x-pack/plugins/observability_ai_assistant/server/service/client/index.ts +++ b/x-pack/plugins/observability_ai_assistant/server/service/client/index.ts @@ -100,7 +100,7 @@ export class ObservabilityAIAssistantClient { await this.dependencies.esClient.delete({ id: conversation._id, index: conversation._index, - refresh: 'wait_for', + refresh: true, }); }; @@ -244,7 +244,7 @@ export class ObservabilityAIAssistantClient { id: document._id, index: document._index, doc: updatedConversation, - refresh: 'wait_for', + refresh: true, }); return updatedConversation; @@ -334,7 +334,7 @@ export class ObservabilityAIAssistantClient { id: document._id, index: document._index, doc: { conversation: { title } }, - refresh: 'wait_for', + refresh: true, }); return updatedConversation; @@ -356,7 +356,7 @@ export class ObservabilityAIAssistantClient { await this.dependencies.esClient.index({ index: this.dependencies.resources.aliases.conversations, document: createdConversation, - refresh: 'wait_for', + refresh: true, }); return createdConversation; diff --git a/x-pack/plugins/observability_ai_assistant/server/service/kb_service/index.ts b/x-pack/plugins/observability_ai_assistant/server/service/kb_service/index.ts index d70879bf46d3e..be10c3eaaa5d5 100644 --- a/x-pack/plugins/observability_ai_assistant/server/service/kb_service/index.ts +++ b/x-pack/plugins/observability_ai_assistant/server/service/kb_service/index.ts @@ -218,7 +218,7 @@ export class KnowledgeBaseService { >({ index: this.dependencies.resources.aliases.kb, query, - size: 10, + size: 5, _source: { includes: ['text', 'is_correction', 'labels'], },