From 39c01117f30046ea4bac1c3ac8d18ab108a0eb4d Mon Sep 17 00:00:00 2001 From: nicolasburtey Date: Mon, 20 May 2024 14:54:32 -0600 Subject: [PATCH] =?UTF-8?q?fix:=20assistant=20function=20call=20when=20mul?= =?UTF-8?q?tple=20tool=20are=20needed=20in=20the=20same=E2=80=A6=20(#4483)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix: assistant function call when multple tool are needed in the same query * fix: ts --------- Co-authored-by: Nicolas Burtey --- core/api/src/domain/support/errors.ts | 2 + core/api/src/graphql/error-map.ts | 1 + core/api/src/services/openai/assistant.ts | 68 +++++++++++++---------- 3 files changed, 43 insertions(+), 28 deletions(-) diff --git a/core/api/src/domain/support/errors.ts b/core/api/src/domain/support/errors.ts index fb6d95cad3..35afc9e185 100644 --- a/core/api/src/domain/support/errors.ts +++ b/core/api/src/domain/support/errors.ts @@ -13,6 +13,8 @@ export class UnknownPineconeError extends SupportError { export class ChatAssistantError extends SupportError {} export class ChatAssistantNotFoundError extends SupportError {} +export class TimeoutAssistantError extends SupportError {} + export class UnknownChatAssistantError extends ChatAssistantError { level = ErrorLevel.Critical } diff --git a/core/api/src/graphql/error-map.ts b/core/api/src/graphql/error-map.ts index 6b5852c404..2b68ec5748 100644 --- a/core/api/src/graphql/error-map.ts +++ b/core/api/src/graphql/error-map.ts @@ -823,6 +823,7 @@ export const mapError = (error: ApplicationError): CustomGraphQLError => { case "UnknownPineconeError": case "CallbackServiceError": case "ChatAssistantNotFoundError": + case "TimeoutAssistantError": message = `Unknown error occurred (code: ${error.name})` return new UnknownClientError({ message, logger: baseLogger }) diff --git a/core/api/src/services/openai/assistant.ts b/core/api/src/services/openai/assistant.ts index 8f937cccbd..46816a3d33 100644 --- a/core/api/src/services/openai/assistant.ts +++ b/core/api/src/services/openai/assistant.ts @@ -11,6 +11,7 @@ import { sleep } from "@/utils" import { UnknownDomainError } from "@/domain/shared" import { ChatAssistantNotFoundError, + TimeoutAssistantError, UnknownChatAssistantError, } from "@/domain/support/errors" @@ -134,30 +135,36 @@ export const Assistant = (): ChatAssistant => { } } - const processAction = async (run: OpenAI.Beta.Threads.Runs.Run) => { + const processAction = async (run: OpenAI.Beta.Threads.Runs.Run): Promise => { const action = run.required_action assert(action?.type === "submit_tool_outputs") - const name = action.submit_tool_outputs.tool_calls[0].function.name - assert(name === "queryBlinkKnowledgeBase") + const outputs: string[] = [] - const args = action.submit_tool_outputs.tool_calls[0].function.arguments - const query = JSON.parse(args).query_str + for (const toolCall of action.submit_tool_outputs.tool_calls) { + const name = toolCall.function.name + assert(name === "queryBlinkKnowledgeBase") - const vector = await textToVector(query) - if (vector instanceof Error) throw vector + const args = toolCall.function.arguments + const query = JSON.parse(args).query_str - const relatedQueries = await retrieveRelatedQueries(vector) - if (relatedQueries instanceof Error) throw relatedQueries + const vector = await textToVector(query) + if (vector instanceof Error) throw vector - let output = "" - let i = 0 - for (const query of relatedQueries) { - output += `Context chunk ${i}:\n${query}\n-----\n` - i += 1 + const relatedQueries = await retrieveRelatedQueries(vector) + if (relatedQueries instanceof Error) throw relatedQueries + + let output = "" + let i = 0 + for (const query of relatedQueries) { + output += `Context chunk ${i}:\n${query}\n-----\n` + i += 1 + } + + outputs.push(output) } - return output + return outputs } const waitForCompletion = async ({ @@ -166,8 +173,11 @@ export const Assistant = (): ChatAssistant => { }: { runId: string threadId: string - }) => { + }): Promise => { let run: OpenAI.Beta.Threads.Runs.Run + const maxRetries = 60 // Assuming a 30-second timeout with 500ms sleep + let retries = 0 + try { run = await openai.beta.threads.runs.retrieve(threadId, runId) } catch (err) { @@ -177,10 +187,14 @@ export const Assistant = (): ChatAssistant => { while ( ["queued", "in_progress", "cancelling", "requires_action"].includes(run.status) ) { - // TODO: max timer for this loop - // add open telemetry here? or is it already present with the http requests? + if (retries >= maxRetries) { + return new TimeoutAssistantError() + } + // Add telemetry here if needed await sleep(500) + retries += 1 + try { run = await openai.beta.threads.runs.retrieve(threadId, runId) } catch (err) { @@ -188,21 +202,19 @@ export const Assistant = (): ChatAssistant => { } if (run.status === "requires_action") { - let output: string + let outputs: string[] try { - output = await processAction(run) + outputs = await processAction(run) } catch (err) { return new UnknownChatAssistantError(err) } try { await openai.beta.threads.runs.submitToolOutputs(threadId, runId, { - tool_outputs: [ - { - tool_call_id: run.required_action?.submit_tool_outputs.tool_calls[0].id, - output, - }, - ], + tool_outputs: outputs.map((output, index) => ({ + tool_call_id: run.required_action?.submit_tool_outputs.tool_calls[index].id, + output, + })), }) } catch (err) { return new UnknownChatAssistantError(err) @@ -222,12 +234,12 @@ export const Assistant = (): ChatAssistant => { const responseThread = messages.data[0] if (responseThread.content[0]?.type !== "text") { - return new UnknownChatAssistantError("last message is not text") + return new UnknownChatAssistantError("Last message is not text") } return true } else { - return new UnknownChatAssistantError("issue running the assistant") + return new UnknownChatAssistantError("Issue running the assistant") } }