diff --git a/src/main/java/org/opensearch/securityanalytics/SecurityAnalyticsPlugin.java b/src/main/java/org/opensearch/securityanalytics/SecurityAnalyticsPlugin.java index e7fe43106..f97afcb60 100644 --- a/src/main/java/org/opensearch/securityanalytics/SecurityAnalyticsPlugin.java +++ b/src/main/java/org/opensearch/securityanalytics/SecurityAnalyticsPlugin.java @@ -17,6 +17,7 @@ import org.opensearch.action.ActionRequest; import org.opensearch.core.action.ActionResponse; import org.opensearch.client.Client; +import org.opensearch.client.node.NodeClient; import org.opensearch.cluster.metadata.IndexNameExpressionResolver; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.cluster.node.DiscoveryNodes; @@ -51,8 +52,33 @@ import org.opensearch.rest.RestController; import org.opensearch.rest.RestHandler; import org.opensearch.script.ScriptService; -import org.opensearch.securityanalytics.action.*; +import org.opensearch.securityanalytics.action.GetAlertsAction; +import org.opensearch.securityanalytics.action.DeleteCorrelationRuleAction; +import org.opensearch.securityanalytics.action.AckAlertsAction; +import org.opensearch.securityanalytics.action.CreateIndexMappingsAction; +import org.opensearch.securityanalytics.action.CorrelatedFindingAction; +import org.opensearch.securityanalytics.action.DeleteCustomLogTypeAction; +import org.opensearch.securityanalytics.action.DeleteDetectorAction; +import org.opensearch.securityanalytics.action.DeleteRuleAction; +import org.opensearch.securityanalytics.action.GetAllRuleCategoriesAction; +import org.opensearch.securityanalytics.action.GetDetectorAction; +import org.opensearch.securityanalytics.action.GetFindingsAction; +import org.opensearch.securityanalytics.action.GetIndexMappingsAction; +import org.opensearch.securityanalytics.action.GetMappingsViewAction; +import org.opensearch.securityanalytics.action.IndexCorrelationRuleAction; +import org.opensearch.securityanalytics.action.IndexCustomLogTypeAction; +import org.opensearch.securityanalytics.action.IndexDetectorAction; +import org.opensearch.securityanalytics.action.IndexRuleAction; +import org.opensearch.securityanalytics.action.ListCorrelationsAction; +import org.opensearch.securityanalytics.action.SearchCorrelationRuleAction; +import org.opensearch.securityanalytics.action.SearchCustomLogTypeAction; +import org.opensearch.securityanalytics.action.SearchDetectorAction; +import org.opensearch.securityanalytics.action.SearchRuleAction; +import org.opensearch.securityanalytics.action.UpdateIndexMappingsAction; +import org.opensearch.securityanalytics.action.ValidateRulesAction; import org.opensearch.securityanalytics.correlation.index.codec.CorrelationCodecService; +import org.opensearch.securityanalytics.correlation.alert.CorrelationAlertService; +import org.opensearch.securityanalytics.correlation.alert.notifications.NotificationService; import org.opensearch.securityanalytics.correlation.index.mapper.CorrelationVectorFieldMapper; import org.opensearch.securityanalytics.correlation.index.query.CorrelationQueryBuilder; import org.opensearch.securityanalytics.indexmanagment.DetectorIndexManagementService; @@ -165,13 +191,14 @@ public Collection createComponents(Client client, TIFJobParameterService tifJobParameterService = new TIFJobParameterService(client, clusterService); TIFJobUpdateService tifJobUpdateService = new TIFJobUpdateService(clusterService, tifJobParameterService, threatIntelFeedDataService, builtInTIFMetadataLoader); TIFLockService threatIntelLockService = new TIFLockService(clusterService, client); - + CorrelationAlertService correlationAlertService = new CorrelationAlertService(client, xContentRegistry); + NotificationService notificationServiceService = new NotificationService((NodeClient)client, scriptService); TIFJobRunner.getJobRunnerInstance().initialize(clusterService, tifJobUpdateService, tifJobParameterService, threatIntelLockService, threadPool, detectorThreatIntelService); return List.of( detectorIndices, correlationIndices, correlationRuleIndices, ruleTopicIndices, customLogTypeIndices, ruleIndices, mapperService, indexTemplateManager, builtinLogTypeLoader, builtInTIFMetadataLoader, threatIntelFeedDataService, detectorThreatIntelService, - tifJobUpdateService, tifJobParameterService, threatIntelLockService); + tifJobUpdateService, tifJobParameterService, threatIntelLockService, correlationAlertService, notificationServiceService); } @Override diff --git a/src/main/java/org/opensearch/securityanalytics/correlation/JoinEngine.java b/src/main/java/org/opensearch/securityanalytics/correlation/JoinEngine.java index 3b4314e12..20cff273a 100644 --- a/src/main/java/org/opensearch/securityanalytics/correlation/JoinEngine.java +++ b/src/main/java/org/opensearch/securityanalytics/correlation/JoinEngine.java @@ -21,6 +21,7 @@ import org.opensearch.common.xcontent.XContentType; import org.opensearch.commons.alerting.action.PublishFindingsRequest; import org.opensearch.commons.alerting.model.Finding; +import org.opensearch.commons.authuser.User; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.index.query.BoolQueryBuilder; @@ -32,9 +33,13 @@ import org.opensearch.search.SearchHit; import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.securityanalytics.config.monitors.DetectorMonitorConfig; +import org.opensearch.securityanalytics.correlation.alert.CorrelationAlertService; +import org.opensearch.securityanalytics.correlation.alert.CorrelationRuleScheduler; +import org.opensearch.securityanalytics.correlation.alert.notifications.NotificationService; import org.opensearch.securityanalytics.logtype.LogTypeService; import org.opensearch.securityanalytics.model.CorrelationQuery; import org.opensearch.securityanalytics.model.CorrelationRule; +import org.opensearch.securityanalytics.model.CorrelationRuleTrigger; import org.opensearch.securityanalytics.model.Detector; import org.opensearch.securityanalytics.transport.TransportCorrelateFindingAction; import org.opensearch.securityanalytics.util.AutoCorrelationsRepo; @@ -68,18 +73,30 @@ public class JoinEngine { private final LogTypeService logTypeService; + private final CorrelationAlertService correlationAlertService; + + private final NotificationService notificationService; + + private volatile TimeValue indexTimeout; + private static final Logger log = LogManager.getLogger(JoinEngine.class); + private final User user; + public JoinEngine(Client client, PublishFindingsRequest request, NamedXContentRegistry xContentRegistry, - long corrTimeWindow, TransportCorrelateFindingAction.AsyncCorrelateFindingAction correlateFindingAction, - LogTypeService logTypeService, boolean enableAutoCorrelations) { + long corrTimeWindow, TimeValue indexTimeout, TransportCorrelateFindingAction.AsyncCorrelateFindingAction correlateFindingAction, + LogTypeService logTypeService, boolean enableAutoCorrelations, CorrelationAlertService correlationAlertService, NotificationService notificationService, User user) { this.client = client; this.request = request; this.xContentRegistry = xContentRegistry; this.corrTimeWindow = corrTimeWindow; + this.indexTimeout = indexTimeout; this.correlateFindingAction = correlateFindingAction; this.logTypeService = logTypeService; this.enableAutoCorrelations = enableAutoCorrelations; + this.correlationAlertService = correlationAlertService; + this.notificationService = notificationService; + this.user = user; } public void onSearchDetectorResponse(Detector detector, Finding finding) { @@ -349,7 +366,7 @@ private void getValidDocuments(String detectorType, List indices, List it.correlationRule).map(CorrelationRule::getId).collect(Collectors.toList()), + filteredCorrelationRules.stream().map(it -> it.correlationRule).collect(Collectors.toList()), autoCorrelations ); }, this::onFailure)); @@ -362,7 +379,7 @@ private void getValidDocuments(String detectorType, List indices, List> categoryToQueriesMap, Map categoryToTimeWindowMap, List correlationRules, Map> autoCorrelations) { + private void searchFindingsByTimestamp(String detectorType, Map> categoryToQueriesMap, Map categoryToTimeWindowMap, List correlationRules, Map> autoCorrelations) { long findingTimestamp = request.getFinding().getTimestamp().toEpochMilli(); MultiSearchRequest mSearchRequest = new MultiSearchRequest(); List>> categoryToQueriesPairs = new ArrayList<>(); @@ -418,14 +435,14 @@ private void searchFindingsByTimestamp(String detectorType, Map relatedDocsMap, Map categoryToTimeWindowMap, List correlationRules, Map> autoCorrelations) { + private void searchDocsWithFilterKeys(String detectorType, Map relatedDocsMap, Map categoryToTimeWindowMap, List correlationRules, Map> autoCorrelations) { MultiSearchRequest mSearchRequest = new MultiSearchRequest(); List categories = new ArrayList<>(); @@ -476,7 +493,7 @@ private void searchDocsWithFilterKeys(String detectorType, Map> filteredRelatedDocIds, Map categoryToTimeWindowMap, List correlationRules, Map> autoCorrelations) { + private void getCorrelatedFindings(String detectorType, Map> filteredRelatedDocIds, Map categoryToTimeWindowMap, List correlationRules, Map> autoCorrelations) { long findingTimestamp = request.getFinding().getTimestamp().toEpochMilli(); MultiSearchRequest mSearchRequest = new MultiSearchRequest(); List categories = new ArrayList<>(); @@ -540,6 +557,11 @@ private void getCorrelatedFindings(String detectorType, Map ++idx; } + if (!correlatedFindings.isEmpty()) { + CorrelationRuleScheduler correlationRuleScheduler = new CorrelationRuleScheduler(client, correlationAlertService, notificationService); + correlationRuleScheduler.schedule(correlationRules, correlatedFindings, request.getFinding().getId(), indexTimeout, user); + } + for (Map.Entry> autoCorrelation: autoCorrelations.entrySet()) { if (correlatedFindings.containsKey(autoCorrelation.getKey())) { Set alreadyCorrelatedFindings = new HashSet<>(correlatedFindings.get(autoCorrelation.getKey())); @@ -549,10 +571,10 @@ private void getCorrelatedFindings(String detectorType, Map correlatedFindings.put(autoCorrelation.getKey(), autoCorrelation.getValue()); } } - correlateFindingAction.initCorrelationIndex(detectorType, correlatedFindings, correlationRules); + correlateFindingAction.initCorrelationIndex(detectorType, correlatedFindings, correlationRules.stream().map(CorrelationRule::getId).collect(Collectors.toList())); }, this::onFailure)); } else { - getTimestampFeature(detectorType, correlationRules, autoCorrelations); + getTimestampFeature(detectorType, correlationRules.stream().map(CorrelationRule::getId).collect(Collectors.toList()), autoCorrelations); } } diff --git a/src/main/java/org/opensearch/securityanalytics/correlation/alert/CorrelationAlertService.java b/src/main/java/org/opensearch/securityanalytics/correlation/alert/CorrelationAlertService.java new file mode 100644 index 000000000..f7aeb4e4d --- /dev/null +++ b/src/main/java/org/opensearch/securityanalytics/correlation/alert/CorrelationAlertService.java @@ -0,0 +1,259 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.securityanalytics.correlation.alert; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.index.IndexRequest; +import org.opensearch.action.index.IndexResponse; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.client.Client; +import org.opensearch.common.lucene.uid.Versions; +import org.opensearch.commons.alerting.model.ActionExecutionResult; +import org.opensearch.commons.alerting.model.Alert; +import org.opensearch.commons.authuser.User; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.common.xcontent.LoggingDeprecationHandler; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.core.xcontent.XContentParserUtils; +import org.opensearch.index.query.BoolQueryBuilder; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.search.SearchHit; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.commons.alerting.model.CorrelationAlert; +import org.opensearch.securityanalytics.util.CorrelationIndices; +import java.io.IOException; +import java.time.Instant; +import java.util.List; +import java.util.ArrayList; +import java.util.Collections; + +public class CorrelationAlertService { + private static final Logger log = LogManager.getLogger(CorrelationAlertService.class); + + private final NamedXContentRegistry xContentRegistry; + private final Client client; + + protected static final String CORRELATED_FINDING_IDS = "correlated_finding_ids"; + protected static final String CORRELATION_RULE_ID = "correlation_rule_id"; + protected static final String CORRELATION_RULE_NAME = "correlation_rule_name"; + protected static final String ALERT_ID_FIELD = "id"; + protected static final String SCHEMA_VERSION_FIELD = "schema_version"; + protected static final String ALERT_VERSION_FIELD = "version"; + protected static final String USER_FIELD = "user"; + protected static final String TRIGGER_NAME_FIELD = "trigger_name"; + protected static final String STATE_FIELD = "state"; + protected static final String START_TIME_FIELD = "start_time"; + protected static final String END_TIME_FIELD = "end_time"; + protected static final String ACKNOWLEDGED_TIME_FIELD = "acknowledged_time"; + protected static final String ERROR_MESSAGE_FIELD = "error_message"; + protected static final String SEVERITY_FIELD = "severity"; + protected static final String ACTION_EXECUTION_RESULTS_FIELD = "action_execution_results"; + protected static final String NO_ID = ""; + protected static final long NO_VERSION = Versions.NOT_FOUND; + + public CorrelationAlertService(Client client, NamedXContentRegistry xContentRegistry) { + this.client = client; + this.xContentRegistry = xContentRegistry; + } + + /** + * Searches for active Alerts in the correlation alerts index within a specified time range. + * + * @param ruleId The correlation rule ID to filter the alerts + * @param currentTime The current time of the search range + * @return The search response containing active alerts + */ + public void getActiveAlerts(String ruleId, long currentTime, ActionListener listener) { + Instant currentTimeDate = Instant.ofEpochMilli(currentTime); + BoolQueryBuilder queryBuilder = QueryBuilders.boolQuery() + .must(QueryBuilders.termQuery("correlation_rule_id", ruleId)) + .must(QueryBuilders.rangeQuery("start_time").lte(currentTimeDate)) + .must(QueryBuilders.rangeQuery("end_time").gte(currentTimeDate)) + .must(QueryBuilders.termQuery("state", "ACTIVE")); + + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder() + .seqNoAndPrimaryTerm(true) + .version(true) + .size(10000) // set the size to 10,000 + .query(queryBuilder); + + SearchRequest searchRequest = new SearchRequest(CorrelationIndices.CORRELATION_ALERT_INDEX) + .source(searchSourceBuilder); + + client.search(searchRequest, ActionListener.wrap( + searchResponse -> { + if (searchResponse.getHits().getTotalHits().equals(0)) { + listener.onResponse(new CorrelationAlertsList(Collections.emptyList(), 0)); + } else { + listener.onResponse(new CorrelationAlertsList( + parseCorrelationAlerts(searchResponse), + searchResponse.getHits() != null && searchResponse.getHits().getTotalHits() != null ? + (int) searchResponse.getHits().getTotalHits().value : 0) + ); + } + }, + e -> { + log.error("Search request to fetch correlation alerts failed", e); + listener.onFailure(e); + } + )); + } + + public void indexCorrelationAlert(CorrelationAlert correlationAlert, TimeValue indexTimeout, ActionListener listener) { + // Convert CorrelationAlert to a map + try { + XContentBuilder builder = XContentFactory.jsonBuilder().startObject(); + builder.field("correlated_finding_ids", correlationAlert.getCorrelatedFindingIds()); + builder.field("correlation_rule_id", correlationAlert.getCorrelationRuleId()); + builder.field("correlation_rule_name", correlationAlert.getCorrelationRuleName()); + builder.field("id", correlationAlert.getId()); + builder.field("user", correlationAlert.getUser()); // Convert User object to map + builder.field("schema_version", correlationAlert.getSchemaVersion()); + builder.field("severity", correlationAlert.getSeverity()); + builder.field("state", correlationAlert.getState()); + builder.field("trigger_name", correlationAlert.getTriggerName()); + builder.field("version", correlationAlert.getVersion()); + builder.field("start_time", correlationAlert.getStartTime()); + builder.field("end_time", correlationAlert.getEndTime()); + builder.field("action_execution_results", correlationAlert.getActionExecutionResults()); + builder.field("error_message", correlationAlert.getErrorMessage()); + builder.field("acknowledged_time", correlationAlert.getAcknowledgedTime()); + builder.endObject(); + IndexRequest indexRequest = new IndexRequest(CorrelationIndices.CORRELATION_ALERT_INDEX) + .id(correlationAlert.getId()) + .source(builder) + .timeout(indexTimeout); + + client.index(indexRequest, listener); + } catch (IOException ex) { + log.error("Exception while adding alerts in .opensearch-sap-correlation-alerts index", ex); + } + } + + public List parseCorrelationAlerts(final SearchResponse response) throws IOException { + List alerts = new ArrayList<>(); + for (SearchHit hit : response.getHits()) { + XContentParser xcp = XContentType.JSON.xContent().createParser( + xContentRegistry, + LoggingDeprecationHandler.INSTANCE, + hit.getSourceAsString() + ); + xcp.nextToken(); + CorrelationAlert correlationAlert = parse(xcp, hit.getId(), hit.getVersion()); + alerts.add(correlationAlert); + } + return alerts; + } + + // logic will be moved to common-utils, once the parsing logic in common-utils is fixed + public static CorrelationAlert parse(XContentParser xcp, String id, long version) throws IOException { + // Parse additional CorrelationAlert-specific fields + List correlatedFindingIds = new ArrayList<>(); + String correlationRuleId = null; + String correlationRuleName = null; + User user = null; + int schemaVersion = 0; + String triggerName = null; + Alert.State state = null; + String errorMessage = null; + String severity = null; + List actionExecutionResults = new ArrayList<>(); + Instant startTime = null; + Instant endTime = null; + Instant acknowledgedTime = null; + + while (xcp.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = xcp.currentName(); + xcp.nextToken(); + switch (fieldName) { + case CORRELATED_FINDING_IDS: + XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_ARRAY, xcp.currentToken(), xcp); + while (xcp.nextToken() != XContentParser.Token.END_ARRAY) { + correlatedFindingIds.add(xcp.text()); + } + break; + case CORRELATION_RULE_ID: + correlationRuleId = xcp.text(); + break; + case CORRELATION_RULE_NAME: + correlationRuleName = xcp.text(); + break; + case USER_FIELD: + user = (xcp.currentToken() == XContentParser.Token.VALUE_NULL) ? null : User.parse(xcp); + break; + case ALERT_ID_FIELD: + id = xcp.text(); + break; + case ALERT_VERSION_FIELD: + version = xcp.longValue(); + break; + case SCHEMA_VERSION_FIELD: + schemaVersion = xcp.intValue(); + break; + case TRIGGER_NAME_FIELD: + triggerName = xcp.text(); + break; + case STATE_FIELD: + state = Alert.State.valueOf(xcp.text()); + break; + case ERROR_MESSAGE_FIELD: + errorMessage = xcp.textOrNull(); + break; + case SEVERITY_FIELD: + severity = xcp.text(); + break; + case ACTION_EXECUTION_RESULTS_FIELD: + XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_ARRAY, xcp.currentToken(), xcp); + while (xcp.nextToken() != XContentParser.Token.END_ARRAY) { + actionExecutionResults.add(ActionExecutionResult.parse(xcp)); + } + break; + case START_TIME_FIELD: + startTime = Instant.parse(xcp.text()); + break; + case END_TIME_FIELD: + endTime = Instant.parse(xcp.text()); + break; + case ACKNOWLEDGED_TIME_FIELD: + if (xcp.currentToken() == XContentParser.Token.VALUE_NULL) { + acknowledgedTime = null; + } else { + acknowledgedTime = Instant.parse(xcp.text()); + } + break; + } + } + + // Create and return CorrelationAlert object + return new CorrelationAlert( + correlatedFindingIds, + correlationRuleId, + correlationRuleName, + id, + version, + schemaVersion, + user, + triggerName, + state, + startTime, + endTime, + acknowledgedTime, + errorMessage, + severity, + actionExecutionResults + ); + } +} + + + + diff --git a/src/main/java/org/opensearch/securityanalytics/correlation/alert/CorrelationAlertsList.java b/src/main/java/org/opensearch/securityanalytics/correlation/alert/CorrelationAlertsList.java new file mode 100644 index 000000000..a6cdda9a6 --- /dev/null +++ b/src/main/java/org/opensearch/securityanalytics/correlation/alert/CorrelationAlertsList.java @@ -0,0 +1,33 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.securityanalytics.correlation.alert; + +import org.opensearch.commons.alerting.model.CorrelationAlert; + +import java.util.List; + +/** + * Wrapper class that holds list of correlation alerts and total number of alerts available. + * Useful for pagination. + */ +public class CorrelationAlertsList { + + private final List correlationAlertList; + private final Integer totalAlerts; + + public CorrelationAlertsList(List correlationAlertList, Integer totalAlerts) { + this.correlationAlertList = correlationAlertList; + this.totalAlerts = totalAlerts; + } + + public List getCorrelationAlertList() { + return correlationAlertList; + } + + public Integer getTotalAlerts() { + return totalAlerts; + } + +} diff --git a/src/main/java/org/opensearch/securityanalytics/correlation/alert/CorrelationRuleScheduler.java b/src/main/java/org/opensearch/securityanalytics/correlation/alert/CorrelationRuleScheduler.java new file mode 100644 index 000000000..ba42e252b --- /dev/null +++ b/src/main/java/org/opensearch/securityanalytics/correlation/alert/CorrelationRuleScheduler.java @@ -0,0 +1,186 @@ +package org.opensearch.securityanalytics.correlation.alert; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.index.IndexResponse; +import org.opensearch.client.Client; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.commons.alerting.model.Alert; +import org.opensearch.commons.alerting.model.CorrelationAlert; +import org.opensearch.commons.authuser.User; +import org.opensearch.core.action.ActionListener; +import org.opensearch.securityanalytics.model.CorrelationQuery; +import org.opensearch.securityanalytics.model.CorrelationRule; +import org.opensearch.securityanalytics.model.CorrelationRuleTrigger; +import org.opensearch.securityanalytics.correlation.alert.notifications.NotificationService; +import org.opensearch.securityanalytics.correlation.alert.notifications.CorrelationAlertContext; +import org.opensearch.commons.alerting.model.action.Action; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.securityanalytics.util.SecurityAnalyticsException; +import java.time.Instant; +import java.util.UUID; +import java.util.List; +import java.util.ArrayList; +import java.util.Map; + +public class CorrelationRuleScheduler { + + private final Logger log = LogManager.getLogger(CorrelationRuleScheduler.class); + private final Client client; + private final CorrelationAlertService correlationAlertService; + private final NotificationService notificationService; + + public CorrelationRuleScheduler(Client client, CorrelationAlertService correlationAlertService, NotificationService notificationService) { + this.client = client; + this.correlationAlertService = correlationAlertService; + this.notificationService = notificationService; + } + + public void schedule(List correlationRules, Map> correlatedFindings, String sourceFinding, TimeValue indexTimeout, User user) { + for (CorrelationRule rule : correlationRules) { + CorrelationRuleTrigger trigger = rule.getCorrelationTrigger(); + if (trigger != null) { + List findingIds = new ArrayList<>(); + for (CorrelationQuery query : rule.getCorrelationQueries()) { + List categoryFindingIds = correlatedFindings.get(query.getCategory()); + if (categoryFindingIds != null) { + findingIds.addAll(categoryFindingIds); + } + } + scheduleRule(rule, findingIds, indexTimeout, sourceFinding, user); + } + } + } + + private void scheduleRule(CorrelationRule correlationRule, List findingIds, TimeValue indexTimeout, String sourceFindingId, User user) { + long startTime = Instant.now().toEpochMilli(); + long endTime = startTime + correlationRule.getCorrTimeWindow(); + RuleTask ruleTask = new RuleTask(correlationRule, findingIds, startTime, endTime, correlationAlertService, notificationService, indexTimeout, sourceFindingId, user); + ruleTask.run(); + } + + private class RuleTask implements Runnable { + private final CorrelationRule correlationRule; + private final long startTime; + private final long endTime; + private final List correlatedFindingIds; + private final CorrelationAlertService correlationAlertService; + private final NotificationService notificationService; + private final TimeValue indexTimeout; + private final String sourceFindingId; + private final User user; + + public RuleTask(CorrelationRule correlationRule, List correlatedFindingIds, long startTime, long endTime, CorrelationAlertService correlationAlertService, NotificationService notificationService, TimeValue indexTimeout, String sourceFindingId, User user) { + this.correlationRule = correlationRule; + this.correlatedFindingIds = correlatedFindingIds; + this.startTime = startTime; + this.endTime = endTime; + this.correlationAlertService = correlationAlertService; + this.notificationService = notificationService; + this.indexTimeout = indexTimeout; + this.sourceFindingId = sourceFindingId; + this.user = user; + } + + @Override + public void run() { + long currentTime = Instant.now().toEpochMilli(); + if (currentTime >= startTime && currentTime <= endTime) { + try { + correlationAlertService.getActiveAlerts(correlationRule.getId(), currentTime, new ActionListener<>() { + @Override + public void onResponse(CorrelationAlertsList correlationAlertsList) { + if (correlationAlertsList.getTotalAlerts() == 0) { + addCorrelationAlertIntoIndex(); + List actions = correlationRule.getCorrelationTrigger().getActions(); + for (Action action : actions) { + String configId = action.getDestinationId(); + CorrelationAlertContext ctx = new CorrelationAlertContext(correlatedFindingIds, correlationRule.getName(), correlationRule.getCorrTimeWindow(), sourceFindingId); + String transformedSubject = notificationService.compileTemplate(ctx, action.getSubjectTemplate()); + String transformedMessage = notificationService.compileTemplate(ctx, action.getMessageTemplate()); + try { + notificationService.sendNotification(configId, correlationRule.getCorrelationTrigger().getSeverity(), transformedSubject, transformedMessage); + } catch (Exception e) { + log.error("Failed while sending a notification with " + configId + "for correlationRule id " + correlationRule.getId(), e); + new SecurityAnalyticsException("Failed to send notification", RestStatus.INTERNAL_SERVER_ERROR, e); + } + + } + } else { + for (CorrelationAlert correlationAlert: correlationAlertsList.getCorrelationAlertList()) { + updateCorrelationAlert(correlationAlert); + } + } + } + + @Override + public void onFailure(Exception e) { + log.error("Failed to search active correlation alert", e); + new SecurityAnalyticsException("Failed to search active correlation alert", RestStatus.INTERNAL_SERVER_ERROR, e); + } + }); + } catch (Exception e) { + log.error("Failed to fetch active alerts in the time window", e); + new SecurityAnalyticsException("Failed to get active alerts in the correlationRuletimewindow", RestStatus.INTERNAL_SERVER_ERROR, e); + } + } + } + + private void addCorrelationAlertIntoIndex() { + CorrelationAlert correlationAlert = new CorrelationAlert( + correlatedFindingIds, + correlationRule.getId(), + correlationRule.getName(), + UUID.randomUUID().toString(), + 1L, + 1, + user, + correlationRule.getCorrelationTrigger().getName(), + Alert.State.ACTIVE, + Instant.ofEpochMilli(startTime), + Instant.ofEpochMilli(endTime), + null, + null, + correlationRule.getCorrelationTrigger().getSeverity(), + new ArrayList<>() + ); + insertCorrelationAlert(correlationAlert); + } + + private void updateCorrelationAlert(CorrelationAlert correlationAlert) { + CorrelationAlert newCorrelationAlert = new CorrelationAlert( + correlatedFindingIds, + correlationAlert.getCorrelationRuleId(), + correlationAlert.getCorrelationRuleName(), + correlationAlert.getId(), + 1L, + 1, + correlationAlert.getUser(), + correlationRule.getCorrelationTrigger().getName(), + Alert.State.ACTIVE, + Instant.ofEpochMilli(startTime), + Instant.ofEpochMilli(endTime), + null, + null, + correlationRule.getCorrelationTrigger().getSeverity(), + new ArrayList<>() + ); + insertCorrelationAlert(newCorrelationAlert); + } + + private void insertCorrelationAlert(CorrelationAlert correlationAlert) { + correlationAlertService.indexCorrelationAlert(correlationAlert, indexTimeout, new ActionListener<>() { + @Override + public void onResponse(IndexResponse indexResponse) { + log.info("Successfully updated the index .opensearch-sap-correlation-alerts: {}", indexResponse); + } + + @Override + public void onFailure(Exception e) { + log.error("Failed to index correlation alert", e); + } + }); + } + } +} + diff --git a/src/main/java/org/opensearch/securityanalytics/correlation/alert/notifications/CorrelationAlertContext.java b/src/main/java/org/opensearch/securityanalytics/correlation/alert/notifications/CorrelationAlertContext.java new file mode 100644 index 000000000..90b8ded25 --- /dev/null +++ b/src/main/java/org/opensearch/securityanalytics/correlation/alert/notifications/CorrelationAlertContext.java @@ -0,0 +1,35 @@ +package org.opensearch.securityanalytics.correlation.alert.notifications; + +import org.opensearch.securityanalytics.model.CorrelationRule; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +public class CorrelationAlertContext { + + private final List correlatedFindingIds; + private final String sourceFinding; + private final String correlationRuleName; + private final long timeWindow; + public CorrelationAlertContext(List correlatedFindingIds, String correlationRuleName, long timeWindow, String sourceFinding) { + this.correlatedFindingIds = correlatedFindingIds; + this.correlationRuleName = correlationRuleName; + this.timeWindow = timeWindow; + this.sourceFinding = sourceFinding; + } + + /** + * Mustache templates need special permissions to reflectively introspect field names. To avoid doing this we + * translate the context to a Map of Strings to primitive types, which can be accessed without reflection. + */ + public Map asTemplateArg() { + Map templateArg = new HashMap<>(); + templateArg.put("correlatedFindingIds", correlatedFindingIds); + templateArg.put("sourceFinding", sourceFinding); + templateArg.put("correlationRuleName", correlationRuleName); + templateArg.put("timeWindow", timeWindow); + return templateArg; + } + +} \ No newline at end of file diff --git a/src/main/java/org/opensearch/securityanalytics/correlation/alert/notifications/NotificationService.java b/src/main/java/org/opensearch/securityanalytics/correlation/alert/notifications/NotificationService.java new file mode 100644 index 000000000..34478a063 --- /dev/null +++ b/src/main/java/org/opensearch/securityanalytics/correlation/alert/notifications/NotificationService.java @@ -0,0 +1,112 @@ +package org.opensearch.securityanalytics.correlation.alert.notifications; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.client.node.NodeClient; +import org.opensearch.commons.notifications.NotificationsPluginInterface; +import org.opensearch.commons.notifications.action.*; +import org.opensearch.commons.notifications.model.ChannelMessage; +import org.opensearch.commons.notifications.model.EventSource; +import org.opensearch.commons.notifications.model.SeverityType; +import org.opensearch.commons.notifications.model.NotificationConfigInfo; +import org.opensearch.commons.notifications.action.GetNotificationConfigRequest; +import org.opensearch.commons.notifications.action.GetNotificationConfigResponse; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.securityanalytics.util.SecurityAnalyticsException; +import org.opensearch.script.ScriptService; +import java.io.IOException; +import java.util.HashMap; +import java.util.List; +import java.util.ArrayList; +import java.util.Map; +import java.util.Set; +import java.util.HashSet; +import java.util.Collections; +import org.opensearch.script.Script; +import org.opensearch.script.TemplateScript; +import org.opensearch.commons.notifications.model.SeverityType; + +public class NotificationService { + + private static final Logger logger = LogManager.getLogger(NotificationService.class); + + private static ScriptService scriptService; + private final NodeClient client; + + public NotificationService(NodeClient client, ScriptService scriptService) { + this.client = client; + this.scriptService = scriptService; + } + /** + * Extension function for publishing a notification to a channel in the Notification plugin. + */ + public void sendNotification(String configId, String severity, String subject, String notificationMessageText) throws IOException { + ChannelMessage message = generateMessage(notificationMessageText); + List channelIds = new ArrayList<>(); + channelIds.add(configId); + SeverityType severityType = SeverityType.Companion.fromTagOrDefault(severity); + NotificationsPluginInterface.INSTANCE.sendNotification(client, new EventSource(subject, configId, severityType, Collections.emptyList()), message, channelIds, new ActionListener() { + @Override + public void onResponse(SendNotificationResponse sendNotificationResponse) { + if(sendNotificationResponse.getStatus() == RestStatus.OK) { + logger.info("Successfully sent a notification, Notification Event: " + sendNotificationResponse.getNotificationEvent()); + } + else { + logger.error("Error while sending a notification, Notification Event: " + sendNotificationResponse.getNotificationEvent()); + } + + } + @Override + public void onFailure(Exception e) { + logger.error("Failed while sending a notification with " + configId, e); + } + }); + } + + /** + * Gets a NotificationConfigInfo object by ID if it exists. + */ + public GetNotificationConfigResponse getNotificationConfigInfo(String id) { + + Set idSet = new HashSet(); + idSet.add(id); + GetNotificationConfigRequest getNotificationConfigRequest = new GetNotificationConfigRequest(idSet, 0, 10, null, null, new HashMap<>()); + GetNotificationConfigResponse configResp = null; + NotificationsPluginInterface.INSTANCE.getNotificationConfig(client, getNotificationConfigRequest, new ActionListener() { + @Override + public void onResponse(GetNotificationConfigResponse getNotificationConfigResponse) { + if (getNotificationConfigResponse.getStatus() == RestStatus.OK) { + getNotificationConfigResponse = configResp; + } else { + logger.error("Successfully sent a notification, Notification Event: " + getNotificationConfigResponse); + } + } + + @Override + public void onFailure(Exception e) { + logger.error("Notification config [" + id + "] was not found"); + new SecurityAnalyticsException("Failed to fetch notification config", RestStatus.INTERNAL_SERVER_ERROR, e); + } + }); + logger.info("Notification config response is: {} ", configResp); + return configResp; + } + + public static ChannelMessage generateMessage(String message) { + return new ChannelMessage( + message, + null, + null + ); + } + + public static String compileTemplate(CorrelationAlertContext ctx, Script template) { + TemplateScript.Factory factory = scriptService.compile(template, TemplateScript.CONTEXT); + Map params = new HashMap<>(template.getParams()); + params.put("ctx", ctx.asTemplateArg()); + TemplateScript templateScript = factory.newInstance(params); + return templateScript.execute(); + } + +} diff --git a/src/main/java/org/opensearch/securityanalytics/model/CorrelationRule.java b/src/main/java/org/opensearch/securityanalytics/model/CorrelationRule.java index b7f5a4f70..c4a1d4e2c 100644 --- a/src/main/java/org/opensearch/securityanalytics/model/CorrelationRule.java +++ b/src/main/java/org/opensearch/securityanalytics/model/CorrelationRule.java @@ -10,6 +10,7 @@ import java.util.Objects; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.opensearch.commons.authuser.User; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.common.io.stream.Writeable; @@ -29,6 +30,7 @@ public class CorrelationRule implements Writeable, ToXContentObject { public static final Long NO_VERSION = 1L; private static final String CORRELATION_QUERIES = "correlate"; private static final String CORRELATION_TIME_WINDOW = "time_window"; + private static final String TRIGGER_FIELD = "trigger"; private String id; @@ -40,16 +42,19 @@ public class CorrelationRule implements Writeable, ToXContentObject { private Long corrTimeWindow; - public CorrelationRule(String id, Long version, String name, List correlationQueries, Long corrTimeWindow) { + private CorrelationRuleTrigger trigger; + + public CorrelationRule(String id, Long version, String name, List correlationQueries, Long corrTimeWindow, CorrelationRuleTrigger trigger) { this.id = id != null ? id : NO_ID; this.version = version != null ? version : NO_VERSION; this.name = name; this.correlationQueries = correlationQueries; this.corrTimeWindow = corrTimeWindow != null? corrTimeWindow: 300000L; + this.trigger = trigger; } public CorrelationRule(StreamInput sin) throws IOException { - this(sin.readString(), sin.readLong(), sin.readString(), sin.readList(CorrelationQuery::readFrom), sin.readLong()); + this(sin.readString(), sin.readLong(), sin.readString(), sin.readList(CorrelationQuery::readFrom), sin.readLong(), sin.readBoolean() ? new CorrelationRuleTrigger(sin) : null); } @Override @@ -62,6 +67,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws correlationQueries = this.correlationQueries.toArray(correlationQueries); builder.field(CORRELATION_QUERIES, correlationQueries); builder.field(CORRELATION_TIME_WINDOW, corrTimeWindow); + builder.field(TRIGGER_FIELD, trigger); return builder.endObject(); } @@ -74,6 +80,11 @@ public void writeTo(StreamOutput out) throws IOException { for (CorrelationQuery query : correlationQueries) { query.writeTo(out); } + + out.writeBoolean(trigger != null); + if (trigger != null) { + trigger.writeTo(out); + } out.writeLong(corrTimeWindow); } @@ -88,7 +99,7 @@ public static CorrelationRule parse(XContentParser xcp, String id, Long version) String name = null; List correlationQueries = new ArrayList<>(); Long corrTimeWindow = null; - + CorrelationRuleTrigger trigger = null; XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_OBJECT, xcp.nextToken(), xcp); while (xcp.nextToken() != XContentParser.Token.END_OBJECT) { String fieldName = xcp.currentName(); @@ -108,11 +119,18 @@ public static CorrelationRule parse(XContentParser xcp, String id, Long version) case CORRELATION_TIME_WINDOW: corrTimeWindow = xcp.longValue(); break; + case TRIGGER_FIELD: + if (xcp.currentToken() == XContentParser.Token.VALUE_NULL) { + trigger = null; + } else { + trigger = CorrelationRuleTrigger.parse(xcp); + } + break; default: xcp.skipChildren(); } } - return new CorrelationRule(id, version, name, correlationQueries, corrTimeWindow); + return new CorrelationRule(id, version, name, correlationQueries, corrTimeWindow, trigger); } public static CorrelationRule readFrom(StreamInput sin) throws IOException { @@ -151,6 +169,10 @@ public Long getCorrTimeWindow() { return corrTimeWindow; } + public CorrelationRuleTrigger getCorrelationTrigger() { + return trigger; + } + @Override public boolean equals(Object o) { if (this == o) return true; @@ -159,7 +181,8 @@ public boolean equals(Object o) { return id.equals(that.id) && version.equals(that.version) && name.equals(that.name) - && correlationQueries.equals(that.correlationQueries); + && correlationQueries.equals(that.correlationQueries) + && trigger.equals(that.trigger); } @Override diff --git a/src/main/java/org/opensearch/securityanalytics/model/CorrelationRuleTrigger.java b/src/main/java/org/opensearch/securityanalytics/model/CorrelationRuleTrigger.java new file mode 100644 index 000000000..3426c7eb1 --- /dev/null +++ b/src/main/java/org/opensearch/securityanalytics/model/CorrelationRuleTrigger.java @@ -0,0 +1,193 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.securityanalytics.model; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.common.UUIDs; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; +import org.opensearch.core.xcontent.XContentParserUtils; +import org.opensearch.commons.alerting.model.action.Action; +import org.opensearch.core.ParseField; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.script.Script; +import org.opensearch.script.ScriptType; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Objects; +import java.util.stream.Collectors; + +public class CorrelationRuleTrigger implements Writeable, ToXContentObject { + + private static final Logger log = LogManager.getLogger(DetectorTrigger.class); + + private String id; + + private String name; + + private String severity; + + private List actions; + + private static final String ID_FIELD = "id"; + + private static final String SEVERITY_FIELD = "severity"; + private static final String ACTIONS_FIELD = "actions"; + + private static final String NAME_FIELD = "name"; + + public static final NamedXContentRegistry.Entry XCONTENT_REGISTRY = new NamedXContentRegistry.Entry( + CorrelationRuleTrigger.class, + new ParseField(ID_FIELD), + CorrelationRuleTrigger::parse + ); + + public CorrelationRuleTrigger(String id, + String name, + String severity, + List actions) { + this.id = id == null ? UUIDs.base64UUID() : id; + this.name = name; + this.severity = severity; + this.actions = actions; + } + + public CorrelationRuleTrigger(StreamInput sin) throws IOException { + this( + sin.readString(), + sin.readString(), + sin.readString(), + sin.readList(Action::readFrom) + ); + } + + public Map asTemplateArg() { + return Map.of( + ACTIONS_FIELD, actions.stream().map(Action::asTemplateArg) + ); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(id); + out.writeString(name); + out.writeString(severity); + out.writeCollection(actions); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + + Action[] actionArray = new Action[]{}; + actionArray = actions.toArray(actionArray); + + return builder.startObject() + .field(ID_FIELD, id) + .field(NAME_FIELD, name) + .field(SEVERITY_FIELD, severity) + .field(ACTIONS_FIELD, actionArray) + .endObject(); + } + + public static CorrelationRuleTrigger parse(XContentParser xcp) throws IOException { + String id = null; + String name = null; + String severity = null; + List actions = new ArrayList<>(); + + XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_OBJECT, xcp.currentToken(), xcp); + while (xcp.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = xcp.currentName(); + xcp.nextToken(); + + switch (fieldName) { + case ID_FIELD: + id = xcp.text(); + break; + case NAME_FIELD: + name = xcp.text(); + break; + case SEVERITY_FIELD: + severity = xcp.text(); + break; + case ACTIONS_FIELD: + XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_ARRAY, xcp.currentToken(), xcp); + while (xcp.nextToken() != XContentParser.Token.END_ARRAY) { + Action action = Action.parse(xcp); + actions.add(action); + } + break; + default: + xcp.skipChildren(); + } + } + return new CorrelationRuleTrigger(id, name, severity, actions); + } + + public static CorrelationRuleTrigger readFrom(StreamInput sin) throws IOException { + return new CorrelationRuleTrigger(sin); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + CorrelationRuleTrigger that = (CorrelationRuleTrigger) o; + return Objects.equals(id, that.id) && Objects.equals(name, that.name) && Objects.equals(severity, that.severity) && Objects.equals(actions, that.actions); + } + + @Override + public int hashCode() { + return Objects.hash(id, name, severity, actions); + } + + public String getId() { + return id; + } + + public String getName() { + return name; + } + + public String getSeverity() { + return severity; + } + + public List getActions() { +// List transformedActions = new ArrayList<>(); +// +// if (actions != null) { +// for (Action action : actions) { +// String subjectTemplate = action.getSubjectTemplate() != null ? action.getSubjectTemplate().getIdOrCode() : ""; +// CorrelationContext ctx = CorrelationContext(rule, sourceFindingId); +// no +// +// action.getMessageTemplate(); +// String messageTemplate = action.getMessageTemplate().getIdOrCode(); +// messageTemplate = messageTemplate.replace("{{ctx.detector", "{{ctx.monitor"); +// +// Action transformedAction = new Action(action.getName(), action.getDestinationId(), +// new Script(ScriptType.INLINE, Script.DEFAULT_TEMPLATE_LANG, subjectTemplate, Collections.emptyMap()), +// new Script(ScriptType.INLINE, Script.DEFAULT_TEMPLATE_LANG, messageTemplate, Collections.emptyMap()), +// action.getThrottleEnabled(), action.getThrottle(), +// action.getId(), action.getActionExecutionPolicy()); +// +// transformedActions.add(transformedAction); +// } +// } + return actions; + } + +} \ No newline at end of file diff --git a/src/main/java/org/opensearch/securityanalytics/transport/TransportCorrelateFindingAction.java b/src/main/java/org/opensearch/securityanalytics/transport/TransportCorrelateFindingAction.java index 910794556..e84d8b3e9 100644 --- a/src/main/java/org/opensearch/securityanalytics/transport/TransportCorrelateFindingAction.java +++ b/src/main/java/org/opensearch/securityanalytics/transport/TransportCorrelateFindingAction.java @@ -35,6 +35,7 @@ import org.opensearch.commons.alerting.action.PublishFindingsRequest; import org.opensearch.commons.alerting.action.SubscribeFindingsResponse; import org.opensearch.commons.alerting.action.AlertingActions; +import org.opensearch.commons.authuser.User; import org.opensearch.core.common.io.stream.InputStreamStreamInput; import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; import org.opensearch.core.xcontent.NamedXContentRegistry; @@ -49,6 +50,8 @@ import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.securityanalytics.correlation.JoinEngine; import org.opensearch.securityanalytics.correlation.VectorEmbeddingsEngine; +import org.opensearch.securityanalytics.correlation.alert.CorrelationAlertService; +import org.opensearch.securityanalytics.correlation.alert.notifications.NotificationService; import org.opensearch.securityanalytics.logtype.LogTypeService; import org.opensearch.securityanalytics.model.CustomLogType; import org.opensearch.securityanalytics.model.Detector; @@ -99,6 +102,10 @@ public class TransportCorrelateFindingAction extends HandledTransportAction actionListener) { try { PublishFindingsRequest transformedRequest = transformRequest(request); - AsyncCorrelateFindingAction correlateFindingAction = new AsyncCorrelateFindingAction(task, transformedRequest, actionListener); + AsyncCorrelateFindingAction correlateFindingAction = new AsyncCorrelateFindingAction(task, transformedRequest, readUserFromThreadContext(this.threadPool), actionListener); if (!this.correlationIndices.correlationIndexExists()) { try { @@ -146,7 +155,6 @@ protected void doExecute(Task task, ActionRequest request, ActionListener { @@ -168,6 +176,19 @@ protected void doExecute(Task task, ActionRequest request, ActionListener { + if (createIndexResponse.isAcknowledged()) { + IndexUtils.correlationAlertIndexUpdated(); + } else { + correlateFindingAction.onFailures(new OpenSearchStatusException("Failed to create correlation metadata Index", RestStatus.INTERNAL_SERVER_ERROR)); + } + }, correlateFindingAction::onFailures)); + } catch (Exception ex) { + correlateFindingAction.onFailures(ex); + } + } } else { correlateFindingAction.onFailures(new OpenSearchStatusException("Failed to create correlation Index", RestStatus.INTERNAL_SERVER_ERROR)); } @@ -193,14 +214,12 @@ public class AsyncCorrelateFindingAction { private final AtomicBoolean counter = new AtomicBoolean(); private final Task task; - AsyncCorrelateFindingAction(Task task, PublishFindingsRequest request, ActionListener listener) { + AsyncCorrelateFindingAction(Task task, PublishFindingsRequest request, User user, ActionListener listener) { this.task = task; this.request = request; this.listener = listener; - this.response =new AtomicReference<>(); - - this.joinEngine = new JoinEngine(client, request, xContentRegistry, corrTimeWindow, this, logTypeService, enableAutoCorrelation); + this.joinEngine = new JoinEngine(client, request, xContentRegistry, corrTimeWindow, indexTimeout, this, logTypeService, enableAutoCorrelation, correlationAlertService, notificationService, user); this.vectorEmbeddingsEngine = new VectorEmbeddingsEngine(client, indexTimeout, corrTimeWindow, this); } diff --git a/src/main/java/org/opensearch/securityanalytics/util/CorrelationIndices.java b/src/main/java/org/opensearch/securityanalytics/util/CorrelationIndices.java index 624d76d58..375342d09 100644 --- a/src/main/java/org/opensearch/securityanalytics/util/CorrelationIndices.java +++ b/src/main/java/org/opensearch/securityanalytics/util/CorrelationIndices.java @@ -36,6 +36,8 @@ public class CorrelationIndices { public static final String CORRELATION_HISTORY_INDEX_PATTERN_REGEXP = ".opensearch-sap-correlation-history*"; public static final String CORRELATION_HISTORY_WRITE_INDEX = ".opensearch-sap-correlation-history-write"; + + public static final String CORRELATION_ALERT_INDEX = ".opensearch-sap-correlation-alerts"; public static final long FIXED_HISTORICAL_INTERVAL = 24L * 60L * 60L * 20L * 1000L; private final Client client; @@ -84,6 +86,11 @@ public boolean correlationMetadataIndexExists() { return clusterState.metadata().hasIndex(CORRELATION_METADATA_INDEX); } + public boolean correlationAlertIndexExists() { + ClusterState clusterState = clusterService.state(); + return clusterState.metadata().hasIndex(CORRELATION_ALERT_INDEX); + } + public void setupCorrelationIndex(TimeValue indexTimeout, Long setupTimestamp, ActionListener listener) throws IOException { try { long currentTimestamp = System.currentTimeMillis(); @@ -122,4 +129,17 @@ public void setupCorrelationIndex(TimeValue indexTimeout, Long setupTimestamp, A throw ex; } } + + public static String correlationAlertIndexMappings() throws IOException { + return new String(Objects.requireNonNull(CorrelationIndices.class.getClassLoader().getResourceAsStream("mappings/correlation_alert_mapping.json")).readAllBytes(), Charset.defaultCharset()); + } + public void initCorrelationAlertIndex(ActionListener actionListener) throws IOException { + Settings correlationAlertSettings = Settings.builder() + .put("index.hidden", true) + .build(); + CreateIndexRequest indexRequest = new CreateIndexRequest(CORRELATION_ALERT_INDEX) + .mapping(correlationAlertIndexMappings()) + .settings(correlationAlertSettings); + client.admin().indices().create(indexRequest, actionListener); + } } \ No newline at end of file diff --git a/src/main/java/org/opensearch/securityanalytics/util/IndexUtils.java b/src/main/java/org/opensearch/securityanalytics/util/IndexUtils.java index ce358591e..a24286fda 100644 --- a/src/main/java/org/opensearch/securityanalytics/util/IndexUtils.java +++ b/src/main/java/org/opensearch/securityanalytics/util/IndexUtils.java @@ -45,6 +45,8 @@ public class IndexUtils { public static String lastUpdatedCorrelationHistoryIndex = null; public static Boolean correlationRuleIndexUpdated = false; + public static Boolean correlationAlertIndexUpdated = false; + public static Boolean customLogTypeIndexUpdated = false; public static void detectorIndexUpdated() { @@ -65,6 +67,10 @@ public static void correlationMetadataIndexUpdated() { correlationMetadataIndexUpdated = true; } + public static void correlationAlertIndexUpdated() { + correlationAlertIndexUpdated = true; + } + public static void correlationRuleIndexUpdated() { correlationRuleIndexUpdated = true; } diff --git a/src/main/resources/mappings/correlation_alert_mapping.json b/src/main/resources/mappings/correlation_alert_mapping.json new file mode 100644 index 000000000..585a036c6 --- /dev/null +++ b/src/main/resources/mappings/correlation_alert_mapping.json @@ -0,0 +1,102 @@ +{ + "_meta": { + "schema_version": 1 + }, + "properties": { + "acknowledged_time": { + "type": "date" + }, + "action_execution_results": { + "type": "nested", + "properties": { + "action_id": { + "type": "keyword" + }, + "last_execution_time": { + "type": "date" + }, + "throttled_count": { + "type": "integer" + } + } + }, + "error_message": { + "type": "text" + }, + "correlated_finding_ids": { + "type": "text", + "fields": { + "keyword": { + "type": "keyword" + } + } + }, + "correlation_rule_id": { + "type": "keyword" + }, + "correlation_rule_name": { + "type": "text" + }, + "user": { + "properties": { + "backend_roles": { + "type": "text", + "fields": { + "keyword": { + "type": "keyword" + } + } + }, + "custom_attribute_names": { + "type": "text", + "fields": { + "keyword": { + "type": "keyword" + } + } + }, + "name": { + "type": "text", + "fields": { + "keyword": { + "type": "keyword", + "ignore_above": 256 + } + } + }, + "roles": { + "type": "text", + "fields": { + "keyword": { + "type": "keyword" + } + } + } + } + }, + "schema_version": { + "type": "integer" + }, + "severity": { + "type": "keyword" + }, + "state": { + "type": "keyword" + }, + "id": { + "type": "keyword" + }, + "trigger_name": { + "type": "text" + }, + "version": { + "type": "long" + }, + "start_time": { + "type": "date" + }, + "end_time": { + "type": "date" + } + } +} \ No newline at end of file diff --git a/src/test/java/org/opensearch/securityanalytics/TestHelpers.java b/src/test/java/org/opensearch/securityanalytics/TestHelpers.java index a1987138d..33d0de4cc 100644 --- a/src/test/java/org/opensearch/securityanalytics/TestHelpers.java +++ b/src/test/java/org/opensearch/securityanalytics/TestHelpers.java @@ -227,7 +227,7 @@ public static CorrelationRule randomCorrelationRule(String name) { List.of( new CorrelationQuery("vpc_flow1", "dstaddr:192.168.1.*", "network", null), new CorrelationQuery("ad_logs1", "azure.platformlogs.result_type:50126", "ad_ldap", null) - ), 300000L); + ), 300000L, null); } public static String randomRule() { diff --git a/src/test/java/org/opensearch/securityanalytics/correlation/CorrelationEngineRestApiIT.java b/src/test/java/org/opensearch/securityanalytics/correlation/CorrelationEngineRestApiIT.java index a2979a231..a7eda56aa 100644 --- a/src/test/java/org/opensearch/securityanalytics/correlation/CorrelationEngineRestApiIT.java +++ b/src/test/java/org/opensearch/securityanalytics/correlation/CorrelationEngineRestApiIT.java @@ -968,7 +968,7 @@ private String createNetworkToWindowsFieldBasedRule(LogIndices indices) throws I CorrelationQuery query1 = new CorrelationQuery(indices.vpcFlowsIndex, null, "network", "srcaddr"); CorrelationQuery query4 = new CorrelationQuery(indices.windowsIndex, null, "test_windows", "SourceIp"); - CorrelationRule rule = new CorrelationRule(CorrelationRule.NO_ID, CorrelationRule.NO_VERSION, "network to windows", List.of(query1, query4), 300000L); + CorrelationRule rule = new CorrelationRule(CorrelationRule.NO_ID, CorrelationRule.NO_VERSION, "network to windows", List.of(query1, query4), 300000L, null); Request request = new Request("POST", "/_plugins/_security_analytics/correlation/rules"); request.setJsonEntity(toJsonString(rule)); Response response = client().performRequest(request); @@ -981,7 +981,7 @@ private String createNetworkToWindowsFilterQueryBasedRule(LogIndices indices) th CorrelationQuery query1 = new CorrelationQuery(indices.vpcFlowsIndex, "srcaddr:1.2.3.4", "network", null); CorrelationQuery query4 = new CorrelationQuery(indices.windowsIndex, "SourceIp:1.2.3.4", "test_windows", null); - CorrelationRule rule = new CorrelationRule(CorrelationRule.NO_ID, CorrelationRule.NO_VERSION, "network to windows", List.of(query1, query4), 300000L); + CorrelationRule rule = new CorrelationRule(CorrelationRule.NO_ID, CorrelationRule.NO_VERSION, "network to windows", List.of(query1, query4), 300000L, null); Request request = new Request("POST", "/_plugins/_security_analytics/correlation/rules"); request.setJsonEntity(toJsonString(rule)); Response response = client().performRequest(request); @@ -994,7 +994,7 @@ private String createNetworkToCustomLogTypeFieldBasedRule(LogIndices indices, St CorrelationQuery query1 = new CorrelationQuery(indices.vpcFlowsIndex, null, "network", "srcaddr"); CorrelationQuery query4 = new CorrelationQuery(customLogTypeIndex, null, customLogTypeName, "SourceIp"); - CorrelationRule rule = new CorrelationRule(CorrelationRule.NO_ID, CorrelationRule.NO_VERSION, "network to custom log type", List.of(query1, query4), 300000L); + CorrelationRule rule = new CorrelationRule(CorrelationRule.NO_ID, CorrelationRule.NO_VERSION, "network to custom log type", List.of(query1, query4), 300000L, null); Request request = new Request("POST", "/_plugins/_security_analytics/correlation/rules"); request.setJsonEntity(toJsonString(rule)); Response response = client().performRequest(request); @@ -1008,7 +1008,7 @@ private String createNetworkToAdLdapToWindowsRule(LogIndices indices) throws IOE CorrelationQuery query2 = new CorrelationQuery(indices.adLdapLogsIndex, "ResultType:50126", "ad_ldap", null); CorrelationQuery query4 = new CorrelationQuery(indices.windowsIndex, "Domain:NTAUTHORI*", "test_windows", null); - CorrelationRule rule = new CorrelationRule(CorrelationRule.NO_ID, CorrelationRule.NO_VERSION, "network to ad_ldap to windows", List.of(query1, query2, query4), 300000L); + CorrelationRule rule = new CorrelationRule(CorrelationRule.NO_ID, CorrelationRule.NO_VERSION, "network to ad_ldap to windows", List.of(query1, query2, query4), 300000L, null); Request request = new Request("POST", "/_plugins/_security_analytics/correlation/rules"); request.setJsonEntity(toJsonString(rule)); Response response = client().performRequest(request); @@ -1022,7 +1022,7 @@ private String createWindowsToAppLogsToS3LogsRule(LogIndices indices) throws IOE CorrelationQuery query2 = new CorrelationQuery(indices.appLogsIndex, "endpoint:\\/customer_records.txt", "others_application", null); CorrelationQuery query4 = new CorrelationQuery(indices.s3AccessLogsIndex, "aws.cloudtrail.eventName:ReplicateObject", "s3", null); - CorrelationRule rule = new CorrelationRule(CorrelationRule.NO_ID, CorrelationRule.NO_VERSION, "windows to app_logs to s3 logs", List.of(query1, query2, query4), 300000L); + CorrelationRule rule = new CorrelationRule(CorrelationRule.NO_ID, CorrelationRule.NO_VERSION, "windows to app_logs to s3 logs", List.of(query1, query2, query4), 300000L, null); Request request = new Request("POST", "/_plugins/_security_analytics/correlation/rules"); request.setJsonEntity(toJsonString(rule)); Response response = client().performRequest(request); @@ -1035,7 +1035,7 @@ private String createCloudtrailFieldBasedRule(String index, String field, Long t CorrelationQuery query1 = new CorrelationQuery(index, "EventName:CreateUser", "cloudtrail", field); CorrelationQuery query2 = new CorrelationQuery(index, "EventName:DeleteUser", "cloudtrail", field); - CorrelationRule rule = new CorrelationRule(CorrelationRule.NO_ID, CorrelationRule.NO_VERSION, "cloudtrail field based", List.of(query1, query2), timeWindow); + CorrelationRule rule = new CorrelationRule(CorrelationRule.NO_ID, CorrelationRule.NO_VERSION, "cloudtrail field based", List.of(query1, query2), timeWindow, null); Request request = new Request("POST", "/_plugins/_security_analytics/correlation/rules"); request.setJsonEntity(toJsonString(rule)); Response response = client().performRequest(request);