Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
import static org.elasticsearch.xpack.esql.DenseVectorFieldTypeIT.ALL_DENSE_VECTOR_INDEX_TYPES;
import static org.elasticsearch.xpack.esql.DenseVectorFieldTypeIT.NON_QUANTIZED_DENSE_VECTOR_INDEX_TYPES;
import static org.hamcrest.CoreMatchers.containsString;
import static org.hamcrest.Matchers.greaterThan;
import static org.hamcrest.Matchers.lessThanOrEqualTo;

public class KnnFunctionIT extends AbstractEsqlIntegTestCase {
Expand Down Expand Up @@ -226,6 +227,35 @@ public void testKnnWithLookupJoin() {
);
}

public void testKnnWitNoRetrievedVector() {
float[] queryVector = new float[numDims];
Arrays.fill(queryVector, 0.0f);

var query = String.format(Locale.ROOT, """
FROM test METADATA _score
| WHERE knn(vector, %s)
| SORT _score DESC
| LIMIT 10
| KEEP id, _score
""", Arrays.toString(queryVector));

try (var resp = run(query)) {
assertColumnNames(resp.columns(), List.of("id", "_score"));
assertColumnTypes(resp.columns(), List.of("integer", "double"));

List<List<Object>> valuesList = EsqlTestUtils.getValuesList(resp);
assertEquals(10, valuesList.size());
double previousScore = Float.MAX_VALUE;
for (List<Object> row : valuesList) {
// Vectors should be in score order
double currentScore = (Double) row.get(1);
assertThat(currentScore, greaterThan(0.0));
assertThat(currentScore, lessThanOrEqualTo(previousScore));
previousScore = currentScore;
}
}
}

@Before
public void setup() throws IOException {
var indexName = "test";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@

import org.elasticsearch.index.IndexMode;
import org.elasticsearch.xpack.esql.core.expression.Attribute;
import org.elasticsearch.xpack.esql.core.expression.AttributeSet;
import org.elasticsearch.xpack.esql.core.expression.FieldAttribute;
import org.elasticsearch.xpack.esql.core.tree.Source;
import org.elasticsearch.xpack.esql.core.util.CollectionUtils;
import org.elasticsearch.xpack.esql.expression.Order;
import org.elasticsearch.xpack.esql.optimizer.LocalPhysicalOptimizerContext;
import org.elasticsearch.xpack.esql.optimizer.rules.physical.local.InsertFieldExtraction;
import org.elasticsearch.xpack.esql.optimizer.rules.physical.local.PushTopNToSource;
Expand All @@ -30,8 +32,10 @@
import org.elasticsearch.xpack.esql.planner.mapper.LocalMapper;
import org.elasticsearch.xpack.esql.stats.SearchStats;

import java.util.HashSet;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.function.Function;

/**
Expand Down Expand Up @@ -74,7 +78,7 @@
* The above actually reads the {@code x} field "unnecessarily", since it's only needed to conform to the output schema of the original
* plan. See #134363 for a way to optimize this little problem.
*/
class LateMaterializationPlanner {
public class LateMaterializationPlanner {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why change the visibility?

public static Optional<ReductionPlan> planReduceDriverTopN(
Function<SearchStats, LocalPhysicalOptimizerContext> contextFactory,
ExchangeSinkExec originalPlan
Expand All @@ -95,8 +99,15 @@ public static Optional<ReductionPlan> planReduceDriverTopN(
}

LocalPhysicalOptimizerContext context = contextFactory.apply(SEARCH_STATS_TOP_N_REPLACEMENT);
List<Attribute> expectedDataOutput = toPhysical(topN, context).output();
Attribute doc = expectedDataOutput.stream().filter(EsQueryExec::isDocAttribute).findFirst().orElse(null);

AttributeSet expectedDataOutputAttrSet = AttributeSet.builder().addAll(topLevelProject.output()).build();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can be simplified:

AttributeSet expectedDataOutputAttrSet = AttributeSet.of(topLevelProject.output());

for (Order order : topN.order()) {
expectedDataOutputAttrSet = expectedDataOutputAttrSet.combine(order.references());
}

Set<Attribute> topLevelProjectAttrs = new HashSet<>(expectedDataOutputAttrSet);
List<Attribute> physicalDataOutput = toPhysical(topN, context).output();
Attribute doc = physicalDataOutput.stream().filter(EsQueryExec::isDocAttribute).findFirst().orElse(null);
if (doc == null) {
return Optional.empty();
}
Expand All @@ -114,8 +125,23 @@ public static Optional<ReductionPlan> planReduceDriverTopN(
return Optional.empty();
}

// Calculate the expected output attributes for the data driver plan.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: redundant comment (You can rename the variable to expectedDataDriverOutputAttrs if you wanted to, or extract this code to a method if you think it warrants a header.)

AttributeSet.Builder expectedDataOutputAttrs = AttributeSet.builder();
// We need to add the doc attribute to the project since otherwise when the fragment is converted to a physical plan for the data
// driver, the resulting ProjectExec won't have the doc attribute in its output, which is needed by the reduce driver.
expectedDataOutputAttrs.add(doc);
// Add all references used in the ordering
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI: this can be shortened to one line:

AttributeSet orderRefsSet = AttributeSet.of(topN.order().stream().flatMap(o -> o.references().stream()).toList());

AttributeSet.Builder orderRefs = AttributeSet.builder();
for (Order order : topN.order()) {
orderRefs.addAll(order.references());
}
AttributeSet orderRefsSet = orderRefs.build();
// Get the output from the physical plan below the TopN, and filter it to only the attributes needed for the final output (either
// because they are in the top-level Project's output, or because they are needed for ordering)
expectedDataOutputAttrs.addAll(
physicalDataOutput.stream().filter(a -> topLevelProject.outputSet().contains(a) || orderRefsSet.contains(a)).toList()
);
List<Attribute> expectedDataOutput = expectedDataOutputAttrs.build().stream().toList();
var updatedFragment = new Project(Source.EMPTY, withAddedDocToRelation, expectedDataOutput);
FragmentExec updatedFragmentExec = fragmentExec.withFragment(updatedFragment);
ExchangeSinkExec updatedDataPlan = originalPlan.replaceChild(updatedFragmentExec);
Expand Down