diff --git a/qa/restart-upgrade/src/test/java/org/opensearch/knn/bwc/ModelIT.java b/qa/restart-upgrade/src/test/java/org/opensearch/knn/bwc/ModelIT.java index 7eb75e24c..eaef4f458 100644 --- a/qa/restart-upgrade/src/test/java/org/opensearch/knn/bwc/ModelIT.java +++ b/qa/restart-upgrade/src/test/java/org/opensearch/knn/bwc/ModelIT.java @@ -10,7 +10,6 @@ import org.opensearch.action.search.SearchResponse; import org.opensearch.client.Request; import org.opensearch.client.Response; -import org.opensearch.client.ResponseException; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.core.xcontent.XContentParser; @@ -24,7 +23,6 @@ import org.opensearch.search.SearchHit; import java.io.IOException; -import java.nio.charset.StandardCharsets; import java.util.Arrays; import java.util.Map; @@ -55,7 +53,7 @@ public class ModelIT extends AbstractRestartUpgradeTestCase { private static int DOC_ID_TEST_MODEL_INDEX = 0; private static int DOC_ID_TEST_MODEL_INDEX_DEFAULT = 0; private static final int DELAY_MILLI_SEC = 1000; - private static final int EXP_NUM_OF_MODELS = 3; + private static final int EXP_NUM_OF_MODELS = 2; private static final int K = 5; private static final int NUM_DOCS = 10; private static final int NUM_DOCS_TEST_MODEL_INDEX = 100; @@ -83,6 +81,7 @@ public void testKNNModel() throws Exception { createKnnIndex(testIndex, modelIndexMapping(TEST_FIELD, TEST_MODEL_ID)); addKNNDocs(testIndex, TEST_FIELD, DIMENSIONS, DOC_ID, NUM_DOCS); } else { + Thread.sleep(1000); DOC_ID = NUM_DOCS; addKNNDocs(testIndex, TEST_FIELD, DIMENSIONS, DOC_ID, NUM_DOCS); QUERY_COUNT = 2 * NUM_DOCS; @@ -115,6 +114,7 @@ public void testKNNModelDefault() throws Exception { createKnnIndex(testIndex, modelIndexMapping(TEST_FIELD, TEST_MODEL_ID_DEFAULT)); addKNNDocs(testIndex, TEST_FIELD, DIMENSIONS, DOC_ID, NUM_DOCS); } else { + Thread.sleep(1000); DOC_ID = NUM_DOCS; addKNNDocs(testIndex, TEST_FIELD, DIMENSIONS, DOC_ID, NUM_DOCS); QUERY_COUNT = 2 * NUM_DOCS; @@ -139,22 +139,6 @@ public void testKNNModelDefault() throws Exception { } } - // KNN Delete Model test for model in Training State - public void testDeleteTrainingModel() throws Exception { - byte[] testModelBlob = "hello".getBytes(StandardCharsets.UTF_8); - ModelMetadata testModelMetadata = getModelMetadata(); - testModelMetadata.setState(ModelState.TRAINING); - if (isRunningAgainstOldCluster()) { - addModelToSystemIndex(TEST_MODEL_ID_TRAINING, testModelMetadata, testModelBlob); - } else { - String restURI = String.join("/", KNNPlugin.KNN_BASE_URI, MODELS, TEST_MODEL_ID_TRAINING); - Request request = new Request("DELETE", restURI); - - ResponseException ex = expectThrows(ResponseException.class, () -> client().performRequest(request)); - assertEquals(RestStatus.CONFLICT.getStatus(), ex.getResponse().getStatusLine().getStatusCode()); - } - } - // Delete Models and ".opensearch-knn-models" index to clear cluster metadata @AfterClass public static void wipeAllModels() throws IOException { diff --git a/src/test/java/org/opensearch/knn/plugin/transport/GetModelResponseTests.java b/src/test/java/org/opensearch/knn/plugin/transport/GetModelResponseTests.java index 05adb1cf4..04c94de7e 100644 --- a/src/test/java/org/opensearch/knn/plugin/transport/GetModelResponseTests.java +++ b/src/test/java/org/opensearch/knn/plugin/transport/GetModelResponseTests.java @@ -11,10 +11,13 @@ package org.opensearch.knn.plugin.transport; +import org.mockito.MockedStatic; +import org.opensearch.Version; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.index.KNNClusterUtil; import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.util.KNNEngine; import org.opensearch.knn.indices.Model; @@ -23,6 +26,10 @@ import java.io.IOException; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.mockStatic; +import static org.mockito.Mockito.when; + public class GetModelResponseTests extends KNNTestCase { private ModelMetadata getModelMetadata(ModelState state) { @@ -41,25 +48,35 @@ public void testStreams() throws IOException { } public void testXContent() throws IOException { - String modelId = "test-model"; - byte[] testModelBlob = "hello".getBytes(); - Model model = new Model(getModelMetadata(ModelState.CREATED), testModelBlob, modelId); - GetModelResponse getModelResponse = new GetModelResponse(model); - String expectedResponseString = - "{\"model_id\":\"test-model\",\"model_blob\":\"aGVsbG8=\",\"state\":\"created\",\"timestamp\":\"2021-03-27 10:15:30 AM +05:30\",\"description\":\"test model\",\"error\":\"\",\"space_type\":\"l2\",\"dimension\":4,\"engine\":\"nmslib\",\"training_node_assignment\":\"\"}"; - XContentBuilder xContentBuilder = XContentFactory.jsonBuilder(); - getModelResponse.toXContent(xContentBuilder, null); - assertEquals(expectedResponseString, xContentBuilder.toString()); + try (MockedStatic knnClusterUtilMockedStatic = mockStatic(KNNClusterUtil.class)) { + final KNNClusterUtil knnClusterUtil = mock(KNNClusterUtil.class); + when(knnClusterUtil.getClusterMinVersion()).thenReturn(Version.CURRENT); + knnClusterUtilMockedStatic.when(KNNClusterUtil::instance).thenReturn(knnClusterUtil); + String modelId = "test-model"; + byte[] testModelBlob = "hello".getBytes(); + Model model = new Model(getModelMetadata(ModelState.CREATED), testModelBlob, modelId); + GetModelResponse getModelResponse = new GetModelResponse(model); + String expectedResponseString = + "{\"model_id\":\"test-model\",\"model_blob\":\"aGVsbG8=\",\"state\":\"created\",\"timestamp\":\"2021-03-27 10:15:30 AM +05:30\",\"description\":\"test model\",\"error\":\"\",\"space_type\":\"l2\",\"dimension\":4,\"engine\":\"nmslib\",\"training_node_assignment\":\"\"}"; + XContentBuilder xContentBuilder = XContentFactory.jsonBuilder(); + getModelResponse.toXContent(xContentBuilder, null); + assertEquals(expectedResponseString, xContentBuilder.toString()); + } } public void testXContentWithNoModelBlob() throws IOException { - String modelId = "test-model"; - Model model = new Model(getModelMetadata(ModelState.FAILED), null, modelId); - GetModelResponse getModelResponse = new GetModelResponse(model); - String expectedResponseString = - "{\"model_id\":\"test-model\",\"model_blob\":\"\",\"state\":\"failed\",\"timestamp\":\"2021-03-27 10:15:30 AM +05:30\",\"description\":\"test model\",\"error\":\"\",\"space_type\":\"l2\",\"dimension\":4,\"engine\":\"nmslib\",\"training_node_assignment\":\"\"}"; - XContentBuilder xContentBuilder = XContentFactory.jsonBuilder(); - getModelResponse.toXContent(xContentBuilder, null); - assertEquals(expectedResponseString, xContentBuilder.toString()); + try (MockedStatic knnClusterUtilMockedStatic = mockStatic(KNNClusterUtil.class)) { + final KNNClusterUtil knnClusterUtil = mock(KNNClusterUtil.class); + when(knnClusterUtil.getClusterMinVersion()).thenReturn(Version.CURRENT); + knnClusterUtilMockedStatic.when(KNNClusterUtil::instance).thenReturn(knnClusterUtil); + String modelId = "test-model"; + Model model = new Model(getModelMetadata(ModelState.FAILED), null, modelId); + GetModelResponse getModelResponse = new GetModelResponse(model); + String expectedResponseString = + "{\"model_id\":\"test-model\",\"model_blob\":\"\",\"state\":\"failed\",\"timestamp\":\"2021-03-27 10:15:30 AM +05:30\",\"description\":\"test model\",\"error\":\"\",\"space_type\":\"l2\",\"dimension\":4,\"engine\":\"nmslib\",\"training_node_assignment\":\"\"}"; + XContentBuilder xContentBuilder = XContentFactory.jsonBuilder(); + getModelResponse.toXContent(xContentBuilder, null); + assertEquals(expectedResponseString, xContentBuilder.toString()); + } } }