From 69853fe99c4ee940fc739373e0ba04418bda71c6 Mon Sep 17 00:00:00 2001 From: Manasvini B Suryanarayana Date: Fri, 23 Aug 2024 14:19:06 -0700 Subject: [PATCH] Add pit for pagination query (#2940) * Add pit for join queries (#2703) * Add search after for join Signed-off-by: Rupal Mahajan * Enable search after by default Signed-off-by: Rupal Mahajan * Add pit Signed-off-by: Rupal Mahajan * nit Signed-off-by: Rupal Mahajan * Fix tests Signed-off-by: Rupal Mahajan * ignore joinWithGeoIntersectNL Signed-off-by: Rupal Mahajan * Rerun CI with scroll Signed-off-by: Rupal Mahajan * Remove unused code and retrigger CI with search_after true Signed-off-by: Rupal Mahajan * Address comments Signed-off-by: Rupal Mahajan * Remove unused code change Signed-off-by: Rupal Mahajan * Update pit keep alive time with SQL_CURSOR_KEEP_ALIVE Signed-off-by: Rupal Mahajan * Fix scroll condition Signed-off-by: Rupal Mahajan * nit Signed-off-by: Rupal Mahajan * Add pit before query execution Signed-off-by: Rupal Mahajan * nit Signed-off-by: Rupal Mahajan * Move pit from join request builder to executor Signed-off-by: Rupal Mahajan * Remove unused methods Signed-off-by: Rupal Mahajan * Add pit in parent class's run() Signed-off-by: Rupal Mahajan * Add comment for fetching subsequent result in NestedLoopsElasticExecutor Signed-off-by: Rupal Mahajan * Update comment Signed-off-by: Rupal Mahajan * Add javadoc for pit handler Signed-off-by: Rupal Mahajan * Add pit interface Signed-off-by: Rupal Mahajan * Add pit handler unit test Signed-off-by: Rupal Mahajan * Fix failed unit test CI Signed-off-by: Rupal Mahajan * Fix spotless error Signed-off-by: Rupal Mahajan * Rename pit class and add logs Signed-off-by: Rupal Mahajan * Fix pit delete unit test Signed-off-by: Rupal Mahajan --------- Signed-off-by: Rupal Mahajan * Add pit for multi query (#2753) * Add search after for join Signed-off-by: Rupal Mahajan * Enable search after by default Signed-off-by: Rupal Mahajan * Add pit Signed-off-by: Rupal Mahajan * nit Signed-off-by: Rupal Mahajan * Fix tests Signed-off-by: Rupal Mahajan * ignore joinWithGeoIntersectNL Signed-off-by: Rupal Mahajan * Rerun CI with scroll Signed-off-by: Rupal Mahajan * draft Signed-off-by: Rupal Mahajan * Remove unused code and retrigger CI with search_after true Signed-off-by: Rupal Mahajan * Address comments Signed-off-by: Rupal Mahajan * Remove unused code change Signed-off-by: Rupal Mahajan * Update pit keep alive time with SQL_CURSOR_KEEP_ALIVE Signed-off-by: Rupal Mahajan * Fix scroll condition Signed-off-by: Rupal Mahajan * nit Signed-off-by: Rupal Mahajan * Add pit before query execution Signed-off-by: Rupal Mahajan * Refactor get response with pit method Signed-off-by: Rupal Mahajan * Update remaining scroll search calls Signed-off-by: Rupal Mahajan * Fix integ test failures Signed-off-by: Rupal Mahajan * nit Signed-off-by: Rupal Mahajan * Move pit from join request builder to executor Signed-off-by: Rupal Mahajan * Remove unused methods Signed-off-by: Rupal Mahajan * Move pit from request to executor Signed-off-by: Rupal Mahajan * Fix pit.delete call missed while merge Signed-off-by: Rupal Mahajan * Move getResponseWithHits method to util class Signed-off-by: Rupal Mahajan * add try catch for create delete pit in minus executor Signed-off-by: Rupal Mahajan * move all common fields to ElasticHitsExecutor Signed-off-by: Rupal Mahajan * add javadoc for ElasticHitsExecutor Signed-off-by: Rupal Mahajan * Add missing javadoc Signed-off-by: Rupal Mahajan * Forcing an empty commit as last commit is stuck processing updates Signed-off-by: Rupal Mahajan --------- Signed-off-by: Rupal Mahajan * Add pit to default cursor Signed-off-by: Rupal Mahajan * Run CI without pit unit test Signed-off-by: Rupal Mahajan * Rerun CI without pit unit test Signed-off-by: Rupal Mahajan * FIx unit tests for PIT changes Signed-off-by: Manasvini B S * Addressed comments Signed-off-by: Manasvini B S --------- Signed-off-by: Rupal Mahajan Signed-off-by: Manasvini B S Co-authored-by: Rupal Mahajan --- .../sql/legacy/cursor/DefaultCursor.java | 122 +++++++++++++++++- .../executor/cursor/CursorCloseExecutor.java | 30 ++++- .../executor/cursor/CursorResultExecutor.java | 62 +++++++-- .../format/PrettyFormatRestExecutor.java | 45 ++++++- .../executor/format/SelectResultSet.java | 28 +++- .../executor/join/ElasticJoinExecutor.java | 9 +- .../legacy/executor/multi/MinusExecutor.java | 9 +- .../legacy/pit/PointInTimeHandlerImpl.java | 63 ++++----- .../sql/legacy/query/DefaultQueryAction.java | 16 ++- .../pit/PointInTimeHandlerImplTest.java | 118 ++++++++++++----- .../unittest/cursor/DefaultCursorTest.java | 58 ++++++++- .../query/DefaultQueryActionTest.java | 63 +++++++-- 12 files changed, 513 insertions(+), 110 deletions(-) diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/cursor/DefaultCursor.java b/legacy/src/main/java/org/opensearch/sql/legacy/cursor/DefaultCursor.java index c5be0066fc..2b0de9022c 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/cursor/DefaultCursor.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/cursor/DefaultCursor.java @@ -5,8 +5,19 @@ package org.opensearch.sql.legacy.cursor; +import static org.opensearch.core.xcontent.DeprecationHandler.IGNORE_DEPRECATIONS; +import static org.opensearch.sql.common.setting.Settings.Key.SQL_PAGINATION_API_SEARCH_AFTER; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; import com.google.common.base.Strings; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.security.AccessController; +import java.security.PrivilegedAction; import java.util.Base64; +import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -18,6 +29,16 @@ import lombok.Setter; import org.json.JSONArray; import org.json.JSONObject; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.search.SearchModule; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.sql.legacy.esdomain.LocalClusterState; import org.opensearch.sql.legacy.executor.format.Schema; /** @@ -40,6 +61,10 @@ public class DefaultCursor implements Cursor { private static final String SCROLL_ID = "s"; private static final String SCHEMA_COLUMNS = "c"; private static final String FIELD_ALIAS_MAP = "a"; + private static final String PIT_ID = "p"; + private static final String SEARCH_REQUEST = "r"; + private static final String SORT_FIELDS = "h"; + private static final ObjectMapper objectMapper = new ObjectMapper(); /** * To get mappings for index to check if type is date needed for @@ -70,11 +95,28 @@ public class DefaultCursor implements Cursor { /** To get next batch of result */ private String scrollId; + /** To get Point In Time */ + private String pitId; + + /** To get next batch of result with search after api */ + private SearchSourceBuilder searchSourceBuilder; + + /** To get last sort values * */ + private Object[] sortFields; + /** To reduce the number of rows left by fetchSize */ @NonNull private Integer fetchSize; private Integer limit; + /** + * {@link NamedXContentRegistry} from {@link SearchModule} used for construct {@link QueryBuilder} + * from DSL query string. + */ + private static final NamedXContentRegistry xContentRegistry = + new NamedXContentRegistry( + new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()); + @Override public CursorType getType() { return type; @@ -82,19 +124,56 @@ public CursorType getType() { @Override public String generateCursorId() { - if (rowsLeft <= 0 || Strings.isNullOrEmpty(scrollId)) { + if (rowsLeft <= 0 || isCursorIdNullOrEmpty()) { return null; } JSONObject json = new JSONObject(); json.put(FETCH_SIZE, fetchSize); json.put(ROWS_LEFT, rowsLeft); json.put(INDEX_PATTERN, indexPattern); - json.put(SCROLL_ID, scrollId); json.put(SCHEMA_COLUMNS, getSchemaAsJson()); json.put(FIELD_ALIAS_MAP, fieldAliasMap); + if (LocalClusterState.state().getSettingValue(SQL_PAGINATION_API_SEARCH_AFTER)) { + json.put(PIT_ID, pitId); + String sortFieldValue = + AccessController.doPrivileged( + (PrivilegedAction) + () -> { + try { + return objectMapper.writeValueAsString(sortFields); + } catch (JsonProcessingException e) { + throw new RuntimeException( + "Failed to parse sort fields from JSON string.", e); + } + }); + json.put(SORT_FIELDS, sortFieldValue); + setSearchRequestString(json, searchSourceBuilder); + } else { + json.put(SCROLL_ID, scrollId); + } return String.format("%s:%s", type.getId(), encodeCursor(json)); } + private void setSearchRequestString(JSONObject cursorJson, SearchSourceBuilder sourceBuilder) { + try { + ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); + XContentBuilder builder = XContentFactory.jsonBuilder(outputStream); + sourceBuilder.toXContent(builder, null); + builder.close(); + + String searchRequestBase64 = Base64.getEncoder().encodeToString(outputStream.toByteArray()); + cursorJson.put("searchSourceBuilder", searchRequestBase64); + } catch (IOException ex) { + throw new RuntimeException("Failed to set search request string on cursor json.", ex); + } + } + + private boolean isCursorIdNullOrEmpty() { + return LocalClusterState.state().getSettingValue(SQL_PAGINATION_API_SEARCH_AFTER) + ? Strings.isNullOrEmpty(pitId) + : Strings.isNullOrEmpty(scrollId); + } + public static DefaultCursor from(String cursorId) { /** * It is assumed that cursorId here is the second part of the original cursor passed by the @@ -105,13 +184,50 @@ public static DefaultCursor from(String cursorId) { cursor.setFetchSize(json.getInt(FETCH_SIZE)); cursor.setRowsLeft(json.getLong(ROWS_LEFT)); cursor.setIndexPattern(json.getString(INDEX_PATTERN)); - cursor.setScrollId(json.getString(SCROLL_ID)); + if (LocalClusterState.state().getSettingValue(SQL_PAGINATION_API_SEARCH_AFTER)) { + populateCursorForPit(json, cursor); + } else { + cursor.setScrollId(json.getString(SCROLL_ID)); + } cursor.setColumns(getColumnsFromSchema(json.getJSONArray(SCHEMA_COLUMNS))); cursor.setFieldAliasMap(fieldAliasMap(json.getJSONObject(FIELD_ALIAS_MAP))); return cursor; } + private static void populateCursorForPit(JSONObject json, DefaultCursor cursor) { + cursor.setPitId(json.getString(PIT_ID)); + + cursor.setSortFields(getSortFieldsFromJson(json)); + + // Retrieve and set the SearchSourceBuilder from the JSON field + String searchSourceBuilderBase64 = json.getString("searchSourceBuilder"); + byte[] bytes = Base64.getDecoder().decode(searchSourceBuilderBase64); + ByteArrayInputStream streamInput = new ByteArrayInputStream(bytes); + try { + XContentParser parser = + XContentType.JSON + .xContent() + .createParser(xContentRegistry, IGNORE_DEPRECATIONS, streamInput); + SearchSourceBuilder sourceBuilder = SearchSourceBuilder.fromXContent(parser); + cursor.setSearchSourceBuilder(sourceBuilder); + } catch (IOException ex) { + throw new RuntimeException("Failed to get searchSourceBuilder from cursor Id", ex); + } + } + + private static Object[] getSortFieldsFromJson(JSONObject json) { + return AccessController.doPrivileged( + (PrivilegedAction) + () -> { + try { + return objectMapper.readValue(json.getString(SORT_FIELDS), Object[].class); + } catch (JsonProcessingException e) { + throw new RuntimeException("Failed to parse sort fields from JSON string.", e); + } + }); + } + private JSONArray getSchemaAsJson() { JSONArray schemaJson = new JSONArray(); diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/executor/cursor/CursorCloseExecutor.java b/legacy/src/main/java/org/opensearch/sql/legacy/executor/cursor/CursorCloseExecutor.java index 7282eaed4c..222ca5d9fc 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/executor/cursor/CursorCloseExecutor.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/executor/cursor/CursorCloseExecutor.java @@ -6,6 +6,7 @@ package org.opensearch.sql.legacy.executor.cursor; import static org.opensearch.core.rest.RestStatus.OK; +import static org.opensearch.sql.common.setting.Settings.Key.SQL_PAGINATION_API_SEARCH_AFTER; import java.util.Map; import org.apache.logging.log4j.LogManager; @@ -18,8 +19,11 @@ import org.opensearch.rest.RestChannel; import org.opensearch.sql.legacy.cursor.CursorType; import org.opensearch.sql.legacy.cursor.DefaultCursor; +import org.opensearch.sql.legacy.esdomain.LocalClusterState; import org.opensearch.sql.legacy.metrics.MetricName; import org.opensearch.sql.legacy.metrics.Metrics; +import org.opensearch.sql.legacy.pit.PointInTimeHandler; +import org.opensearch.sql.legacy.pit.PointInTimeHandlerImpl; import org.opensearch.sql.legacy.rewriter.matchtoterm.VerificationException; public class CursorCloseExecutor implements CursorRestExecutor { @@ -79,14 +83,26 @@ public String execute(Client client, Map params) throws Exceptio } private String handleDefaultCursorCloseRequest(Client client, DefaultCursor cursor) { - String scrollId = cursor.getScrollId(); - ClearScrollResponse clearScrollResponse = - client.prepareClearScroll().addScrollId(scrollId).get(); - if (clearScrollResponse.isSucceeded()) { - return SUCCEEDED_TRUE; + if (LocalClusterState.state().getSettingValue(SQL_PAGINATION_API_SEARCH_AFTER)) { + String pitId = cursor.getPitId(); + PointInTimeHandler pit = new PointInTimeHandlerImpl(client, pitId); + try { + pit.delete(); + return SUCCEEDED_TRUE; + } catch (RuntimeException e) { + Metrics.getInstance().getNumericalMetric(MetricName.FAILED_REQ_COUNT_SYS).increment(); + return SUCCEEDED_FALSE; + } } else { - Metrics.getInstance().getNumericalMetric(MetricName.FAILED_REQ_COUNT_SYS).increment(); - return SUCCEEDED_FALSE; + String scrollId = cursor.getScrollId(); + ClearScrollResponse clearScrollResponse = + client.prepareClearScroll().addScrollId(scrollId).get(); + if (clearScrollResponse.isSucceeded()) { + return SUCCEEDED_TRUE; + } else { + Metrics.getInstance().getNumericalMetric(MetricName.FAILED_REQ_COUNT_SYS).increment(); + return SUCCEEDED_FALSE; + } } } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/executor/cursor/CursorResultExecutor.java b/legacy/src/main/java/org/opensearch/sql/legacy/executor/cursor/CursorResultExecutor.java index 66c69f3430..0af3ca243b 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/executor/cursor/CursorResultExecutor.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/executor/cursor/CursorResultExecutor.java @@ -6,6 +6,8 @@ package org.opensearch.sql.legacy.executor.cursor; import static org.opensearch.core.rest.RestStatus.OK; +import static org.opensearch.sql.common.setting.Settings.Key.SQL_CURSOR_KEEP_ALIVE; +import static org.opensearch.sql.common.setting.Settings.Key.SQL_PAGINATION_API_SEARCH_AFTER; import java.util.Arrays; import java.util.Map; @@ -14,6 +16,7 @@ import org.json.JSONException; import org.opensearch.OpenSearchException; import org.opensearch.action.search.ClearScrollResponse; +import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; import org.opensearch.client.Client; import org.opensearch.common.unit.TimeValue; @@ -21,7 +24,8 @@ import org.opensearch.rest.RestChannel; import org.opensearch.search.SearchHit; import org.opensearch.search.SearchHits; -import org.opensearch.sql.common.setting.Settings; +import org.opensearch.search.builder.PointInTimeBuilder; +import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.sql.legacy.cursor.CursorType; import org.opensearch.sql.legacy.cursor.DefaultCursor; import org.opensearch.sql.legacy.esdomain.LocalClusterState; @@ -29,6 +33,8 @@ import org.opensearch.sql.legacy.executor.format.Protocol; import org.opensearch.sql.legacy.metrics.MetricName; import org.opensearch.sql.legacy.metrics.Metrics; +import org.opensearch.sql.legacy.pit.PointInTimeHandler; +import org.opensearch.sql.legacy.pit.PointInTimeHandlerImpl; import org.opensearch.sql.legacy.rewriter.matchtoterm.VerificationException; public class CursorResultExecutor implements CursorRestExecutor { @@ -91,14 +97,27 @@ public String execute(Client client, Map params) throws Exceptio } private String handleDefaultCursorRequest(Client client, DefaultCursor cursor) { - String previousScrollId = cursor.getScrollId(); LocalClusterState clusterState = LocalClusterState.state(); - TimeValue scrollTimeout = clusterState.getSettingValue(Settings.Key.SQL_CURSOR_KEEP_ALIVE); - SearchResponse scrollResponse = - client.prepareSearchScroll(previousScrollId).setScroll(scrollTimeout).get(); + TimeValue paginationTimeout = clusterState.getSettingValue(SQL_CURSOR_KEEP_ALIVE); + + SearchResponse scrollResponse = null; + if (clusterState.getSettingValue(SQL_PAGINATION_API_SEARCH_AFTER)) { + String pitId = cursor.getPitId(); + SearchSourceBuilder source = cursor.getSearchSourceBuilder(); + source.searchAfter(cursor.getSortFields()); + source.pointInTimeBuilder(new PointInTimeBuilder(pitId)); + SearchRequest searchRequest = new SearchRequest(); + searchRequest.source(source); + scrollResponse = client.search(searchRequest).actionGet(); + } else { + String previousScrollId = cursor.getScrollId(); + scrollResponse = + client.prepareSearchScroll(previousScrollId).setScroll(paginationTimeout).get(); + } SearchHits searchHits = scrollResponse.getHits(); SearchHit[] searchHitArray = searchHits.getHits(); String newScrollId = scrollResponse.getScrollId(); + String newPitId = scrollResponse.pointInTimeId(); int rowsLeft = (int) cursor.getRowsLeft(); int fetch = cursor.getFetchSize(); @@ -124,16 +143,37 @@ private String handleDefaultCursorRequest(Client client, DefaultCursor cursor) { if (rowsLeft <= 0) { /** Clear the scroll context on last page */ - ClearScrollResponse clearScrollResponse = - client.prepareClearScroll().addScrollId(newScrollId).get(); - if (!clearScrollResponse.isSucceeded()) { - Metrics.getInstance().getNumericalMetric(MetricName.FAILED_REQ_COUNT_SYS).increment(); - LOG.info("Error closing the cursor context {} ", newScrollId); + if (newScrollId != null) { + ClearScrollResponse clearScrollResponse = + client.prepareClearScroll().addScrollId(newScrollId).get(); + if (!clearScrollResponse.isSucceeded()) { + Metrics.getInstance().getNumericalMetric(MetricName.FAILED_REQ_COUNT_SYS).increment(); + LOG.info("Error closing the cursor context {} ", newScrollId); + } + } + if (newPitId != null) { + PointInTimeHandler pit = new PointInTimeHandlerImpl(client, newPitId); + try { + pit.delete(); + } catch (RuntimeException e) { + Metrics.getInstance().getNumericalMetric(MetricName.FAILED_REQ_COUNT_SYS).increment(); + LOG.info("Error deleting point in time {} ", newPitId); + } } } cursor.setRowsLeft(rowsLeft); - cursor.setScrollId(newScrollId); + if (clusterState.getSettingValue(SQL_PAGINATION_API_SEARCH_AFTER)) { + cursor.setPitId(newPitId); + cursor.setSearchSourceBuilder(cursor.getSearchSourceBuilder()); + cursor.setSortFields( + scrollResponse + .getHits() + .getAt(scrollResponse.getHits().getHits().length - 1) + .getSortValues()); + } else { + cursor.setScrollId(newScrollId); + } Protocol protocol = new Protocol(client, searchHits, format.name().toLowerCase(), cursor); return protocol.cursorFormat(); } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/executor/format/PrettyFormatRestExecutor.java b/legacy/src/main/java/org/opensearch/sql/legacy/executor/format/PrettyFormatRestExecutor.java index 00feabf5d8..5f758e7d87 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/executor/format/PrettyFormatRestExecutor.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/executor/format/PrettyFormatRestExecutor.java @@ -5,23 +5,31 @@ package org.opensearch.sql.legacy.executor.format; +import static org.opensearch.sql.common.setting.Settings.Key.SQL_PAGINATION_API_SEARCH_AFTER; + import java.util.Map; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.opensearch.OpenSearchException; +import org.opensearch.action.search.SearchRequestBuilder; import org.opensearch.action.search.SearchResponse; import org.opensearch.client.Client; import org.opensearch.core.common.Strings; import org.opensearch.core.rest.RestStatus; import org.opensearch.rest.BytesRestResponse; import org.opensearch.rest.RestChannel; +import org.opensearch.search.builder.PointInTimeBuilder; import org.opensearch.sql.legacy.cursor.Cursor; import org.opensearch.sql.legacy.cursor.DefaultCursor; +import org.opensearch.sql.legacy.esdomain.LocalClusterState; import org.opensearch.sql.legacy.exception.SqlParseException; import org.opensearch.sql.legacy.executor.QueryActionElasticExecutor; import org.opensearch.sql.legacy.executor.RestExecutor; +import org.opensearch.sql.legacy.pit.PointInTimeHandler; +import org.opensearch.sql.legacy.pit.PointInTimeHandlerImpl; import org.opensearch.sql.legacy.query.DefaultQueryAction; import org.opensearch.sql.legacy.query.QueryAction; +import org.opensearch.sql.legacy.query.SqlOpenSearchRequestBuilder; import org.opensearch.sql.legacy.query.join.BackOffRetryStrategy; public class PrettyFormatRestExecutor implements RestExecutor { @@ -90,15 +98,32 @@ public String execute(Client client, Map params, QueryAction que private Protocol buildProtocolForDefaultQuery(Client client, DefaultQueryAction queryAction) throws SqlParseException { - SearchResponse response = (SearchResponse) queryAction.explain().get(); - String scrollId = response.getScrollId(); + PointInTimeHandler pit = null; + SearchResponse response; + SqlOpenSearchRequestBuilder sqlOpenSearchRequestBuilder = queryAction.explain(); + if (LocalClusterState.state().getSettingValue(SQL_PAGINATION_API_SEARCH_AFTER)) { + pit = new PointInTimeHandlerImpl(client, queryAction.getSelect().getIndexArr()); + pit.create(); + SearchRequestBuilder searchRequest = queryAction.getRequestBuilder(); + searchRequest.setPointInTime(new PointInTimeBuilder(pit.getPitId())); + response = searchRequest.get(); + } else { + response = (SearchResponse) sqlOpenSearchRequestBuilder.get(); + } Protocol protocol; - if (!Strings.isNullOrEmpty(scrollId)) { + if (isDefaultCursor(response, queryAction)) { DefaultCursor defaultCursor = new DefaultCursor(); - defaultCursor.setScrollId(scrollId); defaultCursor.setLimit(queryAction.getSelect().getRowCount()); defaultCursor.setFetchSize(queryAction.getSqlRequest().fetchSize()); + if (LocalClusterState.state().getSettingValue(SQL_PAGINATION_API_SEARCH_AFTER)) { + defaultCursor.setPitId(pit.getPitId()); + defaultCursor.setSearchSourceBuilder(queryAction.getRequestBuilder().request().source()); + defaultCursor.setSortFields( + response.getHits().getAt(response.getHits().getHits().length - 1).getSortValues()); + } else { + defaultCursor.setScrollId(response.getScrollId()); + } protocol = new Protocol(client, queryAction, response.getHits(), format, defaultCursor); } else { protocol = new Protocol(client, queryAction, response.getHits(), format, Cursor.NULL_CURSOR); @@ -106,4 +131,16 @@ private Protocol buildProtocolForDefaultQuery(Client client, DefaultQueryAction return protocol; } + + private boolean isDefaultCursor(SearchResponse searchResponse, DefaultQueryAction queryAction) { + if (LocalClusterState.state().getSettingValue(SQL_PAGINATION_API_SEARCH_AFTER)) { + if (searchResponse.getHits().getTotalHits().value < queryAction.getSqlRequest().fetchSize()) { + return false; + } else { + return true; + } + } else { + return !Strings.isNullOrEmpty(searchResponse.getScrollId()); + } + } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/executor/format/SelectResultSet.java b/legacy/src/main/java/org/opensearch/sql/legacy/executor/format/SelectResultSet.java index c60691cb7c..261816cddc 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/executor/format/SelectResultSet.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/executor/format/SelectResultSet.java @@ -40,6 +40,7 @@ import org.opensearch.search.aggregations.metrics.NumericMetricsAggregation; import org.opensearch.search.aggregations.metrics.Percentile; import org.opensearch.search.aggregations.metrics.Percentiles; +import org.opensearch.sql.common.setting.Settings; import org.opensearch.sql.legacy.cursor.Cursor; import org.opensearch.sql.legacy.cursor.DefaultCursor; import org.opensearch.sql.legacy.domain.ColumnTypeProvider; @@ -49,11 +50,14 @@ import org.opensearch.sql.legacy.domain.Query; import org.opensearch.sql.legacy.domain.Select; import org.opensearch.sql.legacy.domain.TableOnJoinSelect; +import org.opensearch.sql.legacy.esdomain.LocalClusterState; import org.opensearch.sql.legacy.esdomain.mapping.FieldMapping; import org.opensearch.sql.legacy.exception.SqlFeatureNotImplementedException; import org.opensearch.sql.legacy.executor.Format; import org.opensearch.sql.legacy.metrics.MetricName; import org.opensearch.sql.legacy.metrics.Metrics; +import org.opensearch.sql.legacy.pit.PointInTimeHandler; +import org.opensearch.sql.legacy.pit.PointInTimeHandlerImpl; import org.opensearch.sql.legacy.utils.SQLFunctions; public class SelectResultSet extends ResultSet { @@ -564,12 +568,24 @@ private void populateDefaultCursor(DefaultCursor cursor) { long rowsLeft = rowsLeft(cursor.getFetchSize(), cursor.getLimit()); if (rowsLeft <= 0) { // close the cursor - String scrollId = cursor.getScrollId(); - ClearScrollResponse clearScrollResponse = - client.prepareClearScroll().addScrollId(scrollId).get(); - if (!clearScrollResponse.isSucceeded()) { - Metrics.getInstance().getNumericalMetric(MetricName.FAILED_REQ_COUNT_SYS).increment(); - LOG.error("Error closing the cursor context {} ", scrollId); + if (LocalClusterState.state().getSettingValue(Settings.Key.SQL_PAGINATION_API_SEARCH_AFTER)) { + String pitId = cursor.getPitId(); + PointInTimeHandler pit = new PointInTimeHandlerImpl(client, pitId); + try { + pit.delete(); + } catch (RuntimeException e) { + Metrics.getInstance().getNumericalMetric(MetricName.FAILED_REQ_COUNT_SYS).increment(); + LOG.info("Error deleting point in time {} ", pitId); + } + } else { + // close the cursor + String scrollId = cursor.getScrollId(); + ClearScrollResponse clearScrollResponse = + client.prepareClearScroll().addScrollId(scrollId).get(); + if (!clearScrollResponse.isSucceeded()) { + Metrics.getInstance().getNumericalMetric(MetricName.FAILED_REQ_COUNT_SYS).increment(); + LOG.error("Error closing the cursor context {} ", scrollId); + } } return; } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/executor/join/ElasticJoinExecutor.java b/legacy/src/main/java/org/opensearch/sql/legacy/executor/join/ElasticJoinExecutor.java index 061868c9b5..c589edcf50 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/executor/join/ElasticJoinExecutor.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/executor/join/ElasticJoinExecutor.java @@ -31,6 +31,8 @@ import org.opensearch.sql.legacy.esdomain.LocalClusterState; import org.opensearch.sql.legacy.exception.SqlParseException; import org.opensearch.sql.legacy.executor.ElasticHitsExecutor; +import org.opensearch.sql.legacy.metrics.MetricName; +import org.opensearch.sql.legacy.metrics.Metrics; import org.opensearch.sql.legacy.pit.PointInTimeHandlerImpl; import org.opensearch.sql.legacy.query.SqlElasticRequestBuilder; import org.opensearch.sql.legacy.query.join.HashJoinElasticRequestBuilder; @@ -99,7 +101,12 @@ public void run() throws IOException, SqlParseException { LOG.error("Failed during join query run.", e); } finally { if (LocalClusterState.state().getSettingValue(SQL_PAGINATION_API_SEARCH_AFTER)) { - pit.delete(); + try { + pit.delete(); + } catch (RuntimeException e) { + Metrics.getInstance().getNumericalMetric(MetricName.FAILED_REQ_COUNT_SYS).increment(); + LOG.info("Error deleting point in time {} ", pit); + } } } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/executor/multi/MinusExecutor.java b/legacy/src/main/java/org/opensearch/sql/legacy/executor/multi/MinusExecutor.java index f58b25e821..06186d0695 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/executor/multi/MinusExecutor.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/executor/multi/MinusExecutor.java @@ -33,6 +33,8 @@ import org.opensearch.sql.legacy.esdomain.LocalClusterState; import org.opensearch.sql.legacy.exception.SqlParseException; import org.opensearch.sql.legacy.executor.ElasticHitsExecutor; +import org.opensearch.sql.legacy.metrics.MetricName; +import org.opensearch.sql.legacy.metrics.Metrics; import org.opensearch.sql.legacy.pit.PointInTimeHandlerImpl; import org.opensearch.sql.legacy.query.DefaultQueryAction; import org.opensearch.sql.legacy.query.multi.MultiQueryRequestBuilder; @@ -120,7 +122,12 @@ public void run() throws SqlParseException { LOG.error("Failed during multi query run.", e); } finally { if (LocalClusterState.state().getSettingValue(SQL_PAGINATION_API_SEARCH_AFTER)) { - pit.delete(); + try { + pit.delete(); + } catch (RuntimeException e) { + Metrics.getInstance().getNumericalMetric(MetricName.FAILED_REQ_COUNT_SYS).increment(); + LOG.info("Error deleting point in time {} ", pit); + } } } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/pit/PointInTimeHandlerImpl.java b/legacy/src/main/java/org/opensearch/sql/legacy/pit/PointInTimeHandlerImpl.java index 64535749e8..8d61c03388 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/pit/PointInTimeHandlerImpl.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/pit/PointInTimeHandlerImpl.java @@ -7,16 +7,19 @@ import static org.opensearch.sql.common.setting.Settings.Key.SQL_CURSOR_KEEP_ALIVE; +import java.util.concurrent.ExecutionException; import lombok.Getter; import lombok.Setter; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.opensearch.action.search.CreatePitAction; import org.opensearch.action.search.CreatePitRequest; import org.opensearch.action.search.CreatePitResponse; +import org.opensearch.action.search.DeletePitAction; import org.opensearch.action.search.DeletePitRequest; import org.opensearch.action.search.DeletePitResponse; import org.opensearch.client.Client; -import org.opensearch.core.action.ActionListener; +import org.opensearch.common.action.ActionFuture; import org.opensearch.sql.legacy.esdomain.LocalClusterState; /** Handler for Point In Time */ @@ -37,47 +40,45 @@ public PointInTimeHandlerImpl(Client client, String[] indices) { this.indices = indices; } + /** + * Constructor for class + * + * @param client OpenSearch client + * @param pitId Point In Time ID + */ + public PointInTimeHandlerImpl(Client client, String pitId) { + this.client = client; + this.pitId = pitId; + } + /** Create PIT for given indices */ @Override public void create() { CreatePitRequest createPitRequest = new CreatePitRequest( LocalClusterState.state().getSettingValue(SQL_CURSOR_KEEP_ALIVE), false, indices); - client.createPit( - createPitRequest, - new ActionListener<>() { - @Override - public void onResponse(CreatePitResponse createPitResponse) { - pitId = createPitResponse.getId(); - LOG.info("Created Point In Time {} successfully.", pitId); - } - - @Override - public void onFailure(Exception e) { - LOG.error("Error occurred while creating PIT", e); - } - }); + ActionFuture execute = + client.execute(CreatePitAction.INSTANCE, createPitRequest); + try { + CreatePitResponse pitResponse = execute.get(); + pitId = pitResponse.getId(); + LOG.info("Created Point In Time {} successfully.", pitId); + } catch (InterruptedException | ExecutionException e) { + throw new RuntimeException("Error occurred while creating PIT.", e); + } } /** Delete PIT */ @Override public void delete() { DeletePitRequest deletePitRequest = new DeletePitRequest(pitId); - client.deletePits( - deletePitRequest, - new ActionListener<>() { - @Override - public void onResponse(DeletePitResponse deletePitResponse) { - LOG.info( - "Delete Point In Time {} status: {}", - pitId, - deletePitResponse.status().getStatus()); - } - - @Override - public void onFailure(Exception e) { - LOG.error("Error occurred while deleting PIT", e); - } - }); + ActionFuture execute = + client.execute(DeletePitAction.INSTANCE, deletePitRequest); + try { + DeletePitResponse deletePitResponse = execute.get(); + LOG.info("Delete Point In Time {} status: {}", pitId, deletePitResponse.status().getStatus()); + } catch (InterruptedException | ExecutionException e) { + throw new RuntimeException("Error occurred while deleting PIT.", e); + } } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/query/DefaultQueryAction.java b/legacy/src/main/java/org/opensearch/sql/legacy/query/DefaultQueryAction.java index 18c9708df8..9877b17a8f 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/query/DefaultQueryAction.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/query/DefaultQueryAction.java @@ -5,6 +5,8 @@ package org.opensearch.sql.legacy.query; +import static org.opensearch.sql.common.setting.Settings.Key.SQL_PAGINATION_API_SEARCH_AFTER; + import com.alibaba.druid.sql.ast.SQLExpr; import com.alibaba.druid.sql.ast.expr.SQLBinaryOpExpr; import com.alibaba.druid.sql.ast.expr.SQLBinaryOperator; @@ -100,7 +102,19 @@ public void checkAndSetScroll() { .getNumericalMetric(MetricName.DEFAULT_CURSOR_REQUEST_COUNT_TOTAL) .increment(); Metrics.getInstance().getNumericalMetric(MetricName.DEFAULT_CURSOR_REQUEST_TOTAL).increment(); - request.setSize(fetchSize).setScroll(timeValue); + request.setSize(fetchSize); + // Set scroll or search after for pagination + if (LocalClusterState.state().getSettingValue(SQL_PAGINATION_API_SEARCH_AFTER)) { + // search after requires results to be in specific order + // set sort field for search_after + boolean ordered = select.isOrderdSelect(); + if (!ordered) { + request.addSort(FieldSortBuilder.DOC_FIELD_NAME, SortOrder.ASC); + } + // Request also requires PointInTime, but we should create pit while execution. + } else { + request.setScroll(timeValue); + } } else { request.setSearchType(SearchType.DFS_QUERY_THEN_FETCH); setLimit(select.getOffset(), rowCount != null ? rowCount : Select.DEFAULT_LIMIT); diff --git a/legacy/src/test/java/org/opensearch/sql/legacy/pit/PointInTimeHandlerImplTest.java b/legacy/src/test/java/org/opensearch/sql/legacy/pit/PointInTimeHandlerImplTest.java index 42f1af4563..eba4ae0346 100644 --- a/legacy/src/test/java/org/opensearch/sql/legacy/pit/PointInTimeHandlerImplTest.java +++ b/legacy/src/test/java/org/opensearch/sql/legacy/pit/PointInTimeHandlerImplTest.java @@ -2,80 +2,132 @@ * Copyright OpenSearch Contributors * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.legacy.pit; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.fail; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; -import static org.opensearch.sql.common.setting.Settings.Key.SQL_CURSOR_KEEP_ALIVE; -import java.util.concurrent.CompletableFuture; +import java.util.Collections; +import java.util.concurrent.ExecutionException; +import lombok.SneakyThrows; import org.junit.Before; import org.junit.Test; -import org.mockito.ArgumentCaptor; -import org.mockito.Captor; import org.mockito.Mock; import org.mockito.MockitoAnnotations; +import org.opensearch.action.search.CreatePitAction; +import org.opensearch.action.search.CreatePitRequest; import org.opensearch.action.search.CreatePitResponse; +import org.opensearch.action.search.DeletePitAction; +import org.opensearch.action.search.DeletePitRequest; import org.opensearch.action.search.DeletePitResponse; import org.opensearch.client.Client; +import org.opensearch.common.action.ActionFuture; import org.opensearch.common.unit.TimeValue; -import org.opensearch.core.action.ActionListener; import org.opensearch.core.rest.RestStatus; +import org.opensearch.sql.common.setting.Settings; import org.opensearch.sql.legacy.esdomain.LocalClusterState; +import org.opensearch.sql.opensearch.setting.OpenSearchSettings; public class PointInTimeHandlerImplTest { @Mock private Client mockClient; private String[] indices = {"index1", "index2"}; private PointInTimeHandlerImpl pointInTimeHandlerImpl; - @Captor private ArgumentCaptor> listenerCaptor; private final String PIT_ID = "testId"; + private CreatePitResponse mockCreatePitResponse; + private DeletePitResponse mockDeletePitResponse; + private ActionFuture mockActionFuture; + private ActionFuture mockActionFutureDelete; + + @Mock private OpenSearchSettings settings; @Before public void setUp() { MockitoAnnotations.initMocks(this); pointInTimeHandlerImpl = new PointInTimeHandlerImpl(mockClient, indices); - } - @Test - public void testCreate() { - when(LocalClusterState.state().getSettingValue(SQL_CURSOR_KEEP_ALIVE)) + doReturn(Collections.emptyList()).when(settings).getSettings(); + when(settings.getSettingValue(Settings.Key.SQL_CURSOR_KEEP_ALIVE)) .thenReturn(new TimeValue(10000)); + LocalClusterState.state().setPluginSettings(settings); - CreatePitResponse mockCreatePitResponse = mock(CreatePitResponse.class); + mockCreatePitResponse = mock(CreatePitResponse.class); + mockDeletePitResponse = mock(DeletePitResponse.class); + mockActionFuture = mock(ActionFuture.class); + mockActionFutureDelete = mock(ActionFuture.class); + when(mockClient.execute(any(CreatePitAction.class), any(CreatePitRequest.class))) + .thenReturn(mockActionFuture); + when(mockClient.execute(any(DeletePitAction.class), any(DeletePitRequest.class))) + .thenReturn(mockActionFutureDelete); + RestStatus mockRestStatus = mock(RestStatus.class); + when(mockDeletePitResponse.status()).thenReturn(mockRestStatus); + when(mockDeletePitResponse.status().getStatus()).thenReturn(200); when(mockCreatePitResponse.getId()).thenReturn(PIT_ID); + } - CompletableFuture completableFuture = - CompletableFuture.completedFuture(mockCreatePitResponse); + @SneakyThrows + @Test + public void testCreate() { + when(mockActionFuture.get()).thenReturn(mockCreatePitResponse); + try { + pointInTimeHandlerImpl.create(); + } catch (RuntimeException e) { + fail("Expected no exception while creating PIT, but got: " + e.getMessage()); + } + verify(mockClient).execute(any(CreatePitAction.class), any(CreatePitRequest.class)); + verify(mockActionFuture).get(); + verify(mockCreatePitResponse).getId(); + } - doAnswer( - invocation -> { - ActionListener actionListener = invocation.getArgument(1); - actionListener.onResponse(mockCreatePitResponse); - return completableFuture; - }) - .when(mockClient) - .createPit(any(), any()); + @SneakyThrows + @Test + public void testCreateForFailure() { + ExecutionException executionException = + new ExecutionException("Error occurred while creating PIT.", new Throwable()); + when(mockActionFuture.get()).thenThrow(executionException); - pointInTimeHandlerImpl.create(); + RuntimeException thrownException = + assertThrows(RuntimeException.class, () -> pointInTimeHandlerImpl.create()); - assertEquals(PIT_ID, pointInTimeHandlerImpl.getPitId()); + verify(mockClient).execute(any(CreatePitAction.class), any(CreatePitRequest.class)); + assertNotNull(thrownException.getCause()); + assertEquals("Error occurred while creating PIT.", thrownException.getMessage()); + verify(mockActionFuture).get(); } + @SneakyThrows @Test public void testDelete() { - DeletePitResponse mockedResponse = mock(DeletePitResponse.class); - RestStatus mockRestStatus = mock(RestStatus.class); - when(mockedResponse.status()).thenReturn(mockRestStatus); - when(mockedResponse.status().getStatus()).thenReturn(200); - pointInTimeHandlerImpl.setPitId(PIT_ID); - pointInTimeHandlerImpl.delete(); - verify(mockClient).deletePits(any(), listenerCaptor.capture()); - listenerCaptor.getValue().onResponse(mockedResponse); + when(mockActionFutureDelete.get()).thenReturn(mockDeletePitResponse); + try { + pointInTimeHandlerImpl.delete(); + } catch (RuntimeException e) { + fail("Expected no exception while deleting PIT, but got: " + e.getMessage()); + } + verify(mockClient).execute(any(DeletePitAction.class), any(DeletePitRequest.class)); + verify(mockActionFutureDelete).get(); + } + + @SneakyThrows + @Test + public void testDeleteForFailure() { + ExecutionException executionException = + new ExecutionException("Error occurred while deleting PIT.", new Throwable()); + when(mockActionFutureDelete.get()).thenThrow(executionException); + + RuntimeException thrownException = + assertThrows(RuntimeException.class, () -> pointInTimeHandlerImpl.delete()); + + verify(mockClient).execute(any(DeletePitAction.class), any(DeletePitRequest.class)); + assertNotNull(thrownException.getCause()); + assertEquals("Error occurred while deleting PIT.", thrownException.getMessage()); + verify(mockActionFutureDelete).get(); } } diff --git a/legacy/src/test/java/org/opensearch/sql/legacy/unittest/cursor/DefaultCursorTest.java b/legacy/src/test/java/org/opensearch/sql/legacy/unittest/cursor/DefaultCursorTest.java index 1b9662035d..deff7132b0 100644 --- a/legacy/src/test/java/org/opensearch/sql/legacy/unittest/cursor/DefaultCursorTest.java +++ b/legacy/src/test/java/org/opensearch/sql/legacy/unittest/cursor/DefaultCursorTest.java @@ -9,14 +9,47 @@ import static org.hamcrest.Matchers.emptyOrNullString; import static org.hamcrest.Matchers.startsWith; import static org.junit.Assert.assertEquals; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.when; +import java.io.ByteArrayOutputStream; import java.util.ArrayList; import java.util.Collections; +import org.junit.Before; import org.junit.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.sql.common.setting.Settings; import org.opensearch.sql.legacy.cursor.CursorType; import org.opensearch.sql.legacy.cursor.DefaultCursor; +import org.opensearch.sql.legacy.esdomain.LocalClusterState; +import org.opensearch.sql.opensearch.setting.OpenSearchSettings; public class DefaultCursorTest { + @Mock private OpenSearchSettings settings; + + @Mock private SearchSourceBuilder sourceBuilder; + + @Before + public void setUp() { + MockitoAnnotations.openMocks(this); + // Required for Pagination queries using PIT instead of Scroll + doReturn(Collections.emptyList()).when(settings).getSettings(); + when(settings.getSettingValue(Settings.Key.SQL_PAGINATION_API_SEARCH_AFTER)).thenReturn(true); + LocalClusterState.state().setPluginSettings(settings); + + // Mock the toXContent method of SearchSourceBuilder + try { + XContentBuilder xContentBuilder = XContentFactory.jsonBuilder(new ByteArrayOutputStream()); + when(sourceBuilder.toXContent(any(XContentBuilder.class), any())).thenReturn(xContentBuilder); + } catch (Exception e) { + throw new RuntimeException(e); + } + } @Test public void checkCursorType() { @@ -25,7 +58,26 @@ public void checkCursorType() { } @Test - public void cursorShouldStartWithCursorTypeID() { + public void cursorShouldStartWithCursorTypeIDForPIT() { + DefaultCursor cursor = new DefaultCursor(); + cursor.setRowsLeft(50); + cursor.setPitId("dbdskbcdjksbcjkdsbcjk+//"); + cursor.setIndexPattern("myIndex"); + cursor.setFetchSize(500); + cursor.setFieldAliasMap(Collections.emptyMap()); + cursor.setColumns(new ArrayList<>()); + + // Set the mocked SearchSourceBuilder to the cursor + cursor.setSearchSourceBuilder(sourceBuilder); + + assertThat(cursor.generateCursorId(), startsWith(cursor.getType().getId() + ":")); + } + + @Test + public void cursorShouldStartWithCursorTypeIDForScroll() { + // Disable PIT for pagination and use scroll instead + when(settings.getSettingValue(Settings.Key.SQL_PAGINATION_API_SEARCH_AFTER)).thenReturn(false); + DefaultCursor cursor = new DefaultCursor(); cursor.setRowsLeft(50); cursor.setScrollId("dbdskbcdjksbcjkdsbcjk+//"); @@ -33,6 +85,10 @@ public void cursorShouldStartWithCursorTypeID() { cursor.setFetchSize(500); cursor.setFieldAliasMap(Collections.emptyMap()); cursor.setColumns(new ArrayList<>()); + + // Set the mocked SearchSourceBuilder to the cursor + cursor.setSearchSourceBuilder(sourceBuilder); + assertThat(cursor.generateCursorId(), startsWith(cursor.getType().getId() + ":")); } diff --git a/legacy/src/test/java/org/opensearch/sql/legacy/unittest/query/DefaultQueryActionTest.java b/legacy/src/test/java/org/opensearch/sql/legacy/unittest/query/DefaultQueryActionTest.java index 755d604a65..ec6ab00f97 100644 --- a/legacy/src/test/java/org/opensearch/sql/legacy/unittest/query/DefaultQueryActionTest.java +++ b/legacy/src/test/java/org/opensearch/sql/legacy/unittest/query/DefaultQueryActionTest.java @@ -8,16 +8,9 @@ import static org.hamcrest.Matchers.equalTo; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.Mockito.any; -import static org.mockito.Mockito.doReturn; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.never; -import static org.mockito.Mockito.times; - -import java.util.ArrayList; -import java.util.LinkedList; -import java.util.List; -import java.util.Optional; +import static org.mockito.Mockito.*; + +import java.util.*; import org.junit.After; import org.junit.Assert; import org.junit.Before; @@ -27,6 +20,8 @@ import org.opensearch.client.Client; import org.opensearch.common.unit.TimeValue; import org.opensearch.script.Script; +import org.opensearch.search.sort.FieldSortBuilder; +import org.opensearch.search.sort.SortOrder; import org.opensearch.sql.common.setting.Settings; import org.opensearch.sql.legacy.domain.Field; import org.opensearch.sql.legacy.domain.KVValue; @@ -149,6 +144,12 @@ public void testIfScrollShouldBeOpenWithDifferentFormats() { queryAction.setFormat(Format.JDBC); queryAction.checkAndSetScroll(); Mockito.verify(mockRequestBuilder).setSize(settingFetchSize); + Mockito.verify(mockRequestBuilder).addSort(FieldSortBuilder.DOC_FIELD_NAME, SortOrder.ASC); + Mockito.verify(mockRequestBuilder, never()).setScroll(timeValue); + + // Verify setScroll when SQL_PAGINATION_API_SEARCH_AFTER is set to false + mockLocalClusterStateAndIntializeMetricsForScroll(timeValue); + queryAction.checkAndSetScroll(); Mockito.verify(mockRequestBuilder).setScroll(timeValue); } @@ -168,6 +169,12 @@ public void testIfScrollShouldBeOpen() { mockLocalClusterStateAndInitializeMetrics(timeValue); queryAction.checkAndSetScroll(); Mockito.verify(mockRequestBuilder).setSize(settingFetchSize); + Mockito.verify(mockRequestBuilder).addSort(FieldSortBuilder.DOC_FIELD_NAME, SortOrder.ASC); + Mockito.verify(mockRequestBuilder, never()).setScroll(timeValue); + + // Verify setScroll when SQL_PAGINATION_API_SEARCH_AFTER is set to false + mockLocalClusterStateAndIntializeMetricsForScroll(timeValue); + queryAction.checkAndSetScroll(); Mockito.verify(mockRequestBuilder).setScroll(timeValue); } @@ -195,6 +202,12 @@ public void testIfScrollShouldBeOpenWithDifferentFetchSize() { doReturn(mockRequestBuilder).when(mockRequestBuilder).setSize(userFetchSize); queryAction.checkAndSetScroll(); Mockito.verify(mockRequestBuilder).setSize(20); + Mockito.verify(mockRequestBuilder).addSort(FieldSortBuilder.DOC_FIELD_NAME, SortOrder.ASC); + Mockito.verify(mockRequestBuilder, never()).setScroll(timeValue); + + // Verify setScroll when SQL_PAGINATION_API_SEARCH_AFTER is set to false + mockLocalClusterStateAndIntializeMetricsForScroll(timeValue); + queryAction.checkAndSetScroll(); Mockito.verify(mockRequestBuilder).setScroll(timeValue); } @@ -216,7 +229,9 @@ public void testIfScrollShouldBeOpenWithDifferentValidFetchSizeAndLimit() { queryAction.checkAndSetScroll(); Mockito.verify(mockRequestBuilder).setSize(userFetchSize); - Mockito.verify(mockRequestBuilder).setScroll(timeValue); + Mockito.verify(mockRequestBuilder).addSort(FieldSortBuilder.DOC_FIELD_NAME, SortOrder.ASC); + // Skip setScroll when SQL_PAGINATION_API_SEARCH_AFTER is set to false + Mockito.verify(mockRequestBuilder, never()).setScroll(timeValue); /** fetchSize > LIMIT - no scroll */ userFetchSize = 5000; @@ -226,6 +241,14 @@ public void testIfScrollShouldBeOpenWithDifferentValidFetchSizeAndLimit() { queryAction.checkAndSetScroll(); Mockito.verify(mockRequestBuilder).setSize(limit); Mockito.verify(mockRequestBuilder, never()).setScroll(timeValue); + + // Verify setScroll when SQL_PAGINATION_API_SEARCH_AFTER is set to false + mockLocalClusterStateAndIntializeMetricsForScroll(timeValue); + /** fetchSize <= LIMIT - open scroll */ + userFetchSize = 1500; + doReturn(userFetchSize).when(mockSqlRequest).fetchSize(); + queryAction.checkAndSetScroll(); + Mockito.verify(mockRequestBuilder).setScroll(timeValue); } private void mockLocalClusterStateAndInitializeMetrics(TimeValue time) { @@ -236,6 +259,24 @@ private void mockLocalClusterStateAndInitializeMetrics(TimeValue time) { .when(mockLocalClusterState) .getSettingValue(Settings.Key.METRICS_ROLLING_WINDOW); doReturn(2L).when(mockLocalClusterState).getSettingValue(Settings.Key.METRICS_ROLLING_INTERVAL); + doReturn(true) + .when(mockLocalClusterState) + .getSettingValue(Settings.Key.SQL_PAGINATION_API_SEARCH_AFTER); + + Metrics.getInstance().registerDefaultMetrics(); + } + + private void mockLocalClusterStateAndIntializeMetricsForScroll(TimeValue time) { + LocalClusterState mockLocalClusterState = mock(LocalClusterState.class); + LocalClusterState.state(mockLocalClusterState); + doReturn(time).when(mockLocalClusterState).getSettingValue(Settings.Key.SQL_CURSOR_KEEP_ALIVE); + doReturn(3600L) + .when(mockLocalClusterState) + .getSettingValue(Settings.Key.METRICS_ROLLING_WINDOW); + doReturn(2L).when(mockLocalClusterState).getSettingValue(Settings.Key.METRICS_ROLLING_INTERVAL); + doReturn(false) + .when(mockLocalClusterState) + .getSettingValue(Settings.Key.SQL_PAGINATION_API_SEARCH_AFTER); Metrics.getInstance().registerDefaultMetrics(); }