From f258385955f67d21c065537b7ec82bf9f5f3e954 Mon Sep 17 00:00:00 2001 From: yangzq50 <58433399+yangzq50@users.noreply.github.com> Date: Wed, 16 Oct 2024 20:46:46 +0800 Subject: [PATCH] Add MinimumShouldMatchIterator (#2056) ### What problem does this PR solve? Add MinimumShouldMatchIterator Speed up fulltext filter with minimum_should_match option Issue link:#1862 ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) - [x] New Feature (non-breaking change which adds functionality) - [x] Refactoring - [x] Performance Improvement --- src/common/stl.cppm | 4 +- src/executor/operator/physical_match.cpp | 60 ++---- .../index_scan/index_filter_evaluators.cpp | 50 ++--- .../index_scan/index_filter_evaluators.cppm | 7 +- .../invertedindex/search/and_iterator.cpp | 16 +- .../invertedindex/search/and_iterator.cppm | 2 - .../invertedindex/search/and_not_iterator.cpp | 2 - .../search/and_not_iterator.cppm | 2 - .../search/blockmax_maxscore_iterator.cpp | 5 - .../search/blockmax_maxscore_iterator.cppm | 2 - .../search/blockmax_wand_iterator.cpp | 6 - .../search/blockmax_wand_iterator.cppm | 4 +- .../invertedindex/search/doc_iterator.cppm | 2 +- .../search/minimum_should_match_iterator.cpp | 187 ++++++++++++++++++ .../search/minimum_should_match_iterator.cppm | 79 ++++++++ .../invertedindex/search/or_iterator.cpp | 6 - .../invertedindex/search/or_iterator.cppm | 2 - .../search/phrase_doc_iterator.cppm | 1 - .../invertedindex/search/query_builder.cpp | 11 +- .../invertedindex/search/query_builder.cppm | 5 +- .../invertedindex/search/query_node.cpp | 114 ++++++++--- src/storage/invertedindex/search/query_node.h | 46 ++++- .../search/term_doc_iterator.cppm | 1 - .../invertedindex/search/query_builder.cpp | 12 +- .../invertedindex/search/query_match.cpp | 3 +- 25 files changed, 463 insertions(+), 166 deletions(-) create mode 100644 src/storage/invertedindex/search/minimum_should_match_iterator.cpp create mode 100644 src/storage/invertedindex/search/minimum_should_match_iterator.cppm diff --git a/src/common/stl.cppm b/src/common/stl.cppm index 6406870e42..4e3c1a92e9 100644 --- a/src/common/stl.cppm +++ b/src/common/stl.cppm @@ -141,11 +141,12 @@ export namespace std { using std::isnan; using std::log2; using std::make_heap; + using std::push_heap; + using std::pop_heap; using std::max_element; using std::min_element; using std::nearbyint; using std::partial_sort; - using std::pop_heap; using std::pow; using std::reduce; using std::remove_if; @@ -291,6 +292,7 @@ export namespace std { using std::conditional_t; using std::remove_pointer_t; using std::remove_reference_t; + using std::derived_from; using std::function; using std::monostate; diff --git a/src/executor/operator/physical_match.cpp b/src/executor/operator/physical_match.cpp index f379126991..dbc7314cc2 100644 --- a/src/executor/operator/physical_match.cpp +++ b/src/executor/operator/physical_match.cpp @@ -103,7 +103,6 @@ class FilterIterator final : public DocIterator { void UpdateScoreThreshold(float threshold) override { query_iterator_->UpdateScoreThreshold(threshold); } // for minimum_should_match parameter - u32 LeafCount() const override { return query_iterator_->LeafCount(); } u32 MatchCount() const override { return query_iterator_->MatchCount(); } void PrintTree(std::ostream &os, const String &prefix, bool is_final) const override { @@ -144,19 +143,24 @@ struct FilterQueryNode final : public QueryNode { query_tree_ = std::move(new_query_tree); } + uint32_t LeafCount() const override { return query_tree_->LeafCount(); } + void PushDownWeight(float factor) override { MultiplyWeight(factor); } - std::unique_ptr - CreateSearch(const TableEntry *table_entry, const IndexReader &index_reader, EarlyTermAlgo early_term_algo) const override { + std::unique_ptr CreateSearch(const TableEntry *table_entry, + const IndexReader &index_reader, + const EarlyTermAlgo early_term_algo, + const u32 minimum_should_match) const override { assert(common_query_filter_ != nullptr); if (!common_query_filter_->AlwaysTrue() && common_query_filter_->filter_result_count_ == 0) return nullptr; - auto search_iter = query_tree_->CreateSearch(table_entry, index_reader, early_term_algo); + auto search_iter = query_tree_->CreateSearch(table_entry, index_reader, early_term_algo, minimum_should_match); if (!search_iter) { return nullptr; } - if (common_query_filter_->AlwaysTrue()) + if (common_query_filter_->AlwaysTrue()) { return search_iter; + } return MakeUnique(common_query_filter_, std::move(search_iter)); } @@ -186,20 +190,18 @@ void ASSERT_FLOAT_EQ(float bar, u32 i, float a, float b) { } } -template -void ExecuteFTSearchT(UniquePtr &et_iter, FullTextScoreResultHeap &result_heap, u32 &blockmax_loop_cnt, const u32 minimum_should_match) { +void ExecuteFTSearch(UniquePtr &et_iter, FullTextScoreResultHeap &result_heap, u32 &blockmax_loop_cnt) { + // et_iter is nullptr if fulltext index is present but there's no data + if (et_iter == nullptr) { + LOG_DEBUG(fmt::format("et_iter is nullptr")); + return; + } while (true) { ++blockmax_loop_cnt; bool ok = et_iter->Next(); if (!ok) [[unlikely]] { break; } - if constexpr (use_minimum_should_match) { - assert(minimum_should_match >= 2); - if (et_iter->MatchCount() < minimum_should_match) { - continue; - } - } RowID id = et_iter->DocID(); float et_score = et_iter->BM25Score(); if (SHOULD_LOG_DEBUG()) { @@ -219,30 +221,6 @@ void ExecuteFTSearchT(UniquePtr &et_iter, FullTextScoreResultHeap & } } -void ExecuteFTSearch(UniquePtr &et_iter, - FullTextScoreResultHeap &result_heap, - u32 &blockmax_loop_cnt, - const MinimumShouldMatchOption &minimum_should_match_option) { - // et_iter is nullptr if fulltext index is present but there's no data - if (et_iter == nullptr) { - LOG_DEBUG(fmt::format("et_iter is nullptr")); - return; - } - u32 minimum_should_match_val = 0; - if (!minimum_should_match_option.empty()) { - const auto leaf_count = et_iter->LeafCount(); - minimum_should_match_val = GetMinimumShouldMatchParameter(minimum_should_match_option, leaf_count); - } - if (minimum_should_match_val <= 1) { - // no need for minimum_should_match - return ExecuteFTSearchT(et_iter, result_heap, blockmax_loop_cnt, 0); - } else { - // now minimum_should_match_val >= 2 - // use minimum_should_match - return ExecuteFTSearchT(et_iter, result_heap, blockmax_loop_cnt, minimum_should_match_val); - } -} - #pragma clang diagnostic push #pragma clang diagnostic ignored "-Wunused-variable" #pragma clang diagnostic ignored "-Wunused-but-set-variable" @@ -310,13 +288,13 @@ bool PhysicalMatch::ExecuteInnerHomebrewed(QueryContext *query_context, Operator full_text_query_context.query_tree_ = MakeUnique(common_query_filter_.get(), std::move(query_tree_)); if (use_block_max_iter) { - et_iter = query_builder.CreateSearch(full_text_query_context, early_term_algo_); + et_iter = query_builder.CreateSearch(full_text_query_context, early_term_algo_, minimum_should_match_option_); // et_iter is nullptr if fulltext index is present but there's no data if (et_iter != nullptr) et_iter->UpdateScoreThreshold(begin_threshold_); } if (use_ordinary_iter) { - doc_iterator = query_builder.CreateSearch(full_text_query_context, EarlyTermAlgo::kNaive); + doc_iterator = query_builder.CreateSearch(full_text_query_context, EarlyTermAlgo::kNaive, minimum_should_match_option_); } // 3 full text search @@ -331,7 +309,7 @@ bool PhysicalMatch::ExecuteInnerHomebrewed(QueryContext *query_context, Operator #ifdef INFINITY_DEBUG auto blockmax_begin_ts = std::chrono::high_resolution_clock::now(); #endif - ExecuteFTSearch(et_iter, result_heap, blockmax_loop_cnt, minimum_should_match_option_); + ExecuteFTSearch(et_iter, result_heap, blockmax_loop_cnt); result_heap.Sort(); blockmax_result_count = result_heap.GetResultSize(); #ifdef INFINITY_DEBUG @@ -346,7 +324,7 @@ bool PhysicalMatch::ExecuteInnerHomebrewed(QueryContext *query_context, Operator #ifdef INFINITY_DEBUG auto ordinary_begin_ts = std::chrono::high_resolution_clock::now(); #endif - ExecuteFTSearch(doc_iterator, result_heap, ordinary_loop_cnt, minimum_should_match_option_); + ExecuteFTSearch(doc_iterator, result_heap, ordinary_loop_cnt); result_heap.Sort(); ordinary_result_count = result_heap.GetResultSize(); #ifdef INFINITY_DEBUG diff --git a/src/planner/optimizer/index_scan/index_filter_evaluators.cpp b/src/planner/optimizer/index_scan/index_filter_evaluators.cpp index 146b881365..2a8370c02c 100644 --- a/src/planner/optimizer/index_scan/index_filter_evaluators.cpp +++ b/src/planner/optimizer/index_scan/index_filter_evaluators.cpp @@ -451,35 +451,35 @@ UniquePtr IndexFilterEvaluatorSecondary::Make(con } } +void IndexFilterEvaluatorFulltext::OptimizeQueryTree() { + if (after_optimize_.test(std::memory_order_acquire)) { + UnrecoverableError(std::format("{}: Already optimized!", __func__)); + } + auto new_query_tree = QueryNode::GetOptimizedQueryTree(std::move(query_tree_)); + query_tree_ = std::move(new_query_tree); + if (!minimum_should_match_option_.empty()) { + const auto leaf_count = query_tree_->LeafCount(); + minimum_should_match_ = GetMinimumShouldMatchParameter(minimum_should_match_option_, leaf_count); + } + after_optimize_.test_and_set(std::memory_order_release); +} + Bitmask IndexFilterEvaluatorFulltext::Evaluate(const SegmentID segment_id, const SegmentOffset segment_row_count, Txn *txn) const { + if (!after_optimize_.test(std::memory_order_acquire)) { + UnrecoverableError(std::format("{}: Not optimized!", __func__)); + } Bitmask result(segment_row_count); result.SetAllFalse(); const RowID begin_rowid(segment_id, 0); const RowID end_rowid(segment_id, segment_row_count); - if (const auto ft_iter = query_tree_->CreateSearch(table_entry_, index_reader_, early_term_algo_); ft_iter && ft_iter->Next(begin_rowid)) { - u32 minimum_should_match_val = 0; - if (!minimum_should_match_option_.empty()) { - const auto leaf_count = ft_iter->LeafCount(); - minimum_should_match_val = GetMinimumShouldMatchParameter(minimum_should_match_option_, leaf_count); - } - if (minimum_should_match_val <= 1) { - // no need for minimum_should_match - while (ft_iter->DocID() < end_rowid) { - result.SetTrue(ft_iter->DocID().segment_offset_); - ft_iter->Next(); - } - } else { - // now minimum_should_match_val >= 2 - // use minimum_should_match - while (ft_iter->DocID() < end_rowid) { - if (ft_iter->MatchCount() >= minimum_should_match_val) { - result.SetTrue(ft_iter->DocID().segment_offset_); - } - ft_iter->Next(); - } + if (const auto ft_iter = query_tree_->CreateSearch(table_entry_, index_reader_, early_term_algo_, minimum_should_match_); + ft_iter && ft_iter->Next(begin_rowid)) { + while (ft_iter->DocID() < end_rowid) { + result.SetTrue(ft_iter->DocID().segment_offset_); + ft_iter->Next(); } + result.RunOptimize(); } - result.RunOptimize(); return result; } @@ -503,9 +503,13 @@ Bitmask IndexFilterEvaluatorAND::Evaluate(const SegmentID segment_id, const Segm const auto &roaring_end = result.End(); Bitmask new_result(segment_row_count); new_result.SetAllFalse(); + if (!fulltext_evaluator_->after_optimize_.test(std::memory_order_acquire)) { + UnrecoverableError(std::format("{}: Not optimized!", __func__)); + } const auto ft_iter = fulltext_evaluator_->query_tree_->CreateSearch(fulltext_evaluator_->table_entry_, fulltext_evaluator_->index_reader_, - fulltext_evaluator_->early_term_algo_); + fulltext_evaluator_->early_term_algo_, + fulltext_evaluator_->minimum_should_match_); if (ft_iter) { const RowID end_rowid(segment_id, segment_row_count); while (roaring_begin != roaring_end && ft_iter->Next(RowID(segment_id, *roaring_begin))) { diff --git a/src/planner/optimizer/index_scan/index_filter_evaluators.cppm b/src/planner/optimizer/index_scan/index_filter_evaluators.cppm index 08ce8d1887..50db759027 100644 --- a/src/planner/optimizer/index_scan/index_filter_evaluators.cppm +++ b/src/planner/optimizer/index_scan/index_filter_evaluators.cppm @@ -90,6 +90,8 @@ export struct IndexFilterEvaluatorFulltext final : IndexFilterEvaluator { IndexReader index_reader_; UniquePtr query_tree_; MinimumShouldMatchOption minimum_should_match_option_; + u32 minimum_should_match_ = 0; + std::atomic_flag after_optimize_ = {}; IndexFilterEvaluatorFulltext(const FilterFulltextExpression *src_filter_fulltext_expression, const TableEntry *table_entry, @@ -102,10 +104,7 @@ export struct IndexFilterEvaluatorFulltext final : IndexFilterEvaluator { minimum_should_match_option_(std::move(minimum_should_match_option)) {} Bitmask Evaluate(SegmentID segment_id, SegmentOffset segment_row_count, Txn *txn) const override; bool HaveMinimumShouldMatchOption() const { return !minimum_should_match_option_.empty(); } - void OptimizeQueryTree() { - auto new_query_tree = QueryNode::GetOptimizedQueryTree(std::move(query_tree_)); - query_tree_ = std::move(new_query_tree); - } + void OptimizeQueryTree(); }; export UniquePtr IndexFilterEvaluatorBuildFromAnd(Vector> candidates); diff --git a/src/storage/invertedindex/search/and_iterator.cpp b/src/storage/invertedindex/search/and_iterator.cpp index 8f14b0d69d..481103c2f5 100644 --- a/src/storage/invertedindex/search/and_iterator.cpp +++ b/src/storage/invertedindex/search/and_iterator.cpp @@ -42,15 +42,7 @@ AndIterator::AndIterator(Vector> iterators) : MultiDocIte ++fixed_match_count_; break; } - case DocIteratorType::kAndIterator: - case DocIteratorType::kAndNotIterator: - case DocIteratorType::kFilterIterator: { - UnrecoverableError("Wrong optimization result"); - break; - } - case DocIteratorType::kOrIterator: - case DocIteratorType::kBMMIterator: - case DocIteratorType::kBMWIterator: { + default: { dyn_match_ids_.push_back(i); break; } @@ -124,12 +116,6 @@ void AndIterator::UpdateScoreThreshold(float threshold) { } } -u32 AndIterator::LeafCount() const { - return std::accumulate(children_.begin(), children_.end(), static_cast(0), [](const u32 cnt, const auto &it) { - return cnt + it->LeafCount(); - }); -} - u32 AndIterator::MatchCount() const { if (DocID() == INVALID_ROWID) { return 0; diff --git a/src/storage/invertedindex/search/and_iterator.cppm b/src/storage/invertedindex/search/and_iterator.cppm index 020bd662c3..405dadf72b 100644 --- a/src/storage/invertedindex/search/and_iterator.cppm +++ b/src/storage/invertedindex/search/and_iterator.cppm @@ -39,8 +39,6 @@ public: void UpdateScoreThreshold(float threshold) override; - u32 LeafCount() const override; - u32 MatchCount() const override; private: diff --git a/src/storage/invertedindex/search/and_not_iterator.cpp b/src/storage/invertedindex/search/and_not_iterator.cpp index f55fe93175..7b1a60233f 100644 --- a/src/storage/invertedindex/search/and_not_iterator.cpp +++ b/src/storage/invertedindex/search/and_not_iterator.cpp @@ -60,8 +60,6 @@ float AndNotIterator::BM25Score() { return children_[0]->BM25Score(); } void AndNotIterator::UpdateScoreThreshold(float threshold) { children_[0]->UpdateScoreThreshold(threshold); } -u32 AndNotIterator::LeafCount() const { return children_[0]->LeafCount(); } - u32 AndNotIterator::MatchCount() const { return children_[0]->MatchCount(); } } // namespace infinity \ No newline at end of file diff --git a/src/storage/invertedindex/search/and_not_iterator.cppm b/src/storage/invertedindex/search/and_not_iterator.cppm index c2dc0740ad..245bfb8488 100644 --- a/src/storage/invertedindex/search/and_not_iterator.cppm +++ b/src/storage/invertedindex/search/and_not_iterator.cppm @@ -38,8 +38,6 @@ public: void UpdateScoreThreshold(float threshold) override; - u32 LeafCount() const override; - u32 MatchCount() const override; }; diff --git a/src/storage/invertedindex/search/blockmax_maxscore_iterator.cpp b/src/storage/invertedindex/search/blockmax_maxscore_iterator.cpp index 13b1196d4a..8c74837e0e 100644 --- a/src/storage/invertedindex/search/blockmax_maxscore_iterator.cpp +++ b/src/storage/invertedindex/search/blockmax_maxscore_iterator.cpp @@ -105,11 +105,6 @@ void BlockMaxMaxscoreIterator::UpdateScoreThreshold(const float threshold) { threshold_ = threshold; } -u32 BlockMaxMaxscoreIterator::LeafCount() const { - UnrecoverableError("BMM not supported now"); - return 0; -} - u32 BlockMaxMaxscoreIterator::MatchCount() const { UnrecoverableError("BMM not supported now"); return 0; diff --git a/src/storage/invertedindex/search/blockmax_maxscore_iterator.cppm b/src/storage/invertedindex/search/blockmax_maxscore_iterator.cppm index 455ac66f51..b84c447087 100644 --- a/src/storage/invertedindex/search/blockmax_maxscore_iterator.cppm +++ b/src/storage/invertedindex/search/blockmax_maxscore_iterator.cppm @@ -41,8 +41,6 @@ public: float BM25Score() override; - u32 LeafCount() const override; - u32 MatchCount() const override; private: diff --git a/src/storage/invertedindex/search/blockmax_wand_iterator.cpp b/src/storage/invertedindex/search/blockmax_wand_iterator.cpp index 82aedc570f..a7c1767123 100644 --- a/src/storage/invertedindex/search/blockmax_wand_iterator.cpp +++ b/src/storage/invertedindex/search/blockmax_wand_iterator.cpp @@ -257,12 +257,6 @@ float BlockMaxWandIterator::BM25Score() { return sum_score; } -u32 BlockMaxWandIterator::LeafCount() const { - return std::accumulate(children_.begin(), children_.end(), static_cast(0), [](const u32 cnt, const auto &it) { - return cnt + it->LeafCount(); - }); -} - u32 BlockMaxWandIterator::MatchCount() const { u32 count = 0; if (const auto current_doc_id = DocID(); current_doc_id != INVALID_ROWID) { diff --git a/src/storage/invertedindex/search/blockmax_wand_iterator.cppm b/src/storage/invertedindex/search/blockmax_wand_iterator.cppm index b89b76007b..a3a1a1681d 100644 --- a/src/storage/invertedindex/search/blockmax_wand_iterator.cppm +++ b/src/storage/invertedindex/search/blockmax_wand_iterator.cppm @@ -25,7 +25,7 @@ import internal_types; namespace infinity { // Refers to https://engineering.nyu.edu/~suel/papers/bmw.pdf -export class BlockMaxWandIterator final : public MultiDocIterator { +export class BlockMaxWandIterator : public MultiDocIterator { public: explicit BlockMaxWandIterator(Vector> &&iterators); @@ -41,8 +41,6 @@ public: float BM25Score() override; - u32 LeafCount() const override; - u32 MatchCount() const override; private: diff --git a/src/storage/invertedindex/search/doc_iterator.cppm b/src/storage/invertedindex/search/doc_iterator.cppm index 00fbc00a5a..4bd49acaf1 100644 --- a/src/storage/invertedindex/search/doc_iterator.cppm +++ b/src/storage/invertedindex/search/doc_iterator.cppm @@ -31,6 +31,7 @@ export enum class DocIteratorType : u8 { kAndIterator, kAndNotIterator, kOrIterator, + kMinimumShouldMatchIterator, kBMMIterator, kBMWIterator, kFilterIterator, @@ -73,7 +74,6 @@ public: virtual void UpdateScoreThreshold(float threshold) = 0; // for minimum_should_match parameter - virtual u32 LeafCount() const = 0; virtual u32 MatchCount() const = 0; // print the query tree, for debugging diff --git a/src/storage/invertedindex/search/minimum_should_match_iterator.cpp b/src/storage/invertedindex/search/minimum_should_match_iterator.cpp new file mode 100644 index 0000000000..4f51298257 --- /dev/null +++ b/src/storage/invertedindex/search/minimum_should_match_iterator.cpp @@ -0,0 +1,187 @@ +// Copyright(C) 2023 InfiniFlow, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +module; + +#include +#include +#include +module minimum_should_match_iterator; +import stl; +import index_defines; +import multi_doc_iterator; +import internal_types; +import logger; +import infinity_exception; + +namespace infinity { + +MinimumShouldMatchIterator::MinimumShouldMatchIterator(Vector> &&iterators, const u32 minimum_should_match) + : MultiDocIterator(std::move(iterators)), minimum_should_match_(minimum_should_match) { +#ifndef NDEBUG + // validate children + for (const auto &child : children_) { + switch (child->GetType()) { + case DocIteratorType::kTermDocIterator: + case DocIteratorType::kPhraseIterator: { + // acceptable + break; + } + default: { + UnrecoverableError("MinimumShouldMatchIterator can only accept TermDocIterator or PhraseIterator as children"); + } + } + } +#endif + if (minimum_should_match_ <= 1u) { + UnrecoverableError("MinimumShouldMatchIterator should have minimum_should_match > 1"); + } + tail_heap_.resize(minimum_should_match_ - 1u); + bm25_score_cache_docid_ = INVALID_ROWID; +} + +MinimumShouldMatchIterator::~MinimumShouldMatchIterator() {} + +void MinimumShouldMatchIterator::UpdateScoreThreshold(float threshold) { UnrecoverableError("Unreachable code"); } + +bool MinimumShouldMatchIterator::Next(RowID doc_id) { + if (doc_id_ == INVALID_ROWID) { + // Initialize once. + lead_.resize(children_.size()); + std::iota(lead_.begin(), lead_.end(), 0u); + } else if (doc_id_ >= doc_id) { + return true; + } + for (;; ++doc_id) { + for (const auto idx : lead_) { + if (const auto [have_pop, pop_idx] = PushToTailHeap(idx); have_pop) { + if (children_[pop_idx]->Next(doc_id)) { + PushToHeadHeap(pop_idx); + } + } + } + lead_.clear(); + while (HeadHeapTop() < doc_id) { + const auto idx = PopFromHeadHeap(); + if (const auto [have_pop, pop_idx] = PushToTailHeap(idx); have_pop) { + if (children_[pop_idx]->Next(doc_id)) { + PushToHeadHeap(pop_idx); + } + } + } + if (tail_size_ + head_heap_.size() < minimum_should_match_) { + doc_id_ = INVALID_ROWID; + return false; + } + doc_id = HeadHeapTop(); + assert(doc_id != INVALID_ROWID); + while (HeadHeapTop() == doc_id) { + lead_.push_back(PopFromHeadHeap()); + } + if (lead_.size() >= minimum_should_match_) { + doc_id_ = doc_id; + return true; + } + while (lead_.size() + tail_size_ >= minimum_should_match_) { + // advance tail + if (const auto tail_idx = PopFromTailHeap(); children_[tail_idx]->Next(doc_id)) { + if (children_[tail_idx]->DocID() > doc_id) { + PushToHeadHeap(tail_idx); + } else if (lead_.push_back(tail_idx); lead_.size() >= minimum_should_match_) { + doc_id_ = doc_id; + return true; + } + } + } + } +} + +float MinimumShouldMatchIterator::BM25Score() { + if (bm25_score_cache_docid_ == doc_id_) { + return bm25_score_cache_; + } + while (tail_size_) { + // advance tail + if (const auto tail_idx = PopFromTailHeap(); children_[tail_idx]->Next(doc_id_)) { + if (children_[tail_idx]->DocID() == doc_id_) { + lead_.push_back(tail_idx); + } else { + PushToHeadHeap(tail_idx); + } + } + } + float sum_score = 0; + for (const auto idx : lead_) { + sum_score += children_[idx]->BM25Score(); + } + bm25_score_cache_docid_ = doc_id_; + bm25_score_cache_ = sum_score; + return sum_score; +} + +u32 MinimumShouldMatchIterator::MatchCount() const { + UnrecoverableError("Unreachable code"); + return {}; +} + +RowID MinimumShouldMatchIterator::HeadHeapTop() const { + if (head_heap_.empty()) { + return INVALID_ROWID; + } + return children_[head_heap_.front()]->DocID(); +} + +void MinimumShouldMatchIterator::PushToHeadHeap(const u32 idx) { + head_heap_.push_back(idx); + std::push_heap(head_heap_.begin(), head_heap_.end(), [&](const u32 lhs, const u32 rhs) { + return children_[lhs]->DocID() > children_[rhs]->DocID(); + }); +} + +u32 MinimumShouldMatchIterator::PopFromHeadHeap() { + assert(!head_heap_.empty()); + std::pop_heap(head_heap_.begin(), head_heap_.end(), [&](const u32 lhs, const u32 rhs) { + return children_[lhs]->DocID() > children_[rhs]->DocID(); + }); + const u32 idx = head_heap_.back(); + head_heap_.pop_back(); + return idx; +} + +Pair MinimumShouldMatchIterator::PushToTailHeap(const u32 idx) { + auto comp = [&](const u32 lhs, const u32 rhs) { return children_[lhs]->GetDF() > children_[rhs]->GetDF(); }; + if (tail_size_ < tail_heap_.size()) { + tail_heap_[tail_size_++] = idx; + std::push_heap(tail_heap_.begin(), tail_heap_.begin() + tail_size_, comp); + return {false, std::numeric_limits::max()}; + } + if (children_[idx]->GetDF() <= children_[tail_heap_.front()]->GetDF()) { + return {true, idx}; + } + const auto result = tail_heap_.front(); + std::pop_heap(tail_heap_.begin(), tail_heap_.end(), comp); + tail_heap_.back() = idx; + std::push_heap(tail_heap_.begin(), tail_heap_.end(), comp); + return {true, result}; +} + +u32 MinimumShouldMatchIterator::PopFromTailHeap() { + assert(tail_size_ > 0); + std::pop_heap(tail_heap_.begin(), tail_heap_.begin() + tail_size_, [&](const u32 lhs, const u32 rhs) { + return children_[lhs]->GetDF() > children_[rhs]->GetDF(); + }); + return tail_heap_[--tail_size_]; +} + +} // namespace infinity diff --git a/src/storage/invertedindex/search/minimum_should_match_iterator.cppm b/src/storage/invertedindex/search/minimum_should_match_iterator.cppm new file mode 100644 index 0000000000..e234c21407 --- /dev/null +++ b/src/storage/invertedindex/search/minimum_should_match_iterator.cppm @@ -0,0 +1,79 @@ +// Copyright(C) 2023 InfiniFlow, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +module; + +export module minimum_should_match_iterator; +import stl; +import doc_iterator; +import multi_doc_iterator; +import internal_types; + +namespace infinity { + +export class MinimumShouldMatchIterator final : public MultiDocIterator { +public: + MinimumShouldMatchIterator(Vector> &&iterators, u32 minimum_should_match); + + ~MinimumShouldMatchIterator() override; + + DocIteratorType GetType() const override { return DocIteratorType::kMinimumShouldMatchIterator; } + + String Name() const override { return "MinimumShouldMatchIterator"; } + + void UpdateScoreThreshold(float threshold) override; + + bool Next(RowID doc_id) override; + + float BM25Score() override; + + u32 MatchCount() const override; + +private: + RowID HeadHeapTop() const; + void PushToHeadHeap(u32 idx); + u32 PopFromHeadHeap(); + Pair PushToTailHeap(u32 idx); + u32 PopFromTailHeap(); + + const u32 minimum_should_match_ = 0; + Vector lead_{}; + Vector head_heap_{}; + Vector tail_heap_{}; + u32 tail_size_ = 0; + + // bm25 score cache + RowID bm25_score_cache_docid_ = {}; + float bm25_score_cache_ = {}; +}; + +export template T> +class MinimumShouldMatchWrapper final : public T { + u32 minimum_should_match_ = 0; + +public: + MinimumShouldMatchWrapper(Vector> &&iterators, const u32 minimum_should_match) + : T(std::move(iterators)), minimum_should_match_(minimum_should_match) {} + ~MinimumShouldMatchWrapper() override = default; + bool Next(RowID doc_id) override { + for (; T::Next(doc_id); doc_id = this->doc_id_ + 1) { + if (this->MatchCount() >= minimum_should_match_) { + return true; + } + } + return false; + } +}; + +} // namespace infinity diff --git a/src/storage/invertedindex/search/or_iterator.cpp b/src/storage/invertedindex/search/or_iterator.cpp index 93da240585..13c0f4b465 100644 --- a/src/storage/invertedindex/search/or_iterator.cpp +++ b/src/storage/invertedindex/search/or_iterator.cpp @@ -107,12 +107,6 @@ void OrIterator::UpdateScoreThreshold(float threshold) { } } -u32 OrIterator::LeafCount() const { - return std::accumulate(children_.begin(), children_.end(), static_cast(0), [](const u32 cnt, const auto &it) { - return cnt + it->LeafCount(); - }); -} - u32 OrIterator::MatchCount() const { u32 count = 0; if (const auto current_doc_id = DocID(); current_doc_id != INVALID_ROWID) { diff --git a/src/storage/invertedindex/search/or_iterator.cppm b/src/storage/invertedindex/search/or_iterator.cppm index ff5fd45045..3775c4a39d 100644 --- a/src/storage/invertedindex/search/or_iterator.cppm +++ b/src/storage/invertedindex/search/or_iterator.cppm @@ -60,8 +60,6 @@ public: void UpdateScoreThreshold(float threshold) override; - u32 LeafCount() const override; - u32 MatchCount() const override; private: diff --git a/src/storage/invertedindex/search/phrase_doc_iterator.cppm b/src/storage/invertedindex/search/phrase_doc_iterator.cppm index 41f69e1e0e..bf574212c4 100644 --- a/src/storage/invertedindex/search/phrase_doc_iterator.cppm +++ b/src/storage/invertedindex/search/phrase_doc_iterator.cppm @@ -47,7 +47,6 @@ public: threshold_ = threshold; } - u32 LeafCount() const override { return 1; } u32 MatchCount() const override { return DocID() != INVALID_ROWID; } void PrintTree(std::ostream &os, const String &prefix, bool is_final) const override; diff --git a/src/storage/invertedindex/search/query_builder.cpp b/src/storage/invertedindex/search/query_builder.cpp index 6b96e6940c..8e33e69e54 100644 --- a/src/storage/invertedindex/search/query_builder.cpp +++ b/src/storage/invertedindex/search/query_builder.cpp @@ -33,6 +33,7 @@ import segment_entry; import term_doc_iterator; import logger; import third_party; +import parse_fulltext_options; namespace infinity { @@ -40,13 +41,19 @@ void QueryBuilder::Init(IndexReader index_reader) { index_reader_ = index_reader QueryBuilder::~QueryBuilder() {} -UniquePtr QueryBuilder::CreateSearch(FullTextQueryContext &context, EarlyTermAlgo early_term_algo) { +UniquePtr QueryBuilder::CreateSearch(FullTextQueryContext &context, + EarlyTermAlgo early_term_algo, + const MinimumShouldMatchOption &minimum_should_match_option) { // Optimize the query tree. if (!context.optimized_query_tree_) { context.optimized_query_tree_ = QueryNode::GetOptimizedQueryTree(std::move(context.query_tree_)); + if (!minimum_should_match_option.empty()) { + const auto leaf_count = context.optimized_query_tree_->LeafCount(); + context.minimum_should_match_ = GetMinimumShouldMatchParameter(minimum_should_match_option, leaf_count); + } } // Create the iterator from the query tree. - UniquePtr result = context.optimized_query_tree_->CreateSearch(table_entry_, index_reader_, early_term_algo); + auto result = context.optimized_query_tree_->CreateSearch(table_entry_, index_reader_, early_term_algo, context.minimum_should_match_); #ifdef INFINITY_DEBUG { OStringStream oss; diff --git a/src/storage/invertedindex/search/query_builder.cppm b/src/storage/invertedindex/search/query_builder.cppm index bb5088db65..0127ee83e3 100644 --- a/src/storage/invertedindex/search/query_builder.cppm +++ b/src/storage/invertedindex/search/query_builder.cppm @@ -23,6 +23,7 @@ import table_entry; import internal_types; import default_values; import base_table_ref; +import parse_fulltext_options; namespace infinity { @@ -31,6 +32,7 @@ struct QueryNode; export struct FullTextQueryContext { UniquePtr query_tree_; UniquePtr optimized_query_tree_; + u32 minimum_should_match_ = 0; }; export class QueryBuilder { @@ -44,7 +46,8 @@ public: const Map &GetColumn2Analyzer() { return index_reader_.GetColumn2Analyzer(); } - UniquePtr CreateSearch(FullTextQueryContext &context, EarlyTermAlgo early_term_algo); + UniquePtr + CreateSearch(FullTextQueryContext &context, EarlyTermAlgo early_term_algo, const MinimumShouldMatchOption &minimum_should_match_option); private: BaseTableRef* base_table_ref_{nullptr}; diff --git a/src/storage/invertedindex/search/query_node.cpp b/src/storage/invertedindex/search/query_node.cpp index 5e1d4ef27d..e7ae7cab21 100644 --- a/src/storage/invertedindex/search/query_node.cpp +++ b/src/storage/invertedindex/search/query_node.cpp @@ -1,4 +1,5 @@ #include "query_node.h" +#include #include import stl; @@ -20,6 +21,7 @@ import term_doc_iterator; import phrase_doc_iterator; import blockmax_wand_iterator; import blockmax_maxscore_iterator; +import minimum_should_match_iterator; namespace infinity { @@ -401,8 +403,10 @@ std::unique_ptr AndNotQueryNode::InnerGetNewOptimizedQueryTree() { } // create search iterator -std::unique_ptr -TermQueryNode::CreateSearch(const TableEntry *table_entry, const IndexReader &index_reader, EarlyTermAlgo /*early_term_algo*/) const { +std::unique_ptr TermQueryNode::CreateSearch(const TableEntry *table_entry, + const IndexReader &index_reader, + EarlyTermAlgo /*early_term_algo*/, + u32 /*minimum_should_match*/) const { ColumnID column_id = table_entry->GetColumnIdByName(column_); ColumnIndexReader *column_index_reader = index_reader.GetColumnIndexReader(column_id); if (!column_index_reader) { @@ -427,8 +431,10 @@ TermQueryNode::CreateSearch(const TableEntry *table_entry, const IndexReader &in return search; } -std::unique_ptr -PhraseQueryNode::CreateSearch(const TableEntry *table_entry, const IndexReader &index_reader, EarlyTermAlgo /*early_term_algo*/) const { +std::unique_ptr PhraseQueryNode::CreateSearch(const TableEntry *table_entry, + const IndexReader &index_reader, + EarlyTermAlgo /*early_term_algo*/, + u32 /*minimum_should_match*/) const { ColumnID column_id = table_entry->GetColumnIdByName(column_); ColumnIndexReader *column_index_reader = index_reader.GetColumnIndexReader(column_id); if (!column_index_reader) { @@ -456,38 +462,47 @@ PhraseQueryNode::CreateSearch(const TableEntry *table_entry, const IndexReader & return search; } -std::unique_ptr -AndQueryNode::CreateSearch(const TableEntry *table_entry, const IndexReader &index_reader, EarlyTermAlgo early_term_algo) const { +std::unique_ptr AndQueryNode::CreateSearch(const TableEntry *table_entry, + const IndexReader &index_reader, + const EarlyTermAlgo early_term_algo, + const u32 minimum_should_match) const { Vector> sub_doc_iters; sub_doc_iters.reserve(children_.size()); for (auto &child : children_) { - auto iter = child->CreateSearch(table_entry, index_reader, early_term_algo); - if (iter) { - sub_doc_iters.emplace_back(std::move(iter)); + auto iter = child->CreateSearch(table_entry, index_reader, early_term_algo, 0); + if (!iter) { + // no need to continue if any child is invalid + return nullptr; } + sub_doc_iters.emplace_back(std::move(iter)); } if (sub_doc_iters.empty()) { return nullptr; } else if (sub_doc_iters.size() == 1) { return std::move(sub_doc_iters[0]); - } else { + } else if (minimum_should_match <= sub_doc_iters.size()) { return MakeUnique(std::move(sub_doc_iters)); + } else { + assert(minimum_should_match > 2u); + return MakeUnique>(std::move(sub_doc_iters), minimum_should_match); } } -std::unique_ptr -AndNotQueryNode::CreateSearch(const TableEntry *table_entry, const IndexReader &index_reader, EarlyTermAlgo early_term_algo) const { +std::unique_ptr AndNotQueryNode::CreateSearch(const TableEntry *table_entry, + const IndexReader &index_reader, + const EarlyTermAlgo early_term_algo, + const u32 minimum_should_match) const { Vector> sub_doc_iters; sub_doc_iters.reserve(children_.size()); // check if the first child is a valid query - auto first_iter = children_.front()->CreateSearch(table_entry, index_reader, early_term_algo); + auto first_iter = children_.front()->CreateSearch(table_entry, index_reader, early_term_algo, minimum_should_match); if (!first_iter) { // no need to continue if the first child is invalid return nullptr; } sub_doc_iters.emplace_back(std::move(first_iter)); for (u32 i = 1; i < children_.size(); ++i) { - auto iter = children_[i]->CreateSearch(table_entry, index_reader, early_term_algo); + auto iter = children_[i]->CreateSearch(table_entry, index_reader, early_term_algo, 0); if (iter) { sub_doc_iters.emplace_back(std::move(iter)); } @@ -499,38 +514,65 @@ AndNotQueryNode::CreateSearch(const TableEntry *table_entry, const IndexReader & } } -std::unique_ptr -OrQueryNode::CreateSearch(const TableEntry *table_entry, const IndexReader &index_reader, EarlyTermAlgo early_term_algo) const { +std::unique_ptr OrQueryNode::CreateSearch(const TableEntry *table_entry, + const IndexReader &index_reader, + const EarlyTermAlgo early_term_algo, + const u32 minimum_should_match) const { Vector> sub_doc_iters; sub_doc_iters.reserve(children_.size()); bool all_are_term = true; + bool all_are_term_or_phrase = true; + const QueryNode *only_child_ = nullptr; for (auto &child : children_) { - if (child->GetType() != QueryNodeType::TERM) { - all_are_term = false; - } - auto iter = child->CreateSearch(table_entry, index_reader, early_term_algo); - if (iter) { + if (auto iter = child->CreateSearch(table_entry, index_reader, early_term_algo, 0); iter) { + only_child_ = child.get(); sub_doc_iters.emplace_back(std::move(iter)); + if (const auto child_type = child->GetType(); child_type != QueryNodeType::TERM) { + all_are_term = false; + if (child_type != QueryNodeType::PHRASE) { + all_are_term_or_phrase = false; + } + } } } if (sub_doc_iters.empty()) { return nullptr; } else if (sub_doc_iters.size() == 1) { - return std::move(sub_doc_iters[0]); - } else { - if (all_are_term && early_term_algo == EarlyTermAlgo::kBMW) + return only_child_->CreateSearch(table_entry, index_reader, early_term_algo, minimum_should_match); + } else if (all_are_term && early_term_algo == EarlyTermAlgo::kBMW) { + if (minimum_should_match <= 1u) { return MakeUnique(std::move(sub_doc_iters)); - else if (all_are_term && early_term_algo == EarlyTermAlgo::kBMM) - return MakeUnique(std::move(sub_doc_iters)); - else + } else if (minimum_should_match < sub_doc_iters.size()) { + return MakeUnique>(std::move(sub_doc_iters), minimum_should_match); + } else if (minimum_should_match == sub_doc_iters.size()) { + return MakeUnique(std::move(sub_doc_iters)); + } else { + return nullptr; + } + } else if (all_are_term_or_phrase) { + if (minimum_should_match <= 1u) { return MakeUnique(std::move(sub_doc_iters)); + } else if (minimum_should_match < sub_doc_iters.size()) { + return MakeUnique(std::move(sub_doc_iters), minimum_should_match); + } else if (minimum_should_match == sub_doc_iters.size()) { + return MakeUnique(std::move(sub_doc_iters)); + } else { + return nullptr; + } + } else { + if (minimum_should_match <= 1u) { + return MakeUnique(std::move(sub_doc_iters)); + } else { + return MakeUnique>(std::move(sub_doc_iters), minimum_should_match); + } } } -std::unique_ptr -NotQueryNode::CreateSearch(const TableEntry *table_entry, const IndexReader &index_reader, EarlyTermAlgo early_term_algo) const { - String error_message = "NOT query node should be optimized into AND_NOT query node"; - UnrecoverableError(error_message); +std::unique_ptr NotQueryNode::CreateSearch(const TableEntry *table_entry, + const IndexReader &index_reader, + EarlyTermAlgo early_term_algo, + u32 minimum_should_match) const { + UnrecoverableError("NOT query node should be optimized into AND_NOT query node"); return nullptr; } @@ -620,4 +662,14 @@ void MultiQueryNode::GetQueryColumnsTerms(std::vector &columns, std } } +uint32_t MultiQueryNode::LeafCount() const { + if (GetType() != QueryNodeType::OR && GetType() != QueryNodeType::AND) { + UnrecoverableError("LeafCount: Unexpected case!"); + } + return std::accumulate(children_.begin(), children_.end(), static_cast(0), [](const u32 cnt, const auto &it) { + return cnt + it->LeafCount(); + }); +} + + } // namespace infinity diff --git a/src/storage/invertedindex/search/query_node.h b/src/storage/invertedindex/search/query_node.h index 7abffccaee..e83d081b3d 100644 --- a/src/storage/invertedindex/search/query_node.h +++ b/src/storage/invertedindex/search/query_node.h @@ -72,6 +72,8 @@ struct QueryNode { void MultiplyWeight(float factor) { weight_ *= factor; } void ResetWeight() { weight_ = 1.0f; } + virtual uint32_t LeafCount() const = 0; + // will do two jobs: // 1. push down the weight to the leaf term node // 2. optimize the query tree @@ -81,8 +83,10 @@ struct QueryNode { // recursively multiply and push down the weight to the leaf term nodes virtual void PushDownWeight(float factor = 1.0f) = 0; // create the iterator from the query tree, need to be called after optimization - virtual std::unique_ptr - CreateSearch(const TableEntry *table_entry, const IndexReader &index_reader, EarlyTermAlgo early_term_algo) const = 0; + virtual std::unique_ptr CreateSearch(const TableEntry *table_entry, + const IndexReader &index_reader, + EarlyTermAlgo early_term_algo, + uint32_t minimum_should_match) const = 0; // print the query tree, for debugging virtual void PrintTree(std::ostream &os, const std::string &prefix = "", bool is_final = true) const = 0; @@ -96,8 +100,12 @@ struct TermQueryNode : public QueryNode { TermQueryNode() : QueryNode(QueryNodeType::TERM) {} + uint32_t LeafCount() const override { return 1; } void PushDownWeight(float factor) override { MultiplyWeight(factor); } - std::unique_ptr CreateSearch(const TableEntry *table_entry, const IndexReader &index_reader, EarlyTermAlgo early_term_algo) const override; + std::unique_ptr CreateSearch(const TableEntry *table_entry, + const IndexReader &index_reader, + EarlyTermAlgo early_term_algo, + uint32_t minimum_should_match) const override; void PrintTree(std::ostream &os, const std::string &prefix, bool is_final) const override; void GetQueryColumnsTerms(std::vector &columns, std::vector &terms) const override; }; @@ -109,8 +117,12 @@ struct PhraseQueryNode final : public QueryNode { PhraseQueryNode() : QueryNode(QueryNodeType::PHRASE) {} + uint32_t LeafCount() const override { return 1; } void PushDownWeight(float factor) override { MultiplyWeight(factor); } - std::unique_ptr CreateSearch(const TableEntry *table_entry, const IndexReader &index_reader, EarlyTermAlgo early_term_algo) const override; + std::unique_ptr CreateSearch(const TableEntry *table_entry, + const IndexReader &index_reader, + EarlyTermAlgo early_term_algo, + uint32_t minimum_should_match) const override; void PrintTree(std::ostream &os, const std::string &prefix, bool is_final) const override; void GetQueryColumnsTerms(std::vector &columns, std::vector &terms) const override; @@ -123,6 +135,7 @@ struct MultiQueryNode : public QueryNode { MultiQueryNode(QueryNodeType type) : QueryNode(type) {} void Add(std::unique_ptr &&node) { children_.emplace_back(std::move(node)); } + uint32_t LeafCount() const override; void PushDownWeight(float factor) final { // no need to update weight for MultiQueryNode, because it will be reset to 1.0 factor *= GetWeight(); @@ -142,23 +155,40 @@ struct MultiQueryNode : public QueryNode { // otherwise, query statement is invalid struct NotQueryNode final : public MultiQueryNode { NotQueryNode() : MultiQueryNode(QueryNodeType::NOT) {} + uint32_t LeafCount() const override { return 0; } std::unique_ptr InnerGetNewOptimizedQueryTree() override; - std::unique_ptr CreateSearch(const TableEntry *table_entry, const IndexReader &index_reader, EarlyTermAlgo early_term_algo) const override; + std::unique_ptr CreateSearch(const TableEntry *table_entry, + const IndexReader &index_reader, + EarlyTermAlgo early_term_algo, + uint32_t minimum_should_match) const override; }; + struct AndQueryNode final : public MultiQueryNode { AndQueryNode() : MultiQueryNode(QueryNodeType::AND) {} std::unique_ptr InnerGetNewOptimizedQueryTree() override; - std::unique_ptr CreateSearch(const TableEntry *table_entry, const IndexReader &index_reader, EarlyTermAlgo early_term_algo) const override; + std::unique_ptr CreateSearch(const TableEntry *table_entry, + const IndexReader &index_reader, + EarlyTermAlgo early_term_algo, + uint32_t minimum_should_match) const override; }; + struct AndNotQueryNode final : public MultiQueryNode { AndNotQueryNode() : MultiQueryNode(QueryNodeType::AND_NOT) {} + uint32_t LeafCount() const override { return children_[0]->LeafCount(); } std::unique_ptr InnerGetNewOptimizedQueryTree() override; - std::unique_ptr CreateSearch(const TableEntry *table_entry, const IndexReader &index_reader, EarlyTermAlgo early_term_algo) const override; + std::unique_ptr CreateSearch(const TableEntry *table_entry, + const IndexReader &index_reader, + EarlyTermAlgo early_term_algo, + uint32_t minimum_should_match) const override; }; + struct OrQueryNode final : public MultiQueryNode { OrQueryNode() : MultiQueryNode(QueryNodeType::OR) {} std::unique_ptr InnerGetNewOptimizedQueryTree() override; - std::unique_ptr CreateSearch(const TableEntry *table_entry, const IndexReader &index_reader, EarlyTermAlgo early_term_algo) const override; + std::unique_ptr CreateSearch(const TableEntry *table_entry, + const IndexReader &index_reader, + EarlyTermAlgo early_term_algo, + uint32_t minimum_should_match) const override; }; // unimplemented diff --git a/src/storage/invertedindex/search/term_doc_iterator.cppm b/src/storage/invertedindex/search/term_doc_iterator.cppm index e0949b9650..1ae7776cb9 100644 --- a/src/storage/invertedindex/search/term_doc_iterator.cppm +++ b/src/storage/invertedindex/search/term_doc_iterator.cppm @@ -70,7 +70,6 @@ public: threshold_ = threshold; } - u32 LeafCount() const override { return 1; } u32 MatchCount() const override { return DocID() != INVALID_ROWID; } void PrintTree(std::ostream &os, const String &prefix, bool is_final) const override; diff --git a/src/unit_test/storage/invertedindex/search/query_builder.cpp b/src/unit_test/storage/invertedindex/search/query_builder.cpp index b7bee17aa8..b6eea79f0c 100644 --- a/src/unit_test/storage/invertedindex/search/query_builder.cpp +++ b/src/unit_test/storage/invertedindex/search/query_builder.cpp @@ -31,6 +31,7 @@ import infinity_context; import global_resource_usage; import third_party; import logger; +import parse_fulltext_options; namespace infinity { @@ -61,7 +62,6 @@ class MockVectorDocIterator : public DocIterator { void UpdateScoreThreshold(float threshold) override {} - u32 LeafCount() const override { return 1; } u32 MatchCount() const override { return DocID() != INVALID_ROWID; } void PrintTree(std::ostream &os, const String &prefix, bool is_final = true) const override { @@ -94,7 +94,7 @@ struct MockQueryNode : public TermQueryNode { } void PushDownWeight(float factor) final { MultiplyWeight(factor); } - std::unique_ptr CreateSearch(const TableEntry *, const IndexReader &, EarlyTermAlgo early_term_algo) const override { + std::unique_ptr CreateSearch(const TableEntry *, const IndexReader &, EarlyTermAlgo, u32) const override { return MakeUnique(std::move(doc_ids_), term_, column_); } void PrintTree(std::ostream &os, const std::string &prefix, bool is_final) const final { @@ -204,7 +204,7 @@ TEST_F(QueryBuilderTest, test_and) { context.query_tree_ = std::move(and_root); FakeQueryBuilder fake_query_builder; QueryBuilder &builder = fake_query_builder.builder; - UniquePtr result_iter = builder.CreateSearch(context, EarlyTermAlgo::kNaive); + UniquePtr result_iter = builder.CreateSearch(context, EarlyTermAlgo::kNaive, MinimumShouldMatchOption{}); oss.str(""); oss << "DocIterator tree after optimization:" << std::endl; @@ -274,7 +274,7 @@ TEST_F(QueryBuilderTest, test_or) { context.query_tree_ = std::move(or_root); FakeQueryBuilder fake_query_builder; QueryBuilder &builder = fake_query_builder.builder; - UniquePtr result_iter = builder.CreateSearch(context, EarlyTermAlgo::kNaive); + UniquePtr result_iter = builder.CreateSearch(context, EarlyTermAlgo::kNaive, MinimumShouldMatchOption{}); oss.str(""); oss << "DocIterator tree after optimization:" << std::endl; @@ -350,7 +350,7 @@ TEST_F(QueryBuilderTest, test_and_not) { context.query_tree_ = std::move(and_not_root); FakeQueryBuilder fake_query_builder; QueryBuilder &builder = fake_query_builder.builder; - UniquePtr result_iter = builder.CreateSearch(context, EarlyTermAlgo::kNaive); + UniquePtr result_iter = builder.CreateSearch(context, EarlyTermAlgo::kNaive, MinimumShouldMatchOption{}); oss.str(""); oss << "DocIterator tree after optimization:" << std::endl; @@ -432,7 +432,7 @@ TEST_F(QueryBuilderTest, test_and_not2) { context.query_tree_ = std::move(and_not_root); FakeQueryBuilder fake_query_builder; QueryBuilder &builder = fake_query_builder.builder; - UniquePtr result_iter = builder.CreateSearch(context, EarlyTermAlgo::kNaive); + UniquePtr result_iter = builder.CreateSearch(context, EarlyTermAlgo::kNaive, MinimumShouldMatchOption{}); oss.str(""); oss << "DocIterator tree after optimization:" << std::endl; diff --git a/src/unit_test/storage/invertedindex/search/query_match.cpp b/src/unit_test/storage/invertedindex/search/query_match.cpp index d40cc9c582..ef574601fc 100644 --- a/src/unit_test/storage/invertedindex/search/query_match.cpp +++ b/src/unit_test/storage/invertedindex/search/query_match.cpp @@ -40,6 +40,7 @@ import global_resource_usage; import term_doc_iterator; import logger; import column_index_reader; +import parse_fulltext_options; using namespace infinity; @@ -341,7 +342,7 @@ void QueryMatchTest::QueryMatch(const String &db_name, } FullTextQueryContext full_text_query_context; full_text_query_context.query_tree_ = std::move(query_tree); - UniquePtr doc_iterator = query_builder.CreateSearch(full_text_query_context, EarlyTermAlgo::kNaive); + UniquePtr doc_iterator = query_builder.CreateSearch(full_text_query_context, EarlyTermAlgo::kNaive, MinimumShouldMatchOption{}); RowID iter_row_id = doc_iterator.get() == nullptr ? INVALID_ROWID : (doc_iterator->Next(), doc_iterator->DocID()); if (iter_row_id == INVALID_ROWID) {