diff --git a/src/main/java/org/opensearch/knn/index/query/BaseQueryFactory.java b/src/main/java/org/opensearch/knn/index/query/BaseQueryFactory.java index 984ed00bc..6c5eea08f 100644 --- a/src/main/java/org/opensearch/knn/index/query/BaseQueryFactory.java +++ b/src/main/java/org/opensearch/knn/index/query/BaseQueryFactory.java @@ -50,7 +50,7 @@ public static class CreateQueryRequest { private QueryBuilder filter; private QueryShardContext context; private RescoreContext rescoreContext; - private boolean expandNested; + private Boolean expandNested; public Optional getFilter() { return Optional.ofNullable(filter); @@ -63,6 +63,10 @@ public Optional getContext() { public Optional getRescoreContext() { return Optional.ofNullable(rescoreContext); } + + public Optional getExpandNested() { + return Optional.ofNullable(expandNested); + } } /** diff --git a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java index 063842a7f..d2b169b2a 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java @@ -109,7 +109,7 @@ public class KNNQueryBuilder extends AbstractQueryBuilder { @Getter private RescoreContext rescoreContext; @Getter - private boolean expandNested; + private Boolean expandNested; /** * Constructs a new query with the given field name and vector @@ -151,7 +151,7 @@ public static class Builder { private String queryName; private float boost = DEFAULT_BOOST; private RescoreContext rescoreContext; - private boolean expandNested; + private Boolean expandNested; public Builder() {} @@ -210,7 +210,7 @@ public Builder rescoreContext(RescoreContext rescoreContext) { return this; } - public Builder expandNested(boolean expandNested) { + public Builder expandNested(Boolean expandNested) { this.expandNested = expandNested; return this; } @@ -330,7 +330,7 @@ public KNNQueryBuilder(String fieldName, float[] vector, int k, QueryBuilder fil this.maxDistance = null; this.minScore = null; this.rescoreContext = null; - this.expandNested = false; + this.expandNested = null; } public static void initialize(ModelDao modelDao) { diff --git a/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java b/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java index 74b864f98..7bac6c126 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java @@ -51,7 +51,7 @@ public static Query create(CreateQueryRequest createQueryRequest) { final Map methodParameters = createQueryRequest.getMethodParameters(); final RescoreContext rescoreContext = createQueryRequest.getRescoreContext().orElse(null); final KNNEngine knnEngine = createQueryRequest.getKnnEngine(); - final boolean expandNested = createQueryRequest.isExpandNested(); + final boolean expandNested = createQueryRequest.getExpandNested().orElse(false); BitSetProducer parentFilter = null; if (createQueryRequest.getContext().isPresent()) { QueryShardContext context = createQueryRequest.getContext().get(); diff --git a/src/main/java/org/opensearch/knn/index/query/parser/KNNQueryBuilderParser.java b/src/main/java/org/opensearch/knn/index/query/parser/KNNQueryBuilderParser.java index 60365c5af..4d085da53 100644 --- a/src/main/java/org/opensearch/knn/index/query/parser/KNNQueryBuilderParser.java +++ b/src/main/java/org/opensearch/knn/index/query/parser/KNNQueryBuilderParser.java @@ -136,7 +136,7 @@ public static KNNQueryBuilder.Builder streamInput(StreamInput in, Function