Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Backport 2.x] Throw proper exception to invalid k-NN query (#1380) #1381

Merged
merged 1 commit into from
Jan 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
* Increase Lucene max dimension limit to 16,000 [#1346](https://github.com/opensearch-project/k-NN/pull/1346)
* Tuned default values for ef_search and ef_construction for better indexing and search performance for vector search [#1353](https://github.com/opensearch-project/k-NN/pull/1353)
* Enabled Filtering on Nested Vector fields with top level filters [#1372](https://github.com/opensearch-project/k-NN/pull/1372)
* Throw proper exception to invalid k-NN query [#1380](https://github.com/opensearch-project/k-NN/pull/1380)
### Bug Fixes
* Fix use-after-free case on nmslib search path [#1305](https://github.com/opensearch-project/k-NN/pull/1305)
* Allow nested knn field mapping when train model [#1318](https://github.com/opensearch-project/k-NN/pull/1318)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,14 @@ public static void initialize(ModelDao modelDao) {
}

private static float[] ObjectsToFloats(List<Object> objs) {
if (Objects.isNull(objs) || objs.isEmpty()) {
throw new IllegalArgumentException(String.format("[%s] field 'vector' requires to be non-null and non-empty", NAME));
}
float[] vec = new float[objs.size()];
for (int i = 0; i < objs.size(); i++) {
if ((objs.get(i) instanceof Number) == false) {
throw new IllegalArgumentException(String.format("[%s] field 'vector' requires to be an array of numbers", NAME));
}
vec[i] = ((Number) objs.get(i)).floatValue();
}
return vec;
Expand Down
51 changes: 51 additions & 0 deletions src/test/java/org/opensearch/knn/index/VectorDataTypeIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import org.opensearch.core.rest.RestStatus;
import org.opensearch.script.Script;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
Expand Down Expand Up @@ -425,6 +426,56 @@ public void testKNNScriptScoreWithInvalidByteQueryVector() throws Exception {
);
}

@SneakyThrows
public void testSearchWithInvalidSearchVectorType() {
createKnnIndexMappingWithLuceneEngine(2, SpaceType.L2, VectorDataType.FLOAT.getValue());
ingestL2FloatTestData();
Request request = new Request("POST", String.format("/%s/_search", INDEX_NAME));
List<Object> invalidTypeQueryVector = new ArrayList<>();
invalidTypeQueryVector.add(1.5);
invalidTypeQueryVector.add(2.5);
invalidTypeQueryVector.add("a");
invalidTypeQueryVector.add(null);
XContentBuilder builder = XContentFactory.jsonBuilder()
.startObject()
.startObject("query")
.startObject("knn")
.startObject(FIELD_NAME)
.field("vector", invalidTypeQueryVector)
.field("k", 4)
.endObject()
.endObject()
.endObject()
.endObject();
request.setJsonEntity(builder.toString());

ResponseException ex = expectThrows(ResponseException.class, () -> client().performRequest(request));
assertEquals(400, ex.getResponse().getStatusLine().getStatusCode());
assertTrue(ex.getMessage().contains("[knn] field 'vector' requires to be an array of numbers"));
}

@SneakyThrows
public void testSearchWithMissingQueryVector() {
createKnnIndexMappingWithLuceneEngine(2, SpaceType.L2, VectorDataType.FLOAT.getValue());
ingestL2FloatTestData();
Request request = new Request("POST", String.format("/%s/_search", INDEX_NAME));
XContentBuilder builder = XContentFactory.jsonBuilder()
.startObject()
.startObject("query")
.startObject("knn")
.startObject(FIELD_NAME)
.field("k", 4)
.endObject()
.endObject()
.endObject()
.endObject();
request.setJsonEntity(builder.toString());

ResponseException ex = expectThrows(ResponseException.class, () -> client().performRequest(request));
assertEquals(400, ex.getResponse().getStatusLine().getStatusCode());
assertTrue(ex.getMessage().contains("[knn] field 'vector' requires to be non-null and non-empty"));
}

@SneakyThrows
private void ingestL2ByteTestData() {
Byte[] b1 = { 6, 6 };
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import org.opensearch.plugins.SearchPlugin;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;

Expand Down Expand Up @@ -149,6 +150,70 @@ public void testFromXcontent_WithFilter_UnsupportedClusterVersion() throws Excep
expectThrows(IllegalArgumentException.class, () -> KNNQueryBuilder.fromXContent(contentParser));
}

public void testFromXContent_invalidQueryVectorType() throws Exception {
final ClusterService clusterService = mockClusterService(Version.CURRENT);

final KNNClusterUtil knnClusterUtil = KNNClusterUtil.instance();
knnClusterUtil.initialize(clusterService);

List<Object> invalidTypeQueryVector = new ArrayList<>();
invalidTypeQueryVector.add(1.5);
invalidTypeQueryVector.add(2.5);
invalidTypeQueryVector.add("a");
invalidTypeQueryVector.add(null);

XContentBuilder builder = XContentFactory.jsonBuilder();
builder.startObject();
builder.startObject(FIELD_NAME);
builder.field(KNNQueryBuilder.VECTOR_FIELD.getPreferredName(), invalidTypeQueryVector);
builder.field(KNNQueryBuilder.K_FIELD.getPreferredName(), K);
builder.endObject();
builder.endObject();
XContentParser contentParser = createParser(builder);
contentParser.nextToken();
IllegalArgumentException exception = expectThrows(
IllegalArgumentException.class,
() -> KNNQueryBuilder.fromXContent(contentParser)
);
assertTrue(exception.getMessage().contains("[knn] field 'vector' requires to be an array of numbers"));
}

public void testFromXContent_missingQueryVector() throws Exception {
final ClusterService clusterService = mockClusterService(Version.CURRENT);

final KNNClusterUtil knnClusterUtil = KNNClusterUtil.instance();
knnClusterUtil.initialize(clusterService);

// Test without vector field
XContentBuilder builderWithoutVectorField = XContentFactory.jsonBuilder();
builderWithoutVectorField.startObject();
builderWithoutVectorField.startObject(FIELD_NAME);
builderWithoutVectorField.field(KNNQueryBuilder.K_FIELD.getPreferredName(), K);
builderWithoutVectorField.endObject();
builderWithoutVectorField.endObject();
XContentParser contentParserWithoutVectorField = createParser(builderWithoutVectorField);
contentParserWithoutVectorField.nextToken();
IllegalArgumentException exception = expectThrows(
IllegalArgumentException.class,
() -> KNNQueryBuilder.fromXContent(contentParserWithoutVectorField)
);
assertTrue(exception.getMessage().contains("[knn] field 'vector' requires to be non-null and non-empty"));

// Test empty vector field
List<Object> emptyQueryVector = new ArrayList<>();
XContentBuilder builderWithEmptyVector = XContentFactory.jsonBuilder();
builderWithEmptyVector.startObject();
builderWithEmptyVector.startObject(FIELD_NAME);
builderWithEmptyVector.field(KNNQueryBuilder.VECTOR_FIELD.getPreferredName(), emptyQueryVector);
builderWithEmptyVector.field(KNNQueryBuilder.K_FIELD.getPreferredName(), K);
builderWithEmptyVector.endObject();
builderWithEmptyVector.endObject();
XContentParser contentParserWithEmptyVector = createParser(builderWithEmptyVector);
contentParserWithEmptyVector.nextToken();
exception = expectThrows(IllegalArgumentException.class, () -> KNNQueryBuilder.fromXContent(contentParserWithEmptyVector));
assertTrue(exception.getMessage().contains("[knn] field 'vector' requires to be non-null and non-empty"));
}

@Override
protected NamedXContentRegistry xContentRegistry() {
List<NamedXContentRegistry.Entry> list = ClusterModule.getNamedXWriteables();
Expand Down
Loading