diff --git a/CHANGELOG.md b/CHANGELOG.md index 16772fd00..ea7c67943 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -24,6 +24,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), * Update Default Rescore Context based on Dimension [#2149](https://github.com/opensearch-project/k-NN/pull/2149) * KNNIterators should support with and without filters [#2155](https://github.com/opensearch-project/k-NN/pull/2155) * Adding Support to Enable/Disble Share level Rescoring and Update Oversampling Factor[#2172](https://github.com/opensearch-project/k-NN/pull/2172) +* Add support to build vector data structures greedily and perform exact search when there are no engine files [#1942](https://github.com/opensearch-project/k-NN/issues/1942) ### Bug Fixes * Add DocValuesProducers for releasing memory when close index [#1946](https://github.com/opensearch-project/k-NN/pull/1946) * KNN80DocValues should only be considered for BinaryDocValues fields [#2147](https://github.com/opensearch-project/k-NN/pull/2147) diff --git a/src/main/java/org/opensearch/knn/index/KNNSettings.java b/src/main/java/org/opensearch/knn/index/KNNSettings.java index f5980879a..fb9c26fef 100644 --- a/src/main/java/org/opensearch/knn/index/KNNSettings.java +++ b/src/main/java/org/opensearch/knn/index/KNNSettings.java @@ -67,6 +67,7 @@ public class KNNSettings { * Settings name */ public static final String KNN_SPACE_TYPE = "index.knn.space_type"; + public static final String INDEX_KNN_ADVANCED_APPROXIMATE_THRESHOLD = "index.knn.advanced.approximate_threshold"; public static final String KNN_ALGO_PARAM_M = "index.knn.algo_param.m"; public static final String KNN_ALGO_PARAM_EF_CONSTRUCTION = "index.knn.algo_param.ef_construction"; public static final String KNN_ALGO_PARAM_EF_SEARCH = "index.knn.algo_param.ef_search"; @@ -97,6 +98,9 @@ public class KNNSettings { public static final boolean KNN_DEFAULT_FAISS_AVX2_DISABLED_VALUE = false; public static final boolean KNN_DEFAULT_FAISS_AVX512_DISABLED_VALUE = false; public static final String INDEX_KNN_DEFAULT_SPACE_TYPE = "l2"; + public static final Integer INDEX_KNN_ADVANCED_APPROXIMATE_THRESHOLD_DEFAULT_VALUE = 0; + public static final Integer INDEX_KNN_BUILD_VECTOR_DATA_STRUCTURE_THRESHOLD_MIN = -1; + public static final Integer INDEX_KNN_BUILD_VECTOR_DATA_STRUCTURE_THRESHOLD_MAX = Integer.MAX_VALUE - 2; public static final String INDEX_KNN_DEFAULT_SPACE_TYPE_FOR_BINARY = "hamming"; public static final Integer INDEX_KNN_DEFAULT_ALGO_PARAM_M = 16; public static final Integer INDEX_KNN_DEFAULT_ALGO_PARAM_EF_SEARCH = 100; @@ -156,6 +160,21 @@ public class KNNSettings { Setting.Property.Deprecated ); + /** + * build_vector_data_structure_threshold - This parameter determines when to build vector data structure for knn fields during indexing + * and merging. Setting -1 (min) will skip building graph, whereas on any other values, the graph will be built if + * number of live docs in segment is greater than this threshold. Since max number of documents in a segment can + * be Integer.MAX_VALUE - 1, this setting will allow threshold to be up to 1 less than max number of documents in a segment + */ + public static final Setting INDEX_KNN_ADVANCED_APPROXIMATE_THRESHOLD_SETTING = Setting.intSetting( + INDEX_KNN_ADVANCED_APPROXIMATE_THRESHOLD, + INDEX_KNN_ADVANCED_APPROXIMATE_THRESHOLD_DEFAULT_VALUE, + INDEX_KNN_BUILD_VECTOR_DATA_STRUCTURE_THRESHOLD_MIN, + INDEX_KNN_BUILD_VECTOR_DATA_STRUCTURE_THRESHOLD_MAX, + IndexScope, + Dynamic + ); + /** * M - the number of bi-directional links created for every new element during construction. * Reasonable range for M is 2-100. Higher M work better on datasets with high intrinsic @@ -486,6 +505,7 @@ private Setting getSetting(String key) { public List> getSettings() { List> settings = Arrays.asList( INDEX_KNN_SPACE_TYPE, + INDEX_KNN_ADVANCED_APPROXIMATE_THRESHOLD_SETTING, INDEX_KNN_ALGO_PARAM_M_SETTING, INDEX_KNN_ALGO_PARAM_EF_CONSTRUCTION_SETTING, INDEX_KNN_ALGO_PARAM_EF_SEARCH_SETTING, diff --git a/src/main/java/org/opensearch/knn/index/codec/BasePerFieldKnnVectorsFormat.java b/src/main/java/org/opensearch/knn/index/codec/BasePerFieldKnnVectorsFormat.java index b06c2a1d8..72187516f 100644 --- a/src/main/java/org/opensearch/knn/index/codec/BasePerFieldKnnVectorsFormat.java +++ b/src/main/java/org/opensearch/knn/index/codec/BasePerFieldKnnVectorsFormat.java @@ -11,7 +11,9 @@ import org.apache.lucene.codecs.hnsw.FlatVectorScorerUtil; import org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsFormat; import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; +import org.opensearch.index.IndexSettings; import org.opensearch.index.mapper.MapperService; +import org.opensearch.knn.index.KNNSettings; import org.opensearch.knn.index.codec.KNN990Codec.NativeEngines990KnnVectorsFormat; import org.opensearch.knn.index.codec.params.KNNScalarQuantizedVectorsFormatParams; import org.opensearch.knn.index.codec.params.KNNVectorsFormatParams; @@ -129,7 +131,23 @@ public KnnVectorsFormat getKnnVectorsFormatForField(final String field) { } private NativeEngines990KnnVectorsFormat nativeEngineVectorsFormat() { - return new NativeEngines990KnnVectorsFormat(new Lucene99FlatVectorsFormat(FlatVectorScorerUtil.getLucene99FlatVectorsScorer())); + // mapperService is already checked for null or valid instance type at caller, hence we don't need + // addition isPresent check here. + int approximateThreshold = getApproximateThresholdValue(); + return new NativeEngines990KnnVectorsFormat( + new Lucene99FlatVectorsFormat(FlatVectorScorerUtil.getLucene99FlatVectorsScorer()), + approximateThreshold + ); + } + + private int getApproximateThresholdValue() { + // This is private method and mapperService is already checked for null or valid instance type before this call + // at caller, hence we don't need additional isPresent check here. + final IndexSettings indexSettings = mapperService.get().getIndexSettings(); + final Integer approximateThresholdValue = indexSettings.getValue(KNNSettings.INDEX_KNN_ADVANCED_APPROXIMATE_THRESHOLD_SETTING); + return approximateThresholdValue != null + ? approximateThresholdValue + : KNNSettings.INDEX_KNN_ADVANCED_APPROXIMATE_THRESHOLD_DEFAULT_VALUE; } @Override diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsFormat.java b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsFormat.java index 626210f25..520a9838d 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsFormat.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsFormat.java @@ -19,6 +19,7 @@ import org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsFormat; import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.SegmentWriteState; +import org.opensearch.knn.index.KNNSettings; import java.io.IOException; @@ -30,15 +31,20 @@ public class NativeEngines990KnnVectorsFormat extends KnnVectorsFormat { /** The format for storing, reading, merging vectors on disk */ private static FlatVectorsFormat flatVectorsFormat; private static final String FORMAT_NAME = "NativeEngines990KnnVectorsFormat"; + private static int approximateThreshold; public NativeEngines990KnnVectorsFormat() { - super(FORMAT_NAME); - flatVectorsFormat = new Lucene99FlatVectorsFormat(new DefaultFlatVectorScorer()); + this(new Lucene99FlatVectorsFormat(new DefaultFlatVectorScorer())); + } + + public NativeEngines990KnnVectorsFormat(final FlatVectorsFormat flatVectorsFormat) { + this(flatVectorsFormat, KNNSettings.INDEX_KNN_ADVANCED_APPROXIMATE_THRESHOLD_DEFAULT_VALUE); } - public NativeEngines990KnnVectorsFormat(final FlatVectorsFormat lucene99FlatVectorsFormat) { + public NativeEngines990KnnVectorsFormat(final FlatVectorsFormat flatVectorsFormat, int approximateThreshold) { super(FORMAT_NAME); - flatVectorsFormat = lucene99FlatVectorsFormat; + NativeEngines990KnnVectorsFormat.flatVectorsFormat = flatVectorsFormat; + NativeEngines990KnnVectorsFormat.approximateThreshold = approximateThreshold; } /** @@ -48,7 +54,7 @@ public NativeEngines990KnnVectorsFormat(final FlatVectorsFormat lucene99FlatVect */ @Override public KnnVectorsWriter fieldsWriter(final SegmentWriteState state) throws IOException { - return new NativeEngines990KnnVectorsWriter(state, flatVectorsFormat.fieldsWriter(state)); + return new NativeEngines990KnnVectorsWriter(state, flatVectorsFormat.fieldsWriter(state), approximateThreshold); } /** @@ -63,6 +69,12 @@ public KnnVectorsReader fieldsReader(final SegmentReadState state) throws IOExce @Override public String toString() { - return "NativeEngines99KnnVectorsFormat(name=" + this.getClass().getSimpleName() + ", flatVectorsFormat=" + flatVectorsFormat + ")"; + return "NativeEngines99KnnVectorsFormat(name=" + + this.getClass().getSimpleName() + + ", flatVectorsFormat=" + + flatVectorsFormat + + ", approximateThreshold=" + + approximateThreshold + + ")"; } } diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java index eccad41c8..7c8636577 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java @@ -53,10 +53,16 @@ public class NativeEngines990KnnVectorsWriter extends KnnVectorsWriter { private KNN990QuantizationStateWriter quantizationStateWriter; private final List> fields = new ArrayList<>(); private boolean finished; + private final Integer approximateThreshold; - public NativeEngines990KnnVectorsWriter(SegmentWriteState segmentWriteState, FlatVectorsWriter flatVectorsWriter) { + public NativeEngines990KnnVectorsWriter( + SegmentWriteState segmentWriteState, + FlatVectorsWriter flatVectorsWriter, + Integer approximateThreshold + ) { this.segmentWriteState = segmentWriteState; this.flatVectorsWriter = flatVectorsWriter; + this.approximateThreshold = approximateThreshold; } /** @@ -98,6 +104,17 @@ public void flush(int maxDoc, final Sorter.DocMap sortMap) throws IOException { field.getVectors() ); final QuantizationState quantizationState = train(field.getFieldInfo(), knnVectorValuesSupplier, totalLiveDocs); + // Check only after quantization state writer finish writing its state, since it is required + // even if there are no graph files in segment, which will be later used by exact search + if (shouldSkipBuildingVectorDataStructure(totalLiveDocs)) { + log.info( + "Skip building vector data structure for field: {}, as liveDoc: {} is less than the threshold {} during flush", + fieldInfo.name, + totalLiveDocs, + approximateThreshold + ); + continue; + } final NativeIndexWriter writer = NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, quantizationState); final KNNVectorValues knnVectorValues = knnVectorValuesSupplier.get(); @@ -127,6 +144,17 @@ public void mergeOneField(final FieldInfo fieldInfo, final MergeState mergeState } final QuantizationState quantizationState = train(fieldInfo, knnVectorValuesSupplier, totalLiveDocs); + // Check only after quantization state writer finish writing its state, since it is required + // even if there are no graph files in segment, which will be later used by exact search + if (shouldSkipBuildingVectorDataStructure(totalLiveDocs)) { + log.info( + "Skip building vector data structure for field: {}, as liveDoc: {} is less than the threshold {} during merge", + fieldInfo.name, + totalLiveDocs, + approximateThreshold + ); + return; + } final NativeIndexWriter writer = NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, quantizationState); final KNNVectorValues knnVectorValues = knnVectorValuesSupplier.get(); @@ -257,4 +285,11 @@ private void initQuantizationStateWriterIfNecessary() throws IOException { quantizationStateWriter.writeHeader(segmentWriteState); } } + + private boolean shouldSkipBuildingVectorDataStructure(final long docCount) { + if (approximateThreshold < 0) { + return true; + } + return docCount < approximateThreshold; + } } diff --git a/src/main/java/org/opensearch/knn/index/query/ExactSearcher.java b/src/main/java/org/opensearch/knn/index/query/ExactSearcher.java index 8e5849abb..77e993297 100644 --- a/src/main/java/org/opensearch/knn/index/query/ExactSearcher.java +++ b/src/main/java/org/opensearch/knn/index/query/ExactSearcher.java @@ -5,8 +5,10 @@ package org.opensearch.knn.index.query; +import com.google.common.base.Predicates; import lombok.AllArgsConstructor; import lombok.Builder; +import lombok.NonNull; import lombok.Value; import lombok.extern.log4j.Log4j2; import org.apache.lucene.index.FieldInfo; @@ -21,6 +23,7 @@ import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.query.iterators.BinaryVectorIdsKNNIterator; +import org.opensearch.knn.index.engine.KNNEngine; import org.opensearch.knn.index.query.iterators.ByteVectorIdsKNNIterator; import org.opensearch.knn.index.query.iterators.NestedBinaryVectorIdsKNNIterator; import org.opensearch.knn.index.query.iterators.VectorIdsKNNIterator; @@ -36,7 +39,9 @@ import java.io.IOException; import java.util.HashMap; +import java.util.Locale; import java.util.Map; +import java.util.function.Predicate; @Log4j2 @AllArgsConstructor @@ -55,11 +60,41 @@ public class ExactSearcher { public Map searchLeaf(final LeafReaderContext leafReaderContext, final ExactSearcherContext exactSearcherContext) throws IOException { KNNIterator iterator = getKNNIterator(leafReaderContext, exactSearcherContext); + if (exactSearcherContext.getKnnQuery().getRadius() != null) { + return doRadialSearch(leafReaderContext, exactSearcherContext, iterator); + } if (exactSearcherContext.getMatchedDocs() != null && exactSearcherContext.getMatchedDocs().cardinality() <= exactSearcherContext.getK()) { return scoreAllDocs(iterator); } - return searchTopK(iterator, exactSearcherContext.getK()); + return searchTopCandidates(iterator, exactSearcherContext.getK(), Predicates.alwaysTrue()); + } + + /** + * Perform radial search by comparing scores with min score. Currently, FAISS from native engine supports radial search. + * Hence, we assume that Radius from knnQuery is always distance, and we convert it to score since we do exact search uses scores + * to filter out the documents that does not have given min score. + * @param leafReaderContext + * @param exactSearcherContext + * @param iterator {@link KNNIterator} + * @return Map of docId and score + * @throws IOException exception raised by iterator during traversal + */ + private Map doRadialSearch( + LeafReaderContext leafReaderContext, + ExactSearcherContext exactSearcherContext, + KNNIterator iterator + ) throws IOException { + final SegmentReader reader = Lucene.segmentReader(leafReaderContext.reader()); + final KNNQuery knnQuery = exactSearcherContext.getKnnQuery(); + final FieldInfo fieldInfo = reader.getFieldInfos().fieldInfo(knnQuery.getField()); + final KNNEngine engine = FieldInfoExtractor.extractKNNEngine(fieldInfo); + if (KNNEngine.FAISS != engine) { + throw new IllegalArgumentException(String.format(Locale.ROOT, "Engine [%s] does not support radial search", engine)); + } + final SpaceType spaceType = FieldInfoExtractor.getSpaceType(modelDao, fieldInfo); + final float minScore = spaceType.scoreTranslation(knnQuery.getRadius()); + return filterDocsByMinScore(exactSearcherContext, iterator, minScore); } private Map scoreAllDocs(KNNIterator iterator) throws IOException { @@ -71,15 +106,17 @@ private Map scoreAllDocs(KNNIterator iterator) throws IOExceptio return docToScore; } - private Map searchTopK(KNNIterator iterator, int k) throws IOException { + private Map searchTopCandidates(KNNIterator iterator, int limit, @NonNull Predicate filterScore) + throws IOException { // Creating min heap and init with MAX DocID and Score as -INF. - final HitQueue queue = new HitQueue(k, true); + final HitQueue queue = new HitQueue(limit, true); ScoreDoc topDoc = queue.top(); final Map docToScore = new HashMap<>(); int docId; while ((docId = iterator.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS) { - if (iterator.score() > topDoc.score) { - topDoc.score = iterator.score(); + final float currentScore = iterator.score(); + if (filterScore.test(currentScore) && currentScore > topDoc.score) { + topDoc.score = currentScore; topDoc.doc = docId; // As the HitQueue is min heap, updating top will bring the doc with -INF score or worst score we // have seen till now on top. @@ -98,10 +135,16 @@ private Map searchTopK(KNNIterator iterator, int k) throws IOExc final ScoreDoc doc = queue.pop(); docToScore.put(doc.doc, doc.score); } - return docToScore; } + private Map filterDocsByMinScore(ExactSearcherContext context, KNNIterator iterator, float minScore) + throws IOException { + int maxResultWindow = context.getKnnQuery().getContext().getMaxResultWindow(); + Predicate scoreGreaterThanOrEqualToMinScore = score -> score >= minScore; + return searchTopCandidates(iterator, maxResultWindow, scoreGreaterThanOrEqualToMinScore); + } + private KNNIterator getKNNIterator(LeafReaderContext leafReaderContext, ExactSearcherContext exactSearcherContext) throws IOException { final KNNQuery knnQuery = exactSearcherContext.getKnnQuery(); final BitSet matchedDocs = exactSearcherContext.getMatchedDocs(); diff --git a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java index 37695c208..87c99b884 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java @@ -5,6 +5,7 @@ package org.opensearch.knn.index.query; +import com.google.common.annotations.VisibleForTesting; import lombok.extern.log4j.Log4j2; import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.LeafReaderContext; @@ -22,6 +23,7 @@ import org.apache.lucene.util.FixedBitSet; import org.opensearch.common.io.PathUtils; import org.opensearch.common.lucene.Lucene; +import org.opensearch.knn.common.FieldInfoExtractor; import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.KNNSettings; import org.opensearch.knn.index.SpaceType; @@ -33,6 +35,7 @@ import org.opensearch.knn.index.memory.NativeMemoryLoadStrategy; import org.opensearch.knn.index.engine.KNNEngine; import org.opensearch.knn.index.quantizationservice.QuantizationService; +import org.opensearch.knn.index.query.ExactSearcher.ExactSearcherContext.ExactSearcherContextBuilder; import org.opensearch.knn.indices.ModelDao; import org.opensearch.knn.indices.ModelMetadata; import org.opensearch.knn.indices.ModelUtil; @@ -93,8 +96,13 @@ public KNNWeight(KNNQuery query, float boost, Weight filterWeight) { } public static void initialize(ModelDao modelDao) { + initialize(modelDao, new ExactSearcher(modelDao)); + } + + @VisibleForTesting + static void initialize(ModelDao modelDao, ExactSearcher exactSearcher) { KNNWeight.modelDao = modelDao; - KNNWeight.DEFAULT_EXACT_SEARCHER = new ExactSearcher(modelDao); + KNNWeight.DEFAULT_EXACT_SEARCHER = exactSearcher; } @Override @@ -129,42 +137,21 @@ public Map searchLeaf(LeafReaderContext context, int k) throws I if (filterWeight != null && cardinality == 0) { return Collections.emptyMap(); } - /* - * The idea for this optimization is to get K results, we need to atleast look at K vectors in the HNSW graph + * The idea for this optimization is to get K results, we need to at least look at K vectors in the HNSW graph * . Hence, if filtered results are less than K and filter query is present we should shift to exact search. * This improves the recall. */ - Map docIdsToScoreMap; - final ExactSearcher.ExactSearcherContext exactSearcherContext = ExactSearcher.ExactSearcherContext.builder() - .k(k) - .isParentHits(true) - .matchedDocs(filterBitSet) - // setting to true, so that if quantization details are present we want to do search on the quantized - // vectors as this flow is used in first pass of search. - .useQuantizedVectorsForSearch(true) - .knnQuery(knnQuery) - .build(); - if (filterWeight != null && canDoExactSearch(cardinality)) { - docIdsToScoreMap = exactSearch(context, exactSearcherContext); - } else { - docIdsToScoreMap = doANNSearch(context, filterBitSet, cardinality, k); - if (docIdsToScoreMap == null) { - return Collections.emptyMap(); - } - if (canDoExactSearchAfterANNSearch(cardinality, docIdsToScoreMap.size())) { - log.debug( - "Doing ExactSearch after doing ANNSearch as the number of documents returned are less than " - + "K, even when we have more than K filtered Ids. K: {}, ANNResults: {}, filteredIdCount: {}", - k, - docIdsToScoreMap.size(), - cardinality - ); - docIdsToScoreMap = exactSearch(context, exactSearcherContext); - } + if (isFilteredExactSearchPreferred(cardinality)) { + return doExactSearch(context, filterBitSet, k); } - if (docIdsToScoreMap.isEmpty()) { - return Collections.emptyMap(); + Map docIdsToScoreMap = doANNSearch(context, filterBitSet, cardinality, k); + // See whether we have to perform exact search based on approx search results + // This is required if there are no native engine files or if approximate search returned + // results less than K, though we have more than k filtered docs + if (isExactSearchRequire(context, cardinality, docIdsToScoreMap.size())) { + final BitSet docs = filterWeight != null ? filterBitSet : null; + return doExactSearch(context, docs, k); } return docIdsToScoreMap; } @@ -221,6 +208,20 @@ private int[] bitSetToIntArray(final BitSet bitSet) { return intArray; } + private Map doExactSearch(final LeafReaderContext context, final BitSet acceptedDocs, int k) throws IOException { + final ExactSearcherContextBuilder exactSearcherContextBuilder = ExactSearcher.ExactSearcherContext.builder() + .isParentHits(true) + .k(k) + // setting to true, so that if quantization details are present we want to do search on the quantized + // vectors as this flow is used in first pass of search. + .useQuantizedVectorsForSearch(true) + .knnQuery(knnQuery); + if (acceptedDocs != null) { + exactSearcherContextBuilder.matchedDocs(acceptedDocs); + } + return exactSearch(context, exactSearcherContextBuilder.build()); + } + private Map doANNSearch( final LeafReaderContext context, final BitSet filterIdsBitSet, @@ -234,7 +235,7 @@ private Map doANNSearch( if (fieldInfo == null) { log.debug("[KNN] Field info not found for {}:{}", knnQuery.getField(), reader.getSegmentName()); - return null; + return Collections.emptyMap(); } KNNEngine knnEngine; @@ -273,8 +274,8 @@ private Map doANNSearch( List engineFiles = KNNCodecUtil.getEngineFiles(knnEngine.getExtension(), knnQuery.getField(), reader.getSegmentInfo().info); if (engineFiles.isEmpty()) { - log.debug("[KNN] No engine index found for field {} for segment {}", knnQuery.getField(), reader.getSegmentName()); - return null; + log.debug("[KNN] No native engine files found for field {} for segment {}", knnQuery.getField(), reader.getSegmentName()); + return Collections.emptyMap(); } Path indexPath = PathUtils.get(directory, engineFiles.get(0)); @@ -364,16 +365,9 @@ private Map doANNSearch( indexAllocation.readUnlock(); indexAllocation.decRef(); } - - /* - * Scores represent the distance of the documents with respect to given query vector. - * Lesser the score, the closer the document is to the query vector. - * Since by default results are retrieved in the descending order of scores, to get the nearest - * neighbors we are inverting the scores. - */ if (results.length == 0) { log.debug("[KNN] Query yielded 0 results"); - return null; + return Collections.emptyMap(); } if (quantizedVector != null) { @@ -386,7 +380,6 @@ private Map doANNSearch( /** * Execute exact search for the given matched doc ids and return the results as a map of docId to score. - * * @return Map of docId to score for the exact search results. * @throws IOException If an error occurs during the search. */ @@ -407,18 +400,18 @@ public static float normalizeScore(float score) { return -score + 1; } - private boolean canDoExactSearch(final int filterIdsCount) { + private boolean isFilteredExactSearchPreferred(final int filterIdsCount) { + if (filterWeight == null) { + return false; + } log.debug( "Info for doing exact search filterIdsLength : {}, Threshold value: {}", filterIdsCount, KNNSettings.getFilteredExactSearchThreshold(knnQuery.getIndexName()) ); - if (knnQuery.getRadius() != null) { - return false; - } int filterThresholdValue = KNNSettings.getFilteredExactSearchThreshold(knnQuery.getIndexName()); // Refer this GitHub around more details https://github.com/opensearch-project/k-NN/issues/1049 on the logic - if (filterIdsCount <= knnQuery.getK()) { + if (knnQuery.getRadius() == null && filterIdsCount <= knnQuery.getK()) { return true; } // See user has defined Exact Search filtered threshold. if yes, then use that setting. @@ -446,14 +439,59 @@ private boolean isExactSearchThresholdSettingSet(int filterThresholdValue) { return filterThresholdValue != KNNSettings.ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD_DEFAULT_VALUE; } + /** + * This condition mainly checks whether exact search should be performed or not + * @param context LeafReaderContext + * @param filterIdsCount count of filtered Doc ids + * @param annResultCount Count of Nearest Neighbours we got after doing filtered ANN Search. + * @return boolean - true if exactSearch needs to be done after ANNSearch. + */ + private boolean isExactSearchRequire(final LeafReaderContext context, final int filterIdsCount, final int annResultCount) { + if (annResultCount == 0 && isMissingNativeEngineFiles(context)) { + log.debug("Perform exact search after approximate search since no native engine files are available"); + return true; + } + if (isFilteredExactSearchRequireAfterANNSearch(filterIdsCount, annResultCount)) { + log.debug( + "Doing ExactSearch after doing ANNSearch as the number of documents returned are less than " + + "K, even when we have more than K filtered Ids. K: {}, ANNResults: {}, filteredIdCount: {}", + this.knnQuery.getK(), + annResultCount, + filterIdsCount + ); + return true; + } + return false; + } + /** * This condition mainly checks during filtered search we have more than K elements in filterIds but the ANN - * doesn't yeild K nearest neighbors. + * doesn't yield K nearest neighbors. * @param filterIdsCount count of filtered Doc ids * @param annResultCount Count of Nearest Neighbours we got after doing filtered ANN Search. * @return boolean - true if exactSearch needs to be done after ANNSearch. */ - private boolean canDoExactSearchAfterANNSearch(final int filterIdsCount, final int annResultCount) { + private boolean isFilteredExactSearchRequireAfterANNSearch(final int filterIdsCount, final int annResultCount) { return filterWeight != null && filterIdsCount >= knnQuery.getK() && knnQuery.getK() > annResultCount; } + + /** + * This condition mainly checks whether segments has native engine files or not + * @return boolean - false if exactSearch needs to be done since no native engine files are in segments. + */ + private boolean isMissingNativeEngineFiles(LeafReaderContext context) { + final SegmentReader reader = Lucene.segmentReader(context.reader()); + final FieldInfo fieldInfo = reader.getFieldInfos().fieldInfo(knnQuery.getField()); + // if segment has no documents with at least 1 vector field, field info will be null + if (fieldInfo == null) { + return false; + } + final KNNEngine knnEngine = FieldInfoExtractor.extractKNNEngine(fieldInfo); + final List engineFiles = KNNCodecUtil.getEngineFiles( + knnEngine.getExtension(), + knnQuery.getField(), + reader.getSegmentInfo().info + ); + return engineFiles.isEmpty(); + } } diff --git a/src/main/java/org/opensearch/knn/index/query/RNNQueryFactory.java b/src/main/java/org/opensearch/knn/index/query/RNNQueryFactory.java index 99152ef6b..b5166866c 100644 --- a/src/main/java/org/opensearch/knn/index/query/RNNQueryFactory.java +++ b/src/main/java/org/opensearch/knn/index/query/RNNQueryFactory.java @@ -88,6 +88,7 @@ public static Query create(RNNQueryFactory.CreateQueryRequest createQueryRequest .indexName(indexName) .parentsFilter(parentFilter) .radius(radius) + .vectorDataType(vectorDataType) .methodParameters(methodParameters) .context(knnQueryContext) .filterQuery(filterQuery) diff --git a/src/main/java/org/opensearch/knn/index/query/nativelib/NativeEngineKnnVectorQuery.java b/src/main/java/org/opensearch/knn/index/query/nativelib/NativeEngineKnnVectorQuery.java index c97a0d061..8b861b430 100644 --- a/src/main/java/org/opensearch/knn/index/query/nativelib/NativeEngineKnnVectorQuery.java +++ b/src/main/java/org/opensearch/knn/index/query/nativelib/NativeEngineKnnVectorQuery.java @@ -57,7 +57,7 @@ public Weight createWeight(IndexSearcher indexSearcher, ScoreMode scoreMode, flo List leafReaderContexts = reader.leaves(); List> perLeafResults; RescoreContext rescoreContext = knnQuery.getRescoreContext(); - int finalK = knnQuery.getK(); + final int finalK = knnQuery.getK(); if (rescoreContext == null) { perLeafResults = doSearch(indexSearcher, leafReaderContexts, knnWeight, finalK); } else { diff --git a/src/test/java/org/opensearch/knn/index/FaissIT.java b/src/test/java/org/opensearch/knn/index/FaissIT.java index 2df1d8a60..c494f7f1f 100644 --- a/src/test/java/org/opensearch/knn/index/FaissIT.java +++ b/src/test/java/org/opensearch/knn/index/FaissIT.java @@ -92,6 +92,7 @@ public class FaissIT extends KNNRestTestCase { private static final String INTEGER_FIELD_NAME = "int_field"; private static final String FILED_TYPE_INTEGER = "integer"; private static final String NON_EXISTENT_INTEGER_FIELD_NAME = "nonexistent_int_field"; + public static final int NEVER_BUILD_VECTOR_DATA_STRUCTURE_THRESHOLD = -1; static TestUtils.TestData testData; @@ -575,6 +576,125 @@ public void testHNSWSQFP16_whenIndexedAndQueried_thenSucceed() { validateGraphEviction(); } + @SneakyThrows + public void testHNSWSQFP16_whenGraphThresholdIsNegative_whenIndexed_thenSkipCreatingGraph() { + final String indexName = "test-index-hnsw-sqfp16"; + final String fieldName = "test-field-hnsw-sqfp16"; + final SpaceType[] spaceTypes = { SpaceType.L2, SpaceType.INNER_PRODUCT }; + final Random random = new Random(); + final SpaceType spaceType = spaceTypes[random.nextInt(spaceTypes.length)]; + + final int dimension = 128; + final int numDocs = 100; + + // Create an index + final XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .startObject("properties") + .startObject(fieldName) + .field("type", "knn_vector") + .field("dimension", dimension) + .startObject(KNN_METHOD) + .field(NAME, METHOD_HNSW) + .field(METHOD_PARAMETER_SPACE_TYPE, spaceType.getValue()) + .field(KNN_ENGINE, KNNEngine.FAISS.getName()) + .startObject(PARAMETERS) + .startObject(METHOD_ENCODER_PARAMETER) + .field(NAME, ENCODER_SQ) + .startObject(PARAMETERS) + .field(FAISS_SQ_TYPE, FAISS_SQ_ENCODER_FP16) + .endObject() + .endObject() + .endObject() + .endObject() + .endObject() + .endObject() + .endObject(); + + final Map mappingMap = xContentBuilderToMap(builder); + final String mapping = builder.toString(); + final Settings knnIndexSettings = buildKNNIndexSettings(-1); + createKnnIndex(indexName, knnIndexSettings, mapping); + assertEquals(new TreeMap<>(mappingMap), new TreeMap<>(getIndexMappingAsMap(indexName))); + indexTestData(indexName, fieldName, dimension, numDocs); + + final float[] queryVector = new float[dimension]; + Arrays.fill(queryVector, (float) numDocs); + + // Assert we have the right number of documents in the index + assertEquals(numDocs, getDocCount(indexName)); + + final Response searchResponse = searchKNNIndex(indexName, buildSearchQuery(fieldName, 1, queryVector, null), 1); + final List results = parseSearchResponse(EntityUtils.toString(searchResponse.getEntity()), fieldName); + // expect result due to exact search + assertEquals(1, results.size()); + + deleteKNNIndex(indexName); + validateGraphEviction(); + } + + @SneakyThrows + public void testHNSWSQFP16_whenGraphThresholdIsMetDuringMerge_thenCreateGraph() { + final String indexName = "test-index-hnsw-sqfp16"; + final String fieldName = "test-field-hnsw-sqfp16"; + final SpaceType[] spaceTypes = { SpaceType.L2, SpaceType.INNER_PRODUCT }; + final Random random = new Random(); + final SpaceType spaceType = spaceTypes[random.nextInt(spaceTypes.length)]; + final int dimension = 128; + final int numDocs = 100; + + // Create an index + final XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .startObject("properties") + .startObject(fieldName) + .field("type", "knn_vector") + .field("dimension", dimension) + .startObject(KNN_METHOD) + .field(NAME, METHOD_HNSW) + .field(METHOD_PARAMETER_SPACE_TYPE, spaceType.getValue()) + .field(KNN_ENGINE, KNNEngine.FAISS.getName()) + .startObject(PARAMETERS) + .startObject(METHOD_ENCODER_PARAMETER) + .field(NAME, ENCODER_SQ) + .startObject(PARAMETERS) + .field(FAISS_SQ_TYPE, FAISS_SQ_ENCODER_FP16) + .endObject() + .endObject() + .endObject() + .endObject() + .endObject() + .endObject() + .endObject(); + + final Map mappingMap = xContentBuilderToMap(builder); + final String mapping = builder.toString(); + final Settings knnIndexSettings = buildKNNIndexSettings(numDocs); + createKnnIndex(indexName, knnIndexSettings, mapping); + assertEquals(new TreeMap<>(mappingMap), new TreeMap<>(getIndexMappingAsMap(indexName))); + indexTestData(indexName, fieldName, dimension, numDocs); + + final float[] queryVector = new float[dimension]; + Arrays.fill(queryVector, (float) numDocs); + + // Assert we have the right number of documents in the index + assertEquals(numDocs, getDocCount(indexName)); + + // KNN Query should return empty result + final Response searchResponse = searchKNNIndex(indexName, buildSearchQuery(fieldName, 1, queryVector, null), 1); + final List results = parseSearchResponse(EntityUtils.toString(searchResponse.getEntity()), fieldName); + assertEquals(1, results.size()); + + // update index setting to build graph and do force merge + // update build vector data structure setting + forceMergeKnnIndex(indexName, 1); + + queryTestData(indexName, fieldName, dimension, numDocs); + + deleteKNNIndex(indexName); + validateGraphEviction(); + } + @SneakyThrows public void testIVFSQFP16_whenIndexedAndQueried_thenSucceed() { @@ -1708,6 +1828,111 @@ public void testIVF_whenBinaryFormat_whenIVF_thenSuccess() { validateGraphEviction(); } + @SneakyThrows + public void testEndToEnd_whenDoRadiusSearch_whenNoGraphFileIsCreated_whenDistanceThreshold_thenSucceed() { + final SpaceType spaceType = SpaceType.L2; + + final List mValues = ImmutableList.of(16, 32, 64, 128); + final List efConstructionValues = ImmutableList.of(16, 32, 64, 128); + final List efSearchValues = ImmutableList.of(16, 32, 64, 128); + + final Integer dimension = testData.indexData.vectors[0].length; + final Settings knnIndexSettings = buildKNNIndexSettings(NEVER_BUILD_VECTOR_DATA_STRUCTURE_THRESHOLD); + + // Create an index + final XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .startObject("properties") + .startObject(FIELD_NAME) + .field("type", "knn_vector") + .field("dimension", dimension) + .startObject(KNN_METHOD) + .field(NAME, METHOD_HNSW) + .field(METHOD_PARAMETER_SPACE_TYPE, spaceType.getValue()) + .field(KNN_ENGINE, KNNEngine.FAISS.getName()) + .startObject(PARAMETERS) + .field(METHOD_PARAMETER_M, mValues.get(random().nextInt(mValues.size()))) + .field(METHOD_PARAMETER_EF_CONSTRUCTION, efConstructionValues.get(random().nextInt(efConstructionValues.size()))) + .field(KNNConstants.METHOD_PARAMETER_EF_SEARCH, efSearchValues.get(random().nextInt(efSearchValues.size()))) + .endObject() + .endObject() + .endObject() + .endObject() + .endObject(); + createKnnIndex(INDEX_NAME, knnIndexSettings, builder.toString()); + + // Index the test data + for (int i = 0; i < testData.indexData.docs.length; i++) { + addKnnDoc( + INDEX_NAME, + Integer.toString(testData.indexData.docs[i]), + FIELD_NAME, + Floats.asList(testData.indexData.vectors[i]).toArray() + ); + } + + // Assert we have the right number of documents + refreshAllNonSystemIndices(); + assertEquals(testData.indexData.docs.length, getDocCount(INDEX_NAME)); + + final float distance = 300000000000f; + final List> resultsFromDistance = validateRadiusSearchResults( + INDEX_NAME, + FIELD_NAME, + testData.queries, + distance, + null, + spaceType, + null, + null + ); + assertFalse(resultsFromDistance.isEmpty()); + resultsFromDistance.forEach(result -> { assertFalse(result.isEmpty()); }); + final float score = spaceType.scoreTranslation(distance); + final List> resultsFromScore = validateRadiusSearchResults( + INDEX_NAME, + FIELD_NAME, + testData.queries, + null, + score, + spaceType, + null, + null + ); + assertFalse(resultsFromScore.isEmpty()); + resultsFromScore.forEach(result -> { assertFalse(result.isEmpty()); }); + + // Delete index + deleteKNNIndex(INDEX_NAME); + } + + @SneakyThrows + public void testRadialQueryWithFilter_whenNoGraphIsCreated_thenSuccess() { + setupKNNIndexForFilterQuery(buildKNNIndexSettings(NEVER_BUILD_VECTOR_DATA_STRUCTURE_THRESHOLD)); + + final float[][] searchVector = new float[][] { { 3.3f, 3.0f, 5.0f } }; + TermQueryBuilder termQueryBuilder = QueryBuilders.termQuery("color", "red"); + List expectedDocIds = Arrays.asList(DOC_ID_3); + + float distance = 15f; + List> queryResult = validateRadiusSearchResults( + INDEX_NAME, + FIELD_NAME, + searchVector, + distance, + null, + SpaceType.L2, + termQueryBuilder, + null + ); + + assertEquals(1, queryResult.get(0).size()); + assertEquals(expectedDocIds.get(0), queryResult.get(0).get(0).getDocId()); + + // Delete index + deleteKNNIndex(INDEX_NAME); + } + @SneakyThrows public void testQueryWithFilter_whenNonExistingFieldUsedInFilter_thenSuccessful() { XContentBuilder builder = XContentFactory.jsonBuilder() @@ -1780,6 +2005,10 @@ public void testQueryWithFilter_whenNonExistingFieldUsedInFilter_thenSuccessful( } protected void setupKNNIndexForFilterQuery() throws Exception { + setupKNNIndexForFilterQuery(getKNNDefaultIndexSettings()); + } + + protected void setupKNNIndexForFilterQuery(Settings settings) throws Exception { // Create Mappings XContentBuilder builder = XContentFactory.jsonBuilder() .startObject() @@ -1797,7 +2026,7 @@ protected void setupKNNIndexForFilterQuery() throws Exception { .endObject(); final String mapping = builder.toString(); - createKnnIndex(INDEX_NAME, mapping); + createKnnIndex(INDEX_NAME, settings, mapping); addKnnDocWithAttributes( DOC_ID_1, diff --git a/src/test/java/org/opensearch/knn/index/OpenSearchIT.java b/src/test/java/org/opensearch/knn/index/OpenSearchIT.java index 0bb7947ba..481b307fa 100644 --- a/src/test/java/org/opensearch/knn/index/OpenSearchIT.java +++ b/src/test/java/org/opensearch/knn/index/OpenSearchIT.java @@ -15,6 +15,7 @@ import com.google.common.primitives.Floats; import java.util.Locale; import lombok.SneakyThrows; +import org.apache.hc.core5.http.ParseException; import org.junit.BeforeClass; import org.junit.Ignore; import org.opensearch.knn.KNNRestTestCase; @@ -42,6 +43,8 @@ import java.util.TreeMap; import static org.hamcrest.Matchers.containsString; +import static org.opensearch.knn.index.KNNSettings.INDEX_KNN_BUILD_VECTOR_DATA_STRUCTURE_THRESHOLD_MAX; +import static org.opensearch.knn.index.KNNSettings.INDEX_KNN_BUILD_VECTOR_DATA_STRUCTURE_THRESHOLD_MIN; public class OpenSearchIT extends KNNRestTestCase { @@ -611,4 +614,297 @@ public void testCacheClear_whenCloseIndex() throws Exception { fail("Graphs are not getting evicted"); } + + public void testKNNIndex_whenBuildGraphThresholdIsPresent_thenGetThresholdValue() throws Exception { + final Integer buildVectorDataStructureThreshold = randomIntBetween( + INDEX_KNN_BUILD_VECTOR_DATA_STRUCTURE_THRESHOLD_MIN, + INDEX_KNN_BUILD_VECTOR_DATA_STRUCTURE_THRESHOLD_MAX + ); + final Settings settings = Settings.builder().put(buildKNNIndexSettings(buildVectorDataStructureThreshold)).build(); + final String knnIndexMapping = createKnnIndexMapping(FIELD_NAME, KNNEngine.getMaxDimensionByEngine(KNNEngine.DEFAULT)); + final String indexName = "test-index-with-build-graph-settings"; + createKnnIndex(indexName, settings, knnIndexMapping); + final String buildVectorDataStructureThresholdSetting = getIndexSettingByName( + indexName, + KNNSettings.INDEX_KNN_ADVANCED_APPROXIMATE_THRESHOLD + ); + assertNotNull("build_vector_data_structure_threshold index setting is not found", buildVectorDataStructureThresholdSetting); + assertEquals( + "incorrect setting for build_vector_data_structure_threshold", + buildVectorDataStructureThreshold, + Integer.valueOf(buildVectorDataStructureThresholdSetting) + ); + deleteKNNIndex(indexName); + } + + public void testKNNIndex_whenBuildThresholdIsNotProvided_thenShouldNotReturnSetting() throws Exception { + final String knnIndexMapping = createKnnIndexMapping(FIELD_NAME, KNNEngine.getMaxDimensionByEngine(KNNEngine.DEFAULT)); + final String indexName = "test-index-with-build-graph-settings"; + createKnnIndex(indexName, knnIndexMapping); + final String buildVectorDataStructureThresholdSetting = getIndexSettingByName( + indexName, + KNNSettings.INDEX_KNN_ADVANCED_APPROXIMATE_THRESHOLD + ); + assertNull( + "build_vector_data_structure_threshold index setting should not be added in index setting", + buildVectorDataStructureThresholdSetting + ); + deleteKNNIndex(indexName); + } + + public void testKNNIndex_whenGetIndexSettingWithDefaultIsCalled_thenReturnDefaultBuildGraphThresholdValue() throws Exception { + final String knnIndexMapping = createKnnIndexMapping(FIELD_NAME, KNNEngine.getMaxDimensionByEngine(KNNEngine.DEFAULT)); + final String indexName = "test-index-with-build-vector-graph-settings"; + createKnnIndex(indexName, knnIndexMapping); + final String buildVectorDataStructureThresholdSetting = getIndexSettingByName( + indexName, + KNNSettings.INDEX_KNN_ADVANCED_APPROXIMATE_THRESHOLD, + true + ); + assertNotNull("build_vector_data_structure index setting is not found", buildVectorDataStructureThresholdSetting); + assertEquals( + "incorrect default setting for build_vector_data_structure_threshold", + KNNSettings.INDEX_KNN_ADVANCED_APPROXIMATE_THRESHOLD_DEFAULT_VALUE, + Integer.valueOf(buildVectorDataStructureThresholdSetting) + ); + deleteKNNIndex(indexName); + } + + /* + For this testcase, we will create index with setting build_vector_data_structure_threshold as -1, then index few documents, perform knn search, + then, confirm hits because of exact search though there are no graph. In next step, update setting to 0, force merge segment to 1, perform knn search and confirm expected + hits are returned. + */ + public void testKNNIndex_whenBuildVectorGraphThresholdIsProvidedEndToEnd_thenBuildGraphBasedOnSetting() throws Exception { + final String indexName = "test-index-1"; + final String fieldName1 = "test-field-1"; + final String fieldName2 = "test-field-2"; + + final Integer dimension = testData.indexData.vectors[0].length; + final Settings knnIndexSettings = buildKNNIndexSettings(-1); + + // Create an index + final XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .startObject("properties") + .startObject(fieldName1) + .field("type", "knn_vector") + .field("dimension", dimension) + .startObject(KNNConstants.KNN_METHOD) + .field(KNNConstants.NAME, KNNConstants.METHOD_HNSW) + .field(KNNConstants.KNN_ENGINE, KNNEngine.NMSLIB.getName()) + .startObject(KNNConstants.PARAMETERS) + .endObject() + .endObject() + .endObject() + .startObject(fieldName2) + .field("type", "knn_vector") + .field("dimension", dimension) + .startObject(KNNConstants.KNN_METHOD) + .field(KNNConstants.NAME, KNNConstants.METHOD_HNSW) + .field(KNNConstants.KNN_ENGINE, KNNEngine.FAISS.getName()) + .startObject(KNNConstants.PARAMETERS) + .endObject() + .endObject() + .endObject() + .endObject() + .endObject(); + + createKnnIndex(indexName, knnIndexSettings, builder.toString()); + + // Index the test data + for (int i = 0; i < testData.indexData.docs.length; i++) { + addKnnDoc( + indexName, + Integer.toString(testData.indexData.docs[i]), + ImmutableList.of(fieldName1, fieldName2), + ImmutableList.of( + Floats.asList(testData.indexData.vectors[i]).toArray(), + Floats.asList(testData.indexData.vectors[i]).toArray() + ) + ); + } + + refreshAllIndices(); + // Assert we have the right number of documents in the index + assertEquals(testData.indexData.docs.length, getDocCount(indexName)); + + final List nmslibNeighbors = getResults(indexName, fieldName1, testData.queries[0], 1); + assertEquals("unexpected neighbors are returned", nmslibNeighbors.size(), nmslibNeighbors.size()); + + final List faissNeighbors = getResults(indexName, fieldName2, testData.queries[0], 1); + assertEquals("unexpected neighbors are returned", faissNeighbors.size(), faissNeighbors.size()); + + // update build vector data structure setting + updateIndexSettings(indexName, Settings.builder().put(KNNSettings.INDEX_KNN_ADVANCED_APPROXIMATE_THRESHOLD, 0)); + forceMergeKnnIndex(indexName, 1); + + final int k = 10; + for (int i = 0; i < testData.queries.length; i++) { + // Search nmslib field + final Response response = searchKNNIndex(indexName, new KNNQueryBuilder(fieldName1, testData.queries[i], k), k); + final String responseBody = EntityUtils.toString(response.getEntity()); + final List nmslibValidNeighbors = parseSearchResponse(responseBody, fieldName1); + assertEquals(k, nmslibValidNeighbors.size()); + // Search faiss field + final List faissValidNeighbors = getResults(indexName, fieldName2, testData.queries[i], k); + assertEquals(k, faissValidNeighbors.size()); + } + + // Delete index + deleteKNNIndex(indexName); + } + + /* + For this testcase, we will create index with setting build_vector_data_structure_threshold number of documents to ingest, then index x documents, perform knn search, + then, confirm expected hits are returned. Here, we don't need force merge to build graph, since, threshold is less than + actual number of documents in segments + */ + public void testKNNIndex_whenBuildVectorDataStructureIsLessThanDocCount_thenBuildGraphBasedSuccessfully() throws Exception { + final String indexName = "test-index-1"; + final String fieldName = "test-field-1"; + + final Integer dimension = testData.indexData.vectors[0].length; + final Settings knnIndexSettings = buildKNNIndexSettings(testData.indexData.docs.length); + + // Create an index + final XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .startObject("properties") + .startObject(fieldName) + .field("type", "knn_vector") + .field("dimension", dimension) + .startObject(KNNConstants.KNN_METHOD) + .field(KNNConstants.NAME, KNNConstants.METHOD_HNSW) + .field(KNNConstants.KNN_ENGINE, KNNEngine.NMSLIB.getName()) + .startObject(KNNConstants.PARAMETERS) + .endObject() + .endObject() + .endObject() + .endObject() + .endObject(); + + createKnnIndex(indexName, knnIndexSettings, builder.toString()); + // Disable refresh + updateIndexSettings(indexName, Settings.builder().put("index.refresh_interval", -1)); + + // Index the test data without refresh on every document + for (int i = 0; i < testData.indexData.docs.length; i++) { + addKnnDoc( + indexName, + Integer.toString(testData.indexData.docs[i]), + ImmutableList.of(fieldName), + ImmutableList.of(Floats.asList(testData.indexData.vectors[i]).toArray()), + false + ); + } + + refreshAllIndices(); + // Assert we have the right number of documents in the index + assertEquals(testData.indexData.docs.length, getDocCount(indexName)); + + final int k = 10; + for (int i = 0; i < testData.queries.length; i++) { + final Response response = searchKNNIndex(indexName, new KNNQueryBuilder(fieldName, testData.queries[i], k), k); + final String responseBody = EntityUtils.toString(response.getEntity()); + final List nmslibValidNeighbors = parseSearchResponse(responseBody, fieldName); + assertEquals(k, nmslibValidNeighbors.size()); + } + // Delete index + deleteKNNIndex(indexName); + } + + /* + For this testcase, we will create index with setting build_vector_data_structure_threshold as -1, then index few documents, perform knn search, + then, confirm hits because of exact search though there are no graph. In next step, update setting to 0, force merge segment to 1, perform knn search and confirm expected + hits are returned. + */ + public void testKNNIndex_whenBuildVectorGraphThresholdIsProvidedEndToEnd_thenBuildGraphBasedOnSettingUsingRadialSearch() + throws Exception { + final String indexName = "test-index-1"; + final String fieldName1 = "test-field-1"; + final String fieldName2 = "test-field-2"; + + final Integer dimension = testData.indexData.vectors[0].length; + final Settings knnIndexSettings = buildKNNIndexSettings(-1); + + // Create an index + final XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .startObject("properties") + .startObject(fieldName1) + .field("type", "knn_vector") + .field("dimension", dimension) + .startObject(KNNConstants.KNN_METHOD) + .field(KNNConstants.NAME, KNNConstants.METHOD_HNSW) + .field(KNNConstants.KNN_ENGINE, KNNEngine.NMSLIB.getName()) + .startObject(KNNConstants.PARAMETERS) + .endObject() + .endObject() + .endObject() + .startObject(fieldName2) + .field("type", "knn_vector") + .field("dimension", dimension) + .startObject(KNNConstants.KNN_METHOD) + .field(KNNConstants.NAME, KNNConstants.METHOD_HNSW) + .field(KNNConstants.KNN_ENGINE, KNNEngine.FAISS.getName()) + .startObject(KNNConstants.PARAMETERS) + .endObject() + .endObject() + .endObject() + .endObject() + .endObject(); + + createKnnIndex(indexName, knnIndexSettings, builder.toString()); + + // Index the test data + for (int i = 0; i < testData.indexData.docs.length; i++) { + addKnnDoc( + indexName, + Integer.toString(testData.indexData.docs[i]), + ImmutableList.of(fieldName1, fieldName2), + ImmutableList.of( + Floats.asList(testData.indexData.vectors[i]).toArray(), + Floats.asList(testData.indexData.vectors[i]).toArray() + ) + ); + } + + refreshAllIndices(); + // Assert we have the right number of documents in the index + assertEquals(testData.indexData.docs.length, getDocCount(indexName)); + + final List nmslibNeighbors = getResults(indexName, fieldName1, testData.queries[0], 1); + assertEquals("unexpected neighbors are returned", nmslibNeighbors.size(), nmslibNeighbors.size()); + + final List faissNeighbors = getResults(indexName, fieldName2, testData.queries[0], 1); + assertEquals("unexpected neighbors are returned", faissNeighbors.size(), faissNeighbors.size()); + + // update build vector data structure setting + updateIndexSettings(indexName, Settings.builder().put(KNNSettings.INDEX_KNN_ADVANCED_APPROXIMATE_THRESHOLD, 0)); + forceMergeKnnIndex(indexName, 1); + + final int k = 10; + for (int i = 0; i < testData.queries.length; i++) { + // Search nmslib field + final Response response = searchKNNIndex(indexName, new KNNQueryBuilder(fieldName1, testData.queries[i], k), k); + final String responseBody = EntityUtils.toString(response.getEntity()); + final List nmslibValidNeighbors = parseSearchResponse(responseBody, fieldName1); + assertEquals(k, nmslibValidNeighbors.size()); + // Search faiss field + final List faissValidNeighbors = getResults(indexName, fieldName2, testData.queries[i], k); + assertEquals(k, faissValidNeighbors.size()); + } + + // Delete index + deleteKNNIndex(indexName); + } + + private List getResults(final String indexName, final String fieldName, final float[] vector, final int k) + throws IOException, ParseException { + final Response searchResponseField = searchKNNIndex(indexName, new KNNQueryBuilder(fieldName, vector, k), k); + final String searchResponseBody = EntityUtils.toString(searchResponseField.getEntity()); + return parseSearchResponse(searchResponseBody, fieldName); + } + } diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngineFieldVectorsWriterTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngineFieldVectorsWriterTests.java index 6e6a51b88..4f68a360e 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngineFieldVectorsWriterTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngineFieldVectorsWriterTests.java @@ -126,6 +126,5 @@ public void testRamByteUsed_whenValidInput_thenSuccess() { // testing for value > 0 as we don't have a concrete way to find out expected bytes. This can OS dependent too. Assert.assertTrue(byteWriter.ramBytesUsed() > 0); Mockito.verify(mockedFlatFieldVectorsWriter, Mockito.times(2)).ramBytesUsed(); - } } diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterFlushTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterFlushTests.java index 5bb6d1926..03d0f6160 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterFlushTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterFlushTests.java @@ -31,10 +31,12 @@ import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; import org.opensearch.test.OpenSearchTestCase; +import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; import java.util.Collections; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.function.Predicate; @@ -52,6 +54,7 @@ import static org.mockito.Mockito.never; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; import static org.mockito.Mockito.when; @RequiredArgsConstructor @@ -76,12 +79,14 @@ public class NativeEngines990KnnVectorsWriterFlushTests extends OpenSearchTestCa private final String description; private final List> vectorsPerField; + private static final Integer BUILD_GRAPH_ALWAYS_THRESHOLD = 0; + private static final Integer BUILD_GRAPH_NEVER_THRESHOLD = -1; @Override public void setUp() throws Exception { super.setUp(); MockitoAnnotations.openMocks(this); - objectUnderTest = new NativeEngines990KnnVectorsWriter(segmentWriteState, flatVectorsWriter); + objectUnderTest = new NativeEngines990KnnVectorsWriter(segmentWriteState, flatVectorsWriter, BUILD_GRAPH_ALWAYS_THRESHOLD); mockedFlatFieldVectorsWriter = Mockito.mock(FlatFieldVectorsWriter.class); Mockito.doNothing().when(mockedFlatFieldVectorsWriter).addValue(Mockito.anyInt(), Mockito.any()); Mockito.when(flatVectorsWriter.addField(Mockito.any())).thenReturn(mockedFlatFieldVectorsWriter); @@ -299,6 +304,539 @@ public void testFlush_WithQuantization() { } } + public void testFlush_whenThresholdIsNegative_thenNativeIndexWriterIsNeverCalled() throws IOException { + // Given + List> expectedVectorValues = new ArrayList<>(); + IntStream.range(0, vectorsPerField.size()).forEach(i -> { + final TestVectorValues.PreDefinedFloatVectorValues randomVectorValues = new TestVectorValues.PreDefinedFloatVectorValues( + new ArrayList<>(vectorsPerField.get(i).values()) + ); + final KNNVectorValues knnVectorValues = KNNVectorValuesFactory.getVectorValues( + VectorDataType.FLOAT, + randomVectorValues + ); + expectedVectorValues.add(knnVectorValues); + + }); + + final NativeEngines990KnnVectorsWriter nativeEngineWriter = new NativeEngines990KnnVectorsWriter( + segmentWriteState, + flatVectorsWriter, + BUILD_GRAPH_NEVER_THRESHOLD + ); + + try ( + MockedStatic fieldWriterMockedStatic = mockStatic(NativeEngineFieldVectorsWriter.class); + MockedStatic knnVectorValuesFactoryMockedStatic = mockStatic(KNNVectorValuesFactory.class); + MockedStatic quantizationServiceMockedStatic = mockStatic(QuantizationService.class); + MockedStatic nativeIndexWriterMockedStatic = mockStatic(NativeIndexWriter.class); + MockedConstruction knn990QuantWriterMockedConstruction = mockConstruction( + KNN990QuantizationStateWriter.class + ); + ) { + quantizationServiceMockedStatic.when(() -> QuantizationService.getInstance()).thenReturn(quantizationService); + IntStream.range(0, vectorsPerField.size()).forEach(i -> { + final FieldInfo fieldInfo = fieldInfo( + i, + VectorEncoding.FLOAT32, + Map.of(KNNConstants.VECTOR_DATA_TYPE_FIELD, "float", KNNConstants.KNN_ENGINE, "faiss") + ); + + NativeEngineFieldVectorsWriter field = nativeEngineFieldVectorsWriter(fieldInfo, vectorsPerField.get(i)); + fieldWriterMockedStatic.when( + () -> NativeEngineFieldVectorsWriter.create(fieldInfo, mockedFlatFieldVectorsWriter, segmentWriteState.infoStream) + ).thenReturn(field); + try { + nativeEngineWriter.addField(fieldInfo); + } catch (Exception e) { + throw new RuntimeException(e); + } + + DocsWithFieldSet docsWithFieldSet = field.getDocsWithField(); + knnVectorValuesFactoryMockedStatic.when( + () -> KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, docsWithFieldSet, vectorsPerField.get(i)) + ).thenReturn(expectedVectorValues.get(i)); + + when(quantizationService.getQuantizationParams(fieldInfo)).thenReturn(null); + nativeIndexWriterMockedStatic.when(() -> NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, null)) + .thenReturn(nativeIndexWriter); + }); + + doAnswer(answer -> { + Thread.sleep(2); // Need this for KNNGraph value assertion, removing this will fail the assertion + return null; + }).when(nativeIndexWriter).flushIndex(any(), anyInt()); + + // When + nativeEngineWriter.flush(5, null); + + // Then + verify(flatVectorsWriter).flush(5, null); + if (vectorsPerField.size() > 0) { + assertEquals(0, knn990QuantWriterMockedConstruction.constructed().size()); + } + verifyNoInteractions(nativeIndexWriter); + } + } + + public void testFlush_whenThresholdIsGreaterThanVectorSize_thenNativeIndexWriterIsNeverCalled() throws IOException { + // Given + List> expectedVectorValues = new ArrayList<>(); + final Map sizeMap = new HashMap<>(); + IntStream.range(0, vectorsPerField.size()).forEach(i -> { + final TestVectorValues.PreDefinedFloatVectorValues randomVectorValues = new TestVectorValues.PreDefinedFloatVectorValues( + new ArrayList<>(vectorsPerField.get(i).values()) + ); + final KNNVectorValues knnVectorValues = KNNVectorValuesFactory.getVectorValues( + VectorDataType.FLOAT, + randomVectorValues + ); + sizeMap.put(i, randomVectorValues.size()); + expectedVectorValues.add(knnVectorValues); + + }); + final int maxThreshold = sizeMap.values().stream().filter(count -> count != 0).max(Integer::compareTo).orElse(0); + final NativeEngines990KnnVectorsWriter nativeEngineWriter = new NativeEngines990KnnVectorsWriter( + segmentWriteState, + flatVectorsWriter, + maxThreshold + 1 + ); + + try ( + MockedStatic fieldWriterMockedStatic = mockStatic(NativeEngineFieldVectorsWriter.class); + MockedStatic knnVectorValuesFactoryMockedStatic = mockStatic(KNNVectorValuesFactory.class); + MockedStatic quantizationServiceMockedStatic = mockStatic(QuantizationService.class); + MockedStatic nativeIndexWriterMockedStatic = mockStatic(NativeIndexWriter.class); + MockedConstruction knn990QuantWriterMockedConstruction = mockConstruction( + KNN990QuantizationStateWriter.class + ); + ) { + quantizationServiceMockedStatic.when(() -> QuantizationService.getInstance()).thenReturn(quantizationService); + IntStream.range(0, vectorsPerField.size()).forEach(i -> { + final FieldInfo fieldInfo = fieldInfo( + i, + VectorEncoding.FLOAT32, + Map.of(KNNConstants.VECTOR_DATA_TYPE_FIELD, "float", KNNConstants.KNN_ENGINE, "faiss") + ); + + NativeEngineFieldVectorsWriter field = nativeEngineFieldVectorsWriter(fieldInfo, vectorsPerField.get(i)); + fieldWriterMockedStatic.when( + () -> NativeEngineFieldVectorsWriter.create(fieldInfo, mockedFlatFieldVectorsWriter, segmentWriteState.infoStream) + ).thenReturn(field); + try { + nativeEngineWriter.addField(fieldInfo); + } catch (Exception e) { + throw new RuntimeException(e); + } + + DocsWithFieldSet docsWithFieldSet = field.getDocsWithField(); + knnVectorValuesFactoryMockedStatic.when( + () -> KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, docsWithFieldSet, vectorsPerField.get(i)) + ).thenReturn(expectedVectorValues.get(i)); + + when(quantizationService.getQuantizationParams(fieldInfo)).thenReturn(null); + nativeIndexWriterMockedStatic.when(() -> NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, null)) + .thenReturn(nativeIndexWriter); + }); + + doAnswer(answer -> { + Thread.sleep(2); // Need this for KNNGraph value assertion, removing this will fail the assertion + return null; + }).when(nativeIndexWriter).flushIndex(any(), anyInt()); + + // When + nativeEngineWriter.flush(5, null); + + // Then + verify(flatVectorsWriter).flush(5, null); + if (vectorsPerField.size() > 0) { + assertEquals(0, knn990QuantWriterMockedConstruction.constructed().size()); + } + verifyNoInteractions(nativeIndexWriter); + } + } + + public void testFlush_whenThresholdIsEqualToMinNumberOfVectors_thenNativeIndexWriterIsCalled() throws IOException { + // Given + List> expectedVectorValues = new ArrayList<>(); + final Map sizeMap = new HashMap<>(); + IntStream.range(0, vectorsPerField.size()).forEach(i -> { + final TestVectorValues.PreDefinedFloatVectorValues randomVectorValues = new TestVectorValues.PreDefinedFloatVectorValues( + new ArrayList<>(vectorsPerField.get(i).values()) + ); + final KNNVectorValues knnVectorValues = KNNVectorValuesFactory.getVectorValues( + VectorDataType.FLOAT, + randomVectorValues + ); + sizeMap.put(i, randomVectorValues.size()); + expectedVectorValues.add(knnVectorValues); + + }); + + final int minThreshold = sizeMap.values().stream().filter(count -> count != 0).min(Integer::compareTo).orElse(0); + final NativeEngines990KnnVectorsWriter nativeEngineWriter = new NativeEngines990KnnVectorsWriter( + segmentWriteState, + flatVectorsWriter, + minThreshold + ); + + try ( + MockedStatic fieldWriterMockedStatic = mockStatic(NativeEngineFieldVectorsWriter.class); + MockedStatic knnVectorValuesFactoryMockedStatic = mockStatic(KNNVectorValuesFactory.class); + MockedStatic quantizationServiceMockedStatic = mockStatic(QuantizationService.class); + MockedStatic nativeIndexWriterMockedStatic = mockStatic(NativeIndexWriter.class); + MockedConstruction knn990QuantWriterMockedConstruction = mockConstruction( + KNN990QuantizationStateWriter.class + ); + ) { + quantizationServiceMockedStatic.when(() -> QuantizationService.getInstance()).thenReturn(quantizationService); + IntStream.range(0, vectorsPerField.size()).forEach(i -> { + final FieldInfo fieldInfo = fieldInfo( + i, + VectorEncoding.FLOAT32, + Map.of(KNNConstants.VECTOR_DATA_TYPE_FIELD, "float", KNNConstants.KNN_ENGINE, "faiss") + ); + + NativeEngineFieldVectorsWriter field = nativeEngineFieldVectorsWriter(fieldInfo, vectorsPerField.get(i)); + fieldWriterMockedStatic.when( + () -> NativeEngineFieldVectorsWriter.create(fieldInfo, mockedFlatFieldVectorsWriter, segmentWriteState.infoStream) + ).thenReturn(field); + try { + nativeEngineWriter.addField(fieldInfo); + } catch (Exception e) { + throw new RuntimeException(e); + } + + DocsWithFieldSet docsWithFieldSet = field.getDocsWithField(); + knnVectorValuesFactoryMockedStatic.when( + () -> KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, docsWithFieldSet, vectorsPerField.get(i)) + ).thenReturn(expectedVectorValues.get(i)); + + when(quantizationService.getQuantizationParams(fieldInfo)).thenReturn(null); + nativeIndexWriterMockedStatic.when(() -> NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, null)) + .thenReturn(nativeIndexWriter); + }); + + doAnswer(answer -> { + Thread.sleep(2); // Need this for KNNGraph value assertion, removing this will fail the assertion + return null; + }).when(nativeIndexWriter).flushIndex(any(), anyInt()); + + // When + nativeEngineWriter.flush(5, null); + + // Then + verify(flatVectorsWriter).flush(5, null); + if (vectorsPerField.size() > 0) { + assertEquals(0, knn990QuantWriterMockedConstruction.constructed().size()); + assertTrue((long) KNNGraphValue.REFRESH_TOTAL_TIME_IN_MILLIS.getValue() > 0); + } + IntStream.range(0, vectorsPerField.size()).forEach(i -> { + try { + if (vectorsPerField.get(i).size() > 0) { + verify(nativeIndexWriter).flushIndex(expectedVectorValues.get(i), vectorsPerField.get(i).size()); + } else { + verify(nativeIndexWriter, never()).flushIndex(expectedVectorValues.get(i), vectorsPerField.get(i).size()); + } + } catch (Exception e) { + throw new RuntimeException(e); + } + }); + } + } + + public void testFlush_whenThresholdIsEqualToFixedValue_thenRelevantNativeIndexWriterIsCalled() throws IOException { + // Given + List> expectedVectorValues = new ArrayList<>(); + IntStream.range(0, vectorsPerField.size()).forEach(i -> { + final TestVectorValues.PreDefinedFloatVectorValues randomVectorValues = new TestVectorValues.PreDefinedFloatVectorValues( + new ArrayList<>(vectorsPerField.get(i).values()) + ); + final KNNVectorValues knnVectorValues = KNNVectorValuesFactory.getVectorValues( + VectorDataType.FLOAT, + randomVectorValues + ); + expectedVectorValues.add(knnVectorValues); + + }); + final int threshold = 4; + final NativeEngines990KnnVectorsWriter nativeEngineWriter = new NativeEngines990KnnVectorsWriter( + segmentWriteState, + flatVectorsWriter, + threshold + ); + + try ( + MockedStatic fieldWriterMockedStatic = mockStatic(NativeEngineFieldVectorsWriter.class); + MockedStatic knnVectorValuesFactoryMockedStatic = mockStatic(KNNVectorValuesFactory.class); + MockedStatic quantizationServiceMockedStatic = mockStatic(QuantizationService.class); + MockedStatic nativeIndexWriterMockedStatic = mockStatic(NativeIndexWriter.class); + MockedConstruction knn990QuantWriterMockedConstruction = mockConstruction( + KNN990QuantizationStateWriter.class + ); + ) { + quantizationServiceMockedStatic.when(() -> QuantizationService.getInstance()).thenReturn(quantizationService); + IntStream.range(0, vectorsPerField.size()).forEach(i -> { + final FieldInfo fieldInfo = fieldInfo( + i, + VectorEncoding.FLOAT32, + Map.of(KNNConstants.VECTOR_DATA_TYPE_FIELD, "float", KNNConstants.KNN_ENGINE, "faiss") + ); + + NativeEngineFieldVectorsWriter field = nativeEngineFieldVectorsWriter(fieldInfo, vectorsPerField.get(i)); + fieldWriterMockedStatic.when( + () -> NativeEngineFieldVectorsWriter.create(fieldInfo, mockedFlatFieldVectorsWriter, segmentWriteState.infoStream) + ).thenReturn(field); + try { + nativeEngineWriter.addField(fieldInfo); + } catch (Exception e) { + throw new RuntimeException(e); + } + + DocsWithFieldSet docsWithFieldSet = field.getDocsWithField(); + knnVectorValuesFactoryMockedStatic.when( + () -> KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, docsWithFieldSet, vectorsPerField.get(i)) + ).thenReturn(expectedVectorValues.get(i)); + + when(quantizationService.getQuantizationParams(fieldInfo)).thenReturn(null); + nativeIndexWriterMockedStatic.when(() -> NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, null)) + .thenReturn(nativeIndexWriter); + }); + + doAnswer(answer -> { + Thread.sleep(2); // Need this for KNNGraph value assertion, removing this will fail the assertion + return null; + }).when(nativeIndexWriter).flushIndex(any(), anyInt()); + + // When + nativeEngineWriter.flush(5, null); + + // Then + verify(flatVectorsWriter).flush(5, null); + if (vectorsPerField.size() > 0) { + assertEquals(0, knn990QuantWriterMockedConstruction.constructed().size()); + } + IntStream.range(0, vectorsPerField.size()).forEach(i -> { + try { + if (vectorsPerField.get(i).size() >= threshold) { + verify(nativeIndexWriter).flushIndex(expectedVectorValues.get(i), vectorsPerField.get(i).size()); + } else { + verify(nativeIndexWriter, never()).flushIndex(expectedVectorValues.get(i), vectorsPerField.get(i).size()); + } + } catch (Exception e) { + throw new RuntimeException(e); + } + }); + } + } + + public void testFlush_whenQuantizationIsProvided_whenBuildGraphDatStructureThresholdIsNotMet_thenSkipBuildingGraph() + throws IOException { + // Given + List> expectedVectorValues = new ArrayList<>(); + final Map sizeMap = new HashMap<>(); + IntStream.range(0, vectorsPerField.size()).forEach(i -> { + final TestVectorValues.PreDefinedFloatVectorValues randomVectorValues = new TestVectorValues.PreDefinedFloatVectorValues( + new ArrayList<>(vectorsPerField.get(i).values()) + ); + final KNNVectorValues knnVectorValues = KNNVectorValuesFactory.getVectorValues( + VectorDataType.FLOAT, + randomVectorValues + ); + sizeMap.put(i, randomVectorValues.size()); + expectedVectorValues.add(knnVectorValues); + + }); + final int maxThreshold = sizeMap.values().stream().filter(count -> count != 0).max(Integer::compareTo).orElse(0); + final NativeEngines990KnnVectorsWriter nativeEngineWriter = new NativeEngines990KnnVectorsWriter( + segmentWriteState, + flatVectorsWriter, + maxThreshold + 1 // to avoid building graph using max doc threshold, the same can be achieved by -1 too + ); + + try ( + MockedStatic fieldWriterMockedStatic = mockStatic(NativeEngineFieldVectorsWriter.class); + MockedStatic knnVectorValuesFactoryMockedStatic = mockStatic(KNNVectorValuesFactory.class); + MockedStatic quantizationServiceMockedStatic = mockStatic(QuantizationService.class); + MockedStatic nativeIndexWriterMockedStatic = mockStatic(NativeIndexWriter.class); + MockedConstruction knn990QuantWriterMockedConstruction = mockConstruction( + KNN990QuantizationStateWriter.class + ); + ) { + quantizationServiceMockedStatic.when(() -> QuantizationService.getInstance()).thenReturn(quantizationService); + + IntStream.range(0, vectorsPerField.size()).forEach(i -> { + final FieldInfo fieldInfo = fieldInfo( + i, + VectorEncoding.FLOAT32, + Map.of(KNNConstants.VECTOR_DATA_TYPE_FIELD, "float", KNNConstants.KNN_ENGINE, "faiss") + ); + + NativeEngineFieldVectorsWriter field = nativeEngineFieldVectorsWriter(fieldInfo, vectorsPerField.get(i)); + fieldWriterMockedStatic.when( + () -> NativeEngineFieldVectorsWriter.create(fieldInfo, mockedFlatFieldVectorsWriter, segmentWriteState.infoStream) + ).thenReturn(field); + + try { + nativeEngineWriter.addField(fieldInfo); + } catch (Exception e) { + throw new RuntimeException(e); + } + + DocsWithFieldSet docsWithFieldSet = field.getDocsWithField(); + knnVectorValuesFactoryMockedStatic.when( + () -> KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, docsWithFieldSet, vectorsPerField.get(i)) + ).thenReturn(expectedVectorValues.get(i)); + + when(quantizationService.getQuantizationParams(fieldInfo)).thenReturn(quantizationParams); + try { + when(quantizationService.train(quantizationParams, expectedVectorValues.get(i), vectorsPerField.get(i).size())) + .thenReturn(quantizationState); + } catch (Exception e) { + throw new RuntimeException(e); + } + + nativeIndexWriterMockedStatic.when(() -> NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, quantizationState)) + .thenReturn(nativeIndexWriter); + }); + doAnswer(answer -> { + Thread.sleep(2); // Need this for KNNGraph value assertion, removing this will fail the assertion + return null; + }).when(nativeIndexWriter).flushIndex(any(), anyInt()); + + // When + nativeEngineWriter.flush(5, null); + + // Then + verify(flatVectorsWriter).flush(5, null); + if (vectorsPerField.size() > 0) { + verify(knn990QuantWriterMockedConstruction.constructed().get(0)).writeHeader(segmentWriteState); + } else { + assertEquals(0, knn990QuantWriterMockedConstruction.constructed().size()); + } + verifyNoInteractions(nativeIndexWriter); + IntStream.range(0, vectorsPerField.size()).forEach(i -> { + try { + if (vectorsPerField.get(i).isEmpty()) { + verify(knn990QuantWriterMockedConstruction.constructed().get(0), never()).writeState(i, quantizationState); + } else { + verify(knn990QuantWriterMockedConstruction.constructed().get(0)).writeState(i, quantizationState); + } + } catch (Exception e) { + throw new RuntimeException(e); + } + }); + final Long expectedTimesGetVectorValuesIsCalled = vectorsPerField.stream().filter(Predicate.not(Map::isEmpty)).count(); + knnVectorValuesFactoryMockedStatic.verify( + () -> KNNVectorValuesFactory.getVectorValues(any(VectorDataType.class), any(DocsWithFieldSet.class), any()), + times(Math.toIntExact(expectedTimesGetVectorValuesIsCalled)) + ); + } + } + + public void testFlush_whenQuantizationIsProvided_whenBuildGraphDatStructureThresholdIsNegative_thenSkipBuildingGraph() + throws IOException { + // Given + List> expectedVectorValues = new ArrayList<>(); + final Map sizeMap = new HashMap<>(); + IntStream.range(0, vectorsPerField.size()).forEach(i -> { + final TestVectorValues.PreDefinedFloatVectorValues randomVectorValues = new TestVectorValues.PreDefinedFloatVectorValues( + new ArrayList<>(vectorsPerField.get(i).values()) + ); + final KNNVectorValues knnVectorValues = KNNVectorValuesFactory.getVectorValues( + VectorDataType.FLOAT, + randomVectorValues + ); + sizeMap.put(i, randomVectorValues.size()); + expectedVectorValues.add(knnVectorValues); + + }); + final NativeEngines990KnnVectorsWriter nativeEngineWriter = new NativeEngines990KnnVectorsWriter( + segmentWriteState, + flatVectorsWriter, + BUILD_GRAPH_NEVER_THRESHOLD + ); + + try ( + MockedStatic fieldWriterMockedStatic = mockStatic(NativeEngineFieldVectorsWriter.class); + MockedStatic knnVectorValuesFactoryMockedStatic = mockStatic(KNNVectorValuesFactory.class); + MockedStatic quantizationServiceMockedStatic = mockStatic(QuantizationService.class); + MockedStatic nativeIndexWriterMockedStatic = mockStatic(NativeIndexWriter.class); + MockedConstruction knn990QuantWriterMockedConstruction = mockConstruction( + KNN990QuantizationStateWriter.class + ); + ) { + quantizationServiceMockedStatic.when(() -> QuantizationService.getInstance()).thenReturn(quantizationService); + + IntStream.range(0, vectorsPerField.size()).forEach(i -> { + final FieldInfo fieldInfo = fieldInfo( + i, + VectorEncoding.FLOAT32, + Map.of(KNNConstants.VECTOR_DATA_TYPE_FIELD, "float", KNNConstants.KNN_ENGINE, "faiss") + ); + + NativeEngineFieldVectorsWriter field = nativeEngineFieldVectorsWriter(fieldInfo, vectorsPerField.get(i)); + fieldWriterMockedStatic.when( + () -> NativeEngineFieldVectorsWriter.create(fieldInfo, mockedFlatFieldVectorsWriter, segmentWriteState.infoStream) + ).thenReturn(field); + + try { + nativeEngineWriter.addField(fieldInfo); + } catch (Exception e) { + throw new RuntimeException(e); + } + + DocsWithFieldSet docsWithFieldSet = field.getDocsWithField(); + knnVectorValuesFactoryMockedStatic.when( + () -> KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, docsWithFieldSet, vectorsPerField.get(i)) + ).thenReturn(expectedVectorValues.get(i)); + + when(quantizationService.getQuantizationParams(fieldInfo)).thenReturn(quantizationParams); + try { + when(quantizationService.train(quantizationParams, expectedVectorValues.get(i), vectorsPerField.get(i).size())) + .thenReturn(quantizationState); + } catch (Exception e) { + throw new RuntimeException(e); + } + + nativeIndexWriterMockedStatic.when(() -> NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, quantizationState)) + .thenReturn(nativeIndexWriter); + }); + doAnswer(answer -> { + Thread.sleep(2); // Need this for KNNGraph value assertion, removing this will fail the assertion + return null; + }).when(nativeIndexWriter).flushIndex(any(), anyInt()); + + // When + nativeEngineWriter.flush(5, null); + + // Then + verify(flatVectorsWriter).flush(5, null); + if (vectorsPerField.size() > 0) { + verify(knn990QuantWriterMockedConstruction.constructed().get(0)).writeHeader(segmentWriteState); + } else { + assertEquals(0, knn990QuantWriterMockedConstruction.constructed().size()); + } + verifyNoInteractions(nativeIndexWriter); + IntStream.range(0, vectorsPerField.size()).forEach(i -> { + try { + if (vectorsPerField.get(i).isEmpty()) { + verify(knn990QuantWriterMockedConstruction.constructed().get(0), never()).writeState(i, quantizationState); + } else { + verify(knn990QuantWriterMockedConstruction.constructed().get(0)).writeState(i, quantizationState); + } + } catch (Exception e) { + throw new RuntimeException(e); + } + }); + final Long expectedTimesGetVectorValuesIsCalled = vectorsPerField.stream().filter(Predicate.not(Map::isEmpty)).count(); + knnVectorValuesFactoryMockedStatic.verify( + () -> KNNVectorValuesFactory.getVectorValues(any(VectorDataType.class), any(DocsWithFieldSet.class), any()), + times(Math.toIntExact(expectedTimesGetVectorValuesIsCalled)) + ); + } + } + private FieldInfo fieldInfo(int fieldNumber, VectorEncoding vectorEncoding, Map attributes) { FieldInfo fieldInfo = mock(FieldInfo.class); when(fieldInfo.getFieldNumber()).thenReturn(fieldNumber); diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterMergeTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterMergeTests.java index af18cd281..77f3fd8ed 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterMergeTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterMergeTests.java @@ -34,6 +34,7 @@ import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; import org.opensearch.test.OpenSearchTestCase; +import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; @@ -77,12 +78,14 @@ public class NativeEngines990KnnVectorsWriterMergeTests extends OpenSearchTestCa private final String description; private final Map mergedVectors; private FlatFieldVectorsWriter mockedFlatFieldVectorsWriter; + private static final Integer BUILD_GRAPH_ALWAYS_THRESHOLD = 0; + private static final Integer BUILD_GRAPH_NEVER_THRESHOLD = -1; @Override public void setUp() throws Exception { super.setUp(); MockitoAnnotations.openMocks(this); - objectUnderTest = new NativeEngines990KnnVectorsWriter(segmentWriteState, flatVectorsWriter); + objectUnderTest = new NativeEngines990KnnVectorsWriter(segmentWriteState, flatVectorsWriter, BUILD_GRAPH_ALWAYS_THRESHOLD); mockedFlatFieldVectorsWriter = Mockito.mock(FlatFieldVectorsWriter.class); Mockito.doNothing().when(mockedFlatFieldVectorsWriter).addValue(Mockito.anyInt(), Mockito.any()); Mockito.when(flatVectorsWriter.addField(Mockito.any())).thenReturn(mockedFlatFieldVectorsWriter); @@ -162,6 +165,126 @@ public void testMerge() { } } + public void testMerge_whenThresholdIsNegative_thenNativeIndexWriterIsNeverCalled() throws IOException { + // Given + final TestVectorValues.PreDefinedFloatVectorValues randomVectorValues = new TestVectorValues.PreDefinedFloatVectorValues( + new ArrayList<>(mergedVectors.values()) + ); + final KNNVectorValues knnVectorValues = KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, randomVectorValues); + final NativeEngines990KnnVectorsWriter nativeEngineWriter = new NativeEngines990KnnVectorsWriter( + segmentWriteState, + flatVectorsWriter, + BUILD_GRAPH_NEVER_THRESHOLD + ); + try ( + MockedStatic fieldWriterMockedStatic = mockStatic(NativeEngineFieldVectorsWriter.class); + MockedStatic knnVectorValuesFactoryMockedStatic = mockStatic(KNNVectorValuesFactory.class); + MockedStatic quantizationServiceMockedStatic = mockStatic(QuantizationService.class); + MockedStatic nativeIndexWriterMockedStatic = mockStatic(NativeIndexWriter.class); + MockedStatic mergedVectorValuesMockedStatic = mockStatic( + KnnVectorsWriter.MergedVectorValues.class + ); + MockedConstruction knn990QuantWriterMockedConstruction = mockConstruction( + KNN990QuantizationStateWriter.class + ); + ) { + quantizationServiceMockedStatic.when(() -> QuantizationService.getInstance()).thenReturn(quantizationService); + final FieldInfo fieldInfo = fieldInfo( + 0, + VectorEncoding.FLOAT32, + Map.of(KNNConstants.VECTOR_DATA_TYPE_FIELD, "float", KNNConstants.KNN_ENGINE, "faiss") + ); + + NativeEngineFieldVectorsWriter field = nativeEngineFieldVectorsWriter(fieldInfo, mergedVectors); + fieldWriterMockedStatic.when( + () -> NativeEngineFieldVectorsWriter.create(fieldInfo, mockedFlatFieldVectorsWriter, segmentWriteState.infoStream) + ).thenReturn(field); + + mergedVectorValuesMockedStatic.when(() -> KnnVectorsWriter.MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState)) + .thenReturn(floatVectorValues); + knnVectorValuesFactoryMockedStatic.when(() -> KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, floatVectorValues)) + .thenReturn(knnVectorValues); + + when(quantizationService.getQuantizationParams(fieldInfo)).thenReturn(null); + nativeIndexWriterMockedStatic.when(() -> NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, null)) + .thenReturn(nativeIndexWriter); + doAnswer(answer -> { + Thread.sleep(2); // Need this for KNNGraph value assertion, removing this will fail the assertion + return null; + }).when(nativeIndexWriter).mergeIndex(any(), anyInt()); + + // When + nativeEngineWriter.mergeOneField(fieldInfo, mergeState); + + // Then + verify(flatVectorsWriter).mergeOneField(fieldInfo, mergeState); + assertEquals(0, knn990QuantWriterMockedConstruction.constructed().size()); + verifyNoInteractions(nativeIndexWriter); + } + } + + public void testMerge_whenThresholdIsEqualToNumberOfVectors_thenNativeIndexWriterIsCalled() throws IOException { + // Given + final TestVectorValues.PreDefinedFloatVectorValues randomVectorValues = new TestVectorValues.PreDefinedFloatVectorValues( + new ArrayList<>(mergedVectors.values()) + ); + final KNNVectorValues knnVectorValues = KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, randomVectorValues); + final NativeEngines990KnnVectorsWriter nativeEngineWriter = new NativeEngines990KnnVectorsWriter( + segmentWriteState, + flatVectorsWriter, + mergedVectors.size() + ); + try ( + MockedStatic fieldWriterMockedStatic = mockStatic(NativeEngineFieldVectorsWriter.class); + MockedStatic knnVectorValuesFactoryMockedStatic = mockStatic(KNNVectorValuesFactory.class); + MockedStatic quantizationServiceMockedStatic = mockStatic(QuantizationService.class); + MockedStatic nativeIndexWriterMockedStatic = mockStatic(NativeIndexWriter.class); + MockedStatic mergedVectorValuesMockedStatic = mockStatic( + KnnVectorsWriter.MergedVectorValues.class + ); + MockedConstruction knn990QuantWriterMockedConstruction = mockConstruction( + KNN990QuantizationStateWriter.class + ); + ) { + quantizationServiceMockedStatic.when(() -> QuantizationService.getInstance()).thenReturn(quantizationService); + final FieldInfo fieldInfo = fieldInfo( + 0, + VectorEncoding.FLOAT32, + Map.of(KNNConstants.VECTOR_DATA_TYPE_FIELD, "float", KNNConstants.KNN_ENGINE, "faiss") + ); + + NativeEngineFieldVectorsWriter field = nativeEngineFieldVectorsWriter(fieldInfo, mergedVectors); + fieldWriterMockedStatic.when( + () -> NativeEngineFieldVectorsWriter.create(fieldInfo, mockedFlatFieldVectorsWriter, segmentWriteState.infoStream) + ).thenReturn(field); + + mergedVectorValuesMockedStatic.when(() -> KnnVectorsWriter.MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState)) + .thenReturn(floatVectorValues); + knnVectorValuesFactoryMockedStatic.when(() -> KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, floatVectorValues)) + .thenReturn(knnVectorValues); + + when(quantizationService.getQuantizationParams(fieldInfo)).thenReturn(null); + nativeIndexWriterMockedStatic.when(() -> NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, null)) + .thenReturn(nativeIndexWriter); + doAnswer(answer -> { + Thread.sleep(2); // Need this for KNNGraph value assertion, removing this will fail the assertion + return null; + }).when(nativeIndexWriter).mergeIndex(any(), anyInt()); + + // When + nativeEngineWriter.mergeOneField(fieldInfo, mergeState); + + // Then + verify(flatVectorsWriter).mergeOneField(fieldInfo, mergeState); + assertEquals(0, knn990QuantWriterMockedConstruction.constructed().size()); + if (!mergedVectors.isEmpty()) { + verify(nativeIndexWriter).mergeIndex(knnVectorValues, mergedVectors.size()); + } else { + verifyNoInteractions(nativeIndexWriter); + } + } + } + @SneakyThrows public void testMerge_WithQuantization() { // Given diff --git a/src/test/java/org/opensearch/knn/index/query/ExactSearcherTests.java b/src/test/java/org/opensearch/knn/index/query/ExactSearcherTests.java new file mode 100644 index 000000000..8492ca1f0 --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/query/ExactSearcherTests.java @@ -0,0 +1,139 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.query; + +import lombok.SneakyThrows; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.FieldInfos; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.SegmentCommitInfo; +import org.apache.lucene.index.SegmentInfo; +import org.apache.lucene.index.SegmentReader; +import org.apache.lucene.search.Sort; +import org.apache.lucene.store.FSDirectory; +import org.apache.lucene.util.StringHelper; +import org.apache.lucene.util.Version; +import org.mockito.MockedStatic; +import org.mockito.Mockito; +import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.codec.KNNCodecVersion; +import org.opensearch.knn.index.engine.KNNEngine; +import org.opensearch.knn.index.vectorvalues.KNNFloatVectorValues; +import org.opensearch.knn.index.vectorvalues.KNNVectorValuesFactory; + +import java.nio.file.Path; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; + +import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; +import static org.opensearch.knn.KNNRestTestCase.FIELD_NAME; +import static org.opensearch.knn.KNNRestTestCase.INDEX_NAME; +import static org.opensearch.knn.common.KNNConstants.INDEX_DESCRIPTION_PARAMETER; +import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE; +import static org.opensearch.knn.common.KNNConstants.PARAMETERS; +import static org.opensearch.knn.common.KNNConstants.SPACE_TYPE; + +public class ExactSearcherTests extends KNNTestCase { + + private static final String SEGMENT_NAME = "0"; + + @SneakyThrows + public void testRadialSearch_whenNoEngineFiles_thenSuccess() { + try (MockedStatic valuesFactoryMockedStatic = Mockito.mockStatic(KNNVectorValuesFactory.class)) { + final float[] queryVector = new float[] { 0.1f, 2.0f, 3.0f }; + final SpaceType spaceType = randomFrom(SpaceType.L2, SpaceType.INNER_PRODUCT); + final List dataVectors = Arrays.asList( + new float[] { 11.0f, 12.0f, 13.0f }, + new float[] { 14.0f, 15.0f, 16.0f }, + new float[] { 17.0f, 18.0f, 19.0f } + ); + final List expectedScores = dataVectors.stream() + .map(vector -> spaceType.getKnnVectorSimilarityFunction().compare(queryVector, vector)) + .collect(Collectors.toList()); + final Float score = Collections.min(expectedScores); + final float radius = KNNEngine.FAISS.scoreToRadialThreshold(score, spaceType); + final int maxResults = 1000; + final KNNQuery.Context context = mock(KNNQuery.Context.class); + when(context.getMaxResultWindow()).thenReturn(maxResults); + KNNWeight.initialize(null); + + final KNNQuery query = KNNQuery.builder() + .field(FIELD_NAME) + .queryVector(queryVector) + .radius(radius) + .indexName(INDEX_NAME) + .context(context) + .build(); + + final ExactSearcher.ExactSearcherContext.ExactSearcherContextBuilder exactSearcherContextBuilder = + ExactSearcher.ExactSearcherContext.builder() + // setting to true, so that if quantization details are present we want to do search on the quantized + // vectors as this flow is used in first pass of search. + .useQuantizedVectorsForSearch(false) + .knnQuery(query); + + ExactSearcher exactSearcher = new ExactSearcher(null); + final LeafReaderContext leafReaderContext = mock(LeafReaderContext.class); + final SegmentReader reader = mock(SegmentReader.class); + when(leafReaderContext.reader()).thenReturn(reader); + + final FSDirectory directory = mock(FSDirectory.class); + when(reader.directory()).thenReturn(directory); + final SegmentInfo segmentInfo = new SegmentInfo( + directory, + Version.LATEST, + Version.LATEST, + SEGMENT_NAME, + 100, + false, + false, + KNNCodecVersion.current().getDefaultCodecDelegate(), + Map.of(), + new byte[StringHelper.ID_LENGTH], + Map.of(), + Sort.RELEVANCE + ); + segmentInfo.setFiles(Set.of()); + final SegmentCommitInfo segmentCommitInfo = new SegmentCommitInfo(segmentInfo, 0, 0, 0, 0, 0, new byte[StringHelper.ID_LENGTH]); + when(reader.getSegmentInfo()).thenReturn(segmentCommitInfo); + + final Path path = mock(Path.class); + when(directory.getDirectory()).thenReturn(path); + final FieldInfos fieldInfos = mock(FieldInfos.class); + final FieldInfo fieldInfo = mock(FieldInfo.class); + when(reader.getFieldInfos()).thenReturn(fieldInfos); + when(fieldInfos.fieldInfo(any())).thenReturn(fieldInfo); + when(fieldInfo.attributes()).thenReturn( + Map.of( + SPACE_TYPE, + spaceType.getValue(), + KNN_ENGINE, + KNNEngine.FAISS.getName(), + PARAMETERS, + String.format(Locale.ROOT, "{\"%s\":\"%s\"}", INDEX_DESCRIPTION_PARAMETER, "HNSW32") + ) + ); + when(fieldInfo.getAttribute(SPACE_TYPE)).thenReturn(spaceType.getValue()); + KNNFloatVectorValues floatVectorValues = mock(KNNFloatVectorValues.class); + valuesFactoryMockedStatic.when(() -> KNNVectorValuesFactory.getVectorValues(fieldInfo, reader)).thenReturn(floatVectorValues); + when(floatVectorValues.nextDoc()).thenReturn(0, 1, 2, NO_MORE_DOCS); + when(floatVectorValues.getVector()).thenReturn(dataVectors.get(0), dataVectors.get(1), dataVectors.get(2)); + final Map integerFloatMap = exactSearcher.searchLeaf(leafReaderContext, exactSearcherContextBuilder.build()); + assertEquals(integerFloatMap.size(), dataVectors.size()); + assertEquals(expectedScores, new ArrayList<>(integerFloatMap.values())); + } + } +} diff --git a/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java b/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java index 2a2c3ed4d..482e7e0cb 100644 --- a/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java +++ b/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java @@ -90,6 +90,7 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mockStatic; import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import static org.opensearch.knn.KNNRestTestCase.INDEX_NAME; import static org.opensearch.knn.common.KNNConstants.INDEX_DESCRIPTION_PARAMETER; @@ -328,8 +329,9 @@ public void testQueryScoreForFaissWithNonExistingModel() throws IOException { } @SneakyThrows - public void testShardWithoutFiles() { + public void testScorer_whenNoVectorFieldsInDocument_thenEmptyScorerIsReturned() { final KNNQuery query = new KNNQuery(FIELD_NAME, QUERY_VECTOR, K, INDEX_NAME, (BitSetProducer) null); + KNNWeight.initialize(null); final KNNWeight knnWeight = new KNNWeight(query, 0.0f); final LeafReaderContext leafReaderContext = mock(LeafReaderContext.class); @@ -360,10 +362,9 @@ public void testShardWithoutFiles() { final Path path = mock(Path.class); when(directory.getDirectory()).thenReturn(path); final FieldInfos fieldInfos = mock(FieldInfos.class); - final FieldInfo fieldInfo = mock(FieldInfo.class); when(reader.getFieldInfos()).thenReturn(fieldInfos); - when(fieldInfos.fieldInfo(any())).thenReturn(fieldInfo); - + // When no knn fields are available , field info for vector field will be null + when(fieldInfos.fieldInfo(FIELD_NAME)).thenReturn(null); final Scorer knnScorer = knnWeight.scorer(leafReaderContext); assertEquals(KNNScorer.emptyScorer(knnWeight), knnScorer); } @@ -868,6 +869,83 @@ public void validateANNWithFilterQuery_whenExactSearch_thenSuccess(final boolean } } + @SneakyThrows + public void testRadialSearch_whenNoEngineFiles_thenPerformExactSearch() { + ExactSearcher mockedExactSearcher = mock(ExactSearcher.class); + final float[] queryVector = new float[] { 0.1f, 2.0f, 3.0f }; + final SpaceType spaceType = randomFrom(SpaceType.L2, SpaceType.INNER_PRODUCT); + KNNWeight.initialize(null, mockedExactSearcher); + final KNNQuery query = KNNQuery.builder() + .field(FIELD_NAME) + .queryVector(queryVector) + .indexName(INDEX_NAME) + .methodParameters(HNSW_METHOD_PARAMETERS) + .build(); + final KNNWeight knnWeight = new KNNWeight(query, 1.0f); + + final LeafReaderContext leafReaderContext = mock(LeafReaderContext.class); + final SegmentReader reader = mock(SegmentReader.class); + when(leafReaderContext.reader()).thenReturn(reader); + + final FSDirectory directory = mock(FSDirectory.class); + when(reader.directory()).thenReturn(directory); + final SegmentInfo segmentInfo = new SegmentInfo( + directory, + Version.LATEST, + Version.LATEST, + SEGMENT_NAME, + 100, + false, + false, + KNNCodecVersion.current().getDefaultCodecDelegate(), + Map.of(), + new byte[StringHelper.ID_LENGTH], + Map.of(), + Sort.RELEVANCE + ); + segmentInfo.setFiles(Set.of()); + final SegmentCommitInfo segmentCommitInfo = new SegmentCommitInfo(segmentInfo, 0, 0, 0, 0, 0, new byte[StringHelper.ID_LENGTH]); + when(reader.getSegmentInfo()).thenReturn(segmentCommitInfo); + + final Path path = mock(Path.class); + when(directory.getDirectory()).thenReturn(path); + final FieldInfos fieldInfos = mock(FieldInfos.class); + final FieldInfo fieldInfo = mock(FieldInfo.class); + when(reader.getFieldInfos()).thenReturn(fieldInfos); + when(fieldInfos.fieldInfo(FIELD_NAME)).thenReturn(fieldInfo); + when(fieldInfo.attributes()).thenReturn( + Map.of( + SPACE_TYPE, + spaceType.getValue(), + KNN_ENGINE, + KNNEngine.FAISS.getName(), + PARAMETERS, + String.format(Locale.ROOT, "{\"%s\":\"%s\"}", INDEX_DESCRIPTION_PARAMETER, "HNSW32") + ) + ); + final ExactSearcher.ExactSearcherContext exactSearchContext = ExactSearcher.ExactSearcherContext.builder() + .isParentHits(true) + // setting to true, so that if quantization details are present we want to do search on the quantized + // vectors as this flow is used in first pass of search. + .useQuantizedVectorsForSearch(true) + .knnQuery(query) + .build(); + when(mockedExactSearcher.searchLeaf(leafReaderContext, exactSearchContext)).thenReturn(DOC_ID_TO_SCORES); + final KNNScorer knnScorer = (KNNScorer) knnWeight.scorer(leafReaderContext); + assertNotNull(knnScorer); + final DocIdSetIterator docIdSetIterator = knnScorer.iterator(); + final List actualDocIds = new ArrayList<>(); + for (int docId = docIdSetIterator.nextDoc(); docId != NO_MORE_DOCS; docId = docIdSetIterator.nextDoc()) { + actualDocIds.add(docId); + assertEquals(DOC_ID_TO_SCORES.get(docId), knnScorer.score(), 0.00000001f); + } + assertEquals(docIdSetIterator.cost(), actualDocIds.size()); + assertTrue(Comparators.isInOrder(actualDocIds, Comparator.naturalOrder())); + // verify JNI Service is not called + jniServiceMockedStatic.verifyNoInteractions(); + verify(mockedExactSearcher).searchLeaf(leafReaderContext, exactSearchContext); + } + @SneakyThrows public void testANNWithFilterQuery_whenExactSearchAndThresholdComputations_thenSuccess() { ModelDao modelDao = mock(ModelDao.class); diff --git a/src/test/java/org/opensearch/knn/integ/BinaryIndexIT.java b/src/test/java/org/opensearch/knn/integ/BinaryIndexIT.java index d5b79768d..4fb267eb5 100644 --- a/src/test/java/org/opensearch/knn/integ/BinaryIndexIT.java +++ b/src/test/java/org/opensearch/knn/integ/BinaryIndexIT.java @@ -5,6 +5,7 @@ package org.opensearch.knn.integ; +import com.google.common.collect.ImmutableList; import com.google.common.primitives.Floats; import lombok.SneakyThrows; import lombok.extern.log4j.Log4j2; @@ -13,11 +14,13 @@ import org.junit.After; import org.junit.BeforeClass; import org.opensearch.client.Response; +import org.opensearch.common.settings.Settings; import org.opensearch.knn.KNNJsonIndexMappingsBuilder; import org.opensearch.knn.KNNJsonQueryBuilder; import org.opensearch.knn.KNNRestTestCase; import org.opensearch.knn.KNNResult; import org.opensearch.knn.TestUtils; +import org.opensearch.knn.index.KNNSettings; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.engine.KNNEngine; @@ -36,6 +39,8 @@ @Log4j2 public class BinaryIndexIT extends KNNRestTestCase { private static TestUtils.TestData testData; + private static final int NEVER_BUILD_GRAPH = -1; + private static final int ALWAYS_BUILD_GRAPH = 0; @BeforeClass public static void setUpClass() throws IOException { @@ -104,6 +109,52 @@ public void testFaissHnswBinary_when1000Data_thenRecallIsAboveNinePointZero() { } } + @SneakyThrows + public void testFaissHnswBinary_whenBuildVectorGraphThresholdIsNegativeEndToEnd_thenBuildGraphBasedOnSetting() { + // Create Index + createKnnHnswBinaryIndex(KNNEngine.FAISS, INDEX_NAME, FIELD_NAME, 128, NEVER_BUILD_GRAPH); + ingestTestData(INDEX_NAME, FIELD_NAME); + + assertEquals(1, runKnnQuery(INDEX_NAME, FIELD_NAME, testData.queries[0], 1).size()); + + // update build vector data structure setting + updateIndexSettings(INDEX_NAME, Settings.builder().put(KNNSettings.INDEX_KNN_ADVANCED_APPROXIMATE_THRESHOLD, ALWAYS_BUILD_GRAPH)); + forceMergeKnnIndex(INDEX_NAME, 1); + + int k = 100; + for (int i = 0; i < testData.queries.length; i++) { + List knnResults = runKnnQuery(INDEX_NAME, FIELD_NAME, testData.queries[i], k); + float recall = getRecall( + Set.of(Arrays.copyOf(testData.groundTruthValues[i], k)), + knnResults.stream().map(KNNResult::getDocId).collect(Collectors.toSet()) + ); + assertTrue("Recall: " + recall, recall > 0.1); + } + } + + @SneakyThrows + public void testFaissHnswBinary_whenBuildVectorGraphThresholdIsProvidedEndToEnd_thenBuildGraphBasedOnSetting() { + // Create Index + createKnnHnswBinaryIndex(KNNEngine.FAISS, INDEX_NAME, FIELD_NAME, 128, testData.indexData.docs.length); + ingestTestData(INDEX_NAME, FIELD_NAME, false); + + assertEquals(1, runKnnQuery(INDEX_NAME, FIELD_NAME, testData.queries[0], 1).size()); + + // update build vector data structure setting + updateIndexSettings(INDEX_NAME, Settings.builder().put(KNNSettings.INDEX_KNN_ADVANCED_APPROXIMATE_THRESHOLD, ALWAYS_BUILD_GRAPH)); + forceMergeKnnIndex(INDEX_NAME, 1); + + int k = 100; + for (int i = 0; i < testData.queries.length; i++) { + List knnResults = runKnnQuery(INDEX_NAME, FIELD_NAME, testData.queries[i], k); + float recall = getRecall( + Set.of(Arrays.copyOf(testData.groundTruthValues[i], k)), + knnResults.stream().map(KNNResult::getDocId).collect(Collectors.toSet()) + ); + assertTrue("Recall: " + recall, recall > 0.1); + } + } + @SneakyThrows public void testFaissHnswBinary_whenRadialSearch_thenThrowException() { // Create Index @@ -157,13 +208,18 @@ private List runKnnQuery(final String indexName, final String fieldNa } private void ingestTestData(final String indexName, final String fieldName) throws Exception { + ingestTestData(indexName, fieldName, true); + } + + private void ingestTestData(final String indexName, final String fieldName, boolean refresh) throws Exception { // Index the test data for (int i = 0; i < testData.indexData.docs.length; i++) { addKnnDoc( indexName, Integer.toString(testData.indexData.docs[i]), - fieldName, - Floats.asList(testData.indexData.vectors[i]).toArray() + ImmutableList.of(fieldName), + ImmutableList.of(Floats.asList(testData.indexData.vectors[i]).toArray()), + refresh ); } @@ -172,8 +228,13 @@ private void ingestTestData(final String indexName, final String fieldName) thro assertEquals(testData.indexData.docs.length, getDocCount(indexName)); } - private void createKnnHnswBinaryIndex(final KNNEngine knnEngine, final String indexName, final String fieldName, final int dimension) - throws IOException { + private void createKnnHnswBinaryIndex( + final KNNEngine knnEngine, + final String indexName, + final String fieldName, + final int dimension, + final int threshold + ) throws IOException { KNNJsonIndexMappingsBuilder.Method method = KNNJsonIndexMappingsBuilder.Method.builder() .methodName(METHOD_HNSW) .engine(knnEngine.getName()) @@ -186,7 +247,11 @@ private void createKnnHnswBinaryIndex(final KNNEngine knnEngine, final String in .method(method) .build() .getIndexMapping(); + createKnnIndex(indexName, buildKNNIndexSettings(threshold), knnIndexMapping); + } - createKnnIndex(indexName, knnIndexMapping); + private void createKnnHnswBinaryIndex(final KNNEngine knnEngine, final String indexName, final String fieldName, final int dimension) + throws IOException { + createKnnHnswBinaryIndex(knnEngine, indexName, fieldName, dimension, ALWAYS_BUILD_GRAPH); } } diff --git a/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java b/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java index 8c033ca03..2c2efed3d 100644 --- a/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java +++ b/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java @@ -100,6 +100,8 @@ import static org.opensearch.knn.TestUtils.computeGroundTruthValues; import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD; +import static org.opensearch.knn.index.KNNSettings.INDEX_KNN_ADVANCED_APPROXIMATE_THRESHOLD; +import static org.opensearch.knn.index.KNNSettings.KNN_INDEX; import static org.opensearch.knn.index.SpaceType.L2; import static org.opensearch.knn.index.memory.NativeMemoryCacheManager.GRAPH_COUNT; import static org.opensearch.knn.index.engine.KNNEngine.FAISS; @@ -638,7 +640,12 @@ protected void addKnnDocWithNestedField(String index, String docId, String neste * Add a single KNN Doc to an index with multiple fields */ protected void addKnnDoc(String index, String docId, List fieldNames, List vectors) throws IOException { - Request request = new Request("POST", "/" + index + "/_doc/" + docId + "?refresh=true"); + addKnnDoc(index, docId, fieldNames, vectors, true); + } + + protected void addKnnDoc(String index, String docId, List fieldNames, List vectors, boolean refresh) + throws IOException { + Request request = new Request("POST", "/" + index + "/_doc/" + docId + "?refresh=" + refresh); XContentBuilder builder = XContentFactory.jsonBuilder().startObject(); for (int i = 0; i < fieldNames.size(); i++) { @@ -768,6 +775,15 @@ protected Settings getKNNSegmentReplicatedIndexSettings() { .build(); } + protected Settings buildKNNIndexSettings(int approximateThreshold) { + return Settings.builder() + .put("number_of_shards", 1) + .put("number_of_replicas", 0) + .put(KNN_INDEX, true) + .put(INDEX_KNN_ADVANCED_APPROXIMATE_THRESHOLD, approximateThreshold) + .build(); + } + @SneakyThrows protected int getDataNodeCount() { Request request = new Request("GET", "_nodes/stats?filter_path=nodes.*.roles");