Skip to content

Commit a7ac7bf

Browse files
committed
adding more unit tests
Signed-off-by: Dhrubo Saha <[email protected]>
1 parent 5fe48b2 commit a7ac7bf

File tree

5 files changed

+1330
-222
lines changed

5 files changed

+1330
-222
lines changed

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutor.java

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -337,8 +337,8 @@ private void executeAgent(
337337
}
338338
}
339339

340-
@SuppressWarnings("removal")
341-
private ActionListener<Object> createAgentActionListener(
340+
@VisibleForTesting
341+
ActionListener<Object> createAgentActionListener(
342342
ActionListener<Output> listener,
343343
List<ModelTensors> outputs,
344344
List<ModelTensor> modelTensors,
@@ -357,7 +357,8 @@ private ActionListener<Object> createAgentActionListener(
357357
});
358358
}
359359

360-
private ActionListener<Object> createAsyncTaskUpdater(MLTask mlTask, List<ModelTensors> outputs, List<ModelTensor> modelTensors) {
360+
@VisibleForTesting
361+
ActionListener<Object> createAsyncTaskUpdater(MLTask mlTask, List<ModelTensors> outputs, List<ModelTensor> modelTensors) {
361362
String taskId = mlTask.getTaskId();
362363
Map<String, Object> agentResponse = new HashMap<>();
363364
Map<String, Object> updatedTask = new HashMap<>();
@@ -534,7 +535,8 @@ private void retrieveAgent(String agentId, String tenantId, ActionListener<MLAge
534535
}
535536
}
536537

537-
private void handleAgentRetrievalError(Throwable throwable, String agentId, ActionListener<MLAgent> listener) {
538+
@VisibleForTesting
539+
void handleAgentRetrievalError(Throwable throwable, String agentId, ActionListener<MLAgent> listener) {
538540
Exception cause = SdkClientUtils.unwrapAndConvertToException(throwable);
539541
if (ExceptionsHelper.unwrap(cause, IndexNotFoundException.class) != null) {
540542
log.error("Failed to get Agent index", cause);
@@ -545,7 +547,8 @@ private void handleAgentRetrievalError(Throwable throwable, String agentId, Acti
545547
}
546548
}
547549

548-
private void parseAgentResponse(Object response, String agentId, String tenantId, ActionListener<MLAgent> listener) {
550+
@VisibleForTesting
551+
void parseAgentResponse(Object response, String agentId, String tenantId, ActionListener<MLAgent> listener) {
549552
try {
550553
// Cast to GetDataObjectResponse to access parser method
551554
org.opensearch.remote.metadata.client.GetDataObjectResponse getDataObjectResponse =
@@ -582,7 +585,8 @@ private void parseAgentResponse(Object response, String agentId, String tenantId
582585
}
583586
}
584587

585-
private MLMemorySpec configureMemorySpec(MLAgent mlAgent, AgentMLInput agentMLInput, RemoteInferenceInputDataSet inputDataSet) {
588+
@VisibleForTesting
589+
MLMemorySpec configureMemorySpec(MLAgent mlAgent, AgentMLInput agentMLInput, RemoteInferenceInputDataSet inputDataSet) {
586590
MLMemorySpec memorySpec = mlAgent.getMemory();
587591

588592
log.info("Request parameters keys: {}", inputDataSet.getParameters().keySet());
@@ -605,7 +609,8 @@ private MLMemorySpec configureMemorySpec(MLAgent mlAgent, AgentMLInput agentMLIn
605609
return memorySpec;
606610
}
607611

608-
private MLMemorySpec configureMemoryFromInput(Map<String, Object> memoryMap, RemoteInferenceInputDataSet inputDataSet) {
612+
@VisibleForTesting
613+
MLMemorySpec configureMemoryFromInput(Map<String, Object> memoryMap, RemoteInferenceInputDataSet inputDataSet) {
609614
String memoryType = (String) memoryMap.get("type");
610615
if (memoryType == null)
611616
return null;
@@ -642,7 +647,8 @@ private MLMemorySpec configureMemoryFromInput(Map<String, Object> memoryMap, Rem
642647
return MLMemorySpec.builder().type(memoryType).build();
643648
}
644649

645-
private void handleMemoryCreation(
650+
@VisibleForTesting
651+
void handleMemoryCreation(
646652
MLMemorySpec memorySpec,
647653
AgentMLInput agentMLInput,
648654
String agentId,
@@ -748,7 +754,8 @@ private void handleConversationMemory(
748754
}));
749755
}
750756

751-
private void handleBedrockMemory(
757+
@VisibleForTesting
758+
void handleBedrockMemory(
752759
BedrockAgentCoreMemory.Factory factory,
753760
AgentMLInput agentMLInput,
754761
String agentId,

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunner.java

Lines changed: 108 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -337,101 +337,118 @@ public void run(MLAgent mlAgent, Map<String, String> apiParams, ActionListener<O
337337
}), messageHistoryLimit);
338338
}, listener::onFailure));
339339
} else if (memoryFactory instanceof BedrockAgentCoreMemory.Factory) {
340-
BedrockAgentCoreMemory.Factory bedrockMemoryFactory = (BedrockAgentCoreMemory.Factory) memoryFactory;
341-
342-
// Build parameters for BedrockAgentCoreMemory from request parameters
343-
Map<String, Object> memoryParams = new HashMap<>();
344-
345-
// Extract memory configuration from parameters passed by MLAgentExecutor
346-
String memoryArn = allParams.get("memory_arn");
347-
String memoryRegion = allParams.get("memory_region");
348-
String accessKey = allParams.get("memory_access_key");
349-
String secretKey = allParams.get("memory_secret_key");
350-
String sessionToken = allParams.get("memory_session_token");
340+
handleBedrockAgentCoreMemory(
341+
(BedrockAgentCoreMemory.Factory) memoryFactory,
342+
mlAgent,
343+
allParams,
344+
memoryId,
345+
masterSessionId,
346+
listener
347+
);
348+
} else {
349+
// For other memory types, skip chat history
350+
log.info("Skipping chat history for memory type: {}", memoryType);
351+
List<String> completedSteps = new ArrayList<>();
352+
setToolsAndRunAgent(mlAgent, allParams, completedSteps, null, memoryId, listener);
353+
}
354+
}
351355

352-
if (memoryArn != null) {
353-
memoryParams.put("memory_arn", memoryArn);
354-
}
355-
if (memoryRegion != null) {
356-
memoryParams.put("region", memoryRegion);
357-
}
356+
@VisibleForTesting
357+
void handleBedrockAgentCoreMemory(
358+
BedrockAgentCoreMemory.Factory bedrockMemoryFactory,
359+
MLAgent mlAgent,
360+
Map<String, String> allParams,
361+
String memoryId,
362+
String masterSessionId,
363+
ActionListener<Object> listener
364+
) {
365+
// Build parameters for BedrockAgentCoreMemory from request parameters
366+
Map<String, Object> memoryParams = new HashMap<>();
367+
368+
// Extract memory configuration from parameters passed by MLAgentExecutor
369+
String memoryArn = allParams.get("memory_arn");
370+
String memoryRegion = allParams.get("memory_region");
371+
String accessKey = allParams.get("memory_access_key");
372+
String secretKey = allParams.get("memory_secret_key");
373+
String sessionToken = allParams.get("memory_session_token");
374+
375+
if (memoryArn != null) {
376+
memoryParams.put("memory_arn", memoryArn);
377+
}
378+
if (memoryRegion != null) {
379+
memoryParams.put("region", memoryRegion);
380+
}
358381

359-
// Use masterSessionId for BedrockAgentCoreMemory to maintain conversation continuity
360-
String sessionIdToUse = masterSessionId != null ? masterSessionId : memoryId;
361-
if (sessionIdToUse != null) {
362-
memoryParams.put("session_id", sessionIdToUse);
363-
log.info("DEBUG: Using session ID for BedrockAgentCoreMemory: {}", sessionIdToUse);
364-
}
382+
// Use masterSessionId for BedrockAgentCoreMemory to maintain conversation continuity
383+
String sessionIdToUse = masterSessionId != null ? masterSessionId : memoryId;
384+
if (sessionIdToUse != null) {
385+
memoryParams.put("session_id", sessionIdToUse);
386+
log.info("DEBUG: Using session ID for BedrockAgentCoreMemory: {}", sessionIdToUse);
387+
}
365388

366-
// Use agent ID from parameters (the actual agent execution ID) as agent_id - MANDATORY
367-
String agentIdToUse = allParams.get("agent_id");
368-
if (agentIdToUse == null) {
369-
throw new IllegalArgumentException(
370-
"Agent ID is mandatory but not found in parameters. This indicates a configuration issue - please check agent setup."
371-
);
372-
}
373-
memoryParams.put("agent_id", agentIdToUse);
374-
log.info("DEBUG: Using mandatory agent ID for BedrockAgentCoreMemory actorId: {}", agentIdToUse);
375-
376-
// Add credentials if available
377-
if (accessKey != null && secretKey != null) {
378-
Map<String, String> credentials = new HashMap<>();
379-
credentials.put("access_key", accessKey);
380-
credentials.put("secret_key", secretKey);
381-
if (sessionToken != null) {
382-
credentials.put("session_token", sessionToken);
383-
}
384-
memoryParams.put("credentials", credentials);
389+
// Use agent ID from parameters (the actual agent execution ID) as agent_id - MANDATORY
390+
String agentIdToUse = allParams.get("agent_id");
391+
if (agentIdToUse == null) {
392+
throw new IllegalArgumentException(
393+
"Agent ID is mandatory but not found in parameters. This indicates a configuration issue - please check agent setup."
394+
);
395+
}
396+
memoryParams.put("agent_id", agentIdToUse);
397+
log.info("DEBUG: Using mandatory agent ID for BedrockAgentCoreMemory actorId: {}", agentIdToUse);
398+
399+
// Add credentials if available
400+
if (accessKey != null && secretKey != null) {
401+
Map<String, String> credentials = new HashMap<>();
402+
credentials.put("access_key", accessKey);
403+
credentials.put("secret_key", secretKey);
404+
if (sessionToken != null) {
405+
credentials.put("session_token", sessionToken);
385406
}
407+
memoryParams.put("credentials", credentials);
408+
}
386409

387-
log.info("Creating BedrockAgentCoreMemory with params: memory_arn={}, region={}", memoryArn, memoryRegion);
410+
log.info("Creating BedrockAgentCoreMemory with params: memory_arn={}, region={}", memoryArn, memoryRegion);
388411

389-
bedrockMemoryFactory.create(memoryParams, ActionListener.wrap(bedrockMemory -> {
390-
// Get conversation history from Bedrock AgentCore using master session ID
391-
String sessionForHistory = masterSessionId != null ? masterSessionId : memoryId;
392-
bedrockMemory.getConversationHistory(sessionForHistory, ActionListener.wrap(records -> {
393-
List<String> completedSteps = new ArrayList<>();
412+
bedrockMemoryFactory.create(memoryParams, ActionListener.wrap(bedrockMemory -> {
413+
// Get conversation history from Bedrock AgentCore using master session ID
414+
String sessionForHistory = masterSessionId != null ? masterSessionId : memoryId;
415+
bedrockMemory.getConversationHistory(sessionForHistory, ActionListener.wrap(records -> {
416+
List<String> completedSteps = new ArrayList<>();
394417

395-
// Convert BedrockAgentCoreMemoryRecords to completed steps format (similar to ConversationIndexMemory)
396-
for (BedrockAgentCoreMemoryRecord record : records) {
397-
if (record != null && record.getContent() != null && record.getResponse() != null) {
398-
completedSteps.add(record.getContent()); // Question
399-
completedSteps.add(record.getResponse()); // Response
400-
}
418+
// Convert BedrockAgentCoreMemoryRecords to completed steps format (similar to ConversationIndexMemory)
419+
for (BedrockAgentCoreMemoryRecord record : records) {
420+
if (record != null && record.getContent() != null && record.getResponse() != null) {
421+
completedSteps.add(record.getContent()); // Question
422+
completedSteps.add(record.getResponse()); // Response
401423
}
424+
}
402425

403-
if (!completedSteps.isEmpty()) {
404-
addSteps(completedSteps, allParams, COMPLETED_STEPS_FIELD);
405-
usePlannerWithHistoryPromptTemplate(allParams);
406-
}
426+
if (!completedSteps.isEmpty()) {
427+
addSteps(completedSteps, allParams, COMPLETED_STEPS_FIELD);
428+
usePlannerWithHistoryPromptTemplate(allParams);
429+
}
407430

408-
setToolsAndRunAgent(
409-
mlAgent,
410-
allParams,
411-
completedSteps,
412-
bedrockMemory,
413-
masterSessionId != null ? masterSessionId : memoryId,
414-
listener
415-
);
416-
}, e -> {
417-
log.warn("Failed to get conversation history from BedrockAgentCoreMemory, proceeding without history", e);
418-
List<String> completedSteps = new ArrayList<>();
419-
setToolsAndRunAgent(
420-
mlAgent,
421-
allParams,
422-
completedSteps,
423-
bedrockMemory,
424-
masterSessionId != null ? masterSessionId : memoryId,
425-
listener
426-
);
427-
}));
428-
}, listener::onFailure));
429-
} else {
430-
// For other memory types, skip chat history
431-
log.info("Skipping chat history for memory type: {}", memoryType);
432-
List<String> completedSteps = new ArrayList<>();
433-
setToolsAndRunAgent(mlAgent, allParams, completedSteps, null, memoryId, listener);
434-
}
431+
setToolsAndRunAgent(
432+
mlAgent,
433+
allParams,
434+
completedSteps,
435+
bedrockMemory,
436+
masterSessionId != null ? masterSessionId : memoryId,
437+
listener
438+
);
439+
}, e -> {
440+
log.warn("Failed to get conversation history from BedrockAgentCoreMemory, proceeding without history", e);
441+
List<String> completedSteps = new ArrayList<>();
442+
setToolsAndRunAgent(
443+
mlAgent,
444+
allParams,
445+
completedSteps,
446+
bedrockMemory,
447+
masterSessionId != null ? masterSessionId : memoryId,
448+
listener
449+
);
450+
}));
451+
}, listener::onFailure));
435452
}
436453

437454
private void setToolsAndRunAgent(
@@ -993,7 +1010,8 @@ private Map<String, String> setupAllParameters(MLAgent mlAgent, Map<String, Stri
9931010
return allParams;
9941011
}
9951012

996-
private String configureMemoryType(MLAgent mlAgent, Map<String, String> allParams) {
1013+
@VisibleForTesting
1014+
String configureMemoryType(MLAgent mlAgent, Map<String, String> allParams) {
9971015
String memoryType = null;
9981016

9991017
// Get memory type from agent configuration (with null check)
@@ -1021,7 +1039,8 @@ private String configureMemoryType(MLAgent mlAgent, Map<String, String> allParam
10211039
return memoryType;
10221040
}
10231041

1024-
private void cacheBedrockMemoryConfig(MLAgent mlAgent, Map<String, String> allParams) {
1042+
@VisibleForTesting
1043+
void cacheBedrockMemoryConfig(MLAgent mlAgent, Map<String, String> allParams) {
10251044
String cacheKey = mlAgent.getName() + "_bedrock_config";
10261045
Map<String, String> bedrockConfig = new HashMap<>();
10271046
bedrockConfig.put("memory_type", "bedrock_agentcore_memory");
@@ -1034,7 +1053,8 @@ private void cacheBedrockMemoryConfig(MLAgent mlAgent, Map<String, String> allPa
10341053
log.info("DEBUG: Cached BedrockAgentCoreMemory config for agent: {}", mlAgent.getName());
10351054
}
10361055

1037-
private void restoreBedrockMemoryConfig(MLAgent mlAgent, Map<String, String> allParams) {
1056+
@VisibleForTesting
1057+
void restoreBedrockMemoryConfig(MLAgent mlAgent, Map<String, String> allParams) {
10381058
String cacheKey = mlAgent.getName() + "_bedrock_config";
10391059
Map<String, String> cachedConfig = bedrockMemoryConfigCache.get(cacheKey);
10401060

0 commit comments

Comments
 (0)