Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix chained findings monitor logic in update detector flow #1019

Merged
merged 2 commits into from
May 9, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -4,28 +4,21 @@
*/
package org.opensearch.securityanalytics.findings;

import java.time.Instant;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.search.join.ScoreMode;
import org.opensearch.OpenSearchStatusException;
import org.opensearch.core.action.ActionListener;
import org.opensearch.client.Client;
import org.opensearch.client.node.NodeClient;
import org.opensearch.commons.alerting.AlertingPluginInterface;
import org.opensearch.commons.alerting.model.DocLevelQuery;
import org.opensearch.commons.alerting.model.FindingWithDocs;
import org.opensearch.commons.alerting.model.Table;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.index.query.BoolQueryBuilder;
import org.opensearch.index.query.PrefixQueryBuilder;
import org.opensearch.index.query.NestedQueryBuilder;
import org.opensearch.index.query.PrefixQueryBuilder;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.securityanalytics.action.FindingDto;
@@ -37,6 +30,16 @@
import org.opensearch.securityanalytics.model.Detector;
import org.opensearch.securityanalytics.util.SecurityAnalyticsException;

import java.time.Instant;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

import static org.opensearch.securityanalytics.transport.TransportIndexDetectorAction.CHAINED_FINDINGS_MONITOR_STRING;

/**
* Implements searching/fetching of findings
*/
@@ -99,8 +102,8 @@ public void onFailure(Exception e) {
Map<String, Detector> monitorToDetectorMapping = new HashMap<>();
detector.getMonitorIds().forEach(
monitorId -> {
if (detector.getRuleIdMonitorIdMap().containsKey("chained_findings_monitor")) {
if (!detector.getRuleIdMonitorIdMap().get("chained_findings_monitor").equals(monitorId)) {
if (detector.getRuleIdMonitorIdMap().containsKey(CHAINED_FINDINGS_MONITOR_STRING)) {
if (!detector.getRuleIdMonitorIdMap().get(CHAINED_FINDINGS_MONITOR_STRING).equals(monitorId)) {
monitorToDetectorMapping.put(monitorId, detector);
}
} else {
Original file line number Diff line number Diff line change
@@ -132,6 +132,7 @@ public class TransportIndexDetectorAction extends HandledTransportAction<IndexDe
public static final String PLUGIN_OWNER_FIELD = "security_analytics";
private static final Logger log = LogManager.getLogger(TransportIndexDetectorAction.class);
public static final String TIMESTAMP_FIELD_ALIAS = "timestamp";
public static final String CHAINED_FINDINGS_MONITOR_STRING = "chained_findings_monitor";

private final Client client;

@@ -465,6 +466,16 @@ public void onResponse(Map<String, Map<String, String>> ruleFieldMappings) {
new ActionListener<>() {
@Override
public void onResponse(Collection<IndexMonitorRequest> indexMonitorRequests) {
if (detector.getRuleIdMonitorIdMap().containsKey(CHAINED_FINDINGS_MONITOR_STRING)) {
String cmfId = detector.getRuleIdMonitorIdMap().get(CHAINED_FINDINGS_MONITOR_STRING);
if (shouldAddChainedFindingDocMonitor(indexMonitorRequests.isEmpty(), rulesById)) {
monitorsToBeUpdated.add(createDocLevelMonitorMatchAllRequest(detector, RefreshPolicy.IMMEDIATE, cmfId, Method.PUT, rulesById));
}
} else {
if (shouldAddChainedFindingDocMonitor(indexMonitorRequests.isEmpty(), rulesById)) {
monitorsToBeAdded.add(createDocLevelMonitorMatchAllRequest(detector, RefreshPolicy.IMMEDIATE, detector.getId() + "_chained_findings", Method.POST, rulesById));
}
}
onIndexMonitorRequestCreation(
monitorsToBeUpdated,
monitorsToBeAdded,
@@ -564,6 +575,10 @@ public void onFailure(Exception e) {
});
}

private boolean shouldAddChainedFindingDocMonitor(boolean bucketLevelMonitorsExist, List<Pair<String, Rule>> rulesById) {
return enabledWorkflowUsage && !bucketLevelMonitorsExist && rulesById.stream().anyMatch(it -> it.getRight().isAggregationRule());
}

private void onIndexMonitorRequestCreation(List<IndexMonitorRequest> monitorsToBeUpdated,
List<IndexMonitorRequest> monitorsToBeAdded,
List<Pair<String, Rule>> rulesById,
@@ -910,7 +925,7 @@ public void onResponse(Map<String, Map<String, String>> ruleFieldMappings) {
@Override
public void onResponse(Collection<IndexMonitorRequest> indexMonitorRequests) {
// if workflow usage enabled, add chained findings monitor request if there are bucket level requests and if the detector triggers have any group by rules configured to trigger
if (enabledWorkflowUsage && !monitorRequests.isEmpty() && queries.stream().anyMatch(it -> it.getRight().isAggregationRule())) {
if (shouldAddChainedFindingDocMonitor(monitorRequests.isEmpty(), queries)) {
monitorRequests.add(createDocLevelMonitorMatchAllRequest(detector, RefreshPolicy.IMMEDIATE, detector.getId() + "_chained_findings", Method.POST, queries));
}
listener.onResponse(monitorRequests);
@@ -1771,7 +1786,7 @@ private Map<String, String> mapMonitorIds(List<IndexMonitorResponse> monitorResp
return it.getMonitor().getTriggers().get(0).getId();
} else {
if (it.getMonitor().getName().contains("_chained_findings")) {
return "chained_findings_monitor";
return CHAINED_FINDINGS_MONITOR_STRING;
} else {
return Detector.DOC_LEVEL_MONITOR;
}
Original file line number Diff line number Diff line change
@@ -8,6 +8,7 @@
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
@@ -39,9 +40,11 @@
import org.opensearch.securityanalytics.model.DetectorRule;
import org.opensearch.securityanalytics.model.DetectorTrigger;

import static java.util.Collections.emptyList;
import static org.opensearch.securityanalytics.TestHelpers.netFlowMappings;
import static org.opensearch.securityanalytics.TestHelpers.randomAction;
import static org.opensearch.securityanalytics.TestHelpers.randomAggregationRule;
import static org.opensearch.securityanalytics.TestHelpers.randomDetector;
import static org.opensearch.securityanalytics.TestHelpers.randomDetectorType;
import static org.opensearch.securityanalytics.TestHelpers.randomDetectorWithInputs;
import static org.opensearch.securityanalytics.TestHelpers.randomDetectorWithInputsAndThreatIntel;
@@ -795,6 +798,49 @@ public void testMultipleAggregationAndDocRules_alertSuccess() throws IOException
Map<String, Object> getAlertsBody = asMap(getAlertsResponse);
// TODO enable asserts here when able
Assert.assertEquals(3, getAlertsBody.get("total_alerts")); // 2 doc level alerts for each doc, 1 bucket level alert

input = new DetectorInput("updated", List.of("windows"), detectorRules,
Collections.emptyList());
Detector updatedDetector = randomDetectorWithInputsAndTriggers(List.of(input),
List.of(new DetectorTrigger("updated", "test-trigger", "1", List.of(randomDetectorType()), List.of(), List.of(), List.of(), List.of(), List.of()))
);
/** update detector and verify chained findings monitor should still exist*/
Response updateResponse = makeRequest(client(), "PUT", SecurityAnalyticsPlugin.DETECTOR_BASE_URI + "/" + detectorId, Collections.emptyMap(), toHttpEntity(updatedDetector));
hits = executeSearch(Detector.DETECTORS_INDEX, request);
hit = hits.get(0);
updatedDetectorMap = (HashMap<String, List>) (hit.getSourceAsMap().get("detector"));

assertEquals(2, ((List<String>) (updatedDetectorMap).get("monitor_id")).size());
indexDoc(index, "3", randomDoc(2, 5, infoOpCode));
indexDoc(index, "4", randomDoc(3, 5, infoOpCode));

hits = executeSearch(Detector.DETECTORS_INDEX, request);
hit = hits.get(0);
updatedDetectorMap = (HashMap<String, List>) (hit.getSourceAsMap().get("detector"));

monitorIds = ((List<String>) (updatedDetectorMap).get("monitor_id"));
numberOfMonitorTypes = new HashMap<>();
for (String monitorId : monitorIds) {
Map<String, String> monitor = (Map<String, String>) (entityAsMap(client().performRequest(new Request("GET", "/_plugins/_alerting/monitors/" + monitorId)))).get("monitor");
numberOfMonitorTypes.merge(monitor.get("monitor_type"), 1, Integer::sum);
Response executeResponse = executeAlertingMonitor(monitorId, Collections.emptyMap());

// Assert monitor executions
Map<String, Object> executeResults = entityAsMap(executeResponse);

if (Monitor.MonitorType.BUCKET_LEVEL_MONITOR.getValue().equals(monitor.get("monitor_type"))) {
ArrayList triggerResults = new ArrayList(((Map<String, Object>) executeResults.get("trigger_results")).values());
assertEquals(triggerResults.size(), 1);
Map<String, Object> triggerResult = (Map<String, Object>) triggerResults.get(0);
assertTrue(triggerResult.containsKey("agg_result_buckets"));
HashMap<String, Object> aggResultBuckets = (HashMap<String, Object>) triggerResult.get("agg_result_buckets");
assertTrue(aggResultBuckets.containsKey("4"));
assertTrue(aggResultBuckets.containsKey("5"));
}
}

assertEquals(1, numberOfMonitorTypes.get(Monitor.MonitorType.BUCKET_LEVEL_MONITOR.getValue()).intValue());
assertEquals(1, numberOfMonitorTypes.get(Monitor.MonitorType.DOC_LEVEL_MONITOR.getValue()).intValue());
}

@Ignore