diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsReader.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsReader.java index d80fc216e556c..e191ce96ea2ed 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsReader.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsReader.java @@ -47,9 +47,50 @@ public DefaultIVFVectorsReader(SegmentReadState state, FlatVectorsReader rawVect super(state, rawVectorsReader); } + CentroidIterator getPostingListPrefetchIterator(CentroidIterator centroidIterator, IndexInput postingListSlice) throws IOException { + return new CentroidIterator() { + CentroidOffsetAndLength nextOffsetAndLength = centroidIterator.hasNext() + ? centroidIterator.nextPostingListOffsetAndLength() + : null; + + { + // prefetch the first one + if (nextOffsetAndLength != null) { + prefetch(nextOffsetAndLength); + } + } + + void prefetch(CentroidOffsetAndLength offsetAndLength) throws IOException { + postingListSlice.prefetch(offsetAndLength.offset(), offsetAndLength.length()); + } + + @Override + public boolean hasNext() { + return nextOffsetAndLength != null; + } + + @Override + public CentroidOffsetAndLength nextPostingListOffsetAndLength() throws IOException { + CentroidOffsetAndLength offsetAndLength = nextOffsetAndLength; + if (centroidIterator.hasNext()) { + nextOffsetAndLength = centroidIterator.nextPostingListOffsetAndLength(); + prefetch(nextOffsetAndLength); + } else { + nextOffsetAndLength = null; // indicate we reached the end + } + return offsetAndLength; + } + }; + } + @Override - CentroidIterator getCentroidIterator(FieldInfo fieldInfo, int numCentroids, IndexInput centroids, float[] targetQuery) - throws IOException { + CentroidIterator getCentroidIterator( + FieldInfo fieldInfo, + int numCentroids, + IndexInput centroids, + float[] targetQuery, + IndexInput postingListSlice + ) throws IOException { final FieldEntry fieldEntry = fields.get(fieldInfo.number); final float globalCentroidDp = fieldEntry.globalCentroidDp(); final OptimizedScalarQuantizer scalarQuantizer = new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction()); @@ -71,8 +112,9 @@ CentroidIterator getCentroidIterator(FieldInfo fieldInfo, int numCentroids, Inde final ES92Int7VectorsScorer scorer = ESVectorUtil.getES92Int7VectorsScorer(centroids, fieldInfo.getVectorDimension()); centroids.seek(0L); int numParents = centroids.readVInt(); + CentroidIterator centroidIterator; if (numParents > 0) { - return getCentroidIteratorWithParents( + centroidIterator = getCentroidIteratorWithParents( fieldInfo, centroids, numParents, @@ -82,8 +124,18 @@ CentroidIterator getCentroidIterator(FieldInfo fieldInfo, int numCentroids, Inde queryParams, globalCentroidDp ); + } else { + centroidIterator = getCentroidIteratorNoParent( + fieldInfo, + centroids, + numCentroids, + scorer, + quantized, + queryParams, + globalCentroidDp + ); } - return getCentroidIteratorNoParent(fieldInfo, centroids, numCentroids, scorer, quantized, queryParams, globalCentroidDp); + return getPostingListPrefetchIterator(centroidIterator, postingListSlice); } private static CentroidIterator getCentroidIteratorNoParent( @@ -115,10 +167,12 @@ public boolean hasNext() { } @Override - public long nextPostingListOffset() throws IOException { + public CentroidOffsetAndLength nextPostingListOffsetAndLength() throws IOException { int centroidOrdinal = neighborQueue.pop(); - centroids.seek(offset + (long) Long.BYTES * centroidOrdinal); - return centroids.readLong(); + centroids.seek(offset + (long) Long.BYTES * 2 * centroidOrdinal); + long postingListOffset = centroids.readLong(); + long postingListLength = centroids.readLong(); + return new CentroidOffsetAndLength(postingListOffset, postingListLength); } }; } @@ -185,11 +239,13 @@ public boolean hasNext() { } @Override - public long nextPostingListOffset() throws IOException { + public CentroidOffsetAndLength nextPostingListOffsetAndLength() throws IOException { int centroidOrdinal = neighborQueue.pop(); updateQueue(); // add one children if available so the queue remains fully populated - centroids.seek(childrenFileOffsets + (long) Long.BYTES * centroidOrdinal); - return centroids.readLong(); + centroids.seek(childrenFileOffsets + (long) Long.BYTES * 2 * centroidOrdinal); + long postingListOffset = centroids.readLong(); + long postingListLength = centroids.readLong(); + return new CentroidOffsetAndLength(postingListOffset, postingListLength); } private void updateQueue() throws IOException { @@ -452,6 +508,7 @@ public int visit(KnnCollector knnCollector) throws IOException { int scoredDocs = 0; int limit = vectors - BULK_SIZE + 1; int i = 0; + for (; i < limit; i += BULK_SIZE) { final int docsToBulkScore = acceptDocs == null ? BULK_SIZE : docToBulkScore(docIdsScratch, i, acceptDocs); if (docsToBulkScore == 0) { diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsWriter.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsWriter.java index 5e696b74530a8..b0de42854dc6a 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsWriter.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsWriter.java @@ -18,7 +18,6 @@ import org.apache.lucene.store.IndexInput; import org.apache.lucene.store.IndexOutput; import org.apache.lucene.util.IntroSorter; -import org.apache.lucene.util.LongValues; import org.apache.lucene.util.VectorUtil; import org.apache.lucene.util.hnsw.IntToIntFunction; import org.apache.lucene.util.packed.PackedInts; @@ -60,7 +59,7 @@ public DefaultIVFVectorsWriter( } @Override - LongValues buildAndWritePostingsLists( + CentroidOffsetAndLength buildAndWritePostingsLists( FieldInfo fieldInfo, CentroidSupplier centroidSupplier, FloatVectorValues floatVectorValues, @@ -102,6 +101,7 @@ LongValues buildAndWritePostingsLists( postingsOutput.writeVInt(maxPostingListSize); // write the posting lists final PackedLongValues.Builder offsets = PackedLongValues.monotonicBuilder(PackedInts.COMPACT); + final PackedLongValues.Builder lengths = PackedLongValues.monotonicBuilder(PackedInts.COMPACT); DiskBBQBulkWriter bulkWriter = new DiskBBQBulkWriter.OneBitDiskBBQBulkWriter(ES91OSQVectorsScorer.BULK_SIZE, postingsOutput); OnHeapQuantizedVectors onHeapQuantizedVectors = new OnHeapQuantizedVectors( floatVectorValues, @@ -116,7 +116,8 @@ LongValues buildAndWritePostingsLists( for (int c = 0; c < centroidSupplier.size(); c++) { float[] centroid = centroidSupplier.centroid(c); int[] cluster = assignmentsByCluster[c]; - offsets.add(postingsOutput.alignFilePointer(Float.BYTES) - fileOffset); + long offset = postingsOutput.alignFilePointer(Float.BYTES) - fileOffset; + offsets.add(offset); buffer.asFloatBuffer().put(centroid); // write raw centroid for quantizing the query vectors postingsOutput.writeBytes(buffer.array(), buffer.array().length); @@ -142,17 +143,18 @@ LongValues buildAndWritePostingsLists( idsWriter.writeDocIds(i -> docDeltas[i], size, postingsOutput); // write vectors bulkWriter.writeVectors(onHeapQuantizedVectors); + lengths.add(postingsOutput.getFilePointer() - fileOffset - offset); } if (logger.isDebugEnabled()) { printClusterQualityStatistics(assignmentsByCluster); } - return offsets.build(); + return new CentroidOffsetAndLength(offsets.build(), lengths.build()); } @Override - LongValues buildAndWritePostingsLists( + CentroidOffsetAndLength buildAndWritePostingsLists( FieldInfo fieldInfo, CentroidSupplier centroidSupplier, FloatVectorValues floatVectorValues, @@ -243,6 +245,7 @@ LongValues buildAndWritePostingsLists( // now we can read the quantized vectors from the temporary file try (IndexInput quantizedVectorsInput = mergeState.segmentInfo.dir.openInput(quantizedVectorsTempName, IOContext.DEFAULT)) { final PackedLongValues.Builder offsets = PackedLongValues.monotonicBuilder(PackedInts.COMPACT); + final PackedLongValues.Builder lengths = PackedLongValues.monotonicBuilder(PackedInts.COMPACT); OffHeapQuantizedVectors offHeapQuantizedVectors = new OffHeapQuantizedVectors( quantizedVectorsInput, fieldInfo.getVectorDimension() @@ -260,7 +263,8 @@ LongValues buildAndWritePostingsLists( float[] centroid = centroidSupplier.centroid(c); int[] cluster = assignmentsByCluster[c]; boolean[] isOverspill = isOverspillByCluster[c]; - offsets.add(postingsOutput.alignFilePointer(Float.BYTES) - fileOffset); + long offset = postingsOutput.alignFilePointer(Float.BYTES) - fileOffset; + offsets.add(offset); // write raw centroid for quantizing the query vectors buffer.asFloatBuffer().put(centroid); postingsOutput.writeBytes(buffer.array(), buffer.array().length); @@ -286,12 +290,14 @@ LongValues buildAndWritePostingsLists( idsWriter.writeDocIds(i -> docDeltas[i], size, postingsOutput); // write vectors bulkWriter.writeVectors(offHeapQuantizedVectors); + lengths.add(postingsOutput.getFilePointer() - fileOffset - offset); + // lengths.add(1); } if (logger.isDebugEnabled()) { printClusterQualityStatistics(assignmentsByCluster); } - return offsets.build(); + return new CentroidOffsetAndLength(offsets.build(), lengths.build()); } } @@ -335,16 +341,16 @@ void writeCentroids( FieldInfo fieldInfo, CentroidSupplier centroidSupplier, float[] globalCentroid, - LongValues offsets, + CentroidOffsetAndLength centroidOffsetAndLength, IndexOutput centroidOutput ) throws IOException { // TODO do we want to store these distances as well for future use? // TODO: sort centroids by global centroid (was doing so previously here) // TODO: sorting tanks recall possibly because centroids ordinals no longer are aligned if (centroidSupplier.size() > centroidsPerParentCluster * centroidsPerParentCluster) { - writeCentroidsWithParents(fieldInfo, centroidSupplier, globalCentroid, offsets, centroidOutput); + writeCentroidsWithParents(fieldInfo, centroidSupplier, globalCentroid, centroidOffsetAndLength, centroidOutput); } else { - writeCentroidsWithoutParents(fieldInfo, centroidSupplier, globalCentroid, offsets, centroidOutput); + writeCentroidsWithoutParents(fieldInfo, centroidSupplier, globalCentroid, centroidOffsetAndLength, centroidOutput); } } @@ -352,7 +358,7 @@ private void writeCentroidsWithParents( FieldInfo fieldInfo, CentroidSupplier centroidSupplier, float[] globalCentroid, - LongValues offsets, + CentroidOffsetAndLength centroidOffsetAndLength, IndexOutput centroidOutput ) throws IOException { DiskBBQBulkWriter.SevenBitDiskBBQBulkWriter bulkWriter = new DiskBBQBulkWriter.SevenBitDiskBBQBulkWriter( @@ -392,7 +398,8 @@ private void writeCentroidsWithParents( for (int i = 0; i < centroidGroups.centroids().length; i++) { final int[] centroidAssignments = centroidGroups.vectors()[i]; for (int assignment : centroidAssignments) { - centroidOutput.writeLong(offsets.get(assignment)); + centroidOutput.writeLong(centroidOffsetAndLength.offsets().get(assignment)); + centroidOutput.writeLong(centroidOffsetAndLength.lengths().get(assignment)); } } } @@ -401,7 +408,7 @@ private void writeCentroidsWithoutParents( FieldInfo fieldInfo, CentroidSupplier centroidSupplier, float[] globalCentroid, - LongValues offsets, + CentroidOffsetAndLength centroidOffsetAndLength, IndexOutput centroidOutput ) throws IOException { centroidOutput.writeVInt(0); @@ -419,7 +426,8 @@ private void writeCentroidsWithoutParents( bulkWriter.writeVectors(quantizedCentroids); // write the centroid offsets at the end of the file for (int i = 0; i < centroidSupplier.size(); i++) { - centroidOutput.writeLong(offsets.get(i)); + centroidOutput.writeLong(centroidOffsetAndLength.offsets().get(i)); + centroidOutput.writeLong(centroidOffsetAndLength.lengths().get(i)); } } diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsReader.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsReader.java index 44188a58618f2..9dcbcc6cb9054 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsReader.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsReader.java @@ -85,8 +85,13 @@ protected IVFVectorsReader(SegmentReadState state, FlatVectorsReader rawVectorsR } } - abstract CentroidIterator getCentroidIterator(FieldInfo fieldInfo, int numCentroids, IndexInput centroids, float[] target) - throws IOException; + abstract CentroidIterator getCentroidIterator( + FieldInfo fieldInfo, + int numCentroids, + IndexInput centroids, + float[] target, + IndexInput postingListSlice + ) throws IOException; private static IndexInput openDataInput( SegmentReadState state, @@ -241,31 +246,37 @@ public final void search(String field, float[] target, KnnCollector knnCollector } // we account for soar vectors here. We can potentially visit a vector twice so we multiply by 2 here. long maxVectorVisited = (long) (2.0 * visitRatio * numVectors); - CentroidIterator centroidIterator = getCentroidIterator(fieldInfo, entry.numCentroids, entry.centroidSlice(ivfCentroids), target); - PostingVisitor scorer = getPostingVisitor(fieldInfo, entry.postingListSlice(ivfClusters), target, acceptDocs); - + IndexInput postListSlice = entry.postingListSlice(ivfClusters); + CentroidIterator centroidPrefetchingIterator = getCentroidIterator( + fieldInfo, + entry.numCentroids, + entry.centroidSlice(ivfCentroids), + target, + postListSlice + ); + PostingVisitor scorer = getPostingVisitor(fieldInfo, postListSlice, target, acceptDocs); long expectedDocs = 0; long actualDocs = 0; // initially we visit only the "centroids to search" // Note, numCollected is doing the bare minimum here. // TODO do we need to handle nested doc counts similarly to how we handle // filtering? E.g. keep exploring until we hit an expected number of parent documents vs. child vectors? - while (centroidIterator.hasNext() + while (centroidPrefetchingIterator.hasNext() && (maxVectorVisited > expectedDocs || knnCollector.minCompetitiveSimilarity() == Float.NEGATIVE_INFINITY)) { // todo do we actually need to know the score??? - long offset = centroidIterator.nextPostingListOffset(); + CentroidOffsetAndLength offsetAndLength = centroidPrefetchingIterator.nextPostingListOffsetAndLength(); // todo do we need direct access to the raw centroid???, this is used for quantizing, maybe hydrating and quantizing // is enough? - expectedDocs += scorer.resetPostingsScorer(offset); + expectedDocs += scorer.resetPostingsScorer(offsetAndLength.offset()); actualDocs += scorer.visit(knnCollector); } if (acceptDocs != null) { float unfilteredRatioVisited = (float) expectedDocs / numVectors; int filteredVectors = (int) Math.ceil(numVectors * percentFiltered); float expectedScored = Math.min(2 * filteredVectors * unfilteredRatioVisited, expectedDocs / 2f); - while (centroidIterator.hasNext() && (actualDocs < expectedScored || actualDocs < knnCollector.k())) { - long offset = centroidIterator.nextPostingListOffset(); - scorer.resetPostingsScorer(offset); + while (centroidPrefetchingIterator.hasNext() && (actualDocs < expectedScored || actualDocs < knnCollector.k())) { + CentroidOffsetAndLength offsetAndLength = centroidPrefetchingIterator.nextPostingListOffsetAndLength(); + scorer.resetPostingsScorer(offsetAndLength.offset()); actualDocs += scorer.visit(knnCollector); } } @@ -312,10 +323,12 @@ IndexInput postingListSlice(IndexInput postingListFile) throws IOException { abstract PostingVisitor getPostingVisitor(FieldInfo fieldInfo, IndexInput postingsLists, float[] target, Bits needsScoring) throws IOException; + record CentroidOffsetAndLength(long offset, long length) {} + interface CentroidIterator { boolean hasNext(); - long nextPostingListOffset() throws IOException; + CentroidOffsetAndLength nextPostingListOffsetAndLength() throws IOException; } interface PostingVisitor { diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsWriter.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsWriter.java index 308ee391b5f4a..63253e1cd8db7 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsWriter.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsWriter.java @@ -123,15 +123,17 @@ public final KnnFieldVectorsWriter addField(FieldInfo fieldInfo) throws IOExc abstract CentroidAssignments calculateCentroids(FieldInfo fieldInfo, FloatVectorValues floatVectorValues, float[] globalCentroid) throws IOException; + record CentroidOffsetAndLength(LongValues offsets, LongValues lengths) {} + abstract void writeCentroids( FieldInfo fieldInfo, CentroidSupplier centroidSupplier, float[] globalCentroid, - LongValues centroidOffset, + CentroidOffsetAndLength centroidOffsetAndLength, IndexOutput centroidOutput ) throws IOException; - abstract LongValues buildAndWritePostingsLists( + abstract CentroidOffsetAndLength buildAndWritePostingsLists( FieldInfo fieldInfo, CentroidSupplier centroidSupplier, FloatVectorValues floatVectorValues, @@ -141,7 +143,7 @@ abstract LongValues buildAndWritePostingsLists( int[] overspillAssignments ) throws IOException; - abstract LongValues buildAndWritePostingsLists( + abstract CentroidOffsetAndLength buildAndWritePostingsLists( FieldInfo fieldInfo, CentroidSupplier centroidSupplier, FloatVectorValues floatVectorValues, @@ -172,7 +174,7 @@ public final void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException { final CentroidSupplier centroidSupplier = new OnHeapCentroidSupplier(centroidAssignments.centroids()); // write posting lists final long postingListOffset = ivfClusters.alignFilePointer(Float.BYTES); - final LongValues offsets = buildAndWritePostingsLists( + final CentroidOffsetAndLength centroidOffsetAndLength = buildAndWritePostingsLists( fieldWriter.fieldInfo, centroidSupplier, floatVectorValues, @@ -184,7 +186,7 @@ public final void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException { final long postingListLength = ivfClusters.getFilePointer() - postingListOffset; // write centroids final long centroidOffset = ivfCentroids.alignFilePointer(Float.BYTES); - writeCentroids(fieldWriter.fieldInfo, centroidSupplier, globalCentroid, offsets, ivfCentroids); + writeCentroids(fieldWriter.fieldInfo, centroidSupplier, globalCentroid, centroidOffsetAndLength, ivfCentroids); final long centroidLength = ivfCentroids.getFilePointer() - centroidOffset; // write meta file writeMeta( @@ -354,7 +356,7 @@ private void mergeOneFieldIVF(FieldInfo fieldInfo, MergeState mergeState) throws ); // write posting lists postingListOffset = ivfClusters.alignFilePointer(Float.BYTES); - final LongValues offsets = buildAndWritePostingsLists( + final CentroidOffsetAndLength centroidOffsetAndLength = buildAndWritePostingsLists( fieldInfo, centroidSupplier, floatVectorValues, @@ -367,7 +369,7 @@ private void mergeOneFieldIVF(FieldInfo fieldInfo, MergeState mergeState) throws postingListLength = ivfClusters.getFilePointer() - postingListOffset; // write centroids centroidOffset = ivfCentroids.alignFilePointer(Float.BYTES); - writeCentroids(fieldInfo, centroidSupplier, calculatedGlobalCentroid, offsets, ivfCentroids); + writeCentroids(fieldInfo, centroidSupplier, calculatedGlobalCentroid, centroidOffsetAndLength, ivfCentroids); centroidLength = ivfCentroids.getFilePointer() - centroidOffset; // write meta writeMeta(