Skip to content

Prefetch PostingList #133009

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Aug 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,50 @@ public DefaultIVFVectorsReader(SegmentReadState state, FlatVectorsReader rawVect
super(state, rawVectorsReader);
}

CentroidIterator getPostingListPrefetchIterator(CentroidIterator centroidIterator, IndexInput postingListSlice) throws IOException {
return new CentroidIterator() {
CentroidOffsetAndLength nextOffsetAndLength = centroidIterator.hasNext()
? centroidIterator.nextPostingListOffsetAndLength()
: null;

{
// prefetch the first one
if (nextOffsetAndLength != null) {
prefetch(nextOffsetAndLength);
}
}

void prefetch(CentroidOffsetAndLength offsetAndLength) throws IOException {
postingListSlice.prefetch(offsetAndLength.offset(), offsetAndLength.length());
}

@Override
public boolean hasNext() {
return nextOffsetAndLength != null;
}

@Override
public CentroidOffsetAndLength nextPostingListOffsetAndLength() throws IOException {
CentroidOffsetAndLength offsetAndLength = nextOffsetAndLength;
if (centroidIterator.hasNext()) {
nextOffsetAndLength = centroidIterator.nextPostingListOffsetAndLength();
prefetch(nextOffsetAndLength);
} else {
nextOffsetAndLength = null; // indicate we reached the end
}
return offsetAndLength;
}
};
}

@Override
CentroidIterator getCentroidIterator(FieldInfo fieldInfo, int numCentroids, IndexInput centroids, float[] targetQuery)
throws IOException {
CentroidIterator getCentroidIterator(
FieldInfo fieldInfo,
int numCentroids,
IndexInput centroids,
float[] targetQuery,
IndexInput postingListSlice
) throws IOException {
final FieldEntry fieldEntry = fields.get(fieldInfo.number);
final float globalCentroidDp = fieldEntry.globalCentroidDp();
final OptimizedScalarQuantizer scalarQuantizer = new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction());
Expand All @@ -71,8 +112,9 @@ CentroidIterator getCentroidIterator(FieldInfo fieldInfo, int numCentroids, Inde
final ES92Int7VectorsScorer scorer = ESVectorUtil.getES92Int7VectorsScorer(centroids, fieldInfo.getVectorDimension());
centroids.seek(0L);
int numParents = centroids.readVInt();
CentroidIterator centroidIterator;
if (numParents > 0) {
return getCentroidIteratorWithParents(
centroidIterator = getCentroidIteratorWithParents(
fieldInfo,
centroids,
numParents,
Expand All @@ -82,8 +124,18 @@ CentroidIterator getCentroidIterator(FieldInfo fieldInfo, int numCentroids, Inde
queryParams,
globalCentroidDp
);
} else {
centroidIterator = getCentroidIteratorNoParent(
fieldInfo,
centroids,
numCentroids,
scorer,
quantized,
queryParams,
globalCentroidDp
);
}
return getCentroidIteratorNoParent(fieldInfo, centroids, numCentroids, scorer, quantized, queryParams, globalCentroidDp);
return getPostingListPrefetchIterator(centroidIterator, postingListSlice);
}

private static CentroidIterator getCentroidIteratorNoParent(
Expand Down Expand Up @@ -115,10 +167,12 @@ public boolean hasNext() {
}

@Override
public long nextPostingListOffset() throws IOException {
public CentroidOffsetAndLength nextPostingListOffsetAndLength() throws IOException {
int centroidOrdinal = neighborQueue.pop();
centroids.seek(offset + (long) Long.BYTES * centroidOrdinal);
return centroids.readLong();
centroids.seek(offset + (long) Long.BYTES * 2 * centroidOrdinal);
long postingListOffset = centroids.readLong();
long postingListLength = centroids.readLong();
return new CentroidOffsetAndLength(postingListOffset, postingListLength);
}
};
}
Expand Down Expand Up @@ -185,11 +239,13 @@ public boolean hasNext() {
}

@Override
public long nextPostingListOffset() throws IOException {
public CentroidOffsetAndLength nextPostingListOffsetAndLength() throws IOException {
int centroidOrdinal = neighborQueue.pop();
updateQueue(); // add one children if available so the queue remains fully populated
centroids.seek(childrenFileOffsets + (long) Long.BYTES * centroidOrdinal);
return centroids.readLong();
centroids.seek(childrenFileOffsets + (long) Long.BYTES * 2 * centroidOrdinal);
long postingListOffset = centroids.readLong();
long postingListLength = centroids.readLong();
return new CentroidOffsetAndLength(postingListOffset, postingListLength);
}

private void updateQueue() throws IOException {
Expand Down Expand Up @@ -452,6 +508,7 @@ public int visit(KnnCollector knnCollector) throws IOException {
int scoredDocs = 0;
int limit = vectors - BULK_SIZE + 1;
int i = 0;

for (; i < limit; i += BULK_SIZE) {
final int docsToBulkScore = acceptDocs == null ? BULK_SIZE : docToBulkScore(docIdsScratch, i, acceptDocs);
if (docsToBulkScore == 0) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.IndexOutput;
import org.apache.lucene.util.IntroSorter;
import org.apache.lucene.util.LongValues;
import org.apache.lucene.util.VectorUtil;
import org.apache.lucene.util.hnsw.IntToIntFunction;
import org.apache.lucene.util.packed.PackedInts;
Expand Down Expand Up @@ -60,7 +59,7 @@ public DefaultIVFVectorsWriter(
}

@Override
LongValues buildAndWritePostingsLists(
CentroidOffsetAndLength buildAndWritePostingsLists(
FieldInfo fieldInfo,
CentroidSupplier centroidSupplier,
FloatVectorValues floatVectorValues,
Expand Down Expand Up @@ -102,6 +101,7 @@ LongValues buildAndWritePostingsLists(
postingsOutput.writeVInt(maxPostingListSize);
// write the posting lists
final PackedLongValues.Builder offsets = PackedLongValues.monotonicBuilder(PackedInts.COMPACT);
final PackedLongValues.Builder lengths = PackedLongValues.monotonicBuilder(PackedInts.COMPACT);
DiskBBQBulkWriter bulkWriter = new DiskBBQBulkWriter.OneBitDiskBBQBulkWriter(ES91OSQVectorsScorer.BULK_SIZE, postingsOutput);
OnHeapQuantizedVectors onHeapQuantizedVectors = new OnHeapQuantizedVectors(
floatVectorValues,
Expand All @@ -116,7 +116,8 @@ LongValues buildAndWritePostingsLists(
for (int c = 0; c < centroidSupplier.size(); c++) {
float[] centroid = centroidSupplier.centroid(c);
int[] cluster = assignmentsByCluster[c];
offsets.add(postingsOutput.alignFilePointer(Float.BYTES) - fileOffset);
long offset = postingsOutput.alignFilePointer(Float.BYTES) - fileOffset;
offsets.add(offset);
buffer.asFloatBuffer().put(centroid);
// write raw centroid for quantizing the query vectors
postingsOutput.writeBytes(buffer.array(), buffer.array().length);
Expand All @@ -142,17 +143,18 @@ LongValues buildAndWritePostingsLists(
idsWriter.writeDocIds(i -> docDeltas[i], size, postingsOutput);
// write vectors
bulkWriter.writeVectors(onHeapQuantizedVectors);
lengths.add(postingsOutput.getFilePointer() - fileOffset - offset);
}

if (logger.isDebugEnabled()) {
printClusterQualityStatistics(assignmentsByCluster);
}

return offsets.build();
return new CentroidOffsetAndLength(offsets.build(), lengths.build());
}

@Override
LongValues buildAndWritePostingsLists(
CentroidOffsetAndLength buildAndWritePostingsLists(
FieldInfo fieldInfo,
CentroidSupplier centroidSupplier,
FloatVectorValues floatVectorValues,
Expand Down Expand Up @@ -243,6 +245,7 @@ LongValues buildAndWritePostingsLists(
// now we can read the quantized vectors from the temporary file
try (IndexInput quantizedVectorsInput = mergeState.segmentInfo.dir.openInput(quantizedVectorsTempName, IOContext.DEFAULT)) {
final PackedLongValues.Builder offsets = PackedLongValues.monotonicBuilder(PackedInts.COMPACT);
final PackedLongValues.Builder lengths = PackedLongValues.monotonicBuilder(PackedInts.COMPACT);
OffHeapQuantizedVectors offHeapQuantizedVectors = new OffHeapQuantizedVectors(
quantizedVectorsInput,
fieldInfo.getVectorDimension()
Expand All @@ -260,7 +263,8 @@ LongValues buildAndWritePostingsLists(
float[] centroid = centroidSupplier.centroid(c);
int[] cluster = assignmentsByCluster[c];
boolean[] isOverspill = isOverspillByCluster[c];
offsets.add(postingsOutput.alignFilePointer(Float.BYTES) - fileOffset);
long offset = postingsOutput.alignFilePointer(Float.BYTES) - fileOffset;
offsets.add(offset);
// write raw centroid for quantizing the query vectors
buffer.asFloatBuffer().put(centroid);
postingsOutput.writeBytes(buffer.array(), buffer.array().length);
Expand All @@ -286,12 +290,14 @@ LongValues buildAndWritePostingsLists(
idsWriter.writeDocIds(i -> docDeltas[i], size, postingsOutput);
// write vectors
bulkWriter.writeVectors(offHeapQuantizedVectors);
lengths.add(postingsOutput.getFilePointer() - fileOffset - offset);
// lengths.add(1);
}

if (logger.isDebugEnabled()) {
printClusterQualityStatistics(assignmentsByCluster);
}
return offsets.build();
return new CentroidOffsetAndLength(offsets.build(), lengths.build());
}
}

Expand Down Expand Up @@ -335,24 +341,24 @@ void writeCentroids(
FieldInfo fieldInfo,
CentroidSupplier centroidSupplier,
float[] globalCentroid,
LongValues offsets,
CentroidOffsetAndLength centroidOffsetAndLength,
IndexOutput centroidOutput
) throws IOException {
// TODO do we want to store these distances as well for future use?
// TODO: sort centroids by global centroid (was doing so previously here)
// TODO: sorting tanks recall possibly because centroids ordinals no longer are aligned
if (centroidSupplier.size() > centroidsPerParentCluster * centroidsPerParentCluster) {
writeCentroidsWithParents(fieldInfo, centroidSupplier, globalCentroid, offsets, centroidOutput);
writeCentroidsWithParents(fieldInfo, centroidSupplier, globalCentroid, centroidOffsetAndLength, centroidOutput);
} else {
writeCentroidsWithoutParents(fieldInfo, centroidSupplier, globalCentroid, offsets, centroidOutput);
writeCentroidsWithoutParents(fieldInfo, centroidSupplier, globalCentroid, centroidOffsetAndLength, centroidOutput);
}
}

private void writeCentroidsWithParents(
FieldInfo fieldInfo,
CentroidSupplier centroidSupplier,
float[] globalCentroid,
LongValues offsets,
CentroidOffsetAndLength centroidOffsetAndLength,
IndexOutput centroidOutput
) throws IOException {
DiskBBQBulkWriter.SevenBitDiskBBQBulkWriter bulkWriter = new DiskBBQBulkWriter.SevenBitDiskBBQBulkWriter(
Expand Down Expand Up @@ -392,7 +398,8 @@ private void writeCentroidsWithParents(
for (int i = 0; i < centroidGroups.centroids().length; i++) {
final int[] centroidAssignments = centroidGroups.vectors()[i];
for (int assignment : centroidAssignments) {
centroidOutput.writeLong(offsets.get(assignment));
centroidOutput.writeLong(centroidOffsetAndLength.offsets().get(assignment));
centroidOutput.writeLong(centroidOffsetAndLength.lengths().get(assignment));
}
}
}
Expand All @@ -401,7 +408,7 @@ private void writeCentroidsWithoutParents(
FieldInfo fieldInfo,
CentroidSupplier centroidSupplier,
float[] globalCentroid,
LongValues offsets,
CentroidOffsetAndLength centroidOffsetAndLength,
IndexOutput centroidOutput
) throws IOException {
centroidOutput.writeVInt(0);
Expand All @@ -419,7 +426,8 @@ private void writeCentroidsWithoutParents(
bulkWriter.writeVectors(quantizedCentroids);
// write the centroid offsets at the end of the file
for (int i = 0; i < centroidSupplier.size(); i++) {
centroidOutput.writeLong(offsets.get(i));
centroidOutput.writeLong(centroidOffsetAndLength.offsets().get(i));
centroidOutput.writeLong(centroidOffsetAndLength.lengths().get(i));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,13 @@ protected IVFVectorsReader(SegmentReadState state, FlatVectorsReader rawVectorsR
}
}

abstract CentroidIterator getCentroidIterator(FieldInfo fieldInfo, int numCentroids, IndexInput centroids, float[] target)
throws IOException;
abstract CentroidIterator getCentroidIterator(
FieldInfo fieldInfo,
int numCentroids,
IndexInput centroids,
float[] target,
IndexInput postingListSlice
) throws IOException;

private static IndexInput openDataInput(
SegmentReadState state,
Expand Down Expand Up @@ -241,31 +246,37 @@ public final void search(String field, float[] target, KnnCollector knnCollector
}
// we account for soar vectors here. We can potentially visit a vector twice so we multiply by 2 here.
long maxVectorVisited = (long) (2.0 * visitRatio * numVectors);
CentroidIterator centroidIterator = getCentroidIterator(fieldInfo, entry.numCentroids, entry.centroidSlice(ivfCentroids), target);
PostingVisitor scorer = getPostingVisitor(fieldInfo, entry.postingListSlice(ivfClusters), target, acceptDocs);

IndexInput postListSlice = entry.postingListSlice(ivfClusters);
CentroidIterator centroidPrefetchingIterator = getCentroidIterator(
fieldInfo,
entry.numCentroids,
entry.centroidSlice(ivfCentroids),
target,
postListSlice
);
PostingVisitor scorer = getPostingVisitor(fieldInfo, postListSlice, target, acceptDocs);
long expectedDocs = 0;
long actualDocs = 0;
// initially we visit only the "centroids to search"
// Note, numCollected is doing the bare minimum here.
// TODO do we need to handle nested doc counts similarly to how we handle
// filtering? E.g. keep exploring until we hit an expected number of parent documents vs. child vectors?
while (centroidIterator.hasNext()
while (centroidPrefetchingIterator.hasNext()
&& (maxVectorVisited > expectedDocs || knnCollector.minCompetitiveSimilarity() == Float.NEGATIVE_INFINITY)) {
// todo do we actually need to know the score???
long offset = centroidIterator.nextPostingListOffset();
CentroidOffsetAndLength offsetAndLength = centroidPrefetchingIterator.nextPostingListOffsetAndLength();
// todo do we need direct access to the raw centroid???, this is used for quantizing, maybe hydrating and quantizing
// is enough?
expectedDocs += scorer.resetPostingsScorer(offset);
expectedDocs += scorer.resetPostingsScorer(offsetAndLength.offset());
actualDocs += scorer.visit(knnCollector);
}
if (acceptDocs != null) {
float unfilteredRatioVisited = (float) expectedDocs / numVectors;
int filteredVectors = (int) Math.ceil(numVectors * percentFiltered);
float expectedScored = Math.min(2 * filteredVectors * unfilteredRatioVisited, expectedDocs / 2f);
while (centroidIterator.hasNext() && (actualDocs < expectedScored || actualDocs < knnCollector.k())) {
long offset = centroidIterator.nextPostingListOffset();
scorer.resetPostingsScorer(offset);
while (centroidPrefetchingIterator.hasNext() && (actualDocs < expectedScored || actualDocs < knnCollector.k())) {
CentroidOffsetAndLength offsetAndLength = centroidPrefetchingIterator.nextPostingListOffsetAndLength();
scorer.resetPostingsScorer(offsetAndLength.offset());
actualDocs += scorer.visit(knnCollector);
}
}
Expand Down Expand Up @@ -312,10 +323,12 @@ IndexInput postingListSlice(IndexInput postingListFile) throws IOException {
abstract PostingVisitor getPostingVisitor(FieldInfo fieldInfo, IndexInput postingsLists, float[] target, Bits needsScoring)
throws IOException;

record CentroidOffsetAndLength(long offset, long length) {}

interface CentroidIterator {
boolean hasNext();

long nextPostingListOffset() throws IOException;
CentroidOffsetAndLength nextPostingListOffsetAndLength() throws IOException;
}

interface PostingVisitor {
Expand Down
Loading