diff --git a/docs/src/content/docs/agents/built-in/openai-agent.mdx b/docs/src/content/docs/agents/built-in/openai-agent.mdx index 0d1ddbb2..46ebc63a 100644 --- a/docs/src/content/docs/agents/built-in/openai-agent.mdx +++ b/docs/src/content/docs/agents/built-in/openai-agent.mdx @@ -13,6 +13,8 @@ The `OpenAIAgent` is a powerful agent class in the Multi-Agent Orchestrator fram - Customizable inference configuration - Handles conversation history for context-aware responses - Customizable system prompts +- Optional integration with retrieval systems for enhanced context (Typescript only) +- Support for Tool use within the conversation flow (Typescript only) ## Creating an OpenAIAgent @@ -39,7 +41,41 @@ const agent = new OpenAIAgent({ topP: 0.9, stopSequences: ['Human:', 'AI:'] }, - systemPrompt: 'You are a helpful AI assistant specialized in answering questions about technology.' + systemPrompt: 'You are a helpful AI assistant specialized in answering questions about technology.', + toolConfig: { + tool: { + type: 'function', + function: { + name: "Weather_Tool", + description: "Get the current weather for a given location, based on its WGS84 coordinates.", + parameters: { + additionalProperties: false, + type: "object", + properties: { + latitude: { + type: "string", + description: "Geographical WGS84 latitude of the location.", + }, + longitude: { + type: "string", + description: "Geographical WGS84 longitude of the location.", + }, + }, + required: ["latitude", "longitude"], + }, + strict: true, + }, + }, + useToolHandler: (response, conversation) => { + //process tool response + return { + role: 'tool' as const, + tool_call_id: id, + content: JSON.stringify(response), + }; + }, + toolMaxRecursions: 5, + }, }); ``` @@ -62,6 +98,7 @@ The `OpenAIAgentOptions` extends the base `AgentOptions` and includes the follow - `topP` (optional): Controls diversity of output generation. - `stopSequences` (optional): An array of sequences that, when generated, will stop the generation process. - `systemPrompt` (optional): A string representing the initial system prompt for the agent. +- `toolConfig` (optional): Defines tools the agent can use and how to handle their responses. ## Setting the System Prompt diff --git a/typescript/src/agents/openAIAgent.ts b/typescript/src/agents/openAIAgent.ts index 9465bfe4..35f2f129 100644 --- a/typescript/src/agents/openAIAgent.ts +++ b/typescript/src/agents/openAIAgent.ts @@ -1,7 +1,8 @@ import { Agent, AgentOptions } from './agent'; -import { ConversationMessage, OPENAI_MODEL_ID_GPT_O_MINI, ParticipantRole } from '../types'; +import { ConversationMessage, OPENAI_MODEL_ID_GPT_O_MINI, ParticipantRole, TemplateVariables } from '../types'; import OpenAI from 'openai'; import { Logger } from '../utils/logger'; +import { Retriever } from '../retrievers/retriever'; export interface OpenAIAgentOptions extends AgentOptions { apiKey: string; @@ -13,6 +14,15 @@ export interface OpenAIAgentOptions extends AgentOptions { topP?: number; stopSequences?: string[]; }; + customSystemPrompt?: { + template: string, variables?: TemplateVariables + }; + retriever?: Retriever; + toolConfig?: { + tool: OpenAI.ChatCompletionTool[]; + useToolHandler: (response: any, conversation: any[]) => any; + toolMaxRecursions?: number; + }; } const DEFAULT_MAX_TOKENS = 1000; @@ -28,6 +38,18 @@ export class OpenAIAgent extends Agent { stopSequences?: string[]; }; + protected retriever?: Retriever; + + private toolConfig?: { + tool: OpenAI.ChatCompletionTool[]; + useToolHandler: (response: any, conversation: any[]) => any; + toolMaxRecursions?: number; + }; + + private promptTemplate: string; + private systemPrompt: string; + private customVariables: TemplateVariables; + constructor(options: OpenAIAgentOptions) { super(options); this.openai = new OpenAI({ apiKey: options.apiKey }); @@ -39,6 +61,35 @@ export class OpenAIAgent extends Agent { topP: options.inferenceConfig?.topP, stopSequences: options.inferenceConfig?.stopSequences, }; + + this.retriever = options.retriever; + this.toolConfig = options.toolConfig ?? null; + + this.systemPrompt = ''; + this.customVariables = {}; + + this.promptTemplate = `You are a ${this.name}. ${this.description} Provide helpful and accurate information based on your expertise. + You will engage in an open-ended conversation, providing helpful and accurate information based on your expertise. + The conversation will proceed as follows: + - The human may ask an initial question or provide a prompt on any topic. + - You will provide a relevant and informative response. + - The human may then follow up with additional questions or prompts related to your previous response, allowing for a multi-turn dialogue on that topic. + - Or, the human may switch to a completely new and unrelated topic at any point. + - You will seamlessly shift your focus to the new topic, providing thoughtful and coherent responses based on your broad knowledge base. + Throughout the conversation, you should aim to: + - Understand the context and intent behind each new question or prompt. + - Provide substantive and well-reasoned responses that directly address the query. + - Draw insights and connections from your extensive knowledge when appropriate. + - Ask for clarification if any part of the question or prompt is ambiguous. + - Maintain a consistent, respectful, and engaging tone tailored to the human's communication style. + - Seamlessly transition between topics as the human introduces new subjects.` + + if (options.customSystemPrompt) { + this.setSystemPrompt( + options.customSystemPrompt.template, + options.customSystemPrompt.variables + ); + } } /* eslint-disable @typescript-eslint/no-unused-vars */ @@ -49,8 +100,6 @@ export class OpenAIAgent extends Agent { chatHistory: ConversationMessage[], additionalParams?: Record ): Promise> { - - const messages = [ ...chatHistory.map(msg => ({ role: msg.role.toLowerCase() as OpenAI.Chat.ChatCompletionMessageParam['role'], @@ -59,6 +108,16 @@ export class OpenAIAgent extends Agent { { role: 'user' as const, content: inputText } ] as OpenAI.Chat.ChatCompletionMessageParam[]; + this.updateSystemPrompt() + + let systemPrompt = this.systemPrompt; + + if (this.retriever) { + const response = await this.retriever.retrieveAndCombineResults(inputText); + const contextPrompt = "\nHere is the context to use to answer the user's question:\n" + response; + systemPrompt = systemPrompt + contextPrompt; + } + const { maxTokens, temperature, topP, stopSequences } = this.inferenceConfig; const requestOptions: OpenAI.Chat.ChatCompletionCreateParams = { @@ -69,18 +128,51 @@ export class OpenAIAgent extends Agent { temperature, top_p: topP, stop: stopSequences, + tools: this.toolConfig?.tool || undefined, }; + try { + + if (this.streaming) { + return this.handleStreamingResponse(requestOptions); + } else { + let finalMessage: string = ''; + let toolUse = false; + let recursions = this.toolConfig?.toolMaxRecursions || 5; + + do { + const response = await this.handleSingleResponse(requestOptions); + + if (response.tool_calls) { + messages.push(response); + + if (!this.toolConfig) { + throw new Error('No tools available for tool use'); + } + const toolResponse = await this.toolConfig.useToolHandler(response, messages); + messages.push(toolResponse); + toolUse = true; + } else { + finalMessage = response.content; + toolUse = false; + } - if (this.streaming) { - return this.handleStreamingResponse(requestOptions); - } else { - return this.handleSingleResponse(requestOptions); + recursions--; + } while (toolUse && recursions > 0); + + return { + role: ParticipantRole.ASSISTANT, + content: [{ text: finalMessage }], + }; + } + } catch (error) { + Logger.logger.error('Error in OpenAI API call:', error); + throw error; } } - private async handleSingleResponse(input: any): Promise { + private async handleSingleResponse(input: any): Promise { try { const nonStreamingOptions = { ...input, stream: false }; const chatCompletion = await this.openai.chat.completions.create(nonStreamingOptions); @@ -89,33 +181,138 @@ export class OpenAIAgent extends Agent { throw new Error('No choices returned from OpenAI API'); } - const assistantMessage = chatCompletion.choices[0]?.message?.content; - - if (typeof assistantMessage !== 'string') { - throw new Error('Unexpected response format from OpenAI API'); - } - - return { - role: ParticipantRole.ASSISTANT, - content: [{ text: assistantMessage }], - }; + const message = chatCompletion.choices[0].message; + return message as OpenAI.Chat.ChatCompletionMessage; } catch (error) { Logger.logger.error('Error in OpenAI API call:', error); throw error; } } - private async *handleStreamingResponse(options: OpenAI.Chat.ChatCompletionCreateParams): AsyncIterable { - const stream = await this.openai.chat.completions.create({ ...options, stream: true }); - for await (const chunk of stream) { - const content = chunk.choices[0]?.delta?.content; - if (content) { - yield content; - } + setSystemPrompt(template?: string, variables?: TemplateVariables): void { + if (template) { + this.promptTemplate = template; + } + + if (variables) { + this.customVariables = variables; } + + this.updateSystemPrompt(); } + private async * handleStreamingResponse(options: OpenAI.Chat.ChatCompletionCreateParams): AsyncIterable { + let recursions = this.toolConfig?.toolMaxRecursions || 5; + + while (recursions > 0) { + // Add tool calls to messages before creating stream + const messagesWithToolCalls = [...options.messages]; + + const stream = await this.openai.chat.completions.create({ + ...options, + messages: messagesWithToolCalls, + stream: true + }); + + let currentToolCalls: any[] = []; + let hasToolCalls = false; + + for await (const chunk of stream) { + const toolCalls = chunk.choices[0]?.delta?.tool_calls; + + if (toolCalls) { + for (const toolCall of toolCalls) { + if (toolCall.id) { + currentToolCalls.push({ + id: toolCall.id, + function: toolCall.function, + }); + } + + if (toolCall.function?.arguments) { + const lastToolCall = currentToolCalls[currentToolCalls.length - 1]; + lastToolCall.function.arguments = (lastToolCall.function.arguments || '') + toolCall.function.arguments; + } + } + } + + if (chunk.choices[0]?.finish_reason === 'tool_calls') { + hasToolCalls = true; + const toolCallResults = []; + + // Add tool calls to messages before processing + messagesWithToolCalls.push({ + role: 'assistant', + tool_calls: currentToolCalls.map(tc => ({ + id: tc.id, + type: 'function', + function: tc.function + })) + }); + + for (const toolCall of currentToolCalls) { + try { + const toolResponse = await this.toolConfig.useToolHandler( + { tool_calls: [toolCall] }, + messagesWithToolCalls + ); + + toolCallResults.push({ + role: 'tool', + tool_call_id: toolCall.id, + content: JSON.stringify(toolResponse) + }); + } catch (error) { + console.error('Tool call error', error); + } + } + // Append tool call results to messages + messagesWithToolCalls.push(...toolCallResults); + // Update options for next iteration + options.messages = messagesWithToolCalls; -} \ No newline at end of file + currentToolCalls = []; + } + + const content = chunk.choices[0]?.delta?.content; + if (content) { + yield content; + } + } + + // Break if no tool calls were found + if (!hasToolCalls) break; + + recursions--; + } + } + + private updateSystemPrompt(): void { + const allVariables: TemplateVariables = { + ...this.customVariables + }; + + this.systemPrompt = this.replaceplaceholders( + this.promptTemplate, + allVariables + ); + } + + private replaceplaceholders( + template: string, + variables: TemplateVariables + ): string { + return template.replace(/{{(\w+)}}/g, (match, key) => { + if (key in variables) { + const value = variables[key]; + if (Array.isArray(value)) { + return value.join("\n"); + } + return value; + } + return match; // If no replacement found, leave the placeholder as is + }); + } +} diff --git a/typescript/tests/agents/Openai.test.ts b/typescript/tests/agents/Openai.test.ts new file mode 100644 index 00000000..f065e48e --- /dev/null +++ b/typescript/tests/agents/Openai.test.ts @@ -0,0 +1,162 @@ +import { OpenAIAgent } from '../../src/agents/openAIAgent'; +import { ConversationMessage, ParticipantRole } from '../../src/types'; +import { OpenAI } from 'openai'; + +jest.mock('openai'); + +describe('OpenAIAgent', () => { + let agent: OpenAIAgent; + let mockCreateCompletion: jest.Mock; + let toolMock: jest.Mock; + let toolHandler: jest.Mock; + + beforeEach(() => { + mockCreateCompletion = jest.fn(); + toolMock = jest.fn(); + toolHandler = jest.fn(); + + (OpenAI as jest.MockedClass).mockImplementation(() => ({ + chat: { + completions: { + create: mockCreateCompletion, + }, + }, + } as unknown as OpenAI)); + + agent = new OpenAIAgent({ + name: 'Test Agent', + description: 'A test agent', + apiKey: 'test-api-key', + model: 'test-model', + streaming: false, + inferenceConfig: { + maxTokens: 1000, + temperature: 0.5, + topP: 0.9, + stopSequences: ['\n'], + }, + toolConfig: { + tool: [{ + function: toolMock, + type: 'function', + }], + useToolHandler: toolHandler, + toolMaxRecursions: 5, + }, + }); + }); + + afterEach(() => { + jest.clearAllMocks(); + }); + + describe('processRequest', () => { + it('should process request successfully no function calls', async () => { + const inputText = 'test input'; + const userId = 'test-user'; + const sessionId = 'test-session'; + const chatHistory: ConversationMessage[] = []; + + mockCreateCompletion.mockResolvedValue({ + choices: [ + { + message: { + content: 'test response', + role: ParticipantRole.ASSISTANT, + }, + }, + ], + }); + + const response = await agent.processRequest(inputText, userId, sessionId, chatHistory); + + expect(response).toEqual({ + role: ParticipantRole.ASSISTANT, + content: [{ text: 'test response' }], + }); + expect(mockCreateCompletion).toHaveBeenCalledTimes(1); + }); + + it('should process request with multiple tool calls', async () => { + const inputText = 'test input'; + const userId = 'test-user'; + const sessionId = 'test-session'; + const chatHistory: ConversationMessage[] = []; + + // Create a mock implementation that changes on the second call + let callCount = 0; + mockCreateCompletion.mockImplementation(() => { + callCount++; + if (callCount === 1) { + // First call - return tool call response + return { + choices: [ + { + message: { + tool_calls: [ + { + id: 'call1', + function: { + name: 'test_function', + arguments: JSON.stringify({ test: 'argument' }), + }, + }, + ] + }, + }, + ], + }; + } else { + // Subsequent calls - return final response + return { + choices: [ + { + message: { + content: 'Final response after tool call', + }, + }, + ], + }; + } + }); + + // Mock tool handler + toolHandler.mockResolvedValue('function output') + + const response = await agent.processRequest(inputText, userId, sessionId, chatHistory); + + expect(mockCreateCompletion).toHaveBeenCalledTimes(2); // Ensure two calls were made + expect(toolHandler).toHaveBeenCalledTimes(1); + expect(response).toEqual({ + role: ParticipantRole.ASSISTANT, + content: [{ text: 'Final response after tool call' }], + }); + }); + + it('should throw an error if API returns no choices', async () => { + const inputText = 'test input'; + const userId = 'test-user'; + const sessionId = 'test-session'; + const chatHistory: ConversationMessage[] = []; + + mockCreateCompletion.mockResolvedValue({ + choices: [], + }); + + await expect(agent.processRequest(inputText, userId, sessionId, chatHistory)) + .rejects.toThrow('No choices returned from OpenAI API'); + }) + + it('should throw an error if API request fails', async () => { + const inputText = 'test input'; + const userId = 'test-user'; + const sessionId = 'test-session'; + const chatHistory: ConversationMessage[] = []; + + mockCreateCompletion.mockRejectedValue(new Error('API request failed')); + + await expect(agent.processRequest(inputText, userId, sessionId, chatHistory)) + .rejects.toThrow('API request failed'); + }); + }); +}); \ No newline at end of file