From 724e2e31733bd23f2966b7a7ce917c002b571598 Mon Sep 17 00:00:00 2001
From: "github-actions[bot]" <github-actions[bot]@users.noreply.github.com>
Date: Tue, 24 Sep 2024 22:40:13 +0000
Subject: [PATCH] threat intel monitor bug fixes (#1317)

* handle exception arising from trying to search with sort on empty index

Signed-off-by: Surya Sashank Nistala <snistala@amazon.com>

* add setting to test max term count in threat intel ioc scan terms query and verify grouped listener wiring

Signed-off-by: Surya Sashank Nistala <snistala@amazon.com>

* remove unused variable

Signed-off-by: Surya Sashank Nistala <snistala@amazon.com>

* avoid grouped listener being initiated with size 0

Signed-off-by: Surya Sashank Nistala <snistala@amazon.com>

* add verification that empty index scan is handled gracefully

Signed-off-by: Surya Sashank Nistala <snistala@amazon.com>

---------

Signed-off-by: Surya Sashank Nistala <snistala@amazon.com>
(cherry picked from commit 39c29d462acea046f636da5923049f3300427e92)
Signed-off-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
---
 .../SecurityAnalyticsPlugin.java              |  5 ++--
 .../settings/SecurityAnalyticsSettings.java   | 12 +++++++++
 .../iocscan/service/IoCScanService.java       |  7 +++++
 .../iocscan/service/SaIoCScanService.java     | 27 ++++++++++---------
 ...ansportThreatIntelMonitorFanOutAction.java | 23 ++++++++++++++--
 .../ThreatIntelMonitorRestApiIT.java          | 10 ++++++-
 6 files changed, 67 insertions(+), 17 deletions(-)

diff --git a/src/main/java/org/opensearch/securityanalytics/SecurityAnalyticsPlugin.java b/src/main/java/org/opensearch/securityanalytics/SecurityAnalyticsPlugin.java
index 4f79dcc7d..dc17d581e 100644
--- a/src/main/java/org/opensearch/securityanalytics/SecurityAnalyticsPlugin.java
+++ b/src/main/java/org/opensearch/securityanalytics/SecurityAnalyticsPlugin.java
@@ -327,7 +327,7 @@ public Collection<Object> createComponents(Client client,
         TIFJobRunner.getJobRunnerInstance().initialize(clusterService, tifJobUpdateService, tifJobParameterService, threatIntelLockService, threadPool, detectorThreatIntelService);
         IocFindingService iocFindingService = new IocFindingService(client, clusterService, xContentRegistry);
         ThreatIntelAlertService threatIntelAlertService = new ThreatIntelAlertService(client, clusterService, xContentRegistry);
-        SaIoCScanService ioCScanService = new SaIoCScanService(client, xContentRegistry, iocFindingService, threatIntelAlertService, notificationService);
+        SaIoCScanService ioCScanService = new SaIoCScanService(client, clusterService, xContentRegistry, iocFindingService, threatIntelAlertService, notificationService);
         DefaultTifSourceConfigLoaderService defaultTifSourceConfigLoaderService = new DefaultTifSourceConfigLoaderService(builtInTIFMetadataLoader, client, saTifSourceConfigManagementService);
         return List.of(
                 detectorIndices, correlationIndices, correlationRuleIndices, ruleTopicIndices, customLogTypeIndices, ruleIndices, threatIntelAlertService,
@@ -502,7 +502,8 @@ public List<Setting<?>> getSettings() {
                 SecurityAnalyticsSettings.BATCH_SIZE,
                 SecurityAnalyticsSettings.THREAT_INTEL_TIMEOUT,
                 SecurityAnalyticsSettings.IOC_INDEX_RETENTION_PERIOD,
-                SecurityAnalyticsSettings.IOC_MAX_INDICES_PER_INDEX_PATTERN
+                SecurityAnalyticsSettings.IOC_MAX_INDICES_PER_INDEX_PATTERN,
+                SecurityAnalyticsSettings.IOC_SCAN_MAX_TERMS_COUNT
         );
     }
 
diff --git a/src/main/java/org/opensearch/securityanalytics/settings/SecurityAnalyticsSettings.java b/src/main/java/org/opensearch/securityanalytics/settings/SecurityAnalyticsSettings.java
index b0f0fed74..8bcc66d40 100644
--- a/src/main/java/org/opensearch/securityanalytics/settings/SecurityAnalyticsSettings.java
+++ b/src/main/java/org/opensearch/securityanalytics/settings/SecurityAnalyticsSettings.java
@@ -10,6 +10,8 @@
 import java.util.List;
 import java.util.concurrent.TimeUnit;
 
+import static org.opensearch.index.IndexSettings.MAX_TERMS_COUNT_SETTING;
+
 public class SecurityAnalyticsSettings {
     public static final String CORRELATION_INDEX = "index.correlation";
 
@@ -237,4 +239,14 @@ public static final List<Setting<?>> settings() {
             Setting.Property.NodeScope, Setting.Property.Dynamic
     );
 
+    /**
+     * Maximum terms in Terms query search query submitted during ioc scan
+     */
+    public static final Setting<Integer> IOC_SCAN_MAX_TERMS_COUNT  = Setting.intSetting(
+            "plugins.security_analytics.ioc.scan_max_terms_count",
+            65536,
+            1,
+            Setting.Property.NodeScope, Setting.Property.Dynamic
+    );
+
 }
\ No newline at end of file
diff --git a/src/main/java/org/opensearch/securityanalytics/threatIntel/iocscan/service/IoCScanService.java b/src/main/java/org/opensearch/securityanalytics/threatIntel/iocscan/service/IoCScanService.java
index f3238460a..861880da9 100644
--- a/src/main/java/org/opensearch/securityanalytics/threatIntel/iocscan/service/IoCScanService.java
+++ b/src/main/java/org/opensearch/securityanalytics/threatIntel/iocscan/service/IoCScanService.java
@@ -40,6 +40,13 @@ public void scanIoCs(IocScanContext<Data> iocScanContext,
 
             long startTime = System.currentTimeMillis();
             IocLookupDtos iocLookupDtos = extractIocsPerType(data, iocScanContext);
+            if (iocLookupDtos.getIocsPerIocTypeMap().isEmpty()) {
+                log.error("Threat intel monitor {}: Unexpected scenario that non-zero number of docs are fetched from indices containing iocs but iocs-per-type map constructed is empty",
+                        iocScanContext.getMonitor().getId()
+                );
+                scanCallback.accept(Collections.emptyList(), null);
+                return;
+            }
             BiConsumer<List<STIX2IOC>, Exception> iocScanResultConsumer = (List<STIX2IOC> maliciousIocs, Exception e) -> {
                 long scanEndTime = System.currentTimeMillis();
                 long timeTaken = scanEndTime - startTime;
diff --git a/src/main/java/org/opensearch/securityanalytics/threatIntel/iocscan/service/SaIoCScanService.java b/src/main/java/org/opensearch/securityanalytics/threatIntel/iocscan/service/SaIoCScanService.java
index 7181dda2e..109fc0bcb 100644
--- a/src/main/java/org/opensearch/securityanalytics/threatIntel/iocscan/service/SaIoCScanService.java
+++ b/src/main/java/org/opensearch/securityanalytics/threatIntel/iocscan/service/SaIoCScanService.java
@@ -7,6 +7,7 @@
 import org.opensearch.action.search.ShardSearchFailure;
 import org.opensearch.action.support.GroupedActionListener;
 import org.opensearch.client.Client;
+import org.opensearch.cluster.service.ClusterService;
 import org.opensearch.common.document.DocumentField;
 import org.opensearch.common.xcontent.LoggingDeprecationHandler;
 import org.opensearch.common.xcontent.XContentType;
@@ -28,6 +29,7 @@
 import org.opensearch.securityanalytics.model.STIX2IOC;
 import org.opensearch.securityanalytics.model.threatintel.IocFinding;
 import org.opensearch.securityanalytics.model.threatintel.ThreatIntelAlert;
+import org.opensearch.securityanalytics.settings.SecurityAnalyticsSettings;
 import org.opensearch.securityanalytics.threatIntel.iocscan.dao.IocFindingService;
 import org.opensearch.securityanalytics.threatIntel.iocscan.dao.ThreatIntelAlertService;
 import org.opensearch.securityanalytics.threatIntel.iocscan.dto.IocScanContext;
@@ -54,16 +56,17 @@
 public class SaIoCScanService extends IoCScanService<SearchHit> {
 
     private static final Logger log = LogManager.getLogger(SaIoCScanService.class);
-    public static final int MAX_TERMS = 65536; //TODO make ioc index setting based. use same setting value to create index
     private final Client client;
+    private final ClusterService clusterService;
     private final NamedXContentRegistry xContentRegistry;
     private final IocFindingService iocFindingService;
     private final ThreatIntelAlertService threatIntelAlertService;
     private final NotificationService notificationService;
 
-    public SaIoCScanService(Client client, NamedXContentRegistry xContentRegistry, IocFindingService iocFindingService,
+    public SaIoCScanService(Client client, ClusterService clusterService, NamedXContentRegistry xContentRegistry, IocFindingService iocFindingService,
                             ThreatIntelAlertService threatIntelAlertService, NotificationService notificationService) {
         this.client = client;
+        this.clusterService = clusterService;
         this.xContentRegistry = xContentRegistry;
         this.iocFindingService = iocFindingService;
         this.threatIntelAlertService = threatIntelAlertService;
@@ -329,12 +332,13 @@ private void performScanForMaliciousIocsPerIocType(
             GroupedActionListener<SearchHitsOrException> listener) {
         // TODO change ioc indices max terms count to 100k and experiment
         // TODO add fuzzy postings on ioc value field to enable bloomfilter on iocs as an index data structure and benchmark performance
-        GroupedActionListener<SearchHitsOrException> perIocTypeListener = getGroupedListenerForIocScanPerIocType(iocs, monitor, iocType, listener);
+        int maxTerms = clusterService.getClusterSettings().get(SecurityAnalyticsSettings.IOC_SCAN_MAX_TERMS_COUNT);
+        GroupedActionListener<SearchHitsOrException> perIocTypeListener = getGroupedListenerForIocScanPerIocType(iocs, monitor, iocType, listener, maxTerms);
         List<String> iocList = new ArrayList<>(iocs);
         int totalIocs = iocList.size();
 
-        for (int start = 0; start < totalIocs; start += MAX_TERMS) {
-            int end = Math.min(start + MAX_TERMS, totalIocs);
+        for (int start = 0; start < totalIocs; start += maxTerms) {
+            int end = Math.min(start + maxTerms, totalIocs);
             List<String> iocsSublist = iocList.subList(start, end);
             SearchRequest searchRequest = getSearchRequestForIocType(indices, iocType, iocsSublist);
             client.search(searchRequest, ActionListener.wrap(
@@ -356,7 +360,7 @@ private void performScanForMaliciousIocsPerIocType(
                                 );
                             }
                         }
-                        listener.onResponse(new SearchHitsOrException(
+                        perIocTypeListener.onResponse(new SearchHitsOrException(
                                 searchResponse.getHits() == null || searchResponse.getHits().getHits() == null ?
                                         emptyList() : Arrays.asList(searchResponse.getHits().getHits()), null));
                     },
@@ -366,7 +370,7 @@ private void performScanForMaliciousIocsPerIocType(
                                 iocsSublist.size(),
                                 iocType), e
                         );
-                        listener.onResponse(new SearchHitsOrException(emptyList(), e));
+                        perIocTypeListener.onResponse(new SearchHitsOrException(emptyList(), e));
                     }
             ));
         }
@@ -387,7 +391,7 @@ private static SearchRequest getSearchRequestForIocType(List<String> indices, St
      * grouped listener for a given ioc type to listen and collate malicious iocs in search hits from batched search calls.
      * batching done for every 65536 or MAX_TERMS setting number of iocs in a list.
      */
-    private GroupedActionListener<SearchHitsOrException> getGroupedListenerForIocScanPerIocType(Set<String> iocs, Monitor monitor, String iocType, GroupedActionListener<SearchHitsOrException> groupedListenerForAllIocTypes) {
+    private GroupedActionListener<SearchHitsOrException> getGroupedListenerForIocScanPerIocType(Set<String> iocs, Monitor monitor, String iocType, GroupedActionListener<SearchHitsOrException> groupedListenerForAllIocTypes, int maxTerms) {
         return new GroupedActionListener<>(
                 ActionListener.wrap(
                         (Collection<SearchHitsOrException> searchHitsOrExceptions) -> {
@@ -419,8 +423,7 @@ private GroupedActionListener<SearchHitsOrException> getGroupedListenerForIocSca
                             groupedListenerForAllIocTypes.onResponse(new SearchHitsOrException(emptyList(), e));
                         }
                 ),
-                //TODO fix groupsize
-                getGroupSizeForIocs(iocs) // batch into #MAX_TERMS setting
+                getGroupSizeForIocs(iocs, maxTerms)
         );
     }
 
@@ -436,8 +439,8 @@ private Exception buildException(Collection<SearchHitsOrException> searchHitsOrE
         return e;
     }
 
-    private static int getGroupSizeForIocs(Set<String> iocs) {
-        return iocs.size() / MAX_TERMS + (iocs.size() % MAX_TERMS == 0 ? 0 : 1);
+    private static int getGroupSizeForIocs(Set<String> iocs, int maxTerms) {
+        return iocs.size() / maxTerms + (iocs.size() % maxTerms == 0 ? 0 : 1);
     }
 
     @Override
diff --git a/src/main/java/org/opensearch/securityanalytics/threatIntel/model/monitor/TransportThreatIntelMonitorFanOutAction.java b/src/main/java/org/opensearch/securityanalytics/threatIntel/model/monitor/TransportThreatIntelMonitorFanOutAction.java
index 6864f7a98..2421e5e5c 100644
--- a/src/main/java/org/opensearch/securityanalytics/threatIntel/model/monitor/TransportThreatIntelMonitorFanOutAction.java
+++ b/src/main/java/org/opensearch/securityanalytics/threatIntel/model/monitor/TransportThreatIntelMonitorFanOutAction.java
@@ -106,7 +106,7 @@ protected void doExecute(Task task, DocLevelMonitorFanOutRequest request, Action
                     iocTypeToIndicesMap -> {
                         onGetIocTypeToIndices(iocTypeToIndicesMap, request, actionListener);
                     }, e -> {
-                        log.error(() -> new ParameterizedMessage("Unexpected Failure in threat intel monitor {} fan out action", request.getMonitor().getId()), e);
+                        log.error(() -> new ParameterizedMessage("Unexpected Failure in threat intel monitor {} fan out action while fetching threat intel ioc indices", request.getMonitor().getId()), e);
                         actionListener.onResponse(
                                 new DocLevelMonitorFanOutResponse(
                                         clusterService.localNode().getId(),
@@ -162,6 +162,20 @@ private void onGetIocTypeToIndices(Map<String, List<String>> iocTypeToIndicesMap
         };
         ActionListener<List<SearchHit>> searchHitsListener = ActionListener.wrap(
                 (List<SearchHit> hits) -> {
+                    if (hits.isEmpty()) {
+                        actionListener.onResponse(
+                                new DocLevelMonitorFanOutResponse(
+                                        clusterService.localNode().getId(),
+                                        request.getExecutionId(),
+                                        request.getMonitor().getId(),
+                                        updatedLastRunContext,
+                                        new InputRunResults(Collections.emptyList(), null, null),
+                                        Collections.emptyMap(),
+                                        null
+                                )
+                        );
+                        return;
+                    }
                     BiConsumer<Object, Exception> resultConsumer = (r, e) -> {
                         if (e == null) {
                             actionListener.onResponse(
@@ -195,7 +209,7 @@ private void onGetIocTypeToIndices(Map<String, List<String>> iocTypeToIndicesMap
                     ), resultConsumer);
                 },
                 e -> {
-                    log.error("unexpected error while", e);
+                    log.error("unexpected error while trying to query shards and fetch docs before scanning for malicious IoC's", e);
                     actionListener.onFailure(e);
                 }
         );
@@ -290,6 +304,11 @@ private void fetchLatestDocsFromShard(
                                     // recursive call to fetch docs with updated seq no.
                                     fetchLatestDocsFromShard(shardId, fromSeqNo, updatedToSeqNo, searchHitsSoFar, monitor, shardLastSeenMapForIndex, updateLastRunContext, fieldsToFetch, listener);
                                 }, e -> {
+                                    if(e.getMessage().contains("all shards failed") && e.getCause().getMessage().contains("No mapping found for [_seq_no] in order to sort on")) {
+                                        // this implies that the index being queried doesn't have any docs and hence doesn't understand the in-built _seq_no field mapping
+                                        listener.onResponse(new SearchHitsOrException(Collections.emptyList(), null));
+                                        return;
+                                    }
                                     log.error(() -> new ParameterizedMessage("Threat intel Monitor {}: Failed to search shard {} in index {}", monitor.getId(), shard, shardId.getIndexName()), e);
                                     listener.onResponse(new SearchHitsOrException(searchHitsSoFar, e));
                                 }
diff --git a/src/test/java/org/opensearch/securityanalytics/resthandler/ThreatIntelMonitorRestApiIT.java b/src/test/java/org/opensearch/securityanalytics/resthandler/ThreatIntelMonitorRestApiIT.java
index a13f3252f..6af6c4275 100644
--- a/src/test/java/org/opensearch/securityanalytics/resthandler/ThreatIntelMonitorRestApiIT.java
+++ b/src/test/java/org/opensearch/securityanalytics/resthandler/ThreatIntelMonitorRestApiIT.java
@@ -17,6 +17,7 @@
 import org.opensearch.search.SearchHit;
 import org.opensearch.securityanalytics.SecurityAnalyticsPlugin;
 import org.opensearch.securityanalytics.SecurityAnalyticsRestTestCase;
+import org.opensearch.securityanalytics.settings.SecurityAnalyticsSettings;
 import org.opensearch.securityanalytics.threatIntel.action.ListIOCsActionRequest;
 import org.opensearch.securityanalytics.commons.model.IOCType;
 import org.opensearch.securityanalytics.model.Detector;
@@ -47,6 +48,7 @@
 import static org.opensearch.securityanalytics.TestHelpers.randomDetectorWithTriggers;
 import static org.opensearch.securityanalytics.TestHelpers.randomIndex;
 import static org.opensearch.securityanalytics.TestHelpers.windowsIndexMapping;
+import static org.opensearch.securityanalytics.settings.SecurityAnalyticsSettings.ALERT_HISTORY_MAX_DOCS;
 import static org.opensearch.securityanalytics.threatIntel.resthandler.monitor.RestSearchThreatIntelMonitorAction.SEARCH_THREAT_INTEL_MONITOR_PATH;
 
 public class ThreatIntelMonitorRestApiIT extends SecurityAnalyticsRestTestCase {
@@ -111,6 +113,7 @@ private String indexTifSourceConfig(List<STIX2IOCDto> testIocDtos) throws IOExce
     }
 
     public void testCreateThreatIntelMonitor_monitorAliases() throws IOException {
+        updateClusterSetting(SecurityAnalyticsSettings.IOC_SCAN_MAX_TERMS_COUNT.getKey(), "1");
         Response iocFindingsResponse = makeRequest(client(), "GET", SecurityAnalyticsPlugin.THREAT_INTEL_BASE_URI + "/findings/_search",
                 Map.of(), null);
         Map<String, Object> responseAsMap = responseAsMap(iocFindingsResponse);
@@ -138,6 +141,8 @@ public void testCreateThreatIntelMonitor_monitorAliases() throws IOException {
 
         final String monitorId = responseBody.get("id").toString();
         Assert.assertNotEquals("response is missing Id", Monitor.NO_ID, monitorId);
+        Response executeResponse = executeAlertingMonitor(monitorId, Collections.emptyMap());
+        Assert.assertEquals(200, executeResponse.getStatusLine().getStatusCode());
 
         Response alertingMonitorResponse = getAlertingMonitor(client(), monitorId);
         Assert.assertEquals(200, alertingMonitorResponse.getStatusLine().getStatusCode());
@@ -151,7 +156,7 @@ public void testCreateThreatIntelMonitor_monitorAliases() throws IOException {
             }
         }
 
-        Response executeResponse = executeAlertingMonitor(monitorId, Collections.emptyMap());
+        executeResponse = executeAlertingMonitor(monitorId, Collections.emptyMap());
         Map<String, Object> executeResults = entityAsMap(executeResponse);
 
         String matchAllRequest = getMatchAllRequest();
@@ -247,6 +252,7 @@ public void testCreateThreatIntelMonitor_monitorAliases() throws IOException {
     }
 
     public void testCreateThreatIntelMonitor() throws IOException {
+        updateClusterSetting(SecurityAnalyticsSettings.IOC_SCAN_MAX_TERMS_COUNT.getKey(), "1");
         Response iocFindingsResponse = makeRequest(client(), "GET", SecurityAnalyticsPlugin.THREAT_INTEL_BASE_URI + "/findings/_search",
                 Map.of(), null);
         Map<String, Object> responseAsMap = responseAsMap(iocFindingsResponse);
@@ -281,6 +287,8 @@ public void testCreateThreatIntelMonitor() throws IOException {
             String doc = String.format("{\"ip\":\"%s\", \"ip1\":\"%s\"}", val, val);
             try {
                 indexDoc(index, "" + i++, doc);
+                indexDoc(index, "" + i++, String.format("{\"ip\":\"1.2.3.4\", \"ip1\":\"1.2.3.4\"}", val, val));
+                indexDoc(index, "" + i++, String.format("{\"random\":\"%s\", \"random1\":\"%s\"}", val, val));
             } catch (IOException e) {
                 fail();
             }