Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
yangzq50 committed Oct 17, 2024
1 parent 5214d31 commit 725f807
Show file tree
Hide file tree
Showing 12 changed files with 113 additions and 59 deletions.
11 changes: 6 additions & 5 deletions src/storage/invertedindex/search/and_iterator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,15 @@ import infinity_exception;
namespace infinity {

AndIterator::AndIterator(Vector<UniquePtr<DocIterator>> iterators) : MultiDocIterator(std::move(iterators)) {
std::sort(children_.begin(), children_.end(), [](const auto &lhs, const auto &rhs) { return lhs->GetDF() < rhs->GetDF(); });
// init df
doc_freq_ = std::numeric_limits<u32>::max();
std::sort(children_.begin(), children_.end(), [](const auto &lhs, const auto &rhs) {
return lhs->GetEstimateIterateCost() < rhs->GetEstimateIterateCost();
});
bm25_score_upper_bound_ = 0.0f;
estimate_iterate_cost_ = {};
for (SizeT i = 0; i < children_.size(); i++) {
const auto &it = children_[i];
doc_freq_ = std::min(doc_freq_, it->GetDF());
bm25_score_upper_bound_ += children_[i]->BM25ScoreUpperBound();
estimate_iterate_cost_ = std::min(estimate_iterate_cost_, it->GetEstimateIterateCost());
// for minimum_should_match parameter
switch (it->GetType()) {
case DocIteratorType::kTermDocIterator:
Expand All @@ -50,7 +51,7 @@ AndIterator::AndIterator(Vector<UniquePtr<DocIterator>> iterators) : MultiDocIte
}
}

bool AndIterator::Next(RowID doc_id) {
bool AndIterator::Next(const RowID doc_id) {
assert(doc_id != INVALID_ROWID);
if (doc_id_ != INVALID_ROWID && doc_id_ >= doc_id)
return true;
Expand Down
9 changes: 4 additions & 5 deletions src/storage/invertedindex/search/and_not_iterator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,11 @@ import internal_types;
namespace infinity {

AndNotIterator::AndNotIterator(Vector<UniquePtr<DocIterator>> iterators) : MultiDocIterator(std::move(iterators)) {
std::sort(children_.begin() + 1, children_.end(), [](const auto &lhs, const auto &rhs) { return lhs->GetDF() < rhs->GetDF(); });
// initialize doc_id_ to first doc
Next(0);
// init df
doc_freq_ = children_[0]->GetDF();
std::sort(children_.begin() + 1, children_.end(), [](const auto &lhs, const auto &rhs) {
return lhs->GetEstimateIterateCost() < rhs->GetEstimateIterateCost();
});
bm25_score_upper_bound_ = children_[0]->BM25ScoreUpperBound();
estimate_iterate_cost_ = children_[0]->GetEstimateIterateCost();
}

bool AndNotIterator::Next(RowID doc_id) {
Expand Down
7 changes: 5 additions & 2 deletions src/storage/invertedindex/search/blockmax_wand_iterator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,15 @@ BlockMaxWandIterator::~BlockMaxWandIterator() {
BlockMaxWandIterator::BlockMaxWandIterator(Vector<UniquePtr<DocIterator>> &&iterators)
: MultiDocIterator(std::move(iterators)), pivot_(sorted_iterators_.size()) {
bm25_score_upper_bound_ = 0.0f;
estimate_iterate_cost_ = {};
SizeT num_iterators = children_.size();
for (SizeT i = 0; i < num_iterators; i++){
TermDocIterator *tdi = dynamic_cast<TermDocIterator *>(children_[i].get());
if (tdi == nullptr)
continue;
if (tdi == nullptr) {
UnrecoverableError("BMW only supports TermDocIterator");
}
bm25_score_upper_bound_ += tdi->BM25ScoreUpperBound();
estimate_iterate_cost_ += tdi->GetEstimateIterateCost();
sorted_iterators_.push_back(tdi);
}
next_sum_score_bm_low_cnt_dist_.resize(100, 0);
Expand Down
17 changes: 15 additions & 2 deletions src/storage/invertedindex/search/doc_iterator.cppm
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,19 @@ export enum class DocIteratorType : u8 {
kFilterIterator,
};

export struct DocIteratorEstimateIterateCost {
u32 priority_ = 0;
u32 estimate_cost_ = 0;
friend auto operator<=>(const DocIteratorEstimateIterateCost &, const DocIteratorEstimateIterateCost &) = default;
auto &operator+=(DocIteratorEstimateIterateCost rhs) {
if (priority_ < rhs.priority_) {
std::swap(*this, rhs);
}
estimate_cost_ += (rhs.estimate_cost_ >> 3 * (priority_ - rhs.priority_));
return *this;
}
};

export class DocIterator {
public:
DocIterator() {}
Expand All @@ -45,7 +58,7 @@ public:

RowID DocID() const { return doc_id_; }

inline u32 GetDF() const { return doc_freq_; }
inline DocIteratorEstimateIterateCost GetEstimateIterateCost() const { return estimate_iterate_cost_; }

// Update doc_id_ to one larger than previous one.
// If has_blockmax is true, it ensures its BM25 score be larger than current threshold.
Expand Down Expand Up @@ -81,8 +94,8 @@ public:

protected:
RowID doc_id_{INVALID_ROWID};
u32 doc_freq_ = 0;
float threshold_ = 0.0f;
float bm25_score_upper_bound_ = std::numeric_limits<float>::max();
DocIteratorEstimateIterateCost estimate_iterate_cost_ = {};
};
} // namespace infinity
24 changes: 20 additions & 4 deletions src/storage/invertedindex/search/minimum_should_match_iterator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,27 @@ MinimumShouldMatchIterator::MinimumShouldMatchIterator(Vector<UniquePtr<DocItera
}
tail_heap_.resize(minimum_should_match_ - 1u);
bm25_score_cache_docid_ = INVALID_ROWID;
bm25_score_upper_bound_ = 0.0f;
estimate_iterate_cost_ = {};
for (const auto &child : children_) {
bm25_score_upper_bound_ += child->BM25ScoreUpperBound();
estimate_iterate_cost_ += child->GetEstimateIterateCost();
}
}

MinimumShouldMatchIterator::~MinimumShouldMatchIterator() {}

void MinimumShouldMatchIterator::UpdateScoreThreshold(float threshold) { UnrecoverableError("Unreachable code"); }
void MinimumShouldMatchIterator::UpdateScoreThreshold(const float threshold) {
if (threshold <= threshold_) {
return;
}
threshold_ = threshold;
const float base_threshold = threshold - BM25ScoreUpperBound();
for (const auto &it : children_) {
const float new_threshold = std::max(0.0f, base_threshold + it->BM25ScoreUpperBound());
it->UpdateScoreThreshold(new_threshold);
}
}

bool MinimumShouldMatchIterator::Next(RowID doc_id) {
if (doc_id_ == INVALID_ROWID) {
Expand Down Expand Up @@ -160,13 +176,13 @@ u32 MinimumShouldMatchIterator::PopFromHeadHeap() {
}

Pair<bool, u32> MinimumShouldMatchIterator::PushToTailHeap(const u32 idx) {
auto comp = [&](const u32 lhs, const u32 rhs) { return children_[lhs]->GetDF() > children_[rhs]->GetDF(); };
auto comp = [&](const u32 lhs, const u32 rhs) { return children_[lhs]->GetEstimateIterateCost() > children_[rhs]->GetEstimateIterateCost(); };
if (tail_size_ < tail_heap_.size()) {
tail_heap_[tail_size_++] = idx;
std::push_heap(tail_heap_.begin(), tail_heap_.begin() + tail_size_, comp);
return {false, std::numeric_limits<u32>::max()};
}
if (children_[idx]->GetDF() <= children_[tail_heap_.front()]->GetDF()) {
if (children_[idx]->GetEstimateIterateCost() <= children_[tail_heap_.front()]->GetEstimateIterateCost()) {
return {true, idx};
}
const auto result = tail_heap_.front();
Expand All @@ -179,7 +195,7 @@ Pair<bool, u32> MinimumShouldMatchIterator::PushToTailHeap(const u32 idx) {
u32 MinimumShouldMatchIterator::PopFromTailHeap() {
assert(tail_size_ > 0);
std::pop_heap(tail_heap_.begin(), tail_heap_.begin() + tail_size_, [&](const u32 lhs, const u32 rhs) {
return children_[lhs]->GetDF() > children_[rhs]->GetDF();
return children_[lhs]->GetEstimateIterateCost() > children_[rhs]->GetEstimateIterateCost();
});
return tail_heap_[--tail_size_];
}
Expand Down
5 changes: 4 additions & 1 deletion src/storage/invertedindex/search/multi_doc_iterator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,10 @@ void MultiDocIterator::PrintTree(std::ostream &os, const String &prefix, bool is
os << prefix;
os << (is_final ? "└──" : "├──");
os << Name();
os << " (doc_freq: " << GetDF() << ")";
{
auto [level, cost] = GetEstimateIterateCost();
os << " (estimate_iterate_cost: " << level << ", " << cost << ")";
}
os << " (bm25_score_upper_bound: " << BM25ScoreUpperBound() << ")";
os << " (threshold: " << Threshold() << ")";
os << " (children count: " << children_.size() << ")";
Expand Down
33 changes: 18 additions & 15 deletions src/storage/invertedindex/search/or_iterator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
module;

#include <cassert>
#include <vector>

module or_iterator;
import internal_types;
Expand Down Expand Up @@ -52,19 +53,21 @@ void DocIteratorHeap::AdjustDown(SizeT idx) {
}

OrIterator::OrIterator(Vector<UniquePtr<DocIterator>> iterators) : MultiDocIterator(std::move(iterators)) {
doc_freq_ = 0;
bm25_score_upper_bound_ = 0.0f;
estimate_iterate_cost_ = {};
for (const auto &child : children_) {
bm25_score_upper_bound_ += child->BM25ScoreUpperBound();
estimate_iterate_cost_ += child->GetEstimateIterateCost();
}
}

bool OrIterator::Next(RowID doc_id) {
bool OrIterator::Next(const RowID doc_id) {
assert(doc_id != INVALID_ROWID);
if (doc_id_ == INVALID_ROWID) {
for (u32 i = 0; i < children_.size(); ++i) {
doc_freq_ += children_[i]->GetDF();
children_[i]->Next();
DocIteratorEntry entry = {children_[i]->DocID(), i};
heap_.AddEntry(entry);
bm25_score_upper_bound_ += children_[i]->BM25ScoreUpperBound();
}
heap_.BuildHeap();
doc_id_ = heap_.TopEntry().doc_id_;
Expand All @@ -86,33 +89,33 @@ float OrIterator::BM25Score() {
return bm25_score_cache_;
}
float sum_score = 0;
for (u32 i = 0; i < children_.size(); ++i) {
if (children_[i]->DocID() == doc_id_)
sum_score += children_[i]->BM25Score();
for (const auto &child : children_) {
if (child->DocID() == doc_id_) {
sum_score += child->BM25Score();
}
}
bm25_score_cache_docid_ = doc_id_;
bm25_score_cache_ = sum_score;
return sum_score;
}

void OrIterator::UpdateScoreThreshold(float threshold) {
void OrIterator::UpdateScoreThreshold(const float threshold) {
if (threshold <= threshold_)
return;
threshold_ = threshold;
const float base_threshold = threshold - BM25ScoreUpperBound();
for (SizeT i = 0; i < children_.size(); i++) {
const auto &it = children_[i];
float new_threshold = std::max(0.0f, base_threshold + it->BM25ScoreUpperBound());
it->UpdateScoreThreshold(new_threshold);
for (const auto &child : children_) {
const float new_threshold = std::max(0.0f, base_threshold + child->BM25ScoreUpperBound());
child->UpdateScoreThreshold(new_threshold);
}
}

u32 OrIterator::MatchCount() const {
u32 count = 0;
if (const auto current_doc_id = DocID(); current_doc_id != INVALID_ROWID) {
for (u32 i = 0; i < children_.size(); ++i) {
if (children_[i]->DocID() == current_doc_id) {
count += children_[i]->MatchCount();
for (const auto &child : children_) {
if (child->DocID() == current_doc_id) {
count += child->MatchCount();
}
}
}
Expand Down
3 changes: 2 additions & 1 deletion src/storage/invertedindex/search/or_iterator.cppm
Original file line number Diff line number Diff line change
Expand Up @@ -72,4 +72,5 @@ private:
RowID bm25_score_cache_docid_ = INVALID_ROWID;
float bm25_score_cache_ = 0.0f;
};
} // namespace infinity

} // namespace infinity
25 changes: 20 additions & 5 deletions src/storage/invertedindex/search/phrase_doc_iterator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,21 @@ import logger;

namespace infinity {

PhraseDocIterator::PhraseDocIterator(Vector<UniquePtr<PostingIterator>> &&iters, const float weight, const u32 slop)
: pos_iters_(std::move(iters)), weight_(weight), slop_(slop) {
doc_freq_ = 0;
phrase_freq_ = 0;
if (pos_iters_.size()) {
estimate_doc_freq_ = pos_iters_[0]->GetDocFreq();
} else {
estimate_doc_freq_ = 0;
}
for (SizeT i = 0; i < pos_iters_.size(); ++i) {
estimate_doc_freq_ = std::min(estimate_doc_freq_, pos_iters_[i]->GetDocFreq());
}
estimate_iterate_cost_ = {1, estimate_doc_freq_};
}

void PhraseDocIterator::InitBM25Info(UniquePtr<FullTextColumnLengthReader> &&column_length_reader) {
// BM25 parameters
constexpr float k1 = 1.2F;
Expand Down Expand Up @@ -44,13 +59,13 @@ void PhraseDocIterator::InitBM25Info(UniquePtr<FullTextColumnLengthReader> &&col
}
}

bool PhraseDocIterator::Next(RowID doc_id) {
bool PhraseDocIterator::Next(const RowID doc_id) {
assert(doc_id != INVALID_ROWID);
if (doc_id_ != INVALID_ROWID && doc_id_ >= doc_id)
return true;
assert(pos_iters_.size() > 0);
RowID target_doc_id = doc_id;
do {
while (true) {
for (const auto &it : pos_iters_) {
target_doc_id = it->SeekDoc(target_doc_id);
if (target_doc_id == INVALID_ROWID) {
Expand All @@ -64,9 +79,9 @@ bool PhraseDocIterator::Next(RowID doc_id) {
doc_id_ = target_doc_id;
return true;
}
target_doc_id++;
++target_doc_id;
}
} while (1);
}
}

float PhraseDocIterator::BM25Score() {
Expand All @@ -93,7 +108,7 @@ void PhraseDocIterator::PrintTree(std::ostream &os, const String &prefix, bool i
os << " " << term;
}
os << ")";
os << " (doc_freq: " << GetDF() << ")";
os << " (doc_freq: " << GetDocFreq() << ")";
os << '\n';
}

Expand Down
19 changes: 6 additions & 13 deletions src/storage/invertedindex/search/phrase_doc_iterator.cppm
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,9 @@ import column_length_io;
namespace infinity {
export class PhraseDocIterator final : public DocIterator {
public:
PhraseDocIterator(Vector<UniquePtr<PostingIterator>> &&iters, float weight, u32 slop = 0)
: pos_iters_(std::move(iters)), weight_(weight), slop_(slop) {
doc_freq_ = 0;
phrase_freq_ = 0;
if (pos_iters_.size()) {
estimate_doc_freq_ = pos_iters_[0]->GetDocFreq();
} else {
estimate_doc_freq_ = 0;
}
for (SizeT i = 0; i < pos_iters_.size(); ++i) {
estimate_doc_freq_ = std::min(estimate_doc_freq_, pos_iters_[i]->GetDocFreq());
}
}
PhraseDocIterator(Vector<UniquePtr<PostingIterator>> &&iters, float weight, u32 slop = 0);

inline u32 GetDocFreq() const { return doc_freq_; }

u32 GetEstimateDF() const { return estimate_doc_freq_; }

Expand Down Expand Up @@ -65,6 +55,9 @@ private:
}
bool GetExactPhraseMatchData();
bool GetSloppyPhraseMatchData();

u32 doc_freq_ = 0;

Vector<UniquePtr<PostingIterator>> pos_iters_;
float weight_;
u32 slop_{};
Expand Down
9 changes: 8 additions & 1 deletion src/storage/invertedindex/search/term_doc_iterator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,13 @@ import logger;

namespace infinity {

TermDocIterator::TermDocIterator(UniquePtr<PostingIterator> &&iter, const u64 column_id, const float weight)
: column_id_(column_id), iter_(std::move(iter)), weight_(weight) {
doc_freq_ = iter_->GetDocFreq();
term_freq_ = 0;
estimate_iterate_cost_ = {0, doc_freq_};
}

TermDocIterator::~TermDocIterator() {
if (SHOULD_LOG_TRACE()) {
OStringStream oss;
Expand Down Expand Up @@ -135,7 +142,7 @@ void TermDocIterator::PrintTree(std::ostream &os, const String &prefix, bool is_
os << " (weight: " << weight_ << ")";
os << " (column: " << *column_name_ptr_ << ")";
os << " (term: " << *term_ptr_ << ")";
os << " (doc_freq: " << GetDF() << ")";
os << " (doc_freq: " << GetDocFreq() << ")";
os << " (bm25_score_upper_bound: " << BM25ScoreUpperBound() << ")";
os << " (threshold: " << Threshold() << ")";
os << " (bm25_score_cache_docid_: " << bm25_score_cache_docid_.ToUint64() << ")";
Expand Down
Loading

0 comments on commit 725f807

Please sign in to comment.