diff --git a/src/main/java/org/opensearch/securityanalytics/findings/FindingsService.java b/src/main/java/org/opensearch/securityanalytics/findings/FindingsService.java index af2cae9c5..4022aeff2 100644 --- a/src/main/java/org/opensearch/securityanalytics/findings/FindingsService.java +++ b/src/main/java/org/opensearch/securityanalytics/findings/FindingsService.java @@ -4,28 +4,21 @@ */ package org.opensearch.securityanalytics.findings; -import java.time.Instant; -import java.util.ArrayList; -import java.util.Collections; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.stream.Collectors; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.apache.lucene.search.join.ScoreMode; import org.opensearch.OpenSearchStatusException; -import org.opensearch.core.action.ActionListener; import org.opensearch.client.Client; import org.opensearch.client.node.NodeClient; import org.opensearch.commons.alerting.AlertingPluginInterface; import org.opensearch.commons.alerting.model.DocLevelQuery; import org.opensearch.commons.alerting.model.FindingWithDocs; import org.opensearch.commons.alerting.model.Table; +import org.opensearch.core.action.ActionListener; import org.opensearch.core.rest.RestStatus; import org.opensearch.index.query.BoolQueryBuilder; -import org.opensearch.index.query.PrefixQueryBuilder; import org.opensearch.index.query.NestedQueryBuilder; +import org.opensearch.index.query.PrefixQueryBuilder; import org.opensearch.index.query.QueryBuilder; import org.opensearch.index.query.QueryBuilders; import org.opensearch.securityanalytics.action.FindingDto; @@ -37,6 +30,16 @@ import org.opensearch.securityanalytics.model.Detector; import org.opensearch.securityanalytics.util.SecurityAnalyticsException; +import java.time.Instant; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +import static org.opensearch.securityanalytics.transport.TransportIndexDetectorAction.CHAINED_FINDINGS_MONITOR_STRING; + /** * Implements searching/fetching of findings */ @@ -99,8 +102,8 @@ public void onFailure(Exception e) { Map monitorToDetectorMapping = new HashMap<>(); detector.getMonitorIds().forEach( monitorId -> { - if (detector.getRuleIdMonitorIdMap().containsKey("chained_findings_monitor")) { - if (!detector.getRuleIdMonitorIdMap().get("chained_findings_monitor").equals(monitorId)) { + if (detector.getRuleIdMonitorIdMap().containsKey(CHAINED_FINDINGS_MONITOR_STRING)) { + if (!detector.getRuleIdMonitorIdMap().get(CHAINED_FINDINGS_MONITOR_STRING).equals(monitorId)) { monitorToDetectorMapping.put(monitorId, detector); } } else { diff --git a/src/main/java/org/opensearch/securityanalytics/transport/TransportIndexDetectorAction.java b/src/main/java/org/opensearch/securityanalytics/transport/TransportIndexDetectorAction.java index 7bdd12816..ccea2d1e0 100644 --- a/src/main/java/org/opensearch/securityanalytics/transport/TransportIndexDetectorAction.java +++ b/src/main/java/org/opensearch/securityanalytics/transport/TransportIndexDetectorAction.java @@ -131,6 +131,7 @@ public class TransportIndexDetectorAction extends HandledTransportAction> ruleFieldMappings) { new ActionListener<>() { @Override public void onResponse(Collection indexMonitorRequests) { + if (detector.getRuleIdMonitorIdMap().containsKey(CHAINED_FINDINGS_MONITOR_STRING)) { + String cmfId = detector.getRuleIdMonitorIdMap().get(CHAINED_FINDINGS_MONITOR_STRING); + if (shouldAddChainedFindingDocMonitor(indexMonitorRequests.isEmpty(), rulesById)) { + monitorsToBeUpdated.add(createDocLevelMonitorMatchAllRequest(detector, RefreshPolicy.IMMEDIATE, cmfId, Method.PUT, rulesById)); + } + } else { + if (shouldAddChainedFindingDocMonitor(indexMonitorRequests.isEmpty(), rulesById)) { + monitorsToBeAdded.add(createDocLevelMonitorMatchAllRequest(detector, RefreshPolicy.IMMEDIATE, detector.getId() + "_chained_findings", Method.POST, rulesById)); + } + } onIndexMonitorRequestCreation( monitorsToBeUpdated, monitorsToBeAdded, @@ -563,6 +574,10 @@ public void onFailure(Exception e) { }); } + private boolean shouldAddChainedFindingDocMonitor(boolean bucketLevelMonitorsExist, List> rulesById) { + return enabledWorkflowUsage && !bucketLevelMonitorsExist && rulesById.stream().anyMatch(it -> it.getRight().isAggregationRule()); + } + private void onIndexMonitorRequestCreation(List monitorsToBeUpdated, List monitorsToBeAdded, List> rulesById, @@ -909,7 +924,7 @@ public void onResponse(Map> ruleFieldMappings) { @Override public void onResponse(Collection indexMonitorRequests) { // if workflow usage enabled, add chained findings monitor request if there are bucket level requests and if the detector triggers have any group by rules configured to trigger - if (enabledWorkflowUsage && !monitorRequests.isEmpty() && queries.stream().anyMatch(it -> it.getRight().isAggregationRule())) { + if (shouldAddChainedFindingDocMonitor(monitorRequests.isEmpty(), queries)) { monitorRequests.add(createDocLevelMonitorMatchAllRequest(detector, RefreshPolicy.IMMEDIATE, detector.getId() + "_chained_findings", Method.POST, queries)); } listener.onResponse(monitorRequests); @@ -1766,7 +1781,7 @@ private Map mapMonitorIds(List monitorResp return it.getMonitor().getTriggers().get(0).getId(); } else { if (it.getMonitor().getName().contains("_chained_findings")) { - return "chained_findings_monitor"; + return CHAINED_FINDINGS_MONITOR_STRING; } else { return Detector.DOC_LEVEL_MONITOR; } diff --git a/src/test/java/org/opensearch/securityanalytics/alerts/AlertsIT.java b/src/test/java/org/opensearch/securityanalytics/alerts/AlertsIT.java index 968b341c3..c666a1d27 100644 --- a/src/test/java/org/opensearch/securityanalytics/alerts/AlertsIT.java +++ b/src/test/java/org/opensearch/securityanalytics/alerts/AlertsIT.java @@ -8,6 +8,7 @@ import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collection; import java.util.Collections; import java.util.HashMap; import java.util.HashSet; @@ -38,9 +39,11 @@ import org.opensearch.securityanalytics.model.DetectorRule; import org.opensearch.securityanalytics.model.DetectorTrigger; +import static java.util.Collections.emptyList; import static org.opensearch.securityanalytics.TestHelpers.netFlowMappings; import static org.opensearch.securityanalytics.TestHelpers.randomAction; import static org.opensearch.securityanalytics.TestHelpers.randomAggregationRule; +import static org.opensearch.securityanalytics.TestHelpers.randomDetector; import static org.opensearch.securityanalytics.TestHelpers.randomDetectorType; import static org.opensearch.securityanalytics.TestHelpers.randomDetectorWithInputs; import static org.opensearch.securityanalytics.TestHelpers.randomDetectorWithInputsAndThreatIntel; @@ -794,6 +797,49 @@ public void testMultipleAggregationAndDocRules_alertSuccess() throws IOException Map getAlertsBody = asMap(getAlertsResponse); // TODO enable asserts here when able Assert.assertEquals(3, getAlertsBody.get("total_alerts")); // 2 doc level alerts for each doc, 1 bucket level alert + + input = new DetectorInput("updated", List.of("windows"), detectorRules, + Collections.emptyList()); + Detector updatedDetector = randomDetectorWithInputsAndTriggers(List.of(input), + List.of(new DetectorTrigger("updated", "test-trigger", "1", List.of(randomDetectorType()), List.of(), List.of(), List.of(), List.of(), List.of())) + ); + /** update detector and verify chained findings monitor should still exist*/ + Response updateResponse = makeRequest(client(), "PUT", SecurityAnalyticsPlugin.DETECTOR_BASE_URI + "/" + detectorId, Collections.emptyMap(), toHttpEntity(updatedDetector)); + hits = executeSearch(Detector.DETECTORS_INDEX, request); + hit = hits.get(0); + updatedDetectorMap = (HashMap) (hit.getSourceAsMap().get("detector")); + + assertEquals(2, ((List) (updatedDetectorMap).get("monitor_id")).size()); + indexDoc(index, "3", randomDoc(2, 5, infoOpCode)); + indexDoc(index, "4", randomDoc(3, 5, infoOpCode)); + + hits = executeSearch(Detector.DETECTORS_INDEX, request); + hit = hits.get(0); + updatedDetectorMap = (HashMap) (hit.getSourceAsMap().get("detector")); + + monitorIds = ((List) (updatedDetectorMap).get("monitor_id")); + numberOfMonitorTypes = new HashMap<>(); + for (String monitorId : monitorIds) { + Map monitor = (Map) (entityAsMap(client().performRequest(new Request("GET", "/_plugins/_alerting/monitors/" + monitorId)))).get("monitor"); + numberOfMonitorTypes.merge(monitor.get("monitor_type"), 1, Integer::sum); + Response executeResponse = executeAlertingMonitor(monitorId, Collections.emptyMap()); + + // Assert monitor executions + Map executeResults = entityAsMap(executeResponse); + + if (Monitor.MonitorType.BUCKET_LEVEL_MONITOR.getValue().equals(monitor.get("monitor_type"))) { + ArrayList triggerResults = new ArrayList(((Map) executeResults.get("trigger_results")).values()); + assertEquals(triggerResults.size(), 1); + Map triggerResult = (Map) triggerResults.get(0); + assertTrue(triggerResult.containsKey("agg_result_buckets")); + HashMap aggResultBuckets = (HashMap) triggerResult.get("agg_result_buckets"); + assertTrue(aggResultBuckets.containsKey("4")); + assertTrue(aggResultBuckets.containsKey("5")); + } + } + + assertEquals(1, numberOfMonitorTypes.get(Monitor.MonitorType.BUCKET_LEVEL_MONITOR.getValue()).intValue()); + assertEquals(1, numberOfMonitorTypes.get(Monitor.MonitorType.DOC_LEVEL_MONITOR.getValue()).intValue()); } @Ignore