diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/cursor/DefaultCursor.java b/legacy/src/main/java/org/opensearch/sql/legacy/cursor/DefaultCursor.java index 3b8fc802d1..65c3725ad9 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/cursor/DefaultCursor.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/cursor/DefaultCursor.java @@ -13,10 +13,11 @@ import com.google.common.base.Strings; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; +import java.io.IOException; import java.security.AccessController; import java.security.PrivilegedAction; -import java.util.ArrayList; import java.util.Base64; +import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -26,7 +27,6 @@ import lombok.NoArgsConstructor; import lombok.NonNull; import lombok.Setter; -import lombok.SneakyThrows; import org.json.JSONArray; import org.json.JSONObject; import org.opensearch.common.settings.Settings; @@ -99,7 +99,7 @@ public class DefaultCursor implements Cursor { private String pitId; /** To get next batch of result with search after api */ - public SearchSourceBuilder searchSourceBuilder; + private SearchSourceBuilder searchSourceBuilder; /** To get last sort values * */ private Object[] sortFields; @@ -115,7 +115,7 @@ public class DefaultCursor implements Cursor { */ private static final NamedXContentRegistry xContentRegistry = new NamedXContentRegistry( - new SearchModule(Settings.builder().build(), new ArrayList<>()).getNamedXContents()); + new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()); @Override public CursorType getType() { @@ -124,11 +124,7 @@ public CursorType getType() { @Override public String generateCursorId() { - boolean isCursorValid = - LocalClusterState.state().getSettingValue(SQL_PAGINATION_API_SEARCH_AFTER) - ? Strings.isNullOrEmpty(pitId) - : Strings.isNullOrEmpty(scrollId); - if (rowsLeft <= 0 || isCursorValid) { + if (rowsLeft <= 0 || isCursorIdNullOrEmpty()) { return null; } JSONObject json = new JSONObject(); @@ -146,24 +142,44 @@ public String generateCursorId() { try { return objectMapper.writeValueAsString(sortFields); } catch (JsonProcessingException e) { - throw new RuntimeException(e); + throw new RuntimeException( + "Failed to parse sort fields from JSON string.", e); } }); json.put(SORT_FIELDS, sortFieldValue); } else { json.put(SCROLL_ID, scrollId); } - return String.format("%s:%s", type.getId(), encodeCursor(json, searchSourceBuilder)); + setSearchRequestString(json, searchSourceBuilder); + return String.format("%s:%s", type.getId(), encodeCursor(json)); + } + + private void setSearchRequestString(JSONObject cursorJson, SearchSourceBuilder sourceBuilder) { + try { + ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); + XContentBuilder builder = XContentFactory.jsonBuilder(outputStream); + sourceBuilder.toXContent(builder, null); + builder.close(); + + String searchRequestBase64 = Base64.getEncoder().encodeToString(outputStream.toByteArray()); + cursorJson.put("searchSourceBuilder", searchRequestBase64); + } catch (IOException ex) { + throw new RuntimeException("Failed to set search request string on cursor json.", ex); + } + } + + private boolean isCursorIdNullOrEmpty() { + return LocalClusterState.state().getSettingValue(SQL_PAGINATION_API_SEARCH_AFTER) + ? Strings.isNullOrEmpty(pitId) + : Strings.isNullOrEmpty(scrollId); } - @SneakyThrows public static DefaultCursor from(String cursorId) { /** * It is assumed that cursorId here is the second part of the original cursor passed by the * client after removing first part which identifies cursor type */ - String[] parts = cursorId.split(":::"); - JSONObject json = decodeCursor(parts[0]); + JSONObject json = decodeCursor(cursorId); DefaultCursor cursor = new DefaultCursor(); cursor.setFetchSize(json.getInt(FETCH_SIZE)); cursor.setRowsLeft(json.getLong(ROWS_LEFT)); @@ -178,19 +194,26 @@ public static DefaultCursor from(String cursorId) { try { return objectMapper.readValue(json.getString(SORT_FIELDS), Object[].class); } catch (JsonProcessingException e) { - throw new RuntimeException(e); + throw new RuntimeException( + "Failed to parse sort fields from JSON string.", e); } }); cursor.setSortFields(sortFieldValue); - byte[] bytes = Base64.getDecoder().decode(parts[1]); + // Retrieve the SearchSourceBuilder from the JSON field + String searchSourceBuilderBase64 = json.getString("searchSourceBuilder"); + byte[] bytes = Base64.getDecoder().decode(searchSourceBuilderBase64); ByteArrayInputStream streamInput = new ByteArrayInputStream(bytes); - XContentParser parser = - XContentType.JSON - .xContent() - .createParser(xContentRegistry, IGNORE_DEPRECATIONS, streamInput); - SearchSourceBuilder sourceBuilder = SearchSourceBuilder.fromXContent(parser); - cursor.searchSourceBuilder = sourceBuilder; + try { + XContentParser parser = + XContentType.JSON + .xContent() + .createParser(xContentRegistry, IGNORE_DEPRECATIONS, streamInput); + SearchSourceBuilder sourceBuilder = SearchSourceBuilder.fromXContent(parser); + cursor.setSearchSourceBuilder(sourceBuilder); + } catch (IOException ex) { + throw new RuntimeException("Failed to get searchSourceBuilder from cursor Id", ex); + } } else { cursor.setScrollId(json.getString(SCROLL_ID)); } @@ -220,18 +243,8 @@ private JSONObject schemaEntry(String name, String alias, String type) { return entry; } - @SneakyThrows - private static String encodeCursor(JSONObject cursorJson, SearchSourceBuilder sourceBuilder) { - String jsonBase64 = Base64.getEncoder().encodeToString(cursorJson.toString().getBytes()); - - ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); - XContentBuilder builder = XContentFactory.jsonBuilder(outputStream); - sourceBuilder.toXContent(builder, null); - builder.close(); - - String searchRequestBase64 = Base64.getEncoder().encodeToString(outputStream.toByteArray()); - - return jsonBase64 + ":::" + searchRequestBase64; + private static String encodeCursor(JSONObject cursorJson) { + return Base64.getEncoder().encodeToString(cursorJson.toString().getBytes()); } private static JSONObject decodeCursor(String cursorId) { diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/pit/PointInTimeHandlerImpl.java b/legacy/src/main/java/org/opensearch/sql/legacy/pit/PointInTimeHandlerImpl.java index b5b9827a9b..dc2c8aadf2 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/pit/PointInTimeHandlerImpl.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/pit/PointInTimeHandlerImpl.java @@ -7,16 +7,19 @@ import static org.opensearch.sql.common.setting.Settings.Key.SQL_CURSOR_KEEP_ALIVE; +import java.util.concurrent.ExecutionException; import lombok.Getter; import lombok.Setter; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.opensearch.action.search.CreatePitAction; import org.opensearch.action.search.CreatePitRequest; import org.opensearch.action.search.CreatePitResponse; +import org.opensearch.action.search.DeletePitAction; import org.opensearch.action.search.DeletePitRequest; import org.opensearch.action.search.DeletePitResponse; import org.opensearch.client.Client; -import org.opensearch.core.action.ActionListener; +import org.opensearch.common.action.ActionFuture; import org.opensearch.sql.legacy.esdomain.LocalClusterState; /** Handler for Point In Time */ @@ -24,8 +27,6 @@ public class PointInTimeHandlerImpl implements PointInTimeHandler { private Client client; private String[] indices; @Getter @Setter private String pitId; - private Boolean deleteStatus = null; - private Boolean createStatus = null; private static final Logger LOG = LogManager.getLogger(); /** @@ -60,31 +61,16 @@ public boolean create() { CreatePitRequest createPitRequest = new CreatePitRequest( LocalClusterState.state().getSettingValue(SQL_CURSOR_KEEP_ALIVE), false, indices); - client.createPit( - createPitRequest, - new ActionListener<>() { - - @Override - public void onResponse(CreatePitResponse createPitResponse) { - pitId = createPitResponse.getId(); - createStatus = true; - LOG.info("Created Point In Time {} successfully.", pitId); - } - - @Override - public void onFailure(Exception e) { - createStatus = false; - LOG.error("Error occurred while creating PIT", e); - } - }); - while (createStatus == null) { - try { - Thread.sleep(100); - } catch (InterruptedException e) { - LOG.error("Error occurred while creating PIT", e); - } + ActionFuture execute = + client.execute(CreatePitAction.INSTANCE, createPitRequest); + try { + CreatePitResponse pitResponse = execute.get(); + pitId = pitResponse.getId(); + LOG.info("Created Point In Time {} successfully.", pitId); + return true; + } catch (InterruptedException | ExecutionException e) { + throw new RuntimeException("Error occurred while creating PIT.", e); } - return createStatus; } /** @@ -95,32 +81,14 @@ public void onFailure(Exception e) { @Override public boolean delete() { DeletePitRequest deletePitRequest = new DeletePitRequest(pitId); - client.deletePits( - deletePitRequest, - new ActionListener<>() { - @Override - public void onResponse(DeletePitResponse deletePitResponse) { - deleteStatus = true; - LOG.info( - "Delete Point In Time {} status: {}", - pitId, - deletePitResponse.status().getStatus()); - } - - @Override - public void onFailure(Exception e) { - deleteStatus = false; - LOG.error("Error occurred while deleting PIT", e); - } - }); - - while (deleteStatus == null) { - try { - Thread.sleep(100); - } catch (InterruptedException e) { - LOG.error("Error occurred while deleting PIT", e); - } + ActionFuture execute = + client.execute(DeletePitAction.INSTANCE, deletePitRequest); + try { + DeletePitResponse deletePitResponse = execute.get(); + LOG.info("Delete Point In Time {} status: {}", pitId, deletePitResponse.status().getStatus()); + return true; + } catch (InterruptedException | ExecutionException e) { + throw new RuntimeException("Error occurred while deleting PIT.", e); } - return deleteStatus; } } diff --git a/legacy/src/test/java/org/opensearch/sql/legacy/pit/PointInTimeHandlerImplTest.java b/legacy/src/test/java/org/opensearch/sql/legacy/pit/PointInTimeHandlerImplTest.java index e17e554952..304147a519 100644 --- a/legacy/src/test/java/org/opensearch/sql/legacy/pit/PointInTimeHandlerImplTest.java +++ b/legacy/src/test/java/org/opensearch/sql/legacy/pit/PointInTimeHandlerImplTest.java @@ -4,24 +4,32 @@ */ package org.opensearch.sql.legacy.pit; -import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.*; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; import java.util.Collections; -import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; +import lombok.SneakyThrows; import org.junit.Before; import org.junit.Test; -import org.mockito.ArgumentCaptor; -import org.mockito.Captor; import org.mockito.Mock; import org.mockito.MockitoAnnotations; +import org.opensearch.action.search.CreatePitAction; +import org.opensearch.action.search.CreatePitRequest; import org.opensearch.action.search.CreatePitResponse; +import org.opensearch.action.search.DeletePitAction; +import org.opensearch.action.search.DeletePitRequest; import org.opensearch.action.search.DeletePitResponse; import org.opensearch.client.Client; +import org.opensearch.common.action.ActionFuture; import org.opensearch.common.unit.TimeValue; -import org.opensearch.core.action.ActionListener; import org.opensearch.core.rest.RestStatus; import org.opensearch.sql.common.setting.Settings; import org.opensearch.sql.legacy.esdomain.LocalClusterState; @@ -32,14 +40,11 @@ public class PointInTimeHandlerImplTest { @Mock private Client mockClient; private String[] indices = {"index1", "index2"}; private PointInTimeHandlerImpl pointInTimeHandlerImpl; - @Captor private ArgumentCaptor> listenerCaptorForDelete; - @Captor private ArgumentCaptor> listenerCaptorForCreate; private final String PIT_ID = "testId"; private CreatePitResponse mockCreatePitResponse; - private CompletableFuture completableFuture; - private CompletableFuture completableFutureForDelete; - private Exception exception; private DeletePitResponse mockDeletePitResponse; + private ActionFuture mockActionFuture; + private ActionFuture mockActionFutureDelete; @Mock private OpenSearchSettings settings; @@ -55,82 +60,69 @@ public void setUp() { mockCreatePitResponse = mock(CreatePitResponse.class); mockDeletePitResponse = mock(DeletePitResponse.class); + mockActionFuture = mock(ActionFuture.class); + mockActionFutureDelete = mock(ActionFuture.class); + when(mockClient.execute(any(CreatePitAction.class), any(CreatePitRequest.class))) + .thenReturn(mockActionFuture); + when(mockClient.execute(any(DeletePitAction.class), any(DeletePitRequest.class))) + .thenReturn(mockActionFutureDelete); RestStatus mockRestStatus = mock(RestStatus.class); when(mockDeletePitResponse.status()).thenReturn(mockRestStatus); when(mockDeletePitResponse.status().getStatus()).thenReturn(200); when(mockCreatePitResponse.getId()).thenReturn(PIT_ID); - - completableFuture = CompletableFuture.completedFuture(mockCreatePitResponse); - completableFutureForDelete = CompletableFuture.completedFuture(mockDeletePitResponse); - exception = mock(Exception.class); } + @SneakyThrows @Test public void testCreate() { - doAnswer( - invocation -> { - ActionListener actionListener = invocation.getArgument(1); - actionListener.onResponse(mockCreatePitResponse); - return completableFuture; - }) - .when(mockClient) - .createPit(any(), listenerCaptorForCreate.capture()); - + when(mockActionFuture.get()).thenReturn(mockCreatePitResponse); boolean status = pointInTimeHandlerImpl.create(); - verify(mockClient).createPit(any(), listenerCaptorForCreate.capture()); - listenerCaptorForCreate.getValue().onResponse(mockCreatePitResponse); - verify(mockCreatePitResponse, times(2)).getId(); + verify(mockClient).execute(any(CreatePitAction.class), any(CreatePitRequest.class)); + verify(mockActionFuture).get(); + verify(mockCreatePitResponse).getId(); assertTrue(status); } + @SneakyThrows @Test public void testCreateForFailure() { - doAnswer( - invocation -> { - ActionListener actionListener = invocation.getArgument(1); - actionListener.onFailure(exception); - return completableFuture; - }) - .when(mockClient) - .createPit(any(), listenerCaptorForCreate.capture()); + ExecutionException executionException = + new ExecutionException("Error occurred while creating PIT.", new Throwable()); + when(mockActionFuture.get()).thenThrow(executionException); - boolean status = pointInTimeHandlerImpl.create(); - verify(mockClient).createPit(any(), listenerCaptorForCreate.capture()); - listenerCaptorForCreate.getValue().onResponse(mockCreatePitResponse); - assertFalse(status); + RuntimeException thrownException = + assertThrows(RuntimeException.class, () -> pointInTimeHandlerImpl.create()); + + verify(mockClient).execute(any(CreatePitAction.class), any(CreatePitRequest.class)); + assertNotNull(thrownException.getCause()); + assertEquals("Error occurred while creating PIT.", thrownException.getMessage()); + verify(mockActionFuture).get(); } + @SneakyThrows @Test public void testDelete() { - doAnswer( - invocation -> { - ActionListener actionListener = invocation.getArgument(1); - actionListener.onResponse(mockDeletePitResponse); - return completableFutureForDelete; - }) - .when(mockClient) - .deletePits(any(), listenerCaptorForDelete.capture()); + when(mockActionFutureDelete.get()).thenReturn(mockDeletePitResponse); boolean status = pointInTimeHandlerImpl.delete(); + verify(mockClient).execute(any(DeletePitAction.class), any(DeletePitRequest.class)); + verify(mockActionFutureDelete).get(); assertTrue(status); - verify(mockClient).deletePits(any(), listenerCaptorForDelete.capture()); - listenerCaptorForDelete.getValue().onResponse(mockDeletePitResponse); } + @SneakyThrows @Test public void testDeleteForFailure() { - doAnswer( - invocation -> { - ActionListener actionListener = invocation.getArgument(1); - actionListener.onFailure(exception); - return completableFutureForDelete; - }) - .when(mockClient) - .deletePits(any(), listenerCaptorForDelete.capture()); + ExecutionException executionException = + new ExecutionException("Error occurred while deleting PIT.", new Throwable()); + when(mockActionFutureDelete.get()).thenThrow(executionException); - boolean status = pointInTimeHandlerImpl.delete(); - assertFalse(status); - verify(mockClient).deletePits(any(), listenerCaptorForDelete.capture()); - listenerCaptorForDelete.getValue().onResponse(mockDeletePitResponse); + RuntimeException thrownException = + assertThrows(RuntimeException.class, () -> pointInTimeHandlerImpl.delete()); + + verify(mockClient).execute(any(DeletePitAction.class), any(DeletePitRequest.class)); + assertNotNull(thrownException.getCause()); + assertEquals("Error occurred while deleting PIT.", thrownException.getMessage()); + verify(mockActionFutureDelete).get(); } } diff --git a/legacy/src/test/java/org/opensearch/sql/legacy/unittest/cursor/DefaultCursorTest.java b/legacy/src/test/java/org/opensearch/sql/legacy/unittest/cursor/DefaultCursorTest.java index 8506c7a1e5..deff7132b0 100644 --- a/legacy/src/test/java/org/opensearch/sql/legacy/unittest/cursor/DefaultCursorTest.java +++ b/legacy/src/test/java/org/opensearch/sql/legacy/unittest/cursor/DefaultCursorTest.java @@ -68,7 +68,7 @@ public void cursorShouldStartWithCursorTypeIDForPIT() { cursor.setColumns(new ArrayList<>()); // Set the mocked SearchSourceBuilder to the cursor - cursor.searchSourceBuilder = sourceBuilder; + cursor.setSearchSourceBuilder(sourceBuilder); assertThat(cursor.generateCursorId(), startsWith(cursor.getType().getId() + ":")); } @@ -87,7 +87,7 @@ public void cursorShouldStartWithCursorTypeIDForScroll() { cursor.setColumns(new ArrayList<>()); // Set the mocked SearchSourceBuilder to the cursor - cursor.searchSourceBuilder = sourceBuilder; + cursor.setSearchSourceBuilder(sourceBuilder); assertThat(cursor.generateCursorId(), startsWith(cursor.getType().getId() + ":")); }