From e8575437916c341defee340f51ac988b83442f7e Mon Sep 17 00:00:00 2001
From: Heemin Kim <heemin@amazon.com>
Date: Fri, 13 Dec 2024 20:11:09 -0800
Subject: [PATCH] Switch type of expandNested from boolean to Boolean (#2333)

Signed-off-by: Heemin Kim <heemin@amazon.com>
---
 .../org/opensearch/knn/index/query/BaseQueryFactory.java  | 6 +++++-
 .../org/opensearch/knn/index/query/KNNQueryBuilder.java   | 8 ++++----
 .../org/opensearch/knn/index/query/KNNQueryFactory.java   | 2 +-
 .../knn/index/query/parser/KNNQueryBuilderParser.java     | 8 ++++----
 4 files changed, 14 insertions(+), 10 deletions(-)

diff --git a/src/main/java/org/opensearch/knn/index/query/BaseQueryFactory.java b/src/main/java/org/opensearch/knn/index/query/BaseQueryFactory.java
index 984ed00bc..6c5eea08f 100644
--- a/src/main/java/org/opensearch/knn/index/query/BaseQueryFactory.java
+++ b/src/main/java/org/opensearch/knn/index/query/BaseQueryFactory.java
@@ -50,7 +50,7 @@ public static class CreateQueryRequest {
         private QueryBuilder filter;
         private QueryShardContext context;
         private RescoreContext rescoreContext;
-        private boolean expandNested;
+        private Boolean expandNested;
 
         public Optional<QueryBuilder> getFilter() {
             return Optional.ofNullable(filter);
@@ -63,6 +63,10 @@ public Optional<QueryShardContext> getContext() {
         public Optional<RescoreContext> getRescoreContext() {
             return Optional.ofNullable(rescoreContext);
         }
+
+        public Optional<Boolean> getExpandNested() {
+            return Optional.ofNullable(expandNested);
+        }
     }
 
     /**
diff --git a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java
index 063842a7f..d2b169b2a 100644
--- a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java
+++ b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java
@@ -109,7 +109,7 @@ public class KNNQueryBuilder extends AbstractQueryBuilder<KNNQueryBuilder> {
     @Getter
     private RescoreContext rescoreContext;
     @Getter
-    private boolean expandNested;
+    private Boolean expandNested;
 
     /**
      * Constructs a new query with the given field name and vector
@@ -151,7 +151,7 @@ public static class Builder {
         private String queryName;
         private float boost = DEFAULT_BOOST;
         private RescoreContext rescoreContext;
-        private boolean expandNested;
+        private Boolean expandNested;
 
         public Builder() {}
 
@@ -210,7 +210,7 @@ public Builder rescoreContext(RescoreContext rescoreContext) {
             return this;
         }
 
-        public Builder expandNested(boolean expandNested) {
+        public Builder expandNested(Boolean expandNested) {
             this.expandNested = expandNested;
             return this;
         }
@@ -330,7 +330,7 @@ public KNNQueryBuilder(String fieldName, float[] vector, int k, QueryBuilder fil
         this.maxDistance = null;
         this.minScore = null;
         this.rescoreContext = null;
-        this.expandNested = false;
+        this.expandNested = null;
     }
 
     public static void initialize(ModelDao modelDao) {
diff --git a/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java b/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java
index 74b864f98..7bac6c126 100644
--- a/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java
+++ b/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java
@@ -51,7 +51,7 @@ public static Query create(CreateQueryRequest createQueryRequest) {
         final Map<String, ?> methodParameters = createQueryRequest.getMethodParameters();
         final RescoreContext rescoreContext = createQueryRequest.getRescoreContext().orElse(null);
         final KNNEngine knnEngine = createQueryRequest.getKnnEngine();
-        final boolean expandNested = createQueryRequest.isExpandNested();
+        final boolean expandNested = createQueryRequest.getExpandNested().orElse(false);
         BitSetProducer parentFilter = null;
         if (createQueryRequest.getContext().isPresent()) {
             QueryShardContext context = createQueryRequest.getContext().get();
diff --git a/src/main/java/org/opensearch/knn/index/query/parser/KNNQueryBuilderParser.java b/src/main/java/org/opensearch/knn/index/query/parser/KNNQueryBuilderParser.java
index 376f60334..57fbed90b 100644
--- a/src/main/java/org/opensearch/knn/index/query/parser/KNNQueryBuilderParser.java
+++ b/src/main/java/org/opensearch/knn/index/query/parser/KNNQueryBuilderParser.java
@@ -133,7 +133,7 @@ public static KNNQueryBuilder.Builder streamInput(StreamInput in, Function<Strin
         }
 
         if (minClusterVersionCheck.apply(EXPAND_NESTED)) {
-            builder.expandNested(in.readBoolean());
+            builder.expandNested(in.readOptionalBoolean());
         }
 
         return builder;
@@ -169,7 +169,7 @@ public static void streamOutput(StreamOutput out, KNNQueryBuilder builder, Funct
             RescoreParser.streamOutput(out, builder.getRescoreContext());
         }
         if (minClusterVersionCheck.apply(EXPAND_NESTED)) {
-            out.writeBoolean(builder.isExpandNested());
+            out.writeOptionalBoolean(builder.getExpandNested());
         }
     }
 
@@ -245,8 +245,8 @@ public static void toXContent(XContentBuilder builder, ToXContent.Params params,
         if (knnQueryBuilder.queryName() != null) {
             builder.field(NAME_FIELD.getPreferredName(), knnQueryBuilder.queryName());
         }
-        if (knnQueryBuilder.isExpandNested()) {
-            builder.field(EXPAND_NESTED, knnQueryBuilder.isExpandNested());
+        if (knnQueryBuilder.getExpandNested() != null) {
+            builder.field(EXPAND_NESTED, knnQueryBuilder.getExpandNested());
         }
 
         builder.endObject();