From 2a4cf2694797eec4f45d531e07ea1d43ca229365 Mon Sep 17 00:00:00 2001 From: Riya Saxena <riysaxen@amazon.com> Date: Wed, 5 Jun 2024 17:12:41 -0700 Subject: [PATCH] addressing the comments Signed-off-by: Riya Saxena <riysaxen@amazon.com> --- .../correlation/JoinEngine.java | 8 +- .../alert/CorrelationAlertService.java | 127 +++++++++++++++++- .../alert/CorrelationRuleScheduler.java | 26 ++-- .../TransportCorrelateFindingAction.java | 9 +- 4 files changed, 148 insertions(+), 22 deletions(-) diff --git a/src/main/java/org/opensearch/securityanalytics/correlation/JoinEngine.java b/src/main/java/org/opensearch/securityanalytics/correlation/JoinEngine.java index 20a6de813..03d4a0b73 100644 --- a/src/main/java/org/opensearch/securityanalytics/correlation/JoinEngine.java +++ b/src/main/java/org/opensearch/securityanalytics/correlation/JoinEngine.java @@ -21,6 +21,7 @@ import org.opensearch.common.xcontent.XContentType; import org.opensearch.commons.alerting.action.PublishFindingsRequest; import org.opensearch.commons.alerting.model.Finding; +import org.opensearch.commons.authuser.User; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.index.query.BoolQueryBuilder; @@ -80,9 +81,11 @@ public class JoinEngine { private static final Logger log = LogManager.getLogger(JoinEngine.class); + private final User user; + public JoinEngine(Client client, PublishFindingsRequest request, NamedXContentRegistry xContentRegistry, long corrTimeWindow, TimeValue indexTimeout, TransportCorrelateFindingAction.AsyncCorrelateFindingAction correlateFindingAction, - LogTypeService logTypeService, boolean enableAutoCorrelations, CorrelationAlertService correlationAlertService, NotificationService notificationService) { + LogTypeService logTypeService, boolean enableAutoCorrelations, CorrelationAlertService correlationAlertService, NotificationService notificationService, User user) { this.client = client; this.request = request; this.xContentRegistry = xContentRegistry; @@ -93,6 +96,7 @@ public JoinEngine(Client client, PublishFindingsRequest request, NamedXContentRe this.enableAutoCorrelations = enableAutoCorrelations; this.correlationAlertService = correlationAlertService; this.notificationService = notificationService; + this.user = user; } public void onSearchDetectorResponse(Detector detector, Finding finding) { @@ -555,7 +559,7 @@ private void getCorrelatedFindings(String detectorType, Map<String, List<String> if (!correlatedFindings.isEmpty()) { CorrelationRuleScheduler correlationRuleScheduler = new CorrelationRuleScheduler(client, correlationAlertService, notificationService); - correlationRuleScheduler.schedule(correlationRules, correlatedFindings, request.getFinding().getId(), indexTimeout); + correlationRuleScheduler.schedule(correlationRules, correlatedFindings, request.getFinding().getId(), indexTimeout, user); correlationRuleScheduler.shutdown(); } diff --git a/src/main/java/org/opensearch/securityanalytics/correlation/alert/CorrelationAlertService.java b/src/main/java/org/opensearch/securityanalytics/correlation/alert/CorrelationAlertService.java index 4bc67b72c..f7aeb4e4d 100644 --- a/src/main/java/org/opensearch/securityanalytics/correlation/alert/CorrelationAlertService.java +++ b/src/main/java/org/opensearch/securityanalytics/correlation/alert/CorrelationAlertService.java @@ -11,6 +11,10 @@ import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; import org.opensearch.client.Client; +import org.opensearch.common.lucene.uid.Versions; +import org.opensearch.commons.alerting.model.ActionExecutionResult; +import org.opensearch.commons.alerting.model.Alert; +import org.opensearch.commons.authuser.User; import org.opensearch.common.unit.TimeValue; import org.opensearch.common.xcontent.LoggingDeprecationHandler; import org.opensearch.common.xcontent.XContentFactory; @@ -19,6 +23,7 @@ import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.core.xcontent.XContentParserUtils; import org.opensearch.index.query.BoolQueryBuilder; import org.opensearch.index.query.QueryBuilders; import org.opensearch.search.SearchHit; @@ -37,6 +42,24 @@ public class CorrelationAlertService { private final NamedXContentRegistry xContentRegistry; private final Client client; + protected static final String CORRELATED_FINDING_IDS = "correlated_finding_ids"; + protected static final String CORRELATION_RULE_ID = "correlation_rule_id"; + protected static final String CORRELATION_RULE_NAME = "correlation_rule_name"; + protected static final String ALERT_ID_FIELD = "id"; + protected static final String SCHEMA_VERSION_FIELD = "schema_version"; + protected static final String ALERT_VERSION_FIELD = "version"; + protected static final String USER_FIELD = "user"; + protected static final String TRIGGER_NAME_FIELD = "trigger_name"; + protected static final String STATE_FIELD = "state"; + protected static final String START_TIME_FIELD = "start_time"; + protected static final String END_TIME_FIELD = "end_time"; + protected static final String ACKNOWLEDGED_TIME_FIELD = "acknowledged_time"; + protected static final String ERROR_MESSAGE_FIELD = "error_message"; + protected static final String SEVERITY_FIELD = "severity"; + protected static final String ACTION_EXECUTION_RESULTS_FIELD = "action_execution_results"; + protected static final String NO_ID = ""; + protected static final long NO_VERSION = Versions.NOT_FOUND; + public CorrelationAlertService(Client client, NamedXContentRegistry xContentRegistry) { this.client = client; this.xContentRegistry = xContentRegistry; @@ -72,7 +95,7 @@ public void getActiveAlerts(String ruleId, long currentTime, ActionListener<Corr listener.onResponse(new CorrelationAlertsList(Collections.emptyList(), 0)); } else { listener.onResponse(new CorrelationAlertsList( - Collections.emptyList(), + parseCorrelationAlerts(searchResponse), searchResponse.getHits() != null && searchResponse.getHits().getTotalHits() != null ? (int) searchResponse.getHits().getTotalHits().value : 0) ); @@ -125,12 +148,112 @@ public List<CorrelationAlert> parseCorrelationAlerts(final SearchResponse respon hit.getSourceAsString() ); xcp.nextToken(); - CorrelationAlert correlationAlert = CorrelationAlert.parse(xcp, hit.getId(), hit.getVersion()); + CorrelationAlert correlationAlert = parse(xcp, hit.getId(), hit.getVersion()); alerts.add(correlationAlert); } return alerts; } + + // logic will be moved to common-utils, once the parsing logic in common-utils is fixed + public static CorrelationAlert parse(XContentParser xcp, String id, long version) throws IOException { + // Parse additional CorrelationAlert-specific fields + List<String> correlatedFindingIds = new ArrayList<>(); + String correlationRuleId = null; + String correlationRuleName = null; + User user = null; + int schemaVersion = 0; + String triggerName = null; + Alert.State state = null; + String errorMessage = null; + String severity = null; + List<ActionExecutionResult> actionExecutionResults = new ArrayList<>(); + Instant startTime = null; + Instant endTime = null; + Instant acknowledgedTime = null; + + while (xcp.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = xcp.currentName(); + xcp.nextToken(); + switch (fieldName) { + case CORRELATED_FINDING_IDS: + XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_ARRAY, xcp.currentToken(), xcp); + while (xcp.nextToken() != XContentParser.Token.END_ARRAY) { + correlatedFindingIds.add(xcp.text()); + } + break; + case CORRELATION_RULE_ID: + correlationRuleId = xcp.text(); + break; + case CORRELATION_RULE_NAME: + correlationRuleName = xcp.text(); + break; + case USER_FIELD: + user = (xcp.currentToken() == XContentParser.Token.VALUE_NULL) ? null : User.parse(xcp); + break; + case ALERT_ID_FIELD: + id = xcp.text(); + break; + case ALERT_VERSION_FIELD: + version = xcp.longValue(); + break; + case SCHEMA_VERSION_FIELD: + schemaVersion = xcp.intValue(); + break; + case TRIGGER_NAME_FIELD: + triggerName = xcp.text(); + break; + case STATE_FIELD: + state = Alert.State.valueOf(xcp.text()); + break; + case ERROR_MESSAGE_FIELD: + errorMessage = xcp.textOrNull(); + break; + case SEVERITY_FIELD: + severity = xcp.text(); + break; + case ACTION_EXECUTION_RESULTS_FIELD: + XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_ARRAY, xcp.currentToken(), xcp); + while (xcp.nextToken() != XContentParser.Token.END_ARRAY) { + actionExecutionResults.add(ActionExecutionResult.parse(xcp)); + } + break; + case START_TIME_FIELD: + startTime = Instant.parse(xcp.text()); + break; + case END_TIME_FIELD: + endTime = Instant.parse(xcp.text()); + break; + case ACKNOWLEDGED_TIME_FIELD: + if (xcp.currentToken() == XContentParser.Token.VALUE_NULL) { + acknowledgedTime = null; + } else { + acknowledgedTime = Instant.parse(xcp.text()); + } + break; + } + } + + // Create and return CorrelationAlert object + return new CorrelationAlert( + correlatedFindingIds, + correlationRuleId, + correlationRuleName, + id, + version, + schemaVersion, + user, + triggerName, + state, + startTime, + endTime, + acknowledgedTime, + errorMessage, + severity, + actionExecutionResults + ); + } } + diff --git a/src/main/java/org/opensearch/securityanalytics/correlation/alert/CorrelationRuleScheduler.java b/src/main/java/org/opensearch/securityanalytics/correlation/alert/CorrelationRuleScheduler.java index bc0ca1809..945407b15 100644 --- a/src/main/java/org/opensearch/securityanalytics/correlation/alert/CorrelationRuleScheduler.java +++ b/src/main/java/org/opensearch/securityanalytics/correlation/alert/CorrelationRuleScheduler.java @@ -7,13 +7,13 @@ import org.opensearch.common.unit.TimeValue; import org.opensearch.commons.alerting.model.Alert; import org.opensearch.commons.alerting.model.CorrelationAlert; +import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; import org.opensearch.securityanalytics.model.CorrelationQuery; import org.opensearch.securityanalytics.model.CorrelationRule; import org.opensearch.securityanalytics.model.CorrelationRuleTrigger; import org.opensearch.securityanalytics.correlation.alert.notifications.NotificationService; import org.opensearch.securityanalytics.correlation.alert.notifications.CorrelationAlertContext; -import org.opensearch.client.node.NodeClient; import org.opensearch.commons.alerting.model.action.Action; import org.opensearch.core.rest.RestStatus; import org.opensearch.securityanalytics.util.SecurityAnalyticsException; @@ -24,7 +24,6 @@ import java.util.Map; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; -import org.opensearch.script.ScriptService; public class CorrelationRuleScheduler { @@ -33,17 +32,15 @@ public class CorrelationRuleScheduler { private final CorrelationAlertService correlationAlertService; private final NotificationService notificationService; private final ExecutorService executorService; - private static ScriptService scriptService; public CorrelationRuleScheduler(Client client, CorrelationAlertService correlationAlertService, NotificationService notificationService) { this.client = client; - this.scriptService = scriptService; this.correlationAlertService = correlationAlertService; this.notificationService = notificationService; this.executorService = Executors.newCachedThreadPool(); } - public void schedule(List<CorrelationRule> correlationRules, Map<String, List<String>> correlatedFindings, String sourceFinding, TimeValue indexTimeout) { + public void schedule(List<CorrelationRule> correlationRules, Map<String, List<String>> correlatedFindings, String sourceFinding, TimeValue indexTimeout, User user) { for (CorrelationRule rule : correlationRules) { CorrelationRuleTrigger trigger = rule.getCorrelationTrigger(); if (trigger != null) { @@ -54,7 +51,7 @@ public void schedule(List<CorrelationRule> correlationRules, Map<String, List<St findingIds.addAll(categoryFindingIds); } } - scheduleRule(rule, findingIds, indexTimeout, sourceFinding); + scheduleRule(rule, findingIds, indexTimeout, sourceFinding, user); } } } @@ -63,10 +60,10 @@ public void shutdown() { executorService.shutdown(); } - private void scheduleRule(CorrelationRule correlationRule, List<String> findingIds, TimeValue indexTimeout, String sourceFindingId) { + private void scheduleRule(CorrelationRule correlationRule, List<String> findingIds, TimeValue indexTimeout, String sourceFindingId, User user) { long startTime = Instant.now().toEpochMilli(); long endTime = startTime + correlationRule.getCorrTimeWindow(); - RuleTask ruleTask = new RuleTask(correlationRule, findingIds, startTime, endTime, correlationAlertService, notificationService, indexTimeout, sourceFindingId); + RuleTask ruleTask = new RuleTask(correlationRule, findingIds, startTime, endTime, correlationAlertService, notificationService, indexTimeout, sourceFindingId, user); executorService.submit(ruleTask); } @@ -79,8 +76,9 @@ private class RuleTask implements Runnable { private final NotificationService notificationService; private final TimeValue indexTimeout; private final String sourceFindingId; + private final User user; - public RuleTask(CorrelationRule correlationRule, List<String> correlatedFindingIds, long startTime, long endTime, CorrelationAlertService correlationAlertService, NotificationService notificationService, TimeValue indexTimeout, String sourceFindingId) { + public RuleTask(CorrelationRule correlationRule, List<String> correlatedFindingIds, long startTime, long endTime, CorrelationAlertService correlationAlertService, NotificationService notificationService, TimeValue indexTimeout, String sourceFindingId, User user) { this.correlationRule = correlationRule; this.correlatedFindingIds = correlatedFindingIds; this.startTime = startTime; @@ -89,6 +87,7 @@ public RuleTask(CorrelationRule correlationRule, List<String> correlatedFindingI this.notificationService = notificationService; this.indexTimeout = indexTimeout; this.sourceFindingId = sourceFindingId; + this.user = user; } @Override @@ -103,13 +102,14 @@ public void onResponse(CorrelationAlertsList correlationAlertsList) { addCorrelationAlertIntoIndex(); List<Action> actions = correlationRule.getCorrelationTrigger().getActions(); for (Action action : actions) { + String configId = action.getDestinationId(); CorrelationAlertContext ctx = new CorrelationAlertContext(correlatedFindingIds, correlationRule.getName(), correlationRule.getCorrTimeWindow(), sourceFindingId); - String transfomedSubject = notificationService.compileTemplate(ctx, action.getSubjectTemplate()); + String transformedSubject = notificationService.compileTemplate(ctx, action.getSubjectTemplate()); String transformedMessage = notificationService.compileTemplate(ctx, action.getMessageTemplate()); try { - notificationService.sendNotification(action.getDestinationId(), correlationRule.getCorrelationTrigger().getSeverity(), transfomedSubject, transformedMessage); + notificationService.sendNotification(configId, correlationRule.getCorrelationTrigger().getSeverity(), transformedSubject, transformedMessage); } catch (Exception e) { - log.error("Failed while sending a notification: " + e.toString()); + log.error("Failed while sending a notification with " + configId + "for correlationRule id " + correlationRule.getId(), e); new SecurityAnalyticsException("Failed to send notification", RestStatus.INTERNAL_SERVER_ERROR, e); } @@ -142,7 +142,7 @@ private void addCorrelationAlertIntoIndex() { UUID.randomUUID().toString(), 1L, 1, - null, + user, correlationRule.getCorrelationTrigger().getName(), Alert.State.ACTIVE, Instant.ofEpochMilli(startTime), diff --git a/src/main/java/org/opensearch/securityanalytics/transport/TransportCorrelateFindingAction.java b/src/main/java/org/opensearch/securityanalytics/transport/TransportCorrelateFindingAction.java index 26c6b0e5b..e84d8b3e9 100644 --- a/src/main/java/org/opensearch/securityanalytics/transport/TransportCorrelateFindingAction.java +++ b/src/main/java/org/opensearch/securityanalytics/transport/TransportCorrelateFindingAction.java @@ -35,6 +35,7 @@ import org.opensearch.commons.alerting.action.PublishFindingsRequest; import org.opensearch.commons.alerting.action.SubscribeFindingsResponse; import org.opensearch.commons.alerting.action.AlertingActions; +import org.opensearch.commons.authuser.User; import org.opensearch.core.common.io.stream.InputStreamStreamInput; import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; import org.opensearch.core.xcontent.NamedXContentRegistry; @@ -140,7 +141,7 @@ public TransportCorrelateFindingAction(TransportService transportService, protected void doExecute(Task task, ActionRequest request, ActionListener<SubscribeFindingsResponse> actionListener) { try { PublishFindingsRequest transformedRequest = transformRequest(request); - AsyncCorrelateFindingAction correlateFindingAction = new AsyncCorrelateFindingAction(task, transformedRequest, actionListener); + AsyncCorrelateFindingAction correlateFindingAction = new AsyncCorrelateFindingAction(task, transformedRequest, readUserFromThreadContext(this.threadPool), actionListener); if (!this.correlationIndices.correlationIndexExists()) { try { @@ -213,14 +214,12 @@ public class AsyncCorrelateFindingAction { private final AtomicBoolean counter = new AtomicBoolean(); private final Task task; - AsyncCorrelateFindingAction(Task task, PublishFindingsRequest request, ActionListener<SubscribeFindingsResponse> listener) { + AsyncCorrelateFindingAction(Task task, PublishFindingsRequest request, User user, ActionListener<SubscribeFindingsResponse> listener) { this.task = task; this.request = request; this.listener = listener; - this.response =new AtomicReference<>(); - - this.joinEngine = new JoinEngine(client, request, xContentRegistry, corrTimeWindow, indexTimeout, this, logTypeService, enableAutoCorrelation, correlationAlertService, notificationService); + this.joinEngine = new JoinEngine(client, request, xContentRegistry, corrTimeWindow, indexTimeout, this, logTypeService, enableAutoCorrelation, correlationAlertService, notificationService, user); this.vectorEmbeddingsEngine = new VectorEmbeddingsEngine(client, indexTimeout, corrTimeWindow, this); }