Skip to content

Commit

Permalink
Throw proper exception to invalid k-NN query
Browse files Browse the repository at this point in the history
Signed-off-by: Junqiu Lei <[email protected]>
  • Loading branch information
junqiu-lei committed Jan 8, 2024
1 parent 271df52 commit a48c909
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 0 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
* Allow nested knn field mapping when train model [#1318](https://github.com/opensearch-project/k-NN/pull/1318)
* Properly designate model state for actively training models when nodes crash or leave cluster [#1317](https://github.com/opensearch-project/k-NN/pull/1317)
* Fix script score queries not getting cached [#1367](https://github.com/opensearch-project/k-NN/pull/1367)
* Throw proper exception to invalid k-NN query [#1380](https://github.com/opensearch-project/k-NN/pull/1380)
### Infrastructure
* Upgrade gradle to 8.4 [1289](https://github.com/opensearch-project/k-NN/pull/1289)
* Refactor security testing to install from individual components [#1307](https://github.com/opensearch-project/k-NN/pull/1307)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,14 @@ public static void initialize(ModelDao modelDao) {
}

private static float[] ObjectsToFloats(List<Object> objs) {
if (Objects.isNull(objs)) {
throw new IllegalArgumentException("[" + NAME + "] requires 'vector' to be non-null");
}
float[] vec = new float[objs.size()];
for (int i = 0; i < objs.size(); i++) {
if (!(objs.get(i) instanceof Number)) {
throw new IllegalArgumentException("[" + NAME + "] requires 'vector' to be an array of numbers");
}
vec[i] = ((Number) objs.get(i)).floatValue();
}
return vec;
Expand Down
45 changes: 45 additions & 0 deletions src/test/java/org/opensearch/knn/index/VectorDataTypeIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,51 @@ public void testKNNScriptScoreWithInvalidByteQueryVector() throws Exception {
);
}

@SneakyThrows
public void testSearchWithInvalidSearchVectorType() {
createKnnIndexMappingWithLuceneEngine(2, SpaceType.L2, VectorDataType.FLOAT.getValue());
ingestL2FloatTestData();
Request request = new Request("POST", String.format("/%s/_search", INDEX_NAME));
request.setJsonEntity(
"{\n"
+ " \"query\": {\n"
+ " \"knn\": {\n"
+ " \"test-field-vec-dt\": {\n"
+ " \"vector\": [\"a\", \"b\", \"c\", \"d\"],\n"
+ " \"k\": 4\n"
+ " }\n"
+ " }\n"
+ " }\n"
+ "}"
);

ResponseException ex = expectThrows(ResponseException.class, () -> client().performRequest(request));
assertTrue(ex.getResponse().getStatusLine().getReasonPhrase().contains("Bad Request"));
assertTrue(ex.getMessage().contains(String.format(Locale.ROOT, "[knn] requires 'vector' to be an array of numbers")));
}

@SneakyThrows
public void testSearchWithMissingQueryVector() {
createKnnIndexMappingWithLuceneEngine(2, SpaceType.L2, VectorDataType.FLOAT.getValue());
ingestL2FloatTestData();
Request request = new Request("POST", String.format("/%s/_search", INDEX_NAME));
request.setJsonEntity(
"{\n"
+ " \"query\": {\n"
+ " \"knn\": {\n"
+ " \"test-field-vec-dt\": {\n"
+ " \"k\": 4\n"
+ " }\n"
+ " }\n"
+ " }\n"
+ "}"
);

ResponseException ex = expectThrows(ResponseException.class, () -> client().performRequest(request));
assertTrue(ex.getResponse().getStatusLine().getReasonPhrase().contains("Bad Request"));
assertTrue(ex.getMessage().contains(String.format(Locale.ROOT, "[knn] requires 'vector' to be non-null")));
}

@SneakyThrows
private void ingestL2ByteTestData() {
Byte[] b1 = { 6, 6 };
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,43 @@ public void testFromXcontent_WithFilter() throws Exception {
actualBuilder.equals(knnQueryBuilder);
}

public void testFromXContent_invalidQueryVectorType() throws Exception {
final ClusterService clusterService = mockClusterService(Version.CURRENT);

final KNNClusterUtil knnClusterUtil = KNNClusterUtil.instance();
knnClusterUtil.initialize(clusterService);

String[] invalidTypeQueryVector = { "a", "b", "c", "d" };

XContentBuilder builder = XContentFactory.jsonBuilder();
builder.startObject();
builder.startObject(FIELD_NAME);
builder.field(KNNQueryBuilder.VECTOR_FIELD.getPreferredName(), invalidTypeQueryVector);
builder.field(KNNQueryBuilder.K_FIELD.getPreferredName(), K);
builder.endObject();
builder.endObject();
XContentParser contentParser = createParser(builder);
contentParser.nextToken();
expectThrows(IllegalArgumentException.class, () -> KNNQueryBuilder.fromXContent(contentParser));
}

public void testFromXContent_missingQueryVector() throws Exception {
final ClusterService clusterService = mockClusterService(Version.CURRENT);

final KNNClusterUtil knnClusterUtil = KNNClusterUtil.instance();
knnClusterUtil.initialize(clusterService);

XContentBuilder builder = XContentFactory.jsonBuilder();
builder.startObject();
builder.startObject(FIELD_NAME);
builder.field(KNNQueryBuilder.K_FIELD.getPreferredName(), K);
builder.endObject();
builder.endObject();
XContentParser contentParser = createParser(builder);
contentParser.nextToken();
expectThrows(IllegalArgumentException.class, () -> KNNQueryBuilder.fromXContent(contentParser));
}

@Override
protected NamedXContentRegistry xContentRegistry() {
List<NamedXContentRegistry.Entry> list = ClusterModule.getNamedXWriteables();
Expand Down

0 comments on commit a48c909

Please sign in to comment.