diff --git a/docs/reference/query-languages/esql/images/functions/knn.svg b/docs/reference/query-languages/esql/images/functions/knn.svg index 75a104a7cdcfa..6e20dbc217206 100644 --- a/docs/reference/query-languages/esql/images/functions/knn.svg +++ b/docs/reference/query-languages/esql/images/functions/knn.svg @@ -1 +1 @@ -KNN(field,query,options) \ No newline at end of file +KNN(field,query,k,options) \ No newline at end of file diff --git a/docs/reference/query-languages/esql/kibana/definition/functions/knn.json b/docs/reference/query-languages/esql/kibana/definition/functions/knn.json index 48d3e582eec58..f7350e39abcdf 100644 --- a/docs/reference/query-languages/esql/kibana/definition/functions/knn.json +++ b/docs/reference/query-languages/esql/kibana/definition/functions/knn.json @@ -5,8 +5,8 @@ "description" : "Finds the k nearest vectors to a query vector, as measured by a similarity metric. knn function finds nearest vectors through approximate search on indexed dense_vectors.", "signatures" : [ ], "examples" : [ - "from colors metadata _score\n| where knn(rgb_vector, [0, 120, 0])\n| sort _score desc", - "from colors metadata _score\n| where knn(rgb_vector, [0,255,255], {\"k\": 4})\n| sort _score desc" + "from colors metadata _score\n| where knn(rgb_vector, [0, 120, 0], 10)\n| sort _score desc, color asc", + "from colors metadata _score\n| where knn(rgb_vector, [0,255,255], 4)\n| sort _score desc, color asc" ], "preview" : true, "snapshot_only" : true diff --git a/docs/reference/query-languages/esql/kibana/docs/functions/knn.md b/docs/reference/query-languages/esql/kibana/docs/functions/knn.md index 45d1f294ea0a8..c7af797488ba4 100644 --- a/docs/reference/query-languages/esql/kibana/docs/functions/knn.md +++ b/docs/reference/query-languages/esql/kibana/docs/functions/knn.md @@ -5,6 +5,6 @@ Finds the k nearest vectors to a query vector, as measured by a similarity metri ```esql from colors metadata _score -| where knn(rgb_vector, [0, 120, 0]) -| sort _score desc +| where knn(rgb_vector, [0, 120, 0], 10) +| sort _score desc, color asc ``` diff --git a/muted-tests.yml b/muted-tests.yml index be3866d777e1d..c01ccbd040ccf 100644 --- a/muted-tests.yml +++ b/muted-tests.yml @@ -511,9 +511,6 @@ tests: - class: org.elasticsearch.entitlement.runtime.policy.FileAccessTreeTests method: testWindowsAbsolutPathAccess issue: https://github.com/elastic/elasticsearch/issues/129168 -- class: org.elasticsearch.xpack.esql.qa.multi_node.EsqlSpecIT - method: test {knn-function.KnnSearchWithKOption ASYNC} - issue: https://github.com/elastic/elasticsearch/issues/129447 - class: org.elasticsearch.xpack.ml.integration.ClassificationIT method: testWithDatastreams issue: https://github.com/elastic/elasticsearch/issues/129457 @@ -535,9 +532,6 @@ tests: - class: org.elasticsearch.xpack.security.PermissionsIT method: testWhenUserLimitedByOnlyAliasOfIndexCanWriteToIndexWhichWasRolledoverByILMPolicy issue: https://github.com/elastic/elasticsearch/issues/129481 -- class: org.elasticsearch.xpack.esql.qa.multi_node.EsqlSpecIT - method: test {knn-function.KnnSearchWithKOption SYNC} - issue: https://github.com/elastic/elasticsearch/issues/129512 - class: org.elasticsearch.xpack.logsdb.qa.StandardVersusStandardReindexedIntoLogsDbChallengeRestIT method: testMatchAllQuery issue: https://github.com/elastic/elasticsearch/issues/129527 diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index 917a0db62da28..427eae224aefe 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -306,6 +306,7 @@ static TransportVersion def(int id) { public static final TransportVersion ML_INFERENCE_CUSTOM_SERVICE_REMOVE_ERROR_PARSING = def(9_102_0_00); public static final TransportVersion ML_INFERENCE_CUSTOM_SERVICE_EMBEDDING_BATCH_SIZE = def(9_103_0_00); public static final TransportVersion STREAMS_LOGS_SUPPORT = def(9_104_0_00); + public static final TransportVersion ESQL_KNN_K_PARAM_MANDATORY = def(9_105_0_00); /* * STOP! READ THIS FIRST! No, really, diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/knn-function.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/knn-function.csv-spec index ac6c16f35de01..a0b9558ac0b75 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/knn-function.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/knn-function.csv-spec @@ -3,11 +3,11 @@ # top-n query at the shard level knnSearch -required_capability: knn_function +required_capability: knn_function_v2 // tag::knn-function[] from colors metadata _score -| where knn(rgb_vector, [0, 120, 0]) +| where knn(rgb_vector, [0, 120, 0], 10) | sort _score desc, color asc // end::knn-function[] | keep color, rgb_vector @@ -29,31 +29,12 @@ chartreuse | [127.0, 255.0, 0.0] // end::knn-function-result[] ; -knnSearchWithKOption -required_capability: knn_function - -// tag::knn-function-options[] -from colors metadata _score -| where knn(rgb_vector, [0,255,255], {"k": 4}) -| sort _score desc, color asc -// end::knn-function-options[] -| keep color, rgb_vector -| limit 4 -; - -color:text | rgb_vector:dense_vector -cyan | [0.0, 255.0, 255.0] -turquoise | [64.0, 224.0, 208.0] -aqua marine | [127.0, 255.0, 212.0] -teal | [0.0, 128.0, 128.0] -; - # https://github.com/elastic/elasticsearch/issues/129550 knnSearchWithSimilarityOption-Ignore -required_capability: knn_function +required_capability: knn_function_v2 from colors metadata _score -| where knn(rgb_vector, [255,192,203], {"k": 140, "similarity": 40}) +| where knn(rgb_vector, [255,192,203], 140, {"similarity": 40}) | sort _score desc, color asc | keep color, rgb_vector ; @@ -63,14 +44,13 @@ pink | [255.0, 192.0, 203.0] peach puff | [255.0, 218.0, 185.0] bisque | [255.0, 228.0, 196.0] wheat | [245.0, 222.0, 179.0] - ; knnHybridSearch -required_capability: knn_function +required_capability: knn_function_v2 from colors metadata _score -| where match(color, "blue") or knn(rgb_vector, [65,105,225], {"k": 140}) +| where match(color, "blue") or knn(rgb_vector, [65,105,225], 140) | where primary == true | sort _score desc, color asc | keep color, rgb_vector @@ -90,10 +70,10 @@ yellow | [255.0, 255.0, 0.0] ; knnWithMultipleFunctions -required_capability: knn_function +required_capability: knn_function_v2 from colors metadata _score -| where knn(rgb_vector, [128,128,0], {"k": 140}) and match(color, "olive") +| where knn(rgb_vector, [128,128,0], 140) and match(color, "olive") | sort _score desc, color asc | keep color, rgb_vector ; @@ -103,11 +83,11 @@ olive | [128.0, 128.0, 0.0] ; knnAfterKeep -required_capability: knn_function +required_capability: knn_function_v2 from colors metadata _score | keep rgb_vector, color, _score -| where knn(rgb_vector, [128,255,0], {"k": 140}) +| where knn(rgb_vector, [128,255,0], 140) | sort _score desc, color asc | keep rgb_vector | limit 5 @@ -122,11 +102,11 @@ rgb_vector:dense_vector ; knnAfterDrop -required_capability: knn_function +required_capability: knn_function_v2 from colors metadata _score | drop primary -| where knn(rgb_vector, [128,250,0], {"k": 140}) +| where knn(rgb_vector, [128,250,0], 140) | sort _score desc, color asc | keep color, rgb_vector | limit 5 @@ -141,11 +121,11 @@ lime | [0.0, 255.0, 0.0] ; knnAfterEval -required_capability: knn_function +required_capability: knn_function_v2 from colors metadata _score | eval composed_name = locate(color, " ") > 0 -| where knn(rgb_vector, [128,128,0], {"k": 140}) +| where knn(rgb_vector, [128,128,0], 140) | sort _score desc, color asc | keep color, composed_name | limit 5 @@ -160,11 +140,11 @@ golden rod | true ; knnWithConjunction -required_capability: knn_function +required_capability: knn_function_v2 # TODO We need kNN prefiltering here so we get more candidates that pass the filter from colors metadata _score -| where knn(rgb_vector, [255,255,238], {"k": 140}) and hex_code like "#FFF*" +| where knn(rgb_vector, [255,255,238], 140) and hex_code like "#FFF*" | sort _score desc, color asc | keep color, hex_code, rgb_vector | limit 10 @@ -181,11 +161,11 @@ yellow | #FFFF00 | [255.0, 255.0, 0.0] ; knnWithDisjunctionAndFiltersConjunction -required_capability: knn_function +required_capability: knn_function_v2 # TODO We need kNN prefiltering here so we get more candidates that pass the filter from colors metadata _score -| where (knn(rgb_vector, [0,255,255], {"k": 140}) or knn(rgb_vector, [128, 0, 255], {"k": 140})) and primary == true +| where (knn(rgb_vector, [0,255,255], 140) or knn(rgb_vector, [128, 0, 255], 140)) and primary == true | keep color, rgb_vector, _score | sort _score desc, color asc | drop _score @@ -205,11 +185,11 @@ yellow | [255.0, 255.0, 0.0] ; knnWithNonPushableConjunction -required_capability: knn_function +required_capability: knn_function_v2 from colors metadata _score | eval composed_name = locate(color, " ") > 0 -| where knn(rgb_vector, [128,128,0], {"k": 140}) and composed_name == false +| where knn(rgb_vector, [128,128,0], 140) and composed_name == false | sort _score desc, color asc | keep color, composed_name | limit 10 @@ -230,10 +210,10 @@ maroon | false # https://github.com/elastic/elasticsearch/issues/129550 testKnnWithNonPushableDisjunctions-Ignore -required_capability: knn_function +required_capability: knn_function_v2 from colors metadata _score -| where knn(rgb_vector, [128,128,0], {"k": 140, "similarity": 30}) or length(color) > 10 +| where knn(rgb_vector, [128,128,0], 140, {"similarity": 30}) or length(color) > 10 | sort _score desc, color asc | keep color ; @@ -247,10 +227,10 @@ papaya whip # https://github.com/elastic/elasticsearch/issues/129550 testKnnWithNonPushableDisjunctionsOnComplexExpressions-Ignore -required_capability: knn_function +required_capability: knn_function_v2 from colors metadata _score -| where (knn(rgb_vector, [128,128,0], {"k": 140, "similarity": 70}) and length(color) < 10) or (knn(rgb_vector, [128,0,128], {"k": 140, "similarity": 60}) and primary == false) +| where (knn(rgb_vector, [128,128,0], 140, {"similarity": 70}) and length(color) < 10) or (knn(rgb_vector, [128,0,128], 140, {"similarity": 60}) and primary == false) | sort _score desc, color asc | keep color, primary ; @@ -262,11 +242,11 @@ indigo | false ; testKnnInStatsNonPushable -required_capability: knn_function +required_capability: knn_function_v2 from colors | where length(color) < 10 -| stats c = count(*) where knn(rgb_vector, [128,128,255], {"k": 140}) +| stats c = count(*) where knn(rgb_vector, [128,128,255], 140) ; c: long @@ -274,12 +254,12 @@ c: long ; testKnnInStatsWithGrouping -required_capability: knn_function +required_capability: knn_function_v2 required_capability: full_text_functions_in_stats_where from colors | where length(color) < 10 -| stats c = count(*) where knn(rgb_vector, [128,128,255], {"k": 140}) by primary +| stats c = count(*) where knn(rgb_vector, [128,128,255], 140) by primary ; c: long | primary: boolean diff --git a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/KnnFunctionIT.java b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/KnnFunctionIT.java index a262943909938..11f9bd6c5aeb5 100644 --- a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/KnnFunctionIT.java +++ b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/KnnFunctionIT.java @@ -39,7 +39,7 @@ public void testKnnDefaults() { var query = String.format(Locale.ROOT, """ FROM test METADATA _score - | WHERE knn(vector, %s) + | WHERE knn(vector, %s, 10) | KEEP id, floats, _score, vector | SORT _score DESC """, Arrays.toString(queryVector)); @@ -73,7 +73,7 @@ public void testKnnOptions() { var query = String.format(Locale.ROOT, """ FROM test METADATA _score - | WHERE knn(vector, %s, {"k": 5}) + | WHERE knn(vector, %s, 5) | KEEP id, floats, _score, vector | SORT _score DESC """, Arrays.toString(queryVector)); @@ -94,7 +94,7 @@ public void testKnnNonPushedDown() { // TODO we need to decide what to do when / if user uses k for limit, as no more than k results will be returned from knn query var query = String.format(Locale.ROOT, """ FROM test METADATA _score - | WHERE knn(vector, %s, {"k": 5}) OR id > 10 + | WHERE knn(vector, %s, 5) OR id > 10 | KEEP id, floats, _score, vector | SORT _score DESC """, Arrays.toString(queryVector)); @@ -111,7 +111,7 @@ public void testKnnNonPushedDown() { @Before public void setup() throws IOException { - assumeTrue("Needs KNN support", EsqlCapabilities.Cap.KNN_FUNCTION.isEnabled()); + assumeTrue("Needs KNN support", EsqlCapabilities.Cap.KNN_FUNCTION_V2.isEnabled()); var indexName = "test"; var client = client().admin().indices(); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java index d59ecdaf02e87..85cf21964cefc 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java @@ -1195,7 +1195,7 @@ public enum Cap { /** * Support knn function */ - KNN_FUNCTION(Build.current().isSnapshot()), + KNN_FUNCTION_V2(Build.current().isSnapshot()), LIKE_WITH_LIST_OF_PATTERNS, diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/ExpressionWritables.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/ExpressionWritables.java index 901f364a60041..a3f6d3a089d49 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/ExpressionWritables.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/ExpressionWritables.java @@ -259,7 +259,7 @@ private static List fullText() { } private static List vector() { - if (EsqlCapabilities.Cap.KNN_FUNCTION.isEnabled()) { + if (EsqlCapabilities.Cap.KNN_FUNCTION_V2.isEnabled()) { return List.of(Knn.ENTRY); } return List.of(); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java index b115eb5c33c6e..7b80903faf22f 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java @@ -487,7 +487,7 @@ private static FunctionDefinition[][] snapshotFunctions() { def(LastOverTime.class, LastOverTime::withUnresolvedTimestamp, "last_over_time"), def(FirstOverTime.class, FirstOverTime::withUnresolvedTimestamp, "first_over_time"), def(Term.class, bi(Term::new), "term"), - def(Knn.class, tri(Knn::new), "knn") } }; + def(Knn.class, Knn::new, "knn") } }; } public EsqlFunctionRegistry snapshotRegistry() { diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/Knn.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/Knn.java index ecce0b069693d..08d6d14901c4c 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/Knn.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/Knn.java @@ -41,6 +41,7 @@ import java.util.Objects; import static java.util.Map.entry; +import static org.elasticsearch.TransportVersions.ESQL_KNN_K_PARAM_MANDATORY; import static org.elasticsearch.index.query.AbstractQueryBuilder.BOOST_FIELD; import static org.elasticsearch.search.vectors.KnnVectorQueryBuilder.K_FIELD; import static org.elasticsearch.search.vectors.KnnVectorQueryBuilder.NUM_CANDS_FIELD; @@ -48,6 +49,7 @@ import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.FIRST; import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.SECOND; import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.THIRD; +import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isFoldable; import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isMapExpression; import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isNotNull; import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isNotNullAndFoldable; @@ -62,10 +64,10 @@ public class Knn extends FullTextFunction implements OptionalArgument, VectorFun public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(Expression.class, "Knn", Knn::readFrom); private final Expression field; + private final Expression k; private final Expression options; public static final Map ALLOWED_OPTIONS = Map.ofEntries( - entry(K_FIELD.getPreferredName(), INTEGER), entry(NUM_CANDS_FIELD.getPreferredName(), INTEGER), entry(VECTOR_SIMILARITY_FIELD.getPreferredName(), FLOAT), entry(BOOST_FIELD.getPreferredName(), FLOAT), @@ -90,6 +92,13 @@ public Knn( type = { "dense_vector" }, description = "Vector value to find top nearest neighbours for." ) Expression query, + @Param( + name = "k", + type = { "integer" }, + description = "The number of nearest neighbors to return from each shard. " + + "Elasticsearch collects k results from each shard, then merges them to find the global top results. " + + "This value must be less than or equal to num_candidates." + ) Expression k, @MapParam( name = "options", params = { @@ -100,14 +109,6 @@ public Knn( description = "Floating point number used to decrease or increase the relevance scores of the query." + "Defaults to 1.0." ), - @MapParam.MapParamEntry( - name = "k", - type = "integer", - valueHint = { "10" }, - description = "The number of nearest neighbors to return from each shard. " - + "Elasticsearch collects k results from each shard, then merges them to find the global top results. " - + "This value must be less than or equal to num_candidates. Defaults to 10." - ), @MapParam.MapParamEntry( name = "num_candidates", type = "integer", @@ -136,12 +137,13 @@ public Knn( optional = true ) Expression options ) { - this(source, field, query, options, null); + this(source, field, query, k, options, null); } - private Knn(Source source, Expression field, Expression query, Expression options, QueryBuilder queryBuilder) { - super(source, query, options == null ? List.of(field, query) : List.of(field, query, options), queryBuilder); + private Knn(Source source, Expression field, Expression query, Expression k, Expression options, QueryBuilder queryBuilder) { + super(source, query, options == null ? List.of(field, query, k) : List.of(field, query, k, options), queryBuilder); this.field = field; + this.k = k; this.options = options; } @@ -149,6 +151,10 @@ public Expression field() { return field; } + public Expression k() { + return k; + } + public Expression options() { return options; } @@ -160,7 +166,7 @@ public DataType dataType() { @Override protected TypeResolution resolveParams() { - return resolveField().and(resolveQuery()).and(resolveOptions()); + return resolveField().and(resolveQuery()).and(resolveK()).and(resolveOptions()); } private TypeResolution resolveField() { @@ -173,14 +179,19 @@ private TypeResolution resolveQuery() { ); } + private TypeResolution resolveK() { + return isType(k(), dt -> dt == INTEGER, sourceText(), THIRD, "integer").and(isFoldable(k(), sourceText(), THIRD)) + .and(isNotNull(k(), sourceText(), THIRD)); + } + private TypeResolution resolveOptions() { if (options() != null) { - TypeResolution resolution = isNotNull(options(), sourceText(), THIRD); + TypeResolution resolution = isNotNull(options(), sourceText(), TypeResolutions.ParamOrdinal.FOURTH); if (resolution.unresolved()) { return resolution; } // MapExpression does not have a DataType associated with it - resolution = isMapExpression(options(), sourceText(), THIRD); + resolution = isMapExpression(options(), sourceText(), TypeResolutions.ParamOrdinal.FOURTH); if (resolution.unresolved()) { return resolution; } @@ -200,7 +211,7 @@ private Map knnQueryOptions() throws InvalidArgumentException { } Map matchOptions = new HashMap<>(); - populateOptionsMap((MapExpression) options(), matchOptions, THIRD, sourceText(), ALLOWED_OPTIONS); + populateOptionsMap((MapExpression) options(), matchOptions, TypeResolutions.ParamOrdinal.FOURTH, sourceText(), ALLOWED_OPTIONS); return matchOptions; } @@ -216,22 +227,24 @@ protected Query translate(TranslatorHandler handler) { for (int i = 0; i < queryFolded.size(); i++) { queryAsFloats[i] = queryFolded.get(i).floatValue(); } + int kValue = ((Number) k().fold(FoldContext.small())).intValue(); + + Map opts = queryOptions(); + opts.put(K_FIELD.getPreferredName(), kValue); - return new KnnQuery(source(), fieldName, queryAsFloats, queryOptions()); + return new KnnQuery(source(), fieldName, queryAsFloats, opts); } @Override public Expression replaceQueryBuilder(QueryBuilder queryBuilder) { - return new Knn(source(), field(), query(), options(), queryBuilder); + return new Knn(source(), field(), query(), k(), options(), queryBuilder); } private Map queryOptions() throws InvalidArgumentException { - if (options() == null) { - return Map.of(); - } - Map options = new HashMap<>(); - populateOptionsMap((MapExpression) options(), options, THIRD, sourceText(), ALLOWED_OPTIONS); + if (options() != null) { + populateOptionsMap((MapExpression) options(), options, TypeResolutions.ParamOrdinal.FOURTH, sourceText(), ALLOWED_OPTIONS); + } return options; } @@ -241,14 +254,15 @@ public Expression replaceChildren(List newChildren) { source(), newChildren.get(0), newChildren.get(1), - newChildren.size() > 2 ? newChildren.get(2) : null, + newChildren.get(2), + newChildren.size() > 3 ? newChildren.get(3) : null, queryBuilder() ); } @Override protected NodeInfo info() { - return NodeInfo.create(this, Knn::new, field(), query(), options()); + return NodeInfo.create(this, Knn::new, field(), query(), k(), options()); } @Override @@ -261,8 +275,11 @@ private static Knn readFrom(StreamInput in) throws IOException { Expression field = in.readNamedWriteable(Expression.class); Expression query = in.readNamedWriteable(Expression.class); QueryBuilder queryBuilder = in.readOptionalNamedWriteable(QueryBuilder.class); - - return new Knn(source, field, query, null, queryBuilder); + Expression k = null; + if (in.getTransportVersion().onOrAfter(ESQL_KNN_K_PARAM_MANDATORY)) { + k = in.readNamedWriteable(Expression.class); + } + return new Knn(source, field, query, k, null, queryBuilder); } @Override @@ -271,6 +288,9 @@ public void writeTo(StreamOutput out) throws IOException { out.writeNamedWriteable(field()); out.writeNamedWriteable(query()); out.writeOptionalNamedWriteable(queryBuilder()); + if (out.getTransportVersion().onOrAfter(ESQL_KNN_K_PARAM_MANDATORY)) { + out.writeNamedWriteable(k()); + } } @Override @@ -281,12 +301,13 @@ public boolean equals(Object o) { Knn knn = (Knn) o; return Objects.equals(field(), knn.field()) && Objects.equals(query(), knn.query()) + && Objects.equals(k(), knn.k()) && Objects.equals(queryBuilder(), knn.queryBuilder()); } @Override public int hashCode() { - return Objects.hash(field(), query(), queryBuilder()); + return Objects.hash(field(), query(), k(), queryBuilder()); } } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java index 6935a1efac2ae..9062bdef62d76 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java @@ -298,7 +298,7 @@ public final void test() throws Throwable { ); assumeFalse( "can't use KNN function in csv tests", - testCase.requiredCapabilities.contains(EsqlCapabilities.Cap.KNN_FUNCTION.capabilityName()) + testCase.requiredCapabilities.contains(EsqlCapabilities.Cap.KNN_FUNCTION_V2.capabilityName()) ); assumeFalse( "lookup join disabled for csv tests", diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java index 0c7cdf4855196..f718febfa7db9 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java @@ -2376,7 +2376,7 @@ public void testDenseVectorImplicitCasting() { Analyzer analyzer = analyzer(loadMapping("mapping-dense_vector.json", "vectors")); var plan = analyze(""" - from test | where knn(vector, [0.342, 0.164, 0.234]) + from test | where knn(vector, [0.342, 0.164, 0.234], 10) """, "mapping-dense_vector.json"); var limit = as(plan, Limit.class); diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java index 14e5c0615e2b4..8936a02e599d1 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java @@ -1235,8 +1235,8 @@ public void testFieldBasedFullTextFunctions() throws Exception { checkFieldBasedWithNonIndexedColumn("Term", "term(text, \"cat\")", "function"); checkFieldBasedFunctionNotAllowedAfterCommands("Term", "function", "term(title, \"Meditation\")"); } - if (EsqlCapabilities.Cap.KNN_FUNCTION.isEnabled()) { - checkFieldBasedFunctionNotAllowedAfterCommands("KNN", "function", "knn(vector, [1, 2, 3])"); + if (EsqlCapabilities.Cap.KNN_FUNCTION_V2.isEnabled()) { + checkFieldBasedFunctionNotAllowedAfterCommands("KNN", "function", "knn(vector, [1, 2, 3], 10)"); } } @@ -1368,8 +1368,8 @@ public void testFullTextFunctionsOnlyAllowedInWhere() throws Exception { if (EsqlCapabilities.Cap.MULTI_MATCH_FUNCTION.isEnabled()) { checkFullTextFunctionsOnlyAllowedInWhere("MultiMatch", "multi_match(\"Meditation\", title, body)", "function"); } - if (EsqlCapabilities.Cap.KNN_FUNCTION.isEnabled()) { - checkFullTextFunctionsOnlyAllowedInWhere("KNN", "knn(vector, [0, 1, 2])", "function"); + if (EsqlCapabilities.Cap.KNN_FUNCTION_V2.isEnabled()) { + checkFullTextFunctionsOnlyAllowedInWhere("KNN", "knn(vector, [0, 1, 2], 10)", "function"); } } @@ -1407,8 +1407,8 @@ public void testFullTextFunctionsDisjunctions() { if (EsqlCapabilities.Cap.TERM_FUNCTION.isEnabled()) { checkWithFullTextFunctionsDisjunctions("term(title, \"Meditation\")"); } - if (EsqlCapabilities.Cap.KNN_FUNCTION.isEnabled()) { - checkWithFullTextFunctionsDisjunctions("knn(vector, [1, 2, 3])"); + if (EsqlCapabilities.Cap.KNN_FUNCTION_V2.isEnabled()) { + checkWithFullTextFunctionsDisjunctions("knn(vector, [1, 2, 3], 10)"); } } @@ -1472,8 +1472,8 @@ public void testFullTextFunctionsWithNonBooleanFunctions() { if (EsqlCapabilities.Cap.TERM_FUNCTION.isEnabled()) { checkFullTextFunctionsWithNonBooleanFunctions("Term", "term(title, \"Meditation\")", "function"); } - if (EsqlCapabilities.Cap.KNN_FUNCTION.isEnabled()) { - checkFullTextFunctionsWithNonBooleanFunctions("KNN", "knn(vector, [1, 2, 3])", "function"); + if (EsqlCapabilities.Cap.KNN_FUNCTION_V2.isEnabled()) { + checkFullTextFunctionsWithNonBooleanFunctions("KNN", "knn(vector, [1, 2, 3], 10)", "function"); } } @@ -1543,8 +1543,8 @@ public void testFullTextFunctionsTargetsExistingField() throws Exception { if (EsqlCapabilities.Cap.TERM_FUNCTION.isEnabled()) { testFullTextFunctionTargetsExistingField("term(fist_name, \"Meditation\")"); } - if (EsqlCapabilities.Cap.KNN_FUNCTION.isEnabled()) { - testFullTextFunctionTargetsExistingField("knn(vector, [0, 1, 2])"); + if (EsqlCapabilities.Cap.KNN_FUNCTION_V2.isEnabled()) { + testFullTextFunctionTargetsExistingField("knn(vector, [0, 1, 2], 10)"); } } @@ -2071,8 +2071,8 @@ public void testFullTextFunctionOptions() { if (EsqlCapabilities.Cap.MULTI_MATCH_FUNCTION.isEnabled()) { checkOptionDataTypes(MultiMatch.OPTIONS, "FROM test | WHERE MULTI_MATCH(\"Jean\", title, body, {\"%s\": %s})"); } - if (EsqlCapabilities.Cap.KNN_FUNCTION.isEnabled()) { - checkOptionDataTypes(Knn.ALLOWED_OPTIONS, "FROM test | WHERE KNN(vector, [0.1, 0.2, 0.3], {\"%s\": %s})"); + if (EsqlCapabilities.Cap.KNN_FUNCTION_V2.isEnabled()) { + checkOptionDataTypes(Knn.ALLOWED_OPTIONS, "FROM test | WHERE KNN(vector, [0.1, 0.2, 0.3], 10, {\"%s\": %s})"); } } @@ -2159,9 +2159,10 @@ public void testFullTextFunctionsNullArgs() throws Exception { checkFullTextFunctionNullArgs("term(null, \"query\")", "first"); checkFullTextFunctionNullArgs("term(title, null)", "second"); } - if (EsqlCapabilities.Cap.KNN_FUNCTION.isEnabled()) { - checkFullTextFunctionNullArgs("knn(null, [0, 1, 2])", "first"); - checkFullTextFunctionNullArgs("knn(vector, null)", "second"); + if (EsqlCapabilities.Cap.KNN_FUNCTION_V2.isEnabled()) { + checkFullTextFunctionNullArgs("knn(null, [0, 1, 2], 10)", "first"); + checkFullTextFunctionNullArgs("knn(vector, null, 10)", "second"); + checkFullTextFunctionNullArgs("knn(vector, [0, 1, 2], null)", "third"); } } @@ -2172,24 +2173,25 @@ private void checkFullTextFunctionNullArgs(String functionInvocation, String arg ); } - public void testFullTextFunctionsConstantQuery() throws Exception { - checkFullTextFunctionsConstantQuery("match(title, category)", "second"); - checkFullTextFunctionsConstantQuery("qstr(title)", ""); - checkFullTextFunctionsConstantQuery("kql(title)", ""); - checkFullTextFunctionsConstantQuery("match_phrase(title, tags)", "second"); + public void testFullTextFunctionsConstantArg() throws Exception { + checkFullTextFunctionsConstantArg("match(title, category)", "second"); + checkFullTextFunctionsConstantArg("qstr(title)", ""); + checkFullTextFunctionsConstantArg("kql(title)", ""); + checkFullTextFunctionsConstantArg("match_phrase(title, tags)", "second"); if (EsqlCapabilities.Cap.MULTI_MATCH_FUNCTION.isEnabled()) { - checkFullTextFunctionsConstantQuery("multi_match(category, body)", "first"); - checkFullTextFunctionsConstantQuery("multi_match(concat(title, \"world\"), title)", "first"); + checkFullTextFunctionsConstantArg("multi_match(category, body)", "first"); + checkFullTextFunctionsConstantArg("multi_match(concat(title, \"world\"), title)", "first"); } if (EsqlCapabilities.Cap.TERM_FUNCTION.isEnabled()) { - checkFullTextFunctionsConstantQuery("term(title, tags)", "second"); + checkFullTextFunctionsConstantArg("term(title, tags)", "second"); } - if (EsqlCapabilities.Cap.KNN_FUNCTION.isEnabled()) { - checkFullTextFunctionsConstantQuery("knn(vector, vector)", "second"); + if (EsqlCapabilities.Cap.KNN_FUNCTION_V2.isEnabled()) { + checkFullTextFunctionsConstantArg("knn(vector, vector, 10)", "second"); + checkFullTextFunctionsConstantArg("knn(vector, [0, 1, 2], category)", "third"); } } - private void checkFullTextFunctionsConstantQuery(String functionInvocation, String argOrdinal) throws Exception { + private void checkFullTextFunctionsConstantArg(String functionInvocation, String argOrdinal) throws Exception { assertThat( error("from test | where " + functionInvocation, fullTextAnalyzer), containsString(argOrdinal + " argument of [" + functionInvocation + "] must be a constant") @@ -2214,8 +2216,8 @@ public void testFullTextFunctionsInStats() { if (EsqlCapabilities.Cap.MULTI_MATCH_FUNCTION.isEnabled()) { checkFullTextFunctionsInStats("multi_match(\"Meditation\", title, body)"); } - if (EsqlCapabilities.Cap.KNN_FUNCTION.isEnabled()) { - checkFullTextFunctionsInStats("knn(vector, [0, 1, 2])"); + if (EsqlCapabilities.Cap.KNN_FUNCTION_V2.isEnabled()) { + checkFullTextFunctionsInStats("knn(vector, [0, 1, 2], 10)"); } } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/fulltext/KnnTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/fulltext/KnnTests.java index 76db793c4e772..4a5708b398b18 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/fulltext/KnnTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/fulltext/KnnTests.java @@ -18,6 +18,7 @@ import org.elasticsearch.xpack.esql.core.expression.MapExpression; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.core.type.DataType; +import org.elasticsearch.xpack.esql.core.type.EsField; import org.elasticsearch.xpack.esql.expression.function.AbstractFunctionTestCase; import org.elasticsearch.xpack.esql.expression.function.TestCaseSupplier; import org.elasticsearch.xpack.esql.expression.function.vector.Knn; @@ -27,6 +28,7 @@ import java.util.ArrayList; import java.util.List; +import java.util.Map; import java.util.function.Supplier; import static org.elasticsearch.xpack.esql.SerializationTestUtils.serializeDeserialize; @@ -49,19 +51,33 @@ public static Iterable parameters() { @Before public void checkCapability() { - assumeTrue("KNN is not enabled", EsqlCapabilities.Cap.KNN_FUNCTION.isEnabled()); + assumeTrue("KNN is not enabled", EsqlCapabilities.Cap.KNN_FUNCTION_V2.isEnabled()); } private static List testCaseSuppliers() { List suppliers = new ArrayList<>(); suppliers.add( - TestCaseSupplier.testCaseSupplier( - new TestCaseSupplier.TypedDataSupplier("dense_vector field", KnnTests::randomDenseVector, DENSE_VECTOR), - new TestCaseSupplier.TypedDataSupplier("query", KnnTests::randomDenseVector, DENSE_VECTOR, true), - (d1, d2) -> equalTo("string"), - BOOLEAN, - (o1, o2) -> true + new TestCaseSupplier( + List.of(DENSE_VECTOR, DENSE_VECTOR, DataType.INTEGER), + () -> new TestCaseSupplier.TestCase( + List.of( + new TestCaseSupplier.TypedData( + new FieldAttribute( + Source.EMPTY, + randomIdentifier(), + new EsField(randomIdentifier(), DENSE_VECTOR, Map.of(), false) + ), + DENSE_VECTOR, + "dense_vector field" + ), + new TestCaseSupplier.TypedData(randomDenseVector(), DENSE_VECTOR, "query"), + new TestCaseSupplier.TypedData(randomIntBetween(1, 1000), DataType.INTEGER, "k") + ), + equalTo("KnnEvaluator" + KnnTests.class.getSimpleName()), + BOOLEAN, + equalTo(true) + ) ) ); @@ -104,7 +120,7 @@ private static List addFunctionNamedParams(List args) { - Knn knn = new Knn(source, args.get(0), args.get(1), args.size() > 2 ? args.get(2) : null); + Knn knn = new Knn(source, args.get(0), args.get(1), args.get(2), args.size() > 3 ? args.get(3) : null); // We need to add the QueryBuilder to the match expression, as it is used to implement equals() and hashCode() and // thus test the serialization methods. But we can only do this if the parameters make sense . if (args.get(0) instanceof FieldAttribute && args.get(1).foldable()) { diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizerTests.java index b7d1243759493..66b797afa426c 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizerTests.java @@ -1363,12 +1363,12 @@ public void testMultiMatchOptionsPushDown() { public void testKnnOptionsPushDown() { assumeTrue("dense_vector capability not available", EsqlCapabilities.Cap.DENSE_VECTOR_FIELD_TYPE.isEnabled()); - assumeTrue("knn capability not available", EsqlCapabilities.Cap.KNN_FUNCTION.isEnabled()); + assumeTrue("knn capability not available", EsqlCapabilities.Cap.KNN_FUNCTION_V2.isEnabled()); String query = """ from test - | where KNN(dense_vector, [0.1, 0.2, 0.3], - { "k": 5, "similarity": 0.001, "num_candidates": 10, "rescore_oversample": 7, "boost": 3.5 }) + | where KNN(dense_vector, [0.1, 0.2, 0.3], 5, + { "similarity": 0.001, "num_candidates": 10, "rescore_oversample": 7, "boost": 3.5 }) """; var analyzer = makeAnalyzer("mapping-all-types.json"); var plan = plannerOptimizer.plan(query, IS_SV_STATS, analyzer);