Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support applying BlockMaxWand algorithm to PhraseDocIterator #2369

Merged
merged 9 commits into from
Dec 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions example/fulltext_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,8 @@
r"Bloom filter", # OR multiple terms
r'"Bloom filter"', # phrase: adjacent multiple terms
r"space efficient", # OR multiple terms
r"space\-efficient", # Escape reserved character '-', equivalent to: `space efficient`
r'"space\-efficient"', # phrase and escape reserved character, equivalent to: `"space efficient"`
r"space\:efficient", # Escape reserved character ':', equivalent to: `space efficient`
r'"space\:efficient"', # phrase and escape reserved character, equivalent to: `"space efficient"`
r'"harmful chemical"~10', # sloppy phrase, refers to https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl-match-query-phrase.html
]
for question in questions:
Expand Down
2 changes: 1 addition & 1 deletion example/fulltext_search_zh.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@
r"羽毛球", # single term
r'"羽毛球锦标赛"', # phrase: adjacent multiple terms
r"2018年世界羽毛球锦标赛在哪个城市举办?", # OR multiple terms
r"high\-tech", # Escape reserved character '-'
r"high\:tech", # Escape reserved character ':'
r'"high tech"', # phrase: adjacent multiple terms
r'"high-tech"', # phrase: adjacent multiple terms
r"graphics card", # OR multiple terms
Expand Down
42 changes: 42 additions & 0 deletions src/storage/invertedindex/search/blockmax_leaf_iterator.cppm
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
// Copyright(C) 2023 InfiniFlow, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

module;

export module blockmax_leaf_iterator;

import stl;
import internal_types;
import doc_iterator;

namespace infinity {

export class BlockMaxLeafIterator : public DocIterator {
public:
virtual RowID BlockMinPossibleDocID() const = 0;

virtual RowID BlockLastDocID() const = 0;

virtual float BlockMaxBM25Score() = 0;

// Move block cursor to ensure its last_doc_id is no less than given doc_id.
// Returns false and update doc_id_ to INVALID_ROWID if the iterator is exhausted.
// Note that this routine decode skip_list only, and doesn't update doc_id_ when returns true.
// Caller may invoke BlockMaxBM25Score() after this routine.
virtual bool NextShallow(RowID doc_id) = 0;

virtual float BM25Score() = 0;
};

} // namespace infinity
42 changes: 24 additions & 18 deletions src/storage/invertedindex/search/blockmax_wand_iterator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ module blockmax_wand_iterator;
import stl;
import third_party;
import index_defines;
import term_doc_iterator;
import blockmax_leaf_iterator;
import multi_doc_iterator;
import internal_types;
import logger;
Expand All @@ -29,18 +29,24 @@ import infinity_exception;
namespace infinity {

BlockMaxWandIterator::~BlockMaxWandIterator() {
String msg = "BlockMaxWandIterator pivot_history: ";
SizeT num_history = pivot_history_.size();
for (SizeT i=0; i<num_history; i++) {
auto &p = pivot_history_[i];
u32 pivot = std::get<0>(p);
u64 row_id = std::get<1>(p);
float score = std::get<2>(p);
//oss << " (" << pivot << ", " << row_id << ", " << score << ")";
msg += fmt::format(" ({}, {}, {:6f})", pivot, row_id, score);
if (SHOULD_LOG_TRACE()) {
String msg = "BlockMaxWandIterator pivot_history: ";
SizeT num_history = pivot_history_.size();
for (SizeT i = 0; i < num_history; i++) {
auto &p = pivot_history_[i];
u32 pivot = std::get<0>(p);
u64 row_id = std::get<1>(p);
float score = std::get<2>(p);
//oss << " (" << pivot << ", " << row_id << ", " << score << ")";
msg += fmt::format(" ({}, {}, {:6f})", pivot, row_id, score);
}
msg += fmt::format("\nnext_sort_cnt_ {}, next_it0_docid_mismatch_cnt_ {}, next_sum_score_low_cnt_ {}, next_sum_score_bm_low_cnt_ {}",
next_sort_cnt_,
next_it0_docid_mismatch_cnt_,
next_sum_score_low_cnt_,
next_sum_score_bm_low_cnt_);
LOG_TRACE(msg);
}
msg += fmt::format("\nnext_sort_cnt_ {}, next_it0_docid_mismatch_cnt_ {}, next_sum_score_low_cnt_ {}, next_sum_score_bm_low_cnt_ {}", next_sort_cnt_, next_it0_docid_mismatch_cnt_, next_sum_score_low_cnt_, next_sum_score_bm_low_cnt_);
LOG_TRACE(msg);
}

BlockMaxWandIterator::BlockMaxWandIterator(Vector<UniquePtr<DocIterator>> &&iterators)
Expand All @@ -49,9 +55,9 @@ BlockMaxWandIterator::BlockMaxWandIterator(Vector<UniquePtr<DocIterator>> &&iter
estimate_iterate_cost_ = {};
SizeT num_iterators = children_.size();
for (SizeT i = 0; i < num_iterators; i++){
TermDocIterator *tdi = dynamic_cast<TermDocIterator *>(children_[i].get());
BlockMaxLeafIterator *tdi = dynamic_cast<BlockMaxLeafIterator *>(children_[i].get());
if (tdi == nullptr) {
UnrecoverableError("BMW only supports TermDocIterator");
UnrecoverableError("BMW only supports BlockMaxLeafIterator");
}
bm25_score_upper_bound_ += tdi->BM25ScoreUpperBound();
estimate_iterate_cost_ += tdi->GetEstimateIterateCost();
Expand Down Expand Up @@ -101,10 +107,10 @@ bool BlockMaxWandIterator::Next(RowID doc_id){
});
// remove exhausted lists
for (int i = int(num_iterators) - 1; i >= 0 && sorted_iterators_[i]->DocID() == INVALID_ROWID; i--) {
if (SHOULD_LOG_DEBUG()) {
if (SHOULD_LOG_TRACE()) {
OStringStream oss;
sorted_iterators_[i]->PrintTree(oss, "Exhaused: ", true);
LOG_DEBUG(oss.str());
LOG_TRACE(oss.str());
}
bm25_score_upper_bound_ -= sorted_iterators_[i]->BM25ScoreUpperBound();
sorted_iterators_.pop_back();
Expand Down Expand Up @@ -142,10 +148,10 @@ bool BlockMaxWandIterator::Next(RowID doc_id){
if (ok) [[likely]] {
sum_score_bm += sorted_iterators_[i]->BlockMaxBM25Score();
} else {
if (SHOULD_LOG_DEBUG()) {
if (SHOULD_LOG_TRACE()) {
OStringStream oss;
sorted_iterators_[i]->PrintTree(oss, "Exhausted: ", true);
LOG_DEBUG(oss.str());
LOG_TRACE(oss.str());
}
sorted_iterators_.erase(sorted_iterators_.begin() + i);
num_iterators = sorted_iterators_.size();
Expand Down
6 changes: 3 additions & 3 deletions src/storage/invertedindex/search/blockmax_wand_iterator.cppm
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ export module blockmax_wand_iterator;
import stl;
import index_defines;
import doc_iterator;
import term_doc_iterator;
import blockmax_leaf_iterator;
import multi_doc_iterator;
import internal_types;

Expand Down Expand Up @@ -50,8 +50,8 @@ private:
RowID common_block_min_possible_doc_id_{}; // not always exist
RowID common_block_last_doc_id_{};
float common_block_max_bm25_score_{};
Vector<TermDocIterator *> sorted_iterators_; // sort by DocID(), in ascending order
Vector<TermDocIterator *> backup_iterators_;
Vector<BlockMaxLeafIterator *> sorted_iterators_; // sort by DocID(), in ascending order
Vector<BlockMaxLeafIterator *> backup_iterators_;
SizeT pivot_;
// bm25 score cache
bool bm25_score_cached_ = false;
Expand Down
70 changes: 69 additions & 1 deletion src/storage/invertedindex/search/phrase_doc_iterator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ module;

#include <cassert>
#include <iostream>
#include <vector>

module phrase_doc_iterator;

Expand Down Expand Up @@ -31,6 +32,8 @@ PhraseDocIterator::PhraseDocIterator(Vector<UniquePtr<PostingIterator>> &&iters,
estimate_doc_freq_ = std::min(estimate_doc_freq_, pos_iters_[i]->GetDocFreq());
}
estimate_iterate_cost_ = {1, estimate_doc_freq_};
block_max_bm25_score_cache_part_info_end_ids_.resize(pos_iters_.size(), INVALID_ROWID);
block_max_bm25_score_cache_part_info_vals_.resize(pos_iters_.size());
}

void PhraseDocIterator::InitBM25Info(UniquePtr<FullTextColumnLengthReader> &&column_length_reader) {
Expand All @@ -41,9 +44,12 @@ void PhraseDocIterator::InitBM25Info(UniquePtr<FullTextColumnLengthReader> &&col
column_length_reader_ = std::move(column_length_reader);
u64 total_df = column_length_reader_->GetTotalDF();
float avg_column_len = column_length_reader_->GetAvgColumnLength();
float smooth_idf = std::log(1.0F + (total_df - estimate_doc_freq_ + 0.5F) / (estimate_doc_freq_ + 0.5F));
float smooth_idf = std::log1p((total_df - estimate_doc_freq_ + 0.5F) / (estimate_doc_freq_ + 0.5F));
bm25_common_score_ = weight_ * smooth_idf * (k1 + 1.0F);
bm25_score_upper_bound_ = bm25_common_score_ / (1.0F + k1 * b / avg_column_len);
f1 = k1 * (1.0F - b);
f2 = k1 * b / avg_column_len;
f3 = f2 * std::numeric_limits<u16>::max();
if (SHOULD_LOG_TRACE()) {
OStringStream oss;
oss << "TermDocIterator: ";
Expand Down Expand Up @@ -80,13 +86,74 @@ bool PhraseDocIterator::Next(const RowID doc_id) {
bool found = GetPhraseMatchData();
if (found && (threshold_ <= 0.0f || BM25Score() > threshold_)) {
doc_id_ = target_doc_id;
UpdateBlockRangeDocID();
return true;
}
++target_doc_id;
}
}
}

void PhraseDocIterator::UpdateBlockRangeDocID() {
RowID min_doc_id = 0;
RowID max_doc_id = INVALID_ROWID;
for (const auto &it : pos_iters_) {
min_doc_id = std::max(min_doc_id, it->BlockLowestPossibleDocID());
max_doc_id = std::min(max_doc_id, it->BlockLastDocID());
}
block_min_possible_doc_id_ = min_doc_id;
block_last_doc_id_ = max_doc_id;
}

float PhraseDocIterator::BlockMaxBM25Score() {
if (const auto last_doc_id = BlockLastDocID(); last_doc_id != block_max_bm25_score_cache_end_id_) {
block_max_bm25_score_cache_end_id_ = last_doc_id;
// bm25_common_score_ / (1.0F + k1 * ((1.0F - b) / block_max_tf + b / block_max_percentage / avg_column_len));
// block_max_bm25_score_cache_ = bm25_common_score_ / (1.0F + f1 / block_max_tf + f3 / block_max_percentage_u16);
float div_add_min = std::numeric_limits<float>::max();
for (SizeT i = 0; i < pos_iters_.size(); ++i) {
const auto *iter = pos_iters_[i].get();
float current_div_add_min = {};
if (const auto iter_block_last_doc_id = iter->BlockLastDocID();
iter_block_last_doc_id == block_max_bm25_score_cache_part_info_end_ids_[i]) {
current_div_add_min = block_max_bm25_score_cache_part_info_vals_[i];
} else {
block_max_bm25_score_cache_part_info_end_ids_[i] = iter_block_last_doc_id;
const auto [block_max_tf, block_max_percentage_u16] = iter->GetBlockMaxInfo();
current_div_add_min = f1 / block_max_tf + f3 / block_max_percentage_u16;
block_max_bm25_score_cache_part_info_vals_[i] = current_div_add_min;
}
div_add_min = std::min(div_add_min, current_div_add_min);
}
block_max_bm25_score_cache_ = bm25_common_score_ / (1.0F + div_add_min);
}
return block_max_bm25_score_cache_;
}

// Move block cursor to ensure its last_doc_id is no less than given doc_id.
// Returns false and update doc_id_ to INVALID_ROWID if the iterator is exhausted.
// Note that this routine decode skip_list only, and doesn't update doc_id_ when returns true.
// Caller may invoke BlockMaxBM25Score() after this routine.
bool PhraseDocIterator::NextShallow(RowID doc_id) {
if (threshold_ > BM25ScoreUpperBound()) [[unlikely]] {
doc_id_ = INVALID_ROWID;
return false;
}
while (true) {
for (const auto &iter : pos_iters_) {
if (!iter->SkipTo(doc_id)) {
doc_id_ = INVALID_ROWID;
return false;
}
}
UpdateBlockRangeDocID();
if (threshold_ <= 0.0f || BlockMaxBM25Score() > threshold_) {
return true;
}
doc_id = BlockLastDocID() + 1;
}
}

float PhraseDocIterator::BM25Score() {
if (doc_id_ == bm25_score_cache_docid_) [[unlikely]] {
return bm25_score_cache_;
Expand All @@ -112,6 +179,7 @@ void PhraseDocIterator::PrintTree(std::ostream &os, const String &prefix, bool i
}
os << ")";
os << " (doc_freq: " << GetDocFreq() << ")";
os << " (bm25_score_upper_bound: " << BM25ScoreUpperBound() << ")";
os << '\n';
}

Expand Down
23 changes: 21 additions & 2 deletions src/storage/invertedindex/search/phrase_doc_iterator.cppm
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,11 @@ import posting_iterator;
import index_defines;
import column_length_io;
import parse_fulltext_options;
import blockmax_leaf_iterator;

namespace infinity {

export class PhraseDocIterator final : public DocIterator {
export class PhraseDocIterator final : public BlockMaxLeafIterator {
public:
PhraseDocIterator(Vector<UniquePtr<PostingIterator>> &&iters, float weight, u32 slop, FulltextSimilarity ft_similarity);

Expand All @@ -32,7 +33,21 @@ public:

bool Next(RowID doc_id) override;

float BM25Score();
RowID BlockMinPossibleDocID() const override { return block_min_possible_doc_id_; }

RowID BlockLastDocID() const override { return block_last_doc_id_; }

void UpdateBlockRangeDocID();

float BlockMaxBM25Score() override;

// Move block cursor to ensure its last_doc_id is no less than given doc_id.
// Returns false and update doc_id_ to INVALID_ROWID if the iterator is exhausted.
// Note that this routine decode skip_list only, and doesn't update doc_id_ when returns true.
// Caller may invoke BlockMaxBM25Score() after this routine.
bool NextShallow(RowID doc_id) override;

float BM25Score() override;

float Score() override {
switch (ft_similarity_) {
Expand Down Expand Up @@ -86,6 +101,10 @@ private:
UniquePtr<FullTextColumnLengthReader> column_length_reader_ = nullptr;
float block_max_bm25_score_cache_ = 0.0f;
RowID block_max_bm25_score_cache_end_id_ = INVALID_ROWID;
Vector<RowID> block_max_bm25_score_cache_part_info_end_ids_;
Vector<float> block_max_bm25_score_cache_part_info_vals_;
RowID block_min_possible_doc_id_ = INVALID_ROWID;
RowID block_last_doc_id_ = INVALID_ROWID;

float tf_ = 0.0f; // current doc_id_'s tf
u32 estimate_doc_freq_{0}; // estimated at the beginning
Expand Down
Loading