From 2a4cf2694797eec4f45d531e07ea1d43ca229365 Mon Sep 17 00:00:00 2001
From: Riya Saxena <riysaxen@amazon.com>
Date: Wed, 5 Jun 2024 17:12:41 -0700
Subject: [PATCH] addressing the comments

Signed-off-by: Riya Saxena <riysaxen@amazon.com>
---
 .../correlation/JoinEngine.java               |   8 +-
 .../alert/CorrelationAlertService.java        | 127 +++++++++++++++++-
 .../alert/CorrelationRuleScheduler.java       |  26 ++--
 .../TransportCorrelateFindingAction.java      |   9 +-
 4 files changed, 148 insertions(+), 22 deletions(-)

diff --git a/src/main/java/org/opensearch/securityanalytics/correlation/JoinEngine.java b/src/main/java/org/opensearch/securityanalytics/correlation/JoinEngine.java
index 20a6de813..03d4a0b73 100644
--- a/src/main/java/org/opensearch/securityanalytics/correlation/JoinEngine.java
+++ b/src/main/java/org/opensearch/securityanalytics/correlation/JoinEngine.java
@@ -21,6 +21,7 @@
 import org.opensearch.common.xcontent.XContentType;
 import org.opensearch.commons.alerting.action.PublishFindingsRequest;
 import org.opensearch.commons.alerting.model.Finding;
+import org.opensearch.commons.authuser.User;
 import org.opensearch.core.xcontent.NamedXContentRegistry;
 import org.opensearch.core.xcontent.XContentParser;
 import org.opensearch.index.query.BoolQueryBuilder;
@@ -80,9 +81,11 @@ public class JoinEngine {
 
     private static final Logger log = LogManager.getLogger(JoinEngine.class);
 
+    private final User user;
+
     public JoinEngine(Client client, PublishFindingsRequest request, NamedXContentRegistry xContentRegistry,
                       long corrTimeWindow, TimeValue indexTimeout, TransportCorrelateFindingAction.AsyncCorrelateFindingAction correlateFindingAction,
-                      LogTypeService logTypeService, boolean enableAutoCorrelations, CorrelationAlertService correlationAlertService, NotificationService notificationService) {
+                      LogTypeService logTypeService, boolean enableAutoCorrelations, CorrelationAlertService correlationAlertService, NotificationService notificationService, User user) {
         this.client = client;
         this.request = request;
         this.xContentRegistry = xContentRegistry;
@@ -93,6 +96,7 @@ public JoinEngine(Client client, PublishFindingsRequest request, NamedXContentRe
         this.enableAutoCorrelations = enableAutoCorrelations;
         this.correlationAlertService = correlationAlertService;
         this.notificationService = notificationService;
+        this.user = user;
     }
 
     public void onSearchDetectorResponse(Detector detector, Finding finding) {
@@ -555,7 +559,7 @@ private void getCorrelatedFindings(String detectorType, Map<String, List<String>
 
                 if (!correlatedFindings.isEmpty()) {
                      CorrelationRuleScheduler correlationRuleScheduler = new CorrelationRuleScheduler(client, correlationAlertService, notificationService);
-                     correlationRuleScheduler.schedule(correlationRules, correlatedFindings, request.getFinding().getId(), indexTimeout);
+                     correlationRuleScheduler.schedule(correlationRules, correlatedFindings, request.getFinding().getId(), indexTimeout, user);
                      correlationRuleScheduler.shutdown();
                 }
 
diff --git a/src/main/java/org/opensearch/securityanalytics/correlation/alert/CorrelationAlertService.java b/src/main/java/org/opensearch/securityanalytics/correlation/alert/CorrelationAlertService.java
index 4bc67b72c..f7aeb4e4d 100644
--- a/src/main/java/org/opensearch/securityanalytics/correlation/alert/CorrelationAlertService.java
+++ b/src/main/java/org/opensearch/securityanalytics/correlation/alert/CorrelationAlertService.java
@@ -11,6 +11,10 @@
 import org.opensearch.action.search.SearchRequest;
 import org.opensearch.action.search.SearchResponse;
 import org.opensearch.client.Client;
+import org.opensearch.common.lucene.uid.Versions;
+import org.opensearch.commons.alerting.model.ActionExecutionResult;
+import org.opensearch.commons.alerting.model.Alert;
+import org.opensearch.commons.authuser.User;
 import org.opensearch.common.unit.TimeValue;
 import org.opensearch.common.xcontent.LoggingDeprecationHandler;
 import org.opensearch.common.xcontent.XContentFactory;
@@ -19,6 +23,7 @@
 import org.opensearch.core.xcontent.NamedXContentRegistry;
 import org.opensearch.core.xcontent.XContentBuilder;
 import org.opensearch.core.xcontent.XContentParser;
+import org.opensearch.core.xcontent.XContentParserUtils;
 import org.opensearch.index.query.BoolQueryBuilder;
 import org.opensearch.index.query.QueryBuilders;
 import org.opensearch.search.SearchHit;
@@ -37,6 +42,24 @@ public class CorrelationAlertService {
     private final NamedXContentRegistry xContentRegistry;
     private final Client client;
 
+    protected static final String CORRELATED_FINDING_IDS = "correlated_finding_ids";
+    protected static final String CORRELATION_RULE_ID = "correlation_rule_id";
+    protected static final String CORRELATION_RULE_NAME = "correlation_rule_name";
+    protected static final String ALERT_ID_FIELD = "id";
+    protected static final String SCHEMA_VERSION_FIELD = "schema_version";
+    protected static final String ALERT_VERSION_FIELD = "version";
+    protected static final String USER_FIELD = "user";
+    protected static final String TRIGGER_NAME_FIELD = "trigger_name";
+    protected static final String STATE_FIELD = "state";
+    protected static final String START_TIME_FIELD = "start_time";
+    protected static final String END_TIME_FIELD = "end_time";
+    protected static final String ACKNOWLEDGED_TIME_FIELD = "acknowledged_time";
+    protected static final String ERROR_MESSAGE_FIELD = "error_message";
+    protected static final String SEVERITY_FIELD = "severity";
+    protected static final String ACTION_EXECUTION_RESULTS_FIELD = "action_execution_results";
+    protected static final String NO_ID = "";
+    protected static final long NO_VERSION = Versions.NOT_FOUND;
+
     public CorrelationAlertService(Client client, NamedXContentRegistry xContentRegistry) {
         this.client = client;
         this.xContentRegistry = xContentRegistry;
@@ -72,7 +95,7 @@ public void getActiveAlerts(String ruleId, long currentTime, ActionListener<Corr
                         listener.onResponse(new CorrelationAlertsList(Collections.emptyList(), 0));
                     } else {
                         listener.onResponse(new CorrelationAlertsList(
-                                Collections.emptyList(),
+                                parseCorrelationAlerts(searchResponse),
                                 searchResponse.getHits() != null && searchResponse.getHits().getTotalHits() != null ?
                                         (int) searchResponse.getHits().getTotalHits().value : 0)
                         );
@@ -125,12 +148,112 @@ public List<CorrelationAlert> parseCorrelationAlerts(final SearchResponse respon
                     hit.getSourceAsString()
             );
             xcp.nextToken();
-            CorrelationAlert correlationAlert = CorrelationAlert.parse(xcp, hit.getId(), hit.getVersion());
+            CorrelationAlert correlationAlert = parse(xcp, hit.getId(), hit.getVersion());
             alerts.add(correlationAlert);
         }
         return alerts;
     }
+
+    // logic will be moved to common-utils, once the parsing logic in common-utils is fixed
+    public static CorrelationAlert parse(XContentParser xcp, String id, long version) throws IOException {
+        // Parse additional CorrelationAlert-specific fields
+        List<String> correlatedFindingIds = new ArrayList<>();
+        String correlationRuleId = null;
+        String correlationRuleName = null;
+        User user = null;
+        int schemaVersion = 0;
+        String triggerName = null;
+        Alert.State state = null;
+        String errorMessage = null;
+        String severity = null;
+        List<ActionExecutionResult> actionExecutionResults = new ArrayList<>();
+        Instant startTime = null;
+        Instant endTime = null;
+        Instant acknowledgedTime = null;
+
+        while (xcp.nextToken() != XContentParser.Token.END_OBJECT) {
+            String fieldName = xcp.currentName();
+            xcp.nextToken();
+            switch (fieldName) {
+                case CORRELATED_FINDING_IDS:
+                    XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_ARRAY, xcp.currentToken(), xcp);
+                    while (xcp.nextToken() != XContentParser.Token.END_ARRAY) {
+                        correlatedFindingIds.add(xcp.text());
+                    }
+                    break;
+                case CORRELATION_RULE_ID:
+                    correlationRuleId = xcp.text();
+                    break;
+                case CORRELATION_RULE_NAME:
+                    correlationRuleName = xcp.text();
+                    break;
+                case USER_FIELD:
+                    user = (xcp.currentToken() == XContentParser.Token.VALUE_NULL) ? null : User.parse(xcp);
+                    break;
+                case ALERT_ID_FIELD:
+                    id = xcp.text();
+                    break;
+                case ALERT_VERSION_FIELD:
+                    version = xcp.longValue();
+                    break;
+                case SCHEMA_VERSION_FIELD:
+                    schemaVersion = xcp.intValue();
+                    break;
+                case TRIGGER_NAME_FIELD:
+                    triggerName = xcp.text();
+                    break;
+                case STATE_FIELD:
+                    state = Alert.State.valueOf(xcp.text());
+                    break;
+                case ERROR_MESSAGE_FIELD:
+                    errorMessage = xcp.textOrNull();
+                    break;
+                case SEVERITY_FIELD:
+                    severity = xcp.text();
+                    break;
+                case ACTION_EXECUTION_RESULTS_FIELD:
+                    XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_ARRAY, xcp.currentToken(), xcp);
+                    while (xcp.nextToken() != XContentParser.Token.END_ARRAY) {
+                        actionExecutionResults.add(ActionExecutionResult.parse(xcp));
+                    }
+                    break;
+                case START_TIME_FIELD:
+                    startTime = Instant.parse(xcp.text());
+                    break;
+                case END_TIME_FIELD:
+                    endTime = Instant.parse(xcp.text());
+                    break;
+                case ACKNOWLEDGED_TIME_FIELD:
+                    if (xcp.currentToken() == XContentParser.Token.VALUE_NULL) {
+                        acknowledgedTime = null;
+                    } else {
+                        acknowledgedTime = Instant.parse(xcp.text());
+                    }
+                    break;
+            }
+        }
+
+            // Create and return CorrelationAlert object
+            return new CorrelationAlert(
+                    correlatedFindingIds,
+                    correlationRuleId,
+                    correlationRuleName,
+                    id,
+                    version,
+                    schemaVersion,
+                    user,
+                    triggerName,
+                    state,
+                    startTime,
+                    endTime,
+                    acknowledgedTime,
+                    errorMessage,
+                    severity,
+                    actionExecutionResults
+            );
+    }
 }
 
 
 
+
diff --git a/src/main/java/org/opensearch/securityanalytics/correlation/alert/CorrelationRuleScheduler.java b/src/main/java/org/opensearch/securityanalytics/correlation/alert/CorrelationRuleScheduler.java
index bc0ca1809..945407b15 100644
--- a/src/main/java/org/opensearch/securityanalytics/correlation/alert/CorrelationRuleScheduler.java
+++ b/src/main/java/org/opensearch/securityanalytics/correlation/alert/CorrelationRuleScheduler.java
@@ -7,13 +7,13 @@
 import org.opensearch.common.unit.TimeValue;
 import org.opensearch.commons.alerting.model.Alert;
 import org.opensearch.commons.alerting.model.CorrelationAlert;
+import org.opensearch.commons.authuser.User;
 import org.opensearch.core.action.ActionListener;
 import org.opensearch.securityanalytics.model.CorrelationQuery;
 import org.opensearch.securityanalytics.model.CorrelationRule;
 import org.opensearch.securityanalytics.model.CorrelationRuleTrigger;
 import org.opensearch.securityanalytics.correlation.alert.notifications.NotificationService;
 import org.opensearch.securityanalytics.correlation.alert.notifications.CorrelationAlertContext;
-import org.opensearch.client.node.NodeClient;
 import org.opensearch.commons.alerting.model.action.Action;
 import org.opensearch.core.rest.RestStatus;
 import org.opensearch.securityanalytics.util.SecurityAnalyticsException;
@@ -24,7 +24,6 @@
 import java.util.Map;
 import java.util.concurrent.ExecutorService;
 import java.util.concurrent.Executors;
-import org.opensearch.script.ScriptService;
 
 public class CorrelationRuleScheduler {
 
@@ -33,17 +32,15 @@ public class CorrelationRuleScheduler {
     private final CorrelationAlertService correlationAlertService;
     private final NotificationService notificationService;
     private final ExecutorService executorService;
-    private static ScriptService scriptService;
 
     public CorrelationRuleScheduler(Client client, CorrelationAlertService correlationAlertService, NotificationService notificationService) {
         this.client = client;
-        this.scriptService = scriptService;
         this.correlationAlertService = correlationAlertService;
         this.notificationService = notificationService;
         this.executorService = Executors.newCachedThreadPool();
     }
 
-    public void schedule(List<CorrelationRule> correlationRules, Map<String, List<String>> correlatedFindings, String sourceFinding, TimeValue indexTimeout) {
+    public void schedule(List<CorrelationRule> correlationRules, Map<String, List<String>> correlatedFindings, String sourceFinding, TimeValue indexTimeout, User user) {
         for (CorrelationRule rule : correlationRules) {
             CorrelationRuleTrigger trigger = rule.getCorrelationTrigger();
             if (trigger != null) {
@@ -54,7 +51,7 @@ public void schedule(List<CorrelationRule> correlationRules, Map<String, List<St
                         findingIds.addAll(categoryFindingIds);
                     }
                 }
-                scheduleRule(rule, findingIds, indexTimeout, sourceFinding);
+                scheduleRule(rule, findingIds, indexTimeout, sourceFinding, user);
             }
         }
     }
@@ -63,10 +60,10 @@ public void shutdown() {
         executorService.shutdown();
     }
 
-    private void scheduleRule(CorrelationRule correlationRule, List<String> findingIds, TimeValue indexTimeout, String sourceFindingId) {
+    private void scheduleRule(CorrelationRule correlationRule, List<String> findingIds, TimeValue indexTimeout, String sourceFindingId, User user) {
         long startTime = Instant.now().toEpochMilli();
         long endTime = startTime + correlationRule.getCorrTimeWindow();
-        RuleTask ruleTask = new RuleTask(correlationRule, findingIds, startTime, endTime, correlationAlertService, notificationService, indexTimeout, sourceFindingId);
+        RuleTask ruleTask = new RuleTask(correlationRule, findingIds, startTime, endTime, correlationAlertService, notificationService, indexTimeout, sourceFindingId, user);
         executorService.submit(ruleTask);
     }
 
@@ -79,8 +76,9 @@ private class RuleTask implements Runnable {
         private final NotificationService notificationService;
         private final TimeValue indexTimeout;
         private final String sourceFindingId;
+        private final User user;
 
-        public RuleTask(CorrelationRule correlationRule, List<String> correlatedFindingIds, long startTime, long endTime, CorrelationAlertService correlationAlertService, NotificationService notificationService, TimeValue indexTimeout, String sourceFindingId) {
+        public RuleTask(CorrelationRule correlationRule, List<String> correlatedFindingIds, long startTime, long endTime, CorrelationAlertService correlationAlertService, NotificationService notificationService, TimeValue indexTimeout, String sourceFindingId, User user) {
             this.correlationRule = correlationRule;
             this.correlatedFindingIds = correlatedFindingIds;
             this.startTime = startTime;
@@ -89,6 +87,7 @@ public RuleTask(CorrelationRule correlationRule, List<String> correlatedFindingI
             this.notificationService = notificationService;
             this.indexTimeout = indexTimeout;
             this.sourceFindingId = sourceFindingId;
+            this.user = user;
         }
 
         @Override
@@ -103,13 +102,14 @@ public void onResponse(CorrelationAlertsList correlationAlertsList) {
                                 addCorrelationAlertIntoIndex();
                                 List<Action> actions = correlationRule.getCorrelationTrigger().getActions();
                                 for (Action action : actions) {
+                                    String configId = action.getDestinationId();
                                     CorrelationAlertContext ctx = new CorrelationAlertContext(correlatedFindingIds, correlationRule.getName(), correlationRule.getCorrTimeWindow(), sourceFindingId);
-                                    String transfomedSubject = notificationService.compileTemplate(ctx, action.getSubjectTemplate());
+                                    String transformedSubject = notificationService.compileTemplate(ctx, action.getSubjectTemplate());
                                     String transformedMessage = notificationService.compileTemplate(ctx, action.getMessageTemplate());
                                     try {
-                                        notificationService.sendNotification(action.getDestinationId(), correlationRule.getCorrelationTrigger().getSeverity(), transfomedSubject, transformedMessage);
+                                        notificationService.sendNotification(configId, correlationRule.getCorrelationTrigger().getSeverity(), transformedSubject, transformedMessage);
                                     } catch (Exception e) {
-                                        log.error("Failed while sending a notification: " + e.toString());
+                                        log.error("Failed while sending a notification with " + configId + "for correlationRule id " + correlationRule.getId(), e);
                                         new SecurityAnalyticsException("Failed to send notification", RestStatus.INTERNAL_SERVER_ERROR, e);
                                     }
 
@@ -142,7 +142,7 @@ private void addCorrelationAlertIntoIndex() {
                     UUID.randomUUID().toString(),
                     1L,
                     1,
-                    null,
+                    user,
                     correlationRule.getCorrelationTrigger().getName(),
                     Alert.State.ACTIVE,
                     Instant.ofEpochMilli(startTime),
diff --git a/src/main/java/org/opensearch/securityanalytics/transport/TransportCorrelateFindingAction.java b/src/main/java/org/opensearch/securityanalytics/transport/TransportCorrelateFindingAction.java
index 26c6b0e5b..e84d8b3e9 100644
--- a/src/main/java/org/opensearch/securityanalytics/transport/TransportCorrelateFindingAction.java
+++ b/src/main/java/org/opensearch/securityanalytics/transport/TransportCorrelateFindingAction.java
@@ -35,6 +35,7 @@
 import org.opensearch.commons.alerting.action.PublishFindingsRequest;
 import org.opensearch.commons.alerting.action.SubscribeFindingsResponse;
 import org.opensearch.commons.alerting.action.AlertingActions;
+import org.opensearch.commons.authuser.User;
 import org.opensearch.core.common.io.stream.InputStreamStreamInput;
 import org.opensearch.core.common.io.stream.OutputStreamStreamOutput;
 import org.opensearch.core.xcontent.NamedXContentRegistry;
@@ -140,7 +141,7 @@ public TransportCorrelateFindingAction(TransportService transportService,
     protected void doExecute(Task task, ActionRequest request, ActionListener<SubscribeFindingsResponse> actionListener) {
         try {
             PublishFindingsRequest transformedRequest = transformRequest(request);
-            AsyncCorrelateFindingAction correlateFindingAction = new AsyncCorrelateFindingAction(task, transformedRequest, actionListener);
+            AsyncCorrelateFindingAction correlateFindingAction = new AsyncCorrelateFindingAction(task, transformedRequest, readUserFromThreadContext(this.threadPool), actionListener);
 
             if (!this.correlationIndices.correlationIndexExists()) {
                 try {
@@ -213,14 +214,12 @@ public class AsyncCorrelateFindingAction {
         private final AtomicBoolean counter = new AtomicBoolean();
         private final Task task;
 
-        AsyncCorrelateFindingAction(Task task, PublishFindingsRequest request, ActionListener<SubscribeFindingsResponse> listener) {
+        AsyncCorrelateFindingAction(Task task, PublishFindingsRequest request, User user, ActionListener<SubscribeFindingsResponse> listener) {
             this.task = task;
             this.request = request;
             this.listener = listener;
-
             this.response =new AtomicReference<>();
-
-            this.joinEngine = new JoinEngine(client, request, xContentRegistry, corrTimeWindow, indexTimeout, this, logTypeService, enableAutoCorrelation, correlationAlertService, notificationService);
+            this.joinEngine = new JoinEngine(client, request, xContentRegistry, corrTimeWindow, indexTimeout, this, logTypeService, enableAutoCorrelation, correlationAlertService, notificationService, user);
             this.vectorEmbeddingsEngine = new VectorEmbeddingsEngine(client, indexTimeout, corrTimeWindow, this);
         }