diff --git a/src/agents/agent.ts b/src/agents/agent.ts index 33966809..bf2c1ca6 100644 --- a/src/agents/agent.ts +++ b/src/agents/agent.ts @@ -47,6 +47,9 @@ export interface AgentOptions { // Optional: The geographic region where the agent should be deployed or run region?: string; + + // Optional: Determines whether to save the chat, defaults to true + saveChat?: boolean; } /** @@ -63,6 +66,9 @@ export abstract class Agent { /** A description of the agent's capabilities and expertise. */ description: string; + /** Whether to save the chat or not. */ + saveChat: boolean; + /** * Constructs a new Agent instance. * @param options - Configuration options for the agent. @@ -71,6 +77,7 @@ export abstract class Agent { this.name = options.name; this.id = this.generateKeyFromName(options.name); this.description = options.description; + this.saveChat = options.saveChat ?? true; // Default to true if not provided } /** diff --git a/src/orchestrator.ts b/src/orchestrator.ts index f392adff..b1948884 100644 --- a/src/orchestrator.ts +++ b/src/orchestrator.ts @@ -8,7 +8,7 @@ import { BedrockLLMAgent } from "./agents/bedrockLLMAgent"; import { ChatStorage } from "./storage/chatStorage"; import { InMemoryChatStorage } from "./storage/memoryChatStorage"; import { AccumulatorTransform } from "./utils/helpers"; -import { saveChat } from "./utils/chatUtils"; +import { saveConversationExchange } from "./utils/chatUtils"; import { Logger } from "./utils/logger"; import { BedrockClassifier } from "./classifiers/bedrockClassifier"; import { Classifier } from "./classifiers/classifier"; @@ -283,7 +283,7 @@ export class MultiAgentOrchestrator { return obj != null && typeof obj[Symbol.asyncIterator] === "function"; } - private async dispatchToAgent( + async dispatchToAgent( params: DispatchToAgentsParams ): Promise> { const { @@ -355,6 +355,7 @@ export class MultiAgentOrchestrator { "Classifying user intent", () => this.classifier.classify(userInput, chatHistory) ); + this.logger.printIntent(userInput, classifierResult); } catch (error) { this.logger.error("Error during intent classification:", error); @@ -364,7 +365,7 @@ export class MultiAgentOrchestrator { streaming: false, }; } - + // Handle case where no agent was selected if (!classifierResult.selectedAgent) { if (this.config.USE_DEFAULT_AGENT_IF_NONE_IDENTIFIED) { @@ -398,7 +399,7 @@ export class MultiAgentOrchestrator { userInput, userId, sessionId, - metadata + classifierResult.selectedAgent ); return { metadata, @@ -407,16 +408,20 @@ export class MultiAgentOrchestrator { }; } - await saveChat( - userInput, - agentResponse, - this.storage, - userId, - sessionId, - classifierResult.selectedAgent.id, - this.config.MAX_MESSAGE_PAIRS_PER_AGENT - ); - + // Check if we should save the conversation + if (classifierResult?.selectedAgent.saveChat) { + await saveConversationExchange( + userInput, + agentResponse, + this.storage, + userId, + sessionId, + classifierResult?.selectedAgent.id, + this.config.MAX_MESSAGE_PAIRS_PER_AGENT + ); + } + + return { metadata, output: agentResponse, @@ -440,7 +445,7 @@ export class MultiAgentOrchestrator { userInput: string, userId: string, sessionId: string, - metadata: any + agent: Agent ): Promise { const streamStartTime = Date.now(); let chunkCount = 0; @@ -462,14 +467,20 @@ export class MultiAgentOrchestrator { const fullResponse = accumulatorTransform.getAccumulatedData(); if (fullResponse) { - await saveChat( + + + + if (agent.saveChat) { + await saveConversationExchange( userInput, fullResponse, this.storage, userId, sessionId, - metadata.agentId + agent.id ); + } + } else { this.logger.warn("No data accumulated, messages not saved"); } diff --git a/src/utils/chatUtils.ts b/src/utils/chatUtils.ts index 81274fcb..ac357dad 100644 --- a/src/utils/chatUtils.ts +++ b/src/utils/chatUtils.ts @@ -2,7 +2,7 @@ import { ChatStorage } from '../storage/chatStorage'; import { ConversationMessage, ParticipantRole } from '../types'; -export async function saveChat( +export async function saveConversationExchange( userInput: string, agentResponse: string, storage: ChatStorage, diff --git a/tests/Orchestrator.test.ts b/tests/Orchestrator.test.ts index 9506d488..4cebdbe7 100644 --- a/tests/Orchestrator.test.ts +++ b/tests/Orchestrator.test.ts @@ -6,12 +6,16 @@ import { ParticipantRole } from '../src/types/index'; import { ClassifierResult } from '../src/classifiers/classifier'; import { ConversationMessage } from '../src/types/index'; import { AccumulatorTransform } from '../src/utils/helpers'; +import * as chatUtils from '../src/utils/chatUtils'; // Mock the dependencies -jest.mock('../src/agents/agent'); jest.mock('../src/classifiers/bedrockClassifier'); jest.mock('../src/storage/memoryChatStorage'); jest.mock('../src/utils/helpers'); +jest.mock('../src/utils/chatUtils', () => ({ + saveConversationExchange: jest.fn(), +})); + // Create a mock Agent class class MockAgent extends Agent { @@ -34,20 +38,20 @@ class MockAgent extends Agent { } } + describe('MultiAgentOrchestrator', () => { + + let orchestrator: MultiAgentOrchestrator; - let mockAgent: jest.Mocked; + let mockAgent: Agent; let mockClassifier: jest.Mocked; let mockStorage: jest.Mocked; beforeEach(() => { - // Create mock instances - mockAgent = { - id: 'test-agent', + mockAgent = new MockAgent({ name: 'Test Agent', - description: 'A test agent', - processRequest: jest.fn(), - } as unknown as jest.Mocked; + description: 'A test agent' + }); mockClassifier = new BedrockClassifier() as jest.Mocked; mockStorage = new InMemoryChatStorage() as jest.Mocked; @@ -68,6 +72,7 @@ describe('MultiAgentOrchestrator', () => { }); test('addAgent adds an agent and updates classifier', () => { + const newAgent: Agent = { id: 'new-agent', name: 'New Agent', @@ -80,7 +85,9 @@ describe('MultiAgentOrchestrator', () => { }); test('getAllAgents returns all added agents', () => { + const agents = orchestrator.getAllAgents(); + expect(agents).toHaveProperty('test-agent'); expect(agents['test-agent']).toEqual({ name: 'Test Agent', @@ -89,6 +96,8 @@ describe('MultiAgentOrchestrator', () => { }); test('routeRequest with identified agent', async () => { + + const processRequestSpy = jest.spyOn(mockAgent, 'processRequest'); const userInput = 'Test input'; const userId = 'user1'; const sessionId = 'session1'; @@ -100,64 +109,30 @@ describe('MultiAgentOrchestrator', () => { mockClassifier.classify.mockResolvedValue(mockClassifierResult); - (mockAgent.processRequest as jest.Mock).mockImplementation(async () => ({ - role: 'assistant', - content: [{ text: 'Agent response' }], - })); - mockStorage.fetchAllChats.mockResolvedValue([]); const response = await orchestrator.routeRequest(userInput, userId, sessionId); - expect(response.output).toBe('Agent response'); + expect(response.output).toBe('Mock response'); expect(response.metadata.agentId).toBe(mockAgent.id); expect(mockStorage.fetchAllChats).toHaveBeenCalledWith(userId, sessionId); expect(mockClassifier.classify).toHaveBeenCalledWith(userInput, []); - expect(mockAgent.processRequest).toHaveBeenCalled(); + expect(processRequestSpy).toHaveBeenCalled(); }); -// test('routeRequest with no identified agent', async () => { -// const userInput = 'Unclassifiable input'; -// const userId = 'user1'; -// const sessionId = 'session1'; - -// const mockClassifierResult: ClassifierResult = { -// selectedAgent: null, -// confidence: 0, -// }; - -// mockClassifier.classify.mockResolvedValue(mockClassifierResult); - -// const defaultAgent = orchestrator.getDefaultAgent(); -// jest.spyOn(defaultAgent, 'processRequest').mockImplementation(async () => ({ -// role: 'assistant', -// content: [{ text: 'Default agent response' }], -// })); - -// const response = await orchestrator.routeRequest(userInput, userId, sessionId); - -// expect(response.output).toBe('Default agent response'); -// expect(response.metadata.agentId).toBe(AgentTypes.DEFAULT); -// }); - -// test('routeRequest with classification error', async () => { -// const userInput = 'Error-causing input'; -// const userId = 'user1'; -// const sessionId = 'session1'; - -// mockClassifier.classify.mockRejectedValue(new Error('Classification error')); - -// const response = await orchestrator.routeRequest(userInput, userId, sessionId); - -// expect(response.output).toBe(orchestrator['config'].CLASSIFICATION_ERROR_MESSAGE); -// expect(response.metadata.errorType).toBe('classification_failed'); -// }); test('routeRequest with streaming response', async () => { const userInput = 'Stream input'; const userId = 'user1'; const sessionId = 'session1'; + const mockAgent = { + id: 'test-agent', + name: 'Test Agent', + description: 'A test agent', + processRequest: jest.fn(), + } as unknown as jest.Mocked; + const mockClassifierResult: ClassifierResult = { selectedAgent: mockAgent, confidence: 0.8, @@ -226,5 +201,126 @@ describe('MultiAgentOrchestrator', () => { // Verify that the classifier's setAgents method was only called once (for the first agent) expect(mockClassifier.setAgents).toHaveBeenCalledTimes(2); // Once in beforeEach, once for existingAgent }); + + +}); + +describe('MultiAgentOrchestrator saveConversationExchange', () => { + let orchestrator: MultiAgentOrchestrator; + let mockAgent: MockAgent; + let mockClassifier: jest.Mocked; + let mockStorage: jest.Mocked; + + beforeEach(() => { + jest.clearAllMocks(); + + mockAgent = new MockAgent({ + name: 'Mock Agent', + description: 'A mock agent', + saveChat: true + }); + + mockClassifier = new BedrockClassifier() as jest.Mocked; + mockStorage = new InMemoryChatStorage() as jest.Mocked; + + const options: OrchestratorOptions = { + storage: mockStorage, + classifier: mockClassifier + }; + + orchestrator = new MultiAgentOrchestrator(options); + orchestrator.addAgent(mockAgent); + }); + + test('routeRequest calls saveConversationExchange for default saveChat', async () => { + + const mockAgent = new MockAgent({ + name: 'Mock Agent', + description: 'A mock agent' + }); + + + const mockClassifierResult: ClassifierResult = { + selectedAgent: mockAgent, + confidence: 0.5, + }; + + mockClassifier.classify.mockResolvedValue(mockClassifierResult); + mockStorage.fetchAllChats.mockResolvedValue([]); + + jest.spyOn(orchestrator, 'dispatchToAgent').mockResolvedValue('Mock agent response'); + + await orchestrator.routeRequest('Test input', 'user1', 'session1'); + + expect(chatUtils.saveConversationExchange).toHaveBeenCalledWith( + 'Test input', + 'Mock agent response', + mockStorage, + 'user1', + 'session1', + 'mock-agent', + expect.any(Number) + ); + }); + + + test('routeRequest calls saveConversationExchange for saveChat=true', async () => { + + const mockAgent = new MockAgent({ + name: 'Mock Agent', + description: 'A mock agent', + saveChat: true + }); + + + const mockClassifierResult: ClassifierResult = { + selectedAgent: mockAgent, + confidence: 0.5, + }; + + mockClassifier.classify.mockResolvedValue(mockClassifierResult); + mockStorage.fetchAllChats.mockResolvedValue([]); + + jest.spyOn(orchestrator, 'dispatchToAgent').mockResolvedValue('Mock agent response'); + + await orchestrator.routeRequest('Test input', 'user1', 'session1'); + + expect(chatUtils.saveConversationExchange).toHaveBeenCalledWith( + 'Test input', + 'Mock agent response', + mockStorage, + 'user1', + 'session1', + 'mock-agent', + expect.any(Number) + ); + }); + + test('routeRequest do not calls saveConversationExchange for saveChat=false', async () => { + + const mockAgent = new MockAgent({ + name: 'Mock Agent', + description: 'A mock agent', + saveChat: false + }); + + + const mockClassifierResult: ClassifierResult = { + selectedAgent: mockAgent, + confidence: 0.5, + }; + + mockClassifier.classify.mockResolvedValue(mockClassifierResult); + mockStorage.fetchAllChats.mockResolvedValue([]); + + jest.spyOn(orchestrator, 'dispatchToAgent').mockResolvedValue('Mock agent response'); + + await orchestrator.routeRequest('Test input', 'user1', 'session1'); + + expect(chatUtils.saveConversationExchange).not.toHaveBeenCalled(); + }); + + + }); \ No newline at end of file