From b638d5f9a8365f7cfc3355fe66145a0f0041ebca Mon Sep 17 00:00:00 2001 From: Saurabh Singh Date: Sat, 25 Oct 2025 19:12:06 -0700 Subject: [PATCH] Support for traversing BKD tree with prefetching --- .../org/apache/lucene/index/PointValues.java | 69 +++++++++++++ .../org/apache/lucene/util/bkd/BKDReader.java | 50 +++++++++- .../tests/index/AssertingLeafReader.java | 97 ++++++++++++++++++- 3 files changed, 214 insertions(+), 2 deletions(-) diff --git a/lucene/core/src/java/org/apache/lucene/index/PointValues.java b/lucene/core/src/java/org/apache/lucene/index/PointValues.java index c77eec0e5ffd..05851ea84102 100644 --- a/lucene/core/src/java/org/apache/lucene/index/PointValues.java +++ b/lucene/core/src/java/org/apache/lucene/index/PointValues.java @@ -20,6 +20,7 @@ import java.io.UncheckedIOException; import java.math.BigInteger; import java.net.InetAddress; +import java.util.List; import org.apache.lucene.document.BinaryPoint; import org.apache.lucene.document.DoublePoint; import org.apache.lucene.document.Field; @@ -276,6 +277,15 @@ public interface PointTree extends Cloneable { void visitDocValues(IntersectVisitor visitor) throws IOException; } + public interface PrefetchablePointTree extends PointTree { + + /** Visit all the docs below the node at position pos */ + void visitDocIDs(long pos, IntersectVisitor visitor) throws IOException; + + /** call prefetch for docs below the current node */ + void prepareVisitDocIDs(PrefetchCapableVisitor prefetchCapableVisitor) throws IOException; + } + /** * We recurse the {@link PointTree}, using a provided instance of this to guide the recursion. * @@ -341,6 +351,65 @@ default void visit(DocIdSetIterator iterator, byte[] packedValue) throws IOExcep default void grow(int count) {} } + public interface PrefetchCapableVisitor extends IntersectVisitor { + + /** + * return the last matched block ordinal - this is used to avoid prefetching call for contiguous + * ordinals assuming contiguous ordinals prefetching can be taken care by readaheads. + */ + int lastMatchedBlock(); + + /** set last matched block ordinal * */ + void setLastMatchedBlock(int leafNodeOrdinal); + + /** save prefetched block for visting later on * */ + void savePrefetchedBlockForLaterVisit(long leafFp); + + /** returns the saved prefetch blocks * */ + List savedPrefetchedBlocks(); + } + + public final void intersectWithPrefetch(PrefetchCapableVisitor visitor) throws IOException { + final PointTree pointTree = getPointTree(); + assert pointTree instanceof PrefetchablePointTree; + PrefetchablePointTree prefetchablePointTree = (PrefetchablePointTree) pointTree; + intersectWithPrefetch(visitor, prefetchablePointTree); + List fps = visitor.savedPrefetchedBlocks(); + for (int fp = 0; fp < fps.size(); ++fp) { + prefetchablePointTree.visitDocIDs(fps.get(fp), visitor); + } + + assert prefetchablePointTree.moveToParent() == false; + } + + private static void intersectWithPrefetch( + PrefetchCapableVisitor visitor, PrefetchablePointTree pointTree) throws IOException { + while (true) { + Relation compare = + visitor.compare(pointTree.getMinPackedValue(), pointTree.getMaxPackedValue()); + if (compare == Relation.CELL_INSIDE_QUERY) { + // This cell is fully inside the query shape: recursively add all points in this cell + // without filtering + pointTree.prepareVisitDocIDs(visitor); + } else if (compare == Relation.CELL_CROSSES_QUERY) { + // The cell crosses the shape boundary, or the cell fully contains the query, so we fall + // through and do full filtering: + if (pointTree.moveToChild()) { + continue; + } + // TODO: we can assert that the first value here in fact matches what the pointTree + // claimed? + // Leaf node; scan and filter all points in this block: + pointTree.visitDocValues(visitor); + } + while (pointTree.moveToSibling() == false) { + if (pointTree.moveToParent() == false) { + return; + } + } + } + } + /** * Finds all documents and points matching the provided visitor. This method does not enforce live * documents, so it's up to the caller to test whether each document is deleted, if necessary. diff --git a/lucene/core/src/java/org/apache/lucene/util/bkd/BKDReader.java b/lucene/core/src/java/org/apache/lucene/util/bkd/BKDReader.java index 9c991e6b1b4a..2314bde2d244 100644 --- a/lucene/core/src/java/org/apache/lucene/util/bkd/BKDReader.java +++ b/lucene/core/src/java/org/apache/lucene/util/bkd/BKDReader.java @@ -195,7 +195,7 @@ public PointTree getPointTree() throws IOException { isTreeBalanced); } - private static class BKDPointTree implements PointTree { + private static class BKDPointTree implements PrefetchablePointTree { private int nodeID; // during clone, the node root can be different to 1 private final int nodeRoot; @@ -589,6 +589,54 @@ public void visitDocIDs(PointValues.IntersectVisitor visitor) throws IOException addAll(visitor, false); } + @Override + public void prepareVisitDocIDs(PrefetchCapableVisitor prefetchCapableVisitor) + throws IOException { + resetNodeDataPosition(); + prefetchAll(prefetchCapableVisitor); + } + + @Override + public void visitDocIDs(long position, IntersectVisitor visitor) throws IOException { + leafNodes.seek(position); + int count = leafNodes.readVInt(); + if (count <= Integer.MAX_VALUE) { + visitor.grow(count); + } + docIdsWriter.readInts(leafNodes, count, visitor, scratchIterator.docIDs); + } + + private int getLeafNodeOrdinal() { + assert isLeafNode() : "nodeID=" + nodeID + " is not a leaf"; + return nodeID - leafNodeOffset; + } + + public void prefetchAll(PrefetchCapableVisitor prefetchCapableVisitor) throws IOException { + if (isLeafNode()) { + // int count = isLastLeaf() ? config.maxPointsInLeafNode() : lastLeafNodePointCount; + long leafFp = getLeafBlockFP(); + int leafNodeOrdinal = getLeafNodeOrdinal(); + // Only call prefetch is this is the first leaf node ordinal or the first match in + // contigiuous sequence of matches for leaf nodes + // boolean prefetched = false; + if (prefetchCapableVisitor.lastMatchedBlock() == -1 + || prefetchCapableVisitor.lastMatchedBlock() + 1 < leafNodeOrdinal) { + // System.out.println("Prefetched called on " + leafNodeOrdinal); + leafNodes.prefetch(leafFp, 1); + // prefetched = true; + } + prefetchCapableVisitor.setLastMatchedBlock(leafNodeOrdinal); + prefetchCapableVisitor.savePrefetchedBlockForLaterVisit(leafFp); + } else { + pushLeft(); + prefetchAll(prefetchCapableVisitor); + pop(); + pushRight(); + prefetchAll(prefetchCapableVisitor); + pop(); + } + } + public void addAll(PointValues.IntersectVisitor visitor, boolean grown) throws IOException { if (grown == false) { final long size = size(); diff --git a/lucene/test-framework/src/java/org/apache/lucene/tests/index/AssertingLeafReader.java b/lucene/test-framework/src/java/org/apache/lucene/tests/index/AssertingLeafReader.java index 1958cd4b9588..e5fd151e1921 100644 --- a/lucene/test-framework/src/java/org/apache/lucene/tests/index/AssertingLeafReader.java +++ b/lucene/test-framework/src/java/org/apache/lucene/tests/index/AssertingLeafReader.java @@ -1485,7 +1485,11 @@ private void assertStats(int maxDoc) { @Override public PointTree getPointTree() throws IOException { assertThread("Points", creationThread); - return new AssertingPointTree(in, in.getPointTree()); + if (in instanceof PrefetchablePointTree) { + return new AssertingPrefetchablePointTree(in, (PrefetchablePointTree) in.getPointTree()); + } else { + return new AssertingPointTree(in, in.getPointTree()); + } } @Override @@ -1599,6 +1603,97 @@ public void visitDocValues(IntersectVisitor visitor) throws IOException { } } + static class AssertingPrefetchablePointTree implements PointValues.PrefetchablePointTree { + + final PointValues pointValues; + final PointValues.PrefetchablePointTree in; + + AssertingPrefetchablePointTree(PointValues pointValues, PointValues.PrefetchablePointTree in) { + this.pointValues = pointValues; + this.in = in; + } + + @Override + public PointValues.PointTree clone() { + return new AssertingPrefetchablePointTree( + pointValues, (PointValues.PrefetchablePointTree) in.clone()); + } + + @Override + public boolean moveToChild() throws IOException { + return in.moveToChild(); + } + + @Override + public boolean moveToSibling() throws IOException { + return in.moveToSibling(); + } + + @Override + public boolean moveToParent() throws IOException { + return in.moveToParent(); + } + + @Override + public byte[] getMinPackedValue() { + return in.getMinPackedValue(); + } + + @Override + public byte[] getMaxPackedValue() { + return in.getMaxPackedValue(); + } + + @Override + public long size() { + final long size = in.size(); + assert size > 0; + return size; + } + + @Override + public void visitDocIDs(IntersectVisitor visitor) throws IOException { + in.visitDocIDs( + new AssertingIntersectVisitor( + pointValues.getNumDimensions(), + pointValues.getNumIndexDimensions(), + pointValues.getBytesPerDimension(), + visitor)); + } + + @Override + public void visitDocValues(IntersectVisitor visitor) throws IOException { + in.visitDocValues( + new AssertingIntersectVisitor( + pointValues.getNumDimensions(), + pointValues.getNumIndexDimensions(), + pointValues.getBytesPerDimension(), + visitor)); + } + + /** + * Visit all the docs below the node at position pos + * + * @param pos position of block from where to start reading doc ids + * @param visitor visitor that will visit doc ids. + */ + @Override + public void visitDocIDs(long pos, IntersectVisitor visitor) throws IOException { + in.visitDocIDs(pos, visitor); + } + + /** + * call prefetch for docs below the current node + * + * @param prefetchCapableVisitor prefetch capable visitors + */ + @Override + public void prepareVisitDocIDs(PointValues.PrefetchCapableVisitor prefetchCapableVisitor) + throws IOException { + in.prepareVisitDocIDs(prefetchCapableVisitor); + } + } + /** * Validates in the 1D case that all points are visited in order, and point values are in bounds * of the last cell checked