diff --git a/docs/reference/query-languages/esql/images/functions/knn.svg b/docs/reference/query-languages/esql/images/functions/knn.svg
index 75a104a7cdcfa..6e20dbc217206 100644
--- a/docs/reference/query-languages/esql/images/functions/knn.svg
+++ b/docs/reference/query-languages/esql/images/functions/knn.svg
@@ -1 +1 @@
-
\ No newline at end of file
+
\ No newline at end of file
diff --git a/docs/reference/query-languages/esql/kibana/definition/functions/knn.json b/docs/reference/query-languages/esql/kibana/definition/functions/knn.json
index 48d3e582eec58..f7350e39abcdf 100644
--- a/docs/reference/query-languages/esql/kibana/definition/functions/knn.json
+++ b/docs/reference/query-languages/esql/kibana/definition/functions/knn.json
@@ -5,8 +5,8 @@
"description" : "Finds the k nearest vectors to a query vector, as measured by a similarity metric. knn function finds nearest vectors through approximate search on indexed dense_vectors.",
"signatures" : [ ],
"examples" : [
- "from colors metadata _score\n| where knn(rgb_vector, [0, 120, 0])\n| sort _score desc",
- "from colors metadata _score\n| where knn(rgb_vector, [0,255,255], {\"k\": 4})\n| sort _score desc"
+ "from colors metadata _score\n| where knn(rgb_vector, [0, 120, 0], 10)\n| sort _score desc, color asc",
+ "from colors metadata _score\n| where knn(rgb_vector, [0,255,255], 4)\n| sort _score desc, color asc"
],
"preview" : true,
"snapshot_only" : true
diff --git a/docs/reference/query-languages/esql/kibana/docs/functions/knn.md b/docs/reference/query-languages/esql/kibana/docs/functions/knn.md
index 45d1f294ea0a8..c7af797488ba4 100644
--- a/docs/reference/query-languages/esql/kibana/docs/functions/knn.md
+++ b/docs/reference/query-languages/esql/kibana/docs/functions/knn.md
@@ -5,6 +5,6 @@ Finds the k nearest vectors to a query vector, as measured by a similarity metri
```esql
from colors metadata _score
-| where knn(rgb_vector, [0, 120, 0])
-| sort _score desc
+| where knn(rgb_vector, [0, 120, 0], 10)
+| sort _score desc, color asc
```
diff --git a/muted-tests.yml b/muted-tests.yml
index be3866d777e1d..c01ccbd040ccf 100644
--- a/muted-tests.yml
+++ b/muted-tests.yml
@@ -511,9 +511,6 @@ tests:
- class: org.elasticsearch.entitlement.runtime.policy.FileAccessTreeTests
method: testWindowsAbsolutPathAccess
issue: https://github.com/elastic/elasticsearch/issues/129168
-- class: org.elasticsearch.xpack.esql.qa.multi_node.EsqlSpecIT
- method: test {knn-function.KnnSearchWithKOption ASYNC}
- issue: https://github.com/elastic/elasticsearch/issues/129447
- class: org.elasticsearch.xpack.ml.integration.ClassificationIT
method: testWithDatastreams
issue: https://github.com/elastic/elasticsearch/issues/129457
@@ -535,9 +532,6 @@ tests:
- class: org.elasticsearch.xpack.security.PermissionsIT
method: testWhenUserLimitedByOnlyAliasOfIndexCanWriteToIndexWhichWasRolledoverByILMPolicy
issue: https://github.com/elastic/elasticsearch/issues/129481
-- class: org.elasticsearch.xpack.esql.qa.multi_node.EsqlSpecIT
- method: test {knn-function.KnnSearchWithKOption SYNC}
- issue: https://github.com/elastic/elasticsearch/issues/129512
- class: org.elasticsearch.xpack.logsdb.qa.StandardVersusStandardReindexedIntoLogsDbChallengeRestIT
method: testMatchAllQuery
issue: https://github.com/elastic/elasticsearch/issues/129527
diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java
index 917a0db62da28..427eae224aefe 100644
--- a/server/src/main/java/org/elasticsearch/TransportVersions.java
+++ b/server/src/main/java/org/elasticsearch/TransportVersions.java
@@ -306,6 +306,7 @@ static TransportVersion def(int id) {
public static final TransportVersion ML_INFERENCE_CUSTOM_SERVICE_REMOVE_ERROR_PARSING = def(9_102_0_00);
public static final TransportVersion ML_INFERENCE_CUSTOM_SERVICE_EMBEDDING_BATCH_SIZE = def(9_103_0_00);
public static final TransportVersion STREAMS_LOGS_SUPPORT = def(9_104_0_00);
+ public static final TransportVersion ESQL_KNN_K_PARAM_MANDATORY = def(9_105_0_00);
/*
* STOP! READ THIS FIRST! No, really,
diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/knn-function.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/knn-function.csv-spec
index ac6c16f35de01..a0b9558ac0b75 100644
--- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/knn-function.csv-spec
+++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/knn-function.csv-spec
@@ -3,11 +3,11 @@
# top-n query at the shard level
knnSearch
-required_capability: knn_function
+required_capability: knn_function_v2
// tag::knn-function[]
from colors metadata _score
-| where knn(rgb_vector, [0, 120, 0])
+| where knn(rgb_vector, [0, 120, 0], 10)
| sort _score desc, color asc
// end::knn-function[]
| keep color, rgb_vector
@@ -29,31 +29,12 @@ chartreuse | [127.0, 255.0, 0.0]
// end::knn-function-result[]
;
-knnSearchWithKOption
-required_capability: knn_function
-
-// tag::knn-function-options[]
-from colors metadata _score
-| where knn(rgb_vector, [0,255,255], {"k": 4})
-| sort _score desc, color asc
-// end::knn-function-options[]
-| keep color, rgb_vector
-| limit 4
-;
-
-color:text | rgb_vector:dense_vector
-cyan | [0.0, 255.0, 255.0]
-turquoise | [64.0, 224.0, 208.0]
-aqua marine | [127.0, 255.0, 212.0]
-teal | [0.0, 128.0, 128.0]
-;
-
# https://github.com/elastic/elasticsearch/issues/129550
knnSearchWithSimilarityOption-Ignore
-required_capability: knn_function
+required_capability: knn_function_v2
from colors metadata _score
-| where knn(rgb_vector, [255,192,203], {"k": 140, "similarity": 40})
+| where knn(rgb_vector, [255,192,203], 140, {"similarity": 40})
| sort _score desc, color asc
| keep color, rgb_vector
;
@@ -63,14 +44,13 @@ pink | [255.0, 192.0, 203.0]
peach puff | [255.0, 218.0, 185.0]
bisque | [255.0, 228.0, 196.0]
wheat | [245.0, 222.0, 179.0]
-
;
knnHybridSearch
-required_capability: knn_function
+required_capability: knn_function_v2
from colors metadata _score
-| where match(color, "blue") or knn(rgb_vector, [65,105,225], {"k": 140})
+| where match(color, "blue") or knn(rgb_vector, [65,105,225], 140)
| where primary == true
| sort _score desc, color asc
| keep color, rgb_vector
@@ -90,10 +70,10 @@ yellow | [255.0, 255.0, 0.0]
;
knnWithMultipleFunctions
-required_capability: knn_function
+required_capability: knn_function_v2
from colors metadata _score
-| where knn(rgb_vector, [128,128,0], {"k": 140}) and match(color, "olive")
+| where knn(rgb_vector, [128,128,0], 140) and match(color, "olive")
| sort _score desc, color asc
| keep color, rgb_vector
;
@@ -103,11 +83,11 @@ olive | [128.0, 128.0, 0.0]
;
knnAfterKeep
-required_capability: knn_function
+required_capability: knn_function_v2
from colors metadata _score
| keep rgb_vector, color, _score
-| where knn(rgb_vector, [128,255,0], {"k": 140})
+| where knn(rgb_vector, [128,255,0], 140)
| sort _score desc, color asc
| keep rgb_vector
| limit 5
@@ -122,11 +102,11 @@ rgb_vector:dense_vector
;
knnAfterDrop
-required_capability: knn_function
+required_capability: knn_function_v2
from colors metadata _score
| drop primary
-| where knn(rgb_vector, [128,250,0], {"k": 140})
+| where knn(rgb_vector, [128,250,0], 140)
| sort _score desc, color asc
| keep color, rgb_vector
| limit 5
@@ -141,11 +121,11 @@ lime | [0.0, 255.0, 0.0]
;
knnAfterEval
-required_capability: knn_function
+required_capability: knn_function_v2
from colors metadata _score
| eval composed_name = locate(color, " ") > 0
-| where knn(rgb_vector, [128,128,0], {"k": 140})
+| where knn(rgb_vector, [128,128,0], 140)
| sort _score desc, color asc
| keep color, composed_name
| limit 5
@@ -160,11 +140,11 @@ golden rod | true
;
knnWithConjunction
-required_capability: knn_function
+required_capability: knn_function_v2
# TODO We need kNN prefiltering here so we get more candidates that pass the filter
from colors metadata _score
-| where knn(rgb_vector, [255,255,238], {"k": 140}) and hex_code like "#FFF*"
+| where knn(rgb_vector, [255,255,238], 140) and hex_code like "#FFF*"
| sort _score desc, color asc
| keep color, hex_code, rgb_vector
| limit 10
@@ -181,11 +161,11 @@ yellow | #FFFF00 | [255.0, 255.0, 0.0]
;
knnWithDisjunctionAndFiltersConjunction
-required_capability: knn_function
+required_capability: knn_function_v2
# TODO We need kNN prefiltering here so we get more candidates that pass the filter
from colors metadata _score
-| where (knn(rgb_vector, [0,255,255], {"k": 140}) or knn(rgb_vector, [128, 0, 255], {"k": 140})) and primary == true
+| where (knn(rgb_vector, [0,255,255], 140) or knn(rgb_vector, [128, 0, 255], 140)) and primary == true
| keep color, rgb_vector, _score
| sort _score desc, color asc
| drop _score
@@ -205,11 +185,11 @@ yellow | [255.0, 255.0, 0.0]
;
knnWithNonPushableConjunction
-required_capability: knn_function
+required_capability: knn_function_v2
from colors metadata _score
| eval composed_name = locate(color, " ") > 0
-| where knn(rgb_vector, [128,128,0], {"k": 140}) and composed_name == false
+| where knn(rgb_vector, [128,128,0], 140) and composed_name == false
| sort _score desc, color asc
| keep color, composed_name
| limit 10
@@ -230,10 +210,10 @@ maroon | false
# https://github.com/elastic/elasticsearch/issues/129550
testKnnWithNonPushableDisjunctions-Ignore
-required_capability: knn_function
+required_capability: knn_function_v2
from colors metadata _score
-| where knn(rgb_vector, [128,128,0], {"k": 140, "similarity": 30}) or length(color) > 10
+| where knn(rgb_vector, [128,128,0], 140, {"similarity": 30}) or length(color) > 10
| sort _score desc, color asc
| keep color
;
@@ -247,10 +227,10 @@ papaya whip
# https://github.com/elastic/elasticsearch/issues/129550
testKnnWithNonPushableDisjunctionsOnComplexExpressions-Ignore
-required_capability: knn_function
+required_capability: knn_function_v2
from colors metadata _score
-| where (knn(rgb_vector, [128,128,0], {"k": 140, "similarity": 70}) and length(color) < 10) or (knn(rgb_vector, [128,0,128], {"k": 140, "similarity": 60}) and primary == false)
+| where (knn(rgb_vector, [128,128,0], 140, {"similarity": 70}) and length(color) < 10) or (knn(rgb_vector, [128,0,128], 140, {"similarity": 60}) and primary == false)
| sort _score desc, color asc
| keep color, primary
;
@@ -262,11 +242,11 @@ indigo | false
;
testKnnInStatsNonPushable
-required_capability: knn_function
+required_capability: knn_function_v2
from colors
| where length(color) < 10
-| stats c = count(*) where knn(rgb_vector, [128,128,255], {"k": 140})
+| stats c = count(*) where knn(rgb_vector, [128,128,255], 140)
;
c: long
@@ -274,12 +254,12 @@ c: long
;
testKnnInStatsWithGrouping
-required_capability: knn_function
+required_capability: knn_function_v2
required_capability: full_text_functions_in_stats_where
from colors
| where length(color) < 10
-| stats c = count(*) where knn(rgb_vector, [128,128,255], {"k": 140}) by primary
+| stats c = count(*) where knn(rgb_vector, [128,128,255], 140) by primary
;
c: long | primary: boolean
diff --git a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/KnnFunctionIT.java b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/KnnFunctionIT.java
index a262943909938..11f9bd6c5aeb5 100644
--- a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/KnnFunctionIT.java
+++ b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/KnnFunctionIT.java
@@ -39,7 +39,7 @@ public void testKnnDefaults() {
var query = String.format(Locale.ROOT, """
FROM test METADATA _score
- | WHERE knn(vector, %s)
+ | WHERE knn(vector, %s, 10)
| KEEP id, floats, _score, vector
| SORT _score DESC
""", Arrays.toString(queryVector));
@@ -73,7 +73,7 @@ public void testKnnOptions() {
var query = String.format(Locale.ROOT, """
FROM test METADATA _score
- | WHERE knn(vector, %s, {"k": 5})
+ | WHERE knn(vector, %s, 5)
| KEEP id, floats, _score, vector
| SORT _score DESC
""", Arrays.toString(queryVector));
@@ -94,7 +94,7 @@ public void testKnnNonPushedDown() {
// TODO we need to decide what to do when / if user uses k for limit, as no more than k results will be returned from knn query
var query = String.format(Locale.ROOT, """
FROM test METADATA _score
- | WHERE knn(vector, %s, {"k": 5}) OR id > 10
+ | WHERE knn(vector, %s, 5) OR id > 10
| KEEP id, floats, _score, vector
| SORT _score DESC
""", Arrays.toString(queryVector));
@@ -111,7 +111,7 @@ public void testKnnNonPushedDown() {
@Before
public void setup() throws IOException {
- assumeTrue("Needs KNN support", EsqlCapabilities.Cap.KNN_FUNCTION.isEnabled());
+ assumeTrue("Needs KNN support", EsqlCapabilities.Cap.KNN_FUNCTION_V2.isEnabled());
var indexName = "test";
var client = client().admin().indices();
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java
index d59ecdaf02e87..85cf21964cefc 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java
@@ -1195,7 +1195,7 @@ public enum Cap {
/**
* Support knn function
*/
- KNN_FUNCTION(Build.current().isSnapshot()),
+ KNN_FUNCTION_V2(Build.current().isSnapshot()),
LIKE_WITH_LIST_OF_PATTERNS,
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/ExpressionWritables.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/ExpressionWritables.java
index 901f364a60041..a3f6d3a089d49 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/ExpressionWritables.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/ExpressionWritables.java
@@ -259,7 +259,7 @@ private static List fullText() {
}
private static List vector() {
- if (EsqlCapabilities.Cap.KNN_FUNCTION.isEnabled()) {
+ if (EsqlCapabilities.Cap.KNN_FUNCTION_V2.isEnabled()) {
return List.of(Knn.ENTRY);
}
return List.of();
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java
index b115eb5c33c6e..7b80903faf22f 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java
@@ -487,7 +487,7 @@ private static FunctionDefinition[][] snapshotFunctions() {
def(LastOverTime.class, LastOverTime::withUnresolvedTimestamp, "last_over_time"),
def(FirstOverTime.class, FirstOverTime::withUnresolvedTimestamp, "first_over_time"),
def(Term.class, bi(Term::new), "term"),
- def(Knn.class, tri(Knn::new), "knn") } };
+ def(Knn.class, Knn::new, "knn") } };
}
public EsqlFunctionRegistry snapshotRegistry() {
diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/Knn.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/Knn.java
index ecce0b069693d..08d6d14901c4c 100644
--- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/Knn.java
+++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/Knn.java
@@ -41,6 +41,7 @@
import java.util.Objects;
import static java.util.Map.entry;
+import static org.elasticsearch.TransportVersions.ESQL_KNN_K_PARAM_MANDATORY;
import static org.elasticsearch.index.query.AbstractQueryBuilder.BOOST_FIELD;
import static org.elasticsearch.search.vectors.KnnVectorQueryBuilder.K_FIELD;
import static org.elasticsearch.search.vectors.KnnVectorQueryBuilder.NUM_CANDS_FIELD;
@@ -48,6 +49,7 @@
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.FIRST;
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.SECOND;
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.THIRD;
+import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isFoldable;
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isMapExpression;
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isNotNull;
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isNotNullAndFoldable;
@@ -62,10 +64,10 @@ public class Knn extends FullTextFunction implements OptionalArgument, VectorFun
public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(Expression.class, "Knn", Knn::readFrom);
private final Expression field;
+ private final Expression k;
private final Expression options;
public static final Map ALLOWED_OPTIONS = Map.ofEntries(
- entry(K_FIELD.getPreferredName(), INTEGER),
entry(NUM_CANDS_FIELD.getPreferredName(), INTEGER),
entry(VECTOR_SIMILARITY_FIELD.getPreferredName(), FLOAT),
entry(BOOST_FIELD.getPreferredName(), FLOAT),
@@ -90,6 +92,13 @@ public Knn(
type = { "dense_vector" },
description = "Vector value to find top nearest neighbours for."
) Expression query,
+ @Param(
+ name = "k",
+ type = { "integer" },
+ description = "The number of nearest neighbors to return from each shard. "
+ + "Elasticsearch collects k results from each shard, then merges them to find the global top results. "
+ + "This value must be less than or equal to num_candidates."
+ ) Expression k,
@MapParam(
name = "options",
params = {
@@ -100,14 +109,6 @@ public Knn(
description = "Floating point number used to decrease or increase the relevance scores of the query."
+ "Defaults to 1.0."
),
- @MapParam.MapParamEntry(
- name = "k",
- type = "integer",
- valueHint = { "10" },
- description = "The number of nearest neighbors to return from each shard. "
- + "Elasticsearch collects k results from each shard, then merges them to find the global top results. "
- + "This value must be less than or equal to num_candidates. Defaults to 10."
- ),
@MapParam.MapParamEntry(
name = "num_candidates",
type = "integer",
@@ -136,12 +137,13 @@ public Knn(
optional = true
) Expression options
) {
- this(source, field, query, options, null);
+ this(source, field, query, k, options, null);
}
- private Knn(Source source, Expression field, Expression query, Expression options, QueryBuilder queryBuilder) {
- super(source, query, options == null ? List.of(field, query) : List.of(field, query, options), queryBuilder);
+ private Knn(Source source, Expression field, Expression query, Expression k, Expression options, QueryBuilder queryBuilder) {
+ super(source, query, options == null ? List.of(field, query, k) : List.of(field, query, k, options), queryBuilder);
this.field = field;
+ this.k = k;
this.options = options;
}
@@ -149,6 +151,10 @@ public Expression field() {
return field;
}
+ public Expression k() {
+ return k;
+ }
+
public Expression options() {
return options;
}
@@ -160,7 +166,7 @@ public DataType dataType() {
@Override
protected TypeResolution resolveParams() {
- return resolveField().and(resolveQuery()).and(resolveOptions());
+ return resolveField().and(resolveQuery()).and(resolveK()).and(resolveOptions());
}
private TypeResolution resolveField() {
@@ -173,14 +179,19 @@ private TypeResolution resolveQuery() {
);
}
+ private TypeResolution resolveK() {
+ return isType(k(), dt -> dt == INTEGER, sourceText(), THIRD, "integer").and(isFoldable(k(), sourceText(), THIRD))
+ .and(isNotNull(k(), sourceText(), THIRD));
+ }
+
private TypeResolution resolveOptions() {
if (options() != null) {
- TypeResolution resolution = isNotNull(options(), sourceText(), THIRD);
+ TypeResolution resolution = isNotNull(options(), sourceText(), TypeResolutions.ParamOrdinal.FOURTH);
if (resolution.unresolved()) {
return resolution;
}
// MapExpression does not have a DataType associated with it
- resolution = isMapExpression(options(), sourceText(), THIRD);
+ resolution = isMapExpression(options(), sourceText(), TypeResolutions.ParamOrdinal.FOURTH);
if (resolution.unresolved()) {
return resolution;
}
@@ -200,7 +211,7 @@ private Map knnQueryOptions() throws InvalidArgumentException {
}
Map matchOptions = new HashMap<>();
- populateOptionsMap((MapExpression) options(), matchOptions, THIRD, sourceText(), ALLOWED_OPTIONS);
+ populateOptionsMap((MapExpression) options(), matchOptions, TypeResolutions.ParamOrdinal.FOURTH, sourceText(), ALLOWED_OPTIONS);
return matchOptions;
}
@@ -216,22 +227,24 @@ protected Query translate(TranslatorHandler handler) {
for (int i = 0; i < queryFolded.size(); i++) {
queryAsFloats[i] = queryFolded.get(i).floatValue();
}
+ int kValue = ((Number) k().fold(FoldContext.small())).intValue();
+
+ Map opts = queryOptions();
+ opts.put(K_FIELD.getPreferredName(), kValue);
- return new KnnQuery(source(), fieldName, queryAsFloats, queryOptions());
+ return new KnnQuery(source(), fieldName, queryAsFloats, opts);
}
@Override
public Expression replaceQueryBuilder(QueryBuilder queryBuilder) {
- return new Knn(source(), field(), query(), options(), queryBuilder);
+ return new Knn(source(), field(), query(), k(), options(), queryBuilder);
}
private Map queryOptions() throws InvalidArgumentException {
- if (options() == null) {
- return Map.of();
- }
-
Map options = new HashMap<>();
- populateOptionsMap((MapExpression) options(), options, THIRD, sourceText(), ALLOWED_OPTIONS);
+ if (options() != null) {
+ populateOptionsMap((MapExpression) options(), options, TypeResolutions.ParamOrdinal.FOURTH, sourceText(), ALLOWED_OPTIONS);
+ }
return options;
}
@@ -241,14 +254,15 @@ public Expression replaceChildren(List newChildren) {
source(),
newChildren.get(0),
newChildren.get(1),
- newChildren.size() > 2 ? newChildren.get(2) : null,
+ newChildren.get(2),
+ newChildren.size() > 3 ? newChildren.get(3) : null,
queryBuilder()
);
}
@Override
protected NodeInfo extends Expression> info() {
- return NodeInfo.create(this, Knn::new, field(), query(), options());
+ return NodeInfo.create(this, Knn::new, field(), query(), k(), options());
}
@Override
@@ -261,8 +275,11 @@ private static Knn readFrom(StreamInput in) throws IOException {
Expression field = in.readNamedWriteable(Expression.class);
Expression query = in.readNamedWriteable(Expression.class);
QueryBuilder queryBuilder = in.readOptionalNamedWriteable(QueryBuilder.class);
-
- return new Knn(source, field, query, null, queryBuilder);
+ Expression k = null;
+ if (in.getTransportVersion().onOrAfter(ESQL_KNN_K_PARAM_MANDATORY)) {
+ k = in.readNamedWriteable(Expression.class);
+ }
+ return new Knn(source, field, query, k, null, queryBuilder);
}
@Override
@@ -271,6 +288,9 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeNamedWriteable(field());
out.writeNamedWriteable(query());
out.writeOptionalNamedWriteable(queryBuilder());
+ if (out.getTransportVersion().onOrAfter(ESQL_KNN_K_PARAM_MANDATORY)) {
+ out.writeNamedWriteable(k());
+ }
}
@Override
@@ -281,12 +301,13 @@ public boolean equals(Object o) {
Knn knn = (Knn) o;
return Objects.equals(field(), knn.field())
&& Objects.equals(query(), knn.query())
+ && Objects.equals(k(), knn.k())
&& Objects.equals(queryBuilder(), knn.queryBuilder());
}
@Override
public int hashCode() {
- return Objects.hash(field(), query(), queryBuilder());
+ return Objects.hash(field(), query(), k(), queryBuilder());
}
}
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java
index 6935a1efac2ae..9062bdef62d76 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java
@@ -298,7 +298,7 @@ public final void test() throws Throwable {
);
assumeFalse(
"can't use KNN function in csv tests",
- testCase.requiredCapabilities.contains(EsqlCapabilities.Cap.KNN_FUNCTION.capabilityName())
+ testCase.requiredCapabilities.contains(EsqlCapabilities.Cap.KNN_FUNCTION_V2.capabilityName())
);
assumeFalse(
"lookup join disabled for csv tests",
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java
index 0c7cdf4855196..f718febfa7db9 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java
@@ -2376,7 +2376,7 @@ public void testDenseVectorImplicitCasting() {
Analyzer analyzer = analyzer(loadMapping("mapping-dense_vector.json", "vectors"));
var plan = analyze("""
- from test | where knn(vector, [0.342, 0.164, 0.234])
+ from test | where knn(vector, [0.342, 0.164, 0.234], 10)
""", "mapping-dense_vector.json");
var limit = as(plan, Limit.class);
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java
index 14e5c0615e2b4..8936a02e599d1 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java
@@ -1235,8 +1235,8 @@ public void testFieldBasedFullTextFunctions() throws Exception {
checkFieldBasedWithNonIndexedColumn("Term", "term(text, \"cat\")", "function");
checkFieldBasedFunctionNotAllowedAfterCommands("Term", "function", "term(title, \"Meditation\")");
}
- if (EsqlCapabilities.Cap.KNN_FUNCTION.isEnabled()) {
- checkFieldBasedFunctionNotAllowedAfterCommands("KNN", "function", "knn(vector, [1, 2, 3])");
+ if (EsqlCapabilities.Cap.KNN_FUNCTION_V2.isEnabled()) {
+ checkFieldBasedFunctionNotAllowedAfterCommands("KNN", "function", "knn(vector, [1, 2, 3], 10)");
}
}
@@ -1368,8 +1368,8 @@ public void testFullTextFunctionsOnlyAllowedInWhere() throws Exception {
if (EsqlCapabilities.Cap.MULTI_MATCH_FUNCTION.isEnabled()) {
checkFullTextFunctionsOnlyAllowedInWhere("MultiMatch", "multi_match(\"Meditation\", title, body)", "function");
}
- if (EsqlCapabilities.Cap.KNN_FUNCTION.isEnabled()) {
- checkFullTextFunctionsOnlyAllowedInWhere("KNN", "knn(vector, [0, 1, 2])", "function");
+ if (EsqlCapabilities.Cap.KNN_FUNCTION_V2.isEnabled()) {
+ checkFullTextFunctionsOnlyAllowedInWhere("KNN", "knn(vector, [0, 1, 2], 10)", "function");
}
}
@@ -1407,8 +1407,8 @@ public void testFullTextFunctionsDisjunctions() {
if (EsqlCapabilities.Cap.TERM_FUNCTION.isEnabled()) {
checkWithFullTextFunctionsDisjunctions("term(title, \"Meditation\")");
}
- if (EsqlCapabilities.Cap.KNN_FUNCTION.isEnabled()) {
- checkWithFullTextFunctionsDisjunctions("knn(vector, [1, 2, 3])");
+ if (EsqlCapabilities.Cap.KNN_FUNCTION_V2.isEnabled()) {
+ checkWithFullTextFunctionsDisjunctions("knn(vector, [1, 2, 3], 10)");
}
}
@@ -1472,8 +1472,8 @@ public void testFullTextFunctionsWithNonBooleanFunctions() {
if (EsqlCapabilities.Cap.TERM_FUNCTION.isEnabled()) {
checkFullTextFunctionsWithNonBooleanFunctions("Term", "term(title, \"Meditation\")", "function");
}
- if (EsqlCapabilities.Cap.KNN_FUNCTION.isEnabled()) {
- checkFullTextFunctionsWithNonBooleanFunctions("KNN", "knn(vector, [1, 2, 3])", "function");
+ if (EsqlCapabilities.Cap.KNN_FUNCTION_V2.isEnabled()) {
+ checkFullTextFunctionsWithNonBooleanFunctions("KNN", "knn(vector, [1, 2, 3], 10)", "function");
}
}
@@ -1543,8 +1543,8 @@ public void testFullTextFunctionsTargetsExistingField() throws Exception {
if (EsqlCapabilities.Cap.TERM_FUNCTION.isEnabled()) {
testFullTextFunctionTargetsExistingField("term(fist_name, \"Meditation\")");
}
- if (EsqlCapabilities.Cap.KNN_FUNCTION.isEnabled()) {
- testFullTextFunctionTargetsExistingField("knn(vector, [0, 1, 2])");
+ if (EsqlCapabilities.Cap.KNN_FUNCTION_V2.isEnabled()) {
+ testFullTextFunctionTargetsExistingField("knn(vector, [0, 1, 2], 10)");
}
}
@@ -2071,8 +2071,8 @@ public void testFullTextFunctionOptions() {
if (EsqlCapabilities.Cap.MULTI_MATCH_FUNCTION.isEnabled()) {
checkOptionDataTypes(MultiMatch.OPTIONS, "FROM test | WHERE MULTI_MATCH(\"Jean\", title, body, {\"%s\": %s})");
}
- if (EsqlCapabilities.Cap.KNN_FUNCTION.isEnabled()) {
- checkOptionDataTypes(Knn.ALLOWED_OPTIONS, "FROM test | WHERE KNN(vector, [0.1, 0.2, 0.3], {\"%s\": %s})");
+ if (EsqlCapabilities.Cap.KNN_FUNCTION_V2.isEnabled()) {
+ checkOptionDataTypes(Knn.ALLOWED_OPTIONS, "FROM test | WHERE KNN(vector, [0.1, 0.2, 0.3], 10, {\"%s\": %s})");
}
}
@@ -2159,9 +2159,10 @@ public void testFullTextFunctionsNullArgs() throws Exception {
checkFullTextFunctionNullArgs("term(null, \"query\")", "first");
checkFullTextFunctionNullArgs("term(title, null)", "second");
}
- if (EsqlCapabilities.Cap.KNN_FUNCTION.isEnabled()) {
- checkFullTextFunctionNullArgs("knn(null, [0, 1, 2])", "first");
- checkFullTextFunctionNullArgs("knn(vector, null)", "second");
+ if (EsqlCapabilities.Cap.KNN_FUNCTION_V2.isEnabled()) {
+ checkFullTextFunctionNullArgs("knn(null, [0, 1, 2], 10)", "first");
+ checkFullTextFunctionNullArgs("knn(vector, null, 10)", "second");
+ checkFullTextFunctionNullArgs("knn(vector, [0, 1, 2], null)", "third");
}
}
@@ -2172,24 +2173,25 @@ private void checkFullTextFunctionNullArgs(String functionInvocation, String arg
);
}
- public void testFullTextFunctionsConstantQuery() throws Exception {
- checkFullTextFunctionsConstantQuery("match(title, category)", "second");
- checkFullTextFunctionsConstantQuery("qstr(title)", "");
- checkFullTextFunctionsConstantQuery("kql(title)", "");
- checkFullTextFunctionsConstantQuery("match_phrase(title, tags)", "second");
+ public void testFullTextFunctionsConstantArg() throws Exception {
+ checkFullTextFunctionsConstantArg("match(title, category)", "second");
+ checkFullTextFunctionsConstantArg("qstr(title)", "");
+ checkFullTextFunctionsConstantArg("kql(title)", "");
+ checkFullTextFunctionsConstantArg("match_phrase(title, tags)", "second");
if (EsqlCapabilities.Cap.MULTI_MATCH_FUNCTION.isEnabled()) {
- checkFullTextFunctionsConstantQuery("multi_match(category, body)", "first");
- checkFullTextFunctionsConstantQuery("multi_match(concat(title, \"world\"), title)", "first");
+ checkFullTextFunctionsConstantArg("multi_match(category, body)", "first");
+ checkFullTextFunctionsConstantArg("multi_match(concat(title, \"world\"), title)", "first");
}
if (EsqlCapabilities.Cap.TERM_FUNCTION.isEnabled()) {
- checkFullTextFunctionsConstantQuery("term(title, tags)", "second");
+ checkFullTextFunctionsConstantArg("term(title, tags)", "second");
}
- if (EsqlCapabilities.Cap.KNN_FUNCTION.isEnabled()) {
- checkFullTextFunctionsConstantQuery("knn(vector, vector)", "second");
+ if (EsqlCapabilities.Cap.KNN_FUNCTION_V2.isEnabled()) {
+ checkFullTextFunctionsConstantArg("knn(vector, vector, 10)", "second");
+ checkFullTextFunctionsConstantArg("knn(vector, [0, 1, 2], category)", "third");
}
}
- private void checkFullTextFunctionsConstantQuery(String functionInvocation, String argOrdinal) throws Exception {
+ private void checkFullTextFunctionsConstantArg(String functionInvocation, String argOrdinal) throws Exception {
assertThat(
error("from test | where " + functionInvocation, fullTextAnalyzer),
containsString(argOrdinal + " argument of [" + functionInvocation + "] must be a constant")
@@ -2214,8 +2216,8 @@ public void testFullTextFunctionsInStats() {
if (EsqlCapabilities.Cap.MULTI_MATCH_FUNCTION.isEnabled()) {
checkFullTextFunctionsInStats("multi_match(\"Meditation\", title, body)");
}
- if (EsqlCapabilities.Cap.KNN_FUNCTION.isEnabled()) {
- checkFullTextFunctionsInStats("knn(vector, [0, 1, 2])");
+ if (EsqlCapabilities.Cap.KNN_FUNCTION_V2.isEnabled()) {
+ checkFullTextFunctionsInStats("knn(vector, [0, 1, 2], 10)");
}
}
diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/fulltext/KnnTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/fulltext/KnnTests.java
index 76db793c4e772..4a5708b398b18 100644
--- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/fulltext/KnnTests.java
+++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/fulltext/KnnTests.java
@@ -18,6 +18,7 @@
import org.elasticsearch.xpack.esql.core.expression.MapExpression;
import org.elasticsearch.xpack.esql.core.tree.Source;
import org.elasticsearch.xpack.esql.core.type.DataType;
+import org.elasticsearch.xpack.esql.core.type.EsField;
import org.elasticsearch.xpack.esql.expression.function.AbstractFunctionTestCase;
import org.elasticsearch.xpack.esql.expression.function.TestCaseSupplier;
import org.elasticsearch.xpack.esql.expression.function.vector.Knn;
@@ -27,6 +28,7 @@
import java.util.ArrayList;
import java.util.List;
+import java.util.Map;
import java.util.function.Supplier;
import static org.elasticsearch.xpack.esql.SerializationTestUtils.serializeDeserialize;
@@ -49,19 +51,33 @@ public static Iterable